lllyasviel commited on
Commit
06fccba
·
0 Parent(s):
.gitignore ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ hf_token.txt
2
+ hf_download/
3
+ results/
4
+ *.csv
5
+ *.onnx
6
+
7
+ # Byte-compiled / optimized / DLL files
8
+ __pycache__/
9
+ *.py[cod]
10
+ *$py.class
11
+
12
+ # C extensions
13
+ *.so
14
+
15
+ # Distribution / packaging
16
+ .Python
17
+ build/
18
+ develop-eggs/
19
+ dist/
20
+ downloads/
21
+ eggs/
22
+ .eggs/
23
+ lib/
24
+ lib64/
25
+ parts/
26
+ sdist/
27
+ var/
28
+ wheels/
29
+ share/python-wheels/
30
+ *.egg-info/
31
+ .installed.cfg
32
+ *.egg
33
+ MANIFEST
34
+
35
+ # PyInstaller
36
+ # Usually these files are written by a python script from a template
37
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
38
+ *.manifest
39
+ *.spec
40
+
41
+ # Installer logs
42
+ pip-log.txt
43
+ pip-delete-this-directory.txt
44
+
45
+ # Unit test / coverage reports
46
+ htmlcov/
47
+ .tox/
48
+ .nox/
49
+ .coverage
50
+ .coverage.*
51
+ .cache
52
+ nosetests.xml
53
+ coverage.xml
54
+ *.cover
55
+ *.py,cover
56
+ .hypothesis/
57
+ .pytest_cache/
58
+ cover/
59
+
60
+ # Translations
61
+ *.mo
62
+ *.pot
63
+
64
+ # Django stuff:
65
+ *.log
66
+ local_settings.py
67
+ db.sqlite3
68
+ db.sqlite3-journal
69
+
70
+ # Flask stuff:
71
+ instance/
72
+ .webassets-cache
73
+
74
+ # Scrapy stuff:
75
+ .scrapy
76
+
77
+ # Sphinx documentation
78
+ docs/_build/
79
+
80
+ # PyBuilder
81
+ .pybuilder/
82
+ target/
83
+
84
+ # Jupyter Notebook
85
+ .ipynb_checkpoints
86
+
87
+ # IPython
88
+ profile_default/
89
+ ipython_config.py
90
+
91
+ # pyenv
92
+ # For a library or package, you might want to ignore these files since the code is
93
+ # intended to run in multiple environments; otherwise, check them in:
94
+ # .python-version
95
+
96
+ # pipenv
97
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
98
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
99
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
100
+ # install all needed dependencies.
101
+ #Pipfile.lock
102
+
103
+ # poetry
104
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
106
+ # commonly ignored for libraries.
107
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108
+ #poetry.lock
109
+
110
+ # pdm
111
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112
+ #pdm.lock
113
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114
+ # in version control.
115
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116
+ .pdm.toml
117
+ .pdm-python
118
+ .pdm-build/
119
+
120
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121
+ __pypackages__/
122
+
123
+ # Celery stuff
124
+ celerybeat-schedule
125
+ celerybeat.pid
126
+
127
+ # SageMath parsed files
128
+ *.sage.py
129
+
130
+ # Environments
131
+ .env
132
+ .venv
133
+ env/
134
+ venv/
135
+ ENV/
136
+ env.bak/
137
+ venv.bak/
138
+
139
+ # Spyder project settings
140
+ .spyderproject
141
+ .spyproject
142
+
143
+ # Rope project settings
144
+ .ropeproject
145
+
146
+ # mkdocs documentation
147
+ /site
148
+
149
+ # mypy
150
+ .mypy_cache/
151
+ .dmypy.json
152
+ dmypy.json
153
+
154
+ # Pyre type checker
155
+ .pyre/
156
+
157
+ # pytype static type analyzer
158
+ .pytype/
159
+
160
+ # Cython debug symbols
161
+ cython_debug/
162
+
163
+ # PyCharm
164
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
167
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168
+ .idea/
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Paints-Undo
2
+
3
+ PaintsUndo: A Base Model of Drawing Behaviors in Digital Paintings
4
+
5
+ Paints-Undo is a project aimed at providing base models of human drawing behaviors with a hope that future AI models can better align with the real needs of human artists.
6
+
7
+ The name "Paints-Undo" is inspired by the similarity that, the model's outputs look like pressing the "undo" button (usually Ctrl+Z) many times in digital painting software.
8
+
9
+ Paints-Undo presents a family of models that take an image as input and then output the drawing sequence of that image. The model displays all kinds of human behaviors, including but not limited to sketching, inking, coloring, shading, transforming, left-right flipping, color curve tuning, changing the visibility of layers, and even changing the overall idea during the drawing process.
10
+
11
+ **This page does not contain any examples. All examples are in the below Git page:**
12
+
13
+ [>>> Click Here to See the Example Page <<<](https://lllyasviel.github.io/pages/paints_undo/)
14
+
15
+ # Get Started
16
+
17
+ You can deploy PaintsUndo locally via:
18
+
19
+ git clone https://github.com/lllyasviel/Paints-UNDO.git
20
+ cd Paints-UNDO
21
+ conda create -n paints_undo python=3.10
22
+ conda activate paints_undo
23
+ pip install xformers
24
+ pip install -r requirements.txt
25
+ python gradio_app.py
26
+
27
+ (If you do not know how to use these commands, you can paste those commands to ChatGPT and ask ChatGPT to explain and give more detailed instructions.)
28
+
29
+ The inference is tested with 24GB VRAM on Nvidia 4090 and 3090TI. It may also work with 16GB VRAM, but does not work with 8GB. My estimation is that, under extreme optimization (including weight offloading and sliced attention), the theoretical minimal VRAM requirement is about 10~12.5 GB.
30
+
31
+ You can expect to process one image in about 5 to 10 minutes, depending on your settings. As a typical result, you will get a video of 25 seconds at FPS 4, with resolution 320x512, or 512x320, or 384x448, or 448x384.
32
+
33
+ Because the processing time, in most cases, is significantly longer than most tasks/quota in HuggingFace Space, I personally do not highly recommend to deploy this to HuggingFace Space, to avoid placing an unnecessary burden on the HF servers.
34
+
35
+ If you do not have required computation devices and still wants an online solution, one option is to wait us to release a Colab notebook (but I am not sure if Colab free tier will work).
36
+
37
+ # Model Notes
38
+
39
+ We currently release two models `paints_undo_single_frame` and `paints_undo_multi_frame`. Let's call them single-frame model and multi-frame model.
40
+
41
+ The single-frame model takes one image and an `operation step` as input, and outputs one single image. Assuming that an artwork can always be created with 1000 human operations (for example, one brush stroke is one operation), and the `operation step` is an int number from 0 to 999. The number 0 is the finished final artwork, and the number 999 is the first brush stroke drawn on the pure white canvas. You can understand this model as an "undo" (or called Ctrl+Z) model. You input the final image, and indicate how many times you want to "Ctrl+Z", and the model will give you a "simulated" screenshot after those "Ctrl+Z"s are pressed. If your `operation step` is 100, then it means you want to simulate "Ctrl+Z" 100 times on this image get the appearance after the 100-th "Ctrl+Z".
42
+
43
+ The multi-frame model takes two images as inputs and output 16 intermediate frames between the two input images. The result is much more consistent than the single-frame model, but also much slower, less "creative", and limited in 16 frames.
44
+
45
+ In this repo, the default method is to use them together. We will first infer the single-frame model about 5-7 times to get 5-7 "keyframes", and then we use the multi-frame model to "interpolate" those keyframes to actually generate a relatively long video.
46
+
47
+ In theory this system can be used in many ways and even give infinitely long video, but in practice results are good when the final frame count is about 100-500.
48
+
49
+ ### Model Architecture (paints_undo_single_frame)
50
+
51
+ The model is a modified architecture of SD1.5 trained on different betas scheduler, clip skip, and the aforementioned `operation step` condition. To be specific, the model is trained with the betas of:
52
+
53
+ `betas = torch.linspace(0.00085, 0.020, 1000, dtype=torch.float64)`
54
+
55
+ For comparison, the original SD1.5 is trained with the betas of:
56
+
57
+ `betas = torch.linspace(0.00085 ** 0.5, 0.012 ** 0.5, 1000, dtype=torch.float64) ** 2`
58
+
59
+ You can notice the difference in the ending betas and the removed square. The choice of this scheduler is based on our internal user study.
60
+
61
+ The last layer of the text encoder CLIP ViT-L/14 is permanently removed. It is now mathematically consistent to always set CLIP Skip to 2 (if you use diffusers).
62
+
63
+ The `operation step` condition is added to layer embeddings in a way similar to SDXL's extra embeddings.
64
+
65
+ Also, since the solo purpose of this model is to process existing images, the model is strictly aligned with WD14 tagger without any other augmentations. You should always use WD14 tagger (the one in this repo) to process the input image to get the prompt. Otherwise, the results may be defective. Human-written prompts are not tested.
66
+
67
+ ### Model Architecture (paints_undo_multi_frame)
68
+
69
+ This model is trained by resuming from [VideoCrafter](https://github.com/AILab-CVC/VideoCrafter) family, but the original Crafter's `lvdm` is not used and all training/inference codes are completely implemented from scratch. (BTW, now the codes are based on modern Diffusers.) Although the initial weights are resumed from VideoCrafter, the topology of neural network is modified a lot, and the network behavior is now largely different from original Crafter after extensive training.
70
+
71
+ The overall architecture is like Crafter with 5 components, 3D-UNet, VAE, CLIP, CLIP-Vision, Image Projection.
72
+
73
+ **VAE**: The VAE is the exactly same anime VAE extracted from [ToonCrafter](https://github.com/ToonCrafter/ToonCrafter). Thanks ToonCrafter a lot for providing the excellent anime temporal VAE for Crafters.
74
+
75
+ **3D-UNet**: The 3D-UNet is modified from Crafters's `lvdm` with revisions to attention modules. Other than some minor changes in codes, the major change is that now the UNet are trained and supports temporal windows in Spatial Self Attention layers. You can change the codes in `diffusers_vdm.attention.CrossAttention.temporal_window_for_spatial_self_attention` and `temporal_window_type` to activate three types of attention windows:
76
+
77
+ 1. "prv" mode: Each frame's Spatial Self-Attention also attend to full spatial contexts of its previous frame. The first frame only attend itself.
78
+ 2. "first": Each frame's Spatial Self-Attention also attend to full spatial contexts of the first frame of the entire sequence. The first frame only attend its self.
79
+ 3. "roll": Each frame's Spatial Self-Attention also attend to full spatial contexts of its previous and next frames, based on the ordering of `torch.roll`.
80
+
81
+ Note that this is by default disabled in inference to save GPU memory.
82
+
83
+ **CLIP**: The CLIP of SD2.1.
84
+
85
+ **CLIP-Vision**: Our implementation of Clip Vision (ViT/H) that supports arbitrary aspect ratios by interpolating the positional embedding. After experimenting with linear interpolation, nearest neighbor, and Rotary Positional Encoding (RoPE), our final choice is nearest neighbor. Note that this is different from Crafter methods that resize or center-crop images to 224x224.
86
+
87
+ **Image Projection**: Our implementation of a tiny transformer that takes two frames as inputs and outputs 16 image embeddings for each frame. Note that this is different from Crafter methods that only use one image.
88
+
89
+ # Tutorial
90
+
91
+ After you get into the Gradio interface:
92
+
93
+ Step 0: Upload an image or just click an Example image on the bottom of the page.
94
+
95
+ Step 1: In the UI titled "step 1", click generate prompts to get the global prompt.
96
+
97
+ Step 2: In the UI titled "step 2", click "Generate Key Frames". You can change seeds or other parameters on the left.
98
+
99
+ Step 3: In the UI titled "step 3", click "Generate Video". You can change seeds or other parameters on the left.
100
+
101
+ # Cite
102
+
103
+ @Misc{paintsundo,
104
+ author = {Paints-Undo Team},
105
+ title = {Paints-Undo GitHub Page},
106
+ year = {2024},
107
+ }
108
+
109
+ # Disclaimer
110
+
111
+ This project aims to develop base models of human drawing behaviors, facilitating future AI systems to better meet the real needs of human artists. Users are granted the freedom to create content using this tool, but they are expected to comply with local laws and use it responsibly. Users must not employ the tool to generate false information or incite confrontation. The developers do not assume any responsibility for potential misuse by users.
diffusers_helper/cat_cond.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def unet_add_concat_conds(unet, new_channels=4):
5
+ with torch.no_grad():
6
+ new_conv_in = torch.nn.Conv2d(4 + new_channels, unet.conv_in.out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding)
7
+ new_conv_in.weight.zero_()
8
+ new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight)
9
+ new_conv_in.bias = unet.conv_in.bias
10
+ unet.conv_in = new_conv_in
11
+
12
+ unet_original_forward = unet.forward
13
+
14
+ def hooked_unet_forward(sample, timestep, encoder_hidden_states, **kwargs):
15
+ cross_attention_kwargs = {k: v for k, v in kwargs['cross_attention_kwargs'].items()}
16
+ c_concat = cross_attention_kwargs.pop('concat_conds')
17
+ kwargs['cross_attention_kwargs'] = cross_attention_kwargs
18
+
19
+ c_concat = torch.cat([c_concat] * (sample.shape[0] // c_concat.shape[0]), dim=0).to(sample)
20
+ new_sample = torch.cat([sample, c_concat], dim=1)
21
+ return unet_original_forward(new_sample, timestep, encoder_hidden_states, **kwargs)
22
+
23
+ unet.forward = hooked_unet_forward
24
+ return
diffusers_helper/code_cond.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
4
+
5
+
6
+ def unet_add_coded_conds(unet, added_number_count=1):
7
+ unet.add_time_proj = Timesteps(256, True, 0)
8
+ unet.add_embedding = TimestepEmbedding(256 * added_number_count, 1280)
9
+
10
+ def get_aug_embed(emb, encoder_hidden_states, added_cond_kwargs):
11
+ coded_conds = added_cond_kwargs.get("coded_conds")
12
+ batch_size = coded_conds.shape[0]
13
+ time_embeds = unet.add_time_proj(coded_conds.flatten())
14
+ time_embeds = time_embeds.reshape((batch_size, -1))
15
+ time_embeds = time_embeds.to(emb)
16
+ aug_emb = unet.add_embedding(time_embeds)
17
+ return aug_emb
18
+
19
+ unet.get_aug_embed = get_aug_embed
20
+
21
+ unet_original_forward = unet.forward
22
+
23
+ def hooked_unet_forward(sample, timestep, encoder_hidden_states, **kwargs):
24
+ cross_attention_kwargs = {k: v for k, v in kwargs['cross_attention_kwargs'].items()}
25
+ coded_conds = cross_attention_kwargs.pop('coded_conds')
26
+ kwargs['cross_attention_kwargs'] = cross_attention_kwargs
27
+
28
+ coded_conds = torch.cat([coded_conds] * (sample.shape[0] // coded_conds.shape[0]), dim=0).to(sample.device)
29
+ kwargs['added_cond_kwargs'] = dict(coded_conds=coded_conds)
30
+ return unet_original_forward(sample, timestep, encoder_hidden_states, **kwargs)
31
+
32
+ unet.forward = hooked_unet_forward
33
+
34
+ return
diffusers_helper/k_diffusion.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+ from tqdm import tqdm
5
+
6
+
7
+ @torch.no_grad()
8
+ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, progress_tqdm=None):
9
+ """DPM-Solver++(2M)."""
10
+ extra_args = {} if extra_args is None else extra_args
11
+ s_in = x.new_ones([x.shape[0]])
12
+ sigma_fn = lambda t: t.neg().exp()
13
+ t_fn = lambda sigma: sigma.log().neg()
14
+ old_denoised = None
15
+
16
+ bar = tqdm if progress_tqdm is None else progress_tqdm
17
+
18
+ for i in bar(range(len(sigmas) - 1)):
19
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
20
+ if callback is not None:
21
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
22
+ t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
23
+ h = t_next - t
24
+ if old_denoised is None or sigmas[i + 1] == 0:
25
+ x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised
26
+ else:
27
+ h_last = t - t_fn(sigmas[i - 1])
28
+ r = h_last / h
29
+ denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised
30
+ x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d
31
+ old_denoised = denoised
32
+ return x
33
+
34
+
35
+ class KModel:
36
+ def __init__(self, unet, timesteps=1000, linear_start=0.00085, linear_end=0.012, linear=False):
37
+ if linear:
38
+ betas = torch.linspace(linear_start, linear_end, timesteps, dtype=torch.float64)
39
+ else:
40
+ betas = torch.linspace(linear_start ** 0.5, linear_end ** 0.5, timesteps, dtype=torch.float64) ** 2
41
+
42
+ alphas = 1. - betas
43
+ alphas_cumprod = torch.tensor(np.cumprod(alphas, axis=0), dtype=torch.float32)
44
+
45
+ self.sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
46
+ self.log_sigmas = self.sigmas.log()
47
+ self.sigma_data = 1.0
48
+ self.unet = unet
49
+ return
50
+
51
+ @property
52
+ def sigma_min(self):
53
+ return self.sigmas[0]
54
+
55
+ @property
56
+ def sigma_max(self):
57
+ return self.sigmas[-1]
58
+
59
+ def timestep(self, sigma):
60
+ log_sigma = sigma.log()
61
+ dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
62
+ return dists.abs().argmin(dim=0).view(sigma.shape).to(sigma.device)
63
+
64
+ def get_sigmas_karras(self, n, rho=7.):
65
+ ramp = torch.linspace(0, 1, n)
66
+ min_inv_rho = self.sigma_min ** (1 / rho)
67
+ max_inv_rho = self.sigma_max ** (1 / rho)
68
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
69
+ return torch.cat([sigmas, sigmas.new_zeros([1])])
70
+
71
+ def __call__(self, x, sigma, **extra_args):
72
+ x_ddim_space = x / (sigma[:, None, None, None] ** 2 + self.sigma_data ** 2) ** 0.5
73
+ x_ddim_space = x_ddim_space.to(dtype=self.unet.dtype)
74
+ t = self.timestep(sigma)
75
+ cfg_scale = extra_args['cfg_scale']
76
+ eps_positive = self.unet(x_ddim_space, t, return_dict=False, **extra_args['positive'])[0]
77
+ eps_negative = self.unet(x_ddim_space, t, return_dict=False, **extra_args['negative'])[0]
78
+ noise_pred = eps_negative + cfg_scale * (eps_positive - eps_negative)
79
+ return x - noise_pred * sigma[:, None, None, None]
80
+
81
+
82
+ class KDiffusionSampler:
83
+ def __init__(self, unet, **kwargs):
84
+ self.unet = unet
85
+ self.k_model = KModel(unet=unet, **kwargs)
86
+
87
+ @torch.inference_mode()
88
+ def __call__(
89
+ self,
90
+ initial_latent = None,
91
+ strength = 1.0,
92
+ num_inference_steps = 25,
93
+ guidance_scale = 5.0,
94
+ batch_size = 1,
95
+ generator = None,
96
+ prompt_embeds = None,
97
+ negative_prompt_embeds = None,
98
+ cross_attention_kwargs = None,
99
+ same_noise_in_batch = False,
100
+ progress_tqdm = None,
101
+ ):
102
+
103
+ device = self.unet.device
104
+
105
+ # Sigmas
106
+
107
+ sigmas = self.k_model.get_sigmas_karras(int(num_inference_steps/strength))
108
+ sigmas = sigmas[-(num_inference_steps + 1):].to(device)
109
+
110
+ # Initial latents
111
+
112
+ if same_noise_in_batch:
113
+ noise = torch.randn(initial_latent.shape, generator=generator, device=device, dtype=self.unet.dtype).repeat(batch_size, 1, 1, 1)
114
+ initial_latent = initial_latent.repeat(batch_size, 1, 1, 1).to(device=device, dtype=self.unet.dtype)
115
+ else:
116
+ initial_latent = initial_latent.repeat(batch_size, 1, 1, 1).to(device=device, dtype=self.unet.dtype)
117
+ noise = torch.randn(initial_latent.shape, generator=generator, device=device, dtype=self.unet.dtype)
118
+
119
+ latents = initial_latent + noise * sigmas[0].to(initial_latent)
120
+
121
+ # Batch
122
+
123
+ latents = latents.to(device)
124
+ prompt_embeds = prompt_embeds.repeat(batch_size, 1, 1).to(device)
125
+ negative_prompt_embeds = negative_prompt_embeds.repeat(batch_size, 1, 1).to(device)
126
+
127
+ # Feeds
128
+
129
+ sampler_kwargs = dict(
130
+ cfg_scale=guidance_scale,
131
+ positive=dict(
132
+ encoder_hidden_states=prompt_embeds,
133
+ cross_attention_kwargs=cross_attention_kwargs
134
+ ),
135
+ negative=dict(
136
+ encoder_hidden_states=negative_prompt_embeds,
137
+ cross_attention_kwargs=cross_attention_kwargs,
138
+ )
139
+ )
140
+
141
+ # Sample
142
+
143
+ results = sample_dpmpp_2m(self.k_model, latents, sigmas, extra_args=sampler_kwargs, progress_tqdm=progress_tqdm)
144
+
145
+ return results
diffusers_helper/utils.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import random
4
+ import glob
5
+ import torch
6
+ import einops
7
+ import torchvision
8
+
9
+ import safetensors.torch as sf
10
+
11
+
12
+ def write_to_json(data, file_path):
13
+ temp_file_path = file_path + ".tmp"
14
+ with open(temp_file_path, 'wt', encoding='utf-8') as temp_file:
15
+ json.dump(data, temp_file, indent=4)
16
+ os.replace(temp_file_path, file_path)
17
+ return
18
+
19
+
20
+ def read_from_json(file_path):
21
+ with open(file_path, 'rt', encoding='utf-8') as file:
22
+ data = json.load(file)
23
+ return data
24
+
25
+
26
+ def get_active_parameters(m):
27
+ return {k:v for k, v in m.named_parameters() if v.requires_grad}
28
+
29
+
30
+ def cast_training_params(m, dtype=torch.float32):
31
+ for param in m.parameters():
32
+ if param.requires_grad:
33
+ param.data = param.to(dtype)
34
+ return
35
+
36
+
37
+ def set_attr_recursive(obj, attr, value):
38
+ attrs = attr.split(".")
39
+ for name in attrs[:-1]:
40
+ obj = getattr(obj, name)
41
+ setattr(obj, attrs[-1], value)
42
+ return
43
+
44
+
45
+ @torch.no_grad()
46
+ def batch_mixture(a, b, probability_a=0.5, mask_a=None):
47
+ assert a.shape == b.shape, "Tensors must have the same shape"
48
+ batch_size = a.size(0)
49
+
50
+ if mask_a is None:
51
+ mask_a = torch.rand(batch_size) < probability_a
52
+
53
+ mask_a = mask_a.to(a.device)
54
+ mask_a = mask_a.reshape((batch_size,) + (1,) * (a.dim() - 1))
55
+ result = torch.where(mask_a, a, b)
56
+ return result
57
+
58
+
59
+ @torch.no_grad()
60
+ def zero_module(module):
61
+ for p in module.parameters():
62
+ p.detach().zero_()
63
+ return module
64
+
65
+
66
+ def load_last_state(model, folder='accelerator_output'):
67
+ file_pattern = os.path.join(folder, '**', 'model.safetensors')
68
+ files = glob.glob(file_pattern, recursive=True)
69
+
70
+ if not files:
71
+ print("No model.safetensors files found in the specified folder.")
72
+ return
73
+
74
+ newest_file = max(files, key=os.path.getmtime)
75
+ state_dict = sf.load_file(newest_file)
76
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
77
+
78
+ if missing_keys:
79
+ print("Missing keys:", missing_keys)
80
+ if unexpected_keys:
81
+ print("Unexpected keys:", unexpected_keys)
82
+
83
+ print("Loaded model state from:", newest_file)
84
+ return
85
+
86
+
87
+ def generate_random_prompt_from_tags(tags_str, min_length=3, max_length=32):
88
+ tags = tags_str.split(', ')
89
+ tags = random.sample(tags, k=min(random.randint(min_length, max_length), len(tags)))
90
+ prompt = ', '.join(tags)
91
+ return prompt
92
+
93
+
94
+ def save_bcthw_as_mp4(x, output_filename, fps=10):
95
+ b, c, t, h, w = x.shape
96
+
97
+ per_row = b
98
+ for p in [6, 5, 4, 3, 2]:
99
+ if b % p == 0:
100
+ per_row = p
101
+ break
102
+
103
+ os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True)
104
+ x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5
105
+ x = x.detach().cpu().to(torch.uint8)
106
+ x = einops.rearrange(x, '(m n) c t h w -> t (m h) (n w) c', n=per_row)
107
+ torchvision.io.write_video(output_filename, x, fps=fps, video_codec='h264', options={'crf': '0'})
108
+ return x
109
+
110
+
111
+ def save_bcthw_as_png(x, output_filename):
112
+ os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True)
113
+ x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5
114
+ x = x.detach().cpu().to(torch.uint8)
115
+ x = einops.rearrange(x, 'b c t h w -> c (b h) (t w)')
116
+ torchvision.io.write_png(x, output_filename)
117
+ return output_filename
118
+
119
+
120
+ def add_tensors_with_padding(tensor1, tensor2):
121
+ if tensor1.shape == tensor2.shape:
122
+ return tensor1 + tensor2
123
+
124
+ shape1 = tensor1.shape
125
+ shape2 = tensor2.shape
126
+
127
+ new_shape = tuple(max(s1, s2) for s1, s2 in zip(shape1, shape2))
128
+
129
+ padded_tensor1 = torch.zeros(new_shape)
130
+ padded_tensor2 = torch.zeros(new_shape)
131
+
132
+ padded_tensor1[tuple(slice(0, s) for s in shape1)] = tensor1
133
+ padded_tensor2[tuple(slice(0, s) for s in shape2)] = tensor2
134
+
135
+ result = padded_tensor1 + padded_tensor2
136
+ return result
diffusers_vdm/attention.py ADDED
@@ -0,0 +1,385 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import xformers.ops
3
+ import torch.nn.functional as F
4
+
5
+ from torch import nn
6
+ from einops import rearrange, repeat
7
+ from functools import partial
8
+ from diffusers_vdm.basics import zero_module, checkpoint, default, make_temporal_window
9
+
10
+
11
+ def sdp(q, k, v, heads):
12
+ b, _, C = q.shape
13
+ dim_head = C // heads
14
+
15
+ q, k, v = map(
16
+ lambda t: t.unsqueeze(3)
17
+ .reshape(b, t.shape[1], heads, dim_head)
18
+ .permute(0, 2, 1, 3)
19
+ .reshape(b * heads, t.shape[1], dim_head)
20
+ .contiguous(),
21
+ (q, k, v),
22
+ )
23
+
24
+ out = xformers.ops.memory_efficient_attention(q, k, v)
25
+
26
+ out = (
27
+ out.unsqueeze(0)
28
+ .reshape(b, heads, out.shape[1], dim_head)
29
+ .permute(0, 2, 1, 3)
30
+ .reshape(b, out.shape[1], heads * dim_head)
31
+ )
32
+
33
+ return out
34
+
35
+
36
+ class RelativePosition(nn.Module):
37
+ """ https://github.com/evelinehong/Transformer_Relative_Position_PyTorch/blob/master/relative_position.py """
38
+
39
+ def __init__(self, num_units, max_relative_position):
40
+ super().__init__()
41
+ self.num_units = num_units
42
+ self.max_relative_position = max_relative_position
43
+ self.embeddings_table = nn.Parameter(torch.Tensor(max_relative_position * 2 + 1, num_units))
44
+ nn.init.xavier_uniform_(self.embeddings_table)
45
+
46
+ def forward(self, length_q, length_k):
47
+ device = self.embeddings_table.device
48
+ range_vec_q = torch.arange(length_q, device=device)
49
+ range_vec_k = torch.arange(length_k, device=device)
50
+ distance_mat = range_vec_k[None, :] - range_vec_q[:, None]
51
+ distance_mat_clipped = torch.clamp(distance_mat, -self.max_relative_position, self.max_relative_position)
52
+ final_mat = distance_mat_clipped + self.max_relative_position
53
+ final_mat = final_mat.long()
54
+ embeddings = self.embeddings_table[final_mat]
55
+ return embeddings
56
+
57
+
58
+ class CrossAttention(nn.Module):
59
+
60
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.,
61
+ relative_position=False, temporal_length=None, video_length=None, image_cross_attention=False,
62
+ image_cross_attention_scale=1.0, image_cross_attention_scale_learnable=False,
63
+ text_context_len=77, temporal_window_for_spatial_self_attention=False):
64
+ super().__init__()
65
+ inner_dim = dim_head * heads
66
+ context_dim = default(context_dim, query_dim)
67
+
68
+ self.scale = dim_head**-0.5
69
+ self.heads = heads
70
+ self.dim_head = dim_head
71
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
72
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
73
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
74
+
75
+ self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
76
+
77
+ self.is_temporal_attention = temporal_length is not None
78
+
79
+ self.relative_position = relative_position
80
+ if self.relative_position:
81
+ assert self.is_temporal_attention
82
+ self.relative_position_k = RelativePosition(num_units=dim_head, max_relative_position=temporal_length)
83
+ self.relative_position_v = RelativePosition(num_units=dim_head, max_relative_position=temporal_length)
84
+
85
+ self.video_length = video_length
86
+ self.temporal_window_for_spatial_self_attention = temporal_window_for_spatial_self_attention
87
+ self.temporal_window_type = 'prv'
88
+
89
+ self.image_cross_attention = image_cross_attention
90
+ self.image_cross_attention_scale = image_cross_attention_scale
91
+ self.text_context_len = text_context_len
92
+ self.image_cross_attention_scale_learnable = image_cross_attention_scale_learnable
93
+ if self.image_cross_attention:
94
+ self.to_k_ip = nn.Linear(context_dim, inner_dim, bias=False)
95
+ self.to_v_ip = nn.Linear(context_dim, inner_dim, bias=False)
96
+ if image_cross_attention_scale_learnable:
97
+ self.register_parameter('alpha', nn.Parameter(torch.tensor(0.)) )
98
+
99
+ def forward(self, x, context=None, mask=None):
100
+ if self.is_temporal_attention:
101
+ return self.temporal_forward(x, context=context, mask=mask)
102
+ else:
103
+ return self.spatial_forward(x, context=context, mask=mask)
104
+
105
+ def temporal_forward(self, x, context=None, mask=None):
106
+ assert mask is None, 'Attention mask not implemented!'
107
+ assert context is None, 'Temporal attention only supports self attention!'
108
+
109
+ q = self.to_q(x)
110
+ k = self.to_k(x)
111
+ v = self.to_v(x)
112
+
113
+ out = sdp(q, k, v, self.heads)
114
+
115
+ return self.to_out(out)
116
+
117
+ def spatial_forward(self, x, context=None, mask=None):
118
+ assert mask is None, 'Attention mask not implemented!'
119
+
120
+ spatial_self_attn = (context is None)
121
+ k_ip, v_ip, out_ip = None, None, None
122
+
123
+ q = self.to_q(x)
124
+ context = default(context, x)
125
+
126
+ if spatial_self_attn:
127
+ k = self.to_k(context)
128
+ v = self.to_v(context)
129
+
130
+ if self.temporal_window_for_spatial_self_attention:
131
+ k = make_temporal_window(k, t=self.video_length, method=self.temporal_window_type)
132
+ v = make_temporal_window(v, t=self.video_length, method=self.temporal_window_type)
133
+ elif self.image_cross_attention:
134
+ context, context_image = context
135
+ k = self.to_k(context)
136
+ v = self.to_v(context)
137
+ k_ip = self.to_k_ip(context_image)
138
+ v_ip = self.to_v_ip(context_image)
139
+ else:
140
+ raise NotImplementedError('Traditional prompt-only attention without IP-Adapter is illegal now.')
141
+
142
+ out = sdp(q, k, v, self.heads)
143
+
144
+ if k_ip is not None:
145
+ out_ip = sdp(q, k_ip, v_ip, self.heads)
146
+
147
+ if self.image_cross_attention_scale_learnable:
148
+ out = out + self.image_cross_attention_scale * out_ip * (torch.tanh(self.alpha) + 1)
149
+ else:
150
+ out = out + self.image_cross_attention_scale * out_ip
151
+
152
+ return self.to_out(out)
153
+
154
+
155
+ class BasicTransformerBlock(nn.Module):
156
+
157
+ def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
158
+ disable_self_attn=False, attention_cls=None, video_length=None, image_cross_attention=False, image_cross_attention_scale=1.0, image_cross_attention_scale_learnable=False, text_context_len=77):
159
+ super().__init__()
160
+ attn_cls = CrossAttention if attention_cls is None else attention_cls
161
+ self.disable_self_attn = disable_self_attn
162
+ self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
163
+ context_dim=context_dim if self.disable_self_attn else None, video_length=video_length)
164
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
165
+ self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout, video_length=video_length, image_cross_attention=image_cross_attention, image_cross_attention_scale=image_cross_attention_scale, image_cross_attention_scale_learnable=image_cross_attention_scale_learnable,text_context_len=text_context_len)
166
+ self.image_cross_attention = image_cross_attention
167
+
168
+ self.norm1 = nn.LayerNorm(dim)
169
+ self.norm2 = nn.LayerNorm(dim)
170
+ self.norm3 = nn.LayerNorm(dim)
171
+ self.checkpoint = checkpoint
172
+
173
+
174
+ def forward(self, x, context=None, mask=None, **kwargs):
175
+ ## implementation tricks: because checkpointing doesn't support non-tensor (e.g. None or scalar) arguments
176
+ input_tuple = (x,) ## should not be (x), otherwise *input_tuple will decouple x into multiple arguments
177
+ if context is not None:
178
+ input_tuple = (x, context)
179
+ if mask is not None:
180
+ forward_mask = partial(self._forward, mask=mask)
181
+ return checkpoint(forward_mask, (x,), self.parameters(), self.checkpoint)
182
+ return checkpoint(self._forward, input_tuple, self.parameters(), self.checkpoint)
183
+
184
+
185
+ def _forward(self, x, context=None, mask=None):
186
+ x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None, mask=mask) + x
187
+ x = self.attn2(self.norm2(x), context=context, mask=mask) + x
188
+ x = self.ff(self.norm3(x)) + x
189
+ return x
190
+
191
+
192
+ class SpatialTransformer(nn.Module):
193
+ """
194
+ Transformer block for image-like data in spatial axis.
195
+ First, project the input (aka embedding)
196
+ and reshape to b, t, d.
197
+ Then apply standard transformer action.
198
+ Finally, reshape to image
199
+ NEW: use_linear for more efficiency instead of the 1x1 convs
200
+ """
201
+
202
+ def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0., context_dim=None,
203
+ use_checkpoint=True, disable_self_attn=False, use_linear=False, video_length=None,
204
+ image_cross_attention=False, image_cross_attention_scale_learnable=False):
205
+ super().__init__()
206
+ self.in_channels = in_channels
207
+ inner_dim = n_heads * d_head
208
+ self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
209
+ if not use_linear:
210
+ self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
211
+ else:
212
+ self.proj_in = nn.Linear(in_channels, inner_dim)
213
+
214
+ attention_cls = None
215
+ self.transformer_blocks = nn.ModuleList([
216
+ BasicTransformerBlock(
217
+ inner_dim,
218
+ n_heads,
219
+ d_head,
220
+ dropout=dropout,
221
+ context_dim=context_dim,
222
+ disable_self_attn=disable_self_attn,
223
+ checkpoint=use_checkpoint,
224
+ attention_cls=attention_cls,
225
+ video_length=video_length,
226
+ image_cross_attention=image_cross_attention,
227
+ image_cross_attention_scale_learnable=image_cross_attention_scale_learnable,
228
+ ) for d in range(depth)
229
+ ])
230
+ if not use_linear:
231
+ self.proj_out = zero_module(nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0))
232
+ else:
233
+ self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
234
+ self.use_linear = use_linear
235
+
236
+
237
+ def forward(self, x, context=None, **kwargs):
238
+ b, c, h, w = x.shape
239
+ x_in = x
240
+ x = self.norm(x)
241
+ if not self.use_linear:
242
+ x = self.proj_in(x)
243
+ x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
244
+ if self.use_linear:
245
+ x = self.proj_in(x)
246
+ for i, block in enumerate(self.transformer_blocks):
247
+ x = block(x, context=context, **kwargs)
248
+ if self.use_linear:
249
+ x = self.proj_out(x)
250
+ x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
251
+ if not self.use_linear:
252
+ x = self.proj_out(x)
253
+ return x + x_in
254
+
255
+
256
+ class TemporalTransformer(nn.Module):
257
+ """
258
+ Transformer block for image-like data in temporal axis.
259
+ First, reshape to b, t, d.
260
+ Then apply standard transformer action.
261
+ Finally, reshape to image
262
+ """
263
+ def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0., context_dim=None,
264
+ use_checkpoint=True, use_linear=False, only_self_att=True, causal_attention=False, causal_block_size=1,
265
+ relative_position=False, temporal_length=None):
266
+ super().__init__()
267
+ self.only_self_att = only_self_att
268
+ self.relative_position = relative_position
269
+ self.causal_attention = causal_attention
270
+ self.causal_block_size = causal_block_size
271
+
272
+ self.in_channels = in_channels
273
+ inner_dim = n_heads * d_head
274
+ self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
275
+ self.proj_in = nn.Conv1d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
276
+ if not use_linear:
277
+ self.proj_in = nn.Conv1d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
278
+ else:
279
+ self.proj_in = nn.Linear(in_channels, inner_dim)
280
+
281
+ if relative_position:
282
+ assert(temporal_length is not None)
283
+ attention_cls = partial(CrossAttention, relative_position=True, temporal_length=temporal_length)
284
+ else:
285
+ attention_cls = partial(CrossAttention, temporal_length=temporal_length)
286
+ if self.causal_attention:
287
+ assert(temporal_length is not None)
288
+ self.mask = torch.tril(torch.ones([1, temporal_length, temporal_length]))
289
+
290
+ if self.only_self_att:
291
+ context_dim = None
292
+ self.transformer_blocks = nn.ModuleList([
293
+ BasicTransformerBlock(
294
+ inner_dim,
295
+ n_heads,
296
+ d_head,
297
+ dropout=dropout,
298
+ context_dim=context_dim,
299
+ attention_cls=attention_cls,
300
+ checkpoint=use_checkpoint) for d in range(depth)
301
+ ])
302
+ if not use_linear:
303
+ self.proj_out = zero_module(nn.Conv1d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0))
304
+ else:
305
+ self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
306
+ self.use_linear = use_linear
307
+
308
+ def forward(self, x, context=None):
309
+ b, c, t, h, w = x.shape
310
+ x_in = x
311
+ x = self.norm(x)
312
+ x = rearrange(x, 'b c t h w -> (b h w) c t').contiguous()
313
+ if not self.use_linear:
314
+ x = self.proj_in(x)
315
+ x = rearrange(x, 'bhw c t -> bhw t c').contiguous()
316
+ if self.use_linear:
317
+ x = self.proj_in(x)
318
+
319
+ temp_mask = None
320
+ if self.causal_attention:
321
+ # slice the from mask map
322
+ temp_mask = self.mask[:,:t,:t].to(x.device)
323
+
324
+ if temp_mask is not None:
325
+ mask = temp_mask.to(x.device)
326
+ mask = repeat(mask, 'l i j -> (l bhw) i j', bhw=b*h*w)
327
+ else:
328
+ mask = None
329
+
330
+ if self.only_self_att:
331
+ ## note: if no context is given, cross-attention defaults to self-attention
332
+ for i, block in enumerate(self.transformer_blocks):
333
+ x = block(x, mask=mask)
334
+ x = rearrange(x, '(b hw) t c -> b hw t c', b=b).contiguous()
335
+ else:
336
+ x = rearrange(x, '(b hw) t c -> b hw t c', b=b).contiguous()
337
+ context = rearrange(context, '(b t) l con -> b t l con', t=t).contiguous()
338
+ for i, block in enumerate(self.transformer_blocks):
339
+ # calculate each batch one by one (since number in shape could not greater then 65,535 for some package)
340
+ for j in range(b):
341
+ context_j = repeat(
342
+ context[j],
343
+ 't l con -> (t r) l con', r=(h * w) // t, t=t).contiguous()
344
+ ## note: causal mask will not applied in cross-attention case
345
+ x[j] = block(x[j], context=context_j)
346
+
347
+ if self.use_linear:
348
+ x = self.proj_out(x)
349
+ x = rearrange(x, 'b (h w) t c -> b c t h w', h=h, w=w).contiguous()
350
+ if not self.use_linear:
351
+ x = rearrange(x, 'b hw t c -> (b hw) c t').contiguous()
352
+ x = self.proj_out(x)
353
+ x = rearrange(x, '(b h w) c t -> b c t h w', b=b, h=h, w=w).contiguous()
354
+
355
+ return x + x_in
356
+
357
+
358
+ class GEGLU(nn.Module):
359
+ def __init__(self, dim_in, dim_out):
360
+ super().__init__()
361
+ self.proj = nn.Linear(dim_in, dim_out * 2)
362
+
363
+ def forward(self, x):
364
+ x, gate = self.proj(x).chunk(2, dim=-1)
365
+ return x * F.gelu(gate)
366
+
367
+
368
+ class FeedForward(nn.Module):
369
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
370
+ super().__init__()
371
+ inner_dim = int(dim * mult)
372
+ dim_out = default(dim_out, dim)
373
+ project_in = nn.Sequential(
374
+ nn.Linear(dim, inner_dim),
375
+ nn.GELU()
376
+ ) if not glu else GEGLU(dim, inner_dim)
377
+
378
+ self.net = nn.Sequential(
379
+ project_in,
380
+ nn.Dropout(dropout),
381
+ nn.Linear(inner_dim, dim_out)
382
+ )
383
+
384
+ def forward(self, x):
385
+ return self.net(x)
diffusers_vdm/basics.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # adopted from
2
+ # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
3
+ # and
4
+ # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
5
+ # and
6
+ # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
7
+ #
8
+ # thanks!
9
+
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import einops
14
+
15
+ from inspect import isfunction
16
+
17
+
18
+ def zero_module(module):
19
+ """
20
+ Zero out the parameters of a module and return it.
21
+ """
22
+ for p in module.parameters():
23
+ p.detach().zero_()
24
+ return module
25
+
26
+ def scale_module(module, scale):
27
+ """
28
+ Scale the parameters of a module and return it.
29
+ """
30
+ for p in module.parameters():
31
+ p.detach().mul_(scale)
32
+ return module
33
+
34
+
35
+ def conv_nd(dims, *args, **kwargs):
36
+ """
37
+ Create a 1D, 2D, or 3D convolution module.
38
+ """
39
+ if dims == 1:
40
+ return nn.Conv1d(*args, **kwargs)
41
+ elif dims == 2:
42
+ return nn.Conv2d(*args, **kwargs)
43
+ elif dims == 3:
44
+ return nn.Conv3d(*args, **kwargs)
45
+ raise ValueError(f"unsupported dimensions: {dims}")
46
+
47
+
48
+ def linear(*args, **kwargs):
49
+ """
50
+ Create a linear module.
51
+ """
52
+ return nn.Linear(*args, **kwargs)
53
+
54
+
55
+ def avg_pool_nd(dims, *args, **kwargs):
56
+ """
57
+ Create a 1D, 2D, or 3D average pooling module.
58
+ """
59
+ if dims == 1:
60
+ return nn.AvgPool1d(*args, **kwargs)
61
+ elif dims == 2:
62
+ return nn.AvgPool2d(*args, **kwargs)
63
+ elif dims == 3:
64
+ return nn.AvgPool3d(*args, **kwargs)
65
+ raise ValueError(f"unsupported dimensions: {dims}")
66
+
67
+
68
+ def nonlinearity(type='silu'):
69
+ if type == 'silu':
70
+ return nn.SiLU()
71
+ elif type == 'leaky_relu':
72
+ return nn.LeakyReLU()
73
+
74
+
75
+ def normalization(channels, num_groups=32):
76
+ """
77
+ Make a standard normalization layer.
78
+ :param channels: number of input channels.
79
+ :return: an nn.Module for normalization.
80
+ """
81
+ return nn.GroupNorm(num_groups, channels)
82
+
83
+
84
+ def default(val, d):
85
+ if exists(val):
86
+ return val
87
+ return d() if isfunction(d) else d
88
+
89
+
90
+ def exists(val):
91
+ return val is not None
92
+
93
+
94
+ def extract_into_tensor(a, t, x_shape):
95
+ b, *_ = t.shape
96
+ out = a.gather(-1, t)
97
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
98
+
99
+
100
+ def make_temporal_window(x, t, method):
101
+ assert method in ['roll', 'prv', 'first']
102
+
103
+ if method == 'roll':
104
+ m = einops.rearrange(x, '(b t) d c -> b t d c', t=t)
105
+ l = torch.roll(m, shifts=1, dims=1)
106
+ r = torch.roll(m, shifts=-1, dims=1)
107
+
108
+ recon = torch.cat([l, m, r], dim=2)
109
+ del l, m, r
110
+
111
+ recon = einops.rearrange(recon, 'b t d c -> (b t) d c')
112
+ return recon
113
+
114
+ if method == 'prv':
115
+ x = einops.rearrange(x, '(b t) d c -> b t d c', t=t)
116
+ prv = torch.cat([x[:, :1], x[:, :-1]], dim=1)
117
+
118
+ recon = torch.cat([x, prv], dim=2)
119
+ del x, prv
120
+
121
+ recon = einops.rearrange(recon, 'b t d c -> (b t) d c')
122
+ return recon
123
+
124
+ if method == 'first':
125
+ x = einops.rearrange(x, '(b t) d c -> b t d c', t=t)
126
+ prv = x[:, [0], :, :].repeat(1, t, 1, 1)
127
+
128
+ recon = torch.cat([x, prv], dim=2)
129
+ del x, prv
130
+
131
+ recon = einops.rearrange(recon, 'b t d c -> (b t) d c')
132
+ return recon
133
+
134
+
135
+ def checkpoint(func, inputs, params, flag):
136
+ """
137
+ Evaluate a function without caching intermediate activations, allowing for
138
+ reduced memory at the expense of extra compute in the backward pass.
139
+ :param func: the function to evaluate.
140
+ :param inputs: the argument sequence to pass to `func`.
141
+ :param params: a sequence of parameters `func` depends on but does not
142
+ explicitly take as arguments.
143
+ :param flag: if False, disable gradient checkpointing.
144
+ """
145
+ if flag:
146
+ return torch.utils.checkpoint.checkpoint(func, *inputs, use_reentrant=False)
147
+ else:
148
+ return func(*inputs)
diffusers_vdm/dynamic_tsnr_sampler.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # everything that can improve v-prediction model
2
+ # dynamic scaling + tsnr + beta modifier + dynamic cfg rescale + ...
3
+ # written by lvmin at stanford 2024
4
+
5
+ import torch
6
+ import numpy as np
7
+
8
+ from tqdm import tqdm
9
+ from functools import partial
10
+ from diffusers_vdm.basics import extract_into_tensor
11
+
12
+
13
+ to_torch = partial(torch.tensor, dtype=torch.float32)
14
+
15
+
16
+ def rescale_zero_terminal_snr(betas):
17
+ # Convert betas to alphas_bar_sqrt
18
+ alphas = 1.0 - betas
19
+ alphas_cumprod = np.cumprod(alphas, axis=0)
20
+ alphas_bar_sqrt = np.sqrt(alphas_cumprod)
21
+
22
+ # Store old values.
23
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].copy()
24
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].copy()
25
+
26
+ # Shift so the last timestep is zero.
27
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
28
+
29
+ # Scale so the first timestep is back to the old value.
30
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
31
+
32
+ # Convert alphas_bar_sqrt to betas
33
+ alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
34
+ alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
35
+ alphas = np.concatenate([alphas_bar[0:1], alphas])
36
+ betas = 1 - alphas
37
+
38
+ return betas
39
+
40
+
41
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
42
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
43
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
44
+
45
+ # rescale the results from guidance (fixes overexposure)
46
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
47
+
48
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
49
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
50
+
51
+ return noise_cfg
52
+
53
+
54
+ class SamplerDynamicTSNR(torch.nn.Module):
55
+ @torch.no_grad()
56
+ def __init__(self, unet, terminal_scale=0.7):
57
+ super().__init__()
58
+ self.unet = unet
59
+
60
+ self.is_v = True
61
+ self.n_timestep = 1000
62
+ self.guidance_rescale = 0.7
63
+
64
+ linear_start = 0.00085
65
+ linear_end = 0.012
66
+
67
+ betas = np.linspace(linear_start ** 0.5, linear_end ** 0.5, self.n_timestep, dtype=np.float64) ** 2
68
+ betas = rescale_zero_terminal_snr(betas)
69
+ alphas = 1. - betas
70
+
71
+ alphas_cumprod = np.cumprod(alphas, axis=0)
72
+
73
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod).to(unet.device))
74
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)).to(unet.device))
75
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)).to(unet.device))
76
+
77
+ # Dynamic TSNR
78
+ turning_step = 400
79
+ scale_arr = np.concatenate([
80
+ np.linspace(1.0, terminal_scale, turning_step),
81
+ np.full(self.n_timestep - turning_step, terminal_scale)
82
+ ])
83
+ self.register_buffer('scale_arr', to_torch(scale_arr).to(unet.device))
84
+
85
+ def predict_eps_from_z_and_v(self, x_t, t, v):
86
+ return self.sqrt_alphas_cumprod[t] * v + self.sqrt_one_minus_alphas_cumprod[t] * x_t
87
+
88
+ def predict_start_from_z_and_v(self, x_t, t, v):
89
+ return self.sqrt_alphas_cumprod[t] * x_t - self.sqrt_one_minus_alphas_cumprod[t] * v
90
+
91
+ def q_sample(self, x0, t, noise):
92
+ return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x0.shape) * x0 +
93
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
94
+
95
+ def get_v(self, x0, t, noise):
96
+ return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x0.shape) * noise -
97
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x0.shape) * x0)
98
+
99
+ def dynamic_x0_rescale(self, x0, t):
100
+ return x0 * extract_into_tensor(self.scale_arr, t, x0.shape)
101
+
102
+ @torch.no_grad()
103
+ def get_ground_truth(self, x0, noise, t):
104
+ x0 = self.dynamic_x0_rescale(x0, t)
105
+ xt = self.q_sample(x0, t, noise)
106
+ target = self.get_v(x0, t, noise) if self.is_v else noise
107
+ return xt, target
108
+
109
+ def get_uniform_trailing_steps(self, steps):
110
+ c = self.n_timestep / steps
111
+ ddim_timesteps = np.flip(np.round(np.arange(self.n_timestep, 0, -c))).astype(np.int64)
112
+ steps_out = ddim_timesteps - 1
113
+ return torch.tensor(steps_out, device=self.unet.device, dtype=torch.long)
114
+
115
+ @torch.no_grad()
116
+ def forward(self, latent_shape, steps, extra_args, progress_tqdm=None):
117
+ bar = tqdm if progress_tqdm is None else progress_tqdm
118
+
119
+ eta = 1.0
120
+
121
+ timesteps = self.get_uniform_trailing_steps(steps)
122
+ timesteps_prev = torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))
123
+
124
+ x = torch.randn(latent_shape, device=self.unet.device, dtype=self.unet.dtype)
125
+
126
+ alphas = self.alphas_cumprod[timesteps]
127
+ alphas_prev = self.alphas_cumprod[timesteps_prev]
128
+ scale_arr = self.scale_arr[timesteps]
129
+ scale_arr_prev = self.scale_arr[timesteps_prev]
130
+
131
+ sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
132
+ sigmas = eta * np.sqrt((1 - alphas_prev.cpu().numpy()) / (1 - alphas.cpu()) * (1 - alphas.cpu() / alphas_prev.cpu().numpy()))
133
+
134
+ s_in = x.new_ones((x.shape[0]))
135
+ s_x = x.new_ones((x.shape[0], ) + (1, ) * (x.ndim - 1))
136
+ for i in bar(range(len(timesteps))):
137
+ index = len(timesteps) - 1 - i
138
+ t = timesteps[index].item()
139
+
140
+ model_output = self.model_apply(x, t * s_in, **extra_args)
141
+
142
+ if self.is_v:
143
+ e_t = self.predict_eps_from_z_and_v(x, t, model_output)
144
+ else:
145
+ e_t = model_output
146
+
147
+ a_prev = alphas_prev[index].item() * s_x
148
+ sigma_t = sigmas[index].item() * s_x
149
+
150
+ if self.is_v:
151
+ pred_x0 = self.predict_start_from_z_and_v(x, t, model_output)
152
+ else:
153
+ a_t = alphas[index].item() * s_x
154
+ sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_x
155
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
156
+
157
+ # dynamic rescale
158
+ scale_t = scale_arr[index].item() * s_x
159
+ prev_scale_t = scale_arr_prev[index].item() * s_x
160
+ rescale = (prev_scale_t / scale_t)
161
+ pred_x0 = pred_x0 * rescale
162
+
163
+ dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * e_t
164
+ noise = sigma_t * torch.randn_like(x)
165
+ x = a_prev.sqrt() * pred_x0 + dir_xt + noise
166
+
167
+ return x
168
+
169
+ @torch.no_grad()
170
+ def model_apply(self, x, t, **extra_args):
171
+ x = x.to(device=self.unet.device, dtype=self.unet.dtype)
172
+ cfg_scale = extra_args['cfg_scale']
173
+ p = self.unet(x, t, **extra_args['positive'])
174
+ n = self.unet(x, t, **extra_args['negative'])
175
+ o = n + cfg_scale * (p - n)
176
+ o_better = rescale_noise_cfg(o, p, guidance_rescale=self.guidance_rescale)
177
+ return o_better
diffusers_vdm/improved_clip_vision.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # A CLIP Vision supporting arbitrary aspect ratios, by lllyasviel
2
+ # The input range is changed to [-1, 1] rather than [0, 1] !!!! (same as VAE's range)
3
+
4
+ import torch
5
+ import types
6
+ import einops
7
+
8
+ from abc import ABCMeta
9
+ from transformers import CLIPVisionModelWithProjection
10
+
11
+
12
+ def preprocess(image):
13
+ mean = torch.tensor([0.48145466, 0.4578275, 0.40821073], device=image.device, dtype=image.dtype)[None, :, None, None]
14
+ std = torch.tensor([0.26862954, 0.26130258, 0.27577711], device=image.device, dtype=image.dtype)[None, :, None, None]
15
+
16
+ scale = 16 / min(image.shape[2], image.shape[3])
17
+ image = torch.nn.functional.interpolate(
18
+ image,
19
+ size=(14 * round(scale * image.shape[2]), 14 * round(scale * image.shape[3])),
20
+ mode="bicubic",
21
+ antialias=True
22
+ )
23
+
24
+ return (image - mean) / std
25
+
26
+
27
+ def arbitrary_positional_encoding(p, H, W):
28
+ weight = p.weight
29
+ cls = weight[:1]
30
+ pos = weight[1:]
31
+ pos = einops.rearrange(pos, '(H W) C -> 1 C H W', H=16, W=16)
32
+ pos = torch.nn.functional.interpolate(pos, size=(H, W), mode="nearest")
33
+ pos = einops.rearrange(pos, '1 C H W -> (H W) C')
34
+ weight = torch.cat([cls, pos])[None]
35
+ return weight
36
+
37
+
38
+ def improved_clipvision_embedding_forward(self, pixel_values):
39
+ pixel_values = pixel_values * 0.5 + 0.5
40
+ pixel_values = preprocess(pixel_values)
41
+ batch_size = pixel_values.shape[0]
42
+ target_dtype = self.patch_embedding.weight.dtype
43
+ patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))
44
+ B, C, H, W = patch_embeds.shape
45
+ patch_embeds = einops.rearrange(patch_embeds, 'B C H W -> B (H W) C')
46
+ class_embeds = self.class_embedding.expand(batch_size, 1, -1)
47
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
48
+ embeddings = embeddings + arbitrary_positional_encoding(self.position_embedding, H, W)
49
+ return embeddings
50
+
51
+
52
+ class ImprovedCLIPVisionModelWithProjection(CLIPVisionModelWithProjection, metaclass=ABCMeta):
53
+ def __init__(self, config):
54
+ super().__init__(config)
55
+ self.vision_model.embeddings.forward = types.MethodType(
56
+ improved_clipvision_embedding_forward,
57
+ self.vision_model.embeddings
58
+ )
diffusers_vdm/pipeline.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import einops
4
+
5
+ from diffusers import DiffusionPipeline
6
+ from transformers import CLIPTextModel, CLIPTokenizer
7
+ from huggingface_hub import snapshot_download
8
+ from diffusers_vdm.vae import VideoAutoencoderKL
9
+ from diffusers_vdm.projection import Resampler
10
+ from diffusers_vdm.unet import UNet3DModel
11
+ from diffusers_vdm.improved_clip_vision import ImprovedCLIPVisionModelWithProjection
12
+ from diffusers_vdm.dynamic_tsnr_sampler import SamplerDynamicTSNR
13
+
14
+
15
+ class LatentVideoDiffusionPipeline(DiffusionPipeline):
16
+ def __init__(self, tokenizer, text_encoder, image_encoder, vae, image_projection, unet, fp16=True, eval=True):
17
+ super().__init__()
18
+
19
+ self.loading_components = dict(
20
+ vae=vae,
21
+ text_encoder=text_encoder,
22
+ tokenizer=tokenizer,
23
+ unet=unet,
24
+ image_encoder=image_encoder,
25
+ image_projection=image_projection
26
+ )
27
+
28
+ for k, v in self.loading_components.items():
29
+ setattr(self, k, v)
30
+
31
+ if fp16:
32
+ self.vae.half()
33
+ self.text_encoder.half()
34
+ self.unet.half()
35
+ self.image_encoder.half()
36
+ self.image_projection.half()
37
+
38
+ self.vae.requires_grad_(False)
39
+ self.text_encoder.requires_grad_(False)
40
+ self.image_encoder.requires_grad_(False)
41
+
42
+ self.vae.eval()
43
+ self.text_encoder.eval()
44
+ self.image_encoder.eval()
45
+
46
+ if eval:
47
+ self.unet.eval()
48
+ self.image_projection.eval()
49
+ else:
50
+ self.unet.train()
51
+ self.image_projection.train()
52
+
53
+ def to(self, *args, **kwargs):
54
+ for k, v in self.loading_components.items():
55
+ if hasattr(v, 'to'):
56
+ v.to(*args, **kwargs)
57
+ return self
58
+
59
+ def save_pretrained(self, save_directory, **kwargs):
60
+ for k, v in self.loading_components.items():
61
+ folder = os.path.join(save_directory, k)
62
+ os.makedirs(folder, exist_ok=True)
63
+ v.save_pretrained(folder)
64
+ return
65
+
66
+ @classmethod
67
+ def from_pretrained(cls, repo_id, fp16=True, eval=True, token=None):
68
+ local_folder = snapshot_download(repo_id=repo_id, token=token)
69
+ return cls(
70
+ tokenizer=CLIPTokenizer.from_pretrained(os.path.join(local_folder, "tokenizer")),
71
+ text_encoder=CLIPTextModel.from_pretrained(os.path.join(local_folder, "text_encoder")),
72
+ image_encoder=ImprovedCLIPVisionModelWithProjection.from_pretrained(os.path.join(local_folder, "image_encoder")),
73
+ vae=VideoAutoencoderKL.from_pretrained(os.path.join(local_folder, "vae")),
74
+ image_projection=Resampler.from_pretrained(os.path.join(local_folder, "image_projection")),
75
+ unet=UNet3DModel.from_pretrained(os.path.join(local_folder, "unet")),
76
+ fp16=fp16,
77
+ eval=eval
78
+ )
79
+
80
+ @torch.inference_mode()
81
+ def encode_cropped_prompt_77tokens(self, prompt: str):
82
+ cond_ids = self.tokenizer(prompt,
83
+ padding="max_length",
84
+ max_length=self.tokenizer.model_max_length,
85
+ truncation=True,
86
+ return_tensors="pt").input_ids.to(self.text_encoder.device)
87
+ cond = self.text_encoder(cond_ids, attention_mask=None).last_hidden_state
88
+ return cond
89
+
90
+ @torch.inference_mode()
91
+ def encode_clip_vision(self, frames):
92
+ b, c, t, h, w = frames.shape
93
+ frames = einops.rearrange(frames, 'b c t h w -> (b t) c h w')
94
+ clipvision_embed = self.image_encoder(frames).last_hidden_state
95
+ clipvision_embed = einops.rearrange(clipvision_embed, '(b t) d c -> b t d c', t=t)
96
+ return clipvision_embed
97
+
98
+ @torch.inference_mode()
99
+ def encode_latents(self, videos, return_hidden_states=True):
100
+ b, c, t, h, w = videos.shape
101
+ x = einops.rearrange(videos, 'b c t h w -> (b t) c h w')
102
+ encoder_posterior, hidden_states = self.vae.encode(x, return_hidden_states=return_hidden_states)
103
+ z = encoder_posterior.mode() * self.vae.scale_factor
104
+ z = einops.rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t)
105
+
106
+ if not return_hidden_states:
107
+ return z
108
+
109
+ hidden_states = [einops.rearrange(h, '(b t) c h w -> b c t h w', b=b) for h in hidden_states]
110
+ hidden_states = [h[:, :, [0, -1], :, :] for h in hidden_states] # only need first and last
111
+
112
+ return z, hidden_states
113
+
114
+ @torch.inference_mode()
115
+ def decode_latents(self, latents, hidden_states):
116
+ B, C, T, H, W = latents.shape
117
+ latents = einops.rearrange(latents, 'b c t h w -> (b t) c h w')
118
+ latents = latents.to(device=self.vae.device, dtype=self.vae.dtype) / self.vae.scale_factor
119
+ pixels = self.vae.decode(latents, ref_context=hidden_states, timesteps=T)
120
+ pixels = einops.rearrange(pixels, '(b t) c h w -> b c t h w', b=B, t=T)
121
+ return pixels
122
+
123
+ @torch.inference_mode()
124
+ def __call__(
125
+ self,
126
+ batch_size: int = 1,
127
+ steps: int = 50,
128
+ guidance_scale: float = 5.0,
129
+ positive_text_cond = None,
130
+ negative_text_cond = None,
131
+ positive_image_cond = None,
132
+ negative_image_cond = None,
133
+ concat_cond = None,
134
+ fs = 3,
135
+ progress_tqdm = None,
136
+ ):
137
+ unet_is_training = self.unet.training
138
+
139
+ if unet_is_training:
140
+ self.unet.eval()
141
+
142
+ device = self.unet.device
143
+ dtype = self.unet.dtype
144
+ dynamic_tsnr_model = SamplerDynamicTSNR(self.unet)
145
+
146
+ # Batch
147
+
148
+ concat_cond = concat_cond.repeat(batch_size, 1, 1, 1, 1).to(device=device, dtype=dtype) # b, c, t, h, w
149
+ positive_text_cond = positive_text_cond.repeat(batch_size, 1, 1).to(concat_cond) # b, f, c
150
+ negative_text_cond = negative_text_cond.repeat(batch_size, 1, 1).to(concat_cond) # b, f, c
151
+ positive_image_cond = positive_image_cond.repeat(batch_size, 1, 1, 1).to(concat_cond) # b, t, l, c
152
+ negative_image_cond = negative_image_cond.repeat(batch_size, 1, 1, 1).to(concat_cond)
153
+
154
+ if isinstance(fs, torch.Tensor):
155
+ fs = fs.repeat(batch_size, ).to(dtype=torch.long, device=device) # b
156
+ else:
157
+ fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=device) # b
158
+
159
+ # Initial latents
160
+
161
+ latent_shape = concat_cond.shape
162
+
163
+ # Feeds
164
+
165
+ sampler_kwargs = dict(
166
+ cfg_scale=guidance_scale,
167
+ positive=dict(
168
+ context_text=positive_text_cond,
169
+ context_img=positive_image_cond,
170
+ fs=fs,
171
+ concat_cond=concat_cond
172
+ ),
173
+ negative=dict(
174
+ context_text=negative_text_cond,
175
+ context_img=negative_image_cond,
176
+ fs=fs,
177
+ concat_cond=concat_cond
178
+ )
179
+ )
180
+
181
+ # Sample
182
+
183
+ results = dynamic_tsnr_model(latent_shape, steps, extra_args=sampler_kwargs, progress_tqdm=progress_tqdm)
184
+
185
+ if unet_is_training:
186
+ self.unet.train()
187
+
188
+ return results
diffusers_vdm/projection.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
2
+ # and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py
3
+ # and https://github.com/tencent-ailab/IP-Adapter/blob/main/ip_adapter/resampler.py
4
+
5
+
6
+ import math
7
+ import torch
8
+ import einops
9
+ import torch.nn as nn
10
+
11
+ from huggingface_hub import PyTorchModelHubMixin
12
+
13
+
14
+ class ImageProjModel(nn.Module):
15
+ """Projection Model"""
16
+ def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
17
+ super().__init__()
18
+ self.cross_attention_dim = cross_attention_dim
19
+ self.clip_extra_context_tokens = clip_extra_context_tokens
20
+ self.proj = nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
21
+ self.norm = nn.LayerNorm(cross_attention_dim)
22
+
23
+ def forward(self, image_embeds):
24
+ #embeds = image_embeds
25
+ embeds = image_embeds.type(list(self.proj.parameters())[0].dtype)
26
+ clip_extra_context_tokens = self.proj(embeds).reshape(-1, self.clip_extra_context_tokens, self.cross_attention_dim)
27
+ clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
28
+ return clip_extra_context_tokens
29
+
30
+
31
+ # FFN
32
+ def FeedForward(dim, mult=4):
33
+ inner_dim = int(dim * mult)
34
+ return nn.Sequential(
35
+ nn.LayerNorm(dim),
36
+ nn.Linear(dim, inner_dim, bias=False),
37
+ nn.GELU(),
38
+ nn.Linear(inner_dim, dim, bias=False),
39
+ )
40
+
41
+
42
+ def reshape_tensor(x, heads):
43
+ bs, length, width = x.shape
44
+ #(bs, length, width) --> (bs, length, n_heads, dim_per_head)
45
+ x = x.view(bs, length, heads, -1)
46
+ # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
47
+ x = x.transpose(1, 2)
48
+ # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
49
+ x = x.reshape(bs, heads, length, -1)
50
+ return x
51
+
52
+
53
+ class PerceiverAttention(nn.Module):
54
+ def __init__(self, *, dim, dim_head=64, heads=8):
55
+ super().__init__()
56
+ self.scale = dim_head**-0.5
57
+ self.dim_head = dim_head
58
+ self.heads = heads
59
+ inner_dim = dim_head * heads
60
+
61
+ self.norm1 = nn.LayerNorm(dim)
62
+ self.norm2 = nn.LayerNorm(dim)
63
+
64
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
65
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
66
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
67
+
68
+
69
+ def forward(self, x, latents):
70
+ """
71
+ Args:
72
+ x (torch.Tensor): image features
73
+ shape (b, n1, D)
74
+ latent (torch.Tensor): latent features
75
+ shape (b, n2, D)
76
+ """
77
+ x = self.norm1(x)
78
+ latents = self.norm2(latents)
79
+
80
+ b, l, _ = latents.shape
81
+
82
+ q = self.to_q(latents)
83
+ kv_input = torch.cat((x, latents), dim=-2)
84
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
85
+
86
+ q = reshape_tensor(q, self.heads)
87
+ k = reshape_tensor(k, self.heads)
88
+ v = reshape_tensor(v, self.heads)
89
+
90
+ # attention
91
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
92
+ weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
93
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
94
+ out = weight @ v
95
+
96
+ out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
97
+
98
+ return self.to_out(out)
99
+
100
+
101
+ class Resampler(nn.Module, PyTorchModelHubMixin):
102
+ def __init__(
103
+ self,
104
+ dim=1024,
105
+ depth=8,
106
+ dim_head=64,
107
+ heads=16,
108
+ num_queries=8,
109
+ embedding_dim=768,
110
+ output_dim=1024,
111
+ ff_mult=4,
112
+ video_length=16,
113
+ input_frames_length=2,
114
+ ):
115
+ super().__init__()
116
+ self.num_queries = num_queries
117
+ self.video_length = video_length
118
+
119
+ self.latents = nn.Parameter(torch.randn(1, num_queries * video_length, dim) / dim**0.5)
120
+ self.input_pos = nn.Parameter(torch.zeros(1, input_frames_length, 1, embedding_dim))
121
+
122
+ self.proj_in = nn.Linear(embedding_dim, dim)
123
+ self.proj_out = nn.Linear(dim, output_dim)
124
+ self.norm_out = nn.LayerNorm(output_dim)
125
+
126
+ self.layers = nn.ModuleList([])
127
+ for _ in range(depth):
128
+ self.layers.append(
129
+ nn.ModuleList(
130
+ [
131
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
132
+ FeedForward(dim=dim, mult=ff_mult),
133
+ ]
134
+ )
135
+ )
136
+
137
+ def forward(self, x):
138
+ latents = self.latents.repeat(x.size(0), 1, 1)
139
+
140
+ x = x + self.input_pos
141
+ x = einops.rearrange(x, 'b ti d c -> b (ti d) c')
142
+ x = self.proj_in(x)
143
+
144
+ for attn, ff in self.layers:
145
+ latents = attn(x, latents) + latents
146
+ latents = ff(latents) + latents
147
+
148
+ latents = self.proj_out(latents)
149
+ latents = self.norm_out(latents)
150
+
151
+ latents = einops.rearrange(latents, 'b (to l) c -> b to l c', to=self.video_length)
152
+ return latents
153
+
154
+ @property
155
+ def device(self):
156
+ return next(self.parameters()).device
157
+
158
+ @property
159
+ def dtype(self):
160
+ return next(self.parameters()).dtype
diffusers_vdm/unet.py ADDED
@@ -0,0 +1,650 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/AILab-CVC/VideoCrafter
2
+ # https://github.com/Doubiiu/DynamiCrafter
3
+ # https://github.com/ToonCrafter/ToonCrafter
4
+ # Then edited by lllyasviel
5
+
6
+ from functools import partial
7
+ from abc import abstractmethod
8
+ import torch
9
+ import math
10
+ import torch.nn as nn
11
+ from einops import rearrange, repeat
12
+ import torch.nn.functional as F
13
+ from diffusers_vdm.basics import checkpoint
14
+ from diffusers_vdm.basics import (
15
+ zero_module,
16
+ conv_nd,
17
+ linear,
18
+ avg_pool_nd,
19
+ normalization
20
+ )
21
+ from diffusers_vdm.attention import SpatialTransformer, TemporalTransformer
22
+ from huggingface_hub import PyTorchModelHubMixin
23
+
24
+
25
+ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
26
+ """
27
+ Create sinusoidal timestep embeddings.
28
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
29
+ These may be fractional.
30
+ :param dim: the dimension of the output.
31
+ :param max_period: controls the minimum frequency of the embeddings.
32
+ :return: an [N x dim] Tensor of positional embeddings.
33
+ """
34
+ if not repeat_only:
35
+ half = dim // 2
36
+ freqs = torch.exp(
37
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
38
+ ).to(device=timesteps.device)
39
+ args = timesteps[:, None].float() * freqs[None]
40
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
41
+ if dim % 2:
42
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
43
+ else:
44
+ embedding = repeat(timesteps, 'b -> b d', d=dim)
45
+ return embedding
46
+
47
+
48
+ class TimestepBlock(nn.Module):
49
+ """
50
+ Any module where forward() takes timestep embeddings as a second argument.
51
+ """
52
+
53
+ @abstractmethod
54
+ def forward(self, x, emb):
55
+ """
56
+ Apply the module to `x` given `emb` timestep embeddings.
57
+ """
58
+
59
+
60
+ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
61
+ """
62
+ A sequential module that passes timestep embeddings to the children that
63
+ support it as an extra input.
64
+ """
65
+
66
+ def forward(self, x, emb, context=None, batch_size=None):
67
+ for layer in self:
68
+ if isinstance(layer, TimestepBlock):
69
+ x = layer(x, emb, batch_size=batch_size)
70
+ elif isinstance(layer, SpatialTransformer):
71
+ x = layer(x, context)
72
+ elif isinstance(layer, TemporalTransformer):
73
+ x = rearrange(x, '(b f) c h w -> b c f h w', b=batch_size)
74
+ x = layer(x, context)
75
+ x = rearrange(x, 'b c f h w -> (b f) c h w')
76
+ else:
77
+ x = layer(x)
78
+ return x
79
+
80
+
81
+ class Downsample(nn.Module):
82
+ """
83
+ A downsampling layer with an optional convolution.
84
+ :param channels: channels in the inputs and outputs.
85
+ :param use_conv: a bool determining if a convolution is applied.
86
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
87
+ downsampling occurs in the inner-two dimensions.
88
+ """
89
+
90
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
91
+ super().__init__()
92
+ self.channels = channels
93
+ self.out_channels = out_channels or channels
94
+ self.use_conv = use_conv
95
+ self.dims = dims
96
+ stride = 2 if dims != 3 else (1, 2, 2)
97
+ if use_conv:
98
+ self.op = conv_nd(
99
+ dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
100
+ )
101
+ else:
102
+ assert self.channels == self.out_channels
103
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
104
+
105
+ def forward(self, x):
106
+ assert x.shape[1] == self.channels
107
+ return self.op(x)
108
+
109
+
110
+ class Upsample(nn.Module):
111
+ """
112
+ An upsampling layer with an optional convolution.
113
+ :param channels: channels in the inputs and outputs.
114
+ :param use_conv: a bool determining if a convolution is applied.
115
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
116
+ upsampling occurs in the inner-two dimensions.
117
+ """
118
+
119
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
120
+ super().__init__()
121
+ self.channels = channels
122
+ self.out_channels = out_channels or channels
123
+ self.use_conv = use_conv
124
+ self.dims = dims
125
+ if use_conv:
126
+ self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
127
+
128
+ def forward(self, x):
129
+ assert x.shape[1] == self.channels
130
+ if self.dims == 3:
131
+ x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode='nearest')
132
+ else:
133
+ x = F.interpolate(x, scale_factor=2, mode='nearest')
134
+ if self.use_conv:
135
+ x = self.conv(x)
136
+ return x
137
+
138
+
139
+ class ResBlock(TimestepBlock):
140
+ """
141
+ A residual block that can optionally change the number of channels.
142
+ :param channels: the number of input channels.
143
+ :param emb_channels: the number of timestep embedding channels.
144
+ :param dropout: the rate of dropout.
145
+ :param out_channels: if specified, the number of out channels.
146
+ :param use_conv: if True and out_channels is specified, use a spatial
147
+ convolution instead of a smaller 1x1 convolution to change the
148
+ channels in the skip connection.
149
+ :param dims: determines if the signal is 1D, 2D, or 3D.
150
+ :param up: if True, use this block for upsampling.
151
+ :param down: if True, use this block for downsampling.
152
+ :param use_temporal_conv: if True, use the temporal convolution.
153
+ :param use_image_dataset: if True, the temporal parameters will not be optimized.
154
+ """
155
+
156
+ def __init__(
157
+ self,
158
+ channels,
159
+ emb_channels,
160
+ dropout,
161
+ out_channels=None,
162
+ use_scale_shift_norm=False,
163
+ dims=2,
164
+ use_checkpoint=False,
165
+ use_conv=False,
166
+ up=False,
167
+ down=False,
168
+ use_temporal_conv=False,
169
+ tempspatial_aware=False
170
+ ):
171
+ super().__init__()
172
+ self.channels = channels
173
+ self.emb_channels = emb_channels
174
+ self.dropout = dropout
175
+ self.out_channels = out_channels or channels
176
+ self.use_conv = use_conv
177
+ self.use_checkpoint = use_checkpoint
178
+ self.use_scale_shift_norm = use_scale_shift_norm
179
+ self.use_temporal_conv = use_temporal_conv
180
+
181
+ self.in_layers = nn.Sequential(
182
+ normalization(channels),
183
+ nn.SiLU(),
184
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
185
+ )
186
+
187
+ self.updown = up or down
188
+
189
+ if up:
190
+ self.h_upd = Upsample(channels, False, dims)
191
+ self.x_upd = Upsample(channels, False, dims)
192
+ elif down:
193
+ self.h_upd = Downsample(channels, False, dims)
194
+ self.x_upd = Downsample(channels, False, dims)
195
+ else:
196
+ self.h_upd = self.x_upd = nn.Identity()
197
+
198
+ self.emb_layers = nn.Sequential(
199
+ nn.SiLU(),
200
+ nn.Linear(
201
+ emb_channels,
202
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
203
+ ),
204
+ )
205
+ self.out_layers = nn.Sequential(
206
+ normalization(self.out_channels),
207
+ nn.SiLU(),
208
+ nn.Dropout(p=dropout),
209
+ zero_module(nn.Conv2d(self.out_channels, self.out_channels, 3, padding=1)),
210
+ )
211
+
212
+ if self.out_channels == channels:
213
+ self.skip_connection = nn.Identity()
214
+ elif use_conv:
215
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1)
216
+ else:
217
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
218
+
219
+ if self.use_temporal_conv:
220
+ self.temopral_conv = TemporalConvBlock(
221
+ self.out_channels,
222
+ self.out_channels,
223
+ dropout=0.1,
224
+ spatial_aware=tempspatial_aware
225
+ )
226
+
227
+ def forward(self, x, emb, batch_size=None):
228
+ """
229
+ Apply the block to a Tensor, conditioned on a timestep embedding.
230
+ :param x: an [N x C x ...] Tensor of features.
231
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
232
+ :return: an [N x C x ...] Tensor of outputs.
233
+ """
234
+ input_tuple = (x, emb)
235
+ if batch_size:
236
+ forward_batchsize = partial(self._forward, batch_size=batch_size)
237
+ return checkpoint(forward_batchsize, input_tuple, self.parameters(), self.use_checkpoint)
238
+ return checkpoint(self._forward, input_tuple, self.parameters(), self.use_checkpoint)
239
+
240
+ def _forward(self, x, emb, batch_size=None):
241
+ if self.updown:
242
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
243
+ h = in_rest(x)
244
+ h = self.h_upd(h)
245
+ x = self.x_upd(x)
246
+ h = in_conv(h)
247
+ else:
248
+ h = self.in_layers(x)
249
+ emb_out = self.emb_layers(emb).type(h.dtype)
250
+ while len(emb_out.shape) < len(h.shape):
251
+ emb_out = emb_out[..., None]
252
+ if self.use_scale_shift_norm:
253
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
254
+ scale, shift = torch.chunk(emb_out, 2, dim=1)
255
+ h = out_norm(h) * (1 + scale) + shift
256
+ h = out_rest(h)
257
+ else:
258
+ h = h + emb_out
259
+ h = self.out_layers(h)
260
+ h = self.skip_connection(x) + h
261
+
262
+ if self.use_temporal_conv and batch_size:
263
+ h = rearrange(h, '(b t) c h w -> b c t h w', b=batch_size)
264
+ h = self.temopral_conv(h)
265
+ h = rearrange(h, 'b c t h w -> (b t) c h w')
266
+ return h
267
+
268
+
269
+ class TemporalConvBlock(nn.Module):
270
+ """
271
+ Adapted from modelscope: https://github.com/modelscope/modelscope/blob/master/modelscope/models/multi_modal/video_synthesis/unet_sd.py
272
+ """
273
+
274
+ def __init__(self, in_channels, out_channels=None, dropout=0.0, spatial_aware=False):
275
+ super(TemporalConvBlock, self).__init__()
276
+ if out_channels is None:
277
+ out_channels = in_channels
278
+ self.in_channels = in_channels
279
+ self.out_channels = out_channels
280
+ th_kernel_shape = (3, 1, 1) if not spatial_aware else (3, 3, 1)
281
+ th_padding_shape = (1, 0, 0) if not spatial_aware else (1, 1, 0)
282
+ tw_kernel_shape = (3, 1, 1) if not spatial_aware else (3, 1, 3)
283
+ tw_padding_shape = (1, 0, 0) if not spatial_aware else (1, 0, 1)
284
+
285
+ # conv layers
286
+ self.conv1 = nn.Sequential(
287
+ nn.GroupNorm(32, in_channels), nn.SiLU(),
288
+ nn.Conv3d(in_channels, out_channels, th_kernel_shape, padding=th_padding_shape))
289
+ self.conv2 = nn.Sequential(
290
+ nn.GroupNorm(32, out_channels), nn.SiLU(), nn.Dropout(dropout),
291
+ nn.Conv3d(out_channels, in_channels, tw_kernel_shape, padding=tw_padding_shape))
292
+ self.conv3 = nn.Sequential(
293
+ nn.GroupNorm(32, out_channels), nn.SiLU(), nn.Dropout(dropout),
294
+ nn.Conv3d(out_channels, in_channels, th_kernel_shape, padding=th_padding_shape))
295
+ self.conv4 = nn.Sequential(
296
+ nn.GroupNorm(32, out_channels), nn.SiLU(), nn.Dropout(dropout),
297
+ nn.Conv3d(out_channels, in_channels, tw_kernel_shape, padding=tw_padding_shape))
298
+
299
+ # zero out the last layer params,so the conv block is identity
300
+ nn.init.zeros_(self.conv4[-1].weight)
301
+ nn.init.zeros_(self.conv4[-1].bias)
302
+
303
+ def forward(self, x):
304
+ identity = x
305
+ x = self.conv1(x)
306
+ x = self.conv2(x)
307
+ x = self.conv3(x)
308
+ x = self.conv4(x)
309
+
310
+ return identity + x
311
+
312
+
313
+ class UNet3DModel(nn.Module, PyTorchModelHubMixin):
314
+ """
315
+ The full UNet model with attention and timestep embedding.
316
+ :param in_channels: in_channels in the input Tensor.
317
+ :param model_channels: base channel count for the model.
318
+ :param out_channels: channels in the output Tensor.
319
+ :param num_res_blocks: number of residual blocks per downsample.
320
+ :param attention_resolutions: a collection of downsample rates at which
321
+ attention will take place. May be a set, list, or tuple.
322
+ For example, if this contains 4, then at 4x downsampling, attention
323
+ will be used.
324
+ :param dropout: the dropout probability.
325
+ :param channel_mult: channel multiplier for each level of the UNet.
326
+ :param conv_resample: if True, use learned convolutions for upsampling and
327
+ downsampling.
328
+ :param dims: determines if the signal is 1D, 2D, or 3D.
329
+ :param num_classes: if specified (as an int), then this model will be
330
+ class-conditional with `num_classes` classes.
331
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
332
+ :param num_heads: the number of attention heads in each attention layer.
333
+ :param num_heads_channels: if specified, ignore num_heads and instead use
334
+ a fixed channel width per attention head.
335
+ :param num_heads_upsample: works with num_heads to set a different number
336
+ of heads for upsampling. Deprecated.
337
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
338
+ :param resblock_updown: use residual blocks for up/downsampling.
339
+ :param use_new_attention_order: use a different attention pattern for potentially
340
+ increased efficiency.
341
+ """
342
+
343
+ def __init__(self,
344
+ in_channels,
345
+ model_channels,
346
+ out_channels,
347
+ num_res_blocks,
348
+ attention_resolutions,
349
+ dropout=0.0,
350
+ channel_mult=(1, 2, 4, 8),
351
+ conv_resample=True,
352
+ dims=2,
353
+ context_dim=None,
354
+ use_scale_shift_norm=False,
355
+ resblock_updown=False,
356
+ num_heads=-1,
357
+ num_head_channels=-1,
358
+ transformer_depth=1,
359
+ use_linear=False,
360
+ temporal_conv=False,
361
+ tempspatial_aware=False,
362
+ temporal_attention=True,
363
+ use_relative_position=True,
364
+ use_causal_attention=False,
365
+ temporal_length=None,
366
+ addition_attention=False,
367
+ temporal_selfatt_only=True,
368
+ image_cross_attention=False,
369
+ image_cross_attention_scale_learnable=False,
370
+ default_fs=4,
371
+ fs_condition=False,
372
+ ):
373
+ super(UNet3DModel, self).__init__()
374
+ if num_heads == -1:
375
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
376
+ if num_head_channels == -1:
377
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
378
+
379
+ self.in_channels = in_channels
380
+ self.model_channels = model_channels
381
+ self.out_channels = out_channels
382
+ self.num_res_blocks = num_res_blocks
383
+ self.attention_resolutions = attention_resolutions
384
+ self.dropout = dropout
385
+ self.channel_mult = channel_mult
386
+ self.conv_resample = conv_resample
387
+ self.temporal_attention = temporal_attention
388
+ time_embed_dim = model_channels * 4
389
+ self.use_checkpoint = use_checkpoint = False # moved to self.enable_gradient_checkpointing()
390
+ temporal_self_att_only = True
391
+ self.addition_attention = addition_attention
392
+ self.temporal_length = temporal_length
393
+ self.image_cross_attention = image_cross_attention
394
+ self.image_cross_attention_scale_learnable = image_cross_attention_scale_learnable
395
+ self.default_fs = default_fs
396
+ self.fs_condition = fs_condition
397
+
398
+ ## Time embedding blocks
399
+ self.time_embed = nn.Sequential(
400
+ linear(model_channels, time_embed_dim),
401
+ nn.SiLU(),
402
+ linear(time_embed_dim, time_embed_dim),
403
+ )
404
+ if fs_condition:
405
+ self.fps_embedding = nn.Sequential(
406
+ linear(model_channels, time_embed_dim),
407
+ nn.SiLU(),
408
+ linear(time_embed_dim, time_embed_dim),
409
+ )
410
+ nn.init.zeros_(self.fps_embedding[-1].weight)
411
+ nn.init.zeros_(self.fps_embedding[-1].bias)
412
+ ## Input Block
413
+ self.input_blocks = nn.ModuleList(
414
+ [
415
+ TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))
416
+ ]
417
+ )
418
+ if self.addition_attention:
419
+ self.init_attn = TimestepEmbedSequential(
420
+ TemporalTransformer(
421
+ model_channels,
422
+ n_heads=8,
423
+ d_head=num_head_channels,
424
+ depth=transformer_depth,
425
+ context_dim=context_dim,
426
+ use_checkpoint=use_checkpoint, only_self_att=temporal_selfatt_only,
427
+ causal_attention=False, relative_position=use_relative_position,
428
+ temporal_length=temporal_length))
429
+
430
+ input_block_chans = [model_channels]
431
+ ch = model_channels
432
+ ds = 1
433
+ for level, mult in enumerate(channel_mult):
434
+ for _ in range(num_res_blocks):
435
+ layers = [
436
+ ResBlock(ch, time_embed_dim, dropout,
437
+ out_channels=mult * model_channels, dims=dims, use_checkpoint=use_checkpoint,
438
+ use_scale_shift_norm=use_scale_shift_norm, tempspatial_aware=tempspatial_aware,
439
+ use_temporal_conv=temporal_conv
440
+ )
441
+ ]
442
+ ch = mult * model_channels
443
+ if ds in attention_resolutions:
444
+ if num_head_channels == -1:
445
+ dim_head = ch // num_heads
446
+ else:
447
+ num_heads = ch // num_head_channels
448
+ dim_head = num_head_channels
449
+ layers.append(
450
+ SpatialTransformer(ch, num_heads, dim_head,
451
+ depth=transformer_depth, context_dim=context_dim, use_linear=use_linear,
452
+ use_checkpoint=use_checkpoint, disable_self_attn=False,
453
+ video_length=temporal_length,
454
+ image_cross_attention=self.image_cross_attention,
455
+ image_cross_attention_scale_learnable=self.image_cross_attention_scale_learnable,
456
+ )
457
+ )
458
+ if self.temporal_attention:
459
+ layers.append(
460
+ TemporalTransformer(ch, num_heads, dim_head,
461
+ depth=transformer_depth, context_dim=context_dim, use_linear=use_linear,
462
+ use_checkpoint=use_checkpoint, only_self_att=temporal_self_att_only,
463
+ causal_attention=use_causal_attention,
464
+ relative_position=use_relative_position,
465
+ temporal_length=temporal_length
466
+ )
467
+ )
468
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
469
+ input_block_chans.append(ch)
470
+ if level != len(channel_mult) - 1:
471
+ out_ch = ch
472
+ self.input_blocks.append(
473
+ TimestepEmbedSequential(
474
+ ResBlock(ch, time_embed_dim, dropout,
475
+ out_channels=out_ch, dims=dims, use_checkpoint=use_checkpoint,
476
+ use_scale_shift_norm=use_scale_shift_norm,
477
+ down=True
478
+ )
479
+ if resblock_updown
480
+ else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch)
481
+ )
482
+ )
483
+ ch = out_ch
484
+ input_block_chans.append(ch)
485
+ ds *= 2
486
+
487
+ if num_head_channels == -1:
488
+ dim_head = ch // num_heads
489
+ else:
490
+ num_heads = ch // num_head_channels
491
+ dim_head = num_head_channels
492
+ layers = [
493
+ ResBlock(ch, time_embed_dim, dropout,
494
+ dims=dims, use_checkpoint=use_checkpoint,
495
+ use_scale_shift_norm=use_scale_shift_norm, tempspatial_aware=tempspatial_aware,
496
+ use_temporal_conv=temporal_conv
497
+ ),
498
+ SpatialTransformer(ch, num_heads, dim_head,
499
+ depth=transformer_depth, context_dim=context_dim, use_linear=use_linear,
500
+ use_checkpoint=use_checkpoint, disable_self_attn=False, video_length=temporal_length,
501
+ image_cross_attention=self.image_cross_attention,
502
+ image_cross_attention_scale_learnable=self.image_cross_attention_scale_learnable
503
+ )
504
+ ]
505
+ if self.temporal_attention:
506
+ layers.append(
507
+ TemporalTransformer(ch, num_heads, dim_head,
508
+ depth=transformer_depth, context_dim=context_dim, use_linear=use_linear,
509
+ use_checkpoint=use_checkpoint, only_self_att=temporal_self_att_only,
510
+ causal_attention=use_causal_attention, relative_position=use_relative_position,
511
+ temporal_length=temporal_length
512
+ )
513
+ )
514
+ layers.append(
515
+ ResBlock(ch, time_embed_dim, dropout,
516
+ dims=dims, use_checkpoint=use_checkpoint,
517
+ use_scale_shift_norm=use_scale_shift_norm, tempspatial_aware=tempspatial_aware,
518
+ use_temporal_conv=temporal_conv
519
+ )
520
+ )
521
+
522
+ ## Middle Block
523
+ self.middle_block = TimestepEmbedSequential(*layers)
524
+
525
+ ## Output Block
526
+ self.output_blocks = nn.ModuleList([])
527
+ for level, mult in list(enumerate(channel_mult))[::-1]:
528
+ for i in range(num_res_blocks + 1):
529
+ ich = input_block_chans.pop()
530
+ layers = [
531
+ ResBlock(ch + ich, time_embed_dim, dropout,
532
+ out_channels=mult * model_channels, dims=dims, use_checkpoint=use_checkpoint,
533
+ use_scale_shift_norm=use_scale_shift_norm, tempspatial_aware=tempspatial_aware,
534
+ use_temporal_conv=temporal_conv
535
+ )
536
+ ]
537
+ ch = model_channels * mult
538
+ if ds in attention_resolutions:
539
+ if num_head_channels == -1:
540
+ dim_head = ch // num_heads
541
+ else:
542
+ num_heads = ch // num_head_channels
543
+ dim_head = num_head_channels
544
+ layers.append(
545
+ SpatialTransformer(ch, num_heads, dim_head,
546
+ depth=transformer_depth, context_dim=context_dim, use_linear=use_linear,
547
+ use_checkpoint=use_checkpoint, disable_self_attn=False,
548
+ video_length=temporal_length,
549
+ image_cross_attention=self.image_cross_attention,
550
+ image_cross_attention_scale_learnable=self.image_cross_attention_scale_learnable
551
+ )
552
+ )
553
+ if self.temporal_attention:
554
+ layers.append(
555
+ TemporalTransformer(ch, num_heads, dim_head,
556
+ depth=transformer_depth, context_dim=context_dim, use_linear=use_linear,
557
+ use_checkpoint=use_checkpoint, only_self_att=temporal_self_att_only,
558
+ causal_attention=use_causal_attention,
559
+ relative_position=use_relative_position,
560
+ temporal_length=temporal_length
561
+ )
562
+ )
563
+ if level and i == num_res_blocks:
564
+ out_ch = ch
565
+ layers.append(
566
+ ResBlock(ch, time_embed_dim, dropout,
567
+ out_channels=out_ch, dims=dims, use_checkpoint=use_checkpoint,
568
+ use_scale_shift_norm=use_scale_shift_norm,
569
+ up=True
570
+ )
571
+ if resblock_updown
572
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
573
+ )
574
+ ds //= 2
575
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
576
+
577
+ self.out = nn.Sequential(
578
+ normalization(ch),
579
+ nn.SiLU(),
580
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
581
+ )
582
+
583
+ @property
584
+ def device(self):
585
+ return next(self.parameters()).device
586
+
587
+ @property
588
+ def dtype(self):
589
+ return next(self.parameters()).dtype
590
+
591
+ def forward(self, x, timesteps, context_text=None, context_img=None, concat_cond=None, fs=None, **kwargs):
592
+ b, _, t, _, _ = x.shape
593
+
594
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).type(x.dtype)
595
+ emb = self.time_embed(t_emb)
596
+
597
+ context_text = context_text.repeat_interleave(repeats=t, dim=0)
598
+ context_img = rearrange(context_img, 'b t l c -> (b t) l c')
599
+
600
+ context = (context_text, context_img)
601
+
602
+ emb = emb.repeat_interleave(repeats=t, dim=0)
603
+
604
+ if concat_cond is not None:
605
+ x = torch.cat([x, concat_cond], dim=1)
606
+
607
+ ## always in shape (b t) c h w, except for temporal layer
608
+ x = rearrange(x, 'b c t h w -> (b t) c h w')
609
+
610
+ ## combine emb
611
+ if self.fs_condition:
612
+ if fs is None:
613
+ fs = torch.tensor(
614
+ [self.default_fs] * b, dtype=torch.long, device=x.device)
615
+ fs_emb = timestep_embedding(fs, self.model_channels, repeat_only=False).type(x.dtype)
616
+
617
+ fs_embed = self.fps_embedding(fs_emb)
618
+ fs_embed = fs_embed.repeat_interleave(repeats=t, dim=0)
619
+ emb = emb + fs_embed
620
+
621
+ h = x
622
+ hs = []
623
+ for id, module in enumerate(self.input_blocks):
624
+ h = module(h, emb, context=context, batch_size=b)
625
+ if id == 0 and self.addition_attention:
626
+ h = self.init_attn(h, emb, context=context, batch_size=b)
627
+ hs.append(h)
628
+
629
+ h = self.middle_block(h, emb, context=context, batch_size=b)
630
+
631
+ for module in self.output_blocks:
632
+ h = torch.cat([h, hs.pop()], dim=1)
633
+ h = module(h, emb, context=context, batch_size=b)
634
+ h = h.type(x.dtype)
635
+ y = self.out(h)
636
+
637
+ y = rearrange(y, '(b t) c h w -> b c t h w', b=b)
638
+ return y
639
+
640
+ def enable_gradient_checkpointing(self, enable=True, verbose=False):
641
+ for k, v in self.named_modules():
642
+ if hasattr(v, 'checkpoint'):
643
+ v.checkpoint = enable
644
+ if verbose:
645
+ print(f'{k}.checkpoint = {enable}')
646
+ if hasattr(v, 'use_checkpoint'):
647
+ v.use_checkpoint = enable
648
+ if verbose:
649
+ print(f'{k}.use_checkpoint = {enable}')
650
+ return
diffusers_vdm/utils.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import torch
4
+ import einops
5
+ import torchvision
6
+
7
+
8
+ def resize_and_center_crop(image, target_width, target_height, interpolation=cv2.INTER_AREA):
9
+ original_height, original_width = image.shape[:2]
10
+ k = max(target_height / original_height, target_width / original_width)
11
+ new_width = int(round(original_width * k))
12
+ new_height = int(round(original_height * k))
13
+ resized_image = cv2.resize(image, (new_width, new_height), interpolation=interpolation)
14
+ x_start = (new_width - target_width) // 2
15
+ y_start = (new_height - target_height) // 2
16
+ cropped_image = resized_image[y_start:y_start + target_height, x_start:x_start + target_width]
17
+ return cropped_image
18
+
19
+
20
+ def save_bcthw_as_mp4(x, output_filename, fps=10):
21
+ b, c, t, h, w = x.shape
22
+
23
+ per_row = b
24
+ for p in [6, 5, 4, 3, 2]:
25
+ if b % p == 0:
26
+ per_row = p
27
+ break
28
+
29
+ os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True)
30
+ x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5
31
+ x = x.detach().cpu().to(torch.uint8)
32
+ x = einops.rearrange(x, '(m n) c t h w -> t (m h) (n w) c', n=per_row)
33
+ torchvision.io.write_video(output_filename, x, fps=fps, video_codec='h264', options={'crf': '1'})
34
+ return x
35
+
36
+
37
+ def save_bcthw_as_png(x, output_filename):
38
+ os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True)
39
+ x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5
40
+ x = x.detach().cpu().to(torch.uint8)
41
+ x = einops.rearrange(x, 'b c t h w -> c (b h) (t w)')
42
+ torchvision.io.write_png(x, output_filename)
43
+ return output_filename
diffusers_vdm/vae.py ADDED
@@ -0,0 +1,826 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # video VAE with many components from lots of repos
2
+ # collected by lvmin
3
+
4
+
5
+ import torch
6
+ import xformers.ops
7
+ import torch.nn as nn
8
+
9
+ from einops import rearrange, repeat
10
+ from diffusers_vdm.basics import default, exists, zero_module, conv_nd, linear, normalization
11
+ from diffusers_vdm.unet import Upsample, Downsample
12
+ from huggingface_hub import PyTorchModelHubMixin
13
+
14
+
15
+ def chunked_attention(q, k, v, batch_chunk=0):
16
+ # if batch_chunk > 0 and not torch.is_grad_enabled():
17
+ # batch_size = q.size(0)
18
+ # chunks = [slice(i, i + batch_chunk) for i in range(0, batch_size, batch_chunk)]
19
+ #
20
+ # out_chunks = []
21
+ # for chunk in chunks:
22
+ # q_chunk = q[chunk]
23
+ # k_chunk = k[chunk]
24
+ # v_chunk = v[chunk]
25
+ #
26
+ # out_chunk = torch.nn.functional.scaled_dot_product_attention(
27
+ # q_chunk, k_chunk, v_chunk, attn_mask=None
28
+ # )
29
+ # out_chunks.append(out_chunk)
30
+ #
31
+ # out = torch.cat(out_chunks, dim=0)
32
+ # else:
33
+ # out = torch.nn.functional.scaled_dot_product_attention(
34
+ # q, k, v, attn_mask=None
35
+ # )
36
+ out = xformers.ops.memory_efficient_attention(q, k, v)
37
+ return out
38
+
39
+
40
+ def nonlinearity(x):
41
+ return x * torch.sigmoid(x)
42
+
43
+
44
+ def GroupNorm(in_channels, num_groups=32):
45
+ return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
46
+
47
+
48
+ class DiagonalGaussianDistribution:
49
+ def __init__(self, parameters, deterministic=False):
50
+ self.parameters = parameters
51
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
52
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
53
+ self.deterministic = deterministic
54
+ self.std = torch.exp(0.5 * self.logvar)
55
+ self.var = torch.exp(self.logvar)
56
+ if self.deterministic:
57
+ self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
58
+
59
+ def sample(self, noise=None):
60
+ if noise is None:
61
+ noise = torch.randn(self.mean.shape)
62
+
63
+ x = self.mean + self.std * noise.to(device=self.parameters.device)
64
+ return x
65
+
66
+ def mode(self):
67
+ return self.mean
68
+
69
+
70
+ class EncoderDownSampleBlock(nn.Module):
71
+ def __init__(self, in_channels, with_conv):
72
+ super().__init__()
73
+ self.with_conv = with_conv
74
+ self.in_channels = in_channels
75
+ if self.with_conv:
76
+ self.conv = torch.nn.Conv2d(in_channels,
77
+ in_channels,
78
+ kernel_size=3,
79
+ stride=2,
80
+ padding=0)
81
+
82
+ def forward(self, x):
83
+ if self.with_conv:
84
+ pad = (0, 1, 0, 1)
85
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
86
+ x = self.conv(x)
87
+ else:
88
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
89
+ return x
90
+
91
+
92
+ class ResnetBlock(nn.Module):
93
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
94
+ dropout, temb_channels=512):
95
+ super().__init__()
96
+ self.in_channels = in_channels
97
+ out_channels = in_channels if out_channels is None else out_channels
98
+ self.out_channels = out_channels
99
+ self.use_conv_shortcut = conv_shortcut
100
+
101
+ self.norm1 = GroupNorm(in_channels)
102
+ self.conv1 = torch.nn.Conv2d(in_channels,
103
+ out_channels,
104
+ kernel_size=3,
105
+ stride=1,
106
+ padding=1)
107
+ if temb_channels > 0:
108
+ self.temb_proj = torch.nn.Linear(temb_channels,
109
+ out_channels)
110
+ self.norm2 = GroupNorm(out_channels)
111
+ self.dropout = torch.nn.Dropout(dropout)
112
+ self.conv2 = torch.nn.Conv2d(out_channels,
113
+ out_channels,
114
+ kernel_size=3,
115
+ stride=1,
116
+ padding=1)
117
+ if self.in_channels != self.out_channels:
118
+ if self.use_conv_shortcut:
119
+ self.conv_shortcut = torch.nn.Conv2d(in_channels,
120
+ out_channels,
121
+ kernel_size=3,
122
+ stride=1,
123
+ padding=1)
124
+ else:
125
+ self.nin_shortcut = torch.nn.Conv2d(in_channels,
126
+ out_channels,
127
+ kernel_size=1,
128
+ stride=1,
129
+ padding=0)
130
+
131
+ def forward(self, x, temb):
132
+ h = x
133
+ h = self.norm1(h)
134
+ h = nonlinearity(h)
135
+ h = self.conv1(h)
136
+
137
+ if temb is not None:
138
+ h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
139
+
140
+ h = self.norm2(h)
141
+ h = nonlinearity(h)
142
+ h = self.dropout(h)
143
+ h = self.conv2(h)
144
+
145
+ if self.in_channels != self.out_channels:
146
+ if self.use_conv_shortcut:
147
+ x = self.conv_shortcut(x)
148
+ else:
149
+ x = self.nin_shortcut(x)
150
+
151
+ return x + h
152
+
153
+
154
+ class Encoder(nn.Module):
155
+ def __init__(self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks,
156
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
157
+ resolution, z_channels, double_z=True, **kwargs):
158
+ super().__init__()
159
+ self.ch = ch
160
+ self.temb_ch = 0
161
+ self.num_resolutions = len(ch_mult)
162
+ self.num_res_blocks = num_res_blocks
163
+ self.resolution = resolution
164
+ self.in_channels = in_channels
165
+
166
+ # downsampling
167
+ self.conv_in = torch.nn.Conv2d(in_channels,
168
+ self.ch,
169
+ kernel_size=3,
170
+ stride=1,
171
+ padding=1)
172
+
173
+ curr_res = resolution
174
+ in_ch_mult = (1,) + tuple(ch_mult)
175
+ self.in_ch_mult = in_ch_mult
176
+ self.down = nn.ModuleList()
177
+ for i_level in range(self.num_resolutions):
178
+ block = nn.ModuleList()
179
+ attn = nn.ModuleList()
180
+ block_in = ch * in_ch_mult[i_level]
181
+ block_out = ch * ch_mult[i_level]
182
+ for i_block in range(self.num_res_blocks):
183
+ block.append(ResnetBlock(in_channels=block_in,
184
+ out_channels=block_out,
185
+ temb_channels=self.temb_ch,
186
+ dropout=dropout))
187
+ block_in = block_out
188
+ if curr_res in attn_resolutions:
189
+ attn.append(Attention(block_in))
190
+ down = nn.Module()
191
+ down.block = block
192
+ down.attn = attn
193
+ if i_level != self.num_resolutions - 1:
194
+ down.downsample = EncoderDownSampleBlock(block_in, resamp_with_conv)
195
+ curr_res = curr_res // 2
196
+ self.down.append(down)
197
+
198
+ # middle
199
+ self.mid = nn.Module()
200
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
201
+ out_channels=block_in,
202
+ temb_channels=self.temb_ch,
203
+ dropout=dropout)
204
+ self.mid.attn_1 = Attention(block_in)
205
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
206
+ out_channels=block_in,
207
+ temb_channels=self.temb_ch,
208
+ dropout=dropout)
209
+
210
+ # end
211
+ self.norm_out = GroupNorm(block_in)
212
+ self.conv_out = torch.nn.Conv2d(block_in,
213
+ 2 * z_channels if double_z else z_channels,
214
+ kernel_size=3,
215
+ stride=1,
216
+ padding=1)
217
+
218
+ def forward(self, x, return_hidden_states=False):
219
+ # timestep embedding
220
+ temb = None
221
+
222
+ # print(f'encoder-input={x.shape}')
223
+ # downsampling
224
+ hs = [self.conv_in(x)]
225
+
226
+ ## if we return hidden states for decoder usage, we will store them in a list
227
+ if return_hidden_states:
228
+ hidden_states = []
229
+ # print(f'encoder-conv in feat={hs[0].shape}')
230
+ for i_level in range(self.num_resolutions):
231
+ for i_block in range(self.num_res_blocks):
232
+ h = self.down[i_level].block[i_block](hs[-1], temb)
233
+ # print(f'encoder-down feat={h.shape}')
234
+ if len(self.down[i_level].attn) > 0:
235
+ h = self.down[i_level].attn[i_block](h)
236
+ hs.append(h)
237
+ if return_hidden_states:
238
+ hidden_states.append(h)
239
+ if i_level != self.num_resolutions - 1:
240
+ # print(f'encoder-downsample (input)={hs[-1].shape}')
241
+ hs.append(self.down[i_level].downsample(hs[-1]))
242
+ # print(f'encoder-downsample (output)={hs[-1].shape}')
243
+ if return_hidden_states:
244
+ hidden_states.append(hs[0])
245
+ # middle
246
+ h = hs[-1]
247
+ h = self.mid.block_1(h, temb)
248
+ # print(f'encoder-mid1 feat={h.shape}')
249
+ h = self.mid.attn_1(h)
250
+ h = self.mid.block_2(h, temb)
251
+ # print(f'encoder-mid2 feat={h.shape}')
252
+
253
+ # end
254
+ h = self.norm_out(h)
255
+ h = nonlinearity(h)
256
+ h = self.conv_out(h)
257
+ # print(f'end feat={h.shape}')
258
+ if return_hidden_states:
259
+ return h, hidden_states
260
+ else:
261
+ return h
262
+
263
+
264
+ class ConvCombiner(nn.Module):
265
+ def __init__(self, ch):
266
+ super().__init__()
267
+ self.conv = nn.Conv2d(ch, ch, 1, padding=0)
268
+
269
+ nn.init.zeros_(self.conv.weight)
270
+ nn.init.zeros_(self.conv.bias)
271
+
272
+ def forward(self, x, context):
273
+ ## x: b c h w, context: b c 2 h w
274
+ b, c, l, h, w = context.shape
275
+ bt, c, h, w = x.shape
276
+ context = rearrange(context, "b c l h w -> (b l) c h w")
277
+ context = self.conv(context)
278
+ context = rearrange(context, "(b l) c h w -> b c l h w", l=l)
279
+ x = rearrange(x, "(b t) c h w -> b c t h w", t=bt // b)
280
+ x[:, :, 0] = x[:, :, 0] + context[:, :, 0]
281
+ x[:, :, -1] = x[:, :, -1] + context[:, :, -1]
282
+ x = rearrange(x, "b c t h w -> (b t) c h w")
283
+ return x
284
+
285
+
286
+ class AttentionCombiner(nn.Module):
287
+ def __init__(
288
+ self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs
289
+ ):
290
+ super().__init__()
291
+
292
+ inner_dim = dim_head * heads
293
+ context_dim = default(context_dim, query_dim)
294
+
295
+ self.heads = heads
296
+ self.dim_head = dim_head
297
+
298
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
299
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
300
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
301
+
302
+ self.to_out = nn.Sequential(
303
+ nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
304
+ )
305
+ self.attention_op = None
306
+
307
+ self.norm = GroupNorm(query_dim)
308
+ nn.init.zeros_(self.to_out[0].weight)
309
+ nn.init.zeros_(self.to_out[0].bias)
310
+
311
+ def forward(
312
+ self,
313
+ x,
314
+ context=None,
315
+ mask=None,
316
+ ):
317
+ bt, c, h, w = x.shape
318
+ h_ = self.norm(x)
319
+ h_ = rearrange(h_, "b c h w -> b (h w) c")
320
+ q = self.to_q(h_)
321
+
322
+ b, c, l, h, w = context.shape
323
+ context = rearrange(context, "b c l h w -> (b l) (h w) c")
324
+ k = self.to_k(context)
325
+ v = self.to_v(context)
326
+
327
+ t = bt // b
328
+ k = repeat(k, "(b l) d c -> (b t) (l d) c", l=l, t=t)
329
+ v = repeat(v, "(b l) d c -> (b t) (l d) c", l=l, t=t)
330
+
331
+ b, _, _ = q.shape
332
+ q, k, v = map(
333
+ lambda t: t.unsqueeze(3)
334
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
335
+ .permute(0, 2, 1, 3)
336
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
337
+ .contiguous(),
338
+ (q, k, v),
339
+ )
340
+
341
+ out = chunked_attention(
342
+ q, k, v, batch_chunk=1
343
+ )
344
+
345
+ if exists(mask):
346
+ raise NotImplementedError
347
+
348
+ out = (
349
+ out.unsqueeze(0)
350
+ .reshape(b, self.heads, out.shape[1], self.dim_head)
351
+ .permute(0, 2, 1, 3)
352
+ .reshape(b, out.shape[1], self.heads * self.dim_head)
353
+ )
354
+ out = self.to_out(out)
355
+ out = rearrange(out, "bt (h w) c -> bt c h w", h=h, w=w, c=c)
356
+ return x + out
357
+
358
+
359
+ class Attention(nn.Module):
360
+ def __init__(self, in_channels):
361
+ super().__init__()
362
+ self.in_channels = in_channels
363
+
364
+ self.norm = GroupNorm(in_channels)
365
+ self.q = torch.nn.Conv2d(
366
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
367
+ )
368
+ self.k = torch.nn.Conv2d(
369
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
370
+ )
371
+ self.v = torch.nn.Conv2d(
372
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
373
+ )
374
+ self.proj_out = torch.nn.Conv2d(
375
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
376
+ )
377
+
378
+ def attention(self, h_: torch.Tensor) -> torch.Tensor:
379
+ h_ = self.norm(h_)
380
+ q = self.q(h_)
381
+ k = self.k(h_)
382
+ v = self.v(h_)
383
+
384
+ # compute attention
385
+ B, C, H, W = q.shape
386
+ q, k, v = map(lambda x: rearrange(x, "b c h w -> b (h w) c"), (q, k, v))
387
+
388
+ q, k, v = map(
389
+ lambda t: t.unsqueeze(3)
390
+ .reshape(B, t.shape[1], 1, C)
391
+ .permute(0, 2, 1, 3)
392
+ .reshape(B * 1, t.shape[1], C)
393
+ .contiguous(),
394
+ (q, k, v),
395
+ )
396
+
397
+ out = chunked_attention(
398
+ q, k, v, batch_chunk=1
399
+ )
400
+
401
+ out = (
402
+ out.unsqueeze(0)
403
+ .reshape(B, 1, out.shape[1], C)
404
+ .permute(0, 2, 1, 3)
405
+ .reshape(B, out.shape[1], C)
406
+ )
407
+ return rearrange(out, "b (h w) c -> b c h w", b=B, h=H, w=W, c=C)
408
+
409
+ def forward(self, x, **kwargs):
410
+ h_ = x
411
+ h_ = self.attention(h_)
412
+ h_ = self.proj_out(h_)
413
+ return x + h_
414
+
415
+
416
+ class VideoDecoder(nn.Module):
417
+ def __init__(
418
+ self,
419
+ *,
420
+ ch,
421
+ out_ch,
422
+ ch_mult=(1, 2, 4, 8),
423
+ num_res_blocks,
424
+ attn_resolutions,
425
+ dropout=0.0,
426
+ resamp_with_conv=True,
427
+ in_channels,
428
+ resolution,
429
+ z_channels,
430
+ give_pre_end=False,
431
+ tanh_out=False,
432
+ use_linear_attn=False,
433
+ attn_level=[2, 3],
434
+ video_kernel_size=[3, 1, 1],
435
+ alpha: float = 0.0,
436
+ merge_strategy: str = "learned",
437
+ **kwargs,
438
+ ):
439
+ super().__init__()
440
+ self.video_kernel_size = video_kernel_size
441
+ self.alpha = alpha
442
+ self.merge_strategy = merge_strategy
443
+ self.ch = ch
444
+ self.temb_ch = 0
445
+ self.num_resolutions = len(ch_mult)
446
+ self.num_res_blocks = num_res_blocks
447
+ self.resolution = resolution
448
+ self.in_channels = in_channels
449
+ self.give_pre_end = give_pre_end
450
+ self.tanh_out = tanh_out
451
+ self.attn_level = attn_level
452
+ # compute in_ch_mult, block_in and curr_res at lowest res
453
+ in_ch_mult = (1,) + tuple(ch_mult)
454
+ block_in = ch * ch_mult[self.num_resolutions - 1]
455
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
456
+ self.z_shape = (1, z_channels, curr_res, curr_res)
457
+
458
+ # z to block_in
459
+ self.conv_in = torch.nn.Conv2d(
460
+ z_channels, block_in, kernel_size=3, stride=1, padding=1
461
+ )
462
+
463
+ # middle
464
+ self.mid = nn.Module()
465
+ self.mid.block_1 = VideoResBlock(
466
+ in_channels=block_in,
467
+ out_channels=block_in,
468
+ temb_channels=self.temb_ch,
469
+ dropout=dropout,
470
+ video_kernel_size=self.video_kernel_size,
471
+ alpha=self.alpha,
472
+ merge_strategy=self.merge_strategy,
473
+ )
474
+ self.mid.attn_1 = Attention(block_in)
475
+ self.mid.block_2 = VideoResBlock(
476
+ in_channels=block_in,
477
+ out_channels=block_in,
478
+ temb_channels=self.temb_ch,
479
+ dropout=dropout,
480
+ video_kernel_size=self.video_kernel_size,
481
+ alpha=self.alpha,
482
+ merge_strategy=self.merge_strategy,
483
+ )
484
+
485
+ # upsampling
486
+ self.up = nn.ModuleList()
487
+ self.attn_refinement = nn.ModuleList()
488
+ for i_level in reversed(range(self.num_resolutions)):
489
+ block = nn.ModuleList()
490
+ attn = nn.ModuleList()
491
+ block_out = ch * ch_mult[i_level]
492
+ for i_block in range(self.num_res_blocks + 1):
493
+ block.append(
494
+ VideoResBlock(
495
+ in_channels=block_in,
496
+ out_channels=block_out,
497
+ temb_channels=self.temb_ch,
498
+ dropout=dropout,
499
+ video_kernel_size=self.video_kernel_size,
500
+ alpha=self.alpha,
501
+ merge_strategy=self.merge_strategy,
502
+ )
503
+ )
504
+ block_in = block_out
505
+ if curr_res in attn_resolutions:
506
+ attn.append(Attention(block_in))
507
+ up = nn.Module()
508
+ up.block = block
509
+ up.attn = attn
510
+ if i_level != 0:
511
+ up.upsample = Upsample(block_in, resamp_with_conv)
512
+ curr_res = curr_res * 2
513
+ self.up.insert(0, up) # prepend to get consistent order
514
+
515
+ if i_level in self.attn_level:
516
+ self.attn_refinement.insert(0, AttentionCombiner(block_in))
517
+ else:
518
+ self.attn_refinement.insert(0, ConvCombiner(block_in))
519
+ # end
520
+ self.norm_out = GroupNorm(block_in)
521
+ self.attn_refinement.append(ConvCombiner(block_in))
522
+ self.conv_out = DecoderConv3D(
523
+ block_in, out_ch, kernel_size=3, stride=1, padding=1, video_kernel_size=self.video_kernel_size
524
+ )
525
+
526
+ def forward(self, z, ref_context=None, **kwargs):
527
+ ## ref_context: b c 2 h w, 2 means starting and ending frame
528
+ # assert z.shape[1:] == self.z_shape[1:]
529
+ self.last_z_shape = z.shape
530
+ # timestep embedding
531
+ temb = None
532
+
533
+ # z to block_in
534
+ h = self.conv_in(z)
535
+
536
+ # middle
537
+ h = self.mid.block_1(h, temb, **kwargs)
538
+ h = self.mid.attn_1(h, **kwargs)
539
+ h = self.mid.block_2(h, temb, **kwargs)
540
+
541
+ # upsampling
542
+ for i_level in reversed(range(self.num_resolutions)):
543
+ for i_block in range(self.num_res_blocks + 1):
544
+ h = self.up[i_level].block[i_block](h, temb, **kwargs)
545
+ if len(self.up[i_level].attn) > 0:
546
+ h = self.up[i_level].attn[i_block](h, **kwargs)
547
+ if ref_context:
548
+ h = self.attn_refinement[i_level](x=h, context=ref_context[i_level])
549
+ if i_level != 0:
550
+ h = self.up[i_level].upsample(h)
551
+
552
+ # end
553
+ if self.give_pre_end:
554
+ return h
555
+
556
+ h = self.norm_out(h)
557
+ h = nonlinearity(h)
558
+ if ref_context:
559
+ # print(h.shape, ref_context[i_level].shape) #torch.Size([8, 128, 256, 256]) torch.Size([1, 128, 2, 256, 256])
560
+ h = self.attn_refinement[-1](x=h, context=ref_context[-1])
561
+ h = self.conv_out(h, **kwargs)
562
+ if self.tanh_out:
563
+ h = torch.tanh(h)
564
+ return h
565
+
566
+
567
+ class TimeStackBlock(torch.nn.Module):
568
+ def __init__(
569
+ self,
570
+ channels: int,
571
+ emb_channels: int,
572
+ dropout: float,
573
+ out_channels: int = None,
574
+ use_conv: bool = False,
575
+ use_scale_shift_norm: bool = False,
576
+ dims: int = 2,
577
+ use_checkpoint: bool = False,
578
+ up: bool = False,
579
+ down: bool = False,
580
+ kernel_size: int = 3,
581
+ exchange_temb_dims: bool = False,
582
+ skip_t_emb: bool = False,
583
+ ):
584
+ super().__init__()
585
+ self.channels = channels
586
+ self.emb_channels = emb_channels
587
+ self.dropout = dropout
588
+ self.out_channels = out_channels or channels
589
+ self.use_conv = use_conv
590
+ self.use_checkpoint = use_checkpoint
591
+ self.use_scale_shift_norm = use_scale_shift_norm
592
+ self.exchange_temb_dims = exchange_temb_dims
593
+
594
+ if isinstance(kernel_size, list):
595
+ padding = [k // 2 for k in kernel_size]
596
+ else:
597
+ padding = kernel_size // 2
598
+
599
+ self.in_layers = nn.Sequential(
600
+ normalization(channels),
601
+ nn.SiLU(),
602
+ conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding),
603
+ )
604
+
605
+ self.updown = up or down
606
+
607
+ if up:
608
+ self.h_upd = Upsample(channels, False, dims)
609
+ self.x_upd = Upsample(channels, False, dims)
610
+ elif down:
611
+ self.h_upd = Downsample(channels, False, dims)
612
+ self.x_upd = Downsample(channels, False, dims)
613
+ else:
614
+ self.h_upd = self.x_upd = nn.Identity()
615
+
616
+ self.skip_t_emb = skip_t_emb
617
+ self.emb_out_channels = (
618
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels
619
+ )
620
+ if self.skip_t_emb:
621
+ # print(f"Skipping timestep embedding in {self.__class__.__name__}")
622
+ assert not self.use_scale_shift_norm
623
+ self.emb_layers = None
624
+ self.exchange_temb_dims = False
625
+ else:
626
+ self.emb_layers = nn.Sequential(
627
+ nn.SiLU(),
628
+ linear(
629
+ emb_channels,
630
+ self.emb_out_channels,
631
+ ),
632
+ )
633
+
634
+ self.out_layers = nn.Sequential(
635
+ normalization(self.out_channels),
636
+ nn.SiLU(),
637
+ nn.Dropout(p=dropout),
638
+ zero_module(
639
+ conv_nd(
640
+ dims,
641
+ self.out_channels,
642
+ self.out_channels,
643
+ kernel_size,
644
+ padding=padding,
645
+ )
646
+ ),
647
+ )
648
+
649
+ if self.out_channels == channels:
650
+ self.skip_connection = nn.Identity()
651
+ elif use_conv:
652
+ self.skip_connection = conv_nd(
653
+ dims, channels, self.out_channels, kernel_size, padding=padding
654
+ )
655
+ else:
656
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
657
+
658
+ def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
659
+ if self.updown:
660
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
661
+ h = in_rest(x)
662
+ h = self.h_upd(h)
663
+ x = self.x_upd(x)
664
+ h = in_conv(h)
665
+ else:
666
+ h = self.in_layers(x)
667
+
668
+ if self.skip_t_emb:
669
+ emb_out = torch.zeros_like(h)
670
+ else:
671
+ emb_out = self.emb_layers(emb).type(h.dtype)
672
+ while len(emb_out.shape) < len(h.shape):
673
+ emb_out = emb_out[..., None]
674
+ if self.use_scale_shift_norm:
675
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
676
+ scale, shift = torch.chunk(emb_out, 2, dim=1)
677
+ h = out_norm(h) * (1 + scale) + shift
678
+ h = out_rest(h)
679
+ else:
680
+ if self.exchange_temb_dims:
681
+ emb_out = rearrange(emb_out, "b t c ... -> b c t ...")
682
+ h = h + emb_out
683
+ h = self.out_layers(h)
684
+ return self.skip_connection(x) + h
685
+
686
+
687
+ class VideoResBlock(ResnetBlock):
688
+ def __init__(
689
+ self,
690
+ out_channels,
691
+ *args,
692
+ dropout=0.0,
693
+ video_kernel_size=3,
694
+ alpha=0.0,
695
+ merge_strategy="learned",
696
+ **kwargs,
697
+ ):
698
+ super().__init__(out_channels=out_channels, dropout=dropout, *args, **kwargs)
699
+ if video_kernel_size is None:
700
+ video_kernel_size = [3, 1, 1]
701
+ self.time_stack = TimeStackBlock(
702
+ channels=out_channels,
703
+ emb_channels=0,
704
+ dropout=dropout,
705
+ dims=3,
706
+ use_scale_shift_norm=False,
707
+ use_conv=False,
708
+ up=False,
709
+ down=False,
710
+ kernel_size=video_kernel_size,
711
+ use_checkpoint=True,
712
+ skip_t_emb=True,
713
+ )
714
+
715
+ self.merge_strategy = merge_strategy
716
+ if self.merge_strategy == "fixed":
717
+ self.register_buffer("mix_factor", torch.Tensor([alpha]))
718
+ elif self.merge_strategy == "learned":
719
+ self.register_parameter(
720
+ "mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
721
+ )
722
+ else:
723
+ raise ValueError(f"unknown merge strategy {self.merge_strategy}")
724
+
725
+ def get_alpha(self, bs):
726
+ if self.merge_strategy == "fixed":
727
+ return self.mix_factor
728
+ elif self.merge_strategy == "learned":
729
+ return torch.sigmoid(self.mix_factor)
730
+ else:
731
+ raise NotImplementedError()
732
+
733
+ def forward(self, x, temb, skip_video=False, timesteps=None):
734
+ assert isinstance(timesteps, int)
735
+
736
+ b, c, h, w = x.shape
737
+
738
+ x = super().forward(x, temb)
739
+
740
+ if not skip_video:
741
+ x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
742
+
743
+ x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
744
+
745
+ x = self.time_stack(x, temb)
746
+
747
+ alpha = self.get_alpha(bs=b // timesteps)
748
+ x = alpha * x + (1.0 - alpha) * x_mix
749
+
750
+ x = rearrange(x, "b c t h w -> (b t) c h w")
751
+ return x
752
+
753
+
754
+ class DecoderConv3D(torch.nn.Conv2d):
755
+ def __init__(self, in_channels, out_channels, video_kernel_size=3, *args, **kwargs):
756
+ super().__init__(in_channels, out_channels, *args, **kwargs)
757
+ if isinstance(video_kernel_size, list):
758
+ padding = [int(k // 2) for k in video_kernel_size]
759
+ else:
760
+ padding = int(video_kernel_size // 2)
761
+
762
+ self.time_mix_conv = torch.nn.Conv3d(
763
+ in_channels=out_channels,
764
+ out_channels=out_channels,
765
+ kernel_size=video_kernel_size,
766
+ padding=padding,
767
+ )
768
+
769
+ def forward(self, input, timesteps, skip_video=False):
770
+ x = super().forward(input)
771
+ if skip_video:
772
+ return x
773
+ x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
774
+ x = self.time_mix_conv(x)
775
+ return rearrange(x, "b c t h w -> (b t) c h w")
776
+
777
+
778
+ class VideoAutoencoderKL(torch.nn.Module, PyTorchModelHubMixin):
779
+ def __init__(self,
780
+ double_z=True,
781
+ z_channels=4,
782
+ resolution=256,
783
+ in_channels=3,
784
+ out_ch=3,
785
+ ch=128,
786
+ ch_mult=[],
787
+ num_res_blocks=2,
788
+ attn_resolutions=[],
789
+ dropout=0.0,
790
+ ):
791
+ super().__init__()
792
+ self.encoder = Encoder(double_z=double_z, z_channels=z_channels, resolution=resolution, in_channels=in_channels,
793
+ out_ch=out_ch, ch=ch, ch_mult=ch_mult, num_res_blocks=num_res_blocks,
794
+ attn_resolutions=attn_resolutions, dropout=dropout)
795
+ self.decoder = VideoDecoder(double_z=double_z, z_channels=z_channels, resolution=resolution,
796
+ in_channels=in_channels, out_ch=out_ch, ch=ch, ch_mult=ch_mult,
797
+ num_res_blocks=num_res_blocks, attn_resolutions=attn_resolutions, dropout=dropout)
798
+ self.quant_conv = torch.nn.Conv2d(2 * z_channels, 2 * z_channels, 1)
799
+ self.post_quant_conv = torch.nn.Conv2d(z_channels, z_channels, 1)
800
+ self.scale_factor = 0.18215
801
+
802
+ def encode(self, x, return_hidden_states=False, **kwargs):
803
+ if return_hidden_states:
804
+ h, hidden = self.encoder(x, return_hidden_states)
805
+ moments = self.quant_conv(h)
806
+ posterior = DiagonalGaussianDistribution(moments)
807
+ return posterior, hidden
808
+ else:
809
+ h = self.encoder(x)
810
+ moments = self.quant_conv(h)
811
+ posterior = DiagonalGaussianDistribution(moments)
812
+ return posterior, None
813
+
814
+ def decode(self, z, **kwargs):
815
+ if len(kwargs) == 0:
816
+ z = self.post_quant_conv(z)
817
+ dec = self.decoder(z, **kwargs)
818
+ return dec
819
+
820
+ @property
821
+ def device(self):
822
+ return next(self.parameters()).device
823
+
824
+ @property
825
+ def dtype(self):
826
+ return next(self.parameters()).dtype
gradio_app.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ os.environ['HF_HOME'] = os.path.join(os.path.dirname(__file__), 'hf_download')
4
+ result_dir = os.path.join('./', 'results')
5
+ os.makedirs(result_dir, exist_ok=True)
6
+
7
+
8
+ import functools
9
+ import os
10
+ import random
11
+ import gradio as gr
12
+ import numpy as np
13
+ import torch
14
+ import wd14tagger
15
+ import memory_management
16
+ import uuid
17
+
18
+ from PIL import Image
19
+ from diffusers_helper.code_cond import unet_add_coded_conds
20
+ from diffusers_helper.cat_cond import unet_add_concat_conds
21
+ from diffusers_helper.k_diffusion import KDiffusionSampler
22
+ from diffusers import AutoencoderKL, UNet2DConditionModel
23
+ from diffusers.models.attention_processor import AttnProcessor2_0
24
+ from transformers import CLIPTextModel, CLIPTokenizer
25
+ from diffusers_vdm.pipeline import LatentVideoDiffusionPipeline
26
+ from diffusers_vdm.utils import resize_and_center_crop, save_bcthw_as_mp4
27
+
28
+
29
+ class ModifiedUNet(UNet2DConditionModel):
30
+ @classmethod
31
+ def from_config(cls, *args, **kwargs):
32
+ m = super().from_config(*args, **kwargs)
33
+ unet_add_concat_conds(unet=m, new_channels=4)
34
+ unet_add_coded_conds(unet=m, added_number_count=1)
35
+ return m
36
+
37
+
38
+ model_name = 'lllyasviel/paints_undo_single_frame'
39
+ tokenizer = CLIPTokenizer.from_pretrained(model_name, subfolder="tokenizer")
40
+ text_encoder = CLIPTextModel.from_pretrained(model_name, subfolder="text_encoder").to(torch.float16)
41
+ vae = AutoencoderKL.from_pretrained(model_name, subfolder="vae").to(torch.bfloat16) # bfloat16 vae
42
+ unet = ModifiedUNet.from_pretrained(model_name, subfolder="unet").to(torch.float16)
43
+
44
+ unet.set_attn_processor(AttnProcessor2_0())
45
+ vae.set_attn_processor(AttnProcessor2_0())
46
+
47
+ video_pipe = LatentVideoDiffusionPipeline.from_pretrained(
48
+ 'lllyasviel/paints_undo_multi_frame',
49
+ fp16=True
50
+ )
51
+
52
+ memory_management.unload_all_models([
53
+ video_pipe.unet, video_pipe.vae, video_pipe.text_encoder, video_pipe.image_projection, video_pipe.image_encoder,
54
+ unet, vae, text_encoder
55
+ ])
56
+
57
+ k_sampler = KDiffusionSampler(
58
+ unet=unet,
59
+ timesteps=1000,
60
+ linear_start=0.00085,
61
+ linear_end=0.020,
62
+ linear=True
63
+ )
64
+
65
+
66
+ def find_best_bucket(h, w, options):
67
+ min_metric = float('inf')
68
+ best_bucket = None
69
+ for (bucket_h, bucket_w) in options:
70
+ metric = abs(h * bucket_w - w * bucket_h)
71
+ if metric <= min_metric:
72
+ min_metric = metric
73
+ best_bucket = (bucket_h, bucket_w)
74
+ return best_bucket
75
+
76
+
77
+ @torch.inference_mode()
78
+ def encode_cropped_prompt_77tokens(txt: str):
79
+ memory_management.load_models_to_gpu(text_encoder)
80
+ cond_ids = tokenizer(txt,
81
+ padding="max_length",
82
+ max_length=tokenizer.model_max_length,
83
+ truncation=True,
84
+ return_tensors="pt").input_ids.to(device=text_encoder.device)
85
+ text_cond = text_encoder(cond_ids, attention_mask=None).last_hidden_state
86
+ return text_cond
87
+
88
+
89
+ @torch.inference_mode()
90
+ def pytorch2numpy(imgs):
91
+ results = []
92
+ for x in imgs:
93
+ y = x.movedim(0, -1)
94
+ y = y * 127.5 + 127.5
95
+ y = y.detach().float().cpu().numpy().clip(0, 255).astype(np.uint8)
96
+ results.append(y)
97
+ return results
98
+
99
+
100
+ @torch.inference_mode()
101
+ def numpy2pytorch(imgs):
102
+ h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.5 - 1.0
103
+ h = h.movedim(-1, 1)
104
+ return h
105
+
106
+
107
+ def resize_without_crop(image, target_width, target_height):
108
+ pil_image = Image.fromarray(image)
109
+ resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS)
110
+ return np.array(resized_image)
111
+
112
+
113
+ @torch.inference_mode()
114
+ def interrogator_process(x):
115
+ return wd14tagger.default_interrogator(x)
116
+
117
+
118
+ @torch.inference_mode()
119
+ def process(input_fg, prompt, input_undo_steps, image_width, image_height, seed, steps, n_prompt, cfg,
120
+ progress=gr.Progress()):
121
+ rng = torch.Generator(device=memory_management.gpu).manual_seed(int(seed))
122
+
123
+ memory_management.load_models_to_gpu(vae)
124
+ fg = resize_and_center_crop(input_fg, image_width, image_height)
125
+ concat_conds = numpy2pytorch([fg]).to(device=vae.device, dtype=vae.dtype)
126
+ concat_conds = vae.encode(concat_conds).latent_dist.mode() * vae.config.scaling_factor
127
+
128
+ memory_management.load_models_to_gpu(text_encoder)
129
+ conds = encode_cropped_prompt_77tokens(prompt)
130
+ unconds = encode_cropped_prompt_77tokens(n_prompt)
131
+
132
+ memory_management.load_models_to_gpu(unet)
133
+ fs = torch.tensor(input_undo_steps).to(device=unet.device, dtype=torch.long)
134
+ initial_latents = torch.zeros_like(concat_conds)
135
+ concat_conds = concat_conds.to(device=unet.device, dtype=unet.dtype)
136
+ latents = k_sampler(
137
+ initial_latent=initial_latents,
138
+ strength=1.0,
139
+ num_inference_steps=steps,
140
+ guidance_scale=cfg,
141
+ batch_size=len(input_undo_steps),
142
+ generator=rng,
143
+ prompt_embeds=conds,
144
+ negative_prompt_embeds=unconds,
145
+ cross_attention_kwargs={'concat_conds': concat_conds, 'coded_conds': fs},
146
+ same_noise_in_batch=True,
147
+ progress_tqdm=functools.partial(progress.tqdm, desc='Generating Key Frames')
148
+ ).to(vae.dtype) / vae.config.scaling_factor
149
+
150
+ memory_management.load_models_to_gpu(vae)
151
+ pixels = vae.decode(latents).sample
152
+ pixels = pytorch2numpy(pixels)
153
+ pixels = [fg] + pixels + [np.zeros_like(fg) + 255]
154
+
155
+ return pixels
156
+
157
+
158
+ @torch.inference_mode()
159
+ def process_video_inner(image_1, image_2, prompt, seed=123, steps=25, cfg_scale=7.5, fs=3, progress_tqdm=None):
160
+ random.seed(seed)
161
+ np.random.seed(seed)
162
+ torch.manual_seed(seed)
163
+ torch.cuda.manual_seed_all(seed)
164
+
165
+ frames = 16
166
+
167
+ target_height, target_width = find_best_bucket(
168
+ image_1.shape[0], image_1.shape[1],
169
+ options=[(320, 512), (384, 448), (448, 384), (512, 320)]
170
+ )
171
+
172
+ image_1 = resize_and_center_crop(image_1, target_width=target_width, target_height=target_height)
173
+ image_2 = resize_and_center_crop(image_2, target_width=target_width, target_height=target_height)
174
+ input_frames = numpy2pytorch([image_1, image_2])
175
+ input_frames = input_frames.unsqueeze(0).movedim(1, 2)
176
+
177
+ memory_management.load_models_to_gpu(video_pipe.text_encoder)
178
+ positive_text_cond = video_pipe.encode_cropped_prompt_77tokens(prompt)
179
+ negative_text_cond = video_pipe.encode_cropped_prompt_77tokens("")
180
+
181
+ memory_management.load_models_to_gpu([video_pipe.image_projection, video_pipe.image_encoder])
182
+ input_frames = input_frames.to(device=video_pipe.image_encoder.device, dtype=video_pipe.image_encoder.dtype)
183
+ positive_image_cond = video_pipe.encode_clip_vision(input_frames)
184
+ positive_image_cond = video_pipe.image_projection(positive_image_cond)
185
+ negative_image_cond = video_pipe.encode_clip_vision(torch.zeros_like(input_frames))
186
+ negative_image_cond = video_pipe.image_projection(negative_image_cond)
187
+
188
+ memory_management.load_models_to_gpu([video_pipe.vae])
189
+ input_frames = input_frames.to(device=video_pipe.vae.device, dtype=video_pipe.vae.dtype)
190
+ input_frame_latents, vae_hidden_states = video_pipe.encode_latents(input_frames, return_hidden_states=True)
191
+ first_frame = input_frame_latents[:, :, 0]
192
+ last_frame = input_frame_latents[:, :, 1]
193
+ concat_cond = torch.stack([first_frame] + [torch.zeros_like(first_frame)] * (frames - 2) + [last_frame], dim=2)
194
+
195
+ memory_management.load_models_to_gpu([video_pipe.unet])
196
+ latents = video_pipe(
197
+ batch_size=1,
198
+ steps=int(steps),
199
+ guidance_scale=cfg_scale,
200
+ positive_text_cond=positive_text_cond,
201
+ negative_text_cond=negative_text_cond,
202
+ positive_image_cond=positive_image_cond,
203
+ negative_image_cond=negative_image_cond,
204
+ concat_cond=concat_cond,
205
+ fs=fs,
206
+ progress_tqdm=progress_tqdm
207
+ )
208
+
209
+ memory_management.load_models_to_gpu([video_pipe.vae])
210
+ video = video_pipe.decode_latents(latents, vae_hidden_states)
211
+ return video, image_1, image_2
212
+
213
+
214
+ @torch.inference_mode()
215
+ def process_video(keyframes, prompt, steps, cfg, fps, seed, progress=gr.Progress()):
216
+ result_frames = []
217
+ cropped_images = []
218
+
219
+ for i, (im1, im2) in enumerate(zip(keyframes[:-1], keyframes[1:])):
220
+ im1 = np.array(Image.open(im1[0]))
221
+ im2 = np.array(Image.open(im2[0]))
222
+ frames, im1, im2 = process_video_inner(
223
+ im1, im2, prompt, seed=seed + i, steps=steps, cfg_scale=cfg, fs=3,
224
+ progress_tqdm=functools.partial(progress.tqdm, desc=f'Generating Videos ({i + 1}/{len(keyframes) - 1})')
225
+ )
226
+ result_frames.append(frames[:, :, :-1, :, :])
227
+ cropped_images.append([im1, im2])
228
+
229
+ video = torch.cat(result_frames, dim=2)
230
+ video = torch.flip(video, dims=[2])
231
+
232
+ uuid_name = str(uuid.uuid4())
233
+ output_filename = os.path.join(result_dir, uuid_name + '.mp4')
234
+ Image.fromarray(cropped_images[0][0]).save(os.path.join(result_dir, uuid_name + '.png'))
235
+ video = save_bcthw_as_mp4(video, output_filename, fps=fps)
236
+ video = [x.cpu().numpy() for x in video]
237
+ return output_filename, video
238
+
239
+
240
+ block = gr.Blocks().queue()
241
+ with block:
242
+ gr.Markdown('# Paints-Undo')
243
+
244
+ with gr.Accordion(label='Step 1: Upload Image and Generate Prompt', open=True):
245
+ with gr.Row():
246
+ with gr.Column():
247
+ input_fg = gr.Image(sources=['upload'], type="numpy", label="Image", height=512)
248
+ with gr.Column():
249
+ prompt_gen_button = gr.Button(value="Generate Prompt", interactive=False)
250
+ prompt = gr.Textbox(label="Output Prompt", interactive=True)
251
+
252
+ with gr.Accordion(label='Step 2: Generate Key Frames', open=True):
253
+ with gr.Row():
254
+ with gr.Column():
255
+ input_undo_steps = gr.Dropdown(label="Operation Steps", value=[400, 600, 800, 900, 950, 999],
256
+ choices=list(range(1000)), multiselect=True)
257
+ seed = gr.Slider(label='Stage 1 Seed', minimum=0, maximum=50000, step=1, value=12345)
258
+ image_width = gr.Slider(label="Image Width", minimum=256, maximum=1024, value=512, step=64)
259
+ image_height = gr.Slider(label="Image Height", minimum=256, maximum=1024, value=640, step=64)
260
+ steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=50, step=1)
261
+ cfg = gr.Slider(label="CFG Scale", minimum=1.0, maximum=32.0, value=3.0, step=0.01)
262
+ n_prompt = gr.Textbox(label="Negative Prompt",
263
+ value='lowres, bad anatomy, bad hands, cropped, worst quality')
264
+
265
+ with gr.Column():
266
+ key_gen_button = gr.Button(value="Generate Key Frames", interactive=False)
267
+ result_gallery = gr.Gallery(height=512, object_fit='contain', label='Outputs', columns=4)
268
+
269
+ with gr.Accordion(label='Step 3: Generate All Videos', open=True):
270
+ with gr.Row():
271
+ with gr.Column():
272
+ i2v_input_text = gr.Text(label='Prompts', value='1girl, masterpiece, best quality')
273
+ i2v_seed = gr.Slider(label='Stage 2 Seed', minimum=0, maximum=50000, step=1, value=123)
274
+ i2v_cfg_scale = gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='CFG Scale', value=7.5,
275
+ elem_id="i2v_cfg_scale")
276
+ i2v_steps = gr.Slider(minimum=1, maximum=60, step=1, elem_id="i2v_steps",
277
+ label="Sampling steps", value=50)
278
+ i2v_fps = gr.Slider(minimum=1, maximum=30, step=1, elem_id="i2v_motion", label="FPS", value=4)
279
+ with gr.Column():
280
+ i2v_end_btn = gr.Button("Generate Video", interactive=False)
281
+ i2v_output_video = gr.Video(label="Generated Video", elem_id="output_vid", autoplay=True,
282
+ show_share_button=True, height=512)
283
+ with gr.Row():
284
+ i2v_output_images = gr.Gallery(height=512, label="Output Frames", object_fit="contain", columns=8)
285
+
286
+ input_fg.change(lambda: ["", gr.update(interactive=True), gr.update(interactive=False), gr.update(interactive=False)],
287
+ outputs=[prompt, prompt_gen_button, key_gen_button, i2v_end_btn])
288
+
289
+ prompt_gen_button.click(
290
+ fn=interrogator_process,
291
+ inputs=[input_fg],
292
+ outputs=[prompt]
293
+ ).then(lambda: [gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=False)],
294
+ outputs=[prompt_gen_button, key_gen_button, i2v_end_btn])
295
+
296
+ key_gen_button.click(
297
+ fn=process,
298
+ inputs=[input_fg, prompt, input_undo_steps, image_width, image_height, seed, steps, n_prompt, cfg],
299
+ outputs=[result_gallery]
300
+ ).then(lambda: [gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True)],
301
+ outputs=[prompt_gen_button, key_gen_button, i2v_end_btn])
302
+
303
+ i2v_end_btn.click(
304
+ inputs=[result_gallery, i2v_input_text, i2v_steps, i2v_cfg_scale, i2v_fps, i2v_seed],
305
+ outputs=[i2v_output_video, i2v_output_images],
306
+ fn=process_video
307
+ )
308
+
309
+ dbs = [
310
+ ['./imgs/1.jpg', 12345, 123],
311
+ ['./imgs/2.jpg', 37000, 12345],
312
+ ['./imgs/3.jpg', 3000, 3000],
313
+ ]
314
+
315
+ gr.Examples(
316
+ examples=dbs,
317
+ inputs=[input_fg, seed, i2v_seed],
318
+ examples_per_page=1024
319
+ )
320
+
321
+ block.queue().launch(server_name='0.0.0.0')
imgs/1.jpg ADDED
imgs/2.jpg ADDED
imgs/3.jpg ADDED
memory_management.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from contextlib import contextmanager
3
+
4
+
5
+ high_vram = False
6
+ gpu = torch.device('cuda')
7
+ cpu = torch.device('cpu')
8
+
9
+ torch.zeros((1, 1)).to(gpu, torch.float32)
10
+ torch.cuda.empty_cache()
11
+
12
+ models_in_gpu = []
13
+
14
+
15
+ @contextmanager
16
+ def movable_bnb_model(m):
17
+ if hasattr(m, 'quantization_method'):
18
+ m.quantization_method_backup = m.quantization_method
19
+ del m.quantization_method
20
+ try:
21
+ yield None
22
+ finally:
23
+ if hasattr(m, 'quantization_method_backup'):
24
+ m.quantization_method = m.quantization_method_backup
25
+ del m.quantization_method_backup
26
+ return
27
+
28
+
29
+ def load_models_to_gpu(models):
30
+ global models_in_gpu
31
+
32
+ if not isinstance(models, (tuple, list)):
33
+ models = [models]
34
+
35
+ models_to_remain = [m for m in set(models) if m in models_in_gpu]
36
+ models_to_load = [m for m in set(models) if m not in models_in_gpu]
37
+ models_to_unload = [m for m in set(models_in_gpu) if m not in models_to_remain]
38
+
39
+ if not high_vram:
40
+ for m in models_to_unload:
41
+ with movable_bnb_model(m):
42
+ m.to(cpu)
43
+ print('Unload to CPU:', m.__class__.__name__)
44
+ models_in_gpu = models_to_remain
45
+
46
+ for m in models_to_load:
47
+ with movable_bnb_model(m):
48
+ m.to(gpu)
49
+ print('Load to GPU:', m.__class__.__name__)
50
+
51
+ models_in_gpu = list(set(models_in_gpu + models))
52
+ torch.cuda.empty_cache()
53
+ return
54
+
55
+
56
+ def unload_all_models(extra_models=None):
57
+ global models_in_gpu
58
+
59
+ if extra_models is None:
60
+ extra_models = []
61
+
62
+ if not isinstance(extra_models, (tuple, list)):
63
+ extra_models = [extra_models]
64
+
65
+ models_in_gpu = list(set(models_in_gpu + extra_models))
66
+
67
+ return load_models_to_gpu([])
requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diffusers==0.28.0
2
+ transformers==4.41.1
3
+ gradio==4.31.5
4
+ bitsandbytes==0.43.1
5
+ accelerate==0.30.1
6
+ protobuf==3.20
7
+ opencv-python
8
+ tensorboardX
9
+ safetensors
10
+ pillow
11
+ einops
12
+ torch
13
+ peft
14
+ xformers
15
+ onnxruntime
16
+ av
wd14tagger.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags
2
+
3
+
4
+ import os
5
+ import csv
6
+ import numpy as np
7
+ import onnxruntime as ort
8
+
9
+ from PIL import Image
10
+ from onnxruntime import InferenceSession
11
+ from torch.hub import download_url_to_file
12
+
13
+
14
+ global_model = None
15
+ global_csv = None
16
+
17
+
18
+ def download_model(url, local_path):
19
+ if os.path.exists(local_path):
20
+ return local_path
21
+
22
+ temp_path = local_path + '.tmp'
23
+ download_url_to_file(url=url, dst=temp_path)
24
+ os.rename(temp_path, local_path)
25
+ return local_path
26
+
27
+
28
+ def default_interrogator(image, threshold=0.35, character_threshold=0.85, exclude_tags=""):
29
+ global global_model, global_csv
30
+
31
+ model_name = "wd-v1-4-moat-tagger-v2"
32
+
33
+ model_onnx_filename = download_model(
34
+ url=f'https://huggingface.co/lllyasviel/misc/resolve/main/{model_name}.onnx',
35
+ local_path=f'./{model_name}.onnx',
36
+ )
37
+
38
+ model_csv_filename = download_model(
39
+ url=f'https://huggingface.co/lllyasviel/misc/resolve/main/{model_name}.csv',
40
+ local_path=f'./{model_name}.csv',
41
+ )
42
+
43
+ if global_model is not None:
44
+ model = global_model
45
+ else:
46
+ # assert 'CUDAExecutionProvider' in ort.get_available_providers(), 'CUDA Install Failed!'
47
+ # model = InferenceSession(model_onnx_filename, providers=['CUDAExecutionProvider'])
48
+ model = InferenceSession(model_onnx_filename, providers=['CPUExecutionProvider'])
49
+ global_model = model
50
+
51
+ input = model.get_inputs()[0]
52
+ height = input.shape[1]
53
+
54
+ if isinstance(image, str):
55
+ image = Image.open(image) # RGB
56
+ elif isinstance(image, np.ndarray):
57
+ image = Image.fromarray(image)
58
+ else:
59
+ image = image
60
+
61
+ ratio = float(height) / max(image.size)
62
+ new_size = tuple([int(x*ratio) for x in image.size])
63
+ image = image.resize(new_size, Image.LANCZOS)
64
+ square = Image.new("RGB", (height, height), (255, 255, 255))
65
+ square.paste(image, ((height-new_size[0])//2, (height-new_size[1])//2))
66
+
67
+ image = np.array(square).astype(np.float32)
68
+ image = image[:, :, ::-1] # RGB -> BGR
69
+ image = np.expand_dims(image, 0)
70
+
71
+ if global_csv is not None:
72
+ csv_lines = global_csv
73
+ else:
74
+ csv_lines = []
75
+ with open(model_csv_filename) as f:
76
+ reader = csv.reader(f)
77
+ next(reader)
78
+ for row in reader:
79
+ csv_lines.append(row)
80
+ global_csv = csv_lines
81
+
82
+ tags = []
83
+ general_index = None
84
+ character_index = None
85
+ for line_num, row in enumerate(csv_lines):
86
+ if general_index is None and row[2] == "0":
87
+ general_index = line_num
88
+ elif character_index is None and row[2] == "4":
89
+ character_index = line_num
90
+ tags.append(row[1])
91
+
92
+ label_name = model.get_outputs()[0].name
93
+ probs = model.run([label_name], {input.name: image})[0]
94
+
95
+ result = list(zip(tags, probs[0]))
96
+
97
+ general = [item for item in result[general_index:character_index] if item[1] > threshold]
98
+ character = [item for item in result[character_index:] if item[1] > character_threshold]
99
+
100
+ all = character + general
101
+ remove = [s.strip() for s in exclude_tags.lower().split(",")]
102
+ all = [tag for tag in all if tag[0] not in remove]
103
+
104
+ res = ", ".join((item[0].replace("(", "\\(").replace(")", "\\)") for item in all)).replace('_', ' ')
105
+ return res