Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	Commit 
							
							·
						
						6bcf114
	
1
								Parent(s):
							
							9d0c692
								
update
Browse filesThis view is limited to 50 files because it contains too many changes.  
							See raw diff
- LICENSE +0 -201
- README.md +4 -4
- README_CN.md +0 -654
- __init__.py +0 -0
- app.py +0 -124
- assets/Step-Audio.pdf +0 -3
- assets/architecture.png +0 -3
- assets/logo.png +0 -0
- assets/pipeline.png +0 -3
- assets/rlhf.png +0 -0
- assets/stepeval_radar_chart.png +0 -3
- assets/yuewen.jpeg +0 -0
- cosyvoice/__init__.py +0 -0
- cosyvoice/cli/__init__.py +0 -0
- cosyvoice/cli/cosyvoice.py +0 -68
- cosyvoice/cli/frontend.py +0 -106
- cosyvoice/cli/model.py +0 -32
- cosyvoice/flow/decoder.py +0 -238
- cosyvoice/flow/flow.py +0 -196
- cosyvoice/flow/flow_matching.py +0 -315
- cosyvoice/flow/length_regulator.py +0 -65
- cosyvoice/hifigan/f0_predictor.py +0 -55
- cosyvoice/hifigan/generator.py +0 -566
- cosyvoice/matcha/audio.py +0 -90
- cosyvoice/matcha/decoder.py +0 -511
- cosyvoice/matcha/flow_matching.py +0 -141
- cosyvoice/matcha/transformer.py +0 -443
- cosyvoice/transformer/__init__.py +0 -0
- cosyvoice/transformer/activation.py +0 -87
- cosyvoice/transformer/attention.py +0 -322
- cosyvoice/transformer/convolution.py +0 -147
- cosyvoice/transformer/decoder.py +0 -418
- cosyvoice/transformer/decoder_layer.py +0 -132
- cosyvoice/transformer/embedding.py +0 -293
- cosyvoice/transformer/encoder.py +0 -633
- cosyvoice/transformer/encoder_layer.py +0 -237
- cosyvoice/transformer/label_smoothing_loss.py +0 -98
- cosyvoice/transformer/positionwise_feed_forward.py +0 -116
- cosyvoice/transformer/subsampling.py +0 -391
- cosyvoice/utils/__init__.py +0 -0
- cosyvoice/utils/audio.py +0 -90
- cosyvoice/utils/class_utils.py +0 -78
- cosyvoice/utils/common.py +0 -169
- cosyvoice/utils/executor.py +0 -151
- cosyvoice/utils/file_utils.py +0 -49
- cosyvoice/utils/frontend_utils.py +0 -142
- cosyvoice/utils/mask.py +0 -226
- cosyvoice/utils/scheduler.py +0 -761
- cosyvoice/utils/train_utils.py +0 -350
- examples/clone_wav_lixueqin.wav +0 -3
    	
        LICENSE
    DELETED
    
    | @@ -1,201 +0,0 @@ | |
| 1 | 
            -
                                             Apache License
         | 
| 2 | 
            -
                                       Version 2.0, January 2004
         | 
| 3 | 
            -
                                    http://www.apache.org/licenses/
         | 
| 4 | 
            -
             | 
| 5 | 
            -
               TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
         | 
| 6 | 
            -
             | 
| 7 | 
            -
               1. Definitions.
         | 
| 8 | 
            -
             | 
| 9 | 
            -
                  "License" shall mean the terms and conditions for use, reproduction,
         | 
| 10 | 
            -
                  and distribution as defined by Sections 1 through 9 of this document.
         | 
| 11 | 
            -
             | 
| 12 | 
            -
                  "Licensor" shall mean the copyright owner or entity authorized by
         | 
| 13 | 
            -
                  the copyright owner that is granting the License.
         | 
| 14 | 
            -
             | 
| 15 | 
            -
                  "Legal Entity" shall mean the union of the acting entity and all
         | 
| 16 | 
            -
                  other entities that control, are controlled by, or are under common
         | 
| 17 | 
            -
                  control with that entity. For the purposes of this definition,
         | 
| 18 | 
            -
                  "control" means (i) the power, direct or indirect, to cause the
         | 
| 19 | 
            -
                  direction or management of such entity, whether by contract or
         | 
| 20 | 
            -
                  otherwise, or (ii) ownership of fifty percent (50%) or more of the
         | 
| 21 | 
            -
                  outstanding shares, or (iii) beneficial ownership of such entity.
         | 
| 22 | 
            -
             | 
| 23 | 
            -
                  "You" (or "Your") shall mean an individual or Legal Entity
         | 
| 24 | 
            -
                  exercising permissions granted by this License.
         | 
| 25 | 
            -
             | 
| 26 | 
            -
                  "Source" form shall mean the preferred form for making modifications,
         | 
| 27 | 
            -
                  including but not limited to software source code, documentation
         | 
| 28 | 
            -
                  source, and configuration files.
         | 
| 29 | 
            -
             | 
| 30 | 
            -
                  "Object" form shall mean any form resulting from mechanical
         | 
| 31 | 
            -
                  transformation or translation of a Source form, including but
         | 
| 32 | 
            -
                  not limited to compiled object code, generated documentation,
         | 
| 33 | 
            -
                  and conversions to other media types.
         | 
| 34 | 
            -
             | 
| 35 | 
            -
                  "Work" shall mean the work of authorship, whether in Source or
         | 
| 36 | 
            -
                  Object form, made available under the License, as indicated by a
         | 
| 37 | 
            -
                  copyright notice that is included in or attached to the work
         | 
| 38 | 
            -
                  (an example is provided in the Appendix below).
         | 
| 39 | 
            -
             | 
| 40 | 
            -
                  "Derivative Works" shall mean any work, whether in Source or Object
         | 
| 41 | 
            -
                  form, that is based on (or derived from) the Work and for which the
         | 
| 42 | 
            -
                  editorial revisions, annotations, elaborations, or other modifications
         | 
| 43 | 
            -
                  represent, as a whole, an original work of authorship. For the purposes
         | 
| 44 | 
            -
                  of this License, Derivative Works shall not include works that remain
         | 
| 45 | 
            -
                  separable from, or merely link (or bind by name) to the interfaces of,
         | 
| 46 | 
            -
                  the Work and Derivative Works thereof.
         | 
| 47 | 
            -
             | 
| 48 | 
            -
                  "Contribution" shall mean any work of authorship, including
         | 
| 49 | 
            -
                  the original version of the Work and any modifications or additions
         | 
| 50 | 
            -
                  to that Work or Derivative Works thereof, that is intentionally
         | 
| 51 | 
            -
                  submitted to Licensor for inclusion in the Work by the copyright owner
         | 
| 52 | 
            -
                  or by an individual or Legal Entity authorized to submit on behalf of
         | 
| 53 | 
            -
                  the copyright owner. For the purposes of this definition, "submitted"
         | 
| 54 | 
            -
                  means any form of electronic, verbal, or written communication sent
         | 
| 55 | 
            -
                  to the Licensor or its representatives, including but not limited to
         | 
| 56 | 
            -
                  communication on electronic mailing lists, source code control systems,
         | 
| 57 | 
            -
                  and issue tracking systems that are managed by, or on behalf of, the
         | 
| 58 | 
            -
                  Licensor for the purpose of discussing and improving the Work, but
         | 
| 59 | 
            -
                  excluding communication that is conspicuously marked or otherwise
         | 
| 60 | 
            -
                  designated in writing by the copyright owner as "Not a Contribution."
         | 
| 61 | 
            -
             | 
| 62 | 
            -
                  "Contributor" shall mean Licensor and any individual or Legal Entity
         | 
| 63 | 
            -
                  on behalf of whom a Contribution has been received by Licensor and
         | 
| 64 | 
            -
                  subsequently incorporated within the Work.
         | 
| 65 | 
            -
             | 
| 66 | 
            -
               2. Grant of Copyright License. Subject to the terms and conditions of
         | 
| 67 | 
            -
                  this License, each Contributor hereby grants to You a perpetual,
         | 
| 68 | 
            -
                  worldwide, non-exclusive, no-charge, royalty-free, irrevocable
         | 
| 69 | 
            -
                  copyright license to reproduce, prepare Derivative Works of,
         | 
| 70 | 
            -
                  publicly display, publicly perform, sublicense, and distribute the
         | 
| 71 | 
            -
                  Work and such Derivative Works in Source or Object form.
         | 
| 72 | 
            -
             | 
| 73 | 
            -
               3. Grant of Patent License. Subject to the terms and conditions of
         | 
| 74 | 
            -
                  this License, each Contributor hereby grants to You a perpetual,
         | 
| 75 | 
            -
                  worldwide, non-exclusive, no-charge, royalty-free, irrevocable
         | 
| 76 | 
            -
                  (except as stated in this section) patent license to make, have made,
         | 
| 77 | 
            -
                  use, offer to sell, sell, import, and otherwise transfer the Work,
         | 
| 78 | 
            -
                  where such license applies only to those patent claims licensable
         | 
| 79 | 
            -
                  by such Contributor that are necessarily infringed by their
         | 
| 80 | 
            -
                  Contribution(s) alone or by combination of their Contribution(s)
         | 
| 81 | 
            -
                  with the Work to which such Contribution(s) was submitted. If You
         | 
| 82 | 
            -
                  institute patent litigation against any entity (including a
         | 
| 83 | 
            -
                  cross-claim or counterclaim in a lawsuit) alleging that the Work
         | 
| 84 | 
            -
                  or a Contribution incorporated within the Work constitutes direct
         | 
| 85 | 
            -
                  or contributory patent infringement, then any patent licenses
         | 
| 86 | 
            -
                  granted to You under this License for that Work shall terminate
         | 
| 87 | 
            -
                  as of the date such litigation is filed.
         | 
| 88 | 
            -
             | 
| 89 | 
            -
               4. Redistribution. You may reproduce and distribute copies of the
         | 
| 90 | 
            -
                  Work or Derivative Works thereof in any medium, with or without
         | 
| 91 | 
            -
                  modifications, and in Source or Object form, provided that You
         | 
| 92 | 
            -
                  meet the following conditions:
         | 
| 93 | 
            -
             | 
| 94 | 
            -
                  (a) You must give any other recipients of the Work or
         | 
| 95 | 
            -
                      Derivative Works a copy of this License; and
         | 
| 96 | 
            -
             | 
| 97 | 
            -
                  (b) You must cause any modified files to carry prominent notices
         | 
| 98 | 
            -
                      stating that You changed the files; and
         | 
| 99 | 
            -
             | 
| 100 | 
            -
                  (c) You must retain, in the Source form of any Derivative Works
         | 
| 101 | 
            -
                      that You distribute, all copyright, patent, trademark, and
         | 
| 102 | 
            -
                      attribution notices from the Source form of the Work,
         | 
| 103 | 
            -
                      excluding those notices that do not pertain to any part of
         | 
| 104 | 
            -
                      the Derivative Works; and
         | 
| 105 | 
            -
             | 
| 106 | 
            -
                  (d) If the Work includes a "NOTICE" text file as part of its
         | 
| 107 | 
            -
                      distribution, then any Derivative Works that You distribute must
         | 
| 108 | 
            -
                      include a readable copy of the attribution notices contained
         | 
| 109 | 
            -
                      within such NOTICE file, excluding those notices that do not
         | 
| 110 | 
            -
                      pertain to any part of the Derivative Works, in at least one
         | 
| 111 | 
            -
                      of the following places: within a NOTICE text file distributed
         | 
| 112 | 
            -
                      as part of the Derivative Works; within the Source form or
         | 
| 113 | 
            -
                      documentation, if provided along with the Derivative Works; or,
         | 
| 114 | 
            -
                      within a display generated by the Derivative Works, if and
         | 
| 115 | 
            -
                      wherever such third-party notices normally appear. The contents
         | 
| 116 | 
            -
                      of the NOTICE file are for informational purposes only and
         | 
| 117 | 
            -
                      do not modify the License. You may add Your own attribution
         | 
| 118 | 
            -
                      notices within Derivative Works that You distribute, alongside
         | 
| 119 | 
            -
                      or as an addendum to the NOTICE text from the Work, provided
         | 
| 120 | 
            -
                      that such additional attribution notices cannot be construed
         | 
| 121 | 
            -
                      as modifying the License.
         | 
| 122 | 
            -
             | 
| 123 | 
            -
                  You may add Your own copyright statement to Your modifications and
         | 
| 124 | 
            -
                  may provide additional or different license terms and conditions
         | 
| 125 | 
            -
                  for use, reproduction, or distribution of Your modifications, or
         | 
| 126 | 
            -
                  for any such Derivative Works as a whole, provided Your use,
         | 
| 127 | 
            -
                  reproduction, and distribution of the Work otherwise complies with
         | 
| 128 | 
            -
                  the conditions stated in this License.
         | 
| 129 | 
            -
             | 
| 130 | 
            -
               5. Submission of Contributions. Unless You explicitly state otherwise,
         | 
| 131 | 
            -
                  any Contribution intentionally submitted for inclusion in the Work
         | 
| 132 | 
            -
                  by You to the Licensor shall be under the terms and conditions of
         | 
| 133 | 
            -
                  this License, without any additional terms or conditions.
         | 
| 134 | 
            -
                  Notwithstanding the above, nothing herein shall supersede or modify
         | 
| 135 | 
            -
                  the terms of any separate license agreement you may have executed
         | 
| 136 | 
            -
                  with Licensor regarding such Contributions.
         | 
| 137 | 
            -
             | 
| 138 | 
            -
               6. Trademarks. This License does not grant permission to use the trade
         | 
| 139 | 
            -
                  names, trademarks, service marks, or product names of the Licensor,
         | 
| 140 | 
            -
                  except as required for reasonable and customary use in describing the
         | 
| 141 | 
            -
                  origin of the Work and reproducing the content of the NOTICE file.
         | 
| 142 | 
            -
             | 
| 143 | 
            -
               7. Disclaimer of Warranty. Unless required by applicable law or
         | 
| 144 | 
            -
                  agreed to in writing, Licensor provides the Work (and each
         | 
| 145 | 
            -
                  Contributor provides its Contributions) on an "AS IS" BASIS,
         | 
| 146 | 
            -
                  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
         | 
| 147 | 
            -
                  implied, including, without limitation, any warranties or conditions
         | 
| 148 | 
            -
                  of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
         | 
| 149 | 
            -
                  PARTICULAR PURPOSE. You are solely responsible for determining the
         | 
| 150 | 
            -
                  appropriateness of using or redistributing the Work and assume any
         | 
| 151 | 
            -
                  risks associated with Your exercise of permissions under this License.
         | 
| 152 | 
            -
             | 
| 153 | 
            -
               8. Limitation of Liability. In no event and under no legal theory,
         | 
| 154 | 
            -
                  whether in tort (including negligence), contract, or otherwise,
         | 
| 155 | 
            -
                  unless required by applicable law (such as deliberate and grossly
         | 
| 156 | 
            -
                  negligent acts) or agreed to in writing, shall any Contributor be
         | 
| 157 | 
            -
                  liable to You for damages, including any direct, indirect, special,
         | 
| 158 | 
            -
                  incidental, or consequential damages of any character arising as a
         | 
| 159 | 
            -
                  result of this License or out of the use or inability to use the
         | 
| 160 | 
            -
                  Work (including but not limited to damages for loss of goodwill,
         | 
| 161 | 
            -
                  work stoppage, computer failure or malfunction, or any and all
         | 
| 162 | 
            -
                  other commercial damages or losses), even if such Contributor
         | 
| 163 | 
            -
                  has been advised of the possibility of such damages.
         | 
| 164 | 
            -
             | 
| 165 | 
            -
               9. Accepting Warranty or Additional Liability. While redistributing
         | 
| 166 | 
            -
                  the Work or Derivative Works thereof, You may choose to offer,
         | 
| 167 | 
            -
                  and charge a fee for, acceptance of support, warranty, indemnity,
         | 
| 168 | 
            -
                  or other liability obligations and/or rights consistent with this
         | 
| 169 | 
            -
                  License. However, in accepting such obligations, You may act only
         | 
| 170 | 
            -
                  on Your own behalf and on Your sole responsibility, not on behalf
         | 
| 171 | 
            -
                  of any other Contributor, and only if You agree to indemnify,
         | 
| 172 | 
            -
                  defend, and hold each Contributor harmless for any liability
         | 
| 173 | 
            -
                  incurred by, or claims asserted against, such Contributor by reason
         | 
| 174 | 
            -
                  of your accepting any such warranty or additional liability.
         | 
| 175 | 
            -
             | 
| 176 | 
            -
               END OF TERMS AND CONDITIONS
         | 
| 177 | 
            -
             | 
| 178 | 
            -
               APPENDIX: How to apply the Apache License to your work.
         | 
| 179 | 
            -
             | 
| 180 | 
            -
                  To apply the Apache License to your work, attach the following
         | 
| 181 | 
            -
                  boilerplate notice, with the fields enclosed by brackets "[]"
         | 
| 182 | 
            -
                  replaced with your own identifying information. (Don't include
         | 
| 183 | 
            -
                  the brackets!)  The text should be enclosed in the appropriate
         | 
| 184 | 
            -
                  comment syntax for the file format. We also recommend that a
         | 
| 185 | 
            -
                  file or class name and description of purpose be included on the
         | 
| 186 | 
            -
                  same "printed page" as the copyright notice for easier
         | 
| 187 | 
            -
                  identification within third-party archives.
         | 
| 188 | 
            -
             | 
| 189 | 
            -
               Copyright [yyyy] [name of copyright owner]
         | 
| 190 | 
            -
             | 
| 191 | 
            -
               Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 192 | 
            -
               you may not use this file except in compliance with the License.
         | 
| 193 | 
            -
               You may obtain a copy of the License at
         | 
| 194 | 
            -
             | 
| 195 | 
            -
                   http://www.apache.org/licenses/LICENSE-2.0
         | 
| 196 | 
            -
             | 
| 197 | 
            -
               Unless required by applicable law or agreed to in writing, software
         | 
| 198 | 
            -
               distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 199 | 
            -
               WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 200 | 
            -
               See the License for the specific language governing permissions and
         | 
| 201 | 
            -
               limitations under the License.
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        README.md
    CHANGED
    
    | @@ -1,13 +1,13 @@ | |
| 1 | 
             
            ---
         | 
| 2 | 
            -
            title:  | 
| 3 | 
            -
            emoji:  | 
| 4 | 
             
            colorFrom: green
         | 
| 5 | 
            -
            colorTo:  | 
| 6 | 
             
            sdk: gradio
         | 
| 7 | 
             
            sdk_version: 5.16.0
         | 
| 8 | 
             
            app_file: app.py
         | 
| 9 | 
             
            pinned: true
         | 
| 10 | 
            -
            short_description:  | 
| 11 | 
             
            ---
         | 
| 12 |  | 
| 13 | 
             
            Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
         | 
|  | |
| 1 | 
             
            ---
         | 
| 2 | 
            +
            title: DMOSpeech 2
         | 
| 3 | 
            +
            emoji: 🔊
         | 
| 4 | 
             
            colorFrom: green
         | 
| 5 | 
            +
            colorTo: green
         | 
| 6 | 
             
            sdk: gradio
         | 
| 7 | 
             
            sdk_version: 5.16.0
         | 
| 8 | 
             
            app_file: app.py
         | 
| 9 | 
             
            pinned: true
         | 
| 10 | 
            +
            short_description: DMOSpeech 2
         | 
| 11 | 
             
            ---
         | 
| 12 |  | 
| 13 | 
             
            Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
         | 
    	
        README_CN.md
    DELETED
    
    | @@ -1,654 +0,0 @@ | |
| 1 | 
            -
            <p align="left">
         | 
| 2 | 
            -
                    中文</a>  |  <a href="README.md">English</a>
         | 
| 3 | 
            -
            </p>
         | 
| 4 | 
            -
            <br><br>
         | 
| 5 | 
            -
             | 
| 6 | 
            -
            # Step-Audio
         | 
| 7 | 
            -
            <p align="center">
         | 
| 8 | 
            -
              <img src="assets/logo.png"  height=100>
         | 
| 9 | 
            -
            </p>
         | 
| 10 | 
            -
            <div align="center">
         | 
| 11 | 
            -
              <a href="https://github.com/stepfun-ai/Step-Audio/blob/cn-readme/assets/Step-Audio.pdf"><img src="https://img.shields.io/static/v1?label=Tech Report&message=Arxiv&color=red"></a>  
         | 
| 12 | 
            -
              <a href="https://x.com/StepFun_ai"><img src="https://img.shields.io/static/v1?label=X.com&message=Web&color=blue"></a>  
         | 
| 13 | 
            -
            </div>
         | 
| 14 | 
            -
             | 
| 15 | 
            -
            <div align="center">
         | 
| 16 | 
            -
              <a href="https://huggingface.co/stepfun-ai/Step-Audio-Chat"><img src="https://img.shields.io/static/v1?label=Step-Audio-Chat&message=HuggingFace&color=yellow"></a>  
         | 
| 17 | 
            -
              <a href="https://huggingface.co/stepfun-ai/Step-Audio-TTS-3B"><img src="https://img.shields.io/static/v1?label=Step-Audio-TTS-3B&message=HuggingFace&color=yellow"></a>  
         | 
| 18 | 
            -
            </div>
         | 
| 19 | 
            -
            <div align="center">
         | 
| 20 | 
            -
              <a href="https://huggingface.co/stepfun-ai/Step-Audio-Tokenizer"><img src="https://img.shields.io/static/v1?label=Step-Audio-Tokenier&message=HuggingFace&color=yellow"></a>  
         | 
| 21 | 
            -
              <a href="https://huggingface.co/datasets/stepfun-ai/StepEval-Audio-360"><img src="https://img.shields.io/static/v1?label=StepEval-Audio-360&message=HuggingFace&color=yellow"></a>  
         | 
| 22 | 
            -
            </div>
         | 
| 23 | 
            -
             | 
| 24 | 
            -
            ## 🔥🔥🔥 News!!
         | 
| 25 | 
            -
            * 2025年2月17日: 👋 发布推理代码和模型权重,其中包含[Step-Audio-Chat](https://huggingface.co/stepfun-ai/Step-Audio-Chat), [Step-Audio-TTS-3B](https://huggingface.co/stepfun-ai/Step-Audio-TTS-3B) 和 [Step-Audio-Tokenizer](https://huggingface.co/stepfun-ai/Step-Audio-Tokenizer)。
         | 
| 26 | 
            -
            * 2025年2月17日: 👋 发布多轮音频交互基准测试[StepEval-Audio-360](https://huggingface.co/datasets/stepfun-ai/StepEval-Audio-360)。
         | 
| 27 | 
            -
            * 2025年2月17日: 👋 发布了技术报告[Step-Audio-Report](./assets/Step-Audio.pdf)。
         | 
| 28 | 
            -
             | 
| 29 | 
            -
            ## Table of Contents
         | 
| 30 | 
            -
             | 
| 31 | 
            -
            1. [介绍](#1-介绍)
         | 
| 32 | 
            -
            2. [模型组成](#2-模型组成)
         | 
| 33 | 
            -
            3. [模型下载](#3-模型下载)
         | 
| 34 | 
            -
            4. [模型使用](#4-模型使用)
         | 
| 35 | 
            -
            5. [基准](#5-基准)
         | 
| 36 | 
            -
            6. [在线引擎](#6-在线引擎)
         | 
| 37 | 
            -
            7. [样例](#7-样例)
         | 
| 38 | 
            -
            8. [引文](#8-引文)
         | 
| 39 | 
            -
             | 
| 40 | 
            -
            ## 1. 介绍
         | 
| 41 | 
            -
             | 
| 42 | 
            -
            Step-Audio 是业界首个集语音理解与生成控制一体化的产品级开源实时语音对话系统,支持多语言对话(如 中文,英文,日语),语音情感(如 开心,悲伤),方言(如 粤语,四川话),可控制语速及韵律风格,支持RAP和哼唱等。其核心技术突破体现在以下四大技术亮点:
         | 
| 43 | 
            -
             | 
| 44 | 
            -
            - **1300亿多模态模型**: 单模型能实现理解生成一体化完成语音识别、语义理解、对话、语音克隆、语音生成等功能,开源千亿参数多模态模型 Step-Audio-Chat。
         | 
| 45 | 
            -
             | 
| 46 | 
            -
            - **高效数据生成链路**: 基于130B 突破传统 TTS 对人工采集数据的依赖,生成高质量的合成音频数据,并同步开源首个基于大规模合成数据训练,支持 RAP 和哼唱的指令加强版语音合成模型 Step-Audio-TTS-3B 。
         | 
| 47 | 
            -
             | 
| 48 | 
            -
            - **精细语音控制**: 支持多种情绪(如生气,高兴,悲伤)、方言(包括粤语、四川话等)和唱歌(包括 RAP、干声哼唱)的精准调控,满足用户对多样化语音生成的需求。
         | 
| 49 | 
            -
             | 
| 50 | 
            -
            - **扩展工具调用**: 通过 ToolCall 机制和角色扮演增强,进一步提升其在 Agents 和复杂任务中的表现。
         | 
| 51 | 
            -
             | 
| 52 | 
            -
            ## 2. 模型组成
         | 
| 53 | 
            -
             | 
| 54 | 
            -
            在Step-Audio系统中,音频流采用Linguistic tokenizer(码率16.7Hz,码本大小1024)与Semantice tokenizer(码率25Hz,码本大小4096)并行的双码本编码器方案,双码本在排列上使用了2:3时序交错策略。通过音频语境化持续预训练和任务定向微调强化了130B参数量的基础模型(Step-1),最终构建了强大的跨模态语音理解能力。为了实现实时音频生成,系统采用了混合语音解码器,结合流匹配(flow matching)与神经声码技术。
         | 
| 55 | 
            -
            
         | 
| 56 | 
            -
             | 
| 57 | 
            -
            ### 2.1 Tokenizer
         | 
| 58 | 
            -
             | 
| 59 | 
            -
            我们通过token级交错方法实现Linguistic token与Semantic token的有效整合。Linguistic tokenizer的码本大小是1024,码率16.7Hz;而Semantic tokenizer则使用4096的大容量码本来捕捉更精细的声学细节,码率25Hz。鉴于两者的码率差异,我们建立了2:3的时间对齐比例——每两个Linguistic token对应三个Linguistic token形成时序配对。
         | 
| 60 | 
            -
             | 
| 61 | 
            -
            ### 2.2 语言模型
         | 
| 62 | 
            -
             | 
| 63 | 
            -
            为了提升Step-Audio有效处理语音信息的能力,并实现精准的语音-文本对齐,我们在Step-1(一个拥有1300亿参数的基于文本的大型语言模型LLM)的基础上进行了音频持续预训练。
         | 
| 64 | 
            -
             | 
| 65 | 
            -
            ### 2.3 语音解码器
         | 
| 66 | 
            -
             | 
| 67 | 
            -
            Step-Audio语音解码器主要是将包含语义和声学信息的离散标记信息转换成连续的语音信号。该解码器架构结合了一个30亿参数的语言模型、流匹配模型(flow matching model)和梅尔频谱到波形的声码器(mel-to-wave vocoder)。为优化合成语音的清晰度(intelligibility)和自然度(naturalness),语音解码器采用双码交错训练方法(dual-code interleaving),确保生成过程中语义与声学特征的无缝融合。
         | 
| 68 | 
            -
             | 
| 69 | 
            -
            ### 2.4 实时推理管线
         | 
| 70 | 
            -
            为了实现实时的语音交互,我们对推理管线进行了一系列优化。其中最核心的是控制模块(Controller),该模块负责管理状态转换、协调响应生成,并确保关键子系统间的无缝协同。这些子系统包括:
         | 
| 71 | 
            -
             | 
| 72 | 
            -
            - **语音活动检测(VAD)**:实时检测用户语音起止
         | 
| 73 | 
            -
             | 
| 74 | 
            -
            - **流式音频分词器(Streaming Audio Tokenizer)**:实时音频流处理
         | 
| 75 | 
            -
             | 
| 76 | 
            -
            - **Step-Audio语言模型与语音解码器**:多模态回复生成
         | 
| 77 | 
            -
             | 
| 78 | 
            -
            - **上下文管理器(Context Manager)**:动态维护对话历史与状态
         | 
| 79 | 
            -
            
         | 
| 80 | 
            -
             | 
| 81 | 
            -
            ### 2.5 后训练细节
         | 
| 82 | 
            -
            在后训练阶段,我们针对自动语音识别(ASR)与文本转语音(TTS)任务进行了专项监督微调(Supervised Fine-Tuning, SFT)。对于音频输入-文本输出(Audio Question Text Answer, AQTA)任务,我们采用多样化高质量数据集进行SFT,并采用了基于人类反馈的强化学习(RLHF)以提升响应质量,从而实现对情感表达、语速、方言及韵律的细粒度控制。
         | 
| 83 | 
            -
            
         | 
| 84 | 
            -
             | 
| 85 | 
            -
             | 
| 86 | 
            -
            ## 3. 模型下载
         | 
| 87 | 
            -
            ### 3.1 Huggingface
         | 
| 88 | 
            -
            | 模型   | 链接   |
         | 
| 89 | 
            -
            |-------|-------|
         | 
| 90 | 
            -
            | Step-Audio-Tokenizer | [🤗huggingface](https://huggingface.co/stepfun-ai/Step-Audio-Tokenizer) |
         | 
| 91 | 
            -
            | Step-Audio-Chat | [🤗huggingface](https://huggingface.co/stepfun-ai/Step-Audio-Chat) |
         | 
| 92 | 
            -
            | Step-Audio-TTS-3B | [🤗huggingface](https://huggingface.co/stepfun-ai/Step-Audio-TTS-3B) |
         | 
| 93 | 
            -
             | 
| 94 | 
            -
            ### 3.2 Modelscope
         | 
| 95 | 
            -
            | 模型   | 链接   |
         | 
| 96 | 
            -
            |-------|-------|
         | 
| 97 | 
            -
            | Step-Audio-Tokenizer | [modelscope](https://modelscope.cn/models/stepfun-ai/Step-Audio-Tokenizer) |
         | 
| 98 | 
            -
            | Step-Audio-Chat | [modelscope](https://modelscope.cn/models/stepfun-ai/Step-Audio-Chat) |
         | 
| 99 | 
            -
            | Step-Audio-TTS-3B | [modelscope](https://modelscope.cn/models/stepfun-ai/Step-Audio-TTS-3B) |
         | 
| 100 | 
            -
             | 
| 101 | 
            -
            ## 4. 模型使用
         | 
| 102 | 
            -
            ### 📜 4.1  要求
         | 
| 103 | 
            -
            下表列出了运行Step-Audio模型(batch size=1)所需的配置要求:
         | 
| 104 | 
            -
             | 
| 105 | 
            -
            |     模型    |  Setting<br/>(采样率) | GPU最低显存  |
         | 
| 106 | 
            -
            |------------|--------------------------------|----------------|
         | 
| 107 | 
            -
            | Step-Audio-Tokenizer   |        41.6Hz          |       1.5GB        |
         | 
| 108 | 
            -
            | Step-Audio-Chat   |        41.6Hz          |       265GB        |
         | 
| 109 | 
            -
            | Step-Audio-TTS-3B   |        41.6Hz          |       8GB        |
         | 
| 110 | 
            -
             | 
| 111 | 
            -
            * 需要支持CUDA的NVIDIA显卡.
         | 
| 112 | 
            -
              * 模型在4块显存为80GB的A800系列NVIDIA显卡上进行测试.
         | 
| 113 | 
            -
              * **推荐**: 为确保最佳生成质量,建议使用4块显存为80GB的A800/H800系列NVIDIA显卡.
         | 
| 114 | 
            -
            * 测试采用的操作系统: Linux
         | 
| 115 | 
            -
             | 
| 116 | 
            -
            ### 🔧 4.2 依赖项与安装
         | 
| 117 | 
            -
            - Python >= 3.10.0 (推荐使用 [Anaconda](https://www.anaconda.com/download/#linux) or [Miniconda](https://docs.conda.io/en/latest/miniconda.html))
         | 
| 118 | 
            -
            - [PyTorch >= 2.3-cu121](https://pytorch.org/)
         | 
| 119 | 
            -
            - [CUDA Toolkit](https://developer.nvidia.com/cuda-downloads)
         | 
| 120 | 
            -
             | 
| 121 | 
            -
            ```bash
         | 
| 122 | 
            -
            git clone https://github.com/stepfun-ai/Step-Audio.git
         | 
| 123 | 
            -
            conda create -n stepaudio python=3.10
         | 
| 124 | 
            -
            conda activate stepaudio
         | 
| 125 | 
            -
             | 
| 126 | 
            -
            cd Step-Audio
         | 
| 127 | 
            -
            pip install -r requirements.txt
         | 
| 128 | 
            -
             | 
| 129 | 
            -
            git lfs install
         | 
| 130 | 
            -
            git clone https://huggingface.co/stepfun-ai/Step-Audio-Tokenizer
         | 
| 131 | 
            -
            git clone https://huggingface.co/stepfun-ai/Step-Audio-Chat
         | 
| 132 | 
            -
            git clone https://huggingface.co/stepfun-ai/Step-Audio-TTS-3B
         | 
| 133 | 
            -
             | 
| 134 | 
            -
            ```
         | 
| 135 | 
            -
             | 
| 136 | 
            -
            下载模型后,where_you_download_dir应包含以下结构:
         | 
| 137 | 
            -
            ```
         | 
| 138 | 
            -
            where_you_download_dir
         | 
| 139 | 
            -
            ├── Step-Audio-Tokenizer
         | 
| 140 | 
            -
            ├── Step-Audio-Chat
         | 
| 141 | 
            -
            ├── Step-Audio-TTS-3B
         | 
| 142 | 
            -
            ```
         | 
| 143 | 
            -
             | 
| 144 | 
            -
            ###  🚀 4.3 推理脚本
         | 
| 145 | 
            -
            #### 离线推理
         | 
| 146 | 
            -
            支持端到端音频/文本输入与音频/文本输出的推理流程。
         | 
| 147 | 
            -
            ```bash
         | 
| 148 | 
            -
            python offline_inference.py --model-path where_you_download_dir
         | 
| 149 | 
            -
            ```
         | 
| 150 | 
            -
             | 
| 151 | 
            -
            #### 语音合成推理
         | 
| 152 | 
            -
            使用默认音色进行语音合成推理或使用新音色进行克隆
         | 
| 153 | 
            -
            ```bash
         | 
| 154 | 
            -
            python tts_inference.py --model-path where_you_download_dir --output-path where_you_save_audio_dir --synthesis-type use_tts_or_clone
         | 
| 155 | 
            -
            ```
         | 
| 156 | 
            -
            克隆模式需要音色信息字典,格式如下:
         | 
| 157 | 
            -
            ```bash
         | 
| 158 | 
            -
            {
         | 
| 159 | 
            -
                "speaker": "speaker id",
         | 
| 160 | 
            -
                "prompt_text": "content of prompt wav",
         | 
| 161 | 
            -
                "wav_path": "prompt wav path"
         | 
| 162 | 
            -
            }
         | 
| 163 | 
            -
            ```
         | 
| 164 | 
            -
             | 
| 165 | 
            -
            #### 启动网页演示
         | 
| 166 | 
            -
            启动本地服务器以进行在线推理。
         | 
| 167 | 
            -
            假设您已配备4块GPU且已完成所有模型的下载。
         | 
| 168 | 
            -
             | 
| 169 | 
            -
            ```bash
         | 
| 170 | 
            -
            python app.py --model-path where_you_download_dir
         | 
| 171 | 
            -
            ```
         | 
| 172 | 
            -
             | 
| 173 | 
            -
            ## 5. 基准
         | 
| 174 | 
            -
             | 
| 175 | 
            -
            ### 5.1 语音识别
         | 
| 176 | 
            -
             | 
| 177 | 
            -
            <table>
         | 
| 178 | 
            -
                <thead>
         | 
| 179 | 
            -
                    <tr>
         | 
| 180 | 
            -
                        <th style="text-align:center"></th>
         | 
| 181 | 
            -
                        <th colspan="4" style="text-align:center">隐层特征建模</th>
         | 
| 182 | 
            -
                        <th colspan="5" style="text-align:center">离散标记建模</th>
         | 
| 183 | 
            -
                    </tr>
         | 
| 184 | 
            -
                    <tr>
         | 
| 185 | 
            -
                        <th style="text-align:center"></th>
         | 
| 186 | 
            -
                        <th style="text-align:center">Whisper Large-v3</th>
         | 
| 187 | 
            -
                        <th style="text-align:center">Qwen2-Audio</th>
         | 
| 188 | 
            -
                        <th style="text-align:center">MinMo</th>
         | 
| 189 | 
            -
                        <th style="text-align:center">LUCY</th>
         | 
| 190 | 
            -
                        <th style="text-align:center">Moshi</th>
         | 
| 191 | 
            -
                        <th style="text-align:center">GLM-4-voice Base</th>
         | 
| 192 | 
            -
                        <th style="text-align:center">GLM-4-voice Chat</th>
         | 
| 193 | 
            -
                        <th style="text-align:center">Step-Audio Pretrain</th>
         | 
| 194 | 
            -
                        <th style="text-align:center">Step-Audio-Chat</th>
         | 
| 195 | 
            -
                    </tr>
         | 
| 196 | 
            -
                </thead>
         | 
| 197 | 
            -
                <tbody>
         | 
| 198 | 
            -
                    <tr>
         | 
| 199 | 
            -
                        <td>Aishell-1</td>
         | 
| 200 | 
            -
                        <td style="text-align:center">5.14</td>
         | 
| 201 | 
            -
                        <td style="text-align:center">1.53</td>
         | 
| 202 | 
            -
                        <td style="text-align:center">-</td>
         | 
| 203 | 
            -
                        <td style="text-align:center">2.4</td>
         | 
| 204 | 
            -
                        <td style="text-align:center">-</td>
         | 
| 205 | 
            -
                        <td style="text-align:center">2.46</td>
         | 
| 206 | 
            -
                        <td style="text-align:center">226.47</td>
         | 
| 207 | 
            -
                        <td style="text-align:center"><strong>0.87</strong></td>
         | 
| 208 | 
            -
                        <td style="text-align:center">1.95</td>
         | 
| 209 | 
            -
                    </tr>
         | 
| 210 | 
            -
                    <tr>
         | 
| 211 | 
            -
                        <td>Aishell-2 ios</td>
         | 
| 212 | 
            -
                        <td style="text-align:center">4.76</td>
         | 
| 213 | 
            -
                        <td style="text-align:center">3.06</td>
         | 
| 214 | 
            -
                        <td style="text-align:center"><strong>2.69</strong></td>
         | 
| 215 | 
            -
                        <td style="text-align:center">-</td>
         | 
| 216 | 
            -
                        <td style="text-align:center">-</td>
         | 
| 217 | 
            -
                        <td style="text-align:center">-</td>
         | 
| 218 | 
            -
                        <td style="text-align:center">211.3</td>
         | 
| 219 | 
            -
                        <td style="text-align:center">2.91</td>
         | 
| 220 | 
            -
                        <td style="text-align:center">3.57</td>
         | 
| 221 | 
            -
                    </tr>
         | 
| 222 | 
            -
                    <tr>
         | 
| 223 | 
            -
                        <td>Wenetspeech test-net</td>
         | 
| 224 | 
            -
                        <td style="text-align:center">9.68</td>
         | 
| 225 | 
            -
                        <td style="text-align:center">7.72</td>
         | 
| 226 | 
            -
                        <td style="text-align:center"><strong>6.64</strong></td>
         | 
| 227 | 
            -
                        <td style="text-align:center">8.78</td>
         | 
| 228 | 
            -
                        <td style="text-align:center">-</td>
         | 
| 229 | 
            -
                        <td style="text-align:center">-</td>
         | 
| 230 | 
            -
                        <td style="text-align:center">146.05</td>
         | 
| 231 | 
            -
                        <td style="text-align:center">7.62</td>
         | 
| 232 | 
            -
                        <td style="text-align:center">8.75</td>
         | 
| 233 | 
            -
                    </tr>
         | 
| 234 | 
            -
                    <tr>
         | 
| 235 | 
            -
                        <td>Wenet test-meeting</td>
         | 
| 236 | 
            -
                        <td style="text-align:center">18.54</td>
         | 
| 237 | 
            -
                        <td style="text-align:center">8.4</td>
         | 
| 238 | 
            -
                        <td style="text-align:center"><strong>7.6</strong></td>
         | 
| 239 | 
            -
                        <td style="text-align:center">10.42</td>
         | 
| 240 | 
            -
                        <td style="text-align:center">-</td>
         | 
| 241 | 
            -
                        <td style="text-align:center">-</td>
         | 
| 242 | 
            -
                        <td style="text-align:center">140.82</td>
         | 
| 243 | 
            -
                        <td style="text-align:center">7.78</td>
         | 
| 244 | 
            -
                        <td style="text-align:center">9.52</td>
         | 
| 245 | 
            -
                    </tr>
         | 
| 246 | 
            -
                    <tr>
         | 
| 247 | 
            -
                        <td>Librispeech test-clean</td>
         | 
| 248 | 
            -
                        <td style="text-align:center">1.9</td>
         | 
| 249 | 
            -
                        <td style="text-align:center"><strong>1.6</strong></td>
         | 
| 250 | 
            -
                        <td style="text-align:center"><strong>1.6</strong></td>
         | 
| 251 | 
            -
                        <td style="text-align:center">3.36</td>
         | 
| 252 | 
            -
                        <td style="text-align:center">5.7</td>
         | 
| 253 | 
            -
                        <td style="text-align:center">2.82</td>
         | 
| 254 | 
            -
                        <td style="text-align:center">75.39</td>
         | 
| 255 | 
            -
                        <td style="text-align:center">2.36</td>
         | 
| 256 | 
            -
                        <td style="text-align:center">3.11</td>
         | 
| 257 | 
            -
                    </tr>
         | 
| 258 | 
            -
                    <tr>
         | 
| 259 | 
            -
                        <td>Librispeech test-other</td>
         | 
| 260 | 
            -
                        <td style="text-align:center">3.65</td>
         | 
| 261 | 
            -
                        <td style="text-align:center"><strong>3.6</strong></td>
         | 
| 262 | 
            -
                        <td style="text-align:center">3.82</td>
         | 
| 263 | 
            -
                        <td style="text-align:center">8.05</td>
         | 
| 264 | 
            -
                        <td style="text-align:center">-</td>
         | 
| 265 | 
            -
                        <td style="text-align:center">7.66</td>
         | 
| 266 | 
            -
                        <td style="text-align:center">80.3</td>
         | 
| 267 | 
            -
                        <td style="text-align:center">6.32</td>
         | 
| 268 | 
            -
                        <td style="text-align:center">8.44</td>
         | 
| 269 | 
            -
                    </tr>
         | 
| 270 | 
            -
                    <tr>
         | 
| 271 | 
            -
                        <td>AVG</td>
         | 
| 272 | 
            -
                        <td style="text-align:center">7.28</td>
         | 
| 273 | 
            -
                        <td style="text-align:center"><strong>4.32</strong></td>
         | 
| 274 | 
            -
                        <td style="text-align:center">-</td>
         | 
| 275 | 
            -
                        <td style="text-align:center">-</td>
         | 
| 276 | 
            -
                        <td style="text-align:center">-</td>
         | 
| 277 | 
            -
                        <td style="text-align:center">-</td>
         | 
| 278 | 
            -
                        <td style="text-align:center">146.74</td>
         | 
| 279 | 
            -
                        <td style="text-align:center">4.64</td>
         | 
| 280 | 
            -
                        <td style="text-align:center">5.89</td>
         | 
| 281 | 
            -
                    </tr>
         | 
| 282 | 
            -
                </tbody>
         | 
| 283 | 
            -
            </table>
         | 
| 284 | 
            -
             | 
| 285 | 
            -
            ### 5.2 语音合成
         | 
| 286 | 
            -
            #### 5.2.1 与GLM-4-Voice与MinMo在内容一致性(CER/WER)上的性能对比。
         | 
| 287 | 
            -
             | 
| 288 | 
            -
            <table>
         | 
| 289 | 
            -
                <thead>
         | 
| 290 | 
            -
                    <tr>
         | 
| 291 | 
            -
                        <th rowspan="2">Model</th>
         | 
| 292 | 
            -
                        <th style="text-align:center" colspan="1">test-zh</th>
         | 
| 293 | 
            -
                        <th style="text-align:center" colspan="1">test-en</th>
         | 
| 294 | 
            -
                    </tr>
         | 
| 295 | 
            -
                    <tr>
         | 
| 296 | 
            -
                        <th style="text-align:center">CER (%) ↓</th>
         | 
| 297 | 
            -
                        <th style="text-align:center">WER (%) ↓</th>
         | 
| 298 | 
            -
                    </tr>
         | 
| 299 | 
            -
                </thead>
         | 
| 300 | 
            -
                <tbody>
         | 
| 301 | 
            -
                    <tr>
         | 
| 302 | 
            -
                        <td>GLM-4-Voice</td>
         | 
| 303 | 
            -
                        <td style="text-align:center">2.19</td>
         | 
| 304 | 
            -
                        <td style="text-align:center">2.91</td>
         | 
| 305 | 
            -
                    </tr>
         | 
| 306 | 
            -
                    <tr>
         | 
| 307 | 
            -
                        <td>MinMo</td>
         | 
| 308 | 
            -
                        <td style="text-align:center">2.48</td>
         | 
| 309 | 
            -
                        <td style="text-align:center">2.90</td>
         | 
| 310 | 
            -
                    </tr>
         | 
| 311 | 
            -
                    <tr>
         | 
| 312 | 
            -
                        <td><strong>Step-Audio</strong></td>
         | 
| 313 | 
            -
                        <td style="text-align:center"><strong>1.53</strong></td>
         | 
| 314 | 
            -
                        <td style="text-align:center"><strong>2.71</strong></td>
         | 
| 315 | 
            -
                    </tr>
         | 
| 316 | 
            -
                </tbody>
         | 
| 317 | 
            -
            </table>
         | 
| 318 | 
            -
             | 
| 319 | 
            -
            #### 5.2.2 语音合成模型在SEED测试集上的性能结果。
         | 
| 320 | 
            -
            * StepAudio-TTS-3B-Single 表示采用双码本主干网络与单码本声码器的组合架构。
         | 
| 321 | 
            -
             | 
| 322 | 
            -
            <table>
         | 
| 323 | 
            -
                <thead>
         | 
| 324 | 
            -
                    <tr>
         | 
| 325 | 
            -
                        <th rowspan="2">Model</th>
         | 
| 326 | 
            -
                        <th style="text-align:center" colspan="2">test-zh</th>
         | 
| 327 | 
            -
                        <th style="text-align:center" colspan="2">test-en</th>
         | 
| 328 | 
            -
                    </tr>
         | 
| 329 | 
            -
                    <tr>
         | 
| 330 | 
            -
                        <th style="text-align:center">CER (%) ↓</th>
         | 
| 331 | 
            -
                        <th style="text-align:center">SS ↑</th>
         | 
| 332 | 
            -
                        <th style="text-align:center">WER (%) ↓</th>
         | 
| 333 | 
            -
                        <th style="text-align:center">SS ↑</th>
         | 
| 334 | 
            -
                    </tr>
         | 
| 335 | 
            -
                </thead>
         | 
| 336 | 
            -
                <tbody>
         | 
| 337 | 
            -
                    <tr>
         | 
| 338 | 
            -
                        <td>FireRedTTS</td>
         | 
| 339 | 
            -
                        <td style="text-align:center">1.51</td>
         | 
| 340 | 
            -
                        <td style="text-align:center">0.630</td>
         | 
| 341 | 
            -
                        <td style="text-align:center">3.82</td>
         | 
| 342 | 
            -
                        <td style="text-align:center">0.460</td>
         | 
| 343 | 
            -
                    </tr>
         | 
| 344 | 
            -
                    <tr>
         | 
| 345 | 
            -
                        <td>MaskGCT</td>
         | 
| 346 | 
            -
                        <td style="text-align:center">2.27</td>
         | 
| 347 | 
            -
                        <td style="text-align:center">0.774</td>
         | 
| 348 | 
            -
                        <td style="text-align:center">2.62</td>
         | 
| 349 | 
            -
                        <td style="text-align:center">0.774</td>
         | 
| 350 | 
            -
                    </tr>
         | 
| 351 | 
            -
                    <tr>
         | 
| 352 | 
            -
                        <td>CosyVoice</td>
         | 
| 353 | 
            -
                        <td style="text-align:center">3.63</td>
         | 
| 354 | 
            -
                        <td style="text-align:center">0.775</td>
         | 
| 355 | 
            -
                        <td style="text-align:center">4.29</td>
         | 
| 356 | 
            -
                        <td style="text-align:center">0.699</td>
         | 
| 357 | 
            -
                    </tr>
         | 
| 358 | 
            -
                    <tr>
         | 
| 359 | 
            -
                        <td>CosyVoice 2</td>
         | 
| 360 | 
            -
                        <td style="text-align:center">1.45</td>
         | 
| 361 | 
            -
                        <td style="text-align:center">0.806</td>
         | 
| 362 | 
            -
                        <td style="text-align:center">2.57</td>
         | 
| 363 | 
            -
                        <td style="text-align:center">0.736</td>
         | 
| 364 | 
            -
                    </tr>
         | 
| 365 | 
            -
                    <tr>
         | 
| 366 | 
            -
                        <td>CosyVoice 2-S</td>
         | 
| 367 | 
            -
                        <td style="text-align:center">1.45</td>
         | 
| 368 | 
            -
                        <td style="text-align:center">0.812</td>
         | 
| 369 | 
            -
                        <td style="text-align:center">2.38</td>
         | 
| 370 | 
            -
                        <td style="text-align:center">0.743</td>
         | 
| 371 | 
            -
                    </tr>
         | 
| 372 | 
            -
                    <tr>
         | 
| 373 | 
            -
                        <td><strong>Step-Audio-TTS-3B-Single</strong></td>
         | 
| 374 | 
            -
                        <td style="text-align:center">1.37</td>
         | 
| 375 | 
            -
                        <td style="text-align:center">0.802</td>
         | 
| 376 | 
            -
                        <td style="text-align:center">2.52</td>
         | 
| 377 | 
            -
                        <td style="text-align:center">0.704</td>
         | 
| 378 | 
            -
                    </tr>
         | 
| 379 | 
            -
                    <tr>
         | 
| 380 | 
            -
                        <td><strong>Step-Audio-TTS-3B</strong></td>
         | 
| 381 | 
            -
                        <td style="text-align:center"><strong>1.31</strong></td>
         | 
| 382 | 
            -
                        <td style="text-align:center">0.733</td>
         | 
| 383 | 
            -
                        <td style="text-align:center"><strong>2.31</strong></td>
         | 
| 384 | 
            -
                        <td style="text-align:center">0.660</td>
         | 
| 385 | 
            -
                    </tr>
         | 
| 386 | 
            -
                    <tr>
         | 
| 387 | 
            -
                        <td><strong>Step-Audio-TTS</strong></td>
         | 
| 388 | 
            -
                        <td style="text-align:center"><strong>1.17</strong></td>
         | 
| 389 | 
            -
                        <td style="text-align:center">0.73</td>
         | 
| 390 | 
            -
                        <td style="text-align:center"><strong>2.0</strong></td>
         | 
| 391 | 
            -
                        <td style="text-align:center">0.660</td>
         | 
| 392 | 
            -
                    </tr>
         | 
| 393 | 
            -
                </tbody>
         | 
| 394 | 
            -
            </table>
         | 
| 395 | 
            -
             | 
| 396 | 
            -
            #### 5.2.3 双码本重合成与CosyVoice性能对比。
         | 
| 397 | 
            -
             | 
| 398 | 
            -
            <table>
         | 
| 399 | 
            -
                <thead>
         | 
| 400 | 
            -
                    <tr>
         | 
| 401 | 
            -
                        <th style="text-align:center" rowspan="2">Token</th>
         | 
| 402 | 
            -
                        <th style="text-align:center" colspan="2">test-zh</th>
         | 
| 403 | 
            -
                        <th style="text-align:center" colspan="2">test-en</th>
         | 
| 404 | 
            -
                    </tr>
         | 
| 405 | 
            -
                    <tr>
         | 
| 406 | 
            -
                        <th style="text-align:center">CER (%) ↓</th>
         | 
| 407 | 
            -
                        <th style="text-align:center">SS ↑</th>
         | 
| 408 | 
            -
                        <th style="text-align:center">WER (%) ↓</th>
         | 
| 409 | 
            -
                        <th style="text-align:center">SS ↑</th>
         | 
| 410 | 
            -
                    </tr>
         | 
| 411 | 
            -
                </thead>
         | 
| 412 | 
            -
                <tbody>
         | 
| 413 | 
            -
                    <tr>
         | 
| 414 | 
            -
                        <td style="text-align:center">Groundtruth</td>
         | 
| 415 | 
            -
                        <td style="text-align:center">0.972</td>
         | 
| 416 | 
            -
                        <td style="text-align:center">-</td>
         | 
| 417 | 
            -
                        <td style="text-align:center">2.156</td>
         | 
| 418 | 
            -
                        <td style="text-align:center">-</td>
         | 
| 419 | 
            -
                    </tr>
         | 
| 420 | 
            -
                    <tr>
         | 
| 421 | 
            -
                        <td style="text-align:center">CosyVoice</td>
         | 
| 422 | 
            -
                        <td style="text-align:center">2.857</td>
         | 
| 423 | 
            -
                        <td style="text-align:center"><strong>0.849</strong></td>
         | 
| 424 | 
            -
                        <td style="text-align:center">4.519</td>
         | 
| 425 | 
            -
                        <td style="text-align:center"><strong>0.807</strong></td>
         | 
| 426 | 
            -
                    </tr>
         | 
| 427 | 
            -
                    <tr>
         | 
| 428 | 
            -
                        <td style="text-align:center">Step-Audio-TTS-3B</td>
         | 
| 429 | 
            -
                        <td style="text-align:center"><strong>2.192</strong></td>
         | 
| 430 | 
            -
                        <td style="text-align:center">0.784</td>
         | 
| 431 | 
            -
                        <td style="text-align:center"><strong>3.585</strong></td>
         | 
| 432 | 
            -
                        <td style="text-align:center">0.742</td>
         | 
| 433 | 
            -
                    </tr>
         | 
| 434 | 
            -
                </tbody>
         | 
| 435 | 
            -
            </table>
         | 
| 436 | 
            -
             | 
| 437 | 
            -
            ### 5.3 语音对话
         | 
| 438 | 
            -
            我们发布全新基准测试[StepEval-Audio-360](https://huggingface.co/datasets/stepfun-ai/StepEval-Audio-360),该数据集包含100个源自真实用户的多轮中文提示,旨在系统性评估生成式语音交互系统在以下维度的表现:语音指令遵循、语音理解、逻辑推理、角色扮演、创作能力、唱歌、语言能力、语音情绪控制、游戏。
         | 
| 439 | 
            -
            #### 5.3.1 StepEval-Audio-360
         | 
| 440 | 
            -
             | 
| 441 | 
            -
            #### 大语言模型评估指标(GPT-4o)
         | 
| 442 | 
            -
            <table>
         | 
| 443 | 
            -
                <caption>Comparison of fundamental capabilities of voice chat on the StepEval-Audio-360.</caption>
         | 
| 444 | 
            -
                <thead>
         | 
| 445 | 
            -
                    <tr>
         | 
| 446 | 
            -
                        <th>Model</th>
         | 
| 447 | 
            -
                        <th style="text-align:center">Factuality (% ↑)</th>
         | 
| 448 | 
            -
                        <th style="text-align:center">Relevance (% ↑)</th>
         | 
| 449 | 
            -
                        <th style="text-align:center">Chat Score ↑</th>
         | 
| 450 | 
            -
                    </tr>
         | 
| 451 | 
            -
                </thead>
         | 
| 452 | 
            -
                <tbody>
         | 
| 453 | 
            -
                    <tr>
         | 
| 454 | 
            -
                        <td>GLM4-Voice</td>
         | 
| 455 | 
            -
                        <td style="text-align:center">54.7</td>
         | 
| 456 | 
            -
                        <td style="text-align:center">66.4</td>
         | 
| 457 | 
            -
                        <td style="text-align:center">3.49</td>
         | 
| 458 | 
            -
                    </tr>
         | 
| 459 | 
            -
                    <tr>
         | 
| 460 | 
            -
                        <td>Qwen2-Audio</td>
         | 
| 461 | 
            -
                        <td style="text-align:center">22.6</td>
         | 
| 462 | 
            -
                        <td style="text-align:center">26.3</td>
         | 
| 463 | 
            -
                        <td style="text-align:center">2.27</td>
         | 
| 464 | 
            -
                    </tr>
         | 
| 465 | 
            -
                    <tr>
         | 
| 466 | 
            -
                        <td>Moshi<sup>*</sup></td>
         | 
| 467 | 
            -
                        <td style="text-align:center">1.0</td>
         | 
| 468 | 
            -
                        <td style="text-align:center">0</td>
         | 
| 469 | 
            -
                        <td style="text-align:center">1.49</td>
         | 
| 470 | 
            -
                    </tr>
         | 
| 471 | 
            -
                    <tr>
         | 
| 472 | 
            -
                        <td><strong>Step-Audio-Chat</strong></td>
         | 
| 473 | 
            -
                        <td style="text-align:center"><strong>66.4</strong></td>
         | 
| 474 | 
            -
                        <td style="text-align:center"><strong>75.2</strong></td>
         | 
| 475 | 
            -
                        <td style="text-align:center"><strong>4.11</strong></td>
         | 
| 476 | 
            -
                    </tr>
         | 
| 477 | 
            -
                </tbody>
         | 
| 478 | 
            -
            </table>
         | 
| 479 | 
            -
             | 
| 480 | 
            -
            *注意:带有“\*”标记的内容仅供参考。
         | 
| 481 | 
            -
             | 
| 482 | 
            -
            #### 雷达图(人工测评)
         | 
| 483 | 
            -
            <img src="./assets/stepeval_radar_chart.png" width="600" alt="QR code">
         | 
| 484 | 
            -
             | 
| 485 | 
            -
            #### 5.3.2 公开测试集
         | 
| 486 | 
            -
             | 
| 487 | 
            -
            <table>
         | 
| 488 | 
            -
                <thead>
         | 
| 489 | 
            -
                    <tr>
         | 
| 490 | 
            -
                        <th>Model</th>
         | 
| 491 | 
            -
                        <th style="text-align:center">Llama Question</th>
         | 
| 492 | 
            -
                        <th style="text-align:center">Web Questions</th>
         | 
| 493 | 
            -
                        <th style="text-align:center">TriviaQA*</th>
         | 
| 494 | 
            -
                        <th style="text-align:center">ComplexBench</th>
         | 
| 495 | 
            -
                        <th style="text-align:center">HSK-6</th>
         | 
| 496 | 
            -
                    </tr>
         | 
| 497 | 
            -
                </thead>
         | 
| 498 | 
            -
                <tbody>
         | 
| 499 | 
            -
                    <tr>
         | 
| 500 | 
            -
                        <td>GLM4-Voice</td>
         | 
| 501 | 
            -
                        <td style="text-align:center">64.7</td>
         | 
| 502 | 
            -
                        <td style="text-align:center">32.2</td>
         | 
| 503 | 
            -
                        <td style="text-align:center">39.1</td>
         | 
| 504 | 
            -
                        <td style="text-align:center">66.0</td>
         | 
| 505 | 
            -
                        <td style="text-align:center">74.0</td>
         | 
| 506 | 
            -
                    </tr>
         | 
| 507 | 
            -
                    <tr>
         | 
| 508 | 
            -
                        <td>Moshi</td>
         | 
| 509 | 
            -
                        <td style="text-align:center">62.3</td>
         | 
| 510 | 
            -
                        <td style="text-align:center">26.6</td>
         | 
| 511 | 
            -
                        <td style="text-align:center">22.8</td>
         | 
| 512 | 
            -
                        <td style="text-align:center">-</td>
         | 
| 513 | 
            -
                        <td style="text-align:center">-</td>
         | 
| 514 | 
            -
                    </tr>
         | 
| 515 | 
            -
                    <tr>
         | 
| 516 | 
            -
                        <td>Freeze-Omni</td>
         | 
| 517 | 
            -
                        <td style="text-align:center">72.0</td>
         | 
| 518 | 
            -
                        <td style="text-align:center">44.7</td>
         | 
| 519 | 
            -
                        <td style="text-align:center">53.9</td>
         | 
| 520 | 
            -
                        <td style="text-align:center">-</td>
         | 
| 521 | 
            -
                        <td style="text-align:center">-</td>
         | 
| 522 | 
            -
                    </tr>
         | 
| 523 | 
            -
                    <tr>
         | 
| 524 | 
            -
                        <td>LUCY</td>
         | 
| 525 | 
            -
                        <td style="text-align:center">59.7</td>
         | 
| 526 | 
            -
                        <td style="text-align:center">29.3</td>
         | 
| 527 | 
            -
                        <td style="text-align:center">27.0</td>
         | 
| 528 | 
            -
                        <td style="text-align:center">-</td>
         | 
| 529 | 
            -
                        <td style="text-align:center">-</td>
         | 
| 530 | 
            -
                    </tr>
         | 
| 531 | 
            -
                    <tr>
         | 
| 532 | 
            -
                        <td>MinMo</td>
         | 
| 533 | 
            -
                        <td style="text-align:center">78.9</td>
         | 
| 534 | 
            -
                        <td style="text-align:center">55.0</td>
         | 
| 535 | 
            -
                        <td style="text-align:center">48.3</td>
         | 
| 536 | 
            -
                        <td style="text-align:center">-</td>
         | 
| 537 | 
            -
                        <td style="text-align:center">-</td>
         | 
| 538 | 
            -
                    </tr>
         | 
| 539 | 
            -
                    <tr>
         | 
| 540 | 
            -
                        <td>Qwen2-Audio</td>
         | 
| 541 | 
            -
                        <td style="text-align:center">52.0</td>
         | 
| 542 | 
            -
                        <td style="text-align:center">27.0</td>
         | 
| 543 | 
            -
                        <td style="text-align:center">37.3</td>
         | 
| 544 | 
            -
                        <td style="text-align:center">54.0</td>
         | 
| 545 | 
            -
                        <td style="text-align:center">-</td>
         | 
| 546 | 
            -
                    </tr>
         | 
| 547 | 
            -
                    <tr>
         | 
| 548 | 
            -
                        <td><strong>Step-Audio-Chat</strong></td>
         | 
| 549 | 
            -
                        <td style="text-align:center"><strong><i>81.0</i></strong></td>
         | 
| 550 | 
            -
                        <td style="text-align:center"><strong>75.1</strong></td>
         | 
| 551 | 
            -
                        <td style="text-align:center"><strong>58.0</strong></td>
         | 
| 552 | 
            -
                        <td style="text-align:center"><strong>74.0</strong></td>
         | 
| 553 | 
            -
                        <td style="text-align:center"><strong>86.0</strong></td>
         | 
| 554 | 
            -
                    </tr>
         | 
| 555 | 
            -
                </tbody>
         | 
| 556 | 
            -
            </table>
         | 
| 557 | 
            -
             | 
| 558 | 
            -
            * 注意:在 TriviaQA 数据集上,带有“\*”标记的结果仅供参考。
         | 
| 559 | 
            -
             | 
| 560 | 
            -
            * 在 TriviaQA 数据集中,带有“\*”标记的结果仅用于参考。
         | 
| 561 | 
            -
             | 
| 562 | 
            -
            #### 5.3.3 语音指令遵循
         | 
| 563 | 
            -
            <table>
         | 
| 564 | 
            -
                <thead>
         | 
| 565 | 
            -
                    <tr>
         | 
| 566 | 
            -
                        <th rowspan="2">Category</th>
         | 
| 567 | 
            -
                        <th colspan="2" style="text-align:center">Instruction Following</th>
         | 
| 568 | 
            -
                        <th colspan="2" style="text-align:center">Audio Quality</th>
         | 
| 569 | 
            -
                    </tr>
         | 
| 570 | 
            -
                    <tr>
         | 
| 571 | 
            -
                        <th style="text-align:center">GLM-4-Voice</th>
         | 
| 572 | 
            -
                        <th style="text-align:center">Step-Audio</th>
         | 
| 573 | 
            -
                        <th style="text-align:center">GLM-4-Voice</th>
         | 
| 574 | 
            -
                        <th style="text-align:center">Step-Audio</th>
         | 
| 575 | 
            -
                    </tr>
         | 
| 576 | 
            -
                </thead>
         | 
| 577 | 
            -
                <tbody>
         | 
| 578 | 
            -
                    <tr>
         | 
| 579 | 
            -
                        <td>Languages</td>
         | 
| 580 | 
            -
                        <td style="text-align:center">1.9</td>
         | 
| 581 | 
            -
                        <td style="text-align:center">3.8</td>
         | 
| 582 | 
            -
                        <td style="text-align:center">2.9</td>
         | 
| 583 | 
            -
                        <td style="text-align:center">3.3</td>
         | 
| 584 | 
            -
                    </tr>
         | 
| 585 | 
            -
                    <tr>
         | 
| 586 | 
            -
                        <td>Role-playing</td>
         | 
| 587 | 
            -
                        <td style="text-align:center">3.8</td>
         | 
| 588 | 
            -
                        <td style="text-align:center">4.2</td>
         | 
| 589 | 
            -
                        <td style="text-align:center">3.2</td>
         | 
| 590 | 
            -
                        <td style="text-align:center">3.6</td>
         | 
| 591 | 
            -
                    </tr>
         | 
| 592 | 
            -
                    <tr>
         | 
| 593 | 
            -
                        <td>Singing / RAP</td>
         | 
| 594 | 
            -
                        <td style="text-align:center">2.1</td>
         | 
| 595 | 
            -
                        <td style="text-align:center">2.4</td>
         | 
| 596 | 
            -
                        <td style="text-align:center">2.4</td>
         | 
| 597 | 
            -
                        <td style="text-align:center">4</td>
         | 
| 598 | 
            -
                    </tr>
         | 
| 599 | 
            -
                    <tr>
         | 
| 600 | 
            -
                        <td>Voice Control</td>
         | 
| 601 | 
            -
                        <td style="text-align:center">3.6</td>
         | 
| 602 | 
            -
                        <td style="text-align:center">4.4</td>
         | 
| 603 | 
            -
                        <td style="text-align:center">3.3</td>
         | 
| 604 | 
            -
                        <td style="text-align:center">4.1</td>
         | 
| 605 | 
            -
                    </tr>
         | 
| 606 | 
            -
                </tbody>
         | 
| 607 | 
            -
            </table>
         | 
| 608 | 
            -
             | 
| 609 | 
            -
            ## 6. 在线引擎
         | 
| 610 | 
            -
            Step-Audio 的在线版本可以通过[跃问](https://yuewen.cn) 的应用程序访问,其中还可以找到一些惊喜的示例。
         | 
| 611 | 
            -
             | 
| 612 | 
            -
            <img src="./assets/yuewen.jpeg" width="200" alt="QR code">
         | 
| 613 | 
            -
             | 
| 614 | 
            -
            ## 7. 样例
         | 
| 615 | 
            -
            ### 音频克隆
         | 
| 616 | 
            -
            | role   | prompt wav | clone wav |
         | 
| 617 | 
            -
            |:-------:|:-------:|:-------:|
         | 
| 618 | 
            -
            |于谦| [google drive](https://drive.google.com/file/d/1N9EJypafFwmeL0R152GoL_CVGbYn1_9A/preview)<br>[audio file](https://github.com/stepfun-ai/Step-Audio/tree/main/examples/prompt_wav_yuqian.wav)|[google drive](https://drive.google.com/file/d/1Zs_1QrCUuoSqtUSdn2ENIor-k5baQdDV/preview)<br>[audio file](https://github.com/stepfun-ai/Step-Audio/tree/main/examples/clone_wav_yuqian.wav)|
         | 
| 619 | 
            -
            |李雪琴| [google drive](https://drive.google.com/file/d/15SkZ29hksELYi1NDOxYOPu-kRTLSyke_/preview)<br>[audio file](https://github.com/stepfun-ai/Step-Audio/tree/main/examples/prompt_wav_lixueqin.wav)|[google drive](https://drive.google.com/file/d/11Le4qMqL2DmWpf7RFRpKUXERIR9TtKC0/preview)<br>[audio file](https://github.com/stepfun-ai/Step-Audio/tree/main/examples/clone_wav_lixueqin.wav)|
         | 
| 620 | 
            -
             | 
| 621 | 
            -
            ### 速度控制
         | 
| 622 | 
            -
            | prompt | response |
         | 
| 623 | 
            -
            |:-------:|:-------:|
         | 
| 624 | 
            -
            |Human: 说一个绕口令<br>Assistant: 吃葡萄不吐葡萄皮,不吃葡萄倒吐葡萄皮<br>Human: 哎,你能把这个绕口令说的再快一点吗?|[google drive](https://drive.google.com/file/d/1mAH-NRrOVZo4tv6gdAZkyJg8kRuTNNGC/preview)<br>[audio file](https://github.com/stepfun-ai/Step-Audio/tree/main/examples/speed_control1.wav)|
         | 
| 625 | 
            -
            |Human: 说一个绕口令<br>Assistant: 吃葡萄不吐葡萄皮,不吃葡萄倒吐葡萄皮<br>Human: 哎,你能把这个绕口令说的再快一点吗?<br>Assistant: 吃葡萄不吐葡萄皮,不吃葡萄倒吐葡萄皮<br>Human: 呃,你再用非常非常慢的速度说一遍的。|[google drive](https://drive.google.com/file/d/1FhRnKo8uGrtO-cWg4qkrg8iDoNRbtqSX/preview)<br>[audio file](https://github.com/stepfun-ai/Step-Audio/tree/main/examples/speed_control2.wav)|
         | 
| 626 | 
            -
             | 
| 627 | 
            -
            ### 高情商(情感控制 & 语调控制)
         | 
| 628 | 
            -
            | prompt | response |
         | 
| 629 | 
            -
            |:-------:|:-------:|
         | 
| 630 | 
            -
            |Human: 你这语气又不撒娇又不卖萌的,要不你撒个娇卖个萌吧。|[google drive](https://drive.google.com/file/d/19IROE6_6h2UQVNniCmDTnrhxKRMOFHq3/preview)<br>[audio file](https://github.com/stepfun-ai/Step-Audio/tree/main/examples/tone_control.wav)|
         | 
| 631 | 
            -
            |Human: 怎么办?我感觉我的人生很失败。|[google drive](https://drive.google.com/file/d/1JlLbOlzmdrokVdxtwy1S8eeWqsZR2Vmc/preview)<br>[audio file](https://github.com/stepfun-ai/Step-Audio/tree/main/examples/emotional_control1.wav)|
         | 
| 632 | 
            -
            |Human: 小跃。你真的是。特别厉害。|[google drive](https://drive.google.com/file/d/19ga1RpguDP5r0Xfl1r5GY1J-kzbmHvJb/preview)<br>[audio file](https://github.com/stepfun-ai/Step-Audio/tree/main/examples/emotional_control2.wav)|
         | 
| 633 | 
            -
             | 
| 634 | 
            -
             | 
| 635 | 
            -
            ### 多语言 (e.g., 中文, 英文, 日语)
         | 
| 636 | 
            -
            | prompt | response |
         | 
| 637 | 
            -
            |:-------:|:-------:|
         | 
| 638 | 
            -
            |Human: What did the speaker mean when they said, it's raining cats and dogs?<br>Assistant: When they say "It's raining cats and dogs," it just means it's raining really hard. The speaker isn't literally saying cats and dogs are falling from the sky! It's just a fun way to describe heavy rain.|[google drive](https://drive.google.com/file/d/1LEIvdR5ANMzWX8GOTqUPTNrynNS1xx--/preview)<br>[audio file](https://github.com/stepfun-ai/Step-Audio/tree/main/examples/multilingual2.wav)|
         | 
| 639 | 
            -
            |Human: こんにちは。(你好)<br>Assistant:こんにちは!何か手伝いましょうか?(您好!我可以帮你做点什么吗?)|[google drive](https://drive.google.com/file/d/1MjKUkkzcGzVcNVXRr_Ya5y2H44K_lybH/preview)<br>[audio file](https://github.com/stepfun-ai/Step-Audio/tree/main/examples/multilingual1.wav)|
         | 
| 640 | 
            -
             | 
| 641 | 
            -
            ### Rap & Vocal
         | 
| 642 | 
            -
            | prompt | response |
         | 
| 643 | 
            -
            |:-------:|:-------:|
         | 
| 644 | 
            -
            |human:唱一段rap|[google drive](https://drive.google.com/file/d/1F8CKmVbGZ7X7d1IkQPlmndSHeG40AXha/preview)<br>[audio file](https://github.com/stepfun-ai/Step-Audio/tree/main/examples/rap.wav)|
         | 
| 645 | 
            -
             | 
| 646 | 
            -
            ## 8. 引用
         | 
| 647 | 
            -
            [论文](https://github.com/stepfun-ai/Step-Audio/blob/cn-readme/assets/Step-Audio.pdf)已提交至arXiv,目前正在审核中。在审核完成后会提供官方预印本链接和论文引用。
         | 
| 648 | 
            -
            ```
         | 
| 649 | 
            -
            @misc{stepaudiotechnicalreport,
         | 
| 650 | 
            -
                  title={Step-Audio: Unified Understanding and Generation in Intelligent Speech Interaction},
         | 
| 651 | 
            -
                  author={Step-Audio Team},
         | 
| 652 | 
            -
                  year={2025},
         | 
| 653 | 
            -
            }
         | 
| 654 | 
            -
            ```
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        __init__.py
    DELETED
    
    | 
            File without changes
         | 
    	
        app.py
    DELETED
    
    | @@ -1,124 +0,0 @@ | |
| 1 | 
            -
            import gradio as gr
         | 
| 2 | 
            -
            import torch
         | 
| 3 | 
            -
            import torchaudio
         | 
| 4 | 
            -
            from huggingface_hub import snapshot_download
         | 
| 5 | 
            -
            from tts import StepAudioTTS
         | 
| 6 | 
            -
            from tokenizer import StepAudioTokenizer
         | 
| 7 | 
            -
            import os
         | 
| 8 | 
            -
            import tempfile
         | 
| 9 | 
            -
            import spaces
         | 
| 10 | 
            -
             | 
| 11 | 
            -
            class StepAudioDemo:
         | 
| 12 | 
            -
                def __init__(self):
         | 
| 13 | 
            -
                    # Download models from HuggingFace
         | 
| 14 | 
            -
                    print("Downloading models from HuggingFace...")
         | 
| 15 | 
            -
                    self.model_path = snapshot_download(repo_id="stepfun-ai/Step-Audio-TTS-3B")
         | 
| 16 | 
            -
                    self.tokenizer_path = snapshot_download(repo_id="stepfun-ai/Step-Audio-Tokenizer")
         | 
| 17 | 
            -
                    
         | 
| 18 | 
            -
                    # Initialize models
         | 
| 19 | 
            -
                    print("Initializing models...")
         | 
| 20 | 
            -
                    self.encoder = StepAudioTokenizer(self.tokenizer_path)
         | 
| 21 | 
            -
                    self.tts_engine = StepAudioTTS(self.model_path, self.encoder)
         | 
| 22 | 
            -
                    
         | 
| 23 | 
            -
                    # Create temporary directory for outputs
         | 
| 24 | 
            -
                    self.temp_dir = tempfile.mkdtemp()
         | 
| 25 | 
            -
                    print("Models loaded and ready!")
         | 
| 26 | 
            -
             | 
| 27 | 
            -
                @spaces.GPU
         | 
| 28 | 
            -
                def generate_tts(self, text, speaker_name):
         | 
| 29 | 
            -
                    """Generate TTS audio"""
         | 
| 30 | 
            -
                    try:
         | 
| 31 | 
            -
                        output_audio, sr = self.tts_engine(text, speaker_name)
         | 
| 32 | 
            -
                        output_path = os.path.join(self.temp_dir, "output_tts.wav")
         | 
| 33 | 
            -
                        torchaudio.save(output_path, output_audio, sr)
         | 
| 34 | 
            -
                        return output_path
         | 
| 35 | 
            -
                    except Exception as e:
         | 
| 36 | 
            -
                        return f"Error generating audio: {str(e)}"
         | 
| 37 | 
            -
             | 
| 38 | 
            -
                @spaces.GPU
         | 
| 39 | 
            -
                def generate_clone(self, text, prompt_audio, prompt_text):
         | 
| 40 | 
            -
                    """Generate cloned voice audio"""
         | 
| 41 | 
            -
                    try:
         | 
| 42 | 
            -
                        clone_speaker = {
         | 
| 43 | 
            -
                            "speaker": "clone",
         | 
| 44 | 
            -
                            "prompt_text": prompt_text,
         | 
| 45 | 
            -
                            "wav_path": prompt_audio
         | 
| 46 | 
            -
                        }
         | 
| 47 | 
            -
                        output_audio, sr = self.tts_engine(text, "", clone_speaker)
         | 
| 48 | 
            -
                        output_path = os.path.join(self.temp_dir, "output_clone.wav")
         | 
| 49 | 
            -
                        torchaudio.save(output_path, output_audio, sr)
         | 
| 50 | 
            -
                        return output_path
         | 
| 51 | 
            -
                    except Exception as e:
         | 
| 52 | 
            -
                        return f"Error generating cloned audio: {str(e)}"
         | 
| 53 | 
            -
             | 
| 54 | 
            -
            def create_demo():
         | 
| 55 | 
            -
                demo = StepAudioDemo()
         | 
| 56 | 
            -
                
         | 
| 57 | 
            -
                with gr.Blocks() as interface:
         | 
| 58 | 
            -
                    gr.Markdown("# Step Audio TTS Demo")
         | 
| 59 | 
            -
                    
         | 
| 60 | 
            -
                    with gr.Tabs():
         | 
| 61 | 
            -
                        # TTS Tab
         | 
| 62 | 
            -
                        with gr.TabItem("Text-to-Speech"):
         | 
| 63 | 
            -
                            with gr.Row():
         | 
| 64 | 
            -
                                with gr.Column():
         | 
| 65 | 
            -
                                    tts_text = gr.Textbox(
         | 
| 66 | 
            -
                                        label="Input Text",
         | 
| 67 | 
            -
                                        placeholder="Enter text to synthesize...",
         | 
| 68 | 
            -
                                        lines=5
         | 
| 69 | 
            -
                                    )
         | 
| 70 | 
            -
                                    speaker_name = gr.Textbox(
         | 
| 71 | 
            -
                                        label="Speaker Name",
         | 
| 72 | 
            -
                                        placeholder="Enter speaker name (e.g., 闫雨婷)",
         | 
| 73 | 
            -
                                        value="闫雨婷"
         | 
| 74 | 
            -
                                    )
         | 
| 75 | 
            -
                                    tts_button = gr.Button("Generate Speech")
         | 
| 76 | 
            -
                                with gr.Column():
         | 
| 77 | 
            -
                                    tts_output = gr.Audio(label="Generated Audio")
         | 
| 78 | 
            -
                            
         | 
| 79 | 
            -
                            tts_button.click(
         | 
| 80 | 
            -
                                fn=demo.generate_tts,
         | 
| 81 | 
            -
                                inputs=[tts_text, speaker_name],
         | 
| 82 | 
            -
                                outputs=tts_output
         | 
| 83 | 
            -
                            )
         | 
| 84 | 
            -
                        
         | 
| 85 | 
            -
                        # Voice Cloning Tab
         | 
| 86 | 
            -
                        with gr.TabItem("Voice Cloning"):
         | 
| 87 | 
            -
                            with gr.Row():
         | 
| 88 | 
            -
                                with gr.Column():
         | 
| 89 | 
            -
                                    clone_text = gr.Textbox(
         | 
| 90 | 
            -
                                        label="Input Text",
         | 
| 91 | 
            -
                                        placeholder="Enter text to synthesize with cloned voice...",
         | 
| 92 | 
            -
                                        lines=5
         | 
| 93 | 
            -
                                    )
         | 
| 94 | 
            -
                                    prompt_text = gr.Textbox(
         | 
| 95 | 
            -
                                        label="Prompt Text",
         | 
| 96 | 
            -
                                        placeholder="Enter the transcript of your prompt audio...",
         | 
| 97 | 
            -
                                        lines=3
         | 
| 98 | 
            -
                                    )
         | 
| 99 | 
            -
                                    prompt_audio = gr.Audio(
         | 
| 100 | 
            -
                                        label="Upload Prompt Audio",
         | 
| 101 | 
            -
                                        type="filepath"
         | 
| 102 | 
            -
                                    )
         | 
| 103 | 
            -
                                    clone_button = gr.Button("Generate Cloned Speech")
         | 
| 104 | 
            -
                                with gr.Column():
         | 
| 105 | 
            -
                                    clone_output = gr.Audio(label="Generated Audio")
         | 
| 106 | 
            -
                            
         | 
| 107 | 
            -
                            clone_button.click(
         | 
| 108 | 
            -
                                fn=demo.generate_clone,
         | 
| 109 | 
            -
                                inputs=[clone_text, prompt_audio, prompt_text],
         | 
| 110 | 
            -
                                outputs=clone_output
         | 
| 111 | 
            -
                            )
         | 
| 112 | 
            -
             | 
| 113 | 
            -
                    gr.Markdown("""
         | 
| 114 | 
            -
                    ## Usage Notes:
         | 
| 115 | 
            -
                    - For basic TTS: Enter text and speaker name in the Text-to-Speech tab
         | 
| 116 | 
            -
                    - For voice cloning: Upload a prompt audio file, enter its transcript, and the text you want to synthesize
         | 
| 117 | 
            -
                    - Generation may take a few moments depending on text length
         | 
| 118 | 
            -
                    """)
         | 
| 119 | 
            -
             | 
| 120 | 
            -
                return interface
         | 
| 121 | 
            -
             | 
| 122 | 
            -
            if __name__ == "__main__":
         | 
| 123 | 
            -
                demo = create_demo()
         | 
| 124 | 
            -
                demo.queue().launch()
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        assets/Step-Audio.pdf
    DELETED
    
    | @@ -1,3 +0,0 @@ | |
| 1 | 
            -
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            -
            oid sha256:06212c657ee762261f6593fec650911051c66220839b5e50bfdf82356998679e
         | 
| 3 | 
            -
            size 7031189
         | 
|  | |
|  | |
|  | |
|  | 
    	
        assets/architecture.png
    DELETED
    
    | Git LFS Details
 | 
    	
        assets/logo.png
    DELETED
    
    | Binary file (6.87 kB) | 
|  | 
    	
        assets/pipeline.png
    DELETED
    
    | Git LFS Details
 | 
    	
        assets/rlhf.png
    DELETED
    
    | Binary file (70 kB) | 
|  | 
    	
        assets/stepeval_radar_chart.png
    DELETED
    
    | Git LFS Details
 | 
    	
        assets/yuewen.jpeg
    DELETED
    
    | Binary file (57.7 kB) | 
|  | 
    	
        cosyvoice/__init__.py
    DELETED
    
    | 
            File without changes
         | 
    	
        cosyvoice/cli/__init__.py
    DELETED
    
    | 
            File without changes
         | 
    	
        cosyvoice/cli/cosyvoice.py
    DELETED
    
    | @@ -1,68 +0,0 @@ | |
| 1 | 
            -
            # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
         | 
| 2 | 
            -
            #
         | 
| 3 | 
            -
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 4 | 
            -
            # you may not use this file except in compliance with the License.
         | 
| 5 | 
            -
            # You may obtain a copy of the License at
         | 
| 6 | 
            -
            #
         | 
| 7 | 
            -
            #   http://www.apache.org/licenses/LICENSE-2.0
         | 
| 8 | 
            -
            #
         | 
| 9 | 
            -
            # Unless required by applicable law or agreed to in writing, software
         | 
| 10 | 
            -
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 11 | 
            -
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 12 | 
            -
            # See the License for the specific language governing permissions and
         | 
| 13 | 
            -
            # limitations under the License.
         | 
| 14 | 
            -
            import os
         | 
| 15 | 
            -
            import uuid
         | 
| 16 | 
            -
            import time
         | 
| 17 | 
            -
            from tqdm import tqdm
         | 
| 18 | 
            -
            import torch
         | 
| 19 | 
            -
            import torchaudio
         | 
| 20 | 
            -
            from hyperpyyaml import load_hyperpyyaml
         | 
| 21 | 
            -
            from cosyvoice.cli.frontend import CosyVoiceFrontEnd
         | 
| 22 | 
            -
            from cosyvoice.cli.model import CosyVoiceModel
         | 
| 23 | 
            -
             | 
| 24 | 
            -
             | 
| 25 | 
            -
            class CosyVoice:
         | 
| 26 | 
            -
             | 
| 27 | 
            -
                def __init__(
         | 
| 28 | 
            -
                    self,
         | 
| 29 | 
            -
                    model_dir,
         | 
| 30 | 
            -
                ):
         | 
| 31 | 
            -
                    self.model_dir = model_dir
         | 
| 32 | 
            -
                    with open("{}/cosyvoice.yaml".format(model_dir), "r") as f:
         | 
| 33 | 
            -
                        configs = load_hyperpyyaml(f)
         | 
| 34 | 
            -
                    self.frontend = CosyVoiceFrontEnd(
         | 
| 35 | 
            -
                        configs["feat_extractor"],
         | 
| 36 | 
            -
                        "{}/campplus.onnx".format(model_dir),
         | 
| 37 | 
            -
                        "{}/speech_tokenizer_v1.onnx".format(model_dir),
         | 
| 38 | 
            -
                    )
         | 
| 39 | 
            -
                    self.model = CosyVoiceModel(configs["flow"], configs["hift"])
         | 
| 40 | 
            -
                    self.model.load(
         | 
| 41 | 
            -
                        "{}/flow.pt".format(model_dir),
         | 
| 42 | 
            -
                        "{}/hift.pt".format(model_dir),
         | 
| 43 | 
            -
                    )
         | 
| 44 | 
            -
                    self.model.flow = self.model.flow.to(torch.bfloat16)
         | 
| 45 | 
            -
                    del configs
         | 
| 46 | 
            -
             | 
| 47 | 
            -
                def token_to_wav_offline(
         | 
| 48 | 
            -
                    self,
         | 
| 49 | 
            -
                    speech_token,
         | 
| 50 | 
            -
                    speech_feat,
         | 
| 51 | 
            -
                    speech_feat_len,
         | 
| 52 | 
            -
                    prompt_token,
         | 
| 53 | 
            -
                    prompt_token_len,
         | 
| 54 | 
            -
                    embedding,
         | 
| 55 | 
            -
                ):
         | 
| 56 | 
            -
                    tts_mel = self.model.flow.inference(
         | 
| 57 | 
            -
                        token=speech_token.to(self.model.device),
         | 
| 58 | 
            -
                        token_len=torch.tensor([speech_token.size(1)], dtype=torch.int32).to(
         | 
| 59 | 
            -
                            self.model.device
         | 
| 60 | 
            -
                        ),
         | 
| 61 | 
            -
                        prompt_token=prompt_token.to(self.model.device),
         | 
| 62 | 
            -
                        prompt_token_len=prompt_token_len.to(self.model.device),
         | 
| 63 | 
            -
                        prompt_feat=speech_feat.to(self.model.device),
         | 
| 64 | 
            -
                        prompt_feat_len=speech_feat_len.to(self.model.device),
         | 
| 65 | 
            -
                        embedding=embedding.to(self.model.device),
         | 
| 66 | 
            -
                    )
         | 
| 67 | 
            -
                    tts_speech = self.model.hift.inference(mel=tts_mel.float())[0].cpu()
         | 
| 68 | 
            -
                    return tts_speech
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        cosyvoice/cli/frontend.py
    DELETED
    
    | @@ -1,106 +0,0 @@ | |
| 1 | 
            -
            # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
         | 
| 2 | 
            -
            #
         | 
| 3 | 
            -
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 4 | 
            -
            # you may not use this file except in compliance with the License.
         | 
| 5 | 
            -
            # You may obtain a copy of the License at
         | 
| 6 | 
            -
            #
         | 
| 7 | 
            -
            #   http://www.apache.org/licenses/LICENSE-2.0
         | 
| 8 | 
            -
            #
         | 
| 9 | 
            -
            # Unless required by applicable law or agreed to in writing, software
         | 
| 10 | 
            -
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 11 | 
            -
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 12 | 
            -
            # See the License for the specific language governing permissions and
         | 
| 13 | 
            -
            # limitations under the License.
         | 
| 14 | 
            -
            import onnxruntime
         | 
| 15 | 
            -
            import torch
         | 
| 16 | 
            -
            import numpy as np
         | 
| 17 | 
            -
            import whisper
         | 
| 18 | 
            -
            from typing import Callable
         | 
| 19 | 
            -
            import torchaudio.compliance.kaldi as kaldi
         | 
| 20 | 
            -
             | 
| 21 | 
            -
             | 
| 22 | 
            -
            class CosyVoiceFrontEnd:
         | 
| 23 | 
            -
             | 
| 24 | 
            -
                def __init__(
         | 
| 25 | 
            -
                    self,
         | 
| 26 | 
            -
                    feat_extractor: Callable,
         | 
| 27 | 
            -
                    campplus_model: str,
         | 
| 28 | 
            -
                    speech_tokenizer_model: str,
         | 
| 29 | 
            -
                ):
         | 
| 30 | 
            -
                    self.feat_extractor = feat_extractor
         | 
| 31 | 
            -
                    self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
         | 
| 32 | 
            -
                    option = onnxruntime.SessionOptions()
         | 
| 33 | 
            -
                    option.graph_optimization_level = (
         | 
| 34 | 
            -
                        onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
         | 
| 35 | 
            -
                    )
         | 
| 36 | 
            -
                    option.intra_op_num_threads = 1
         | 
| 37 | 
            -
                    self.campplus_session = onnxruntime.InferenceSession(
         | 
| 38 | 
            -
                        campplus_model, sess_options=option, providers=["CPUExecutionProvider"]
         | 
| 39 | 
            -
                    )
         | 
| 40 | 
            -
                    self.speech_tokenizer_session = onnxruntime.InferenceSession(
         | 
| 41 | 
            -
                        speech_tokenizer_model,
         | 
| 42 | 
            -
                        sess_options=option,
         | 
| 43 | 
            -
                        providers=[
         | 
| 44 | 
            -
                            (
         | 
| 45 | 
            -
                                "CUDAExecutionProvider"
         | 
| 46 | 
            -
                                if torch.cuda.is_available()
         | 
| 47 | 
            -
                                else "CPUExecutionProvider"
         | 
| 48 | 
            -
                            )
         | 
| 49 | 
            -
                        ],
         | 
| 50 | 
            -
                    )
         | 
| 51 | 
            -
             | 
| 52 | 
            -
                def _extract_speech_token(self, speech):
         | 
| 53 | 
            -
                    assert (
         | 
| 54 | 
            -
                        speech.shape[1] / 16000 <= 30
         | 
| 55 | 
            -
                    ), "do not support extract speech token for audio longer than 30s"
         | 
| 56 | 
            -
                    feat = whisper.log_mel_spectrogram(speech, n_mels=128)
         | 
| 57 | 
            -
                    speech_token = (
         | 
| 58 | 
            -
                        self.speech_tokenizer_session.run(
         | 
| 59 | 
            -
                            None,
         | 
| 60 | 
            -
                            {
         | 
| 61 | 
            -
                                self.speech_tokenizer_session.get_inputs()[0]
         | 
| 62 | 
            -
                                .name: feat.detach()
         | 
| 63 | 
            -
                                .cpu()
         | 
| 64 | 
            -
                                .numpy(),
         | 
| 65 | 
            -
                                self.speech_tokenizer_session.get_inputs()[1].name: np.array(
         | 
| 66 | 
            -
                                    [feat.shape[2]], dtype=np.int32
         | 
| 67 | 
            -
                                ),
         | 
| 68 | 
            -
                            },
         | 
| 69 | 
            -
                        )[0]
         | 
| 70 | 
            -
                        .flatten()
         | 
| 71 | 
            -
                        .tolist()
         | 
| 72 | 
            -
                    )
         | 
| 73 | 
            -
                    speech_token = torch.tensor([speech_token], dtype=torch.int32).to(self.device)
         | 
| 74 | 
            -
                    speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32).to(
         | 
| 75 | 
            -
                        self.device
         | 
| 76 | 
            -
                    )
         | 
| 77 | 
            -
                    return speech_token, speech_token_len
         | 
| 78 | 
            -
             | 
| 79 | 
            -
                def _extract_spk_embedding(self, speech):
         | 
| 80 | 
            -
                    feat = kaldi.fbank(speech, num_mel_bins=80, dither=0, sample_frequency=16000)
         | 
| 81 | 
            -
                    feat = feat - feat.mean(dim=0, keepdim=True)
         | 
| 82 | 
            -
                    embedding = (
         | 
| 83 | 
            -
                        self.campplus_session.run(
         | 
| 84 | 
            -
                            None,
         | 
| 85 | 
            -
                            {
         | 
| 86 | 
            -
                                self.campplus_session.get_inputs()[0]
         | 
| 87 | 
            -
                                .name: feat.unsqueeze(dim=0)
         | 
| 88 | 
            -
                                .cpu()
         | 
| 89 | 
            -
                                .numpy()
         | 
| 90 | 
            -
                            },
         | 
| 91 | 
            -
                        )[0]
         | 
| 92 | 
            -
                        .flatten()
         | 
| 93 | 
            -
                        .tolist()
         | 
| 94 | 
            -
                    )
         | 
| 95 | 
            -
                    embedding = torch.tensor([embedding]).to(self.device)
         | 
| 96 | 
            -
                    return embedding
         | 
| 97 | 
            -
             | 
| 98 | 
            -
                def _extract_speech_feat(self, speech):
         | 
| 99 | 
            -
                    speech_feat = (
         | 
| 100 | 
            -
                        self.feat_extractor(speech).squeeze(dim=0).transpose(0, 1).to(self.device)
         | 
| 101 | 
            -
                    )
         | 
| 102 | 
            -
                    speech_feat = speech_feat.unsqueeze(dim=0)
         | 
| 103 | 
            -
                    speech_feat_len = torch.tensor([speech_feat.shape[1]], dtype=torch.int32).to(
         | 
| 104 | 
            -
                        self.device
         | 
| 105 | 
            -
                    )
         | 
| 106 | 
            -
                    return speech_feat, speech_feat_len
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        cosyvoice/cli/model.py
    DELETED
    
    | @@ -1,32 +0,0 @@ | |
| 1 | 
            -
            # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
         | 
| 2 | 
            -
            #
         | 
| 3 | 
            -
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 4 | 
            -
            # you may not use this file except in compliance with the License.
         | 
| 5 | 
            -
            # You may obtain a copy of the License at
         | 
| 6 | 
            -
            #
         | 
| 7 | 
            -
            #   http://www.apache.org/licenses/LICENSE-2.0
         | 
| 8 | 
            -
            #
         | 
| 9 | 
            -
            # Unless required by applicable law or agreed to in writing, software
         | 
| 10 | 
            -
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 11 | 
            -
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 12 | 
            -
            # See the License for the specific language governing permissions and
         | 
| 13 | 
            -
            # limitations under the License.
         | 
| 14 | 
            -
            import torch
         | 
| 15 | 
            -
             | 
| 16 | 
            -
             | 
| 17 | 
            -
            class CosyVoiceModel:
         | 
| 18 | 
            -
             | 
| 19 | 
            -
                def __init__(
         | 
| 20 | 
            -
                    self,
         | 
| 21 | 
            -
                    flow: torch.nn.Module,
         | 
| 22 | 
            -
                    hift: torch.nn.Module,
         | 
| 23 | 
            -
                ):
         | 
| 24 | 
            -
                    self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
         | 
| 25 | 
            -
                    self.flow = flow
         | 
| 26 | 
            -
                    self.hift = hift
         | 
| 27 | 
            -
             | 
| 28 | 
            -
                def load(self, flow_model, hift_model):
         | 
| 29 | 
            -
                    self.flow.load_state_dict(torch.load(flow_model, map_location=self.device))
         | 
| 30 | 
            -
                    self.flow.to(self.device).eval()
         | 
| 31 | 
            -
                    self.hift.load_state_dict(torch.load(hift_model, map_location=self.device))
         | 
| 32 | 
            -
                    self.hift.to(self.device).eval()
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        cosyvoice/flow/decoder.py
    DELETED
    
    | @@ -1,238 +0,0 @@ | |
| 1 | 
            -
            # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
         | 
| 2 | 
            -
            #
         | 
| 3 | 
            -
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 4 | 
            -
            # you may not use this file except in compliance with the License.
         | 
| 5 | 
            -
            # You may obtain a copy of the License at
         | 
| 6 | 
            -
            #
         | 
| 7 | 
            -
            #     http://www.apache.org/licenses/LICENSE-2.0
         | 
| 8 | 
            -
            #
         | 
| 9 | 
            -
            # Unless required by applicable law or agreed to in writing, software
         | 
| 10 | 
            -
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 11 | 
            -
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 12 | 
            -
            # See the License for the specific language governing permissions and
         | 
| 13 | 
            -
            # limitations under the License.
         | 
| 14 | 
            -
            import torch
         | 
| 15 | 
            -
            import torch.nn as nn
         | 
| 16 | 
            -
            from einops import pack, rearrange, repeat
         | 
| 17 | 
            -
            from cosyvoice.matcha.decoder import (
         | 
| 18 | 
            -
                SinusoidalPosEmb,
         | 
| 19 | 
            -
                Block1D,
         | 
| 20 | 
            -
                ResnetBlock1D,
         | 
| 21 | 
            -
                Downsample1D,
         | 
| 22 | 
            -
                TimestepEmbedding,
         | 
| 23 | 
            -
                Upsample1D,
         | 
| 24 | 
            -
            )
         | 
| 25 | 
            -
            from cosyvoice.matcha.transformer import BasicTransformerBlock
         | 
| 26 | 
            -
             | 
| 27 | 
            -
             | 
| 28 | 
            -
            class ConditionalDecoder(nn.Module):
         | 
| 29 | 
            -
                def __init__(
         | 
| 30 | 
            -
                    self,
         | 
| 31 | 
            -
                    in_channels,
         | 
| 32 | 
            -
                    out_channels,
         | 
| 33 | 
            -
                    channels=(256, 256),
         | 
| 34 | 
            -
                    dropout=0.05,
         | 
| 35 | 
            -
                    attention_head_dim=64,
         | 
| 36 | 
            -
                    n_blocks=1,
         | 
| 37 | 
            -
                    num_mid_blocks=2,
         | 
| 38 | 
            -
                    num_heads=4,
         | 
| 39 | 
            -
                    act_fn="snake",
         | 
| 40 | 
            -
                ):
         | 
| 41 | 
            -
                    """
         | 
| 42 | 
            -
                    This decoder requires an input with the same shape of the target. So, if your text content
         | 
| 43 | 
            -
                    is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
         | 
| 44 | 
            -
                    """
         | 
| 45 | 
            -
                    super().__init__()
         | 
| 46 | 
            -
                    channels = tuple(channels)
         | 
| 47 | 
            -
                    self.in_channels = in_channels
         | 
| 48 | 
            -
                    self.out_channels = out_channels
         | 
| 49 | 
            -
             | 
| 50 | 
            -
                    self.time_embeddings = SinusoidalPosEmb(in_channels)
         | 
| 51 | 
            -
                    time_embed_dim = channels[0] * 4
         | 
| 52 | 
            -
                    self.time_mlp = TimestepEmbedding(
         | 
| 53 | 
            -
                        in_channels=in_channels,
         | 
| 54 | 
            -
                        time_embed_dim=time_embed_dim,
         | 
| 55 | 
            -
                        act_fn="silu",
         | 
| 56 | 
            -
                    )
         | 
| 57 | 
            -
                    self.down_blocks = nn.ModuleList([])
         | 
| 58 | 
            -
                    self.mid_blocks = nn.ModuleList([])
         | 
| 59 | 
            -
                    self.up_blocks = nn.ModuleList([])
         | 
| 60 | 
            -
             | 
| 61 | 
            -
                    output_channel = in_channels
         | 
| 62 | 
            -
                    for i in range(len(channels)):  # pylint: disable=consider-using-enumerate
         | 
| 63 | 
            -
                        input_channel = output_channel
         | 
| 64 | 
            -
                        output_channel = channels[i]
         | 
| 65 | 
            -
                        is_last = i == len(channels) - 1
         | 
| 66 | 
            -
                        resnet = ResnetBlock1D(
         | 
| 67 | 
            -
                            dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim
         | 
| 68 | 
            -
                        )
         | 
| 69 | 
            -
                        transformer_blocks = nn.ModuleList(
         | 
| 70 | 
            -
                            [
         | 
| 71 | 
            -
                                BasicTransformerBlock(
         | 
| 72 | 
            -
                                    dim=output_channel,
         | 
| 73 | 
            -
                                    num_attention_heads=num_heads,
         | 
| 74 | 
            -
                                    attention_head_dim=attention_head_dim,
         | 
| 75 | 
            -
                                    dropout=dropout,
         | 
| 76 | 
            -
                                    activation_fn=act_fn,
         | 
| 77 | 
            -
                                )
         | 
| 78 | 
            -
                                for _ in range(n_blocks)
         | 
| 79 | 
            -
                            ]
         | 
| 80 | 
            -
                        )
         | 
| 81 | 
            -
                        downsample = (
         | 
| 82 | 
            -
                            Downsample1D(output_channel)
         | 
| 83 | 
            -
                            if not is_last
         | 
| 84 | 
            -
                            else nn.Conv1d(output_channel, output_channel, 3, padding=1)
         | 
| 85 | 
            -
                        )
         | 
| 86 | 
            -
                        self.down_blocks.append(
         | 
| 87 | 
            -
                            nn.ModuleList([resnet, transformer_blocks, downsample])
         | 
| 88 | 
            -
                        )
         | 
| 89 | 
            -
             | 
| 90 | 
            -
                    for _ in range(num_mid_blocks):
         | 
| 91 | 
            -
                        input_channel = channels[-1]
         | 
| 92 | 
            -
                        out_channels = channels[-1]
         | 
| 93 | 
            -
                        resnet = ResnetBlock1D(
         | 
| 94 | 
            -
                            dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim
         | 
| 95 | 
            -
                        )
         | 
| 96 | 
            -
             | 
| 97 | 
            -
                        transformer_blocks = nn.ModuleList(
         | 
| 98 | 
            -
                            [
         | 
| 99 | 
            -
                                BasicTransformerBlock(
         | 
| 100 | 
            -
                                    dim=output_channel,
         | 
| 101 | 
            -
                                    num_attention_heads=num_heads,
         | 
| 102 | 
            -
                                    attention_head_dim=attention_head_dim,
         | 
| 103 | 
            -
                                    dropout=dropout,
         | 
| 104 | 
            -
                                    activation_fn=act_fn,
         | 
| 105 | 
            -
                                )
         | 
| 106 | 
            -
                                for _ in range(n_blocks)
         | 
| 107 | 
            -
                            ]
         | 
| 108 | 
            -
                        )
         | 
| 109 | 
            -
             | 
| 110 | 
            -
                        self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
         | 
| 111 | 
            -
             | 
| 112 | 
            -
                    channels = channels[::-1] + (channels[0],)
         | 
| 113 | 
            -
                    for i in range(len(channels) - 1):
         | 
| 114 | 
            -
                        input_channel = channels[i] * 2
         | 
| 115 | 
            -
                        output_channel = channels[i + 1]
         | 
| 116 | 
            -
                        is_last = i == len(channels) - 2
         | 
| 117 | 
            -
                        resnet = ResnetBlock1D(
         | 
| 118 | 
            -
                            dim=input_channel,
         | 
| 119 | 
            -
                            dim_out=output_channel,
         | 
| 120 | 
            -
                            time_emb_dim=time_embed_dim,
         | 
| 121 | 
            -
                        )
         | 
| 122 | 
            -
                        transformer_blocks = nn.ModuleList(
         | 
| 123 | 
            -
                            [
         | 
| 124 | 
            -
                                BasicTransformerBlock(
         | 
| 125 | 
            -
                                    dim=output_channel,
         | 
| 126 | 
            -
                                    num_attention_heads=num_heads,
         | 
| 127 | 
            -
                                    attention_head_dim=attention_head_dim,
         | 
| 128 | 
            -
                                    dropout=dropout,
         | 
| 129 | 
            -
                                    activation_fn=act_fn,
         | 
| 130 | 
            -
                                )
         | 
| 131 | 
            -
                                for _ in range(n_blocks)
         | 
| 132 | 
            -
                            ]
         | 
| 133 | 
            -
                        )
         | 
| 134 | 
            -
                        upsample = (
         | 
| 135 | 
            -
                            Upsample1D(output_channel, use_conv_transpose=True)
         | 
| 136 | 
            -
                            if not is_last
         | 
| 137 | 
            -
                            else nn.Conv1d(output_channel, output_channel, 3, padding=1)
         | 
| 138 | 
            -
                        )
         | 
| 139 | 
            -
                        self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
         | 
| 140 | 
            -
                    self.final_block = Block1D(channels[-1], channels[-1])
         | 
| 141 | 
            -
                    self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
         | 
| 142 | 
            -
                    self.initialize_weights()
         | 
| 143 | 
            -
             | 
| 144 | 
            -
                def initialize_weights(self):
         | 
| 145 | 
            -
                    for m in self.modules():
         | 
| 146 | 
            -
                        if isinstance(m, nn.Conv1d):
         | 
| 147 | 
            -
                            nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
         | 
| 148 | 
            -
                            if m.bias is not None:
         | 
| 149 | 
            -
                                nn.init.constant_(m.bias, 0)
         | 
| 150 | 
            -
                        elif isinstance(m, nn.GroupNorm):
         | 
| 151 | 
            -
                            nn.init.constant_(m.weight, 1)
         | 
| 152 | 
            -
                            nn.init.constant_(m.bias, 0)
         | 
| 153 | 
            -
                        elif isinstance(m, nn.Linear):
         | 
| 154 | 
            -
                            nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
         | 
| 155 | 
            -
                            if m.bias is not None:
         | 
| 156 | 
            -
                                nn.init.constant_(m.bias, 0)
         | 
| 157 | 
            -
             | 
| 158 | 
            -
                def forward(self, x, mask, mu, t, spks=None, cond=None):
         | 
| 159 | 
            -
                    """Forward pass of the UNet1DConditional model.
         | 
| 160 | 
            -
             | 
| 161 | 
            -
                    Args:
         | 
| 162 | 
            -
                        x (torch.Tensor): shape (batch_size, in_channels, time)
         | 
| 163 | 
            -
                        mask (_type_): shape (batch_size, 1, time)
         | 
| 164 | 
            -
                        t (_type_): shape (batch_size)
         | 
| 165 | 
            -
                        spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
         | 
| 166 | 
            -
                        cond (_type_, optional): placeholder for future use. Defaults to None.
         | 
| 167 | 
            -
             | 
| 168 | 
            -
                    Raises:
         | 
| 169 | 
            -
                        ValueError: _description_
         | 
| 170 | 
            -
                        ValueError: _description_
         | 
| 171 | 
            -
             | 
| 172 | 
            -
                    Returns:
         | 
| 173 | 
            -
                        _type_: _description_
         | 
| 174 | 
            -
                    """
         | 
| 175 | 
            -
             | 
| 176 | 
            -
                    t = self.time_embeddings(t).to(t.dtype)
         | 
| 177 | 
            -
                    t = self.time_mlp(t)
         | 
| 178 | 
            -
             | 
| 179 | 
            -
                    x = pack([x, mu], "b * t")[0]
         | 
| 180 | 
            -
             | 
| 181 | 
            -
                    if spks is not None:
         | 
| 182 | 
            -
                        spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
         | 
| 183 | 
            -
                        x = pack([x, spks], "b * t")[0]
         | 
| 184 | 
            -
                    if cond is not None:
         | 
| 185 | 
            -
                        x = pack([x, cond], "b * t")[0]
         | 
| 186 | 
            -
             | 
| 187 | 
            -
                    hiddens = []
         | 
| 188 | 
            -
                    masks = [mask]
         | 
| 189 | 
            -
                    for resnet, transformer_blocks, downsample in self.down_blocks:
         | 
| 190 | 
            -
                        mask_down = masks[-1]
         | 
| 191 | 
            -
                        x = resnet(
         | 
| 192 | 
            -
                            x.to(torch.bfloat16), mask_down.to(torch.bfloat16), t.to(torch.bfloat16)
         | 
| 193 | 
            -
                        )
         | 
| 194 | 
            -
                        x = rearrange(x, "b c t -> b t c").contiguous()
         | 
| 195 | 
            -
                        # attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down)
         | 
| 196 | 
            -
                        for transformer_block in transformer_blocks:
         | 
| 197 | 
            -
                            x = transformer_block(
         | 
| 198 | 
            -
                                hidden_states=x,
         | 
| 199 | 
            -
                                # attention_mask=attn_mask,
         | 
| 200 | 
            -
                                timestep=t,
         | 
| 201 | 
            -
                            )
         | 
| 202 | 
            -
                        x = rearrange(x, "b t c -> b c t").contiguous()
         | 
| 203 | 
            -
                        hiddens.append(x)  # Save hidden states for skip connections
         | 
| 204 | 
            -
                        x = downsample(x * mask_down)
         | 
| 205 | 
            -
                        masks.append(mask_down[:, :, ::2])
         | 
| 206 | 
            -
                    masks = masks[:-1]
         | 
| 207 | 
            -
                    mask_mid = masks[-1]
         | 
| 208 | 
            -
             | 
| 209 | 
            -
                    for resnet, transformer_blocks in self.mid_blocks:
         | 
| 210 | 
            -
                        x = resnet(x, mask_mid, t)
         | 
| 211 | 
            -
                        x = rearrange(x, "b c t -> b t c").contiguous()
         | 
| 212 | 
            -
                        # attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid)
         | 
| 213 | 
            -
                        for transformer_block in transformer_blocks:
         | 
| 214 | 
            -
                            x = transformer_block(
         | 
| 215 | 
            -
                                hidden_states=x,
         | 
| 216 | 
            -
                                # attention_mask=attn_mask,
         | 
| 217 | 
            -
                                timestep=t,
         | 
| 218 | 
            -
                            )
         | 
| 219 | 
            -
                        x = rearrange(x, "b t c -> b c t").contiguous()
         | 
| 220 | 
            -
             | 
| 221 | 
            -
                    for resnet, transformer_blocks, upsample in self.up_blocks:
         | 
| 222 | 
            -
                        mask_up = masks.pop()
         | 
| 223 | 
            -
                        skip = hiddens.pop()
         | 
| 224 | 
            -
                        x = pack([x[:, :, : skip.shape[-1]], skip], "b * t")[0]
         | 
| 225 | 
            -
                        x = resnet(x, mask_up, t)
         | 
| 226 | 
            -
                        x = rearrange(x, "b c t -> b t c").contiguous()
         | 
| 227 | 
            -
                        # attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up)
         | 
| 228 | 
            -
                        for transformer_block in transformer_blocks:
         | 
| 229 | 
            -
                            x = transformer_block(
         | 
| 230 | 
            -
                                hidden_states=x,
         | 
| 231 | 
            -
                                # attention_mask=attn_mask,
         | 
| 232 | 
            -
                                timestep=t,
         | 
| 233 | 
            -
                            )
         | 
| 234 | 
            -
                        x = rearrange(x, "b t c -> b c t").contiguous()
         | 
| 235 | 
            -
                        x = upsample(x * mask_up)
         | 
| 236 | 
            -
                    x = self.final_block(x, mask_up)
         | 
| 237 | 
            -
                    output = self.final_proj(x * mask_up)
         | 
| 238 | 
            -
                    return output * mask
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        cosyvoice/flow/flow.py
    DELETED
    
    | @@ -1,196 +0,0 @@ | |
| 1 | 
            -
            # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
         | 
| 2 | 
            -
            #
         | 
| 3 | 
            -
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 4 | 
            -
            # you may not use this file except in compliance with the License.
         | 
| 5 | 
            -
            # You may obtain a copy of the License at
         | 
| 6 | 
            -
            #
         | 
| 7 | 
            -
            #     http://www.apache.org/licenses/LICENSE-2.0
         | 
| 8 | 
            -
            #
         | 
| 9 | 
            -
            # Unless required by applicable law or agreed to in writing, software
         | 
| 10 | 
            -
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 11 | 
            -
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 12 | 
            -
            # See the License for the specific language governing permissions and
         | 
| 13 | 
            -
            # limitations under the License.
         | 
| 14 | 
            -
            import logging
         | 
| 15 | 
            -
            import random
         | 
| 16 | 
            -
            from typing import Dict, Optional
         | 
| 17 | 
            -
            import torch
         | 
| 18 | 
            -
            import torch.nn as nn
         | 
| 19 | 
            -
            from torch.nn import functional as F
         | 
| 20 | 
            -
            from omegaconf import DictConfig
         | 
| 21 | 
            -
            from cosyvoice.utils.mask import make_pad_mask
         | 
| 22 | 
            -
            import time
         | 
| 23 | 
            -
             | 
| 24 | 
            -
             | 
| 25 | 
            -
            class MaskedDiffWithXvec(torch.nn.Module):
         | 
| 26 | 
            -
                def __init__(
         | 
| 27 | 
            -
                    self,
         | 
| 28 | 
            -
                    input_size: int = 512,
         | 
| 29 | 
            -
                    output_size: int = 80,
         | 
| 30 | 
            -
                    spk_embed_dim: int = 192,
         | 
| 31 | 
            -
                    output_type: str = "mel",
         | 
| 32 | 
            -
                    vocab_size: int = 4096,
         | 
| 33 | 
            -
                    input_frame_rate: int = 50,
         | 
| 34 | 
            -
                    only_mask_loss: bool = True,
         | 
| 35 | 
            -
                    encoder: torch.nn.Module = None,
         | 
| 36 | 
            -
                    length_regulator: torch.nn.Module = None,
         | 
| 37 | 
            -
                    decoder: torch.nn.Module = None,
         | 
| 38 | 
            -
                    decoder_conf: Dict = {
         | 
| 39 | 
            -
                        "in_channels": 240,
         | 
| 40 | 
            -
                        "out_channel": 80,
         | 
| 41 | 
            -
                        "spk_emb_dim": 80,
         | 
| 42 | 
            -
                        "n_spks": 1,
         | 
| 43 | 
            -
                        "cfm_params": DictConfig(
         | 
| 44 | 
            -
                            {
         | 
| 45 | 
            -
                                "sigma_min": 1e-06,
         | 
| 46 | 
            -
                                "solver": "euler",
         | 
| 47 | 
            -
                                "t_scheduler": "cosine",
         | 
| 48 | 
            -
                                "training_cfg_rate": 0.2,
         | 
| 49 | 
            -
                                "inference_cfg_rate": 0.7,
         | 
| 50 | 
            -
                                "reg_loss_type": "l1",
         | 
| 51 | 
            -
                            }
         | 
| 52 | 
            -
                        ),
         | 
| 53 | 
            -
                        "decoder_params": {
         | 
| 54 | 
            -
                            "channels": [256, 256],
         | 
| 55 | 
            -
                            "dropout": 0.0,
         | 
| 56 | 
            -
                            "attention_head_dim": 64,
         | 
| 57 | 
            -
                            "n_blocks": 4,
         | 
| 58 | 
            -
                            "num_mid_blocks": 12,
         | 
| 59 | 
            -
                            "num_heads": 8,
         | 
| 60 | 
            -
                            "act_fn": "gelu",
         | 
| 61 | 
            -
                        },
         | 
| 62 | 
            -
                    },
         | 
| 63 | 
            -
                    mel_feat_conf: Dict = {
         | 
| 64 | 
            -
                        "n_fft": 1024,
         | 
| 65 | 
            -
                        "num_mels": 80,
         | 
| 66 | 
            -
                        "sampling_rate": 22050,
         | 
| 67 | 
            -
                        "hop_size": 256,
         | 
| 68 | 
            -
                        "win_size": 1024,
         | 
| 69 | 
            -
                        "fmin": 0,
         | 
| 70 | 
            -
                        "fmax": 8000,
         | 
| 71 | 
            -
                    },
         | 
| 72 | 
            -
                ):
         | 
| 73 | 
            -
                    super().__init__()
         | 
| 74 | 
            -
                    self.input_size = input_size
         | 
| 75 | 
            -
                    self.output_size = output_size
         | 
| 76 | 
            -
                    self.decoder_conf = decoder_conf
         | 
| 77 | 
            -
                    self.mel_feat_conf = mel_feat_conf
         | 
| 78 | 
            -
                    self.vocab_size = vocab_size
         | 
| 79 | 
            -
                    self.output_type = output_type
         | 
| 80 | 
            -
                    self.input_frame_rate = input_frame_rate
         | 
| 81 | 
            -
                    logging.info(f"input frame rate={self.input_frame_rate}")
         | 
| 82 | 
            -
                    self.input_embedding = nn.Embedding(vocab_size, input_size)
         | 
| 83 | 
            -
                    self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
         | 
| 84 | 
            -
                    self.encoder = encoder
         | 
| 85 | 
            -
                    self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
         | 
| 86 | 
            -
                    self.decoder = decoder
         | 
| 87 | 
            -
                    self.length_regulator = length_regulator
         | 
| 88 | 
            -
                    self.only_mask_loss = only_mask_loss
         | 
| 89 | 
            -
             | 
| 90 | 
            -
                def forward(
         | 
| 91 | 
            -
                    self,
         | 
| 92 | 
            -
                    batch: dict,
         | 
| 93 | 
            -
                    device: torch.device,
         | 
| 94 | 
            -
                ) -> Dict[str, Optional[torch.Tensor]]:
         | 
| 95 | 
            -
                    token = batch["speech_token"].to(device)
         | 
| 96 | 
            -
                    token_len = batch["speech_token_len"].to(device)
         | 
| 97 | 
            -
                    feat = batch["speech_feat"].to(device)
         | 
| 98 | 
            -
                    feat_len = batch["speech_feat_len"].to(device)
         | 
| 99 | 
            -
                    embedding = batch["embedding"].to(device)
         | 
| 100 | 
            -
             | 
| 101 | 
            -
                    # xvec projection
         | 
| 102 | 
            -
                    embedding = F.normalize(embedding, dim=1)
         | 
| 103 | 
            -
                    embedding = self.spk_embed_affine_layer(embedding)
         | 
| 104 | 
            -
             | 
| 105 | 
            -
                    # concat text and prompt_text
         | 
| 106 | 
            -
                    mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
         | 
| 107 | 
            -
                    token = self.input_embedding(torch.clamp(token, min=0)) * mask
         | 
| 108 | 
            -
             | 
| 109 | 
            -
                    # text encode
         | 
| 110 | 
            -
                    h, h_lengths = self.encoder(token, token_len)
         | 
| 111 | 
            -
                    h = self.encoder_proj(h)
         | 
| 112 | 
            -
                    h, h_lengths = self.length_regulator(h, feat_len)
         | 
| 113 | 
            -
             | 
| 114 | 
            -
                    # get conditions
         | 
| 115 | 
            -
                    conds = torch.zeros(feat.shape, device=token.device)
         | 
| 116 | 
            -
                    for i, j in enumerate(feat_len):
         | 
| 117 | 
            -
                        if random.random() < 0.5:
         | 
| 118 | 
            -
                            continue
         | 
| 119 | 
            -
                        index = random.randint(0, int(0.3 * j))
         | 
| 120 | 
            -
                        conds[i, :index] = feat[i, :index]
         | 
| 121 | 
            -
                    conds = conds.transpose(1, 2)
         | 
| 122 | 
            -
             | 
| 123 | 
            -
                    mask = (~make_pad_mask(feat_len)).to(h)
         | 
| 124 | 
            -
                    feat = F.interpolate(
         | 
| 125 | 
            -
                        feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest"
         | 
| 126 | 
            -
                    ).squeeze(dim=1)
         | 
| 127 | 
            -
                    loss, _ = self.decoder.compute_loss(
         | 
| 128 | 
            -
                        feat.transpose(1, 2).contiguous(),
         | 
| 129 | 
            -
                        mask.unsqueeze(1),
         | 
| 130 | 
            -
                        h.transpose(1, 2).contiguous(),
         | 
| 131 | 
            -
                        embedding,
         | 
| 132 | 
            -
                        cond=conds,
         | 
| 133 | 
            -
                    )
         | 
| 134 | 
            -
                    return {"loss": loss}
         | 
| 135 | 
            -
             | 
| 136 | 
            -
                @torch.inference_mode()
         | 
| 137 | 
            -
                def inference(
         | 
| 138 | 
            -
                    self,
         | 
| 139 | 
            -
                    token,
         | 
| 140 | 
            -
                    token_len,
         | 
| 141 | 
            -
                    prompt_token,
         | 
| 142 | 
            -
                    prompt_token_len,
         | 
| 143 | 
            -
                    prompt_feat,
         | 
| 144 | 
            -
                    prompt_feat_len,
         | 
| 145 | 
            -
                    embedding,
         | 
| 146 | 
            -
                ):
         | 
| 147 | 
            -
                    assert token.shape[0] == 1
         | 
| 148 | 
            -
                    # xvec projection
         | 
| 149 | 
            -
                    embedding = F.normalize(embedding, dim=1)
         | 
| 150 | 
            -
                    embedding = self.spk_embed_affine_layer(embedding)
         | 
| 151 | 
            -
             | 
| 152 | 
            -
                    # concat text and prompt_text
         | 
| 153 | 
            -
                    token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
         | 
| 154 | 
            -
                    # text encode
         | 
| 155 | 
            -
                    token, token_len = (
         | 
| 156 | 
            -
                        torch.concat([prompt_token, token], dim=1),
         | 
| 157 | 
            -
                        prompt_token_len + token_len,
         | 
| 158 | 
            -
                    )
         | 
| 159 | 
            -
                    token = self.input_embedding(torch.clamp(token, min=0))
         | 
| 160 | 
            -
                    h, _ = self.encoder.inference(token, token_len)
         | 
| 161 | 
            -
                    h = self.encoder_proj(h)
         | 
| 162 | 
            -
                    mel_len1, mel_len2 = prompt_feat.shape[1], int(
         | 
| 163 | 
            -
                        token_len2
         | 
| 164 | 
            -
                        / self.input_frame_rate
         | 
| 165 | 
            -
                        * self.mel_feat_conf["sampling_rate"]
         | 
| 166 | 
            -
                        / self.mel_feat_conf["hop_size"]
         | 
| 167 | 
            -
                    )
         | 
| 168 | 
            -
             | 
| 169 | 
            -
                    h, _ = self.length_regulator.inference(
         | 
| 170 | 
            -
                        h[:, :token_len1],
         | 
| 171 | 
            -
                        h[:, token_len1:],
         | 
| 172 | 
            -
                        mel_len1,
         | 
| 173 | 
            -
                        mel_len2,
         | 
| 174 | 
            -
                    )
         | 
| 175 | 
            -
             | 
| 176 | 
            -
                    # get conditions
         | 
| 177 | 
            -
                    conds = torch.zeros(
         | 
| 178 | 
            -
                        [1, mel_len1 + mel_len2, self.output_size], device=token.device
         | 
| 179 | 
            -
                    )
         | 
| 180 | 
            -
                    conds[:, :mel_len1] = prompt_feat
         | 
| 181 | 
            -
                    conds = conds.transpose(1, 2)
         | 
| 182 | 
            -
             | 
| 183 | 
            -
                    # mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
         | 
| 184 | 
            -
                    mask = torch.ones(
         | 
| 185 | 
            -
                        [1, mel_len1 + mel_len2], device=h.device, dtype=torch.bfloat16
         | 
| 186 | 
            -
                    )
         | 
| 187 | 
            -
                    feat = self.decoder(
         | 
| 188 | 
            -
                        mu=h.transpose(1, 2).contiguous(),
         | 
| 189 | 
            -
                        mask=mask.unsqueeze(1),
         | 
| 190 | 
            -
                        spks=embedding,
         | 
| 191 | 
            -
                        cond=conds,
         | 
| 192 | 
            -
                        n_timesteps=10,
         | 
| 193 | 
            -
                    )
         | 
| 194 | 
            -
                    feat = feat[:, :, mel_len1:]
         | 
| 195 | 
            -
                    assert feat.shape[2] == mel_len2
         | 
| 196 | 
            -
                    return feat
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        cosyvoice/flow/flow_matching.py
    DELETED
    
    | @@ -1,315 +0,0 @@ | |
| 1 | 
            -
            # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
         | 
| 2 | 
            -
            #
         | 
| 3 | 
            -
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 4 | 
            -
            # you may not use this file except in compliance with the License.
         | 
| 5 | 
            -
            # You may obtain a copy of the License at
         | 
| 6 | 
            -
            #
         | 
| 7 | 
            -
            #     http://www.apache.org/licenses/LICENSE-2.0
         | 
| 8 | 
            -
            #
         | 
| 9 | 
            -
            # Unless required by applicable law or agreed to in writing, software
         | 
| 10 | 
            -
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 11 | 
            -
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 12 | 
            -
            # See the License for the specific language governing permissions and
         | 
| 13 | 
            -
            # limitations under the License.
         | 
| 14 | 
            -
            import time
         | 
| 15 | 
            -
            import torch
         | 
| 16 | 
            -
            import torch.nn.functional as F
         | 
| 17 | 
            -
            from cosyvoice.matcha.flow_matching import BASECFM
         | 
| 18 | 
            -
             | 
| 19 | 
            -
             | 
| 20 | 
            -
            class ConditionalCFM(BASECFM):
         | 
| 21 | 
            -
                def __init__(
         | 
| 22 | 
            -
                    self,
         | 
| 23 | 
            -
                    in_channels,
         | 
| 24 | 
            -
                    cfm_params,
         | 
| 25 | 
            -
                    n_spks=1,
         | 
| 26 | 
            -
                    spk_emb_dim=64,
         | 
| 27 | 
            -
                    estimator: torch.nn.Module = None,
         | 
| 28 | 
            -
                ):
         | 
| 29 | 
            -
                    super().__init__(
         | 
| 30 | 
            -
                        n_feats=in_channels,
         | 
| 31 | 
            -
                        cfm_params=cfm_params,
         | 
| 32 | 
            -
                        n_spks=n_spks,
         | 
| 33 | 
            -
                        spk_emb_dim=spk_emb_dim,
         | 
| 34 | 
            -
                    )
         | 
| 35 | 
            -
                    self.t_scheduler = cfm_params.t_scheduler
         | 
| 36 | 
            -
                    self.training_cfg_rate = cfm_params.training_cfg_rate
         | 
| 37 | 
            -
                    self.inference_cfg_rate = cfm_params.inference_cfg_rate
         | 
| 38 | 
            -
                    in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
         | 
| 39 | 
            -
                    # Just change the architecture of the estimator here
         | 
| 40 | 
            -
                    self.estimator = estimator
         | 
| 41 | 
            -
                    self.inference_graphs = {}
         | 
| 42 | 
            -
                    self.inference_buffers = {}
         | 
| 43 | 
            -
                    # self.capture_inference()
         | 
| 44 | 
            -
             | 
| 45 | 
            -
                @torch.inference_mode()
         | 
| 46 | 
            -
                def forward(
         | 
| 47 | 
            -
                    self,
         | 
| 48 | 
            -
                    mu,
         | 
| 49 | 
            -
                    mask,
         | 
| 50 | 
            -
                    n_timesteps,
         | 
| 51 | 
            -
                    temperature=1.0,
         | 
| 52 | 
            -
                    spks=None,
         | 
| 53 | 
            -
                    cond=None,
         | 
| 54 | 
            -
                ):
         | 
| 55 | 
            -
                    """Forward diffusion
         | 
| 56 | 
            -
             | 
| 57 | 
            -
                    Args:
         | 
| 58 | 
            -
                        mu (torch.Tensor): output of encoder
         | 
| 59 | 
            -
                            shape: (batch_size, n_feats, mel_timesteps)
         | 
| 60 | 
            -
                        mask (torch.Tensor): output_mask
         | 
| 61 | 
            -
                            shape: (batch_size, 1, mel_timesteps)
         | 
| 62 | 
            -
                        n_timesteps (int): number of diffusion steps
         | 
| 63 | 
            -
                        temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
         | 
| 64 | 
            -
                        spks (torch.Tensor, optional): speaker ids. Defaults to None.
         | 
| 65 | 
            -
                            shape: (batch_size, spk_emb_dim)
         | 
| 66 | 
            -
                        cond: Not used but kept for future purposes
         | 
| 67 | 
            -
             | 
| 68 | 
            -
                    Returns:
         | 
| 69 | 
            -
                        sample: generated mel-spectrogram
         | 
| 70 | 
            -
                            shape: (batch_size, n_feats, mel_timesteps)
         | 
| 71 | 
            -
                    """
         | 
| 72 | 
            -
                    z = torch.randn_like(mu) * temperature
         | 
| 73 | 
            -
                    t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
         | 
| 74 | 
            -
                    if self.t_scheduler == "cosine":
         | 
| 75 | 
            -
                        t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
         | 
| 76 | 
            -
                    return self.solve_euler(
         | 
| 77 | 
            -
                        z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond
         | 
| 78 | 
            -
                    )
         | 
| 79 | 
            -
             | 
| 80 | 
            -
                @torch.inference_mode()
         | 
| 81 | 
            -
                def capture_inference(self, seq_len_to_capture=list(range(128, 512, 8))):
         | 
| 82 | 
            -
                    start_time = time.time()
         | 
| 83 | 
            -
                    print(
         | 
| 84 | 
            -
                        f"capture_inference for ConditionalCFM solve euler, seq_len_to_capture: {seq_len_to_capture}"
         | 
| 85 | 
            -
                    )
         | 
| 86 | 
            -
                    for seq_len in seq_len_to_capture:
         | 
| 87 | 
            -
                        static_z = torch.randn(
         | 
| 88 | 
            -
                            1, 80, seq_len, device=torch.device("cuda"), dtype=torch.bfloat16
         | 
| 89 | 
            -
                        )
         | 
| 90 | 
            -
                        static_t_span = torch.linspace(
         | 
| 91 | 
            -
                            0, 1, 11, device=torch.device("cuda"), dtype=torch.bfloat16
         | 
| 92 | 
            -
                        )  # only capture at 10 steps
         | 
| 93 | 
            -
                        static_mu = torch.randn(
         | 
| 94 | 
            -
                            1, 80, seq_len, device=torch.device("cuda"), dtype=torch.bfloat16
         | 
| 95 | 
            -
                        )
         | 
| 96 | 
            -
                        static_mask = torch.ones(
         | 
| 97 | 
            -
                            1, 1, seq_len, device=torch.device("cuda"), dtype=torch.bfloat16
         | 
| 98 | 
            -
                        )
         | 
| 99 | 
            -
                        static_spks = torch.randn(
         | 
| 100 | 
            -
                            1, 80, device=torch.device("cuda"), dtype=torch.bfloat16
         | 
| 101 | 
            -
                        )
         | 
| 102 | 
            -
                        static_cond = torch.randn(
         | 
| 103 | 
            -
                            1, 80, seq_len, device=torch.device("cuda"), dtype=torch.float32
         | 
| 104 | 
            -
                        )
         | 
| 105 | 
            -
                        static_out = torch.randn(
         | 
| 106 | 
            -
                            1, 80, seq_len, device=torch.device("cuda"), dtype=torch.bfloat16
         | 
| 107 | 
            -
                        )
         | 
| 108 | 
            -
             | 
| 109 | 
            -
                        self._solve_euler_impl(
         | 
| 110 | 
            -
                            static_z,
         | 
| 111 | 
            -
                            t_span=static_t_span,
         | 
| 112 | 
            -
                            mu=static_mu,
         | 
| 113 | 
            -
                            mask=static_mask,
         | 
| 114 | 
            -
                            spks=static_spks,
         | 
| 115 | 
            -
                            cond=static_cond,
         | 
| 116 | 
            -
                        )
         | 
| 117 | 
            -
                        torch.cuda.synchronize()
         | 
| 118 | 
            -
             | 
| 119 | 
            -
                        g = torch.cuda.CUDAGraph()
         | 
| 120 | 
            -
                        with torch.cuda.graph(g):
         | 
| 121 | 
            -
                            static_out = self._solve_euler_impl(
         | 
| 122 | 
            -
                                static_z,
         | 
| 123 | 
            -
                                t_span=static_t_span,
         | 
| 124 | 
            -
                                mu=static_mu,
         | 
| 125 | 
            -
                                mask=static_mask,
         | 
| 126 | 
            -
                                spks=static_spks,
         | 
| 127 | 
            -
                                cond=static_cond,
         | 
| 128 | 
            -
                            )
         | 
| 129 | 
            -
             | 
| 130 | 
            -
                    self.inference_buffers[seq_len] = {
         | 
| 131 | 
            -
                        "z": static_z,
         | 
| 132 | 
            -
                        "t_span": static_t_span,
         | 
| 133 | 
            -
                        "mu": static_mu,
         | 
| 134 | 
            -
                        "mask": static_mask,
         | 
| 135 | 
            -
                        "spks": static_spks,
         | 
| 136 | 
            -
                        "cond": static_cond,
         | 
| 137 | 
            -
                        "out": static_out,
         | 
| 138 | 
            -
                    }
         | 
| 139 | 
            -
                    self.inference_graphs[seq_len] = g
         | 
| 140 | 
            -
                    end_time = time.time()
         | 
| 141 | 
            -
                    print(
         | 
| 142 | 
            -
                        f"capture_inference for ConditionalCFM solve euler, time elapsed: {end_time - start_time}"
         | 
| 143 | 
            -
                    )
         | 
| 144 | 
            -
             | 
| 145 | 
            -
                def solve_euler(self, x, t_span, mu, mask, spks, cond):
         | 
| 146 | 
            -
                    if hasattr(self, "inference_graphs") and len(self.inference_graphs) > 0:
         | 
| 147 | 
            -
                        curr_seq_len = x.shape[2]
         | 
| 148 | 
            -
             | 
| 149 | 
            -
                        available_lengths = sorted(list(self.inference_graphs.keys()))
         | 
| 150 | 
            -
             | 
| 151 | 
            -
                        if curr_seq_len <= max(available_lengths):
         | 
| 152 | 
            -
                            target_len = min(available_lengths, key=lambda x: abs(x - curr_seq_len))
         | 
| 153 | 
            -
                            if target_len == curr_seq_len:
         | 
| 154 | 
            -
                                padded_x = x
         | 
| 155 | 
            -
                                padded_mu = mu
         | 
| 156 | 
            -
                                padded_mask = mask
         | 
| 157 | 
            -
                                if cond is not None:
         | 
| 158 | 
            -
                                    padded_cond = cond
         | 
| 159 | 
            -
                            else:
         | 
| 160 | 
            -
                                padded_x = torch.randn(
         | 
| 161 | 
            -
                                    (x.shape[0], x.shape[1], target_len),
         | 
| 162 | 
            -
                                    dtype=x.dtype,
         | 
| 163 | 
            -
                                    device=x.device,
         | 
| 164 | 
            -
                                )
         | 
| 165 | 
            -
                                padded_x[:, :, :curr_seq_len] = x
         | 
| 166 | 
            -
             | 
| 167 | 
            -
                                padded_mu = torch.randn(
         | 
| 168 | 
            -
                                    (mu.shape[0], mu.shape[1], target_len),
         | 
| 169 | 
            -
                                    dtype=mu.dtype,
         | 
| 170 | 
            -
                                    device=mu.device,
         | 
| 171 | 
            -
                                )
         | 
| 172 | 
            -
                                padded_mu[:, :, :curr_seq_len] = mu
         | 
| 173 | 
            -
             | 
| 174 | 
            -
                                # FIXME(ys): uses zeros and maskgroupnorm
         | 
| 175 | 
            -
                                padded_mask = torch.ones(
         | 
| 176 | 
            -
                                    (mask.shape[0], mask.shape[1], target_len),
         | 
| 177 | 
            -
                                    dtype=mask.dtype,
         | 
| 178 | 
            -
                                    device=mask.device,
         | 
| 179 | 
            -
                                )
         | 
| 180 | 
            -
             | 
| 181 | 
            -
                                if cond is not None:
         | 
| 182 | 
            -
                                    padded_cond = torch.randn(
         | 
| 183 | 
            -
                                        (cond.shape[0], cond.shape[1], target_len),
         | 
| 184 | 
            -
                                        dtype=cond.dtype,
         | 
| 185 | 
            -
                                        device=cond.device,
         | 
| 186 | 
            -
                                    )
         | 
| 187 | 
            -
                                    padded_cond[:, :, :curr_seq_len] = cond
         | 
| 188 | 
            -
             | 
| 189 | 
            -
                            buffer = self.inference_buffers[target_len]
         | 
| 190 | 
            -
                            buffer["z"].copy_(padded_x)
         | 
| 191 | 
            -
                            buffer["t_span"].copy_(t_span)
         | 
| 192 | 
            -
                            buffer["mu"].copy_(padded_mu)
         | 
| 193 | 
            -
                            buffer["mask"].copy_(padded_mask)
         | 
| 194 | 
            -
                            buffer["spks"].copy_(spks)
         | 
| 195 | 
            -
                            if cond is not None:
         | 
| 196 | 
            -
                                buffer["cond"].copy_(padded_cond)
         | 
| 197 | 
            -
             | 
| 198 | 
            -
                            self.inference_graphs[target_len].replay()
         | 
| 199 | 
            -
             | 
| 200 | 
            -
                            output = buffer["out"][:, :, :curr_seq_len]
         | 
| 201 | 
            -
                            return output
         | 
| 202 | 
            -
             | 
| 203 | 
            -
                    return self._solve_euler_impl(x, t_span, mu, mask, spks, cond)
         | 
| 204 | 
            -
             | 
| 205 | 
            -
                def _solve_euler_impl(self, x, t_span, mu, mask, spks, cond):
         | 
| 206 | 
            -
                    """
         | 
| 207 | 
            -
                    Fixed euler solver for ODEs.
         | 
| 208 | 
            -
                    Args:
         | 
| 209 | 
            -
                        x (torch.Tensor): random noise
         | 
| 210 | 
            -
                        t_span (torch.Tensor): n_timesteps interpolated
         | 
| 211 | 
            -
                            shape: (n_timesteps + 1,)
         | 
| 212 | 
            -
                        mu (torch.Tensor): output of encoder
         | 
| 213 | 
            -
                            shape: (batch_size, n_feats, mel_timesteps)
         | 
| 214 | 
            -
                        mask (torch.Tensor): output_mask
         | 
| 215 | 
            -
                            shape: (batch_size, 1, mel_timesteps)
         | 
| 216 | 
            -
                        spks (torch.Tensor, optional): speaker ids. Defaults to None.
         | 
| 217 | 
            -
                            shape: (batch_size, spk_emb_dim)
         | 
| 218 | 
            -
                        cond: Not used but kept for future purposes
         | 
| 219 | 
            -
                    """
         | 
| 220 | 
            -
                    t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
         | 
| 221 | 
            -
                    t = t.unsqueeze(dim=0)
         | 
| 222 | 
            -
             | 
| 223 | 
            -
                    # I am storing this because I can later plot it by putting a debugger here and saving it to a file
         | 
| 224 | 
            -
                    # Or in future might add like a return_all_steps flag
         | 
| 225 | 
            -
                    sol = []
         | 
| 226 | 
            -
             | 
| 227 | 
            -
                    for step in range(1, len(t_span)):
         | 
| 228 | 
            -
                        if self.inference_cfg_rate > 0:
         | 
| 229 | 
            -
                            x_double = torch.cat([x, x], dim=0)
         | 
| 230 | 
            -
                            mask_double = torch.cat([mask, mask], dim=0)
         | 
| 231 | 
            -
                            mu_double = torch.cat([mu, torch.zeros_like(mu)], dim=0)
         | 
| 232 | 
            -
                            t_double = torch.cat([t, t], dim=0)
         | 
| 233 | 
            -
                            spks_double = (
         | 
| 234 | 
            -
                                torch.cat([spks, torch.zeros_like(spks)], dim=0)
         | 
| 235 | 
            -
                                if spks is not None
         | 
| 236 | 
            -
                                else None
         | 
| 237 | 
            -
                            )
         | 
| 238 | 
            -
                            cond_double = torch.cat([cond, torch.zeros_like(cond)], dim=0)
         | 
| 239 | 
            -
             | 
| 240 | 
            -
                            dphi_dt_double = self.forward_estimator(
         | 
| 241 | 
            -
                                x_double, mask_double, mu_double, t_double, spks_double, cond_double
         | 
| 242 | 
            -
                            )
         | 
| 243 | 
            -
             | 
| 244 | 
            -
                            dphi_dt, cfg_dphi_dt = torch.chunk(dphi_dt_double, 2, dim=0)
         | 
| 245 | 
            -
                            dphi_dt = (
         | 
| 246 | 
            -
                                1.0 + self.inference_cfg_rate
         | 
| 247 | 
            -
                            ) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt
         | 
| 248 | 
            -
                        else:
         | 
| 249 | 
            -
                            dphi_dt = self.forward_estimator(x, mask, mu, t, spks, cond)
         | 
| 250 | 
            -
             | 
| 251 | 
            -
                        x = x + dt * dphi_dt
         | 
| 252 | 
            -
                        t = t + dt
         | 
| 253 | 
            -
                        sol.append(x)
         | 
| 254 | 
            -
                        if step < len(t_span) - 1:
         | 
| 255 | 
            -
                            dt = t_span[step + 1] - t
         | 
| 256 | 
            -
             | 
| 257 | 
            -
                    return sol[-1]
         | 
| 258 | 
            -
             | 
| 259 | 
            -
                def forward_estimator(self, x, mask, mu, t, spks, cond):
         | 
| 260 | 
            -
                    if isinstance(self.estimator, torch.nn.Module):
         | 
| 261 | 
            -
                        return self.estimator.forward(x, mask, mu, t, spks, cond)
         | 
| 262 | 
            -
                    else:
         | 
| 263 | 
            -
                        ort_inputs = {
         | 
| 264 | 
            -
                            "x": x.cpu().numpy(),
         | 
| 265 | 
            -
                            "mask": mask.cpu().numpy(),
         | 
| 266 | 
            -
                            "mu": mu.cpu().numpy(),
         | 
| 267 | 
            -
                            "t": t.cpu().numpy(),
         | 
| 268 | 
            -
                            "spks": spks.cpu().numpy(),
         | 
| 269 | 
            -
                            "cond": cond.cpu().numpy(),
         | 
| 270 | 
            -
                        }
         | 
| 271 | 
            -
                        output = self.estimator.run(None, ort_inputs)[0]
         | 
| 272 | 
            -
                        return torch.tensor(output, dtype=x.dtype, device=x.device)
         | 
| 273 | 
            -
             | 
| 274 | 
            -
                def compute_loss(self, x1, mask, mu, spks=None, cond=None):
         | 
| 275 | 
            -
                    """Computes diffusion loss
         | 
| 276 | 
            -
             | 
| 277 | 
            -
                    Args:
         | 
| 278 | 
            -
                        x1 (torch.Tensor): Target
         | 
| 279 | 
            -
                            shape: (batch_size, n_feats, mel_timesteps)
         | 
| 280 | 
            -
                        mask (torch.Tensor): target mask
         | 
| 281 | 
            -
                            shape: (batch_size, 1, mel_timesteps)
         | 
| 282 | 
            -
                        mu (torch.Tensor): output of encoder
         | 
| 283 | 
            -
                            shape: (batch_size, n_feats, mel_timesteps)
         | 
| 284 | 
            -
                        spks (torch.Tensor, optional): speaker embedding. Defaults to None.
         | 
| 285 | 
            -
                            shape: (batch_size, spk_emb_dim)
         | 
| 286 | 
            -
             | 
| 287 | 
            -
                    Returns:
         | 
| 288 | 
            -
                        loss: conditional flow matching loss
         | 
| 289 | 
            -
                        y: conditional flow
         | 
| 290 | 
            -
                            shape: (batch_size, n_feats, mel_timesteps)
         | 
| 291 | 
            -
                    """
         | 
| 292 | 
            -
                    b, _, t = mu.shape
         | 
| 293 | 
            -
             | 
| 294 | 
            -
                    # random timestep
         | 
| 295 | 
            -
                    t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
         | 
| 296 | 
            -
                    if self.t_scheduler == "cosine":
         | 
| 297 | 
            -
                        t = 1 - torch.cos(t * 0.5 * torch.pi)
         | 
| 298 | 
            -
                    # sample noise p(x_0)
         | 
| 299 | 
            -
                    z = torch.randn_like(x1)
         | 
| 300 | 
            -
             | 
| 301 | 
            -
                    y = (1 - (1 - self.sigma_min) * t) * z + t * x1
         | 
| 302 | 
            -
                    u = x1 - (1 - self.sigma_min) * z
         | 
| 303 | 
            -
             | 
| 304 | 
            -
                    # during training, we randomly drop condition to trade off mode coverage and sample fidelity
         | 
| 305 | 
            -
                    if self.training_cfg_rate > 0:
         | 
| 306 | 
            -
                        cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate
         | 
| 307 | 
            -
                        mu = mu * cfg_mask.view(-1, 1, 1)
         | 
| 308 | 
            -
                        spks = spks * cfg_mask.view(-1, 1)
         | 
| 309 | 
            -
                        cond = cond * cfg_mask.view(-1, 1, 1)
         | 
| 310 | 
            -
             | 
| 311 | 
            -
                    pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond)
         | 
| 312 | 
            -
                    loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (
         | 
| 313 | 
            -
                        torch.sum(mask) * u.shape[1]
         | 
| 314 | 
            -
                    )
         | 
| 315 | 
            -
                    return loss, y
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        cosyvoice/flow/length_regulator.py
    DELETED
    
    | @@ -1,65 +0,0 @@ | |
| 1 | 
            -
            # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
         | 
| 2 | 
            -
            #
         | 
| 3 | 
            -
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 4 | 
            -
            # you may not use this file except in compliance with the License.
         | 
| 5 | 
            -
            # You may obtain a copy of the License at
         | 
| 6 | 
            -
            #
         | 
| 7 | 
            -
            #     http://www.apache.org/licenses/LICENSE-2.0
         | 
| 8 | 
            -
            #
         | 
| 9 | 
            -
            # Unless required by applicable law or agreed to in writing, software
         | 
| 10 | 
            -
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 11 | 
            -
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 12 | 
            -
            # See the License for the specific language governing permissions and
         | 
| 13 | 
            -
            # limitations under the License.
         | 
| 14 | 
            -
            from typing import Tuple
         | 
| 15 | 
            -
            import torch.nn as nn
         | 
| 16 | 
            -
            import torch
         | 
| 17 | 
            -
            from torch.nn import functional as F
         | 
| 18 | 
            -
            from cosyvoice.utils.mask import make_pad_mask
         | 
| 19 | 
            -
             | 
| 20 | 
            -
             | 
| 21 | 
            -
            class InterpolateRegulator(nn.Module):
         | 
| 22 | 
            -
                def __init__(
         | 
| 23 | 
            -
                    self,
         | 
| 24 | 
            -
                    channels: int,
         | 
| 25 | 
            -
                    sampling_ratios: Tuple,
         | 
| 26 | 
            -
                    out_channels: int = None,
         | 
| 27 | 
            -
                    groups: int = 1,
         | 
| 28 | 
            -
                ):
         | 
| 29 | 
            -
                    super().__init__()
         | 
| 30 | 
            -
                    self.sampling_ratios = sampling_ratios
         | 
| 31 | 
            -
                    out_channels = out_channels or channels
         | 
| 32 | 
            -
                    model = nn.ModuleList([])
         | 
| 33 | 
            -
                    if len(sampling_ratios) > 0:
         | 
| 34 | 
            -
                        for _ in sampling_ratios:
         | 
| 35 | 
            -
                            module = nn.Conv1d(channels, channels, 3, 1, 1)
         | 
| 36 | 
            -
                            norm = nn.GroupNorm(groups, channels)
         | 
| 37 | 
            -
                            act = nn.Mish()
         | 
| 38 | 
            -
                            model.extend([module, norm, act])
         | 
| 39 | 
            -
                    model.append(nn.Conv1d(channels, out_channels, 1, 1))
         | 
| 40 | 
            -
                    self.model = nn.Sequential(*model)
         | 
| 41 | 
            -
             | 
| 42 | 
            -
                def forward(self, x, ylens=None):
         | 
| 43 | 
            -
                    # x in (B, T, D)
         | 
| 44 | 
            -
                    mask = (~make_pad_mask(ylens)).to(x).unsqueeze(-1)
         | 
| 45 | 
            -
                    x = F.interpolate(
         | 
| 46 | 
            -
                        x.transpose(1, 2).contiguous(), size=ylens.max(), mode="linear"
         | 
| 47 | 
            -
                    )
         | 
| 48 | 
            -
                    out = self.model(x).transpose(1, 2).contiguous()
         | 
| 49 | 
            -
                    olens = ylens
         | 
| 50 | 
            -
                    return out * mask, olens
         | 
| 51 | 
            -
             | 
| 52 | 
            -
                def inference(self, x1, x2, mel_len1, mel_len2):
         | 
| 53 | 
            -
                    # x in (B, T, D)
         | 
| 54 | 
            -
                    x2 = F.interpolate(
         | 
| 55 | 
            -
                        x2.transpose(1, 2).contiguous(), size=mel_len2, mode="linear"
         | 
| 56 | 
            -
                    )
         | 
| 57 | 
            -
                    if x1.shape[1] != 0:
         | 
| 58 | 
            -
                        x1 = F.interpolate(
         | 
| 59 | 
            -
                            x1.transpose(1, 2).contiguous(), size=mel_len1, mode="linear"
         | 
| 60 | 
            -
                        )
         | 
| 61 | 
            -
                        x = torch.concat([x1, x2], dim=2)
         | 
| 62 | 
            -
                    else:
         | 
| 63 | 
            -
                        x = x2
         | 
| 64 | 
            -
                    out = self.model(x).transpose(1, 2).contiguous()
         | 
| 65 | 
            -
                    return out, mel_len1 + mel_len2
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        cosyvoice/hifigan/f0_predictor.py
    DELETED
    
    | @@ -1,55 +0,0 @@ | |
| 1 | 
            -
            # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu)
         | 
| 2 | 
            -
            #
         | 
| 3 | 
            -
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 4 | 
            -
            # you may not use this file except in compliance with the License.
         | 
| 5 | 
            -
            # You may obtain a copy of the License at
         | 
| 6 | 
            -
            #
         | 
| 7 | 
            -
            #     http://www.apache.org/licenses/LICENSE-2.0
         | 
| 8 | 
            -
            #
         | 
| 9 | 
            -
            # Unless required by applicable law or agreed to in writing, software
         | 
| 10 | 
            -
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 11 | 
            -
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 12 | 
            -
            # See the License for the specific language governing permissions and
         | 
| 13 | 
            -
            # limitations under the License.
         | 
| 14 | 
            -
            import torch
         | 
| 15 | 
            -
            import torch.nn as nn
         | 
| 16 | 
            -
            from torch.nn.utils import weight_norm
         | 
| 17 | 
            -
             | 
| 18 | 
            -
             | 
| 19 | 
            -
            class ConvRNNF0Predictor(nn.Module):
         | 
| 20 | 
            -
                def __init__(
         | 
| 21 | 
            -
                    self, num_class: int = 1, in_channels: int = 80, cond_channels: int = 512
         | 
| 22 | 
            -
                ):
         | 
| 23 | 
            -
                    super().__init__()
         | 
| 24 | 
            -
             | 
| 25 | 
            -
                    self.num_class = num_class
         | 
| 26 | 
            -
                    self.condnet = nn.Sequential(
         | 
| 27 | 
            -
                        weight_norm(
         | 
| 28 | 
            -
                            nn.Conv1d(in_channels, cond_channels, kernel_size=3, padding=1)
         | 
| 29 | 
            -
                        ),
         | 
| 30 | 
            -
                        nn.ELU(),
         | 
| 31 | 
            -
                        weight_norm(
         | 
| 32 | 
            -
                            nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
         | 
| 33 | 
            -
                        ),
         | 
| 34 | 
            -
                        nn.ELU(),
         | 
| 35 | 
            -
                        weight_norm(
         | 
| 36 | 
            -
                            nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
         | 
| 37 | 
            -
                        ),
         | 
| 38 | 
            -
                        nn.ELU(),
         | 
| 39 | 
            -
                        weight_norm(
         | 
| 40 | 
            -
                            nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
         | 
| 41 | 
            -
                        ),
         | 
| 42 | 
            -
                        nn.ELU(),
         | 
| 43 | 
            -
                        weight_norm(
         | 
| 44 | 
            -
                            nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
         | 
| 45 | 
            -
                        ),
         | 
| 46 | 
            -
                        nn.ELU(),
         | 
| 47 | 
            -
                    )
         | 
| 48 | 
            -
                    self.classifier = nn.Linear(
         | 
| 49 | 
            -
                        in_features=cond_channels, out_features=self.num_class
         | 
| 50 | 
            -
                    )
         | 
| 51 | 
            -
             | 
| 52 | 
            -
                def forward(self, x: torch.Tensor) -> torch.Tensor:
         | 
| 53 | 
            -
                    x = self.condnet(x)
         | 
| 54 | 
            -
                    x = x.transpose(1, 2)
         | 
| 55 | 
            -
                    return torch.abs(self.classifier(x).squeeze(-1))
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        cosyvoice/hifigan/generator.py
    DELETED
    
    | @@ -1,566 +0,0 @@ | |
| 1 | 
            -
            # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu)
         | 
| 2 | 
            -
            #
         | 
| 3 | 
            -
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 4 | 
            -
            # you may not use this file except in compliance with the License.
         | 
| 5 | 
            -
            # You may obtain a copy of the License at
         | 
| 6 | 
            -
            #
         | 
| 7 | 
            -
            #     http://www.apache.org/licenses/LICENSE-2.0
         | 
| 8 | 
            -
            #
         | 
| 9 | 
            -
            # Unless required by applicable law or agreed to in writing, software
         | 
| 10 | 
            -
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 11 | 
            -
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 12 | 
            -
            # See the License for the specific language governing permissions and
         | 
| 13 | 
            -
            # limitations under the License.
         | 
| 14 | 
            -
             | 
| 15 | 
            -
            """HIFI-GAN"""
         | 
| 16 | 
            -
             | 
| 17 | 
            -
            import typing as tp
         | 
| 18 | 
            -
            import time
         | 
| 19 | 
            -
            import numpy as np
         | 
| 20 | 
            -
            from scipy.signal import get_window
         | 
| 21 | 
            -
            import torch
         | 
| 22 | 
            -
            import torch.nn as nn
         | 
| 23 | 
            -
            import torch.nn.functional as F
         | 
| 24 | 
            -
            from torch.nn import Conv1d
         | 
| 25 | 
            -
            from torch.nn import ConvTranspose1d
         | 
| 26 | 
            -
            from torch.nn.utils import remove_weight_norm
         | 
| 27 | 
            -
            from torch.nn.utils import weight_norm
         | 
| 28 | 
            -
            from torch.distributions.uniform import Uniform
         | 
| 29 | 
            -
             | 
| 30 | 
            -
            from cosyvoice.transformer.activation import Snake
         | 
| 31 | 
            -
            from cosyvoice.utils.common import get_padding
         | 
| 32 | 
            -
            from cosyvoice.utils.common import init_weights
         | 
| 33 | 
            -
             | 
| 34 | 
            -
             | 
| 35 | 
            -
            """hifigan based generator implementation.
         | 
| 36 | 
            -
             | 
| 37 | 
            -
            This code is modified from https://github.com/jik876/hifi-gan
         | 
| 38 | 
            -
             ,https://github.com/kan-bayashi/ParallelWaveGAN and
         | 
| 39 | 
            -
             https://github.com/NVIDIA/BigVGAN
         | 
| 40 | 
            -
             | 
| 41 | 
            -
            """
         | 
| 42 | 
            -
             | 
| 43 | 
            -
             | 
| 44 | 
            -
            class ResBlock(torch.nn.Module):
         | 
| 45 | 
            -
                """Residual block module in HiFiGAN/BigVGAN."""
         | 
| 46 | 
            -
             | 
| 47 | 
            -
                def __init__(
         | 
| 48 | 
            -
                    self,
         | 
| 49 | 
            -
                    channels: int = 512,
         | 
| 50 | 
            -
                    kernel_size: int = 3,
         | 
| 51 | 
            -
                    dilations: tp.List[int] = [1, 3, 5],
         | 
| 52 | 
            -
                ):
         | 
| 53 | 
            -
                    super(ResBlock, self).__init__()
         | 
| 54 | 
            -
                    self.convs1 = nn.ModuleList()
         | 
| 55 | 
            -
                    self.convs2 = nn.ModuleList()
         | 
| 56 | 
            -
             | 
| 57 | 
            -
                    for dilation in dilations:
         | 
| 58 | 
            -
                        self.convs1.append(
         | 
| 59 | 
            -
                            weight_norm(
         | 
| 60 | 
            -
                                Conv1d(
         | 
| 61 | 
            -
                                    channels,
         | 
| 62 | 
            -
                                    channels,
         | 
| 63 | 
            -
                                    kernel_size,
         | 
| 64 | 
            -
                                    1,
         | 
| 65 | 
            -
                                    dilation=dilation,
         | 
| 66 | 
            -
                                    padding=get_padding(kernel_size, dilation),
         | 
| 67 | 
            -
                                )
         | 
| 68 | 
            -
                            )
         | 
| 69 | 
            -
                        )
         | 
| 70 | 
            -
                        self.convs2.append(
         | 
| 71 | 
            -
                            weight_norm(
         | 
| 72 | 
            -
                                Conv1d(
         | 
| 73 | 
            -
                                    channels,
         | 
| 74 | 
            -
                                    channels,
         | 
| 75 | 
            -
                                    kernel_size,
         | 
| 76 | 
            -
                                    1,
         | 
| 77 | 
            -
                                    dilation=1,
         | 
| 78 | 
            -
                                    padding=get_padding(kernel_size, 1),
         | 
| 79 | 
            -
                                )
         | 
| 80 | 
            -
                            )
         | 
| 81 | 
            -
                        )
         | 
| 82 | 
            -
                    self.convs1.apply(init_weights)
         | 
| 83 | 
            -
                    self.convs2.apply(init_weights)
         | 
| 84 | 
            -
                    self.activations1 = nn.ModuleList(
         | 
| 85 | 
            -
                        [Snake(channels, alpha_logscale=False) for _ in range(len(self.convs1))]
         | 
| 86 | 
            -
                    )
         | 
| 87 | 
            -
                    self.activations2 = nn.ModuleList(
         | 
| 88 | 
            -
                        [Snake(channels, alpha_logscale=False) for _ in range(len(self.convs2))]
         | 
| 89 | 
            -
                    )
         | 
| 90 | 
            -
             | 
| 91 | 
            -
                def forward(self, x: torch.Tensor) -> torch.Tensor:
         | 
| 92 | 
            -
                    for idx in range(len(self.convs1)):
         | 
| 93 | 
            -
                        xt = self.activations1[idx](x)
         | 
| 94 | 
            -
                        xt = self.convs1[idx](xt)
         | 
| 95 | 
            -
                        xt = self.activations2[idx](xt)
         | 
| 96 | 
            -
                        xt = self.convs2[idx](xt)
         | 
| 97 | 
            -
                        x = xt + x
         | 
| 98 | 
            -
                    return x
         | 
| 99 | 
            -
             | 
| 100 | 
            -
                def remove_weight_norm(self):
         | 
| 101 | 
            -
                    for idx in range(len(self.convs1)):
         | 
| 102 | 
            -
                        remove_weight_norm(self.convs1[idx])
         | 
| 103 | 
            -
                        remove_weight_norm(self.convs2[idx])
         | 
| 104 | 
            -
             | 
| 105 | 
            -
             | 
| 106 | 
            -
            class SineGen(torch.nn.Module):
         | 
| 107 | 
            -
                """Definition of sine generator
         | 
| 108 | 
            -
                SineGen(samp_rate, harmonic_num = 0,
         | 
| 109 | 
            -
                        sine_amp = 0.1, noise_std = 0.003,
         | 
| 110 | 
            -
                        voiced_threshold = 0,
         | 
| 111 | 
            -
                        flag_for_pulse=False)
         | 
| 112 | 
            -
                samp_rate: sampling rate in Hz
         | 
| 113 | 
            -
                harmonic_num: number of harmonic overtones (default 0)
         | 
| 114 | 
            -
                sine_amp: amplitude of sine-wavefrom (default 0.1)
         | 
| 115 | 
            -
                noise_std: std of Gaussian noise (default 0.003)
         | 
| 116 | 
            -
                voiced_thoreshold: F0 threshold for U/V classification (default 0)
         | 
| 117 | 
            -
                flag_for_pulse: this SinGen is used inside PulseGen (default False)
         | 
| 118 | 
            -
                Note: when flag_for_pulse is True, the first time step of a voiced
         | 
| 119 | 
            -
                    segment is always sin(np.pi) or cos(0)
         | 
| 120 | 
            -
                """
         | 
| 121 | 
            -
             | 
| 122 | 
            -
                def __init__(
         | 
| 123 | 
            -
                    self,
         | 
| 124 | 
            -
                    samp_rate,
         | 
| 125 | 
            -
                    harmonic_num=0,
         | 
| 126 | 
            -
                    sine_amp=0.1,
         | 
| 127 | 
            -
                    noise_std=0.003,
         | 
| 128 | 
            -
                    voiced_threshold=0,
         | 
| 129 | 
            -
                ):
         | 
| 130 | 
            -
                    super(SineGen, self).__init__()
         | 
| 131 | 
            -
                    self.sine_amp = sine_amp
         | 
| 132 | 
            -
                    self.noise_std = noise_std
         | 
| 133 | 
            -
                    self.harmonic_num = harmonic_num
         | 
| 134 | 
            -
                    self.sampling_rate = samp_rate
         | 
| 135 | 
            -
                    self.voiced_threshold = voiced_threshold
         | 
| 136 | 
            -
             | 
| 137 | 
            -
                def _f02uv(self, f0):
         | 
| 138 | 
            -
                    # generate uv signal
         | 
| 139 | 
            -
                    uv = (f0 > self.voiced_threshold).type(torch.float32)
         | 
| 140 | 
            -
                    return uv
         | 
| 141 | 
            -
             | 
| 142 | 
            -
                @torch.no_grad()
         | 
| 143 | 
            -
                def forward(self, f0):
         | 
| 144 | 
            -
                    """
         | 
| 145 | 
            -
                    :param f0: [B, 1, sample_len], Hz
         | 
| 146 | 
            -
                    :return: [B, 1, sample_len]
         | 
| 147 | 
            -
                    """
         | 
| 148 | 
            -
             | 
| 149 | 
            -
                    F_mat = torch.zeros((f0.size(0), self.harmonic_num + 1, f0.size(-1))).to(
         | 
| 150 | 
            -
                        f0.device
         | 
| 151 | 
            -
                    )
         | 
| 152 | 
            -
                    for i in range(self.harmonic_num + 1):
         | 
| 153 | 
            -
                        F_mat[:, i : i + 1, :] = f0 * (i + 1) / self.sampling_rate
         | 
| 154 | 
            -
             | 
| 155 | 
            -
                    theta_mat = 2 * np.pi * (torch.cumsum(F_mat, dim=-1) % 1)
         | 
| 156 | 
            -
                    u_dist = Uniform(low=-np.pi, high=np.pi)
         | 
| 157 | 
            -
                    phase_vec = u_dist.sample(
         | 
| 158 | 
            -
                        sample_shape=(f0.size(0), self.harmonic_num + 1, 1)
         | 
| 159 | 
            -
                    ).to(F_mat.device)
         | 
| 160 | 
            -
                    phase_vec[:, 0, :] = 0
         | 
| 161 | 
            -
             | 
| 162 | 
            -
                    # generate sine waveforms
         | 
| 163 | 
            -
                    sine_waves = self.sine_amp * torch.sin(theta_mat + phase_vec)
         | 
| 164 | 
            -
             | 
| 165 | 
            -
                    # generate uv signal
         | 
| 166 | 
            -
                    uv = self._f02uv(f0)
         | 
| 167 | 
            -
             | 
| 168 | 
            -
                    # noise: for unvoiced should be similar to sine_amp
         | 
| 169 | 
            -
                    #        std = self.sine_amp/3 -> max value ~ self.sine_amp
         | 
| 170 | 
            -
                    # .       for voiced regions is self.noise_std
         | 
| 171 | 
            -
                    noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
         | 
| 172 | 
            -
                    noise = noise_amp * torch.randn_like(sine_waves)
         | 
| 173 | 
            -
             | 
| 174 | 
            -
                    # first: set the unvoiced part to 0 by uv
         | 
| 175 | 
            -
                    # then: additive noise
         | 
| 176 | 
            -
                    sine_waves = sine_waves * uv + noise
         | 
| 177 | 
            -
                    return sine_waves, uv, noise
         | 
| 178 | 
            -
             | 
| 179 | 
            -
             | 
| 180 | 
            -
            class SourceModuleHnNSF(torch.nn.Module):
         | 
| 181 | 
            -
                """SourceModule for hn-nsf
         | 
| 182 | 
            -
                SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
         | 
| 183 | 
            -
                             add_noise_std=0.003, voiced_threshod=0)
         | 
| 184 | 
            -
                sampling_rate: sampling_rate in Hz
         | 
| 185 | 
            -
                harmonic_num: number of harmonic above F0 (default: 0)
         | 
| 186 | 
            -
                sine_amp: amplitude of sine source signal (default: 0.1)
         | 
| 187 | 
            -
                add_noise_std: std of additive Gaussian noise (default: 0.003)
         | 
| 188 | 
            -
                    note that amplitude of noise in unvoiced is decided
         | 
| 189 | 
            -
                    by sine_amp
         | 
| 190 | 
            -
                voiced_threshold: threhold to set U/V given F0 (default: 0)
         | 
| 191 | 
            -
                Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
         | 
| 192 | 
            -
                F0_sampled (batchsize, length, 1)
         | 
| 193 | 
            -
                Sine_source (batchsize, length, 1)
         | 
| 194 | 
            -
                noise_source (batchsize, length 1)
         | 
| 195 | 
            -
                uv (batchsize, length, 1)
         | 
| 196 | 
            -
                """
         | 
| 197 | 
            -
             | 
| 198 | 
            -
                def __init__(
         | 
| 199 | 
            -
                    self,
         | 
| 200 | 
            -
                    sampling_rate,
         | 
| 201 | 
            -
                    upsample_scale,
         | 
| 202 | 
            -
                    harmonic_num=0,
         | 
| 203 | 
            -
                    sine_amp=0.1,
         | 
| 204 | 
            -
                    add_noise_std=0.003,
         | 
| 205 | 
            -
                    voiced_threshod=0,
         | 
| 206 | 
            -
                ):
         | 
| 207 | 
            -
                    super(SourceModuleHnNSF, self).__init__()
         | 
| 208 | 
            -
             | 
| 209 | 
            -
                    self.sine_amp = sine_amp
         | 
| 210 | 
            -
                    self.noise_std = add_noise_std
         | 
| 211 | 
            -
             | 
| 212 | 
            -
                    # to produce sine waveforms
         | 
| 213 | 
            -
                    self.l_sin_gen = SineGen(
         | 
| 214 | 
            -
                        sampling_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshod
         | 
| 215 | 
            -
                    )
         | 
| 216 | 
            -
             | 
| 217 | 
            -
                    # to merge source harmonics into a single excitation
         | 
| 218 | 
            -
                    self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
         | 
| 219 | 
            -
                    self.l_tanh = torch.nn.Tanh()
         | 
| 220 | 
            -
             | 
| 221 | 
            -
                def forward(self, x):
         | 
| 222 | 
            -
                    """
         | 
| 223 | 
            -
                    Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
         | 
| 224 | 
            -
                    F0_sampled (batchsize, length, 1)
         | 
| 225 | 
            -
                    Sine_source (batchsize, length, 1)
         | 
| 226 | 
            -
                    noise_source (batchsize, length 1)
         | 
| 227 | 
            -
                    """
         | 
| 228 | 
            -
                    # source for harmonic branch
         | 
| 229 | 
            -
                    with torch.no_grad():
         | 
| 230 | 
            -
                        sine_wavs, uv, _ = self.l_sin_gen(x.transpose(1, 2))
         | 
| 231 | 
            -
                        sine_wavs = sine_wavs.transpose(1, 2)
         | 
| 232 | 
            -
                        uv = uv.transpose(1, 2)
         | 
| 233 | 
            -
                    sine_merge = self.l_tanh(self.l_linear(sine_wavs))
         | 
| 234 | 
            -
             | 
| 235 | 
            -
                    # source for noise branch, in the same shape as uv
         | 
| 236 | 
            -
                    noise = torch.randn_like(uv) * self.sine_amp / 3
         | 
| 237 | 
            -
                    return sine_merge, noise, uv
         | 
| 238 | 
            -
             | 
| 239 | 
            -
             | 
| 240 | 
            -
            class HiFTGenerator(nn.Module):
         | 
| 241 | 
            -
                """
         | 
| 242 | 
            -
                HiFTNet Generator: Neural Source Filter + ISTFTNet
         | 
| 243 | 
            -
                https://arxiv.org/abs/2309.09493
         | 
| 244 | 
            -
                """
         | 
| 245 | 
            -
             | 
| 246 | 
            -
                def __init__(
         | 
| 247 | 
            -
                    self,
         | 
| 248 | 
            -
                    in_channels: int = 80,
         | 
| 249 | 
            -
                    base_channels: int = 512,
         | 
| 250 | 
            -
                    nb_harmonics: int = 8,
         | 
| 251 | 
            -
                    sampling_rate: int = 22050,
         | 
| 252 | 
            -
                    nsf_alpha: float = 0.1,
         | 
| 253 | 
            -
                    nsf_sigma: float = 0.003,
         | 
| 254 | 
            -
                    nsf_voiced_threshold: float = 10,
         | 
| 255 | 
            -
                    upsample_rates: tp.List[int] = [8, 8],
         | 
| 256 | 
            -
                    upsample_kernel_sizes: tp.List[int] = [16, 16],
         | 
| 257 | 
            -
                    istft_params: tp.Dict[str, int] = {"n_fft": 16, "hop_len": 4},
         | 
| 258 | 
            -
                    resblock_kernel_sizes: tp.List[int] = [3, 7, 11],
         | 
| 259 | 
            -
                    resblock_dilation_sizes: tp.List[tp.List[int]] = [
         | 
| 260 | 
            -
                        [1, 3, 5],
         | 
| 261 | 
            -
                        [1, 3, 5],
         | 
| 262 | 
            -
                        [1, 3, 5],
         | 
| 263 | 
            -
                    ],
         | 
| 264 | 
            -
                    source_resblock_kernel_sizes: tp.List[int] = [7, 11],
         | 
| 265 | 
            -
                    source_resblock_dilation_sizes: tp.List[tp.List[int]] = [[1, 3, 5], [1, 3, 5]],
         | 
| 266 | 
            -
                    lrelu_slope: float = 0.1,
         | 
| 267 | 
            -
                    audio_limit: float = 0.99,
         | 
| 268 | 
            -
                    f0_predictor: torch.nn.Module = None,
         | 
| 269 | 
            -
                ):
         | 
| 270 | 
            -
                    super(HiFTGenerator, self).__init__()
         | 
| 271 | 
            -
             | 
| 272 | 
            -
                    self.out_channels = 1
         | 
| 273 | 
            -
                    self.nb_harmonics = nb_harmonics
         | 
| 274 | 
            -
                    self.sampling_rate = sampling_rate
         | 
| 275 | 
            -
                    self.istft_params = istft_params
         | 
| 276 | 
            -
                    self.lrelu_slope = lrelu_slope
         | 
| 277 | 
            -
                    self.audio_limit = audio_limit
         | 
| 278 | 
            -
             | 
| 279 | 
            -
                    self.num_kernels = len(resblock_kernel_sizes)
         | 
| 280 | 
            -
                    self.num_upsamples = len(upsample_rates)
         | 
| 281 | 
            -
                    self.upsample_rates = upsample_rates
         | 
| 282 | 
            -
                    self.m_source = SourceModuleHnNSF(
         | 
| 283 | 
            -
                        sampling_rate=sampling_rate,
         | 
| 284 | 
            -
                        upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
         | 
| 285 | 
            -
                        harmonic_num=nb_harmonics,
         | 
| 286 | 
            -
                        sine_amp=nsf_alpha,
         | 
| 287 | 
            -
                        add_noise_std=nsf_sigma,
         | 
| 288 | 
            -
                        voiced_threshod=nsf_voiced_threshold,
         | 
| 289 | 
            -
                    )
         | 
| 290 | 
            -
                    self.f0_upsamp = torch.nn.Upsample(
         | 
| 291 | 
            -
                        scale_factor=np.prod(upsample_rates) * istft_params["hop_len"]
         | 
| 292 | 
            -
                    )
         | 
| 293 | 
            -
             | 
| 294 | 
            -
                    self.conv_pre = weight_norm(Conv1d(in_channels, base_channels, 7, 1, padding=3))
         | 
| 295 | 
            -
             | 
| 296 | 
            -
                    # Up
         | 
| 297 | 
            -
                    self.ups = nn.ModuleList()
         | 
| 298 | 
            -
                    for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
         | 
| 299 | 
            -
                        self.ups.append(
         | 
| 300 | 
            -
                            weight_norm(
         | 
| 301 | 
            -
                                ConvTranspose1d(
         | 
| 302 | 
            -
                                    base_channels // (2**i),
         | 
| 303 | 
            -
                                    base_channels // (2 ** (i + 1)),
         | 
| 304 | 
            -
                                    k,
         | 
| 305 | 
            -
                                    u,
         | 
| 306 | 
            -
                                    padding=(k - u) // 2,
         | 
| 307 | 
            -
                                )
         | 
| 308 | 
            -
                            )
         | 
| 309 | 
            -
                        )
         | 
| 310 | 
            -
             | 
| 311 | 
            -
                    # Down
         | 
| 312 | 
            -
                    self.source_downs = nn.ModuleList()
         | 
| 313 | 
            -
                    self.source_resblocks = nn.ModuleList()
         | 
| 314 | 
            -
                    downsample_rates = [1] + upsample_rates[::-1][:-1]
         | 
| 315 | 
            -
                    downsample_cum_rates = np.cumprod(downsample_rates)
         | 
| 316 | 
            -
                    for i, (u, k, d) in enumerate(
         | 
| 317 | 
            -
                        zip(
         | 
| 318 | 
            -
                            downsample_cum_rates[::-1],
         | 
| 319 | 
            -
                            source_resblock_kernel_sizes,
         | 
| 320 | 
            -
                            source_resblock_dilation_sizes,
         | 
| 321 | 
            -
                        )
         | 
| 322 | 
            -
                    ):
         | 
| 323 | 
            -
                        if u == 1:
         | 
| 324 | 
            -
                            self.source_downs.append(
         | 
| 325 | 
            -
                                Conv1d(
         | 
| 326 | 
            -
                                    istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1
         | 
| 327 | 
            -
                                )
         | 
| 328 | 
            -
                            )
         | 
| 329 | 
            -
                        else:
         | 
| 330 | 
            -
                            self.source_downs.append(
         | 
| 331 | 
            -
                                Conv1d(
         | 
| 332 | 
            -
                                    istft_params["n_fft"] + 2,
         | 
| 333 | 
            -
                                    base_channels // (2 ** (i + 1)),
         | 
| 334 | 
            -
                                    u * 2,
         | 
| 335 | 
            -
                                    u,
         | 
| 336 | 
            -
                                    padding=(u // 2),
         | 
| 337 | 
            -
                                )
         | 
| 338 | 
            -
                            )
         | 
| 339 | 
            -
             | 
| 340 | 
            -
                        self.source_resblocks.append(
         | 
| 341 | 
            -
                            ResBlock(base_channels // (2 ** (i + 1)), k, d)
         | 
| 342 | 
            -
                        )
         | 
| 343 | 
            -
             | 
| 344 | 
            -
                    self.resblocks = nn.ModuleList()
         | 
| 345 | 
            -
                    for i in range(len(self.ups)):
         | 
| 346 | 
            -
                        ch = base_channels // (2 ** (i + 1))
         | 
| 347 | 
            -
                        for _, (k, d) in enumerate(
         | 
| 348 | 
            -
                            zip(resblock_kernel_sizes, resblock_dilation_sizes)
         | 
| 349 | 
            -
                        ):
         | 
| 350 | 
            -
                            self.resblocks.append(ResBlock(ch, k, d))
         | 
| 351 | 
            -
             | 
| 352 | 
            -
                    self.conv_post = weight_norm(
         | 
| 353 | 
            -
                        Conv1d(ch, istft_params["n_fft"] + 2, 7, 1, padding=3)
         | 
| 354 | 
            -
                    )
         | 
| 355 | 
            -
                    self.ups.apply(init_weights)
         | 
| 356 | 
            -
                    self.conv_post.apply(init_weights)
         | 
| 357 | 
            -
                    self.reflection_pad = nn.ReflectionPad1d((1, 0))
         | 
| 358 | 
            -
                    self.stft_window = torch.from_numpy(
         | 
| 359 | 
            -
                        get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32)
         | 
| 360 | 
            -
                    ).cuda()
         | 
| 361 | 
            -
                    self.f0_predictor = f0_predictor
         | 
| 362 | 
            -
                    self.inference_buffers = {}
         | 
| 363 | 
            -
                    self.inference_graphs = {}
         | 
| 364 | 
            -
             | 
| 365 | 
            -
                def _f02source(self, f0: torch.Tensor) -> torch.Tensor:
         | 
| 366 | 
            -
                    f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2)  # bs,n,t
         | 
| 367 | 
            -
             | 
| 368 | 
            -
                    har_source, _, _ = self.m_source(f0)
         | 
| 369 | 
            -
                    return har_source.transpose(1, 2)
         | 
| 370 | 
            -
             | 
| 371 | 
            -
                def _stft(self, x):
         | 
| 372 | 
            -
                    spec = torch.stft(
         | 
| 373 | 
            -
                        x,
         | 
| 374 | 
            -
                        self.istft_params["n_fft"],
         | 
| 375 | 
            -
                        self.istft_params["hop_len"],
         | 
| 376 | 
            -
                        self.istft_params["n_fft"],
         | 
| 377 | 
            -
                        window=self.stft_window,
         | 
| 378 | 
            -
                        return_complex=True,
         | 
| 379 | 
            -
                    )
         | 
| 380 | 
            -
                    spec = torch.view_as_real(spec)  # [B, F, TT, 2]
         | 
| 381 | 
            -
                    return spec[..., 0], spec[..., 1]
         | 
| 382 | 
            -
             | 
| 383 | 
            -
                def _istft(self, magnitude, phase):
         | 
| 384 | 
            -
                    magnitude = torch.clip(magnitude, max=1e2)
         | 
| 385 | 
            -
                    real = magnitude * torch.cos(phase)
         | 
| 386 | 
            -
                    img = magnitude * torch.sin(phase)
         | 
| 387 | 
            -
                    inverse_transform = torch.istft(
         | 
| 388 | 
            -
                        torch.complex(real, img),
         | 
| 389 | 
            -
                        self.istft_params["n_fft"],
         | 
| 390 | 
            -
                        self.istft_params["hop_len"],
         | 
| 391 | 
            -
                        self.istft_params["n_fft"],
         | 
| 392 | 
            -
                        window=self.stft_window,
         | 
| 393 | 
            -
                    )
         | 
| 394 | 
            -
                    return inverse_transform
         | 
| 395 | 
            -
             | 
| 396 | 
            -
                def forward(
         | 
| 397 | 
            -
                    self, x: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)
         | 
| 398 | 
            -
                ) -> torch.Tensor:
         | 
| 399 | 
            -
                    f0 = self.f0_predictor(x)
         | 
| 400 | 
            -
                    s = self._f02source(f0)
         | 
| 401 | 
            -
             | 
| 402 | 
            -
                    # use cache_source to avoid glitch
         | 
| 403 | 
            -
                    if cache_source.shape[2] != 0:
         | 
| 404 | 
            -
                        s[:, :, : cache_source.shape[2]] = cache_source
         | 
| 405 | 
            -
             | 
| 406 | 
            -
                    s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
         | 
| 407 | 
            -
                    s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
         | 
| 408 | 
            -
             | 
| 409 | 
            -
                    x = self.conv_pre(x)
         | 
| 410 | 
            -
                    for i in range(self.num_upsamples):
         | 
| 411 | 
            -
                        x = F.leaky_relu(x, self.lrelu_slope)
         | 
| 412 | 
            -
                        x = self.ups[i](x)
         | 
| 413 | 
            -
             | 
| 414 | 
            -
                        if i == self.num_upsamples - 1:
         | 
| 415 | 
            -
                            x = self.reflection_pad(x)
         | 
| 416 | 
            -
             | 
| 417 | 
            -
                        # fusion
         | 
| 418 | 
            -
                        si = self.source_downs[i](s_stft)
         | 
| 419 | 
            -
                        si = self.source_resblocks[i](si)
         | 
| 420 | 
            -
                        x = x + si
         | 
| 421 | 
            -
             | 
| 422 | 
            -
                        xs = None
         | 
| 423 | 
            -
                        for j in range(self.num_kernels):
         | 
| 424 | 
            -
                            if xs is None:
         | 
| 425 | 
            -
                                xs = self.resblocks[i * self.num_kernels + j](x)
         | 
| 426 | 
            -
                            else:
         | 
| 427 | 
            -
                                xs += self.resblocks[i * self.num_kernels + j](x)
         | 
| 428 | 
            -
                        x = xs / self.num_kernels
         | 
| 429 | 
            -
             | 
| 430 | 
            -
                    x = F.leaky_relu(x)
         | 
| 431 | 
            -
                    x = self.conv_post(x)
         | 
| 432 | 
            -
                    magnitude = torch.exp(x[:, : self.istft_params["n_fft"] // 2 + 1, :])
         | 
| 433 | 
            -
                    phase = torch.sin(
         | 
| 434 | 
            -
                        x[:, self.istft_params["n_fft"] // 2 + 1 :, :]
         | 
| 435 | 
            -
                    )  # actually, sin is redundancy
         | 
| 436 | 
            -
             | 
| 437 | 
            -
                    x = self._istft(magnitude, phase)
         | 
| 438 | 
            -
                    x = torch.clamp(x, -self.audio_limit, self.audio_limit)
         | 
| 439 | 
            -
                    return x, s
         | 
| 440 | 
            -
             | 
| 441 | 
            -
                def remove_weight_norm(self):
         | 
| 442 | 
            -
                    print("Removing weight norm...")
         | 
| 443 | 
            -
                    for l in self.ups:
         | 
| 444 | 
            -
                        remove_weight_norm(l)
         | 
| 445 | 
            -
                    for l in self.resblocks:
         | 
| 446 | 
            -
                        l.remove_weight_norm()
         | 
| 447 | 
            -
                    remove_weight_norm(self.conv_pre)
         | 
| 448 | 
            -
                    remove_weight_norm(self.conv_post)
         | 
| 449 | 
            -
                    self.source_module.remove_weight_norm()
         | 
| 450 | 
            -
                    for l in self.source_downs:
         | 
| 451 | 
            -
                        remove_weight_norm(l)
         | 
| 452 | 
            -
                    for l in self.source_resblocks:
         | 
| 453 | 
            -
                        l.remove_weight_norm()
         | 
| 454 | 
            -
             | 
| 455 | 
            -
                @torch.inference_mode()
         | 
| 456 | 
            -
                def _inference_impl(self, mel: torch.Tensor, s_stft: torch.Tensor) -> torch.Tensor:
         | 
| 457 | 
            -
                    x = self.conv_pre(mel)
         | 
| 458 | 
            -
                    for i in range(self.num_upsamples):
         | 
| 459 | 
            -
                        x = F.leaky_relu(x, self.lrelu_slope)
         | 
| 460 | 
            -
                        x = self.ups[i](x)
         | 
| 461 | 
            -
             | 
| 462 | 
            -
                        if i == self.num_upsamples - 1:
         | 
| 463 | 
            -
                            x = self.reflection_pad(x)
         | 
| 464 | 
            -
             | 
| 465 | 
            -
                        # fusion
         | 
| 466 | 
            -
                        si = self.source_downs[i](s_stft)
         | 
| 467 | 
            -
                        si = self.source_resblocks[i](si)
         | 
| 468 | 
            -
                        x = x + si
         | 
| 469 | 
            -
             | 
| 470 | 
            -
                        xs = None
         | 
| 471 | 
            -
                        for j in range(self.num_kernels):
         | 
| 472 | 
            -
                            if xs is None:
         | 
| 473 | 
            -
                                xs = self.resblocks[i * self.num_kernels + j](x)
         | 
| 474 | 
            -
                            else:
         | 
| 475 | 
            -
                                xs += self.resblocks[i * self.num_kernels + j](x)
         | 
| 476 | 
            -
                        x = xs / self.num_kernels
         | 
| 477 | 
            -
             | 
| 478 | 
            -
                    x = F.leaky_relu(x)
         | 
| 479 | 
            -
                    x = self.conv_post(x)
         | 
| 480 | 
            -
                    magnitude = torch.exp(x[:, : self.istft_params["n_fft"] // 2 + 1, :])
         | 
| 481 | 
            -
                    phase = torch.sin(
         | 
| 482 | 
            -
                        x[:, self.istft_params["n_fft"] // 2 + 1 :, :]
         | 
| 483 | 
            -
                    )  # actually, sin is redundancy
         | 
| 484 | 
            -
                    # print(f"mel: {mel.shape}, magnitude: {magnitude.shape}, phase: {phase.shape}")
         | 
| 485 | 
            -
                    return magnitude, phase
         | 
| 486 | 
            -
             | 
| 487 | 
            -
                @torch.inference_mode()
         | 
| 488 | 
            -
                def inference(
         | 
| 489 | 
            -
                    self, mel: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)
         | 
| 490 | 
            -
                ) -> torch.Tensor:
         | 
| 491 | 
            -
                    curr_seq_len = mel.shape[2]
         | 
| 492 | 
            -
                    f0 = self.f0_predictor(mel)
         | 
| 493 | 
            -
                    s = self._f02source(f0)
         | 
| 494 | 
            -
                    s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
         | 
| 495 | 
            -
                    s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
         | 
| 496 | 
            -
             | 
| 497 | 
            -
                    target_len = None
         | 
| 498 | 
            -
                    for seq_len in sorted(self.inference_buffers.keys()):
         | 
| 499 | 
            -
                        if curr_seq_len <= seq_len:
         | 
| 500 | 
            -
                            target_len = seq_len
         | 
| 501 | 
            -
                            break
         | 
| 502 | 
            -
             | 
| 503 | 
            -
                    if target_len is not None:
         | 
| 504 | 
            -
                        buffer = self.inference_buffers[target_len]
         | 
| 505 | 
            -
             | 
| 506 | 
            -
                        if curr_seq_len < target_len:
         | 
| 507 | 
            -
                            padded_mel = torch.zeros_like(buffer["mel"])
         | 
| 508 | 
            -
                            padded_mel[:, :, :curr_seq_len] = mel
         | 
| 509 | 
            -
                            buffer["mel"].copy_(padded_mel)
         | 
| 510 | 
            -
                            padded_s_stft = torch.zeros_like(buffer["s_stft"])
         | 
| 511 | 
            -
                            cur_s_stft_len = s_stft.shape[2]
         | 
| 512 | 
            -
                            padded_s_stft[:, :, :cur_s_stft_len] = s_stft
         | 
| 513 | 
            -
                            buffer["s_stft"].copy_(padded_s_stft)
         | 
| 514 | 
            -
             | 
| 515 | 
            -
                        else:
         | 
| 516 | 
            -
                            buffer["mel"].copy_(mel)
         | 
| 517 | 
            -
                            buffer["s_stft"].copy_(s_stft)
         | 
| 518 | 
            -
                            cur_s_stft_len = s_stft.shape[2]
         | 
| 519 | 
            -
             | 
| 520 | 
            -
                        self.inference_graphs[target_len].replay()
         | 
| 521 | 
            -
             | 
| 522 | 
            -
                        magnitude, phase = (
         | 
| 523 | 
            -
                            buffer["magnitude"][:, :, :cur_s_stft_len],
         | 
| 524 | 
            -
                            buffer["phase"][:, :, :cur_s_stft_len],
         | 
| 525 | 
            -
                        )
         | 
| 526 | 
            -
                    else:
         | 
| 527 | 
            -
                        magnitude, phase = self._inference_impl(mel=mel, s_stft=s_stft)
         | 
| 528 | 
            -
             | 
| 529 | 
            -
                    x = self._istft(magnitude, phase)
         | 
| 530 | 
            -
                    x = torch.clamp(x, -self.audio_limit, self.audio_limit)
         | 
| 531 | 
            -
                    return x, s
         | 
| 532 | 
            -
             | 
| 533 | 
            -
                @torch.inference_mode()
         | 
| 534 | 
            -
                def capture_inference(self, seq_len_to_capture=[64, 128, 256, 512, 1024]):
         | 
| 535 | 
            -
                    start_time = time.time()
         | 
| 536 | 
            -
                    print(
         | 
| 537 | 
            -
                        f"capture inference for HiFTGenerator with seq_len_to_capture: {seq_len_to_capture}"
         | 
| 538 | 
            -
                    )
         | 
| 539 | 
            -
                    for seq_len in seq_len_to_capture:
         | 
| 540 | 
            -
                        mel = torch.randn(
         | 
| 541 | 
            -
                            1, 80, seq_len, device=torch.device("cuda"), dtype=torch.float32
         | 
| 542 | 
            -
                        )
         | 
| 543 | 
            -
                        f0 = self.f0_predictor(mel)
         | 
| 544 | 
            -
                        s = self._f02source(f0)
         | 
| 545 | 
            -
                        s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
         | 
| 546 | 
            -
                        s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
         | 
| 547 | 
            -
             | 
| 548 | 
            -
                        magnitude, phase = self._inference_impl(mel=mel, s_stft=s_stft)
         | 
| 549 | 
            -
                        torch.cuda.synchronize()
         | 
| 550 | 
            -
             | 
| 551 | 
            -
                        g = torch.cuda.CUDAGraph()
         | 
| 552 | 
            -
                        with torch.cuda.graph(g):
         | 
| 553 | 
            -
                            magnitude, phase = self._inference_impl(mel=mel, s_stft=s_stft)
         | 
| 554 | 
            -
                        inference_buffer = {
         | 
| 555 | 
            -
                            "mel": mel,
         | 
| 556 | 
            -
                            "s_stft": s_stft,
         | 
| 557 | 
            -
                            "magnitude": magnitude,
         | 
| 558 | 
            -
                            "phase": phase,
         | 
| 559 | 
            -
                        }
         | 
| 560 | 
            -
                        self.inference_buffers[seq_len] = inference_buffer
         | 
| 561 | 
            -
                        self.inference_graphs[seq_len] = g
         | 
| 562 | 
            -
             | 
| 563 | 
            -
                    end_time = time.time()
         | 
| 564 | 
            -
                    print(
         | 
| 565 | 
            -
                        f"capture inference for HiFTGenerator with seq_len_to_capture: {seq_len_to_capture} takes {end_time - start_time} seconds"
         | 
| 566 | 
            -
                    )
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        cosyvoice/matcha/audio.py
    DELETED
    
    | @@ -1,90 +0,0 @@ | |
| 1 | 
            -
            import numpy as np
         | 
| 2 | 
            -
            import torch
         | 
| 3 | 
            -
            import torch.utils.data
         | 
| 4 | 
            -
            from librosa.filters import mel as librosa_mel_fn
         | 
| 5 | 
            -
            from scipy.io.wavfile import read
         | 
| 6 | 
            -
             | 
| 7 | 
            -
            MAX_WAV_VALUE = 32768.0
         | 
| 8 | 
            -
             | 
| 9 | 
            -
             | 
| 10 | 
            -
            def load_wav(full_path):
         | 
| 11 | 
            -
                sampling_rate, data = read(full_path)
         | 
| 12 | 
            -
                return data, sampling_rate
         | 
| 13 | 
            -
             | 
| 14 | 
            -
             | 
| 15 | 
            -
            def dynamic_range_compression(x, C=1, clip_val=1e-5):
         | 
| 16 | 
            -
                return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
         | 
| 17 | 
            -
             | 
| 18 | 
            -
             | 
| 19 | 
            -
            def dynamic_range_decompression(x, C=1):
         | 
| 20 | 
            -
                return np.exp(x) / C
         | 
| 21 | 
            -
             | 
| 22 | 
            -
             | 
| 23 | 
            -
            def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
         | 
| 24 | 
            -
                return torch.log(torch.clamp(x, min=clip_val) * C)
         | 
| 25 | 
            -
             | 
| 26 | 
            -
             | 
| 27 | 
            -
            def dynamic_range_decompression_torch(x, C=1):
         | 
| 28 | 
            -
                return torch.exp(x) / C
         | 
| 29 | 
            -
             | 
| 30 | 
            -
             | 
| 31 | 
            -
            def spectral_normalize_torch(magnitudes):
         | 
| 32 | 
            -
                output = dynamic_range_compression_torch(magnitudes)
         | 
| 33 | 
            -
                return output
         | 
| 34 | 
            -
             | 
| 35 | 
            -
             | 
| 36 | 
            -
            def spectral_de_normalize_torch(magnitudes):
         | 
| 37 | 
            -
                output = dynamic_range_decompression_torch(magnitudes)
         | 
| 38 | 
            -
                return output
         | 
| 39 | 
            -
             | 
| 40 | 
            -
             | 
| 41 | 
            -
            mel_basis = {}
         | 
| 42 | 
            -
            hann_window = {}
         | 
| 43 | 
            -
             | 
| 44 | 
            -
             | 
| 45 | 
            -
            def mel_spectrogram(
         | 
| 46 | 
            -
                y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False
         | 
| 47 | 
            -
            ):
         | 
| 48 | 
            -
                if torch.min(y) < -1.0:
         | 
| 49 | 
            -
                    print("min value is ", torch.min(y))
         | 
| 50 | 
            -
                if torch.max(y) > 1.0:
         | 
| 51 | 
            -
                    print("max value is ", torch.max(y))
         | 
| 52 | 
            -
             | 
| 53 | 
            -
                global mel_basis, hann_window  # pylint: disable=global-statement
         | 
| 54 | 
            -
                if f"{str(fmax)}_{str(y.device)}" not in mel_basis:
         | 
| 55 | 
            -
                    mel = librosa_mel_fn(
         | 
| 56 | 
            -
                        sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
         | 
| 57 | 
            -
                    )
         | 
| 58 | 
            -
                    mel_basis[str(fmax) + "_" + str(y.device)] = (
         | 
| 59 | 
            -
                        torch.from_numpy(mel).float().to(y.device)
         | 
| 60 | 
            -
                    )
         | 
| 61 | 
            -
                    hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
         | 
| 62 | 
            -
             | 
| 63 | 
            -
                y = torch.nn.functional.pad(
         | 
| 64 | 
            -
                    y.unsqueeze(1),
         | 
| 65 | 
            -
                    (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
         | 
| 66 | 
            -
                    mode="reflect",
         | 
| 67 | 
            -
                )
         | 
| 68 | 
            -
                y = y.squeeze(1)
         | 
| 69 | 
            -
             | 
| 70 | 
            -
                spec = torch.view_as_real(
         | 
| 71 | 
            -
                    torch.stft(
         | 
| 72 | 
            -
                        y,
         | 
| 73 | 
            -
                        n_fft,
         | 
| 74 | 
            -
                        hop_length=hop_size,
         | 
| 75 | 
            -
                        win_length=win_size,
         | 
| 76 | 
            -
                        window=hann_window[str(y.device)],
         | 
| 77 | 
            -
                        center=center,
         | 
| 78 | 
            -
                        pad_mode="reflect",
         | 
| 79 | 
            -
                        normalized=False,
         | 
| 80 | 
            -
                        onesided=True,
         | 
| 81 | 
            -
                        return_complex=True,
         | 
| 82 | 
            -
                    )
         | 
| 83 | 
            -
                )
         | 
| 84 | 
            -
             | 
| 85 | 
            -
                spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
         | 
| 86 | 
            -
             | 
| 87 | 
            -
                spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec)
         | 
| 88 | 
            -
                spec = spectral_normalize_torch(spec)
         | 
| 89 | 
            -
             | 
| 90 | 
            -
                return spec
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        cosyvoice/matcha/decoder.py
    DELETED
    
    | @@ -1,511 +0,0 @@ | |
| 1 | 
            -
            import math
         | 
| 2 | 
            -
            from typing import Optional
         | 
| 3 | 
            -
             | 
| 4 | 
            -
            import torch
         | 
| 5 | 
            -
            import torch.nn as nn
         | 
| 6 | 
            -
            import torch.nn.functional as F
         | 
| 7 | 
            -
            from conformer import ConformerBlock
         | 
| 8 | 
            -
            from diffusers.models.activations import get_activation
         | 
| 9 | 
            -
            from einops import pack, rearrange, repeat
         | 
| 10 | 
            -
             | 
| 11 | 
            -
            from cosyvoice.matcha.transformer import BasicTransformerBlock
         | 
| 12 | 
            -
             | 
| 13 | 
            -
             | 
| 14 | 
            -
            class SinusoidalPosEmb(torch.nn.Module):
         | 
| 15 | 
            -
                def __init__(self, dim):
         | 
| 16 | 
            -
                    super().__init__()
         | 
| 17 | 
            -
                    self.dim = dim
         | 
| 18 | 
            -
                    assert self.dim % 2 == 0, "SinusoidalPosEmb requires dim to be even"
         | 
| 19 | 
            -
             | 
| 20 | 
            -
                def forward(self, x, scale=1000):
         | 
| 21 | 
            -
                    if x.ndim < 1:
         | 
| 22 | 
            -
                        x = x.unsqueeze(0)
         | 
| 23 | 
            -
                    device = x.device
         | 
| 24 | 
            -
                    half_dim = self.dim // 2
         | 
| 25 | 
            -
                    emb = math.log(10000) / (half_dim - 1)
         | 
| 26 | 
            -
                    emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
         | 
| 27 | 
            -
                    emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
         | 
| 28 | 
            -
                    emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
         | 
| 29 | 
            -
                    return emb
         | 
| 30 | 
            -
             | 
| 31 | 
            -
             | 
| 32 | 
            -
            class MaskedGroupNorm(nn.GroupNorm):
         | 
| 33 | 
            -
                """
         | 
| 34 | 
            -
                Masked verstion of the Group normalization.
         | 
| 35 | 
            -
             | 
| 36 | 
            -
                Based on: https://github.com/ptrblck/pytorch_misc/blob/20e8ea93bd458b88f921a87e2d4001a4eb753a02/batch_norm_manual.py
         | 
| 37 | 
            -
             | 
| 38 | 
            -
                Receives a N-dim tensor of sequence lengths per batch element
         | 
| 39 | 
            -
                along with the regular input for masking.
         | 
| 40 | 
            -
             | 
| 41 | 
            -
                Check pytorch's GroupNorm implementation for argument details.
         | 
| 42 | 
            -
                """
         | 
| 43 | 
            -
             | 
| 44 | 
            -
                def __init__(self, num_groups, num_channels, eps=1e-5, affine=True):
         | 
| 45 | 
            -
                    super(MaskedGroupNorm, self).__init__(num_groups, num_channels, eps, affine)
         | 
| 46 | 
            -
             | 
| 47 | 
            -
                def forward(self, inp, mask=None):
         | 
| 48 | 
            -
                    assert (
         | 
| 49 | 
            -
                        inp.shape[1] % self.num_groups == 0
         | 
| 50 | 
            -
                    ), "Feature size not divisible by groups"
         | 
| 51 | 
            -
             | 
| 52 | 
            -
                    # 计算有效长度
         | 
| 53 | 
            -
                    seq_lengths = mask.sum(-1, keepdim=True)  # [batch_size, 1]
         | 
| 54 | 
            -
             | 
| 55 | 
            -
                    # 将输入reshape为groups
         | 
| 56 | 
            -
                    features_per_group = inp.shape[1] // self.num_groups
         | 
| 57 | 
            -
                    inp_r = inp.reshape(
         | 
| 58 | 
            -
                        inp.shape[0], self.num_groups, features_per_group, inp.shape[-1]
         | 
| 59 | 
            -
                    )
         | 
| 60 | 
            -
                    mask_r = mask.unsqueeze(1)  # [batch_size, 1, 1, length]
         | 
| 61 | 
            -
             | 
| 62 | 
            -
                    # 计算masked mean和variance
         | 
| 63 | 
            -
                    masked_inp = inp_r * mask_r
         | 
| 64 | 
            -
                    n = seq_lengths * features_per_group  # 每组的有效元素数量
         | 
| 65 | 
            -
                    mean = masked_inp.sum([2, 3], keepdim=True) / (n.view(-1, 1, 1, 1) + 1e-5)
         | 
| 66 | 
            -
                    var = ((masked_inp - mean * mask_r) ** 2).sum([2, 3], keepdim=True) / (
         | 
| 67 | 
            -
                        n.view(-1, 1, 1, 1) + 1e-5
         | 
| 68 | 
            -
                    )
         | 
| 69 | 
            -
             | 
| 70 | 
            -
                    # 标准化
         | 
| 71 | 
            -
                    inp_r = (inp_r - mean) / (torch.sqrt(var + self.eps))
         | 
| 72 | 
            -
                    out = inp_r.reshape(inp.shape[0], self.num_channels, inp.shape[-1])
         | 
| 73 | 
            -
             | 
| 74 | 
            -
                    # 应用仿射变换
         | 
| 75 | 
            -
                    if self.affine:
         | 
| 76 | 
            -
                        out = out * self.weight[None, :, None] + self.bias[None, :, None]
         | 
| 77 | 
            -
             | 
| 78 | 
            -
                    return out
         | 
| 79 | 
            -
             | 
| 80 | 
            -
             | 
| 81 | 
            -
            class Block1D(torch.nn.Module):
         | 
| 82 | 
            -
                def __init__(self, dim, dim_out, groups=8):
         | 
| 83 | 
            -
                    super().__init__()
         | 
| 84 | 
            -
                    self.block = torch.nn.Sequential(
         | 
| 85 | 
            -
                        torch.nn.Conv1d(dim, dim_out, 3, padding=1),
         | 
| 86 | 
            -
                        torch.nn.GroupNorm(groups, dim_out),
         | 
| 87 | 
            -
                        # MaskedGroupNorm(groups, dim_out),
         | 
| 88 | 
            -
                        nn.Mish(),
         | 
| 89 | 
            -
                    )
         | 
| 90 | 
            -
             | 
| 91 | 
            -
                def forward(self, x, mask):
         | 
| 92 | 
            -
                    output = self.block(x * mask)
         | 
| 93 | 
            -
                    return output * mask
         | 
| 94 | 
            -
                    return x * mask
         | 
| 95 | 
            -
             | 
| 96 | 
            -
             | 
| 97 | 
            -
            class ResnetBlock1D(torch.nn.Module):
         | 
| 98 | 
            -
                def __init__(self, dim, dim_out, time_emb_dim, groups=8):
         | 
| 99 | 
            -
                    super().__init__()
         | 
| 100 | 
            -
                    self.mlp = torch.nn.Sequential(
         | 
| 101 | 
            -
                        nn.Mish(), torch.nn.Linear(time_emb_dim, dim_out)
         | 
| 102 | 
            -
                    )
         | 
| 103 | 
            -
             | 
| 104 | 
            -
                    self.block1 = Block1D(dim, dim_out, groups=groups)
         | 
| 105 | 
            -
                    self.block2 = Block1D(dim_out, dim_out, groups=groups)
         | 
| 106 | 
            -
             | 
| 107 | 
            -
                    self.res_conv = torch.nn.Conv1d(dim, dim_out, 1)
         | 
| 108 | 
            -
             | 
| 109 | 
            -
                def forward(self, x, mask, time_emb):
         | 
| 110 | 
            -
                    h = self.block1(x, mask)
         | 
| 111 | 
            -
                    h += self.mlp(time_emb).unsqueeze(-1)
         | 
| 112 | 
            -
                    h = self.block2(h, mask)
         | 
| 113 | 
            -
                    output = h + self.res_conv(x * mask)
         | 
| 114 | 
            -
                    return output
         | 
| 115 | 
            -
             | 
| 116 | 
            -
             | 
| 117 | 
            -
            class Downsample1D(nn.Module):
         | 
| 118 | 
            -
                def __init__(self, dim):
         | 
| 119 | 
            -
                    super().__init__()
         | 
| 120 | 
            -
                    self.conv = torch.nn.Conv1d(dim, dim, 3, 2, 1)
         | 
| 121 | 
            -
             | 
| 122 | 
            -
                def forward(self, x):
         | 
| 123 | 
            -
                    return self.conv(x)
         | 
| 124 | 
            -
             | 
| 125 | 
            -
             | 
| 126 | 
            -
            class TimestepEmbedding(nn.Module):
         | 
| 127 | 
            -
                def __init__(
         | 
| 128 | 
            -
                    self,
         | 
| 129 | 
            -
                    in_channels: int,
         | 
| 130 | 
            -
                    time_embed_dim: int,
         | 
| 131 | 
            -
                    act_fn: str = "silu",
         | 
| 132 | 
            -
                    out_dim: int = None,
         | 
| 133 | 
            -
                    post_act_fn: Optional[str] = None,
         | 
| 134 | 
            -
                    cond_proj_dim=None,
         | 
| 135 | 
            -
                ):
         | 
| 136 | 
            -
                    super().__init__()
         | 
| 137 | 
            -
             | 
| 138 | 
            -
                    self.linear_1 = nn.Linear(in_channels, time_embed_dim)
         | 
| 139 | 
            -
             | 
| 140 | 
            -
                    if cond_proj_dim is not None:
         | 
| 141 | 
            -
                        self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
         | 
| 142 | 
            -
                    else:
         | 
| 143 | 
            -
                        self.cond_proj = None
         | 
| 144 | 
            -
             | 
| 145 | 
            -
                    self.act = get_activation(act_fn)
         | 
| 146 | 
            -
             | 
| 147 | 
            -
                    if out_dim is not None:
         | 
| 148 | 
            -
                        time_embed_dim_out = out_dim
         | 
| 149 | 
            -
                    else:
         | 
| 150 | 
            -
                        time_embed_dim_out = time_embed_dim
         | 
| 151 | 
            -
                    self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
         | 
| 152 | 
            -
             | 
| 153 | 
            -
                    if post_act_fn is None:
         | 
| 154 | 
            -
                        self.post_act = None
         | 
| 155 | 
            -
                    else:
         | 
| 156 | 
            -
                        self.post_act = get_activation(post_act_fn)
         | 
| 157 | 
            -
             | 
| 158 | 
            -
                def forward(self, sample, condition=None):
         | 
| 159 | 
            -
                    if condition is not None:
         | 
| 160 | 
            -
                        sample = sample + self.cond_proj(condition)
         | 
| 161 | 
            -
                    sample = self.linear_1(sample)
         | 
| 162 | 
            -
             | 
| 163 | 
            -
                    if self.act is not None:
         | 
| 164 | 
            -
                        sample = self.act(sample)
         | 
| 165 | 
            -
             | 
| 166 | 
            -
                    sample = self.linear_2(sample)
         | 
| 167 | 
            -
             | 
| 168 | 
            -
                    if self.post_act is not None:
         | 
| 169 | 
            -
                        sample = self.post_act(sample)
         | 
| 170 | 
            -
                    return sample
         | 
| 171 | 
            -
             | 
| 172 | 
            -
             | 
| 173 | 
            -
            class Upsample1D(nn.Module):
         | 
| 174 | 
            -
                """A 1D upsampling layer with an optional convolution.
         | 
| 175 | 
            -
             | 
| 176 | 
            -
                Parameters:
         | 
| 177 | 
            -
                    channels (`int`):
         | 
| 178 | 
            -
                        number of channels in the inputs and outputs.
         | 
| 179 | 
            -
                    use_conv (`bool`, default `False`):
         | 
| 180 | 
            -
                        option to use a convolution.
         | 
| 181 | 
            -
                    use_conv_transpose (`bool`, default `False`):
         | 
| 182 | 
            -
                        option to use a convolution transpose.
         | 
| 183 | 
            -
                    out_channels (`int`, optional):
         | 
| 184 | 
            -
                        number of output channels. Defaults to `channels`.
         | 
| 185 | 
            -
                """
         | 
| 186 | 
            -
             | 
| 187 | 
            -
                def __init__(
         | 
| 188 | 
            -
                    self,
         | 
| 189 | 
            -
                    channels,
         | 
| 190 | 
            -
                    use_conv=False,
         | 
| 191 | 
            -
                    use_conv_transpose=True,
         | 
| 192 | 
            -
                    out_channels=None,
         | 
| 193 | 
            -
                    name="conv",
         | 
| 194 | 
            -
                ):
         | 
| 195 | 
            -
                    super().__init__()
         | 
| 196 | 
            -
                    self.channels = channels
         | 
| 197 | 
            -
                    self.out_channels = out_channels or channels
         | 
| 198 | 
            -
                    self.use_conv = use_conv
         | 
| 199 | 
            -
                    self.use_conv_transpose = use_conv_transpose
         | 
| 200 | 
            -
                    self.name = name
         | 
| 201 | 
            -
             | 
| 202 | 
            -
                    self.conv = None
         | 
| 203 | 
            -
                    if use_conv_transpose:
         | 
| 204 | 
            -
                        self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
         | 
| 205 | 
            -
                    elif use_conv:
         | 
| 206 | 
            -
                        self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)
         | 
| 207 | 
            -
             | 
| 208 | 
            -
                def forward(self, inputs):
         | 
| 209 | 
            -
                    assert inputs.shape[1] == self.channels
         | 
| 210 | 
            -
                    if self.use_conv_transpose:
         | 
| 211 | 
            -
                        return self.conv(inputs)
         | 
| 212 | 
            -
             | 
| 213 | 
            -
                    outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest")
         | 
| 214 | 
            -
             | 
| 215 | 
            -
                    if self.use_conv:
         | 
| 216 | 
            -
                        outputs = self.conv(outputs)
         | 
| 217 | 
            -
             | 
| 218 | 
            -
                    return outputs
         | 
| 219 | 
            -
             | 
| 220 | 
            -
             | 
| 221 | 
            -
            class ConformerWrapper(ConformerBlock):
         | 
| 222 | 
            -
                def __init__(  # pylint: disable=useless-super-delegation
         | 
| 223 | 
            -
                    self,
         | 
| 224 | 
            -
                    *,
         | 
| 225 | 
            -
                    dim,
         | 
| 226 | 
            -
                    dim_head=64,
         | 
| 227 | 
            -
                    heads=8,
         | 
| 228 | 
            -
                    ff_mult=4,
         | 
| 229 | 
            -
                    conv_expansion_factor=2,
         | 
| 230 | 
            -
                    conv_kernel_size=31,
         | 
| 231 | 
            -
                    attn_dropout=0,
         | 
| 232 | 
            -
                    ff_dropout=0,
         | 
| 233 | 
            -
                    conv_dropout=0,
         | 
| 234 | 
            -
                    conv_causal=False,
         | 
| 235 | 
            -
                ):
         | 
| 236 | 
            -
                    super().__init__(
         | 
| 237 | 
            -
                        dim=dim,
         | 
| 238 | 
            -
                        dim_head=dim_head,
         | 
| 239 | 
            -
                        heads=heads,
         | 
| 240 | 
            -
                        ff_mult=ff_mult,
         | 
| 241 | 
            -
                        conv_expansion_factor=conv_expansion_factor,
         | 
| 242 | 
            -
                        conv_kernel_size=conv_kernel_size,
         | 
| 243 | 
            -
                        attn_dropout=attn_dropout,
         | 
| 244 | 
            -
                        ff_dropout=ff_dropout,
         | 
| 245 | 
            -
                        conv_dropout=conv_dropout,
         | 
| 246 | 
            -
                        conv_causal=conv_causal,
         | 
| 247 | 
            -
                    )
         | 
| 248 | 
            -
             | 
| 249 | 
            -
                def forward(
         | 
| 250 | 
            -
                    self,
         | 
| 251 | 
            -
                    hidden_states,
         | 
| 252 | 
            -
                    attention_mask,
         | 
| 253 | 
            -
                    encoder_hidden_states=None,
         | 
| 254 | 
            -
                    encoder_attention_mask=None,
         | 
| 255 | 
            -
                    timestep=None,
         | 
| 256 | 
            -
                ):
         | 
| 257 | 
            -
                    return super().forward(x=hidden_states, mask=attention_mask.bool())
         | 
| 258 | 
            -
             | 
| 259 | 
            -
             | 
| 260 | 
            -
            class Decoder(nn.Module):
         | 
| 261 | 
            -
                def __init__(
         | 
| 262 | 
            -
                    self,
         | 
| 263 | 
            -
                    in_channels,
         | 
| 264 | 
            -
                    out_channels,
         | 
| 265 | 
            -
                    channels=(256, 256),
         | 
| 266 | 
            -
                    dropout=0.05,
         | 
| 267 | 
            -
                    attention_head_dim=64,
         | 
| 268 | 
            -
                    n_blocks=1,
         | 
| 269 | 
            -
                    num_mid_blocks=2,
         | 
| 270 | 
            -
                    num_heads=4,
         | 
| 271 | 
            -
                    act_fn="snake",
         | 
| 272 | 
            -
                    down_block_type="transformer",
         | 
| 273 | 
            -
                    mid_block_type="transformer",
         | 
| 274 | 
            -
                    up_block_type="transformer",
         | 
| 275 | 
            -
                ):
         | 
| 276 | 
            -
                    super().__init__()
         | 
| 277 | 
            -
                    channels = tuple(channels)
         | 
| 278 | 
            -
                    self.in_channels = in_channels
         | 
| 279 | 
            -
                    self.out_channels = out_channels
         | 
| 280 | 
            -
             | 
| 281 | 
            -
                    self.time_embeddings = SinusoidalPosEmb(in_channels)
         | 
| 282 | 
            -
                    time_embed_dim = channels[0] * 4
         | 
| 283 | 
            -
                    self.time_mlp = TimestepEmbedding(
         | 
| 284 | 
            -
                        in_channels=in_channels,
         | 
| 285 | 
            -
                        time_embed_dim=time_embed_dim,
         | 
| 286 | 
            -
                        act_fn="silu",
         | 
| 287 | 
            -
                    )
         | 
| 288 | 
            -
             | 
| 289 | 
            -
                    self.down_blocks = nn.ModuleList([])
         | 
| 290 | 
            -
                    self.mid_blocks = nn.ModuleList([])
         | 
| 291 | 
            -
                    self.up_blocks = nn.ModuleList([])
         | 
| 292 | 
            -
             | 
| 293 | 
            -
                    output_channel = in_channels
         | 
| 294 | 
            -
                    for i in range(len(channels)):  # pylint: disable=consider-using-enumerate
         | 
| 295 | 
            -
                        input_channel = output_channel
         | 
| 296 | 
            -
                        output_channel = channels[i]
         | 
| 297 | 
            -
                        is_last = i == len(channels) - 1
         | 
| 298 | 
            -
                        resnet = ResnetBlock1D(
         | 
| 299 | 
            -
                            dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim
         | 
| 300 | 
            -
                        )
         | 
| 301 | 
            -
                        transformer_blocks = nn.ModuleList(
         | 
| 302 | 
            -
                            [
         | 
| 303 | 
            -
                                self.get_block(
         | 
| 304 | 
            -
                                    down_block_type,
         | 
| 305 | 
            -
                                    output_channel,
         | 
| 306 | 
            -
                                    attention_head_dim,
         | 
| 307 | 
            -
                                    num_heads,
         | 
| 308 | 
            -
                                    dropout,
         | 
| 309 | 
            -
                                    act_fn,
         | 
| 310 | 
            -
                                )
         | 
| 311 | 
            -
                                for _ in range(n_blocks)
         | 
| 312 | 
            -
                            ]
         | 
| 313 | 
            -
                        )
         | 
| 314 | 
            -
                        downsample = (
         | 
| 315 | 
            -
                            Downsample1D(output_channel)
         | 
| 316 | 
            -
                            if not is_last
         | 
| 317 | 
            -
                            else nn.Conv1d(output_channel, output_channel, 3, padding=1)
         | 
| 318 | 
            -
                        )
         | 
| 319 | 
            -
             | 
| 320 | 
            -
                        self.down_blocks.append(
         | 
| 321 | 
            -
                            nn.ModuleList([resnet, transformer_blocks, downsample])
         | 
| 322 | 
            -
                        )
         | 
| 323 | 
            -
             | 
| 324 | 
            -
                    for i in range(num_mid_blocks):
         | 
| 325 | 
            -
                        input_channel = channels[-1]
         | 
| 326 | 
            -
                        out_channels = channels[-1]
         | 
| 327 | 
            -
             | 
| 328 | 
            -
                        resnet = ResnetBlock1D(
         | 
| 329 | 
            -
                            dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim
         | 
| 330 | 
            -
                        )
         | 
| 331 | 
            -
             | 
| 332 | 
            -
                        transformer_blocks = nn.ModuleList(
         | 
| 333 | 
            -
                            [
         | 
| 334 | 
            -
                                self.get_block(
         | 
| 335 | 
            -
                                    mid_block_type,
         | 
| 336 | 
            -
                                    output_channel,
         | 
| 337 | 
            -
                                    attention_head_dim,
         | 
| 338 | 
            -
                                    num_heads,
         | 
| 339 | 
            -
                                    dropout,
         | 
| 340 | 
            -
                                    act_fn,
         | 
| 341 | 
            -
                                )
         | 
| 342 | 
            -
                                for _ in range(n_blocks)
         | 
| 343 | 
            -
                            ]
         | 
| 344 | 
            -
                        )
         | 
| 345 | 
            -
             | 
| 346 | 
            -
                        self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
         | 
| 347 | 
            -
             | 
| 348 | 
            -
                    channels = channels[::-1] + (channels[0],)
         | 
| 349 | 
            -
                    for i in range(len(channels) - 1):
         | 
| 350 | 
            -
                        input_channel = channels[i]
         | 
| 351 | 
            -
                        output_channel = channels[i + 1]
         | 
| 352 | 
            -
                        is_last = i == len(channels) - 2
         | 
| 353 | 
            -
             | 
| 354 | 
            -
                        resnet = ResnetBlock1D(
         | 
| 355 | 
            -
                            dim=2 * input_channel,
         | 
| 356 | 
            -
                            dim_out=output_channel,
         | 
| 357 | 
            -
                            time_emb_dim=time_embed_dim,
         | 
| 358 | 
            -
                        )
         | 
| 359 | 
            -
                        transformer_blocks = nn.ModuleList(
         | 
| 360 | 
            -
                            [
         | 
| 361 | 
            -
                                self.get_block(
         | 
| 362 | 
            -
                                    up_block_type,
         | 
| 363 | 
            -
                                    output_channel,
         | 
| 364 | 
            -
                                    attention_head_dim,
         | 
| 365 | 
            -
                                    num_heads,
         | 
| 366 | 
            -
                                    dropout,
         | 
| 367 | 
            -
                                    act_fn,
         | 
| 368 | 
            -
                                )
         | 
| 369 | 
            -
                                for _ in range(n_blocks)
         | 
| 370 | 
            -
                            ]
         | 
| 371 | 
            -
                        )
         | 
| 372 | 
            -
                        upsample = (
         | 
| 373 | 
            -
                            Upsample1D(output_channel, use_conv_transpose=True)
         | 
| 374 | 
            -
                            if not is_last
         | 
| 375 | 
            -
                            else nn.Conv1d(output_channel, output_channel, 3, padding=1)
         | 
| 376 | 
            -
                        )
         | 
| 377 | 
            -
             | 
| 378 | 
            -
                        self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
         | 
| 379 | 
            -
             | 
| 380 | 
            -
                    self.final_block = Block1D(channels[-1], channels[-1])
         | 
| 381 | 
            -
                    self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
         | 
| 382 | 
            -
             | 
| 383 | 
            -
                    self.initialize_weights()
         | 
| 384 | 
            -
                    # nn.init.normal_(self.final_proj.weight)
         | 
| 385 | 
            -
             | 
| 386 | 
            -
                @staticmethod
         | 
| 387 | 
            -
                def get_block(block_type, dim, attention_head_dim, num_heads, dropout, act_fn):
         | 
| 388 | 
            -
                    if block_type == "conformer":
         | 
| 389 | 
            -
                        block = ConformerWrapper(
         | 
| 390 | 
            -
                            dim=dim,
         | 
| 391 | 
            -
                            dim_head=attention_head_dim,
         | 
| 392 | 
            -
                            heads=num_heads,
         | 
| 393 | 
            -
                            ff_mult=1,
         | 
| 394 | 
            -
                            conv_expansion_factor=2,
         | 
| 395 | 
            -
                            ff_dropout=dropout,
         | 
| 396 | 
            -
                            attn_dropout=dropout,
         | 
| 397 | 
            -
                            conv_dropout=dropout,
         | 
| 398 | 
            -
                            conv_kernel_size=31,
         | 
| 399 | 
            -
                        )
         | 
| 400 | 
            -
                    elif block_type == "transformer":
         | 
| 401 | 
            -
                        block = BasicTransformerBlock(
         | 
| 402 | 
            -
                            dim=dim,
         | 
| 403 | 
            -
                            num_attention_heads=num_heads,
         | 
| 404 | 
            -
                            attention_head_dim=attention_head_dim,
         | 
| 405 | 
            -
                            dropout=dropout,
         | 
| 406 | 
            -
                            activation_fn=act_fn,
         | 
| 407 | 
            -
                        )
         | 
| 408 | 
            -
                    else:
         | 
| 409 | 
            -
                        raise ValueError(f"Unknown block type {block_type}")
         | 
| 410 | 
            -
             | 
| 411 | 
            -
                    return block
         | 
| 412 | 
            -
             | 
| 413 | 
            -
                def initialize_weights(self):
         | 
| 414 | 
            -
                    for m in self.modules():
         | 
| 415 | 
            -
                        if isinstance(m, nn.Conv1d):
         | 
| 416 | 
            -
                            nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
         | 
| 417 | 
            -
             | 
| 418 | 
            -
                            if m.bias is not None:
         | 
| 419 | 
            -
                                nn.init.constant_(m.bias, 0)
         | 
| 420 | 
            -
             | 
| 421 | 
            -
                        elif isinstance(m, nn.GroupNorm):
         | 
| 422 | 
            -
                            nn.init.constant_(m.weight, 1)
         | 
| 423 | 
            -
                            nn.init.constant_(m.bias, 0)
         | 
| 424 | 
            -
             | 
| 425 | 
            -
                        elif isinstance(m, nn.Linear):
         | 
| 426 | 
            -
                            nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
         | 
| 427 | 
            -
             | 
| 428 | 
            -
                            if m.bias is not None:
         | 
| 429 | 
            -
                                nn.init.constant_(m.bias, 0)
         | 
| 430 | 
            -
             | 
| 431 | 
            -
                def forward(self, x, mask, mu, t, spks=None, cond=None):
         | 
| 432 | 
            -
                    """Forward pass of the UNet1DConditional model.
         | 
| 433 | 
            -
             | 
| 434 | 
            -
                    Args:
         | 
| 435 | 
            -
                        x (torch.Tensor): shape (batch_size, in_channels, time)
         | 
| 436 | 
            -
                        mask (_type_): shape (batch_size, 1, time)
         | 
| 437 | 
            -
                        t (_type_): shape (batch_size)
         | 
| 438 | 
            -
                        spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
         | 
| 439 | 
            -
                        cond (_type_, optional): placeholder for future use. Defaults to None.
         | 
| 440 | 
            -
             | 
| 441 | 
            -
                    Raises:
         | 
| 442 | 
            -
                        ValueError: _description_
         | 
| 443 | 
            -
                        ValueError: _description_
         | 
| 444 | 
            -
             | 
| 445 | 
            -
                    Returns:
         | 
| 446 | 
            -
                        _type_: _description_
         | 
| 447 | 
            -
                    """
         | 
| 448 | 
            -
             | 
| 449 | 
            -
                    t = self.time_embeddings(t)
         | 
| 450 | 
            -
                    t = self.time_mlp(t)
         | 
| 451 | 
            -
             | 
| 452 | 
            -
                    x = pack([x, mu], "b * t")[0]
         | 
| 453 | 
            -
             | 
| 454 | 
            -
                    if spks is not None:
         | 
| 455 | 
            -
                        spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
         | 
| 456 | 
            -
                        x = pack([x, spks], "b * t")[0]
         | 
| 457 | 
            -
             | 
| 458 | 
            -
                    hiddens = []
         | 
| 459 | 
            -
                    masks = [mask]
         | 
| 460 | 
            -
                    for resnet, transformer_blocks, downsample in self.down_blocks:
         | 
| 461 | 
            -
                        mask_down = masks[-1]
         | 
| 462 | 
            -
                        x = resnet(x, mask_down, t)
         | 
| 463 | 
            -
                        x = rearrange(x, "b c t -> b t c")
         | 
| 464 | 
            -
                        mask_down = rearrange(mask_down, "b 1 t -> b t")
         | 
| 465 | 
            -
                        for transformer_block in transformer_blocks:
         | 
| 466 | 
            -
                            x = transformer_block(
         | 
| 467 | 
            -
                                hidden_states=x,
         | 
| 468 | 
            -
                                attention_mask=mask_down,
         | 
| 469 | 
            -
                                timestep=t,
         | 
| 470 | 
            -
                            )
         | 
| 471 | 
            -
                        x = rearrange(x, "b t c -> b c t")
         | 
| 472 | 
            -
                        mask_down = rearrange(mask_down, "b t -> b 1 t")
         | 
| 473 | 
            -
                        hiddens.append(x)  # Save hidden states for skip connections
         | 
| 474 | 
            -
                        x = downsample(x * mask_down)
         | 
| 475 | 
            -
                        masks.append(mask_down[:, :, ::2])
         | 
| 476 | 
            -
             | 
| 477 | 
            -
                    masks = masks[:-1]
         | 
| 478 | 
            -
                    mask_mid = masks[-1]
         | 
| 479 | 
            -
             | 
| 480 | 
            -
                    for resnet, transformer_blocks in self.mid_blocks:
         | 
| 481 | 
            -
                        x = resnet(x, mask_mid, t)
         | 
| 482 | 
            -
                        x = rearrange(x, "b c t -> b t c")
         | 
| 483 | 
            -
                        mask_mid = rearrange(mask_mid, "b 1 t -> b t")
         | 
| 484 | 
            -
                        for transformer_block in transformer_blocks:
         | 
| 485 | 
            -
                            x = transformer_block(
         | 
| 486 | 
            -
                                hidden_states=x,
         | 
| 487 | 
            -
                                attention_mask=mask_mid,
         | 
| 488 | 
            -
                                timestep=t,
         | 
| 489 | 
            -
                            )
         | 
| 490 | 
            -
                        x = rearrange(x, "b t c -> b c t")
         | 
| 491 | 
            -
                        mask_mid = rearrange(mask_mid, "b t -> b 1 t")
         | 
| 492 | 
            -
             | 
| 493 | 
            -
                    for resnet, transformer_blocks, upsample in self.up_blocks:
         | 
| 494 | 
            -
                        mask_up = masks.pop()
         | 
| 495 | 
            -
                        x = resnet(pack([x, hiddens.pop()], "b * t")[0], mask_up, t)
         | 
| 496 | 
            -
                        x = rearrange(x, "b c t -> b t c")
         | 
| 497 | 
            -
                        mask_up = rearrange(mask_up, "b 1 t -> b t")
         | 
| 498 | 
            -
                        for transformer_block in transformer_blocks:
         | 
| 499 | 
            -
                            x = transformer_block(
         | 
| 500 | 
            -
                                hidden_states=x,
         | 
| 501 | 
            -
                                attention_mask=mask_up,
         | 
| 502 | 
            -
                                timestep=t,
         | 
| 503 | 
            -
                            )
         | 
| 504 | 
            -
                        x = rearrange(x, "b t c -> b c t")
         | 
| 505 | 
            -
                        mask_up = rearrange(mask_up, "b t -> b 1 t")
         | 
| 506 | 
            -
                        x = upsample(x * mask_up)
         | 
| 507 | 
            -
             | 
| 508 | 
            -
                    x = self.final_block(x, mask_up)
         | 
| 509 | 
            -
                    output = self.final_proj(x * mask_up)
         | 
| 510 | 
            -
             | 
| 511 | 
            -
                    return output * mask
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        cosyvoice/matcha/flow_matching.py
    DELETED
    
    | @@ -1,141 +0,0 @@ | |
| 1 | 
            -
            from abc import ABC
         | 
| 2 | 
            -
             | 
| 3 | 
            -
            import torch
         | 
| 4 | 
            -
            import torch.nn.functional as F
         | 
| 5 | 
            -
             | 
| 6 | 
            -
            from cosyvoice.matcha.decoder import Decoder
         | 
| 7 | 
            -
             | 
| 8 | 
            -
             | 
| 9 | 
            -
            class BASECFM(torch.nn.Module, ABC):
         | 
| 10 | 
            -
                def __init__(
         | 
| 11 | 
            -
                    self,
         | 
| 12 | 
            -
                    n_feats,
         | 
| 13 | 
            -
                    cfm_params,
         | 
| 14 | 
            -
                    n_spks=1,
         | 
| 15 | 
            -
                    spk_emb_dim=128,
         | 
| 16 | 
            -
                ):
         | 
| 17 | 
            -
                    super().__init__()
         | 
| 18 | 
            -
                    self.n_feats = n_feats
         | 
| 19 | 
            -
                    self.n_spks = n_spks
         | 
| 20 | 
            -
                    self.spk_emb_dim = spk_emb_dim
         | 
| 21 | 
            -
                    self.solver = cfm_params.solver
         | 
| 22 | 
            -
                    if hasattr(cfm_params, "sigma_min"):
         | 
| 23 | 
            -
                        self.sigma_min = cfm_params.sigma_min
         | 
| 24 | 
            -
                    else:
         | 
| 25 | 
            -
                        self.sigma_min = 1e-4
         | 
| 26 | 
            -
             | 
| 27 | 
            -
                    self.estimator = None
         | 
| 28 | 
            -
             | 
| 29 | 
            -
                @torch.inference_mode()
         | 
| 30 | 
            -
                def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
         | 
| 31 | 
            -
                    """Forward diffusion
         | 
| 32 | 
            -
             | 
| 33 | 
            -
                    Args:
         | 
| 34 | 
            -
                        mu (torch.Tensor): output of encoder
         | 
| 35 | 
            -
                            shape: (batch_size, n_feats, mel_timesteps)
         | 
| 36 | 
            -
                        mask (torch.Tensor): output_mask
         | 
| 37 | 
            -
                            shape: (batch_size, 1, mel_timesteps)
         | 
| 38 | 
            -
                        n_timesteps (int): number of diffusion steps
         | 
| 39 | 
            -
                        temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
         | 
| 40 | 
            -
                        spks (torch.Tensor, optional): speaker ids. Defaults to None.
         | 
| 41 | 
            -
                            shape: (batch_size, spk_emb_dim)
         | 
| 42 | 
            -
                        cond: Not used but kept for future purposes
         | 
| 43 | 
            -
             | 
| 44 | 
            -
                    Returns:
         | 
| 45 | 
            -
                        sample: generated mel-spectrogram
         | 
| 46 | 
            -
                            shape: (batch_size, n_feats, mel_timesteps)
         | 
| 47 | 
            -
                    """
         | 
| 48 | 
            -
                    z = torch.randn_like(mu) * temperature
         | 
| 49 | 
            -
                    t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
         | 
| 50 | 
            -
                    return self.solve_euler(
         | 
| 51 | 
            -
                        z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond
         | 
| 52 | 
            -
                    )
         | 
| 53 | 
            -
             | 
| 54 | 
            -
                def solve_euler(self, x, t_span, mu, mask, spks, cond):
         | 
| 55 | 
            -
                    """
         | 
| 56 | 
            -
                    Fixed euler solver for ODEs.
         | 
| 57 | 
            -
                    Args:
         | 
| 58 | 
            -
                        x (torch.Tensor): random noise
         | 
| 59 | 
            -
                        t_span (torch.Tensor): n_timesteps interpolated
         | 
| 60 | 
            -
                            shape: (n_timesteps + 1,)
         | 
| 61 | 
            -
                        mu (torch.Tensor): output of encoder
         | 
| 62 | 
            -
                            shape: (batch_size, n_feats, mel_timesteps)
         | 
| 63 | 
            -
                        mask (torch.Tensor): output_mask
         | 
| 64 | 
            -
                            shape: (batch_size, 1, mel_timesteps)
         | 
| 65 | 
            -
                        spks (torch.Tensor, optional): speaker ids. Defaults to None.
         | 
| 66 | 
            -
                            shape: (batch_size, spk_emb_dim)
         | 
| 67 | 
            -
                        cond: Not used but kept for future purposes
         | 
| 68 | 
            -
                    """
         | 
| 69 | 
            -
                    t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
         | 
| 70 | 
            -
             | 
| 71 | 
            -
                    # I am storing this because I can later plot it by putting a debugger here and saving it to a file
         | 
| 72 | 
            -
                    # Or in future might add like a return_all_steps flag
         | 
| 73 | 
            -
                    sol = []
         | 
| 74 | 
            -
             | 
| 75 | 
            -
                    for step in range(1, len(t_span)):
         | 
| 76 | 
            -
                        dphi_dt = self.estimator(x, mask, mu, t, spks, cond)
         | 
| 77 | 
            -
             | 
| 78 | 
            -
                        x = x + dt * dphi_dt
         | 
| 79 | 
            -
                        t = t + dt
         | 
| 80 | 
            -
                        sol.append(x)
         | 
| 81 | 
            -
                        if step < len(t_span) - 1:
         | 
| 82 | 
            -
                            dt = t_span[step + 1] - t
         | 
| 83 | 
            -
             | 
| 84 | 
            -
                    return sol[-1]
         | 
| 85 | 
            -
             | 
| 86 | 
            -
                def compute_loss(self, x1, mask, mu, spks=None, cond=None):
         | 
| 87 | 
            -
                    """Computes diffusion loss
         | 
| 88 | 
            -
             | 
| 89 | 
            -
                    Args:
         | 
| 90 | 
            -
                        x1 (torch.Tensor): Target
         | 
| 91 | 
            -
                            shape: (batch_size, n_feats, mel_timesteps)
         | 
| 92 | 
            -
                        mask (torch.Tensor): target mask
         | 
| 93 | 
            -
                            shape: (batch_size, 1, mel_timesteps)
         | 
| 94 | 
            -
                        mu (torch.Tensor): output of encoder
         | 
| 95 | 
            -
                            shape: (batch_size, n_feats, mel_timesteps)
         | 
| 96 | 
            -
                        spks (torch.Tensor, optional): speaker embedding. Defaults to None.
         | 
| 97 | 
            -
                            shape: (batch_size, spk_emb_dim)
         | 
| 98 | 
            -
             | 
| 99 | 
            -
                    Returns:
         | 
| 100 | 
            -
                        loss: conditional flow matching loss
         | 
| 101 | 
            -
                        y: conditional flow
         | 
| 102 | 
            -
                            shape: (batch_size, n_feats, mel_timesteps)
         | 
| 103 | 
            -
                    """
         | 
| 104 | 
            -
                    b, _, t = mu.shape
         | 
| 105 | 
            -
             | 
| 106 | 
            -
                    # random timestep
         | 
| 107 | 
            -
                    t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
         | 
| 108 | 
            -
                    # sample noise p(x_0)
         | 
| 109 | 
            -
                    z = torch.randn_like(x1)
         | 
| 110 | 
            -
             | 
| 111 | 
            -
                    y = (1 - (1 - self.sigma_min) * t) * z + t * x1
         | 
| 112 | 
            -
                    u = x1 - (1 - self.sigma_min) * z
         | 
| 113 | 
            -
             | 
| 114 | 
            -
                    loss = F.mse_loss(
         | 
| 115 | 
            -
                        self.estimator(y, mask, mu, t.squeeze(), spks), u, reduction="sum"
         | 
| 116 | 
            -
                    ) / (torch.sum(mask) * u.shape[1])
         | 
| 117 | 
            -
                    return loss, y
         | 
| 118 | 
            -
             | 
| 119 | 
            -
             | 
| 120 | 
            -
            class CFM(BASECFM):
         | 
| 121 | 
            -
                def __init__(
         | 
| 122 | 
            -
                    self,
         | 
| 123 | 
            -
                    in_channels,
         | 
| 124 | 
            -
                    out_channel,
         | 
| 125 | 
            -
                    cfm_params,
         | 
| 126 | 
            -
                    decoder_params,
         | 
| 127 | 
            -
                    n_spks=1,
         | 
| 128 | 
            -
                    spk_emb_dim=64,
         | 
| 129 | 
            -
                ):
         | 
| 130 | 
            -
                    super().__init__(
         | 
| 131 | 
            -
                        n_feats=in_channels,
         | 
| 132 | 
            -
                        cfm_params=cfm_params,
         | 
| 133 | 
            -
                        n_spks=n_spks,
         | 
| 134 | 
            -
                        spk_emb_dim=spk_emb_dim,
         | 
| 135 | 
            -
                    )
         | 
| 136 | 
            -
             | 
| 137 | 
            -
                    in_channels = in_channels + (spk_emb_dim if n_spks > 1 else 0)
         | 
| 138 | 
            -
                    # Just change the architecture of the estimator here
         | 
| 139 | 
            -
                    self.estimator = Decoder(
         | 
| 140 | 
            -
                        in_channels=in_channels, out_channels=out_channel, **decoder_params
         | 
| 141 | 
            -
                    )
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        cosyvoice/matcha/transformer.py
    DELETED
    
    | @@ -1,443 +0,0 @@ | |
| 1 | 
            -
            from typing import Any, Dict, Optional
         | 
| 2 | 
            -
             | 
| 3 | 
            -
            import torch
         | 
| 4 | 
            -
            import torch.nn as nn
         | 
| 5 | 
            -
            from diffusers.models.attention import (
         | 
| 6 | 
            -
                GEGLU,
         | 
| 7 | 
            -
                GELU,
         | 
| 8 | 
            -
                AdaLayerNorm,
         | 
| 9 | 
            -
                AdaLayerNormZero,
         | 
| 10 | 
            -
                ApproximateGELU,
         | 
| 11 | 
            -
            )
         | 
| 12 | 
            -
            from diffusers.models.attention_processor import Attention
         | 
| 13 | 
            -
            from diffusers.models.lora import LoRACompatibleLinear
         | 
| 14 | 
            -
            from diffusers.utils.torch_utils import maybe_allow_in_graph
         | 
| 15 | 
            -
             | 
| 16 | 
            -
             | 
| 17 | 
            -
            class SnakeBeta(nn.Module):
         | 
| 18 | 
            -
                """
         | 
| 19 | 
            -
                A modified Snake function which uses separate parameters for the magnitude of the periodic components
         | 
| 20 | 
            -
                Shape:
         | 
| 21 | 
            -
                    - Input: (B, C, T)
         | 
| 22 | 
            -
                    - Output: (B, C, T), same shape as the input
         | 
| 23 | 
            -
                Parameters:
         | 
| 24 | 
            -
                    - alpha - trainable parameter that controls frequency
         | 
| 25 | 
            -
                    - beta - trainable parameter that controls magnitude
         | 
| 26 | 
            -
                References:
         | 
| 27 | 
            -
                    - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
         | 
| 28 | 
            -
                    https://arxiv.org/abs/2006.08195
         | 
| 29 | 
            -
                Examples:
         | 
| 30 | 
            -
                    >>> a1 = snakebeta(256)
         | 
| 31 | 
            -
                    >>> x = torch.randn(256)
         | 
| 32 | 
            -
                    >>> x = a1(x)
         | 
| 33 | 
            -
                """
         | 
| 34 | 
            -
             | 
| 35 | 
            -
                def __init__(
         | 
| 36 | 
            -
                    self,
         | 
| 37 | 
            -
                    in_features,
         | 
| 38 | 
            -
                    out_features,
         | 
| 39 | 
            -
                    alpha=1.0,
         | 
| 40 | 
            -
                    alpha_trainable=True,
         | 
| 41 | 
            -
                    alpha_logscale=True,
         | 
| 42 | 
            -
                ):
         | 
| 43 | 
            -
                    """
         | 
| 44 | 
            -
                    Initialization.
         | 
| 45 | 
            -
                    INPUT:
         | 
| 46 | 
            -
                        - in_features: shape of the input
         | 
| 47 | 
            -
                        - alpha - trainable parameter that controls frequency
         | 
| 48 | 
            -
                        - beta - trainable parameter that controls magnitude
         | 
| 49 | 
            -
                        alpha is initialized to 1 by default, higher values = higher-frequency.
         | 
| 50 | 
            -
                        beta is initialized to 1 by default, higher values = higher-magnitude.
         | 
| 51 | 
            -
                        alpha will be trained along with the rest of your model.
         | 
| 52 | 
            -
                    """
         | 
| 53 | 
            -
                    super().__init__()
         | 
| 54 | 
            -
                    self.in_features = (
         | 
| 55 | 
            -
                        out_features if isinstance(out_features, list) else [out_features]
         | 
| 56 | 
            -
                    )
         | 
| 57 | 
            -
                    self.proj = LoRACompatibleLinear(in_features, out_features)
         | 
| 58 | 
            -
             | 
| 59 | 
            -
                    # initialize alpha
         | 
| 60 | 
            -
                    self.alpha_logscale = alpha_logscale
         | 
| 61 | 
            -
                    if self.alpha_logscale:  # log scale alphas initialized to zeros
         | 
| 62 | 
            -
                        self.alpha = nn.Parameter(torch.zeros(self.in_features) * alpha)
         | 
| 63 | 
            -
                        self.beta = nn.Parameter(torch.zeros(self.in_features) * alpha)
         | 
| 64 | 
            -
                    else:  # linear scale alphas initialized to ones
         | 
| 65 | 
            -
                        self.alpha = nn.Parameter(torch.ones(self.in_features) * alpha)
         | 
| 66 | 
            -
                        self.beta = nn.Parameter(torch.ones(self.in_features) * alpha)
         | 
| 67 | 
            -
             | 
| 68 | 
            -
                    self.alpha.requires_grad = alpha_trainable
         | 
| 69 | 
            -
                    self.beta.requires_grad = alpha_trainable
         | 
| 70 | 
            -
             | 
| 71 | 
            -
                    self.no_div_by_zero = 0.000000001
         | 
| 72 | 
            -
             | 
| 73 | 
            -
                def forward(self, x):
         | 
| 74 | 
            -
                    """
         | 
| 75 | 
            -
                    Forward pass of the function.
         | 
| 76 | 
            -
                    Applies the function to the input elementwise.
         | 
| 77 | 
            -
                    SnakeBeta ∶= x + 1/b * sin^2 (xa)
         | 
| 78 | 
            -
                    """
         | 
| 79 | 
            -
                    x = self.proj(x)
         | 
| 80 | 
            -
                    if self.alpha_logscale:
         | 
| 81 | 
            -
                        alpha = torch.exp(self.alpha)
         | 
| 82 | 
            -
                        beta = torch.exp(self.beta)
         | 
| 83 | 
            -
                    else:
         | 
| 84 | 
            -
                        alpha = self.alpha
         | 
| 85 | 
            -
                        beta = self.beta
         | 
| 86 | 
            -
             | 
| 87 | 
            -
                    x = x + (1.0 / (beta + self.no_div_by_zero)) * torch.pow(
         | 
| 88 | 
            -
                        torch.sin(x * alpha), 2
         | 
| 89 | 
            -
                    )
         | 
| 90 | 
            -
             | 
| 91 | 
            -
                    return x
         | 
| 92 | 
            -
             | 
| 93 | 
            -
             | 
| 94 | 
            -
            class FeedForward(nn.Module):
         | 
| 95 | 
            -
                r"""
         | 
| 96 | 
            -
                A feed-forward layer.
         | 
| 97 | 
            -
             | 
| 98 | 
            -
                Parameters:
         | 
| 99 | 
            -
                    dim (`int`): The number of channels in the input.
         | 
| 100 | 
            -
                    dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
         | 
| 101 | 
            -
                    mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
         | 
| 102 | 
            -
                    dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
         | 
| 103 | 
            -
                    activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
         | 
| 104 | 
            -
                    final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
         | 
| 105 | 
            -
                """
         | 
| 106 | 
            -
             | 
| 107 | 
            -
                def __init__(
         | 
| 108 | 
            -
                    self,
         | 
| 109 | 
            -
                    dim: int,
         | 
| 110 | 
            -
                    dim_out: Optional[int] = None,
         | 
| 111 | 
            -
                    mult: int = 4,
         | 
| 112 | 
            -
                    dropout: float = 0.0,
         | 
| 113 | 
            -
                    activation_fn: str = "geglu",
         | 
| 114 | 
            -
                    final_dropout: bool = False,
         | 
| 115 | 
            -
                ):
         | 
| 116 | 
            -
                    super().__init__()
         | 
| 117 | 
            -
                    inner_dim = int(dim * mult)
         | 
| 118 | 
            -
                    dim_out = dim_out if dim_out is not None else dim
         | 
| 119 | 
            -
             | 
| 120 | 
            -
                    if activation_fn == "gelu":
         | 
| 121 | 
            -
                        act_fn = GELU(dim, inner_dim)
         | 
| 122 | 
            -
                    if activation_fn == "gelu-approximate":
         | 
| 123 | 
            -
                        act_fn = GELU(dim, inner_dim, approximate="tanh")
         | 
| 124 | 
            -
                    elif activation_fn == "geglu":
         | 
| 125 | 
            -
                        act_fn = GEGLU(dim, inner_dim)
         | 
| 126 | 
            -
                    elif activation_fn == "geglu-approximate":
         | 
| 127 | 
            -
                        act_fn = ApproximateGELU(dim, inner_dim)
         | 
| 128 | 
            -
                    elif activation_fn == "snakebeta":
         | 
| 129 | 
            -
                        act_fn = SnakeBeta(dim, inner_dim)
         | 
| 130 | 
            -
             | 
| 131 | 
            -
                    self.net = nn.ModuleList([])
         | 
| 132 | 
            -
                    # project in
         | 
| 133 | 
            -
                    self.net.append(act_fn)
         | 
| 134 | 
            -
                    # project dropout
         | 
| 135 | 
            -
                    self.net.append(nn.Dropout(dropout))
         | 
| 136 | 
            -
                    # project out
         | 
| 137 | 
            -
                    self.net.append(LoRACompatibleLinear(inner_dim, dim_out))
         | 
| 138 | 
            -
                    # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
         | 
| 139 | 
            -
                    if final_dropout:
         | 
| 140 | 
            -
                        self.net.append(nn.Dropout(dropout))
         | 
| 141 | 
            -
             | 
| 142 | 
            -
                def forward(self, hidden_states):
         | 
| 143 | 
            -
                    for module in self.net:
         | 
| 144 | 
            -
                        hidden_states = module(hidden_states)
         | 
| 145 | 
            -
                    return hidden_states
         | 
| 146 | 
            -
             | 
| 147 | 
            -
             | 
| 148 | 
            -
            @maybe_allow_in_graph
         | 
| 149 | 
            -
            class BasicTransformerBlock(nn.Module):
         | 
| 150 | 
            -
                r"""
         | 
| 151 | 
            -
                A basic Transformer block.
         | 
| 152 | 
            -
             | 
| 153 | 
            -
                Parameters:
         | 
| 154 | 
            -
                    dim (`int`): The number of channels in the input and output.
         | 
| 155 | 
            -
                    num_attention_heads (`int`): The number of heads to use for multi-head attention.
         | 
| 156 | 
            -
                    attention_head_dim (`int`): The number of channels in each head.
         | 
| 157 | 
            -
                    dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
         | 
| 158 | 
            -
                    cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
         | 
| 159 | 
            -
                    only_cross_attention (`bool`, *optional*):
         | 
| 160 | 
            -
                        Whether to use only cross-attention layers. In this case two cross attention layers are used.
         | 
| 161 | 
            -
                    double_self_attention (`bool`, *optional*):
         | 
| 162 | 
            -
                        Whether to use two self-attention layers. In this case no cross attention layers are used.
         | 
| 163 | 
            -
                    activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
         | 
| 164 | 
            -
                    num_embeds_ada_norm (:
         | 
| 165 | 
            -
                        obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
         | 
| 166 | 
            -
                    attention_bias (:
         | 
| 167 | 
            -
                        obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
         | 
| 168 | 
            -
                """
         | 
| 169 | 
            -
             | 
| 170 | 
            -
                def __init__(
         | 
| 171 | 
            -
                    self,
         | 
| 172 | 
            -
                    dim: int,
         | 
| 173 | 
            -
                    num_attention_heads: int,
         | 
| 174 | 
            -
                    attention_head_dim: int,
         | 
| 175 | 
            -
                    dropout=0.0,
         | 
| 176 | 
            -
                    cross_attention_dim: Optional[int] = None,
         | 
| 177 | 
            -
                    activation_fn: str = "geglu",
         | 
| 178 | 
            -
                    num_embeds_ada_norm: Optional[int] = None,
         | 
| 179 | 
            -
                    attention_bias: bool = False,
         | 
| 180 | 
            -
                    only_cross_attention: bool = False,
         | 
| 181 | 
            -
                    double_self_attention: bool = False,
         | 
| 182 | 
            -
                    upcast_attention: bool = False,
         | 
| 183 | 
            -
                    norm_elementwise_affine: bool = True,
         | 
| 184 | 
            -
                    norm_type: str = "layer_norm",
         | 
| 185 | 
            -
                    final_dropout: bool = False,
         | 
| 186 | 
            -
                ):
         | 
| 187 | 
            -
                    super().__init__()
         | 
| 188 | 
            -
                    self.only_cross_attention = only_cross_attention
         | 
| 189 | 
            -
             | 
| 190 | 
            -
                    self.use_ada_layer_norm_zero = (
         | 
| 191 | 
            -
                        num_embeds_ada_norm is not None
         | 
| 192 | 
            -
                    ) and norm_type == "ada_norm_zero"
         | 
| 193 | 
            -
                    self.use_ada_layer_norm = (
         | 
| 194 | 
            -
                        num_embeds_ada_norm is not None
         | 
| 195 | 
            -
                    ) and norm_type == "ada_norm"
         | 
| 196 | 
            -
             | 
| 197 | 
            -
                    if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
         | 
| 198 | 
            -
                        raise ValueError(
         | 
| 199 | 
            -
                            f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
         | 
| 200 | 
            -
                            f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
         | 
| 201 | 
            -
                        )
         | 
| 202 | 
            -
             | 
| 203 | 
            -
                    # Define 3 blocks. Each block has its own normalization layer.
         | 
| 204 | 
            -
                    # 1. Self-Attn
         | 
| 205 | 
            -
                    if self.use_ada_layer_norm:
         | 
| 206 | 
            -
                        self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
         | 
| 207 | 
            -
                    elif self.use_ada_layer_norm_zero:
         | 
| 208 | 
            -
                        self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
         | 
| 209 | 
            -
                    else:
         | 
| 210 | 
            -
                        self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
         | 
| 211 | 
            -
                    self.attn1 = Attention(
         | 
| 212 | 
            -
                        query_dim=dim,
         | 
| 213 | 
            -
                        heads=num_attention_heads,
         | 
| 214 | 
            -
                        dim_head=attention_head_dim,
         | 
| 215 | 
            -
                        dropout=dropout,
         | 
| 216 | 
            -
                        bias=attention_bias,
         | 
| 217 | 
            -
                        cross_attention_dim=cross_attention_dim if only_cross_attention else None,
         | 
| 218 | 
            -
                        upcast_attention=upcast_attention,
         | 
| 219 | 
            -
                    )
         | 
| 220 | 
            -
             | 
| 221 | 
            -
                    # 2. Cross-Attn
         | 
| 222 | 
            -
                    if cross_attention_dim is not None or double_self_attention:
         | 
| 223 | 
            -
                        # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
         | 
| 224 | 
            -
                        # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
         | 
| 225 | 
            -
                        # the second cross attention block.
         | 
| 226 | 
            -
                        self.norm2 = (
         | 
| 227 | 
            -
                            AdaLayerNorm(dim, num_embeds_ada_norm)
         | 
| 228 | 
            -
                            if self.use_ada_layer_norm
         | 
| 229 | 
            -
                            else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
         | 
| 230 | 
            -
                        )
         | 
| 231 | 
            -
                        self.attn2 = Attention(
         | 
| 232 | 
            -
                            query_dim=dim,
         | 
| 233 | 
            -
                            cross_attention_dim=(
         | 
| 234 | 
            -
                                cross_attention_dim if not double_self_attention else None
         | 
| 235 | 
            -
                            ),
         | 
| 236 | 
            -
                            heads=num_attention_heads,
         | 
| 237 | 
            -
                            dim_head=attention_head_dim,
         | 
| 238 | 
            -
                            dropout=dropout,
         | 
| 239 | 
            -
                            bias=attention_bias,
         | 
| 240 | 
            -
                            upcast_attention=upcast_attention,
         | 
| 241 | 
            -
                            # scale_qk=False, # uncomment this to not to use flash attention
         | 
| 242 | 
            -
                        )  # is self-attn if encoder_hidden_states is none
         | 
| 243 | 
            -
                    else:
         | 
| 244 | 
            -
                        self.norm2 = None
         | 
| 245 | 
            -
                        self.attn2 = None
         | 
| 246 | 
            -
             | 
| 247 | 
            -
                    # 3. Feed-forward
         | 
| 248 | 
            -
                    self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
         | 
| 249 | 
            -
                    self.ff = FeedForward(
         | 
| 250 | 
            -
                        dim,
         | 
| 251 | 
            -
                        dropout=dropout,
         | 
| 252 | 
            -
                        activation_fn=activation_fn,
         | 
| 253 | 
            -
                        final_dropout=final_dropout,
         | 
| 254 | 
            -
                    )
         | 
| 255 | 
            -
             | 
| 256 | 
            -
                    # let chunk size default to None
         | 
| 257 | 
            -
                    self._chunk_size = None
         | 
| 258 | 
            -
                    self._chunk_dim = 0
         | 
| 259 | 
            -
             | 
| 260 | 
            -
                def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
         | 
| 261 | 
            -
                    # Sets chunk feed-forward
         | 
| 262 | 
            -
                    self._chunk_size = chunk_size
         | 
| 263 | 
            -
                    self._chunk_dim = dim
         | 
| 264 | 
            -
             | 
| 265 | 
            -
                def forward_native(
         | 
| 266 | 
            -
                    self,
         | 
| 267 | 
            -
                    hidden_states: torch.FloatTensor,
         | 
| 268 | 
            -
                    attention_mask: Optional[torch.FloatTensor] = None,
         | 
| 269 | 
            -
                    encoder_hidden_states: Optional[torch.FloatTensor] = None,
         | 
| 270 | 
            -
                    encoder_attention_mask: Optional[torch.FloatTensor] = None,
         | 
| 271 | 
            -
                    timestep: Optional[torch.LongTensor] = None,
         | 
| 272 | 
            -
                    cross_attention_kwargs: Dict[str, Any] = None,
         | 
| 273 | 
            -
                    class_labels: Optional[torch.LongTensor] = None,
         | 
| 274 | 
            -
                ):
         | 
| 275 | 
            -
                    # Notice that normalization is always applied before the real computation in the following blocks.
         | 
| 276 | 
            -
                    # 1. Self-Attention
         | 
| 277 | 
            -
                    if self.use_ada_layer_norm:
         | 
| 278 | 
            -
                        norm_hidden_states = self.norm1(hidden_states, timestep)
         | 
| 279 | 
            -
                    elif self.use_ada_layer_norm_zero:
         | 
| 280 | 
            -
                        norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
         | 
| 281 | 
            -
                            hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
         | 
| 282 | 
            -
                        )
         | 
| 283 | 
            -
                    else:
         | 
| 284 | 
            -
                        norm_hidden_states = self.norm1(hidden_states)
         | 
| 285 | 
            -
             | 
| 286 | 
            -
                    cross_attention_kwargs = (
         | 
| 287 | 
            -
                        cross_attention_kwargs if cross_attention_kwargs is not None else {}
         | 
| 288 | 
            -
                    )
         | 
| 289 | 
            -
             | 
| 290 | 
            -
                    attn_output = self.attn1(
         | 
| 291 | 
            -
                        norm_hidden_states,
         | 
| 292 | 
            -
                        encoder_hidden_states=(
         | 
| 293 | 
            -
                            encoder_hidden_states if self.only_cross_attention else None
         | 
| 294 | 
            -
                        ),
         | 
| 295 | 
            -
                        attention_mask=(
         | 
| 296 | 
            -
                            encoder_attention_mask if self.only_cross_attention else attention_mask
         | 
| 297 | 
            -
                        ),
         | 
| 298 | 
            -
                        **cross_attention_kwargs,
         | 
| 299 | 
            -
                    )
         | 
| 300 | 
            -
                    if self.use_ada_layer_norm_zero:
         | 
| 301 | 
            -
                        attn_output = gate_msa.unsqueeze(1) * attn_output
         | 
| 302 | 
            -
                    hidden_states = attn_output + hidden_states
         | 
| 303 | 
            -
             | 
| 304 | 
            -
                    # 2. Cross-Attention
         | 
| 305 | 
            -
                    if self.attn2 is not None:
         | 
| 306 | 
            -
                        norm_hidden_states = (
         | 
| 307 | 
            -
                            self.norm2(hidden_states, timestep)
         | 
| 308 | 
            -
                            if self.use_ada_layer_norm
         | 
| 309 | 
            -
                            else self.norm2(hidden_states)
         | 
| 310 | 
            -
                        )
         | 
| 311 | 
            -
             | 
| 312 | 
            -
                        attn_output = self.attn2(
         | 
| 313 | 
            -
                            norm_hidden_states,
         | 
| 314 | 
            -
                            encoder_hidden_states=encoder_hidden_states,
         | 
| 315 | 
            -
                            attention_mask=encoder_attention_mask,
         | 
| 316 | 
            -
                            **cross_attention_kwargs,
         | 
| 317 | 
            -
                        )
         | 
| 318 | 
            -
                        hidden_states = attn_output + hidden_states
         | 
| 319 | 
            -
             | 
| 320 | 
            -
                    # 3. Feed-forward
         | 
| 321 | 
            -
                    norm_hidden_states = self.norm3(hidden_states)
         | 
| 322 | 
            -
             | 
| 323 | 
            -
                    if self.use_ada_layer_norm_zero:
         | 
| 324 | 
            -
                        norm_hidden_states = (
         | 
| 325 | 
            -
                            norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
         | 
| 326 | 
            -
                        )
         | 
| 327 | 
            -
             | 
| 328 | 
            -
                    if self._chunk_size is not None:
         | 
| 329 | 
            -
                        # "feed_forward_chunk_size" can be used to save memory
         | 
| 330 | 
            -
                        if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
         | 
| 331 | 
            -
                            raise ValueError(
         | 
| 332 | 
            -
                                f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
         | 
| 333 | 
            -
                            )
         | 
| 334 | 
            -
             | 
| 335 | 
            -
                        num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
         | 
| 336 | 
            -
                        ff_output = torch.cat(
         | 
| 337 | 
            -
                            [
         | 
| 338 | 
            -
                                self.ff(hid_slice)
         | 
| 339 | 
            -
                                for hid_slice in norm_hidden_states.chunk(
         | 
| 340 | 
            -
                                    num_chunks, dim=self._chunk_dim
         | 
| 341 | 
            -
                                )
         | 
| 342 | 
            -
                            ],
         | 
| 343 | 
            -
                            dim=self._chunk_dim,
         | 
| 344 | 
            -
                        )
         | 
| 345 | 
            -
                    else:
         | 
| 346 | 
            -
                        ff_output = self.ff(norm_hidden_states)
         | 
| 347 | 
            -
             | 
| 348 | 
            -
                    if self.use_ada_layer_norm_zero:
         | 
| 349 | 
            -
                        ff_output = gate_mlp.unsqueeze(1) * ff_output
         | 
| 350 | 
            -
             | 
| 351 | 
            -
                    hidden_states = ff_output + hidden_states
         | 
| 352 | 
            -
             | 
| 353 | 
            -
                    return hidden_states
         | 
| 354 | 
            -
             | 
| 355 | 
            -
                def forward(
         | 
| 356 | 
            -
                    self,
         | 
| 357 | 
            -
                    hidden_states: torch.FloatTensor,
         | 
| 358 | 
            -
                    attention_mask: Optional[torch.FloatTensor] = None,
         | 
| 359 | 
            -
                    encoder_hidden_states: Optional[torch.FloatTensor] = None,
         | 
| 360 | 
            -
                    encoder_attention_mask: Optional[torch.FloatTensor] = None,
         | 
| 361 | 
            -
                    timestep: Optional[torch.LongTensor] = None,
         | 
| 362 | 
            -
                    cross_attention_kwargs: Dict[str, Any] = None,
         | 
| 363 | 
            -
                    class_labels: Optional[torch.LongTensor] = None,
         | 
| 364 | 
            -
                ):
         | 
| 365 | 
            -
                    # Notice that normalization is always applied before the real computation in the following blocks.
         | 
| 366 | 
            -
                    # 1. Self-Attention
         | 
| 367 | 
            -
                    if self.use_ada_layer_norm:
         | 
| 368 | 
            -
                        norm_hidden_states = self.norm1(hidden_states, timestep)
         | 
| 369 | 
            -
                    elif self.use_ada_layer_norm_zero:
         | 
| 370 | 
            -
                        norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
         | 
| 371 | 
            -
                            hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
         | 
| 372 | 
            -
                        )
         | 
| 373 | 
            -
                    else:
         | 
| 374 | 
            -
                        norm_hidden_states = self.norm1(hidden_states)
         | 
| 375 | 
            -
             | 
| 376 | 
            -
                    cross_attention_kwargs = (
         | 
| 377 | 
            -
                        cross_attention_kwargs if cross_attention_kwargs is not None else {}
         | 
| 378 | 
            -
                    )
         | 
| 379 | 
            -
             | 
| 380 | 
            -
                    attn_output = self.attn1(
         | 
| 381 | 
            -
                        norm_hidden_states,
         | 
| 382 | 
            -
                        encoder_hidden_states=(
         | 
| 383 | 
            -
                            encoder_hidden_states if self.only_cross_attention else None
         | 
| 384 | 
            -
                        ),
         | 
| 385 | 
            -
                        attention_mask=(
         | 
| 386 | 
            -
                            encoder_attention_mask if self.only_cross_attention else attention_mask
         | 
| 387 | 
            -
                        ),
         | 
| 388 | 
            -
                        **cross_attention_kwargs,
         | 
| 389 | 
            -
                    )
         | 
| 390 | 
            -
                    if self.use_ada_layer_norm_zero:
         | 
| 391 | 
            -
                        attn_output = gate_msa.unsqueeze(1) * attn_output
         | 
| 392 | 
            -
                    hidden_states = attn_output + hidden_states
         | 
| 393 | 
            -
             | 
| 394 | 
            -
                    # 2. Cross-Attention
         | 
| 395 | 
            -
                    if self.attn2 is not None:
         | 
| 396 | 
            -
                        norm_hidden_states = (
         | 
| 397 | 
            -
                            self.norm2(hidden_states, timestep)
         | 
| 398 | 
            -
                            if self.use_ada_layer_norm
         | 
| 399 | 
            -
                            else self.norm2(hidden_states)
         | 
| 400 | 
            -
                        )
         | 
| 401 | 
            -
             | 
| 402 | 
            -
                        attn_output = self.attn2(
         | 
| 403 | 
            -
                            norm_hidden_states,
         | 
| 404 | 
            -
                            encoder_hidden_states=encoder_hidden_states,
         | 
| 405 | 
            -
                            attention_mask=encoder_attention_mask,
         | 
| 406 | 
            -
                            **cross_attention_kwargs,
         | 
| 407 | 
            -
                        )
         | 
| 408 | 
            -
                        hidden_states = attn_output + hidden_states
         | 
| 409 | 
            -
             | 
| 410 | 
            -
                    # 3. Feed-forward
         | 
| 411 | 
            -
                    norm_hidden_states = self.norm3(hidden_states)
         | 
| 412 | 
            -
             | 
| 413 | 
            -
                    if self.use_ada_layer_norm_zero:
         | 
| 414 | 
            -
                        norm_hidden_states = (
         | 
| 415 | 
            -
                            norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
         | 
| 416 | 
            -
                        )
         | 
| 417 | 
            -
             | 
| 418 | 
            -
                    if self._chunk_size is not None:
         | 
| 419 | 
            -
                        # "feed_forward_chunk_size" can be used to save memory
         | 
| 420 | 
            -
                        if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
         | 
| 421 | 
            -
                            raise ValueError(
         | 
| 422 | 
            -
                                f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
         | 
| 423 | 
            -
                            )
         | 
| 424 | 
            -
             | 
| 425 | 
            -
                        num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
         | 
| 426 | 
            -
                        ff_output = torch.cat(
         | 
| 427 | 
            -
                            [
         | 
| 428 | 
            -
                                self.ff(hid_slice)
         | 
| 429 | 
            -
                                for hid_slice in norm_hidden_states.chunk(
         | 
| 430 | 
            -
                                    num_chunks, dim=self._chunk_dim
         | 
| 431 | 
            -
                                )
         | 
| 432 | 
            -
                            ],
         | 
| 433 | 
            -
                            dim=self._chunk_dim,
         | 
| 434 | 
            -
                        )
         | 
| 435 | 
            -
                    else:
         | 
| 436 | 
            -
                        ff_output = self.ff(norm_hidden_states)
         | 
| 437 | 
            -
             | 
| 438 | 
            -
                    if self.use_ada_layer_norm_zero:
         | 
| 439 | 
            -
                        ff_output = gate_mlp.unsqueeze(1) * ff_output
         | 
| 440 | 
            -
             | 
| 441 | 
            -
                    hidden_states = ff_output + hidden_states
         | 
| 442 | 
            -
             | 
| 443 | 
            -
                    return hidden_states
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        cosyvoice/transformer/__init__.py
    DELETED
    
    | 
            File without changes
         | 
    	
        cosyvoice/transformer/activation.py
    DELETED
    
    | @@ -1,87 +0,0 @@ | |
| 1 | 
            -
            # Copyright (c) 2020 Johns Hopkins University (Shinji Watanabe)
         | 
| 2 | 
            -
            #               2020 Northwestern Polytechnical University (Pengcheng Guo)
         | 
| 3 | 
            -
            #               2020 Mobvoi Inc (Binbin Zhang)
         | 
| 4 | 
            -
            #               2024 Alibaba Inc (Xiang Lyu)
         | 
| 5 | 
            -
            #
         | 
| 6 | 
            -
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 7 | 
            -
            # you may not use this file except in compliance with the License.
         | 
| 8 | 
            -
            # You may obtain a copy of the License at
         | 
| 9 | 
            -
            #
         | 
| 10 | 
            -
            #   http://www.apache.org/licenses/LICENSE-2.0
         | 
| 11 | 
            -
            #
         | 
| 12 | 
            -
            # Unless required by applicable law or agreed to in writing, software
         | 
| 13 | 
            -
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 14 | 
            -
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 15 | 
            -
            # See the License for the specific language governing permissions and
         | 
| 16 | 
            -
            # limitations under the License.
         | 
| 17 | 
            -
            """Swish() activation function for Conformer."""
         | 
| 18 | 
            -
             | 
| 19 | 
            -
            import torch
         | 
| 20 | 
            -
            from torch import nn, sin, pow
         | 
| 21 | 
            -
            from torch.nn import Parameter
         | 
| 22 | 
            -
             | 
| 23 | 
            -
             | 
| 24 | 
            -
            class Swish(torch.nn.Module):
         | 
| 25 | 
            -
                """Construct an Swish object."""
         | 
| 26 | 
            -
             | 
| 27 | 
            -
                def forward(self, x: torch.Tensor) -> torch.Tensor:
         | 
| 28 | 
            -
                    """Return Swish activation function."""
         | 
| 29 | 
            -
                    return x * torch.sigmoid(x)
         | 
| 30 | 
            -
             | 
| 31 | 
            -
             | 
| 32 | 
            -
            # Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
         | 
| 33 | 
            -
            #   LICENSE is in incl_licenses directory.
         | 
| 34 | 
            -
            class Snake(nn.Module):
         | 
| 35 | 
            -
                """
         | 
| 36 | 
            -
                Implementation of a sine-based periodic activation function
         | 
| 37 | 
            -
                Shape:
         | 
| 38 | 
            -
                    - Input: (B, C, T)
         | 
| 39 | 
            -
                    - Output: (B, C, T), same shape as the input
         | 
| 40 | 
            -
                Parameters:
         | 
| 41 | 
            -
                    - alpha - trainable parameter
         | 
| 42 | 
            -
                References:
         | 
| 43 | 
            -
                    - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
         | 
| 44 | 
            -
                    https://arxiv.org/abs/2006.08195
         | 
| 45 | 
            -
                Examples:
         | 
| 46 | 
            -
                    >>> a1 = snake(256)
         | 
| 47 | 
            -
                    >>> x = torch.randn(256)
         | 
| 48 | 
            -
                    >>> x = a1(x)
         | 
| 49 | 
            -
                """
         | 
| 50 | 
            -
             | 
| 51 | 
            -
                def __init__(
         | 
| 52 | 
            -
                    self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
         | 
| 53 | 
            -
                ):
         | 
| 54 | 
            -
                    """
         | 
| 55 | 
            -
                    Initialization.
         | 
| 56 | 
            -
                    INPUT:
         | 
| 57 | 
            -
                        - in_features: shape of the input
         | 
| 58 | 
            -
                        - alpha: trainable parameter
         | 
| 59 | 
            -
                        alpha is initialized to 1 by default, higher values = higher-frequency.
         | 
| 60 | 
            -
                        alpha will be trained along with the rest of your model.
         | 
| 61 | 
            -
                    """
         | 
| 62 | 
            -
                    super(Snake, self).__init__()
         | 
| 63 | 
            -
                    self.in_features = in_features
         | 
| 64 | 
            -
             | 
| 65 | 
            -
                    # initialize alpha
         | 
| 66 | 
            -
                    self.alpha_logscale = alpha_logscale
         | 
| 67 | 
            -
                    if self.alpha_logscale:  # log scale alphas initialized to zeros
         | 
| 68 | 
            -
                        self.alpha = Parameter(torch.zeros(in_features) * alpha)
         | 
| 69 | 
            -
                    else:  # linear scale alphas initialized to ones
         | 
| 70 | 
            -
                        self.alpha = Parameter(torch.ones(in_features) * alpha)
         | 
| 71 | 
            -
             | 
| 72 | 
            -
                    self.alpha.requires_grad = alpha_trainable
         | 
| 73 | 
            -
             | 
| 74 | 
            -
                    self.no_div_by_zero = 0.000000001
         | 
| 75 | 
            -
             | 
| 76 | 
            -
                def forward(self, x):
         | 
| 77 | 
            -
                    """
         | 
| 78 | 
            -
                    Forward pass of the function.
         | 
| 79 | 
            -
                    Applies the function to the input elementwise.
         | 
| 80 | 
            -
                    Snake ∶= x + 1/a * sin^2 (xa)
         | 
| 81 | 
            -
                    """
         | 
| 82 | 
            -
                    alpha = self.alpha.unsqueeze(0).unsqueeze(-1)  # line up with x to [B, C, T]
         | 
| 83 | 
            -
                    if self.alpha_logscale:
         | 
| 84 | 
            -
                        alpha = torch.exp(alpha)
         | 
| 85 | 
            -
                    x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
         | 
| 86 | 
            -
             | 
| 87 | 
            -
                    return x
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        cosyvoice/transformer/attention.py
    DELETED
    
    | @@ -1,322 +0,0 @@ | |
| 1 | 
            -
            # Copyright (c) 2019 Shigeki Karita
         | 
| 2 | 
            -
            #               2020 Mobvoi Inc (Binbin Zhang)
         | 
| 3 | 
            -
            #               2022 Xingchen Song ([email protected])
         | 
| 4 | 
            -
            #               2024 Alibaba Inc (Xiang Lyu)
         | 
| 5 | 
            -
            #
         | 
| 6 | 
            -
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 7 | 
            -
            # you may not use this file except in compliance with the License.
         | 
| 8 | 
            -
            # You may obtain a copy of the License at
         | 
| 9 | 
            -
            #
         | 
| 10 | 
            -
            #   http://www.apache.org/licenses/LICENSE-2.0
         | 
| 11 | 
            -
            #
         | 
| 12 | 
            -
            # Unless required by applicable law or agreed to in writing, software
         | 
| 13 | 
            -
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 14 | 
            -
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 15 | 
            -
            # See the License for the specific language governing permissions and
         | 
| 16 | 
            -
            # limitations under the License.
         | 
| 17 | 
            -
            """Multi-Head Attention layer definition."""
         | 
| 18 | 
            -
             | 
| 19 | 
            -
            import math
         | 
| 20 | 
            -
            from typing import Tuple
         | 
| 21 | 
            -
             | 
| 22 | 
            -
            import torch
         | 
| 23 | 
            -
            from torch import nn
         | 
| 24 | 
            -
             | 
| 25 | 
            -
             | 
| 26 | 
            -
            class MultiHeadedAttention(nn.Module):
         | 
| 27 | 
            -
                """Multi-Head Attention layer.
         | 
| 28 | 
            -
             | 
| 29 | 
            -
                Args:
         | 
| 30 | 
            -
                    n_head (int): The number of heads.
         | 
| 31 | 
            -
                    n_feat (int): The number of features.
         | 
| 32 | 
            -
                    dropout_rate (float): Dropout rate.
         | 
| 33 | 
            -
             | 
| 34 | 
            -
                """
         | 
| 35 | 
            -
             | 
| 36 | 
            -
                def __init__(
         | 
| 37 | 
            -
                    self, n_head: int, n_feat: int, dropout_rate: float, key_bias: bool = True
         | 
| 38 | 
            -
                ):
         | 
| 39 | 
            -
                    """Construct an MultiHeadedAttention object."""
         | 
| 40 | 
            -
                    super().__init__()
         | 
| 41 | 
            -
                    assert n_feat % n_head == 0
         | 
| 42 | 
            -
                    # We assume d_v always equals d_k
         | 
| 43 | 
            -
                    self.d_k = n_feat // n_head
         | 
| 44 | 
            -
                    self.h = n_head
         | 
| 45 | 
            -
                    self.linear_q = nn.Linear(n_feat, n_feat)
         | 
| 46 | 
            -
                    self.linear_k = nn.Linear(n_feat, n_feat, bias=key_bias)
         | 
| 47 | 
            -
                    self.linear_v = nn.Linear(n_feat, n_feat)
         | 
| 48 | 
            -
                    self.linear_out = nn.Linear(n_feat, n_feat)
         | 
| 49 | 
            -
                    self.dropout = nn.Dropout(p=dropout_rate)
         | 
| 50 | 
            -
             | 
| 51 | 
            -
                def forward_qkv(
         | 
| 52 | 
            -
                    self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
         | 
| 53 | 
            -
                ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
         | 
| 54 | 
            -
                    """Transform query, key and value.
         | 
| 55 | 
            -
             | 
| 56 | 
            -
                    Args:
         | 
| 57 | 
            -
                        query (torch.Tensor): Query tensor (#batch, time1, size).
         | 
| 58 | 
            -
                        key (torch.Tensor): Key tensor (#batch, time2, size).
         | 
| 59 | 
            -
                        value (torch.Tensor): Value tensor (#batch, time2, size).
         | 
| 60 | 
            -
             | 
| 61 | 
            -
                    Returns:
         | 
| 62 | 
            -
                        torch.Tensor: Transformed query tensor, size
         | 
| 63 | 
            -
                            (#batch, n_head, time1, d_k).
         | 
| 64 | 
            -
                        torch.Tensor: Transformed key tensor, size
         | 
| 65 | 
            -
                            (#batch, n_head, time2, d_k).
         | 
| 66 | 
            -
                        torch.Tensor: Transformed value tensor, size
         | 
| 67 | 
            -
                            (#batch, n_head, time2, d_k).
         | 
| 68 | 
            -
             | 
| 69 | 
            -
                    """
         | 
| 70 | 
            -
                    n_batch = query.size(0)
         | 
| 71 | 
            -
                    q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
         | 
| 72 | 
            -
                    k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
         | 
| 73 | 
            -
                    v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
         | 
| 74 | 
            -
                    q = q.transpose(1, 2)  # (batch, head, time1, d_k)
         | 
| 75 | 
            -
                    k = k.transpose(1, 2)  # (batch, head, time2, d_k)
         | 
| 76 | 
            -
                    v = v.transpose(1, 2)  # (batch, head, time2, d_k)
         | 
| 77 | 
            -
             | 
| 78 | 
            -
                    return q, k, v
         | 
| 79 | 
            -
             | 
| 80 | 
            -
                def forward_attention(
         | 
| 81 | 
            -
                    self,
         | 
| 82 | 
            -
                    value: torch.Tensor,
         | 
| 83 | 
            -
                    scores: torch.Tensor,
         | 
| 84 | 
            -
                    mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
         | 
| 85 | 
            -
                ) -> torch.Tensor:
         | 
| 86 | 
            -
                    """Compute attention context vector.
         | 
| 87 | 
            -
             | 
| 88 | 
            -
                    Args:
         | 
| 89 | 
            -
                        value (torch.Tensor): Transformed value, size
         | 
| 90 | 
            -
                            (#batch, n_head, time2, d_k).
         | 
| 91 | 
            -
                        scores (torch.Tensor): Attention score, size
         | 
| 92 | 
            -
                            (#batch, n_head, time1, time2).
         | 
| 93 | 
            -
                        mask (torch.Tensor): Mask, size (#batch, 1, time2) or
         | 
| 94 | 
            -
                            (#batch, time1, time2), (0, 0, 0) means fake mask.
         | 
| 95 | 
            -
             | 
| 96 | 
            -
                    Returns:
         | 
| 97 | 
            -
                        torch.Tensor: Transformed value (#batch, time1, d_model)
         | 
| 98 | 
            -
                            weighted by the attention score (#batch, time1, time2).
         | 
| 99 | 
            -
             | 
| 100 | 
            -
                    """
         | 
| 101 | 
            -
                    n_batch = value.size(0)
         | 
| 102 | 
            -
                    # NOTE(xcsong): When will `if mask.size(2) > 0` be True?
         | 
| 103 | 
            -
                    #   1. onnx(16/4) [WHY? Because we feed real cache & real mask for the
         | 
| 104 | 
            -
                    #           1st chunk to ease the onnx export.]
         | 
| 105 | 
            -
                    #   2. pytorch training
         | 
| 106 | 
            -
                    if mask.size(2) > 0:  # time2 > 0
         | 
| 107 | 
            -
                        mask = mask.unsqueeze(1).eq(0)  # (batch, 1, *, time2)
         | 
| 108 | 
            -
                        # For last chunk, time2 might be larger than scores.size(-1)
         | 
| 109 | 
            -
                        mask = mask[:, :, :, : scores.size(-1)]  # (batch, 1, *, time2)
         | 
| 110 | 
            -
                        scores = scores.masked_fill(mask, -float("inf"))
         | 
| 111 | 
            -
                        attn = torch.softmax(scores, dim=-1).masked_fill(
         | 
| 112 | 
            -
                            mask, 0.0
         | 
| 113 | 
            -
                        )  # (batch, head, time1, time2)
         | 
| 114 | 
            -
                    # NOTE(xcsong): When will `if mask.size(2) > 0` be False?
         | 
| 115 | 
            -
                    #   1. onnx(16/-1, -1/-1, 16/0)
         | 
| 116 | 
            -
                    #   2. jit (16/-1, -1/-1, 16/0, 16/4)
         | 
| 117 | 
            -
                    else:
         | 
| 118 | 
            -
                        attn = torch.softmax(scores, dim=-1)  # (batch, head, time1, time2)
         | 
| 119 | 
            -
             | 
| 120 | 
            -
                    p_attn = self.dropout(attn)
         | 
| 121 | 
            -
                    x = torch.matmul(p_attn, value)  # (batch, head, time1, d_k)
         | 
| 122 | 
            -
                    x = (
         | 
| 123 | 
            -
                        x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
         | 
| 124 | 
            -
                    )  # (batch, time1, d_model)
         | 
| 125 | 
            -
             | 
| 126 | 
            -
                    return self.linear_out(x)  # (batch, time1, d_model)
         | 
| 127 | 
            -
             | 
| 128 | 
            -
                def forward(
         | 
| 129 | 
            -
                    self,
         | 
| 130 | 
            -
                    query: torch.Tensor,
         | 
| 131 | 
            -
                    key: torch.Tensor,
         | 
| 132 | 
            -
                    value: torch.Tensor,
         | 
| 133 | 
            -
                    mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
         | 
| 134 | 
            -
                    pos_emb: torch.Tensor = torch.empty(0),
         | 
| 135 | 
            -
                    cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
         | 
| 136 | 
            -
                ) -> Tuple[torch.Tensor, torch.Tensor]:
         | 
| 137 | 
            -
                    """Compute scaled dot product attention.
         | 
| 138 | 
            -
             | 
| 139 | 
            -
                    Args:
         | 
| 140 | 
            -
                        query (torch.Tensor): Query tensor (#batch, time1, size).
         | 
| 141 | 
            -
                        key (torch.Tensor): Key tensor (#batch, time2, size).
         | 
| 142 | 
            -
                        value (torch.Tensor): Value tensor (#batch, time2, size).
         | 
| 143 | 
            -
                        mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
         | 
| 144 | 
            -
                            (#batch, time1, time2).
         | 
| 145 | 
            -
                            1.When applying cross attention between decoder and encoder,
         | 
| 146 | 
            -
                            the batch padding mask for input is in (#batch, 1, T) shape.
         | 
| 147 | 
            -
                            2.When applying self attention of encoder,
         | 
| 148 | 
            -
                            the mask is in (#batch, T, T)  shape.
         | 
| 149 | 
            -
                            3.When applying self attention of decoder,
         | 
| 150 | 
            -
                            the mask is in (#batch, L, L)  shape.
         | 
| 151 | 
            -
                            4.If the different position in decoder see different block
         | 
| 152 | 
            -
                            of the encoder, such as Mocha, the passed in mask could be
         | 
| 153 | 
            -
                            in (#batch, L, T) shape. But there is no such case in current
         | 
| 154 | 
            -
                            CosyVoice.
         | 
| 155 | 
            -
                        cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
         | 
| 156 | 
            -
                            where `cache_t == chunk_size * num_decoding_left_chunks`
         | 
| 157 | 
            -
                            and `head * d_k == size`
         | 
| 158 | 
            -
             | 
| 159 | 
            -
             | 
| 160 | 
            -
                    Returns:
         | 
| 161 | 
            -
                        torch.Tensor: Output tensor (#batch, time1, d_model).
         | 
| 162 | 
            -
                        torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
         | 
| 163 | 
            -
                            where `cache_t == chunk_size * num_decoding_left_chunks`
         | 
| 164 | 
            -
                            and `head * d_k == size`
         | 
| 165 | 
            -
             | 
| 166 | 
            -
                    """
         | 
| 167 | 
            -
                    q, k, v = self.forward_qkv(query, key, value)
         | 
| 168 | 
            -
             | 
| 169 | 
            -
                    # NOTE(xcsong):
         | 
| 170 | 
            -
                    #   when export onnx model, for 1st chunk, we feed
         | 
| 171 | 
            -
                    #       cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
         | 
| 172 | 
            -
                    #       or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
         | 
| 173 | 
            -
                    #       In all modes, `if cache.size(0) > 0` will alwayse be `True`
         | 
| 174 | 
            -
                    #       and we will always do splitting and
         | 
| 175 | 
            -
                    #       concatnation(this will simplify onnx export). Note that
         | 
| 176 | 
            -
                    #       it's OK to concat & split zero-shaped tensors(see code below).
         | 
| 177 | 
            -
                    #   when export jit  model, for 1st chunk, we always feed
         | 
| 178 | 
            -
                    #       cache(0, 0, 0, 0) since jit supports dynamic if-branch.
         | 
| 179 | 
            -
                    # >>> a = torch.ones((1, 2, 0, 4))
         | 
| 180 | 
            -
                    # >>> b = torch.ones((1, 2, 3, 4))
         | 
| 181 | 
            -
                    # >>> c = torch.cat((a, b), dim=2)
         | 
| 182 | 
            -
                    # >>> torch.equal(b, c)        # True
         | 
| 183 | 
            -
                    # >>> d = torch.split(a, 2, dim=-1)
         | 
| 184 | 
            -
                    # >>> torch.equal(d[0], d[1])  # True
         | 
| 185 | 
            -
                    if cache.size(0) > 0:
         | 
| 186 | 
            -
                        key_cache, value_cache = torch.split(cache, cache.size(-1) // 2, dim=-1)
         | 
| 187 | 
            -
                        k = torch.cat([key_cache, k], dim=2)
         | 
| 188 | 
            -
                        v = torch.cat([value_cache, v], dim=2)
         | 
| 189 | 
            -
                    # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
         | 
| 190 | 
            -
                    #   non-trivial to calculate `next_cache_start` here.
         | 
| 191 | 
            -
                    new_cache = torch.cat((k, v), dim=-1)
         | 
| 192 | 
            -
             | 
| 193 | 
            -
                    scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
         | 
| 194 | 
            -
                    return self.forward_attention(v, scores, mask), new_cache
         | 
| 195 | 
            -
             | 
| 196 | 
            -
             | 
| 197 | 
            -
            class RelPositionMultiHeadedAttention(MultiHeadedAttention):
         | 
| 198 | 
            -
                """Multi-Head Attention layer with relative position encoding.
         | 
| 199 | 
            -
                Paper: https://arxiv.org/abs/1901.02860
         | 
| 200 | 
            -
                Args:
         | 
| 201 | 
            -
                    n_head (int): The number of heads.
         | 
| 202 | 
            -
                    n_feat (int): The number of features.
         | 
| 203 | 
            -
                    dropout_rate (float): Dropout rate.
         | 
| 204 | 
            -
                """
         | 
| 205 | 
            -
             | 
| 206 | 
            -
                def __init__(
         | 
| 207 | 
            -
                    self, n_head: int, n_feat: int, dropout_rate: float, key_bias: bool = True
         | 
| 208 | 
            -
                ):
         | 
| 209 | 
            -
                    """Construct an RelPositionMultiHeadedAttention object."""
         | 
| 210 | 
            -
                    super().__init__(n_head, n_feat, dropout_rate, key_bias)
         | 
| 211 | 
            -
                    # linear transformation for positional encoding
         | 
| 212 | 
            -
                    self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
         | 
| 213 | 
            -
                    # these two learnable bias are used in matrix c and matrix d
         | 
| 214 | 
            -
                    # as described in https://arxiv.org/abs/1901.02860 Section 3.3
         | 
| 215 | 
            -
                    self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
         | 
| 216 | 
            -
                    self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
         | 
| 217 | 
            -
                    torch.nn.init.xavier_uniform_(self.pos_bias_u)
         | 
| 218 | 
            -
                    torch.nn.init.xavier_uniform_(self.pos_bias_v)
         | 
| 219 | 
            -
             | 
| 220 | 
            -
                def rel_shift(self, x: torch.Tensor) -> torch.Tensor:
         | 
| 221 | 
            -
                    """Compute relative positional encoding.
         | 
| 222 | 
            -
             | 
| 223 | 
            -
                    Args:
         | 
| 224 | 
            -
                        x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1).
         | 
| 225 | 
            -
                        time1 means the length of query vector.
         | 
| 226 | 
            -
             | 
| 227 | 
            -
                    Returns:
         | 
| 228 | 
            -
                        torch.Tensor: Output tensor.
         | 
| 229 | 
            -
             | 
| 230 | 
            -
                    """
         | 
| 231 | 
            -
                    zero_pad = torch.zeros(
         | 
| 232 | 
            -
                        (x.size()[0], x.size()[1], x.size()[2], 1), device=x.device, dtype=x.dtype
         | 
| 233 | 
            -
                    )
         | 
| 234 | 
            -
                    x_padded = torch.cat([zero_pad, x], dim=-1)
         | 
| 235 | 
            -
             | 
| 236 | 
            -
                    x_padded = x_padded.view(x.size()[0], x.size()[1], x.size(3) + 1, x.size(2))
         | 
| 237 | 
            -
                    x = x_padded[:, :, 1:].view_as(x)[
         | 
| 238 | 
            -
                        :, :, :, : x.size(-1) // 2 + 1
         | 
| 239 | 
            -
                    ]  # only keep the positions from 0 to time2
         | 
| 240 | 
            -
                    return x
         | 
| 241 | 
            -
             | 
| 242 | 
            -
                def forward(
         | 
| 243 | 
            -
                    self,
         | 
| 244 | 
            -
                    query: torch.Tensor,
         | 
| 245 | 
            -
                    key: torch.Tensor,
         | 
| 246 | 
            -
                    value: torch.Tensor,
         | 
| 247 | 
            -
                    mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
         | 
| 248 | 
            -
                    pos_emb: torch.Tensor = torch.empty(0),
         | 
| 249 | 
            -
                    cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
         | 
| 250 | 
            -
                ) -> Tuple[torch.Tensor, torch.Tensor]:
         | 
| 251 | 
            -
                    """Compute 'Scaled Dot Product Attention' with rel. positional encoding.
         | 
| 252 | 
            -
                    Args:
         | 
| 253 | 
            -
                        query (torch.Tensor): Query tensor (#batch, time1, size).
         | 
| 254 | 
            -
                        key (torch.Tensor): Key tensor (#batch, time2, size).
         | 
| 255 | 
            -
                        value (torch.Tensor): Value tensor (#batch, time2, size).
         | 
| 256 | 
            -
                        mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
         | 
| 257 | 
            -
                            (#batch, time1, time2), (0, 0, 0) means fake mask.
         | 
| 258 | 
            -
                        pos_emb (torch.Tensor): Positional embedding tensor
         | 
| 259 | 
            -
                            (#batch, time2, size).
         | 
| 260 | 
            -
                        cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
         | 
| 261 | 
            -
                            where `cache_t == chunk_size * num_decoding_left_chunks`
         | 
| 262 | 
            -
                            and `head * d_k == size`
         | 
| 263 | 
            -
                    Returns:
         | 
| 264 | 
            -
                        torch.Tensor: Output tensor (#batch, time1, d_model).
         | 
| 265 | 
            -
                        torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
         | 
| 266 | 
            -
                            where `cache_t == chunk_size * num_decoding_left_chunks`
         | 
| 267 | 
            -
                            and `head * d_k == size`
         | 
| 268 | 
            -
                    """
         | 
| 269 | 
            -
                    q, k, v = self.forward_qkv(query, key, value)
         | 
| 270 | 
            -
                    q = q.transpose(1, 2)  # (batch, time1, head, d_k)
         | 
| 271 | 
            -
             | 
| 272 | 
            -
                    # NOTE(xcsong):
         | 
| 273 | 
            -
                    #   when export onnx model, for 1st chunk, we feed
         | 
| 274 | 
            -
                    #       cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
         | 
| 275 | 
            -
                    #       or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
         | 
| 276 | 
            -
                    #       In all modes, `if cache.size(0) > 0` will alwayse be `True`
         | 
| 277 | 
            -
                    #       and we will always do splitting and
         | 
| 278 | 
            -
                    #       concatnation(this will simplify onnx export). Note that
         | 
| 279 | 
            -
                    #       it's OK to concat & split zero-shaped tensors(see code below).
         | 
| 280 | 
            -
                    #   when export jit  model, for 1st chunk, we always feed
         | 
| 281 | 
            -
                    #       cache(0, 0, 0, 0) since jit supports dynamic if-branch.
         | 
| 282 | 
            -
                    # >>> a = torch.ones((1, 2, 0, 4))
         | 
| 283 | 
            -
                    # >>> b = torch.ones((1, 2, 3, 4))
         | 
| 284 | 
            -
                    # >>> c = torch.cat((a, b), dim=2)
         | 
| 285 | 
            -
                    # >>> torch.equal(b, c)        # True
         | 
| 286 | 
            -
                    # >>> d = torch.split(a, 2, dim=-1)
         | 
| 287 | 
            -
                    # >>> torch.equal(d[0], d[1])  # True
         | 
| 288 | 
            -
                    if cache.size(0) > 0:
         | 
| 289 | 
            -
                        key_cache, value_cache = torch.split(cache, cache.size(-1) // 2, dim=-1)
         | 
| 290 | 
            -
                        k = torch.cat([key_cache, k], dim=2)
         | 
| 291 | 
            -
                        v = torch.cat([value_cache, v], dim=2)
         | 
| 292 | 
            -
                    # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
         | 
| 293 | 
            -
                    #   non-trivial to calculate `next_cache_start` here.
         | 
| 294 | 
            -
                    new_cache = torch.cat((k, v), dim=-1)
         | 
| 295 | 
            -
             | 
| 296 | 
            -
                    n_batch_pos = pos_emb.size(0)
         | 
| 297 | 
            -
                    p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
         | 
| 298 | 
            -
                    p = p.transpose(1, 2)  # (batch, head, time1, d_k)
         | 
| 299 | 
            -
             | 
| 300 | 
            -
                    # (batch, head, time1, d_k)
         | 
| 301 | 
            -
                    q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
         | 
| 302 | 
            -
                    # (batch, head, time1, d_k)
         | 
| 303 | 
            -
                    q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
         | 
| 304 | 
            -
             | 
| 305 | 
            -
                    # compute attention score
         | 
| 306 | 
            -
                    # first compute matrix a and matrix c
         | 
| 307 | 
            -
                    # as described in https://arxiv.org/abs/1901.02860 Section 3.3
         | 
| 308 | 
            -
                    # (batch, head, time1, time2)
         | 
| 309 | 
            -
                    matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
         | 
| 310 | 
            -
             | 
| 311 | 
            -
                    # compute matrix b and matrix d
         | 
| 312 | 
            -
                    # (batch, head, time1, time2)
         | 
| 313 | 
            -
                    matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
         | 
| 314 | 
            -
                    # NOTE(Xiang Lyu): Keep rel_shift since espnet rel_pos_emb is used
         | 
| 315 | 
            -
                    if matrix_ac.shape != matrix_bd.shape:
         | 
| 316 | 
            -
                        matrix_bd = self.rel_shift(matrix_bd)
         | 
| 317 | 
            -
             | 
| 318 | 
            -
                    scores = (matrix_ac + matrix_bd) / math.sqrt(
         | 
| 319 | 
            -
                        self.d_k
         | 
| 320 | 
            -
                    )  # (batch, head, time1, time2)
         | 
| 321 | 
            -
             | 
| 322 | 
            -
                    return self.forward_attention(v, scores, mask), new_cache
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        cosyvoice/transformer/convolution.py
    DELETED
    
    | @@ -1,147 +0,0 @@ | |
| 1 | 
            -
            # Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
         | 
| 2 | 
            -
            #               2024 Alibaba Inc (Xiang Lyu)
         | 
| 3 | 
            -
            #
         | 
| 4 | 
            -
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 5 | 
            -
            # you may not use this file except in compliance with the License.
         | 
| 6 | 
            -
            # You may obtain a copy of the License at
         | 
| 7 | 
            -
            #
         | 
| 8 | 
            -
            #     http://www.apache.org/licenses/LICENSE-2.0
         | 
| 9 | 
            -
            #
         | 
| 10 | 
            -
            # Unless required by applicable law or agreed to in writing, software
         | 
| 11 | 
            -
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 12 | 
            -
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 13 | 
            -
            # See the License for the specific language governing permissions and
         | 
| 14 | 
            -
            # limitations under the License.
         | 
| 15 | 
            -
            # Modified from ESPnet(https://github.com/espnet/espnet)
         | 
| 16 | 
            -
            """ConvolutionModule definition."""
         | 
| 17 | 
            -
             | 
| 18 | 
            -
            from typing import Tuple
         | 
| 19 | 
            -
             | 
| 20 | 
            -
            import torch
         | 
| 21 | 
            -
            from torch import nn
         | 
| 22 | 
            -
             | 
| 23 | 
            -
             | 
| 24 | 
            -
            class ConvolutionModule(nn.Module):
         | 
| 25 | 
            -
                """ConvolutionModule in Conformer model."""
         | 
| 26 | 
            -
             | 
| 27 | 
            -
                def __init__(
         | 
| 28 | 
            -
                    self,
         | 
| 29 | 
            -
                    channels: int,
         | 
| 30 | 
            -
                    kernel_size: int = 15,
         | 
| 31 | 
            -
                    activation: nn.Module = nn.ReLU(),
         | 
| 32 | 
            -
                    norm: str = "batch_norm",
         | 
| 33 | 
            -
                    causal: bool = False,
         | 
| 34 | 
            -
                    bias: bool = True,
         | 
| 35 | 
            -
                ):
         | 
| 36 | 
            -
                    """Construct an ConvolutionModule object.
         | 
| 37 | 
            -
                    Args:
         | 
| 38 | 
            -
                        channels (int): The number of channels of conv layers.
         | 
| 39 | 
            -
                        kernel_size (int): Kernel size of conv layers.
         | 
| 40 | 
            -
                        causal (int): Whether use causal convolution or not
         | 
| 41 | 
            -
                    """
         | 
| 42 | 
            -
                    super().__init__()
         | 
| 43 | 
            -
             | 
| 44 | 
            -
                    self.pointwise_conv1 = nn.Conv1d(
         | 
| 45 | 
            -
                        channels,
         | 
| 46 | 
            -
                        2 * channels,
         | 
| 47 | 
            -
                        kernel_size=1,
         | 
| 48 | 
            -
                        stride=1,
         | 
| 49 | 
            -
                        padding=0,
         | 
| 50 | 
            -
                        bias=bias,
         | 
| 51 | 
            -
                    )
         | 
| 52 | 
            -
                    # self.lorder is used to distinguish if it's a causal convolution,
         | 
| 53 | 
            -
                    # if self.lorder > 0: it's a causal convolution, the input will be
         | 
| 54 | 
            -
                    #    padded with self.lorder frames on the left in forward.
         | 
| 55 | 
            -
                    # else: it's a symmetrical convolution
         | 
| 56 | 
            -
                    if causal:
         | 
| 57 | 
            -
                        padding = 0
         | 
| 58 | 
            -
                        self.lorder = kernel_size - 1
         | 
| 59 | 
            -
                    else:
         | 
| 60 | 
            -
                        # kernel_size should be an odd number for none causal convolution
         | 
| 61 | 
            -
                        assert (kernel_size - 1) % 2 == 0
         | 
| 62 | 
            -
                        padding = (kernel_size - 1) // 2
         | 
| 63 | 
            -
                        self.lorder = 0
         | 
| 64 | 
            -
                    self.depthwise_conv = nn.Conv1d(
         | 
| 65 | 
            -
                        channels,
         | 
| 66 | 
            -
                        channels,
         | 
| 67 | 
            -
                        kernel_size,
         | 
| 68 | 
            -
                        stride=1,
         | 
| 69 | 
            -
                        padding=padding,
         | 
| 70 | 
            -
                        groups=channels,
         | 
| 71 | 
            -
                        bias=bias,
         | 
| 72 | 
            -
                    )
         | 
| 73 | 
            -
             | 
| 74 | 
            -
                    assert norm in ["batch_norm", "layer_norm"]
         | 
| 75 | 
            -
                    if norm == "batch_norm":
         | 
| 76 | 
            -
                        self.use_layer_norm = False
         | 
| 77 | 
            -
                        self.norm = nn.BatchNorm1d(channels)
         | 
| 78 | 
            -
                    else:
         | 
| 79 | 
            -
                        self.use_layer_norm = True
         | 
| 80 | 
            -
                        self.norm = nn.LayerNorm(channels)
         | 
| 81 | 
            -
             | 
| 82 | 
            -
                    self.pointwise_conv2 = nn.Conv1d(
         | 
| 83 | 
            -
                        channels,
         | 
| 84 | 
            -
                        channels,
         | 
| 85 | 
            -
                        kernel_size=1,
         | 
| 86 | 
            -
                        stride=1,
         | 
| 87 | 
            -
                        padding=0,
         | 
| 88 | 
            -
                        bias=bias,
         | 
| 89 | 
            -
                    )
         | 
| 90 | 
            -
                    self.activation = activation
         | 
| 91 | 
            -
             | 
| 92 | 
            -
                def forward(
         | 
| 93 | 
            -
                    self,
         | 
| 94 | 
            -
                    x: torch.Tensor,
         | 
| 95 | 
            -
                    mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
         | 
| 96 | 
            -
                    cache: torch.Tensor = torch.zeros((0, 0, 0)),
         | 
| 97 | 
            -
                ) -> Tuple[torch.Tensor, torch.Tensor]:
         | 
| 98 | 
            -
                    """Compute convolution module.
         | 
| 99 | 
            -
                    Args:
         | 
| 100 | 
            -
                        x (torch.Tensor): Input tensor (#batch, time, channels).
         | 
| 101 | 
            -
                        mask_pad (torch.Tensor): used for batch padding (#batch, 1, time),
         | 
| 102 | 
            -
                            (0, 0, 0) means fake mask.
         | 
| 103 | 
            -
                        cache (torch.Tensor): left context cache, it is only
         | 
| 104 | 
            -
                            used in causal convolution (#batch, channels, cache_t),
         | 
| 105 | 
            -
                            (0, 0, 0) meas fake cache.
         | 
| 106 | 
            -
                    Returns:
         | 
| 107 | 
            -
                        torch.Tensor: Output tensor (#batch, time, channels).
         | 
| 108 | 
            -
                    """
         | 
| 109 | 
            -
                    # exchange the temporal dimension and the feature dimension
         | 
| 110 | 
            -
                    x = x.transpose(1, 2)  # (#batch, channels, time)
         | 
| 111 | 
            -
             | 
| 112 | 
            -
                    # mask batch padding
         | 
| 113 | 
            -
                    if mask_pad.size(2) > 0:  # time > 0
         | 
| 114 | 
            -
                        x.masked_fill_(~mask_pad, 0.0)
         | 
| 115 | 
            -
             | 
| 116 | 
            -
                    if self.lorder > 0:
         | 
| 117 | 
            -
                        if cache.size(2) == 0:  # cache_t == 0
         | 
| 118 | 
            -
                            x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0)
         | 
| 119 | 
            -
                        else:
         | 
| 120 | 
            -
                            assert cache.size(0) == x.size(0)  # equal batch
         | 
| 121 | 
            -
                            assert cache.size(1) == x.size(1)  # equal channel
         | 
| 122 | 
            -
                            x = torch.cat((cache, x), dim=2)
         | 
| 123 | 
            -
                        assert x.size(2) > self.lorder
         | 
| 124 | 
            -
                        new_cache = x[:, :, -self.lorder :]
         | 
| 125 | 
            -
                    else:
         | 
| 126 | 
            -
                        # It's better we just return None if no cache is required,
         | 
| 127 | 
            -
                        # However, for JIT export, here we just fake one tensor instead of
         | 
| 128 | 
            -
                        # None.
         | 
| 129 | 
            -
                        new_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
         | 
| 130 | 
            -
             | 
| 131 | 
            -
                    # GLU mechanism
         | 
| 132 | 
            -
                    x = self.pointwise_conv1(x)  # (batch, 2*channel, dim)
         | 
| 133 | 
            -
                    x = nn.functional.glu(x, dim=1)  # (batch, channel, dim)
         | 
| 134 | 
            -
             | 
| 135 | 
            -
                    # 1D Depthwise Conv
         | 
| 136 | 
            -
                    x = self.depthwise_conv(x)
         | 
| 137 | 
            -
                    if self.use_layer_norm:
         | 
| 138 | 
            -
                        x = x.transpose(1, 2)
         | 
| 139 | 
            -
                    x = self.activation(self.norm(x))
         | 
| 140 | 
            -
                    if self.use_layer_norm:
         | 
| 141 | 
            -
                        x = x.transpose(1, 2)
         | 
| 142 | 
            -
                    x = self.pointwise_conv2(x)
         | 
| 143 | 
            -
                    # mask batch padding
         | 
| 144 | 
            -
                    if mask_pad.size(2) > 0:  # time > 0
         | 
| 145 | 
            -
                        x.masked_fill_(~mask_pad, 0.0)
         | 
| 146 | 
            -
             | 
| 147 | 
            -
                    return x.transpose(1, 2), new_cache
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        cosyvoice/transformer/decoder.py
    DELETED
    
    | @@ -1,418 +0,0 @@ | |
| 1 | 
            -
            # Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
         | 
| 2 | 
            -
            #               2024 Alibaba Inc (Xiang Lyu)
         | 
| 3 | 
            -
            #
         | 
| 4 | 
            -
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 5 | 
            -
            # you may not use this file except in compliance with the License.
         | 
| 6 | 
            -
            # You may obtain a copy of the License at
         | 
| 7 | 
            -
            #
         | 
| 8 | 
            -
            #     http://www.apache.org/licenses/LICENSE-2.0
         | 
| 9 | 
            -
            #
         | 
| 10 | 
            -
            # Unless required by applicable law or agreed to in writing, software
         | 
| 11 | 
            -
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 12 | 
            -
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 13 | 
            -
            # See the License for the specific language governing permissions and
         | 
| 14 | 
            -
            # limitations under the License.
         | 
| 15 | 
            -
            # Modified from ESPnet(https://github.com/espnet/espnet)
         | 
| 16 | 
            -
            """Decoder definition."""
         | 
| 17 | 
            -
            from typing import Tuple, List, Optional
         | 
| 18 | 
            -
             | 
| 19 | 
            -
            import torch
         | 
| 20 | 
            -
            import torch.utils.checkpoint as ckpt
         | 
| 21 | 
            -
            import logging
         | 
| 22 | 
            -
             | 
| 23 | 
            -
            from cosyvoice.transformer.decoder_layer import DecoderLayer
         | 
| 24 | 
            -
            from cosyvoice.transformer.positionwise_feed_forward import (
         | 
| 25 | 
            -
                PositionwiseFeedForward,
         | 
| 26 | 
            -
            )
         | 
| 27 | 
            -
            from cosyvoice.utils.class_utils import (
         | 
| 28 | 
            -
                COSYVOICE_EMB_CLASSES,
         | 
| 29 | 
            -
                COSYVOICE_ATTENTION_CLASSES,
         | 
| 30 | 
            -
                COSYVOICE_ACTIVATION_CLASSES,
         | 
| 31 | 
            -
            )
         | 
| 32 | 
            -
            from cosyvoice.utils.mask import subsequent_mask, make_pad_mask
         | 
| 33 | 
            -
             | 
| 34 | 
            -
             | 
| 35 | 
            -
            class TransformerDecoder(torch.nn.Module):
         | 
| 36 | 
            -
                """Base class of Transfomer decoder module.
         | 
| 37 | 
            -
                Args:
         | 
| 38 | 
            -
                    vocab_size: output dim
         | 
| 39 | 
            -
                    encoder_output_size: dimension of attention
         | 
| 40 | 
            -
                    attention_heads: the number of heads of multi head attention
         | 
| 41 | 
            -
                    linear_units: the hidden units number of position-wise feedforward
         | 
| 42 | 
            -
                    num_blocks: the number of decoder blocks
         | 
| 43 | 
            -
                    dropout_rate: dropout rate
         | 
| 44 | 
            -
                    self_attention_dropout_rate: dropout rate for attention
         | 
| 45 | 
            -
                    input_layer: input layer type
         | 
| 46 | 
            -
                    use_output_layer: whether to use output layer
         | 
| 47 | 
            -
                    pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
         | 
| 48 | 
            -
                    normalize_before:
         | 
| 49 | 
            -
                        True: use layer_norm before each sub-block of a layer.
         | 
| 50 | 
            -
                        False: use layer_norm after each sub-block of a layer.
         | 
| 51 | 
            -
                    src_attention: if false, encoder-decoder cross attention is not
         | 
| 52 | 
            -
                                   applied, such as CIF model
         | 
| 53 | 
            -
                    key_bias: whether use bias in attention.linear_k, False for whisper models.
         | 
| 54 | 
            -
                    gradient_checkpointing: rerunning a forward-pass segment for each
         | 
| 55 | 
            -
                        checkpointed segment during backward.
         | 
| 56 | 
            -
                    tie_word_embedding: Tie or clone module weights depending of whether we are
         | 
| 57 | 
            -
                        using TorchScript or not
         | 
| 58 | 
            -
                """
         | 
| 59 | 
            -
             | 
| 60 | 
            -
                def __init__(
         | 
| 61 | 
            -
                    self,
         | 
| 62 | 
            -
                    vocab_size: int,
         | 
| 63 | 
            -
                    encoder_output_size: int,
         | 
| 64 | 
            -
                    attention_heads: int = 4,
         | 
| 65 | 
            -
                    linear_units: int = 2048,
         | 
| 66 | 
            -
                    num_blocks: int = 6,
         | 
| 67 | 
            -
                    dropout_rate: float = 0.1,
         | 
| 68 | 
            -
                    positional_dropout_rate: float = 0.1,
         | 
| 69 | 
            -
                    self_attention_dropout_rate: float = 0.0,
         | 
| 70 | 
            -
                    src_attention_dropout_rate: float = 0.0,
         | 
| 71 | 
            -
                    input_layer: str = "embed",
         | 
| 72 | 
            -
                    use_output_layer: bool = True,
         | 
| 73 | 
            -
                    normalize_before: bool = True,
         | 
| 74 | 
            -
                    src_attention: bool = True,
         | 
| 75 | 
            -
                    key_bias: bool = True,
         | 
| 76 | 
            -
                    activation_type: str = "relu",
         | 
| 77 | 
            -
                    gradient_checkpointing: bool = False,
         | 
| 78 | 
            -
                    tie_word_embedding: bool = False,
         | 
| 79 | 
            -
                ):
         | 
| 80 | 
            -
                    super().__init__()
         | 
| 81 | 
            -
                    attention_dim = encoder_output_size
         | 
| 82 | 
            -
                    activation = COSYVOICE_ACTIVATION_CLASSES[activation_type]()
         | 
| 83 | 
            -
             | 
| 84 | 
            -
                    self.embed = torch.nn.Sequential(
         | 
| 85 | 
            -
                        (
         | 
| 86 | 
            -
                            torch.nn.Identity()
         | 
| 87 | 
            -
                            if input_layer == "no_pos"
         | 
| 88 | 
            -
                            else torch.nn.Embedding(vocab_size, attention_dim)
         | 
| 89 | 
            -
                        ),
         | 
| 90 | 
            -
                        COSYVOICE_EMB_CLASSES[input_layer](attention_dim, positional_dropout_rate),
         | 
| 91 | 
            -
                    )
         | 
| 92 | 
            -
             | 
| 93 | 
            -
                    self.normalize_before = normalize_before
         | 
| 94 | 
            -
                    self.after_norm = torch.nn.LayerNorm(attention_dim, eps=1e-5)
         | 
| 95 | 
            -
                    self.use_output_layer = use_output_layer
         | 
| 96 | 
            -
                    if use_output_layer:
         | 
| 97 | 
            -
                        self.output_layer = torch.nn.Linear(attention_dim, vocab_size)
         | 
| 98 | 
            -
                    else:
         | 
| 99 | 
            -
                        self.output_layer = torch.nn.Identity()
         | 
| 100 | 
            -
                    self.num_blocks = num_blocks
         | 
| 101 | 
            -
                    self.decoders = torch.nn.ModuleList(
         | 
| 102 | 
            -
                        [
         | 
| 103 | 
            -
                            DecoderLayer(
         | 
| 104 | 
            -
                                attention_dim,
         | 
| 105 | 
            -
                                COSYVOICE_ATTENTION_CLASSES["selfattn"](
         | 
| 106 | 
            -
                                    attention_heads,
         | 
| 107 | 
            -
                                    attention_dim,
         | 
| 108 | 
            -
                                    self_attention_dropout_rate,
         | 
| 109 | 
            -
                                    key_bias,
         | 
| 110 | 
            -
                                ),
         | 
| 111 | 
            -
                                (
         | 
| 112 | 
            -
                                    COSYVOICE_ATTENTION_CLASSES["selfattn"](
         | 
| 113 | 
            -
                                        attention_heads,
         | 
| 114 | 
            -
                                        attention_dim,
         | 
| 115 | 
            -
                                        src_attention_dropout_rate,
         | 
| 116 | 
            -
                                        key_bias,
         | 
| 117 | 
            -
                                    )
         | 
| 118 | 
            -
                                    if src_attention
         | 
| 119 | 
            -
                                    else None
         | 
| 120 | 
            -
                                ),
         | 
| 121 | 
            -
                                PositionwiseFeedForward(
         | 
| 122 | 
            -
                                    attention_dim, linear_units, dropout_rate, activation
         | 
| 123 | 
            -
                                ),
         | 
| 124 | 
            -
                                dropout_rate,
         | 
| 125 | 
            -
                                normalize_before,
         | 
| 126 | 
            -
                            )
         | 
| 127 | 
            -
                            for _ in range(self.num_blocks)
         | 
| 128 | 
            -
                        ]
         | 
| 129 | 
            -
                    )
         | 
| 130 | 
            -
             | 
| 131 | 
            -
                    self.gradient_checkpointing = gradient_checkpointing
         | 
| 132 | 
            -
                    self.tie_word_embedding = tie_word_embedding
         | 
| 133 | 
            -
             | 
| 134 | 
            -
                def forward(
         | 
| 135 | 
            -
                    self,
         | 
| 136 | 
            -
                    memory: torch.Tensor,
         | 
| 137 | 
            -
                    memory_mask: torch.Tensor,
         | 
| 138 | 
            -
                    ys_in_pad: torch.Tensor,
         | 
| 139 | 
            -
                    ys_in_lens: torch.Tensor,
         | 
| 140 | 
            -
                    r_ys_in_pad: torch.Tensor = torch.empty(0),
         | 
| 141 | 
            -
                    reverse_weight: float = 0.0,
         | 
| 142 | 
            -
                ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
         | 
| 143 | 
            -
                    """Forward decoder.
         | 
| 144 | 
            -
                    Args:
         | 
| 145 | 
            -
                        memory: encoded memory, float32  (batch, maxlen_in, feat)
         | 
| 146 | 
            -
                        memory_mask: encoder memory mask, (batch, 1, maxlen_in)
         | 
| 147 | 
            -
                        ys_in_pad: padded input token ids, int64 (batch, maxlen_out)
         | 
| 148 | 
            -
                        ys_in_lens: input lengths of this batch (batch)
         | 
| 149 | 
            -
                        r_ys_in_pad: not used in transformer decoder, in order to unify api
         | 
| 150 | 
            -
                            with bidirectional decoder
         | 
| 151 | 
            -
                        reverse_weight: not used in transformer decoder, in order to unify
         | 
| 152 | 
            -
                            api with bidirectional decode
         | 
| 153 | 
            -
                    Returns:
         | 
| 154 | 
            -
                        (tuple): tuple containing:
         | 
| 155 | 
            -
                            x: decoded token score before softmax (batch, maxlen_out,
         | 
| 156 | 
            -
                                vocab_size) if use_output_layer is True,
         | 
| 157 | 
            -
                            torch.tensor(0.0), in order to unify api with bidirectional decoder
         | 
| 158 | 
            -
                            olens: (batch, )
         | 
| 159 | 
            -
                    NOTE(xcsong):
         | 
| 160 | 
            -
                        We pass the `__call__` method of the modules instead of `forward` to the
         | 
| 161 | 
            -
                        checkpointing API because `__call__` attaches all the hooks of the module.
         | 
| 162 | 
            -
                        https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
         | 
| 163 | 
            -
                    """
         | 
| 164 | 
            -
                    tgt = ys_in_pad
         | 
| 165 | 
            -
                    maxlen = tgt.size(1)
         | 
| 166 | 
            -
                    # tgt_mask: (B, 1, L)
         | 
| 167 | 
            -
                    tgt_mask = ~make_pad_mask(ys_in_lens, maxlen).unsqueeze(1)
         | 
| 168 | 
            -
                    tgt_mask = tgt_mask.to(tgt.device)
         | 
| 169 | 
            -
                    # m: (1, L, L)
         | 
| 170 | 
            -
                    m = subsequent_mask(tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0)
         | 
| 171 | 
            -
                    # tgt_mask: (B, L, L)
         | 
| 172 | 
            -
                    tgt_mask = tgt_mask & m
         | 
| 173 | 
            -
                    x, _ = self.embed(tgt)
         | 
| 174 | 
            -
                    if self.gradient_checkpointing and self.training:
         | 
| 175 | 
            -
                        x = self.forward_layers_checkpointed(x, tgt_mask, memory, memory_mask)
         | 
| 176 | 
            -
                    else:
         | 
| 177 | 
            -
                        x = self.forward_layers(x, tgt_mask, memory, memory_mask)
         | 
| 178 | 
            -
                    if self.normalize_before:
         | 
| 179 | 
            -
                        x = self.after_norm(x)
         | 
| 180 | 
            -
                    if self.use_output_layer:
         | 
| 181 | 
            -
                        x = self.output_layer(x)
         | 
| 182 | 
            -
                    olens = tgt_mask.sum(1)
         | 
| 183 | 
            -
                    return x, torch.tensor(0.0), olens
         | 
| 184 | 
            -
             | 
| 185 | 
            -
                def forward_layers(
         | 
| 186 | 
            -
                    self,
         | 
| 187 | 
            -
                    x: torch.Tensor,
         | 
| 188 | 
            -
                    tgt_mask: torch.Tensor,
         | 
| 189 | 
            -
                    memory: torch.Tensor,
         | 
| 190 | 
            -
                    memory_mask: torch.Tensor,
         | 
| 191 | 
            -
                ) -> torch.Tensor:
         | 
| 192 | 
            -
                    for layer in self.decoders:
         | 
| 193 | 
            -
                        x, tgt_mask, memory, memory_mask = layer(x, tgt_mask, memory, memory_mask)
         | 
| 194 | 
            -
                    return x
         | 
| 195 | 
            -
             | 
| 196 | 
            -
                @torch.jit.unused
         | 
| 197 | 
            -
                def forward_layers_checkpointed(
         | 
| 198 | 
            -
                    self,
         | 
| 199 | 
            -
                    x: torch.Tensor,
         | 
| 200 | 
            -
                    tgt_mask: torch.Tensor,
         | 
| 201 | 
            -
                    memory: torch.Tensor,
         | 
| 202 | 
            -
                    memory_mask: torch.Tensor,
         | 
| 203 | 
            -
                ) -> torch.Tensor:
         | 
| 204 | 
            -
                    for layer in self.decoders:
         | 
| 205 | 
            -
                        x, tgt_mask, memory, memory_mask = ckpt.checkpoint(
         | 
| 206 | 
            -
                            layer.__call__, x, tgt_mask, memory, memory_mask
         | 
| 207 | 
            -
                        )
         | 
| 208 | 
            -
                    return x
         | 
| 209 | 
            -
             | 
| 210 | 
            -
                def forward_one_step(
         | 
| 211 | 
            -
                    self,
         | 
| 212 | 
            -
                    memory: torch.Tensor,
         | 
| 213 | 
            -
                    memory_mask: torch.Tensor,
         | 
| 214 | 
            -
                    tgt: torch.Tensor,
         | 
| 215 | 
            -
                    tgt_mask: torch.Tensor,
         | 
| 216 | 
            -
                    cache: Optional[List[torch.Tensor]] = None,
         | 
| 217 | 
            -
                ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
         | 
| 218 | 
            -
                    """Forward one step.
         | 
| 219 | 
            -
                        This is only used for decoding.
         | 
| 220 | 
            -
                    Args:
         | 
| 221 | 
            -
                        memory: encoded memory, float32  (batch, maxlen_in, feat)
         | 
| 222 | 
            -
                        memory_mask: encoded memory mask, (batch, 1, maxlen_in)
         | 
| 223 | 
            -
                        tgt: input token ids, int64 (batch, maxlen_out)
         | 
| 224 | 
            -
                        tgt_mask: input token mask,  (batch, maxlen_out)
         | 
| 225 | 
            -
                                  dtype=torch.uint8 in PyTorch 1.2-
         | 
| 226 | 
            -
                                  dtype=torch.bool in PyTorch 1.2+ (include 1.2)
         | 
| 227 | 
            -
                        cache: cached output list of (batch, max_time_out-1, size)
         | 
| 228 | 
            -
                    Returns:
         | 
| 229 | 
            -
                        y, cache: NN output value and cache per `self.decoders`.
         | 
| 230 | 
            -
                        y.shape` is (batch, maxlen_out, token)
         | 
| 231 | 
            -
                    """
         | 
| 232 | 
            -
                    x, _ = self.embed(tgt)
         | 
| 233 | 
            -
                    new_cache = []
         | 
| 234 | 
            -
                    for i, decoder in enumerate(self.decoders):
         | 
| 235 | 
            -
                        if cache is None:
         | 
| 236 | 
            -
                            c = None
         | 
| 237 | 
            -
                        else:
         | 
| 238 | 
            -
                            c = cache[i]
         | 
| 239 | 
            -
                        x, tgt_mask, memory, memory_mask = decoder(
         | 
| 240 | 
            -
                            x, tgt_mask, memory, memory_mask, cache=c
         | 
| 241 | 
            -
                        )
         | 
| 242 | 
            -
                        new_cache.append(x)
         | 
| 243 | 
            -
                    if self.normalize_before:
         | 
| 244 | 
            -
                        y = self.after_norm(x[:, -1])
         | 
| 245 | 
            -
                    else:
         | 
| 246 | 
            -
                        y = x[:, -1]
         | 
| 247 | 
            -
                    if self.use_output_layer:
         | 
| 248 | 
            -
                        y = torch.log_softmax(self.output_layer(y), dim=-1)
         | 
| 249 | 
            -
                    return y, new_cache
         | 
| 250 | 
            -
             | 
| 251 | 
            -
                def tie_or_clone_weights(self, jit_mode: bool = True):
         | 
| 252 | 
            -
                    """Tie or clone module weights (between word_emb and output_layer)
         | 
| 253 | 
            -
                    depending of whether we are using TorchScript or not"""
         | 
| 254 | 
            -
                    if not self.use_output_layer:
         | 
| 255 | 
            -
                        return
         | 
| 256 | 
            -
                    if jit_mode:
         | 
| 257 | 
            -
                        logging.info("clone emb.weight to output.weight")
         | 
| 258 | 
            -
                        self.output_layer.weight = torch.nn.Parameter(self.embed[0].weight.clone())
         | 
| 259 | 
            -
                    else:
         | 
| 260 | 
            -
                        logging.info("tie emb.weight with output.weight")
         | 
| 261 | 
            -
                        self.output_layer.weight = self.embed[0].weight
         | 
| 262 | 
            -
             | 
| 263 | 
            -
                    if getattr(self.output_layer, "bias", None) is not None:
         | 
| 264 | 
            -
                        self.output_layer.bias.data = torch.nn.functional.pad(
         | 
| 265 | 
            -
                            self.output_layer.bias.data,
         | 
| 266 | 
            -
                            (
         | 
| 267 | 
            -
                                0,
         | 
| 268 | 
            -
                                self.output_layer.weight.shape[0] - self.output_layer.bias.shape[0],
         | 
| 269 | 
            -
                            ),
         | 
| 270 | 
            -
                            "constant",
         | 
| 271 | 
            -
                            0,
         | 
| 272 | 
            -
                        )
         | 
| 273 | 
            -
             | 
| 274 | 
            -
             | 
| 275 | 
            -
            class BiTransformerDecoder(torch.nn.Module):
         | 
| 276 | 
            -
                """Base class of Transfomer decoder module.
         | 
| 277 | 
            -
                Args:
         | 
| 278 | 
            -
                    vocab_size: output dim
         | 
| 279 | 
            -
                    encoder_output_size: dimension of attention
         | 
| 280 | 
            -
                    attention_heads: the number of heads of multi head attention
         | 
| 281 | 
            -
                    linear_units: the hidden units number of position-wise feedforward
         | 
| 282 | 
            -
                    num_blocks: the number of decoder blocks
         | 
| 283 | 
            -
                    r_num_blocks: the number of right to left decoder blocks
         | 
| 284 | 
            -
                    dropout_rate: dropout rate
         | 
| 285 | 
            -
                    self_attention_dropout_rate: dropout rate for attention
         | 
| 286 | 
            -
                    input_layer: input layer type
         | 
| 287 | 
            -
                    use_output_layer: whether to use output layer
         | 
| 288 | 
            -
                    pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
         | 
| 289 | 
            -
                    normalize_before:
         | 
| 290 | 
            -
                        True: use layer_norm before each sub-block of a layer.
         | 
| 291 | 
            -
                        False: use layer_norm after each sub-block of a layer.
         | 
| 292 | 
            -
                    key_bias: whether use bias in attention.linear_k, False for whisper models.
         | 
| 293 | 
            -
                """
         | 
| 294 | 
            -
             | 
| 295 | 
            -
                def __init__(
         | 
| 296 | 
            -
                    self,
         | 
| 297 | 
            -
                    vocab_size: int,
         | 
| 298 | 
            -
                    encoder_output_size: int,
         | 
| 299 | 
            -
                    attention_heads: int = 4,
         | 
| 300 | 
            -
                    linear_units: int = 2048,
         | 
| 301 | 
            -
                    num_blocks: int = 6,
         | 
| 302 | 
            -
                    r_num_blocks: int = 0,
         | 
| 303 | 
            -
                    dropout_rate: float = 0.1,
         | 
| 304 | 
            -
                    positional_dropout_rate: float = 0.1,
         | 
| 305 | 
            -
                    self_attention_dropout_rate: float = 0.0,
         | 
| 306 | 
            -
                    src_attention_dropout_rate: float = 0.0,
         | 
| 307 | 
            -
                    input_layer: str = "embed",
         | 
| 308 | 
            -
                    use_output_layer: bool = True,
         | 
| 309 | 
            -
                    normalize_before: bool = True,
         | 
| 310 | 
            -
                    key_bias: bool = True,
         | 
| 311 | 
            -
                    gradient_checkpointing: bool = False,
         | 
| 312 | 
            -
                    tie_word_embedding: bool = False,
         | 
| 313 | 
            -
                ):
         | 
| 314 | 
            -
             | 
| 315 | 
            -
                    super().__init__()
         | 
| 316 | 
            -
                    self.tie_word_embedding = tie_word_embedding
         | 
| 317 | 
            -
                    self.left_decoder = TransformerDecoder(
         | 
| 318 | 
            -
                        vocab_size,
         | 
| 319 | 
            -
                        encoder_output_size,
         | 
| 320 | 
            -
                        attention_heads,
         | 
| 321 | 
            -
                        linear_units,
         | 
| 322 | 
            -
                        num_blocks,
         | 
| 323 | 
            -
                        dropout_rate,
         | 
| 324 | 
            -
                        positional_dropout_rate,
         | 
| 325 | 
            -
                        self_attention_dropout_rate,
         | 
| 326 | 
            -
                        src_attention_dropout_rate,
         | 
| 327 | 
            -
                        input_layer,
         | 
| 328 | 
            -
                        use_output_layer,
         | 
| 329 | 
            -
                        normalize_before,
         | 
| 330 | 
            -
                        key_bias=key_bias,
         | 
| 331 | 
            -
                        gradient_checkpointing=gradient_checkpointing,
         | 
| 332 | 
            -
                        tie_word_embedding=tie_word_embedding,
         | 
| 333 | 
            -
                    )
         | 
| 334 | 
            -
             | 
| 335 | 
            -
                    self.right_decoder = TransformerDecoder(
         | 
| 336 | 
            -
                        vocab_size,
         | 
| 337 | 
            -
                        encoder_output_size,
         | 
| 338 | 
            -
                        attention_heads,
         | 
| 339 | 
            -
                        linear_units,
         | 
| 340 | 
            -
                        r_num_blocks,
         | 
| 341 | 
            -
                        dropout_rate,
         | 
| 342 | 
            -
                        positional_dropout_rate,
         | 
| 343 | 
            -
                        self_attention_dropout_rate,
         | 
| 344 | 
            -
                        src_attention_dropout_rate,
         | 
| 345 | 
            -
                        input_layer,
         | 
| 346 | 
            -
                        use_output_layer,
         | 
| 347 | 
            -
                        normalize_before,
         | 
| 348 | 
            -
                        key_bias=key_bias,
         | 
| 349 | 
            -
                        gradient_checkpointing=gradient_checkpointing,
         | 
| 350 | 
            -
                        tie_word_embedding=tie_word_embedding,
         | 
| 351 | 
            -
                    )
         | 
| 352 | 
            -
             | 
| 353 | 
            -
                def forward(
         | 
| 354 | 
            -
                    self,
         | 
| 355 | 
            -
                    memory: torch.Tensor,
         | 
| 356 | 
            -
                    memory_mask: torch.Tensor,
         | 
| 357 | 
            -
                    ys_in_pad: torch.Tensor,
         | 
| 358 | 
            -
                    ys_in_lens: torch.Tensor,
         | 
| 359 | 
            -
                    r_ys_in_pad: torch.Tensor,
         | 
| 360 | 
            -
                    reverse_weight: float = 0.0,
         | 
| 361 | 
            -
                ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
         | 
| 362 | 
            -
                    """Forward decoder.
         | 
| 363 | 
            -
                    Args:
         | 
| 364 | 
            -
                        memory: encoded memory, float32  (batch, maxlen_in, feat)
         | 
| 365 | 
            -
                        memory_mask: encoder memory mask, (batch, 1, maxlen_in)
         | 
| 366 | 
            -
                        ys_in_pad: padded input token ids, int64 (batch, maxlen_out)
         | 
| 367 | 
            -
                        ys_in_lens: input lengths of this batch (batch)
         | 
| 368 | 
            -
                        r_ys_in_pad: padded input token ids, int64 (batch, maxlen_out),
         | 
| 369 | 
            -
                            used for right to left decoder
         | 
| 370 | 
            -
                        reverse_weight: used for right to left decoder
         | 
| 371 | 
            -
                    Returns:
         | 
| 372 | 
            -
                        (tuple): tuple containing:
         | 
| 373 | 
            -
                            x: decoded token score before softmax (batch, maxlen_out,
         | 
| 374 | 
            -
                                vocab_size) if use_output_layer is True,
         | 
| 375 | 
            -
                            r_x: x: decoded token score (right to left decoder)
         | 
| 376 | 
            -
                                before softmax (batch, maxlen_out, vocab_size)
         | 
| 377 | 
            -
                                if use_output_layer is True,
         | 
| 378 | 
            -
                            olens: (batch, )
         | 
| 379 | 
            -
                    """
         | 
| 380 | 
            -
                    l_x, _, olens = self.left_decoder(memory, memory_mask, ys_in_pad, ys_in_lens)
         | 
| 381 | 
            -
                    r_x = torch.tensor(0.0)
         | 
| 382 | 
            -
                    if reverse_weight > 0.0:
         | 
| 383 | 
            -
                        r_x, _, olens = self.right_decoder(
         | 
| 384 | 
            -
                            memory, memory_mask, r_ys_in_pad, ys_in_lens
         | 
| 385 | 
            -
                        )
         | 
| 386 | 
            -
                    return l_x, r_x, olens
         | 
| 387 | 
            -
             | 
| 388 | 
            -
                def forward_one_step(
         | 
| 389 | 
            -
                    self,
         | 
| 390 | 
            -
                    memory: torch.Tensor,
         | 
| 391 | 
            -
                    memory_mask: torch.Tensor,
         | 
| 392 | 
            -
                    tgt: torch.Tensor,
         | 
| 393 | 
            -
                    tgt_mask: torch.Tensor,
         | 
| 394 | 
            -
                    cache: Optional[List[torch.Tensor]] = None,
         | 
| 395 | 
            -
                ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
         | 
| 396 | 
            -
                    """Forward one step.
         | 
| 397 | 
            -
                        This is only used for decoding.
         | 
| 398 | 
            -
                    Args:
         | 
| 399 | 
            -
                        memory: encoded memory, float32  (batch, maxlen_in, feat)
         | 
| 400 | 
            -
                        memory_mask: encoded memory mask, (batch, 1, maxlen_in)
         | 
| 401 | 
            -
                        tgt: input token ids, int64 (batch, maxlen_out)
         | 
| 402 | 
            -
                        tgt_mask: input token mask,  (batch, maxlen_out)
         | 
| 403 | 
            -
                                  dtype=torch.uint8 in PyTorch 1.2-
         | 
| 404 | 
            -
                                  dtype=torch.bool in PyTorch 1.2+ (include 1.2)
         | 
| 405 | 
            -
                        cache: cached output list of (batch, max_time_out-1, size)
         | 
| 406 | 
            -
                    Returns:
         | 
| 407 | 
            -
                        y, cache: NN output value and cache per `self.decoders`.
         | 
| 408 | 
            -
                        y.shape` is (batch, maxlen_out, token)
         | 
| 409 | 
            -
                    """
         | 
| 410 | 
            -
                    return self.left_decoder.forward_one_step(
         | 
| 411 | 
            -
                        memory, memory_mask, tgt, tgt_mask, cache
         | 
| 412 | 
            -
                    )
         | 
| 413 | 
            -
             | 
| 414 | 
            -
                def tie_or_clone_weights(self, jit_mode: bool = True):
         | 
| 415 | 
            -
                    """Tie or clone module weights (between word_emb and output_layer)
         | 
| 416 | 
            -
                    depending of whether we are using TorchScript or not"""
         | 
| 417 | 
            -
                    self.left_decoder.tie_or_clone_weights(jit_mode)
         | 
| 418 | 
            -
                    self.right_decoder.tie_or_clone_weights(jit_mode)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        cosyvoice/transformer/decoder_layer.py
    DELETED
    
    | @@ -1,132 +0,0 @@ | |
| 1 | 
            -
            # Copyright (c) 2019 Shigeki Karita
         | 
| 2 | 
            -
            #               2020 Mobvoi Inc (Binbin Zhang)
         | 
| 3 | 
            -
            #
         | 
| 4 | 
            -
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 5 | 
            -
            # you may not use this file except in compliance with the License.
         | 
| 6 | 
            -
            # You may obtain a copy of the License at
         | 
| 7 | 
            -
            #
         | 
| 8 | 
            -
            #   http://www.apache.org/licenses/LICENSE-2.0
         | 
| 9 | 
            -
            #
         | 
| 10 | 
            -
            # Unless required by applicable law or agreed to in writing, software
         | 
| 11 | 
            -
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 12 | 
            -
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 13 | 
            -
            # See the License for the specific language governing permissions and
         | 
| 14 | 
            -
            # limitations under the License.
         | 
| 15 | 
            -
            """Decoder self-attention layer definition."""
         | 
| 16 | 
            -
            from typing import Optional, Tuple
         | 
| 17 | 
            -
             | 
| 18 | 
            -
            import torch
         | 
| 19 | 
            -
            from torch import nn
         | 
| 20 | 
            -
             | 
| 21 | 
            -
             | 
| 22 | 
            -
            class DecoderLayer(nn.Module):
         | 
| 23 | 
            -
                """Single decoder layer module.
         | 
| 24 | 
            -
             | 
| 25 | 
            -
                Args:
         | 
| 26 | 
            -
                    size (int): Input dimension.
         | 
| 27 | 
            -
                    self_attn (torch.nn.Module): Self-attention module instance.
         | 
| 28 | 
            -
                        `MultiHeadedAttention` instance can be used as the argument.
         | 
| 29 | 
            -
                    src_attn (torch.nn.Module): Inter-attention module instance.
         | 
| 30 | 
            -
                        `MultiHeadedAttention` instance can be used as the argument.
         | 
| 31 | 
            -
                        If `None` is passed, Inter-attention is not used, such as
         | 
| 32 | 
            -
                        CIF, GPT, and other decoder only model.
         | 
| 33 | 
            -
                    feed_forward (torch.nn.Module): Feed-forward module instance.
         | 
| 34 | 
            -
                        `PositionwiseFeedForward` instance can be used as the argument.
         | 
| 35 | 
            -
                    dropout_rate (float): Dropout rate.
         | 
| 36 | 
            -
                    normalize_before (bool):
         | 
| 37 | 
            -
                        True: use layer_norm before each sub-block.
         | 
| 38 | 
            -
                        False: to use layer_norm after each sub-block.
         | 
| 39 | 
            -
                """
         | 
| 40 | 
            -
             | 
| 41 | 
            -
                def __init__(
         | 
| 42 | 
            -
                    self,
         | 
| 43 | 
            -
                    size: int,
         | 
| 44 | 
            -
                    self_attn: nn.Module,
         | 
| 45 | 
            -
                    src_attn: Optional[nn.Module],
         | 
| 46 | 
            -
                    feed_forward: nn.Module,
         | 
| 47 | 
            -
                    dropout_rate: float,
         | 
| 48 | 
            -
                    normalize_before: bool = True,
         | 
| 49 | 
            -
                ):
         | 
| 50 | 
            -
                    """Construct an DecoderLayer object."""
         | 
| 51 | 
            -
                    super().__init__()
         | 
| 52 | 
            -
                    self.size = size
         | 
| 53 | 
            -
                    self.self_attn = self_attn
         | 
| 54 | 
            -
                    self.src_attn = src_attn
         | 
| 55 | 
            -
                    self.feed_forward = feed_forward
         | 
| 56 | 
            -
                    self.norm1 = nn.LayerNorm(size, eps=1e-5)
         | 
| 57 | 
            -
                    self.norm2 = nn.LayerNorm(size, eps=1e-5)
         | 
| 58 | 
            -
                    self.norm3 = nn.LayerNorm(size, eps=1e-5)
         | 
| 59 | 
            -
                    self.dropout = nn.Dropout(dropout_rate)
         | 
| 60 | 
            -
                    self.normalize_before = normalize_before
         | 
| 61 | 
            -
             | 
| 62 | 
            -
                def forward(
         | 
| 63 | 
            -
                    self,
         | 
| 64 | 
            -
                    tgt: torch.Tensor,
         | 
| 65 | 
            -
                    tgt_mask: torch.Tensor,
         | 
| 66 | 
            -
                    memory: torch.Tensor,
         | 
| 67 | 
            -
                    memory_mask: torch.Tensor,
         | 
| 68 | 
            -
                    cache: Optional[torch.Tensor] = None,
         | 
| 69 | 
            -
                ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
         | 
| 70 | 
            -
                    """Compute decoded features.
         | 
| 71 | 
            -
             | 
| 72 | 
            -
                    Args:
         | 
| 73 | 
            -
                        tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
         | 
| 74 | 
            -
                        tgt_mask (torch.Tensor): Mask for input tensor
         | 
| 75 | 
            -
                            (#batch, maxlen_out).
         | 
| 76 | 
            -
                        memory (torch.Tensor): Encoded memory
         | 
| 77 | 
            -
                            (#batch, maxlen_in, size).
         | 
| 78 | 
            -
                        memory_mask (torch.Tensor): Encoded memory mask
         | 
| 79 | 
            -
                            (#batch, maxlen_in).
         | 
| 80 | 
            -
                        cache (torch.Tensor): cached tensors.
         | 
| 81 | 
            -
                            (#batch, maxlen_out - 1, size).
         | 
| 82 | 
            -
             | 
| 83 | 
            -
                    Returns:
         | 
| 84 | 
            -
                        torch.Tensor: Output tensor (#batch, maxlen_out, size).
         | 
| 85 | 
            -
                        torch.Tensor: Mask for output tensor (#batch, maxlen_out).
         | 
| 86 | 
            -
                        torch.Tensor: Encoded memory (#batch, maxlen_in, size).
         | 
| 87 | 
            -
                        torch.Tensor: Encoded memory mask (#batch, maxlen_in).
         | 
| 88 | 
            -
             | 
| 89 | 
            -
                    """
         | 
| 90 | 
            -
                    residual = tgt
         | 
| 91 | 
            -
                    if self.normalize_before:
         | 
| 92 | 
            -
                        tgt = self.norm1(tgt)
         | 
| 93 | 
            -
             | 
| 94 | 
            -
                    if cache is None:
         | 
| 95 | 
            -
                        tgt_q = tgt
         | 
| 96 | 
            -
                        tgt_q_mask = tgt_mask
         | 
| 97 | 
            -
                    else:
         | 
| 98 | 
            -
                        # compute only the last frame query keeping dim: max_time_out -> 1
         | 
| 99 | 
            -
                        assert cache.shape == (
         | 
| 100 | 
            -
                            tgt.shape[0],
         | 
| 101 | 
            -
                            tgt.shape[1] - 1,
         | 
| 102 | 
            -
                            self.size,
         | 
| 103 | 
            -
                        ), "{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}"
         | 
| 104 | 
            -
                        tgt_q = tgt[:, -1:, :]
         | 
| 105 | 
            -
                        residual = residual[:, -1:, :]
         | 
| 106 | 
            -
                        tgt_q_mask = tgt_mask[:, -1:, :]
         | 
| 107 | 
            -
             | 
| 108 | 
            -
                    x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)[0])
         | 
| 109 | 
            -
                    if not self.normalize_before:
         | 
| 110 | 
            -
                        x = self.norm1(x)
         | 
| 111 | 
            -
             | 
| 112 | 
            -
                    if self.src_attn is not None:
         | 
| 113 | 
            -
                        residual = x
         | 
| 114 | 
            -
                        if self.normalize_before:
         | 
| 115 | 
            -
                            x = self.norm2(x)
         | 
| 116 | 
            -
                        x = residual + self.dropout(
         | 
| 117 | 
            -
                            self.src_attn(x, memory, memory, memory_mask)[0]
         | 
| 118 | 
            -
                        )
         | 
| 119 | 
            -
                        if not self.normalize_before:
         | 
| 120 | 
            -
                            x = self.norm2(x)
         | 
| 121 | 
            -
             | 
| 122 | 
            -
                    residual = x
         | 
| 123 | 
            -
                    if self.normalize_before:
         | 
| 124 | 
            -
                        x = self.norm3(x)
         | 
| 125 | 
            -
                    x = residual + self.dropout(self.feed_forward(x))
         | 
| 126 | 
            -
                    if not self.normalize_before:
         | 
| 127 | 
            -
                        x = self.norm3(x)
         | 
| 128 | 
            -
             | 
| 129 | 
            -
                    if cache is not None:
         | 
| 130 | 
            -
                        x = torch.cat([cache, x], dim=1)
         | 
| 131 | 
            -
             | 
| 132 | 
            -
                    return x, tgt_mask, memory, memory_mask
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        cosyvoice/transformer/embedding.py
    DELETED
    
    | @@ -1,293 +0,0 @@ | |
| 1 | 
            -
            # Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
         | 
| 2 | 
            -
            #               2024 Alibaba Inc (Xiang Lyu)
         | 
| 3 | 
            -
            #
         | 
| 4 | 
            -
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 5 | 
            -
            # you may not use this file except in compliance with the License.
         | 
| 6 | 
            -
            # You may obtain a copy of the License at
         | 
| 7 | 
            -
            #
         | 
| 8 | 
            -
            #     http://www.apache.org/licenses/LICENSE-2.0
         | 
| 9 | 
            -
            #
         | 
| 10 | 
            -
            # Unless required by applicable law or agreed to in writing, software
         | 
| 11 | 
            -
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 12 | 
            -
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 13 | 
            -
            # See the License for the specific language governing permissions and
         | 
| 14 | 
            -
            # limitations under the License.
         | 
| 15 | 
            -
            # Modified from ESPnet(https://github.com/espnet/espnet)
         | 
| 16 | 
            -
            """Positonal Encoding Module."""
         | 
| 17 | 
            -
             | 
| 18 | 
            -
            import math
         | 
| 19 | 
            -
            from typing import Tuple, Union
         | 
| 20 | 
            -
             | 
| 21 | 
            -
            import torch
         | 
| 22 | 
            -
            import torch.nn.functional as F
         | 
| 23 | 
            -
            import numpy as np
         | 
| 24 | 
            -
             | 
| 25 | 
            -
             | 
| 26 | 
            -
            class PositionalEncoding(torch.nn.Module):
         | 
| 27 | 
            -
                """Positional encoding.
         | 
| 28 | 
            -
             | 
| 29 | 
            -
                :param int d_model: embedding dim
         | 
| 30 | 
            -
                :param float dropout_rate: dropout rate
         | 
| 31 | 
            -
                :param int max_len: maximum input length
         | 
| 32 | 
            -
             | 
| 33 | 
            -
                PE(pos, 2i)   = sin(pos/(10000^(2i/dmodel)))
         | 
| 34 | 
            -
                PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel)))
         | 
| 35 | 
            -
                """
         | 
| 36 | 
            -
             | 
| 37 | 
            -
                def __init__(
         | 
| 38 | 
            -
                    self,
         | 
| 39 | 
            -
                    d_model: int,
         | 
| 40 | 
            -
                    dropout_rate: float,
         | 
| 41 | 
            -
                    max_len: int = 5000,
         | 
| 42 | 
            -
                    reverse: bool = False,
         | 
| 43 | 
            -
                ):
         | 
| 44 | 
            -
                    """Construct an PositionalEncoding object."""
         | 
| 45 | 
            -
                    super().__init__()
         | 
| 46 | 
            -
                    self.d_model = d_model
         | 
| 47 | 
            -
                    self.xscale = math.sqrt(self.d_model)
         | 
| 48 | 
            -
                    self.dropout = torch.nn.Dropout(p=dropout_rate)
         | 
| 49 | 
            -
                    self.max_len = max_len
         | 
| 50 | 
            -
             | 
| 51 | 
            -
                    self.pe = torch.zeros(self.max_len, self.d_model)
         | 
| 52 | 
            -
                    position = torch.arange(0, self.max_len, dtype=torch.float32).unsqueeze(1)
         | 
| 53 | 
            -
                    div_term = torch.exp(
         | 
| 54 | 
            -
                        torch.arange(0, self.d_model, 2, dtype=torch.float32)
         | 
| 55 | 
            -
                        * -(math.log(10000.0) / self.d_model)
         | 
| 56 | 
            -
                    )
         | 
| 57 | 
            -
                    self.pe[:, 0::2] = torch.sin(position * div_term)
         | 
| 58 | 
            -
                    self.pe[:, 1::2] = torch.cos(position * div_term)
         | 
| 59 | 
            -
                    self.pe = self.pe.unsqueeze(0)
         | 
| 60 | 
            -
             | 
| 61 | 
            -
                def forward(
         | 
| 62 | 
            -
                    self, x: torch.Tensor, offset: Union[int, torch.Tensor] = 0
         | 
| 63 | 
            -
                ) -> Tuple[torch.Tensor, torch.Tensor]:
         | 
| 64 | 
            -
                    """Add positional encoding.
         | 
| 65 | 
            -
             | 
| 66 | 
            -
                    Args:
         | 
| 67 | 
            -
                        x (torch.Tensor): Input. Its shape is (batch, time, ...)
         | 
| 68 | 
            -
                        offset (int, torch.tensor): position offset
         | 
| 69 | 
            -
             | 
| 70 | 
            -
                    Returns:
         | 
| 71 | 
            -
                        torch.Tensor: Encoded tensor. Its shape is (batch, time, ...)
         | 
| 72 | 
            -
                        torch.Tensor: for compatibility to RelPositionalEncoding
         | 
| 73 | 
            -
                    """
         | 
| 74 | 
            -
             | 
| 75 | 
            -
                    self.pe = self.pe.to(x.device)
         | 
| 76 | 
            -
                    pos_emb = self.position_encoding(offset, x.size(1), False)
         | 
| 77 | 
            -
                    x = x * self.xscale + pos_emb
         | 
| 78 | 
            -
                    return self.dropout(x), self.dropout(pos_emb)
         | 
| 79 | 
            -
             | 
| 80 | 
            -
                def position_encoding(
         | 
| 81 | 
            -
                    self, offset: Union[int, torch.Tensor], size: int, apply_dropout: bool = True
         | 
| 82 | 
            -
                ) -> torch.Tensor:
         | 
| 83 | 
            -
                    """For getting encoding in a streaming fashion
         | 
| 84 | 
            -
             | 
| 85 | 
            -
                    Attention!!!!!
         | 
| 86 | 
            -
                    we apply dropout only once at the whole utterance level in a none
         | 
| 87 | 
            -
                    streaming way, but will call this function several times with
         | 
| 88 | 
            -
                    increasing input size in a streaming scenario, so the dropout will
         | 
| 89 | 
            -
                    be applied several times.
         | 
| 90 | 
            -
             | 
| 91 | 
            -
                    Args:
         | 
| 92 | 
            -
                        offset (int or torch.tensor): start offset
         | 
| 93 | 
            -
                        size (int): required size of position encoding
         | 
| 94 | 
            -
             | 
| 95 | 
            -
                    Returns:
         | 
| 96 | 
            -
                        torch.Tensor: Corresponding encoding
         | 
| 97 | 
            -
                    """
         | 
| 98 | 
            -
                    # How to subscript a Union type:
         | 
| 99 | 
            -
                    #   https://github.com/pytorch/pytorch/issues/69434
         | 
| 100 | 
            -
                    if isinstance(offset, int):
         | 
| 101 | 
            -
                        assert offset + size <= self.max_len
         | 
| 102 | 
            -
                        pos_emb = self.pe[:, offset : offset + size]
         | 
| 103 | 
            -
                    elif isinstance(offset, torch.Tensor) and offset.dim() == 0:  # scalar
         | 
| 104 | 
            -
                        assert offset + size <= self.max_len
         | 
| 105 | 
            -
                        pos_emb = self.pe[:, offset : offset + size]
         | 
| 106 | 
            -
                    else:  # for batched streaming decoding on GPU
         | 
| 107 | 
            -
                        assert torch.max(offset) + size <= self.max_len
         | 
| 108 | 
            -
                        index = offset.unsqueeze(1) + torch.arange(0, size).to(
         | 
| 109 | 
            -
                            offset.device
         | 
| 110 | 
            -
                        )  # B X T
         | 
| 111 | 
            -
                        flag = index > 0
         | 
| 112 | 
            -
                        # remove negative offset
         | 
| 113 | 
            -
                        index = index * flag
         | 
| 114 | 
            -
                        pos_emb = F.embedding(index, self.pe[0])  # B X T X d_model
         | 
| 115 | 
            -
             | 
| 116 | 
            -
                    if apply_dropout:
         | 
| 117 | 
            -
                        pos_emb = self.dropout(pos_emb)
         | 
| 118 | 
            -
                    return pos_emb
         | 
| 119 | 
            -
             | 
| 120 | 
            -
             | 
| 121 | 
            -
            class RelPositionalEncoding(PositionalEncoding):
         | 
| 122 | 
            -
                """Relative positional encoding module.
         | 
| 123 | 
            -
                See : Appendix B in https://arxiv.org/abs/1901.02860
         | 
| 124 | 
            -
                Args:
         | 
| 125 | 
            -
                    d_model (int): Embedding dimension.
         | 
| 126 | 
            -
                    dropout_rate (float): Dropout rate.
         | 
| 127 | 
            -
                    max_len (int): Maximum input length.
         | 
| 128 | 
            -
                """
         | 
| 129 | 
            -
             | 
| 130 | 
            -
                def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000):
         | 
| 131 | 
            -
                    """Initialize class."""
         | 
| 132 | 
            -
                    super().__init__(d_model, dropout_rate, max_len, reverse=True)
         | 
| 133 | 
            -
             | 
| 134 | 
            -
                def forward(
         | 
| 135 | 
            -
                    self, x: torch.Tensor, offset: Union[int, torch.Tensor] = 0
         | 
| 136 | 
            -
                ) -> Tuple[torch.Tensor, torch.Tensor]:
         | 
| 137 | 
            -
                    """Compute positional encoding.
         | 
| 138 | 
            -
                    Args:
         | 
| 139 | 
            -
                        x (torch.Tensor): Input tensor (batch, time, `*`).
         | 
| 140 | 
            -
                    Returns:
         | 
| 141 | 
            -
                        torch.Tensor: Encoded tensor (batch, time, `*`).
         | 
| 142 | 
            -
                        torch.Tensor: Positional embedding tensor (1, time, `*`).
         | 
| 143 | 
            -
                    """
         | 
| 144 | 
            -
                    self.pe = self.pe.to(x.device)
         | 
| 145 | 
            -
                    x = x * self.xscale
         | 
| 146 | 
            -
                    pos_emb = self.position_encoding(offset, x.size(1), False)
         | 
| 147 | 
            -
                    return self.dropout(x), self.dropout(pos_emb)
         | 
| 148 | 
            -
             | 
| 149 | 
            -
             | 
| 150 | 
            -
            class WhisperPositionalEncoding(PositionalEncoding):
         | 
| 151 | 
            -
                """Sinusoids position encoding used in openai-whisper.encoder"""
         | 
| 152 | 
            -
             | 
| 153 | 
            -
                def __init__(self, d_model: int, dropout_rate: float, max_len: int = 1500):
         | 
| 154 | 
            -
                    super().__init__(d_model, dropout_rate, max_len)
         | 
| 155 | 
            -
                    self.xscale = 1.0
         | 
| 156 | 
            -
                    log_timescale_increment = np.log(10000) / (d_model // 2 - 1)
         | 
| 157 | 
            -
                    inv_timescales = torch.exp(
         | 
| 158 | 
            -
                        -log_timescale_increment * torch.arange(d_model // 2)
         | 
| 159 | 
            -
                    )
         | 
| 160 | 
            -
                    scaled_time = (
         | 
| 161 | 
            -
                        torch.arange(max_len)[:, np.newaxis] * inv_timescales[np.newaxis, :]
         | 
| 162 | 
            -
                    )
         | 
| 163 | 
            -
                    pe = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
         | 
| 164 | 
            -
                    delattr(self, "pe")
         | 
| 165 | 
            -
                    self.register_buffer("pe", pe.unsqueeze(0))
         | 
| 166 | 
            -
             | 
| 167 | 
            -
             | 
| 168 | 
            -
            class LearnablePositionalEncoding(PositionalEncoding):
         | 
| 169 | 
            -
                """Learnable position encoding used in openai-whisper.decoder"""
         | 
| 170 | 
            -
             | 
| 171 | 
            -
                def __init__(self, d_model: int, dropout_rate: float, max_len: int = 448):
         | 
| 172 | 
            -
                    super().__init__(d_model, dropout_rate, max_len)
         | 
| 173 | 
            -
                    # NOTE(xcsong): overwrite self.pe & self.xscale
         | 
| 174 | 
            -
                    self.pe = torch.nn.Parameter(torch.empty(1, max_len, d_model))
         | 
| 175 | 
            -
                    self.xscale = 1.0
         | 
| 176 | 
            -
             | 
| 177 | 
            -
             | 
| 178 | 
            -
            class NoPositionalEncoding(torch.nn.Module):
         | 
| 179 | 
            -
                """No position encoding"""
         | 
| 180 | 
            -
             | 
| 181 | 
            -
                def __init__(self, d_model: int, dropout_rate: float):
         | 
| 182 | 
            -
                    super().__init__()
         | 
| 183 | 
            -
                    self.d_model = d_model
         | 
| 184 | 
            -
                    self.dropout = torch.nn.Dropout(p=dropout_rate)
         | 
| 185 | 
            -
             | 
| 186 | 
            -
                def forward(
         | 
| 187 | 
            -
                    self, x: torch.Tensor, offset: Union[int, torch.Tensor] = 0
         | 
| 188 | 
            -
                ) -> Tuple[torch.Tensor, torch.Tensor]:
         | 
| 189 | 
            -
                    """Just return zero vector for interface compatibility"""
         | 
| 190 | 
            -
                    pos_emb = torch.zeros(1, x.size(1), self.d_model).to(x.device)
         | 
| 191 | 
            -
                    return self.dropout(x), pos_emb
         | 
| 192 | 
            -
             | 
| 193 | 
            -
                def position_encoding(
         | 
| 194 | 
            -
                    self, offset: Union[int, torch.Tensor], size: int
         | 
| 195 | 
            -
                ) -> torch.Tensor:
         | 
| 196 | 
            -
                    return torch.zeros(1, size, self.d_model)
         | 
| 197 | 
            -
             | 
| 198 | 
            -
             | 
| 199 | 
            -
            class EspnetRelPositionalEncoding(torch.nn.Module):
         | 
| 200 | 
            -
                """Relative positional encoding module (new implementation).
         | 
| 201 | 
            -
             | 
| 202 | 
            -
                Details can be found in https://github.com/espnet/espnet/pull/2816.
         | 
| 203 | 
            -
             | 
| 204 | 
            -
                See : Appendix B in https://arxiv.org/abs/1901.02860
         | 
| 205 | 
            -
             | 
| 206 | 
            -
                Args:
         | 
| 207 | 
            -
                    d_model (int): Embedding dimension.
         | 
| 208 | 
            -
                    dropout_rate (float): Dropout rate.
         | 
| 209 | 
            -
                    max_len (int): Maximum input length.
         | 
| 210 | 
            -
             | 
| 211 | 
            -
                """
         | 
| 212 | 
            -
             | 
| 213 | 
            -
                def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000):
         | 
| 214 | 
            -
                    """Construct an PositionalEncoding object."""
         | 
| 215 | 
            -
                    super(EspnetRelPositionalEncoding, self).__init__()
         | 
| 216 | 
            -
                    self.d_model = d_model
         | 
| 217 | 
            -
                    self.xscale = math.sqrt(self.d_model)
         | 
| 218 | 
            -
                    self.dropout = torch.nn.Dropout(p=dropout_rate)
         | 
| 219 | 
            -
                    self.pe = None
         | 
| 220 | 
            -
                    self.extend_pe(torch.tensor(0.0).expand(1, max_len))
         | 
| 221 | 
            -
             | 
| 222 | 
            -
                def extend_pe(self, x: torch.Tensor):
         | 
| 223 | 
            -
                    """Reset the positional encodings."""
         | 
| 224 | 
            -
                    if self.pe is not None:
         | 
| 225 | 
            -
                        # self.pe contains both positive and negative parts
         | 
| 226 | 
            -
                        # the length of self.pe is 2 * input_len - 1
         | 
| 227 | 
            -
                        if self.pe.size(1) >= x.size(1) * 2 - 1:
         | 
| 228 | 
            -
                            if self.pe.dtype != x.dtype or self.pe.device != x.device:
         | 
| 229 | 
            -
                                self.pe = self.pe.to(dtype=x.dtype, device=x.device)
         | 
| 230 | 
            -
                            return
         | 
| 231 | 
            -
                    # Suppose `i` means to the position of query vecotr and `j` means the
         | 
| 232 | 
            -
                    # position of key vector. We use position relative positions when keys
         | 
| 233 | 
            -
                    # are to the left (i>j) and negative relative positions otherwise (i<j).
         | 
| 234 | 
            -
                    pe_positive = torch.zeros(x.size(1), self.d_model)
         | 
| 235 | 
            -
                    pe_negative = torch.zeros(x.size(1), self.d_model)
         | 
| 236 | 
            -
                    position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
         | 
| 237 | 
            -
                    div_term = torch.exp(
         | 
| 238 | 
            -
                        torch.arange(0, self.d_model, 2, dtype=torch.float32)
         | 
| 239 | 
            -
                        * -(math.log(10000.0) / self.d_model)
         | 
| 240 | 
            -
                    )
         | 
| 241 | 
            -
                    pe_positive[:, 0::2] = torch.sin(position * div_term)
         | 
| 242 | 
            -
                    pe_positive[:, 1::2] = torch.cos(position * div_term)
         | 
| 243 | 
            -
                    pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
         | 
| 244 | 
            -
                    pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
         | 
| 245 | 
            -
             | 
| 246 | 
            -
                    # Reserve the order of positive indices and concat both positive and
         | 
| 247 | 
            -
                    # negative indices. This is used to support the shifting trick
         | 
| 248 | 
            -
                    # as in https://arxiv.org/abs/1901.02860
         | 
| 249 | 
            -
                    pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
         | 
| 250 | 
            -
                    pe_negative = pe_negative[1:].unsqueeze(0)
         | 
| 251 | 
            -
                    pe = torch.cat([pe_positive, pe_negative], dim=1)
         | 
| 252 | 
            -
                    self.pe = pe.to(device=x.device, dtype=x.dtype)
         | 
| 253 | 
            -
             | 
| 254 | 
            -
                def forward(
         | 
| 255 | 
            -
                    self, x: torch.Tensor, offset: Union[int, torch.Tensor] = 0
         | 
| 256 | 
            -
                ) -> Tuple[torch.Tensor, torch.Tensor]:
         | 
| 257 | 
            -
                    """Add positional encoding.
         | 
| 258 | 
            -
             | 
| 259 | 
            -
                    Args:
         | 
| 260 | 
            -
                        x (torch.Tensor): Input tensor (batch, time, `*`).
         | 
| 261 | 
            -
             | 
| 262 | 
            -
                    Returns:
         | 
| 263 | 
            -
                        torch.Tensor: Encoded tensor (batch, time, `*`).
         | 
| 264 | 
            -
             | 
| 265 | 
            -
                    """
         | 
| 266 | 
            -
                    self.extend_pe(x)
         | 
| 267 | 
            -
                    x = x * self.xscale
         | 
| 268 | 
            -
                    pos_emb = self.position_encoding(size=x.size(1), offset=offset)
         | 
| 269 | 
            -
                    return self.dropout(x), self.dropout(pos_emb)
         | 
| 270 | 
            -
             | 
| 271 | 
            -
                def position_encoding(
         | 
| 272 | 
            -
                    self, offset: Union[int, torch.Tensor], size: int
         | 
| 273 | 
            -
                ) -> torch.Tensor:
         | 
| 274 | 
            -
                    """For getting encoding in a streaming fashion
         | 
| 275 | 
            -
             | 
| 276 | 
            -
                    Attention!!!!!
         | 
| 277 | 
            -
                    we apply dropout only once at the whole utterance level in a none
         | 
| 278 | 
            -
                    streaming way, but will call this function several times with
         | 
| 279 | 
            -
                    increasing input size in a streaming scenario, so the dropout will
         | 
| 280 | 
            -
                    be applied several times.
         | 
| 281 | 
            -
             | 
| 282 | 
            -
                    Args:
         | 
| 283 | 
            -
                        offset (int or torch.tensor): start offset
         | 
| 284 | 
            -
                        size (int): required size of position encoding
         | 
| 285 | 
            -
             | 
| 286 | 
            -
                    Returns:
         | 
| 287 | 
            -
                        torch.Tensor: Corresponding encoding
         | 
| 288 | 
            -
                    """
         | 
| 289 | 
            -
                    pos_emb = self.pe[
         | 
| 290 | 
            -
                        :,
         | 
| 291 | 
            -
                        self.pe.size(1) // 2 - size + 1 : self.pe.size(1) // 2 + size,
         | 
| 292 | 
            -
                    ]
         | 
| 293 | 
            -
                    return pos_emb
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        cosyvoice/transformer/encoder.py
    DELETED
    
    | @@ -1,633 +0,0 @@ | |
| 1 | 
            -
            # Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
         | 
| 2 | 
            -
            #               2022 Xingchen Song ([email protected])
         | 
| 3 | 
            -
            #               2024 Alibaba Inc (Xiang Lyu)
         | 
| 4 | 
            -
            #
         | 
| 5 | 
            -
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 6 | 
            -
            # you may not use this file except in compliance with the License.
         | 
| 7 | 
            -
            # You may obtain a copy of the License at
         | 
| 8 | 
            -
            #
         | 
| 9 | 
            -
            #   http://www.apache.org/licenses/LICENSE-2.0
         | 
| 10 | 
            -
            #
         | 
| 11 | 
            -
            # Unless required by applicable law or agreed to in writing, software
         | 
| 12 | 
            -
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 13 | 
            -
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 14 | 
            -
            # See the License for the specific language governing permissions and
         | 
| 15 | 
            -
            # limitations under the License.
         | 
| 16 | 
            -
            # Modified from ESPnet(https://github.com/espnet/espnet)
         | 
| 17 | 
            -
            """Encoder definition."""
         | 
| 18 | 
            -
            from typing import Tuple
         | 
| 19 | 
            -
            import time
         | 
| 20 | 
            -
             | 
| 21 | 
            -
            import torch
         | 
| 22 | 
            -
            import torch.utils.checkpoint as ckpt
         | 
| 23 | 
            -
            import torch.nn.functional as F
         | 
| 24 | 
            -
             | 
| 25 | 
            -
            from cosyvoice.transformer.convolution import ConvolutionModule
         | 
| 26 | 
            -
            from cosyvoice.transformer.encoder_layer import (
         | 
| 27 | 
            -
                TransformerEncoderLayer,
         | 
| 28 | 
            -
            )
         | 
| 29 | 
            -
            from cosyvoice.transformer.encoder_layer import (
         | 
| 30 | 
            -
                ConformerEncoderLayer,
         | 
| 31 | 
            -
            )
         | 
| 32 | 
            -
            from cosyvoice.transformer.positionwise_feed_forward import (
         | 
| 33 | 
            -
                PositionwiseFeedForward,
         | 
| 34 | 
            -
            )
         | 
| 35 | 
            -
            from cosyvoice.utils.class_utils import (
         | 
| 36 | 
            -
                COSYVOICE_EMB_CLASSES,
         | 
| 37 | 
            -
                COSYVOICE_SUBSAMPLE_CLASSES,
         | 
| 38 | 
            -
                COSYVOICE_ATTENTION_CLASSES,
         | 
| 39 | 
            -
                COSYVOICE_ACTIVATION_CLASSES,
         | 
| 40 | 
            -
            )
         | 
| 41 | 
            -
            from cosyvoice.utils.mask import make_pad_mask
         | 
| 42 | 
            -
            from cosyvoice.utils.mask import add_optional_chunk_mask
         | 
| 43 | 
            -
             | 
| 44 | 
            -
             | 
| 45 | 
            -
            class BaseEncoder(torch.nn.Module):
         | 
| 46 | 
            -
             | 
| 47 | 
            -
                def __init__(
         | 
| 48 | 
            -
                    self,
         | 
| 49 | 
            -
                    input_size: int,
         | 
| 50 | 
            -
                    output_size: int = 256,
         | 
| 51 | 
            -
                    attention_heads: int = 4,
         | 
| 52 | 
            -
                    linear_units: int = 2048,
         | 
| 53 | 
            -
                    num_blocks: int = 6,
         | 
| 54 | 
            -
                    dropout_rate: float = 0.1,
         | 
| 55 | 
            -
                    positional_dropout_rate: float = 0.1,
         | 
| 56 | 
            -
                    attention_dropout_rate: float = 0.0,
         | 
| 57 | 
            -
                    input_layer: str = "conv2d",
         | 
| 58 | 
            -
                    pos_enc_layer_type: str = "abs_pos",
         | 
| 59 | 
            -
                    normalize_before: bool = True,
         | 
| 60 | 
            -
                    static_chunk_size: int = 0,
         | 
| 61 | 
            -
                    use_dynamic_chunk: bool = False,
         | 
| 62 | 
            -
                    global_cmvn: torch.nn.Module = None,
         | 
| 63 | 
            -
                    use_dynamic_left_chunk: bool = False,
         | 
| 64 | 
            -
                    gradient_checkpointing: bool = False,
         | 
| 65 | 
            -
                ):
         | 
| 66 | 
            -
                    """
         | 
| 67 | 
            -
                    Args:
         | 
| 68 | 
            -
                        input_size (int): input dim
         | 
| 69 | 
            -
                        output_size (int): dimension of attention
         | 
| 70 | 
            -
                        attention_heads (int): the number of heads of multi head attention
         | 
| 71 | 
            -
                        linear_units (int): the hidden units number of position-wise feed
         | 
| 72 | 
            -
                            forward
         | 
| 73 | 
            -
                        num_blocks (int): the number of decoder blocks
         | 
| 74 | 
            -
                        dropout_rate (float): dropout rate
         | 
| 75 | 
            -
                        attention_dropout_rate (float): dropout rate in attention
         | 
| 76 | 
            -
                        positional_dropout_rate (float): dropout rate after adding
         | 
| 77 | 
            -
                            positional encoding
         | 
| 78 | 
            -
                        input_layer (str): input layer type.
         | 
| 79 | 
            -
                            optional [linear, conv2d, conv2d6, conv2d8]
         | 
| 80 | 
            -
                        pos_enc_layer_type (str): Encoder positional encoding layer type.
         | 
| 81 | 
            -
                            opitonal [abs_pos, scaled_abs_pos, rel_pos, no_pos]
         | 
| 82 | 
            -
                        normalize_before (bool):
         | 
| 83 | 
            -
                            True: use layer_norm before each sub-block of a layer.
         | 
| 84 | 
            -
                            False: use layer_norm after each sub-block of a layer.
         | 
| 85 | 
            -
                        static_chunk_size (int): chunk size for static chunk training and
         | 
| 86 | 
            -
                            decoding
         | 
| 87 | 
            -
                        use_dynamic_chunk (bool): whether use dynamic chunk size for
         | 
| 88 | 
            -
                            training or not, You can only use fixed chunk(chunk_size > 0)
         | 
| 89 | 
            -
                            or dyanmic chunk size(use_dynamic_chunk = True)
         | 
| 90 | 
            -
                        global_cmvn (Optional[torch.nn.Module]): Optional GlobalCMVN module
         | 
| 91 | 
            -
                        use_dynamic_left_chunk (bool): whether use dynamic left chunk in
         | 
| 92 | 
            -
                            dynamic chunk training
         | 
| 93 | 
            -
                        key_bias: whether use bias in attention.linear_k, False for whisper models.
         | 
| 94 | 
            -
                        gradient_checkpointing: rerunning a forward-pass segment for each
         | 
| 95 | 
            -
                            checkpointed segment during backward.
         | 
| 96 | 
            -
                    """
         | 
| 97 | 
            -
                    super().__init__()
         | 
| 98 | 
            -
                    self._output_size = output_size
         | 
| 99 | 
            -
             | 
| 100 | 
            -
                    self.global_cmvn = global_cmvn
         | 
| 101 | 
            -
                    self.embed = COSYVOICE_SUBSAMPLE_CLASSES[input_layer](
         | 
| 102 | 
            -
                        input_size,
         | 
| 103 | 
            -
                        output_size,
         | 
| 104 | 
            -
                        dropout_rate,
         | 
| 105 | 
            -
                        COSYVOICE_EMB_CLASSES[pos_enc_layer_type](
         | 
| 106 | 
            -
                            output_size, positional_dropout_rate
         | 
| 107 | 
            -
                        ),
         | 
| 108 | 
            -
                    )
         | 
| 109 | 
            -
             | 
| 110 | 
            -
                    self.normalize_before = normalize_before
         | 
| 111 | 
            -
                    self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-5)
         | 
| 112 | 
            -
                    self.static_chunk_size = static_chunk_size
         | 
| 113 | 
            -
                    self.use_dynamic_chunk = use_dynamic_chunk
         | 
| 114 | 
            -
                    self.use_dynamic_left_chunk = use_dynamic_left_chunk
         | 
| 115 | 
            -
                    self.gradient_checkpointing = gradient_checkpointing
         | 
| 116 | 
            -
             | 
| 117 | 
            -
                def output_size(self) -> int:
         | 
| 118 | 
            -
                    return self._output_size
         | 
| 119 | 
            -
             | 
| 120 | 
            -
                def forward(
         | 
| 121 | 
            -
                    self,
         | 
| 122 | 
            -
                    xs: torch.Tensor,
         | 
| 123 | 
            -
                    xs_lens: torch.Tensor,
         | 
| 124 | 
            -
                    decoding_chunk_size: int = 0,
         | 
| 125 | 
            -
                    num_decoding_left_chunks: int = -1,
         | 
| 126 | 
            -
                ) -> Tuple[torch.Tensor, torch.Tensor]:
         | 
| 127 | 
            -
                    """Embed positions in tensor.
         | 
| 128 | 
            -
             | 
| 129 | 
            -
                    Args:
         | 
| 130 | 
            -
                        xs: padded input tensor (B, T, D)
         | 
| 131 | 
            -
                        xs_lens: input length (B)
         | 
| 132 | 
            -
                        decoding_chunk_size: decoding chunk size for dynamic chunk
         | 
| 133 | 
            -
                            0: default for training, use random dynamic chunk.
         | 
| 134 | 
            -
                            <0: for decoding, use full chunk.
         | 
| 135 | 
            -
                            >0: for decoding, use fixed chunk size as set.
         | 
| 136 | 
            -
                        num_decoding_left_chunks: number of left chunks, this is for decoding,
         | 
| 137 | 
            -
                        the chunk size is decoding_chunk_size.
         | 
| 138 | 
            -
                            >=0: use num_decoding_left_chunks
         | 
| 139 | 
            -
                            <0: use all left chunks
         | 
| 140 | 
            -
                    Returns:
         | 
| 141 | 
            -
                        encoder output tensor xs, and subsampled masks
         | 
| 142 | 
            -
                        xs: padded output tensor (B, T' ~= T/subsample_rate, D)
         | 
| 143 | 
            -
                        masks: torch.Tensor batch padding mask after subsample
         | 
| 144 | 
            -
                            (B, 1, T' ~= T/subsample_rate)
         | 
| 145 | 
            -
                    NOTE(xcsong):
         | 
| 146 | 
            -
                        We pass the `__call__` method of the modules instead of `forward` to the
         | 
| 147 | 
            -
                        checkpointing API because `__call__` attaches all the hooks of the module.
         | 
| 148 | 
            -
                        https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
         | 
| 149 | 
            -
                    """
         | 
| 150 | 
            -
                    T = xs.size(1)
         | 
| 151 | 
            -
                    masks = ~make_pad_mask(xs_lens, T).unsqueeze(1)  # (B, 1, T)
         | 
| 152 | 
            -
                    if self.global_cmvn is not None:
         | 
| 153 | 
            -
                        xs = self.global_cmvn(xs)
         | 
| 154 | 
            -
                    xs, pos_emb, masks = self.embed(xs, masks)
         | 
| 155 | 
            -
                    mask_pad = masks  # (B, 1, T/subsample_rate)
         | 
| 156 | 
            -
                    chunk_masks = add_optional_chunk_mask(
         | 
| 157 | 
            -
                        xs,
         | 
| 158 | 
            -
                        masks,
         | 
| 159 | 
            -
                        self.use_dynamic_chunk,
         | 
| 160 | 
            -
                        self.use_dynamic_left_chunk,
         | 
| 161 | 
            -
                        decoding_chunk_size,
         | 
| 162 | 
            -
                        self.static_chunk_size,
         | 
| 163 | 
            -
                        num_decoding_left_chunks,
         | 
| 164 | 
            -
                    )
         | 
| 165 | 
            -
                    print(f"chunk_masks shape: {chunk_masks.shape}")
         | 
| 166 | 
            -
                    if self.gradient_checkpointing and self.training:
         | 
| 167 | 
            -
                        xs = self.forward_layers_checkpointed(xs, chunk_masks, pos_emb, mask_pad)
         | 
| 168 | 
            -
                    else:
         | 
| 169 | 
            -
                        xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad)
         | 
| 170 | 
            -
                    if self.normalize_before:
         | 
| 171 | 
            -
                        xs = self.after_norm(xs)
         | 
| 172 | 
            -
                    # Here we assume the mask is not changed in encoder layers, so just
         | 
| 173 | 
            -
                    # return the masks before encoder layers, and the masks will be used
         | 
| 174 | 
            -
                    # for cross attention with decoder later
         | 
| 175 | 
            -
                    return xs, masks
         | 
| 176 | 
            -
             | 
| 177 | 
            -
                def forward_layers(
         | 
| 178 | 
            -
                    self,
         | 
| 179 | 
            -
                    xs: torch.Tensor,
         | 
| 180 | 
            -
                    chunk_masks: torch.Tensor,
         | 
| 181 | 
            -
                    pos_emb: torch.Tensor,
         | 
| 182 | 
            -
                    mask_pad: torch.Tensor,
         | 
| 183 | 
            -
                ) -> torch.Tensor:
         | 
| 184 | 
            -
                    for layer in self.encoders:
         | 
| 185 | 
            -
                        xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
         | 
| 186 | 
            -
                    return xs
         | 
| 187 | 
            -
             | 
| 188 | 
            -
                @torch.jit.unused
         | 
| 189 | 
            -
                def forward_layers_checkpointed(
         | 
| 190 | 
            -
                    self,
         | 
| 191 | 
            -
                    xs: torch.Tensor,
         | 
| 192 | 
            -
                    chunk_masks: torch.Tensor,
         | 
| 193 | 
            -
                    pos_emb: torch.Tensor,
         | 
| 194 | 
            -
                    mask_pad: torch.Tensor,
         | 
| 195 | 
            -
                ) -> torch.Tensor:
         | 
| 196 | 
            -
                    for layer in self.encoders:
         | 
| 197 | 
            -
                        xs, chunk_masks, _, _ = ckpt.checkpoint(
         | 
| 198 | 
            -
                            layer.__call__, xs, chunk_masks, pos_emb, mask_pad
         | 
| 199 | 
            -
                        )
         | 
| 200 | 
            -
                    return xs
         | 
| 201 | 
            -
             | 
| 202 | 
            -
                @torch.jit.export
         | 
| 203 | 
            -
                def forward_chunk(
         | 
| 204 | 
            -
                    self,
         | 
| 205 | 
            -
                    xs: torch.Tensor,
         | 
| 206 | 
            -
                    offset: int,
         | 
| 207 | 
            -
                    required_cache_size: int,
         | 
| 208 | 
            -
                    att_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
         | 
| 209 | 
            -
                    cnn_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
         | 
| 210 | 
            -
                    att_mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
         | 
| 211 | 
            -
                ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
         | 
| 212 | 
            -
                    """ Forward just one chunk
         | 
| 213 | 
            -
             | 
| 214 | 
            -
                    Args:
         | 
| 215 | 
            -
                        xs (torch.Tensor): chunk input, with shape (b=1, time, mel-dim),
         | 
| 216 | 
            -
                            where `time == (chunk_size - 1) * subsample_rate + \
         | 
| 217 | 
            -
                                    subsample.right_context + 1`
         | 
| 218 | 
            -
                        offset (int): current offset in encoder output time stamp
         | 
| 219 | 
            -
                        required_cache_size (int): cache size required for next chunk
         | 
| 220 | 
            -
                            compuation
         | 
| 221 | 
            -
                            >=0: actual cache size
         | 
| 222 | 
            -
                            <0: means all history cache is required
         | 
| 223 | 
            -
                        att_cache (torch.Tensor): cache tensor for KEY & VALUE in
         | 
| 224 | 
            -
                            transformer/conformer attention, with shape
         | 
| 225 | 
            -
                            (elayers, head, cache_t1, d_k * 2), where
         | 
| 226 | 
            -
                            `head * d_k == hidden-dim` and
         | 
| 227 | 
            -
                            `cache_t1 == chunk_size * num_decoding_left_chunks`.
         | 
| 228 | 
            -
                        cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer,
         | 
| 229 | 
            -
                            (elayers, b=1, hidden-dim, cache_t2), where
         | 
| 230 | 
            -
                            `cache_t2 == cnn.lorder - 1`
         | 
| 231 | 
            -
             | 
| 232 | 
            -
                    Returns:
         | 
| 233 | 
            -
                        torch.Tensor: output of current input xs,
         | 
| 234 | 
            -
                            with shape (b=1, chunk_size, hidden-dim).
         | 
| 235 | 
            -
                        torch.Tensor: new attention cache required for next chunk, with
         | 
| 236 | 
            -
                            dynamic shape (elayers, head, ?, d_k * 2)
         | 
| 237 | 
            -
                            depending on required_cache_size.
         | 
| 238 | 
            -
                        torch.Tensor: new conformer cnn cache required for next chunk, with
         | 
| 239 | 
            -
                            same shape as the original cnn_cache.
         | 
| 240 | 
            -
             | 
| 241 | 
            -
                    """
         | 
| 242 | 
            -
                    assert xs.size(0) == 1
         | 
| 243 | 
            -
                    # tmp_masks is just for interface compatibility
         | 
| 244 | 
            -
                    tmp_masks = torch.ones(1, xs.size(1), device=xs.device, dtype=torch.bool)
         | 
| 245 | 
            -
                    tmp_masks = tmp_masks.unsqueeze(1)
         | 
| 246 | 
            -
                    if self.global_cmvn is not None:
         | 
| 247 | 
            -
                        xs = self.global_cmvn(xs)
         | 
| 248 | 
            -
                    # NOTE(xcsong): Before embed, shape(xs) is (b=1, time, mel-dim)
         | 
| 249 | 
            -
                    xs, pos_emb, _ = self.embed(xs, tmp_masks, offset)
         | 
| 250 | 
            -
                    # NOTE(xcsong): After  embed, shape(xs) is (b=1, chunk_size, hidden-dim)
         | 
| 251 | 
            -
                    elayers, cache_t1 = att_cache.size(0), att_cache.size(2)
         | 
| 252 | 
            -
                    chunk_size = xs.size(1)
         | 
| 253 | 
            -
                    attention_key_size = cache_t1 + chunk_size
         | 
| 254 | 
            -
                    pos_emb = self.embed.position_encoding(
         | 
| 255 | 
            -
                        offset=offset - cache_t1, size=attention_key_size
         | 
| 256 | 
            -
                    )
         | 
| 257 | 
            -
                    if required_cache_size < 0:
         | 
| 258 | 
            -
                        next_cache_start = 0
         | 
| 259 | 
            -
                    elif required_cache_size == 0:
         | 
| 260 | 
            -
                        next_cache_start = attention_key_size
         | 
| 261 | 
            -
                    else:
         | 
| 262 | 
            -
                        next_cache_start = max(attention_key_size - required_cache_size, 0)
         | 
| 263 | 
            -
                    r_att_cache = []
         | 
| 264 | 
            -
                    r_cnn_cache = []
         | 
| 265 | 
            -
                    for i, layer in enumerate(self.encoders):
         | 
| 266 | 
            -
                        # NOTE(xcsong): Before layer.forward
         | 
| 267 | 
            -
                        #   shape(att_cache[i:i + 1]) is (1, head, cache_t1, d_k * 2),
         | 
| 268 | 
            -
                        #   shape(cnn_cache[i])       is (b=1, hidden-dim, cache_t2)
         | 
| 269 | 
            -
                        xs, _, new_att_cache, new_cnn_cache = layer(
         | 
| 270 | 
            -
                            xs,
         | 
| 271 | 
            -
                            att_mask,
         | 
| 272 | 
            -
                            pos_emb,
         | 
| 273 | 
            -
                            att_cache=att_cache[i : i + 1] if elayers > 0 else att_cache,
         | 
| 274 | 
            -
                            cnn_cache=cnn_cache[i] if cnn_cache.size(0) > 0 else cnn_cache,
         | 
| 275 | 
            -
                        )
         | 
| 276 | 
            -
                        # NOTE(xcsong): After layer.forward
         | 
| 277 | 
            -
                        #   shape(new_att_cache) is (1, head, attention_key_size, d_k * 2),
         | 
| 278 | 
            -
                        #   shape(new_cnn_cache) is (b=1, hidden-dim, cache_t2)
         | 
| 279 | 
            -
                        r_att_cache.append(new_att_cache[:, :, next_cache_start:, :])
         | 
| 280 | 
            -
                        r_cnn_cache.append(new_cnn_cache.unsqueeze(0))
         | 
| 281 | 
            -
                    if self.normalize_before:
         | 
| 282 | 
            -
                        xs = self.after_norm(xs)
         | 
| 283 | 
            -
             | 
| 284 | 
            -
                    # NOTE(xcsong): shape(r_att_cache) is (elayers, head, ?, d_k * 2),
         | 
| 285 | 
            -
                    #   ? may be larger than cache_t1, it depends on required_cache_size
         | 
| 286 | 
            -
                    r_att_cache = torch.cat(r_att_cache, dim=0)
         | 
| 287 | 
            -
                    # NOTE(xcsong): shape(r_cnn_cache) is (e, b=1, hidden-dim, cache_t2)
         | 
| 288 | 
            -
                    r_cnn_cache = torch.cat(r_cnn_cache, dim=0)
         | 
| 289 | 
            -
             | 
| 290 | 
            -
                    return (xs, r_att_cache, r_cnn_cache)
         | 
| 291 | 
            -
             | 
| 292 | 
            -
                @torch.jit.unused
         | 
| 293 | 
            -
                def forward_chunk_by_chunk(
         | 
| 294 | 
            -
                    self,
         | 
| 295 | 
            -
                    xs: torch.Tensor,
         | 
| 296 | 
            -
                    decoding_chunk_size: int,
         | 
| 297 | 
            -
                    num_decoding_left_chunks: int = -1,
         | 
| 298 | 
            -
                ) -> Tuple[torch.Tensor, torch.Tensor]:
         | 
| 299 | 
            -
                    """Forward input chunk by chunk with chunk_size like a streaming
         | 
| 300 | 
            -
                        fashion
         | 
| 301 | 
            -
             | 
| 302 | 
            -
                    Here we should pay special attention to computation cache in the
         | 
| 303 | 
            -
                    streaming style forward chunk by chunk. Three things should be taken
         | 
| 304 | 
            -
                    into account for computation in the current network:
         | 
| 305 | 
            -
                        1. transformer/conformer encoder layers output cache
         | 
| 306 | 
            -
                        2. convolution in conformer
         | 
| 307 | 
            -
                        3. convolution in subsampling
         | 
| 308 | 
            -
             | 
| 309 | 
            -
                    However, we don't implement subsampling cache for:
         | 
| 310 | 
            -
                        1. We can control subsampling module to output the right result by
         | 
| 311 | 
            -
                           overlapping input instead of cache left context, even though it
         | 
| 312 | 
            -
                           wastes some computation, but subsampling only takes a very
         | 
| 313 | 
            -
                           small fraction of computation in the whole model.
         | 
| 314 | 
            -
                        2. Typically, there are several covolution layers with subsampling
         | 
| 315 | 
            -
                           in subsampling module, it is tricky and complicated to do cache
         | 
| 316 | 
            -
                           with different convolution layers with different subsampling
         | 
| 317 | 
            -
                           rate.
         | 
| 318 | 
            -
                        3. Currently, nn.Sequential is used to stack all the convolution
         | 
| 319 | 
            -
                           layers in subsampling, we need to rewrite it to make it work
         | 
| 320 | 
            -
                           with cache, which is not preferred.
         | 
| 321 | 
            -
                    Args:
         | 
| 322 | 
            -
                        xs (torch.Tensor): (1, max_len, dim)
         | 
| 323 | 
            -
                        chunk_size (int): decoding chunk size
         | 
| 324 | 
            -
                    """
         | 
| 325 | 
            -
                    assert decoding_chunk_size > 0
         | 
| 326 | 
            -
                    # The model is trained by static or dynamic chunk
         | 
| 327 | 
            -
                    assert self.static_chunk_size > 0 or self.use_dynamic_chunk
         | 
| 328 | 
            -
                    subsampling = self.embed.subsampling_rate
         | 
| 329 | 
            -
                    context = self.embed.right_context + 1  # Add current frame
         | 
| 330 | 
            -
                    stride = subsampling * decoding_chunk_size
         | 
| 331 | 
            -
                    decoding_window = (decoding_chunk_size - 1) * subsampling + context
         | 
| 332 | 
            -
                    num_frames = xs.size(1)
         | 
| 333 | 
            -
                    att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), device=xs.device)
         | 
| 334 | 
            -
                    cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), device=xs.device)
         | 
| 335 | 
            -
                    outputs = []
         | 
| 336 | 
            -
                    offset = 0
         | 
| 337 | 
            -
                    required_cache_size = decoding_chunk_size * num_decoding_left_chunks
         | 
| 338 | 
            -
             | 
| 339 | 
            -
                    # Feed forward overlap input step by step
         | 
| 340 | 
            -
                    for cur in range(0, num_frames - context + 1, stride):
         | 
| 341 | 
            -
                        end = min(cur + decoding_window, num_frames)
         | 
| 342 | 
            -
                        chunk_xs = xs[:, cur:end, :]
         | 
| 343 | 
            -
                        (y, att_cache, cnn_cache) = self.forward_chunk(
         | 
| 344 | 
            -
                            chunk_xs, offset, required_cache_size, att_cache, cnn_cache
         | 
| 345 | 
            -
                        )
         | 
| 346 | 
            -
                        outputs.append(y)
         | 
| 347 | 
            -
                        offset += y.size(1)
         | 
| 348 | 
            -
                    ys = torch.cat(outputs, 1)
         | 
| 349 | 
            -
                    masks = torch.ones((1, 1, ys.size(1)), device=ys.device, dtype=torch.bool)
         | 
| 350 | 
            -
                    return ys, masks
         | 
| 351 | 
            -
             | 
| 352 | 
            -
             | 
| 353 | 
            -
            class TransformerEncoder(BaseEncoder):
         | 
| 354 | 
            -
                """Transformer encoder module."""
         | 
| 355 | 
            -
             | 
| 356 | 
            -
                def __init__(
         | 
| 357 | 
            -
                    self,
         | 
| 358 | 
            -
                    input_size: int,
         | 
| 359 | 
            -
                    output_size: int = 256,
         | 
| 360 | 
            -
                    attention_heads: int = 4,
         | 
| 361 | 
            -
                    linear_units: int = 2048,
         | 
| 362 | 
            -
                    num_blocks: int = 6,
         | 
| 363 | 
            -
                    dropout_rate: float = 0.1,
         | 
| 364 | 
            -
                    positional_dropout_rate: float = 0.1,
         | 
| 365 | 
            -
                    attention_dropout_rate: float = 0.0,
         | 
| 366 | 
            -
                    input_layer: str = "conv2d",
         | 
| 367 | 
            -
                    pos_enc_layer_type: str = "abs_pos",
         | 
| 368 | 
            -
                    normalize_before: bool = True,
         | 
| 369 | 
            -
                    static_chunk_size: int = 0,
         | 
| 370 | 
            -
                    use_dynamic_chunk: bool = False,
         | 
| 371 | 
            -
                    global_cmvn: torch.nn.Module = None,
         | 
| 372 | 
            -
                    use_dynamic_left_chunk: bool = False,
         | 
| 373 | 
            -
                    key_bias: bool = True,
         | 
| 374 | 
            -
                    selfattention_layer_type: str = "selfattn",
         | 
| 375 | 
            -
                    activation_type: str = "relu",
         | 
| 376 | 
            -
                    gradient_checkpointing: bool = False,
         | 
| 377 | 
            -
                ):
         | 
| 378 | 
            -
                    """Construct TransformerEncoder
         | 
| 379 | 
            -
             | 
| 380 | 
            -
                    See Encoder for the meaning of each parameter.
         | 
| 381 | 
            -
                    """
         | 
| 382 | 
            -
                    super().__init__(
         | 
| 383 | 
            -
                        input_size,
         | 
| 384 | 
            -
                        output_size,
         | 
| 385 | 
            -
                        attention_heads,
         | 
| 386 | 
            -
                        linear_units,
         | 
| 387 | 
            -
                        num_blocks,
         | 
| 388 | 
            -
                        dropout_rate,
         | 
| 389 | 
            -
                        positional_dropout_rate,
         | 
| 390 | 
            -
                        attention_dropout_rate,
         | 
| 391 | 
            -
                        input_layer,
         | 
| 392 | 
            -
                        pos_enc_layer_type,
         | 
| 393 | 
            -
                        normalize_before,
         | 
| 394 | 
            -
                        static_chunk_size,
         | 
| 395 | 
            -
                        use_dynamic_chunk,
         | 
| 396 | 
            -
                        global_cmvn,
         | 
| 397 | 
            -
                        use_dynamic_left_chunk,
         | 
| 398 | 
            -
                        gradient_checkpointing,
         | 
| 399 | 
            -
                    )
         | 
| 400 | 
            -
                    activation = COSYVOICE_ACTIVATION_CLASSES[activation_type]()
         | 
| 401 | 
            -
                    self.encoders = torch.nn.ModuleList(
         | 
| 402 | 
            -
                        [
         | 
| 403 | 
            -
                            TransformerEncoderLayer(
         | 
| 404 | 
            -
                                output_size,
         | 
| 405 | 
            -
                                COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type](
         | 
| 406 | 
            -
                                    attention_heads, output_size, attention_dropout_rate, key_bias
         | 
| 407 | 
            -
                                ),
         | 
| 408 | 
            -
                                PositionwiseFeedForward(
         | 
| 409 | 
            -
                                    output_size, linear_units, dropout_rate, activation
         | 
| 410 | 
            -
                                ),
         | 
| 411 | 
            -
                                dropout_rate,
         | 
| 412 | 
            -
                                normalize_before,
         | 
| 413 | 
            -
                            )
         | 
| 414 | 
            -
                            for _ in range(num_blocks)
         | 
| 415 | 
            -
                        ]
         | 
| 416 | 
            -
                    )
         | 
| 417 | 
            -
             | 
| 418 | 
            -
             | 
| 419 | 
            -
            class ConformerEncoder(BaseEncoder):
         | 
| 420 | 
            -
                """Conformer encoder module."""
         | 
| 421 | 
            -
             | 
| 422 | 
            -
                def __init__(
         | 
| 423 | 
            -
                    self,
         | 
| 424 | 
            -
                    input_size: int,
         | 
| 425 | 
            -
                    output_size: int = 256,
         | 
| 426 | 
            -
                    attention_heads: int = 4,
         | 
| 427 | 
            -
                    linear_units: int = 2048,
         | 
| 428 | 
            -
                    num_blocks: int = 6,
         | 
| 429 | 
            -
                    dropout_rate: float = 0.1,
         | 
| 430 | 
            -
                    positional_dropout_rate: float = 0.1,
         | 
| 431 | 
            -
                    attention_dropout_rate: float = 0.0,
         | 
| 432 | 
            -
                    input_layer: str = "conv2d",
         | 
| 433 | 
            -
                    pos_enc_layer_type: str = "rel_pos",
         | 
| 434 | 
            -
                    normalize_before: bool = True,
         | 
| 435 | 
            -
                    static_chunk_size: int = 0,
         | 
| 436 | 
            -
                    use_dynamic_chunk: bool = False,
         | 
| 437 | 
            -
                    global_cmvn: torch.nn.Module = None,
         | 
| 438 | 
            -
                    use_dynamic_left_chunk: bool = False,
         | 
| 439 | 
            -
                    positionwise_conv_kernel_size: int = 1,
         | 
| 440 | 
            -
                    macaron_style: bool = True,
         | 
| 441 | 
            -
                    selfattention_layer_type: str = "rel_selfattn",
         | 
| 442 | 
            -
                    activation_type: str = "swish",
         | 
| 443 | 
            -
                    use_cnn_module: bool = True,
         | 
| 444 | 
            -
                    cnn_module_kernel: int = 15,
         | 
| 445 | 
            -
                    causal: bool = False,
         | 
| 446 | 
            -
                    cnn_module_norm: str = "batch_norm",
         | 
| 447 | 
            -
                    key_bias: bool = True,
         | 
| 448 | 
            -
                    gradient_checkpointing: bool = False,
         | 
| 449 | 
            -
                ):
         | 
| 450 | 
            -
                    """Construct ConformerEncoder
         | 
| 451 | 
            -
             | 
| 452 | 
            -
                    Args:
         | 
| 453 | 
            -
                        input_size to use_dynamic_chunk, see in BaseEncoder
         | 
| 454 | 
            -
                        positionwise_conv_kernel_size (int): Kernel size of positionwise
         | 
| 455 | 
            -
                            conv1d layer.
         | 
| 456 | 
            -
                        macaron_style (bool): Whether to use macaron style for
         | 
| 457 | 
            -
                            positionwise layer.
         | 
| 458 | 
            -
                        selfattention_layer_type (str): Encoder attention layer type,
         | 
| 459 | 
            -
                            the parameter has no effect now, it's just for configure
         | 
| 460 | 
            -
                            compatibility.
         | 
| 461 | 
            -
                        activation_type (str): Encoder activation function type.
         | 
| 462 | 
            -
                        use_cnn_module (bool): Whether to use convolution module.
         | 
| 463 | 
            -
                        cnn_module_kernel (int): Kernel size of convolution module.
         | 
| 464 | 
            -
                        causal (bool): whether to use causal convolution or not.
         | 
| 465 | 
            -
                        key_bias: whether use bias in attention.linear_k, False for whisper models.
         | 
| 466 | 
            -
                    """
         | 
| 467 | 
            -
                    super().__init__(
         | 
| 468 | 
            -
                        input_size,
         | 
| 469 | 
            -
                        output_size,
         | 
| 470 | 
            -
                        attention_heads,
         | 
| 471 | 
            -
                        linear_units,
         | 
| 472 | 
            -
                        num_blocks,
         | 
| 473 | 
            -
                        dropout_rate,
         | 
| 474 | 
            -
                        positional_dropout_rate,
         | 
| 475 | 
            -
                        attention_dropout_rate,
         | 
| 476 | 
            -
                        input_layer,
         | 
| 477 | 
            -
                        pos_enc_layer_type,
         | 
| 478 | 
            -
                        normalize_before,
         | 
| 479 | 
            -
                        static_chunk_size,
         | 
| 480 | 
            -
                        use_dynamic_chunk,
         | 
| 481 | 
            -
                        global_cmvn,
         | 
| 482 | 
            -
                        use_dynamic_left_chunk,
         | 
| 483 | 
            -
                        gradient_checkpointing,
         | 
| 484 | 
            -
                    )
         | 
| 485 | 
            -
                    activation = COSYVOICE_ACTIVATION_CLASSES[activation_type]()
         | 
| 486 | 
            -
             | 
| 487 | 
            -
                    # self-attention module definition
         | 
| 488 | 
            -
                    encoder_selfattn_layer_args = (
         | 
| 489 | 
            -
                        attention_heads,
         | 
| 490 | 
            -
                        output_size,
         | 
| 491 | 
            -
                        attention_dropout_rate,
         | 
| 492 | 
            -
                        key_bias,
         | 
| 493 | 
            -
                    )
         | 
| 494 | 
            -
                    # feed-forward module definition
         | 
| 495 | 
            -
                    positionwise_layer_args = (
         | 
| 496 | 
            -
                        output_size,
         | 
| 497 | 
            -
                        linear_units,
         | 
| 498 | 
            -
                        dropout_rate,
         | 
| 499 | 
            -
                        activation,
         | 
| 500 | 
            -
                    )
         | 
| 501 | 
            -
                    # convolution module definition
         | 
| 502 | 
            -
                    convolution_layer_args = (
         | 
| 503 | 
            -
                        output_size,
         | 
| 504 | 
            -
                        cnn_module_kernel,
         | 
| 505 | 
            -
                        activation,
         | 
| 506 | 
            -
                        cnn_module_norm,
         | 
| 507 | 
            -
                        causal,
         | 
| 508 | 
            -
                    )
         | 
| 509 | 
            -
             | 
| 510 | 
            -
                    self.encoders = torch.nn.ModuleList(
         | 
| 511 | 
            -
                        [
         | 
| 512 | 
            -
                            ConformerEncoderLayer(
         | 
| 513 | 
            -
                                output_size,
         | 
| 514 | 
            -
                                COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type](
         | 
| 515 | 
            -
                                    *encoder_selfattn_layer_args
         | 
| 516 | 
            -
                                ),
         | 
| 517 | 
            -
                                PositionwiseFeedForward(*positionwise_layer_args),
         | 
| 518 | 
            -
                                (
         | 
| 519 | 
            -
                                    PositionwiseFeedForward(*positionwise_layer_args)
         | 
| 520 | 
            -
                                    if macaron_style
         | 
| 521 | 
            -
                                    else None
         | 
| 522 | 
            -
                                ),
         | 
| 523 | 
            -
                                (
         | 
| 524 | 
            -
                                    ConvolutionModule(*convolution_layer_args)
         | 
| 525 | 
            -
                                    if use_cnn_module
         | 
| 526 | 
            -
                                    else None
         | 
| 527 | 
            -
                                ),
         | 
| 528 | 
            -
                                dropout_rate,
         | 
| 529 | 
            -
                                normalize_before,
         | 
| 530 | 
            -
                            )
         | 
| 531 | 
            -
                            for _ in range(num_blocks)
         | 
| 532 | 
            -
                        ]
         | 
| 533 | 
            -
                    )
         | 
| 534 | 
            -
                    self.inference_buffers = {}
         | 
| 535 | 
            -
                    self.inference_graphs = {}
         | 
| 536 | 
            -
             | 
| 537 | 
            -
                @torch.inference_mode()
         | 
| 538 | 
            -
                def capture_inference(self, seq_len_to_capture=[128, 256, 512, 1024]):
         | 
| 539 | 
            -
                    device = next(self.parameters()).device
         | 
| 540 | 
            -
                    start_time = time.time()
         | 
| 541 | 
            -
                    print(
         | 
| 542 | 
            -
                        f"Start capture_inference for ConformerEncoder, seq_len_to_capture: {seq_len_to_capture}"
         | 
| 543 | 
            -
                    )
         | 
| 544 | 
            -
             | 
| 545 | 
            -
                    for seq_len in seq_len_to_capture:
         | 
| 546 | 
            -
                        xs = torch.randn(
         | 
| 547 | 
            -
                            1, seq_len, self._output_size, device=device, dtype=torch.bfloat16
         | 
| 548 | 
            -
                        )
         | 
| 549 | 
            -
                        xs_lens = torch.tensor([seq_len], device=device, dtype=torch.int32)
         | 
| 550 | 
            -
                        decoding_chunk_size = 0
         | 
| 551 | 
            -
                        num_decoding_left_chunks = -1
         | 
| 552 | 
            -
             | 
| 553 | 
            -
                        T = xs.size(1)
         | 
| 554 | 
            -
                        masks = ~make_pad_mask(xs_lens, T).unsqueeze(1)  # (B, 1, T)
         | 
| 555 | 
            -
                        if self.global_cmvn is not None:
         | 
| 556 | 
            -
                            xs = self.global_cmvn(xs)
         | 
| 557 | 
            -
                        xs, pos_emb, masks = self.embed(xs, masks)
         | 
| 558 | 
            -
                        mask_pad = masks  # (B, 1, T/subsample_rate)
         | 
| 559 | 
            -
                        chunk_masks = add_optional_chunk_mask(
         | 
| 560 | 
            -
                            xs,
         | 
| 561 | 
            -
                            masks,
         | 
| 562 | 
            -
                            self.use_dynamic_chunk,
         | 
| 563 | 
            -
                            self.use_dynamic_left_chunk,
         | 
| 564 | 
            -
                            decoding_chunk_size,
         | 
| 565 | 
            -
                            self.static_chunk_size,
         | 
| 566 | 
            -
                            num_decoding_left_chunks,
         | 
| 567 | 
            -
                        )
         | 
| 568 | 
            -
             | 
| 569 | 
            -
                        g = torch.cuda.CUDAGraph()
         | 
| 570 | 
            -
                        with torch.cuda.graph(g):
         | 
| 571 | 
            -
                            out = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad)
         | 
| 572 | 
            -
             | 
| 573 | 
            -
                        self.inference_graphs[seq_len] = g
         | 
| 574 | 
            -
                        self.inference_buffers[seq_len] = {
         | 
| 575 | 
            -
                            "xs": xs,
         | 
| 576 | 
            -
                            "chunk_masks": chunk_masks,
         | 
| 577 | 
            -
                            "pos_emb": pos_emb,
         | 
| 578 | 
            -
                            "mask_pad": mask_pad,
         | 
| 579 | 
            -
                            "out": out,
         | 
| 580 | 
            -
                        }
         | 
| 581 | 
            -
                    end_time = time.time()
         | 
| 582 | 
            -
                    print(
         | 
| 583 | 
            -
                        f"Finish capture_inference for ConformerEncoder, time elapsed: {end_time - start_time}"
         | 
| 584 | 
            -
                    )
         | 
| 585 | 
            -
             | 
| 586 | 
            -
                @torch.inference_mode()
         | 
| 587 | 
            -
                def inference(self, xs: torch.Tensor, xs_lens: torch.Tensor):
         | 
| 588 | 
            -
                    curr_seq_len = xs.shape[1]
         | 
| 589 | 
            -
                    target_len = None
         | 
| 590 | 
            -
             | 
| 591 | 
            -
                    for seq_len in sorted(self.inference_graphs.keys()):
         | 
| 592 | 
            -
                        if seq_len >= curr_seq_len:
         | 
| 593 | 
            -
                            target_len = seq_len
         | 
| 594 | 
            -
                            break
         | 
| 595 | 
            -
             | 
| 596 | 
            -
                    if target_len is not None:
         | 
| 597 | 
            -
                        xs = F.pad(xs, (0, 0, 0, target_len - curr_seq_len), "constant", 0)
         | 
| 598 | 
            -
             | 
| 599 | 
            -
                    decoding_chunk_size = 0
         | 
| 600 | 
            -
                    num_decoding_left_chunks = -1
         | 
| 601 | 
            -
             | 
| 602 | 
            -
                    T = xs.size(1)
         | 
| 603 | 
            -
                    masks = ~make_pad_mask(xs_lens, T).unsqueeze(1)  # (B, 1, T)
         | 
| 604 | 
            -
                    if self.global_cmvn is not None:
         | 
| 605 | 
            -
                        xs = self.global_cmvn(xs)
         | 
| 606 | 
            -
                    xs, pos_emb, masks = self.embed(xs, masks)
         | 
| 607 | 
            -
                    mask_pad = masks  # (B, 1, T/subsample_rate)
         | 
| 608 | 
            -
                    chunk_masks = add_optional_chunk_mask(
         | 
| 609 | 
            -
                        xs,
         | 
| 610 | 
            -
                        masks,
         | 
| 611 | 
            -
                        self.use_dynamic_chunk,
         | 
| 612 | 
            -
                        self.use_dynamic_left_chunk,
         | 
| 613 | 
            -
                        decoding_chunk_size,
         | 
| 614 | 
            -
                        self.static_chunk_size,
         | 
| 615 | 
            -
                        num_decoding_left_chunks,
         | 
| 616 | 
            -
                    )
         | 
| 617 | 
            -
             | 
| 618 | 
            -
                    if target_len is not None:
         | 
| 619 | 
            -
                        buffer = self.inference_buffers[target_len]
         | 
| 620 | 
            -
                        buffer["xs"].copy_(xs)
         | 
| 621 | 
            -
                        buffer["chunk_masks"].copy_(chunk_masks)
         | 
| 622 | 
            -
                        buffer["pos_emb"].copy_(pos_emb)
         | 
| 623 | 
            -
                        buffer["mask_pad"].copy_(mask_pad)
         | 
| 624 | 
            -
             | 
| 625 | 
            -
                        self.inference_graphs[target_len].replay()
         | 
| 626 | 
            -
             | 
| 627 | 
            -
                        out = buffer["out"][:, :curr_seq_len, :]
         | 
| 628 | 
            -
                    else:
         | 
| 629 | 
            -
                        out = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad)
         | 
| 630 | 
            -
             | 
| 631 | 
            -
                    if self.normalize_before:
         | 
| 632 | 
            -
                        out = self.after_norm(out)
         | 
| 633 | 
            -
                    return out, masks
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        cosyvoice/transformer/encoder_layer.py
    DELETED
    
    | @@ -1,237 +0,0 @@ | |
| 1 | 
            -
            # Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
         | 
| 2 | 
            -
            #               2022 Xingchen Song ([email protected])
         | 
| 3 | 
            -
            #
         | 
| 4 | 
            -
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 5 | 
            -
            # you may not use this file except in compliance with the License.
         | 
| 6 | 
            -
            # You may obtain a copy of the License at
         | 
| 7 | 
            -
            #
         | 
| 8 | 
            -
            #   http://www.apache.org/licenses/LICENSE-2.0
         | 
| 9 | 
            -
            #
         | 
| 10 | 
            -
            # Unless required by applicable law or agreed to in writing, software
         | 
| 11 | 
            -
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 12 | 
            -
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 13 | 
            -
            # See the License for the specific language governing permissions and
         | 
| 14 | 
            -
            # limitations under the License.
         | 
| 15 | 
            -
            # Modified from ESPnet(https://github.com/espnet/espnet)
         | 
| 16 | 
            -
            """Encoder self-attention layer definition."""
         | 
| 17 | 
            -
             | 
| 18 | 
            -
            from typing import Optional, Tuple
         | 
| 19 | 
            -
             | 
| 20 | 
            -
            import torch
         | 
| 21 | 
            -
            from torch import nn
         | 
| 22 | 
            -
             | 
| 23 | 
            -
             | 
| 24 | 
            -
            class TransformerEncoderLayer(nn.Module):
         | 
| 25 | 
            -
                """Encoder layer module.
         | 
| 26 | 
            -
             | 
| 27 | 
            -
                Args:
         | 
| 28 | 
            -
                    size (int): Input dimension.
         | 
| 29 | 
            -
                    self_attn (torch.nn.Module): Self-attention module instance.
         | 
| 30 | 
            -
                        `MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
         | 
| 31 | 
            -
                        instance can be used as the argument.
         | 
| 32 | 
            -
                    feed_forward (torch.nn.Module): Feed-forward module instance.
         | 
| 33 | 
            -
                        `PositionwiseFeedForward`, instance can be used as the argument.
         | 
| 34 | 
            -
                    dropout_rate (float): Dropout rate.
         | 
| 35 | 
            -
                    normalize_before (bool):
         | 
| 36 | 
            -
                        True: use layer_norm before each sub-block.
         | 
| 37 | 
            -
                        False: to use layer_norm after each sub-block.
         | 
| 38 | 
            -
                """
         | 
| 39 | 
            -
             | 
| 40 | 
            -
                def __init__(
         | 
| 41 | 
            -
                    self,
         | 
| 42 | 
            -
                    size: int,
         | 
| 43 | 
            -
                    self_attn: torch.nn.Module,
         | 
| 44 | 
            -
                    feed_forward: torch.nn.Module,
         | 
| 45 | 
            -
                    dropout_rate: float,
         | 
| 46 | 
            -
                    normalize_before: bool = True,
         | 
| 47 | 
            -
                ):
         | 
| 48 | 
            -
                    """Construct an EncoderLayer object."""
         | 
| 49 | 
            -
                    super().__init__()
         | 
| 50 | 
            -
                    self.self_attn = self_attn
         | 
| 51 | 
            -
                    self.feed_forward = feed_forward
         | 
| 52 | 
            -
                    self.norm1 = nn.LayerNorm(size, eps=1e-5)
         | 
| 53 | 
            -
                    self.norm2 = nn.LayerNorm(size, eps=1e-5)
         | 
| 54 | 
            -
                    self.dropout = nn.Dropout(dropout_rate)
         | 
| 55 | 
            -
                    self.size = size
         | 
| 56 | 
            -
                    self.normalize_before = normalize_before
         | 
| 57 | 
            -
             | 
| 58 | 
            -
                def forward(
         | 
| 59 | 
            -
                    self,
         | 
| 60 | 
            -
                    x: torch.Tensor,
         | 
| 61 | 
            -
                    mask: torch.Tensor,
         | 
| 62 | 
            -
                    pos_emb: torch.Tensor,
         | 
| 63 | 
            -
                    mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
         | 
| 64 | 
            -
                    att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
         | 
| 65 | 
            -
                    cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
         | 
| 66 | 
            -
                ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
         | 
| 67 | 
            -
                    """Compute encoded features.
         | 
| 68 | 
            -
             | 
| 69 | 
            -
                    Args:
         | 
| 70 | 
            -
                        x (torch.Tensor): (#batch, time, size)
         | 
| 71 | 
            -
                        mask (torch.Tensor): Mask tensor for the input (#batch, time,time),
         | 
| 72 | 
            -
                            (0, 0, 0) means fake mask.
         | 
| 73 | 
            -
                        pos_emb (torch.Tensor): just for interface compatibility
         | 
| 74 | 
            -
                            to ConformerEncoderLayer
         | 
| 75 | 
            -
                        mask_pad (torch.Tensor): does not used in transformer layer,
         | 
| 76 | 
            -
                            just for unified api with conformer.
         | 
| 77 | 
            -
                        att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
         | 
| 78 | 
            -
                            (#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
         | 
| 79 | 
            -
                        cnn_cache (torch.Tensor): Convolution cache in conformer layer
         | 
| 80 | 
            -
                            (#batch=1, size, cache_t2), not used here, it's for interface
         | 
| 81 | 
            -
                            compatibility to ConformerEncoderLayer.
         | 
| 82 | 
            -
                    Returns:
         | 
| 83 | 
            -
                        torch.Tensor: Output tensor (#batch, time, size).
         | 
| 84 | 
            -
                        torch.Tensor: Mask tensor (#batch, time, time).
         | 
| 85 | 
            -
                        torch.Tensor: att_cache tensor,
         | 
| 86 | 
            -
                            (#batch=1, head, cache_t1 + time, d_k * 2).
         | 
| 87 | 
            -
                        torch.Tensor: cnn_cahce tensor (#batch=1, size, cache_t2).
         | 
| 88 | 
            -
             | 
| 89 | 
            -
                    """
         | 
| 90 | 
            -
                    residual = x
         | 
| 91 | 
            -
                    if self.normalize_before:
         | 
| 92 | 
            -
                        x = self.norm1(x)
         | 
| 93 | 
            -
                    x_att, new_att_cache = self.self_attn(
         | 
| 94 | 
            -
                        x, x, x, mask, pos_emb=pos_emb, cache=att_cache
         | 
| 95 | 
            -
                    )
         | 
| 96 | 
            -
                    x = residual + self.dropout(x_att)
         | 
| 97 | 
            -
                    if not self.normalize_before:
         | 
| 98 | 
            -
                        x = self.norm1(x)
         | 
| 99 | 
            -
             | 
| 100 | 
            -
                    residual = x
         | 
| 101 | 
            -
                    if self.normalize_before:
         | 
| 102 | 
            -
                        x = self.norm2(x)
         | 
| 103 | 
            -
                    x = residual + self.dropout(self.feed_forward(x))
         | 
| 104 | 
            -
                    if not self.normalize_before:
         | 
| 105 | 
            -
                        x = self.norm2(x)
         | 
| 106 | 
            -
             | 
| 107 | 
            -
                    fake_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
         | 
| 108 | 
            -
                    return x, mask, new_att_cache, fake_cnn_cache
         | 
| 109 | 
            -
             | 
| 110 | 
            -
             | 
| 111 | 
            -
            class ConformerEncoderLayer(nn.Module):
         | 
| 112 | 
            -
                """Encoder layer module.
         | 
| 113 | 
            -
                Args:
         | 
| 114 | 
            -
                    size (int): Input dimension.
         | 
| 115 | 
            -
                    self_attn (torch.nn.Module): Self-attention module instance.
         | 
| 116 | 
            -
                        `MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
         | 
| 117 | 
            -
                        instance can be used as the argument.
         | 
| 118 | 
            -
                    feed_forward (torch.nn.Module): Feed-forward module instance.
         | 
| 119 | 
            -
                        `PositionwiseFeedForward` instance can be used as the argument.
         | 
| 120 | 
            -
                    feed_forward_macaron (torch.nn.Module): Additional feed-forward module
         | 
| 121 | 
            -
                         instance.
         | 
| 122 | 
            -
                        `PositionwiseFeedForward` instance can be used as the argument.
         | 
| 123 | 
            -
                    conv_module (torch.nn.Module): Convolution module instance.
         | 
| 124 | 
            -
                        `ConvlutionModule` instance can be used as the argument.
         | 
| 125 | 
            -
                    dropout_rate (float): Dropout rate.
         | 
| 126 | 
            -
                    normalize_before (bool):
         | 
| 127 | 
            -
                        True: use layer_norm before each sub-block.
         | 
| 128 | 
            -
                        False: use layer_norm after each sub-block.
         | 
| 129 | 
            -
                """
         | 
| 130 | 
            -
             | 
| 131 | 
            -
                def __init__(
         | 
| 132 | 
            -
                    self,
         | 
| 133 | 
            -
                    size: int,
         | 
| 134 | 
            -
                    self_attn: torch.nn.Module,
         | 
| 135 | 
            -
                    feed_forward: Optional[nn.Module] = None,
         | 
| 136 | 
            -
                    feed_forward_macaron: Optional[nn.Module] = None,
         | 
| 137 | 
            -
                    conv_module: Optional[nn.Module] = None,
         | 
| 138 | 
            -
                    dropout_rate: float = 0.1,
         | 
| 139 | 
            -
                    normalize_before: bool = True,
         | 
| 140 | 
            -
                ):
         | 
| 141 | 
            -
                    """Construct an EncoderLayer object."""
         | 
| 142 | 
            -
                    super().__init__()
         | 
| 143 | 
            -
                    self.self_attn = self_attn
         | 
| 144 | 
            -
                    self.feed_forward = feed_forward
         | 
| 145 | 
            -
                    self.feed_forward_macaron = feed_forward_macaron
         | 
| 146 | 
            -
                    self.conv_module = conv_module
         | 
| 147 | 
            -
                    self.norm_ff = nn.LayerNorm(size, eps=1e-5)  # for the FNN module
         | 
| 148 | 
            -
                    self.norm_mha = nn.LayerNorm(size, eps=1e-5)  # for the MHA module
         | 
| 149 | 
            -
                    if feed_forward_macaron is not None:
         | 
| 150 | 
            -
                        self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-5)
         | 
| 151 | 
            -
                        self.ff_scale = 0.5
         | 
| 152 | 
            -
                    else:
         | 
| 153 | 
            -
                        self.ff_scale = 1.0
         | 
| 154 | 
            -
                    if self.conv_module is not None:
         | 
| 155 | 
            -
                        self.norm_conv = nn.LayerNorm(size, eps=1e-5)  # for the CNN module
         | 
| 156 | 
            -
                        self.norm_final = nn.LayerNorm(
         | 
| 157 | 
            -
                            size, eps=1e-5
         | 
| 158 | 
            -
                        )  # for the final output of the block
         | 
| 159 | 
            -
                    self.dropout = nn.Dropout(dropout_rate)
         | 
| 160 | 
            -
                    self.size = size
         | 
| 161 | 
            -
                    self.normalize_before = normalize_before
         | 
| 162 | 
            -
             | 
| 163 | 
            -
                def forward(
         | 
| 164 | 
            -
                    self,
         | 
| 165 | 
            -
                    x: torch.Tensor,
         | 
| 166 | 
            -
                    mask: torch.Tensor,
         | 
| 167 | 
            -
                    pos_emb: torch.Tensor,
         | 
| 168 | 
            -
                    mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
         | 
| 169 | 
            -
                    att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
         | 
| 170 | 
            -
                    cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
         | 
| 171 | 
            -
                ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
         | 
| 172 | 
            -
                    """Compute encoded features.
         | 
| 173 | 
            -
             | 
| 174 | 
            -
                    Args:
         | 
| 175 | 
            -
                        x (torch.Tensor): (#batch, time, size)
         | 
| 176 | 
            -
                        mask (torch.Tensor): Mask tensor for the input (#batch, time,time),
         | 
| 177 | 
            -
                            (0, 0, 0) means fake mask.
         | 
| 178 | 
            -
                        pos_emb (torch.Tensor): positional encoding, must not be None
         | 
| 179 | 
            -
                            for ConformerEncoderLayer.
         | 
| 180 | 
            -
                        mask_pad (torch.Tensor): batch padding mask used for conv module.
         | 
| 181 | 
            -
                            (#batch, 1,time), (0, 0, 0) means fake mask.
         | 
| 182 | 
            -
                        att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
         | 
| 183 | 
            -
                            (#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
         | 
| 184 | 
            -
                        cnn_cache (torch.Tensor): Convolution cache in conformer layer
         | 
| 185 | 
            -
                            (#batch=1, size, cache_t2)
         | 
| 186 | 
            -
                    Returns:
         | 
| 187 | 
            -
                        torch.Tensor: Output tensor (#batch, time, size).
         | 
| 188 | 
            -
                        torch.Tensor: Mask tensor (#batch, time, time).
         | 
| 189 | 
            -
                        torch.Tensor: att_cache tensor,
         | 
| 190 | 
            -
                            (#batch=1, head, cache_t1 + time, d_k * 2).
         | 
| 191 | 
            -
                        torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2).
         | 
| 192 | 
            -
                    """
         | 
| 193 | 
            -
             | 
| 194 | 
            -
                    # whether to use macaron style
         | 
| 195 | 
            -
                    if self.feed_forward_macaron is not None:
         | 
| 196 | 
            -
                        residual = x
         | 
| 197 | 
            -
                        if self.normalize_before:
         | 
| 198 | 
            -
                            x = self.norm_ff_macaron(x)
         | 
| 199 | 
            -
                        x = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(x))
         | 
| 200 | 
            -
                        if not self.normalize_before:
         | 
| 201 | 
            -
                            x = self.norm_ff_macaron(x)
         | 
| 202 | 
            -
             | 
| 203 | 
            -
                    # multi-headed self-attention module
         | 
| 204 | 
            -
                    residual = x
         | 
| 205 | 
            -
                    if self.normalize_before:
         | 
| 206 | 
            -
                        x = self.norm_mha(x)
         | 
| 207 | 
            -
                    x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb, att_cache)
         | 
| 208 | 
            -
                    x = residual + self.dropout(x_att)
         | 
| 209 | 
            -
                    if not self.normalize_before:
         | 
| 210 | 
            -
                        x = self.norm_mha(x)
         | 
| 211 | 
            -
             | 
| 212 | 
            -
                    # convolution module
         | 
| 213 | 
            -
                    # Fake new cnn cache here, and then change it in conv_module
         | 
| 214 | 
            -
                    new_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
         | 
| 215 | 
            -
                    if self.conv_module is not None:
         | 
| 216 | 
            -
                        residual = x
         | 
| 217 | 
            -
                        if self.normalize_before:
         | 
| 218 | 
            -
                            x = self.norm_conv(x)
         | 
| 219 | 
            -
                        x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache)
         | 
| 220 | 
            -
                        x = residual + self.dropout(x)
         | 
| 221 | 
            -
             | 
| 222 | 
            -
                        if not self.normalize_before:
         | 
| 223 | 
            -
                            x = self.norm_conv(x)
         | 
| 224 | 
            -
             | 
| 225 | 
            -
                    # feed forward module
         | 
| 226 | 
            -
                    residual = x
         | 
| 227 | 
            -
                    if self.normalize_before:
         | 
| 228 | 
            -
                        x = self.norm_ff(x)
         | 
| 229 | 
            -
             | 
| 230 | 
            -
                    x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
         | 
| 231 | 
            -
                    if not self.normalize_before:
         | 
| 232 | 
            -
                        x = self.norm_ff(x)
         | 
| 233 | 
            -
             | 
| 234 | 
            -
                    if self.conv_module is not None:
         | 
| 235 | 
            -
                        x = self.norm_final(x)
         | 
| 236 | 
            -
             | 
| 237 | 
            -
                    return x, mask, new_att_cache, new_cnn_cache
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        cosyvoice/transformer/label_smoothing_loss.py
    DELETED
    
    | @@ -1,98 +0,0 @@ | |
| 1 | 
            -
            # Copyright (c) 2019 Shigeki Karita
         | 
| 2 | 
            -
            #               2020 Mobvoi Inc (Binbin Zhang)
         | 
| 3 | 
            -
            #
         | 
| 4 | 
            -
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 5 | 
            -
            # you may not use this file except in compliance with the License.
         | 
| 6 | 
            -
            # You may obtain a copy of the License at
         | 
| 7 | 
            -
            #
         | 
| 8 | 
            -
            #   http://www.apache.org/licenses/LICENSE-2.0
         | 
| 9 | 
            -
            #
         | 
| 10 | 
            -
            # Unless required by applicable law or agreed to in writing, software
         | 
| 11 | 
            -
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 12 | 
            -
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 13 | 
            -
            # See the License for the specific language governing permissions and
         | 
| 14 | 
            -
            # limitations under the License.
         | 
| 15 | 
            -
            """Label smoothing module."""
         | 
| 16 | 
            -
             | 
| 17 | 
            -
            import torch
         | 
| 18 | 
            -
            from torch import nn
         | 
| 19 | 
            -
             | 
| 20 | 
            -
             | 
| 21 | 
            -
            class LabelSmoothingLoss(nn.Module):
         | 
| 22 | 
            -
                """Label-smoothing loss.
         | 
| 23 | 
            -
             | 
| 24 | 
            -
                In a standard CE loss, the label's data distribution is:
         | 
| 25 | 
            -
                [0,1,2] ->
         | 
| 26 | 
            -
                [
         | 
| 27 | 
            -
                    [1.0, 0.0, 0.0],
         | 
| 28 | 
            -
                    [0.0, 1.0, 0.0],
         | 
| 29 | 
            -
                    [0.0, 0.0, 1.0],
         | 
| 30 | 
            -
                ]
         | 
| 31 | 
            -
             | 
| 32 | 
            -
                In the smoothing version CE Loss,some probabilities
         | 
| 33 | 
            -
                are taken from the true label prob (1.0) and are divided
         | 
| 34 | 
            -
                among other labels.
         | 
| 35 | 
            -
             | 
| 36 | 
            -
                e.g.
         | 
| 37 | 
            -
                smoothing=0.1
         | 
| 38 | 
            -
                [0,1,2] ->
         | 
| 39 | 
            -
                [
         | 
| 40 | 
            -
                    [0.9, 0.05, 0.05],
         | 
| 41 | 
            -
                    [0.05, 0.9, 0.05],
         | 
| 42 | 
            -
                    [0.05, 0.05, 0.9],
         | 
| 43 | 
            -
                ]
         | 
| 44 | 
            -
             | 
| 45 | 
            -
                Args:
         | 
| 46 | 
            -
                    size (int): the number of class
         | 
| 47 | 
            -
                    padding_idx (int): padding class id which will be ignored for loss
         | 
| 48 | 
            -
                    smoothing (float): smoothing rate (0.0 means the conventional CE)
         | 
| 49 | 
            -
                    normalize_length (bool):
         | 
| 50 | 
            -
                        normalize loss by sequence length if True
         | 
| 51 | 
            -
                        normalize loss by batch size if False
         | 
| 52 | 
            -
                """
         | 
| 53 | 
            -
             | 
| 54 | 
            -
                def __init__(
         | 
| 55 | 
            -
                    self,
         | 
| 56 | 
            -
                    size: int,
         | 
| 57 | 
            -
                    padding_idx: int,
         | 
| 58 | 
            -
                    smoothing: float,
         | 
| 59 | 
            -
                    normalize_length: bool = False,
         | 
| 60 | 
            -
                ):
         | 
| 61 | 
            -
                    """Construct an LabelSmoothingLoss object."""
         | 
| 62 | 
            -
                    super(LabelSmoothingLoss, self).__init__()
         | 
| 63 | 
            -
                    self.criterion = nn.KLDivLoss(reduction="none")
         | 
| 64 | 
            -
                    self.padding_idx = padding_idx
         | 
| 65 | 
            -
                    self.confidence = 1.0 - smoothing
         | 
| 66 | 
            -
                    self.smoothing = smoothing
         | 
| 67 | 
            -
                    self.size = size
         | 
| 68 | 
            -
                    self.normalize_length = normalize_length
         | 
| 69 | 
            -
             | 
| 70 | 
            -
                def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
         | 
| 71 | 
            -
                    """Compute loss between x and target.
         | 
| 72 | 
            -
             | 
| 73 | 
            -
                    The model outputs and data labels tensors are flatten to
         | 
| 74 | 
            -
                    (batch*seqlen, class) shape and a mask is applied to the
         | 
| 75 | 
            -
                    padding part which should not be calculated for loss.
         | 
| 76 | 
            -
             | 
| 77 | 
            -
                    Args:
         | 
| 78 | 
            -
                        x (torch.Tensor): prediction (batch, seqlen, class)
         | 
| 79 | 
            -
                        target (torch.Tensor):
         | 
| 80 | 
            -
                            target signal masked with self.padding_id (batch, seqlen)
         | 
| 81 | 
            -
                    Returns:
         | 
| 82 | 
            -
                        loss (torch.Tensor) : The KL loss, scalar float value
         | 
| 83 | 
            -
                    """
         | 
| 84 | 
            -
                    assert x.size(2) == self.size
         | 
| 85 | 
            -
                    batch_size = x.size(0)
         | 
| 86 | 
            -
                    x = x.view(-1, self.size)
         | 
| 87 | 
            -
                    target = target.view(-1)
         | 
| 88 | 
            -
                    # use zeros_like instead of torch.no_grad() for true_dist,
         | 
| 89 | 
            -
                    # since no_grad() can not be exported by JIT
         | 
| 90 | 
            -
                    true_dist = torch.zeros_like(x)
         | 
| 91 | 
            -
                    true_dist.fill_(self.smoothing / (self.size - 1))
         | 
| 92 | 
            -
                    ignore = target == self.padding_idx  # (B,)
         | 
| 93 | 
            -
                    total = len(target) - ignore.sum().item()
         | 
| 94 | 
            -
                    target = target.masked_fill(ignore, 0)  # avoid -1 index
         | 
| 95 | 
            -
                    true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
         | 
| 96 | 
            -
                    kl = self.criterion(torch.log_softmax(x, dim=1), true_dist)
         | 
| 97 | 
            -
                    denom = total if self.normalize_length else batch_size
         | 
| 98 | 
            -
                    return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        cosyvoice/transformer/positionwise_feed_forward.py
    DELETED
    
    | @@ -1,116 +0,0 @@ | |
| 1 | 
            -
            # Copyright (c) 2019 Shigeki Karita
         | 
| 2 | 
            -
            #               2020 Mobvoi Inc (Binbin Zhang)
         | 
| 3 | 
            -
            #
         | 
| 4 | 
            -
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 5 | 
            -
            # you may not use this file except in compliance with the License.
         | 
| 6 | 
            -
            # You may obtain a copy of the License at
         | 
| 7 | 
            -
            #
         | 
| 8 | 
            -
            #   http://www.apache.org/licenses/LICENSE-2.0
         | 
| 9 | 
            -
            #
         | 
| 10 | 
            -
            # Unless required by applicable law or agreed to in writing, software
         | 
| 11 | 
            -
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 12 | 
            -
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 13 | 
            -
            # See the License for the specific language governing permissions and
         | 
| 14 | 
            -
            # limitations under the License.
         | 
| 15 | 
            -
            """Positionwise feed forward layer definition."""
         | 
| 16 | 
            -
             | 
| 17 | 
            -
            import torch
         | 
| 18 | 
            -
             | 
| 19 | 
            -
             | 
| 20 | 
            -
            class PositionwiseFeedForward(torch.nn.Module):
         | 
| 21 | 
            -
                """Positionwise feed forward layer.
         | 
| 22 | 
            -
             | 
| 23 | 
            -
                FeedForward are appied on each position of the sequence.
         | 
| 24 | 
            -
                The output dim is same with the input dim.
         | 
| 25 | 
            -
             | 
| 26 | 
            -
                Args:
         | 
| 27 | 
            -
                    idim (int): Input dimenstion.
         | 
| 28 | 
            -
                    hidden_units (int): The number of hidden units.
         | 
| 29 | 
            -
                    dropout_rate (float): Dropout rate.
         | 
| 30 | 
            -
                    activation (torch.nn.Module): Activation function
         | 
| 31 | 
            -
                """
         | 
| 32 | 
            -
             | 
| 33 | 
            -
                def __init__(
         | 
| 34 | 
            -
                    self,
         | 
| 35 | 
            -
                    idim: int,
         | 
| 36 | 
            -
                    hidden_units: int,
         | 
| 37 | 
            -
                    dropout_rate: float,
         | 
| 38 | 
            -
                    activation: torch.nn.Module = torch.nn.ReLU(),
         | 
| 39 | 
            -
                ):
         | 
| 40 | 
            -
                    """Construct a PositionwiseFeedForward object."""
         | 
| 41 | 
            -
                    super(PositionwiseFeedForward, self).__init__()
         | 
| 42 | 
            -
                    self.w_1 = torch.nn.Linear(idim, hidden_units)
         | 
| 43 | 
            -
                    self.activation = activation
         | 
| 44 | 
            -
                    self.dropout = torch.nn.Dropout(dropout_rate)
         | 
| 45 | 
            -
                    self.w_2 = torch.nn.Linear(hidden_units, idim)
         | 
| 46 | 
            -
             | 
| 47 | 
            -
                def forward(self, xs: torch.Tensor) -> torch.Tensor:
         | 
| 48 | 
            -
                    """Forward function.
         | 
| 49 | 
            -
             | 
| 50 | 
            -
                    Args:
         | 
| 51 | 
            -
                        xs: input tensor (B, L, D)
         | 
| 52 | 
            -
                    Returns:
         | 
| 53 | 
            -
                        output tensor, (B, L, D)
         | 
| 54 | 
            -
                    """
         | 
| 55 | 
            -
                    return self.w_2(self.dropout(self.activation(self.w_1(xs))))
         | 
| 56 | 
            -
             | 
| 57 | 
            -
             | 
| 58 | 
            -
            class MoEFFNLayer(torch.nn.Module):
         | 
| 59 | 
            -
                """
         | 
| 60 | 
            -
                Mixture of expert with Positionwise feed forward layer
         | 
| 61 | 
            -
                See also figure 1 in https://arxiv.org/pdf/2305.15663.pdf
         | 
| 62 | 
            -
                The output dim is same with the input dim.
         | 
| 63 | 
            -
             | 
| 64 | 
            -
                Modified from https://github.com/Lightning-AI/lit-gpt/pull/823
         | 
| 65 | 
            -
                              https://github.com/mistralai/mistral-src/blob/b46d6/moe_one_file_ref.py#L203-L219
         | 
| 66 | 
            -
                Args:
         | 
| 67 | 
            -
                    n_expert: number of expert.
         | 
| 68 | 
            -
                    n_expert_per_token: The actual number of experts used for each frame
         | 
| 69 | 
            -
                    idim (int): Input dimenstion.
         | 
| 70 | 
            -
                    hidden_units (int): The number of hidden units.
         | 
| 71 | 
            -
                    dropout_rate (float): Dropout rate.
         | 
| 72 | 
            -
                    activation (torch.nn.Module): Activation function
         | 
| 73 | 
            -
                """
         | 
| 74 | 
            -
             | 
| 75 | 
            -
                def __init__(
         | 
| 76 | 
            -
                    self,
         | 
| 77 | 
            -
                    n_expert: int,
         | 
| 78 | 
            -
                    n_expert_per_token: int,
         | 
| 79 | 
            -
                    idim: int,
         | 
| 80 | 
            -
                    hidden_units: int,
         | 
| 81 | 
            -
                    dropout_rate: float,
         | 
| 82 | 
            -
                    activation: torch.nn.Module = torch.nn.ReLU(),
         | 
| 83 | 
            -
                ):
         | 
| 84 | 
            -
                    super(MoEFFNLayer, self).__init__()
         | 
| 85 | 
            -
                    self.gate = torch.nn.Linear(idim, n_expert, bias=False)
         | 
| 86 | 
            -
                    self.experts = torch.nn.ModuleList(
         | 
| 87 | 
            -
                        PositionwiseFeedForward(idim, hidden_units, dropout_rate, activation)
         | 
| 88 | 
            -
                        for _ in range(n_expert)
         | 
| 89 | 
            -
                    )
         | 
| 90 | 
            -
                    self.n_expert_per_token = n_expert_per_token
         | 
| 91 | 
            -
             | 
| 92 | 
            -
                def forward(self, xs: torch.Tensor) -> torch.Tensor:
         | 
| 93 | 
            -
                    """Foward function.
         | 
| 94 | 
            -
                    Args:
         | 
| 95 | 
            -
                        xs: input tensor (B, L, D)
         | 
| 96 | 
            -
                    Returns:
         | 
| 97 | 
            -
                        output tensor, (B, L, D)
         | 
| 98 | 
            -
             | 
| 99 | 
            -
                    """
         | 
| 100 | 
            -
                    B, L, D = xs.size()  # batch size, sequence length, embedding dimension (idim)
         | 
| 101 | 
            -
                    xs = xs.view(-1, D)  # (B*L, D)
         | 
| 102 | 
            -
                    router = self.gate(xs)  # (B*L, n_expert)
         | 
| 103 | 
            -
                    logits, indices = torch.topk(
         | 
| 104 | 
            -
                        router, self.n_expert_per_token
         | 
| 105 | 
            -
                    )  # probs:(B*L, n_expert), indices: (B*L, n_expert)
         | 
| 106 | 
            -
                    weights = torch.nn.functional.softmax(logits, dim=1, dtype=torch.float).to(
         | 
| 107 | 
            -
                        dtype=xs.dtype
         | 
| 108 | 
            -
                    )  # (B*L, n_expert_per_token)
         | 
| 109 | 
            -
                    output = torch.zeros_like(xs)  # (B*L, D)
         | 
| 110 | 
            -
                    for i, expert in enumerate(self.experts):
         | 
| 111 | 
            -
                        mask = indices == i
         | 
| 112 | 
            -
                        batch_idx, ith_expert = torch.where(mask)
         | 
| 113 | 
            -
                        output[batch_idx] += weights[batch_idx, ith_expert, None] * expert(
         | 
| 114 | 
            -
                            xs[batch_idx]
         | 
| 115 | 
            -
                        )
         | 
| 116 | 
            -
                    return output.view(B, L, D)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        cosyvoice/transformer/subsampling.py
    DELETED
    
    | @@ -1,391 +0,0 @@ | |
| 1 | 
            -
            # Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
         | 
| 2 | 
            -
            #               2024 Alibaba Inc (Xiang Lyu)
         | 
| 3 | 
            -
            #
         | 
| 4 | 
            -
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 5 | 
            -
            # you may not use this file except in compliance with the License.
         | 
| 6 | 
            -
            # You may obtain a copy of the License at
         | 
| 7 | 
            -
            #
         | 
| 8 | 
            -
            #   http://www.apache.org/licenses/LICENSE-2.0
         | 
| 9 | 
            -
            #
         | 
| 10 | 
            -
            # Unless required by applicable law or agreed to in writing, software
         | 
| 11 | 
            -
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 12 | 
            -
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 13 | 
            -
            # See the License for the specific language governing permissions and
         | 
| 14 | 
            -
            # limitations under the License.
         | 
| 15 | 
            -
            # Modified from ESPnet(https://github.com/espnet/espnet)
         | 
| 16 | 
            -
            """Subsampling layer definition."""
         | 
| 17 | 
            -
             | 
| 18 | 
            -
            from typing import Tuple, Union
         | 
| 19 | 
            -
             | 
| 20 | 
            -
            import torch
         | 
| 21 | 
            -
             | 
| 22 | 
            -
             | 
| 23 | 
            -
            class BaseSubsampling(torch.nn.Module):
         | 
| 24 | 
            -
             | 
| 25 | 
            -
                def __init__(self):
         | 
| 26 | 
            -
                    super().__init__()
         | 
| 27 | 
            -
                    self.right_context = 0
         | 
| 28 | 
            -
                    self.subsampling_rate = 1
         | 
| 29 | 
            -
             | 
| 30 | 
            -
                def position_encoding(
         | 
| 31 | 
            -
                    self, offset: Union[int, torch.Tensor], size: int
         | 
| 32 | 
            -
                ) -> torch.Tensor:
         | 
| 33 | 
            -
                    return self.pos_enc.position_encoding(offset, size)
         | 
| 34 | 
            -
             | 
| 35 | 
            -
             | 
| 36 | 
            -
            class EmbedinigNoSubsampling(BaseSubsampling):
         | 
| 37 | 
            -
                """Embedding input without subsampling"""
         | 
| 38 | 
            -
             | 
| 39 | 
            -
                def __init__(
         | 
| 40 | 
            -
                    self, idim: int, odim: int, dropout_rate: float, pos_enc_class: torch.nn.Module
         | 
| 41 | 
            -
                ):
         | 
| 42 | 
            -
                    super().__init__()
         | 
| 43 | 
            -
                    self.embed = torch.nn.Embedding(idim, odim)
         | 
| 44 | 
            -
                    self.pos_enc = pos_enc_class
         | 
| 45 | 
            -
             | 
| 46 | 
            -
                def forward(
         | 
| 47 | 
            -
                    self,
         | 
| 48 | 
            -
                    x: torch.Tensor,
         | 
| 49 | 
            -
                    x_mask: torch.Tensor,
         | 
| 50 | 
            -
                    offset: Union[int, torch.Tensor] = 0,
         | 
| 51 | 
            -
                ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
         | 
| 52 | 
            -
                    """Input x.
         | 
| 53 | 
            -
             | 
| 54 | 
            -
                    Args:
         | 
| 55 | 
            -
                        x (torch.Tensor): Input tensor (#batch, time, idim).
         | 
| 56 | 
            -
                        x_mask (torch.Tensor): Input mask (#batch, 1, time).
         | 
| 57 | 
            -
             | 
| 58 | 
            -
                    Returns:
         | 
| 59 | 
            -
                        torch.Tensor: linear input tensor (#batch, time', odim),
         | 
| 60 | 
            -
                            where time' = time .
         | 
| 61 | 
            -
                        torch.Tensor: linear input mask (#batch, 1, time'),
         | 
| 62 | 
            -
                            where time' = time .
         | 
| 63 | 
            -
             | 
| 64 | 
            -
                    """
         | 
| 65 | 
            -
                    x = self.embed(x)
         | 
| 66 | 
            -
                    x, pos_emb = self.pos_enc(x, offset)
         | 
| 67 | 
            -
                    return x, pos_emb, x_mask
         | 
| 68 | 
            -
             | 
| 69 | 
            -
             | 
| 70 | 
            -
            class LinearNoSubsampling(BaseSubsampling):
         | 
| 71 | 
            -
                """Linear transform the input without subsampling
         | 
| 72 | 
            -
             | 
| 73 | 
            -
                Args:
         | 
| 74 | 
            -
                    idim (int): Input dimension.
         | 
| 75 | 
            -
                    odim (int): Output dimension.
         | 
| 76 | 
            -
                    dropout_rate (float): Dropout rate.
         | 
| 77 | 
            -
             | 
| 78 | 
            -
                """
         | 
| 79 | 
            -
             | 
| 80 | 
            -
                def __init__(
         | 
| 81 | 
            -
                    self, idim: int, odim: int, dropout_rate: float, pos_enc_class: torch.nn.Module
         | 
| 82 | 
            -
                ):
         | 
| 83 | 
            -
                    """Construct an linear object."""
         | 
| 84 | 
            -
                    super().__init__()
         | 
| 85 | 
            -
                    self.out = torch.nn.Sequential(
         | 
| 86 | 
            -
                        torch.nn.Linear(idim, odim),
         | 
| 87 | 
            -
                        torch.nn.LayerNorm(odim, eps=1e-5),
         | 
| 88 | 
            -
                        torch.nn.Dropout(dropout_rate),
         | 
| 89 | 
            -
                    )
         | 
| 90 | 
            -
                    self.pos_enc = pos_enc_class
         | 
| 91 | 
            -
                    self.right_context = 0
         | 
| 92 | 
            -
                    self.subsampling_rate = 1
         | 
| 93 | 
            -
             | 
| 94 | 
            -
                def forward(
         | 
| 95 | 
            -
                    self,
         | 
| 96 | 
            -
                    x: torch.Tensor,
         | 
| 97 | 
            -
                    x_mask: torch.Tensor,
         | 
| 98 | 
            -
                    offset: Union[int, torch.Tensor] = 0,
         | 
| 99 | 
            -
                ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
         | 
| 100 | 
            -
                    """Input x.
         | 
| 101 | 
            -
             | 
| 102 | 
            -
                    Args:
         | 
| 103 | 
            -
                        x (torch.Tensor): Input tensor (#batch, time, idim).
         | 
| 104 | 
            -
                        x_mask (torch.Tensor): Input mask (#batch, 1, time).
         | 
| 105 | 
            -
             | 
| 106 | 
            -
                    Returns:
         | 
| 107 | 
            -
                        torch.Tensor: linear input tensor (#batch, time', odim),
         | 
| 108 | 
            -
                            where time' = time .
         | 
| 109 | 
            -
                        torch.Tensor: linear input mask (#batch, 1, time'),
         | 
| 110 | 
            -
                            where time' = time .
         | 
| 111 | 
            -
             | 
| 112 | 
            -
                    """
         | 
| 113 | 
            -
                    x = self.out(x)
         | 
| 114 | 
            -
                    x, pos_emb = self.pos_enc(x, offset)
         | 
| 115 | 
            -
                    return x, pos_emb, x_mask
         | 
| 116 | 
            -
             | 
| 117 | 
            -
             | 
| 118 | 
            -
            class Conv1dSubsampling2(BaseSubsampling):
         | 
| 119 | 
            -
                """Convolutional 1D subsampling (to 1/2 length).
         | 
| 120 | 
            -
                   It is designed for Whisper, ref:
         | 
| 121 | 
            -
                   https://github.com/openai/whisper/blob/main/whisper/model.py
         | 
| 122 | 
            -
             | 
| 123 | 
            -
                Args:
         | 
| 124 | 
            -
                    idim (int): Input dimension.
         | 
| 125 | 
            -
                    odim (int): Output dimension.
         | 
| 126 | 
            -
                    dropout_rate (float): Dropout rate.
         | 
| 127 | 
            -
             | 
| 128 | 
            -
                """
         | 
| 129 | 
            -
             | 
| 130 | 
            -
                def __init__(
         | 
| 131 | 
            -
                    self, idim: int, odim: int, dropout_rate: float, pos_enc_class: torch.nn.Module
         | 
| 132 | 
            -
                ):
         | 
| 133 | 
            -
                    """Construct an Conv1dSubsampling2 object."""
         | 
| 134 | 
            -
                    super().__init__()
         | 
| 135 | 
            -
                    self.conv = torch.nn.Sequential(
         | 
| 136 | 
            -
                        torch.nn.Conv1d(idim, odim, kernel_size=3, padding=1),
         | 
| 137 | 
            -
                        torch.nn.GELU(),
         | 
| 138 | 
            -
                        torch.nn.Conv1d(odim, odim, kernel_size=3, stride=2, padding=1),
         | 
| 139 | 
            -
                        torch.nn.GELU(),
         | 
| 140 | 
            -
                    )
         | 
| 141 | 
            -
                    self.pos_enc = pos_enc_class
         | 
| 142 | 
            -
                    # The right context for every conv layer is computed by:
         | 
| 143 | 
            -
                    # (kernel_size - 1) * frame_rate_of_this_layer
         | 
| 144 | 
            -
                    self.subsampling_rate = 2
         | 
| 145 | 
            -
                    # 4 = (3 - 1) * 1 + (3 - 1) * 1
         | 
| 146 | 
            -
                    self.right_context = 4
         | 
| 147 | 
            -
             | 
| 148 | 
            -
                def forward(
         | 
| 149 | 
            -
                    self,
         | 
| 150 | 
            -
                    x: torch.Tensor,
         | 
| 151 | 
            -
                    x_mask: torch.Tensor,
         | 
| 152 | 
            -
                    offset: Union[int, torch.Tensor] = 0,
         | 
| 153 | 
            -
                ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
         | 
| 154 | 
            -
                    """Subsample x.
         | 
| 155 | 
            -
             | 
| 156 | 
            -
                    Args:
         | 
| 157 | 
            -
                        x (torch.Tensor): Input tensor (#batch, time, idim).
         | 
| 158 | 
            -
                        x_mask (torch.Tensor): Input mask (#batch, 1, time).
         | 
| 159 | 
            -
             | 
| 160 | 
            -
                    Returns:
         | 
| 161 | 
            -
                        torch.Tensor: Subsampled tensor (#batch, time', odim),
         | 
| 162 | 
            -
                            where time' = time // 2.
         | 
| 163 | 
            -
                        torch.Tensor: Subsampled mask (#batch, 1, time'),
         | 
| 164 | 
            -
                            where time' = time // 2.
         | 
| 165 | 
            -
                        torch.Tensor: positional encoding
         | 
| 166 | 
            -
             | 
| 167 | 
            -
                    """
         | 
| 168 | 
            -
                    time = x.size(1)
         | 
| 169 | 
            -
                    x = x.transpose(1, 2)  # (b, f, t)
         | 
| 170 | 
            -
                    x = self.conv(x)
         | 
| 171 | 
            -
                    x = x.transpose(1, 2)  # (b, t, f)
         | 
| 172 | 
            -
                    x, pos_emb = self.pos_enc(x, offset)
         | 
| 173 | 
            -
                    return x, pos_emb, x_mask[:, :, (time + 1) % 2 :: 2]
         | 
| 174 | 
            -
             | 
| 175 | 
            -
             | 
| 176 | 
            -
            class Conv2dSubsampling4(BaseSubsampling):
         | 
| 177 | 
            -
                """Convolutional 2D subsampling (to 1/4 length).
         | 
| 178 | 
            -
             | 
| 179 | 
            -
                Args:
         | 
| 180 | 
            -
                    idim (int): Input dimension.
         | 
| 181 | 
            -
                    odim (int): Output dimension.
         | 
| 182 | 
            -
                    dropout_rate (float): Dropout rate.
         | 
| 183 | 
            -
             | 
| 184 | 
            -
                """
         | 
| 185 | 
            -
             | 
| 186 | 
            -
                def __init__(
         | 
| 187 | 
            -
                    self, idim: int, odim: int, dropout_rate: float, pos_enc_class: torch.nn.Module
         | 
| 188 | 
            -
                ):
         | 
| 189 | 
            -
                    """Construct an Conv2dSubsampling4 object."""
         | 
| 190 | 
            -
                    super().__init__()
         | 
| 191 | 
            -
                    self.conv = torch.nn.Sequential(
         | 
| 192 | 
            -
                        torch.nn.Conv2d(1, odim, 3, 2),
         | 
| 193 | 
            -
                        torch.nn.ReLU(),
         | 
| 194 | 
            -
                        torch.nn.Conv2d(odim, odim, 3, 2),
         | 
| 195 | 
            -
                        torch.nn.ReLU(),
         | 
| 196 | 
            -
                    )
         | 
| 197 | 
            -
                    self.out = torch.nn.Sequential(
         | 
| 198 | 
            -
                        torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)
         | 
| 199 | 
            -
                    )
         | 
| 200 | 
            -
                    self.pos_enc = pos_enc_class
         | 
| 201 | 
            -
                    # The right context for every conv layer is computed by:
         | 
| 202 | 
            -
                    # (kernel_size - 1) * frame_rate_of_this_layer
         | 
| 203 | 
            -
                    self.subsampling_rate = 4
         | 
| 204 | 
            -
                    # 6 = (3 - 1) * 1 + (3 - 1) * 2
         | 
| 205 | 
            -
                    self.right_context = 6
         | 
| 206 | 
            -
             | 
| 207 | 
            -
                def forward(
         | 
| 208 | 
            -
                    self,
         | 
| 209 | 
            -
                    x: torch.Tensor,
         | 
| 210 | 
            -
                    x_mask: torch.Tensor,
         | 
| 211 | 
            -
                    offset: Union[int, torch.Tensor] = 0,
         | 
| 212 | 
            -
                ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
         | 
| 213 | 
            -
                    """Subsample x.
         | 
| 214 | 
            -
             | 
| 215 | 
            -
                    Args:
         | 
| 216 | 
            -
                        x (torch.Tensor): Input tensor (#batch, time, idim).
         | 
| 217 | 
            -
                        x_mask (torch.Tensor): Input mask (#batch, 1, time).
         | 
| 218 | 
            -
             | 
| 219 | 
            -
                    Returns:
         | 
| 220 | 
            -
                        torch.Tensor: Subsampled tensor (#batch, time', odim),
         | 
| 221 | 
            -
                            where time' = time // 4.
         | 
| 222 | 
            -
                        torch.Tensor: Subsampled mask (#batch, 1, time'),
         | 
| 223 | 
            -
                            where time' = time // 4.
         | 
| 224 | 
            -
                        torch.Tensor: positional encoding
         | 
| 225 | 
            -
             | 
| 226 | 
            -
                    """
         | 
| 227 | 
            -
                    x = x.unsqueeze(1)  # (b, c=1, t, f)
         | 
| 228 | 
            -
                    x = self.conv(x)
         | 
| 229 | 
            -
                    b, c, t, f = x.size()
         | 
| 230 | 
            -
                    x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
         | 
| 231 | 
            -
                    x, pos_emb = self.pos_enc(x, offset)
         | 
| 232 | 
            -
                    return x, pos_emb, x_mask[:, :, 2::2][:, :, 2::2]
         | 
| 233 | 
            -
             | 
| 234 | 
            -
             | 
| 235 | 
            -
            class Conv2dSubsampling6(BaseSubsampling):
         | 
| 236 | 
            -
                """Convolutional 2D subsampling (to 1/6 length).
         | 
| 237 | 
            -
                Args:
         | 
| 238 | 
            -
                    idim (int): Input dimension.
         | 
| 239 | 
            -
                    odim (int): Output dimension.
         | 
| 240 | 
            -
                    dropout_rate (float): Dropout rate.
         | 
| 241 | 
            -
                    pos_enc (torch.nn.Module): Custom position encoding layer.
         | 
| 242 | 
            -
                """
         | 
| 243 | 
            -
             | 
| 244 | 
            -
                def __init__(
         | 
| 245 | 
            -
                    self, idim: int, odim: int, dropout_rate: float, pos_enc_class: torch.nn.Module
         | 
| 246 | 
            -
                ):
         | 
| 247 | 
            -
                    """Construct an Conv2dSubsampling6 object."""
         | 
| 248 | 
            -
                    super().__init__()
         | 
| 249 | 
            -
                    self.conv = torch.nn.Sequential(
         | 
| 250 | 
            -
                        torch.nn.Conv2d(1, odim, 3, 2),
         | 
| 251 | 
            -
                        torch.nn.ReLU(),
         | 
| 252 | 
            -
                        torch.nn.Conv2d(odim, odim, 5, 3),
         | 
| 253 | 
            -
                        torch.nn.ReLU(),
         | 
| 254 | 
            -
                    )
         | 
| 255 | 
            -
                    self.linear = torch.nn.Linear(odim * (((idim - 1) // 2 - 2) // 3), odim)
         | 
| 256 | 
            -
                    self.pos_enc = pos_enc_class
         | 
| 257 | 
            -
                    # 10 = (3 - 1) * 1 + (5 - 1) * 2
         | 
| 258 | 
            -
                    self.subsampling_rate = 6
         | 
| 259 | 
            -
                    self.right_context = 10
         | 
| 260 | 
            -
             | 
| 261 | 
            -
                def forward(
         | 
| 262 | 
            -
                    self,
         | 
| 263 | 
            -
                    x: torch.Tensor,
         | 
| 264 | 
            -
                    x_mask: torch.Tensor,
         | 
| 265 | 
            -
                    offset: Union[int, torch.Tensor] = 0,
         | 
| 266 | 
            -
                ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
         | 
| 267 | 
            -
                    """Subsample x.
         | 
| 268 | 
            -
                    Args:
         | 
| 269 | 
            -
                        x (torch.Tensor): Input tensor (#batch, time, idim).
         | 
| 270 | 
            -
                        x_mask (torch.Tensor): Input mask (#batch, 1, time).
         | 
| 271 | 
            -
             | 
| 272 | 
            -
                    Returns:
         | 
| 273 | 
            -
                        torch.Tensor: Subsampled tensor (#batch, time', odim),
         | 
| 274 | 
            -
                            where time' = time // 6.
         | 
| 275 | 
            -
                        torch.Tensor: Subsampled mask (#batch, 1, time'),
         | 
| 276 | 
            -
                            where time' = time // 6.
         | 
| 277 | 
            -
                        torch.Tensor: positional encoding
         | 
| 278 | 
            -
                    """
         | 
| 279 | 
            -
                    x = x.unsqueeze(1)  # (b, c, t, f)
         | 
| 280 | 
            -
                    x = self.conv(x)
         | 
| 281 | 
            -
                    b, c, t, f = x.size()
         | 
| 282 | 
            -
                    x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f))
         | 
| 283 | 
            -
                    x, pos_emb = self.pos_enc(x, offset)
         | 
| 284 | 
            -
                    return x, pos_emb, x_mask[:, :, 2::2][:, :, 4::3]
         | 
| 285 | 
            -
             | 
| 286 | 
            -
             | 
| 287 | 
            -
            class Conv2dSubsampling8(BaseSubsampling):
         | 
| 288 | 
            -
                """Convolutional 2D subsampling (to 1/8 length).
         | 
| 289 | 
            -
             | 
| 290 | 
            -
                Args:
         | 
| 291 | 
            -
                    idim (int): Input dimension.
         | 
| 292 | 
            -
                    odim (int): Output dimension.
         | 
| 293 | 
            -
                    dropout_rate (float): Dropout rate.
         | 
| 294 | 
            -
             | 
| 295 | 
            -
                """
         | 
| 296 | 
            -
             | 
| 297 | 
            -
                def __init__(
         | 
| 298 | 
            -
                    self, idim: int, odim: int, dropout_rate: float, pos_enc_class: torch.nn.Module
         | 
| 299 | 
            -
                ):
         | 
| 300 | 
            -
                    """Construct an Conv2dSubsampling8 object."""
         | 
| 301 | 
            -
                    super().__init__()
         | 
| 302 | 
            -
                    self.conv = torch.nn.Sequential(
         | 
| 303 | 
            -
                        torch.nn.Conv2d(1, odim, 3, 2),
         | 
| 304 | 
            -
                        torch.nn.ReLU(),
         | 
| 305 | 
            -
                        torch.nn.Conv2d(odim, odim, 3, 2),
         | 
| 306 | 
            -
                        torch.nn.ReLU(),
         | 
| 307 | 
            -
                        torch.nn.Conv2d(odim, odim, 3, 2),
         | 
| 308 | 
            -
                        torch.nn.ReLU(),
         | 
| 309 | 
            -
                    )
         | 
| 310 | 
            -
                    self.linear = torch.nn.Linear(
         | 
| 311 | 
            -
                        odim * ((((idim - 1) // 2 - 1) // 2 - 1) // 2), odim
         | 
| 312 | 
            -
                    )
         | 
| 313 | 
            -
                    self.pos_enc = pos_enc_class
         | 
| 314 | 
            -
                    self.subsampling_rate = 8
         | 
| 315 | 
            -
                    # 14 = (3 - 1) * 1 + (3 - 1) * 2 + (3 - 1) * 4
         | 
| 316 | 
            -
                    self.right_context = 14
         | 
| 317 | 
            -
             | 
| 318 | 
            -
                def forward(
         | 
| 319 | 
            -
                    self,
         | 
| 320 | 
            -
                    x: torch.Tensor,
         | 
| 321 | 
            -
                    x_mask: torch.Tensor,
         | 
| 322 | 
            -
                    offset: Union[int, torch.Tensor] = 0,
         | 
| 323 | 
            -
                ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
         | 
| 324 | 
            -
                    """Subsample x.
         | 
| 325 | 
            -
             | 
| 326 | 
            -
                    Args:
         | 
| 327 | 
            -
                        x (torch.Tensor): Input tensor (#batch, time, idim).
         | 
| 328 | 
            -
                        x_mask (torch.Tensor): Input mask (#batch, 1, time).
         | 
| 329 | 
            -
             | 
| 330 | 
            -
                    Returns:
         | 
| 331 | 
            -
                        torch.Tensor: Subsampled tensor (#batch, time', odim),
         | 
| 332 | 
            -
                            where time' = time // 8.
         | 
| 333 | 
            -
                        torch.Tensor: Subsampled mask (#batch, 1, time'),
         | 
| 334 | 
            -
                            where time' = time // 8.
         | 
| 335 | 
            -
                        torch.Tensor: positional encoding
         | 
| 336 | 
            -
                    """
         | 
| 337 | 
            -
                    x = x.unsqueeze(1)  # (b, c, t, f)
         | 
| 338 | 
            -
                    x = self.conv(x)
         | 
| 339 | 
            -
                    b, c, t, f = x.size()
         | 
| 340 | 
            -
                    x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f))
         | 
| 341 | 
            -
                    x, pos_emb = self.pos_enc(x, offset)
         | 
| 342 | 
            -
                    return x, pos_emb, x_mask[:, :, 2::2][:, :, 2::2][:, :, 2::2]
         | 
| 343 | 
            -
             | 
| 344 | 
            -
             | 
| 345 | 
            -
            class LegacyLinearNoSubsampling(BaseSubsampling):
         | 
| 346 | 
            -
                """Linear transform the input without subsampling
         | 
| 347 | 
            -
             | 
| 348 | 
            -
                Args:
         | 
| 349 | 
            -
                    idim (int): Input dimension.
         | 
| 350 | 
            -
                    odim (int): Output dimension.
         | 
| 351 | 
            -
                    dropout_rate (float): Dropout rate.
         | 
| 352 | 
            -
             | 
| 353 | 
            -
                """
         | 
| 354 | 
            -
             | 
| 355 | 
            -
                def __init__(
         | 
| 356 | 
            -
                    self, idim: int, odim: int, dropout_rate: float, pos_enc_class: torch.nn.Module
         | 
| 357 | 
            -
                ):
         | 
| 358 | 
            -
                    """Construct an linear object."""
         | 
| 359 | 
            -
                    super().__init__()
         | 
| 360 | 
            -
                    self.out = torch.nn.Sequential(
         | 
| 361 | 
            -
                        torch.nn.Linear(idim, odim),
         | 
| 362 | 
            -
                        torch.nn.LayerNorm(odim, eps=1e-5),
         | 
| 363 | 
            -
                        torch.nn.Dropout(dropout_rate),
         | 
| 364 | 
            -
                        torch.nn.ReLU(),
         | 
| 365 | 
            -
                    )
         | 
| 366 | 
            -
                    self.pos_enc = pos_enc_class
         | 
| 367 | 
            -
                    self.right_context = 0
         | 
| 368 | 
            -
                    self.subsampling_rate = 1
         | 
| 369 | 
            -
             | 
| 370 | 
            -
                def forward(
         | 
| 371 | 
            -
                    self,
         | 
| 372 | 
            -
                    x: torch.Tensor,
         | 
| 373 | 
            -
                    x_mask: torch.Tensor,
         | 
| 374 | 
            -
                    offset: Union[int, torch.Tensor] = 0,
         | 
| 375 | 
            -
                ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
         | 
| 376 | 
            -
                    """Input x.
         | 
| 377 | 
            -
             | 
| 378 | 
            -
                    Args:
         | 
| 379 | 
            -
                        x (torch.Tensor): Input tensor (#batch, time, idim).
         | 
| 380 | 
            -
                        x_mask (torch.Tensor): Input mask (#batch, 1, time).
         | 
| 381 | 
            -
             | 
| 382 | 
            -
                    Returns:
         | 
| 383 | 
            -
                        torch.Tensor: linear input tensor (#batch, time', odim),
         | 
| 384 | 
            -
                            where time' = time .
         | 
| 385 | 
            -
                        torch.Tensor: linear input mask (#batch, 1, time'),
         | 
| 386 | 
            -
                            where time' = time .
         | 
| 387 | 
            -
             | 
| 388 | 
            -
                    """
         | 
| 389 | 
            -
                    x = self.out(x)
         | 
| 390 | 
            -
                    x, pos_emb = self.pos_enc(x, offset)
         | 
| 391 | 
            -
                    return x, pos_emb, x_mask
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        cosyvoice/utils/__init__.py
    DELETED
    
    | 
            File without changes
         | 
    	
        cosyvoice/utils/audio.py
    DELETED
    
    | @@ -1,90 +0,0 @@ | |
| 1 | 
            -
            import numpy as np
         | 
| 2 | 
            -
            import torch
         | 
| 3 | 
            -
            import torch.utils.data
         | 
| 4 | 
            -
            from librosa.filters import mel as librosa_mel_fn
         | 
| 5 | 
            -
            from scipy.io.wavfile import read
         | 
| 6 | 
            -
             | 
| 7 | 
            -
            MAX_WAV_VALUE = 32768.0
         | 
| 8 | 
            -
             | 
| 9 | 
            -
             | 
| 10 | 
            -
            def load_wav(full_path):
         | 
| 11 | 
            -
                sampling_rate, data = read(full_path)
         | 
| 12 | 
            -
                return data, sampling_rate
         | 
| 13 | 
            -
             | 
| 14 | 
            -
             | 
| 15 | 
            -
            def dynamic_range_compression(x, C=1, clip_val=1e-5):
         | 
| 16 | 
            -
                return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
         | 
| 17 | 
            -
             | 
| 18 | 
            -
             | 
| 19 | 
            -
            def dynamic_range_decompression(x, C=1):
         | 
| 20 | 
            -
                return np.exp(x) / C
         | 
| 21 | 
            -
             | 
| 22 | 
            -
             | 
| 23 | 
            -
            def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
         | 
| 24 | 
            -
                return torch.log(torch.clamp(x, min=clip_val) * C)
         | 
| 25 | 
            -
             | 
| 26 | 
            -
             | 
| 27 | 
            -
            def dynamic_range_decompression_torch(x, C=1):
         | 
| 28 | 
            -
                return torch.exp(x) / C
         | 
| 29 | 
            -
             | 
| 30 | 
            -
             | 
| 31 | 
            -
            def spectral_normalize_torch(magnitudes):
         | 
| 32 | 
            -
                output = dynamic_range_compression_torch(magnitudes)
         | 
| 33 | 
            -
                return output
         | 
| 34 | 
            -
             | 
| 35 | 
            -
             | 
| 36 | 
            -
            def spectral_de_normalize_torch(magnitudes):
         | 
| 37 | 
            -
                output = dynamic_range_decompression_torch(magnitudes)
         | 
| 38 | 
            -
                return output
         | 
| 39 | 
            -
             | 
| 40 | 
            -
             | 
| 41 | 
            -
            mel_basis = {}
         | 
| 42 | 
            -
            hann_window = {}
         | 
| 43 | 
            -
             | 
| 44 | 
            -
             | 
| 45 | 
            -
            def mel_spectrogram(
         | 
| 46 | 
            -
                y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False
         | 
| 47 | 
            -
            ):
         | 
| 48 | 
            -
                # if torch.min(y) < -1.0:
         | 
| 49 | 
            -
                #     print("min value is ", torch.min(y))
         | 
| 50 | 
            -
                # if torch.max(y) > 1.0:
         | 
| 51 | 
            -
                #     print("max value is ", torch.max(y))
         | 
| 52 | 
            -
             | 
| 53 | 
            -
                global mel_basis, hann_window  # pylint: disable=global-statement
         | 
| 54 | 
            -
                if f"{str(fmax)}_{str(y.device)}" not in mel_basis:
         | 
| 55 | 
            -
                    mel = librosa_mel_fn(
         | 
| 56 | 
            -
                        sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
         | 
| 57 | 
            -
                    )
         | 
| 58 | 
            -
                    mel_basis[str(fmax) + "_" + str(y.device)] = (
         | 
| 59 | 
            -
                        torch.from_numpy(mel).float().to(y.device)
         | 
| 60 | 
            -
                    )
         | 
| 61 | 
            -
                    hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
         | 
| 62 | 
            -
             | 
| 63 | 
            -
                y = torch.nn.functional.pad(
         | 
| 64 | 
            -
                    y.unsqueeze(1),
         | 
| 65 | 
            -
                    (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
         | 
| 66 | 
            -
                    mode="reflect",
         | 
| 67 | 
            -
                )
         | 
| 68 | 
            -
                y = y.squeeze(1)
         | 
| 69 | 
            -
             | 
| 70 | 
            -
                spec = torch.view_as_real(
         | 
| 71 | 
            -
                    torch.stft(
         | 
| 72 | 
            -
                        y,
         | 
| 73 | 
            -
                        n_fft,
         | 
| 74 | 
            -
                        hop_length=hop_size,
         | 
| 75 | 
            -
                        win_length=win_size,
         | 
| 76 | 
            -
                        window=hann_window[str(y.device)],
         | 
| 77 | 
            -
                        center=center,
         | 
| 78 | 
            -
                        pad_mode="reflect",
         | 
| 79 | 
            -
                        normalized=False,
         | 
| 80 | 
            -
                        onesided=True,
         | 
| 81 | 
            -
                        return_complex=True,
         | 
| 82 | 
            -
                    )
         | 
| 83 | 
            -
                )
         | 
| 84 | 
            -
             | 
| 85 | 
            -
                spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
         | 
| 86 | 
            -
             | 
| 87 | 
            -
                spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec)
         | 
| 88 | 
            -
                spec = spectral_normalize_torch(spec)
         | 
| 89 | 
            -
             | 
| 90 | 
            -
                return spec
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        cosyvoice/utils/class_utils.py
    DELETED
    
    | @@ -1,78 +0,0 @@ | |
| 1 | 
            -
            # Copyright [2023-11-28] <[email protected], Xingchen Song>
         | 
| 2 | 
            -
            #            2024 Alibaba Inc (authors: Xiang Lyu)
         | 
| 3 | 
            -
            #
         | 
| 4 | 
            -
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 5 | 
            -
            # you may not use this file except in compliance with the License.
         | 
| 6 | 
            -
            # You may obtain a copy of the License at
         | 
| 7 | 
            -
            #
         | 
| 8 | 
            -
            #     http://www.apache.org/licenses/LICENSE-2.0
         | 
| 9 | 
            -
            #
         | 
| 10 | 
            -
            # Unless required by applicable law or agreed to in writing, software
         | 
| 11 | 
            -
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 12 | 
            -
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 13 | 
            -
            # See the License for the specific language governing permissions and
         | 
| 14 | 
            -
            # limitations under the License.
         | 
| 15 | 
            -
            import torch
         | 
| 16 | 
            -
             | 
| 17 | 
            -
            from cosyvoice.transformer.activation import Swish
         | 
| 18 | 
            -
            from cosyvoice.transformer.subsampling import (
         | 
| 19 | 
            -
                LinearNoSubsampling,
         | 
| 20 | 
            -
                EmbedinigNoSubsampling,
         | 
| 21 | 
            -
                Conv1dSubsampling2,
         | 
| 22 | 
            -
                Conv2dSubsampling4,
         | 
| 23 | 
            -
                Conv2dSubsampling6,
         | 
| 24 | 
            -
                Conv2dSubsampling8,
         | 
| 25 | 
            -
            )
         | 
| 26 | 
            -
            from cosyvoice.transformer.embedding import (
         | 
| 27 | 
            -
                PositionalEncoding,
         | 
| 28 | 
            -
                RelPositionalEncoding,
         | 
| 29 | 
            -
                WhisperPositionalEncoding,
         | 
| 30 | 
            -
                LearnablePositionalEncoding,
         | 
| 31 | 
            -
                NoPositionalEncoding,
         | 
| 32 | 
            -
            )
         | 
| 33 | 
            -
            from cosyvoice.transformer.attention import (
         | 
| 34 | 
            -
                MultiHeadedAttention,
         | 
| 35 | 
            -
                RelPositionMultiHeadedAttention,
         | 
| 36 | 
            -
            )
         | 
| 37 | 
            -
            from cosyvoice.transformer.embedding import (
         | 
| 38 | 
            -
                EspnetRelPositionalEncoding,
         | 
| 39 | 
            -
            )
         | 
| 40 | 
            -
            from cosyvoice.transformer.subsampling import (
         | 
| 41 | 
            -
                LegacyLinearNoSubsampling,
         | 
| 42 | 
            -
            )
         | 
| 43 | 
            -
             | 
| 44 | 
            -
             | 
| 45 | 
            -
            COSYVOICE_ACTIVATION_CLASSES = {
         | 
| 46 | 
            -
                "hardtanh": torch.nn.Hardtanh,
         | 
| 47 | 
            -
                "tanh": torch.nn.Tanh,
         | 
| 48 | 
            -
                "relu": torch.nn.ReLU,
         | 
| 49 | 
            -
                "selu": torch.nn.SELU,
         | 
| 50 | 
            -
                "swish": getattr(torch.nn, "SiLU", Swish),
         | 
| 51 | 
            -
                "gelu": torch.nn.GELU,
         | 
| 52 | 
            -
            }
         | 
| 53 | 
            -
             | 
| 54 | 
            -
            COSYVOICE_SUBSAMPLE_CLASSES = {
         | 
| 55 | 
            -
                "linear": LinearNoSubsampling,
         | 
| 56 | 
            -
                "linear_legacy": LegacyLinearNoSubsampling,
         | 
| 57 | 
            -
                "embed": EmbedinigNoSubsampling,
         | 
| 58 | 
            -
                "conv1d2": Conv1dSubsampling2,
         | 
| 59 | 
            -
                "conv2d": Conv2dSubsampling4,
         | 
| 60 | 
            -
                "conv2d6": Conv2dSubsampling6,
         | 
| 61 | 
            -
                "conv2d8": Conv2dSubsampling8,
         | 
| 62 | 
            -
                "paraformer_dummy": torch.nn.Identity,
         | 
| 63 | 
            -
            }
         | 
| 64 | 
            -
             | 
| 65 | 
            -
            COSYVOICE_EMB_CLASSES = {
         | 
| 66 | 
            -
                "embed": PositionalEncoding,
         | 
| 67 | 
            -
                "abs_pos": PositionalEncoding,
         | 
| 68 | 
            -
                "rel_pos": RelPositionalEncoding,
         | 
| 69 | 
            -
                "rel_pos_espnet": EspnetRelPositionalEncoding,
         | 
| 70 | 
            -
                "no_pos": NoPositionalEncoding,
         | 
| 71 | 
            -
                "abs_pos_whisper": WhisperPositionalEncoding,
         | 
| 72 | 
            -
                "embed_learnable_pe": LearnablePositionalEncoding,
         | 
| 73 | 
            -
            }
         | 
| 74 | 
            -
             | 
| 75 | 
            -
            COSYVOICE_ATTENTION_CLASSES = {
         | 
| 76 | 
            -
                "selfattn": MultiHeadedAttention,
         | 
| 77 | 
            -
                "rel_selfattn": RelPositionMultiHeadedAttention,
         | 
| 78 | 
            -
            }
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        cosyvoice/utils/common.py
    DELETED
    
    | @@ -1,169 +0,0 @@ | |
| 1 | 
            -
            # Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
         | 
| 2 | 
            -
            #               2024 Alibaba Inc (authors: Xiang Lyu)
         | 
| 3 | 
            -
            #
         | 
| 4 | 
            -
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 5 | 
            -
            # you may not use this file except in compliance with the License.
         | 
| 6 | 
            -
            # You may obtain a copy of the License at
         | 
| 7 | 
            -
            #
         | 
| 8 | 
            -
            #   http://www.apache.org/licenses/LICENSE-2.0
         | 
| 9 | 
            -
            #
         | 
| 10 | 
            -
            # Unless required by applicable law or agreed to in writing, software
         | 
| 11 | 
            -
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 12 | 
            -
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 13 | 
            -
            # See the License for the specific language governing permissions and
         | 
| 14 | 
            -
            # limitations under the License.
         | 
| 15 | 
            -
            # Modified from ESPnet(https://github.com/espnet/espnet)
         | 
| 16 | 
            -
            """Unility functions for Transformer."""
         | 
| 17 | 
            -
             | 
| 18 | 
            -
            import random
         | 
| 19 | 
            -
            from typing import List
         | 
| 20 | 
            -
             | 
| 21 | 
            -
            import numpy as np
         | 
| 22 | 
            -
            import torch
         | 
| 23 | 
            -
             | 
| 24 | 
            -
            IGNORE_ID = -1
         | 
| 25 | 
            -
             | 
| 26 | 
            -
             | 
| 27 | 
            -
            def pad_list(xs: List[torch.Tensor], pad_value: int):
         | 
| 28 | 
            -
                """Perform padding for the list of tensors.
         | 
| 29 | 
            -
             | 
| 30 | 
            -
                Args:
         | 
| 31 | 
            -
                    xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
         | 
| 32 | 
            -
                    pad_value (float): Value for padding.
         | 
| 33 | 
            -
             | 
| 34 | 
            -
                Returns:
         | 
| 35 | 
            -
                    Tensor: Padded tensor (B, Tmax, `*`).
         | 
| 36 | 
            -
             | 
| 37 | 
            -
                Examples:
         | 
| 38 | 
            -
                    >>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
         | 
| 39 | 
            -
                    >>> x
         | 
| 40 | 
            -
                    [tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
         | 
| 41 | 
            -
                    >>> pad_list(x, 0)
         | 
| 42 | 
            -
                    tensor([[1., 1., 1., 1.],
         | 
| 43 | 
            -
                            [1., 1., 0., 0.],
         | 
| 44 | 
            -
                            [1., 0., 0., 0.]])
         | 
| 45 | 
            -
             | 
| 46 | 
            -
                """
         | 
| 47 | 
            -
                max_len = max([len(item) for item in xs])
         | 
| 48 | 
            -
                batchs = len(xs)
         | 
| 49 | 
            -
                ndim = xs[0].ndim
         | 
| 50 | 
            -
                if ndim == 1:
         | 
| 51 | 
            -
                    pad_res = torch.zeros(batchs, max_len, dtype=xs[0].dtype, device=xs[0].device)
         | 
| 52 | 
            -
                elif ndim == 2:
         | 
| 53 | 
            -
                    pad_res = torch.zeros(
         | 
| 54 | 
            -
                        batchs, max_len, xs[0].shape[1], dtype=xs[0].dtype, device=xs[0].device
         | 
| 55 | 
            -
                    )
         | 
| 56 | 
            -
                elif ndim == 3:
         | 
| 57 | 
            -
                    pad_res = torch.zeros(
         | 
| 58 | 
            -
                        batchs,
         | 
| 59 | 
            -
                        max_len,
         | 
| 60 | 
            -
                        xs[0].shape[1],
         | 
| 61 | 
            -
                        xs[0].shape[2],
         | 
| 62 | 
            -
                        dtype=xs[0].dtype,
         | 
| 63 | 
            -
                        device=xs[0].device,
         | 
| 64 | 
            -
                    )
         | 
| 65 | 
            -
                else:
         | 
| 66 | 
            -
                    raise ValueError(f"Unsupported ndim: {ndim}")
         | 
| 67 | 
            -
                pad_res.fill_(pad_value)
         | 
| 68 | 
            -
                for i in range(batchs):
         | 
| 69 | 
            -
                    pad_res[i, : len(xs[i])] = xs[i]
         | 
| 70 | 
            -
                return pad_res
         | 
| 71 | 
            -
             | 
| 72 | 
            -
             | 
| 73 | 
            -
            def th_accuracy(
         | 
| 74 | 
            -
                pad_outputs: torch.Tensor, pad_targets: torch.Tensor, ignore_label: int
         | 
| 75 | 
            -
            ) -> torch.Tensor:
         | 
| 76 | 
            -
                """Calculate accuracy.
         | 
| 77 | 
            -
             | 
| 78 | 
            -
                Args:
         | 
| 79 | 
            -
                    pad_outputs (Tensor): Prediction tensors (B * Lmax, D).
         | 
| 80 | 
            -
                    pad_targets (LongTensor): Target label tensors (B, Lmax).
         | 
| 81 | 
            -
                    ignore_label (int): Ignore label id.
         | 
| 82 | 
            -
             | 
| 83 | 
            -
                Returns:
         | 
| 84 | 
            -
                    torch.Tensor: Accuracy value (0.0 - 1.0).
         | 
| 85 | 
            -
             | 
| 86 | 
            -
                """
         | 
| 87 | 
            -
                pad_pred = pad_outputs.view(
         | 
| 88 | 
            -
                    pad_targets.size(0), pad_targets.size(1), pad_outputs.size(1)
         | 
| 89 | 
            -
                ).argmax(2)
         | 
| 90 | 
            -
                mask = pad_targets != ignore_label
         | 
| 91 | 
            -
                numerator = torch.sum(
         | 
| 92 | 
            -
                    pad_pred.masked_select(mask) == pad_targets.masked_select(mask)
         | 
| 93 | 
            -
                )
         | 
| 94 | 
            -
                denominator = torch.sum(mask)
         | 
| 95 | 
            -
                return (numerator / denominator).detach()
         | 
| 96 | 
            -
             | 
| 97 | 
            -
             | 
| 98 | 
            -
            def get_padding(kernel_size, dilation=1):
         | 
| 99 | 
            -
                return int((kernel_size * dilation - dilation) / 2)
         | 
| 100 | 
            -
             | 
| 101 | 
            -
             | 
| 102 | 
            -
            def init_weights(m, mean=0.0, std=0.01):
         | 
| 103 | 
            -
                classname = m.__class__.__name__
         | 
| 104 | 
            -
                if classname.find("Conv") != -1:
         | 
| 105 | 
            -
                    m.weight.data.normal_(mean, std)
         | 
| 106 | 
            -
             | 
| 107 | 
            -
             | 
| 108 | 
            -
            # Repetition Aware Sampling in VALL-E 2
         | 
| 109 | 
            -
            def ras_sampling(
         | 
| 110 | 
            -
                weighted_scores,
         | 
| 111 | 
            -
                decoded_tokens,
         | 
| 112 | 
            -
                sampling,
         | 
| 113 | 
            -
                top_p=0.8,
         | 
| 114 | 
            -
                top_k=25,
         | 
| 115 | 
            -
                win_size=10,
         | 
| 116 | 
            -
                tau_r=0.1,
         | 
| 117 | 
            -
            ):
         | 
| 118 | 
            -
                top_ids = nucleus_sampling(weighted_scores, top_p=top_p, top_k=top_k)
         | 
| 119 | 
            -
                rep_num = (
         | 
| 120 | 
            -
                    (torch.tensor(decoded_tokens[-win_size:]).to(weighted_scores.device) == top_ids)
         | 
| 121 | 
            -
                    .sum()
         | 
| 122 | 
            -
                    .item()
         | 
| 123 | 
            -
                )
         | 
| 124 | 
            -
                if rep_num >= win_size * tau_r:
         | 
| 125 | 
            -
                    top_ids = random_sampling(weighted_scores, decoded_tokens, sampling)
         | 
| 126 | 
            -
                return top_ids
         | 
| 127 | 
            -
             | 
| 128 | 
            -
             | 
| 129 | 
            -
            def nucleus_sampling(weighted_scores, top_p=0.8, top_k=25):
         | 
| 130 | 
            -
                prob, indices = [], []
         | 
| 131 | 
            -
                cum_prob = 0.0
         | 
| 132 | 
            -
                sorted_value, sorted_idx = weighted_scores.softmax(dim=0).sort(
         | 
| 133 | 
            -
                    descending=True, stable=True
         | 
| 134 | 
            -
                )
         | 
| 135 | 
            -
                for i in range(len(sorted_idx)):
         | 
| 136 | 
            -
                    # sampling both top-p and numbers.
         | 
| 137 | 
            -
                    if cum_prob < top_p and len(prob) < top_k:
         | 
| 138 | 
            -
                        cum_prob += sorted_value[i]
         | 
| 139 | 
            -
                        prob.append(sorted_value[i])
         | 
| 140 | 
            -
                        indices.append(sorted_idx[i])
         | 
| 141 | 
            -
                    else:
         | 
| 142 | 
            -
                        break
         | 
| 143 | 
            -
                prob = torch.tensor(prob).to(weighted_scores)
         | 
| 144 | 
            -
                indices = torch.tensor(indices, dtype=torch.long).to(weighted_scores.device)
         | 
| 145 | 
            -
                top_ids = indices[prob.multinomial(1, replacement=True)]
         | 
| 146 | 
            -
                return top_ids
         | 
| 147 | 
            -
             | 
| 148 | 
            -
             | 
| 149 | 
            -
            def random_sampling(weighted_scores, decoded_tokens, sampling):
         | 
| 150 | 
            -
                top_ids = weighted_scores.softmax(dim=0).multinomial(1, replacement=True)
         | 
| 151 | 
            -
                return top_ids
         | 
| 152 | 
            -
             | 
| 153 | 
            -
             | 
| 154 | 
            -
            def fade_in_out(fade_in_mel, fade_out_mel, window):
         | 
| 155 | 
            -
                device = fade_in_mel.device
         | 
| 156 | 
            -
                fade_in_mel, fade_out_mel = fade_in_mel.cpu(), fade_out_mel.cpu()
         | 
| 157 | 
            -
                mel_overlap_len = int(window.shape[0] / 2)
         | 
| 158 | 
            -
                fade_in_mel[..., :mel_overlap_len] = (
         | 
| 159 | 
            -
                    fade_in_mel[..., :mel_overlap_len] * window[:mel_overlap_len]
         | 
| 160 | 
            -
                    + fade_out_mel[..., -mel_overlap_len:] * window[mel_overlap_len:]
         | 
| 161 | 
            -
                )
         | 
| 162 | 
            -
                return fade_in_mel.to(device)
         | 
| 163 | 
            -
             | 
| 164 | 
            -
             | 
| 165 | 
            -
            def set_all_random_seed(seed):
         | 
| 166 | 
            -
                random.seed(seed)
         | 
| 167 | 
            -
                np.random.seed(seed)
         | 
| 168 | 
            -
                torch.manual_seed(seed)
         | 
| 169 | 
            -
                torch.cuda.manual_seed_all(seed)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        cosyvoice/utils/executor.py
    DELETED
    
    | @@ -1,151 +0,0 @@ | |
| 1 | 
            -
            # Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
         | 
| 2 | 
            -
            #               2024 Alibaba Inc (authors: Xiang Lyu)
         | 
| 3 | 
            -
            #
         | 
| 4 | 
            -
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 5 | 
            -
            # you may not use this file except in compliance with the License.
         | 
| 6 | 
            -
            # You may obtain a copy of the License at
         | 
| 7 | 
            -
            #
         | 
| 8 | 
            -
            #   http://www.apache.org/licenses/LICENSE-2.0
         | 
| 9 | 
            -
            #
         | 
| 10 | 
            -
            # Unless required by applicable law or agreed to in writing, software
         | 
| 11 | 
            -
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 12 | 
            -
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 13 | 
            -
            # See the License for the specific language governing permissions and
         | 
| 14 | 
            -
            # limitations under the License.
         | 
| 15 | 
            -
             | 
| 16 | 
            -
            import logging
         | 
| 17 | 
            -
            from contextlib import nullcontext
         | 
| 18 | 
            -
            import os
         | 
| 19 | 
            -
             | 
| 20 | 
            -
            import torch
         | 
| 21 | 
            -
            import torch.distributed as dist
         | 
| 22 | 
            -
             | 
| 23 | 
            -
            from cosyvoice.utils.train_utils import (
         | 
| 24 | 
            -
                update_parameter_and_lr,
         | 
| 25 | 
            -
                log_per_step,
         | 
| 26 | 
            -
                log_per_save,
         | 
| 27 | 
            -
                batch_forward,
         | 
| 28 | 
            -
                batch_backward,
         | 
| 29 | 
            -
                save_model,
         | 
| 30 | 
            -
                cosyvoice_join,
         | 
| 31 | 
            -
            )
         | 
| 32 | 
            -
             | 
| 33 | 
            -
             | 
| 34 | 
            -
            class Executor:
         | 
| 35 | 
            -
             | 
| 36 | 
            -
                def __init__(self):
         | 
| 37 | 
            -
                    self.step = 0
         | 
| 38 | 
            -
                    self.epoch = 0
         | 
| 39 | 
            -
                    self.rank = int(os.environ.get("RANK", 0))
         | 
| 40 | 
            -
                    self.device = torch.device("cuda:{}".format(self.rank))
         | 
| 41 | 
            -
             | 
| 42 | 
            -
                def train_one_epoc(
         | 
| 43 | 
            -
                    self,
         | 
| 44 | 
            -
                    model,
         | 
| 45 | 
            -
                    optimizer,
         | 
| 46 | 
            -
                    scheduler,
         | 
| 47 | 
            -
                    train_data_loader,
         | 
| 48 | 
            -
                    cv_data_loader,
         | 
| 49 | 
            -
                    writer,
         | 
| 50 | 
            -
                    info_dict,
         | 
| 51 | 
            -
                    group_join,
         | 
| 52 | 
            -
                ):
         | 
| 53 | 
            -
                    """Train one epoch"""
         | 
| 54 | 
            -
             | 
| 55 | 
            -
                    lr = optimizer.param_groups[0]["lr"]
         | 
| 56 | 
            -
                    logging.info(
         | 
| 57 | 
            -
                        "Epoch {} TRAIN info lr {} rank {}".format(self.epoch, lr, self.rank)
         | 
| 58 | 
            -
                    )
         | 
| 59 | 
            -
                    logging.info(
         | 
| 60 | 
            -
                        "using accumulate grad, new batch size is {} times"
         | 
| 61 | 
            -
                        " larger than before".format(info_dict["accum_grad"])
         | 
| 62 | 
            -
                    )
         | 
| 63 | 
            -
                    # A context manager to be used in conjunction with an instance of
         | 
| 64 | 
            -
                    # torch.nn.parallel.DistributedDataParallel to be able to train
         | 
| 65 | 
            -
                    # with uneven inputs across participating processes.
         | 
| 66 | 
            -
                    model.train()
         | 
| 67 | 
            -
                    model_context = (
         | 
| 68 | 
            -
                        model.join if info_dict["train_engine"] == "torch_ddp" else nullcontext
         | 
| 69 | 
            -
                    )
         | 
| 70 | 
            -
                    with model_context():
         | 
| 71 | 
            -
                        for batch_idx, batch_dict in enumerate(train_data_loader):
         | 
| 72 | 
            -
                            info_dict["tag"] = "TRAIN"
         | 
| 73 | 
            -
                            info_dict["step"] = self.step
         | 
| 74 | 
            -
                            info_dict["epoch"] = self.epoch
         | 
| 75 | 
            -
                            info_dict["batch_idx"] = batch_idx
         | 
| 76 | 
            -
                            if cosyvoice_join(group_join, info_dict):
         | 
| 77 | 
            -
                                break
         | 
| 78 | 
            -
             | 
| 79 | 
            -
                            # Disable gradient synchronizations across DDP processes.
         | 
| 80 | 
            -
                            # Within this context, gradients will be accumulated on module
         | 
| 81 | 
            -
                            # variables, which will later be synchronized.
         | 
| 82 | 
            -
                            if (
         | 
| 83 | 
            -
                                info_dict["train_engine"] == "torch_ddp"
         | 
| 84 | 
            -
                                and (batch_idx + 1) % info_dict["accum_grad"] != 0
         | 
| 85 | 
            -
                            ):
         | 
| 86 | 
            -
                                context = model.no_sync
         | 
| 87 | 
            -
                            # Used for single gpu training and DDP gradient synchronization
         | 
| 88 | 
            -
                            # processes.
         | 
| 89 | 
            -
                            else:
         | 
| 90 | 
            -
                                context = nullcontext
         | 
| 91 | 
            -
             | 
| 92 | 
            -
                            with context():
         | 
| 93 | 
            -
                                info_dict = batch_forward(model, batch_dict, info_dict)
         | 
| 94 | 
            -
                                info_dict = batch_backward(model, info_dict)
         | 
| 95 | 
            -
             | 
| 96 | 
            -
                            info_dict = update_parameter_and_lr(
         | 
| 97 | 
            -
                                model, optimizer, scheduler, info_dict
         | 
| 98 | 
            -
                            )
         | 
| 99 | 
            -
                            log_per_step(writer, info_dict)
         | 
| 100 | 
            -
                            # NOTE specify save_per_step in cosyvoice.yaml if you want to enable step save
         | 
| 101 | 
            -
                            if (
         | 
| 102 | 
            -
                                info_dict["save_per_step"] > 0
         | 
| 103 | 
            -
                                and (self.step + 1) % info_dict["save_per_step"] == 0
         | 
| 104 | 
            -
                                and (batch_idx + 1) % info_dict["accum_grad"] == 0
         | 
| 105 | 
            -
                            ):
         | 
| 106 | 
            -
                                dist.barrier()
         | 
| 107 | 
            -
                                self.cv(
         | 
| 108 | 
            -
                                    model, cv_data_loader, writer, info_dict, on_batch_end=False
         | 
| 109 | 
            -
                                )
         | 
| 110 | 
            -
                                model.train()
         | 
| 111 | 
            -
                            if (batch_idx + 1) % info_dict["accum_grad"] == 0:
         | 
| 112 | 
            -
                                self.step += 1
         | 
| 113 | 
            -
                    dist.barrier()
         | 
| 114 | 
            -
                    self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=True)
         | 
| 115 | 
            -
             | 
| 116 | 
            -
                @torch.inference_mode()
         | 
| 117 | 
            -
                def cv(self, model, cv_data_loader, writer, info_dict, on_batch_end=True):
         | 
| 118 | 
            -
                    """Cross validation on"""
         | 
| 119 | 
            -
                    logging.info(
         | 
| 120 | 
            -
                        "Epoch {} Step {} on_batch_end {} CV rank {}".format(
         | 
| 121 | 
            -
                            self.epoch, self.step + 1, on_batch_end, self.rank
         | 
| 122 | 
            -
                        )
         | 
| 123 | 
            -
                    )
         | 
| 124 | 
            -
                    model.eval()
         | 
| 125 | 
            -
                    total_num_utts, total_loss_dict = 0, {}  # avoid division by 0
         | 
| 126 | 
            -
                    for batch_idx, batch_dict in enumerate(cv_data_loader):
         | 
| 127 | 
            -
                        info_dict["tag"] = "CV"
         | 
| 128 | 
            -
                        info_dict["step"] = self.step
         | 
| 129 | 
            -
                        info_dict["epoch"] = self.epoch
         | 
| 130 | 
            -
                        info_dict["batch_idx"] = batch_idx
         | 
| 131 | 
            -
             | 
| 132 | 
            -
                        num_utts = len(batch_dict["utts"])
         | 
| 133 | 
            -
                        total_num_utts += num_utts
         | 
| 134 | 
            -
             | 
| 135 | 
            -
                        info_dict = batch_forward(model, batch_dict, info_dict)
         | 
| 136 | 
            -
             | 
| 137 | 
            -
                        for k, v in info_dict["loss_dict"].items():
         | 
| 138 | 
            -
                            if k not in total_loss_dict:
         | 
| 139 | 
            -
                                total_loss_dict[k] = []
         | 
| 140 | 
            -
                            total_loss_dict[k].append(v.item() * num_utts)
         | 
| 141 | 
            -
                        log_per_step(None, info_dict)
         | 
| 142 | 
            -
                    for k, v in total_loss_dict.items():
         | 
| 143 | 
            -
                        total_loss_dict[k] = sum(v) / total_num_utts
         | 
| 144 | 
            -
                    info_dict["loss_dict"] = total_loss_dict
         | 
| 145 | 
            -
                    log_per_save(writer, info_dict)
         | 
| 146 | 
            -
                    model_name = (
         | 
| 147 | 
            -
                        "epoch_{}_whole".format(self.epoch)
         | 
| 148 | 
            -
                        if on_batch_end
         | 
| 149 | 
            -
                        else "epoch_{}_step_{}".format(self.epoch, self.step + 1)
         | 
| 150 | 
            -
                    )
         | 
| 151 | 
            -
                    save_model(model, model_name, info_dict)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        cosyvoice/utils/file_utils.py
    DELETED
    
    | @@ -1,49 +0,0 @@ | |
| 1 | 
            -
            # Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
         | 
| 2 | 
            -
            #               2024 Alibaba Inc (authors: Xiang Lyu)
         | 
| 3 | 
            -
            #
         | 
| 4 | 
            -
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 5 | 
            -
            # you may not use this file except in compliance with the License.
         | 
| 6 | 
            -
            # You may obtain a copy of the License at
         | 
| 7 | 
            -
            #
         | 
| 8 | 
            -
            #     http://www.apache.org/licenses/LICENSE-2.0
         | 
| 9 | 
            -
            #
         | 
| 10 | 
            -
            # Unless required by applicable law or agreed to in writing, software
         | 
| 11 | 
            -
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 12 | 
            -
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 13 | 
            -
            # See the License for the specific language governing permissions and
         | 
| 14 | 
            -
            # limitations under the License.
         | 
| 15 | 
            -
             | 
| 16 | 
            -
            import json
         | 
| 17 | 
            -
            import torchaudio
         | 
| 18 | 
            -
            import logging
         | 
| 19 | 
            -
             | 
| 20 | 
            -
            logging.getLogger("matplotlib").setLevel(logging.WARNING)
         | 
| 21 | 
            -
            logging.basicConfig(level=logging.DEBUG, format="%(asctime)s %(levelname)s %(message)s")
         | 
| 22 | 
            -
             | 
| 23 | 
            -
             | 
| 24 | 
            -
            def read_lists(list_file):
         | 
| 25 | 
            -
                lists = []
         | 
| 26 | 
            -
                with open(list_file, "r", encoding="utf8") as fin:
         | 
| 27 | 
            -
                    for line in fin:
         | 
| 28 | 
            -
                        lists.append(line.strip())
         | 
| 29 | 
            -
                return lists
         | 
| 30 | 
            -
             | 
| 31 | 
            -
             | 
| 32 | 
            -
            def read_json_lists(list_file):
         | 
| 33 | 
            -
                lists = read_lists(list_file)
         | 
| 34 | 
            -
                results = {}
         | 
| 35 | 
            -
                for fn in lists:
         | 
| 36 | 
            -
                    with open(fn, "r", encoding="utf8") as fin:
         | 
| 37 | 
            -
                        results.update(json.load(fin))
         | 
| 38 | 
            -
                return results
         | 
| 39 | 
            -
             | 
| 40 | 
            -
             | 
| 41 | 
            -
            def load_wav(wav, target_sr):
         | 
| 42 | 
            -
                speech, sample_rate = torchaudio.load(wav)
         | 
| 43 | 
            -
                speech = speech.mean(dim=0, keepdim=True)
         | 
| 44 | 
            -
                if sample_rate != target_sr:
         | 
| 45 | 
            -
                    # assert sample_rate > target_sr, 'wav sample rate {} must be greater than {}'.format(sample_rate, target_sr)
         | 
| 46 | 
            -
                    speech = torchaudio.transforms.Resample(
         | 
| 47 | 
            -
                        orig_freq=sample_rate, new_freq=target_sr
         | 
| 48 | 
            -
                    )(speech)
         | 
| 49 | 
            -
                return speech
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        cosyvoice/utils/frontend_utils.py
    DELETED
    
    | @@ -1,142 +0,0 @@ | |
| 1 | 
            -
            # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
         | 
| 2 | 
            -
            #
         | 
| 3 | 
            -
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 4 | 
            -
            # you may not use this file except in compliance with the License.
         | 
| 5 | 
            -
            # You may obtain a copy of the License at
         | 
| 6 | 
            -
            #
         | 
| 7 | 
            -
            #   http://www.apache.org/licenses/LICENSE-2.0
         | 
| 8 | 
            -
            #
         | 
| 9 | 
            -
            # Unless required by applicable law or agreed to in writing, software
         | 
| 10 | 
            -
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 11 | 
            -
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 12 | 
            -
            # See the License for the specific language governing permissions and
         | 
| 13 | 
            -
            # limitations under the License.
         | 
| 14 | 
            -
             | 
| 15 | 
            -
            import re
         | 
| 16 | 
            -
             | 
| 17 | 
            -
            chinese_char_pattern = re.compile(r"[\u4e00-\u9fff]+")
         | 
| 18 | 
            -
             | 
| 19 | 
            -
             | 
| 20 | 
            -
            # whether contain chinese character
         | 
| 21 | 
            -
            def contains_chinese(text):
         | 
| 22 | 
            -
                return bool(chinese_char_pattern.search(text))
         | 
| 23 | 
            -
             | 
| 24 | 
            -
             | 
| 25 | 
            -
            # replace special symbol
         | 
| 26 | 
            -
            def replace_corner_mark(text):
         | 
| 27 | 
            -
                text = text.replace("²", "平方")
         | 
| 28 | 
            -
                text = text.replace("³", "立方")
         | 
| 29 | 
            -
                return text
         | 
| 30 | 
            -
             | 
| 31 | 
            -
             | 
| 32 | 
            -
            # remove meaningless symbol
         | 
| 33 | 
            -
            def remove_bracket(text):
         | 
| 34 | 
            -
                text = text.replace("(", "").replace(")", "")
         | 
| 35 | 
            -
                text = text.replace("【", "").replace("】", "")
         | 
| 36 | 
            -
                text = text.replace("`", "").replace("`", "")
         | 
| 37 | 
            -
                text = text.replace("——", " ")
         | 
| 38 | 
            -
                return text
         | 
| 39 | 
            -
             | 
| 40 | 
            -
             | 
| 41 | 
            -
            # spell Arabic numerals
         | 
| 42 | 
            -
            def spell_out_number(text: str, inflect_parser):
         | 
| 43 | 
            -
                new_text = []
         | 
| 44 | 
            -
                st = None
         | 
| 45 | 
            -
                for i, c in enumerate(text):
         | 
| 46 | 
            -
                    if not c.isdigit():
         | 
| 47 | 
            -
                        if st is not None:
         | 
| 48 | 
            -
                            num_str = inflect_parser.number_to_words(text[st:i])
         | 
| 49 | 
            -
                            new_text.append(num_str)
         | 
| 50 | 
            -
                            st = None
         | 
| 51 | 
            -
                        new_text.append(c)
         | 
| 52 | 
            -
                    else:
         | 
| 53 | 
            -
                        if st is None:
         | 
| 54 | 
            -
                            st = i
         | 
| 55 | 
            -
                if st is not None and st < len(text):
         | 
| 56 | 
            -
                    num_str = inflect_parser.number_to_words(text[st:])
         | 
| 57 | 
            -
                    new_text.append(num_str)
         | 
| 58 | 
            -
                return "".join(new_text)
         | 
| 59 | 
            -
             | 
| 60 | 
            -
             | 
| 61 | 
            -
            # split paragrah logic:
         | 
| 62 | 
            -
            # 1. per sentence max len token_max_n, min len token_min_n, merge if last sentence len less than merge_len
         | 
| 63 | 
            -
            # 2. cal sentence len according to lang
         | 
| 64 | 
            -
            # 3. split sentence according to puncatation
         | 
| 65 | 
            -
            def split_paragraph(
         | 
| 66 | 
            -
                text: str,
         | 
| 67 | 
            -
                tokenize,
         | 
| 68 | 
            -
                lang="zh",
         | 
| 69 | 
            -
                token_max_n=80,
         | 
| 70 | 
            -
                token_min_n=60,
         | 
| 71 | 
            -
                merge_len=20,
         | 
| 72 | 
            -
                comma_split=False,
         | 
| 73 | 
            -
            ):
         | 
| 74 | 
            -
                def calc_utt_length(_text: str):
         | 
| 75 | 
            -
                    if lang == "zh":
         | 
| 76 | 
            -
                        return len(_text)
         | 
| 77 | 
            -
                    else:
         | 
| 78 | 
            -
                        return len(tokenize(_text))
         | 
| 79 | 
            -
             | 
| 80 | 
            -
                def should_merge(_text: str):
         | 
| 81 | 
            -
                    if lang == "zh":
         | 
| 82 | 
            -
                        return len(_text) < merge_len
         | 
| 83 | 
            -
                    else:
         | 
| 84 | 
            -
                        return len(tokenize(_text)) < merge_len
         | 
| 85 | 
            -
             | 
| 86 | 
            -
                if lang == "zh":
         | 
| 87 | 
            -
                    pounc = ["。", "?", "!", ";", ":", "、", ".", "?", "!", ";"]
         | 
| 88 | 
            -
                else:
         | 
| 89 | 
            -
                    pounc = [".", "?", "!", ";", ":"]
         | 
| 90 | 
            -
                if comma_split:
         | 
| 91 | 
            -
                    pounc.extend([",", ","])
         | 
| 92 | 
            -
             | 
| 93 | 
            -
                if text[-1] not in pounc:
         | 
| 94 | 
            -
                    if lang == "zh":
         | 
| 95 | 
            -
                        text += "。"
         | 
| 96 | 
            -
                    else:
         | 
| 97 | 
            -
                        text += "."
         | 
| 98 | 
            -
             | 
| 99 | 
            -
                st = 0
         | 
| 100 | 
            -
                utts = []
         | 
| 101 | 
            -
                for i, c in enumerate(text):
         | 
| 102 | 
            -
                    if c in pounc:
         | 
| 103 | 
            -
                        if len(text[st:i]) > 0:
         | 
| 104 | 
            -
                            utts.append(text[st:i] + c)
         | 
| 105 | 
            -
                        if i + 1 < len(text) and text[i + 1] in ['"', "”"]:
         | 
| 106 | 
            -
                            tmp = utts.pop(-1)
         | 
| 107 | 
            -
                            utts.append(tmp + text[i + 1])
         | 
| 108 | 
            -
                            st = i + 2
         | 
| 109 | 
            -
                        else:
         | 
| 110 | 
            -
                            st = i + 1
         | 
| 111 | 
            -
             | 
| 112 | 
            -
                final_utts = []
         | 
| 113 | 
            -
                cur_utt = ""
         | 
| 114 | 
            -
                for utt in utts:
         | 
| 115 | 
            -
                    if (
         | 
| 116 | 
            -
                        calc_utt_length(cur_utt + utt) > token_max_n
         | 
| 117 | 
            -
                        and calc_utt_length(cur_utt) > token_min_n
         | 
| 118 | 
            -
                    ):
         | 
| 119 | 
            -
                        final_utts.append(cur_utt)
         | 
| 120 | 
            -
                        cur_utt = ""
         | 
| 121 | 
            -
                    cur_utt = cur_utt + utt
         | 
| 122 | 
            -
                if len(cur_utt) > 0:
         | 
| 123 | 
            -
                    if should_merge(cur_utt) and len(final_utts) != 0:
         | 
| 124 | 
            -
                        final_utts[-1] = final_utts[-1] + cur_utt
         | 
| 125 | 
            -
                    else:
         | 
| 126 | 
            -
                        final_utts.append(cur_utt)
         | 
| 127 | 
            -
             | 
| 128 | 
            -
                return final_utts
         | 
| 129 | 
            -
             | 
| 130 | 
            -
             | 
| 131 | 
            -
            # remove blank between chinese character
         | 
| 132 | 
            -
            def replace_blank(text: str):
         | 
| 133 | 
            -
                out_str = []
         | 
| 134 | 
            -
                for i, c in enumerate(text):
         | 
| 135 | 
            -
                    if c == " ":
         | 
| 136 | 
            -
                        if (text[i + 1].isascii() and text[i + 1] != " ") and (
         | 
| 137 | 
            -
                            text[i - 1].isascii() and text[i - 1] != " "
         | 
| 138 | 
            -
                        ):
         | 
| 139 | 
            -
                            out_str.append(c)
         | 
| 140 | 
            -
                    else:
         | 
| 141 | 
            -
                        out_str.append(c)
         | 
| 142 | 
            -
                return "".join(out_str)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        cosyvoice/utils/mask.py
    DELETED
    
    | @@ -1,226 +0,0 @@ | |
| 1 | 
            -
            # Copyright (c) 2019 Shigeki Karita
         | 
| 2 | 
            -
            #               2020 Mobvoi Inc (Binbin Zhang)
         | 
| 3 | 
            -
            #               2024 Alibaba Inc (authors: Xiang Lyu)
         | 
| 4 | 
            -
            #
         | 
| 5 | 
            -
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 6 | 
            -
            # you may not use this file except in compliance with the License.
         | 
| 7 | 
            -
            # You may obtain a copy of the License at
         | 
| 8 | 
            -
            #
         | 
| 9 | 
            -
            #   http://www.apache.org/licenses/LICENSE-2.0
         | 
| 10 | 
            -
            #
         | 
| 11 | 
            -
            # Unless required by applicable law or agreed to in writing, software
         | 
| 12 | 
            -
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 13 | 
            -
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 14 | 
            -
            # See the License for the specific language governing permissions and
         | 
| 15 | 
            -
            # limitations under the License.
         | 
| 16 | 
            -
             | 
| 17 | 
            -
            import torch
         | 
| 18 | 
            -
             | 
| 19 | 
            -
            '''
         | 
| 20 | 
            -
            def subsequent_mask(
         | 
| 21 | 
            -
                    size: int,
         | 
| 22 | 
            -
                    device: torch.device = torch.device("cpu"),
         | 
| 23 | 
            -
            ) -> torch.Tensor:
         | 
| 24 | 
            -
                """Create mask for subsequent steps (size, size).
         | 
| 25 | 
            -
             | 
| 26 | 
            -
                This mask is used only in decoder which works in an auto-regressive mode.
         | 
| 27 | 
            -
                This means the current step could only do attention with its left steps.
         | 
| 28 | 
            -
             | 
| 29 | 
            -
                In encoder, fully attention is used when streaming is not necessary and
         | 
| 30 | 
            -
                the sequence is not long. In this  case, no attention mask is needed.
         | 
| 31 | 
            -
             | 
| 32 | 
            -
                When streaming is need, chunk-based attention is used in encoder. See
         | 
| 33 | 
            -
                subsequent_chunk_mask for the chunk-based attention mask.
         | 
| 34 | 
            -
             | 
| 35 | 
            -
                Args:
         | 
| 36 | 
            -
                    size (int): size of mask
         | 
| 37 | 
            -
                    str device (str): "cpu" or "cuda" or torch.Tensor.device
         | 
| 38 | 
            -
                    dtype (torch.device): result dtype
         | 
| 39 | 
            -
             | 
| 40 | 
            -
                Returns:
         | 
| 41 | 
            -
                    torch.Tensor: mask
         | 
| 42 | 
            -
             | 
| 43 | 
            -
                Examples:
         | 
| 44 | 
            -
                    >>> subsequent_mask(3)
         | 
| 45 | 
            -
                    [[1, 0, 0],
         | 
| 46 | 
            -
                     [1, 1, 0],
         | 
| 47 | 
            -
                     [1, 1, 1]]
         | 
| 48 | 
            -
                """
         | 
| 49 | 
            -
                ret = torch.ones(size, size, device=device, dtype=torch.bool)
         | 
| 50 | 
            -
                return torch.tril(ret)
         | 
| 51 | 
            -
            '''
         | 
| 52 | 
            -
             | 
| 53 | 
            -
             | 
| 54 | 
            -
            def subsequent_mask(
         | 
| 55 | 
            -
                size: int,
         | 
| 56 | 
            -
                device: torch.device = torch.device("cpu"),
         | 
| 57 | 
            -
            ) -> torch.Tensor:
         | 
| 58 | 
            -
                """Create mask for subsequent steps (size, size).
         | 
| 59 | 
            -
             | 
| 60 | 
            -
                This mask is used only in decoder which works in an auto-regressive mode.
         | 
| 61 | 
            -
                This means the current step could only do attention with its left steps.
         | 
| 62 | 
            -
             | 
| 63 | 
            -
                In encoder, fully attention is used when streaming is not necessary and
         | 
| 64 | 
            -
                the sequence is not long. In this  case, no attention mask is needed.
         | 
| 65 | 
            -
             | 
| 66 | 
            -
                When streaming is need, chunk-based attention is used in encoder. See
         | 
| 67 | 
            -
                subsequent_chunk_mask for the chunk-based attention mask.
         | 
| 68 | 
            -
             | 
| 69 | 
            -
                Args:
         | 
| 70 | 
            -
                    size (int): size of mask
         | 
| 71 | 
            -
                    str device (str): "cpu" or "cuda" or torch.Tensor.device
         | 
| 72 | 
            -
                    dtype (torch.device): result dtype
         | 
| 73 | 
            -
             | 
| 74 | 
            -
                Returns:
         | 
| 75 | 
            -
                    torch.Tensor: mask
         | 
| 76 | 
            -
             | 
| 77 | 
            -
                Examples:
         | 
| 78 | 
            -
                    >>> subsequent_mask(3)
         | 
| 79 | 
            -
                    [[1, 0, 0],
         | 
| 80 | 
            -
                     [1, 1, 0],
         | 
| 81 | 
            -
                     [1, 1, 1]]
         | 
| 82 | 
            -
                """
         | 
| 83 | 
            -
                arange = torch.arange(size, device=device)
         | 
| 84 | 
            -
                mask = arange.expand(size, size)
         | 
| 85 | 
            -
                arange = arange.unsqueeze(-1)
         | 
| 86 | 
            -
                mask = mask <= arange
         | 
| 87 | 
            -
                return mask
         | 
| 88 | 
            -
             | 
| 89 | 
            -
             | 
| 90 | 
            -
            def subsequent_chunk_mask(
         | 
| 91 | 
            -
                size: int,
         | 
| 92 | 
            -
                chunk_size: int,
         | 
| 93 | 
            -
                num_left_chunks: int = -1,
         | 
| 94 | 
            -
                device: torch.device = torch.device("cpu"),
         | 
| 95 | 
            -
            ) -> torch.Tensor:
         | 
| 96 | 
            -
                """Create mask for subsequent steps (size, size) with chunk size,
         | 
| 97 | 
            -
                   this is for streaming encoder
         | 
| 98 | 
            -
             | 
| 99 | 
            -
                Args:
         | 
| 100 | 
            -
                    size (int): size of mask
         | 
| 101 | 
            -
                    chunk_size (int): size of chunk
         | 
| 102 | 
            -
                    num_left_chunks (int): number of left chunks
         | 
| 103 | 
            -
                        <0: use full chunk
         | 
| 104 | 
            -
                        >=0: use num_left_chunks
         | 
| 105 | 
            -
                    device (torch.device): "cpu" or "cuda" or torch.Tensor.device
         | 
| 106 | 
            -
             | 
| 107 | 
            -
                Returns:
         | 
| 108 | 
            -
                    torch.Tensor: mask
         | 
| 109 | 
            -
             | 
| 110 | 
            -
                Examples:
         | 
| 111 | 
            -
                    >>> subsequent_chunk_mask(4, 2)
         | 
| 112 | 
            -
                    [[1, 1, 0, 0],
         | 
| 113 | 
            -
                     [1, 1, 0, 0],
         | 
| 114 | 
            -
                     [1, 1, 1, 1],
         | 
| 115 | 
            -
                     [1, 1, 1, 1]]
         | 
| 116 | 
            -
                """
         | 
| 117 | 
            -
                ret = torch.zeros(size, size, device=device, dtype=torch.bool)
         | 
| 118 | 
            -
                for i in range(size):
         | 
| 119 | 
            -
                    if num_left_chunks < 0:
         | 
| 120 | 
            -
                        start = 0
         | 
| 121 | 
            -
                    else:
         | 
| 122 | 
            -
                        start = max((i // chunk_size - num_left_chunks) * chunk_size, 0)
         | 
| 123 | 
            -
                    ending = min((i // chunk_size + 1) * chunk_size, size)
         | 
| 124 | 
            -
                    ret[i, start:ending] = True
         | 
| 125 | 
            -
                return ret
         | 
| 126 | 
            -
             | 
| 127 | 
            -
             | 
| 128 | 
            -
            def add_optional_chunk_mask(
         | 
| 129 | 
            -
                xs: torch.Tensor,
         | 
| 130 | 
            -
                masks: torch.Tensor,
         | 
| 131 | 
            -
                use_dynamic_chunk: bool,
         | 
| 132 | 
            -
                use_dynamic_left_chunk: bool,
         | 
| 133 | 
            -
                decoding_chunk_size: int,
         | 
| 134 | 
            -
                static_chunk_size: int,
         | 
| 135 | 
            -
                num_decoding_left_chunks: int,
         | 
| 136 | 
            -
                enable_full_context: bool = True,
         | 
| 137 | 
            -
            ):
         | 
| 138 | 
            -
                """Apply optional mask for encoder.
         | 
| 139 | 
            -
             | 
| 140 | 
            -
                Args:
         | 
| 141 | 
            -
                    xs (torch.Tensor): padded input, (B, L, D), L for max length
         | 
| 142 | 
            -
                    mask (torch.Tensor): mask for xs, (B, 1, L)
         | 
| 143 | 
            -
                    use_dynamic_chunk (bool): whether to use dynamic chunk or not
         | 
| 144 | 
            -
                    use_dynamic_left_chunk (bool): whether to use dynamic left chunk for
         | 
| 145 | 
            -
                        training.
         | 
| 146 | 
            -
                    decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's
         | 
| 147 | 
            -
                        0: default for training, use random dynamic chunk.
         | 
| 148 | 
            -
                        <0: for decoding, use full chunk.
         | 
| 149 | 
            -
                        >0: for decoding, use fixed chunk size as set.
         | 
| 150 | 
            -
                    static_chunk_size (int): chunk size for static chunk training/decoding
         | 
| 151 | 
            -
                        if it's greater than 0, if use_dynamic_chunk is true,
         | 
| 152 | 
            -
                        this parameter will be ignored
         | 
| 153 | 
            -
                    num_decoding_left_chunks: number of left chunks, this is for decoding,
         | 
| 154 | 
            -
                        the chunk size is decoding_chunk_size.
         | 
| 155 | 
            -
                        >=0: use num_decoding_left_chunks
         | 
| 156 | 
            -
                        <0: use all left chunks
         | 
| 157 | 
            -
                    enable_full_context (bool):
         | 
| 158 | 
            -
                        True: chunk size is either [1, 25] or full context(max_len)
         | 
| 159 | 
            -
                        False: chunk size ~ U[1, 25]
         | 
| 160 | 
            -
             | 
| 161 | 
            -
                Returns:
         | 
| 162 | 
            -
                    torch.Tensor: chunk mask of the input xs.
         | 
| 163 | 
            -
                """
         | 
| 164 | 
            -
                # Whether to use chunk mask or not
         | 
| 165 | 
            -
                if use_dynamic_chunk:
         | 
| 166 | 
            -
                    max_len = xs.size(1)
         | 
| 167 | 
            -
                    if decoding_chunk_size < 0:
         | 
| 168 | 
            -
                        chunk_size = max_len
         | 
| 169 | 
            -
                        num_left_chunks = -1
         | 
| 170 | 
            -
                    elif decoding_chunk_size > 0:
         | 
| 171 | 
            -
                        chunk_size = decoding_chunk_size
         | 
| 172 | 
            -
                        num_left_chunks = num_decoding_left_chunks
         | 
| 173 | 
            -
                    else:
         | 
| 174 | 
            -
                        # chunk size is either [1, 25] or full context(max_len).
         | 
| 175 | 
            -
                        # Since we use 4 times subsampling and allow up to 1s(100 frames)
         | 
| 176 | 
            -
                        # delay, the maximum frame is 100 / 4 = 25.
         | 
| 177 | 
            -
                        chunk_size = torch.randint(1, max_len, (1,)).item()
         | 
| 178 | 
            -
                        num_left_chunks = -1
         | 
| 179 | 
            -
                        if chunk_size > max_len // 2 and enable_full_context:
         | 
| 180 | 
            -
                            chunk_size = max_len
         | 
| 181 | 
            -
                        else:
         | 
| 182 | 
            -
                            chunk_size = chunk_size % 25 + 1
         | 
| 183 | 
            -
                            if use_dynamic_left_chunk:
         | 
| 184 | 
            -
                                max_left_chunks = (max_len - 1) // chunk_size
         | 
| 185 | 
            -
                                num_left_chunks = torch.randint(0, max_left_chunks, (1,)).item()
         | 
| 186 | 
            -
                    chunk_masks = subsequent_chunk_mask(
         | 
| 187 | 
            -
                        xs.size(1), chunk_size, num_left_chunks, xs.device
         | 
| 188 | 
            -
                    )  # (L, L)
         | 
| 189 | 
            -
                    chunk_masks = chunk_masks.unsqueeze(0)  # (1, L, L)
         | 
| 190 | 
            -
                    chunk_masks = masks & chunk_masks  # (B, L, L)
         | 
| 191 | 
            -
                elif static_chunk_size > 0:
         | 
| 192 | 
            -
                    num_left_chunks = num_decoding_left_chunks
         | 
| 193 | 
            -
                    chunk_masks = subsequent_chunk_mask(
         | 
| 194 | 
            -
                        xs.size(1), static_chunk_size, num_left_chunks, xs.device
         | 
| 195 | 
            -
                    )  # (L, L)
         | 
| 196 | 
            -
                    chunk_masks = chunk_masks.unsqueeze(0)  # (1, L, L)
         | 
| 197 | 
            -
                    chunk_masks = masks & chunk_masks  # (B, L, L)
         | 
| 198 | 
            -
                else:
         | 
| 199 | 
            -
                    chunk_masks = masks
         | 
| 200 | 
            -
                return chunk_masks
         | 
| 201 | 
            -
             | 
| 202 | 
            -
             | 
| 203 | 
            -
            def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
         | 
| 204 | 
            -
                """Make mask tensor containing indices of padded part.
         | 
| 205 | 
            -
             | 
| 206 | 
            -
                See description of make_non_pad_mask.
         | 
| 207 | 
            -
             | 
| 208 | 
            -
                Args:
         | 
| 209 | 
            -
                    lengths (torch.Tensor): Batch of lengths (B,).
         | 
| 210 | 
            -
                Returns:
         | 
| 211 | 
            -
                    torch.Tensor: Mask tensor containing indices of padded part.
         | 
| 212 | 
            -
             | 
| 213 | 
            -
                Examples:
         | 
| 214 | 
            -
                    >>> lengths = [5, 3, 2]
         | 
| 215 | 
            -
                    >>> make_pad_mask(lengths)
         | 
| 216 | 
            -
                    masks = [[0, 0, 0, 0 ,0],
         | 
| 217 | 
            -
                             [0, 0, 0, 1, 1],
         | 
| 218 | 
            -
                             [0, 0, 1, 1, 1]]
         | 
| 219 | 
            -
                """
         | 
| 220 | 
            -
                batch_size = lengths.size(0)
         | 
| 221 | 
            -
                max_len = max_len if max_len > 0 else lengths.max().item()
         | 
| 222 | 
            -
                seq_range = torch.arange(0, max_len, dtype=torch.int64, device=lengths.device)
         | 
| 223 | 
            -
                seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
         | 
| 224 | 
            -
                seq_length_expand = lengths.unsqueeze(-1)
         | 
| 225 | 
            -
                mask = seq_range_expand >= seq_length_expand
         | 
| 226 | 
            -
                return mask
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        cosyvoice/utils/scheduler.py
    DELETED
    
    | @@ -1,761 +0,0 @@ | |
| 1 | 
            -
            # Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
         | 
| 2 | 
            -
            #               2022 Ximalaya Inc (Yuguang Yang)
         | 
| 3 | 
            -
            #               2024 Alibaba Inc (authors: Xiang Lyu)
         | 
| 4 | 
            -
            #
         | 
| 5 | 
            -
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 6 | 
            -
            # you may not use this file except in compliance with the License.
         | 
| 7 | 
            -
            # You may obtain a copy of the License at
         | 
| 8 | 
            -
            #
         | 
| 9 | 
            -
            #   http://www.apache.org/licenses/LICENSE-2.0
         | 
| 10 | 
            -
            #
         | 
| 11 | 
            -
            # Unless required by applicable law or agreed to in writing, software
         | 
| 12 | 
            -
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 13 | 
            -
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 14 | 
            -
            # See the License for the specific language governing permissions and
         | 
| 15 | 
            -
            # limitations under the License.
         | 
| 16 | 
            -
            # Modified from ESPnet(https://github.com/espnet/espnet)
         | 
| 17 | 
            -
            #               NeMo(https://github.com/NVIDIA/NeMo)
         | 
| 18 | 
            -
             | 
| 19 | 
            -
            from typing import Union
         | 
| 20 | 
            -
             | 
| 21 | 
            -
            import math
         | 
| 22 | 
            -
            import warnings
         | 
| 23 | 
            -
            import torch
         | 
| 24 | 
            -
            from torch.optim.lr_scheduler import _LRScheduler
         | 
| 25 | 
            -
             | 
| 26 | 
            -
             | 
| 27 | 
            -
            class WarmupLR(_LRScheduler):
         | 
| 28 | 
            -
                """The WarmupLR scheduler
         | 
| 29 | 
            -
             | 
| 30 | 
            -
                This scheduler is almost same as NoamLR Scheduler except for following
         | 
| 31 | 
            -
                difference:
         | 
| 32 | 
            -
             | 
| 33 | 
            -
                NoamLR:
         | 
| 34 | 
            -
                    lr = optimizer.lr * model_size ** -0.5
         | 
| 35 | 
            -
                         * min(step ** -0.5, step * warmup_step ** -1.5)
         | 
| 36 | 
            -
                WarmupLR:
         | 
| 37 | 
            -
                    lr = optimizer.lr * warmup_step ** 0.5
         | 
| 38 | 
            -
                         * min(step ** -0.5, step * warmup_step ** -1.5)
         | 
| 39 | 
            -
             | 
| 40 | 
            -
                Note that the maximum lr equals to optimizer.lr in this scheduler.
         | 
| 41 | 
            -
             | 
| 42 | 
            -
                """
         | 
| 43 | 
            -
             | 
| 44 | 
            -
                def __init__(
         | 
| 45 | 
            -
                    self,
         | 
| 46 | 
            -
                    optimizer: torch.optim.Optimizer,
         | 
| 47 | 
            -
                    warmup_steps: Union[int, float] = 25000,
         | 
| 48 | 
            -
                    last_epoch: int = -1,
         | 
| 49 | 
            -
                ):
         | 
| 50 | 
            -
                    self.warmup_steps = warmup_steps
         | 
| 51 | 
            -
             | 
| 52 | 
            -
                    # __init__() must be invoked before setting field
         | 
| 53 | 
            -
                    # because step() is also invoked in __init__()
         | 
| 54 | 
            -
                    super().__init__(optimizer, last_epoch)
         | 
| 55 | 
            -
             | 
| 56 | 
            -
                def __repr__(self):
         | 
| 57 | 
            -
                    return f"{self.__class__.__name__}(warmup_steps={self.warmup_steps})"
         | 
| 58 | 
            -
             | 
| 59 | 
            -
                def get_lr(self):
         | 
| 60 | 
            -
                    step_num = self.last_epoch + 1
         | 
| 61 | 
            -
                    if self.warmup_steps == 0:
         | 
| 62 | 
            -
                        return [lr * step_num**-0.5 for lr in self.base_lrs]
         | 
| 63 | 
            -
                    else:
         | 
| 64 | 
            -
                        return [
         | 
| 65 | 
            -
                            lr
         | 
| 66 | 
            -
                            * self.warmup_steps**0.5
         | 
| 67 | 
            -
                            * min(step_num**-0.5, step_num * self.warmup_steps**-1.5)
         | 
| 68 | 
            -
                            for lr in self.base_lrs
         | 
| 69 | 
            -
                        ]
         | 
| 70 | 
            -
             | 
| 71 | 
            -
                def set_step(self, step: int):
         | 
| 72 | 
            -
                    self.last_epoch = step
         | 
| 73 | 
            -
             | 
| 74 | 
            -
             | 
| 75 | 
            -
            class WarmupPolicy(_LRScheduler):
         | 
| 76 | 
            -
                """Adds warmup kwargs and warmup logic to lr policy.
         | 
| 77 | 
            -
                All arguments should be passed as kwargs for clarity,
         | 
| 78 | 
            -
                Args:
         | 
| 79 | 
            -
                    warmup_steps: Number of training steps in warmup stage
         | 
| 80 | 
            -
                    warmup_ratio: Ratio of warmup steps to total steps
         | 
| 81 | 
            -
                    max_steps: Total number of steps while training or `None` for
         | 
| 82 | 
            -
                        infinite training
         | 
| 83 | 
            -
                """
         | 
| 84 | 
            -
             | 
| 85 | 
            -
                def __init__(
         | 
| 86 | 
            -
                    self,
         | 
| 87 | 
            -
                    optimizer,
         | 
| 88 | 
            -
                    *,
         | 
| 89 | 
            -
                    warmup_steps=None,
         | 
| 90 | 
            -
                    warmup_ratio=None,
         | 
| 91 | 
            -
                    max_steps=None,
         | 
| 92 | 
            -
                    min_lr=0.0,
         | 
| 93 | 
            -
                    last_epoch=-1,
         | 
| 94 | 
            -
                ):
         | 
| 95 | 
            -
                    assert not (
         | 
| 96 | 
            -
                        warmup_steps is not None and warmup_ratio is not None
         | 
| 97 | 
            -
                    ), "Either use particular number of step or ratio"
         | 
| 98 | 
            -
                    assert (
         | 
| 99 | 
            -
                        warmup_ratio is None or max_steps is not None
         | 
| 100 | 
            -
                    ), "If there is a ratio, there should be a total steps"
         | 
| 101 | 
            -
             | 
| 102 | 
            -
                    # It is necessary to assign all attributes *before* __init__,
         | 
| 103 | 
            -
                    # as class is wrapped by an inner class.
         | 
| 104 | 
            -
                    self.max_steps = max_steps
         | 
| 105 | 
            -
                    if warmup_steps is not None:
         | 
| 106 | 
            -
                        self.warmup_steps = warmup_steps
         | 
| 107 | 
            -
                    elif warmup_ratio is not None:
         | 
| 108 | 
            -
                        self.warmup_steps = int(warmup_ratio * max_steps)
         | 
| 109 | 
            -
                    else:
         | 
| 110 | 
            -
                        self.warmup_steps = 0
         | 
| 111 | 
            -
             | 
| 112 | 
            -
                    self.min_lr = min_lr
         | 
| 113 | 
            -
                    super().__init__(optimizer, last_epoch)
         | 
| 114 | 
            -
             | 
| 115 | 
            -
                def get_lr(self):
         | 
| 116 | 
            -
                    if not self._get_lr_called_within_step:
         | 
| 117 | 
            -
                        warnings.warn(
         | 
| 118 | 
            -
                            "To get the last learning rate computed "
         | 
| 119 | 
            -
                            "by the scheduler, please use `get_last_lr()`.",
         | 
| 120 | 
            -
                            UserWarning,
         | 
| 121 | 
            -
                            stacklevel=2,
         | 
| 122 | 
            -
                        )
         | 
| 123 | 
            -
             | 
| 124 | 
            -
                    step = self.last_epoch
         | 
| 125 | 
            -
             | 
| 126 | 
            -
                    if step <= self.warmup_steps and self.warmup_steps > 0:
         | 
| 127 | 
            -
                        return self._get_warmup_lr(step)
         | 
| 128 | 
            -
             | 
| 129 | 
            -
                    if step > self.max_steps:
         | 
| 130 | 
            -
                        return [self.min_lr for _ in self.base_lrs]
         | 
| 131 | 
            -
             | 
| 132 | 
            -
                    return self._get_lr(step)
         | 
| 133 | 
            -
             | 
| 134 | 
            -
                def _get_warmup_lr(self, step):
         | 
| 135 | 
            -
                    lr_val = (step + 1) / (self.warmup_steps + 1)
         | 
| 136 | 
            -
                    return [initial_lr * lr_val for initial_lr in self.base_lrs]
         | 
| 137 | 
            -
             | 
| 138 | 
            -
                def _get_lr(self, step):
         | 
| 139 | 
            -
                    """Simple const lr policy"""
         | 
| 140 | 
            -
                    return self.base_lrs
         | 
| 141 | 
            -
             | 
| 142 | 
            -
             | 
| 143 | 
            -
            class SquareRootConstantPolicy(_LRScheduler):
         | 
| 144 | 
            -
                """Adds warmup kwargs and warmup logic to lr policy.
         | 
| 145 | 
            -
                All arguments should be passed as kwargs for clarity,
         | 
| 146 | 
            -
                Args:
         | 
| 147 | 
            -
                    warmup_steps: Number of training steps in warmup stage
         | 
| 148 | 
            -
                    warmup_ratio: Ratio of warmup steps to total steps
         | 
| 149 | 
            -
                    max_steps: Total number of steps while training or `None` for
         | 
| 150 | 
            -
                        infinite training
         | 
| 151 | 
            -
                """
         | 
| 152 | 
            -
             | 
| 153 | 
            -
                def __init__(
         | 
| 154 | 
            -
                    self,
         | 
| 155 | 
            -
                    optimizer,
         | 
| 156 | 
            -
                    *,
         | 
| 157 | 
            -
                    constant_steps=None,
         | 
| 158 | 
            -
                    constant_ratio=None,
         | 
| 159 | 
            -
                    max_steps=None,
         | 
| 160 | 
            -
                    min_lr=0.0,
         | 
| 161 | 
            -
                    last_epoch=-1,
         | 
| 162 | 
            -
                ):
         | 
| 163 | 
            -
                    assert not (
         | 
| 164 | 
            -
                        constant_steps is not None and constant_ratio is not None
         | 
| 165 | 
            -
                    ), "Either use particular number of step or ratio"
         | 
| 166 | 
            -
                    assert (
         | 
| 167 | 
            -
                        constant_ratio is None or max_steps is not None
         | 
| 168 | 
            -
                    ), "If there is a ratio, there should be a total steps"
         | 
| 169 | 
            -
             | 
| 170 | 
            -
                    # It is necessary to assign all attributes *before* __init__,
         | 
| 171 | 
            -
                    # as class is wrapped by an inner class.
         | 
| 172 | 
            -
                    self.max_steps = max_steps
         | 
| 173 | 
            -
                    if constant_steps is not None:
         | 
| 174 | 
            -
                        self.constant_steps = constant_steps
         | 
| 175 | 
            -
                    elif constant_ratio is not None:
         | 
| 176 | 
            -
                        self.constant_steps = int(constant_ratio * max_steps)
         | 
| 177 | 
            -
                    else:
         | 
| 178 | 
            -
                        self.constant_steps = 0
         | 
| 179 | 
            -
             | 
| 180 | 
            -
                    self.constant_lr = 1 / (constant_steps**0.5)
         | 
| 181 | 
            -
                    self.min_lr = min_lr
         | 
| 182 | 
            -
                    super().__init__(optimizer, last_epoch)
         | 
| 183 | 
            -
             | 
| 184 | 
            -
                def get_lr(self):
         | 
| 185 | 
            -
                    if not self._get_lr_called_within_step:
         | 
| 186 | 
            -
                        warnings.warn(
         | 
| 187 | 
            -
                            "To get the last learning rate computed "
         | 
| 188 | 
            -
                            "by the scheduler, please use `get_last_lr()`.",
         | 
| 189 | 
            -
                            UserWarning,
         | 
| 190 | 
            -
                            stacklevel=2,
         | 
| 191 | 
            -
                        )
         | 
| 192 | 
            -
             | 
| 193 | 
            -
                    step = self.last_epoch
         | 
| 194 | 
            -
             | 
| 195 | 
            -
                    if step <= self.constant_steps:
         | 
| 196 | 
            -
                        return [self.constant_lr for _ in self.base_lrs]
         | 
| 197 | 
            -
             | 
| 198 | 
            -
                    if step > self.max_steps:
         | 
| 199 | 
            -
                        return [self.min_lr for _ in self.base_lrs]
         | 
| 200 | 
            -
             | 
| 201 | 
            -
                    return self._get_lr(step)
         | 
| 202 | 
            -
             | 
| 203 | 
            -
                def _get_lr(self, step):
         | 
| 204 | 
            -
                    """Simple const lr policy"""
         | 
| 205 | 
            -
                    return self.base_lrs
         | 
| 206 | 
            -
             | 
| 207 | 
            -
             | 
| 208 | 
            -
            class WarmupHoldPolicy(WarmupPolicy):
         | 
| 209 | 
            -
                """Variant of WarmupPolicy which maintains high
         | 
| 210 | 
            -
                   learning rate for a defined number of steps.
         | 
| 211 | 
            -
                All arguments should be passed as kwargs for clarity,
         | 
| 212 | 
            -
                Args:
         | 
| 213 | 
            -
                    warmup_steps: Number of training steps in warmup stage
         | 
| 214 | 
            -
                    warmup_ratio: Ratio of warmup steps to total steps
         | 
| 215 | 
            -
                    hold_steps: Number of training steps to
         | 
| 216 | 
            -
                                hold the learning rate after warm up
         | 
| 217 | 
            -
                    hold_ratio: Ratio of hold steps to total steps
         | 
| 218 | 
            -
                    max_steps: Total number of steps while training or `None` for
         | 
| 219 | 
            -
                        infinite training
         | 
| 220 | 
            -
                """
         | 
| 221 | 
            -
             | 
| 222 | 
            -
                def __init__(
         | 
| 223 | 
            -
                    self,
         | 
| 224 | 
            -
                    optimizer,
         | 
| 225 | 
            -
                    *,
         | 
| 226 | 
            -
                    warmup_steps=None,
         | 
| 227 | 
            -
                    warmup_ratio=None,
         | 
| 228 | 
            -
                    hold_steps=None,
         | 
| 229 | 
            -
                    hold_ratio=None,
         | 
| 230 | 
            -
                    max_steps=None,
         | 
| 231 | 
            -
                    min_lr=0.0,
         | 
| 232 | 
            -
                    last_epoch=-1,
         | 
| 233 | 
            -
                ):
         | 
| 234 | 
            -
                    assert not (
         | 
| 235 | 
            -
                        hold_steps is not None and hold_ratio is not None
         | 
| 236 | 
            -
                    ), "Either use particular number of step or ratio"
         | 
| 237 | 
            -
                    assert (
         | 
| 238 | 
            -
                        hold_ratio is None or max_steps is not None
         | 
| 239 | 
            -
                    ), "If there is a ratio, there should be a total steps"
         | 
| 240 | 
            -
             | 
| 241 | 
            -
                    self.min_lr = min_lr
         | 
| 242 | 
            -
                    self._last_warmup_lr = 0.0
         | 
| 243 | 
            -
             | 
| 244 | 
            -
                    # Necessary to duplicate as class attributes are hidden in inner class
         | 
| 245 | 
            -
                    self.max_steps = max_steps
         | 
| 246 | 
            -
                    if warmup_steps is not None:
         | 
| 247 | 
            -
                        self.warmup_steps = warmup_steps
         | 
| 248 | 
            -
                    elif warmup_ratio is not None:
         | 
| 249 | 
            -
                        self.warmup_steps = int(warmup_ratio * max_steps)
         | 
| 250 | 
            -
                    else:
         | 
| 251 | 
            -
                        self.warmup_steps = 0
         | 
| 252 | 
            -
             | 
| 253 | 
            -
                    if hold_steps is not None:
         | 
| 254 | 
            -
                        self.hold_steps = hold_steps + self.warmup_steps
         | 
| 255 | 
            -
                    elif hold_ratio is not None:
         | 
| 256 | 
            -
                        self.hold_steps = int(hold_ratio * max_steps) + self.warmup_steps
         | 
| 257 | 
            -
                    else:
         | 
| 258 | 
            -
                        self.hold_steps = 0
         | 
| 259 | 
            -
             | 
| 260 | 
            -
                    super().__init__(
         | 
| 261 | 
            -
                        optimizer,
         | 
| 262 | 
            -
                        warmup_steps=warmup_steps,
         | 
| 263 | 
            -
                        warmup_ratio=warmup_ratio,
         | 
| 264 | 
            -
                        max_steps=max_steps,
         | 
| 265 | 
            -
                        last_epoch=last_epoch,
         | 
| 266 | 
            -
                        min_lr=min_lr,
         | 
| 267 | 
            -
                    )
         | 
| 268 | 
            -
             | 
| 269 | 
            -
                def get_lr(self):
         | 
| 270 | 
            -
                    if not self._get_lr_called_within_step:
         | 
| 271 | 
            -
                        warnings.warn(
         | 
| 272 | 
            -
                            "To get the last learning rate computed by the scheduler,"
         | 
| 273 | 
            -
                            " "
         | 
| 274 | 
            -
                            "please use `get_last_lr()`.",
         | 
| 275 | 
            -
                            UserWarning,
         | 
| 276 | 
            -
                            stacklevel=2,
         | 
| 277 | 
            -
                        )
         | 
| 278 | 
            -
             | 
| 279 | 
            -
                    step = self.last_epoch
         | 
| 280 | 
            -
             | 
| 281 | 
            -
                    # Warmup phase
         | 
| 282 | 
            -
                    if step <= self.warmup_steps and self.warmup_steps > 0:
         | 
| 283 | 
            -
                        return self._get_warmup_lr(step)
         | 
| 284 | 
            -
             | 
| 285 | 
            -
                    # Hold phase
         | 
| 286 | 
            -
                    if (step >= self.warmup_steps) and (step < self.hold_steps):
         | 
| 287 | 
            -
                        return self.base_lrs
         | 
| 288 | 
            -
             | 
| 289 | 
            -
                    if step > self.max_steps:
         | 
| 290 | 
            -
                        return [self.min_lr for _ in self.base_lrs]
         | 
| 291 | 
            -
             | 
| 292 | 
            -
                    return self._get_lr(step)
         | 
| 293 | 
            -
             | 
| 294 | 
            -
             | 
| 295 | 
            -
            class WarmupAnnealHoldPolicy(_LRScheduler):
         | 
| 296 | 
            -
                """Adds warmup kwargs and warmup logic to lr policy.
         | 
| 297 | 
            -
                All arguments should be passed as kwargs for clarity,
         | 
| 298 | 
            -
                Args:
         | 
| 299 | 
            -
                    warmup_steps: Number of training steps in warmup stage
         | 
| 300 | 
            -
                    warmup_ratio: Ratio of warmup steps to total steps
         | 
| 301 | 
            -
                    max_steps: Total number of steps while training or `None` for
         | 
| 302 | 
            -
                        infinite training
         | 
| 303 | 
            -
                    min_lr: Minimum lr to hold the learning rate after decay at.
         | 
| 304 | 
            -
                    constant_steps: Number of steps to keep lr constant at.
         | 
| 305 | 
            -
                    constant_ratio: Ratio of steps to keep lr constant.
         | 
| 306 | 
            -
                """
         | 
| 307 | 
            -
             | 
| 308 | 
            -
                def __init__(
         | 
| 309 | 
            -
                    self,
         | 
| 310 | 
            -
                    optimizer,
         | 
| 311 | 
            -
                    *,
         | 
| 312 | 
            -
                    warmup_steps=None,
         | 
| 313 | 
            -
                    warmup_ratio=None,
         | 
| 314 | 
            -
                    constant_steps=None,
         | 
| 315 | 
            -
                    constant_ratio=None,
         | 
| 316 | 
            -
                    max_steps=None,
         | 
| 317 | 
            -
                    min_lr=0.0,
         | 
| 318 | 
            -
                    last_epoch=-1,
         | 
| 319 | 
            -
                ):
         | 
| 320 | 
            -
                    assert not (
         | 
| 321 | 
            -
                        warmup_steps is not None and warmup_ratio is not None
         | 
| 322 | 
            -
                    ), "Either use particular number of step or ratio"
         | 
| 323 | 
            -
                    assert not (
         | 
| 324 | 
            -
                        constant_steps is not None and constant_ratio is not None
         | 
| 325 | 
            -
                    ), "Either use constant_steps or constant_ratio"
         | 
| 326 | 
            -
                    assert (
         | 
| 327 | 
            -
                        warmup_ratio is None or max_steps is not None
         | 
| 328 | 
            -
                    ), "If there is a ratio, there should be a total steps"
         | 
| 329 | 
            -
             | 
| 330 | 
            -
                    # It is necessary to assign all attributes *before* __init__,
         | 
| 331 | 
            -
                    # as class is wrapped by an inner class.
         | 
| 332 | 
            -
                    self.max_steps = max_steps
         | 
| 333 | 
            -
             | 
| 334 | 
            -
                    if warmup_steps is not None:
         | 
| 335 | 
            -
                        self.warmup_steps = warmup_steps
         | 
| 336 | 
            -
                    elif warmup_ratio is not None:
         | 
| 337 | 
            -
                        self.warmup_steps = int(warmup_ratio * max_steps)
         | 
| 338 | 
            -
                    else:
         | 
| 339 | 
            -
                        self.warmup_steps = 0
         | 
| 340 | 
            -
             | 
| 341 | 
            -
                    if constant_steps is not None:
         | 
| 342 | 
            -
                        self.constant_steps = constant_steps
         | 
| 343 | 
            -
                    elif constant_ratio is not None:
         | 
| 344 | 
            -
                        self.constant_steps = int(constant_ratio * max_steps)
         | 
| 345 | 
            -
                    else:
         | 
| 346 | 
            -
                        self.constant_steps = 0
         | 
| 347 | 
            -
             | 
| 348 | 
            -
                    self.decay_steps = max_steps - (self.constant_steps + self.warmup_steps)
         | 
| 349 | 
            -
             | 
| 350 | 
            -
                    self.min_lr = min_lr
         | 
| 351 | 
            -
                    super().__init__(optimizer, last_epoch)
         | 
| 352 | 
            -
             | 
| 353 | 
            -
                def get_lr(self):
         | 
| 354 | 
            -
                    if not self._get_lr_called_within_step:
         | 
| 355 | 
            -
                        warnings.warn(
         | 
| 356 | 
            -
                            "To get the last learning rate computed "
         | 
| 357 | 
            -
                            "by the scheduler, please use `get_last_lr()`.",
         | 
| 358 | 
            -
                            UserWarning,
         | 
| 359 | 
            -
                            stacklevel=2,
         | 
| 360 | 
            -
                        )
         | 
| 361 | 
            -
             | 
| 362 | 
            -
                    step = self.last_epoch
         | 
| 363 | 
            -
             | 
| 364 | 
            -
                    # Warmup steps
         | 
| 365 | 
            -
                    if self.warmup_steps > 0 and step <= self.warmup_steps:
         | 
| 366 | 
            -
                        return self._get_warmup_lr(step)
         | 
| 367 | 
            -
             | 
| 368 | 
            -
                    # Constant steps after warmup and decay
         | 
| 369 | 
            -
                    if (
         | 
| 370 | 
            -
                        self.constant_steps > 0
         | 
| 371 | 
            -
                        and (self.warmup_steps + self.decay_steps) < step <= self.max_steps
         | 
| 372 | 
            -
                    ):
         | 
| 373 | 
            -
                        return self._get_constant_lr(step)
         | 
| 374 | 
            -
             | 
| 375 | 
            -
                    # Min lr after max steps of updates
         | 
| 376 | 
            -
                    if step > self.max_steps:
         | 
| 377 | 
            -
                        return [self.min_lr for _ in self.base_lrs]
         | 
| 378 | 
            -
             | 
| 379 | 
            -
                    return self._get_lr(step)
         | 
| 380 | 
            -
             | 
| 381 | 
            -
                def _get_warmup_lr(self, step):
         | 
| 382 | 
            -
                    lr_val = (step + 1) / (self.warmup_steps + 1)
         | 
| 383 | 
            -
                    return [initial_lr * lr_val for initial_lr in self.base_lrs]
         | 
| 384 | 
            -
             | 
| 385 | 
            -
                def _get_constant_lr(self, step):
         | 
| 386 | 
            -
                    return [self.min_lr for _ in self.base_lrs]
         | 
| 387 | 
            -
             | 
| 388 | 
            -
                def _get_lr(self, step):
         | 
| 389 | 
            -
                    """Simple const lr policy"""
         | 
| 390 | 
            -
                    return self.base_lrs
         | 
| 391 | 
            -
             | 
| 392 | 
            -
             | 
| 393 | 
            -
            def _squareroot_annealing(initial_lr, step, max_steps, min_lr):
         | 
| 394 | 
            -
                mult = ((max_steps - step) / max_steps) ** 0.5
         | 
| 395 | 
            -
                out_lr = initial_lr * mult
         | 
| 396 | 
            -
                out_lr = max(out_lr, min_lr)
         | 
| 397 | 
            -
                return out_lr
         | 
| 398 | 
            -
             | 
| 399 | 
            -
             | 
| 400 | 
            -
            def _square_annealing(initial_lr, step, max_steps, min_lr):
         | 
| 401 | 
            -
                mult = ((max_steps - step) / max_steps) ** 2
         | 
| 402 | 
            -
                out_lr = initial_lr * mult
         | 
| 403 | 
            -
                out_lr = max(out_lr, min_lr)
         | 
| 404 | 
            -
                return out_lr
         | 
| 405 | 
            -
             | 
| 406 | 
            -
             | 
| 407 | 
            -
            def _cosine_annealing(initial_lr, step, max_steps, min_lr):
         | 
| 408 | 
            -
                mult = 0.5 * (1 + math.cos(math.pi * step / max_steps))
         | 
| 409 | 
            -
                out_lr = (initial_lr - min_lr) * mult + min_lr
         | 
| 410 | 
            -
                return out_lr
         | 
| 411 | 
            -
             | 
| 412 | 
            -
             | 
| 413 | 
            -
            def _linear_warmup_with_cosine_annealing(
         | 
| 414 | 
            -
                max_lr, warmup_steps, step, decay_steps, min_lr
         | 
| 415 | 
            -
            ):
         | 
| 416 | 
            -
                assert max_lr > min_lr
         | 
| 417 | 
            -
                # Use linear warmup for the initial part.
         | 
| 418 | 
            -
                if warmup_steps > 0 and step <= warmup_steps:
         | 
| 419 | 
            -
                    return max_lr * float(step) / float(warmup_steps)
         | 
| 420 | 
            -
             | 
| 421 | 
            -
                # For any steps larger than `decay_steps`, use `min_lr`.
         | 
| 422 | 
            -
                if step > warmup_steps + decay_steps:
         | 
| 423 | 
            -
                    return min_lr
         | 
| 424 | 
            -
             | 
| 425 | 
            -
                # If we are done with the warmup period, use the decay style.
         | 
| 426 | 
            -
                num_steps_ = step - warmup_steps
         | 
| 427 | 
            -
                decay_steps_ = decay_steps
         | 
| 428 | 
            -
                decay_ratio = float(num_steps_) / float(decay_steps_)
         | 
| 429 | 
            -
                assert decay_ratio >= 0.0
         | 
| 430 | 
            -
                assert decay_ratio <= 1.0
         | 
| 431 | 
            -
                delta_lr = max_lr - min_lr
         | 
| 432 | 
            -
             | 
| 433 | 
            -
                coeff = 0.5 * (math.cos(math.pi * decay_ratio) + 1.0)
         | 
| 434 | 
            -
             | 
| 435 | 
            -
                return min_lr + coeff * delta_lr
         | 
| 436 | 
            -
             | 
| 437 | 
            -
             | 
| 438 | 
            -
            def _poly_decay(initial_lr, step, decay_steps, power, min_lr, cycle):
         | 
| 439 | 
            -
                if cycle:
         | 
| 440 | 
            -
                    multiplier = 1.0 if step == 0 else math.ceil(step / decay_steps)
         | 
| 441 | 
            -
                    decay_steps *= multiplier
         | 
| 442 | 
            -
                else:
         | 
| 443 | 
            -
                    step = min(step, decay_steps)
         | 
| 444 | 
            -
                p = step / decay_steps
         | 
| 445 | 
            -
                lr = (initial_lr - min_lr) * math.pow(1.0 - p, power)
         | 
| 446 | 
            -
                lr += min_lr
         | 
| 447 | 
            -
                return lr
         | 
| 448 | 
            -
             | 
| 449 | 
            -
             | 
| 450 | 
            -
            def _noam_hold_annealing(
         | 
| 451 | 
            -
                initial_lr, step, warmup_steps, hold_steps, decay_rate, min_lr
         | 
| 452 | 
            -
            ):
         | 
| 453 | 
            -
                # hold_steps = total number of steps
         | 
| 454 | 
            -
                # to hold the LR, not the warmup + hold steps.
         | 
| 455 | 
            -
                T_warmup_decay = max(1, warmup_steps**decay_rate)
         | 
| 456 | 
            -
                T_hold_decay = max(1, (step - hold_steps) ** decay_rate)
         | 
| 457 | 
            -
                lr = (initial_lr * T_warmup_decay) / T_hold_decay
         | 
| 458 | 
            -
                lr = max(lr, min_lr)
         | 
| 459 | 
            -
                return lr
         | 
| 460 | 
            -
             | 
| 461 | 
            -
             | 
| 462 | 
            -
            class SquareAnnealing(WarmupPolicy):
         | 
| 463 | 
            -
             | 
| 464 | 
            -
                def __init__(self, optimizer, *, max_steps, min_lr=1e-5, last_epoch=-1, **kwargs):
         | 
| 465 | 
            -
                    super().__init__(
         | 
| 466 | 
            -
                        optimizer=optimizer,
         | 
| 467 | 
            -
                        max_steps=max_steps,
         | 
| 468 | 
            -
                        last_epoch=last_epoch,
         | 
| 469 | 
            -
                        min_lr=min_lr,
         | 
| 470 | 
            -
                        **kwargs,
         | 
| 471 | 
            -
                    )
         | 
| 472 | 
            -
             | 
| 473 | 
            -
                def _get_lr(self, step):
         | 
| 474 | 
            -
                    new_lrs = [
         | 
| 475 | 
            -
                        _square_annealing(
         | 
| 476 | 
            -
                            initial_lr=initial_lr,
         | 
| 477 | 
            -
                            step=step - self.warmup_steps,
         | 
| 478 | 
            -
                            max_steps=self.max_steps - self.warmup_steps,
         | 
| 479 | 
            -
                            min_lr=self.min_lr,
         | 
| 480 | 
            -
                        )
         | 
| 481 | 
            -
                        for initial_lr in self.base_lrs
         | 
| 482 | 
            -
                    ]
         | 
| 483 | 
            -
                    return new_lrs
         | 
| 484 | 
            -
             | 
| 485 | 
            -
             | 
| 486 | 
            -
            class SquareRootAnnealing(WarmupPolicy):
         | 
| 487 | 
            -
             | 
| 488 | 
            -
                def __init__(self, optimizer, *, max_steps, min_lr=0, last_epoch=-1, **kwargs):
         | 
| 489 | 
            -
                    super().__init__(
         | 
| 490 | 
            -
                        optimizer=optimizer,
         | 
| 491 | 
            -
                        max_steps=max_steps,
         | 
| 492 | 
            -
                        last_epoch=last_epoch,
         | 
| 493 | 
            -
                        min_lr=min_lr,
         | 
| 494 | 
            -
                        **kwargs,
         | 
| 495 | 
            -
                    )
         | 
| 496 | 
            -
             | 
| 497 | 
            -
                def _get_lr(self, step):
         | 
| 498 | 
            -
                    new_lrs = [
         | 
| 499 | 
            -
                        _squareroot_annealing(
         | 
| 500 | 
            -
                            initial_lr=initial_lr,
         | 
| 501 | 
            -
                            step=step,
         | 
| 502 | 
            -
                            max_steps=self.max_steps,
         | 
| 503 | 
            -
                            min_lr=self.min_lr,
         | 
| 504 | 
            -
                        )
         | 
| 505 | 
            -
                        for initial_lr in self.base_lrs
         | 
| 506 | 
            -
                    ]
         | 
| 507 | 
            -
                    return new_lrs
         | 
| 508 | 
            -
             | 
| 509 | 
            -
             | 
| 510 | 
            -
            class CosineAnnealing(WarmupAnnealHoldPolicy):
         | 
| 511 | 
            -
             | 
| 512 | 
            -
                def __init__(self, optimizer, *, max_steps, min_lr=0, last_epoch=-1, **kwargs):
         | 
| 513 | 
            -
                    super().__init__(
         | 
| 514 | 
            -
                        optimizer=optimizer,
         | 
| 515 | 
            -
                        max_steps=max_steps,
         | 
| 516 | 
            -
                        last_epoch=last_epoch,
         | 
| 517 | 
            -
                        min_lr=min_lr,
         | 
| 518 | 
            -
                        **kwargs,
         | 
| 519 | 
            -
                    )
         | 
| 520 | 
            -
             | 
| 521 | 
            -
                def _get_lr(self, step):
         | 
| 522 | 
            -
                    for initial_lr in self.base_lrs:
         | 
| 523 | 
            -
                        if initial_lr < self.min_lr:
         | 
| 524 | 
            -
                            raise ValueError(
         | 
| 525 | 
            -
                                f"{self} received an initial learning rate "
         | 
| 526 | 
            -
                                f"that was lower than the minimum learning rate."
         | 
| 527 | 
            -
                            )
         | 
| 528 | 
            -
             | 
| 529 | 
            -
                    if self.constant_steps is None or self.constant_steps == 0:
         | 
| 530 | 
            -
                        new_lrs = [
         | 
| 531 | 
            -
                            _cosine_annealing(
         | 
| 532 | 
            -
                                initial_lr=initial_lr,
         | 
| 533 | 
            -
                                step=step - self.warmup_steps,
         | 
| 534 | 
            -
                                max_steps=self.max_steps - self.warmup_steps,
         | 
| 535 | 
            -
                                min_lr=self.min_lr,
         | 
| 536 | 
            -
                            )
         | 
| 537 | 
            -
                            for initial_lr in self.base_lrs
         | 
| 538 | 
            -
                        ]
         | 
| 539 | 
            -
                    else:
         | 
| 540 | 
            -
                        new_lrs = self._get_linear_warmup_with_cosine_annealing_lr(step)
         | 
| 541 | 
            -
                    return new_lrs
         | 
| 542 | 
            -
             | 
| 543 | 
            -
                def _get_warmup_lr(self, step):
         | 
| 544 | 
            -
                    if self.constant_steps is None or self.constant_steps == 0:
         | 
| 545 | 
            -
                        return super()._get_warmup_lr(step)
         | 
| 546 | 
            -
                    else:
         | 
| 547 | 
            -
                        # Use linear warmup for the initial part.
         | 
| 548 | 
            -
                        return self._get_linear_warmup_with_cosine_annealing_lr(step)
         | 
| 549 | 
            -
             | 
| 550 | 
            -
                def _get_constant_lr(self, step):
         | 
| 551 | 
            -
                    # Only called when `constant_steps` > 0.
         | 
| 552 | 
            -
                    return self._get_linear_warmup_with_cosine_annealing_lr(step)
         | 
| 553 | 
            -
             | 
| 554 | 
            -
                def _get_linear_warmup_with_cosine_annealing_lr(self, step):
         | 
| 555 | 
            -
                    # Cosine Schedule for Megatron LM,
         | 
| 556 | 
            -
                    # slightly different warmup schedule + constant LR at the end.
         | 
| 557 | 
            -
                    new_lrs = [
         | 
| 558 | 
            -
                        _linear_warmup_with_cosine_annealing(
         | 
| 559 | 
            -
                            max_lr=self.base_lrs[0],
         | 
| 560 | 
            -
                            warmup_steps=self.warmup_steps,
         | 
| 561 | 
            -
                            step=step,
         | 
| 562 | 
            -
                            decay_steps=self.decay_steps,
         | 
| 563 | 
            -
                            min_lr=self.min_lr,
         | 
| 564 | 
            -
                        )
         | 
| 565 | 
            -
                        for _ in self.base_lrs
         | 
| 566 | 
            -
                    ]
         | 
| 567 | 
            -
                    return new_lrs
         | 
| 568 | 
            -
             | 
| 569 | 
            -
             | 
| 570 | 
            -
            class NoamAnnealing(_LRScheduler):
         | 
| 571 | 
            -
             | 
| 572 | 
            -
                def __init__(
         | 
| 573 | 
            -
                    self,
         | 
| 574 | 
            -
                    optimizer,
         | 
| 575 | 
            -
                    *,
         | 
| 576 | 
            -
                    d_model,
         | 
| 577 | 
            -
                    warmup_steps=None,
         | 
| 578 | 
            -
                    warmup_ratio=None,
         | 
| 579 | 
            -
                    max_steps=None,
         | 
| 580 | 
            -
                    min_lr=0.0,
         | 
| 581 | 
            -
                    last_epoch=-1,
         | 
| 582 | 
            -
                ):
         | 
| 583 | 
            -
                    self._normalize = d_model ** (-0.5)
         | 
| 584 | 
            -
                    assert not (
         | 
| 585 | 
            -
                        warmup_steps is not None and warmup_ratio is not None
         | 
| 586 | 
            -
                    ), "Either use particular number of step or ratio"
         | 
| 587 | 
            -
                    assert (
         | 
| 588 | 
            -
                        warmup_ratio is None or max_steps is not None
         | 
| 589 | 
            -
                    ), "If there is a ratio, there should be a total steps"
         | 
| 590 | 
            -
             | 
| 591 | 
            -
                    # It is necessary to assign all attributes *before* __init__,
         | 
| 592 | 
            -
                    # as class is wrapped by an inner class.
         | 
| 593 | 
            -
                    self.max_steps = max_steps
         | 
| 594 | 
            -
                    if warmup_steps is not None:
         | 
| 595 | 
            -
                        self.warmup_steps = warmup_steps
         | 
| 596 | 
            -
                    elif warmup_ratio is not None:
         | 
| 597 | 
            -
                        self.warmup_steps = int(warmup_ratio * max_steps)
         | 
| 598 | 
            -
                    else:
         | 
| 599 | 
            -
                        self.warmup_steps = 0
         | 
| 600 | 
            -
             | 
| 601 | 
            -
                    self.min_lr = min_lr
         | 
| 602 | 
            -
                    super().__init__(optimizer, last_epoch)
         | 
| 603 | 
            -
             | 
| 604 | 
            -
                def get_lr(self):
         | 
| 605 | 
            -
                    if not self._get_lr_called_within_step:
         | 
| 606 | 
            -
                        warnings.warn(
         | 
| 607 | 
            -
                            "To get the last learning rate computed "
         | 
| 608 | 
            -
                            "by the scheduler, please use `get_last_lr()`.",
         | 
| 609 | 
            -
                            UserWarning,
         | 
| 610 | 
            -
                            stacklevel=2,
         | 
| 611 | 
            -
                        )
         | 
| 612 | 
            -
             | 
| 613 | 
            -
                    step = max(1, self.last_epoch)
         | 
| 614 | 
            -
             | 
| 615 | 
            -
                    for initial_lr in self.base_lrs:
         | 
| 616 | 
            -
                        if initial_lr < self.min_lr:
         | 
| 617 | 
            -
                            raise ValueError(
         | 
| 618 | 
            -
                                f"{self} received an initial learning rate "
         | 
| 619 | 
            -
                                f"that was lower than the minimum learning rate."
         | 
| 620 | 
            -
                            )
         | 
| 621 | 
            -
             | 
| 622 | 
            -
                    new_lrs = [
         | 
| 623 | 
            -
                        self._noam_annealing(initial_lr=initial_lr, step=step)
         | 
| 624 | 
            -
                        for initial_lr in self.base_lrs
         | 
| 625 | 
            -
                    ]
         | 
| 626 | 
            -
                    return new_lrs
         | 
| 627 | 
            -
             | 
| 628 | 
            -
                def _noam_annealing(self, initial_lr, step):
         | 
| 629 | 
            -
                    if self.warmup_steps > 0:
         | 
| 630 | 
            -
                        mult = self._normalize * min(
         | 
| 631 | 
            -
                            step ** (-0.5), step * (self.warmup_steps ** (-1.5))
         | 
| 632 | 
            -
                        )
         | 
| 633 | 
            -
                    else:
         | 
| 634 | 
            -
                        mult = self._normalize * step ** (-0.5)
         | 
| 635 | 
            -
             | 
| 636 | 
            -
                    out_lr = initial_lr * mult
         | 
| 637 | 
            -
                    if step > self.warmup_steps:
         | 
| 638 | 
            -
                        out_lr = max(out_lr, self.min_lr)
         | 
| 639 | 
            -
                    return out_lr
         | 
| 640 | 
            -
             | 
| 641 | 
            -
             | 
| 642 | 
            -
            class NoamHoldAnnealing(WarmupHoldPolicy):
         | 
| 643 | 
            -
             | 
| 644 | 
            -
                def __init__(
         | 
| 645 | 
            -
                    self,
         | 
| 646 | 
            -
                    optimizer,
         | 
| 647 | 
            -
                    *,
         | 
| 648 | 
            -
                    max_steps,
         | 
| 649 | 
            -
                    decay_rate=0.5,
         | 
| 650 | 
            -
                    min_lr=0.0,
         | 
| 651 | 
            -
                    last_epoch=-1,
         | 
| 652 | 
            -
                    **kwargs,
         | 
| 653 | 
            -
                ):
         | 
| 654 | 
            -
                    """
         | 
| 655 | 
            -
                    From Nemo:
         | 
| 656 | 
            -
                    Implementation of the Noam Hold Annealing policy
         | 
| 657 | 
            -
                    from the SqueezeFormer paper.
         | 
| 658 | 
            -
             | 
| 659 | 
            -
                    Unlike NoamAnnealing, the peak learning rate
         | 
| 660 | 
            -
                    can be explicitly set for this scheduler.
         | 
| 661 | 
            -
                    The schedule first performs linear warmup,
         | 
| 662 | 
            -
                    then holds the peak LR, then decays with some schedule for
         | 
| 663 | 
            -
                    the remainder of the steps.
         | 
| 664 | 
            -
                    Therefore the min-lr is still dependent
         | 
| 665 | 
            -
                    on the hyper parameters selected.
         | 
| 666 | 
            -
             | 
| 667 | 
            -
                    It's schedule is determined by three factors-
         | 
| 668 | 
            -
             | 
| 669 | 
            -
                    Warmup Steps: Initial stage, where linear warmup
         | 
| 670 | 
            -
                        occurs uptil the peak LR is reached. Unlike NoamAnnealing,
         | 
| 671 | 
            -
                        the peak LR is explicitly stated here instead of a scaling factor.
         | 
| 672 | 
            -
             | 
| 673 | 
            -
                    Hold Steps: Intermediate stage, where the peak LR
         | 
| 674 | 
            -
                        is maintained for some number of steps. In this region,
         | 
| 675 | 
            -
                        the high peak LR allows the model to converge faster
         | 
| 676 | 
            -
                        if training is stable. However the high LR
         | 
| 677 | 
            -
                        may also cause instability during training.
         | 
| 678 | 
            -
                        Should usually be a significant fraction of training
         | 
| 679 | 
            -
                        steps (around 30-40% of the entire training steps).
         | 
| 680 | 
            -
             | 
| 681 | 
            -
                    Decay Steps: Final stage, where the LR rapidly decays
         | 
| 682 | 
            -
                        with some scaling rate (set by decay rate).
         | 
| 683 | 
            -
                        To attain Noam decay, use 0.5,
         | 
| 684 | 
            -
                        for Squeezeformer recommended decay, use 1.0.
         | 
| 685 | 
            -
                        The fast decay after prolonged high LR during
         | 
| 686 | 
            -
                        hold phase allows for rapid convergence.
         | 
| 687 | 
            -
             | 
| 688 | 
            -
                    References:
         | 
| 689 | 
            -
                        - [Squeezeformer:
         | 
| 690 | 
            -
                        An Efficient Transformer for Automatic Speech Recognition]
         | 
| 691 | 
            -
                        (https://arxiv.org/abs/2206.00888)
         | 
| 692 | 
            -
             | 
| 693 | 
            -
                    Args:
         | 
| 694 | 
            -
                        optimizer: Pytorch compatible Optimizer object.
         | 
| 695 | 
            -
                        warmup_steps: Number of training steps in warmup stage
         | 
| 696 | 
            -
                        warmup_ratio: Ratio of warmup steps to total steps
         | 
| 697 | 
            -
                        hold_steps: Number of training steps to
         | 
| 698 | 
            -
                                    hold the learning rate after warm up
         | 
| 699 | 
            -
                        hold_ratio: Ratio of hold steps to total steps
         | 
| 700 | 
            -
                        max_steps: Total number of steps while training or `None` for
         | 
| 701 | 
            -
                            infinite training
         | 
| 702 | 
            -
                        decay_rate: Float value describing the polynomial decay
         | 
| 703 | 
            -
                                    after the hold period. Default value
         | 
| 704 | 
            -
                                    of 0.5 corresponds to Noam decay.
         | 
| 705 | 
            -
                        min_lr: Minimum learning rate.
         | 
| 706 | 
            -
                    """
         | 
| 707 | 
            -
                    self.decay_rate = decay_rate
         | 
| 708 | 
            -
                    super().__init__(
         | 
| 709 | 
            -
                        optimizer=optimizer,
         | 
| 710 | 
            -
                        max_steps=max_steps,
         | 
| 711 | 
            -
                        last_epoch=last_epoch,
         | 
| 712 | 
            -
                        min_lr=min_lr,
         | 
| 713 | 
            -
                        **kwargs,
         | 
| 714 | 
            -
                    )
         | 
| 715 | 
            -
             | 
| 716 | 
            -
                def _get_lr(self, step):
         | 
| 717 | 
            -
                    if self.warmup_steps is None or self.warmup_steps == 0:
         | 
| 718 | 
            -
                        raise ValueError("Noam scheduler cannot be used without warmup steps")
         | 
| 719 | 
            -
             | 
| 720 | 
            -
                    if self.hold_steps > 0:
         | 
| 721 | 
            -
                        hold_steps = self.hold_steps - self.warmup_steps
         | 
| 722 | 
            -
                    else:
         | 
| 723 | 
            -
                        hold_steps = 0
         | 
| 724 | 
            -
             | 
| 725 | 
            -
                    new_lrs = [
         | 
| 726 | 
            -
                        _noam_hold_annealing(
         | 
| 727 | 
            -
                            initial_lr,
         | 
| 728 | 
            -
                            step=step,
         | 
| 729 | 
            -
                            warmup_steps=self.warmup_steps,
         | 
| 730 | 
            -
                            hold_steps=hold_steps,
         | 
| 731 | 
            -
                            decay_rate=self.decay_rate,
         | 
| 732 | 
            -
                            min_lr=self.min_lr,
         | 
| 733 | 
            -
                        )
         | 
| 734 | 
            -
                        for initial_lr in self.base_lrs
         | 
| 735 | 
            -
                    ]
         | 
| 736 | 
            -
                    return new_lrs
         | 
| 737 | 
            -
             | 
| 738 | 
            -
                def set_step(self, step: int):
         | 
| 739 | 
            -
                    self.last_epoch = step
         | 
| 740 | 
            -
             | 
| 741 | 
            -
             | 
| 742 | 
            -
            class ConstantLR(_LRScheduler):
         | 
| 743 | 
            -
                """The ConstantLR scheduler
         | 
| 744 | 
            -
             | 
| 745 | 
            -
                This scheduler keeps a constant lr
         | 
| 746 | 
            -
             | 
| 747 | 
            -
                """
         | 
| 748 | 
            -
             | 
| 749 | 
            -
                def __init__(
         | 
| 750 | 
            -
                    self,
         | 
| 751 | 
            -
                    optimizer: torch.optim.Optimizer,
         | 
| 752 | 
            -
                ):
         | 
| 753 | 
            -
                    # __init__() must be invoked before setting field
         | 
| 754 | 
            -
                    # because step() is also invoked in __init__()
         | 
| 755 | 
            -
                    super().__init__(optimizer)
         | 
| 756 | 
            -
             | 
| 757 | 
            -
                def get_lr(self):
         | 
| 758 | 
            -
                    return self.base_lrs
         | 
| 759 | 
            -
             | 
| 760 | 
            -
                def set_step(self, step: int):
         | 
| 761 | 
            -
                    self.last_epoch = step
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        cosyvoice/utils/train_utils.py
    DELETED
    
    | @@ -1,350 +0,0 @@ | |
| 1 | 
            -
            # Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
         | 
| 2 | 
            -
            #               2023 Horizon Inc. (authors: Xingchen Song)
         | 
| 3 | 
            -
            #               2024 Alibaba Inc (authors: Xiang Lyu)
         | 
| 4 | 
            -
            #
         | 
| 5 | 
            -
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 6 | 
            -
            # you may not use this file except in compliance with the License.
         | 
| 7 | 
            -
            # You may obtain a copy of the License at
         | 
| 8 | 
            -
            #
         | 
| 9 | 
            -
            #     http://www.apache.org/licenses/LICENSE-2.0
         | 
| 10 | 
            -
            #
         | 
| 11 | 
            -
            # Unless required by applicable law or agreed to in writing, software
         | 
| 12 | 
            -
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 13 | 
            -
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 14 | 
            -
            # See the License for the specific language governing permissions and
         | 
| 15 | 
            -
            # limitations under the License.
         | 
| 16 | 
            -
             | 
| 17 | 
            -
            from contextlib import nullcontext
         | 
| 18 | 
            -
            import logging
         | 
| 19 | 
            -
            import os
         | 
| 20 | 
            -
            import torch
         | 
| 21 | 
            -
            import json
         | 
| 22 | 
            -
            import re
         | 
| 23 | 
            -
            import datetime
         | 
| 24 | 
            -
            import yaml
         | 
| 25 | 
            -
             | 
| 26 | 
            -
            import deepspeed
         | 
| 27 | 
            -
            import torch.optim as optim
         | 
| 28 | 
            -
            import torch.distributed as dist
         | 
| 29 | 
            -
             | 
| 30 | 
            -
            from torch.utils.tensorboard import SummaryWriter
         | 
| 31 | 
            -
            from torch.utils.data import DataLoader
         | 
| 32 | 
            -
            from torch.nn.utils import clip_grad_norm_
         | 
| 33 | 
            -
             | 
| 34 | 
            -
            from deepspeed.runtime.zero.stage_1_and_2 import (
         | 
| 35 | 
            -
                estimate_zero2_model_states_mem_needs_all_live,
         | 
| 36 | 
            -
            )
         | 
| 37 | 
            -
             | 
| 38 | 
            -
            from cosyvoice.dataset.dataset import Dataset
         | 
| 39 | 
            -
            from cosyvoice.utils.scheduler import (
         | 
| 40 | 
            -
                WarmupLR,
         | 
| 41 | 
            -
                NoamHoldAnnealing,
         | 
| 42 | 
            -
                ConstantLR,
         | 
| 43 | 
            -
            )
         | 
| 44 | 
            -
             | 
| 45 | 
            -
             | 
| 46 | 
            -
            def init_distributed(args):
         | 
| 47 | 
            -
                world_size = int(os.environ.get("WORLD_SIZE", 1))
         | 
| 48 | 
            -
                local_rank = int(os.environ.get("LOCAL_RANK", 0))
         | 
| 49 | 
            -
                rank = int(os.environ.get("RANK", 0))
         | 
| 50 | 
            -
                logging.info(
         | 
| 51 | 
            -
                    "training on multiple gpus, this gpu {}".format(local_rank)
         | 
| 52 | 
            -
                    + ", rank {}, world_size {}".format(rank, world_size)
         | 
| 53 | 
            -
                )
         | 
| 54 | 
            -
                if args.train_engine == "torch_ddp":
         | 
| 55 | 
            -
                    torch.cuda.set_device(local_rank)
         | 
| 56 | 
            -
                    dist.init_process_group(args.dist_backend)
         | 
| 57 | 
            -
                else:
         | 
| 58 | 
            -
                    deepspeed.init_distributed(dist_backend=args.dist_backend)
         | 
| 59 | 
            -
                return world_size, local_rank, rank
         | 
| 60 | 
            -
             | 
| 61 | 
            -
             | 
| 62 | 
            -
            def init_dataset_and_dataloader(args, configs):
         | 
| 63 | 
            -
                train_dataset = Dataset(
         | 
| 64 | 
            -
                    args.train_data,
         | 
| 65 | 
            -
                    data_pipeline=configs["data_pipeline"],
         | 
| 66 | 
            -
                    mode="train",
         | 
| 67 | 
            -
                    shuffle=True,
         | 
| 68 | 
            -
                    partition=True,
         | 
| 69 | 
            -
                )
         | 
| 70 | 
            -
                cv_dataset = Dataset(
         | 
| 71 | 
            -
                    args.cv_data,
         | 
| 72 | 
            -
                    data_pipeline=configs["data_pipeline"],
         | 
| 73 | 
            -
                    mode="train",
         | 
| 74 | 
            -
                    shuffle=False,
         | 
| 75 | 
            -
                    partition=False,
         | 
| 76 | 
            -
                )
         | 
| 77 | 
            -
             | 
| 78 | 
            -
                # do not use persistent_workers=True, as whisper tokenizer opens tiktoken file each time when the for loop starts
         | 
| 79 | 
            -
                train_data_loader = DataLoader(
         | 
| 80 | 
            -
                    train_dataset,
         | 
| 81 | 
            -
                    batch_size=None,
         | 
| 82 | 
            -
                    pin_memory=args.pin_memory,
         | 
| 83 | 
            -
                    num_workers=args.num_workers,
         | 
| 84 | 
            -
                    prefetch_factor=args.prefetch,
         | 
| 85 | 
            -
                )
         | 
| 86 | 
            -
                cv_data_loader = DataLoader(
         | 
| 87 | 
            -
                    cv_dataset,
         | 
| 88 | 
            -
                    batch_size=None,
         | 
| 89 | 
            -
                    pin_memory=args.pin_memory,
         | 
| 90 | 
            -
                    num_workers=args.num_workers,
         | 
| 91 | 
            -
                    prefetch_factor=args.prefetch,
         | 
| 92 | 
            -
                )
         | 
| 93 | 
            -
                return train_dataset, cv_dataset, train_data_loader, cv_data_loader
         | 
| 94 | 
            -
             | 
| 95 | 
            -
             | 
| 96 | 
            -
            def check_modify_and_save_config(args, configs):
         | 
| 97 | 
            -
                if args.train_engine == "torch_ddp":
         | 
| 98 | 
            -
                    configs["train_conf"]["dtype"] = "fp32"
         | 
| 99 | 
            -
                else:
         | 
| 100 | 
            -
                    with open(args.deepspeed_config, "r") as fin:
         | 
| 101 | 
            -
                        ds_configs = json.load(fin)
         | 
| 102 | 
            -
                    if "fp16" in ds_configs and ds_configs["fp16"]["enabled"]:
         | 
| 103 | 
            -
                        configs["train_conf"]["dtype"] = "fp16"
         | 
| 104 | 
            -
                    elif "bf16" in ds_configs and ds_configs["bf16"]["enabled"]:
         | 
| 105 | 
            -
                        configs["train_conf"]["dtype"] = "bf16"
         | 
| 106 | 
            -
                    else:
         | 
| 107 | 
            -
                        configs["train_conf"]["dtype"] = "fp32"
         | 
| 108 | 
            -
                    assert ds_configs["train_micro_batch_size_per_gpu"] == 1
         | 
| 109 | 
            -
                    # if use deepspeed, override ddp config
         | 
| 110 | 
            -
                    configs["train_conf"]["save_per_step"] = int(
         | 
| 111 | 
            -
                        configs["train_conf"]["save_per_step"]
         | 
| 112 | 
            -
                        * configs["train_conf"]["accum_grad"]
         | 
| 113 | 
            -
                        / ds_configs["gradient_accumulation_steps"]
         | 
| 114 | 
            -
                    )
         | 
| 115 | 
            -
                    configs["train_conf"]["accum_grad"] = ds_configs["gradient_accumulation_steps"]
         | 
| 116 | 
            -
                    configs["train_conf"]["grad_clip"] = ds_configs["gradient_clipping"]
         | 
| 117 | 
            -
                    configs["train_conf"]["log_interval"] = ds_configs["steps_per_print"]
         | 
| 118 | 
            -
                return configs
         | 
| 119 | 
            -
             | 
| 120 | 
            -
             | 
| 121 | 
            -
            def wrap_cuda_model(args, model):
         | 
| 122 | 
            -
                local_world_size = int(os.environ.get("LOCAL_WORLD_SIZE", 1))
         | 
| 123 | 
            -
                world_size = int(os.environ.get("WORLD_SIZE", 1))
         | 
| 124 | 
            -
                if args.train_engine == "torch_ddp":  # native pytorch ddp
         | 
| 125 | 
            -
                    assert torch.cuda.is_available()
         | 
| 126 | 
            -
                    model.cuda()
         | 
| 127 | 
            -
                    model = torch.nn.parallel.DistributedDataParallel(
         | 
| 128 | 
            -
                        model, find_unused_parameters=True
         | 
| 129 | 
            -
                    )
         | 
| 130 | 
            -
                else:
         | 
| 131 | 
            -
                    if int(os.environ.get("RANK", 0)) == 0:
         | 
| 132 | 
            -
                        logging.info("Estimating model states memory needs (zero2)...")
         | 
| 133 | 
            -
                        estimate_zero2_model_states_mem_needs_all_live(
         | 
| 134 | 
            -
                            model,
         | 
| 135 | 
            -
                            num_gpus_per_node=local_world_size,
         | 
| 136 | 
            -
                            num_nodes=world_size // local_world_size,
         | 
| 137 | 
            -
                        )
         | 
| 138 | 
            -
                return model
         | 
| 139 | 
            -
             | 
| 140 | 
            -
             | 
| 141 | 
            -
            def init_optimizer_and_scheduler(args, configs, model):
         | 
| 142 | 
            -
                if configs["train_conf"]["optim"] == "adam":
         | 
| 143 | 
            -
                    optimizer = optim.Adam(
         | 
| 144 | 
            -
                        model.parameters(), **configs["train_conf"]["optim_conf"]
         | 
| 145 | 
            -
                    )
         | 
| 146 | 
            -
                elif configs["train_conf"]["optim"] == "adamw":
         | 
| 147 | 
            -
                    optimizer = optim.AdamW(
         | 
| 148 | 
            -
                        model.parameters(), **configs["train_conf"]["optim_conf"]
         | 
| 149 | 
            -
                    )
         | 
| 150 | 
            -
                else:
         | 
| 151 | 
            -
                    raise ValueError("unknown optimizer: " + configs["train_conf"])
         | 
| 152 | 
            -
             | 
| 153 | 
            -
                if configs["train_conf"]["scheduler"] == "warmuplr":
         | 
| 154 | 
            -
                    scheduler_type = WarmupLR
         | 
| 155 | 
            -
                    scheduler = WarmupLR(optimizer, **configs["train_conf"]["scheduler_conf"])
         | 
| 156 | 
            -
                elif configs["train_conf"]["scheduler"] == "NoamHoldAnnealing":
         | 
| 157 | 
            -
                    scheduler_type = NoamHoldAnnealing
         | 
| 158 | 
            -
                    scheduler = NoamHoldAnnealing(
         | 
| 159 | 
            -
                        optimizer, **configs["train_conf"]["scheduler_conf"]
         | 
| 160 | 
            -
                    )
         | 
| 161 | 
            -
                elif configs["train_conf"]["scheduler"] == "constantlr":
         | 
| 162 | 
            -
                    scheduler_type = ConstantLR
         | 
| 163 | 
            -
                    scheduler = ConstantLR(optimizer)
         | 
| 164 | 
            -
                else:
         | 
| 165 | 
            -
                    raise ValueError("unknown scheduler: " + configs["train_conf"])
         | 
| 166 | 
            -
             | 
| 167 | 
            -
                # use deepspeed optimizer for speedup
         | 
| 168 | 
            -
                if args.train_engine == "deepspeed":
         | 
| 169 | 
            -
             | 
| 170 | 
            -
                    def scheduler(opt):
         | 
| 171 | 
            -
                        return scheduler_type(opt, **configs["train_conf"]["scheduler_conf"])
         | 
| 172 | 
            -
             | 
| 173 | 
            -
                    model, optimizer, _, scheduler = deepspeed.initialize(
         | 
| 174 | 
            -
                        args=args,
         | 
| 175 | 
            -
                        model=model,
         | 
| 176 | 
            -
                        optimizer=None,
         | 
| 177 | 
            -
                        lr_scheduler=scheduler,
         | 
| 178 | 
            -
                        model_parameters=model.parameters(),
         | 
| 179 | 
            -
                    )
         | 
| 180 | 
            -
             | 
| 181 | 
            -
                return model, optimizer, scheduler
         | 
| 182 | 
            -
             | 
| 183 | 
            -
             | 
| 184 | 
            -
            def init_summarywriter(args):
         | 
| 185 | 
            -
                writer = None
         | 
| 186 | 
            -
                if int(os.environ.get("RANK", 0)) == 0:
         | 
| 187 | 
            -
                    os.makedirs(args.model_dir, exist_ok=True)
         | 
| 188 | 
            -
                    writer = SummaryWriter(args.tensorboard_dir)
         | 
| 189 | 
            -
                return writer
         | 
| 190 | 
            -
             | 
| 191 | 
            -
             | 
| 192 | 
            -
            def save_model(model, model_name, info_dict):
         | 
| 193 | 
            -
                rank = int(os.environ.get("RANK", 0))
         | 
| 194 | 
            -
                model_dir = info_dict["model_dir"]
         | 
| 195 | 
            -
                save_model_path = os.path.join(model_dir, "{}.pt".format(model_name))
         | 
| 196 | 
            -
             | 
| 197 | 
            -
                if info_dict["train_engine"] == "torch_ddp":
         | 
| 198 | 
            -
                    if rank == 0:
         | 
| 199 | 
            -
                        torch.save(model.module.state_dict(), save_model_path)
         | 
| 200 | 
            -
                else:
         | 
| 201 | 
            -
                    with torch.no_grad():
         | 
| 202 | 
            -
                        model.save_checkpoint(
         | 
| 203 | 
            -
                            save_dir=model_dir, tag=model_name, client_state=info_dict
         | 
| 204 | 
            -
                        )
         | 
| 205 | 
            -
                if rank == 0:
         | 
| 206 | 
            -
                    info_path = re.sub(".pt$", ".yaml", save_model_path)
         | 
| 207 | 
            -
                    info_dict["save_time"] = datetime.datetime.now().strftime("%d/%m/%Y %H:%M:%S")
         | 
| 208 | 
            -
                    with open(info_path, "w") as fout:
         | 
| 209 | 
            -
                        data = yaml.dump(info_dict)
         | 
| 210 | 
            -
                        fout.write(data)
         | 
| 211 | 
            -
                    logging.info(
         | 
| 212 | 
            -
                        "[Rank {}] Checkpoint: save to checkpoint {}".format(rank, save_model_path)
         | 
| 213 | 
            -
                    )
         | 
| 214 | 
            -
             | 
| 215 | 
            -
             | 
| 216 | 
            -
            def cosyvoice_join(group_join, info_dict):
         | 
| 217 | 
            -
                world_size = int(os.environ.get("WORLD_SIZE", 1))
         | 
| 218 | 
            -
                local_rank = int(os.environ.get("LOCAL_RANK", 0))
         | 
| 219 | 
            -
                rank = int(os.environ.get("RANK", 0))
         | 
| 220 | 
            -
             | 
| 221 | 
            -
                if info_dict["batch_idx"] != 0:
         | 
| 222 | 
            -
                    # we try to join all rank in both ddp and deepspeed mode, in case different rank has different lr
         | 
| 223 | 
            -
                    try:
         | 
| 224 | 
            -
                        dist.monitored_barrier(
         | 
| 225 | 
            -
                            group=group_join, timeout=group_join.options._timeout
         | 
| 226 | 
            -
                        )
         | 
| 227 | 
            -
                        return False
         | 
| 228 | 
            -
                    except RuntimeError as e:
         | 
| 229 | 
            -
                        logging.info(
         | 
| 230 | 
            -
                            "Detected uneven workload distribution: {}\n".format(e)
         | 
| 231 | 
            -
                            + "Break current worker to manually join all workers, "
         | 
| 232 | 
            -
                            + "world_size {}, current rank {}, current local_rank {}\n".format(
         | 
| 233 | 
            -
                                world_size, rank, local_rank
         | 
| 234 | 
            -
                            )
         | 
| 235 | 
            -
                        )
         | 
| 236 | 
            -
                        return True
         | 
| 237 | 
            -
                else:
         | 
| 238 | 
            -
                    return False
         | 
| 239 | 
            -
             | 
| 240 | 
            -
             | 
| 241 | 
            -
            def batch_forward(model, batch, info_dict):
         | 
| 242 | 
            -
                device = int(os.environ.get("LOCAL_RANK", 0))
         | 
| 243 | 
            -
             | 
| 244 | 
            -
                dtype = info_dict["dtype"]
         | 
| 245 | 
            -
                if dtype == "fp16":
         | 
| 246 | 
            -
                    dtype = torch.float16
         | 
| 247 | 
            -
                elif dtype == "bf16":
         | 
| 248 | 
            -
                    dtype = torch.bfloat16
         | 
| 249 | 
            -
                else:  # fp32
         | 
| 250 | 
            -
                    dtype = torch.float32
         | 
| 251 | 
            -
             | 
| 252 | 
            -
                if info_dict["train_engine"] == "torch_ddp":
         | 
| 253 | 
            -
                    autocast = nullcontext()
         | 
| 254 | 
            -
                else:
         | 
| 255 | 
            -
                    autocast = torch.cuda.amp.autocast(
         | 
| 256 | 
            -
                        enabled=True, dtype=dtype, cache_enabled=False
         | 
| 257 | 
            -
                    )
         | 
| 258 | 
            -
             | 
| 259 | 
            -
                with autocast:
         | 
| 260 | 
            -
                    info_dict["loss_dict"] = model(batch, device)
         | 
| 261 | 
            -
                return info_dict
         | 
| 262 | 
            -
             | 
| 263 | 
            -
             | 
| 264 | 
            -
            def batch_backward(model, info_dict):
         | 
| 265 | 
            -
                if info_dict["train_engine"] == "deepspeed":
         | 
| 266 | 
            -
                    scaled_loss = model.backward(info_dict["loss_dict"]["loss"])
         | 
| 267 | 
            -
                else:
         | 
| 268 | 
            -
                    scaled_loss = info_dict["loss_dict"]["loss"] / info_dict["accum_grad"]
         | 
| 269 | 
            -
                    scaled_loss.backward()
         | 
| 270 | 
            -
             | 
| 271 | 
            -
                info_dict["loss_dict"]["loss"] = scaled_loss
         | 
| 272 | 
            -
                return info_dict
         | 
| 273 | 
            -
             | 
| 274 | 
            -
             | 
| 275 | 
            -
            def update_parameter_and_lr(model, optimizer, scheduler, info_dict):
         | 
| 276 | 
            -
                grad_norm = 0.0
         | 
| 277 | 
            -
                if info_dict["train_engine"] == "deepspeed":
         | 
| 278 | 
            -
                    info_dict["is_gradient_accumulation_boundary"] = (
         | 
| 279 | 
            -
                        model.is_gradient_accumulation_boundary()
         | 
| 280 | 
            -
                    )
         | 
| 281 | 
            -
                    model.step()
         | 
| 282 | 
            -
                    grad_norm = model.get_global_grad_norm()
         | 
| 283 | 
            -
                elif (info_dict["batch_idx"] + 1) % info_dict["accum_grad"] == 0:
         | 
| 284 | 
            -
                    grad_norm = clip_grad_norm_(model.parameters(), info_dict["grad_clip"])
         | 
| 285 | 
            -
                    if torch.isfinite(grad_norm):
         | 
| 286 | 
            -
                        optimizer.step()
         | 
| 287 | 
            -
                    optimizer.zero_grad()
         | 
| 288 | 
            -
                    scheduler.step()
         | 
| 289 | 
            -
                info_dict["lr"] = optimizer.param_groups[0]["lr"]
         | 
| 290 | 
            -
                info_dict["grad_norm"] = grad_norm
         | 
| 291 | 
            -
                return info_dict
         | 
| 292 | 
            -
             | 
| 293 | 
            -
             | 
| 294 | 
            -
            def log_per_step(writer, info_dict):
         | 
| 295 | 
            -
                tag = info_dict["tag"]
         | 
| 296 | 
            -
                epoch = info_dict.get("epoch", 0)
         | 
| 297 | 
            -
                step = info_dict["step"]
         | 
| 298 | 
            -
                batch_idx = info_dict["batch_idx"]
         | 
| 299 | 
            -
                loss_dict = info_dict["loss_dict"]
         | 
| 300 | 
            -
                rank = int(os.environ.get("RANK", 0))
         | 
| 301 | 
            -
             | 
| 302 | 
            -
                # only rank 0 write to tensorboard to avoid multi-process write
         | 
| 303 | 
            -
                if writer is not None:
         | 
| 304 | 
            -
                    if (
         | 
| 305 | 
            -
                        info_dict["train_engine"] == "deepspeed"
         | 
| 306 | 
            -
                        and info_dict["is_gradient_accumulation_boundary"] is True
         | 
| 307 | 
            -
                    ) or (
         | 
| 308 | 
            -
                        info_dict["train_engine"] == "torch_ddp"
         | 
| 309 | 
            -
                        and (info_dict["batch_idx"] + 1) % info_dict["accum_grad"] == 0
         | 
| 310 | 
            -
                    ):
         | 
| 311 | 
            -
                        for k in ["epoch", "lr", "grad_norm"]:
         | 
| 312 | 
            -
                            writer.add_scalar("{}/{}".format(tag, k), info_dict[k], step + 1)
         | 
| 313 | 
            -
                        for k, v in loss_dict.items():
         | 
| 314 | 
            -
                            writer.add_scalar("{}/{}".format(tag, k), v, step + 1)
         | 
| 315 | 
            -
             | 
| 316 | 
            -
                # TRAIN & CV, Shell log (stdout)
         | 
| 317 | 
            -
                if (info_dict["batch_idx"] + 1) % info_dict["log_interval"] == 0:
         | 
| 318 | 
            -
                    log_str = "{} Batch {}/{} ".format(tag, epoch, batch_idx + 1)
         | 
| 319 | 
            -
                    for name, value in loss_dict.items():
         | 
| 320 | 
            -
                        log_str += "{} {:.6f} ".format(name, value)
         | 
| 321 | 
            -
                    if tag == "TRAIN":
         | 
| 322 | 
            -
                        log_str += "lr {:.8f} grad_norm {:.6f}".format(
         | 
| 323 | 
            -
                            info_dict["lr"], info_dict["grad_norm"]
         | 
| 324 | 
            -
                        )
         | 
| 325 | 
            -
                    log_str += " rank {}".format(rank)
         | 
| 326 | 
            -
                    logging.debug(log_str)
         | 
| 327 | 
            -
             | 
| 328 | 
            -
             | 
| 329 | 
            -
            def log_per_save(writer, info_dict):
         | 
| 330 | 
            -
                tag = info_dict["tag"]
         | 
| 331 | 
            -
                epoch = info_dict["epoch"]
         | 
| 332 | 
            -
                step = info_dict["step"]
         | 
| 333 | 
            -
                loss_dict = info_dict["loss_dict"]
         | 
| 334 | 
            -
                lr = info_dict["lr"]
         | 
| 335 | 
            -
                rank = int(os.environ.get("RANK", 0))
         | 
| 336 | 
            -
                logging.info(
         | 
| 337 | 
            -
                    "Epoch {} Step {} CV info lr {} {} rank {}".format(
         | 
| 338 | 
            -
                        epoch,
         | 
| 339 | 
            -
                        step + 1,
         | 
| 340 | 
            -
                        lr,
         | 
| 341 | 
            -
                        rank,
         | 
| 342 | 
            -
                        " ".join(["{}_{}".format(k, v) for k, v in loss_dict.items()]),
         | 
| 343 | 
            -
                    )
         | 
| 344 | 
            -
                )
         | 
| 345 | 
            -
             | 
| 346 | 
            -
                if writer is not None:
         | 
| 347 | 
            -
                    for k in ["epoch", "lr"]:
         | 
| 348 | 
            -
                        writer.add_scalar("{}/{}".format(tag, k), info_dict[k], step + 1)
         | 
| 349 | 
            -
                    for k, v in loss_dict.items():
         | 
| 350 | 
            -
                        writer.add_scalar("{}/{}".format(tag, k), v, step + 1)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        examples/clone_wav_lixueqin.wav
    DELETED
    
    | @@ -1,3 +0,0 @@ | |
| 1 | 
            -
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            -
            oid sha256:4e9de6d9c0e98c466fec26a806f64a60d56faa99f641d389f9de8c259201bad5
         | 
| 3 | 
            -
            size 285774
         | 
|  | |
|  | |
|  | |
|  | 
