James Zhou commited on
Commit
9867d34
·
1 Parent(s): 860b27a
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. LICENSE +77 -0
  3. app.py +814 -0
  4. assets/data_pipeline.png +3 -0
  5. assets/model_arch.png +3 -0
  6. assets/pan_chart.png +3 -0
  7. configs/hunyuanvideo-foley-xxl.yaml +49 -0
  8. examples/1_result.mp4 +3 -0
  9. examples/1_video.mp4 +3 -0
  10. examples/2_result.mp4 +3 -0
  11. examples/2_video.mp4 +3 -0
  12. examples/3_result.mp4 +3 -0
  13. examples/3_video.mp4 +3 -0
  14. examples/4_result.mp4 +3 -0
  15. examples/4_video.mp4 +3 -0
  16. examples/5_result.mp4 +3 -0
  17. examples/5_video.mp4 +3 -0
  18. examples/6_result.mp4 +3 -0
  19. examples/6_video.mp4 +3 -0
  20. examples/7_result.mp4 +3 -0
  21. examples/7_video.mp4 +3 -0
  22. examples/8_result.mp4 +3 -0
  23. examples/8_video.mp4 +3 -0
  24. hunyuanvideo_foley/__init__.py +0 -0
  25. hunyuanvideo_foley/__pycache__/__init__.cpython-312.pyc +0 -0
  26. hunyuanvideo_foley/constants.py +57 -0
  27. hunyuanvideo_foley/models/__init__.py +0 -0
  28. hunyuanvideo_foley/models/__pycache__/mmaudio_layer.cpython-312.pyc +0 -0
  29. hunyuanvideo_foley/models/dac_vae/__init__.py +16 -0
  30. hunyuanvideo_foley/models/dac_vae/__main__.py +36 -0
  31. hunyuanvideo_foley/models/dac_vae/model/__init__.py +4 -0
  32. hunyuanvideo_foley/models/dac_vae/model/base.py +301 -0
  33. hunyuanvideo_foley/models/dac_vae/model/dac.py +410 -0
  34. hunyuanvideo_foley/models/dac_vae/model/discriminator.py +228 -0
  35. hunyuanvideo_foley/models/dac_vae/nn/__init__.py +3 -0
  36. hunyuanvideo_foley/models/dac_vae/nn/layers.py +33 -0
  37. hunyuanvideo_foley/models/dac_vae/nn/loss.py +368 -0
  38. hunyuanvideo_foley/models/dac_vae/nn/quantize.py +262 -0
  39. hunyuanvideo_foley/models/dac_vae/nn/vae_utils.py +91 -0
  40. hunyuanvideo_foley/models/dac_vae/utils/__init__.py +121 -0
  41. hunyuanvideo_foley/models/dac_vae/utils/decode.py +95 -0
  42. hunyuanvideo_foley/models/dac_vae/utils/encode.py +94 -0
  43. hunyuanvideo_foley/models/hifi_foley.py +794 -0
  44. hunyuanvideo_foley/models/nn/__init__.py +0 -0
  45. hunyuanvideo_foley/models/nn/activation_layers.py +44 -0
  46. hunyuanvideo_foley/models/nn/attn_layers.py +546 -0
  47. hunyuanvideo_foley/models/nn/embed_layers.py +136 -0
  48. hunyuanvideo_foley/models/nn/mlp_layers.py +149 -0
  49. hunyuanvideo_foley/models/nn/modulate_layers.py +49 -0
  50. hunyuanvideo_foley/models/nn/norm_layers.py +70 -0
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.png filter=lfs diff=lfs merge=lfs -text
37
+ *.mp4 filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT
2
+ Tencent HunyuanVideo-Foley Release Date: August 28, 2025
3
+ THIS LICENSE AGREEMENT DOES NOT APPLY IN THE EUROPEAN UNION, UNITED KINGDOM AND SOUTH KOREA AND IS EXPRESSLY LIMITED TO THE TERRITORY, AS DEFINED BELOW.
4
+ By clicking to agree or by using, reproducing, modifying, distributing, performing or displaying any portion or element of the Tencent Hunyuan Works, including via any Hosted Service, You will be deemed to have recognized and accepted the content of this Agreement, which is effective immediately.
5
+ 1. DEFINITIONS.
6
+ a. “Acceptable Use Policy” shall mean the policy made available by Tencent as set forth in the Exhibit A.
7
+ b. “Agreement” shall mean the terms and conditions for use, reproduction, distribution, modification, performance and displaying of Tencent Hunyuan Works or any portion or element thereof set forth herein.
8
+ c. “Documentation” shall mean the specifications, manuals and documentation for Tencent Hunyuan made publicly available by Tencent.
9
+ d. “Hosted Service” shall mean a hosted service offered via an application programming interface (API), web access, or any other electronic or remote means.
10
+ e. “Licensee,” “You” or “Your” shall mean a natural person or legal entity exercising the rights granted by this Agreement and/or using the Tencent Hunyuan Works for any purpose and in any field of use.
11
+ f. “Materials” shall mean, collectively, Tencent’s proprietary Tencent Hunyuan and Documentation (and any portion thereof) as made available by Tencent under this Agreement.
12
+ g. “Model Derivatives” shall mean all: (i) modifications to Tencent Hunyuan or any Model Derivative of Tencent Hunyuan; (ii) works based on Tencent Hunyuan or any Model Derivative of Tencent Hunyuan; or (iii) any other machine learning model which is created by transfer of patterns of the weights, parameters, operations, or Output of Tencent Hunyuan or any Model Derivative of Tencent Hunyuan, to that model in order to cause that model to perform similarly to Tencent Hunyuan or a Model Derivative of Tencent Hunyuan, including distillation methods, methods that use intermediate data representations, or methods based on the generation of synthetic data Outputs by Tencent Hunyuan or a Model Derivative of Tencent Hunyuan for training that model. For clarity, Outputs by themselves are not deemed Model Derivatives.
13
+ h. “Output” shall mean the information and/or content output of Tencent Hunyuan or a Model Derivative that results from operating or otherwise using Tencent Hunyuan or a Model Derivative, including via a Hosted Service.
14
+ i. “Tencent,” “We” or “Us” shall mean the applicable entity or entities in the Tencent corporate family that own(s) intellectual property or other rights embodied in or utilized by the Materials.
15
+ j. “Tencent Hunyuan” shall mean the large language models, text/image/video/audio/3D generation models, and multimodal large language models and their software and algorithms, including trained model weights, parameters (including optimizer states), machine-learning model code, inference-enabling code, training-enabling code, fine-tuning enabling code and other elements of the foregoing made publicly available by Us, including, without limitation to, Tencent HunyuanVideo-Foley released at [https://github.com/Tencent-Hunyuan/HunyuanVideo-Foley].
16
+ k. “Tencent Hunyuan Works” shall mean: (i) the Materials; (ii) Model Derivatives; and (iii) all derivative works thereof.
17
+ l. “Territory” shall mean the worldwide territory, excluding the territory of the European Union, United Kingdom and South Korea.
18
+ m. “Third Party” or “Third Parties” shall mean individuals or legal entities that are not under common control with Us or You.
19
+ n. “including” shall mean including but not limited to.
20
+ 2. GRANT OF RIGHTS.
21
+ We grant You, for the Territory only, a non-exclusive, non-transferable and royalty-free limited license under Tencent’s intellectual property or other rights owned by Us embodied in or utilized by the Materials to use, reproduce, distribute, create derivative works of (including Model Derivatives), and make modifications to the Materials, only in accordance with the terms of this Agreement and the Acceptable Use Policy, and You must not violate (or encourage or permit anyone else to violate) any term of this Agreement or the Acceptable Use Policy.
22
+ 3. DISTRIBUTION.
23
+ You may, subject to Your compliance with this Agreement, distribute or make available to Third Parties the Tencent Hunyuan Works, exclusively in the Territory, provided that You meet all of the following conditions:
24
+ a. You must provide all such Third Party recipients of the Tencent Hunyuan Works or products or services using them a copy of this Agreement;
25
+ b. You must cause any modified files to carry prominent notices stating that You changed the files;
26
+ c. You are encouraged to: (i) publish at least one technology introduction blogpost or one public statement expressing Your experience of using the Tencent Hunyuan Works; and (ii) mark the products or services developed by using the Tencent Hunyuan Works to indicate that the product/service is “Powered by Tencent Hunyuan”; and
27
+ d. All distributions to Third Parties (other than through a Hosted Service) must be accompanied by a “Notice” text file that contains the following notice: “Tencent Hunyuan is licensed under the Tencent Hunyuan Community License Agreement, Copyright © 2025 Tencent. All Rights Reserved. The trademark rights of “Tencent Hunyuan” are owned by Tencent or its affiliate.”
28
+ You may add Your own copyright statement to Your modifications and, except as set forth in this Section and in Section 5, may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Model Derivatives as a whole, provided Your use, reproduction, modification, distribution, performance and display of the work otherwise complies with the terms and conditions of this Agreement (including as regards the Territory). If You receive Tencent Hunyuan Works from a Licensee as part of an integrated end user product, then this Section 3 of this Agreement will not apply to You.
29
+ 4. ADDITIONAL COMMERCIAL TERMS.
30
+ If, on the Tencent Hunyuan version release date, the monthly active users of all products or services made available by or for Licensee is greater than 100 million monthly active users in the preceding calendar month, You must request a license from Tencent, which Tencent may grant to You in its sole discretion, and You are not authorized to exercise any of the rights under this Agreement unless or until Tencent otherwise expressly grants You such rights.
31
+ 5. RULES OF USE.
32
+ a. Your use of the Tencent Hunyuan Works must comply with applicable laws and regulations (including trade compliance laws and regulations) and adhere to the Acceptable Use Policy for the Tencent Hunyuan Works, which is hereby incorporated by reference into this Agreement. You must include the use restrictions referenced in these Sections 5(a) and 5(b) as an enforceable provision in any agreement (e.g., license agreement, terms of use, etc.) governing the use and/or distribution of Tencent Hunyuan Works and You must provide notice to subsequent users to whom You distribute that Tencent Hunyuan Works are subject to the use restrictions in these Sections 5(a) and 5(b).
33
+ b. You must not use the Tencent Hunyuan Works or any Output or results of the Tencent Hunyuan Works to improve any other AI model (other than Tencent Hunyuan or Model Derivatives thereof).
34
+ c. You must not use, reproduce, modify, distribute, or display the Tencent Hunyuan Works, Output or results of the Tencent Hunyuan Works outside the Territory. Any such use outside the Territory is unlicensed and unauthorized under this Agreement.
35
+ 6. INTELLECTUAL PROPERTY.
36
+ a. Subject to Tencent’s ownership of Tencent Hunyuan Works made by or for Tencent and intellectual property rights therein, conditioned upon Your compliance with the terms and conditions of this Agreement, as between You and Tencent, You will be the owner of any derivative works and modifications of the Materials and any Model Derivatives that are made by or for You.
37
+ b. No trademark licenses are granted under this Agreement, and in connection with the Tencent Hunyuan Works, Licensee may not use any name or mark owned by or associated with Tencent or any of its affiliates, except as required for reasonable and customary use in describing and distributing the Tencent Hunyuan Works. Tencent hereby grants You a license to use “Tencent Hunyuan” (the “Mark”) in the Territory solely as required to comply with the provisions of Section 3(c), provided that You comply with any applicable laws related to trademark protection. All goodwill arising out of Your use of the Mark will inure to the benefit of Tencent.
38
+ c. If You commence a lawsuit or other proceedings (including a cross-claim or counterclaim in a lawsuit) against Us or any person or entity alleging that the Materials or any Output, or any portion of any of the foregoing, infringe any intellectual property or other right owned or licensable by You, then all licenses granted to You under this Agreement shall terminate as of the date such lawsuit or other proceeding is filed. You will defend, indemnify and hold harmless Us from and against any claim by any Third Party arising out of or related to Your or the Third Party’s use or distribution of the Tencent Hunyuan Works.
39
+ d. Tencent claims no rights in Outputs You generate. You and Your users are solely responsible for Outputs and their subsequent uses.
40
+ 7. DISCLAIMERS OF WARRANTY AND LIMITATIONS OF LIABILITY.
41
+ a. We are not obligated to support, update, provide training for, or develop any further version of the Tencent Hunyuan Works or to grant any license thereto.
42
+ b. UNLESS AND ONLY TO THE EXTENT REQUIRED BY APPLICABLE LAW, THE TENCENT HUNYUAN WORKS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED “AS IS” WITHOUT ANY EXPRESS OR IMPLIED WARRANTIES OF ANY KIND INCLUDING ANY WARRANTIES OF TITLE, MERCHANTABILITY, NONINFRINGEMENT, COURSE OF DEALING, USAGE OF TRADE, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING, REPRODUCING, MODIFYING, PERFORMING, DISPLAYING OR DISTRIBUTING ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS AND ASSUME ANY AND ALL RISKS ASSOCIATED WITH YOUR OR A THIRD PARTY’S USE OR DISTRIBUTION OF ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS AND YOUR EXERCISE OF RIGHTS AND PERMISSIONS UNDER THIS AGREEMENT.
43
+ c. TO THE FULLEST EXTENT PERMITTED BY APPLICABLE LAW, IN NO EVENT SHALL TENCENT OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, FOR ANY DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, CONSEQUENTIAL OR PUNITIVE DAMAGES, OR LOST PROFITS OF ANY KIND ARISING FROM THIS AGREEMENT OR RELATED TO ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS, EVEN IF TENCENT OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
44
+ 8. SURVIVAL AND TERMINATION.
45
+ a. The term of this Agreement shall commence upon Your acceptance of this Agreement or access to the Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein.
46
+ b. We may terminate this Agreement if You breach any of the terms or conditions of this Agreement. Upon termination of this Agreement, You must promptly delete and cease use of the Tencent Hunyuan Works. Sections 6(a), 6(c), 7 and 9 shall survive the termination of this Agreement.
47
+ 9. GOVERNING LAW AND JURISDICTION.
48
+ a. This Agreement and any dispute arising out of or relating to it will be governed by the laws of the Hong Kong Special Administrative Region of the People’s Republic of China, without regard to conflict of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement.
49
+ b. Exclusive jurisdiction and venue for any dispute arising out of or relating to this Agreement will be a court of competent jurisdiction in the Hong Kong Special Administrative Region of the People’s Republic of China, and Tencent and Licensee consent to the exclusive jurisdiction of such court with respect to any such dispute.
50
+
51
+ EXHIBIT A
52
+ ACCEPTABLE USE POLICY
53
+
54
+ Tencent reserves the right to update this Acceptable Use Policy from time to time.
55
+ Last modified: November 5, 2024
56
+
57
+ Tencent endeavors to promote safe and fair use of its tools and features, including Tencent Hunyuan. You agree not to use Tencent Hunyuan or Model Derivatives:
58
+ 1. Outside the Territory;
59
+ 2. In any way that violates any applicable national, federal, state, local, international or any other law or regulation;
60
+ 3. To harm Yourself or others;
61
+ 4. To repurpose or distribute output from Tencent Hunyuan or any Model Derivatives to harm Yourself or others;
62
+ 5. To override or circumvent the safety guardrails and safeguards We have put in place;
63
+ 6. For the purpose of exploiting, harming or attempting to exploit or harm minors in any way;
64
+ 7. To generate or disseminate verifiably false information and/or content with the purpose of harming others or influencing elections;
65
+ 8. To generate or facilitate false online engagement, including fake reviews and other means of fake online engagement;
66
+ 9. To intentionally defame, disparage or otherwise harass others;
67
+ 10. To generate and/or disseminate malware (including ransomware) or any other content to be used for the purpose of harming electronic systems;
68
+ 11. To generate or disseminate personal identifiable information with the purpose of harming others;
69
+ 12. To generate or disseminate information (including images, code, posts, articles), and place the information in any public context (including –through the use of bot generated tweets), without expressly and conspicuously identifying that the information and/or content is machine generated;
70
+ 13. To impersonate another individual without consent, authorization, or legal right;
71
+ 14. To make high-stakes automated decisions in domains that affect an individual’s safety, rights or wellbeing (e.g., law enforcement, migration, medicine/health, management of critical infrastructure, safety components of products, essential services, credit, employment, housing, education, social scoring, or insurance);
72
+ 15. In a manner that violates or disrespects the social ethics and moral standards of other countries or regions;
73
+ 16. To perform, facilitate, threaten, incite, plan, promote or encourage violent extremism or terrorism;
74
+ 17. For any use intended to discriminate against or harm individuals or groups based on protected characteristics or categories, online or offline social behavior or known or predicted personal or personality characteristics;
75
+ 18. To intentionally exploit any of the vulnerabilities of a specific group of persons based on their age, social, physical or mental characteristics, in order to materially distort the behavior of a person pertaining to that group in a manner that causes or is likely to cause that person or another person physical or psychological harm;
76
+ 19. For military purposes;
77
+ 20. To engage in the unauthorized or unlicensed practice of any profession including, but not limited to, financial, legal, medical/health, or other professional practices.
app.py ADDED
@@ -0,0 +1,814 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+ import gradio as gr
4
+ import torch
5
+ import torchaudio
6
+ from loguru import logger
7
+ from typing import Optional, Tuple
8
+ import random
9
+ import numpy as np
10
+
11
+ from hunyuanvideo_foley.utils.model_utils import load_model
12
+ from hunyuanvideo_foley.utils.feature_utils import feature_process
13
+ from hunyuanvideo_foley.utils.model_utils import denoise_process
14
+ from hunyuanvideo_foley.utils.media_utils import merge_audio_video
15
+
16
+ # Global variables for model storage
17
+ model_dict = None
18
+ cfg = None
19
+ device = None
20
+
21
+ # need to modify the model path
22
+ MODEL_PATH = os.environ.get("HIFI_FOLEY_MODEL_PATH", "./pretrained_models/")
23
+ CONFIG_PATH = "configs/hunyuanvideo-foley-xxl.yaml"
24
+
25
+ def setup_device(device_str: str = "auto", gpu_id: int = 0) -> torch.device:
26
+ """Setup computing device"""
27
+ if device_str == "auto":
28
+ if torch.cuda.is_available():
29
+ device = torch.device(f"cuda:{gpu_id}")
30
+ logger.info(f"Using CUDA device: {device}")
31
+ elif torch.backends.mps.is_available():
32
+ device = torch.device("mps")
33
+ logger.info("Using MPS device")
34
+ else:
35
+ device = torch.device("cpu")
36
+ logger.info("Using CPU device")
37
+ else:
38
+ if device_str == "cuda":
39
+ device = torch.device(f"cuda:{gpu_id}")
40
+ else:
41
+ device = torch.device(device_str)
42
+ logger.info(f"Using specified device: {device}")
43
+
44
+ return device
45
+
46
+ def auto_load_models() -> str:
47
+ """Automatically load preset models"""
48
+ global model_dict, cfg, device
49
+
50
+ try:
51
+ if not os.path.exists(MODEL_PATH):
52
+ return f"❌ Model file not found: {MODEL_PATH}"
53
+ if not os.path.exists(CONFIG_PATH):
54
+ return f"❌ Config file not found: {CONFIG_PATH}"
55
+
56
+ # Use GPU by default
57
+ device = setup_device("auto", 0)
58
+
59
+ # Load model
60
+ logger.info("Auto-loading model...")
61
+ logger.info(f"Model path: {MODEL_PATH}")
62
+ logger.info(f"Config path: {CONFIG_PATH}")
63
+
64
+ model_dict, cfg = load_model(MODEL_PATH, CONFIG_PATH, device)
65
+
66
+ logger.info("✅ Model loaded successfully!")
67
+ return "✅ Model loaded successfully!"
68
+
69
+ except Exception as e:
70
+ logger.error(f"Model loading failed: {str(e)}")
71
+ return f"❌ Model loading failed: {str(e)}"
72
+
73
+ def infer_single_video(
74
+ video_file,
75
+ text_prompt: str,
76
+ guidance_scale: float = 4.5,
77
+ num_inference_steps: int = 50,
78
+ sample_nums: int = 1
79
+ ) -> Tuple[list, str]:
80
+ """Single video inference"""
81
+ global model_dict, cfg, device
82
+
83
+ if model_dict is None or cfg is None:
84
+ return [], "❌ Please load the model first!"
85
+
86
+ if video_file is None:
87
+ return [], "❌ Please upload a video file!"
88
+
89
+ # Allow empty text prompt, use empty string if no prompt provided
90
+ if text_prompt is None:
91
+ text_prompt = ""
92
+ text_prompt = text_prompt.strip()
93
+
94
+ try:
95
+ logger.info(f"Processing video: {video_file}")
96
+ logger.info(f"Text prompt: {text_prompt}")
97
+
98
+ # Feature processing
99
+ visual_feats, text_feats, audio_len_in_s = feature_process(
100
+ video_file,
101
+ text_prompt,
102
+ model_dict,
103
+ cfg
104
+ )
105
+
106
+ # Denoising process to generate multiple audio samples
107
+ # Note: The model now generates sample_nums audio samples per inference
108
+ # The denoise_process function returns audio with shape [batch_size, channels, samples]
109
+ logger.info(f"Generating {sample_nums} audio samples...")
110
+ audio, sample_rate = denoise_process(
111
+ visual_feats,
112
+ text_feats,
113
+ audio_len_in_s,
114
+ model_dict,
115
+ cfg,
116
+ guidance_scale=guidance_scale,
117
+ num_inference_steps=num_inference_steps,
118
+ batch_size=sample_nums
119
+ )
120
+
121
+ # Create temporary files to save results
122
+ temp_dir = tempfile.mkdtemp()
123
+ video_outputs = []
124
+
125
+ # Process each generated audio sample
126
+ for i in range(sample_nums):
127
+ # Save audio file
128
+ audio_output = os.path.join(temp_dir, f"generated_audio_{i+1}.wav")
129
+ torchaudio.save(audio_output, audio[i], sample_rate)
130
+
131
+ # Merge video and audio
132
+ video_output = os.path.join(temp_dir, f"video_with_audio_{i+1}.mp4")
133
+ merge_audio_video(audio_output, video_file, video_output)
134
+ video_outputs.append(video_output)
135
+
136
+ logger.info(f"Inference completed! Generated {sample_nums} samples.")
137
+ return video_outputs, f"✅ Generated {sample_nums} audio sample(s) successfully!"
138
+
139
+ except Exception as e:
140
+ logger.error(f"Inference failed: {str(e)}")
141
+ return [], f"❌ Inference failed: {str(e)}"
142
+
143
+ def update_video_outputs(video_list, status_msg):
144
+ """Update video outputs based on the number of generated samples"""
145
+ # Initialize all outputs as None
146
+ outputs = [None] * 6
147
+
148
+ # Set values based on generated videos
149
+ for i, video_path in enumerate(video_list[:6]): # Max 6 samples
150
+ outputs[i] = video_path
151
+
152
+ # Return all outputs plus status message
153
+ return tuple(outputs + [status_msg])
154
+
155
+ def create_gradio_interface():
156
+ """Create Gradio interface"""
157
+
158
+ # Custom CSS for beautiful interface with better contrast
159
+ css = """
160
+ .gradio-container {
161
+ font-family: 'Inter', -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif;
162
+ background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%);
163
+ min-height: 100vh;
164
+ }
165
+
166
+ .main-header {
167
+ text-align: center;
168
+ padding: 2rem 0;
169
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
170
+ border-radius: 20px;
171
+ margin-bottom: 2rem;
172
+ box-shadow: 0 8px 32px rgba(0,0,0,0.15);
173
+ }
174
+
175
+ .main-header h1 {
176
+ color: white;
177
+ font-size: 3rem;
178
+ font-weight: 700;
179
+ margin-bottom: 0.5rem;
180
+ text-shadow: 0 2px 10px rgba(0,0,0,0.3);
181
+ }
182
+
183
+ .main-header p {
184
+ color: rgba(255, 255, 255, 0.95);
185
+ font-size: 1.2rem;
186
+ font-weight: 300;
187
+ }
188
+
189
+ .status-card {
190
+ background: white;
191
+ border-radius: 15px;
192
+ padding: 1rem;
193
+ margin-bottom: 1.5rem;
194
+ border: 1px solid #e1e5e9;
195
+ box-shadow: 0 4px 20px rgba(0,0,0,0.08);
196
+ }
197
+
198
+ .status-card label {
199
+ color: #2d3748 !important;
200
+ font-weight: 600 !important;
201
+ }
202
+
203
+ .usage-guide h3 {
204
+ color: #2d3748 !important;
205
+ font-weight: 600 !important;
206
+ margin-bottom: 0.5rem !important;
207
+ }
208
+
209
+ .usage-guide p {
210
+ color: #4a5568 !important;
211
+ font-size: 1rem !important;
212
+ line-height: 1.6 !important;
213
+ margin: 0.5rem 0 !important;
214
+ }
215
+
216
+ .usage-guide strong {
217
+ color: #1a202c !important;
218
+ font-weight: 700 !important;
219
+ }
220
+
221
+ .usage-guide em {
222
+ color: #1a202c !important;
223
+ font-weight: 700 !important;
224
+ font-style: normal !important;
225
+ }
226
+
227
+ .main-interface {
228
+ margin-bottom: 2rem;
229
+ }
230
+
231
+ .input-section {
232
+ background: white;
233
+ border-radius: 20px;
234
+ padding: 2rem;
235
+ margin-right: 1rem;
236
+ box-shadow: 0 8px 32px rgba(0,0,0,0.1);
237
+ border: 1px solid #e1e5e9;
238
+ }
239
+
240
+ .input-section h3 {
241
+ color: #2d3748 !important;
242
+ font-weight: 600 !important;
243
+ margin-bottom: 1rem !important;
244
+ }
245
+
246
+ .input-section label {
247
+ color: #4a5568 !important;
248
+ font-weight: 500 !important;
249
+ }
250
+
251
+ .output-section {
252
+ background: white;
253
+ border-radius: 20px;
254
+ padding: 2rem;
255
+ margin-left: 1rem;
256
+ box-shadow: 0 8px 32px rgba(0,0,0,0.1);
257
+ border: 1px solid #e1e5e9;
258
+ }
259
+
260
+ .output-section h3 {
261
+ color: #2d3748 !important;
262
+ font-weight: 600 !important;
263
+ margin-bottom: 1rem !important;
264
+ }
265
+
266
+ .output-section label {
267
+ color: #4a5568 !important;
268
+ font-weight: 500 !important;
269
+ }
270
+
271
+ .examples-section h3 {
272
+ color: #2d3748 !important;
273
+ font-weight: 600 !important;
274
+ margin-bottom: 1.5rem !important;
275
+ }
276
+
277
+ .generate-btn {
278
+ background: linear-gradient(45deg, #667eea, #764ba2) !important;
279
+ border: none !important;
280
+ color: white !important;
281
+ font-weight: 600 !important;
282
+ font-size: 1.1rem !important;
283
+ padding: 12px 30px !important;
284
+ border-radius: 25px !important;
285
+ box-shadow: 0 4px 15px rgba(102, 126, 234, 0.4) !important;
286
+ transition: all 0.3s ease !important;
287
+ }
288
+
289
+ .generate-btn:hover {
290
+ transform: translateY(-2px) !important;
291
+ box-shadow: 0 8px 25px rgba(102, 126, 234, 0.6) !important;
292
+ }
293
+
294
+
295
+
296
+ .examples-section {
297
+ background: white;
298
+ border-radius: 20px;
299
+ padding: 2rem;
300
+ margin-top: 2rem;
301
+ box-shadow: 0 8px 32px rgba(0,0,0,0.1);
302
+ border: 1px solid #e1e5e9;
303
+ }
304
+
305
+ .examples-section p {
306
+ color: #4a5568 !important;
307
+ margin-bottom: 1rem !important;
308
+ }
309
+
310
+ .example-row {
311
+ background: #f8fafc;
312
+ border: 1px solid #e2e8f0;
313
+ border-radius: 15px;
314
+ padding: 1.5rem;
315
+ margin: 1rem 0;
316
+ transition: all 0.3s ease;
317
+ align-items: center;
318
+ }
319
+
320
+ .example-row:hover {
321
+ border-color: #667eea;
322
+ transform: translateY(-2px);
323
+ box-shadow: 0 4px 20px rgba(102, 126, 234, 0.15);
324
+ }
325
+
326
+ .example-row .markdown {
327
+ color: #2d3748 !important;
328
+ }
329
+
330
+ .example-row .markdown p {
331
+ color: #2d3748 !important;
332
+ margin: 0.5rem 0 !important;
333
+ line-height: 1.5 !important;
334
+ }
335
+
336
+ .example-row .markdown strong {
337
+ color: #1a202c !important;
338
+ font-weight: 600 !important;
339
+ }
340
+
341
+ /* Example grid layout styles */
342
+ .example-grid-row {
343
+ margin: 1rem 0;
344
+ gap: 1rem;
345
+ }
346
+
347
+ .example-item {
348
+ background: #f8fafc;
349
+ border: 1px solid #e2e8f0;
350
+ border-radius: 15px;
351
+ padding: 1rem;
352
+ transition: all 0.3s ease;
353
+ margin: 0.25rem;
354
+ max-width: 250px;
355
+ margin-left: auto;
356
+ margin-right: auto;
357
+ }
358
+
359
+ .example-item:hover {
360
+ border-color: #667eea;
361
+ transform: translateY(-2px);
362
+ box-shadow: 0 4px 20px rgba(102, 126, 234, 0.15);
363
+ }
364
+
365
+ .example-caption {
366
+ margin: 0.5rem 0 !important;
367
+ min-height: 2.8rem !important;
368
+ display: flex !important;
369
+ align-items: flex-start !important;
370
+ }
371
+
372
+ .example-caption p {
373
+ color: #2d3748 !important;
374
+ font-size: 0.9rem !important;
375
+ line-height: 1.4 !important;
376
+ margin: 0.5rem 0 !important;
377
+ }
378
+
379
+ /* Multi-video gallery styles */
380
+ .additional-samples {
381
+ margin-top: 1rem;
382
+ gap: 0.5rem;
383
+ }
384
+
385
+ .additional-samples .gradio-video {
386
+ border-radius: 10px;
387
+ overflow: hidden;
388
+ }
389
+
390
+ /* Video gallery responsive layout */
391
+ .video-gallery {
392
+ display: grid;
393
+ gap: 1rem;
394
+ margin-top: 1rem;
395
+ }
396
+
397
+ .video-gallery.single {
398
+ grid-template-columns: 1fr;
399
+ }
400
+
401
+ .video-gallery.dual {
402
+ grid-template-columns: 1fr 1fr;
403
+ }
404
+
405
+ .video-gallery.multi {
406
+ grid-template-columns: repeat(2, 1fr);
407
+ grid-template-rows: auto auto auto;
408
+ }
409
+
410
+ .footer-text {
411
+ color: #718096 !important;
412
+ text-align: center;
413
+ padding: 2rem;
414
+ font-size: 0.9rem;
415
+ }
416
+
417
+ /* Video component styling for consistent size */
418
+ .input-section video,
419
+ .output-section video,
420
+ .example-row video {
421
+ width: 100% !important;
422
+ height: 300px !important;
423
+ object-fit: contain !important;
424
+ border-radius: 10px !important;
425
+ background-color: #000 !important;
426
+ }
427
+
428
+ .example-row video {
429
+ height: 150px !important;
430
+ }
431
+
432
+ /* Fix for additional samples video display */
433
+ .additional-samples video {
434
+ height: 150px !important;
435
+ object-fit: contain !important;
436
+ border-radius: 10px !important;
437
+ background-color: #000 !important;
438
+ }
439
+
440
+ .additional-samples .gradio-video {
441
+ border-radius: 10px !important;
442
+ overflow: hidden !important;
443
+ background-color: #000 !important;
444
+ }
445
+
446
+ .additional-samples .gradio-video > div {
447
+ background-color: #000 !important;
448
+ border-radius: 10px !important;
449
+ }
450
+
451
+ /* Video container styling */
452
+ .input-section .video-container,
453
+ .output-section .video-container,
454
+ .example-row .video-container {
455
+ background-color: #000 !important;
456
+ border-radius: 10px !important;
457
+ display: flex !important;
458
+ align-items: center !important;
459
+ justify-content: center !important;
460
+ overflow: hidden !important;
461
+ }
462
+
463
+ /* Ensure proper alignment */
464
+ .example-row {
465
+ display: flex !important;
466
+ align-items: stretch !important;
467
+ }
468
+
469
+ .example-row > div {
470
+ display: flex !important;
471
+ flex-direction: column !important;
472
+ justify-content: center !important;
473
+ }
474
+
475
+ /* Video wrapper for better control */
476
+ .video-wrapper {
477
+ position: relative !important;
478
+ width: 100% !important;
479
+ background: #000 !important;
480
+ border-radius: 10px !important;
481
+ overflow: hidden !important;
482
+ display: flex !important;
483
+ align-items: center !important;
484
+ justify-content: center !important;
485
+ }
486
+ """
487
+
488
+ with gr.Blocks(css=css, title="HunyuanVideo-Foley") as app:
489
+
490
+ # Main header
491
+ with gr.Column(elem_classes=["main-header"]):
492
+ gr.HTML("""
493
+ <h1>🎵 HunyuanVideo-Foley</h1>
494
+ <p>Text-Video-to-Audio Synthesis: Generate realistic audio from video and text descriptions</p>
495
+ """)
496
+
497
+ # Usage Guide
498
+ with gr.Column(elem_classes=["status-card"]):
499
+ gr.Markdown("""
500
+ ### 📋 Quick Start Guide
501
+ **1.** Upload your video file\t**2.** Add optional text description\t**3.** Adjust sample numbers (1-6)\t**4.** Click Generate Audio
502
+
503
+ 💡 For quick start, you can load the prepared examples by clicking the button.
504
+ """, elem_classes=["usage-guide"])
505
+
506
+ # Main inference interface - Input and Results side by side
507
+ with gr.Row(elem_classes=["main-interface"]):
508
+ # Input section
509
+ with gr.Column(scale=1, elem_classes=["input-section"]):
510
+ gr.Markdown("### 📹 Video Input")
511
+
512
+ video_input = gr.Video(
513
+ label="Upload Video",
514
+ info="Supported formats: MP4, AVI, MOV, etc.",
515
+ height=300
516
+ )
517
+
518
+ text_input = gr.Textbox(
519
+ label="🎯 Audio Description (English)",
520
+ placeholder="A person walks on frozen ice",
521
+ lines=3,
522
+ info="Describe the audio you want to generate (optional)"
523
+ )
524
+
525
+ with gr.Row():
526
+ guidance_scale = gr.Slider(
527
+ minimum=1.0,
528
+ maximum=10.0,
529
+ value=4.5,
530
+ step=0.1,
531
+ label="🎚️ CFG Scale",
532
+ )
533
+
534
+ inference_steps = gr.Slider(
535
+ minimum=10,
536
+ maximum=100,
537
+ value=50,
538
+ step=5,
539
+ label="⚡ Steps",
540
+ )
541
+
542
+ sample_nums = gr.Slider(
543
+ minimum=1,
544
+ maximum=6,
545
+ value=1,
546
+ step=1,
547
+ label="🎲 Sample Nums",
548
+ )
549
+
550
+ generate_btn = gr.Button(
551
+ "🎵 Generate Audio",
552
+ variant="primary",
553
+ elem_classes=["generate-btn"]
554
+ )
555
+
556
+ # Results section
557
+ with gr.Column(scale=1, elem_classes=["output-section"]):
558
+ gr.Markdown("### 🎥 Generated Results")
559
+
560
+ # Multi-video gallery for displaying multiple generated samples
561
+ with gr.Column():
562
+ # Primary video (Sample 1)
563
+ video_output_1 = gr.Video(
564
+ label="Sample 1",
565
+ height=250,
566
+ visible=True
567
+ )
568
+
569
+ # Additional videos (Samples 2-6) - initially hidden
570
+ with gr.Row(elem_classes=["additional-samples"]):
571
+ with gr.Column(scale=1):
572
+ video_output_2 = gr.Video(
573
+ label="Sample 2",
574
+ height=150,
575
+ visible=False
576
+ )
577
+ video_output_3 = gr.Video(
578
+ label="Sample 3",
579
+ height=150,
580
+ visible=False
581
+ )
582
+ with gr.Column(scale=1):
583
+ video_output_4 = gr.Video(
584
+ label="Sample 4",
585
+ height=150,
586
+ visible=False
587
+ )
588
+ video_output_5 = gr.Video(
589
+ label="Sample 5",
590
+ height=150,
591
+ visible=False
592
+ )
593
+
594
+ # Sample 6 - full width
595
+ video_output_6 = gr.Video(
596
+ label="Sample 6",
597
+ height=150,
598
+ visible=False
599
+ )
600
+
601
+ result_text = gr.Textbox(
602
+ label="Status",
603
+ interactive=False,
604
+ lines=2
605
+ )
606
+
607
+ # Examples section at the bottom
608
+ with gr.Column(elem_classes=["examples-section"]):
609
+ gr.Markdown("### 🌟 Examples")
610
+ gr.Markdown("Click on any example to load it into the interface above")
611
+
612
+ # Define your custom examples here - 8 examples total
613
+ examples_data = [
614
+ # Example 1
615
+ {
616
+ "caption": "A person walks on frozen ice",
617
+ "video_path": "examples/1_video.mp4",
618
+ "result_path": "examples/1_result.mp4"
619
+ },
620
+ # Example 2
621
+ {
622
+ "caption": "With a faint sound as their hands parted, the two embraced, a soft 'mm' escaping between them.",
623
+ "video_path": "examples/2_video.mp4",
624
+ "result_path": "examples/2_result.mp4"
625
+ },
626
+ # Example 3
627
+ {
628
+ "caption": "The sound of the number 3's bouncing footsteps is as light and clear as glass marbles hitting the ground. Each step carries a magical sound.",
629
+ "video_path": "examples/3_video.mp4",
630
+ "result_path": "examples/3_result.mp4"
631
+ },
632
+ # Example 4
633
+ {
634
+ "caption": "gentle gurgling of the stream's current, and music plays in the background which is a beautiful and serene piano solo with a hint of classical charm, evoking a sense of peace and serenity in people's hearts.",
635
+ "video_path": "examples/4_video.mp4",
636
+ "result_path": "examples/4_result.mp4"
637
+ },
638
+ # Example 5 - Add your new examples here
639
+ {
640
+ "caption": "snow crunching under the snowboard's edge.",
641
+ "video_path": "examples/5_video.mp4",
642
+ "result_path": "examples/5_result.mp4"
643
+ },
644
+ # Example 6
645
+ {
646
+ "caption": "The crackling of the fire, the whooshing of the flames, and the occasional crisp popping of charred leaves filled the forest.",
647
+ "video_path": "examples/6_video.mp4",
648
+ "result_path": "examples/6_result.mp4"
649
+ },
650
+ # Example 7
651
+ {
652
+ "caption": "humming of the scooter engine accelerates slowly.",
653
+ "video_path": "examples/7_video.mp4",
654
+ "result_path": "examples/7_result.mp4"
655
+ },
656
+ # Example 8
657
+ {
658
+ "caption": "splash of water and loud thud as person hits the surface.",
659
+ "video_path": "examples/8_video.mp4",
660
+ "result_path": "examples/8_result.mp4"
661
+ }
662
+ ]
663
+
664
+ # Create example grid - 4 examples per row, 2 rows total
665
+ example_buttons = []
666
+ for row in range(2): # 2 rows
667
+ with gr.Row(elem_classes=["example-grid-row"]):
668
+ for col in range(4): # 4 columns
669
+ idx = row * 4 + col
670
+ if idx < len(examples_data):
671
+ example = examples_data[idx]
672
+
673
+ with gr.Column(scale=1, elem_classes=["example-item"]):
674
+ # Video thumbnail
675
+ if os.path.exists(example['video_path']):
676
+ example_video = gr.Video(
677
+ value=example['video_path'],
678
+ label=f"Example {idx+1}",
679
+ interactive=False,
680
+ show_label=True,
681
+ height=180
682
+ )
683
+ else:
684
+ example_video = gr.HTML(f"""
685
+ <div style="background: #f0f0f0; padding: 15px; text-align: center; border-radius: 8px; height: 180px; display: flex; align-items: center; justify-content: center;">
686
+ <div>
687
+ <p style="color: #666; margin: 0; font-size: 12px;">📹 Video not found</p>
688
+ <small style="color: #999; font-size: 10px;">{example['video_path']}</small>
689
+ </div>
690
+ </div>
691
+ """)
692
+
693
+ # Caption (truncated for grid layout)
694
+ caption_preview = example['caption'][:60] + "..." if len(example['caption']) > 60 else example['caption']
695
+ gr.Markdown(f"{caption_preview}", elem_classes=["example-caption"])
696
+
697
+ # Load button
698
+ example_btn = gr.Button(
699
+ f"Load Example {idx+1}",
700
+ variant="secondary",
701
+ size="sm"
702
+ )
703
+ example_buttons.append((example_btn, example))
704
+
705
+ # Event handlers
706
+ def process_inference(video_file, text_prompt, guidance_scale, inference_steps, sample_nums):
707
+ # Generate videos
708
+ video_list, status_msg = infer_single_video(
709
+ video_file, text_prompt, guidance_scale, inference_steps, int(sample_nums)
710
+ )
711
+ # Update outputs with proper visibility
712
+ return update_video_outputs(video_list, status_msg)
713
+
714
+ # Add dynamic visibility control based on sample_nums
715
+ def update_visibility(sample_nums):
716
+ sample_nums = int(sample_nums)
717
+ return [
718
+ gr.update(visible=True), # Sample 1 always visible
719
+ gr.update(visible=sample_nums >= 2), # Sample 2
720
+ gr.update(visible=sample_nums >= 3), # Sample 3
721
+ gr.update(visible=sample_nums >= 4), # Sample 4
722
+ gr.update(visible=sample_nums >= 5), # Sample 5
723
+ gr.update(visible=sample_nums >= 6), # Sample 6
724
+ ]
725
+
726
+ # Update visibility when sample_nums changes
727
+ sample_nums.change(
728
+ fn=update_visibility,
729
+ inputs=[sample_nums],
730
+ outputs=[video_output_1, video_output_2, video_output_3, video_output_4, video_output_5, video_output_6]
731
+ )
732
+
733
+ generate_btn.click(
734
+ fn=process_inference,
735
+ inputs=[video_input, text_input, guidance_scale, inference_steps, sample_nums],
736
+ outputs=[
737
+ video_output_1, # Sample 1 value
738
+ video_output_2, # Sample 2 value
739
+ video_output_3, # Sample 3 value
740
+ video_output_4, # Sample 4 value
741
+ video_output_5, # Sample 5 value
742
+ video_output_6, # Sample 6 value
743
+ result_text
744
+ ]
745
+ )
746
+
747
+ # Add click handlers for example buttons
748
+ for btn, example in example_buttons:
749
+ def create_example_handler(ex):
750
+ def handler():
751
+ # Check if files exist, if not, return placeholder message
752
+ if os.path.exists(ex['video_path']):
753
+ video_file = ex['video_path']
754
+ else:
755
+ video_file = None
756
+
757
+ if os.path.exists(ex['result_path']):
758
+ result_video = ex['result_path']
759
+ else:
760
+ result_video = None
761
+
762
+ status_msg = f"✅ Loaded example with caption: {ex['caption'][:50]}..."
763
+ if not video_file:
764
+ status_msg += f"\n⚠️ Video file not found: {ex['video_path']}"
765
+ if not result_video:
766
+ status_msg += f"\n⚠️ Result video not found: {ex['result_path']}"
767
+
768
+ return video_file, ex['caption'], result_video, status_msg
769
+ return handler
770
+
771
+ btn.click(
772
+ fn=create_example_handler(example),
773
+ outputs=[video_input, text_input, video_output_1, result_text]
774
+ )
775
+
776
+ # Footer
777
+ gr.HTML("""
778
+ <div class="footer-text">
779
+ <p>🚀 Powered by HunyuanVideo-Foley | Generate high-quality audio from video and text descriptions</p>
780
+ </div>
781
+ """)
782
+
783
+ return app
784
+
785
+ def set_manual_seed(global_seed):
786
+ random.seed(global_seed)
787
+ np.random.seed(global_seed)
788
+ torch.manual_seed(global_seed)
789
+
790
+ if __name__ == "__main__":
791
+ set_manual_seed(1)
792
+ # Setup logging
793
+ logger.remove()
794
+ logger.add(lambda msg: print(msg, end=''), level="INFO")
795
+
796
+ # Auto-load model
797
+ logger.info("Starting application and loading model...")
798
+ model_load_result = auto_load_models()
799
+ logger.info(model_load_result)
800
+
801
+ # Create and launch Gradio app
802
+ app = create_gradio_interface()
803
+
804
+ # Log completion status
805
+ if "successfully" in model_load_result:
806
+ logger.info("Application ready, model loaded")
807
+
808
+ app.launch(
809
+ server_name="0.0.0.0",
810
+ server_port=8080,
811
+ share=False,
812
+ debug=False,
813
+ show_error=True
814
+ )
assets/data_pipeline.png ADDED

Git LFS Details

  • SHA256: d5c9e5cd92a7ac24d1e8f39db09e0eaa9ee84bedade8ff08bd1d50141fc7867c
  • Pointer size: 131 Bytes
  • Size of remote file: 385 kB
assets/model_arch.png ADDED

Git LFS Details

  • SHA256: 4709a32df5b115e7806e0eb102aaf2e396a0978e12b31fba338730068d6454d7
  • Pointer size: 131 Bytes
  • Size of remote file: 542 kB
assets/pan_chart.png ADDED

Git LFS Details

  • SHA256: 16019d3355051f5b470532809a0cf9046d22170d30c860dd01929f6921d29ead
  • Pointer size: 131 Bytes
  • Size of remote file: 304 kB
configs/hunyuanvideo-foley-xxl.yaml ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_config:
2
+ model_name: HunyuanVideo-Foley-XXL
3
+ model_type: 1d
4
+ model_precision: bf16
5
+ model_kwargs:
6
+ depth_triple_blocks: 18
7
+ depth_single_blocks: 36
8
+ hidden_size: 1536
9
+ num_heads: 12
10
+ mlp_ratio: 4
11
+ mlp_act_type: "gelu_tanh"
12
+ qkv_bias: True
13
+ qk_norm: True
14
+ qk_norm_type: "rms"
15
+ attn_mode: "torch"
16
+ embedder_type: "default"
17
+ interleaved_audio_visual_rope: True
18
+ enable_learnable_empty_visual_feat: True
19
+ sync_modulation: False
20
+ add_sync_feat_to_audio: True
21
+ cross_attention: True
22
+ use_attention_mask: False
23
+ condition_projection: "linear"
24
+ sync_feat_dim: 768 # syncformer 768 dim
25
+ condition_dim: 768 # clap 768 text condition dim (clip-text)
26
+ clip_dim: 768 # siglip2 visual dim
27
+ audio_vae_latent_dim: 128
28
+ audio_frame_rate: 50
29
+ patch_size: 1
30
+ rope_dim_list: null
31
+ rope_theta: 10000
32
+ text_length: 77
33
+ clip_length: 64
34
+ sync_length: 192
35
+ use_mmaudio_singleblock: True
36
+ depth_triple_ssl_encoder: null
37
+ depth_single_ssl_encoder: 8
38
+ use_repa_with_audiossl: True
39
+
40
+ diffusion_config:
41
+ denoise_type: "flow"
42
+ flow_path_type: "linear"
43
+ flow_predict_type: "velocity"
44
+ flow_reverse: True
45
+ flow_solver: "euler"
46
+ sample_flow_shift: 1.0
47
+ sample_use_flux_shift: False
48
+ flux_base_shift: 0.5
49
+ flux_max_shift: 1.15
examples/1_result.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7f3f49d6130592f479b0aca5f02ba25960140ed8d9d17340ff7f6306b39096a8
3
+ size 11357340
examples/1_video.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:54fc2de0b52f6969157b9caff212ffffddd4d34a75efb47ef7e7f8352d0a38db
3
+ size 11181543
examples/2_result.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cf4f324b158f6a6926e77bbd0791610d79f0c3a600a571ed8ec61b0b7e645e46
3
+ size 1720732
examples/2_video.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:512b185682ec8e60407dff65718443d7a28c75c79aec1e733b5abf7433af41a7
3
+ size 1636945
examples/3_result.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:df90300335b7ab1fb2fc4c020976837c6b3781f796e211bcbbaa30d34353d3e5
3
+ size 1738462
examples/3_video.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e79f80cf939fcb507e3fe218f61146d4fd3949a84e802b0f5b67bb2e981931a7
3
+ size 1652180
examples/4_result.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c7d2b5b63f6756f719d53e8087f772aa6bb25f31fbcd9f1cbae9e075fc841a2c
3
+ size 45242387
examples/4_video.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f94cfd97634f3df085672ce2a91805697320507a716360bf85ba6eabb5a4b6f0
3
+ size 45066257
examples/5_result.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bf1db36336822b54b4e0e0aa8c98334fc2b97b9288271ddc9d42a45417b1f1d9
3
+ size 40423834
examples/5_video.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:746954bde2d5e693beecd8e3661bcd66ff0e55a8143f4f3b37f0b6d3873a8fff
3
+ size 40248335
examples/6_result.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7778c26a677c04e93dc722cee44b896c5281bea8328a10800a08b98865419cd0
3
+ size 4005580
examples/6_video.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:866b2cff3441ddd686e551181c48ad7ca718626a489cf64198626b42bd732366
3
+ size 3872852
examples/7_result.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:491114177a4ceeb50a6edfa5fca14fc6ce4fdb61ee7dfb0e13983236c42ee10d
3
+ size 32307884
examples/7_video.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7bc8a31e867245f6a6a6fbfa9778ac4c12e816184dc70324dc92d4496a36f62b
3
+ size 32131367
examples/8_result.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:17053c5f8a373d656be2d2030619a71a5fd55db6365f8d54234402121e6030ce
3
+ size 29544164
examples/8_video.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:45a1f35974ad6e2d86304828fdeb230d9b008aae2f10cff8c87d71a8dcc6491e
3
+ size 29367637
hunyuanvideo_foley/__init__.py ADDED
File without changes
hunyuanvideo_foley/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (173 Bytes). View file
 
hunyuanvideo_foley/constants.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Constants used throughout the HunyuanVideo-Foley project."""
2
+
3
+ from typing import Dict, List
4
+
5
+ # Model configuration
6
+ DEFAULT_AUDIO_SAMPLE_RATE = 48000
7
+ DEFAULT_VIDEO_FPS = 25
8
+ DEFAULT_AUDIO_CHANNELS = 2
9
+
10
+ # Video processing
11
+ MAX_VIDEO_DURATION_SECONDS = 15.0
12
+ MIN_VIDEO_DURATION_SECONDS = 1.0
13
+
14
+ # Audio processing
15
+ AUDIO_VAE_LATENT_DIM = 128
16
+ AUDIO_FRAME_RATE = 75 # frames per second in latent space
17
+
18
+ # Visual features
19
+ FPS_VISUAL: Dict[str, int] = {
20
+ "siglip2": 8,
21
+ "synchformer": 25
22
+ }
23
+
24
+ # Model paths (can be overridden by environment variables)
25
+ DEFAULT_MODEL_PATH = "./pretrained_models/"
26
+ DEFAULT_CONFIG_PATH = "configs/hunyuanvideo-foley-xxl.yaml"
27
+
28
+ # Inference parameters
29
+ DEFAULT_GUIDANCE_SCALE = 4.5
30
+ DEFAULT_NUM_INFERENCE_STEPS = 50
31
+ MIN_GUIDANCE_SCALE = 1.0
32
+ MAX_GUIDANCE_SCALE = 10.0
33
+ MIN_INFERENCE_STEPS = 10
34
+ MAX_INFERENCE_STEPS = 100
35
+
36
+ # Text processing
37
+ MAX_TEXT_LENGTH = 100
38
+ DEFAULT_NEGATIVE_PROMPT = "noisy, harsh"
39
+
40
+ # File extensions
41
+ SUPPORTED_VIDEO_EXTENSIONS: List[str] = [".mp4", ".avi", ".mov", ".mkv", ".webm"]
42
+ SUPPORTED_AUDIO_EXTENSIONS: List[str] = [".wav", ".mp3", ".flac", ".aac"]
43
+
44
+ # Quality settings
45
+ AUDIO_QUALITY_SETTINGS: Dict[str, List[str]] = {
46
+ "high": ["-b:a", "192k"],
47
+ "medium": ["-b:a", "128k"],
48
+ "low": ["-b:a", "96k"]
49
+ }
50
+
51
+ # Error messages
52
+ ERROR_MESSAGES: Dict[str, str] = {
53
+ "model_not_loaded": "Model is not loaded. Please load the model first.",
54
+ "invalid_video_format": "Unsupported video format. Supported formats: {formats}",
55
+ "video_too_long": f"Video duration exceeds maximum of {MAX_VIDEO_DURATION_SECONDS} seconds",
56
+ "ffmpeg_not_found": "ffmpeg not found. Please install ffmpeg: https://ffmpeg.org/download.html"
57
+ }
hunyuanvideo_foley/models/__init__.py ADDED
File without changes
hunyuanvideo_foley/models/__pycache__/mmaudio_layer.cpython-312.pyc ADDED
Binary file (12.1 kB). View file
 
hunyuanvideo_foley/models/dac_vae/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __version__ = "1.0.0"
2
+
3
+ # preserved here for legacy reasons
4
+ __model_version__ = "latest"
5
+
6
+ import audiotools
7
+
8
+ audiotools.ml.BaseModel.INTERN += ["dac.**"]
9
+ audiotools.ml.BaseModel.EXTERN += ["einops"]
10
+
11
+
12
+ from . import nn
13
+ from . import model
14
+ from . import utils
15
+ from .model import DAC
16
+ from .model import DACFile
hunyuanvideo_foley/models/dac_vae/__main__.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ import argbind
4
+
5
+ from .utils import download
6
+ from .utils.decode import decode
7
+ from .utils.encode import encode
8
+
9
+ STAGES = ["encode", "decode", "download"]
10
+
11
+
12
+ def run(stage: str):
13
+ """Run stages.
14
+
15
+ Parameters
16
+ ----------
17
+ stage : str
18
+ Stage to run
19
+ """
20
+ if stage not in STAGES:
21
+ raise ValueError(f"Unknown command: {stage}. Allowed commands are {STAGES}")
22
+ stage_fn = globals()[stage]
23
+
24
+ if stage == "download":
25
+ stage_fn()
26
+ return
27
+
28
+ stage_fn()
29
+
30
+
31
+ if __name__ == "__main__":
32
+ group = sys.argv.pop(1)
33
+ args = argbind.parse_args(group=group)
34
+
35
+ with argbind.scope(args):
36
+ run(group)
hunyuanvideo_foley/models/dac_vae/model/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .base import CodecMixin
2
+ from .base import DACFile
3
+ from .dac import DAC
4
+ from .discriminator import Discriminator
hunyuanvideo_foley/models/dac_vae/model/base.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+ from pathlib import Path
4
+ from typing import Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ import tqdm
9
+ from audiotools import AudioSignal
10
+ from torch import nn
11
+
12
+ SUPPORTED_VERSIONS = ["1.0.0"]
13
+
14
+
15
+ @dataclass
16
+ class DACFile:
17
+ codes: torch.Tensor
18
+
19
+ # Metadata
20
+ chunk_length: int
21
+ original_length: int
22
+ input_db: float
23
+ channels: int
24
+ sample_rate: int
25
+ padding: bool
26
+ dac_version: str
27
+
28
+ def save(self, path):
29
+ artifacts = {
30
+ "codes": self.codes.numpy().astype(np.uint16),
31
+ "metadata": {
32
+ "input_db": self.input_db.numpy().astype(np.float32),
33
+ "original_length": self.original_length,
34
+ "sample_rate": self.sample_rate,
35
+ "chunk_length": self.chunk_length,
36
+ "channels": self.channels,
37
+ "padding": self.padding,
38
+ "dac_version": SUPPORTED_VERSIONS[-1],
39
+ },
40
+ }
41
+ path = Path(path).with_suffix(".dac")
42
+ with open(path, "wb") as f:
43
+ np.save(f, artifacts)
44
+ return path
45
+
46
+ @classmethod
47
+ def load(cls, path):
48
+ artifacts = np.load(path, allow_pickle=True)[()]
49
+ codes = torch.from_numpy(artifacts["codes"].astype(int))
50
+ if artifacts["metadata"].get("dac_version", None) not in SUPPORTED_VERSIONS:
51
+ raise RuntimeError(
52
+ f"Given file {path} can't be loaded with this version of descript-audio-codec."
53
+ )
54
+ return cls(codes=codes, **artifacts["metadata"])
55
+
56
+
57
+ class CodecMixin:
58
+ @property
59
+ def padding(self):
60
+ if not hasattr(self, "_padding"):
61
+ self._padding = True
62
+ return self._padding
63
+
64
+ @padding.setter
65
+ def padding(self, value):
66
+ assert isinstance(value, bool)
67
+
68
+ layers = [
69
+ l for l in self.modules() if isinstance(l, (nn.Conv1d, nn.ConvTranspose1d))
70
+ ]
71
+
72
+ for layer in layers:
73
+ if value:
74
+ if hasattr(layer, "original_padding"):
75
+ layer.padding = layer.original_padding
76
+ else:
77
+ layer.original_padding = layer.padding
78
+ layer.padding = tuple(0 for _ in range(len(layer.padding)))
79
+
80
+ self._padding = value
81
+
82
+ def get_delay(self):
83
+ # Any number works here, delay is invariant to input length
84
+ l_out = self.get_output_length(0)
85
+ L = l_out
86
+
87
+ layers = []
88
+ for layer in self.modules():
89
+ if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
90
+ layers.append(layer)
91
+
92
+ for layer in reversed(layers):
93
+ d = layer.dilation[0]
94
+ k = layer.kernel_size[0]
95
+ s = layer.stride[0]
96
+
97
+ if isinstance(layer, nn.ConvTranspose1d):
98
+ L = ((L - d * (k - 1) - 1) / s) + 1
99
+ elif isinstance(layer, nn.Conv1d):
100
+ L = (L - 1) * s + d * (k - 1) + 1
101
+
102
+ L = math.ceil(L)
103
+
104
+ l_in = L
105
+
106
+ return (l_in - l_out) // 2
107
+
108
+ def get_output_length(self, input_length):
109
+ L = input_length
110
+ # Calculate output length
111
+ for layer in self.modules():
112
+ if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
113
+ d = layer.dilation[0]
114
+ k = layer.kernel_size[0]
115
+ s = layer.stride[0]
116
+
117
+ if isinstance(layer, nn.Conv1d):
118
+ L = ((L - d * (k - 1) - 1) / s) + 1
119
+ elif isinstance(layer, nn.ConvTranspose1d):
120
+ L = (L - 1) * s + d * (k - 1) + 1
121
+
122
+ L = math.floor(L)
123
+ return L
124
+
125
+ @torch.no_grad()
126
+ def compress(
127
+ self,
128
+ audio_path_or_signal: Union[str, Path, AudioSignal],
129
+ win_duration: float = 1.0,
130
+ verbose: bool = False,
131
+ normalize_db: float = -16,
132
+ n_quantizers: int = None,
133
+ ) -> DACFile:
134
+ """Processes an audio signal from a file or AudioSignal object into
135
+ discrete codes. This function processes the signal in short windows,
136
+ using constant GPU memory.
137
+
138
+ Parameters
139
+ ----------
140
+ audio_path_or_signal : Union[str, Path, AudioSignal]
141
+ audio signal to reconstruct
142
+ win_duration : float, optional
143
+ window duration in seconds, by default 5.0
144
+ verbose : bool, optional
145
+ by default False
146
+ normalize_db : float, optional
147
+ normalize db, by default -16
148
+
149
+ Returns
150
+ -------
151
+ DACFile
152
+ Object containing compressed codes and metadata
153
+ required for decompression
154
+ """
155
+ audio_signal = audio_path_or_signal
156
+ if isinstance(audio_signal, (str, Path)):
157
+ audio_signal = AudioSignal.load_from_file_with_ffmpeg(str(audio_signal))
158
+
159
+ self.eval()
160
+ original_padding = self.padding
161
+ original_device = audio_signal.device
162
+
163
+ audio_signal = audio_signal.clone()
164
+ audio_signal = audio_signal.to_mono()
165
+ original_sr = audio_signal.sample_rate
166
+
167
+ resample_fn = audio_signal.resample
168
+ loudness_fn = audio_signal.loudness
169
+
170
+ # If audio is > 10 minutes long, use the ffmpeg versions
171
+ if audio_signal.signal_duration >= 10 * 60 * 60:
172
+ resample_fn = audio_signal.ffmpeg_resample
173
+ loudness_fn = audio_signal.ffmpeg_loudness
174
+
175
+ original_length = audio_signal.signal_length
176
+ resample_fn(self.sample_rate)
177
+ input_db = loudness_fn()
178
+
179
+ if normalize_db is not None:
180
+ audio_signal.normalize(normalize_db)
181
+ audio_signal.ensure_max_of_audio()
182
+
183
+ nb, nac, nt = audio_signal.audio_data.shape
184
+ audio_signal.audio_data = audio_signal.audio_data.reshape(nb * nac, 1, nt)
185
+ win_duration = (
186
+ audio_signal.signal_duration if win_duration is None else win_duration
187
+ )
188
+
189
+ if audio_signal.signal_duration <= win_duration:
190
+ # Unchunked compression (used if signal length < win duration)
191
+ self.padding = True
192
+ n_samples = nt
193
+ hop = nt
194
+ else:
195
+ # Chunked inference
196
+ self.padding = False
197
+ # Zero-pad signal on either side by the delay
198
+ audio_signal.zero_pad(self.delay, self.delay)
199
+ n_samples = int(win_duration * self.sample_rate)
200
+ # Round n_samples to nearest hop length multiple
201
+ n_samples = int(math.ceil(n_samples / self.hop_length) * self.hop_length)
202
+ hop = self.get_output_length(n_samples)
203
+
204
+ codes = []
205
+ range_fn = range if not verbose else tqdm.trange
206
+
207
+ for i in range_fn(0, nt, hop):
208
+ x = audio_signal[..., i : i + n_samples]
209
+ x = x.zero_pad(0, max(0, n_samples - x.shape[-1]))
210
+
211
+ audio_data = x.audio_data.to(self.device)
212
+ audio_data = self.preprocess(audio_data, self.sample_rate)
213
+ _, c, _, _, _ = self.encode(audio_data, n_quantizers)
214
+ codes.append(c.to(original_device))
215
+ chunk_length = c.shape[-1]
216
+
217
+ codes = torch.cat(codes, dim=-1)
218
+
219
+ dac_file = DACFile(
220
+ codes=codes,
221
+ chunk_length=chunk_length,
222
+ original_length=original_length,
223
+ input_db=input_db,
224
+ channels=nac,
225
+ sample_rate=original_sr,
226
+ padding=self.padding,
227
+ dac_version=SUPPORTED_VERSIONS[-1],
228
+ )
229
+
230
+ if n_quantizers is not None:
231
+ codes = codes[:, :n_quantizers, :]
232
+
233
+ self.padding = original_padding
234
+ return dac_file
235
+
236
+ @torch.no_grad()
237
+ def decompress(
238
+ self,
239
+ obj: Union[str, Path, DACFile],
240
+ verbose: bool = False,
241
+ ) -> AudioSignal:
242
+ """Reconstruct audio from a given .dac file
243
+
244
+ Parameters
245
+ ----------
246
+ obj : Union[str, Path, DACFile]
247
+ .dac file location or corresponding DACFile object.
248
+ verbose : bool, optional
249
+ Prints progress if True, by default False
250
+
251
+ Returns
252
+ -------
253
+ AudioSignal
254
+ Object with the reconstructed audio
255
+ """
256
+ self.eval()
257
+ if isinstance(obj, (str, Path)):
258
+ obj = DACFile.load(obj)
259
+
260
+ original_padding = self.padding
261
+ self.padding = obj.padding
262
+
263
+ range_fn = range if not verbose else tqdm.trange
264
+ codes = obj.codes
265
+ original_device = codes.device
266
+ chunk_length = obj.chunk_length
267
+ recons = []
268
+
269
+ for i in range_fn(0, codes.shape[-1], chunk_length):
270
+ c = codes[..., i : i + chunk_length].to(self.device)
271
+ z = self.quantizer.from_codes(c)[0]
272
+ r = self.decode(z)
273
+ recons.append(r.to(original_device))
274
+
275
+ recons = torch.cat(recons, dim=-1)
276
+ recons = AudioSignal(recons, self.sample_rate)
277
+
278
+ resample_fn = recons.resample
279
+ loudness_fn = recons.loudness
280
+
281
+ # If audio is > 10 minutes long, use the ffmpeg versions
282
+ if recons.signal_duration >= 10 * 60 * 60:
283
+ resample_fn = recons.ffmpeg_resample
284
+ loudness_fn = recons.ffmpeg_loudness
285
+
286
+ if obj.input_db is not None:
287
+ recons.normalize(obj.input_db)
288
+
289
+ resample_fn(obj.sample_rate)
290
+
291
+ if obj.original_length is not None:
292
+ recons = recons[..., : obj.original_length]
293
+ loudness_fn()
294
+ recons.audio_data = recons.audio_data.reshape(
295
+ -1, obj.channels, obj.original_length
296
+ )
297
+ else:
298
+ loudness_fn()
299
+
300
+ self.padding = original_padding
301
+ return recons
hunyuanvideo_foley/models/dac_vae/model/dac.py ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List
3
+ from typing import Union
4
+
5
+ import numpy as np
6
+ import torch
7
+ from audiotools import AudioSignal
8
+ from audiotools.ml import BaseModel
9
+ from torch import nn
10
+
11
+ from .base import CodecMixin
12
+ from ..nn.layers import Snake1d
13
+ from ..nn.layers import WNConv1d
14
+ from ..nn.layers import WNConvTranspose1d
15
+ from ..nn.quantize import ResidualVectorQuantize
16
+ from ..nn.vae_utils import DiagonalGaussianDistribution
17
+
18
+
19
+ def init_weights(m):
20
+ if isinstance(m, nn.Conv1d):
21
+ nn.init.trunc_normal_(m.weight, std=0.02)
22
+ nn.init.constant_(m.bias, 0)
23
+
24
+
25
+ class ResidualUnit(nn.Module):
26
+ def __init__(self, dim: int = 16, dilation: int = 1):
27
+ super().__init__()
28
+ pad = ((7 - 1) * dilation) // 2
29
+ self.block = nn.Sequential(
30
+ Snake1d(dim),
31
+ WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
32
+ Snake1d(dim),
33
+ WNConv1d(dim, dim, kernel_size=1),
34
+ )
35
+
36
+ def forward(self, x):
37
+ y = self.block(x)
38
+ pad = (x.shape[-1] - y.shape[-1]) // 2
39
+ if pad > 0:
40
+ x = x[..., pad:-pad]
41
+ return x + y
42
+
43
+
44
+ class EncoderBlock(nn.Module):
45
+ def __init__(self, dim: int = 16, stride: int = 1):
46
+ super().__init__()
47
+ self.block = nn.Sequential(
48
+ ResidualUnit(dim // 2, dilation=1),
49
+ ResidualUnit(dim // 2, dilation=3),
50
+ ResidualUnit(dim // 2, dilation=9),
51
+ Snake1d(dim // 2),
52
+ WNConv1d(
53
+ dim // 2,
54
+ dim,
55
+ kernel_size=2 * stride,
56
+ stride=stride,
57
+ padding=math.ceil(stride / 2),
58
+ ),
59
+ )
60
+
61
+ def forward(self, x):
62
+ return self.block(x)
63
+
64
+
65
+ class Encoder(nn.Module):
66
+ def __init__(
67
+ self,
68
+ d_model: int = 64,
69
+ strides: list = [2, 4, 8, 8],
70
+ d_latent: int = 64,
71
+ ):
72
+ super().__init__()
73
+ # Create first convolution
74
+ self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)]
75
+
76
+ # Create EncoderBlocks that double channels as they downsample by `stride`
77
+ for stride in strides:
78
+ d_model *= 2
79
+ self.block += [EncoderBlock(d_model, stride=stride)]
80
+
81
+ # Create last convolution
82
+ self.block += [
83
+ Snake1d(d_model),
84
+ WNConv1d(d_model, d_latent, kernel_size=3, padding=1),
85
+ ]
86
+
87
+ # Wrap black into nn.Sequential
88
+ self.block = nn.Sequential(*self.block)
89
+ self.enc_dim = d_model
90
+
91
+ def forward(self, x):
92
+ return self.block(x)
93
+
94
+
95
+ class DecoderBlock(nn.Module):
96
+ def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1):
97
+ super().__init__()
98
+ self.block = nn.Sequential(
99
+ Snake1d(input_dim),
100
+ WNConvTranspose1d(
101
+ input_dim,
102
+ output_dim,
103
+ kernel_size=2 * stride,
104
+ stride=stride,
105
+ padding=math.ceil(stride / 2),
106
+ output_padding=stride % 2,
107
+ ),
108
+ ResidualUnit(output_dim, dilation=1),
109
+ ResidualUnit(output_dim, dilation=3),
110
+ ResidualUnit(output_dim, dilation=9),
111
+ )
112
+
113
+ def forward(self, x):
114
+ return self.block(x)
115
+
116
+
117
+ class Decoder(nn.Module):
118
+ def __init__(
119
+ self,
120
+ input_channel,
121
+ channels,
122
+ rates,
123
+ d_out: int = 1,
124
+ ):
125
+ super().__init__()
126
+
127
+ # Add first conv layer
128
+ layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3)]
129
+
130
+ # Add upsampling + MRF blocks
131
+ for i, stride in enumerate(rates):
132
+ input_dim = channels // 2**i
133
+ output_dim = channels // 2 ** (i + 1)
134
+ layers += [DecoderBlock(input_dim, output_dim, stride)]
135
+
136
+ # Add final conv layer
137
+ layers += [
138
+ Snake1d(output_dim),
139
+ WNConv1d(output_dim, d_out, kernel_size=7, padding=3),
140
+ nn.Tanh(),
141
+ ]
142
+
143
+ self.model = nn.Sequential(*layers)
144
+
145
+ def forward(self, x):
146
+ return self.model(x)
147
+
148
+
149
+ class DAC(BaseModel, CodecMixin):
150
+ def __init__(
151
+ self,
152
+ encoder_dim: int = 64,
153
+ encoder_rates: List[int] = [2, 4, 8, 8],
154
+ latent_dim: int = None,
155
+ decoder_dim: int = 1536,
156
+ decoder_rates: List[int] = [8, 8, 4, 2],
157
+ n_codebooks: int = 9,
158
+ codebook_size: int = 1024,
159
+ codebook_dim: Union[int, list] = 8,
160
+ quantizer_dropout: bool = False,
161
+ sample_rate: int = 44100,
162
+ continuous: bool = False,
163
+ ):
164
+ super().__init__()
165
+
166
+ self.encoder_dim = encoder_dim
167
+ self.encoder_rates = encoder_rates
168
+ self.decoder_dim = decoder_dim
169
+ self.decoder_rates = decoder_rates
170
+ self.sample_rate = sample_rate
171
+ self.continuous = continuous
172
+
173
+ if latent_dim is None:
174
+ latent_dim = encoder_dim * (2 ** len(encoder_rates))
175
+
176
+ self.latent_dim = latent_dim
177
+
178
+ self.hop_length = np.prod(encoder_rates)
179
+ self.encoder = Encoder(encoder_dim, encoder_rates, latent_dim)
180
+
181
+ if not continuous:
182
+ self.n_codebooks = n_codebooks
183
+ self.codebook_size = codebook_size
184
+ self.codebook_dim = codebook_dim
185
+ self.quantizer = ResidualVectorQuantize(
186
+ input_dim=latent_dim,
187
+ n_codebooks=n_codebooks,
188
+ codebook_size=codebook_size,
189
+ codebook_dim=codebook_dim,
190
+ quantizer_dropout=quantizer_dropout,
191
+ )
192
+ else:
193
+ self.quant_conv = torch.nn.Conv1d(latent_dim, 2 * latent_dim, 1)
194
+ self.post_quant_conv = torch.nn.Conv1d(latent_dim, latent_dim, 1)
195
+
196
+ self.decoder = Decoder(
197
+ latent_dim,
198
+ decoder_dim,
199
+ decoder_rates,
200
+ )
201
+ self.sample_rate = sample_rate
202
+ self.apply(init_weights)
203
+
204
+ self.delay = self.get_delay()
205
+
206
+ @property
207
+ def dtype(self):
208
+ """Get the dtype of the model parameters."""
209
+ # Return the dtype of the first parameter found
210
+ for param in self.parameters():
211
+ return param.dtype
212
+ return torch.float32 # fallback
213
+
214
+ @property
215
+ def device(self):
216
+ """Get the device of the model parameters."""
217
+ # Return the device of the first parameter found
218
+ for param in self.parameters():
219
+ return param.device
220
+ return torch.device('cpu') # fallback
221
+
222
+ def preprocess(self, audio_data, sample_rate):
223
+ if sample_rate is None:
224
+ sample_rate = self.sample_rate
225
+ assert sample_rate == self.sample_rate
226
+
227
+ length = audio_data.shape[-1]
228
+ right_pad = math.ceil(length / self.hop_length) * self.hop_length - length
229
+ audio_data = nn.functional.pad(audio_data, (0, right_pad))
230
+
231
+ return audio_data
232
+
233
+ def encode(
234
+ self,
235
+ audio_data: torch.Tensor,
236
+ n_quantizers: int = None,
237
+ ):
238
+ """Encode given audio data and return quantized latent codes
239
+
240
+ Parameters
241
+ ----------
242
+ audio_data : Tensor[B x 1 x T]
243
+ Audio data to encode
244
+ n_quantizers : int, optional
245
+ Number of quantizers to use, by default None
246
+ If None, all quantizers are used.
247
+
248
+ Returns
249
+ -------
250
+ dict
251
+ A dictionary with the following keys:
252
+ "z" : Tensor[B x D x T]
253
+ Quantized continuous representation of input
254
+ "codes" : Tensor[B x N x T]
255
+ Codebook indices for each codebook
256
+ (quantized discrete representation of input)
257
+ "latents" : Tensor[B x N*D x T]
258
+ Projected latents (continuous representation of input before quantization)
259
+ "vq/commitment_loss" : Tensor[1]
260
+ Commitment loss to train encoder to predict vectors closer to codebook
261
+ entries
262
+ "vq/codebook_loss" : Tensor[1]
263
+ Codebook loss to update the codebook
264
+ "length" : int
265
+ Number of samples in input audio
266
+ """
267
+ z = self.encoder(audio_data) # [B x D x T]
268
+ if not self.continuous:
269
+ z, codes, latents, commitment_loss, codebook_loss = self.quantizer(z, n_quantizers)
270
+ else:
271
+ z = self.quant_conv(z) # [B x 2D x T]
272
+ z = DiagonalGaussianDistribution(z)
273
+ codes, latents, commitment_loss, codebook_loss = None, None, 0, 0
274
+
275
+ return z, codes, latents, commitment_loss, codebook_loss
276
+
277
+ def decode(self, z: torch.Tensor):
278
+ """Decode given latent codes and return audio data
279
+
280
+ Parameters
281
+ ----------
282
+ z : Tensor[B x D x T]
283
+ Quantized continuous representation of input
284
+ length : int, optional
285
+ Number of samples in output audio, by default None
286
+
287
+ Returns
288
+ -------
289
+ dict
290
+ A dictionary with the following keys:
291
+ "audio" : Tensor[B x 1 x length]
292
+ Decoded audio data.
293
+ """
294
+ if not self.continuous:
295
+ audio = self.decoder(z)
296
+ else:
297
+ z = self.post_quant_conv(z)
298
+ audio = self.decoder(z)
299
+
300
+ return audio
301
+
302
+ def forward(
303
+ self,
304
+ audio_data: torch.Tensor,
305
+ sample_rate: int = None,
306
+ n_quantizers: int = None,
307
+ ):
308
+ """Model forward pass
309
+
310
+ Parameters
311
+ ----------
312
+ audio_data : Tensor[B x 1 x T]
313
+ Audio data to encode
314
+ sample_rate : int, optional
315
+ Sample rate of audio data in Hz, by default None
316
+ If None, defaults to `self.sample_rate`
317
+ n_quantizers : int, optional
318
+ Number of quantizers to use, by default None.
319
+ If None, all quantizers are used.
320
+
321
+ Returns
322
+ -------
323
+ dict
324
+ A dictionary with the following keys:
325
+ "z" : Tensor[B x D x T]
326
+ Quantized continuous representation of input
327
+ "codes" : Tensor[B x N x T]
328
+ Codebook indices for each codebook
329
+ (quantized discrete representation of input)
330
+ "latents" : Tensor[B x N*D x T]
331
+ Projected latents (continuous representation of input before quantization)
332
+ "vq/commitment_loss" : Tensor[1]
333
+ Commitment loss to train encoder to predict vectors closer to codebook
334
+ entries
335
+ "vq/codebook_loss" : Tensor[1]
336
+ Codebook loss to update the codebook
337
+ "length" : int
338
+ Number of samples in input audio
339
+ "audio" : Tensor[B x 1 x length]
340
+ Decoded audio data.
341
+ """
342
+ length = audio_data.shape[-1]
343
+ audio_data = self.preprocess(audio_data, sample_rate)
344
+ if not self.continuous:
345
+ z, codes, latents, commitment_loss, codebook_loss = self.encode(audio_data, n_quantizers)
346
+
347
+ x = self.decode(z)
348
+ return {
349
+ "audio": x[..., :length],
350
+ "z": z,
351
+ "codes": codes,
352
+ "latents": latents,
353
+ "vq/commitment_loss": commitment_loss,
354
+ "vq/codebook_loss": codebook_loss,
355
+ }
356
+ else:
357
+ posterior, _, _, _, _ = self.encode(audio_data, n_quantizers)
358
+ z = posterior.sample()
359
+ x = self.decode(z)
360
+
361
+ kl_loss = posterior.kl()
362
+ kl_loss = kl_loss.mean()
363
+
364
+ return {
365
+ "audio": x[..., :length],
366
+ "z": z,
367
+ "kl_loss": kl_loss,
368
+ }
369
+
370
+
371
+ if __name__ == "__main__":
372
+ import numpy as np
373
+ from functools import partial
374
+
375
+ model = DAC().to("cpu")
376
+
377
+ for n, m in model.named_modules():
378
+ o = m.extra_repr()
379
+ p = sum([np.prod(p.size()) for p in m.parameters()])
380
+ fn = lambda o, p: o + f" {p/1e6:<.3f}M params."
381
+ setattr(m, "extra_repr", partial(fn, o=o, p=p))
382
+ print(model)
383
+ print("Total # of params: ", sum([np.prod(p.size()) for p in model.parameters()]))
384
+
385
+ length = 88200 * 2
386
+ x = torch.randn(1, 1, length).to(model.device)
387
+ x.requires_grad_(True)
388
+ x.retain_grad()
389
+
390
+ # Make a forward pass
391
+ out = model(x)["audio"]
392
+ print("Input shape:", x.shape)
393
+ print("Output shape:", out.shape)
394
+
395
+ # Create gradient variable
396
+ grad = torch.zeros_like(out)
397
+ grad[:, :, grad.shape[-1] // 2] = 1
398
+
399
+ # Make a backward pass
400
+ out.backward(grad)
401
+
402
+ # Check non-zero values
403
+ gradmap = x.grad.squeeze(0)
404
+ gradmap = (gradmap != 0).sum(0) # sum across features
405
+ rf = (gradmap != 0).sum()
406
+
407
+ print(f"Receptive field: {rf.item()}")
408
+
409
+ x = AudioSignal(torch.randn(1, 1, 44100 * 60), 44100)
410
+ model.decompress(model.compress(x, verbose=True), verbose=True)
hunyuanvideo_foley/models/dac_vae/model/discriminator.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from audiotools import AudioSignal
5
+ from audiotools import ml
6
+ from audiotools import STFTParams
7
+ from einops import rearrange
8
+ from torch.nn.utils import weight_norm
9
+
10
+
11
+ def WNConv1d(*args, **kwargs):
12
+ act = kwargs.pop("act", True)
13
+ conv = weight_norm(nn.Conv1d(*args, **kwargs))
14
+ if not act:
15
+ return conv
16
+ return nn.Sequential(conv, nn.LeakyReLU(0.1))
17
+
18
+
19
+ def WNConv2d(*args, **kwargs):
20
+ act = kwargs.pop("act", True)
21
+ conv = weight_norm(nn.Conv2d(*args, **kwargs))
22
+ if not act:
23
+ return conv
24
+ return nn.Sequential(conv, nn.LeakyReLU(0.1))
25
+
26
+
27
+ class MPD(nn.Module):
28
+ def __init__(self, period):
29
+ super().__init__()
30
+ self.period = period
31
+ self.convs = nn.ModuleList(
32
+ [
33
+ WNConv2d(1, 32, (5, 1), (3, 1), padding=(2, 0)),
34
+ WNConv2d(32, 128, (5, 1), (3, 1), padding=(2, 0)),
35
+ WNConv2d(128, 512, (5, 1), (3, 1), padding=(2, 0)),
36
+ WNConv2d(512, 1024, (5, 1), (3, 1), padding=(2, 0)),
37
+ WNConv2d(1024, 1024, (5, 1), 1, padding=(2, 0)),
38
+ ]
39
+ )
40
+ self.conv_post = WNConv2d(
41
+ 1024, 1, kernel_size=(3, 1), padding=(1, 0), act=False
42
+ )
43
+
44
+ def pad_to_period(self, x):
45
+ t = x.shape[-1]
46
+ x = F.pad(x, (0, self.period - t % self.period), mode="reflect")
47
+ return x
48
+
49
+ def forward(self, x):
50
+ fmap = []
51
+
52
+ x = self.pad_to_period(x)
53
+ x = rearrange(x, "b c (l p) -> b c l p", p=self.period)
54
+
55
+ for layer in self.convs:
56
+ x = layer(x)
57
+ fmap.append(x)
58
+
59
+ x = self.conv_post(x)
60
+ fmap.append(x)
61
+
62
+ return fmap
63
+
64
+
65
+ class MSD(nn.Module):
66
+ def __init__(self, rate: int = 1, sample_rate: int = 44100):
67
+ super().__init__()
68
+ self.convs = nn.ModuleList(
69
+ [
70
+ WNConv1d(1, 16, 15, 1, padding=7),
71
+ WNConv1d(16, 64, 41, 4, groups=4, padding=20),
72
+ WNConv1d(64, 256, 41, 4, groups=16, padding=20),
73
+ WNConv1d(256, 1024, 41, 4, groups=64, padding=20),
74
+ WNConv1d(1024, 1024, 41, 4, groups=256, padding=20),
75
+ WNConv1d(1024, 1024, 5, 1, padding=2),
76
+ ]
77
+ )
78
+ self.conv_post = WNConv1d(1024, 1, 3, 1, padding=1, act=False)
79
+ self.sample_rate = sample_rate
80
+ self.rate = rate
81
+
82
+ def forward(self, x):
83
+ x = AudioSignal(x, self.sample_rate)
84
+ x.resample(self.sample_rate // self.rate)
85
+ x = x.audio_data
86
+
87
+ fmap = []
88
+
89
+ for l in self.convs:
90
+ x = l(x)
91
+ fmap.append(x)
92
+ x = self.conv_post(x)
93
+ fmap.append(x)
94
+
95
+ return fmap
96
+
97
+
98
+ BANDS = [(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)]
99
+
100
+
101
+ class MRD(nn.Module):
102
+ def __init__(
103
+ self,
104
+ window_length: int,
105
+ hop_factor: float = 0.25,
106
+ sample_rate: int = 44100,
107
+ bands: list = BANDS,
108
+ ):
109
+ """Complex multi-band spectrogram discriminator.
110
+ Parameters
111
+ ----------
112
+ window_length : int
113
+ Window length of STFT.
114
+ hop_factor : float, optional
115
+ Hop factor of the STFT, defaults to ``0.25 * window_length``.
116
+ sample_rate : int, optional
117
+ Sampling rate of audio in Hz, by default 44100
118
+ bands : list, optional
119
+ Bands to run discriminator over.
120
+ """
121
+ super().__init__()
122
+
123
+ self.window_length = window_length
124
+ self.hop_factor = hop_factor
125
+ self.sample_rate = sample_rate
126
+ self.stft_params = STFTParams(
127
+ window_length=window_length,
128
+ hop_length=int(window_length * hop_factor),
129
+ match_stride=True,
130
+ )
131
+
132
+ n_fft = window_length // 2 + 1
133
+ bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
134
+ self.bands = bands
135
+
136
+ ch = 32
137
+ convs = lambda: nn.ModuleList(
138
+ [
139
+ WNConv2d(2, ch, (3, 9), (1, 1), padding=(1, 4)),
140
+ WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
141
+ WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
142
+ WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
143
+ WNConv2d(ch, ch, (3, 3), (1, 1), padding=(1, 1)),
144
+ ]
145
+ )
146
+ self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
147
+ self.conv_post = WNConv2d(ch, 1, (3, 3), (1, 1), padding=(1, 1), act=False)
148
+
149
+ def spectrogram(self, x):
150
+ x = AudioSignal(x, self.sample_rate, stft_params=self.stft_params)
151
+ x = torch.view_as_real(x.stft())
152
+ x = rearrange(x, "b 1 f t c -> (b 1) c t f")
153
+ # Split into bands
154
+ x_bands = [x[..., b[0] : b[1]] for b in self.bands]
155
+ return x_bands
156
+
157
+ def forward(self, x):
158
+ x_bands = self.spectrogram(x)
159
+ fmap = []
160
+
161
+ x = []
162
+ for band, stack in zip(x_bands, self.band_convs):
163
+ for layer in stack:
164
+ band = layer(band)
165
+ fmap.append(band)
166
+ x.append(band)
167
+
168
+ x = torch.cat(x, dim=-1)
169
+ x = self.conv_post(x)
170
+ fmap.append(x)
171
+
172
+ return fmap
173
+
174
+
175
+ class Discriminator(ml.BaseModel):
176
+ def __init__(
177
+ self,
178
+ rates: list = [],
179
+ periods: list = [2, 3, 5, 7, 11],
180
+ fft_sizes: list = [2048, 1024, 512],
181
+ sample_rate: int = 44100,
182
+ bands: list = BANDS,
183
+ ):
184
+ """Discriminator that combines multiple discriminators.
185
+
186
+ Parameters
187
+ ----------
188
+ rates : list, optional
189
+ sampling rates (in Hz) to run MSD at, by default []
190
+ If empty, MSD is not used.
191
+ periods : list, optional
192
+ periods (of samples) to run MPD at, by default [2, 3, 5, 7, 11]
193
+ fft_sizes : list, optional
194
+ Window sizes of the FFT to run MRD at, by default [2048, 1024, 512]
195
+ sample_rate : int, optional
196
+ Sampling rate of audio in Hz, by default 44100
197
+ bands : list, optional
198
+ Bands to run MRD at, by default `BANDS`
199
+ """
200
+ super().__init__()
201
+ discs = []
202
+ discs += [MPD(p) for p in periods]
203
+ discs += [MSD(r, sample_rate=sample_rate) for r in rates]
204
+ discs += [MRD(f, sample_rate=sample_rate, bands=bands) for f in fft_sizes]
205
+ self.discriminators = nn.ModuleList(discs)
206
+
207
+ def preprocess(self, y):
208
+ # Remove DC offset
209
+ y = y - y.mean(dim=-1, keepdims=True)
210
+ # Peak normalize the volume of input audio
211
+ y = 0.8 * y / (y.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
212
+ return y
213
+
214
+ def forward(self, x):
215
+ x = self.preprocess(x)
216
+ fmaps = [d(x) for d in self.discriminators]
217
+ return fmaps
218
+
219
+
220
+ if __name__ == "__main__":
221
+ disc = Discriminator()
222
+ x = torch.zeros(1, 1, 44100)
223
+ results = disc(x)
224
+ for i, result in enumerate(results):
225
+ print(f"disc{i}")
226
+ for i, r in enumerate(result):
227
+ print(r.shape, r.mean(), r.min(), r.max())
228
+ print()
hunyuanvideo_foley/models/dac_vae/nn/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from . import layers
2
+ from . import loss
3
+ from . import quantize
hunyuanvideo_foley/models/dac_vae/nn/layers.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from einops import rearrange
6
+ from torch.nn.utils import weight_norm
7
+
8
+
9
+ def WNConv1d(*args, **kwargs):
10
+ return weight_norm(nn.Conv1d(*args, **kwargs))
11
+
12
+
13
+ def WNConvTranspose1d(*args, **kwargs):
14
+ return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
15
+
16
+
17
+ # Scripting this brings model speed up 1.4x
18
+ @torch.jit.script
19
+ def snake(x, alpha):
20
+ shape = x.shape
21
+ x = x.reshape(shape[0], shape[1], -1)
22
+ x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
23
+ x = x.reshape(shape)
24
+ return x
25
+
26
+
27
+ class Snake1d(nn.Module):
28
+ def __init__(self, channels):
29
+ super().__init__()
30
+ self.alpha = nn.Parameter(torch.ones(1, channels, 1))
31
+
32
+ def forward(self, x):
33
+ return snake(x, self.alpha)
hunyuanvideo_foley/models/dac_vae/nn/loss.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import typing
2
+ from typing import List
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from audiotools import AudioSignal
7
+ from audiotools import STFTParams
8
+ from torch import nn
9
+
10
+
11
+ class L1Loss(nn.L1Loss):
12
+ """L1 Loss between AudioSignals. Defaults
13
+ to comparing ``audio_data``, but any
14
+ attribute of an AudioSignal can be used.
15
+
16
+ Parameters
17
+ ----------
18
+ attribute : str, optional
19
+ Attribute of signal to compare, defaults to ``audio_data``.
20
+ weight : float, optional
21
+ Weight of this loss, defaults to 1.0.
22
+
23
+ Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py
24
+ """
25
+
26
+ def __init__(self, attribute: str = "audio_data", weight: float = 1.0, **kwargs):
27
+ self.attribute = attribute
28
+ self.weight = weight
29
+ super().__init__(**kwargs)
30
+
31
+ def forward(self, x: AudioSignal, y: AudioSignal):
32
+ """
33
+ Parameters
34
+ ----------
35
+ x : AudioSignal
36
+ Estimate AudioSignal
37
+ y : AudioSignal
38
+ Reference AudioSignal
39
+
40
+ Returns
41
+ -------
42
+ torch.Tensor
43
+ L1 loss between AudioSignal attributes.
44
+ """
45
+ if isinstance(x, AudioSignal):
46
+ x = getattr(x, self.attribute)
47
+ y = getattr(y, self.attribute)
48
+ return super().forward(x, y)
49
+
50
+
51
+ class SISDRLoss(nn.Module):
52
+ """
53
+ Computes the Scale-Invariant Source-to-Distortion Ratio between a batch
54
+ of estimated and reference audio signals or aligned features.
55
+
56
+ Parameters
57
+ ----------
58
+ scaling : int, optional
59
+ Whether to use scale-invariant (True) or
60
+ signal-to-noise ratio (False), by default True
61
+ reduction : str, optional
62
+ How to reduce across the batch (either 'mean',
63
+ 'sum', or none).], by default ' mean'
64
+ zero_mean : int, optional
65
+ Zero mean the references and estimates before
66
+ computing the loss, by default True
67
+ clip_min : int, optional
68
+ The minimum possible loss value. Helps network
69
+ to not focus on making already good examples better, by default None
70
+ weight : float, optional
71
+ Weight of this loss, defaults to 1.0.
72
+
73
+ Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py
74
+ """
75
+
76
+ def __init__(
77
+ self,
78
+ scaling: int = True,
79
+ reduction: str = "mean",
80
+ zero_mean: int = True,
81
+ clip_min: int = None,
82
+ weight: float = 1.0,
83
+ ):
84
+ self.scaling = scaling
85
+ self.reduction = reduction
86
+ self.zero_mean = zero_mean
87
+ self.clip_min = clip_min
88
+ self.weight = weight
89
+ super().__init__()
90
+
91
+ def forward(self, x: AudioSignal, y: AudioSignal):
92
+ eps = 1e-8
93
+ # nb, nc, nt
94
+ if isinstance(x, AudioSignal):
95
+ references = x.audio_data
96
+ estimates = y.audio_data
97
+ else:
98
+ references = x
99
+ estimates = y
100
+
101
+ nb = references.shape[0]
102
+ references = references.reshape(nb, 1, -1).permute(0, 2, 1)
103
+ estimates = estimates.reshape(nb, 1, -1).permute(0, 2, 1)
104
+
105
+ # samples now on axis 1
106
+ if self.zero_mean:
107
+ mean_reference = references.mean(dim=1, keepdim=True)
108
+ mean_estimate = estimates.mean(dim=1, keepdim=True)
109
+ else:
110
+ mean_reference = 0
111
+ mean_estimate = 0
112
+
113
+ _references = references - mean_reference
114
+ _estimates = estimates - mean_estimate
115
+
116
+ references_projection = (_references**2).sum(dim=-2) + eps
117
+ references_on_estimates = (_estimates * _references).sum(dim=-2) + eps
118
+
119
+ scale = (
120
+ (references_on_estimates / references_projection).unsqueeze(1)
121
+ if self.scaling
122
+ else 1
123
+ )
124
+
125
+ e_true = scale * _references
126
+ e_res = _estimates - e_true
127
+
128
+ signal = (e_true**2).sum(dim=1)
129
+ noise = (e_res**2).sum(dim=1)
130
+ sdr = -10 * torch.log10(signal / noise + eps)
131
+
132
+ if self.clip_min is not None:
133
+ sdr = torch.clamp(sdr, min=self.clip_min)
134
+
135
+ if self.reduction == "mean":
136
+ sdr = sdr.mean()
137
+ elif self.reduction == "sum":
138
+ sdr = sdr.sum()
139
+ return sdr
140
+
141
+
142
+ class MultiScaleSTFTLoss(nn.Module):
143
+ """Computes the multi-scale STFT loss from [1].
144
+
145
+ Parameters
146
+ ----------
147
+ window_lengths : List[int], optional
148
+ Length of each window of each STFT, by default [2048, 512]
149
+ loss_fn : typing.Callable, optional
150
+ How to compare each loss, by default nn.L1Loss()
151
+ clamp_eps : float, optional
152
+ Clamp on the log magnitude, below, by default 1e-5
153
+ mag_weight : float, optional
154
+ Weight of raw magnitude portion of loss, by default 1.0
155
+ log_weight : float, optional
156
+ Weight of log magnitude portion of loss, by default 1.0
157
+ pow : float, optional
158
+ Power to raise magnitude to before taking log, by default 2.0
159
+ weight : float, optional
160
+ Weight of this loss, by default 1.0
161
+ match_stride : bool, optional
162
+ Whether to match the stride of convolutional layers, by default False
163
+
164
+ References
165
+ ----------
166
+
167
+ 1. Engel, Jesse, Chenjie Gu, and Adam Roberts.
168
+ "DDSP: Differentiable Digital Signal Processing."
169
+ International Conference on Learning Representations. 2019.
170
+
171
+ Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py
172
+ """
173
+
174
+ def __init__(
175
+ self,
176
+ window_lengths: List[int] = [2048, 512],
177
+ loss_fn: typing.Callable = nn.L1Loss(),
178
+ clamp_eps: float = 1e-5,
179
+ mag_weight: float = 1.0,
180
+ log_weight: float = 1.0,
181
+ pow: float = 2.0,
182
+ weight: float = 1.0,
183
+ match_stride: bool = False,
184
+ window_type: str = None,
185
+ ):
186
+ super().__init__()
187
+ self.stft_params = [
188
+ STFTParams(
189
+ window_length=w,
190
+ hop_length=w // 4,
191
+ match_stride=match_stride,
192
+ window_type=window_type,
193
+ )
194
+ for w in window_lengths
195
+ ]
196
+ self.loss_fn = loss_fn
197
+ self.log_weight = log_weight
198
+ self.mag_weight = mag_weight
199
+ self.clamp_eps = clamp_eps
200
+ self.weight = weight
201
+ self.pow = pow
202
+
203
+ def forward(self, x: AudioSignal, y: AudioSignal):
204
+ """Computes multi-scale STFT between an estimate and a reference
205
+ signal.
206
+
207
+ Parameters
208
+ ----------
209
+ x : AudioSignal
210
+ Estimate signal
211
+ y : AudioSignal
212
+ Reference signal
213
+
214
+ Returns
215
+ -------
216
+ torch.Tensor
217
+ Multi-scale STFT loss.
218
+ """
219
+ loss = 0.0
220
+ for s in self.stft_params:
221
+ x.stft(s.window_length, s.hop_length, s.window_type)
222
+ y.stft(s.window_length, s.hop_length, s.window_type)
223
+ loss += self.log_weight * self.loss_fn(
224
+ x.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(),
225
+ y.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(),
226
+ )
227
+ loss += self.mag_weight * self.loss_fn(x.magnitude, y.magnitude)
228
+ return loss
229
+
230
+
231
+ class MelSpectrogramLoss(nn.Module):
232
+ """Compute distance between mel spectrograms. Can be used
233
+ in a multi-scale way.
234
+
235
+ Parameters
236
+ ----------
237
+ n_mels : List[int]
238
+ Number of mels per STFT, by default [150, 80],
239
+ window_lengths : List[int], optional
240
+ Length of each window of each STFT, by default [2048, 512]
241
+ loss_fn : typing.Callable, optional
242
+ How to compare each loss, by default nn.L1Loss()
243
+ clamp_eps : float, optional
244
+ Clamp on the log magnitude, below, by default 1e-5
245
+ mag_weight : float, optional
246
+ Weight of raw magnitude portion of loss, by default 1.0
247
+ log_weight : float, optional
248
+ Weight of log magnitude portion of loss, by default 1.0
249
+ pow : float, optional
250
+ Power to raise magnitude to before taking log, by default 2.0
251
+ weight : float, optional
252
+ Weight of this loss, by default 1.0
253
+ match_stride : bool, optional
254
+ Whether to match the stride of convolutional layers, by default False
255
+
256
+ Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py
257
+ """
258
+
259
+ def __init__(
260
+ self,
261
+ n_mels: List[int] = [150, 80],
262
+ window_lengths: List[int] = [2048, 512],
263
+ loss_fn: typing.Callable = nn.L1Loss(),
264
+ clamp_eps: float = 1e-5,
265
+ mag_weight: float = 1.0,
266
+ log_weight: float = 1.0,
267
+ pow: float = 2.0,
268
+ weight: float = 1.0,
269
+ match_stride: bool = False,
270
+ mel_fmin: List[float] = [0.0, 0.0],
271
+ mel_fmax: List[float] = [None, None],
272
+ window_type: str = None,
273
+ ):
274
+ super().__init__()
275
+ self.stft_params = [
276
+ STFTParams(
277
+ window_length=w,
278
+ hop_length=w // 4,
279
+ match_stride=match_stride,
280
+ window_type=window_type,
281
+ )
282
+ for w in window_lengths
283
+ ]
284
+ self.n_mels = n_mels
285
+ self.loss_fn = loss_fn
286
+ self.clamp_eps = clamp_eps
287
+ self.log_weight = log_weight
288
+ self.mag_weight = mag_weight
289
+ self.weight = weight
290
+ self.mel_fmin = mel_fmin
291
+ self.mel_fmax = mel_fmax
292
+ self.pow = pow
293
+
294
+ def forward(self, x: AudioSignal, y: AudioSignal):
295
+ """Computes mel loss between an estimate and a reference
296
+ signal.
297
+
298
+ Parameters
299
+ ----------
300
+ x : AudioSignal
301
+ Estimate signal
302
+ y : AudioSignal
303
+ Reference signal
304
+
305
+ Returns
306
+ -------
307
+ torch.Tensor
308
+ Mel loss.
309
+ """
310
+ loss = 0.0
311
+ for n_mels, fmin, fmax, s in zip(
312
+ self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params
313
+ ):
314
+ kwargs = {
315
+ "window_length": s.window_length,
316
+ "hop_length": s.hop_length,
317
+ "window_type": s.window_type,
318
+ }
319
+ x_mels = x.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs)
320
+ y_mels = y.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs)
321
+
322
+ loss += self.log_weight * self.loss_fn(
323
+ x_mels.clamp(self.clamp_eps).pow(self.pow).log10(),
324
+ y_mels.clamp(self.clamp_eps).pow(self.pow).log10(),
325
+ )
326
+ loss += self.mag_weight * self.loss_fn(x_mels, y_mels)
327
+ return loss
328
+
329
+
330
+ class GANLoss(nn.Module):
331
+ """
332
+ Computes a discriminator loss, given a discriminator on
333
+ generated waveforms/spectrograms compared to ground truth
334
+ waveforms/spectrograms. Computes the loss for both the
335
+ discriminator and the generator in separate functions.
336
+ """
337
+
338
+ def __init__(self, discriminator):
339
+ super().__init__()
340
+ self.discriminator = discriminator
341
+
342
+ def forward(self, fake, real):
343
+ d_fake = self.discriminator(fake.audio_data)
344
+ d_real = self.discriminator(real.audio_data)
345
+ return d_fake, d_real
346
+
347
+ def discriminator_loss(self, fake, real):
348
+ d_fake, d_real = self.forward(fake.clone().detach(), real)
349
+
350
+ loss_d = 0
351
+ for x_fake, x_real in zip(d_fake, d_real):
352
+ loss_d += torch.mean(x_fake[-1] ** 2)
353
+ loss_d += torch.mean((1 - x_real[-1]) ** 2)
354
+ return loss_d
355
+
356
+ def generator_loss(self, fake, real):
357
+ d_fake, d_real = self.forward(fake, real)
358
+
359
+ loss_g = 0
360
+ for x_fake in d_fake:
361
+ loss_g += torch.mean((1 - x_fake[-1]) ** 2)
362
+
363
+ loss_feature = 0
364
+
365
+ for i in range(len(d_fake)):
366
+ for j in range(len(d_fake[i]) - 1):
367
+ loss_feature += F.l1_loss(d_fake[i][j], d_real[i][j].detach())
368
+ return loss_g, loss_feature
hunyuanvideo_foley/models/dac_vae/nn/quantize.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from einops import rearrange
8
+ from torch.nn.utils import weight_norm
9
+
10
+ from .layers import WNConv1d
11
+
12
+
13
+ class VectorQuantize(nn.Module):
14
+ """
15
+ Implementation of VQ similar to Karpathy's repo:
16
+ https://github.com/karpathy/deep-vector-quantization
17
+ Additionally uses following tricks from Improved VQGAN
18
+ (https://arxiv.org/pdf/2110.04627.pdf):
19
+ 1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
20
+ for improved codebook usage
21
+ 2. l2-normalized codes: Converts euclidean distance to cosine similarity which
22
+ improves training stability
23
+ """
24
+
25
+ def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int):
26
+ super().__init__()
27
+ self.codebook_size = codebook_size
28
+ self.codebook_dim = codebook_dim
29
+
30
+ self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1)
31
+ self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1)
32
+ self.codebook = nn.Embedding(codebook_size, codebook_dim)
33
+
34
+ def forward(self, z):
35
+ """Quantized the input tensor using a fixed codebook and returns
36
+ the corresponding codebook vectors
37
+
38
+ Parameters
39
+ ----------
40
+ z : Tensor[B x D x T]
41
+
42
+ Returns
43
+ -------
44
+ Tensor[B x D x T]
45
+ Quantized continuous representation of input
46
+ Tensor[1]
47
+ Commitment loss to train encoder to predict vectors closer to codebook
48
+ entries
49
+ Tensor[1]
50
+ Codebook loss to update the codebook
51
+ Tensor[B x T]
52
+ Codebook indices (quantized discrete representation of input)
53
+ Tensor[B x D x T]
54
+ Projected latents (continuous representation of input before quantization)
55
+ """
56
+
57
+ # Factorized codes (ViT-VQGAN) Project input into low-dimensional space
58
+ z_e = self.in_proj(z) # z_e : (B x D x T)
59
+ z_q, indices = self.decode_latents(z_e)
60
+
61
+ commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
62
+ codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
63
+
64
+ z_q = (
65
+ z_e + (z_q - z_e).detach()
66
+ ) # noop in forward pass, straight-through gradient estimator in backward pass
67
+
68
+ z_q = self.out_proj(z_q)
69
+
70
+ return z_q, commitment_loss, codebook_loss, indices, z_e
71
+
72
+ def embed_code(self, embed_id):
73
+ return F.embedding(embed_id, self.codebook.weight)
74
+
75
+ def decode_code(self, embed_id):
76
+ return self.embed_code(embed_id).transpose(1, 2)
77
+
78
+ def decode_latents(self, latents):
79
+ encodings = rearrange(latents, "b d t -> (b t) d")
80
+ codebook = self.codebook.weight # codebook: (N x D)
81
+
82
+ # L2 normalize encodings and codebook (ViT-VQGAN)
83
+ encodings = F.normalize(encodings)
84
+ codebook = F.normalize(codebook)
85
+
86
+ # Compute euclidean distance with codebook
87
+ dist = (
88
+ encodings.pow(2).sum(1, keepdim=True)
89
+ - 2 * encodings @ codebook.t()
90
+ + codebook.pow(2).sum(1, keepdim=True).t()
91
+ )
92
+ indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
93
+ z_q = self.decode_code(indices)
94
+ return z_q, indices
95
+
96
+
97
+ class ResidualVectorQuantize(nn.Module):
98
+ """
99
+ Introduced in SoundStream: An end2end neural audio codec
100
+ https://arxiv.org/abs/2107.03312
101
+ """
102
+
103
+ def __init__(
104
+ self,
105
+ input_dim: int = 512,
106
+ n_codebooks: int = 9,
107
+ codebook_size: int = 1024,
108
+ codebook_dim: Union[int, list] = 8,
109
+ quantizer_dropout: float = 0.0,
110
+ ):
111
+ super().__init__()
112
+ if isinstance(codebook_dim, int):
113
+ codebook_dim = [codebook_dim for _ in range(n_codebooks)]
114
+
115
+ self.n_codebooks = n_codebooks
116
+ self.codebook_dim = codebook_dim
117
+ self.codebook_size = codebook_size
118
+
119
+ self.quantizers = nn.ModuleList(
120
+ [
121
+ VectorQuantize(input_dim, codebook_size, codebook_dim[i])
122
+ for i in range(n_codebooks)
123
+ ]
124
+ )
125
+ self.quantizer_dropout = quantizer_dropout
126
+
127
+ def forward(self, z, n_quantizers: int = None):
128
+ """Quantized the input tensor using a fixed set of `n` codebooks and returns
129
+ the corresponding codebook vectors
130
+ Parameters
131
+ ----------
132
+ z : Tensor[B x D x T]
133
+ n_quantizers : int, optional
134
+ No. of quantizers to use
135
+ (n_quantizers < self.n_codebooks ex: for quantizer dropout)
136
+ Note: if `self.quantizer_dropout` is True, this argument is ignored
137
+ when in training mode, and a random number of quantizers is used.
138
+ Returns
139
+ -------
140
+ dict
141
+ A dictionary with the following keys:
142
+
143
+ "z" : Tensor[B x D x T]
144
+ Quantized continuous representation of input
145
+ "codes" : Tensor[B x N x T]
146
+ Codebook indices for each codebook
147
+ (quantized discrete representation of input)
148
+ "latents" : Tensor[B x N*D x T]
149
+ Projected latents (continuous representation of input before quantization)
150
+ "vq/commitment_loss" : Tensor[1]
151
+ Commitment loss to train encoder to predict vectors closer to codebook
152
+ entries
153
+ "vq/codebook_loss" : Tensor[1]
154
+ Codebook loss to update the codebook
155
+ """
156
+ z_q = 0
157
+ residual = z
158
+ commitment_loss = 0
159
+ codebook_loss = 0
160
+
161
+ codebook_indices = []
162
+ latents = []
163
+
164
+ if n_quantizers is None:
165
+ n_quantizers = self.n_codebooks
166
+ if self.training:
167
+ n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1
168
+ dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],))
169
+ n_dropout = int(z.shape[0] * self.quantizer_dropout)
170
+ n_quantizers[:n_dropout] = dropout[:n_dropout]
171
+ n_quantizers = n_quantizers.to(z.device)
172
+
173
+ for i, quantizer in enumerate(self.quantizers):
174
+ if self.training is False and i >= n_quantizers:
175
+ break
176
+
177
+ z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(
178
+ residual
179
+ )
180
+
181
+ # Create mask to apply quantizer dropout
182
+ mask = (
183
+ torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
184
+ )
185
+ z_q = z_q + z_q_i * mask[:, None, None]
186
+ residual = residual - z_q_i
187
+
188
+ # Sum losses
189
+ commitment_loss += (commitment_loss_i * mask).mean()
190
+ codebook_loss += (codebook_loss_i * mask).mean()
191
+
192
+ codebook_indices.append(indices_i)
193
+ latents.append(z_e_i)
194
+
195
+ codes = torch.stack(codebook_indices, dim=1)
196
+ latents = torch.cat(latents, dim=1)
197
+
198
+ return z_q, codes, latents, commitment_loss, codebook_loss
199
+
200
+ def from_codes(self, codes: torch.Tensor):
201
+ """Given the quantized codes, reconstruct the continuous representation
202
+ Parameters
203
+ ----------
204
+ codes : Tensor[B x N x T]
205
+ Quantized discrete representation of input
206
+ Returns
207
+ -------
208
+ Tensor[B x D x T]
209
+ Quantized continuous representation of input
210
+ """
211
+ z_q = 0.0
212
+ z_p = []
213
+ n_codebooks = codes.shape[1]
214
+ for i in range(n_codebooks):
215
+ z_p_i = self.quantizers[i].decode_code(codes[:, i, :])
216
+ z_p.append(z_p_i)
217
+
218
+ z_q_i = self.quantizers[i].out_proj(z_p_i)
219
+ z_q = z_q + z_q_i
220
+ return z_q, torch.cat(z_p, dim=1), codes
221
+
222
+ def from_latents(self, latents: torch.Tensor):
223
+ """Given the unquantized latents, reconstruct the
224
+ continuous representation after quantization.
225
+
226
+ Parameters
227
+ ----------
228
+ latents : Tensor[B x N x T]
229
+ Continuous representation of input after projection
230
+
231
+ Returns
232
+ -------
233
+ Tensor[B x D x T]
234
+ Quantized representation of full-projected space
235
+ Tensor[B x D x T]
236
+ Quantized representation of latent space
237
+ """
238
+ z_q = 0
239
+ z_p = []
240
+ codes = []
241
+ dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers])
242
+
243
+ n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[
244
+ 0
245
+ ]
246
+ for i in range(n_codebooks):
247
+ j, k = dims[i], dims[i + 1]
248
+ z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :])
249
+ z_p.append(z_p_i)
250
+ codes.append(codes_i)
251
+
252
+ z_q_i = self.quantizers[i].out_proj(z_p_i)
253
+ z_q = z_q + z_q_i
254
+
255
+ return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1)
256
+
257
+
258
+ if __name__ == "__main__":
259
+ rvq = ResidualVectorQuantize(quantizer_dropout=True)
260
+ x = torch.randn(16, 512, 80)
261
+ y = rvq(x)
262
+ print(y["latents"].shape)
hunyuanvideo_foley/models/dac_vae/nn/vae_utils.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+
5
+ class AbstractDistribution:
6
+ def sample(self):
7
+ raise NotImplementedError()
8
+
9
+ def mode(self):
10
+ raise NotImplementedError()
11
+
12
+
13
+ class DiracDistribution(AbstractDistribution):
14
+ def __init__(self, value):
15
+ self.value = value
16
+
17
+ def sample(self):
18
+ return self.value
19
+
20
+ def mode(self):
21
+ return self.value
22
+
23
+
24
+ class DiagonalGaussianDistribution(object):
25
+ def __init__(self, parameters, deterministic=False):
26
+ self.parameters = parameters
27
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
28
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
29
+ self.deterministic = deterministic
30
+ self.std = torch.exp(0.5 * self.logvar)
31
+ self.var = torch.exp(self.logvar)
32
+ if self.deterministic:
33
+ self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
34
+
35
+ def sample(self):
36
+ x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
37
+ return x
38
+
39
+ def kl(self, other=None):
40
+ if self.deterministic:
41
+ return torch.Tensor([0.0])
42
+ else:
43
+ if other is None:
44
+ return 0.5 * torch.mean(
45
+ torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
46
+ dim=[1, 2],
47
+ )
48
+ else:
49
+ return 0.5 * torch.mean(
50
+ torch.pow(self.mean - other.mean, 2) / other.var
51
+ + self.var / other.var
52
+ - 1.0
53
+ - self.logvar
54
+ + other.logvar,
55
+ dim=[1, 2],
56
+ )
57
+
58
+ def nll(self, sample, dims=[1, 2]):
59
+ if self.deterministic:
60
+ return torch.Tensor([0.0])
61
+ logtwopi = np.log(2.0 * np.pi)
62
+ return 0.5 * torch.sum(
63
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
64
+ dim=dims,
65
+ )
66
+
67
+ def mode(self):
68
+ return self.mean
69
+
70
+
71
+ def normal_kl(mean1, logvar1, mean2, logvar2):
72
+ """
73
+ source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
74
+ Compute the KL divergence between two gaussians.
75
+ Shapes are automatically broadcasted, so batches can be compared to
76
+ scalars, among other use cases.
77
+ """
78
+ tensor = None
79
+ for obj in (mean1, logvar1, mean2, logvar2):
80
+ if isinstance(obj, torch.Tensor):
81
+ tensor = obj
82
+ break
83
+ assert tensor is not None, "at least one argument must be a Tensor"
84
+
85
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
86
+ # Tensors, but it does not work for torch.exp().
87
+ logvar1, logvar2 = [x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) for x in (logvar1, logvar2)]
88
+
89
+ return 0.5 * (
90
+ -1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
91
+ )
hunyuanvideo_foley/models/dac_vae/utils/__init__.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import argbind
4
+ from audiotools import ml
5
+
6
+ from ..model import DAC
7
+ Accelerator = ml.Accelerator
8
+
9
+ __MODEL_LATEST_TAGS__ = {
10
+ ("44khz", "8kbps"): "0.0.1",
11
+ ("24khz", "8kbps"): "0.0.4",
12
+ ("16khz", "8kbps"): "0.0.5",
13
+ ("44khz", "16kbps"): "1.0.0",
14
+ }
15
+
16
+ __MODEL_URLS__ = {
17
+ (
18
+ "44khz",
19
+ "0.0.1",
20
+ "8kbps",
21
+ ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.1/weights.pth",
22
+ (
23
+ "24khz",
24
+ "0.0.4",
25
+ "8kbps",
26
+ ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.4/weights_24khz.pth",
27
+ (
28
+ "16khz",
29
+ "0.0.5",
30
+ "8kbps",
31
+ ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.5/weights_16khz.pth",
32
+ (
33
+ "44khz",
34
+ "1.0.0",
35
+ "16kbps",
36
+ ): "https://github.com/descriptinc/descript-audio-codec/releases/download/1.0.0/weights_44khz_16kbps.pth",
37
+ }
38
+
39
+
40
+ @argbind.bind(group="download", positional=True, without_prefix=True)
41
+ def download(
42
+ model_type: str = "44khz", model_bitrate: str = "8kbps", tag: str = "latest"
43
+ ):
44
+ """
45
+ Function that downloads the weights file from URL if a local cache is not found.
46
+
47
+ Parameters
48
+ ----------
49
+ model_type : str
50
+ The type of model to download. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz".
51
+ model_bitrate: str
52
+ Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps".
53
+ Only 44khz model supports 16kbps.
54
+ tag : str
55
+ The tag of the model to download. Defaults to "latest".
56
+
57
+ Returns
58
+ -------
59
+ Path
60
+ Directory path required to load model via audiotools.
61
+ """
62
+ model_type = model_type.lower()
63
+ tag = tag.lower()
64
+
65
+ assert model_type in [
66
+ "44khz",
67
+ "24khz",
68
+ "16khz",
69
+ ], "model_type must be one of '44khz', '24khz', or '16khz'"
70
+
71
+ assert model_bitrate in [
72
+ "8kbps",
73
+ "16kbps",
74
+ ], "model_bitrate must be one of '8kbps', or '16kbps'"
75
+
76
+ if tag == "latest":
77
+ tag = __MODEL_LATEST_TAGS__[(model_type, model_bitrate)]
78
+
79
+ download_link = __MODEL_URLS__.get((model_type, tag, model_bitrate), None)
80
+
81
+ if download_link is None:
82
+ raise ValueError(
83
+ f"Could not find model with tag {tag} and model type {model_type}"
84
+ )
85
+
86
+ local_path = (
87
+ Path.home()
88
+ / ".cache"
89
+ / "descript"
90
+ / "dac"
91
+ / f"weights_{model_type}_{model_bitrate}_{tag}.pth"
92
+ )
93
+ if not local_path.exists():
94
+ local_path.parent.mkdir(parents=True, exist_ok=True)
95
+
96
+ # Download the model
97
+ import requests
98
+
99
+ response = requests.get(download_link)
100
+
101
+ if response.status_code != 200:
102
+ raise ValueError(
103
+ f"Could not download model. Received response code {response.status_code}"
104
+ )
105
+ local_path.write_bytes(response.content)
106
+
107
+ return local_path
108
+
109
+
110
+ def load_model(
111
+ model_type: str = "44khz",
112
+ model_bitrate: str = "8kbps",
113
+ tag: str = "latest",
114
+ load_path: str = None,
115
+ ):
116
+ if not load_path:
117
+ load_path = download(
118
+ model_type=model_type, model_bitrate=model_bitrate, tag=tag
119
+ )
120
+ generator = DAC.load(load_path)
121
+ return generator
hunyuanvideo_foley/models/dac_vae/utils/decode.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ from pathlib import Path
3
+
4
+ import argbind
5
+ import numpy as np
6
+ import torch
7
+ from audiotools import AudioSignal
8
+ from tqdm import tqdm
9
+
10
+ from ..model import DACFile
11
+ from . import load_model
12
+
13
+ warnings.filterwarnings("ignore", category=UserWarning)
14
+
15
+
16
+ @argbind.bind(group="decode", positional=True, without_prefix=True)
17
+ @torch.inference_mode()
18
+ @torch.no_grad()
19
+ def decode(
20
+ input: str,
21
+ output: str = "",
22
+ weights_path: str = "",
23
+ model_tag: str = "latest",
24
+ model_bitrate: str = "8kbps",
25
+ device: str = "cuda",
26
+ model_type: str = "44khz",
27
+ verbose: bool = False,
28
+ ):
29
+ """Decode audio from codes.
30
+
31
+ Parameters
32
+ ----------
33
+ input : str
34
+ Path to input directory or file
35
+ output : str, optional
36
+ Path to output directory, by default "".
37
+ If `input` is a directory, the directory sub-tree relative to `input` is re-created in `output`.
38
+ weights_path : str, optional
39
+ Path to weights file, by default "". If not specified, the weights file will be downloaded from the internet using the
40
+ model_tag and model_type.
41
+ model_tag : str, optional
42
+ Tag of the model to use, by default "latest". Ignored if `weights_path` is specified.
43
+ model_bitrate: str
44
+ Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps".
45
+ device : str, optional
46
+ Device to use, by default "cuda". If "cpu", the model will be loaded on the CPU.
47
+ model_type : str, optional
48
+ The type of model to use. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". Ignored if `weights_path` is specified.
49
+ """
50
+ generator = load_model(
51
+ model_type=model_type,
52
+ model_bitrate=model_bitrate,
53
+ tag=model_tag,
54
+ load_path=weights_path,
55
+ )
56
+ generator.to(device)
57
+ generator.eval()
58
+
59
+ # Find all .dac files in input directory
60
+ _input = Path(input)
61
+ input_files = list(_input.glob("**/*.dac"))
62
+
63
+ # If input is a .dac file, add it to the list
64
+ if _input.suffix == ".dac":
65
+ input_files.append(_input)
66
+
67
+ # Create output directory
68
+ output = Path(output)
69
+ output.mkdir(parents=True, exist_ok=True)
70
+
71
+ for i in tqdm(range(len(input_files)), desc=f"Decoding files"):
72
+ # Load file
73
+ artifact = DACFile.load(input_files[i])
74
+
75
+ # Reconstruct audio from codes
76
+ recons = generator.decompress(artifact, verbose=verbose)
77
+
78
+ # Compute output path
79
+ relative_path = input_files[i].relative_to(input)
80
+ output_dir = output / relative_path.parent
81
+ if not relative_path.name:
82
+ output_dir = output
83
+ relative_path = input_files[i]
84
+ output_name = relative_path.with_suffix(".wav").name
85
+ output_path = output_dir / output_name
86
+ output_path.parent.mkdir(parents=True, exist_ok=True)
87
+
88
+ # Write to file
89
+ recons.write(output_path)
90
+
91
+
92
+ if __name__ == "__main__":
93
+ args = argbind.parse_args()
94
+ with argbind.scope(args):
95
+ decode()
hunyuanvideo_foley/models/dac_vae/utils/encode.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import warnings
3
+ from pathlib import Path
4
+
5
+ import argbind
6
+ import numpy as np
7
+ import torch
8
+ from audiotools import AudioSignal
9
+ from audiotools.core import util
10
+ from tqdm import tqdm
11
+
12
+ from . import load_model
13
+
14
+ warnings.filterwarnings("ignore", category=UserWarning)
15
+
16
+
17
+ @argbind.bind(group="encode", positional=True, without_prefix=True)
18
+ @torch.inference_mode()
19
+ @torch.no_grad()
20
+ def encode(
21
+ input: str,
22
+ output: str = "",
23
+ weights_path: str = "",
24
+ model_tag: str = "latest",
25
+ model_bitrate: str = "8kbps",
26
+ n_quantizers: int = None,
27
+ device: str = "cuda",
28
+ model_type: str = "44khz",
29
+ win_duration: float = 5.0,
30
+ verbose: bool = False,
31
+ ):
32
+ """Encode audio files in input path to .dac format.
33
+
34
+ Parameters
35
+ ----------
36
+ input : str
37
+ Path to input audio file or directory
38
+ output : str, optional
39
+ Path to output directory, by default "". If `input` is a directory, the directory sub-tree relative to `input` is re-created in `output`.
40
+ weights_path : str, optional
41
+ Path to weights file, by default "". If not specified, the weights file will be downloaded from the internet using the
42
+ model_tag and model_type.
43
+ model_tag : str, optional
44
+ Tag of the model to use, by default "latest". Ignored if `weights_path` is specified.
45
+ model_bitrate: str
46
+ Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps".
47
+ n_quantizers : int, optional
48
+ Number of quantizers to use, by default None. If not specified, all the quantizers will be used and the model will compress at maximum bitrate.
49
+ device : str, optional
50
+ Device to use, by default "cuda"
51
+ model_type : str, optional
52
+ The type of model to use. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". Ignored if `weights_path` is specified.
53
+ """
54
+ generator = load_model(
55
+ model_type=model_type,
56
+ model_bitrate=model_bitrate,
57
+ tag=model_tag,
58
+ load_path=weights_path,
59
+ )
60
+ generator.to(device)
61
+ generator.eval()
62
+ kwargs = {"n_quantizers": n_quantizers}
63
+
64
+ # Find all audio files in input path
65
+ input = Path(input)
66
+ audio_files = util.find_audio(input)
67
+
68
+ output = Path(output)
69
+ output.mkdir(parents=True, exist_ok=True)
70
+
71
+ for i in tqdm(range(len(audio_files)), desc="Encoding files"):
72
+ # Load file
73
+ signal = AudioSignal(audio_files[i])
74
+
75
+ # Encode audio to .dac format
76
+ artifact = generator.compress(signal, win_duration, verbose=verbose, **kwargs)
77
+
78
+ # Compute output path
79
+ relative_path = audio_files[i].relative_to(input)
80
+ output_dir = output / relative_path.parent
81
+ if not relative_path.name:
82
+ output_dir = output
83
+ relative_path = audio_files[i]
84
+ output_name = relative_path.with_suffix(".dac").name
85
+ output_path = output_dir / output_name
86
+ output_path.parent.mkdir(parents=True, exist_ok=True)
87
+
88
+ artifact.save(output_path)
89
+
90
+
91
+ if __name__ == "__main__":
92
+ args = argbind.parse_args()
93
+ with argbind.scope(args):
94
+ encode()
hunyuanvideo_foley/models/hifi_foley.py ADDED
@@ -0,0 +1,794 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple, Optional, Union, Dict
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from einops import rearrange
7
+ from einops.layers.torch import Rearrange
8
+ from diffusers.models import ModelMixin
9
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
10
+
11
+ from .nn.activation_layers import SwiGLU, get_activation_layer
12
+ from .nn.attn_layers import apply_rotary_emb, attention
13
+ from .nn.embed_layers import TimestepEmbedder, ConditionProjection, PatchEmbed1D
14
+ from .nn.mlp_layers import MLP, ConvMLP, FinalLayer1D, ChannelLastConv1d
15
+ from .nn.modulate_layers import ModulateDiT, ckpt_wrapper, apply_gate, modulate
16
+ from .nn.norm_layers import get_norm_layer
17
+ from .nn.posemb_layers import get_nd_rotary_pos_embed
18
+
19
+ def interleave_two_sequences(x1: torch.Tensor, x2: torch.Tensor):
20
+ # [B, N1, H, C] & [B, N2, H, C]
21
+ B, N1, H, C = x1.shape
22
+ B, N2, H, C = x2.shape
23
+ assert x1.ndim == x2.ndim == 4
24
+
25
+ if N1 != N2:
26
+ x2 = x2.view(B, N2, -1).transpose(1, 2)
27
+ x2 = F.interpolate(x2, size=(N1), mode="nearest-exact")
28
+ x2 = x2.transpose(1, 2).view(B, N1, H, C)
29
+ x = torch.stack((x1, x2), dim=2)
30
+ x = x.reshape(B, N1 * 2, H, C)
31
+ return x
32
+
33
+ def decouple_interleaved_two_sequences(x: torch.Tensor, len1: int, len2: int):
34
+ B, N, H, C = x.shape
35
+ assert N % 2 == 0 and N // 2 == len1
36
+
37
+ x = x.reshape(B, -1, 2, H, C)
38
+ x1 = x[:, :, 0]
39
+ x2 = x[:, :, 1]
40
+ if x2.shape[1] != len2:
41
+ x2 = x2.view(B, len1, H * C).transpose(1, 2)
42
+ x2 = F.interpolate(x2, size=(len2), mode="nearest-exact")
43
+ x2 = x2.transpose(1, 2).view(B, len2, H, C)
44
+ return x1, x2
45
+
46
+ class TwoStreamCABlock(nn.Module):
47
+ def __init__(
48
+ self,
49
+ hidden_size: int,
50
+ num_heads: int,
51
+ mlp_ratio: float,
52
+ mlp_act_type: str = "gelu_tanh",
53
+ qk_norm: bool = True,
54
+ qk_norm_type: str = "rms",
55
+ qkv_bias: bool = False,
56
+ attn_mode: str = "torch",
57
+ reverse: bool = False,
58
+ interleaved_audio_visual_rope: bool = False,
59
+ dtype: Optional[torch.dtype] = None,
60
+ device: Optional[torch.device] = None,
61
+ ):
62
+ factory_kwargs = {"device": device, "dtype": dtype}
63
+ super().__init__()
64
+
65
+ self.deterministic = False
66
+ self.reverse = reverse
67
+ self.attn_mode = attn_mode
68
+ self.num_heads = num_heads
69
+ self.hidden_size = hidden_size
70
+ head_dim = hidden_size // num_heads
71
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
72
+
73
+ self.interleaved_audio_visual_rope = interleaved_audio_visual_rope
74
+
75
+ # Self attention for audio + visual
76
+ self.audio_mod = ModulateDiT(hidden_size, factor=9, act_layer=get_activation_layer("silu"), **factory_kwargs)
77
+ self.audio_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
78
+ self.audio_self_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs)
79
+ qk_norm_layer = get_norm_layer(qk_norm_type)
80
+ self.audio_self_q_norm = (
81
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
82
+ )
83
+ self.audio_self_k_norm = (
84
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
85
+ )
86
+ self.audio_self_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
87
+
88
+ # visual cond
89
+ self.v_cond_mod = ModulateDiT(hidden_size, factor=9, act_layer=get_activation_layer("silu"), **factory_kwargs)
90
+ self.v_cond_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
91
+ self.v_cond_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs)
92
+ self.v_cond_attn_q_norm = (
93
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
94
+ )
95
+ self.v_cond_attn_k_norm = (
96
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
97
+ )
98
+ self.v_cond_self_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
99
+
100
+ self.max_text_len = 100
101
+ self.rope_dim_list = None
102
+
103
+ # audio and video norm for cross attention with text
104
+ self.audio_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
105
+ self.v_cond_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
106
+
107
+ # Cross attention: (video_audio) as query, text as key/value
108
+ self.audio_cross_q = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
109
+ self.v_cond_cross_q = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
110
+ self.text_cross_kv = nn.Linear(hidden_size, hidden_size * 2, bias=qkv_bias, **factory_kwargs)
111
+
112
+ self.audio_cross_q_norm = (
113
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
114
+ )
115
+ self.v_cond_cross_q_norm = (
116
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
117
+ )
118
+ self.text_cross_k_norm = (
119
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
120
+ )
121
+ self.audio_cross_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
122
+ self.v_cond_cross_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
123
+
124
+ # MLPs
125
+ self.audio_norm3 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
126
+ self.audio_mlp = MLP(
127
+ hidden_size, mlp_hidden_dim, act_layer=get_activation_layer(mlp_act_type), bias=True, **factory_kwargs
128
+ )
129
+
130
+ self.v_cond_norm3 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
131
+ self.v_cond_mlp = MLP(
132
+ hidden_size, mlp_hidden_dim, act_layer=get_activation_layer(mlp_act_type), bias=True, **factory_kwargs
133
+ )
134
+
135
+ def build_rope_for_text(self, text_len, head_dim, rope_dim_list=None):
136
+ target_ndim = 1 # n-d RoPE
137
+ rope_sizes = [text_len]
138
+
139
+ if rope_dim_list is None:
140
+ rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
141
+ assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer"
142
+
143
+ text_freqs_cos, text_freqs_sin = get_nd_rotary_pos_embed(
144
+ rope_dim_list=rope_dim_list,
145
+ start=rope_sizes,
146
+ theta=10000,
147
+ use_real=True,
148
+ theta_rescale_factor=1.0,
149
+ )
150
+ return text_freqs_cos, text_freqs_sin
151
+
152
+ def set_attn_mode(self, new_mode):
153
+ if new_mode != "torch":
154
+ raise NotImplementedError(f"Only support 'torch' mode, got {new_mode}.")
155
+ self.attn_mode = new_mode
156
+
157
+ def enable_deterministic(self):
158
+ self.deterministic = True
159
+
160
+ def disable_deterministic(self):
161
+ self.deterministic = False
162
+
163
+ def forward(
164
+ self,
165
+ audio: torch.Tensor,
166
+ cond: torch.Tensor,
167
+ v_cond: torch.Tensor,
168
+ attn_mask: torch.Tensor,
169
+ vec: torch.Tensor,
170
+ freqs_cis: tuple = None,
171
+ v_freqs_cis: tuple = None,
172
+ sync_vec: torch.Tensor = None,
173
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
174
+ # Get modulation parameters
175
+ if sync_vec is not None:
176
+ assert sync_vec.ndim == 3
177
+ (audio_mod1_shift, audio_mod1_scale, audio_mod1_gate,
178
+ audio_mod2_shift, audio_mod2_scale, audio_mod2_gate,
179
+ audio_mod3_shift, audio_mod3_scale, audio_mod3_gate,
180
+ ) = self.audio_mod(sync_vec).chunk(9, dim=-1)
181
+ else:
182
+ (audio_mod1_shift, audio_mod1_scale, audio_mod1_gate,
183
+ audio_mod2_shift, audio_mod2_scale, audio_mod2_gate,
184
+ audio_mod3_shift, audio_mod3_scale, audio_mod3_gate,
185
+ ) = self.audio_mod(vec).chunk(9, dim=-1)
186
+
187
+ (
188
+ v_cond_mod1_shift,
189
+ v_cond_mod1_scale,
190
+ v_cond_mod1_gate,
191
+ v_cond_mod2_shift,
192
+ v_cond_mod2_scale,
193
+ v_cond_mod2_gate,
194
+ v_cond_mod3_shift,
195
+ v_cond_mod3_scale,
196
+ v_cond_mod3_gate,
197
+ ) = self.v_cond_mod(vec).chunk(9, dim=-1)
198
+
199
+ # 1. Self Attention for audio + visual
200
+ audio_modulated = self.audio_norm1(audio)
201
+ audio_modulated = modulate(audio_modulated, shift=audio_mod1_shift, scale=audio_mod1_scale)
202
+ audio_qkv = self.audio_self_attn_qkv(audio_modulated)
203
+ audio_q, audio_k, audio_v = rearrange(audio_qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads)
204
+ audio_q = self.audio_self_q_norm(audio_q).to(audio_v)
205
+ audio_k = self.audio_self_k_norm(audio_k).to(audio_v)
206
+
207
+ # Prepare visual cond for attention
208
+ v_cond_modulated = self.v_cond_norm1(v_cond)
209
+ v_cond_modulated = modulate(v_cond_modulated, shift=v_cond_mod1_shift, scale=v_cond_mod1_scale)
210
+ v_cond_qkv = self.v_cond_attn_qkv(v_cond_modulated)
211
+ v_cond_q, v_cond_k, v_cond_v = rearrange(v_cond_qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads)
212
+ v_cond_q = self.v_cond_attn_q_norm(v_cond_q).to(v_cond_v)
213
+ v_cond_k = self.v_cond_attn_k_norm(v_cond_k).to(v_cond_v)
214
+
215
+ # Apply RoPE if needed for audio and visual
216
+ if freqs_cis is not None:
217
+ if not self.interleaved_audio_visual_rope:
218
+ audio_qq, audio_kk = apply_rotary_emb(audio_q, audio_k, freqs_cis, head_first=False)
219
+ audio_q, audio_k = audio_qq, audio_kk
220
+ else:
221
+ ori_audio_len = audio_q.shape[1]
222
+ ori_v_con_len = v_cond_q.shape[1]
223
+ interleaved_audio_visual_q = interleave_two_sequences(audio_q, v_cond_q)
224
+ interleaved_audio_visual_k = interleave_two_sequences(audio_k, v_cond_k)
225
+ interleaved_audio_visual_qq, interleaved_audio_visual_kk = apply_rotary_emb(
226
+ interleaved_audio_visual_q, interleaved_audio_visual_k, freqs_cis, head_first=False
227
+ )
228
+ audio_qq, v_cond_qq = decouple_interleaved_two_sequences(
229
+ interleaved_audio_visual_qq, ori_audio_len, ori_v_con_len
230
+ )
231
+ audio_kk, v_cond_kk = decouple_interleaved_two_sequences(
232
+ interleaved_audio_visual_kk, ori_audio_len, ori_v_con_len
233
+ )
234
+ audio_q, audio_k = audio_qq, audio_kk
235
+ v_cond_q, v_cond_k = v_cond_qq, v_cond_kk
236
+
237
+ # Apply RoPE to visual if needed and not interleaved
238
+ if v_freqs_cis is not None and not self.interleaved_audio_visual_rope:
239
+ v_cond_qq, v_cond_kk = apply_rotary_emb(v_cond_q, v_cond_k, v_freqs_cis, head_first=False)
240
+ v_cond_q, v_cond_k = v_cond_qq, v_cond_kk
241
+
242
+ # Concatenate for self-attention
243
+ q = torch.cat((v_cond_q, audio_q), dim=1)
244
+ k = torch.cat((v_cond_k, audio_k), dim=1)
245
+ v = torch.cat((v_cond_v, audio_v), dim=1)
246
+
247
+ # Run self-attention
248
+ attn = attention(q, k, v, mode=self.attn_mode, attn_mask=attn_mask, deterministic=self.deterministic)
249
+ v_cond_attn, audio_attn = torch.split(attn, [v_cond.shape[1], audio.shape[1]], dim=1)
250
+
251
+ # Apply self-attention output to audio and v_cond
252
+ audio = audio + apply_gate(self.audio_self_proj(audio_attn), gate=audio_mod1_gate)
253
+ v_cond = v_cond + apply_gate(self.v_cond_self_proj(v_cond_attn), gate=v_cond_mod1_gate)
254
+
255
+ # 2. Cross Attention: (v_cond, audio) as query, text as key/value
256
+ # audio, v_cond modulation
257
+ audio_modulated = self.audio_norm2(audio)
258
+ audio_modulated = modulate(audio_modulated, shift=audio_mod2_shift, scale=audio_mod2_scale)
259
+ v_cond_modulated = self.v_cond_norm2(v_cond)
260
+ v_cond_modulated = modulate(v_cond_modulated, shift=v_cond_mod2_shift, scale=v_cond_mod2_scale)
261
+
262
+ # Prepare audio query
263
+ audio_q = self.audio_cross_q(audio_modulated)
264
+ audio_q = rearrange(audio_q, "B L (H D) -> B L H D", H=self.num_heads)
265
+ audio_q = self.audio_cross_q_norm(audio_q)
266
+
267
+ # Prepare v_cond query
268
+ v_cond_q = self.v_cond_cross_q(v_cond_modulated)
269
+ v_cond_q = rearrange(v_cond_q, "B L (H D) -> B L H D", H=self.num_heads)
270
+ v_cond_q = self.v_cond_cross_q_norm(v_cond_q)
271
+
272
+ # Prepare text key/value
273
+ text_kv = self.text_cross_kv(cond)
274
+ text_k, text_v = rearrange(text_kv, "B L (K H D) -> K B L H D", K=2, H=self.num_heads)
275
+ text_k = self.text_cross_k_norm(text_k).to(text_v)
276
+
277
+ # Apply RoPE to (v_cond, audio) query and text key if needed
278
+ head_dim = self.hidden_size // self.num_heads
279
+ audio_cross_freqs_cos, audio_cross_freqs_sin = self.build_rope_for_text(audio_q.shape[1], head_dim, rope_dim_list=self.rope_dim_list)
280
+ audio_cross_freqs_cis = (audio_cross_freqs_cos.to(audio_q.device), audio_cross_freqs_sin.to(audio_q.device))
281
+ audio_q = apply_rotary_emb(audio_q, audio_q, audio_cross_freqs_cis, head_first=False)[0]
282
+
283
+ v_cond_cross_freqs_cos, v_cond_cross_freqs_sin = self.build_rope_for_text(v_cond_q.shape[1], head_dim, rope_dim_list=self.rope_dim_list)
284
+ v_cond_cross_freqs_cis = (v_cond_cross_freqs_cos.to(v_cond_q.device), v_cond_cross_freqs_sin.to(v_cond_q.device))
285
+ v_cond_q = apply_rotary_emb(v_cond_q, v_cond_q, v_cond_cross_freqs_cis, head_first=False)[0]
286
+
287
+ text_len = text_k.shape[1]
288
+
289
+ text_freqs_cos, text_freqs_sin = self.build_rope_for_text(text_len, head_dim,
290
+ rope_dim_list=self.rope_dim_list)
291
+ text_freqs_cis = (text_freqs_cos.to(text_k.device), text_freqs_sin.to(text_k.device))
292
+ text_k = apply_rotary_emb(text_k, text_k, text_freqs_cis, head_first=False)[1]
293
+
294
+ # Concat v_cond and audio for cross-attention
295
+ v_cond_audio_q = torch.cat([v_cond_q, audio_q], dim=1)
296
+
297
+ # Run cross-attention
298
+ cross_attn = attention(v_cond_audio_q, text_k, text_v, mode=self.attn_mode, deterministic=self.deterministic)
299
+ v_cond_cross_attn, audio_cross_attn = torch.split(cross_attn, [v_cond.shape[1], audio.shape[1]], dim=1)
300
+
301
+ # Apply cross-attention output
302
+ audio = audio + apply_gate(self.audio_cross_proj(audio_cross_attn), gate=audio_mod2_gate)
303
+ v_cond = v_cond + apply_gate(self.v_cond_cross_proj(v_cond_cross_attn), gate=v_cond_mod2_gate)
304
+
305
+ # 3. Apply MLPs
306
+ audio = audio + apply_gate(
307
+ self.audio_mlp(modulate(self.audio_norm3(audio), shift=audio_mod3_shift, scale=audio_mod3_scale)),
308
+ gate=audio_mod3_gate,
309
+ )
310
+
311
+ # Apply visual MLP
312
+ v_cond = v_cond + apply_gate(
313
+ self.v_cond_mlp(modulate(self.v_cond_norm3(v_cond), shift=v_cond_mod3_shift, scale=v_cond_mod3_scale)),
314
+ gate=v_cond_mod3_gate,
315
+ )
316
+
317
+ return audio, cond, v_cond
318
+
319
+ class SingleStreamBlock(nn.Module):
320
+
321
+ def __init__(self, hidden_size: int,
322
+ num_heads: int,
323
+ mlp_ratio: float,
324
+ qk_norm_type: str = "rms",
325
+ dtype: Optional[torch.dtype] = None,
326
+ device: Optional[torch.device] = None,):
327
+ factory_kwargs = {"device": device, "dtype": dtype}
328
+ super().__init__()
329
+
330
+ self.hidden_size = hidden_size
331
+ self.num_heads = num_heads
332
+
333
+ self.modulation = ModulateDiT(
334
+ hidden_size=hidden_size,
335
+ factor=6,
336
+ act_layer=get_activation_layer("silu"),
337
+ **factory_kwargs,
338
+ )
339
+ self.linear_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=True)
340
+ self.linear1 = ChannelLastConv1d(hidden_size, hidden_size, kernel_size=3, padding=1, **factory_kwargs)
341
+ self.linear2 = ConvMLP(hidden_size, hidden_size * mlp_ratio, kernel_size=3, padding=1, **factory_kwargs)
342
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False)
343
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False)
344
+ self.q_norm = nn.RMSNorm(hidden_size // num_heads)
345
+ self.k_norm = nn.RMSNorm(hidden_size // num_heads)
346
+ self.rearrange = Rearrange("B L (H D K) -> B H L D K", K=3, H=num_heads)
347
+
348
+ def forward(self, x: torch.Tensor, cond: torch.Tensor,freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None):
349
+ assert cond.ndim == 3, "Condition should be in shape of [B, T, D]"
350
+ modulation = self.modulation(cond)
351
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = modulation.chunk(6, dim=-1)
352
+ x_norm1 = self.norm1(x) * (1 + scale_msa) + shift_msa
353
+
354
+ qkv = self.linear_qkv(x_norm1)
355
+ q, k, v = self.rearrange(qkv).chunk(3, dim=-1)
356
+ q = q.squeeze(-1)
357
+ k = k.squeeze(-1)
358
+ v = v.squeeze(-1)
359
+
360
+ q = self.q_norm(q)
361
+ k = self.k_norm(k)
362
+ q, k = apply_rotary_emb(q, k, freqs_cis, head_first=True)
363
+
364
+ q = q.contiguous()
365
+ k = k.contiguous()
366
+ v = v.contiguous()
367
+ out = F.scaled_dot_product_attention(q, k, v)
368
+ out = rearrange(out, 'b h n d -> b n (h d)').contiguous()
369
+
370
+ x = x + apply_gate(self.linear1(out),gate=gate_msa)
371
+ x_norm = self.norm2(x) * (1 + scale_mlp) + shift_mlp
372
+ x = x + apply_gate(self.linear2(x_norm), gate=gate_mlp)
373
+
374
+ return x
375
+
376
+ class HunyuanVideoFoley(ModelMixin, ConfigMixin):
377
+ @register_to_config
378
+ def __init__(
379
+ self,
380
+ model_config,
381
+ dtype: Optional[torch.dtype] = None,
382
+ device: Optional[torch.device] = None,
383
+ ):
384
+ factory_kwargs = {"device": device, "dtype": dtype}
385
+ super().__init__()
386
+
387
+ model_args = model_config.model_config.model_kwargs
388
+ self.depth_triple_blocks = model_args.get("depth_triple_blocks", 19)
389
+ self.depth_single_blocks = model_args.get("depth_single_blocks", 38)
390
+ # Gradient checkpoint.
391
+ self.gradient_checkpoint = False
392
+ self.gradient_checkpoint_layers = None
393
+ if self.gradient_checkpoint:
394
+ assert self.gradient_checkpoint_layers <= self.depth_triple_blocks + self.depth_single_blocks, (
395
+ f"Gradient checkpoint layers must be less or equal than the depth of the model. "
396
+ f"Got gradient_checkpoint_layers={self.gradient_checkpoint_layers} and depth={self.depth_triple_blocks + self.depth_single_blocks}."
397
+ )
398
+
399
+ self.interleaved_audio_visual_rope = model_args.get("interleaved_audio_visual_rope", False)
400
+
401
+ # Condition projection. Default to linear projection.
402
+ self.condition_projection = model_args.get("condition_projection", "linear")
403
+ self.condition_dim = model_args.get("condition_dim", None)
404
+ self.use_attention_mask = model_args.get("use_attention_mask", False)
405
+
406
+ self.patch_size = model_args.get("patch_size", 1)
407
+ self.visual_in_channels = model_args.get("clip_dim", 768)
408
+ self.audio_vae_latent_dim = model_args.get("audio_vae_latent_dim", 128)
409
+ self.out_channels = self.audio_vae_latent_dim
410
+ self.unpatchify_channels = self.out_channels
411
+ self.reverse = model_args.get("reverse", False)
412
+
413
+ self.num_heads = model_args.get("num_heads", 24)
414
+ self.hidden_size = model_args.get("hidden_size", 3072)
415
+ self.rope_dim_list = model_args.get("rope_dim_list", None)
416
+ self.mlp_ratio = model_args.get("mlp_ratio", 4.0)
417
+ self.mlp_act_type = model_args.get("mlp_act_type", "gelu_tanh")
418
+
419
+ self.qkv_bias = model_args.get("qkv_bias", True)
420
+ self.qk_norm = model_args.get("qk_norm", True)
421
+ self.qk_norm_type = model_args.get("qk_norm_type", "rms")
422
+ self.attn_mode = model_args.get("attn_mode", "torch")
423
+
424
+ self.embedder_type = model_args.get("embedder_type", "default")
425
+
426
+ # sync condition things
427
+ self.sync_modulation = model_args.get("sync_modulation", False)
428
+ self.add_sync_feat_to_audio = model_args.get("add_sync_feat_to_audio", False)
429
+ self.sync_feat_dim = model_args.get("sync_feat_dim", 768)
430
+ self.sync_in_ksz = model_args.get("sync_in_ksz", 1)
431
+
432
+ # condition tokens length
433
+ self.clip_len = model_args.get("clip_length", 64)
434
+ self.sync_len = model_args.get("sync_length", 192)
435
+
436
+ if self.hidden_size % self.num_heads != 0:
437
+ raise ValueError(f"Hidden size {self.hidden_size} must be divisible by num_heads {self.num_heads}")
438
+
439
+ # Build audio patchify layer and visual gated linear projection
440
+ self.patch_size = 1
441
+ self.audio_embedder = PatchEmbed1D(self.patch_size, self.audio_vae_latent_dim, self.hidden_size, **factory_kwargs)
442
+ self.visual_proj = SwiGLU(self.visual_in_channels, hidden_dim=self.hidden_size, out_dim=self.hidden_size)
443
+
444
+ # condition
445
+ if self.condition_projection == "linear":
446
+ self.cond_in = ConditionProjection(
447
+ self.condition_dim, self.hidden_size, get_activation_layer("silu"), **factory_kwargs
448
+ )
449
+ else:
450
+ raise NotImplementedError(f"Unsupported condition_projection: {self.condition_projection}")
451
+
452
+ # time modulation
453
+ self.time_in = TimestepEmbedder(self.hidden_size, get_activation_layer("silu"), **factory_kwargs)
454
+
455
+ # visual sync embedder if needed
456
+ if self.sync_in_ksz == 1:
457
+ sync_in_padding = 0
458
+ elif self.sync_in_ksz == 3:
459
+ sync_in_padding = 1
460
+ else:
461
+ raise ValueError
462
+ if self.sync_modulation or self.add_sync_feat_to_audio:
463
+ self.sync_in = nn.Sequential(
464
+ nn.Linear(self.sync_feat_dim, self.hidden_size),
465
+ nn.SiLU(),
466
+ ConvMLP(self.hidden_size, self.hidden_size * 4, kernel_size=self.sync_in_ksz, padding=sync_in_padding),
467
+ )
468
+ self.sync_pos_emb = nn.Parameter(torch.zeros((1, 1, 8, self.sync_feat_dim)))
469
+
470
+ self.triple_blocks = nn.ModuleList(
471
+ [
472
+ TwoStreamCABlock(
473
+ hidden_size=self.hidden_size,
474
+ num_heads=self.num_heads,
475
+ mlp_ratio=self.mlp_ratio,
476
+ mlp_act_type=self.mlp_act_type,
477
+ qk_norm=self.qk_norm,
478
+ qk_norm_type=self.qk_norm_type,
479
+ qkv_bias=self.qkv_bias,
480
+ attn_mode=self.attn_mode,
481
+ reverse=self.reverse,
482
+ interleaved_audio_visual_rope=self.interleaved_audio_visual_rope,
483
+ **factory_kwargs,
484
+ )
485
+ for _ in range(self.depth_triple_blocks)
486
+ ]
487
+ )
488
+
489
+
490
+ self.single_blocks = nn.ModuleList(
491
+ [
492
+ SingleStreamBlock(
493
+ hidden_size=self.hidden_size,
494
+ num_heads=self.num_heads,
495
+ mlp_ratio=self.mlp_ratio,
496
+ qk_norm_type=self.qk_norm_type,
497
+ **factory_kwargs,
498
+ )
499
+ for _ in range(self.depth_single_blocks)
500
+ ]
501
+ )
502
+
503
+ self.final_layer = FinalLayer1D(
504
+ self.hidden_size, self.patch_size, self.out_channels, get_activation_layer("silu"), **factory_kwargs
505
+ )
506
+ self.unpatchify_channels = self.out_channels
507
+
508
+ self.empty_clip_feat = nn.Parameter(torch.zeros(1, self.visual_in_channels), requires_grad=True)
509
+ self.empty_sync_feat = nn.Parameter(torch.zeros(1, self.sync_feat_dim), requires_grad=True)
510
+ nn.init.constant_(self.empty_clip_feat, 0)
511
+ nn.init.constant_(self.empty_sync_feat, 0)
512
+
513
+ def get_empty_string_sequence(self, bs=None) -> torch.Tensor:
514
+ if bs is None:
515
+ return self.empty_string_feat
516
+ else:
517
+ return self.empty_string_feat.unsqueeze(0).expand(bs, -1, -1)
518
+
519
+ def get_empty_clip_sequence(self, bs=None, len=None) -> torch.Tensor:
520
+ len = len if len is not None else self.clip_len
521
+ if bs is None:
522
+ return self.empty_clip_feat.expand(len, -1) # 15s
523
+ else:
524
+ return self.empty_clip_feat.unsqueeze(0).expand(bs, len, -1) # 15s
525
+
526
+ def get_empty_sync_sequence(self, bs=None, len=None) -> torch.Tensor:
527
+ len = len if len is not None else self.sync_len
528
+ if bs is None:
529
+ return self.empty_sync_feat.expand(len, -1)
530
+ else:
531
+ return self.empty_sync_feat.unsqueeze(0).expand(bs, len, -1)
532
+
533
+ def build_rope_for_audio_visual(self, audio_emb_len, visual_cond_len):
534
+ assert self.patch_size == 1
535
+ # ======================================== Build RoPE for audio tokens ======================================
536
+ target_ndim = 1 # n-d RoPE
537
+ rope_sizes = [audio_emb_len]
538
+ head_dim = self.hidden_size // self.num_heads
539
+ rope_dim_list = self.rope_dim_list
540
+ if rope_dim_list is None:
541
+ rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
542
+ assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer"
543
+ freqs_cos, freqs_sin = get_nd_rotary_pos_embed(
544
+ rope_dim_list=rope_dim_list,
545
+ start=rope_sizes,
546
+ theta=10000,
547
+ use_real=True,
548
+ theta_rescale_factor=1.0,
549
+ )
550
+
551
+ # ========================== Build RoPE for clip tokens =========================
552
+ target_ndim = 1 # n-d RoPE
553
+ rope_sizes = [visual_cond_len]
554
+ head_dim = self.hidden_size // self.num_heads
555
+ rope_dim_list = self.rope_dim_list
556
+ if rope_dim_list is None:
557
+ rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
558
+ assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer"
559
+ v_freqs_cos, v_freqs_sin = get_nd_rotary_pos_embed(
560
+ rope_dim_list=rope_dim_list,
561
+ start=rope_sizes,
562
+ theta=10000,
563
+ use_real=True,
564
+ theta_rescale_factor=1.0,
565
+ freq_scaling=1.0 * audio_emb_len / visual_cond_len,
566
+ )
567
+ return freqs_cos, freqs_sin, v_freqs_cos, v_freqs_sin
568
+
569
+ def build_rope_for_interleaved_audio_visual(self, total_len):
570
+ assert self.patch_size == 1
571
+ # ========================== Build RoPE for audio tokens ========================
572
+ target_ndim = 1 # n-d RoPE
573
+ rope_sizes = [total_len]
574
+ head_dim = self.hidden_size // self.num_heads
575
+ rope_dim_list = self.rope_dim_list
576
+ if rope_dim_list is None:
577
+ rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
578
+ assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer"
579
+ freqs_cos, freqs_sin = get_nd_rotary_pos_embed(
580
+ rope_dim_list=rope_dim_list,
581
+ start=rope_sizes,
582
+ theta=10000,
583
+ use_real=True,
584
+ theta_rescale_factor=1.0,
585
+ )
586
+ return freqs_cos, freqs_sin
587
+
588
+ def set_attn_mode(self, new_mode):
589
+ for block in self.triple_blocks:
590
+ block.set_attn_mode(new_mode)
591
+ for block in self.single_blocks:
592
+ block.set_attn_mode(new_mode)
593
+
594
+ def enable_deterministic(self):
595
+ for block in self.triple_blocks:
596
+ block.enable_deterministic()
597
+ for block in self.single_blocks:
598
+ block.enable_deterministic()
599
+
600
+ def disable_deterministic(self):
601
+ for block in self.triple_blocks:
602
+ block.disable_deterministic()
603
+ for block in self.single_blocks:
604
+ block.disable_deterministic()
605
+
606
+ def forward(
607
+ self,
608
+ x: torch.Tensor,
609
+ t: torch.Tensor, # Should be in range(0, 1000).
610
+ clip_feat: Optional[torch.Tensor] = None,
611
+ cond: torch.Tensor = None,
612
+ audio_mask: Optional[torch.Tensor] = None,
613
+ cond_mask: torch.Tensor = None,
614
+ sync_feat: Optional[torch.Tensor] = None,
615
+ drop_visual: Optional[List[bool]] = None,
616
+ return_dict: bool = True,
617
+ ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
618
+ out = {}
619
+ audio = x
620
+ bs, _, ol = x.shape
621
+ tl = ol // self.patch_size
622
+
623
+ # Prepare learnable empty conditions for visual condition
624
+ if drop_visual is not None:
625
+ clip_feat[drop_visual] = self.get_empty_clip_sequence().to(dtype=clip_feat.dtype)
626
+ sync_feat[drop_visual] = self.get_empty_sync_sequence().to(dtype=sync_feat.dtype)
627
+
628
+ # ========================= Prepare time & visual modulation =========================
629
+ vec = self.time_in(t)
630
+ sync_vec = None
631
+ if self.sync_modulation:
632
+ assert sync_feat is not None and sync_feat.shape[1] % 8 == 0
633
+ sync_feat = sync_feat.view(bs, int(sync_feat.shape[1] / 8), 8, self.sync_feat_dim) + self.sync_pos_emb
634
+ sync_feat = sync_feat.view(bs, -1, self.sync_feat_dim) # bs, num_segments * 8, channels
635
+ sync_vec = self.sync_in(sync_feat) # bs, num_segments * 8, c
636
+ sync_vec = (
637
+ F.interpolate(sync_vec.transpose(1, 2), size=(tl), mode="nearest-exact").contiguous().transpose(1, 2)
638
+ ) # bs, tl, c
639
+ sync_vec = sync_vec + vec.unsqueeze(1)
640
+ elif self.add_sync_feat_to_audio:
641
+ assert sync_feat is not None and sync_feat.shape[1] % 8 == 0
642
+ sync_feat = sync_feat.view(bs, sync_feat.shape[1] // 8, 8, self.sync_feat_dim) + self.sync_pos_emb
643
+ sync_feat = sync_feat.view(bs, -1, self.sync_feat_dim) # bs, num_segments * 8, channels
644
+ sync_feat = self.sync_in(sync_feat) # bs, num_segments * 8, c
645
+ add_sync_feat_to_audio = (
646
+ F.interpolate(sync_feat.transpose(1, 2), size=(tl), mode="nearest-exact").contiguous().transpose(1, 2)
647
+ ) # bs, tl, c
648
+
649
+ # ========================= Get text, audio and video clip embedding =========================
650
+ cond = self.cond_in(cond)
651
+ cond_seq_len = cond.shape[1]
652
+
653
+ audio = self.audio_embedder(x)
654
+ audio_seq_len = audio.shape[1]
655
+ v_cond = self.visual_proj(clip_feat)
656
+ v_cond_seq_len = v_cond.shape[1]
657
+
658
+ # ========================= Compute attention mask =========================
659
+ attn_mask = None
660
+ if self.use_attention_mask:
661
+ assert cond_mask is not None
662
+ batch_size = audio.shape[0]
663
+ seq_len = cond_seq_len + v_cond_seq_len + audio_seq_len
664
+
665
+ # get default audio_mask and v_cond_mask
666
+ audio_mask = torch.ones((batch_size, audio_seq_len), dtype=torch.bool, device=audio.device)
667
+ v_cond_mask = torch.ones((batch_size, v_cond_seq_len), dtype=torch.bool, device=audio.device)
668
+
669
+ # batch_size x seq_len
670
+ concat_mask = torch.cat([cond_mask, v_cond_mask, audio_mask], dim=1)
671
+ # batch_size x 1 x seq_len x seq_len
672
+ attn_mask_1 = concat_mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1)
673
+ # batch_size x 1 x seq_len x seq_len
674
+ attn_mask_2 = attn_mask_1.transpose(2, 3)
675
+ # batch_size x 1 x seq_len x seq_len, 1 for broadcasting of num_heads
676
+ attn_mask = (attn_mask_1 & attn_mask_2).bool()
677
+ # avoids self-attention weight being NaN for text padding tokens
678
+ attn_mask[:, :, :, 0] = True
679
+
680
+
681
+ # ========================= Build rope for audio and clip tokens =========================
682
+ if self.interleaved_audio_visual_rope:
683
+ freqs_cos, freqs_sin = self.build_rope_for_interleaved_audio_visual(audio_seq_len * 2)
684
+ v_freqs_cos = v_freqs_sin = None
685
+ else:
686
+ freqs_cos, freqs_sin, v_freqs_cos, v_freqs_sin = self.build_rope_for_audio_visual(
687
+ audio_seq_len, v_cond_seq_len
688
+ )
689
+
690
+ # ========================= Pass through DiT blocks =========================
691
+ freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None
692
+ v_freqs_cis = (v_freqs_cos, v_freqs_sin) if v_freqs_cos is not None else None
693
+
694
+ if self.add_sync_feat_to_audio:
695
+ add_sync_layer = 0
696
+ assert (
697
+ add_sync_layer < self.depth_triple_blocks
698
+ ), f"The layer to add mel_spectrogram feature and sync feature should in the triple_stream_blocks (n: {self.depth_triple_blocks})."
699
+ # Triple-stream blocks
700
+ for layer_num, block in enumerate(self.triple_blocks):
701
+ if self.add_sync_feat_to_audio and layer_num == add_sync_layer:
702
+ audio = audio + add_sync_feat_to_audio
703
+ triple_block_args = [audio, cond, v_cond, attn_mask, vec, freqs_cis, v_freqs_cis, sync_vec]
704
+ if (
705
+ self.training
706
+ and self.gradient_checkpoint
707
+ and (self.gradient_checkpoint_layers == -1 or layer_num < self.gradient_checkpoint_layers)
708
+ ):
709
+ audio, cond, v_cond = torch.utils.checkpoint.checkpoint(
710
+ ckpt_wrapper(block), *triple_block_args, use_reentrant=False
711
+ )
712
+ else:
713
+ audio, cond, v_cond = block(*triple_block_args)
714
+
715
+ x = audio
716
+ if sync_vec is not None:
717
+ vec = vec.unsqueeze(1).repeat(1, cond_seq_len + v_cond_seq_len, 1)
718
+ vec = torch.cat((vec, sync_vec), dim=1)
719
+
720
+ freqs_cos, freqs_sin, _, _ = self.build_rope_for_audio_visual(audio_seq_len, v_cond_seq_len)
721
+ if self.add_sync_feat_to_audio:
722
+ vec = add_sync_feat_to_audio + vec.unsqueeze(dim=1)
723
+ if len(self.single_blocks) > 0:
724
+ for layer_num, block in enumerate(self.single_blocks):
725
+ single_block_args = [
726
+ x,
727
+ vec,
728
+ (freqs_cos, freqs_sin),
729
+ ]
730
+ if (
731
+ self.training
732
+ and self.gradient_checkpoint
733
+ and (
734
+ self.gradient_checkpoint_layers == -1
735
+ or layer_num + len(self.triple_blocks) < self.gradient_checkpoint_layers
736
+ )
737
+ ):
738
+ x = torch.utils.checkpoint.checkpoint(ckpt_wrapper(block), *single_block_args, use_reentrant=False)
739
+ else:
740
+ x = block(*single_block_args)
741
+
742
+ audio = x
743
+
744
+ # ========================= Final layer =========================
745
+ if sync_vec is not None:
746
+ vec = sync_vec
747
+ audio = self.final_layer(audio, vec) # (N, T, patch_size * out_channels)
748
+ audio = self.unpatchify1d(audio, tl)
749
+
750
+ if return_dict:
751
+ out["x"] = audio
752
+ return out
753
+ return audio
754
+
755
+ def unpatchify1d(self, x, l):
756
+ # x: (N, L, patch_size * C)
757
+ # audio: (N, C, T), T == L * patch_size
758
+ c = self.unpatchify_channels
759
+ p = self.patch_size
760
+ assert l == x.shape[1]
761
+
762
+ x = x.reshape(shape=(x.shape[0], l, p, c))
763
+ x = torch.einsum("ntpc->nctp", x)
764
+ audio = x.reshape(shape=(x.shape[0], c, l * p))
765
+ return audio
766
+
767
+ def params_count(self):
768
+ counts = {
769
+ "triple": sum(
770
+ [
771
+ sum(p.numel() for p in block.audio_cross_q.parameters())
772
+ + sum(p.numel() for p in block.v_cond_cross_q.parameters())
773
+ + sum(p.numel() for p in block.text_cross_kv.parameters())
774
+ + sum(p.numel() for p in block.audio_self_attn_qkv.parameters())
775
+ + sum(p.numel() for p in block.v_cond_attn_qkv.parameters())
776
+ + sum(p.numel() for p in block.audio_mlp.parameters())
777
+ + sum(p.numel() for p in block.audio_self_proj.parameters())
778
+ + sum(p.numel() for p in block.v_cond_self_proj.parameters())
779
+ + sum(p.numel() for p in block.v_cond_mlp.parameters())
780
+ for block in self.triple_blocks
781
+ ]
782
+ ),
783
+ "single": sum(
784
+ [
785
+ sum(p.numel() for p in block.linear1.parameters())
786
+ + sum(p.numel() for p in block.linear2.parameters())
787
+ for block in self.single_blocks
788
+ ]
789
+ ),
790
+ "total": sum(p.numel() for p in self.parameters()),
791
+ }
792
+
793
+ counts["attn+mlp"] = counts["triple"] + counts["single"]
794
+ return counts
hunyuanvideo_foley/models/nn/__init__.py ADDED
File without changes
hunyuanvideo_foley/models/nn/activation_layers.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch.nn.functional as F
3
+
4
+ def get_activation_layer(act_type):
5
+ if act_type == "gelu":
6
+ return lambda: nn.GELU()
7
+ elif act_type == "gelu_tanh":
8
+ # Approximate `tanh` requires torch >= 1.13
9
+ return lambda: nn.GELU(approximate="tanh")
10
+ elif act_type == "relu":
11
+ return nn.ReLU
12
+ elif act_type == "silu":
13
+ return nn.SiLU
14
+ else:
15
+ raise ValueError(f"Unknown activation type: {act_type}")
16
+
17
+ class SwiGLU(nn.Module):
18
+ def __init__(
19
+ self,
20
+ dim: int,
21
+ hidden_dim: int,
22
+ out_dim: int,
23
+ ):
24
+ """
25
+ Initialize the SwiGLU FeedForward module.
26
+
27
+ Args:
28
+ dim (int): Input dimension.
29
+ hidden_dim (int): Hidden dimension of the feedforward layer.
30
+
31
+ Attributes:
32
+ w1: Linear transformation for the first layer.
33
+ w2: Linear transformation for the second layer.
34
+ w3: Linear transformation for the third layer.
35
+
36
+ """
37
+ super().__init__()
38
+
39
+ self.w1 = nn.Linear(dim, hidden_dim, bias=False)
40
+ self.w2 = nn.Linear(hidden_dim, out_dim, bias=False)
41
+ self.w3 = nn.Linear(dim, hidden_dim, bias=False)
42
+
43
+ def forward(self, x):
44
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
hunyuanvideo_foley/models/nn/attn_layers.py ADDED
@@ -0,0 +1,546 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib.metadata
2
+ import math
3
+ from typing import Tuple, Union
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from einops import rearrange
8
+
9
+ try:
10
+ from flash_attn import (
11
+ flash_attn_qkvpacked_func,
12
+ flash_attn_kvpacked_func,
13
+ flash_attn_varlen_kvpacked_func,
14
+ flash_attn_varlen_qkvpacked_func,
15
+ )
16
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input
17
+ except ImportError:
18
+ flash_attn_qkvpacked_func, flash_attn_kvpacked_func, flash_attn_varlen_kvpacked_func = None, None, None
19
+ index_first_axis = None
20
+ from packaging import version
21
+ from transformers.utils.import_utils import _is_package_available
22
+
23
+ from .norm_layers import get_norm_layer
24
+
25
+ def reshape_for_broadcast(freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], x: torch.Tensor, head_first=False):
26
+ """
27
+ Reshape frequency tensor for broadcasting it with another tensor.
28
+
29
+ This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
30
+ for the purpose of broadcasting the frequency tensor during element-wise operations.
31
+
32
+ Notes:
33
+ When using FlashMHAModified, head_first should be False.
34
+ When using Attention, head_first should be True.
35
+
36
+ Args:
37
+ freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Frequency tensor to be reshaped.
38
+ x (torch.Tensor): Target tensor for broadcasting compatibility.
39
+ head_first (bool): head dimension first (except batch dim) or not.
40
+
41
+ Returns:
42
+ torch.Tensor: Reshaped frequency tensor.
43
+
44
+ Raises:
45
+ AssertionError: If the frequency tensor doesn't match the expected shape.
46
+ AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
47
+ """
48
+ ndim = x.ndim
49
+ assert 0 <= 1 < ndim
50
+
51
+ if isinstance(freqs_cis, tuple):
52
+ # freqs_cis: (cos, sin) in real space
53
+ if head_first:
54
+ assert freqs_cis[0].shape == (
55
+ x.shape[-2],
56
+ x.shape[-1],
57
+ ), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
58
+ shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
59
+ else:
60
+ assert freqs_cis[0].shape == (
61
+ x.shape[1],
62
+ x.shape[-1],
63
+ ), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
64
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
65
+ return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
66
+ else:
67
+ # freqs_cis: values in complex space
68
+ if head_first:
69
+ assert freqs_cis.shape == (
70
+ x.shape[-2],
71
+ x.shape[-1],
72
+ ), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}"
73
+ shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
74
+ else:
75
+ assert freqs_cis.shape == (
76
+ x.shape[1],
77
+ x.shape[-1],
78
+ ), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}"
79
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
80
+ return freqs_cis.view(*shape)
81
+
82
+
83
+ def rotate_half(x):
84
+ x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
85
+ return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
86
+
87
+
88
+ def apply_rotary_emb(
89
+ xq: torch.Tensor,
90
+ xk: torch.Tensor,
91
+ freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
92
+ head_first: bool = False,
93
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
94
+ """
95
+ Apply rotary embeddings to input tensors using the given frequency tensor.
96
+
97
+ This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
98
+ frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
99
+ is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
100
+ returned as real tensors.
101
+
102
+ Args:
103
+ xq (torch.Tensor): Query tensor to apply rotary embeddings. [B, S, H, D]
104
+ xk (torch.Tensor): Key tensor to apply rotary embeddings. [B, S, H, D]
105
+ freqs_cis (torch.Tensor or tuple): Precomputed frequency tensor for complex exponential.
106
+ head_first (bool): head dimension first (except batch dim) or not.
107
+
108
+ Returns:
109
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
110
+
111
+ """
112
+ xk_out = None
113
+ if isinstance(freqs_cis, tuple):
114
+ cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D]
115
+ cos, sin = cos.to(xq.device), sin.to(xq.device)
116
+ # real * cos - imag * sin
117
+ # imag * cos + real * sin
118
+ xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq)
119
+ xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk)
120
+ else:
121
+ # view_as_complex will pack [..., D/2, 2](real) to [..., D/2](complex)
122
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # [B, S, H, D//2]
123
+ freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(xq.device) # [S, D//2] --> [1, S, 1, D//2]
124
+ # (real, imag) * (cos, sin) = (real * cos - imag * sin, imag * cos + real * sin)
125
+ # view_as_real will expand [..., D/2](complex) to [..., D/2, 2](real)
126
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq)
127
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) # [B, S, H, D//2]
128
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk)
129
+
130
+ return xq_out, xk_out
131
+
132
+
133
+ class BasicAttentionLayer(nn.Module):
134
+ def __init__(self, attn_mode="flash", deterministic=False):
135
+ super().__init__()
136
+ self.attn_mode = attn_mode
137
+ self.deterministic = deterministic
138
+
139
+ def set_attn_mode(self, new_mode):
140
+ self.attn_mode = new_mode
141
+
142
+ def enable_deterministic(self):
143
+ self.deterministic = True
144
+
145
+ def disable_deterministic(self):
146
+ self.deterministic = False
147
+
148
+
149
+ MEMORY_LAYOUT = {
150
+ "self_flash": (
151
+ lambda x: x,
152
+ lambda x: x,
153
+ ),
154
+ "cross_flash": (
155
+ lambda x: x,
156
+ lambda x: x,
157
+ ),
158
+ "flash_torch_sp": (
159
+ lambda x: x,
160
+ lambda x: x,
161
+ ),
162
+ "torch": (
163
+ lambda x: x.transpose(1, 2),
164
+ lambda x: x.transpose(1, 2),
165
+ ),
166
+ "vanilla": (
167
+ lambda x: x.transpose(1, 2),
168
+ lambda x: x.transpose(1, 2),
169
+ ),
170
+ }
171
+
172
+
173
+ # Copyed from https://github.com/huggingface/transformers/blob/b873234cb649a24865021f0d598627ce2b24d34a/src/transformers/modeling_flash_attention_utils.py#L33C1-L57C6
174
+ def _get_unpad_data(attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, int]:
175
+ """
176
+ Retrieves indexing data required to repad unpadded (ragged) tensors.
177
+
178
+ Arguments:
179
+ attention_mask (`torch.Tensor`):
180
+ Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
181
+
182
+ Return:
183
+ indices (`torch.Tensor):
184
+ The indices of non-masked tokens from the flattened input sequence.
185
+ cu_seqlens (`torch.Tensor`):
186
+ The cumulative sequence lengths, used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
187
+ max_seqlen_in_batch (`int`):
188
+ Maximum sequence length in batch.
189
+ """
190
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
191
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
192
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
193
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
194
+ return (
195
+ indices,
196
+ cu_seqlens,
197
+ max_seqlen_in_batch,
198
+ )
199
+
200
+
201
+ # Copyed from https://github.com/huggingface/transformers/blob/b873234cb649a24865021f0d598627ce2b24d34a/src/transformers/utils/import_utils.py#L822
202
+ def is_flash_attn_greater_or_equal(library_version: str):
203
+ if not _is_package_available("flash_attn"):
204
+ return False
205
+
206
+ return version.parse(importlib.metadata.version("flash_attn")) >= version.parse(library_version)
207
+
208
+
209
+ def get_kv_seqlens_with_mask(attn_mask, k, v):
210
+ indices_k, cu_seqlens_k, max_seqlen_k = _get_unpad_data(attn_mask)
211
+ b, s1, a, d = k.shape
212
+ k = index_first_axis(k.reshape(b * s1, a, d), indices_k)
213
+ v = index_first_axis(v.reshape(b * s1, a, d), indices_k)
214
+ kv = torch.stack([k, v], dim=1)
215
+ return cu_seqlens_k, max_seqlen_k, kv
216
+
217
+
218
+ def get_q_seqlens(q):
219
+ bs, s, a, d = q.shape
220
+ cu_seqlens_q = torch.arange(0, (bs + 1) * s, step=s, dtype=torch.int32, device=q.device)
221
+ q = q.reshape(bs * s, a, d)
222
+ return cu_seqlens_q, s, q
223
+
224
+ def flash_attn_no_pad(
225
+ qkv, key_padding_mask, causal=False, dropout_p=0.0, softmax_scale=None
226
+ ):
227
+ # adapted from https://github.com/Dao-AILab/flash-attention/blob/13403e81157ba37ca525890f2f0f2137edf75311/flash_attn/flash_attention.py#L27
228
+ batch_size = qkv.shape[0]
229
+ seqlen = qkv.shape[1]
230
+ nheads = qkv.shape[-2]
231
+ x = rearrange(qkv, "b s three h d -> b s (three h d)")
232
+ # x_unpad, indices, cu_seqlens, max_s, used_seqlens_in_batch
233
+ # x_unpad, indices, cu_seqlens, max_s
234
+ unpad_results = unpad_input(
235
+ x, key_padding_mask
236
+ )
237
+
238
+ if len(unpad_results) == 4:
239
+ x_unpad, indices, cu_seqlens, max_s = unpad_results
240
+ elif len(unpad_results) == 5:
241
+ x_unpad, indices, cu_seqlens, max_s, used_seqlens_in_batch = unpad_results
242
+ else:
243
+ raise ValueError
244
+
245
+ x_unpad = rearrange(x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads)
246
+ output_unpad = flash_attn_varlen_qkvpacked_func(
247
+ x_unpad,
248
+ cu_seqlens,
249
+ max_s,
250
+ dropout_p,
251
+ softmax_scale=softmax_scale,
252
+ causal=causal,
253
+ )
254
+ output = rearrange(
255
+ pad_input(
256
+ rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, batch_size, seqlen
257
+ ),
258
+ "b s (h d) -> b s h d",
259
+ h=nheads,
260
+ )
261
+ return output
262
+
263
+
264
+ def attention(
265
+ q,
266
+ k,
267
+ v,
268
+ mode,
269
+ drop_rate=0,
270
+ attn_mask=None,
271
+ cond_mask=None,
272
+ causal=False,
273
+ deterministic=False,
274
+ cu_seqlens=None,
275
+ max_seqlen=None,
276
+ cu_seqlens_k=None,
277
+ max_seqlen_k=None,
278
+ img_seq_len=None,
279
+ ):
280
+ """
281
+ Perform QKV self attention.
282
+
283
+ Args:
284
+ q (torch.Tensor): Query tensor with shape [b, s, a, d], where a is the number of heads.
285
+ k (torch.Tensor): Key tensor with shape [b, s1, a, d]
286
+ v (torch.Tensor): Value tensor with shape [b, s1, a, d]
287
+ mode (str): Attention mode. Choose from 'self_flash', 'cross_flash', 'torch', and 'vanilla'.
288
+ drop_rate (float): Dropout rate in attention map. (default: 0)
289
+ attn_mask (torch.Tensor): Attention mask with shape [b, s1] (cross_attn), or [b, a, s, s1] (torch or vanilla).
290
+ (default: None)
291
+ causal (bool): Whether to use causal attention. (default: False)
292
+ deterministic (bool): Whether to use deterministic attention. (default: False)
293
+ cu_seqlens (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
294
+ used to index into q.
295
+ max_seqlen (int): The maximum sequence length in the batch of q.
296
+ cu_seqlens_k (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
297
+ used to index into kv.
298
+ max_seqlen_k (int): The maximum sequence length in the batch of k and v.
299
+
300
+ Returns:
301
+ torch.Tensor: Output tensor after self attention with shape [b, s, ad]
302
+ """
303
+ if mode in ["torch", "vanilla", "self_flash", "cross_flash"]:
304
+ if isinstance(q, tuple):
305
+ q = torch.cat(q, dim=1)
306
+ if isinstance(k, tuple):
307
+ k = torch.cat(k, dim=1)
308
+ if isinstance(v, tuple):
309
+ v = torch.cat(v, dim=1)
310
+ pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode]
311
+ q = pre_attn_layout(q)
312
+ k = pre_attn_layout(k)
313
+ v = pre_attn_layout(v)
314
+
315
+ if "flash" in mode:
316
+ assert (
317
+ flash_attn_qkvpacked_func is not None
318
+ ), "Flash attention is not available. Please install flash_attn first."
319
+ flash_kwargs = dict(dropout_p=drop_rate, causal=causal)
320
+ if deterministic:
321
+ if not is_flash_attn_greater_or_equal("2.4.1"):
322
+ raise ValueError(
323
+ "Flash attention deterministic mode requires flash_attn>=2.4.1. " "Please upgrade flash_attn"
324
+ )
325
+ flash_kwargs["deterministic"] = deterministic
326
+
327
+ if mode == "self_flash":
328
+ qkv = torch.stack([q, k, v], dim=2)
329
+ if attn_mask is not None:
330
+ raise ValueError("Self attention does not support attention mask")
331
+ x = flash_attn_qkvpacked_func(qkv, **flash_kwargs)
332
+
333
+ elif mode == "cross_flash":
334
+ kv = torch.stack([k, v], dim=2)
335
+ if attn_mask is None:
336
+ x = flash_attn_kvpacked_func(q, kv, **flash_kwargs)
337
+ else:
338
+ b, s, a, h = q.shape
339
+ cu_seqlens_q, max_seqlen_q, q = get_q_seqlens(q)
340
+ cu_seqlens_k, max_seqlen_k, kv = get_kv_seqlens_with_mask(attn_mask, k, v)
341
+
342
+ attn_output = flash_attn_varlen_kvpacked_func(
343
+ q,
344
+ kv,
345
+ cu_seqlens_q=cu_seqlens_q,
346
+ cu_seqlens_k=cu_seqlens_k,
347
+ max_seqlen_q=max_seqlen_q,
348
+ max_seqlen_k=max_seqlen_k,
349
+ **flash_kwargs,
350
+ )
351
+ x = attn_output.reshape(b, s, a, h)
352
+ elif mode == 'torch':
353
+ if attn_mask is not None and attn_mask.dtype != torch.bool:
354
+ attn_mask = attn_mask.to(q.dtype)
355
+ x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal)
356
+
357
+ elif mode == "vanilla":
358
+ scale_factor = 1 / math.sqrt(q.size(-1))
359
+
360
+ b, a, s, _ = q.shape
361
+ s1 = k.size(2)
362
+ attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device)
363
+ if causal:
364
+ # Only applied to self attention
365
+ assert attn_mask is None, "Causal mask and attn_mask cannot be used together"
366
+ temp_mask = torch.ones(b, a, s, s, dtype=torch.bool, device=q.device).tril(diagonal=0)
367
+ attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
368
+ attn_bias.to(q.dtype)
369
+
370
+ if attn_mask is not None:
371
+ if attn_mask.dtype == torch.bool:
372
+ attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
373
+ else:
374
+ attn_bias += attn_mask
375
+
376
+ # TODO(jarvizhang): Maybe force q and k to be float32 to avoid numerical overflow
377
+ attn = (q @ k.transpose(-2, -1)) * scale_factor
378
+ attn += attn_bias
379
+ attn = attn.softmax(dim=-1)
380
+ attn = torch.dropout(attn, p=drop_rate, train=True)
381
+ x = attn @ v
382
+ else:
383
+ raise NotImplementedError(f"Unsupported attention mode: {mode}")
384
+
385
+ if mode in ["torch", "vanilla", "self_flash", "cross_flash"]:
386
+ x = post_attn_layout(x).contiguous()
387
+ b, s, a, d = x.shape
388
+ out = x.reshape(b, s, -1)
389
+ return out
390
+
391
+
392
+ class SelfAttentionLayer(BasicAttentionLayer):
393
+ def __init__(
394
+ self,
395
+ dim,
396
+ num_heads,
397
+ qkv_bias=True,
398
+ qk_norm=True,
399
+ attn_drop=0,
400
+ proj_drop=0,
401
+ dtype=None,
402
+ device=None,
403
+ norm_type="layer",
404
+ attn_mode="self_flash",
405
+ deterministic=False,
406
+ ) -> None:
407
+ factory_kwargs = {"device": device, "dtype": dtype}
408
+ super().__init__(attn_mode, deterministic)
409
+ self.dim = dim
410
+ self.num_heads = num_heads
411
+ assert self.dim % num_heads == 0, "dim must be divisible by num_heads"
412
+ self.head_dim = self.dim // num_heads
413
+ self.attn_drop = attn_drop
414
+
415
+ # This assertion is aligned with flash attention
416
+ assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8"
417
+
418
+ self.Wqkv = nn.Linear(dim, dim * 3, bias=qkv_bias, **factory_kwargs)
419
+
420
+ norm_layer = get_norm_layer(norm_type)
421
+ self.q_norm = (
422
+ norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
423
+ )
424
+ self.k_norm = (
425
+ norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
426
+ )
427
+
428
+ self.out_proj = nn.Linear(dim, dim, bias=qkv_bias, **factory_kwargs)
429
+ self.proj_drop = nn.Dropout(proj_drop)
430
+
431
+ def forward(self, x, freqs_cis=None, attn_mask=None):
432
+ """
433
+ Args:
434
+ x (torch.Tensor): (batch, seq_len, hidden_dim) (where hidden_dim = num heads * head dim)
435
+ freqs_cis (torch.Tensor, optional): (batch, hidden_dim // 2), RoPE for image
436
+ attn_mask (torch.Tensor, optional): (batch, seq_len, seq_len), mask for attention
437
+ """
438
+ b, s, d = x.shape
439
+
440
+ # Apply QKV projection
441
+ qkv = self.Wqkv(x)
442
+ qkv = qkv.view(b, s, 3, self.num_heads, self.head_dim) # [b, s, 3, a, d]
443
+ q, k, v = qkv.unbind(dim=2) # [b, s, a, d]
444
+
445
+ # Apply QK-Norm if needed
446
+ q = self.q_norm(q)
447
+ k = self.k_norm(k)
448
+
449
+ # Apply RoPE if needed
450
+ if freqs_cis is not None:
451
+ qq, kk = apply_rotary_emb(q, k, freqs_cis)
452
+ assert (
453
+ qq.shape == q.shape and kk.shape == k.shape
454
+ ), f"qq: {qq.shape}, q: {q.shape}, kk: {kk.shape}, k: {k.shape}"
455
+ q, k = qq, kk
456
+
457
+ # Apply self attention
458
+ context = attention(
459
+ q,
460
+ k,
461
+ v,
462
+ drop_rate=self.attn_drop if self.training else 0,
463
+ attn_mask=attn_mask,
464
+ mode=self.attn_mode,
465
+ deterministic=self.deterministic,
466
+ )
467
+ out = self.out_proj(context)
468
+ out = self.proj_drop(out)
469
+
470
+ return out
471
+
472
+
473
+ class CrossAttentionLayer(BasicAttentionLayer):
474
+ def __init__(
475
+ self,
476
+ qdim,
477
+ kdim,
478
+ num_heads,
479
+ qkv_bias=True,
480
+ qk_norm=True,
481
+ attn_drop=0,
482
+ proj_drop=0,
483
+ dtype=None,
484
+ device=None,
485
+ norm_type="layer",
486
+ attn_mode="cross_flash",
487
+ deterministic=False,
488
+ ):
489
+ factory_kwargs = {"device": device, "dtype": dtype}
490
+ super().__init__(attn_mode, deterministic)
491
+ self.qdim = qdim
492
+ self.kdim = kdim
493
+ self.num_heads = num_heads
494
+ assert self.qdim % num_heads == 0, "qdim must be divisible by num_heads"
495
+ self.head_dim = self.qdim // num_heads
496
+ self.attn_drop = attn_drop
497
+
498
+ # This assertion is aligned with flash attention
499
+ assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8"
500
+
501
+ self.q_proj = nn.Linear(qdim, qdim, bias=qkv_bias, **factory_kwargs)
502
+ self.kv_proj = nn.Linear(kdim, 2 * qdim, bias=qkv_bias, **factory_kwargs)
503
+
504
+ norm_layer = get_norm_layer(norm_type)
505
+ self.q_norm = (
506
+ norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
507
+ )
508
+ self.k_norm = (
509
+ norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
510
+ )
511
+
512
+ self.out_proj = nn.Linear(qdim, qdim, bias=qkv_bias, **factory_kwargs)
513
+ self.proj_drop = nn.Dropout(proj_drop)
514
+
515
+ def forward(self, x, y, attn_mask=None):
516
+ """
517
+ Args:
518
+ x (torch.Tensor): (batch, seq_len, hidden_dim) (where hidden_dim = num heads * head dim)
519
+ y (torch.Tensor): (batch, seq_len1, hidden_dim1)
520
+ attn_mask (torch.Tensor): (batch, seq_len1), mask for attention
521
+ """
522
+ b, s, d = x.shape
523
+ _, s1, d1 = y.shape
524
+
525
+ q = self.q_proj(x).view(b, s, self.num_heads, self.head_dim)
526
+ kv = self.kv_proj(y).view(b, s1, 2, self.num_heads, self.head_dim)
527
+ k, v = kv.unbind(dim=2)
528
+
529
+ # Apply QK-Norm if needed
530
+ q = self.q_norm(q)
531
+ k = self.k_norm(k)
532
+
533
+ # Apply cross attention
534
+ context = attention(
535
+ q,
536
+ k,
537
+ v,
538
+ attn_mask=attn_mask,
539
+ drop_rate=self.attn_drop if self.training else 0,
540
+ mode=self.attn_mode,
541
+ deterministic=self.deterministic,
542
+ )
543
+ out = self.out_proj(context)
544
+ out = self.proj_drop(out)
545
+
546
+ return out
hunyuanvideo_foley/models/nn/embed_layers.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ from ...utils.helper import to_2tuple, to_1tuple
6
+
7
+ class PatchEmbed1D(nn.Module):
8
+ """1D Audio to Patch Embedding
9
+
10
+ A convolution based approach to patchifying a 1D audio w/ embedding projection.
11
+
12
+ Based on the impl in https://github.com/google-research/vision_transformer
13
+
14
+ Hacked together by / Copyright 2020 Ross Wightman
15
+ """
16
+
17
+ def __init__(
18
+ self,
19
+ patch_size=1,
20
+ in_chans=768,
21
+ embed_dim=768,
22
+ norm_layer=None,
23
+ flatten=True,
24
+ bias=True,
25
+ dtype=None,
26
+ device=None,
27
+ ):
28
+ factory_kwargs = {"dtype": dtype, "device": device}
29
+ super().__init__()
30
+ patch_size = to_1tuple(patch_size)
31
+ self.patch_size = patch_size
32
+ self.flatten = flatten
33
+
34
+ self.proj = nn.Conv1d(
35
+ in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias, **factory_kwargs
36
+ )
37
+ nn.init.xavier_uniform_(self.proj.weight.view(self.proj.weight.size(0), -1))
38
+ if bias:
39
+ nn.init.zeros_(self.proj.bias)
40
+
41
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
42
+
43
+ def forward(self, x):
44
+ assert (
45
+ x.shape[2] % self.patch_size[0] == 0
46
+ ), f"The patch_size of {self.patch_size[0]} must be divisible by the token number ({x.shape[2]}) of x."
47
+
48
+ x = self.proj(x)
49
+ if self.flatten:
50
+ x = x.transpose(1, 2) # BCN -> BNC
51
+ x = self.norm(x)
52
+ return x
53
+
54
+
55
+ class ConditionProjection(nn.Module):
56
+ """
57
+ Projects condition embeddings. Also handles dropout for classifier-free guidance.
58
+
59
+ Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
60
+ """
61
+
62
+ def __init__(self, in_channels, hidden_size, act_layer, dtype=None, device=None):
63
+ factory_kwargs = {'dtype': dtype, 'device': device}
64
+ super().__init__()
65
+ self.linear_1 = nn.Linear(in_features=in_channels, out_features=hidden_size, bias=True, **factory_kwargs)
66
+ self.act_1 = act_layer()
67
+ self.linear_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True, **factory_kwargs)
68
+
69
+ def forward(self, caption):
70
+ hidden_states = self.linear_1(caption)
71
+ hidden_states = self.act_1(hidden_states)
72
+ hidden_states = self.linear_2(hidden_states)
73
+ return hidden_states
74
+
75
+
76
+ def timestep_embedding(t, dim, max_period=10000):
77
+ """
78
+ Create sinusoidal timestep embeddings.
79
+
80
+ Args:
81
+ t (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional.
82
+ dim (int): the dimension of the output.
83
+ max_period (int): controls the minimum frequency of the embeddings.
84
+
85
+ Returns:
86
+ embedding (torch.Tensor): An (N, D) Tensor of positional embeddings.
87
+
88
+ .. ref_link: https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
89
+ """
90
+ half = dim // 2
91
+ freqs = torch.exp(
92
+ -math.log(max_period)
93
+ * torch.arange(start=0, end=half, dtype=torch.float32)
94
+ / half
95
+ ).to(device=t.device)
96
+ args = t[:, None].float() * freqs[None]
97
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
98
+ if dim % 2:
99
+ embedding = torch.cat(
100
+ [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
101
+ )
102
+ return embedding
103
+
104
+
105
+ class TimestepEmbedder(nn.Module):
106
+ """
107
+ Embeds scalar timesteps into vector representations.
108
+ """
109
+ def __init__(self,
110
+ hidden_size,
111
+ act_layer,
112
+ frequency_embedding_size=256,
113
+ max_period=10000,
114
+ out_size=None,
115
+ dtype=None,
116
+ device=None
117
+ ):
118
+ factory_kwargs = {'dtype': dtype, 'device': device}
119
+ super().__init__()
120
+ self.frequency_embedding_size = frequency_embedding_size
121
+ self.max_period = max_period
122
+ if out_size is None:
123
+ out_size = hidden_size
124
+
125
+ self.mlp = nn.Sequential(
126
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True, **factory_kwargs),
127
+ act_layer(),
128
+ nn.Linear(hidden_size, out_size, bias=True, **factory_kwargs),
129
+ )
130
+ nn.init.normal_(self.mlp[0].weight, std=0.02)
131
+ nn.init.normal_(self.mlp[2].weight, std=0.02)
132
+
133
+ def forward(self, t):
134
+ t_freq = timestep_embedding(t, self.frequency_embedding_size, self.max_period).type(self.mlp[0].weight.dtype)
135
+ t_emb = self.mlp(t_freq)
136
+ return t_emb
hunyuanvideo_foley/models/nn/mlp_layers.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from timm library:
2
+ # https://github.com/huggingface/pytorch-image-models/blob/648aaa41233ba83eb38faf5ba9d415d574823241/timm/layers/mlp.py#L13
3
+
4
+ from functools import partial
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ from .modulate_layers import modulate
11
+ from ...utils.helper import to_2tuple
12
+
13
+ class MLP(nn.Module):
14
+ """MLP as used in Vision Transformer, MLP-Mixer and related networks"""
15
+
16
+ def __init__(
17
+ self,
18
+ in_channels,
19
+ hidden_channels=None,
20
+ out_features=None,
21
+ act_layer=nn.GELU,
22
+ norm_layer=None,
23
+ bias=True,
24
+ drop=0.0,
25
+ use_conv=False,
26
+ device=None,
27
+ dtype=None,
28
+ ):
29
+ factory_kwargs = {"device": device, "dtype": dtype}
30
+ super().__init__()
31
+ out_features = out_features or in_channels
32
+ hidden_channels = hidden_channels or in_channels
33
+ bias = to_2tuple(bias)
34
+ drop_probs = to_2tuple(drop)
35
+ linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
36
+
37
+ self.fc1 = linear_layer(in_channels, hidden_channels, bias=bias[0], **factory_kwargs)
38
+ self.act = act_layer()
39
+ self.drop1 = nn.Dropout(drop_probs[0])
40
+ self.norm = norm_layer(hidden_channels, **factory_kwargs) if norm_layer is not None else nn.Identity()
41
+ self.fc2 = linear_layer(hidden_channels, out_features, bias=bias[1], **factory_kwargs)
42
+ self.drop2 = nn.Dropout(drop_probs[1])
43
+
44
+ def forward(self, x):
45
+ x = self.fc1(x)
46
+ x = self.act(x)
47
+ x = self.drop1(x)
48
+ x = self.norm(x)
49
+ x = self.fc2(x)
50
+ x = self.drop2(x)
51
+ return x
52
+
53
+
54
+ # copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py
55
+ # only used when use_vanilla is True
56
+ class MLPEmbedder(nn.Module):
57
+ def __init__(self, in_dim: int, hidden_dim: int, device=None, dtype=None):
58
+ factory_kwargs = {"device": device, "dtype": dtype}
59
+ super().__init__()
60
+ self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True, **factory_kwargs)
61
+ self.silu = nn.SiLU()
62
+ self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True, **factory_kwargs)
63
+
64
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
65
+ return self.out_layer(self.silu(self.in_layer(x)))
66
+
67
+
68
+ class LinearWarpforSingle(nn.Module):
69
+ def __init__(self, in_dim: int, out_dim: int, bias=True, device=None, dtype=None):
70
+ factory_kwargs = {"device": device, "dtype": dtype}
71
+ super().__init__()
72
+ self.fc = nn.Linear(in_dim, out_dim, bias=bias, **factory_kwargs)
73
+
74
+ def forward(self, x, y):
75
+ z = torch.cat([x, y], dim=2)
76
+ return self.fc(z)
77
+
78
+ class FinalLayer1D(nn.Module):
79
+ def __init__(self, hidden_size, patch_size, out_channels, act_layer, device=None, dtype=None):
80
+ factory_kwargs = {"device": device, "dtype": dtype}
81
+ super().__init__()
82
+
83
+ # Just use LayerNorm for the final layer
84
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
85
+ self.linear = nn.Linear(hidden_size, patch_size * out_channels, bias=True, **factory_kwargs)
86
+ nn.init.zeros_(self.linear.weight)
87
+ nn.init.zeros_(self.linear.bias)
88
+
89
+ # Here we don't distinguish between the modulate types. Just use the simple one.
90
+ self.adaLN_modulation = nn.Sequential(
91
+ act_layer(), nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs)
92
+ )
93
+ # Zero-initialize the modulation
94
+ nn.init.zeros_(self.adaLN_modulation[1].weight)
95
+ nn.init.zeros_(self.adaLN_modulation[1].bias)
96
+
97
+ def forward(self, x, c):
98
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
99
+ x = modulate(self.norm_final(x), shift=shift, scale=scale)
100
+ x = self.linear(x)
101
+ return x
102
+
103
+
104
+ class ChannelLastConv1d(nn.Conv1d):
105
+
106
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
107
+ x = x.permute(0, 2, 1)
108
+ x = super().forward(x)
109
+ x = x.permute(0, 2, 1)
110
+ return x
111
+
112
+
113
+ class ConvMLP(nn.Module):
114
+
115
+ def __init__(
116
+ self,
117
+ dim: int,
118
+ hidden_dim: int,
119
+ multiple_of: int = 256,
120
+ kernel_size: int = 3,
121
+ padding: int = 1,
122
+ device=None,
123
+ dtype=None,
124
+ ):
125
+ """
126
+ Convolutional MLP module.
127
+
128
+ Args:
129
+ dim (int): Input dimension.
130
+ hidden_dim (int): Hidden dimension of the feedforward layer.
131
+ multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
132
+
133
+ Attributes:
134
+ w1: Linear transformation for the first layer.
135
+ w2: Linear transformation for the second layer.
136
+ w3: Linear transformation for the third layer.
137
+
138
+ """
139
+ factory_kwargs = {"device": device, "dtype": dtype}
140
+ super().__init__()
141
+ hidden_dim = int(2 * hidden_dim / 3)
142
+ hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
143
+
144
+ self.w1 = ChannelLastConv1d(dim, hidden_dim, bias=False, kernel_size=kernel_size, padding=padding, **factory_kwargs)
145
+ self.w2 = ChannelLastConv1d(hidden_dim, dim, bias=False, kernel_size=kernel_size, padding=padding, **factory_kwargs)
146
+ self.w3 = ChannelLastConv1d(dim, hidden_dim, bias=False, kernel_size=kernel_size, padding=padding, **factory_kwargs)
147
+
148
+ def forward(self, x):
149
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
hunyuanvideo_foley/models/nn/modulate_layers.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ class ModulateDiT(nn.Module):
6
+ def __init__(self, hidden_size: int, factor: int, act_layer: Callable, dtype=None, device=None):
7
+ factory_kwargs = {"dtype": dtype, "device": device}
8
+ super().__init__()
9
+ self.act = act_layer()
10
+ self.linear = nn.Linear(hidden_size, factor * hidden_size, bias=True, **factory_kwargs)
11
+ # Zero-initialize the modulation
12
+ nn.init.zeros_(self.linear.weight)
13
+ nn.init.zeros_(self.linear.bias)
14
+
15
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
16
+ return self.linear(self.act(x))
17
+
18
+
19
+ def modulate(x, shift=None, scale=None):
20
+ if x.ndim == 3:
21
+ shift = shift.unsqueeze(1) if shift is not None and shift.ndim == 2 else None
22
+ scale = scale.unsqueeze(1) if scale is not None and scale.ndim == 2 else None
23
+ if scale is None and shift is None:
24
+ return x
25
+ elif shift is None:
26
+ return x * (1 + scale)
27
+ elif scale is None:
28
+ return x + shift
29
+ else:
30
+ return x * (1 + scale) + shift
31
+
32
+
33
+ def apply_gate(x, gate=None, tanh=False):
34
+ if gate is None:
35
+ return x
36
+ if gate.ndim == 2 and x.ndim == 3:
37
+ gate = gate.unsqueeze(1)
38
+ if tanh:
39
+ return x * gate.tanh()
40
+ else:
41
+ return x * gate
42
+
43
+
44
+ def ckpt_wrapper(module):
45
+ def ckpt_forward(*inputs):
46
+ outputs = module(*inputs)
47
+ return outputs
48
+
49
+ return ckpt_forward
hunyuanvideo_foley/models/nn/norm_layers.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class RMSNorm(nn.Module):
5
+ def __init__(self, dim: int, elementwise_affine=True, eps: float = 1e-6,
6
+ device=None, dtype=None):
7
+ """
8
+ Initialize the RMSNorm normalization layer.
9
+
10
+ Args:
11
+ dim (int): The dimension of the input tensor.
12
+ eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
13
+
14
+ Attributes:
15
+ eps (float): A small value added to the denominator for numerical stability.
16
+ weight (nn.Parameter): Learnable scaling parameter.
17
+
18
+ """
19
+ factory_kwargs = {'device': device, 'dtype': dtype}
20
+ super().__init__()
21
+ self.eps = eps
22
+ if elementwise_affine:
23
+ self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
24
+
25
+ def _norm(self, x):
26
+ """
27
+ Apply the RMSNorm normalization to the input tensor.
28
+
29
+ Args:
30
+ x (torch.Tensor): The input tensor.
31
+
32
+ Returns:
33
+ torch.Tensor: The normalized tensor.
34
+
35
+ """
36
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
37
+
38
+ def forward(self, x):
39
+ """
40
+ Forward pass through the RMSNorm layer.
41
+
42
+ Args:
43
+ x (torch.Tensor): The input tensor.
44
+
45
+ Returns:
46
+ torch.Tensor: The output tensor after applying RMSNorm.
47
+
48
+ """
49
+ output = self._norm(x.float()).type_as(x)
50
+ if hasattr(self, "weight"):
51
+ output = output * self.weight
52
+ return output
53
+
54
+
55
+ def get_norm_layer(norm_layer):
56
+ """
57
+ Get the normalization layer.
58
+
59
+ Args:
60
+ norm_layer (str): The type of normalization layer.
61
+
62
+ Returns:
63
+ norm_layer (nn.Module): The normalization layer.
64
+ """
65
+ if norm_layer == "layer":
66
+ return nn.LayerNorm
67
+ elif norm_layer == "rms":
68
+ return RMSNorm
69
+ else:
70
+ raise NotImplementedError(f"Norm layer {norm_layer} is not implemented")