Spaces:
Runtime error
Runtime error
James Zhou
commited on
Commit
·
9867d34
1
Parent(s):
860b27a
[init]
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +2 -0
- LICENSE +77 -0
- app.py +814 -0
- assets/data_pipeline.png +3 -0
- assets/model_arch.png +3 -0
- assets/pan_chart.png +3 -0
- configs/hunyuanvideo-foley-xxl.yaml +49 -0
- examples/1_result.mp4 +3 -0
- examples/1_video.mp4 +3 -0
- examples/2_result.mp4 +3 -0
- examples/2_video.mp4 +3 -0
- examples/3_result.mp4 +3 -0
- examples/3_video.mp4 +3 -0
- examples/4_result.mp4 +3 -0
- examples/4_video.mp4 +3 -0
- examples/5_result.mp4 +3 -0
- examples/5_video.mp4 +3 -0
- examples/6_result.mp4 +3 -0
- examples/6_video.mp4 +3 -0
- examples/7_result.mp4 +3 -0
- examples/7_video.mp4 +3 -0
- examples/8_result.mp4 +3 -0
- examples/8_video.mp4 +3 -0
- hunyuanvideo_foley/__init__.py +0 -0
- hunyuanvideo_foley/__pycache__/__init__.cpython-312.pyc +0 -0
- hunyuanvideo_foley/constants.py +57 -0
- hunyuanvideo_foley/models/__init__.py +0 -0
- hunyuanvideo_foley/models/__pycache__/mmaudio_layer.cpython-312.pyc +0 -0
- hunyuanvideo_foley/models/dac_vae/__init__.py +16 -0
- hunyuanvideo_foley/models/dac_vae/__main__.py +36 -0
- hunyuanvideo_foley/models/dac_vae/model/__init__.py +4 -0
- hunyuanvideo_foley/models/dac_vae/model/base.py +301 -0
- hunyuanvideo_foley/models/dac_vae/model/dac.py +410 -0
- hunyuanvideo_foley/models/dac_vae/model/discriminator.py +228 -0
- hunyuanvideo_foley/models/dac_vae/nn/__init__.py +3 -0
- hunyuanvideo_foley/models/dac_vae/nn/layers.py +33 -0
- hunyuanvideo_foley/models/dac_vae/nn/loss.py +368 -0
- hunyuanvideo_foley/models/dac_vae/nn/quantize.py +262 -0
- hunyuanvideo_foley/models/dac_vae/nn/vae_utils.py +91 -0
- hunyuanvideo_foley/models/dac_vae/utils/__init__.py +121 -0
- hunyuanvideo_foley/models/dac_vae/utils/decode.py +95 -0
- hunyuanvideo_foley/models/dac_vae/utils/encode.py +94 -0
- hunyuanvideo_foley/models/hifi_foley.py +794 -0
- hunyuanvideo_foley/models/nn/__init__.py +0 -0
- hunyuanvideo_foley/models/nn/activation_layers.py +44 -0
- hunyuanvideo_foley/models/nn/attn_layers.py +546 -0
- hunyuanvideo_foley/models/nn/embed_layers.py +136 -0
- hunyuanvideo_foley/models/nn/mlp_layers.py +149 -0
- hunyuanvideo_foley/models/nn/modulate_layers.py +49 -0
- hunyuanvideo_foley/models/nn/norm_layers.py +70 -0
.gitattributes
CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
37 |
+
*.mp4 filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT
|
2 |
+
Tencent HunyuanVideo-Foley Release Date: August 28, 2025
|
3 |
+
THIS LICENSE AGREEMENT DOES NOT APPLY IN THE EUROPEAN UNION, UNITED KINGDOM AND SOUTH KOREA AND IS EXPRESSLY LIMITED TO THE TERRITORY, AS DEFINED BELOW.
|
4 |
+
By clicking to agree or by using, reproducing, modifying, distributing, performing or displaying any portion or element of the Tencent Hunyuan Works, including via any Hosted Service, You will be deemed to have recognized and accepted the content of this Agreement, which is effective immediately.
|
5 |
+
1. DEFINITIONS.
|
6 |
+
a. “Acceptable Use Policy” shall mean the policy made available by Tencent as set forth in the Exhibit A.
|
7 |
+
b. “Agreement” shall mean the terms and conditions for use, reproduction, distribution, modification, performance and displaying of Tencent Hunyuan Works or any portion or element thereof set forth herein.
|
8 |
+
c. “Documentation” shall mean the specifications, manuals and documentation for Tencent Hunyuan made publicly available by Tencent.
|
9 |
+
d. “Hosted Service” shall mean a hosted service offered via an application programming interface (API), web access, or any other electronic or remote means.
|
10 |
+
e. “Licensee,” “You” or “Your” shall mean a natural person or legal entity exercising the rights granted by this Agreement and/or using the Tencent Hunyuan Works for any purpose and in any field of use.
|
11 |
+
f. “Materials” shall mean, collectively, Tencent’s proprietary Tencent Hunyuan and Documentation (and any portion thereof) as made available by Tencent under this Agreement.
|
12 |
+
g. “Model Derivatives” shall mean all: (i) modifications to Tencent Hunyuan or any Model Derivative of Tencent Hunyuan; (ii) works based on Tencent Hunyuan or any Model Derivative of Tencent Hunyuan; or (iii) any other machine learning model which is created by transfer of patterns of the weights, parameters, operations, or Output of Tencent Hunyuan or any Model Derivative of Tencent Hunyuan, to that model in order to cause that model to perform similarly to Tencent Hunyuan or a Model Derivative of Tencent Hunyuan, including distillation methods, methods that use intermediate data representations, or methods based on the generation of synthetic data Outputs by Tencent Hunyuan or a Model Derivative of Tencent Hunyuan for training that model. For clarity, Outputs by themselves are not deemed Model Derivatives.
|
13 |
+
h. “Output” shall mean the information and/or content output of Tencent Hunyuan or a Model Derivative that results from operating or otherwise using Tencent Hunyuan or a Model Derivative, including via a Hosted Service.
|
14 |
+
i. “Tencent,” “We” or “Us” shall mean the applicable entity or entities in the Tencent corporate family that own(s) intellectual property or other rights embodied in or utilized by the Materials.
|
15 |
+
j. “Tencent Hunyuan” shall mean the large language models, text/image/video/audio/3D generation models, and multimodal large language models and their software and algorithms, including trained model weights, parameters (including optimizer states), machine-learning model code, inference-enabling code, training-enabling code, fine-tuning enabling code and other elements of the foregoing made publicly available by Us, including, without limitation to, Tencent HunyuanVideo-Foley released at [https://github.com/Tencent-Hunyuan/HunyuanVideo-Foley].
|
16 |
+
k. “Tencent Hunyuan Works” shall mean: (i) the Materials; (ii) Model Derivatives; and (iii) all derivative works thereof.
|
17 |
+
l. “Territory” shall mean the worldwide territory, excluding the territory of the European Union, United Kingdom and South Korea.
|
18 |
+
m. “Third Party” or “Third Parties” shall mean individuals or legal entities that are not under common control with Us or You.
|
19 |
+
n. “including” shall mean including but not limited to.
|
20 |
+
2. GRANT OF RIGHTS.
|
21 |
+
We grant You, for the Territory only, a non-exclusive, non-transferable and royalty-free limited license under Tencent’s intellectual property or other rights owned by Us embodied in or utilized by the Materials to use, reproduce, distribute, create derivative works of (including Model Derivatives), and make modifications to the Materials, only in accordance with the terms of this Agreement and the Acceptable Use Policy, and You must not violate (or encourage or permit anyone else to violate) any term of this Agreement or the Acceptable Use Policy.
|
22 |
+
3. DISTRIBUTION.
|
23 |
+
You may, subject to Your compliance with this Agreement, distribute or make available to Third Parties the Tencent Hunyuan Works, exclusively in the Territory, provided that You meet all of the following conditions:
|
24 |
+
a. You must provide all such Third Party recipients of the Tencent Hunyuan Works or products or services using them a copy of this Agreement;
|
25 |
+
b. You must cause any modified files to carry prominent notices stating that You changed the files;
|
26 |
+
c. You are encouraged to: (i) publish at least one technology introduction blogpost or one public statement expressing Your experience of using the Tencent Hunyuan Works; and (ii) mark the products or services developed by using the Tencent Hunyuan Works to indicate that the product/service is “Powered by Tencent Hunyuan”; and
|
27 |
+
d. All distributions to Third Parties (other than through a Hosted Service) must be accompanied by a “Notice” text file that contains the following notice: “Tencent Hunyuan is licensed under the Tencent Hunyuan Community License Agreement, Copyright © 2025 Tencent. All Rights Reserved. The trademark rights of “Tencent Hunyuan” are owned by Tencent or its affiliate.”
|
28 |
+
You may add Your own copyright statement to Your modifications and, except as set forth in this Section and in Section 5, may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Model Derivatives as a whole, provided Your use, reproduction, modification, distribution, performance and display of the work otherwise complies with the terms and conditions of this Agreement (including as regards the Territory). If You receive Tencent Hunyuan Works from a Licensee as part of an integrated end user product, then this Section 3 of this Agreement will not apply to You.
|
29 |
+
4. ADDITIONAL COMMERCIAL TERMS.
|
30 |
+
If, on the Tencent Hunyuan version release date, the monthly active users of all products or services made available by or for Licensee is greater than 100 million monthly active users in the preceding calendar month, You must request a license from Tencent, which Tencent may grant to You in its sole discretion, and You are not authorized to exercise any of the rights under this Agreement unless or until Tencent otherwise expressly grants You such rights.
|
31 |
+
5. RULES OF USE.
|
32 |
+
a. Your use of the Tencent Hunyuan Works must comply with applicable laws and regulations (including trade compliance laws and regulations) and adhere to the Acceptable Use Policy for the Tencent Hunyuan Works, which is hereby incorporated by reference into this Agreement. You must include the use restrictions referenced in these Sections 5(a) and 5(b) as an enforceable provision in any agreement (e.g., license agreement, terms of use, etc.) governing the use and/or distribution of Tencent Hunyuan Works and You must provide notice to subsequent users to whom You distribute that Tencent Hunyuan Works are subject to the use restrictions in these Sections 5(a) and 5(b).
|
33 |
+
b. You must not use the Tencent Hunyuan Works or any Output or results of the Tencent Hunyuan Works to improve any other AI model (other than Tencent Hunyuan or Model Derivatives thereof).
|
34 |
+
c. You must not use, reproduce, modify, distribute, or display the Tencent Hunyuan Works, Output or results of the Tencent Hunyuan Works outside the Territory. Any such use outside the Territory is unlicensed and unauthorized under this Agreement.
|
35 |
+
6. INTELLECTUAL PROPERTY.
|
36 |
+
a. Subject to Tencent’s ownership of Tencent Hunyuan Works made by or for Tencent and intellectual property rights therein, conditioned upon Your compliance with the terms and conditions of this Agreement, as between You and Tencent, You will be the owner of any derivative works and modifications of the Materials and any Model Derivatives that are made by or for You.
|
37 |
+
b. No trademark licenses are granted under this Agreement, and in connection with the Tencent Hunyuan Works, Licensee may not use any name or mark owned by or associated with Tencent or any of its affiliates, except as required for reasonable and customary use in describing and distributing the Tencent Hunyuan Works. Tencent hereby grants You a license to use “Tencent Hunyuan” (the “Mark”) in the Territory solely as required to comply with the provisions of Section 3(c), provided that You comply with any applicable laws related to trademark protection. All goodwill arising out of Your use of the Mark will inure to the benefit of Tencent.
|
38 |
+
c. If You commence a lawsuit or other proceedings (including a cross-claim or counterclaim in a lawsuit) against Us or any person or entity alleging that the Materials or any Output, or any portion of any of the foregoing, infringe any intellectual property or other right owned or licensable by You, then all licenses granted to You under this Agreement shall terminate as of the date such lawsuit or other proceeding is filed. You will defend, indemnify and hold harmless Us from and against any claim by any Third Party arising out of or related to Your or the Third Party’s use or distribution of the Tencent Hunyuan Works.
|
39 |
+
d. Tencent claims no rights in Outputs You generate. You and Your users are solely responsible for Outputs and their subsequent uses.
|
40 |
+
7. DISCLAIMERS OF WARRANTY AND LIMITATIONS OF LIABILITY.
|
41 |
+
a. We are not obligated to support, update, provide training for, or develop any further version of the Tencent Hunyuan Works or to grant any license thereto.
|
42 |
+
b. UNLESS AND ONLY TO THE EXTENT REQUIRED BY APPLICABLE LAW, THE TENCENT HUNYUAN WORKS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED “AS IS” WITHOUT ANY EXPRESS OR IMPLIED WARRANTIES OF ANY KIND INCLUDING ANY WARRANTIES OF TITLE, MERCHANTABILITY, NONINFRINGEMENT, COURSE OF DEALING, USAGE OF TRADE, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING, REPRODUCING, MODIFYING, PERFORMING, DISPLAYING OR DISTRIBUTING ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS AND ASSUME ANY AND ALL RISKS ASSOCIATED WITH YOUR OR A THIRD PARTY’S USE OR DISTRIBUTION OF ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS AND YOUR EXERCISE OF RIGHTS AND PERMISSIONS UNDER THIS AGREEMENT.
|
43 |
+
c. TO THE FULLEST EXTENT PERMITTED BY APPLICABLE LAW, IN NO EVENT SHALL TENCENT OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, FOR ANY DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, CONSEQUENTIAL OR PUNITIVE DAMAGES, OR LOST PROFITS OF ANY KIND ARISING FROM THIS AGREEMENT OR RELATED TO ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS, EVEN IF TENCENT OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
|
44 |
+
8. SURVIVAL AND TERMINATION.
|
45 |
+
a. The term of this Agreement shall commence upon Your acceptance of this Agreement or access to the Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein.
|
46 |
+
b. We may terminate this Agreement if You breach any of the terms or conditions of this Agreement. Upon termination of this Agreement, You must promptly delete and cease use of the Tencent Hunyuan Works. Sections 6(a), 6(c), 7 and 9 shall survive the termination of this Agreement.
|
47 |
+
9. GOVERNING LAW AND JURISDICTION.
|
48 |
+
a. This Agreement and any dispute arising out of or relating to it will be governed by the laws of the Hong Kong Special Administrative Region of the People’s Republic of China, without regard to conflict of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement.
|
49 |
+
b. Exclusive jurisdiction and venue for any dispute arising out of or relating to this Agreement will be a court of competent jurisdiction in the Hong Kong Special Administrative Region of the People’s Republic of China, and Tencent and Licensee consent to the exclusive jurisdiction of such court with respect to any such dispute.
|
50 |
+
|
51 |
+
EXHIBIT A
|
52 |
+
ACCEPTABLE USE POLICY
|
53 |
+
|
54 |
+
Tencent reserves the right to update this Acceptable Use Policy from time to time.
|
55 |
+
Last modified: November 5, 2024
|
56 |
+
|
57 |
+
Tencent endeavors to promote safe and fair use of its tools and features, including Tencent Hunyuan. You agree not to use Tencent Hunyuan or Model Derivatives:
|
58 |
+
1. Outside the Territory;
|
59 |
+
2. In any way that violates any applicable national, federal, state, local, international or any other law or regulation;
|
60 |
+
3. To harm Yourself or others;
|
61 |
+
4. To repurpose or distribute output from Tencent Hunyuan or any Model Derivatives to harm Yourself or others;
|
62 |
+
5. To override or circumvent the safety guardrails and safeguards We have put in place;
|
63 |
+
6. For the purpose of exploiting, harming or attempting to exploit or harm minors in any way;
|
64 |
+
7. To generate or disseminate verifiably false information and/or content with the purpose of harming others or influencing elections;
|
65 |
+
8. To generate or facilitate false online engagement, including fake reviews and other means of fake online engagement;
|
66 |
+
9. To intentionally defame, disparage or otherwise harass others;
|
67 |
+
10. To generate and/or disseminate malware (including ransomware) or any other content to be used for the purpose of harming electronic systems;
|
68 |
+
11. To generate or disseminate personal identifiable information with the purpose of harming others;
|
69 |
+
12. To generate or disseminate information (including images, code, posts, articles), and place the information in any public context (including –through the use of bot generated tweets), without expressly and conspicuously identifying that the information and/or content is machine generated;
|
70 |
+
13. To impersonate another individual without consent, authorization, or legal right;
|
71 |
+
14. To make high-stakes automated decisions in domains that affect an individual’s safety, rights or wellbeing (e.g., law enforcement, migration, medicine/health, management of critical infrastructure, safety components of products, essential services, credit, employment, housing, education, social scoring, or insurance);
|
72 |
+
15. In a manner that violates or disrespects the social ethics and moral standards of other countries or regions;
|
73 |
+
16. To perform, facilitate, threaten, incite, plan, promote or encourage violent extremism or terrorism;
|
74 |
+
17. For any use intended to discriminate against or harm individuals or groups based on protected characteristics or categories, online or offline social behavior or known or predicted personal or personality characteristics;
|
75 |
+
18. To intentionally exploit any of the vulnerabilities of a specific group of persons based on their age, social, physical or mental characteristics, in order to materially distort the behavior of a person pertaining to that group in a manner that causes or is likely to cause that person or another person physical or psychological harm;
|
76 |
+
19. For military purposes;
|
77 |
+
20. To engage in the unauthorized or unlicensed practice of any profession including, but not limited to, financial, legal, medical/health, or other professional practices.
|
app.py
ADDED
@@ -0,0 +1,814 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import tempfile
|
3 |
+
import gradio as gr
|
4 |
+
import torch
|
5 |
+
import torchaudio
|
6 |
+
from loguru import logger
|
7 |
+
from typing import Optional, Tuple
|
8 |
+
import random
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
from hunyuanvideo_foley.utils.model_utils import load_model
|
12 |
+
from hunyuanvideo_foley.utils.feature_utils import feature_process
|
13 |
+
from hunyuanvideo_foley.utils.model_utils import denoise_process
|
14 |
+
from hunyuanvideo_foley.utils.media_utils import merge_audio_video
|
15 |
+
|
16 |
+
# Global variables for model storage
|
17 |
+
model_dict = None
|
18 |
+
cfg = None
|
19 |
+
device = None
|
20 |
+
|
21 |
+
# need to modify the model path
|
22 |
+
MODEL_PATH = os.environ.get("HIFI_FOLEY_MODEL_PATH", "./pretrained_models/")
|
23 |
+
CONFIG_PATH = "configs/hunyuanvideo-foley-xxl.yaml"
|
24 |
+
|
25 |
+
def setup_device(device_str: str = "auto", gpu_id: int = 0) -> torch.device:
|
26 |
+
"""Setup computing device"""
|
27 |
+
if device_str == "auto":
|
28 |
+
if torch.cuda.is_available():
|
29 |
+
device = torch.device(f"cuda:{gpu_id}")
|
30 |
+
logger.info(f"Using CUDA device: {device}")
|
31 |
+
elif torch.backends.mps.is_available():
|
32 |
+
device = torch.device("mps")
|
33 |
+
logger.info("Using MPS device")
|
34 |
+
else:
|
35 |
+
device = torch.device("cpu")
|
36 |
+
logger.info("Using CPU device")
|
37 |
+
else:
|
38 |
+
if device_str == "cuda":
|
39 |
+
device = torch.device(f"cuda:{gpu_id}")
|
40 |
+
else:
|
41 |
+
device = torch.device(device_str)
|
42 |
+
logger.info(f"Using specified device: {device}")
|
43 |
+
|
44 |
+
return device
|
45 |
+
|
46 |
+
def auto_load_models() -> str:
|
47 |
+
"""Automatically load preset models"""
|
48 |
+
global model_dict, cfg, device
|
49 |
+
|
50 |
+
try:
|
51 |
+
if not os.path.exists(MODEL_PATH):
|
52 |
+
return f"❌ Model file not found: {MODEL_PATH}"
|
53 |
+
if not os.path.exists(CONFIG_PATH):
|
54 |
+
return f"❌ Config file not found: {CONFIG_PATH}"
|
55 |
+
|
56 |
+
# Use GPU by default
|
57 |
+
device = setup_device("auto", 0)
|
58 |
+
|
59 |
+
# Load model
|
60 |
+
logger.info("Auto-loading model...")
|
61 |
+
logger.info(f"Model path: {MODEL_PATH}")
|
62 |
+
logger.info(f"Config path: {CONFIG_PATH}")
|
63 |
+
|
64 |
+
model_dict, cfg = load_model(MODEL_PATH, CONFIG_PATH, device)
|
65 |
+
|
66 |
+
logger.info("✅ Model loaded successfully!")
|
67 |
+
return "✅ Model loaded successfully!"
|
68 |
+
|
69 |
+
except Exception as e:
|
70 |
+
logger.error(f"Model loading failed: {str(e)}")
|
71 |
+
return f"❌ Model loading failed: {str(e)}"
|
72 |
+
|
73 |
+
def infer_single_video(
|
74 |
+
video_file,
|
75 |
+
text_prompt: str,
|
76 |
+
guidance_scale: float = 4.5,
|
77 |
+
num_inference_steps: int = 50,
|
78 |
+
sample_nums: int = 1
|
79 |
+
) -> Tuple[list, str]:
|
80 |
+
"""Single video inference"""
|
81 |
+
global model_dict, cfg, device
|
82 |
+
|
83 |
+
if model_dict is None or cfg is None:
|
84 |
+
return [], "❌ Please load the model first!"
|
85 |
+
|
86 |
+
if video_file is None:
|
87 |
+
return [], "❌ Please upload a video file!"
|
88 |
+
|
89 |
+
# Allow empty text prompt, use empty string if no prompt provided
|
90 |
+
if text_prompt is None:
|
91 |
+
text_prompt = ""
|
92 |
+
text_prompt = text_prompt.strip()
|
93 |
+
|
94 |
+
try:
|
95 |
+
logger.info(f"Processing video: {video_file}")
|
96 |
+
logger.info(f"Text prompt: {text_prompt}")
|
97 |
+
|
98 |
+
# Feature processing
|
99 |
+
visual_feats, text_feats, audio_len_in_s = feature_process(
|
100 |
+
video_file,
|
101 |
+
text_prompt,
|
102 |
+
model_dict,
|
103 |
+
cfg
|
104 |
+
)
|
105 |
+
|
106 |
+
# Denoising process to generate multiple audio samples
|
107 |
+
# Note: The model now generates sample_nums audio samples per inference
|
108 |
+
# The denoise_process function returns audio with shape [batch_size, channels, samples]
|
109 |
+
logger.info(f"Generating {sample_nums} audio samples...")
|
110 |
+
audio, sample_rate = denoise_process(
|
111 |
+
visual_feats,
|
112 |
+
text_feats,
|
113 |
+
audio_len_in_s,
|
114 |
+
model_dict,
|
115 |
+
cfg,
|
116 |
+
guidance_scale=guidance_scale,
|
117 |
+
num_inference_steps=num_inference_steps,
|
118 |
+
batch_size=sample_nums
|
119 |
+
)
|
120 |
+
|
121 |
+
# Create temporary files to save results
|
122 |
+
temp_dir = tempfile.mkdtemp()
|
123 |
+
video_outputs = []
|
124 |
+
|
125 |
+
# Process each generated audio sample
|
126 |
+
for i in range(sample_nums):
|
127 |
+
# Save audio file
|
128 |
+
audio_output = os.path.join(temp_dir, f"generated_audio_{i+1}.wav")
|
129 |
+
torchaudio.save(audio_output, audio[i], sample_rate)
|
130 |
+
|
131 |
+
# Merge video and audio
|
132 |
+
video_output = os.path.join(temp_dir, f"video_with_audio_{i+1}.mp4")
|
133 |
+
merge_audio_video(audio_output, video_file, video_output)
|
134 |
+
video_outputs.append(video_output)
|
135 |
+
|
136 |
+
logger.info(f"Inference completed! Generated {sample_nums} samples.")
|
137 |
+
return video_outputs, f"✅ Generated {sample_nums} audio sample(s) successfully!"
|
138 |
+
|
139 |
+
except Exception as e:
|
140 |
+
logger.error(f"Inference failed: {str(e)}")
|
141 |
+
return [], f"❌ Inference failed: {str(e)}"
|
142 |
+
|
143 |
+
def update_video_outputs(video_list, status_msg):
|
144 |
+
"""Update video outputs based on the number of generated samples"""
|
145 |
+
# Initialize all outputs as None
|
146 |
+
outputs = [None] * 6
|
147 |
+
|
148 |
+
# Set values based on generated videos
|
149 |
+
for i, video_path in enumerate(video_list[:6]): # Max 6 samples
|
150 |
+
outputs[i] = video_path
|
151 |
+
|
152 |
+
# Return all outputs plus status message
|
153 |
+
return tuple(outputs + [status_msg])
|
154 |
+
|
155 |
+
def create_gradio_interface():
|
156 |
+
"""Create Gradio interface"""
|
157 |
+
|
158 |
+
# Custom CSS for beautiful interface with better contrast
|
159 |
+
css = """
|
160 |
+
.gradio-container {
|
161 |
+
font-family: 'Inter', -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif;
|
162 |
+
background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%);
|
163 |
+
min-height: 100vh;
|
164 |
+
}
|
165 |
+
|
166 |
+
.main-header {
|
167 |
+
text-align: center;
|
168 |
+
padding: 2rem 0;
|
169 |
+
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
170 |
+
border-radius: 20px;
|
171 |
+
margin-bottom: 2rem;
|
172 |
+
box-shadow: 0 8px 32px rgba(0,0,0,0.15);
|
173 |
+
}
|
174 |
+
|
175 |
+
.main-header h1 {
|
176 |
+
color: white;
|
177 |
+
font-size: 3rem;
|
178 |
+
font-weight: 700;
|
179 |
+
margin-bottom: 0.5rem;
|
180 |
+
text-shadow: 0 2px 10px rgba(0,0,0,0.3);
|
181 |
+
}
|
182 |
+
|
183 |
+
.main-header p {
|
184 |
+
color: rgba(255, 255, 255, 0.95);
|
185 |
+
font-size: 1.2rem;
|
186 |
+
font-weight: 300;
|
187 |
+
}
|
188 |
+
|
189 |
+
.status-card {
|
190 |
+
background: white;
|
191 |
+
border-radius: 15px;
|
192 |
+
padding: 1rem;
|
193 |
+
margin-bottom: 1.5rem;
|
194 |
+
border: 1px solid #e1e5e9;
|
195 |
+
box-shadow: 0 4px 20px rgba(0,0,0,0.08);
|
196 |
+
}
|
197 |
+
|
198 |
+
.status-card label {
|
199 |
+
color: #2d3748 !important;
|
200 |
+
font-weight: 600 !important;
|
201 |
+
}
|
202 |
+
|
203 |
+
.usage-guide h3 {
|
204 |
+
color: #2d3748 !important;
|
205 |
+
font-weight: 600 !important;
|
206 |
+
margin-bottom: 0.5rem !important;
|
207 |
+
}
|
208 |
+
|
209 |
+
.usage-guide p {
|
210 |
+
color: #4a5568 !important;
|
211 |
+
font-size: 1rem !important;
|
212 |
+
line-height: 1.6 !important;
|
213 |
+
margin: 0.5rem 0 !important;
|
214 |
+
}
|
215 |
+
|
216 |
+
.usage-guide strong {
|
217 |
+
color: #1a202c !important;
|
218 |
+
font-weight: 700 !important;
|
219 |
+
}
|
220 |
+
|
221 |
+
.usage-guide em {
|
222 |
+
color: #1a202c !important;
|
223 |
+
font-weight: 700 !important;
|
224 |
+
font-style: normal !important;
|
225 |
+
}
|
226 |
+
|
227 |
+
.main-interface {
|
228 |
+
margin-bottom: 2rem;
|
229 |
+
}
|
230 |
+
|
231 |
+
.input-section {
|
232 |
+
background: white;
|
233 |
+
border-radius: 20px;
|
234 |
+
padding: 2rem;
|
235 |
+
margin-right: 1rem;
|
236 |
+
box-shadow: 0 8px 32px rgba(0,0,0,0.1);
|
237 |
+
border: 1px solid #e1e5e9;
|
238 |
+
}
|
239 |
+
|
240 |
+
.input-section h3 {
|
241 |
+
color: #2d3748 !important;
|
242 |
+
font-weight: 600 !important;
|
243 |
+
margin-bottom: 1rem !important;
|
244 |
+
}
|
245 |
+
|
246 |
+
.input-section label {
|
247 |
+
color: #4a5568 !important;
|
248 |
+
font-weight: 500 !important;
|
249 |
+
}
|
250 |
+
|
251 |
+
.output-section {
|
252 |
+
background: white;
|
253 |
+
border-radius: 20px;
|
254 |
+
padding: 2rem;
|
255 |
+
margin-left: 1rem;
|
256 |
+
box-shadow: 0 8px 32px rgba(0,0,0,0.1);
|
257 |
+
border: 1px solid #e1e5e9;
|
258 |
+
}
|
259 |
+
|
260 |
+
.output-section h3 {
|
261 |
+
color: #2d3748 !important;
|
262 |
+
font-weight: 600 !important;
|
263 |
+
margin-bottom: 1rem !important;
|
264 |
+
}
|
265 |
+
|
266 |
+
.output-section label {
|
267 |
+
color: #4a5568 !important;
|
268 |
+
font-weight: 500 !important;
|
269 |
+
}
|
270 |
+
|
271 |
+
.examples-section h3 {
|
272 |
+
color: #2d3748 !important;
|
273 |
+
font-weight: 600 !important;
|
274 |
+
margin-bottom: 1.5rem !important;
|
275 |
+
}
|
276 |
+
|
277 |
+
.generate-btn {
|
278 |
+
background: linear-gradient(45deg, #667eea, #764ba2) !important;
|
279 |
+
border: none !important;
|
280 |
+
color: white !important;
|
281 |
+
font-weight: 600 !important;
|
282 |
+
font-size: 1.1rem !important;
|
283 |
+
padding: 12px 30px !important;
|
284 |
+
border-radius: 25px !important;
|
285 |
+
box-shadow: 0 4px 15px rgba(102, 126, 234, 0.4) !important;
|
286 |
+
transition: all 0.3s ease !important;
|
287 |
+
}
|
288 |
+
|
289 |
+
.generate-btn:hover {
|
290 |
+
transform: translateY(-2px) !important;
|
291 |
+
box-shadow: 0 8px 25px rgba(102, 126, 234, 0.6) !important;
|
292 |
+
}
|
293 |
+
|
294 |
+
|
295 |
+
|
296 |
+
.examples-section {
|
297 |
+
background: white;
|
298 |
+
border-radius: 20px;
|
299 |
+
padding: 2rem;
|
300 |
+
margin-top: 2rem;
|
301 |
+
box-shadow: 0 8px 32px rgba(0,0,0,0.1);
|
302 |
+
border: 1px solid #e1e5e9;
|
303 |
+
}
|
304 |
+
|
305 |
+
.examples-section p {
|
306 |
+
color: #4a5568 !important;
|
307 |
+
margin-bottom: 1rem !important;
|
308 |
+
}
|
309 |
+
|
310 |
+
.example-row {
|
311 |
+
background: #f8fafc;
|
312 |
+
border: 1px solid #e2e8f0;
|
313 |
+
border-radius: 15px;
|
314 |
+
padding: 1.5rem;
|
315 |
+
margin: 1rem 0;
|
316 |
+
transition: all 0.3s ease;
|
317 |
+
align-items: center;
|
318 |
+
}
|
319 |
+
|
320 |
+
.example-row:hover {
|
321 |
+
border-color: #667eea;
|
322 |
+
transform: translateY(-2px);
|
323 |
+
box-shadow: 0 4px 20px rgba(102, 126, 234, 0.15);
|
324 |
+
}
|
325 |
+
|
326 |
+
.example-row .markdown {
|
327 |
+
color: #2d3748 !important;
|
328 |
+
}
|
329 |
+
|
330 |
+
.example-row .markdown p {
|
331 |
+
color: #2d3748 !important;
|
332 |
+
margin: 0.5rem 0 !important;
|
333 |
+
line-height: 1.5 !important;
|
334 |
+
}
|
335 |
+
|
336 |
+
.example-row .markdown strong {
|
337 |
+
color: #1a202c !important;
|
338 |
+
font-weight: 600 !important;
|
339 |
+
}
|
340 |
+
|
341 |
+
/* Example grid layout styles */
|
342 |
+
.example-grid-row {
|
343 |
+
margin: 1rem 0;
|
344 |
+
gap: 1rem;
|
345 |
+
}
|
346 |
+
|
347 |
+
.example-item {
|
348 |
+
background: #f8fafc;
|
349 |
+
border: 1px solid #e2e8f0;
|
350 |
+
border-radius: 15px;
|
351 |
+
padding: 1rem;
|
352 |
+
transition: all 0.3s ease;
|
353 |
+
margin: 0.25rem;
|
354 |
+
max-width: 250px;
|
355 |
+
margin-left: auto;
|
356 |
+
margin-right: auto;
|
357 |
+
}
|
358 |
+
|
359 |
+
.example-item:hover {
|
360 |
+
border-color: #667eea;
|
361 |
+
transform: translateY(-2px);
|
362 |
+
box-shadow: 0 4px 20px rgba(102, 126, 234, 0.15);
|
363 |
+
}
|
364 |
+
|
365 |
+
.example-caption {
|
366 |
+
margin: 0.5rem 0 !important;
|
367 |
+
min-height: 2.8rem !important;
|
368 |
+
display: flex !important;
|
369 |
+
align-items: flex-start !important;
|
370 |
+
}
|
371 |
+
|
372 |
+
.example-caption p {
|
373 |
+
color: #2d3748 !important;
|
374 |
+
font-size: 0.9rem !important;
|
375 |
+
line-height: 1.4 !important;
|
376 |
+
margin: 0.5rem 0 !important;
|
377 |
+
}
|
378 |
+
|
379 |
+
/* Multi-video gallery styles */
|
380 |
+
.additional-samples {
|
381 |
+
margin-top: 1rem;
|
382 |
+
gap: 0.5rem;
|
383 |
+
}
|
384 |
+
|
385 |
+
.additional-samples .gradio-video {
|
386 |
+
border-radius: 10px;
|
387 |
+
overflow: hidden;
|
388 |
+
}
|
389 |
+
|
390 |
+
/* Video gallery responsive layout */
|
391 |
+
.video-gallery {
|
392 |
+
display: grid;
|
393 |
+
gap: 1rem;
|
394 |
+
margin-top: 1rem;
|
395 |
+
}
|
396 |
+
|
397 |
+
.video-gallery.single {
|
398 |
+
grid-template-columns: 1fr;
|
399 |
+
}
|
400 |
+
|
401 |
+
.video-gallery.dual {
|
402 |
+
grid-template-columns: 1fr 1fr;
|
403 |
+
}
|
404 |
+
|
405 |
+
.video-gallery.multi {
|
406 |
+
grid-template-columns: repeat(2, 1fr);
|
407 |
+
grid-template-rows: auto auto auto;
|
408 |
+
}
|
409 |
+
|
410 |
+
.footer-text {
|
411 |
+
color: #718096 !important;
|
412 |
+
text-align: center;
|
413 |
+
padding: 2rem;
|
414 |
+
font-size: 0.9rem;
|
415 |
+
}
|
416 |
+
|
417 |
+
/* Video component styling for consistent size */
|
418 |
+
.input-section video,
|
419 |
+
.output-section video,
|
420 |
+
.example-row video {
|
421 |
+
width: 100% !important;
|
422 |
+
height: 300px !important;
|
423 |
+
object-fit: contain !important;
|
424 |
+
border-radius: 10px !important;
|
425 |
+
background-color: #000 !important;
|
426 |
+
}
|
427 |
+
|
428 |
+
.example-row video {
|
429 |
+
height: 150px !important;
|
430 |
+
}
|
431 |
+
|
432 |
+
/* Fix for additional samples video display */
|
433 |
+
.additional-samples video {
|
434 |
+
height: 150px !important;
|
435 |
+
object-fit: contain !important;
|
436 |
+
border-radius: 10px !important;
|
437 |
+
background-color: #000 !important;
|
438 |
+
}
|
439 |
+
|
440 |
+
.additional-samples .gradio-video {
|
441 |
+
border-radius: 10px !important;
|
442 |
+
overflow: hidden !important;
|
443 |
+
background-color: #000 !important;
|
444 |
+
}
|
445 |
+
|
446 |
+
.additional-samples .gradio-video > div {
|
447 |
+
background-color: #000 !important;
|
448 |
+
border-radius: 10px !important;
|
449 |
+
}
|
450 |
+
|
451 |
+
/* Video container styling */
|
452 |
+
.input-section .video-container,
|
453 |
+
.output-section .video-container,
|
454 |
+
.example-row .video-container {
|
455 |
+
background-color: #000 !important;
|
456 |
+
border-radius: 10px !important;
|
457 |
+
display: flex !important;
|
458 |
+
align-items: center !important;
|
459 |
+
justify-content: center !important;
|
460 |
+
overflow: hidden !important;
|
461 |
+
}
|
462 |
+
|
463 |
+
/* Ensure proper alignment */
|
464 |
+
.example-row {
|
465 |
+
display: flex !important;
|
466 |
+
align-items: stretch !important;
|
467 |
+
}
|
468 |
+
|
469 |
+
.example-row > div {
|
470 |
+
display: flex !important;
|
471 |
+
flex-direction: column !important;
|
472 |
+
justify-content: center !important;
|
473 |
+
}
|
474 |
+
|
475 |
+
/* Video wrapper for better control */
|
476 |
+
.video-wrapper {
|
477 |
+
position: relative !important;
|
478 |
+
width: 100% !important;
|
479 |
+
background: #000 !important;
|
480 |
+
border-radius: 10px !important;
|
481 |
+
overflow: hidden !important;
|
482 |
+
display: flex !important;
|
483 |
+
align-items: center !important;
|
484 |
+
justify-content: center !important;
|
485 |
+
}
|
486 |
+
"""
|
487 |
+
|
488 |
+
with gr.Blocks(css=css, title="HunyuanVideo-Foley") as app:
|
489 |
+
|
490 |
+
# Main header
|
491 |
+
with gr.Column(elem_classes=["main-header"]):
|
492 |
+
gr.HTML("""
|
493 |
+
<h1>🎵 HunyuanVideo-Foley</h1>
|
494 |
+
<p>Text-Video-to-Audio Synthesis: Generate realistic audio from video and text descriptions</p>
|
495 |
+
""")
|
496 |
+
|
497 |
+
# Usage Guide
|
498 |
+
with gr.Column(elem_classes=["status-card"]):
|
499 |
+
gr.Markdown("""
|
500 |
+
### 📋 Quick Start Guide
|
501 |
+
**1.** Upload your video file\t**2.** Add optional text description\t**3.** Adjust sample numbers (1-6)\t**4.** Click Generate Audio
|
502 |
+
|
503 |
+
💡 For quick start, you can load the prepared examples by clicking the button.
|
504 |
+
""", elem_classes=["usage-guide"])
|
505 |
+
|
506 |
+
# Main inference interface - Input and Results side by side
|
507 |
+
with gr.Row(elem_classes=["main-interface"]):
|
508 |
+
# Input section
|
509 |
+
with gr.Column(scale=1, elem_classes=["input-section"]):
|
510 |
+
gr.Markdown("### 📹 Video Input")
|
511 |
+
|
512 |
+
video_input = gr.Video(
|
513 |
+
label="Upload Video",
|
514 |
+
info="Supported formats: MP4, AVI, MOV, etc.",
|
515 |
+
height=300
|
516 |
+
)
|
517 |
+
|
518 |
+
text_input = gr.Textbox(
|
519 |
+
label="🎯 Audio Description (English)",
|
520 |
+
placeholder="A person walks on frozen ice",
|
521 |
+
lines=3,
|
522 |
+
info="Describe the audio you want to generate (optional)"
|
523 |
+
)
|
524 |
+
|
525 |
+
with gr.Row():
|
526 |
+
guidance_scale = gr.Slider(
|
527 |
+
minimum=1.0,
|
528 |
+
maximum=10.0,
|
529 |
+
value=4.5,
|
530 |
+
step=0.1,
|
531 |
+
label="🎚️ CFG Scale",
|
532 |
+
)
|
533 |
+
|
534 |
+
inference_steps = gr.Slider(
|
535 |
+
minimum=10,
|
536 |
+
maximum=100,
|
537 |
+
value=50,
|
538 |
+
step=5,
|
539 |
+
label="⚡ Steps",
|
540 |
+
)
|
541 |
+
|
542 |
+
sample_nums = gr.Slider(
|
543 |
+
minimum=1,
|
544 |
+
maximum=6,
|
545 |
+
value=1,
|
546 |
+
step=1,
|
547 |
+
label="🎲 Sample Nums",
|
548 |
+
)
|
549 |
+
|
550 |
+
generate_btn = gr.Button(
|
551 |
+
"🎵 Generate Audio",
|
552 |
+
variant="primary",
|
553 |
+
elem_classes=["generate-btn"]
|
554 |
+
)
|
555 |
+
|
556 |
+
# Results section
|
557 |
+
with gr.Column(scale=1, elem_classes=["output-section"]):
|
558 |
+
gr.Markdown("### 🎥 Generated Results")
|
559 |
+
|
560 |
+
# Multi-video gallery for displaying multiple generated samples
|
561 |
+
with gr.Column():
|
562 |
+
# Primary video (Sample 1)
|
563 |
+
video_output_1 = gr.Video(
|
564 |
+
label="Sample 1",
|
565 |
+
height=250,
|
566 |
+
visible=True
|
567 |
+
)
|
568 |
+
|
569 |
+
# Additional videos (Samples 2-6) - initially hidden
|
570 |
+
with gr.Row(elem_classes=["additional-samples"]):
|
571 |
+
with gr.Column(scale=1):
|
572 |
+
video_output_2 = gr.Video(
|
573 |
+
label="Sample 2",
|
574 |
+
height=150,
|
575 |
+
visible=False
|
576 |
+
)
|
577 |
+
video_output_3 = gr.Video(
|
578 |
+
label="Sample 3",
|
579 |
+
height=150,
|
580 |
+
visible=False
|
581 |
+
)
|
582 |
+
with gr.Column(scale=1):
|
583 |
+
video_output_4 = gr.Video(
|
584 |
+
label="Sample 4",
|
585 |
+
height=150,
|
586 |
+
visible=False
|
587 |
+
)
|
588 |
+
video_output_5 = gr.Video(
|
589 |
+
label="Sample 5",
|
590 |
+
height=150,
|
591 |
+
visible=False
|
592 |
+
)
|
593 |
+
|
594 |
+
# Sample 6 - full width
|
595 |
+
video_output_6 = gr.Video(
|
596 |
+
label="Sample 6",
|
597 |
+
height=150,
|
598 |
+
visible=False
|
599 |
+
)
|
600 |
+
|
601 |
+
result_text = gr.Textbox(
|
602 |
+
label="Status",
|
603 |
+
interactive=False,
|
604 |
+
lines=2
|
605 |
+
)
|
606 |
+
|
607 |
+
# Examples section at the bottom
|
608 |
+
with gr.Column(elem_classes=["examples-section"]):
|
609 |
+
gr.Markdown("### 🌟 Examples")
|
610 |
+
gr.Markdown("Click on any example to load it into the interface above")
|
611 |
+
|
612 |
+
# Define your custom examples here - 8 examples total
|
613 |
+
examples_data = [
|
614 |
+
# Example 1
|
615 |
+
{
|
616 |
+
"caption": "A person walks on frozen ice",
|
617 |
+
"video_path": "examples/1_video.mp4",
|
618 |
+
"result_path": "examples/1_result.mp4"
|
619 |
+
},
|
620 |
+
# Example 2
|
621 |
+
{
|
622 |
+
"caption": "With a faint sound as their hands parted, the two embraced, a soft 'mm' escaping between them.",
|
623 |
+
"video_path": "examples/2_video.mp4",
|
624 |
+
"result_path": "examples/2_result.mp4"
|
625 |
+
},
|
626 |
+
# Example 3
|
627 |
+
{
|
628 |
+
"caption": "The sound of the number 3's bouncing footsteps is as light and clear as glass marbles hitting the ground. Each step carries a magical sound.",
|
629 |
+
"video_path": "examples/3_video.mp4",
|
630 |
+
"result_path": "examples/3_result.mp4"
|
631 |
+
},
|
632 |
+
# Example 4
|
633 |
+
{
|
634 |
+
"caption": "gentle gurgling of the stream's current, and music plays in the background which is a beautiful and serene piano solo with a hint of classical charm, evoking a sense of peace and serenity in people's hearts.",
|
635 |
+
"video_path": "examples/4_video.mp4",
|
636 |
+
"result_path": "examples/4_result.mp4"
|
637 |
+
},
|
638 |
+
# Example 5 - Add your new examples here
|
639 |
+
{
|
640 |
+
"caption": "snow crunching under the snowboard's edge.",
|
641 |
+
"video_path": "examples/5_video.mp4",
|
642 |
+
"result_path": "examples/5_result.mp4"
|
643 |
+
},
|
644 |
+
# Example 6
|
645 |
+
{
|
646 |
+
"caption": "The crackling of the fire, the whooshing of the flames, and the occasional crisp popping of charred leaves filled the forest.",
|
647 |
+
"video_path": "examples/6_video.mp4",
|
648 |
+
"result_path": "examples/6_result.mp4"
|
649 |
+
},
|
650 |
+
# Example 7
|
651 |
+
{
|
652 |
+
"caption": "humming of the scooter engine accelerates slowly.",
|
653 |
+
"video_path": "examples/7_video.mp4",
|
654 |
+
"result_path": "examples/7_result.mp4"
|
655 |
+
},
|
656 |
+
# Example 8
|
657 |
+
{
|
658 |
+
"caption": "splash of water and loud thud as person hits the surface.",
|
659 |
+
"video_path": "examples/8_video.mp4",
|
660 |
+
"result_path": "examples/8_result.mp4"
|
661 |
+
}
|
662 |
+
]
|
663 |
+
|
664 |
+
# Create example grid - 4 examples per row, 2 rows total
|
665 |
+
example_buttons = []
|
666 |
+
for row in range(2): # 2 rows
|
667 |
+
with gr.Row(elem_classes=["example-grid-row"]):
|
668 |
+
for col in range(4): # 4 columns
|
669 |
+
idx = row * 4 + col
|
670 |
+
if idx < len(examples_data):
|
671 |
+
example = examples_data[idx]
|
672 |
+
|
673 |
+
with gr.Column(scale=1, elem_classes=["example-item"]):
|
674 |
+
# Video thumbnail
|
675 |
+
if os.path.exists(example['video_path']):
|
676 |
+
example_video = gr.Video(
|
677 |
+
value=example['video_path'],
|
678 |
+
label=f"Example {idx+1}",
|
679 |
+
interactive=False,
|
680 |
+
show_label=True,
|
681 |
+
height=180
|
682 |
+
)
|
683 |
+
else:
|
684 |
+
example_video = gr.HTML(f"""
|
685 |
+
<div style="background: #f0f0f0; padding: 15px; text-align: center; border-radius: 8px; height: 180px; display: flex; align-items: center; justify-content: center;">
|
686 |
+
<div>
|
687 |
+
<p style="color: #666; margin: 0; font-size: 12px;">📹 Video not found</p>
|
688 |
+
<small style="color: #999; font-size: 10px;">{example['video_path']}</small>
|
689 |
+
</div>
|
690 |
+
</div>
|
691 |
+
""")
|
692 |
+
|
693 |
+
# Caption (truncated for grid layout)
|
694 |
+
caption_preview = example['caption'][:60] + "..." if len(example['caption']) > 60 else example['caption']
|
695 |
+
gr.Markdown(f"{caption_preview}", elem_classes=["example-caption"])
|
696 |
+
|
697 |
+
# Load button
|
698 |
+
example_btn = gr.Button(
|
699 |
+
f"Load Example {idx+1}",
|
700 |
+
variant="secondary",
|
701 |
+
size="sm"
|
702 |
+
)
|
703 |
+
example_buttons.append((example_btn, example))
|
704 |
+
|
705 |
+
# Event handlers
|
706 |
+
def process_inference(video_file, text_prompt, guidance_scale, inference_steps, sample_nums):
|
707 |
+
# Generate videos
|
708 |
+
video_list, status_msg = infer_single_video(
|
709 |
+
video_file, text_prompt, guidance_scale, inference_steps, int(sample_nums)
|
710 |
+
)
|
711 |
+
# Update outputs with proper visibility
|
712 |
+
return update_video_outputs(video_list, status_msg)
|
713 |
+
|
714 |
+
# Add dynamic visibility control based on sample_nums
|
715 |
+
def update_visibility(sample_nums):
|
716 |
+
sample_nums = int(sample_nums)
|
717 |
+
return [
|
718 |
+
gr.update(visible=True), # Sample 1 always visible
|
719 |
+
gr.update(visible=sample_nums >= 2), # Sample 2
|
720 |
+
gr.update(visible=sample_nums >= 3), # Sample 3
|
721 |
+
gr.update(visible=sample_nums >= 4), # Sample 4
|
722 |
+
gr.update(visible=sample_nums >= 5), # Sample 5
|
723 |
+
gr.update(visible=sample_nums >= 6), # Sample 6
|
724 |
+
]
|
725 |
+
|
726 |
+
# Update visibility when sample_nums changes
|
727 |
+
sample_nums.change(
|
728 |
+
fn=update_visibility,
|
729 |
+
inputs=[sample_nums],
|
730 |
+
outputs=[video_output_1, video_output_2, video_output_3, video_output_4, video_output_5, video_output_6]
|
731 |
+
)
|
732 |
+
|
733 |
+
generate_btn.click(
|
734 |
+
fn=process_inference,
|
735 |
+
inputs=[video_input, text_input, guidance_scale, inference_steps, sample_nums],
|
736 |
+
outputs=[
|
737 |
+
video_output_1, # Sample 1 value
|
738 |
+
video_output_2, # Sample 2 value
|
739 |
+
video_output_3, # Sample 3 value
|
740 |
+
video_output_4, # Sample 4 value
|
741 |
+
video_output_5, # Sample 5 value
|
742 |
+
video_output_6, # Sample 6 value
|
743 |
+
result_text
|
744 |
+
]
|
745 |
+
)
|
746 |
+
|
747 |
+
# Add click handlers for example buttons
|
748 |
+
for btn, example in example_buttons:
|
749 |
+
def create_example_handler(ex):
|
750 |
+
def handler():
|
751 |
+
# Check if files exist, if not, return placeholder message
|
752 |
+
if os.path.exists(ex['video_path']):
|
753 |
+
video_file = ex['video_path']
|
754 |
+
else:
|
755 |
+
video_file = None
|
756 |
+
|
757 |
+
if os.path.exists(ex['result_path']):
|
758 |
+
result_video = ex['result_path']
|
759 |
+
else:
|
760 |
+
result_video = None
|
761 |
+
|
762 |
+
status_msg = f"✅ Loaded example with caption: {ex['caption'][:50]}..."
|
763 |
+
if not video_file:
|
764 |
+
status_msg += f"\n⚠️ Video file not found: {ex['video_path']}"
|
765 |
+
if not result_video:
|
766 |
+
status_msg += f"\n⚠️ Result video not found: {ex['result_path']}"
|
767 |
+
|
768 |
+
return video_file, ex['caption'], result_video, status_msg
|
769 |
+
return handler
|
770 |
+
|
771 |
+
btn.click(
|
772 |
+
fn=create_example_handler(example),
|
773 |
+
outputs=[video_input, text_input, video_output_1, result_text]
|
774 |
+
)
|
775 |
+
|
776 |
+
# Footer
|
777 |
+
gr.HTML("""
|
778 |
+
<div class="footer-text">
|
779 |
+
<p>🚀 Powered by HunyuanVideo-Foley | Generate high-quality audio from video and text descriptions</p>
|
780 |
+
</div>
|
781 |
+
""")
|
782 |
+
|
783 |
+
return app
|
784 |
+
|
785 |
+
def set_manual_seed(global_seed):
|
786 |
+
random.seed(global_seed)
|
787 |
+
np.random.seed(global_seed)
|
788 |
+
torch.manual_seed(global_seed)
|
789 |
+
|
790 |
+
if __name__ == "__main__":
|
791 |
+
set_manual_seed(1)
|
792 |
+
# Setup logging
|
793 |
+
logger.remove()
|
794 |
+
logger.add(lambda msg: print(msg, end=''), level="INFO")
|
795 |
+
|
796 |
+
# Auto-load model
|
797 |
+
logger.info("Starting application and loading model...")
|
798 |
+
model_load_result = auto_load_models()
|
799 |
+
logger.info(model_load_result)
|
800 |
+
|
801 |
+
# Create and launch Gradio app
|
802 |
+
app = create_gradio_interface()
|
803 |
+
|
804 |
+
# Log completion status
|
805 |
+
if "successfully" in model_load_result:
|
806 |
+
logger.info("Application ready, model loaded")
|
807 |
+
|
808 |
+
app.launch(
|
809 |
+
server_name="0.0.0.0",
|
810 |
+
server_port=8080,
|
811 |
+
share=False,
|
812 |
+
debug=False,
|
813 |
+
show_error=True
|
814 |
+
)
|
assets/data_pipeline.png
ADDED
![]() |
Git LFS Details
|
assets/model_arch.png
ADDED
![]() |
Git LFS Details
|
assets/pan_chart.png
ADDED
![]() |
Git LFS Details
|
configs/hunyuanvideo-foley-xxl.yaml
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model_config:
|
2 |
+
model_name: HunyuanVideo-Foley-XXL
|
3 |
+
model_type: 1d
|
4 |
+
model_precision: bf16
|
5 |
+
model_kwargs:
|
6 |
+
depth_triple_blocks: 18
|
7 |
+
depth_single_blocks: 36
|
8 |
+
hidden_size: 1536
|
9 |
+
num_heads: 12
|
10 |
+
mlp_ratio: 4
|
11 |
+
mlp_act_type: "gelu_tanh"
|
12 |
+
qkv_bias: True
|
13 |
+
qk_norm: True
|
14 |
+
qk_norm_type: "rms"
|
15 |
+
attn_mode: "torch"
|
16 |
+
embedder_type: "default"
|
17 |
+
interleaved_audio_visual_rope: True
|
18 |
+
enable_learnable_empty_visual_feat: True
|
19 |
+
sync_modulation: False
|
20 |
+
add_sync_feat_to_audio: True
|
21 |
+
cross_attention: True
|
22 |
+
use_attention_mask: False
|
23 |
+
condition_projection: "linear"
|
24 |
+
sync_feat_dim: 768 # syncformer 768 dim
|
25 |
+
condition_dim: 768 # clap 768 text condition dim (clip-text)
|
26 |
+
clip_dim: 768 # siglip2 visual dim
|
27 |
+
audio_vae_latent_dim: 128
|
28 |
+
audio_frame_rate: 50
|
29 |
+
patch_size: 1
|
30 |
+
rope_dim_list: null
|
31 |
+
rope_theta: 10000
|
32 |
+
text_length: 77
|
33 |
+
clip_length: 64
|
34 |
+
sync_length: 192
|
35 |
+
use_mmaudio_singleblock: True
|
36 |
+
depth_triple_ssl_encoder: null
|
37 |
+
depth_single_ssl_encoder: 8
|
38 |
+
use_repa_with_audiossl: True
|
39 |
+
|
40 |
+
diffusion_config:
|
41 |
+
denoise_type: "flow"
|
42 |
+
flow_path_type: "linear"
|
43 |
+
flow_predict_type: "velocity"
|
44 |
+
flow_reverse: True
|
45 |
+
flow_solver: "euler"
|
46 |
+
sample_flow_shift: 1.0
|
47 |
+
sample_use_flux_shift: False
|
48 |
+
flux_base_shift: 0.5
|
49 |
+
flux_max_shift: 1.15
|
examples/1_result.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7f3f49d6130592f479b0aca5f02ba25960140ed8d9d17340ff7f6306b39096a8
|
3 |
+
size 11357340
|
examples/1_video.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:54fc2de0b52f6969157b9caff212ffffddd4d34a75efb47ef7e7f8352d0a38db
|
3 |
+
size 11181543
|
examples/2_result.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:cf4f324b158f6a6926e77bbd0791610d79f0c3a600a571ed8ec61b0b7e645e46
|
3 |
+
size 1720732
|
examples/2_video.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:512b185682ec8e60407dff65718443d7a28c75c79aec1e733b5abf7433af41a7
|
3 |
+
size 1636945
|
examples/3_result.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:df90300335b7ab1fb2fc4c020976837c6b3781f796e211bcbbaa30d34353d3e5
|
3 |
+
size 1738462
|
examples/3_video.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e79f80cf939fcb507e3fe218f61146d4fd3949a84e802b0f5b67bb2e981931a7
|
3 |
+
size 1652180
|
examples/4_result.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c7d2b5b63f6756f719d53e8087f772aa6bb25f31fbcd9f1cbae9e075fc841a2c
|
3 |
+
size 45242387
|
examples/4_video.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f94cfd97634f3df085672ce2a91805697320507a716360bf85ba6eabb5a4b6f0
|
3 |
+
size 45066257
|
examples/5_result.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:bf1db36336822b54b4e0e0aa8c98334fc2b97b9288271ddc9d42a45417b1f1d9
|
3 |
+
size 40423834
|
examples/5_video.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:746954bde2d5e693beecd8e3661bcd66ff0e55a8143f4f3b37f0b6d3873a8fff
|
3 |
+
size 40248335
|
examples/6_result.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7778c26a677c04e93dc722cee44b896c5281bea8328a10800a08b98865419cd0
|
3 |
+
size 4005580
|
examples/6_video.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:866b2cff3441ddd686e551181c48ad7ca718626a489cf64198626b42bd732366
|
3 |
+
size 3872852
|
examples/7_result.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:491114177a4ceeb50a6edfa5fca14fc6ce4fdb61ee7dfb0e13983236c42ee10d
|
3 |
+
size 32307884
|
examples/7_video.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7bc8a31e867245f6a6a6fbfa9778ac4c12e816184dc70324dc92d4496a36f62b
|
3 |
+
size 32131367
|
examples/8_result.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:17053c5f8a373d656be2d2030619a71a5fd55db6365f8d54234402121e6030ce
|
3 |
+
size 29544164
|
examples/8_video.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:45a1f35974ad6e2d86304828fdeb230d9b008aae2f10cff8c87d71a8dcc6491e
|
3 |
+
size 29367637
|
hunyuanvideo_foley/__init__.py
ADDED
File without changes
|
hunyuanvideo_foley/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (173 Bytes). View file
|
|
hunyuanvideo_foley/constants.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Constants used throughout the HunyuanVideo-Foley project."""
|
2 |
+
|
3 |
+
from typing import Dict, List
|
4 |
+
|
5 |
+
# Model configuration
|
6 |
+
DEFAULT_AUDIO_SAMPLE_RATE = 48000
|
7 |
+
DEFAULT_VIDEO_FPS = 25
|
8 |
+
DEFAULT_AUDIO_CHANNELS = 2
|
9 |
+
|
10 |
+
# Video processing
|
11 |
+
MAX_VIDEO_DURATION_SECONDS = 15.0
|
12 |
+
MIN_VIDEO_DURATION_SECONDS = 1.0
|
13 |
+
|
14 |
+
# Audio processing
|
15 |
+
AUDIO_VAE_LATENT_DIM = 128
|
16 |
+
AUDIO_FRAME_RATE = 75 # frames per second in latent space
|
17 |
+
|
18 |
+
# Visual features
|
19 |
+
FPS_VISUAL: Dict[str, int] = {
|
20 |
+
"siglip2": 8,
|
21 |
+
"synchformer": 25
|
22 |
+
}
|
23 |
+
|
24 |
+
# Model paths (can be overridden by environment variables)
|
25 |
+
DEFAULT_MODEL_PATH = "./pretrained_models/"
|
26 |
+
DEFAULT_CONFIG_PATH = "configs/hunyuanvideo-foley-xxl.yaml"
|
27 |
+
|
28 |
+
# Inference parameters
|
29 |
+
DEFAULT_GUIDANCE_SCALE = 4.5
|
30 |
+
DEFAULT_NUM_INFERENCE_STEPS = 50
|
31 |
+
MIN_GUIDANCE_SCALE = 1.0
|
32 |
+
MAX_GUIDANCE_SCALE = 10.0
|
33 |
+
MIN_INFERENCE_STEPS = 10
|
34 |
+
MAX_INFERENCE_STEPS = 100
|
35 |
+
|
36 |
+
# Text processing
|
37 |
+
MAX_TEXT_LENGTH = 100
|
38 |
+
DEFAULT_NEGATIVE_PROMPT = "noisy, harsh"
|
39 |
+
|
40 |
+
# File extensions
|
41 |
+
SUPPORTED_VIDEO_EXTENSIONS: List[str] = [".mp4", ".avi", ".mov", ".mkv", ".webm"]
|
42 |
+
SUPPORTED_AUDIO_EXTENSIONS: List[str] = [".wav", ".mp3", ".flac", ".aac"]
|
43 |
+
|
44 |
+
# Quality settings
|
45 |
+
AUDIO_QUALITY_SETTINGS: Dict[str, List[str]] = {
|
46 |
+
"high": ["-b:a", "192k"],
|
47 |
+
"medium": ["-b:a", "128k"],
|
48 |
+
"low": ["-b:a", "96k"]
|
49 |
+
}
|
50 |
+
|
51 |
+
# Error messages
|
52 |
+
ERROR_MESSAGES: Dict[str, str] = {
|
53 |
+
"model_not_loaded": "Model is not loaded. Please load the model first.",
|
54 |
+
"invalid_video_format": "Unsupported video format. Supported formats: {formats}",
|
55 |
+
"video_too_long": f"Video duration exceeds maximum of {MAX_VIDEO_DURATION_SECONDS} seconds",
|
56 |
+
"ffmpeg_not_found": "ffmpeg not found. Please install ffmpeg: https://ffmpeg.org/download.html"
|
57 |
+
}
|
hunyuanvideo_foley/models/__init__.py
ADDED
File without changes
|
hunyuanvideo_foley/models/__pycache__/mmaudio_layer.cpython-312.pyc
ADDED
Binary file (12.1 kB). View file
|
|
hunyuanvideo_foley/models/dac_vae/__init__.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__version__ = "1.0.0"
|
2 |
+
|
3 |
+
# preserved here for legacy reasons
|
4 |
+
__model_version__ = "latest"
|
5 |
+
|
6 |
+
import audiotools
|
7 |
+
|
8 |
+
audiotools.ml.BaseModel.INTERN += ["dac.**"]
|
9 |
+
audiotools.ml.BaseModel.EXTERN += ["einops"]
|
10 |
+
|
11 |
+
|
12 |
+
from . import nn
|
13 |
+
from . import model
|
14 |
+
from . import utils
|
15 |
+
from .model import DAC
|
16 |
+
from .model import DACFile
|
hunyuanvideo_foley/models/dac_vae/__main__.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
|
3 |
+
import argbind
|
4 |
+
|
5 |
+
from .utils import download
|
6 |
+
from .utils.decode import decode
|
7 |
+
from .utils.encode import encode
|
8 |
+
|
9 |
+
STAGES = ["encode", "decode", "download"]
|
10 |
+
|
11 |
+
|
12 |
+
def run(stage: str):
|
13 |
+
"""Run stages.
|
14 |
+
|
15 |
+
Parameters
|
16 |
+
----------
|
17 |
+
stage : str
|
18 |
+
Stage to run
|
19 |
+
"""
|
20 |
+
if stage not in STAGES:
|
21 |
+
raise ValueError(f"Unknown command: {stage}. Allowed commands are {STAGES}")
|
22 |
+
stage_fn = globals()[stage]
|
23 |
+
|
24 |
+
if stage == "download":
|
25 |
+
stage_fn()
|
26 |
+
return
|
27 |
+
|
28 |
+
stage_fn()
|
29 |
+
|
30 |
+
|
31 |
+
if __name__ == "__main__":
|
32 |
+
group = sys.argv.pop(1)
|
33 |
+
args = argbind.parse_args(group=group)
|
34 |
+
|
35 |
+
with argbind.scope(args):
|
36 |
+
run(group)
|
hunyuanvideo_foley/models/dac_vae/model/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .base import CodecMixin
|
2 |
+
from .base import DACFile
|
3 |
+
from .dac import DAC
|
4 |
+
from .discriminator import Discriminator
|
hunyuanvideo_foley/models/dac_vae/model/base.py
ADDED
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from dataclasses import dataclass
|
3 |
+
from pathlib import Path
|
4 |
+
from typing import Union
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import tqdm
|
9 |
+
from audiotools import AudioSignal
|
10 |
+
from torch import nn
|
11 |
+
|
12 |
+
SUPPORTED_VERSIONS = ["1.0.0"]
|
13 |
+
|
14 |
+
|
15 |
+
@dataclass
|
16 |
+
class DACFile:
|
17 |
+
codes: torch.Tensor
|
18 |
+
|
19 |
+
# Metadata
|
20 |
+
chunk_length: int
|
21 |
+
original_length: int
|
22 |
+
input_db: float
|
23 |
+
channels: int
|
24 |
+
sample_rate: int
|
25 |
+
padding: bool
|
26 |
+
dac_version: str
|
27 |
+
|
28 |
+
def save(self, path):
|
29 |
+
artifacts = {
|
30 |
+
"codes": self.codes.numpy().astype(np.uint16),
|
31 |
+
"metadata": {
|
32 |
+
"input_db": self.input_db.numpy().astype(np.float32),
|
33 |
+
"original_length": self.original_length,
|
34 |
+
"sample_rate": self.sample_rate,
|
35 |
+
"chunk_length": self.chunk_length,
|
36 |
+
"channels": self.channels,
|
37 |
+
"padding": self.padding,
|
38 |
+
"dac_version": SUPPORTED_VERSIONS[-1],
|
39 |
+
},
|
40 |
+
}
|
41 |
+
path = Path(path).with_suffix(".dac")
|
42 |
+
with open(path, "wb") as f:
|
43 |
+
np.save(f, artifacts)
|
44 |
+
return path
|
45 |
+
|
46 |
+
@classmethod
|
47 |
+
def load(cls, path):
|
48 |
+
artifacts = np.load(path, allow_pickle=True)[()]
|
49 |
+
codes = torch.from_numpy(artifacts["codes"].astype(int))
|
50 |
+
if artifacts["metadata"].get("dac_version", None) not in SUPPORTED_VERSIONS:
|
51 |
+
raise RuntimeError(
|
52 |
+
f"Given file {path} can't be loaded with this version of descript-audio-codec."
|
53 |
+
)
|
54 |
+
return cls(codes=codes, **artifacts["metadata"])
|
55 |
+
|
56 |
+
|
57 |
+
class CodecMixin:
|
58 |
+
@property
|
59 |
+
def padding(self):
|
60 |
+
if not hasattr(self, "_padding"):
|
61 |
+
self._padding = True
|
62 |
+
return self._padding
|
63 |
+
|
64 |
+
@padding.setter
|
65 |
+
def padding(self, value):
|
66 |
+
assert isinstance(value, bool)
|
67 |
+
|
68 |
+
layers = [
|
69 |
+
l for l in self.modules() if isinstance(l, (nn.Conv1d, nn.ConvTranspose1d))
|
70 |
+
]
|
71 |
+
|
72 |
+
for layer in layers:
|
73 |
+
if value:
|
74 |
+
if hasattr(layer, "original_padding"):
|
75 |
+
layer.padding = layer.original_padding
|
76 |
+
else:
|
77 |
+
layer.original_padding = layer.padding
|
78 |
+
layer.padding = tuple(0 for _ in range(len(layer.padding)))
|
79 |
+
|
80 |
+
self._padding = value
|
81 |
+
|
82 |
+
def get_delay(self):
|
83 |
+
# Any number works here, delay is invariant to input length
|
84 |
+
l_out = self.get_output_length(0)
|
85 |
+
L = l_out
|
86 |
+
|
87 |
+
layers = []
|
88 |
+
for layer in self.modules():
|
89 |
+
if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
|
90 |
+
layers.append(layer)
|
91 |
+
|
92 |
+
for layer in reversed(layers):
|
93 |
+
d = layer.dilation[0]
|
94 |
+
k = layer.kernel_size[0]
|
95 |
+
s = layer.stride[0]
|
96 |
+
|
97 |
+
if isinstance(layer, nn.ConvTranspose1d):
|
98 |
+
L = ((L - d * (k - 1) - 1) / s) + 1
|
99 |
+
elif isinstance(layer, nn.Conv1d):
|
100 |
+
L = (L - 1) * s + d * (k - 1) + 1
|
101 |
+
|
102 |
+
L = math.ceil(L)
|
103 |
+
|
104 |
+
l_in = L
|
105 |
+
|
106 |
+
return (l_in - l_out) // 2
|
107 |
+
|
108 |
+
def get_output_length(self, input_length):
|
109 |
+
L = input_length
|
110 |
+
# Calculate output length
|
111 |
+
for layer in self.modules():
|
112 |
+
if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
|
113 |
+
d = layer.dilation[0]
|
114 |
+
k = layer.kernel_size[0]
|
115 |
+
s = layer.stride[0]
|
116 |
+
|
117 |
+
if isinstance(layer, nn.Conv1d):
|
118 |
+
L = ((L - d * (k - 1) - 1) / s) + 1
|
119 |
+
elif isinstance(layer, nn.ConvTranspose1d):
|
120 |
+
L = (L - 1) * s + d * (k - 1) + 1
|
121 |
+
|
122 |
+
L = math.floor(L)
|
123 |
+
return L
|
124 |
+
|
125 |
+
@torch.no_grad()
|
126 |
+
def compress(
|
127 |
+
self,
|
128 |
+
audio_path_or_signal: Union[str, Path, AudioSignal],
|
129 |
+
win_duration: float = 1.0,
|
130 |
+
verbose: bool = False,
|
131 |
+
normalize_db: float = -16,
|
132 |
+
n_quantizers: int = None,
|
133 |
+
) -> DACFile:
|
134 |
+
"""Processes an audio signal from a file or AudioSignal object into
|
135 |
+
discrete codes. This function processes the signal in short windows,
|
136 |
+
using constant GPU memory.
|
137 |
+
|
138 |
+
Parameters
|
139 |
+
----------
|
140 |
+
audio_path_or_signal : Union[str, Path, AudioSignal]
|
141 |
+
audio signal to reconstruct
|
142 |
+
win_duration : float, optional
|
143 |
+
window duration in seconds, by default 5.0
|
144 |
+
verbose : bool, optional
|
145 |
+
by default False
|
146 |
+
normalize_db : float, optional
|
147 |
+
normalize db, by default -16
|
148 |
+
|
149 |
+
Returns
|
150 |
+
-------
|
151 |
+
DACFile
|
152 |
+
Object containing compressed codes and metadata
|
153 |
+
required for decompression
|
154 |
+
"""
|
155 |
+
audio_signal = audio_path_or_signal
|
156 |
+
if isinstance(audio_signal, (str, Path)):
|
157 |
+
audio_signal = AudioSignal.load_from_file_with_ffmpeg(str(audio_signal))
|
158 |
+
|
159 |
+
self.eval()
|
160 |
+
original_padding = self.padding
|
161 |
+
original_device = audio_signal.device
|
162 |
+
|
163 |
+
audio_signal = audio_signal.clone()
|
164 |
+
audio_signal = audio_signal.to_mono()
|
165 |
+
original_sr = audio_signal.sample_rate
|
166 |
+
|
167 |
+
resample_fn = audio_signal.resample
|
168 |
+
loudness_fn = audio_signal.loudness
|
169 |
+
|
170 |
+
# If audio is > 10 minutes long, use the ffmpeg versions
|
171 |
+
if audio_signal.signal_duration >= 10 * 60 * 60:
|
172 |
+
resample_fn = audio_signal.ffmpeg_resample
|
173 |
+
loudness_fn = audio_signal.ffmpeg_loudness
|
174 |
+
|
175 |
+
original_length = audio_signal.signal_length
|
176 |
+
resample_fn(self.sample_rate)
|
177 |
+
input_db = loudness_fn()
|
178 |
+
|
179 |
+
if normalize_db is not None:
|
180 |
+
audio_signal.normalize(normalize_db)
|
181 |
+
audio_signal.ensure_max_of_audio()
|
182 |
+
|
183 |
+
nb, nac, nt = audio_signal.audio_data.shape
|
184 |
+
audio_signal.audio_data = audio_signal.audio_data.reshape(nb * nac, 1, nt)
|
185 |
+
win_duration = (
|
186 |
+
audio_signal.signal_duration if win_duration is None else win_duration
|
187 |
+
)
|
188 |
+
|
189 |
+
if audio_signal.signal_duration <= win_duration:
|
190 |
+
# Unchunked compression (used if signal length < win duration)
|
191 |
+
self.padding = True
|
192 |
+
n_samples = nt
|
193 |
+
hop = nt
|
194 |
+
else:
|
195 |
+
# Chunked inference
|
196 |
+
self.padding = False
|
197 |
+
# Zero-pad signal on either side by the delay
|
198 |
+
audio_signal.zero_pad(self.delay, self.delay)
|
199 |
+
n_samples = int(win_duration * self.sample_rate)
|
200 |
+
# Round n_samples to nearest hop length multiple
|
201 |
+
n_samples = int(math.ceil(n_samples / self.hop_length) * self.hop_length)
|
202 |
+
hop = self.get_output_length(n_samples)
|
203 |
+
|
204 |
+
codes = []
|
205 |
+
range_fn = range if not verbose else tqdm.trange
|
206 |
+
|
207 |
+
for i in range_fn(0, nt, hop):
|
208 |
+
x = audio_signal[..., i : i + n_samples]
|
209 |
+
x = x.zero_pad(0, max(0, n_samples - x.shape[-1]))
|
210 |
+
|
211 |
+
audio_data = x.audio_data.to(self.device)
|
212 |
+
audio_data = self.preprocess(audio_data, self.sample_rate)
|
213 |
+
_, c, _, _, _ = self.encode(audio_data, n_quantizers)
|
214 |
+
codes.append(c.to(original_device))
|
215 |
+
chunk_length = c.shape[-1]
|
216 |
+
|
217 |
+
codes = torch.cat(codes, dim=-1)
|
218 |
+
|
219 |
+
dac_file = DACFile(
|
220 |
+
codes=codes,
|
221 |
+
chunk_length=chunk_length,
|
222 |
+
original_length=original_length,
|
223 |
+
input_db=input_db,
|
224 |
+
channels=nac,
|
225 |
+
sample_rate=original_sr,
|
226 |
+
padding=self.padding,
|
227 |
+
dac_version=SUPPORTED_VERSIONS[-1],
|
228 |
+
)
|
229 |
+
|
230 |
+
if n_quantizers is not None:
|
231 |
+
codes = codes[:, :n_quantizers, :]
|
232 |
+
|
233 |
+
self.padding = original_padding
|
234 |
+
return dac_file
|
235 |
+
|
236 |
+
@torch.no_grad()
|
237 |
+
def decompress(
|
238 |
+
self,
|
239 |
+
obj: Union[str, Path, DACFile],
|
240 |
+
verbose: bool = False,
|
241 |
+
) -> AudioSignal:
|
242 |
+
"""Reconstruct audio from a given .dac file
|
243 |
+
|
244 |
+
Parameters
|
245 |
+
----------
|
246 |
+
obj : Union[str, Path, DACFile]
|
247 |
+
.dac file location or corresponding DACFile object.
|
248 |
+
verbose : bool, optional
|
249 |
+
Prints progress if True, by default False
|
250 |
+
|
251 |
+
Returns
|
252 |
+
-------
|
253 |
+
AudioSignal
|
254 |
+
Object with the reconstructed audio
|
255 |
+
"""
|
256 |
+
self.eval()
|
257 |
+
if isinstance(obj, (str, Path)):
|
258 |
+
obj = DACFile.load(obj)
|
259 |
+
|
260 |
+
original_padding = self.padding
|
261 |
+
self.padding = obj.padding
|
262 |
+
|
263 |
+
range_fn = range if not verbose else tqdm.trange
|
264 |
+
codes = obj.codes
|
265 |
+
original_device = codes.device
|
266 |
+
chunk_length = obj.chunk_length
|
267 |
+
recons = []
|
268 |
+
|
269 |
+
for i in range_fn(0, codes.shape[-1], chunk_length):
|
270 |
+
c = codes[..., i : i + chunk_length].to(self.device)
|
271 |
+
z = self.quantizer.from_codes(c)[0]
|
272 |
+
r = self.decode(z)
|
273 |
+
recons.append(r.to(original_device))
|
274 |
+
|
275 |
+
recons = torch.cat(recons, dim=-1)
|
276 |
+
recons = AudioSignal(recons, self.sample_rate)
|
277 |
+
|
278 |
+
resample_fn = recons.resample
|
279 |
+
loudness_fn = recons.loudness
|
280 |
+
|
281 |
+
# If audio is > 10 minutes long, use the ffmpeg versions
|
282 |
+
if recons.signal_duration >= 10 * 60 * 60:
|
283 |
+
resample_fn = recons.ffmpeg_resample
|
284 |
+
loudness_fn = recons.ffmpeg_loudness
|
285 |
+
|
286 |
+
if obj.input_db is not None:
|
287 |
+
recons.normalize(obj.input_db)
|
288 |
+
|
289 |
+
resample_fn(obj.sample_rate)
|
290 |
+
|
291 |
+
if obj.original_length is not None:
|
292 |
+
recons = recons[..., : obj.original_length]
|
293 |
+
loudness_fn()
|
294 |
+
recons.audio_data = recons.audio_data.reshape(
|
295 |
+
-1, obj.channels, obj.original_length
|
296 |
+
)
|
297 |
+
else:
|
298 |
+
loudness_fn()
|
299 |
+
|
300 |
+
self.padding = original_padding
|
301 |
+
return recons
|
hunyuanvideo_foley/models/dac_vae/model/dac.py
ADDED
@@ -0,0 +1,410 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import List
|
3 |
+
from typing import Union
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from audiotools import AudioSignal
|
8 |
+
from audiotools.ml import BaseModel
|
9 |
+
from torch import nn
|
10 |
+
|
11 |
+
from .base import CodecMixin
|
12 |
+
from ..nn.layers import Snake1d
|
13 |
+
from ..nn.layers import WNConv1d
|
14 |
+
from ..nn.layers import WNConvTranspose1d
|
15 |
+
from ..nn.quantize import ResidualVectorQuantize
|
16 |
+
from ..nn.vae_utils import DiagonalGaussianDistribution
|
17 |
+
|
18 |
+
|
19 |
+
def init_weights(m):
|
20 |
+
if isinstance(m, nn.Conv1d):
|
21 |
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
22 |
+
nn.init.constant_(m.bias, 0)
|
23 |
+
|
24 |
+
|
25 |
+
class ResidualUnit(nn.Module):
|
26 |
+
def __init__(self, dim: int = 16, dilation: int = 1):
|
27 |
+
super().__init__()
|
28 |
+
pad = ((7 - 1) * dilation) // 2
|
29 |
+
self.block = nn.Sequential(
|
30 |
+
Snake1d(dim),
|
31 |
+
WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
|
32 |
+
Snake1d(dim),
|
33 |
+
WNConv1d(dim, dim, kernel_size=1),
|
34 |
+
)
|
35 |
+
|
36 |
+
def forward(self, x):
|
37 |
+
y = self.block(x)
|
38 |
+
pad = (x.shape[-1] - y.shape[-1]) // 2
|
39 |
+
if pad > 0:
|
40 |
+
x = x[..., pad:-pad]
|
41 |
+
return x + y
|
42 |
+
|
43 |
+
|
44 |
+
class EncoderBlock(nn.Module):
|
45 |
+
def __init__(self, dim: int = 16, stride: int = 1):
|
46 |
+
super().__init__()
|
47 |
+
self.block = nn.Sequential(
|
48 |
+
ResidualUnit(dim // 2, dilation=1),
|
49 |
+
ResidualUnit(dim // 2, dilation=3),
|
50 |
+
ResidualUnit(dim // 2, dilation=9),
|
51 |
+
Snake1d(dim // 2),
|
52 |
+
WNConv1d(
|
53 |
+
dim // 2,
|
54 |
+
dim,
|
55 |
+
kernel_size=2 * stride,
|
56 |
+
stride=stride,
|
57 |
+
padding=math.ceil(stride / 2),
|
58 |
+
),
|
59 |
+
)
|
60 |
+
|
61 |
+
def forward(self, x):
|
62 |
+
return self.block(x)
|
63 |
+
|
64 |
+
|
65 |
+
class Encoder(nn.Module):
|
66 |
+
def __init__(
|
67 |
+
self,
|
68 |
+
d_model: int = 64,
|
69 |
+
strides: list = [2, 4, 8, 8],
|
70 |
+
d_latent: int = 64,
|
71 |
+
):
|
72 |
+
super().__init__()
|
73 |
+
# Create first convolution
|
74 |
+
self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)]
|
75 |
+
|
76 |
+
# Create EncoderBlocks that double channels as they downsample by `stride`
|
77 |
+
for stride in strides:
|
78 |
+
d_model *= 2
|
79 |
+
self.block += [EncoderBlock(d_model, stride=stride)]
|
80 |
+
|
81 |
+
# Create last convolution
|
82 |
+
self.block += [
|
83 |
+
Snake1d(d_model),
|
84 |
+
WNConv1d(d_model, d_latent, kernel_size=3, padding=1),
|
85 |
+
]
|
86 |
+
|
87 |
+
# Wrap black into nn.Sequential
|
88 |
+
self.block = nn.Sequential(*self.block)
|
89 |
+
self.enc_dim = d_model
|
90 |
+
|
91 |
+
def forward(self, x):
|
92 |
+
return self.block(x)
|
93 |
+
|
94 |
+
|
95 |
+
class DecoderBlock(nn.Module):
|
96 |
+
def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1):
|
97 |
+
super().__init__()
|
98 |
+
self.block = nn.Sequential(
|
99 |
+
Snake1d(input_dim),
|
100 |
+
WNConvTranspose1d(
|
101 |
+
input_dim,
|
102 |
+
output_dim,
|
103 |
+
kernel_size=2 * stride,
|
104 |
+
stride=stride,
|
105 |
+
padding=math.ceil(stride / 2),
|
106 |
+
output_padding=stride % 2,
|
107 |
+
),
|
108 |
+
ResidualUnit(output_dim, dilation=1),
|
109 |
+
ResidualUnit(output_dim, dilation=3),
|
110 |
+
ResidualUnit(output_dim, dilation=9),
|
111 |
+
)
|
112 |
+
|
113 |
+
def forward(self, x):
|
114 |
+
return self.block(x)
|
115 |
+
|
116 |
+
|
117 |
+
class Decoder(nn.Module):
|
118 |
+
def __init__(
|
119 |
+
self,
|
120 |
+
input_channel,
|
121 |
+
channels,
|
122 |
+
rates,
|
123 |
+
d_out: int = 1,
|
124 |
+
):
|
125 |
+
super().__init__()
|
126 |
+
|
127 |
+
# Add first conv layer
|
128 |
+
layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3)]
|
129 |
+
|
130 |
+
# Add upsampling + MRF blocks
|
131 |
+
for i, stride in enumerate(rates):
|
132 |
+
input_dim = channels // 2**i
|
133 |
+
output_dim = channels // 2 ** (i + 1)
|
134 |
+
layers += [DecoderBlock(input_dim, output_dim, stride)]
|
135 |
+
|
136 |
+
# Add final conv layer
|
137 |
+
layers += [
|
138 |
+
Snake1d(output_dim),
|
139 |
+
WNConv1d(output_dim, d_out, kernel_size=7, padding=3),
|
140 |
+
nn.Tanh(),
|
141 |
+
]
|
142 |
+
|
143 |
+
self.model = nn.Sequential(*layers)
|
144 |
+
|
145 |
+
def forward(self, x):
|
146 |
+
return self.model(x)
|
147 |
+
|
148 |
+
|
149 |
+
class DAC(BaseModel, CodecMixin):
|
150 |
+
def __init__(
|
151 |
+
self,
|
152 |
+
encoder_dim: int = 64,
|
153 |
+
encoder_rates: List[int] = [2, 4, 8, 8],
|
154 |
+
latent_dim: int = None,
|
155 |
+
decoder_dim: int = 1536,
|
156 |
+
decoder_rates: List[int] = [8, 8, 4, 2],
|
157 |
+
n_codebooks: int = 9,
|
158 |
+
codebook_size: int = 1024,
|
159 |
+
codebook_dim: Union[int, list] = 8,
|
160 |
+
quantizer_dropout: bool = False,
|
161 |
+
sample_rate: int = 44100,
|
162 |
+
continuous: bool = False,
|
163 |
+
):
|
164 |
+
super().__init__()
|
165 |
+
|
166 |
+
self.encoder_dim = encoder_dim
|
167 |
+
self.encoder_rates = encoder_rates
|
168 |
+
self.decoder_dim = decoder_dim
|
169 |
+
self.decoder_rates = decoder_rates
|
170 |
+
self.sample_rate = sample_rate
|
171 |
+
self.continuous = continuous
|
172 |
+
|
173 |
+
if latent_dim is None:
|
174 |
+
latent_dim = encoder_dim * (2 ** len(encoder_rates))
|
175 |
+
|
176 |
+
self.latent_dim = latent_dim
|
177 |
+
|
178 |
+
self.hop_length = np.prod(encoder_rates)
|
179 |
+
self.encoder = Encoder(encoder_dim, encoder_rates, latent_dim)
|
180 |
+
|
181 |
+
if not continuous:
|
182 |
+
self.n_codebooks = n_codebooks
|
183 |
+
self.codebook_size = codebook_size
|
184 |
+
self.codebook_dim = codebook_dim
|
185 |
+
self.quantizer = ResidualVectorQuantize(
|
186 |
+
input_dim=latent_dim,
|
187 |
+
n_codebooks=n_codebooks,
|
188 |
+
codebook_size=codebook_size,
|
189 |
+
codebook_dim=codebook_dim,
|
190 |
+
quantizer_dropout=quantizer_dropout,
|
191 |
+
)
|
192 |
+
else:
|
193 |
+
self.quant_conv = torch.nn.Conv1d(latent_dim, 2 * latent_dim, 1)
|
194 |
+
self.post_quant_conv = torch.nn.Conv1d(latent_dim, latent_dim, 1)
|
195 |
+
|
196 |
+
self.decoder = Decoder(
|
197 |
+
latent_dim,
|
198 |
+
decoder_dim,
|
199 |
+
decoder_rates,
|
200 |
+
)
|
201 |
+
self.sample_rate = sample_rate
|
202 |
+
self.apply(init_weights)
|
203 |
+
|
204 |
+
self.delay = self.get_delay()
|
205 |
+
|
206 |
+
@property
|
207 |
+
def dtype(self):
|
208 |
+
"""Get the dtype of the model parameters."""
|
209 |
+
# Return the dtype of the first parameter found
|
210 |
+
for param in self.parameters():
|
211 |
+
return param.dtype
|
212 |
+
return torch.float32 # fallback
|
213 |
+
|
214 |
+
@property
|
215 |
+
def device(self):
|
216 |
+
"""Get the device of the model parameters."""
|
217 |
+
# Return the device of the first parameter found
|
218 |
+
for param in self.parameters():
|
219 |
+
return param.device
|
220 |
+
return torch.device('cpu') # fallback
|
221 |
+
|
222 |
+
def preprocess(self, audio_data, sample_rate):
|
223 |
+
if sample_rate is None:
|
224 |
+
sample_rate = self.sample_rate
|
225 |
+
assert sample_rate == self.sample_rate
|
226 |
+
|
227 |
+
length = audio_data.shape[-1]
|
228 |
+
right_pad = math.ceil(length / self.hop_length) * self.hop_length - length
|
229 |
+
audio_data = nn.functional.pad(audio_data, (0, right_pad))
|
230 |
+
|
231 |
+
return audio_data
|
232 |
+
|
233 |
+
def encode(
|
234 |
+
self,
|
235 |
+
audio_data: torch.Tensor,
|
236 |
+
n_quantizers: int = None,
|
237 |
+
):
|
238 |
+
"""Encode given audio data and return quantized latent codes
|
239 |
+
|
240 |
+
Parameters
|
241 |
+
----------
|
242 |
+
audio_data : Tensor[B x 1 x T]
|
243 |
+
Audio data to encode
|
244 |
+
n_quantizers : int, optional
|
245 |
+
Number of quantizers to use, by default None
|
246 |
+
If None, all quantizers are used.
|
247 |
+
|
248 |
+
Returns
|
249 |
+
-------
|
250 |
+
dict
|
251 |
+
A dictionary with the following keys:
|
252 |
+
"z" : Tensor[B x D x T]
|
253 |
+
Quantized continuous representation of input
|
254 |
+
"codes" : Tensor[B x N x T]
|
255 |
+
Codebook indices for each codebook
|
256 |
+
(quantized discrete representation of input)
|
257 |
+
"latents" : Tensor[B x N*D x T]
|
258 |
+
Projected latents (continuous representation of input before quantization)
|
259 |
+
"vq/commitment_loss" : Tensor[1]
|
260 |
+
Commitment loss to train encoder to predict vectors closer to codebook
|
261 |
+
entries
|
262 |
+
"vq/codebook_loss" : Tensor[1]
|
263 |
+
Codebook loss to update the codebook
|
264 |
+
"length" : int
|
265 |
+
Number of samples in input audio
|
266 |
+
"""
|
267 |
+
z = self.encoder(audio_data) # [B x D x T]
|
268 |
+
if not self.continuous:
|
269 |
+
z, codes, latents, commitment_loss, codebook_loss = self.quantizer(z, n_quantizers)
|
270 |
+
else:
|
271 |
+
z = self.quant_conv(z) # [B x 2D x T]
|
272 |
+
z = DiagonalGaussianDistribution(z)
|
273 |
+
codes, latents, commitment_loss, codebook_loss = None, None, 0, 0
|
274 |
+
|
275 |
+
return z, codes, latents, commitment_loss, codebook_loss
|
276 |
+
|
277 |
+
def decode(self, z: torch.Tensor):
|
278 |
+
"""Decode given latent codes and return audio data
|
279 |
+
|
280 |
+
Parameters
|
281 |
+
----------
|
282 |
+
z : Tensor[B x D x T]
|
283 |
+
Quantized continuous representation of input
|
284 |
+
length : int, optional
|
285 |
+
Number of samples in output audio, by default None
|
286 |
+
|
287 |
+
Returns
|
288 |
+
-------
|
289 |
+
dict
|
290 |
+
A dictionary with the following keys:
|
291 |
+
"audio" : Tensor[B x 1 x length]
|
292 |
+
Decoded audio data.
|
293 |
+
"""
|
294 |
+
if not self.continuous:
|
295 |
+
audio = self.decoder(z)
|
296 |
+
else:
|
297 |
+
z = self.post_quant_conv(z)
|
298 |
+
audio = self.decoder(z)
|
299 |
+
|
300 |
+
return audio
|
301 |
+
|
302 |
+
def forward(
|
303 |
+
self,
|
304 |
+
audio_data: torch.Tensor,
|
305 |
+
sample_rate: int = None,
|
306 |
+
n_quantizers: int = None,
|
307 |
+
):
|
308 |
+
"""Model forward pass
|
309 |
+
|
310 |
+
Parameters
|
311 |
+
----------
|
312 |
+
audio_data : Tensor[B x 1 x T]
|
313 |
+
Audio data to encode
|
314 |
+
sample_rate : int, optional
|
315 |
+
Sample rate of audio data in Hz, by default None
|
316 |
+
If None, defaults to `self.sample_rate`
|
317 |
+
n_quantizers : int, optional
|
318 |
+
Number of quantizers to use, by default None.
|
319 |
+
If None, all quantizers are used.
|
320 |
+
|
321 |
+
Returns
|
322 |
+
-------
|
323 |
+
dict
|
324 |
+
A dictionary with the following keys:
|
325 |
+
"z" : Tensor[B x D x T]
|
326 |
+
Quantized continuous representation of input
|
327 |
+
"codes" : Tensor[B x N x T]
|
328 |
+
Codebook indices for each codebook
|
329 |
+
(quantized discrete representation of input)
|
330 |
+
"latents" : Tensor[B x N*D x T]
|
331 |
+
Projected latents (continuous representation of input before quantization)
|
332 |
+
"vq/commitment_loss" : Tensor[1]
|
333 |
+
Commitment loss to train encoder to predict vectors closer to codebook
|
334 |
+
entries
|
335 |
+
"vq/codebook_loss" : Tensor[1]
|
336 |
+
Codebook loss to update the codebook
|
337 |
+
"length" : int
|
338 |
+
Number of samples in input audio
|
339 |
+
"audio" : Tensor[B x 1 x length]
|
340 |
+
Decoded audio data.
|
341 |
+
"""
|
342 |
+
length = audio_data.shape[-1]
|
343 |
+
audio_data = self.preprocess(audio_data, sample_rate)
|
344 |
+
if not self.continuous:
|
345 |
+
z, codes, latents, commitment_loss, codebook_loss = self.encode(audio_data, n_quantizers)
|
346 |
+
|
347 |
+
x = self.decode(z)
|
348 |
+
return {
|
349 |
+
"audio": x[..., :length],
|
350 |
+
"z": z,
|
351 |
+
"codes": codes,
|
352 |
+
"latents": latents,
|
353 |
+
"vq/commitment_loss": commitment_loss,
|
354 |
+
"vq/codebook_loss": codebook_loss,
|
355 |
+
}
|
356 |
+
else:
|
357 |
+
posterior, _, _, _, _ = self.encode(audio_data, n_quantizers)
|
358 |
+
z = posterior.sample()
|
359 |
+
x = self.decode(z)
|
360 |
+
|
361 |
+
kl_loss = posterior.kl()
|
362 |
+
kl_loss = kl_loss.mean()
|
363 |
+
|
364 |
+
return {
|
365 |
+
"audio": x[..., :length],
|
366 |
+
"z": z,
|
367 |
+
"kl_loss": kl_loss,
|
368 |
+
}
|
369 |
+
|
370 |
+
|
371 |
+
if __name__ == "__main__":
|
372 |
+
import numpy as np
|
373 |
+
from functools import partial
|
374 |
+
|
375 |
+
model = DAC().to("cpu")
|
376 |
+
|
377 |
+
for n, m in model.named_modules():
|
378 |
+
o = m.extra_repr()
|
379 |
+
p = sum([np.prod(p.size()) for p in m.parameters()])
|
380 |
+
fn = lambda o, p: o + f" {p/1e6:<.3f}M params."
|
381 |
+
setattr(m, "extra_repr", partial(fn, o=o, p=p))
|
382 |
+
print(model)
|
383 |
+
print("Total # of params: ", sum([np.prod(p.size()) for p in model.parameters()]))
|
384 |
+
|
385 |
+
length = 88200 * 2
|
386 |
+
x = torch.randn(1, 1, length).to(model.device)
|
387 |
+
x.requires_grad_(True)
|
388 |
+
x.retain_grad()
|
389 |
+
|
390 |
+
# Make a forward pass
|
391 |
+
out = model(x)["audio"]
|
392 |
+
print("Input shape:", x.shape)
|
393 |
+
print("Output shape:", out.shape)
|
394 |
+
|
395 |
+
# Create gradient variable
|
396 |
+
grad = torch.zeros_like(out)
|
397 |
+
grad[:, :, grad.shape[-1] // 2] = 1
|
398 |
+
|
399 |
+
# Make a backward pass
|
400 |
+
out.backward(grad)
|
401 |
+
|
402 |
+
# Check non-zero values
|
403 |
+
gradmap = x.grad.squeeze(0)
|
404 |
+
gradmap = (gradmap != 0).sum(0) # sum across features
|
405 |
+
rf = (gradmap != 0).sum()
|
406 |
+
|
407 |
+
print(f"Receptive field: {rf.item()}")
|
408 |
+
|
409 |
+
x = AudioSignal(torch.randn(1, 1, 44100 * 60), 44100)
|
410 |
+
model.decompress(model.compress(x, verbose=True), verbose=True)
|
hunyuanvideo_foley/models/dac_vae/model/discriminator.py
ADDED
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from audiotools import AudioSignal
|
5 |
+
from audiotools import ml
|
6 |
+
from audiotools import STFTParams
|
7 |
+
from einops import rearrange
|
8 |
+
from torch.nn.utils import weight_norm
|
9 |
+
|
10 |
+
|
11 |
+
def WNConv1d(*args, **kwargs):
|
12 |
+
act = kwargs.pop("act", True)
|
13 |
+
conv = weight_norm(nn.Conv1d(*args, **kwargs))
|
14 |
+
if not act:
|
15 |
+
return conv
|
16 |
+
return nn.Sequential(conv, nn.LeakyReLU(0.1))
|
17 |
+
|
18 |
+
|
19 |
+
def WNConv2d(*args, **kwargs):
|
20 |
+
act = kwargs.pop("act", True)
|
21 |
+
conv = weight_norm(nn.Conv2d(*args, **kwargs))
|
22 |
+
if not act:
|
23 |
+
return conv
|
24 |
+
return nn.Sequential(conv, nn.LeakyReLU(0.1))
|
25 |
+
|
26 |
+
|
27 |
+
class MPD(nn.Module):
|
28 |
+
def __init__(self, period):
|
29 |
+
super().__init__()
|
30 |
+
self.period = period
|
31 |
+
self.convs = nn.ModuleList(
|
32 |
+
[
|
33 |
+
WNConv2d(1, 32, (5, 1), (3, 1), padding=(2, 0)),
|
34 |
+
WNConv2d(32, 128, (5, 1), (3, 1), padding=(2, 0)),
|
35 |
+
WNConv2d(128, 512, (5, 1), (3, 1), padding=(2, 0)),
|
36 |
+
WNConv2d(512, 1024, (5, 1), (3, 1), padding=(2, 0)),
|
37 |
+
WNConv2d(1024, 1024, (5, 1), 1, padding=(2, 0)),
|
38 |
+
]
|
39 |
+
)
|
40 |
+
self.conv_post = WNConv2d(
|
41 |
+
1024, 1, kernel_size=(3, 1), padding=(1, 0), act=False
|
42 |
+
)
|
43 |
+
|
44 |
+
def pad_to_period(self, x):
|
45 |
+
t = x.shape[-1]
|
46 |
+
x = F.pad(x, (0, self.period - t % self.period), mode="reflect")
|
47 |
+
return x
|
48 |
+
|
49 |
+
def forward(self, x):
|
50 |
+
fmap = []
|
51 |
+
|
52 |
+
x = self.pad_to_period(x)
|
53 |
+
x = rearrange(x, "b c (l p) -> b c l p", p=self.period)
|
54 |
+
|
55 |
+
for layer in self.convs:
|
56 |
+
x = layer(x)
|
57 |
+
fmap.append(x)
|
58 |
+
|
59 |
+
x = self.conv_post(x)
|
60 |
+
fmap.append(x)
|
61 |
+
|
62 |
+
return fmap
|
63 |
+
|
64 |
+
|
65 |
+
class MSD(nn.Module):
|
66 |
+
def __init__(self, rate: int = 1, sample_rate: int = 44100):
|
67 |
+
super().__init__()
|
68 |
+
self.convs = nn.ModuleList(
|
69 |
+
[
|
70 |
+
WNConv1d(1, 16, 15, 1, padding=7),
|
71 |
+
WNConv1d(16, 64, 41, 4, groups=4, padding=20),
|
72 |
+
WNConv1d(64, 256, 41, 4, groups=16, padding=20),
|
73 |
+
WNConv1d(256, 1024, 41, 4, groups=64, padding=20),
|
74 |
+
WNConv1d(1024, 1024, 41, 4, groups=256, padding=20),
|
75 |
+
WNConv1d(1024, 1024, 5, 1, padding=2),
|
76 |
+
]
|
77 |
+
)
|
78 |
+
self.conv_post = WNConv1d(1024, 1, 3, 1, padding=1, act=False)
|
79 |
+
self.sample_rate = sample_rate
|
80 |
+
self.rate = rate
|
81 |
+
|
82 |
+
def forward(self, x):
|
83 |
+
x = AudioSignal(x, self.sample_rate)
|
84 |
+
x.resample(self.sample_rate // self.rate)
|
85 |
+
x = x.audio_data
|
86 |
+
|
87 |
+
fmap = []
|
88 |
+
|
89 |
+
for l in self.convs:
|
90 |
+
x = l(x)
|
91 |
+
fmap.append(x)
|
92 |
+
x = self.conv_post(x)
|
93 |
+
fmap.append(x)
|
94 |
+
|
95 |
+
return fmap
|
96 |
+
|
97 |
+
|
98 |
+
BANDS = [(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)]
|
99 |
+
|
100 |
+
|
101 |
+
class MRD(nn.Module):
|
102 |
+
def __init__(
|
103 |
+
self,
|
104 |
+
window_length: int,
|
105 |
+
hop_factor: float = 0.25,
|
106 |
+
sample_rate: int = 44100,
|
107 |
+
bands: list = BANDS,
|
108 |
+
):
|
109 |
+
"""Complex multi-band spectrogram discriminator.
|
110 |
+
Parameters
|
111 |
+
----------
|
112 |
+
window_length : int
|
113 |
+
Window length of STFT.
|
114 |
+
hop_factor : float, optional
|
115 |
+
Hop factor of the STFT, defaults to ``0.25 * window_length``.
|
116 |
+
sample_rate : int, optional
|
117 |
+
Sampling rate of audio in Hz, by default 44100
|
118 |
+
bands : list, optional
|
119 |
+
Bands to run discriminator over.
|
120 |
+
"""
|
121 |
+
super().__init__()
|
122 |
+
|
123 |
+
self.window_length = window_length
|
124 |
+
self.hop_factor = hop_factor
|
125 |
+
self.sample_rate = sample_rate
|
126 |
+
self.stft_params = STFTParams(
|
127 |
+
window_length=window_length,
|
128 |
+
hop_length=int(window_length * hop_factor),
|
129 |
+
match_stride=True,
|
130 |
+
)
|
131 |
+
|
132 |
+
n_fft = window_length // 2 + 1
|
133 |
+
bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
|
134 |
+
self.bands = bands
|
135 |
+
|
136 |
+
ch = 32
|
137 |
+
convs = lambda: nn.ModuleList(
|
138 |
+
[
|
139 |
+
WNConv2d(2, ch, (3, 9), (1, 1), padding=(1, 4)),
|
140 |
+
WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
|
141 |
+
WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
|
142 |
+
WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
|
143 |
+
WNConv2d(ch, ch, (3, 3), (1, 1), padding=(1, 1)),
|
144 |
+
]
|
145 |
+
)
|
146 |
+
self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
|
147 |
+
self.conv_post = WNConv2d(ch, 1, (3, 3), (1, 1), padding=(1, 1), act=False)
|
148 |
+
|
149 |
+
def spectrogram(self, x):
|
150 |
+
x = AudioSignal(x, self.sample_rate, stft_params=self.stft_params)
|
151 |
+
x = torch.view_as_real(x.stft())
|
152 |
+
x = rearrange(x, "b 1 f t c -> (b 1) c t f")
|
153 |
+
# Split into bands
|
154 |
+
x_bands = [x[..., b[0] : b[1]] for b in self.bands]
|
155 |
+
return x_bands
|
156 |
+
|
157 |
+
def forward(self, x):
|
158 |
+
x_bands = self.spectrogram(x)
|
159 |
+
fmap = []
|
160 |
+
|
161 |
+
x = []
|
162 |
+
for band, stack in zip(x_bands, self.band_convs):
|
163 |
+
for layer in stack:
|
164 |
+
band = layer(band)
|
165 |
+
fmap.append(band)
|
166 |
+
x.append(band)
|
167 |
+
|
168 |
+
x = torch.cat(x, dim=-1)
|
169 |
+
x = self.conv_post(x)
|
170 |
+
fmap.append(x)
|
171 |
+
|
172 |
+
return fmap
|
173 |
+
|
174 |
+
|
175 |
+
class Discriminator(ml.BaseModel):
|
176 |
+
def __init__(
|
177 |
+
self,
|
178 |
+
rates: list = [],
|
179 |
+
periods: list = [2, 3, 5, 7, 11],
|
180 |
+
fft_sizes: list = [2048, 1024, 512],
|
181 |
+
sample_rate: int = 44100,
|
182 |
+
bands: list = BANDS,
|
183 |
+
):
|
184 |
+
"""Discriminator that combines multiple discriminators.
|
185 |
+
|
186 |
+
Parameters
|
187 |
+
----------
|
188 |
+
rates : list, optional
|
189 |
+
sampling rates (in Hz) to run MSD at, by default []
|
190 |
+
If empty, MSD is not used.
|
191 |
+
periods : list, optional
|
192 |
+
periods (of samples) to run MPD at, by default [2, 3, 5, 7, 11]
|
193 |
+
fft_sizes : list, optional
|
194 |
+
Window sizes of the FFT to run MRD at, by default [2048, 1024, 512]
|
195 |
+
sample_rate : int, optional
|
196 |
+
Sampling rate of audio in Hz, by default 44100
|
197 |
+
bands : list, optional
|
198 |
+
Bands to run MRD at, by default `BANDS`
|
199 |
+
"""
|
200 |
+
super().__init__()
|
201 |
+
discs = []
|
202 |
+
discs += [MPD(p) for p in periods]
|
203 |
+
discs += [MSD(r, sample_rate=sample_rate) for r in rates]
|
204 |
+
discs += [MRD(f, sample_rate=sample_rate, bands=bands) for f in fft_sizes]
|
205 |
+
self.discriminators = nn.ModuleList(discs)
|
206 |
+
|
207 |
+
def preprocess(self, y):
|
208 |
+
# Remove DC offset
|
209 |
+
y = y - y.mean(dim=-1, keepdims=True)
|
210 |
+
# Peak normalize the volume of input audio
|
211 |
+
y = 0.8 * y / (y.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
|
212 |
+
return y
|
213 |
+
|
214 |
+
def forward(self, x):
|
215 |
+
x = self.preprocess(x)
|
216 |
+
fmaps = [d(x) for d in self.discriminators]
|
217 |
+
return fmaps
|
218 |
+
|
219 |
+
|
220 |
+
if __name__ == "__main__":
|
221 |
+
disc = Discriminator()
|
222 |
+
x = torch.zeros(1, 1, 44100)
|
223 |
+
results = disc(x)
|
224 |
+
for i, result in enumerate(results):
|
225 |
+
print(f"disc{i}")
|
226 |
+
for i, r in enumerate(result):
|
227 |
+
print(r.shape, r.mean(), r.min(), r.max())
|
228 |
+
print()
|
hunyuanvideo_foley/models/dac_vae/nn/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from . import layers
|
2 |
+
from . import loss
|
3 |
+
from . import quantize
|
hunyuanvideo_foley/models/dac_vae/nn/layers.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from einops import rearrange
|
6 |
+
from torch.nn.utils import weight_norm
|
7 |
+
|
8 |
+
|
9 |
+
def WNConv1d(*args, **kwargs):
|
10 |
+
return weight_norm(nn.Conv1d(*args, **kwargs))
|
11 |
+
|
12 |
+
|
13 |
+
def WNConvTranspose1d(*args, **kwargs):
|
14 |
+
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
|
15 |
+
|
16 |
+
|
17 |
+
# Scripting this brings model speed up 1.4x
|
18 |
+
@torch.jit.script
|
19 |
+
def snake(x, alpha):
|
20 |
+
shape = x.shape
|
21 |
+
x = x.reshape(shape[0], shape[1], -1)
|
22 |
+
x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
|
23 |
+
x = x.reshape(shape)
|
24 |
+
return x
|
25 |
+
|
26 |
+
|
27 |
+
class Snake1d(nn.Module):
|
28 |
+
def __init__(self, channels):
|
29 |
+
super().__init__()
|
30 |
+
self.alpha = nn.Parameter(torch.ones(1, channels, 1))
|
31 |
+
|
32 |
+
def forward(self, x):
|
33 |
+
return snake(x, self.alpha)
|
hunyuanvideo_foley/models/dac_vae/nn/loss.py
ADDED
@@ -0,0 +1,368 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import typing
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from audiotools import AudioSignal
|
7 |
+
from audiotools import STFTParams
|
8 |
+
from torch import nn
|
9 |
+
|
10 |
+
|
11 |
+
class L1Loss(nn.L1Loss):
|
12 |
+
"""L1 Loss between AudioSignals. Defaults
|
13 |
+
to comparing ``audio_data``, but any
|
14 |
+
attribute of an AudioSignal can be used.
|
15 |
+
|
16 |
+
Parameters
|
17 |
+
----------
|
18 |
+
attribute : str, optional
|
19 |
+
Attribute of signal to compare, defaults to ``audio_data``.
|
20 |
+
weight : float, optional
|
21 |
+
Weight of this loss, defaults to 1.0.
|
22 |
+
|
23 |
+
Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py
|
24 |
+
"""
|
25 |
+
|
26 |
+
def __init__(self, attribute: str = "audio_data", weight: float = 1.0, **kwargs):
|
27 |
+
self.attribute = attribute
|
28 |
+
self.weight = weight
|
29 |
+
super().__init__(**kwargs)
|
30 |
+
|
31 |
+
def forward(self, x: AudioSignal, y: AudioSignal):
|
32 |
+
"""
|
33 |
+
Parameters
|
34 |
+
----------
|
35 |
+
x : AudioSignal
|
36 |
+
Estimate AudioSignal
|
37 |
+
y : AudioSignal
|
38 |
+
Reference AudioSignal
|
39 |
+
|
40 |
+
Returns
|
41 |
+
-------
|
42 |
+
torch.Tensor
|
43 |
+
L1 loss between AudioSignal attributes.
|
44 |
+
"""
|
45 |
+
if isinstance(x, AudioSignal):
|
46 |
+
x = getattr(x, self.attribute)
|
47 |
+
y = getattr(y, self.attribute)
|
48 |
+
return super().forward(x, y)
|
49 |
+
|
50 |
+
|
51 |
+
class SISDRLoss(nn.Module):
|
52 |
+
"""
|
53 |
+
Computes the Scale-Invariant Source-to-Distortion Ratio between a batch
|
54 |
+
of estimated and reference audio signals or aligned features.
|
55 |
+
|
56 |
+
Parameters
|
57 |
+
----------
|
58 |
+
scaling : int, optional
|
59 |
+
Whether to use scale-invariant (True) or
|
60 |
+
signal-to-noise ratio (False), by default True
|
61 |
+
reduction : str, optional
|
62 |
+
How to reduce across the batch (either 'mean',
|
63 |
+
'sum', or none).], by default ' mean'
|
64 |
+
zero_mean : int, optional
|
65 |
+
Zero mean the references and estimates before
|
66 |
+
computing the loss, by default True
|
67 |
+
clip_min : int, optional
|
68 |
+
The minimum possible loss value. Helps network
|
69 |
+
to not focus on making already good examples better, by default None
|
70 |
+
weight : float, optional
|
71 |
+
Weight of this loss, defaults to 1.0.
|
72 |
+
|
73 |
+
Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py
|
74 |
+
"""
|
75 |
+
|
76 |
+
def __init__(
|
77 |
+
self,
|
78 |
+
scaling: int = True,
|
79 |
+
reduction: str = "mean",
|
80 |
+
zero_mean: int = True,
|
81 |
+
clip_min: int = None,
|
82 |
+
weight: float = 1.0,
|
83 |
+
):
|
84 |
+
self.scaling = scaling
|
85 |
+
self.reduction = reduction
|
86 |
+
self.zero_mean = zero_mean
|
87 |
+
self.clip_min = clip_min
|
88 |
+
self.weight = weight
|
89 |
+
super().__init__()
|
90 |
+
|
91 |
+
def forward(self, x: AudioSignal, y: AudioSignal):
|
92 |
+
eps = 1e-8
|
93 |
+
# nb, nc, nt
|
94 |
+
if isinstance(x, AudioSignal):
|
95 |
+
references = x.audio_data
|
96 |
+
estimates = y.audio_data
|
97 |
+
else:
|
98 |
+
references = x
|
99 |
+
estimates = y
|
100 |
+
|
101 |
+
nb = references.shape[0]
|
102 |
+
references = references.reshape(nb, 1, -1).permute(0, 2, 1)
|
103 |
+
estimates = estimates.reshape(nb, 1, -1).permute(0, 2, 1)
|
104 |
+
|
105 |
+
# samples now on axis 1
|
106 |
+
if self.zero_mean:
|
107 |
+
mean_reference = references.mean(dim=1, keepdim=True)
|
108 |
+
mean_estimate = estimates.mean(dim=1, keepdim=True)
|
109 |
+
else:
|
110 |
+
mean_reference = 0
|
111 |
+
mean_estimate = 0
|
112 |
+
|
113 |
+
_references = references - mean_reference
|
114 |
+
_estimates = estimates - mean_estimate
|
115 |
+
|
116 |
+
references_projection = (_references**2).sum(dim=-2) + eps
|
117 |
+
references_on_estimates = (_estimates * _references).sum(dim=-2) + eps
|
118 |
+
|
119 |
+
scale = (
|
120 |
+
(references_on_estimates / references_projection).unsqueeze(1)
|
121 |
+
if self.scaling
|
122 |
+
else 1
|
123 |
+
)
|
124 |
+
|
125 |
+
e_true = scale * _references
|
126 |
+
e_res = _estimates - e_true
|
127 |
+
|
128 |
+
signal = (e_true**2).sum(dim=1)
|
129 |
+
noise = (e_res**2).sum(dim=1)
|
130 |
+
sdr = -10 * torch.log10(signal / noise + eps)
|
131 |
+
|
132 |
+
if self.clip_min is not None:
|
133 |
+
sdr = torch.clamp(sdr, min=self.clip_min)
|
134 |
+
|
135 |
+
if self.reduction == "mean":
|
136 |
+
sdr = sdr.mean()
|
137 |
+
elif self.reduction == "sum":
|
138 |
+
sdr = sdr.sum()
|
139 |
+
return sdr
|
140 |
+
|
141 |
+
|
142 |
+
class MultiScaleSTFTLoss(nn.Module):
|
143 |
+
"""Computes the multi-scale STFT loss from [1].
|
144 |
+
|
145 |
+
Parameters
|
146 |
+
----------
|
147 |
+
window_lengths : List[int], optional
|
148 |
+
Length of each window of each STFT, by default [2048, 512]
|
149 |
+
loss_fn : typing.Callable, optional
|
150 |
+
How to compare each loss, by default nn.L1Loss()
|
151 |
+
clamp_eps : float, optional
|
152 |
+
Clamp on the log magnitude, below, by default 1e-5
|
153 |
+
mag_weight : float, optional
|
154 |
+
Weight of raw magnitude portion of loss, by default 1.0
|
155 |
+
log_weight : float, optional
|
156 |
+
Weight of log magnitude portion of loss, by default 1.0
|
157 |
+
pow : float, optional
|
158 |
+
Power to raise magnitude to before taking log, by default 2.0
|
159 |
+
weight : float, optional
|
160 |
+
Weight of this loss, by default 1.0
|
161 |
+
match_stride : bool, optional
|
162 |
+
Whether to match the stride of convolutional layers, by default False
|
163 |
+
|
164 |
+
References
|
165 |
+
----------
|
166 |
+
|
167 |
+
1. Engel, Jesse, Chenjie Gu, and Adam Roberts.
|
168 |
+
"DDSP: Differentiable Digital Signal Processing."
|
169 |
+
International Conference on Learning Representations. 2019.
|
170 |
+
|
171 |
+
Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py
|
172 |
+
"""
|
173 |
+
|
174 |
+
def __init__(
|
175 |
+
self,
|
176 |
+
window_lengths: List[int] = [2048, 512],
|
177 |
+
loss_fn: typing.Callable = nn.L1Loss(),
|
178 |
+
clamp_eps: float = 1e-5,
|
179 |
+
mag_weight: float = 1.0,
|
180 |
+
log_weight: float = 1.0,
|
181 |
+
pow: float = 2.0,
|
182 |
+
weight: float = 1.0,
|
183 |
+
match_stride: bool = False,
|
184 |
+
window_type: str = None,
|
185 |
+
):
|
186 |
+
super().__init__()
|
187 |
+
self.stft_params = [
|
188 |
+
STFTParams(
|
189 |
+
window_length=w,
|
190 |
+
hop_length=w // 4,
|
191 |
+
match_stride=match_stride,
|
192 |
+
window_type=window_type,
|
193 |
+
)
|
194 |
+
for w in window_lengths
|
195 |
+
]
|
196 |
+
self.loss_fn = loss_fn
|
197 |
+
self.log_weight = log_weight
|
198 |
+
self.mag_weight = mag_weight
|
199 |
+
self.clamp_eps = clamp_eps
|
200 |
+
self.weight = weight
|
201 |
+
self.pow = pow
|
202 |
+
|
203 |
+
def forward(self, x: AudioSignal, y: AudioSignal):
|
204 |
+
"""Computes multi-scale STFT between an estimate and a reference
|
205 |
+
signal.
|
206 |
+
|
207 |
+
Parameters
|
208 |
+
----------
|
209 |
+
x : AudioSignal
|
210 |
+
Estimate signal
|
211 |
+
y : AudioSignal
|
212 |
+
Reference signal
|
213 |
+
|
214 |
+
Returns
|
215 |
+
-------
|
216 |
+
torch.Tensor
|
217 |
+
Multi-scale STFT loss.
|
218 |
+
"""
|
219 |
+
loss = 0.0
|
220 |
+
for s in self.stft_params:
|
221 |
+
x.stft(s.window_length, s.hop_length, s.window_type)
|
222 |
+
y.stft(s.window_length, s.hop_length, s.window_type)
|
223 |
+
loss += self.log_weight * self.loss_fn(
|
224 |
+
x.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(),
|
225 |
+
y.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(),
|
226 |
+
)
|
227 |
+
loss += self.mag_weight * self.loss_fn(x.magnitude, y.magnitude)
|
228 |
+
return loss
|
229 |
+
|
230 |
+
|
231 |
+
class MelSpectrogramLoss(nn.Module):
|
232 |
+
"""Compute distance between mel spectrograms. Can be used
|
233 |
+
in a multi-scale way.
|
234 |
+
|
235 |
+
Parameters
|
236 |
+
----------
|
237 |
+
n_mels : List[int]
|
238 |
+
Number of mels per STFT, by default [150, 80],
|
239 |
+
window_lengths : List[int], optional
|
240 |
+
Length of each window of each STFT, by default [2048, 512]
|
241 |
+
loss_fn : typing.Callable, optional
|
242 |
+
How to compare each loss, by default nn.L1Loss()
|
243 |
+
clamp_eps : float, optional
|
244 |
+
Clamp on the log magnitude, below, by default 1e-5
|
245 |
+
mag_weight : float, optional
|
246 |
+
Weight of raw magnitude portion of loss, by default 1.0
|
247 |
+
log_weight : float, optional
|
248 |
+
Weight of log magnitude portion of loss, by default 1.0
|
249 |
+
pow : float, optional
|
250 |
+
Power to raise magnitude to before taking log, by default 2.0
|
251 |
+
weight : float, optional
|
252 |
+
Weight of this loss, by default 1.0
|
253 |
+
match_stride : bool, optional
|
254 |
+
Whether to match the stride of convolutional layers, by default False
|
255 |
+
|
256 |
+
Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py
|
257 |
+
"""
|
258 |
+
|
259 |
+
def __init__(
|
260 |
+
self,
|
261 |
+
n_mels: List[int] = [150, 80],
|
262 |
+
window_lengths: List[int] = [2048, 512],
|
263 |
+
loss_fn: typing.Callable = nn.L1Loss(),
|
264 |
+
clamp_eps: float = 1e-5,
|
265 |
+
mag_weight: float = 1.0,
|
266 |
+
log_weight: float = 1.0,
|
267 |
+
pow: float = 2.0,
|
268 |
+
weight: float = 1.0,
|
269 |
+
match_stride: bool = False,
|
270 |
+
mel_fmin: List[float] = [0.0, 0.0],
|
271 |
+
mel_fmax: List[float] = [None, None],
|
272 |
+
window_type: str = None,
|
273 |
+
):
|
274 |
+
super().__init__()
|
275 |
+
self.stft_params = [
|
276 |
+
STFTParams(
|
277 |
+
window_length=w,
|
278 |
+
hop_length=w // 4,
|
279 |
+
match_stride=match_stride,
|
280 |
+
window_type=window_type,
|
281 |
+
)
|
282 |
+
for w in window_lengths
|
283 |
+
]
|
284 |
+
self.n_mels = n_mels
|
285 |
+
self.loss_fn = loss_fn
|
286 |
+
self.clamp_eps = clamp_eps
|
287 |
+
self.log_weight = log_weight
|
288 |
+
self.mag_weight = mag_weight
|
289 |
+
self.weight = weight
|
290 |
+
self.mel_fmin = mel_fmin
|
291 |
+
self.mel_fmax = mel_fmax
|
292 |
+
self.pow = pow
|
293 |
+
|
294 |
+
def forward(self, x: AudioSignal, y: AudioSignal):
|
295 |
+
"""Computes mel loss between an estimate and a reference
|
296 |
+
signal.
|
297 |
+
|
298 |
+
Parameters
|
299 |
+
----------
|
300 |
+
x : AudioSignal
|
301 |
+
Estimate signal
|
302 |
+
y : AudioSignal
|
303 |
+
Reference signal
|
304 |
+
|
305 |
+
Returns
|
306 |
+
-------
|
307 |
+
torch.Tensor
|
308 |
+
Mel loss.
|
309 |
+
"""
|
310 |
+
loss = 0.0
|
311 |
+
for n_mels, fmin, fmax, s in zip(
|
312 |
+
self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params
|
313 |
+
):
|
314 |
+
kwargs = {
|
315 |
+
"window_length": s.window_length,
|
316 |
+
"hop_length": s.hop_length,
|
317 |
+
"window_type": s.window_type,
|
318 |
+
}
|
319 |
+
x_mels = x.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs)
|
320 |
+
y_mels = y.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs)
|
321 |
+
|
322 |
+
loss += self.log_weight * self.loss_fn(
|
323 |
+
x_mels.clamp(self.clamp_eps).pow(self.pow).log10(),
|
324 |
+
y_mels.clamp(self.clamp_eps).pow(self.pow).log10(),
|
325 |
+
)
|
326 |
+
loss += self.mag_weight * self.loss_fn(x_mels, y_mels)
|
327 |
+
return loss
|
328 |
+
|
329 |
+
|
330 |
+
class GANLoss(nn.Module):
|
331 |
+
"""
|
332 |
+
Computes a discriminator loss, given a discriminator on
|
333 |
+
generated waveforms/spectrograms compared to ground truth
|
334 |
+
waveforms/spectrograms. Computes the loss for both the
|
335 |
+
discriminator and the generator in separate functions.
|
336 |
+
"""
|
337 |
+
|
338 |
+
def __init__(self, discriminator):
|
339 |
+
super().__init__()
|
340 |
+
self.discriminator = discriminator
|
341 |
+
|
342 |
+
def forward(self, fake, real):
|
343 |
+
d_fake = self.discriminator(fake.audio_data)
|
344 |
+
d_real = self.discriminator(real.audio_data)
|
345 |
+
return d_fake, d_real
|
346 |
+
|
347 |
+
def discriminator_loss(self, fake, real):
|
348 |
+
d_fake, d_real = self.forward(fake.clone().detach(), real)
|
349 |
+
|
350 |
+
loss_d = 0
|
351 |
+
for x_fake, x_real in zip(d_fake, d_real):
|
352 |
+
loss_d += torch.mean(x_fake[-1] ** 2)
|
353 |
+
loss_d += torch.mean((1 - x_real[-1]) ** 2)
|
354 |
+
return loss_d
|
355 |
+
|
356 |
+
def generator_loss(self, fake, real):
|
357 |
+
d_fake, d_real = self.forward(fake, real)
|
358 |
+
|
359 |
+
loss_g = 0
|
360 |
+
for x_fake in d_fake:
|
361 |
+
loss_g += torch.mean((1 - x_fake[-1]) ** 2)
|
362 |
+
|
363 |
+
loss_feature = 0
|
364 |
+
|
365 |
+
for i in range(len(d_fake)):
|
366 |
+
for j in range(len(d_fake[i]) - 1):
|
367 |
+
loss_feature += F.l1_loss(d_fake[i][j], d_real[i][j].detach())
|
368 |
+
return loss_g, loss_feature
|
hunyuanvideo_foley/models/dac_vae/nn/quantize.py
ADDED
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Union
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from einops import rearrange
|
8 |
+
from torch.nn.utils import weight_norm
|
9 |
+
|
10 |
+
from .layers import WNConv1d
|
11 |
+
|
12 |
+
|
13 |
+
class VectorQuantize(nn.Module):
|
14 |
+
"""
|
15 |
+
Implementation of VQ similar to Karpathy's repo:
|
16 |
+
https://github.com/karpathy/deep-vector-quantization
|
17 |
+
Additionally uses following tricks from Improved VQGAN
|
18 |
+
(https://arxiv.org/pdf/2110.04627.pdf):
|
19 |
+
1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
|
20 |
+
for improved codebook usage
|
21 |
+
2. l2-normalized codes: Converts euclidean distance to cosine similarity which
|
22 |
+
improves training stability
|
23 |
+
"""
|
24 |
+
|
25 |
+
def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int):
|
26 |
+
super().__init__()
|
27 |
+
self.codebook_size = codebook_size
|
28 |
+
self.codebook_dim = codebook_dim
|
29 |
+
|
30 |
+
self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1)
|
31 |
+
self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1)
|
32 |
+
self.codebook = nn.Embedding(codebook_size, codebook_dim)
|
33 |
+
|
34 |
+
def forward(self, z):
|
35 |
+
"""Quantized the input tensor using a fixed codebook and returns
|
36 |
+
the corresponding codebook vectors
|
37 |
+
|
38 |
+
Parameters
|
39 |
+
----------
|
40 |
+
z : Tensor[B x D x T]
|
41 |
+
|
42 |
+
Returns
|
43 |
+
-------
|
44 |
+
Tensor[B x D x T]
|
45 |
+
Quantized continuous representation of input
|
46 |
+
Tensor[1]
|
47 |
+
Commitment loss to train encoder to predict vectors closer to codebook
|
48 |
+
entries
|
49 |
+
Tensor[1]
|
50 |
+
Codebook loss to update the codebook
|
51 |
+
Tensor[B x T]
|
52 |
+
Codebook indices (quantized discrete representation of input)
|
53 |
+
Tensor[B x D x T]
|
54 |
+
Projected latents (continuous representation of input before quantization)
|
55 |
+
"""
|
56 |
+
|
57 |
+
# Factorized codes (ViT-VQGAN) Project input into low-dimensional space
|
58 |
+
z_e = self.in_proj(z) # z_e : (B x D x T)
|
59 |
+
z_q, indices = self.decode_latents(z_e)
|
60 |
+
|
61 |
+
commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
|
62 |
+
codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
|
63 |
+
|
64 |
+
z_q = (
|
65 |
+
z_e + (z_q - z_e).detach()
|
66 |
+
) # noop in forward pass, straight-through gradient estimator in backward pass
|
67 |
+
|
68 |
+
z_q = self.out_proj(z_q)
|
69 |
+
|
70 |
+
return z_q, commitment_loss, codebook_loss, indices, z_e
|
71 |
+
|
72 |
+
def embed_code(self, embed_id):
|
73 |
+
return F.embedding(embed_id, self.codebook.weight)
|
74 |
+
|
75 |
+
def decode_code(self, embed_id):
|
76 |
+
return self.embed_code(embed_id).transpose(1, 2)
|
77 |
+
|
78 |
+
def decode_latents(self, latents):
|
79 |
+
encodings = rearrange(latents, "b d t -> (b t) d")
|
80 |
+
codebook = self.codebook.weight # codebook: (N x D)
|
81 |
+
|
82 |
+
# L2 normalize encodings and codebook (ViT-VQGAN)
|
83 |
+
encodings = F.normalize(encodings)
|
84 |
+
codebook = F.normalize(codebook)
|
85 |
+
|
86 |
+
# Compute euclidean distance with codebook
|
87 |
+
dist = (
|
88 |
+
encodings.pow(2).sum(1, keepdim=True)
|
89 |
+
- 2 * encodings @ codebook.t()
|
90 |
+
+ codebook.pow(2).sum(1, keepdim=True).t()
|
91 |
+
)
|
92 |
+
indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
|
93 |
+
z_q = self.decode_code(indices)
|
94 |
+
return z_q, indices
|
95 |
+
|
96 |
+
|
97 |
+
class ResidualVectorQuantize(nn.Module):
|
98 |
+
"""
|
99 |
+
Introduced in SoundStream: An end2end neural audio codec
|
100 |
+
https://arxiv.org/abs/2107.03312
|
101 |
+
"""
|
102 |
+
|
103 |
+
def __init__(
|
104 |
+
self,
|
105 |
+
input_dim: int = 512,
|
106 |
+
n_codebooks: int = 9,
|
107 |
+
codebook_size: int = 1024,
|
108 |
+
codebook_dim: Union[int, list] = 8,
|
109 |
+
quantizer_dropout: float = 0.0,
|
110 |
+
):
|
111 |
+
super().__init__()
|
112 |
+
if isinstance(codebook_dim, int):
|
113 |
+
codebook_dim = [codebook_dim for _ in range(n_codebooks)]
|
114 |
+
|
115 |
+
self.n_codebooks = n_codebooks
|
116 |
+
self.codebook_dim = codebook_dim
|
117 |
+
self.codebook_size = codebook_size
|
118 |
+
|
119 |
+
self.quantizers = nn.ModuleList(
|
120 |
+
[
|
121 |
+
VectorQuantize(input_dim, codebook_size, codebook_dim[i])
|
122 |
+
for i in range(n_codebooks)
|
123 |
+
]
|
124 |
+
)
|
125 |
+
self.quantizer_dropout = quantizer_dropout
|
126 |
+
|
127 |
+
def forward(self, z, n_quantizers: int = None):
|
128 |
+
"""Quantized the input tensor using a fixed set of `n` codebooks and returns
|
129 |
+
the corresponding codebook vectors
|
130 |
+
Parameters
|
131 |
+
----------
|
132 |
+
z : Tensor[B x D x T]
|
133 |
+
n_quantizers : int, optional
|
134 |
+
No. of quantizers to use
|
135 |
+
(n_quantizers < self.n_codebooks ex: for quantizer dropout)
|
136 |
+
Note: if `self.quantizer_dropout` is True, this argument is ignored
|
137 |
+
when in training mode, and a random number of quantizers is used.
|
138 |
+
Returns
|
139 |
+
-------
|
140 |
+
dict
|
141 |
+
A dictionary with the following keys:
|
142 |
+
|
143 |
+
"z" : Tensor[B x D x T]
|
144 |
+
Quantized continuous representation of input
|
145 |
+
"codes" : Tensor[B x N x T]
|
146 |
+
Codebook indices for each codebook
|
147 |
+
(quantized discrete representation of input)
|
148 |
+
"latents" : Tensor[B x N*D x T]
|
149 |
+
Projected latents (continuous representation of input before quantization)
|
150 |
+
"vq/commitment_loss" : Tensor[1]
|
151 |
+
Commitment loss to train encoder to predict vectors closer to codebook
|
152 |
+
entries
|
153 |
+
"vq/codebook_loss" : Tensor[1]
|
154 |
+
Codebook loss to update the codebook
|
155 |
+
"""
|
156 |
+
z_q = 0
|
157 |
+
residual = z
|
158 |
+
commitment_loss = 0
|
159 |
+
codebook_loss = 0
|
160 |
+
|
161 |
+
codebook_indices = []
|
162 |
+
latents = []
|
163 |
+
|
164 |
+
if n_quantizers is None:
|
165 |
+
n_quantizers = self.n_codebooks
|
166 |
+
if self.training:
|
167 |
+
n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1
|
168 |
+
dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],))
|
169 |
+
n_dropout = int(z.shape[0] * self.quantizer_dropout)
|
170 |
+
n_quantizers[:n_dropout] = dropout[:n_dropout]
|
171 |
+
n_quantizers = n_quantizers.to(z.device)
|
172 |
+
|
173 |
+
for i, quantizer in enumerate(self.quantizers):
|
174 |
+
if self.training is False and i >= n_quantizers:
|
175 |
+
break
|
176 |
+
|
177 |
+
z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(
|
178 |
+
residual
|
179 |
+
)
|
180 |
+
|
181 |
+
# Create mask to apply quantizer dropout
|
182 |
+
mask = (
|
183 |
+
torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
|
184 |
+
)
|
185 |
+
z_q = z_q + z_q_i * mask[:, None, None]
|
186 |
+
residual = residual - z_q_i
|
187 |
+
|
188 |
+
# Sum losses
|
189 |
+
commitment_loss += (commitment_loss_i * mask).mean()
|
190 |
+
codebook_loss += (codebook_loss_i * mask).mean()
|
191 |
+
|
192 |
+
codebook_indices.append(indices_i)
|
193 |
+
latents.append(z_e_i)
|
194 |
+
|
195 |
+
codes = torch.stack(codebook_indices, dim=1)
|
196 |
+
latents = torch.cat(latents, dim=1)
|
197 |
+
|
198 |
+
return z_q, codes, latents, commitment_loss, codebook_loss
|
199 |
+
|
200 |
+
def from_codes(self, codes: torch.Tensor):
|
201 |
+
"""Given the quantized codes, reconstruct the continuous representation
|
202 |
+
Parameters
|
203 |
+
----------
|
204 |
+
codes : Tensor[B x N x T]
|
205 |
+
Quantized discrete representation of input
|
206 |
+
Returns
|
207 |
+
-------
|
208 |
+
Tensor[B x D x T]
|
209 |
+
Quantized continuous representation of input
|
210 |
+
"""
|
211 |
+
z_q = 0.0
|
212 |
+
z_p = []
|
213 |
+
n_codebooks = codes.shape[1]
|
214 |
+
for i in range(n_codebooks):
|
215 |
+
z_p_i = self.quantizers[i].decode_code(codes[:, i, :])
|
216 |
+
z_p.append(z_p_i)
|
217 |
+
|
218 |
+
z_q_i = self.quantizers[i].out_proj(z_p_i)
|
219 |
+
z_q = z_q + z_q_i
|
220 |
+
return z_q, torch.cat(z_p, dim=1), codes
|
221 |
+
|
222 |
+
def from_latents(self, latents: torch.Tensor):
|
223 |
+
"""Given the unquantized latents, reconstruct the
|
224 |
+
continuous representation after quantization.
|
225 |
+
|
226 |
+
Parameters
|
227 |
+
----------
|
228 |
+
latents : Tensor[B x N x T]
|
229 |
+
Continuous representation of input after projection
|
230 |
+
|
231 |
+
Returns
|
232 |
+
-------
|
233 |
+
Tensor[B x D x T]
|
234 |
+
Quantized representation of full-projected space
|
235 |
+
Tensor[B x D x T]
|
236 |
+
Quantized representation of latent space
|
237 |
+
"""
|
238 |
+
z_q = 0
|
239 |
+
z_p = []
|
240 |
+
codes = []
|
241 |
+
dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers])
|
242 |
+
|
243 |
+
n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[
|
244 |
+
0
|
245 |
+
]
|
246 |
+
for i in range(n_codebooks):
|
247 |
+
j, k = dims[i], dims[i + 1]
|
248 |
+
z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :])
|
249 |
+
z_p.append(z_p_i)
|
250 |
+
codes.append(codes_i)
|
251 |
+
|
252 |
+
z_q_i = self.quantizers[i].out_proj(z_p_i)
|
253 |
+
z_q = z_q + z_q_i
|
254 |
+
|
255 |
+
return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1)
|
256 |
+
|
257 |
+
|
258 |
+
if __name__ == "__main__":
|
259 |
+
rvq = ResidualVectorQuantize(quantizer_dropout=True)
|
260 |
+
x = torch.randn(16, 512, 80)
|
261 |
+
y = rvq(x)
|
262 |
+
print(y["latents"].shape)
|
hunyuanvideo_foley/models/dac_vae/nn/vae_utils.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
|
5 |
+
class AbstractDistribution:
|
6 |
+
def sample(self):
|
7 |
+
raise NotImplementedError()
|
8 |
+
|
9 |
+
def mode(self):
|
10 |
+
raise NotImplementedError()
|
11 |
+
|
12 |
+
|
13 |
+
class DiracDistribution(AbstractDistribution):
|
14 |
+
def __init__(self, value):
|
15 |
+
self.value = value
|
16 |
+
|
17 |
+
def sample(self):
|
18 |
+
return self.value
|
19 |
+
|
20 |
+
def mode(self):
|
21 |
+
return self.value
|
22 |
+
|
23 |
+
|
24 |
+
class DiagonalGaussianDistribution(object):
|
25 |
+
def __init__(self, parameters, deterministic=False):
|
26 |
+
self.parameters = parameters
|
27 |
+
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
|
28 |
+
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
29 |
+
self.deterministic = deterministic
|
30 |
+
self.std = torch.exp(0.5 * self.logvar)
|
31 |
+
self.var = torch.exp(self.logvar)
|
32 |
+
if self.deterministic:
|
33 |
+
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
|
34 |
+
|
35 |
+
def sample(self):
|
36 |
+
x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
|
37 |
+
return x
|
38 |
+
|
39 |
+
def kl(self, other=None):
|
40 |
+
if self.deterministic:
|
41 |
+
return torch.Tensor([0.0])
|
42 |
+
else:
|
43 |
+
if other is None:
|
44 |
+
return 0.5 * torch.mean(
|
45 |
+
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
|
46 |
+
dim=[1, 2],
|
47 |
+
)
|
48 |
+
else:
|
49 |
+
return 0.5 * torch.mean(
|
50 |
+
torch.pow(self.mean - other.mean, 2) / other.var
|
51 |
+
+ self.var / other.var
|
52 |
+
- 1.0
|
53 |
+
- self.logvar
|
54 |
+
+ other.logvar,
|
55 |
+
dim=[1, 2],
|
56 |
+
)
|
57 |
+
|
58 |
+
def nll(self, sample, dims=[1, 2]):
|
59 |
+
if self.deterministic:
|
60 |
+
return torch.Tensor([0.0])
|
61 |
+
logtwopi = np.log(2.0 * np.pi)
|
62 |
+
return 0.5 * torch.sum(
|
63 |
+
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
|
64 |
+
dim=dims,
|
65 |
+
)
|
66 |
+
|
67 |
+
def mode(self):
|
68 |
+
return self.mean
|
69 |
+
|
70 |
+
|
71 |
+
def normal_kl(mean1, logvar1, mean2, logvar2):
|
72 |
+
"""
|
73 |
+
source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
|
74 |
+
Compute the KL divergence between two gaussians.
|
75 |
+
Shapes are automatically broadcasted, so batches can be compared to
|
76 |
+
scalars, among other use cases.
|
77 |
+
"""
|
78 |
+
tensor = None
|
79 |
+
for obj in (mean1, logvar1, mean2, logvar2):
|
80 |
+
if isinstance(obj, torch.Tensor):
|
81 |
+
tensor = obj
|
82 |
+
break
|
83 |
+
assert tensor is not None, "at least one argument must be a Tensor"
|
84 |
+
|
85 |
+
# Force variances to be Tensors. Broadcasting helps convert scalars to
|
86 |
+
# Tensors, but it does not work for torch.exp().
|
87 |
+
logvar1, logvar2 = [x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) for x in (logvar1, logvar2)]
|
88 |
+
|
89 |
+
return 0.5 * (
|
90 |
+
-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
|
91 |
+
)
|
hunyuanvideo_foley/models/dac_vae/utils/__init__.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
|
3 |
+
import argbind
|
4 |
+
from audiotools import ml
|
5 |
+
|
6 |
+
from ..model import DAC
|
7 |
+
Accelerator = ml.Accelerator
|
8 |
+
|
9 |
+
__MODEL_LATEST_TAGS__ = {
|
10 |
+
("44khz", "8kbps"): "0.0.1",
|
11 |
+
("24khz", "8kbps"): "0.0.4",
|
12 |
+
("16khz", "8kbps"): "0.0.5",
|
13 |
+
("44khz", "16kbps"): "1.0.0",
|
14 |
+
}
|
15 |
+
|
16 |
+
__MODEL_URLS__ = {
|
17 |
+
(
|
18 |
+
"44khz",
|
19 |
+
"0.0.1",
|
20 |
+
"8kbps",
|
21 |
+
): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.1/weights.pth",
|
22 |
+
(
|
23 |
+
"24khz",
|
24 |
+
"0.0.4",
|
25 |
+
"8kbps",
|
26 |
+
): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.4/weights_24khz.pth",
|
27 |
+
(
|
28 |
+
"16khz",
|
29 |
+
"0.0.5",
|
30 |
+
"8kbps",
|
31 |
+
): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.5/weights_16khz.pth",
|
32 |
+
(
|
33 |
+
"44khz",
|
34 |
+
"1.0.0",
|
35 |
+
"16kbps",
|
36 |
+
): "https://github.com/descriptinc/descript-audio-codec/releases/download/1.0.0/weights_44khz_16kbps.pth",
|
37 |
+
}
|
38 |
+
|
39 |
+
|
40 |
+
@argbind.bind(group="download", positional=True, without_prefix=True)
|
41 |
+
def download(
|
42 |
+
model_type: str = "44khz", model_bitrate: str = "8kbps", tag: str = "latest"
|
43 |
+
):
|
44 |
+
"""
|
45 |
+
Function that downloads the weights file from URL if a local cache is not found.
|
46 |
+
|
47 |
+
Parameters
|
48 |
+
----------
|
49 |
+
model_type : str
|
50 |
+
The type of model to download. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz".
|
51 |
+
model_bitrate: str
|
52 |
+
Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps".
|
53 |
+
Only 44khz model supports 16kbps.
|
54 |
+
tag : str
|
55 |
+
The tag of the model to download. Defaults to "latest".
|
56 |
+
|
57 |
+
Returns
|
58 |
+
-------
|
59 |
+
Path
|
60 |
+
Directory path required to load model via audiotools.
|
61 |
+
"""
|
62 |
+
model_type = model_type.lower()
|
63 |
+
tag = tag.lower()
|
64 |
+
|
65 |
+
assert model_type in [
|
66 |
+
"44khz",
|
67 |
+
"24khz",
|
68 |
+
"16khz",
|
69 |
+
], "model_type must be one of '44khz', '24khz', or '16khz'"
|
70 |
+
|
71 |
+
assert model_bitrate in [
|
72 |
+
"8kbps",
|
73 |
+
"16kbps",
|
74 |
+
], "model_bitrate must be one of '8kbps', or '16kbps'"
|
75 |
+
|
76 |
+
if tag == "latest":
|
77 |
+
tag = __MODEL_LATEST_TAGS__[(model_type, model_bitrate)]
|
78 |
+
|
79 |
+
download_link = __MODEL_URLS__.get((model_type, tag, model_bitrate), None)
|
80 |
+
|
81 |
+
if download_link is None:
|
82 |
+
raise ValueError(
|
83 |
+
f"Could not find model with tag {tag} and model type {model_type}"
|
84 |
+
)
|
85 |
+
|
86 |
+
local_path = (
|
87 |
+
Path.home()
|
88 |
+
/ ".cache"
|
89 |
+
/ "descript"
|
90 |
+
/ "dac"
|
91 |
+
/ f"weights_{model_type}_{model_bitrate}_{tag}.pth"
|
92 |
+
)
|
93 |
+
if not local_path.exists():
|
94 |
+
local_path.parent.mkdir(parents=True, exist_ok=True)
|
95 |
+
|
96 |
+
# Download the model
|
97 |
+
import requests
|
98 |
+
|
99 |
+
response = requests.get(download_link)
|
100 |
+
|
101 |
+
if response.status_code != 200:
|
102 |
+
raise ValueError(
|
103 |
+
f"Could not download model. Received response code {response.status_code}"
|
104 |
+
)
|
105 |
+
local_path.write_bytes(response.content)
|
106 |
+
|
107 |
+
return local_path
|
108 |
+
|
109 |
+
|
110 |
+
def load_model(
|
111 |
+
model_type: str = "44khz",
|
112 |
+
model_bitrate: str = "8kbps",
|
113 |
+
tag: str = "latest",
|
114 |
+
load_path: str = None,
|
115 |
+
):
|
116 |
+
if not load_path:
|
117 |
+
load_path = download(
|
118 |
+
model_type=model_type, model_bitrate=model_bitrate, tag=tag
|
119 |
+
)
|
120 |
+
generator = DAC.load(load_path)
|
121 |
+
return generator
|
hunyuanvideo_foley/models/dac_vae/utils/decode.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
import argbind
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from audiotools import AudioSignal
|
8 |
+
from tqdm import tqdm
|
9 |
+
|
10 |
+
from ..model import DACFile
|
11 |
+
from . import load_model
|
12 |
+
|
13 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
14 |
+
|
15 |
+
|
16 |
+
@argbind.bind(group="decode", positional=True, without_prefix=True)
|
17 |
+
@torch.inference_mode()
|
18 |
+
@torch.no_grad()
|
19 |
+
def decode(
|
20 |
+
input: str,
|
21 |
+
output: str = "",
|
22 |
+
weights_path: str = "",
|
23 |
+
model_tag: str = "latest",
|
24 |
+
model_bitrate: str = "8kbps",
|
25 |
+
device: str = "cuda",
|
26 |
+
model_type: str = "44khz",
|
27 |
+
verbose: bool = False,
|
28 |
+
):
|
29 |
+
"""Decode audio from codes.
|
30 |
+
|
31 |
+
Parameters
|
32 |
+
----------
|
33 |
+
input : str
|
34 |
+
Path to input directory or file
|
35 |
+
output : str, optional
|
36 |
+
Path to output directory, by default "".
|
37 |
+
If `input` is a directory, the directory sub-tree relative to `input` is re-created in `output`.
|
38 |
+
weights_path : str, optional
|
39 |
+
Path to weights file, by default "". If not specified, the weights file will be downloaded from the internet using the
|
40 |
+
model_tag and model_type.
|
41 |
+
model_tag : str, optional
|
42 |
+
Tag of the model to use, by default "latest". Ignored if `weights_path` is specified.
|
43 |
+
model_bitrate: str
|
44 |
+
Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps".
|
45 |
+
device : str, optional
|
46 |
+
Device to use, by default "cuda". If "cpu", the model will be loaded on the CPU.
|
47 |
+
model_type : str, optional
|
48 |
+
The type of model to use. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". Ignored if `weights_path` is specified.
|
49 |
+
"""
|
50 |
+
generator = load_model(
|
51 |
+
model_type=model_type,
|
52 |
+
model_bitrate=model_bitrate,
|
53 |
+
tag=model_tag,
|
54 |
+
load_path=weights_path,
|
55 |
+
)
|
56 |
+
generator.to(device)
|
57 |
+
generator.eval()
|
58 |
+
|
59 |
+
# Find all .dac files in input directory
|
60 |
+
_input = Path(input)
|
61 |
+
input_files = list(_input.glob("**/*.dac"))
|
62 |
+
|
63 |
+
# If input is a .dac file, add it to the list
|
64 |
+
if _input.suffix == ".dac":
|
65 |
+
input_files.append(_input)
|
66 |
+
|
67 |
+
# Create output directory
|
68 |
+
output = Path(output)
|
69 |
+
output.mkdir(parents=True, exist_ok=True)
|
70 |
+
|
71 |
+
for i in tqdm(range(len(input_files)), desc=f"Decoding files"):
|
72 |
+
# Load file
|
73 |
+
artifact = DACFile.load(input_files[i])
|
74 |
+
|
75 |
+
# Reconstruct audio from codes
|
76 |
+
recons = generator.decompress(artifact, verbose=verbose)
|
77 |
+
|
78 |
+
# Compute output path
|
79 |
+
relative_path = input_files[i].relative_to(input)
|
80 |
+
output_dir = output / relative_path.parent
|
81 |
+
if not relative_path.name:
|
82 |
+
output_dir = output
|
83 |
+
relative_path = input_files[i]
|
84 |
+
output_name = relative_path.with_suffix(".wav").name
|
85 |
+
output_path = output_dir / output_name
|
86 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
87 |
+
|
88 |
+
# Write to file
|
89 |
+
recons.write(output_path)
|
90 |
+
|
91 |
+
|
92 |
+
if __name__ == "__main__":
|
93 |
+
args = argbind.parse_args()
|
94 |
+
with argbind.scope(args):
|
95 |
+
decode()
|
hunyuanvideo_foley/models/dac_vae/utils/encode.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import warnings
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
import argbind
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from audiotools import AudioSignal
|
9 |
+
from audiotools.core import util
|
10 |
+
from tqdm import tqdm
|
11 |
+
|
12 |
+
from . import load_model
|
13 |
+
|
14 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
15 |
+
|
16 |
+
|
17 |
+
@argbind.bind(group="encode", positional=True, without_prefix=True)
|
18 |
+
@torch.inference_mode()
|
19 |
+
@torch.no_grad()
|
20 |
+
def encode(
|
21 |
+
input: str,
|
22 |
+
output: str = "",
|
23 |
+
weights_path: str = "",
|
24 |
+
model_tag: str = "latest",
|
25 |
+
model_bitrate: str = "8kbps",
|
26 |
+
n_quantizers: int = None,
|
27 |
+
device: str = "cuda",
|
28 |
+
model_type: str = "44khz",
|
29 |
+
win_duration: float = 5.0,
|
30 |
+
verbose: bool = False,
|
31 |
+
):
|
32 |
+
"""Encode audio files in input path to .dac format.
|
33 |
+
|
34 |
+
Parameters
|
35 |
+
----------
|
36 |
+
input : str
|
37 |
+
Path to input audio file or directory
|
38 |
+
output : str, optional
|
39 |
+
Path to output directory, by default "". If `input` is a directory, the directory sub-tree relative to `input` is re-created in `output`.
|
40 |
+
weights_path : str, optional
|
41 |
+
Path to weights file, by default "". If not specified, the weights file will be downloaded from the internet using the
|
42 |
+
model_tag and model_type.
|
43 |
+
model_tag : str, optional
|
44 |
+
Tag of the model to use, by default "latest". Ignored if `weights_path` is specified.
|
45 |
+
model_bitrate: str
|
46 |
+
Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps".
|
47 |
+
n_quantizers : int, optional
|
48 |
+
Number of quantizers to use, by default None. If not specified, all the quantizers will be used and the model will compress at maximum bitrate.
|
49 |
+
device : str, optional
|
50 |
+
Device to use, by default "cuda"
|
51 |
+
model_type : str, optional
|
52 |
+
The type of model to use. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". Ignored if `weights_path` is specified.
|
53 |
+
"""
|
54 |
+
generator = load_model(
|
55 |
+
model_type=model_type,
|
56 |
+
model_bitrate=model_bitrate,
|
57 |
+
tag=model_tag,
|
58 |
+
load_path=weights_path,
|
59 |
+
)
|
60 |
+
generator.to(device)
|
61 |
+
generator.eval()
|
62 |
+
kwargs = {"n_quantizers": n_quantizers}
|
63 |
+
|
64 |
+
# Find all audio files in input path
|
65 |
+
input = Path(input)
|
66 |
+
audio_files = util.find_audio(input)
|
67 |
+
|
68 |
+
output = Path(output)
|
69 |
+
output.mkdir(parents=True, exist_ok=True)
|
70 |
+
|
71 |
+
for i in tqdm(range(len(audio_files)), desc="Encoding files"):
|
72 |
+
# Load file
|
73 |
+
signal = AudioSignal(audio_files[i])
|
74 |
+
|
75 |
+
# Encode audio to .dac format
|
76 |
+
artifact = generator.compress(signal, win_duration, verbose=verbose, **kwargs)
|
77 |
+
|
78 |
+
# Compute output path
|
79 |
+
relative_path = audio_files[i].relative_to(input)
|
80 |
+
output_dir = output / relative_path.parent
|
81 |
+
if not relative_path.name:
|
82 |
+
output_dir = output
|
83 |
+
relative_path = audio_files[i]
|
84 |
+
output_name = relative_path.with_suffix(".dac").name
|
85 |
+
output_path = output_dir / output_name
|
86 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
87 |
+
|
88 |
+
artifact.save(output_path)
|
89 |
+
|
90 |
+
|
91 |
+
if __name__ == "__main__":
|
92 |
+
args = argbind.parse_args()
|
93 |
+
with argbind.scope(args):
|
94 |
+
encode()
|
hunyuanvideo_foley/models/hifi_foley.py
ADDED
@@ -0,0 +1,794 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Tuple, Optional, Union, Dict
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from einops import rearrange
|
7 |
+
from einops.layers.torch import Rearrange
|
8 |
+
from diffusers.models import ModelMixin
|
9 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
10 |
+
|
11 |
+
from .nn.activation_layers import SwiGLU, get_activation_layer
|
12 |
+
from .nn.attn_layers import apply_rotary_emb, attention
|
13 |
+
from .nn.embed_layers import TimestepEmbedder, ConditionProjection, PatchEmbed1D
|
14 |
+
from .nn.mlp_layers import MLP, ConvMLP, FinalLayer1D, ChannelLastConv1d
|
15 |
+
from .nn.modulate_layers import ModulateDiT, ckpt_wrapper, apply_gate, modulate
|
16 |
+
from .nn.norm_layers import get_norm_layer
|
17 |
+
from .nn.posemb_layers import get_nd_rotary_pos_embed
|
18 |
+
|
19 |
+
def interleave_two_sequences(x1: torch.Tensor, x2: torch.Tensor):
|
20 |
+
# [B, N1, H, C] & [B, N2, H, C]
|
21 |
+
B, N1, H, C = x1.shape
|
22 |
+
B, N2, H, C = x2.shape
|
23 |
+
assert x1.ndim == x2.ndim == 4
|
24 |
+
|
25 |
+
if N1 != N2:
|
26 |
+
x2 = x2.view(B, N2, -1).transpose(1, 2)
|
27 |
+
x2 = F.interpolate(x2, size=(N1), mode="nearest-exact")
|
28 |
+
x2 = x2.transpose(1, 2).view(B, N1, H, C)
|
29 |
+
x = torch.stack((x1, x2), dim=2)
|
30 |
+
x = x.reshape(B, N1 * 2, H, C)
|
31 |
+
return x
|
32 |
+
|
33 |
+
def decouple_interleaved_two_sequences(x: torch.Tensor, len1: int, len2: int):
|
34 |
+
B, N, H, C = x.shape
|
35 |
+
assert N % 2 == 0 and N // 2 == len1
|
36 |
+
|
37 |
+
x = x.reshape(B, -1, 2, H, C)
|
38 |
+
x1 = x[:, :, 0]
|
39 |
+
x2 = x[:, :, 1]
|
40 |
+
if x2.shape[1] != len2:
|
41 |
+
x2 = x2.view(B, len1, H * C).transpose(1, 2)
|
42 |
+
x2 = F.interpolate(x2, size=(len2), mode="nearest-exact")
|
43 |
+
x2 = x2.transpose(1, 2).view(B, len2, H, C)
|
44 |
+
return x1, x2
|
45 |
+
|
46 |
+
class TwoStreamCABlock(nn.Module):
|
47 |
+
def __init__(
|
48 |
+
self,
|
49 |
+
hidden_size: int,
|
50 |
+
num_heads: int,
|
51 |
+
mlp_ratio: float,
|
52 |
+
mlp_act_type: str = "gelu_tanh",
|
53 |
+
qk_norm: bool = True,
|
54 |
+
qk_norm_type: str = "rms",
|
55 |
+
qkv_bias: bool = False,
|
56 |
+
attn_mode: str = "torch",
|
57 |
+
reverse: bool = False,
|
58 |
+
interleaved_audio_visual_rope: bool = False,
|
59 |
+
dtype: Optional[torch.dtype] = None,
|
60 |
+
device: Optional[torch.device] = None,
|
61 |
+
):
|
62 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
63 |
+
super().__init__()
|
64 |
+
|
65 |
+
self.deterministic = False
|
66 |
+
self.reverse = reverse
|
67 |
+
self.attn_mode = attn_mode
|
68 |
+
self.num_heads = num_heads
|
69 |
+
self.hidden_size = hidden_size
|
70 |
+
head_dim = hidden_size // num_heads
|
71 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
72 |
+
|
73 |
+
self.interleaved_audio_visual_rope = interleaved_audio_visual_rope
|
74 |
+
|
75 |
+
# Self attention for audio + visual
|
76 |
+
self.audio_mod = ModulateDiT(hidden_size, factor=9, act_layer=get_activation_layer("silu"), **factory_kwargs)
|
77 |
+
self.audio_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
78 |
+
self.audio_self_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs)
|
79 |
+
qk_norm_layer = get_norm_layer(qk_norm_type)
|
80 |
+
self.audio_self_q_norm = (
|
81 |
+
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
|
82 |
+
)
|
83 |
+
self.audio_self_k_norm = (
|
84 |
+
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
|
85 |
+
)
|
86 |
+
self.audio_self_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
|
87 |
+
|
88 |
+
# visual cond
|
89 |
+
self.v_cond_mod = ModulateDiT(hidden_size, factor=9, act_layer=get_activation_layer("silu"), **factory_kwargs)
|
90 |
+
self.v_cond_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
91 |
+
self.v_cond_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs)
|
92 |
+
self.v_cond_attn_q_norm = (
|
93 |
+
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
|
94 |
+
)
|
95 |
+
self.v_cond_attn_k_norm = (
|
96 |
+
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
|
97 |
+
)
|
98 |
+
self.v_cond_self_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
|
99 |
+
|
100 |
+
self.max_text_len = 100
|
101 |
+
self.rope_dim_list = None
|
102 |
+
|
103 |
+
# audio and video norm for cross attention with text
|
104 |
+
self.audio_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
105 |
+
self.v_cond_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
106 |
+
|
107 |
+
# Cross attention: (video_audio) as query, text as key/value
|
108 |
+
self.audio_cross_q = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
|
109 |
+
self.v_cond_cross_q = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
|
110 |
+
self.text_cross_kv = nn.Linear(hidden_size, hidden_size * 2, bias=qkv_bias, **factory_kwargs)
|
111 |
+
|
112 |
+
self.audio_cross_q_norm = (
|
113 |
+
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
|
114 |
+
)
|
115 |
+
self.v_cond_cross_q_norm = (
|
116 |
+
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
|
117 |
+
)
|
118 |
+
self.text_cross_k_norm = (
|
119 |
+
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
|
120 |
+
)
|
121 |
+
self.audio_cross_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
|
122 |
+
self.v_cond_cross_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
|
123 |
+
|
124 |
+
# MLPs
|
125 |
+
self.audio_norm3 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
126 |
+
self.audio_mlp = MLP(
|
127 |
+
hidden_size, mlp_hidden_dim, act_layer=get_activation_layer(mlp_act_type), bias=True, **factory_kwargs
|
128 |
+
)
|
129 |
+
|
130 |
+
self.v_cond_norm3 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
131 |
+
self.v_cond_mlp = MLP(
|
132 |
+
hidden_size, mlp_hidden_dim, act_layer=get_activation_layer(mlp_act_type), bias=True, **factory_kwargs
|
133 |
+
)
|
134 |
+
|
135 |
+
def build_rope_for_text(self, text_len, head_dim, rope_dim_list=None):
|
136 |
+
target_ndim = 1 # n-d RoPE
|
137 |
+
rope_sizes = [text_len]
|
138 |
+
|
139 |
+
if rope_dim_list is None:
|
140 |
+
rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
|
141 |
+
assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer"
|
142 |
+
|
143 |
+
text_freqs_cos, text_freqs_sin = get_nd_rotary_pos_embed(
|
144 |
+
rope_dim_list=rope_dim_list,
|
145 |
+
start=rope_sizes,
|
146 |
+
theta=10000,
|
147 |
+
use_real=True,
|
148 |
+
theta_rescale_factor=1.0,
|
149 |
+
)
|
150 |
+
return text_freqs_cos, text_freqs_sin
|
151 |
+
|
152 |
+
def set_attn_mode(self, new_mode):
|
153 |
+
if new_mode != "torch":
|
154 |
+
raise NotImplementedError(f"Only support 'torch' mode, got {new_mode}.")
|
155 |
+
self.attn_mode = new_mode
|
156 |
+
|
157 |
+
def enable_deterministic(self):
|
158 |
+
self.deterministic = True
|
159 |
+
|
160 |
+
def disable_deterministic(self):
|
161 |
+
self.deterministic = False
|
162 |
+
|
163 |
+
def forward(
|
164 |
+
self,
|
165 |
+
audio: torch.Tensor,
|
166 |
+
cond: torch.Tensor,
|
167 |
+
v_cond: torch.Tensor,
|
168 |
+
attn_mask: torch.Tensor,
|
169 |
+
vec: torch.Tensor,
|
170 |
+
freqs_cis: tuple = None,
|
171 |
+
v_freqs_cis: tuple = None,
|
172 |
+
sync_vec: torch.Tensor = None,
|
173 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
174 |
+
# Get modulation parameters
|
175 |
+
if sync_vec is not None:
|
176 |
+
assert sync_vec.ndim == 3
|
177 |
+
(audio_mod1_shift, audio_mod1_scale, audio_mod1_gate,
|
178 |
+
audio_mod2_shift, audio_mod2_scale, audio_mod2_gate,
|
179 |
+
audio_mod3_shift, audio_mod3_scale, audio_mod3_gate,
|
180 |
+
) = self.audio_mod(sync_vec).chunk(9, dim=-1)
|
181 |
+
else:
|
182 |
+
(audio_mod1_shift, audio_mod1_scale, audio_mod1_gate,
|
183 |
+
audio_mod2_shift, audio_mod2_scale, audio_mod2_gate,
|
184 |
+
audio_mod3_shift, audio_mod3_scale, audio_mod3_gate,
|
185 |
+
) = self.audio_mod(vec).chunk(9, dim=-1)
|
186 |
+
|
187 |
+
(
|
188 |
+
v_cond_mod1_shift,
|
189 |
+
v_cond_mod1_scale,
|
190 |
+
v_cond_mod1_gate,
|
191 |
+
v_cond_mod2_shift,
|
192 |
+
v_cond_mod2_scale,
|
193 |
+
v_cond_mod2_gate,
|
194 |
+
v_cond_mod3_shift,
|
195 |
+
v_cond_mod3_scale,
|
196 |
+
v_cond_mod3_gate,
|
197 |
+
) = self.v_cond_mod(vec).chunk(9, dim=-1)
|
198 |
+
|
199 |
+
# 1. Self Attention for audio + visual
|
200 |
+
audio_modulated = self.audio_norm1(audio)
|
201 |
+
audio_modulated = modulate(audio_modulated, shift=audio_mod1_shift, scale=audio_mod1_scale)
|
202 |
+
audio_qkv = self.audio_self_attn_qkv(audio_modulated)
|
203 |
+
audio_q, audio_k, audio_v = rearrange(audio_qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads)
|
204 |
+
audio_q = self.audio_self_q_norm(audio_q).to(audio_v)
|
205 |
+
audio_k = self.audio_self_k_norm(audio_k).to(audio_v)
|
206 |
+
|
207 |
+
# Prepare visual cond for attention
|
208 |
+
v_cond_modulated = self.v_cond_norm1(v_cond)
|
209 |
+
v_cond_modulated = modulate(v_cond_modulated, shift=v_cond_mod1_shift, scale=v_cond_mod1_scale)
|
210 |
+
v_cond_qkv = self.v_cond_attn_qkv(v_cond_modulated)
|
211 |
+
v_cond_q, v_cond_k, v_cond_v = rearrange(v_cond_qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads)
|
212 |
+
v_cond_q = self.v_cond_attn_q_norm(v_cond_q).to(v_cond_v)
|
213 |
+
v_cond_k = self.v_cond_attn_k_norm(v_cond_k).to(v_cond_v)
|
214 |
+
|
215 |
+
# Apply RoPE if needed for audio and visual
|
216 |
+
if freqs_cis is not None:
|
217 |
+
if not self.interleaved_audio_visual_rope:
|
218 |
+
audio_qq, audio_kk = apply_rotary_emb(audio_q, audio_k, freqs_cis, head_first=False)
|
219 |
+
audio_q, audio_k = audio_qq, audio_kk
|
220 |
+
else:
|
221 |
+
ori_audio_len = audio_q.shape[1]
|
222 |
+
ori_v_con_len = v_cond_q.shape[1]
|
223 |
+
interleaved_audio_visual_q = interleave_two_sequences(audio_q, v_cond_q)
|
224 |
+
interleaved_audio_visual_k = interleave_two_sequences(audio_k, v_cond_k)
|
225 |
+
interleaved_audio_visual_qq, interleaved_audio_visual_kk = apply_rotary_emb(
|
226 |
+
interleaved_audio_visual_q, interleaved_audio_visual_k, freqs_cis, head_first=False
|
227 |
+
)
|
228 |
+
audio_qq, v_cond_qq = decouple_interleaved_two_sequences(
|
229 |
+
interleaved_audio_visual_qq, ori_audio_len, ori_v_con_len
|
230 |
+
)
|
231 |
+
audio_kk, v_cond_kk = decouple_interleaved_two_sequences(
|
232 |
+
interleaved_audio_visual_kk, ori_audio_len, ori_v_con_len
|
233 |
+
)
|
234 |
+
audio_q, audio_k = audio_qq, audio_kk
|
235 |
+
v_cond_q, v_cond_k = v_cond_qq, v_cond_kk
|
236 |
+
|
237 |
+
# Apply RoPE to visual if needed and not interleaved
|
238 |
+
if v_freqs_cis is not None and not self.interleaved_audio_visual_rope:
|
239 |
+
v_cond_qq, v_cond_kk = apply_rotary_emb(v_cond_q, v_cond_k, v_freqs_cis, head_first=False)
|
240 |
+
v_cond_q, v_cond_k = v_cond_qq, v_cond_kk
|
241 |
+
|
242 |
+
# Concatenate for self-attention
|
243 |
+
q = torch.cat((v_cond_q, audio_q), dim=1)
|
244 |
+
k = torch.cat((v_cond_k, audio_k), dim=1)
|
245 |
+
v = torch.cat((v_cond_v, audio_v), dim=1)
|
246 |
+
|
247 |
+
# Run self-attention
|
248 |
+
attn = attention(q, k, v, mode=self.attn_mode, attn_mask=attn_mask, deterministic=self.deterministic)
|
249 |
+
v_cond_attn, audio_attn = torch.split(attn, [v_cond.shape[1], audio.shape[1]], dim=1)
|
250 |
+
|
251 |
+
# Apply self-attention output to audio and v_cond
|
252 |
+
audio = audio + apply_gate(self.audio_self_proj(audio_attn), gate=audio_mod1_gate)
|
253 |
+
v_cond = v_cond + apply_gate(self.v_cond_self_proj(v_cond_attn), gate=v_cond_mod1_gate)
|
254 |
+
|
255 |
+
# 2. Cross Attention: (v_cond, audio) as query, text as key/value
|
256 |
+
# audio, v_cond modulation
|
257 |
+
audio_modulated = self.audio_norm2(audio)
|
258 |
+
audio_modulated = modulate(audio_modulated, shift=audio_mod2_shift, scale=audio_mod2_scale)
|
259 |
+
v_cond_modulated = self.v_cond_norm2(v_cond)
|
260 |
+
v_cond_modulated = modulate(v_cond_modulated, shift=v_cond_mod2_shift, scale=v_cond_mod2_scale)
|
261 |
+
|
262 |
+
# Prepare audio query
|
263 |
+
audio_q = self.audio_cross_q(audio_modulated)
|
264 |
+
audio_q = rearrange(audio_q, "B L (H D) -> B L H D", H=self.num_heads)
|
265 |
+
audio_q = self.audio_cross_q_norm(audio_q)
|
266 |
+
|
267 |
+
# Prepare v_cond query
|
268 |
+
v_cond_q = self.v_cond_cross_q(v_cond_modulated)
|
269 |
+
v_cond_q = rearrange(v_cond_q, "B L (H D) -> B L H D", H=self.num_heads)
|
270 |
+
v_cond_q = self.v_cond_cross_q_norm(v_cond_q)
|
271 |
+
|
272 |
+
# Prepare text key/value
|
273 |
+
text_kv = self.text_cross_kv(cond)
|
274 |
+
text_k, text_v = rearrange(text_kv, "B L (K H D) -> K B L H D", K=2, H=self.num_heads)
|
275 |
+
text_k = self.text_cross_k_norm(text_k).to(text_v)
|
276 |
+
|
277 |
+
# Apply RoPE to (v_cond, audio) query and text key if needed
|
278 |
+
head_dim = self.hidden_size // self.num_heads
|
279 |
+
audio_cross_freqs_cos, audio_cross_freqs_sin = self.build_rope_for_text(audio_q.shape[1], head_dim, rope_dim_list=self.rope_dim_list)
|
280 |
+
audio_cross_freqs_cis = (audio_cross_freqs_cos.to(audio_q.device), audio_cross_freqs_sin.to(audio_q.device))
|
281 |
+
audio_q = apply_rotary_emb(audio_q, audio_q, audio_cross_freqs_cis, head_first=False)[0]
|
282 |
+
|
283 |
+
v_cond_cross_freqs_cos, v_cond_cross_freqs_sin = self.build_rope_for_text(v_cond_q.shape[1], head_dim, rope_dim_list=self.rope_dim_list)
|
284 |
+
v_cond_cross_freqs_cis = (v_cond_cross_freqs_cos.to(v_cond_q.device), v_cond_cross_freqs_sin.to(v_cond_q.device))
|
285 |
+
v_cond_q = apply_rotary_emb(v_cond_q, v_cond_q, v_cond_cross_freqs_cis, head_first=False)[0]
|
286 |
+
|
287 |
+
text_len = text_k.shape[1]
|
288 |
+
|
289 |
+
text_freqs_cos, text_freqs_sin = self.build_rope_for_text(text_len, head_dim,
|
290 |
+
rope_dim_list=self.rope_dim_list)
|
291 |
+
text_freqs_cis = (text_freqs_cos.to(text_k.device), text_freqs_sin.to(text_k.device))
|
292 |
+
text_k = apply_rotary_emb(text_k, text_k, text_freqs_cis, head_first=False)[1]
|
293 |
+
|
294 |
+
# Concat v_cond and audio for cross-attention
|
295 |
+
v_cond_audio_q = torch.cat([v_cond_q, audio_q], dim=1)
|
296 |
+
|
297 |
+
# Run cross-attention
|
298 |
+
cross_attn = attention(v_cond_audio_q, text_k, text_v, mode=self.attn_mode, deterministic=self.deterministic)
|
299 |
+
v_cond_cross_attn, audio_cross_attn = torch.split(cross_attn, [v_cond.shape[1], audio.shape[1]], dim=1)
|
300 |
+
|
301 |
+
# Apply cross-attention output
|
302 |
+
audio = audio + apply_gate(self.audio_cross_proj(audio_cross_attn), gate=audio_mod2_gate)
|
303 |
+
v_cond = v_cond + apply_gate(self.v_cond_cross_proj(v_cond_cross_attn), gate=v_cond_mod2_gate)
|
304 |
+
|
305 |
+
# 3. Apply MLPs
|
306 |
+
audio = audio + apply_gate(
|
307 |
+
self.audio_mlp(modulate(self.audio_norm3(audio), shift=audio_mod3_shift, scale=audio_mod3_scale)),
|
308 |
+
gate=audio_mod3_gate,
|
309 |
+
)
|
310 |
+
|
311 |
+
# Apply visual MLP
|
312 |
+
v_cond = v_cond + apply_gate(
|
313 |
+
self.v_cond_mlp(modulate(self.v_cond_norm3(v_cond), shift=v_cond_mod3_shift, scale=v_cond_mod3_scale)),
|
314 |
+
gate=v_cond_mod3_gate,
|
315 |
+
)
|
316 |
+
|
317 |
+
return audio, cond, v_cond
|
318 |
+
|
319 |
+
class SingleStreamBlock(nn.Module):
|
320 |
+
|
321 |
+
def __init__(self, hidden_size: int,
|
322 |
+
num_heads: int,
|
323 |
+
mlp_ratio: float,
|
324 |
+
qk_norm_type: str = "rms",
|
325 |
+
dtype: Optional[torch.dtype] = None,
|
326 |
+
device: Optional[torch.device] = None,):
|
327 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
328 |
+
super().__init__()
|
329 |
+
|
330 |
+
self.hidden_size = hidden_size
|
331 |
+
self.num_heads = num_heads
|
332 |
+
|
333 |
+
self.modulation = ModulateDiT(
|
334 |
+
hidden_size=hidden_size,
|
335 |
+
factor=6,
|
336 |
+
act_layer=get_activation_layer("silu"),
|
337 |
+
**factory_kwargs,
|
338 |
+
)
|
339 |
+
self.linear_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=True)
|
340 |
+
self.linear1 = ChannelLastConv1d(hidden_size, hidden_size, kernel_size=3, padding=1, **factory_kwargs)
|
341 |
+
self.linear2 = ConvMLP(hidden_size, hidden_size * mlp_ratio, kernel_size=3, padding=1, **factory_kwargs)
|
342 |
+
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False)
|
343 |
+
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False)
|
344 |
+
self.q_norm = nn.RMSNorm(hidden_size // num_heads)
|
345 |
+
self.k_norm = nn.RMSNorm(hidden_size // num_heads)
|
346 |
+
self.rearrange = Rearrange("B L (H D K) -> B H L D K", K=3, H=num_heads)
|
347 |
+
|
348 |
+
def forward(self, x: torch.Tensor, cond: torch.Tensor,freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None):
|
349 |
+
assert cond.ndim == 3, "Condition should be in shape of [B, T, D]"
|
350 |
+
modulation = self.modulation(cond)
|
351 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = modulation.chunk(6, dim=-1)
|
352 |
+
x_norm1 = self.norm1(x) * (1 + scale_msa) + shift_msa
|
353 |
+
|
354 |
+
qkv = self.linear_qkv(x_norm1)
|
355 |
+
q, k, v = self.rearrange(qkv).chunk(3, dim=-1)
|
356 |
+
q = q.squeeze(-1)
|
357 |
+
k = k.squeeze(-1)
|
358 |
+
v = v.squeeze(-1)
|
359 |
+
|
360 |
+
q = self.q_norm(q)
|
361 |
+
k = self.k_norm(k)
|
362 |
+
q, k = apply_rotary_emb(q, k, freqs_cis, head_first=True)
|
363 |
+
|
364 |
+
q = q.contiguous()
|
365 |
+
k = k.contiguous()
|
366 |
+
v = v.contiguous()
|
367 |
+
out = F.scaled_dot_product_attention(q, k, v)
|
368 |
+
out = rearrange(out, 'b h n d -> b n (h d)').contiguous()
|
369 |
+
|
370 |
+
x = x + apply_gate(self.linear1(out),gate=gate_msa)
|
371 |
+
x_norm = self.norm2(x) * (1 + scale_mlp) + shift_mlp
|
372 |
+
x = x + apply_gate(self.linear2(x_norm), gate=gate_mlp)
|
373 |
+
|
374 |
+
return x
|
375 |
+
|
376 |
+
class HunyuanVideoFoley(ModelMixin, ConfigMixin):
|
377 |
+
@register_to_config
|
378 |
+
def __init__(
|
379 |
+
self,
|
380 |
+
model_config,
|
381 |
+
dtype: Optional[torch.dtype] = None,
|
382 |
+
device: Optional[torch.device] = None,
|
383 |
+
):
|
384 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
385 |
+
super().__init__()
|
386 |
+
|
387 |
+
model_args = model_config.model_config.model_kwargs
|
388 |
+
self.depth_triple_blocks = model_args.get("depth_triple_blocks", 19)
|
389 |
+
self.depth_single_blocks = model_args.get("depth_single_blocks", 38)
|
390 |
+
# Gradient checkpoint.
|
391 |
+
self.gradient_checkpoint = False
|
392 |
+
self.gradient_checkpoint_layers = None
|
393 |
+
if self.gradient_checkpoint:
|
394 |
+
assert self.gradient_checkpoint_layers <= self.depth_triple_blocks + self.depth_single_blocks, (
|
395 |
+
f"Gradient checkpoint layers must be less or equal than the depth of the model. "
|
396 |
+
f"Got gradient_checkpoint_layers={self.gradient_checkpoint_layers} and depth={self.depth_triple_blocks + self.depth_single_blocks}."
|
397 |
+
)
|
398 |
+
|
399 |
+
self.interleaved_audio_visual_rope = model_args.get("interleaved_audio_visual_rope", False)
|
400 |
+
|
401 |
+
# Condition projection. Default to linear projection.
|
402 |
+
self.condition_projection = model_args.get("condition_projection", "linear")
|
403 |
+
self.condition_dim = model_args.get("condition_dim", None)
|
404 |
+
self.use_attention_mask = model_args.get("use_attention_mask", False)
|
405 |
+
|
406 |
+
self.patch_size = model_args.get("patch_size", 1)
|
407 |
+
self.visual_in_channels = model_args.get("clip_dim", 768)
|
408 |
+
self.audio_vae_latent_dim = model_args.get("audio_vae_latent_dim", 128)
|
409 |
+
self.out_channels = self.audio_vae_latent_dim
|
410 |
+
self.unpatchify_channels = self.out_channels
|
411 |
+
self.reverse = model_args.get("reverse", False)
|
412 |
+
|
413 |
+
self.num_heads = model_args.get("num_heads", 24)
|
414 |
+
self.hidden_size = model_args.get("hidden_size", 3072)
|
415 |
+
self.rope_dim_list = model_args.get("rope_dim_list", None)
|
416 |
+
self.mlp_ratio = model_args.get("mlp_ratio", 4.0)
|
417 |
+
self.mlp_act_type = model_args.get("mlp_act_type", "gelu_tanh")
|
418 |
+
|
419 |
+
self.qkv_bias = model_args.get("qkv_bias", True)
|
420 |
+
self.qk_norm = model_args.get("qk_norm", True)
|
421 |
+
self.qk_norm_type = model_args.get("qk_norm_type", "rms")
|
422 |
+
self.attn_mode = model_args.get("attn_mode", "torch")
|
423 |
+
|
424 |
+
self.embedder_type = model_args.get("embedder_type", "default")
|
425 |
+
|
426 |
+
# sync condition things
|
427 |
+
self.sync_modulation = model_args.get("sync_modulation", False)
|
428 |
+
self.add_sync_feat_to_audio = model_args.get("add_sync_feat_to_audio", False)
|
429 |
+
self.sync_feat_dim = model_args.get("sync_feat_dim", 768)
|
430 |
+
self.sync_in_ksz = model_args.get("sync_in_ksz", 1)
|
431 |
+
|
432 |
+
# condition tokens length
|
433 |
+
self.clip_len = model_args.get("clip_length", 64)
|
434 |
+
self.sync_len = model_args.get("sync_length", 192)
|
435 |
+
|
436 |
+
if self.hidden_size % self.num_heads != 0:
|
437 |
+
raise ValueError(f"Hidden size {self.hidden_size} must be divisible by num_heads {self.num_heads}")
|
438 |
+
|
439 |
+
# Build audio patchify layer and visual gated linear projection
|
440 |
+
self.patch_size = 1
|
441 |
+
self.audio_embedder = PatchEmbed1D(self.patch_size, self.audio_vae_latent_dim, self.hidden_size, **factory_kwargs)
|
442 |
+
self.visual_proj = SwiGLU(self.visual_in_channels, hidden_dim=self.hidden_size, out_dim=self.hidden_size)
|
443 |
+
|
444 |
+
# condition
|
445 |
+
if self.condition_projection == "linear":
|
446 |
+
self.cond_in = ConditionProjection(
|
447 |
+
self.condition_dim, self.hidden_size, get_activation_layer("silu"), **factory_kwargs
|
448 |
+
)
|
449 |
+
else:
|
450 |
+
raise NotImplementedError(f"Unsupported condition_projection: {self.condition_projection}")
|
451 |
+
|
452 |
+
# time modulation
|
453 |
+
self.time_in = TimestepEmbedder(self.hidden_size, get_activation_layer("silu"), **factory_kwargs)
|
454 |
+
|
455 |
+
# visual sync embedder if needed
|
456 |
+
if self.sync_in_ksz == 1:
|
457 |
+
sync_in_padding = 0
|
458 |
+
elif self.sync_in_ksz == 3:
|
459 |
+
sync_in_padding = 1
|
460 |
+
else:
|
461 |
+
raise ValueError
|
462 |
+
if self.sync_modulation or self.add_sync_feat_to_audio:
|
463 |
+
self.sync_in = nn.Sequential(
|
464 |
+
nn.Linear(self.sync_feat_dim, self.hidden_size),
|
465 |
+
nn.SiLU(),
|
466 |
+
ConvMLP(self.hidden_size, self.hidden_size * 4, kernel_size=self.sync_in_ksz, padding=sync_in_padding),
|
467 |
+
)
|
468 |
+
self.sync_pos_emb = nn.Parameter(torch.zeros((1, 1, 8, self.sync_feat_dim)))
|
469 |
+
|
470 |
+
self.triple_blocks = nn.ModuleList(
|
471 |
+
[
|
472 |
+
TwoStreamCABlock(
|
473 |
+
hidden_size=self.hidden_size,
|
474 |
+
num_heads=self.num_heads,
|
475 |
+
mlp_ratio=self.mlp_ratio,
|
476 |
+
mlp_act_type=self.mlp_act_type,
|
477 |
+
qk_norm=self.qk_norm,
|
478 |
+
qk_norm_type=self.qk_norm_type,
|
479 |
+
qkv_bias=self.qkv_bias,
|
480 |
+
attn_mode=self.attn_mode,
|
481 |
+
reverse=self.reverse,
|
482 |
+
interleaved_audio_visual_rope=self.interleaved_audio_visual_rope,
|
483 |
+
**factory_kwargs,
|
484 |
+
)
|
485 |
+
for _ in range(self.depth_triple_blocks)
|
486 |
+
]
|
487 |
+
)
|
488 |
+
|
489 |
+
|
490 |
+
self.single_blocks = nn.ModuleList(
|
491 |
+
[
|
492 |
+
SingleStreamBlock(
|
493 |
+
hidden_size=self.hidden_size,
|
494 |
+
num_heads=self.num_heads,
|
495 |
+
mlp_ratio=self.mlp_ratio,
|
496 |
+
qk_norm_type=self.qk_norm_type,
|
497 |
+
**factory_kwargs,
|
498 |
+
)
|
499 |
+
for _ in range(self.depth_single_blocks)
|
500 |
+
]
|
501 |
+
)
|
502 |
+
|
503 |
+
self.final_layer = FinalLayer1D(
|
504 |
+
self.hidden_size, self.patch_size, self.out_channels, get_activation_layer("silu"), **factory_kwargs
|
505 |
+
)
|
506 |
+
self.unpatchify_channels = self.out_channels
|
507 |
+
|
508 |
+
self.empty_clip_feat = nn.Parameter(torch.zeros(1, self.visual_in_channels), requires_grad=True)
|
509 |
+
self.empty_sync_feat = nn.Parameter(torch.zeros(1, self.sync_feat_dim), requires_grad=True)
|
510 |
+
nn.init.constant_(self.empty_clip_feat, 0)
|
511 |
+
nn.init.constant_(self.empty_sync_feat, 0)
|
512 |
+
|
513 |
+
def get_empty_string_sequence(self, bs=None) -> torch.Tensor:
|
514 |
+
if bs is None:
|
515 |
+
return self.empty_string_feat
|
516 |
+
else:
|
517 |
+
return self.empty_string_feat.unsqueeze(0).expand(bs, -1, -1)
|
518 |
+
|
519 |
+
def get_empty_clip_sequence(self, bs=None, len=None) -> torch.Tensor:
|
520 |
+
len = len if len is not None else self.clip_len
|
521 |
+
if bs is None:
|
522 |
+
return self.empty_clip_feat.expand(len, -1) # 15s
|
523 |
+
else:
|
524 |
+
return self.empty_clip_feat.unsqueeze(0).expand(bs, len, -1) # 15s
|
525 |
+
|
526 |
+
def get_empty_sync_sequence(self, bs=None, len=None) -> torch.Tensor:
|
527 |
+
len = len if len is not None else self.sync_len
|
528 |
+
if bs is None:
|
529 |
+
return self.empty_sync_feat.expand(len, -1)
|
530 |
+
else:
|
531 |
+
return self.empty_sync_feat.unsqueeze(0).expand(bs, len, -1)
|
532 |
+
|
533 |
+
def build_rope_for_audio_visual(self, audio_emb_len, visual_cond_len):
|
534 |
+
assert self.patch_size == 1
|
535 |
+
# ======================================== Build RoPE for audio tokens ======================================
|
536 |
+
target_ndim = 1 # n-d RoPE
|
537 |
+
rope_sizes = [audio_emb_len]
|
538 |
+
head_dim = self.hidden_size // self.num_heads
|
539 |
+
rope_dim_list = self.rope_dim_list
|
540 |
+
if rope_dim_list is None:
|
541 |
+
rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
|
542 |
+
assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer"
|
543 |
+
freqs_cos, freqs_sin = get_nd_rotary_pos_embed(
|
544 |
+
rope_dim_list=rope_dim_list,
|
545 |
+
start=rope_sizes,
|
546 |
+
theta=10000,
|
547 |
+
use_real=True,
|
548 |
+
theta_rescale_factor=1.0,
|
549 |
+
)
|
550 |
+
|
551 |
+
# ========================== Build RoPE for clip tokens =========================
|
552 |
+
target_ndim = 1 # n-d RoPE
|
553 |
+
rope_sizes = [visual_cond_len]
|
554 |
+
head_dim = self.hidden_size // self.num_heads
|
555 |
+
rope_dim_list = self.rope_dim_list
|
556 |
+
if rope_dim_list is None:
|
557 |
+
rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
|
558 |
+
assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer"
|
559 |
+
v_freqs_cos, v_freqs_sin = get_nd_rotary_pos_embed(
|
560 |
+
rope_dim_list=rope_dim_list,
|
561 |
+
start=rope_sizes,
|
562 |
+
theta=10000,
|
563 |
+
use_real=True,
|
564 |
+
theta_rescale_factor=1.0,
|
565 |
+
freq_scaling=1.0 * audio_emb_len / visual_cond_len,
|
566 |
+
)
|
567 |
+
return freqs_cos, freqs_sin, v_freqs_cos, v_freqs_sin
|
568 |
+
|
569 |
+
def build_rope_for_interleaved_audio_visual(self, total_len):
|
570 |
+
assert self.patch_size == 1
|
571 |
+
# ========================== Build RoPE for audio tokens ========================
|
572 |
+
target_ndim = 1 # n-d RoPE
|
573 |
+
rope_sizes = [total_len]
|
574 |
+
head_dim = self.hidden_size // self.num_heads
|
575 |
+
rope_dim_list = self.rope_dim_list
|
576 |
+
if rope_dim_list is None:
|
577 |
+
rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
|
578 |
+
assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer"
|
579 |
+
freqs_cos, freqs_sin = get_nd_rotary_pos_embed(
|
580 |
+
rope_dim_list=rope_dim_list,
|
581 |
+
start=rope_sizes,
|
582 |
+
theta=10000,
|
583 |
+
use_real=True,
|
584 |
+
theta_rescale_factor=1.0,
|
585 |
+
)
|
586 |
+
return freqs_cos, freqs_sin
|
587 |
+
|
588 |
+
def set_attn_mode(self, new_mode):
|
589 |
+
for block in self.triple_blocks:
|
590 |
+
block.set_attn_mode(new_mode)
|
591 |
+
for block in self.single_blocks:
|
592 |
+
block.set_attn_mode(new_mode)
|
593 |
+
|
594 |
+
def enable_deterministic(self):
|
595 |
+
for block in self.triple_blocks:
|
596 |
+
block.enable_deterministic()
|
597 |
+
for block in self.single_blocks:
|
598 |
+
block.enable_deterministic()
|
599 |
+
|
600 |
+
def disable_deterministic(self):
|
601 |
+
for block in self.triple_blocks:
|
602 |
+
block.disable_deterministic()
|
603 |
+
for block in self.single_blocks:
|
604 |
+
block.disable_deterministic()
|
605 |
+
|
606 |
+
def forward(
|
607 |
+
self,
|
608 |
+
x: torch.Tensor,
|
609 |
+
t: torch.Tensor, # Should be in range(0, 1000).
|
610 |
+
clip_feat: Optional[torch.Tensor] = None,
|
611 |
+
cond: torch.Tensor = None,
|
612 |
+
audio_mask: Optional[torch.Tensor] = None,
|
613 |
+
cond_mask: torch.Tensor = None,
|
614 |
+
sync_feat: Optional[torch.Tensor] = None,
|
615 |
+
drop_visual: Optional[List[bool]] = None,
|
616 |
+
return_dict: bool = True,
|
617 |
+
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
618 |
+
out = {}
|
619 |
+
audio = x
|
620 |
+
bs, _, ol = x.shape
|
621 |
+
tl = ol // self.patch_size
|
622 |
+
|
623 |
+
# Prepare learnable empty conditions for visual condition
|
624 |
+
if drop_visual is not None:
|
625 |
+
clip_feat[drop_visual] = self.get_empty_clip_sequence().to(dtype=clip_feat.dtype)
|
626 |
+
sync_feat[drop_visual] = self.get_empty_sync_sequence().to(dtype=sync_feat.dtype)
|
627 |
+
|
628 |
+
# ========================= Prepare time & visual modulation =========================
|
629 |
+
vec = self.time_in(t)
|
630 |
+
sync_vec = None
|
631 |
+
if self.sync_modulation:
|
632 |
+
assert sync_feat is not None and sync_feat.shape[1] % 8 == 0
|
633 |
+
sync_feat = sync_feat.view(bs, int(sync_feat.shape[1] / 8), 8, self.sync_feat_dim) + self.sync_pos_emb
|
634 |
+
sync_feat = sync_feat.view(bs, -1, self.sync_feat_dim) # bs, num_segments * 8, channels
|
635 |
+
sync_vec = self.sync_in(sync_feat) # bs, num_segments * 8, c
|
636 |
+
sync_vec = (
|
637 |
+
F.interpolate(sync_vec.transpose(1, 2), size=(tl), mode="nearest-exact").contiguous().transpose(1, 2)
|
638 |
+
) # bs, tl, c
|
639 |
+
sync_vec = sync_vec + vec.unsqueeze(1)
|
640 |
+
elif self.add_sync_feat_to_audio:
|
641 |
+
assert sync_feat is not None and sync_feat.shape[1] % 8 == 0
|
642 |
+
sync_feat = sync_feat.view(bs, sync_feat.shape[1] // 8, 8, self.sync_feat_dim) + self.sync_pos_emb
|
643 |
+
sync_feat = sync_feat.view(bs, -1, self.sync_feat_dim) # bs, num_segments * 8, channels
|
644 |
+
sync_feat = self.sync_in(sync_feat) # bs, num_segments * 8, c
|
645 |
+
add_sync_feat_to_audio = (
|
646 |
+
F.interpolate(sync_feat.transpose(1, 2), size=(tl), mode="nearest-exact").contiguous().transpose(1, 2)
|
647 |
+
) # bs, tl, c
|
648 |
+
|
649 |
+
# ========================= Get text, audio and video clip embedding =========================
|
650 |
+
cond = self.cond_in(cond)
|
651 |
+
cond_seq_len = cond.shape[1]
|
652 |
+
|
653 |
+
audio = self.audio_embedder(x)
|
654 |
+
audio_seq_len = audio.shape[1]
|
655 |
+
v_cond = self.visual_proj(clip_feat)
|
656 |
+
v_cond_seq_len = v_cond.shape[1]
|
657 |
+
|
658 |
+
# ========================= Compute attention mask =========================
|
659 |
+
attn_mask = None
|
660 |
+
if self.use_attention_mask:
|
661 |
+
assert cond_mask is not None
|
662 |
+
batch_size = audio.shape[0]
|
663 |
+
seq_len = cond_seq_len + v_cond_seq_len + audio_seq_len
|
664 |
+
|
665 |
+
# get default audio_mask and v_cond_mask
|
666 |
+
audio_mask = torch.ones((batch_size, audio_seq_len), dtype=torch.bool, device=audio.device)
|
667 |
+
v_cond_mask = torch.ones((batch_size, v_cond_seq_len), dtype=torch.bool, device=audio.device)
|
668 |
+
|
669 |
+
# batch_size x seq_len
|
670 |
+
concat_mask = torch.cat([cond_mask, v_cond_mask, audio_mask], dim=1)
|
671 |
+
# batch_size x 1 x seq_len x seq_len
|
672 |
+
attn_mask_1 = concat_mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1)
|
673 |
+
# batch_size x 1 x seq_len x seq_len
|
674 |
+
attn_mask_2 = attn_mask_1.transpose(2, 3)
|
675 |
+
# batch_size x 1 x seq_len x seq_len, 1 for broadcasting of num_heads
|
676 |
+
attn_mask = (attn_mask_1 & attn_mask_2).bool()
|
677 |
+
# avoids self-attention weight being NaN for text padding tokens
|
678 |
+
attn_mask[:, :, :, 0] = True
|
679 |
+
|
680 |
+
|
681 |
+
# ========================= Build rope for audio and clip tokens =========================
|
682 |
+
if self.interleaved_audio_visual_rope:
|
683 |
+
freqs_cos, freqs_sin = self.build_rope_for_interleaved_audio_visual(audio_seq_len * 2)
|
684 |
+
v_freqs_cos = v_freqs_sin = None
|
685 |
+
else:
|
686 |
+
freqs_cos, freqs_sin, v_freqs_cos, v_freqs_sin = self.build_rope_for_audio_visual(
|
687 |
+
audio_seq_len, v_cond_seq_len
|
688 |
+
)
|
689 |
+
|
690 |
+
# ========================= Pass through DiT blocks =========================
|
691 |
+
freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None
|
692 |
+
v_freqs_cis = (v_freqs_cos, v_freqs_sin) if v_freqs_cos is not None else None
|
693 |
+
|
694 |
+
if self.add_sync_feat_to_audio:
|
695 |
+
add_sync_layer = 0
|
696 |
+
assert (
|
697 |
+
add_sync_layer < self.depth_triple_blocks
|
698 |
+
), f"The layer to add mel_spectrogram feature and sync feature should in the triple_stream_blocks (n: {self.depth_triple_blocks})."
|
699 |
+
# Triple-stream blocks
|
700 |
+
for layer_num, block in enumerate(self.triple_blocks):
|
701 |
+
if self.add_sync_feat_to_audio and layer_num == add_sync_layer:
|
702 |
+
audio = audio + add_sync_feat_to_audio
|
703 |
+
triple_block_args = [audio, cond, v_cond, attn_mask, vec, freqs_cis, v_freqs_cis, sync_vec]
|
704 |
+
if (
|
705 |
+
self.training
|
706 |
+
and self.gradient_checkpoint
|
707 |
+
and (self.gradient_checkpoint_layers == -1 or layer_num < self.gradient_checkpoint_layers)
|
708 |
+
):
|
709 |
+
audio, cond, v_cond = torch.utils.checkpoint.checkpoint(
|
710 |
+
ckpt_wrapper(block), *triple_block_args, use_reentrant=False
|
711 |
+
)
|
712 |
+
else:
|
713 |
+
audio, cond, v_cond = block(*triple_block_args)
|
714 |
+
|
715 |
+
x = audio
|
716 |
+
if sync_vec is not None:
|
717 |
+
vec = vec.unsqueeze(1).repeat(1, cond_seq_len + v_cond_seq_len, 1)
|
718 |
+
vec = torch.cat((vec, sync_vec), dim=1)
|
719 |
+
|
720 |
+
freqs_cos, freqs_sin, _, _ = self.build_rope_for_audio_visual(audio_seq_len, v_cond_seq_len)
|
721 |
+
if self.add_sync_feat_to_audio:
|
722 |
+
vec = add_sync_feat_to_audio + vec.unsqueeze(dim=1)
|
723 |
+
if len(self.single_blocks) > 0:
|
724 |
+
for layer_num, block in enumerate(self.single_blocks):
|
725 |
+
single_block_args = [
|
726 |
+
x,
|
727 |
+
vec,
|
728 |
+
(freqs_cos, freqs_sin),
|
729 |
+
]
|
730 |
+
if (
|
731 |
+
self.training
|
732 |
+
and self.gradient_checkpoint
|
733 |
+
and (
|
734 |
+
self.gradient_checkpoint_layers == -1
|
735 |
+
or layer_num + len(self.triple_blocks) < self.gradient_checkpoint_layers
|
736 |
+
)
|
737 |
+
):
|
738 |
+
x = torch.utils.checkpoint.checkpoint(ckpt_wrapper(block), *single_block_args, use_reentrant=False)
|
739 |
+
else:
|
740 |
+
x = block(*single_block_args)
|
741 |
+
|
742 |
+
audio = x
|
743 |
+
|
744 |
+
# ========================= Final layer =========================
|
745 |
+
if sync_vec is not None:
|
746 |
+
vec = sync_vec
|
747 |
+
audio = self.final_layer(audio, vec) # (N, T, patch_size * out_channels)
|
748 |
+
audio = self.unpatchify1d(audio, tl)
|
749 |
+
|
750 |
+
if return_dict:
|
751 |
+
out["x"] = audio
|
752 |
+
return out
|
753 |
+
return audio
|
754 |
+
|
755 |
+
def unpatchify1d(self, x, l):
|
756 |
+
# x: (N, L, patch_size * C)
|
757 |
+
# audio: (N, C, T), T == L * patch_size
|
758 |
+
c = self.unpatchify_channels
|
759 |
+
p = self.patch_size
|
760 |
+
assert l == x.shape[1]
|
761 |
+
|
762 |
+
x = x.reshape(shape=(x.shape[0], l, p, c))
|
763 |
+
x = torch.einsum("ntpc->nctp", x)
|
764 |
+
audio = x.reshape(shape=(x.shape[0], c, l * p))
|
765 |
+
return audio
|
766 |
+
|
767 |
+
def params_count(self):
|
768 |
+
counts = {
|
769 |
+
"triple": sum(
|
770 |
+
[
|
771 |
+
sum(p.numel() for p in block.audio_cross_q.parameters())
|
772 |
+
+ sum(p.numel() for p in block.v_cond_cross_q.parameters())
|
773 |
+
+ sum(p.numel() for p in block.text_cross_kv.parameters())
|
774 |
+
+ sum(p.numel() for p in block.audio_self_attn_qkv.parameters())
|
775 |
+
+ sum(p.numel() for p in block.v_cond_attn_qkv.parameters())
|
776 |
+
+ sum(p.numel() for p in block.audio_mlp.parameters())
|
777 |
+
+ sum(p.numel() for p in block.audio_self_proj.parameters())
|
778 |
+
+ sum(p.numel() for p in block.v_cond_self_proj.parameters())
|
779 |
+
+ sum(p.numel() for p in block.v_cond_mlp.parameters())
|
780 |
+
for block in self.triple_blocks
|
781 |
+
]
|
782 |
+
),
|
783 |
+
"single": sum(
|
784 |
+
[
|
785 |
+
sum(p.numel() for p in block.linear1.parameters())
|
786 |
+
+ sum(p.numel() for p in block.linear2.parameters())
|
787 |
+
for block in self.single_blocks
|
788 |
+
]
|
789 |
+
),
|
790 |
+
"total": sum(p.numel() for p in self.parameters()),
|
791 |
+
}
|
792 |
+
|
793 |
+
counts["attn+mlp"] = counts["triple"] + counts["single"]
|
794 |
+
return counts
|
hunyuanvideo_foley/models/nn/__init__.py
ADDED
File without changes
|
hunyuanvideo_foley/models/nn/activation_layers.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch.nn.functional as F
|
3 |
+
|
4 |
+
def get_activation_layer(act_type):
|
5 |
+
if act_type == "gelu":
|
6 |
+
return lambda: nn.GELU()
|
7 |
+
elif act_type == "gelu_tanh":
|
8 |
+
# Approximate `tanh` requires torch >= 1.13
|
9 |
+
return lambda: nn.GELU(approximate="tanh")
|
10 |
+
elif act_type == "relu":
|
11 |
+
return nn.ReLU
|
12 |
+
elif act_type == "silu":
|
13 |
+
return nn.SiLU
|
14 |
+
else:
|
15 |
+
raise ValueError(f"Unknown activation type: {act_type}")
|
16 |
+
|
17 |
+
class SwiGLU(nn.Module):
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
dim: int,
|
21 |
+
hidden_dim: int,
|
22 |
+
out_dim: int,
|
23 |
+
):
|
24 |
+
"""
|
25 |
+
Initialize the SwiGLU FeedForward module.
|
26 |
+
|
27 |
+
Args:
|
28 |
+
dim (int): Input dimension.
|
29 |
+
hidden_dim (int): Hidden dimension of the feedforward layer.
|
30 |
+
|
31 |
+
Attributes:
|
32 |
+
w1: Linear transformation for the first layer.
|
33 |
+
w2: Linear transformation for the second layer.
|
34 |
+
w3: Linear transformation for the third layer.
|
35 |
+
|
36 |
+
"""
|
37 |
+
super().__init__()
|
38 |
+
|
39 |
+
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
|
40 |
+
self.w2 = nn.Linear(hidden_dim, out_dim, bias=False)
|
41 |
+
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
42 |
+
|
43 |
+
def forward(self, x):
|
44 |
+
return self.w2(F.silu(self.w1(x)) * self.w3(x))
|
hunyuanvideo_foley/models/nn/attn_layers.py
ADDED
@@ -0,0 +1,546 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib.metadata
|
2 |
+
import math
|
3 |
+
from typing import Tuple, Union
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from einops import rearrange
|
8 |
+
|
9 |
+
try:
|
10 |
+
from flash_attn import (
|
11 |
+
flash_attn_qkvpacked_func,
|
12 |
+
flash_attn_kvpacked_func,
|
13 |
+
flash_attn_varlen_kvpacked_func,
|
14 |
+
flash_attn_varlen_qkvpacked_func,
|
15 |
+
)
|
16 |
+
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input
|
17 |
+
except ImportError:
|
18 |
+
flash_attn_qkvpacked_func, flash_attn_kvpacked_func, flash_attn_varlen_kvpacked_func = None, None, None
|
19 |
+
index_first_axis = None
|
20 |
+
from packaging import version
|
21 |
+
from transformers.utils.import_utils import _is_package_available
|
22 |
+
|
23 |
+
from .norm_layers import get_norm_layer
|
24 |
+
|
25 |
+
def reshape_for_broadcast(freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], x: torch.Tensor, head_first=False):
|
26 |
+
"""
|
27 |
+
Reshape frequency tensor for broadcasting it with another tensor.
|
28 |
+
|
29 |
+
This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
|
30 |
+
for the purpose of broadcasting the frequency tensor during element-wise operations.
|
31 |
+
|
32 |
+
Notes:
|
33 |
+
When using FlashMHAModified, head_first should be False.
|
34 |
+
When using Attention, head_first should be True.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Frequency tensor to be reshaped.
|
38 |
+
x (torch.Tensor): Target tensor for broadcasting compatibility.
|
39 |
+
head_first (bool): head dimension first (except batch dim) or not.
|
40 |
+
|
41 |
+
Returns:
|
42 |
+
torch.Tensor: Reshaped frequency tensor.
|
43 |
+
|
44 |
+
Raises:
|
45 |
+
AssertionError: If the frequency tensor doesn't match the expected shape.
|
46 |
+
AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
|
47 |
+
"""
|
48 |
+
ndim = x.ndim
|
49 |
+
assert 0 <= 1 < ndim
|
50 |
+
|
51 |
+
if isinstance(freqs_cis, tuple):
|
52 |
+
# freqs_cis: (cos, sin) in real space
|
53 |
+
if head_first:
|
54 |
+
assert freqs_cis[0].shape == (
|
55 |
+
x.shape[-2],
|
56 |
+
x.shape[-1],
|
57 |
+
), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
|
58 |
+
shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
59 |
+
else:
|
60 |
+
assert freqs_cis[0].shape == (
|
61 |
+
x.shape[1],
|
62 |
+
x.shape[-1],
|
63 |
+
), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
|
64 |
+
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
65 |
+
return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
|
66 |
+
else:
|
67 |
+
# freqs_cis: values in complex space
|
68 |
+
if head_first:
|
69 |
+
assert freqs_cis.shape == (
|
70 |
+
x.shape[-2],
|
71 |
+
x.shape[-1],
|
72 |
+
), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}"
|
73 |
+
shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
74 |
+
else:
|
75 |
+
assert freqs_cis.shape == (
|
76 |
+
x.shape[1],
|
77 |
+
x.shape[-1],
|
78 |
+
), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}"
|
79 |
+
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
80 |
+
return freqs_cis.view(*shape)
|
81 |
+
|
82 |
+
|
83 |
+
def rotate_half(x):
|
84 |
+
x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
|
85 |
+
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
86 |
+
|
87 |
+
|
88 |
+
def apply_rotary_emb(
|
89 |
+
xq: torch.Tensor,
|
90 |
+
xk: torch.Tensor,
|
91 |
+
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
|
92 |
+
head_first: bool = False,
|
93 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
94 |
+
"""
|
95 |
+
Apply rotary embeddings to input tensors using the given frequency tensor.
|
96 |
+
|
97 |
+
This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
|
98 |
+
frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
|
99 |
+
is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
|
100 |
+
returned as real tensors.
|
101 |
+
|
102 |
+
Args:
|
103 |
+
xq (torch.Tensor): Query tensor to apply rotary embeddings. [B, S, H, D]
|
104 |
+
xk (torch.Tensor): Key tensor to apply rotary embeddings. [B, S, H, D]
|
105 |
+
freqs_cis (torch.Tensor or tuple): Precomputed frequency tensor for complex exponential.
|
106 |
+
head_first (bool): head dimension first (except batch dim) or not.
|
107 |
+
|
108 |
+
Returns:
|
109 |
+
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
|
110 |
+
|
111 |
+
"""
|
112 |
+
xk_out = None
|
113 |
+
if isinstance(freqs_cis, tuple):
|
114 |
+
cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D]
|
115 |
+
cos, sin = cos.to(xq.device), sin.to(xq.device)
|
116 |
+
# real * cos - imag * sin
|
117 |
+
# imag * cos + real * sin
|
118 |
+
xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq)
|
119 |
+
xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk)
|
120 |
+
else:
|
121 |
+
# view_as_complex will pack [..., D/2, 2](real) to [..., D/2](complex)
|
122 |
+
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # [B, S, H, D//2]
|
123 |
+
freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(xq.device) # [S, D//2] --> [1, S, 1, D//2]
|
124 |
+
# (real, imag) * (cos, sin) = (real * cos - imag * sin, imag * cos + real * sin)
|
125 |
+
# view_as_real will expand [..., D/2](complex) to [..., D/2, 2](real)
|
126 |
+
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq)
|
127 |
+
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) # [B, S, H, D//2]
|
128 |
+
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk)
|
129 |
+
|
130 |
+
return xq_out, xk_out
|
131 |
+
|
132 |
+
|
133 |
+
class BasicAttentionLayer(nn.Module):
|
134 |
+
def __init__(self, attn_mode="flash", deterministic=False):
|
135 |
+
super().__init__()
|
136 |
+
self.attn_mode = attn_mode
|
137 |
+
self.deterministic = deterministic
|
138 |
+
|
139 |
+
def set_attn_mode(self, new_mode):
|
140 |
+
self.attn_mode = new_mode
|
141 |
+
|
142 |
+
def enable_deterministic(self):
|
143 |
+
self.deterministic = True
|
144 |
+
|
145 |
+
def disable_deterministic(self):
|
146 |
+
self.deterministic = False
|
147 |
+
|
148 |
+
|
149 |
+
MEMORY_LAYOUT = {
|
150 |
+
"self_flash": (
|
151 |
+
lambda x: x,
|
152 |
+
lambda x: x,
|
153 |
+
),
|
154 |
+
"cross_flash": (
|
155 |
+
lambda x: x,
|
156 |
+
lambda x: x,
|
157 |
+
),
|
158 |
+
"flash_torch_sp": (
|
159 |
+
lambda x: x,
|
160 |
+
lambda x: x,
|
161 |
+
),
|
162 |
+
"torch": (
|
163 |
+
lambda x: x.transpose(1, 2),
|
164 |
+
lambda x: x.transpose(1, 2),
|
165 |
+
),
|
166 |
+
"vanilla": (
|
167 |
+
lambda x: x.transpose(1, 2),
|
168 |
+
lambda x: x.transpose(1, 2),
|
169 |
+
),
|
170 |
+
}
|
171 |
+
|
172 |
+
|
173 |
+
# Copyed from https://github.com/huggingface/transformers/blob/b873234cb649a24865021f0d598627ce2b24d34a/src/transformers/modeling_flash_attention_utils.py#L33C1-L57C6
|
174 |
+
def _get_unpad_data(attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, int]:
|
175 |
+
"""
|
176 |
+
Retrieves indexing data required to repad unpadded (ragged) tensors.
|
177 |
+
|
178 |
+
Arguments:
|
179 |
+
attention_mask (`torch.Tensor`):
|
180 |
+
Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
|
181 |
+
|
182 |
+
Return:
|
183 |
+
indices (`torch.Tensor):
|
184 |
+
The indices of non-masked tokens from the flattened input sequence.
|
185 |
+
cu_seqlens (`torch.Tensor`):
|
186 |
+
The cumulative sequence lengths, used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
|
187 |
+
max_seqlen_in_batch (`int`):
|
188 |
+
Maximum sequence length in batch.
|
189 |
+
"""
|
190 |
+
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
191 |
+
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
192 |
+
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
193 |
+
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
|
194 |
+
return (
|
195 |
+
indices,
|
196 |
+
cu_seqlens,
|
197 |
+
max_seqlen_in_batch,
|
198 |
+
)
|
199 |
+
|
200 |
+
|
201 |
+
# Copyed from https://github.com/huggingface/transformers/blob/b873234cb649a24865021f0d598627ce2b24d34a/src/transformers/utils/import_utils.py#L822
|
202 |
+
def is_flash_attn_greater_or_equal(library_version: str):
|
203 |
+
if not _is_package_available("flash_attn"):
|
204 |
+
return False
|
205 |
+
|
206 |
+
return version.parse(importlib.metadata.version("flash_attn")) >= version.parse(library_version)
|
207 |
+
|
208 |
+
|
209 |
+
def get_kv_seqlens_with_mask(attn_mask, k, v):
|
210 |
+
indices_k, cu_seqlens_k, max_seqlen_k = _get_unpad_data(attn_mask)
|
211 |
+
b, s1, a, d = k.shape
|
212 |
+
k = index_first_axis(k.reshape(b * s1, a, d), indices_k)
|
213 |
+
v = index_first_axis(v.reshape(b * s1, a, d), indices_k)
|
214 |
+
kv = torch.stack([k, v], dim=1)
|
215 |
+
return cu_seqlens_k, max_seqlen_k, kv
|
216 |
+
|
217 |
+
|
218 |
+
def get_q_seqlens(q):
|
219 |
+
bs, s, a, d = q.shape
|
220 |
+
cu_seqlens_q = torch.arange(0, (bs + 1) * s, step=s, dtype=torch.int32, device=q.device)
|
221 |
+
q = q.reshape(bs * s, a, d)
|
222 |
+
return cu_seqlens_q, s, q
|
223 |
+
|
224 |
+
def flash_attn_no_pad(
|
225 |
+
qkv, key_padding_mask, causal=False, dropout_p=0.0, softmax_scale=None
|
226 |
+
):
|
227 |
+
# adapted from https://github.com/Dao-AILab/flash-attention/blob/13403e81157ba37ca525890f2f0f2137edf75311/flash_attn/flash_attention.py#L27
|
228 |
+
batch_size = qkv.shape[0]
|
229 |
+
seqlen = qkv.shape[1]
|
230 |
+
nheads = qkv.shape[-2]
|
231 |
+
x = rearrange(qkv, "b s three h d -> b s (three h d)")
|
232 |
+
# x_unpad, indices, cu_seqlens, max_s, used_seqlens_in_batch
|
233 |
+
# x_unpad, indices, cu_seqlens, max_s
|
234 |
+
unpad_results = unpad_input(
|
235 |
+
x, key_padding_mask
|
236 |
+
)
|
237 |
+
|
238 |
+
if len(unpad_results) == 4:
|
239 |
+
x_unpad, indices, cu_seqlens, max_s = unpad_results
|
240 |
+
elif len(unpad_results) == 5:
|
241 |
+
x_unpad, indices, cu_seqlens, max_s, used_seqlens_in_batch = unpad_results
|
242 |
+
else:
|
243 |
+
raise ValueError
|
244 |
+
|
245 |
+
x_unpad = rearrange(x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads)
|
246 |
+
output_unpad = flash_attn_varlen_qkvpacked_func(
|
247 |
+
x_unpad,
|
248 |
+
cu_seqlens,
|
249 |
+
max_s,
|
250 |
+
dropout_p,
|
251 |
+
softmax_scale=softmax_scale,
|
252 |
+
causal=causal,
|
253 |
+
)
|
254 |
+
output = rearrange(
|
255 |
+
pad_input(
|
256 |
+
rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, batch_size, seqlen
|
257 |
+
),
|
258 |
+
"b s (h d) -> b s h d",
|
259 |
+
h=nheads,
|
260 |
+
)
|
261 |
+
return output
|
262 |
+
|
263 |
+
|
264 |
+
def attention(
|
265 |
+
q,
|
266 |
+
k,
|
267 |
+
v,
|
268 |
+
mode,
|
269 |
+
drop_rate=0,
|
270 |
+
attn_mask=None,
|
271 |
+
cond_mask=None,
|
272 |
+
causal=False,
|
273 |
+
deterministic=False,
|
274 |
+
cu_seqlens=None,
|
275 |
+
max_seqlen=None,
|
276 |
+
cu_seqlens_k=None,
|
277 |
+
max_seqlen_k=None,
|
278 |
+
img_seq_len=None,
|
279 |
+
):
|
280 |
+
"""
|
281 |
+
Perform QKV self attention.
|
282 |
+
|
283 |
+
Args:
|
284 |
+
q (torch.Tensor): Query tensor with shape [b, s, a, d], where a is the number of heads.
|
285 |
+
k (torch.Tensor): Key tensor with shape [b, s1, a, d]
|
286 |
+
v (torch.Tensor): Value tensor with shape [b, s1, a, d]
|
287 |
+
mode (str): Attention mode. Choose from 'self_flash', 'cross_flash', 'torch', and 'vanilla'.
|
288 |
+
drop_rate (float): Dropout rate in attention map. (default: 0)
|
289 |
+
attn_mask (torch.Tensor): Attention mask with shape [b, s1] (cross_attn), or [b, a, s, s1] (torch or vanilla).
|
290 |
+
(default: None)
|
291 |
+
causal (bool): Whether to use causal attention. (default: False)
|
292 |
+
deterministic (bool): Whether to use deterministic attention. (default: False)
|
293 |
+
cu_seqlens (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
|
294 |
+
used to index into q.
|
295 |
+
max_seqlen (int): The maximum sequence length in the batch of q.
|
296 |
+
cu_seqlens_k (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
|
297 |
+
used to index into kv.
|
298 |
+
max_seqlen_k (int): The maximum sequence length in the batch of k and v.
|
299 |
+
|
300 |
+
Returns:
|
301 |
+
torch.Tensor: Output tensor after self attention with shape [b, s, ad]
|
302 |
+
"""
|
303 |
+
if mode in ["torch", "vanilla", "self_flash", "cross_flash"]:
|
304 |
+
if isinstance(q, tuple):
|
305 |
+
q = torch.cat(q, dim=1)
|
306 |
+
if isinstance(k, tuple):
|
307 |
+
k = torch.cat(k, dim=1)
|
308 |
+
if isinstance(v, tuple):
|
309 |
+
v = torch.cat(v, dim=1)
|
310 |
+
pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode]
|
311 |
+
q = pre_attn_layout(q)
|
312 |
+
k = pre_attn_layout(k)
|
313 |
+
v = pre_attn_layout(v)
|
314 |
+
|
315 |
+
if "flash" in mode:
|
316 |
+
assert (
|
317 |
+
flash_attn_qkvpacked_func is not None
|
318 |
+
), "Flash attention is not available. Please install flash_attn first."
|
319 |
+
flash_kwargs = dict(dropout_p=drop_rate, causal=causal)
|
320 |
+
if deterministic:
|
321 |
+
if not is_flash_attn_greater_or_equal("2.4.1"):
|
322 |
+
raise ValueError(
|
323 |
+
"Flash attention deterministic mode requires flash_attn>=2.4.1. " "Please upgrade flash_attn"
|
324 |
+
)
|
325 |
+
flash_kwargs["deterministic"] = deterministic
|
326 |
+
|
327 |
+
if mode == "self_flash":
|
328 |
+
qkv = torch.stack([q, k, v], dim=2)
|
329 |
+
if attn_mask is not None:
|
330 |
+
raise ValueError("Self attention does not support attention mask")
|
331 |
+
x = flash_attn_qkvpacked_func(qkv, **flash_kwargs)
|
332 |
+
|
333 |
+
elif mode == "cross_flash":
|
334 |
+
kv = torch.stack([k, v], dim=2)
|
335 |
+
if attn_mask is None:
|
336 |
+
x = flash_attn_kvpacked_func(q, kv, **flash_kwargs)
|
337 |
+
else:
|
338 |
+
b, s, a, h = q.shape
|
339 |
+
cu_seqlens_q, max_seqlen_q, q = get_q_seqlens(q)
|
340 |
+
cu_seqlens_k, max_seqlen_k, kv = get_kv_seqlens_with_mask(attn_mask, k, v)
|
341 |
+
|
342 |
+
attn_output = flash_attn_varlen_kvpacked_func(
|
343 |
+
q,
|
344 |
+
kv,
|
345 |
+
cu_seqlens_q=cu_seqlens_q,
|
346 |
+
cu_seqlens_k=cu_seqlens_k,
|
347 |
+
max_seqlen_q=max_seqlen_q,
|
348 |
+
max_seqlen_k=max_seqlen_k,
|
349 |
+
**flash_kwargs,
|
350 |
+
)
|
351 |
+
x = attn_output.reshape(b, s, a, h)
|
352 |
+
elif mode == 'torch':
|
353 |
+
if attn_mask is not None and attn_mask.dtype != torch.bool:
|
354 |
+
attn_mask = attn_mask.to(q.dtype)
|
355 |
+
x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal)
|
356 |
+
|
357 |
+
elif mode == "vanilla":
|
358 |
+
scale_factor = 1 / math.sqrt(q.size(-1))
|
359 |
+
|
360 |
+
b, a, s, _ = q.shape
|
361 |
+
s1 = k.size(2)
|
362 |
+
attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device)
|
363 |
+
if causal:
|
364 |
+
# Only applied to self attention
|
365 |
+
assert attn_mask is None, "Causal mask and attn_mask cannot be used together"
|
366 |
+
temp_mask = torch.ones(b, a, s, s, dtype=torch.bool, device=q.device).tril(diagonal=0)
|
367 |
+
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
|
368 |
+
attn_bias.to(q.dtype)
|
369 |
+
|
370 |
+
if attn_mask is not None:
|
371 |
+
if attn_mask.dtype == torch.bool:
|
372 |
+
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
|
373 |
+
else:
|
374 |
+
attn_bias += attn_mask
|
375 |
+
|
376 |
+
# TODO(jarvizhang): Maybe force q and k to be float32 to avoid numerical overflow
|
377 |
+
attn = (q @ k.transpose(-2, -1)) * scale_factor
|
378 |
+
attn += attn_bias
|
379 |
+
attn = attn.softmax(dim=-1)
|
380 |
+
attn = torch.dropout(attn, p=drop_rate, train=True)
|
381 |
+
x = attn @ v
|
382 |
+
else:
|
383 |
+
raise NotImplementedError(f"Unsupported attention mode: {mode}")
|
384 |
+
|
385 |
+
if mode in ["torch", "vanilla", "self_flash", "cross_flash"]:
|
386 |
+
x = post_attn_layout(x).contiguous()
|
387 |
+
b, s, a, d = x.shape
|
388 |
+
out = x.reshape(b, s, -1)
|
389 |
+
return out
|
390 |
+
|
391 |
+
|
392 |
+
class SelfAttentionLayer(BasicAttentionLayer):
|
393 |
+
def __init__(
|
394 |
+
self,
|
395 |
+
dim,
|
396 |
+
num_heads,
|
397 |
+
qkv_bias=True,
|
398 |
+
qk_norm=True,
|
399 |
+
attn_drop=0,
|
400 |
+
proj_drop=0,
|
401 |
+
dtype=None,
|
402 |
+
device=None,
|
403 |
+
norm_type="layer",
|
404 |
+
attn_mode="self_flash",
|
405 |
+
deterministic=False,
|
406 |
+
) -> None:
|
407 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
408 |
+
super().__init__(attn_mode, deterministic)
|
409 |
+
self.dim = dim
|
410 |
+
self.num_heads = num_heads
|
411 |
+
assert self.dim % num_heads == 0, "dim must be divisible by num_heads"
|
412 |
+
self.head_dim = self.dim // num_heads
|
413 |
+
self.attn_drop = attn_drop
|
414 |
+
|
415 |
+
# This assertion is aligned with flash attention
|
416 |
+
assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8"
|
417 |
+
|
418 |
+
self.Wqkv = nn.Linear(dim, dim * 3, bias=qkv_bias, **factory_kwargs)
|
419 |
+
|
420 |
+
norm_layer = get_norm_layer(norm_type)
|
421 |
+
self.q_norm = (
|
422 |
+
norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
|
423 |
+
)
|
424 |
+
self.k_norm = (
|
425 |
+
norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
|
426 |
+
)
|
427 |
+
|
428 |
+
self.out_proj = nn.Linear(dim, dim, bias=qkv_bias, **factory_kwargs)
|
429 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
430 |
+
|
431 |
+
def forward(self, x, freqs_cis=None, attn_mask=None):
|
432 |
+
"""
|
433 |
+
Args:
|
434 |
+
x (torch.Tensor): (batch, seq_len, hidden_dim) (where hidden_dim = num heads * head dim)
|
435 |
+
freqs_cis (torch.Tensor, optional): (batch, hidden_dim // 2), RoPE for image
|
436 |
+
attn_mask (torch.Tensor, optional): (batch, seq_len, seq_len), mask for attention
|
437 |
+
"""
|
438 |
+
b, s, d = x.shape
|
439 |
+
|
440 |
+
# Apply QKV projection
|
441 |
+
qkv = self.Wqkv(x)
|
442 |
+
qkv = qkv.view(b, s, 3, self.num_heads, self.head_dim) # [b, s, 3, a, d]
|
443 |
+
q, k, v = qkv.unbind(dim=2) # [b, s, a, d]
|
444 |
+
|
445 |
+
# Apply QK-Norm if needed
|
446 |
+
q = self.q_norm(q)
|
447 |
+
k = self.k_norm(k)
|
448 |
+
|
449 |
+
# Apply RoPE if needed
|
450 |
+
if freqs_cis is not None:
|
451 |
+
qq, kk = apply_rotary_emb(q, k, freqs_cis)
|
452 |
+
assert (
|
453 |
+
qq.shape == q.shape and kk.shape == k.shape
|
454 |
+
), f"qq: {qq.shape}, q: {q.shape}, kk: {kk.shape}, k: {k.shape}"
|
455 |
+
q, k = qq, kk
|
456 |
+
|
457 |
+
# Apply self attention
|
458 |
+
context = attention(
|
459 |
+
q,
|
460 |
+
k,
|
461 |
+
v,
|
462 |
+
drop_rate=self.attn_drop if self.training else 0,
|
463 |
+
attn_mask=attn_mask,
|
464 |
+
mode=self.attn_mode,
|
465 |
+
deterministic=self.deterministic,
|
466 |
+
)
|
467 |
+
out = self.out_proj(context)
|
468 |
+
out = self.proj_drop(out)
|
469 |
+
|
470 |
+
return out
|
471 |
+
|
472 |
+
|
473 |
+
class CrossAttentionLayer(BasicAttentionLayer):
|
474 |
+
def __init__(
|
475 |
+
self,
|
476 |
+
qdim,
|
477 |
+
kdim,
|
478 |
+
num_heads,
|
479 |
+
qkv_bias=True,
|
480 |
+
qk_norm=True,
|
481 |
+
attn_drop=0,
|
482 |
+
proj_drop=0,
|
483 |
+
dtype=None,
|
484 |
+
device=None,
|
485 |
+
norm_type="layer",
|
486 |
+
attn_mode="cross_flash",
|
487 |
+
deterministic=False,
|
488 |
+
):
|
489 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
490 |
+
super().__init__(attn_mode, deterministic)
|
491 |
+
self.qdim = qdim
|
492 |
+
self.kdim = kdim
|
493 |
+
self.num_heads = num_heads
|
494 |
+
assert self.qdim % num_heads == 0, "qdim must be divisible by num_heads"
|
495 |
+
self.head_dim = self.qdim // num_heads
|
496 |
+
self.attn_drop = attn_drop
|
497 |
+
|
498 |
+
# This assertion is aligned with flash attention
|
499 |
+
assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8"
|
500 |
+
|
501 |
+
self.q_proj = nn.Linear(qdim, qdim, bias=qkv_bias, **factory_kwargs)
|
502 |
+
self.kv_proj = nn.Linear(kdim, 2 * qdim, bias=qkv_bias, **factory_kwargs)
|
503 |
+
|
504 |
+
norm_layer = get_norm_layer(norm_type)
|
505 |
+
self.q_norm = (
|
506 |
+
norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
|
507 |
+
)
|
508 |
+
self.k_norm = (
|
509 |
+
norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
|
510 |
+
)
|
511 |
+
|
512 |
+
self.out_proj = nn.Linear(qdim, qdim, bias=qkv_bias, **factory_kwargs)
|
513 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
514 |
+
|
515 |
+
def forward(self, x, y, attn_mask=None):
|
516 |
+
"""
|
517 |
+
Args:
|
518 |
+
x (torch.Tensor): (batch, seq_len, hidden_dim) (where hidden_dim = num heads * head dim)
|
519 |
+
y (torch.Tensor): (batch, seq_len1, hidden_dim1)
|
520 |
+
attn_mask (torch.Tensor): (batch, seq_len1), mask for attention
|
521 |
+
"""
|
522 |
+
b, s, d = x.shape
|
523 |
+
_, s1, d1 = y.shape
|
524 |
+
|
525 |
+
q = self.q_proj(x).view(b, s, self.num_heads, self.head_dim)
|
526 |
+
kv = self.kv_proj(y).view(b, s1, 2, self.num_heads, self.head_dim)
|
527 |
+
k, v = kv.unbind(dim=2)
|
528 |
+
|
529 |
+
# Apply QK-Norm if needed
|
530 |
+
q = self.q_norm(q)
|
531 |
+
k = self.k_norm(k)
|
532 |
+
|
533 |
+
# Apply cross attention
|
534 |
+
context = attention(
|
535 |
+
q,
|
536 |
+
k,
|
537 |
+
v,
|
538 |
+
attn_mask=attn_mask,
|
539 |
+
drop_rate=self.attn_drop if self.training else 0,
|
540 |
+
mode=self.attn_mode,
|
541 |
+
deterministic=self.deterministic,
|
542 |
+
)
|
543 |
+
out = self.out_proj(context)
|
544 |
+
out = self.proj_drop(out)
|
545 |
+
|
546 |
+
return out
|
hunyuanvideo_foley/models/nn/embed_layers.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
from ...utils.helper import to_2tuple, to_1tuple
|
6 |
+
|
7 |
+
class PatchEmbed1D(nn.Module):
|
8 |
+
"""1D Audio to Patch Embedding
|
9 |
+
|
10 |
+
A convolution based approach to patchifying a 1D audio w/ embedding projection.
|
11 |
+
|
12 |
+
Based on the impl in https://github.com/google-research/vision_transformer
|
13 |
+
|
14 |
+
Hacked together by / Copyright 2020 Ross Wightman
|
15 |
+
"""
|
16 |
+
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
patch_size=1,
|
20 |
+
in_chans=768,
|
21 |
+
embed_dim=768,
|
22 |
+
norm_layer=None,
|
23 |
+
flatten=True,
|
24 |
+
bias=True,
|
25 |
+
dtype=None,
|
26 |
+
device=None,
|
27 |
+
):
|
28 |
+
factory_kwargs = {"dtype": dtype, "device": device}
|
29 |
+
super().__init__()
|
30 |
+
patch_size = to_1tuple(patch_size)
|
31 |
+
self.patch_size = patch_size
|
32 |
+
self.flatten = flatten
|
33 |
+
|
34 |
+
self.proj = nn.Conv1d(
|
35 |
+
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias, **factory_kwargs
|
36 |
+
)
|
37 |
+
nn.init.xavier_uniform_(self.proj.weight.view(self.proj.weight.size(0), -1))
|
38 |
+
if bias:
|
39 |
+
nn.init.zeros_(self.proj.bias)
|
40 |
+
|
41 |
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
42 |
+
|
43 |
+
def forward(self, x):
|
44 |
+
assert (
|
45 |
+
x.shape[2] % self.patch_size[0] == 0
|
46 |
+
), f"The patch_size of {self.patch_size[0]} must be divisible by the token number ({x.shape[2]}) of x."
|
47 |
+
|
48 |
+
x = self.proj(x)
|
49 |
+
if self.flatten:
|
50 |
+
x = x.transpose(1, 2) # BCN -> BNC
|
51 |
+
x = self.norm(x)
|
52 |
+
return x
|
53 |
+
|
54 |
+
|
55 |
+
class ConditionProjection(nn.Module):
|
56 |
+
"""
|
57 |
+
Projects condition embeddings. Also handles dropout for classifier-free guidance.
|
58 |
+
|
59 |
+
Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
|
60 |
+
"""
|
61 |
+
|
62 |
+
def __init__(self, in_channels, hidden_size, act_layer, dtype=None, device=None):
|
63 |
+
factory_kwargs = {'dtype': dtype, 'device': device}
|
64 |
+
super().__init__()
|
65 |
+
self.linear_1 = nn.Linear(in_features=in_channels, out_features=hidden_size, bias=True, **factory_kwargs)
|
66 |
+
self.act_1 = act_layer()
|
67 |
+
self.linear_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True, **factory_kwargs)
|
68 |
+
|
69 |
+
def forward(self, caption):
|
70 |
+
hidden_states = self.linear_1(caption)
|
71 |
+
hidden_states = self.act_1(hidden_states)
|
72 |
+
hidden_states = self.linear_2(hidden_states)
|
73 |
+
return hidden_states
|
74 |
+
|
75 |
+
|
76 |
+
def timestep_embedding(t, dim, max_period=10000):
|
77 |
+
"""
|
78 |
+
Create sinusoidal timestep embeddings.
|
79 |
+
|
80 |
+
Args:
|
81 |
+
t (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional.
|
82 |
+
dim (int): the dimension of the output.
|
83 |
+
max_period (int): controls the minimum frequency of the embeddings.
|
84 |
+
|
85 |
+
Returns:
|
86 |
+
embedding (torch.Tensor): An (N, D) Tensor of positional embeddings.
|
87 |
+
|
88 |
+
.. ref_link: https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
89 |
+
"""
|
90 |
+
half = dim // 2
|
91 |
+
freqs = torch.exp(
|
92 |
+
-math.log(max_period)
|
93 |
+
* torch.arange(start=0, end=half, dtype=torch.float32)
|
94 |
+
/ half
|
95 |
+
).to(device=t.device)
|
96 |
+
args = t[:, None].float() * freqs[None]
|
97 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
98 |
+
if dim % 2:
|
99 |
+
embedding = torch.cat(
|
100 |
+
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
|
101 |
+
)
|
102 |
+
return embedding
|
103 |
+
|
104 |
+
|
105 |
+
class TimestepEmbedder(nn.Module):
|
106 |
+
"""
|
107 |
+
Embeds scalar timesteps into vector representations.
|
108 |
+
"""
|
109 |
+
def __init__(self,
|
110 |
+
hidden_size,
|
111 |
+
act_layer,
|
112 |
+
frequency_embedding_size=256,
|
113 |
+
max_period=10000,
|
114 |
+
out_size=None,
|
115 |
+
dtype=None,
|
116 |
+
device=None
|
117 |
+
):
|
118 |
+
factory_kwargs = {'dtype': dtype, 'device': device}
|
119 |
+
super().__init__()
|
120 |
+
self.frequency_embedding_size = frequency_embedding_size
|
121 |
+
self.max_period = max_period
|
122 |
+
if out_size is None:
|
123 |
+
out_size = hidden_size
|
124 |
+
|
125 |
+
self.mlp = nn.Sequential(
|
126 |
+
nn.Linear(frequency_embedding_size, hidden_size, bias=True, **factory_kwargs),
|
127 |
+
act_layer(),
|
128 |
+
nn.Linear(hidden_size, out_size, bias=True, **factory_kwargs),
|
129 |
+
)
|
130 |
+
nn.init.normal_(self.mlp[0].weight, std=0.02)
|
131 |
+
nn.init.normal_(self.mlp[2].weight, std=0.02)
|
132 |
+
|
133 |
+
def forward(self, t):
|
134 |
+
t_freq = timestep_embedding(t, self.frequency_embedding_size, self.max_period).type(self.mlp[0].weight.dtype)
|
135 |
+
t_emb = self.mlp(t_freq)
|
136 |
+
return t_emb
|
hunyuanvideo_foley/models/nn/mlp_layers.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from timm library:
|
2 |
+
# https://github.com/huggingface/pytorch-image-models/blob/648aaa41233ba83eb38faf5ba9d415d574823241/timm/layers/mlp.py#L13
|
3 |
+
|
4 |
+
from functools import partial
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
from .modulate_layers import modulate
|
11 |
+
from ...utils.helper import to_2tuple
|
12 |
+
|
13 |
+
class MLP(nn.Module):
|
14 |
+
"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""
|
15 |
+
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
in_channels,
|
19 |
+
hidden_channels=None,
|
20 |
+
out_features=None,
|
21 |
+
act_layer=nn.GELU,
|
22 |
+
norm_layer=None,
|
23 |
+
bias=True,
|
24 |
+
drop=0.0,
|
25 |
+
use_conv=False,
|
26 |
+
device=None,
|
27 |
+
dtype=None,
|
28 |
+
):
|
29 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
30 |
+
super().__init__()
|
31 |
+
out_features = out_features or in_channels
|
32 |
+
hidden_channels = hidden_channels or in_channels
|
33 |
+
bias = to_2tuple(bias)
|
34 |
+
drop_probs = to_2tuple(drop)
|
35 |
+
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
|
36 |
+
|
37 |
+
self.fc1 = linear_layer(in_channels, hidden_channels, bias=bias[0], **factory_kwargs)
|
38 |
+
self.act = act_layer()
|
39 |
+
self.drop1 = nn.Dropout(drop_probs[0])
|
40 |
+
self.norm = norm_layer(hidden_channels, **factory_kwargs) if norm_layer is not None else nn.Identity()
|
41 |
+
self.fc2 = linear_layer(hidden_channels, out_features, bias=bias[1], **factory_kwargs)
|
42 |
+
self.drop2 = nn.Dropout(drop_probs[1])
|
43 |
+
|
44 |
+
def forward(self, x):
|
45 |
+
x = self.fc1(x)
|
46 |
+
x = self.act(x)
|
47 |
+
x = self.drop1(x)
|
48 |
+
x = self.norm(x)
|
49 |
+
x = self.fc2(x)
|
50 |
+
x = self.drop2(x)
|
51 |
+
return x
|
52 |
+
|
53 |
+
|
54 |
+
# copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py
|
55 |
+
# only used when use_vanilla is True
|
56 |
+
class MLPEmbedder(nn.Module):
|
57 |
+
def __init__(self, in_dim: int, hidden_dim: int, device=None, dtype=None):
|
58 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
59 |
+
super().__init__()
|
60 |
+
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True, **factory_kwargs)
|
61 |
+
self.silu = nn.SiLU()
|
62 |
+
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True, **factory_kwargs)
|
63 |
+
|
64 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
65 |
+
return self.out_layer(self.silu(self.in_layer(x)))
|
66 |
+
|
67 |
+
|
68 |
+
class LinearWarpforSingle(nn.Module):
|
69 |
+
def __init__(self, in_dim: int, out_dim: int, bias=True, device=None, dtype=None):
|
70 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
71 |
+
super().__init__()
|
72 |
+
self.fc = nn.Linear(in_dim, out_dim, bias=bias, **factory_kwargs)
|
73 |
+
|
74 |
+
def forward(self, x, y):
|
75 |
+
z = torch.cat([x, y], dim=2)
|
76 |
+
return self.fc(z)
|
77 |
+
|
78 |
+
class FinalLayer1D(nn.Module):
|
79 |
+
def __init__(self, hidden_size, patch_size, out_channels, act_layer, device=None, dtype=None):
|
80 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
81 |
+
super().__init__()
|
82 |
+
|
83 |
+
# Just use LayerNorm for the final layer
|
84 |
+
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
85 |
+
self.linear = nn.Linear(hidden_size, patch_size * out_channels, bias=True, **factory_kwargs)
|
86 |
+
nn.init.zeros_(self.linear.weight)
|
87 |
+
nn.init.zeros_(self.linear.bias)
|
88 |
+
|
89 |
+
# Here we don't distinguish between the modulate types. Just use the simple one.
|
90 |
+
self.adaLN_modulation = nn.Sequential(
|
91 |
+
act_layer(), nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs)
|
92 |
+
)
|
93 |
+
# Zero-initialize the modulation
|
94 |
+
nn.init.zeros_(self.adaLN_modulation[1].weight)
|
95 |
+
nn.init.zeros_(self.adaLN_modulation[1].bias)
|
96 |
+
|
97 |
+
def forward(self, x, c):
|
98 |
+
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
|
99 |
+
x = modulate(self.norm_final(x), shift=shift, scale=scale)
|
100 |
+
x = self.linear(x)
|
101 |
+
return x
|
102 |
+
|
103 |
+
|
104 |
+
class ChannelLastConv1d(nn.Conv1d):
|
105 |
+
|
106 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
107 |
+
x = x.permute(0, 2, 1)
|
108 |
+
x = super().forward(x)
|
109 |
+
x = x.permute(0, 2, 1)
|
110 |
+
return x
|
111 |
+
|
112 |
+
|
113 |
+
class ConvMLP(nn.Module):
|
114 |
+
|
115 |
+
def __init__(
|
116 |
+
self,
|
117 |
+
dim: int,
|
118 |
+
hidden_dim: int,
|
119 |
+
multiple_of: int = 256,
|
120 |
+
kernel_size: int = 3,
|
121 |
+
padding: int = 1,
|
122 |
+
device=None,
|
123 |
+
dtype=None,
|
124 |
+
):
|
125 |
+
"""
|
126 |
+
Convolutional MLP module.
|
127 |
+
|
128 |
+
Args:
|
129 |
+
dim (int): Input dimension.
|
130 |
+
hidden_dim (int): Hidden dimension of the feedforward layer.
|
131 |
+
multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
|
132 |
+
|
133 |
+
Attributes:
|
134 |
+
w1: Linear transformation for the first layer.
|
135 |
+
w2: Linear transformation for the second layer.
|
136 |
+
w3: Linear transformation for the third layer.
|
137 |
+
|
138 |
+
"""
|
139 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
140 |
+
super().__init__()
|
141 |
+
hidden_dim = int(2 * hidden_dim / 3)
|
142 |
+
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
143 |
+
|
144 |
+
self.w1 = ChannelLastConv1d(dim, hidden_dim, bias=False, kernel_size=kernel_size, padding=padding, **factory_kwargs)
|
145 |
+
self.w2 = ChannelLastConv1d(hidden_dim, dim, bias=False, kernel_size=kernel_size, padding=padding, **factory_kwargs)
|
146 |
+
self.w3 = ChannelLastConv1d(dim, hidden_dim, bias=False, kernel_size=kernel_size, padding=padding, **factory_kwargs)
|
147 |
+
|
148 |
+
def forward(self, x):
|
149 |
+
return self.w2(F.silu(self.w1(x)) * self.w3(x))
|
hunyuanvideo_foley/models/nn/modulate_layers.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Callable
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
class ModulateDiT(nn.Module):
|
6 |
+
def __init__(self, hidden_size: int, factor: int, act_layer: Callable, dtype=None, device=None):
|
7 |
+
factory_kwargs = {"dtype": dtype, "device": device}
|
8 |
+
super().__init__()
|
9 |
+
self.act = act_layer()
|
10 |
+
self.linear = nn.Linear(hidden_size, factor * hidden_size, bias=True, **factory_kwargs)
|
11 |
+
# Zero-initialize the modulation
|
12 |
+
nn.init.zeros_(self.linear.weight)
|
13 |
+
nn.init.zeros_(self.linear.bias)
|
14 |
+
|
15 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
16 |
+
return self.linear(self.act(x))
|
17 |
+
|
18 |
+
|
19 |
+
def modulate(x, shift=None, scale=None):
|
20 |
+
if x.ndim == 3:
|
21 |
+
shift = shift.unsqueeze(1) if shift is not None and shift.ndim == 2 else None
|
22 |
+
scale = scale.unsqueeze(1) if scale is not None and scale.ndim == 2 else None
|
23 |
+
if scale is None and shift is None:
|
24 |
+
return x
|
25 |
+
elif shift is None:
|
26 |
+
return x * (1 + scale)
|
27 |
+
elif scale is None:
|
28 |
+
return x + shift
|
29 |
+
else:
|
30 |
+
return x * (1 + scale) + shift
|
31 |
+
|
32 |
+
|
33 |
+
def apply_gate(x, gate=None, tanh=False):
|
34 |
+
if gate is None:
|
35 |
+
return x
|
36 |
+
if gate.ndim == 2 and x.ndim == 3:
|
37 |
+
gate = gate.unsqueeze(1)
|
38 |
+
if tanh:
|
39 |
+
return x * gate.tanh()
|
40 |
+
else:
|
41 |
+
return x * gate
|
42 |
+
|
43 |
+
|
44 |
+
def ckpt_wrapper(module):
|
45 |
+
def ckpt_forward(*inputs):
|
46 |
+
outputs = module(*inputs)
|
47 |
+
return outputs
|
48 |
+
|
49 |
+
return ckpt_forward
|
hunyuanvideo_foley/models/nn/norm_layers.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
class RMSNorm(nn.Module):
|
5 |
+
def __init__(self, dim: int, elementwise_affine=True, eps: float = 1e-6,
|
6 |
+
device=None, dtype=None):
|
7 |
+
"""
|
8 |
+
Initialize the RMSNorm normalization layer.
|
9 |
+
|
10 |
+
Args:
|
11 |
+
dim (int): The dimension of the input tensor.
|
12 |
+
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
|
13 |
+
|
14 |
+
Attributes:
|
15 |
+
eps (float): A small value added to the denominator for numerical stability.
|
16 |
+
weight (nn.Parameter): Learnable scaling parameter.
|
17 |
+
|
18 |
+
"""
|
19 |
+
factory_kwargs = {'device': device, 'dtype': dtype}
|
20 |
+
super().__init__()
|
21 |
+
self.eps = eps
|
22 |
+
if elementwise_affine:
|
23 |
+
self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
|
24 |
+
|
25 |
+
def _norm(self, x):
|
26 |
+
"""
|
27 |
+
Apply the RMSNorm normalization to the input tensor.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
x (torch.Tensor): The input tensor.
|
31 |
+
|
32 |
+
Returns:
|
33 |
+
torch.Tensor: The normalized tensor.
|
34 |
+
|
35 |
+
"""
|
36 |
+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
37 |
+
|
38 |
+
def forward(self, x):
|
39 |
+
"""
|
40 |
+
Forward pass through the RMSNorm layer.
|
41 |
+
|
42 |
+
Args:
|
43 |
+
x (torch.Tensor): The input tensor.
|
44 |
+
|
45 |
+
Returns:
|
46 |
+
torch.Tensor: The output tensor after applying RMSNorm.
|
47 |
+
|
48 |
+
"""
|
49 |
+
output = self._norm(x.float()).type_as(x)
|
50 |
+
if hasattr(self, "weight"):
|
51 |
+
output = output * self.weight
|
52 |
+
return output
|
53 |
+
|
54 |
+
|
55 |
+
def get_norm_layer(norm_layer):
|
56 |
+
"""
|
57 |
+
Get the normalization layer.
|
58 |
+
|
59 |
+
Args:
|
60 |
+
norm_layer (str): The type of normalization layer.
|
61 |
+
|
62 |
+
Returns:
|
63 |
+
norm_layer (nn.Module): The normalization layer.
|
64 |
+
"""
|
65 |
+
if norm_layer == "layer":
|
66 |
+
return nn.LayerNorm
|
67 |
+
elif norm_layer == "rms":
|
68 |
+
return RMSNorm
|
69 |
+
else:
|
70 |
+
raise NotImplementedError(f"Norm layer {norm_layer} is not implemented")
|