diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..e2d9729aae670eb88273b99ec1e813b08ea59162 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +*.png filter=lfs diff=lfs merge=lfs -text +*.mp4 filter=lfs diff=lfs merge=lfs -text diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..a27ca3f82e8814609adedf4bfae2d1146d13e4ea --- /dev/null +++ b/LICENSE @@ -0,0 +1,77 @@ +TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT +Tencent HunyuanVideo-Foley Release Date: August 28, 2025 +THIS LICENSE AGREEMENT DOES NOT APPLY IN THE EUROPEAN UNION, UNITED KINGDOM AND SOUTH KOREA AND IS EXPRESSLY LIMITED TO THE TERRITORY, AS DEFINED BELOW. +By clicking to agree or by using, reproducing, modifying, distributing, performing or displaying any portion or element of the Tencent Hunyuan Works, including via any Hosted Service, You will be deemed to have recognized and accepted the content of this Agreement, which is effective immediately. +1. DEFINITIONS. +a. “Acceptable Use Policy” shall mean the policy made available by Tencent as set forth in the Exhibit A. +b. “Agreement” shall mean the terms and conditions for use, reproduction, distribution, modification, performance and displaying of Tencent Hunyuan Works or any portion or element thereof set forth herein. +c. “Documentation” shall mean the specifications, manuals and documentation for Tencent Hunyuan made publicly available by Tencent. +d. “Hosted Service” shall mean a hosted service offered via an application programming interface (API), web access, or any other electronic or remote means. +e. “Licensee,” “You” or “Your” shall mean a natural person or legal entity exercising the rights granted by this Agreement and/or using the Tencent Hunyuan Works for any purpose and in any field of use. +f. “Materials” shall mean, collectively, Tencent’s proprietary Tencent Hunyuan and Documentation (and any portion thereof) as made available by Tencent under this Agreement. +g. “Model Derivatives” shall mean all: (i) modifications to Tencent Hunyuan or any Model Derivative of Tencent Hunyuan; (ii) works based on Tencent Hunyuan or any Model Derivative of Tencent Hunyuan; or (iii) any other machine learning model which is created by transfer of patterns of the weights, parameters, operations, or Output of Tencent Hunyuan or any Model Derivative of Tencent Hunyuan, to that model in order to cause that model to perform similarly to Tencent Hunyuan or a Model Derivative of Tencent Hunyuan, including distillation methods, methods that use intermediate data representations, or methods based on the generation of synthetic data Outputs by Tencent Hunyuan or a Model Derivative of Tencent Hunyuan for training that model. For clarity, Outputs by themselves are not deemed Model Derivatives. +h. “Output” shall mean the information and/or content output of Tencent Hunyuan or a Model Derivative that results from operating or otherwise using Tencent Hunyuan or a Model Derivative, including via a Hosted Service. +i. “Tencent,” “We” or “Us” shall mean the applicable entity or entities in the Tencent corporate family that own(s) intellectual property or other rights embodied in or utilized by the Materials. +j. “Tencent Hunyuan” shall mean the large language models, text/image/video/audio/3D generation models, and multimodal large language models and their software and algorithms, including trained model weights, parameters (including optimizer states), machine-learning model code, inference-enabling code, training-enabling code, fine-tuning enabling code and other elements of the foregoing made publicly available by Us, including, without limitation to, Tencent HunyuanVideo-Foley released at [https://github.com/Tencent-Hunyuan/HunyuanVideo-Foley]. +k. “Tencent Hunyuan Works” shall mean: (i) the Materials; (ii) Model Derivatives; and (iii) all derivative works thereof. +l. “Territory” shall mean the worldwide territory, excluding the territory of the European Union, United Kingdom and South Korea. +m. “Third Party” or “Third Parties” shall mean individuals or legal entities that are not under common control with Us or You. +n. “including” shall mean including but not limited to. +2. GRANT OF RIGHTS. +We grant You, for the Territory only, a non-exclusive, non-transferable and royalty-free limited license under Tencent’s intellectual property or other rights owned by Us embodied in or utilized by the Materials to use, reproduce, distribute, create derivative works of (including Model Derivatives), and make modifications to the Materials, only in accordance with the terms of this Agreement and the Acceptable Use Policy, and You must not violate (or encourage or permit anyone else to violate) any term of this Agreement or the Acceptable Use Policy. +3. DISTRIBUTION. +You may, subject to Your compliance with this Agreement, distribute or make available to Third Parties the Tencent Hunyuan Works, exclusively in the Territory, provided that You meet all of the following conditions: +a. You must provide all such Third Party recipients of the Tencent Hunyuan Works or products or services using them a copy of this Agreement; +b. You must cause any modified files to carry prominent notices stating that You changed the files; +c. You are encouraged to: (i) publish at least one technology introduction blogpost or one public statement expressing Your experience of using the Tencent Hunyuan Works; and (ii) mark the products or services developed by using the Tencent Hunyuan Works to indicate that the product/service is “Powered by Tencent Hunyuan”; and +d. All distributions to Third Parties (other than through a Hosted Service) must be accompanied by a “Notice” text file that contains the following notice: “Tencent Hunyuan is licensed under the Tencent Hunyuan Community License Agreement, Copyright © 2025 Tencent. All Rights Reserved. The trademark rights of “Tencent Hunyuan” are owned by Tencent or its affiliate.” +You may add Your own copyright statement to Your modifications and, except as set forth in this Section and in Section 5, may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Model Derivatives as a whole, provided Your use, reproduction, modification, distribution, performance and display of the work otherwise complies with the terms and conditions of this Agreement (including as regards the Territory). If You receive Tencent Hunyuan Works from a Licensee as part of an integrated end user product, then this Section 3 of this Agreement will not apply to You. +4. ADDITIONAL COMMERCIAL TERMS. +If, on the Tencent Hunyuan version release date, the monthly active users of all products or services made available by or for Licensee is greater than 100 million monthly active users in the preceding calendar month, You must request a license from Tencent, which Tencent may grant to You in its sole discretion, and You are not authorized to exercise any of the rights under this Agreement unless or until Tencent otherwise expressly grants You such rights. +5. RULES OF USE. +a. Your use of the Tencent Hunyuan Works must comply with applicable laws and regulations (including trade compliance laws and regulations) and adhere to the Acceptable Use Policy for the Tencent Hunyuan Works, which is hereby incorporated by reference into this Agreement. You must include the use restrictions referenced in these Sections 5(a) and 5(b) as an enforceable provision in any agreement (e.g., license agreement, terms of use, etc.) governing the use and/or distribution of Tencent Hunyuan Works and You must provide notice to subsequent users to whom You distribute that Tencent Hunyuan Works are subject to the use restrictions in these Sections 5(a) and 5(b). +b. You must not use the Tencent Hunyuan Works or any Output or results of the Tencent Hunyuan Works to improve any other AI model (other than Tencent Hunyuan or Model Derivatives thereof). +c. You must not use, reproduce, modify, distribute, or display the Tencent Hunyuan Works, Output or results of the Tencent Hunyuan Works outside the Territory. Any such use outside the Territory is unlicensed and unauthorized under this Agreement. +6. INTELLECTUAL PROPERTY. +a. Subject to Tencent’s ownership of Tencent Hunyuan Works made by or for Tencent and intellectual property rights therein, conditioned upon Your compliance with the terms and conditions of this Agreement, as between You and Tencent, You will be the owner of any derivative works and modifications of the Materials and any Model Derivatives that are made by or for You. +b. No trademark licenses are granted under this Agreement, and in connection with the Tencent Hunyuan Works, Licensee may not use any name or mark owned by or associated with Tencent or any of its affiliates, except as required for reasonable and customary use in describing and distributing the Tencent Hunyuan Works. Tencent hereby grants You a license to use “Tencent Hunyuan” (the “Mark”) in the Territory solely as required to comply with the provisions of Section 3(c), provided that You comply with any applicable laws related to trademark protection. All goodwill arising out of Your use of the Mark will inure to the benefit of Tencent. +c. If You commence a lawsuit or other proceedings (including a cross-claim or counterclaim in a lawsuit) against Us or any person or entity alleging that the Materials or any Output, or any portion of any of the foregoing, infringe any intellectual property or other right owned or licensable by You, then all licenses granted to You under this Agreement shall terminate as of the date such lawsuit or other proceeding is filed. You will defend, indemnify and hold harmless Us from and against any claim by any Third Party arising out of or related to Your or the Third Party’s use or distribution of the Tencent Hunyuan Works. +d. Tencent claims no rights in Outputs You generate. You and Your users are solely responsible for Outputs and their subsequent uses. +7. DISCLAIMERS OF WARRANTY AND LIMITATIONS OF LIABILITY. +a. We are not obligated to support, update, provide training for, or develop any further version of the Tencent Hunyuan Works or to grant any license thereto. +b. UNLESS AND ONLY TO THE EXTENT REQUIRED BY APPLICABLE LAW, THE TENCENT HUNYUAN WORKS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED “AS IS” WITHOUT ANY EXPRESS OR IMPLIED WARRANTIES OF ANY KIND INCLUDING ANY WARRANTIES OF TITLE, MERCHANTABILITY, NONINFRINGEMENT, COURSE OF DEALING, USAGE OF TRADE, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING, REPRODUCING, MODIFYING, PERFORMING, DISPLAYING OR DISTRIBUTING ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS AND ASSUME ANY AND ALL RISKS ASSOCIATED WITH YOUR OR A THIRD PARTY’S USE OR DISTRIBUTION OF ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS AND YOUR EXERCISE OF RIGHTS AND PERMISSIONS UNDER THIS AGREEMENT. +c. TO THE FULLEST EXTENT PERMITTED BY APPLICABLE LAW, IN NO EVENT SHALL TENCENT OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, FOR ANY DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, CONSEQUENTIAL OR PUNITIVE DAMAGES, OR LOST PROFITS OF ANY KIND ARISING FROM THIS AGREEMENT OR RELATED TO ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS, EVEN IF TENCENT OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING. +8. SURVIVAL AND TERMINATION. +a. The term of this Agreement shall commence upon Your acceptance of this Agreement or access to the Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein. +b. We may terminate this Agreement if You breach any of the terms or conditions of this Agreement. Upon termination of this Agreement, You must promptly delete and cease use of the Tencent Hunyuan Works. Sections 6(a), 6(c), 7 and 9 shall survive the termination of this Agreement. +9. GOVERNING LAW AND JURISDICTION. +a. This Agreement and any dispute arising out of or relating to it will be governed by the laws of the Hong Kong Special Administrative Region of the People’s Republic of China, without regard to conflict of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement. +b. Exclusive jurisdiction and venue for any dispute arising out of or relating to this Agreement will be a court of competent jurisdiction in the Hong Kong Special Administrative Region of the People’s Republic of China, and Tencent and Licensee consent to the exclusive jurisdiction of such court with respect to any such dispute. +  +EXHIBIT A +ACCEPTABLE USE POLICY + +Tencent reserves the right to update this Acceptable Use Policy from time to time. +Last modified: November 5, 2024 + +Tencent endeavors to promote safe and fair use of its tools and features, including Tencent Hunyuan. You agree not to use Tencent Hunyuan or Model Derivatives: +1. Outside the Territory; +2. In any way that violates any applicable national, federal, state, local, international or any other law or regulation; +3. To harm Yourself or others; +4. To repurpose or distribute output from Tencent Hunyuan or any Model Derivatives to harm Yourself or others; +5. To override or circumvent the safety guardrails and safeguards We have put in place; +6. For the purpose of exploiting, harming or attempting to exploit or harm minors in any way; +7. To generate or disseminate verifiably false information and/or content with the purpose of harming others or influencing elections; +8. To generate or facilitate false online engagement, including fake reviews and other means of fake online engagement; +9. To intentionally defame, disparage or otherwise harass others; +10. To generate and/or disseminate malware (including ransomware) or any other content to be used for the purpose of harming electronic systems; +11. To generate or disseminate personal identifiable information with the purpose of harming others; +12. To generate or disseminate information (including images, code, posts, articles), and place the information in any public context (including –through the use of bot generated tweets), without expressly and conspicuously identifying that the information and/or content is machine generated; +13. To impersonate another individual without consent, authorization, or legal right; +14. To make high-stakes automated decisions in domains that affect an individual’s safety, rights or wellbeing (e.g., law enforcement, migration, medicine/health, management of critical infrastructure, safety components of products, essential services, credit, employment, housing, education, social scoring, or insurance); +15. In a manner that violates or disrespects the social ethics and moral standards of other countries or regions; +16. To perform, facilitate, threaten, incite, plan, promote or encourage violent extremism or terrorism; +17. For any use intended to discriminate against or harm individuals or groups based on protected characteristics or categories, online or offline social behavior or known or predicted personal or personality characteristics; +18. To intentionally exploit any of the vulnerabilities of a specific group of persons based on their age, social, physical or mental characteristics, in order to materially distort the behavior of a person pertaining to that group in a manner that causes or is likely to cause that person or another person physical or psychological harm; +19. For military purposes; +20. To engage in the unauthorized or unlicensed practice of any profession including, but not limited to, financial, legal, medical/health, or other professional practices. diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..d6fd0c679f20b45a96043fe61293a5d0baf0e251 --- /dev/null +++ b/app.py @@ -0,0 +1,814 @@ +import os +import tempfile +import gradio as gr +import torch +import torchaudio +from loguru import logger +from typing import Optional, Tuple +import random +import numpy as np + +from hunyuanvideo_foley.utils.model_utils import load_model +from hunyuanvideo_foley.utils.feature_utils import feature_process +from hunyuanvideo_foley.utils.model_utils import denoise_process +from hunyuanvideo_foley.utils.media_utils import merge_audio_video + +# Global variables for model storage +model_dict = None +cfg = None +device = None + +# need to modify the model path +MODEL_PATH = os.environ.get("HIFI_FOLEY_MODEL_PATH", "./pretrained_models/") +CONFIG_PATH = "configs/hunyuanvideo-foley-xxl.yaml" + +def setup_device(device_str: str = "auto", gpu_id: int = 0) -> torch.device: + """Setup computing device""" + if device_str == "auto": + if torch.cuda.is_available(): + device = torch.device(f"cuda:{gpu_id}") + logger.info(f"Using CUDA device: {device}") + elif torch.backends.mps.is_available(): + device = torch.device("mps") + logger.info("Using MPS device") + else: + device = torch.device("cpu") + logger.info("Using CPU device") + else: + if device_str == "cuda": + device = torch.device(f"cuda:{gpu_id}") + else: + device = torch.device(device_str) + logger.info(f"Using specified device: {device}") + + return device + +def auto_load_models() -> str: + """Automatically load preset models""" + global model_dict, cfg, device + + try: + if not os.path.exists(MODEL_PATH): + return f"❌ Model file not found: {MODEL_PATH}" + if not os.path.exists(CONFIG_PATH): + return f"❌ Config file not found: {CONFIG_PATH}" + + # Use GPU by default + device = setup_device("auto", 0) + + # Load model + logger.info("Auto-loading model...") + logger.info(f"Model path: {MODEL_PATH}") + logger.info(f"Config path: {CONFIG_PATH}") + + model_dict, cfg = load_model(MODEL_PATH, CONFIG_PATH, device) + + logger.info("✅ Model loaded successfully!") + return "✅ Model loaded successfully!" + + except Exception as e: + logger.error(f"Model loading failed: {str(e)}") + return f"❌ Model loading failed: {str(e)}" + +def infer_single_video( + video_file, + text_prompt: str, + guidance_scale: float = 4.5, + num_inference_steps: int = 50, + sample_nums: int = 1 +) -> Tuple[list, str]: + """Single video inference""" + global model_dict, cfg, device + + if model_dict is None or cfg is None: + return [], "❌ Please load the model first!" + + if video_file is None: + return [], "❌ Please upload a video file!" + + # Allow empty text prompt, use empty string if no prompt provided + if text_prompt is None: + text_prompt = "" + text_prompt = text_prompt.strip() + + try: + logger.info(f"Processing video: {video_file}") + logger.info(f"Text prompt: {text_prompt}") + + # Feature processing + visual_feats, text_feats, audio_len_in_s = feature_process( + video_file, + text_prompt, + model_dict, + cfg + ) + + # Denoising process to generate multiple audio samples + # Note: The model now generates sample_nums audio samples per inference + # The denoise_process function returns audio with shape [batch_size, channels, samples] + logger.info(f"Generating {sample_nums} audio samples...") + audio, sample_rate = denoise_process( + visual_feats, + text_feats, + audio_len_in_s, + model_dict, + cfg, + guidance_scale=guidance_scale, + num_inference_steps=num_inference_steps, + batch_size=sample_nums + ) + + # Create temporary files to save results + temp_dir = tempfile.mkdtemp() + video_outputs = [] + + # Process each generated audio sample + for i in range(sample_nums): + # Save audio file + audio_output = os.path.join(temp_dir, f"generated_audio_{i+1}.wav") + torchaudio.save(audio_output, audio[i], sample_rate) + + # Merge video and audio + video_output = os.path.join(temp_dir, f"video_with_audio_{i+1}.mp4") + merge_audio_video(audio_output, video_file, video_output) + video_outputs.append(video_output) + + logger.info(f"Inference completed! Generated {sample_nums} samples.") + return video_outputs, f"✅ Generated {sample_nums} audio sample(s) successfully!" + + except Exception as e: + logger.error(f"Inference failed: {str(e)}") + return [], f"❌ Inference failed: {str(e)}" + +def update_video_outputs(video_list, status_msg): + """Update video outputs based on the number of generated samples""" + # Initialize all outputs as None + outputs = [None] * 6 + + # Set values based on generated videos + for i, video_path in enumerate(video_list[:6]): # Max 6 samples + outputs[i] = video_path + + # Return all outputs plus status message + return tuple(outputs + [status_msg]) + +def create_gradio_interface(): + """Create Gradio interface""" + + # Custom CSS for beautiful interface with better contrast + css = """ + .gradio-container { + font-family: 'Inter', -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif; + background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%); + min-height: 100vh; + } + + .main-header { + text-align: center; + padding: 2rem 0; + background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); + border-radius: 20px; + margin-bottom: 2rem; + box-shadow: 0 8px 32px rgba(0,0,0,0.15); + } + + .main-header h1 { + color: white; + font-size: 3rem; + font-weight: 700; + margin-bottom: 0.5rem; + text-shadow: 0 2px 10px rgba(0,0,0,0.3); + } + + .main-header p { + color: rgba(255, 255, 255, 0.95); + font-size: 1.2rem; + font-weight: 300; + } + + .status-card { + background: white; + border-radius: 15px; + padding: 1rem; + margin-bottom: 1.5rem; + border: 1px solid #e1e5e9; + box-shadow: 0 4px 20px rgba(0,0,0,0.08); + } + + .status-card label { + color: #2d3748 !important; + font-weight: 600 !important; + } + + .usage-guide h3 { + color: #2d3748 !important; + font-weight: 600 !important; + margin-bottom: 0.5rem !important; + } + + .usage-guide p { + color: #4a5568 !important; + font-size: 1rem !important; + line-height: 1.6 !important; + margin: 0.5rem 0 !important; + } + + .usage-guide strong { + color: #1a202c !important; + font-weight: 700 !important; + } + + .usage-guide em { + color: #1a202c !important; + font-weight: 700 !important; + font-style: normal !important; + } + + .main-interface { + margin-bottom: 2rem; + } + + .input-section { + background: white; + border-radius: 20px; + padding: 2rem; + margin-right: 1rem; + box-shadow: 0 8px 32px rgba(0,0,0,0.1); + border: 1px solid #e1e5e9; + } + + .input-section h3 { + color: #2d3748 !important; + font-weight: 600 !important; + margin-bottom: 1rem !important; + } + + .input-section label { + color: #4a5568 !important; + font-weight: 500 !important; + } + + .output-section { + background: white; + border-radius: 20px; + padding: 2rem; + margin-left: 1rem; + box-shadow: 0 8px 32px rgba(0,0,0,0.1); + border: 1px solid #e1e5e9; + } + + .output-section h3 { + color: #2d3748 !important; + font-weight: 600 !important; + margin-bottom: 1rem !important; + } + + .output-section label { + color: #4a5568 !important; + font-weight: 500 !important; + } + + .examples-section h3 { + color: #2d3748 !important; + font-weight: 600 !important; + margin-bottom: 1.5rem !important; + } + + .generate-btn { + background: linear-gradient(45deg, #667eea, #764ba2) !important; + border: none !important; + color: white !important; + font-weight: 600 !important; + font-size: 1.1rem !important; + padding: 12px 30px !important; + border-radius: 25px !important; + box-shadow: 0 4px 15px rgba(102, 126, 234, 0.4) !important; + transition: all 0.3s ease !important; + } + + .generate-btn:hover { + transform: translateY(-2px) !important; + box-shadow: 0 8px 25px rgba(102, 126, 234, 0.6) !important; + } + + + + .examples-section { + background: white; + border-radius: 20px; + padding: 2rem; + margin-top: 2rem; + box-shadow: 0 8px 32px rgba(0,0,0,0.1); + border: 1px solid #e1e5e9; + } + + .examples-section p { + color: #4a5568 !important; + margin-bottom: 1rem !important; + } + + .example-row { + background: #f8fafc; + border: 1px solid #e2e8f0; + border-radius: 15px; + padding: 1.5rem; + margin: 1rem 0; + transition: all 0.3s ease; + align-items: center; + } + + .example-row:hover { + border-color: #667eea; + transform: translateY(-2px); + box-shadow: 0 4px 20px rgba(102, 126, 234, 0.15); + } + + .example-row .markdown { + color: #2d3748 !important; + } + + .example-row .markdown p { + color: #2d3748 !important; + margin: 0.5rem 0 !important; + line-height: 1.5 !important; + } + + .example-row .markdown strong { + color: #1a202c !important; + font-weight: 600 !important; + } + + /* Example grid layout styles */ + .example-grid-row { + margin: 1rem 0; + gap: 1rem; + } + + .example-item { + background: #f8fafc; + border: 1px solid #e2e8f0; + border-radius: 15px; + padding: 1rem; + transition: all 0.3s ease; + margin: 0.25rem; + max-width: 250px; + margin-left: auto; + margin-right: auto; + } + + .example-item:hover { + border-color: #667eea; + transform: translateY(-2px); + box-shadow: 0 4px 20px rgba(102, 126, 234, 0.15); + } + + .example-caption { + margin: 0.5rem 0 !important; + min-height: 2.8rem !important; + display: flex !important; + align-items: flex-start !important; + } + + .example-caption p { + color: #2d3748 !important; + font-size: 0.9rem !important; + line-height: 1.4 !important; + margin: 0.5rem 0 !important; + } + + /* Multi-video gallery styles */ + .additional-samples { + margin-top: 1rem; + gap: 0.5rem; + } + + .additional-samples .gradio-video { + border-radius: 10px; + overflow: hidden; + } + + /* Video gallery responsive layout */ + .video-gallery { + display: grid; + gap: 1rem; + margin-top: 1rem; + } + + .video-gallery.single { + grid-template-columns: 1fr; + } + + .video-gallery.dual { + grid-template-columns: 1fr 1fr; + } + + .video-gallery.multi { + grid-template-columns: repeat(2, 1fr); + grid-template-rows: auto auto auto; + } + + .footer-text { + color: #718096 !important; + text-align: center; + padding: 2rem; + font-size: 0.9rem; + } + + /* Video component styling for consistent size */ + .input-section video, + .output-section video, + .example-row video { + width: 100% !important; + height: 300px !important; + object-fit: contain !important; + border-radius: 10px !important; + background-color: #000 !important; + } + + .example-row video { + height: 150px !important; + } + + /* Fix for additional samples video display */ + .additional-samples video { + height: 150px !important; + object-fit: contain !important; + border-radius: 10px !important; + background-color: #000 !important; + } + + .additional-samples .gradio-video { + border-radius: 10px !important; + overflow: hidden !important; + background-color: #000 !important; + } + + .additional-samples .gradio-video > div { + background-color: #000 !important; + border-radius: 10px !important; + } + + /* Video container styling */ + .input-section .video-container, + .output-section .video-container, + .example-row .video-container { + background-color: #000 !important; + border-radius: 10px !important; + display: flex !important; + align-items: center !important; + justify-content: center !important; + overflow: hidden !important; + } + + /* Ensure proper alignment */ + .example-row { + display: flex !important; + align-items: stretch !important; + } + + .example-row > div { + display: flex !important; + flex-direction: column !important; + justify-content: center !important; + } + + /* Video wrapper for better control */ + .video-wrapper { + position: relative !important; + width: 100% !important; + background: #000 !important; + border-radius: 10px !important; + overflow: hidden !important; + display: flex !important; + align-items: center !important; + justify-content: center !important; + } + """ + + with gr.Blocks(css=css, title="HunyuanVideo-Foley") as app: + + # Main header + with gr.Column(elem_classes=["main-header"]): + gr.HTML(""" +

🎵 HunyuanVideo-Foley

+

Text-Video-to-Audio Synthesis: Generate realistic audio from video and text descriptions

+ """) + + # Usage Guide + with gr.Column(elem_classes=["status-card"]): + gr.Markdown(""" + ### 📋 Quick Start Guide + **1.** Upload your video file\t**2.** Add optional text description\t**3.** Adjust sample numbers (1-6)\t**4.** Click Generate Audio + + 💡 For quick start, you can load the prepared examples by clicking the button. + """, elem_classes=["usage-guide"]) + + # Main inference interface - Input and Results side by side + with gr.Row(elem_classes=["main-interface"]): + # Input section + with gr.Column(scale=1, elem_classes=["input-section"]): + gr.Markdown("### 📹 Video Input") + + video_input = gr.Video( + label="Upload Video", + info="Supported formats: MP4, AVI, MOV, etc.", + height=300 + ) + + text_input = gr.Textbox( + label="🎯 Audio Description (English)", + placeholder="A person walks on frozen ice", + lines=3, + info="Describe the audio you want to generate (optional)" + ) + + with gr.Row(): + guidance_scale = gr.Slider( + minimum=1.0, + maximum=10.0, + value=4.5, + step=0.1, + label="🎚️ CFG Scale", + ) + + inference_steps = gr.Slider( + minimum=10, + maximum=100, + value=50, + step=5, + label="⚡ Steps", + ) + + sample_nums = gr.Slider( + minimum=1, + maximum=6, + value=1, + step=1, + label="🎲 Sample Nums", + ) + + generate_btn = gr.Button( + "🎵 Generate Audio", + variant="primary", + elem_classes=["generate-btn"] + ) + + # Results section + with gr.Column(scale=1, elem_classes=["output-section"]): + gr.Markdown("### 🎥 Generated Results") + + # Multi-video gallery for displaying multiple generated samples + with gr.Column(): + # Primary video (Sample 1) + video_output_1 = gr.Video( + label="Sample 1", + height=250, + visible=True + ) + + # Additional videos (Samples 2-6) - initially hidden + with gr.Row(elem_classes=["additional-samples"]): + with gr.Column(scale=1): + video_output_2 = gr.Video( + label="Sample 2", + height=150, + visible=False + ) + video_output_3 = gr.Video( + label="Sample 3", + height=150, + visible=False + ) + with gr.Column(scale=1): + video_output_4 = gr.Video( + label="Sample 4", + height=150, + visible=False + ) + video_output_5 = gr.Video( + label="Sample 5", + height=150, + visible=False + ) + + # Sample 6 - full width + video_output_6 = gr.Video( + label="Sample 6", + height=150, + visible=False + ) + + result_text = gr.Textbox( + label="Status", + interactive=False, + lines=2 + ) + + # Examples section at the bottom + with gr.Column(elem_classes=["examples-section"]): + gr.Markdown("### 🌟 Examples") + gr.Markdown("Click on any example to load it into the interface above") + + # Define your custom examples here - 8 examples total + examples_data = [ + # Example 1 + { + "caption": "A person walks on frozen ice", + "video_path": "examples/1_video.mp4", + "result_path": "examples/1_result.mp4" + }, + # Example 2 + { + "caption": "With a faint sound as their hands parted, the two embraced, a soft 'mm' escaping between them.", + "video_path": "examples/2_video.mp4", + "result_path": "examples/2_result.mp4" + }, + # Example 3 + { + "caption": "The sound of the number 3's bouncing footsteps is as light and clear as glass marbles hitting the ground. Each step carries a magical sound.", + "video_path": "examples/3_video.mp4", + "result_path": "examples/3_result.mp4" + }, + # Example 4 + { + "caption": "gentle gurgling of the stream's current, and music plays in the background which is a beautiful and serene piano solo with a hint of classical charm, evoking a sense of peace and serenity in people's hearts.", + "video_path": "examples/4_video.mp4", + "result_path": "examples/4_result.mp4" + }, + # Example 5 - Add your new examples here + { + "caption": "snow crunching under the snowboard's edge.", + "video_path": "examples/5_video.mp4", + "result_path": "examples/5_result.mp4" + }, + # Example 6 + { + "caption": "The crackling of the fire, the whooshing of the flames, and the occasional crisp popping of charred leaves filled the forest.", + "video_path": "examples/6_video.mp4", + "result_path": "examples/6_result.mp4" + }, + # Example 7 + { + "caption": "humming of the scooter engine accelerates slowly.", + "video_path": "examples/7_video.mp4", + "result_path": "examples/7_result.mp4" + }, + # Example 8 + { + "caption": "splash of water and loud thud as person hits the surface.", + "video_path": "examples/8_video.mp4", + "result_path": "examples/8_result.mp4" + } + ] + + # Create example grid - 4 examples per row, 2 rows total + example_buttons = [] + for row in range(2): # 2 rows + with gr.Row(elem_classes=["example-grid-row"]): + for col in range(4): # 4 columns + idx = row * 4 + col + if idx < len(examples_data): + example = examples_data[idx] + + with gr.Column(scale=1, elem_classes=["example-item"]): + # Video thumbnail + if os.path.exists(example['video_path']): + example_video = gr.Video( + value=example['video_path'], + label=f"Example {idx+1}", + interactive=False, + show_label=True, + height=180 + ) + else: + example_video = gr.HTML(f""" +
+
+

📹 Video not found

+ {example['video_path']} +
+
+ """) + + # Caption (truncated for grid layout) + caption_preview = example['caption'][:60] + "..." if len(example['caption']) > 60 else example['caption'] + gr.Markdown(f"{caption_preview}", elem_classes=["example-caption"]) + + # Load button + example_btn = gr.Button( + f"Load Example {idx+1}", + variant="secondary", + size="sm" + ) + example_buttons.append((example_btn, example)) + + # Event handlers + def process_inference(video_file, text_prompt, guidance_scale, inference_steps, sample_nums): + # Generate videos + video_list, status_msg = infer_single_video( + video_file, text_prompt, guidance_scale, inference_steps, int(sample_nums) + ) + # Update outputs with proper visibility + return update_video_outputs(video_list, status_msg) + + # Add dynamic visibility control based on sample_nums + def update_visibility(sample_nums): + sample_nums = int(sample_nums) + return [ + gr.update(visible=True), # Sample 1 always visible + gr.update(visible=sample_nums >= 2), # Sample 2 + gr.update(visible=sample_nums >= 3), # Sample 3 + gr.update(visible=sample_nums >= 4), # Sample 4 + gr.update(visible=sample_nums >= 5), # Sample 5 + gr.update(visible=sample_nums >= 6), # Sample 6 + ] + + # Update visibility when sample_nums changes + sample_nums.change( + fn=update_visibility, + inputs=[sample_nums], + outputs=[video_output_1, video_output_2, video_output_3, video_output_4, video_output_5, video_output_6] + ) + + generate_btn.click( + fn=process_inference, + inputs=[video_input, text_input, guidance_scale, inference_steps, sample_nums], + outputs=[ + video_output_1, # Sample 1 value + video_output_2, # Sample 2 value + video_output_3, # Sample 3 value + video_output_4, # Sample 4 value + video_output_5, # Sample 5 value + video_output_6, # Sample 6 value + result_text + ] + ) + + # Add click handlers for example buttons + for btn, example in example_buttons: + def create_example_handler(ex): + def handler(): + # Check if files exist, if not, return placeholder message + if os.path.exists(ex['video_path']): + video_file = ex['video_path'] + else: + video_file = None + + if os.path.exists(ex['result_path']): + result_video = ex['result_path'] + else: + result_video = None + + status_msg = f"✅ Loaded example with caption: {ex['caption'][:50]}..." + if not video_file: + status_msg += f"\n⚠️ Video file not found: {ex['video_path']}" + if not result_video: + status_msg += f"\n⚠️ Result video not found: {ex['result_path']}" + + return video_file, ex['caption'], result_video, status_msg + return handler + + btn.click( + fn=create_example_handler(example), + outputs=[video_input, text_input, video_output_1, result_text] + ) + + # Footer + gr.HTML(""" + + """) + + return app + +def set_manual_seed(global_seed): + random.seed(global_seed) + np.random.seed(global_seed) + torch.manual_seed(global_seed) + +if __name__ == "__main__": + set_manual_seed(1) + # Setup logging + logger.remove() + logger.add(lambda msg: print(msg, end=''), level="INFO") + + # Auto-load model + logger.info("Starting application and loading model...") + model_load_result = auto_load_models() + logger.info(model_load_result) + + # Create and launch Gradio app + app = create_gradio_interface() + + # Log completion status + if "successfully" in model_load_result: + logger.info("Application ready, model loaded") + + app.launch( + server_name="0.0.0.0", + server_port=8080, + share=False, + debug=False, + show_error=True + ) diff --git a/assets/data_pipeline.png b/assets/data_pipeline.png new file mode 100644 index 0000000000000000000000000000000000000000..dfb4fd3cba2ffc56156b87b20aef22651e461ba0 --- /dev/null +++ b/assets/data_pipeline.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d5c9e5cd92a7ac24d1e8f39db09e0eaa9ee84bedade8ff08bd1d50141fc7867c +size 384649 diff --git a/assets/model_arch.png b/assets/model_arch.png new file mode 100644 index 0000000000000000000000000000000000000000..b4163a7ccef39ea8ef3ee334c8f2ffb77bc897d1 --- /dev/null +++ b/assets/model_arch.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4709a32df5b115e7806e0eb102aaf2e396a0978e12b31fba338730068d6454d7 +size 542135 diff --git a/assets/pan_chart.png b/assets/pan_chart.png new file mode 100644 index 0000000000000000000000000000000000000000..a5558b5b157a9bf8dc21310a5a4443cf18f07253 --- /dev/null +++ b/assets/pan_chart.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:16019d3355051f5b470532809a0cf9046d22170d30c860dd01929f6921d29ead +size 303974 diff --git a/configs/hunyuanvideo-foley-xxl.yaml b/configs/hunyuanvideo-foley-xxl.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9b7da02713647fd3201d378d86924bfdf42e93c9 --- /dev/null +++ b/configs/hunyuanvideo-foley-xxl.yaml @@ -0,0 +1,49 @@ +model_config: + model_name: HunyuanVideo-Foley-XXL + model_type: 1d + model_precision: bf16 + model_kwargs: + depth_triple_blocks: 18 + depth_single_blocks: 36 + hidden_size: 1536 + num_heads: 12 + mlp_ratio: 4 + mlp_act_type: "gelu_tanh" + qkv_bias: True + qk_norm: True + qk_norm_type: "rms" + attn_mode: "torch" + embedder_type: "default" + interleaved_audio_visual_rope: True + enable_learnable_empty_visual_feat: True + sync_modulation: False + add_sync_feat_to_audio: True + cross_attention: True + use_attention_mask: False + condition_projection: "linear" + sync_feat_dim: 768 # syncformer 768 dim + condition_dim: 768 # clap 768 text condition dim (clip-text) + clip_dim: 768 # siglip2 visual dim + audio_vae_latent_dim: 128 + audio_frame_rate: 50 + patch_size: 1 + rope_dim_list: null + rope_theta: 10000 + text_length: 77 + clip_length: 64 + sync_length: 192 + use_mmaudio_singleblock: True + depth_triple_ssl_encoder: null + depth_single_ssl_encoder: 8 + use_repa_with_audiossl: True + +diffusion_config: + denoise_type: "flow" + flow_path_type: "linear" + flow_predict_type: "velocity" + flow_reverse: True + flow_solver: "euler" + sample_flow_shift: 1.0 + sample_use_flux_shift: False + flux_base_shift: 0.5 + flux_max_shift: 1.15 diff --git a/examples/1_result.mp4 b/examples/1_result.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..3e2403ad988b43cef1b4769283044f27dd913538 --- /dev/null +++ b/examples/1_result.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7f3f49d6130592f479b0aca5f02ba25960140ed8d9d17340ff7f6306b39096a8 +size 11357340 diff --git a/examples/1_video.mp4 b/examples/1_video.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..25ee2eec46f6c24b482422ab5fe4e4256d70caaa --- /dev/null +++ b/examples/1_video.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:54fc2de0b52f6969157b9caff212ffffddd4d34a75efb47ef7e7f8352d0a38db +size 11181543 diff --git a/examples/2_result.mp4 b/examples/2_result.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..74f78d66da1e3622b503d9a54ca056978e920ead --- /dev/null +++ b/examples/2_result.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cf4f324b158f6a6926e77bbd0791610d79f0c3a600a571ed8ec61b0b7e645e46 +size 1720732 diff --git a/examples/2_video.mp4 b/examples/2_video.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..3c13a1c28ad25f19219baf25909ad6b715b8c53c --- /dev/null +++ b/examples/2_video.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:512b185682ec8e60407dff65718443d7a28c75c79aec1e733b5abf7433af41a7 +size 1636945 diff --git a/examples/3_result.mp4 b/examples/3_result.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..81e8da485a92453d83869ed9ae72b2b15b0535ac --- /dev/null +++ b/examples/3_result.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:df90300335b7ab1fb2fc4c020976837c6b3781f796e211bcbbaa30d34353d3e5 +size 1738462 diff --git a/examples/3_video.mp4 b/examples/3_video.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..7fcd81be1e7a625707743b9745a337987be494da --- /dev/null +++ b/examples/3_video.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e79f80cf939fcb507e3fe218f61146d4fd3949a84e802b0f5b67bb2e981931a7 +size 1652180 diff --git a/examples/4_result.mp4 b/examples/4_result.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..f101d19ac947905f5c721157bb2053c4a48cd6a6 --- /dev/null +++ b/examples/4_result.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c7d2b5b63f6756f719d53e8087f772aa6bb25f31fbcd9f1cbae9e075fc841a2c +size 45242387 diff --git a/examples/4_video.mp4 b/examples/4_video.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..ec23454fbd96194f22b1c619d2b3ac744bab9f10 --- /dev/null +++ b/examples/4_video.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f94cfd97634f3df085672ce2a91805697320507a716360bf85ba6eabb5a4b6f0 +size 45066257 diff --git a/examples/5_result.mp4 b/examples/5_result.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..7614d7a8380f0f1a0f6571727e0372a8cfa6a19c --- /dev/null +++ b/examples/5_result.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bf1db36336822b54b4e0e0aa8c98334fc2b97b9288271ddc9d42a45417b1f1d9 +size 40423834 diff --git a/examples/5_video.mp4 b/examples/5_video.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..c307d58c1101f588377a1a54931d59572daa96d6 --- /dev/null +++ b/examples/5_video.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:746954bde2d5e693beecd8e3661bcd66ff0e55a8143f4f3b37f0b6d3873a8fff +size 40248335 diff --git a/examples/6_result.mp4 b/examples/6_result.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..61a5058e6b4e7626947f400e7b8139c9b32ca3f7 --- /dev/null +++ b/examples/6_result.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7778c26a677c04e93dc722cee44b896c5281bea8328a10800a08b98865419cd0 +size 4005580 diff --git a/examples/6_video.mp4 b/examples/6_video.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..ee0656d0fdd7f3fb268bc740142480b71675c764 --- /dev/null +++ b/examples/6_video.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:866b2cff3441ddd686e551181c48ad7ca718626a489cf64198626b42bd732366 +size 3872852 diff --git a/examples/7_result.mp4 b/examples/7_result.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..c8562fa427ff2f1599828d216d93122a0cccd933 --- /dev/null +++ b/examples/7_result.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:491114177a4ceeb50a6edfa5fca14fc6ce4fdb61ee7dfb0e13983236c42ee10d +size 32307884 diff --git a/examples/7_video.mp4 b/examples/7_video.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..06598d8ecdc6b41a05f8b3827d0b4f44861296a4 --- /dev/null +++ b/examples/7_video.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7bc8a31e867245f6a6a6fbfa9778ac4c12e816184dc70324dc92d4496a36f62b +size 32131367 diff --git a/examples/8_result.mp4 b/examples/8_result.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..9de5e673c399250fec2e0869e96b0bfe9fc5fa70 --- /dev/null +++ b/examples/8_result.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:17053c5f8a373d656be2d2030619a71a5fd55db6365f8d54234402121e6030ce +size 29544164 diff --git a/examples/8_video.mp4 b/examples/8_video.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..6e1ebe6d8fa81a54312fc8c3732047c40aa0c671 --- /dev/null +++ b/examples/8_video.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:45a1f35974ad6e2d86304828fdeb230d9b008aae2f10cff8c87d71a8dcc6491e +size 29367637 diff --git a/hunyuanvideo_foley/__init__.py b/hunyuanvideo_foley/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/hunyuanvideo_foley/__pycache__/__init__.cpython-312.pyc b/hunyuanvideo_foley/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..59ddce2fc4bb8be228e5f310a706361cb3829bb8 Binary files /dev/null and b/hunyuanvideo_foley/__pycache__/__init__.cpython-312.pyc differ diff --git a/hunyuanvideo_foley/constants.py b/hunyuanvideo_foley/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..81519407b44f33cfcbadd550f41a634b11821649 --- /dev/null +++ b/hunyuanvideo_foley/constants.py @@ -0,0 +1,57 @@ +"""Constants used throughout the HunyuanVideo-Foley project.""" + +from typing import Dict, List + +# Model configuration +DEFAULT_AUDIO_SAMPLE_RATE = 48000 +DEFAULT_VIDEO_FPS = 25 +DEFAULT_AUDIO_CHANNELS = 2 + +# Video processing +MAX_VIDEO_DURATION_SECONDS = 15.0 +MIN_VIDEO_DURATION_SECONDS = 1.0 + +# Audio processing +AUDIO_VAE_LATENT_DIM = 128 +AUDIO_FRAME_RATE = 75 # frames per second in latent space + +# Visual features +FPS_VISUAL: Dict[str, int] = { + "siglip2": 8, + "synchformer": 25 +} + +# Model paths (can be overridden by environment variables) +DEFAULT_MODEL_PATH = "./pretrained_models/" +DEFAULT_CONFIG_PATH = "configs/hunyuanvideo-foley-xxl.yaml" + +# Inference parameters +DEFAULT_GUIDANCE_SCALE = 4.5 +DEFAULT_NUM_INFERENCE_STEPS = 50 +MIN_GUIDANCE_SCALE = 1.0 +MAX_GUIDANCE_SCALE = 10.0 +MIN_INFERENCE_STEPS = 10 +MAX_INFERENCE_STEPS = 100 + +# Text processing +MAX_TEXT_LENGTH = 100 +DEFAULT_NEGATIVE_PROMPT = "noisy, harsh" + +# File extensions +SUPPORTED_VIDEO_EXTENSIONS: List[str] = [".mp4", ".avi", ".mov", ".mkv", ".webm"] +SUPPORTED_AUDIO_EXTENSIONS: List[str] = [".wav", ".mp3", ".flac", ".aac"] + +# Quality settings +AUDIO_QUALITY_SETTINGS: Dict[str, List[str]] = { + "high": ["-b:a", "192k"], + "medium": ["-b:a", "128k"], + "low": ["-b:a", "96k"] +} + +# Error messages +ERROR_MESSAGES: Dict[str, str] = { + "model_not_loaded": "Model is not loaded. Please load the model first.", + "invalid_video_format": "Unsupported video format. Supported formats: {formats}", + "video_too_long": f"Video duration exceeds maximum of {MAX_VIDEO_DURATION_SECONDS} seconds", + "ffmpeg_not_found": "ffmpeg not found. Please install ffmpeg: https://ffmpeg.org/download.html" +} \ No newline at end of file diff --git a/hunyuanvideo_foley/models/__init__.py b/hunyuanvideo_foley/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/hunyuanvideo_foley/models/__pycache__/mmaudio_layer.cpython-312.pyc b/hunyuanvideo_foley/models/__pycache__/mmaudio_layer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1a5ca630b276db4dbd96c127571aa57e170f1937 Binary files /dev/null and b/hunyuanvideo_foley/models/__pycache__/mmaudio_layer.cpython-312.pyc differ diff --git a/hunyuanvideo_foley/models/dac_vae/__init__.py b/hunyuanvideo_foley/models/dac_vae/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..51205ef6ded9c6735a988b76008e0f6bdce8e215 --- /dev/null +++ b/hunyuanvideo_foley/models/dac_vae/__init__.py @@ -0,0 +1,16 @@ +__version__ = "1.0.0" + +# preserved here for legacy reasons +__model_version__ = "latest" + +import audiotools + +audiotools.ml.BaseModel.INTERN += ["dac.**"] +audiotools.ml.BaseModel.EXTERN += ["einops"] + + +from . import nn +from . import model +from . import utils +from .model import DAC +from .model import DACFile diff --git a/hunyuanvideo_foley/models/dac_vae/__main__.py b/hunyuanvideo_foley/models/dac_vae/__main__.py new file mode 100644 index 0000000000000000000000000000000000000000..f1fe6531c5bf82f731d8e07ec09c21d79aae4cfa --- /dev/null +++ b/hunyuanvideo_foley/models/dac_vae/__main__.py @@ -0,0 +1,36 @@ +import sys + +import argbind + +from .utils import download +from .utils.decode import decode +from .utils.encode import encode + +STAGES = ["encode", "decode", "download"] + + +def run(stage: str): + """Run stages. + + Parameters + ---------- + stage : str + Stage to run + """ + if stage not in STAGES: + raise ValueError(f"Unknown command: {stage}. Allowed commands are {STAGES}") + stage_fn = globals()[stage] + + if stage == "download": + stage_fn() + return + + stage_fn() + + +if __name__ == "__main__": + group = sys.argv.pop(1) + args = argbind.parse_args(group=group) + + with argbind.scope(args): + run(group) diff --git a/hunyuanvideo_foley/models/dac_vae/model/__init__.py b/hunyuanvideo_foley/models/dac_vae/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..02a75b7ad6028f5c41b6a8285b0257d4c23bdfcf --- /dev/null +++ b/hunyuanvideo_foley/models/dac_vae/model/__init__.py @@ -0,0 +1,4 @@ +from .base import CodecMixin +from .base import DACFile +from .dac import DAC +from .discriminator import Discriminator diff --git a/hunyuanvideo_foley/models/dac_vae/model/base.py b/hunyuanvideo_foley/models/dac_vae/model/base.py new file mode 100644 index 0000000000000000000000000000000000000000..e95a84149a767f256a54b7cc3241c09551f39061 --- /dev/null +++ b/hunyuanvideo_foley/models/dac_vae/model/base.py @@ -0,0 +1,301 @@ +import math +from dataclasses import dataclass +from pathlib import Path +from typing import Union + +import numpy as np +import torch +import tqdm +from audiotools import AudioSignal +from torch import nn + +SUPPORTED_VERSIONS = ["1.0.0"] + + +@dataclass +class DACFile: + codes: torch.Tensor + + # Metadata + chunk_length: int + original_length: int + input_db: float + channels: int + sample_rate: int + padding: bool + dac_version: str + + def save(self, path): + artifacts = { + "codes": self.codes.numpy().astype(np.uint16), + "metadata": { + "input_db": self.input_db.numpy().astype(np.float32), + "original_length": self.original_length, + "sample_rate": self.sample_rate, + "chunk_length": self.chunk_length, + "channels": self.channels, + "padding": self.padding, + "dac_version": SUPPORTED_VERSIONS[-1], + }, + } + path = Path(path).with_suffix(".dac") + with open(path, "wb") as f: + np.save(f, artifacts) + return path + + @classmethod + def load(cls, path): + artifacts = np.load(path, allow_pickle=True)[()] + codes = torch.from_numpy(artifacts["codes"].astype(int)) + if artifacts["metadata"].get("dac_version", None) not in SUPPORTED_VERSIONS: + raise RuntimeError( + f"Given file {path} can't be loaded with this version of descript-audio-codec." + ) + return cls(codes=codes, **artifacts["metadata"]) + + +class CodecMixin: + @property + def padding(self): + if not hasattr(self, "_padding"): + self._padding = True + return self._padding + + @padding.setter + def padding(self, value): + assert isinstance(value, bool) + + layers = [ + l for l in self.modules() if isinstance(l, (nn.Conv1d, nn.ConvTranspose1d)) + ] + + for layer in layers: + if value: + if hasattr(layer, "original_padding"): + layer.padding = layer.original_padding + else: + layer.original_padding = layer.padding + layer.padding = tuple(0 for _ in range(len(layer.padding))) + + self._padding = value + + def get_delay(self): + # Any number works here, delay is invariant to input length + l_out = self.get_output_length(0) + L = l_out + + layers = [] + for layer in self.modules(): + if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)): + layers.append(layer) + + for layer in reversed(layers): + d = layer.dilation[0] + k = layer.kernel_size[0] + s = layer.stride[0] + + if isinstance(layer, nn.ConvTranspose1d): + L = ((L - d * (k - 1) - 1) / s) + 1 + elif isinstance(layer, nn.Conv1d): + L = (L - 1) * s + d * (k - 1) + 1 + + L = math.ceil(L) + + l_in = L + + return (l_in - l_out) // 2 + + def get_output_length(self, input_length): + L = input_length + # Calculate output length + for layer in self.modules(): + if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)): + d = layer.dilation[0] + k = layer.kernel_size[0] + s = layer.stride[0] + + if isinstance(layer, nn.Conv1d): + L = ((L - d * (k - 1) - 1) / s) + 1 + elif isinstance(layer, nn.ConvTranspose1d): + L = (L - 1) * s + d * (k - 1) + 1 + + L = math.floor(L) + return L + + @torch.no_grad() + def compress( + self, + audio_path_or_signal: Union[str, Path, AudioSignal], + win_duration: float = 1.0, + verbose: bool = False, + normalize_db: float = -16, + n_quantizers: int = None, + ) -> DACFile: + """Processes an audio signal from a file or AudioSignal object into + discrete codes. This function processes the signal in short windows, + using constant GPU memory. + + Parameters + ---------- + audio_path_or_signal : Union[str, Path, AudioSignal] + audio signal to reconstruct + win_duration : float, optional + window duration in seconds, by default 5.0 + verbose : bool, optional + by default False + normalize_db : float, optional + normalize db, by default -16 + + Returns + ------- + DACFile + Object containing compressed codes and metadata + required for decompression + """ + audio_signal = audio_path_or_signal + if isinstance(audio_signal, (str, Path)): + audio_signal = AudioSignal.load_from_file_with_ffmpeg(str(audio_signal)) + + self.eval() + original_padding = self.padding + original_device = audio_signal.device + + audio_signal = audio_signal.clone() + audio_signal = audio_signal.to_mono() + original_sr = audio_signal.sample_rate + + resample_fn = audio_signal.resample + loudness_fn = audio_signal.loudness + + # If audio is > 10 minutes long, use the ffmpeg versions + if audio_signal.signal_duration >= 10 * 60 * 60: + resample_fn = audio_signal.ffmpeg_resample + loudness_fn = audio_signal.ffmpeg_loudness + + original_length = audio_signal.signal_length + resample_fn(self.sample_rate) + input_db = loudness_fn() + + if normalize_db is not None: + audio_signal.normalize(normalize_db) + audio_signal.ensure_max_of_audio() + + nb, nac, nt = audio_signal.audio_data.shape + audio_signal.audio_data = audio_signal.audio_data.reshape(nb * nac, 1, nt) + win_duration = ( + audio_signal.signal_duration if win_duration is None else win_duration + ) + + if audio_signal.signal_duration <= win_duration: + # Unchunked compression (used if signal length < win duration) + self.padding = True + n_samples = nt + hop = nt + else: + # Chunked inference + self.padding = False + # Zero-pad signal on either side by the delay + audio_signal.zero_pad(self.delay, self.delay) + n_samples = int(win_duration * self.sample_rate) + # Round n_samples to nearest hop length multiple + n_samples = int(math.ceil(n_samples / self.hop_length) * self.hop_length) + hop = self.get_output_length(n_samples) + + codes = [] + range_fn = range if not verbose else tqdm.trange + + for i in range_fn(0, nt, hop): + x = audio_signal[..., i : i + n_samples] + x = x.zero_pad(0, max(0, n_samples - x.shape[-1])) + + audio_data = x.audio_data.to(self.device) + audio_data = self.preprocess(audio_data, self.sample_rate) + _, c, _, _, _ = self.encode(audio_data, n_quantizers) + codes.append(c.to(original_device)) + chunk_length = c.shape[-1] + + codes = torch.cat(codes, dim=-1) + + dac_file = DACFile( + codes=codes, + chunk_length=chunk_length, + original_length=original_length, + input_db=input_db, + channels=nac, + sample_rate=original_sr, + padding=self.padding, + dac_version=SUPPORTED_VERSIONS[-1], + ) + + if n_quantizers is not None: + codes = codes[:, :n_quantizers, :] + + self.padding = original_padding + return dac_file + + @torch.no_grad() + def decompress( + self, + obj: Union[str, Path, DACFile], + verbose: bool = False, + ) -> AudioSignal: + """Reconstruct audio from a given .dac file + + Parameters + ---------- + obj : Union[str, Path, DACFile] + .dac file location or corresponding DACFile object. + verbose : bool, optional + Prints progress if True, by default False + + Returns + ------- + AudioSignal + Object with the reconstructed audio + """ + self.eval() + if isinstance(obj, (str, Path)): + obj = DACFile.load(obj) + + original_padding = self.padding + self.padding = obj.padding + + range_fn = range if not verbose else tqdm.trange + codes = obj.codes + original_device = codes.device + chunk_length = obj.chunk_length + recons = [] + + for i in range_fn(0, codes.shape[-1], chunk_length): + c = codes[..., i : i + chunk_length].to(self.device) + z = self.quantizer.from_codes(c)[0] + r = self.decode(z) + recons.append(r.to(original_device)) + + recons = torch.cat(recons, dim=-1) + recons = AudioSignal(recons, self.sample_rate) + + resample_fn = recons.resample + loudness_fn = recons.loudness + + # If audio is > 10 minutes long, use the ffmpeg versions + if recons.signal_duration >= 10 * 60 * 60: + resample_fn = recons.ffmpeg_resample + loudness_fn = recons.ffmpeg_loudness + + if obj.input_db is not None: + recons.normalize(obj.input_db) + + resample_fn(obj.sample_rate) + + if obj.original_length is not None: + recons = recons[..., : obj.original_length] + loudness_fn() + recons.audio_data = recons.audio_data.reshape( + -1, obj.channels, obj.original_length + ) + else: + loudness_fn() + + self.padding = original_padding + return recons diff --git a/hunyuanvideo_foley/models/dac_vae/model/dac.py b/hunyuanvideo_foley/models/dac_vae/model/dac.py new file mode 100644 index 0000000000000000000000000000000000000000..3df2cbad1502774ce2519c1795b6e85571ac3fc1 --- /dev/null +++ b/hunyuanvideo_foley/models/dac_vae/model/dac.py @@ -0,0 +1,410 @@ +import math +from typing import List +from typing import Union + +import numpy as np +import torch +from audiotools import AudioSignal +from audiotools.ml import BaseModel +from torch import nn + +from .base import CodecMixin +from ..nn.layers import Snake1d +from ..nn.layers import WNConv1d +from ..nn.layers import WNConvTranspose1d +from ..nn.quantize import ResidualVectorQuantize +from ..nn.vae_utils import DiagonalGaussianDistribution + + +def init_weights(m): + if isinstance(m, nn.Conv1d): + nn.init.trunc_normal_(m.weight, std=0.02) + nn.init.constant_(m.bias, 0) + + +class ResidualUnit(nn.Module): + def __init__(self, dim: int = 16, dilation: int = 1): + super().__init__() + pad = ((7 - 1) * dilation) // 2 + self.block = nn.Sequential( + Snake1d(dim), + WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad), + Snake1d(dim), + WNConv1d(dim, dim, kernel_size=1), + ) + + def forward(self, x): + y = self.block(x) + pad = (x.shape[-1] - y.shape[-1]) // 2 + if pad > 0: + x = x[..., pad:-pad] + return x + y + + +class EncoderBlock(nn.Module): + def __init__(self, dim: int = 16, stride: int = 1): + super().__init__() + self.block = nn.Sequential( + ResidualUnit(dim // 2, dilation=1), + ResidualUnit(dim // 2, dilation=3), + ResidualUnit(dim // 2, dilation=9), + Snake1d(dim // 2), + WNConv1d( + dim // 2, + dim, + kernel_size=2 * stride, + stride=stride, + padding=math.ceil(stride / 2), + ), + ) + + def forward(self, x): + return self.block(x) + + +class Encoder(nn.Module): + def __init__( + self, + d_model: int = 64, + strides: list = [2, 4, 8, 8], + d_latent: int = 64, + ): + super().__init__() + # Create first convolution + self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)] + + # Create EncoderBlocks that double channels as they downsample by `stride` + for stride in strides: + d_model *= 2 + self.block += [EncoderBlock(d_model, stride=stride)] + + # Create last convolution + self.block += [ + Snake1d(d_model), + WNConv1d(d_model, d_latent, kernel_size=3, padding=1), + ] + + # Wrap black into nn.Sequential + self.block = nn.Sequential(*self.block) + self.enc_dim = d_model + + def forward(self, x): + return self.block(x) + + +class DecoderBlock(nn.Module): + def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1): + super().__init__() + self.block = nn.Sequential( + Snake1d(input_dim), + WNConvTranspose1d( + input_dim, + output_dim, + kernel_size=2 * stride, + stride=stride, + padding=math.ceil(stride / 2), + output_padding=stride % 2, + ), + ResidualUnit(output_dim, dilation=1), + ResidualUnit(output_dim, dilation=3), + ResidualUnit(output_dim, dilation=9), + ) + + def forward(self, x): + return self.block(x) + + +class Decoder(nn.Module): + def __init__( + self, + input_channel, + channels, + rates, + d_out: int = 1, + ): + super().__init__() + + # Add first conv layer + layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3)] + + # Add upsampling + MRF blocks + for i, stride in enumerate(rates): + input_dim = channels // 2**i + output_dim = channels // 2 ** (i + 1) + layers += [DecoderBlock(input_dim, output_dim, stride)] + + # Add final conv layer + layers += [ + Snake1d(output_dim), + WNConv1d(output_dim, d_out, kernel_size=7, padding=3), + nn.Tanh(), + ] + + self.model = nn.Sequential(*layers) + + def forward(self, x): + return self.model(x) + + +class DAC(BaseModel, CodecMixin): + def __init__( + self, + encoder_dim: int = 64, + encoder_rates: List[int] = [2, 4, 8, 8], + latent_dim: int = None, + decoder_dim: int = 1536, + decoder_rates: List[int] = [8, 8, 4, 2], + n_codebooks: int = 9, + codebook_size: int = 1024, + codebook_dim: Union[int, list] = 8, + quantizer_dropout: bool = False, + sample_rate: int = 44100, + continuous: bool = False, + ): + super().__init__() + + self.encoder_dim = encoder_dim + self.encoder_rates = encoder_rates + self.decoder_dim = decoder_dim + self.decoder_rates = decoder_rates + self.sample_rate = sample_rate + self.continuous = continuous + + if latent_dim is None: + latent_dim = encoder_dim * (2 ** len(encoder_rates)) + + self.latent_dim = latent_dim + + self.hop_length = np.prod(encoder_rates) + self.encoder = Encoder(encoder_dim, encoder_rates, latent_dim) + + if not continuous: + self.n_codebooks = n_codebooks + self.codebook_size = codebook_size + self.codebook_dim = codebook_dim + self.quantizer = ResidualVectorQuantize( + input_dim=latent_dim, + n_codebooks=n_codebooks, + codebook_size=codebook_size, + codebook_dim=codebook_dim, + quantizer_dropout=quantizer_dropout, + ) + else: + self.quant_conv = torch.nn.Conv1d(latent_dim, 2 * latent_dim, 1) + self.post_quant_conv = torch.nn.Conv1d(latent_dim, latent_dim, 1) + + self.decoder = Decoder( + latent_dim, + decoder_dim, + decoder_rates, + ) + self.sample_rate = sample_rate + self.apply(init_weights) + + self.delay = self.get_delay() + + @property + def dtype(self): + """Get the dtype of the model parameters.""" + # Return the dtype of the first parameter found + for param in self.parameters(): + return param.dtype + return torch.float32 # fallback + + @property + def device(self): + """Get the device of the model parameters.""" + # Return the device of the first parameter found + for param in self.parameters(): + return param.device + return torch.device('cpu') # fallback + + def preprocess(self, audio_data, sample_rate): + if sample_rate is None: + sample_rate = self.sample_rate + assert sample_rate == self.sample_rate + + length = audio_data.shape[-1] + right_pad = math.ceil(length / self.hop_length) * self.hop_length - length + audio_data = nn.functional.pad(audio_data, (0, right_pad)) + + return audio_data + + def encode( + self, + audio_data: torch.Tensor, + n_quantizers: int = None, + ): + """Encode given audio data and return quantized latent codes + + Parameters + ---------- + audio_data : Tensor[B x 1 x T] + Audio data to encode + n_quantizers : int, optional + Number of quantizers to use, by default None + If None, all quantizers are used. + + Returns + ------- + dict + A dictionary with the following keys: + "z" : Tensor[B x D x T] + Quantized continuous representation of input + "codes" : Tensor[B x N x T] + Codebook indices for each codebook + (quantized discrete representation of input) + "latents" : Tensor[B x N*D x T] + Projected latents (continuous representation of input before quantization) + "vq/commitment_loss" : Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + "vq/codebook_loss" : Tensor[1] + Codebook loss to update the codebook + "length" : int + Number of samples in input audio + """ + z = self.encoder(audio_data) # [B x D x T] + if not self.continuous: + z, codes, latents, commitment_loss, codebook_loss = self.quantizer(z, n_quantizers) + else: + z = self.quant_conv(z) # [B x 2D x T] + z = DiagonalGaussianDistribution(z) + codes, latents, commitment_loss, codebook_loss = None, None, 0, 0 + + return z, codes, latents, commitment_loss, codebook_loss + + def decode(self, z: torch.Tensor): + """Decode given latent codes and return audio data + + Parameters + ---------- + z : Tensor[B x D x T] + Quantized continuous representation of input + length : int, optional + Number of samples in output audio, by default None + + Returns + ------- + dict + A dictionary with the following keys: + "audio" : Tensor[B x 1 x length] + Decoded audio data. + """ + if not self.continuous: + audio = self.decoder(z) + else: + z = self.post_quant_conv(z) + audio = self.decoder(z) + + return audio + + def forward( + self, + audio_data: torch.Tensor, + sample_rate: int = None, + n_quantizers: int = None, + ): + """Model forward pass + + Parameters + ---------- + audio_data : Tensor[B x 1 x T] + Audio data to encode + sample_rate : int, optional + Sample rate of audio data in Hz, by default None + If None, defaults to `self.sample_rate` + n_quantizers : int, optional + Number of quantizers to use, by default None. + If None, all quantizers are used. + + Returns + ------- + dict + A dictionary with the following keys: + "z" : Tensor[B x D x T] + Quantized continuous representation of input + "codes" : Tensor[B x N x T] + Codebook indices for each codebook + (quantized discrete representation of input) + "latents" : Tensor[B x N*D x T] + Projected latents (continuous representation of input before quantization) + "vq/commitment_loss" : Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + "vq/codebook_loss" : Tensor[1] + Codebook loss to update the codebook + "length" : int + Number of samples in input audio + "audio" : Tensor[B x 1 x length] + Decoded audio data. + """ + length = audio_data.shape[-1] + audio_data = self.preprocess(audio_data, sample_rate) + if not self.continuous: + z, codes, latents, commitment_loss, codebook_loss = self.encode(audio_data, n_quantizers) + + x = self.decode(z) + return { + "audio": x[..., :length], + "z": z, + "codes": codes, + "latents": latents, + "vq/commitment_loss": commitment_loss, + "vq/codebook_loss": codebook_loss, + } + else: + posterior, _, _, _, _ = self.encode(audio_data, n_quantizers) + z = posterior.sample() + x = self.decode(z) + + kl_loss = posterior.kl() + kl_loss = kl_loss.mean() + + return { + "audio": x[..., :length], + "z": z, + "kl_loss": kl_loss, + } + + +if __name__ == "__main__": + import numpy as np + from functools import partial + + model = DAC().to("cpu") + + for n, m in model.named_modules(): + o = m.extra_repr() + p = sum([np.prod(p.size()) for p in m.parameters()]) + fn = lambda o, p: o + f" {p/1e6:<.3f}M params." + setattr(m, "extra_repr", partial(fn, o=o, p=p)) + print(model) + print("Total # of params: ", sum([np.prod(p.size()) for p in model.parameters()])) + + length = 88200 * 2 + x = torch.randn(1, 1, length).to(model.device) + x.requires_grad_(True) + x.retain_grad() + + # Make a forward pass + out = model(x)["audio"] + print("Input shape:", x.shape) + print("Output shape:", out.shape) + + # Create gradient variable + grad = torch.zeros_like(out) + grad[:, :, grad.shape[-1] // 2] = 1 + + # Make a backward pass + out.backward(grad) + + # Check non-zero values + gradmap = x.grad.squeeze(0) + gradmap = (gradmap != 0).sum(0) # sum across features + rf = (gradmap != 0).sum() + + print(f"Receptive field: {rf.item()}") + + x = AudioSignal(torch.randn(1, 1, 44100 * 60), 44100) + model.decompress(model.compress(x, verbose=True), verbose=True) diff --git a/hunyuanvideo_foley/models/dac_vae/model/discriminator.py b/hunyuanvideo_foley/models/dac_vae/model/discriminator.py new file mode 100644 index 0000000000000000000000000000000000000000..09c79d1342ca46bef21daca64667577f05e61638 --- /dev/null +++ b/hunyuanvideo_foley/models/dac_vae/model/discriminator.py @@ -0,0 +1,228 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from audiotools import AudioSignal +from audiotools import ml +from audiotools import STFTParams +from einops import rearrange +from torch.nn.utils import weight_norm + + +def WNConv1d(*args, **kwargs): + act = kwargs.pop("act", True) + conv = weight_norm(nn.Conv1d(*args, **kwargs)) + if not act: + return conv + return nn.Sequential(conv, nn.LeakyReLU(0.1)) + + +def WNConv2d(*args, **kwargs): + act = kwargs.pop("act", True) + conv = weight_norm(nn.Conv2d(*args, **kwargs)) + if not act: + return conv + return nn.Sequential(conv, nn.LeakyReLU(0.1)) + + +class MPD(nn.Module): + def __init__(self, period): + super().__init__() + self.period = period + self.convs = nn.ModuleList( + [ + WNConv2d(1, 32, (5, 1), (3, 1), padding=(2, 0)), + WNConv2d(32, 128, (5, 1), (3, 1), padding=(2, 0)), + WNConv2d(128, 512, (5, 1), (3, 1), padding=(2, 0)), + WNConv2d(512, 1024, (5, 1), (3, 1), padding=(2, 0)), + WNConv2d(1024, 1024, (5, 1), 1, padding=(2, 0)), + ] + ) + self.conv_post = WNConv2d( + 1024, 1, kernel_size=(3, 1), padding=(1, 0), act=False + ) + + def pad_to_period(self, x): + t = x.shape[-1] + x = F.pad(x, (0, self.period - t % self.period), mode="reflect") + return x + + def forward(self, x): + fmap = [] + + x = self.pad_to_period(x) + x = rearrange(x, "b c (l p) -> b c l p", p=self.period) + + for layer in self.convs: + x = layer(x) + fmap.append(x) + + x = self.conv_post(x) + fmap.append(x) + + return fmap + + +class MSD(nn.Module): + def __init__(self, rate: int = 1, sample_rate: int = 44100): + super().__init__() + self.convs = nn.ModuleList( + [ + WNConv1d(1, 16, 15, 1, padding=7), + WNConv1d(16, 64, 41, 4, groups=4, padding=20), + WNConv1d(64, 256, 41, 4, groups=16, padding=20), + WNConv1d(256, 1024, 41, 4, groups=64, padding=20), + WNConv1d(1024, 1024, 41, 4, groups=256, padding=20), + WNConv1d(1024, 1024, 5, 1, padding=2), + ] + ) + self.conv_post = WNConv1d(1024, 1, 3, 1, padding=1, act=False) + self.sample_rate = sample_rate + self.rate = rate + + def forward(self, x): + x = AudioSignal(x, self.sample_rate) + x.resample(self.sample_rate // self.rate) + x = x.audio_data + + fmap = [] + + for l in self.convs: + x = l(x) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + + return fmap + + +BANDS = [(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)] + + +class MRD(nn.Module): + def __init__( + self, + window_length: int, + hop_factor: float = 0.25, + sample_rate: int = 44100, + bands: list = BANDS, + ): + """Complex multi-band spectrogram discriminator. + Parameters + ---------- + window_length : int + Window length of STFT. + hop_factor : float, optional + Hop factor of the STFT, defaults to ``0.25 * window_length``. + sample_rate : int, optional + Sampling rate of audio in Hz, by default 44100 + bands : list, optional + Bands to run discriminator over. + """ + super().__init__() + + self.window_length = window_length + self.hop_factor = hop_factor + self.sample_rate = sample_rate + self.stft_params = STFTParams( + window_length=window_length, + hop_length=int(window_length * hop_factor), + match_stride=True, + ) + + n_fft = window_length // 2 + 1 + bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands] + self.bands = bands + + ch = 32 + convs = lambda: nn.ModuleList( + [ + WNConv2d(2, ch, (3, 9), (1, 1), padding=(1, 4)), + WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), + WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), + WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), + WNConv2d(ch, ch, (3, 3), (1, 1), padding=(1, 1)), + ] + ) + self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))]) + self.conv_post = WNConv2d(ch, 1, (3, 3), (1, 1), padding=(1, 1), act=False) + + def spectrogram(self, x): + x = AudioSignal(x, self.sample_rate, stft_params=self.stft_params) + x = torch.view_as_real(x.stft()) + x = rearrange(x, "b 1 f t c -> (b 1) c t f") + # Split into bands + x_bands = [x[..., b[0] : b[1]] for b in self.bands] + return x_bands + + def forward(self, x): + x_bands = self.spectrogram(x) + fmap = [] + + x = [] + for band, stack in zip(x_bands, self.band_convs): + for layer in stack: + band = layer(band) + fmap.append(band) + x.append(band) + + x = torch.cat(x, dim=-1) + x = self.conv_post(x) + fmap.append(x) + + return fmap + + +class Discriminator(ml.BaseModel): + def __init__( + self, + rates: list = [], + periods: list = [2, 3, 5, 7, 11], + fft_sizes: list = [2048, 1024, 512], + sample_rate: int = 44100, + bands: list = BANDS, + ): + """Discriminator that combines multiple discriminators. + + Parameters + ---------- + rates : list, optional + sampling rates (in Hz) to run MSD at, by default [] + If empty, MSD is not used. + periods : list, optional + periods (of samples) to run MPD at, by default [2, 3, 5, 7, 11] + fft_sizes : list, optional + Window sizes of the FFT to run MRD at, by default [2048, 1024, 512] + sample_rate : int, optional + Sampling rate of audio in Hz, by default 44100 + bands : list, optional + Bands to run MRD at, by default `BANDS` + """ + super().__init__() + discs = [] + discs += [MPD(p) for p in periods] + discs += [MSD(r, sample_rate=sample_rate) for r in rates] + discs += [MRD(f, sample_rate=sample_rate, bands=bands) for f in fft_sizes] + self.discriminators = nn.ModuleList(discs) + + def preprocess(self, y): + # Remove DC offset + y = y - y.mean(dim=-1, keepdims=True) + # Peak normalize the volume of input audio + y = 0.8 * y / (y.abs().max(dim=-1, keepdim=True)[0] + 1e-9) + return y + + def forward(self, x): + x = self.preprocess(x) + fmaps = [d(x) for d in self.discriminators] + return fmaps + + +if __name__ == "__main__": + disc = Discriminator() + x = torch.zeros(1, 1, 44100) + results = disc(x) + for i, result in enumerate(results): + print(f"disc{i}") + for i, r in enumerate(result): + print(r.shape, r.mean(), r.min(), r.max()) + print() diff --git a/hunyuanvideo_foley/models/dac_vae/nn/__init__.py b/hunyuanvideo_foley/models/dac_vae/nn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6718c8b1a3d36c31655b030f4c515a144cde4db7 --- /dev/null +++ b/hunyuanvideo_foley/models/dac_vae/nn/__init__.py @@ -0,0 +1,3 @@ +from . import layers +from . import loss +from . import quantize diff --git a/hunyuanvideo_foley/models/dac_vae/nn/layers.py b/hunyuanvideo_foley/models/dac_vae/nn/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..44fbc2929715e11d843b24195d7042a528969a94 --- /dev/null +++ b/hunyuanvideo_foley/models/dac_vae/nn/layers.py @@ -0,0 +1,33 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from torch.nn.utils import weight_norm + + +def WNConv1d(*args, **kwargs): + return weight_norm(nn.Conv1d(*args, **kwargs)) + + +def WNConvTranspose1d(*args, **kwargs): + return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) + + +# Scripting this brings model speed up 1.4x +@torch.jit.script +def snake(x, alpha): + shape = x.shape + x = x.reshape(shape[0], shape[1], -1) + x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2) + x = x.reshape(shape) + return x + + +class Snake1d(nn.Module): + def __init__(self, channels): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1, channels, 1)) + + def forward(self, x): + return snake(x, self.alpha) diff --git a/hunyuanvideo_foley/models/dac_vae/nn/loss.py b/hunyuanvideo_foley/models/dac_vae/nn/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..9bb3dd6ce08a7a24f18f941eeb5b68fe9461e86b --- /dev/null +++ b/hunyuanvideo_foley/models/dac_vae/nn/loss.py @@ -0,0 +1,368 @@ +import typing +from typing import List + +import torch +import torch.nn.functional as F +from audiotools import AudioSignal +from audiotools import STFTParams +from torch import nn + + +class L1Loss(nn.L1Loss): + """L1 Loss between AudioSignals. Defaults + to comparing ``audio_data``, but any + attribute of an AudioSignal can be used. + + Parameters + ---------- + attribute : str, optional + Attribute of signal to compare, defaults to ``audio_data``. + weight : float, optional + Weight of this loss, defaults to 1.0. + + Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py + """ + + def __init__(self, attribute: str = "audio_data", weight: float = 1.0, **kwargs): + self.attribute = attribute + self.weight = weight + super().__init__(**kwargs) + + def forward(self, x: AudioSignal, y: AudioSignal): + """ + Parameters + ---------- + x : AudioSignal + Estimate AudioSignal + y : AudioSignal + Reference AudioSignal + + Returns + ------- + torch.Tensor + L1 loss between AudioSignal attributes. + """ + if isinstance(x, AudioSignal): + x = getattr(x, self.attribute) + y = getattr(y, self.attribute) + return super().forward(x, y) + + +class SISDRLoss(nn.Module): + """ + Computes the Scale-Invariant Source-to-Distortion Ratio between a batch + of estimated and reference audio signals or aligned features. + + Parameters + ---------- + scaling : int, optional + Whether to use scale-invariant (True) or + signal-to-noise ratio (False), by default True + reduction : str, optional + How to reduce across the batch (either 'mean', + 'sum', or none).], by default ' mean' + zero_mean : int, optional + Zero mean the references and estimates before + computing the loss, by default True + clip_min : int, optional + The minimum possible loss value. Helps network + to not focus on making already good examples better, by default None + weight : float, optional + Weight of this loss, defaults to 1.0. + + Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py + """ + + def __init__( + self, + scaling: int = True, + reduction: str = "mean", + zero_mean: int = True, + clip_min: int = None, + weight: float = 1.0, + ): + self.scaling = scaling + self.reduction = reduction + self.zero_mean = zero_mean + self.clip_min = clip_min + self.weight = weight + super().__init__() + + def forward(self, x: AudioSignal, y: AudioSignal): + eps = 1e-8 + # nb, nc, nt + if isinstance(x, AudioSignal): + references = x.audio_data + estimates = y.audio_data + else: + references = x + estimates = y + + nb = references.shape[0] + references = references.reshape(nb, 1, -1).permute(0, 2, 1) + estimates = estimates.reshape(nb, 1, -1).permute(0, 2, 1) + + # samples now on axis 1 + if self.zero_mean: + mean_reference = references.mean(dim=1, keepdim=True) + mean_estimate = estimates.mean(dim=1, keepdim=True) + else: + mean_reference = 0 + mean_estimate = 0 + + _references = references - mean_reference + _estimates = estimates - mean_estimate + + references_projection = (_references**2).sum(dim=-2) + eps + references_on_estimates = (_estimates * _references).sum(dim=-2) + eps + + scale = ( + (references_on_estimates / references_projection).unsqueeze(1) + if self.scaling + else 1 + ) + + e_true = scale * _references + e_res = _estimates - e_true + + signal = (e_true**2).sum(dim=1) + noise = (e_res**2).sum(dim=1) + sdr = -10 * torch.log10(signal / noise + eps) + + if self.clip_min is not None: + sdr = torch.clamp(sdr, min=self.clip_min) + + if self.reduction == "mean": + sdr = sdr.mean() + elif self.reduction == "sum": + sdr = sdr.sum() + return sdr + + +class MultiScaleSTFTLoss(nn.Module): + """Computes the multi-scale STFT loss from [1]. + + Parameters + ---------- + window_lengths : List[int], optional + Length of each window of each STFT, by default [2048, 512] + loss_fn : typing.Callable, optional + How to compare each loss, by default nn.L1Loss() + clamp_eps : float, optional + Clamp on the log magnitude, below, by default 1e-5 + mag_weight : float, optional + Weight of raw magnitude portion of loss, by default 1.0 + log_weight : float, optional + Weight of log magnitude portion of loss, by default 1.0 + pow : float, optional + Power to raise magnitude to before taking log, by default 2.0 + weight : float, optional + Weight of this loss, by default 1.0 + match_stride : bool, optional + Whether to match the stride of convolutional layers, by default False + + References + ---------- + + 1. Engel, Jesse, Chenjie Gu, and Adam Roberts. + "DDSP: Differentiable Digital Signal Processing." + International Conference on Learning Representations. 2019. + + Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py + """ + + def __init__( + self, + window_lengths: List[int] = [2048, 512], + loss_fn: typing.Callable = nn.L1Loss(), + clamp_eps: float = 1e-5, + mag_weight: float = 1.0, + log_weight: float = 1.0, + pow: float = 2.0, + weight: float = 1.0, + match_stride: bool = False, + window_type: str = None, + ): + super().__init__() + self.stft_params = [ + STFTParams( + window_length=w, + hop_length=w // 4, + match_stride=match_stride, + window_type=window_type, + ) + for w in window_lengths + ] + self.loss_fn = loss_fn + self.log_weight = log_weight + self.mag_weight = mag_weight + self.clamp_eps = clamp_eps + self.weight = weight + self.pow = pow + + def forward(self, x: AudioSignal, y: AudioSignal): + """Computes multi-scale STFT between an estimate and a reference + signal. + + Parameters + ---------- + x : AudioSignal + Estimate signal + y : AudioSignal + Reference signal + + Returns + ------- + torch.Tensor + Multi-scale STFT loss. + """ + loss = 0.0 + for s in self.stft_params: + x.stft(s.window_length, s.hop_length, s.window_type) + y.stft(s.window_length, s.hop_length, s.window_type) + loss += self.log_weight * self.loss_fn( + x.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(), + y.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(), + ) + loss += self.mag_weight * self.loss_fn(x.magnitude, y.magnitude) + return loss + + +class MelSpectrogramLoss(nn.Module): + """Compute distance between mel spectrograms. Can be used + in a multi-scale way. + + Parameters + ---------- + n_mels : List[int] + Number of mels per STFT, by default [150, 80], + window_lengths : List[int], optional + Length of each window of each STFT, by default [2048, 512] + loss_fn : typing.Callable, optional + How to compare each loss, by default nn.L1Loss() + clamp_eps : float, optional + Clamp on the log magnitude, below, by default 1e-5 + mag_weight : float, optional + Weight of raw magnitude portion of loss, by default 1.0 + log_weight : float, optional + Weight of log magnitude portion of loss, by default 1.0 + pow : float, optional + Power to raise magnitude to before taking log, by default 2.0 + weight : float, optional + Weight of this loss, by default 1.0 + match_stride : bool, optional + Whether to match the stride of convolutional layers, by default False + + Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py + """ + + def __init__( + self, + n_mels: List[int] = [150, 80], + window_lengths: List[int] = [2048, 512], + loss_fn: typing.Callable = nn.L1Loss(), + clamp_eps: float = 1e-5, + mag_weight: float = 1.0, + log_weight: float = 1.0, + pow: float = 2.0, + weight: float = 1.0, + match_stride: bool = False, + mel_fmin: List[float] = [0.0, 0.0], + mel_fmax: List[float] = [None, None], + window_type: str = None, + ): + super().__init__() + self.stft_params = [ + STFTParams( + window_length=w, + hop_length=w // 4, + match_stride=match_stride, + window_type=window_type, + ) + for w in window_lengths + ] + self.n_mels = n_mels + self.loss_fn = loss_fn + self.clamp_eps = clamp_eps + self.log_weight = log_weight + self.mag_weight = mag_weight + self.weight = weight + self.mel_fmin = mel_fmin + self.mel_fmax = mel_fmax + self.pow = pow + + def forward(self, x: AudioSignal, y: AudioSignal): + """Computes mel loss between an estimate and a reference + signal. + + Parameters + ---------- + x : AudioSignal + Estimate signal + y : AudioSignal + Reference signal + + Returns + ------- + torch.Tensor + Mel loss. + """ + loss = 0.0 + for n_mels, fmin, fmax, s in zip( + self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params + ): + kwargs = { + "window_length": s.window_length, + "hop_length": s.hop_length, + "window_type": s.window_type, + } + x_mels = x.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs) + y_mels = y.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs) + + loss += self.log_weight * self.loss_fn( + x_mels.clamp(self.clamp_eps).pow(self.pow).log10(), + y_mels.clamp(self.clamp_eps).pow(self.pow).log10(), + ) + loss += self.mag_weight * self.loss_fn(x_mels, y_mels) + return loss + + +class GANLoss(nn.Module): + """ + Computes a discriminator loss, given a discriminator on + generated waveforms/spectrograms compared to ground truth + waveforms/spectrograms. Computes the loss for both the + discriminator and the generator in separate functions. + """ + + def __init__(self, discriminator): + super().__init__() + self.discriminator = discriminator + + def forward(self, fake, real): + d_fake = self.discriminator(fake.audio_data) + d_real = self.discriminator(real.audio_data) + return d_fake, d_real + + def discriminator_loss(self, fake, real): + d_fake, d_real = self.forward(fake.clone().detach(), real) + + loss_d = 0 + for x_fake, x_real in zip(d_fake, d_real): + loss_d += torch.mean(x_fake[-1] ** 2) + loss_d += torch.mean((1 - x_real[-1]) ** 2) + return loss_d + + def generator_loss(self, fake, real): + d_fake, d_real = self.forward(fake, real) + + loss_g = 0 + for x_fake in d_fake: + loss_g += torch.mean((1 - x_fake[-1]) ** 2) + + loss_feature = 0 + + for i in range(len(d_fake)): + for j in range(len(d_fake[i]) - 1): + loss_feature += F.l1_loss(d_fake[i][j], d_real[i][j].detach()) + return loss_g, loss_feature diff --git a/hunyuanvideo_foley/models/dac_vae/nn/quantize.py b/hunyuanvideo_foley/models/dac_vae/nn/quantize.py new file mode 100644 index 0000000000000000000000000000000000000000..bc462bc56ab5191b65d58ebd8bcaff4af7fb5927 --- /dev/null +++ b/hunyuanvideo_foley/models/dac_vae/nn/quantize.py @@ -0,0 +1,262 @@ +from typing import Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from torch.nn.utils import weight_norm + +from .layers import WNConv1d + + +class VectorQuantize(nn.Module): + """ + Implementation of VQ similar to Karpathy's repo: + https://github.com/karpathy/deep-vector-quantization + Additionally uses following tricks from Improved VQGAN + (https://arxiv.org/pdf/2110.04627.pdf): + 1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space + for improved codebook usage + 2. l2-normalized codes: Converts euclidean distance to cosine similarity which + improves training stability + """ + + def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int): + super().__init__() + self.codebook_size = codebook_size + self.codebook_dim = codebook_dim + + self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1) + self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1) + self.codebook = nn.Embedding(codebook_size, codebook_dim) + + def forward(self, z): + """Quantized the input tensor using a fixed codebook and returns + the corresponding codebook vectors + + Parameters + ---------- + z : Tensor[B x D x T] + + Returns + ------- + Tensor[B x D x T] + Quantized continuous representation of input + Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + Tensor[1] + Codebook loss to update the codebook + Tensor[B x T] + Codebook indices (quantized discrete representation of input) + Tensor[B x D x T] + Projected latents (continuous representation of input before quantization) + """ + + # Factorized codes (ViT-VQGAN) Project input into low-dimensional space + z_e = self.in_proj(z) # z_e : (B x D x T) + z_q, indices = self.decode_latents(z_e) + + commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2]) + codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2]) + + z_q = ( + z_e + (z_q - z_e).detach() + ) # noop in forward pass, straight-through gradient estimator in backward pass + + z_q = self.out_proj(z_q) + + return z_q, commitment_loss, codebook_loss, indices, z_e + + def embed_code(self, embed_id): + return F.embedding(embed_id, self.codebook.weight) + + def decode_code(self, embed_id): + return self.embed_code(embed_id).transpose(1, 2) + + def decode_latents(self, latents): + encodings = rearrange(latents, "b d t -> (b t) d") + codebook = self.codebook.weight # codebook: (N x D) + + # L2 normalize encodings and codebook (ViT-VQGAN) + encodings = F.normalize(encodings) + codebook = F.normalize(codebook) + + # Compute euclidean distance with codebook + dist = ( + encodings.pow(2).sum(1, keepdim=True) + - 2 * encodings @ codebook.t() + + codebook.pow(2).sum(1, keepdim=True).t() + ) + indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0)) + z_q = self.decode_code(indices) + return z_q, indices + + +class ResidualVectorQuantize(nn.Module): + """ + Introduced in SoundStream: An end2end neural audio codec + https://arxiv.org/abs/2107.03312 + """ + + def __init__( + self, + input_dim: int = 512, + n_codebooks: int = 9, + codebook_size: int = 1024, + codebook_dim: Union[int, list] = 8, + quantizer_dropout: float = 0.0, + ): + super().__init__() + if isinstance(codebook_dim, int): + codebook_dim = [codebook_dim for _ in range(n_codebooks)] + + self.n_codebooks = n_codebooks + self.codebook_dim = codebook_dim + self.codebook_size = codebook_size + + self.quantizers = nn.ModuleList( + [ + VectorQuantize(input_dim, codebook_size, codebook_dim[i]) + for i in range(n_codebooks) + ] + ) + self.quantizer_dropout = quantizer_dropout + + def forward(self, z, n_quantizers: int = None): + """Quantized the input tensor using a fixed set of `n` codebooks and returns + the corresponding codebook vectors + Parameters + ---------- + z : Tensor[B x D x T] + n_quantizers : int, optional + No. of quantizers to use + (n_quantizers < self.n_codebooks ex: for quantizer dropout) + Note: if `self.quantizer_dropout` is True, this argument is ignored + when in training mode, and a random number of quantizers is used. + Returns + ------- + dict + A dictionary with the following keys: + + "z" : Tensor[B x D x T] + Quantized continuous representation of input + "codes" : Tensor[B x N x T] + Codebook indices for each codebook + (quantized discrete representation of input) + "latents" : Tensor[B x N*D x T] + Projected latents (continuous representation of input before quantization) + "vq/commitment_loss" : Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + "vq/codebook_loss" : Tensor[1] + Codebook loss to update the codebook + """ + z_q = 0 + residual = z + commitment_loss = 0 + codebook_loss = 0 + + codebook_indices = [] + latents = [] + + if n_quantizers is None: + n_quantizers = self.n_codebooks + if self.training: + n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1 + dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],)) + n_dropout = int(z.shape[0] * self.quantizer_dropout) + n_quantizers[:n_dropout] = dropout[:n_dropout] + n_quantizers = n_quantizers.to(z.device) + + for i, quantizer in enumerate(self.quantizers): + if self.training is False and i >= n_quantizers: + break + + z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer( + residual + ) + + # Create mask to apply quantizer dropout + mask = ( + torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers + ) + z_q = z_q + z_q_i * mask[:, None, None] + residual = residual - z_q_i + + # Sum losses + commitment_loss += (commitment_loss_i * mask).mean() + codebook_loss += (codebook_loss_i * mask).mean() + + codebook_indices.append(indices_i) + latents.append(z_e_i) + + codes = torch.stack(codebook_indices, dim=1) + latents = torch.cat(latents, dim=1) + + return z_q, codes, latents, commitment_loss, codebook_loss + + def from_codes(self, codes: torch.Tensor): + """Given the quantized codes, reconstruct the continuous representation + Parameters + ---------- + codes : Tensor[B x N x T] + Quantized discrete representation of input + Returns + ------- + Tensor[B x D x T] + Quantized continuous representation of input + """ + z_q = 0.0 + z_p = [] + n_codebooks = codes.shape[1] + for i in range(n_codebooks): + z_p_i = self.quantizers[i].decode_code(codes[:, i, :]) + z_p.append(z_p_i) + + z_q_i = self.quantizers[i].out_proj(z_p_i) + z_q = z_q + z_q_i + return z_q, torch.cat(z_p, dim=1), codes + + def from_latents(self, latents: torch.Tensor): + """Given the unquantized latents, reconstruct the + continuous representation after quantization. + + Parameters + ---------- + latents : Tensor[B x N x T] + Continuous representation of input after projection + + Returns + ------- + Tensor[B x D x T] + Quantized representation of full-projected space + Tensor[B x D x T] + Quantized representation of latent space + """ + z_q = 0 + z_p = [] + codes = [] + dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers]) + + n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[ + 0 + ] + for i in range(n_codebooks): + j, k = dims[i], dims[i + 1] + z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :]) + z_p.append(z_p_i) + codes.append(codes_i) + + z_q_i = self.quantizers[i].out_proj(z_p_i) + z_q = z_q + z_q_i + + return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1) + + +if __name__ == "__main__": + rvq = ResidualVectorQuantize(quantizer_dropout=True) + x = torch.randn(16, 512, 80) + y = rvq(x) + print(y["latents"].shape) diff --git a/hunyuanvideo_foley/models/dac_vae/nn/vae_utils.py b/hunyuanvideo_foley/models/dac_vae/nn/vae_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a97597f5d5ae4aa19a194c24f3c17b2238224bcf --- /dev/null +++ b/hunyuanvideo_foley/models/dac_vae/nn/vae_utils.py @@ -0,0 +1,91 @@ +import torch +import numpy as np + + +class AbstractDistribution: + def sample(self): + raise NotImplementedError() + + def mode(self): + raise NotImplementedError() + + +class DiracDistribution(AbstractDistribution): + def __init__(self, value): + self.value = value + + def sample(self): + return self.value + + def mode(self): + return self.value + + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters, deterministic=False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) + + def sample(self): + x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) + return x + + def kl(self, other=None): + if self.deterministic: + return torch.Tensor([0.0]) + else: + if other is None: + return 0.5 * torch.mean( + torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, + dim=[1, 2], + ) + else: + return 0.5 * torch.mean( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var + - 1.0 + - self.logvar + + other.logvar, + dim=[1, 2], + ) + + def nll(self, sample, dims=[1, 2]): + if self.deterministic: + return torch.Tensor([0.0]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims, + ) + + def mode(self): + return self.mean + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 + Compute the KL divergence between two gaussians. + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, torch.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for torch.exp(). + logvar1, logvar2 = [x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) for x in (logvar1, logvar2)] + + return 0.5 * ( + -1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) + ) diff --git a/hunyuanvideo_foley/models/dac_vae/utils/__init__.py b/hunyuanvideo_foley/models/dac_vae/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3dce1ed49f1b4e4fe1cb42b054298911207e0e41 --- /dev/null +++ b/hunyuanvideo_foley/models/dac_vae/utils/__init__.py @@ -0,0 +1,121 @@ +from pathlib import Path + +import argbind +from audiotools import ml + +from ..model import DAC +Accelerator = ml.Accelerator + +__MODEL_LATEST_TAGS__ = { + ("44khz", "8kbps"): "0.0.1", + ("24khz", "8kbps"): "0.0.4", + ("16khz", "8kbps"): "0.0.5", + ("44khz", "16kbps"): "1.0.0", +} + +__MODEL_URLS__ = { + ( + "44khz", + "0.0.1", + "8kbps", + ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.1/weights.pth", + ( + "24khz", + "0.0.4", + "8kbps", + ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.4/weights_24khz.pth", + ( + "16khz", + "0.0.5", + "8kbps", + ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.5/weights_16khz.pth", + ( + "44khz", + "1.0.0", + "16kbps", + ): "https://github.com/descriptinc/descript-audio-codec/releases/download/1.0.0/weights_44khz_16kbps.pth", +} + + +@argbind.bind(group="download", positional=True, without_prefix=True) +def download( + model_type: str = "44khz", model_bitrate: str = "8kbps", tag: str = "latest" +): + """ + Function that downloads the weights file from URL if a local cache is not found. + + Parameters + ---------- + model_type : str + The type of model to download. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". + model_bitrate: str + Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps". + Only 44khz model supports 16kbps. + tag : str + The tag of the model to download. Defaults to "latest". + + Returns + ------- + Path + Directory path required to load model via audiotools. + """ + model_type = model_type.lower() + tag = tag.lower() + + assert model_type in [ + "44khz", + "24khz", + "16khz", + ], "model_type must be one of '44khz', '24khz', or '16khz'" + + assert model_bitrate in [ + "8kbps", + "16kbps", + ], "model_bitrate must be one of '8kbps', or '16kbps'" + + if tag == "latest": + tag = __MODEL_LATEST_TAGS__[(model_type, model_bitrate)] + + download_link = __MODEL_URLS__.get((model_type, tag, model_bitrate), None) + + if download_link is None: + raise ValueError( + f"Could not find model with tag {tag} and model type {model_type}" + ) + + local_path = ( + Path.home() + / ".cache" + / "descript" + / "dac" + / f"weights_{model_type}_{model_bitrate}_{tag}.pth" + ) + if not local_path.exists(): + local_path.parent.mkdir(parents=True, exist_ok=True) + + # Download the model + import requests + + response = requests.get(download_link) + + if response.status_code != 200: + raise ValueError( + f"Could not download model. Received response code {response.status_code}" + ) + local_path.write_bytes(response.content) + + return local_path + + +def load_model( + model_type: str = "44khz", + model_bitrate: str = "8kbps", + tag: str = "latest", + load_path: str = None, +): + if not load_path: + load_path = download( + model_type=model_type, model_bitrate=model_bitrate, tag=tag + ) + generator = DAC.load(load_path) + return generator diff --git a/hunyuanvideo_foley/models/dac_vae/utils/decode.py b/hunyuanvideo_foley/models/dac_vae/utils/decode.py new file mode 100644 index 0000000000000000000000000000000000000000..00261a561251b1bef6f11e6594bf80de10b93ff2 --- /dev/null +++ b/hunyuanvideo_foley/models/dac_vae/utils/decode.py @@ -0,0 +1,95 @@ +import warnings +from pathlib import Path + +import argbind +import numpy as np +import torch +from audiotools import AudioSignal +from tqdm import tqdm + +from ..model import DACFile +from . import load_model + +warnings.filterwarnings("ignore", category=UserWarning) + + +@argbind.bind(group="decode", positional=True, without_prefix=True) +@torch.inference_mode() +@torch.no_grad() +def decode( + input: str, + output: str = "", + weights_path: str = "", + model_tag: str = "latest", + model_bitrate: str = "8kbps", + device: str = "cuda", + model_type: str = "44khz", + verbose: bool = False, +): + """Decode audio from codes. + + Parameters + ---------- + input : str + Path to input directory or file + output : str, optional + Path to output directory, by default "". + If `input` is a directory, the directory sub-tree relative to `input` is re-created in `output`. + weights_path : str, optional + Path to weights file, by default "". If not specified, the weights file will be downloaded from the internet using the + model_tag and model_type. + model_tag : str, optional + Tag of the model to use, by default "latest". Ignored if `weights_path` is specified. + model_bitrate: str + Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps". + device : str, optional + Device to use, by default "cuda". If "cpu", the model will be loaded on the CPU. + model_type : str, optional + The type of model to use. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". Ignored if `weights_path` is specified. + """ + generator = load_model( + model_type=model_type, + model_bitrate=model_bitrate, + tag=model_tag, + load_path=weights_path, + ) + generator.to(device) + generator.eval() + + # Find all .dac files in input directory + _input = Path(input) + input_files = list(_input.glob("**/*.dac")) + + # If input is a .dac file, add it to the list + if _input.suffix == ".dac": + input_files.append(_input) + + # Create output directory + output = Path(output) + output.mkdir(parents=True, exist_ok=True) + + for i in tqdm(range(len(input_files)), desc=f"Decoding files"): + # Load file + artifact = DACFile.load(input_files[i]) + + # Reconstruct audio from codes + recons = generator.decompress(artifact, verbose=verbose) + + # Compute output path + relative_path = input_files[i].relative_to(input) + output_dir = output / relative_path.parent + if not relative_path.name: + output_dir = output + relative_path = input_files[i] + output_name = relative_path.with_suffix(".wav").name + output_path = output_dir / output_name + output_path.parent.mkdir(parents=True, exist_ok=True) + + # Write to file + recons.write(output_path) + + +if __name__ == "__main__": + args = argbind.parse_args() + with argbind.scope(args): + decode() diff --git a/hunyuanvideo_foley/models/dac_vae/utils/encode.py b/hunyuanvideo_foley/models/dac_vae/utils/encode.py new file mode 100644 index 0000000000000000000000000000000000000000..c86946c3c6d6a7ff1d1ea883c600d9b93c41b7d9 --- /dev/null +++ b/hunyuanvideo_foley/models/dac_vae/utils/encode.py @@ -0,0 +1,94 @@ +import math +import warnings +from pathlib import Path + +import argbind +import numpy as np +import torch +from audiotools import AudioSignal +from audiotools.core import util +from tqdm import tqdm + +from . import load_model + +warnings.filterwarnings("ignore", category=UserWarning) + + +@argbind.bind(group="encode", positional=True, without_prefix=True) +@torch.inference_mode() +@torch.no_grad() +def encode( + input: str, + output: str = "", + weights_path: str = "", + model_tag: str = "latest", + model_bitrate: str = "8kbps", + n_quantizers: int = None, + device: str = "cuda", + model_type: str = "44khz", + win_duration: float = 5.0, + verbose: bool = False, +): + """Encode audio files in input path to .dac format. + + Parameters + ---------- + input : str + Path to input audio file or directory + output : str, optional + Path to output directory, by default "". If `input` is a directory, the directory sub-tree relative to `input` is re-created in `output`. + weights_path : str, optional + Path to weights file, by default "". If not specified, the weights file will be downloaded from the internet using the + model_tag and model_type. + model_tag : str, optional + Tag of the model to use, by default "latest". Ignored if `weights_path` is specified. + model_bitrate: str + Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps". + n_quantizers : int, optional + Number of quantizers to use, by default None. If not specified, all the quantizers will be used and the model will compress at maximum bitrate. + device : str, optional + Device to use, by default "cuda" + model_type : str, optional + The type of model to use. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". Ignored if `weights_path` is specified. + """ + generator = load_model( + model_type=model_type, + model_bitrate=model_bitrate, + tag=model_tag, + load_path=weights_path, + ) + generator.to(device) + generator.eval() + kwargs = {"n_quantizers": n_quantizers} + + # Find all audio files in input path + input = Path(input) + audio_files = util.find_audio(input) + + output = Path(output) + output.mkdir(parents=True, exist_ok=True) + + for i in tqdm(range(len(audio_files)), desc="Encoding files"): + # Load file + signal = AudioSignal(audio_files[i]) + + # Encode audio to .dac format + artifact = generator.compress(signal, win_duration, verbose=verbose, **kwargs) + + # Compute output path + relative_path = audio_files[i].relative_to(input) + output_dir = output / relative_path.parent + if not relative_path.name: + output_dir = output + relative_path = audio_files[i] + output_name = relative_path.with_suffix(".dac").name + output_path = output_dir / output_name + output_path.parent.mkdir(parents=True, exist_ok=True) + + artifact.save(output_path) + + +if __name__ == "__main__": + args = argbind.parse_args() + with argbind.scope(args): + encode() diff --git a/hunyuanvideo_foley/models/hifi_foley.py b/hunyuanvideo_foley/models/hifi_foley.py new file mode 100644 index 0000000000000000000000000000000000000000..1b5fffcca7a61c3baf686031e9ffdbc3696f3e22 --- /dev/null +++ b/hunyuanvideo_foley/models/hifi_foley.py @@ -0,0 +1,794 @@ +from typing import List, Tuple, Optional, Union, Dict + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from einops.layers.torch import Rearrange +from diffusers.models import ModelMixin +from diffusers.configuration_utils import ConfigMixin, register_to_config + +from .nn.activation_layers import SwiGLU, get_activation_layer +from .nn.attn_layers import apply_rotary_emb, attention +from .nn.embed_layers import TimestepEmbedder, ConditionProjection, PatchEmbed1D +from .nn.mlp_layers import MLP, ConvMLP, FinalLayer1D, ChannelLastConv1d +from .nn.modulate_layers import ModulateDiT, ckpt_wrapper, apply_gate, modulate +from .nn.norm_layers import get_norm_layer +from .nn.posemb_layers import get_nd_rotary_pos_embed + +def interleave_two_sequences(x1: torch.Tensor, x2: torch.Tensor): + # [B, N1, H, C] & [B, N2, H, C] + B, N1, H, C = x1.shape + B, N2, H, C = x2.shape + assert x1.ndim == x2.ndim == 4 + + if N1 != N2: + x2 = x2.view(B, N2, -1).transpose(1, 2) + x2 = F.interpolate(x2, size=(N1), mode="nearest-exact") + x2 = x2.transpose(1, 2).view(B, N1, H, C) + x = torch.stack((x1, x2), dim=2) + x = x.reshape(B, N1 * 2, H, C) + return x + +def decouple_interleaved_two_sequences(x: torch.Tensor, len1: int, len2: int): + B, N, H, C = x.shape + assert N % 2 == 0 and N // 2 == len1 + + x = x.reshape(B, -1, 2, H, C) + x1 = x[:, :, 0] + x2 = x[:, :, 1] + if x2.shape[1] != len2: + x2 = x2.view(B, len1, H * C).transpose(1, 2) + x2 = F.interpolate(x2, size=(len2), mode="nearest-exact") + x2 = x2.transpose(1, 2).view(B, len2, H, C) + return x1, x2 + +class TwoStreamCABlock(nn.Module): + def __init__( + self, + hidden_size: int, + num_heads: int, + mlp_ratio: float, + mlp_act_type: str = "gelu_tanh", + qk_norm: bool = True, + qk_norm_type: str = "rms", + qkv_bias: bool = False, + attn_mode: str = "torch", + reverse: bool = False, + interleaved_audio_visual_rope: bool = False, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + + self.deterministic = False + self.reverse = reverse + self.attn_mode = attn_mode + self.num_heads = num_heads + self.hidden_size = hidden_size + head_dim = hidden_size // num_heads + mlp_hidden_dim = int(hidden_size * mlp_ratio) + + self.interleaved_audio_visual_rope = interleaved_audio_visual_rope + + # Self attention for audio + visual + self.audio_mod = ModulateDiT(hidden_size, factor=9, act_layer=get_activation_layer("silu"), **factory_kwargs) + self.audio_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) + self.audio_self_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs) + qk_norm_layer = get_norm_layer(qk_norm_type) + self.audio_self_q_norm = ( + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() + ) + self.audio_self_k_norm = ( + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() + ) + self.audio_self_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs) + + # visual cond + self.v_cond_mod = ModulateDiT(hidden_size, factor=9, act_layer=get_activation_layer("silu"), **factory_kwargs) + self.v_cond_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) + self.v_cond_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs) + self.v_cond_attn_q_norm = ( + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() + ) + self.v_cond_attn_k_norm = ( + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() + ) + self.v_cond_self_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs) + + self.max_text_len = 100 + self.rope_dim_list = None + + # audio and video norm for cross attention with text + self.audio_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) + self.v_cond_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) + + # Cross attention: (video_audio) as query, text as key/value + self.audio_cross_q = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs) + self.v_cond_cross_q = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs) + self.text_cross_kv = nn.Linear(hidden_size, hidden_size * 2, bias=qkv_bias, **factory_kwargs) + + self.audio_cross_q_norm = ( + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() + ) + self.v_cond_cross_q_norm = ( + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() + ) + self.text_cross_k_norm = ( + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() + ) + self.audio_cross_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs) + self.v_cond_cross_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs) + + # MLPs + self.audio_norm3 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) + self.audio_mlp = MLP( + hidden_size, mlp_hidden_dim, act_layer=get_activation_layer(mlp_act_type), bias=True, **factory_kwargs + ) + + self.v_cond_norm3 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) + self.v_cond_mlp = MLP( + hidden_size, mlp_hidden_dim, act_layer=get_activation_layer(mlp_act_type), bias=True, **factory_kwargs + ) + + def build_rope_for_text(self, text_len, head_dim, rope_dim_list=None): + target_ndim = 1 # n-d RoPE + rope_sizes = [text_len] + + if rope_dim_list is None: + rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)] + assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer" + + text_freqs_cos, text_freqs_sin = get_nd_rotary_pos_embed( + rope_dim_list=rope_dim_list, + start=rope_sizes, + theta=10000, + use_real=True, + theta_rescale_factor=1.0, + ) + return text_freqs_cos, text_freqs_sin + + def set_attn_mode(self, new_mode): + if new_mode != "torch": + raise NotImplementedError(f"Only support 'torch' mode, got {new_mode}.") + self.attn_mode = new_mode + + def enable_deterministic(self): + self.deterministic = True + + def disable_deterministic(self): + self.deterministic = False + + def forward( + self, + audio: torch.Tensor, + cond: torch.Tensor, + v_cond: torch.Tensor, + attn_mask: torch.Tensor, + vec: torch.Tensor, + freqs_cis: tuple = None, + v_freqs_cis: tuple = None, + sync_vec: torch.Tensor = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # Get modulation parameters + if sync_vec is not None: + assert sync_vec.ndim == 3 + (audio_mod1_shift, audio_mod1_scale, audio_mod1_gate, + audio_mod2_shift, audio_mod2_scale, audio_mod2_gate, + audio_mod3_shift, audio_mod3_scale, audio_mod3_gate, + ) = self.audio_mod(sync_vec).chunk(9, dim=-1) + else: + (audio_mod1_shift, audio_mod1_scale, audio_mod1_gate, + audio_mod2_shift, audio_mod2_scale, audio_mod2_gate, + audio_mod3_shift, audio_mod3_scale, audio_mod3_gate, + ) = self.audio_mod(vec).chunk(9, dim=-1) + + ( + v_cond_mod1_shift, + v_cond_mod1_scale, + v_cond_mod1_gate, + v_cond_mod2_shift, + v_cond_mod2_scale, + v_cond_mod2_gate, + v_cond_mod3_shift, + v_cond_mod3_scale, + v_cond_mod3_gate, + ) = self.v_cond_mod(vec).chunk(9, dim=-1) + + # 1. Self Attention for audio + visual + audio_modulated = self.audio_norm1(audio) + audio_modulated = modulate(audio_modulated, shift=audio_mod1_shift, scale=audio_mod1_scale) + audio_qkv = self.audio_self_attn_qkv(audio_modulated) + audio_q, audio_k, audio_v = rearrange(audio_qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads) + audio_q = self.audio_self_q_norm(audio_q).to(audio_v) + audio_k = self.audio_self_k_norm(audio_k).to(audio_v) + + # Prepare visual cond for attention + v_cond_modulated = self.v_cond_norm1(v_cond) + v_cond_modulated = modulate(v_cond_modulated, shift=v_cond_mod1_shift, scale=v_cond_mod1_scale) + v_cond_qkv = self.v_cond_attn_qkv(v_cond_modulated) + v_cond_q, v_cond_k, v_cond_v = rearrange(v_cond_qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads) + v_cond_q = self.v_cond_attn_q_norm(v_cond_q).to(v_cond_v) + v_cond_k = self.v_cond_attn_k_norm(v_cond_k).to(v_cond_v) + + # Apply RoPE if needed for audio and visual + if freqs_cis is not None: + if not self.interleaved_audio_visual_rope: + audio_qq, audio_kk = apply_rotary_emb(audio_q, audio_k, freqs_cis, head_first=False) + audio_q, audio_k = audio_qq, audio_kk + else: + ori_audio_len = audio_q.shape[1] + ori_v_con_len = v_cond_q.shape[1] + interleaved_audio_visual_q = interleave_two_sequences(audio_q, v_cond_q) + interleaved_audio_visual_k = interleave_two_sequences(audio_k, v_cond_k) + interleaved_audio_visual_qq, interleaved_audio_visual_kk = apply_rotary_emb( + interleaved_audio_visual_q, interleaved_audio_visual_k, freqs_cis, head_first=False + ) + audio_qq, v_cond_qq = decouple_interleaved_two_sequences( + interleaved_audio_visual_qq, ori_audio_len, ori_v_con_len + ) + audio_kk, v_cond_kk = decouple_interleaved_two_sequences( + interleaved_audio_visual_kk, ori_audio_len, ori_v_con_len + ) + audio_q, audio_k = audio_qq, audio_kk + v_cond_q, v_cond_k = v_cond_qq, v_cond_kk + + # Apply RoPE to visual if needed and not interleaved + if v_freqs_cis is not None and not self.interleaved_audio_visual_rope: + v_cond_qq, v_cond_kk = apply_rotary_emb(v_cond_q, v_cond_k, v_freqs_cis, head_first=False) + v_cond_q, v_cond_k = v_cond_qq, v_cond_kk + + # Concatenate for self-attention + q = torch.cat((v_cond_q, audio_q), dim=1) + k = torch.cat((v_cond_k, audio_k), dim=1) + v = torch.cat((v_cond_v, audio_v), dim=1) + + # Run self-attention + attn = attention(q, k, v, mode=self.attn_mode, attn_mask=attn_mask, deterministic=self.deterministic) + v_cond_attn, audio_attn = torch.split(attn, [v_cond.shape[1], audio.shape[1]], dim=1) + + # Apply self-attention output to audio and v_cond + audio = audio + apply_gate(self.audio_self_proj(audio_attn), gate=audio_mod1_gate) + v_cond = v_cond + apply_gate(self.v_cond_self_proj(v_cond_attn), gate=v_cond_mod1_gate) + + # 2. Cross Attention: (v_cond, audio) as query, text as key/value + # audio, v_cond modulation + audio_modulated = self.audio_norm2(audio) + audio_modulated = modulate(audio_modulated, shift=audio_mod2_shift, scale=audio_mod2_scale) + v_cond_modulated = self.v_cond_norm2(v_cond) + v_cond_modulated = modulate(v_cond_modulated, shift=v_cond_mod2_shift, scale=v_cond_mod2_scale) + + # Prepare audio query + audio_q = self.audio_cross_q(audio_modulated) + audio_q = rearrange(audio_q, "B L (H D) -> B L H D", H=self.num_heads) + audio_q = self.audio_cross_q_norm(audio_q) + + # Prepare v_cond query + v_cond_q = self.v_cond_cross_q(v_cond_modulated) + v_cond_q = rearrange(v_cond_q, "B L (H D) -> B L H D", H=self.num_heads) + v_cond_q = self.v_cond_cross_q_norm(v_cond_q) + + # Prepare text key/value + text_kv = self.text_cross_kv(cond) + text_k, text_v = rearrange(text_kv, "B L (K H D) -> K B L H D", K=2, H=self.num_heads) + text_k = self.text_cross_k_norm(text_k).to(text_v) + + # Apply RoPE to (v_cond, audio) query and text key if needed + head_dim = self.hidden_size // self.num_heads + audio_cross_freqs_cos, audio_cross_freqs_sin = self.build_rope_for_text(audio_q.shape[1], head_dim, rope_dim_list=self.rope_dim_list) + audio_cross_freqs_cis = (audio_cross_freqs_cos.to(audio_q.device), audio_cross_freqs_sin.to(audio_q.device)) + audio_q = apply_rotary_emb(audio_q, audio_q, audio_cross_freqs_cis, head_first=False)[0] + + v_cond_cross_freqs_cos, v_cond_cross_freqs_sin = self.build_rope_for_text(v_cond_q.shape[1], head_dim, rope_dim_list=self.rope_dim_list) + v_cond_cross_freqs_cis = (v_cond_cross_freqs_cos.to(v_cond_q.device), v_cond_cross_freqs_sin.to(v_cond_q.device)) + v_cond_q = apply_rotary_emb(v_cond_q, v_cond_q, v_cond_cross_freqs_cis, head_first=False)[0] + + text_len = text_k.shape[1] + + text_freqs_cos, text_freqs_sin = self.build_rope_for_text(text_len, head_dim, + rope_dim_list=self.rope_dim_list) + text_freqs_cis = (text_freqs_cos.to(text_k.device), text_freqs_sin.to(text_k.device)) + text_k = apply_rotary_emb(text_k, text_k, text_freqs_cis, head_first=False)[1] + + # Concat v_cond and audio for cross-attention + v_cond_audio_q = torch.cat([v_cond_q, audio_q], dim=1) + + # Run cross-attention + cross_attn = attention(v_cond_audio_q, text_k, text_v, mode=self.attn_mode, deterministic=self.deterministic) + v_cond_cross_attn, audio_cross_attn = torch.split(cross_attn, [v_cond.shape[1], audio.shape[1]], dim=1) + + # Apply cross-attention output + audio = audio + apply_gate(self.audio_cross_proj(audio_cross_attn), gate=audio_mod2_gate) + v_cond = v_cond + apply_gate(self.v_cond_cross_proj(v_cond_cross_attn), gate=v_cond_mod2_gate) + + # 3. Apply MLPs + audio = audio + apply_gate( + self.audio_mlp(modulate(self.audio_norm3(audio), shift=audio_mod3_shift, scale=audio_mod3_scale)), + gate=audio_mod3_gate, + ) + + # Apply visual MLP + v_cond = v_cond + apply_gate( + self.v_cond_mlp(modulate(self.v_cond_norm3(v_cond), shift=v_cond_mod3_shift, scale=v_cond_mod3_scale)), + gate=v_cond_mod3_gate, + ) + + return audio, cond, v_cond + +class SingleStreamBlock(nn.Module): + + def __init__(self, hidden_size: int, + num_heads: int, + mlp_ratio: float, + qk_norm_type: str = "rms", + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None,): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + + self.hidden_size = hidden_size + self.num_heads = num_heads + + self.modulation = ModulateDiT( + hidden_size=hidden_size, + factor=6, + act_layer=get_activation_layer("silu"), + **factory_kwargs, + ) + self.linear_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=True) + self.linear1 = ChannelLastConv1d(hidden_size, hidden_size, kernel_size=3, padding=1, **factory_kwargs) + self.linear2 = ConvMLP(hidden_size, hidden_size * mlp_ratio, kernel_size=3, padding=1, **factory_kwargs) + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False) + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False) + self.q_norm = nn.RMSNorm(hidden_size // num_heads) + self.k_norm = nn.RMSNorm(hidden_size // num_heads) + self.rearrange = Rearrange("B L (H D K) -> B H L D K", K=3, H=num_heads) + + def forward(self, x: torch.Tensor, cond: torch.Tensor,freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None): + assert cond.ndim == 3, "Condition should be in shape of [B, T, D]" + modulation = self.modulation(cond) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = modulation.chunk(6, dim=-1) + x_norm1 = self.norm1(x) * (1 + scale_msa) + shift_msa + + qkv = self.linear_qkv(x_norm1) + q, k, v = self.rearrange(qkv).chunk(3, dim=-1) + q = q.squeeze(-1) + k = k.squeeze(-1) + v = v.squeeze(-1) + + q = self.q_norm(q) + k = self.k_norm(k) + q, k = apply_rotary_emb(q, k, freqs_cis, head_first=True) + + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + out = F.scaled_dot_product_attention(q, k, v) + out = rearrange(out, 'b h n d -> b n (h d)').contiguous() + + x = x + apply_gate(self.linear1(out),gate=gate_msa) + x_norm = self.norm2(x) * (1 + scale_mlp) + shift_mlp + x = x + apply_gate(self.linear2(x_norm), gate=gate_mlp) + + return x + +class HunyuanVideoFoley(ModelMixin, ConfigMixin): + @register_to_config + def __init__( + self, + model_config, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + + model_args = model_config.model_config.model_kwargs + self.depth_triple_blocks = model_args.get("depth_triple_blocks", 19) + self.depth_single_blocks = model_args.get("depth_single_blocks", 38) + # Gradient checkpoint. + self.gradient_checkpoint = False + self.gradient_checkpoint_layers = None + if self.gradient_checkpoint: + assert self.gradient_checkpoint_layers <= self.depth_triple_blocks + self.depth_single_blocks, ( + f"Gradient checkpoint layers must be less or equal than the depth of the model. " + f"Got gradient_checkpoint_layers={self.gradient_checkpoint_layers} and depth={self.depth_triple_blocks + self.depth_single_blocks}." + ) + + self.interleaved_audio_visual_rope = model_args.get("interleaved_audio_visual_rope", False) + + # Condition projection. Default to linear projection. + self.condition_projection = model_args.get("condition_projection", "linear") + self.condition_dim = model_args.get("condition_dim", None) + self.use_attention_mask = model_args.get("use_attention_mask", False) + + self.patch_size = model_args.get("patch_size", 1) + self.visual_in_channels = model_args.get("clip_dim", 768) + self.audio_vae_latent_dim = model_args.get("audio_vae_latent_dim", 128) + self.out_channels = self.audio_vae_latent_dim + self.unpatchify_channels = self.out_channels + self.reverse = model_args.get("reverse", False) + + self.num_heads = model_args.get("num_heads", 24) + self.hidden_size = model_args.get("hidden_size", 3072) + self.rope_dim_list = model_args.get("rope_dim_list", None) + self.mlp_ratio = model_args.get("mlp_ratio", 4.0) + self.mlp_act_type = model_args.get("mlp_act_type", "gelu_tanh") + + self.qkv_bias = model_args.get("qkv_bias", True) + self.qk_norm = model_args.get("qk_norm", True) + self.qk_norm_type = model_args.get("qk_norm_type", "rms") + self.attn_mode = model_args.get("attn_mode", "torch") + + self.embedder_type = model_args.get("embedder_type", "default") + + # sync condition things + self.sync_modulation = model_args.get("sync_modulation", False) + self.add_sync_feat_to_audio = model_args.get("add_sync_feat_to_audio", False) + self.sync_feat_dim = model_args.get("sync_feat_dim", 768) + self.sync_in_ksz = model_args.get("sync_in_ksz", 1) + + # condition tokens length + self.clip_len = model_args.get("clip_length", 64) + self.sync_len = model_args.get("sync_length", 192) + + if self.hidden_size % self.num_heads != 0: + raise ValueError(f"Hidden size {self.hidden_size} must be divisible by num_heads {self.num_heads}") + + # Build audio patchify layer and visual gated linear projection + self.patch_size = 1 + self.audio_embedder = PatchEmbed1D(self.patch_size, self.audio_vae_latent_dim, self.hidden_size, **factory_kwargs) + self.visual_proj = SwiGLU(self.visual_in_channels, hidden_dim=self.hidden_size, out_dim=self.hidden_size) + + # condition + if self.condition_projection == "linear": + self.cond_in = ConditionProjection( + self.condition_dim, self.hidden_size, get_activation_layer("silu"), **factory_kwargs + ) + else: + raise NotImplementedError(f"Unsupported condition_projection: {self.condition_projection}") + + # time modulation + self.time_in = TimestepEmbedder(self.hidden_size, get_activation_layer("silu"), **factory_kwargs) + + # visual sync embedder if needed + if self.sync_in_ksz == 1: + sync_in_padding = 0 + elif self.sync_in_ksz == 3: + sync_in_padding = 1 + else: + raise ValueError + if self.sync_modulation or self.add_sync_feat_to_audio: + self.sync_in = nn.Sequential( + nn.Linear(self.sync_feat_dim, self.hidden_size), + nn.SiLU(), + ConvMLP(self.hidden_size, self.hidden_size * 4, kernel_size=self.sync_in_ksz, padding=sync_in_padding), + ) + self.sync_pos_emb = nn.Parameter(torch.zeros((1, 1, 8, self.sync_feat_dim))) + + self.triple_blocks = nn.ModuleList( + [ + TwoStreamCABlock( + hidden_size=self.hidden_size, + num_heads=self.num_heads, + mlp_ratio=self.mlp_ratio, + mlp_act_type=self.mlp_act_type, + qk_norm=self.qk_norm, + qk_norm_type=self.qk_norm_type, + qkv_bias=self.qkv_bias, + attn_mode=self.attn_mode, + reverse=self.reverse, + interleaved_audio_visual_rope=self.interleaved_audio_visual_rope, + **factory_kwargs, + ) + for _ in range(self.depth_triple_blocks) + ] + ) + + + self.single_blocks = nn.ModuleList( + [ + SingleStreamBlock( + hidden_size=self.hidden_size, + num_heads=self.num_heads, + mlp_ratio=self.mlp_ratio, + qk_norm_type=self.qk_norm_type, + **factory_kwargs, + ) + for _ in range(self.depth_single_blocks) + ] + ) + + self.final_layer = FinalLayer1D( + self.hidden_size, self.patch_size, self.out_channels, get_activation_layer("silu"), **factory_kwargs + ) + self.unpatchify_channels = self.out_channels + + self.empty_clip_feat = nn.Parameter(torch.zeros(1, self.visual_in_channels), requires_grad=True) + self.empty_sync_feat = nn.Parameter(torch.zeros(1, self.sync_feat_dim), requires_grad=True) + nn.init.constant_(self.empty_clip_feat, 0) + nn.init.constant_(self.empty_sync_feat, 0) + + def get_empty_string_sequence(self, bs=None) -> torch.Tensor: + if bs is None: + return self.empty_string_feat + else: + return self.empty_string_feat.unsqueeze(0).expand(bs, -1, -1) + + def get_empty_clip_sequence(self, bs=None, len=None) -> torch.Tensor: + len = len if len is not None else self.clip_len + if bs is None: + return self.empty_clip_feat.expand(len, -1) # 15s + else: + return self.empty_clip_feat.unsqueeze(0).expand(bs, len, -1) # 15s + + def get_empty_sync_sequence(self, bs=None, len=None) -> torch.Tensor: + len = len if len is not None else self.sync_len + if bs is None: + return self.empty_sync_feat.expand(len, -1) + else: + return self.empty_sync_feat.unsqueeze(0).expand(bs, len, -1) + + def build_rope_for_audio_visual(self, audio_emb_len, visual_cond_len): + assert self.patch_size == 1 + # ======================================== Build RoPE for audio tokens ====================================== + target_ndim = 1 # n-d RoPE + rope_sizes = [audio_emb_len] + head_dim = self.hidden_size // self.num_heads + rope_dim_list = self.rope_dim_list + if rope_dim_list is None: + rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)] + assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer" + freqs_cos, freqs_sin = get_nd_rotary_pos_embed( + rope_dim_list=rope_dim_list, + start=rope_sizes, + theta=10000, + use_real=True, + theta_rescale_factor=1.0, + ) + + # ========================== Build RoPE for clip tokens ========================= + target_ndim = 1 # n-d RoPE + rope_sizes = [visual_cond_len] + head_dim = self.hidden_size // self.num_heads + rope_dim_list = self.rope_dim_list + if rope_dim_list is None: + rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)] + assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer" + v_freqs_cos, v_freqs_sin = get_nd_rotary_pos_embed( + rope_dim_list=rope_dim_list, + start=rope_sizes, + theta=10000, + use_real=True, + theta_rescale_factor=1.0, + freq_scaling=1.0 * audio_emb_len / visual_cond_len, + ) + return freqs_cos, freqs_sin, v_freqs_cos, v_freqs_sin + + def build_rope_for_interleaved_audio_visual(self, total_len): + assert self.patch_size == 1 + # ========================== Build RoPE for audio tokens ======================== + target_ndim = 1 # n-d RoPE + rope_sizes = [total_len] + head_dim = self.hidden_size // self.num_heads + rope_dim_list = self.rope_dim_list + if rope_dim_list is None: + rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)] + assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer" + freqs_cos, freqs_sin = get_nd_rotary_pos_embed( + rope_dim_list=rope_dim_list, + start=rope_sizes, + theta=10000, + use_real=True, + theta_rescale_factor=1.0, + ) + return freqs_cos, freqs_sin + + def set_attn_mode(self, new_mode): + for block in self.triple_blocks: + block.set_attn_mode(new_mode) + for block in self.single_blocks: + block.set_attn_mode(new_mode) + + def enable_deterministic(self): + for block in self.triple_blocks: + block.enable_deterministic() + for block in self.single_blocks: + block.enable_deterministic() + + def disable_deterministic(self): + for block in self.triple_blocks: + block.disable_deterministic() + for block in self.single_blocks: + block.disable_deterministic() + + def forward( + self, + x: torch.Tensor, + t: torch.Tensor, # Should be in range(0, 1000). + clip_feat: Optional[torch.Tensor] = None, + cond: torch.Tensor = None, + audio_mask: Optional[torch.Tensor] = None, + cond_mask: torch.Tensor = None, + sync_feat: Optional[torch.Tensor] = None, + drop_visual: Optional[List[bool]] = None, + return_dict: bool = True, + ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: + out = {} + audio = x + bs, _, ol = x.shape + tl = ol // self.patch_size + + # Prepare learnable empty conditions for visual condition + if drop_visual is not None: + clip_feat[drop_visual] = self.get_empty_clip_sequence().to(dtype=clip_feat.dtype) + sync_feat[drop_visual] = self.get_empty_sync_sequence().to(dtype=sync_feat.dtype) + + # ========================= Prepare time & visual modulation ========================= + vec = self.time_in(t) + sync_vec = None + if self.sync_modulation: + assert sync_feat is not None and sync_feat.shape[1] % 8 == 0 + sync_feat = sync_feat.view(bs, int(sync_feat.shape[1] / 8), 8, self.sync_feat_dim) + self.sync_pos_emb + sync_feat = sync_feat.view(bs, -1, self.sync_feat_dim) # bs, num_segments * 8, channels + sync_vec = self.sync_in(sync_feat) # bs, num_segments * 8, c + sync_vec = ( + F.interpolate(sync_vec.transpose(1, 2), size=(tl), mode="nearest-exact").contiguous().transpose(1, 2) + ) # bs, tl, c + sync_vec = sync_vec + vec.unsqueeze(1) + elif self.add_sync_feat_to_audio: + assert sync_feat is not None and sync_feat.shape[1] % 8 == 0 + sync_feat = sync_feat.view(bs, sync_feat.shape[1] // 8, 8, self.sync_feat_dim) + self.sync_pos_emb + sync_feat = sync_feat.view(bs, -1, self.sync_feat_dim) # bs, num_segments * 8, channels + sync_feat = self.sync_in(sync_feat) # bs, num_segments * 8, c + add_sync_feat_to_audio = ( + F.interpolate(sync_feat.transpose(1, 2), size=(tl), mode="nearest-exact").contiguous().transpose(1, 2) + ) # bs, tl, c + + # ========================= Get text, audio and video clip embedding ========================= + cond = self.cond_in(cond) + cond_seq_len = cond.shape[1] + + audio = self.audio_embedder(x) + audio_seq_len = audio.shape[1] + v_cond = self.visual_proj(clip_feat) + v_cond_seq_len = v_cond.shape[1] + + # ========================= Compute attention mask ========================= + attn_mask = None + if self.use_attention_mask: + assert cond_mask is not None + batch_size = audio.shape[0] + seq_len = cond_seq_len + v_cond_seq_len + audio_seq_len + + # get default audio_mask and v_cond_mask + audio_mask = torch.ones((batch_size, audio_seq_len), dtype=torch.bool, device=audio.device) + v_cond_mask = torch.ones((batch_size, v_cond_seq_len), dtype=torch.bool, device=audio.device) + + # batch_size x seq_len + concat_mask = torch.cat([cond_mask, v_cond_mask, audio_mask], dim=1) + # batch_size x 1 x seq_len x seq_len + attn_mask_1 = concat_mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1) + # batch_size x 1 x seq_len x seq_len + attn_mask_2 = attn_mask_1.transpose(2, 3) + # batch_size x 1 x seq_len x seq_len, 1 for broadcasting of num_heads + attn_mask = (attn_mask_1 & attn_mask_2).bool() + # avoids self-attention weight being NaN for text padding tokens + attn_mask[:, :, :, 0] = True + + + # ========================= Build rope for audio and clip tokens ========================= + if self.interleaved_audio_visual_rope: + freqs_cos, freqs_sin = self.build_rope_for_interleaved_audio_visual(audio_seq_len * 2) + v_freqs_cos = v_freqs_sin = None + else: + freqs_cos, freqs_sin, v_freqs_cos, v_freqs_sin = self.build_rope_for_audio_visual( + audio_seq_len, v_cond_seq_len + ) + + # ========================= Pass through DiT blocks ========================= + freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None + v_freqs_cis = (v_freqs_cos, v_freqs_sin) if v_freqs_cos is not None else None + + if self.add_sync_feat_to_audio: + add_sync_layer = 0 + assert ( + add_sync_layer < self.depth_triple_blocks + ), f"The layer to add mel_spectrogram feature and sync feature should in the triple_stream_blocks (n: {self.depth_triple_blocks})." + # Triple-stream blocks + for layer_num, block in enumerate(self.triple_blocks): + if self.add_sync_feat_to_audio and layer_num == add_sync_layer: + audio = audio + add_sync_feat_to_audio + triple_block_args = [audio, cond, v_cond, attn_mask, vec, freqs_cis, v_freqs_cis, sync_vec] + if ( + self.training + and self.gradient_checkpoint + and (self.gradient_checkpoint_layers == -1 or layer_num < self.gradient_checkpoint_layers) + ): + audio, cond, v_cond = torch.utils.checkpoint.checkpoint( + ckpt_wrapper(block), *triple_block_args, use_reentrant=False + ) + else: + audio, cond, v_cond = block(*triple_block_args) + + x = audio + if sync_vec is not None: + vec = vec.unsqueeze(1).repeat(1, cond_seq_len + v_cond_seq_len, 1) + vec = torch.cat((vec, sync_vec), dim=1) + + freqs_cos, freqs_sin, _, _ = self.build_rope_for_audio_visual(audio_seq_len, v_cond_seq_len) + if self.add_sync_feat_to_audio: + vec = add_sync_feat_to_audio + vec.unsqueeze(dim=1) + if len(self.single_blocks) > 0: + for layer_num, block in enumerate(self.single_blocks): + single_block_args = [ + x, + vec, + (freqs_cos, freqs_sin), + ] + if ( + self.training + and self.gradient_checkpoint + and ( + self.gradient_checkpoint_layers == -1 + or layer_num + len(self.triple_blocks) < self.gradient_checkpoint_layers + ) + ): + x = torch.utils.checkpoint.checkpoint(ckpt_wrapper(block), *single_block_args, use_reentrant=False) + else: + x = block(*single_block_args) + + audio = x + + # ========================= Final layer ========================= + if sync_vec is not None: + vec = sync_vec + audio = self.final_layer(audio, vec) # (N, T, patch_size * out_channels) + audio = self.unpatchify1d(audio, tl) + + if return_dict: + out["x"] = audio + return out + return audio + + def unpatchify1d(self, x, l): + # x: (N, L, patch_size * C) + # audio: (N, C, T), T == L * patch_size + c = self.unpatchify_channels + p = self.patch_size + assert l == x.shape[1] + + x = x.reshape(shape=(x.shape[0], l, p, c)) + x = torch.einsum("ntpc->nctp", x) + audio = x.reshape(shape=(x.shape[0], c, l * p)) + return audio + + def params_count(self): + counts = { + "triple": sum( + [ + sum(p.numel() for p in block.audio_cross_q.parameters()) + + sum(p.numel() for p in block.v_cond_cross_q.parameters()) + + sum(p.numel() for p in block.text_cross_kv.parameters()) + + sum(p.numel() for p in block.audio_self_attn_qkv.parameters()) + + sum(p.numel() for p in block.v_cond_attn_qkv.parameters()) + + sum(p.numel() for p in block.audio_mlp.parameters()) + + sum(p.numel() for p in block.audio_self_proj.parameters()) + + sum(p.numel() for p in block.v_cond_self_proj.parameters()) + + sum(p.numel() for p in block.v_cond_mlp.parameters()) + for block in self.triple_blocks + ] + ), + "single": sum( + [ + sum(p.numel() for p in block.linear1.parameters()) + + sum(p.numel() for p in block.linear2.parameters()) + for block in self.single_blocks + ] + ), + "total": sum(p.numel() for p in self.parameters()), + } + + counts["attn+mlp"] = counts["triple"] + counts["single"] + return counts diff --git a/hunyuanvideo_foley/models/nn/__init__.py b/hunyuanvideo_foley/models/nn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/hunyuanvideo_foley/models/nn/activation_layers.py b/hunyuanvideo_foley/models/nn/activation_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..55414cbd054546263e1217f363e9fe02e846a122 --- /dev/null +++ b/hunyuanvideo_foley/models/nn/activation_layers.py @@ -0,0 +1,44 @@ +import torch.nn as nn +import torch.nn.functional as F + +def get_activation_layer(act_type): + if act_type == "gelu": + return lambda: nn.GELU() + elif act_type == "gelu_tanh": + # Approximate `tanh` requires torch >= 1.13 + return lambda: nn.GELU(approximate="tanh") + elif act_type == "relu": + return nn.ReLU + elif act_type == "silu": + return nn.SiLU + else: + raise ValueError(f"Unknown activation type: {act_type}") + +class SwiGLU(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + out_dim: int, + ): + """ + Initialize the SwiGLU FeedForward module. + + Args: + dim (int): Input dimension. + hidden_dim (int): Hidden dimension of the feedforward layer. + + Attributes: + w1: Linear transformation for the first layer. + w2: Linear transformation for the second layer. + w3: Linear transformation for the third layer. + + """ + super().__init__() + + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, out_dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + + def forward(self, x): + return self.w2(F.silu(self.w1(x)) * self.w3(x)) diff --git a/hunyuanvideo_foley/models/nn/attn_layers.py b/hunyuanvideo_foley/models/nn/attn_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..c954eebb33038369f3c7d71cda4fdd9c3a8d27dd --- /dev/null +++ b/hunyuanvideo_foley/models/nn/attn_layers.py @@ -0,0 +1,546 @@ +import importlib.metadata +import math +from typing import Tuple, Union +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + +try: + from flash_attn import ( + flash_attn_qkvpacked_func, + flash_attn_kvpacked_func, + flash_attn_varlen_kvpacked_func, + flash_attn_varlen_qkvpacked_func, + ) + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input +except ImportError: + flash_attn_qkvpacked_func, flash_attn_kvpacked_func, flash_attn_varlen_kvpacked_func = None, None, None + index_first_axis = None +from packaging import version +from transformers.utils.import_utils import _is_package_available + +from .norm_layers import get_norm_layer + +def reshape_for_broadcast(freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], x: torch.Tensor, head_first=False): + """ + Reshape frequency tensor for broadcasting it with another tensor. + + This function reshapes the frequency tensor to have the same shape as the target tensor 'x' + for the purpose of broadcasting the frequency tensor during element-wise operations. + + Notes: + When using FlashMHAModified, head_first should be False. + When using Attention, head_first should be True. + + Args: + freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Frequency tensor to be reshaped. + x (torch.Tensor): Target tensor for broadcasting compatibility. + head_first (bool): head dimension first (except batch dim) or not. + + Returns: + torch.Tensor: Reshaped frequency tensor. + + Raises: + AssertionError: If the frequency tensor doesn't match the expected shape. + AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions. + """ + ndim = x.ndim + assert 0 <= 1 < ndim + + if isinstance(freqs_cis, tuple): + # freqs_cis: (cos, sin) in real space + if head_first: + assert freqs_cis[0].shape == ( + x.shape[-2], + x.shape[-1], + ), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}" + shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + else: + assert freqs_cis[0].shape == ( + x.shape[1], + x.shape[-1], + ), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}" + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape) + else: + # freqs_cis: values in complex space + if head_first: + assert freqs_cis.shape == ( + x.shape[-2], + x.shape[-1], + ), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}" + shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + else: + assert freqs_cis.shape == ( + x.shape[1], + x.shape[-1], + ), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}" + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def rotate_half(x): + x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] + return torch.stack([-x_imag, x_real], dim=-1).flatten(3) + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], + head_first: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary embeddings to input tensors using the given frequency tensor. + + This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided + frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor + is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are + returned as real tensors. + + Args: + xq (torch.Tensor): Query tensor to apply rotary embeddings. [B, S, H, D] + xk (torch.Tensor): Key tensor to apply rotary embeddings. [B, S, H, D] + freqs_cis (torch.Tensor or tuple): Precomputed frequency tensor for complex exponential. + head_first (bool): head dimension first (except batch dim) or not. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. + + """ + xk_out = None + if isinstance(freqs_cis, tuple): + cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D] + cos, sin = cos.to(xq.device), sin.to(xq.device) + # real * cos - imag * sin + # imag * cos + real * sin + xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq) + xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk) + else: + # view_as_complex will pack [..., D/2, 2](real) to [..., D/2](complex) + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # [B, S, H, D//2] + freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(xq.device) # [S, D//2] --> [1, S, 1, D//2] + # (real, imag) * (cos, sin) = (real * cos - imag * sin, imag * cos + real * sin) + # view_as_real will expand [..., D/2](complex) to [..., D/2, 2](real) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) # [B, S, H, D//2] + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk) + + return xq_out, xk_out + + +class BasicAttentionLayer(nn.Module): + def __init__(self, attn_mode="flash", deterministic=False): + super().__init__() + self.attn_mode = attn_mode + self.deterministic = deterministic + + def set_attn_mode(self, new_mode): + self.attn_mode = new_mode + + def enable_deterministic(self): + self.deterministic = True + + def disable_deterministic(self): + self.deterministic = False + + +MEMORY_LAYOUT = { + "self_flash": ( + lambda x: x, + lambda x: x, + ), + "cross_flash": ( + lambda x: x, + lambda x: x, + ), + "flash_torch_sp": ( + lambda x: x, + lambda x: x, + ), + "torch": ( + lambda x: x.transpose(1, 2), + lambda x: x.transpose(1, 2), + ), + "vanilla": ( + lambda x: x.transpose(1, 2), + lambda x: x.transpose(1, 2), + ), +} + + +# Copyed from https://github.com/huggingface/transformers/blob/b873234cb649a24865021f0d598627ce2b24d34a/src/transformers/modeling_flash_attention_utils.py#L33C1-L57C6 +def _get_unpad_data(attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, int]: + """ + Retrieves indexing data required to repad unpadded (ragged) tensors. + + Arguments: + attention_mask (`torch.Tensor`): + Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. + + Return: + indices (`torch.Tensor): + The indices of non-masked tokens from the flattened input sequence. + cu_seqlens (`torch.Tensor`): + The cumulative sequence lengths, used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). + max_seqlen_in_batch (`int`): + Maximum sequence length in batch. + """ + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +# Copyed from https://github.com/huggingface/transformers/blob/b873234cb649a24865021f0d598627ce2b24d34a/src/transformers/utils/import_utils.py#L822 +def is_flash_attn_greater_or_equal(library_version: str): + if not _is_package_available("flash_attn"): + return False + + return version.parse(importlib.metadata.version("flash_attn")) >= version.parse(library_version) + + +def get_kv_seqlens_with_mask(attn_mask, k, v): + indices_k, cu_seqlens_k, max_seqlen_k = _get_unpad_data(attn_mask) + b, s1, a, d = k.shape + k = index_first_axis(k.reshape(b * s1, a, d), indices_k) + v = index_first_axis(v.reshape(b * s1, a, d), indices_k) + kv = torch.stack([k, v], dim=1) + return cu_seqlens_k, max_seqlen_k, kv + + +def get_q_seqlens(q): + bs, s, a, d = q.shape + cu_seqlens_q = torch.arange(0, (bs + 1) * s, step=s, dtype=torch.int32, device=q.device) + q = q.reshape(bs * s, a, d) + return cu_seqlens_q, s, q + +def flash_attn_no_pad( + qkv, key_padding_mask, causal=False, dropout_p=0.0, softmax_scale=None +): + # adapted from https://github.com/Dao-AILab/flash-attention/blob/13403e81157ba37ca525890f2f0f2137edf75311/flash_attn/flash_attention.py#L27 + batch_size = qkv.shape[0] + seqlen = qkv.shape[1] + nheads = qkv.shape[-2] + x = rearrange(qkv, "b s three h d -> b s (three h d)") + # x_unpad, indices, cu_seqlens, max_s, used_seqlens_in_batch + # x_unpad, indices, cu_seqlens, max_s + unpad_results = unpad_input( + x, key_padding_mask + ) + + if len(unpad_results) == 4: + x_unpad, indices, cu_seqlens, max_s = unpad_results + elif len(unpad_results) == 5: + x_unpad, indices, cu_seqlens, max_s, used_seqlens_in_batch = unpad_results + else: + raise ValueError + + x_unpad = rearrange(x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads) + output_unpad = flash_attn_varlen_qkvpacked_func( + x_unpad, + cu_seqlens, + max_s, + dropout_p, + softmax_scale=softmax_scale, + causal=causal, + ) + output = rearrange( + pad_input( + rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, batch_size, seqlen + ), + "b s (h d) -> b s h d", + h=nheads, + ) + return output + + +def attention( + q, + k, + v, + mode, + drop_rate=0, + attn_mask=None, + cond_mask=None, + causal=False, + deterministic=False, + cu_seqlens=None, + max_seqlen=None, + cu_seqlens_k=None, + max_seqlen_k=None, + img_seq_len=None, +): + """ + Perform QKV self attention. + + Args: + q (torch.Tensor): Query tensor with shape [b, s, a, d], where a is the number of heads. + k (torch.Tensor): Key tensor with shape [b, s1, a, d] + v (torch.Tensor): Value tensor with shape [b, s1, a, d] + mode (str): Attention mode. Choose from 'self_flash', 'cross_flash', 'torch', and 'vanilla'. + drop_rate (float): Dropout rate in attention map. (default: 0) + attn_mask (torch.Tensor): Attention mask with shape [b, s1] (cross_attn), or [b, a, s, s1] (torch or vanilla). + (default: None) + causal (bool): Whether to use causal attention. (default: False) + deterministic (bool): Whether to use deterministic attention. (default: False) + cu_seqlens (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, + used to index into q. + max_seqlen (int): The maximum sequence length in the batch of q. + cu_seqlens_k (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, + used to index into kv. + max_seqlen_k (int): The maximum sequence length in the batch of k and v. + + Returns: + torch.Tensor: Output tensor after self attention with shape [b, s, ad] + """ + if mode in ["torch", "vanilla", "self_flash", "cross_flash"]: + if isinstance(q, tuple): + q = torch.cat(q, dim=1) + if isinstance(k, tuple): + k = torch.cat(k, dim=1) + if isinstance(v, tuple): + v = torch.cat(v, dim=1) + pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode] + q = pre_attn_layout(q) + k = pre_attn_layout(k) + v = pre_attn_layout(v) + + if "flash" in mode: + assert ( + flash_attn_qkvpacked_func is not None + ), "Flash attention is not available. Please install flash_attn first." + flash_kwargs = dict(dropout_p=drop_rate, causal=causal) + if deterministic: + if not is_flash_attn_greater_or_equal("2.4.1"): + raise ValueError( + "Flash attention deterministic mode requires flash_attn>=2.4.1. " "Please upgrade flash_attn" + ) + flash_kwargs["deterministic"] = deterministic + + if mode == "self_flash": + qkv = torch.stack([q, k, v], dim=2) + if attn_mask is not None: + raise ValueError("Self attention does not support attention mask") + x = flash_attn_qkvpacked_func(qkv, **flash_kwargs) + + elif mode == "cross_flash": + kv = torch.stack([k, v], dim=2) + if attn_mask is None: + x = flash_attn_kvpacked_func(q, kv, **flash_kwargs) + else: + b, s, a, h = q.shape + cu_seqlens_q, max_seqlen_q, q = get_q_seqlens(q) + cu_seqlens_k, max_seqlen_k, kv = get_kv_seqlens_with_mask(attn_mask, k, v) + + attn_output = flash_attn_varlen_kvpacked_func( + q, + kv, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + **flash_kwargs, + ) + x = attn_output.reshape(b, s, a, h) + elif mode == 'torch': + if attn_mask is not None and attn_mask.dtype != torch.bool: + attn_mask = attn_mask.to(q.dtype) + x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal) + + elif mode == "vanilla": + scale_factor = 1 / math.sqrt(q.size(-1)) + + b, a, s, _ = q.shape + s1 = k.size(2) + attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device) + if causal: + # Only applied to self attention + assert attn_mask is None, "Causal mask and attn_mask cannot be used together" + temp_mask = torch.ones(b, a, s, s, dtype=torch.bool, device=q.device).tril(diagonal=0) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias.to(q.dtype) + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) + else: + attn_bias += attn_mask + + # TODO(jarvizhang): Maybe force q and k to be float32 to avoid numerical overflow + attn = (q @ k.transpose(-2, -1)) * scale_factor + attn += attn_bias + attn = attn.softmax(dim=-1) + attn = torch.dropout(attn, p=drop_rate, train=True) + x = attn @ v + else: + raise NotImplementedError(f"Unsupported attention mode: {mode}") + + if mode in ["torch", "vanilla", "self_flash", "cross_flash"]: + x = post_attn_layout(x).contiguous() + b, s, a, d = x.shape + out = x.reshape(b, s, -1) + return out + + +class SelfAttentionLayer(BasicAttentionLayer): + def __init__( + self, + dim, + num_heads, + qkv_bias=True, + qk_norm=True, + attn_drop=0, + proj_drop=0, + dtype=None, + device=None, + norm_type="layer", + attn_mode="self_flash", + deterministic=False, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__(attn_mode, deterministic) + self.dim = dim + self.num_heads = num_heads + assert self.dim % num_heads == 0, "dim must be divisible by num_heads" + self.head_dim = self.dim // num_heads + self.attn_drop = attn_drop + + # This assertion is aligned with flash attention + assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8" + + self.Wqkv = nn.Linear(dim, dim * 3, bias=qkv_bias, **factory_kwargs) + + norm_layer = get_norm_layer(norm_type) + self.q_norm = ( + norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() + ) + self.k_norm = ( + norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() + ) + + self.out_proj = nn.Linear(dim, dim, bias=qkv_bias, **factory_kwargs) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, freqs_cis=None, attn_mask=None): + """ + Args: + x (torch.Tensor): (batch, seq_len, hidden_dim) (where hidden_dim = num heads * head dim) + freqs_cis (torch.Tensor, optional): (batch, hidden_dim // 2), RoPE for image + attn_mask (torch.Tensor, optional): (batch, seq_len, seq_len), mask for attention + """ + b, s, d = x.shape + + # Apply QKV projection + qkv = self.Wqkv(x) + qkv = qkv.view(b, s, 3, self.num_heads, self.head_dim) # [b, s, 3, a, d] + q, k, v = qkv.unbind(dim=2) # [b, s, a, d] + + # Apply QK-Norm if needed + q = self.q_norm(q) + k = self.k_norm(k) + + # Apply RoPE if needed + if freqs_cis is not None: + qq, kk = apply_rotary_emb(q, k, freqs_cis) + assert ( + qq.shape == q.shape and kk.shape == k.shape + ), f"qq: {qq.shape}, q: {q.shape}, kk: {kk.shape}, k: {k.shape}" + q, k = qq, kk + + # Apply self attention + context = attention( + q, + k, + v, + drop_rate=self.attn_drop if self.training else 0, + attn_mask=attn_mask, + mode=self.attn_mode, + deterministic=self.deterministic, + ) + out = self.out_proj(context) + out = self.proj_drop(out) + + return out + + +class CrossAttentionLayer(BasicAttentionLayer): + def __init__( + self, + qdim, + kdim, + num_heads, + qkv_bias=True, + qk_norm=True, + attn_drop=0, + proj_drop=0, + dtype=None, + device=None, + norm_type="layer", + attn_mode="cross_flash", + deterministic=False, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__(attn_mode, deterministic) + self.qdim = qdim + self.kdim = kdim + self.num_heads = num_heads + assert self.qdim % num_heads == 0, "qdim must be divisible by num_heads" + self.head_dim = self.qdim // num_heads + self.attn_drop = attn_drop + + # This assertion is aligned with flash attention + assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8" + + self.q_proj = nn.Linear(qdim, qdim, bias=qkv_bias, **factory_kwargs) + self.kv_proj = nn.Linear(kdim, 2 * qdim, bias=qkv_bias, **factory_kwargs) + + norm_layer = get_norm_layer(norm_type) + self.q_norm = ( + norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() + ) + self.k_norm = ( + norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() + ) + + self.out_proj = nn.Linear(qdim, qdim, bias=qkv_bias, **factory_kwargs) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, y, attn_mask=None): + """ + Args: + x (torch.Tensor): (batch, seq_len, hidden_dim) (where hidden_dim = num heads * head dim) + y (torch.Tensor): (batch, seq_len1, hidden_dim1) + attn_mask (torch.Tensor): (batch, seq_len1), mask for attention + """ + b, s, d = x.shape + _, s1, d1 = y.shape + + q = self.q_proj(x).view(b, s, self.num_heads, self.head_dim) + kv = self.kv_proj(y).view(b, s1, 2, self.num_heads, self.head_dim) + k, v = kv.unbind(dim=2) + + # Apply QK-Norm if needed + q = self.q_norm(q) + k = self.k_norm(k) + + # Apply cross attention + context = attention( + q, + k, + v, + attn_mask=attn_mask, + drop_rate=self.attn_drop if self.training else 0, + mode=self.attn_mode, + deterministic=self.deterministic, + ) + out = self.out_proj(context) + out = self.proj_drop(out) + + return out diff --git a/hunyuanvideo_foley/models/nn/embed_layers.py b/hunyuanvideo_foley/models/nn/embed_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..fd15167836fd5a5aec7b0c21af296082d7d1b2f4 --- /dev/null +++ b/hunyuanvideo_foley/models/nn/embed_layers.py @@ -0,0 +1,136 @@ +import math +import torch +import torch.nn as nn + +from ...utils.helper import to_2tuple, to_1tuple + +class PatchEmbed1D(nn.Module): + """1D Audio to Patch Embedding + + A convolution based approach to patchifying a 1D audio w/ embedding projection. + + Based on the impl in https://github.com/google-research/vision_transformer + + Hacked together by / Copyright 2020 Ross Wightman + """ + + def __init__( + self, + patch_size=1, + in_chans=768, + embed_dim=768, + norm_layer=None, + flatten=True, + bias=True, + dtype=None, + device=None, + ): + factory_kwargs = {"dtype": dtype, "device": device} + super().__init__() + patch_size = to_1tuple(patch_size) + self.patch_size = patch_size + self.flatten = flatten + + self.proj = nn.Conv1d( + in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias, **factory_kwargs + ) + nn.init.xavier_uniform_(self.proj.weight.view(self.proj.weight.size(0), -1)) + if bias: + nn.init.zeros_(self.proj.bias) + + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x): + assert ( + x.shape[2] % self.patch_size[0] == 0 + ), f"The patch_size of {self.patch_size[0]} must be divisible by the token number ({x.shape[2]}) of x." + + x = self.proj(x) + if self.flatten: + x = x.transpose(1, 2) # BCN -> BNC + x = self.norm(x) + return x + + +class ConditionProjection(nn.Module): + """ + Projects condition embeddings. Also handles dropout for classifier-free guidance. + + Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py + """ + + def __init__(self, in_channels, hidden_size, act_layer, dtype=None, device=None): + factory_kwargs = {'dtype': dtype, 'device': device} + super().__init__() + self.linear_1 = nn.Linear(in_features=in_channels, out_features=hidden_size, bias=True, **factory_kwargs) + self.act_1 = act_layer() + self.linear_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True, **factory_kwargs) + + def forward(self, caption): + hidden_states = self.linear_1(caption) + hidden_states = self.act_1(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + +def timestep_embedding(t, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + + Args: + t (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional. + dim (int): the dimension of the output. + max_period (int): controls the minimum frequency of the embeddings. + + Returns: + embedding (torch.Tensor): An (N, D) Tensor of positional embeddings. + + .. ref_link: https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py + """ + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) + * torch.arange(start=0, end=half, dtype=torch.float32) + / half + ).to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat( + [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 + ) + return embedding + + +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + def __init__(self, + hidden_size, + act_layer, + frequency_embedding_size=256, + max_period=10000, + out_size=None, + dtype=None, + device=None + ): + factory_kwargs = {'dtype': dtype, 'device': device} + super().__init__() + self.frequency_embedding_size = frequency_embedding_size + self.max_period = max_period + if out_size is None: + out_size = hidden_size + + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True, **factory_kwargs), + act_layer(), + nn.Linear(hidden_size, out_size, bias=True, **factory_kwargs), + ) + nn.init.normal_(self.mlp[0].weight, std=0.02) + nn.init.normal_(self.mlp[2].weight, std=0.02) + + def forward(self, t): + t_freq = timestep_embedding(t, self.frequency_embedding_size, self.max_period).type(self.mlp[0].weight.dtype) + t_emb = self.mlp(t_freq) + return t_emb diff --git a/hunyuanvideo_foley/models/nn/mlp_layers.py b/hunyuanvideo_foley/models/nn/mlp_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..0434559025e280b74fbd7c181247aa1c2a41a409 --- /dev/null +++ b/hunyuanvideo_foley/models/nn/mlp_layers.py @@ -0,0 +1,149 @@ +# Modified from timm library: +# https://github.com/huggingface/pytorch-image-models/blob/648aaa41233ba83eb38faf5ba9d415d574823241/timm/layers/mlp.py#L13 + +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .modulate_layers import modulate +from ...utils.helper import to_2tuple + +class MLP(nn.Module): + """MLP as used in Vision Transformer, MLP-Mixer and related networks""" + + def __init__( + self, + in_channels, + hidden_channels=None, + out_features=None, + act_layer=nn.GELU, + norm_layer=None, + bias=True, + drop=0.0, + use_conv=False, + device=None, + dtype=None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + out_features = out_features or in_channels + hidden_channels = hidden_channels or in_channels + bias = to_2tuple(bias) + drop_probs = to_2tuple(drop) + linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear + + self.fc1 = linear_layer(in_channels, hidden_channels, bias=bias[0], **factory_kwargs) + self.act = act_layer() + self.drop1 = nn.Dropout(drop_probs[0]) + self.norm = norm_layer(hidden_channels, **factory_kwargs) if norm_layer is not None else nn.Identity() + self.fc2 = linear_layer(hidden_channels, out_features, bias=bias[1], **factory_kwargs) + self.drop2 = nn.Dropout(drop_probs[1]) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop1(x) + x = self.norm(x) + x = self.fc2(x) + x = self.drop2(x) + return x + + +# copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py +# only used when use_vanilla is True +class MLPEmbedder(nn.Module): + def __init__(self, in_dim: int, hidden_dim: int, device=None, dtype=None): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True, **factory_kwargs) + self.silu = nn.SiLU() + self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True, **factory_kwargs) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.out_layer(self.silu(self.in_layer(x))) + + +class LinearWarpforSingle(nn.Module): + def __init__(self, in_dim: int, out_dim: int, bias=True, device=None, dtype=None): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.fc = nn.Linear(in_dim, out_dim, bias=bias, **factory_kwargs) + + def forward(self, x, y): + z = torch.cat([x, y], dim=2) + return self.fc(z) + +class FinalLayer1D(nn.Module): + def __init__(self, hidden_size, patch_size, out_channels, act_layer, device=None, dtype=None): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + + # Just use LayerNorm for the final layer + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) + self.linear = nn.Linear(hidden_size, patch_size * out_channels, bias=True, **factory_kwargs) + nn.init.zeros_(self.linear.weight) + nn.init.zeros_(self.linear.bias) + + # Here we don't distinguish between the modulate types. Just use the simple one. + self.adaLN_modulation = nn.Sequential( + act_layer(), nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs) + ) + # Zero-initialize the modulation + nn.init.zeros_(self.adaLN_modulation[1].weight) + nn.init.zeros_(self.adaLN_modulation[1].bias) + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) + x = modulate(self.norm_final(x), shift=shift, scale=scale) + x = self.linear(x) + return x + + +class ChannelLastConv1d(nn.Conv1d): + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x.permute(0, 2, 1) + x = super().forward(x) + x = x.permute(0, 2, 1) + return x + + +class ConvMLP(nn.Module): + + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int = 256, + kernel_size: int = 3, + padding: int = 1, + device=None, + dtype=None, + ): + """ + Convolutional MLP module. + + Args: + dim (int): Input dimension. + hidden_dim (int): Hidden dimension of the feedforward layer. + multiple_of (int): Value to ensure hidden dimension is a multiple of this value. + + Attributes: + w1: Linear transformation for the first layer. + w2: Linear transformation for the second layer. + w3: Linear transformation for the third layer. + + """ + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + self.w1 = ChannelLastConv1d(dim, hidden_dim, bias=False, kernel_size=kernel_size, padding=padding, **factory_kwargs) + self.w2 = ChannelLastConv1d(hidden_dim, dim, bias=False, kernel_size=kernel_size, padding=padding, **factory_kwargs) + self.w3 = ChannelLastConv1d(dim, hidden_dim, bias=False, kernel_size=kernel_size, padding=padding, **factory_kwargs) + + def forward(self, x): + return self.w2(F.silu(self.w1(x)) * self.w3(x)) diff --git a/hunyuanvideo_foley/models/nn/modulate_layers.py b/hunyuanvideo_foley/models/nn/modulate_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..5235233e996fc17c0476de04428a13bfcf3ba8fe --- /dev/null +++ b/hunyuanvideo_foley/models/nn/modulate_layers.py @@ -0,0 +1,49 @@ +from typing import Callable +import torch +import torch.nn as nn + +class ModulateDiT(nn.Module): + def __init__(self, hidden_size: int, factor: int, act_layer: Callable, dtype=None, device=None): + factory_kwargs = {"dtype": dtype, "device": device} + super().__init__() + self.act = act_layer() + self.linear = nn.Linear(hidden_size, factor * hidden_size, bias=True, **factory_kwargs) + # Zero-initialize the modulation + nn.init.zeros_(self.linear.weight) + nn.init.zeros_(self.linear.bias) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(self.act(x)) + + +def modulate(x, shift=None, scale=None): + if x.ndim == 3: + shift = shift.unsqueeze(1) if shift is not None and shift.ndim == 2 else None + scale = scale.unsqueeze(1) if scale is not None and scale.ndim == 2 else None + if scale is None and shift is None: + return x + elif shift is None: + return x * (1 + scale) + elif scale is None: + return x + shift + else: + return x * (1 + scale) + shift + + +def apply_gate(x, gate=None, tanh=False): + if gate is None: + return x + if gate.ndim == 2 and x.ndim == 3: + gate = gate.unsqueeze(1) + if tanh: + return x * gate.tanh() + else: + return x * gate + + +def ckpt_wrapper(module): + def ckpt_forward(*inputs): + outputs = module(*inputs) + return outputs + + return ckpt_forward diff --git a/hunyuanvideo_foley/models/nn/norm_layers.py b/hunyuanvideo_foley/models/nn/norm_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..7ad30b0ea4faeaa18e22ed25fcc44b97aee2d243 --- /dev/null +++ b/hunyuanvideo_foley/models/nn/norm_layers.py @@ -0,0 +1,70 @@ +import torch +import torch.nn as nn + +class RMSNorm(nn.Module): + def __init__(self, dim: int, elementwise_affine=True, eps: float = 1e-6, + device=None, dtype=None): + """ + Initialize the RMSNorm normalization layer. + + Args: + dim (int): The dimension of the input tensor. + eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. + + Attributes: + eps (float): A small value added to the denominator for numerical stability. + weight (nn.Parameter): Learnable scaling parameter. + + """ + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + self.eps = eps + if elementwise_affine: + self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs)) + + def _norm(self, x): + """ + Apply the RMSNorm normalization to the input tensor. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The normalized tensor. + + """ + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + """ + Forward pass through the RMSNorm layer. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The output tensor after applying RMSNorm. + + """ + output = self._norm(x.float()).type_as(x) + if hasattr(self, "weight"): + output = output * self.weight + return output + + +def get_norm_layer(norm_layer): + """ + Get the normalization layer. + + Args: + norm_layer (str): The type of normalization layer. + + Returns: + norm_layer (nn.Module): The normalization layer. + """ + if norm_layer == "layer": + return nn.LayerNorm + elif norm_layer == "rms": + return RMSNorm + else: + raise NotImplementedError(f"Norm layer {norm_layer} is not implemented") diff --git a/hunyuanvideo_foley/models/nn/posemb_layers.py b/hunyuanvideo_foley/models/nn/posemb_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..dbd188bbeb710d8155e758590446a51f5f0dd038 --- /dev/null +++ b/hunyuanvideo_foley/models/nn/posemb_layers.py @@ -0,0 +1,159 @@ +import torch +from typing import Union, Tuple + + +def _to_tuple(x, dim=2): + if isinstance(x, int): + return (x,) * dim + elif len(x) == dim: + return x + else: + raise ValueError(f"Expected length {dim} or int, but got {x}") + + +def get_meshgrid_nd(start, *args, dim=2): + """ + Get n-D meshgrid with start, stop and num. + + Args: + start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop, + step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num + should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in + n-tuples. + *args: See above. + dim (int): Dimension of the meshgrid. Defaults to 2. + + Returns: + grid (np.ndarray): [dim, ...] + """ + if len(args) == 0: + # start is grid_size + num = _to_tuple(start, dim=dim) + start = (0,) * dim + stop = num + elif len(args) == 1: + # start is start, args[0] is stop, step is 1 + start = _to_tuple(start, dim=dim) + stop = _to_tuple(args[0], dim=dim) + num = [stop[i] - start[i] for i in range(dim)] + elif len(args) == 2: + # start is start, args[0] is stop, args[1] is num + start = _to_tuple(start, dim=dim) # Left-Top eg: 12,0 + stop = _to_tuple(args[0], dim=dim) # Right-Bottom eg: 20,32 + num = _to_tuple(args[1], dim=dim) # Target Size eg: 32,124 + else: + raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}") + + # PyTorch implement of np.linspace(start[i], stop[i], num[i], endpoint=False) + axis_grid = [] + for i in range(dim): + a, b, n = start[i], stop[i], num[i] + g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n] + axis_grid.append(g) + grid = torch.meshgrid(*axis_grid, indexing="ij") # dim x [W, H, D] + grid = torch.stack(grid, dim=0) # [dim, W, H, D] + + return grid + + +################################################################################# +# Rotary Positional Embedding Functions # +################################################################################# +# https://github.com/meta-llama/llama/blob/be327c427cc5e89cc1d3ab3d3fec4484df771245/llama/model.py#L80 + + +def get_nd_rotary_pos_embed( + rope_dim_list, start, *args, theta=10000.0, use_real=False, theta_rescale_factor=1.0, freq_scaling=1.0 +): + """ + This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure. + + Args: + rope_dim_list (list of int): Dimension of each rope. len(rope_dim_list) should equal to n. + sum(rope_dim_list) should equal to head_dim of attention layer. + start (int | tuple of int | list of int): If len(args) == 0, start is num; If len(args) == 1, start is start, + args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. + *args: See above. + theta (float): Scaling factor for frequency computation. Defaults to 10000.0. + use_real (bool): If True, return real part and imaginary part separately. Otherwise, return complex numbers. + Some libraries such as TensorRT does not support complex64 data type. So it is useful to provide a real + part and an imaginary part separately. + theta_rescale_factor (float): Rescale factor for theta. Defaults to 1.0. + freq_scaling (float, optional): Frequence rescale factor, which is proposed in mmaudio. Defaults to 1.0. + + Returns: + pos_embed (torch.Tensor): [HW, D/2] + """ + + grid = get_meshgrid_nd(start, *args, dim=len(rope_dim_list)) # [3, W, H, D] / [2, W, H] + + # use 1/ndim of dimensions to encode grid_axis + embs = [] + for i in range(len(rope_dim_list)): + emb = get_1d_rotary_pos_embed( + rope_dim_list[i], + grid[i].reshape(-1), + theta, + use_real=use_real, + theta_rescale_factor=theta_rescale_factor, + freq_scaling=freq_scaling, + ) # 2 x [WHD, rope_dim_list[i]] + embs.append(emb) + + if use_real: + cos = torch.cat([emb[0] for emb in embs], dim=1) # (WHD, D/2) + sin = torch.cat([emb[1] for emb in embs], dim=1) # (WHD, D/2) + return cos, sin + else: + emb = torch.cat(embs, dim=1) # (WHD, D/2) + return emb + + +def get_1d_rotary_pos_embed( + dim: int, + pos: Union[torch.FloatTensor, int], + theta: float = 10000.0, + use_real: bool = False, + theta_rescale_factor: float = 1.0, + freq_scaling: float = 1.0, +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + Precompute the frequency tensor for complex exponential (cis) with given dimensions. + (Note: `cis` means `cos + i * sin`, where i is the imaginary unit.) + + This function calculates a frequency tensor with complex exponential using the given dimension 'dim' + and the end index 'end'. The 'theta' parameter scales the frequencies. + The returned tensor contains complex values in complex64 data type. + + Args: + dim (int): Dimension of the frequency tensor. + pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar + theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0. + use_real (bool, optional): If True, return real part and imaginary part separately. + Otherwise, return complex numbers. + theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0. + freq_scaling (float, optional): Frequence rescale factor, which is proposed in mmaudio. Defaults to 1.0. + + Returns: + freqs_cis: Precomputed frequency tensor with complex exponential. [S, D/2] + freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D] + """ + if isinstance(pos, int): + pos = torch.arange(pos).float() + + # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning + # has some connection to NTK literature + # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ + if theta_rescale_factor != 1.0: + theta *= theta_rescale_factor ** (dim / (dim - 1)) + + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [D/2] + freqs *= freq_scaling + freqs = torch.outer(pos, freqs) # [S, D/2] + if use_real: + freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D] + freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D] + return freqs_cos, freqs_sin + else: + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2] + return freqs_cis diff --git a/hunyuanvideo_foley/models/synchformer/__init__.py b/hunyuanvideo_foley/models/synchformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6331d467e30b9d24378025bf540e7430ff4fd7ad --- /dev/null +++ b/hunyuanvideo_foley/models/synchformer/__init__.py @@ -0,0 +1 @@ +from .synchformer import Synchformer diff --git a/hunyuanvideo_foley/models/synchformer/ast_model.py b/hunyuanvideo_foley/models/synchformer/ast_model.py new file mode 100644 index 0000000000000000000000000000000000000000..5f4394306ccd08de3a1e6bb556df8f42d2e4cacb --- /dev/null +++ b/hunyuanvideo_foley/models/synchformer/ast_model.py @@ -0,0 +1,289 @@ +import logging + +import torch +from transformers.modeling_outputs import BaseModelOutputWithPooling + +from .modeling_ast import ASTForAudioClassification, ASTConfig +from .motionformer import AveragePooling, BaseEncoderLayer, TemporalTransformerEncoderLayer +from .utils import check_if_file_exists_else_download + + +class AST(torch.nn.Module): + def __init__( + self, + extract_features: bool = False, + ckpt_path: str = None, + feat_type: str = None, + max_spec_t: int = None, + factorize_freq_time: bool = None, + agg_freq_module: str = None, + agg_time_module: str = None, + add_global_repr: bool = True, + agg_segments_module: str = None, + max_segments: int = None, + ) -> None: + """ + extract_features: if True, then the model will return the features instead of head's output + ckpt_path: is not a path to a ckpt file, but a name of a model from the HuggingFace model hub. + feat_type: if extract_features is True, this parameter specifies the type of features to return + max_spec_t: if specified, then the model (pos emb) will be patched to support this length of spec + factorize_freq_time: if True, then the model will use a factorized freq/time aggregation + agg_freq_module: if specified, then the model will use this module for freq aggregation + agg_time_module: if specified, then the model will use this module for time aggregation + add_global_repr: if True, adds a global representation to the features (aggregation on segments) + agg_segments_module: if specified, then the model will use this module for segments aggregation + max_segments: if specified, the initialization of PE in the global agg module will use this value. + This should correspond to the max number of segments per video (if None, 16 is used) + """ + super().__init__() + self.extract_features = extract_features + self.ckpt_path = ckpt_path + self.max_spec_t = max_spec_t + self.max_segments = max_segments + + # depending on whether the feat extractor was pre-trained contrastively or not, we need to + # load the state dict differently. + + # if ckpt is specified, then load the model from the HuggingFace model hub, otherwise init a new model + if ckpt_path == "MIT/ast-finetuned-audioset-10-10-0.4593": + revision = "c1c0c66" # fixing the revision for compatibility (V4.27.4) + self.config = ASTConfig.from_pretrained(ckpt_path, revision=revision) + full_model = ASTForAudioClassification.from_pretrained(ckpt_path, revision=revision) + logging.info(f"Loaded AST from {ckpt_path}") + else: + self.config = ASTConfig() + self.config.num_labels = 527 # 2 by default, audioset has 527 labels + full_model = ASTForAudioClassification(self.config) + logging.info("Initialized AST from scratch with the AST AudioSet config") + + was_pt_on_avclip = ckpt_path is not None and ckpt_path.endswith(".pt") + + # feature extractor + self.ast = full_model.audio_spectrogram_transformer + + if self.extract_features: + # assign `feat_type` (use default if not specified) + self.feat_type = "last_hidden_state" if feat_type is None else feat_type + # define adapters if needed + self.factorize_freq_time = factorize_freq_time + # avoiding code duplication (used only if agg_*_module is TransformerEncoderLayer) + transf_enc_layer_kwargs = dict( + d_model=self.config.hidden_size, + nhead=self.config.num_attention_heads, + dim_feedforward=self.config.intermediate_size, + activation=torch.nn.GELU(), + batch_first=True, + dropout=self.config.attention_probs_dropout_prob, + layer_norm_eps=1e-6, + norm_first=True, + ) + if factorize_freq_time: + self.feat_type = "last_hidden_state" # this feat_type supports factorization + # frequency aggreration + if agg_freq_module == "TransformerEncoderLayer": + self.freq_attn_agg = FrequencyTransformerEncoderLayer(**transf_enc_layer_kwargs) + elif agg_freq_module == "AveragePooling": + self.freq_attn_agg = AveragePooling( + avg_pattern="BS D f t -> BS D t", then_permute_pattern="BS D t -> BS t D" + ) + # time aggreration + if agg_time_module == "TransformerEncoderLayer": + self.temp_attn_agg = TemporalTransformerEncoderLayer(**transf_enc_layer_kwargs) + elif agg_time_module == "AveragePooling": + self.temp_attn_agg = AveragePooling(avg_pattern="BS t D -> BS D") + elif "Identity" in agg_time_module: + self.temp_attn_agg = torch.nn.Identity() + # define a global aggregation layer (aggregarate over segments) + self.add_global_repr = add_global_repr + if add_global_repr: + if agg_segments_module == "TransformerEncoderLayer": + # we can reuse the same layer as for temporal factorization (B, dim_to_agg, D) -> (B, D) + # we need to add pos emb (PE) because previously we added the same PE for each segment + pos_max_len = max_segments if max_segments is not None else 16 # 16 = 10sec//0.64sec + 1 + self.global_attn_agg = TemporalTransformerEncoderLayer( + add_pos_emb=True, + pos_emb_drop=self.config.hidden_dropout_prob, + pos_max_len=pos_max_len, + **transf_enc_layer_kwargs, + ) + elif agg_segments_module == "AveragePooling": + self.global_attn_agg = AveragePooling(avg_pattern="B S D -> B D") + else: + self.classifier = full_model.classifier + + # AST.device fails with AttributeError. This is a workaround + self.device = full_model.device + + # pre-trained on 12*101+2=1214 tokens, but we have less (e.g. 12*6+2=74) + self.patch_position_emb() + + if was_pt_on_avclip: + # we need to filter out the state_dict of the AVCLIP model (has both A and V extractors) + # and keep only the state_dict of the feat extractor + check_if_file_exists_else_download(self.ckpt_path) + ckpt = torch.load(ckpt_path, map_location="cpu") + ckpt_weights = dict() + for k, v in ckpt["state_dict"].items(): + if k.startswith(("module.a_encoder.", "a_encoder.")): + k = k.replace("module.", "").replace("a_encoder.", "") + ckpt_weights[k] = v + _load_status = self.load_state_dict(ckpt_weights, strict=False) + if len(_load_status.missing_keys) > 0 or len(_load_status.unexpected_keys) > 0: + logging.warning( + f"Loading exact afeat_extractor ckpt from {self.ckpt_path} failed. \n" + f"Missing keys ({len(_load_status.missing_keys)}): " + f"{_load_status.missing_keys}, \n" + f"Unexpected keys ({len(_load_status.unexpected_keys)}): " + f"{_load_status.unexpected_keys} \n" + f"temp_attn_agg are expected to be missing if ckpt was pt contrastively." + ) + else: + logging.info(f"Loading afeat_extractor ckpt from {self.ckpt_path} succeeded.") + + # print the number of parameters + logging.info(f"AST: {sum(p.numel() for p in self.parameters() if p.requires_grad):,}") + + def forward( + self, x: torch.Tensor, for_loop: bool = False, cont_mask: torch.Tensor = None, **ast_kwargs + ) -> torch.Tensor: + """ + x: (B, S, T, F) where S is number of segments, F is number of (mel) frequency bins, + ast_kwargs: additional arguments for the AST model + cont_mask: (B, S, T, F) where 0s are the values to be masked out + if `for_loop=True`, we use a for loop to extract features for each segment separately. + if `for_loop=False`, we extract features for all segments at once. + Using the for loop is slower but more memory efficient, while using all segments at once + is faster but more memory inefficient. + Using for loop allows to control the memory footprint by varying the number of videos in a + batch (batch size) rather than the number of segments in a video. + """ + B, S, T, F = x.shape + + if for_loop: + assert cont_mask is None, "cont_mask is not supported with for_loop=True" + orig_shape_s = (B, 1, T, F) + # NOTE: since x is (B, S, T, F), and forward_segments expects (BS, T, F). + # (B, S, T, F)[:, s] is (B, T, F) or (BS, T, F) if S=1. + x = torch.cat( + [self.forward_segments(x[:, s], orig_shape_s, **ast_kwargs).unsqueeze(1) for s in range(S)], dim=1 + ) + else: + orig_shape = (B, S, T, F) + x = x.view(B * S, T, F) + if cont_mask is not None: + cont_mask = cont_mask.reshape(B * S, T, F) + # AST expects a tensor of shape (B*S, T, F). + x = self.forward_segments(x, orig_shape=orig_shape, cont_mask=cont_mask, **ast_kwargs) + # unpack the segments (using rest dimensions to support different shapes e.g. (BS, D) or (BS, t, D)) + x = x.view(B, S, *x.shape[1:]) + # x now is of shape (B, S, D) or (B, S, t, D) if `self.temp_attn_agg` is `Identity` + + global_x = None + if self.extract_features and self.add_global_repr: # lazy execution, throws AttributeError + assert len(x.shape) == 3, f"Local representation should be (B, S, D) {x.shape}" + global_x = self.global_attn_agg(x) # (B, D) + + return x, global_x # x is (B, S, ...), global_x is (B, D) or None + + def forward_segments(self, x, orig_shape: tuple, cont_mask: torch.Tensor = None, **ast_kwargs): + """x is (BS, T, F), where S is the number of segments; cont_mask is (BS, T, F): 0s to be masked out""" + # 'pooler_output': (B, D); or 'last_hidden_state: (B, T, D) where T is [CLS, DISTILL, ] + # x_mask is (B, T) where 0s are the values to be masked out + x, x_mask = self.ast(x, cont_mask=cont_mask, **ast_kwargs) + + if self.extract_features: + x = self.get_features_by_type(x) + if self.factorize_freq_time: + x = self.restore_freq_temp_dims(x, orig_shape) # (BS, D, f, t) <- (B*S, T, D) + if cont_mask is not None: + # duplicating the mask for the latent dimension (D) to be compatible with the next func + x_mask = x_mask.unsqueeze(-1).expand(-1, -1, self.config.hidden_size) + x_mask = self.restore_freq_temp_dims(x_mask, orig_shape) # (BS, D, f, t) <- (B*S, T, D) + # again removing the latent + x_mask = x_mask[:, 0, :, :] + else: + x_mask = None + x = self.freq_attn_agg(x, x_mask) # (BS, t, D) + x = self.temp_attn_agg(x) # (BS, D) or (BS, t, D) if self.temp_attn_agg is Identity + else: + x = x["pooler_output"] + x = self.classifier(x) + return x + + def get_features_by_type(self, x: BaseModelOutputWithPooling) -> torch.Tensor: + if self.feat_type == "pooler_output": + return x["pooler_output"] # (B, D) + elif self.feat_type == "CLS": + return x["last_hidden_state"][:, 0, :] # (B, D) + elif self.feat_type == "last_hidden_state": + return x["last_hidden_state"] # (B, 2+T, D) + elif self.feat_type == "last_hidden_state_no_AUX": + return x["last_hidden_state"][:, 2:, :] # (B, T, D) removing CLS and distill tokens + else: + raise ValueError(f"Unknown feature type: {self.feat_type}") + + def restore_freq_temp_dims(self, feats, orig_shape: tuple): + """ + feats are of shape (B*S, T, D) + where T = 2 + f * t (if feat_type == 'last_hidden_state') + where T = f * t (if feat_type == 'last_hidden_state_no_AUX') + Our goal is to make them of shape (B*S, f, t, D) where f and t are dimensions after patching. + From `self.ast.embeddings.patch_embeddings`, it follows that we could reshape feats: + `feats.transpose(1, 2).view(B*S, D, f, t)` + + (Similar function is defined in for RGB features in `motionformer.py`) + """ + B, S, T, F = orig_shape + D = self.config.hidden_size + + # num patches in each dimension + f, t = self.ast.embeddings.get_shape(self.config) + + if self.feat_type == "last_hidden_state": + feats = feats[:, 2:, :] # removing CLS and distill tokens + + feats = feats.permute(0, 2, 1) # (B*S, D, T) + feats = feats.view(B * S, D, f, t) # (B*S, D, f, t) + + return feats + + def patch_position_emb(self): + if self.max_spec_t is not None: + self.config.max_length = self.max_spec_t + f, t = self.ast.embeddings.get_shape(self.config) + shortened = self.ast.embeddings.position_embeddings[:, : f * t + 2].clone() # +2 for CLS and distill tokens + self.ast.embeddings.position_embeddings = torch.nn.Parameter(shortened).to(self.device) + + def to(self, device): + """AST.device fails with AttributeError. This is a workaround.""" + self.device = torch.device(device) + return super().to(device) + + +class FrequencyTransformerEncoderLayer(BaseEncoderLayer): + """This layer is used to aggregate the features along the frequency axis. + It follows the same logic as spatio-temporal aggregation in visual feature extractor. + Thus, it is recommended to check the definition of `BaseEncoderLayer` in `motionformer.py`""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, x: torch.Tensor, x_mask: torch.Tensor = None) -> torch.Tensor: + """x: (B*S, D, f, t); if specified x_mask (B*S, f, t), 0s are the values to be masked out""" + BS, D, f, t = x.shape + + # time as a batch dimension + x = x.permute(0, 3, 2, 1) # (B*S, t, f, D) + x = x.reshape(BS * t, f, D) # .view() fails with non-contiguous memory + # similar to mask + if x_mask is not None: + x_mask = x_mask.permute(0, 2, 1) # (B*S, t, f) + x_mask = x_mask.reshape(BS * t, f) + + # apply encoder layer (BaseEncoderLayer.forward) - it will add CLS token and output its representation + x = super().forward(x=x, x_mask=x_mask) # (B*S*t, D) + + # reshape back to (B*S, t, D) + x = x.view(BS, t, D) + + return x # (B*S, t, D) diff --git a/hunyuanvideo_foley/models/synchformer/compute_desync_score.py b/hunyuanvideo_foley/models/synchformer/compute_desync_score.py new file mode 100644 index 0000000000000000000000000000000000000000..936c4994bd6d68444980959f78aa710e4ba5a205 --- /dev/null +++ b/hunyuanvideo_foley/models/synchformer/compute_desync_score.py @@ -0,0 +1,214 @@ +import argparse +import subprocess +from pathlib import Path + +import torch +import torchaudio +import torchvision +from omegaconf import OmegaConf + +import data_transforms +from .synchformer import Synchformer +from .data_transforms import make_class_grid, quantize_offset +from .utils import check_if_file_exists_else_download, which_ffmpeg + + +def prepare_inputs(batch, device): + aud = batch["audio"].to(device) + vid = batch["video"].to(device) + + return aud, vid + + +def get_test_transforms(): + ts = [ + data_transforms.EqualifyFromRight(), + data_transforms.RGBSpatialCrop(input_size=224, is_random=False), + data_transforms.TemporalCropAndOffset( + crop_len_sec=5, + max_off_sec=2, # https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a/sync/sync_models/24-01-04T16-39-21/cfg-24-01-04T16-39-21.yaml + max_wiggle_sec=0.0, + do_offset=True, + offset_type="grid", + prob_oos="null", + grid_size=21, + segment_size_vframes=16, + n_segments=14, + step_size_seg=0.5, + vfps=25, + ), + data_transforms.GenerateMultipleSegments( + segment_size_vframes=16, + n_segments=14, + is_start_random=False, + step_size_seg=0.5, + ), + data_transforms.RGBToHalfToZeroOne(), + data_transforms.RGBNormalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), # motionformer normalization + data_transforms.AudioMelSpectrogram( + sample_rate=16000, + win_length=400, # 25 ms * 16 kHz + hop_length=160, # 10 ms * 16 kHz + n_fft=1024, # 2^(ceil(log2(window_size * sampling_rate))) + n_mels=128, # as in AST + ), + data_transforms.AudioLog(), + data_transforms.PadOrTruncate(max_spec_t=66), + data_transforms.AudioNormalizeAST(mean=-4.2677393, std=4.5689974), # AST, pre-trained on AudioSet + data_transforms.PermuteStreams( + einops_order_audio="S F T -> S 1 F T", einops_order_rgb="S T C H W -> S T C H W" # same + ), + ] + transforms = torchvision.transforms.Compose(ts) + + return transforms + + +def get_video_and_audio(path, get_meta=False, start_sec=0, end_sec=None): + orig_path = path + # (Tv, 3, H, W) [0, 255, uint8]; (Ca, Ta) + rgb, audio, meta = torchvision.io.read_video(str(path), start_sec, end_sec, "sec", output_format="TCHW") + assert meta["video_fps"], f"No video fps for {orig_path}" + # (Ta) <- (Ca, Ta) + audio = audio.mean(dim=0) + # FIXME: this is legacy format of `meta` as it used to be loaded by VideoReader. + meta = { + "video": {"fps": [meta["video_fps"]]}, + "audio": {"framerate": [meta["audio_fps"]]}, + } + return rgb, audio, meta + + +def reencode_video(path, vfps=25, afps=16000, in_size=256): + assert which_ffmpeg() != "", "Is ffmpeg installed? Check if the conda environment is activated." + new_path = Path.cwd() / "vis" / f"{Path(path).stem}_{vfps}fps_{in_size}side_{afps}hz.mp4" + new_path.parent.mkdir(exist_ok=True) + new_path = str(new_path) + cmd = f"{which_ffmpeg()}" + # no info/error printing + cmd += " -hide_banner -loglevel panic" + cmd += f" -y -i {path}" + # 1) change fps, 2) resize: min(H,W)=MIN_SIDE (vertical vids are supported), 3) change audio framerate + cmd += f" -vf fps={vfps},scale=iw*{in_size}/'min(iw,ih)':ih*{in_size}/'min(iw,ih)',crop='trunc(iw/2)'*2:'trunc(ih/2)'*2" + cmd += f" -ar {afps}" + cmd += f" {new_path}" + subprocess.call(cmd.split()) + cmd = f"{which_ffmpeg()}" + cmd += " -hide_banner -loglevel panic" + cmd += f" -y -i {new_path}" + cmd += f" -acodec pcm_s16le -ac 1" + cmd += f' {new_path.replace(".mp4", ".wav")}' + subprocess.call(cmd.split()) + return new_path + + +def decode_single_video_prediction(off_logits, grid, item): + label = item["targets"]["offset_label"].item() + print("Ground Truth offset (sec):", f"{label:.2f} ({quantize_offset(grid, label)[-1].item()})") + print() + print("Prediction Results:") + off_probs = torch.softmax(off_logits, dim=-1) + k = min(off_probs.shape[-1], 5) + topk_logits, topk_preds = torch.topk(off_logits, k) + # remove batch dimension + assert len(topk_logits) == 1, "batch is larger than 1" + topk_logits = topk_logits[0] + topk_preds = topk_preds[0] + off_logits = off_logits[0] + off_probs = off_probs[0] + for target_hat in topk_preds: + print(f'p={off_probs[target_hat]:.4f} ({off_logits[target_hat]:.4f}), "{grid[target_hat]:.2f}" ({target_hat})') + return off_probs + + +def main(args): + vfps = 25 + afps = 16000 + in_size = 256 + # making the offset class grid similar to the one used in transforms, + # refer to the used one: https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a/sync/sync_models/24-01-04T16-39-21/cfg-24-01-04T16-39-21.yaml + max_off_sec = 2 + num_cls = 21 + + # checking if the provided video has the correct frame rates + print(f"Using video: {args.vid_path}") + v, _, info = torchvision.io.read_video(args.vid_path, pts_unit="sec") + _, H, W, _ = v.shape + if info["video_fps"] != vfps or info["audio_fps"] != afps or min(H, W) != in_size: + print(f'Reencoding. vfps: {info["video_fps"]} -> {vfps};', end=" ") + print(f'afps: {info["audio_fps"]} -> {afps};', end=" ") + print(f"{(H, W)} -> min(H, W)={in_size}") + args.vid_path = reencode_video(args.vid_path, vfps, afps, in_size) + else: + print(f'Skipping reencoding. vfps: {info["video_fps"]}; afps: {info["audio_fps"]}; min(H, W)={in_size}') + + device = torch.device(args.device) + + # load visual and audio streams + # rgb: (Tv, 3, H, W) in [0, 225], audio: (Ta,) in [-1, 1] + rgb, audio, meta = get_video_and_audio(args.vid_path, get_meta=True) + + # making an item (dict) to apply transformations + # NOTE: here is how it works: + # For instance, if the model is trained on 5sec clips, the provided video is 9sec, and `v_start_i_sec=1.3` + # the transform will crop out a 5sec-clip from 1.3 to 6.3 seconds and shift the start of the audio + # track by `args.offset_sec` seconds. It means that if `offset_sec` > 0, the audio will + # start by `offset_sec` earlier than the rgb track. + # It is a good idea to use something in [-`max_off_sec`, `max_off_sec`] (-2, +2) seconds (see `grid`) + item = dict( + video=rgb, + audio=audio, + meta=meta, + path=args.vid_path, + split="test", + targets={ + "v_start_i_sec": args.v_start_i_sec, + "offset_sec": args.offset_sec, + }, + ) + + grid = make_class_grid(-max_off_sec, max_off_sec, num_cls) + if not (min(grid) <= item["targets"]["offset_sec"] <= max(grid)): + print(f'WARNING: offset_sec={item["targets"]["offset_sec"]} is outside the trained grid: {grid}') + + # applying the test-time transform + item = get_test_transforms()(item) + + # prepare inputs for inference + batch = torch.utils.data.default_collate([item]) + aud, vid = prepare_inputs(batch, device) + + # TODO: + # sanity check: we will take the input to the `model` and recontruct make a video from it. + # Use this check to make sure the input makes sense (audio should be ok but shifted as you specified) + # reconstruct_video_from_input(aud, vid, batch['meta'], args.vid_path, args.v_start_i_sec, args.offset_sec, + # vfps, afps) + + # forward pass + with torch.set_grad_enabled(False): + with torch.autocast("cuda", enabled=True): + _, logits = synchformer(vid, aud) + + # simply prints the results of the prediction + decode_single_video_prediction(logits, grid, item) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--exp_name", required=True, help="In a format: xx-xx-xxTxx-xx-xx") + parser.add_argument("--vid_path", required=True, help="A path to .mp4 video") + parser.add_argument("--offset_sec", type=float, default=0.0) + parser.add_argument("--v_start_i_sec", type=float, default=0.0) + parser.add_argument("--device", default="cuda:0") + args = parser.parse_args() + + synchformer = Synchformer().cuda().eval() + synchformer.load_state_dict( + torch.load( + os.environ.get("SYNCHFORMER_WEIGHTS", f"weights/synchformer.pth"), + weights_only=True, + map_location="cpu", + ) + ) + + main(args) diff --git a/hunyuanvideo_foley/models/synchformer/data_transforms.py b/hunyuanvideo_foley/models/synchformer/data_transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..d59331eb3c27454a4c52bdbf8a8b85946c63c0a3 --- /dev/null +++ b/hunyuanvideo_foley/models/synchformer/data_transforms.py @@ -0,0 +1,1130 @@ +import logging +import math +import random +from typing import Tuple +import torch +import torchvision +import torchaudio +import numpy as np +import einops + + +def sec2frames(sec, fps): + return int(sec * fps) + + +def frames2sec(frames, fps): + return frames / fps + + +class EqualifyFromRight(torch.nn.Module): + + def __init__(self, clip_max_len_sec=10): + """ + Takes the dataset item and makes sure more streams are of an equal size in terms of fps. + It, however, assumes that the signal is synched and trims the ending parts ('from the right'). + """ + super().__init__() + self.clip_max_len_sec = clip_max_len_sec + + def forward(self, item): + """ + `item`: {'video': (Tv, C, H, W), 'audio': (Ta,), + 'meta': { + 'audio': {'framerate': [float], 'duration': [float]} + 'video': {'fps': [float], 'duration': [float]}} + """ + a_fps = item["meta"]["audio"]["framerate"][0] + v_fps = item["meta"]["video"]["fps"][0] + + Ta = item["audio"].shape[0] + Tv, C, H, W = item["video"].shape + + a_len_secs = Ta / a_fps + v_len_secs = Tv / v_fps + min_len = min(self.clip_max_len_sec, a_len_secs, v_len_secs) + + a_frames_per_v_frame = a_fps // v_fps + v_len_frames = int(v_fps * min_len) + a_len_frames = int(a_frames_per_v_frame * v_len_frames) + # print(a_len_frames, v_len_frames) + + assert a_len_frames <= Ta and v_len_frames <= Tv + + item["audio"] = item["audio"][:a_len_frames] + item["video"] = item["video"][:v_len_frames, :, :, :] + + return item + + +class RGBSpatialCrop(torch.nn.Module): + + def __init__(self, input_size, is_random): + super().__init__() + assert input_size is not None, f"smaller_input_size is `{input_size}`" + if isinstance(input_size, int): + input_size = (input_size, input_size) + self.input_size = input_size + self.is_random = is_random + + @staticmethod + def get_random_crop_sides(vid, output_size): + """Slice parameters for random crop""" + h, w = vid.shape[-2:] + th, tw = output_size + if w == tw and h == th: + return 0, 0, h, w + i = random.randint(0, h - th) + j = random.randint(0, w - tw) + return i, j, th, tw + + @staticmethod + def get_center_crop_sides(vid, output_size): + """Slice parameters for center crop""" + h, w = vid.shape[-2:] + th, tw = output_size + + i = int(round((h - th) / 2.0)) + j = int(round((w - tw) / 2.0)) + return i, j, th, tw + + def forward(self, item): + # (Tv, C, H, W) + vid = item["video"] + if self.is_random: + i, j, h, w = self.get_random_crop_sides(vid, self.input_size) + else: + i, j, h, w = self.get_center_crop_sides(vid, self.input_size) + item["video"] = vid[..., i : (i + h), j : (j + w)] + return item + + +class Resize(torchvision.transforms.Resize): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, item): + item["video"] = super().forward(item["video"]) + return item + + +class RGBSpatialCropSometimesUpscale(torch.nn.Module): + """This (randomly) crops the input video and with prob `sometimes_p` this crop is smaller but upscaled + to `target_input_size`""" + + def __init__(self, sometimes_p, target_input_size, is_random, smaller_input_size=None): + super().__init__() + self.sometimes_p = sometimes_p + self.do_sometimes_upscale = sometimes_p is not None and sometimes_p > 0 + + self.crop_only = RGBSpatialCrop(target_input_size, is_random) + + if self.do_sometimes_upscale: + self.crop_further_and_upscale = torchvision.transforms.Compose( + [ + RGBSpatialCrop(smaller_input_size, is_random), + Resize(target_input_size, antialias=None), + ] + ) + + def forward(self, item): + assert len(item["video"].shape) == 4, ( + f"{item['video'].shape}: if it is applied after GenerateMultipleClips," + "augs should be applied to each clip separately, not to the whole video array. " + "Otherwise, ignore this warning (comment it)." + ) + if self.do_sometimes_upscale and self.sometimes_p > torch.rand(1): + return self.crop_further_and_upscale(item) + else: + return self.crop_only(item) + + +class RandomApplyColorDistortion(torch.nn.Module): + + def __init__(self, p_gray_scale=0.0, p_color_jitter=0.0, s=1.0) -> None: + super().__init__() + self.p_gray_scale = p_gray_scale + self.p_color_jitter = p_color_jitter + self.s = s + assert 0 <= self.p_color_jitter <= 1 and 0 <= self.p_gray_scale <= 1, (p_color_jitter, p_gray_scale) + # SimCLR params + color_jitter = torchvision.transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s) + rand_color_jitter = torchvision.transforms.RandomApply([color_jitter], p_color_jitter) + rand_gray = torchvision.transforms.RandomGrayscale(p_gray_scale) + self.transforms = torchvision.transforms.Compose([rand_color_jitter, rand_gray]) + + def apply_to_single_clip(self, clip): + return self.transforms(clip) + + def apply_to_each_clip(self, clips): + for i, clip in enumerate(clips): + clips[i] = self.apply_to_single_clip(clip) + return clips + + def forward(self, item): + has_batch_dim = len(item["video"].shape) == 5 + if has_batch_dim: + fn = self.apply_to_each_clip + else: + fn = self.apply_to_single_clip + item["video"] = fn(item["video"]) + return item + + +class ApplyColorJitterFrameWise(torch.nn.Module): + + def __init__(self, s=1.0) -> None: + super().__init__() + self.s = s + # SimCLR params + self.transform = torchvision.transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s) + + def apply_to_single_clip(self, clip): + for i, frame in enumerate(clip): + clip[i] = self.transform(frame) + return clip + + def apply_to_each_clip(self, clips): + for i, clip in enumerate(clips): + clips[i] = self.apply_to_single_clip(clip) + return clips + + def forward(self, item): + has_batch_dim = len(item["video"].shape) == 5 + if has_batch_dim: + fn = self.apply_to_each_clip + else: + fn = self.apply_to_single_clip + item["video"] = fn(item["video"]) + return item + + +class RandomHorizontalFlip(torchvision.transforms.RandomHorizontalFlip): + + def __init__(self, p=0.5): + super().__init__(p) + + def apply_to_single_clip(self, clip): + return super().forward(clip) + + def apply_to_each_clip(self, clips): + for i, clip in enumerate(clips): + clips[i] = self.apply_to_single_clip(clip) + return clips + + def forward(self, item): + has_batch_dim = len(item["video"].shape) == 5 + if has_batch_dim: + fn = self.apply_to_each_clip + else: + fn = self.apply_to_single_clip + item["video"] = fn(item["video"]) + return item + + +def make_class_grid( + leftmost_val, + rightmost_val, + grid_size, + add_extreme_offset: bool = False, + seg_size_vframes: int = None, + nseg: int = None, + step_size_seg: float = None, + vfps: float = None, +): + assert grid_size >= 3, f"grid_size: {grid_size} doesnot make sense. If =2 -> (-1,1); =1 -> (-1); =0 -> ()" + grid = torch.from_numpy(np.linspace(leftmost_val, rightmost_val, grid_size)).float() + if add_extreme_offset: + assert all([seg_size_vframes, nseg, step_size_seg]), f"{seg_size_vframes} {nseg} {step_size_seg}" + seg_size_sec = seg_size_vframes / vfps + trim_size_in_seg = nseg - (1 - step_size_seg) * (nseg - 1) + extreme_value = trim_size_in_seg * seg_size_sec + grid = torch.cat([grid, torch.tensor([extreme_value])]) # adding extreme offset to the class grid + return grid + + +def quantize_offset(grid: torch.Tensor, off_sec: float) -> Tuple[float, int]: + """Takes in the offset in seconds and snaps it onto the closest grid element. + Returns the grid value and its index.""" + closest_grid_el = (grid - off_sec).abs().argmin() + return grid[closest_grid_el], closest_grid_el + + +def apply_a_jitter(a_start_i, a_len_frames, a_crop_len_frames, a_fps, max_a_jitter_sec): + max_a_start_i = a_len_frames - a_crop_len_frames + max_a_jitter_i = sec2frames(max_a_jitter_sec, a_fps) + max_a_jitter_i_left = min(a_start_i, max_a_jitter_i) + max_a_jitter_i_right = min(max_a_start_i - a_start_i, max_a_jitter_i) + # jitter is U[left, right] + a_jitter_i = random.randint(-max_a_jitter_i_left, max_a_jitter_i_right) + # apply jitter + a_start_i = a_start_i + a_jitter_i + # making sure that any value from `a_start_i + U[left, right]` will be inside of [0, len-crop] region + assert 0 <= a_start_i <= max_a_start_i, f"{a_jitter_i} {max_a_jitter_i_left} {max_a_jitter_i_right} {max_a_start_i}" + return a_start_i, a_jitter_i + + +class TemporalCropAndOffset(torch.nn.Module): + + def __init__( + self, + crop_len_sec: float, + max_off_sec: float, + offset_type="grid", + do_offset: bool = True, + grid_size: int = None, + max_wiggle_sec: float = None, + add_doubt_cls: bool = False, + segment_size_vframes: int = None, + n_segments: int = None, + step_size_seg: float = None, + vfps: float = None, + prob_oos: float = None, + ): + super().__init__() + self.crop_len_sec = crop_len_sec + self.do_offset = do_offset + self.grid_size = grid_size + self.offset_type = offset_type + self.max_off_sec = max_off_sec + self.max_a_jitter_sec = max_wiggle_sec + if do_offset: + if offset_type == "grid": + self.class_grid = make_class_grid( + -max_off_sec, + max_off_sec, + grid_size, + add_doubt_cls, + segment_size_vframes, + n_segments, + step_size_seg, + vfps, + ) + logging.info(f"Offsets class grid: {self.class_grid}") + if self.max_a_jitter_sec is not None: + assert (max_wiggle_sec - 1e-6) <= ( + (self.class_grid[1] - self.class_grid[0]) / 2 + ), f"{self.class_grid}" + elif offset_type == "uniform": + self.off_dist = torch.distributions.uniform.Uniform(-max_off_sec, max_off_sec) + logging.info(f"Offset uniform distribution: {self.off_dist}") + elif offset_type == "uniform_binary": + self.itu_t_range = (-0.125, 0.045) + self.prob_oos = prob_oos + self.ins_dist = torch.distributions.uniform.Uniform(self.itu_t_range[0], self.itu_t_range[1]) + self.off_dist = torch.distributions.uniform.Uniform(-max_off_sec, max_off_sec) + else: + raise NotImplementedError(f"Unknown offset type: {offset_type}") + + def forward(self, item): + vid = item["video"] + aud = item["audio"] + v_len_frames, C, H, W = vid.shape + a_len_frames = aud.shape[0] + + v_fps = int(item["meta"]["video"]["fps"][0]) + a_fps = int(item["meta"]["audio"]["framerate"][0]) + + v_crop_len_frames = sec2frames(self.crop_len_sec, v_fps) + a_crop_len_frames = sec2frames(self.crop_len_sec, a_fps) + + if self.do_offset: + # trying to get the offset parameters (for instance during valid and test we have fixed offsets) + offset_sec = item["targets"].get("offset_sec", None) + v_start_i_sec = item["targets"].get("v_start_i_sec", None) + if "offset_target" in item["targets"]: + is_oos = item["targets"]["offset_target"].get("oos", None) + # train-time + if offset_sec is None and v_start_i_sec is None: + # aud starts `offset_sec` earlier than it should; aud has what will be shown after offset_sec + if self.offset_type == "grid": + offset_sec = random.choice(self.class_grid.tolist()) + elif self.offset_type == "uniform": + offset_sec = self.off_dist.sample().item() + elif self.offset_type == "uniform_binary": + # in-sync: Uniform(-0.125, 0.045) + # out-of-sync: Uniform(-5.5, 5.5) and resampled until not in Uniform(-0.125, 0.045) + # first, we sample if the offset is out-of-sync with prob_oss + is_oos = (torch.rand(1) < self.prob_oos).item() + if is_oos: + # second, we sample the offset itself (if in in-sync range, trying again) + offset_sec = self.off_dist.sample().item() + while self.itu_t_range[0] <= offset_sec <= self.itu_t_range[1]: + offset_sec = self.off_dist.sample().item() + else: + offset_sec = self.ins_dist.sample().item() + offset_sec = round(offset_sec, 2) + v_start_max_sec = frames2sec(v_len_frames - v_crop_len_frames, v_fps) + assert v_start_max_sec > 0, f'{v_len_frames} {v_crop_len_frames} {v_fps} @ {item["path"]}' + # `v_start_sec` IS NOT rounded to the fps grid + v_start_sec = random.uniform(max(0, -offset_sec), min(v_start_max_sec, v_start_max_sec - offset_sec)) + assert 0 <= v_start_sec <= v_start_max_sec, f'{v_start_sec} {v_start_max_sec} {item["path"]}' + v_start_i = sec2frames(v_start_sec, v_fps) + # `v_start_i_sec` IS rounded to the fps grid + v_start_i_sec = frames2sec(v_start_i, v_fps) + else: + offset_sec = round(offset_sec, 2) + v_start_i = sec2frames(v_start_i_sec, v_fps) + v_end_i = v_start_i + v_crop_len_frames + # `a_start_i` depends on the rounded value `v_start_i_sec`, otherwise + # (v_start_sec) we have ±0.1 jittering + a_start_i = sec2frames(v_start_i_sec + offset_sec, a_fps) + else: + offset_sec = 0.0 + is_random_crop = item["split"] == "train" + v_start_i, v_end_i = self.get_crop_idx(v_len_frames, v_crop_len_frames, is_random=is_random_crop) + v_start_i_sec = frames2sec(v_start_i, v_fps) + a_start_i = sec2frames(v_start_i_sec, a_fps) + + # sometimes due to the rounding error e.g. v_start_sec = 1.505 but sec2frames(1.505, 25) = 1.48 + # given offset is -1.5, the a_start_i will be a small negative value. (likely a_fps * 1/v_fps * 0.5) + if a_start_i < 0: + how_much_out = a_start_i + logging.info(f'a_start_i is negative ({how_much_out}) at {item["path"]}') + if abs(how_much_out) <= a_fps / v_fps: + logging.info("fixing it") + a_start_i += abs(how_much_out) + else: + raise Exception(f'{how_much_out} {item["path"]}') + + if self.max_a_jitter_sec is not None and self.max_a_jitter_sec > 0: + a_start_i, a_jitter_i = apply_a_jitter( + a_start_i, a_len_frames, a_crop_len_frames, a_fps, self.max_a_jitter_sec + ) + item["meta"]["a_jitter_i"] = a_jitter_i + + a_end_i = a_start_i + a_crop_len_frames + + assert v_start_i < v_end_i and a_start_i < a_end_i + assert aud.shape[0] >= a_end_i, f'{aud.shape} {a_end_i} {item["path"]}' + assert vid.shape[0] >= v_end_i, f'{vid.shape} {v_end_i} {item["path"]}' + + vid, aud = vid[v_start_i:v_end_i, :, :, :], aud[a_start_i:a_end_i] + + item["video"] = vid + item["audio"] = aud + + assert item["video"].shape[0] == v_fps * self.crop_len_sec, f'{item["video"].shape} {item["path"]}' + assert item["audio"].shape[0] == a_fps * self.crop_len_sec, f'{item["audio"].shape} {item["path"]}' + + # caching parameters + if self.do_offset: + if self.offset_type == "grid": + offset_label, offset_target = quantize_offset(self.class_grid, offset_sec) + elif self.offset_type == "uniform": + offset_label, offset_target = offset_sec, offset_sec + elif self.offset_type == "uniform_binary": + offset_label, offset_target = offset_sec, {"oos": is_oos, "offset": offset_sec} + item["targets"]["offset_sec"] = offset_sec + item["targets"]["v_start_i_sec"] = v_start_i_sec + item["targets"]["offset_label"] = offset_label + # assert 'offset_target' not in item['targets'], f'{item["targets"]}. What passed it there?' + item["targets"]["offset_target"] = offset_target + + return item + + def get_crop_idx(self, len_frames: int, crop_len_frames: int, is_random=True): + if len_frames == crop_len_frames: + return 0, len_frames + if is_random: + left_i = random.randint(0, len_frames - crop_len_frames) + else: + left_i = int(round((len_frames - crop_len_frames) / 2.0)) + return left_i, left_i + crop_len_frames + + +class GenerateMultipleSegments(torch.nn.Module): + """ + Given an item with video and audio, generates a batch of `n_segments` segments + of length `segment_size_vframes` (if None, the max number of segments will be made). + If `is_start_random` is True, the starting position of the 1st segment will be random but respecting + n_segments. + `audio_jitter_sec` is the amount of audio offset in seconds. + """ + + def __init__( + self, + segment_size_vframes: int, + n_segments: int = None, + is_start_random: bool = False, + audio_jitter_sec: float = 0.0, + step_size_seg: float = 1, + ): + super().__init__() + self.segment_size_vframes = segment_size_vframes + self.n_segments = n_segments + self.is_start_random = is_start_random + self.audio_jitter_sec = audio_jitter_sec + self.step_size_seg = step_size_seg + logging.info(f"Segment step size: {self.step_size_seg}") + + def forward(self, item): + v_len_frames, C, H, W = item["video"].shape + a_len_frames = item["audio"].shape[0] + + v_fps = int(item["meta"]["video"]["fps"][0]) + a_fps = int(item["meta"]["audio"]["framerate"][0]) + + ## Determining the number of segments + # segment size + segment_size_vframes = self.segment_size_vframes + segment_size_aframes = sec2frames(frames2sec(self.segment_size_vframes, v_fps), a_fps) + # step size (stride) + stride_vframes = int(self.step_size_seg * segment_size_vframes) + stride_aframes = int(self.step_size_seg * segment_size_aframes) + # calculating the number of segments. (W - F + 2P) / S + 1 + n_segments_max_v = math.floor((v_len_frames - segment_size_vframes) / stride_vframes) + 1 + n_segments_max_a = math.floor((a_len_frames - segment_size_aframes) / stride_aframes) + 1 + # making sure audio and video can accommodate the same number of segments + n_segments_max = min(n_segments_max_v, n_segments_max_a) + n_segments = n_segments_max if self.n_segments is None else self.n_segments + + assert n_segments <= n_segments_max, ( + f"cant make {n_segments} segs of len {self.segment_size_vframes} in a vid " + f'of len {v_len_frames} for {item["path"]}' + ) + + # (n_segments, 2) each + v_ranges, a_ranges = self.get_sequential_seg_ranges( + v_len_frames, a_len_frames, v_fps, a_fps, n_segments, segment_size_aframes + ) + + # segmenting original streams (n_segments, segment_size_frames, C, H, W) + item["video"] = torch.stack([item["video"][s:e] for s, e in v_ranges], dim=0) + item["audio"] = torch.stack([item["audio"][s:e] for s, e in a_ranges], dim=0) + return item + + def get_sequential_seg_ranges(self, v_len_frames, a_len_frames, v_fps, a_fps, n_seg, seg_size_aframes): + # if is_start_random is True, the starting position of the 1st segment will + # be random but respecting n_segments like so: "-CCCCCCCC---" (maybe with fixed overlap), + # else the segments are taken from the middle of the video respecting n_segments: "--CCCCCCCC--" + + seg_size_vframes = self.segment_size_vframes # for brevity + + # calculating the step size in frames + step_size_vframes = int(self.step_size_seg * seg_size_vframes) + step_size_aframes = int(self.step_size_seg * seg_size_aframes) + + # calculating the length of the sequence of segments (and in frames) + seg_seq_len = n_seg * self.step_size_seg + (1 - self.step_size_seg) + vframes_seg_seq_len = int(seg_seq_len * seg_size_vframes) + aframes_seg_seq_len = int(seg_seq_len * seg_size_aframes) + + # doing temporal crop + max_v_start_i = v_len_frames - vframes_seg_seq_len + if self.is_start_random: + v_start_i = random.randint(0, max_v_start_i) + else: + v_start_i = max_v_start_i // 2 + a_start_i = sec2frames(frames2sec(v_start_i, v_fps), a_fps) # vid frames -> seconds -> aud frames + + # make segments starts + v_start_seg_i = torch.tensor([v_start_i + i * step_size_vframes for i in range(n_seg)]).int() + a_start_seg_i = torch.tensor([a_start_i + i * step_size_aframes for i in range(n_seg)]).int() + + # apply jitter to audio + if self.audio_jitter_sec > 0: + jitter_aframes = sec2frames(self.audio_jitter_sec, a_fps) + # making sure after applying jitter, the audio is still within the audio boundaries + jitter_aframes = min(jitter_aframes, a_start_i, a_len_frames - a_start_i - aframes_seg_seq_len) + a_start_seg_i += random.randint(-jitter_aframes, jitter_aframes) # applying jitter to segments + + # make segment ends + v_ends_seg_i = v_start_seg_i + seg_size_vframes + a_ends_seg_i = a_start_seg_i + seg_size_aframes # using the adjusted a_start_seg_i (with jitter) + + # make ranges + v_ranges = torch.stack([v_start_seg_i, v_ends_seg_i], dim=1) + a_ranges = torch.stack([a_start_seg_i, a_ends_seg_i], dim=1) + assert (a_ranges >= 0).all() and (a_ranges <= a_len_frames).all(), f"{a_ranges} out of {a_len_frames}" + assert (v_ranges <= v_len_frames).all(), f"{v_ranges} out of {v_len_frames}" + return v_ranges, a_ranges + + +class TemporalCropAndOffsetForSyncabilityTraining(torch.nn.Module): + + def __init__( + self, + max_off_sec: float, + do_offset: bool = True, + grid_size: int = None, + max_wiggle_sec: float = None, + segment_size_vframes: int = None, + n_segments: int = None, + step_size_seg: float = None, + vfps: float = None, + ): + super().__init__() + seg_size_sec = segment_size_vframes / vfps + trim_size_in_seg = n_segments - (1 - step_size_seg) * (n_segments - 1) + self.crop_len_sec = round(trim_size_in_seg * seg_size_sec, 2) + logging.info(f"Crop len: {self.crop_len_sec}") + self.do_offset = do_offset + self.grid_size = grid_size + self.max_off_sec = max_off_sec + self.max_a_jitter_sec = max_wiggle_sec + self.segment_size_vframes = segment_size_vframes + self.n_segments = n_segments + self.step_size_seg = step_size_seg + self.prob_syncable = 0.5 + if do_offset: + self.class_grid = make_class_grid(-max_off_sec, max_off_sec, grid_size) + logging.info(f"Offset class grid: {self.class_grid}") + if self.max_a_jitter_sec is not None: + assert (max_wiggle_sec - 1e-6) <= ((self.class_grid[1] - self.class_grid[0]) / 2), f"{self.class_grid}" + + def forward(self, item): + vid = item["video"] + aud = item["audio"] + v_len_frames, C, H, W = vid.shape + a_len_frames = aud.shape[0] + + v_fps = int(item["meta"]["video"]["fps"][0]) + a_fps = int(item["meta"]["audio"]["framerate"][0]) + + v_crop_len_frames = sec2frames(self.crop_len_sec, v_fps) + a_crop_len_frames = sec2frames(self.crop_len_sec, a_fps) + + if self.do_offset: + # trying to get the offset parameters (for instance during valid and test we have fixed offsets) + offset_sec = item["targets"].get("offset_sec", None) + v_start_i_sec = item["targets"].get("v_start_i_sec", None) + # train-time + if offset_sec is None and v_start_i_sec is None: + + # for the syncability training, we want to have a syncable or non-syncable offset with 50% prob + offset_is_syncable = random.random() < self.prob_syncable # 1=syncable, 0=non-syncable + if offset_is_syncable: + offset_sec = random.choice(self.class_grid.tolist()) + else: + offset_sec = random.choice([-self.crop_len_sec, self.crop_len_sec]) # either - or + offset + # aud starts `offset_sec` earlier than it should; aud has what will be shown after offset_sec + + offset_sec = round(offset_sec, 2) + v_start_max_sec = frames2sec(v_len_frames - v_crop_len_frames, v_fps) + assert v_start_max_sec > 0, f'{v_len_frames} {v_crop_len_frames} {v_fps} @ {item["path"]}' + # `v_start_sec` IS NOT rounded to the fps grid + v_start_sec = random.uniform(max(0, -offset_sec), min(v_start_max_sec, v_start_max_sec - offset_sec)) + assert 0 <= v_start_sec <= v_start_max_sec, f'{v_start_sec} {v_start_max_sec} {item["path"]}' + v_start_i = sec2frames(v_start_sec, v_fps) + v_end_i = v_start_i + v_crop_len_frames + # `v_start_i_sec` IS rounded to the fps grid + v_start_i_sec = frames2sec(v_start_i, v_fps) + # `a_start_i` depends on the rounded value `v_start_i_sec`, otherwise + # (v_start_sec) we have ±0.1 jittering + a_start_i = sec2frames(v_start_i_sec + offset_sec, a_fps) + if self.max_a_jitter_sec is not None and self.max_a_jitter_sec > 0: + a_start_i, a_jitter_i = apply_a_jitter( + a_start_i, a_len_frames, a_crop_len_frames, a_fps, self.max_a_jitter_sec + ) + item["meta"]["a_jitter_i"] = a_jitter_i + a_end_i = a_start_i + a_crop_len_frames + else: + offset_sec = round(offset_sec, 2) + v_start_i = sec2frames(v_start_i_sec, v_fps) + a_start_i = sec2frames(v_start_i_sec + offset_sec, a_fps) + v_end_i = v_start_i + v_crop_len_frames + a_end_i = a_start_i + a_crop_len_frames + else: + offset_sec = 0.0 + is_random_crop = item["split"] == "train" + v_start_i, v_end_i = self.get_crop_idx(v_len_frames, v_crop_len_frames, is_random=is_random_crop) + v_start_i_sec = frames2sec(v_start_i, v_fps) + a_start_i = sec2frames(v_start_i_sec, a_fps) + if self.max_a_jitter_sec is not None and self.max_a_jitter_sec > 0: + a_start_i, a_jitter_i = apply_a_jitter( + a_start_i, a_len_frames, a_crop_len_frames, a_fps, self.max_a_jitter_sec + ) + item["meta"]["a_jitter_i"] = a_jitter_i + a_end_i = a_start_i + a_crop_len_frames + + # sometimes due to the rounding error e.g. v_start_sec = 1.505 but sec2frames(1.505, 25) = 1.48 + # given offset is -1.5, the a_start_i will be a small negative value. (likely a_fps * 1/v_fps * 0.5) + if a_start_i < 0: + how_much_out = a_start_i + logging.info(f'a_start_i is negative ({how_much_out}) at {item["path"]}') + if abs(how_much_out) <= a_fps / v_fps: + logging.info("fixing it") + a_start_i += abs(how_much_out) + a_end_i += abs(how_much_out) + else: + raise Exception(f'{how_much_out} {item["path"]}') + + assert v_start_i < v_end_i and a_start_i < a_end_i + assert aud.shape[0] >= a_end_i, f'{aud.shape} {a_end_i} {item["path"]}' + assert vid.shape[0] >= v_end_i, f'{vid.shape} {v_end_i} {item["path"]}' + + vid, aud = vid[v_start_i:v_end_i, :, :, :], aud[a_start_i:a_end_i] + + item["video"] = vid + item["audio"] = aud + + assert item["video"].shape[0] == int(v_fps * self.crop_len_sec), f'{item["video"].shape} {item["path"]}' + assert item["audio"].shape[0] == int(a_fps * self.crop_len_sec), f'{item["audio"].shape} {item["path"]}' + + # caching parameters + if self.do_offset: + # NOTE: this is useless for the extreme offsetting + offset_label, offset_target = quantize_offset(self.class_grid, offset_sec) + item["targets"]["offset_sec"] = offset_sec + item["targets"]["offset_label"] = offset_label + # assert 'offset_target' not in item['targets'], f'{item["targets"]}. What passed it there?' + item["targets"]["offset_target"] = offset_target + item["targets"]["v_start_i_sec"] = v_start_i_sec + item["targets"]["sync_target"] = int(offset_is_syncable) + + return item + + def get_crop_idx(self, len_frames: int, crop_len_frames: int, is_random=True): + if len_frames == crop_len_frames: + return 0, len_frames + if is_random: + left_i = random.randint(0, len_frames - crop_len_frames) + else: + left_i = int(round((len_frames - crop_len_frames) / 2.0)) + return left_i, left_i + crop_len_frames + + +class RGBToFloatToZeroOne(torch.nn.Module): + + def __init__(self) -> None: + super().__init__() + + def forward(self, item): + item["video"] = item["video"].to(torch.float32).div(255.0) + return item + + +class RGBToHalfToZeroOne(torch.nn.Module): + + def __init__(self) -> None: + super().__init__() + + def forward(self, item): + item["video"] = item["video"].half().div(255.0) + return item + + +class RGBNormalize(torchvision.transforms.Normalize): + """The same as the torchvision`s but with different interface for the dict. + This should work for any shape (..., C, H, W)""" + + def __init__(self, mean, std, inplace=False): + super().__init__(mean, std, inplace) + logging.info(f"RGBNormalize: mean={mean}, std={std}") + + def forward(self, item): + item["video"] = super().forward(item["video"]) + item["meta"]["video"]["norm_stats"] = {"mean": torch.as_tensor(self.mean), "std": torch.as_tensor(self.std)} + return item + + +class AudioRandomVolume(torch.nn.Module): + + def __init__(self, p: float, **kwargs): + super().__init__() + transform = torchaudio.transforms.Vol(**kwargs) + self.transform = torchvision.transforms.RandomApply([transform], p) + + def apply_to_single_clip(self, clip): + return self.transform(clip) + + def apply_to_each_clip(self, clips): + for i, clip in enumerate(clips): + clips[i] = self.apply_to_single_clip(clip) + return clips + + def forward(self, item): + has_batch_dim = len(item["audio"].shape) == 2 + if has_batch_dim: + fn = self.apply_to_each_clip + else: + fn = self.apply_to_single_clip + item["audio"] = fn(item["audio"]) + return item + + +class AudioRandomLowpassFilter(torch.nn.Module): + + def __init__(self, p: float, cutoff_freq: float, Q: float = 0.707): + super().__init__() + self.p = p + self.cutoff_freq = cutoff_freq + self.Q = Q + + def apply_to_single_clip(self, clip, sr): + if self.p > torch.rand(1): + return torchaudio.functional.lowpass_biquad(clip, sr, self.cutoff_freq, self.Q) + else: + return clip + + def apply_to_each_clip(self, clips, sr): + for i, clip in enumerate(clips): + clips[i] = self.apply_to_single_clip(clip, sr) + return clips + + def forward(self, item): + has_batch_dim = len(item["audio"].shape) == 2 + sr = int(item["meta"]["audio"]["framerate"][0]) + if has_batch_dim: + fn = self.apply_to_each_clip + else: + fn = self.apply_to_single_clip + item["audio"] = fn(item["audio"], sr) + return item + + +class AudioRandomPitchShift(torch.nn.Module): + + def __init__(self, p: float, shift: int) -> None: + super().__init__() + self.p = p + self.shift = shift + + def apply_to_single_clip(self, wave, sr): + if self.p > torch.rand(1): + effects = [["pitch", f"{self.shift}"], ["rate", f"{sr}"]] + wave = wave.unsqueeze(0) + wave, _ = torchaudio.sox_effects.apply_effects_tensor(wave, sr, effects) + wave = wave.squeeze(0) + return wave + + def apply_to_each_clip(self, waves, sr): + for i, wave in enumerate(waves): + waves[i] = self.apply_to_single_clip(wave, sr) + return waves + + def forward(self, item): + has_batch_dim = len(item["audio"].shape) == 2 + sr = int(item["meta"]["audio"]["framerate"][0]) + if has_batch_dim: + fn = self.apply_to_each_clip + else: + fn = self.apply_to_single_clip + item["audio"] = fn(item["audio"], sr) + return item + + +class AudioRandomReverb(torch.nn.Module): + + def __init__(self, p: float) -> None: + super().__init__() + self.p = p + self.effects = [["reverb", "-w"]] + + def apply_to_single_clip(self, wave, fps): + if self.p > torch.rand(1): + wave = wave.unsqueeze(0) + wave, _ = torchaudio.sox_effects.apply_effects_tensor(wave, fps, self.effects) + wave = wave.mean(dim=0) + return wave + + def apply_to_each_clip(self, waves, fps): + for i, wave in enumerate(waves): + waves[i] = self.apply_to_single_clip(wave, fps) + return waves + + def forward(self, item): + has_batch_dim = len(item["audio"].shape) == 2 + sr = int(item["meta"]["audio"]["framerate"][0]) + if has_batch_dim: + fn = self.apply_to_each_clip + else: + fn = self.apply_to_single_clip + item["audio"] = fn(item["audio"], sr) + return item + + +class AudioRandomGaussNoise(torch.nn.Module): + + def __init__(self, p: float, amplitude=0.01) -> None: + super().__init__() + self.p = p + self.amplitude = amplitude + + def apply_to_single_clip(self, wave): + if self.p > torch.rand(1): + noise = torch.randn_like(wave, dtype=wave.dtype) + wave = wave + self.amplitude * noise + return wave + + def apply_to_each_clip(self, waves): + for i, wave in enumerate(waves): + waves[i] = self.apply_to_single_clip(wave) + return waves + + def forward(self, item): + has_batch_dim = len(item["audio"].shape) == 2 + if has_batch_dim: + fn = self.apply_to_each_clip + else: + fn = self.apply_to_single_clip + item["audio"] = fn(item["audio"]) + return item + + +class AudioMelSpectrogram(torch.nn.Module): + + def __init__(self, **kwargs): + super().__init__() + self.spec = torchaudio.transforms.MelSpectrogram(**kwargs) + + def forward(self, item): + item["audio"] = self.spec(item["audio"]) # safe for batched input + return item + + +class AudioLog(torch.nn.Module): + + def __init__(self, eps=1e-6) -> None: + super().__init__() + self.eps = eps + + def forward(self, item): + item["audio"] = torch.log(item["audio"] + self.eps) + return item + + +class PadOrTruncate(torch.nn.Module): + + def __init__(self, max_spec_t: int, pad_mode: str = "constant", pad_value: float = 0.0): + super().__init__() + self.max_spec_t = max_spec_t + self.pad_mode = pad_mode + self.pad_value = pad_value + + def forward(self, item): + item["audio"] = self.pad_or_truncate(item["audio"]) + return item + + def pad_or_truncate(self, audio): + difference = self.max_spec_t - audio.shape[-1] # safe for batched input + # pad or truncate, depending on difference + if difference > 0: + # pad the last dim (time) -> (..., n_mels, 0+time+difference) # safe for batched input + pad_dims = (0, difference) + audio = torch.nn.functional.pad(audio, pad_dims, self.pad_mode, self.pad_value) + elif difference < 0: + logging.warning(f"Truncating spec ({audio.shape}) to max_spec_t ({self.max_spec_t}).") + audio = audio[..., : self.max_spec_t] # safe for batched input + return audio + + +class AudioNormalizeAST(torch.nn.Module): + """Normalization is done with two specified mean and std (half)""" + + def __init__(self, mean: float, std: float) -> None: + super().__init__() + self.mean = mean + self.std = std + + def forward(self, item): + item["audio"] = (item["audio"] - self.mean) / (2 * self.std) + item["meta"]["audio"]["norm_stats"] = {"mean": self.mean, "std": self.std} + return item + + +class PermuteStreams(torch.nn.Module): + + def __init__(self, einops_order_audio: str, einops_order_rgb: str) -> None: + '''For example: + einops_order_audio: "S F T -> S T F" + einops_order_rgb: "S T C H W -> S C T H W"''' + super().__init__() + self.einops_order_audio = einops_order_audio + self.einops_order_rgb = einops_order_rgb + + def forward(self, item): + if self.einops_order_audio is not None: + item["audio"] = einops.rearrange(item["audio"], self.einops_order_audio).contiguous() + if self.einops_order_rgb is not None: + item["video"] = einops.rearrange(item["video"], self.einops_order_rgb).contiguous() + return item + + +class ResampleAudio(torch.nn.Module): + + def __init__(self, new_fps: int): + super().__init__() + self.new_fps = new_fps + + def forward(self, item): + orig_fps = int(item["meta"]["audio"]["framerate"][0]) + item["meta"]["audio"]["orig_shape"] = item["audio"].shape + if orig_fps != self.new_fps: + item["audio"] = torchaudio.functional.resample(item["audio"], orig_fps, self.new_fps) + item["meta"]["audio"]["framerate"][0] = self.new_fps + return item + + +class ResampleRGB(torch.nn.Module): + + def __init__(self, new_fps: int) -> None: + super().__init__() + self.new_fps = new_fps + + def forward(self, item): + orig_fps = float(item["meta"]["video"]["fps"][0]) + item["meta"]["video"]["orig_shape"] = item["video"].shape + if orig_fps != self.new_fps: + duration_sec = item["video"].shape[0] / orig_fps + indices = torch.arange(0, orig_fps * duration_sec - 1e-9, orig_fps / self.new_fps) + # basically, rounding + indices = indices.to(dtype=torch.long) + item["video"] = item["video"][indices] + item["meta"]["video"]["fps"][0] = self.new_fps + return item + + +class ResizeAndLetterboxPad(torch.nn.Module): + """Adapted from WACV24 Amazon`s challenge""" + + def __init__(self, new_h, new_w): + super().__init__() + self.new_h = new_h + self.new_w = new_w + self.aspect_ratio = new_w / new_h + + def forward(self, item): + item["video"] = self.resize_and_pad(item["video"]) + return item + + def resize_and_pad(self, rgb: torch.Tensor): + _, _, height, width = rgb.shape + current_aspect_ratio = width / height + if current_aspect_ratio > self.aspect_ratio: + scaled_height = round(self.new_w / current_aspect_ratio) + rgb = torchvision.transforms.functional.resize(rgb, (scaled_height, self.new_w), antialias=None) + top = (self.new_h - scaled_height) // 2 + bottom = self.new_h - (scaled_height + top) + rgb = torch.nn.ConstantPad2d((0, 0, top, bottom), 0)(rgb) + elif current_aspect_ratio < self.aspect_ratio: + scaled_width = round(self.new_h * current_aspect_ratio) + rgb = torchvision.transforms.functional.resize(rgb, (self.new_h, scaled_width), antialias=None) + left = (self.new_w - scaled_width) // 2 + right = self.new_w - (scaled_width + left) + rgb = torch.nn.ConstantPad2d((left, right, 0, 0), 0)(rgb) + return rgb + + +class ResampleResizeLetterboxPad(torch.nn.Module): + + def __init__(self, afps, vfps, new_h, new_w) -> None: + super().__init__() + self.transforms = torchvision.transforms.Compose( + [ResampleAudio(new_fps=afps), ResampleRGB(new_fps=vfps), ResizeAndLetterboxPad(new_h=new_h, new_w=new_w)] + ) + + def forward(self, x: dict) -> dict: + return self.transforms(x) + + +class DoNothing(torch.nn.Module): + def __init__(self, *args, **kwargs) -> None: + super().__init__() + + def forward(self, x: dict) -> dict: + return x + + +if __name__ == "__main__": + grid = make_class_grid(-1, 1, 21) + grid = make_class_grid(-2, 2, 41) + print("grid:", grid) + print("value quantization:", quantize_offset(grid, 0.06)) + v_fps = 25.0 + duration = 10.0 + + input = { + "video": torch.randint(0, 256, (int(duration * v_fps), 3, 720 // 2, 1280 // 2), dtype=torch.uint8), + "audio": torch.arange(221184 - 1).float(), + "targets": {}, + "meta": { + "video": {"duration": [duration], "fps": [v_fps]}, + "audio": {"duration": [duration], "framerate": [22050.0]}, + "subtitles": {"duration": []}, + "cc": {"duration": []}, + }, + "path": "/home/nvme/data/vggsound/video/-5cWCaoEDlE_261000_271000.mp4", + "split": "train", + } + + print(input["audio"].shape, input["video"].shape) + + fn = EqualifyFromRight(clip_max_len_sec=10) + input = fn(input) + print(input["audio"].shape, input["video"].shape) + + fn = RGBSpatialCrop((224, 224), is_random=True) + # fn = RGBSpatialCrop((112, 112), is_random=True) + input = fn(input) + print(input["audio"].shape, input["video"].shape, input["meta"]["audio"]) + + fn = Resize((224, 224)) + input = fn(input) + print(input["audio"].shape, input["video"].shape, input["meta"]["audio"]) + + fn = GenerateMultipleSegments( + segment_size_vframes=16, n_segments=14, is_start_random=False, audio_jitter_sec=0.05, step_size_seg=0.5 + ) + input = fn(input) + print(input["audio"].shape, input["video"].shape, input["meta"]["audio"]) + + fn = RandomApplyColorDistortion(p_gray_scale=0.5, p_color_jitter=0.5, s=1.0) + input = fn(input) + print(input["audio"].shape, input["video"].shape, input["meta"]["audio"]) + + fn = RGBToFloatToZeroOne() + input = fn(input) + print(input["audio"].shape, input["video"].shape, input["meta"]["audio"]) + print(input["meta"]) + + fn = RGBNormalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + input = fn(input) + print(input["audio"].shape, input["video"].shape, input["meta"]["audio"]) + print(input["video"].mean(dim=(0, 2, 3))) + print(input["meta"]) + + fn = AudioRandomReverb(p=1.0) + input = fn(input) + + fn = AudioRandomVolume(p=1.0, gain=2.0, gain_type="amplitude") + input = fn(input) + print(input["audio"].shape, input["video"].shape, input["meta"]["audio"]) + + fn = AudioRandomPitchShift(p=1.0, shift=1000) + input = fn(input) + print(input["audio"].shape, input["video"].shape, input["meta"]["audio"]) + + fn = AudioRandomLowpassFilter(p=1.0, cutoff_freq=100) + input = fn(input) + print(input["audio"].shape, input["video"].shape, input["meta"]["audio"]) + + fn = AudioRandomGaussNoise(p=1.0, amplitude=0.01) + input = fn(input) + print(input["audio"].shape, input["video"].shape, input["meta"]["audio"]) + + fn = AudioLog() + input = fn(input) + print(input["audio"].shape, input["video"].shape, input["meta"]["audio"]) + + # audio only + input = { + "audio": torch.arange(221184).float(), + "meta": { + "video": {"duration": [10.0], "fps": [10.0]}, + "audio": {"duration": [11.0], "framerate": [22050.0]}, + "subtitles": {"duration": []}, + "cc": {"duration": []}, + }, + "path": "/home/nvme/data/vggsound/video/-5cWCaoEDlE_261000_271000.mp4", + } + + print(input["audio"].shape) + + fn = AudioLog() + input = fn(input) + print(input["audio"].shape, input["meta"]["audio"]) + print(input["meta"]) + print(input["audio"].min(), input["audio"].max()) diff --git a/hunyuanvideo_foley/models/synchformer/divided_224_16x4.yaml b/hunyuanvideo_foley/models/synchformer/divided_224_16x4.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f9d20b76302a8af7928391643bd4b2d184e970aa --- /dev/null +++ b/hunyuanvideo_foley/models/synchformer/divided_224_16x4.yaml @@ -0,0 +1,84 @@ +TRAIN: + ENABLE: True + DATASET: Ssv2 + BATCH_SIZE: 32 + EVAL_PERIOD: 5 + CHECKPOINT_PERIOD: 5 + AUTO_RESUME: True + CHECKPOINT_EPOCH_RESET: True + CHECKPOINT_FILE_PATH: /checkpoint/fmetze/neurips_sota/40944587/checkpoints/checkpoint_epoch_00035.pyth +DATA: + NUM_FRAMES: 16 + SAMPLING_RATE: 4 + TRAIN_JITTER_SCALES: [256, 320] + TRAIN_CROP_SIZE: 224 + TEST_CROP_SIZE: 224 + INPUT_CHANNEL_NUM: [3] + MEAN: [0.5, 0.5, 0.5] + STD: [0.5, 0.5, 0.5] + PATH_TO_DATA_DIR: /private/home/mandelapatrick/slowfast/data/ssv2 + PATH_PREFIX: /datasets01/SomethingV2/092720/20bn-something-something-v2-frames + INV_UNIFORM_SAMPLE: True + RANDOM_FLIP: False + REVERSE_INPUT_CHANNEL: True + USE_RAND_AUGMENT: True + RE_PROB: 0.0 + USE_REPEATED_AUG: False + USE_RANDOM_RESIZE_CROPS: False + COLORJITTER: False + GRAYSCALE: False + GAUSSIAN: False +SOLVER: + BASE_LR: 1e-4 + LR_POLICY: steps_with_relative_lrs + LRS: [1, 0.1, 0.01] + STEPS: [0, 20, 30] + MAX_EPOCH: 35 + MOMENTUM: 0.9 + WEIGHT_DECAY: 5e-2 + WARMUP_EPOCHS: 0.0 + OPTIMIZING_METHOD: adamw + USE_MIXED_PRECISION: True + SMOOTHING: 0.2 +SLOWFAST: + ALPHA: 8 +VIT: + PATCH_SIZE: 16 + PATCH_SIZE_TEMP: 2 + CHANNELS: 3 + EMBED_DIM: 768 + DEPTH: 12 + NUM_HEADS: 12 + MLP_RATIO: 4 + QKV_BIAS: True + VIDEO_INPUT: True + TEMPORAL_RESOLUTION: 8 + USE_MLP: True + DROP: 0.0 + POS_DROPOUT: 0.0 + DROP_PATH: 0.2 + IM_PRETRAINED: True + HEAD_DROPOUT: 0.0 + HEAD_ACT: tanh + PRETRAINED_WEIGHTS: vit_1k + ATTN_LAYER: divided +MODEL: + NUM_CLASSES: 174 + ARCH: slow + MODEL_NAME: VisionTransformer + LOSS_FUNC: cross_entropy +TEST: + ENABLE: True + DATASET: Ssv2 + BATCH_SIZE: 64 + NUM_ENSEMBLE_VIEWS: 1 + NUM_SPATIAL_CROPS: 3 +DATA_LOADER: + NUM_WORKERS: 4 + PIN_MEMORY: True +NUM_GPUS: 8 +NUM_SHARDS: 4 +RNG_SEED: 0 +OUTPUT_DIR: . +TENSORBOARD: + ENABLE: True diff --git a/hunyuanvideo_foley/models/synchformer/modeling_ast.py b/hunyuanvideo_foley/models/synchformer/modeling_ast.py new file mode 100644 index 0000000000000000000000000000000000000000..f456753ecfff180dd36a3d2ff3e50a47ab735d52 --- /dev/null +++ b/hunyuanvideo_foley/models/synchformer/modeling_ast.py @@ -0,0 +1,673 @@ +# coding=utf-8 +# Copyright 2022 MIT and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Modified by v-iashin to support token masking + +"""PyTorch Audio Spectrogram Transformer (AST) model.""" + +import math +from typing import Dict, List, Optional, Set, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.activations import ACT2FN +from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, SequenceClassifierOutput +from transformers.modeling_utils import PreTrainedModel +from transformers.pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer +from transformers.models.audio_spectrogram_transformer.modeling_audio_spectrogram_transformer import ASTConfig +from transformers.utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, +) + + +logger = logging.get_logger(__name__) + +# General docstring +_CONFIG_FOR_DOC = "ASTConfig" + +# Base docstring +_CHECKPOINT_FOR_DOC = "MIT/ast-finetuned-audioset-10-10-0.4593" +_EXPECTED_OUTPUT_SHAPE = [1, 1214, 768] + +# Audio classification docstring +_SEQ_CLASS_CHECKPOINT = "MIT/ast-finetuned-audioset-10-10-0.4593" +_SEQ_CLASS_EXPECTED_OUTPUT = "'Speech'" +_SEQ_CLASS_EXPECTED_LOSS = 0.17 + + +AUDIO_SPECTROGRAM_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "MIT/ast-finetuned-audioset-10-10-0.4593", + # See all Audio Spectrogram Transformer models at https://huggingface.co/models?filter=ast +] + + +class ASTEmbeddings(nn.Module): + """ + Construct the CLS token, position and patch embeddings. + """ + + def __init__(self, config: ASTConfig) -> None: + super().__init__() + + self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) + self.distillation_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) + self.patch_embeddings = ASTPatchEmbeddings(config) + + frequency_out_dimension, time_out_dimension = self.get_shape(config) + num_patches = frequency_out_dimension * time_out_dimension + self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 2, config.hidden_size)) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.config = config + + def get_shape(self, config): + # see Karpathy's cs231n blog on how to calculate the output dimensions + # https://cs231n.github.io/convolutional-networks/#conv + frequency_out_dimension = (config.num_mel_bins - config.patch_size) // config.frequency_stride + 1 + time_out_dimension = (config.max_length - config.patch_size) // config.time_stride + 1 + + return frequency_out_dimension, time_out_dimension + + def forward(self, input_values: torch.Tensor) -> torch.Tensor: + batch_size = input_values.shape[0] + embeddings = self.patch_embeddings(input_values) + + cls_tokens = self.cls_token.expand(batch_size, -1, -1) + distillation_tokens = self.distillation_token.expand(batch_size, -1, -1) + embeddings = torch.cat((cls_tokens, distillation_tokens, embeddings), dim=1) + print(self.position_embeddings.shape) + embeddings = embeddings + self.position_embeddings + embeddings = self.dropout(embeddings) + + return embeddings + + +class ASTPatchEmbeddings(nn.Module): + """ + This class turns `input_values` into the initial `hidden_states` (patch embeddings) of shape `(batch_size, + seq_length, hidden_size)` to be consumed by a Transformer. + """ + + def __init__(self, config): + super().__init__() + + patch_size = config.patch_size + frequency_stride = config.frequency_stride + time_stride = config.time_stride + + self.projection = nn.Conv2d( + 1, config.hidden_size, kernel_size=(patch_size, patch_size), stride=(frequency_stride, time_stride) + ) + + def forward(self, input_values: torch.Tensor) -> torch.Tensor: + input_values = input_values.unsqueeze(1) + input_values = input_values.transpose(2, 3) + embeddings = self.projection(input_values).flatten(2).transpose(1, 2) + return embeddings + + +# Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->AST +class ASTSelfAttention(nn.Module): + def __init__(self, config: ASTConfig) -> None: + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size {config.hidden_size,} is not a multiple of the number of attention " + f"heads {config.num_attention_heads}." + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) + self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) + self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states, + tok_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + mixed_query_layer = self.query(hidden_states) + + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + query_layer = self.transpose_for_scores(mixed_query_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + + # apply masking if provided, tok_mask is (BS, N): 1s - keep; attention_scores is (BS, H, N, N) + if tok_mask is not None: + BS, N = tok_mask.shape + attention_scores = attention_scores.masked_fill(tok_mask.view(BS, 1, 1, N) == 0, float("-inf")) + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + return outputs + + +# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->AST +class ASTSelfOutput(nn.Module): + """ + The residual connection is defined in ASTLayer instead of here (as is the case with other models), due to the + layernorm applied before each block. + """ + + def __init__(self, config: ASTConfig) -> None: + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + + return hidden_states + + +# Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->AST +class ASTAttention(nn.Module): + def __init__(self, config: ASTConfig) -> None: + super().__init__() + self.attention = ASTSelfAttention(config) + self.output = ASTSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads: Set[int]) -> None: + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.attention.query = prune_linear_layer(self.attention.query, index) + self.attention.key = prune_linear_layer(self.attention.key, index) + self.attention.value = prune_linear_layer(self.attention.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads) + self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states: torch.Tensor, + tok_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + self_outputs = self.attention(hidden_states, tok_mask, head_mask, output_attentions) + + attention_output = self.output(self_outputs[0], hidden_states) + + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +# Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->AST +class ASTIntermediate(nn.Module): + def __init__(self, config: ASTConfig) -> None: + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + + return hidden_states + + +# Copied from transformers.models.vit.modeling_vit.ViTOutput with ViT->AST +class ASTOutput(nn.Module): + def __init__(self, config: ASTConfig) -> None: + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + + hidden_states = hidden_states + input_tensor + + return hidden_states + + +# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->AST +class ASTLayer(nn.Module): + """This corresponds to the Block class in the timm implementation.""" + + def __init__(self, config: ASTConfig) -> None: + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = ASTAttention(config) + self.intermediate = ASTIntermediate(config) + self.output = ASTOutput(config) + self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + tok_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + self_attention_outputs = self.attention( + self.layernorm_before(hidden_states), # in AST, layernorm is applied before self-attention + tok_mask, + head_mask, + output_attentions=output_attentions, + ) + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + # first residual connection + hidden_states = attention_output + hidden_states + + # in AST, layernorm is also applied after self-attention + layer_output = self.layernorm_after(hidden_states) + layer_output = self.intermediate(layer_output) + + # second residual connection is done here + layer_output = self.output(layer_output, hidden_states) + + outputs = (layer_output,) + outputs + + return outputs + + +# Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->AST +class ASTEncoder(nn.Module): + def __init__(self, config: ASTConfig) -> None: + super().__init__() + self.config = config + self.layer = nn.ModuleList([ASTLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + tok_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ) -> Union[tuple, BaseModelOutput]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + tok_mask, + layer_head_mask, + ) + else: + layer_outputs = layer_module(hidden_states, tok_mask, layer_head_mask, output_attentions) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class ASTPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = ASTConfig + base_model_prefix = "audio_spectrogram_transformer" + main_input_name = "input_values" + supports_gradient_checkpointing = True + + # Copied from transformers.models.deit.modeling_deit.DeiTPreTrainedModel._init_weights + def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d)): + # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid + # `trunc_normal_cpu` not implemented in `half` issues + module.weight.data = nn.init.trunc_normal_( + module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range + ).to(module.weight.dtype) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + # Copied from transformers.models.vit.modeling_vit.ViTPreTrainedModel._set_gradient_checkpointing with ViT->AST + def _set_gradient_checkpointing(self, module: ASTEncoder, value: bool = False) -> None: + if isinstance(module, ASTEncoder): + module.gradient_checkpointing = value + + +AUDIO_SPECTROGRAM_TRANSFORMER_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it + as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`ASTConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +AUDIO_SPECTROGRAM_TRANSFORMER_INPUTS_DOCSTRING = r""" + Args: + input_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoFeatureExtractor`]. See + [`ASTFeatureExtractor.__call__`] for details. + + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare AST Model transformer outputting raw hidden-states without any specific head on top.", + AUDIO_SPECTROGRAM_TRANSFORMER_START_DOCSTRING, +) +class ASTModel(ASTPreTrainedModel): + def __init__(self, config: ASTConfig): + super().__init__(config) + self.config = config + + self.embeddings = ASTEmbeddings(config) + self.encoder = ASTEncoder(config) + + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> ASTPatchEmbeddings: + return self.embeddings.patch_embeddings + + def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None: + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(AUDIO_SPECTROGRAM_TRANSFORMER_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPooling, + config_class=_CONFIG_FOR_DOC, + modality="audio", + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + def forward( + self, + input_values: Optional[torch.Tensor] = None, + cont_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_values is None: + raise ValueError("You have to specify input_values") + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + embedding_output = self.embeddings(input_values) + + # transforms the mask that has spectrogram dims to the token masking which is obtained after patching. + # Due to the ovelap in patching, getting the token mask from spectrogram mask is not straightforward, + # because one 16x16 content patch is encoded in two tokens if stride is <16. So, to get the mask for + # tokens I will apply the patching func (self.embeddings) to the tensor with infinities at the masked + # content position. For infs, the patching fn will return nans, which I'll use to get the token mask. + if cont_mask is not None: + indicator = torch.ones_like(input_values).to(input_values.dtype) + # replace content mask (0s) with infs + indicator[~cont_mask] = torch.inf + # apply patching; now nans are where the content mask was + with torch.no_grad(): + indicator = self.embeddings(indicator) # BS, N, D + # replace nans with 0s; these are the tokens that correspond to the masked content + tok_mask = ~torch.isnan(indicator) + # since all values in the D-dimension (latent) will also be nans, we can just use the first el + tok_mask = tok_mask[:, :, 0] # (BS, 2+num_patches) -- 2 is from CLS and DISTIL tokens + else: + tok_mask = None + + encoder_outputs = self.encoder( + embedding_output, + tok_mask=tok_mask, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + sequence_output = self.layernorm(sequence_output) + + pooled_output = (sequence_output[:, 0] + sequence_output[:, 1]) / 2 + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return ( + BaseModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ), + tok_mask, + ) + + +class ASTMLPHead(nn.Module): + def __init__(self, config: ASTConfig): + super().__init__() + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dense = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity() + + def forward(self, hidden_state): + hidden_state = self.layernorm(hidden_state) + hidden_state = self.dense(hidden_state) + return hidden_state + + +@add_start_docstrings( + """ + Audio Spectrogram Transformer model with an audio classification head on top (a linear layer on top of the pooled + output) e.g. for datasets like AudioSet, Speech Commands v2. + """, + AUDIO_SPECTROGRAM_TRANSFORMER_START_DOCSTRING, +) +class ASTForAudioClassification(ASTPreTrainedModel): + def __init__(self, config: ASTConfig) -> None: + super().__init__(config) + + self.num_labels = config.num_labels + self.audio_spectrogram_transformer = ASTModel(config) + + # Classifier head + self.classifier = ASTMLPHead(config) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(AUDIO_SPECTROGRAM_TRANSFORMER_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_SEQ_CLASS_CHECKPOINT, + output_type=SequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + modality="audio", + expected_output=_SEQ_CLASS_EXPECTED_OUTPUT, + expected_loss=_SEQ_CLASS_EXPECTED_LOSS, + ) + def forward( + self, + input_values: Optional[torch.Tensor] = None, + cont_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, SequenceClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the audio classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.audio_spectrogram_transformer( + input_values, + cont_mask=cont_mask, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/hunyuanvideo_foley/models/synchformer/motionformer.py b/hunyuanvideo_foley/models/synchformer/motionformer.py new file mode 100644 index 0000000000000000000000000000000000000000..9980a6f6d667699a275b10c6f613a30493566713 --- /dev/null +++ b/hunyuanvideo_foley/models/synchformer/motionformer.py @@ -0,0 +1,397 @@ +import logging +from pathlib import Path + +import einops +import torch +from omegaconf import OmegaConf +from timm.layers import trunc_normal_ +from torch import nn + +from .utils import check_if_file_exists_else_download +from .video_model_builder import VisionTransformer + + +FILE2URL = { + # cfg + "motionformer_224_16x4.yaml": "https://raw.githubusercontent.com/facebookresearch/Motionformer/bf43d50/configs/SSV2/motionformer_224_16x4.yaml", + "joint_224_16x4.yaml": "https://raw.githubusercontent.com/facebookresearch/Motionformer/bf43d50/configs/SSV2/joint_224_16x4.yaml", + "divided_224_16x4.yaml": "https://raw.githubusercontent.com/facebookresearch/Motionformer/bf43d50/configs/SSV2/divided_224_16x4.yaml", + # ckpt + "ssv2_motionformer_224_16x4.pyth": "https://dl.fbaipublicfiles.com/motionformer/ssv2_motionformer_224_16x4.pyth", + "ssv2_joint_224_16x4.pyth": "https://dl.fbaipublicfiles.com/motionformer/ssv2_joint_224_16x4.pyth", + "ssv2_divided_224_16x4.pyth": "https://dl.fbaipublicfiles.com/motionformer/ssv2_divided_224_16x4.pyth", +} + + +class MotionFormer(VisionTransformer): + """This class serves three puposes: + 1. Renames the class to MotionFormer. + 2. Downloads the cfg from the original repo and patches it if needed. + 3. Takes care of feature extraction by redefining .forward() + - if `extract_features=True` and `factorize_space_time=False`, + the output is of shape (B, T, D) where T = 1 + (224 // 16) * (224 // 16) * 8 + - if `extract_features=True` and `factorize_space_time=True`, the output is of shape (B*S, D) + and spatial and temporal transformer encoder layers are used. + - if `extract_features=True` and `factorize_space_time=True` as well as `add_global_repr=True` + the output is of shape (B, D) and spatial and temporal transformer encoder layers + are used as well as the global representation is extracted from segments (extra pos emb + is added). + """ + + def __init__( + self, + extract_features: bool = False, + ckpt_path: str = None, + factorize_space_time: bool = None, + agg_space_module: str = None, + agg_time_module: str = None, + add_global_repr: bool = True, + agg_segments_module: str = None, + max_segments: int = None, + ): + self.extract_features = extract_features + self.ckpt_path = ckpt_path + self.factorize_space_time = factorize_space_time + + if self.ckpt_path is not None: + check_if_file_exists_else_download(self.ckpt_path, FILE2URL) + ckpt = torch.load(self.ckpt_path, map_location="cpu") + mformer_ckpt2cfg = { + "ssv2_motionformer_224_16x4.pyth": "motionformer_224_16x4.yaml", + "ssv2_joint_224_16x4.pyth": "joint_224_16x4.yaml", + "ssv2_divided_224_16x4.pyth": "divided_224_16x4.yaml", + } + # init from motionformer ckpt or from our Stage I ckpt + # depending on whether the feat extractor was pre-trained on AVCLIPMoCo or not, we need to + # load the state dict differently + was_pt_on_avclip = self.ckpt_path.endswith(".pt") # checks if it is a stage I ckpt (FIXME: a bit generic) + if self.ckpt_path.endswith(tuple(mformer_ckpt2cfg.keys())): + cfg_fname = mformer_ckpt2cfg[Path(self.ckpt_path).name] + elif was_pt_on_avclip: + # TODO: this is a hack, we should be able to get the cfg from the ckpt (earlier ckpt didn't have it) + s1_cfg = ckpt.get("args", None) # Stage I cfg + if s1_cfg is not None: + s1_vfeat_extractor_ckpt_path = s1_cfg.model.params.vfeat_extractor.params.ckpt_path + # if the stage I ckpt was initialized from a motionformer ckpt or train from scratch + if s1_vfeat_extractor_ckpt_path is not None: + cfg_fname = mformer_ckpt2cfg[Path(s1_vfeat_extractor_ckpt_path).name] + else: + cfg_fname = "divided_224_16x4.yaml" + else: + cfg_fname = "divided_224_16x4.yaml" + else: + raise ValueError(f"ckpt_path {self.ckpt_path} is not supported.") + else: + was_pt_on_avclip = False + cfg_fname = "divided_224_16x4.yaml" + # logging.info(f'No ckpt_path provided, using {cfg_fname} config.') + + if cfg_fname in ["motionformer_224_16x4.yaml", "divided_224_16x4.yaml"]: + pos_emb_type = "separate" + elif cfg_fname == "joint_224_16x4.yaml": + pos_emb_type = "joint" + + self.mformer_cfg_path = Path(__file__).absolute().parent / cfg_fname + + check_if_file_exists_else_download(self.mformer_cfg_path, FILE2URL) + mformer_cfg = OmegaConf.load(self.mformer_cfg_path) + logging.info(f"Loading MotionFormer config from {self.mformer_cfg_path.absolute()}") + + # patch the cfg (from the default cfg defined in the repo `Motionformer/slowfast/config/defaults.py`) + mformer_cfg.VIT.ATTN_DROPOUT = 0.0 + mformer_cfg.VIT.POS_EMBED = pos_emb_type + mformer_cfg.VIT.USE_ORIGINAL_TRAJ_ATTN_CODE = True + mformer_cfg.VIT.APPROX_ATTN_TYPE = "none" # guessing + mformer_cfg.VIT.APPROX_ATTN_DIM = 64 # from ckpt['cfg'] + + # finally init VisionTransformer with the cfg + super().__init__(mformer_cfg) + + # load the ckpt now if ckpt is provided and not from AVCLIPMoCo-pretrained ckpt + if (self.ckpt_path is not None) and (not was_pt_on_avclip): + _ckpt_load_status = self.load_state_dict(ckpt["model_state"], strict=False) + if len(_ckpt_load_status.missing_keys) > 0 or len(_ckpt_load_status.unexpected_keys) > 0: + logging.warning( + f"Loading exact vfeat_extractor ckpt from {self.ckpt_path} failed." + f"Missing keys: {_ckpt_load_status.missing_keys}, " + f"Unexpected keys: {_ckpt_load_status.unexpected_keys}" + ) + else: + logging.info(f"Loading vfeat_extractor ckpt from {self.ckpt_path} succeeded.") + + if self.extract_features: + assert isinstance(self.norm, nn.LayerNorm), "early x[:, 1:, :] may not be safe for per-tr weights" + # pre-logits are Sequential(nn.Linear(emb, emd), act) and `act` is tanh but see the logger + self.pre_logits = nn.Identity() + # we don't need the classification head (saving memory) + self.head = nn.Identity() + self.head_drop = nn.Identity() + # avoiding code duplication (used only if agg_*_module is TransformerEncoderLayer) + transf_enc_layer_kwargs = dict( + d_model=self.embed_dim, + nhead=self.num_heads, + activation=nn.GELU(), + batch_first=True, + dim_feedforward=self.mlp_ratio * self.embed_dim, + dropout=self.drop_rate, + layer_norm_eps=1e-6, + norm_first=True, + ) + # define adapters if needed + if self.factorize_space_time: + if agg_space_module == "TransformerEncoderLayer": + self.spatial_attn_agg = SpatialTransformerEncoderLayer(**transf_enc_layer_kwargs) + elif agg_space_module == "AveragePooling": + self.spatial_attn_agg = AveragePooling( + avg_pattern="BS D t h w -> BS D t", then_permute_pattern="BS D t -> BS t D" + ) + if agg_time_module == "TransformerEncoderLayer": + self.temp_attn_agg = TemporalTransformerEncoderLayer(**transf_enc_layer_kwargs) + elif agg_time_module == "AveragePooling": + self.temp_attn_agg = AveragePooling(avg_pattern="BS t D -> BS D") + elif "Identity" in agg_time_module: + self.temp_attn_agg = nn.Identity() + # define a global aggregation layer (aggregarate over segments) + self.add_global_repr = add_global_repr + if add_global_repr: + if agg_segments_module == "TransformerEncoderLayer": + # we can reuse the same layer as for temporal factorization (B, dim_to_agg, D) -> (B, D) + # we need to add pos emb (PE) because previously we added the same PE for each segment + pos_max_len = max_segments if max_segments is not None else 16 # 16 = 10sec//0.64sec + 1 + self.global_attn_agg = TemporalTransformerEncoderLayer( + add_pos_emb=True, + pos_emb_drop=mformer_cfg.VIT.POS_DROPOUT, + pos_max_len=pos_max_len, + **transf_enc_layer_kwargs, + ) + elif agg_segments_module == "AveragePooling": + self.global_attn_agg = AveragePooling(avg_pattern="B S D -> B D") + + if was_pt_on_avclip: + # we need to filter out the state_dict of the AVCLIP model (has both A and V extractors) + # and keep only the state_dict of the feat extractor + ckpt_weights = dict() + for k, v in ckpt["state_dict"].items(): + if k.startswith(("module.v_encoder.", "v_encoder.")): + k = k.replace("module.", "").replace("v_encoder.", "") + ckpt_weights[k] = v + _load_status = self.load_state_dict(ckpt_weights, strict=False) + if len(_load_status.missing_keys) > 0 or len(_load_status.unexpected_keys) > 0: + logging.warning( + f"Loading exact vfeat_extractor ckpt from {self.ckpt_path} failed. \n" + f"Missing keys ({len(_load_status.missing_keys)}): " + f"{_load_status.missing_keys}, \n" + f"Unexpected keys ({len(_load_status.unexpected_keys)}): " + f"{_load_status.unexpected_keys} \n" + f"temp_attn_agg are expected to be missing if ckpt was pt contrastively." + ) + else: + logging.info(f"Loading vfeat_extractor ckpt from {self.ckpt_path} succeeded.") + + # patch_embed is not used in MotionFormer, only patch_embed_3d, because cfg.VIT.PATCH_SIZE_TEMP > 1 + # but it used to calculate the number of patches, so we need to set keep it + self.patch_embed.requires_grad_(False) + + def forward(self, x): + """ + x is of shape (B, S, C, T, H, W) where S is the number of segments. + """ + # Batch, Segments, Channels, T=frames, Height, Width + B, S, C, T, H, W = x.shape + # Motionformer expects a tensor of shape (1, B, C, T, H, W). + # The first dimension (1) is a dummy dimension to make the input tensor and won't be used: + # see `video_model_builder.video_input`. + # x = x.unsqueeze(0) # (1, B, S, C, T, H, W) + + orig_shape = (B, S, C, T, H, W) + x = x.view(B * S, C, T, H, W) # flatten batch and segments + x = self.forward_segments(x, orig_shape=orig_shape) + # unpack the segments (using rest dimensions to support different shapes e.g. (BS, D) or (BS, t, D)) + x = x.view(B, S, *x.shape[1:]) + # x is now of shape (B*S, D) or (B*S, t, D) if `self.temp_attn_agg` is `Identity` + + return x # x is (B, S, ...) + + def forward_segments(self, x, orig_shape: tuple) -> torch.Tensor: + """x is of shape (1, BS, C, T, H, W) where S is the number of segments.""" + x, x_mask = self.forward_features(x) + + assert self.extract_features + + # (BS, T, D) where T = 1 + (224 // 16) * (224 // 16) * 8 + x = x[:, 1:, :] # without the CLS token for efficiency (should be safe for LayerNorm and FC) + x = self.norm(x) + x = self.pre_logits(x) + if self.factorize_space_time: + x = self.restore_spatio_temp_dims(x, orig_shape) # (B*S, D, t, h, w) <- (B*S, t*h*w, D) + + x = self.spatial_attn_agg(x, x_mask) # (B*S, t, D) + x = self.temp_attn_agg(x) # (B*S, D) or (BS, t, D) if `self.temp_attn_agg` is `Identity` + + return x + + def restore_spatio_temp_dims(self, feats: torch.Tensor, orig_shape: tuple) -> torch.Tensor: + """ + feats are of shape (B*S, T, D) where T = 1 + (224 // 16) * (224 // 16) * 8 + Our goal is to make them of shape (B*S, t, h, w, D) where h, w are the spatial dimensions. + From `self.patch_embed_3d`, it follows that we could reshape feats with: + `feats.transpose(1, 2).view(B*S, D, t, h, w)` + """ + B, S, C, T, H, W = orig_shape + D = self.embed_dim + + # num patches in each dimension + t = T // self.patch_embed_3d.z_block_size + h = self.patch_embed_3d.height + w = self.patch_embed_3d.width + + feats = feats.permute(0, 2, 1) # (B*S, D, T) + feats = feats.view(B * S, D, t, h, w) # (B*S, D, t, h, w) + + return feats + + +class BaseEncoderLayer(nn.TransformerEncoderLayer): + """ + This is a wrapper around nn.TransformerEncoderLayer that adds a CLS token + to the sequence and outputs the CLS token's representation. + This base class parents both SpatialEncoderLayer and TemporalEncoderLayer for the RGB stream + and the FrequencyEncoderLayer and TemporalEncoderLayer for the audio stream stream. + We also, optionally, add a positional embedding to the input sequence which + allows to reuse it for global aggregation (of segments) for both streams. + """ + + def __init__( + self, + add_pos_emb: bool = False, + pos_emb_drop: float = None, + pos_max_len: int = None, + *args_transformer_enc, + **kwargs_transformer_enc, + ): + super().__init__(*args_transformer_enc, **kwargs_transformer_enc) + self.cls_token = nn.Parameter(torch.zeros(1, 1, self.self_attn.embed_dim)) + trunc_normal_(self.cls_token, std=0.02) + + # add positional embedding + self.add_pos_emb = add_pos_emb + if add_pos_emb: + self.pos_max_len = 1 + pos_max_len # +1 (for CLS) + self.pos_emb = nn.Parameter(torch.zeros(1, self.pos_max_len, self.self_attn.embed_dim)) + self.pos_drop = nn.Dropout(pos_emb_drop) + trunc_normal_(self.pos_emb, std=0.02) + + self.apply(self._init_weights) + + def forward(self, x: torch.Tensor, x_mask: torch.Tensor = None): + """x is of shape (B, N, D); if provided x_mask is of shape (B, N)""" + batch_dim = x.shape[0] + + # add CLS token + cls_tokens = self.cls_token.expand(batch_dim, -1, -1) # expanding to match batch dimension + x = torch.cat((cls_tokens, x), dim=-2) # (batch_dim, 1+seq_len, D) + if x_mask is not None: + cls_mask = torch.ones((batch_dim, 1), dtype=torch.bool, device=x_mask.device) # 1=keep; 0=mask + x_mask_w_cls = torch.cat((cls_mask, x_mask), dim=-1) # (batch_dim, 1+seq_len) + B, N = x_mask_w_cls.shape + # torch expects (N, N) or (B*num_heads, N, N) mask (sadness ahead); torch masks + x_mask_w_cls = ( + x_mask_w_cls.reshape(B, 1, 1, N) + .expand(-1, self.self_attn.num_heads, N, -1) + .reshape(B * self.self_attn.num_heads, N, N) + ) + assert x_mask_w_cls.dtype == x_mask_w_cls.bool().dtype, "x_mask_w_cls.dtype != bool" + x_mask_w_cls = ~x_mask_w_cls # invert mask (1=mask) + else: + x_mask_w_cls = None + + # add positional embedding + if self.add_pos_emb: + seq_len = x.shape[1] # (don't even think about moving it before the CLS token concatenation) + assert seq_len <= self.pos_max_len, f"Seq len ({seq_len}) > pos_max_len ({self.pos_max_len})" + x = x + self.pos_emb[:, :seq_len, :] + x = self.pos_drop(x) + + # apply encoder layer (calls nn.TransformerEncoderLayer.forward); + x = super().forward(src=x, src_mask=x_mask_w_cls) # (batch_dim, 1+seq_len, D) + + # CLS token is expected to hold spatial information for each frame + x = x[:, 0, :] # (batch_dim, D) + + return x + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {"cls_token", "pos_emb"} + + +class SpatialTransformerEncoderLayer(BaseEncoderLayer): + """Aggregates spatial dimensions by applying attention individually to each frame.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, x: torch.Tensor, x_mask: torch.Tensor = None) -> torch.Tensor: + """x is of shape (B*S, D, t, h, w) where S is the number of segments. + if specified x_mask (B*S, t, h, w), 0=masked, 1=kept + Returns a tensor of shape (B*S, t, D) pooling spatial information for each frame.""" + BS, D, t, h, w = x.shape + + # time as a batch dimension and flatten spatial dimensions as sequence + x = einops.rearrange(x, "BS D t h w -> (BS t) (h w) D") + # similar to mask + if x_mask is not None: + x_mask = einops.rearrange(x_mask, "BS t h w -> (BS t) (h w)") + + # apply encoder layer (BaseEncoderLayer.forward) - it will add CLS token and output its representation + x = super().forward(x=x, x_mask=x_mask) # (B*S*t, D) + + # reshape back to (B*S, t, D) + x = einops.rearrange(x, "(BS t) D -> BS t D", BS=BS, t=t) + + # (B*S, t, D) + return x + + +class TemporalTransformerEncoderLayer(BaseEncoderLayer): + """Aggregates temporal dimension with attention. Also used with pos emb as global aggregation + in both streams.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, x): + """x is of shape (B*S, t, D) where S is the number of segments. + Returns a tensor of shape (B*S, D) pooling temporal information.""" + BS, t, D = x.shape + + # apply encoder layer (BaseEncoderLayer.forward) - it will add CLS token and output its representation + x = super().forward(x) # (B*S, D) + + return x # (B*S, D) + + +class AveragePooling(nn.Module): + + def __init__(self, avg_pattern: str, then_permute_pattern: str = None) -> None: + """patterns are e.g. "bs t d -> bs d" """ + super().__init__() + # TODO: need to register them as buffers (but fails because these are strings) + self.reduce_fn = "mean" + self.avg_pattern = avg_pattern + self.then_permute_pattern = then_permute_pattern + + def forward(self, x: torch.Tensor, x_mask: torch.Tensor = None) -> torch.Tensor: + x = einops.reduce(x, self.avg_pattern, self.reduce_fn) + if self.then_permute_pattern is not None: + x = einops.rearrange(x, self.then_permute_pattern) + return x diff --git a/hunyuanvideo_foley/models/synchformer/synchformer.py b/hunyuanvideo_foley/models/synchformer/synchformer.py new file mode 100644 index 0000000000000000000000000000000000000000..1238bbc68dc451a56121ba7ab1fc00aa290420c3 --- /dev/null +++ b/hunyuanvideo_foley/models/synchformer/synchformer.py @@ -0,0 +1,355 @@ +import logging +import math +from typing import Any, Mapping + +import einops +import numpy as np +import torch +import torchaudio +from torch import nn +from torch.nn import functional as F + +from .motionformer import MotionFormer +from .ast_model import AST +from .utils import Config + + +class Synchformer(nn.Module): + + def __init__(self): + super().__init__() + + self.vfeat_extractor = MotionFormer( + extract_features=True, + factorize_space_time=True, + agg_space_module="TransformerEncoderLayer", + agg_time_module="torch.nn.Identity", + add_global_repr=False, + ) + self.afeat_extractor = AST( + extract_features=True, + max_spec_t=66, + factorize_freq_time=True, + agg_freq_module="TransformerEncoderLayer", + agg_time_module="torch.nn.Identity", + add_global_repr=False, + ) + + # # bridging the s3d latent dim (1024) into what is specified in the config + # # to match e.g. the transformer dim + self.vproj = nn.Linear(in_features=768, out_features=768) + self.aproj = nn.Linear(in_features=768, out_features=768) + self.transformer = GlobalTransformer( + tok_pdrop=0.0, embd_pdrop=0.1, resid_pdrop=0.1, attn_pdrop=0.1, n_layer=3, n_head=8, n_embd=768 + ) + + def forward(self, vis): + B, S, Tv, C, H, W = vis.shape + vis = vis.permute(0, 1, 3, 2, 4, 5) # (B, S, C, Tv, H, W) + # feat extractors return a tuple of segment-level and global features (ignored for sync) + # (B, S, tv, D), e.g. (B, 7, 8, 768) + vis = self.vfeat_extractor(vis) + return vis + + def compare_v_a(self, vis: torch.Tensor, aud: torch.Tensor): + vis = self.vproj(vis) + aud = self.aproj(aud) + + B, S, tv, D = vis.shape + B, S, ta, D = aud.shape + vis = vis.view(B, S * tv, D) # (B, S*tv, D) + aud = aud.view(B, S * ta, D) # (B, S*ta, D) + # print(vis.shape, aud.shape) + + # self.transformer will concatenate the vis and aud in one sequence with aux tokens, + # ie `CvvvvMaaaaaa`, and will return the logits for the CLS tokens + logits = self.transformer(vis, aud) # (B, cls); or (B, cls) and (B, 2) if DoubtingTransformer + + return logits + + def extract_vfeats(self, vis): + B, S, Tv, C, H, W = vis.shape + vis = vis.permute(0, 1, 3, 2, 4, 5) # (B, S, C, Tv, H, W) + # feat extractors return a tuple of segment-level and global features (ignored for sync) + # (B, S, tv, D), e.g. (B, 7, 8, 768) + vis = self.vfeat_extractor(vis) + return vis + + def extract_afeats(self, aud): + B, S, _, Fa, Ta = aud.shape + aud = aud.view(B, S, Fa, Ta).permute(0, 1, 3, 2) # (B, S, Ta, F) + # (B, S, ta, D), e.g. (B, 7, 6, 768) + aud, _ = self.afeat_extractor(aud) + return aud + + def compute_loss(self, logits, targets, loss_fn: str = None): + loss = None + if targets is not None: + if loss_fn is None or loss_fn == "cross_entropy": + # logits: (B, cls) and targets: (B,) + loss = F.cross_entropy(logits, targets) + else: + raise NotImplementedError(f"Loss {loss_fn} not implemented") + return loss + + def load_state_dict(self, sd: Mapping[str, Any], strict: bool = True): + # discard all entries except vfeat_extractor + # sd = {k: v for k, v in sd.items() if k.startswith('vfeat_extractor')} + + return super().load_state_dict(sd, strict) + + +class RandInitPositionalEncoding(nn.Module): + """Random inited trainable pos embedding. It is just applied on the sequence, thus respects no priors.""" + + def __init__(self, block_shape: list, n_embd: int): + super().__init__() + self.block_shape = block_shape + self.n_embd = n_embd + self.pos_emb = nn.Parameter(torch.randn(1, *block_shape, n_embd)) + + def forward(self, token_embeddings): + return token_embeddings + self.pos_emb + + +class GlobalTransformer(torch.nn.Module): + """Same as in SparseSync but without the selector transformers and the head""" + + def __init__( + self, + tok_pdrop=0.0, + embd_pdrop=0.1, + resid_pdrop=0.1, + attn_pdrop=0.1, + n_layer=3, + n_head=8, + n_embd=768, + pos_emb_block_shape=[ + 198, + ], + n_off_head_out=21, + ) -> None: + super().__init__() + self.config = Config( + embd_pdrop=embd_pdrop, + resid_pdrop=resid_pdrop, + attn_pdrop=attn_pdrop, + n_layer=n_layer, + n_head=n_head, + n_embd=n_embd, + ) + # input norm + self.vis_in_lnorm = torch.nn.LayerNorm(n_embd) + self.aud_in_lnorm = torch.nn.LayerNorm(n_embd) + # aux tokens + self.OFF_tok = torch.nn.Parameter(torch.randn(1, 1, n_embd)) + self.MOD_tok = torch.nn.Parameter(torch.randn(1, 1, n_embd)) + # whole token dropout + self.tok_pdrop = tok_pdrop + self.tok_drop_vis = torch.nn.Dropout1d(tok_pdrop) + self.tok_drop_aud = torch.nn.Dropout1d(tok_pdrop) + # maybe add pos emb + self.pos_emb_cfg = RandInitPositionalEncoding( + block_shape=pos_emb_block_shape, + n_embd=n_embd, + ) + # the stem + self.drop = torch.nn.Dropout(embd_pdrop) + self.blocks = torch.nn.Sequential(*[Block(self.config) for _ in range(n_layer)]) + # pre-output norm + self.ln_f = torch.nn.LayerNorm(n_embd) + # maybe add a head + self.off_head = torch.nn.Linear(in_features=n_embd, out_features=n_off_head_out) + + def forward(self, v: torch.Tensor, a: torch.Tensor, targets=None, attempt_to_apply_heads=True): + B, Sv, D = v.shape + B, Sa, D = a.shape + # broadcasting special tokens to the batch size + off_tok = einops.repeat(self.OFF_tok, "1 1 d -> b 1 d", b=B) + mod_tok = einops.repeat(self.MOD_tok, "1 1 d -> b 1 d", b=B) + # norm + v, a = self.vis_in_lnorm(v), self.aud_in_lnorm(a) + # maybe whole token dropout + if self.tok_pdrop > 0: + v, a = self.tok_drop_vis(v), self.tok_drop_aud(a) + # (B, 1+Sv+1+Sa, D) + x = torch.cat((off_tok, v, mod_tok, a), dim=1) + # maybe add pos emb + if hasattr(self, "pos_emb_cfg"): + x = self.pos_emb_cfg(x) + # dropout -> stem -> norm + x = self.drop(x) + x = self.blocks(x) + x = self.ln_f(x) + # maybe add heads + if attempt_to_apply_heads and hasattr(self, "off_head"): + x = self.off_head(x[:, 0, :]) + return x + + +class SelfAttention(nn.Module): + """ + A vanilla multi-head masked self-attention layer with a projection at the end. + It is possible to use torch.nn.MultiheadAttention here but I am including an + explicit implementation here to show that there is nothing too scary here. + """ + + def __init__(self, config): + super().__init__() + assert config.n_embd % config.n_head == 0 + # key, query, value projections for all heads + self.key = nn.Linear(config.n_embd, config.n_embd) + self.query = nn.Linear(config.n_embd, config.n_embd) + self.value = nn.Linear(config.n_embd, config.n_embd) + # regularization + self.attn_drop = nn.Dropout(config.attn_pdrop) + self.resid_drop = nn.Dropout(config.resid_pdrop) + # output projection + self.proj = nn.Linear(config.n_embd, config.n_embd) + # # causal mask to ensure that attention is only applied to the left in the input sequence + # mask = torch.tril(torch.ones(config.block_size, + # config.block_size)) + # if hasattr(config, "n_unmasked"): + # mask[:config.n_unmasked, :config.n_unmasked] = 1 + # self.register_buffer("mask", mask.view(1, 1, config.block_size, config.block_size)) + self.n_head = config.n_head + + def forward(self, x): + B, T, C = x.size() + + # calculate query, key, values for all heads in batch and move head forward to be the batch dim + k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + + # self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) + att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) + # att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float('-inf')) + att = F.softmax(att, dim=-1) + y = self.attn_drop(att) @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) + y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side + + # output projection + y = self.resid_drop(self.proj(y)) + + return y + + +class Block(nn.Module): + """an unassuming Transformer block""" + + def __init__(self, config): + super().__init__() + self.ln1 = nn.LayerNorm(config.n_embd) + self.ln2 = nn.LayerNorm(config.n_embd) + self.attn = SelfAttention(config) + self.mlp = nn.Sequential( + nn.Linear(config.n_embd, 4 * config.n_embd), + nn.GELU(), # nice + nn.Linear(4 * config.n_embd, config.n_embd), + nn.Dropout(config.resid_pdrop), + ) + + def forward(self, x): + x = x + self.attn(self.ln1(x)) + x = x + self.mlp(self.ln2(x)) + return x + + +def make_class_grid( + leftmost_val, + rightmost_val, + grid_size, + add_extreme_offset: bool = False, + seg_size_vframes: int = None, + nseg: int = None, + step_size_seg: float = None, + vfps: float = None, +): + assert grid_size >= 3, f"grid_size: {grid_size} doesnot make sense. If =2 -> (-1,1); =1 -> (-1); =0 -> ()" + grid = torch.from_numpy(np.linspace(leftmost_val, rightmost_val, grid_size)).float() + if add_extreme_offset: + assert all([seg_size_vframes, nseg, step_size_seg]), f"{seg_size_vframes} {nseg} {step_size_seg}" + seg_size_sec = seg_size_vframes / vfps + trim_size_in_seg = nseg - (1 - step_size_seg) * (nseg - 1) + extreme_value = trim_size_in_seg * seg_size_sec + grid = torch.cat([grid, torch.tensor([extreme_value])]) # adding extreme offset to the class grid + return grid + + +# from synchformer +def pad_or_truncate(audio: torch.Tensor, max_spec_t: int, pad_mode: str = "constant", pad_value: float = 0.0): + difference = max_spec_t - audio.shape[-1] # safe for batched input + # pad or truncate, depending on difference + if difference > 0: + # pad the last dim (time) -> (..., n_mels, 0+time+difference) # safe for batched input + pad_dims = (0, difference) + audio = torch.nn.functional.pad(audio, pad_dims, pad_mode, pad_value) + elif difference < 0: + print(f"Truncating spec ({audio.shape}) to max_spec_t ({max_spec_t}).") + audio = audio[..., :max_spec_t] # safe for batched input + return audio + + +def encode_audio_with_sync( + synchformer: Synchformer, x: torch.Tensor, mel: torchaudio.transforms.MelSpectrogram +) -> torch.Tensor: + b, t = x.shape + + # partition the video + segment_size = 10240 + step_size = 10240 // 2 + num_segments = (t - segment_size) // step_size + 1 + segments = [] + for i in range(num_segments): + segments.append(x[:, i * step_size : i * step_size + segment_size]) + x = torch.stack(segments, dim=1) # (B, S, T, C, H, W) + + x = mel(x) + x = torch.log(x + 1e-6) + x = pad_or_truncate(x, 66) + + mean = -4.2677393 + std = 4.5689974 + x = (x - mean) / (2 * std) + # x: B * S * 128 * 66 + x = synchformer.extract_afeats(x.unsqueeze(2)) + return x + + +def read_audio(filename, expected_length=int(16000 * 4)): + waveform, sr = torchaudio.load(filename) + waveform = waveform.mean(dim=0) + + if sr != 16000: + resampler = torchaudio.transforms.Resample(sr, 16000) + waveform = resampler[sr](waveform) + + waveform = waveform[:expected_length] + if waveform.shape[0] != expected_length: + raise ValueError(f"Audio {filename} is too short") + + waveform = waveform.squeeze() + + return waveform + + +if __name__ == "__main__": + synchformer = Synchformer().cuda().eval() + + # mmaudio provided synchformer ckpt + synchformer.load_state_dict( + torch.load( + os.environ.get("SYNCHFORMER_WEIGHTS", f"weights/synchformer.pth"), + weights_only=True, + map_location="cpu", + ) + ) + + sync_mel_spectrogram = torchaudio.transforms.MelSpectrogram( + sample_rate=16000, + win_length=400, + hop_length=160, + n_fft=1024, + n_mels=128, + ) diff --git a/hunyuanvideo_foley/models/synchformer/utils.py b/hunyuanvideo_foley/models/synchformer/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..05595cc15b925f52ccd07fea8f131ec810f56bd7 --- /dev/null +++ b/hunyuanvideo_foley/models/synchformer/utils.py @@ -0,0 +1,87 @@ +from hashlib import md5 +from pathlib import Path +import subprocess + +import requests +from tqdm import tqdm + +PARENT_LINK = "https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a" +FNAME2LINK = { + # S3: Synchability: AudioSet (run 2) + "24-01-22T20-34-52.pt": f"{PARENT_LINK}/sync/sync_models/24-01-22T20-34-52/24-01-22T20-34-52.pt", + "cfg-24-01-22T20-34-52.yaml": f"{PARENT_LINK}/sync/sync_models/24-01-22T20-34-52/cfg-24-01-22T20-34-52.yaml", + # S2: Synchformer: AudioSet (run 2) + "24-01-04T16-39-21.pt": f"{PARENT_LINK}/sync/sync_models/24-01-04T16-39-21/24-01-04T16-39-21.pt", + "cfg-24-01-04T16-39-21.yaml": f"{PARENT_LINK}/sync/sync_models/24-01-04T16-39-21/cfg-24-01-04T16-39-21.yaml", + # S2: Synchformer: AudioSet (run 1) + "23-08-28T11-23-23.pt": f"{PARENT_LINK}/sync/sync_models/23-08-28T11-23-23/23-08-28T11-23-23.pt", + "cfg-23-08-28T11-23-23.yaml": f"{PARENT_LINK}/sync/sync_models/23-08-28T11-23-23/cfg-23-08-28T11-23-23.yaml", + # S2: Synchformer: LRS3 (run 2) + "23-12-23T18-33-57.pt": f"{PARENT_LINK}/sync/sync_models/23-12-23T18-33-57/23-12-23T18-33-57.pt", + "cfg-23-12-23T18-33-57.yaml": f"{PARENT_LINK}/sync/sync_models/23-12-23T18-33-57/cfg-23-12-23T18-33-57.yaml", + # S2: Synchformer: VGS (run 2) + "24-01-02T10-00-53.pt": f"{PARENT_LINK}/sync/sync_models/24-01-02T10-00-53/24-01-02T10-00-53.pt", + "cfg-24-01-02T10-00-53.yaml": f"{PARENT_LINK}/sync/sync_models/24-01-02T10-00-53/cfg-24-01-02T10-00-53.yaml", + # SparseSync: ft VGGSound-Full + "22-09-21T21-00-52.pt": f"{PARENT_LINK}/sync/sync_models/22-09-21T21-00-52/22-09-21T21-00-52.pt", + "cfg-22-09-21T21-00-52.yaml": f"{PARENT_LINK}/sync/sync_models/22-09-21T21-00-52/cfg-22-09-21T21-00-52.yaml", + # SparseSync: ft VGGSound-Sparse + "22-07-28T15-49-45.pt": f"{PARENT_LINK}/sync/sync_models/22-07-28T15-49-45/22-07-28T15-49-45.pt", + "cfg-22-07-28T15-49-45.yaml": f"{PARENT_LINK}/sync/sync_models/22-07-28T15-49-45/cfg-22-07-28T15-49-45.yaml", + # SparseSync: only pt on LRS3 + "22-07-13T22-25-49.pt": f"{PARENT_LINK}/sync/sync_models/22-07-13T22-25-49/22-07-13T22-25-49.pt", + "cfg-22-07-13T22-25-49.yaml": f"{PARENT_LINK}/sync/sync_models/22-07-13T22-25-49/cfg-22-07-13T22-25-49.yaml", + # SparseSync: feature extractors + "ResNetAudio-22-08-04T09-51-04.pt": f"{PARENT_LINK}/sync/ResNetAudio-22-08-04T09-51-04.pt", # 2s + "ResNetAudio-22-08-03T23-14-49.pt": f"{PARENT_LINK}/sync/ResNetAudio-22-08-03T23-14-49.pt", # 3s + "ResNetAudio-22-08-03T23-14-28.pt": f"{PARENT_LINK}/sync/ResNetAudio-22-08-03T23-14-28.pt", # 4s + "ResNetAudio-22-06-24T08-10-33.pt": f"{PARENT_LINK}/sync/ResNetAudio-22-06-24T08-10-33.pt", # 5s + "ResNetAudio-22-06-24T17-31-07.pt": f"{PARENT_LINK}/sync/ResNetAudio-22-06-24T17-31-07.pt", # 6s + "ResNetAudio-22-06-24T23-57-11.pt": f"{PARENT_LINK}/sync/ResNetAudio-22-06-24T23-57-11.pt", # 7s + "ResNetAudio-22-06-25T04-35-42.pt": f"{PARENT_LINK}/sync/ResNetAudio-22-06-25T04-35-42.pt", # 8s +} + + +def check_if_file_exists_else_download(path, fname2link=FNAME2LINK, chunk_size=1024): + """Checks if file exists, if not downloads it from the link to the path""" + path = Path(path) + if not path.exists(): + path.parent.mkdir(exist_ok=True, parents=True) + link = fname2link.get(path.name, None) + if link is None: + raise ValueError( + f"Cant find the checkpoint file: {path}.", f"Please download it manually and ensure the path exists." + ) + with requests.get(fname2link[path.name], stream=True) as r: + total_size = int(r.headers.get("content-length", 0)) + with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: + with open(path, "wb") as f: + for data in r.iter_content(chunk_size=chunk_size): + if data: + f.write(data) + pbar.update(chunk_size) + + +def which_ffmpeg() -> str: + """Determines the path to ffmpeg library + Returns: + str -- path to the library + """ + result = subprocess.run(["which", "ffmpeg"], stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + ffmpeg_path = result.stdout.decode("utf-8").replace("\n", "") + return ffmpeg_path + + +def get_md5sum(path): + hash_md5 = md5() + with open(path, "rb") as f: + for chunk in iter(lambda: f.read(4096 * 8), b""): + hash_md5.update(chunk) + md5sum = hash_md5.hexdigest() + return md5sum + + +class Config: + def __init__(self, **kwargs): + for k, v in kwargs.items(): + setattr(self, k, v) diff --git a/hunyuanvideo_foley/models/synchformer/video_model_builder.py b/hunyuanvideo_foley/models/synchformer/video_model_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..190df1a5f066c2c06ab41178fc1174c7956bc599 --- /dev/null +++ b/hunyuanvideo_foley/models/synchformer/video_model_builder.py @@ -0,0 +1,270 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +# Copyright 2020 Ross Wightman +# Modified Model definition + +from collections import OrderedDict +from functools import partial + +import torch +import torch.nn as nn +from timm.layers import trunc_normal_ + +from .vit_helper import PatchEmbed, PatchEmbed3D, DividedSpaceTimeBlock + + +class VisionTransformer(nn.Module): + """Vision Transformer with support for patch or hybrid CNN input stage""" + + def __init__(self, cfg): + super().__init__() + self.img_size = cfg.DATA.TRAIN_CROP_SIZE + self.patch_size = cfg.VIT.PATCH_SIZE + self.in_chans = cfg.VIT.CHANNELS + if cfg.TRAIN.DATASET == "Epickitchens": + self.num_classes = [97, 300] + else: + self.num_classes = cfg.MODEL.NUM_CLASSES + self.embed_dim = cfg.VIT.EMBED_DIM + self.depth = cfg.VIT.DEPTH + self.num_heads = cfg.VIT.NUM_HEADS + self.mlp_ratio = cfg.VIT.MLP_RATIO + self.qkv_bias = cfg.VIT.QKV_BIAS + self.drop_rate = cfg.VIT.DROP + self.drop_path_rate = cfg.VIT.DROP_PATH + self.head_dropout = cfg.VIT.HEAD_DROPOUT + self.video_input = cfg.VIT.VIDEO_INPUT + self.temporal_resolution = cfg.VIT.TEMPORAL_RESOLUTION + self.use_mlp = cfg.VIT.USE_MLP + self.num_features = self.embed_dim + norm_layer = partial(nn.LayerNorm, eps=1e-6) + self.attn_drop_rate = cfg.VIT.ATTN_DROPOUT + self.head_act = cfg.VIT.HEAD_ACT + self.cfg = cfg + + # Patch Embedding + self.patch_embed = PatchEmbed( + img_size=224, patch_size=self.patch_size, in_chans=self.in_chans, embed_dim=self.embed_dim + ) + + # 3D Patch Embedding + self.patch_embed_3d = PatchEmbed3D( + img_size=self.img_size, + temporal_resolution=self.temporal_resolution, + patch_size=self.patch_size, + in_chans=self.in_chans, + embed_dim=self.embed_dim, + z_block_size=self.cfg.VIT.PATCH_SIZE_TEMP, + ) + self.patch_embed_3d.proj.weight.data = torch.zeros_like(self.patch_embed_3d.proj.weight.data) + + # Number of patches + if self.video_input: + num_patches = self.patch_embed.num_patches * self.temporal_resolution + else: + num_patches = self.patch_embed.num_patches + self.num_patches = num_patches + + # CLS token + self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) + trunc_normal_(self.cls_token, std=0.02) + + # Positional embedding + self.pos_embed = nn.Parameter(torch.zeros(1, self.patch_embed.num_patches + 1, self.embed_dim)) + self.pos_drop = nn.Dropout(p=cfg.VIT.POS_DROPOUT) + trunc_normal_(self.pos_embed, std=0.02) + + if self.cfg.VIT.POS_EMBED == "joint": + self.st_embed = nn.Parameter(torch.zeros(1, num_patches + 1, self.embed_dim)) + trunc_normal_(self.st_embed, std=0.02) + elif self.cfg.VIT.POS_EMBED == "separate": + self.temp_embed = nn.Parameter(torch.zeros(1, self.temporal_resolution, self.embed_dim)) + + # Layer Blocks + dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, self.depth)] + if self.cfg.VIT.ATTN_LAYER == "divided": + self.blocks = nn.ModuleList( + [ + DividedSpaceTimeBlock( + attn_type=cfg.VIT.ATTN_LAYER, + dim=self.embed_dim, + num_heads=self.num_heads, + mlp_ratio=self.mlp_ratio, + qkv_bias=self.qkv_bias, + drop=self.drop_rate, + attn_drop=self.attn_drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + ) + for i in range(self.depth) + ] + ) + + self.norm = norm_layer(self.embed_dim) + + # MLP head + if self.use_mlp: + hidden_dim = self.embed_dim + if self.head_act == "tanh": + # logging.info("Using TanH activation in MLP") + act = nn.Tanh() + elif self.head_act == "gelu": + # logging.info("Using GELU activation in MLP") + act = nn.GELU() + else: + # logging.info("Using ReLU activation in MLP") + act = nn.ReLU() + self.pre_logits = nn.Sequential( + OrderedDict( + [ + ("fc", nn.Linear(self.embed_dim, hidden_dim)), + ("act", act), + ] + ) + ) + else: + self.pre_logits = nn.Identity() + + # Classifier Head + self.head_drop = nn.Dropout(p=self.head_dropout) + if isinstance(self.num_classes, (list,)) and len(self.num_classes) > 1: + for a, i in enumerate(range(len(self.num_classes))): + setattr(self, "head%d" % a, nn.Linear(self.embed_dim, self.num_classes[i])) + else: + self.head = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity() + + # Initialize weights + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + if self.cfg.VIT.POS_EMBED == "joint": + return {"pos_embed", "cls_token", "st_embed"} + else: + return {"pos_embed", "cls_token", "temp_embed"} + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=""): + self.num_classes = num_classes + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + # if self.video_input: + # x = x[0] + B = x.shape[0] + + # Tokenize input + # if self.cfg.VIT.PATCH_SIZE_TEMP > 1: + # for simplicity of mapping between content dimensions (input x) and token dims (after patching) + # we use the same trick as for AST (see modeling_ast.ASTModel.forward for the details): + + # apply patching on input + x = self.patch_embed_3d(x) + tok_mask = None + + # else: + # tok_mask = None + # # 2D tokenization + # if self.video_input: + # x = x.permute(0, 2, 1, 3, 4) + # (B, T, C, H, W) = x.shape + # x = x.reshape(B * T, C, H, W) + + # x = self.patch_embed(x) + + # if self.video_input: + # (B2, T2, D2) = x.shape + # x = x.reshape(B, T * T2, D2) + + # Append CLS token + cls_tokens = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + # if tok_mask is not None: + # # prepend 1(=keep) to the mask to account for the CLS token as well + # tok_mask = torch.cat((torch.ones_like(tok_mask[:, [0]]), tok_mask), dim=1) + + # Interpolate positinoal embeddings + # if self.cfg.DATA.TRAIN_CROP_SIZE != 224: + # pos_embed = self.pos_embed + # N = pos_embed.shape[1] - 1 + # npatch = int((x.size(1) - 1) / self.temporal_resolution) + # class_emb = pos_embed[:, 0] + # pos_embed = pos_embed[:, 1:] + # dim = x.shape[-1] + # pos_embed = torch.nn.functional.interpolate( + # pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), + # scale_factor=math.sqrt(npatch / N), + # mode='bicubic', + # ) + # pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + # new_pos_embed = torch.cat((class_emb.unsqueeze(0), pos_embed), dim=1) + # else: + new_pos_embed = self.pos_embed + npatch = self.patch_embed.num_patches + + # Add positional embeddings to input + if self.video_input: + if self.cfg.VIT.POS_EMBED == "separate": + cls_embed = self.pos_embed[:, 0, :].unsqueeze(1) + tile_pos_embed = new_pos_embed[:, 1:, :].repeat(1, self.temporal_resolution, 1) + tile_temporal_embed = self.temp_embed.repeat_interleave(npatch, 1) + total_pos_embed = tile_pos_embed + tile_temporal_embed + total_pos_embed = torch.cat([cls_embed, total_pos_embed], dim=1) + x = x + total_pos_embed + elif self.cfg.VIT.POS_EMBED == "joint": + x = x + self.st_embed + else: + # image input + x = x + new_pos_embed + + # Apply positional dropout + x = self.pos_drop(x) + + # Encoding using transformer layers + for i, blk in enumerate(self.blocks): + x = blk( + x, + seq_len=npatch, + num_frames=self.temporal_resolution, + approx=self.cfg.VIT.APPROX_ATTN_TYPE, + num_landmarks=self.cfg.VIT.APPROX_ATTN_DIM, + tok_mask=tok_mask, + ) + + ### v-iashin: I moved it to the forward pass + # x = self.norm(x)[:, 0] + # x = self.pre_logits(x) + ### + return x, tok_mask + + # def forward(self, x): + # x = self.forward_features(x) + # ### v-iashin: here. This should leave the same forward output as before + # x = self.norm(x)[:, 0] + # x = self.pre_logits(x) + # ### + # x = self.head_drop(x) + # if isinstance(self.num_classes, (list, )) and len(self.num_classes) > 1: + # output = [] + # for head in range(len(self.num_classes)): + # x_out = getattr(self, "head%d" % head)(x) + # if not self.training: + # x_out = torch.nn.functional.softmax(x_out, dim=-1) + # output.append(x_out) + # return output + # else: + # x = self.head(x) + # if not self.training: + # x = torch.nn.functional.softmax(x, dim=-1) + # return x diff --git a/hunyuanvideo_foley/models/synchformer/vit_helper.py b/hunyuanvideo_foley/models/synchformer/vit_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..29739530ce8692f3124b3ae748f11a4a06aa5fc8 --- /dev/null +++ b/hunyuanvideo_foley/models/synchformer/vit_helper.py @@ -0,0 +1,364 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +# Copyright 2020 Ross Wightman +# Modified Model definition +"""Video models.""" + +import math + +import torch +import torch.nn as nn +from einops import rearrange, repeat +from timm.layers import to_2tuple +from torch import einsum +from torch.nn import functional as F + +default_cfgs = { + "vit_1k": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth", + "vit_1k_large": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_224-4ee7a4dc.pth", +} + + +def qkv_attn(q, k, v, tok_mask: torch.Tensor = None): + sim = einsum("b i d, b j d -> b i j", q, k) + # apply masking if provided, tok_mask is (B*S*H, N): 1s - keep; sim is (B*S*H, H, N, N) + if tok_mask is not None: + BSH, N = tok_mask.shape + sim = sim.masked_fill(tok_mask.view(BSH, 1, N) == 0, float("-inf")) # 1 - broadcasts across N + attn = sim.softmax(dim=-1) + out = einsum("b i j, b j d -> b i d", attn, v) + return out + + +class DividedAttention(nn.Module): + + def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + + # init to zeros + self.qkv.weight.data.fill_(0) + self.qkv.bias.data.fill_(0) + self.proj.weight.data.fill_(1) + self.proj.bias.data.fill_(0) + + self.attn_drop = nn.Dropout(attn_drop) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, einops_from, einops_to, tok_mask: torch.Tensor = None, **einops_dims): + # num of heads variable + h = self.num_heads + + # project x to q, k, v vaalues + q, k, v = self.qkv(x).chunk(3, dim=-1) + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) + if tok_mask is not None: + # replicate token mask across heads (b, n) -> (b, h, n) -> (b*h, n) -- same as qkv but w/o d + assert len(tok_mask.shape) == 2 + tok_mask = tok_mask.unsqueeze(1).expand(-1, h, -1).reshape(-1, tok_mask.shape[1]) + + # Scale q + q *= self.scale + + # Take out cls_q, cls_k, cls_v + (cls_q, q_), (cls_k, k_), (cls_v, v_) = map(lambda t: (t[:, 0:1], t[:, 1:]), (q, k, v)) + # the same for masking + if tok_mask is not None: + cls_mask, mask_ = tok_mask[:, 0:1], tok_mask[:, 1:] + else: + cls_mask, mask_ = None, None + + # let CLS token attend to key / values of all patches across time and space + cls_out = qkv_attn(cls_q, k, v, tok_mask=tok_mask) + + # rearrange across time or space + q_, k_, v_ = map(lambda t: rearrange(t, f"{einops_from} -> {einops_to}", **einops_dims), (q_, k_, v_)) + + # expand CLS token keys and values across time or space and concat + r = q_.shape[0] // cls_k.shape[0] + cls_k, cls_v = map(lambda t: repeat(t, "b () d -> (b r) () d", r=r), (cls_k, cls_v)) + + k_ = torch.cat((cls_k, k_), dim=1) + v_ = torch.cat((cls_v, v_), dim=1) + + # the same for masking (if provided) + if tok_mask is not None: + # since mask does not have the latent dim (d), we need to remove it from einops dims + mask_ = rearrange(mask_, f"{einops_from} -> {einops_to}".replace(" d", ""), **einops_dims) + cls_mask = repeat(cls_mask, "b () -> (b r) ()", r=r) # expand cls_mask across time or space + mask_ = torch.cat((cls_mask, mask_), dim=1) + + # attention + out = qkv_attn(q_, k_, v_, tok_mask=mask_) + + # merge back time or space + out = rearrange(out, f"{einops_to} -> {einops_from}", **einops_dims) + + # concat back the cls token + out = torch.cat((cls_out, out), dim=1) + + # merge back the heads + out = rearrange(out, "(b h) n d -> b n (h d)", h=h) + + ## to out + x = self.proj(out) + x = self.proj_drop(x) + return x + + +class DividedSpaceTimeBlock(nn.Module): + + def __init__( + self, + dim=768, + num_heads=12, + attn_type="divided", + mlp_ratio=4.0, + qkv_bias=False, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + ): + super().__init__() + + self.einops_from_space = "b (f n) d" + self.einops_to_space = "(b f) n d" + self.einops_from_time = "b (f n) d" + self.einops_to_time = "(b n) f d" + + self.norm1 = norm_layer(dim) + + self.attn = DividedAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) + + self.timeattn = DividedAttention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop + ) + + # self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.drop_path = nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.norm3 = norm_layer(dim) + + def forward(self, x, seq_len=196, num_frames=8, approx="none", num_landmarks=128, tok_mask: torch.Tensor = None): + time_output = self.timeattn( + self.norm3(x), self.einops_from_time, self.einops_to_time, n=seq_len, tok_mask=tok_mask + ) + time_residual = x + time_output + + space_output = self.attn( + self.norm1(time_residual), self.einops_from_space, self.einops_to_space, f=num_frames, tok_mask=tok_mask + ) + space_residual = time_residual + self.drop_path(space_output) + + x = space_residual + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class Mlp(nn.Module): + + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class PatchEmbed(nn.Module): + """Image to Patch Embedding""" + + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): + super().__init__() + img_size = img_size if type(img_size) is tuple else to_2tuple(img_size) + patch_size = img_size if type(patch_size) is tuple else to_2tuple(patch_size) + num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, x): + B, C, H, W = x.shape + x = self.proj(x).flatten(2).transpose(1, 2) + return x + + +class PatchEmbed3D(nn.Module): + """Image to Patch Embedding""" + + def __init__( + self, + img_size=224, + temporal_resolution=4, + in_chans=3, + patch_size=16, + z_block_size=2, + embed_dim=768, + flatten=True, + ): + super().__init__() + self.height = img_size // patch_size + self.width = img_size // patch_size + ### v-iashin: these two are incorrect + # self.frames = (temporal_resolution // z_block_size) + # self.num_patches = self.height * self.width * self.frames + self.z_block_size = z_block_size + ### + self.proj = nn.Conv3d( + in_chans, + embed_dim, + kernel_size=(z_block_size, patch_size, patch_size), + stride=(z_block_size, patch_size, patch_size), + ) + self.flatten = flatten + + def forward(self, x): + B, C, T, H, W = x.shape + x = self.proj(x) + if self.flatten: + x = x.flatten(2).transpose(1, 2) + return x + + +class HeadMLP(nn.Module): + + def __init__(self, n_input, n_classes, n_hidden=512, p=0.1): + super(HeadMLP, self).__init__() + self.n_input = n_input + self.n_classes = n_classes + self.n_hidden = n_hidden + if n_hidden is None: + # use linear classifier + self.block_forward = nn.Sequential(nn.Dropout(p=p), nn.Linear(n_input, n_classes, bias=True)) + else: + # use simple MLP classifier + self.block_forward = nn.Sequential( + nn.Dropout(p=p), + nn.Linear(n_input, n_hidden, bias=True), + nn.BatchNorm1d(n_hidden), + nn.ReLU(inplace=True), + nn.Dropout(p=p), + nn.Linear(n_hidden, n_classes, bias=True), + ) + print(f"Dropout-NLP: {p}") + + def forward(self, x): + return self.block_forward(x) + + +def _conv_filter(state_dict, patch_size=16): + """convert patch embedding weight from manual patchify + linear proj to conv""" + out_dict = {} + for k, v in state_dict.items(): + if "patch_embed.proj.weight" in k: + v = v.reshape((v.shape[0], 3, patch_size, patch_size)) + out_dict[k] = v + return out_dict + + +def adapt_input_conv(in_chans, conv_weight, agg="sum"): + conv_type = conv_weight.dtype + conv_weight = conv_weight.float() + O, I, J, K = conv_weight.shape + if in_chans == 1: + if I > 3: + assert conv_weight.shape[1] % 3 == 0 + # For models with space2depth stems + conv_weight = conv_weight.reshape(O, I // 3, 3, J, K) + conv_weight = conv_weight.sum(dim=2, keepdim=False) + else: + if agg == "sum": + print("Summing conv1 weights") + conv_weight = conv_weight.sum(dim=1, keepdim=True) + else: + print("Averaging conv1 weights") + conv_weight = conv_weight.mean(dim=1, keepdim=True) + elif in_chans != 3: + if I != 3: + raise NotImplementedError("Weight format not supported by conversion.") + else: + if agg == "sum": + print("Summing conv1 weights") + repeat = int(math.ceil(in_chans / 3)) + conv_weight = conv_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :] + conv_weight *= 3 / float(in_chans) + else: + print("Averaging conv1 weights") + conv_weight = conv_weight.mean(dim=1, keepdim=True) + conv_weight = conv_weight.repeat(1, in_chans, 1, 1) + conv_weight = conv_weight.to(conv_type) + return conv_weight + + +def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True, progress=False): + # Load state dict + assert f"{cfg.VIT.PRETRAINED_WEIGHTS} not in [vit_1k, vit_1k_large]" + state_dict = torch.hub.load_state_dict_from_url(url=default_cfgs[cfg.VIT.PRETRAINED_WEIGHTS]) + + if filter_fn is not None: + state_dict = filter_fn(state_dict) + + input_convs = "patch_embed.proj" + if input_convs is not None and in_chans != 3: + if isinstance(input_convs, str): + input_convs = (input_convs,) + for input_conv_name in input_convs: + weight_name = input_conv_name + ".weight" + try: + state_dict[weight_name] = adapt_input_conv(in_chans, state_dict[weight_name], agg="avg") + print(f"Converted input conv {input_conv_name} pretrained weights from 3 to {in_chans} channel(s)") + except NotImplementedError as e: + del state_dict[weight_name] + strict = False + print(f"Unable to convert pretrained {input_conv_name} weights, using random init for this layer.") + + classifier_name = "head" + label_offset = cfg.get("label_offset", 0) + pretrain_classes = 1000 + if num_classes != pretrain_classes: + # completely discard fully connected if model num_classes doesn't match pretrained weights + del state_dict[classifier_name + ".weight"] + del state_dict[classifier_name + ".bias"] + strict = False + elif label_offset > 0: + # special case for pretrained weights with an extra background class in pretrained weights + classifier_weight = state_dict[classifier_name + ".weight"] + state_dict[classifier_name + ".weight"] = classifier_weight[label_offset:] + classifier_bias = state_dict[classifier_name + ".bias"] + state_dict[classifier_name + ".bias"] = classifier_bias[label_offset:] + + loaded_state = state_dict + self_state = model.state_dict() + all_names = set(self_state.keys()) + saved_names = set([]) + for name, param in loaded_state.items(): + param = param + if "module." in name: + name = name.replace("module.", "") + if name in self_state.keys() and param.shape == self_state[name].shape: + saved_names.add(name) + self_state[name].copy_(param) + else: + print(f"didnt load: {name} of shape: {param.shape}") + print("Missing Keys:") + print(all_names - saved_names) diff --git a/hunyuanvideo_foley/utils/__init__.py b/hunyuanvideo_foley/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/hunyuanvideo_foley/utils/__pycache__/__init__.cpython-312.pyc b/hunyuanvideo_foley/utils/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ecb2a5bb45f0d2d676ace18907b62c2d2acd3b28 Binary files /dev/null and b/hunyuanvideo_foley/utils/__pycache__/__init__.cpython-312.pyc differ diff --git a/hunyuanvideo_foley/utils/__pycache__/model_utils.cpython-312.pyc b/hunyuanvideo_foley/utils/__pycache__/model_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fd82ed9bff287f4f271fb47350d786e6a056846a Binary files /dev/null and b/hunyuanvideo_foley/utils/__pycache__/model_utils.cpython-312.pyc differ diff --git a/hunyuanvideo_foley/utils/config_utils.py b/hunyuanvideo_foley/utils/config_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f1bdf9108dfebf68c0a27cfdd675c8cc5e4c1d88 --- /dev/null +++ b/hunyuanvideo_foley/utils/config_utils.py @@ -0,0 +1,109 @@ +"""Configuration utilities for the HunyuanVideo-Foley project.""" + +import yaml +from pathlib import Path +from typing import Any, Dict, List, Union + +class AttributeDict: + + def __init__(self, data: Union[Dict, List, Any]): + if isinstance(data, dict): + for key, value in data.items(): + if isinstance(value, (dict, list)): + value = AttributeDict(value) + setattr(self, self._sanitize_key(key), value) + elif isinstance(data, list): + self._list = [AttributeDict(item) if isinstance(item, (dict, list)) else item + for item in data] + else: + self._value = data + + def _sanitize_key(self, key: str) -> str: + import re + sanitized = re.sub(r'[^a-zA-Z0-9_]', '_', str(key)) + if sanitized[0].isdigit(): + sanitized = f'_{sanitized}' + return sanitized + + def __getitem__(self, key): + if hasattr(self, '_list'): + return self._list[key] + return getattr(self, self._sanitize_key(key)) + + def __setitem__(self, key, value): + if hasattr(self, '_list'): + self._list[key] = value + else: + setattr(self, self._sanitize_key(key), value) + + def __iter__(self): + if hasattr(self, '_list'): + return iter(self._list) + return iter(self.__dict__.keys()) + + def __len__(self): + if hasattr(self, '_list'): + return len(self._list) + return len(self.__dict__) + + def get(self, key, default=None): + try: + return self[key] + except (KeyError, AttributeError, IndexError): + return default + + def keys(self): + if hasattr(self, '_list'): + return range(len(self._list)) + elif hasattr(self, '_value'): + return [] + else: + return [key for key in self.__dict__.keys() if not key.startswith('_')] + + def values(self): + if hasattr(self, '_list'): + return self._list + elif hasattr(self, '_value'): + return [self._value] + else: + return [value for key, value in self.__dict__.items() if not key.startswith('_')] + + def items(self): + if hasattr(self, '_list'): + return enumerate(self._list) + elif hasattr(self, '_value'): + return [] + else: + return [(key, value) for key, value in self.__dict__.items() if not key.startswith('_')] + + def __repr__(self): + if hasattr(self, '_list'): + return f"AttributeDict({self._list})" + elif hasattr(self, '_value'): + return f"AttributeDict({self._value})" + return f"AttributeDict({dict(self.__dict__)})" + + def to_dict(self) -> Union[Dict, List, Any]: + if hasattr(self, '_list'): + return [item.to_dict() if isinstance(item, AttributeDict) else item + for item in self._list] + elif hasattr(self, '_value'): + return self._value + else: + result = {} + for key, value in self.__dict__.items(): + if isinstance(value, AttributeDict): + result[key] = value.to_dict() + else: + result[key] = value + return result + +def load_yaml(file_path: str, encoding: str = 'utf-8') -> AttributeDict: + try: + with open(file_path, 'r', encoding=encoding) as file: + data = yaml.safe_load(file) + return AttributeDict(data) + except FileNotFoundError: + raise FileNotFoundError(f"YAML file not found: {file_path}") + except yaml.YAMLError as e: + raise yaml.YAMLError(f"YAML format error: {e}") diff --git a/hunyuanvideo_foley/utils/feature_utils.py b/hunyuanvideo_foley/utils/feature_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..39f474479fdac18000e89d4db0b5211e35dffe23 --- /dev/null +++ b/hunyuanvideo_foley/utils/feature_utils.py @@ -0,0 +1,156 @@ +"""Feature extraction utilities for video and text processing.""" + +import os +import numpy as np +import torch +import av +from PIL import Image +from einops import rearrange +from typing import Any, Dict, List, Union, Tuple +from loguru import logger + +from .config_utils import AttributeDict +from ..constants import FPS_VISUAL, MAX_VIDEO_DURATION_SECONDS + + +class FeatureExtractionError(Exception): + """Exception raised for feature extraction errors.""" + pass + +def get_frames_av( + video_path: str, + fps: float, + max_length: float = None, +) -> Tuple[np.ndarray, float]: + end_sec = max_length if max_length is not None else 15 + next_frame_time_for_each_fps = 0.0 + time_delta_for_each_fps = 1 / fps + + all_frames = [] + output_frames = [] + + with av.open(video_path) as container: + stream = container.streams.video[0] + ori_fps = stream.guessed_rate + stream.thread_type = "AUTO" + for packet in container.demux(stream): + for frame in packet.decode(): + frame_time = frame.time + if frame_time < 0: + continue + if frame_time > end_sec: + break + + frame_np = None + + this_time = frame_time + while this_time >= next_frame_time_for_each_fps: + if frame_np is None: + frame_np = frame.to_ndarray(format="rgb24") + + output_frames.append(frame_np) + next_frame_time_for_each_fps += time_delta_for_each_fps + + output_frames = np.stack(output_frames) + + vid_len_in_s = len(output_frames) / fps + if max_length is not None and len(output_frames) > int(max_length * fps): + output_frames = output_frames[: int(max_length * fps)] + vid_len_in_s = max_length + + return output_frames, vid_len_in_s + +@torch.inference_mode() +def encode_video_with_siglip2(x: torch.Tensor, model_dict, batch_size: int = -1): + b, t, c, h, w = x.shape + if batch_size < 0: + batch_size = b * t + x = rearrange(x, "b t c h w -> (b t) c h w") + outputs = [] + for i in range(0, b * t, batch_size): + outputs.append(model_dict.siglip2_model.get_image_features(pixel_values=x[i : i + batch_size])) + res = torch.cat(outputs, dim=0) + res = rearrange(res, "(b t) d -> b t d", b=b) + return res + +@torch.inference_mode() +def encode_video_with_sync(x: torch.Tensor, model_dict, batch_size: int = -1): + """ + The input video of x is best to be in fps of 24 of greater than 24. + Input: + x: tensor in shape of [B, T, C, H, W] + batch_size: the batch_size for synchformer inference + """ + b, t, c, h, w = x.shape + assert c == 3 and h == 224 and w == 224 + + segment_size = 16 + step_size = 8 + num_segments = (t - segment_size) // step_size + 1 + segments = [] + for i in range(num_segments): + segments.append(x[:, i * step_size : i * step_size + segment_size]) + x = torch.stack(segments, dim=1).cuda() # (B, num_segments, segment_size, 3, 224, 224) + + outputs = [] + if batch_size < 0: + batch_size = b * num_segments + x = rearrange(x, "b s t c h w -> (b s) 1 t c h w") + for i in range(0, b * num_segments, batch_size): + with torch.autocast(device_type="cuda", enabled=True, dtype=torch.half): + outputs.append(model_dict.syncformer_model(x[i : i + batch_size])) + x = torch.cat(outputs, dim=0) # [b * num_segments, 1, 8, 768] + x = rearrange(x, "(b s) 1 t d -> b (s t) d", b=b) + return x + + +@torch.inference_mode() +def encode_video_features(video_path, model_dict): + visual_features = {} + # siglip2 visual features + frames, ori_vid_len_in_s = get_frames_av(video_path, FPS_VISUAL["siglip2"]) + images = [Image.fromarray(frame).convert('RGB') for frame in frames] + images = [model_dict.siglip2_preprocess(image) for image in images] # [T, C, H, W] + clip_frames = torch.stack(images).to(model_dict.device).unsqueeze(0) + visual_features['siglip2_feat'] = encode_video_with_siglip2(clip_frames, model_dict).to(model_dict.device) + + # synchformer visual features + frames, ori_vid_len_in_s = get_frames_av(video_path, FPS_VISUAL["synchformer"]) + images = torch.from_numpy(frames).permute(0, 3, 1, 2) # [T, C, H, W] + sync_frames = model_dict.syncformer_preprocess(images).unsqueeze(0) # [1, T, 3, 224, 224] + # [1, num_segments * 8, channel_dim], e.g. [1, 240, 768] for 10s video + visual_features['syncformer_feat'] = encode_video_with_sync(sync_frames, model_dict) + + vid_len_in_s = sync_frames.shape[1] / FPS_VISUAL["synchformer"] + visual_features = AttributeDict(visual_features) + + return visual_features, vid_len_in_s + +@torch.inference_mode() +def encode_text_feat(text: List[str], model_dict): + # x: (B, L) + inputs = model_dict.clap_tokenizer(text, padding=True, return_tensors="pt").to(model_dict.device) + outputs = model_dict.clap_model(**inputs, output_hidden_states=True, return_dict=True) + return outputs.last_hidden_state, outputs.attentions + + +def feature_process(video_path, prompt, model_dict, cfg): + visual_feats, audio_len_in_s = encode_video_features(video_path, model_dict) + neg_prompt = "noisy, harsh" + prompts = [neg_prompt, prompt] + text_feat_res, text_feat_mask = encode_text_feat(prompts, model_dict) + + text_feat = text_feat_res[1:] + uncond_text_feat = text_feat_res[:1] + + if cfg.model_config.model_kwargs.text_length < text_feat.shape[1]: + text_seq_length = cfg.model_config.model_kwargs.text_length + text_feat = text_feat[:, :text_seq_length] + uncond_text_feat = uncond_text_feat[:, :text_seq_length] + + text_feats = AttributeDict({ + 'text_feat': text_feat, + 'uncond_text_feat': uncond_text_feat, + }) + + return visual_feats, text_feats, audio_len_in_s diff --git a/hunyuanvideo_foley/utils/helper.py b/hunyuanvideo_foley/utils/helper.py new file mode 100644 index 0000000000000000000000000000000000000000..04840dfcdf847e3f61ec327ebc691b0dbd139da4 --- /dev/null +++ b/hunyuanvideo_foley/utils/helper.py @@ -0,0 +1,134 @@ +import collections.abc +from itertools import repeat +import importlib +import yaml +import time + +def default(value, default_val): + return default_val if value is None else value + + +def default_dtype(value, default_val): + if value is not None: + assert isinstance(value, type(default_val)), f"Expect {type(default_val)}, got {type(value)}." + return value + return default_val + + +def repeat_interleave(lst, num_repeats): + return [item for item in lst for _ in range(num_repeats)] + + +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): + x = tuple(x) + if len(x) == 1: + x = tuple(repeat(x[0], n)) + return x + return tuple(repeat(x, n)) + + return parse + + +to_1tuple = _ntuple(1) +to_2tuple = _ntuple(2) +to_3tuple = _ntuple(3) +to_4tuple = _ntuple(4) + + +def as_tuple(x): + if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): + return tuple(x) + if x is None or isinstance(x, (int, float, str)): + return (x,) + else: + raise ValueError(f"Unknown type {type(x)}") + + +def as_list_of_2tuple(x): + x = as_tuple(x) + if len(x) == 1: + x = (x[0], x[0]) + assert len(x) % 2 == 0, f"Expect even length, got {len(x)}." + lst = [] + for i in range(0, len(x), 2): + lst.append((x[i], x[i + 1])) + return lst + + +def find_multiple(n: int, k: int) -> int: + assert k > 0 + if n % k == 0: + return n + return n - (n % k) + k + + +def merge_dicts(dict1, dict2): + for key, value in dict2.items(): + if key in dict1 and isinstance(dict1[key], dict) and isinstance(value, dict): + merge_dicts(dict1[key], value) + else: + dict1[key] = value + return dict1 + + +def merge_yaml_files(file_list): + merged_config = {} + + for file in file_list: + with open(file, "r", encoding="utf-8") as f: + config = yaml.safe_load(f) + if config: + # Remove the first level + for key, value in config.items(): + if isinstance(value, dict): + merged_config = merge_dicts(merged_config, value) + else: + merged_config[key] = value + + return merged_config + + +def merge_dict(file_list): + merged_config = {} + + for file in file_list: + with open(file, "r", encoding="utf-8") as f: + config = yaml.safe_load(f) + if config: + merged_config = merge_dicts(merged_config, config) + + return merged_config + + +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + + +def readable_time(seconds): + """ Convert time seconds to a readable format: DD Days, HH Hours, MM Minutes, SS Seconds """ + seconds = int(seconds) + days, seconds = divmod(seconds, 86400) + hours, seconds = divmod(seconds, 3600) + minutes, seconds = divmod(seconds, 60) + if days > 0: + return f"{days} Days, {hours} Hours, {minutes} Minutes, {seconds} Seconds" + if hours > 0: + return f"{hours} Hours, {minutes} Minutes, {seconds} Seconds" + if minutes > 0: + return f"{minutes} Minutes, {seconds} Seconds" + return f"{seconds} Seconds" + + +def get_obj_from_cfg(cfg, reload=False): + if isinstance(cfg, str): + return get_obj_from_str(cfg, reload) + elif isinstance(cfg, (list, tuple,)): + return tuple([get_obj_from_str(c, reload) for c in cfg]) + else: + raise NotImplementedError(f"Not implemented for {type(cfg)}.") diff --git a/hunyuanvideo_foley/utils/media_utils.py b/hunyuanvideo_foley/utils/media_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d80e6cbfe5164caeb7340fa2fbbcb8a11ddd9c13 --- /dev/null +++ b/hunyuanvideo_foley/utils/media_utils.py @@ -0,0 +1,101 @@ +"""Media utilities for audio/video processing.""" + +import os +import subprocess +from pathlib import Path +from typing import Optional + +from loguru import logger + + +class MediaProcessingError(Exception): + """Exception raised for media processing errors.""" + pass + + +def merge_audio_video( + audio_path: str, + video_path: str, + output_path: str, + overwrite: bool = True, + quality: str = "high" +) -> str: + """ + Merge audio and video files using ffmpeg. + + Args: + audio_path: Path to input audio file + video_path: Path to input video file + output_path: Path for output video file + overwrite: Whether to overwrite existing output file + quality: Quality setting ('high', 'medium', 'low') + + Returns: + Path to the output file + + Raises: + MediaProcessingError: If input files don't exist or ffmpeg fails + FileNotFoundError: If ffmpeg is not installed + """ + # Validate input files + if not os.path.exists(audio_path): + raise MediaProcessingError(f"Audio file not found: {audio_path}") + if not os.path.exists(video_path): + raise MediaProcessingError(f"Video file not found: {video_path}") + + # Create output directory if needed + output_dir = Path(output_path).parent + output_dir.mkdir(parents=True, exist_ok=True) + + # Quality settings + quality_settings = { + "high": ["-b:a", "192k"], + "medium": ["-b:a", "128k"], + "low": ["-b:a", "96k"] + } + + # Build ffmpeg command + ffmpeg_command = [ + "ffmpeg", + "-i", video_path, + "-i", audio_path, + "-c:v", "copy", + "-c:a", "aac", + "-ac", "2", + "-af", "pan=stereo|c0=c0|c1=c0", + "-map", "0:v:0", + "-map", "1:a:0", + *quality_settings.get(quality, quality_settings["high"]), + ] + + if overwrite: + ffmpeg_command.append("-y") + + ffmpeg_command.append(output_path) + + try: + logger.info(f"Merging audio '{audio_path}' with video '{video_path}'") + process = subprocess.Popen( + ffmpeg_command, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True + ) + stdout, stderr = process.communicate() + + if process.returncode != 0: + error_msg = f"FFmpeg failed with return code {process.returncode}: {stderr}" + logger.error(error_msg) + raise MediaProcessingError(error_msg) + else: + logger.info(f"Successfully merged video saved to: {output_path}") + + except FileNotFoundError: + raise FileNotFoundError( + "ffmpeg not found. Please install ffmpeg: " + "https://ffmpeg.org/download.html" + ) + except Exception as e: + raise MediaProcessingError(f"Unexpected error during media processing: {e}") + + return output_path diff --git a/hunyuanvideo_foley/utils/model_utils.py b/hunyuanvideo_foley/utils/model_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..354ea223c7ef8ae4f8759f6763028d27e6ad13bc --- /dev/null +++ b/hunyuanvideo_foley/utils/model_utils.py @@ -0,0 +1,241 @@ +import torch +import os +from loguru import logger +from torchvision import transforms +from torchvision.transforms import v2 +from diffusers.utils.torch_utils import randn_tensor +from transformers import AutoTokenizer, AutoModel, ClapTextModelWithProjection +from ..models.dac_vae.model.dac import DAC +from ..models.synchformer import Synchformer +from ..models.hifi_foley import HunyuanVideoFoley +from .config_utils import load_yaml, AttributeDict +from .schedulers import FlowMatchDiscreteScheduler +from tqdm import tqdm + +def load_state_dict(model, model_path): + logger.info(f"Loading model state dict from: {model_path}") + state_dict = torch.load(model_path, map_location=lambda storage, loc: storage, weights_only=False) + + missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) + + if missing_keys: + logger.warning(f"Missing keys in state dict ({len(missing_keys)} keys):") + for key in missing_keys: + logger.warning(f" - {key}") + else: + logger.info("No missing keys found") + + if unexpected_keys: + logger.warning(f"Unexpected keys in state dict ({len(unexpected_keys)} keys):") + for key in unexpected_keys: + logger.warning(f" - {key}") + else: + logger.info("No unexpected keys found") + + logger.info("Model state dict loaded successfully") + return model + +def load_model(model_path, config_path, device): + logger.info("Starting model loading process...") + logger.info(f"Configuration file: {config_path}") + logger.info(f"Model weights dir: {model_path}") + logger.info(f"Target device: {device}") + + cfg = load_yaml(config_path) + logger.info("Configuration loaded successfully") + + # HunyuanVideoFoley + logger.info("Loading HunyuanVideoFoley main model...") + foley_model = HunyuanVideoFoley(cfg, dtype=torch.bfloat16, device=device).to(device=device, dtype=torch.bfloat16) + foley_model = load_state_dict(foley_model, os.path.join(model_path, "hunyuanvideo_foley.pth")) + foley_model.eval() + logger.info("HunyuanVideoFoley model loaded and set to evaluation mode") + + # DAC-VAE + dac_path = os.path.join(model_path, "vae_128d_48k.pth") + logger.info(f"Loading DAC VAE model from: {dac_path}") + dac_model = DAC.load(dac_path) + dac_model = dac_model.to(device) + dac_model.requires_grad_(False) + dac_model.eval() + logger.info("DAC VAE model loaded successfully") + + # Siglip2 visual-encoder + logger.info("Loading SigLIP2 visual encoder...") + siglip2_preprocess = transforms.Compose([ + transforms.Resize((512, 512)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + siglip2_model = AutoModel.from_pretrained("google/siglip2-base-patch16-512").to(device).eval() + logger.info("SigLIP2 model and preprocessing pipeline loaded successfully") + + # clap text-encoder + logger.info("Loading CLAP text encoder...") + clap_tokenizer = AutoTokenizer.from_pretrained("laion/larger_clap_general") + clap_model = ClapTextModelWithProjection.from_pretrained("laion/larger_clap_general").to(device) + logger.info("CLAP tokenizer and model loaded successfully") + + # syncformer + syncformer_path = os.path.join(model_path, "synchformer_state_dict.pth") + logger.info(f"Loading Synchformer model from: {syncformer_path}") + syncformer_preprocess = v2.Compose( + [ + v2.Resize(224, interpolation=v2.InterpolationMode.BICUBIC), + v2.CenterCrop(224), + v2.ToImage(), + v2.ToDtype(torch.float32, scale=True), + v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ] + ) + + syncformer_model = Synchformer() + syncformer_model.load_state_dict(torch.load(syncformer_path, weights_only=False, map_location="cpu")) + syncformer_model = syncformer_model.to(device).eval() + logger.info("Synchformer model and preprocessing pipeline loaded successfully") + + + logger.info("Creating model dictionary with attribute access...") + model_dict = AttributeDict({ + 'foley_model': foley_model, + 'dac_model': dac_model, + 'siglip2_preprocess': siglip2_preprocess, + 'siglip2_model': siglip2_model, + 'clap_tokenizer': clap_tokenizer, + 'clap_model': clap_model, + 'syncformer_preprocess': syncformer_preprocess, + 'syncformer_model': syncformer_model, + 'device': device, + }) + + logger.info("All models loaded successfully!") + logger.info("Available model components:") + for key in model_dict.keys(): + logger.info(f" - {key}") + logger.info("Models can be accessed via attribute notation (e.g., models.foley_model)") + + return model_dict, cfg + +def retrieve_timesteps( + scheduler, + num_inference_steps, + device, + **kwargs, +): + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +def prepare_latents(scheduler, batch_size, num_channels_latents, length, dtype, device): + shape = (batch_size, num_channels_latents, int(length)) + latents = randn_tensor(shape, device=device, dtype=dtype) + + # Check existence to make it compatible with FlowMatchEulerDiscreteScheduler + if hasattr(scheduler, "init_noise_sigma"): + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * scheduler.init_noise_sigma + + return latents + + +@torch.no_grad() +def denoise_process(visual_feats, text_feats, audio_len_in_s, model_dict, cfg, guidance_scale=4.5, num_inference_steps=50, batch_size=1): + + target_dtype = model_dict.foley_model.dtype + autocast_enabled = target_dtype != torch.float32 + device = model_dict.device + + scheduler = FlowMatchDiscreteScheduler( + shift=cfg.diffusion_config.sample_flow_shift, + reverse=cfg.diffusion_config.flow_reverse, + solver=cfg.diffusion_config.flow_solver, + use_flux_shift=cfg.diffusion_config.sample_use_flux_shift, + flux_base_shift=cfg.diffusion_config.flux_base_shift, + flux_max_shift=cfg.diffusion_config.flux_max_shift, + ) + + timesteps, num_inference_steps = retrieve_timesteps( + scheduler, + num_inference_steps, + device, + ) + + latents = prepare_latents( + scheduler, + batch_size=batch_size, + num_channels_latents=cfg.model_config.model_kwargs.audio_vae_latent_dim, + length=audio_len_in_s * cfg.model_config.model_kwargs.audio_frame_rate, + dtype=target_dtype, + device=device, + ) + + # Denoise loop + for i, t in tqdm(enumerate(timesteps), total=len(timesteps), desc="Denoising steps"): + # noise latents + latent_input = torch.cat([latents] * 2) if guidance_scale > 1.0 else latents + latent_input = scheduler.scale_model_input(latent_input, t) + + t_expand = t.repeat(latent_input.shape[0]) + + # siglip2 features + siglip2_feat = visual_feats.siglip2_feat.repeat(batch_size, 1, 1) # Repeat for batch_size + uncond_siglip2_feat = model_dict.foley_model.get_empty_clip_sequence( + bs=batch_size, len=siglip2_feat.shape[1] + ).to(device) + + if guidance_scale is not None and guidance_scale > 1.0: + siglip2_feat_input = torch.cat([uncond_siglip2_feat, siglip2_feat], dim=0) + else: + siglip2_feat_input = siglip2_feat + + # syncformer features + syncformer_feat = visual_feats.syncformer_feat.repeat(batch_size, 1, 1) # Repeat for batch_size + uncond_syncformer_feat = model_dict.foley_model.get_empty_sync_sequence( + bs=batch_size, len=syncformer_feat.shape[1] + ).to(device) + if guidance_scale is not None and guidance_scale > 1.0: + syncformer_feat_input = torch.cat([uncond_syncformer_feat, syncformer_feat], dim=0) + else: + syncformer_feat_input = syncformer_feat + + # text features + text_feat_repeated = text_feats.text_feat.repeat(batch_size, 1, 1) # Repeat for batch_size + uncond_text_feat_repeated = text_feats.uncond_text_feat.repeat(batch_size, 1, 1) # Repeat for batch_size + if guidance_scale is not None and guidance_scale > 1.0: + text_feat_input = torch.cat([uncond_text_feat_repeated, text_feat_repeated], dim=0) + else: + text_feat_input = text_feat_repeated + + with torch.autocast(device_type=device.type, enabled=autocast_enabled, dtype=target_dtype): + # Predict the noise residual + noise_pred = model_dict.foley_model( + x=latent_input, + t=t_expand, + cond=text_feat_input, + clip_feat=siglip2_feat_input, + sync_feat=syncformer_feat_input, + return_dict=True, + )["x"] + + noise_pred = noise_pred.to(dtype=torch.float32) + + if guidance_scale is not None and guidance_scale > 1.0: + # Perform classifier-free guidance + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # Compute the previous noisy sample x_t -> x_t-1 + latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + # Post-process the latents to audio + + with torch.no_grad(): + audio = model_dict.dac_model.decode(latents) + audio = audio.float().cpu() + + audio = audio[:, :int(audio_len_in_s*model_dict.dac_model.sample_rate)] + + return audio, model_dict.dac_model.sample_rate + + diff --git a/hunyuanvideo_foley/utils/schedulers/__init__.py b/hunyuanvideo_foley/utils/schedulers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3fea4433560c92fc9d8569993447d0bdb456dc9e --- /dev/null +++ b/hunyuanvideo_foley/utils/schedulers/__init__.py @@ -0,0 +1,2 @@ +from diffusers.schedulers import DDPMScheduler, EulerDiscreteScheduler +from .scheduling_flow_match_discrete import FlowMatchDiscreteScheduler \ No newline at end of file diff --git a/hunyuanvideo_foley/utils/schedulers/scheduling_flow_match_discrete.py b/hunyuanvideo_foley/utils/schedulers/scheduling_flow_match_discrete.py new file mode 100644 index 0000000000000000000000000000000000000000..2f58814479643b42eb18158db2fcc2a29544424e --- /dev/null +++ b/hunyuanvideo_foley/utils/schedulers/scheduling_flow_match_discrete.py @@ -0,0 +1,376 @@ +import math +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import numpy as np +import torch + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.utils import BaseOutput, logging +from diffusers.schedulers.scheduling_utils import SchedulerMixin + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class FlowMatchDiscreteSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's `step` function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + """ + + prev_sample: torch.FloatTensor + + +class FlowMatchDiscreteScheduler(SchedulerMixin, ConfigMixin): + """ + Euler scheduler. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + timestep_spacing (`str`, defaults to `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + shift (`float`, defaults to 1.0): + The shift value for the timestep schedule. + reverse (`bool`, defaults to `True`): + Whether to reverse the timestep schedule. + """ + + _compatibles = [] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + shift: float = 1.0, + reverse: bool = True, + solver: str = "euler", + use_flux_shift: bool = False, + flux_base_shift: float = 0.5, + flux_max_shift: float = 1.15, + n_tokens: Optional[int] = None, + ): + sigmas = torch.linspace(1, 0, num_train_timesteps + 1) + + if not reverse: + sigmas = sigmas.flip(0) + + self.sigmas = sigmas + # the value fed to model + self.timesteps = (sigmas[:-1] * num_train_timesteps).to(dtype=torch.float32) + self.timesteps_full = (sigmas * num_train_timesteps).to(dtype=torch.float32) + + self._step_index = None + self._begin_index = None + + self.supported_solver = [ + "euler", + "heun-2", "midpoint-2", + "kutta-4", + ] + if solver not in self.supported_solver: + raise ValueError(f"Solver {solver} not supported. Supported solvers: {self.supported_solver}") + + # empty dt and derivative (for heun) + self.derivative_1 = None + self.derivative_2 = None + self.derivative_3 = None + self.dt = None + + @property + def step_index(self): + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + def _sigma_to_t(self, sigma): + return sigma * self.config.num_train_timesteps + + @property + def state_in_first_order(self): + return self.derivative_1 is None + + @property + def state_in_second_order(self): + return self.derivative_2 is None + + @property + def state_in_third_order(self): + return self.derivative_3 is None + + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, + n_tokens: int = None): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + n_tokens (`int`, *optional*): + Number of tokens in the input sequence. + """ + self.num_inference_steps = num_inference_steps + + sigmas = torch.linspace(1, 0, num_inference_steps + 1) + + # Apply timestep shift + if self.config.use_flux_shift: + assert isinstance(n_tokens, int), "n_tokens should be provided for flux shift" + mu = self.get_lin_function(y1=self.config.flux_base_shift, y2=self.config.flux_max_shift)(n_tokens) + sigmas = self.flux_time_shift(mu, 1.0, sigmas) + elif self.config.shift != 1.: + sigmas = self.sd3_time_shift(sigmas) + + if not self.config.reverse: + sigmas = 1 - sigmas + + self.sigmas = sigmas + self.timesteps = (sigmas[:-1] * self.config.num_train_timesteps).to(dtype=torch.float32, device=device) + self.timesteps_full = (sigmas * self.config.num_train_timesteps).to(dtype=torch.float32, device=device) + + # empty dt and derivative (for kutta) + self.derivative_1 = None + self.derivative_2 = None + self.derivative_3 = None + self.dt = None + + # Reset step index + self._step_index = None + + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + pos = 1 if len(indices) > 1 else 0 + + return indices[pos].item() + + def _init_step_index(self, timestep): + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor: + return sample + + @staticmethod + def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15): + m = (y2 - y1) / (x2 - x1) + b = y1 - m * x1 + return lambda x: m * x + b + + @staticmethod + def flux_time_shift(mu: float, sigma: float, t: torch.Tensor): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + def sd3_time_shift(self, t: torch.Tensor): + return (self.config.shift * t) / (1 + (self.config.shift - 1) * t) + + def step( + self, + model_output: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + sample: torch.FloatTensor, + pred_uncond: torch.FloatTensor = None, + generator: Optional[torch.Generator] = None, + n_tokens: Optional[int] = None, + return_dict: bool = True, + ) -> Union[FlowMatchDiscreteSchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): + The direct output from learned diffusion model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + generator (`torch.Generator`, *optional*): + A random number generator. + n_tokens (`int`, *optional*): + Number of tokens in the input sequence. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or + tuple. + + Returns: + [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is + returned, otherwise a tuple is returned where the first element is the sample tensor. + """ + + if ( + isinstance(timestep, int) + or isinstance(timestep, torch.IntTensor) + or isinstance(timestep, torch.LongTensor) + ): + raise ValueError( + ( + "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" + " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" + " one of the `scheduler.timesteps` as a timestep." + ), + ) + + if self.step_index is None: + self._init_step_index(timestep) + + # Upcast to avoid precision issues when computing prev_sample + sample = sample.to(torch.float32) + model_output = model_output.to(torch.float32) + pred_uncond = pred_uncond.to(torch.float32) if pred_uncond is not None else None + + # dt = self.sigmas[self.step_index + 1] - self.sigmas[self.step_index] + sigma = self.sigmas[self.step_index] + sigma_next = self.sigmas[self.step_index + 1] + + last_inner_step = True + if self.config.solver == "euler": + derivative, dt, sample, last_inner_step = self.first_order_method(model_output, sigma, sigma_next, sample) + elif self.config.solver in ["heun-2", "midpoint-2"]: + derivative, dt, sample, last_inner_step = self.second_order_method(model_output, sigma, sigma_next, sample) + elif self.config.solver == "kutta-4": + derivative, dt, sample, last_inner_step = self.fourth_order_method(model_output, sigma, sigma_next, sample) + else: + raise ValueError(f"Solver {self.config.solver} not supported. Supported solvers: {self.supported_solver}") + + prev_sample = sample + derivative * dt + + # Cast sample back to model compatible dtype + # prev_sample = prev_sample.to(model_output.dtype) + + # upon completion increase step index by one + if last_inner_step: + self._step_index += 1 + + if not return_dict: + return (prev_sample,) + + return FlowMatchDiscreteSchedulerOutput(prev_sample=prev_sample) + + def first_order_method(self, model_output, sigma, sigma_next, sample): + derivative = model_output.float() + dt = sigma_next - sigma + return derivative, dt, sample, True + + def second_order_method(self, model_output, sigma, sigma_next, sample): + if self.state_in_first_order: + # store for 2nd order step + self.derivative_1 = model_output + self.dt = sigma_next - sigma + self.sample = sample + + derivative = model_output + if self.config.solver == 'heun-2': + dt = self.dt + elif self.config.solver == 'midpoint-2': + dt = self.dt / 2 + else: + raise NotImplementedError(f"Solver {self.config.solver} not supported.") + last_inner_step = False + + else: + if self.config.solver == 'heun-2': + derivative = 0.5 * (self.derivative_1 + model_output) + elif self.config.solver == 'midpoint-2': + derivative = model_output + else: + raise NotImplementedError(f"Solver {self.config.solver} not supported.") + + # 3. take prev timestep & sample + dt = self.dt + sample = self.sample + last_inner_step = True + + # free dt and derivative + # Note, this puts the scheduler in "first order mode" + self.derivative_1 = None + self.dt = None + self.sample = None + + return derivative, dt, sample, last_inner_step + + def fourth_order_method(self, model_output, sigma, sigma_next, sample): + if self.state_in_first_order: + self.derivative_1 = model_output + self.dt = sigma_next - sigma + self.sample = sample + derivative = model_output + dt = self.dt / 2 + last_inner_step = False + + elif self.state_in_second_order: + self.derivative_2 = model_output + derivative = model_output + dt = self.dt / 2 + last_inner_step = False + + elif self.state_in_third_order: + self.derivative_3 = model_output + derivative = model_output + dt = self.dt + last_inner_step = False + + else: + derivative = 1/6 * self.derivative_1 + 1/3 * self.derivative_2 + 1/3 * self.derivative_3 + 1/6 * model_output + + # 3. take prev timestep & sample + dt = self.dt + sample = self.sample + last_inner_step = True + + # free dt and derivative + # Note, this puts the scheduler in "first order mode" + self.derivative_1 = None + self.derivative_2 = None + self.derivative_3 = None + self.dt = None + self.sample = None + + return derivative, dt, sample, last_inner_step + + def __len__(self): + return self.config.num_train_timesteps diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..32db4a62273e6775384c32d431232bc8449c3987 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,61 @@ +[tool.black] +line-length = 120 +target-version = ['py38', 'py39', 'py310', 'py311'] +include = '\.pyi?$' +exclude = ''' +/( + \.eggs + | \.git + | \.hg + | \.mypy_cache + | \.tox + | \.venv + | _build + | buck-out + | build + | dist +)/ +''' + +[tool.isort] +profile = "black" +line_length = 120 +multi_line_output = 3 +include_trailing_comma = true +force_grid_wrap = 0 +use_parentheses = true +ensure_newline_before_comments = true + +[tool.flake8] +max-line-length = 120 +select = ["E", "W", "F"] +ignore = [ + "E203", # whitespace before ':' + "E501", # line too long + "W503", # line break before binary operator +] +exclude = [ + ".git", + "__pycache__", + "build", + "dist", + ".eggs", + "*.egg-info", + ".venv", + ".tox", +] + +[tool.mypy] +python_version = "3.8" +warn_return_any = true +warn_unused_configs = true +disallow_untyped_defs = false +disallow_incomplete_defs = false +check_untyped_defs = true +disallow_untyped_decorators = false +no_implicit_optional = true +warn_redundant_casts = true +warn_unused_ignores = true +warn_no_return = true +warn_unreachable = true +strict_equality = true \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..9a765757c67aebb12abf380c53ea8f316ec55870 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,48 @@ +# Core ML dependencies +torch>=2.0.0 +torchvision>=0.15.0 +torchaudio>=2.0.0 +numpy==1.26.4 +scipy + +# Deep Learning frameworks +diffusers +timm +accelerate + +# Transformers and NLP +git+https://github.com/huggingface/transformers@v4.49.0-SigLIP-2 +sentencepiece + +# Audio processing +git+https://github.com/descriptinc/audiotools + +# Video/Image processing +pillow +av +einops + +# Configuration and utilities +pyyaml +omegaconf +easydict +loguru +tqdm +setuptools + +# Data handling +pandas +pyarrow + +# Web interface +gradio==3.50.2 + +# Network +urllib3==2.4.0 + +# Development dependencies (optional) +black>=23.0.0 +isort>=5.12.0 +flake8>=6.0.0 +mypy>=1.3.0 +pre-commit>=3.0.0