Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +7 -0
- .gitignore +42 -0
- LICENSE.txt +17 -0
- README.md +127 -12
- assets/comp_effic.png +3 -0
- assets/data_for_diff_stage.jpg +3 -0
- assets/i2v_res.png +3 -0
- assets/logo.png +0 -0
- assets/t2v_res.jpg +3 -0
- assets/vben_vs_sota.png +3 -0
- assets/video_dit_arch.jpg +3 -0
- assets/video_vae_res.jpg +3 -0
- docs/CHANGELOG.md +157 -0
- docs/CLI.md +226 -0
- docs/GETTING_STARTED.md +194 -0
- docs/INSTALLATION.md +170 -0
- docs/LORAS.md +224 -0
- docs/MODELS.md +268 -0
- docs/TROUBLESHOOTING.md +338 -0
- docs/VACE.md +190 -0
- fantasytalking/infer.py +36 -0
- fantasytalking/model.py +162 -0
- fantasytalking/utils.py +52 -0
- hyvideo/__init__.py +0 -0
- hyvideo/config.py +534 -0
- hyvideo/constants.py +164 -0
- hyvideo/data_kits/audio_dataset.py +170 -0
- hyvideo/data_kits/audio_preprocessor.py +76 -0
- hyvideo/data_kits/data_tools.py +41 -0
- hyvideo/data_kits/face_align/__init__.py +1 -0
- hyvideo/data_kits/face_align/align.py +34 -0
- hyvideo/data_kits/face_align/detface.py +283 -0
- hyvideo/diffusion/__init__.py +2 -0
- hyvideo/diffusion/pipelines/__init__.py +2 -0
- hyvideo/diffusion/pipelines/pipeline_hunyuan_video.py +1438 -0
- hyvideo/diffusion/pipelines/pipeline_hunyuan_video_audio.py +1362 -0
- hyvideo/diffusion/schedulers/__init__.py +1 -0
- hyvideo/diffusion/schedulers/scheduling_flow_match_discrete.py +255 -0
- hyvideo/hunyuan.py +1062 -0
- hyvideo/modules/__init__.py +26 -0
- hyvideo/modules/activation_layers.py +23 -0
- hyvideo/modules/attenion.py +362 -0
- hyvideo/modules/audio_adapters.py +220 -0
- hyvideo/modules/embed_layers.py +158 -0
- hyvideo/modules/mlp_layers.py +131 -0
- hyvideo/modules/models.py +1159 -0
- hyvideo/modules/modulate_layers.py +136 -0
- hyvideo/modules/norm_layers.py +88 -0
- hyvideo/modules/original models.py +760 -0
- hyvideo/modules/placement.py +389 -0
.gitattributes
CHANGED
@@ -33,3 +33,10 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
assets/comp_effic.png filter=lfs diff=lfs merge=lfs -text
|
37 |
+
assets/data_for_diff_stage.jpg filter=lfs diff=lfs merge=lfs -text
|
38 |
+
assets/i2v_res.png filter=lfs diff=lfs merge=lfs -text
|
39 |
+
assets/t2v_res.jpg filter=lfs diff=lfs merge=lfs -text
|
40 |
+
assets/vben_vs_sota.png filter=lfs diff=lfs merge=lfs -text
|
41 |
+
assets/video_dit_arch.jpg filter=lfs diff=lfs merge=lfs -text
|
42 |
+
assets/video_vae_res.jpg filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.*
|
2 |
+
*.py[cod]
|
3 |
+
# *.jpg
|
4 |
+
*.jpeg
|
5 |
+
# *.png
|
6 |
+
*.gif
|
7 |
+
*.bmp
|
8 |
+
*.mp4
|
9 |
+
*.mov
|
10 |
+
*.mkv
|
11 |
+
*.log
|
12 |
+
*.zip
|
13 |
+
*.pt
|
14 |
+
*.pth
|
15 |
+
*.ckpt
|
16 |
+
*.safetensors
|
17 |
+
*.json
|
18 |
+
# *.txt
|
19 |
+
*.backup
|
20 |
+
*.pkl
|
21 |
+
*.html
|
22 |
+
*.pdf
|
23 |
+
*.whl
|
24 |
+
*.exe
|
25 |
+
cache
|
26 |
+
__pycache__/
|
27 |
+
storage/
|
28 |
+
samples/
|
29 |
+
!.gitignore
|
30 |
+
!requirements.txt
|
31 |
+
.DS_Store
|
32 |
+
*DS_Store
|
33 |
+
google/
|
34 |
+
Wan2.1-T2V-14B/
|
35 |
+
Wan2.1-T2V-1.3B/
|
36 |
+
Wan2.1-I2V-14B-480P/
|
37 |
+
Wan2.1-I2V-14B-720P/
|
38 |
+
outputs/
|
39 |
+
gradio_outputs/
|
40 |
+
ckpts/
|
41 |
+
loras/
|
42 |
+
loras_i2v/
|
LICENSE.txt
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FREE for Non Commercial USE
|
2 |
+
|
3 |
+
You are free to:
|
4 |
+
- Share — copy and redistribute the material in any medium or format
|
5 |
+
- Adapt — remix, transform, and build upon the material
|
6 |
+
The licensor cannot revoke these freedoms as long as you follow the license terms.
|
7 |
+
|
8 |
+
Under the following terms:
|
9 |
+
- Attribution — You must give appropriate credit , provide a link to the license, and indicate if changes were made . You may do so in any reasonable manner, but not in any way that suggests the licensor endorses you or your use.
|
10 |
+
NonCommercial — You may not use the material for commercial purposes .
|
11 |
+
|
12 |
+
- No additional restrictions — You may not apply legal terms or technological measures that legally restrict others from doing anything the license permits.
|
13 |
+
Notices:
|
14 |
+
|
15 |
+
- You do not have to comply with the license for elements of the material in the public domain or where your use is permitted by an applicable exception or limitation .
|
16 |
+
|
17 |
+
No warranties are given. The license may not give you all of the permissions necessary for your intended use. For example, other rights such as publicity, privacy, or moral rights may limit how you use the material.
|
README.md
CHANGED
@@ -1,12 +1,127 @@
|
|
1 |
-
---
|
2 |
-
title: Wan2GP
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Wan2GP
|
3 |
+
app_file: wgp.py
|
4 |
+
sdk: gradio
|
5 |
+
sdk_version: 5.23.0
|
6 |
+
---
|
7 |
+
# WanGP
|
8 |
+
|
9 |
+
-----
|
10 |
+
<p align="center">
|
11 |
+
<b>WanGP by DeepBeepMeep : The best Open Source Video Generative Models Accessible to the GPU Poor</b>
|
12 |
+
</p>
|
13 |
+
|
14 |
+
WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models with:
|
15 |
+
- Low VRAM requirements (as low as 6 GB of VRAM is sufficient for certain models)
|
16 |
+
- Support for old GPUs (RTX 10XX, 20xx, ...)
|
17 |
+
- Very Fast on the latest GPUs
|
18 |
+
- Easy to use Full Web based interface
|
19 |
+
- Auto download of the required model adapted to your specific architecture
|
20 |
+
- Tools integrated to facilitate Video Generation : Mask Editor, Prompt Enhancer, Temporal and Spatial Generation
|
21 |
+
- Loras Support to customize each model
|
22 |
+
- Queuing system : make your shopping list of videos to generate and come back later
|
23 |
+
|
24 |
+
**Discord Server to get Help from Other Users and show your Best Videos:** https://discord.gg/g7efUW9jGV
|
25 |
+
|
26 |
+
**Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep
|
27 |
+
|
28 |
+
## 🔥 Latest Updates
|
29 |
+
### June 11 2025: WanGP v5.5
|
30 |
+
👋 *Hunyuan Video Custom Audio*: it is similar to Hunyuan Video Avatar except there isn't any lower limit on the number of frames and you can use your reference images in a different context than the image itself\
|
31 |
+
*Hunyuan Video Custom Edit*: Hunyuan Video Controlnet, use it to do inpainting and replace a person in a video while still keeping his poses. Similar to Vace but less restricted than the Wan models in terms of content...
|
32 |
+
|
33 |
+
|
34 |
+
### June 6 2025: WanGP v5.41
|
35 |
+
👋 Bonus release: Support for **AccVideo** Lora to speed up x2 Video generations in Wan models. Check the Loras documentation to get the usage instructions of AccVideo.\
|
36 |
+
You will need to do a *pip install -r requirements.txt*
|
37 |
+
|
38 |
+
### June 6 2025: WanGP v5.4
|
39 |
+
👋 World Exclusive : **Hunyuan Video Avatar** Support ! You won't need 80 GB of VRAM nor 32 GB oF VRAM, just 10 GB of VRAM will be sufficient to generate up to 15s of high quality speech / song driven Video at a high speed with no quality degradation. Support for TeaCache included.\
|
40 |
+
Here is a link to the original repo where you will find some very interesting documentation and examples. https://github.com/Tencent-Hunyuan/HunyuanVideo-Avatar. Kudos to the Hunyuan Video Avatar team for the best model of its kind.\
|
41 |
+
Also many thanks to Reevoy24 for his repackaging / completing the documentation
|
42 |
+
|
43 |
+
### May 28 2025: WanGP v5.31
|
44 |
+
👋 Added **Phantom 14B**, a model that you can use to transfer objects / people in the video. My preference goes to Vace that remains the king of controlnets.
|
45 |
+
VACE improvements: Better sliding window transitions, image mask support in Matanyone, new Extend Video feature, and enhanced background removal options.
|
46 |
+
|
47 |
+
### May 26, 2025: WanGP v5.3
|
48 |
+
👋 Settings management revolution! Now you can:
|
49 |
+
- Select any generated video and click *Use Selected Video Settings* to instantly reuse its configuration
|
50 |
+
- Drag & drop videos to automatically extract their settings metadata
|
51 |
+
- Export/import settings as JSON files for easy sharing and backup
|
52 |
+
|
53 |
+
### May 20, 2025: WanGP v5.2
|
54 |
+
👋 **CausVid support** - Generate videos in just 4-12 steps with the new distilled Wan model! Also added experimental MoviiGen for 1080p generation (20GB+ VRAM required). Check the Loras documentation to get the usage instructions of CausVid.
|
55 |
+
|
56 |
+
### May 18, 2025: WanGP v5.1
|
57 |
+
👋 **LTX Video 13B Distilled** - Generate high-quality videos in less than one minute!
|
58 |
+
|
59 |
+
### May 17, 2025: WanGP v5.0
|
60 |
+
👋 **One App to Rule Them All!** Added Hunyuan Video and LTX Video support, plus Vace 14B and integrated prompt enhancer.
|
61 |
+
|
62 |
+
See full changelog: **[Changelog](docs/CHANGELOG.md)**
|
63 |
+
|
64 |
+
## 📋 Table of Contents
|
65 |
+
|
66 |
+
- [🚀 Quick Start](#-quick-start)
|
67 |
+
- [📦 Installation](#-installation)
|
68 |
+
- [🎯 Usage](#-usage)
|
69 |
+
- [📚 Documentation](#-documentation)
|
70 |
+
- [🔗 Related Projects](#-related-projects)
|
71 |
+
|
72 |
+
## 🚀 Quick Start
|
73 |
+
|
74 |
+
**One-click installation:** Get started instantly with [Pinokio App](https://pinokio.computer/)
|
75 |
+
|
76 |
+
**Manual installation:**
|
77 |
+
```bash
|
78 |
+
git clone https://github.com/deepbeepmeep/Wan2GP.git
|
79 |
+
cd Wan2GP
|
80 |
+
conda create -n wan2gp python=3.10.9
|
81 |
+
conda activate wan2gp
|
82 |
+
pip install torch==2.6.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/test/cu124
|
83 |
+
pip install -r requirements.txt
|
84 |
+
```
|
85 |
+
|
86 |
+
**Run the application:**
|
87 |
+
```bash
|
88 |
+
python wgp.py # Text-to-video (default)
|
89 |
+
python wgp.py --i2v # Image-to-video
|
90 |
+
```
|
91 |
+
|
92 |
+
## 📦 Installation
|
93 |
+
|
94 |
+
For detailed installation instructions for different GPU generations:
|
95 |
+
- **[Installation Guide](docs/INSTALLATION.md)** - Complete setup instructions for RTX 10XX to RTX 50XX
|
96 |
+
|
97 |
+
## 🎯 Usage
|
98 |
+
|
99 |
+
### Basic Usage
|
100 |
+
- **[Getting Started Guide](docs/GETTING_STARTED.md)** - First steps and basic usage
|
101 |
+
- **[Models Overview](docs/MODELS.md)** - Available models and their capabilities
|
102 |
+
|
103 |
+
### Advanced Features
|
104 |
+
- **[Loras Guide](docs/LORAS.md)** - Using and managing Loras for customization
|
105 |
+
- **[VACE ControlNet](docs/VACE.md)** - Advanced video control and manipulation
|
106 |
+
- **[Command Line Reference](docs/CLI.md)** - All available command line options
|
107 |
+
|
108 |
+
## 📚 Documentation
|
109 |
+
|
110 |
+
- **[Changelog](docs/CHANGELOG.md)** - Latest updates and version history
|
111 |
+
- **[Troubleshooting](docs/TROUBLESHOOTING.md)** - Common issues and solutions
|
112 |
+
|
113 |
+
## 🔗 Related Projects
|
114 |
+
|
115 |
+
### Other Models for the GPU Poor
|
116 |
+
- **[HuanyuanVideoGP](https://github.com/deepbeepmeep/HunyuanVideoGP)** - One of the best open source Text to Video generators
|
117 |
+
- **[Hunyuan3D-2GP](https://github.com/deepbeepmeep/Hunyuan3D-2GP)** - Image to 3D and text to 3D tool
|
118 |
+
- **[FluxFillGP](https://github.com/deepbeepmeep/FluxFillGP)** - Inpainting/outpainting tools based on Flux
|
119 |
+
- **[Cosmos1GP](https://github.com/deepbeepmeep/Cosmos1GP)** - Text to world generator and image/video to world
|
120 |
+
- **[OminiControlGP](https://github.com/deepbeepmeep/OminiControlGP)** - Flux-derived application for object transfer
|
121 |
+
- **[YuE GP](https://github.com/deepbeepmeep/YuEGP)** - Song generator with instruments and singer's voice
|
122 |
+
|
123 |
+
---
|
124 |
+
|
125 |
+
<p align="center">
|
126 |
+
Made with ❤️ by DeepBeepMeep
|
127 |
+
</p>
|
assets/comp_effic.png
ADDED
![]() |
Git LFS Details
|
assets/data_for_diff_stage.jpg
ADDED
![]() |
Git LFS Details
|
assets/i2v_res.png
ADDED
![]() |
Git LFS Details
|
assets/logo.png
ADDED
![]() |
assets/t2v_res.jpg
ADDED
![]() |
Git LFS Details
|
assets/vben_vs_sota.png
ADDED
![]() |
Git LFS Details
|
assets/video_dit_arch.jpg
ADDED
![]() |
Git LFS Details
|
assets/video_vae_res.jpg
ADDED
![]() |
Git LFS Details
|
docs/CHANGELOG.md
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Changelog
|
2 |
+
|
3 |
+
## 🔥 Latest News
|
4 |
+
### June 11 2025: WanGP v5.5
|
5 |
+
👋 *Hunyuan Video Custom Audio*: it is similar to Hunyuan Video Avatar excpet there isn't any lower limit on the number of frames and you can use your reference images in a different context than the image itself\
|
6 |
+
*Hunyuan Video Custom Edit*: Hunyuan Video Controlnet, use it to do inpainting and replace a person in a video while still keeping his poses. Similar to Vace but less restricted than the Wan models in terms of content...
|
7 |
+
|
8 |
+
### June 6 2025: WanGP v5.41
|
9 |
+
👋 Bonus release: Support for **AccVideo** Lora to speed up x2 Video generations in Wan models. Check the Loras documentation to get the usage instructions of AccVideo.
|
10 |
+
|
11 |
+
### June 6 2025: WanGP v5.4
|
12 |
+
👋 World Exclusive : Hunyuan Video Avatar Support ! You won't need 80 GB of VRAM nor 32 GB oF VRAM, just 10 GB of VRAM will be sufficient to generate up to 15s of high quality speech / song driven Video at a high speed with no quality degradation. Support for TeaCache included.
|
13 |
+
|
14 |
+
### May 26, 2025: WanGP v5.3
|
15 |
+
👋 Happy with a Video generation and want to do more generations using the same settings but you can't remember what you did or you find it too hard to copy/paste one per one each setting from the file metadata? Rejoice! There are now multiple ways to turn this tedious process into a one click task:
|
16 |
+
- Select one Video recently generated in the Video Gallery and click *Use Selected Video Settings*
|
17 |
+
- Click *Drop File Here* and select a Video you saved somewhere, if the settings metadata have been saved with the Video you will be able to extract them automatically
|
18 |
+
- Click *Export Settings to File* to save on your harddrive the current settings. You will be able to use them later again by clicking *Drop File Here* and select this time a Settings json file
|
19 |
+
|
20 |
+
### May 23, 2025: WanGP v5.21
|
21 |
+
👋 Improvements for Vace: better transitions between Sliding Windows, Support for Image masks in Matanyone, new Extend Video for Vace, different types of automated background removal
|
22 |
+
|
23 |
+
### May 20, 2025: WanGP v5.2
|
24 |
+
👋 Added support for Wan CausVid which is a distilled Wan model that can generate nice looking videos in only 4 to 12 steps. The great thing is that Kijai (Kudos to him!) has created a CausVid Lora that can be combined with any existing Wan t2v model 14B like Wan Vace 14B. See [LORAS.md](LORAS.md) for instructions on how to use CausVid.
|
25 |
+
|
26 |
+
Also as an experiment I have added support for the MoviiGen, the first model that claims to be capable of generating 1080p videos (if you have enough VRAM (20GB...) and be ready to wait for a long time...). Don't hesitate to share your impressions on the Discord server.
|
27 |
+
|
28 |
+
### May 18, 2025: WanGP v5.1
|
29 |
+
👋 Bonus Day, added LTX Video 13B Distilled: generate in less than one minute, very high quality Videos!
|
30 |
+
|
31 |
+
### May 17, 2025: WanGP v5.0
|
32 |
+
👋 One App to Rule Them All! Added support for the other great open source architectures:
|
33 |
+
- **Hunyuan Video**: text 2 video (one of the best, if not the best t2v), image 2 video and the recently released Hunyuan Custom (very good identity preservation when injecting a person into a video)
|
34 |
+
- **LTX Video 13B** (released last week): very long video support and fast 720p generation. Wan GP version has been greatly optimized and reduced LTX Video VRAM requirements by 4!
|
35 |
+
|
36 |
+
Also:
|
37 |
+
- Added support for the best Control Video Model, released 2 days ago: Vace 14B
|
38 |
+
- New Integrated prompt enhancer to increase the quality of the generated videos
|
39 |
+
|
40 |
+
*You will need one more `pip install -r requirements.txt`*
|
41 |
+
|
42 |
+
### May 5, 2025: WanGP v4.5
|
43 |
+
👋 FantasySpeaking model, you can animate a talking head using a voice track. This works not only on people but also on objects. Also better seamless transitions between Vace sliding windows for very long videos. New high quality processing features (mixed 16/32 bits calculation and 32 bits VAE)
|
44 |
+
|
45 |
+
### April 27, 2025: WanGP v4.4
|
46 |
+
👋 Phantom model support, very good model to transfer people or objects into video, works quite well at 720p and with the number of steps > 30
|
47 |
+
|
48 |
+
### April 25, 2025: WanGP v4.3
|
49 |
+
👋 Added preview mode and support for Sky Reels v2 Diffusion Forcing for high quality "infinite length videos". Note that Skyreel uses causal attention that is only supported by Sdpa attention so even if you choose another type of attention, some of the processes will use Sdpa attention.
|
50 |
+
|
51 |
+
### April 18, 2025: WanGP v4.2
|
52 |
+
👋 FLF2V model support, official support from Wan for image2video start and end frames specialized for 720p.
|
53 |
+
|
54 |
+
### April 17, 2025: WanGP v4.1
|
55 |
+
👋 Recam Master model support, view a video from a different angle. The video to process must be at least 81 frames long and you should set at least 15 steps denoising to get good results.
|
56 |
+
|
57 |
+
### April 13, 2025: WanGP v4.0
|
58 |
+
👋 Lots of goodies for you!
|
59 |
+
- A new UI, tabs were replaced by a Dropdown box to easily switch models
|
60 |
+
- A new queuing system that lets you stack in a queue as many text2video, image2video tasks, ... as you want. Each task can rely on complete different generation parameters (different number of frames, steps, loras, ...). Many thanks to **Tophness** for being a big contributor on this new feature
|
61 |
+
- Temporal upsampling (Rife) and spatial upsampling (Lanczos) for a smoother video (32 fps or 64 fps) and to enlarge your video by x2 or x4. Check these new advanced options.
|
62 |
+
- Wan Vace Control Net support: with Vace you can inject in the scene people or objects, animate a person, perform inpainting or outpainting, continue a video, ... See [VACE.md](VACE.md) for introduction guide.
|
63 |
+
- Integrated *Matanyone* tool directly inside WanGP so that you can create easily inpainting masks used in Vace
|
64 |
+
- Sliding Window generation for Vace, create windows that can last dozens of seconds
|
65 |
+
- New optimizations for old generation GPUs: Generate 5s (81 frames, 15 steps) of Vace 1.3B with only 5GB and in only 6 minutes on a RTX 2080Ti and 5s of t2v 14B in less than 10 minutes.
|
66 |
+
|
67 |
+
### March 27, 2025
|
68 |
+
👋 Added support for the new Wan Fun InP models (image2video). The 14B Fun InP has probably better end image support but unfortunately existing loras do not work so well with it. The great novelty is the Fun InP image2 1.3B model: Image 2 Video is now accessible to even lower hardware configuration. It is not as good as the 14B models but very impressive for its size. Many thanks to the VideoX-Fun team (https://github.com/aigc-apps/VideoX-Fun)
|
69 |
+
|
70 |
+
### March 26, 2025
|
71 |
+
👋 Good news! Official support for RTX 50xx please check the [installation instructions](INSTALLATION.md).
|
72 |
+
|
73 |
+
### March 24, 2025: Wan2.1GP v3.2
|
74 |
+
👋
|
75 |
+
- Added Classifier-Free Guidance Zero Star. The video should match better the text prompt (especially with text2video) at no performance cost: many thanks to the **CFG Zero * Team**. Don't hesitate to give them a star if you appreciate the results: https://github.com/WeichenFan/CFG-Zero-star
|
76 |
+
- Added back support for PyTorch compilation with Loras. It seems it had been broken for some time
|
77 |
+
- Added possibility to keep a number of pregenerated videos in the Video Gallery (useful to compare outputs of different settings)
|
78 |
+
|
79 |
+
*You will need one more `pip install -r requirements.txt`*
|
80 |
+
|
81 |
+
### March 19, 2025: Wan2.1GP v3.1
|
82 |
+
👋 Faster launch and RAM optimizations (should require less RAM to run)
|
83 |
+
|
84 |
+
*You will need one more `pip install -r requirements.txt`*
|
85 |
+
|
86 |
+
### March 18, 2025: Wan2.1GP v3.0
|
87 |
+
👋
|
88 |
+
- New Tab based interface, you can switch from i2v to t2v conversely without restarting the app
|
89 |
+
- Experimental Dual Frames mode for i2v, you can also specify an End frame. It doesn't always work, so you will need a few attempts.
|
90 |
+
- You can save default settings in the files *i2v_settings.json* and *t2v_settings.json* that will be used when launching the app (you can also specify the path to different settings files)
|
91 |
+
- Slight acceleration with loras
|
92 |
+
|
93 |
+
*You will need one more `pip install -r requirements.txt`*
|
94 |
+
|
95 |
+
Many thanks to *Tophness* who created the framework (and did a big part of the work) of the multitabs and saved settings features
|
96 |
+
|
97 |
+
### March 18, 2025: Wan2.1GP v2.11
|
98 |
+
👋 Added more command line parameters to prefill the generation settings + customizable output directory and choice of type of metadata for generated videos. Many thanks to *Tophness* for his contributions.
|
99 |
+
|
100 |
+
*You will need one more `pip install -r requirements.txt` to reflect new dependencies*
|
101 |
+
|
102 |
+
### March 18, 2025: Wan2.1GP v2.1
|
103 |
+
👋 More Loras!: added support for 'Safetensors' and 'Replicate' Lora formats.
|
104 |
+
|
105 |
+
*You will need to refresh the requirements with a `pip install -r requirements.txt`*
|
106 |
+
|
107 |
+
### March 17, 2025: Wan2.1GP v2.0
|
108 |
+
👋 The Lora festival continues:
|
109 |
+
- Clearer user interface
|
110 |
+
- Download 30 Loras in one click to try them all (expand the info section)
|
111 |
+
- Very easy to use Loras as now Lora presets can input the subject (or other needed terms) of the Lora so that you don't have to modify manually a prompt
|
112 |
+
- Added basic macro prompt language to prefill prompts with different values. With one prompt template, you can generate multiple prompts.
|
113 |
+
- New Multiple images prompts: you can now combine any number of images with any number of text prompts (need to launch the app with --multiple-images)
|
114 |
+
- New command lines options to launch directly the 1.3B t2v model or the 14B t2v model
|
115 |
+
|
116 |
+
### March 14, 2025: Wan2.1GP v1.7
|
117 |
+
👋
|
118 |
+
- Lora Fest special edition: very fast loading/unload of loras for those Loras collectors around. You can also now add/remove loras in the Lora folder without restarting the app.
|
119 |
+
- Added experimental Skip Layer Guidance (advanced settings), that should improve the image quality at no extra cost. Many thanks to the *AmericanPresidentJimmyCarter* for the original implementation
|
120 |
+
|
121 |
+
*You will need to refresh the requirements `pip install -r requirements.txt`*
|
122 |
+
|
123 |
+
### March 13, 2025: Wan2.1GP v1.6
|
124 |
+
👋 Better Loras support, accelerated loading Loras.
|
125 |
+
|
126 |
+
*You will need to refresh the requirements `pip install -r requirements.txt`*
|
127 |
+
|
128 |
+
### March 10, 2025: Wan2.1GP v1.5
|
129 |
+
👋 Official Teacache support + Smart Teacache (find automatically best parameters for a requested speed multiplier), 10% speed boost with no quality loss, improved lora presets (they can now include prompts and comments to guide the user)
|
130 |
+
|
131 |
+
### March 7, 2025: Wan2.1GP v1.4
|
132 |
+
👋 Fix PyTorch compilation, now it is really 20% faster when activated
|
133 |
+
|
134 |
+
### March 4, 2025: Wan2.1GP v1.3
|
135 |
+
👋 Support for Image to Video with multiples images for different images/prompts combinations (requires *--multiple-images* switch), and added command line *--preload x* to preload in VRAM x MB of the main diffusion model if you find there is too much unused VRAM and you want to (slightly) accelerate the generation process.
|
136 |
+
|
137 |
+
*If you upgrade you will need to do a `pip install -r requirements.txt` again.*
|
138 |
+
|
139 |
+
### March 4, 2025: Wan2.1GP v1.2
|
140 |
+
👋 Implemented tiling on VAE encoding and decoding. No more VRAM peaks at the beginning and at the end
|
141 |
+
|
142 |
+
### March 3, 2025: Wan2.1GP v1.1
|
143 |
+
👋 Added Tea Cache support for faster generations: optimization of kijai's implementation (https://github.com/kijai/ComfyUI-WanVideoWrapper/) of teacache (https://github.com/ali-vilab/TeaCache)
|
144 |
+
|
145 |
+
### March 2, 2025: Wan2.1GP by DeepBeepMeep v1
|
146 |
+
👋 Brings:
|
147 |
+
- Support for all Wan including the Image to Video model
|
148 |
+
- Reduced memory consumption by 2, with possibility to generate more than 10s of video at 720p with a RTX 4090 and 10s of video at 480p with less than 12GB of VRAM. Many thanks to REFLEx (https://github.com/thu-ml/RIFLEx) for their algorithm that allows generating nice looking video longer than 5s.
|
149 |
+
- The usual perks: web interface, multiple generations, loras support, sage attention, auto download of models, ...
|
150 |
+
|
151 |
+
## Original Wan Releases
|
152 |
+
|
153 |
+
### February 25, 2025
|
154 |
+
👋 We've released the inference code and weights of Wan2.1.
|
155 |
+
|
156 |
+
### February 27, 2025
|
157 |
+
👋 Wan2.1 has been integrated into [ComfyUI](https://comfyanonymous.github.io/ComfyUI_examples/wan/). Enjoy!
|
docs/CLI.md
ADDED
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
--vace-1-3B--vace-1-3B# Command Line Reference
|
2 |
+
|
3 |
+
This document covers all available command line options for WanGP.
|
4 |
+
|
5 |
+
## Basic Usage
|
6 |
+
|
7 |
+
```bash
|
8 |
+
# Default launch
|
9 |
+
python wgp.py
|
10 |
+
|
11 |
+
# Specific model modes
|
12 |
+
python wgp.py --i2v # Image-to-video
|
13 |
+
python wgp.py --t2v # Text-to-video (default)
|
14 |
+
python wgp.py --t2v-14B # 14B text-to-video model
|
15 |
+
python wgp.py --t2v-1-3B # 1.3B text-to-video model
|
16 |
+
python wgp.py --i2v-14B # 14B image-to-video model
|
17 |
+
python wgp.py --i2v-1-3B # Fun InP 1.3B image-to-video model
|
18 |
+
python wgp.py --vace-1-3B # VACE ControlNet 1.3B model
|
19 |
+
```
|
20 |
+
|
21 |
+
## Model and Performance Options
|
22 |
+
|
23 |
+
### Model Configuration
|
24 |
+
```bash
|
25 |
+
--quantize-transformer BOOL # Enable/disable transformer quantization (default: True)
|
26 |
+
--compile # Enable PyTorch compilation (requires Triton)
|
27 |
+
--attention MODE # Force attention mode: sdpa, flash, sage, sage2
|
28 |
+
--profile NUMBER # Performance profile 1-5 (default: 4)
|
29 |
+
--preload NUMBER # Preload N MB of diffusion model in VRAM
|
30 |
+
--fp16 # Force fp16 instead of bf16 models
|
31 |
+
--gpu DEVICE # Run on specific GPU device (e.g., "cuda:1")
|
32 |
+
```
|
33 |
+
|
34 |
+
### Performance Profiles
|
35 |
+
- **Profile 1**: Load entire current model in VRAM and keep all unused models in reserved RAM for fast VRAM tranfers
|
36 |
+
- **Profile 2**: Load model parts as needed, keep all unused models in reserved RAM for fast VRAM tranfers
|
37 |
+
- **Profile 3**: Load entire current model in VRAM (requires 24GB for 14B model)
|
38 |
+
- **Profile 4**: Default and recommended, load model parts as needed, most flexible option
|
39 |
+
- **Profile 5**: Minimum RAM usage
|
40 |
+
|
41 |
+
### Memory Management
|
42 |
+
```bash
|
43 |
+
--perc-reserved-mem-max FLOAT # Max percentage of RAM for reserved memory (< 0.5)
|
44 |
+
```
|
45 |
+
|
46 |
+
## Lora Configuration
|
47 |
+
|
48 |
+
```bash
|
49 |
+
--lora-dir PATH # Path to Wan t2v loras directory
|
50 |
+
--lora-dir-i2v PATH # Path to Wan i2v loras directory
|
51 |
+
--lora-dir-hunyuan PATH # Path to Hunyuan t2v loras directory
|
52 |
+
--lora-dir-hunyuan-i2v PATH # Path to Hunyuan i2v loras directory
|
53 |
+
--lora-dir-ltxv PATH # Path to LTX Video loras directory
|
54 |
+
--lora-preset PRESET # Load lora preset file (.lset) on startup
|
55 |
+
--check-loras # Filter incompatible loras (slower startup)
|
56 |
+
```
|
57 |
+
|
58 |
+
## Generation Settings
|
59 |
+
|
60 |
+
### Basic Generation
|
61 |
+
```bash
|
62 |
+
--seed NUMBER # Set default seed value
|
63 |
+
--frames NUMBER # Set default number of frames to generate
|
64 |
+
--steps NUMBER # Set default number of denoising steps
|
65 |
+
--advanced # Launch with advanced mode enabled
|
66 |
+
```
|
67 |
+
|
68 |
+
### Advanced Generation
|
69 |
+
```bash
|
70 |
+
--teacache MULTIPLIER # TeaCache speed multiplier: 0, 1.5, 1.75, 2.0, 2.25, 2.5
|
71 |
+
```
|
72 |
+
|
73 |
+
## Interface and Server Options
|
74 |
+
|
75 |
+
### Server Configuration
|
76 |
+
```bash
|
77 |
+
--server-port PORT # Gradio server port (default: 7860)
|
78 |
+
--server-name NAME # Gradio server name (default: localhost)
|
79 |
+
--listen # Make server accessible on network
|
80 |
+
--share # Create shareable HuggingFace URL for remote access
|
81 |
+
--open-browser # Open browser automatically when launching
|
82 |
+
```
|
83 |
+
|
84 |
+
### Interface Options
|
85 |
+
```bash
|
86 |
+
--lock-config # Prevent modifying video engine configuration from interface
|
87 |
+
--theme THEME_NAME # UI theme: "default" or "gradio"
|
88 |
+
```
|
89 |
+
|
90 |
+
## File and Directory Options
|
91 |
+
|
92 |
+
```bash
|
93 |
+
--settings PATH # Path to folder containing default settings for all models
|
94 |
+
--verbose LEVEL # Information level 0-2 (default: 1)
|
95 |
+
```
|
96 |
+
|
97 |
+
## Examples
|
98 |
+
|
99 |
+
### Basic Usage Examples
|
100 |
+
```bash
|
101 |
+
# Launch with specific model and loras
|
102 |
+
python wgp.py --t2v-14B --lora-preset mystyle.lset
|
103 |
+
|
104 |
+
# High-performance setup with compilation
|
105 |
+
python wgp.py --compile --attention sage2 --profile 3
|
106 |
+
|
107 |
+
# Low VRAM setup
|
108 |
+
python wgp.py --t2v-1-3B --profile 4 --attention sdpa
|
109 |
+
|
110 |
+
# Multiple images with custom lora directory
|
111 |
+
python wgp.py --i2v --multiple-images --lora-dir /path/to/shared/loras
|
112 |
+
```
|
113 |
+
|
114 |
+
### Server Configuration Examples
|
115 |
+
```bash
|
116 |
+
# Network accessible server
|
117 |
+
python wgp.py --listen --server-port 8080
|
118 |
+
|
119 |
+
# Shareable server with custom theme
|
120 |
+
python wgp.py --share --theme gradio --open-browser
|
121 |
+
|
122 |
+
# Locked configuration for public use
|
123 |
+
python wgp.py --lock-config --share
|
124 |
+
```
|
125 |
+
|
126 |
+
### Advanced Performance Examples
|
127 |
+
```bash
|
128 |
+
# Maximum performance (requires high-end GPU)
|
129 |
+
python wgp.py --compile --attention sage2 --profile 3 --preload 2000
|
130 |
+
|
131 |
+
# Optimized for RTX 2080Ti
|
132 |
+
python wgp.py --profile 4 --attention sdpa --teacache 2.0
|
133 |
+
|
134 |
+
# Memory-efficient setup
|
135 |
+
python wgp.py --fp16 --profile 4 --perc-reserved-mem-max 0.3
|
136 |
+
```
|
137 |
+
|
138 |
+
### TeaCache Configuration
|
139 |
+
```bash
|
140 |
+
# Different speed multipliers
|
141 |
+
python wgp.py --teacache 1.5 # 1.5x speed, minimal quality loss
|
142 |
+
python wgp.py --teacache 2.0 # 2x speed, some quality loss
|
143 |
+
python wgp.py --teacache 2.5 # 2.5x speed, noticeable quality loss
|
144 |
+
python wgp.py --teacache 0 # Disable TeaCache
|
145 |
+
```
|
146 |
+
|
147 |
+
## Attention Modes
|
148 |
+
|
149 |
+
### SDPA (Default)
|
150 |
+
```bash
|
151 |
+
python wgp.py --attention sdpa
|
152 |
+
```
|
153 |
+
- Available by default with PyTorch
|
154 |
+
- Good compatibility with all GPUs
|
155 |
+
- Moderate performance
|
156 |
+
|
157 |
+
### Sage Attention
|
158 |
+
```bash
|
159 |
+
python wgp.py --attention sage
|
160 |
+
```
|
161 |
+
- Requires Triton installation
|
162 |
+
- 30% faster than SDPA
|
163 |
+
- Small quality cost
|
164 |
+
|
165 |
+
### Sage2 Attention
|
166 |
+
```bash
|
167 |
+
python wgp.py --attention sage2
|
168 |
+
```
|
169 |
+
- Requires Triton and SageAttention 2.x
|
170 |
+
- 40% faster than SDPA
|
171 |
+
- Best performance option
|
172 |
+
|
173 |
+
### Flash Attention
|
174 |
+
```bash
|
175 |
+
python wgp.py --attention flash
|
176 |
+
```
|
177 |
+
- May require CUDA kernel compilation
|
178 |
+
- Good performance
|
179 |
+
- Can be complex to install on Windows
|
180 |
+
|
181 |
+
## Troubleshooting Command Lines
|
182 |
+
|
183 |
+
### Fallback to Basic Setup
|
184 |
+
```bash
|
185 |
+
# If advanced features don't work
|
186 |
+
python wgp.py --attention sdpa --profile 4 --fp16
|
187 |
+
```
|
188 |
+
|
189 |
+
### Debug Mode
|
190 |
+
```bash
|
191 |
+
# Maximum verbosity for troubleshooting
|
192 |
+
python wgp.py --verbose 2 --check-loras
|
193 |
+
```
|
194 |
+
|
195 |
+
### Memory Issue Debugging
|
196 |
+
```bash
|
197 |
+
# Minimal memory usage
|
198 |
+
python wgp.py --profile 4 --attention sdpa --perc-reserved-mem-max 0.2
|
199 |
+
```
|
200 |
+
|
201 |
+
|
202 |
+
|
203 |
+
## Configuration Files
|
204 |
+
|
205 |
+
### Settings Files
|
206 |
+
Load custom settings:
|
207 |
+
```bash
|
208 |
+
python wgp.py --settings /path/to/settings/folder
|
209 |
+
```
|
210 |
+
|
211 |
+
### Lora Presets
|
212 |
+
Create and share lora configurations:
|
213 |
+
```bash
|
214 |
+
# Load specific preset
|
215 |
+
python wgp.py --lora-preset anime_style.lset
|
216 |
+
|
217 |
+
# With custom lora directory
|
218 |
+
python wgp.py --lora-preset mystyle.lset --lora-dir /shared/loras
|
219 |
+
```
|
220 |
+
|
221 |
+
## Environment Variables
|
222 |
+
|
223 |
+
While not command line options, these environment variables can affect behavior:
|
224 |
+
- `CUDA_VISIBLE_DEVICES` - Limit visible GPUs
|
225 |
+
- `PYTORCH_CUDA_ALLOC_CONF` - CUDA memory allocation settings
|
226 |
+
- `TRITON_CACHE_DIR` - Triton cache directory (for Sage attention)
|
docs/GETTING_STARTED.md
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Getting Started with WanGP
|
2 |
+
|
3 |
+
This guide will help you get started with WanGP video generation quickly and easily.
|
4 |
+
|
5 |
+
## Prerequisites
|
6 |
+
|
7 |
+
Before starting, ensure you have:
|
8 |
+
- A compatible GPU (RTX 10XX or newer recommended)
|
9 |
+
- Python 3.10.9 installed
|
10 |
+
- At least 6GB of VRAM for basic models
|
11 |
+
- Internet connection for model downloads
|
12 |
+
|
13 |
+
## Quick Setup
|
14 |
+
|
15 |
+
### Option 1: One-Click Installation (Recommended)
|
16 |
+
Use [Pinokio App](https://pinokio.computer/) for the easiest installation experience.
|
17 |
+
|
18 |
+
### Option 2: Manual Installation
|
19 |
+
```bash
|
20 |
+
git clone https://github.com/deepbeepmeep/Wan2GP.git
|
21 |
+
cd Wan2GP
|
22 |
+
conda create -n wan2gp python=3.10.9
|
23 |
+
conda activate wan2gp
|
24 |
+
pip install torch==2.6.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/test/cu124
|
25 |
+
pip install -r requirements.txt
|
26 |
+
```
|
27 |
+
|
28 |
+
For detailed installation instructions, see [INSTALLATION.md](INSTALLATION.md).
|
29 |
+
|
30 |
+
## First Launch
|
31 |
+
|
32 |
+
### Basic Launch
|
33 |
+
```bash
|
34 |
+
python wgp.py
|
35 |
+
```
|
36 |
+
This launches the WanGP generator with default settings. You will be able to pick from a Drop Down menu which model you want to use.
|
37 |
+
|
38 |
+
### Alternative Modes
|
39 |
+
```bash
|
40 |
+
python wgp.py --i2v # Wan Image-to-video mode
|
41 |
+
python wgp.py --t2v-1-3B # Wan Smaller, faster model
|
42 |
+
```
|
43 |
+
|
44 |
+
## Understanding the Interface
|
45 |
+
|
46 |
+
When you launch WanGP, you'll see a web interface with several sections:
|
47 |
+
|
48 |
+
### Main Generation Panel
|
49 |
+
- **Model Selection**: Dropdown to choose between different models
|
50 |
+
- **Prompt**: Text description of what you want to generate
|
51 |
+
- **Generate Button**: Start the video generation process
|
52 |
+
|
53 |
+
### Advanced Settings (click checkbox to enable)
|
54 |
+
- **Generation Settings**: Steps, guidance, seeds
|
55 |
+
- **Loras**: Additional style customizations
|
56 |
+
- **Sliding Window**: For longer videos
|
57 |
+
|
58 |
+
## Your First Video
|
59 |
+
|
60 |
+
Let's generate a simple text-to-video:
|
61 |
+
|
62 |
+
1. **Launch WanGP**: `python wgp.py`
|
63 |
+
2. **Open Browser**: Navigate to `http://localhost:7860`
|
64 |
+
3. **Enter Prompt**: "A cat walking in a garden"
|
65 |
+
4. **Click Generate**: Wait for the video to be created
|
66 |
+
5. **View Result**: The video will appear in the output section
|
67 |
+
|
68 |
+
### Recommended First Settings
|
69 |
+
- **Model**: Wan 2.1 text2video 1.3B (faster, lower VRAM)
|
70 |
+
- **Frames**: 49 (about 2 seconds)
|
71 |
+
- **Steps**: 20 (good balance of speed/quality)
|
72 |
+
|
73 |
+
## Model Selection
|
74 |
+
|
75 |
+
### Text-to-Video Models
|
76 |
+
- **Wan 2.1 T2V 1.3B**: Fastest, lowest VRAM (6GB), good quality
|
77 |
+
- **Wan 2.1 T2V 14B**: Best quality, requires more VRAM (12GB+)
|
78 |
+
- **Hunyuan Video**: Excellent quality, slower generation
|
79 |
+
- **LTX Video**: Good for longer videos
|
80 |
+
|
81 |
+
### Image-to-Video Models
|
82 |
+
- **Wan Fun InP 1.3B**: Fast image animation
|
83 |
+
- **Wan Fun InP 14B**: Higher quality image animation
|
84 |
+
- **VACE**: Advanced control over video generation
|
85 |
+
|
86 |
+
### Choosing the Right Model
|
87 |
+
- **Low VRAM (6-8GB)**: Use 1.3B models
|
88 |
+
- **Medium VRAM (10-12GB)**: Use 14B models or Hunyuan
|
89 |
+
- **High VRAM (16GB+)**: Any model, longer videos
|
90 |
+
|
91 |
+
## Basic Settings Explained
|
92 |
+
|
93 |
+
### Generation Settings
|
94 |
+
- **Frames**: Number of frames (more = longer video)
|
95 |
+
- 25 frames ≈ 1 second
|
96 |
+
- 49 frames ≈ 2 seconds
|
97 |
+
- 73 frames ≈ 3 seconds
|
98 |
+
|
99 |
+
- **Steps**: Quality vs Speed tradeoff
|
100 |
+
- 15 steps: Fast, lower quality
|
101 |
+
- 20 steps: Good balance
|
102 |
+
- 30+ steps: High quality, slower
|
103 |
+
|
104 |
+
- **Guidance Scale**: How closely to follow the prompt
|
105 |
+
- 3-5: More creative interpretation
|
106 |
+
- 7-10: Closer to prompt description
|
107 |
+
- 12+: Very literal interpretation
|
108 |
+
|
109 |
+
### Seeds
|
110 |
+
- **Random Seed**: Different result each time
|
111 |
+
- **Fixed Seed**: Reproducible results
|
112 |
+
- **Use same seed + prompt**: Generate variations
|
113 |
+
|
114 |
+
## Common Beginner Issues
|
115 |
+
|
116 |
+
### "Out of Memory" Errors
|
117 |
+
1. Use smaller models (1.3B instead of 14B)
|
118 |
+
2. Reduce frame count
|
119 |
+
3. Lower resolution in advanced settings
|
120 |
+
4. Enable quantization (usually on by default)
|
121 |
+
|
122 |
+
### Slow Generation
|
123 |
+
1. Use 1.3B models for speed
|
124 |
+
2. Reduce number of steps
|
125 |
+
3. Install Sage attention (see [INSTALLATION.md](INSTALLATION.md))
|
126 |
+
4. Enable TeaCache: `python wgp.py --teacache 2.0`
|
127 |
+
|
128 |
+
### Poor Quality Results
|
129 |
+
1. Increase number of steps (25-30)
|
130 |
+
2. Improve prompt description
|
131 |
+
3. Use 14B models if you have enough VRAM
|
132 |
+
4. Enable Skip Layer Guidance in advanced settings
|
133 |
+
|
134 |
+
## Writing Good Prompts
|
135 |
+
|
136 |
+
### Basic Structure
|
137 |
+
```
|
138 |
+
[Subject] [Action] [Setting] [Style/Quality modifiers]
|
139 |
+
```
|
140 |
+
|
141 |
+
### Examples
|
142 |
+
```
|
143 |
+
A red sports car driving through a mountain road at sunset, cinematic, high quality
|
144 |
+
|
145 |
+
A woman with long hair walking on a beach, waves in the background, realistic, detailed
|
146 |
+
|
147 |
+
A cat sitting on a windowsill watching rain, cozy atmosphere, soft lighting
|
148 |
+
```
|
149 |
+
|
150 |
+
### Tips
|
151 |
+
- Be specific about what you want
|
152 |
+
- Include style descriptions (cinematic, realistic, etc.)
|
153 |
+
- Mention lighting and atmosphere
|
154 |
+
- Describe the setting in detail
|
155 |
+
- Use quality modifiers (high quality, detailed, etc.)
|
156 |
+
|
157 |
+
## Next Steps
|
158 |
+
|
159 |
+
Once you're comfortable with basic generation:
|
160 |
+
|
161 |
+
1. **Explore Advanced Features**:
|
162 |
+
- [Loras Guide](LORAS.md) - Customize styles and characters
|
163 |
+
- [VACE ControlNet](VACE.md) - Advanced video control
|
164 |
+
- [Command Line Options](CLI.md) - Optimize performance
|
165 |
+
|
166 |
+
2. **Improve Performance**:
|
167 |
+
- Install better attention mechanisms
|
168 |
+
- Optimize memory settings
|
169 |
+
- Use compilation for speed
|
170 |
+
|
171 |
+
3. **Join the Community**:
|
172 |
+
- [Discord Server](https://discord.gg/g7efUW9jGV) - Get help and share videos
|
173 |
+
- Share your best results
|
174 |
+
- Learn from other users
|
175 |
+
|
176 |
+
## Troubleshooting First Steps
|
177 |
+
|
178 |
+
### Installation Issues
|
179 |
+
- Ensure Python 3.10.9 is used
|
180 |
+
- Check CUDA version compatibility
|
181 |
+
- See [INSTALLATION.md](INSTALLATION.md) for detailed steps
|
182 |
+
|
183 |
+
### Generation Issues
|
184 |
+
- Check GPU compatibility
|
185 |
+
- Verify sufficient VRAM
|
186 |
+
- Try basic settings first
|
187 |
+
- See [TROUBLESHOOTING.md](TROUBLESHOOTING.md) for specific issues
|
188 |
+
|
189 |
+
### Performance Issues
|
190 |
+
- Use appropriate model for your hardware
|
191 |
+
- Enable performance optimizations
|
192 |
+
- Check [CLI.md](CLI.md) for optimization flags
|
193 |
+
|
194 |
+
Remember: Start simple and gradually explore more advanced features as you become comfortable with the basics!
|
docs/INSTALLATION.md
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Installation Guide
|
2 |
+
|
3 |
+
This guide covers installation for different GPU generations and operating systems.
|
4 |
+
|
5 |
+
## Requirements
|
6 |
+
|
7 |
+
- Python 3.10.9
|
8 |
+
- Conda or Python venv
|
9 |
+
- Compatible GPU (RTX 10XX or newer recommended)
|
10 |
+
|
11 |
+
## Installation for RTX 10XX to RTX 40XX (Stable)
|
12 |
+
|
13 |
+
This installation uses PyTorch 2.6.0 which is well-tested and stable.
|
14 |
+
|
15 |
+
### Step 1: Download and Setup Environment
|
16 |
+
|
17 |
+
```shell
|
18 |
+
# Clone the repository
|
19 |
+
git clone https://github.com/deepbeepmeep/Wan2GP.git
|
20 |
+
cd Wan2GP
|
21 |
+
|
22 |
+
# Create Python 3.10.9 environment using conda
|
23 |
+
conda create -n wan2gp python=3.10.9
|
24 |
+
conda activate wan2gp
|
25 |
+
```
|
26 |
+
|
27 |
+
### Step 2: Install PyTorch
|
28 |
+
|
29 |
+
```shell
|
30 |
+
# Install PyTorch 2.6.0 with CUDA 12.4
|
31 |
+
pip install torch==2.6.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/test/cu124
|
32 |
+
```
|
33 |
+
|
34 |
+
### Step 3: Install Dependencies
|
35 |
+
|
36 |
+
```shell
|
37 |
+
# Install core dependencies
|
38 |
+
pip install -r requirements.txt
|
39 |
+
```
|
40 |
+
|
41 |
+
### Step 4: Optional Performance Optimizations
|
42 |
+
|
43 |
+
#### Sage Attention (30% faster)
|
44 |
+
|
45 |
+
```shell
|
46 |
+
# Windows only: Install Triton
|
47 |
+
pip install triton-windows
|
48 |
+
|
49 |
+
# For both Windows and Linux
|
50 |
+
pip install sageattention==1.0.6
|
51 |
+
```
|
52 |
+
|
53 |
+
#### Sage 2 Attention (40% faster)
|
54 |
+
|
55 |
+
```shell
|
56 |
+
# Windows
|
57 |
+
pip install triton-windows
|
58 |
+
pip install https://github.com/woct0rdho/SageAttention/releases/download/v2.1.1-windows/sageattention-2.1.1+cu126torch2.6.0-cp310-cp310-win_amd64.whl
|
59 |
+
|
60 |
+
# Linux (manual compilation required)
|
61 |
+
git clone https://github.com/thu-ml/SageAttention
|
62 |
+
cd SageAttention
|
63 |
+
pip install -e .
|
64 |
+
```
|
65 |
+
|
66 |
+
#### Flash Attention
|
67 |
+
|
68 |
+
```shell
|
69 |
+
# May require CUDA kernel compilation on Windows
|
70 |
+
pip install flash-attn==2.7.2.post1
|
71 |
+
```
|
72 |
+
|
73 |
+
## Installation for RTX 50XX (Beta)
|
74 |
+
|
75 |
+
RTX 50XX GPUs require PyTorch 2.7.0 (beta). This version may be less stable.
|
76 |
+
|
77 |
+
⚠️ **Important:** Use Python 3.10 for compatibility with pip wheels.
|
78 |
+
|
79 |
+
### Step 1: Setup Environment
|
80 |
+
|
81 |
+
```shell
|
82 |
+
# Clone and setup (same as above)
|
83 |
+
git clone https://github.com/deepbeepmeep/Wan2GP.git
|
84 |
+
cd Wan2GP
|
85 |
+
conda create -n wan2gp python=3.10.9
|
86 |
+
conda activate wan2gp
|
87 |
+
```
|
88 |
+
|
89 |
+
### Step 2: Install PyTorch Beta
|
90 |
+
|
91 |
+
```shell
|
92 |
+
# Install PyTorch 2.7.0 with CUDA 12.8
|
93 |
+
pip install torch==2.7.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/test/cu128
|
94 |
+
```
|
95 |
+
|
96 |
+
### Step 3: Install Dependencies
|
97 |
+
|
98 |
+
```shell
|
99 |
+
pip install -r requirements.txt
|
100 |
+
```
|
101 |
+
|
102 |
+
### Step 4: Optional Optimizations for RTX 50XX
|
103 |
+
|
104 |
+
#### Sage Attention
|
105 |
+
|
106 |
+
```shell
|
107 |
+
# Windows
|
108 |
+
pip install triton-windows
|
109 |
+
pip install sageattention==1.0.6
|
110 |
+
|
111 |
+
# Linux
|
112 |
+
pip install sageattention==1.0.6
|
113 |
+
```
|
114 |
+
|
115 |
+
#### Sage 2 Attention
|
116 |
+
|
117 |
+
```shell
|
118 |
+
# Windows
|
119 |
+
pip install triton-windows
|
120 |
+
pip install https://github.com/woct0rdho/SageAttention/releases/download/v2.1.1-windows/sageattention-2.1.1+cu128torch2.7.0-cp310-cp310-win_amd64.whl
|
121 |
+
|
122 |
+
# Linux (manual compilation)
|
123 |
+
git clone https://github.com/thu-ml/SageAttention
|
124 |
+
cd SageAttention
|
125 |
+
pip install -e .
|
126 |
+
```
|
127 |
+
|
128 |
+
## Attention Modes
|
129 |
+
|
130 |
+
WanGP supports several attention implementations:
|
131 |
+
|
132 |
+
- **SDPA** (default): Available by default with PyTorch
|
133 |
+
- **Sage**: 30% speed boost with small quality cost
|
134 |
+
- **Sage2**: 40% speed boost
|
135 |
+
- **Flash**: Good performance, may be complex to install on Windows
|
136 |
+
|
137 |
+
## Performance Profiles
|
138 |
+
|
139 |
+
Choose a profile based on your hardware:
|
140 |
+
|
141 |
+
- **Profile 3 (LowRAM_HighVRAM)**: Loads entire model in VRAM, requires 24GB VRAM for 8-bit quantized 14B model
|
142 |
+
- **Profile 4 (LowRAM_LowVRAM)**: Default, loads model parts as needed, slower but lower VRAM requirement
|
143 |
+
|
144 |
+
## Troubleshooting
|
145 |
+
|
146 |
+
### Sage Attention Issues
|
147 |
+
|
148 |
+
If Sage attention doesn't work:
|
149 |
+
|
150 |
+
1. Check if Triton is properly installed
|
151 |
+
2. Clear Triton cache
|
152 |
+
3. Fallback to SDPA attention:
|
153 |
+
```bash
|
154 |
+
python wgp.py --attention sdpa
|
155 |
+
```
|
156 |
+
|
157 |
+
### Memory Issues
|
158 |
+
|
159 |
+
- Use lower resolution or shorter videos
|
160 |
+
- Enable quantization (default)
|
161 |
+
- Use Profile 4 for lower VRAM usage
|
162 |
+
- Consider using 1.3B models instead of 14B models
|
163 |
+
|
164 |
+
### GPU Compatibility
|
165 |
+
|
166 |
+
- RTX 10XX, 20XX: Supported with SDPA attention
|
167 |
+
- RTX 30XX, 40XX: Full feature support
|
168 |
+
- RTX 50XX: Beta support with PyTorch 2.7.0
|
169 |
+
|
170 |
+
For more troubleshooting, see [TROUBLESHOOTING.md](TROUBLESHOOTING.md)
|
docs/LORAS.md
ADDED
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Loras Guide
|
2 |
+
|
3 |
+
Loras (Low-Rank Adaptations) allow you to customize video generation models by adding specific styles, characters, or effects to your videos.
|
4 |
+
|
5 |
+
## Directory Structure
|
6 |
+
|
7 |
+
Loras are organized in different folders based on the model they're designed for:
|
8 |
+
|
9 |
+
### Text-to-Video Models
|
10 |
+
- `loras/` - General t2v loras
|
11 |
+
- `loras/1.3B/` - Loras specifically for 1.3B models
|
12 |
+
- `loras/14B/` - Loras specifically for 14B models
|
13 |
+
|
14 |
+
### Image-to-Video Models
|
15 |
+
- `loras_i2v/` - Image-to-video loras
|
16 |
+
|
17 |
+
### Other Models
|
18 |
+
- `loras_hunyuan/` - Hunyuan Video t2v loras
|
19 |
+
- `loras_hunyuan_i2v/` - Hunyuan Video i2v loras
|
20 |
+
- `loras_ltxv/` - LTX Video loras
|
21 |
+
|
22 |
+
## Custom Lora Directory
|
23 |
+
|
24 |
+
You can specify custom lora directories when launching the app:
|
25 |
+
|
26 |
+
```bash
|
27 |
+
# Use shared lora directory for both t2v and i2v
|
28 |
+
python wgp.py --lora-dir /path/to/shared/loras --lora-dir-i2v /path/to/shared/loras
|
29 |
+
|
30 |
+
# Specify different directories for different models
|
31 |
+
python wgp.py --lora-dir-hunyuan /path/to/hunyuan/loras --lora-dir-ltxv /path/to/ltx/loras
|
32 |
+
```
|
33 |
+
|
34 |
+
## Using Loras
|
35 |
+
|
36 |
+
### Basic Usage
|
37 |
+
|
38 |
+
1. Place your lora files in the appropriate directory
|
39 |
+
2. Launch WanGP
|
40 |
+
3. In the Advanced Tab, select the "Loras" section
|
41 |
+
4. Check the loras you want to activate
|
42 |
+
5. Set multipliers for each lora (default is 1.0)
|
43 |
+
|
44 |
+
### Lora Multipliers
|
45 |
+
|
46 |
+
Multipliers control the strength of each lora's effect:
|
47 |
+
|
48 |
+
#### Simple Multipliers
|
49 |
+
```
|
50 |
+
1.2 0.8
|
51 |
+
```
|
52 |
+
- First lora: 1.2 strength
|
53 |
+
- Second lora: 0.8 strength
|
54 |
+
|
55 |
+
#### Time-based Multipliers
|
56 |
+
For dynamic effects over generation steps, use comma-separated values:
|
57 |
+
```
|
58 |
+
0.9,0.8,0.7
|
59 |
+
1.2,1.1,1.0
|
60 |
+
```
|
61 |
+
- For 30 steps: steps 0-9 use first value, 10-19 use second, 20-29 use third
|
62 |
+
- First lora: 0.9 → 0.8 → 0.7
|
63 |
+
- Second lora: 1.2 → 1.1 → 1.0
|
64 |
+
|
65 |
+
## Lora Presets
|
66 |
+
|
67 |
+
Presets are combinations of loras with predefined multipliers and prompts.
|
68 |
+
|
69 |
+
### Creating Presets
|
70 |
+
1. Configure your loras and multipliers
|
71 |
+
2. Write a prompt with comments (lines starting with #)
|
72 |
+
3. Save as a preset with `.lset` extension
|
73 |
+
|
74 |
+
### Example Preset
|
75 |
+
```
|
76 |
+
# Use the keyword "ohnvx" to trigger the lora
|
77 |
+
A ohnvx character is driving a car through the city
|
78 |
+
```
|
79 |
+
|
80 |
+
### Using Presets
|
81 |
+
```bash
|
82 |
+
# Load preset on startup
|
83 |
+
python wgp.py --lora-preset mypreset.lset
|
84 |
+
```
|
85 |
+
|
86 |
+
### Managing Presets
|
87 |
+
- Edit, save, or delete presets directly from the web interface
|
88 |
+
- Presets include comments with usage instructions
|
89 |
+
- Share `.lset` files with other users
|
90 |
+
|
91 |
+
## CausVid Lora (Video Generation Accelerator)
|
92 |
+
|
93 |
+
CausVid is a distilled Wan model that generates videos in 4-12 steps with 2x speed improvement.
|
94 |
+
|
95 |
+
### Setup Instructions
|
96 |
+
1. Download the CausVid Lora:
|
97 |
+
```
|
98 |
+
https://huggingface.co/Kijai/WanVideo_comfy/blob/main/Wan21_CausVid_14B_T2V_lora_rank32.safetensors
|
99 |
+
```
|
100 |
+
2. Place in your `loras/` directory
|
101 |
+
|
102 |
+
### Usage
|
103 |
+
1. Select a Wan t2v model (e.g., Wan 2.1 text2video 13B or Vace 13B)
|
104 |
+
2. Enable Advanced Mode
|
105 |
+
3. In Advanced Generation Tab:
|
106 |
+
- Set Guidance Scale = 1
|
107 |
+
- Set Shift Scale = 7
|
108 |
+
4. In Advanced Lora Tab:
|
109 |
+
- Select CausVid Lora
|
110 |
+
- Set multiplier to 0.3
|
111 |
+
5. Set generation steps to 12
|
112 |
+
6. Generate!
|
113 |
+
|
114 |
+
### CausVid Step/Multiplier Relationship
|
115 |
+
- **12 steps**: 0.3 multiplier (recommended)
|
116 |
+
- **8 steps**: 0.5-0.7 multiplier
|
117 |
+
- **4 steps**: 0.8-1.0 multiplier
|
118 |
+
|
119 |
+
*Note: Lower steps = lower quality (especially motion)*
|
120 |
+
|
121 |
+
## Supported Formats
|
122 |
+
|
123 |
+
WanGP supports multiple lora formats:
|
124 |
+
- **Safetensors** (.safetensors)
|
125 |
+
- **Replicate** format
|
126 |
+
- **Standard PyTorch** (.pt, .pth)
|
127 |
+
|
128 |
+
## AccVid Lora (Video Generation Accelerator)
|
129 |
+
|
130 |
+
AccVid is a distilled Wan model that generates videos with a 2x speed improvement since classifier free guidance is no longer needed (that is cfg = 1).
|
131 |
+
|
132 |
+
### Setup Instructions
|
133 |
+
1. Download the CausVid Lora:
|
134 |
+
|
135 |
+
- for t2v models:
|
136 |
+
```
|
137 |
+
https://huggingface.co/Kijai/WanVideo_comfy/blob/main/Wan21_AccVid_T2V_14B_lora_rank32_fp16.safetensors
|
138 |
+
```
|
139 |
+
|
140 |
+
- for i2v models:
|
141 |
+
```
|
142 |
+
https://huggingface.co/Kijai/WanVideo_comfy/blob/main/Wan21_AccVid_I2V_480P_14B_lora_rank32_fp16.safetensors
|
143 |
+
```
|
144 |
+
|
145 |
+
2. Place in your `loras/` directory or `loras_i2v/` directory
|
146 |
+
|
147 |
+
### Usage
|
148 |
+
1. Select a Wan t2v model (e.g., Wan 2.1 text2video 13B or Vace 13B) or Wan i2v model
|
149 |
+
2. Enable Advanced Mode
|
150 |
+
3. In Advanced Generation Tab:
|
151 |
+
- Set Guidance Scale = 1
|
152 |
+
- Set Shift Scale = 5
|
153 |
+
4. The number steps remain unchanged compared to what you would use with the original model but it will be two times faster since classifier free guidance is not needed
|
154 |
+
|
155 |
+
## Performance Tips
|
156 |
+
|
157 |
+
### Fast Loading/Unloading
|
158 |
+
- Loras can be added/removed without restarting the app
|
159 |
+
- Use the "Refresh" button to detect new loras
|
160 |
+
- Enable `--check-loras` to filter incompatible loras (slower startup)
|
161 |
+
|
162 |
+
### Memory Management
|
163 |
+
- Loras are loaded on-demand to save VRAM
|
164 |
+
- Multiple loras can be used simultaneously
|
165 |
+
- Time-based multipliers don't use extra memory
|
166 |
+
|
167 |
+
## Finding Loras
|
168 |
+
|
169 |
+
### Sources
|
170 |
+
- **[Civitai](https://civitai.com/)** - Large community collection
|
171 |
+
- **HuggingFace** - Official and community loras
|
172 |
+
- **Discord Server** - Community recommendations
|
173 |
+
|
174 |
+
### Creating Loras
|
175 |
+
- **Kohya** - Popular training tool
|
176 |
+
- **OneTrainer** - Alternative training solution
|
177 |
+
- **Custom datasets** - Train on your own content
|
178 |
+
|
179 |
+
## Macro System (Advanced)
|
180 |
+
|
181 |
+
Create multiple prompts from templates using macros:
|
182 |
+
|
183 |
+
```
|
184 |
+
! {Subject}="cat","woman","man", {Location}="forest","lake","city", {Possessive}="its","her","his"
|
185 |
+
In the video, a {Subject} is presented. The {Subject} is in a {Location} and looks at {Possessive} watch.
|
186 |
+
```
|
187 |
+
|
188 |
+
This generates:
|
189 |
+
1. "In the video, a cat is presented. The cat is in a forest and looks at its watch."
|
190 |
+
2. "In the video, a woman is presented. The woman is in a lake and looks at her watch."
|
191 |
+
3. "In the video, a man is presented. The man is in a city and looks at his watch."
|
192 |
+
|
193 |
+
## Troubleshooting
|
194 |
+
|
195 |
+
### Lora Not Working
|
196 |
+
1. Check if lora is compatible with your model size (1.3B vs 14B)
|
197 |
+
2. Verify lora format is supported
|
198 |
+
3. Try different multiplier values
|
199 |
+
4. Check the lora was trained for your model type (t2v vs i2v)
|
200 |
+
|
201 |
+
### Performance Issues
|
202 |
+
1. Reduce number of active loras
|
203 |
+
2. Lower multiplier values
|
204 |
+
3. Use `--check-loras` to filter incompatible files
|
205 |
+
4. Clear lora cache if issues persist
|
206 |
+
|
207 |
+
### Memory Errors
|
208 |
+
1. Use fewer loras simultaneously
|
209 |
+
2. Reduce model size (use 1.3B instead of 14B)
|
210 |
+
3. Lower video resolution or frame count
|
211 |
+
4. Enable quantization if not already active
|
212 |
+
|
213 |
+
## Command Line Options
|
214 |
+
|
215 |
+
```bash
|
216 |
+
# Lora-related command line options
|
217 |
+
--lora-dir path # Path to t2v loras directory
|
218 |
+
--lora-dir-i2v path # Path to i2v loras directory
|
219 |
+
--lora-dir-hunyuan path # Path to Hunyuan t2v loras
|
220 |
+
--lora-dir-hunyuan-i2v path # Path to Hunyuan i2v loras
|
221 |
+
--lora-dir-ltxv path # Path to LTX Video loras
|
222 |
+
--lora-preset preset # Load preset on startup
|
223 |
+
--check-loras # Filter incompatible loras
|
224 |
+
```
|
docs/MODELS.md
ADDED
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Models Overview
|
2 |
+
|
3 |
+
WanGP supports multiple video generation models, each optimized for different use cases and hardware configurations.
|
4 |
+
|
5 |
+
|
6 |
+
## Wan 2.1 Text2Video Models
|
7 |
+
Please note that that the term *Text2Video* refers to the underlying Wan architecture but as it has been greatly improved overtime many derived Text2Video models can now generate videos using images.
|
8 |
+
|
9 |
+
#### Wan 2.1 Text2Video 1.3B
|
10 |
+
- **Size**: 1.3 billion parameters
|
11 |
+
- **VRAM**: 6GB minimum
|
12 |
+
- **Speed**: Fast generation
|
13 |
+
- **Quality**: Good quality for the size
|
14 |
+
- **Best for**: Quick iterations, lower-end hardware
|
15 |
+
- **Command**: `python wgp.py --t2v-1-3B`
|
16 |
+
|
17 |
+
#### Wan 2.1 Text2Video 14B
|
18 |
+
- **Size**: 14 billion parameters
|
19 |
+
- **VRAM**: 12GB+ recommended
|
20 |
+
- **Speed**: Slower but higher quality
|
21 |
+
- **Quality**: Excellent detail and coherence
|
22 |
+
- **Best for**: Final production videos
|
23 |
+
- **Command**: `python wgp.py --t2v-14B`
|
24 |
+
|
25 |
+
#### Wan Vace 1.3B
|
26 |
+
- **Type**: ControlNet for advanced video control
|
27 |
+
- **VRAM**: 6GB minimum
|
28 |
+
- **Features**: Motion transfer, object injection, inpainting
|
29 |
+
- **Best for**: Advanced video manipulation
|
30 |
+
- **Command**: `python wgp.py --vace-1.3B`
|
31 |
+
|
32 |
+
#### Wan Vace 14B
|
33 |
+
- **Type**: Large ControlNet model
|
34 |
+
- **VRAM**: 12GB+ recommended
|
35 |
+
- **Features**: All Vace features with higher quality
|
36 |
+
- **Best for**: Professional video editing workflows
|
37 |
+
|
38 |
+
#### MoviiGen (Experimental)
|
39 |
+
- **Resolution**: Claims 1080p capability
|
40 |
+
- **VRAM**: 20GB+ required
|
41 |
+
- **Speed**: Very slow generation
|
42 |
+
- **Features**: Should generate cinema like video, specialized for 2.1 / 1 ratios
|
43 |
+
- **Status**: Experimental, feedback welcome
|
44 |
+
|
45 |
+
<BR>
|
46 |
+
|
47 |
+
## Wan 2.1 Image-to-Video Models
|
48 |
+
|
49 |
+
#### Wan 2.1 Image2Video 14B
|
50 |
+
- **Size**: 14 billion parameters
|
51 |
+
- **VRAM**: 12GB+ recommended
|
52 |
+
- **Speed**: Slower but higher quality
|
53 |
+
- **Quality**: Excellent detail and coherence
|
54 |
+
- **Best for**: Most Loras available work with this model
|
55 |
+
- **Command**: `python wgp.py --i2v-14B`
|
56 |
+
|
57 |
+
#### FLF2V
|
58 |
+
- **Type**: Start/end frame specialist
|
59 |
+
- **Resolution**: Optimized for 720p
|
60 |
+
- **Official**: Wan team supported
|
61 |
+
- **Use case**: Image-to-video with specific endpoints
|
62 |
+
|
63 |
+
|
64 |
+
<BR>
|
65 |
+
|
66 |
+
## Wan 2.1 Specialized Models
|
67 |
+
|
68 |
+
#### FantasySpeaking
|
69 |
+
- **Type**: Talking head animation
|
70 |
+
- **Input**: Voice track + image
|
71 |
+
- **Works on**: People and objects
|
72 |
+
- **Use case**: Lip-sync and voice-driven animation
|
73 |
+
|
74 |
+
#### Phantom
|
75 |
+
- **Type**: Person/object transfer
|
76 |
+
- **Resolution**: Works well at 720p
|
77 |
+
- **Requirements**: 30+ steps for good results
|
78 |
+
- **Best for**: Transferring subjects between videos
|
79 |
+
|
80 |
+
#### Recam Master
|
81 |
+
- **Type**: Viewpoint change
|
82 |
+
- **Requirements**: 81+ frame input videos, 15+ denoising steps
|
83 |
+
- **Use case**: View same scene from different angles
|
84 |
+
|
85 |
+
#### Sky Reels v2
|
86 |
+
- **Type**: Diffusion Forcing model
|
87 |
+
- **Specialty**: "Infinite length" videos
|
88 |
+
- **Features**: High quality continuous generation
|
89 |
+
|
90 |
+
|
91 |
+
<BR>
|
92 |
+
|
93 |
+
## Wan Fun InP Models
|
94 |
+
|
95 |
+
#### Wan Fun InP 1.3B
|
96 |
+
- **Size**: 1.3 billion parameters
|
97 |
+
- **VRAM**: 6GB minimum
|
98 |
+
- **Quality**: Good for the size, accessible to lower hardware
|
99 |
+
- **Best for**: Entry-level image animation
|
100 |
+
- **Command**: `python wgp.py --i2v-1-3B`
|
101 |
+
|
102 |
+
#### Wan Fun InP 14B
|
103 |
+
- **Size**: 14 billion parameters
|
104 |
+
- **VRAM**: 12GB+ recommended
|
105 |
+
- **Quality**: Better end image support
|
106 |
+
- **Limitation**: Existing loras don't work as well
|
107 |
+
|
108 |
+
<BR>
|
109 |
+
|
110 |
+
## Wan Special Loras
|
111 |
+
### Causvid
|
112 |
+
- **Type**: Distilled model (Lora implementation)
|
113 |
+
- **Speed**: 4-12 steps generation, 2x faster
|
114 |
+
- **Compatible**: Works with Wan 14B models
|
115 |
+
- **Setup**: Requires CausVid Lora (see [LORAS.md](LORAS.md))
|
116 |
+
|
117 |
+
|
118 |
+
<BR>
|
119 |
+
|
120 |
+
## Hunyuan Video Models
|
121 |
+
|
122 |
+
#### Hunyuan Video Text2Video
|
123 |
+
- **Quality**: Among the best open source t2v models
|
124 |
+
- **VRAM**: 12GB+ recommended
|
125 |
+
- **Speed**: Slower generation but excellent results
|
126 |
+
- **Features**: Superior text adherence and video quality, up to 10s of video
|
127 |
+
- **Best for**: High-quality text-to-video generation
|
128 |
+
|
129 |
+
#### Hunyuan Video Custom
|
130 |
+
- **Specialty**: Identity preservation
|
131 |
+
- **Use case**: Injecting specific people into videos
|
132 |
+
- **Quality**: Excellent for character consistency
|
133 |
+
- **Best for**: Character-focused video generation
|
134 |
+
|
135 |
+
#### Hunyuan Video Avater
|
136 |
+
- **Specialty**: Generate up to 15s of high quality speech / song driven Video .
|
137 |
+
- **Use case**: Injecting specific people into videos
|
138 |
+
- **Quality**: Excellent for character consistency
|
139 |
+
- **Best for**: Character-focused video generation, Video synchronized with voice
|
140 |
+
|
141 |
+
|
142 |
+
<BR>
|
143 |
+
|
144 |
+
## LTX Video Models
|
145 |
+
|
146 |
+
#### LTX Video 13B
|
147 |
+
- **Specialty**: Long video generation
|
148 |
+
- **Resolution**: Fast 720p generation
|
149 |
+
- **VRAM**: Optimized by WanGP (4x reduction in requirements)
|
150 |
+
- **Best for**: Longer duration videos
|
151 |
+
|
152 |
+
#### LTX Video 13B Distilled
|
153 |
+
- **Speed**: Generate in less than one minute
|
154 |
+
- **Quality**: Very high quality despite speed
|
155 |
+
- **Best for**: Rapid prototyping and quick results
|
156 |
+
|
157 |
+
<BR>
|
158 |
+
|
159 |
+
## Model Selection Guide
|
160 |
+
|
161 |
+
### By Hardware (VRAM)
|
162 |
+
|
163 |
+
#### 6-8GB VRAM
|
164 |
+
- Wan 2.1 T2V 1.3B
|
165 |
+
- Wan Fun InP 1.3B
|
166 |
+
- Wan Vace 1.3B
|
167 |
+
|
168 |
+
#### 10-12GB VRAM
|
169 |
+
- Wan 2.1 T2V 14B
|
170 |
+
- Wan Fun InP 14B
|
171 |
+
- Hunyuan Video (with optimizations)
|
172 |
+
- LTX Video 13B
|
173 |
+
|
174 |
+
#### 16GB+ VRAM
|
175 |
+
- All models supported
|
176 |
+
- Longer videos possible
|
177 |
+
- Higher resolutions
|
178 |
+
- Multiple simultaneous Loras
|
179 |
+
|
180 |
+
#### 20GB+ VRAM
|
181 |
+
- MoviiGen (experimental 1080p)
|
182 |
+
- Very long videos
|
183 |
+
- Maximum quality settings
|
184 |
+
|
185 |
+
### By Use Case
|
186 |
+
|
187 |
+
#### Quick Prototyping
|
188 |
+
1. **LTX Video 13B Distilled** - Fastest, high quality
|
189 |
+
2. **Wan 2.1 T2V 1.3B** - Fast, good quality
|
190 |
+
3. **CausVid Lora** - 4-12 steps, very fast
|
191 |
+
|
192 |
+
#### Best Quality
|
193 |
+
1. **Hunyuan Video** - Overall best t2v quality
|
194 |
+
2. **Wan 2.1 T2V 14B** - Excellent Wan quality
|
195 |
+
3. **Wan Vace 14B** - Best for controlled generation
|
196 |
+
|
197 |
+
#### Advanced Control
|
198 |
+
1. **Wan Vace 14B/1.3B** - Motion transfer, object injection
|
199 |
+
2. **Phantom** - Person/object transfer
|
200 |
+
3. **FantasySpeaking** - Voice-driven animation
|
201 |
+
|
202 |
+
#### Long Videos
|
203 |
+
1. **LTX Video 13B** - Specialized for length
|
204 |
+
2. **Sky Reels v2** - Infinite length videos
|
205 |
+
3. **Wan Vace + Sliding Windows** - Up to 1 minute
|
206 |
+
|
207 |
+
#### Lower Hardware
|
208 |
+
1. **Wan Fun InP 1.3B** - Image-to-video
|
209 |
+
2. **Wan 2.1 T2V 1.3B** - Text-to-video
|
210 |
+
3. **Wan Vace 1.3B** - Advanced control
|
211 |
+
|
212 |
+
<BR>
|
213 |
+
|
214 |
+
## Performance Comparison
|
215 |
+
|
216 |
+
### Speed (Relative)
|
217 |
+
1. **CausVid Lora** (4-12 steps) - Fastest
|
218 |
+
2. **LTX Video Distilled** - Very fast
|
219 |
+
3. **Wan 1.3B models** - Fast
|
220 |
+
4. **Wan 14B models** - Medium
|
221 |
+
5. **Hunyuan Video** - Slower
|
222 |
+
6. **MoviiGen** - Slowest
|
223 |
+
|
224 |
+
### Quality (Subjective)
|
225 |
+
1. **Hunyuan Video** - Highest overall
|
226 |
+
2. **Wan 14B models** - Excellent
|
227 |
+
3. **LTX Video models** - Very good
|
228 |
+
4. **Wan 1.3B models** - Good
|
229 |
+
5. **CausVid** - Good (varies with steps)
|
230 |
+
|
231 |
+
### VRAM Efficiency
|
232 |
+
1. **Wan 1.3B models** - Most efficient
|
233 |
+
2. **LTX Video** (with WanGP optimizations)
|
234 |
+
3. **Wan 14B models**
|
235 |
+
4. **Hunyuan Video**
|
236 |
+
5. **MoviiGen** - Least efficient
|
237 |
+
|
238 |
+
<BR>
|
239 |
+
|
240 |
+
## Model Switching
|
241 |
+
|
242 |
+
WanGP allows switching between models without restarting:
|
243 |
+
|
244 |
+
1. Use the dropdown menu in the web interface
|
245 |
+
2. Models are loaded on-demand
|
246 |
+
3. Previous model is unloaded to save VRAM
|
247 |
+
4. Settings are preserved when possible
|
248 |
+
|
249 |
+
<BR>
|
250 |
+
|
251 |
+
## Tips for Model Selection
|
252 |
+
|
253 |
+
### First Time Users
|
254 |
+
Start with **Wan 2.1 T2V 1.3B** to learn the interface and test your hardware.
|
255 |
+
|
256 |
+
### Production Work
|
257 |
+
Use **Hunyuan Video** or **Wan 14B** models for final output quality.
|
258 |
+
|
259 |
+
### Experimentation
|
260 |
+
**CausVid Lora** or **LTX Distilled** for rapid iteration and testing.
|
261 |
+
|
262 |
+
### Specialized Tasks
|
263 |
+
- **VACE** for advanced control
|
264 |
+
- **FantasySpeaking** for talking heads
|
265 |
+
- **LTX Video** for long sequences
|
266 |
+
|
267 |
+
### Hardware Optimization
|
268 |
+
Always start with the largest model your VRAM can handle, then optimize settings for speed vs quality based on your needs.
|
docs/TROUBLESHOOTING.md
ADDED
@@ -0,0 +1,338 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Troubleshooting Guide
|
2 |
+
|
3 |
+
This guide covers common issues and their solutions when using WanGP.
|
4 |
+
|
5 |
+
## Installation Issues
|
6 |
+
|
7 |
+
### PyTorch Installation Problems
|
8 |
+
|
9 |
+
#### CUDA Version Mismatch
|
10 |
+
**Problem**: PyTorch can't detect GPU or CUDA errors
|
11 |
+
**Solution**:
|
12 |
+
```bash
|
13 |
+
# Check your CUDA version
|
14 |
+
nvidia-smi
|
15 |
+
|
16 |
+
# Install matching PyTorch version
|
17 |
+
# For CUDA 12.4 (RTX 10XX-40XX)
|
18 |
+
pip install torch==2.6.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/test/cu124
|
19 |
+
|
20 |
+
# For CUDA 12.8 (RTX 50XX)
|
21 |
+
pip install torch==2.7.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/test/cu128
|
22 |
+
```
|
23 |
+
|
24 |
+
#### Python Version Issues
|
25 |
+
**Problem**: Package compatibility errors
|
26 |
+
**Solution**: Ensure you're using Python 3.10.9
|
27 |
+
```bash
|
28 |
+
python --version # Should show 3.10.9
|
29 |
+
conda create -n wan2gp python=3.10.9
|
30 |
+
```
|
31 |
+
|
32 |
+
### Dependency Installation Failures
|
33 |
+
|
34 |
+
#### Triton Installation (Windows)
|
35 |
+
**Problem**: `pip install triton-windows` fails
|
36 |
+
**Solution**:
|
37 |
+
1. Update pip: `pip install --upgrade pip`
|
38 |
+
2. Try pre-compiled wheel
|
39 |
+
3. Fallback to SDPA attention: `python wgp.py --attention sdpa`
|
40 |
+
|
41 |
+
#### SageAttention Compilation Issues
|
42 |
+
**Problem**: SageAttention installation fails
|
43 |
+
**Solution**:
|
44 |
+
1. Install Visual Studio Build Tools (Windows)
|
45 |
+
2. Use pre-compiled wheels when available
|
46 |
+
3. Fallback to basic attention modes
|
47 |
+
|
48 |
+
## Memory Issues
|
49 |
+
|
50 |
+
### CUDA Out of Memory
|
51 |
+
|
52 |
+
#### During Model Loading
|
53 |
+
**Problem**: "CUDA out of memory" when loading model
|
54 |
+
**Solutions**:
|
55 |
+
```bash
|
56 |
+
# Use smaller model
|
57 |
+
python wgp.py --t2v-1-3B
|
58 |
+
|
59 |
+
# Enable quantization (usually default)
|
60 |
+
python wgp.py --quantize-transformer True
|
61 |
+
|
62 |
+
# Use memory-efficient profile
|
63 |
+
python wgp.py --profile 4
|
64 |
+
|
65 |
+
# Reduce preloaded model size
|
66 |
+
python wgp.py --preload 0
|
67 |
+
```
|
68 |
+
|
69 |
+
#### During Video Generation
|
70 |
+
**Problem**: Memory error during generation
|
71 |
+
**Solutions**:
|
72 |
+
1. Reduce frame count (shorter videos)
|
73 |
+
2. Lower resolution in advanced settings
|
74 |
+
3. Use lower batch size
|
75 |
+
4. Clear GPU cache between generations
|
76 |
+
|
77 |
+
### System RAM Issues
|
78 |
+
|
79 |
+
#### High RAM Usage
|
80 |
+
**Problem**: System runs out of RAM
|
81 |
+
**Solutions**:
|
82 |
+
```bash
|
83 |
+
# Limit reserved memory
|
84 |
+
python wgp.py --perc-reserved-mem-max 0.3
|
85 |
+
|
86 |
+
# Use minimal RAM profile
|
87 |
+
python wgp.py --profile 5
|
88 |
+
|
89 |
+
# Enable swap file (OS level)
|
90 |
+
```
|
91 |
+
|
92 |
+
## Performance Issues
|
93 |
+
|
94 |
+
### Slow Generation Speed
|
95 |
+
|
96 |
+
#### General Optimization
|
97 |
+
```bash
|
98 |
+
# Enable compilation (requires Triton)
|
99 |
+
python wgp.py --compile
|
100 |
+
|
101 |
+
# Use faster attention
|
102 |
+
python wgp.py --attention sage2
|
103 |
+
|
104 |
+
# Enable TeaCache
|
105 |
+
python wgp.py --teacache 2.0
|
106 |
+
|
107 |
+
# Use high-performance profile
|
108 |
+
python wgp.py --profile 3
|
109 |
+
```
|
110 |
+
|
111 |
+
#### GPU-Specific Optimizations
|
112 |
+
|
113 |
+
**RTX 10XX/20XX Series**:
|
114 |
+
```bash
|
115 |
+
python wgp.py --attention sdpa --profile 4 --teacache 1.5
|
116 |
+
```
|
117 |
+
|
118 |
+
**RTX 30XX/40XX Series**:
|
119 |
+
```bash
|
120 |
+
python wgp.py --compile --attention sage --profile 3 --teacache 2.0
|
121 |
+
```
|
122 |
+
|
123 |
+
**RTX 50XX Series**:
|
124 |
+
```bash
|
125 |
+
python wgp.py --attention sage --profile 4 --fp16
|
126 |
+
```
|
127 |
+
|
128 |
+
### Attention Mechanism Issues
|
129 |
+
|
130 |
+
#### Sage Attention Not Working
|
131 |
+
**Problem**: Sage attention fails to compile or work
|
132 |
+
**Diagnostic Steps**:
|
133 |
+
1. Check Triton installation:
|
134 |
+
```python
|
135 |
+
import triton
|
136 |
+
print(triton.__version__)
|
137 |
+
```
|
138 |
+
2. Clear Triton cache:
|
139 |
+
```bash
|
140 |
+
# Windows
|
141 |
+
rmdir /s %USERPROFILE%\.triton
|
142 |
+
# Linux
|
143 |
+
rm -rf ~/.triton
|
144 |
+
```
|
145 |
+
3. Fallback solution:
|
146 |
+
```bash
|
147 |
+
python wgp.py --attention sdpa
|
148 |
+
```
|
149 |
+
|
150 |
+
#### Flash Attention Issues
|
151 |
+
**Problem**: Flash attention compilation fails
|
152 |
+
**Solution**:
|
153 |
+
- Windows: Often requires manual CUDA kernel compilation
|
154 |
+
- Linux: Usually works with `pip install flash-attn`
|
155 |
+
- Fallback: Use Sage or SDPA attention
|
156 |
+
|
157 |
+
## Model-Specific Issues
|
158 |
+
|
159 |
+
### Lora Problems
|
160 |
+
|
161 |
+
#### Loras Not Loading
|
162 |
+
**Problem**: Loras don't appear in the interface
|
163 |
+
**Solutions**:
|
164 |
+
1. Check file format (should be .safetensors, .pt, or .pth)
|
165 |
+
2. Verify correct directory:
|
166 |
+
```
|
167 |
+
loras/ # For t2v models
|
168 |
+
loras_i2v/ # For i2v models
|
169 |
+
loras_hunyuan/ # For Hunyuan models
|
170 |
+
```
|
171 |
+
3. Click "Refresh" button in interface
|
172 |
+
4. Use `--check-loras` to filter incompatible files
|
173 |
+
|
174 |
+
#### Lora Compatibility Issues
|
175 |
+
**Problem**: Lora causes errors or poor results
|
176 |
+
**Solutions**:
|
177 |
+
1. Check model size compatibility (1.3B vs 14B)
|
178 |
+
2. Verify lora was trained for your model type
|
179 |
+
3. Try different multiplier values
|
180 |
+
4. Use `--check-loras` flag to auto-filter
|
181 |
+
|
182 |
+
### VACE-Specific Issues
|
183 |
+
|
184 |
+
#### Poor VACE Results
|
185 |
+
**Problem**: VACE generates poor quality or unexpected results
|
186 |
+
**Solutions**:
|
187 |
+
1. Enable Skip Layer Guidance
|
188 |
+
2. Use detailed prompts describing all elements
|
189 |
+
3. Ensure proper mask creation with Matanyone
|
190 |
+
4. Check reference image quality
|
191 |
+
5. Use at least 15 steps, preferably 30+
|
192 |
+
|
193 |
+
#### Matanyone Tool Issues
|
194 |
+
**Problem**: Mask creation difficulties
|
195 |
+
**Solutions**:
|
196 |
+
1. Use negative point prompts to refine selection
|
197 |
+
2. Create multiple sub-masks and combine them
|
198 |
+
3. Try different background removal options
|
199 |
+
4. Ensure sufficient contrast in source video
|
200 |
+
|
201 |
+
## Network and Server Issues
|
202 |
+
|
203 |
+
### Gradio Interface Problems
|
204 |
+
|
205 |
+
#### Port Already in Use
|
206 |
+
**Problem**: "Port 7860 is already in use"
|
207 |
+
**Solution**:
|
208 |
+
```bash
|
209 |
+
# Use different port
|
210 |
+
python wgp.py --server-port 7861
|
211 |
+
|
212 |
+
# Or kill existing process
|
213 |
+
# Windows
|
214 |
+
netstat -ano | findstr :7860
|
215 |
+
taskkill /PID <PID> /F
|
216 |
+
|
217 |
+
# Linux
|
218 |
+
lsof -i :7860
|
219 |
+
kill <PID>
|
220 |
+
```
|
221 |
+
|
222 |
+
#### Interface Not Loading
|
223 |
+
**Problem**: Browser shows "connection refused"
|
224 |
+
**Solutions**:
|
225 |
+
1. Check if server started successfully
|
226 |
+
2. Try `http://127.0.0.1:7860` instead of `localhost:7860`
|
227 |
+
3. Disable firewall temporarily
|
228 |
+
4. Use `--listen` flag for network access
|
229 |
+
|
230 |
+
### Remote Access Issues
|
231 |
+
|
232 |
+
#### Sharing Not Working
|
233 |
+
**Problem**: `--share` flag doesn't create public URL
|
234 |
+
**Solutions**:
|
235 |
+
1. Check internet connection
|
236 |
+
2. Try different network
|
237 |
+
3. Use `--listen` with port forwarding
|
238 |
+
4. Check firewall settings
|
239 |
+
|
240 |
+
## Quality Issues
|
241 |
+
|
242 |
+
### Poor Video Quality
|
243 |
+
|
244 |
+
#### General Quality Improvements
|
245 |
+
1. Increase number of steps (25-30+)
|
246 |
+
2. Use larger models (14B instead of 1.3B)
|
247 |
+
3. Enable Skip Layer Guidance
|
248 |
+
4. Improve prompt descriptions
|
249 |
+
5. Use higher resolution settings
|
250 |
+
|
251 |
+
#### Specific Quality Issues
|
252 |
+
|
253 |
+
**Blurry Videos**:
|
254 |
+
- Increase steps
|
255 |
+
- Check source image quality (i2v)
|
256 |
+
- Reduce TeaCache multiplier
|
257 |
+
- Use higher guidance scale
|
258 |
+
|
259 |
+
**Inconsistent Motion**:
|
260 |
+
- Use longer overlap in sliding windows
|
261 |
+
- Reduce window size
|
262 |
+
- Improve prompt consistency
|
263 |
+
- Check control video quality (VACE)
|
264 |
+
|
265 |
+
**Color Issues**:
|
266 |
+
- Check model compatibility
|
267 |
+
- Adjust guidance scale
|
268 |
+
- Verify input image color space
|
269 |
+
- Try different VAE settings
|
270 |
+
|
271 |
+
## Advanced Debugging
|
272 |
+
|
273 |
+
### Enable Verbose Output
|
274 |
+
```bash
|
275 |
+
# Maximum verbosity
|
276 |
+
python wgp.py --verbose 2
|
277 |
+
|
278 |
+
# Check lora compatibility
|
279 |
+
python wgp.py --check-loras --verbose 2
|
280 |
+
```
|
281 |
+
|
282 |
+
### Memory Debugging
|
283 |
+
```bash
|
284 |
+
# Monitor GPU memory
|
285 |
+
nvidia-smi -l 1
|
286 |
+
|
287 |
+
# Reduce memory usage
|
288 |
+
python wgp.py --profile 4 --perc-reserved-mem-max 0.2
|
289 |
+
```
|
290 |
+
|
291 |
+
### Performance Profiling
|
292 |
+
```bash
|
293 |
+
# Test different configurations
|
294 |
+
python wgp.py --attention sdpa --profile 4 # Baseline
|
295 |
+
python wgp.py --attention sage --profile 3 # Performance
|
296 |
+
python wgp.py --compile --teacache 2.0 # Maximum speed
|
297 |
+
```
|
298 |
+
|
299 |
+
## Getting Help
|
300 |
+
|
301 |
+
### Before Asking for Help
|
302 |
+
1. Check this troubleshooting guide
|
303 |
+
2. Read the relevant documentation:
|
304 |
+
- [Installation Guide](INSTALLATION.md)
|
305 |
+
- [Getting Started](GETTING_STARTED.md)
|
306 |
+
- [Command Line Reference](CLI.md)
|
307 |
+
3. Try basic fallback configuration:
|
308 |
+
```bash
|
309 |
+
python wgp.py --attention sdpa --profile 4
|
310 |
+
```
|
311 |
+
|
312 |
+
### Community Support
|
313 |
+
- **Discord Server**: https://discord.gg/g7efUW9jGV
|
314 |
+
- Provide relevant information:
|
315 |
+
- GPU model and VRAM amount
|
316 |
+
- Python and PyTorch versions
|
317 |
+
- Complete error messages
|
318 |
+
- Command used to launch WanGP
|
319 |
+
- Operating system
|
320 |
+
|
321 |
+
### Reporting Bugs
|
322 |
+
When reporting issues:
|
323 |
+
1. Include system specifications
|
324 |
+
2. Provide complete error logs
|
325 |
+
3. List the exact steps to reproduce
|
326 |
+
4. Mention any modifications to default settings
|
327 |
+
5. Include command line arguments used
|
328 |
+
|
329 |
+
## Emergency Fallback
|
330 |
+
|
331 |
+
If nothing works, try this minimal configuration:
|
332 |
+
```bash
|
333 |
+
# Absolute minimum setup
|
334 |
+
python wgp.py --t2v-1-3B --attention sdpa --profile 4 --teacache 0 --fp16
|
335 |
+
|
336 |
+
# If that fails, check basic PyTorch installation
|
337 |
+
python -c "import torch; print(torch.cuda.is_available())"
|
338 |
+
```
|
docs/VACE.md
ADDED
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# VACE ControlNet Guide
|
2 |
+
|
3 |
+
VACE is a powerful ControlNet that enables Video-to-Video and Reference-to-Video generation. It allows you to inject your own images into output videos, animate characters, perform inpainting/outpainting, and continue videos.
|
4 |
+
|
5 |
+
## Overview
|
6 |
+
|
7 |
+
VACE is probably one of the most powerful Wan models available. With it, you can:
|
8 |
+
- Inject people or objects into scenes
|
9 |
+
- Animate characters
|
10 |
+
- Perform video inpainting and outpainting
|
11 |
+
- Continue existing videos
|
12 |
+
- Transfer motion from one video to another
|
13 |
+
- Change the style of scenes while preserving depth
|
14 |
+
|
15 |
+
## Getting Started
|
16 |
+
|
17 |
+
### Model Selection
|
18 |
+
1. Select either "Vace 1.3B" or "Vace 13B" from the dropdown menu
|
19 |
+
2. Note: VACE works best with videos up to 7 seconds with the Riflex option enabled
|
20 |
+
|
21 |
+
### Input Types
|
22 |
+
|
23 |
+
VACE accepts three types of visual hints (which can be combined):
|
24 |
+
|
25 |
+
#### 1. Control Video
|
26 |
+
- Transfer motion or depth to a new video
|
27 |
+
- Use only the first n frames and extrapolate the rest
|
28 |
+
- Perform inpainting with grey color (127) as mask areas
|
29 |
+
- Grey areas will be filled based on text prompt and reference images
|
30 |
+
|
31 |
+
#### 2. Reference Images
|
32 |
+
- Use as background/setting for the video
|
33 |
+
- Inject people or objects of your choice
|
34 |
+
- Select multiple reference images
|
35 |
+
- **Tip**: Replace complex backgrounds with white for better object integration
|
36 |
+
- Always describe injected objects/people explicitly in your text prompt
|
37 |
+
|
38 |
+
#### 3. Video Mask
|
39 |
+
- Stronger control over which parts to keep (black) or replace (white)
|
40 |
+
- Perfect for inpainting/outpainting
|
41 |
+
- Example: White mask except at beginning/end (black) keeps first/last frames while generating middle content
|
42 |
+
|
43 |
+
## Common Use Cases
|
44 |
+
|
45 |
+
### Motion Transfer
|
46 |
+
**Goal**: Animate a character of your choice using motion from another video
|
47 |
+
**Setup**:
|
48 |
+
- Reference Images: Your character
|
49 |
+
- Control Video: Person performing desired motion
|
50 |
+
- Text Prompt: Describe your character and the action
|
51 |
+
|
52 |
+
### Object/Person Injection
|
53 |
+
**Goal**: Insert people or objects into a scene
|
54 |
+
**Setup**:
|
55 |
+
- Reference Images: The people/objects to inject
|
56 |
+
- Text Prompt: Describe the scene and explicitly mention the injected elements
|
57 |
+
|
58 |
+
### Character Animation
|
59 |
+
**Goal**: Animate a character based on text description
|
60 |
+
**Setup**:
|
61 |
+
- Control Video: Video of person moving
|
62 |
+
- Text Prompt: Detailed description of your character
|
63 |
+
|
64 |
+
### Style Transfer with Depth
|
65 |
+
**Goal**: Change scene style while preserving spatial relationships
|
66 |
+
**Setup**:
|
67 |
+
- Control Video: Original video (for depth information)
|
68 |
+
- Text Prompt: New style description
|
69 |
+
|
70 |
+
## Integrated Matanyone Tool
|
71 |
+
|
72 |
+
WanGP includes the Matanyone tool, specifically tuned for VACE workflows. This helps create control videos and masks simultaneously.
|
73 |
+
|
74 |
+
### Creating Face Replacement Masks
|
75 |
+
1. Load your video in Matanyone
|
76 |
+
2. Click on the face in the first frame
|
77 |
+
3. Create a mask for the face
|
78 |
+
4. Generate both control video and mask video with "Generate Video Matting"
|
79 |
+
5. Export to VACE with "Export to current Video Input and Video Mask"
|
80 |
+
6. Load replacement face image in Reference Images field
|
81 |
+
|
82 |
+
### Advanced Matanyone Tips
|
83 |
+
- **Negative Point Prompts**: Remove parts from current selection
|
84 |
+
- **Sub Masks**: Create multiple independent masks, then combine them
|
85 |
+
- **Background Masks**: Select everything except the character (useful for background replacement)
|
86 |
+
- Enable/disable sub masks in Matanyone settings
|
87 |
+
|
88 |
+
## Recommended Settings
|
89 |
+
|
90 |
+
### Quality Settings
|
91 |
+
- **Skip Layer Guidance**: Turn ON with default configuration for better results
|
92 |
+
- **Long Prompts**: Use detailed descriptions, especially for background elements not in reference images
|
93 |
+
- **Steps**: Use at least 15 steps for good quality, 30+ for best results
|
94 |
+
|
95 |
+
### Sliding Window Settings
|
96 |
+
For very long videos, configure sliding windows properly:
|
97 |
+
|
98 |
+
- **Window Size**: Set appropriate duration for your content
|
99 |
+
- **Overlap Frames**: Long enough for motion continuity, short enough to avoid blur propagation
|
100 |
+
- **Discard Last Frames**: Remove at least 4 frames from each window (VACE 1.3B tends to blur final frames)
|
101 |
+
|
102 |
+
### Background Removal
|
103 |
+
VACE includes automatic background removal options:
|
104 |
+
- Use for reference images containing people/objects
|
105 |
+
- **Don't use** for landscape/setting reference images (first reference image)
|
106 |
+
- Multiple background removal types available
|
107 |
+
|
108 |
+
## Window Sliding for Long Videos
|
109 |
+
|
110 |
+
Generate videos up to 1 minute by merging multiple windows:
|
111 |
+
|
112 |
+
### How It Works
|
113 |
+
- Each window uses corresponding time segment from control video
|
114 |
+
- Example: 0-4s control video → first window, 4-8s → second window, etc.
|
115 |
+
- Automatic overlap management ensures smooth transitions
|
116 |
+
|
117 |
+
### Settings
|
118 |
+
- **Window Size**: Duration of each generation window
|
119 |
+
- **Overlap Frames**: Frames shared between windows for continuity
|
120 |
+
- **Discard Last Frames**: Remove poor-quality ending frames
|
121 |
+
- **Add Overlapped Noise**: Reduce quality degradation over time
|
122 |
+
|
123 |
+
### Formula
|
124 |
+
```
|
125 |
+
Generated Frames = [Windows - 1] × [Window Size - Overlap - Discard] + Window Size
|
126 |
+
```
|
127 |
+
|
128 |
+
### Multi-Line Prompts (Experimental)
|
129 |
+
- Each line of prompt used for different window
|
130 |
+
- If more windows than prompt lines, last line repeats
|
131 |
+
- Separate lines with carriage return
|
132 |
+
|
133 |
+
## Advanced Features
|
134 |
+
|
135 |
+
### Extend Video
|
136 |
+
Click "Extend the Video Sample, Please!" during generation to add more windows dynamically.
|
137 |
+
|
138 |
+
### Noise Addition
|
139 |
+
Add noise to overlapped frames to hide accumulated errors and quality degradation.
|
140 |
+
|
141 |
+
### Frame Truncation
|
142 |
+
Automatically remove lower-quality final frames from each window (recommended: 4 frames for VACE 1.3B).
|
143 |
+
|
144 |
+
## External Resources
|
145 |
+
|
146 |
+
### Official VACE Resources
|
147 |
+
- **GitHub**: https://github.com/ali-vilab/VACE/tree/main/vace/gradios
|
148 |
+
- **User Guide**: https://github.com/ali-vilab/VACE/blob/main/UserGuide.md
|
149 |
+
- **Preprocessors**: Gradio tools for preparing materials
|
150 |
+
|
151 |
+
### Recommended External Tools
|
152 |
+
- **Annotation Tools**: For creating precise masks
|
153 |
+
- **Video Editors**: For preparing control videos
|
154 |
+
- **Background Removal**: For cleaning reference images
|
155 |
+
|
156 |
+
## Troubleshooting
|
157 |
+
|
158 |
+
### Poor Quality Results
|
159 |
+
1. Use longer, more detailed prompts
|
160 |
+
2. Enable Skip Layer Guidance
|
161 |
+
3. Increase number of steps (30+)
|
162 |
+
4. Check reference image quality
|
163 |
+
5. Ensure proper mask creation
|
164 |
+
|
165 |
+
### Inconsistent Windows
|
166 |
+
1. Increase overlap frames
|
167 |
+
2. Use consistent prompting across windows
|
168 |
+
3. Add noise to overlapped frames
|
169 |
+
4. Reduce discard frames if losing too much content
|
170 |
+
|
171 |
+
### Memory Issues
|
172 |
+
1. Use VACE 1.3B instead of 13B
|
173 |
+
2. Reduce video length or resolution
|
174 |
+
3. Decrease window size
|
175 |
+
4. Enable quantization
|
176 |
+
|
177 |
+
### Blurry Results
|
178 |
+
1. Reduce overlap frames
|
179 |
+
2. Increase discard last frames
|
180 |
+
3. Use higher resolution reference images
|
181 |
+
4. Check control video quality
|
182 |
+
|
183 |
+
## Tips for Best Results
|
184 |
+
|
185 |
+
1. **Detailed Prompts**: Describe everything in the scene, especially elements not in reference images
|
186 |
+
2. **Quality Reference Images**: Use high-resolution, well-lit reference images
|
187 |
+
3. **Proper Masking**: Take time to create precise masks with Matanyone
|
188 |
+
4. **Iterative Approach**: Start with short videos, then extend successful results
|
189 |
+
5. **Background Preparation**: Remove complex backgrounds from object/person reference images
|
190 |
+
6. **Consistent Lighting**: Match lighting between reference images and intended scene
|
fantasytalking/infer.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright Alibaba Inc. All Rights Reserved.
|
2 |
+
|
3 |
+
from transformers import Wav2Vec2Model, Wav2Vec2Processor
|
4 |
+
|
5 |
+
from .model import FantasyTalkingAudioConditionModel
|
6 |
+
from .utils import get_audio_features
|
7 |
+
import gc, torch
|
8 |
+
|
9 |
+
def parse_audio(audio_path, num_frames, fps = 23, device = "cuda"):
|
10 |
+
fantasytalking = FantasyTalkingAudioConditionModel(None, 768, 2048).to(device)
|
11 |
+
from mmgp import offload
|
12 |
+
from accelerate import init_empty_weights
|
13 |
+
from fantasytalking.model import AudioProjModel
|
14 |
+
|
15 |
+
torch.set_grad_enabled(False)
|
16 |
+
|
17 |
+
with init_empty_weights():
|
18 |
+
proj_model = AudioProjModel( 768, 2048)
|
19 |
+
offload.load_model_data(proj_model, "ckpts/fantasy_proj_model.safetensors")
|
20 |
+
proj_model.to("cpu").eval().requires_grad_(False)
|
21 |
+
|
22 |
+
wav2vec_model_dir = "ckpts/wav2vec"
|
23 |
+
wav2vec_processor = Wav2Vec2Processor.from_pretrained(wav2vec_model_dir)
|
24 |
+
wav2vec = Wav2Vec2Model.from_pretrained(wav2vec_model_dir, device_map="cpu").eval().requires_grad_(False)
|
25 |
+
wav2vec.to(device)
|
26 |
+
proj_model.to(device)
|
27 |
+
audio_wav2vec_fea = get_audio_features( wav2vec, wav2vec_processor, audio_path, fps, num_frames )
|
28 |
+
|
29 |
+
audio_proj_fea = proj_model(audio_wav2vec_fea)
|
30 |
+
pos_idx_ranges = fantasytalking.split_audio_sequence( audio_proj_fea.size(1), num_frames=num_frames )
|
31 |
+
audio_proj_split, audio_context_lens = fantasytalking.split_tensor_with_padding( audio_proj_fea, pos_idx_ranges, expand_length=4 ) # [b,21,9+8,768]
|
32 |
+
wav2vec, proj_model= None, None
|
33 |
+
gc.collect()
|
34 |
+
torch.cuda.empty_cache()
|
35 |
+
|
36 |
+
return audio_proj_split, audio_context_lens
|
fantasytalking/model.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from wan.modules.attention import pay_attention
|
5 |
+
|
6 |
+
|
7 |
+
class AudioProjModel(nn.Module):
|
8 |
+
def __init__(self, audio_in_dim=1024, cross_attention_dim=1024):
|
9 |
+
super().__init__()
|
10 |
+
self.cross_attention_dim = cross_attention_dim
|
11 |
+
self.proj = torch.nn.Linear(audio_in_dim, cross_attention_dim, bias=False)
|
12 |
+
self.norm = torch.nn.LayerNorm(cross_attention_dim)
|
13 |
+
|
14 |
+
def forward(self, audio_embeds):
|
15 |
+
context_tokens = self.proj(audio_embeds)
|
16 |
+
context_tokens = self.norm(context_tokens)
|
17 |
+
return context_tokens # [B,L,C]
|
18 |
+
|
19 |
+
class WanCrossAttentionProcessor(nn.Module):
|
20 |
+
def __init__(self, context_dim, hidden_dim):
|
21 |
+
super().__init__()
|
22 |
+
|
23 |
+
self.context_dim = context_dim
|
24 |
+
self.hidden_dim = hidden_dim
|
25 |
+
|
26 |
+
self.k_proj = nn.Linear(context_dim, hidden_dim, bias=False)
|
27 |
+
self.v_proj = nn.Linear(context_dim, hidden_dim, bias=False)
|
28 |
+
|
29 |
+
nn.init.zeros_(self.k_proj.weight)
|
30 |
+
nn.init.zeros_(self.v_proj.weight)
|
31 |
+
|
32 |
+
def __call__(
|
33 |
+
self,
|
34 |
+
q: torch.Tensor,
|
35 |
+
audio_proj: torch.Tensor,
|
36 |
+
latents_num_frames: int = 21,
|
37 |
+
audio_context_lens = None
|
38 |
+
) -> torch.Tensor:
|
39 |
+
"""
|
40 |
+
audio_proj: [B, 21, L3, C]
|
41 |
+
audio_context_lens: [B*21].
|
42 |
+
"""
|
43 |
+
b, l, n, d = q.shape
|
44 |
+
|
45 |
+
if len(audio_proj.shape) == 4:
|
46 |
+
audio_q = q.view(b * latents_num_frames, -1, n, d) # [b, 21, l1, n, d]
|
47 |
+
ip_key = self.k_proj(audio_proj).view(b * latents_num_frames, -1, n, d)
|
48 |
+
ip_value = self.v_proj(audio_proj).view(b * latents_num_frames, -1, n, d)
|
49 |
+
qkv_list = [audio_q, ip_key, ip_value]
|
50 |
+
del q, audio_q, ip_key, ip_value
|
51 |
+
audio_x = pay_attention(qkv_list, k_lens =audio_context_lens) #audio_context_lens
|
52 |
+
audio_x = audio_x.view(b, l, n, d)
|
53 |
+
audio_x = audio_x.flatten(2)
|
54 |
+
elif len(audio_proj.shape) == 3:
|
55 |
+
ip_key = self.k_proj(audio_proj).view(b, -1, n, d)
|
56 |
+
ip_value = self.v_proj(audio_proj).view(b, -1, n, d)
|
57 |
+
qkv_list = [q, ip_key, ip_value]
|
58 |
+
del q, ip_key, ip_value
|
59 |
+
audio_x = pay_attention(qkv_list, k_lens =audio_context_lens) #audio_context_lens
|
60 |
+
audio_x = audio_x.flatten(2)
|
61 |
+
return audio_x
|
62 |
+
|
63 |
+
|
64 |
+
class FantasyTalkingAudioConditionModel(nn.Module):
|
65 |
+
def __init__(self, wan_dit, audio_in_dim: int, audio_proj_dim: int):
|
66 |
+
super().__init__()
|
67 |
+
|
68 |
+
self.audio_in_dim = audio_in_dim
|
69 |
+
self.audio_proj_dim = audio_proj_dim
|
70 |
+
|
71 |
+
def split_audio_sequence(self, audio_proj_length, num_frames=81):
|
72 |
+
"""
|
73 |
+
Map the audio feature sequence to corresponding latent frame slices.
|
74 |
+
|
75 |
+
Args:
|
76 |
+
audio_proj_length (int): The total length of the audio feature sequence
|
77 |
+
(e.g., 173 in audio_proj[1, 173, 768]).
|
78 |
+
num_frames (int): The number of video frames in the training data (default: 81).
|
79 |
+
|
80 |
+
Returns:
|
81 |
+
list: A list of [start_idx, end_idx] pairs. Each pair represents the index range
|
82 |
+
(within the audio feature sequence) corresponding to a latent frame.
|
83 |
+
"""
|
84 |
+
# Average number of tokens per original video frame
|
85 |
+
tokens_per_frame = audio_proj_length / num_frames
|
86 |
+
|
87 |
+
# Each latent frame covers 4 video frames, and we want the center
|
88 |
+
tokens_per_latent_frame = tokens_per_frame * 4
|
89 |
+
half_tokens = int(tokens_per_latent_frame / 2)
|
90 |
+
|
91 |
+
pos_indices = []
|
92 |
+
for i in range(int((num_frames - 1) / 4) + 1):
|
93 |
+
if i == 0:
|
94 |
+
pos_indices.append(0)
|
95 |
+
else:
|
96 |
+
start_token = tokens_per_frame * ((i - 1) * 4 + 1)
|
97 |
+
end_token = tokens_per_frame * (i * 4 + 1)
|
98 |
+
center_token = int((start_token + end_token) / 2) - 1
|
99 |
+
pos_indices.append(center_token)
|
100 |
+
|
101 |
+
# Build index ranges centered around each position
|
102 |
+
pos_idx_ranges = [[idx - half_tokens, idx + half_tokens] for idx in pos_indices]
|
103 |
+
|
104 |
+
# Adjust the first range to avoid negative start index
|
105 |
+
pos_idx_ranges[0] = [
|
106 |
+
-(half_tokens * 2 - pos_idx_ranges[1][0]),
|
107 |
+
pos_idx_ranges[1][0],
|
108 |
+
]
|
109 |
+
|
110 |
+
return pos_idx_ranges
|
111 |
+
|
112 |
+
def split_tensor_with_padding(self, input_tensor, pos_idx_ranges, expand_length=0):
|
113 |
+
"""
|
114 |
+
Split the input tensor into subsequences based on index ranges, and apply right-side zero-padding
|
115 |
+
if the range exceeds the input boundaries.
|
116 |
+
|
117 |
+
Args:
|
118 |
+
input_tensor (Tensor): Input audio tensor of shape [1, L, 768].
|
119 |
+
pos_idx_ranges (list): A list of index ranges, e.g. [[-7, 1], [1, 9], ..., [165, 173]].
|
120 |
+
expand_length (int): Number of tokens to expand on both sides of each subsequence.
|
121 |
+
|
122 |
+
Returns:
|
123 |
+
sub_sequences (Tensor): A tensor of shape [1, F, L, 768], where L is the length after padding.
|
124 |
+
Each element is a padded subsequence.
|
125 |
+
k_lens (Tensor): A tensor of shape [F], representing the actual (unpadded) length of each subsequence.
|
126 |
+
Useful for ignoring padding tokens in attention masks.
|
127 |
+
"""
|
128 |
+
pos_idx_ranges = [
|
129 |
+
[idx[0] - expand_length, idx[1] + expand_length] for idx in pos_idx_ranges
|
130 |
+
]
|
131 |
+
sub_sequences = []
|
132 |
+
seq_len = input_tensor.size(1) # 173
|
133 |
+
max_valid_idx = seq_len - 1 # 172
|
134 |
+
k_lens_list = []
|
135 |
+
for start, end in pos_idx_ranges:
|
136 |
+
# Calculate the fill amount
|
137 |
+
pad_front = max(-start, 0)
|
138 |
+
pad_back = max(end - max_valid_idx, 0)
|
139 |
+
|
140 |
+
# Calculate the start and end indices of the valid part
|
141 |
+
valid_start = max(start, 0)
|
142 |
+
valid_end = min(end, max_valid_idx)
|
143 |
+
|
144 |
+
# Extract the valid part
|
145 |
+
if valid_start <= valid_end:
|
146 |
+
valid_part = input_tensor[:, valid_start : valid_end + 1, :]
|
147 |
+
else:
|
148 |
+
valid_part = input_tensor.new_zeros((1, 0, input_tensor.size(2)))
|
149 |
+
|
150 |
+
# In the sequence dimension (the 1st dimension) perform padding
|
151 |
+
padded_subseq = F.pad(
|
152 |
+
valid_part,
|
153 |
+
(0, 0, 0, pad_back + pad_front, 0, 0),
|
154 |
+
mode="constant",
|
155 |
+
value=0,
|
156 |
+
)
|
157 |
+
k_lens_list.append(padded_subseq.size(-2) - pad_back - pad_front)
|
158 |
+
|
159 |
+
sub_sequences.append(padded_subseq)
|
160 |
+
return torch.stack(sub_sequences, dim=1), torch.tensor(
|
161 |
+
k_lens_list, dtype=torch.long
|
162 |
+
)
|
fantasytalking/utils.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright Alibaba Inc. All Rights Reserved.
|
2 |
+
|
3 |
+
import imageio
|
4 |
+
import librosa
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from PIL import Image
|
8 |
+
from tqdm import tqdm
|
9 |
+
|
10 |
+
|
11 |
+
def resize_image_by_longest_edge(image_path, target_size):
|
12 |
+
image = Image.open(image_path).convert("RGB")
|
13 |
+
width, height = image.size
|
14 |
+
scale = target_size / max(width, height)
|
15 |
+
new_size = (int(width * scale), int(height * scale))
|
16 |
+
return image.resize(new_size, Image.LANCZOS)
|
17 |
+
|
18 |
+
|
19 |
+
def save_video(frames, save_path, fps, quality=9, ffmpeg_params=None):
|
20 |
+
writer = imageio.get_writer(
|
21 |
+
save_path, fps=fps, quality=quality, ffmpeg_params=ffmpeg_params
|
22 |
+
)
|
23 |
+
for frame in tqdm(frames, desc="Saving video"):
|
24 |
+
frame = np.array(frame)
|
25 |
+
writer.append_data(frame)
|
26 |
+
writer.close()
|
27 |
+
|
28 |
+
|
29 |
+
def get_audio_features(wav2vec, audio_processor, audio_path, fps, num_frames):
|
30 |
+
sr = 16000
|
31 |
+
audio_input, sample_rate = librosa.load(audio_path, sr=sr) # 采样率为 16kHz
|
32 |
+
|
33 |
+
start_time = 0
|
34 |
+
# end_time = (0 + (num_frames - 1) * 1) / fps
|
35 |
+
end_time = num_frames / fps
|
36 |
+
|
37 |
+
start_sample = int(start_time * sr)
|
38 |
+
end_sample = int(end_time * sr)
|
39 |
+
|
40 |
+
try:
|
41 |
+
audio_segment = audio_input[start_sample:end_sample]
|
42 |
+
except:
|
43 |
+
audio_segment = audio_input
|
44 |
+
|
45 |
+
input_values = audio_processor(
|
46 |
+
audio_segment, sampling_rate=sample_rate, return_tensors="pt"
|
47 |
+
).input_values.to("cuda")
|
48 |
+
|
49 |
+
with torch.no_grad():
|
50 |
+
fea = wav2vec(input_values).last_hidden_state
|
51 |
+
|
52 |
+
return fea
|
hyvideo/__init__.py
ADDED
File without changes
|
hyvideo/config.py
ADDED
@@ -0,0 +1,534 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from .constants import *
|
3 |
+
import re
|
4 |
+
from .modules.models import HUNYUAN_VIDEO_CONFIG
|
5 |
+
|
6 |
+
|
7 |
+
def parse_args(namespace=None):
|
8 |
+
parser = argparse.ArgumentParser(description="HunyuanVideo inference script")
|
9 |
+
|
10 |
+
parser = add_network_args(parser)
|
11 |
+
parser = add_extra_models_args(parser)
|
12 |
+
parser = add_denoise_schedule_args(parser)
|
13 |
+
parser = add_inference_args(parser)
|
14 |
+
parser = add_parallel_args(parser)
|
15 |
+
|
16 |
+
args = parser.parse_args(namespace=namespace)
|
17 |
+
args = sanity_check_args(args)
|
18 |
+
|
19 |
+
return args
|
20 |
+
|
21 |
+
|
22 |
+
def add_network_args(parser: argparse.ArgumentParser):
|
23 |
+
group = parser.add_argument_group(title="HunyuanVideo network args")
|
24 |
+
|
25 |
+
|
26 |
+
group.add_argument(
|
27 |
+
"--quantize-transformer",
|
28 |
+
action="store_true",
|
29 |
+
help="On the fly 'transformer' quantization"
|
30 |
+
)
|
31 |
+
|
32 |
+
|
33 |
+
group.add_argument(
|
34 |
+
"--lora-dir-i2v",
|
35 |
+
type=str,
|
36 |
+
default="loras_i2v",
|
37 |
+
help="Path to a directory that contains Loras for i2v"
|
38 |
+
)
|
39 |
+
|
40 |
+
|
41 |
+
group.add_argument(
|
42 |
+
"--lora-dir",
|
43 |
+
type=str,
|
44 |
+
default="",
|
45 |
+
help="Path to a directory that contains Loras"
|
46 |
+
)
|
47 |
+
|
48 |
+
|
49 |
+
group.add_argument(
|
50 |
+
"--lora-preset",
|
51 |
+
type=str,
|
52 |
+
default="",
|
53 |
+
help="Lora preset to preload"
|
54 |
+
)
|
55 |
+
|
56 |
+
# group.add_argument(
|
57 |
+
# "--lora-preset-i2v",
|
58 |
+
# type=str,
|
59 |
+
# default="",
|
60 |
+
# help="Lora preset to preload for i2v"
|
61 |
+
# )
|
62 |
+
|
63 |
+
group.add_argument(
|
64 |
+
"--profile",
|
65 |
+
type=str,
|
66 |
+
default=-1,
|
67 |
+
help="Profile No"
|
68 |
+
)
|
69 |
+
|
70 |
+
group.add_argument(
|
71 |
+
"--verbose",
|
72 |
+
type=str,
|
73 |
+
default=1,
|
74 |
+
help="Verbose level"
|
75 |
+
)
|
76 |
+
|
77 |
+
group.add_argument(
|
78 |
+
"--server-port",
|
79 |
+
type=str,
|
80 |
+
default=0,
|
81 |
+
help="Server port"
|
82 |
+
)
|
83 |
+
|
84 |
+
group.add_argument(
|
85 |
+
"--server-name",
|
86 |
+
type=str,
|
87 |
+
default="",
|
88 |
+
help="Server name"
|
89 |
+
)
|
90 |
+
|
91 |
+
group.add_argument(
|
92 |
+
"--open-browser",
|
93 |
+
action="store_true",
|
94 |
+
help="open browser"
|
95 |
+
)
|
96 |
+
|
97 |
+
group.add_argument(
|
98 |
+
"--t2v",
|
99 |
+
action="store_true",
|
100 |
+
help="text to video mode"
|
101 |
+
)
|
102 |
+
|
103 |
+
group.add_argument(
|
104 |
+
"--i2v",
|
105 |
+
action="store_true",
|
106 |
+
help="image to video mode"
|
107 |
+
)
|
108 |
+
|
109 |
+
group.add_argument(
|
110 |
+
"--compile",
|
111 |
+
action="store_true",
|
112 |
+
help="Enable pytorch compilation"
|
113 |
+
)
|
114 |
+
|
115 |
+
group.add_argument(
|
116 |
+
"--fast",
|
117 |
+
action="store_true",
|
118 |
+
help="use Fast HunyuanVideo model"
|
119 |
+
)
|
120 |
+
|
121 |
+
group.add_argument(
|
122 |
+
"--fastest",
|
123 |
+
action="store_true",
|
124 |
+
help="activate the best config"
|
125 |
+
)
|
126 |
+
|
127 |
+
group.add_argument(
|
128 |
+
"--attention",
|
129 |
+
type=str,
|
130 |
+
default="",
|
131 |
+
help="attention mode"
|
132 |
+
)
|
133 |
+
|
134 |
+
group.add_argument(
|
135 |
+
"--vae-config",
|
136 |
+
type=str,
|
137 |
+
default="",
|
138 |
+
help="vae config mode"
|
139 |
+
)
|
140 |
+
|
141 |
+
parser.add_argument(
|
142 |
+
"--share",
|
143 |
+
action="store_true",
|
144 |
+
help="Create a shared URL to access webserver remotely"
|
145 |
+
)
|
146 |
+
|
147 |
+
parser.add_argument(
|
148 |
+
"--lock-config",
|
149 |
+
action="store_true",
|
150 |
+
help="Prevent modifying the configuration from the web interface"
|
151 |
+
)
|
152 |
+
|
153 |
+
parser.add_argument(
|
154 |
+
"--preload",
|
155 |
+
type=str,
|
156 |
+
default="0",
|
157 |
+
help="Megabytes of the diffusion model to preload in VRAM"
|
158 |
+
)
|
159 |
+
|
160 |
+
parser.add_argument(
|
161 |
+
"--multiple-images",
|
162 |
+
action="store_true",
|
163 |
+
help="Allow inputting multiple images with image to video"
|
164 |
+
)
|
165 |
+
|
166 |
+
|
167 |
+
# Main model
|
168 |
+
group.add_argument(
|
169 |
+
"--model",
|
170 |
+
type=str,
|
171 |
+
choices=list(HUNYUAN_VIDEO_CONFIG.keys()),
|
172 |
+
default="HYVideo-T/2-cfgdistill",
|
173 |
+
)
|
174 |
+
group.add_argument(
|
175 |
+
"--latent-channels",
|
176 |
+
type=str,
|
177 |
+
default=16,
|
178 |
+
help="Number of latent channels of DiT. If None, it will be determined by `vae`. If provided, "
|
179 |
+
"it still needs to match the latent channels of the VAE model.",
|
180 |
+
)
|
181 |
+
group.add_argument(
|
182 |
+
"--precision",
|
183 |
+
type=str,
|
184 |
+
default="bf16",
|
185 |
+
choices=PRECISIONS,
|
186 |
+
help="Precision mode. Options: fp32, fp16, bf16. Applied to the backbone model and optimizer.",
|
187 |
+
)
|
188 |
+
|
189 |
+
# RoPE
|
190 |
+
group.add_argument(
|
191 |
+
"--rope-theta", type=int, default=256, help="Theta used in RoPE."
|
192 |
+
)
|
193 |
+
return parser
|
194 |
+
|
195 |
+
|
196 |
+
def add_extra_models_args(parser: argparse.ArgumentParser):
|
197 |
+
group = parser.add_argument_group(
|
198 |
+
title="Extra models args, including vae, text encoders and tokenizers)"
|
199 |
+
)
|
200 |
+
|
201 |
+
# - VAE
|
202 |
+
group.add_argument(
|
203 |
+
"--vae",
|
204 |
+
type=str,
|
205 |
+
default="884-16c-hy",
|
206 |
+
choices=list(VAE_PATH),
|
207 |
+
help="Name of the VAE model.",
|
208 |
+
)
|
209 |
+
group.add_argument(
|
210 |
+
"--vae-precision",
|
211 |
+
type=str,
|
212 |
+
default="fp16",
|
213 |
+
choices=PRECISIONS,
|
214 |
+
help="Precision mode for the VAE model.",
|
215 |
+
)
|
216 |
+
group.add_argument(
|
217 |
+
"--vae-tiling",
|
218 |
+
action="store_true",
|
219 |
+
help="Enable tiling for the VAE model to save GPU memory.",
|
220 |
+
)
|
221 |
+
group.set_defaults(vae_tiling=True)
|
222 |
+
|
223 |
+
group.add_argument(
|
224 |
+
"--text-encoder",
|
225 |
+
type=str,
|
226 |
+
default="llm",
|
227 |
+
choices=list(TEXT_ENCODER_PATH),
|
228 |
+
help="Name of the text encoder model.",
|
229 |
+
)
|
230 |
+
group.add_argument(
|
231 |
+
"--text-encoder-precision",
|
232 |
+
type=str,
|
233 |
+
default="fp16",
|
234 |
+
choices=PRECISIONS,
|
235 |
+
help="Precision mode for the text encoder model.",
|
236 |
+
)
|
237 |
+
group.add_argument(
|
238 |
+
"--text-states-dim",
|
239 |
+
type=int,
|
240 |
+
default=4096,
|
241 |
+
help="Dimension of the text encoder hidden states.",
|
242 |
+
)
|
243 |
+
group.add_argument(
|
244 |
+
"--text-len", type=int, default=256, help="Maximum length of the text input."
|
245 |
+
)
|
246 |
+
group.add_argument(
|
247 |
+
"--tokenizer",
|
248 |
+
type=str,
|
249 |
+
default="llm",
|
250 |
+
choices=list(TOKENIZER_PATH),
|
251 |
+
help="Name of the tokenizer model.",
|
252 |
+
)
|
253 |
+
group.add_argument(
|
254 |
+
"--prompt-template",
|
255 |
+
type=str,
|
256 |
+
default="dit-llm-encode",
|
257 |
+
choices=PROMPT_TEMPLATE,
|
258 |
+
help="Image prompt template for the decoder-only text encoder model.",
|
259 |
+
)
|
260 |
+
group.add_argument(
|
261 |
+
"--prompt-template-video",
|
262 |
+
type=str,
|
263 |
+
default="dit-llm-encode-video",
|
264 |
+
choices=PROMPT_TEMPLATE,
|
265 |
+
help="Video prompt template for the decoder-only text encoder model.",
|
266 |
+
)
|
267 |
+
group.add_argument(
|
268 |
+
"--hidden-state-skip-layer",
|
269 |
+
type=int,
|
270 |
+
default=2,
|
271 |
+
help="Skip layer for hidden states.",
|
272 |
+
)
|
273 |
+
group.add_argument(
|
274 |
+
"--apply-final-norm",
|
275 |
+
action="store_true",
|
276 |
+
help="Apply final normalization to the used text encoder hidden states.",
|
277 |
+
)
|
278 |
+
|
279 |
+
# - CLIP
|
280 |
+
group.add_argument(
|
281 |
+
"--text-encoder-2",
|
282 |
+
type=str,
|
283 |
+
default="clipL",
|
284 |
+
choices=list(TEXT_ENCODER_PATH),
|
285 |
+
help="Name of the second text encoder model.",
|
286 |
+
)
|
287 |
+
group.add_argument(
|
288 |
+
"--text-encoder-precision-2",
|
289 |
+
type=str,
|
290 |
+
default="fp16",
|
291 |
+
choices=PRECISIONS,
|
292 |
+
help="Precision mode for the second text encoder model.",
|
293 |
+
)
|
294 |
+
group.add_argument(
|
295 |
+
"--text-states-dim-2",
|
296 |
+
type=int,
|
297 |
+
default=768,
|
298 |
+
help="Dimension of the second text encoder hidden states.",
|
299 |
+
)
|
300 |
+
group.add_argument(
|
301 |
+
"--tokenizer-2",
|
302 |
+
type=str,
|
303 |
+
default="clipL",
|
304 |
+
choices=list(TOKENIZER_PATH),
|
305 |
+
help="Name of the second tokenizer model.",
|
306 |
+
)
|
307 |
+
group.add_argument(
|
308 |
+
"--text-len-2",
|
309 |
+
type=int,
|
310 |
+
default=77,
|
311 |
+
help="Maximum length of the second text input.",
|
312 |
+
)
|
313 |
+
|
314 |
+
return parser
|
315 |
+
|
316 |
+
|
317 |
+
def add_denoise_schedule_args(parser: argparse.ArgumentParser):
|
318 |
+
group = parser.add_argument_group(title="Denoise schedule args")
|
319 |
+
|
320 |
+
group.add_argument(
|
321 |
+
"--denoise-type",
|
322 |
+
type=str,
|
323 |
+
default="flow",
|
324 |
+
help="Denoise type for noised inputs.",
|
325 |
+
)
|
326 |
+
|
327 |
+
# Flow Matching
|
328 |
+
group.add_argument(
|
329 |
+
"--flow-shift",
|
330 |
+
type=float,
|
331 |
+
default=7.0,
|
332 |
+
help="Shift factor for flow matching schedulers.",
|
333 |
+
)
|
334 |
+
group.add_argument(
|
335 |
+
"--flow-reverse",
|
336 |
+
action="store_true",
|
337 |
+
help="If reverse, learning/sampling from t=1 -> t=0.",
|
338 |
+
)
|
339 |
+
group.add_argument(
|
340 |
+
"--flow-solver",
|
341 |
+
type=str,
|
342 |
+
default="euler",
|
343 |
+
help="Solver for flow matching.",
|
344 |
+
)
|
345 |
+
group.add_argument(
|
346 |
+
"--use-linear-quadratic-schedule",
|
347 |
+
action="store_true",
|
348 |
+
help="Use linear quadratic schedule for flow matching."
|
349 |
+
"Following MovieGen (https://ai.meta.com/static-resource/movie-gen-research-paper)",
|
350 |
+
)
|
351 |
+
group.add_argument(
|
352 |
+
"--linear-schedule-end",
|
353 |
+
type=int,
|
354 |
+
default=25,
|
355 |
+
help="End step for linear quadratic schedule for flow matching.",
|
356 |
+
)
|
357 |
+
|
358 |
+
return parser
|
359 |
+
|
360 |
+
|
361 |
+
def add_inference_args(parser: argparse.ArgumentParser):
|
362 |
+
group = parser.add_argument_group(title="Inference args")
|
363 |
+
|
364 |
+
# ======================== Model loads ========================
|
365 |
+
group.add_argument(
|
366 |
+
"--model-base",
|
367 |
+
type=str,
|
368 |
+
default="ckpts",
|
369 |
+
help="Root path of all the models, including t2v models and extra models.",
|
370 |
+
)
|
371 |
+
group.add_argument(
|
372 |
+
"--dit-weight",
|
373 |
+
type=str,
|
374 |
+
default="ckpts/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt",
|
375 |
+
help="Path to the HunyuanVideo model. If None, search the model in the args.model_root."
|
376 |
+
"1. If it is a file, load the model directly."
|
377 |
+
"2. If it is a directory, search the model in the directory. Support two types of models: "
|
378 |
+
"1) named `pytorch_model_*.pt`"
|
379 |
+
"2) named `*_model_states.pt`, where * can be `mp_rank_00`.",
|
380 |
+
)
|
381 |
+
group.add_argument(
|
382 |
+
"--model-resolution",
|
383 |
+
type=str,
|
384 |
+
default="540p",
|
385 |
+
choices=["540p", "720p"],
|
386 |
+
help="Root path of all the models, including t2v models and extra models.",
|
387 |
+
)
|
388 |
+
group.add_argument(
|
389 |
+
"--load-key",
|
390 |
+
type=str,
|
391 |
+
default="module",
|
392 |
+
help="Key to load the model states. 'module' for the main model, 'ema' for the EMA model.",
|
393 |
+
)
|
394 |
+
group.add_argument(
|
395 |
+
"--use-cpu-offload",
|
396 |
+
action="store_true",
|
397 |
+
help="Use CPU offload for the model load.",
|
398 |
+
)
|
399 |
+
|
400 |
+
# ======================== Inference general setting ========================
|
401 |
+
group.add_argument(
|
402 |
+
"--batch-size",
|
403 |
+
type=int,
|
404 |
+
default=1,
|
405 |
+
help="Batch size for inference and evaluation.",
|
406 |
+
)
|
407 |
+
group.add_argument(
|
408 |
+
"--infer-steps",
|
409 |
+
type=int,
|
410 |
+
default=50,
|
411 |
+
help="Number of denoising steps for inference.",
|
412 |
+
)
|
413 |
+
group.add_argument(
|
414 |
+
"--disable-autocast",
|
415 |
+
action="store_true",
|
416 |
+
help="Disable autocast for denoising loop and vae decoding in pipeline sampling.",
|
417 |
+
)
|
418 |
+
group.add_argument(
|
419 |
+
"--save-path",
|
420 |
+
type=str,
|
421 |
+
default="./results",
|
422 |
+
help="Path to save the generated samples.",
|
423 |
+
)
|
424 |
+
group.add_argument(
|
425 |
+
"--save-path-suffix",
|
426 |
+
type=str,
|
427 |
+
default="",
|
428 |
+
help="Suffix for the directory of saved samples.",
|
429 |
+
)
|
430 |
+
group.add_argument(
|
431 |
+
"--name-suffix",
|
432 |
+
type=str,
|
433 |
+
default="",
|
434 |
+
help="Suffix for the names of saved samples.",
|
435 |
+
)
|
436 |
+
group.add_argument(
|
437 |
+
"--num-videos",
|
438 |
+
type=int,
|
439 |
+
default=1,
|
440 |
+
help="Number of videos to generate for each prompt.",
|
441 |
+
)
|
442 |
+
# ---sample size---
|
443 |
+
group.add_argument(
|
444 |
+
"--video-size",
|
445 |
+
type=int,
|
446 |
+
nargs="+",
|
447 |
+
default=(720, 1280),
|
448 |
+
help="Video size for training. If a single value is provided, it will be used for both height "
|
449 |
+
"and width. If two values are provided, they will be used for height and width "
|
450 |
+
"respectively.",
|
451 |
+
)
|
452 |
+
group.add_argument(
|
453 |
+
"--video-length",
|
454 |
+
type=int,
|
455 |
+
default=129,
|
456 |
+
help="How many frames to sample from a video. if using 3d vae, the number should be 4n+1",
|
457 |
+
)
|
458 |
+
# --- prompt ---
|
459 |
+
group.add_argument(
|
460 |
+
"--prompt",
|
461 |
+
type=str,
|
462 |
+
default=None,
|
463 |
+
help="Prompt for sampling during evaluation.",
|
464 |
+
)
|
465 |
+
group.add_argument(
|
466 |
+
"--seed-type",
|
467 |
+
type=str,
|
468 |
+
default="auto",
|
469 |
+
choices=["file", "random", "fixed", "auto"],
|
470 |
+
help="Seed type for evaluation. If file, use the seed from the CSV file. If random, generate a "
|
471 |
+
"random seed. If fixed, use the fixed seed given by `--seed`. If auto, `csv` will use the "
|
472 |
+
"seed column if available, otherwise use the fixed `seed` value. `prompt` will use the "
|
473 |
+
"fixed `seed` value.",
|
474 |
+
)
|
475 |
+
group.add_argument("--seed", type=int, default=None, help="Seed for evaluation.")
|
476 |
+
|
477 |
+
# Classifier-Free Guidance
|
478 |
+
group.add_argument(
|
479 |
+
"--neg-prompt", type=str, default=None, help="Negative prompt for sampling."
|
480 |
+
)
|
481 |
+
group.add_argument(
|
482 |
+
"--cfg-scale", type=float, default=1.0, help="Classifier free guidance scale."
|
483 |
+
)
|
484 |
+
group.add_argument(
|
485 |
+
"--embedded-cfg-scale",
|
486 |
+
type=float,
|
487 |
+
default=6.0,
|
488 |
+
help="Embeded classifier free guidance scale.",
|
489 |
+
)
|
490 |
+
|
491 |
+
group.add_argument(
|
492 |
+
"--reproduce",
|
493 |
+
action="store_true",
|
494 |
+
help="Enable reproducibility by setting random seeds and deterministic algorithms.",
|
495 |
+
)
|
496 |
+
|
497 |
+
return parser
|
498 |
+
|
499 |
+
|
500 |
+
def add_parallel_args(parser: argparse.ArgumentParser):
|
501 |
+
group = parser.add_argument_group(title="Parallel args")
|
502 |
+
|
503 |
+
# ======================== Model loads ========================
|
504 |
+
group.add_argument(
|
505 |
+
"--ulysses-degree",
|
506 |
+
type=int,
|
507 |
+
default=1,
|
508 |
+
help="Ulysses degree.",
|
509 |
+
)
|
510 |
+
group.add_argument(
|
511 |
+
"--ring-degree",
|
512 |
+
type=int,
|
513 |
+
default=1,
|
514 |
+
help="Ulysses degree.",
|
515 |
+
)
|
516 |
+
|
517 |
+
return parser
|
518 |
+
|
519 |
+
|
520 |
+
def sanity_check_args(args):
|
521 |
+
# VAE channels
|
522 |
+
vae_pattern = r"\d{2,3}-\d{1,2}c-\w+"
|
523 |
+
if not re.match(vae_pattern, args.vae):
|
524 |
+
raise ValueError(
|
525 |
+
f"Invalid VAE model: {args.vae}. Must be in the format of '{vae_pattern}'."
|
526 |
+
)
|
527 |
+
vae_channels = int(args.vae.split("-")[1][:-1])
|
528 |
+
if args.latent_channels is None:
|
529 |
+
args.latent_channels = vae_channels
|
530 |
+
if vae_channels != args.latent_channels:
|
531 |
+
raise ValueError(
|
532 |
+
f"Latent channels ({args.latent_channels}) must match the VAE channels ({vae_channels})."
|
533 |
+
)
|
534 |
+
return args
|
hyvideo/constants.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
|
4 |
+
__all__ = [
|
5 |
+
"C_SCALE",
|
6 |
+
"PROMPT_TEMPLATE",
|
7 |
+
"MODEL_BASE",
|
8 |
+
"PRECISIONS",
|
9 |
+
"NORMALIZATION_TYPE",
|
10 |
+
"ACTIVATION_TYPE",
|
11 |
+
"VAE_PATH",
|
12 |
+
"TEXT_ENCODER_PATH",
|
13 |
+
"TOKENIZER_PATH",
|
14 |
+
"TEXT_PROJECTION",
|
15 |
+
"DATA_TYPE",
|
16 |
+
"NEGATIVE_PROMPT",
|
17 |
+
"NEGATIVE_PROMPT_I2V",
|
18 |
+
"FLOW_PATH_TYPE",
|
19 |
+
"FLOW_PREDICT_TYPE",
|
20 |
+
"FLOW_LOSS_WEIGHT",
|
21 |
+
"FLOW_SNR_TYPE",
|
22 |
+
"FLOW_SOLVER",
|
23 |
+
]
|
24 |
+
|
25 |
+
PRECISION_TO_TYPE = {
|
26 |
+
'fp32': torch.float32,
|
27 |
+
'fp16': torch.float16,
|
28 |
+
'bf16': torch.bfloat16,
|
29 |
+
}
|
30 |
+
|
31 |
+
# =================== Constant Values =====================
|
32 |
+
# Computation scale factor, 1P = 1_000_000_000_000_000. Tensorboard will display the value in PetaFLOPS to avoid
|
33 |
+
# overflow error when tensorboard logging values.
|
34 |
+
C_SCALE = 1_000_000_000_000_000
|
35 |
+
|
36 |
+
# When using decoder-only models, we must provide a prompt template to instruct the text encoder
|
37 |
+
# on how to generate the text.
|
38 |
+
# --------------------------------------------------------------------
|
39 |
+
PROMPT_TEMPLATE_ENCODE = (
|
40 |
+
"<|start_header_id|>system<|end_header_id|>\n\nDescribe the image by detailing the color, shape, size, texture, "
|
41 |
+
"quantity, text, spatial relationships of the objects and background:<|eot_id|>"
|
42 |
+
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
|
43 |
+
)
|
44 |
+
PROMPT_TEMPLATE_ENCODE_VIDEO = (
|
45 |
+
"<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: "
|
46 |
+
"1. The main content and theme of the video."
|
47 |
+
"2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
|
48 |
+
"3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
|
49 |
+
"4. background environment, light, style and atmosphere."
|
50 |
+
"5. camera angles, movements, and transitions used in the video:<|eot_id|>"
|
51 |
+
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
|
52 |
+
)
|
53 |
+
|
54 |
+
PROMPT_TEMPLATE_ENCODE_I2V = (
|
55 |
+
"<|start_header_id|>system<|end_header_id|>\n\n<image>\nDescribe the image by detailing the color, shape, size, texture, "
|
56 |
+
"quantity, text, spatial relationships of the objects and background:<|eot_id|>"
|
57 |
+
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
|
58 |
+
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
59 |
+
)
|
60 |
+
|
61 |
+
PROMPT_TEMPLATE_ENCODE_VIDEO_I2V = (
|
62 |
+
"<|start_header_id|>system<|end_header_id|>\n\n<image>\nDescribe the video by detailing the following aspects according to the reference image: "
|
63 |
+
"1. The main content and theme of the video."
|
64 |
+
"2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
|
65 |
+
"3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
|
66 |
+
"4. background environment, light, style and atmosphere."
|
67 |
+
"5. camera angles, movements, and transitions used in the video:<|eot_id|>\n\n"
|
68 |
+
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
|
69 |
+
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
70 |
+
)
|
71 |
+
|
72 |
+
NEGATIVE_PROMPT = "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion"
|
73 |
+
NEGATIVE_PROMPT_I2V = "deformation, a poor composition and deformed video, bad teeth, bad eyes, bad limbs"
|
74 |
+
|
75 |
+
PROMPT_TEMPLATE = {
|
76 |
+
"dit-llm-encode": {
|
77 |
+
"template": PROMPT_TEMPLATE_ENCODE,
|
78 |
+
"crop_start": 36,
|
79 |
+
},
|
80 |
+
"dit-llm-encode-video": {
|
81 |
+
"template": PROMPT_TEMPLATE_ENCODE_VIDEO,
|
82 |
+
"crop_start": 95,
|
83 |
+
},
|
84 |
+
"dit-llm-encode-i2v": {
|
85 |
+
"template": PROMPT_TEMPLATE_ENCODE_I2V,
|
86 |
+
"crop_start": 36,
|
87 |
+
"image_emb_start": 5,
|
88 |
+
"image_emb_end": 581,
|
89 |
+
"image_emb_len": 576,
|
90 |
+
"double_return_token_id": 271
|
91 |
+
},
|
92 |
+
"dit-llm-encode-video-i2v": {
|
93 |
+
"template": PROMPT_TEMPLATE_ENCODE_VIDEO_I2V,
|
94 |
+
"crop_start": 103,
|
95 |
+
"image_emb_start": 5,
|
96 |
+
"image_emb_end": 581,
|
97 |
+
"image_emb_len": 576,
|
98 |
+
"double_return_token_id": 271
|
99 |
+
},
|
100 |
+
}
|
101 |
+
|
102 |
+
# ======================= Model ======================
|
103 |
+
PRECISIONS = {"fp32", "fp16", "bf16"}
|
104 |
+
NORMALIZATION_TYPE = {"layer", "rms"}
|
105 |
+
ACTIVATION_TYPE = {"relu", "silu", "gelu", "gelu_tanh"}
|
106 |
+
|
107 |
+
# =================== Model Path =====================
|
108 |
+
MODEL_BASE = os.getenv("MODEL_BASE", "./ckpts")
|
109 |
+
|
110 |
+
# =================== Data =======================
|
111 |
+
DATA_TYPE = {"image", "video", "image_video"}
|
112 |
+
|
113 |
+
# 3D VAE
|
114 |
+
VAE_PATH = {"884-16c-hy": f"{MODEL_BASE}/hunyuan-video-t2v-720p/vae"}
|
115 |
+
|
116 |
+
# Text Encoder
|
117 |
+
TEXT_ENCODER_PATH = {
|
118 |
+
"clipL": f"{MODEL_BASE}/clip_vit_large_patch14",
|
119 |
+
"llm": f"{MODEL_BASE}/llava-llama-3-8b",
|
120 |
+
"llm-i2v": f"{MODEL_BASE}/llava-llama-3-8b",
|
121 |
+
}
|
122 |
+
|
123 |
+
# Tokenizer
|
124 |
+
TOKENIZER_PATH = {
|
125 |
+
"clipL": f"{MODEL_BASE}/clip_vit_large_patch14",
|
126 |
+
"llm": f"{MODEL_BASE}/llava-llama-3-8b",
|
127 |
+
"llm-i2v": f"{MODEL_BASE}/llava-llama-3-8b",
|
128 |
+
}
|
129 |
+
|
130 |
+
TEXT_PROJECTION = {
|
131 |
+
"linear", # Default, an nn.Linear() layer
|
132 |
+
"single_refiner", # Single TokenRefiner. Refer to LI-DiT
|
133 |
+
}
|
134 |
+
|
135 |
+
# Flow Matching path type
|
136 |
+
FLOW_PATH_TYPE = {
|
137 |
+
"linear", # Linear trajectory between noise and data
|
138 |
+
"gvp", # Generalized variance-preserving SDE
|
139 |
+
"vp", # Variance-preserving SDE
|
140 |
+
}
|
141 |
+
|
142 |
+
# Flow Matching predict type
|
143 |
+
FLOW_PREDICT_TYPE = {
|
144 |
+
"velocity", # Predict velocity
|
145 |
+
"score", # Predict score
|
146 |
+
"noise", # Predict noise
|
147 |
+
}
|
148 |
+
|
149 |
+
# Flow Matching loss weight
|
150 |
+
FLOW_LOSS_WEIGHT = {
|
151 |
+
"velocity", # Weight loss by velocity
|
152 |
+
"likelihood", # Weight loss by likelihood
|
153 |
+
}
|
154 |
+
|
155 |
+
# Flow Matching SNR type
|
156 |
+
FLOW_SNR_TYPE = {
|
157 |
+
"lognorm", # Log-normal SNR
|
158 |
+
"uniform", # Uniform SNR
|
159 |
+
}
|
160 |
+
|
161 |
+
# Flow Matching solvers
|
162 |
+
FLOW_SOLVER = {
|
163 |
+
"euler", # Euler solver
|
164 |
+
}
|
hyvideo/data_kits/audio_dataset.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import math
|
4 |
+
import json
|
5 |
+
import torch
|
6 |
+
import random
|
7 |
+
import librosa
|
8 |
+
import traceback
|
9 |
+
import torchvision
|
10 |
+
import numpy as np
|
11 |
+
import pandas as pd
|
12 |
+
from PIL import Image
|
13 |
+
from einops import rearrange
|
14 |
+
from torch.utils.data import Dataset
|
15 |
+
from decord import VideoReader, cpu
|
16 |
+
from transformers import CLIPImageProcessor
|
17 |
+
import torchvision.transforms as transforms
|
18 |
+
from torchvision.transforms import ToPILImage
|
19 |
+
|
20 |
+
|
21 |
+
|
22 |
+
def get_audio_feature(feature_extractor, audio_path):
|
23 |
+
audio_input, sampling_rate = librosa.load(audio_path, sr=16000)
|
24 |
+
assert sampling_rate == 16000
|
25 |
+
|
26 |
+
audio_features = []
|
27 |
+
window = 750*640
|
28 |
+
for i in range(0, len(audio_input), window):
|
29 |
+
audio_feature = feature_extractor(audio_input[i:i+window],
|
30 |
+
sampling_rate=sampling_rate,
|
31 |
+
return_tensors="pt",
|
32 |
+
).input_features
|
33 |
+
audio_features.append(audio_feature)
|
34 |
+
|
35 |
+
audio_features = torch.cat(audio_features, dim=-1)
|
36 |
+
return audio_features, len(audio_input) // 640
|
37 |
+
|
38 |
+
|
39 |
+
class VideoAudioTextLoaderVal(Dataset):
|
40 |
+
def __init__(
|
41 |
+
self,
|
42 |
+
image_size: int,
|
43 |
+
meta_file: str,
|
44 |
+
**kwargs,
|
45 |
+
):
|
46 |
+
super().__init__()
|
47 |
+
self.meta_file = meta_file
|
48 |
+
self.image_size = image_size
|
49 |
+
self.text_encoder = kwargs.get("text_encoder", None) # llava_text_encoder
|
50 |
+
self.text_encoder_2 = kwargs.get("text_encoder_2", None) # clipL_text_encoder
|
51 |
+
self.feature_extractor = kwargs.get("feature_extractor", None)
|
52 |
+
self.meta_files = []
|
53 |
+
|
54 |
+
csv_data = pd.read_csv(meta_file)
|
55 |
+
for idx in range(len(csv_data)):
|
56 |
+
self.meta_files.append(
|
57 |
+
{
|
58 |
+
"videoid": str(csv_data["videoid"][idx]),
|
59 |
+
"image_path": str(csv_data["image"][idx]),
|
60 |
+
"audio_path": str(csv_data["audio"][idx]),
|
61 |
+
"prompt": str(csv_data["prompt"][idx]),
|
62 |
+
"fps": float(csv_data["fps"][idx])
|
63 |
+
}
|
64 |
+
)
|
65 |
+
|
66 |
+
self.llava_transform = transforms.Compose(
|
67 |
+
[
|
68 |
+
transforms.Resize((336, 336), interpolation=transforms.InterpolationMode.BILINEAR),
|
69 |
+
transforms.ToTensor(),
|
70 |
+
transforms.Normalize((0.48145466, 0.4578275, 0.4082107), (0.26862954, 0.26130258, 0.27577711)),
|
71 |
+
]
|
72 |
+
)
|
73 |
+
self.clip_image_processor = CLIPImageProcessor()
|
74 |
+
|
75 |
+
self.device = torch.device("cuda")
|
76 |
+
self.weight_dtype = torch.float16
|
77 |
+
|
78 |
+
|
79 |
+
def __len__(self):
|
80 |
+
return len(self.meta_files)
|
81 |
+
|
82 |
+
@staticmethod
|
83 |
+
def get_text_tokens(text_encoder, description, dtype_encode="video"):
|
84 |
+
text_inputs = text_encoder.text2tokens(description, data_type=dtype_encode)
|
85 |
+
text_ids = text_inputs["input_ids"].squeeze(0)
|
86 |
+
text_mask = text_inputs["attention_mask"].squeeze(0)
|
87 |
+
return text_ids, text_mask
|
88 |
+
|
89 |
+
def get_batch_data(self, idx):
|
90 |
+
meta_file = self.meta_files[idx]
|
91 |
+
videoid = meta_file["videoid"]
|
92 |
+
image_path = meta_file["image_path"]
|
93 |
+
audio_path = meta_file["audio_path"]
|
94 |
+
prompt = "Authentic, Realistic, Natural, High-quality, Lens-Fixed, " + meta_file["prompt"]
|
95 |
+
fps = meta_file["fps"]
|
96 |
+
|
97 |
+
img_size = self.image_size
|
98 |
+
ref_image = Image.open(image_path).convert('RGB')
|
99 |
+
|
100 |
+
# Resize reference image
|
101 |
+
w, h = ref_image.size
|
102 |
+
scale = img_size / min(w, h)
|
103 |
+
new_w = round(w * scale / 64) * 64
|
104 |
+
new_h = round(h * scale / 64) * 64
|
105 |
+
|
106 |
+
if img_size == 704:
|
107 |
+
img_size_long = 1216
|
108 |
+
if new_w * new_h > img_size * img_size_long:
|
109 |
+
import math
|
110 |
+
scale = math.sqrt(img_size * img_size_long / w / h)
|
111 |
+
new_w = round(w * scale / 64) * 64
|
112 |
+
new_h = round(h * scale / 64) * 64
|
113 |
+
|
114 |
+
ref_image = ref_image.resize((new_w, new_h), Image.LANCZOS)
|
115 |
+
|
116 |
+
ref_image = np.array(ref_image)
|
117 |
+
ref_image = torch.from_numpy(ref_image)
|
118 |
+
|
119 |
+
audio_input, audio_len = get_audio_feature(self.feature_extractor, audio_path)
|
120 |
+
audio_prompts = audio_input[0]
|
121 |
+
|
122 |
+
motion_bucket_id_heads = np.array([25] * 4)
|
123 |
+
motion_bucket_id_exps = np.array([30] * 4)
|
124 |
+
motion_bucket_id_heads = torch.from_numpy(motion_bucket_id_heads)
|
125 |
+
motion_bucket_id_exps = torch.from_numpy(motion_bucket_id_exps)
|
126 |
+
fps = torch.from_numpy(np.array(fps))
|
127 |
+
|
128 |
+
to_pil = ToPILImage()
|
129 |
+
pixel_value_ref = rearrange(ref_image.clone().unsqueeze(0), "b h w c -> b c h w") # (b c h w)
|
130 |
+
|
131 |
+
pixel_value_ref_llava = [self.llava_transform(to_pil(image)) for image in pixel_value_ref]
|
132 |
+
pixel_value_ref_llava = torch.stack(pixel_value_ref_llava, dim=0)
|
133 |
+
pixel_value_ref_clip = self.clip_image_processor(
|
134 |
+
images=Image.fromarray((pixel_value_ref[0].permute(1,2,0)).data.cpu().numpy().astype(np.uint8)),
|
135 |
+
return_tensors="pt"
|
136 |
+
).pixel_values[0]
|
137 |
+
pixel_value_ref_clip = pixel_value_ref_clip.unsqueeze(0)
|
138 |
+
|
139 |
+
# Encode text prompts
|
140 |
+
|
141 |
+
text_ids, text_mask = self.get_text_tokens(self.text_encoder, prompt)
|
142 |
+
text_ids_2, text_mask_2 = self.get_text_tokens(self.text_encoder_2, prompt)
|
143 |
+
|
144 |
+
# Output batch
|
145 |
+
batch = {
|
146 |
+
"text_prompt": prompt, #
|
147 |
+
"videoid": videoid,
|
148 |
+
"pixel_value_ref": pixel_value_ref.to(dtype=torch.float16), # 参考图,用于vae提特征 (1, 3, h, w), 取值范围(0, 255)
|
149 |
+
"pixel_value_ref_llava": pixel_value_ref_llava.to(dtype=torch.float16), # 参考图,用于llava提特征 (1, 3, 336, 336), 取值范围 = CLIP取值范围
|
150 |
+
"pixel_value_ref_clip": pixel_value_ref_clip.to(dtype=torch.float16), # 参考图,用于clip_image_encoder提特征 (1, 3, 244, 244), 取值范围 = CLIP取值范围
|
151 |
+
"audio_prompts": audio_prompts.to(dtype=torch.float16),
|
152 |
+
"motion_bucket_id_heads": motion_bucket_id_heads.to(dtype=text_ids.dtype),
|
153 |
+
"motion_bucket_id_exps": motion_bucket_id_exps.to(dtype=text_ids.dtype),
|
154 |
+
"fps": fps.to(dtype=torch.float16),
|
155 |
+
"text_ids": text_ids.clone(), # 对应llava_text_encoder
|
156 |
+
"text_mask": text_mask.clone(), # 对应llava_text_encoder
|
157 |
+
"text_ids_2": text_ids_2.clone(), # 对应clip_text_encoder
|
158 |
+
"text_mask_2": text_mask_2.clone(), # 对应clip_text_encoder
|
159 |
+
"audio_len": audio_len,
|
160 |
+
"image_path": image_path,
|
161 |
+
"audio_path": audio_path,
|
162 |
+
}
|
163 |
+
return batch
|
164 |
+
|
165 |
+
def __getitem__(self, idx):
|
166 |
+
return self.get_batch_data(idx)
|
167 |
+
|
168 |
+
|
169 |
+
|
170 |
+
|
hyvideo/data_kits/audio_preprocessor.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import os
|
3 |
+
import cv2
|
4 |
+
import json
|
5 |
+
import time
|
6 |
+
import decord
|
7 |
+
import einops
|
8 |
+
import librosa
|
9 |
+
import torch
|
10 |
+
import random
|
11 |
+
import argparse
|
12 |
+
import traceback
|
13 |
+
import numpy as np
|
14 |
+
from tqdm import tqdm
|
15 |
+
from PIL import Image
|
16 |
+
from einops import rearrange
|
17 |
+
|
18 |
+
|
19 |
+
|
20 |
+
def get_facemask(ref_image, align_instance, area=1.25):
|
21 |
+
# ref_image: (b f c h w)
|
22 |
+
bsz, f, c, h, w = ref_image.shape
|
23 |
+
images = rearrange(ref_image, "b f c h w -> (b f) h w c").data.cpu().numpy().astype(np.uint8)
|
24 |
+
face_masks = []
|
25 |
+
for image in images:
|
26 |
+
image_pil = Image.fromarray(image).convert("RGB")
|
27 |
+
_, _, bboxes_list = align_instance(np.array(image_pil)[:,:,[2,1,0]], maxface=True)
|
28 |
+
try:
|
29 |
+
bboxSrc = bboxes_list[0]
|
30 |
+
except:
|
31 |
+
bboxSrc = [0, 0, w, h]
|
32 |
+
x1, y1, ww, hh = bboxSrc
|
33 |
+
x2, y2 = x1 + ww, y1 + hh
|
34 |
+
ww, hh = (x2-x1) * area, (y2-y1) * area
|
35 |
+
center = [(x2+x1)//2, (y2+y1)//2]
|
36 |
+
x1 = max(center[0] - ww//2, 0)
|
37 |
+
y1 = max(center[1] - hh//2, 0)
|
38 |
+
x2 = min(center[0] + ww//2, w)
|
39 |
+
y2 = min(center[1] + hh//2, h)
|
40 |
+
|
41 |
+
face_mask = np.zeros_like(np.array(image_pil))
|
42 |
+
face_mask[int(y1):int(y2), int(x1):int(x2)] = 1.0
|
43 |
+
face_masks.append(torch.from_numpy(face_mask[...,:1]))
|
44 |
+
face_masks = torch.stack(face_masks, dim=0) # (b*f, h, w, c)
|
45 |
+
face_masks = rearrange(face_masks, "(b f) h w c -> b c f h w", b=bsz, f=f)
|
46 |
+
face_masks = face_masks.to(device=ref_image.device, dtype=ref_image.dtype)
|
47 |
+
return face_masks
|
48 |
+
|
49 |
+
|
50 |
+
def encode_audio(wav2vec, audio_feats, fps, num_frames=129):
|
51 |
+
if fps == 25:
|
52 |
+
start_ts = [0]
|
53 |
+
step_ts = [1]
|
54 |
+
elif fps == 12.5:
|
55 |
+
start_ts = [0]
|
56 |
+
step_ts = [2]
|
57 |
+
else:
|
58 |
+
start_ts = [0]
|
59 |
+
step_ts = [1]
|
60 |
+
|
61 |
+
num_frames = min(num_frames, 400)
|
62 |
+
audio_feats = wav2vec.encoder(audio_feats.unsqueeze(0)[:, :, :3000], output_hidden_states=True).hidden_states
|
63 |
+
audio_feats = torch.stack(audio_feats, dim=2)
|
64 |
+
audio_feats = torch.cat([torch.zeros_like(audio_feats[:,:4]), audio_feats], 1)
|
65 |
+
|
66 |
+
audio_prompts = []
|
67 |
+
for bb in range(1):
|
68 |
+
audio_feats_list = []
|
69 |
+
for f in range(num_frames):
|
70 |
+
cur_t = (start_ts[bb] + f * step_ts[bb]) * 2
|
71 |
+
audio_clip = audio_feats[bb:bb+1, cur_t: cur_t+10]
|
72 |
+
audio_feats_list.append(audio_clip)
|
73 |
+
audio_feats_list = torch.stack(audio_feats_list, 1)
|
74 |
+
audio_prompts.append(audio_feats_list)
|
75 |
+
audio_prompts = torch.cat(audio_prompts)
|
76 |
+
return audio_prompts
|
hyvideo/data_kits/data_tools.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
import imageio
|
6 |
+
import torchvision
|
7 |
+
from einops import rearrange
|
8 |
+
|
9 |
+
|
10 |
+
def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8, quality=8):
|
11 |
+
videos = rearrange(videos, "b c t h w -> t b c h w")
|
12 |
+
outputs = []
|
13 |
+
for x in videos:
|
14 |
+
x = torchvision.utils.make_grid(x, nrow=n_rows)
|
15 |
+
x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
|
16 |
+
if rescale:
|
17 |
+
x = (x + 1.0) / 2.0 # -1,1 -> 0,1
|
18 |
+
x = torch.clamp(x,0,1)
|
19 |
+
x = (x * 255).numpy().astype(np.uint8)
|
20 |
+
outputs.append(x)
|
21 |
+
|
22 |
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
23 |
+
imageio.mimsave(path, outputs, fps=fps, quality=quality)
|
24 |
+
|
25 |
+
def pad_image(crop_img, size, color=(255, 255, 255), resize_ratio=1):
|
26 |
+
crop_h, crop_w = crop_img.shape[:2]
|
27 |
+
target_w, target_h = size
|
28 |
+
scale_h, scale_w = target_h / crop_h, target_w / crop_w
|
29 |
+
if scale_w > scale_h:
|
30 |
+
resize_h = int(target_h*resize_ratio)
|
31 |
+
resize_w = int(crop_w / crop_h * resize_h)
|
32 |
+
else:
|
33 |
+
resize_w = int(target_w*resize_ratio)
|
34 |
+
resize_h = int(crop_h / crop_w * resize_w)
|
35 |
+
crop_img = cv2.resize(crop_img, (resize_w, resize_h))
|
36 |
+
pad_left = (target_w - resize_w) // 2
|
37 |
+
pad_top = (target_h - resize_h) // 2
|
38 |
+
pad_right = target_w - resize_w - pad_left
|
39 |
+
pad_bottom = target_h - resize_h - pad_top
|
40 |
+
crop_img = cv2.copyMakeBorder(crop_img, pad_top, pad_bottom, pad_left, pad_right, cv2.BORDER_CONSTANT, value=color)
|
41 |
+
return crop_img
|
hyvideo/data_kits/face_align/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .align import AlignImage
|
hyvideo/data_kits/face_align/align.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import torch
|
4 |
+
from .detface import DetFace
|
5 |
+
|
6 |
+
class AlignImage(object):
|
7 |
+
def __init__(self, device='cuda', det_path=''):
|
8 |
+
self.facedet = DetFace(pt_path=det_path, confThreshold=0.5, nmsThreshold=0.45, device=device)
|
9 |
+
|
10 |
+
@torch.no_grad()
|
11 |
+
def __call__(self, im, maxface=False):
|
12 |
+
bboxes, kpss, scores = self.facedet.detect(im)
|
13 |
+
face_num = bboxes.shape[0]
|
14 |
+
|
15 |
+
five_pts_list = []
|
16 |
+
scores_list = []
|
17 |
+
bboxes_list = []
|
18 |
+
for i in range(face_num):
|
19 |
+
five_pts_list.append(kpss[i].reshape(5,2))
|
20 |
+
scores_list.append(scores[i])
|
21 |
+
bboxes_list.append(bboxes[i])
|
22 |
+
|
23 |
+
if maxface and face_num>1:
|
24 |
+
max_idx = 0
|
25 |
+
max_area = (bboxes[0, 2])*(bboxes[0, 3])
|
26 |
+
for i in range(1, face_num):
|
27 |
+
area = (bboxes[i,2])*(bboxes[i,3])
|
28 |
+
if area>max_area:
|
29 |
+
max_idx = i
|
30 |
+
five_pts_list = [five_pts_list[max_idx]]
|
31 |
+
scores_list = [scores_list[max_idx]]
|
32 |
+
bboxes_list = [bboxes_list[max_idx]]
|
33 |
+
|
34 |
+
return five_pts_list, scores_list, bboxes_list
|
hyvideo/data_kits/face_align/detface.py
ADDED
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: UTF-8 -*-
|
2 |
+
import os
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torchvision
|
7 |
+
|
8 |
+
|
9 |
+
def xyxy2xywh(x):
|
10 |
+
# Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
|
11 |
+
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
12 |
+
y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
|
13 |
+
y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
|
14 |
+
y[:, 2] = x[:, 2] - x[:, 0] # width
|
15 |
+
y[:, 3] = x[:, 3] - x[:, 1] # height
|
16 |
+
return y
|
17 |
+
|
18 |
+
|
19 |
+
def xywh2xyxy(x):
|
20 |
+
# Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
|
21 |
+
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
22 |
+
y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
|
23 |
+
y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
|
24 |
+
y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
|
25 |
+
y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
|
26 |
+
return y
|
27 |
+
|
28 |
+
|
29 |
+
def box_iou(box1, box2):
|
30 |
+
# https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
|
31 |
+
"""
|
32 |
+
Return intersection-over-union (Jaccard index) of boxes.
|
33 |
+
Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
|
34 |
+
Arguments:
|
35 |
+
box1 (Tensor[N, 4])
|
36 |
+
box2 (Tensor[M, 4])
|
37 |
+
Returns:
|
38 |
+
iou (Tensor[N, M]): the NxM matrix containing the pairwise
|
39 |
+
IoU values for every element in boxes1 and boxes2
|
40 |
+
"""
|
41 |
+
|
42 |
+
def box_area(box):
|
43 |
+
# box = 4xn
|
44 |
+
return (box[2] - box[0]) * (box[3] - box[1])
|
45 |
+
|
46 |
+
area1 = box_area(box1.T)
|
47 |
+
area2 = box_area(box2.T)
|
48 |
+
|
49 |
+
# inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
|
50 |
+
inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) -
|
51 |
+
torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2)
|
52 |
+
# iou = inter / (area1 + area2 - inter)
|
53 |
+
return inter / (area1[:, None] + area2 - inter)
|
54 |
+
|
55 |
+
|
56 |
+
def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None):
|
57 |
+
# Rescale coords (xyxy) from img1_shape to img0_shape
|
58 |
+
if ratio_pad is None: # calculate from img0_shape
|
59 |
+
gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
|
60 |
+
pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
|
61 |
+
else:
|
62 |
+
gain = ratio_pad[0][0]
|
63 |
+
pad = ratio_pad[1]
|
64 |
+
|
65 |
+
coords[:, [0, 2]] -= pad[0] # x padding
|
66 |
+
coords[:, [1, 3]] -= pad[1] # y padding
|
67 |
+
coords[:, :4] /= gain
|
68 |
+
clip_coords(coords, img0_shape)
|
69 |
+
return coords
|
70 |
+
|
71 |
+
|
72 |
+
def clip_coords(boxes, img_shape):
|
73 |
+
# Clip bounding xyxy bounding boxes to image shape (height, width)
|
74 |
+
boxes[:, 0].clamp_(0, img_shape[1]) # x1
|
75 |
+
boxes[:, 1].clamp_(0, img_shape[0]) # y1
|
76 |
+
boxes[:, 2].clamp_(0, img_shape[1]) # x2
|
77 |
+
boxes[:, 3].clamp_(0, img_shape[0]) # y2
|
78 |
+
|
79 |
+
|
80 |
+
def scale_coords_landmarks(img1_shape, coords, img0_shape, ratio_pad=None):
|
81 |
+
# Rescale coords (xyxy) from img1_shape to img0_shape
|
82 |
+
if ratio_pad is None: # calculate from img0_shape
|
83 |
+
gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
|
84 |
+
pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
|
85 |
+
else:
|
86 |
+
gain = ratio_pad[0][0]
|
87 |
+
pad = ratio_pad[1]
|
88 |
+
|
89 |
+
coords[:, [0, 2, 4, 6, 8]] -= pad[0] # x padding
|
90 |
+
coords[:, [1, 3, 5, 7, 9]] -= pad[1] # y padding
|
91 |
+
coords[:, :10] /= gain
|
92 |
+
#clip_coords(coords, img0_shape)
|
93 |
+
coords[:, 0].clamp_(0, img0_shape[1]) # x1
|
94 |
+
coords[:, 1].clamp_(0, img0_shape[0]) # y1
|
95 |
+
coords[:, 2].clamp_(0, img0_shape[1]) # x2
|
96 |
+
coords[:, 3].clamp_(0, img0_shape[0]) # y2
|
97 |
+
coords[:, 4].clamp_(0, img0_shape[1]) # x3
|
98 |
+
coords[:, 5].clamp_(0, img0_shape[0]) # y3
|
99 |
+
coords[:, 6].clamp_(0, img0_shape[1]) # x4
|
100 |
+
coords[:, 7].clamp_(0, img0_shape[0]) # y4
|
101 |
+
coords[:, 8].clamp_(0, img0_shape[1]) # x5
|
102 |
+
coords[:, 9].clamp_(0, img0_shape[0]) # y5
|
103 |
+
return coords
|
104 |
+
|
105 |
+
|
106 |
+
def show_results(img, xywh, conf, landmarks, class_num):
|
107 |
+
h,w,c = img.shape
|
108 |
+
tl = 1 or round(0.002 * (h + w) / 2) + 1 # line/font thickness
|
109 |
+
x1 = int(xywh[0] * w - 0.5 * xywh[2] * w)
|
110 |
+
y1 = int(xywh[1] * h - 0.5 * xywh[3] * h)
|
111 |
+
x2 = int(xywh[0] * w + 0.5 * xywh[2] * w)
|
112 |
+
y2 = int(xywh[1] * h + 0.5 * xywh[3] * h)
|
113 |
+
cv2.rectangle(img, (x1,y1), (x2, y2), (0,255,0), thickness=tl, lineType=cv2.LINE_AA)
|
114 |
+
|
115 |
+
clors = [(255,0,0),(0,255,0),(0,0,255),(255,255,0),(0,255,255)]
|
116 |
+
|
117 |
+
for i in range(5):
|
118 |
+
point_x = int(landmarks[2 * i] * w)
|
119 |
+
point_y = int(landmarks[2 * i + 1] * h)
|
120 |
+
cv2.circle(img, (point_x, point_y), tl+1, clors[i], -1)
|
121 |
+
|
122 |
+
tf = max(tl - 1, 1) # font thickness
|
123 |
+
label = str(conf)[:5]
|
124 |
+
cv2.putText(img, label, (x1, y1 - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
|
125 |
+
return img
|
126 |
+
|
127 |
+
|
128 |
+
def make_divisible(x, divisor):
|
129 |
+
# Returns x evenly divisible by divisor
|
130 |
+
return (x // divisor) * divisor
|
131 |
+
|
132 |
+
|
133 |
+
def non_max_suppression_face(prediction, conf_thres=0.5, iou_thres=0.45, classes=None, agnostic=False, labels=()):
|
134 |
+
"""Performs Non-Maximum Suppression (NMS) on inference results
|
135 |
+
Returns:
|
136 |
+
detections with shape: nx6 (x1, y1, x2, y2, conf, cls)
|
137 |
+
"""
|
138 |
+
|
139 |
+
nc = prediction.shape[2] - 15 # number of classes
|
140 |
+
xc = prediction[..., 4] > conf_thres # candidates
|
141 |
+
|
142 |
+
# Settings
|
143 |
+
min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height
|
144 |
+
# time_limit = 10.0 # seconds to quit after
|
145 |
+
redundant = True # require redundant detections
|
146 |
+
multi_label = nc > 1 # multiple labels per box (adds 0.5ms/img)
|
147 |
+
merge = False # use merge-NMS
|
148 |
+
|
149 |
+
# t = time.time()
|
150 |
+
output = [torch.zeros((0, 16), device=prediction.device)] * prediction.shape[0]
|
151 |
+
for xi, x in enumerate(prediction): # image index, image inference
|
152 |
+
# Apply constraints
|
153 |
+
# x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
|
154 |
+
x = x[xc[xi]] # confidence
|
155 |
+
|
156 |
+
# Cat apriori labels if autolabelling
|
157 |
+
if labels and len(labels[xi]):
|
158 |
+
l = labels[xi]
|
159 |
+
v = torch.zeros((len(l), nc + 15), device=x.device)
|
160 |
+
v[:, :4] = l[:, 1:5] # box
|
161 |
+
v[:, 4] = 1.0 # conf
|
162 |
+
v[range(len(l)), l[:, 0].long() + 15] = 1.0 # cls
|
163 |
+
x = torch.cat((x, v), 0)
|
164 |
+
|
165 |
+
# If none remain process next image
|
166 |
+
if not x.shape[0]:
|
167 |
+
continue
|
168 |
+
|
169 |
+
# Compute conf
|
170 |
+
x[:, 15:] *= x[:, 4:5] # conf = obj_conf * cls_conf
|
171 |
+
|
172 |
+
# Box (center x, center y, width, height) to (x1, y1, x2, y2)
|
173 |
+
box = xywh2xyxy(x[:, :4])
|
174 |
+
|
175 |
+
# Detections matrix nx6 (xyxy, conf, landmarks, cls)
|
176 |
+
if multi_label:
|
177 |
+
i, j = (x[:, 15:] > conf_thres).nonzero(as_tuple=False).T
|
178 |
+
x = torch.cat((box[i], x[i, j + 15, None], x[i, 5:15] ,j[:, None].float()), 1)
|
179 |
+
else: # best class only
|
180 |
+
conf, j = x[:, 15:].max(1, keepdim=True)
|
181 |
+
x = torch.cat((box, conf, x[:, 5:15], j.float()), 1)[conf.view(-1) > conf_thres]
|
182 |
+
|
183 |
+
# Filter by class
|
184 |
+
if classes is not None:
|
185 |
+
x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
|
186 |
+
|
187 |
+
# If none remain process next image
|
188 |
+
n = x.shape[0] # number of boxes
|
189 |
+
if not n:
|
190 |
+
continue
|
191 |
+
|
192 |
+
# Batched NMS
|
193 |
+
c = x[:, 15:16] * (0 if agnostic else max_wh) # classes
|
194 |
+
boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
|
195 |
+
i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
|
196 |
+
#if i.shape[0] > max_det: # limit detections
|
197 |
+
# i = i[:max_det]
|
198 |
+
if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
|
199 |
+
# update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
|
200 |
+
iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
|
201 |
+
weights = iou * scores[None] # box weights
|
202 |
+
x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
|
203 |
+
if redundant:
|
204 |
+
i = i[iou.sum(1) > 1] # require redundancy
|
205 |
+
|
206 |
+
output[xi] = x[i]
|
207 |
+
# if (time.time() - t) > time_limit:
|
208 |
+
# break # time limit exceeded
|
209 |
+
|
210 |
+
return output
|
211 |
+
|
212 |
+
|
213 |
+
class DetFace():
|
214 |
+
def __init__(self, pt_path, confThreshold=0.5, nmsThreshold=0.45, device='cuda'):
|
215 |
+
assert os.path.exists(pt_path)
|
216 |
+
|
217 |
+
self.inpSize = 416
|
218 |
+
self.conf_thres = confThreshold
|
219 |
+
self.iou_thres = nmsThreshold
|
220 |
+
self.test_device = torch.device(device if torch.cuda.is_available() else "cpu")
|
221 |
+
self.model = torch.jit.load(pt_path).to(self.test_device)
|
222 |
+
self.last_w = 416
|
223 |
+
self.last_h = 416
|
224 |
+
self.grids = None
|
225 |
+
|
226 |
+
@torch.no_grad()
|
227 |
+
def detect(self, srcimg):
|
228 |
+
# t0=time.time()
|
229 |
+
|
230 |
+
h0, w0 = srcimg.shape[:2] # orig hw
|
231 |
+
r = self.inpSize / min(h0, w0) # resize image to img_size
|
232 |
+
h1 = int(h0*r+31)//32*32
|
233 |
+
w1 = int(w0*r+31)//32*32
|
234 |
+
|
235 |
+
img = cv2.resize(srcimg, (w1,h1), interpolation=cv2.INTER_LINEAR)
|
236 |
+
|
237 |
+
# Convert
|
238 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # BGR to RGB
|
239 |
+
|
240 |
+
# Run inference
|
241 |
+
img = torch.from_numpy(img).to(self.test_device).permute(2,0,1)
|
242 |
+
img = img.float()/255 # uint8 to fp16/32 0-1
|
243 |
+
if img.ndimension() == 3:
|
244 |
+
img = img.unsqueeze(0)
|
245 |
+
|
246 |
+
# Inference
|
247 |
+
if h1 != self.last_h or w1 != self.last_w or self.grids is None:
|
248 |
+
grids = []
|
249 |
+
for scale in [8,16,32]:
|
250 |
+
ny = h1//scale
|
251 |
+
nx = w1//scale
|
252 |
+
yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
|
253 |
+
grid = torch.stack((xv, yv), 2).view((1,1,ny, nx, 2)).float()
|
254 |
+
grids.append(grid.to(self.test_device))
|
255 |
+
self.grids = grids
|
256 |
+
self.last_w = w1
|
257 |
+
self.last_h = h1
|
258 |
+
|
259 |
+
pred = self.model(img, self.grids).cpu()
|
260 |
+
|
261 |
+
# Apply NMS
|
262 |
+
det = non_max_suppression_face(pred, self.conf_thres, self.iou_thres)[0]
|
263 |
+
# Process detections
|
264 |
+
# det = pred[0]
|
265 |
+
bboxes = np.zeros((det.shape[0], 4))
|
266 |
+
kpss = np.zeros((det.shape[0], 5, 2))
|
267 |
+
scores = np.zeros((det.shape[0]))
|
268 |
+
# gn = torch.tensor([w0, h0, w0, h0]).to(pred) # normalization gain whwh
|
269 |
+
# gn_lks = torch.tensor([w0, h0, w0, h0, w0, h0, w0, h0, w0, h0]).to(pred) # normalization gain landmarks
|
270 |
+
det = det.cpu().numpy()
|
271 |
+
|
272 |
+
for j in range(det.shape[0]):
|
273 |
+
# xywh = (xyxy2xywh(det[j, :4].view(1, 4)) / gn).view(4).cpu().numpy()
|
274 |
+
bboxes[j, 0] = det[j, 0] * w0/w1
|
275 |
+
bboxes[j, 1] = det[j, 1] * h0/h1
|
276 |
+
bboxes[j, 2] = det[j, 2] * w0/w1 - bboxes[j, 0]
|
277 |
+
bboxes[j, 3] = det[j, 3] * h0/h1 - bboxes[j, 1]
|
278 |
+
scores[j] = det[j, 4]
|
279 |
+
# landmarks = (det[j, 5:15].view(1, 10) / gn_lks).view(5,2).cpu().numpy()
|
280 |
+
kpss[j, :, :] = det[j, 5:15].reshape(5, 2) * np.array([[w0/w1,h0/h1]])
|
281 |
+
# class_num = det[j, 15].cpu().numpy()
|
282 |
+
# orgimg = show_results(orgimg, xywh, conf, landmarks, class_num)
|
283 |
+
return bboxes, kpss, scores
|
hyvideo/diffusion/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .pipelines import HunyuanVideoPipeline
|
2 |
+
from .schedulers import FlowMatchDiscreteScheduler
|
hyvideo/diffusion/pipelines/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .pipeline_hunyuan_video import HunyuanVideoPipeline
|
2 |
+
from .pipeline_hunyuan_video_audio import HunyuanVideoAudioPipeline
|
hyvideo/diffusion/pipelines/pipeline_hunyuan_video.py
ADDED
@@ -0,0 +1,1438 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
#
|
16 |
+
# Modified from diffusers==0.29.2
|
17 |
+
#
|
18 |
+
# ==============================================================================
|
19 |
+
import inspect
|
20 |
+
from typing import Any, Callable, Dict, List, Optional, Union, Tuple
|
21 |
+
import torch
|
22 |
+
import torch.distributed as dist
|
23 |
+
import numpy as np
|
24 |
+
from dataclasses import dataclass
|
25 |
+
from packaging import version
|
26 |
+
|
27 |
+
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
28 |
+
from diffusers.configuration_utils import FrozenDict
|
29 |
+
from diffusers.image_processor import VaeImageProcessor
|
30 |
+
from diffusers.utils import BaseOutput
|
31 |
+
from diffusers.loaders import LoraLoaderMixin, TextualInversionLoaderMixin
|
32 |
+
from diffusers.models import AutoencoderKL
|
33 |
+
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
34 |
+
from diffusers.schedulers import KarrasDiffusionSchedulers
|
35 |
+
from diffusers.utils import (
|
36 |
+
USE_PEFT_BACKEND,
|
37 |
+
deprecate,
|
38 |
+
logging,
|
39 |
+
replace_example_docstring,
|
40 |
+
scale_lora_layers,
|
41 |
+
unscale_lora_layers,
|
42 |
+
)
|
43 |
+
from diffusers.utils.torch_utils import randn_tensor
|
44 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
45 |
+
from diffusers.utils import BaseOutput
|
46 |
+
|
47 |
+
from ...constants import PRECISION_TO_TYPE
|
48 |
+
from ...vae.autoencoder_kl_causal_3d import AutoencoderKLCausal3D
|
49 |
+
from ...text_encoder import TextEncoder
|
50 |
+
from ...modules import HYVideoDiffusionTransformer
|
51 |
+
from mmgp import offload
|
52 |
+
from ...utils.data_utils import black_image
|
53 |
+
from einops import rearrange
|
54 |
+
|
55 |
+
EXAMPLE_DOC_STRING = """"""
|
56 |
+
|
57 |
+
|
58 |
+
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
59 |
+
"""
|
60 |
+
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
|
61 |
+
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
|
62 |
+
"""
|
63 |
+
std_text = noise_pred_text.std(
|
64 |
+
dim=list(range(1, noise_pred_text.ndim)), keepdim=True
|
65 |
+
)
|
66 |
+
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
67 |
+
# rescale the results from guidance (fixes overexposure)
|
68 |
+
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
69 |
+
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
|
70 |
+
noise_cfg = (
|
71 |
+
guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
|
72 |
+
)
|
73 |
+
return noise_cfg
|
74 |
+
|
75 |
+
|
76 |
+
def retrieve_timesteps(
|
77 |
+
scheduler,
|
78 |
+
num_inference_steps: Optional[int] = None,
|
79 |
+
device: Optional[Union[str, torch.device]] = None,
|
80 |
+
timesteps: Optional[List[int]] = None,
|
81 |
+
sigmas: Optional[List[float]] = None,
|
82 |
+
**kwargs,
|
83 |
+
):
|
84 |
+
"""
|
85 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
86 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
87 |
+
|
88 |
+
Args:
|
89 |
+
scheduler (`SchedulerMixin`):
|
90 |
+
The scheduler to get timesteps from.
|
91 |
+
num_inference_steps (`int`):
|
92 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
93 |
+
must be `None`.
|
94 |
+
device (`str` or `torch.device`, *optional*):
|
95 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
96 |
+
timesteps (`List[int]`, *optional*):
|
97 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
98 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
99 |
+
sigmas (`List[float]`, *optional*):
|
100 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
101 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
102 |
+
|
103 |
+
Returns:
|
104 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
105 |
+
second element is the number of inference steps.
|
106 |
+
"""
|
107 |
+
if timesteps is not None and sigmas is not None:
|
108 |
+
raise ValueError(
|
109 |
+
"Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
|
110 |
+
)
|
111 |
+
if timesteps is not None:
|
112 |
+
accepts_timesteps = "timesteps" in set(
|
113 |
+
inspect.signature(scheduler.set_timesteps).parameters.keys()
|
114 |
+
)
|
115 |
+
if not accepts_timesteps:
|
116 |
+
raise ValueError(
|
117 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
118 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
119 |
+
)
|
120 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
121 |
+
timesteps = scheduler.timesteps
|
122 |
+
num_inference_steps = len(timesteps)
|
123 |
+
elif sigmas is not None:
|
124 |
+
accept_sigmas = "sigmas" in set(
|
125 |
+
inspect.signature(scheduler.set_timesteps).parameters.keys()
|
126 |
+
)
|
127 |
+
if not accept_sigmas:
|
128 |
+
raise ValueError(
|
129 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
130 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
131 |
+
)
|
132 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
133 |
+
timesteps = scheduler.timesteps
|
134 |
+
num_inference_steps = len(timesteps)
|
135 |
+
else:
|
136 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
137 |
+
timesteps = scheduler.timesteps
|
138 |
+
return timesteps, num_inference_steps
|
139 |
+
|
140 |
+
|
141 |
+
@dataclass
|
142 |
+
class HunyuanVideoPipelineOutput(BaseOutput):
|
143 |
+
videos: Union[torch.Tensor, np.ndarray]
|
144 |
+
|
145 |
+
|
146 |
+
class HunyuanVideoPipeline(DiffusionPipeline):
|
147 |
+
r"""
|
148 |
+
Pipeline for text-to-video generation using HunyuanVideo.
|
149 |
+
|
150 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
151 |
+
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
152 |
+
|
153 |
+
Args:
|
154 |
+
vae ([`AutoencoderKL`]):
|
155 |
+
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
|
156 |
+
text_encoder ([`TextEncoder`]):
|
157 |
+
Frozen text-encoder.
|
158 |
+
text_encoder_2 ([`TextEncoder`]):
|
159 |
+
Frozen text-encoder_2.
|
160 |
+
transformer ([`HYVideoDiffusionTransformer`]):
|
161 |
+
A `HYVideoDiffusionTransformer` to denoise the encoded video latents.
|
162 |
+
scheduler ([`SchedulerMixin`]):
|
163 |
+
A scheduler to be used in combination with `unet` to denoise the encoded image latents.
|
164 |
+
"""
|
165 |
+
|
166 |
+
model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
|
167 |
+
_optional_components = ["text_encoder_2"]
|
168 |
+
_exclude_from_cpu_offload = ["transformer"]
|
169 |
+
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
170 |
+
|
171 |
+
def __init__(
|
172 |
+
self,
|
173 |
+
vae: AutoencoderKL,
|
174 |
+
text_encoder: TextEncoder,
|
175 |
+
transformer: HYVideoDiffusionTransformer,
|
176 |
+
scheduler: KarrasDiffusionSchedulers,
|
177 |
+
text_encoder_2: Optional[TextEncoder] = None,
|
178 |
+
progress_bar_config: Dict[str, Any] = None,
|
179 |
+
args=None,
|
180 |
+
):
|
181 |
+
super().__init__()
|
182 |
+
|
183 |
+
# ==========================================================================================
|
184 |
+
if progress_bar_config is None:
|
185 |
+
progress_bar_config = {}
|
186 |
+
if not hasattr(self, "_progress_bar_config"):
|
187 |
+
self._progress_bar_config = {}
|
188 |
+
self._progress_bar_config.update(progress_bar_config)
|
189 |
+
|
190 |
+
self.args = args
|
191 |
+
# ==========================================================================================
|
192 |
+
|
193 |
+
if (
|
194 |
+
hasattr(scheduler.config, "steps_offset")
|
195 |
+
and scheduler.config.steps_offset != 1
|
196 |
+
):
|
197 |
+
deprecation_message = (
|
198 |
+
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
199 |
+
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
|
200 |
+
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
|
201 |
+
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
|
202 |
+
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
|
203 |
+
" file"
|
204 |
+
)
|
205 |
+
deprecate(
|
206 |
+
"steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False
|
207 |
+
)
|
208 |
+
new_config = dict(scheduler.config)
|
209 |
+
new_config["steps_offset"] = 1
|
210 |
+
scheduler._internal_dict = FrozenDict(new_config)
|
211 |
+
|
212 |
+
if (
|
213 |
+
hasattr(scheduler.config, "clip_sample")
|
214 |
+
and scheduler.config.clip_sample is True
|
215 |
+
):
|
216 |
+
deprecation_message = (
|
217 |
+
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
|
218 |
+
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
|
219 |
+
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
|
220 |
+
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
|
221 |
+
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
|
222 |
+
)
|
223 |
+
deprecate(
|
224 |
+
"clip_sample not set", "1.0.0", deprecation_message, standard_warn=False
|
225 |
+
)
|
226 |
+
new_config = dict(scheduler.config)
|
227 |
+
new_config["clip_sample"] = False
|
228 |
+
scheduler._internal_dict = FrozenDict(new_config)
|
229 |
+
|
230 |
+
self.register_modules(
|
231 |
+
vae=vae,
|
232 |
+
text_encoder=text_encoder,
|
233 |
+
transformer=transformer,
|
234 |
+
scheduler=scheduler,
|
235 |
+
text_encoder_2=text_encoder_2,
|
236 |
+
)
|
237 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
238 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
239 |
+
self.noise_pertub = 0
|
240 |
+
|
241 |
+
def encode_prompt(
|
242 |
+
self,
|
243 |
+
prompt,
|
244 |
+
name,
|
245 |
+
device,
|
246 |
+
num_videos_per_prompt,
|
247 |
+
do_classifier_free_guidance,
|
248 |
+
negative_prompt=None,
|
249 |
+
pixel_value_llava: Optional[torch.Tensor] = None,
|
250 |
+
uncond_pixel_value_llava: Optional[torch.Tensor] = None,
|
251 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
252 |
+
attention_mask: Optional[torch.Tensor] = None,
|
253 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
254 |
+
negative_attention_mask: Optional[torch.Tensor] = None,
|
255 |
+
lora_scale: Optional[float] = None,
|
256 |
+
clip_skip: Optional[int] = None,
|
257 |
+
text_encoder: Optional[TextEncoder] = None,
|
258 |
+
data_type: Optional[str] = "image",
|
259 |
+
semantic_images=None
|
260 |
+
):
|
261 |
+
r"""
|
262 |
+
Encodes the prompt into text encoder hidden states.
|
263 |
+
|
264 |
+
Args:
|
265 |
+
prompt (`str` or `List[str]`, *optional*):
|
266 |
+
prompt to be encoded
|
267 |
+
device: (`torch.device`):
|
268 |
+
torch device
|
269 |
+
num_videos_per_prompt (`int`):
|
270 |
+
number of videos that should be generated per prompt
|
271 |
+
do_classifier_free_guidance (`bool`):
|
272 |
+
whether to use classifier free guidance or not
|
273 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
274 |
+
The prompt or prompts not to guide the video generation. If not defined, one has to pass
|
275 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
276 |
+
less than `1`).
|
277 |
+
pixel_value_llava (`torch.Tensor`, *optional*):
|
278 |
+
The image tensor for llava.
|
279 |
+
uncond_pixel_value_llava (`torch.Tensor`, *optional*):
|
280 |
+
The image tensor for llava. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
281 |
+
less than `1`).
|
282 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
283 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
284 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
285 |
+
attention_mask (`torch.Tensor`, *optional*):
|
286 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
287 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
288 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
289 |
+
argument.
|
290 |
+
negative_attention_mask (`torch.Tensor`, *optional*):
|
291 |
+
lora_scale (`float`, *optional*):
|
292 |
+
A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
293 |
+
clip_skip (`int`, *optional*):
|
294 |
+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
295 |
+
the output of the pre-final layer will be used for computing the prompt embeddings.
|
296 |
+
text_encoder (TextEncoder, *optional*):
|
297 |
+
data_type (`str`, *optional*):
|
298 |
+
"""
|
299 |
+
if text_encoder is None:
|
300 |
+
text_encoder = self.text_encoder
|
301 |
+
|
302 |
+
# set lora scale so that monkey patched LoRA
|
303 |
+
# function of text encoder can correctly access it
|
304 |
+
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
|
305 |
+
self._lora_scale = lora_scale
|
306 |
+
|
307 |
+
# dynamically adjust the LoRA scale
|
308 |
+
if not USE_PEFT_BACKEND:
|
309 |
+
adjust_lora_scale_text_encoder(text_encoder.model, lora_scale)
|
310 |
+
else:
|
311 |
+
scale_lora_layers(text_encoder.model, lora_scale)
|
312 |
+
|
313 |
+
if prompt is not None and isinstance(prompt, str):
|
314 |
+
batch_size = 1
|
315 |
+
elif prompt is not None and isinstance(prompt, list):
|
316 |
+
batch_size = len(prompt)
|
317 |
+
else:
|
318 |
+
batch_size = prompt_embeds.shape[0]
|
319 |
+
|
320 |
+
if prompt_embeds is None:
|
321 |
+
# textual inversion: process multi-vector tokens if necessary
|
322 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
323 |
+
prompt = self.maybe_convert_prompt(prompt, text_encoder.tokenizer)
|
324 |
+
|
325 |
+
text_inputs = text_encoder.text2tokens(prompt, data_type=data_type, name = name)
|
326 |
+
|
327 |
+
if pixel_value_llava is not None:
|
328 |
+
text_inputs['pixel_value_llava'] = pixel_value_llava
|
329 |
+
text_inputs['attention_mask'] = torch.cat([text_inputs['attention_mask'], torch.ones((1, 575 * len(pixel_value_llava))).to(text_inputs['attention_mask'])], dim=1)
|
330 |
+
|
331 |
+
if clip_skip is None:
|
332 |
+
prompt_outputs = text_encoder.encode(
|
333 |
+
text_inputs, data_type=data_type, semantic_images=semantic_images, device=device
|
334 |
+
)
|
335 |
+
prompt_embeds = prompt_outputs.hidden_state
|
336 |
+
else:
|
337 |
+
prompt_outputs = text_encoder.encode(
|
338 |
+
text_inputs,
|
339 |
+
output_hidden_states=True,
|
340 |
+
data_type=data_type,
|
341 |
+
semantic_images=semantic_images,
|
342 |
+
device=device,
|
343 |
+
)
|
344 |
+
# Access the `hidden_states` first, that contains a tuple of
|
345 |
+
# all the hidden states from the encoder layers. Then index into
|
346 |
+
# the tuple to access the hidden states from the desired layer.
|
347 |
+
prompt_embeds = prompt_outputs.hidden_states_list[-(clip_skip + 1)]
|
348 |
+
# We also need to apply the final LayerNorm here to not mess with the
|
349 |
+
# representations. The `last_hidden_states` that we typically use for
|
350 |
+
# obtaining the final prompt representations passes through the LayerNorm
|
351 |
+
# layer.
|
352 |
+
prompt_embeds = text_encoder.model.text_model.final_layer_norm(
|
353 |
+
prompt_embeds
|
354 |
+
)
|
355 |
+
|
356 |
+
attention_mask = prompt_outputs.attention_mask
|
357 |
+
if attention_mask is not None:
|
358 |
+
attention_mask = attention_mask.to(device)
|
359 |
+
bs_embed, seq_len = attention_mask.shape
|
360 |
+
attention_mask = attention_mask.repeat(1, num_videos_per_prompt)
|
361 |
+
attention_mask = attention_mask.view(
|
362 |
+
bs_embed * num_videos_per_prompt, seq_len
|
363 |
+
)
|
364 |
+
|
365 |
+
if text_encoder is not None:
|
366 |
+
prompt_embeds_dtype = text_encoder.dtype
|
367 |
+
elif self.transformer is not None:
|
368 |
+
prompt_embeds_dtype = self.transformer.dtype
|
369 |
+
else:
|
370 |
+
prompt_embeds_dtype = prompt_embeds.dtype
|
371 |
+
|
372 |
+
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
373 |
+
|
374 |
+
if prompt_embeds.ndim == 2:
|
375 |
+
bs_embed, _ = prompt_embeds.shape
|
376 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
377 |
+
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt)
|
378 |
+
prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, -1)
|
379 |
+
else:
|
380 |
+
bs_embed, seq_len, _ = prompt_embeds.shape
|
381 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
382 |
+
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
383 |
+
prompt_embeds = prompt_embeds.view(
|
384 |
+
bs_embed * num_videos_per_prompt, seq_len, -1
|
385 |
+
)
|
386 |
+
|
387 |
+
# get unconditional embeddings for classifier free guidance
|
388 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
389 |
+
uncond_tokens: List[str]
|
390 |
+
if negative_prompt is None:
|
391 |
+
uncond_tokens = [""] * batch_size
|
392 |
+
elif prompt is not None and type(prompt) is not type(negative_prompt):
|
393 |
+
raise TypeError(
|
394 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
395 |
+
f" {type(prompt)}."
|
396 |
+
)
|
397 |
+
elif isinstance(negative_prompt, str):
|
398 |
+
uncond_tokens = [negative_prompt]
|
399 |
+
elif batch_size != len(negative_prompt):
|
400 |
+
raise ValueError(
|
401 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
402 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
403 |
+
" the batch size of `prompt`."
|
404 |
+
)
|
405 |
+
else:
|
406 |
+
uncond_tokens = negative_prompt
|
407 |
+
|
408 |
+
# textual inversion: process multi-vector tokens if necessary
|
409 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
410 |
+
uncond_tokens = self.maybe_convert_prompt(
|
411 |
+
uncond_tokens, text_encoder.tokenizer
|
412 |
+
)
|
413 |
+
|
414 |
+
# max_length = prompt_embeds.shape[1]
|
415 |
+
uncond_input = text_encoder.text2tokens(uncond_tokens, data_type=data_type, name = name)
|
416 |
+
|
417 |
+
if semantic_images is not None:
|
418 |
+
uncond_image = [black_image(img.size[0], img.size[1]) for img in semantic_images]
|
419 |
+
else:
|
420 |
+
uncond_image = None
|
421 |
+
|
422 |
+
if uncond_pixel_value_llava is not None:
|
423 |
+
uncond_input['pixel_value_llava'] = uncond_pixel_value_llava
|
424 |
+
uncond_input['attention_mask'] = torch.cat([uncond_input['attention_mask'], torch.ones((1, 575 * len(uncond_pixel_value_llava))).to(uncond_input['attention_mask'])], dim=1)
|
425 |
+
|
426 |
+
negative_prompt_outputs = text_encoder.encode(
|
427 |
+
uncond_input, data_type=data_type, semantic_images=uncond_image, device=device
|
428 |
+
)
|
429 |
+
negative_prompt_embeds = negative_prompt_outputs.hidden_state
|
430 |
+
|
431 |
+
negative_attention_mask = negative_prompt_outputs.attention_mask
|
432 |
+
if negative_attention_mask is not None:
|
433 |
+
negative_attention_mask = negative_attention_mask.to(device)
|
434 |
+
_, seq_len = negative_attention_mask.shape
|
435 |
+
negative_attention_mask = negative_attention_mask.repeat(
|
436 |
+
1, num_videos_per_prompt
|
437 |
+
)
|
438 |
+
negative_attention_mask = negative_attention_mask.view(
|
439 |
+
batch_size * num_videos_per_prompt, seq_len
|
440 |
+
)
|
441 |
+
|
442 |
+
if do_classifier_free_guidance:
|
443 |
+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
444 |
+
seq_len = negative_prompt_embeds.shape[1]
|
445 |
+
|
446 |
+
negative_prompt_embeds = negative_prompt_embeds.to(
|
447 |
+
dtype=prompt_embeds_dtype, device=device
|
448 |
+
)
|
449 |
+
|
450 |
+
if negative_prompt_embeds.ndim == 2:
|
451 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(
|
452 |
+
1, num_videos_per_prompt
|
453 |
+
)
|
454 |
+
negative_prompt_embeds = negative_prompt_embeds.view(
|
455 |
+
batch_size * num_videos_per_prompt, -1
|
456 |
+
)
|
457 |
+
else:
|
458 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(
|
459 |
+
1, num_videos_per_prompt, 1
|
460 |
+
)
|
461 |
+
negative_prompt_embeds = negative_prompt_embeds.view(
|
462 |
+
batch_size * num_videos_per_prompt, seq_len, -1
|
463 |
+
)
|
464 |
+
|
465 |
+
if text_encoder is not None:
|
466 |
+
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
|
467 |
+
# Retrieve the original scale by scaling back the LoRA layers
|
468 |
+
unscale_lora_layers(text_encoder.model, lora_scale)
|
469 |
+
|
470 |
+
return (
|
471 |
+
prompt_embeds,
|
472 |
+
negative_prompt_embeds,
|
473 |
+
attention_mask,
|
474 |
+
negative_attention_mask,
|
475 |
+
)
|
476 |
+
|
477 |
+
def decode_latents(self, latents, enable_tiling=True):
|
478 |
+
deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
|
479 |
+
deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
|
480 |
+
|
481 |
+
latents = 1 / self.vae.config.scaling_factor * latents
|
482 |
+
if enable_tiling:
|
483 |
+
self.vae.enable_tiling()
|
484 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
485 |
+
else:
|
486 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
487 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
488 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
489 |
+
if image.ndim == 4:
|
490 |
+
image = image.cpu().permute(0, 2, 3, 1).float()
|
491 |
+
else:
|
492 |
+
image = image.cpu().float()
|
493 |
+
return image
|
494 |
+
|
495 |
+
def prepare_extra_func_kwargs(self, func, kwargs):
|
496 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
497 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
498 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
499 |
+
# and should be between [0, 1]
|
500 |
+
extra_step_kwargs = {}
|
501 |
+
|
502 |
+
for k, v in kwargs.items():
|
503 |
+
accepts = k in set(inspect.signature(func).parameters.keys())
|
504 |
+
if accepts:
|
505 |
+
extra_step_kwargs[k] = v
|
506 |
+
return extra_step_kwargs
|
507 |
+
|
508 |
+
def check_inputs(
|
509 |
+
self,
|
510 |
+
prompt,
|
511 |
+
height,
|
512 |
+
width,
|
513 |
+
video_length,
|
514 |
+
callback_steps,
|
515 |
+
pixel_value_llava=None,
|
516 |
+
uncond_pixel_value_llava=None,
|
517 |
+
negative_prompt=None,
|
518 |
+
prompt_embeds=None,
|
519 |
+
negative_prompt_embeds=None,
|
520 |
+
callback_on_step_end_tensor_inputs=None,
|
521 |
+
vae_ver="88-4c-sd",
|
522 |
+
):
|
523 |
+
if height % 8 != 0 or width % 8 != 0:
|
524 |
+
raise ValueError(
|
525 |
+
f"`height` and `width` have to be divisible by 8 but are {height} and {width}."
|
526 |
+
)
|
527 |
+
|
528 |
+
if video_length is not None:
|
529 |
+
if "884" in vae_ver:
|
530 |
+
if video_length != 1 and (video_length - 1) % 4 != 0:
|
531 |
+
raise ValueError(
|
532 |
+
f"`video_length` has to be 1 or a multiple of 4 but is {video_length}."
|
533 |
+
)
|
534 |
+
elif "888" in vae_ver:
|
535 |
+
if video_length != 1 and (video_length - 1) % 8 != 0:
|
536 |
+
raise ValueError(
|
537 |
+
f"`video_length` has to be 1 or a multiple of 8 but is {video_length}."
|
538 |
+
)
|
539 |
+
|
540 |
+
if callback_steps is not None and (
|
541 |
+
not isinstance(callback_steps, int) or callback_steps <= 0
|
542 |
+
):
|
543 |
+
raise ValueError(
|
544 |
+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
545 |
+
f" {type(callback_steps)}."
|
546 |
+
)
|
547 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
548 |
+
k in self._callback_tensor_inputs
|
549 |
+
for k in callback_on_step_end_tensor_inputs
|
550 |
+
):
|
551 |
+
raise ValueError(
|
552 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
553 |
+
)
|
554 |
+
|
555 |
+
if prompt is not None and prompt_embeds is not None:
|
556 |
+
raise ValueError(
|
557 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
558 |
+
" only forward one of the two."
|
559 |
+
)
|
560 |
+
elif prompt is None and prompt_embeds is None:
|
561 |
+
raise ValueError(
|
562 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
563 |
+
)
|
564 |
+
elif prompt is not None and (
|
565 |
+
not isinstance(prompt, str) and not isinstance(prompt, list)
|
566 |
+
):
|
567 |
+
raise ValueError(
|
568 |
+
f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
|
569 |
+
)
|
570 |
+
|
571 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
572 |
+
raise ValueError(
|
573 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
574 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
575 |
+
)
|
576 |
+
|
577 |
+
|
578 |
+
if pixel_value_llava is not None and uncond_pixel_value_llava is not None:
|
579 |
+
if len(pixel_value_llava) != len(uncond_pixel_value_llava):
|
580 |
+
raise ValueError(
|
581 |
+
"`pixel_value_llava` and `uncond_pixel_value_llava` must have the same length when passed directly, but"
|
582 |
+
f" got: `pixel_value_llava` {len(pixel_value_llava)} != `uncond_pixel_value_llava`"
|
583 |
+
f" {len(uncond_pixel_value_llava)}."
|
584 |
+
)
|
585 |
+
|
586 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
587 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
588 |
+
raise ValueError(
|
589 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
590 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
591 |
+
f" {negative_prompt_embeds.shape}."
|
592 |
+
)
|
593 |
+
|
594 |
+
def get_timesteps(self, num_inference_steps, strength, device):
|
595 |
+
# get the original timestep using init_timestep
|
596 |
+
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
|
597 |
+
|
598 |
+
t_start = max(num_inference_steps - init_timestep, 0)
|
599 |
+
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
600 |
+
if hasattr(self.scheduler, "set_begin_index"):
|
601 |
+
self.scheduler.set_begin_index(t_start * self.scheduler.order)
|
602 |
+
|
603 |
+
return timesteps.to(device), num_inference_steps - t_start
|
604 |
+
|
605 |
+
|
606 |
+
def prepare_latents(
|
607 |
+
self,
|
608 |
+
batch_size,
|
609 |
+
num_channels_latents,
|
610 |
+
num_inference_steps,
|
611 |
+
height,
|
612 |
+
width,
|
613 |
+
video_length,
|
614 |
+
dtype,
|
615 |
+
device,
|
616 |
+
timesteps,
|
617 |
+
generator,
|
618 |
+
latents=None,
|
619 |
+
denoise_strength=1.0,
|
620 |
+
img_latents=None,
|
621 |
+
i2v_mode=False,
|
622 |
+
i2v_condition_type=None,
|
623 |
+
i2v_stability=True,
|
624 |
+
):
|
625 |
+
if i2v_mode and i2v_condition_type == "latent_concat":
|
626 |
+
num_channels_latents = (num_channels_latents - 1) // 2
|
627 |
+
shape = (
|
628 |
+
batch_size,
|
629 |
+
num_channels_latents,
|
630 |
+
video_length,
|
631 |
+
int(height) // self.vae_scale_factor,
|
632 |
+
int(width) // self.vae_scale_factor,
|
633 |
+
)
|
634 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
635 |
+
raise ValueError(
|
636 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
637 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
638 |
+
)
|
639 |
+
|
640 |
+
if i2v_mode and i2v_stability:
|
641 |
+
if img_latents.shape[2] == 1:
|
642 |
+
img_latents = img_latents.repeat(1, 1, video_length, 1, 1)
|
643 |
+
x0 = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
644 |
+
x1 = img_latents
|
645 |
+
|
646 |
+
t = torch.tensor([0.999]).to(device=device)
|
647 |
+
latents = x0 * t + x1 * (1 - t)
|
648 |
+
latents = latents.to(dtype=dtype)
|
649 |
+
|
650 |
+
if denoise_strength == 0:
|
651 |
+
if latents is None:
|
652 |
+
latents = randn_tensor(
|
653 |
+
shape, generator=generator, device=device, dtype=dtype
|
654 |
+
)
|
655 |
+
else:
|
656 |
+
latents = latents.to(device)
|
657 |
+
original_latents = None
|
658 |
+
noise = None
|
659 |
+
timesteps = timesteps
|
660 |
+
else:
|
661 |
+
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
662 |
+
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, denoise_strength, device)
|
663 |
+
|
664 |
+
if latents is None:
|
665 |
+
latents = noise
|
666 |
+
original_latents = None
|
667 |
+
else:
|
668 |
+
latents = latents.to(device)
|
669 |
+
latent_timestep = timesteps[:1]
|
670 |
+
frames_needed = noise.shape[2]
|
671 |
+
current_frames = latents.shape[2]
|
672 |
+
|
673 |
+
if frames_needed > current_frames:
|
674 |
+
repeat_factor = frames_needed - current_frames
|
675 |
+
additional_frame = torch.randn((latents.size(0), latents.size(1),repeat_factor, latents.size(3), latents.size(4)), dtype=latents.dtype, device=latents.device)
|
676 |
+
latents = torch.cat((additional_frame, latents), dim=2)
|
677 |
+
self.additional_frames = repeat_factor
|
678 |
+
elif frames_needed < current_frames:
|
679 |
+
latents = latents[:, :, :frames_needed, :, :]
|
680 |
+
|
681 |
+
original_latents = latents.clone()
|
682 |
+
latents = latents * (1 - latent_timestep / 1000) + latent_timestep / 1000 * noise
|
683 |
+
print(f'debug:latent_timestep={latent_timestep}, latents-size={latents.shape}')
|
684 |
+
|
685 |
+
# Check existence to make it compatible with FlowMatchEulerDiscreteScheduler
|
686 |
+
if hasattr(self.scheduler, "init_noise_sigma"):
|
687 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
688 |
+
latents = latents * self.scheduler.init_noise_sigma
|
689 |
+
return latents, original_latents, noise, timesteps
|
690 |
+
|
691 |
+
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
|
692 |
+
def get_guidance_scale_embedding(
|
693 |
+
self,
|
694 |
+
w: torch.Tensor,
|
695 |
+
embedding_dim: int = 512,
|
696 |
+
dtype: torch.dtype = torch.float32,
|
697 |
+
) -> torch.Tensor:
|
698 |
+
"""
|
699 |
+
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
|
700 |
+
|
701 |
+
Args:
|
702 |
+
w (`torch.Tensor`):
|
703 |
+
Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.
|
704 |
+
embedding_dim (`int`, *optional*, defaults to 512):
|
705 |
+
Dimension of the embeddings to generate.
|
706 |
+
dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
|
707 |
+
Data type of the generated embeddings.
|
708 |
+
|
709 |
+
Returns:
|
710 |
+
`torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`.
|
711 |
+
"""
|
712 |
+
assert len(w.shape) == 1
|
713 |
+
w = w * 1000.0
|
714 |
+
|
715 |
+
half_dim = embedding_dim // 2
|
716 |
+
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
|
717 |
+
emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
|
718 |
+
emb = w.to(dtype)[:, None] * emb[None, :]
|
719 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
720 |
+
if embedding_dim % 2 == 1: # zero pad
|
721 |
+
emb = torch.nn.functional.pad(emb, (0, 1))
|
722 |
+
assert emb.shape == (w.shape[0], embedding_dim)
|
723 |
+
return emb
|
724 |
+
|
725 |
+
@property
|
726 |
+
def guidance_scale(self):
|
727 |
+
return self._guidance_scale
|
728 |
+
|
729 |
+
@property
|
730 |
+
def guidance_rescale(self):
|
731 |
+
return self._guidance_rescale
|
732 |
+
|
733 |
+
@property
|
734 |
+
def clip_skip(self):
|
735 |
+
return self._clip_skip
|
736 |
+
|
737 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
738 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
739 |
+
# corresponds to doing no classifier free guidance.
|
740 |
+
@property
|
741 |
+
def do_classifier_free_guidance(self):
|
742 |
+
# return self._guidance_scale > 1 and self.transformer.config.time_cond_proj_dim is None
|
743 |
+
return self._guidance_scale > 1
|
744 |
+
|
745 |
+
@property
|
746 |
+
def cross_attention_kwargs(self):
|
747 |
+
return self._cross_attention_kwargs
|
748 |
+
|
749 |
+
@property
|
750 |
+
def num_timesteps(self):
|
751 |
+
return self._num_timesteps
|
752 |
+
|
753 |
+
@property
|
754 |
+
def interrupt(self):
|
755 |
+
return self._interrupt
|
756 |
+
|
757 |
+
@torch.no_grad()
|
758 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
759 |
+
def __call__(
|
760 |
+
self,
|
761 |
+
prompt: Union[str, List[str]],
|
762 |
+
height: int,
|
763 |
+
width: int,
|
764 |
+
video_length: int,
|
765 |
+
name: Union[str, List[str]] = None,
|
766 |
+
data_type: str = "video",
|
767 |
+
num_inference_steps: int = 50,
|
768 |
+
timesteps: List[int] = None,
|
769 |
+
sigmas: List[float] = None,
|
770 |
+
guidance_scale: float = 7.5,
|
771 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
772 |
+
pixel_value_ref=None,
|
773 |
+
# ref_latents: Optional[torch.Tensor] = None,
|
774 |
+
# uncond_ref_latents: Optional[torch.Tensor] = None,
|
775 |
+
pixel_value_llava: Optional[torch.Tensor] = None,
|
776 |
+
uncond_pixel_value_llava: Optional[torch.Tensor] = None,
|
777 |
+
bg_latents: Optional[torch.Tensor] = None,
|
778 |
+
audio_prompts: Optional[torch.Tensor] = None,
|
779 |
+
ip_cfg_scale: float = 0.0,
|
780 |
+
audio_strength: float = 1.0,
|
781 |
+
use_deepcache: int = 1,
|
782 |
+
num_videos_per_prompt: Optional[int] = 1,
|
783 |
+
eta: float = 0.0,
|
784 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
785 |
+
latents: Optional[torch.Tensor] = None,
|
786 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
787 |
+
attention_mask: Optional[torch.Tensor] = None,
|
788 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
789 |
+
negative_attention_mask: Optional[torch.Tensor] = None,
|
790 |
+
output_type: Optional[str] = "pil",
|
791 |
+
return_dict: bool = True,
|
792 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
793 |
+
guidance_rescale: float = 0.0,
|
794 |
+
clip_skip: Optional[int] = None,
|
795 |
+
callback_on_step_end: Optional[
|
796 |
+
Union[
|
797 |
+
Callable[[int, int, Dict], None],
|
798 |
+
PipelineCallback,
|
799 |
+
MultiPipelineCallbacks,
|
800 |
+
]
|
801 |
+
] = None,
|
802 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
803 |
+
freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None,
|
804 |
+
vae_ver: str = "88-4c-sd",
|
805 |
+
enable_tiling: bool = False,
|
806 |
+
n_tokens: Optional[int] = None,
|
807 |
+
video_val_flag: bool=False,
|
808 |
+
denoise_strength: float = 1.0,
|
809 |
+
mask = None,
|
810 |
+
embedded_guidance_scale: Optional[float] = None,
|
811 |
+
i2v_mode: bool = False,
|
812 |
+
i2v_condition_type: str = None,
|
813 |
+
i2v_stability: bool = True,
|
814 |
+
img_latents: Optional[torch.Tensor] = None,
|
815 |
+
semantic_images=None,
|
816 |
+
joint_pass = False,
|
817 |
+
cfg_star_rescale = False,
|
818 |
+
callback = None,
|
819 |
+
**kwargs,
|
820 |
+
):
|
821 |
+
r"""
|
822 |
+
The call function to the pipeline for generation.
|
823 |
+
|
824 |
+
Args:
|
825 |
+
prompt (`str` or `List[str]`):
|
826 |
+
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
|
827 |
+
height (`int`):
|
828 |
+
The height in pixels of the generated image.
|
829 |
+
width (`int`):
|
830 |
+
The width in pixels of the generated image.
|
831 |
+
video_length (`int`):
|
832 |
+
The number of frames in the generated video.
|
833 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
834 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
835 |
+
expense of slower inference.
|
836 |
+
timesteps (`List[int]`, *optional*):
|
837 |
+
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
838 |
+
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
839 |
+
passed will be used. Must be in descending order.
|
840 |
+
sigmas (`List[float]`, *optional*):
|
841 |
+
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
842 |
+
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
843 |
+
will be used.
|
844 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
845 |
+
A higher guidance scale value encourages the model to generate images closely linked to the text
|
846 |
+
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
|
847 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
848 |
+
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
|
849 |
+
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
|
850 |
+
ref_latents (`torch.Tensor`, *optional*):
|
851 |
+
The image tensor for time-concat.
|
852 |
+
uncond_ref_latents (`torch.Tensor`, *optional*):
|
853 |
+
The image tensor for time-concat. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
854 |
+
less than `1`).
|
855 |
+
pixel_value_llava (`torch.Tensor`, *optional*):
|
856 |
+
The image tensor for llava.
|
857 |
+
uncond_pixel_value_llava (`torch.Tensor`, *optional*):
|
858 |
+
The image tensor for llava. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
859 |
+
less than `1`).
|
860 |
+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
861 |
+
The number of images to generate per prompt.
|
862 |
+
eta (`float`, *optional*, defaults to 0.0):
|
863 |
+
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
|
864 |
+
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
|
865 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
866 |
+
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
867 |
+
generation deterministic.
|
868 |
+
latents (`torch.Tensor`, *optional*):
|
869 |
+
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
870 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
871 |
+
tensor is generated by sampling using the supplied random `generator`.
|
872 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
873 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
874 |
+
provided, text embeddings are generated from the `prompt` input argument.
|
875 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
876 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
|
877 |
+
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
|
878 |
+
|
879 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
880 |
+
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
881 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
882 |
+
Whether or not to return a [`HunyuanVideoPipelineOutput`] instead of a
|
883 |
+
plain tuple.
|
884 |
+
cross_attention_kwargs (`dict`, *optional*):
|
885 |
+
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
|
886 |
+
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
887 |
+
guidance_rescale (`float`, *optional*, defaults to 0.0):
|
888 |
+
Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
|
889 |
+
Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
|
890 |
+
using zero terminal SNR.
|
891 |
+
clip_skip (`int`, *optional*):
|
892 |
+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
893 |
+
the output of the pre-final layer will be used for computing the prompt embeddings.
|
894 |
+
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
|
895 |
+
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
|
896 |
+
each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
|
897 |
+
DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
|
898 |
+
list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
|
899 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
900 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
901 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
902 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
903 |
+
|
904 |
+
Examples:
|
905 |
+
|
906 |
+
Returns:
|
907 |
+
[`~HunyuanVideoPipelineOutput`] or `tuple`:
|
908 |
+
If `return_dict` is `True`, [`HunyuanVideoPipelineOutput`] is returned,
|
909 |
+
otherwise a `tuple` is returned where the first element is a list with the generated images and the
|
910 |
+
second element is a list of `bool`s indicating whether the corresponding generated image contains
|
911 |
+
"not-safe-for-work" (nsfw) content.
|
912 |
+
"""
|
913 |
+
callback_steps = kwargs.pop("callback_steps", None)
|
914 |
+
|
915 |
+
# if callback is not None:
|
916 |
+
# deprecate(
|
917 |
+
# "callback",
|
918 |
+
# "1.0.0",
|
919 |
+
# "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
|
920 |
+
# )
|
921 |
+
# if callback_steps is not None:
|
922 |
+
# deprecate(
|
923 |
+
# "callback_steps",
|
924 |
+
# "1.0.0",
|
925 |
+
# "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
|
926 |
+
# )
|
927 |
+
|
928 |
+
|
929 |
+
if self._interrupt:
|
930 |
+
return [None]
|
931 |
+
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
932 |
+
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
933 |
+
|
934 |
+
if pixel_value_ref != None:
|
935 |
+
pixel_value_ref = pixel_value_ref * 2 - 1.
|
936 |
+
pixel_value_ref_for_vae = rearrange(pixel_value_ref,"b c h w -> b c 1 h w")
|
937 |
+
|
938 |
+
ref_latents = self.vae.encode(pixel_value_ref_for_vae.clone()).latent_dist.sample()
|
939 |
+
uncond_ref_latents = self.vae.encode(torch.ones_like(pixel_value_ref_for_vae)).latent_dist.sample()
|
940 |
+
ref_latents.mul_(self.vae.config.scaling_factor)
|
941 |
+
uncond_ref_latents.mul_(self.vae.config.scaling_factor)
|
942 |
+
else:
|
943 |
+
ref_latents = None
|
944 |
+
uncond_ref_latents = None
|
945 |
+
|
946 |
+
|
947 |
+
# 0. Default height and width to unet
|
948 |
+
# height = height or self.transformer.config.sample_size * self.vae_scale_factor
|
949 |
+
# width = width or self.transformer.config.sample_size * self.vae_scale_factor
|
950 |
+
# to deal with lora scaling and other possible forward hooks
|
951 |
+
trans = self.transformer
|
952 |
+
if trans.enable_teacache:
|
953 |
+
teacache_multiplier = trans.teacache_multiplier
|
954 |
+
trans.accumulated_rel_l1_distance = 0
|
955 |
+
trans.rel_l1_thresh = 0.1 if teacache_multiplier < 2 else 0.15
|
956 |
+
# trans.teacache_start_step = int(tea_cache_start_step_perc*num_inference_steps/100)
|
957 |
+
# 1. Check inputs. Raise error if not correct
|
958 |
+
self.check_inputs(
|
959 |
+
prompt,
|
960 |
+
height,
|
961 |
+
width,
|
962 |
+
video_length,
|
963 |
+
callback_steps,
|
964 |
+
negative_prompt,
|
965 |
+
pixel_value_llava,
|
966 |
+
uncond_pixel_value_llava,
|
967 |
+
prompt_embeds,
|
968 |
+
negative_prompt_embeds,
|
969 |
+
callback_on_step_end_tensor_inputs,
|
970 |
+
vae_ver=vae_ver,
|
971 |
+
)
|
972 |
+
|
973 |
+
self._guidance_scale = guidance_scale
|
974 |
+
self._guidance_rescale = guidance_rescale
|
975 |
+
self._clip_skip = clip_skip
|
976 |
+
self._cross_attention_kwargs = cross_attention_kwargs
|
977 |
+
|
978 |
+
# 2. Define call parameters
|
979 |
+
if prompt is not None and isinstance(prompt, str):
|
980 |
+
batch_size = 1
|
981 |
+
elif prompt is not None and isinstance(prompt, list):
|
982 |
+
batch_size = len(prompt)
|
983 |
+
else:
|
984 |
+
batch_size = prompt_embeds.shape[0]
|
985 |
+
|
986 |
+
device = torch.device(f"cuda:{dist.get_rank()}") if dist.is_initialized() else self._execution_device
|
987 |
+
|
988 |
+
# 3. Encode input prompt
|
989 |
+
lora_scale = (
|
990 |
+
self.cross_attention_kwargs.get("scale", None)
|
991 |
+
if self.cross_attention_kwargs is not None
|
992 |
+
else None
|
993 |
+
)
|
994 |
+
|
995 |
+
(
|
996 |
+
prompt_embeds,
|
997 |
+
negative_prompt_embeds,
|
998 |
+
prompt_mask,
|
999 |
+
negative_prompt_mask,
|
1000 |
+
) = self.encode_prompt(
|
1001 |
+
prompt,
|
1002 |
+
name,
|
1003 |
+
device,
|
1004 |
+
num_videos_per_prompt,
|
1005 |
+
self.do_classifier_free_guidance,
|
1006 |
+
negative_prompt,
|
1007 |
+
pixel_value_llava=pixel_value_llava,
|
1008 |
+
uncond_pixel_value_llava=uncond_pixel_value_llava,
|
1009 |
+
prompt_embeds=prompt_embeds,
|
1010 |
+
attention_mask=attention_mask,
|
1011 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
1012 |
+
negative_attention_mask=negative_attention_mask,
|
1013 |
+
lora_scale=lora_scale,
|
1014 |
+
clip_skip=self.clip_skip,
|
1015 |
+
data_type=data_type,
|
1016 |
+
semantic_images=semantic_images
|
1017 |
+
)
|
1018 |
+
if self.text_encoder_2 is not None:
|
1019 |
+
(
|
1020 |
+
prompt_embeds_2,
|
1021 |
+
negative_prompt_embeds_2,
|
1022 |
+
prompt_mask_2,
|
1023 |
+
negative_prompt_mask_2,
|
1024 |
+
) = self.encode_prompt(
|
1025 |
+
prompt,
|
1026 |
+
name,
|
1027 |
+
device,
|
1028 |
+
num_videos_per_prompt,
|
1029 |
+
self.do_classifier_free_guidance,
|
1030 |
+
negative_prompt,
|
1031 |
+
prompt_embeds=None,
|
1032 |
+
attention_mask=None,
|
1033 |
+
negative_prompt_embeds=None,
|
1034 |
+
negative_attention_mask=None,
|
1035 |
+
lora_scale=lora_scale,
|
1036 |
+
clip_skip=self.clip_skip,
|
1037 |
+
text_encoder=self.text_encoder_2,
|
1038 |
+
data_type=data_type,
|
1039 |
+
)
|
1040 |
+
else:
|
1041 |
+
prompt_embeds_2 = None
|
1042 |
+
negative_prompt_embeds_2 = None
|
1043 |
+
prompt_mask_2 = None
|
1044 |
+
negative_prompt_mask_2 = None
|
1045 |
+
|
1046 |
+
# For classifier free guidance, we need to do two forward passes.
|
1047 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
1048 |
+
# to avoid doing two forward passes
|
1049 |
+
if self.do_classifier_free_guidance:
|
1050 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
1051 |
+
if prompt_mask is not None:
|
1052 |
+
prompt_mask = torch.cat([negative_prompt_mask, prompt_mask])
|
1053 |
+
if prompt_embeds_2 is not None:
|
1054 |
+
prompt_embeds_2 = torch.cat([negative_prompt_embeds_2, prompt_embeds_2])
|
1055 |
+
if prompt_mask_2 is not None:
|
1056 |
+
prompt_mask_2 = torch.cat([negative_prompt_mask_2, prompt_mask_2])
|
1057 |
+
|
1058 |
+
if self.do_classifier_free_guidance:
|
1059 |
+
if ref_latents is not None:
|
1060 |
+
ref_latents = torch.cat([ref_latents, ref_latents], dim=0)
|
1061 |
+
if prompt_mask[0].sum() > 575:
|
1062 |
+
prompt_mask[0] = torch.cat([torch.ones((1, prompt_mask[0].sum() - 575)).to(prompt_mask),
|
1063 |
+
torch.zeros((1, prompt_mask.shape[1] - prompt_mask[0].sum() + 575)).to(prompt_mask)], dim=1)
|
1064 |
+
|
1065 |
+
if bg_latents is not None:
|
1066 |
+
bg_latents = torch.cat([bg_latents, bg_latents], dim=0)
|
1067 |
+
|
1068 |
+
if audio_prompts is not None:
|
1069 |
+
audio_prompts = torch.cat([torch.zeros_like(audio_prompts), audio_prompts], dim=0)
|
1070 |
+
|
1071 |
+
if ip_cfg_scale>0:
|
1072 |
+
prompt_embeds = torch.cat([prompt_embeds, prompt_embeds[1:]])
|
1073 |
+
prompt_embeds_2 = torch.cat([prompt_embeds_2, prompt_embeds_2[1:]])
|
1074 |
+
prompt_mask = torch.cat([prompt_mask, prompt_mask[1:]], dim=0)
|
1075 |
+
ref_latents = torch.cat([uncond_ref_latents, uncond_ref_latents, ref_latents[1:]], dim=0)
|
1076 |
+
|
1077 |
+
|
1078 |
+
# 4. Prepare timesteps
|
1079 |
+
extra_set_timesteps_kwargs = self.prepare_extra_func_kwargs(
|
1080 |
+
self.scheduler.set_timesteps, {"n_tokens": n_tokens}
|
1081 |
+
)
|
1082 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
1083 |
+
self.scheduler,
|
1084 |
+
num_inference_steps,
|
1085 |
+
device,
|
1086 |
+
timesteps,
|
1087 |
+
sigmas,
|
1088 |
+
**extra_set_timesteps_kwargs,
|
1089 |
+
)
|
1090 |
+
|
1091 |
+
if "884" in vae_ver:
|
1092 |
+
video_length = (video_length - 1) // 4 + 1
|
1093 |
+
elif "888" in vae_ver:
|
1094 |
+
video_length = (video_length - 1) // 8 + 1
|
1095 |
+
else:
|
1096 |
+
video_length = video_length
|
1097 |
+
|
1098 |
+
if self.transformer.mixed_precision:
|
1099 |
+
latent_dtype = torch.float32
|
1100 |
+
else:
|
1101 |
+
latent_dtype = torch.bfloat16
|
1102 |
+
if prompt_embeds != None:
|
1103 |
+
prompt_embeds = prompt_embeds.to(torch.bfloat16)
|
1104 |
+
if prompt_embeds_2 != None:
|
1105 |
+
prompt_embeds_2 = prompt_embeds_2.to(torch.bfloat16)
|
1106 |
+
# if prompt_mask != None:
|
1107 |
+
# prompt_mask = prompt_mask.to(torch.bfloat16)
|
1108 |
+
# 5. Prepare latent variables
|
1109 |
+
num_channels_latents = self.transformer.config.in_channels
|
1110 |
+
latents, original_latents, noise, timesteps = self.prepare_latents(
|
1111 |
+
batch_size * num_videos_per_prompt,
|
1112 |
+
num_channels_latents,
|
1113 |
+
num_inference_steps,
|
1114 |
+
height,
|
1115 |
+
width,
|
1116 |
+
video_length,
|
1117 |
+
latent_dtype, #prompt_embeds.dtype,
|
1118 |
+
device,
|
1119 |
+
timesteps,
|
1120 |
+
generator,
|
1121 |
+
latents,
|
1122 |
+
denoise_strength,
|
1123 |
+
img_latents=img_latents,
|
1124 |
+
i2v_mode=i2v_mode,
|
1125 |
+
i2v_condition_type=i2v_condition_type,
|
1126 |
+
i2v_stability=i2v_stability
|
1127 |
+
)
|
1128 |
+
|
1129 |
+
if i2v_mode and i2v_condition_type == "latent_concat":
|
1130 |
+
if img_latents.shape[2] == 1:
|
1131 |
+
img_latents_concat = img_latents.repeat(1, 1, video_length, 1, 1)
|
1132 |
+
else:
|
1133 |
+
img_latents_concat = img_latents
|
1134 |
+
img_latents_concat[:, :, 1:, ...] = 0
|
1135 |
+
|
1136 |
+
i2v_mask = torch.zeros(video_length)
|
1137 |
+
i2v_mask[0] = 1
|
1138 |
+
|
1139 |
+
mask_concat = torch.ones(img_latents_concat.shape[0], 1, img_latents_concat.shape[2], img_latents_concat.shape[3],
|
1140 |
+
img_latents_concat.shape[4]).to(device=img_latents.device)
|
1141 |
+
mask_concat[:, :, 1:, ...] = 0
|
1142 |
+
|
1143 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
1144 |
+
extra_step_kwargs = self.prepare_extra_func_kwargs(
|
1145 |
+
self.scheduler.step,
|
1146 |
+
{"generator": generator, "eta": eta},
|
1147 |
+
)
|
1148 |
+
|
1149 |
+
vae_precision = "fp16" # torch.float16
|
1150 |
+
precision = "bf16" # torch.bfloat16
|
1151 |
+
|
1152 |
+
disable_autocast = True
|
1153 |
+
|
1154 |
+
target_dtype = PRECISION_TO_TYPE[precision]
|
1155 |
+
autocast_enabled = target_dtype != torch.float32 and not disable_autocast
|
1156 |
+
vae_dtype = self.vae._model_dtype # PRECISION_TO_TYPE[vae_precision]
|
1157 |
+
vae_autocast_enabled = vae_dtype != torch.float32 and not disable_autocast
|
1158 |
+
|
1159 |
+
# 7. Denoising loop
|
1160 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
1161 |
+
self._num_timesteps = len(timesteps)
|
1162 |
+
start_scale = ip_cfg_scale # 3.0
|
1163 |
+
end_scale = 1.0
|
1164 |
+
step_scale = (start_scale - end_scale) / (self._num_timesteps - 1 + 1e-3)
|
1165 |
+
|
1166 |
+
# print('sigmas used in generation:', self.scheduler.sigmas)
|
1167 |
+
# print('inference timesteps used in generation:', timesteps)
|
1168 |
+
|
1169 |
+
|
1170 |
+
# 8. Mask latents
|
1171 |
+
mask_latents = None
|
1172 |
+
if mask is not None:
|
1173 |
+
target_video_length = mask.shape[0]
|
1174 |
+
target_height = mask.shape[1]
|
1175 |
+
target_width = mask.shape[2]
|
1176 |
+
|
1177 |
+
mask_length = (target_video_length - 1) // 4 + 1
|
1178 |
+
mask_height = target_height // 8
|
1179 |
+
mask_width = target_width // 8
|
1180 |
+
|
1181 |
+
mask = mask[...,0:1]
|
1182 |
+
mask = mask.unsqueeze(0)
|
1183 |
+
mask = rearrange(mask, "b t h w c -> b c t h w")
|
1184 |
+
|
1185 |
+
mask_latents = torch.nn.functional.interpolate(mask, size=(mask_length, mask_height, mask_width))
|
1186 |
+
mask_latents = mask_latents.to(device)
|
1187 |
+
|
1188 |
+
if mask_latents is not None:
|
1189 |
+
mask_latents_model_input = (
|
1190 |
+
torch.cat([mask_latents] * 2)
|
1191 |
+
if self.do_classifier_free_guidance
|
1192 |
+
else mask_latents
|
1193 |
+
)
|
1194 |
+
print(f'maskinfo, mask={mask.shape}, mask_latents_model_input={mask_latents_model_input.shape} ')
|
1195 |
+
|
1196 |
+
|
1197 |
+
if callback != None:
|
1198 |
+
callback(-1, None, True)
|
1199 |
+
|
1200 |
+
load_latent = True
|
1201 |
+
load_latent = False
|
1202 |
+
|
1203 |
+
multi_passes_free_guidance = not joint_pass
|
1204 |
+
if load_latent:
|
1205 |
+
timesteps = []
|
1206 |
+
|
1207 |
+
latent_items = 2 if self.do_classifier_free_guidance else 1
|
1208 |
+
if ip_cfg_scale>0:
|
1209 |
+
latent_items += 1
|
1210 |
+
|
1211 |
+
if self.transformer.enable_teacache:
|
1212 |
+
self.transformer.previous_residual = [None] * latent_items
|
1213 |
+
|
1214 |
+
# if is_progress_bar:
|
1215 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
1216 |
+
for i, t in enumerate(timesteps):
|
1217 |
+
offload.set_step_no_for_lora(self.transformer, i)
|
1218 |
+
if self.interrupt:
|
1219 |
+
continue
|
1220 |
+
if i2v_mode and i2v_condition_type == "token_replace":
|
1221 |
+
latents = torch.concat([img_latents, latents[:, :, 1:, :, :]], dim=2)
|
1222 |
+
|
1223 |
+
# expand the latents if we are doing classifier free guidance
|
1224 |
+
if i2v_mode and i2v_condition_type == "latent_concat":
|
1225 |
+
latent_model_input = torch.concat([latents, img_latents_concat, mask_concat], dim=1)
|
1226 |
+
else:
|
1227 |
+
latent_model_input = latents
|
1228 |
+
|
1229 |
+
latent_model_input = torch.cat([latent_model_input] * latent_items) if latent_items > 1 else latent_model_input
|
1230 |
+
|
1231 |
+
latent_model_input = self.scheduler.scale_model_input(
|
1232 |
+
latent_model_input, t
|
1233 |
+
)
|
1234 |
+
|
1235 |
+
if mask_latents is not None:
|
1236 |
+
original_latents_noise = original_latents * (1 - t / 1000.0) + t / 1000.0 * noise
|
1237 |
+
original_latent_noise_model_input = (
|
1238 |
+
torch.cat([original_latents_noise] * 2)
|
1239 |
+
if self.do_classifier_free_guidance
|
1240 |
+
else original_latents_noise
|
1241 |
+
)
|
1242 |
+
original_latent_noise_model_input = self.scheduler.scale_model_input(original_latent_noise_model_input, t)
|
1243 |
+
latent_model_input = mask_latents_model_input * latent_model_input + (1 - mask_latents_model_input) * original_latent_noise_model_input
|
1244 |
+
|
1245 |
+
t_expand = t.repeat(latent_model_input.shape[0])
|
1246 |
+
guidance_expand = (
|
1247 |
+
torch.tensor(
|
1248 |
+
[embedded_guidance_scale] * latent_model_input.shape[0],
|
1249 |
+
dtype=torch.float32,
|
1250 |
+
device=device,
|
1251 |
+
).to(latent_dtype)
|
1252 |
+
* 1000.0
|
1253 |
+
if embedded_guidance_scale is not None
|
1254 |
+
else None
|
1255 |
+
)
|
1256 |
+
|
1257 |
+
# predict the noise residual
|
1258 |
+
with torch.autocast(
|
1259 |
+
device_type="cuda", dtype=target_dtype, enabled=autocast_enabled
|
1260 |
+
):
|
1261 |
+
|
1262 |
+
if self.do_classifier_free_guidance and multi_passes_free_guidance:
|
1263 |
+
for j in range(len(latent_model_input)):
|
1264 |
+
ret = self.transformer( # For an input image (129, 192, 336) (1, 256, 256)
|
1265 |
+
latent_model_input[j].unsqueeze(0), # [2, 16, 33, 24, 42]
|
1266 |
+
t_expand[j].unsqueeze(0), # [2]
|
1267 |
+
text_states=prompt_embeds[j].unsqueeze(0), # [2, 256, 4096]
|
1268 |
+
text_mask=prompt_mask[j].unsqueeze(0), # [2, 256]
|
1269 |
+
text_states_2=prompt_embeds_2[j].unsqueeze(0), # [2, 768]
|
1270 |
+
ref_latents=ref_latents[j].unsqueeze(0),
|
1271 |
+
freqs_cos=freqs_cis[0], # [seqlen, head_dim]
|
1272 |
+
freqs_sin=freqs_cis[1], # [seqlen, head_dim]
|
1273 |
+
guidance=guidance_expand,
|
1274 |
+
pipeline=self,
|
1275 |
+
x_id=j,
|
1276 |
+
step_no=i,
|
1277 |
+
bg_latents=bg_latents[j].unsqueeze(0) if bg_latents!=None else None,
|
1278 |
+
audio_prompts=audio_prompts[j].unsqueeze(0) if audio_prompts!=None else None,
|
1279 |
+
audio_strength=audio_strength,
|
1280 |
+
callback = callback,
|
1281 |
+
)
|
1282 |
+
if self._interrupt:
|
1283 |
+
return [None]
|
1284 |
+
if j==0:
|
1285 |
+
noise_pred_uncond= ret[0]
|
1286 |
+
elif j==1:
|
1287 |
+
noise_pred_text= ret[0]
|
1288 |
+
else:
|
1289 |
+
noise_pred_ip = ret[0]
|
1290 |
+
ret = None
|
1291 |
+
else:
|
1292 |
+
# if self.do_classifier_free_guidance:
|
1293 |
+
# noise_pred_uncond = self.transformer(latent_model_input[:1], t_expand[:1], ref_latents=ref_latents[:1], text_states=prompt_embeds[:1], text_mask=prompt_mask[:1], text_states_2=prompt_embeds_2[:1], freqs_cos=freqs_cis[0],freqs_sin=freqs_cis[1], guidance=guidance_expand,return_dict=True)['x']
|
1294 |
+
# noise_pred_text = self.transformer(latent_model_input[1:], t_expand[1:], ref_latents=ref_latents[1:], text_states=prompt_embeds[1:], text_mask=prompt_mask[1:], text_states_2=prompt_embeds_2[1:], freqs_cos=freqs_cis[0],freqs_sin=freqs_cis[1], guidance=guidance_expand,return_dict=True)['x']
|
1295 |
+
# noise_pred = torch.cat([noise_pred_uncond, noise_pred_text], dim=0)
|
1296 |
+
# else:
|
1297 |
+
ret = self.transformer( # For an input image (129, 192, 336) (1, 256, 256)
|
1298 |
+
latent_model_input, # [2, 16, 33, 24, 42]
|
1299 |
+
t_expand, # [2]
|
1300 |
+
text_states=prompt_embeds, # [2, 256, 4096]
|
1301 |
+
text_mask=prompt_mask, # [2, 256]
|
1302 |
+
text_states_2=prompt_embeds_2, # [2, 768]
|
1303 |
+
ref_latents=ref_latents,
|
1304 |
+
freqs_cos=freqs_cis[0], # [seqlen, head_dim]
|
1305 |
+
freqs_sin=freqs_cis[1], # [seqlen, head_dim]
|
1306 |
+
guidance=guidance_expand,
|
1307 |
+
pipeline=self,
|
1308 |
+
step_no=i,
|
1309 |
+
bg_latents=bg_latents,
|
1310 |
+
audio_prompts=audio_prompts,
|
1311 |
+
audio_strength=audio_strength,
|
1312 |
+
callback = callback,
|
1313 |
+
)
|
1314 |
+
if self._interrupt:
|
1315 |
+
return [None]
|
1316 |
+
if self.do_classifier_free_guidance :
|
1317 |
+
if ip_cfg_scale > 0:
|
1318 |
+
noise_pred_uncond, noise_pred_text, noise_pred_ip = ret
|
1319 |
+
else:
|
1320 |
+
noise_pred_uncond, noise_pred_text = noise_pred = ret
|
1321 |
+
else:
|
1322 |
+
noise_pred = ret[0]
|
1323 |
+
|
1324 |
+
# perform guidance
|
1325 |
+
if self.do_classifier_free_guidance:
|
1326 |
+
if cfg_star_rescale:
|
1327 |
+
batch_size = 1
|
1328 |
+
positive_flat = noise_pred_text.view(batch_size, -1)
|
1329 |
+
negative_flat = noise_pred_uncond.view(batch_size, -1)
|
1330 |
+
dot_product = torch.sum(
|
1331 |
+
positive_flat * negative_flat, dim=1, keepdim=True
|
1332 |
+
)
|
1333 |
+
squared_norm = torch.sum(negative_flat**2, dim=1, keepdim=True) + 1e-8
|
1334 |
+
positive_flat, negative_flat = None, None
|
1335 |
+
alpha = dot_product / squared_norm
|
1336 |
+
noise_pred_uncond *= alpha
|
1337 |
+
|
1338 |
+
if ip_cfg_scale > 0:
|
1339 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + start_scale * (noise_pred_ip-noise_pred_text)
|
1340 |
+
start_scale -= step_scale
|
1341 |
+
if i==0:
|
1342 |
+
print(f'i={i}, noise_pred shape={noise_pred.shape}')
|
1343 |
+
else:
|
1344 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * ( noise_pred_text - noise_pred_uncond)
|
1345 |
+
|
1346 |
+
|
1347 |
+
if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
|
1348 |
+
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
1349 |
+
noise_pred = rescale_noise_cfg( noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale, )
|
1350 |
+
|
1351 |
+
# compute the previous noisy sample x_t -> x_t-1
|
1352 |
+
if i2v_mode and i2v_condition_type == "token_replace":
|
1353 |
+
noise_pred = noise_pred.unsqueeze(0)
|
1354 |
+
latents = self.scheduler.step(
|
1355 |
+
noise_pred[:, :, 1:, :, :], t, latents[:, :, 1:, :, :], **extra_step_kwargs, return_dict=False
|
1356 |
+
)[0]
|
1357 |
+
latents = torch.concat(
|
1358 |
+
[img_latents, latents], dim=2
|
1359 |
+
)
|
1360 |
+
else:
|
1361 |
+
latents = self.scheduler.step(
|
1362 |
+
noise_pred, t, latents, **extra_step_kwargs, return_dict=False
|
1363 |
+
)[0]
|
1364 |
+
|
1365 |
+
|
1366 |
+
noise_pred_uncond, noise_pred_text, noise_pred, noise_pred_ip, ret = None, None, None, None, None
|
1367 |
+
|
1368 |
+
if callback is not None:
|
1369 |
+
callback(i, latents.squeeze(0), False)
|
1370 |
+
|
1371 |
+
if self.interrupt:
|
1372 |
+
return [None]
|
1373 |
+
|
1374 |
+
# if load_latent:
|
1375 |
+
# latents = torch.load("latent.pt")
|
1376 |
+
# else:
|
1377 |
+
# torch.save(latents, "latent.pt")
|
1378 |
+
|
1379 |
+
|
1380 |
+
if mask_latents is not None:
|
1381 |
+
latents = mask_latents * latents + (1 - mask_latents) * original_latents
|
1382 |
+
|
1383 |
+
if not output_type == "latent":
|
1384 |
+
expand_temporal_dim = False
|
1385 |
+
if len(latents.shape) == 4:
|
1386 |
+
if isinstance(self.vae, AutoencoderKLCausal3D):
|
1387 |
+
latents = latents.unsqueeze(2)
|
1388 |
+
expand_temporal_dim = True
|
1389 |
+
elif len(latents.shape) == 5:
|
1390 |
+
pass
|
1391 |
+
else:
|
1392 |
+
raise ValueError(
|
1393 |
+
f"Only support latents with shape (b, c, h, w) or (b, c, f, h, w), but got {latents.shape}."
|
1394 |
+
)
|
1395 |
+
|
1396 |
+
if (
|
1397 |
+
hasattr(self.vae.config, "shift_factor")
|
1398 |
+
and self.vae.config.shift_factor
|
1399 |
+
):
|
1400 |
+
latents = (
|
1401 |
+
latents / self.vae.config.scaling_factor
|
1402 |
+
+ self.vae.config.shift_factor
|
1403 |
+
)
|
1404 |
+
else:
|
1405 |
+
latents = latents / self.vae.config.scaling_factor
|
1406 |
+
|
1407 |
+
with torch.autocast(
|
1408 |
+
device_type="cuda", dtype=vae_dtype, enabled=vae_autocast_enabled
|
1409 |
+
):
|
1410 |
+
if enable_tiling:
|
1411 |
+
self.vae.enable_tiling()
|
1412 |
+
image = self.vae.decode(
|
1413 |
+
latents, return_dict=False, generator=generator
|
1414 |
+
)[0]
|
1415 |
+
else:
|
1416 |
+
image = self.vae.decode(
|
1417 |
+
latents, return_dict=False, generator=generator
|
1418 |
+
)[0]
|
1419 |
+
|
1420 |
+
if expand_temporal_dim or image.shape[2] == 1:
|
1421 |
+
image = image.squeeze(2)
|
1422 |
+
|
1423 |
+
else:
|
1424 |
+
image = latents
|
1425 |
+
|
1426 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
1427 |
+
image = image.cpu().float()
|
1428 |
+
|
1429 |
+
if i2v_mode and i2v_condition_type == "latent_concat":
|
1430 |
+
image = image[:, :, 4:, :, :]
|
1431 |
+
|
1432 |
+
# Offload all models
|
1433 |
+
self.maybe_free_model_hooks()
|
1434 |
+
|
1435 |
+
if not return_dict:
|
1436 |
+
return image
|
1437 |
+
|
1438 |
+
return HunyuanVideoPipelineOutput(videos=image)
|
hyvideo/diffusion/pipelines/pipeline_hunyuan_video_audio.py
ADDED
@@ -0,0 +1,1362 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
#
|
16 |
+
# Modified from diffusers==0.29.2
|
17 |
+
#
|
18 |
+
# ==============================================================================
|
19 |
+
import inspect
|
20 |
+
from typing import Any, Callable, Dict, List, Optional, Union, Tuple
|
21 |
+
import numpy as np
|
22 |
+
import torch
|
23 |
+
from packaging import version
|
24 |
+
from diffusers.utils import BaseOutput
|
25 |
+
from dataclasses import dataclass
|
26 |
+
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
27 |
+
from diffusers.configuration_utils import FrozenDict
|
28 |
+
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
29 |
+
from diffusers.loaders import LoraLoaderMixin, TextualInversionLoaderMixin
|
30 |
+
from diffusers.models import AutoencoderKL, ImageProjection
|
31 |
+
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
32 |
+
from diffusers.schedulers import KarrasDiffusionSchedulers
|
33 |
+
from diffusers.utils import (
|
34 |
+
USE_PEFT_BACKEND,
|
35 |
+
deprecate,
|
36 |
+
logging,
|
37 |
+
replace_example_docstring,
|
38 |
+
scale_lora_layers,
|
39 |
+
unscale_lora_layers,
|
40 |
+
)
|
41 |
+
from diffusers.utils.torch_utils import randn_tensor
|
42 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
43 |
+
|
44 |
+
from hyvideo.constants import PRECISION_TO_TYPE
|
45 |
+
from hyvideo.vae.autoencoder_kl_causal_3d import AutoencoderKLCausal3D
|
46 |
+
from hyvideo.text_encoder import TextEncoder
|
47 |
+
from einops import rearrange
|
48 |
+
from ...modules import HYVideoDiffusionTransformer
|
49 |
+
|
50 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
51 |
+
|
52 |
+
EXAMPLE_DOC_STRING = """"""
|
53 |
+
|
54 |
+
|
55 |
+
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
56 |
+
"""
|
57 |
+
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
|
58 |
+
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
|
59 |
+
"""
|
60 |
+
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
|
61 |
+
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
62 |
+
# rescale the results from guidance (fixes overexposure)
|
63 |
+
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
64 |
+
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
|
65 |
+
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
|
66 |
+
return noise_cfg
|
67 |
+
|
68 |
+
|
69 |
+
def retrieve_timesteps(
|
70 |
+
scheduler,
|
71 |
+
num_inference_steps: Optional[int] = None,
|
72 |
+
device: Optional[Union[str, torch.device]] = None,
|
73 |
+
timesteps: Optional[List[int]] = None,
|
74 |
+
sigmas: Optional[List[float]] = None,
|
75 |
+
**kwargs,
|
76 |
+
):
|
77 |
+
"""
|
78 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
79 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
80 |
+
|
81 |
+
Args:
|
82 |
+
scheduler (`SchedulerMixin`):
|
83 |
+
The scheduler to get timesteps from.
|
84 |
+
num_inference_steps (`int`):
|
85 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
86 |
+
must be `None`.
|
87 |
+
device (`str` or `torch.device`, *optional*):
|
88 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
89 |
+
timesteps (`List[int]`, *optional*):
|
90 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
91 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
92 |
+
sigmas (`List[float]`, *optional*):
|
93 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
94 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
95 |
+
|
96 |
+
Returns:
|
97 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
98 |
+
second element is the number of inference steps.
|
99 |
+
"""
|
100 |
+
if timesteps is not None and sigmas is not None:
|
101 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
102 |
+
if timesteps is not None:
|
103 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
104 |
+
if not accepts_timesteps:
|
105 |
+
raise ValueError(
|
106 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
107 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
108 |
+
)
|
109 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
110 |
+
timesteps = scheduler.timesteps
|
111 |
+
num_inference_steps = len(timesteps)
|
112 |
+
elif sigmas is not None:
|
113 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
114 |
+
if not accept_sigmas:
|
115 |
+
raise ValueError(
|
116 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
117 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
118 |
+
)
|
119 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
120 |
+
timesteps = scheduler.timesteps
|
121 |
+
num_inference_steps = len(timesteps)
|
122 |
+
else:
|
123 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
124 |
+
timesteps = scheduler.timesteps
|
125 |
+
return timesteps, num_inference_steps
|
126 |
+
|
127 |
+
@dataclass
|
128 |
+
class HunyuanVideoPipelineOutput(BaseOutput):
|
129 |
+
videos: Union[torch.Tensor, np.ndarray]
|
130 |
+
|
131 |
+
|
132 |
+
class HunyuanVideoAudioPipeline(DiffusionPipeline):
|
133 |
+
r"""
|
134 |
+
Pipeline for text-to-video generation using HunyuanVideo.
|
135 |
+
|
136 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
137 |
+
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
138 |
+
|
139 |
+
Args:
|
140 |
+
vae ([`AutoencoderKL`]):
|
141 |
+
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
|
142 |
+
text_encoder ([`TextEncoder`]):
|
143 |
+
Frozen text-encoder.
|
144 |
+
text_encoder_2 ([`TextEncoder`]):
|
145 |
+
Frozen text-encoder_2.
|
146 |
+
transformer ([`HYVideoDiffusionTransformer`]):
|
147 |
+
A `HYVideoDiffusionTransformer` to denoise the encoded video latents.
|
148 |
+
scheduler ([`SchedulerMixin`]):
|
149 |
+
A scheduler to be used in combination with `unet` to denoise the encoded image latents.
|
150 |
+
"""
|
151 |
+
|
152 |
+
model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
|
153 |
+
_optional_components = ["text_encoder_2"]
|
154 |
+
_exclude_from_cpu_offload = ["transformer"]
|
155 |
+
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
156 |
+
|
157 |
+
def __init__(
|
158 |
+
self,
|
159 |
+
vae: AutoencoderKL,
|
160 |
+
text_encoder: TextEncoder,
|
161 |
+
transformer: HYVideoDiffusionTransformer,
|
162 |
+
scheduler: KarrasDiffusionSchedulers,
|
163 |
+
text_encoder_2: Optional[TextEncoder] = None,
|
164 |
+
progress_bar_config: Dict[str, Any] = None,
|
165 |
+
args=None,
|
166 |
+
):
|
167 |
+
super().__init__()
|
168 |
+
|
169 |
+
# ==========================================================================================
|
170 |
+
if progress_bar_config is None:
|
171 |
+
progress_bar_config = {}
|
172 |
+
if not hasattr(self, '_progress_bar_config'):
|
173 |
+
self._progress_bar_config = {}
|
174 |
+
self._progress_bar_config.update(progress_bar_config)
|
175 |
+
|
176 |
+
self.args = args
|
177 |
+
# ==========================================================================================
|
178 |
+
|
179 |
+
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
|
180 |
+
deprecation_message = (
|
181 |
+
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
182 |
+
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
|
183 |
+
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
|
184 |
+
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
|
185 |
+
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
|
186 |
+
" file"
|
187 |
+
)
|
188 |
+
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
|
189 |
+
new_config = dict(scheduler.config)
|
190 |
+
new_config["steps_offset"] = 1
|
191 |
+
scheduler._internal_dict = FrozenDict(new_config)
|
192 |
+
|
193 |
+
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
|
194 |
+
deprecation_message = (
|
195 |
+
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
|
196 |
+
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
|
197 |
+
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
|
198 |
+
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
|
199 |
+
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
|
200 |
+
)
|
201 |
+
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
|
202 |
+
new_config = dict(scheduler.config)
|
203 |
+
new_config["clip_sample"] = False
|
204 |
+
scheduler._internal_dict = FrozenDict(new_config)
|
205 |
+
|
206 |
+
self.register_modules(
|
207 |
+
vae=vae,
|
208 |
+
text_encoder=text_encoder,
|
209 |
+
transformer=transformer,
|
210 |
+
scheduler=scheduler,
|
211 |
+
text_encoder_2=text_encoder_2
|
212 |
+
)
|
213 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
214 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
215 |
+
|
216 |
+
def encode_prompt(
|
217 |
+
self,
|
218 |
+
prompt,
|
219 |
+
name,
|
220 |
+
device,
|
221 |
+
num_videos_per_prompt,
|
222 |
+
do_classifier_free_guidance,
|
223 |
+
negative_prompt=None,
|
224 |
+
pixel_value_llava: Optional[torch.Tensor] = None,
|
225 |
+
uncond_pixel_value_llava: Optional[torch.Tensor] = None,
|
226 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
227 |
+
attention_mask: Optional[torch.Tensor] = None,
|
228 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
229 |
+
negative_attention_mask: Optional[torch.Tensor] = None,
|
230 |
+
lora_scale: Optional[float] = None,
|
231 |
+
clip_skip: Optional[int] = None,
|
232 |
+
text_encoder: Optional[TextEncoder] = None,
|
233 |
+
data_type: Optional[str] = "image",
|
234 |
+
):
|
235 |
+
r"""
|
236 |
+
Encodes the prompt into text encoder hidden states.
|
237 |
+
|
238 |
+
Args:
|
239 |
+
prompt (`str` or `List[str]`, *optional*):
|
240 |
+
prompt to be encoded
|
241 |
+
device: (`torch.device`):
|
242 |
+
torch device
|
243 |
+
num_videos_per_prompt (`int`):
|
244 |
+
number of images that should be generated per prompt
|
245 |
+
do_classifier_free_guidance (`bool`):
|
246 |
+
whether to use classifier free guidance or not
|
247 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
248 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
249 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
250 |
+
less than `1`).
|
251 |
+
pixel_value_llava (`torch.Tensor`, *optional*):
|
252 |
+
The image tensor for llava.
|
253 |
+
uncond_pixel_value_llava (`torch.Tensor`, *optional*):
|
254 |
+
The image tensor for llava. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
255 |
+
less than `1`).
|
256 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
257 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
258 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
259 |
+
attention_mask (`torch.Tensor`, *optional*):
|
260 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
261 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
262 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
263 |
+
argument.
|
264 |
+
negative_attention_mask (`torch.Tensor`, *optional*):
|
265 |
+
lora_scale (`float`, *optional*):
|
266 |
+
A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
267 |
+
clip_skip (`int`, *optional*):
|
268 |
+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
269 |
+
the output of the pre-final layer will be used for computing the prompt embeddings.
|
270 |
+
text_encoder (TextEncoder, *optional*):
|
271 |
+
"""
|
272 |
+
if text_encoder is None:
|
273 |
+
text_encoder = self.text_encoder
|
274 |
+
|
275 |
+
# set lora scale so that monkey patched LoRA
|
276 |
+
# function of text encoder can correctly access it
|
277 |
+
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
|
278 |
+
self._lora_scale = lora_scale
|
279 |
+
|
280 |
+
# dynamically adjust the LoRA scale
|
281 |
+
if not USE_PEFT_BACKEND:
|
282 |
+
adjust_lora_scale_text_encoder(text_encoder.model, lora_scale)
|
283 |
+
else:
|
284 |
+
scale_lora_layers(text_encoder.model, lora_scale)
|
285 |
+
|
286 |
+
if prompt is not None and isinstance(prompt, str):
|
287 |
+
batch_size = 1
|
288 |
+
elif prompt is not None and isinstance(prompt, list):
|
289 |
+
batch_size = len(prompt)
|
290 |
+
else:
|
291 |
+
batch_size = prompt_embeds.shape[0]
|
292 |
+
|
293 |
+
if prompt_embeds is None:
|
294 |
+
# textual inversion: process multi-vector tokens if necessary
|
295 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
296 |
+
prompt = self.maybe_convert_prompt(prompt, text_encoder.tokenizer)
|
297 |
+
text_inputs = text_encoder.text2tokens(prompt, data_type=data_type, name=name)
|
298 |
+
|
299 |
+
if pixel_value_llava is not None:
|
300 |
+
text_inputs['pixel_value_llava'] = pixel_value_llava
|
301 |
+
text_inputs['attention_mask'] = torch.cat([text_inputs['attention_mask'], torch.ones((1, 575 * len(pixel_value_llava))).to(text_inputs['attention_mask'])], dim=1)
|
302 |
+
|
303 |
+
if clip_skip is None:
|
304 |
+
prompt_outputs = text_encoder.encode(text_inputs, data_type=data_type)
|
305 |
+
prompt_embeds = prompt_outputs.hidden_state
|
306 |
+
else:
|
307 |
+
prompt_outputs = text_encoder.encode(text_inputs, output_hidden_states=True, data_type=data_type)
|
308 |
+
# Access the `hidden_states` first, that contains a tuple of
|
309 |
+
# all the hidden states from the encoder layers. Then index into
|
310 |
+
# the tuple to access the hidden states from the desired layer.
|
311 |
+
prompt_embeds = prompt_outputs.hidden_states_list[-(clip_skip + 1)]
|
312 |
+
# We also need to apply the final LayerNorm here to not mess with the
|
313 |
+
# representations. The `last_hidden_states` that we typically use for
|
314 |
+
# obtaining the final prompt representations passes through the LayerNorm
|
315 |
+
# layer.
|
316 |
+
prompt_embeds = text_encoder.model.text_model.final_layer_norm(prompt_embeds)
|
317 |
+
|
318 |
+
attention_mask = prompt_outputs.attention_mask
|
319 |
+
if attention_mask is not None:
|
320 |
+
attention_mask = attention_mask.to(device)
|
321 |
+
bs_embed, seq_len = attention_mask.shape
|
322 |
+
attention_mask = attention_mask.repeat(1, num_videos_per_prompt)
|
323 |
+
attention_mask = attention_mask.view(bs_embed * num_videos_per_prompt, seq_len)
|
324 |
+
|
325 |
+
if text_encoder is not None:
|
326 |
+
prompt_embeds_dtype = text_encoder.dtype
|
327 |
+
elif self.transformer is not None:
|
328 |
+
prompt_embeds_dtype = self.transformer.dtype
|
329 |
+
else:
|
330 |
+
prompt_embeds_dtype = prompt_embeds.dtype
|
331 |
+
|
332 |
+
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
333 |
+
|
334 |
+
if prompt_embeds.ndim == 2:
|
335 |
+
bs_embed, _ = prompt_embeds.shape
|
336 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
337 |
+
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt)
|
338 |
+
prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, -1)
|
339 |
+
else:
|
340 |
+
bs_embed, seq_len, _ = prompt_embeds.shape
|
341 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
342 |
+
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
343 |
+
prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1)
|
344 |
+
|
345 |
+
# get unconditional embeddings for classifier free guidance
|
346 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
347 |
+
uncond_tokens: List[str]
|
348 |
+
if negative_prompt is None:
|
349 |
+
uncond_tokens = [""] * batch_size
|
350 |
+
elif prompt is not None and type(prompt) is not type(negative_prompt):
|
351 |
+
raise TypeError(
|
352 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
353 |
+
f" {type(prompt)}."
|
354 |
+
)
|
355 |
+
elif isinstance(negative_prompt, str):
|
356 |
+
uncond_tokens = [negative_prompt]
|
357 |
+
elif batch_size != len(negative_prompt):
|
358 |
+
raise ValueError(
|
359 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
360 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
361 |
+
" the batch size of `prompt`."
|
362 |
+
)
|
363 |
+
else:
|
364 |
+
uncond_tokens = negative_prompt
|
365 |
+
|
366 |
+
# textual inversion: process multi-vector tokens if necessary
|
367 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
368 |
+
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, text_encoder.tokenizer)
|
369 |
+
uncond_input = text_encoder.text2tokens(uncond_tokens, data_type=data_type)
|
370 |
+
if uncond_pixel_value_llava is not None:
|
371 |
+
uncond_input['pixel_value_llava'] = uncond_pixel_value_llava
|
372 |
+
uncond_input['attention_mask'] = torch.cat([uncond_input['attention_mask'], torch.ones((1, 575 * len(uncond_pixel_value_llava))).to(uncond_input['attention_mask'])], dim=1)
|
373 |
+
|
374 |
+
negative_prompt_outputs = text_encoder.encode(uncond_input, data_type=data_type)
|
375 |
+
negative_prompt_embeds = negative_prompt_outputs.hidden_state
|
376 |
+
|
377 |
+
negative_attention_mask = negative_prompt_outputs.attention_mask
|
378 |
+
if negative_attention_mask is not None:
|
379 |
+
negative_attention_mask = negative_attention_mask.to(device)
|
380 |
+
_, seq_len = negative_attention_mask.shape
|
381 |
+
negative_attention_mask = negative_attention_mask.repeat(1, num_videos_per_prompt)
|
382 |
+
negative_attention_mask = negative_attention_mask.view(batch_size * num_videos_per_prompt, seq_len)
|
383 |
+
|
384 |
+
if do_classifier_free_guidance:
|
385 |
+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
386 |
+
seq_len = negative_prompt_embeds.shape[1]
|
387 |
+
|
388 |
+
negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
389 |
+
|
390 |
+
if negative_prompt_embeds.ndim == 2:
|
391 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt)
|
392 |
+
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, -1)
|
393 |
+
else:
|
394 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
395 |
+
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
396 |
+
|
397 |
+
if text_encoder is not None:
|
398 |
+
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
|
399 |
+
# Retrieve the original scale by scaling back the LoRA layers
|
400 |
+
unscale_lora_layers(text_encoder.model, lora_scale)
|
401 |
+
|
402 |
+
return prompt_embeds, negative_prompt_embeds, attention_mask, negative_attention_mask
|
403 |
+
|
404 |
+
def encode_prompt_audio_text_base(
|
405 |
+
self,
|
406 |
+
prompt,
|
407 |
+
uncond_prompt,
|
408 |
+
pixel_value_llava,
|
409 |
+
uncond_pixel_value_llava,
|
410 |
+
device,
|
411 |
+
num_images_per_prompt,
|
412 |
+
do_classifier_free_guidance,
|
413 |
+
negative_prompt=None,
|
414 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
415 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
416 |
+
lora_scale: Optional[float] = None,
|
417 |
+
clip_skip: Optional[int] = None,
|
418 |
+
text_encoder: Optional[TextEncoder] = None,
|
419 |
+
data_type: Optional[str] = "image",
|
420 |
+
name = "person"
|
421 |
+
):
|
422 |
+
if text_encoder is None:
|
423 |
+
text_encoder = self.text_encoder
|
424 |
+
|
425 |
+
# set lora scale so that monkey patched LoRA
|
426 |
+
# function of text encoder can correctly access it
|
427 |
+
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
|
428 |
+
self._lora_scale = lora_scale
|
429 |
+
|
430 |
+
# dynamically adjust the LoRA scale
|
431 |
+
if not USE_PEFT_BACKEND:
|
432 |
+
adjust_lora_scale_text_encoder(text_encoder.model, lora_scale)
|
433 |
+
else:
|
434 |
+
scale_lora_layers(text_encoder.model, lora_scale)
|
435 |
+
|
436 |
+
if prompt is not None and isinstance(prompt, str):
|
437 |
+
batch_size = 1
|
438 |
+
elif prompt is not None and isinstance(prompt, list):
|
439 |
+
batch_size = len(prompt)
|
440 |
+
else:
|
441 |
+
batch_size = prompt_embeds.shape[0]
|
442 |
+
|
443 |
+
prompt_embeds = None
|
444 |
+
|
445 |
+
if prompt_embeds is None:
|
446 |
+
# textual inversion: process multi-vector tokens if necessary
|
447 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
448 |
+
prompt = self.maybe_convert_prompt(prompt, text_encoder.tokenizer)
|
449 |
+
text_inputs = text_encoder.text2tokens(prompt, data_type=data_type, name=name) # data_type: video, text_inputs: {'input_ids', 'attention_mask'}
|
450 |
+
|
451 |
+
text_keys = ['input_ids', 'attention_mask']
|
452 |
+
|
453 |
+
if pixel_value_llava is not None:
|
454 |
+
text_inputs['pixel_value_llava'] = pixel_value_llava
|
455 |
+
text_inputs['attention_mask'] = torch.cat([text_inputs['attention_mask'], torch.ones((1, 575)).to(text_inputs['attention_mask'])], dim=1)
|
456 |
+
|
457 |
+
|
458 |
+
if clip_skip is None:
|
459 |
+
prompt_outputs = text_encoder.encode(text_inputs, data_type=data_type)
|
460 |
+
prompt_embeds = prompt_outputs.hidden_state
|
461 |
+
else:
|
462 |
+
prompt_outputs = text_encoder.encode(text_inputs, output_hidden_states=True, data_type=data_type)
|
463 |
+
# Access the `hidden_states` first, that contains a tuple of
|
464 |
+
# all the hidden states from the encoder layers. Then index into
|
465 |
+
# the tuple to access the hidden states from the desired layer.
|
466 |
+
prompt_embeds = prompt_outputs.hidden_states_list[-(clip_skip + 1)]
|
467 |
+
# We also need to apply the final LayerNorm here to not mess with the
|
468 |
+
# representations. The `last_hidden_states` that we typically use for
|
469 |
+
# obtaining the final prompt representations passes through the LayerNorm
|
470 |
+
# layer.
|
471 |
+
prompt_embeds = text_encoder.model.text_model.final_layer_norm(prompt_embeds)
|
472 |
+
|
473 |
+
attention_mask = prompt_outputs.attention_mask
|
474 |
+
if attention_mask is not None:
|
475 |
+
attention_mask = attention_mask.to(device)
|
476 |
+
bs_embed, seq_len = attention_mask.shape
|
477 |
+
attention_mask = attention_mask.repeat(1, num_images_per_prompt)
|
478 |
+
attention_mask = attention_mask.view(bs_embed * num_images_per_prompt, seq_len)
|
479 |
+
|
480 |
+
if text_encoder is not None:
|
481 |
+
prompt_embeds_dtype = text_encoder.dtype
|
482 |
+
elif self.unet is not None:
|
483 |
+
prompt_embeds_dtype = self.unet.dtype
|
484 |
+
else:
|
485 |
+
prompt_embeds_dtype = prompt_embeds.dtype
|
486 |
+
|
487 |
+
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
488 |
+
|
489 |
+
if prompt_embeds.ndim == 2:
|
490 |
+
bs_embed, _ = prompt_embeds.shape
|
491 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
492 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
|
493 |
+
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, -1)
|
494 |
+
else:
|
495 |
+
bs_embed, seq_len, _ = prompt_embeds.shape
|
496 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
497 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
498 |
+
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
499 |
+
|
500 |
+
# get unconditional embeddings for classifier free guidance
|
501 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
502 |
+
uncond_tokens: List[str]
|
503 |
+
if negative_prompt is None:
|
504 |
+
uncond_tokens = [""] * batch_size
|
505 |
+
elif prompt is not None and type(prompt) is not type(negative_prompt):
|
506 |
+
raise TypeError(
|
507 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
508 |
+
f" {type(prompt)}."
|
509 |
+
)
|
510 |
+
elif isinstance(negative_prompt, str):
|
511 |
+
uncond_tokens = [negative_prompt]
|
512 |
+
elif batch_size != len(negative_prompt):
|
513 |
+
raise ValueError(
|
514 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
515 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
516 |
+
" the batch size of `prompt`."
|
517 |
+
)
|
518 |
+
else:
|
519 |
+
uncond_tokens = negative_prompt
|
520 |
+
|
521 |
+
# textual inversion: process multi-vector tokens if necessary
|
522 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
523 |
+
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, text_encoder.tokenizer)
|
524 |
+
# max_length = prompt_embeds.shape[1]
|
525 |
+
uncond_input = text_encoder.text2tokens(uncond_tokens, data_type=data_type, name=name)
|
526 |
+
|
527 |
+
# if hasattr(text_encoder.model.config, "use_attention_mask") and text_encoder.model.config.use_attention_mask:
|
528 |
+
# attention_mask = uncond_input.attention_mask.to(device)
|
529 |
+
# else:
|
530 |
+
# attention_mask = None
|
531 |
+
if uncond_pixel_value_llava is not None:
|
532 |
+
uncond_input['pixel_value_llava'] = uncond_pixel_value_llava
|
533 |
+
uncond_input['attention_mask'] = torch.cat([uncond_input['attention_mask'], torch.ones((1, 575)).to(uncond_input['attention_mask'])], dim=1)
|
534 |
+
|
535 |
+
negative_prompt_outputs = text_encoder.encode(uncond_input, data_type=data_type)
|
536 |
+
negative_prompt_embeds = negative_prompt_outputs.hidden_state
|
537 |
+
|
538 |
+
negative_attention_mask = negative_prompt_outputs.attention_mask
|
539 |
+
if negative_attention_mask is not None:
|
540 |
+
negative_attention_mask = negative_attention_mask.to(device)
|
541 |
+
_, seq_len = negative_attention_mask.shape
|
542 |
+
negative_attention_mask = negative_attention_mask.repeat(1, num_images_per_prompt)
|
543 |
+
negative_attention_mask = negative_attention_mask.view(batch_size * num_images_per_prompt, seq_len)
|
544 |
+
|
545 |
+
if do_classifier_free_guidance:
|
546 |
+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
547 |
+
seq_len = negative_prompt_embeds.shape[1]
|
548 |
+
|
549 |
+
negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
550 |
+
|
551 |
+
if negative_prompt_embeds.ndim == 2:
|
552 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt)
|
553 |
+
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
|
554 |
+
else:
|
555 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
556 |
+
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
557 |
+
|
558 |
+
if text_encoder is not None:
|
559 |
+
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
|
560 |
+
# Retrieve the original scale by scaling back the LoRA layers
|
561 |
+
unscale_lora_layers(text_encoder.model, lora_scale)
|
562 |
+
|
563 |
+
return prompt_embeds, negative_prompt_embeds, attention_mask, negative_attention_mask
|
564 |
+
|
565 |
+
def decode_latents(self, latents, enable_tiling=True):
|
566 |
+
deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
|
567 |
+
deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
|
568 |
+
|
569 |
+
latents = 1 / self.vae.config.scaling_factor * latents
|
570 |
+
if enable_tiling:
|
571 |
+
self.vae.enable_tiling()
|
572 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
573 |
+
self.vae.disable_tiling()
|
574 |
+
else:
|
575 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
576 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
577 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
578 |
+
if image.ndim==4: image = image.cpu().permute(0, 2, 3, 1).float()
|
579 |
+
else: image = image.cpu().float()
|
580 |
+
return image
|
581 |
+
|
582 |
+
def prepare_extra_func_kwargs(self, func, kwargs):
|
583 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
584 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
585 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
586 |
+
# and should be between [0, 1]
|
587 |
+
extra_step_kwargs = {}
|
588 |
+
|
589 |
+
for k, v in kwargs.items():
|
590 |
+
accepts = k in set(inspect.signature(func).parameters.keys())
|
591 |
+
if accepts:
|
592 |
+
extra_step_kwargs[k] = v
|
593 |
+
return extra_step_kwargs
|
594 |
+
|
595 |
+
def check_inputs(
|
596 |
+
self,
|
597 |
+
prompt,
|
598 |
+
height,
|
599 |
+
width,
|
600 |
+
frame,
|
601 |
+
callback_steps,
|
602 |
+
pixel_value_llava=None,
|
603 |
+
uncond_pixel_value_llava=None,
|
604 |
+
negative_prompt=None,
|
605 |
+
prompt_embeds=None,
|
606 |
+
negative_prompt_embeds=None,
|
607 |
+
callback_on_step_end_tensor_inputs=None,
|
608 |
+
vae_ver='88-4c-sd'
|
609 |
+
):
|
610 |
+
if height % 8 != 0 or width % 8 != 0:
|
611 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
612 |
+
|
613 |
+
if frame is not None:
|
614 |
+
if '884' in vae_ver:
|
615 |
+
if frame!=1 and (frame-1)%4!=0:
|
616 |
+
raise ValueError(f'`frame` has to be 1 or a multiple of 4 but is {frame}.')
|
617 |
+
elif '888' in vae_ver:
|
618 |
+
if frame!=1 and (frame-1)%8!=0:
|
619 |
+
raise ValueError(f'`frame` has to be 1 or a multiple of 8 but is {frame}.')
|
620 |
+
|
621 |
+
if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
|
622 |
+
raise ValueError(
|
623 |
+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
624 |
+
f" {type(callback_steps)}."
|
625 |
+
)
|
626 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
627 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
628 |
+
):
|
629 |
+
raise ValueError(
|
630 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
631 |
+
)
|
632 |
+
|
633 |
+
if prompt is not None and prompt_embeds is not None:
|
634 |
+
raise ValueError(
|
635 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
636 |
+
" only forward one of the two."
|
637 |
+
)
|
638 |
+
elif prompt is None and prompt_embeds is None:
|
639 |
+
raise ValueError(
|
640 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
641 |
+
)
|
642 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
643 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
644 |
+
|
645 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
646 |
+
raise ValueError(
|
647 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
648 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
649 |
+
)
|
650 |
+
|
651 |
+
if pixel_value_llava is not None and uncond_pixel_value_llava is not None:
|
652 |
+
if len(pixel_value_llava) != len(uncond_pixel_value_llava):
|
653 |
+
raise ValueError(
|
654 |
+
"`pixel_value_llava` and `uncond_pixel_value_llava` must have the same length when passed directly, but"
|
655 |
+
f" got: `pixel_value_llava` {len(pixel_value_llava)} != `uncond_pixel_value_llava`"
|
656 |
+
f" {len(uncond_pixel_value_llava)}."
|
657 |
+
)
|
658 |
+
|
659 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
660 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
661 |
+
raise ValueError(
|
662 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
663 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
664 |
+
f" {negative_prompt_embeds.shape}."
|
665 |
+
)
|
666 |
+
|
667 |
+
def get_timesteps(self, num_inference_steps, strength, device):
|
668 |
+
# get the original timestep using init_timestep
|
669 |
+
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
|
670 |
+
|
671 |
+
t_start = max(num_inference_steps - init_timestep, 0)
|
672 |
+
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
673 |
+
if hasattr(self.scheduler, "set_begin_index"):
|
674 |
+
self.scheduler.set_begin_index(t_start * self.scheduler.order)
|
675 |
+
|
676 |
+
return timesteps.to(device), num_inference_steps - t_start
|
677 |
+
|
678 |
+
def prepare_latents(self, batch_size, num_channels_latents, height, width, frame, dtype, device, generator, latents=None, ref_latents=None, timestep=None):
|
679 |
+
shape = (
|
680 |
+
batch_size,
|
681 |
+
num_channels_latents,
|
682 |
+
frame,
|
683 |
+
int(height) // self.vae_scale_factor,
|
684 |
+
int(width) // self.vae_scale_factor,
|
685 |
+
)
|
686 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
687 |
+
raise ValueError(
|
688 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
689 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
690 |
+
)
|
691 |
+
|
692 |
+
if latents is None:
|
693 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
694 |
+
else:
|
695 |
+
latents = latents.to(device)
|
696 |
+
|
697 |
+
|
698 |
+
if timestep is not None:
|
699 |
+
init_latents = ref_latents.clone().repeat(1,1,frame,1,1).to(device).to(dtype)
|
700 |
+
latents = latents
|
701 |
+
|
702 |
+
# Check existence to make it compatible with FlowMatchEulerDiscreteScheduler
|
703 |
+
if hasattr(self.scheduler, "init_noise_sigma"):
|
704 |
+
latents = latents * self.scheduler.init_noise_sigma
|
705 |
+
|
706 |
+
return latents
|
707 |
+
|
708 |
+
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
|
709 |
+
def get_guidance_scale_embedding(
|
710 |
+
self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32
|
711 |
+
) -> torch.Tensor:
|
712 |
+
"""
|
713 |
+
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
|
714 |
+
|
715 |
+
Args:
|
716 |
+
w (`torch.Tensor`):
|
717 |
+
Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.
|
718 |
+
embedding_dim (`int`, *optional*, defaults to 512):
|
719 |
+
Dimension of the embeddings to generate.
|
720 |
+
dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
|
721 |
+
Data type of the generated embeddings.
|
722 |
+
|
723 |
+
Returns:
|
724 |
+
`torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`.
|
725 |
+
"""
|
726 |
+
assert len(w.shape) == 1
|
727 |
+
w = w * 1000.0
|
728 |
+
|
729 |
+
half_dim = embedding_dim // 2
|
730 |
+
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
|
731 |
+
emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
|
732 |
+
emb = w.to(dtype)[:, None] * emb[None, :]
|
733 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
734 |
+
if embedding_dim % 2 == 1: # zero pad
|
735 |
+
emb = torch.nn.functional.pad(emb, (0, 1))
|
736 |
+
assert emb.shape == (w.shape[0], embedding_dim)
|
737 |
+
return emb
|
738 |
+
|
739 |
+
@property
|
740 |
+
def guidance_scale(self):
|
741 |
+
return self._guidance_scale
|
742 |
+
|
743 |
+
@property
|
744 |
+
def guidance_rescale(self):
|
745 |
+
return self._guidance_rescale
|
746 |
+
|
747 |
+
@property
|
748 |
+
def clip_skip(self):
|
749 |
+
return self._clip_skip
|
750 |
+
|
751 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
752 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
753 |
+
# corresponds to doing no classifier free guidance.
|
754 |
+
@property
|
755 |
+
def do_classifier_free_guidance(self):
|
756 |
+
# return self._guidance_scale > 1 and self.transformer.config.time_cond_proj_dim is None
|
757 |
+
return self._guidance_scale > 1
|
758 |
+
|
759 |
+
@property
|
760 |
+
def cross_attention_kwargs(self):
|
761 |
+
return self._cross_attention_kwargs
|
762 |
+
|
763 |
+
@property
|
764 |
+
def num_timesteps(self):
|
765 |
+
return self._num_timesteps
|
766 |
+
|
767 |
+
@property
|
768 |
+
def interrupt(self):
|
769 |
+
return self._interrupt
|
770 |
+
|
771 |
+
@torch.no_grad()
|
772 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
773 |
+
def __call__(
|
774 |
+
self,
|
775 |
+
prompt: Union[str, List[str]],
|
776 |
+
|
777 |
+
ref_latents: Union[torch.Tensor], # [1, 16, 1, h//8, w//8]
|
778 |
+
# uncond_ref_latents: Union[torch.Tensor],
|
779 |
+
pixel_value_llava: Union[torch.Tensor], # [1, 3, 336, 336]
|
780 |
+
uncond_pixel_value_llava: Union[torch.Tensor],
|
781 |
+
pixel_value_ref: Union[torch.Tensor],
|
782 |
+
face_masks: Union[torch.Tensor], # [b f h w]
|
783 |
+
audio_prompts: Union[torch.Tensor],
|
784 |
+
uncond_audio_prompts: Union[torch.Tensor],
|
785 |
+
motion_exp: Union[torch.Tensor],
|
786 |
+
motion_pose: Union[torch.Tensor],
|
787 |
+
fps: Union[torch.Tensor],
|
788 |
+
|
789 |
+
height: int,
|
790 |
+
width: int,
|
791 |
+
video_length: int,
|
792 |
+
data_type: str = "video",
|
793 |
+
num_inference_steps: int = 50,
|
794 |
+
timesteps: List[int] = None,
|
795 |
+
sigmas: List[float] = None,
|
796 |
+
guidance_scale: float = 7.5,
|
797 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
798 |
+
num_videos_per_prompt: Optional[int] = 1,
|
799 |
+
eta: float = 0.0,
|
800 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
801 |
+
latents: Optional[torch.Tensor] = None,
|
802 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
803 |
+
attention_mask: Optional[torch.Tensor] = None,
|
804 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
805 |
+
negative_attention_mask: Optional[torch.Tensor] = None,
|
806 |
+
output_type: Optional[str] = "pil",
|
807 |
+
return_dict: bool = True,
|
808 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
809 |
+
guidance_rescale: float = 0.0,
|
810 |
+
clip_skip: Optional[int] = None,
|
811 |
+
callback_on_step_end: Optional[
|
812 |
+
Union[
|
813 |
+
Callable[[int, int, Dict], None],
|
814 |
+
PipelineCallback,
|
815 |
+
MultiPipelineCallbacks,
|
816 |
+
]
|
817 |
+
] = None,
|
818 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
819 |
+
freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None,
|
820 |
+
vae_ver: str = "88-4c-sd",
|
821 |
+
enable_tiling: bool = False,
|
822 |
+
n_tokens: Optional[int] = None,
|
823 |
+
embedded_guidance_scale: Optional[float] = None,
|
824 |
+
joint_pass = False,
|
825 |
+
cfg_star_rescale = False,
|
826 |
+
name = None,
|
827 |
+
**kwargs,
|
828 |
+
):
|
829 |
+
r"""
|
830 |
+
The call function to the pipeline for generation.
|
831 |
+
|
832 |
+
Args:
|
833 |
+
prompt (`str` or `List[str]`):
|
834 |
+
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
|
835 |
+
height (`int`):
|
836 |
+
The height in pixels of the generated image.
|
837 |
+
width (`int`):
|
838 |
+
The width in pixels of the generated image.
|
839 |
+
video_length (`int`):
|
840 |
+
The number of frames in the generated video.
|
841 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
842 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
843 |
+
expense of slower inference.
|
844 |
+
timesteps (`List[int]`, *optional*):
|
845 |
+
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
846 |
+
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
847 |
+
passed will be used. Must be in descending order.
|
848 |
+
sigmas (`List[float]`, *optional*):
|
849 |
+
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
850 |
+
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
851 |
+
will be used.
|
852 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
853 |
+
A higher guidance scale value encourages the model to generate images closely linked to the text
|
854 |
+
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
|
855 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
856 |
+
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
|
857 |
+
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
|
858 |
+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
859 |
+
The number of images to generate per prompt.
|
860 |
+
eta (`float`, *optional*, defaults to 0.0):
|
861 |
+
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
|
862 |
+
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
|
863 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
864 |
+
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
865 |
+
generation deterministic.
|
866 |
+
latents (`torch.Tensor`, *optional*):
|
867 |
+
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
868 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
869 |
+
tensor is generated by sampling using the supplied random `generator`.
|
870 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
871 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
872 |
+
provided, text embeddings are generated from the `prompt` input argument.
|
873 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
874 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
|
875 |
+
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
|
876 |
+
|
877 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
878 |
+
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
879 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
880 |
+
Whether or not to return a [`HunyuanVideoPipelineOutput`] instead of a
|
881 |
+
plain tuple.
|
882 |
+
cross_attention_kwargs (`dict`, *optional*):
|
883 |
+
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
|
884 |
+
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
885 |
+
guidance_rescale (`float`, *optional*, defaults to 0.0):
|
886 |
+
Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
|
887 |
+
Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
|
888 |
+
using zero terminal SNR.
|
889 |
+
clip_skip (`int`, *optional*):
|
890 |
+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
891 |
+
the output of the pre-final layer will be used for computing the prompt embeddings.
|
892 |
+
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
|
893 |
+
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
|
894 |
+
each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
|
895 |
+
DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
|
896 |
+
list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
|
897 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
898 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
899 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
900 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
901 |
+
|
902 |
+
Examples:
|
903 |
+
|
904 |
+
Returns:
|
905 |
+
[`~HunyuanVideoPipelineOutput`] or `tuple`:
|
906 |
+
If `return_dict` is `True`, [`HunyuanVideoPipelineOutput`] is returned,
|
907 |
+
otherwise a `tuple` is returned where the first element is a list with the generated images and the
|
908 |
+
second element is a list of `bool`s indicating whether the corresponding generated image contains
|
909 |
+
"not-safe-for-work" (nsfw) content.
|
910 |
+
"""
|
911 |
+
|
912 |
+
if self._interrupt:
|
913 |
+
return [None]
|
914 |
+
|
915 |
+
callback = kwargs.pop("callback", None)
|
916 |
+
callback_steps = kwargs.pop("callback_steps", None)
|
917 |
+
if callback_steps is not None:
|
918 |
+
deprecate(
|
919 |
+
"callback_steps",
|
920 |
+
"1.0.0",
|
921 |
+
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
|
922 |
+
)
|
923 |
+
|
924 |
+
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
925 |
+
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
926 |
+
|
927 |
+
|
928 |
+
# num_inference_steps = 50
|
929 |
+
|
930 |
+
# 0. Default height and width to transformer
|
931 |
+
# height = height or self.transformer.config.sample_size * self.vae_scale_factor
|
932 |
+
# width = width or self.transformer.config.sample_size * self.vae_scale_factor
|
933 |
+
# to deal with lora scaling and other possible forward hooks
|
934 |
+
|
935 |
+
transformer = self.transformer
|
936 |
+
|
937 |
+
if transformer.enable_teacache:
|
938 |
+
teacache_multiplier = transformer.teacache_multiplier
|
939 |
+
transformer.accumulated_rel_l1_distance = 0
|
940 |
+
transformer.rel_l1_thresh = 0.1 if teacache_multiplier < 2 else 0.15
|
941 |
+
|
942 |
+
# 1. Check inputs. Raise error if not correct
|
943 |
+
self.check_inputs(
|
944 |
+
prompt,
|
945 |
+
height,
|
946 |
+
width,
|
947 |
+
video_length,
|
948 |
+
callback_steps,
|
949 |
+
pixel_value_llava,
|
950 |
+
uncond_pixel_value_llava,
|
951 |
+
negative_prompt,
|
952 |
+
prompt_embeds,
|
953 |
+
negative_prompt_embeds,
|
954 |
+
callback_on_step_end_tensor_inputs,
|
955 |
+
vae_ver=vae_ver
|
956 |
+
)
|
957 |
+
|
958 |
+
self._guidance_scale = guidance_scale
|
959 |
+
self.start_cfg_scale = guidance_scale
|
960 |
+
self._guidance_rescale = guidance_rescale
|
961 |
+
self._clip_skip = clip_skip
|
962 |
+
self._cross_attention_kwargs = cross_attention_kwargs
|
963 |
+
|
964 |
+
# 2. Define call parameters
|
965 |
+
if prompt is not None and isinstance(prompt, str):
|
966 |
+
batch_size = 1
|
967 |
+
elif prompt is not None and isinstance(prompt, list):
|
968 |
+
batch_size = len(prompt)
|
969 |
+
else:
|
970 |
+
batch_size = prompt_embeds.shape[0]
|
971 |
+
|
972 |
+
device = self._execution_device
|
973 |
+
|
974 |
+
# 3. Encode input prompt
|
975 |
+
lora_scale = (
|
976 |
+
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
|
977 |
+
)
|
978 |
+
|
979 |
+
|
980 |
+
# ========== Encode text prompt (image prompt) ==========
|
981 |
+
prompt_embeds, negative_prompt_embeds, prompt_mask, negative_prompt_mask = \
|
982 |
+
self.encode_prompt_audio_text_base(
|
983 |
+
prompt=prompt,
|
984 |
+
uncond_prompt=negative_prompt,
|
985 |
+
pixel_value_llava=pixel_value_llava,
|
986 |
+
uncond_pixel_value_llava=uncond_pixel_value_llava,
|
987 |
+
device=device,
|
988 |
+
num_images_per_prompt=num_videos_per_prompt,
|
989 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
990 |
+
negative_prompt=negative_prompt,
|
991 |
+
prompt_embeds=prompt_embeds,
|
992 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
993 |
+
lora_scale=lora_scale,
|
994 |
+
clip_skip=self.clip_skip,
|
995 |
+
text_encoder=self.text_encoder,
|
996 |
+
data_type=data_type,
|
997 |
+
name= name,
|
998 |
+
# **kwargs
|
999 |
+
)
|
1000 |
+
if self.text_encoder_2 is not None:
|
1001 |
+
prompt_embeds_2, negative_prompt_embeds_2, prompt_mask_2, negative_prompt_mask_2 = \
|
1002 |
+
self.encode_prompt_audio_text_base(
|
1003 |
+
prompt=prompt,
|
1004 |
+
uncond_prompt=negative_prompt,
|
1005 |
+
pixel_value_llava=None,
|
1006 |
+
uncond_pixel_value_llava=None,
|
1007 |
+
device=device,
|
1008 |
+
num_images_per_prompt=num_videos_per_prompt,
|
1009 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
1010 |
+
negative_prompt=negative_prompt,
|
1011 |
+
prompt_embeds=None,
|
1012 |
+
negative_prompt_embeds=None,
|
1013 |
+
lora_scale=lora_scale,
|
1014 |
+
clip_skip=self.clip_skip,
|
1015 |
+
text_encoder=self.text_encoder_2,
|
1016 |
+
# **kwargs
|
1017 |
+
)
|
1018 |
+
else:
|
1019 |
+
prompt_embeds_2 = None
|
1020 |
+
negative_prompt_embeds_2 = None
|
1021 |
+
prompt_mask_2 = None
|
1022 |
+
negative_prompt_mask_2 = None
|
1023 |
+
|
1024 |
+
if self.transformer.mixed_precision:
|
1025 |
+
latent_dtype = torch.float32
|
1026 |
+
else:
|
1027 |
+
latent_dtype = torch.bfloat16
|
1028 |
+
if prompt_embeds != None:
|
1029 |
+
prompt_embeds = prompt_embeds.to(torch.bfloat16)
|
1030 |
+
if negative_prompt_embeds != None:
|
1031 |
+
negative_prompt_embeds = negative_prompt_embeds.to(torch.bfloat16)
|
1032 |
+
if prompt_embeds_2 != None:
|
1033 |
+
prompt_embeds_2 = prompt_embeds_2.to(torch.bfloat16)
|
1034 |
+
if negative_prompt_embeds_2 != None:
|
1035 |
+
negative_prompt_embeds_2 = negative_prompt_embeds_2.to(torch.bfloat16)
|
1036 |
+
if audio_prompts != None:
|
1037 |
+
audio_prompts = audio_prompts.to(torch.bfloat16)
|
1038 |
+
if face_masks!= None:
|
1039 |
+
face_masks = face_masks.to(torch.bfloat16)
|
1040 |
+
if ref_latents != None:
|
1041 |
+
ref_latents = ref_latents.to(torch.bfloat16)
|
1042 |
+
|
1043 |
+
# For classifier free guidance, we need to do two forward passes.
|
1044 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
1045 |
+
# to avoid doing two forward passes
|
1046 |
+
if self.do_classifier_free_guidance:
|
1047 |
+
prompt_embeds_input = torch.cat([negative_prompt_embeds, prompt_embeds])
|
1048 |
+
if prompt_mask is not None:
|
1049 |
+
prompt_mask_input = torch.cat([negative_prompt_mask, prompt_mask])
|
1050 |
+
if prompt_embeds_2 is not None:
|
1051 |
+
prompt_embeds_2_input = torch.cat([negative_prompt_embeds_2, prompt_embeds_2])
|
1052 |
+
if prompt_mask_2 is not None:
|
1053 |
+
prompt_mask_2_input = torch.cat([negative_prompt_mask_2, prompt_mask_2])
|
1054 |
+
|
1055 |
+
if self.do_classifier_free_guidance and ref_latents != None:
|
1056 |
+
ref_latents = torch.cat([ref_latents, ref_latents], dim=0)
|
1057 |
+
|
1058 |
+
|
1059 |
+
# 4. Prepare timesteps
|
1060 |
+
extra_set_timesteps_kwargs = self.prepare_extra_func_kwargs(
|
1061 |
+
self.scheduler.set_timesteps, {"n_tokens": n_tokens}
|
1062 |
+
)
|
1063 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
1064 |
+
self.scheduler, num_inference_steps, device, timesteps, sigmas, **extra_set_timesteps_kwargs,
|
1065 |
+
)
|
1066 |
+
|
1067 |
+
video_length = audio_prompts.shape[1] // 4 * 4 + 1
|
1068 |
+
if "884" in vae_ver:
|
1069 |
+
video_length = (video_length - 1) // 4 + 1
|
1070 |
+
elif "888" in vae_ver:
|
1071 |
+
video_length = (video_length - 1) // 8 + 1
|
1072 |
+
else:
|
1073 |
+
video_length = video_length
|
1074 |
+
|
1075 |
+
|
1076 |
+
# 5. Prepare latent variables
|
1077 |
+
num_channels_latents = self.transformer.config.in_channels
|
1078 |
+
infer_length = (audio_prompts.shape[1] // 128 + 1) * 32 + 1
|
1079 |
+
latents = self.prepare_latents(
|
1080 |
+
batch_size * num_videos_per_prompt,
|
1081 |
+
num_channels_latents,
|
1082 |
+
height,
|
1083 |
+
width,
|
1084 |
+
infer_length,
|
1085 |
+
latent_dtype, #prompt_embeds.dtype,
|
1086 |
+
device,
|
1087 |
+
generator,
|
1088 |
+
latents,
|
1089 |
+
ref_latents[-1:] if ref_latents != None else None,
|
1090 |
+
timesteps[:1]
|
1091 |
+
)
|
1092 |
+
|
1093 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
1094 |
+
extra_step_kwargs = self.prepare_extra_func_kwargs(
|
1095 |
+
self.scheduler.step, {"generator": generator, "eta": eta},
|
1096 |
+
)
|
1097 |
+
|
1098 |
+
vae_precision = "fp16" # torch.float16
|
1099 |
+
precision = "bf16" # torch.bfloat16
|
1100 |
+
disable_autocast = True
|
1101 |
+
|
1102 |
+
target_dtype = PRECISION_TO_TYPE[precision]
|
1103 |
+
autocast_enabled = (target_dtype != torch.float32) and not disable_autocast
|
1104 |
+
vae_dtype = self.vae._model_dtype #PRECISION_TO_TYPE[vae_precision]
|
1105 |
+
vae_autocast_enabled = (vae_dtype != torch.float32) and not disable_autocast
|
1106 |
+
|
1107 |
+
# 7. Denoising loop
|
1108 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
1109 |
+
self._num_timesteps = len(timesteps)
|
1110 |
+
|
1111 |
+
latents_all = latents.clone()
|
1112 |
+
pad_audio_length = (audio_prompts.shape[1] // 128 + 1) * 128 + 4 - audio_prompts.shape[1]
|
1113 |
+
audio_prompts_all = torch.cat([audio_prompts, torch.zeros_like(audio_prompts[:, :pad_audio_length])], dim=1)
|
1114 |
+
|
1115 |
+
|
1116 |
+
shift = 0
|
1117 |
+
shift_offset = 10
|
1118 |
+
frames_per_batch = 33
|
1119 |
+
self.cache_tensor = None
|
1120 |
+
|
1121 |
+
""" If the total length is shorter than 129, shift is not required """
|
1122 |
+
if video_length == 33 or infer_length == 33:
|
1123 |
+
infer_length = 33
|
1124 |
+
shift_offset = 0
|
1125 |
+
latents_all = latents_all[:, :, :33]
|
1126 |
+
audio_prompts_all = audio_prompts_all[:, :132]
|
1127 |
+
joint_pass = joint_pass or not self.do_classifier_free_guidance
|
1128 |
+
|
1129 |
+
if callback != None:
|
1130 |
+
callback(-1, None, True, override_num_inference_steps = num_inference_steps)
|
1131 |
+
|
1132 |
+
latent_items = 2 if self.do_classifier_free_guidance else 1
|
1133 |
+
|
1134 |
+
fps = torch.from_numpy(np.array(fps)).unsqueeze(0).to(dtype=torch.float16)
|
1135 |
+
|
1136 |
+
if self._interrupt:
|
1137 |
+
return [None]
|
1138 |
+
|
1139 |
+
if transformer.enable_teacache:
|
1140 |
+
cache_size = round( infer_length / frames_per_batch )
|
1141 |
+
transformer.previous_residual = [None] * latent_items
|
1142 |
+
cache_all_previous_residual = [None] * latent_items
|
1143 |
+
cache_all_previous_modulated_input = None
|
1144 |
+
cache_should_calc = [True] * cache_size
|
1145 |
+
cache_accumulated_rel_l1_distance = [0.] * cache_size
|
1146 |
+
cache_teacache_skipped_steps = [0] * cache_size
|
1147 |
+
|
1148 |
+
|
1149 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
1150 |
+
for i, t in enumerate(timesteps):
|
1151 |
+
# init
|
1152 |
+
pred_latents = torch.zeros_like(
|
1153 |
+
latents_all,
|
1154 |
+
dtype=latents_all.dtype,
|
1155 |
+
)
|
1156 |
+
counter = torch.zeros(
|
1157 |
+
(latents_all.shape[0], latents_all.shape[1], infer_length, 1, 1),
|
1158 |
+
dtype=latents_all.dtype,
|
1159 |
+
).to(device=latents_all.device)
|
1160 |
+
|
1161 |
+
cache_slot_no = 0
|
1162 |
+
for index_start in range(0, infer_length, frames_per_batch):
|
1163 |
+
self.scheduler._step_index = None
|
1164 |
+
|
1165 |
+
index_start = index_start - shift
|
1166 |
+
idx_list = [ii % latents_all.shape[2] for ii in range(index_start, index_start + frames_per_batch)]
|
1167 |
+
latents = latents_all[:, :, idx_list].clone()
|
1168 |
+
|
1169 |
+
idx_list_audio = [ii % audio_prompts_all.shape[1] for ii in range(index_start * 4, (index_start + frames_per_batch) * 4 - 3)]
|
1170 |
+
audio_prompts = audio_prompts_all[:, idx_list_audio].clone()
|
1171 |
+
|
1172 |
+
# expand the latents if we are doing classifier free guidance
|
1173 |
+
if self.do_classifier_free_guidance:
|
1174 |
+
latent_model_input = torch.cat([latents] * 2)
|
1175 |
+
else:
|
1176 |
+
latent_model_input = latents
|
1177 |
+
|
1178 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
1179 |
+
embedded_hw = (latent_model_input.shape[-1] // 2) * (latent_model_input.shape[-2] // 2) * 3072
|
1180 |
+
img_ref_len = (latent_model_input.shape[-1] // 2) * (latent_model_input.shape[-2] // 2) * ( 1)
|
1181 |
+
img_all_len = (latents_all.shape[-1] // 2) * (latents_all.shape[-2] // 2) * latents_all.shape[-3]
|
1182 |
+
|
1183 |
+
if transformer.enable_teacache and cache_size > 1:
|
1184 |
+
for l in range(latent_items):
|
1185 |
+
if cache_all_previous_residual[l] != None:
|
1186 |
+
bsz = cache_all_previous_residual[l].shape[0]
|
1187 |
+
transformer.previous_residual[l][:, img_ref_len:] = cache_all_previous_residual[l].reshape(1, -1, embedded_hw) [:, idx_list].reshape(1, -1, 3072)
|
1188 |
+
if cache_all_previous_modulated_input != None:
|
1189 |
+
transformer.previous_modulated_input[:, img_ref_len:] = cache_all_previous_modulated_input.reshape(1, -1, embedded_hw) [:, idx_list].reshape(1, -1, 3072)
|
1190 |
+
transformer.should_calc = cache_should_calc[cache_slot_no]
|
1191 |
+
transformer.accumulated_rel_l1_distance = cache_accumulated_rel_l1_distance[cache_slot_no]
|
1192 |
+
transformer.teacache_skipped_steps = cache_teacache_skipped_steps[cache_slot_no]
|
1193 |
+
|
1194 |
+
|
1195 |
+
if self.do_classifier_free_guidance:
|
1196 |
+
if i < num_inference_steps * 0.2 :
|
1197 |
+
self._guidance_scale = (1 - i / len(timesteps)) * (self.start_cfg_scale - 2) + 2
|
1198 |
+
audio_prompts_input = torch.cat([uncond_audio_prompts, audio_prompts], dim=0)
|
1199 |
+
face_masks_input = torch.cat([face_masks * 0.6] * 2, dim=0)
|
1200 |
+
else:
|
1201 |
+
# define 10-50 step cfg
|
1202 |
+
self._guidance_scale = (1 - i / len(timesteps)) * (6.5 - 3.5) + 3.5 # 5-2 +2
|
1203 |
+
|
1204 |
+
prompt_embeds_input = torch.cat([prompt_embeds, prompt_embeds])
|
1205 |
+
if prompt_mask is not None:
|
1206 |
+
prompt_mask_input = torch.cat([prompt_mask, prompt_mask])
|
1207 |
+
if prompt_embeds_2 is not None:
|
1208 |
+
prompt_embeds_2_input = torch.cat([prompt_embeds_2, prompt_embeds_2])
|
1209 |
+
if prompt_mask_2 is not None:
|
1210 |
+
prompt_mask_2_input = torch.cat([prompt_mask_2, prompt_mask_2])
|
1211 |
+
audio_prompts_input = torch.cat([uncond_audio_prompts, audio_prompts], dim=0)
|
1212 |
+
face_masks_input = torch.cat([face_masks] * 2, dim=0)
|
1213 |
+
|
1214 |
+
motion_exp_input = torch.cat([motion_exp] * 2, dim=0)
|
1215 |
+
motion_pose_input = torch.cat([motion_pose] * 2, dim=0)
|
1216 |
+
fps_input = torch.cat([fps] * 2, dim=0)
|
1217 |
+
|
1218 |
+
else:
|
1219 |
+
audio_prompts_input = audio_prompts
|
1220 |
+
face_masks_input = face_masks
|
1221 |
+
motion_exp_input = motion_exp
|
1222 |
+
motion_pose_input = motion_pose
|
1223 |
+
fps_input = fps
|
1224 |
+
|
1225 |
+
t_expand = t.repeat(latent_model_input.shape[0])
|
1226 |
+
guidance_expand = None
|
1227 |
+
|
1228 |
+
with torch.autocast(device_type="cuda", dtype=target_dtype, enabled=autocast_enabled):
|
1229 |
+
additional_kwargs = {
|
1230 |
+
"pipeline": self,
|
1231 |
+
"step_no": i,
|
1232 |
+
}
|
1233 |
+
if joint_pass:
|
1234 |
+
additional_kwargs.update({
|
1235 |
+
"motion_exp": motion_exp_input,
|
1236 |
+
"motion_pose": motion_pose_input,
|
1237 |
+
"fps": fps_input,
|
1238 |
+
"audio_prompts": audio_prompts_input,
|
1239 |
+
"face_mask": face_masks_input
|
1240 |
+
})
|
1241 |
+
noise_pred = self.transformer(latent_model_input, t_expand, ref_latents=ref_latents, text_states=prompt_embeds_input, text_mask=prompt_mask_input, text_states_2=prompt_embeds_2_input, freqs_cos=freqs_cis[0], freqs_sin=freqs_cis[1], guidance=guidance_expand, **additional_kwargs,)
|
1242 |
+
if self._interrupt:
|
1243 |
+
return [None]
|
1244 |
+
if self.do_classifier_free_guidance:
|
1245 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
1246 |
+
else:
|
1247 |
+
additional_kwargs.update({
|
1248 |
+
"motion_exp": motion_exp_input[:1],
|
1249 |
+
"motion_pose": motion_pose_input[:1],
|
1250 |
+
"fps": fps_input[:1],
|
1251 |
+
"audio_prompts": audio_prompts_input[:1],
|
1252 |
+
"face_mask": face_masks_input[:1]
|
1253 |
+
})
|
1254 |
+
noise_pred_uncond = self.transformer(latent_model_input[:1], t_expand[:1], ref_latents=ref_latents[:1], text_states=prompt_embeds_input[:1], text_mask=prompt_mask_input[:1], text_states_2=prompt_embeds_2_input[:1], freqs_cos=freqs_cis[0], freqs_sin=freqs_cis[1], guidance=guidance_expand, x_id = 0, **additional_kwargs,)
|
1255 |
+
if self._interrupt:
|
1256 |
+
return [None]
|
1257 |
+
noise_pred_uncond = noise_pred_uncond[0]
|
1258 |
+
additional_kwargs.update({
|
1259 |
+
"motion_exp": motion_exp_input[1:],
|
1260 |
+
"motion_pose": motion_pose_input[1:],
|
1261 |
+
"fps": fps_input[1:],
|
1262 |
+
"audio_prompts": audio_prompts_input[1:],
|
1263 |
+
"face_mask": face_masks_input[1:]
|
1264 |
+
})
|
1265 |
+
noise_pred_text = self.transformer(latent_model_input[1:], t_expand[1:], ref_latents=ref_latents[1:], text_states=prompt_embeds_input[1:], text_mask=prompt_mask_input[1:], text_states_2=prompt_embeds_2_input[1:], freqs_cos=freqs_cis[0], freqs_sin=freqs_cis[1], guidance=guidance_expand, x_id = 1, **additional_kwargs,)
|
1266 |
+
if self._interrupt:
|
1267 |
+
return [None]
|
1268 |
+
noise_pred_text = noise_pred_text[0]
|
1269 |
+
|
1270 |
+
# perform guidance
|
1271 |
+
if self.do_classifier_free_guidance:
|
1272 |
+
if cfg_star_rescale:
|
1273 |
+
batch_size = 1
|
1274 |
+
positive_flat = noise_pred_text.view(batch_size, -1)
|
1275 |
+
negative_flat = noise_pred_uncond.view(batch_size, -1)
|
1276 |
+
dot_product = torch.sum(
|
1277 |
+
positive_flat * negative_flat, dim=1, keepdim=True
|
1278 |
+
)
|
1279 |
+
squared_norm = torch.sum(negative_flat**2, dim=1, keepdim=True) + 1e-8
|
1280 |
+
positive_flat, negative_flat = None, None
|
1281 |
+
alpha = dot_product / squared_norm
|
1282 |
+
noise_pred_uncond *= alpha
|
1283 |
+
|
1284 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
1285 |
+
|
1286 |
+
if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
|
1287 |
+
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
1288 |
+
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
|
1289 |
+
noise_pred_text, noise_pred_uncond = None, None
|
1290 |
+
# compute the previous noisy sample x_t -> x_t-1
|
1291 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
1292 |
+
noise_pred = None
|
1293 |
+
|
1294 |
+
latents = latents.to(torch.bfloat16)
|
1295 |
+
for iii in range(frames_per_batch):
|
1296 |
+
p = (index_start + iii) % pred_latents.shape[2]
|
1297 |
+
pred_latents[:, :, p] += latents[:, :, iii]
|
1298 |
+
counter[:, :, p] += 1
|
1299 |
+
|
1300 |
+
if transformer.enable_teacache and cache_size > 1:
|
1301 |
+
for l in range(latent_items):
|
1302 |
+
if transformer.previous_residual[l] != None:
|
1303 |
+
bsz = transformer.previous_residual[l].shape[0]
|
1304 |
+
if cache_all_previous_residual[l] == None:
|
1305 |
+
cache_all_previous_residual[l] = torch.zeros((bsz, img_all_len, 3072 ), device=transformer.previous_residual[l].device, dtype=transformer.previous_residual[l].dtype)
|
1306 |
+
cache_all_previous_residual[l].reshape(bsz, -1, embedded_hw)[:, idx_list] = transformer.previous_residual[l][:, img_ref_len:].reshape(bsz, -1, embedded_hw)
|
1307 |
+
|
1308 |
+
if transformer.previous_modulated_input != None:
|
1309 |
+
if cache_all_previous_modulated_input == None:
|
1310 |
+
cache_all_previous_modulated_input = torch.zeros((1, img_all_len, 3072 ), device=transformer.previous_modulated_input.device, dtype=transformer.previous_modulated_input.dtype)
|
1311 |
+
cache_all_previous_modulated_input.reshape(1, -1, embedded_hw)[:, idx_list] = transformer.previous_modulated_input[:, img_ref_len:].reshape(1, -1, embedded_hw)
|
1312 |
+
cache_should_calc[cache_slot_no] = transformer.should_calc
|
1313 |
+
cache_accumulated_rel_l1_distance[cache_slot_no] = transformer.accumulated_rel_l1_distance
|
1314 |
+
cache_teacache_skipped_steps[cache_slot_no] = transformer.teacache_skipped_steps
|
1315 |
+
|
1316 |
+
cache_slot_no += 1
|
1317 |
+
|
1318 |
+
shift += shift_offset
|
1319 |
+
shift = shift % frames_per_batch
|
1320 |
+
pred_latents = pred_latents / counter
|
1321 |
+
latents_all = pred_latents
|
1322 |
+
|
1323 |
+
if callback is not None:
|
1324 |
+
callback(i, latents_all.squeeze(0), False)
|
1325 |
+
|
1326 |
+
latents = latents_all.float()[:, :, :video_length]
|
1327 |
+
|
1328 |
+
if not output_type == "latent":
|
1329 |
+
expand_temporal_dim = False
|
1330 |
+
if len(latents.shape) == 4:
|
1331 |
+
if isinstance(self.vae, AutoencoderKLCausal3D):
|
1332 |
+
latents = latents.unsqueeze(2)
|
1333 |
+
expand_temporal_dim = True
|
1334 |
+
elif len(latents.shape) == 5:
|
1335 |
+
pass
|
1336 |
+
else:
|
1337 |
+
raise ValueError(
|
1338 |
+
f"Only support latents with shape (b, c, h, w) or (b, c, f, h, w), but got {latents.shape}.")
|
1339 |
+
|
1340 |
+
if hasattr(self.vae.config, 'shift_factor') and self.vae.config.shift_factor:
|
1341 |
+
latents = latents / self.vae.config.scaling_factor + self.vae.config.shift_factor
|
1342 |
+
else:
|
1343 |
+
latents = latents / self.vae.config.scaling_factor
|
1344 |
+
|
1345 |
+
with torch.autocast(device_type="cuda", dtype=vae_dtype, enabled=vae_autocast_enabled):
|
1346 |
+
image = self.vae.decode(latents, return_dict=False, generator=generator)[0]
|
1347 |
+
if image is None:
|
1348 |
+
return (None, )
|
1349 |
+
|
1350 |
+
if expand_temporal_dim or image.shape[2] == 1:
|
1351 |
+
image = image.squeeze(2)
|
1352 |
+
|
1353 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
1354 |
+
image = image.cpu().float()
|
1355 |
+
|
1356 |
+
# Offload all models
|
1357 |
+
self.maybe_free_model_hooks()
|
1358 |
+
|
1359 |
+
if not return_dict:
|
1360 |
+
return image
|
1361 |
+
|
1362 |
+
return HunyuanVideoPipelineOutput(videos=image)
|
hyvideo/diffusion/schedulers/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .scheduling_flow_match_discrete import FlowMatchDiscreteScheduler
|
hyvideo/diffusion/schedulers/scheduling_flow_match_discrete.py
ADDED
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
#
|
16 |
+
# Modified from diffusers==0.29.2
|
17 |
+
#
|
18 |
+
# ==============================================================================
|
19 |
+
|
20 |
+
from dataclasses import dataclass
|
21 |
+
from typing import Optional, Tuple, Union
|
22 |
+
|
23 |
+
import numpy as np
|
24 |
+
import torch
|
25 |
+
|
26 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
27 |
+
from diffusers.utils import BaseOutput, logging
|
28 |
+
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
29 |
+
|
30 |
+
|
31 |
+
|
32 |
+
@dataclass
|
33 |
+
class FlowMatchDiscreteSchedulerOutput(BaseOutput):
|
34 |
+
"""
|
35 |
+
Output class for the scheduler's `step` function output.
|
36 |
+
|
37 |
+
Args:
|
38 |
+
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
39 |
+
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
|
40 |
+
denoising loop.
|
41 |
+
"""
|
42 |
+
|
43 |
+
prev_sample: torch.FloatTensor
|
44 |
+
|
45 |
+
|
46 |
+
class FlowMatchDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
47 |
+
"""
|
48 |
+
Euler scheduler.
|
49 |
+
|
50 |
+
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
|
51 |
+
methods the library implements for all schedulers such as loading and saving.
|
52 |
+
|
53 |
+
Args:
|
54 |
+
num_train_timesteps (`int`, defaults to 1000):
|
55 |
+
The number of diffusion steps to train the model.
|
56 |
+
timestep_spacing (`str`, defaults to `"linspace"`):
|
57 |
+
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
58 |
+
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
59 |
+
shift (`float`, defaults to 1.0):
|
60 |
+
The shift value for the timestep schedule.
|
61 |
+
reverse (`bool`, defaults to `True`):
|
62 |
+
Whether to reverse the timestep schedule.
|
63 |
+
"""
|
64 |
+
|
65 |
+
_compatibles = []
|
66 |
+
order = 1
|
67 |
+
|
68 |
+
@register_to_config
|
69 |
+
def __init__(
|
70 |
+
self,
|
71 |
+
num_train_timesteps: int = 1000,
|
72 |
+
shift: float = 1.0,
|
73 |
+
reverse: bool = True,
|
74 |
+
solver: str = "euler",
|
75 |
+
n_tokens: Optional[int] = None,
|
76 |
+
):
|
77 |
+
sigmas = torch.linspace(1, 0, num_train_timesteps + 1)
|
78 |
+
|
79 |
+
if not reverse:
|
80 |
+
sigmas = sigmas.flip(0)
|
81 |
+
|
82 |
+
self.sigmas = sigmas
|
83 |
+
# the value fed to model
|
84 |
+
self.timesteps = (sigmas[:-1] * num_train_timesteps).to(dtype=torch.float32)
|
85 |
+
|
86 |
+
self._step_index = None
|
87 |
+
self._begin_index = None
|
88 |
+
|
89 |
+
self.supported_solver = ["euler"]
|
90 |
+
if solver not in self.supported_solver:
|
91 |
+
raise ValueError(
|
92 |
+
f"Solver {solver} not supported. Supported solvers: {self.supported_solver}"
|
93 |
+
)
|
94 |
+
|
95 |
+
@property
|
96 |
+
def step_index(self):
|
97 |
+
"""
|
98 |
+
The index counter for current timestep. It will increase 1 after each scheduler step.
|
99 |
+
"""
|
100 |
+
return self._step_index
|
101 |
+
|
102 |
+
@property
|
103 |
+
def begin_index(self):
|
104 |
+
"""
|
105 |
+
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
106 |
+
"""
|
107 |
+
return self._begin_index
|
108 |
+
|
109 |
+
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
110 |
+
def set_begin_index(self, begin_index: int = 0):
|
111 |
+
"""
|
112 |
+
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
113 |
+
|
114 |
+
Args:
|
115 |
+
begin_index (`int`):
|
116 |
+
The begin index for the scheduler.
|
117 |
+
"""
|
118 |
+
self._begin_index = begin_index
|
119 |
+
|
120 |
+
def _sigma_to_t(self, sigma):
|
121 |
+
return sigma * self.config.num_train_timesteps
|
122 |
+
|
123 |
+
def set_timesteps(
|
124 |
+
self,
|
125 |
+
num_inference_steps: int,
|
126 |
+
device: Union[str, torch.device] = None,
|
127 |
+
n_tokens: int = None,
|
128 |
+
):
|
129 |
+
"""
|
130 |
+
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
131 |
+
|
132 |
+
Args:
|
133 |
+
num_inference_steps (`int`):
|
134 |
+
The number of diffusion steps used when generating samples with a pre-trained model.
|
135 |
+
device (`str` or `torch.device`, *optional*):
|
136 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
137 |
+
n_tokens (`int`, *optional*):
|
138 |
+
Number of tokens in the input sequence.
|
139 |
+
"""
|
140 |
+
self.num_inference_steps = num_inference_steps
|
141 |
+
|
142 |
+
sigmas = torch.linspace(1, 0, num_inference_steps + 1)
|
143 |
+
sigmas = self.sd3_time_shift(sigmas)
|
144 |
+
|
145 |
+
if not self.config.reverse:
|
146 |
+
sigmas = 1 - sigmas
|
147 |
+
|
148 |
+
self.sigmas = sigmas
|
149 |
+
self.timesteps = (sigmas[:-1] * self.config.num_train_timesteps).to(
|
150 |
+
dtype=torch.float32, device=device
|
151 |
+
)
|
152 |
+
|
153 |
+
# Reset step index
|
154 |
+
self._step_index = None
|
155 |
+
|
156 |
+
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
157 |
+
if schedule_timesteps is None:
|
158 |
+
schedule_timesteps = self.timesteps
|
159 |
+
|
160 |
+
indices = (schedule_timesteps == timestep).nonzero()
|
161 |
+
|
162 |
+
# The sigma index that is taken for the **very** first `step`
|
163 |
+
# is always the second index (or the last index if there is only 1)
|
164 |
+
# This way we can ensure we don't accidentally skip a sigma in
|
165 |
+
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
166 |
+
pos = 1 if len(indices) > 1 else 0
|
167 |
+
|
168 |
+
return indices[pos].item()
|
169 |
+
|
170 |
+
def _init_step_index(self, timestep):
|
171 |
+
if self.begin_index is None:
|
172 |
+
if isinstance(timestep, torch.Tensor):
|
173 |
+
timestep = timestep.to(self.timesteps.device)
|
174 |
+
self._step_index = self.index_for_timestep(timestep)
|
175 |
+
else:
|
176 |
+
self._step_index = self._begin_index
|
177 |
+
|
178 |
+
def scale_model_input(
|
179 |
+
self, sample: torch.Tensor, timestep: Optional[int] = None
|
180 |
+
) -> torch.Tensor:
|
181 |
+
return sample
|
182 |
+
|
183 |
+
def sd3_time_shift(self, t: torch.Tensor):
|
184 |
+
return (self.config.shift * t) / (1 + (self.config.shift - 1) * t)
|
185 |
+
|
186 |
+
def step(
|
187 |
+
self,
|
188 |
+
model_output: torch.FloatTensor,
|
189 |
+
timestep: Union[float, torch.FloatTensor],
|
190 |
+
sample: torch.FloatTensor,
|
191 |
+
return_dict: bool = True,
|
192 |
+
) -> Union[FlowMatchDiscreteSchedulerOutput, Tuple]:
|
193 |
+
"""
|
194 |
+
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
|
195 |
+
process from the learned model outputs (most often the predicted noise).
|
196 |
+
|
197 |
+
Args:
|
198 |
+
model_output (`torch.FloatTensor`):
|
199 |
+
The direct output from learned diffusion model.
|
200 |
+
timestep (`float`):
|
201 |
+
The current discrete timestep in the diffusion chain.
|
202 |
+
sample (`torch.FloatTensor`):
|
203 |
+
A current instance of a sample created by the diffusion process.
|
204 |
+
generator (`torch.Generator`, *optional*):
|
205 |
+
A random number generator.
|
206 |
+
n_tokens (`int`, *optional*):
|
207 |
+
Number of tokens in the input sequence.
|
208 |
+
return_dict (`bool`):
|
209 |
+
Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
|
210 |
+
tuple.
|
211 |
+
|
212 |
+
Returns:
|
213 |
+
[`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
|
214 |
+
If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
|
215 |
+
returned, otherwise a tuple is returned where the first element is the sample tensor.
|
216 |
+
"""
|
217 |
+
|
218 |
+
if (
|
219 |
+
isinstance(timestep, int)
|
220 |
+
or isinstance(timestep, torch.IntTensor)
|
221 |
+
or isinstance(timestep, torch.LongTensor)
|
222 |
+
):
|
223 |
+
raise ValueError(
|
224 |
+
(
|
225 |
+
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
|
226 |
+
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
|
227 |
+
" one of the `scheduler.timesteps` as a timestep."
|
228 |
+
),
|
229 |
+
)
|
230 |
+
|
231 |
+
if self.step_index is None:
|
232 |
+
self._init_step_index(timestep)
|
233 |
+
|
234 |
+
# Upcast to avoid precision issues when computing prev_sample
|
235 |
+
sample = sample.to(torch.float32)
|
236 |
+
|
237 |
+
dt = self.sigmas[self.step_index + 1] - self.sigmas[self.step_index]
|
238 |
+
|
239 |
+
if self.config.solver == "euler":
|
240 |
+
prev_sample = sample + model_output.to(torch.float32) * dt
|
241 |
+
else:
|
242 |
+
raise ValueError(
|
243 |
+
f"Solver {self.config.solver} not supported. Supported solvers: {self.supported_solver}"
|
244 |
+
)
|
245 |
+
|
246 |
+
# upon completion increase step index by one
|
247 |
+
self._step_index += 1
|
248 |
+
|
249 |
+
if not return_dict:
|
250 |
+
return (prev_sample,)
|
251 |
+
|
252 |
+
return FlowMatchDiscreteSchedulerOutput(prev_sample=prev_sample)
|
253 |
+
|
254 |
+
def __len__(self):
|
255 |
+
return self.config.num_train_timesteps
|
hyvideo/hunyuan.py
ADDED
@@ -0,0 +1,1062 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
import random
|
4 |
+
import functools
|
5 |
+
from typing import List, Optional, Tuple, Union
|
6 |
+
|
7 |
+
from pathlib import Path
|
8 |
+
from einops import rearrange
|
9 |
+
import torch
|
10 |
+
import torch.distributed as dist
|
11 |
+
from hyvideo.constants import PROMPT_TEMPLATE, NEGATIVE_PROMPT, PRECISION_TO_TYPE, NEGATIVE_PROMPT_I2V
|
12 |
+
from hyvideo.vae import load_vae
|
13 |
+
from hyvideo.modules import load_model
|
14 |
+
from hyvideo.text_encoder import TextEncoder
|
15 |
+
from hyvideo.utils.data_utils import align_to, get_closest_ratio, generate_crop_size_list
|
16 |
+
from hyvideo.modules.posemb_layers import get_nd_rotary_pos_embed, get_nd_rotary_pos_embed_new
|
17 |
+
from hyvideo.diffusion.schedulers import FlowMatchDiscreteScheduler
|
18 |
+
from hyvideo.diffusion.pipelines import HunyuanVideoPipeline
|
19 |
+
from hyvideo.diffusion.pipelines import HunyuanVideoAudioPipeline
|
20 |
+
from PIL import Image
|
21 |
+
import numpy as np
|
22 |
+
import torchvision.transforms as transforms
|
23 |
+
import cv2
|
24 |
+
from wan.utils.utils import resize_lanczos, calculate_new_dimensions
|
25 |
+
from hyvideo.data_kits.audio_preprocessor import encode_audio, get_facemask
|
26 |
+
from transformers import WhisperModel
|
27 |
+
from transformers import AutoFeatureExtractor
|
28 |
+
from hyvideo.data_kits.face_align import AlignImage
|
29 |
+
import librosa
|
30 |
+
|
31 |
+
def get_audio_feature(feature_extractor, audio_path, duration):
|
32 |
+
audio_input, sampling_rate = librosa.load(audio_path, duration=duration, sr=16000)
|
33 |
+
assert sampling_rate == 16000
|
34 |
+
|
35 |
+
audio_features = []
|
36 |
+
window = 750*640
|
37 |
+
for i in range(0, len(audio_input), window):
|
38 |
+
audio_feature = feature_extractor(audio_input[i:i+window],
|
39 |
+
sampling_rate=sampling_rate,
|
40 |
+
return_tensors="pt",
|
41 |
+
device="cuda"
|
42 |
+
).input_features
|
43 |
+
audio_features.append(audio_feature)
|
44 |
+
|
45 |
+
audio_features = torch.cat(audio_features, dim=-1)
|
46 |
+
return audio_features, len(audio_input) // 640
|
47 |
+
|
48 |
+
def pad_image(crop_img, size, color=(255, 255, 255), resize_ratio=1):
|
49 |
+
crop_h, crop_w = crop_img.shape[:2]
|
50 |
+
target_w, target_h = size
|
51 |
+
scale_h, scale_w = target_h / crop_h, target_w / crop_w
|
52 |
+
if scale_w > scale_h:
|
53 |
+
resize_h = int(target_h*resize_ratio)
|
54 |
+
resize_w = int(crop_w / crop_h * resize_h)
|
55 |
+
else:
|
56 |
+
resize_w = int(target_w*resize_ratio)
|
57 |
+
resize_h = int(crop_h / crop_w * resize_w)
|
58 |
+
crop_img = cv2.resize(crop_img, (resize_w, resize_h))
|
59 |
+
pad_left = (target_w - resize_w) // 2
|
60 |
+
pad_top = (target_h - resize_h) // 2
|
61 |
+
pad_right = target_w - resize_w - pad_left
|
62 |
+
pad_bottom = target_h - resize_h - pad_top
|
63 |
+
crop_img = cv2.copyMakeBorder(crop_img, pad_top, pad_bottom, pad_left, pad_right, cv2.BORDER_CONSTANT, value=color)
|
64 |
+
return crop_img
|
65 |
+
|
66 |
+
|
67 |
+
|
68 |
+
|
69 |
+
def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels):
|
70 |
+
num_images, num_image_patches, embed_dim = image_features.shape
|
71 |
+
batch_size, sequence_length = input_ids.shape
|
72 |
+
left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id))
|
73 |
+
# 1. Create a mask to know where special image tokens are
|
74 |
+
special_image_token_mask = input_ids == self.config.image_token_index
|
75 |
+
num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
|
76 |
+
# Compute the maximum embed dimension
|
77 |
+
max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length
|
78 |
+
batch_indices, non_image_indices = torch.where(input_ids != self.config.image_token_index)
|
79 |
+
|
80 |
+
# 2. Compute the positions where text should be written
|
81 |
+
# Calculate new positions for text tokens in merged image-text sequence.
|
82 |
+
# `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens.
|
83 |
+
# `torch.cumsum` computes how each image token shifts subsequent text token positions.
|
84 |
+
# - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
|
85 |
+
new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1
|
86 |
+
nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1]
|
87 |
+
if left_padding:
|
88 |
+
new_token_positions += nb_image_pad[:, None] # offset for left padding
|
89 |
+
text_to_overwrite = new_token_positions[batch_indices, non_image_indices]
|
90 |
+
|
91 |
+
# 3. Create the full embedding, already padded to the maximum position
|
92 |
+
final_embedding = torch.zeros(
|
93 |
+
batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
|
94 |
+
)
|
95 |
+
final_attention_mask = torch.zeros(
|
96 |
+
batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
|
97 |
+
)
|
98 |
+
if labels is not None:
|
99 |
+
final_labels = torch.full(
|
100 |
+
(batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device
|
101 |
+
)
|
102 |
+
# In case the Vision model or the Language model has been offloaded to CPU, we need to manually
|
103 |
+
# set the corresponding tensors into their correct target device.
|
104 |
+
target_device = inputs_embeds.device
|
105 |
+
batch_indices, non_image_indices, text_to_overwrite = (
|
106 |
+
batch_indices.to(target_device),
|
107 |
+
non_image_indices.to(target_device),
|
108 |
+
text_to_overwrite.to(target_device),
|
109 |
+
)
|
110 |
+
attention_mask = attention_mask.to(target_device)
|
111 |
+
|
112 |
+
# 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]
|
113 |
+
# we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
|
114 |
+
final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices]
|
115 |
+
final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices]
|
116 |
+
if labels is not None:
|
117 |
+
final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices]
|
118 |
+
|
119 |
+
# 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835)
|
120 |
+
image_to_overwrite = torch.full(
|
121 |
+
(batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device
|
122 |
+
)
|
123 |
+
image_to_overwrite[batch_indices, text_to_overwrite] = False
|
124 |
+
image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device)
|
125 |
+
|
126 |
+
if image_to_overwrite.sum() != image_features.shape[:-1].numel():
|
127 |
+
raise ValueError(
|
128 |
+
f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while"
|
129 |
+
f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation."
|
130 |
+
)
|
131 |
+
|
132 |
+
final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device)
|
133 |
+
final_attention_mask |= image_to_overwrite
|
134 |
+
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
|
135 |
+
|
136 |
+
# 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
|
137 |
+
batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id)
|
138 |
+
indices_to_mask = new_token_positions[batch_indices, pad_indices]
|
139 |
+
|
140 |
+
final_embedding[batch_indices, indices_to_mask] = 0
|
141 |
+
|
142 |
+
if labels is None:
|
143 |
+
final_labels = None
|
144 |
+
|
145 |
+
return final_embedding, final_attention_mask, final_labels, position_ids
|
146 |
+
|
147 |
+
def patched_llava_forward(
|
148 |
+
self,
|
149 |
+
input_ids: torch.LongTensor = None,
|
150 |
+
pixel_values: torch.FloatTensor = None,
|
151 |
+
attention_mask: Optional[torch.Tensor] = None,
|
152 |
+
position_ids: Optional[torch.LongTensor] = None,
|
153 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
154 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
155 |
+
vision_feature_layer: Optional[int] = None,
|
156 |
+
vision_feature_select_strategy: Optional[str] = None,
|
157 |
+
labels: Optional[torch.LongTensor] = None,
|
158 |
+
use_cache: Optional[bool] = None,
|
159 |
+
output_attentions: Optional[bool] = None,
|
160 |
+
output_hidden_states: Optional[bool] = None,
|
161 |
+
return_dict: Optional[bool] = None,
|
162 |
+
cache_position: Optional[torch.LongTensor] = None,
|
163 |
+
num_logits_to_keep: int = 0,
|
164 |
+
):
|
165 |
+
from transformers.models.llava.modeling_llava import LlavaCausalLMOutputWithPast
|
166 |
+
|
167 |
+
|
168 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
169 |
+
output_hidden_states = (
|
170 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
171 |
+
)
|
172 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
173 |
+
vision_feature_layer = (
|
174 |
+
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
|
175 |
+
)
|
176 |
+
vision_feature_select_strategy = (
|
177 |
+
vision_feature_select_strategy
|
178 |
+
if vision_feature_select_strategy is not None
|
179 |
+
else self.config.vision_feature_select_strategy
|
180 |
+
)
|
181 |
+
|
182 |
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
183 |
+
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
184 |
+
|
185 |
+
if pixel_values is not None and inputs_embeds is not None:
|
186 |
+
raise ValueError(
|
187 |
+
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
|
188 |
+
)
|
189 |
+
|
190 |
+
if inputs_embeds is None:
|
191 |
+
inputs_embeds = self.get_input_embeddings()(input_ids)
|
192 |
+
|
193 |
+
image_features = None
|
194 |
+
if pixel_values is not None:
|
195 |
+
image_features = self.get_image_features(
|
196 |
+
pixel_values=pixel_values,
|
197 |
+
vision_feature_layer=vision_feature_layer,
|
198 |
+
vision_feature_select_strategy=vision_feature_select_strategy,
|
199 |
+
)
|
200 |
+
|
201 |
+
|
202 |
+
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
|
203 |
+
image_features, inputs_embeds, input_ids, attention_mask, labels
|
204 |
+
)
|
205 |
+
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)
|
206 |
+
|
207 |
+
|
208 |
+
outputs = self.language_model(
|
209 |
+
attention_mask=attention_mask,
|
210 |
+
position_ids=position_ids,
|
211 |
+
past_key_values=past_key_values,
|
212 |
+
inputs_embeds=inputs_embeds,
|
213 |
+
use_cache=use_cache,
|
214 |
+
output_attentions=output_attentions,
|
215 |
+
output_hidden_states=output_hidden_states,
|
216 |
+
return_dict=return_dict,
|
217 |
+
cache_position=cache_position,
|
218 |
+
num_logits_to_keep=num_logits_to_keep,
|
219 |
+
)
|
220 |
+
|
221 |
+
logits = outputs[0]
|
222 |
+
|
223 |
+
loss = None
|
224 |
+
|
225 |
+
if not return_dict:
|
226 |
+
output = (logits,) + outputs[1:]
|
227 |
+
return (loss,) + output if loss is not None else output
|
228 |
+
|
229 |
+
return LlavaCausalLMOutputWithPast(
|
230 |
+
loss=loss,
|
231 |
+
logits=logits,
|
232 |
+
past_key_values=outputs.past_key_values,
|
233 |
+
hidden_states=outputs.hidden_states,
|
234 |
+
attentions=outputs.attentions,
|
235 |
+
image_hidden_states=image_features if pixel_values is not None else None,
|
236 |
+
)
|
237 |
+
|
238 |
+
def adapt_model(model, audio_block_name):
|
239 |
+
modules_dict= { k: m for k, m in model.named_modules()}
|
240 |
+
for model_layer, avatar_layer in model.double_stream_map.items():
|
241 |
+
module = modules_dict[f"{audio_block_name}.{avatar_layer}"]
|
242 |
+
target = modules_dict[f"double_blocks.{model_layer}"]
|
243 |
+
setattr(target, "audio_adapter", module )
|
244 |
+
delattr(model, audio_block_name)
|
245 |
+
|
246 |
+
class DataPreprocess(object):
|
247 |
+
def __init__(self):
|
248 |
+
self.llava_size = (336, 336)
|
249 |
+
self.llava_transform = transforms.Compose(
|
250 |
+
[
|
251 |
+
transforms.Resize(self.llava_size, interpolation=transforms.InterpolationMode.BILINEAR),
|
252 |
+
transforms.ToTensor(),
|
253 |
+
transforms.Normalize((0.48145466, 0.4578275, 0.4082107), (0.26862954, 0.26130258, 0.27577711)),
|
254 |
+
]
|
255 |
+
)
|
256 |
+
|
257 |
+
def get_batch(self, image , size, pad = False):
|
258 |
+
image = np.asarray(image)
|
259 |
+
if pad:
|
260 |
+
llava_item_image = pad_image(image.copy(), self.llava_size)
|
261 |
+
else:
|
262 |
+
llava_item_image = image.copy()
|
263 |
+
uncond_llava_item_image = np.ones_like(llava_item_image) * 255
|
264 |
+
|
265 |
+
if pad:
|
266 |
+
cat_item_image = pad_image(image.copy(), size)
|
267 |
+
else:
|
268 |
+
cat_item_image = image.copy()
|
269 |
+
llava_item_tensor = self.llava_transform(Image.fromarray(llava_item_image.astype(np.uint8)))
|
270 |
+
uncond_llava_item_tensor = self.llava_transform(Image.fromarray(uncond_llava_item_image))
|
271 |
+
cat_item_tensor = torch.from_numpy(cat_item_image.copy()).permute((2, 0, 1)) / 255.0
|
272 |
+
# batch = {
|
273 |
+
# "pixel_value_llava": llava_item_tensor.unsqueeze(0),
|
274 |
+
# "uncond_pixel_value_llava": uncond_llava_item_tensor.unsqueeze(0),
|
275 |
+
# 'pixel_value_ref': cat_item_tensor.unsqueeze(0),
|
276 |
+
# }
|
277 |
+
return llava_item_tensor.unsqueeze(0), uncond_llava_item_tensor.unsqueeze(0), cat_item_tensor.unsqueeze(0)
|
278 |
+
|
279 |
+
class Inference(object):
|
280 |
+
def __init__(
|
281 |
+
self,
|
282 |
+
i2v,
|
283 |
+
custom,
|
284 |
+
avatar,
|
285 |
+
enable_cfg,
|
286 |
+
vae,
|
287 |
+
vae_kwargs,
|
288 |
+
text_encoder,
|
289 |
+
model,
|
290 |
+
text_encoder_2=None,
|
291 |
+
pipeline=None,
|
292 |
+
feature_extractor=None,
|
293 |
+
wav2vec=None,
|
294 |
+
align_instance=None,
|
295 |
+
device=None,
|
296 |
+
):
|
297 |
+
self.i2v = i2v
|
298 |
+
self.custom = custom
|
299 |
+
self.avatar = avatar
|
300 |
+
self.enable_cfg = enable_cfg
|
301 |
+
self.vae = vae
|
302 |
+
self.vae_kwargs = vae_kwargs
|
303 |
+
|
304 |
+
self.text_encoder = text_encoder
|
305 |
+
self.text_encoder_2 = text_encoder_2
|
306 |
+
|
307 |
+
self.model = model
|
308 |
+
self.pipeline = pipeline
|
309 |
+
|
310 |
+
self.feature_extractor=feature_extractor
|
311 |
+
self.wav2vec=wav2vec
|
312 |
+
self.align_instance=align_instance
|
313 |
+
|
314 |
+
self.device = "cuda"
|
315 |
+
|
316 |
+
|
317 |
+
@classmethod
|
318 |
+
def from_pretrained(cls, model_filepath, text_encoder_filepath, dtype = torch.bfloat16, VAE_dtype = torch.float16, mixed_precision_transformer =torch.bfloat16 , **kwargs):
|
319 |
+
|
320 |
+
device = "cuda"
|
321 |
+
|
322 |
+
import transformers
|
323 |
+
transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.forward = patched_llava_forward # force legacy behaviour to be able to use tansformers v>(4.47)
|
324 |
+
transformers.models.llava.modeling_llava.LlavaForConditionalGeneration._merge_input_ids_with_image_features = _merge_input_ids_with_image_features
|
325 |
+
|
326 |
+
torch.set_grad_enabled(False)
|
327 |
+
text_len = 512
|
328 |
+
latent_channels = 16
|
329 |
+
precision = "bf16"
|
330 |
+
vae_precision = "fp32" if VAE_dtype == torch.float32 else "bf16"
|
331 |
+
embedded_cfg_scale = 6
|
332 |
+
filepath = model_filepath[0]
|
333 |
+
i2v_condition_type = None
|
334 |
+
i2v_mode = "i2v" in filepath
|
335 |
+
custom = False
|
336 |
+
custom_audio = False
|
337 |
+
avatar = False
|
338 |
+
if i2v_mode:
|
339 |
+
model_id = "HYVideo-T/2"
|
340 |
+
i2v_condition_type = "token_replace"
|
341 |
+
elif "custom" in filepath:
|
342 |
+
if "audio" in filepath:
|
343 |
+
model_id = "HYVideo-T/2-custom-audio"
|
344 |
+
custom_audio = True
|
345 |
+
elif "edit" in filepath:
|
346 |
+
model_id = "HYVideo-T/2-custom-edit"
|
347 |
+
else:
|
348 |
+
model_id = "HYVideo-T/2-custom"
|
349 |
+
custom = True
|
350 |
+
elif "avatar" in filepath :
|
351 |
+
model_id = "HYVideo-T/2-avatar"
|
352 |
+
text_len = 256
|
353 |
+
avatar = True
|
354 |
+
else:
|
355 |
+
model_id = "HYVideo-T/2-cfgdistill"
|
356 |
+
|
357 |
+
|
358 |
+
if i2v_mode and i2v_condition_type == "latent_concat":
|
359 |
+
in_channels = latent_channels * 2 + 1
|
360 |
+
image_embed_interleave = 2
|
361 |
+
elif i2v_mode and i2v_condition_type == "token_replace":
|
362 |
+
in_channels = latent_channels
|
363 |
+
image_embed_interleave = 4
|
364 |
+
else:
|
365 |
+
in_channels = latent_channels
|
366 |
+
image_embed_interleave = 1
|
367 |
+
out_channels = latent_channels
|
368 |
+
pinToMemory = kwargs.pop("pinToMemory", False)
|
369 |
+
partialPinning = kwargs.pop("partialPinning", False)
|
370 |
+
factor_kwargs = kwargs | {"device": "meta", "dtype": PRECISION_TO_TYPE[precision]}
|
371 |
+
|
372 |
+
if embedded_cfg_scale and i2v_mode:
|
373 |
+
factor_kwargs["guidance_embed"] = True
|
374 |
+
|
375 |
+
model = load_model(
|
376 |
+
model = model_id,
|
377 |
+
i2v_condition_type = i2v_condition_type,
|
378 |
+
in_channels=in_channels,
|
379 |
+
out_channels=out_channels,
|
380 |
+
factor_kwargs=factor_kwargs,
|
381 |
+
)
|
382 |
+
|
383 |
+
|
384 |
+
from mmgp import offload
|
385 |
+
# model = Inference.load_state_dict(args, model, model_filepath)
|
386 |
+
|
387 |
+
# model_filepath ="c:/temp/hc/mp_rank_00_model_states_video.pt"
|
388 |
+
offload.load_model_data(model, model_filepath, pinToMemory = pinToMemory, partialPinning = partialPinning)
|
389 |
+
pass
|
390 |
+
# offload.save_model(model, "hunyuan_video_avatar_edit_720_bf16.safetensors")
|
391 |
+
# offload.save_model(model, "hunyuan_video_avatar_edit_720_quanto_bf16_int8.safetensors", do_quantize= True)
|
392 |
+
|
393 |
+
model.mixed_precision = mixed_precision_transformer
|
394 |
+
|
395 |
+
if model.mixed_precision :
|
396 |
+
model._lock_dtype = torch.float32
|
397 |
+
model.lock_layers_dtypes(torch.float32)
|
398 |
+
model.eval()
|
399 |
+
|
400 |
+
# ============================= Build extra models ========================
|
401 |
+
# VAE
|
402 |
+
if custom or avatar:
|
403 |
+
vae_configpath = "ckpts/hunyuan_video_custom_VAE_config.json"
|
404 |
+
vae_filepath = "ckpts/hunyuan_video_custom_VAE_fp32.safetensors"
|
405 |
+
# elif avatar:
|
406 |
+
# vae_configpath = "ckpts/config_vae_avatar.json"
|
407 |
+
# vae_filepath = "ckpts/vae_avatar.pt"
|
408 |
+
else:
|
409 |
+
vae_configpath = "ckpts/hunyuan_video_VAE_config.json"
|
410 |
+
vae_filepath = "ckpts/hunyuan_video_VAE_fp32.safetensors"
|
411 |
+
|
412 |
+
# config = AutoencoderKLCausal3D.load_config("ckpts/hunyuan_video_VAE_config.json")
|
413 |
+
# config = AutoencoderKLCausal3D.load_config("c:/temp/hvae/config_vae.json")
|
414 |
+
|
415 |
+
vae, _, s_ratio, t_ratio = load_vae( "884-16c-hy", vae_path= vae_filepath, vae_config_path= vae_configpath, vae_precision= vae_precision, device= "cpu", )
|
416 |
+
|
417 |
+
vae._model_dtype = torch.float32 if VAE_dtype == torch.float32 else (torch.float16 if avatar else torch.bfloat16)
|
418 |
+
vae._model_dtype = torch.float32 if VAE_dtype == torch.float32 else torch.bfloat16
|
419 |
+
vae_kwargs = {"s_ratio": s_ratio, "t_ratio": t_ratio}
|
420 |
+
enable_cfg = False
|
421 |
+
# Text encoder
|
422 |
+
if i2v_mode:
|
423 |
+
text_encoder = "llm-i2v"
|
424 |
+
tokenizer = "llm-i2v"
|
425 |
+
prompt_template = "dit-llm-encode-i2v"
|
426 |
+
prompt_template_video = "dit-llm-encode-video-i2v"
|
427 |
+
elif custom or avatar :
|
428 |
+
text_encoder = "llm-i2v"
|
429 |
+
tokenizer = "llm-i2v"
|
430 |
+
prompt_template = "dit-llm-encode"
|
431 |
+
prompt_template_video = "dit-llm-encode-video"
|
432 |
+
enable_cfg = True
|
433 |
+
else:
|
434 |
+
text_encoder = "llm"
|
435 |
+
tokenizer = "llm"
|
436 |
+
prompt_template = "dit-llm-encode"
|
437 |
+
prompt_template_video = "dit-llm-encode-video"
|
438 |
+
|
439 |
+
if prompt_template_video is not None:
|
440 |
+
crop_start = PROMPT_TEMPLATE[prompt_template_video].get( "crop_start", 0 )
|
441 |
+
elif prompt_template is not None:
|
442 |
+
crop_start = PROMPT_TEMPLATE[prompt_template].get("crop_start", 0)
|
443 |
+
else:
|
444 |
+
crop_start = 0
|
445 |
+
max_length = text_len + crop_start
|
446 |
+
|
447 |
+
# prompt_template
|
448 |
+
prompt_template = PROMPT_TEMPLATE[prompt_template] if prompt_template is not None else None
|
449 |
+
|
450 |
+
# prompt_template_video
|
451 |
+
prompt_template_video = PROMPT_TEMPLATE[prompt_template_video] if prompt_template_video is not None else None
|
452 |
+
|
453 |
+
|
454 |
+
text_encoder = TextEncoder(
|
455 |
+
text_encoder_type=text_encoder,
|
456 |
+
max_length=max_length,
|
457 |
+
text_encoder_precision="fp16",
|
458 |
+
tokenizer_type=tokenizer,
|
459 |
+
i2v_mode=i2v_mode,
|
460 |
+
prompt_template=prompt_template,
|
461 |
+
prompt_template_video=prompt_template_video,
|
462 |
+
hidden_state_skip_layer=2,
|
463 |
+
apply_final_norm=False,
|
464 |
+
reproduce=True,
|
465 |
+
device="cpu",
|
466 |
+
image_embed_interleave=image_embed_interleave,
|
467 |
+
text_encoder_path = text_encoder_filepath
|
468 |
+
)
|
469 |
+
|
470 |
+
text_encoder_2 = TextEncoder(
|
471 |
+
text_encoder_type="clipL",
|
472 |
+
max_length=77,
|
473 |
+
text_encoder_precision="fp16",
|
474 |
+
tokenizer_type="clipL",
|
475 |
+
reproduce=True,
|
476 |
+
device="cpu",
|
477 |
+
)
|
478 |
+
|
479 |
+
feature_extractor = None
|
480 |
+
wav2vec = None
|
481 |
+
align_instance = None
|
482 |
+
|
483 |
+
if avatar or custom_audio:
|
484 |
+
feature_extractor = AutoFeatureExtractor.from_pretrained("ckpts/whisper-tiny/")
|
485 |
+
wav2vec = WhisperModel.from_pretrained("ckpts/whisper-tiny/").to(device="cpu", dtype=torch.float32)
|
486 |
+
wav2vec._model_dtype = torch.float32
|
487 |
+
wav2vec.requires_grad_(False)
|
488 |
+
if avatar:
|
489 |
+
align_instance = AlignImage("cuda", det_path="ckpts/det_align/detface.pt")
|
490 |
+
align_instance.facedet.model.to("cpu")
|
491 |
+
adapt_model(model, "audio_adapter_blocks")
|
492 |
+
elif custom_audio:
|
493 |
+
adapt_model(model, "audio_models")
|
494 |
+
|
495 |
+
return cls(
|
496 |
+
i2v=i2v_mode,
|
497 |
+
custom=custom,
|
498 |
+
avatar=avatar,
|
499 |
+
enable_cfg = enable_cfg,
|
500 |
+
vae=vae,
|
501 |
+
vae_kwargs=vae_kwargs,
|
502 |
+
text_encoder=text_encoder,
|
503 |
+
text_encoder_2=text_encoder_2,
|
504 |
+
model=model,
|
505 |
+
feature_extractor=feature_extractor,
|
506 |
+
wav2vec=wav2vec,
|
507 |
+
align_instance=align_instance,
|
508 |
+
device=device,
|
509 |
+
)
|
510 |
+
|
511 |
+
|
512 |
+
|
513 |
+
class HunyuanVideoSampler(Inference):
|
514 |
+
def __init__(
|
515 |
+
self,
|
516 |
+
i2v,
|
517 |
+
custom,
|
518 |
+
avatar,
|
519 |
+
enable_cfg,
|
520 |
+
vae,
|
521 |
+
vae_kwargs,
|
522 |
+
text_encoder,
|
523 |
+
model,
|
524 |
+
text_encoder_2=None,
|
525 |
+
pipeline=None,
|
526 |
+
feature_extractor=None,
|
527 |
+
wav2vec=None,
|
528 |
+
align_instance=None,
|
529 |
+
device=0,
|
530 |
+
):
|
531 |
+
super().__init__(
|
532 |
+
i2v,
|
533 |
+
custom,
|
534 |
+
avatar,
|
535 |
+
enable_cfg,
|
536 |
+
vae,
|
537 |
+
vae_kwargs,
|
538 |
+
text_encoder,
|
539 |
+
model,
|
540 |
+
text_encoder_2=text_encoder_2,
|
541 |
+
pipeline=pipeline,
|
542 |
+
feature_extractor=feature_extractor,
|
543 |
+
wav2vec=wav2vec,
|
544 |
+
align_instance=align_instance,
|
545 |
+
device=device,
|
546 |
+
)
|
547 |
+
|
548 |
+
self.i2v_mode = i2v
|
549 |
+
self.enable_cfg = enable_cfg
|
550 |
+
self.pipeline = self.load_diffusion_pipeline(
|
551 |
+
avatar = self.avatar,
|
552 |
+
vae=self.vae,
|
553 |
+
text_encoder=self.text_encoder,
|
554 |
+
text_encoder_2=self.text_encoder_2,
|
555 |
+
model=self.model,
|
556 |
+
device=self.device,
|
557 |
+
)
|
558 |
+
|
559 |
+
if self.i2v_mode:
|
560 |
+
self.default_negative_prompt = NEGATIVE_PROMPT_I2V
|
561 |
+
else:
|
562 |
+
self.default_negative_prompt = NEGATIVE_PROMPT
|
563 |
+
|
564 |
+
@property
|
565 |
+
def _interrupt(self):
|
566 |
+
return self.pipeline._interrupt
|
567 |
+
|
568 |
+
@_interrupt.setter
|
569 |
+
def _interrupt(self, value):
|
570 |
+
self.pipeline._interrupt =value
|
571 |
+
|
572 |
+
def load_diffusion_pipeline(
|
573 |
+
self,
|
574 |
+
avatar,
|
575 |
+
vae,
|
576 |
+
text_encoder,
|
577 |
+
text_encoder_2,
|
578 |
+
model,
|
579 |
+
scheduler=None,
|
580 |
+
device=None,
|
581 |
+
progress_bar_config=None,
|
582 |
+
#data_type="video",
|
583 |
+
):
|
584 |
+
"""Load the denoising scheduler for inference."""
|
585 |
+
if scheduler is None:
|
586 |
+
scheduler = FlowMatchDiscreteScheduler(
|
587 |
+
shift=6.0,
|
588 |
+
reverse=True,
|
589 |
+
solver="euler",
|
590 |
+
)
|
591 |
+
|
592 |
+
if avatar:
|
593 |
+
pipeline = HunyuanVideoAudioPipeline(
|
594 |
+
vae=vae,
|
595 |
+
text_encoder=text_encoder,
|
596 |
+
text_encoder_2=text_encoder_2,
|
597 |
+
transformer=model,
|
598 |
+
scheduler=scheduler,
|
599 |
+
progress_bar_config=progress_bar_config,
|
600 |
+
)
|
601 |
+
else:
|
602 |
+
pipeline = HunyuanVideoPipeline(
|
603 |
+
vae=vae,
|
604 |
+
text_encoder=text_encoder,
|
605 |
+
text_encoder_2=text_encoder_2,
|
606 |
+
transformer=model,
|
607 |
+
scheduler=scheduler,
|
608 |
+
progress_bar_config=progress_bar_config,
|
609 |
+
)
|
610 |
+
|
611 |
+
return pipeline
|
612 |
+
|
613 |
+
def get_rotary_pos_embed_new(self, video_length, height, width, concat_dict={}, enable_riflex = False):
|
614 |
+
target_ndim = 3
|
615 |
+
ndim = 5 - 2
|
616 |
+
latents_size = [(video_length-1)//4+1 , height//8, width//8]
|
617 |
+
|
618 |
+
if isinstance(self.model.patch_size, int):
|
619 |
+
assert all(s % self.model.patch_size == 0 for s in latents_size), \
|
620 |
+
f"Latent size(last {ndim} dimensions) should be divisible by patch size({self.model.patch_size}), " \
|
621 |
+
f"but got {latents_size}."
|
622 |
+
rope_sizes = [s // self.model.patch_size for s in latents_size]
|
623 |
+
elif isinstance(self.model.patch_size, list):
|
624 |
+
assert all(s % self.model.patch_size[idx] == 0 for idx, s in enumerate(latents_size)), \
|
625 |
+
f"Latent size(last {ndim} dimensions) should be divisible by patch size({self.model.patch_size}), " \
|
626 |
+
f"but got {latents_size}."
|
627 |
+
rope_sizes = [s // self.model.patch_size[idx] for idx, s in enumerate(latents_size)]
|
628 |
+
|
629 |
+
if len(rope_sizes) != target_ndim:
|
630 |
+
rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes # time axis
|
631 |
+
head_dim = self.model.hidden_size // self.model.heads_num
|
632 |
+
rope_dim_list = self.model.rope_dim_list
|
633 |
+
if rope_dim_list is None:
|
634 |
+
rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
|
635 |
+
assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer"
|
636 |
+
freqs_cos, freqs_sin = get_nd_rotary_pos_embed_new(rope_dim_list,
|
637 |
+
rope_sizes,
|
638 |
+
theta=256,
|
639 |
+
use_real=True,
|
640 |
+
theta_rescale_factor=1,
|
641 |
+
concat_dict=concat_dict,
|
642 |
+
L_test = (video_length - 1) // 4 + 1,
|
643 |
+
enable_riflex = enable_riflex
|
644 |
+
)
|
645 |
+
return freqs_cos, freqs_sin
|
646 |
+
|
647 |
+
def get_rotary_pos_embed(self, video_length, height, width, enable_riflex = False):
|
648 |
+
target_ndim = 3
|
649 |
+
ndim = 5 - 2
|
650 |
+
# 884
|
651 |
+
vae = "884-16c-hy"
|
652 |
+
if "884" in vae:
|
653 |
+
latents_size = [(video_length - 1) // 4 + 1, height // 8, width // 8]
|
654 |
+
elif "888" in vae:
|
655 |
+
latents_size = [(video_length - 1) // 8 + 1, height // 8, width // 8]
|
656 |
+
else:
|
657 |
+
latents_size = [video_length, height // 8, width // 8]
|
658 |
+
|
659 |
+
if isinstance(self.model.patch_size, int):
|
660 |
+
assert all(s % self.model.patch_size == 0 for s in latents_size), (
|
661 |
+
f"Latent size(last {ndim} dimensions) should be divisible by patch size({self.model.patch_size}), "
|
662 |
+
f"but got {latents_size}."
|
663 |
+
)
|
664 |
+
rope_sizes = [s // self.model.patch_size for s in latents_size]
|
665 |
+
elif isinstance(self.model.patch_size, list):
|
666 |
+
assert all(
|
667 |
+
s % self.model.patch_size[idx] == 0
|
668 |
+
for idx, s in enumerate(latents_size)
|
669 |
+
), (
|
670 |
+
f"Latent size(last {ndim} dimensions) should be divisible by patch size({self.model.patch_size}), "
|
671 |
+
f"but got {latents_size}."
|
672 |
+
)
|
673 |
+
rope_sizes = [
|
674 |
+
s // self.model.patch_size[idx] for idx, s in enumerate(latents_size)
|
675 |
+
]
|
676 |
+
|
677 |
+
if len(rope_sizes) != target_ndim:
|
678 |
+
rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes # time axis
|
679 |
+
head_dim = self.model.hidden_size // self.model.heads_num
|
680 |
+
rope_dim_list = self.model.rope_dim_list
|
681 |
+
if rope_dim_list is None:
|
682 |
+
rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
|
683 |
+
assert (
|
684 |
+
sum(rope_dim_list) == head_dim
|
685 |
+
), "sum(rope_dim_list) should equal to head_dim of attention layer"
|
686 |
+
freqs_cos, freqs_sin = get_nd_rotary_pos_embed(
|
687 |
+
rope_dim_list,
|
688 |
+
rope_sizes,
|
689 |
+
theta=256,
|
690 |
+
use_real=True,
|
691 |
+
theta_rescale_factor=1,
|
692 |
+
L_test = (video_length - 1) // 4 + 1,
|
693 |
+
enable_riflex = enable_riflex
|
694 |
+
)
|
695 |
+
return freqs_cos, freqs_sin
|
696 |
+
|
697 |
+
|
698 |
+
def generate(
|
699 |
+
self,
|
700 |
+
input_prompt,
|
701 |
+
input_ref_images = None,
|
702 |
+
audio_guide = None,
|
703 |
+
input_frames = None,
|
704 |
+
input_masks = None,
|
705 |
+
input_video = None,
|
706 |
+
fps = 24,
|
707 |
+
height=192,
|
708 |
+
width=336,
|
709 |
+
frame_num=129,
|
710 |
+
seed=None,
|
711 |
+
n_prompt=None,
|
712 |
+
sampling_steps=50,
|
713 |
+
guide_scale=1.0,
|
714 |
+
shift=5.0,
|
715 |
+
embedded_guidance_scale=6.0,
|
716 |
+
batch_size=1,
|
717 |
+
num_videos_per_prompt=1,
|
718 |
+
i2v_resolution="720p",
|
719 |
+
image_start=None,
|
720 |
+
enable_RIFLEx = False,
|
721 |
+
i2v_condition_type: str = "token_replace",
|
722 |
+
i2v_stability=True,
|
723 |
+
VAE_tile_size = None,
|
724 |
+
joint_pass = False,
|
725 |
+
cfg_star_switch = False,
|
726 |
+
fit_into_canvas = True,
|
727 |
+
conditioning_latents_size = 0,
|
728 |
+
**kwargs,
|
729 |
+
):
|
730 |
+
|
731 |
+
if VAE_tile_size != None:
|
732 |
+
self.vae.tile_sample_min_tsize = VAE_tile_size["tile_sample_min_tsize"]
|
733 |
+
self.vae.tile_latent_min_tsize = VAE_tile_size["tile_latent_min_tsize"]
|
734 |
+
self.vae.tile_sample_min_size = VAE_tile_size["tile_sample_min_size"]
|
735 |
+
self.vae.tile_latent_min_size = VAE_tile_size["tile_latent_min_size"]
|
736 |
+
self.vae.tile_overlap_factor = VAE_tile_size["tile_overlap_factor"]
|
737 |
+
self.vae.enable_tiling()
|
738 |
+
|
739 |
+
i2v_mode= self.i2v_mode
|
740 |
+
if not self.enable_cfg:
|
741 |
+
guide_scale=1.0
|
742 |
+
|
743 |
+
# ========================================================================
|
744 |
+
# Arguments: seed
|
745 |
+
# ========================================================================
|
746 |
+
if isinstance(seed, torch.Tensor):
|
747 |
+
seed = seed.tolist()
|
748 |
+
if seed is None:
|
749 |
+
seeds = [
|
750 |
+
random.randint(0, 1_000_000)
|
751 |
+
for _ in range(batch_size * num_videos_per_prompt)
|
752 |
+
]
|
753 |
+
elif isinstance(seed, int):
|
754 |
+
seeds = [
|
755 |
+
seed + i
|
756 |
+
for _ in range(batch_size)
|
757 |
+
for i in range(num_videos_per_prompt)
|
758 |
+
]
|
759 |
+
elif isinstance(seed, (list, tuple)):
|
760 |
+
if len(seed) == batch_size:
|
761 |
+
seeds = [
|
762 |
+
int(seed[i]) + j
|
763 |
+
for i in range(batch_size)
|
764 |
+
for j in range(num_videos_per_prompt)
|
765 |
+
]
|
766 |
+
elif len(seed) == batch_size * num_videos_per_prompt:
|
767 |
+
seeds = [int(s) for s in seed]
|
768 |
+
else:
|
769 |
+
raise ValueError(
|
770 |
+
f"Length of seed must be equal to number of prompt(batch_size) or "
|
771 |
+
f"batch_size * num_videos_per_prompt ({batch_size} * {num_videos_per_prompt}), got {seed}."
|
772 |
+
)
|
773 |
+
else:
|
774 |
+
raise ValueError(
|
775 |
+
f"Seed must be an integer, a list of integers, or None, got {seed}."
|
776 |
+
)
|
777 |
+
from wan.utils.utils import seed_everything
|
778 |
+
seed_everything(seed)
|
779 |
+
generator = [torch.Generator("cuda").manual_seed(seed) for seed in seeds]
|
780 |
+
# generator = [torch.Generator(self.device).manual_seed(seed) for seed in seeds]
|
781 |
+
|
782 |
+
# ========================================================================
|
783 |
+
# Arguments: target_width, target_height, target_frame_num
|
784 |
+
# ========================================================================
|
785 |
+
if width <= 0 or height <= 0 or frame_num <= 0:
|
786 |
+
raise ValueError(
|
787 |
+
f"`height` and `width` and `frame_num` must be positive integers, got height={height}, width={width}, frame_num={frame_num}"
|
788 |
+
)
|
789 |
+
if (frame_num - 1) % 4 != 0:
|
790 |
+
raise ValueError(
|
791 |
+
f"`frame_num-1` must be a multiple of 4, got {frame_num}"
|
792 |
+
)
|
793 |
+
|
794 |
+
target_height = align_to(height, 16)
|
795 |
+
target_width = align_to(width, 16)
|
796 |
+
target_frame_num = frame_num
|
797 |
+
audio_strength = 1
|
798 |
+
|
799 |
+
if input_ref_images != None:
|
800 |
+
# ip_cfg_scale = 3.0
|
801 |
+
ip_cfg_scale = 0
|
802 |
+
denoise_strength = 1
|
803 |
+
# guide_scale=7.5
|
804 |
+
# shift=13
|
805 |
+
name = "person"
|
806 |
+
input_ref_images = input_ref_images[0]
|
807 |
+
|
808 |
+
# ========================================================================
|
809 |
+
# Arguments: prompt, new_prompt, negative_prompt
|
810 |
+
# ========================================================================
|
811 |
+
if not isinstance(input_prompt, str):
|
812 |
+
raise TypeError(f"`prompt` must be a string, but got {type(input_prompt)}")
|
813 |
+
input_prompt = [input_prompt.strip()]
|
814 |
+
|
815 |
+
# negative prompt
|
816 |
+
if n_prompt is None or n_prompt == "":
|
817 |
+
n_prompt = self.default_negative_prompt
|
818 |
+
if guide_scale == 1.0:
|
819 |
+
n_prompt = ""
|
820 |
+
if not isinstance(n_prompt, str):
|
821 |
+
raise TypeError(
|
822 |
+
f"`negative_prompt` must be a string, but got {type(n_prompt)}"
|
823 |
+
)
|
824 |
+
n_prompt = [n_prompt.strip()]
|
825 |
+
|
826 |
+
# ========================================================================
|
827 |
+
# Scheduler
|
828 |
+
# ========================================================================
|
829 |
+
scheduler = FlowMatchDiscreteScheduler(
|
830 |
+
shift=shift,
|
831 |
+
reverse=True,
|
832 |
+
solver="euler"
|
833 |
+
)
|
834 |
+
self.pipeline.scheduler = scheduler
|
835 |
+
|
836 |
+
# ---------------------------------
|
837 |
+
# Reference condition
|
838 |
+
# ---------------------------------
|
839 |
+
img_latents = None
|
840 |
+
semantic_images = None
|
841 |
+
denoise_strength = 0
|
842 |
+
ip_cfg_scale = 0
|
843 |
+
if i2v_mode:
|
844 |
+
if i2v_resolution == "720p":
|
845 |
+
bucket_hw_base_size = 960
|
846 |
+
elif i2v_resolution == "540p":
|
847 |
+
bucket_hw_base_size = 720
|
848 |
+
elif i2v_resolution == "360p":
|
849 |
+
bucket_hw_base_size = 480
|
850 |
+
else:
|
851 |
+
raise ValueError(f"i2v_resolution: {i2v_resolution} must be in [360p, 540p, 720p]")
|
852 |
+
|
853 |
+
# semantic_images = [Image.open(i2v_image_path).convert('RGB')]
|
854 |
+
semantic_images = [image_start.convert('RGB')] #
|
855 |
+
origin_size = semantic_images[0].size
|
856 |
+
h, w = origin_size
|
857 |
+
h, w = calculate_new_dimensions(height, width, h, w, fit_into_canvas)
|
858 |
+
closest_size = (w, h)
|
859 |
+
# crop_size_list = generate_crop_size_list(bucket_hw_base_size, 32)
|
860 |
+
# aspect_ratios = np.array([round(float(h)/float(w), 5) for h, w in crop_size_list])
|
861 |
+
# closest_size, closest_ratio = get_closest_ratio(origin_size[1], origin_size[0], aspect_ratios, crop_size_list)
|
862 |
+
ref_image_transform = transforms.Compose([
|
863 |
+
transforms.Resize(closest_size),
|
864 |
+
transforms.CenterCrop(closest_size),
|
865 |
+
transforms.ToTensor(),
|
866 |
+
transforms.Normalize([0.5], [0.5])
|
867 |
+
])
|
868 |
+
|
869 |
+
semantic_image_pixel_values = [ref_image_transform(semantic_image) for semantic_image in semantic_images]
|
870 |
+
semantic_image_pixel_values = torch.cat(semantic_image_pixel_values).unsqueeze(0).unsqueeze(2).to(self.device)
|
871 |
+
|
872 |
+
with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=True):
|
873 |
+
img_latents = self.pipeline.vae.encode(semantic_image_pixel_values).latent_dist.mode() # B, C, F, H, W
|
874 |
+
img_latents.mul_(self.pipeline.vae.config.scaling_factor)
|
875 |
+
|
876 |
+
target_height, target_width = closest_size
|
877 |
+
|
878 |
+
# ========================================================================
|
879 |
+
# Build Rope freqs
|
880 |
+
# ========================================================================
|
881 |
+
|
882 |
+
if input_ref_images == None:
|
883 |
+
freqs_cos, freqs_sin = self.get_rotary_pos_embed(target_frame_num, target_height, target_width, enable_RIFLEx)
|
884 |
+
else:
|
885 |
+
if self.avatar:
|
886 |
+
w, h = input_ref_images.size
|
887 |
+
target_height, target_width = calculate_new_dimensions(target_height, target_width, h, w, fit_into_canvas)
|
888 |
+
if target_width != w or target_height != h:
|
889 |
+
input_ref_images = input_ref_images.resize((target_width,target_height), resample=Image.Resampling.LANCZOS)
|
890 |
+
|
891 |
+
concat_dict = {'mode': 'timecat', 'bias': -1}
|
892 |
+
freqs_cos, freqs_sin = self.get_rotary_pos_embed_new(129, target_height, target_width, concat_dict)
|
893 |
+
else:
|
894 |
+
if input_frames != None:
|
895 |
+
target_height, target_width = input_frames.shape[-3:-1]
|
896 |
+
elif input_video != None:
|
897 |
+
target_height, target_width = input_video.shape[-2:]
|
898 |
+
|
899 |
+
concat_dict = {'mode': 'timecat-w', 'bias': -1}
|
900 |
+
freqs_cos, freqs_sin = self.get_rotary_pos_embed_new(target_frame_num, target_height, target_width, concat_dict, enable_RIFLEx)
|
901 |
+
|
902 |
+
n_tokens = freqs_cos.shape[0]
|
903 |
+
|
904 |
+
callback = kwargs.pop("callback", None)
|
905 |
+
callback_steps = kwargs.pop("callback_steps", None)
|
906 |
+
# ========================================================================
|
907 |
+
# Pipeline inference
|
908 |
+
# ========================================================================
|
909 |
+
|
910 |
+
pixel_value_llava, uncond_pixel_value_llava, pixel_value_ref = None, None, None
|
911 |
+
if input_ref_images == None:
|
912 |
+
name = None
|
913 |
+
else:
|
914 |
+
pixel_value_llava, uncond_pixel_value_llava, pixel_value_ref = DataPreprocess().get_batch(input_ref_images, (target_width, target_height), pad = self.custom)
|
915 |
+
|
916 |
+
ref_latents, uncond_audio_prompts, audio_prompts, face_masks, motion_exp, motion_pose = None, None, None, None, None, None
|
917 |
+
|
918 |
+
|
919 |
+
bg_latents = None
|
920 |
+
if input_video != None:
|
921 |
+
pixel_value_bg = input_video.unsqueeze(0)
|
922 |
+
pixel_value_mask = torch.zeros_like(input_video).unsqueeze(0)
|
923 |
+
if input_frames != None:
|
924 |
+
pixel_value_video_bg = input_frames.permute(-1,0,1,2).unsqueeze(0).float()
|
925 |
+
pixel_value_video_mask = input_masks.unsqueeze(-1).repeat(1,1,1,3).permute(-1,0,1,2).unsqueeze(0).float()
|
926 |
+
pixel_value_video_bg = pixel_value_video_bg.div_(127.5).add_(-1.)
|
927 |
+
if input_video != None:
|
928 |
+
pixel_value_bg = torch.cat([pixel_value_bg, pixel_value_video_bg], dim=2)
|
929 |
+
pixel_value_mask = torch.cat([ pixel_value_mask, pixel_value_video_mask], dim=2)
|
930 |
+
else:
|
931 |
+
pixel_value_bg = pixel_value_video_bg
|
932 |
+
pixel_value_mask = pixel_value_video_mask
|
933 |
+
pixel_value_video_mask, pixel_value_video_bg = None, None
|
934 |
+
if input_video != None or input_frames != None:
|
935 |
+
if pixel_value_bg.shape[2] < frame_num:
|
936 |
+
padding_shape = list(pixel_value_bg.shape[0:2]) + [frame_num-pixel_value_bg.shape[2]] + list(pixel_value_bg.shape[3:])
|
937 |
+
pixel_value_bg = torch.cat([pixel_value_bg, torch.full(padding_shape, -1, dtype=pixel_value_bg.dtype, device= pixel_value_bg.device ) ], dim=2)
|
938 |
+
pixel_value_mask = torch.cat([ pixel_value_mask, torch.full(padding_shape, 255, dtype=pixel_value_mask.dtype, device= pixel_value_mask.device ) ], dim=2)
|
939 |
+
|
940 |
+
bg_latents = self.vae.encode(pixel_value_bg).latent_dist.sample()
|
941 |
+
pixel_value_mask = pixel_value_mask.div_(127.5).add_(-1.)
|
942 |
+
mask_latents = self.vae.encode(pixel_value_mask).latent_dist.sample()
|
943 |
+
bg_latents = torch.cat([bg_latents, mask_latents], dim=1)
|
944 |
+
bg_latents.mul_(self.vae.config.scaling_factor)
|
945 |
+
|
946 |
+
if self.avatar:
|
947 |
+
if n_prompt == None or len(n_prompt) == 0:
|
948 |
+
n_prompt = "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion, blurring, Lens changes"
|
949 |
+
|
950 |
+
uncond_pixel_value_llava = pixel_value_llava.clone()
|
951 |
+
|
952 |
+
pixel_value_ref = pixel_value_ref.unsqueeze(0)
|
953 |
+
self.align_instance.facedet.model.to("cuda")
|
954 |
+
face_masks = get_facemask(pixel_value_ref.to("cuda")*255, self.align_instance, area=3.0)
|
955 |
+
# iii = (face_masks.squeeze(0).squeeze(0).permute(1,2,0).repeat(1,1,3)*255).cpu().numpy().astype(np.uint8)
|
956 |
+
# image = Image.fromarray(iii)
|
957 |
+
# image.save("mask.png")
|
958 |
+
# jjj = (pixel_value_ref.squeeze(0).squeeze(0).permute(1,2,0)*255).cpu().numpy().astype(np.uint8)
|
959 |
+
|
960 |
+
self.align_instance.facedet.model.to("cpu")
|
961 |
+
# pixel_value_ref = pixel_value_ref.clone().repeat(1,129,1,1,1)
|
962 |
+
|
963 |
+
pixel_value_ref = pixel_value_ref.repeat(1,1+4*2,1,1,1)
|
964 |
+
pixel_value_ref = pixel_value_ref * 2 - 1
|
965 |
+
pixel_value_ref_for_vae = rearrange(pixel_value_ref, "b f c h w -> b c f h w")
|
966 |
+
|
967 |
+
vae_dtype = self.vae.dtype
|
968 |
+
with torch.autocast(device_type="cuda", dtype=vae_dtype, enabled=vae_dtype != torch.float32):
|
969 |
+
ref_latents = self.vae.encode(pixel_value_ref_for_vae).latent_dist.sample()
|
970 |
+
ref_latents = torch.cat( [ref_latents[:,:, :1], ref_latents[:,:, 1:2].repeat(1,1,31,1,1), ref_latents[:,:, -1:]], dim=2)
|
971 |
+
pixel_value_ref, pixel_value_ref_for_vae = None, None
|
972 |
+
|
973 |
+
if hasattr(self.vae.config, 'shift_factor') and self.vae.config.shift_factor:
|
974 |
+
ref_latents.sub_(self.vae.config.shift_factor).mul_(self.vae.config.scaling_factor)
|
975 |
+
else:
|
976 |
+
ref_latents.mul_(self.vae.config.scaling_factor)
|
977 |
+
|
978 |
+
# out_latents= ref_latents / self.vae.config.scaling_factor
|
979 |
+
# image = self.vae.decode(out_latents, return_dict=False, generator=generator)[0]
|
980 |
+
# image = image.clamp(-1, 1)
|
981 |
+
# from wan.utils.utils import cache_video
|
982 |
+
# cache_video( tensor=image, save_file="decode.mp4", fps=25, nrow=1, normalize=True, value_range=(-1, 1))
|
983 |
+
|
984 |
+
motion_pose = np.array([25] * 4)
|
985 |
+
motion_exp = np.array([30] * 4)
|
986 |
+
motion_pose = torch.from_numpy(motion_pose).unsqueeze(0)
|
987 |
+
motion_exp = torch.from_numpy(motion_exp).unsqueeze(0)
|
988 |
+
|
989 |
+
face_masks = torch.nn.functional.interpolate(face_masks.float().squeeze(2),
|
990 |
+
(ref_latents.shape[-2],
|
991 |
+
ref_latents.shape[-1]),
|
992 |
+
mode="bilinear").unsqueeze(2).to(dtype=ref_latents.dtype)
|
993 |
+
|
994 |
+
|
995 |
+
if audio_guide != None:
|
996 |
+
audio_input, audio_len = get_audio_feature(self.feature_extractor, audio_guide, duration = frame_num/fps )
|
997 |
+
audio_prompts = audio_input[0]
|
998 |
+
weight_dtype = audio_prompts.dtype
|
999 |
+
if self.custom:
|
1000 |
+
audio_len = min(audio_len, frame_num)
|
1001 |
+
audio_input = audio_input[:, :audio_len]
|
1002 |
+
audio_prompts = encode_audio(self.wav2vec, audio_prompts.to(dtype=self.wav2vec.dtype), fps, num_frames=audio_len)
|
1003 |
+
audio_prompts = audio_prompts.to(self.model.dtype)
|
1004 |
+
segment_size = 129 if self.avatar else frame_num
|
1005 |
+
if audio_prompts.shape[1] <= segment_size:
|
1006 |
+
audio_prompts = torch.cat([audio_prompts, torch.zeros_like(audio_prompts[:, :1]).repeat(1,segment_size-audio_prompts.shape[1], 1, 1, 1)], dim=1)
|
1007 |
+
else:
|
1008 |
+
audio_prompts = torch.cat([audio_prompts, torch.zeros_like(audio_prompts[:, :1]).repeat(1, 5, 1, 1, 1)], dim=1)
|
1009 |
+
uncond_audio_prompts = torch.zeros_like(audio_prompts[:,:129])
|
1010 |
+
|
1011 |
+
samples = self.pipeline(
|
1012 |
+
prompt=input_prompt,
|
1013 |
+
height=target_height,
|
1014 |
+
width=target_width,
|
1015 |
+
video_length=target_frame_num,
|
1016 |
+
num_inference_steps=sampling_steps,
|
1017 |
+
guidance_scale=guide_scale,
|
1018 |
+
negative_prompt=n_prompt,
|
1019 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
1020 |
+
generator=generator,
|
1021 |
+
output_type="pil",
|
1022 |
+
name = name,
|
1023 |
+
|
1024 |
+
pixel_value_ref = pixel_value_ref,
|
1025 |
+
ref_latents=ref_latents, # [1, 16, 1, h//8, w//8]
|
1026 |
+
pixel_value_llava=pixel_value_llava, # [1, 3, 336, 336]
|
1027 |
+
uncond_pixel_value_llava=uncond_pixel_value_llava,
|
1028 |
+
face_masks=face_masks, # [b f h w]
|
1029 |
+
audio_prompts=audio_prompts,
|
1030 |
+
uncond_audio_prompts=uncond_audio_prompts,
|
1031 |
+
motion_exp=motion_exp,
|
1032 |
+
motion_pose=motion_pose,
|
1033 |
+
fps= torch.from_numpy(np.array(fps)),
|
1034 |
+
|
1035 |
+
bg_latents = bg_latents,
|
1036 |
+
audio_strength = audio_strength,
|
1037 |
+
|
1038 |
+
denoise_strength=denoise_strength,
|
1039 |
+
ip_cfg_scale=ip_cfg_scale,
|
1040 |
+
freqs_cis=(freqs_cos, freqs_sin),
|
1041 |
+
n_tokens=n_tokens,
|
1042 |
+
embedded_guidance_scale=embedded_guidance_scale,
|
1043 |
+
data_type="video" if target_frame_num > 1 else "image",
|
1044 |
+
is_progress_bar=True,
|
1045 |
+
vae_ver="884-16c-hy",
|
1046 |
+
enable_tiling=True,
|
1047 |
+
i2v_mode=i2v_mode,
|
1048 |
+
i2v_condition_type=i2v_condition_type,
|
1049 |
+
i2v_stability=i2v_stability,
|
1050 |
+
img_latents=img_latents,
|
1051 |
+
semantic_images=semantic_images,
|
1052 |
+
joint_pass = joint_pass,
|
1053 |
+
cfg_star_rescale = cfg_star_switch,
|
1054 |
+
callback = callback,
|
1055 |
+
callback_steps = callback_steps,
|
1056 |
+
)[0]
|
1057 |
+
|
1058 |
+
if samples == None:
|
1059 |
+
return None
|
1060 |
+
samples = samples.squeeze(0)
|
1061 |
+
|
1062 |
+
return samples
|
hyvideo/modules/__init__.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .models import HYVideoDiffusionTransformer, HUNYUAN_VIDEO_CONFIG
|
2 |
+
|
3 |
+
|
4 |
+
def load_model(model, i2v_condition_type, in_channels, out_channels, factor_kwargs):
|
5 |
+
"""load hunyuan video model
|
6 |
+
|
7 |
+
Args:
|
8 |
+
args (dict): model args
|
9 |
+
in_channels (int): input channels number
|
10 |
+
out_channels (int): output channels number
|
11 |
+
factor_kwargs (dict): factor kwargs
|
12 |
+
|
13 |
+
Returns:
|
14 |
+
model (nn.Module): The hunyuan video model
|
15 |
+
"""
|
16 |
+
if model in HUNYUAN_VIDEO_CONFIG.keys():
|
17 |
+
model = HYVideoDiffusionTransformer(
|
18 |
+
i2v_condition_type = i2v_condition_type,
|
19 |
+
in_channels=in_channels,
|
20 |
+
out_channels=out_channels,
|
21 |
+
**HUNYUAN_VIDEO_CONFIG[model],
|
22 |
+
**factor_kwargs,
|
23 |
+
)
|
24 |
+
return model
|
25 |
+
else:
|
26 |
+
raise NotImplementedError()
|
hyvideo/modules/activation_layers.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
|
3 |
+
|
4 |
+
def get_activation_layer(act_type):
|
5 |
+
"""get activation layer
|
6 |
+
|
7 |
+
Args:
|
8 |
+
act_type (str): the activation type
|
9 |
+
|
10 |
+
Returns:
|
11 |
+
torch.nn.functional: the activation layer
|
12 |
+
"""
|
13 |
+
if act_type == "gelu":
|
14 |
+
return lambda: nn.GELU()
|
15 |
+
elif act_type == "gelu_tanh":
|
16 |
+
# Approximate `tanh` requires torch >= 1.13
|
17 |
+
return lambda: nn.GELU(approximate="tanh")
|
18 |
+
elif act_type == "relu":
|
19 |
+
return nn.ReLU
|
20 |
+
elif act_type == "silu":
|
21 |
+
return nn.SiLU
|
22 |
+
else:
|
23 |
+
raise ValueError(f"Unknown activation type: {act_type}")
|
hyvideo/modules/attenion.py
ADDED
@@ -0,0 +1,362 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib.metadata
|
2 |
+
import math
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from importlib.metadata import version
|
8 |
+
|
9 |
+
def clear_list(l):
|
10 |
+
for i in range(len(l)):
|
11 |
+
l[i] = None
|
12 |
+
|
13 |
+
try:
|
14 |
+
import flash_attn
|
15 |
+
from flash_attn.flash_attn_interface import _flash_attn_forward
|
16 |
+
from flash_attn.flash_attn_interface import flash_attn_varlen_func
|
17 |
+
except ImportError:
|
18 |
+
flash_attn = None
|
19 |
+
flash_attn_varlen_func = None
|
20 |
+
_flash_attn_forward = None
|
21 |
+
|
22 |
+
try:
|
23 |
+
from xformers.ops import memory_efficient_attention
|
24 |
+
except ImportError:
|
25 |
+
memory_efficient_attention = None
|
26 |
+
|
27 |
+
try:
|
28 |
+
from sageattention import sageattn_varlen
|
29 |
+
def sageattn_varlen_wrapper(
|
30 |
+
q,
|
31 |
+
k,
|
32 |
+
v,
|
33 |
+
cu_seqlens_q,
|
34 |
+
cu_seqlens_kv,
|
35 |
+
max_seqlen_q,
|
36 |
+
max_seqlen_kv,
|
37 |
+
):
|
38 |
+
return sageattn_varlen(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
|
39 |
+
except ImportError:
|
40 |
+
sageattn_varlen_wrapper = None
|
41 |
+
|
42 |
+
try:
|
43 |
+
from sageattention import sageattn
|
44 |
+
@torch.compiler.disable()
|
45 |
+
def sageattn_wrapper(
|
46 |
+
qkv_list,
|
47 |
+
attention_length
|
48 |
+
):
|
49 |
+
q,k, v = qkv_list
|
50 |
+
padding_length = q.shape[1] -attention_length
|
51 |
+
q = q[:, :attention_length, :, : ]
|
52 |
+
k = k[:, :attention_length, :, : ]
|
53 |
+
v = v[:, :attention_length, :, : ]
|
54 |
+
|
55 |
+
o = sageattn(q, k, v, tensor_layout="NHD")
|
56 |
+
del q, k ,v
|
57 |
+
clear_list(qkv_list)
|
58 |
+
|
59 |
+
if padding_length > 0:
|
60 |
+
o = torch.cat([o, torch.empty( (o.shape[0], padding_length, *o.shape[-2:]), dtype= o.dtype, device=o.device ) ], 1)
|
61 |
+
|
62 |
+
return o
|
63 |
+
|
64 |
+
except ImportError:
|
65 |
+
sageattn = None
|
66 |
+
|
67 |
+
|
68 |
+
def get_attention_modes():
|
69 |
+
ret = ["sdpa", "auto"]
|
70 |
+
if flash_attn != None:
|
71 |
+
ret.append("flash")
|
72 |
+
if memory_efficient_attention != None:
|
73 |
+
ret.append("xformers")
|
74 |
+
if sageattn_varlen_wrapper != None:
|
75 |
+
ret.append("sage")
|
76 |
+
if sageattn != None and version("sageattention").startswith("2") :
|
77 |
+
ret.append("sage2")
|
78 |
+
|
79 |
+
return ret
|
80 |
+
|
81 |
+
|
82 |
+
|
83 |
+
MEMORY_LAYOUT = {
|
84 |
+
"sdpa": (
|
85 |
+
lambda x: x.transpose(1, 2),
|
86 |
+
lambda x: x.transpose(1, 2),
|
87 |
+
),
|
88 |
+
"xformers": (
|
89 |
+
lambda x: x,
|
90 |
+
lambda x: x,
|
91 |
+
),
|
92 |
+
"sage2": (
|
93 |
+
lambda x: x,
|
94 |
+
lambda x: x,
|
95 |
+
),
|
96 |
+
"sage": (
|
97 |
+
lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]),
|
98 |
+
lambda x: x,
|
99 |
+
),
|
100 |
+
"flash": (
|
101 |
+
lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]),
|
102 |
+
lambda x: x,
|
103 |
+
),
|
104 |
+
"torch": (
|
105 |
+
lambda x: x.transpose(1, 2),
|
106 |
+
lambda x: x.transpose(1, 2),
|
107 |
+
),
|
108 |
+
"vanilla": (
|
109 |
+
lambda x: x.transpose(1, 2),
|
110 |
+
lambda x: x.transpose(1, 2),
|
111 |
+
),
|
112 |
+
}
|
113 |
+
|
114 |
+
@torch.compiler.disable()
|
115 |
+
def sdpa_wrapper(
|
116 |
+
qkv_list,
|
117 |
+
attention_length
|
118 |
+
):
|
119 |
+
q,k, v = qkv_list
|
120 |
+
padding_length = q.shape[2] -attention_length
|
121 |
+
q = q[:, :, :attention_length, :]
|
122 |
+
k = k[:, :, :attention_length, :]
|
123 |
+
v = v[:, :, :attention_length, :]
|
124 |
+
|
125 |
+
o = F.scaled_dot_product_attention(
|
126 |
+
q, k, v, attn_mask=None, is_causal=False
|
127 |
+
)
|
128 |
+
del q, k ,v
|
129 |
+
clear_list(qkv_list)
|
130 |
+
|
131 |
+
if padding_length > 0:
|
132 |
+
o = torch.cat([o, torch.empty( (*o.shape[:2], padding_length, o.shape[-1]), dtype= o.dtype, device=o.device ) ], 2)
|
133 |
+
|
134 |
+
return o
|
135 |
+
|
136 |
+
def get_cu_seqlens(text_mask, img_len):
|
137 |
+
"""Calculate cu_seqlens_q, cu_seqlens_kv using text_mask and img_len
|
138 |
+
|
139 |
+
Args:
|
140 |
+
text_mask (torch.Tensor): the mask of text
|
141 |
+
img_len (int): the length of image
|
142 |
+
|
143 |
+
Returns:
|
144 |
+
torch.Tensor: the calculated cu_seqlens for flash attention
|
145 |
+
"""
|
146 |
+
batch_size = text_mask.shape[0]
|
147 |
+
text_len = text_mask.sum(dim=1)
|
148 |
+
max_len = text_mask.shape[1] + img_len
|
149 |
+
|
150 |
+
cu_seqlens = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device="cuda")
|
151 |
+
|
152 |
+
for i in range(batch_size):
|
153 |
+
s = text_len[i] + img_len
|
154 |
+
s1 = i * max_len + s
|
155 |
+
s2 = (i + 1) * max_len
|
156 |
+
cu_seqlens[2 * i + 1] = s1
|
157 |
+
cu_seqlens[2 * i + 2] = s2
|
158 |
+
|
159 |
+
return cu_seqlens
|
160 |
+
|
161 |
+
|
162 |
+
def attention(
|
163 |
+
qkv_list,
|
164 |
+
mode="flash",
|
165 |
+
drop_rate=0,
|
166 |
+
attn_mask=None,
|
167 |
+
causal=False,
|
168 |
+
cu_seqlens_q=None,
|
169 |
+
cu_seqlens_kv=None,
|
170 |
+
max_seqlen_q=None,
|
171 |
+
max_seqlen_kv=None,
|
172 |
+
batch_size=1,
|
173 |
+
):
|
174 |
+
"""
|
175 |
+
Perform QKV self attention.
|
176 |
+
|
177 |
+
Args:
|
178 |
+
q (torch.Tensor): Query tensor with shape [b, s, a, d], where a is the number of heads.
|
179 |
+
k (torch.Tensor): Key tensor with shape [b, s1, a, d]
|
180 |
+
v (torch.Tensor): Value tensor with shape [b, s1, a, d]
|
181 |
+
mode (str): Attention mode. Choose from 'self_flash', 'cross_flash', 'torch', and 'vanilla'.
|
182 |
+
drop_rate (float): Dropout rate in attention map. (default: 0)
|
183 |
+
attn_mask (torch.Tensor): Attention mask with shape [b, s1] (cross_attn), or [b, a, s, s1] (torch or vanilla).
|
184 |
+
(default: None)
|
185 |
+
causal (bool): Whether to use causal attention. (default: False)
|
186 |
+
cu_seqlens_q (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
|
187 |
+
used to index into q.
|
188 |
+
cu_seqlens_kv (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
|
189 |
+
used to index into kv.
|
190 |
+
max_seqlen_q (int): The maximum sequence length in the batch of q.
|
191 |
+
max_seqlen_kv (int): The maximum sequence length in the batch of k and v.
|
192 |
+
|
193 |
+
Returns:
|
194 |
+
torch.Tensor: Output tensor after self attention with shape [b, s, ad]
|
195 |
+
"""
|
196 |
+
pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode]
|
197 |
+
q , k , v = qkv_list
|
198 |
+
clear_list(qkv_list)
|
199 |
+
del qkv_list
|
200 |
+
padding_length = 0
|
201 |
+
# if attn_mask == None and mode == "sdpa":
|
202 |
+
# padding_length = q.shape[1] - cu_seqlens_q
|
203 |
+
# q = q[:, :cu_seqlens_q, ... ]
|
204 |
+
# k = k[:, :cu_seqlens_kv, ... ]
|
205 |
+
# v = v[:, :cu_seqlens_kv, ... ]
|
206 |
+
|
207 |
+
q = pre_attn_layout(q)
|
208 |
+
k = pre_attn_layout(k)
|
209 |
+
v = pre_attn_layout(v)
|
210 |
+
|
211 |
+
if mode == "torch":
|
212 |
+
if attn_mask is not None and attn_mask.dtype != torch.bool:
|
213 |
+
attn_mask = attn_mask.to(q.dtype)
|
214 |
+
x = F.scaled_dot_product_attention(
|
215 |
+
q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal
|
216 |
+
)
|
217 |
+
|
218 |
+
elif mode == "sdpa":
|
219 |
+
# if attn_mask is not None and attn_mask.dtype != torch.bool:
|
220 |
+
# attn_mask = attn_mask.to(q.dtype)
|
221 |
+
# x = F.scaled_dot_product_attention(
|
222 |
+
# q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal
|
223 |
+
# )
|
224 |
+
assert attn_mask==None
|
225 |
+
qkv_list = [q, k, v]
|
226 |
+
del q, k , v
|
227 |
+
x = sdpa_wrapper( qkv_list, cu_seqlens_q )
|
228 |
+
|
229 |
+
elif mode == "xformers":
|
230 |
+
x = memory_efficient_attention(
|
231 |
+
q, k, v , attn_bias= attn_mask
|
232 |
+
)
|
233 |
+
|
234 |
+
elif mode == "sage2":
|
235 |
+
qkv_list = [q, k, v]
|
236 |
+
del q, k , v
|
237 |
+
x = sageattn_wrapper(qkv_list, cu_seqlens_q)
|
238 |
+
|
239 |
+
elif mode == "sage":
|
240 |
+
x = sageattn_varlen_wrapper(
|
241 |
+
q,
|
242 |
+
k,
|
243 |
+
v,
|
244 |
+
cu_seqlens_q,
|
245 |
+
cu_seqlens_kv,
|
246 |
+
max_seqlen_q,
|
247 |
+
max_seqlen_kv,
|
248 |
+
)
|
249 |
+
# x with shape [(bxs), a, d]
|
250 |
+
x = x.view(
|
251 |
+
batch_size, max_seqlen_q, x.shape[-2], x.shape[-1]
|
252 |
+
) # reshape x to [b, s, a, d]
|
253 |
+
|
254 |
+
elif mode == "flash":
|
255 |
+
x = flash_attn_varlen_func(
|
256 |
+
q,
|
257 |
+
k,
|
258 |
+
v,
|
259 |
+
cu_seqlens_q,
|
260 |
+
cu_seqlens_kv,
|
261 |
+
max_seqlen_q,
|
262 |
+
max_seqlen_kv,
|
263 |
+
)
|
264 |
+
# x with shape [(bxs), a, d]
|
265 |
+
x = x.view(
|
266 |
+
batch_size, max_seqlen_q, x.shape[-2], x.shape[-1]
|
267 |
+
) # reshape x to [b, s, a, d]
|
268 |
+
elif mode == "vanilla":
|
269 |
+
scale_factor = 1 / math.sqrt(q.size(-1))
|
270 |
+
|
271 |
+
b, a, s, _ = q.shape
|
272 |
+
s1 = k.size(2)
|
273 |
+
attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device)
|
274 |
+
if causal:
|
275 |
+
# Only applied to self attention
|
276 |
+
assert (
|
277 |
+
attn_mask is None
|
278 |
+
), "Causal mask and attn_mask cannot be used together"
|
279 |
+
temp_mask = torch.ones(b, a, s, s, dtype=torch.bool, device=q.device).tril(
|
280 |
+
diagonal=0
|
281 |
+
)
|
282 |
+
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
|
283 |
+
attn_bias.to(q.dtype)
|
284 |
+
|
285 |
+
if attn_mask is not None:
|
286 |
+
if attn_mask.dtype == torch.bool:
|
287 |
+
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
|
288 |
+
else:
|
289 |
+
attn_bias += attn_mask
|
290 |
+
|
291 |
+
# TODO: Maybe force q and k to be float32 to avoid numerical overflow
|
292 |
+
attn = (q @ k.transpose(-2, -1)) * scale_factor
|
293 |
+
attn += attn_bias
|
294 |
+
attn = attn.softmax(dim=-1)
|
295 |
+
attn = torch.dropout(attn, p=drop_rate, train=True)
|
296 |
+
x = attn @ v
|
297 |
+
else:
|
298 |
+
raise NotImplementedError(f"Unsupported attention mode: {mode}")
|
299 |
+
|
300 |
+
x = post_attn_layout(x)
|
301 |
+
b, s, a, d = x.shape
|
302 |
+
out = x.reshape(b, s, -1)
|
303 |
+
if padding_length > 0 :
|
304 |
+
out = torch.cat([out, torch.empty( (out.shape[0], padding_length, out.shape[2]), dtype= out.dtype, device=out.device ) ], 1)
|
305 |
+
|
306 |
+
return out
|
307 |
+
|
308 |
+
|
309 |
+
def parallel_attention(
|
310 |
+
hybrid_seq_parallel_attn,
|
311 |
+
q,
|
312 |
+
k,
|
313 |
+
v,
|
314 |
+
img_q_len,
|
315 |
+
img_kv_len,
|
316 |
+
cu_seqlens_q,
|
317 |
+
cu_seqlens_kv
|
318 |
+
):
|
319 |
+
attn1 = hybrid_seq_parallel_attn(
|
320 |
+
None,
|
321 |
+
q[:, :img_q_len, :, :],
|
322 |
+
k[:, :img_kv_len, :, :],
|
323 |
+
v[:, :img_kv_len, :, :],
|
324 |
+
dropout_p=0.0,
|
325 |
+
causal=False,
|
326 |
+
joint_tensor_query=q[:,img_q_len:cu_seqlens_q[1]],
|
327 |
+
joint_tensor_key=k[:,img_kv_len:cu_seqlens_kv[1]],
|
328 |
+
joint_tensor_value=v[:,img_kv_len:cu_seqlens_kv[1]],
|
329 |
+
joint_strategy="rear",
|
330 |
+
)
|
331 |
+
if flash_attn.__version__ >= '2.7.0':
|
332 |
+
attn2, *_ = _flash_attn_forward(
|
333 |
+
q[:,cu_seqlens_q[1]:],
|
334 |
+
k[:,cu_seqlens_kv[1]:],
|
335 |
+
v[:,cu_seqlens_kv[1]:],
|
336 |
+
dropout_p=0.0,
|
337 |
+
softmax_scale=q.shape[-1] ** (-0.5),
|
338 |
+
causal=False,
|
339 |
+
window_size_left=-1,
|
340 |
+
window_size_right=-1,
|
341 |
+
softcap=0.0,
|
342 |
+
alibi_slopes=None,
|
343 |
+
return_softmax=False,
|
344 |
+
)
|
345 |
+
else:
|
346 |
+
attn2, *_ = _flash_attn_forward(
|
347 |
+
q[:,cu_seqlens_q[1]:],
|
348 |
+
k[:,cu_seqlens_kv[1]:],
|
349 |
+
v[:,cu_seqlens_kv[1]:],
|
350 |
+
dropout_p=0.0,
|
351 |
+
softmax_scale=q.shape[-1] ** (-0.5),
|
352 |
+
causal=False,
|
353 |
+
window_size=(-1, -1),
|
354 |
+
softcap=0.0,
|
355 |
+
alibi_slopes=None,
|
356 |
+
return_softmax=False,
|
357 |
+
)
|
358 |
+
attn = torch.cat([attn1, attn2], dim=1)
|
359 |
+
b, s, a, d = attn.shape
|
360 |
+
attn = attn.reshape(b, s, -1)
|
361 |
+
|
362 |
+
return attn
|
hyvideo/modules/audio_adapters.py
ADDED
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This module provides the implementation of an Audio Projection Model, which is designed for
|
3 |
+
audio processing tasks. The model takes audio embeddings as input and outputs context tokens
|
4 |
+
that can be used for various downstream applications, such as audio analysis or synthesis.
|
5 |
+
|
6 |
+
The AudioProjModel class is based on the ModelMixin class from the diffusers library, which
|
7 |
+
provides a foundation for building custom models. This implementation includes multiple linear
|
8 |
+
layers with ReLU activation functions and a LayerNorm for normalization.
|
9 |
+
|
10 |
+
Key Features:
|
11 |
+
- Audio embedding input with flexible sequence length and block structure.
|
12 |
+
- Multiple linear layers for feature transformation.
|
13 |
+
- ReLU activation for non-linear transformation.
|
14 |
+
- LayerNorm for stabilizing and speeding up training.
|
15 |
+
- Rearrangement of input embeddings to match the model's expected input shape.
|
16 |
+
- Customizable number of blocks, channels, and context tokens for adaptability.
|
17 |
+
|
18 |
+
The module is structured to be easily integrated into larger systems or used as a standalone
|
19 |
+
component for audio feature extraction and processing.
|
20 |
+
|
21 |
+
Classes:
|
22 |
+
- AudioProjModel: A class representing the audio projection model with configurable parameters.
|
23 |
+
|
24 |
+
Functions:
|
25 |
+
- (none)
|
26 |
+
|
27 |
+
Dependencies:
|
28 |
+
- torch: For tensor operations and neural network components.
|
29 |
+
- diffusers: For the ModelMixin base class.
|
30 |
+
- einops: For tensor rearrangement operations.
|
31 |
+
|
32 |
+
"""
|
33 |
+
|
34 |
+
import torch
|
35 |
+
from diffusers import ModelMixin
|
36 |
+
from einops import rearrange
|
37 |
+
|
38 |
+
import math
|
39 |
+
import torch.nn as nn
|
40 |
+
|
41 |
+
class AudioProjNet2(ModelMixin):
|
42 |
+
"""Audio Projection Model
|
43 |
+
|
44 |
+
This class defines an audio projection model that takes audio embeddings as input
|
45 |
+
and produces context tokens as output. The model is based on the ModelMixin class
|
46 |
+
and consists of multiple linear layers and activation functions. It can be used
|
47 |
+
for various audio processing tasks.
|
48 |
+
|
49 |
+
Attributes:
|
50 |
+
seq_len (int): The length of the audio sequence.
|
51 |
+
blocks (int): The number of blocks in the audio projection model.
|
52 |
+
channels (int): The number of channels in the audio projection model.
|
53 |
+
intermediate_dim (int): The intermediate dimension of the model.
|
54 |
+
context_tokens (int): The number of context tokens in the output.
|
55 |
+
output_dim (int): The output dimension of the context tokens.
|
56 |
+
|
57 |
+
Methods:
|
58 |
+
__init__(self, seq_len=5, blocks=12, channels=768, intermediate_dim=512, context_tokens=32, output_dim=768):
|
59 |
+
Initializes the AudioProjModel with the given parameters.
|
60 |
+
forward(self, audio_embeds):
|
61 |
+
Defines the forward pass for the AudioProjModel.
|
62 |
+
Parameters:
|
63 |
+
audio_embeds (torch.Tensor): The input audio embeddings with shape (batch_size, video_length, blocks, channels).
|
64 |
+
Returns:
|
65 |
+
context_tokens (torch.Tensor): The output context tokens with shape (batch_size, video_length, context_tokens, output_dim).
|
66 |
+
|
67 |
+
"""
|
68 |
+
|
69 |
+
def __init__(
|
70 |
+
self,
|
71 |
+
seq_len=5,
|
72 |
+
blocks=12, # add a new parameter blocks
|
73 |
+
channels=768, # add a new parameter channels
|
74 |
+
intermediate_dim=512,
|
75 |
+
output_dim=768,
|
76 |
+
context_tokens=4,
|
77 |
+
):
|
78 |
+
super().__init__()
|
79 |
+
|
80 |
+
self.seq_len = seq_len
|
81 |
+
self.blocks = blocks
|
82 |
+
self.channels = channels
|
83 |
+
self.input_dim = (
|
84 |
+
seq_len * blocks * channels
|
85 |
+
)
|
86 |
+
self.intermediate_dim = intermediate_dim
|
87 |
+
self.context_tokens = context_tokens
|
88 |
+
self.output_dim = output_dim
|
89 |
+
|
90 |
+
# define multiple linear layers
|
91 |
+
self.proj1 = nn.Linear(self.input_dim, intermediate_dim)
|
92 |
+
self.proj2 = nn.Linear(intermediate_dim, intermediate_dim)
|
93 |
+
self.proj3 = nn.Linear(intermediate_dim, context_tokens * output_dim)
|
94 |
+
|
95 |
+
self.norm = nn.LayerNorm(output_dim)
|
96 |
+
|
97 |
+
|
98 |
+
def forward(self, audio_embeds):
|
99 |
+
|
100 |
+
video_length = audio_embeds.shape[1]
|
101 |
+
audio_embeds = rearrange(audio_embeds, "bz f w b c -> (bz f) w b c")
|
102 |
+
batch_size, window_size, blocks, channels = audio_embeds.shape
|
103 |
+
audio_embeds = audio_embeds.view(batch_size, window_size * blocks * channels)
|
104 |
+
|
105 |
+
audio_embeds = torch.relu(self.proj1(audio_embeds))
|
106 |
+
audio_embeds = torch.relu(self.proj2(audio_embeds))
|
107 |
+
|
108 |
+
context_tokens = self.proj3(audio_embeds).reshape(
|
109 |
+
batch_size, self.context_tokens, self.output_dim
|
110 |
+
)
|
111 |
+
context_tokens = self.norm(context_tokens)
|
112 |
+
out_all = rearrange(
|
113 |
+
context_tokens, "(bz f) m c -> bz f m c", f=video_length
|
114 |
+
)
|
115 |
+
|
116 |
+
return out_all
|
117 |
+
|
118 |
+
|
119 |
+
def reshape_tensor(x, heads):
|
120 |
+
bs, length, width = x.shape
|
121 |
+
# (bs, length, width) --> (bs, length, n_heads, dim_per_head)
|
122 |
+
x = x.view(bs, length, heads, -1)
|
123 |
+
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
|
124 |
+
x = x.transpose(1, 2)
|
125 |
+
# (bs, n_heads, length, dim_per_head)
|
126 |
+
x = x.reshape(bs, heads, length, -1)
|
127 |
+
return x
|
128 |
+
|
129 |
+
|
130 |
+
class PerceiverAttentionCA(nn.Module):
|
131 |
+
def __init__(self, *, dim=3072, dim_head=1024, heads=33):
|
132 |
+
super().__init__()
|
133 |
+
self.scale = dim_head ** -0.5
|
134 |
+
self.dim_head = dim_head
|
135 |
+
self.heads = heads
|
136 |
+
inner_dim = dim_head #* heads
|
137 |
+
|
138 |
+
self.norm1 = nn.LayerNorm(dim)
|
139 |
+
self.norm2 = nn.LayerNorm(dim)
|
140 |
+
|
141 |
+
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
142 |
+
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
|
143 |
+
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
144 |
+
|
145 |
+
import torch.nn.init as init
|
146 |
+
init.zeros_(self.to_out.weight)
|
147 |
+
if self.to_out.bias is not None:
|
148 |
+
init.zeros_(self.to_out.bias)
|
149 |
+
|
150 |
+
def forward(self, x, latents):
|
151 |
+
"""
|
152 |
+
Args:
|
153 |
+
x (torch.Tensor): image features
|
154 |
+
shape (b, t, aa, D)
|
155 |
+
latent (torch.Tensor): latent features
|
156 |
+
shape (b, t, hw, D)
|
157 |
+
"""
|
158 |
+
x = self.norm1(x)
|
159 |
+
latents = self.norm2(latents)
|
160 |
+
# print("latents shape: ", latents.shape)
|
161 |
+
# print("x shape: ", x.shape)
|
162 |
+
q = self.to_q(latents)
|
163 |
+
k, v = self.to_kv(x).chunk(2, dim=-1)
|
164 |
+
|
165 |
+
|
166 |
+
# attention
|
167 |
+
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
|
168 |
+
weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
|
169 |
+
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
170 |
+
out = weight @ v
|
171 |
+
|
172 |
+
# out = out.permute(0, 2, 1, 3)
|
173 |
+
return self.to_out(out)
|
174 |
+
#def forward(self, x, latents):
|
175 |
+
# """
|
176 |
+
# Args:
|
177 |
+
# x (torch.Tensor): image features
|
178 |
+
# shape (b, t, aa, D)
|
179 |
+
# latent (torch.Tensor): latent features
|
180 |
+
# shape (b, t, hw, D)
|
181 |
+
# """
|
182 |
+
# if get_sequence_parallel_state():
|
183 |
+
# sp_size = nccl_info.sp_size
|
184 |
+
# sp_rank = nccl_info.rank_within_group
|
185 |
+
# print("rank:", latents.shape, sp_size, sp_rank)
|
186 |
+
# latents = torch.chunk(latents, sp_size, dim=1)[sp_rank]
|
187 |
+
|
188 |
+
# x = self.norm1(x)
|
189 |
+
# latents = self.norm2(latents)
|
190 |
+
# # print("latents shape: ", latents.shape)
|
191 |
+
# # print("x shape: ", x.shape)
|
192 |
+
# q = self.to_q(latents)
|
193 |
+
# k, v = self.to_kv(x).chunk(2, dim=-1)
|
194 |
+
|
195 |
+
# # print("q, k, v: ", q.shape, k.shape, v.shape)
|
196 |
+
|
197 |
+
# # attention
|
198 |
+
# #scale = 1 / math.sqrt(math.sqrt(self.dim_head))
|
199 |
+
# #weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
|
200 |
+
# #weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
201 |
+
# #out = weight @ v
|
202 |
+
# def shrink_head(encoder_state, dim):
|
203 |
+
# local_heads = encoder_state.shape[dim] // nccl_info.sp_size
|
204 |
+
# return encoder_state.narrow(dim, nccl_info.rank_within_group * local_heads, local_heads)
|
205 |
+
|
206 |
+
# if get_sequence_parallel_state():
|
207 |
+
# # batch_size, seq_len, attn_heads, head_dim
|
208 |
+
# q = all_to_all_4D(q, scatter_dim=2, gather_dim=1) # [2, 32256, 24, 128]
|
209 |
+
# k = shrink_head(k ,dim=2)
|
210 |
+
# v = shrink_head(v ,dim=2)
|
211 |
+
# qkv = torch.stack([query, key, value], dim=2)
|
212 |
+
# attn = flash_attn_no_pad(qkv, causal=False, dropout_p=0.0, softmax_scale=None)
|
213 |
+
# # out = out.permute(0, 2, 1, 3)
|
214 |
+
# #b, s, a, d = attn.shape
|
215 |
+
# #attn = attn.reshape(b, s, -1)
|
216 |
+
#
|
217 |
+
# out = self.to_out(attn)
|
218 |
+
# if get_sequence_parallel_state():
|
219 |
+
# out = all_gather(out, dim=1)
|
220 |
+
# return out
|
hyvideo/modules/embed_layers.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from einops import rearrange, repeat
|
5 |
+
|
6 |
+
from ..utils.helpers import to_2tuple
|
7 |
+
|
8 |
+
|
9 |
+
class PatchEmbed(nn.Module):
|
10 |
+
"""2D Image to Patch Embedding
|
11 |
+
|
12 |
+
Image to Patch Embedding using Conv2d
|
13 |
+
|
14 |
+
A convolution based approach to patchifying a 2D image w/ embedding projection.
|
15 |
+
|
16 |
+
Based on the impl in https://github.com/google-research/vision_transformer
|
17 |
+
|
18 |
+
Hacked together by / Copyright 2020 Ross Wightman
|
19 |
+
|
20 |
+
Remove the _assert function in forward function to be compatible with multi-resolution images.
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(
|
24 |
+
self,
|
25 |
+
patch_size=16,
|
26 |
+
in_chans=3,
|
27 |
+
embed_dim=768,
|
28 |
+
norm_layer=None,
|
29 |
+
flatten=True,
|
30 |
+
bias=True,
|
31 |
+
dtype=None,
|
32 |
+
device=None,
|
33 |
+
):
|
34 |
+
factory_kwargs = {"dtype": dtype, "device": device}
|
35 |
+
super().__init__()
|
36 |
+
patch_size = to_2tuple(patch_size)
|
37 |
+
self.patch_size = patch_size
|
38 |
+
self.flatten = flatten
|
39 |
+
|
40 |
+
self.proj = nn.Conv3d(
|
41 |
+
in_chans,
|
42 |
+
embed_dim,
|
43 |
+
kernel_size=patch_size,
|
44 |
+
stride=patch_size,
|
45 |
+
bias=bias,
|
46 |
+
**factory_kwargs
|
47 |
+
)
|
48 |
+
nn.init.xavier_uniform_(self.proj.weight.view(self.proj.weight.size(0), -1))
|
49 |
+
if bias:
|
50 |
+
nn.init.zeros_(self.proj.bias)
|
51 |
+
|
52 |
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
53 |
+
|
54 |
+
def forward(self, x):
|
55 |
+
x = self.proj(x)
|
56 |
+
shape = x.shape
|
57 |
+
if self.flatten:
|
58 |
+
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
|
59 |
+
x = self.norm(x)
|
60 |
+
return x, shape
|
61 |
+
|
62 |
+
|
63 |
+
class TextProjection(nn.Module):
|
64 |
+
"""
|
65 |
+
Projects text embeddings. Also handles dropout for classifier-free guidance.
|
66 |
+
|
67 |
+
Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
|
68 |
+
"""
|
69 |
+
|
70 |
+
def __init__(self, in_channels, hidden_size, act_layer, dtype=None, device=None):
|
71 |
+
factory_kwargs = {"dtype": dtype, "device": device}
|
72 |
+
super().__init__()
|
73 |
+
self.linear_1 = nn.Linear(
|
74 |
+
in_features=in_channels,
|
75 |
+
out_features=hidden_size,
|
76 |
+
bias=True,
|
77 |
+
**factory_kwargs
|
78 |
+
)
|
79 |
+
self.act_1 = act_layer()
|
80 |
+
self.linear_2 = nn.Linear(
|
81 |
+
in_features=hidden_size,
|
82 |
+
out_features=hidden_size,
|
83 |
+
bias=True,
|
84 |
+
**factory_kwargs
|
85 |
+
)
|
86 |
+
|
87 |
+
def forward(self, caption):
|
88 |
+
hidden_states = self.linear_1(caption)
|
89 |
+
hidden_states = self.act_1(hidden_states)
|
90 |
+
hidden_states = self.linear_2(hidden_states)
|
91 |
+
return hidden_states
|
92 |
+
|
93 |
+
|
94 |
+
def timestep_embedding(t, dim, max_period=10000):
|
95 |
+
"""
|
96 |
+
Create sinusoidal timestep embeddings.
|
97 |
+
|
98 |
+
Args:
|
99 |
+
t (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional.
|
100 |
+
dim (int): the dimension of the output.
|
101 |
+
max_period (int): controls the minimum frequency of the embeddings.
|
102 |
+
|
103 |
+
Returns:
|
104 |
+
embedding (torch.Tensor): An (N, D) Tensor of positional embeddings.
|
105 |
+
|
106 |
+
.. ref_link: https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
107 |
+
"""
|
108 |
+
half = dim // 2
|
109 |
+
freqs = torch.exp(
|
110 |
+
-math.log(max_period)
|
111 |
+
* torch.arange(start=0, end=half, dtype=torch.float32)
|
112 |
+
/ half
|
113 |
+
).to(device=t.device)
|
114 |
+
args = t[:, None].float() * freqs[None]
|
115 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
116 |
+
if dim % 2:
|
117 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
118 |
+
return embedding
|
119 |
+
|
120 |
+
|
121 |
+
class TimestepEmbedder(nn.Module):
|
122 |
+
"""
|
123 |
+
Embeds scalar timesteps into vector representations.
|
124 |
+
"""
|
125 |
+
|
126 |
+
def __init__(
|
127 |
+
self,
|
128 |
+
hidden_size,
|
129 |
+
act_layer,
|
130 |
+
frequency_embedding_size=256,
|
131 |
+
max_period=10000,
|
132 |
+
out_size=None,
|
133 |
+
dtype=None,
|
134 |
+
device=None,
|
135 |
+
):
|
136 |
+
factory_kwargs = {"dtype": dtype, "device": device}
|
137 |
+
super().__init__()
|
138 |
+
self.frequency_embedding_size = frequency_embedding_size
|
139 |
+
self.max_period = max_period
|
140 |
+
if out_size is None:
|
141 |
+
out_size = hidden_size
|
142 |
+
|
143 |
+
self.mlp = nn.Sequential(
|
144 |
+
nn.Linear(
|
145 |
+
frequency_embedding_size, hidden_size, bias=True, **factory_kwargs
|
146 |
+
),
|
147 |
+
act_layer(),
|
148 |
+
nn.Linear(hidden_size, out_size, bias=True, **factory_kwargs),
|
149 |
+
)
|
150 |
+
nn.init.normal_(self.mlp[0].weight, std=0.02)
|
151 |
+
nn.init.normal_(self.mlp[2].weight, std=0.02)
|
152 |
+
|
153 |
+
def forward(self, t):
|
154 |
+
t_freq = timestep_embedding(
|
155 |
+
t, self.frequency_embedding_size, self.max_period
|
156 |
+
).type(self.mlp[0].weight.dtype)
|
157 |
+
t_emb = self.mlp(t_freq)
|
158 |
+
return t_emb
|
hyvideo/modules/mlp_layers.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from timm library:
|
2 |
+
# https://github.com/huggingface/pytorch-image-models/blob/648aaa41233ba83eb38faf5ba9d415d574823241/timm/layers/mlp.py#L13
|
3 |
+
|
4 |
+
from functools import partial
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
|
9 |
+
from .modulate_layers import modulate_
|
10 |
+
from ..utils.helpers import to_2tuple
|
11 |
+
|
12 |
+
|
13 |
+
class MLP(nn.Module):
|
14 |
+
"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""
|
15 |
+
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
in_channels,
|
19 |
+
hidden_channels=None,
|
20 |
+
out_features=None,
|
21 |
+
act_layer=nn.GELU,
|
22 |
+
norm_layer=None,
|
23 |
+
bias=True,
|
24 |
+
drop=0.0,
|
25 |
+
use_conv=False,
|
26 |
+
device=None,
|
27 |
+
dtype=None,
|
28 |
+
):
|
29 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
30 |
+
super().__init__()
|
31 |
+
out_features = out_features or in_channels
|
32 |
+
hidden_channels = hidden_channels or in_channels
|
33 |
+
bias = to_2tuple(bias)
|
34 |
+
drop_probs = to_2tuple(drop)
|
35 |
+
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
|
36 |
+
|
37 |
+
self.fc1 = linear_layer(
|
38 |
+
in_channels, hidden_channels, bias=bias[0], **factory_kwargs
|
39 |
+
)
|
40 |
+
self.act = act_layer()
|
41 |
+
self.drop1 = nn.Dropout(drop_probs[0])
|
42 |
+
self.norm = (
|
43 |
+
norm_layer(hidden_channels, **factory_kwargs)
|
44 |
+
if norm_layer is not None
|
45 |
+
else nn.Identity()
|
46 |
+
)
|
47 |
+
self.fc2 = linear_layer(
|
48 |
+
hidden_channels, out_features, bias=bias[1], **factory_kwargs
|
49 |
+
)
|
50 |
+
self.drop2 = nn.Dropout(drop_probs[1])
|
51 |
+
|
52 |
+
def forward(self, x):
|
53 |
+
x = self.fc1(x)
|
54 |
+
x = self.act(x)
|
55 |
+
x = self.drop1(x)
|
56 |
+
x = self.norm(x)
|
57 |
+
x = self.fc2(x)
|
58 |
+
x = self.drop2(x)
|
59 |
+
return x
|
60 |
+
|
61 |
+
def apply_(self, x, divide = 4):
|
62 |
+
x_shape = x.shape
|
63 |
+
x = x.view(-1, x.shape[-1])
|
64 |
+
chunk_size = int(x_shape[1]/divide)
|
65 |
+
x_chunks = torch.split(x, chunk_size)
|
66 |
+
for i, x_chunk in enumerate(x_chunks):
|
67 |
+
mlp_chunk = self.fc1(x_chunk)
|
68 |
+
mlp_chunk = self.act(mlp_chunk)
|
69 |
+
mlp_chunk = self.drop1(mlp_chunk)
|
70 |
+
mlp_chunk = self.norm(mlp_chunk)
|
71 |
+
mlp_chunk = self.fc2(mlp_chunk)
|
72 |
+
x_chunk[...] = self.drop2(mlp_chunk)
|
73 |
+
return x
|
74 |
+
|
75 |
+
#
|
76 |
+
class MLPEmbedder(nn.Module):
|
77 |
+
"""copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py"""
|
78 |
+
def __init__(self, in_dim: int, hidden_dim: int, device=None, dtype=None):
|
79 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
80 |
+
super().__init__()
|
81 |
+
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True, **factory_kwargs)
|
82 |
+
self.silu = nn.SiLU()
|
83 |
+
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True, **factory_kwargs)
|
84 |
+
|
85 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
86 |
+
return self.out_layer(self.silu(self.in_layer(x)))
|
87 |
+
|
88 |
+
|
89 |
+
class FinalLayer(nn.Module):
|
90 |
+
"""The final layer of DiT."""
|
91 |
+
|
92 |
+
def __init__(
|
93 |
+
self, hidden_size, patch_size, out_channels, act_layer, device=None, dtype=None
|
94 |
+
):
|
95 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
96 |
+
super().__init__()
|
97 |
+
|
98 |
+
# Just use LayerNorm for the final layer
|
99 |
+
self.norm_final = nn.LayerNorm(
|
100 |
+
hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
|
101 |
+
)
|
102 |
+
if isinstance(patch_size, int):
|
103 |
+
self.linear = nn.Linear(
|
104 |
+
hidden_size,
|
105 |
+
patch_size * patch_size * out_channels,
|
106 |
+
bias=True,
|
107 |
+
**factory_kwargs
|
108 |
+
)
|
109 |
+
else:
|
110 |
+
self.linear = nn.Linear(
|
111 |
+
hidden_size,
|
112 |
+
patch_size[0] * patch_size[1] * patch_size[2] * out_channels,
|
113 |
+
bias=True,
|
114 |
+
)
|
115 |
+
nn.init.zeros_(self.linear.weight)
|
116 |
+
nn.init.zeros_(self.linear.bias)
|
117 |
+
|
118 |
+
# Here we don't distinguish between the modulate types. Just use the simple one.
|
119 |
+
self.adaLN_modulation = nn.Sequential(
|
120 |
+
act_layer(),
|
121 |
+
nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs),
|
122 |
+
)
|
123 |
+
# Zero-initialize the modulation
|
124 |
+
nn.init.zeros_(self.adaLN_modulation[1].weight)
|
125 |
+
nn.init.zeros_(self.adaLN_modulation[1].bias)
|
126 |
+
|
127 |
+
def forward(self, x, c):
|
128 |
+
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
129 |
+
x = modulate_(self.norm_final(x), shift=shift, scale=scale)
|
130 |
+
x = self.linear(x)
|
131 |
+
return x
|
hyvideo/modules/models.py
ADDED
@@ -0,0 +1,1159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, List, Tuple, Optional, Union, Dict
|
2 |
+
from einops import rearrange
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
from diffusers.models import ModelMixin
|
9 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
10 |
+
|
11 |
+
from .activation_layers import get_activation_layer
|
12 |
+
from .norm_layers import get_norm_layer
|
13 |
+
from .embed_layers import TimestepEmbedder, PatchEmbed, TextProjection
|
14 |
+
from .attenion import attention, parallel_attention, get_cu_seqlens
|
15 |
+
from .posemb_layers import apply_rotary_emb
|
16 |
+
from .mlp_layers import MLP, MLPEmbedder, FinalLayer
|
17 |
+
from .modulate_layers import ModulateDiT, modulate, modulate_ , apply_gate, apply_gate_and_accumulate_
|
18 |
+
from .token_refiner import SingleTokenRefiner
|
19 |
+
import numpy as np
|
20 |
+
from mmgp import offload
|
21 |
+
from wan.modules.attention import pay_attention
|
22 |
+
from .audio_adapters import AudioProjNet2, PerceiverAttentionCA
|
23 |
+
|
24 |
+
def get_linear_split_map():
|
25 |
+
hidden_size = 3072
|
26 |
+
split_linear_modules_map = {
|
27 |
+
"img_attn_qkv" : {"mapped_modules" : ["img_attn_q", "img_attn_k", "img_attn_v"] , "split_sizes": [hidden_size, hidden_size, hidden_size]},
|
28 |
+
"linear1" : {"mapped_modules" : ["linear1_attn_q", "linear1_attn_k", "linear1_attn_v", "linear1_mlp"] , "split_sizes": [hidden_size, hidden_size, hidden_size, 7*hidden_size- 3*hidden_size]}
|
29 |
+
}
|
30 |
+
return split_linear_modules_map
|
31 |
+
try:
|
32 |
+
from xformers.ops.fmha.attn_bias import BlockDiagonalPaddedKeysMask
|
33 |
+
except ImportError:
|
34 |
+
BlockDiagonalPaddedKeysMask = None
|
35 |
+
|
36 |
+
|
37 |
+
class MMDoubleStreamBlock(nn.Module):
|
38 |
+
"""
|
39 |
+
A multimodal dit block with seperate modulation for
|
40 |
+
text and image/video, see more details (SD3): https://arxiv.org/abs/2403.03206
|
41 |
+
(Flux.1): https://github.com/black-forest-labs/flux
|
42 |
+
"""
|
43 |
+
|
44 |
+
def __init__(
|
45 |
+
self,
|
46 |
+
hidden_size: int,
|
47 |
+
heads_num: int,
|
48 |
+
mlp_width_ratio: float,
|
49 |
+
mlp_act_type: str = "gelu_tanh",
|
50 |
+
qk_norm: bool = True,
|
51 |
+
qk_norm_type: str = "rms",
|
52 |
+
qkv_bias: bool = False,
|
53 |
+
dtype: Optional[torch.dtype] = None,
|
54 |
+
device: Optional[torch.device] = None,
|
55 |
+
attention_mode: str = "sdpa",
|
56 |
+
):
|
57 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
58 |
+
super().__init__()
|
59 |
+
|
60 |
+
self.attention_mode = attention_mode
|
61 |
+
self.deterministic = False
|
62 |
+
self.heads_num = heads_num
|
63 |
+
head_dim = hidden_size // heads_num
|
64 |
+
mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
|
65 |
+
|
66 |
+
self.img_mod = ModulateDiT(
|
67 |
+
hidden_size,
|
68 |
+
factor=6,
|
69 |
+
act_layer=get_activation_layer("silu"),
|
70 |
+
**factory_kwargs,
|
71 |
+
)
|
72 |
+
self.img_norm1 = nn.LayerNorm(
|
73 |
+
hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
|
74 |
+
)
|
75 |
+
|
76 |
+
self.img_attn_qkv = nn.Linear(
|
77 |
+
hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs
|
78 |
+
)
|
79 |
+
qk_norm_layer = get_norm_layer(qk_norm_type)
|
80 |
+
self.img_attn_q_norm = (
|
81 |
+
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
|
82 |
+
if qk_norm
|
83 |
+
else nn.Identity()
|
84 |
+
)
|
85 |
+
self.img_attn_k_norm = (
|
86 |
+
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
|
87 |
+
if qk_norm
|
88 |
+
else nn.Identity()
|
89 |
+
)
|
90 |
+
self.img_attn_proj = nn.Linear(
|
91 |
+
hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs
|
92 |
+
)
|
93 |
+
|
94 |
+
self.img_norm2 = nn.LayerNorm(
|
95 |
+
hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
|
96 |
+
)
|
97 |
+
self.img_mlp = MLP(
|
98 |
+
hidden_size,
|
99 |
+
mlp_hidden_dim,
|
100 |
+
act_layer=get_activation_layer(mlp_act_type),
|
101 |
+
bias=True,
|
102 |
+
**factory_kwargs,
|
103 |
+
)
|
104 |
+
|
105 |
+
self.txt_mod = ModulateDiT(
|
106 |
+
hidden_size,
|
107 |
+
factor=6,
|
108 |
+
act_layer=get_activation_layer("silu"),
|
109 |
+
**factory_kwargs,
|
110 |
+
)
|
111 |
+
self.txt_norm1 = nn.LayerNorm(
|
112 |
+
hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
|
113 |
+
)
|
114 |
+
|
115 |
+
self.txt_attn_qkv = nn.Linear(
|
116 |
+
hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs
|
117 |
+
)
|
118 |
+
self.txt_attn_q_norm = (
|
119 |
+
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
|
120 |
+
if qk_norm
|
121 |
+
else nn.Identity()
|
122 |
+
)
|
123 |
+
self.txt_attn_k_norm = (
|
124 |
+
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
|
125 |
+
if qk_norm
|
126 |
+
else nn.Identity()
|
127 |
+
)
|
128 |
+
self.txt_attn_proj = nn.Linear(
|
129 |
+
hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs
|
130 |
+
)
|
131 |
+
|
132 |
+
self.txt_norm2 = nn.LayerNorm(
|
133 |
+
hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
|
134 |
+
)
|
135 |
+
self.txt_mlp = MLP(
|
136 |
+
hidden_size,
|
137 |
+
mlp_hidden_dim,
|
138 |
+
act_layer=get_activation_layer(mlp_act_type),
|
139 |
+
bias=True,
|
140 |
+
**factory_kwargs,
|
141 |
+
)
|
142 |
+
self.hybrid_seq_parallel_attn = None
|
143 |
+
|
144 |
+
def enable_deterministic(self):
|
145 |
+
self.deterministic = True
|
146 |
+
|
147 |
+
def disable_deterministic(self):
|
148 |
+
self.deterministic = False
|
149 |
+
|
150 |
+
def forward(
|
151 |
+
self,
|
152 |
+
img: torch.Tensor,
|
153 |
+
txt: torch.Tensor,
|
154 |
+
vec: torch.Tensor,
|
155 |
+
attn_mask = None,
|
156 |
+
seqlens_q: Optional[torch.Tensor] = None,
|
157 |
+
seqlens_kv: Optional[torch.Tensor] = None,
|
158 |
+
freqs_cis: tuple = None,
|
159 |
+
condition_type: str = None,
|
160 |
+
token_replace_vec: torch.Tensor = None,
|
161 |
+
frist_frame_token_num: int = None,
|
162 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
163 |
+
|
164 |
+
if condition_type == "token_replace":
|
165 |
+
img_mod1, token_replace_img_mod1 = self.img_mod(vec, condition_type=condition_type, \
|
166 |
+
token_replace_vec=token_replace_vec)
|
167 |
+
(img_mod1_shift,
|
168 |
+
img_mod1_scale,
|
169 |
+
img_mod1_gate,
|
170 |
+
img_mod2_shift,
|
171 |
+
img_mod2_scale,
|
172 |
+
img_mod2_gate) = img_mod1.chunk(6, dim=-1)
|
173 |
+
(tr_img_mod1_shift,
|
174 |
+
tr_img_mod1_scale,
|
175 |
+
tr_img_mod1_gate,
|
176 |
+
tr_img_mod2_shift,
|
177 |
+
tr_img_mod2_scale,
|
178 |
+
tr_img_mod2_gate) = token_replace_img_mod1.chunk(6, dim=-1)
|
179 |
+
else:
|
180 |
+
(
|
181 |
+
img_mod1_shift,
|
182 |
+
img_mod1_scale,
|
183 |
+
img_mod1_gate,
|
184 |
+
img_mod2_shift,
|
185 |
+
img_mod2_scale,
|
186 |
+
img_mod2_gate,
|
187 |
+
) = self.img_mod(vec).chunk(6, dim=-1)
|
188 |
+
(
|
189 |
+
txt_mod1_shift,
|
190 |
+
txt_mod1_scale,
|
191 |
+
txt_mod1_gate,
|
192 |
+
txt_mod2_shift,
|
193 |
+
txt_mod2_scale,
|
194 |
+
txt_mod2_gate,
|
195 |
+
) = self.txt_mod(vec).chunk(6, dim=-1)
|
196 |
+
|
197 |
+
##### Enjoy this spagheti VRAM optimizations done by DeepBeepMeep !
|
198 |
+
# I am sure you are a nice person and as you copy this code, you will give me officially proper credits:
|
199 |
+
# Please link to https://github.com/deepbeepmeep/HunyuanVideoGP and @deepbeepmeep on twitter
|
200 |
+
|
201 |
+
# Prepare image for attention.
|
202 |
+
img_modulated = self.img_norm1(img)
|
203 |
+
img_modulated = img_modulated.to(torch.bfloat16)
|
204 |
+
|
205 |
+
if condition_type == "token_replace":
|
206 |
+
modulate_(img_modulated[:, :frist_frame_token_num], shift=tr_img_mod1_shift, scale=tr_img_mod1_scale)
|
207 |
+
modulate_(img_modulated[:, frist_frame_token_num:], shift=img_mod1_shift, scale=img_mod1_scale)
|
208 |
+
else:
|
209 |
+
modulate_( img_modulated, shift=img_mod1_shift, scale=img_mod1_scale )
|
210 |
+
|
211 |
+
shape = (*img_modulated.shape[:2], self.heads_num, int(img_modulated.shape[-1] / self.heads_num) )
|
212 |
+
img_q = self.img_attn_q(img_modulated).view(*shape)
|
213 |
+
img_k = self.img_attn_k(img_modulated).view(*shape)
|
214 |
+
img_v = self.img_attn_v(img_modulated).view(*shape)
|
215 |
+
del img_modulated
|
216 |
+
|
217 |
+
# Apply QK-Norm if needed
|
218 |
+
self.img_attn_q_norm.apply_(img_q).to(img_v)
|
219 |
+
img_q_len = img_q.shape[1]
|
220 |
+
self.img_attn_k_norm.apply_(img_k).to(img_v)
|
221 |
+
img_kv_len= img_k.shape[1]
|
222 |
+
batch_size = img_k.shape[0]
|
223 |
+
# Apply RoPE if needed.
|
224 |
+
qklist = [img_q, img_k]
|
225 |
+
del img_q, img_k
|
226 |
+
img_q, img_k = apply_rotary_emb(qklist, freqs_cis, head_first=False)
|
227 |
+
# Prepare txt for attention.
|
228 |
+
txt_modulated = self.txt_norm1(txt)
|
229 |
+
modulate_(txt_modulated, shift=txt_mod1_shift, scale=txt_mod1_scale )
|
230 |
+
|
231 |
+
txt_qkv = self.txt_attn_qkv(txt_modulated)
|
232 |
+
del txt_modulated
|
233 |
+
txt_q, txt_k, txt_v = rearrange(
|
234 |
+
txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num
|
235 |
+
)
|
236 |
+
del txt_qkv
|
237 |
+
# Apply QK-Norm if needed.
|
238 |
+
self.txt_attn_q_norm.apply_(txt_q).to(txt_v)
|
239 |
+
self.txt_attn_k_norm.apply_(txt_k).to(txt_v)
|
240 |
+
|
241 |
+
# Run actual attention.
|
242 |
+
q = torch.cat((img_q, txt_q), dim=1)
|
243 |
+
del img_q, txt_q
|
244 |
+
k = torch.cat((img_k, txt_k), dim=1)
|
245 |
+
del img_k, txt_k
|
246 |
+
v = torch.cat((img_v, txt_v), dim=1)
|
247 |
+
del img_v, txt_v
|
248 |
+
|
249 |
+
# attention computation start
|
250 |
+
qkv_list = [q,k,v]
|
251 |
+
del q, k, v
|
252 |
+
|
253 |
+
attn = pay_attention(
|
254 |
+
qkv_list,
|
255 |
+
attention_mask=attn_mask,
|
256 |
+
q_lens=seqlens_q,
|
257 |
+
k_lens=seqlens_kv,
|
258 |
+
)
|
259 |
+
b, s, a, d = attn.shape
|
260 |
+
attn = attn.reshape(b, s, -1)
|
261 |
+
del qkv_list
|
262 |
+
|
263 |
+
# attention computation end
|
264 |
+
|
265 |
+
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1] :]
|
266 |
+
del attn
|
267 |
+
# Calculate the img bloks.
|
268 |
+
|
269 |
+
if condition_type == "token_replace":
|
270 |
+
img_attn = self.img_attn_proj(img_attn)
|
271 |
+
apply_gate_and_accumulate_(img[:, :frist_frame_token_num], img_attn[:, :frist_frame_token_num], gate=tr_img_mod1_gate)
|
272 |
+
apply_gate_and_accumulate_(img[:, frist_frame_token_num:], img_attn[:, frist_frame_token_num:], gate=img_mod1_gate)
|
273 |
+
del img_attn
|
274 |
+
img_modulated = self.img_norm2(img)
|
275 |
+
img_modulated = img_modulated.to(torch.bfloat16)
|
276 |
+
modulate_( img_modulated[:, :frist_frame_token_num], shift=tr_img_mod2_shift, scale=tr_img_mod2_scale)
|
277 |
+
modulate_( img_modulated[:, frist_frame_token_num:], shift=img_mod2_shift, scale=img_mod2_scale)
|
278 |
+
self.img_mlp.apply_(img_modulated)
|
279 |
+
apply_gate_and_accumulate_(img[:, :frist_frame_token_num], img_modulated[:, :frist_frame_token_num], gate=tr_img_mod2_gate)
|
280 |
+
apply_gate_and_accumulate_(img[:, frist_frame_token_num:], img_modulated[:, frist_frame_token_num:], gate=img_mod2_gate)
|
281 |
+
del img_modulated
|
282 |
+
else:
|
283 |
+
img_attn = self.img_attn_proj(img_attn)
|
284 |
+
apply_gate_and_accumulate_(img, img_attn, gate=img_mod1_gate)
|
285 |
+
del img_attn
|
286 |
+
img_modulated = self.img_norm2(img)
|
287 |
+
img_modulated = img_modulated.to(torch.bfloat16)
|
288 |
+
modulate_( img_modulated , shift=img_mod2_shift, scale=img_mod2_scale)
|
289 |
+
self.img_mlp.apply_(img_modulated)
|
290 |
+
apply_gate_and_accumulate_(img, img_modulated, gate=img_mod2_gate)
|
291 |
+
del img_modulated
|
292 |
+
|
293 |
+
# Calculate the txt bloks.
|
294 |
+
txt_attn = self.txt_attn_proj(txt_attn)
|
295 |
+
apply_gate_and_accumulate_(txt, txt_attn, gate=txt_mod1_gate)
|
296 |
+
del txt_attn
|
297 |
+
txt_modulated = self.txt_norm2(txt)
|
298 |
+
txt_modulated = txt_modulated.to(torch.bfloat16)
|
299 |
+
modulate_(txt_modulated, shift=txt_mod2_shift, scale=txt_mod2_scale)
|
300 |
+
txt_mlp = self.txt_mlp(txt_modulated)
|
301 |
+
del txt_modulated
|
302 |
+
apply_gate_and_accumulate_(txt, txt_mlp, gate=txt_mod2_gate)
|
303 |
+
return img, txt
|
304 |
+
|
305 |
+
|
306 |
+
class MMSingleStreamBlock(nn.Module):
|
307 |
+
"""
|
308 |
+
A DiT block with parallel linear layers as described in
|
309 |
+
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
|
310 |
+
Also refer to (SD3): https://arxiv.org/abs/2403.03206
|
311 |
+
(Flux.1): https://github.com/black-forest-labs/flux
|
312 |
+
"""
|
313 |
+
|
314 |
+
def __init__(
|
315 |
+
self,
|
316 |
+
hidden_size: int,
|
317 |
+
heads_num: int,
|
318 |
+
mlp_width_ratio: float = 4.0,
|
319 |
+
mlp_act_type: str = "gelu_tanh",
|
320 |
+
qk_norm: bool = True,
|
321 |
+
qk_norm_type: str = "rms",
|
322 |
+
qk_scale: float = None,
|
323 |
+
dtype: Optional[torch.dtype] = None,
|
324 |
+
device: Optional[torch.device] = None,
|
325 |
+
attention_mode: str = "sdpa",
|
326 |
+
):
|
327 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
328 |
+
super().__init__()
|
329 |
+
self.attention_mode = attention_mode
|
330 |
+
self.deterministic = False
|
331 |
+
self.hidden_size = hidden_size
|
332 |
+
self.heads_num = heads_num
|
333 |
+
head_dim = hidden_size // heads_num
|
334 |
+
mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
|
335 |
+
self.mlp_hidden_dim = mlp_hidden_dim
|
336 |
+
self.scale = qk_scale or head_dim ** -0.5
|
337 |
+
|
338 |
+
# qkv and mlp_in
|
339 |
+
self.linear1 = nn.Linear(
|
340 |
+
hidden_size, hidden_size * 3 + mlp_hidden_dim, **factory_kwargs
|
341 |
+
)
|
342 |
+
# proj and mlp_out
|
343 |
+
self.linear2 = nn.Linear(
|
344 |
+
hidden_size + mlp_hidden_dim, hidden_size, **factory_kwargs
|
345 |
+
)
|
346 |
+
|
347 |
+
qk_norm_layer = get_norm_layer(qk_norm_type)
|
348 |
+
self.q_norm = (
|
349 |
+
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
|
350 |
+
if qk_norm
|
351 |
+
else nn.Identity()
|
352 |
+
)
|
353 |
+
self.k_norm = (
|
354 |
+
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
|
355 |
+
if qk_norm
|
356 |
+
else nn.Identity()
|
357 |
+
)
|
358 |
+
|
359 |
+
self.pre_norm = nn.LayerNorm(
|
360 |
+
hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
|
361 |
+
)
|
362 |
+
|
363 |
+
self.mlp_act = get_activation_layer(mlp_act_type)()
|
364 |
+
self.modulation = ModulateDiT(
|
365 |
+
hidden_size,
|
366 |
+
factor=3,
|
367 |
+
act_layer=get_activation_layer("silu"),
|
368 |
+
**factory_kwargs,
|
369 |
+
)
|
370 |
+
self.hybrid_seq_parallel_attn = None
|
371 |
+
|
372 |
+
def enable_deterministic(self):
|
373 |
+
self.deterministic = True
|
374 |
+
|
375 |
+
def disable_deterministic(self):
|
376 |
+
self.deterministic = False
|
377 |
+
|
378 |
+
def forward(
|
379 |
+
self,
|
380 |
+
# x: torch.Tensor,
|
381 |
+
img: torch.Tensor,
|
382 |
+
txt: torch.Tensor,
|
383 |
+
vec: torch.Tensor,
|
384 |
+
txt_len: int,
|
385 |
+
attn_mask= None,
|
386 |
+
seqlens_q: Optional[torch.Tensor] = None,
|
387 |
+
seqlens_kv: Optional[torch.Tensor] = None,
|
388 |
+
freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None,
|
389 |
+
condition_type: str = None,
|
390 |
+
token_replace_vec: torch.Tensor = None,
|
391 |
+
frist_frame_token_num: int = None,
|
392 |
+
) -> torch.Tensor:
|
393 |
+
|
394 |
+
##### More spagheti VRAM optimizations done by DeepBeepMeep !
|
395 |
+
# I am sure you are a nice person and as you copy this code, you will give me proper credits:
|
396 |
+
# Please link to https://github.com/deepbeepmeep/HunyuanVideoGP and @deepbeepmeep on twitter
|
397 |
+
|
398 |
+
if condition_type == "token_replace":
|
399 |
+
mod, tr_mod = self.modulation(vec,
|
400 |
+
condition_type=condition_type,
|
401 |
+
token_replace_vec=token_replace_vec)
|
402 |
+
(mod_shift,
|
403 |
+
mod_scale,
|
404 |
+
mod_gate) = mod.chunk(3, dim=-1)
|
405 |
+
(tr_mod_shift,
|
406 |
+
tr_mod_scale,
|
407 |
+
tr_mod_gate) = tr_mod.chunk(3, dim=-1)
|
408 |
+
else:
|
409 |
+
mod_shift, mod_scale, mod_gate = self.modulation(vec).chunk(3, dim=-1)
|
410 |
+
|
411 |
+
img_mod = self.pre_norm(img)
|
412 |
+
img_mod = img_mod.to(torch.bfloat16)
|
413 |
+
if condition_type == "token_replace":
|
414 |
+
modulate_(img_mod[:, :frist_frame_token_num], shift=tr_mod_shift, scale=tr_mod_scale)
|
415 |
+
modulate_(img_mod[:, frist_frame_token_num:], shift=mod_shift, scale=mod_scale)
|
416 |
+
else:
|
417 |
+
modulate_(img_mod, shift=mod_shift, scale=mod_scale)
|
418 |
+
txt_mod = self.pre_norm(txt)
|
419 |
+
txt_mod = txt_mod.to(torch.bfloat16)
|
420 |
+
modulate_(txt_mod, shift=mod_shift, scale=mod_scale)
|
421 |
+
|
422 |
+
shape = (*img_mod.shape[:2], self.heads_num, int(img_mod.shape[-1] / self.heads_num) )
|
423 |
+
img_q = self.linear1_attn_q(img_mod).view(*shape)
|
424 |
+
img_k = self.linear1_attn_k(img_mod).view(*shape)
|
425 |
+
img_v = self.linear1_attn_v(img_mod).view(*shape)
|
426 |
+
|
427 |
+
shape = (*txt_mod.shape[:2], self.heads_num, int(txt_mod.shape[-1] / self.heads_num) )
|
428 |
+
txt_q = self.linear1_attn_q(txt_mod).view(*shape)
|
429 |
+
txt_k = self.linear1_attn_k(txt_mod).view(*shape)
|
430 |
+
txt_v = self.linear1_attn_v(txt_mod).view(*shape)
|
431 |
+
|
432 |
+
batch_size = img_mod.shape[0]
|
433 |
+
|
434 |
+
# Apply QK-Norm if needed.
|
435 |
+
# q = self.q_norm(q).to(v)
|
436 |
+
self.q_norm.apply_(img_q)
|
437 |
+
self.k_norm.apply_(img_k)
|
438 |
+
self.q_norm.apply_(txt_q)
|
439 |
+
self.k_norm.apply_(txt_k)
|
440 |
+
|
441 |
+
qklist = [img_q, img_k]
|
442 |
+
del img_q, img_k
|
443 |
+
img_q, img_k = apply_rotary_emb(qklist, freqs_cis, head_first=False)
|
444 |
+
img_q_len=img_q.shape[1]
|
445 |
+
q = torch.cat((img_q, txt_q), dim=1)
|
446 |
+
del img_q, txt_q
|
447 |
+
k = torch.cat((img_k, txt_k), dim=1)
|
448 |
+
img_kv_len=img_k.shape[1]
|
449 |
+
del img_k, txt_k
|
450 |
+
|
451 |
+
v = torch.cat((img_v, txt_v), dim=1)
|
452 |
+
del img_v, txt_v
|
453 |
+
|
454 |
+
# attention computation start
|
455 |
+
qkv_list = [q,k,v]
|
456 |
+
del q, k, v
|
457 |
+
attn = pay_attention(
|
458 |
+
qkv_list,
|
459 |
+
attention_mask=attn_mask,
|
460 |
+
q_lens = seqlens_q,
|
461 |
+
k_lens = seqlens_kv,
|
462 |
+
)
|
463 |
+
b, s, a, d = attn.shape
|
464 |
+
attn = attn.reshape(b, s, -1)
|
465 |
+
del qkv_list
|
466 |
+
# attention computation end
|
467 |
+
|
468 |
+
x_mod = torch.cat((img_mod, txt_mod), 1)
|
469 |
+
del img_mod, txt_mod
|
470 |
+
x_mod_shape = x_mod.shape
|
471 |
+
x_mod = x_mod.view(-1, x_mod.shape[-1])
|
472 |
+
chunk_size = int(x_mod_shape[1]/6)
|
473 |
+
x_chunks = torch.split(x_mod, chunk_size)
|
474 |
+
attn = attn.view(-1, attn.shape[-1])
|
475 |
+
attn_chunks =torch.split(attn, chunk_size)
|
476 |
+
for x_chunk, attn_chunk in zip(x_chunks, attn_chunks):
|
477 |
+
mlp_chunk = self.linear1_mlp(x_chunk)
|
478 |
+
mlp_chunk = self.mlp_act(mlp_chunk)
|
479 |
+
attn_mlp_chunk = torch.cat((attn_chunk, mlp_chunk), -1)
|
480 |
+
del attn_chunk, mlp_chunk
|
481 |
+
x_chunk[...] = self.linear2(attn_mlp_chunk)
|
482 |
+
del attn_mlp_chunk
|
483 |
+
x_mod = x_mod.view(x_mod_shape)
|
484 |
+
|
485 |
+
if condition_type == "token_replace":
|
486 |
+
apply_gate_and_accumulate_(img[:, :frist_frame_token_num, :], x_mod[:, :frist_frame_token_num, :], gate=tr_mod_gate)
|
487 |
+
apply_gate_and_accumulate_(img[:, frist_frame_token_num:, :], x_mod[:, frist_frame_token_num:-txt_len, :], gate=mod_gate)
|
488 |
+
else:
|
489 |
+
apply_gate_and_accumulate_(img, x_mod[:, :-txt_len, :], gate=mod_gate)
|
490 |
+
|
491 |
+
apply_gate_and_accumulate_(txt, x_mod[:, -txt_len:, :], gate=mod_gate)
|
492 |
+
|
493 |
+
return img, txt
|
494 |
+
|
495 |
+
class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin):
|
496 |
+
def preprocess_loras(self, model_filename, sd):
|
497 |
+
if not "i2v" in model_filename:
|
498 |
+
return sd
|
499 |
+
new_sd = {}
|
500 |
+
for k,v in sd.items():
|
501 |
+
repl_list = ["double_blocks", "single_blocks", "final_layer", "img_mlp", "img_attn_qkv", "img_attn_proj","img_mod", "txt_mlp", "txt_attn_qkv","txt_attn_proj", "txt_mod", "linear1",
|
502 |
+
"linear2", "modulation", "mlp_fc1"]
|
503 |
+
src_list = [k +"_" for k in repl_list] + ["_" + k for k in repl_list]
|
504 |
+
tgt_list = [k +"." for k in repl_list] + ["." + k for k in repl_list]
|
505 |
+
if k.startswith("Hunyuan_video_I2V_lora_"):
|
506 |
+
# crappy conversion script for non reversible lora naming
|
507 |
+
k = k.replace("Hunyuan_video_I2V_lora_","diffusion_model.")
|
508 |
+
k = k.replace("lora_up","lora_B")
|
509 |
+
k = k.replace("lora_down","lora_A")
|
510 |
+
if "txt_in_individual" in k:
|
511 |
+
pass
|
512 |
+
for s,t in zip(src_list, tgt_list):
|
513 |
+
k = k.replace(s,t)
|
514 |
+
if "individual_token_refiner" in k:
|
515 |
+
k = k.replace("txt_in_individual_token_refiner_blocks_", "txt_in.individual_token_refiner.blocks.")
|
516 |
+
k = k.replace("_mlp_fc", ".mlp.fc",)
|
517 |
+
k = k.replace(".mlp_fc", ".mlp.fc",)
|
518 |
+
new_sd[k] = v
|
519 |
+
return new_sd
|
520 |
+
"""
|
521 |
+
HunyuanVideo Transformer backbone
|
522 |
+
|
523 |
+
Inherited from ModelMixin and ConfigMixin for compatibility with diffusers' sampler StableDiffusionPipeline.
|
524 |
+
|
525 |
+
Reference:
|
526 |
+
[1] Flux.1: https://github.com/black-forest-labs/flux
|
527 |
+
[2] MMDiT: http://arxiv.org/abs/2403.03206
|
528 |
+
|
529 |
+
Parameters
|
530 |
+
----------
|
531 |
+
args: argparse.Namespace
|
532 |
+
The arguments parsed by argparse.
|
533 |
+
patch_size: list
|
534 |
+
The size of the patch.
|
535 |
+
in_channels: int
|
536 |
+
The number of input channels.
|
537 |
+
out_channels: int
|
538 |
+
The number of output channels.
|
539 |
+
hidden_size: int
|
540 |
+
The hidden size of the transformer backbone.
|
541 |
+
heads_num: int
|
542 |
+
The number of attention heads.
|
543 |
+
mlp_width_ratio: float
|
544 |
+
The ratio of the hidden size of the MLP in the transformer block.
|
545 |
+
mlp_act_type: str
|
546 |
+
The activation function of the MLP in the transformer block.
|
547 |
+
depth_double_blocks: int
|
548 |
+
The number of transformer blocks in the double blocks.
|
549 |
+
depth_single_blocks: int
|
550 |
+
The number of transformer blocks in the single blocks.
|
551 |
+
rope_dim_list: list
|
552 |
+
The dimension of the rotary embedding for t, h, w.
|
553 |
+
qkv_bias: bool
|
554 |
+
Whether to use bias in the qkv linear layer.
|
555 |
+
qk_norm: bool
|
556 |
+
Whether to use qk norm.
|
557 |
+
qk_norm_type: str
|
558 |
+
The type of qk norm.
|
559 |
+
guidance_embed: bool
|
560 |
+
Whether to use guidance embedding for distillation.
|
561 |
+
text_projection: str
|
562 |
+
The type of the text projection, default is single_refiner.
|
563 |
+
use_attention_mask: bool
|
564 |
+
Whether to use attention mask for text encoder.
|
565 |
+
dtype: torch.dtype
|
566 |
+
The dtype of the model.
|
567 |
+
device: torch.device
|
568 |
+
The device of the model.
|
569 |
+
"""
|
570 |
+
|
571 |
+
@register_to_config
|
572 |
+
def __init__(
|
573 |
+
self,
|
574 |
+
i2v_condition_type,
|
575 |
+
patch_size: list = [1, 2, 2],
|
576 |
+
in_channels: int = 4, # Should be VAE.config.latent_channels.
|
577 |
+
out_channels: int = None,
|
578 |
+
hidden_size: int = 3072,
|
579 |
+
heads_num: int = 24,
|
580 |
+
mlp_width_ratio: float = 4.0,
|
581 |
+
mlp_act_type: str = "gelu_tanh",
|
582 |
+
mm_double_blocks_depth: int = 20,
|
583 |
+
mm_single_blocks_depth: int = 40,
|
584 |
+
rope_dim_list: List[int] = [16, 56, 56],
|
585 |
+
qkv_bias: bool = True,
|
586 |
+
qk_norm: bool = True,
|
587 |
+
qk_norm_type: str = "rms",
|
588 |
+
guidance_embed: bool = False, # For modulation.
|
589 |
+
text_projection: str = "single_refiner",
|
590 |
+
use_attention_mask: bool = True,
|
591 |
+
dtype: Optional[torch.dtype] = None,
|
592 |
+
device: Optional[torch.device] = None,
|
593 |
+
attention_mode: Optional[str] = "sdpa",
|
594 |
+
video_condition: bool = False,
|
595 |
+
audio_condition: bool = False,
|
596 |
+
avatar = False,
|
597 |
+
custom = False,
|
598 |
+
):
|
599 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
600 |
+
super().__init__()
|
601 |
+
|
602 |
+
# mm_double_blocks_depth , mm_single_blocks_depth = 5, 5
|
603 |
+
|
604 |
+
self.patch_size = patch_size
|
605 |
+
self.in_channels = in_channels
|
606 |
+
self.out_channels = in_channels if out_channels is None else out_channels
|
607 |
+
self.unpatchify_channels = self.out_channels
|
608 |
+
self.guidance_embed = guidance_embed
|
609 |
+
self.rope_dim_list = rope_dim_list
|
610 |
+
self.i2v_condition_type = i2v_condition_type
|
611 |
+
self.attention_mode = attention_mode
|
612 |
+
self.video_condition = video_condition
|
613 |
+
self.audio_condition = audio_condition
|
614 |
+
self.avatar = avatar
|
615 |
+
self.custom = custom
|
616 |
+
|
617 |
+
# Text projection. Default to linear projection.
|
618 |
+
# Alternative: TokenRefiner. See more details (LI-DiT): http://arxiv.org/abs/2406.11831
|
619 |
+
self.use_attention_mask = use_attention_mask
|
620 |
+
self.text_projection = text_projection
|
621 |
+
|
622 |
+
self.text_states_dim = 4096
|
623 |
+
self.text_states_dim_2 = 768
|
624 |
+
|
625 |
+
if hidden_size % heads_num != 0:
|
626 |
+
raise ValueError(
|
627 |
+
f"Hidden size {hidden_size} must be divisible by heads_num {heads_num}"
|
628 |
+
)
|
629 |
+
pe_dim = hidden_size // heads_num
|
630 |
+
if sum(rope_dim_list) != pe_dim:
|
631 |
+
raise ValueError(
|
632 |
+
f"Got {rope_dim_list} but expected positional dim {pe_dim}"
|
633 |
+
)
|
634 |
+
self.hidden_size = hidden_size
|
635 |
+
self.heads_num = heads_num
|
636 |
+
|
637 |
+
# image projection
|
638 |
+
self.img_in = PatchEmbed(
|
639 |
+
self.patch_size, self.in_channels, self.hidden_size, **factory_kwargs
|
640 |
+
)
|
641 |
+
|
642 |
+
# text projection
|
643 |
+
if self.text_projection == "linear":
|
644 |
+
self.txt_in = TextProjection(
|
645 |
+
self.text_states_dim,
|
646 |
+
self.hidden_size,
|
647 |
+
get_activation_layer("silu"),
|
648 |
+
**factory_kwargs,
|
649 |
+
)
|
650 |
+
elif self.text_projection == "single_refiner":
|
651 |
+
self.txt_in = SingleTokenRefiner(
|
652 |
+
self.text_states_dim, hidden_size, heads_num, depth=2, **factory_kwargs
|
653 |
+
)
|
654 |
+
else:
|
655 |
+
raise NotImplementedError(
|
656 |
+
f"Unsupported text_projection: {self.text_projection}"
|
657 |
+
)
|
658 |
+
|
659 |
+
# time modulation
|
660 |
+
self.time_in = TimestepEmbedder(
|
661 |
+
self.hidden_size, get_activation_layer("silu"), **factory_kwargs
|
662 |
+
)
|
663 |
+
|
664 |
+
# text modulation
|
665 |
+
self.vector_in = MLPEmbedder(
|
666 |
+
self.text_states_dim_2, self.hidden_size, **factory_kwargs
|
667 |
+
)
|
668 |
+
|
669 |
+
# guidance modulation
|
670 |
+
self.guidance_in = (
|
671 |
+
TimestepEmbedder(
|
672 |
+
self.hidden_size, get_activation_layer("silu"), **factory_kwargs
|
673 |
+
)
|
674 |
+
if guidance_embed
|
675 |
+
else None
|
676 |
+
)
|
677 |
+
|
678 |
+
# double blocks
|
679 |
+
self.double_blocks = nn.ModuleList(
|
680 |
+
[
|
681 |
+
MMDoubleStreamBlock(
|
682 |
+
self.hidden_size,
|
683 |
+
self.heads_num,
|
684 |
+
mlp_width_ratio=mlp_width_ratio,
|
685 |
+
mlp_act_type=mlp_act_type,
|
686 |
+
qk_norm=qk_norm,
|
687 |
+
qk_norm_type=qk_norm_type,
|
688 |
+
qkv_bias=qkv_bias,
|
689 |
+
attention_mode = attention_mode,
|
690 |
+
**factory_kwargs,
|
691 |
+
)
|
692 |
+
for _ in range(mm_double_blocks_depth)
|
693 |
+
]
|
694 |
+
)
|
695 |
+
|
696 |
+
# single blocks
|
697 |
+
self.single_blocks = nn.ModuleList(
|
698 |
+
[
|
699 |
+
MMSingleStreamBlock(
|
700 |
+
self.hidden_size,
|
701 |
+
self.heads_num,
|
702 |
+
mlp_width_ratio=mlp_width_ratio,
|
703 |
+
mlp_act_type=mlp_act_type,
|
704 |
+
qk_norm=qk_norm,
|
705 |
+
qk_norm_type=qk_norm_type,
|
706 |
+
attention_mode = attention_mode,
|
707 |
+
**factory_kwargs,
|
708 |
+
)
|
709 |
+
for _ in range(mm_single_blocks_depth)
|
710 |
+
]
|
711 |
+
)
|
712 |
+
|
713 |
+
self.final_layer = FinalLayer(
|
714 |
+
self.hidden_size,
|
715 |
+
self.patch_size,
|
716 |
+
self.out_channels,
|
717 |
+
get_activation_layer("silu"),
|
718 |
+
**factory_kwargs,
|
719 |
+
)
|
720 |
+
|
721 |
+
if self.video_condition:
|
722 |
+
self.bg_in = PatchEmbed(
|
723 |
+
self.patch_size, self.in_channels * 2, self.hidden_size, **factory_kwargs
|
724 |
+
)
|
725 |
+
self.bg_proj = nn.Linear(self.hidden_size, self.hidden_size)
|
726 |
+
|
727 |
+
if audio_condition:
|
728 |
+
if avatar:
|
729 |
+
self.ref_in = PatchEmbed(
|
730 |
+
self.patch_size, self.in_channels, self.hidden_size, **factory_kwargs
|
731 |
+
)
|
732 |
+
|
733 |
+
# -------------------- audio_proj_model --------------------
|
734 |
+
self.audio_proj = AudioProjNet2(seq_len=10, blocks=5, channels=384, intermediate_dim=1024, output_dim=3072, context_tokens=4)
|
735 |
+
|
736 |
+
# -------------------- motion-embeder --------------------
|
737 |
+
self.motion_exp = TimestepEmbedder(
|
738 |
+
self.hidden_size // 4,
|
739 |
+
get_activation_layer("silu"),
|
740 |
+
**factory_kwargs
|
741 |
+
)
|
742 |
+
self.motion_pose = TimestepEmbedder(
|
743 |
+
self.hidden_size // 4,
|
744 |
+
get_activation_layer("silu"),
|
745 |
+
**factory_kwargs
|
746 |
+
)
|
747 |
+
|
748 |
+
self.fps_proj = TimestepEmbedder(
|
749 |
+
self.hidden_size,
|
750 |
+
get_activation_layer("silu"),
|
751 |
+
**factory_kwargs
|
752 |
+
)
|
753 |
+
|
754 |
+
self.before_proj = nn.Linear(self.hidden_size, self.hidden_size)
|
755 |
+
|
756 |
+
# -------------------- audio_insert_model --------------------
|
757 |
+
self.double_stream_list = [1, 3, 5, 7, 9, 11, 13, 15, 17, 19]
|
758 |
+
audio_block_name = "audio_adapter_blocks"
|
759 |
+
elif custom:
|
760 |
+
self.audio_proj = AudioProjNet2(seq_len=10, blocks=5, channels=384, intermediate_dim=1024, output_dim=3072, context_tokens=4)
|
761 |
+
self.double_stream_list = [1, 3, 5, 7, 9, 11]
|
762 |
+
audio_block_name = "audio_models"
|
763 |
+
|
764 |
+
self.double_stream_map = {str(i): j for j, i in enumerate(self.double_stream_list)}
|
765 |
+
self.single_stream_list = []
|
766 |
+
self.single_stream_map = {str(i): j+len(self.double_stream_list) for j, i in enumerate(self.single_stream_list)}
|
767 |
+
setattr(self, audio_block_name, nn.ModuleList([
|
768 |
+
PerceiverAttentionCA(dim=3072, dim_head=1024, heads=33) for _ in range(len(self.double_stream_list) + len(self.single_stream_list))
|
769 |
+
]))
|
770 |
+
|
771 |
+
|
772 |
+
|
773 |
+
def lock_layers_dtypes(self, dtype = torch.float32):
|
774 |
+
layer_list = [self.final_layer, self.final_layer.linear, self.final_layer.adaLN_modulation[1]]
|
775 |
+
target_dype= dtype
|
776 |
+
|
777 |
+
for current_layer_list, current_dtype in zip([layer_list], [target_dype]):
|
778 |
+
for layer in current_layer_list:
|
779 |
+
layer._lock_dtype = dtype
|
780 |
+
|
781 |
+
if hasattr(layer, "weight") and layer.weight.dtype != current_dtype :
|
782 |
+
layer.weight.data = layer.weight.data.to(current_dtype)
|
783 |
+
if hasattr(layer, "bias"):
|
784 |
+
layer.bias.data = layer.bias.data.to(current_dtype)
|
785 |
+
|
786 |
+
self._lock_dtype = dtype
|
787 |
+
|
788 |
+
def enable_deterministic(self):
|
789 |
+
for block in self.double_blocks:
|
790 |
+
block.enable_deterministic()
|
791 |
+
for block in self.single_blocks:
|
792 |
+
block.enable_deterministic()
|
793 |
+
|
794 |
+
def disable_deterministic(self):
|
795 |
+
for block in self.double_blocks:
|
796 |
+
block.disable_deterministic()
|
797 |
+
for block in self.single_blocks:
|
798 |
+
block.disable_deterministic()
|
799 |
+
|
800 |
+
def forward(
|
801 |
+
self,
|
802 |
+
x: torch.Tensor,
|
803 |
+
t: torch.Tensor, # Should be in range(0, 1000).
|
804 |
+
ref_latents: torch.Tensor=None,
|
805 |
+
text_states: torch.Tensor = None,
|
806 |
+
text_mask: torch.Tensor = None, # Now we don't use it.
|
807 |
+
text_states_2: Optional[torch.Tensor] = None, # Text embedding for modulation.
|
808 |
+
freqs_cos: Optional[torch.Tensor] = None,
|
809 |
+
freqs_sin: Optional[torch.Tensor] = None,
|
810 |
+
guidance: torch.Tensor = None, # Guidance for modulation, should be cfg_scale x 1000.
|
811 |
+
pipeline=None,
|
812 |
+
x_id = 0,
|
813 |
+
step_no = 0,
|
814 |
+
callback = None,
|
815 |
+
audio_prompts = None,
|
816 |
+
motion_exp = None,
|
817 |
+
motion_pose = None,
|
818 |
+
fps = None,
|
819 |
+
face_mask = None,
|
820 |
+
audio_strength = None,
|
821 |
+
bg_latents = None,
|
822 |
+
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
823 |
+
|
824 |
+
img = x
|
825 |
+
bsz, _, ot, oh, ow = x.shape
|
826 |
+
del x
|
827 |
+
txt = text_states
|
828 |
+
tt, th, tw = (
|
829 |
+
ot // self.patch_size[0],
|
830 |
+
oh // self.patch_size[1],
|
831 |
+
ow // self.patch_size[2],
|
832 |
+
)
|
833 |
+
|
834 |
+
# Prepare modulation vectors.
|
835 |
+
vec = self.time_in(t)
|
836 |
+
if motion_exp != None:
|
837 |
+
vec += self.motion_exp(motion_exp.view(-1)).view(bsz, -1) # (b, 3072)
|
838 |
+
if motion_pose != None:
|
839 |
+
vec += self.motion_pose(motion_pose.view(-1)).view(bsz, -1) # (b, 3072)
|
840 |
+
if fps != None:
|
841 |
+
vec += self.fps_proj(fps) # (b, 3072)
|
842 |
+
if audio_prompts != None:
|
843 |
+
audio_feature_all = self.audio_proj(audio_prompts)
|
844 |
+
audio_feature_pad = audio_feature_all[:,:1].repeat(1,3,1,1)
|
845 |
+
audio_feature_all_insert = torch.cat([audio_feature_pad, audio_feature_all], dim=1).view(bsz, ot, 16, 3072)
|
846 |
+
audio_feature_all = None
|
847 |
+
|
848 |
+
if self.i2v_condition_type == "token_replace":
|
849 |
+
token_replace_t = torch.zeros_like(t)
|
850 |
+
token_replace_vec = self.time_in(token_replace_t)
|
851 |
+
frist_frame_token_num = th * tw
|
852 |
+
else:
|
853 |
+
token_replace_vec = None
|
854 |
+
frist_frame_token_num = None
|
855 |
+
# token_replace_mask_img = None
|
856 |
+
# token_replace_mask_txt = None
|
857 |
+
|
858 |
+
# text modulation
|
859 |
+
vec_2 = self.vector_in(text_states_2)
|
860 |
+
del text_states_2
|
861 |
+
vec += vec_2
|
862 |
+
if self.i2v_condition_type == "token_replace":
|
863 |
+
token_replace_vec += vec_2
|
864 |
+
del vec_2
|
865 |
+
|
866 |
+
# guidance modulation
|
867 |
+
if self.guidance_embed:
|
868 |
+
if guidance is None:
|
869 |
+
raise ValueError(
|
870 |
+
"Didn't get guidance strength for guidance distilled model."
|
871 |
+
)
|
872 |
+
|
873 |
+
# our timestep_embedding is merged into guidance_in(TimestepEmbedder)
|
874 |
+
vec += self.guidance_in(guidance)
|
875 |
+
|
876 |
+
# Embed image and text.
|
877 |
+
img, shape_mask = self.img_in(img)
|
878 |
+
if self.avatar:
|
879 |
+
ref_latents_first = ref_latents[:, :, :1].clone()
|
880 |
+
ref_latents,_ = self.ref_in(ref_latents)
|
881 |
+
ref_latents_first,_ = self.img_in(ref_latents_first)
|
882 |
+
elif self.custom:
|
883 |
+
if ref_latents != None:
|
884 |
+
ref_latents, _ = self.img_in(ref_latents)
|
885 |
+
if bg_latents is not None and self.video_condition:
|
886 |
+
bg_latents, _ = self.bg_in(bg_latents)
|
887 |
+
img += self.bg_proj(bg_latents)
|
888 |
+
|
889 |
+
if self.text_projection == "linear":
|
890 |
+
txt = self.txt_in(txt)
|
891 |
+
elif self.text_projection == "single_refiner":
|
892 |
+
txt = self.txt_in(txt, t, text_mask if self.use_attention_mask else None)
|
893 |
+
else:
|
894 |
+
raise NotImplementedError(
|
895 |
+
f"Unsupported text_projection: {self.text_projection}"
|
896 |
+
)
|
897 |
+
|
898 |
+
if self.avatar:
|
899 |
+
img += self.before_proj(ref_latents)
|
900 |
+
ref_length = ref_latents_first.shape[-2] # [b s c]
|
901 |
+
img = torch.cat([ref_latents_first, img], dim=-2) # t c
|
902 |
+
img_len = img.shape[1]
|
903 |
+
mask_len = img_len - ref_length
|
904 |
+
if face_mask.shape[2] == 1:
|
905 |
+
face_mask = face_mask.repeat(1,1,ot,1,1) # repeat if number of mask frame is 1
|
906 |
+
face_mask = torch.nn.functional.interpolate(face_mask, size=[ot, shape_mask[-2], shape_mask[-1]], mode="nearest")
|
907 |
+
# face_mask = face_mask.view(-1,mask_len,1).repeat(1,1,img.shape[-1]).type_as(img)
|
908 |
+
face_mask = face_mask.view(-1,mask_len,1).type_as(img)
|
909 |
+
elif ref_latents == None:
|
910 |
+
ref_length = None
|
911 |
+
else:
|
912 |
+
ref_length = ref_latents.shape[-2]
|
913 |
+
img = torch.cat([ref_latents, img], dim=-2) # t c
|
914 |
+
txt_seq_len = txt.shape[1]
|
915 |
+
img_seq_len = img.shape[1]
|
916 |
+
|
917 |
+
text_len = text_mask.sum(1)
|
918 |
+
total_len = text_len + img_seq_len
|
919 |
+
seqlens_q = seqlens_kv = total_len
|
920 |
+
attn_mask = None
|
921 |
+
|
922 |
+
freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None
|
923 |
+
|
924 |
+
|
925 |
+
if self.enable_teacache:
|
926 |
+
if x_id == 0:
|
927 |
+
self.should_calc = True
|
928 |
+
inp = img[0:1]
|
929 |
+
vec_ = vec[0:1]
|
930 |
+
( img_mod1_shift, img_mod1_scale, _ , _ , _ , _ , ) = self.double_blocks[0].img_mod(vec_).chunk(6, dim=-1)
|
931 |
+
normed_inp = self.double_blocks[0].img_norm1(inp)
|
932 |
+
normed_inp = normed_inp.to(torch.bfloat16)
|
933 |
+
modulated_inp = modulate( normed_inp, shift=img_mod1_shift, scale=img_mod1_scale )
|
934 |
+
del normed_inp, img_mod1_shift, img_mod1_scale
|
935 |
+
if step_no <= self.teacache_start_step or step_no == self.num_steps-1:
|
936 |
+
self.accumulated_rel_l1_distance = 0
|
937 |
+
else:
|
938 |
+
coefficients = [7.33226126e+02, -4.01131952e+02, 6.75869174e+01, -3.14987800e+00, 9.61237896e-02]
|
939 |
+
rescale_func = np.poly1d(coefficients)
|
940 |
+
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
|
941 |
+
if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
|
942 |
+
self.should_calc = False
|
943 |
+
self.teacache_skipped_steps += 1
|
944 |
+
else:
|
945 |
+
self.accumulated_rel_l1_distance = 0
|
946 |
+
self.previous_modulated_input = modulated_inp
|
947 |
+
else:
|
948 |
+
self.should_calc = True
|
949 |
+
|
950 |
+
if not self.should_calc:
|
951 |
+
img += self.previous_residual[x_id]
|
952 |
+
else:
|
953 |
+
if self.enable_teacache:
|
954 |
+
self.previous_residual[x_id] = None
|
955 |
+
ori_img = img[0:1].clone()
|
956 |
+
# --------------------- Pass through DiT blocks ------------------------
|
957 |
+
for layer_num, block in enumerate(self.double_blocks):
|
958 |
+
for i in range(len(img)):
|
959 |
+
if callback != None:
|
960 |
+
callback(-1, None, False, True)
|
961 |
+
if pipeline._interrupt:
|
962 |
+
return None
|
963 |
+
double_block_args = [
|
964 |
+
img[i:i+1],
|
965 |
+
txt[i:i+1],
|
966 |
+
vec[i:i+1],
|
967 |
+
attn_mask,
|
968 |
+
seqlens_q[i:i+1],
|
969 |
+
seqlens_kv[i:i+1],
|
970 |
+
freqs_cis,
|
971 |
+
self.i2v_condition_type,
|
972 |
+
token_replace_vec,
|
973 |
+
frist_frame_token_num,
|
974 |
+
]
|
975 |
+
|
976 |
+
img[i], txt[i] = block(*double_block_args)
|
977 |
+
double_block_args = None
|
978 |
+
# insert audio feature to img
|
979 |
+
if audio_prompts != None:
|
980 |
+
audio_adapter = getattr(self.double_blocks[layer_num], "audio_adapter", None)
|
981 |
+
if audio_adapter != None:
|
982 |
+
real_img = img[i:i+1,ref_length:].view(1, ot, -1, 3072)
|
983 |
+
real_img = audio_adapter(audio_feature_all_insert[i:i+1], real_img).view(1, -1, 3072)
|
984 |
+
if face_mask != None:
|
985 |
+
real_img *= face_mask[i:i+1]
|
986 |
+
if audio_strength != None and audio_strength != 1:
|
987 |
+
real_img *= audio_strength
|
988 |
+
img[i:i+1, ref_length:] += real_img
|
989 |
+
real_img = None
|
990 |
+
|
991 |
+
|
992 |
+
for _, block in enumerate(self.single_blocks):
|
993 |
+
for i in range(len(img)):
|
994 |
+
if callback != None:
|
995 |
+
callback(-1, None, False, True)
|
996 |
+
if pipeline._interrupt:
|
997 |
+
return None
|
998 |
+
single_block_args = [
|
999 |
+
# x,
|
1000 |
+
img[i:i+1],
|
1001 |
+
txt[i:i+1],
|
1002 |
+
vec[i:i+1],
|
1003 |
+
txt_seq_len,
|
1004 |
+
attn_mask,
|
1005 |
+
seqlens_q[i:i+1],
|
1006 |
+
seqlens_kv[i:i+1],
|
1007 |
+
(freqs_cos, freqs_sin),
|
1008 |
+
self.i2v_condition_type,
|
1009 |
+
token_replace_vec,
|
1010 |
+
frist_frame_token_num,
|
1011 |
+
]
|
1012 |
+
|
1013 |
+
img[i], txt[i] = block(*single_block_args)
|
1014 |
+
single_block_args = None
|
1015 |
+
|
1016 |
+
# img = x[:, :img_seq_len, ...]
|
1017 |
+
if self.enable_teacache:
|
1018 |
+
if len(img) > 1:
|
1019 |
+
self.previous_residual[0] = torch.empty_like(img)
|
1020 |
+
for i, (x, residual) in enumerate(zip(img, self.previous_residual[0])):
|
1021 |
+
if i < len(img) - 1:
|
1022 |
+
residual[...] = torch.sub(x, ori_img)
|
1023 |
+
else:
|
1024 |
+
residual[...] = ori_img
|
1025 |
+
torch.sub(x, ori_img, out=residual)
|
1026 |
+
x = None
|
1027 |
+
else:
|
1028 |
+
self.previous_residual[x_id] = ori_img
|
1029 |
+
torch.sub(img, ori_img, out=self.previous_residual[x_id])
|
1030 |
+
|
1031 |
+
|
1032 |
+
if ref_length != None:
|
1033 |
+
img = img[:, ref_length:]
|
1034 |
+
# ---------------------------- Final layer ------------------------------
|
1035 |
+
out_dtype = self.final_layer.linear.weight.dtype
|
1036 |
+
vec = vec.to(out_dtype)
|
1037 |
+
img_list = []
|
1038 |
+
for img_chunk, vec_chunk in zip(img,vec):
|
1039 |
+
img_list.append( self.final_layer(img_chunk.to(out_dtype).unsqueeze(0), vec_chunk.unsqueeze(0))) # (N, T, patch_size ** 2 * out_channels)
|
1040 |
+
img = torch.cat(img_list)
|
1041 |
+
img_list = None
|
1042 |
+
|
1043 |
+
# img = self.unpatchify(img, tt, th, tw)
|
1044 |
+
img = self.unpatchify(img, tt, th, tw)
|
1045 |
+
|
1046 |
+
return img
|
1047 |
+
|
1048 |
+
def unpatchify(self, x, t, h, w):
|
1049 |
+
"""
|
1050 |
+
x: (N, T, patch_size**2 * C)
|
1051 |
+
imgs: (N, H, W, C)
|
1052 |
+
"""
|
1053 |
+
c = self.unpatchify_channels
|
1054 |
+
pt, ph, pw = self.patch_size
|
1055 |
+
assert t * h * w == x.shape[1]
|
1056 |
+
|
1057 |
+
x = x.reshape(shape=(x.shape[0], t, h, w, c, pt, ph, pw))
|
1058 |
+
x = torch.einsum("nthwcopq->nctohpwq", x)
|
1059 |
+
imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw))
|
1060 |
+
|
1061 |
+
return imgs
|
1062 |
+
|
1063 |
+
def params_count(self):
|
1064 |
+
counts = {
|
1065 |
+
"double": sum(
|
1066 |
+
[
|
1067 |
+
sum(p.numel() for p in block.img_attn_qkv.parameters())
|
1068 |
+
+ sum(p.numel() for p in block.img_attn_proj.parameters())
|
1069 |
+
+ sum(p.numel() for p in block.img_mlp.parameters())
|
1070 |
+
+ sum(p.numel() for p in block.txt_attn_qkv.parameters())
|
1071 |
+
+ sum(p.numel() for p in block.txt_attn_proj.parameters())
|
1072 |
+
+ sum(p.numel() for p in block.txt_mlp.parameters())
|
1073 |
+
for block in self.double_blocks
|
1074 |
+
]
|
1075 |
+
),
|
1076 |
+
"single": sum(
|
1077 |
+
[
|
1078 |
+
sum(p.numel() for p in block.linear1.parameters())
|
1079 |
+
+ sum(p.numel() for p in block.linear2.parameters())
|
1080 |
+
for block in self.single_blocks
|
1081 |
+
]
|
1082 |
+
),
|
1083 |
+
"total": sum(p.numel() for p in self.parameters()),
|
1084 |
+
}
|
1085 |
+
counts["attn+mlp"] = counts["double"] + counts["single"]
|
1086 |
+
return counts
|
1087 |
+
|
1088 |
+
|
1089 |
+
#################################################################################
|
1090 |
+
# HunyuanVideo Configs #
|
1091 |
+
#################################################################################
|
1092 |
+
|
1093 |
+
HUNYUAN_VIDEO_CONFIG = {
|
1094 |
+
"HYVideo-T/2": {
|
1095 |
+
"mm_double_blocks_depth": 20,
|
1096 |
+
"mm_single_blocks_depth": 40,
|
1097 |
+
"rope_dim_list": [16, 56, 56],
|
1098 |
+
"hidden_size": 3072,
|
1099 |
+
"heads_num": 24,
|
1100 |
+
"mlp_width_ratio": 4,
|
1101 |
+
},
|
1102 |
+
"HYVideo-T/2-cfgdistill": {
|
1103 |
+
"mm_double_blocks_depth": 20,
|
1104 |
+
"mm_single_blocks_depth": 40,
|
1105 |
+
"rope_dim_list": [16, 56, 56],
|
1106 |
+
"hidden_size": 3072,
|
1107 |
+
"heads_num": 24,
|
1108 |
+
"mlp_width_ratio": 4,
|
1109 |
+
"guidance_embed": True,
|
1110 |
+
},
|
1111 |
+
"HYVideo-S/2": {
|
1112 |
+
"mm_double_blocks_depth": 6,
|
1113 |
+
"mm_single_blocks_depth": 12,
|
1114 |
+
"rope_dim_list": [12, 42, 42],
|
1115 |
+
"hidden_size": 480,
|
1116 |
+
"heads_num": 5,
|
1117 |
+
"mlp_width_ratio": 4,
|
1118 |
+
},
|
1119 |
+
'HYVideo-T/2-custom': { # 9.0B / 12.5B
|
1120 |
+
"mm_double_blocks_depth": 20,
|
1121 |
+
"mm_single_blocks_depth": 40,
|
1122 |
+
"rope_dim_list": [16, 56, 56],
|
1123 |
+
"hidden_size": 3072,
|
1124 |
+
"heads_num": 24,
|
1125 |
+
"mlp_width_ratio": 4,
|
1126 |
+
'custom' : True
|
1127 |
+
},
|
1128 |
+
'HYVideo-T/2-custom-audio': { # 9.0B / 12.5B
|
1129 |
+
"mm_double_blocks_depth": 20,
|
1130 |
+
"mm_single_blocks_depth": 40,
|
1131 |
+
"rope_dim_list": [16, 56, 56],
|
1132 |
+
"hidden_size": 3072,
|
1133 |
+
"heads_num": 24,
|
1134 |
+
"mlp_width_ratio": 4,
|
1135 |
+
'custom' : True,
|
1136 |
+
'audio_condition' : True,
|
1137 |
+
},
|
1138 |
+
'HYVideo-T/2-custom-edit': { # 9.0B / 12.5B
|
1139 |
+
"mm_double_blocks_depth": 20,
|
1140 |
+
"mm_single_blocks_depth": 40,
|
1141 |
+
"rope_dim_list": [16, 56, 56],
|
1142 |
+
"hidden_size": 3072,
|
1143 |
+
"heads_num": 24,
|
1144 |
+
"mlp_width_ratio": 4,
|
1145 |
+
'custom' : True,
|
1146 |
+
'video_condition' : True,
|
1147 |
+
},
|
1148 |
+
'HYVideo-T/2-avatar': { # 9.0B / 12.5B
|
1149 |
+
'mm_double_blocks_depth': 20,
|
1150 |
+
'mm_single_blocks_depth': 40,
|
1151 |
+
'rope_dim_list': [16, 56, 56],
|
1152 |
+
'hidden_size': 3072,
|
1153 |
+
'heads_num': 24,
|
1154 |
+
'mlp_width_ratio': 4,
|
1155 |
+
'avatar': True,
|
1156 |
+
'audio_condition' : True,
|
1157 |
+
},
|
1158 |
+
|
1159 |
+
}
|
hyvideo/modules/modulate_layers.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Callable
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import math
|
6 |
+
|
7 |
+
class ModulateDiT(nn.Module):
|
8 |
+
"""Modulation layer for DiT."""
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
hidden_size: int,
|
12 |
+
factor: int,
|
13 |
+
act_layer: Callable,
|
14 |
+
dtype=None,
|
15 |
+
device=None,
|
16 |
+
):
|
17 |
+
factory_kwargs = {"dtype": dtype, "device": device}
|
18 |
+
super().__init__()
|
19 |
+
self.act = act_layer()
|
20 |
+
self.linear = nn.Linear(
|
21 |
+
hidden_size, factor * hidden_size, bias=True, **factory_kwargs
|
22 |
+
)
|
23 |
+
# Zero-initialize the modulation
|
24 |
+
nn.init.zeros_(self.linear.weight)
|
25 |
+
nn.init.zeros_(self.linear.bias)
|
26 |
+
|
27 |
+
def forward(self, x: torch.Tensor, condition_type=None, token_replace_vec=None) -> torch.Tensor:
|
28 |
+
x_out = self.linear(self.act(x))
|
29 |
+
|
30 |
+
if condition_type == "token_replace":
|
31 |
+
x_token_replace_out = self.linear(self.act(token_replace_vec))
|
32 |
+
return x_out, x_token_replace_out
|
33 |
+
else:
|
34 |
+
return x_out
|
35 |
+
|
36 |
+
def modulate(x, shift=None, scale=None):
|
37 |
+
"""modulate by shift and scale
|
38 |
+
|
39 |
+
Args:
|
40 |
+
x (torch.Tensor): input tensor.
|
41 |
+
shift (torch.Tensor, optional): shift tensor. Defaults to None.
|
42 |
+
scale (torch.Tensor, optional): scale tensor. Defaults to None.
|
43 |
+
|
44 |
+
Returns:
|
45 |
+
torch.Tensor: the output tensor after modulate.
|
46 |
+
"""
|
47 |
+
if scale is None and shift is None:
|
48 |
+
return x
|
49 |
+
elif shift is None:
|
50 |
+
return x * (1 + scale.unsqueeze(1))
|
51 |
+
elif scale is None:
|
52 |
+
return x + shift.unsqueeze(1)
|
53 |
+
else:
|
54 |
+
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
55 |
+
|
56 |
+
def modulate_(x, shift=None, scale=None):
|
57 |
+
|
58 |
+
if scale is None and shift is None:
|
59 |
+
return x
|
60 |
+
elif shift is None:
|
61 |
+
scale = scale + 1
|
62 |
+
scale = scale.unsqueeze(1)
|
63 |
+
return x.mul_(scale)
|
64 |
+
elif scale is None:
|
65 |
+
return x + shift.unsqueeze(1)
|
66 |
+
else:
|
67 |
+
scale = scale + 1
|
68 |
+
scale = scale.unsqueeze(1)
|
69 |
+
# return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
70 |
+
torch.addcmul(shift.unsqueeze(1), x, scale, out =x )
|
71 |
+
return x
|
72 |
+
|
73 |
+
def modulate(x, shift=None, scale=None, condition_type=None,
|
74 |
+
tr_shift=None, tr_scale=None,
|
75 |
+
frist_frame_token_num=None):
|
76 |
+
if condition_type == "token_replace":
|
77 |
+
x_zero = x[:, :frist_frame_token_num] * (1 + tr_scale.unsqueeze(1)) + tr_shift.unsqueeze(1)
|
78 |
+
x_orig = x[:, frist_frame_token_num:] * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
79 |
+
x = torch.concat((x_zero, x_orig), dim=1)
|
80 |
+
return x
|
81 |
+
else:
|
82 |
+
if scale is None and shift is None:
|
83 |
+
return x
|
84 |
+
elif shift is None:
|
85 |
+
return x * (1 + scale.unsqueeze(1))
|
86 |
+
elif scale is None:
|
87 |
+
return x + shift.unsqueeze(1)
|
88 |
+
else:
|
89 |
+
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
90 |
+
|
91 |
+
def apply_gate(x, gate=None, tanh=False, condition_type=None, tr_gate=None, frist_frame_token_num=None):
|
92 |
+
"""AI is creating summary for apply_gate
|
93 |
+
|
94 |
+
Args:
|
95 |
+
x (torch.Tensor): input tensor.
|
96 |
+
gate (torch.Tensor, optional): gate tensor. Defaults to None.
|
97 |
+
tanh (bool, optional): whether to use tanh function. Defaults to False.
|
98 |
+
|
99 |
+
Returns:
|
100 |
+
torch.Tensor: the output tensor after apply gate.
|
101 |
+
"""
|
102 |
+
if condition_type == "token_replace":
|
103 |
+
if gate is None:
|
104 |
+
return x
|
105 |
+
if tanh:
|
106 |
+
x_zero = x[:, :frist_frame_token_num] * tr_gate.unsqueeze(1).tanh()
|
107 |
+
x_orig = x[:, frist_frame_token_num:] * gate.unsqueeze(1).tanh()
|
108 |
+
x = torch.concat((x_zero, x_orig), dim=1)
|
109 |
+
return x
|
110 |
+
else:
|
111 |
+
x_zero = x[:, :frist_frame_token_num] * tr_gate.unsqueeze(1)
|
112 |
+
x_orig = x[:, frist_frame_token_num:] * gate.unsqueeze(1)
|
113 |
+
x = torch.concat((x_zero, x_orig), dim=1)
|
114 |
+
return x
|
115 |
+
else:
|
116 |
+
if gate is None:
|
117 |
+
return x
|
118 |
+
if tanh:
|
119 |
+
return x * gate.unsqueeze(1).tanh()
|
120 |
+
else:
|
121 |
+
return x * gate.unsqueeze(1)
|
122 |
+
|
123 |
+
def apply_gate_and_accumulate_(accumulator, x, gate=None, tanh=False):
|
124 |
+
if gate is None:
|
125 |
+
return accumulator
|
126 |
+
if tanh:
|
127 |
+
return accumulator.addcmul_(x, gate.unsqueeze(1).tanh())
|
128 |
+
else:
|
129 |
+
return accumulator.addcmul_(x, gate.unsqueeze(1))
|
130 |
+
|
131 |
+
def ckpt_wrapper(module):
|
132 |
+
def ckpt_forward(*inputs):
|
133 |
+
outputs = module(*inputs)
|
134 |
+
return outputs
|
135 |
+
|
136 |
+
return ckpt_forward
|
hyvideo/modules/norm_layers.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
|
5 |
+
class RMSNorm(nn.Module):
|
6 |
+
def __init__(
|
7 |
+
self,
|
8 |
+
dim: int,
|
9 |
+
elementwise_affine=True,
|
10 |
+
eps: float = 1e-6,
|
11 |
+
device=None,
|
12 |
+
dtype=None,
|
13 |
+
):
|
14 |
+
"""
|
15 |
+
Initialize the RMSNorm normalization layer.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
dim (int): The dimension of the input tensor.
|
19 |
+
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
|
20 |
+
|
21 |
+
Attributes:
|
22 |
+
eps (float): A small value added to the denominator for numerical stability.
|
23 |
+
weight (nn.Parameter): Learnable scaling parameter.
|
24 |
+
|
25 |
+
"""
|
26 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
27 |
+
super().__init__()
|
28 |
+
self.eps = eps
|
29 |
+
if elementwise_affine:
|
30 |
+
self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
|
31 |
+
|
32 |
+
def _norm(self, x):
|
33 |
+
"""
|
34 |
+
Apply the RMSNorm normalization to the input tensor.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
x (torch.Tensor): The input tensor.
|
38 |
+
|
39 |
+
Returns:
|
40 |
+
torch.Tensor: The normalized tensor.
|
41 |
+
|
42 |
+
"""
|
43 |
+
|
44 |
+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
45 |
+
|
46 |
+
def forward(self, x):
|
47 |
+
"""
|
48 |
+
Forward pass through the RMSNorm layer.
|
49 |
+
|
50 |
+
Args:
|
51 |
+
x (torch.Tensor): The input tensor.
|
52 |
+
|
53 |
+
Returns:
|
54 |
+
torch.Tensor: The output tensor after applying RMSNorm.
|
55 |
+
|
56 |
+
"""
|
57 |
+
output = self._norm(x.float()).type_as(x)
|
58 |
+
if hasattr(self, "weight"):
|
59 |
+
output = output * self.weight
|
60 |
+
return output
|
61 |
+
|
62 |
+
def apply_(self, x):
|
63 |
+
y = x.pow(2).mean(-1, keepdim=True)
|
64 |
+
y.add_(self.eps)
|
65 |
+
y.rsqrt_()
|
66 |
+
x.mul_(y)
|
67 |
+
del y
|
68 |
+
if hasattr(self, "weight"):
|
69 |
+
x.mul_(self.weight)
|
70 |
+
return x
|
71 |
+
|
72 |
+
|
73 |
+
def get_norm_layer(norm_layer):
|
74 |
+
"""
|
75 |
+
Get the normalization layer.
|
76 |
+
|
77 |
+
Args:
|
78 |
+
norm_layer (str): The type of normalization layer.
|
79 |
+
|
80 |
+
Returns:
|
81 |
+
norm_layer (nn.Module): The normalization layer.
|
82 |
+
"""
|
83 |
+
if norm_layer == "layer":
|
84 |
+
return nn.LayerNorm
|
85 |
+
elif norm_layer == "rms":
|
86 |
+
return RMSNorm
|
87 |
+
else:
|
88 |
+
raise NotImplementedError(f"Norm layer {norm_layer} is not implemented")
|
hyvideo/modules/original models.py
ADDED
@@ -0,0 +1,760 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, List, Tuple, Optional, Union, Dict
|
2 |
+
from einops import rearrange
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
from diffusers.models import ModelMixin
|
9 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
10 |
+
|
11 |
+
from .activation_layers import get_activation_layer
|
12 |
+
from .norm_layers import get_norm_layer
|
13 |
+
from .embed_layers import TimestepEmbedder, PatchEmbed, TextProjection
|
14 |
+
from .attenion import attention, parallel_attention, get_cu_seqlens
|
15 |
+
from .posemb_layers import apply_rotary_emb
|
16 |
+
from .mlp_layers import MLP, MLPEmbedder, FinalLayer
|
17 |
+
from .modulate_layers import ModulateDiT, modulate, apply_gate
|
18 |
+
from .token_refiner import SingleTokenRefiner
|
19 |
+
|
20 |
+
|
21 |
+
class MMDoubleStreamBlock(nn.Module):
|
22 |
+
"""
|
23 |
+
A multimodal dit block with seperate modulation for
|
24 |
+
text and image/video, see more details (SD3): https://arxiv.org/abs/2403.03206
|
25 |
+
(Flux.1): https://github.com/black-forest-labs/flux
|
26 |
+
"""
|
27 |
+
|
28 |
+
def __init__(
|
29 |
+
self,
|
30 |
+
hidden_size: int,
|
31 |
+
heads_num: int,
|
32 |
+
mlp_width_ratio: float,
|
33 |
+
mlp_act_type: str = "gelu_tanh",
|
34 |
+
qk_norm: bool = True,
|
35 |
+
qk_norm_type: str = "rms",
|
36 |
+
qkv_bias: bool = False,
|
37 |
+
dtype: Optional[torch.dtype] = None,
|
38 |
+
device: Optional[torch.device] = None,
|
39 |
+
):
|
40 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
41 |
+
super().__init__()
|
42 |
+
|
43 |
+
self.deterministic = False
|
44 |
+
self.heads_num = heads_num
|
45 |
+
head_dim = hidden_size // heads_num
|
46 |
+
mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
|
47 |
+
|
48 |
+
self.img_mod = ModulateDiT(
|
49 |
+
hidden_size,
|
50 |
+
factor=6,
|
51 |
+
act_layer=get_activation_layer("silu"),
|
52 |
+
**factory_kwargs,
|
53 |
+
)
|
54 |
+
self.img_norm1 = nn.LayerNorm(
|
55 |
+
hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
|
56 |
+
)
|
57 |
+
|
58 |
+
self.img_attn_qkv = nn.Linear(
|
59 |
+
hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs
|
60 |
+
)
|
61 |
+
qk_norm_layer = get_norm_layer(qk_norm_type)
|
62 |
+
self.img_attn_q_norm = (
|
63 |
+
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
|
64 |
+
if qk_norm
|
65 |
+
else nn.Identity()
|
66 |
+
)
|
67 |
+
self.img_attn_k_norm = (
|
68 |
+
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
|
69 |
+
if qk_norm
|
70 |
+
else nn.Identity()
|
71 |
+
)
|
72 |
+
self.img_attn_proj = nn.Linear(
|
73 |
+
hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs
|
74 |
+
)
|
75 |
+
|
76 |
+
self.img_norm2 = nn.LayerNorm(
|
77 |
+
hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
|
78 |
+
)
|
79 |
+
self.img_mlp = MLP(
|
80 |
+
hidden_size,
|
81 |
+
mlp_hidden_dim,
|
82 |
+
act_layer=get_activation_layer(mlp_act_type),
|
83 |
+
bias=True,
|
84 |
+
**factory_kwargs,
|
85 |
+
)
|
86 |
+
|
87 |
+
self.txt_mod = ModulateDiT(
|
88 |
+
hidden_size,
|
89 |
+
factor=6,
|
90 |
+
act_layer=get_activation_layer("silu"),
|
91 |
+
**factory_kwargs,
|
92 |
+
)
|
93 |
+
self.txt_norm1 = nn.LayerNorm(
|
94 |
+
hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
|
95 |
+
)
|
96 |
+
|
97 |
+
self.txt_attn_qkv = nn.Linear(
|
98 |
+
hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs
|
99 |
+
)
|
100 |
+
self.txt_attn_q_norm = (
|
101 |
+
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
|
102 |
+
if qk_norm
|
103 |
+
else nn.Identity()
|
104 |
+
)
|
105 |
+
self.txt_attn_k_norm = (
|
106 |
+
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
|
107 |
+
if qk_norm
|
108 |
+
else nn.Identity()
|
109 |
+
)
|
110 |
+
self.txt_attn_proj = nn.Linear(
|
111 |
+
hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs
|
112 |
+
)
|
113 |
+
|
114 |
+
self.txt_norm2 = nn.LayerNorm(
|
115 |
+
hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
|
116 |
+
)
|
117 |
+
self.txt_mlp = MLP(
|
118 |
+
hidden_size,
|
119 |
+
mlp_hidden_dim,
|
120 |
+
act_layer=get_activation_layer(mlp_act_type),
|
121 |
+
bias=True,
|
122 |
+
**factory_kwargs,
|
123 |
+
)
|
124 |
+
self.hybrid_seq_parallel_attn = None
|
125 |
+
|
126 |
+
def enable_deterministic(self):
|
127 |
+
self.deterministic = True
|
128 |
+
|
129 |
+
def disable_deterministic(self):
|
130 |
+
self.deterministic = False
|
131 |
+
|
132 |
+
def forward(
|
133 |
+
self,
|
134 |
+
img: torch.Tensor,
|
135 |
+
txt: torch.Tensor,
|
136 |
+
vec: torch.Tensor,
|
137 |
+
cu_seqlens_q: Optional[torch.Tensor] = None,
|
138 |
+
cu_seqlens_kv: Optional[torch.Tensor] = None,
|
139 |
+
max_seqlen_q: Optional[int] = None,
|
140 |
+
max_seqlen_kv: Optional[int] = None,
|
141 |
+
freqs_cis: tuple = None,
|
142 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
143 |
+
(
|
144 |
+
img_mod1_shift,
|
145 |
+
img_mod1_scale,
|
146 |
+
img_mod1_gate,
|
147 |
+
img_mod2_shift,
|
148 |
+
img_mod2_scale,
|
149 |
+
img_mod2_gate,
|
150 |
+
) = self.img_mod(vec).chunk(6, dim=-1)
|
151 |
+
(
|
152 |
+
txt_mod1_shift,
|
153 |
+
txt_mod1_scale,
|
154 |
+
txt_mod1_gate,
|
155 |
+
txt_mod2_shift,
|
156 |
+
txt_mod2_scale,
|
157 |
+
txt_mod2_gate,
|
158 |
+
) = self.txt_mod(vec).chunk(6, dim=-1)
|
159 |
+
|
160 |
+
# Prepare image for attention.
|
161 |
+
img_modulated = self.img_norm1(img)
|
162 |
+
img_modulated = modulate(
|
163 |
+
img_modulated, shift=img_mod1_shift, scale=img_mod1_scale
|
164 |
+
)
|
165 |
+
img_qkv = self.img_attn_qkv(img_modulated)
|
166 |
+
img_q, img_k, img_v = rearrange(
|
167 |
+
img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num
|
168 |
+
)
|
169 |
+
# Apply QK-Norm if needed
|
170 |
+
img_q = self.img_attn_q_norm(img_q).to(img_v)
|
171 |
+
img_k = self.img_attn_k_norm(img_k).to(img_v)
|
172 |
+
|
173 |
+
# Apply RoPE if needed.
|
174 |
+
if freqs_cis is not None:
|
175 |
+
img_qq, img_kk = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False)
|
176 |
+
assert (
|
177 |
+
img_qq.shape == img_q.shape and img_kk.shape == img_k.shape
|
178 |
+
), f"img_kk: {img_qq.shape}, img_q: {img_q.shape}, img_kk: {img_kk.shape}, img_k: {img_k.shape}"
|
179 |
+
img_q, img_k = img_qq, img_kk
|
180 |
+
|
181 |
+
# Prepare txt for attention.
|
182 |
+
txt_modulated = self.txt_norm1(txt)
|
183 |
+
txt_modulated = modulate(
|
184 |
+
txt_modulated, shift=txt_mod1_shift, scale=txt_mod1_scale
|
185 |
+
)
|
186 |
+
txt_qkv = self.txt_attn_qkv(txt_modulated)
|
187 |
+
txt_q, txt_k, txt_v = rearrange(
|
188 |
+
txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num
|
189 |
+
)
|
190 |
+
# Apply QK-Norm if needed.
|
191 |
+
txt_q = self.txt_attn_q_norm(txt_q).to(txt_v)
|
192 |
+
txt_k = self.txt_attn_k_norm(txt_k).to(txt_v)
|
193 |
+
|
194 |
+
# Run actual attention.
|
195 |
+
q = torch.cat((img_q, txt_q), dim=1)
|
196 |
+
k = torch.cat((img_k, txt_k), dim=1)
|
197 |
+
v = torch.cat((img_v, txt_v), dim=1)
|
198 |
+
assert (
|
199 |
+
cu_seqlens_q.shape[0] == 2 * img.shape[0] + 1
|
200 |
+
), f"cu_seqlens_q.shape:{cu_seqlens_q.shape}, img.shape[0]:{img.shape[0]}"
|
201 |
+
|
202 |
+
# attention computation start
|
203 |
+
if not self.hybrid_seq_parallel_attn:
|
204 |
+
attn = attention(
|
205 |
+
q,
|
206 |
+
k,
|
207 |
+
v,
|
208 |
+
cu_seqlens_q=cu_seqlens_q,
|
209 |
+
cu_seqlens_kv=cu_seqlens_kv,
|
210 |
+
max_seqlen_q=max_seqlen_q,
|
211 |
+
max_seqlen_kv=max_seqlen_kv,
|
212 |
+
batch_size=img_k.shape[0],
|
213 |
+
)
|
214 |
+
else:
|
215 |
+
attn = parallel_attention(
|
216 |
+
self.hybrid_seq_parallel_attn,
|
217 |
+
q,
|
218 |
+
k,
|
219 |
+
v,
|
220 |
+
img_q_len=img_q.shape[1],
|
221 |
+
img_kv_len=img_k.shape[1],
|
222 |
+
cu_seqlens_q=cu_seqlens_q,
|
223 |
+
cu_seqlens_kv=cu_seqlens_kv
|
224 |
+
)
|
225 |
+
|
226 |
+
# attention computation end
|
227 |
+
|
228 |
+
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1] :]
|
229 |
+
|
230 |
+
# Calculate the img bloks.
|
231 |
+
img = img + apply_gate(self.img_attn_proj(img_attn), gate=img_mod1_gate)
|
232 |
+
img = img + apply_gate(
|
233 |
+
self.img_mlp(
|
234 |
+
modulate(
|
235 |
+
self.img_norm2(img), shift=img_mod2_shift, scale=img_mod2_scale
|
236 |
+
)
|
237 |
+
),
|
238 |
+
gate=img_mod2_gate,
|
239 |
+
)
|
240 |
+
|
241 |
+
# Calculate the txt bloks.
|
242 |
+
txt = txt + apply_gate(self.txt_attn_proj(txt_attn), gate=txt_mod1_gate)
|
243 |
+
txt = txt + apply_gate(
|
244 |
+
self.txt_mlp(
|
245 |
+
modulate(
|
246 |
+
self.txt_norm2(txt), shift=txt_mod2_shift, scale=txt_mod2_scale
|
247 |
+
)
|
248 |
+
),
|
249 |
+
gate=txt_mod2_gate,
|
250 |
+
)
|
251 |
+
|
252 |
+
return img, txt
|
253 |
+
|
254 |
+
|
255 |
+
class MMSingleStreamBlock(nn.Module):
|
256 |
+
"""
|
257 |
+
A DiT block with parallel linear layers as described in
|
258 |
+
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
|
259 |
+
Also refer to (SD3): https://arxiv.org/abs/2403.03206
|
260 |
+
(Flux.1): https://github.com/black-forest-labs/flux
|
261 |
+
"""
|
262 |
+
|
263 |
+
def __init__(
|
264 |
+
self,
|
265 |
+
hidden_size: int,
|
266 |
+
heads_num: int,
|
267 |
+
mlp_width_ratio: float = 4.0,
|
268 |
+
mlp_act_type: str = "gelu_tanh",
|
269 |
+
qk_norm: bool = True,
|
270 |
+
qk_norm_type: str = "rms",
|
271 |
+
qk_scale: float = None,
|
272 |
+
dtype: Optional[torch.dtype] = None,
|
273 |
+
device: Optional[torch.device] = None,
|
274 |
+
):
|
275 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
276 |
+
super().__init__()
|
277 |
+
|
278 |
+
self.deterministic = False
|
279 |
+
self.hidden_size = hidden_size
|
280 |
+
self.heads_num = heads_num
|
281 |
+
head_dim = hidden_size // heads_num
|
282 |
+
mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
|
283 |
+
self.mlp_hidden_dim = mlp_hidden_dim
|
284 |
+
self.scale = qk_scale or head_dim ** -0.5
|
285 |
+
|
286 |
+
# qkv and mlp_in
|
287 |
+
self.linear1 = nn.Linear(
|
288 |
+
hidden_size, hidden_size * 3 + mlp_hidden_dim, **factory_kwargs
|
289 |
+
)
|
290 |
+
# proj and mlp_out
|
291 |
+
self.linear2 = nn.Linear(
|
292 |
+
hidden_size + mlp_hidden_dim, hidden_size, **factory_kwargs
|
293 |
+
)
|
294 |
+
|
295 |
+
qk_norm_layer = get_norm_layer(qk_norm_type)
|
296 |
+
self.q_norm = (
|
297 |
+
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
|
298 |
+
if qk_norm
|
299 |
+
else nn.Identity()
|
300 |
+
)
|
301 |
+
self.k_norm = (
|
302 |
+
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
|
303 |
+
if qk_norm
|
304 |
+
else nn.Identity()
|
305 |
+
)
|
306 |
+
|
307 |
+
self.pre_norm = nn.LayerNorm(
|
308 |
+
hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
|
309 |
+
)
|
310 |
+
|
311 |
+
self.mlp_act = get_activation_layer(mlp_act_type)()
|
312 |
+
self.modulation = ModulateDiT(
|
313 |
+
hidden_size,
|
314 |
+
factor=3,
|
315 |
+
act_layer=get_activation_layer("silu"),
|
316 |
+
**factory_kwargs,
|
317 |
+
)
|
318 |
+
self.hybrid_seq_parallel_attn = None
|
319 |
+
|
320 |
+
def enable_deterministic(self):
|
321 |
+
self.deterministic = True
|
322 |
+
|
323 |
+
def disable_deterministic(self):
|
324 |
+
self.deterministic = False
|
325 |
+
|
326 |
+
def forward(
|
327 |
+
self,
|
328 |
+
x: torch.Tensor,
|
329 |
+
vec: torch.Tensor,
|
330 |
+
txt_len: int,
|
331 |
+
cu_seqlens_q: Optional[torch.Tensor] = None,
|
332 |
+
cu_seqlens_kv: Optional[torch.Tensor] = None,
|
333 |
+
max_seqlen_q: Optional[int] = None,
|
334 |
+
max_seqlen_kv: Optional[int] = None,
|
335 |
+
freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None,
|
336 |
+
) -> torch.Tensor:
|
337 |
+
mod_shift, mod_scale, mod_gate = self.modulation(vec).chunk(3, dim=-1)
|
338 |
+
x_mod = modulate(self.pre_norm(x), shift=mod_shift, scale=mod_scale)
|
339 |
+
qkv, mlp = torch.split(
|
340 |
+
self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1
|
341 |
+
)
|
342 |
+
|
343 |
+
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
|
344 |
+
|
345 |
+
# Apply QK-Norm if needed.
|
346 |
+
q = self.q_norm(q).to(v)
|
347 |
+
k = self.k_norm(k).to(v)
|
348 |
+
|
349 |
+
# Apply RoPE if needed.
|
350 |
+
if freqs_cis is not None:
|
351 |
+
img_q, txt_q = q[:, :-txt_len, :, :], q[:, -txt_len:, :, :]
|
352 |
+
img_k, txt_k = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :]
|
353 |
+
img_qq, img_kk = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False)
|
354 |
+
assert (
|
355 |
+
img_qq.shape == img_q.shape and img_kk.shape == img_k.shape
|
356 |
+
), f"img_kk: {img_qq.shape}, img_q: {img_q.shape}, img_kk: {img_kk.shape}, img_k: {img_k.shape}"
|
357 |
+
img_q, img_k = img_qq, img_kk
|
358 |
+
q = torch.cat((img_q, txt_q), dim=1)
|
359 |
+
k = torch.cat((img_k, txt_k), dim=1)
|
360 |
+
|
361 |
+
# Compute attention.
|
362 |
+
assert (
|
363 |
+
cu_seqlens_q.shape[0] == 2 * x.shape[0] + 1
|
364 |
+
), f"cu_seqlens_q.shape:{cu_seqlens_q.shape}, x.shape[0]:{x.shape[0]}"
|
365 |
+
|
366 |
+
# attention computation start
|
367 |
+
if not self.hybrid_seq_parallel_attn:
|
368 |
+
attn = attention(
|
369 |
+
q,
|
370 |
+
k,
|
371 |
+
v,
|
372 |
+
cu_seqlens_q=cu_seqlens_q,
|
373 |
+
cu_seqlens_kv=cu_seqlens_kv,
|
374 |
+
max_seqlen_q=max_seqlen_q,
|
375 |
+
max_seqlen_kv=max_seqlen_kv,
|
376 |
+
batch_size=x.shape[0],
|
377 |
+
)
|
378 |
+
else:
|
379 |
+
attn = parallel_attention(
|
380 |
+
self.hybrid_seq_parallel_attn,
|
381 |
+
q,
|
382 |
+
k,
|
383 |
+
v,
|
384 |
+
img_q_len=img_q.shape[1],
|
385 |
+
img_kv_len=img_k.shape[1],
|
386 |
+
cu_seqlens_q=cu_seqlens_q,
|
387 |
+
cu_seqlens_kv=cu_seqlens_kv
|
388 |
+
)
|
389 |
+
# attention computation end
|
390 |
+
|
391 |
+
# Compute activation in mlp stream, cat again and run second linear layer.
|
392 |
+
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
393 |
+
return x + apply_gate(output, gate=mod_gate)
|
394 |
+
|
395 |
+
|
396 |
+
class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin):
|
397 |
+
"""
|
398 |
+
HunyuanVideo Transformer backbone
|
399 |
+
|
400 |
+
Inherited from ModelMixin and ConfigMixin for compatibility with diffusers' sampler StableDiffusionPipeline.
|
401 |
+
|
402 |
+
Reference:
|
403 |
+
[1] Flux.1: https://github.com/black-forest-labs/flux
|
404 |
+
[2] MMDiT: http://arxiv.org/abs/2403.03206
|
405 |
+
|
406 |
+
Parameters
|
407 |
+
----------
|
408 |
+
args: argparse.Namespace
|
409 |
+
The arguments parsed by argparse.
|
410 |
+
patch_size: list
|
411 |
+
The size of the patch.
|
412 |
+
in_channels: int
|
413 |
+
The number of input channels.
|
414 |
+
out_channels: int
|
415 |
+
The number of output channels.
|
416 |
+
hidden_size: int
|
417 |
+
The hidden size of the transformer backbone.
|
418 |
+
heads_num: int
|
419 |
+
The number of attention heads.
|
420 |
+
mlp_width_ratio: float
|
421 |
+
The ratio of the hidden size of the MLP in the transformer block.
|
422 |
+
mlp_act_type: str
|
423 |
+
The activation function of the MLP in the transformer block.
|
424 |
+
depth_double_blocks: int
|
425 |
+
The number of transformer blocks in the double blocks.
|
426 |
+
depth_single_blocks: int
|
427 |
+
The number of transformer blocks in the single blocks.
|
428 |
+
rope_dim_list: list
|
429 |
+
The dimension of the rotary embedding for t, h, w.
|
430 |
+
qkv_bias: bool
|
431 |
+
Whether to use bias in the qkv linear layer.
|
432 |
+
qk_norm: bool
|
433 |
+
Whether to use qk norm.
|
434 |
+
qk_norm_type: str
|
435 |
+
The type of qk norm.
|
436 |
+
guidance_embed: bool
|
437 |
+
Whether to use guidance embedding for distillation.
|
438 |
+
text_projection: str
|
439 |
+
The type of the text projection, default is single_refiner.
|
440 |
+
use_attention_mask: bool
|
441 |
+
Whether to use attention mask for text encoder.
|
442 |
+
dtype: torch.dtype
|
443 |
+
The dtype of the model.
|
444 |
+
device: torch.device
|
445 |
+
The device of the model.
|
446 |
+
"""
|
447 |
+
|
448 |
+
@register_to_config
|
449 |
+
def __init__(
|
450 |
+
self,
|
451 |
+
args: Any,
|
452 |
+
patch_size: list = [1, 2, 2],
|
453 |
+
in_channels: int = 4, # Should be VAE.config.latent_channels.
|
454 |
+
out_channels: int = None,
|
455 |
+
hidden_size: int = 3072,
|
456 |
+
heads_num: int = 24,
|
457 |
+
mlp_width_ratio: float = 4.0,
|
458 |
+
mlp_act_type: str = "gelu_tanh",
|
459 |
+
mm_double_blocks_depth: int = 20,
|
460 |
+
mm_single_blocks_depth: int = 40,
|
461 |
+
rope_dim_list: List[int] = [16, 56, 56],
|
462 |
+
qkv_bias: bool = True,
|
463 |
+
qk_norm: bool = True,
|
464 |
+
qk_norm_type: str = "rms",
|
465 |
+
guidance_embed: bool = False, # For modulation.
|
466 |
+
text_projection: str = "single_refiner",
|
467 |
+
use_attention_mask: bool = True,
|
468 |
+
dtype: Optional[torch.dtype] = None,
|
469 |
+
device: Optional[torch.device] = None,
|
470 |
+
):
|
471 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
472 |
+
super().__init__()
|
473 |
+
|
474 |
+
self.patch_size = patch_size
|
475 |
+
self.in_channels = in_channels
|
476 |
+
self.out_channels = in_channels if out_channels is None else out_channels
|
477 |
+
self.unpatchify_channels = self.out_channels
|
478 |
+
self.guidance_embed = guidance_embed
|
479 |
+
self.rope_dim_list = rope_dim_list
|
480 |
+
|
481 |
+
# Text projection. Default to linear projection.
|
482 |
+
# Alternative: TokenRefiner. See more details (LI-DiT): http://arxiv.org/abs/2406.11831
|
483 |
+
self.use_attention_mask = use_attention_mask
|
484 |
+
self.text_projection = text_projection
|
485 |
+
|
486 |
+
self.text_states_dim = args.text_states_dim
|
487 |
+
self.text_states_dim_2 = args.text_states_dim_2
|
488 |
+
|
489 |
+
if hidden_size % heads_num != 0:
|
490 |
+
raise ValueError(
|
491 |
+
f"Hidden size {hidden_size} must be divisible by heads_num {heads_num}"
|
492 |
+
)
|
493 |
+
pe_dim = hidden_size // heads_num
|
494 |
+
if sum(rope_dim_list) != pe_dim:
|
495 |
+
raise ValueError(
|
496 |
+
f"Got {rope_dim_list} but expected positional dim {pe_dim}"
|
497 |
+
)
|
498 |
+
self.hidden_size = hidden_size
|
499 |
+
self.heads_num = heads_num
|
500 |
+
|
501 |
+
# image projection
|
502 |
+
self.img_in = PatchEmbed(
|
503 |
+
self.patch_size, self.in_channels, self.hidden_size, **factory_kwargs
|
504 |
+
)
|
505 |
+
|
506 |
+
# text projection
|
507 |
+
if self.text_projection == "linear":
|
508 |
+
self.txt_in = TextProjection(
|
509 |
+
self.text_states_dim,
|
510 |
+
self.hidden_size,
|
511 |
+
get_activation_layer("silu"),
|
512 |
+
**factory_kwargs,
|
513 |
+
)
|
514 |
+
elif self.text_projection == "single_refiner":
|
515 |
+
self.txt_in = SingleTokenRefiner(
|
516 |
+
self.text_states_dim, hidden_size, heads_num, depth=2, **factory_kwargs
|
517 |
+
)
|
518 |
+
else:
|
519 |
+
raise NotImplementedError(
|
520 |
+
f"Unsupported text_projection: {self.text_projection}"
|
521 |
+
)
|
522 |
+
|
523 |
+
# time modulation
|
524 |
+
self.time_in = TimestepEmbedder(
|
525 |
+
self.hidden_size, get_activation_layer("silu"), **factory_kwargs
|
526 |
+
)
|
527 |
+
|
528 |
+
# text modulation
|
529 |
+
self.vector_in = MLPEmbedder(
|
530 |
+
self.text_states_dim_2, self.hidden_size, **factory_kwargs
|
531 |
+
)
|
532 |
+
|
533 |
+
# guidance modulation
|
534 |
+
self.guidance_in = (
|
535 |
+
TimestepEmbedder(
|
536 |
+
self.hidden_size, get_activation_layer("silu"), **factory_kwargs
|
537 |
+
)
|
538 |
+
if guidance_embed
|
539 |
+
else None
|
540 |
+
)
|
541 |
+
|
542 |
+
# double blocks
|
543 |
+
self.double_blocks = nn.ModuleList(
|
544 |
+
[
|
545 |
+
MMDoubleStreamBlock(
|
546 |
+
self.hidden_size,
|
547 |
+
self.heads_num,
|
548 |
+
mlp_width_ratio=mlp_width_ratio,
|
549 |
+
mlp_act_type=mlp_act_type,
|
550 |
+
qk_norm=qk_norm,
|
551 |
+
qk_norm_type=qk_norm_type,
|
552 |
+
qkv_bias=qkv_bias,
|
553 |
+
**factory_kwargs,
|
554 |
+
)
|
555 |
+
for _ in range(mm_double_blocks_depth)
|
556 |
+
]
|
557 |
+
)
|
558 |
+
|
559 |
+
# single blocks
|
560 |
+
self.single_blocks = nn.ModuleList(
|
561 |
+
[
|
562 |
+
MMSingleStreamBlock(
|
563 |
+
self.hidden_size,
|
564 |
+
self.heads_num,
|
565 |
+
mlp_width_ratio=mlp_width_ratio,
|
566 |
+
mlp_act_type=mlp_act_type,
|
567 |
+
qk_norm=qk_norm,
|
568 |
+
qk_norm_type=qk_norm_type,
|
569 |
+
**factory_kwargs,
|
570 |
+
)
|
571 |
+
for _ in range(mm_single_blocks_depth)
|
572 |
+
]
|
573 |
+
)
|
574 |
+
|
575 |
+
self.final_layer = FinalLayer(
|
576 |
+
self.hidden_size,
|
577 |
+
self.patch_size,
|
578 |
+
self.out_channels,
|
579 |
+
get_activation_layer("silu"),
|
580 |
+
**factory_kwargs,
|
581 |
+
)
|
582 |
+
|
583 |
+
def enable_deterministic(self):
|
584 |
+
for block in self.double_blocks:
|
585 |
+
block.enable_deterministic()
|
586 |
+
for block in self.single_blocks:
|
587 |
+
block.enable_deterministic()
|
588 |
+
|
589 |
+
def disable_deterministic(self):
|
590 |
+
for block in self.double_blocks:
|
591 |
+
block.disable_deterministic()
|
592 |
+
for block in self.single_blocks:
|
593 |
+
block.disable_deterministic()
|
594 |
+
|
595 |
+
def forward(
|
596 |
+
self,
|
597 |
+
x: torch.Tensor,
|
598 |
+
t: torch.Tensor, # Should be in range(0, 1000).
|
599 |
+
text_states: torch.Tensor = None,
|
600 |
+
text_mask: torch.Tensor = None, # Now we don't use it.
|
601 |
+
text_states_2: Optional[torch.Tensor] = None, # Text embedding for modulation.
|
602 |
+
freqs_cos: Optional[torch.Tensor] = None,
|
603 |
+
freqs_sin: Optional[torch.Tensor] = None,
|
604 |
+
guidance: torch.Tensor = None, # Guidance for modulation, should be cfg_scale x 1000.
|
605 |
+
return_dict: bool = True,
|
606 |
+
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
607 |
+
out = {}
|
608 |
+
img = x
|
609 |
+
txt = text_states
|
610 |
+
_, _, ot, oh, ow = x.shape
|
611 |
+
tt, th, tw = (
|
612 |
+
ot // self.patch_size[0],
|
613 |
+
oh // self.patch_size[1],
|
614 |
+
ow // self.patch_size[2],
|
615 |
+
)
|
616 |
+
|
617 |
+
# Prepare modulation vectors.
|
618 |
+
vec = self.time_in(t)
|
619 |
+
|
620 |
+
# text modulation
|
621 |
+
vec = vec + self.vector_in(text_states_2)
|
622 |
+
|
623 |
+
# guidance modulation
|
624 |
+
if self.guidance_embed:
|
625 |
+
if guidance is None:
|
626 |
+
raise ValueError(
|
627 |
+
"Didn't get guidance strength for guidance distilled model."
|
628 |
+
)
|
629 |
+
|
630 |
+
# our timestep_embedding is merged into guidance_in(TimestepEmbedder)
|
631 |
+
vec = vec + self.guidance_in(guidance)
|
632 |
+
|
633 |
+
# Embed image and text.
|
634 |
+
img = self.img_in(img)
|
635 |
+
if self.text_projection == "linear":
|
636 |
+
txt = self.txt_in(txt)
|
637 |
+
elif self.text_projection == "single_refiner":
|
638 |
+
txt = self.txt_in(txt, t, text_mask if self.use_attention_mask else None)
|
639 |
+
else:
|
640 |
+
raise NotImplementedError(
|
641 |
+
f"Unsupported text_projection: {self.text_projection}"
|
642 |
+
)
|
643 |
+
|
644 |
+
txt_seq_len = txt.shape[1]
|
645 |
+
img_seq_len = img.shape[1]
|
646 |
+
|
647 |
+
# Compute cu_squlens and max_seqlen for flash attention
|
648 |
+
cu_seqlens_q = get_cu_seqlens(text_mask, img_seq_len)
|
649 |
+
cu_seqlens_kv = cu_seqlens_q
|
650 |
+
max_seqlen_q = img_seq_len + txt_seq_len
|
651 |
+
max_seqlen_kv = max_seqlen_q
|
652 |
+
|
653 |
+
freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None
|
654 |
+
# --------------------- Pass through DiT blocks ------------------------
|
655 |
+
for _, block in enumerate(self.double_blocks):
|
656 |
+
double_block_args = [
|
657 |
+
img,
|
658 |
+
txt,
|
659 |
+
vec,
|
660 |
+
cu_seqlens_q,
|
661 |
+
cu_seqlens_kv,
|
662 |
+
max_seqlen_q,
|
663 |
+
max_seqlen_kv,
|
664 |
+
freqs_cis,
|
665 |
+
]
|
666 |
+
|
667 |
+
img, txt = block(*double_block_args)
|
668 |
+
|
669 |
+
# Merge txt and img to pass through single stream blocks.
|
670 |
+
x = torch.cat((img, txt), 1)
|
671 |
+
if len(self.single_blocks) > 0:
|
672 |
+
for _, block in enumerate(self.single_blocks):
|
673 |
+
single_block_args = [
|
674 |
+
x,
|
675 |
+
vec,
|
676 |
+
txt_seq_len,
|
677 |
+
cu_seqlens_q,
|
678 |
+
cu_seqlens_kv,
|
679 |
+
max_seqlen_q,
|
680 |
+
max_seqlen_kv,
|
681 |
+
(freqs_cos, freqs_sin),
|
682 |
+
]
|
683 |
+
|
684 |
+
x = block(*single_block_args)
|
685 |
+
|
686 |
+
img = x[:, :img_seq_len, ...]
|
687 |
+
|
688 |
+
# ---------------------------- Final layer ------------------------------
|
689 |
+
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
690 |
+
|
691 |
+
img = self.unpatchify(img, tt, th, tw)
|
692 |
+
if return_dict:
|
693 |
+
out["x"] = img
|
694 |
+
return out
|
695 |
+
return img
|
696 |
+
|
697 |
+
def unpatchify(self, x, t, h, w):
|
698 |
+
"""
|
699 |
+
x: (N, T, patch_size**2 * C)
|
700 |
+
imgs: (N, H, W, C)
|
701 |
+
"""
|
702 |
+
c = self.unpatchify_channels
|
703 |
+
pt, ph, pw = self.patch_size
|
704 |
+
assert t * h * w == x.shape[1]
|
705 |
+
|
706 |
+
x = x.reshape(shape=(x.shape[0], t, h, w, c, pt, ph, pw))
|
707 |
+
x = torch.einsum("nthwcopq->nctohpwq", x)
|
708 |
+
imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw))
|
709 |
+
|
710 |
+
return imgs
|
711 |
+
|
712 |
+
def params_count(self):
|
713 |
+
counts = {
|
714 |
+
"double": sum(
|
715 |
+
[
|
716 |
+
sum(p.numel() for p in block.img_attn_qkv.parameters())
|
717 |
+
+ sum(p.numel() for p in block.img_attn_proj.parameters())
|
718 |
+
+ sum(p.numel() for p in block.img_mlp.parameters())
|
719 |
+
+ sum(p.numel() for p in block.txt_attn_qkv.parameters())
|
720 |
+
+ sum(p.numel() for p in block.txt_attn_proj.parameters())
|
721 |
+
+ sum(p.numel() for p in block.txt_mlp.parameters())
|
722 |
+
for block in self.double_blocks
|
723 |
+
]
|
724 |
+
),
|
725 |
+
"single": sum(
|
726 |
+
[
|
727 |
+
sum(p.numel() for p in block.linear1.parameters())
|
728 |
+
+ sum(p.numel() for p in block.linear2.parameters())
|
729 |
+
for block in self.single_blocks
|
730 |
+
]
|
731 |
+
),
|
732 |
+
"total": sum(p.numel() for p in self.parameters()),
|
733 |
+
}
|
734 |
+
counts["attn+mlp"] = counts["double"] + counts["single"]
|
735 |
+
return counts
|
736 |
+
|
737 |
+
|
738 |
+
#################################################################################
|
739 |
+
# HunyuanVideo Configs #
|
740 |
+
#################################################################################
|
741 |
+
|
742 |
+
HUNYUAN_VIDEO_CONFIG = {
|
743 |
+
"HYVideo-T/2": {
|
744 |
+
"mm_double_blocks_depth": 20,
|
745 |
+
"mm_single_blocks_depth": 40,
|
746 |
+
"rope_dim_list": [16, 56, 56],
|
747 |
+
"hidden_size": 3072,
|
748 |
+
"heads_num": 24,
|
749 |
+
"mlp_width_ratio": 4,
|
750 |
+
},
|
751 |
+
"HYVideo-T/2-cfgdistill": {
|
752 |
+
"mm_double_blocks_depth": 20,
|
753 |
+
"mm_single_blocks_depth": 40,
|
754 |
+
"rope_dim_list": [16, 56, 56],
|
755 |
+
"hidden_size": 3072,
|
756 |
+
"heads_num": 24,
|
757 |
+
"mlp_width_ratio": 4,
|
758 |
+
"guidance_embed": True,
|
759 |
+
},
|
760 |
+
}
|
hyvideo/modules/placement.py
ADDED
@@ -0,0 +1,389 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import triton
|
3 |
+
import triton.language as tl
|
4 |
+
|
5 |
+
def hunyuan_token_reorder_to_token_major(tensor, fix_len, reorder_len, reorder_num_frame, frame_size):
|
6 |
+
"""Reorder it from frame major to token major!"""
|
7 |
+
assert reorder_len == reorder_num_frame * frame_size
|
8 |
+
assert tensor.shape[2] == fix_len + reorder_len
|
9 |
+
|
10 |
+
tensor[:, :, :-fix_len, :] = tensor[:, :, :-fix_len:, :].reshape(tensor.shape[0], tensor.shape[1], reorder_num_frame, frame_size, tensor.shape[3]) \
|
11 |
+
.transpose(2, 3).reshape(tensor.shape[0], tensor.shape[1], reorder_len, tensor.shape[3])
|
12 |
+
return tensor
|
13 |
+
|
14 |
+
def hunyuan_token_reorder_to_frame_major(tensor, fix_len, reorder_len, reorder_num_frame, frame_size):
|
15 |
+
"""Reorder it from token major to frame major!"""
|
16 |
+
assert reorder_len == reorder_num_frame * frame_size
|
17 |
+
assert tensor.shape[2] == fix_len + reorder_len
|
18 |
+
|
19 |
+
tensor[:, :, :-fix_len:, :] = tensor[:, :, :-fix_len:, :].reshape(tensor.shape[0], tensor.shape[1], frame_size, reorder_num_frame, tensor.shape[3]) \
|
20 |
+
.transpose(2, 3).reshape(tensor.shape[0], tensor.shape[1], reorder_len, tensor.shape[3])
|
21 |
+
return tensor
|
22 |
+
|
23 |
+
|
24 |
+
@triton.jit
|
25 |
+
def hunyuan_sparse_head_placement_kernel(
|
26 |
+
query_ptr, key_ptr, value_ptr, # [cfg, num_heads, seq_len, head_dim] seq_len = context_length + num_frame * frame_size
|
27 |
+
query_out_ptr, key_out_ptr, value_out_ptr, # [cfg, num_heads, seq_len, head_dim]
|
28 |
+
best_mask_idx_ptr, # [cfg, num_heads]
|
29 |
+
query_stride_b, query_stride_h, query_stride_s, query_stride_d,
|
30 |
+
mask_idx_stride_b, mask_idx_stride_h,
|
31 |
+
seq_len: tl.constexpr,
|
32 |
+
head_dim: tl.constexpr,
|
33 |
+
context_length: tl.constexpr,
|
34 |
+
num_frame: tl.constexpr,
|
35 |
+
frame_size: tl.constexpr,
|
36 |
+
BLOCK_SIZE: tl.constexpr
|
37 |
+
):
|
38 |
+
# Copy query, key, value to output
|
39 |
+
# range: [b, h, block_id * block_size: block_id * block_size + block_size, :]
|
40 |
+
cfg = tl.program_id(0)
|
41 |
+
head = tl.program_id(1)
|
42 |
+
block_id = tl.program_id(2)
|
43 |
+
|
44 |
+
start_id = block_id * BLOCK_SIZE
|
45 |
+
end_id = start_id + BLOCK_SIZE
|
46 |
+
end_id = tl.where(end_id > seq_len, seq_len, end_id)
|
47 |
+
|
48 |
+
# Load best mask idx (0 is spatial, 1 is temporal)
|
49 |
+
is_temporal = tl.load(best_mask_idx_ptr + cfg * mask_idx_stride_b + head * mask_idx_stride_h)
|
50 |
+
|
51 |
+
offset_token = tl.arange(0, BLOCK_SIZE) + start_id
|
52 |
+
offset_mask = offset_token < seq_len
|
53 |
+
offset_d = tl.arange(0, head_dim)
|
54 |
+
|
55 |
+
if is_temporal:
|
56 |
+
frame_id = offset_token // frame_size
|
57 |
+
patch_id = offset_token - frame_id * frame_size
|
58 |
+
offset_store_token = tl.where(offset_token >= seq_len - context_length, offset_token, patch_id * num_frame + frame_id)
|
59 |
+
|
60 |
+
offset_load = (cfg * query_stride_b + head * query_stride_h + offset_token[:,None] * query_stride_s) + offset_d[None,:] * query_stride_d
|
61 |
+
offset_query = query_ptr + offset_load
|
62 |
+
offset_key = key_ptr + offset_load
|
63 |
+
offset_value = value_ptr + offset_load
|
64 |
+
|
65 |
+
offset_store = (cfg * query_stride_b + head * query_stride_h + offset_store_token[:,None] * query_stride_s) + offset_d[None,:] * query_stride_d
|
66 |
+
offset_query_out = query_out_ptr + offset_store
|
67 |
+
offset_key_out = key_out_ptr + offset_store
|
68 |
+
offset_value_out = value_out_ptr + offset_store
|
69 |
+
|
70 |
+
# Maybe tune the pipeline here
|
71 |
+
query = tl.load(offset_query, mask=offset_mask[:,None])
|
72 |
+
tl.store(offset_query_out, query, mask=offset_mask[:,None])
|
73 |
+
key = tl.load(offset_key, mask=offset_mask[:,None])
|
74 |
+
tl.store(offset_key_out, key, mask=offset_mask[:,None])
|
75 |
+
value = tl.load(offset_value, mask=offset_mask[:,None])
|
76 |
+
tl.store(offset_value_out, value, mask=offset_mask[:,None])
|
77 |
+
|
78 |
+
|
79 |
+
else:
|
80 |
+
offset_load = (cfg * query_stride_b + head * query_stride_h + offset_token[:,None] * query_stride_s) + offset_d[None,:] * query_stride_d
|
81 |
+
offset_query = query_ptr + offset_load
|
82 |
+
offset_key = key_ptr + offset_load
|
83 |
+
offset_value = value_ptr + offset_load
|
84 |
+
|
85 |
+
offset_store = offset_load
|
86 |
+
offset_query_out = query_out_ptr + offset_store
|
87 |
+
offset_key_out = key_out_ptr + offset_store
|
88 |
+
offset_value_out = value_out_ptr + offset_store
|
89 |
+
|
90 |
+
# Maybe tune the pipeline here
|
91 |
+
query = tl.load(offset_query, mask=offset_mask[:,None])
|
92 |
+
tl.store(offset_query_out, query, mask=offset_mask[:,None])
|
93 |
+
key = tl.load(offset_key, mask=offset_mask[:,None])
|
94 |
+
tl.store(offset_key_out, key, mask=offset_mask[:,None])
|
95 |
+
value = tl.load(offset_value, mask=offset_mask[:,None])
|
96 |
+
tl.store(offset_value_out, value, mask=offset_mask[:,None])
|
97 |
+
|
98 |
+
|
99 |
+
def hunyuan_sparse_head_placement(query, key, value, query_out, key_out, value_out, best_mask_idx, context_length, num_frame, frame_size):
|
100 |
+
cfg, num_heads, seq_len, head_dim = query.shape
|
101 |
+
BLOCK_SIZE = 128
|
102 |
+
assert seq_len == context_length + num_frame * frame_size
|
103 |
+
|
104 |
+
grid = (cfg, num_heads, (seq_len + BLOCK_SIZE - 1) // BLOCK_SIZE)
|
105 |
+
|
106 |
+
hunyuan_sparse_head_placement_kernel[grid](
|
107 |
+
query, key, value,
|
108 |
+
query_out, key_out, value_out,
|
109 |
+
best_mask_idx,
|
110 |
+
query.stride(0), query.stride(1), query.stride(2), query.stride(3),
|
111 |
+
best_mask_idx.stride(0), best_mask_idx.stride(1),
|
112 |
+
seq_len, head_dim, context_length, num_frame, frame_size,
|
113 |
+
BLOCK_SIZE
|
114 |
+
)
|
115 |
+
|
116 |
+
|
117 |
+
def ref_hunyuan_sparse_head_placement(query, key, value, best_mask_idx, context_length, num_frame, frame_size):
|
118 |
+
cfg, num_heads, seq_len, head_dim = query.shape
|
119 |
+
assert seq_len == context_length + num_frame * frame_size
|
120 |
+
|
121 |
+
query_out = query.clone()
|
122 |
+
key_out = key.clone()
|
123 |
+
value_out = value.clone()
|
124 |
+
|
125 |
+
# Spatial
|
126 |
+
query_out[best_mask_idx == 0], key_out[best_mask_idx == 0], value_out[best_mask_idx == 0] = \
|
127 |
+
query[best_mask_idx == 0], key[best_mask_idx == 0], value[best_mask_idx == 0]
|
128 |
+
|
129 |
+
# Temporal
|
130 |
+
query_out[best_mask_idx == 1], key_out[best_mask_idx == 1], value_out[best_mask_idx == 1] = \
|
131 |
+
hunyuan_token_reorder_to_token_major(query[best_mask_idx == 1].unsqueeze(0), context_length, num_frame * frame_size, num_frame, frame_size).squeeze(0), \
|
132 |
+
hunyuan_token_reorder_to_token_major(key[best_mask_idx == 1].unsqueeze(0), context_length, num_frame * frame_size, num_frame, frame_size).squeeze(0), \
|
133 |
+
hunyuan_token_reorder_to_token_major(value[best_mask_idx == 1].unsqueeze(0), context_length, num_frame * frame_size, num_frame, frame_size).squeeze(0)
|
134 |
+
|
135 |
+
return query_out, key_out, value_out
|
136 |
+
|
137 |
+
|
138 |
+
def test_hunyuan_sparse_head_placement():
|
139 |
+
|
140 |
+
context_length = 226
|
141 |
+
num_frame = 11
|
142 |
+
frame_size = 4080
|
143 |
+
|
144 |
+
cfg = 2
|
145 |
+
num_heads = 48
|
146 |
+
|
147 |
+
seq_len = context_length + num_frame * frame_size
|
148 |
+
head_dim = 64
|
149 |
+
|
150 |
+
dtype = torch.bfloat16
|
151 |
+
device = torch.device("cuda")
|
152 |
+
|
153 |
+
query = torch.randn(cfg, num_heads, seq_len, head_dim, dtype=dtype, device=device)
|
154 |
+
key = torch.randn(cfg, num_heads, seq_len, head_dim, dtype=dtype, device=device)
|
155 |
+
value = torch.randn(cfg, num_heads, seq_len, head_dim, dtype=dtype, device=device)
|
156 |
+
|
157 |
+
best_mask_idx = torch.randint(0, 2, (cfg, num_heads), device=device)
|
158 |
+
|
159 |
+
query_out = torch.empty_like(query)
|
160 |
+
key_out = torch.empty_like(key)
|
161 |
+
value_out = torch.empty_like(value)
|
162 |
+
|
163 |
+
hunyuan_sparse_head_placement(query, key, value, query_out, key_out, value_out, best_mask_idx, context_length, num_frame, frame_size)
|
164 |
+
ref_query_out, ref_key_out, ref_value_out = ref_hunyuan_sparse_head_placement(query, key, value, best_mask_idx, context_length, num_frame, frame_size)
|
165 |
+
|
166 |
+
torch.testing.assert_close(query_out, ref_query_out)
|
167 |
+
torch.testing.assert_close(key_out, ref_key_out)
|
168 |
+
torch.testing.assert_close(value_out, ref_value_out)
|
169 |
+
|
170 |
+
|
171 |
+
def benchmark_hunyuan_sparse_head_placement():
|
172 |
+
import time
|
173 |
+
|
174 |
+
context_length = 226
|
175 |
+
num_frame = 11
|
176 |
+
frame_size = 4080
|
177 |
+
|
178 |
+
cfg = 2
|
179 |
+
num_heads = 48
|
180 |
+
|
181 |
+
seq_len = context_length + num_frame * frame_size
|
182 |
+
head_dim = 64
|
183 |
+
|
184 |
+
dtype = torch.bfloat16
|
185 |
+
device = torch.device("cuda")
|
186 |
+
|
187 |
+
query = torch.randn(cfg, num_heads, seq_len, head_dim, dtype=dtype, device=device)
|
188 |
+
key = torch.randn(cfg, num_heads, seq_len, head_dim, dtype=dtype, device=device)
|
189 |
+
value = torch.randn(cfg, num_heads, seq_len, head_dim, dtype=dtype, device=device)
|
190 |
+
best_mask_idx = torch.randint(0, 2, (cfg, num_heads), device=device)
|
191 |
+
|
192 |
+
query_out = torch.empty_like(query)
|
193 |
+
key_out = torch.empty_like(key)
|
194 |
+
value_out = torch.empty_like(value)
|
195 |
+
|
196 |
+
warmup = 10
|
197 |
+
all_iter = 1000
|
198 |
+
|
199 |
+
# warmup
|
200 |
+
for _ in range(warmup):
|
201 |
+
hunyuan_sparse_head_placement(query, key, value, query_out, key_out, value_out, best_mask_idx, context_length, num_frame, frame_size)
|
202 |
+
|
203 |
+
torch.cuda.synchronize()
|
204 |
+
start = time.time()
|
205 |
+
for _ in range(all_iter):
|
206 |
+
hunyuan_sparse_head_placement(query, key, value, query_out, key_out, value_out, best_mask_idx, context_length, num_frame, frame_size)
|
207 |
+
torch.cuda.synchronize()
|
208 |
+
end = time.time()
|
209 |
+
|
210 |
+
print(f"Triton Elapsed Time: {(end - start) / all_iter * 1e3:.2f} ms")
|
211 |
+
print(f"Triton Total Bandwidth: {query.nelement() * query.element_size() * 3 * 2 * all_iter / (end - start) / 1e9:.2f} GB/s")
|
212 |
+
|
213 |
+
torch.cuda.synchronize()
|
214 |
+
start = time.time()
|
215 |
+
for _ in range(all_iter):
|
216 |
+
ref_hunyuan_sparse_head_placement(query, key, value, best_mask_idx, context_length, num_frame, frame_size)
|
217 |
+
torch.cuda.synchronize()
|
218 |
+
end = time.time()
|
219 |
+
|
220 |
+
print(f"Reference Elapsed Time: {(end - start) / all_iter * 1e3:.2f} ms")
|
221 |
+
print(f"Reference Total Bandwidth: {query.nelement() * query.element_size() * 3 * 2 * all_iter / (end - start) / 1e9:.2f} GB/s")
|
222 |
+
|
223 |
+
|
224 |
+
@triton.jit
|
225 |
+
def hunyuan_hidden_states_placement_kernel(
|
226 |
+
hidden_states_ptr, # [cfg, num_heads, seq_len, head_dim] seq_len = context_length + num_frame * frame_size
|
227 |
+
hidden_states_out_ptr, # [cfg, num_heads, seq_len, head_dim]
|
228 |
+
best_mask_idx_ptr, # [cfg, num_heads]
|
229 |
+
hidden_states_stride_b, hidden_states_stride_h, hidden_states_stride_s, hidden_states_stride_d,
|
230 |
+
mask_idx_stride_b, mask_idx_stride_h,
|
231 |
+
seq_len: tl.constexpr,
|
232 |
+
head_dim: tl.constexpr,
|
233 |
+
context_length: tl.constexpr,
|
234 |
+
num_frame: tl.constexpr,
|
235 |
+
frame_size: tl.constexpr,
|
236 |
+
BLOCK_SIZE: tl.constexpr
|
237 |
+
):
|
238 |
+
# Copy hidden_states to output
|
239 |
+
# range: [b, h, block_id * block_size: block_id * block_size + block_size, :]
|
240 |
+
cfg = tl.program_id(0)
|
241 |
+
head = tl.program_id(1)
|
242 |
+
block_id = tl.program_id(2)
|
243 |
+
|
244 |
+
start_id = block_id * BLOCK_SIZE
|
245 |
+
end_id = start_id + BLOCK_SIZE
|
246 |
+
end_id = tl.where(end_id > seq_len, seq_len, end_id)
|
247 |
+
|
248 |
+
# Load best mask idx (0 is spatial, 1 is temporal)
|
249 |
+
is_temporal = tl.load(best_mask_idx_ptr + cfg * mask_idx_stride_b + head * mask_idx_stride_h)
|
250 |
+
|
251 |
+
offset_token = tl.arange(0, BLOCK_SIZE) + start_id
|
252 |
+
offset_mask = offset_token < seq_len
|
253 |
+
offset_d = tl.arange(0, head_dim)
|
254 |
+
|
255 |
+
if is_temporal:
|
256 |
+
patch_id = offset_token // num_frame
|
257 |
+
frame_id = offset_token - patch_id * num_frame
|
258 |
+
offset_store_token = tl.where(offset_token >= seq_len - context_length, offset_token, frame_id * frame_size + patch_id)
|
259 |
+
|
260 |
+
offset_load = (cfg * hidden_states_stride_b + head * hidden_states_stride_h + offset_token[:,None] * hidden_states_stride_s) + offset_d[None,:] * hidden_states_stride_d
|
261 |
+
offset_hidden_states = hidden_states_ptr + offset_load
|
262 |
+
|
263 |
+
offset_store = (cfg * hidden_states_stride_b + head * hidden_states_stride_h + offset_store_token[:,None] * hidden_states_stride_s) + offset_d[None,:] * hidden_states_stride_d
|
264 |
+
offset_hidden_states_out = hidden_states_out_ptr + offset_store
|
265 |
+
|
266 |
+
# Maybe tune the pipeline here
|
267 |
+
hidden_states = tl.load(offset_hidden_states, mask=offset_mask[:,None])
|
268 |
+
tl.store(offset_hidden_states_out, hidden_states, mask=offset_mask[:,None])
|
269 |
+
else:
|
270 |
+
offset_load = (cfg * hidden_states_stride_b + head * hidden_states_stride_h + offset_token[:,None] * hidden_states_stride_s) + offset_d[None,:] * hidden_states_stride_d
|
271 |
+
offset_hidden_states = hidden_states_ptr + offset_load
|
272 |
+
|
273 |
+
offset_store = offset_load
|
274 |
+
offset_hidden_states_out = hidden_states_out_ptr + offset_store
|
275 |
+
|
276 |
+
# Maybe tune the pipeline here
|
277 |
+
hidden_states = tl.load(offset_hidden_states, mask=offset_mask[:,None])
|
278 |
+
tl.store(offset_hidden_states_out, hidden_states, mask=offset_mask[:,None])
|
279 |
+
|
280 |
+
|
281 |
+
def hunyuan_hidden_states_placement(hidden_states, hidden_states_out, best_mask_idx, context_length, num_frame, frame_size):
|
282 |
+
cfg, num_heads, seq_len, head_dim = hidden_states.shape
|
283 |
+
BLOCK_SIZE = 128
|
284 |
+
assert seq_len == context_length + num_frame * frame_size
|
285 |
+
|
286 |
+
grid = (cfg, num_heads, (seq_len + BLOCK_SIZE - 1) // BLOCK_SIZE)
|
287 |
+
|
288 |
+
|
289 |
+
hunyuan_hidden_states_placement_kernel[grid](
|
290 |
+
hidden_states,
|
291 |
+
hidden_states_out,
|
292 |
+
best_mask_idx,
|
293 |
+
hidden_states.stride(0), hidden_states.stride(1), hidden_states.stride(2), hidden_states.stride(3),
|
294 |
+
best_mask_idx.stride(0), best_mask_idx.stride(1),
|
295 |
+
seq_len, head_dim, context_length, num_frame, frame_size,
|
296 |
+
BLOCK_SIZE
|
297 |
+
)
|
298 |
+
|
299 |
+
return hidden_states_out
|
300 |
+
|
301 |
+
def ref_hunyuan_hidden_states_placement(hidden_states, output_hidden_states, best_mask_idx, context_length, num_frame, frame_size):
|
302 |
+
cfg, num_heads, seq_len, head_dim = hidden_states.shape
|
303 |
+
assert seq_len == context_length + num_frame * frame_size
|
304 |
+
|
305 |
+
# Spatial
|
306 |
+
output_hidden_states[best_mask_idx == 0] = hidden_states[best_mask_idx == 0]
|
307 |
+
# Temporal
|
308 |
+
output_hidden_states[best_mask_idx == 1] = hunyuan_token_reorder_to_frame_major(hidden_states[best_mask_idx == 1].unsqueeze(0), context_length, num_frame * frame_size, num_frame, frame_size).squeeze(0)
|
309 |
+
|
310 |
+
def test_hunyuan_hidden_states_placement():
|
311 |
+
|
312 |
+
context_length = 226
|
313 |
+
num_frame = 11
|
314 |
+
frame_size = 4080
|
315 |
+
|
316 |
+
cfg = 2
|
317 |
+
num_heads = 48
|
318 |
+
|
319 |
+
seq_len = context_length + num_frame * frame_size
|
320 |
+
head_dim = 64
|
321 |
+
|
322 |
+
dtype = torch.bfloat16
|
323 |
+
device = torch.device("cuda")
|
324 |
+
|
325 |
+
hidden_states = torch.randn(cfg, num_heads, seq_len, head_dim, dtype=dtype, device=device)
|
326 |
+
best_mask_idx = torch.randint(0, 2, (cfg, num_heads), device=device)
|
327 |
+
|
328 |
+
hidden_states_out1 = torch.empty_like(hidden_states)
|
329 |
+
hidden_states_out2 = torch.empty_like(hidden_states)
|
330 |
+
|
331 |
+
hunyuan_hidden_states_placement(hidden_states, hidden_states_out1, best_mask_idx, context_length, num_frame, frame_size)
|
332 |
+
ref_hunyuan_hidden_states_placement(hidden_states, hidden_states_out2, best_mask_idx, context_length, num_frame, frame_size)
|
333 |
+
|
334 |
+
torch.testing.assert_close(hidden_states_out1, hidden_states_out2)
|
335 |
+
|
336 |
+
def benchmark_hunyuan_hidden_states_placement():
|
337 |
+
import time
|
338 |
+
|
339 |
+
context_length = 226
|
340 |
+
num_frame = 11
|
341 |
+
frame_size = 4080
|
342 |
+
|
343 |
+
cfg = 2
|
344 |
+
num_heads = 48
|
345 |
+
|
346 |
+
seq_len = context_length + num_frame * frame_size
|
347 |
+
head_dim = 64
|
348 |
+
|
349 |
+
dtype = torch.bfloat16
|
350 |
+
device = torch.device("cuda")
|
351 |
+
|
352 |
+
hidden_states = torch.randn(cfg, num_heads, seq_len, head_dim, dtype=dtype, device=device)
|
353 |
+
best_mask_idx = torch.randint(0, 2, (cfg, num_heads), device=device)
|
354 |
+
|
355 |
+
hidden_states_out = torch.empty_like(hidden_states)
|
356 |
+
|
357 |
+
warmup = 10
|
358 |
+
all_iter = 1000
|
359 |
+
|
360 |
+
# warmup
|
361 |
+
for _ in range(warmup):
|
362 |
+
hunyuan_hidden_states_placement(hidden_states, hidden_states_out, best_mask_idx, context_length, num_frame, frame_size)
|
363 |
+
|
364 |
+
torch.cuda.synchronize()
|
365 |
+
start = time.time()
|
366 |
+
for _ in range(all_iter):
|
367 |
+
hunyuan_hidden_states_placement(hidden_states, hidden_states_out, best_mask_idx, context_length, num_frame, frame_size)
|
368 |
+
torch.cuda.synchronize()
|
369 |
+
end = time.time()
|
370 |
+
|
371 |
+
print(f"Triton Elapsed Time: {(end - start) / all_iter * 1e3:.2f} ms")
|
372 |
+
print(f"Triton Total Bandwidth: {hidden_states.nelement() * hidden_states.element_size() * 2 * all_iter / (end - start) / 1e9:.2f} GB/s")
|
373 |
+
|
374 |
+
torch.cuda.synchronize()
|
375 |
+
start = time.time()
|
376 |
+
for _ in range(all_iter):
|
377 |
+
ref_hunyuan_hidden_states_placement(hidden_states, hidden_states.clone(), best_mask_idx, context_length, num_frame, frame_size)
|
378 |
+
torch.cuda.synchronize()
|
379 |
+
end = time.time()
|
380 |
+
|
381 |
+
print(f"Reference Elapsed Time: {(end - start) / all_iter * 1e3:.2f} ms")
|
382 |
+
print(f"Reference Total Bandwidth: {hidden_states.nelement() * hidden_states.element_size() * 2 * all_iter / (end - start) / 1e9:.2f} GB/s")
|
383 |
+
|
384 |
+
|
385 |
+
if __name__ == "__main__":
|
386 |
+
test_hunyuan_sparse_head_placement()
|
387 |
+
benchmark_hunyuan_sparse_head_placement()
|
388 |
+
test_hunyuan_hidden_states_placement()
|
389 |
+
benchmark_hunyuan_hidden_states_placement()
|