aroluo commited on
Commit
64d4f68
·
verified ·
1 Parent(s): ab144d7

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +15 -0
  2. .gitignore +174 -0
  3. LICENSE +201 -0
  4. README.md +342 -6
  5. cli/SparkTTS.py +236 -0
  6. cli/inference.py +116 -0
  7. example/infer.sh +47 -0
  8. example/prompt_audio.wav +3 -0
  9. requirements.txt +11 -0
  10. runtime/triton_trtllm/Dockerfile.server +5 -0
  11. runtime/triton_trtllm/README.md +94 -0
  12. runtime/triton_trtllm/client_grpc.py +831 -0
  13. runtime/triton_trtllm/client_http.py +165 -0
  14. runtime/triton_trtllm/docker-compose.yml +20 -0
  15. runtime/triton_trtllm/model_repo/audio_tokenizer/1/model.py +137 -0
  16. runtime/triton_trtllm/model_repo/audio_tokenizer/config.pbtxt +58 -0
  17. runtime/triton_trtllm/model_repo/spark_tts/1/model.py +404 -0
  18. runtime/triton_trtllm/model_repo/spark_tts/config.pbtxt +86 -0
  19. runtime/triton_trtllm/model_repo/tensorrt_llm/1/.gitkeep +0 -0
  20. runtime/triton_trtllm/model_repo/tensorrt_llm/config.pbtxt +857 -0
  21. runtime/triton_trtllm/model_repo/vocoder/1/model.py +106 -0
  22. runtime/triton_trtllm/model_repo/vocoder/config.pbtxt +53 -0
  23. runtime/triton_trtllm/run.sh +109 -0
  24. runtime/triton_trtllm/scripts/convert_checkpoint.py +335 -0
  25. runtime/triton_trtllm/scripts/fill_template.py +70 -0
  26. sparktts/models/audio_tokenizer.py +163 -0
  27. sparktts/models/bicodec.py +247 -0
  28. sparktts/modules/blocks/layers.py +73 -0
  29. sparktts/modules/blocks/samper.py +115 -0
  30. sparktts/modules/blocks/vocos.py +373 -0
  31. sparktts/modules/encoder_decoder/feat_decoder.py +115 -0
  32. sparktts/modules/encoder_decoder/feat_encoder.py +105 -0
  33. sparktts/modules/encoder_decoder/wave_generator.py +88 -0
  34. sparktts/modules/fsq/finite_scalar_quantization.py +251 -0
  35. sparktts/modules/fsq/residual_fsq.py +355 -0
  36. sparktts/modules/speaker/ecapa_tdnn.py +267 -0
  37. sparktts/modules/speaker/perceiver_encoder.py +360 -0
  38. sparktts/modules/speaker/pooling_layers.py +298 -0
  39. sparktts/modules/speaker/speaker_encoder.py +136 -0
  40. sparktts/modules/vq/factorized_vector_quantize.py +187 -0
  41. sparktts/utils/__init__.py +0 -0
  42. sparktts/utils/audio.py +271 -0
  43. sparktts/utils/file.py +221 -0
  44. sparktts/utils/parse_options.sh +97 -0
  45. sparktts/utils/token_parser.py +187 -0
  46. src/demos/trump/trump_en.wav +3 -0
  47. src/demos/zhongli/zhongli_en.wav +3 -0
  48. src/demos/余承东/yuchengdong_zh.wav +3 -0
  49. src/demos/刘德华/dehua_zh.wav +3 -0
  50. src/demos/哪吒/nezha_zh.wav +3 -0
.gitattributes CHANGED
@@ -33,3 +33,18 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ example/prompt_audio.wav filter=lfs diff=lfs merge=lfs -text
37
+ src/demos/trump/trump_en.wav filter=lfs diff=lfs merge=lfs -text
38
+ src/demos/zhongli/zhongli_en.wav filter=lfs diff=lfs merge=lfs -text
39
+ src/demos/余承东/yuchengdong_zh.wav filter=lfs diff=lfs merge=lfs -text
40
+ src/demos/刘德华/dehua_zh.wav filter=lfs diff=lfs merge=lfs -text
41
+ src/demos/哪吒/nezha_zh.wav filter=lfs diff=lfs merge=lfs -text
42
+ src/demos/徐志胜/zhisheng_zh.wav filter=lfs diff=lfs merge=lfs -text
43
+ src/demos/李靖/lijing_zh.wav filter=lfs diff=lfs merge=lfs -text
44
+ src/demos/杨澜/yanglan_zh.wav filter=lfs diff=lfs merge=lfs -text
45
+ src/demos/马云/mayun_zh.wav filter=lfs diff=lfs merge=lfs -text
46
+ src/demos/鲁豫/luyu_zh.wav filter=lfs diff=lfs merge=lfs -text
47
+ src/figures/infer_control.png filter=lfs diff=lfs merge=lfs -text
48
+ src/figures/infer_voice_cloning.png filter=lfs diff=lfs merge=lfs -text
49
+ src/logo/mobvoi.png filter=lfs diff=lfs merge=lfs -text
50
+ src/logo/SparkTTS.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ pretrained_models/
6
+ results/
7
+ demo/
8
+ # C extensions
9
+ *.so
10
+ .gradio/
11
+ # Distribution / packaging
12
+ .Python
13
+ build/
14
+ develop-eggs/
15
+ dist/
16
+ downloads/
17
+ eggs/
18
+ .eggs/
19
+ lib/
20
+ lib64/
21
+ parts/
22
+ sdist/
23
+ var/
24
+ wheels/
25
+ share/python-wheels/
26
+ *.egg-info/
27
+ .installed.cfg
28
+ *.egg
29
+ MANIFEST
30
+ webui_test.py
31
+
32
+ # PyInstaller
33
+ # Usually these files are written by a python script from a template
34
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
35
+ *.manifest
36
+ *.spec
37
+
38
+ # Installer logs
39
+ pip-log.txt
40
+ pip-delete-this-directory.txt
41
+
42
+ # Unit test / coverage reports
43
+ htmlcov/
44
+ .tox/
45
+ .nox/
46
+ .coverage
47
+ .coverage.*
48
+ .cache
49
+ nosetests.xml
50
+ coverage.xml
51
+ *.cover
52
+ *.py,cover
53
+ .hypothesis/
54
+ .pytest_cache/
55
+ cover/
56
+
57
+ # Translations
58
+ *.mo
59
+ *.pot
60
+
61
+ # Django stuff:
62
+ *.log
63
+ local_settings.py
64
+ db.sqlite3
65
+ db.sqlite3-journal
66
+
67
+ # Flask stuff:
68
+ instance/
69
+ .webassets-cache
70
+
71
+ # Scrapy stuff:
72
+ .scrapy
73
+
74
+ # Sphinx documentation
75
+ docs/_build/
76
+
77
+ # PyBuilder
78
+ .pybuilder/
79
+ target/
80
+
81
+ # Jupyter Notebook
82
+ .ipynb_checkpoints
83
+
84
+ # IPython
85
+ profile_default/
86
+ ipython_config.py
87
+
88
+ # pyenv
89
+ # For a library or package, you might want to ignore these files since the code is
90
+ # intended to run in multiple environments; otherwise, check them in:
91
+ # .python-version
92
+
93
+ # pipenv
94
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
95
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
96
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
97
+ # install all needed dependencies.
98
+ #Pipfile.lock
99
+
100
+ # UV
101
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
102
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
103
+ # commonly ignored for libraries.
104
+ #uv.lock
105
+
106
+ # poetry
107
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
108
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
109
+ # commonly ignored for libraries.
110
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
111
+ #poetry.lock
112
+
113
+ # pdm
114
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
115
+ #pdm.lock
116
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
117
+ # in version control.
118
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
119
+ .pdm.toml
120
+ .pdm-python
121
+ .pdm-build/
122
+
123
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
124
+ __pypackages__/
125
+
126
+ # Celery stuff
127
+ celerybeat-schedule
128
+ celerybeat.pid
129
+
130
+ # SageMath parsed files
131
+ *.sage.py
132
+
133
+ # Environments
134
+ .env
135
+ .venv
136
+ env/
137
+ venv/
138
+ ENV/
139
+ env.bak/
140
+ venv.bak/
141
+
142
+ # Spyder project settings
143
+ .spyderproject
144
+ .spyproject
145
+
146
+ # Rope project settings
147
+ .ropeproject
148
+
149
+ # mkdocs documentation
150
+ /site
151
+
152
+ # mypy
153
+ .mypy_cache/
154
+ .dmypy.json
155
+ dmypy.json
156
+
157
+ # Pyre type checker
158
+ .pyre/
159
+
160
+ # pytype static type analyzer
161
+ .pytype/
162
+
163
+ # Cython debug symbols
164
+ cython_debug/
165
+
166
+ # PyCharm
167
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
168
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
169
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
170
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
171
+ #.idea/
172
+
173
+ # PyPI configuration file
174
+ .pypirc
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -1,12 +1,348 @@
1
  ---
2
  title: SPKTTS
3
- emoji: 💻
4
- colorFrom: yellow
5
- colorTo: red
6
  sdk: gradio
7
  sdk_version: 5.25.2
8
- app_file: app.py
9
- pinned: false
10
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  title: SPKTTS
3
+ app_file: webui.py
 
 
4
  sdk: gradio
5
  sdk_version: 5.25.2
 
 
6
  ---
7
+ <div align="center">
8
+ <h1>
9
+ Spark-TTS
10
+ </h1>
11
+ <p>
12
+ Official PyTorch code for inference of <br>
13
+ <b><em>Spark-TTS: An Efficient LLM-Based Text-to-Speech Model with Single-Stream Decoupled Speech Tokens</em></b>
14
+ </p>
15
+ <p>
16
+ <img src="src/logo/SparkTTS.jpg" alt="Spark-TTS Logo" style="width: 200px; height: 200px;">
17
+ </p>
18
+ <p>
19
+ <img src="src/logo/HKUST.jpg" alt="Institution 1" style="width: 200px; height: 60px;">
20
+ <img src="src/logo/mobvoi.jpg" alt="Institution 2" style="width: 200px; height: 60px;">
21
+ <img src="src/logo/SJU.jpg" alt="Institution 3" style="width: 200px; height: 60px;">
22
+ </p>
23
+ <p>
24
+ <img src="src/logo/NTU.jpg" alt="Institution 4" style="width: 200px; height: 60px;">
25
+ <img src="src/logo/NPU.jpg" alt="Institution 5" style="width: 200px; height: 60px;">
26
+ <img src="src/logo/SparkAudio2.jpg" alt="Institution 6" style="width: 200px; height: 60px;">
27
+ </p>
28
+ <p>
29
+ </p>
30
+ <a href="https://arxiv.org/pdf/2503.01710"><img src="https://img.shields.io/badge/Paper-ArXiv-red" alt="paper"></a>
31
+ <a href="https://sparkaudio.github.io/spark-tts/"><img src="https://img.shields.io/badge/Demo-Page-lightgrey" alt="version"></a>
32
+ <a href="https://huggingface.co/SparkAudio/Spark-TTS-0.5B"><img src="https://img.shields.io/badge/Hugging%20Face-Model%20Page-yellow" alt="Hugging Face"></a>
33
+ <a href="https://github.com/SparkAudio/Spark-TTS"><img src="https://img.shields.io/badge/Platform-linux-lightgrey" alt="version"></a>
34
+ <a href="https://github.com/SparkAudio/Spark-TTS"><img src="https://img.shields.io/badge/Python-3.12+-orange" alt="version"></a>
35
+ <a href="https://github.com/SparkAudio/Spark-TTS"><img src="https://img.shields.io/badge/PyTorch-2.5+-brightgreen" alt="python"></a>
36
+ <a href="https://github.com/SparkAudio/Spark-TTS"><img src="https://img.shields.io/badge/License-Apache%202.0-blue.svg" alt="mit"></a>
37
+ </div>
38
 
39
+
40
+ ## Spark-TTS 🔥
41
+
42
+ ### Overview
43
+
44
+ Spark-TTS is an advanced text-to-speech system that uses the power of large language models (LLM) for highly accurate and natural-sounding voice synthesis. It is designed to be efficient, flexible, and powerful for both research and production use.
45
+
46
+ ### Key Features
47
+
48
+ - **Simplicity and Efficiency**: Built entirely on Qwen2.5, Spark-TTS eliminates the need for additional generation models like flow matching. Instead of relying on separate models to generate acoustic features, it directly reconstructs audio from the code predicted by the LLM. This approach streamlines the process, improving efficiency and reducing complexity.
49
+ - **High-Quality Voice Cloning**: Supports zero-shot voice cloning, which means it can replicate a speaker's voice even without specific training data for that voice. This is ideal for cross-lingual and code-switching scenarios, allowing for seamless transitions between languages and voices without requiring separate training for each one.
50
+ - **Bilingual Support**: Supports both Chinese and English, and is capable of zero-shot voice cloning for cross-lingual and code-switching scenarios, enabling the model to synthesize speech in multiple languages with high naturalness and accuracy.
51
+ - **Controllable Speech Generation**: Supports creating virtual speakers by adjusting parameters such as gender, pitch, and speaking rate.
52
+
53
+ ---
54
+
55
+ <table align="center">
56
+ <tr>
57
+ <td align="center"><b>Inference Overview of Voice Cloning</b><br><img src="src/figures/infer_voice_cloning.png" width="80%" /></td>
58
+ </tr>
59
+ <tr>
60
+ <td align="center"><b>Inference Overview of Controlled Generation</b><br><img src="src/figures/infer_control.png" width="80%" /></td>
61
+ </tr>
62
+ </table>
63
+
64
+
65
+ ## 🚀 News
66
+
67
+ - **[2025-03-04]** Our paper on this project has been published! You can read it here: [Spark-TTS](https://arxiv.org/pdf/2503.01710).
68
+
69
+ - **[2025-03-12]** Nvidia Triton Inference Serving is now supported. See the Runtime section below for more details.
70
+
71
+
72
+ ## Install
73
+ **Clone and Install**
74
+
75
+ Here are instructions for installing on Linux. If you're on Windows, please refer to the [Windows Installation Guide](https://github.com/SparkAudio/Spark-TTS/issues/5).
76
+ *(Thanks to [@AcTePuKc](https://github.com/AcTePuKc) for the detailed Windows instructions!)*
77
+
78
+
79
+ - Clone the repo
80
+ ``` sh
81
+ git clone https://github.com/SparkAudio/Spark-TTS.git
82
+ cd Spark-TTS
83
+ ```
84
+
85
+ - Install Conda: please see https://docs.conda.io/en/latest/miniconda.html
86
+ - Create Conda env:
87
+
88
+ ``` sh
89
+ conda create -n sparktts -y python=3.12
90
+ conda activate sparktts
91
+ pip install -r requirements.txt
92
+ # If you are in mainland China, you can set the mirror as follows:
93
+ pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com
94
+ ```
95
+
96
+ **Model Download**
97
+
98
+ Download via python:
99
+ ```python
100
+ from huggingface_hub import snapshot_download
101
+
102
+ snapshot_download("SparkAudio/Spark-TTS-0.5B", local_dir="pretrained_models/Spark-TTS-0.5B")
103
+ ```
104
+
105
+ Download via git clone:
106
+ ```sh
107
+ mkdir -p pretrained_models
108
+
109
+ # Make sure you have git-lfs installed (https://git-lfs.com)
110
+ git lfs install
111
+
112
+ git clone https://huggingface.co/SparkAudio/Spark-TTS-0.5B pretrained_models/Spark-TTS-0.5B
113
+ ```
114
+
115
+ **Basic Usage**
116
+
117
+ You can simply run the demo with the following commands:
118
+ ``` sh
119
+ cd example
120
+ bash infer.sh
121
+ ```
122
+
123
+ Alternatively, you can directly execute the following command in the command line to perform inference:
124
+
125
+ ``` sh
126
+ python -m cli.inference \
127
+ --text "text to synthesis." \
128
+ --device 0 \
129
+ --save_dir "path/to/save/audio" \
130
+ --model_dir pretrained_models/Spark-TTS-0.5B \
131
+ --prompt_text "transcript of the prompt audio" \
132
+ --prompt_speech_path "path/to/prompt_audio"
133
+ ```
134
+
135
+ **Web UI Usage**
136
+
137
+ You can start the UI interface by running `python webui.py --device 0`, which allows you to perform Voice Cloning and Voice Creation. Voice Cloning supports uploading reference audio or directly recording the audio.
138
+
139
+
140
+ | **Voice Cloning** | **Voice Creation** |
141
+ |:-------------------:|:-------------------:|
142
+ | ![Image 1](src/figures/gradio_TTS.png) | ![Image 2](src/figures/gradio_control.png) |
143
+
144
+
145
+ **Optional Methods**
146
+
147
+ For additional CLI and Web UI methods, including alternative implementations and extended functionalities, you can refer to:
148
+
149
+ - [CLI and UI by AcTePuKc](https://github.com/SparkAudio/Spark-TTS/issues/10)
150
+
151
+
152
+ ## Runtime
153
+
154
+ **Nvidia Triton Inference Serving**
155
+
156
+ We now provide a reference for deploying Spark-TTS with Nvidia Triton and TensorRT-LLM. The table below presents benchmark results on a single L20 GPU, using 26 different prompt_audio/target_text pairs (totalling 169 seconds of audio):
157
+
158
+ | Model | Note | Concurrency | Avg Latency | RTF |
159
+ |-------|-----------|-----------------------|---------|--|
160
+ | Spark-TTS-0.5B | [Code Commit](https://github.com/SparkAudio/Spark-TTS/tree/4d769ff782a868524f29e0be851ca64f8b22ebf1/runtime/triton_trtllm) | 1 | 876.24 ms | 0.1362|
161
+ | Spark-TTS-0.5B | [Code Commit](https://github.com/SparkAudio/Spark-TTS/tree/4d769ff782a868524f29e0be851ca64f8b22ebf1/runtime/triton_trtllm) | 2 | 920.97 ms | 0.0737|
162
+ | Spark-TTS-0.5B | [Code Commit](https://github.com/SparkAudio/Spark-TTS/tree/4d769ff782a868524f29e0be851ca64f8b22ebf1/runtime/triton_trtllm) | 4 | 1611.51 ms | 0.0704|
163
+
164
+
165
+ Please see the detailed instructions in [runtime/triton_trtllm/README.md](runtime/triton_trtllm/README.md ) for more information.
166
+
167
+
168
+ ## **Demos**
169
+
170
+ Here are some demos generated by Spark-TTS using zero-shot voice cloning. For more demos, visit our [demo page](https://sparkaudio.github.io/spark-tts/).
171
+
172
+ ---
173
+
174
+ <table>
175
+ <tr>
176
+ <td align="center">
177
+
178
+ **Donald Trump**
179
+ </td>
180
+ <td align="center">
181
+
182
+ **Zhongli (Genshin Impact)**
183
+ </td>
184
+ </tr>
185
+
186
+ <tr>
187
+ <td align="center">
188
+
189
+ [Donald Trump](https://github.com/user-attachments/assets/fb225780-d9fe-44b2-9b2e-54390cb3d8fd)
190
+
191
+ </td>
192
+ <td align="center">
193
+
194
+ [Zhongli](https://github.com/user-attachments/assets/80eeb9c7-0443-4758-a1ce-55ac59e64bd6)
195
+
196
+ </td>
197
+ </tr>
198
+ </table>
199
+
200
+ ---
201
+
202
+ <table>
203
+
204
+ <tr>
205
+ <td align="center">
206
+
207
+ **陈鲁豫 Chen Luyu**
208
+ </td>
209
+ <td align="center">
210
+
211
+ **杨澜 Yang Lan**
212
+ </td>
213
+ </tr>
214
+
215
+ <tr>
216
+ <td align="center">
217
+
218
+ [陈鲁豫Chen_Luyu.webm](https://github.com/user-attachments/assets/5c6585ae-830d-47b1-992d-ee3691f48cf4)
219
+ </td>
220
+ <td align="center">
221
+
222
+ [Yang_Lan.webm](https://github.com/user-attachments/assets/2fb3d00c-abc3-410e-932f-46ba204fb1d7)
223
+ </td>
224
+ </tr>
225
+ </table>
226
+
227
+ ---
228
+
229
+
230
+ <table>
231
+ <tr>
232
+ <td align="center">
233
+
234
+ **余承东 Richard Yu**
235
+ </td>
236
+ <td align="center">
237
+
238
+ **马云 Jack Ma**
239
+ </td>
240
+ </tr>
241
+
242
+ <tr>
243
+ <td align="center">
244
+
245
+ [Yu_Chengdong.webm](https://github.com/user-attachments/assets/78feca02-84bb-4d3a-a770-0cfd02f1a8da)
246
+
247
+ </td>
248
+ <td align="center">
249
+
250
+ [Ma_Yun.webm](https://github.com/user-attachments/assets/2d54e2eb-cec4-4c2f-8c84-8fe587da321b)
251
+
252
+ </td>
253
+ </tr>
254
+ </table>
255
+
256
+ ---
257
+
258
+
259
+ <table>
260
+ <tr>
261
+ <td align="center">
262
+
263
+ **刘德华 Andy Lau**
264
+ </td>
265
+ <td align="center">
266
+
267
+ **徐志胜 Xu Zhisheng**
268
+ </td>
269
+ </tr>
270
+
271
+ <tr>
272
+ <td align="center">
273
+
274
+ [Liu_Dehua.webm](https://github.com/user-attachments/assets/195b5e97-1fee-4955-b954-6d10fa04f1d7)
275
+
276
+ </td>
277
+ <td align="center">
278
+
279
+ [Xu_Zhisheng.webm](https://github.com/user-attachments/assets/dd812af9-76bd-4e26-9988-9cdb9ccbb87b)
280
+
281
+ </td>
282
+ </tr>
283
+ </table>
284
+
285
+
286
+ ---
287
+
288
+ <table>
289
+ <tr>
290
+ <td align="center">
291
+
292
+ **哪吒 Nezha**
293
+ </td>
294
+ <td align="center">
295
+
296
+ **李靖 Li Jing**
297
+ </td>
298
+ </tr>
299
+
300
+ <tr>
301
+ <td align="center">
302
+
303
+ [Ne_Zha.webm](https://github.com/user-attachments/assets/8c608037-a17a-46d4-8588-4db34b49ed1d)
304
+ </td>
305
+ <td align="center">
306
+
307
+ [Li_Jing.webm](https://github.com/user-attachments/assets/aa8ba091-097c-4156-b4e3-6445da5ea101)
308
+
309
+ </td>
310
+ </tr>
311
+ </table>
312
+
313
+
314
+ ## To-Do List
315
+
316
+ - [x] Release the Spark-TTS paper.
317
+ - [ ] Release the training code.
318
+ - [ ] Release the training dataset, VoxBox.
319
+
320
+
321
+ ## Citation
322
+
323
+ ```
324
+ @misc{wang2025sparktts,
325
+ title={Spark-TTS: An Efficient LLM-Based Text-to-Speech Model with Single-Stream Decoupled Speech Tokens},
326
+ author={Xinsheng Wang and Mingqi Jiang and Ziyang Ma and Ziyu Zhang and Songxiang Liu and Linqin Li and Zheng Liang and Qixi Zheng and Rui Wang and Xiaoqin Feng and Weizhen Bian and Zhen Ye and Sitong Cheng and Ruibin Yuan and Zhixian Zhao and Xinfa Zhu and Jiahao Pan and Liumeng Xue and Pengcheng Zhu and Yunlin Chen and Zhifei Li and Xie Chen and Lei Xie and Yike Guo and Wei Xue},
327
+ year={2025},
328
+ eprint={2503.01710},
329
+ archivePrefix={arXiv},
330
+ primaryClass={cs.SD},
331
+ url={https://arxiv.org/abs/2503.01710},
332
+ }
333
+ ```
334
+
335
+
336
+ ## ⚠️ Usage Disclaimer
337
+
338
+ This project provides a zero-shot voice cloning TTS model intended for academic research, educational purposes, and legitimate applications, such as personalized speech synthesis, assistive technologies, and linguistic research.
339
+
340
+ Please note:
341
+
342
+ - Do not use this model for unauthorized voice cloning, impersonation, fraud, scams, deepfakes, or any illegal activities.
343
+
344
+ - Ensure compliance with local laws and regulations when using this model and uphold ethical standards.
345
+
346
+ - The developers assume no liability for any misuse of this model.
347
+
348
+ We advocate for the responsible development and use of AI and encourage the community to uphold safety and ethical principles in AI research and applications. If you have any concerns regarding ethics or misuse, please contact us.
cli/SparkTTS.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 SparkAudio
2
+ # 2025 Xinsheng Wang ([email protected])
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import re
17
+ import torch
18
+ from typing import Tuple
19
+ from pathlib import Path
20
+ from transformers import AutoTokenizer, AutoModelForCausalLM
21
+
22
+ from sparktts.utils.file import load_config
23
+ from sparktts.models.audio_tokenizer import BiCodecTokenizer
24
+ from sparktts.utils.token_parser import LEVELS_MAP, GENDER_MAP, TASK_TOKEN_MAP
25
+
26
+
27
+ class SparkTTS:
28
+ """
29
+ Spark-TTS for text-to-speech generation.
30
+ """
31
+
32
+ def __init__(self, model_dir: Path, device: torch.device = torch.device("cuda:0")):
33
+ """
34
+ Initializes the SparkTTS model with the provided configurations and device.
35
+
36
+ Args:
37
+ model_dir (Path): Directory containing the model and config files.
38
+ device (torch.device): The device (CPU/GPU) to run the model on.
39
+ """
40
+ self.device = device
41
+ self.model_dir = model_dir
42
+ self.configs = load_config(f"{model_dir}/config.yaml")
43
+ self.sample_rate = self.configs["sample_rate"]
44
+ self._initialize_inference()
45
+
46
+ def _initialize_inference(self):
47
+ """Initializes the tokenizer, model, and audio tokenizer for inference."""
48
+ self.tokenizer = AutoTokenizer.from_pretrained(f"{self.model_dir}/LLM")
49
+ self.model = AutoModelForCausalLM.from_pretrained(f"{self.model_dir}/LLM")
50
+ self.audio_tokenizer = BiCodecTokenizer(self.model_dir, device=self.device)
51
+ self.model.to(self.device)
52
+
53
+ def process_prompt(
54
+ self,
55
+ text: str,
56
+ prompt_speech_path: Path,
57
+ prompt_text: str = None,
58
+ ) -> Tuple[str, torch.Tensor]:
59
+ """
60
+ Process input for voice cloning.
61
+
62
+ Args:
63
+ text (str): The text input to be converted to speech.
64
+ prompt_speech_path (Path): Path to the audio file used as a prompt.
65
+ prompt_text (str, optional): Transcript of the prompt audio.
66
+
67
+ Return:
68
+ Tuple[str, torch.Tensor]: Input prompt; global tokens
69
+ """
70
+
71
+ global_token_ids, semantic_token_ids = self.audio_tokenizer.tokenize(
72
+ prompt_speech_path
73
+ )
74
+ global_tokens = "".join(
75
+ [f"<|bicodec_global_{i}|>" for i in global_token_ids.squeeze()]
76
+ )
77
+
78
+ # Prepare the input tokens for the model
79
+ if prompt_text is not None:
80
+ semantic_tokens = "".join(
81
+ [f"<|bicodec_semantic_{i}|>" for i in semantic_token_ids.squeeze()]
82
+ )
83
+ inputs = [
84
+ TASK_TOKEN_MAP["tts"],
85
+ "<|start_content|>",
86
+ prompt_text,
87
+ text,
88
+ "<|end_content|>",
89
+ "<|start_global_token|>",
90
+ global_tokens,
91
+ "<|end_global_token|>",
92
+ "<|start_semantic_token|>",
93
+ semantic_tokens,
94
+ ]
95
+ else:
96
+ inputs = [
97
+ TASK_TOKEN_MAP["tts"],
98
+ "<|start_content|>",
99
+ text,
100
+ "<|end_content|>",
101
+ "<|start_global_token|>",
102
+ global_tokens,
103
+ "<|end_global_token|>",
104
+ ]
105
+
106
+ inputs = "".join(inputs)
107
+
108
+ return inputs, global_token_ids
109
+
110
+ def process_prompt_control(
111
+ self,
112
+ gender: str,
113
+ pitch: str,
114
+ speed: str,
115
+ text: str,
116
+ ):
117
+ """
118
+ Process input for voice creation.
119
+
120
+ Args:
121
+ gender (str): female | male.
122
+ pitch (str): very_low | low | moderate | high | very_high
123
+ speed (str): very_low | low | moderate | high | very_high
124
+ text (str): The text input to be converted to speech.
125
+
126
+ Return:
127
+ str: Input prompt
128
+ """
129
+ assert gender in GENDER_MAP.keys()
130
+ assert pitch in LEVELS_MAP.keys()
131
+ assert speed in LEVELS_MAP.keys()
132
+
133
+ gender_id = GENDER_MAP[gender]
134
+ pitch_level_id = LEVELS_MAP[pitch]
135
+ speed_level_id = LEVELS_MAP[speed]
136
+
137
+ pitch_label_tokens = f"<|pitch_label_{pitch_level_id}|>"
138
+ speed_label_tokens = f"<|speed_label_{speed_level_id}|>"
139
+ gender_tokens = f"<|gender_{gender_id}|>"
140
+
141
+ attribte_tokens = "".join(
142
+ [gender_tokens, pitch_label_tokens, speed_label_tokens]
143
+ )
144
+
145
+ control_tts_inputs = [
146
+ TASK_TOKEN_MAP["controllable_tts"],
147
+ "<|start_content|>",
148
+ text,
149
+ "<|end_content|>",
150
+ "<|start_style_label|>",
151
+ attribte_tokens,
152
+ "<|end_style_label|>",
153
+ ]
154
+
155
+ return "".join(control_tts_inputs)
156
+
157
+ @torch.no_grad()
158
+ def inference(
159
+ self,
160
+ text: str,
161
+ prompt_speech_path: Path = None,
162
+ prompt_text: str = None,
163
+ gender: str = None,
164
+ pitch: str = None,
165
+ speed: str = None,
166
+ temperature: float = 0.8,
167
+ top_k: float = 50,
168
+ top_p: float = 0.95,
169
+ ) -> torch.Tensor:
170
+ """
171
+ Performs inference to generate speech from text, incorporating prompt audio and/or text.
172
+
173
+ Args:
174
+ text (str): The text input to be converted to speech.
175
+ prompt_speech_path (Path): Path to the audio file used as a prompt.
176
+ prompt_text (str, optional): Transcript of the prompt audio.
177
+ gender (str): female | male.
178
+ pitch (str): very_low | low | moderate | high | very_high
179
+ speed (str): very_low | low | moderate | high | very_high
180
+ temperature (float, optional): Sampling temperature for controlling randomness. Default is 0.8.
181
+ top_k (float, optional): Top-k sampling parameter. Default is 50.
182
+ top_p (float, optional): Top-p (nucleus) sampling parameter. Default is 0.95.
183
+
184
+ Returns:
185
+ torch.Tensor: Generated waveform as a tensor.
186
+ """
187
+ if gender is not None:
188
+ prompt = self.process_prompt_control(gender, pitch, speed, text)
189
+
190
+ else:
191
+ prompt, global_token_ids = self.process_prompt(
192
+ text, prompt_speech_path, prompt_text
193
+ )
194
+ model_inputs = self.tokenizer([prompt], return_tensors="pt").to(self.device)
195
+
196
+ # Generate speech using the model
197
+ generated_ids = self.model.generate(
198
+ **model_inputs,
199
+ max_new_tokens=3000,
200
+ do_sample=True,
201
+ top_k=top_k,
202
+ top_p=top_p,
203
+ temperature=temperature,
204
+ )
205
+
206
+ # Trim the output tokens to remove the input tokens
207
+ generated_ids = [
208
+ output_ids[len(input_ids) :]
209
+ for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
210
+ ]
211
+
212
+ # Decode the generated tokens into text
213
+ predicts = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
214
+
215
+ # Extract semantic token IDs from the generated text
216
+ pred_semantic_ids = (
217
+ torch.tensor([int(token) for token in re.findall(r"bicodec_semantic_(\d+)", predicts)])
218
+ .long()
219
+ .unsqueeze(0)
220
+ )
221
+
222
+ if gender is not None:
223
+ global_token_ids = (
224
+ torch.tensor([int(token) for token in re.findall(r"bicodec_global_(\d+)", predicts)])
225
+ .long()
226
+ .unsqueeze(0)
227
+ .unsqueeze(0)
228
+ )
229
+
230
+ # Convert semantic tokens back to waveform
231
+ wav = self.audio_tokenizer.detokenize(
232
+ global_token_ids.to(self.device).squeeze(0),
233
+ pred_semantic_ids.to(self.device),
234
+ )
235
+
236
+ return wav
cli/inference.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 SparkAudio
2
+ # 2025 Xinsheng Wang ([email protected])
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+
17
+ import os
18
+ import argparse
19
+ import torch
20
+ import soundfile as sf
21
+ import logging
22
+ from datetime import datetime
23
+ import platform
24
+
25
+ from cli.SparkTTS import SparkTTS
26
+
27
+
28
+ def parse_args():
29
+ """Parse command-line arguments."""
30
+ parser = argparse.ArgumentParser(description="Run TTS inference.")
31
+
32
+ parser.add_argument(
33
+ "--model_dir",
34
+ type=str,
35
+ default="pretrained_models/Spark-TTS-0.5B",
36
+ help="Path to the model directory",
37
+ )
38
+ parser.add_argument(
39
+ "--save_dir",
40
+ type=str,
41
+ default="example/results",
42
+ help="Directory to save generated audio files",
43
+ )
44
+ parser.add_argument("--device", type=int, default=0, help="CUDA device number")
45
+ parser.add_argument(
46
+ "--text", type=str, required=True, help="Text for TTS generation"
47
+ )
48
+ parser.add_argument("--prompt_text", type=str, help="Transcript of prompt audio")
49
+ parser.add_argument(
50
+ "--prompt_speech_path",
51
+ type=str,
52
+ help="Path to the prompt audio file",
53
+ )
54
+ parser.add_argument("--gender", choices=["male", "female"])
55
+ parser.add_argument(
56
+ "--pitch", choices=["very_low", "low", "moderate", "high", "very_high"]
57
+ )
58
+ parser.add_argument(
59
+ "--speed", choices=["very_low", "low", "moderate", "high", "very_high"]
60
+ )
61
+ return parser.parse_args()
62
+
63
+
64
+ def run_tts(args):
65
+ """Perform TTS inference and save the generated audio."""
66
+ logging.info(f"Using model from: {args.model_dir}")
67
+ logging.info(f"Saving audio to: {args.save_dir}")
68
+
69
+ # Ensure the save directory exists
70
+ os.makedirs(args.save_dir, exist_ok=True)
71
+
72
+ # Convert device argument to torch.device
73
+ if platform.system() == "Darwin" and torch.backends.mps.is_available():
74
+ # macOS with MPS support (Apple Silicon)
75
+ device = torch.device(f"mps:{args.device}")
76
+ logging.info(f"Using MPS device: {device}")
77
+ elif torch.cuda.is_available():
78
+ # System with CUDA support
79
+ device = torch.device(f"cuda:{args.device}")
80
+ logging.info(f"Using CUDA device: {device}")
81
+ else:
82
+ # Fall back to CPU
83
+ device = torch.device("cpu")
84
+ logging.info("GPU acceleration not available, using CPU")
85
+
86
+ # Initialize the model
87
+ model = SparkTTS(args.model_dir, device)
88
+
89
+ # Generate unique filename using timestamp
90
+ timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
91
+ save_path = os.path.join(args.save_dir, f"{timestamp}.wav")
92
+
93
+ logging.info("Starting inference...")
94
+
95
+ # Perform inference and save the output audio
96
+ with torch.no_grad():
97
+ wav = model.inference(
98
+ args.text,
99
+ args.prompt_speech_path,
100
+ prompt_text=args.prompt_text,
101
+ gender=args.gender,
102
+ pitch=args.pitch,
103
+ speed=args.speed,
104
+ )
105
+ sf.write(save_path, wav, samplerate=16000)
106
+
107
+ logging.info(f"Audio saved at: {save_path}")
108
+
109
+
110
+ if __name__ == "__main__":
111
+ logging.basicConfig(
112
+ level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
113
+ )
114
+
115
+ args = parse_args()
116
+ run_tts(args)
example/infer.sh ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Copyright (c) 2025 SparkAudio
4
+ # 2025 Xinsheng Wang ([email protected])
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+
19
+ # Get the absolute path of the script's directory
20
+ script_dir=$(dirname "$(realpath "$0")")
21
+
22
+ # Get the root directory
23
+ root_dir=$(dirname "$script_dir")
24
+
25
+ # Set default parameters
26
+ device=0
27
+ save_dir='example/results'
28
+ model_dir="pretrained_models/Spark-TTS-0.5B"
29
+ text="身临其境,换新体验。塑造开源语音合成新范式,让智能语音更自然。"
30
+ prompt_text="吃燕窝就选燕之屋,本节目由26年专注高品质燕窝的燕之屋冠名播出。豆奶牛奶换着喝,营养更均衡,本节目由豆本豆豆奶特约播出。"
31
+ prompt_speech_path="example/prompt_audio.wav"
32
+
33
+ # Change directory to the root directory
34
+ cd "$root_dir" || exit
35
+
36
+ source sparktts/utils/parse_options.sh
37
+
38
+ # Run inference
39
+ python -m cli.inference \
40
+ --text "${text}" \
41
+ --device "${device}" \
42
+ --save_dir "${save_dir}" \
43
+ --model_dir "${model_dir}" \
44
+ --prompt_text "${prompt_text}" \
45
+ --prompt_speech_path "${prompt_speech_path}"
46
+
47
+
example/prompt_audio.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:335e7f7789b231cd90d9670292d561ecfe6a6bdd5e737a7bc6c29730741852de
3
+ size 318550
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ einops==0.8.1
2
+ einx==0.3.0
3
+ numpy==2.2.3
4
+ omegaconf==2.3.0
5
+ packaging==24.2
6
+ safetensors==0.5.2
7
+ soundfile==0.12.1
8
+ soxr==0.5.0.post1
9
+ tqdm==4.66.5
10
+ transformers==4.46.2
11
+ gradio==5.18.0
runtime/triton_trtllm/Dockerfile.server ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ FROM nvcr.io/nvidia/tritonserver:25.02-trtllm-python-py3
2
+ RUN apt-get update && apt-get install -y cmake
3
+ RUN git clone https://github.com/pytorch/audio.git && cd audio && git checkout c670ad8 && PATH=/usr/local/cuda/bin:$PATH python3 setup.py develop
4
+ RUN pip install einx==0.3.0 omegaconf==2.3.0 soundfile==0.12.1 soxr==0.5.0.post1 gradio tritonclient librosa
5
+ WORKDIR /workspace
runtime/triton_trtllm/README.md ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Nvidia Triton Inference Serving Best Practice for Spark TTS
2
+
3
+ ### Quick Start
4
+ Directly launch the service using docker compose.
5
+ ```sh
6
+ docker compose up
7
+ ```
8
+
9
+ ### Build Image
10
+ Build the docker image from scratch.
11
+ ```sh
12
+ docker build . -f Dockerfile.server -t soar97/triton-spark-tts:25.02
13
+ ```
14
+
15
+ ### Create Docker Container
16
+ ```sh
17
+ your_mount_dir=/mnt:/mnt
18
+ docker run -it --name "spark-tts-server" --gpus all --net host -v $your_mount_dir --shm-size=2g soar97/triton-spark-tts:25.02
19
+ ```
20
+
21
+ ### Understanding `run.sh`
22
+
23
+ The `run.sh` script automates various steps using stages. You can run specific stages using:
24
+ ```sh
25
+ bash run.sh <start_stage> <stop_stage> [service_type]
26
+ ```
27
+ - `<start_stage>`: The stage to begin execution from (0-5).
28
+ - `<stop_stage>`: The stage to end execution at (0-5).
29
+ - `[service_type]`: Optional, specifies the service type ('streaming' or 'offline', defaults may apply based on script logic). Required for stages 4 and 5.
30
+
31
+ Stages:
32
+ - **Stage 0**: Download Spark-TTS-0.5B model from HuggingFace.
33
+ - **Stage 1**: Convert HuggingFace checkpoint to TensorRT-LLM format and build TensorRT engines.
34
+ - **Stage 2**: Create the Triton model repository structure and configure model files (adjusts for streaming/offline).
35
+ - **Stage 3**: Launch the Triton Inference Server.
36
+ - **Stage 4**: Run the gRPC benchmark client.
37
+ - **Stage 5**: Run the single utterance client (gRPC for streaming, HTTP for offline).
38
+
39
+ ### Export Models to TensorRT-LLM and Launch Server
40
+ Inside the docker container, you can prepare the models and launch the Triton server by running stages 0 through 3. This involves downloading the original model, converting it to TensorRT-LLM format, building the optimized TensorRT engines, creating the necessary model repository structure for Triton, and finally starting the server.
41
+ ```sh
42
+ # This runs stages 0, 1, 2, and 3
43
+ bash run.sh 0 3
44
+ ```
45
+ *Note: Stage 2 prepares the model repository differently based on whether you intend to run streaming or offline inference later. You might need to re-run stage 2 if switching service types.*
46
+
47
+
48
+ ### Single Utterance Client
49
+ Run a single inference request. Specify `streaming` or `offline` as the third argument.
50
+
51
+ **Streaming Mode (gRPC):**
52
+ ```sh
53
+ bash run.sh 5 5 streaming
54
+ ```
55
+ This executes the `client_grpc.py` script with predefined example text and prompt audio in streaming mode.
56
+
57
+ **Offline Mode (HTTP):**
58
+ ```sh
59
+ bash run.sh 5 5 offline
60
+ ```
61
+
62
+ ### Benchmark using Dataset
63
+ Run the benchmark client against the running Triton server. Specify `streaming` or `offline` as the third argument.
64
+ ```sh
65
+ # Run benchmark in streaming mode
66
+ bash run.sh 4 4 streaming
67
+
68
+ # Run benchmark in offline mode
69
+ bash run.sh 4 4 offline
70
+
71
+ # You can also customize parameters like num_task directly in client_grpc.py or via args if supported
72
+ # Example from run.sh (streaming):
73
+ # python3 client_grpc.py \
74
+ # --server-addr localhost \
75
+ # --model-name spark_tts \
76
+ # --num-tasks 2 \
77
+ # --mode streaming \
78
+ # --log-dir ./log_concurrent_tasks_2_streaming_new
79
+
80
+ # Example customizing dataset (requires modifying client_grpc.py or adding args):
81
+ # python3 client_grpc.py --num-tasks 2 --huggingface-dataset yuekai/seed_tts --split-name wenetspeech4tts --mode [streaming|offline]
82
+ ```
83
+
84
+ ### Benchmark Results
85
+ Decoding on a single L20 GPU, using 26 different prompt_audio/target_text [pairs](https://huggingface.co/datasets/yuekai/seed_tts), total audio duration 169 secs.
86
+
87
+ | Mode | Note | Concurrency | Avg Latency | First Chunk Latency (P50) | RTF |
88
+ |-------|-----------|-----------------------|---------|----------------|-|
89
+ | Offline | [Code Commit](https://github.com/SparkAudio/Spark-TTS/tree/4d769ff782a868524f29e0be851ca64f8b22ebf1/runtime/triton_trtllm) | 1 | 876.24 ms |-| 0.1362|
90
+ | Offline | [Code Commit](https://github.com/SparkAudio/Spark-TTS/tree/4d769ff782a868524f29e0be851ca64f8b22ebf1/runtime/triton_trtllm) | 2 | 920.97 ms |-|0.0737|
91
+ | Offline | [Code Commit](https://github.com/SparkAudio/Spark-TTS/tree/4d769ff782a868524f29e0be851ca64f8b22ebf1/runtime/triton_trtllm) | 4 | 1611.51 ms |-| 0.0704|
92
+ | Streaming | [Code Commit](https://github.com/yuekaizhang/Spark-TTS/commit/0e978a327f99aa49f0735f86eb09372f16410d86) | 1 | 913.28 ms |210.42 ms| 0.1501 |
93
+ | Streaming | [Code Commit](https://github.com/yuekaizhang/Spark-TTS/commit/0e978a327f99aa49f0735f86eb09372f16410d86) | 2 | 1009.23 ms |226.08 ms |0.0862 |
94
+ | Streaming | [Code Commit](https://github.com/yuekaizhang/Spark-TTS/commit/0e978a327f99aa49f0735f86eb09372f16410d86) | 4 | 1793.86 ms |1017.70 ms| 0.0824 |
runtime/triton_trtllm/client_grpc.py ADDED
@@ -0,0 +1,831 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
3
+ # 2023 Nvidia (authors: Yuekai Zhang)
4
+ # 2023 Recurrent.ai (authors: Songtao Shi)
5
+ # See LICENSE for clarification regarding multiple authors
6
+ #
7
+ # Licensed under the Apache License, Version 2.0 (the "License");
8
+ # you may not use this file except in compliance with the License.
9
+ # You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing, software
14
+ # distributed under the License is distributed on an "AS IS" BASIS,
15
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ # See the License for the specific language governing permissions and
17
+ # limitations under the License.
18
+ """
19
+ This script supports to load dataset from huggingface and sends it to the server
20
+ for decoding, in parallel.
21
+
22
+ Usage:
23
+ num_task=2
24
+
25
+ # For offline F5-TTS
26
+ python3 client_grpc.py \
27
+ --server-addr localhost \
28
+ --model-name f5_tts \
29
+ --num-tasks $num_task \
30
+ --huggingface-dataset yuekai/seed_tts \
31
+ --split-name test_zh \
32
+ --log-dir ./log_concurrent_tasks_${num_task}
33
+
34
+ # For offline Spark-TTS-0.5B
35
+ python3 client_grpc.py \
36
+ --server-addr localhost \
37
+ --model-name spark_tts \
38
+ --num-tasks $num_task \
39
+ --huggingface-dataset yuekai/seed_tts \
40
+ --split-name wenetspeech4tts \
41
+ --log-dir ./log_concurrent_tasks_${num_task}
42
+ """
43
+
44
+ import argparse
45
+ import asyncio
46
+ import json
47
+ import queue # Added
48
+ import uuid # Added
49
+ import functools # Added
50
+
51
+ import os
52
+ import time
53
+ import types
54
+ from pathlib import Path
55
+
56
+ import numpy as np
57
+ import soundfile as sf
58
+ import tritonclient
59
+ import tritonclient.grpc.aio as grpcclient_aio # Renamed original import
60
+ import tritonclient.grpc as grpcclient_sync # Added sync client import
61
+ from tritonclient.utils import np_to_triton_dtype, InferenceServerException # Added InferenceServerException
62
+
63
+
64
+ # --- Added UserData and callback ---
65
+ class UserData:
66
+ def __init__(self):
67
+ self._completed_requests = queue.Queue()
68
+ self._first_chunk_time = None
69
+ self._start_time = None
70
+
71
+ def record_start_time(self):
72
+ self._start_time = time.time()
73
+
74
+ def get_first_chunk_latency(self):
75
+ if self._first_chunk_time and self._start_time:
76
+ return self._first_chunk_time - self._start_time
77
+ return None
78
+
79
+ def callback(user_data, result, error):
80
+ if user_data._first_chunk_time is None and not error:
81
+ user_data._first_chunk_time = time.time() # Record time of first successful chunk
82
+ if error:
83
+ user_data._completed_requests.put(error)
84
+ else:
85
+ user_data._completed_requests.put(result)
86
+ # --- End Added UserData and callback ---
87
+
88
+
89
+ def write_triton_stats(stats, summary_file):
90
+ with open(summary_file, "w") as summary_f:
91
+ model_stats = stats["model_stats"]
92
+ # write a note, the log is from triton_client.get_inference_statistics(), to better human readability
93
+ summary_f.write(
94
+ "The log is parsing from triton_client.get_inference_statistics(), to better human readability. \n"
95
+ )
96
+ summary_f.write("To learn more about the log, please refer to: \n")
97
+ summary_f.write("1. https://github.com/triton-inference-server/server/blob/main/docs/user_guide/metrics.md \n")
98
+ summary_f.write("2. https://github.com/triton-inference-server/server/issues/5374 \n\n")
99
+ summary_f.write(
100
+ "To better improve throughput, we always would like let requests wait in the queue for a while, and then execute them with a larger batch size. \n"
101
+ )
102
+ summary_f.write(
103
+ "However, there is a trade-off between the increased queue time and the increased batch size. \n"
104
+ )
105
+ summary_f.write(
106
+ "You may change 'max_queue_delay_microseconds' and 'preferred_batch_size' in the model configuration file to achieve this. \n"
107
+ )
108
+ summary_f.write(
109
+ "See https://github.com/triton-inference-server/server/blob/main/docs/user_guide/model_configuration.md#delayed-batching for more details. \n\n"
110
+ )
111
+ for model_state in model_stats:
112
+ if "last_inference" not in model_state:
113
+ continue
114
+ summary_f.write(f"model name is {model_state['name']} \n")
115
+ model_inference_stats = model_state["inference_stats"]
116
+ total_queue_time_s = int(model_inference_stats["queue"]["ns"]) / 1e9
117
+ total_infer_time_s = int(model_inference_stats["compute_infer"]["ns"]) / 1e9
118
+ total_input_time_s = int(model_inference_stats["compute_input"]["ns"]) / 1e9
119
+ total_output_time_s = int(model_inference_stats["compute_output"]["ns"]) / 1e9
120
+ summary_f.write(
121
+ f"queue time {total_queue_time_s:<5.2f} s, compute infer time {total_infer_time_s:<5.2f} s, compute input time {total_input_time_s:<5.2f} s, compute output time {total_output_time_s:<5.2f} s \n" # noqa
122
+ )
123
+ model_batch_stats = model_state["batch_stats"]
124
+ for batch in model_batch_stats:
125
+ batch_size = int(batch["batch_size"])
126
+ compute_input = batch["compute_input"]
127
+ compute_output = batch["compute_output"]
128
+ compute_infer = batch["compute_infer"]
129
+ batch_count = int(compute_infer["count"])
130
+ assert compute_infer["count"] == compute_output["count"] == compute_input["count"]
131
+ compute_infer_time_ms = int(compute_infer["ns"]) / 1e6
132
+ compute_input_time_ms = int(compute_input["ns"]) / 1e6
133
+ compute_output_time_ms = int(compute_output["ns"]) / 1e6
134
+ summary_f.write(
135
+ f"execuate inference with batch_size {batch_size:<2} total {batch_count:<5} times, total_infer_time {compute_infer_time_ms:<9.2f} ms, avg_infer_time {compute_infer_time_ms:<9.2f}/{batch_count:<5}={compute_infer_time_ms / batch_count:.2f} ms, avg_infer_time_per_sample {compute_infer_time_ms:<9.2f}/{batch_count:<5}/{batch_size}={compute_infer_time_ms / batch_count / batch_size:.2f} ms \n" # noqa
136
+ )
137
+ summary_f.write(
138
+ f"input {compute_input_time_ms:<9.2f} ms, avg {compute_input_time_ms / batch_count:.2f} ms, " # noqa
139
+ )
140
+ summary_f.write(
141
+ f"output {compute_output_time_ms:<9.2f} ms, avg {compute_output_time_ms / batch_count:.2f} ms \n" # noqa
142
+ )
143
+
144
+
145
+ def get_args():
146
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
147
+
148
+ parser.add_argument(
149
+ "--server-addr",
150
+ type=str,
151
+ default="localhost",
152
+ help="Address of the server",
153
+ )
154
+
155
+ parser.add_argument(
156
+ "--server-port",
157
+ type=int,
158
+ default=8001,
159
+ help="Grpc port of the triton server, default is 8001",
160
+ )
161
+
162
+ parser.add_argument(
163
+ "--reference-audio",
164
+ type=str,
165
+ default=None,
166
+ help="Path to a single audio file. It can't be specified at the same time with --manifest-dir",
167
+ )
168
+
169
+ parser.add_argument(
170
+ "--reference-text",
171
+ type=str,
172
+ default="",
173
+ help="",
174
+ )
175
+
176
+ parser.add_argument(
177
+ "--target-text",
178
+ type=str,
179
+ default="",
180
+ help="",
181
+ )
182
+
183
+ parser.add_argument(
184
+ "--huggingface-dataset",
185
+ type=str,
186
+ default="yuekai/seed_tts",
187
+ help="dataset name in huggingface dataset hub",
188
+ )
189
+
190
+ parser.add_argument(
191
+ "--split-name",
192
+ type=str,
193
+ default="wenetspeech4tts",
194
+ choices=["wenetspeech4tts", "test_zh", "test_en", "test_hard"],
195
+ help="dataset split name, default is 'test'",
196
+ )
197
+
198
+ parser.add_argument(
199
+ "--manifest-path",
200
+ type=str,
201
+ default=None,
202
+ help="Path to the manifest dir which includes wav.scp trans.txt files.",
203
+ )
204
+
205
+ parser.add_argument(
206
+ "--model-name",
207
+ type=str,
208
+ default="f5_tts",
209
+ choices=["f5_tts", "spark_tts"],
210
+ help="triton model_repo module name to request: transducer for k2, attention_rescoring for wenet offline, streaming_wenet for wenet streaming, infer_pipeline for paraformer large offline",
211
+ )
212
+
213
+ parser.add_argument(
214
+ "--num-tasks",
215
+ type=int,
216
+ default=1,
217
+ help="Number of concurrent tasks for sending",
218
+ )
219
+
220
+ parser.add_argument(
221
+ "--log-interval",
222
+ type=int,
223
+ default=5,
224
+ help="Controls how frequently we print the log.",
225
+ )
226
+
227
+ parser.add_argument(
228
+ "--compute-wer",
229
+ action="store_true",
230
+ default=False,
231
+ help="""True to compute WER.
232
+ """,
233
+ )
234
+
235
+ parser.add_argument(
236
+ "--log-dir",
237
+ type=str,
238
+ required=False,
239
+ default="./tmp",
240
+ help="log directory",
241
+ )
242
+
243
+ # --- Added arguments ---
244
+ parser.add_argument(
245
+ "--mode",
246
+ type=str,
247
+ default="offline",
248
+ choices=["offline", "streaming"],
249
+ help="Select offline or streaming benchmark mode."
250
+ )
251
+ parser.add_argument(
252
+ "--chunk-overlap-duration",
253
+ type=float,
254
+ default=0.1,
255
+ help="Chunk overlap duration for streaming reconstruction (in seconds)."
256
+ )
257
+ # --- End Added arguments ---
258
+
259
+ return parser.parse_args()
260
+
261
+
262
+ def load_audio(wav_path, target_sample_rate=16000):
263
+ assert target_sample_rate == 16000, "hard coding in server"
264
+ if isinstance(wav_path, dict):
265
+ waveform = wav_path["array"]
266
+ sample_rate = wav_path["sampling_rate"]
267
+ else:
268
+ waveform, sample_rate = sf.read(wav_path)
269
+ if sample_rate != target_sample_rate:
270
+ from scipy.signal import resample
271
+
272
+ num_samples = int(len(waveform) * (target_sample_rate / sample_rate))
273
+ waveform = resample(waveform, num_samples)
274
+ return waveform, target_sample_rate
275
+
276
+ def prepare_request_input_output(
277
+ protocol_client, # Can be grpcclient_aio or grpcclient_sync
278
+ waveform,
279
+ reference_text,
280
+ target_text,
281
+ sample_rate=16000,
282
+ padding_duration: int = None # Optional padding for offline mode
283
+ ):
284
+ """Prepares inputs for Triton inference (offline or streaming)."""
285
+ assert len(waveform.shape) == 1, "waveform should be 1D"
286
+ lengths = np.array([[len(waveform)]], dtype=np.int32)
287
+
288
+ # Apply padding only if padding_duration is provided (for offline)
289
+ if padding_duration:
290
+ duration = len(waveform) / sample_rate
291
+ # Estimate target duration based on text length ratio (crude estimation)
292
+ # Avoid division by zero if reference_text is empty
293
+ if reference_text:
294
+ estimated_target_duration = duration / len(reference_text) * len(target_text)
295
+ else:
296
+ estimated_target_duration = duration # Assume target duration similar to reference if no text
297
+
298
+ # Calculate required samples based on estimated total duration
299
+ required_total_samples = padding_duration * sample_rate * (
300
+ (int(estimated_target_duration + duration) // padding_duration) + 1
301
+ )
302
+ samples = np.zeros((1, required_total_samples), dtype=np.float32)
303
+ samples[0, : len(waveform)] = waveform
304
+ else:
305
+ # No padding for streaming or if padding_duration is None
306
+ samples = waveform.reshape(1, -1).astype(np.float32)
307
+
308
+ # Common input creation logic
309
+ inputs = [
310
+ protocol_client.InferInput("reference_wav", samples.shape, np_to_triton_dtype(samples.dtype)),
311
+ protocol_client.InferInput(
312
+ "reference_wav_len", lengths.shape, np_to_triton_dtype(lengths.dtype)
313
+ ),
314
+ protocol_client.InferInput("reference_text", [1, 1], "BYTES"),
315
+ protocol_client.InferInput("target_text", [1, 1], "BYTES"),
316
+ ]
317
+ inputs[0].set_data_from_numpy(samples)
318
+ inputs[1].set_data_from_numpy(lengths)
319
+
320
+ input_data_numpy = np.array([reference_text], dtype=object)
321
+ input_data_numpy = input_data_numpy.reshape((1, 1))
322
+ inputs[2].set_data_from_numpy(input_data_numpy)
323
+
324
+ input_data_numpy = np.array([target_text], dtype=object)
325
+ input_data_numpy = input_data_numpy.reshape((1, 1))
326
+ inputs[3].set_data_from_numpy(input_data_numpy)
327
+
328
+ outputs = [protocol_client.InferRequestedOutput("waveform")]
329
+
330
+ return inputs, outputs
331
+
332
+ def run_sync_streaming_inference(
333
+ sync_triton_client: tritonclient.grpc.InferenceServerClient,
334
+ model_name: str,
335
+ inputs: list,
336
+ outputs: list,
337
+ request_id: str,
338
+ user_data: UserData,
339
+ chunk_overlap_duration: float,
340
+ save_sample_rate: int,
341
+ audio_save_path: str,
342
+ ):
343
+ """Helper function to run the blocking sync streaming call."""
344
+ start_time_total = time.time()
345
+ user_data.record_start_time() # Record start time for first chunk latency calculation
346
+
347
+ # Establish stream
348
+ sync_triton_client.start_stream(callback=functools.partial(callback, user_data))
349
+
350
+ # Send request
351
+ sync_triton_client.async_stream_infer(
352
+ model_name,
353
+ inputs,
354
+ request_id=request_id,
355
+ outputs=outputs,
356
+ enable_empty_final_response=True,
357
+ )
358
+
359
+ # Process results
360
+ audios = []
361
+ while True:
362
+ try:
363
+ result = user_data._completed_requests.get() # Add timeout
364
+ if isinstance(result, InferenceServerException):
365
+ print(f"Received InferenceServerException: {result}")
366
+ sync_triton_client.stop_stream()
367
+ return None, None, None # Indicate error
368
+ # Get response metadata
369
+ response = result.get_response()
370
+ final = response.parameters["triton_final_response"].bool_param
371
+ if final is True:
372
+ break
373
+
374
+ audio_chunk = result.as_numpy("waveform").reshape(-1)
375
+ if audio_chunk.size > 0: # Only append non-empty chunks
376
+ audios.append(audio_chunk)
377
+ else:
378
+ print("Warning: received empty audio chunk.")
379
+
380
+ except queue.Empty:
381
+ print(f"Timeout waiting for response for request id {request_id}")
382
+ sync_triton_client.stop_stream()
383
+ return None, None, None # Indicate error
384
+
385
+ sync_triton_client.stop_stream()
386
+ end_time_total = time.time()
387
+ total_request_latency = end_time_total - start_time_total
388
+ first_chunk_latency = user_data.get_first_chunk_latency()
389
+
390
+ # Reconstruct audio using cross-fade (from client_grpc_streaming.py)
391
+ actual_duration = 0
392
+ if audios:
393
+ cross_fade_samples = int(chunk_overlap_duration * save_sample_rate)
394
+ fade_out = np.linspace(1, 0, cross_fade_samples)
395
+ fade_in = np.linspace(0, 1, cross_fade_samples)
396
+ reconstructed_audio = None
397
+
398
+ # Simplified reconstruction based on client_grpc_streaming.py
399
+ if not audios:
400
+ print("Warning: No audio chunks received.")
401
+ reconstructed_audio = np.array([], dtype=np.float32) # Empty array
402
+ elif len(audios) == 1:
403
+ reconstructed_audio = audios[0]
404
+ else:
405
+ reconstructed_audio = audios[0][:-cross_fade_samples] # Start with first chunk minus overlap
406
+ for i in range(1, len(audios)):
407
+ # Cross-fade section
408
+ cross_faded_overlap = (audios[i][:cross_fade_samples] * fade_in +
409
+ audios[i - 1][-cross_fade_samples:] * fade_out)
410
+ # Middle section of the current chunk
411
+ middle_part = audios[i][cross_fade_samples:-cross_fade_samples]
412
+ # Concatenate
413
+ reconstructed_audio = np.concatenate([reconstructed_audio, cross_faded_overlap, middle_part])
414
+ # Add the last part of the final chunk
415
+ reconstructed_audio = np.concatenate([reconstructed_audio, audios[-1][-cross_fade_samples:]])
416
+
417
+ if reconstructed_audio is not None and reconstructed_audio.size > 0:
418
+ actual_duration = len(reconstructed_audio) / save_sample_rate
419
+ # Save reconstructed audio
420
+ os.makedirs(os.path.dirname(audio_save_path), exist_ok=True)
421
+ sf.write(audio_save_path, reconstructed_audio, save_sample_rate, "PCM_16")
422
+ else:
423
+ print("Warning: No audio chunks received or reconstructed.")
424
+ actual_duration = 0 # Set duration to 0 if no audio
425
+
426
+ else:
427
+ print("Warning: No audio chunks received.")
428
+ actual_duration = 0
429
+
430
+ return total_request_latency, first_chunk_latency, actual_duration
431
+
432
+
433
+ async def send_streaming(
434
+ manifest_item_list: list,
435
+ name: str,
436
+ server_url: str, # Changed from sync_triton_client
437
+ protocol_client: types.ModuleType,
438
+ log_interval: int,
439
+ model_name: str,
440
+ audio_save_dir: str = "./",
441
+ save_sample_rate: int = 16000,
442
+ chunk_overlap_duration: float = 0.1,
443
+ padding_duration: int = None,
444
+ ):
445
+ total_duration = 0.0
446
+ latency_data = []
447
+ task_id = int(name[5:])
448
+ sync_triton_client = None # Initialize client variable
449
+
450
+ try: # Wrap in try...finally to ensure client closing
451
+ print(f"{name}: Initializing sync client for streaming...")
452
+ sync_triton_client = grpcclient_sync.InferenceServerClient(url=server_url, verbose=False) # Create client here
453
+
454
+ print(f"{name}: Starting streaming processing for {len(manifest_item_list)} items.")
455
+ for i, item in enumerate(manifest_item_list):
456
+ if i % log_interval == 0:
457
+ print(f"{name}: Processing item {i}/{len(manifest_item_list)}")
458
+
459
+ try:
460
+ waveform, sample_rate = load_audio(item["audio_filepath"], target_sample_rate=16000)
461
+ reference_text, target_text = item["reference_text"], item["target_text"]
462
+
463
+ inputs, outputs = prepare_request_input_output(
464
+ protocol_client,
465
+ waveform,
466
+ reference_text,
467
+ target_text,
468
+ sample_rate,
469
+ padding_duration=padding_duration
470
+ )
471
+ request_id = str(uuid.uuid4())
472
+ user_data = UserData()
473
+
474
+ audio_save_path = os.path.join(audio_save_dir, f"{item['target_audio_path']}.wav")
475
+
476
+ total_request_latency, first_chunk_latency, actual_duration = await asyncio.to_thread(
477
+ run_sync_streaming_inference,
478
+ sync_triton_client,
479
+ model_name,
480
+ inputs,
481
+ outputs,
482
+ request_id,
483
+ user_data,
484
+ chunk_overlap_duration,
485
+ save_sample_rate,
486
+ audio_save_path
487
+ )
488
+
489
+ if total_request_latency is not None:
490
+ print(f"{name}: Item {i} - First Chunk Latency: {first_chunk_latency:.4f}s, Total Latency: {total_request_latency:.4f}s, Duration: {actual_duration:.4f}s")
491
+ latency_data.append((total_request_latency, first_chunk_latency, actual_duration))
492
+ total_duration += actual_duration
493
+ else:
494
+ print(f"{name}: Item {i} failed.")
495
+
496
+
497
+ except FileNotFoundError:
498
+ print(f"Error: Audio file not found for item {i}: {item['audio_filepath']}")
499
+ except Exception as e:
500
+ print(f"Error processing item {i} ({item['target_audio_path']}): {e}")
501
+ import traceback
502
+ traceback.print_exc()
503
+
504
+
505
+ finally: # Ensure client is closed
506
+ if sync_triton_client:
507
+ try:
508
+ print(f"{name}: Closing sync client...")
509
+ sync_triton_client.close()
510
+ except Exception as e:
511
+ print(f"{name}: Error closing sync client: {e}")
512
+
513
+
514
+ print(f"{name}: Finished streaming processing. Total duration synthesized: {total_duration:.4f}s")
515
+ return total_duration, latency_data
516
+
517
+ async def send(
518
+ manifest_item_list: list,
519
+ name: str,
520
+ triton_client: tritonclient.grpc.aio.InferenceServerClient,
521
+ protocol_client: types.ModuleType,
522
+ log_interval: int,
523
+ model_name: str,
524
+ padding_duration: int = None,
525
+ audio_save_dir: str = "./",
526
+ save_sample_rate: int = 16000,
527
+ ):
528
+ total_duration = 0.0
529
+ latency_data = []
530
+ task_id = int(name[5:])
531
+
532
+ print(f"manifest_item_list: {manifest_item_list}")
533
+ for i, item in enumerate(manifest_item_list):
534
+ if i % log_interval == 0:
535
+ print(f"{name}: {i}/{len(manifest_item_list)}")
536
+ waveform, sample_rate = load_audio(item["audio_filepath"], target_sample_rate=16000)
537
+ reference_text, target_text = item["reference_text"], item["target_text"]
538
+
539
+ inputs, outputs = prepare_request_input_output(
540
+ protocol_client,
541
+ waveform,
542
+ reference_text,
543
+ target_text,
544
+ sample_rate,
545
+ padding_duration=padding_duration
546
+ )
547
+ sequence_id = 100000000 + i + task_id * 10
548
+ start = time.time()
549
+ response = await triton_client.infer(model_name, inputs, request_id=str(sequence_id), outputs=outputs)
550
+
551
+ audio = response.as_numpy("waveform").reshape(-1)
552
+ actual_duration = len(audio) / save_sample_rate
553
+
554
+ end = time.time() - start
555
+
556
+ audio_save_path = os.path.join(audio_save_dir, f"{item['target_audio_path']}.wav")
557
+ sf.write(audio_save_path, audio, save_sample_rate, "PCM_16")
558
+
559
+ latency_data.append((end, actual_duration))
560
+ total_duration += actual_duration
561
+
562
+ return total_duration, latency_data
563
+
564
+
565
+ def load_manifests(manifest_path):
566
+ with open(manifest_path, "r") as f:
567
+ manifest_list = []
568
+ for line in f:
569
+ assert len(line.strip().split("|")) == 4
570
+ utt, prompt_text, prompt_wav, gt_text = line.strip().split("|")
571
+ utt = Path(utt).stem
572
+ # gt_wav = os.path.join(os.path.dirname(manifest_path), "wavs", utt + ".wav")
573
+ if not os.path.isabs(prompt_wav):
574
+ prompt_wav = os.path.join(os.path.dirname(manifest_path), prompt_wav)
575
+ manifest_list.append(
576
+ {
577
+ "audio_filepath": prompt_wav,
578
+ "reference_text": prompt_text,
579
+ "target_text": gt_text,
580
+ "target_audio_path": utt,
581
+ }
582
+ )
583
+ return manifest_list
584
+
585
+
586
+ def split_data(data, k):
587
+ n = len(data)
588
+ if n < k:
589
+ print(f"Warning: the length of the input list ({n}) is less than k ({k}). Setting k to {n}.")
590
+ k = n
591
+
592
+ quotient = n // k
593
+ remainder = n % k
594
+
595
+ result = []
596
+ start = 0
597
+ for i in range(k):
598
+ if i < remainder:
599
+ end = start + quotient + 1
600
+ else:
601
+ end = start + quotient
602
+
603
+ result.append(data[start:end])
604
+ start = end
605
+
606
+ return result
607
+
608
+ async def main():
609
+ args = get_args()
610
+ url = f"{args.server_addr}:{args.server_port}"
611
+
612
+ # --- Client Initialization based on mode ---
613
+ triton_client = None
614
+ protocol_client = None
615
+ if args.mode == "offline":
616
+ print("Initializing gRPC client for offline mode...")
617
+ # Use the async client for offline tasks
618
+ triton_client = grpcclient_aio.InferenceServerClient(url=url, verbose=False)
619
+ protocol_client = grpcclient_aio
620
+ elif args.mode == "streaming":
621
+ print("Initializing gRPC client for streaming mode...")
622
+ # Use the sync client for streaming tasks, handled via asyncio.to_thread
623
+ # We will create one sync client instance PER TASK inside send_streaming.
624
+ # triton_client = grpcclient_sync.InferenceServerClient(url=url, verbose=False) # REMOVED: Client created per task now
625
+ protocol_client = grpcclient_sync # protocol client for input prep
626
+ else:
627
+ raise ValueError(f"Invalid mode: {args.mode}")
628
+ # --- End Client Initialization ---
629
+
630
+ if args.reference_audio:
631
+ args.num_tasks = 1
632
+ args.log_interval = 1
633
+ manifest_item_list = [
634
+ {
635
+ "reference_text": args.reference_text,
636
+ "target_text": args.target_text,
637
+ "audio_filepath": args.reference_audio,
638
+ "target_audio_path": "test",
639
+ }
640
+ ]
641
+ elif args.huggingface_dataset:
642
+ import datasets
643
+
644
+ dataset = datasets.load_dataset(
645
+ args.huggingface_dataset,
646
+ split=args.split_name,
647
+ trust_remote_code=True,
648
+ )
649
+ manifest_item_list = []
650
+ for i in range(len(dataset)):
651
+ manifest_item_list.append(
652
+ {
653
+ "audio_filepath": dataset[i]["prompt_audio"],
654
+ "reference_text": dataset[i]["prompt_text"],
655
+ "target_audio_path": dataset[i]["id"],
656
+ "target_text": dataset[i]["target_text"],
657
+ }
658
+ )
659
+ else:
660
+ manifest_item_list = load_manifests(args.manifest_path)
661
+
662
+ num_tasks = min(args.num_tasks, len(manifest_item_list))
663
+ manifest_item_list = split_data(manifest_item_list, num_tasks)
664
+
665
+ os.makedirs(args.log_dir, exist_ok=True)
666
+ tasks = []
667
+ start_time = time.time()
668
+ for i in range(num_tasks):
669
+ # --- Task Creation based on mode ---
670
+ if args.mode == "offline":
671
+ task = asyncio.create_task(
672
+ send(
673
+ manifest_item_list[i],
674
+ name=f"task-{i}",
675
+ triton_client=triton_client,
676
+ protocol_client=protocol_client,
677
+ log_interval=args.log_interval,
678
+ model_name=args.model_name,
679
+ audio_save_dir=args.log_dir,
680
+ padding_duration=1,
681
+ save_sample_rate=24000 if args.model_name == "f5_tts" else 16000,
682
+ )
683
+ )
684
+ elif args.mode == "streaming":
685
+ task = asyncio.create_task(
686
+ send_streaming(
687
+ manifest_item_list[i],
688
+ name=f"task-{i}",
689
+ server_url=url, # Pass URL instead of client
690
+ protocol_client=protocol_client,
691
+ log_interval=args.log_interval,
692
+ model_name=args.model_name,
693
+ audio_save_dir=args.log_dir,
694
+ padding_duration=10,
695
+ save_sample_rate=24000 if args.model_name == "f5_tts" else 16000,
696
+ chunk_overlap_duration=args.chunk_overlap_duration,
697
+ )
698
+ )
699
+ # --- End Task Creation ---
700
+ tasks.append(task)
701
+
702
+ ans_list = await asyncio.gather(*tasks)
703
+
704
+ end_time = time.time()
705
+ elapsed = end_time - start_time
706
+
707
+ total_duration = 0.0
708
+ latency_data = []
709
+ for ans in ans_list:
710
+ if ans:
711
+ total_duration += ans[0]
712
+ latency_data.extend(ans[1]) # Use extend for list of lists
713
+ else:
714
+ print("Warning: A task returned None, possibly due to an error.")
715
+
716
+
717
+ if total_duration == 0:
718
+ print("Total synthesized duration is zero. Cannot calculate RTF or latency percentiles.")
719
+ rtf = float('inf')
720
+ else:
721
+ rtf = elapsed / total_duration
722
+
723
+ s = f"Mode: {args.mode}\n"
724
+ s += f"RTF: {rtf:.4f}\n"
725
+ s += f"total_duration: {total_duration:.3f} seconds\n"
726
+ s += f"({total_duration / 3600:.2f} hours)\n"
727
+ s += f"processing time: {elapsed:.3f} seconds ({elapsed / 3600:.2f} hours)\n"
728
+
729
+ # --- Statistics Reporting based on mode ---
730
+ if latency_data:
731
+ if args.mode == "offline":
732
+ # Original offline latency calculation
733
+ latency_list = [chunk_end for (chunk_end, chunk_duration) in latency_data]
734
+ if latency_list:
735
+ latency_ms = sum(latency_list) / float(len(latency_list)) * 1000.0
736
+ latency_variance = np.var(latency_list, dtype=np.float64) * 1000.0
737
+ s += f"latency_variance: {latency_variance:.2f}\n"
738
+ s += f"latency_50_percentile_ms: {np.percentile(latency_list, 50) * 1000.0:.2f}\n"
739
+ s += f"latency_90_percentile_ms: {np.percentile(latency_list, 90) * 1000.0:.2f}\n"
740
+ s += f"latency_95_percentile_ms: {np.percentile(latency_list, 95) * 1000.0:.2f}\n"
741
+ s += f"latency_99_percentile_ms: {np.percentile(latency_list, 99) * 1000.0:.2f}\n"
742
+ s += f"average_latency_ms: {latency_ms:.2f}\n"
743
+ else:
744
+ s += "No latency data collected for offline mode.\n"
745
+
746
+ elif args.mode == "streaming":
747
+ # Calculate stats for total request latency and first chunk latency
748
+ total_latency_list = [total for (total, first, duration) in latency_data if total is not None]
749
+ first_chunk_latency_list = [first for (total, first, duration) in latency_data if first is not None]
750
+
751
+ s += "\n--- Total Request Latency ---\n"
752
+ if total_latency_list:
753
+ avg_total_latency_ms = sum(total_latency_list) / len(total_latency_list) * 1000.0
754
+ variance_total_latency = np.var(total_latency_list, dtype=np.float64) * 1000.0
755
+ s += f"total_request_latency_variance: {variance_total_latency:.2f}\n"
756
+ s += f"total_request_latency_50_percentile_ms: {np.percentile(total_latency_list, 50) * 1000.0:.2f}\n"
757
+ s += f"total_request_latency_90_percentile_ms: {np.percentile(total_latency_list, 90) * 1000.0:.2f}\n"
758
+ s += f"total_request_latency_95_percentile_ms: {np.percentile(total_latency_list, 95) * 1000.0:.2f}\n"
759
+ s += f"total_request_latency_99_percentile_ms: {np.percentile(total_latency_list, 99) * 1000.0:.2f}\n"
760
+ s += f"average_total_request_latency_ms: {avg_total_latency_ms:.2f}\n"
761
+ else:
762
+ s += "No total request latency data collected.\n"
763
+
764
+ s += "\n--- First Chunk Latency ---\n"
765
+ if first_chunk_latency_list:
766
+ avg_first_chunk_latency_ms = sum(first_chunk_latency_list) / len(first_chunk_latency_list) * 1000.0
767
+ variance_first_chunk_latency = np.var(first_chunk_latency_list, dtype=np.float64) * 1000.0
768
+ s += f"first_chunk_latency_variance: {variance_first_chunk_latency:.2f}\n"
769
+ s += f"first_chunk_latency_50_percentile_ms: {np.percentile(first_chunk_latency_list, 50) * 1000.0:.2f}\n"
770
+ s += f"first_chunk_latency_90_percentile_ms: {np.percentile(first_chunk_latency_list, 90) * 1000.0:.2f}\n"
771
+ s += f"first_chunk_latency_95_percentile_ms: {np.percentile(first_chunk_latency_list, 95) * 1000.0:.2f}\n"
772
+ s += f"first_chunk_latency_99_percentile_ms: {np.percentile(first_chunk_latency_list, 99) * 1000.0:.2f}\n"
773
+ s += f"average_first_chunk_latency_ms: {avg_first_chunk_latency_ms:.2f}\n"
774
+ else:
775
+ s += "No first chunk latency data collected (check for errors or if all requests failed before first chunk).\n"
776
+ else:
777
+ s += "No latency data collected.\n"
778
+ # --- End Statistics Reporting ---
779
+
780
+ print(s)
781
+ if args.manifest_path:
782
+ name = Path(args.manifest_path).stem
783
+ elif args.split_name:
784
+ name = args.split_name
785
+ elif args.reference_audio:
786
+ name = Path(args.reference_audio).stem
787
+ else:
788
+ name = "results" # Default name if no manifest/split/audio provided
789
+ with open(f"{args.log_dir}/rtf-{name}.txt", "w") as f:
790
+ f.write(s)
791
+
792
+ # --- Statistics Fetching using temporary Async Client ---
793
+ # Use a separate async client for fetching stats regardless of mode
794
+ stats_client = None
795
+ try:
796
+ print("Initializing temporary async client for fetching stats...")
797
+ stats_client = grpcclient_aio.InferenceServerClient(url=url, verbose=False)
798
+ print("Fetching inference statistics...")
799
+ # Fetching for all models, filtering might be needed depending on server setup
800
+ stats = await stats_client.get_inference_statistics(model_name="", as_json=True)
801
+ print("Fetching model config...")
802
+ metadata = await stats_client.get_model_config(model_name=args.model_name, as_json=True)
803
+
804
+ write_triton_stats(stats, f"{args.log_dir}/stats_summary-{name}.txt")
805
+
806
+ with open(f"{args.log_dir}/model_config-{name}.json", "w") as f:
807
+ json.dump(metadata, f, indent=4)
808
+
809
+ except Exception as e:
810
+ print(f"Could not retrieve statistics or config: {e}")
811
+ finally:
812
+ if stats_client:
813
+ try:
814
+ print("Closing temporary async stats client...")
815
+ await stats_client.close()
816
+ except Exception as e:
817
+ print(f"Error closing async stats client: {e}")
818
+ # --- End Statistics Fetching ---
819
+
820
+
821
+ if __name__ == "__main__":
822
+ # asyncio.run(main()) # Use TaskGroup for better exception handling if needed
823
+ async def run_main():
824
+ try:
825
+ await main()
826
+ except Exception as e:
827
+ print(f"An error occurred in main: {e}")
828
+ import traceback
829
+ traceback.print_exc()
830
+
831
+ asyncio.run(run_main())
runtime/triton_trtllm/client_http.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # Redistribution and use in source and binary forms, with or without
4
+ # modification, are permitted provided that the following conditions
5
+ # are met:
6
+ # * Redistributions of source code must retain the above copyright
7
+ # notice, this list of conditions and the following disclaimer.
8
+ # * Redistributions in binary form must reproduce the above copyright
9
+ # notice, this list of conditions and the following disclaimer in the
10
+ # documentation and/or other materials provided with the distribution.
11
+ # * Neither the name of NVIDIA CORPORATION nor the names of its
12
+ # contributors may be used to endorse or promote products derived
13
+ # from this software without specific prior written permission.
14
+ #
15
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
16
+ # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
18
+ # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
19
+ # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20
+ # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21
+ # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22
+ # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
23
+ # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24
+ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26
+ import requests
27
+ import soundfile as sf
28
+ import json
29
+ import numpy as np
30
+ import argparse
31
+
32
+ def get_args():
33
+ parser = argparse.ArgumentParser(
34
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
35
+ )
36
+
37
+ parser.add_argument(
38
+ "--server-url",
39
+ type=str,
40
+ default="localhost:8000",
41
+ help="Address of the server",
42
+ )
43
+
44
+ parser.add_argument(
45
+ "--reference-audio",
46
+ type=str,
47
+ default="../../example/prompt_audio.wav",
48
+ help="Path to a single audio file. It can't be specified at the same time with --manifest-dir",
49
+ )
50
+
51
+ parser.add_argument(
52
+ "--reference-text",
53
+ type=str,
54
+ default="吃燕窝就选燕之屋,本节目由26年专注高品质燕窝的燕之屋冠名播出。豆奶牛奶换着喝,营养更均衡,本节目由豆本豆豆奶特约播出。",
55
+ help="",
56
+ )
57
+
58
+ parser.add_argument(
59
+ "--target-text",
60
+ type=str,
61
+ default="身临其境,换新体验。塑造开源语音合成新范式,让智能语音更自然。",
62
+ help="",
63
+ )
64
+
65
+ parser.add_argument(
66
+ "--model-name",
67
+ type=str,
68
+ default="spark_tts",
69
+ choices=[
70
+ "f5_tts", "spark_tts"
71
+ ],
72
+ help="triton model_repo module name to request: transducer for k2, attention_rescoring for wenet offline, streaming_wenet for wenet streaming, infer_pipeline for paraformer large offline",
73
+ )
74
+
75
+ parser.add_argument(
76
+ "--output-audio",
77
+ type=str,
78
+ default="output.wav",
79
+ help="Path to save the output audio",
80
+ )
81
+ return parser.parse_args()
82
+
83
+ def prepare_request(
84
+ waveform,
85
+ reference_text,
86
+ target_text,
87
+ sample_rate=16000,
88
+ padding_duration: int = None,
89
+ audio_save_dir: str = "./",
90
+ ):
91
+ assert len(waveform.shape) == 1, "waveform should be 1D"
92
+ lengths = np.array([[len(waveform)]], dtype=np.int32)
93
+ if padding_duration:
94
+ # padding to nearset 10 seconds
95
+ samples = np.zeros(
96
+ (
97
+ 1,
98
+ padding_duration
99
+ * sample_rate
100
+ * ((int(duration) // padding_duration) + 1),
101
+ ),
102
+ dtype=np.float32,
103
+ )
104
+
105
+ samples[0, : len(waveform)] = waveform
106
+ else:
107
+ samples = waveform
108
+
109
+ samples = samples.reshape(1, -1).astype(np.float32)
110
+
111
+ data = {
112
+ "inputs":[
113
+ {
114
+ "name": "reference_wav",
115
+ "shape": samples.shape,
116
+ "datatype": "FP32",
117
+ "data": samples.tolist()
118
+ },
119
+ {
120
+ "name": "reference_wav_len",
121
+ "shape": lengths.shape,
122
+ "datatype": "INT32",
123
+ "data": lengths.tolist(),
124
+ },
125
+ {
126
+ "name": "reference_text",
127
+ "shape": [1, 1],
128
+ "datatype": "BYTES",
129
+ "data": [reference_text]
130
+ },
131
+ {
132
+ "name": "target_text",
133
+ "shape": [1, 1],
134
+ "datatype": "BYTES",
135
+ "data": [target_text]
136
+ }
137
+ ]
138
+ }
139
+
140
+ return data
141
+
142
+ if __name__ == "__main__":
143
+ args = get_args()
144
+ server_url = args.server_url
145
+ if not server_url.startswith(("http://", "https://")):
146
+ server_url = f"http://{server_url}"
147
+
148
+ url = f"{server_url}/v2/models/{args.model_name}/infer"
149
+ waveform, sr = sf.read(args.reference_audio)
150
+ assert sr == 16000, "sample rate hardcoded in server"
151
+
152
+ samples = np.array(waveform, dtype=np.float32)
153
+ data = prepare_request(samples, args.reference_text, args.target_text)
154
+
155
+ rsp = requests.post(
156
+ url,
157
+ headers={"Content-Type": "application/json"},
158
+ json=data,
159
+ verify=False,
160
+ params={"request_id": '0'}
161
+ )
162
+ result = rsp.json()
163
+ audio = result["outputs"][0]["data"]
164
+ audio = np.array(audio, dtype=np.float32)
165
+ sf.write(args.output_audio, audio, 16000, "PCM_16")
runtime/triton_trtllm/docker-compose.yml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ services:
2
+ tts:
3
+ image: soar97/triton-spark-tts:25.02
4
+ shm_size: '1gb'
5
+ ports:
6
+ - "8000:8000"
7
+ - "8001:8001"
8
+ - "8002:8002"
9
+ environment:
10
+ - PYTHONIOENCODING=utf-8
11
+ - MODEL_ID=${MODEL_ID}
12
+ deploy:
13
+ resources:
14
+ reservations:
15
+ devices:
16
+ - driver: nvidia
17
+ device_ids: ['0']
18
+ capabilities: [gpu]
19
+ command: >
20
+ /bin/bash -c "rm -rf Spark-TTS && git clone https://github.com/SparkAudio/Spark-TTS.git && cd Spark-TTS/runtime/triton_trtllm && bash run.sh 0 3"
runtime/triton_trtllm/model_repo/audio_tokenizer/1/model.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # Redistribution and use in source and binary forms, with or without
4
+ # modification, are permitted provided that the following conditions
5
+ # are met:
6
+ # * Redistributions of source code must retain the above copyright
7
+ # notice, this list of conditions and the following disclaimer.
8
+ # * Redistributions in binary form must reproduce the above copyright
9
+ # notice, this list of conditions and the following disclaimer in the
10
+ # documentation and/or other materials provided with the distribution.
11
+ # * Neither the name of NVIDIA CORPORATION nor the names of its
12
+ # contributors may be used to endorse or promote products derived
13
+ # from this software without specific prior written permission.
14
+ #
15
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
16
+ # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
18
+ # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
19
+ # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20
+ # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21
+ # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22
+ # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
23
+ # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24
+ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26
+ import json
27
+ import torch
28
+ from torch.utils.dlpack import to_dlpack
29
+
30
+ import triton_python_backend_utils as pb_utils
31
+
32
+ import os
33
+ import numpy as np
34
+
35
+ from sparktts.models.audio_tokenizer import BiCodecTokenizer
36
+
37
+ class TritonPythonModel:
38
+ """Triton Python model for audio tokenization.
39
+
40
+ This model takes reference audio input and extracts semantic and global tokens
41
+ using BiCodec tokenizer.
42
+ """
43
+
44
+ def initialize(self, args):
45
+ """Initialize the model.
46
+
47
+ Args:
48
+ args: Dictionary containing model configuration
49
+ """
50
+ # Parse model parameters
51
+ parameters = json.loads(args['model_config'])['parameters']
52
+ model_params = {k: v["string_value"] for k, v in parameters.items()}
53
+
54
+ # Initialize tokenizer
55
+ self.device = torch.device("cuda")
56
+ self.audio_tokenizer = BiCodecTokenizer(model_params["model_dir"],
57
+ device=self.device)
58
+
59
+ def get_ref_clip(self, wav: np.ndarray) -> np.ndarray:
60
+ """Extract reference audio clip for speaker embedding.
61
+
62
+ Args:
63
+ wav: Input waveform array
64
+
65
+ Returns:
66
+ Reference clip of fixed duration
67
+ """
68
+ SAMPLE_RATE = 16000
69
+ REF_SEGMENT_DURATION = 6 # seconds
70
+ LATENT_HOP_LENGTH = 320
71
+
72
+ ref_segment_length = (
73
+ int(SAMPLE_RATE * REF_SEGMENT_DURATION)
74
+ // LATENT_HOP_LENGTH
75
+ * LATENT_HOP_LENGTH
76
+ )
77
+ wav_length = len(wav)
78
+
79
+ if ref_segment_length > wav_length:
80
+ # Repeat and truncate if input is too short
81
+ repeat_times = ref_segment_length // wav_length + 1
82
+ wav = np.tile(wav, repeat_times)
83
+
84
+ return wav[:ref_segment_length]
85
+
86
+ def execute(self, requests):
87
+ """Execute inference on the batched requests.
88
+
89
+ Args:
90
+ requests: List of inference requests
91
+
92
+ Returns:
93
+ List of inference responses containing tokenized outputs
94
+ """
95
+ reference_wav_list = []
96
+ reference_wav_ref_clip_list = []
97
+
98
+ # Process each request in batch
99
+ for request in requests:
100
+ # Extract input tensors
101
+ wav_array = pb_utils.get_input_tensor_by_name(
102
+ request, "reference_wav").as_numpy()
103
+ wav_len = pb_utils.get_input_tensor_by_name(
104
+ request, "reference_wav_len").as_numpy().item()
105
+
106
+ # Prepare inputs
107
+ wav = wav_array[:, :wav_len].squeeze(0)
108
+ reference_wav_list.append(wav)
109
+
110
+ wav_ref_clip = self.get_ref_clip(wav)
111
+ reference_wav_ref_clip_list.append(torch.from_numpy(wav_ref_clip))
112
+
113
+ # Batch process through tokenizer
114
+ ref_wav_clip_tensor = torch.stack(reference_wav_ref_clip_list, dim=0)
115
+ wav2vec2_features = self.audio_tokenizer.extract_wav2vec2_features(
116
+ reference_wav_list)
117
+
118
+ audio_tokenizer_input = {
119
+ "ref_wav": ref_wav_clip_tensor.to(self.device),
120
+ "feat": wav2vec2_features.to(self.device),
121
+ }
122
+ semantic_tokens, global_tokens = self.audio_tokenizer.model.tokenize(
123
+ audio_tokenizer_input)
124
+
125
+ # Prepare responses
126
+ responses = []
127
+ for i in range(len(requests)):
128
+ global_tokens_tensor = pb_utils.Tensor.from_dlpack(
129
+ "global_tokens", to_dlpack(global_tokens[i]))
130
+ semantic_tokens_tensor = pb_utils.Tensor.from_dlpack(
131
+ "semantic_tokens", to_dlpack(semantic_tokens[i]))
132
+
133
+ inference_response = pb_utils.InferenceResponse(
134
+ output_tensors=[global_tokens_tensor, semantic_tokens_tensor])
135
+ responses.append(inference_response)
136
+
137
+ return responses
runtime/triton_trtllm/model_repo/audio_tokenizer/config.pbtxt ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ name: "audio_tokenizer"
16
+ backend: "python"
17
+ max_batch_size: ${triton_max_batch_size}
18
+ dynamic_batching {
19
+ max_queue_delay_microseconds: ${max_queue_delay_microseconds}
20
+ }
21
+ parameters [
22
+ {
23
+ key: "model_dir",
24
+ value: {string_value:"${model_dir}"}
25
+ }
26
+ ]
27
+
28
+ input [
29
+ {
30
+ name: "reference_wav"
31
+ data_type: TYPE_FP32
32
+ dims: [-1]
33
+ },
34
+ {
35
+ name: "reference_wav_len"
36
+ data_type: TYPE_INT32
37
+ dims: [1]
38
+ }
39
+ ]
40
+ output [
41
+ {
42
+ name: "global_tokens"
43
+ data_type: TYPE_INT32
44
+ dims: [-1]
45
+ },
46
+ {
47
+ name: "semantic_tokens"
48
+ data_type: TYPE_INT32
49
+ dims: [-1]
50
+ }
51
+ ]
52
+
53
+ instance_group [
54
+ {
55
+ count: 1
56
+ kind: KIND_CPU
57
+ }
58
+ ]
runtime/triton_trtllm/model_repo/spark_tts/1/model.py ADDED
@@ -0,0 +1,404 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # Redistribution and use in source and binary forms, with or without
4
+ # modification, are permitted provided that the following conditions
5
+ # are met:
6
+ # * Redistributions of source code must retain the above copyright
7
+ # notice, this list of conditions and the following disclaimer.
8
+ # * Redistributions in binary form must reproduce the above copyright
9
+ # notice, this list of conditions and the following disclaimer in the
10
+ # documentation and/or other materials provided with the distribution.
11
+ # * Neither the name of NVIDIA CORPORATION nor the names of its
12
+ # contributors may be used to endorse or promote products derived
13
+ # from this software without specific prior written permission.
14
+ #
15
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
16
+ # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
18
+ # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
19
+ # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20
+ # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21
+ # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22
+ # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
23
+ # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24
+ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26
+
27
+ import json
28
+ import math
29
+ import os
30
+ import re
31
+ from typing import Dict, List, Tuple, Optional, Union
32
+
33
+ import numpy as np
34
+ import torch
35
+ from torch.utils.dlpack import from_dlpack, to_dlpack
36
+ import triton_python_backend_utils as pb_utils
37
+ from transformers import AutoTokenizer
38
+
39
+ from sparktts.utils.token_parser import TASK_TOKEN_MAP
40
+
41
+ def process_prompt(
42
+ text: str,
43
+ prompt_text: Optional[str] = None,
44
+ global_token_ids: torch.Tensor = None,
45
+ semantic_token_ids: torch.Tensor = None,
46
+ ) -> Tuple[str, torch.Tensor]:
47
+ """
48
+ Process input for voice cloning.
49
+
50
+ Args:
51
+ text: The text input to be converted to speech.
52
+ prompt_text: Transcript of the prompt audio.
53
+ global_token_ids: Global token IDs extracted from reference audio.
54
+ semantic_token_ids: Semantic token IDs extracted from reference audio.
55
+
56
+ Returns:
57
+ Tuple containing the formatted input prompt and global token IDs.
58
+ """
59
+ # Convert global tokens to string format
60
+ global_tokens = "".join(
61
+ [f"<|bicodec_global_{i}|>" for i in global_token_ids.squeeze()]
62
+ )
63
+
64
+
65
+ # Prepare the input tokens for the model
66
+ if prompt_text is not None:
67
+ # Include semantic tokens when prompt text is provided
68
+ semantic_tokens = "".join(
69
+ [f"<|bicodec_semantic_{i}|>" for i in semantic_token_ids.squeeze()]
70
+ )
71
+
72
+ inputs = [
73
+ TASK_TOKEN_MAP["tts"],
74
+ "<|start_content|>",
75
+ prompt_text,
76
+ text,
77
+ "<|end_content|>",
78
+ "<|start_global_token|>",
79
+ global_tokens,
80
+ "<|end_global_token|>",
81
+ "<|start_semantic_token|>",
82
+ semantic_tokens,
83
+ ]
84
+ else:
85
+ # Without prompt text, exclude semantic tokens
86
+ inputs = [
87
+ TASK_TOKEN_MAP["tts"],
88
+ "<|start_content|>",
89
+ text,
90
+ "<|end_content|>",
91
+ "<|start_global_token|>",
92
+ global_tokens,
93
+ "<|end_global_token|>",
94
+ ]
95
+
96
+ # Join all input components into a single string
97
+ inputs = "".join(inputs)
98
+ return inputs, global_token_ids
99
+
100
+
101
+ class TritonPythonModel:
102
+ """Triton Python model for Spark TTS.
103
+
104
+ This model orchestrates the end-to-end TTS pipeline by coordinating
105
+ between audio tokenizer, LLM, and vocoder components.
106
+ """
107
+
108
+ def initialize(self, args):
109
+ """Initialize the model.
110
+
111
+ Args:
112
+ args: Dictionary containing model configuration
113
+ """
114
+ self.logger = pb_utils.Logger
115
+ # Parse model parameters
116
+ self.model_config = json.loads(args['model_config'])
117
+ parameters = self.model_config['parameters']
118
+ model_params = {k: v["string_value"] for k, v in parameters.items()}
119
+ self.logger.log_info(f"model_params:{model_params}")
120
+ # streaming TTS parameters
121
+ assert (
122
+ float(model_params["audio_chunk_duration"]) >= 0.5
123
+ ), f"audio_chunk_duration at least 0.5 seconds"
124
+ self.audio_chunk_duration = float(model_params["audio_chunk_duration"])
125
+ self.max_audio_chunk_duration = float(model_params["max_audio_chunk_duration"])
126
+ assert (
127
+ float(model_params["audio_chunk_size_scale_factor"]) >= 1.0
128
+ ), "audio_chunk_size_scale_factor should be greater than 1, change it according to your actual rtf"
129
+ self.audio_chunk_size_scale_factor = float(model_params["audio_chunk_size_scale_factor"]) # scale speed
130
+ self.audio_chunk_overlap_duration = float(model_params["audio_chunk_overlap_duration"])
131
+ self.audio_tokenizer_frame_rate = int(model_params["audio_tokenizer_frame_rate"])
132
+
133
+ # Initialize tokenizer
134
+ llm_tokenizer_dir = model_params["llm_tokenizer_dir"]
135
+ self.tokenizer = AutoTokenizer.from_pretrained(llm_tokenizer_dir)
136
+ self.device = torch.device("cuda")
137
+ self.decoupled = pb_utils.using_decoupled_model_transaction_policy(self.model_config)
138
+
139
+ def forward_llm(self, input_ids):
140
+ """
141
+ Prepares the response from the language model based on the provided
142
+ inputs. Creates a `pb_utils.InferenceRequest` object with passed
143
+ `llm_request_inputs` to send to a decoupled TensorRTLLM model.
144
+ For each response from the language model:
145
+ - Checks for errors and raise an exception if any are found.
146
+ - Extracts the "output_ids" tensor from the response.
147
+ - Determines the finish reason based on the presence of the
148
+ end-of-sequence token or reaching the maximum length.
149
+ - Appends the generated token IDs to `output_ids`.
150
+ - If the finish reason is determined, decodes the output IDs to text
151
+ and prepares the final response.
152
+
153
+ The final response includes the generated text, finish reason,
154
+ completion tokens, prompt tokens, and total tokens.
155
+
156
+ Parameters
157
+ ----------
158
+ - llm_request_inputs (dict): A dictionary containing the inputs for the language model.
159
+
160
+ Returns
161
+ -------
162
+ - pb_utils.InferenceResponse: The response object containing the generated text and additional metadata.
163
+ """
164
+ # convert input_ids to numpy, with shape [1, sequence_length]
165
+ input_ids = input_ids.cpu().numpy()
166
+ max_tokens = 512
167
+ input_dict = {
168
+ "request_output_len": np.array([[max_tokens]], dtype=np.int32),
169
+ "end_id": np.array([[self.tokenizer.eos_token_id]], dtype=np.int32),
170
+ "pad_id": np.array([[self.tokenizer.pad_token_id]], dtype=np.int32),
171
+ "streaming": np.array([[self.decoupled]], dtype=np.bool_),
172
+ "runtime_top_p": np.array([[0.95]], dtype=np.float32),
173
+ "runtime_top_k": np.array([[50]], dtype=np.int32),
174
+ "temperature": np.array([[0.8]], dtype=np.float32),
175
+ "input_ids": input_ids,
176
+ "input_lengths": np.array([[input_ids.shape[1]]], dtype=np.int32),
177
+ }
178
+
179
+ # Convert inputs to Triton tensors
180
+ input_tensor_list = [
181
+ pb_utils.Tensor(k, v) for k, v in input_dict.items()
182
+ ]
183
+
184
+ # Create and execute inference request
185
+ llm_request = pb_utils.InferenceRequest(
186
+ model_name="tensorrt_llm",
187
+ requested_output_names=["output_ids", "sequence_length"],
188
+ inputs=input_tensor_list,
189
+ )
190
+
191
+ llm_responses = llm_request.exec(decoupled=self.decoupled)
192
+ if self.decoupled:
193
+ for llm_response in llm_responses:
194
+ if llm_response.has_error():
195
+ raise pb_utils.TritonModelException(llm_response.error().message())
196
+
197
+ # Extract and process output
198
+ output_ids = pb_utils.get_output_tensor_by_name(
199
+ llm_response, "output_ids").as_numpy()
200
+ seq_lens = pb_utils.get_output_tensor_by_name(
201
+ llm_response, "sequence_length").as_numpy()
202
+
203
+ # Get actual output IDs up to the sequence length
204
+ actual_output_ids = output_ids[0][0][:seq_lens[0][0]]
205
+
206
+ yield actual_output_ids
207
+ else:
208
+ llm_response = llm_responses
209
+ if llm_response.has_error():
210
+ raise pb_utils.TritonModelException(llm_response.error().message())
211
+
212
+ # Extract and process output
213
+ output_ids = pb_utils.get_output_tensor_by_name(
214
+ llm_response, "output_ids").as_numpy()
215
+ seq_lens = pb_utils.get_output_tensor_by_name(
216
+ llm_response, "sequence_length").as_numpy()
217
+
218
+ # Get actual output IDs up to the sequence length
219
+ actual_output_ids = output_ids[0][0][:seq_lens[0][0]]
220
+
221
+ yield actual_output_ids
222
+
223
+ def forward_audio_tokenizer(self, wav, wav_len):
224
+ """Forward pass through the audio tokenizer component.
225
+
226
+ Args:
227
+ wav: Input waveform tensor
228
+ wav_len: Waveform length tensor
229
+
230
+ Returns:
231
+ Tuple of global and semantic tokens
232
+ """
233
+ inference_request = pb_utils.InferenceRequest(
234
+ model_name='audio_tokenizer',
235
+ requested_output_names=['global_tokens', 'semantic_tokens'],
236
+ inputs=[wav, wav_len]
237
+ )
238
+
239
+ inference_response = inference_request.exec()
240
+ if inference_response.has_error():
241
+ raise pb_utils.TritonModelException(inference_response.error().message())
242
+
243
+ # Extract and convert output tensors
244
+ global_tokens = pb_utils.get_output_tensor_by_name(inference_response, 'global_tokens')
245
+ global_tokens = torch.utils.dlpack.from_dlpack(global_tokens.to_dlpack()).cpu()
246
+
247
+ semantic_tokens = pb_utils.get_output_tensor_by_name(inference_response, 'semantic_tokens')
248
+ semantic_tokens = torch.utils.dlpack.from_dlpack(semantic_tokens.to_dlpack()).cpu()
249
+
250
+ return global_tokens, semantic_tokens
251
+
252
+ def forward_vocoder(self, global_token_ids: torch.Tensor, pred_semantic_ids: torch.Tensor) -> torch.Tensor:
253
+ """Forward pass through the vocoder component.
254
+
255
+ Args:
256
+ global_token_ids: Global token IDs tensor
257
+ pred_semantic_ids: Predicted semantic token IDs tensor
258
+
259
+ Returns:
260
+ Generated waveform tensor
261
+ """
262
+ # Convert tensors to Triton format
263
+ global_token_ids_tensor = pb_utils.Tensor.from_dlpack("global_tokens", to_dlpack(global_token_ids))
264
+ pred_semantic_ids_tensor = pb_utils.Tensor.from_dlpack("semantic_tokens", to_dlpack(pred_semantic_ids))
265
+
266
+ # Create and execute inference request
267
+ inference_request = pb_utils.InferenceRequest(
268
+ model_name='vocoder',
269
+ requested_output_names=['waveform'],
270
+ inputs=[global_token_ids_tensor, pred_semantic_ids_tensor]
271
+ )
272
+
273
+ inference_response = inference_request.exec()
274
+ if inference_response.has_error():
275
+ raise pb_utils.TritonModelException(inference_response.error().message())
276
+
277
+ # Extract and convert output waveform
278
+ waveform = pb_utils.get_output_tensor_by_name(inference_response, 'waveform')
279
+ waveform = torch.utils.dlpack.from_dlpack(waveform.to_dlpack()).cpu()
280
+
281
+ return waveform
282
+
283
+ def token2wav(self, generated_token_ids, global_token_ids):
284
+ # Decode and extract semantic token IDs from generated text
285
+ predicted_text = self.tokenizer.batch_decode(
286
+ [generated_token_ids],
287
+ skip_special_tokens=True,
288
+ )[0]
289
+ pred_semantic_ids = (
290
+ torch.tensor(
291
+ [int(token) for token in re.findall(r"bicodec_semantic_(\d+)", predicted_text)]
292
+ )
293
+ .unsqueeze(0)
294
+ .to(torch.int32)
295
+ )
296
+
297
+ # Generate audio with vocoder
298
+ audio = self.forward_vocoder(
299
+ global_token_ids.to(self.device),
300
+ pred_semantic_ids.to(self.device),
301
+ )
302
+
303
+ return audio
304
+
305
+ def execute(self, requests):
306
+ """Execute inference on the batched requests.
307
+
308
+ Args:
309
+ requests: List of inference requests
310
+
311
+ Returns:
312
+ List of inference responses containing generated audio
313
+ """
314
+ responses = []
315
+
316
+ for request in requests:
317
+ # Extract input tensors
318
+ wav = pb_utils.get_input_tensor_by_name(request, "reference_wav")
319
+ wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
320
+
321
+ # Process reference audio through audio tokenizer
322
+ global_tokens, semantic_tokens = self.forward_audio_tokenizer(wav, wav_len)
323
+
324
+ # Extract text inputs
325
+ reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
326
+ reference_text = reference_text[0][0].decode('utf-8')
327
+
328
+ target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy()
329
+ target_text = target_text[0][0].decode('utf-8')
330
+
331
+ # Prepare prompt for LLM
332
+ prompt, global_token_ids = process_prompt(
333
+ text=target_text,
334
+ prompt_text=reference_text,
335
+ global_token_ids=global_tokens,
336
+ semantic_token_ids=semantic_tokens,
337
+ )
338
+
339
+
340
+ # Tokenize prompt for LLM
341
+ model_inputs = self.tokenizer([prompt], return_tensors="pt").to(self.device)
342
+ input_ids = model_inputs.input_ids.to(torch.int32)
343
+
344
+ # Generate semantic tokens with LLM
345
+ generated_ids_iter = self.forward_llm(input_ids)
346
+
347
+ if self.decoupled:
348
+ response_sender = request.get_response_sender()
349
+ request_id = request.request_id()
350
+ semantic_token_ids_arr = []
351
+ max_chunk_size = math.ceil(self.max_audio_chunk_duration * self.audio_tokenizer_frame_rate)
352
+ chunk_size = math.ceil(self.audio_chunk_duration * self.audio_tokenizer_frame_rate)
353
+ overlap_chunk_size = math.ceil(self.audio_chunk_overlap_duration * self.audio_tokenizer_frame_rate)
354
+ self.logger.log_info(
355
+ f"[{request_id}] init chunk_size: {chunk_size} max_chunk_size: {max_chunk_size}"
356
+ )
357
+ for generated_ids in generated_ids_iter:
358
+ if generated_ids is None or len(generated_ids) == 0:
359
+ break
360
+
361
+ semantic_token_ids_arr.append(generated_ids)
362
+ if len(semantic_token_ids_arr) >= chunk_size:
363
+ chunk = semantic_token_ids_arr[:chunk_size]
364
+ generated_semantic_token_ids = np.hstack(chunk)
365
+ # Process each chunk
366
+ sub_tts_speech = self.token2wav(generated_semantic_token_ids, global_token_ids)
367
+ # Prepare response to send
368
+ audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
369
+ inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
370
+ response_sender.send(inference_response)
371
+
372
+ semantic_token_ids_arr = semantic_token_ids_arr[chunk_size - overlap_chunk_size:]
373
+ # increase chunk size for better speech quality
374
+ chunk_size = min(max_chunk_size, int(chunk_size * self.audio_chunk_size_scale_factor))
375
+ self.logger.log_info(f"[{request_id}] increase chunk_size: {chunk_size}")
376
+
377
+ if len(semantic_token_ids_arr) > 0: # end to finalize
378
+ generated_semantic_token_ids = np.hstack(semantic_token_ids_arr)
379
+ # Process each chunk
380
+ sub_tts_speech = self.token2wav(generated_semantic_token_ids, global_token_ids)
381
+ # Prepare response to send
382
+ audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
383
+ inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
384
+ response_sender.send(inference_response)
385
+ self.logger.log_info(f"[{request_id}] last chunk len: {len(semantic_token_ids_arr)}")
386
+ else:
387
+ generated_ids = next(generated_ids_iter)
388
+ if generated_ids is None or len(generated_ids) == 0:
389
+ raise pb_utils.TritonModelException("Generated IDs is None or empty")
390
+
391
+ audio = self.token2wav(generated_ids, global_token_ids)
392
+
393
+ # Prepare response
394
+ audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio))
395
+ inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
396
+ responses.append(inference_response)
397
+
398
+ if self.decoupled:
399
+ response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
400
+ self.logger.log_info(f"send tritonserver_response_complete_final to end")
401
+
402
+ if not self.decoupled:
403
+ return responses
404
+
runtime/triton_trtllm/model_repo/spark_tts/config.pbtxt ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ name: "spark_tts"
16
+ backend: "python"
17
+ max_batch_size: ${triton_max_batch_size}
18
+ dynamic_batching {
19
+ max_queue_delay_microseconds: ${max_queue_delay_microseconds}
20
+ }
21
+ model_transaction_policy {
22
+ decoupled: ${decoupled_mode}
23
+ }
24
+ parameters [
25
+ {
26
+ key: "llm_tokenizer_dir",
27
+ value: {string_value:"${llm_tokenizer_dir}"}
28
+ },
29
+ {
30
+ key: "audio_chunk_duration",
31
+ value: {string_value:"${audio_chunk_duration}"}
32
+ },
33
+ {
34
+ key: "audio_chunk_size_scale_factor",
35
+ value: {string_value:"${audio_chunk_size_scale_factor}"}
36
+ },
37
+ {
38
+ key: "max_audio_chunk_duration",
39
+ value: {string_value:"${max_audio_chunk_duration}"}
40
+ },
41
+ {
42
+ key: "audio_chunk_overlap_duration",
43
+ value: {string_value:"${audio_chunk_overlap_duration}"}
44
+ },
45
+ {
46
+ key: "audio_tokenizer_frame_rate",
47
+ value: {string_value:"50"}
48
+ }
49
+ ]
50
+
51
+ input [
52
+ {
53
+ name: "reference_wav"
54
+ data_type: TYPE_FP32
55
+ dims: [-1]
56
+ },
57
+ {
58
+ name: "reference_wav_len"
59
+ data_type: TYPE_INT32
60
+ dims: [1]
61
+ },
62
+ {
63
+ name: "reference_text"
64
+ data_type: TYPE_STRING
65
+ dims: [1]
66
+ },
67
+ {
68
+ name: "target_text"
69
+ data_type: TYPE_STRING
70
+ dims: [1]
71
+ }
72
+ ]
73
+ output [
74
+ {
75
+ name: "waveform"
76
+ data_type: TYPE_FP32
77
+ dims: [ -1 ]
78
+ }
79
+ ]
80
+
81
+ instance_group [
82
+ {
83
+ count: ${bls_instance_num}
84
+ kind: KIND_CPU
85
+ }
86
+ ]
runtime/triton_trtllm/model_repo/tensorrt_llm/1/.gitkeep ADDED
File without changes
runtime/triton_trtllm/model_repo/tensorrt_llm/config.pbtxt ADDED
@@ -0,0 +1,857 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # Redistribution and use in source and binary forms, with or without
4
+ # modification, are permitted provided that the following conditions
5
+ # are met:
6
+ # * Redistributions of source code must retain the above copyright
7
+ # notice, this list of conditions and the following disclaimer.
8
+ # * Redistributions in binary form must reproduce the above copyright
9
+ # notice, this list of conditions and the following disclaimer in the
10
+ # documentation and/or other materials provided with the distribution.
11
+ # * Neither the name of NVIDIA CORPORATION nor the names of its
12
+ # contributors may be used to endorse or promote products derived
13
+ # from this software without specific prior written permission.
14
+ #
15
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
16
+ # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
18
+ # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
19
+ # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20
+ # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21
+ # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22
+ # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
23
+ # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24
+ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26
+
27
+ name: "tensorrt_llm"
28
+ backend: "${triton_backend}"
29
+ max_batch_size: ${triton_max_batch_size}
30
+
31
+ model_transaction_policy {
32
+ decoupled: ${decoupled_mode}
33
+ }
34
+
35
+ dynamic_batching {
36
+ preferred_batch_size: [ ${triton_max_batch_size} ]
37
+ max_queue_delay_microseconds: ${max_queue_delay_microseconds}
38
+ default_queue_policy: { max_queue_size: ${max_queue_size} }
39
+ }
40
+
41
+ input [
42
+ {
43
+ name: "input_ids"
44
+ data_type: TYPE_INT32
45
+ dims: [ -1 ]
46
+ allow_ragged_batch: true
47
+ optional: true
48
+ },
49
+ {
50
+ name: "encoder_input_features"
51
+ data_type: ${encoder_input_features_data_type}
52
+ dims: [ -1, -1 ]
53
+ allow_ragged_batch: true
54
+ optional: true
55
+ },
56
+ {
57
+ name: "encoder_output_lengths"
58
+ data_type: TYPE_INT32
59
+ dims: [ 1 ]
60
+ reshape: { shape: [ ] }
61
+ optional: true
62
+ },
63
+ {
64
+ name: "input_lengths"
65
+ data_type: TYPE_INT32
66
+ dims: [ 1 ]
67
+ reshape: { shape: [ ] }
68
+ },
69
+ {
70
+ name: "request_output_len"
71
+ data_type: TYPE_INT32
72
+ dims: [ 1 ]
73
+ reshape: { shape: [ ] }
74
+ },
75
+ {
76
+ name: "num_return_sequences"
77
+ data_type: TYPE_INT32
78
+ dims: [ 1 ]
79
+ reshape: { shape: [ ] }
80
+ optional: true
81
+ },
82
+ {
83
+ name: "draft_input_ids"
84
+ data_type: TYPE_INT32
85
+ dims: [ -1 ]
86
+ optional: true
87
+ allow_ragged_batch: true
88
+ },
89
+ {
90
+ name: "decoder_input_ids"
91
+ data_type: TYPE_INT32
92
+ dims: [ -1 ]
93
+ optional: true
94
+ allow_ragged_batch: true
95
+ },
96
+ {
97
+ name: "decoder_input_lengths"
98
+ data_type: TYPE_INT32
99
+ dims: [ 1 ]
100
+ optional: true
101
+ reshape: { shape: [ ] }
102
+ },
103
+ {
104
+ name: "draft_logits"
105
+ data_type: ${logits_datatype}
106
+ dims: [ -1, -1 ]
107
+ optional: true
108
+ allow_ragged_batch: true
109
+ },
110
+ {
111
+ name: "draft_acceptance_threshold"
112
+ data_type: TYPE_FP32
113
+ dims: [ 1 ]
114
+ reshape: { shape: [ ] }
115
+ optional: true
116
+ },
117
+ {
118
+ name: "end_id"
119
+ data_type: TYPE_INT32
120
+ dims: [ 1 ]
121
+ reshape: { shape: [ ] }
122
+ optional: true
123
+ },
124
+ {
125
+ name: "pad_id"
126
+ data_type: TYPE_INT32
127
+ dims: [ 1 ]
128
+ reshape: { shape: [ ] }
129
+ optional: true
130
+ },
131
+ {
132
+ name: "stop_words_list"
133
+ data_type: TYPE_INT32
134
+ dims: [ 2, -1 ]
135
+ optional: true
136
+ allow_ragged_batch: true
137
+ },
138
+ {
139
+ name: "bad_words_list"
140
+ data_type: TYPE_INT32
141
+ dims: [ 2, -1 ]
142
+ optional: true
143
+ allow_ragged_batch: true
144
+ },
145
+ {
146
+ name: "embedding_bias"
147
+ data_type: TYPE_FP32
148
+ dims: [ -1 ]
149
+ optional: true
150
+ allow_ragged_batch: true
151
+ },
152
+ {
153
+ name: "beam_width"
154
+ data_type: TYPE_INT32
155
+ dims: [ 1 ]
156
+ reshape: { shape: [ ] }
157
+ optional: true
158
+ },
159
+ {
160
+ name: "temperature"
161
+ data_type: TYPE_FP32
162
+ dims: [ 1 ]
163
+ reshape: { shape: [ ] }
164
+ optional: true
165
+ },
166
+ {
167
+ name: "runtime_top_k"
168
+ data_type: TYPE_INT32
169
+ dims: [ 1 ]
170
+ reshape: { shape: [ ] }
171
+ optional: true
172
+ },
173
+ {
174
+ name: "runtime_top_p"
175
+ data_type: TYPE_FP32
176
+ dims: [ 1 ]
177
+ reshape: { shape: [ ] }
178
+ optional: true
179
+ },
180
+ {
181
+ name: "runtime_top_p_min"
182
+ data_type: TYPE_FP32
183
+ dims: [ 1 ]
184
+ reshape: { shape: [ ] }
185
+ optional: true
186
+ },
187
+ {
188
+ name: "runtime_top_p_decay"
189
+ data_type: TYPE_FP32
190
+ dims: [ 1 ]
191
+ reshape: { shape: [ ] }
192
+ optional: true
193
+ },
194
+ {
195
+ name: "runtime_top_p_reset_ids"
196
+ data_type: TYPE_INT32
197
+ dims: [ 1 ]
198
+ reshape: { shape: [ ] }
199
+ optional: true
200
+ },
201
+ {
202
+ name: "len_penalty"
203
+ data_type: TYPE_FP32
204
+ dims: [ 1 ]
205
+ reshape: { shape: [ ] }
206
+ optional: true
207
+ },
208
+ {
209
+ name: "early_stopping"
210
+ data_type: TYPE_BOOL
211
+ dims: [ 1 ]
212
+ reshape: { shape: [ ] }
213
+ optional: true
214
+ },
215
+ {
216
+ name: "repetition_penalty"
217
+ data_type: TYPE_FP32
218
+ dims: [ 1 ]
219
+ reshape: { shape: [ ] }
220
+ optional: true
221
+ },
222
+ {
223
+ name: "min_length"
224
+ data_type: TYPE_INT32
225
+ dims: [ 1 ]
226
+ reshape: { shape: [ ] }
227
+ optional: true
228
+ },
229
+ {
230
+ name: "beam_search_diversity_rate"
231
+ data_type: TYPE_FP32
232
+ dims: [ 1 ]
233
+ reshape: { shape: [ ] }
234
+ optional: true
235
+ },
236
+ {
237
+ name: "presence_penalty"
238
+ data_type: TYPE_FP32
239
+ dims: [ 1 ]
240
+ reshape: { shape: [ ] }
241
+ optional: true
242
+ },
243
+ {
244
+ name: "frequency_penalty"
245
+ data_type: TYPE_FP32
246
+ dims: [ 1 ]
247
+ reshape: { shape: [ ] }
248
+ optional: true
249
+ },
250
+ {
251
+ name: "random_seed"
252
+ data_type: TYPE_UINT64
253
+ dims: [ 1 ]
254
+ reshape: { shape: [ ] }
255
+ optional: true
256
+ },
257
+ {
258
+ name: "return_log_probs"
259
+ data_type: TYPE_BOOL
260
+ dims: [ 1 ]
261
+ reshape: { shape: [ ] }
262
+ optional: true
263
+ },
264
+ {
265
+ name: "return_context_logits"
266
+ data_type: TYPE_BOOL
267
+ dims: [ 1 ]
268
+ reshape: { shape: [ ] }
269
+ optional: true
270
+ },
271
+ {
272
+ name: "return_generation_logits"
273
+ data_type: TYPE_BOOL
274
+ dims: [ 1 ]
275
+ reshape: { shape: [ ] }
276
+ optional: true
277
+ },
278
+ {
279
+ name: "return_perf_metrics"
280
+ data_type: TYPE_BOOL
281
+ dims: [ 1 ]
282
+ reshape: { shape: [ ] }
283
+ optional: true
284
+ },
285
+ {
286
+ name: "exclude_input_in_output"
287
+ data_type: TYPE_BOOL
288
+ dims: [ 1 ]
289
+ reshape: { shape: [ ] }
290
+ optional: true
291
+ },
292
+ {
293
+ name: "stop"
294
+ data_type: TYPE_BOOL
295
+ dims: [ 1 ]
296
+ reshape: { shape: [ ] }
297
+ optional: true
298
+ },
299
+ {
300
+ name: "streaming"
301
+ data_type: TYPE_BOOL
302
+ dims: [ 1 ]
303
+ reshape: { shape: [ ] }
304
+ optional: true
305
+ },
306
+ {
307
+ name: "prompt_embedding_table"
308
+ data_type: TYPE_FP16
309
+ dims: [ -1, -1 ]
310
+ optional: true
311
+ allow_ragged_batch: true
312
+ },
313
+ {
314
+ name: "prompt_table_extra_ids"
315
+ data_type: TYPE_UINT64
316
+ dims: [ -1 ]
317
+ optional: true
318
+ allow_ragged_batch: true
319
+ },
320
+ {
321
+ name: "prompt_vocab_size"
322
+ data_type: TYPE_INT32
323
+ dims: [ 1 ]
324
+ reshape: { shape: [ ] }
325
+ optional: true
326
+ },
327
+ # cross_attention_mask shape `[bs, seq_len, num_images*num_tiles]`
328
+ {
329
+ name: "cross_attention_mask"
330
+ data_type: TYPE_BOOL
331
+ dims: [ -1, -1 ]
332
+ optional: true
333
+ allow_ragged_batch: true
334
+ },
335
+ # Mrope param when mrope is used
336
+ {
337
+ name: "mrope_rotary_cos_sin"
338
+ data_type: TYPE_FP32
339
+ dims: [ -1 ]
340
+ optional: true
341
+ },
342
+ {
343
+ name: "mrope_position_deltas"
344
+ data_type: TYPE_INT64
345
+ dims: [ 1 ]
346
+ optional: true
347
+ },
348
+ # the unique task ID for the given LoRA.
349
+ # To perform inference with a specific LoRA for the first time `lora_task_id` `lora_weights` and `lora_config` must all be given.
350
+ # The LoRA will be cached, so that subsequent requests for the same task only require `lora_task_id`.
351
+ # If the cache is full the oldest LoRA will be evicted to make space for new ones. An error is returned if `lora_task_id` is not cached.
352
+ {
353
+ name: "lora_task_id"
354
+ data_type: TYPE_UINT64
355
+ dims: [ 1 ]
356
+ reshape: { shape: [ ] }
357
+ optional: true
358
+ },
359
+ # weights for a lora adapter shape [ num_lora_modules_layers, D x Hi + Ho x D ]
360
+ # where the last dimension holds the in / out adapter weights for the associated module (e.g. attn_qkv) and model layer
361
+ # each of the in / out tensors are first flattened and then concatenated together in the format above.
362
+ # D=adapter_size (R value), Hi=hidden_size_in, Ho=hidden_size_out.
363
+ {
364
+ name: "lora_weights"
365
+ data_type: TYPE_FP16
366
+ dims: [ -1, -1 ]
367
+ optional: true
368
+ allow_ragged_batch: true
369
+ },
370
+ # module identifier (same size a first dimension of lora_weights)
371
+ # See LoraModule::ModuleType for model id mapping
372
+ #
373
+ # "attn_qkv": 0 # compbined qkv adapter
374
+ # "attn_q": 1 # q adapter
375
+ # "attn_k": 2 # k adapter
376
+ # "attn_v": 3 # v adapter
377
+ # "attn_dense": 4 # adapter for the dense layer in attention
378
+ # "mlp_h_to_4h": 5 # for llama2 adapter for gated mlp layer after attention / RMSNorm: up projection
379
+ # "mlp_4h_to_h": 6 # for llama2 adapter for gated mlp layer after attention / RMSNorm: down projection
380
+ # "mlp_gate": 7 # for llama2 adapter for gated mlp later after attention / RMSNorm: gate
381
+ #
382
+ # last dim holds [ module_id, layer_idx, adapter_size (D aka R value) ]
383
+ {
384
+ name: "lora_config"
385
+ data_type: TYPE_INT32
386
+ dims: [ -1, 3 ]
387
+ optional: true
388
+ allow_ragged_batch: true
389
+ },
390
+ {
391
+ name: "context_phase_params"
392
+ data_type: TYPE_UINT8
393
+ dims: [ -1 ]
394
+ optional: true
395
+ allow_ragged_batch: true
396
+ },
397
+ # skip_cross_attn_blocks shape `[bs, 1]`, only used in mllama
398
+ {
399
+ name: "skip_cross_attn_blocks"
400
+ data_type: TYPE_BOOL
401
+ dims: [ 1 ]
402
+ optional: true
403
+ allow_ragged_batch: true
404
+ },
405
+ {
406
+ name: "retention_token_range_starts"
407
+ data_type: TYPE_INT32
408
+ dims: [ -1 ]
409
+ optional: true
410
+ allow_ragged_batch: true
411
+ },
412
+ {
413
+ name: "retention_token_range_ends"
414
+ data_type: TYPE_INT32
415
+ dims: [ -1 ]
416
+ optional: true
417
+ allow_ragged_batch: true
418
+ },
419
+ {
420
+ name: "retention_token_range_priorities"
421
+ data_type: TYPE_INT32
422
+ dims: [ -1 ]
423
+ optional: true
424
+ allow_ragged_batch: true
425
+ },
426
+ {
427
+ name: "retention_token_range_durations_ms"
428
+ data_type: TYPE_INT32
429
+ dims: [ -1 ]
430
+ optional: true
431
+ allow_ragged_batch: true
432
+ },
433
+ {
434
+ name: "retention_decode_priority"
435
+ data_type: TYPE_INT32
436
+ dims: [ 1 ]
437
+ optional: true
438
+ allow_ragged_batch: true
439
+ },
440
+ {
441
+ name: "retention_decode_duration_ms"
442
+ data_type: TYPE_INT32
443
+ dims: [ 1 ]
444
+ optional: true
445
+ allow_ragged_batch: true
446
+ },
447
+ {
448
+ name: "guided_decoding_guide_type"
449
+ data_type: TYPE_STRING
450
+ dims: [ 1 ]
451
+ optional: true
452
+ allow_ragged_batch: true
453
+ },
454
+ {
455
+ name: "guided_decoding_guide"
456
+ data_type: TYPE_STRING
457
+ dims: [ 1 ]
458
+ optional: true
459
+ allow_ragged_batch: true
460
+ },
461
+ {
462
+ name: "lookahead_window_size"
463
+ data_type: TYPE_INT32
464
+ dims: [ 1 ]
465
+ optional: true
466
+ allow_ragged_batch: true
467
+ },
468
+ {
469
+ name: "lookahead_ngram_size"
470
+ data_type: TYPE_INT32
471
+ dims: [ 1 ]
472
+ optional: true
473
+ allow_ragged_batch: true
474
+ },
475
+ {
476
+ name: "lookahead_verification_set_size"
477
+ data_type: TYPE_INT32
478
+ dims: [ 1 ]
479
+ optional: true
480
+ allow_ragged_batch: true
481
+ }
482
+ ]
483
+ output [
484
+ {
485
+ name: "output_ids"
486
+ data_type: TYPE_INT32
487
+ dims: [ -1, -1 ]
488
+ },
489
+ {
490
+ name: "sequence_length"
491
+ data_type: TYPE_INT32
492
+ dims: [ -1 ]
493
+ },
494
+ {
495
+ name: "cum_log_probs"
496
+ data_type: TYPE_FP32
497
+ dims: [ -1 ]
498
+ },
499
+ {
500
+ name: "output_log_probs"
501
+ data_type: TYPE_FP32
502
+ dims: [ -1, -1 ]
503
+ },
504
+ {
505
+ name: "context_logits"
506
+ data_type: ${logits_datatype}
507
+ dims: [ -1, -1 ]
508
+ },
509
+ {
510
+ name: "generation_logits"
511
+ data_type: ${logits_datatype}
512
+ dims: [ -1, -1, -1 ]
513
+ },
514
+ {
515
+ name: "batch_index"
516
+ data_type: TYPE_INT32
517
+ dims: [ 1 ]
518
+ },
519
+ {
520
+ name: "sequence_index"
521
+ data_type: TYPE_INT32
522
+ dims: [ 1 ]
523
+ },
524
+ {
525
+ name: "context_phase_params"
526
+ data_type: TYPE_UINT8
527
+ dims: [ -1 ]
528
+ },
529
+ {
530
+ name: "kv_cache_alloc_new_blocks"
531
+ data_type: TYPE_INT32
532
+ dims: [ 1 ]
533
+ },
534
+ {
535
+ name: "kv_cache_reused_blocks"
536
+ data_type: TYPE_INT32
537
+ dims: [ 1 ]
538
+ },
539
+ {
540
+ name: "kv_cache_alloc_total_blocks"
541
+ data_type: TYPE_INT32
542
+ dims: [ 1 ]
543
+ },
544
+ {
545
+ name: "arrival_time_ns"
546
+ data_type: TYPE_INT64
547
+ dims: [ 1 ]
548
+ },
549
+ {
550
+ name: "first_scheduled_time_ns"
551
+ data_type: TYPE_INT64
552
+ dims: [ 1 ]
553
+ },
554
+ {
555
+ name: "first_token_time_ns"
556
+ data_type: TYPE_INT64
557
+ dims: [ 1 ]
558
+ },
559
+ {
560
+ name: "last_token_time_ns"
561
+ data_type: TYPE_INT64
562
+ dims: [ 1 ]
563
+ },
564
+ {
565
+ name: "acceptance_rate"
566
+ data_type: TYPE_FP32
567
+ dims: [ 1 ]
568
+ },
569
+ {
570
+ name: "total_accepted_draft_tokens"
571
+ data_type: TYPE_INT32
572
+ dims: [ 1 ]
573
+ },
574
+ {
575
+ name: "total_draft_tokens"
576
+ data_type: TYPE_INT32
577
+ dims: [ 1 ]
578
+ }
579
+ ]
580
+ instance_group [
581
+ {
582
+ count: 1
583
+ kind : KIND_CPU
584
+ }
585
+ ]
586
+ parameters: {
587
+ key: "max_beam_width"
588
+ value: {
589
+ string_value: "${max_beam_width}"
590
+ }
591
+ }
592
+ parameters: {
593
+ key: "FORCE_CPU_ONLY_INPUT_TENSORS"
594
+ value: {
595
+ string_value: "no"
596
+ }
597
+ }
598
+ parameters: {
599
+ key: "gpt_model_type"
600
+ value: {
601
+ string_value: "${batching_strategy}"
602
+ }
603
+ }
604
+ parameters: {
605
+ key: "gpt_model_path"
606
+ value: {
607
+ string_value: "${engine_dir}"
608
+ }
609
+ }
610
+ parameters: {
611
+ key: "encoder_model_path"
612
+ value: {
613
+ string_value: "${encoder_engine_dir}"
614
+ }
615
+ }
616
+ parameters: {
617
+ key: "max_tokens_in_paged_kv_cache"
618
+ value: {
619
+ string_value: "${max_tokens_in_paged_kv_cache}"
620
+ }
621
+ }
622
+ parameters: {
623
+ key: "max_attention_window_size"
624
+ value: {
625
+ string_value: "${max_attention_window_size}"
626
+ }
627
+ }
628
+ parameters: {
629
+ key: "sink_token_length"
630
+ value: {
631
+ string_value: "${sink_token_length}"
632
+ }
633
+ }
634
+ parameters: {
635
+ key: "batch_scheduler_policy"
636
+ value: {
637
+ string_value: "${batch_scheduler_policy}"
638
+ }
639
+ }
640
+ parameters: {
641
+ key: "kv_cache_free_gpu_mem_fraction"
642
+ value: {
643
+ string_value: "${kv_cache_free_gpu_mem_fraction}"
644
+ }
645
+ }
646
+ parameters: {
647
+ key: "cross_kv_cache_fraction"
648
+ value: {
649
+ string_value: "${cross_kv_cache_fraction}"
650
+ }
651
+ }
652
+ parameters: {
653
+ key: "kv_cache_host_memory_bytes"
654
+ value: {
655
+ string_value: "${kv_cache_host_memory_bytes}"
656
+ }
657
+ }
658
+ # kv_cache_onboard_blocks is for internal implementation.
659
+ parameters: {
660
+ key: "kv_cache_onboard_blocks"
661
+ value: {
662
+ string_value: "${kv_cache_onboard_blocks}"
663
+ }
664
+ }
665
+ # enable_trt_overlap is deprecated and doesn't have any effect on the runtime
666
+ # parameters: {
667
+ # key: "enable_trt_overlap"
668
+ # value: {
669
+ # string_value: "${enable_trt_overlap}"
670
+ # }
671
+ # }
672
+ parameters: {
673
+ key: "exclude_input_in_output"
674
+ value: {
675
+ string_value: "${exclude_input_in_output}"
676
+ }
677
+ }
678
+ parameters: {
679
+ key: "cancellation_check_period_ms"
680
+ value: {
681
+ string_value: "${cancellation_check_period_ms}"
682
+ }
683
+ }
684
+ parameters: {
685
+ key: "stats_check_period_ms"
686
+ value: {
687
+ string_value: "${stats_check_period_ms}"
688
+ }
689
+ }
690
+ parameters: {
691
+ key: "iter_stats_max_iterations"
692
+ value: {
693
+ string_value: "${iter_stats_max_iterations}"
694
+ }
695
+ }
696
+ parameters: {
697
+ key: "request_stats_max_iterations"
698
+ value: {
699
+ string_value: "${request_stats_max_iterations}"
700
+ }
701
+ }
702
+ parameters: {
703
+ key: "enable_kv_cache_reuse"
704
+ value: {
705
+ string_value: "${enable_kv_cache_reuse}"
706
+ }
707
+ }
708
+ parameters: {
709
+ key: "normalize_log_probs"
710
+ value: {
711
+ string_value: "${normalize_log_probs}"
712
+ }
713
+ }
714
+ parameters: {
715
+ key: "enable_chunked_context"
716
+ value: {
717
+ string_value: "${enable_chunked_context}"
718
+ }
719
+ }
720
+ parameters: {
721
+ key: "gpu_device_ids"
722
+ value: {
723
+ string_value: "${gpu_device_ids}"
724
+ }
725
+ }
726
+ parameters: {
727
+ key: "participant_ids"
728
+ value: {
729
+ string_value: "${participant_ids}"
730
+ }
731
+ }
732
+ parameters: {
733
+ key: "lora_cache_optimal_adapter_size"
734
+ value: {
735
+ string_value: "${lora_cache_optimal_adapter_size}"
736
+ }
737
+ }
738
+ parameters: {
739
+ key: "lora_cache_max_adapter_size"
740
+ value: {
741
+ string_value: "${lora_cache_max_adapter_size}"
742
+ }
743
+ }
744
+ parameters: {
745
+ key: "lora_cache_gpu_memory_fraction"
746
+ value: {
747
+ string_value: "${lora_cache_gpu_memory_fraction}"
748
+ }
749
+ }
750
+ parameters: {
751
+ key: "lora_cache_host_memory_bytes"
752
+ value: {
753
+ string_value: "${lora_cache_host_memory_bytes}"
754
+ }
755
+ }
756
+ parameters: {
757
+ key: "lora_prefetch_dir"
758
+ value: {
759
+ string_value: "${lora_prefetch_dir}"
760
+ }
761
+ }
762
+ parameters: {
763
+ key: "decoding_mode"
764
+ value: {
765
+ string_value: "${decoding_mode}"
766
+ }
767
+ }
768
+ parameters: {
769
+ key: "executor_worker_path"
770
+ value: {
771
+ string_value: "/opt/tritonserver/backends/tensorrtllm/trtllmExecutorWorker"
772
+ }
773
+ }
774
+ parameters: {
775
+ key: "lookahead_window_size"
776
+ value: {
777
+ string_value: "${lookahead_window_size}"
778
+ }
779
+ }
780
+ parameters: {
781
+ key: "lookahead_ngram_size"
782
+ value: {
783
+ string_value: "${lookahead_ngram_size}"
784
+ }
785
+ }
786
+ parameters: {
787
+ key: "lookahead_verification_set_size"
788
+ value: {
789
+ string_value: "${lookahead_verification_set_size}"
790
+ }
791
+ }
792
+ parameters: {
793
+ key: "medusa_choices"
794
+ value: {
795
+ string_value: "${medusa_choices}"
796
+ }
797
+ }
798
+ parameters: {
799
+ key: "eagle_choices"
800
+ value: {
801
+ string_value: "${eagle_choices}"
802
+ }
803
+ }
804
+ parameters: {
805
+ key: "gpu_weights_percent"
806
+ value: {
807
+ string_value: "${gpu_weights_percent}"
808
+ }
809
+ }
810
+ parameters: {
811
+ key: "enable_context_fmha_fp32_acc"
812
+ value: {
813
+ string_value: "${enable_context_fmha_fp32_acc}"
814
+ }
815
+ }
816
+ parameters: {
817
+ key: "multi_block_mode"
818
+ value: {
819
+ string_value: "${multi_block_mode}"
820
+ }
821
+ }
822
+ parameters: {
823
+ key: "cuda_graph_mode"
824
+ value: {
825
+ string_value: "${cuda_graph_mode}"
826
+ }
827
+ }
828
+ parameters: {
829
+ key: "cuda_graph_cache_size"
830
+ value: {
831
+ string_value: "${cuda_graph_cache_size}"
832
+ }
833
+ }
834
+ parameters: {
835
+ key: "speculative_decoding_fast_logits"
836
+ value: {
837
+ string_value: "${speculative_decoding_fast_logits}"
838
+ }
839
+ }
840
+ parameters: {
841
+ key: "tokenizer_dir"
842
+ value: {
843
+ string_value: "${tokenizer_dir}"
844
+ }
845
+ }
846
+ parameters: {
847
+ key: "guided_decoding_backend"
848
+ value: {
849
+ string_value: "${guided_decoding_backend}"
850
+ }
851
+ }
852
+ parameters: {
853
+ key: "xgrammar_tokenizer_info_path"
854
+ value: {
855
+ string_value: "${xgrammar_tokenizer_info_path}"
856
+ }
857
+ }
runtime/triton_trtllm/model_repo/vocoder/1/model.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # Redistribution and use in source and binary forms, with or without
4
+ # modification, are permitted provided that the following conditions
5
+ # are met:
6
+ # * Redistributions of source code must retain the above copyright
7
+ # notice, this list of conditions and the following disclaimer.
8
+ # * Redistributions in binary form must reproduce the above copyright
9
+ # notice, this list of conditions and the following disclaimer in the
10
+ # documentation and/or other materials provided with the distribution.
11
+ # * Neither the name of NVIDIA CORPORATION nor the names of its
12
+ # contributors may be used to endorse or promote products derived
13
+ # from this software without specific prior written permission.
14
+ #
15
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
16
+ # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
18
+ # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
19
+ # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20
+ # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21
+ # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22
+ # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
23
+ # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24
+ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26
+
27
+ import json
28
+ import os
29
+ import logging
30
+ from typing import List, Dict
31
+
32
+ import torch
33
+ from torch.utils.dlpack import to_dlpack
34
+
35
+ import triton_python_backend_utils as pb_utils
36
+
37
+ from sparktts.models.bicodec import BiCodec
38
+
39
+ # Configure logging
40
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
41
+ logger = logging.getLogger(__name__)
42
+
43
+ class TritonPythonModel:
44
+ """Triton Python model for vocoder.
45
+
46
+ This model takes global and semantic tokens as input and generates audio waveforms
47
+ using the BiCodec vocoder.
48
+ """
49
+
50
+ def initialize(self, args):
51
+ """Initialize the model.
52
+
53
+ Args:
54
+ args: Dictionary containing model configuration
55
+ """
56
+ # Parse model parameters
57
+ parameters = json.loads(args['model_config'])['parameters']
58
+ model_params = {key: value["string_value"] for key, value in parameters.items()}
59
+ model_dir = model_params["model_dir"]
60
+
61
+ # Initialize device and vocoder
62
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
63
+ logger.info(f"Initializing vocoder from {model_dir} on {self.device}")
64
+
65
+ self.vocoder = BiCodec.load_from_checkpoint(f"{model_dir}/BiCodec")
66
+ del self.vocoder.encoder, self.vocoder.postnet
67
+ self.vocoder.eval().to(self.device) # Set model to evaluation mode
68
+
69
+ logger.info("Vocoder initialized successfully")
70
+
71
+
72
+ def execute(self, requests):
73
+ """Execute inference on the batched requests.
74
+
75
+ Args:
76
+ requests: List of inference requests
77
+
78
+ Returns:
79
+ List of inference responses containing generated waveforms
80
+ """
81
+ global_tokens_list, semantic_tokens_list = [], []
82
+
83
+ # Process each request in batch
84
+ for request in requests:
85
+ global_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "global_tokens").as_numpy()
86
+ semantic_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "semantic_tokens").as_numpy()
87
+ global_tokens_list.append(torch.from_numpy(global_tokens_tensor).to(self.device))
88
+ semantic_tokens_list.append(torch.from_numpy(semantic_tokens_tensor).to(self.device))
89
+
90
+ # Concatenate tokens for batch processing
91
+ global_tokens = torch.cat(global_tokens_list, dim=0)
92
+ semantic_tokens = torch.cat(semantic_tokens_list, dim=0)
93
+
94
+
95
+ # Generate waveforms
96
+ with torch.no_grad():
97
+ wavs = self.vocoder.detokenize(semantic_tokens, global_tokens.unsqueeze(1))
98
+
99
+ # Prepare responses
100
+ responses = []
101
+ for i in range(len(requests)):
102
+ wav_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(wavs[i]))
103
+ inference_response = pb_utils.InferenceResponse(output_tensors=[wav_tensor])
104
+ responses.append(inference_response)
105
+
106
+ return responses
runtime/triton_trtllm/model_repo/vocoder/config.pbtxt ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ name: "vocoder"
16
+ backend: "python"
17
+ max_batch_size: ${triton_max_batch_size}
18
+ dynamic_batching {
19
+ max_queue_delay_microseconds: ${max_queue_delay_microseconds}
20
+ }
21
+ parameters [
22
+ {
23
+ key: "model_dir",
24
+ value: {string_value:"${model_dir}"}
25
+ }
26
+ ]
27
+
28
+ input [
29
+ {
30
+ name: "global_tokens"
31
+ data_type: TYPE_INT32
32
+ dims: [-1]
33
+ },
34
+ {
35
+ name: "semantic_tokens"
36
+ data_type: TYPE_INT32
37
+ dims: [-1]
38
+ }
39
+ ]
40
+ output [
41
+ {
42
+ name: "waveform"
43
+ data_type: TYPE_FP32
44
+ dims: [ -1 ]
45
+ }
46
+ ]
47
+
48
+ instance_group [
49
+ {
50
+ count: 1
51
+ kind: KIND_CPU
52
+ }
53
+ ]
runtime/triton_trtllm/run.sh ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export PYTHONPATH=../../../Spark-TTS/
2
+ export CUDA_VISIBLE_DEVICES=0
3
+ stage=$1
4
+ stop_stage=$2
5
+ service_type=$3
6
+ echo "Start stage: $stage, Stop stage: $stop_stage service_type: $service_type"
7
+
8
+ huggingface_model_local_dir=../../pretrained_models/Spark-TTS-0.5B
9
+ trt_dtype=bfloat16
10
+ trt_weights_dir=./tllm_checkpoint_${trt_dtype}
11
+ trt_engines_dir=./trt_engines_${trt_dtype}
12
+
13
+ model_repo=./model_repo_test
14
+
15
+ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
16
+ echo "Downloading Spark-TTS-0.5B from HuggingFace"
17
+ huggingface-cli download SparkAudio/Spark-TTS-0.5B --local-dir $huggingface_model_local_dir || exit 1
18
+ fi
19
+
20
+
21
+ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
22
+ echo "Converting checkpoint to TensorRT weights"
23
+ python scripts/convert_checkpoint.py --model_dir $huggingface_model_local_dir/LLM \
24
+ --output_dir $trt_weights_dir \
25
+ --dtype $trt_dtype || exit 1
26
+
27
+ echo "Building TensorRT engines"
28
+ trtllm-build --checkpoint_dir $trt_weights_dir \
29
+ --output_dir $trt_engines_dir \
30
+ --max_batch_size 16 \
31
+ --max_num_tokens 32768 \
32
+ --gemm_plugin $trt_dtype || exit 1
33
+ fi
34
+
35
+ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
36
+ echo "Creating model repository"
37
+ rm -rf $model_repo
38
+ mkdir -p $model_repo
39
+ spark_tts_dir="spark_tts"
40
+
41
+ cp -r ./model_repo/${spark_tts_dir} $model_repo
42
+ cp -r ./model_repo/audio_tokenizer $model_repo
43
+ cp -r ./model_repo/tensorrt_llm $model_repo
44
+ cp -r ./model_repo/vocoder $model_repo
45
+
46
+ ENGINE_PATH=$trt_engines_dir
47
+ MAX_QUEUE_DELAY_MICROSECONDS=0
48
+ MODEL_DIR=$huggingface_model_local_dir
49
+ LLM_TOKENIZER_DIR=$huggingface_model_local_dir/LLM
50
+ BLS_INSTANCE_NUM=4
51
+ TRITON_MAX_BATCH_SIZE=16
52
+ # streaming TTS parameters
53
+ AUDIO_CHUNK_DURATION=1.0
54
+ MAX_AUDIO_CHUNK_DURATION=30.0
55
+ AUDIO_CHUNK_SIZE_SCALE_FACTOR=8.0
56
+ AUDIO_CHUNK_OVERLAP_DURATION=0.1
57
+ python3 scripts/fill_template.py -i ${model_repo}/vocoder/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
58
+ python3 scripts/fill_template.py -i ${model_repo}/audio_tokenizer/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
59
+ if [ "$service_type" == "streaming" ]; then
60
+ DECOUPLED_MODE=True
61
+ else
62
+ DECOUPLED_MODE=False
63
+ fi
64
+ python3 scripts/fill_template.py -i ${model_repo}/${spark_tts_dir}/config.pbtxt bls_instance_num:${BLS_INSTANCE_NUM},llm_tokenizer_dir:${LLM_TOKENIZER_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS},audio_chunk_duration:${AUDIO_CHUNK_DURATION},max_audio_chunk_duration:${MAX_AUDIO_CHUNK_DURATION},audio_chunk_size_scale_factor:${AUDIO_CHUNK_SIZE_SCALE_FACTOR},audio_chunk_overlap_duration:${AUDIO_CHUNK_OVERLAP_DURATION}
65
+ python3 scripts/fill_template.py -i ${model_repo}/tensorrt_llm/config.pbtxt triton_backend:tensorrtllm,triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_beam_width:1,engine_dir:${ENGINE_PATH},max_tokens_in_paged_kv_cache:2560,max_attention_window_size:2560,kv_cache_free_gpu_mem_fraction:0.5,exclude_input_in_output:True,enable_kv_cache_reuse:False,batching_strategy:inflight_fused_batching,max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS},encoder_input_features_data_type:TYPE_FP16,logits_datatype:TYPE_FP32
66
+
67
+ fi
68
+
69
+ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
70
+ echo "Starting Triton server"
71
+ tritonserver --model-repository ${model_repo}
72
+ fi
73
+
74
+
75
+ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
76
+ echo "Running benchmark client"
77
+ num_task=2
78
+ if [ "$service_type" == "streaming" ]; then
79
+ mode="streaming"
80
+ else
81
+ mode="offline"
82
+ fi
83
+ python3 client_grpc.py \
84
+ --server-addr localhost \
85
+ --model-name spark_tts \
86
+ --num-tasks $num_task \
87
+ --mode $mode \
88
+ --log-dir ./log_concurrent_tasks_${num_task}_${mode}_new
89
+ fi
90
+
91
+ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
92
+ echo "Running single utterance client"
93
+ if [ "$service_type" == "streaming" ]; then
94
+ python client_grpc.py \
95
+ --server-addr localhost \
96
+ --reference-audio ../../example/prompt_audio.wav \
97
+ --reference-text "吃燕窝就选燕之屋,本节目由26年专注高品质燕窝的燕之屋冠名播出。豆奶牛奶换着喝,营养更均衡,本节目由豆本豆豆奶特约播出。" \
98
+ --target-text "身临其境,换新体验。塑造开源语音合成新范式,让智能语音更自然。" \
99
+ --model-name spark_tts \
100
+ --chunk-overlap-duration 0.1 \
101
+ --mode streaming
102
+ else
103
+ python client_http.py \
104
+ --reference-audio ../../example/prompt_audio.wav \
105
+ --reference-text "吃燕窝就选燕之屋,本节目由26年专注高品质燕窝的燕之屋冠名播出。豆奶牛奶换着喝,营养更均衡,本节目由豆本豆豆奶特约播出。" \
106
+ --target-text "身临其境,换新体验。塑造开源语音合成新范式,让智能语音更自然。" \
107
+ --model-name spark_tts
108
+ fi
109
+ fi
runtime/triton_trtllm/scripts/convert_checkpoint.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import time
4
+ import traceback
5
+ from concurrent.futures import ThreadPoolExecutor, as_completed
6
+
7
+ from transformers import AutoConfig
8
+
9
+ import tensorrt_llm
10
+ from tensorrt_llm._utils import release_gc
11
+ from tensorrt_llm.logger import logger
12
+ from tensorrt_llm.mapping import Mapping
13
+ from tensorrt_llm.models import QWenForCausalLM
14
+ from tensorrt_llm.models.modeling_utils import QuantConfig
15
+ from tensorrt_llm.quantization import QuantAlgo
16
+
17
+
18
+ def parse_arguments():
19
+ parser = argparse.ArgumentParser()
20
+ parser.add_argument('--model_dir', type=str, default=None, required=True)
21
+ parser.add_argument('--tp_size',
22
+ type=int,
23
+ default=1,
24
+ help='N-way tensor parallelism size')
25
+ parser.add_argument('--pp_size',
26
+ type=int,
27
+ default=1,
28
+ help='N-way pipeline parallelism size')
29
+ parser.add_argument(
30
+ '--dtype',
31
+ type=str,
32
+ default='auto',
33
+ choices=['auto', 'float16', 'bfloat16', 'float32'],
34
+ help=
35
+ "The data type for the model weights and activations if not quantized. "
36
+ "If 'auto', the data type is automatically inferred from the source model; "
37
+ "however, if the source dtype is float32, it is converted to float16.")
38
+ parser.add_argument(
39
+ '--use_weight_only',
40
+ default=False,
41
+ action="store_true",
42
+ help='Quantize weights for the various GEMMs to INT4/INT8.'
43
+ 'See --weight_only_precision to set the precision')
44
+ parser.add_argument(
45
+ '--disable_weight_only_quant_plugin',
46
+ default=False,
47
+ action="store_true",
48
+ help=
49
+ 'By default, using plugin implementation for weight quantization. Enabling disable_weight_only_quant_plugin flag will use ootb implementation instead of plugin.'
50
+ 'You must also use --use_weight_only for that argument to have an impact.'
51
+ )
52
+ parser.add_argument(
53
+ '--weight_only_precision',
54
+ const='int8',
55
+ type=str,
56
+ nargs='?',
57
+ default='int8',
58
+ choices=['int8', 'int4', 'int4_gptq'],
59
+ help=
60
+ 'Define the precision for the weights when using weight-only quantization.'
61
+ 'You must also use --use_weight_only for that argument to have an impact.'
62
+ )
63
+ parser.add_argument(
64
+ '--calib_dataset',
65
+ type=str,
66
+ default='ccdv/cnn_dailymail',
67
+ help=
68
+ "The huggingface dataset name or the local directory of the dataset for calibration."
69
+ )
70
+ parser.add_argument(
71
+ "--smoothquant",
72
+ "-sq",
73
+ type=float,
74
+ default=None,
75
+ help="Set the α parameter (see https://arxiv.org/pdf/2211.10438.pdf)"
76
+ " to Smoothquant the model, and output int8 weights."
77
+ " A good first try is 0.5. Must be in [0, 1]")
78
+ parser.add_argument(
79
+ '--per_channel',
80
+ action="store_true",
81
+ default=False,
82
+ help=
83
+ 'By default, we use a single static scaling factor for the GEMM\'s result. '
84
+ 'per_channel instead uses a different static scaling factor for each channel. '
85
+ 'The latter is usually more accurate, but a little slower.')
86
+ parser.add_argument(
87
+ '--per_token',
88
+ action="store_true",
89
+ default=False,
90
+ help=
91
+ 'By default, we use a single static scaling factor to scale activations in the int8 range. '
92
+ 'per_token chooses at run time, and for each token, a custom scaling factor. '
93
+ 'The latter is usually more accurate, but a little slower.')
94
+ parser.add_argument(
95
+ '--int8_kv_cache',
96
+ default=False,
97
+ action="store_true",
98
+ help=
99
+ 'By default, we use dtype for KV cache. int8_kv_cache chooses int8 quantization for KV'
100
+ )
101
+ parser.add_argument(
102
+ '--per_group',
103
+ default=False,
104
+ action="store_true",
105
+ help=
106
+ 'By default, we use a single static scaling factor to scale weights in the int4 range. '
107
+ 'per_group chooses at run time, and for each group, a custom scaling factor. '
108
+ 'The flag is built for GPTQ/AWQ quantization.')
109
+
110
+ parser.add_argument('--group_size',
111
+ type=int,
112
+ default=128,
113
+ help='Group size used in GPTQ quantization.')
114
+
115
+ parser.add_argument("--load_model_on_cpu", action="store_true")
116
+ parser.add_argument(
117
+ '--use_parallel_embedding',
118
+ action="store_true",
119
+ default=False,
120
+ help=
121
+ 'By default embedding parallelism is disabled. By setting this flag, embedding parallelism is enabled'
122
+ )
123
+ parser.add_argument(
124
+ '--embedding_sharding_dim',
125
+ type=int,
126
+ default=0,
127
+ choices=[0, 1],
128
+ help=
129
+ 'By default the embedding lookup table is sharded along vocab dimension (embedding_sharding_dim=0). '
130
+ 'To shard it along hidden dimension, set embedding_sharding_dim=1'
131
+ 'Note: embedding sharing is only enabled when embedding_sharding_dim = 0'
132
+ )
133
+ parser.add_argument('--output_dir',
134
+ type=str,
135
+ default='tllm_checkpoint',
136
+ help='The path to save the TensorRT-LLM checkpoint')
137
+ parser.add_argument(
138
+ '--workers',
139
+ type=int,
140
+ default=1,
141
+ help='The number of workers for converting checkpoint in parallel')
142
+ parser.add_argument(
143
+ '--moe_tp_size',
144
+ type=int,
145
+ default=-1,
146
+ help=
147
+ 'N-way tensor parallelism size for MOE, default is tp_size, which will do tp-only for MoE'
148
+ )
149
+ parser.add_argument(
150
+ '--moe_ep_size',
151
+ type=int,
152
+ default=-1,
153
+ help=
154
+ 'N-way expert parallelism size for MOE, default is 1, which will do tp-only for MoE'
155
+ )
156
+ args = parser.parse_args()
157
+ return args
158
+
159
+
160
+ def args_to_quant_config(args: argparse.Namespace) -> QuantConfig:
161
+ '''return config dict with quantization info based on the command line args
162
+ '''
163
+ quant_config = QuantConfig()
164
+ if args.use_weight_only:
165
+ if args.weight_only_precision == 'int8':
166
+ quant_config.quant_algo = QuantAlgo.W8A16
167
+ elif args.weight_only_precision == 'int4':
168
+ quant_config.quant_algo = QuantAlgo.W4A16
169
+ elif args.smoothquant:
170
+ quant_config.smoothquant_val = args.smoothquant
171
+ if args.per_channel:
172
+ if args.per_token:
173
+ quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN
174
+ else:
175
+ quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_CHANNEL_PER_TENSOR_PLUGIN
176
+ else:
177
+ if args.per_token:
178
+ quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_TENSOR_PER_TOKEN_PLUGIN
179
+ else:
180
+ quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_TENSOR_PLUGIN
181
+
182
+ if args.int8_kv_cache:
183
+ quant_config.kv_cache_quant_algo = QuantAlgo.INT8
184
+
185
+ if args.weight_only_precision == 'int4_gptq':
186
+ quant_config.group_size = args.group_size
187
+ quant_config.has_zero_point = True
188
+ quant_config.pre_quant_scale = False
189
+ quant_config.quant_algo = QuantAlgo.W4A16_GPTQ
190
+
191
+ return quant_config
192
+
193
+
194
+ def update_quant_config_from_hf(quant_config, hf_config,
195
+ override_fields) -> tuple[QuantConfig, dict]:
196
+ hf_config_dict = hf_config.to_dict()
197
+ if hf_config_dict.get('quantization_config'):
198
+ # update the quant_algo, and clamp_val.
199
+ if hf_config_dict['quantization_config'].get('quant_method') == 'awq':
200
+ logger.info(
201
+ "Load quantization configs from huggingface model_config.")
202
+ quant_config.quant_algo = QuantAlgo.W4A16_GPTQ
203
+ quant_config.group_size = hf_config_dict['quantization_config'].get(
204
+ 'group_size', 128)
205
+ quant_config.has_zero_point = hf_config_dict[
206
+ 'quantization_config'].get('zero_point', False)
207
+ override_fields.update({"use_autoawq": True})
208
+ elif hf_config_dict['quantization_config'].get(
209
+ 'quant_method') == 'gptq':
210
+ logger.info(
211
+ "Load quantization configs from huggingface model_config.")
212
+ desc_act = hf_config_dict['quantization_config'].get(
213
+ 'desc_act', False)
214
+ if desc_act:
215
+ raise ValueError("GPTQ with desc_act=True is not implemented!")
216
+ quant_config.quant_algo = QuantAlgo.W4A16_GPTQ
217
+ quant_config.group_size = hf_config_dict['quantization_config'].get(
218
+ 'group_size', 128)
219
+ quant_config.has_zero_point = hf_config_dict[
220
+ 'quantization_config'].get('sym', False)
221
+ return quant_config, override_fields
222
+
223
+
224
+ def args_to_build_options(args):
225
+ return {
226
+ 'use_parallel_embedding': args.use_parallel_embedding,
227
+ 'embedding_sharding_dim': args.embedding_sharding_dim,
228
+ 'disable_weight_only_quant_plugin':
229
+ args.disable_weight_only_quant_plugin
230
+ }
231
+
232
+
233
+ def convert_and_save_hf(args):
234
+ model_dir = args.model_dir
235
+ world_size = args.tp_size * args.pp_size
236
+ # Need to convert the cli args to the kay-value pairs and override them in the generate config dict.
237
+ # Ideally these fields will be moved out of the config and pass them into build API, keep them here for compatibility purpose for now,
238
+ # before the refactor is done.
239
+ override_fields = {}
240
+ override_fields.update(args_to_build_options(args))
241
+ quant_config = args_to_quant_config(args)
242
+
243
+ try:
244
+ hf_config = AutoConfig.from_pretrained(model_dir,
245
+ trust_remote_code=True)
246
+ quant_config, override_fields = update_quant_config_from_hf(
247
+ quant_config, hf_config, override_fields)
248
+ except:
249
+ logger.warning("AutoConfig cannot load the huggingface config.")
250
+
251
+ if args.smoothquant is not None or args.int8_kv_cache:
252
+ mapping = Mapping(
253
+ world_size=world_size,
254
+ tp_size=args.tp_size,
255
+ pp_size=args.pp_size,
256
+ moe_tp_size=args.moe_tp_size,
257
+ moe_ep_size=args.moe_ep_size,
258
+ )
259
+ QWenForCausalLM.quantize(args.model_dir,
260
+ args.output_dir,
261
+ dtype=args.dtype,
262
+ mapping=mapping,
263
+ quant_config=quant_config,
264
+ calib_dataset=args.calib_dataset,
265
+ **override_fields)
266
+ else:
267
+
268
+ def convert_and_save_rank(args, rank):
269
+ mapping = Mapping(world_size=world_size,
270
+ rank=rank,
271
+ tp_size=args.tp_size,
272
+ pp_size=args.pp_size,
273
+ moe_tp_size=args.moe_tp_size,
274
+ moe_ep_size=args.moe_ep_size)
275
+ qwen = QWenForCausalLM.from_hugging_face(model_dir,
276
+ args.dtype,
277
+ mapping=mapping,
278
+ quant_config=quant_config,
279
+ **override_fields)
280
+ qwen.save_checkpoint(args.output_dir, save_config=(rank == 0))
281
+ del qwen
282
+
283
+ execute(args.workers, [convert_and_save_rank] * world_size, args)
284
+ release_gc()
285
+
286
+
287
+ def execute(workers, func, args):
288
+ if workers == 1:
289
+ for rank, f in enumerate(func):
290
+ f(args, rank)
291
+ else:
292
+ with ThreadPoolExecutor(max_workers=workers) as p:
293
+ futures = [p.submit(f, args, rank) for rank, f in enumerate(func)]
294
+ exceptions = []
295
+ for future in as_completed(futures):
296
+ try:
297
+ future.result()
298
+ except Exception as e:
299
+ traceback.print_exc()
300
+ exceptions.append(e)
301
+ assert len(
302
+ exceptions
303
+ ) == 0, "Checkpoint conversion failed, please check error log."
304
+
305
+
306
+ def main():
307
+ print(tensorrt_llm.__version__)
308
+ args = parse_arguments()
309
+
310
+ if (args.moe_tp_size == -1 and args.moe_ep_size == -1):
311
+ # moe default to tp-only
312
+ args.moe_tp_size = args.tp_size
313
+ args.moe_ep_size = 1
314
+ elif (args.moe_tp_size == -1):
315
+ args.moe_tp_size = args.tp_size // args.moe_ep_size
316
+ elif (args.moe_ep_size == -1):
317
+ args.moe_ep_size = args.tp_size // args.moe_tp_size
318
+ assert (args.moe_tp_size * args.moe_ep_size == args.tp_size
319
+ ), "moe_tp_size * moe_ep_size must equal to tp_size"
320
+
321
+ tik = time.time()
322
+
323
+ if not os.path.exists(args.output_dir):
324
+ os.makedirs(args.output_dir)
325
+
326
+ assert args.model_dir is not None
327
+ convert_and_save_hf(args)
328
+
329
+ tok = time.time()
330
+ t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
331
+ print(f'Total time of converting checkpoints: {t}')
332
+
333
+
334
+ if __name__ == '__main__':
335
+ main()
runtime/triton_trtllm/scripts/fill_template.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! /usr/bin/env python3
2
+ from argparse import ArgumentParser
3
+ from string import Template
4
+
5
+
6
+ def split(string, delimiter):
7
+ """Split a string using delimiter. Supports escaping.
8
+
9
+ Args:
10
+ string (str): The string to split.
11
+ delimiter (str): The delimiter to split the string with.
12
+
13
+ Returns:
14
+ list: A list of strings.
15
+ """
16
+ result = []
17
+ current = ""
18
+ escape = False
19
+ for char in string:
20
+ if escape:
21
+ current += char
22
+ escape = False
23
+ elif char == delimiter:
24
+ result.append(current)
25
+ current = ""
26
+ elif char == "\\":
27
+ escape = True
28
+ else:
29
+ current += char
30
+ result.append(current)
31
+ return result
32
+
33
+
34
+ def main(file_path, substitutions, in_place):
35
+ with open(file_path) as f:
36
+ pbtxt = Template(f.read())
37
+
38
+ sub_dict = {
39
+ "max_queue_size": 0,
40
+ 'max_queue_delay_microseconds': 0,
41
+ }
42
+ for sub in split(substitutions, ","):
43
+ key, value = split(sub, ":")
44
+ sub_dict[key] = value
45
+
46
+ assert key in pbtxt.template, f"key '{key}' does not exist in the file {file_path}."
47
+
48
+ pbtxt = pbtxt.safe_substitute(sub_dict)
49
+
50
+ if in_place:
51
+ with open(file_path, "w") as f:
52
+ f.write(pbtxt)
53
+ else:
54
+ print(pbtxt)
55
+
56
+
57
+ if __name__ == "__main__":
58
+ parser = ArgumentParser()
59
+ parser.add_argument("file_path", help="path of the .pbtxt to modify")
60
+ parser.add_argument(
61
+ "substitutions",
62
+ help=
63
+ "substitutions to perform, in the format variable_name_1:value_1,variable_name_2:value_2..."
64
+ )
65
+ parser.add_argument("--in_place",
66
+ "-i",
67
+ action="store_true",
68
+ help="do the operation in-place")
69
+ args = parser.parse_args()
70
+ main(**vars(args))
sparktts/models/audio_tokenizer.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 SparkAudio
2
+ # 2025 Xinsheng Wang ([email protected])
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+
17
+ import torch
18
+ import numpy as np
19
+
20
+ from pathlib import Path
21
+ from typing import Any, Dict, Tuple
22
+ from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2Model
23
+
24
+ from sparktts.utils.file import load_config
25
+ from sparktts.utils.audio import load_audio
26
+ from sparktts.models.bicodec import BiCodec
27
+
28
+
29
+ class BiCodecTokenizer:
30
+ """BiCodec tokenizer for handling audio input and tokenization."""
31
+
32
+ def __init__(self, model_dir: Path, device: torch.device = None, **kwargs):
33
+ super().__init__()
34
+ """
35
+ Args:
36
+ model_dir: Path to the model directory.
37
+ device: Device to run the model on (default is GPU if available).
38
+ """
39
+ self.device = device
40
+ self.model_dir = model_dir
41
+ self.config = load_config(f"{model_dir}/config.yaml")
42
+ self._initialize_model()
43
+
44
+ def _initialize_model(self):
45
+ """Load and initialize the BiCodec model and Wav2Vec2 feature extractor."""
46
+ self.model = BiCodec.load_from_checkpoint(f"{self.model_dir}/BiCodec").to(
47
+ self.device
48
+ )
49
+ self.processor = Wav2Vec2FeatureExtractor.from_pretrained(
50
+ f"{self.model_dir}/wav2vec2-large-xlsr-53"
51
+ )
52
+ self.feature_extractor = Wav2Vec2Model.from_pretrained(
53
+ f"{self.model_dir}/wav2vec2-large-xlsr-53"
54
+ ).to(self.device)
55
+ self.feature_extractor.config.output_hidden_states = True
56
+
57
+ def get_ref_clip(self, wav: np.ndarray) -> np.ndarray:
58
+ """Get reference audio clip for speaker embedding."""
59
+ ref_segment_length = (
60
+ int(self.config["sample_rate"] * self.config["ref_segment_duration"])
61
+ // self.config["latent_hop_length"]
62
+ * self.config["latent_hop_length"]
63
+ )
64
+ wav_length = len(wav)
65
+
66
+ if ref_segment_length > wav_length:
67
+ # Repeat and truncate to handle insufficient length
68
+ wav = np.tile(wav, ref_segment_length // wav_length + 1)
69
+
70
+ return wav[:ref_segment_length]
71
+
72
+ def process_audio(self, wav_path: Path) -> Tuple[np.ndarray, torch.Tensor]:
73
+ """load auido and get reference audio from wav path"""
74
+ wav = load_audio(
75
+ wav_path,
76
+ sampling_rate=self.config["sample_rate"],
77
+ volume_normalize=self.config["volume_normalize"],
78
+ )
79
+
80
+ wav_ref = self.get_ref_clip(wav)
81
+
82
+ wav_ref = torch.from_numpy(wav_ref).unsqueeze(0).float()
83
+ return wav, wav_ref
84
+
85
+ def extract_wav2vec2_features(self, wavs: torch.Tensor) -> torch.Tensor:
86
+ """extract wav2vec2 features"""
87
+ inputs = self.processor(
88
+ wavs,
89
+ sampling_rate=16000,
90
+ return_tensors="pt",
91
+ padding=True,
92
+ output_hidden_states=True,
93
+ ).input_values
94
+ feat = self.feature_extractor(inputs.to(self.feature_extractor.device))
95
+ feats_mix = (
96
+ feat.hidden_states[11] + feat.hidden_states[14] + feat.hidden_states[16]
97
+ ) / 3
98
+
99
+ return feats_mix
100
+
101
+ def tokenize_batch(self, batch: Dict[str, Any]) -> torch.Tensor:
102
+ """tokenize the batch of audio
103
+
104
+ Args:
105
+ batch:
106
+ wavs (List[np.ndarray]): batch of audio
107
+ ref_wavs (torch.Tensor): reference audio. shape: (batch_size, seq_len)
108
+
109
+ Returns:
110
+ semantic_tokens: semantic tokens. shape: (batch_size, seq_len, latent_dim)
111
+ global_tokens: global tokens. shape: (batch_size, seq_len, global_dim)
112
+ """
113
+ feats = self.extract_wav2vec2_features(batch["wav"])
114
+ batch["feat"] = feats
115
+ semantic_tokens, global_tokens = self.model.tokenize(batch)
116
+
117
+ return global_tokens, semantic_tokens
118
+
119
+ def tokenize(self, audio_path: str) -> Tuple[torch.Tensor, torch.Tensor]:
120
+ """tokenize the audio"""
121
+ wav, ref_wav = self.process_audio(audio_path)
122
+ feat = self.extract_wav2vec2_features(wav)
123
+ batch = {
124
+ "wav": torch.from_numpy(wav).unsqueeze(0).float().to(self.device),
125
+ "ref_wav": ref_wav.to(self.device),
126
+ "feat": feat.to(self.device),
127
+ }
128
+ semantic_tokens, global_tokens = self.model.tokenize(batch)
129
+
130
+ return global_tokens, semantic_tokens
131
+
132
+ def detokenize(
133
+ self, global_tokens: torch.Tensor, semantic_tokens: torch.Tensor
134
+ ) -> np.array:
135
+ """detokenize the tokens to waveform
136
+
137
+ Args:
138
+ global_tokens: global tokens. shape: (batch_size, global_dim)
139
+ semantic_tokens: semantic tokens. shape: (batch_size, latent_dim)
140
+
141
+ Returns:
142
+ wav_rec: waveform. shape: (batch_size, seq_len) for batch or (seq_len,) for single
143
+ """
144
+ global_tokens = global_tokens.unsqueeze(1)
145
+ wav_rec = self.model.detokenize(semantic_tokens, global_tokens)
146
+ return wav_rec.detach().squeeze().cpu().numpy()
147
+
148
+
149
+ # test
150
+ if __name__ == "__main__":
151
+ import soundfile as sf
152
+
153
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
154
+ tokenizer = BiCodecTokenizer(
155
+ model_dir="pretrained_models/Spark-TTS-0.5B",
156
+ device=device,
157
+ )
158
+ wav_path = "example/prompt_audio.wav"
159
+
160
+ global_tokens, semantic_tokens = tokenizer.tokenize(wav_path)
161
+
162
+ wav_rec = tokenizer.detokenize(global_tokens.squeeze(0), semantic_tokens)
163
+ sf.write("example/prompt_recon.wav", wav_rec, 16000)
sparktts/models/bicodec.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 SparkAudio
2
+ # 2025 Xinsheng Wang ([email protected])
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ from pathlib import Path
19
+ from typing import Dict, Any
20
+ from omegaconf import DictConfig
21
+ from safetensors.torch import load_file
22
+
23
+ from sparktts.utils.file import load_config
24
+ from sparktts.modules.speaker.speaker_encoder import SpeakerEncoder
25
+ from sparktts.modules.encoder_decoder.feat_encoder import Encoder
26
+ from sparktts.modules.encoder_decoder.feat_decoder import Decoder
27
+ from sparktts.modules.encoder_decoder.wave_generator import WaveGenerator
28
+ from sparktts.modules.vq.factorized_vector_quantize import FactorizedVectorQuantize
29
+
30
+
31
+ class BiCodec(nn.Module):
32
+ """
33
+ BiCodec model for speech synthesis, incorporating a speaker encoder, feature encoder/decoder,
34
+ quantizer, and wave generator.
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ mel_params: Dict[str, Any],
40
+ encoder: nn.Module,
41
+ decoder: nn.Module,
42
+ quantizer: nn.Module,
43
+ speaker_encoder: nn.Module,
44
+ prenet: nn.Module,
45
+ postnet: nn.Module,
46
+ **kwargs
47
+ ) -> None:
48
+ """
49
+ Initializes the BiCodec model with the required components.
50
+
51
+ Args:
52
+ mel_params (dict): Parameters for the mel-spectrogram transformer.
53
+ encoder (nn.Module): Encoder module.
54
+ decoder (nn.Module): Decoder module.
55
+ quantizer (nn.Module): Quantizer module.
56
+ speaker_encoder (nn.Module): Speaker encoder module.
57
+ prenet (nn.Module): Prenet network.
58
+ postnet (nn.Module): Postnet network.
59
+ """
60
+ super().__init__()
61
+ self.encoder = encoder
62
+ self.decoder = decoder
63
+ self.quantizer = quantizer
64
+ self.speaker_encoder = speaker_encoder
65
+ self.prenet = prenet
66
+ self.postnet = postnet
67
+ self.init_mel_transformer(mel_params)
68
+
69
+ @classmethod
70
+ def load_from_checkpoint(cls, model_dir: Path, **kwargs) -> "BiCodec":
71
+ """
72
+ Loads the model from a checkpoint.
73
+
74
+ Args:
75
+ model_dir (Path): Path to the model directory containing checkpoint and config.
76
+
77
+ Returns:
78
+ BiCodec: The initialized BiCodec model.
79
+ """
80
+ ckpt_path = f'{model_dir}/model.safetensors'
81
+ config = load_config(f'{model_dir}/config.yaml')['audio_tokenizer']
82
+ mel_params = config["mel_params"]
83
+ encoder = Encoder(**config["encoder"])
84
+ quantizer = FactorizedVectorQuantize(**config["quantizer"])
85
+ prenet = Decoder(**config["prenet"])
86
+ postnet = Decoder(**config["postnet"])
87
+ decoder = WaveGenerator(**config["decoder"])
88
+ speaker_encoder = SpeakerEncoder(**config["speaker_encoder"])
89
+
90
+ model = cls(
91
+ mel_params=mel_params,
92
+ encoder=encoder,
93
+ decoder=decoder,
94
+ quantizer=quantizer,
95
+ speaker_encoder=speaker_encoder,
96
+ prenet=prenet,
97
+ postnet=postnet,
98
+ )
99
+
100
+ state_dict = load_file(ckpt_path)
101
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
102
+
103
+ for key in missing_keys:
104
+ print(f"Missing tensor: {key}")
105
+ for key in unexpected_keys:
106
+ print(f"Unexpected tensor: {key}")
107
+
108
+ model.eval()
109
+ model.remove_weight_norm()
110
+
111
+ return model
112
+
113
+ def forward(self, batch: Dict[str, Any]) -> Dict[str, Any]:
114
+ """
115
+ Performs a forward pass through the model.
116
+
117
+ Args:
118
+ batch (dict): A dictionary containing features, reference waveform, and target waveform.
119
+
120
+ Returns:
121
+ dict: A dictionary containing the reconstruction, features, and other metrics.
122
+ """
123
+ feat = batch["feat"]
124
+ mel = self.mel_transformer(batch["ref_wav"]).squeeze(1)
125
+
126
+ z = self.encoder(feat.transpose(1, 2))
127
+ vq_outputs = self.quantizer(z)
128
+
129
+ x_vector, d_vector = self.speaker_encoder(mel.transpose(1, 2))
130
+
131
+ conditions = d_vector
132
+ with_speaker_loss = False
133
+
134
+ x = self.prenet(vq_outputs["z_q"], conditions)
135
+ pred_feat = self.postnet(x)
136
+ x = x + conditions.unsqueeze(-1)
137
+ wav_recon = self.decoder(x)
138
+
139
+ return {
140
+ "vq_loss": vq_outputs["vq_loss"],
141
+ "perplexity": vq_outputs["perplexity"],
142
+ "cluster_size": vq_outputs["active_num"],
143
+ "recons": wav_recon,
144
+ "pred_feat": pred_feat,
145
+ "x_vector": x_vector,
146
+ "d_vector": d_vector,
147
+ "audios": batch["wav"].unsqueeze(1),
148
+ "with_speaker_loss": with_speaker_loss,
149
+ }
150
+
151
+ @torch.no_grad()
152
+ def tokenize(self, batch: Dict[str, Any]):
153
+ """
154
+ Tokenizes the input audio into semantic and global tokens.
155
+
156
+ Args:
157
+ batch (dict): The input audio features and reference waveform.
158
+
159
+ Returns:
160
+ tuple: Semantic tokens and global tokens.
161
+ """
162
+ feat = batch["feat"]
163
+ mel = self.mel_transformer(batch["ref_wav"]).squeeze(1)
164
+
165
+ z = self.encoder(feat.transpose(1, 2))
166
+ semantic_tokens = self.quantizer.tokenize(z)
167
+ global_tokens = self.speaker_encoder.tokenize(mel.transpose(1, 2))
168
+
169
+ return semantic_tokens, global_tokens
170
+
171
+ @torch.no_grad()
172
+ def detokenize(self, semantic_tokens, global_tokens):
173
+ """
174
+ Detokenizes the semantic and global tokens into a waveform.
175
+
176
+ Args:
177
+ semantic_tokens (tensor): Semantic tokens.
178
+ global_tokens (tensor): Global tokens.
179
+
180
+ Returns:
181
+ tensor: Reconstructed waveform.
182
+ """
183
+ z_q = self.quantizer.detokenize(semantic_tokens)
184
+ d_vector = self.speaker_encoder.detokenize(global_tokens)
185
+ x = self.prenet(z_q, d_vector)
186
+ x = x + d_vector.unsqueeze(-1)
187
+ wav_recon = self.decoder(x)
188
+
189
+ return wav_recon
190
+
191
+ def init_mel_transformer(self, config: Dict[str, Any]):
192
+ """
193
+ Initializes the MelSpectrogram transformer based on the provided configuration.
194
+
195
+ Args:
196
+ config (dict): Configuration parameters for MelSpectrogram.
197
+ """
198
+ import torchaudio.transforms as TT
199
+
200
+ self.mel_transformer = TT.MelSpectrogram(
201
+ config["sample_rate"],
202
+ config["n_fft"],
203
+ config["win_length"],
204
+ config["hop_length"],
205
+ config["mel_fmin"],
206
+ config["mel_fmax"],
207
+ n_mels=config["num_mels"],
208
+ power=1,
209
+ norm="slaney",
210
+ mel_scale="slaney",
211
+ )
212
+
213
+ def remove_weight_norm(self):
214
+ """Removes weight normalization from all layers."""
215
+ def _remove_weight_norm(m):
216
+ try:
217
+ torch.nn.utils.remove_weight_norm(m)
218
+ except ValueError:
219
+ pass # The module didn't have weight norm
220
+
221
+ self.apply(_remove_weight_norm)
222
+
223
+
224
+ # Test the model
225
+ if __name__ == "__main__":
226
+
227
+ config = load_config("pretrained_models/SparkTTS-0.5B/BiCodec/config.yaml")
228
+ model = BiCodec.load_from_checkpoint(
229
+ model_dir="pretrained_models/SparkTTS-0.5B/BiCodec",
230
+ )
231
+
232
+ # Generate random inputs for testing
233
+ duration = 0.96
234
+ x = torch.randn(20, 1, int(duration * 16000))
235
+ feat = torch.randn(20, int(duration * 50), 1024)
236
+ inputs = {"feat": feat, "wav": x, "ref_wav": x}
237
+
238
+ # Forward pass
239
+ outputs = model(inputs)
240
+ semantic_tokens, global_tokens = model.tokenize(inputs)
241
+ wav_recon = model.detokenize(semantic_tokens, global_tokens)
242
+
243
+ # Verify if the reconstruction matches
244
+ if torch.allclose(outputs["recons"].detach(), wav_recon):
245
+ print("Test successful")
246
+ else:
247
+ print("Test failed")
sparktts/modules/blocks/layers.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 SparkAudio
2
+ # 2025 Xinsheng Wang ([email protected])
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # Adapted from https://github.com/descriptinc/descript-audio-codec under the Apache License 2.0
17
+
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ from torch.nn.utils import weight_norm
22
+
23
+
24
+ def WNConv1d(*args, **kwargs):
25
+ return weight_norm(nn.Conv1d(*args, **kwargs))
26
+
27
+
28
+ def WNConvTranspose1d(*args, **kwargs):
29
+ return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
30
+
31
+
32
+ # Scripting this brings model speed up 1.4x
33
+ @torch.jit.script
34
+ def snake(x, alpha):
35
+ shape = x.shape
36
+ x = x.reshape(shape[0], shape[1], -1)
37
+ x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
38
+ x = x.reshape(shape)
39
+ return x
40
+
41
+
42
+ class Snake1d(nn.Module):
43
+ def __init__(self, channels):
44
+ super().__init__()
45
+ self.alpha = nn.Parameter(torch.ones(1, channels, 1))
46
+
47
+ def forward(self, x):
48
+ return snake(x, self.alpha)
49
+
50
+
51
+ class ResidualUnit(nn.Module):
52
+ def __init__(self, dim: int = 16, dilation: int = 1):
53
+ super().__init__()
54
+ pad = ((7 - 1) * dilation) // 2
55
+ self.block = nn.Sequential(
56
+ Snake1d(dim),
57
+ WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
58
+ Snake1d(dim),
59
+ WNConv1d(dim, dim, kernel_size=1),
60
+ )
61
+
62
+ def forward(self, x):
63
+ y = self.block(x)
64
+ pad = (x.shape[-1] - y.shape[-1]) // 2
65
+ if pad > 0:
66
+ x = x[..., pad:-pad]
67
+ return x + y
68
+
69
+
70
+ def init_weights(m):
71
+ if isinstance(m, nn.Conv1d):
72
+ nn.init.trunc_normal_(m.weight, std=0.02)
73
+ nn.init.constant_(m.bias, 0)
sparktts/modules/blocks/samper.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 SparkAudio
2
+ # 2025 Xinsheng Wang ([email protected])
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+
21
+
22
+ class SamplingBlock(nn.Module):
23
+ """Sampling block for upsampling or downsampling"""
24
+
25
+ def __init__(
26
+ self,
27
+ dim: int,
28
+ groups: int = 1,
29
+ upsample_scale: int = 1,
30
+ downsample_scale: int = 1,
31
+ ) -> None:
32
+ """
33
+ Args:
34
+ dim: input dimension
35
+ groups: number of groups
36
+ upsample_scale: upsampling scale
37
+ downsample_scale: downsampling scale
38
+ """
39
+ super(SamplingBlock, self).__init__()
40
+
41
+ self.upsample_scale = upsample_scale
42
+ self.downsample_scale = downsample_scale
43
+
44
+ if self.upsample_scale > 1:
45
+ self.de_conv_upsampler = nn.Sequential(
46
+ nn.LeakyReLU(0.2),
47
+ nn.ConvTranspose1d(
48
+ dim,
49
+ dim,
50
+ kernel_size=upsample_scale * 2,
51
+ stride=upsample_scale,
52
+ padding=upsample_scale // 2 + upsample_scale % 2,
53
+ output_padding=upsample_scale % 2,
54
+ groups=groups,
55
+ ),
56
+ )
57
+
58
+ if self.downsample_scale > 1:
59
+ self.conv_downsampler = nn.Sequential(
60
+ nn.LeakyReLU(0.2),
61
+ nn.Conv1d(
62
+ dim,
63
+ dim,
64
+ kernel_size=2 * downsample_scale,
65
+ stride=downsample_scale,
66
+ padding=downsample_scale // 2 + downsample_scale % 2,
67
+ groups=groups,
68
+ ),
69
+ )
70
+
71
+ @staticmethod
72
+ def repeat_upsampler(x, upsample_scale):
73
+ return x.repeat_interleave(upsample_scale, dim=2)
74
+
75
+ @staticmethod
76
+ def skip_downsampler(x, downsample_scale):
77
+ return F.avg_pool1d(x, kernel_size=downsample_scale, stride=downsample_scale)
78
+
79
+ def forward(self, x):
80
+ x = x.transpose(1, 2)
81
+ if self.upsample_scale > 1:
82
+ repeat_res = self.repeat_upsampler(x, self.upsample_scale)
83
+ deconv_res = self.de_conv_upsampler(x)
84
+ upmerge_res = repeat_res + deconv_res
85
+ else:
86
+ upmerge_res = x
87
+ repeat_res = x
88
+
89
+ if self.downsample_scale > 1:
90
+ conv_res = self.conv_downsampler(upmerge_res)
91
+ skip2_res = self.skip_downsampler(upmerge_res, self.downsample_scale)
92
+ skip1_res = self.skip_downsampler(repeat_res, self.downsample_scale)
93
+ else:
94
+ conv_res = upmerge_res
95
+ skip2_res = upmerge_res
96
+ skip1_res = repeat_res
97
+
98
+ final_res = conv_res + skip1_res + skip2_res
99
+
100
+ return final_res
101
+
102
+
103
+ # test
104
+ if __name__ == "__main__":
105
+ test_input = torch.randn(8, 1024, 50) # Batch size = 8, 1024 channels, length = 50
106
+ model = SamplingBlock(1024, 1024, upsample_scale=2)
107
+ model_down = SamplingBlock(1024, 1024, downsample_scale=2)
108
+ output = model(test_input)
109
+ output_down = model_down(test_input)
110
+ print("shape after upsample * 2", output.shape) # torch.Size([8, 1024, 100])
111
+ print("shape after downsample * 2", output_down.shape) # torch.Size([8, 1024, 25])
112
+ if output.shape == torch.Size([8, 1024, 100]) and output_down.shape == torch.Size(
113
+ [8, 1024, 25]
114
+ ):
115
+ print("test successful")
sparktts/modules/blocks/vocos.py ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 SparkAudio
2
+ # 2025 Xinsheng Wang ([email protected])
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+ from typing import Tuple
21
+ from torch.nn.utils import weight_norm, remove_weight_norm
22
+
23
+ from typing import Optional
24
+
25
+
26
+ class ConvNeXtBlock(nn.Module):
27
+ """ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
28
+
29
+ Args:
30
+ dim (int): Number of input channels.
31
+ intermediate_dim (int): Dimensionality of the intermediate layer.
32
+ layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
33
+ Defaults to None.
34
+ adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
35
+ None means non-conditional LayerNorm. Defaults to None.
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ dim: int,
41
+ intermediate_dim: int,
42
+ layer_scale_init_value: float,
43
+ condition_dim: Optional[int] = None,
44
+ ):
45
+ super().__init__()
46
+ self.dwconv = nn.Conv1d(
47
+ dim, dim, kernel_size=7, padding=3, groups=dim
48
+ ) # depthwise conv
49
+ self.adanorm = condition_dim is not None
50
+ if condition_dim:
51
+ self.norm = AdaLayerNorm(condition_dim, dim, eps=1e-6)
52
+ else:
53
+ self.norm = nn.LayerNorm(dim, eps=1e-6)
54
+ self.pwconv1 = nn.Linear(
55
+ dim, intermediate_dim
56
+ ) # pointwise/1x1 convs, implemented with linear layers
57
+ self.act = nn.GELU()
58
+ self.pwconv2 = nn.Linear(intermediate_dim, dim)
59
+ self.gamma = (
60
+ nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
61
+ if layer_scale_init_value > 0
62
+ else None
63
+ )
64
+
65
+ def forward(
66
+ self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None
67
+ ) -> torch.Tensor:
68
+ residual = x
69
+ x = self.dwconv(x)
70
+ x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
71
+ if self.adanorm:
72
+ assert cond_embedding_id is not None
73
+ x = self.norm(x, cond_embedding_id)
74
+ else:
75
+ x = self.norm(x)
76
+ x = self.pwconv1(x)
77
+ x = self.act(x)
78
+ x = self.pwconv2(x)
79
+ if self.gamma is not None:
80
+ x = self.gamma * x
81
+ x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
82
+
83
+ x = residual + x
84
+ return x
85
+
86
+
87
+ class AdaLayerNorm(nn.Module):
88
+ """
89
+ Adaptive Layer Normalization module with learnable embeddings per `num_embeddings` classes
90
+
91
+ Args:
92
+ condition_dim (int): Dimension of the condition.
93
+ embedding_dim (int): Dimension of the embeddings.
94
+ """
95
+
96
+ def __init__(self, condition_dim: int, embedding_dim: int, eps: float = 1e-6):
97
+ super().__init__()
98
+ self.eps = eps
99
+ self.dim = embedding_dim
100
+ self.scale = nn.Linear(condition_dim, embedding_dim)
101
+ self.shift = nn.Linear(condition_dim, embedding_dim)
102
+ torch.nn.init.ones_(self.scale.weight)
103
+ torch.nn.init.zeros_(self.shift.weight)
104
+
105
+ def forward(self, x: torch.Tensor, cond_embedding: torch.Tensor) -> torch.Tensor:
106
+ scale = self.scale(cond_embedding)
107
+ shift = self.shift(cond_embedding)
108
+ x = nn.functional.layer_norm(x, (self.dim,), eps=self.eps)
109
+ x = x * scale.unsqueeze(1) + shift.unsqueeze(1)
110
+ return x
111
+
112
+
113
+ class ResBlock1(nn.Module):
114
+ """
115
+ ResBlock adapted from HiFi-GAN V1 (https://github.com/jik876/hifi-gan) with dilated 1D convolutions,
116
+ but without upsampling layers.
117
+
118
+ Args:
119
+ dim (int): Number of input channels.
120
+ kernel_size (int, optional): Size of the convolutional kernel. Defaults to 3.
121
+ dilation (tuple[int], optional): Dilation factors for the dilated convolutions.
122
+ Defaults to (1, 3, 5).
123
+ lrelu_slope (float, optional): Negative slope of the LeakyReLU activation function.
124
+ Defaults to 0.1.
125
+ layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
126
+ Defaults to None.
127
+ """
128
+
129
+ def __init__(
130
+ self,
131
+ dim: int,
132
+ kernel_size: int = 3,
133
+ dilation: Tuple[int, int, int] = (1, 3, 5),
134
+ lrelu_slope: float = 0.1,
135
+ layer_scale_init_value: Optional[float] = None,
136
+ ):
137
+ super().__init__()
138
+ self.lrelu_slope = lrelu_slope
139
+ self.convs1 = nn.ModuleList(
140
+ [
141
+ weight_norm(
142
+ nn.Conv1d(
143
+ dim,
144
+ dim,
145
+ kernel_size,
146
+ 1,
147
+ dilation=dilation[0],
148
+ padding=self.get_padding(kernel_size, dilation[0]),
149
+ )
150
+ ),
151
+ weight_norm(
152
+ nn.Conv1d(
153
+ dim,
154
+ dim,
155
+ kernel_size,
156
+ 1,
157
+ dilation=dilation[1],
158
+ padding=self.get_padding(kernel_size, dilation[1]),
159
+ )
160
+ ),
161
+ weight_norm(
162
+ nn.Conv1d(
163
+ dim,
164
+ dim,
165
+ kernel_size,
166
+ 1,
167
+ dilation=dilation[2],
168
+ padding=self.get_padding(kernel_size, dilation[2]),
169
+ )
170
+ ),
171
+ ]
172
+ )
173
+
174
+ self.convs2 = nn.ModuleList(
175
+ [
176
+ weight_norm(
177
+ nn.Conv1d(
178
+ dim,
179
+ dim,
180
+ kernel_size,
181
+ 1,
182
+ dilation=1,
183
+ padding=self.get_padding(kernel_size, 1),
184
+ )
185
+ ),
186
+ weight_norm(
187
+ nn.Conv1d(
188
+ dim,
189
+ dim,
190
+ kernel_size,
191
+ 1,
192
+ dilation=1,
193
+ padding=self.get_padding(kernel_size, 1),
194
+ )
195
+ ),
196
+ weight_norm(
197
+ nn.Conv1d(
198
+ dim,
199
+ dim,
200
+ kernel_size,
201
+ 1,
202
+ dilation=1,
203
+ padding=self.get_padding(kernel_size, 1),
204
+ )
205
+ ),
206
+ ]
207
+ )
208
+
209
+ self.gamma = nn.ParameterList(
210
+ [
211
+ (
212
+ nn.Parameter(
213
+ layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
214
+ )
215
+ if layer_scale_init_value is not None
216
+ else None
217
+ ),
218
+ (
219
+ nn.Parameter(
220
+ layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
221
+ )
222
+ if layer_scale_init_value is not None
223
+ else None
224
+ ),
225
+ (
226
+ nn.Parameter(
227
+ layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
228
+ )
229
+ if layer_scale_init_value is not None
230
+ else None
231
+ ),
232
+ ]
233
+ )
234
+
235
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
236
+ for c1, c2, gamma in zip(self.convs1, self.convs2, self.gamma):
237
+ xt = torch.nn.functional.leaky_relu(x, negative_slope=self.lrelu_slope)
238
+ xt = c1(xt)
239
+ xt = torch.nn.functional.leaky_relu(xt, negative_slope=self.lrelu_slope)
240
+ xt = c2(xt)
241
+ if gamma is not None:
242
+ xt = gamma * xt
243
+ x = xt + x
244
+ return x
245
+
246
+ def remove_weight_norm(self):
247
+ for l in self.convs1:
248
+ remove_weight_norm(l)
249
+ for l in self.convs2:
250
+ remove_weight_norm(l)
251
+
252
+ @staticmethod
253
+ def get_padding(kernel_size: int, dilation: int = 1) -> int:
254
+ return int((kernel_size * dilation - dilation) / 2)
255
+
256
+
257
+ class Backbone(nn.Module):
258
+ """Base class for the generator's backbone. It preserves the same temporal resolution across all layers."""
259
+
260
+ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
261
+ """
262
+ Args:
263
+ x (Tensor): Input tensor of shape (B, C, L), where B is the batch size,
264
+ C denotes output features, and L is the sequence length.
265
+
266
+ Returns:
267
+ Tensor: Output of shape (B, L, H), where B is the batch size, L is the sequence length,
268
+ and H denotes the model dimension.
269
+ """
270
+ raise NotImplementedError("Subclasses must implement the forward method.")
271
+
272
+
273
+ class VocosBackbone(Backbone):
274
+ """
275
+ Vocos backbone module built with ConvNeXt blocks. Supports additional conditioning with Adaptive Layer Normalization
276
+
277
+ Args:
278
+ input_channels (int): Number of input features channels.
279
+ dim (int): Hidden dimension of the model.
280
+ intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock.
281
+ num_layers (int): Number of ConvNeXtBlock layers.
282
+ layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`.
283
+ adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
284
+ None means non-conditional model. Defaults to None.
285
+ """
286
+
287
+ def __init__(
288
+ self,
289
+ input_channels: int,
290
+ dim: int,
291
+ intermediate_dim: int,
292
+ num_layers: int,
293
+ layer_scale_init_value: Optional[float] = None,
294
+ condition_dim: Optional[int] = None,
295
+ ):
296
+ super().__init__()
297
+ self.input_channels = input_channels
298
+ self.embed = nn.Conv1d(input_channels, dim, kernel_size=7, padding=3)
299
+ self.adanorm = condition_dim is not None
300
+ if condition_dim:
301
+ self.norm = AdaLayerNorm(condition_dim, dim, eps=1e-6)
302
+ else:
303
+ self.norm = nn.LayerNorm(dim, eps=1e-6)
304
+ layer_scale_init_value = layer_scale_init_value or 1 / num_layers
305
+ self.convnext = nn.ModuleList(
306
+ [
307
+ ConvNeXtBlock(
308
+ dim=dim,
309
+ intermediate_dim=intermediate_dim,
310
+ layer_scale_init_value=layer_scale_init_value,
311
+ condition_dim=condition_dim,
312
+ )
313
+ for _ in range(num_layers)
314
+ ]
315
+ )
316
+ self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6)
317
+ self.apply(self._init_weights)
318
+
319
+ def _init_weights(self, m):
320
+ if isinstance(m, (nn.Conv1d, nn.Linear)):
321
+ nn.init.trunc_normal_(m.weight, std=0.02)
322
+ nn.init.constant_(m.bias, 0)
323
+
324
+ def forward(self, x: torch.Tensor, condition: torch.Tensor = None) -> torch.Tensor:
325
+ x = self.embed(x)
326
+ if self.adanorm:
327
+ assert condition is not None
328
+ x = self.norm(x.transpose(1, 2), condition)
329
+ else:
330
+ x = self.norm(x.transpose(1, 2))
331
+ x = x.transpose(1, 2)
332
+ for conv_block in self.convnext:
333
+ x = conv_block(x, condition)
334
+ x = self.final_layer_norm(x.transpose(1, 2))
335
+ return x
336
+
337
+
338
+ class VocosResNetBackbone(Backbone):
339
+ """
340
+ Vocos backbone module built with ResBlocks.
341
+
342
+ Args:
343
+ input_channels (int): Number of input features channels.
344
+ dim (int): Hidden dimension of the model.
345
+ num_blocks (int): Number of ResBlock1 blocks.
346
+ layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to None.
347
+ """
348
+
349
+ def __init__(
350
+ self,
351
+ input_channels,
352
+ dim,
353
+ num_blocks,
354
+ layer_scale_init_value=None,
355
+ ):
356
+ super().__init__()
357
+ self.input_channels = input_channels
358
+ self.embed = weight_norm(
359
+ nn.Conv1d(input_channels, dim, kernel_size=3, padding=1)
360
+ )
361
+ layer_scale_init_value = layer_scale_init_value or 1 / num_blocks / 3
362
+ self.resnet = nn.Sequential(
363
+ *[
364
+ ResBlock1(dim=dim, layer_scale_init_value=layer_scale_init_value)
365
+ for _ in range(num_blocks)
366
+ ]
367
+ )
368
+
369
+ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
370
+ x = self.embed(x)
371
+ x = self.resnet(x)
372
+ x = x.transpose(1, 2)
373
+ return x
sparktts/modules/encoder_decoder/feat_decoder.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 SparkAudio
2
+ # 2025 Xinsheng Wang ([email protected])
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+ from typing import List
21
+
22
+ from sparktts.modules.blocks.vocos import VocosBackbone
23
+ from sparktts.modules.blocks.samper import SamplingBlock
24
+
25
+
26
+ class Decoder(nn.Module):
27
+ """Decoder module with convnext and upsampling blocks
28
+
29
+ Args:
30
+ sample_ratios (List[int]): sample ratios
31
+ example: [2, 2] means downsample by 2x and then upsample by 2x
32
+ """
33
+
34
+ def __init__(
35
+ self,
36
+ input_channels: int,
37
+ vocos_dim: int,
38
+ vocos_intermediate_dim: int,
39
+ vocos_num_layers: int,
40
+ out_channels: int,
41
+ condition_dim: int = None,
42
+ sample_ratios: List[int] = [1, 1],
43
+ use_tanh_at_final: bool = False,
44
+ ):
45
+ super().__init__()
46
+
47
+ self.linear_pre = nn.Linear(input_channels, vocos_dim)
48
+ modules = [
49
+ nn.Sequential(
50
+ SamplingBlock(
51
+ dim=vocos_dim,
52
+ groups=vocos_dim,
53
+ upsample_scale=ratio,
54
+ ),
55
+ VocosBackbone(
56
+ input_channels=vocos_dim,
57
+ dim=vocos_dim,
58
+ intermediate_dim=vocos_intermediate_dim,
59
+ num_layers=2,
60
+ condition_dim=None,
61
+ ),
62
+ )
63
+ for ratio in sample_ratios
64
+ ]
65
+
66
+ self.downsample = nn.Sequential(*modules)
67
+
68
+ self.vocos_backbone = VocosBackbone(
69
+ input_channels=vocos_dim,
70
+ dim=vocos_dim,
71
+ intermediate_dim=vocos_intermediate_dim,
72
+ num_layers=vocos_num_layers,
73
+ condition_dim=condition_dim,
74
+ )
75
+ self.linear = nn.Linear(vocos_dim, out_channels)
76
+ self.use_tanh_at_final = use_tanh_at_final
77
+
78
+ def forward(self, x: torch.Tensor, c: torch.Tensor = None):
79
+ """encoder forward.
80
+
81
+ Args:
82
+ x (torch.Tensor): (batch_size, input_channels, length)
83
+
84
+ Returns:
85
+ x (torch.Tensor): (batch_size, encode_channels, length)
86
+ """
87
+ x = self.linear_pre(x.transpose(1, 2))
88
+ x = self.downsample(x).transpose(1, 2)
89
+ x = self.vocos_backbone(x, condition=c)
90
+ x = self.linear(x).transpose(1, 2)
91
+ if self.use_tanh_at_final:
92
+ x = torch.tanh(x)
93
+
94
+ return x
95
+
96
+
97
+ # test
98
+ if __name__ == "__main__":
99
+ test_input = torch.randn(8, 1024, 50) # Batch size = 8, 1024 channels, length = 50
100
+ condition = torch.randn(8, 256)
101
+ decoder = Decoder(
102
+ input_channels=1024,
103
+ vocos_dim=384,
104
+ vocos_intermediate_dim=2048,
105
+ vocos_num_layers=12,
106
+ out_channels=256,
107
+ condition_dim=256,
108
+ sample_ratios=[2, 2],
109
+ )
110
+ output = decoder(test_input, condition)
111
+ print(output.shape) # torch.Size([8, 256, 200])
112
+ if output.shape == torch.Size([8, 256, 200]):
113
+ print("Decoder test passed")
114
+ else:
115
+ print("Decoder test failed")
sparktts/modules/encoder_decoder/feat_encoder.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 SparkAudio
2
+ # 2025 Xinsheng Wang ([email protected])
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+ from typing import List
21
+
22
+ from sparktts.modules.blocks.vocos import VocosBackbone
23
+ from sparktts.modules.blocks.samper import SamplingBlock
24
+
25
+
26
+ class Encoder(nn.Module):
27
+ """Encoder module with convnext and downsampling blocks"""
28
+
29
+ def __init__(
30
+ self,
31
+ input_channels: int,
32
+ vocos_dim: int,
33
+ vocos_intermediate_dim: int,
34
+ vocos_num_layers: int,
35
+ out_channels: int,
36
+ sample_ratios: List[int] = [1, 1],
37
+ ):
38
+ super().__init__()
39
+ """
40
+ Encoder module with VocosBackbone and sampling blocks.
41
+
42
+ Args:
43
+ sample_ratios (List[int]): sample ratios
44
+ example: [2, 2] means downsample by 2x and then upsample by 2x
45
+ """
46
+ self.encoder = VocosBackbone(
47
+ input_channels=input_channels,
48
+ dim=vocos_dim,
49
+ intermediate_dim=vocos_intermediate_dim,
50
+ num_layers=vocos_num_layers,
51
+ condition_dim=None,
52
+ )
53
+
54
+ modules = [
55
+ nn.Sequential(
56
+ SamplingBlock(
57
+ dim=vocos_dim,
58
+ groups=vocos_dim,
59
+ downsample_scale=ratio,
60
+ ),
61
+ VocosBackbone(
62
+ input_channels=vocos_dim,
63
+ dim=vocos_dim,
64
+ intermediate_dim=vocos_intermediate_dim,
65
+ num_layers=2,
66
+ condition_dim=None,
67
+ ),
68
+ )
69
+ for ratio in sample_ratios
70
+ ]
71
+
72
+ self.downsample = nn.Sequential(*modules)
73
+
74
+ self.project = nn.Linear(vocos_dim, out_channels)
75
+
76
+ def forward(self, x: torch.Tensor, *args):
77
+ """
78
+ Args:
79
+ x (torch.Tensor): (batch_size, input_channels, length)
80
+
81
+ Returns:
82
+ x (torch.Tensor): (batch_size, encode_channels, length)
83
+ """
84
+ x = self.encoder(x)
85
+ x = self.downsample(x)
86
+ x = self.project(x)
87
+ return x.transpose(1, 2)
88
+
89
+
90
+ # test
91
+ if __name__ == "__main__":
92
+ test_input = torch.randn(8, 1024, 50) # Batch size = 8, 1024 channels, length = 50
93
+ encoder = Encoder(
94
+ input_channels=1024,
95
+ vocos_dim=384,
96
+ vocos_intermediate_dim=2048,
97
+ vocos_num_layers=12,
98
+ out_channels=256,
99
+ sample_ratios=[2, 2],
100
+ )
101
+
102
+ output = encoder(test_input)
103
+ print(output.shape) # torch.Size([8, 256, 12])
104
+ if output.shape == torch.Size([8, 256, 12]):
105
+ print("test successful")
sparktts/modules/encoder_decoder/wave_generator.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Xinsheng Wang ([email protected])
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Adapted from https://github.com/descriptinc/descript-audio-codec under the Apache License 2.0
16
+
17
+
18
+ import torch.nn as nn
19
+
20
+ from sparktts.modules.blocks.layers import (
21
+ Snake1d,
22
+ WNConv1d,
23
+ ResidualUnit,
24
+ WNConvTranspose1d,
25
+ init_weights,
26
+ )
27
+
28
+
29
+ class DecoderBlock(nn.Module):
30
+ def __init__(
31
+ self,
32
+ input_dim: int = 16,
33
+ output_dim: int = 8,
34
+ kernel_size: int = 2,
35
+ stride: int = 1,
36
+ ):
37
+ super().__init__()
38
+ self.block = nn.Sequential(
39
+ Snake1d(input_dim),
40
+ WNConvTranspose1d(
41
+ input_dim,
42
+ output_dim,
43
+ kernel_size=kernel_size,
44
+ stride=stride,
45
+ padding=(kernel_size - stride) // 2,
46
+ ),
47
+ ResidualUnit(output_dim, dilation=1),
48
+ ResidualUnit(output_dim, dilation=3),
49
+ ResidualUnit(output_dim, dilation=9),
50
+ )
51
+
52
+ def forward(self, x):
53
+ return self.block(x)
54
+
55
+
56
+ class WaveGenerator(nn.Module):
57
+ def __init__(
58
+ self,
59
+ input_channel,
60
+ channels,
61
+ rates,
62
+ kernel_sizes,
63
+ d_out: int = 1,
64
+ ):
65
+ super().__init__()
66
+
67
+ # Add first conv layer
68
+ layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3)]
69
+
70
+ # Add upsampling + MRF blocks
71
+ for i, (kernel_size, stride) in enumerate(zip(kernel_sizes, rates)):
72
+ input_dim = channels // 2**i
73
+ output_dim = channels // 2 ** (i + 1)
74
+ layers += [DecoderBlock(input_dim, output_dim, kernel_size, stride)]
75
+
76
+ # Add final conv layer
77
+ layers += [
78
+ Snake1d(output_dim),
79
+ WNConv1d(output_dim, d_out, kernel_size=7, padding=3),
80
+ nn.Tanh(),
81
+ ]
82
+
83
+ self.model = nn.Sequential(*layers)
84
+
85
+ self.apply(init_weights)
86
+
87
+ def forward(self, x):
88
+ return self.model(x)
sparktts/modules/fsq/finite_scalar_quantization.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Finite Scalar Quantization: VQ-VAE Made Simple - https://arxiv.org/abs/2309.15505
3
+ Code adapted from Jax version in Appendix A.1
4
+ """
5
+
6
+ from __future__ import annotations
7
+ from functools import wraps, partial
8
+ from contextlib import nullcontext
9
+ from typing import List, Tuple
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ from torch.nn import Module
14
+ from torch import Tensor, int32
15
+ from torch.amp import autocast
16
+
17
+ from einops import rearrange, pack, unpack
18
+
19
+ # helper functions
20
+
21
+
22
+ def exists(v):
23
+ return v is not None
24
+
25
+
26
+ def default(*args):
27
+ for arg in args:
28
+ if exists(arg):
29
+ return arg
30
+ return None
31
+
32
+
33
+ def maybe(fn):
34
+ @wraps(fn)
35
+ def inner(x, *args, **kwargs):
36
+ if not exists(x):
37
+ return x
38
+ return fn(x, *args, **kwargs)
39
+
40
+ return inner
41
+
42
+
43
+ def pack_one(t, pattern):
44
+ return pack([t], pattern)
45
+
46
+
47
+ def unpack_one(t, ps, pattern):
48
+ return unpack(t, ps, pattern)[0]
49
+
50
+
51
+ # tensor helpers
52
+
53
+
54
+ def round_ste(z: Tensor) -> Tensor:
55
+ """Round with straight through gradients."""
56
+ zhat = z.round()
57
+ return z + (zhat - z).detach()
58
+
59
+
60
+ # main class
61
+
62
+
63
+ class FSQ(Module):
64
+ def __init__(
65
+ self,
66
+ levels: List[int],
67
+ dim: int | None = None,
68
+ num_codebooks=1,
69
+ keep_num_codebooks_dim: bool | None = None,
70
+ scale: float | None = None,
71
+ allowed_dtypes: Tuple[torch.dtype, ...] = (torch.float32, torch.float64),
72
+ channel_first: bool = False,
73
+ projection_has_bias: bool = True,
74
+ return_indices=True,
75
+ force_quantization_f32=True,
76
+ ):
77
+ super().__init__()
78
+ _levels = torch.tensor(levels, dtype=int32)
79
+ self.register_buffer("_levels", _levels, persistent=False)
80
+
81
+ _basis = torch.cumprod(torch.tensor([1] + levels[:-1]), dim=0, dtype=int32)
82
+ self.register_buffer("_basis", _basis, persistent=False)
83
+
84
+ self.scale = scale
85
+
86
+ codebook_dim = len(levels)
87
+ self.codebook_dim = codebook_dim
88
+
89
+ effective_codebook_dim = codebook_dim * num_codebooks
90
+ self.num_codebooks = num_codebooks
91
+ self.effective_codebook_dim = effective_codebook_dim
92
+
93
+ keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1)
94
+ assert not (num_codebooks > 1 and not keep_num_codebooks_dim)
95
+ self.keep_num_codebooks_dim = keep_num_codebooks_dim
96
+
97
+ self.dim = default(dim, len(_levels) * num_codebooks)
98
+
99
+ self.channel_first = channel_first
100
+
101
+ has_projections = self.dim != effective_codebook_dim
102
+ self.project_in = (
103
+ nn.Linear(self.dim, effective_codebook_dim, bias=projection_has_bias)
104
+ if has_projections
105
+ else nn.Identity()
106
+ )
107
+ self.project_out = (
108
+ nn.Linear(effective_codebook_dim, self.dim, bias=projection_has_bias)
109
+ if has_projections
110
+ else nn.Identity()
111
+ )
112
+
113
+ self.has_projections = has_projections
114
+
115
+ self.return_indices = return_indices
116
+ if return_indices:
117
+ self.codebook_size = self._levels.prod().item()
118
+ implicit_codebook = self._indices_to_codes(torch.arange(self.codebook_size))
119
+ self.register_buffer(
120
+ "implicit_codebook", implicit_codebook, persistent=False
121
+ )
122
+
123
+ self.allowed_dtypes = allowed_dtypes
124
+ self.force_quantization_f32 = force_quantization_f32
125
+
126
+ def bound(self, z, eps: float = 1e-3):
127
+ """Bound `z`, an array of shape (..., d)."""
128
+ half_l = (self._levels - 1) * (1 + eps) / 2
129
+ offset = torch.where(self._levels % 2 == 0, 0.5, 0.0)
130
+ shift = (offset / half_l).atanh()
131
+ return (z + shift).tanh() * half_l - offset
132
+
133
+ def quantize(self, z):
134
+ """Quantizes z, returns quantized zhat, same shape as z."""
135
+ quantized = round_ste(self.bound(z))
136
+ half_width = self._levels // 2 # Renormalize to [-1, 1].
137
+ return quantized / half_width
138
+
139
+ def _scale_and_shift(self, zhat_normalized):
140
+ half_width = self._levels // 2
141
+ return (zhat_normalized * half_width) + half_width
142
+
143
+ def _scale_and_shift_inverse(self, zhat):
144
+ half_width = self._levels // 2
145
+ return (zhat - half_width) / half_width
146
+
147
+ def _indices_to_codes(self, indices):
148
+ level_indices = self.indices_to_level_indices(indices)
149
+ codes = self._scale_and_shift_inverse(level_indices)
150
+ return codes
151
+
152
+ def codes_to_indices(self, zhat):
153
+ """Converts a `code` to an index in the codebook."""
154
+ assert zhat.shape[-1] == self.codebook_dim
155
+ zhat = self._scale_and_shift(zhat)
156
+ return (zhat * self._basis).sum(dim=-1).to(int32)
157
+
158
+ def indices_to_level_indices(self, indices):
159
+ """Converts indices to indices at each level, perhaps needed for a transformer with factorized embeddings"""
160
+ indices = rearrange(indices, "... -> ... 1")
161
+ codes_non_centered = (indices // self._basis) % self._levels
162
+ return codes_non_centered
163
+
164
+ def indices_to_codes(self, indices):
165
+ """Inverse of `codes_to_indices`."""
166
+ assert exists(indices)
167
+
168
+ is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim))
169
+
170
+ codes = self._indices_to_codes(indices)
171
+
172
+ if self.keep_num_codebooks_dim:
173
+ codes = rearrange(codes, "... c d -> ... (c d)")
174
+
175
+ codes = self.project_out(codes)
176
+
177
+ if is_img_or_video or self.channel_first:
178
+ codes = rearrange(codes, "b ... d -> b d ...")
179
+
180
+ return codes
181
+
182
+ def forward(self, z):
183
+ """
184
+ einstein notation
185
+ b - batch
186
+ n - sequence (or flattened spatial dimensions)
187
+ d - feature dimension
188
+ c - number of codebook dim
189
+ """
190
+
191
+ is_img_or_video = z.ndim >= 4
192
+ need_move_channel_last = is_img_or_video or self.channel_first
193
+
194
+ # standardize image or video into (batch, seq, dimension)
195
+
196
+ if need_move_channel_last:
197
+ z = rearrange(z, "b d ... -> b ... d")
198
+ z, ps = pack_one(z, "b * d")
199
+
200
+ assert (
201
+ z.shape[-1] == self.dim
202
+ ), f"expected dimension of {self.dim} but found dimension of {z.shape[-1]}"
203
+
204
+ z = self.project_in(z)
205
+
206
+ z = rearrange(z, "b n (c d) -> b n c d", c=self.num_codebooks)
207
+
208
+ # whether to force quantization step to be full precision or not
209
+
210
+ force_f32 = self.force_quantization_f32
211
+ quantization_context = (
212
+ partial(autocast, "cuda", enabled=False) if force_f32 else nullcontext
213
+ )
214
+
215
+ with quantization_context():
216
+ orig_dtype = z.dtype
217
+
218
+ if force_f32 and orig_dtype not in self.allowed_dtypes:
219
+ z = z.float()
220
+
221
+ codes = self.quantize(z)
222
+
223
+ # returning indices could be optional
224
+
225
+ indices = None
226
+
227
+ if self.return_indices:
228
+ indices = self.codes_to_indices(codes)
229
+
230
+ codes = rearrange(codes, "b n c d -> b n (c d)")
231
+
232
+ codes = codes.type(orig_dtype)
233
+
234
+ # project out
235
+
236
+ out = self.project_out(codes)
237
+
238
+ # reconstitute image or video dimensions
239
+
240
+ if need_move_channel_last:
241
+ out = unpack_one(out, ps, "b * d")
242
+ out = rearrange(out, "b ... d -> b d ...")
243
+
244
+ indices = maybe(unpack_one)(indices, ps, "b * c")
245
+
246
+ if not self.keep_num_codebooks_dim and self.return_indices:
247
+ indices = maybe(rearrange)(indices, "... 1 -> ...")
248
+
249
+ # return quantized output and indices
250
+
251
+ return out, indices
sparktts/modules/fsq/residual_fsq.py ADDED
@@ -0,0 +1,355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import torch.distributed as dist
5
+
6
+ from typing import List
7
+ from torch import nn
8
+ from torch.nn import Module
9
+ from torch.amp import autocast
10
+ from einx import get_at
11
+ from einops import rearrange, reduce, pack, unpack
12
+
13
+ from sparktts.modules.fsq.finite_scalar_quantization import FSQ
14
+
15
+
16
+ def exists(val):
17
+ return val is not None
18
+
19
+
20
+ def first(l):
21
+ return l[0]
22
+
23
+
24
+ def default(val, d):
25
+ return val if exists(val) else d
26
+
27
+
28
+ def round_up_multiple(num, mult):
29
+ return ceil(num / mult) * mult
30
+
31
+
32
+ # distributed helpers
33
+
34
+
35
+ def is_distributed():
36
+ return dist.is_initialized() and dist.get_world_size() > 1
37
+
38
+
39
+ def get_maybe_sync_seed(device, max_size=10_000):
40
+ rand_int = torch.randint(0, max_size, (), device=device)
41
+
42
+ if is_distributed():
43
+ dist.all_reduce(rand_int)
44
+
45
+ return rand_int.item()
46
+
47
+
48
+ class ResidualFSQ(Module):
49
+ """Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf"""
50
+
51
+ def __init__(
52
+ self,
53
+ *,
54
+ levels: List[int],
55
+ num_quantizers,
56
+ dim=None,
57
+ is_channel_first=False,
58
+ quantize_dropout=False,
59
+ quantize_dropout_cutoff_index=0,
60
+ quantize_dropout_multiple_of=1,
61
+ **kwargs,
62
+ ):
63
+ super().__init__()
64
+ codebook_dim = len(levels)
65
+ dim = default(dim, codebook_dim)
66
+
67
+ requires_projection = codebook_dim != dim
68
+ self.project_in = (
69
+ nn.Linear(dim, codebook_dim) if requires_projection else nn.Identity()
70
+ )
71
+ self.project_out = (
72
+ nn.Linear(codebook_dim, dim) if requires_projection else nn.Identity()
73
+ )
74
+ self.has_projections = requires_projection
75
+
76
+ self.is_channel_first = is_channel_first
77
+ self.num_quantizers = num_quantizers
78
+
79
+ self.levels = levels
80
+ self.layers = nn.ModuleList([])
81
+
82
+ levels_tensor = torch.Tensor(levels)
83
+
84
+ scales = []
85
+
86
+ for ind in range(num_quantizers):
87
+ scales.append((levels_tensor - 1) ** -ind)
88
+
89
+ fsq = FSQ(levels=levels, dim=codebook_dim, **kwargs)
90
+
91
+ self.layers.append(fsq)
92
+
93
+ assert all([not fsq.has_projections for fsq in self.layers])
94
+
95
+ self.codebook_size = self.layers[0].codebook_size
96
+
97
+ self.register_buffer("scales", torch.stack(scales), persistent=False)
98
+
99
+ self.quantize_dropout = quantize_dropout and num_quantizers > 1
100
+
101
+ assert quantize_dropout_cutoff_index >= 0
102
+
103
+ self.quantize_dropout_cutoff_index = quantize_dropout_cutoff_index
104
+ self.quantize_dropout_multiple_of = quantize_dropout_multiple_of # encodec paper proposes structured dropout, believe this was set to 4
105
+
106
+ @property
107
+ def codebooks(self):
108
+ codebooks = [layer.implicit_codebook for layer in self.layers]
109
+ codebooks = torch.stack(codebooks, dim=0)
110
+ return codebooks
111
+
112
+ def get_codes_from_indices(self, indices):
113
+
114
+ batch, quantize_dim = indices.shape[0], indices.shape[-1]
115
+
116
+ # may also receive indices in the shape of 'b h w q' (accept_image_fmap)
117
+
118
+ indices, ps = pack([indices], "b * q")
119
+
120
+ # because of quantize dropout, one can pass in indices that are coarse
121
+ # and the network should be able to reconstruct
122
+
123
+ if quantize_dim < self.num_quantizers:
124
+ assert (
125
+ self.quantize_dropout > 0.0
126
+ ), "quantize dropout must be greater than 0 if you wish to reconstruct from a signal with less fine quantizations"
127
+ indices = F.pad(indices, (0, self.num_quantizers - quantize_dim), value=-1)
128
+
129
+ # take care of quantizer dropout
130
+
131
+ mask = indices == -1
132
+ indices = indices.masked_fill(
133
+ mask, 0
134
+ ) # have it fetch a dummy code to be masked out later
135
+
136
+ all_codes = get_at("q [c] d, b n q -> q b n d", self.codebooks, indices)
137
+
138
+ # mask out any codes that were dropout-ed
139
+
140
+ all_codes = all_codes.masked_fill(rearrange(mask, "b n q -> q b n 1"), 0.0)
141
+
142
+ # scale the codes
143
+
144
+ scales = rearrange(self.scales, "q d -> q 1 1 d")
145
+ all_codes = all_codes * scales
146
+
147
+ # if (accept_image_fmap = True) then return shape (quantize, batch, height, width, dimension)
148
+
149
+ (all_codes,) = unpack(all_codes, ps, "q b * d")
150
+
151
+ return all_codes
152
+
153
+ def get_output_from_indices(self, indices):
154
+ codes = self.get_codes_from_indices(indices)
155
+ codes_summed = reduce(codes, "q ... -> ...", "sum")
156
+ return self.project_out(codes_summed)
157
+
158
+ def forward(self, x, return_all_codes=False, rand_quantize_dropout_fixed_seed=None):
159
+ num_quant, quant_dropout_multiple_of, device = (
160
+ self.num_quantizers,
161
+ self.quantize_dropout_multiple_of,
162
+ x.device,
163
+ )
164
+
165
+ # handle channel first
166
+
167
+ if self.is_channel_first:
168
+ x = rearrange(x, "b d ... -> b ... d")
169
+ x, ps = pack([x], "b * d")
170
+
171
+ # maybe project in
172
+
173
+ x = self.project_in(x)
174
+
175
+ quantized_out = 0.0
176
+ residual = x
177
+
178
+ all_indices = []
179
+
180
+ should_quantize_dropout = self.training and self.quantize_dropout
181
+
182
+ # sample a layer index at which to dropout further residual quantization
183
+ # also prepare null indices
184
+
185
+ if should_quantize_dropout:
186
+
187
+ # check if seed is manually passed in
188
+
189
+ if not exists(rand_quantize_dropout_fixed_seed):
190
+ rand_quantize_dropout_fixed_seed = get_maybe_sync_seed(device)
191
+
192
+ rand = random.Random(rand_quantize_dropout_fixed_seed)
193
+
194
+ rand_quantize_dropout_index = rand.randrange(
195
+ self.quantize_dropout_cutoff_index, num_quant
196
+ )
197
+
198
+ if quant_dropout_multiple_of != 1:
199
+ rand_quantize_dropout_index = (
200
+ round_up_multiple(
201
+ rand_quantize_dropout_index + 1, quant_dropout_multiple_of
202
+ )
203
+ - 1
204
+ )
205
+
206
+ null_indices = torch.full(
207
+ x.shape[:2], -1.0, device=device, dtype=torch.long
208
+ )
209
+
210
+ # go through the layers
211
+
212
+ with autocast("cuda", enabled=False):
213
+ for quantizer_index, (layer, scale) in enumerate(
214
+ zip(self.layers, self.scales)
215
+ ):
216
+
217
+ if (
218
+ should_quantize_dropout
219
+ and quantizer_index > rand_quantize_dropout_index
220
+ ):
221
+ all_indices.append(null_indices)
222
+ continue
223
+
224
+ quantized, indices = layer(residual / scale)
225
+
226
+ quantized = quantized * scale
227
+
228
+ residual = residual - quantized.detach()
229
+ quantized_out = quantized_out + quantized
230
+
231
+ all_indices.append(indices)
232
+
233
+ # project out, if needed
234
+
235
+ quantized_out = self.project_out(quantized_out)
236
+
237
+ # stack all indices
238
+
239
+ all_indices = torch.stack(all_indices, dim=-1)
240
+
241
+ # channel first out
242
+
243
+ if self.is_channel_first:
244
+ (quantized_out,) = unpack(quantized_out, ps, "b * d")
245
+ (all_indices,) = unpack(all_indices, ps, "b * d")
246
+
247
+ quantized_out = rearrange(quantized_out, "b ... d -> b d ...")
248
+ all_indices = rearrange(all_indices, "b ... d -> b d ...")
249
+
250
+ # return
251
+
252
+ ret = (quantized_out, all_indices)
253
+
254
+ if not return_all_codes:
255
+ return ret
256
+
257
+ # whether to return all codes from all codebooks across layers
258
+
259
+ all_codes = self.get_codes_from_indices(all_indices)
260
+
261
+ # will return all codes in shape (quantizer, batch, sequence length, codebook dimension)
262
+
263
+ return (*ret, all_codes)
264
+
265
+
266
+ # grouped residual fsq
267
+
268
+
269
+ class GroupedResidualFSQ(Module):
270
+ def __init__(self, *, dim, groups=1, accept_image_fmap=False, **kwargs):
271
+ super().__init__()
272
+ self.dim = dim
273
+ self.groups = groups
274
+ assert (dim % groups) == 0
275
+ dim_per_group = dim // groups
276
+
277
+ self.accept_image_fmap = accept_image_fmap
278
+
279
+ self.rvqs = nn.ModuleList([])
280
+
281
+ for _ in range(groups):
282
+ self.rvqs.append(ResidualFSQ(dim=dim_per_group, **kwargs))
283
+
284
+ self.codebook_size = self.rvqs[0].codebook_size
285
+
286
+ @property
287
+ def codebooks(self):
288
+ return torch.stack(tuple(rvq.codebooks for rvq in self.rvqs))
289
+
290
+ @property
291
+ def split_dim(self):
292
+ return 1 if self.accept_image_fmap else -1
293
+
294
+ def get_codes_from_indices(self, indices):
295
+ codes = tuple(
296
+ rvq.get_codes_from_indices(chunk_indices)
297
+ for rvq, chunk_indices in zip(self.rvqs, indices)
298
+ )
299
+ return torch.stack(codes)
300
+
301
+ def get_output_from_indices(self, indices):
302
+ outputs = tuple(
303
+ rvq.get_output_from_indices(chunk_indices)
304
+ for rvq, chunk_indices in zip(self.rvqs, indices)
305
+ )
306
+ return torch.cat(outputs, dim=self.split_dim)
307
+
308
+ def forward(self, x, return_all_codes=False):
309
+ shape, split_dim, device = x.shape, self.split_dim, x.device
310
+ assert shape[split_dim] == self.dim
311
+
312
+ # split the feature dimension into groups
313
+
314
+ x = x.chunk(self.groups, dim=split_dim)
315
+
316
+ forward_kwargs = dict(
317
+ return_all_codes=return_all_codes,
318
+ rand_quantize_dropout_fixed_seed=(
319
+ get_maybe_sync_seed(device) if self.training else None
320
+ ),
321
+ )
322
+
323
+ # invoke residual vq on each group
324
+
325
+ out = tuple(rvq(chunk, **forward_kwargs) for rvq, chunk in zip(self.rvqs, x))
326
+ out = tuple(zip(*out))
327
+
328
+ # otherwise, get all the zipped outputs and combine them
329
+
330
+ quantized, all_indices, *maybe_all_codes = out
331
+
332
+ quantized = torch.cat(quantized, dim=split_dim)
333
+ all_indices = torch.stack(all_indices)
334
+
335
+ ret = (quantized, all_indices, *maybe_all_codes)
336
+ return ret
337
+
338
+
339
+ if __name__ == "__main__":
340
+ model = ResidualFSQ(
341
+ levels=[4, 4, 4, 4, 4, 4],
342
+ num_quantizers=1,
343
+ dim=30,
344
+ is_channel_first=True,
345
+ quantize_dropout=False,
346
+ )
347
+ x = torch.randn(2, 30, 10)
348
+ quantize, embed_ind = model(x)
349
+
350
+ emb_from_ind = model.get_output_from_indices(embed_ind.transpose(1, 2))
351
+
352
+ print(quantize == emb_from_ind.transpose(1, 2))
353
+
354
+ print("quantize shape", quantize.shape)
355
+ print("embed_ind", embed_ind)
sparktts/modules/speaker/ecapa_tdnn.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 Zhengyang Chen ([email protected])
2
+ # 2022 Hongji Wang ([email protected])
3
+ # 2023 Bing Han ([email protected])
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ """ This implementation is adapted from github repo:
18
+ https://github.com/lawlict/ECAPA-TDNN.
19
+ """
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+
25
+ import sparktts.modules.speaker.pooling_layers as pooling_layers
26
+
27
+
28
+ class Res2Conv1dReluBn(nn.Module):
29
+ """
30
+ in_channels == out_channels == channels
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ channels,
36
+ kernel_size=1,
37
+ stride=1,
38
+ padding=0,
39
+ dilation=1,
40
+ bias=True,
41
+ scale=4,
42
+ ):
43
+ super().__init__()
44
+ assert channels % scale == 0, "{} % {} != 0".format(channels, scale)
45
+ self.scale = scale
46
+ self.width = channels // scale
47
+ self.nums = scale if scale == 1 else scale - 1
48
+
49
+ self.convs = []
50
+ self.bns = []
51
+ for i in range(self.nums):
52
+ self.convs.append(
53
+ nn.Conv1d(
54
+ self.width,
55
+ self.width,
56
+ kernel_size,
57
+ stride,
58
+ padding,
59
+ dilation,
60
+ bias=bias,
61
+ )
62
+ )
63
+ self.bns.append(nn.BatchNorm1d(self.width))
64
+ self.convs = nn.ModuleList(self.convs)
65
+ self.bns = nn.ModuleList(self.bns)
66
+
67
+ def forward(self, x):
68
+ out = []
69
+ spx = torch.split(x, self.width, 1)
70
+ sp = spx[0]
71
+ for i, (conv, bn) in enumerate(zip(self.convs, self.bns)):
72
+ # Order: conv -> relu -> bn
73
+ if i >= 1:
74
+ sp = sp + spx[i]
75
+ sp = conv(sp)
76
+ sp = bn(F.relu(sp))
77
+ out.append(sp)
78
+ if self.scale != 1:
79
+ out.append(spx[self.nums])
80
+ out = torch.cat(out, dim=1)
81
+
82
+ return out
83
+
84
+
85
+ """ Conv1d + BatchNorm1d + ReLU
86
+ """
87
+
88
+
89
+ class Conv1dReluBn(nn.Module):
90
+
91
+ def __init__(
92
+ self,
93
+ in_channels,
94
+ out_channels,
95
+ kernel_size=1,
96
+ stride=1,
97
+ padding=0,
98
+ dilation=1,
99
+ bias=True,
100
+ ):
101
+ super().__init__()
102
+ self.conv = nn.Conv1d(
103
+ in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias
104
+ )
105
+ self.bn = nn.BatchNorm1d(out_channels)
106
+
107
+ def forward(self, x):
108
+ return self.bn(F.relu(self.conv(x)))
109
+
110
+
111
+ """ The SE connection of 1D case.
112
+ """
113
+
114
+
115
+ class SE_Connect(nn.Module):
116
+
117
+ def __init__(self, channels, se_bottleneck_dim=128):
118
+ super().__init__()
119
+ self.linear1 = nn.Linear(channels, se_bottleneck_dim)
120
+ self.linear2 = nn.Linear(se_bottleneck_dim, channels)
121
+
122
+ def forward(self, x):
123
+ out = x.mean(dim=2)
124
+ out = F.relu(self.linear1(out))
125
+ out = torch.sigmoid(self.linear2(out))
126
+ out = x * out.unsqueeze(2)
127
+
128
+ return out
129
+
130
+
131
+ """ SE-Res2Block of the ECAPA-TDNN architecture.
132
+ """
133
+
134
+
135
+ class SE_Res2Block(nn.Module):
136
+
137
+ def __init__(self, channels, kernel_size, stride, padding, dilation, scale):
138
+ super().__init__()
139
+ self.se_res2block = nn.Sequential(
140
+ Conv1dReluBn(channels, channels, kernel_size=1, stride=1, padding=0),
141
+ Res2Conv1dReluBn(
142
+ channels, kernel_size, stride, padding, dilation, scale=scale
143
+ ),
144
+ Conv1dReluBn(channels, channels, kernel_size=1, stride=1, padding=0),
145
+ SE_Connect(channels),
146
+ )
147
+
148
+ def forward(self, x):
149
+ return x + self.se_res2block(x)
150
+
151
+
152
+ class ECAPA_TDNN(nn.Module):
153
+
154
+ def __init__(
155
+ self,
156
+ channels=512,
157
+ feat_dim=80,
158
+ embed_dim=192,
159
+ pooling_func="ASTP",
160
+ global_context_att=False,
161
+ emb_bn=False,
162
+ ):
163
+ super().__init__()
164
+
165
+ self.layer1 = Conv1dReluBn(feat_dim, channels, kernel_size=5, padding=2)
166
+ self.layer2 = SE_Res2Block(
167
+ channels, kernel_size=3, stride=1, padding=2, dilation=2, scale=8
168
+ )
169
+ self.layer3 = SE_Res2Block(
170
+ channels, kernel_size=3, stride=1, padding=3, dilation=3, scale=8
171
+ )
172
+ self.layer4 = SE_Res2Block(
173
+ channels, kernel_size=3, stride=1, padding=4, dilation=4, scale=8
174
+ )
175
+
176
+ cat_channels = channels * 3
177
+ out_channels = 512 * 3
178
+ self.conv = nn.Conv1d(cat_channels, out_channels, kernel_size=1)
179
+ self.pool = getattr(pooling_layers, pooling_func)(
180
+ in_dim=out_channels, global_context_att=global_context_att
181
+ )
182
+ self.pool_out_dim = self.pool.get_out_dim()
183
+ self.bn = nn.BatchNorm1d(self.pool_out_dim)
184
+ self.linear = nn.Linear(self.pool_out_dim, embed_dim)
185
+ self.emb_bn = emb_bn
186
+ if emb_bn: # better in SSL for SV
187
+ self.bn2 = nn.BatchNorm1d(embed_dim)
188
+ else:
189
+ self.bn2 = nn.Identity()
190
+
191
+ def forward(self, x, return_latent=False):
192
+ x = x.permute(0, 2, 1) # (B,T,F) -> (B,F,T)
193
+
194
+ out1 = self.layer1(x)
195
+ out2 = self.layer2(out1)
196
+ out3 = self.layer3(out2)
197
+ out4 = self.layer4(out3)
198
+
199
+ out = torch.cat([out2, out3, out4], dim=1)
200
+ latent = F.relu(self.conv(out))
201
+ out = self.bn(self.pool(latent))
202
+ out = self.linear(out)
203
+ if self.emb_bn:
204
+ out = self.bn2(out)
205
+
206
+ if return_latent:
207
+ return out, latent
208
+ return out
209
+
210
+
211
+ def ECAPA_TDNN_c1024(feat_dim, embed_dim, pooling_func="ASTP", emb_bn=False):
212
+ return ECAPA_TDNN(
213
+ channels=1024,
214
+ feat_dim=feat_dim,
215
+ embed_dim=embed_dim,
216
+ pooling_func=pooling_func,
217
+ emb_bn=emb_bn,
218
+ )
219
+
220
+
221
+ def ECAPA_TDNN_GLOB_c1024(feat_dim, embed_dim, pooling_func="ASTP", emb_bn=False):
222
+ return ECAPA_TDNN(
223
+ channels=1024,
224
+ feat_dim=feat_dim,
225
+ embed_dim=embed_dim,
226
+ pooling_func=pooling_func,
227
+ global_context_att=True,
228
+ emb_bn=emb_bn,
229
+ )
230
+
231
+
232
+ def ECAPA_TDNN_c512(feat_dim, embed_dim, pooling_func="ASTP", emb_bn=False):
233
+ return ECAPA_TDNN(
234
+ channels=512,
235
+ feat_dim=feat_dim,
236
+ embed_dim=embed_dim,
237
+ pooling_func=pooling_func,
238
+ emb_bn=emb_bn,
239
+ )
240
+
241
+
242
+ def ECAPA_TDNN_GLOB_c512(feat_dim, embed_dim, pooling_func="ASTP", emb_bn=False):
243
+ return ECAPA_TDNN(
244
+ channels=512,
245
+ feat_dim=feat_dim,
246
+ embed_dim=embed_dim,
247
+ pooling_func=pooling_func,
248
+ global_context_att=True,
249
+ emb_bn=emb_bn,
250
+ )
251
+
252
+
253
+ if __name__ == "__main__":
254
+ x = torch.zeros(1, 200, 100)
255
+ model = ECAPA_TDNN_GLOB_c512(feat_dim=100, embed_dim=256, pooling_func="ASTP")
256
+ model.eval()
257
+ out, latent = model(x, True)
258
+ print(out.shape)
259
+ print(latent.shape)
260
+
261
+ num_params = sum(param.numel() for param in model.parameters())
262
+ print("{} M".format(num_params / 1e6))
263
+
264
+ # from thop import profile
265
+ # x_np = torch.randn(1, 200, 80)
266
+ # flops, params = profile(model, inputs=(x_np, ))
267
+ # print("FLOPs: {} G, Params: {} M".format(flops / 1e9, params / 1e6))
sparktts/modules/speaker/perceiver_encoder.py ADDED
@@ -0,0 +1,360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 SparkAudio
2
+ # 2025 Xinsheng Wang ([email protected])
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # Adapted from https://github.com/lucidrains/naturalspeech2-pytorch/blob/659bec7f7543e7747e809e950cc2f84242fbeec7/naturalspeech2_pytorch/naturalspeech2_pytorch.py#L532
17
+
18
+ from collections import namedtuple
19
+ from functools import wraps
20
+
21
+ import torch
22
+ import torch.nn.functional as F
23
+ from einops import rearrange, repeat
24
+ from einops.layers.torch import Rearrange
25
+ from packaging import version
26
+ from torch import einsum, nn
27
+
28
+
29
+ def exists(val):
30
+ return val is not None
31
+
32
+
33
+ def once(fn):
34
+ called = False
35
+
36
+ @wraps(fn)
37
+ def inner(x):
38
+ nonlocal called
39
+ if called:
40
+ return
41
+ called = True
42
+ return fn(x)
43
+
44
+ return inner
45
+
46
+
47
+ print_once = once(print)
48
+
49
+ # main class
50
+
51
+
52
+ class Attend(nn.Module):
53
+ def __init__(self, dropout=0.0, causal=False, use_flash=False):
54
+ super().__init__()
55
+ self.dropout = dropout
56
+ self.attn_dropout = nn.Dropout(dropout)
57
+
58
+ self.causal = causal
59
+ self.register_buffer("mask", None, persistent=False)
60
+
61
+ self.use_flash = use_flash
62
+ assert not (
63
+ use_flash and version.parse(torch.__version__) < version.parse("2.0.0")
64
+ ), "in order to use flash attention, you must be using pytorch 2.0 or above"
65
+
66
+ # determine efficient attention configs for cuda and cpu
67
+ self.config = namedtuple(
68
+ "EfficientAttentionConfig",
69
+ ["enable_flash", "enable_math", "enable_mem_efficient"],
70
+ )
71
+ self.cpu_config = self.config(True, True, True)
72
+ self.cuda_config = None
73
+
74
+ if not torch.cuda.is_available() or not use_flash:
75
+ return
76
+
77
+ device_properties = torch.cuda.get_device_properties(torch.device("cuda"))
78
+
79
+ if device_properties.major == 8 and device_properties.minor == 0:
80
+ print_once(
81
+ "A100 GPU detected, using flash attention if input tensor is on cuda"
82
+ )
83
+ self.cuda_config = self.config(True, False, False)
84
+ else:
85
+ print_once(
86
+ "Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda"
87
+ )
88
+ self.cuda_config = self.config(False, True, True)
89
+
90
+ def get_mask(self, n, device):
91
+ if exists(self.mask) and self.mask.shape[-1] >= n:
92
+ return self.mask[:n, :n]
93
+
94
+ mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1)
95
+ self.register_buffer("mask", mask, persistent=False)
96
+ return mask
97
+
98
+ def flash_attn(self, q, k, v, mask=None):
99
+ _, heads, q_len, _, k_len, is_cuda = *q.shape, k.shape[-2], q.is_cuda
100
+
101
+ # Recommended for multi-query single-key-value attention by Tri Dao
102
+ # kv shape torch.Size([1, 512, 64]) -> torch.Size([1, 8, 512, 64])
103
+
104
+ if k.ndim == 3:
105
+ k = rearrange(k, "b ... -> b 1 ...").expand_as(q)
106
+
107
+ if v.ndim == 3:
108
+ v = rearrange(v, "b ... -> b 1 ...").expand_as(q)
109
+
110
+ # Check if mask exists and expand to compatible shape
111
+ # The mask is B L, so it would have to be expanded to B H N L
112
+
113
+ if exists(mask):
114
+ mask = rearrange(mask, "b j -> b 1 1 j")
115
+ mask = mask.expand(-1, heads, q_len, -1)
116
+
117
+ # Check if there is a compatible device for flash attention
118
+
119
+ config = self.cuda_config if is_cuda else self.cpu_config
120
+
121
+ # pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale
122
+
123
+ with torch.backends.cuda.sdp_kernel(**config._asdict()):
124
+ out = F.scaled_dot_product_attention(
125
+ q,
126
+ k,
127
+ v,
128
+ attn_mask=mask,
129
+ dropout_p=self.dropout if self.training else 0.0,
130
+ is_causal=self.causal,
131
+ )
132
+
133
+ return out
134
+
135
+ def forward(self, q, k, v, mask=None):
136
+ """
137
+ einstein notation
138
+ b - batch
139
+ h - heads
140
+ n, i, j - sequence length (base sequence length, source, target)
141
+ d - feature dimension
142
+ """
143
+
144
+ n, device = q.shape[-2], q.device
145
+
146
+ scale = q.shape[-1] ** -0.5
147
+
148
+ if self.use_flash:
149
+ return self.flash_attn(q, k, v, mask=mask)
150
+
151
+ kv_einsum_eq = "b j d" if k.ndim == 3 else "b h j d"
152
+
153
+ # similarity
154
+
155
+ sim = einsum(f"b h i d, {kv_einsum_eq} -> b h i j", q, k) * scale
156
+
157
+ # key padding mask
158
+
159
+ if exists(mask):
160
+ mask = rearrange(mask, "b j -> b 1 1 j")
161
+ sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
162
+
163
+ # causal mask
164
+
165
+ if self.causal:
166
+ causal_mask = self.get_mask(n, device)
167
+ sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
168
+
169
+ # attention
170
+
171
+ attn = sim.softmax(dim=-1)
172
+ attn = self.attn_dropout(attn)
173
+
174
+ # aggregate values
175
+
176
+ out = einsum(f"b h i j, {kv_einsum_eq} -> b h i d", attn, v)
177
+
178
+ return out
179
+
180
+
181
+ def Sequential(*mods):
182
+ return nn.Sequential(*filter(exists, mods))
183
+
184
+
185
+ def exists(x):
186
+ return x is not None
187
+
188
+
189
+ def default(val, d):
190
+ if exists(val):
191
+ return val
192
+ return d() if callable(d) else d
193
+
194
+
195
+ class RMSNorm(nn.Module):
196
+ def __init__(self, dim, scale=True, dim_cond=None):
197
+ super().__init__()
198
+ self.cond = exists(dim_cond)
199
+ self.to_gamma_beta = nn.Linear(dim_cond, dim * 2) if self.cond else None
200
+
201
+ self.scale = dim**0.5
202
+ self.gamma = nn.Parameter(torch.ones(dim)) if scale else None
203
+
204
+ def forward(self, x, cond=None):
205
+ gamma = default(self.gamma, 1)
206
+ out = F.normalize(x, dim=-1) * self.scale * gamma
207
+
208
+ if not self.cond:
209
+ return out
210
+
211
+ assert exists(cond)
212
+ gamma, beta = self.to_gamma_beta(cond).chunk(2, dim=-1)
213
+ gamma, beta = map(lambda t: rearrange(t, "b d -> b 1 d"), (gamma, beta))
214
+ return out * gamma + beta
215
+
216
+
217
+ class CausalConv1d(nn.Conv1d):
218
+ def __init__(self, *args, **kwargs):
219
+ super().__init__(*args, **kwargs)
220
+ (kernel_size,) = self.kernel_size
221
+ (dilation,) = self.dilation
222
+ (stride,) = self.stride
223
+
224
+ assert stride == 1
225
+ self.causal_padding = dilation * (kernel_size - 1)
226
+
227
+ def forward(self, x):
228
+ causal_padded_x = F.pad(x, (self.causal_padding, 0), value=0.0)
229
+ return super().forward(causal_padded_x)
230
+
231
+
232
+ class GEGLU(nn.Module):
233
+ def forward(self, x):
234
+ x, gate = x.chunk(2, dim=-1)
235
+ return F.gelu(gate) * x
236
+
237
+
238
+ def FeedForward(dim, mult=4, causal_conv=False):
239
+ dim_inner = int(dim * mult * 2 / 3)
240
+
241
+ conv = None
242
+ if causal_conv:
243
+ conv = nn.Sequential(
244
+ Rearrange("b n d -> b d n"),
245
+ CausalConv1d(dim_inner, dim_inner, 3),
246
+ Rearrange("b d n -> b n d"),
247
+ )
248
+
249
+ return Sequential(
250
+ nn.Linear(dim, dim_inner * 2), GEGLU(), conv, nn.Linear(dim_inner, dim)
251
+ )
252
+
253
+
254
+ class Attention(nn.Module):
255
+ def __init__(
256
+ self,
257
+ dim,
258
+ *,
259
+ dim_context=None,
260
+ causal=False,
261
+ dim_head=64,
262
+ heads=8,
263
+ dropout=0.0,
264
+ use_flash=False,
265
+ cross_attn_include_queries=False,
266
+ ):
267
+ super().__init__()
268
+ self.scale = dim_head**-0.5
269
+ self.heads = heads
270
+ self.cross_attn_include_queries = cross_attn_include_queries
271
+
272
+ dim_inner = dim_head * heads
273
+ dim_context = default(dim_context, dim)
274
+
275
+ self.attend = Attend(causal=causal, dropout=dropout, use_flash=use_flash)
276
+ self.to_q = nn.Linear(dim, dim_inner, bias=False)
277
+ self.to_kv = nn.Linear(dim_context, dim_inner * 2, bias=False)
278
+ self.to_out = nn.Linear(dim_inner, dim, bias=False)
279
+
280
+ def forward(self, x, context=None, mask=None):
281
+ h, has_context = self.heads, exists(context)
282
+
283
+ context = default(context, x)
284
+
285
+ if has_context and self.cross_attn_include_queries:
286
+ context = torch.cat((x, context), dim=-2)
287
+
288
+ q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim=-1))
289
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
290
+
291
+ out = self.attend(q, k, v, mask=mask)
292
+
293
+ out = rearrange(out, "b h n d -> b n (h d)")
294
+ return self.to_out(out)
295
+
296
+
297
+ class PerceiverResampler(nn.Module):
298
+ def __init__(
299
+ self,
300
+ *,
301
+ dim,
302
+ depth=2,
303
+ dim_context=None,
304
+ num_latents=32,
305
+ dim_head=64,
306
+ heads=8,
307
+ ff_mult=4,
308
+ use_flash_attn=False,
309
+ ):
310
+ super().__init__()
311
+ dim_context = default(dim_context, dim)
312
+
313
+ self.proj_context = (
314
+ nn.Linear(dim_context, dim) if dim_context != dim else nn.Identity()
315
+ )
316
+
317
+ self.latents = nn.Parameter(torch.randn(num_latents, dim))
318
+ nn.init.normal_(self.latents, std=0.02)
319
+
320
+ self.layers = nn.ModuleList([])
321
+ for _ in range(depth):
322
+ self.layers.append(
323
+ nn.ModuleList(
324
+ [
325
+ Attention(
326
+ dim=dim,
327
+ dim_head=dim_head,
328
+ heads=heads,
329
+ use_flash=use_flash_attn,
330
+ cross_attn_include_queries=True,
331
+ ),
332
+ FeedForward(dim=dim, mult=ff_mult),
333
+ ]
334
+ )
335
+ )
336
+
337
+ self.norm = RMSNorm(dim)
338
+
339
+ def forward(self, x, mask=None):
340
+ batch = x.shape[0]
341
+
342
+ x = self.proj_context(x)
343
+
344
+ latents = repeat(self.latents, "n d -> b n d", b=batch)
345
+
346
+ for attn, ff in self.layers:
347
+ latents = attn(latents, x, mask=mask) + latents
348
+ latents = ff(latents) + latents
349
+
350
+ return self.norm(latents)
351
+
352
+
353
+ if __name__ == "__main__":
354
+ model = PerceiverResampler(dim=256, dim_context=80)
355
+ x = torch.randn(8, 200, 80)
356
+ out = model(x)
357
+ print(out.shape) # [8, 32, 80]
358
+
359
+ num_params = sum(param.numel() for param in model.parameters())
360
+ print("{} M".format(num_params / 1e6))
sparktts/modules/speaker/pooling_layers.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 Shuai Wang ([email protected])
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ Pooling functions to aggregate frame-level deep features
16
+ into segment-level speaker embeddings
17
+
18
+ High-order statistics are surprisingly effective, TSDP acts similarly as TSTP,
19
+ even though we remove the mean statistic, on Voxceleb.
20
+ """
21
+
22
+ import torch
23
+ import torch.nn as nn
24
+ import torch.nn.functional as F
25
+
26
+
27
+ class TAP(nn.Module):
28
+ """
29
+ Temporal average pooling, only first-order mean is considered
30
+ """
31
+
32
+ def __init__(self, in_dim=0, **kwargs):
33
+ super(TAP, self).__init__()
34
+ self.in_dim = in_dim
35
+
36
+ def forward(self, x):
37
+ pooling_mean = x.mean(dim=-1)
38
+ # To be compatable with 2D input
39
+ pooling_mean = pooling_mean.flatten(start_dim=1)
40
+ return pooling_mean
41
+
42
+ def get_out_dim(self):
43
+ self.out_dim = self.in_dim
44
+ return self.out_dim
45
+
46
+
47
+ class TSDP(nn.Module):
48
+ """
49
+ Temporal standard deviation pooling, only second-order std is considered
50
+ """
51
+
52
+ def __init__(self, in_dim=0, **kwargs):
53
+ super(TSDP, self).__init__()
54
+ self.in_dim = in_dim
55
+
56
+ def forward(self, x):
57
+ # The last dimension is the temporal axis
58
+ pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-7)
59
+ pooling_std = pooling_std.flatten(start_dim=1)
60
+ return pooling_std
61
+
62
+ def get_out_dim(self):
63
+ self.out_dim = self.in_dim
64
+ return self.out_dim
65
+
66
+
67
+ class TSTP(nn.Module):
68
+ """
69
+ Temporal statistics pooling, concatenate mean and std, which is used in
70
+ x-vector
71
+ Comment: simple concatenation can not make full use of both statistics
72
+ """
73
+
74
+ def __init__(self, in_dim=0, **kwargs):
75
+ super(TSTP, self).__init__()
76
+ self.in_dim = in_dim
77
+
78
+ def forward(self, x):
79
+ # The last dimension is the temporal axis
80
+ pooling_mean = x.mean(dim=-1)
81
+ pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-7)
82
+ pooling_mean = pooling_mean.flatten(start_dim=1)
83
+ pooling_std = pooling_std.flatten(start_dim=1)
84
+ stats = torch.cat((pooling_mean, pooling_std), 1)
85
+ return stats
86
+
87
+ def get_out_dim(self):
88
+ self.out_dim = self.in_dim * 2
89
+ return self.out_dim
90
+
91
+
92
+ class ASTP(nn.Module):
93
+ """ Attentive statistics pooling: Channel- and context-dependent
94
+ statistics pooling, first used in ECAPA_TDNN.
95
+ """
96
+
97
+ def __init__(self,
98
+ in_dim,
99
+ bottleneck_dim=128,
100
+ global_context_att=False,
101
+ **kwargs):
102
+ super(ASTP, self).__init__()
103
+ self.in_dim = in_dim
104
+ self.global_context_att = global_context_att
105
+
106
+ # Use Conv1d with stride == 1 rather than Linear, then we don't
107
+ # need to transpose inputs.
108
+ if global_context_att:
109
+ self.linear1 = nn.Conv1d(
110
+ in_dim * 3, bottleneck_dim,
111
+ kernel_size=1) # equals W and b in the paper
112
+ else:
113
+ self.linear1 = nn.Conv1d(
114
+ in_dim, bottleneck_dim,
115
+ kernel_size=1) # equals W and b in the paper
116
+ self.linear2 = nn.Conv1d(bottleneck_dim, in_dim,
117
+ kernel_size=1) # equals V and k in the paper
118
+
119
+ def forward(self, x):
120
+ """
121
+ x: a 3-dimensional tensor in tdnn-based architecture (B,F,T)
122
+ or a 4-dimensional tensor in resnet architecture (B,C,F,T)
123
+ 0-dim: batch-dimension, last-dim: time-dimension (frame-dimension)
124
+ """
125
+ if len(x.shape) == 4:
126
+ x = x.reshape(x.shape[0], x.shape[1] * x.shape[2], x.shape[3])
127
+ assert len(x.shape) == 3
128
+
129
+ if self.global_context_att:
130
+ context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
131
+ context_std = torch.sqrt(
132
+ torch.var(x, dim=-1, keepdim=True) + 1e-7).expand_as(x)
133
+ x_in = torch.cat((x, context_mean, context_std), dim=1)
134
+ else:
135
+ x_in = x
136
+
137
+ # DON'T use ReLU here! ReLU may be hard to converge.
138
+ alpha = torch.tanh(
139
+ self.linear1(x_in)) # alpha = F.relu(self.linear1(x_in))
140
+ alpha = torch.softmax(self.linear2(alpha), dim=2)
141
+ mean = torch.sum(alpha * x, dim=2)
142
+ var = torch.sum(alpha * (x**2), dim=2) - mean**2
143
+ std = torch.sqrt(var.clamp(min=1e-7))
144
+ return torch.cat([mean, std], dim=1)
145
+
146
+ def get_out_dim(self):
147
+ self.out_dim = 2 * self.in_dim
148
+ return self.out_dim
149
+
150
+
151
+ class MHASTP(torch.nn.Module):
152
+ """ Multi head attentive statistics pooling
153
+ Reference:
154
+ Self Multi-Head Attention for Speaker Recognition
155
+ https://arxiv.org/pdf/1906.09890.pdf
156
+ """
157
+
158
+ def __init__(self,
159
+ in_dim,
160
+ layer_num=2,
161
+ head_num=2,
162
+ d_s=1,
163
+ bottleneck_dim=64,
164
+ **kwargs):
165
+ super(MHASTP, self).__init__()
166
+ assert (in_dim % head_num
167
+ ) == 0 # make sure that head num can be divided by input_dim
168
+ self.in_dim = in_dim
169
+ self.head_num = head_num
170
+ d_model = int(in_dim / head_num)
171
+ channel_dims = [bottleneck_dim for i in range(layer_num + 1)]
172
+ if d_s > 1:
173
+ d_s = d_model
174
+ else:
175
+ d_s = 1
176
+ self.d_s = d_s
177
+ channel_dims[0], channel_dims[-1] = d_model, d_s
178
+ heads_att_trans = []
179
+ for i in range(self.head_num):
180
+ att_trans = nn.Sequential()
181
+ for i in range(layer_num - 1):
182
+ att_trans.add_module(
183
+ 'att_' + str(i),
184
+ nn.Conv1d(channel_dims[i], channel_dims[i + 1], 1, 1))
185
+ att_trans.add_module('tanh' + str(i), nn.Tanh())
186
+ att_trans.add_module(
187
+ 'att_' + str(layer_num - 1),
188
+ nn.Conv1d(channel_dims[layer_num - 1], channel_dims[layer_num],
189
+ 1, 1))
190
+ heads_att_trans.append(att_trans)
191
+ self.heads_att_trans = nn.ModuleList(heads_att_trans)
192
+
193
+ def forward(self, input):
194
+ """
195
+ input: a 3-dimensional tensor in xvector architecture
196
+ or a 4-dimensional tensor in resnet architecture
197
+ 0-dim: batch-dimension, last-dim: time-dimension (frame-dimension)
198
+ """
199
+ if len(input.shape) == 4: # B x F x T
200
+ input = input.reshape(input.shape[0],
201
+ input.shape[1] * input.shape[2],
202
+ input.shape[3])
203
+ assert len(input.shape) == 3
204
+ bs, f_dim, t_dim = input.shape
205
+ chunks = torch.chunk(input, self.head_num, 1)
206
+ # split
207
+ chunks_out = []
208
+ # for i in range(self.head_num):
209
+ # att_score = self.heads_att_trans[i](chunks[i])
210
+ for i, layer in enumerate(self.heads_att_trans):
211
+ att_score = layer(chunks[i])
212
+ alpha = F.softmax(att_score, dim=-1)
213
+ mean = torch.sum(alpha * chunks[i], dim=2)
214
+ var = torch.sum(alpha * chunks[i]**2, dim=2) - mean**2
215
+ std = torch.sqrt(var.clamp(min=1e-7))
216
+ chunks_out.append(torch.cat((mean, std), dim=1))
217
+ out = torch.cat(chunks_out, dim=1)
218
+ return out
219
+
220
+ def get_out_dim(self):
221
+ self.out_dim = 2 * self.in_dim
222
+ return self.out_dim
223
+
224
+
225
+ class MQMHASTP(torch.nn.Module):
226
+ """ An attentive pooling
227
+ Reference:
228
+ multi query multi head attentive statistics pooling
229
+ https://arxiv.org/pdf/2110.05042.pdf
230
+ Args:
231
+ in_dim: the feature dimension of input
232
+ layer_num: the number of layer in the pooling layer
233
+ query_num: the number of querys
234
+ head_num: the number of heads
235
+ bottleneck_dim: the bottleneck dimension
236
+
237
+ SA (H = 1, Q = 1, n = 2, d_s = 1) ref:
238
+ https://www.danielpovey.com/files/2018_interspeech_xvector_attention.pdf
239
+ MHA (H > 1, Q = 1, n = 1, d_s = 1) ref:
240
+ https://arxiv.org/pdf/1906.09890.pdf
241
+ AS (H = 1, Q > 1, n = 2, d_s = 1) ref:
242
+ https://arxiv.org/pdf/1803.10963.pdf
243
+ VSA (H = 1, Q > 1, n = 2, d_s = d_h) ref:
244
+ http://www.interspeech2020.org/uploadfile/pdf/Mon-2-10-5.pdf
245
+ """
246
+
247
+ def __init__(self,
248
+ in_dim,
249
+ layer_num=2,
250
+ query_num=2,
251
+ head_num=8,
252
+ d_s=2,
253
+ bottleneck_dim=64,
254
+ **kwargs):
255
+ super(MQMHASTP, self).__init__()
256
+ self.n_query = nn.ModuleList([
257
+ MHASTP(in_dim,
258
+ layer_num=layer_num,
259
+ head_num=head_num,
260
+ d_s=d_s,
261
+ bottleneck_dim=bottleneck_dim) for i in range(query_num)
262
+ ])
263
+ self.query_num = query_num
264
+ self.in_dim = in_dim
265
+
266
+ def forward(self, input):
267
+ """
268
+ input: a 3-dimensional tensor in xvector architecture
269
+ or a 4-dimensional tensor in resnet architecture
270
+ 0-dim: batch-dimension, last-dim: time-dimension (frame-dimension)
271
+ """
272
+ if len(input.shape) == 4: # B x F x T
273
+ input = input.reshape(input.shape[0],
274
+ input.shape[1] * input.shape[2],
275
+ input.shape[3])
276
+ assert len(input.shape) == 3
277
+ res = []
278
+ for i, layer in enumerate(self.n_query):
279
+ res.append(layer(input))
280
+ out = torch.cat(res, dim=-1)
281
+ return out
282
+
283
+ def get_out_dim(self):
284
+ self.out_dim = self.in_dim * 2 * self.query_num
285
+ return self.out_dim
286
+
287
+
288
+ if __name__ == '__main__':
289
+ data = torch.randn(16, 512, 10, 35)
290
+ # model = StatisticsPooling()
291
+ model = MQMHASTP(512 * 10)
292
+ model = MHASTP(512 * 10)
293
+ model = MQMHASTP(512 * 10, context=False)
294
+ print(model)
295
+
296
+ out = model(data)
297
+ print(out.shape)
298
+ print(model.get_out_dim())
sparktts/modules/speaker/speaker_encoder.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 SparkAudio
2
+ # 2025 Xinsheng Wang ([email protected])
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+
19
+ from typing import List, Tuple
20
+ from sparktts.modules.fsq.residual_fsq import ResidualFSQ
21
+ from sparktts.modules.speaker.ecapa_tdnn import ECAPA_TDNN_GLOB_c512
22
+ from sparktts.modules.speaker.perceiver_encoder import PerceiverResampler
23
+
24
+ """
25
+ x-vector + d-vector
26
+ """
27
+
28
+
29
+ class SpeakerEncoder(nn.Module):
30
+ """
31
+
32
+ Args:
33
+ input_dim (int): acoustic feature dimension
34
+ out_dim (int): output dimension of x-vector and d-vector
35
+ latent_dim (int): latent dimension before quantization
36
+ token_num (int): sequence length of speaker tokens
37
+ fsq_levels (List[int]): number of levels for each quantizer
38
+ fsq_num_quantizers (int): number of quantizers
39
+
40
+ Return:
41
+ speaker_embs: (B, T2, out_dim)
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ input_dim: int = 100,
47
+ out_dim: int = 512,
48
+ latent_dim: int = 128,
49
+ token_num: int = 32,
50
+ fsq_levels: List[int] = [4, 4, 4, 4, 4, 4],
51
+ fsq_num_quantizers: int = 1,
52
+ ):
53
+ super(SpeakerEncoder, self).__init__()
54
+
55
+ self.speaker_encoder = ECAPA_TDNN_GLOB_c512(
56
+ feat_dim=input_dim, embed_dim=out_dim
57
+ )
58
+ self.perceiver_sampler = PerceiverResampler(
59
+ dim=latent_dim, dim_context=512 * 3, num_latents=token_num
60
+ )
61
+ self.quantizer = ResidualFSQ(
62
+ levels=fsq_levels,
63
+ num_quantizers=fsq_num_quantizers,
64
+ dim=latent_dim,
65
+ is_channel_first=True,
66
+ quantize_dropout=False,
67
+ )
68
+
69
+ self.project = nn.Linear(latent_dim * token_num, out_dim)
70
+
71
+ def get_codes_from_indices(self, indices: torch.Tensor) -> torch.Tensor:
72
+ zq = self.quantizer.get_codes_from_indices(indices.transpose(1, 2))
73
+ return zq.transpose(1, 2)
74
+
75
+ def get_indices(self, mels: torch.Tensor) -> torch.Tensor:
76
+ mels = mels.transpose(1, 2)
77
+ x = self.perceiver_sampler(mels).transpose(1, 2)
78
+ zq, indices = self.quantizer(x)
79
+ return indices
80
+
81
+ def forward(self, mels: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
82
+ """
83
+ Args:
84
+ mels: (B, D_mel, T1)
85
+
86
+ Return:
87
+ x_vector: (B, out_dim)
88
+ d_vector: (B, out_dim)
89
+ """
90
+ # mels = mels.transpose(1,2)
91
+
92
+ x_vector, features = self.speaker_encoder(mels, True)
93
+ x = self.perceiver_sampler(features.transpose(1, 2)).transpose(1, 2)
94
+ zq, indices = self.quantizer(x) # zq: (B, latent_dim, T2, latent_dim)
95
+ x = zq.reshape(zq.shape[0], -1)
96
+ d_vector = self.project(x)
97
+
98
+ return x_vector, d_vector
99
+
100
+ def tokenize(self, mels: torch.Tensor) -> torch.Tensor:
101
+ """tokenize the input mel spectrogram"""
102
+ _, features = self.speaker_encoder(mels, True)
103
+ x = self.perceiver_sampler(features.transpose(1, 2)).transpose(1, 2)
104
+ zq, indices = self.quantizer(x)
105
+ return indices
106
+
107
+ def detokenize(self, indices: torch.Tensor) -> torch.Tensor:
108
+ """detokenize the input indices to d-vector"""
109
+ zq = self.quantizer.get_output_from_indices(indices.transpose(1, 2)).transpose(1, 2)
110
+ x = zq.reshape(zq.shape[0], -1)
111
+ d_vector = self.project(x)
112
+ return d_vector
113
+
114
+ if __name__ == "__main__":
115
+ model = SpeakerEncoder(
116
+ input_dim=100,
117
+ latent_dim=128,
118
+ token_num=32,
119
+ fsq_levels=[4, 4, 4, 4, 4, 4],
120
+ fsq_num_quantizers=1,
121
+ )
122
+ mel = torch.randn(8, 200, 100)
123
+ x_vector, d_vector = model(mel)
124
+ print("x-vector shape", x_vector.shape)
125
+ print("d-vector shape", d_vector.shape)
126
+
127
+ indices = model.tokenize(mel)
128
+ print("indices shape", indices.shape)
129
+ d_vector_post = model.detokenize(indices)
130
+ print("d-vector shape", d_vector_post.shape)
131
+ if d_vector_post.all() == d_vector.all():
132
+ print("d-vector post and d-vector are the same")
133
+ else:
134
+ print("d-vector post and d-vector are different")
135
+ num_params = sum(param.numel() for param in model.parameters())
136
+ print("{} M".format(num_params / 1e6))
sparktts/modules/vq/factorized_vector_quantize.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 SparkAudio
2
+ # 2025 Xinsheng Wang ([email protected])
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # Heavily based on https://github.com/lucidrains/vector-quantize-pytorch
17
+
18
+
19
+ from typing import Any, Dict
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+ from einops import rearrange
25
+ from torch.nn.utils import weight_norm
26
+
27
+
28
+ def WNConv1d(*args, **kwargs):
29
+ return weight_norm(nn.Conv1d(*args, **kwargs))
30
+
31
+
32
+ def ema_inplace(moving_avg, new, decay):
33
+ moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
34
+
35
+
36
+ class FactorizedVectorQuantize(nn.Module):
37
+ def __init__(
38
+ self,
39
+ input_dim: int,
40
+ codebook_size: int,
41
+ codebook_dim: int,
42
+ commitment: float,
43
+ codebook_loss_weight: float = 1.0,
44
+ decay: float = 0.99,
45
+ threshold_ema_dead_code: float = 2,
46
+ momentum: float = 0.99,
47
+ **kwargs,
48
+ ):
49
+ super().__init__()
50
+ self.input_dim = input_dim
51
+ self.codebook_size = codebook_size
52
+ self.codebook_dim = codebook_dim
53
+ self.commitment = commitment
54
+ self.codebook_loss_weight = codebook_loss_weight
55
+ self.decay = decay
56
+ self.threshold_ema_dead_code = threshold_ema_dead_code
57
+ self.momentum = momentum
58
+
59
+ if input_dim != self.codebook_dim:
60
+ self.in_project = WNConv1d(input_dim, self.codebook_dim, kernel_size=1)
61
+ self.out_project = WNConv1d(self.codebook_dim, input_dim, kernel_size=1)
62
+
63
+ else:
64
+ self.in_project = nn.Identity()
65
+ self.out_project = nn.Identity()
66
+
67
+ self.codebook = nn.Embedding(self.codebook_size, self.codebook_dim)
68
+ self.register_buffer("cluster_size", torch.zeros(self.codebook_size))
69
+
70
+ def forward(self, z: torch.Tensor) -> Dict[str, Any]:
71
+ """Quantized the input tensor using a fixed codebook and returns
72
+ the corresponding codebook vectors
73
+
74
+ Parameters
75
+ ----------
76
+ z : Tensor[B x D x T]
77
+
78
+ Returns
79
+ -------
80
+ Tensor[B x D x T]
81
+ Quantized continuous representation of input
82
+ Tensor[1]
83
+ Commitment loss to train encoder to predict vectors closer to codebook
84
+ entries
85
+ Tensor[1]
86
+ Codebook loss to update the codebook
87
+ Tensor[B x T]
88
+ Codebook indices (quantized discrete representation of input)
89
+ Tensor[B x D x T]
90
+ Projected latents (continuous representation of input before quantization)
91
+ """
92
+ # transpose since we use linear
93
+
94
+ # Factorized codes project input into low-dimensional space if self.input_dim != self.codebook_dim
95
+ z_e = self.in_project(z)
96
+ z_q, indices, dists = self.decode_latents(z_e)
97
+
98
+ # statistic the usage of codes
99
+ embed_onehot = F.one_hot(indices, self.codebook_size).type(z_e.dtype)
100
+ avg_probs = torch.mean(embed_onehot.reshape(-1, self.codebook_size), dim=0)
101
+ perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
102
+
103
+ active_num = (embed_onehot.sum(0).sum(0) > 0).sum()
104
+ if self.training:
105
+ # We do the expiry of code at that point as buffers are in sync
106
+ # and all the workers will take the same decision.
107
+ ema_inplace(self.cluster_size, embed_onehot.sum(0).sum(0), self.decay)
108
+ active_num = sum(self.cluster_size > self.threshold_ema_dead_code)
109
+
110
+ if self.training:
111
+ commit_loss = (
112
+ F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
113
+ * self.commitment
114
+ )
115
+
116
+ codebook_loss = (
117
+ F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
118
+ * self.codebook_loss_weight
119
+ )
120
+
121
+ else:
122
+ commit_loss = torch.zeros(0, device=z.device)
123
+ codebook_loss = torch.zeros(0, device=z.device)
124
+
125
+ z_q = (
126
+ z_e + (z_q - z_e).detach()
127
+ ) # noop in forward pass, straight-through gradient estimator in backward pass
128
+
129
+ z_q = self.out_project(z_q)
130
+
131
+ vq_loss = (commit_loss + codebook_loss).mean()
132
+
133
+ return {
134
+ "z_q": z_q,
135
+ "indices": indices,
136
+ "dists": dists,
137
+ "vq_loss": vq_loss,
138
+ "perplexity": perplexity,
139
+ "active_num": active_num.float(),
140
+ }
141
+
142
+ def vq2emb(self, vq, out_proj=True):
143
+ emb = self.embed_code(vq)
144
+ if out_proj:
145
+ emb = self.out_project(emb)
146
+ return emb
147
+
148
+ def tokenize(self, z: torch.Tensor) -> torch.Tensor:
149
+ """tokenize the input tensor"""
150
+ z_e = self.in_project(z)
151
+ _, indices, _ = self.decode_latents(z_e)
152
+ return indices
153
+
154
+ def detokenize(self, indices):
155
+ """detokenize the input indices"""
156
+ z_q = self.decode_code(indices)
157
+ z_q = self.out_project(z_q)
158
+ return z_q
159
+
160
+ def get_emb(self):
161
+ return self.codebook.weight
162
+
163
+ def embed_code(self, embed_id):
164
+ return F.embedding(embed_id, self.codebook.weight)
165
+
166
+ def decode_code(self, embed_id):
167
+ return self.embed_code(embed_id).transpose(1, 2)
168
+
169
+ def decode_latents(self, latents):
170
+ encodings = rearrange(latents, "b d t -> (b t) d")
171
+ codebook = self.codebook.weight
172
+
173
+ # L2 normalize encodings and codebook
174
+ encodings = F.normalize(encodings)
175
+ codebook = F.normalize(codebook)
176
+
177
+ # Compute euclidean distance between encodings and codebook,
178
+ # with L2 normalization, the distance is equal to cosine distance
179
+ dist = (
180
+ encodings.pow(2).sum(1, keepdim=True)
181
+ - 2 * encodings @ codebook.t()
182
+ + codebook.pow(2).sum(1, keepdim=True).t()
183
+ )
184
+ indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
185
+ z_q = self.decode_code(indices)
186
+
187
+ return z_q, indices, dist
sparktts/utils/__init__.py ADDED
File without changes
sparktts/utils/audio.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 SparkAudio
2
+ # 2025 Xinsheng Wang ([email protected])
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Description:
17
+ This script contains a collection of functions designed to handle various
18
+ audio processing.
19
+ """
20
+
21
+ import random
22
+ import soxr
23
+ import soundfile
24
+ import torch
25
+ import torchaudio
26
+ import numpy as np
27
+
28
+ from pathlib import Path
29
+ from typing import Tuple
30
+ from numpy.lib.stride_tricks import sliding_window_view
31
+
32
+
33
+ def audio_volume_normalize(audio: np.ndarray, coeff: float = 0.2) -> np.ndarray:
34
+ """
35
+ Normalize the volume of an audio signal.
36
+
37
+ Parameters:
38
+ audio (numpy array): Input audio signal array.
39
+ coeff (float): Target coefficient for normalization, default is 0.2.
40
+
41
+ Returns:
42
+ numpy array: The volume-normalized audio signal.
43
+ """
44
+ # Sort the absolute values of the audio signal
45
+ temp = np.sort(np.abs(audio))
46
+
47
+ # If the maximum value is less than 0.1, scale the array to have a maximum of 0.1
48
+ if temp[-1] < 0.1:
49
+ scaling_factor = max(
50
+ temp[-1], 1e-3
51
+ ) # Prevent division by zero with a small constant
52
+ audio = audio / scaling_factor * 0.1
53
+
54
+ # Filter out values less than 0.01 from temp
55
+ temp = temp[temp > 0.01]
56
+ L = temp.shape[0] # Length of the filtered array
57
+
58
+ # If there are fewer than or equal to 10 significant values, return the audio without further processing
59
+ if L <= 10:
60
+ return audio
61
+
62
+ # Compute the average of the top 10% to 1% of values in temp
63
+ volume = np.mean(temp[int(0.9 * L) : int(0.99 * L)])
64
+
65
+ # Normalize the audio to the target coefficient level, clamping the scale factor between 0.1 and 10
66
+ audio = audio * np.clip(coeff / volume, a_min=0.1, a_max=10)
67
+
68
+ # Ensure the maximum absolute value in the audio does not exceed 1
69
+ max_value = np.max(np.abs(audio))
70
+ if max_value > 1:
71
+ audio = audio / max_value
72
+
73
+ return audio
74
+
75
+
76
+ def load_audio(
77
+ adfile: Path,
78
+ sampling_rate: int = None,
79
+ length: int = None,
80
+ volume_normalize: bool = False,
81
+ segment_duration: int = None,
82
+ ) -> np.ndarray:
83
+ r"""Load audio file with target sampling rate and lsength
84
+
85
+ Args:
86
+ adfile (Path): path to audio file.
87
+ sampling_rate (int, optional): target sampling rate. Defaults to None.
88
+ length (int, optional): target audio length. Defaults to None.
89
+ volume_normalize (bool, optional): whether perform volume normalization. Defaults to False.
90
+ segment_duration (int): random select a segment with duration of {segment_duration}s.
91
+ Defualt to None which means the whole audio will be used.
92
+
93
+ Returns:
94
+ audio (np.ndarray): audio
95
+ """
96
+
97
+ audio, sr = soundfile.read(adfile)
98
+ if len(audio.shape) > 1:
99
+ audio = audio[:, 0]
100
+
101
+ if sampling_rate is not None and sr != sampling_rate:
102
+ audio = soxr.resample(audio, sr, sampling_rate, quality="VHQ")
103
+ sr = sampling_rate
104
+
105
+ if segment_duration is not None:
106
+ seg_length = int(sr * segment_duration)
107
+ audio = random_select_audio_segment(audio, seg_length)
108
+
109
+ # Audio volume normalize
110
+ if volume_normalize:
111
+ audio = audio_volume_normalize(audio)
112
+ # check the audio length
113
+ if length is not None:
114
+ assert abs(audio.shape[0] - length) < 1000
115
+ if audio.shape[0] > length:
116
+ audio = audio[:length]
117
+ else:
118
+ audio = np.pad(audio, (0, int(length - audio.shape[0])))
119
+ return audio
120
+
121
+
122
+ def random_select_audio_segment(audio: np.ndarray, length: int) -> np.ndarray:
123
+ """get an audio segment given the length
124
+
125
+ Args:
126
+ audio (np.ndarray):
127
+ length (int): audio length = sampling_rate * duration
128
+ """
129
+ if audio.shape[0] < length:
130
+ audio = np.pad(audio, (0, int(length - audio.shape[0])))
131
+ start_index = random.randint(0, audio.shape[0] - length)
132
+ end_index = int(start_index + length)
133
+
134
+ return audio[start_index:end_index]
135
+
136
+
137
+ def audio_highpass_filter(audio, sample_rate, highpass_cutoff_freq):
138
+ """apply highpass fileter to audio
139
+
140
+ Args:
141
+ audio (np.ndarray):
142
+ sample_rate (ind):
143
+ highpass_cutoff_freq (int):
144
+ """
145
+
146
+ audio = torchaudio.functional.highpass_biquad(
147
+ torch.from_numpy(audio), sample_rate, cutoff_freq=highpass_cutoff_freq
148
+ )
149
+ return audio.numpy()
150
+
151
+
152
+ def stft(
153
+ x: torch.Tensor,
154
+ fft_size: int,
155
+ hop_size: int,
156
+ win_length: int,
157
+ window: str,
158
+ use_complex: bool = False,
159
+ ) -> torch.Tensor:
160
+ """Perform STFT and convert to magnitude spectrogram.
161
+ Args:
162
+ x (Tensor): Input signal tensor (B, T).
163
+ fft_size (int): FFT size.
164
+ hop_size (int): Hop size.
165
+ win_length (int): Window length.
166
+ window (str): Window function type.
167
+ Returns:
168
+ Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
169
+ """
170
+
171
+ x_stft = torch.stft(
172
+ x, fft_size, hop_size, win_length, window.to(x.device), return_complex=True
173
+ )
174
+
175
+ # clamp is needed to avoid nan or inf
176
+ if not use_complex:
177
+ return torch.sqrt(
178
+ torch.clamp(x_stft.real**2 + x_stft.imag**2, min=1e-7, max=1e3)
179
+ ).transpose(2, 1)
180
+ else:
181
+ res = torch.cat([x_stft.real.unsqueeze(1), x_stft.imag.unsqueeze(1)], dim=1)
182
+ res = res.transpose(2, 3) # [B, 2, T, F]
183
+ return res
184
+
185
+
186
+ def detect_speech_boundaries(
187
+ wav: np.ndarray,
188
+ sample_rate: int,
189
+ window_duration: float = 0.1,
190
+ energy_threshold: float = 0.01,
191
+ margin_factor: int = 2
192
+ ) -> Tuple[int, int]:
193
+ """Detect the start and end points of speech in an audio signal using RMS energy.
194
+
195
+ Args:
196
+ wav: Input audio signal array with values in [-1, 1]
197
+ sample_rate: Audio sample rate in Hz
198
+ window_duration: Duration of detection window in seconds
199
+ energy_threshold: RMS energy threshold for speech detection
200
+ margin_factor: Factor to determine extra margin around detected boundaries
201
+
202
+ Returns:
203
+ tuple: (start_index, end_index) of speech segment
204
+
205
+ Raises:
206
+ ValueError: If the audio contains only silence
207
+ """
208
+ window_size = int(window_duration * sample_rate)
209
+ margin = margin_factor * window_size
210
+ step_size = window_size // 10
211
+
212
+ # Create sliding windows using stride tricks to avoid loops
213
+ windows = sliding_window_view(wav, window_size)[::step_size]
214
+
215
+ # Calculate RMS energy for each window
216
+ energy = np.sqrt(np.mean(windows ** 2, axis=1))
217
+ speech_mask = energy >= energy_threshold
218
+
219
+ if not np.any(speech_mask):
220
+ raise ValueError("No speech detected in audio (only silence)")
221
+
222
+ start = max(0, np.argmax(speech_mask) * step_size - margin)
223
+ end = min(len(wav), (len(speech_mask) - 1 - np.argmax(speech_mask[::-1])) * step_size + margin)
224
+
225
+ return start, end
226
+
227
+
228
+ def remove_silence_on_both_ends(
229
+ wav: np.ndarray,
230
+ sample_rate: int,
231
+ window_duration: float = 0.1,
232
+ volume_threshold: float = 0.01
233
+ ) -> np.ndarray:
234
+ """Remove silence from both ends of an audio signal.
235
+
236
+ Args:
237
+ wav: Input audio signal array
238
+ sample_rate: Audio sample rate in Hz
239
+ window_duration: Duration of detection window in seconds
240
+ volume_threshold: Amplitude threshold for silence detection
241
+
242
+ Returns:
243
+ np.ndarray: Audio signal with silence removed from both ends
244
+
245
+ Raises:
246
+ ValueError: If the audio contains only silence
247
+ """
248
+ start, end = detect_speech_boundaries(
249
+ wav,
250
+ sample_rate,
251
+ window_duration,
252
+ volume_threshold
253
+ )
254
+ return wav[start:end]
255
+
256
+
257
+
258
+ def hertz_to_mel(pitch: float) -> float:
259
+ """
260
+ Converts a frequency from the Hertz scale to the Mel scale.
261
+
262
+ Parameters:
263
+ - pitch: float or ndarray
264
+ Frequency in Hertz.
265
+
266
+ Returns:
267
+ - mel: float or ndarray
268
+ Frequency in Mel scale.
269
+ """
270
+ mel = 2595 * np.log10(1 + pitch / 700)
271
+ return mel
sparktts/utils/file.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 SparkAudio
2
+ # 2025 Xinsheng Wang ([email protected])
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Description:
17
+ This script contains a collection of functions designed to handle various
18
+ file reading and writing operations. It provides utilities to read from files,
19
+ write data to files, and perform file manipulation tasks.
20
+ """
21
+
22
+
23
+ import os
24
+ import json
25
+ import json
26
+ import csv
27
+
28
+ from tqdm import tqdm
29
+ from typing import List, Dict, Any, Set, Union
30
+ from pathlib import Path
31
+ from omegaconf import OmegaConf, DictConfig
32
+
33
+
34
+ def resolve_symbolic_link(symbolic_link_path: Path) -> Path:
35
+ """
36
+ Resolves the absolute path of a symbolic link.
37
+
38
+ Args:
39
+ symbolic_link_path (Path): The path to the symbolic link.
40
+
41
+ Returns:
42
+ Path: The absolute path that the symbolic link points to.
43
+ """
44
+
45
+ link_directory = os.path.dirname(symbolic_link_path)
46
+ target_path_relative = os.readlink(symbolic_link_path)
47
+ return os.path.join(link_directory, target_path_relative)
48
+
49
+
50
+ def write_jsonl(metadata: List[dict], file_path: Path) -> None:
51
+ """Writes a list of dictionaries to a JSONL file.
52
+
53
+ Args:
54
+ metadata : List[dict]
55
+ A list of dictionaries, each representing a piece of meta.
56
+ file_path : Path
57
+ The file path to save the JSONL file
58
+
59
+ This function writes each dictionary in the list to a new line in the specified file.
60
+ """
61
+ with open(file_path, "w", encoding="utf-8") as f:
62
+ for meta in tqdm(metadata, desc="writing jsonl"):
63
+ # Convert dictionary to JSON string and write it to the file with a newline
64
+ json_str = json.dumps(meta, ensure_ascii=False) + "\n"
65
+ f.write(json_str)
66
+ print(f"jsonl saved to {file_path}")
67
+
68
+
69
+ def read_jsonl(file_path: Path) -> List[dict]:
70
+ """
71
+ Reads a JSONL file and returns a list of dictionaries.
72
+
73
+ Args:
74
+ file_path : Path
75
+ The path to the JSONL file to be read.
76
+
77
+ Returns:
78
+ List[dict]
79
+ A list of dictionaries parsed from each line of the JSONL file.
80
+ """
81
+ metadata = []
82
+ # Open the file for reading
83
+ with open(file_path, "r", encoding="utf-8") as f:
84
+ # Split the file into lines
85
+ lines = f.read().splitlines()
86
+ # Process each line
87
+ for line in lines:
88
+ # Convert JSON string back to dictionary and append to list
89
+ meta = json.loads(line)
90
+ metadata.append(meta)
91
+ # Return the list of metadata
92
+ return metadata
93
+
94
+ def read_json_as_jsonl(file_path: Path) -> List[dict]:
95
+ metadata = []
96
+ with open(file_path, 'r', encoding='utf-8') as infile:
97
+ data = json.load(infile)
98
+ for k in sorted(data.keys()):
99
+ meta = {'index': k}
100
+ meta.update(data[k])
101
+ metadata.append(meta)
102
+ return metadata
103
+
104
+
105
+
106
+ def decode_unicode_strings(meta: Dict[str, Any]) -> Dict[str, Any]:
107
+ processed_meta = {}
108
+ for k, v in meta.items():
109
+ if isinstance(v, str):
110
+ processed_meta[k] = v.encode("utf-8").decode("unicode_escape")
111
+ else:
112
+ processed_meta[k] = v
113
+ return processed_meta
114
+
115
+
116
+ def load_config(config_path: Path) -> DictConfig:
117
+ """Loads a configuration file and optionally merges it with a base configuration.
118
+
119
+ Args:
120
+ config_path (Path): Path to the configuration file.
121
+ """
122
+ # Load the initial configuration from the given path
123
+ config = OmegaConf.load(config_path)
124
+
125
+ # Check if there is a base configuration specified and merge if necessary
126
+ if config.get("base_config", None) is not None:
127
+ base_config = OmegaConf.load(config["base_config"])
128
+ config = OmegaConf.merge(base_config, config)
129
+
130
+ return config
131
+
132
+
133
+
134
+ def jsonl_to_csv(jsonl_file_path: str, csv_file_path: str) -> None:
135
+ """
136
+ Converts a JSONL file to a CSV file.
137
+
138
+ This function reads a JSONL file, determines all unique keys present in the file,
139
+ and writes the data to a CSV file with columns for all these keys.
140
+ """
141
+
142
+ all_keys = set()
143
+ data_rows = []
144
+
145
+ # Read the JSONL file once to extract keys and collect data
146
+ with open(jsonl_file_path, 'r') as file:
147
+ for line in file:
148
+ data = json.loads(line.strip())
149
+ data_rows.append(data)
150
+ all_keys.update(data.keys())
151
+
152
+ # Convert the set of keys to a sorted list for consistent column order
153
+ sorted_keys = sorted(all_keys)
154
+
155
+ # Write the data to a CSV file
156
+ with open(csv_file_path, 'w', newline='') as csvfile:
157
+ writer = csv.DictWriter(csvfile, fieldnames=sorted_keys)
158
+
159
+ # Write the header row
160
+ writer.writeheader()
161
+
162
+ # Write each row of data
163
+ for data in data_rows:
164
+ writer.writerow(data)
165
+
166
+ print(f"CSV file has been created at {csv_file_path}")
167
+
168
+
169
+ def save_metadata(data, filename, headers=None):
170
+ """
171
+ Save metadata to a file.
172
+
173
+ Args:
174
+ data (list of dict): Metadata to be saved.
175
+ filename (str): Name of the file to save the metadata.
176
+ headers (list of str): The order of column names to be saved; defaults to the keys from the first dictionary in data if not provided.
177
+ """
178
+ # Set headers to keys from the first dictionary in data if not explicitly provided
179
+ if headers is None:
180
+ headers = list(data[0].keys())
181
+
182
+ with open(filename, "w", encoding="utf-8") as file:
183
+ # Write the headers to the file
184
+ file.write("|".join(headers) + "\n")
185
+ for entry in data:
186
+ # Retrieve values in the order of headers, replacing any '|' characters with a space to prevent formatting errors
187
+ formatted_values = [str(entry.get(key, "")).replace("|", " ") for key in headers]
188
+ # Write the formatted values to the file
189
+ file.write("|".join(formatted_values) + "\n")
190
+
191
+
192
+ def read_metadata(filename, headers=None):
193
+ """
194
+ Read metadata from a file.
195
+
196
+ Args:
197
+ filename (str): The file from which to read the metadata.
198
+
199
+ Returns:
200
+ list of dict: The metadata read from the file.
201
+ list of str: The headers used in the file.
202
+ """
203
+ with open(filename, "r", encoding="utf-8") as file:
204
+ lines = file.readlines()
205
+
206
+ data = []
207
+ # Set headers from the first line of the file if not provided
208
+ if headers is None:
209
+ headers = lines[0].strip().split("|")
210
+ lines = lines[1:]
211
+
212
+ for line in lines:
213
+ line = line.strip()
214
+ # Skip empty lines
215
+ if not line:
216
+ continue
217
+ # Split the line by '|' and pair with headers to form a dictionary
218
+ entry_data = dict(zip(headers, line.split("|")))
219
+ data.append(entry_data)
220
+
221
+ return data, headers
sparktts/utils/parse_options.sh ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Copyright 2012 Johns Hopkins University (Author: Daniel Povey);
4
+ # Arnab Ghoshal, Karel Vesely
5
+
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
13
+ # KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
14
+ # WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
15
+ # MERCHANTABLITY OR NON-INFRINGEMENT.
16
+ # See the Apache 2 License for the specific language governing permissions and
17
+ # limitations under the License.
18
+
19
+
20
+ # Parse command-line options.
21
+ # To be sourced by another script (as in ". parse_options.sh").
22
+ # Option format is: --option-name arg
23
+ # and shell variable "option_name" gets set to value "arg."
24
+ # The exception is --help, which takes no arguments, but prints the
25
+ # $help_message variable (if defined).
26
+
27
+
28
+ ###
29
+ ### The --config file options have lower priority to command line
30
+ ### options, so we need to import them first...
31
+ ###
32
+
33
+ # Now import all the configs specified by command-line, in left-to-right order
34
+ # for ((argpos=1; argpos<$#; argpos++)); do
35
+ # if [ "${!argpos}" == "--config" ]; then
36
+ # argpos_plus1=$((argpos+1))
37
+ # config=${!argpos_plus1}
38
+ # [ ! -r $config ] && echo "$0: missing config '$config'" && exit 1
39
+ # . $config # source the config file.
40
+ # fi
41
+ # done
42
+
43
+
44
+ ###
45
+ ### No we process the command line options
46
+ ###
47
+ while true; do
48
+ [ -z "${1:-}" ] && break; # break if there are no arguments
49
+ case "$1" in
50
+ # If the enclosing script is called with --help option, print the help
51
+ # message and exit. Scripts should put help messages in $help_message
52
+ --help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2;
53
+ else printf "$help_message\n" 1>&2 ; fi;
54
+ exit 0 ;;
55
+ --*=*) echo "$0: options to scripts must be of the form --name value, got '$1'"
56
+ exit 1 ;;
57
+ # If the first command-line argument begins with "--" (e.g. --foo-bar),
58
+ # then work out the variable name as $name, which will equal "foo_bar".
59
+ --*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`;
60
+ # Next we test whether the variable in question is undefned-- if so it's
61
+ # an invalid option and we die. Note: $0 evaluates to the name of the
62
+ # enclosing script.
63
+ # The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar
64
+ # is undefined. We then have to wrap this test inside "eval" because
65
+ # foo_bar is itself inside a variable ($name).
66
+ eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
67
+
68
+ oldval="`eval echo \\$$name`";
69
+ # Work out whether we seem to be expecting a Boolean argument.
70
+ if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then
71
+ was_bool=true;
72
+ else
73
+ was_bool=false;
74
+ fi
75
+
76
+ # Set the variable to the right value-- the escaped quotes make it work if
77
+ # the option had spaces, like --cmd "queue.pl -sync y"
78
+ eval $name=\"$2\";
79
+
80
+ # Check that Boolean-valued arguments are really Boolean.
81
+ if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
82
+ echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
83
+ exit 1;
84
+ fi
85
+ shift 2;
86
+ ;;
87
+ *) break;
88
+ esac
89
+ done
90
+
91
+
92
+ # Check for an empty argument to the --cmd option, which can easily occur as a
93
+ # result of scripting errors.
94
+ [ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1;
95
+
96
+
97
+ true; # so this script returns exit code 0.
sparktts/utils/token_parser.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TASK_TOKEN_MAP = {
2
+ "vc": "<|task_vc|>",
3
+ "tts": "<|task_tts|>",
4
+ "asr": "<|task_asr|>",
5
+ "s2s": "<|task_s2s|>",
6
+ "t2s": "<|task_t2s|>",
7
+ "understand": "<|task_understand|>",
8
+ "caption": "<|task_cap|>",
9
+ "controllable_tts": "<|task_controllable_tts|>",
10
+ "prompt_tts": "<|task_prompt_tts|>",
11
+ "speech_edit": "<|task_edit|>",
12
+ }
13
+
14
+ LEVELS_MAP = {
15
+ "very_low": 0,
16
+ "low": 1,
17
+ "moderate": 2,
18
+ "high": 3,
19
+ "very_high": 4,
20
+ }
21
+
22
+ LEVELS_MAP_UI = {
23
+ 1: 'very_low',
24
+ 2: 'low',
25
+ 3: 'moderate',
26
+ 4: 'high',
27
+ 5: 'very_high'
28
+ }
29
+
30
+ GENDER_MAP = {
31
+ "female": 0,
32
+ "male": 1,
33
+ }
34
+
35
+ AGE_MAP = {"Child": 0, "Teenager": 1, "Youth-Adult": 2, "Middle-aged": 3, "Elderly": 4}
36
+
37
+ EMO_MAP = {
38
+ "UNKNOWN": 0,
39
+ "NEUTRAL": 1,
40
+ "ANGRY": 2,
41
+ "HAPPY": 3,
42
+ "SAD": 4,
43
+ "FEARFUL": 5,
44
+ "DISGUSTED": 6,
45
+ "SURPRISED": 7,
46
+ "SARCASTIC": 8,
47
+ "EXCITED": 9,
48
+ "SLEEPY": 10,
49
+ "CONFUSED": 11,
50
+ "EMPHASIS": 12,
51
+ "LAUGHING": 13,
52
+ "SINGING": 14,
53
+ "WORRIED": 15,
54
+ "WHISPER": 16,
55
+ "ANXIOUS": 17,
56
+ "NO-AGREEMENT": 18,
57
+ "APOLOGETIC": 19,
58
+ "CONCERNED": 20,
59
+ "ENUNCIATED": 21,
60
+ "ASSERTIVE": 22,
61
+ "ENCOURAGING": 23,
62
+ "CONTEMPT": 24,
63
+ }
64
+
65
+
66
+ class TokenParser:
67
+ """Turn label to special token"""
68
+
69
+ def __init__(self):
70
+ pass
71
+
72
+ """Parse the attributes of a person."""
73
+
74
+ def __init__(self):
75
+ pass
76
+
77
+ @staticmethod
78
+ def age(age: str) -> str:
79
+ """Turn age token."""
80
+ age_id = AGE_MAP[age]
81
+ return f"<|age_{age_id}|>"
82
+
83
+ @staticmethod
84
+ def gender(gender: str) -> str:
85
+ """Turn gender token."""
86
+ gender_id = GENDER_MAP[gender]
87
+ return f"<|gender_{gender_id}|>"
88
+
89
+ @staticmethod
90
+ def mel_value(mel: int):
91
+ """Turn special token of mel scale pitch."""
92
+ mel = max(0, int(mel))
93
+ mel = min(1000, int(mel))
94
+ return f"<|pitch_value_{mel}|>"
95
+
96
+ @staticmethod
97
+ def mel_level(level: str):
98
+ """Turn special token of mel level."""
99
+ level_tag = LEVELS_MAP[level]
100
+ return f"<|pitch_label_{level_tag}|>"
101
+
102
+ @staticmethod
103
+ def pitch_var_value(pitch_std: int):
104
+ """Turn special token of pitch_std value."""
105
+ assert isinstance(pitch_std, int)
106
+ pitch_std = max(0, int(pitch_std))
107
+ pitch_std = min(10, int(pitch_std))
108
+ return f"<|pitch_var_value_{pitch_std}|>"
109
+
110
+ @staticmethod
111
+ def pitch_var_level(level: str):
112
+ """Turn special token of pitch std level."""
113
+ level_tag = LEVELS_MAP[level]
114
+ return f"<|pitch_var_label_{level_tag}|>"
115
+
116
+ @staticmethod
117
+ def loudness_value(loudness: int):
118
+ """Turn special toak of loudness value [0, 30]"""
119
+ assert loudness >= 0
120
+ loudness = max(0, int(loudness))
121
+ loudness = min(30, int(loudness))
122
+ return f"<|loudness_value_{loudness}|>"
123
+
124
+ @staticmethod
125
+ def loudness_level(level: str):
126
+ """Turn special token of loudness level."""
127
+ level_tag = LEVELS_MAP[level]
128
+ return f"<|loudness_label_{level_tag}|>"
129
+
130
+ @staticmethod
131
+ def speed_value(speed: int):
132
+ """Turn special token of speed value."""
133
+ speed = max(0, int(speed))
134
+ speed = min(10, int(speed))
135
+ return f"<|speed_value_{speed}|>"
136
+
137
+ @staticmethod
138
+ def speed_level(level: str):
139
+ """Turn special token of speed level."""
140
+ level_tag = LEVELS_MAP[level]
141
+ return f"<|speed_label_{level_tag}|>"
142
+
143
+ @staticmethod
144
+ def task(task: str) -> str:
145
+ """Turn special token of task."""
146
+ assert task in TASK_TOKEN_MAP.keys()
147
+
148
+ return TASK_TOKEN_MAP[task]
149
+
150
+ @staticmethod
151
+ def emotion(emotion: str):
152
+ emo_id = EMO_MAP[emotion]
153
+
154
+ return f"<|emotion_{emo_id}|>"
155
+
156
+
157
+ # test
158
+ if __name__ == "__main__":
159
+ from transformers import AutoTokenizer
160
+
161
+ tokenizer = AutoTokenizer.from_pretrained(
162
+ "/aifs4su/xinshengwang/code/StyleCraft/tokenizer/stylecraft-bicodec-pitch-loudness-speed-emotion-tokenizer"
163
+ )
164
+
165
+ tasks = ["tts", "tts", "understand", "controllable_tts", "prompt_tts"]
166
+ ages = ["Child", "Teenager", "Youth-Adult", "Middle-aged", "Elderly"]
167
+ genders = ["female", "female", "female", "male", "male"]
168
+ mels = [100, 200, 300, 400, 500]
169
+ mel_levels = ["very_low", "low", "moderate", "high", "very_high"]
170
+ loudnesses = [1, 10, 23, 19, 30]
171
+ loudness_levels = ["very_low", "low", "moderate", "high", "very_high"]
172
+ emotions = ["UNKNOWN", "NEUTRAL", "ANGRY", "HAPPY", "SAD"]
173
+
174
+ for i in range(5):
175
+ task = TokenParser.task(tasks[i])
176
+ age = TokenParser.age(ages[i])
177
+ gender = TokenParser.gender(genders[i])
178
+ mel = TokenParser.mel_value(mels[i])
179
+ mel_level = TokenParser.mel_level(mel_levels[i])
180
+ loudness = TokenParser.loudness_value(loudnesses[i])
181
+ loudness_level = TokenParser.loudness_level(loudness_levels[i])
182
+ emotion = TokenParser.emotion(emotions[i])
183
+ inputs = [task, age, gender, mel, mel_level, loudness, loudness_level, emotion]
184
+ inputs = "".join(inputs)
185
+ ids = tokenizer.encode(inputs, add_special_tokens=False)
186
+ print(ids)
187
+ print("decode", tokenizer.decode(ids))
src/demos/trump/trump_en.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:64fe26aa5fb432be85a183b6e48f7e1045c5c9fd4b8eb4faeeb5d4df5934f80f
3
+ size 476204
src/demos/zhongli/zhongli_en.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9b94df88681b9f8afe07d2081bcc354f8e97c83886a944f37b4979b83547d39f
3
+ size 389804
src/demos/余承东/yuchengdong_zh.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:90bcbe608c5c483d34213e18b7778e11ab55998d52bbe1b5f8bdd80f0473e7c2
3
+ size 496044
src/demos/刘德华/dehua_zh.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fe97bd7c679be4dfbbd30496bf9b192c43dfae4a497ae7d99e85841ea06e77f2
3
+ size 772524
src/demos/哪吒/nezha_zh.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e09b01a61e6965f42b2bbfd47a33730991276ad7e6892531ba844646f8c9601e
3
+ size 596524