l-li commited on
Commit
0b23d5a
·
0 Parent(s):

init(*): initialization.

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +58 -0
  2. .gitignore +7 -0
  3. LICENSE +302 -0
  4. README.md +11 -0
  5. app.py +809 -0
  6. assets/sample1.jpg +3 -0
  7. assets/sample1.mp4 +3 -0
  8. assets/sample2.jpg +3 -0
  9. assets/sample2.mp4 +3 -0
  10. assets/sample3.jpg +3 -0
  11. assets/sample3.mp4 +3 -0
  12. assets/sample4.jpg +3 -0
  13. assets/sample4.mp4 +3 -0
  14. assets/sample5-1.png +3 -0
  15. assets/sample5-2.png +3 -0
  16. assets/sample5.mp4 +3 -0
  17. configs/dual_stream/nvcomposer.yaml +139 -0
  18. core/basics.py +95 -0
  19. core/common.py +167 -0
  20. core/data/__init__.py +0 -0
  21. core/data/camera_pose_utils.py +277 -0
  22. core/data/combined_multi_view_dataset.py +341 -0
  23. core/data/utils.py +184 -0
  24. core/distributions.py +102 -0
  25. core/ema.py +84 -0
  26. core/losses/__init__.py +1 -0
  27. core/losses/contperceptual.py +173 -0
  28. core/losses/vqperceptual.py +217 -0
  29. core/models/autoencoder.py +395 -0
  30. core/models/diffusion.py +1679 -0
  31. core/models/samplers/__init__.py +0 -0
  32. core/models/samplers/ddim.py +546 -0
  33. core/models/samplers/dpm_solver/__init__.py +1 -0
  34. core/models/samplers/dpm_solver/dpm_solver.py +1298 -0
  35. core/models/samplers/dpm_solver/sampler.py +91 -0
  36. core/models/samplers/plms.py +358 -0
  37. core/models/samplers/uni_pc/__init__.py +0 -0
  38. core/models/samplers/uni_pc/sampler.py +67 -0
  39. core/models/samplers/uni_pc/uni_pc.py +998 -0
  40. core/models/utils_diffusion.py +186 -0
  41. core/modules/attention.py +710 -0
  42. core/modules/attention_mv.py +316 -0
  43. core/modules/attention_temporal.py +1111 -0
  44. core/modules/encoders/__init__.py +0 -0
  45. core/modules/encoders/adapter.py +485 -0
  46. core/modules/encoders/condition.py +511 -0
  47. core/modules/encoders/resampler.py +264 -0
  48. core/modules/networks/ae_modules.py +1023 -0
  49. core/modules/networks/unet_modules.py +1047 -0
  50. core/modules/position_encoding.py +97 -0
.gitattributes ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ __asset__/sample3.jpg filter=lfs diff=lfs merge=lfs -text
37
+ __asset__/sample1-1.png filter=lfs diff=lfs merge=lfs -text
38
+ __asset__/sample1-2.png filter=lfs diff=lfs merge=lfs -text
39
+ __asset__/sample2.png filter=lfs diff=lfs merge=lfs -text
40
+ assets/sample1-1.png filter=lfs diff=lfs merge=lfs -text
41
+ assets/sample1-2.png filter=lfs diff=lfs merge=lfs -text
42
+ assets/sample2.png filter=lfs diff=lfs merge=lfs -text
43
+ assets/sample3.jpg filter=lfs diff=lfs merge=lfs -text
44
+ assets/sample1.mp4 filter=lfs diff=lfs merge=lfs -text
45
+ assets/sample2.mp4 filter=lfs diff=lfs merge=lfs -text
46
+ assets/sample3.mp4 filter=lfs diff=lfs merge=lfs -text
47
+ assets/sample3-2.png filter=lfs diff=lfs merge=lfs -text
48
+ assets/sample1.jpg filter=lfs diff=lfs merge=lfs -text
49
+ assets/sample2.jpeg filter=lfs diff=lfs merge=lfs -text
50
+ assets/sample3-1.png filter=lfs diff=lfs merge=lfs -text
51
+ assets/sample4.mp4 filter=lfs diff=lfs merge=lfs -text
52
+ assets/sample5-1.png filter=lfs diff=lfs merge=lfs -text
53
+ assets/sample5-2.png filter=lfs diff=lfs merge=lfs -text
54
+ assets/sample5.mp4 filter=lfs diff=lfs merge=lfs -text
55
+ assets/sample2.jpg filter=lfs diff=lfs merge=lfs -text
56
+ assets/sample3.jpeg filter=lfs diff=lfs merge=lfs -text
57
+ assets/sample4.jpeg filter=lfs diff=lfs merge=lfs -text
58
+ assets/sample4.jpg filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ .idea
2
+ __pycache__
3
+ .git
4
+ *.pyc
5
+ .DS_Store
6
+ ._*
7
+ cache
LICENSE ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. The below software and/or models in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"). All Tencent Modifications are Copyright (C) THL A29 Limited.
2
+
3
+
4
+ License Terms of the NVComposer:
5
+ --------------------------------------------------------------------
6
+
7
+ Permission is hereby granted, free of charge, to any person obtaining a copy of this Software and associated documentation files, to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, and/or sublicense copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
8
+
9
+ - You agree to use the NVComposer only for academic, research and education purposes, and refrain from using it for any commercial or production purposes under any circumstances.
10
+
11
+ - The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
12
+
13
+ For avoidance of doubts, "Software" means the NVComposer model inference-enabling code, parameters and weights made available under this license.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
16
+
17
+
18
+ Other dependencies and licenses:
19
+
20
+
21
+ Open Source Model Licensed under the CreativeML OpenRAIL M license:
22
+ The below model in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"), as model weights provided for the NVComposer Project hereunder is fine-tuned with the assistance of below model.
23
+
24
+ All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
25
+ --------------------------------------------------------------------
26
+ 1. stable-diffusion-v1-5
27
+ This stable-diffusion-v1-5 is licensed under the CreativeML OpenRAIL M license, Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors
28
+ The original model is available at: https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5
29
+
30
+
31
+ Terms of the CreativeML OpenRAIL M license:
32
+ --------------------------------------------------------------------
33
+ Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors
34
+
35
+ CreativeML Open RAIL-M
36
+ dated August 22, 2022
37
+
38
+ Section I: PREAMBLE
39
+
40
+ Multimodal generative models are being widely adopted and used, and have the potential to transform the way artists, among other individuals, conceive and benefit from AI or ML technologies as a tool for content creation.
41
+
42
+ Notwithstanding the current and potential benefits that these artifacts can bring to society at large, there are also concerns about potential misuses of them, either due to their technical limitations or ethical considerations.
43
+
44
+ In short, this license strives for both the open and responsible downstream use of the accompanying model. When it comes to the open character, we took inspiration from open source permissive licenses regarding the grant of IP rights. Referring to the downstream responsible use, we added use-based restrictions not permitting the use of the Model in very specific scenarios, in order for the licensor to be able to enforce the license in case potential misuses of the Model may occur. At the same time, we strive to promote open and responsible research on generative models for art and content generation.
45
+
46
+ Even though downstream derivative versions of the model could be released under different licensing terms, the latter will always have to include - at minimum - the same use-based restrictions as the ones in the original license (this license). We believe in the intersection between open and responsible AI development; thus, this License aims to strike a balance between both in order to enable responsible open-science in the field of AI.
47
+
48
+ This License governs the use of the model (and its derivatives) and is informed by the model card associated with the model.
49
+
50
+ NOW THEREFORE, You and Licensor agree as follows:
51
+
52
+ 1. Definitions
53
+
54
+ - "License" means the terms and conditions for use, reproduction, and Distribution as defined in this document.
55
+ - "Data" means a collection of information and/or content extracted from the dataset used with the Model, including to train, pretrain, or otherwise evaluate the Model. The Data is not licensed under this License.
56
+ - "Output" means the results of operating a Model as embodied in informational content resulting therefrom.
57
+ - "Model" means any accompanying machine-learning based assemblies (including checkpoints), consisting of learnt weights, parameters (including optimizer states), corresponding to the model architecture as embodied in the Complementary Material, that have been trained or tuned, in whole or in part on the Data, using the Complementary Material.
58
+ - "Derivatives of the Model" means all modifications to the Model, works based on the Model, or any other model which is created or initialized by transfer of patterns of the weights, parameters, activations or output of the Model, to the other model, in order to cause the other model to perform similarly to the Model, including - but not limited to - distillation methods entailing the use of intermediate data representations or methods based on the generation of synthetic data by the Model for training the other model.
59
+ - "Complementary Material" means the accompanying source code and scripts used to define, run, load, benchmark or evaluate the Model, and used to prepare data for training or evaluation, if any. This includes any accompanying documentation, tutorials, examples, etc, if any.
60
+ - "Distribution" means any transmission, reproduction, publication or other sharing of the Model or Derivatives of the Model to a third party, including providing the Model as a hosted service made available by electronic or other remote means - e.g. API-based or web access.
61
+ - "Licensor" means the copyright owner or entity authorized by the copyright owner that is granting the License, including the persons or entities that may have rights in the Model and/or distributing the Model.
62
+ - "You" (or "Your") means an individual or Legal Entity exercising permissions granted by this License and/or making use of the Model for whichever purpose and in any field of use, including usage of the Model in an end-use application - e.g. chatbot, translator, image generator.
63
+ - "Third Parties" means individuals or legal entities that are not under common control with Licensor or You.
64
+ - "Contribution" means any work of authorship, including the original version of the Model and any modifications or additions to that Model or Derivatives of the Model thereof, that is intentionally submitted to Licensor for inclusion in the Model by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Model, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution."
65
+ - "Contributor" means Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Model.
66
+
67
+ Section II: INTELLECTUAL PROPERTY RIGHTS
68
+
69
+ Both copyright and patent grants apply to the Model, Derivatives of the Model and Complementary Material. The Model and Derivatives of the Model are subject to additional terms as described in Section III.
70
+
71
+ 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare, publicly display, publicly perform, sublicense, and distribute the Complementary Material, the Model, and Derivatives of the Model.
72
+ 3. Grant of Patent License. Subject to the terms and conditions of this License and where and as applicable, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this paragraph) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Model and the Complementary Material, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Model to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Model and/or Complementary Material or a Contribution incorporated within the Model and/or Complementary Material constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for the Model and/or Work shall terminate as of the date such litigation is asserted or filed.
73
+
74
+ Section III: CONDITIONS OF USAGE, DISTRIBUTION AND REDISTRIBUTION
75
+
76
+ 4. Distribution and Redistribution. You may host for Third Party remote access purposes (e.g. software-as-a-service), reproduce and distribute copies of the Model or Derivatives of the Model thereof in any medium, with or without modifications, provided that You meet the following conditions:
77
+ Use-based restrictions as referenced in paragraph 5 MUST be included as an enforceable provision by You in any type of legal agreement (e.g. a license) governing the use and/or distribution of the Model or Derivatives of the Model, and You shall give notice to subsequent users You Distribute to, that the Model or Derivatives of the Model are subject to paragraph 5. This provision does not apply to the use of Complementary Material.
78
+ You must give any Third Party recipients of the Model or Derivatives of the Model a copy of this License;
79
+ You must cause any modified files to carry prominent notices stating that You changed the files;
80
+ You must retain all copyright, patent, trademark, and attribution notices excluding those notices that do not pertain to any part of the Model, Derivatives of the Model.
81
+ You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions - respecting paragraph 4.a. - for use, reproduction, or Distribution of Your modifications, or for any such Derivatives of the Model as a whole, provided Your use, reproduction, and Distribution of the Model otherwise complies with the conditions stated in this License.
82
+ 5. Use-based restrictions. The restrictions set forth in Attachment A are considered Use-based restrictions. Therefore You cannot use the Model and the Derivatives of the Model for the specified restricted uses. You may use the Model subject to this License, including only for lawful purposes and in accordance with the License. Use may include creating any content with, finetuning, updating, running, training, evaluating and/or reparametrizing the Model. You shall require all of Your users who use the Model or a Derivative of the Model to comply with the terms of this paragraph (paragraph 5).
83
+ 6. The Output You Generate. Except as set forth herein, Licensor claims no rights in the Output You generate using the Model. You are accountable for the Output you generate and its subsequent uses. No use of the output can contravene any provision as stated in the License.
84
+
85
+ Section IV: OTHER PROVISIONS
86
+
87
+ 7. Updates and Runtime Restrictions. To the maximum extent permitted by law, Licensor reserves the right to restrict (remotely or otherwise) usage of the Model in violation of this License, update the Model through electronic means, or modify the Output of the Model based on updates. You shall undertake reasonable efforts to use the latest version of the Model.
88
+ 8. Trademarks and related. Nothing in this License permits You to make use of Licensors’ trademarks, trade names, logos or to otherwise suggest endorsement or misrepresent the relationship between the parties; and any rights not expressly granted herein are reserved by the Licensors.
89
+ 9. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Model and the Complementary Material (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Model, Derivatives of the Model, and the Complementary Material and assume any risks associated with Your exercise of permissions under this License.
90
+ 10. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Model and the Complementary Material (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
91
+ 11. Accepting Warranty or Additional Liability. While redistributing the Model, Derivatives of the Model and the Complementary Material thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
92
+ 12. If any provision of this License is held to be invalid, illegal or unenforceable, the remaining provisions shall be unaffected thereby and remain valid as if such provision had not been set forth herein.
93
+
94
+ END OF TERMS AND CONDITIONS
95
+
96
+
97
+
98
+
99
+ Attachment A
100
+
101
+ Use Restrictions
102
+
103
+ You agree not to use the Model or Derivatives of the Model:
104
+ - In any way that violates any applicable national, federal, state, local or international law or regulation;
105
+ - For the purpose of exploiting, harming or attempting to exploit or harm minors in any way;
106
+ - To generate or disseminate verifiably false information and/or content with the purpose of harming others;
107
+ - To generate or disseminate personal identifiable information that can be used to harm an individual;
108
+ - To defame, disparage or otherwise harass others;
109
+ - For fully automated decision making that adversely impacts an individual’s legal rights or otherwise creates or modifies a binding, enforceable obligation;
110
+ - For any use intended to or which has the effect of discriminating against or harming individuals or groups based on online or offline social behavior or known or predicted personal or personality characteristics;
111
+ - To exploit any of the vulnerabilities of a specific group of persons based on their age, social, physical or mental characteristics, in order to materially distort the behavior of a person pertaining to that group in a manner that causes or is likely to cause that person or another person physical or psychological harm;
112
+ - For any use intended to or which has the effect of discriminating against individuals or groups based on legally protected characteristics or categories;
113
+ - To provide medical advice and medical results interpretation;
114
+ - To generate or disseminate information for the purpose to be used for administration of justice, law enforcement, immigration or asylum processes, such as predicting an individual will commit fraud/crime commitment (e.g. by text profiling, drawing causal relationships between assertions made in documents, indiscriminate and arbitrarily-targeted use).
115
+
116
+
117
+
118
+ Open Source Software Licensed under the Apache License Version 2.0:
119
+ --------------------------------------------------------------------
120
+ 1. pytorch_lightning
121
+ Copyright 2018-2021 William Falcon
122
+
123
+ 2. gradio
124
+ Copyright (c) gradio original author and authors
125
+
126
+
127
+ Terms of the Apache License Version 2.0:
128
+ --------------------------------------------------------------------
129
+ Apache License
130
+
131
+ Version 2.0, January 2004
132
+
133
+ http://www.apache.org/licenses/
134
+
135
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
136
+ 1. Definitions.
137
+
138
+ "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document.
139
+
140
+ "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License.
141
+
142
+ "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
143
+
144
+ "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License.
145
+
146
+ "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files.
147
+
148
+ "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types.
149
+
150
+ "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below).
151
+
152
+ "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof.
153
+
154
+ "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution."
155
+
156
+ "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work.
157
+
158
+ 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form.
159
+
160
+ 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed.
161
+
162
+ 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:
163
+
164
+ You must give any other recipients of the Work or Derivative Works a copy of this License; and
165
+
166
+ You must cause any modified files to carry prominent notices stating that You changed the files; and
167
+
168
+ You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and
169
+
170
+ If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License.
171
+
172
+ You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License.
173
+
174
+ 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions.
175
+
176
+ 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file.
177
+
178
+ 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License.
179
+
180
+ 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
181
+
182
+ 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
183
+
184
+ END OF TERMS AND CONDITIONS
185
+
186
+
187
+
188
+ Open Source Software Licensed under the BSD 3-Clause License:
189
+ --------------------------------------------------------------------
190
+ 1. torchvision
191
+ Copyright (c) Soumith Chintala 2016,
192
+ All rights reserved.
193
+
194
+ 2. scikit-learn
195
+ Copyright (c) 2007-2024 The scikit-learn developers.
196
+ All rights reserved.
197
+
198
+
199
+ Terms of the BSD 3-Clause License:
200
+ --------------------------------------------------------------------
201
+ Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
202
+
203
+ 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
204
+
205
+ 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
206
+
207
+ 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
208
+
209
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
210
+
211
+
212
+
213
+ Open Source Software Licensed under the BSD 3-Clause License and Other Licenses of the Third-Party Components therein:
214
+ --------------------------------------------------------------------
215
+ 1. torch
216
+ Copyright (c) 2016- Facebook, Inc (Adam Paszke)
217
+ Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
218
+ Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
219
+ Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
220
+ Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
221
+ Copyright (c) 2011-2013 NYU (Clement Farabet)
222
+ Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
223
+ Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
224
+ Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
225
+
226
+
227
+
228
+ A copy of the BSD 3-Clause is included in this file.
229
+
230
+ For the license of other third party components, please refer to the following URL:
231
+ https://github.com/pytorch/pytorch/tree/v2.1.2/third_party
232
+
233
+
234
+ Open Source Software Licensed under the BSD 3-Clause License and Other Licenses of the Third-Party Components therein:
235
+ --------------------------------------------------------------------
236
+ 1. numpy
237
+ Copyright (c) 2005-2023, NumPy Developers.
238
+ All rights reserved.
239
+
240
+
241
+ A copy of the BSD 3-Clause is included in this file.
242
+
243
+ For the license of other third party components, please refer to the following URL:
244
+ https://github.com/numpy/numpy/blob/v1.26.3/LICENSES_bundled.txt
245
+
246
+
247
+ Open Source Software Licensed under the HPND License:
248
+ --------------------------------------------------------------------
249
+ 1. Pillow
250
+ Copyright © 2010-2024 by Jeffrey A. Clark (Alex) and contributors.
251
+
252
+
253
+ Terms of the HPND License:
254
+ --------------------------------------------------------------------
255
+ By obtaining, using, and/or copying this software and/or its associated
256
+ documentation, you agree that you have read, understood, and will comply
257
+ with the following terms and conditions:
258
+
259
+ Permission to use, copy, modify and distribute this software and its
260
+ documentation for any purpose and without fee is hereby granted,
261
+ provided that the above copyright notice appears in all copies, and that
262
+ both that copyright notice and this permission notice appear in supporting
263
+ documentation, and that the name of Secret Labs AB or the author not be
264
+ used in advertising or publicity pertaining to distribution of the software
265
+ without specific, written prior permission.
266
+
267
+ SECRET LABS AB AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS
268
+ SOFTWARE, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS.
269
+ IN NO EVENT SHALL SECRET LABS AB OR THE AUTHOR BE LIABLE FOR ANY SPECIAL,
270
+ INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM
271
+ LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE
272
+ OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR
273
+ PERFORMANCE OF THIS SOFTWARE.
274
+
275
+
276
+
277
+ Open Source Software Licensed under the MIT License:
278
+ --------------------------------------------------------------------
279
+ 1. einops
280
+ Copyright (c) 2018 Alex Rogozhnikov
281
+
282
+
283
+ Terms of the MIT License:
284
+ --------------------------------------------------------------------
285
+ Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
286
+
287
+ The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
288
+
289
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
290
+
291
+
292
+
293
+ Open Source Software Licensed under the MIT License and Other Licenses of the Third-Party Components therein:
294
+ --------------------------------------------------------------------
295
+ 1. opencv-python
296
+ Copyright (c) Olli-Pekka Heinisuo
297
+
298
+
299
+ A copy of the MIT is included in this file.
300
+
301
+ For the license of other third party components, please refer to the following URL:
302
+ https://github.com/opencv/opencv-python/blob/4.x/LICENSE-3RD-PARTY.txt
README.md ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: NVComposer
3
+ emoji: 📸
4
+ colorFrom: indigo
5
+ colorTo: gray
6
+ sdk: gradio
7
+ sdk_version: 4.38.1
8
+ app_file: app.py
9
+ pinned: false
10
+ python_version: 3.1
11
+ ---
app.py ADDED
@@ -0,0 +1,809 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import json
3
+ import os
4
+
5
+ import gradio as gr
6
+ from huggingface_hub import hf_hub_download
7
+ import spaces
8
+ import PIL.Image
9
+ import numpy as np
10
+ import torch
11
+ import torchvision.transforms.functional
12
+ from numpy import deg2rad
13
+ from omegaconf import OmegaConf
14
+
15
+ from core.data.camera_pose_utils import convert_w2c_between_c2w
16
+ from core.data.combined_multi_view_dataset import (
17
+ get_ray_embeddings,
18
+ normalize_w2c_camera_pose_sequence,
19
+ crop_and_resize,
20
+ )
21
+ from main.evaluation.funcs import load_model_checkpoint
22
+ from main.evaluation.pose_interpolation import (
23
+ move_pose,
24
+ interpolate_camera_poses,
25
+ generate_spherical_trajectory,
26
+ )
27
+ from main.evaluation.utils_eval import process_inference_batch
28
+ from utils.utils import instantiate_from_config
29
+ from core.models.samplers.ddim import DDIMSampler
30
+
31
+ torch.set_float32_matmul_precision("medium")
32
+
33
+ gpu_no = 0
34
+ config = "./configs/dual_stream/nvcomposer.yaml"
35
+ ckpt = hf_hub_download(
36
+ repo_id="TencentARC/NVComposer", filename="NVComposer-V0.1.ckpt", repo_type="model"
37
+ )
38
+
39
+ model_resolution_height, model_resolution_width = 576, 1024
40
+ num_views = 16
41
+ dtype = torch.float16
42
+ config = OmegaConf.load(config)
43
+ model_config = config.pop("model", OmegaConf.create())
44
+ model_config.params.train_with_multi_view_feature_alignment = False
45
+ model = instantiate_from_config(model_config).cuda(gpu_no).to(dtype=dtype)
46
+ assert os.path.exists(ckpt), f"Error: checkpoint [{ckpt}] Not Found!"
47
+ print(f"Loading checkpoint from {ckpt}...")
48
+ model = load_model_checkpoint(model, ckpt)
49
+ model.eval()
50
+ latent_h, latent_w = (
51
+ model_resolution_height // 8,
52
+ model_resolution_width // 8,
53
+ )
54
+ channels = model.channels
55
+ sampler = DDIMSampler(model)
56
+
57
+ EXAMPLES = [
58
+ [
59
+ "./assets/sample1.jpg",
60
+ None,
61
+ 1,
62
+ 0,
63
+ 0,
64
+ 1,
65
+ 0,
66
+ 0,
67
+ 0,
68
+ 0,
69
+ 0,
70
+ -0.2,
71
+ 3,
72
+ 1.5,
73
+ 20,
74
+ "./assets/sample1.mp4",
75
+ 1,
76
+ ],
77
+ [
78
+ "./assets/sample2.jpg",
79
+ None,
80
+ 0,
81
+ 0,
82
+ 25,
83
+ 1,
84
+ 0,
85
+ 0,
86
+ 0,
87
+ 0,
88
+ 0,
89
+ 0,
90
+ 3,
91
+ 1.5,
92
+ 20,
93
+ "./assets/sample2.mp4",
94
+ 1,
95
+ ],
96
+ [
97
+ "./assets/sample3.jpg",
98
+ None,
99
+ 0,
100
+ 0,
101
+ 15,
102
+ 1,
103
+ 0,
104
+ 0,
105
+ 0,
106
+ 0,
107
+ 0,
108
+ 0,
109
+ 3,
110
+ 1.5,
111
+ 20,
112
+ "./assets/sample3.mp4",
113
+ 1,
114
+ ],
115
+ [
116
+ "./assets/sample4.jpg",
117
+ None,
118
+ 0,
119
+ 0,
120
+ -15,
121
+ 1,
122
+ 0,
123
+ 0,
124
+ 0,
125
+ 0,
126
+ 0,
127
+ 0,
128
+ 3,
129
+ 1.5,
130
+ 20,
131
+ "./assets/sample4.mp4",
132
+ 1,
133
+ ],
134
+ [
135
+ "./assets/sample5-1.png",
136
+ "./assets/sample5-2.png",
137
+ 0,
138
+ 0,
139
+ -30,
140
+ 1,
141
+ 0,
142
+ 0,
143
+ 0,
144
+ 0,
145
+ 0,
146
+ 0,
147
+ 3,
148
+ 1.5,
149
+ 20,
150
+ "./assets/sample5.mp4",
151
+ 2,
152
+ ],
153
+ ]
154
+
155
+
156
+ def compose_data_item(
157
+ num_views,
158
+ cond_pil_image_list,
159
+ caption="",
160
+ camera_mode=False,
161
+ input_pose_format="c2w",
162
+ model_pose_format="c2w",
163
+ x_rotation_angle=10,
164
+ y_rotation_angle=10,
165
+ z_rotation_angle=10,
166
+ x_translation=0.5,
167
+ y_translation=0.5,
168
+ z_translation=0.5,
169
+ image_size=None,
170
+ spherical_angle_x=10,
171
+ spherical_angle_y=10,
172
+ spherical_radius=10,
173
+ ):
174
+ if image_size is None:
175
+ image_size = [512, 512]
176
+ latent_size = [image_size[0] // 8, image_size[1] // 8]
177
+
178
+ def image_processing_function(x):
179
+ return (
180
+ torch.from_numpy(
181
+ np.array(
182
+ crop_and_resize(
183
+ x, target_height=image_size[0], target_width=image_size[1]
184
+ )
185
+ ).transpose((2, 0, 1))
186
+ ).float()
187
+ / 255.0
188
+ )
189
+
190
+ resizer_image_to_latent_size = torchvision.transforms.Resize(
191
+ size=latent_size,
192
+ interpolation=torchvision.transforms.InterpolationMode.BILINEAR,
193
+ antialias=True,
194
+ )
195
+ num_cond_views = len(cond_pil_image_list)
196
+ print(f"Number of received condition images: {num_cond_views}.")
197
+ num_target_views = num_views - num_cond_views
198
+ if camera_mode == 1:
199
+ print("Camera Mode: Movement with Rotation and Translation.")
200
+ start_pose = torch.tensor(
201
+ [
202
+ [1, 0, 0, 0],
203
+ [0, 1, 0, 0],
204
+ [0, 0, 1, 0],
205
+ ]
206
+ ).float()
207
+ end_pose = move_pose(
208
+ start_pose,
209
+ x_angle=torch.tensor(deg2rad(x_rotation_angle)),
210
+ y_angle=torch.tensor(deg2rad(y_rotation_angle)),
211
+ z_angle=torch.tensor(deg2rad(z_rotation_angle)),
212
+ translation=torch.tensor([x_translation, y_translation, z_translation]),
213
+ )
214
+ target_poses = interpolate_camera_poses(
215
+ start_pose, end_pose, num_steps=num_target_views
216
+ )
217
+ elif camera_mode == 0:
218
+ print("Camera Mode: Spherical Movement.")
219
+ target_poses = generate_spherical_trajectory(
220
+ end_angles=(spherical_angle_x, spherical_angle_y),
221
+ radius=spherical_radius,
222
+ num_steps=num_target_views,
223
+ )
224
+ print("Target pose sequence (before normalization): \n ", target_poses)
225
+ cond_poses = [
226
+ torch.tensor(
227
+ [
228
+ [1, 0, 0, 0],
229
+ [0, 1, 0, 0],
230
+ [0, 0, 1, 0],
231
+ ]
232
+ ).float()
233
+ ] * num_cond_views
234
+ target_poses = torch.stack(target_poses, dim=0).float()
235
+ cond_poses = torch.stack(cond_poses, dim=0).float()
236
+ if not camera_mode != 0 and (input_pose_format != "w2c"):
237
+ # c2w to w2c. Input for normalize_camera_pose_sequence() should be w2c
238
+ target_poses = convert_w2c_between_c2w(target_poses)
239
+ cond_poses = convert_w2c_between_c2w(cond_poses)
240
+ target_poses, cond_poses = normalize_w2c_camera_pose_sequence(
241
+ target_poses,
242
+ cond_poses,
243
+ output_c2w=model_pose_format == "c2w",
244
+ translation_norm_mode="disabled",
245
+ )
246
+ target_and_condition_camera_poses = torch.cat([target_poses, cond_poses], dim=0)
247
+
248
+ print("Target pose sequence (after normalization): \n ", target_poses)
249
+ fov_xy = [80, 45]
250
+ target_rays = get_ray_embeddings(
251
+ target_poses,
252
+ size_h=image_size[0],
253
+ size_w=image_size[1],
254
+ fov_xy_list=[fov_xy for _ in range(num_target_views)],
255
+ )
256
+ condition_rays = get_ray_embeddings(
257
+ cond_poses,
258
+ size_h=image_size[0],
259
+ size_w=image_size[1],
260
+ fov_xy_list=[fov_xy for _ in range(num_cond_views)],
261
+ )
262
+ target_images_tensor = torch.zeros(
263
+ num_target_views, 3, image_size[0], image_size[1]
264
+ )
265
+ condition_images = [image_processing_function(x) for x in cond_pil_image_list]
266
+ condition_images_tensor = torch.stack(condition_images, dim=0) * 2.0 - 1.0
267
+ target_images_tensor[0, :, :, :] = condition_images_tensor[0, :, :, :]
268
+ target_and_condition_images_tensor = torch.cat(
269
+ [target_images_tensor, condition_images_tensor], dim=0
270
+ )
271
+ target_and_condition_rays_tensor = torch.cat([target_rays, condition_rays], dim=0)
272
+ target_and_condition_rays_tensor = resizer_image_to_latent_size(
273
+ target_and_condition_rays_tensor * 5.0
274
+ )
275
+ mask_preserving_target = torch.ones(size=[num_views, 1], dtype=torch.float16)
276
+ mask_preserving_target[num_target_views:] = 0.0
277
+ combined_fovs = torch.stack([torch.tensor(fov_xy)] * num_views, dim=0)
278
+
279
+ mask_only_preserving_first_target = torch.zeros_like(mask_preserving_target)
280
+ mask_only_preserving_first_target[0] = 1.0
281
+ mask_only_preserving_first_condition = torch.zeros_like(mask_preserving_target)
282
+ mask_only_preserving_first_condition[num_target_views] = 1.0
283
+ test_data = {
284
+ # T, C, H, W
285
+ "combined_images": target_and_condition_images_tensor.unsqueeze(0),
286
+ "mask_preserving_target": mask_preserving_target.unsqueeze(0), # T, 1
287
+ # T, 1
288
+ "mask_only_preserving_first_target": mask_only_preserving_first_target.unsqueeze(
289
+ 0
290
+ ),
291
+ # T, 1
292
+ "mask_only_preserving_first_condition": mask_only_preserving_first_condition.unsqueeze(
293
+ 0
294
+ ),
295
+ # T, C, H//8, W//8
296
+ "combined_rays": target_and_condition_rays_tensor.unsqueeze(0),
297
+ "combined_fovs": combined_fovs.unsqueeze(0),
298
+ "target_and_condition_camera_poses": target_and_condition_camera_poses.unsqueeze(
299
+ 0
300
+ ),
301
+ "num_target_images": torch.tensor([num_target_views]),
302
+ "num_cond_images": torch.tensor([num_cond_views]),
303
+ "num_cond_images_str": [str(num_cond_views)],
304
+ "item_idx": [0],
305
+ "subset_key": ["evaluation"],
306
+ "caption": [caption],
307
+ "fov_xy": torch.tensor(fov_xy).float().unsqueeze(0),
308
+ }
309
+ return test_data
310
+
311
+
312
+ def tensor_to_mp4(video, savepath, fps, nrow=None):
313
+ """
314
+ video: torch.Tensor, b,t,c,h,w, value range: 0-1
315
+ """
316
+ n = video.shape[0]
317
+ print("Video shape=", video.shape)
318
+ video = video.permute(1, 0, 2, 3, 4) # t,n,c,h,w
319
+ nrow = int(np.sqrt(n)) if nrow is None else nrow
320
+ frame_grids = [
321
+ torchvision.utils.make_grid(framesheet, nrow=nrow) for framesheet in video
322
+ ] # [3, grid_h, grid_w]
323
+ # stack in temporal dim [T, 3, grid_h, grid_w]
324
+ grid = torch.stack(frame_grids, dim=0)
325
+ grid = torch.clamp(grid.float(), -1.0, 1.0)
326
+ # [T, 3, grid_h, grid_w] -> [T, grid_h, grid_w, 3]
327
+ grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
328
+ # print(f'Save video to {savepath}')
329
+ torchvision.io.write_video(
330
+ savepath, grid, fps=fps, video_codec="h264", options={"crf": "10"}
331
+ )
332
+
333
+
334
+ def parse_to_np_array(input_string):
335
+ try:
336
+ # Try to parse the input as JSON first
337
+ data = json.loads(input_string)
338
+ arr = np.array(data)
339
+ except json.JSONDecodeError:
340
+ # If JSON parsing fails, assume it's a multi-line string and handle accordingly
341
+ lines = input_string.strip().splitlines()
342
+ data = []
343
+ for line in lines:
344
+ # Split the line by spaces and convert to floats
345
+ data.append([float(x) for x in line.split()])
346
+ arr = np.array(data)
347
+
348
+ # Check if the resulting array is 3x4
349
+ if arr.shape != (3, 4):
350
+ raise ValueError(f"Expected array shape (3, 4), but got {arr.shape}")
351
+
352
+ return arr
353
+
354
+
355
+ @spaces.GPU(duration=180)
356
+ def run_inference(
357
+ camera_mode,
358
+ input_cond_image1=None,
359
+ input_cond_image2=None,
360
+ input_cond_image3=None,
361
+ input_cond_image4=None,
362
+ input_pose_format="c2w",
363
+ model_pose_format="c2w",
364
+ x_rotation_angle=None,
365
+ y_rotation_angle=None,
366
+ z_rotation_angle=None,
367
+ x_translation=None,
368
+ y_translation=None,
369
+ z_translation=None,
370
+ trajectory_extension_factor=1,
371
+ cfg_scale=1.0,
372
+ cfg_scale_extra=1.0,
373
+ sample_steps=50,
374
+ num_images_slider=None,
375
+ spherical_angle_x=10,
376
+ spherical_angle_y=10,
377
+ spherical_radius=10,
378
+ random_seed=1,
379
+ ):
380
+ cfg_scale_extra = 1.0 # Disable Extra CFG due to time limit of ZeroGPU
381
+ os.makedirs("./cache/", exist_ok=True)
382
+ with torch.no_grad():
383
+ with torch.cuda.amp.autocast(dtype=dtype):
384
+ torch.manual_seed(random_seed)
385
+ input_cond_images = []
386
+ for _cond_image in [
387
+ input_cond_image1,
388
+ input_cond_image2,
389
+ input_cond_image3,
390
+ input_cond_image4,
391
+ ]:
392
+ if _cond_image is not None:
393
+ if isinstance(_cond_image, np.ndarray):
394
+ _cond_image = PIL.Image.fromarray(_cond_image)
395
+ input_cond_images.append(_cond_image)
396
+ num_condition_views = len(input_cond_images)
397
+ assert (
398
+ num_images_slider == num_condition_views
399
+ ), f"The `num_condition_views`={num_condition_views} while got `num_images_slider`={num_images_slider}."
400
+ input_caption = ""
401
+ num_target_views = num_views - num_condition_views
402
+ data_item = compose_data_item(
403
+ num_views=num_views,
404
+ cond_pil_image_list=input_cond_images,
405
+ caption=input_caption,
406
+ camera_mode=camera_mode,
407
+ input_pose_format=input_pose_format,
408
+ model_pose_format=model_pose_format,
409
+ x_rotation_angle=x_rotation_angle,
410
+ y_rotation_angle=y_rotation_angle,
411
+ z_rotation_angle=z_rotation_angle,
412
+ x_translation=x_translation,
413
+ y_translation=y_translation,
414
+ z_translation=z_translation,
415
+ image_size=[model_resolution_height, model_resolution_width],
416
+ spherical_angle_x=spherical_angle_x,
417
+ spherical_angle_y=spherical_angle_y,
418
+ spherical_radius=spherical_radius,
419
+ )
420
+ batch = data_item
421
+ if trajectory_extension_factor == 1:
422
+ print("No trajectory extension.")
423
+ else:
424
+ print(f"Trajectory is enabled: {trajectory_extension_factor}.")
425
+ full_x_samples = []
426
+ for repeat_idx in range(int(trajectory_extension_factor)):
427
+ if repeat_idx != 0:
428
+ batch["combined_images"][:, 0, :, :, :] = full_x_samples[-1][
429
+ :, -1, :, :, :
430
+ ]
431
+ batch["combined_images"][:, num_target_views, :, :, :] = (
432
+ full_x_samples[-1][:, -1, :, :, :]
433
+ )
434
+ cond, uc, uc_extra, x_rec = process_inference_batch(
435
+ cfg_scale, batch, model, with_uncondition_extra=True
436
+ )
437
+
438
+ batch_size = x_rec.shape[0]
439
+ shape_without_batch = (num_views, channels, latent_h, latent_w)
440
+ samples, _ = sampler.sample(
441
+ sample_steps,
442
+ batch_size=batch_size,
443
+ shape=shape_without_batch,
444
+ conditioning=cond,
445
+ verbose=True,
446
+ unconditional_conditioning=uc,
447
+ unconditional_guidance_scale=cfg_scale,
448
+ unconditional_conditioning_extra=uc_extra,
449
+ unconditional_guidance_scale_extra=cfg_scale_extra,
450
+ x_T=None,
451
+ expand_mode=False,
452
+ num_target_views=num_views - num_condition_views,
453
+ num_condition_views=num_condition_views,
454
+ dense_expansion_ratio=None,
455
+ pred_x0_post_process_function=None,
456
+ pred_x0_post_process_function_kwargs=None,
457
+ )
458
+
459
+ if samples.size(2) > 4:
460
+ image_samples = samples[:, :num_target_views, :4, :, :]
461
+ else:
462
+ image_samples = samples
463
+ per_instance_decoding = False
464
+ if per_instance_decoding:
465
+ x_samples = []
466
+ for item_idx in range(image_samples.shape[0]):
467
+ image_samples = image_samples[
468
+ item_idx : item_idx + 1, :, :, :, :
469
+ ]
470
+ x_sample = model.decode_first_stage(image_samples)
471
+ x_samples.append(x_sample)
472
+ x_samples = torch.cat(x_samples, dim=0)
473
+ else:
474
+ x_samples = model.decode_first_stage(image_samples)
475
+ full_x_samples.append(x_samples[:, :num_target_views, ...])
476
+
477
+ full_x_samples = torch.concat(full_x_samples, dim=1)
478
+ x_samples = full_x_samples
479
+ x_samples = torch.clamp((x_samples + 1.0) / 2.0, 0.0, 1.0)
480
+ video_name = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") + ".mp4"
481
+ video_path = "./cache/" + video_name
482
+ tensor_to_mp4(x_samples.detach().cpu(), fps=6, savepath=video_path)
483
+ return video_path
484
+
485
+
486
+ with gr.Blocks() as demo:
487
+ gr.HTML(
488
+ """
489
+ <div style="text-align: center;">
490
+ <h1 style="text-align: center; color: #333333;">📸 NVComposer</h1>
491
+ <h3 style="text-align: center; color: #333333;">Generative Novel View Synthesis with Sparse and
492
+ Unposed Images</h3>
493
+ <p style="text-align: center; font-weight: bold">
494
+ <a href="https://lg-li.github.io/project/nvcomposer">🌍 Project Page</a> |
495
+ <a href="https://arxiv.org/abs/2412.03517">📃 ArXiv Preprint</a> |
496
+ <a href="https://github.com/TencentARC/NVComposer">🧑‍💻 Github Repository</a>
497
+ </p>
498
+ <p style="text-align: left; font-size: 1.1em;">
499
+ Welcome to the demo of <strong>NVComposer</strong>. Follow the steps below to explore its capabilities:
500
+ </p>
501
+ </div>
502
+ <div style="text-align: left; margin: 0 auto; ">
503
+ <ol style="font-size: 1.1em;">
504
+ <li><strong>Choose camera movement mode:</strong> Spherical Mode or Rotation & Translation Mode.</li>
505
+ <li><strong>Customize the camera trajectory:</strong> Adjust the spherical parameters or rotation/translations along the X, Y,
506
+ and Z axes.</li>
507
+ <li><strong>Upload images:</strong> You can upload up to 4 images as input conditions.</li>
508
+ <li><strong>Set sampling parameters (optional):</strong> Tweak the settings and click the <b>Generate</b> button.</li>
509
+ </ol>
510
+ <p>
511
+ ⏱️ <b>ZeroGPU Time Limit</b>: Hugging Face ZeroGPU has a inference time limit of 180 seconds.
512
+ You may need to <b>log in with a free account</b> to use this demo.
513
+ Large sampling steps might lead to timeout (GPU Abort).
514
+ In that case, please consider log in with a Pro account or run it on your local machine.
515
+ </p>
516
+ <p style="text-align: left; font-size: 1.1em;">🤗 Please 🌟 star our <a href="https://github.com/TencentARC/NVComposer"> GitHub repo </a>
517
+ and click on the ❤️ like button above if you find our work helpful. <br>
518
+ <a href="https://github.com/TencentARC/NVComposer"><img src="https://img.shields.io/github/stars/TencentARC%2FNVComposer"/></a> </p>
519
+ </div>
520
+ """
521
+ )
522
+ with gr.Row():
523
+ with gr.Column(scale=1):
524
+ with gr.Accordion("Camera Movement Settings", open=True):
525
+ camera_mode = gr.Radio(
526
+ choices=[("Spherical Mode", 0), ("Rotation & Translation Mode", 1)],
527
+ label="Camera Mode",
528
+ value=0,
529
+ interactive=True,
530
+ )
531
+
532
+ with gr.Group(visible=True) as group_spherical:
533
+ # This tab can be left blank for now as per your request
534
+ # Add extra options manually here in the future
535
+ gr.HTML(
536
+ """<p style="padding: 10px">
537
+ <b>Spherical Mode</b> allows you to control the camera's movement by specifying its position on a sphere centered around the scene.
538
+ Adjust the Polar Angle (vertical rotation), Azimuth Angle (horizontal rotation), and Radius (distance from the center of the anchor view) to define the camera's viewpoint.
539
+ The anchor view is considered located on the sphere at the specified radius, aligned with a zero polar angle and zero azimuth angle, oriented toward the origin.
540
+ </p>
541
+ """
542
+ )
543
+ spherical_angle_x = gr.Slider(
544
+ minimum=-30,
545
+ maximum=30,
546
+ step=1,
547
+ value=0,
548
+ label="Polar Angle (Theta)",
549
+ )
550
+ spherical_angle_y = gr.Slider(
551
+ minimum=-30,
552
+ maximum=30,
553
+ step=1,
554
+ value=5,
555
+ label="Azimuth Angle (Phi)",
556
+ )
557
+ spherical_radius = gr.Slider(
558
+ minimum=0.5, maximum=1.5, step=0.1, value=1, label="Radius"
559
+ )
560
+
561
+ with gr.Group(visible=False) as group_move_rotation_translation:
562
+ gr.HTML(
563
+ """<p style="padding: 10px">
564
+ <b>Rotation & Translation Mode</b> lets you directly define how the camera moves and rotates in the 3D space.
565
+ Use Rotation X/Y/Z to control the camera's orientation and Translation X/Y/Z to shift its position.
566
+ The anchor view serves as the starting point, with no initial rotation or translation applied.
567
+ </p>
568
+ """
569
+ )
570
+ rotation_x = gr.Slider(
571
+ minimum=-20, maximum=20, step=1, value=0, label="Rotation X"
572
+ )
573
+ rotation_y = gr.Slider(
574
+ minimum=-20, maximum=20, step=1, value=0, label="Rotation Y"
575
+ )
576
+ rotation_z = gr.Slider(
577
+ minimum=-20, maximum=20, step=1, value=0, label="Rotation Z"
578
+ )
579
+ translation_x = gr.Slider(
580
+ minimum=-1, maximum=1, step=0.1, value=0, label="Translation X"
581
+ )
582
+ translation_y = gr.Slider(
583
+ minimum=-1, maximum=1, step=0.1, value=0, label="Translation Y"
584
+ )
585
+ translation_z = gr.Slider(
586
+ minimum=-1,
587
+ maximum=1,
588
+ step=0.1,
589
+ value=-0.2,
590
+ label="Translation Z",
591
+ )
592
+
593
+ input_camera_pose_format = gr.Radio(
594
+ choices=["W2C", "C2W"],
595
+ value="C2W",
596
+ label="Input Camera Pose Format",
597
+ visible=False,
598
+ )
599
+ model_camera_pose_format = gr.Radio(
600
+ choices=["W2C", "C2W"],
601
+ value="C2W",
602
+ label="Model Camera Pose Format",
603
+ visible=False,
604
+ )
605
+
606
+ def on_change_selected_camera_settings(_id):
607
+ return [gr.update(visible=_id == 0), gr.update(visible=_id == 1)]
608
+
609
+ camera_mode.change(
610
+ fn=on_change_selected_camera_settings,
611
+ inputs=camera_mode,
612
+ outputs=[group_spherical, group_move_rotation_translation],
613
+ )
614
+
615
+ with gr.Accordion("Advanced Sampling Settings"):
616
+ cfg_scale = gr.Slider(
617
+ value=3.0,
618
+ label="Classifier-Free Guidance Scale",
619
+ minimum=1,
620
+ maximum=10,
621
+ step=0.1,
622
+ )
623
+ extra_cfg_scale = gr.Slider(
624
+ value=1.0,
625
+ label="Extra Classifier-Free Guidance Scale",
626
+ minimum=1,
627
+ maximum=10,
628
+ step=0.1,
629
+ visible=False,
630
+ )
631
+ sample_steps = gr.Slider(
632
+ value=18, label="DDIM Sample Steps", minimum=0, maximum=25, step=1
633
+ )
634
+ trajectory_extension_factor = gr.Slider(
635
+ value=1,
636
+ label="Trajectory Extension (proportional to runtime)",
637
+ minimum=1,
638
+ maximum=3,
639
+ step=1,
640
+ )
641
+ random_seed = gr.Slider(
642
+ value=1024, minimum=1, maximum=9999, step=1, label="Random Seed"
643
+ )
644
+
645
+ def on_change_trajectory_extension_factor(_val):
646
+ if _val == 1:
647
+ return [
648
+ gr.update(minimum=-30, maximum=30),
649
+ gr.update(minimum=-30, maximum=30),
650
+ gr.update(minimum=0.5, maximum=1.5),
651
+ gr.update(minimum=-20, maximum=20),
652
+ gr.update(minimum=-20, maximum=20),
653
+ gr.update(minimum=-20, maximum=20),
654
+ gr.update(minimum=-1, maximum=1),
655
+ gr.update(minimum=-1, maximum=1),
656
+ gr.update(minimum=-1, maximum=1),
657
+ ]
658
+ elif _val == 2:
659
+ return [
660
+ gr.update(minimum=-15, maximum=15),
661
+ gr.update(minimum=-15, maximum=15),
662
+ gr.update(minimum=0.5, maximum=1.5),
663
+ gr.update(minimum=-10, maximum=10),
664
+ gr.update(minimum=-10, maximum=10),
665
+ gr.update(minimum=-10, maximum=10),
666
+ gr.update(minimum=-0.5, maximum=0.5),
667
+ gr.update(minimum=-0.5, maximum=0.5),
668
+ gr.update(minimum=-0.5, maximum=0.5),
669
+ ]
670
+ elif _val == 3:
671
+ return [
672
+ gr.update(minimum=-10, maximum=10),
673
+ gr.update(minimum=-10, maximum=10),
674
+ gr.update(minimum=0.5, maximum=1.5),
675
+ gr.update(minimum=-6, maximum=6),
676
+ gr.update(minimum=-6, maximum=6),
677
+ gr.update(minimum=-6, maximum=6),
678
+ gr.update(minimum=-0.3, maximum=0.3),
679
+ gr.update(minimum=-0.3, maximum=0.3),
680
+ gr.update(minimum=-0.3, maximum=0.3),
681
+ ]
682
+
683
+ trajectory_extension_factor.change(
684
+ fn=on_change_trajectory_extension_factor,
685
+ inputs=trajectory_extension_factor,
686
+ outputs=[
687
+ spherical_angle_x,
688
+ spherical_angle_y,
689
+ spherical_radius,
690
+ rotation_x,
691
+ rotation_y,
692
+ rotation_z,
693
+ translation_x,
694
+ translation_y,
695
+ translation_z,
696
+ ],
697
+ )
698
+
699
+ with gr.Column(scale=1):
700
+ with gr.Accordion("Input Image(s)", open=True):
701
+ num_images_slider = gr.Slider(
702
+ minimum=1,
703
+ maximum=4,
704
+ step=1,
705
+ value=1,
706
+ label="Number of Input Image(s)",
707
+ )
708
+ condition_image_1 = gr.Image(label="Input Image 1 (Anchor View)")
709
+ condition_image_2 = gr.Image(label="Input Image 2", visible=False)
710
+ condition_image_3 = gr.Image(label="Input Image 3", visible=False)
711
+ condition_image_4 = gr.Image(label="Input Image 4", visible=False)
712
+
713
+ with gr.Column(scale=1):
714
+ with gr.Accordion("Output Video", open=True):
715
+ output_video = gr.Video(label="Output Video")
716
+ run_btn = gr.Button("Generate")
717
+ with gr.Accordion("Notes", open=True):
718
+ gr.HTML(
719
+ """
720
+ <p style="font-size: 1.1em; line-height: 1.6; color: #555;">
721
+ 🧐 <b>Reminder</b>:
722
+ As a generative model, NVComposer may occasionally produce unexpected outputs.
723
+ Try adjusting the random seed, sampling steps, or CFG scales to explore different results.
724
+ <br>
725
+ 🤔 <b>Longer Generation</b>:
726
+ If you need longer video, you can increase the trajectory extension value in the advanced sampling settings and run with your own GPU.
727
+ This extends the defined camera trajectory by repeating it, allowing for a longer output.
728
+ This also requires using smaller rotation or translation scales to maintain smooth transitions and will increase the generation time. <br>
729
+ 🤗 <b>Limitation</b>:
730
+ This is the initial beta version of NVComposer.
731
+ Its generalizability may be limited in certain scenarios, and artifacts can appear with large camera motions due to the current foundation model's constraints.
732
+ We’re actively working on an improved version with enhanced datasets and a more powerful foundation model,
733
+ and we are looking for <b>collaboration opportunities from the community</b>. <br>
734
+ ✨ We welcome your feedback and questions. Thank you! </p>
735
+ """
736
+ )
737
+
738
+ with gr.Row():
739
+ gr.Examples(
740
+ label="Quick Examples",
741
+ examples=EXAMPLES,
742
+ inputs=[
743
+ condition_image_1,
744
+ condition_image_2,
745
+ camera_mode,
746
+ spherical_angle_x,
747
+ spherical_angle_y,
748
+ spherical_radius,
749
+ rotation_x,
750
+ rotation_y,
751
+ rotation_z,
752
+ translation_x,
753
+ translation_y,
754
+ translation_z,
755
+ cfg_scale,
756
+ extra_cfg_scale,
757
+ sample_steps,
758
+ output_video,
759
+ num_images_slider,
760
+ ],
761
+ examples_per_page=5,
762
+ cache_examples=False,
763
+ )
764
+
765
+ # Update visibility of condition images based on the slider
766
+ def update_visible_images(num_images):
767
+ return [
768
+ gr.update(visible=num_images >= 2),
769
+ gr.update(visible=num_images >= 3),
770
+ gr.update(visible=num_images >= 4),
771
+ ]
772
+
773
+ # Trigger visibility update when the slider value changes
774
+ num_images_slider.change(
775
+ fn=update_visible_images,
776
+ inputs=num_images_slider,
777
+ outputs=[condition_image_2, condition_image_3, condition_image_4],
778
+ )
779
+
780
+ run_btn.click(
781
+ fn=run_inference,
782
+ inputs=[
783
+ camera_mode,
784
+ condition_image_1,
785
+ condition_image_2,
786
+ condition_image_3,
787
+ condition_image_4,
788
+ input_camera_pose_format,
789
+ model_camera_pose_format,
790
+ rotation_x,
791
+ rotation_y,
792
+ rotation_z,
793
+ translation_x,
794
+ translation_y,
795
+ translation_z,
796
+ trajectory_extension_factor,
797
+ cfg_scale,
798
+ extra_cfg_scale,
799
+ sample_steps,
800
+ num_images_slider,
801
+ spherical_angle_x,
802
+ spherical_angle_y,
803
+ spherical_radius,
804
+ random_seed,
805
+ ],
806
+ outputs=output_video,
807
+ )
808
+
809
+ demo.launch()
assets/sample1.jpg ADDED

Git LFS Details

  • SHA256: 821bdc48093db88c75b89e124de2d3511ee3d6f17617ffc94bcc5b30ebe7d295
  • Pointer size: 132 Bytes
  • Size of remote file: 1.5 MB
assets/sample1.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:242b3617d2c50a9f175619827974b90dd665fa51ae06b2cc7bb9373248f5f8d1
3
+ size 2513866
assets/sample2.jpg ADDED

Git LFS Details

  • SHA256: 085b781f0330692c746e6f9e2d28f24fbfe0285db1b5ec94383037200b673b0a
  • Pointer size: 131 Bytes
  • Size of remote file: 154 kB
assets/sample2.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:754a3246734802261d445878fa7ad5f5b860deceb597d3bb439547078b7f0281
3
+ size 2369420
assets/sample3.jpg ADDED

Git LFS Details

  • SHA256: 0d55565e60ea8a80e7c09ef8f1f3ee4e64b507571174cf79f0c65c3d8cdcb1de
  • Pointer size: 131 Bytes
  • Size of remote file: 757 kB
assets/sample3.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:047c45a0e93627e63464fe939484acbfbfc9087c43d366618c8b34dd331ba3f5
3
+ size 4129878
assets/sample4.jpg ADDED

Git LFS Details

  • SHA256: 0d55565e60ea8a80e7c09ef8f1f3ee4e64b507571174cf79f0c65c3d8cdcb1de
  • Pointer size: 131 Bytes
  • Size of remote file: 757 kB
assets/sample4.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5a76a2950e6b7cc82bc2e36c04d312a1e5babc25cdd489b25a542862912b9f62
3
+ size 4118935
assets/sample5-1.png ADDED

Git LFS Details

  • SHA256: 6c41016f7cc5acd012ab0251d4e4a2a698b9160aaacf017e0aa6053786d87f58
  • Pointer size: 132 Bytes
  • Size of remote file: 1.19 MB
assets/sample5-2.png ADDED

Git LFS Details

  • SHA256: c06371aa6dd628f733adec128bb650a4c2aa710f26f9c8e266f18b1bf9b536a2
  • Pointer size: 132 Bytes
  • Size of remote file: 1.19 MB
assets/sample5.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d4846600b8e47774729e3bc199d4f08399d983cd7213b724ede7a5ed9057a3d5
3
+ size 4124063
configs/dual_stream/nvcomposer.yaml ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ num_frames: &num_frames 16
2
+ resolution: &resolution [576, 1024]
3
+ model:
4
+ base_learning_rate: 1.0e-5
5
+ scale_lr: false
6
+ target: core.models.diffusion.DualStreamMultiViewDiffusionModel
7
+ params:
8
+ use_task_embedding: false
9
+ ray_as_image: false
10
+ apply_condition_mask_in_training_loss: true
11
+ separate_noise_and_condition: true
12
+ condition_padding_with_anchor: false
13
+ use_ray_decoder_loss_high_frequency_isolation: false
14
+ train_with_multi_view_feature_alignment: true
15
+ use_text_cross_attention_condition: false
16
+
17
+ linear_start: 0.00085
18
+ linear_end: 0.012
19
+ num_time_steps_cond: 1
20
+ log_every_t: 200
21
+ time_steps: 1000
22
+
23
+ data_key_images: combined_images
24
+ data_key_rays: combined_rays
25
+ data_key_text_condition: caption
26
+ cond_stage_trainable: false
27
+ image_size: [72, 128]
28
+
29
+ channels: 10
30
+ monitor: global_step
31
+ scale_by_std: false
32
+ scale_factor: 0.18215
33
+ use_dynamic_rescale: true
34
+ base_scale: 0.3
35
+
36
+ use_ema: false
37
+ uncond_prob: 0.05
38
+ uncond_type: 'empty_seq'
39
+
40
+ use_camera_pose_query_transformer: false
41
+ random_cond: false
42
+ cond_concat: true
43
+ frame_mask: false
44
+ padding: true
45
+ per_frame_auto_encoding: true
46
+ parameterization: "v"
47
+ rescale_betas_zero_snr: true
48
+ use_noise_offset: false
49
+ scheduler_config:
50
+ target: utils.lr_scheduler.LambdaLRScheduler
51
+ interval: 'step'
52
+ frequency: 100
53
+ params:
54
+ start_step: 0
55
+ final_decay_ratio: 0.1
56
+ decay_steps: 100
57
+ bd_noise: false
58
+
59
+ unet_config:
60
+ target: core.modules.networks.unet_modules.UNetModel
61
+ params:
62
+ in_channels: 20
63
+ out_channels: 10
64
+ model_channels: 320
65
+ attention_resolutions:
66
+ - 4
67
+ - 2
68
+ - 1
69
+ num_res_blocks: 2
70
+ channel_mult:
71
+ - 1
72
+ - 2
73
+ - 4
74
+ - 4
75
+ dropout: 0.1
76
+ num_head_channels: 64
77
+ transformer_depth: 1
78
+ context_dim: 1024
79
+ use_linear: true
80
+ use_checkpoint: true
81
+ temporal_conv: true
82
+ temporal_attention: true
83
+ temporal_selfatt_only: true
84
+ use_relative_position: false
85
+ use_causal_attention: false
86
+ temporal_length: *num_frames
87
+ addition_attention: true
88
+ image_cross_attention: true
89
+ image_cross_attention_scale_learnable: true
90
+ default_fs: 3
91
+ fs_condition: false
92
+ use_spatial_temporal_attention: true
93
+ use_addition_ray_output_head: true
94
+ ray_channels: 6
95
+ use_lora_for_rays_in_output_blocks: false
96
+ use_task_embedding: false
97
+ use_ray_decoder: true
98
+ use_ray_decoder_residual: true
99
+ full_spatial_temporal_attention: true
100
+ enhance_multi_view_correspondence: false
101
+ camera_pose_condition: true
102
+ use_feature_alignment: true
103
+
104
+ first_stage_config:
105
+ target: core.models.autoencoder.AutoencoderKL
106
+ params:
107
+ embed_dim: 4
108
+ monitor: val/rec_loss
109
+ ddconfig:
110
+ double_z: true
111
+ z_channels: 4
112
+ resolution: 256
113
+ in_channels: 3
114
+ out_ch: 3
115
+ ch: 128
116
+ ch_mult: [1, 2, 4, 4]
117
+ num_res_blocks: 2
118
+ attn_resolutions: []
119
+ dropout: 0.0
120
+ lossconfig:
121
+ target: torch.nn.Identity
122
+
123
+ cond_img_config:
124
+ target: core.modules.encoders.condition.FrozenOpenCLIPImageEmbedderV2
125
+ params:
126
+ freeze: true
127
+
128
+ image_proj_model_config:
129
+ target: core.modules.encoders.resampler.Resampler
130
+ params:
131
+ dim: 1024
132
+ depth: 4
133
+ dim_head: 64
134
+ heads: 12
135
+ num_queries: 16
136
+ embedding_dim: 1280
137
+ output_dim: 1024
138
+ ff_mult: 4
139
+ video_length: *num_frames
core/basics.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ from utils.utils import instantiate_from_config
4
+
5
+
6
+ def disabled_train(self, mode=True):
7
+ """Overwrite model.train with this function to make sure train/eval mode
8
+ does not change anymore."""
9
+ return self
10
+
11
+
12
+ def zero_module(module):
13
+ """
14
+ Zero out the parameters of a module and return it.
15
+ """
16
+ for p in module.parameters():
17
+ p.detach().zero_()
18
+ return module
19
+
20
+
21
+ def scale_module(module, scale):
22
+ """
23
+ Scale the parameters of a module and return it.
24
+ """
25
+ for p in module.parameters():
26
+ p.detach().mul_(scale)
27
+ return module
28
+
29
+
30
+ def conv_nd(dims, *args, **kwargs):
31
+ """
32
+ Create a 1D, 2D, or 3D convolution module.
33
+ """
34
+ if dims == 1:
35
+ return nn.Conv1d(*args, **kwargs)
36
+ elif dims == 2:
37
+ return nn.Conv2d(*args, **kwargs)
38
+ elif dims == 3:
39
+ return nn.Conv3d(*args, **kwargs)
40
+ raise ValueError(f"unsupported dimensions: {dims}")
41
+
42
+
43
+ def linear(*args, **kwargs):
44
+ """
45
+ Create a linear module.
46
+ """
47
+ return nn.Linear(*args, **kwargs)
48
+
49
+
50
+ def avg_pool_nd(dims, *args, **kwargs):
51
+ """
52
+ Create a 1D, 2D, or 3D average pooling module.
53
+ """
54
+ if dims == 1:
55
+ return nn.AvgPool1d(*args, **kwargs)
56
+ elif dims == 2:
57
+ return nn.AvgPool2d(*args, **kwargs)
58
+ elif dims == 3:
59
+ return nn.AvgPool3d(*args, **kwargs)
60
+ raise ValueError(f"unsupported dimensions: {dims}")
61
+
62
+
63
+ def nonlinearity(type="silu"):
64
+ if type == "silu":
65
+ return nn.SiLU()
66
+ elif type == "leaky_relu":
67
+ return nn.LeakyReLU()
68
+
69
+
70
+ class GroupNormSpecific(nn.GroupNorm):
71
+ def forward(self, x):
72
+ return super().forward(x.float()).type(x.dtype)
73
+
74
+
75
+ def normalization(channels, num_groups=32):
76
+ """
77
+ Make a standard normalization layer.
78
+ :param channels: number of input channels.
79
+ :param num_groups: number of groupseg.
80
+ :return: an nn.Module for normalization.
81
+ """
82
+ return GroupNormSpecific(num_groups, channels)
83
+
84
+
85
+ class HybridConditioner(nn.Module):
86
+
87
+ def __init__(self, c_concat_config, c_crossattn_config):
88
+ super().__init__()
89
+ self.concat_conditioner = instantiate_from_config(c_concat_config)
90
+ self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
91
+
92
+ def forward(self, c_concat, c_crossattn):
93
+ c_concat = self.concat_conditioner(c_concat)
94
+ c_crossattn = self.crossattn_conditioner(c_crossattn)
95
+ return {"c_concat": [c_concat], "c_crossattn": [c_crossattn]}
core/common.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from inspect import isfunction
3
+
4
+ import torch
5
+ import torch.distributed as dist
6
+ from torch import nn
7
+
8
+
9
+ def gather_data(data, return_np=True):
10
+ """gather data from multiple processes to one list"""
11
+ data_list = [torch.zeros_like(data) for _ in range(dist.get_world_size())]
12
+ dist.all_gather(data_list, data) # gather not supported with NCCL
13
+ if return_np:
14
+ data_list = [data.cpu().numpy() for data in data_list]
15
+ return data_list
16
+
17
+
18
+ def autocast(f):
19
+ def do_autocast(*args, **kwargs):
20
+ with torch.cuda.amp.autocast(
21
+ enabled=True,
22
+ dtype=torch.get_autocast_gpu_dtype(),
23
+ cache_enabled=torch.is_autocast_cache_enabled(),
24
+ ):
25
+ return f(*args, **kwargs)
26
+
27
+ return do_autocast
28
+
29
+
30
+ def extract_into_tensor(a, t, x_shape):
31
+ b, *_ = t.shape
32
+ out = a.gather(-1, t)
33
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
34
+
35
+
36
+ def noise_like(shape, device, repeat=False):
37
+ def repeat_noise():
38
+ return torch.randn((1, *shape[1:]), device=device).repeat(
39
+ shape[0], *((1,) * (len(shape) - 1))
40
+ )
41
+
42
+ def noise():
43
+ return torch.randn(shape, device=device)
44
+
45
+ return repeat_noise() if repeat else noise()
46
+
47
+
48
+ def default(val, d):
49
+ if exists(val):
50
+ return val
51
+ return d() if isfunction(d) else d
52
+
53
+
54
+ def exists(val):
55
+ return val is not None
56
+
57
+
58
+ def identity(*args, **kwargs):
59
+ return nn.Identity()
60
+
61
+
62
+ def uniq(arr):
63
+ return {el: True for el in arr}.keys()
64
+
65
+
66
+ def mean_flat(tensor):
67
+ """
68
+ Take the mean over all non-batch dimensions.
69
+ """
70
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
71
+
72
+
73
+ def ismap(x):
74
+ if not isinstance(x, torch.Tensor):
75
+ return False
76
+ return (len(x.shape) == 4) and (x.shape[1] > 3)
77
+
78
+
79
+ def isimage(x):
80
+ if not isinstance(x, torch.Tensor):
81
+ return False
82
+ return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
83
+
84
+
85
+ def max_neg_value(t):
86
+ return -torch.finfo(t.dtype).max
87
+
88
+
89
+ def shape_to_str(x):
90
+ shape_str = "x".join([str(x) for x in x.shape])
91
+ return shape_str
92
+
93
+
94
+ def init_(tensor):
95
+ dim = tensor.shape[-1]
96
+ std = 1 / math.sqrt(dim)
97
+ tensor.uniform_(-std, std)
98
+ return tensor
99
+
100
+
101
+ # USE_DEEP_SPEED_CHECKPOINTING = False
102
+ # if USE_DEEP_SPEED_CHECKPOINTING:
103
+ # import deepspeed
104
+ #
105
+ # _gradient_checkpoint_function = deepspeed.checkpointing.checkpoint
106
+ # else:
107
+ _gradient_checkpoint_function = torch.utils.checkpoint.checkpoint
108
+
109
+
110
+ def gradient_checkpoint(func, inputs, params, flag):
111
+ """
112
+ Evaluate a function without caching intermediate activations, allowing for
113
+ reduced memory at the expense of extra compute in the backward pass.
114
+ :param func: the function to evaluate.
115
+ :param inputs: the argument sequence to pass to `func`.
116
+ :param params: a sequence of parameters `func` depends on but does not
117
+ explicitly take as arguments.
118
+ :param flag: if False, disable gradient checkpointing.
119
+ """
120
+ if flag:
121
+ # args = tuple(inputs) + tuple(params)
122
+ # return CheckpointFunction.apply(func, len(inputs), *args)
123
+ if isinstance(inputs, tuple):
124
+ return _gradient_checkpoint_function(func, *inputs, use_reentrant=False)
125
+ else:
126
+ return _gradient_checkpoint_function(func, inputs, use_reentrant=False)
127
+ else:
128
+ return func(*inputs)
129
+
130
+
131
+ class CheckpointFunction(torch.autograd.Function):
132
+ @staticmethod
133
+ @torch.cuda.amp.custom_fwd
134
+ def forward(ctx, run_function, length, *args):
135
+ ctx.run_function = run_function
136
+ ctx.input_tensors = list(args[:length])
137
+ ctx.input_params = list(args[length:])
138
+
139
+ with torch.no_grad():
140
+ output_tensors = ctx.run_function(*ctx.input_tensors)
141
+ return output_tensors
142
+
143
+ @staticmethod
144
+ @torch.cuda.amp.custom_bwd # add this
145
+ def backward(ctx, *output_grads):
146
+ """
147
+ for x in ctx.input_tensors:
148
+ if isinstance(x, int):
149
+ print('-----------------', ctx.run_function)
150
+ """
151
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
152
+ with torch.enable_grad():
153
+ # Fixes a bug where the first op in run_function modifies the
154
+ # Tensor storage in place, which is not allowed for detach()'d
155
+ # Tensors.
156
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
157
+ output_tensors = ctx.run_function(*shallow_copies)
158
+ input_grads = torch.autograd.grad(
159
+ output_tensors,
160
+ ctx.input_tensors + ctx.input_params,
161
+ output_grads,
162
+ allow_unused=True,
163
+ )
164
+ del ctx.input_tensors
165
+ del ctx.input_params
166
+ del output_tensors
167
+ return (None, None) + input_grads
core/data/__init__.py ADDED
File without changes
core/data/camera_pose_utils.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import numpy as np
3
+ import torch
4
+ from scipy.spatial.transform import Rotation as R
5
+
6
+
7
+ def get_opencv_from_blender(matrix_world, fov, image_size):
8
+ # convert matrix_world to opencv format extrinsics
9
+ opencv_world_to_cam = matrix_world.inverse()
10
+ opencv_world_to_cam[1, :] *= -1
11
+ opencv_world_to_cam[2, :] *= -1
12
+ R, T = opencv_world_to_cam[:3, :3], opencv_world_to_cam[:3, 3]
13
+ R, T = R.unsqueeze(0), T.unsqueeze(0)
14
+
15
+ # convert fov to opencv format intrinsics
16
+ focal = 1 / np.tan(fov / 2)
17
+ intrinsics = np.diag(np.array([focal, focal, 1])).astype(np.float32)
18
+ opencv_cam_matrix = torch.from_numpy(intrinsics).unsqueeze(0).float()
19
+ opencv_cam_matrix[:, :2, -1] += torch.tensor([image_size / 2, image_size / 2])
20
+ opencv_cam_matrix[:, [0, 1], [0, 1]] *= image_size / 2
21
+
22
+ return R, T, opencv_cam_matrix
23
+
24
+
25
+ def cartesian_to_spherical(xyz):
26
+ xy = xyz[:, 0] ** 2 + xyz[:, 1] ** 2
27
+ z = np.sqrt(xy + xyz[:, 2] ** 2)
28
+ # for elevation angle defined from z-axis down
29
+ theta = np.arctan2(np.sqrt(xy), xyz[:, 2])
30
+ azimuth = np.arctan2(xyz[:, 1], xyz[:, 0])
31
+ return np.stack([theta, azimuth, z], axis=-1)
32
+
33
+
34
+ def spherical_to_cartesian(spherical_coords):
35
+ # convert from spherical to cartesian coordinates
36
+ theta, azimuth, radius = spherical_coords.T
37
+ x = radius * np.sin(theta) * np.cos(azimuth)
38
+ y = radius * np.sin(theta) * np.sin(azimuth)
39
+ z = radius * np.cos(theta)
40
+ return np.stack([x, y, z], axis=-1)
41
+
42
+
43
+ def look_at(eye, center, up):
44
+ # Create a normalized direction vector from eye to center
45
+ f = np.array(center) - np.array(eye)
46
+ f /= np.linalg.norm(f)
47
+
48
+ # Create a normalized right vector
49
+ up_norm = np.array(up) / np.linalg.norm(up)
50
+ s = np.cross(f, up_norm)
51
+ s /= np.linalg.norm(s)
52
+
53
+ # Recompute the up vector
54
+ u = np.cross(s, f)
55
+
56
+ # Create rotation matrix R
57
+ R = np.array([[s[0], s[1], s[2]], [u[0], u[1], u[2]], [-f[0], -f[1], -f[2]]])
58
+
59
+ # Create translation vector T
60
+ T = -np.dot(R, np.array(eye))
61
+
62
+ return R, T
63
+
64
+
65
+ def get_blender_from_spherical(elevation, azimuth):
66
+ """Generates blender camera from spherical coordinates."""
67
+
68
+ cartesian_coords = spherical_to_cartesian(np.array([[elevation, azimuth, 3.5]]))
69
+
70
+ # get camera rotation
71
+ center = np.array([0, 0, 0])
72
+ eye = cartesian_coords[0]
73
+ up = np.array([0, 0, 1])
74
+
75
+ R, T = look_at(eye, center, up)
76
+ R = R.T
77
+ T = -np.dot(R, T)
78
+ RT = np.concatenate([R, T.reshape(3, 1)], axis=-1)
79
+
80
+ blender_cam = torch.from_numpy(RT).float()
81
+ blender_cam = torch.cat([blender_cam, torch.tensor([[0, 0, 0, 1]])], dim=0)
82
+ print(blender_cam)
83
+ return blender_cam
84
+
85
+
86
+ def invert_pose(r, t):
87
+ r_inv = r.T
88
+ t_inv = -np.dot(r_inv, t)
89
+ return r_inv, t_inv
90
+
91
+
92
+ def transform_pose_sequence_to_relative(poses, as_z_up=False):
93
+ """
94
+ poses: a sequence of 3*4 C2W camera pose matrices
95
+ as_z_up: output in z-up format. If False, the output is in y-up format
96
+ """
97
+ r0, t0 = poses[0][:3, :3], poses[0][:3, 3]
98
+ # r0_inv, t0_inv = invert_pose(r0, t0)
99
+ r0_inv = r0.T
100
+ new_rt0 = np.hstack([np.eye(3, 3), np.zeros((3, 1))])
101
+ if as_z_up:
102
+ new_rt0 = c2w_y_up_to_z_up(new_rt0)
103
+ transformed_poses = [new_rt0]
104
+ for pose in poses[1:]:
105
+ r, t = pose[:3, :3], pose[:3, 3]
106
+ new_r = np.dot(r0_inv, r)
107
+ new_t = np.dot(r0_inv, t - t0)
108
+ new_rt = np.hstack([new_r, new_t[:, None]])
109
+ if as_z_up:
110
+ new_rt = c2w_y_up_to_z_up(new_rt)
111
+ transformed_poses.append(new_rt)
112
+ return transformed_poses
113
+
114
+
115
+ def c2w_y_up_to_z_up(c2w_3x4):
116
+ R_y_up_to_z_up = np.array([[1, 0, 0], [0, 0, -1], [0, 1, 0]])
117
+
118
+ R = c2w_3x4[:, :3]
119
+ t = c2w_3x4[:, 3]
120
+
121
+ R_z_up = R_y_up_to_z_up @ R
122
+ t_z_up = R_y_up_to_z_up @ t
123
+
124
+ T_z_up = np.hstack((R_z_up, t_z_up.reshape(3, 1)))
125
+
126
+ return T_z_up
127
+
128
+
129
+ def transform_pose_sequence_to_relative_w2c(poses):
130
+ new_rt_list = []
131
+ first_frame_rt = copy.deepcopy(poses[0])
132
+ first_frame_r_inv = first_frame_rt[:, :3].T
133
+ first_frame_t = first_frame_rt[:, -1]
134
+ for rt in poses:
135
+ rt[:, :3] = np.matmul(rt[:, :3], first_frame_r_inv)
136
+ rt[:, -1] = rt[:, -1] - np.matmul(rt[:, :3], first_frame_t)
137
+ new_rt_list.append(copy.deepcopy(rt))
138
+ return new_rt_list
139
+
140
+
141
+ def transform_pose_sequence_to_relative_c2w(poses):
142
+ first_frame_rt = poses[0]
143
+ first_frame_r_inv = first_frame_rt[:, :3].T
144
+ first_frame_t = first_frame_rt[:, -1]
145
+ rotations = poses[:, :, :3]
146
+ translations = poses[:, :, 3]
147
+
148
+ # Compute new rotations and translations in batch
149
+ new_rotations = torch.matmul(first_frame_r_inv, rotations)
150
+ new_translations = torch.matmul(
151
+ first_frame_r_inv, (translations - first_frame_t.unsqueeze(0)).unsqueeze(-1)
152
+ )
153
+ # Concatenate new rotations and translations
154
+ new_rt = torch.cat([new_rotations, new_translations], dim=-1)
155
+
156
+ return new_rt
157
+
158
+
159
+ def convert_w2c_between_c2w(poses):
160
+ rotations = poses[:, :, :3]
161
+ translations = poses[:, :, 3]
162
+ new_rotations = rotations.transpose(-1, -2)
163
+ new_translations = torch.matmul(-new_rotations, translations.unsqueeze(-1))
164
+ new_rt = torch.cat([new_rotations, new_translations], dim=-1)
165
+ return new_rt
166
+
167
+
168
+ def slerp(q1, q2, t):
169
+ """
170
+ Performs spherical linear interpolation (SLERP) between two quaternions.
171
+
172
+ Args:
173
+ q1 (torch.Tensor): Start quaternion (4,).
174
+ q2 (torch.Tensor): End quaternion (4,).
175
+ t (float or torch.Tensor): Interpolation parameter in [0, 1].
176
+
177
+ Returns:
178
+ torch.Tensor: Interpolated quaternion (4,).
179
+ """
180
+ q1 = q1 / torch.linalg.norm(q1) # Normalize q1
181
+ q2 = q2 / torch.linalg.norm(q2) # Normalize q2
182
+
183
+ dot = torch.dot(q1, q2)
184
+
185
+ # Ensure shortest path (flip q2 if needed)
186
+ if dot < 0.0:
187
+ q2 = -q2
188
+ dot = -dot
189
+
190
+ # Avoid numerical precision issues
191
+ dot = torch.clamp(dot, -1.0, 1.0)
192
+
193
+ theta = torch.acos(dot) # Angle between q1 and q2
194
+
195
+ if theta < 1e-6: # If very close, use linear interpolation
196
+ return (1 - t) * q1 + t * q2
197
+
198
+ sin_theta = torch.sin(theta)
199
+
200
+ return (torch.sin((1 - t) * theta) / sin_theta) * q1 + (
201
+ torch.sin(t * theta) / sin_theta
202
+ ) * q2
203
+
204
+
205
+ def interpolate_camera_poses(c2w: torch.Tensor, factor: int) -> torch.Tensor:
206
+ """
207
+ Interpolates a sequence of camera c2w poses to N times the length of the original sequence.
208
+
209
+ Args:
210
+ c2w (torch.Tensor): Input camera poses of shape (N, 3, 4).
211
+ factor (int): The upsampling factor (e.g., 2 for doubling the length).
212
+
213
+ Returns:
214
+ torch.Tensor: Interpolated camera poses of shape (N * factor, 3, 4).
215
+ """
216
+ assert c2w.ndim == 3 and c2w.shape[1:] == (
217
+ 3,
218
+ 4,
219
+ ), "Input tensor must have shape (N, 3, 4)."
220
+ assert factor > 1, "Upsampling factor must be greater than 1."
221
+
222
+ N = c2w.shape[0]
223
+ new_length = N * factor
224
+
225
+ # Extract rotations (R) and translations (T)
226
+ rotations = c2w[:, :3, :3] # Shape (N, 3, 3)
227
+ translations = c2w[:, :3, 3] # Shape (N, 3)
228
+
229
+ # Convert rotations to quaternions for interpolation
230
+ quaternions = torch.tensor(
231
+ R.from_matrix(rotations.numpy()).as_quat()
232
+ ) # Shape (N, 4)
233
+
234
+ # Initialize interpolated quaternions and translations
235
+ interpolated_quats = []
236
+ interpolated_translations = []
237
+
238
+ # Perform interpolation
239
+ for i in range(N - 1):
240
+ # Start and end quaternions and translations for this segment
241
+ q1, q2 = quaternions[i], quaternions[i + 1]
242
+ t1, t2 = translations[i], translations[i + 1]
243
+
244
+ # Time steps for interpolation within this segment
245
+ t_values = torch.linspace(0, 1, factor, dtype=torch.float32)
246
+
247
+ # Interpolate quaternions using SLERP
248
+ for t in t_values:
249
+ interpolated_quats.append(slerp(q1, q2, t))
250
+
251
+ # Interpolate translations linearly
252
+ interp_t = t1 * (1 - t_values[:, None]) + t2 * t_values[:, None]
253
+ interpolated_translations.append(interp_t)
254
+
255
+ interpolated_quats.append(quaternions[0])
256
+ interpolated_translations.append(translations[0].unsqueeze(0))
257
+ # Add the last pose (end of sequence)
258
+ interpolated_quats.append(quaternions[-1])
259
+ interpolated_translations.append(translations[-1].unsqueeze(0)) # Add as 2D tensor
260
+
261
+ # Combine interpolated results
262
+ interpolated_quats = torch.stack(interpolated_quats, dim=0) # Shape (new_length, 4)
263
+ interpolated_translations = torch.cat(
264
+ interpolated_translations, dim=0
265
+ ) # Shape (new_length, 3)
266
+
267
+ # Convert quaternions back to rotation matrices
268
+ interpolated_rotations = torch.tensor(
269
+ R.from_quat(interpolated_quats.numpy()).as_matrix()
270
+ ) # Shape (new_length, 3, 3)
271
+
272
+ # Form final c2w matrix
273
+ interpolated_c2w = torch.zeros((new_length, 3, 4), dtype=torch.float32)
274
+ interpolated_c2w[:, :3, :3] = interpolated_rotations
275
+ interpolated_c2w[:, :3, 3] = interpolated_translations
276
+
277
+ return interpolated_c2w
core/data/combined_multi_view_dataset.py ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import PIL
2
+ import numpy as np
3
+ import torch
4
+ from PIL import Image
5
+
6
+ from .camera_pose_utils import (
7
+ convert_w2c_between_c2w,
8
+ transform_pose_sequence_to_relative_c2w,
9
+ )
10
+
11
+
12
+ def get_ray_embeddings(
13
+ poses, size_h=256, size_w=256, fov_xy_list=None, focal_xy_list=None
14
+ ):
15
+ """
16
+ poses: sequence of cameras poses (y-up format)
17
+ """
18
+ use_focal = False
19
+ if fov_xy_list is None or fov_xy_list[0] is None or fov_xy_list[0][0] is None:
20
+ assert focal_xy_list is not None
21
+ use_focal = True
22
+
23
+ rays_embeddings = []
24
+ for i in range(poses.shape[0]):
25
+ cur_pose = poses[i]
26
+ if use_focal:
27
+ rays_o, rays_d = get_rays(
28
+ # [h, w, 3]
29
+ cur_pose,
30
+ size_h,
31
+ size_w,
32
+ focal_xy=focal_xy_list[i],
33
+ )
34
+ else:
35
+ rays_o, rays_d = get_rays(
36
+ cur_pose, size_h, size_w, fov_xy=fov_xy_list[i]
37
+ ) # [h, w, 3]
38
+
39
+ rays_plucker = torch.cat(
40
+ [torch.cross(rays_o, rays_d, dim=-1), rays_d], dim=-1
41
+ ) # [h, w, 6]
42
+ rays_embeddings.append(rays_plucker)
43
+
44
+ rays_embeddings = (
45
+ torch.stack(rays_embeddings, dim=0).permute(0, 3, 1, 2).contiguous()
46
+ ) # [V, 6, h, w]
47
+ return rays_embeddings
48
+
49
+
50
+ def get_rays(pose, h, w, fov_xy=None, focal_xy=None, opengl=True):
51
+ x, y = torch.meshgrid(
52
+ torch.arange(w, device=pose.device),
53
+ torch.arange(h, device=pose.device),
54
+ indexing="xy",
55
+ )
56
+ x = x.flatten()
57
+ y = y.flatten()
58
+
59
+ cx = w * 0.5
60
+ cy = h * 0.5
61
+
62
+ # print("fov_xy=", fov_xy)
63
+ # print("focal_xy=", focal_xy)
64
+
65
+ if focal_xy is None:
66
+ assert fov_xy is not None, "fov_x/y and focal_x/y cannot both be None."
67
+ focal_x = w * 0.5 / np.tan(0.5 * np.deg2rad(fov_xy[0]))
68
+ focal_y = h * 0.5 / np.tan(0.5 * np.deg2rad(fov_xy[1]))
69
+ else:
70
+ assert (
71
+ len(focal_xy) == 2
72
+ ), "focal_xy should be a list-like object containing only two elements (focal length in x and y direction)."
73
+ focal_x = w * focal_xy[0]
74
+ focal_y = h * focal_xy[1]
75
+
76
+ camera_dirs = torch.nn.functional.pad(
77
+ torch.stack(
78
+ [
79
+ (x - cx + 0.5) / focal_x,
80
+ (y - cy + 0.5) / focal_y * (-1.0 if opengl else 1.0),
81
+ ],
82
+ dim=-1,
83
+ ),
84
+ (0, 1),
85
+ value=(-1.0 if opengl else 1.0),
86
+ ) # [hw, 3]
87
+
88
+ rays_d = camera_dirs @ pose[:3, :3].transpose(0, 1) # [hw, 3]
89
+ rays_o = pose[:3, 3].unsqueeze(0).expand_as(rays_d) # [hw, 3]
90
+
91
+ rays_o = rays_o.view(h, w, 3)
92
+ rays_d = safe_normalize(rays_d).view(h, w, 3)
93
+
94
+ return rays_o, rays_d
95
+
96
+
97
+ def safe_normalize(x, eps=1e-20):
98
+ return x / length(x, eps)
99
+
100
+
101
+ def length(x, eps=1e-20):
102
+ if isinstance(x, np.ndarray):
103
+ return np.sqrt(np.maximum(np.sum(x * x, axis=-1, keepdims=True), eps))
104
+ else:
105
+ return torch.sqrt(torch.clamp(dot(x, x), min=eps))
106
+
107
+
108
+ def dot(x, y):
109
+ if isinstance(x, np.ndarray):
110
+ return np.sum(x * y, -1, keepdims=True)
111
+ else:
112
+ return torch.sum(x * y, -1, keepdim=True)
113
+
114
+
115
+ def extend_list_by_repeating(original_list, target_length, repeat_idx, at_front):
116
+ if not original_list:
117
+ raise ValueError("The original list cannot be empty.")
118
+
119
+ extended_list = []
120
+ original_length = len(original_list)
121
+ for i in range(target_length - original_length):
122
+ extended_list.append(original_list[repeat_idx])
123
+
124
+ if at_front:
125
+ extended_list.extend(original_list)
126
+ return extended_list
127
+ else:
128
+ original_list.extend(extended_list)
129
+ return original_list
130
+
131
+
132
+ def select_evenly_spaced_elements(arr, x):
133
+ if x <= 0 or len(arr) == 0:
134
+ return []
135
+
136
+ # Calculate step size as the ratio of length of the list and x
137
+ step = len(arr) / x
138
+
139
+ # Pick elements at indices that are multiples of step (round them to nearest integer)
140
+ selected_elements = [arr[round(i * step)] for i in range(x)]
141
+
142
+ return selected_elements
143
+
144
+
145
+ def convert_co3d_annotation_to_opengl_pose_and_intrinsics(frame_annotation):
146
+ p = frame_annotation.viewpoint.principal_point
147
+ f = frame_annotation.viewpoint.focal_length
148
+ h, w = frame_annotation.image.size
149
+ K = np.eye(3)
150
+ s = (min(h, w) - 1) / 2
151
+ if frame_annotation.viewpoint.intrinsics_format == "ndc_norm_image_bounds":
152
+ K[0, 0] = f[0] * (w - 1) / 2
153
+ K[1, 1] = f[1] * (h - 1) / 2
154
+ elif frame_annotation.viewpoint.intrinsics_format == "ndc_isotropic":
155
+ K[0, 0] = f[0] * s / 2
156
+ K[1, 1] = f[1] * s / 2
157
+ else:
158
+ assert (
159
+ False
160
+ ), f"Invalid intrinsics_format: {frame_annotation.viewpoint.intrinsics_format}"
161
+ K[0, 2] = -p[0] * s + (w - 1) / 2
162
+ K[1, 2] = -p[1] * s + (h - 1) / 2
163
+
164
+ R = np.array(frame_annotation.viewpoint.R).T # note the transpose here
165
+ T = np.array(frame_annotation.viewpoint.T)
166
+ pose = np.concatenate([R, T[:, None]], 1)
167
+ # Need to be converted into OpenGL format. Flip the direction of x, z axis
168
+ pose = np.diag([-1, 1, -1]).astype(np.float32) @ pose
169
+ return pose, K
170
+
171
+
172
+ def normalize_w2c_camera_pose_sequence(
173
+ target_camera_poses,
174
+ condition_camera_poses=None,
175
+ output_c2w=False,
176
+ translation_norm_mode="div_by_max",
177
+ ):
178
+ """
179
+ Normalize camera pose sequence so that the first frame is identity rotation and zero translation,
180
+ and the translation scale is normalized by the farest point from the first frame (to one).
181
+ :param target_camera_poses: W2C poses tensor in [N, 3, 4]
182
+ :param condition_camera_poses: W2C poses tensor in [N, 3, 4]
183
+ :return: Tuple(Tensor, Tensor), the normalized `target_camera_poses` and `condition_camera_poses`
184
+ """
185
+ # Normalize at w2c, all poses should be in w2c in UnifiedFrame
186
+ num_target_views = target_camera_poses.size(0)
187
+ if condition_camera_poses is not None:
188
+ all_poses = torch.concat([target_camera_poses, condition_camera_poses], dim=0)
189
+ else:
190
+ all_poses = target_camera_poses
191
+ # Convert W2C to C2W
192
+ normalized_poses = transform_pose_sequence_to_relative_c2w(
193
+ convert_w2c_between_c2w(all_poses)
194
+ )
195
+ # Here normalized_poses is C2W
196
+ if not output_c2w:
197
+ # Convert from C2W back to W2C if output_c2w is False.
198
+ normalized_poses = convert_w2c_between_c2w(normalized_poses)
199
+
200
+ t_norms = torch.linalg.norm(normalized_poses[:, :, 3], ord=2, dim=-1)
201
+ # print("t_norms=", t_norms)
202
+ largest_t_norm = torch.max(t_norms)
203
+
204
+ # print("largest_t_norm=", largest_t_norm)
205
+ # normalized_poses[:, :, 3] -= first_t.unsqueeze(0).repeat(normalized_poses.size(0), 1)
206
+ if translation_norm_mode == "div_by_max_plus_one":
207
+ # Always add a constant component to the translation norm
208
+ largest_t_norm = largest_t_norm + 1.0
209
+ elif translation_norm_mode == "div_by_max":
210
+ largest_t_norm = largest_t_norm
211
+ if largest_t_norm <= 0.05:
212
+ largest_t_norm = 0.05
213
+ elif translation_norm_mode == "disabled":
214
+ largest_t_norm = 1.0
215
+ else:
216
+ assert False, f"Invalid translation_norm_mode: {translation_norm_mode}."
217
+ normalized_poses[:, :, 3] /= largest_t_norm
218
+
219
+ target_camera_poses = normalized_poses[:num_target_views]
220
+ if condition_camera_poses is not None:
221
+ condition_camera_poses = normalized_poses[num_target_views:]
222
+ else:
223
+ condition_camera_poses = None
224
+ # print("After First condition:", condition_camera_poses[0])
225
+ # print("After First target:", target_camera_poses[0])
226
+ return target_camera_poses, condition_camera_poses
227
+
228
+
229
+ def central_crop_pil_image(_image, crop_size, use_central_padding=False):
230
+ if use_central_padding:
231
+ # Determine the new size
232
+ _w, _h = _image.size
233
+ new_size = max(_w, _h)
234
+ # Create a new image with white background
235
+ new_image = Image.new("RGB", (new_size, new_size), (255, 255, 255))
236
+ # Calculate the position to paste the original image
237
+ paste_position = ((new_size - _w) // 2, (new_size - _h) // 2)
238
+ # Paste the original image onto the new image
239
+ new_image.paste(_image, paste_position)
240
+ _image = new_image
241
+ # get the new size again if padded
242
+ _w, _h = _image.size
243
+ scale = crop_size / min(_h, _w)
244
+ # resize shortest side to crop_size
245
+ _w_out, _h_out = int(scale * _w), int(scale * _h)
246
+ _image = _image.resize(
247
+ (_w_out, _h_out),
248
+ resample=(
249
+ PIL.Image.Resampling.LANCZOS if scale < 1 else PIL.Image.Resampling.BICUBIC
250
+ ),
251
+ )
252
+ # center crop
253
+ margin_w = (_image.size[0] - crop_size) // 2
254
+ margin_h = (_image.size[1] - crop_size) // 2
255
+ _image = _image.crop(
256
+ (margin_w, margin_h, margin_w + crop_size, margin_h + crop_size)
257
+ )
258
+ return _image
259
+
260
+
261
+ def crop_and_resize(
262
+ image: Image.Image, target_width: int, target_height: int
263
+ ) -> Image.Image:
264
+ """
265
+ Crops and resizes an image while preserving the aspect ratio.
266
+
267
+ Args:
268
+ image (Image.Image): Input PIL image to be cropped and resized.
269
+ target_width (int): Target width of the output image.
270
+ target_height (int): Target height of the output image.
271
+
272
+ Returns:
273
+ Image.Image: Cropped and resized image.
274
+ """
275
+ # Original dimensions
276
+ original_width, original_height = image.size
277
+ original_aspect = original_width / original_height
278
+ target_aspect = target_width / target_height
279
+
280
+ # Calculate crop box to maintain aspect ratio
281
+ if original_aspect > target_aspect:
282
+ # Crop horizontally
283
+ new_width = int(original_height * target_aspect)
284
+ new_height = original_height
285
+ left = (original_width - new_width) / 2
286
+ top = 0
287
+ right = left + new_width
288
+ bottom = original_height
289
+ else:
290
+ # Crop vertically
291
+ new_width = original_width
292
+ new_height = int(original_width / target_aspect)
293
+ left = 0
294
+ top = (original_height - new_height) / 2
295
+ right = original_width
296
+ bottom = top + new_height
297
+
298
+ # Crop and resize
299
+ cropped_image = image.crop((left, top, right, bottom))
300
+ resized_image = cropped_image.resize((target_width, target_height), Image.LANCZOS)
301
+
302
+ return resized_image
303
+
304
+
305
+ def calculate_fov_after_resize(
306
+ fov_x: float,
307
+ fov_y: float,
308
+ original_width: int,
309
+ original_height: int,
310
+ target_width: int,
311
+ target_height: int,
312
+ ) -> (float, float):
313
+ """
314
+ Calculates the new field of view after cropping and resizing an image.
315
+
316
+ Args:
317
+ fov_x (float): Original field of view in the x-direction (horizontal).
318
+ fov_y (float): Original field of view in the y-direction (vertical).
319
+ original_width (int): Original width of the image.
320
+ original_height (int): Original height of the image.
321
+ target_width (int): Target width of the output image.
322
+ target_height (int): Target height of the output image.
323
+
324
+ Returns:
325
+ (float, float): New field of view (fov_x, fov_y) after cropping and resizing.
326
+ """
327
+ original_aspect = original_width / original_height
328
+ target_aspect = target_width / target_height
329
+
330
+ if original_aspect > target_aspect:
331
+ # Crop horizontally
332
+ new_width = int(original_height * target_aspect)
333
+ new_fov_x = fov_x * (new_width / original_width)
334
+ new_fov_y = fov_y
335
+ else:
336
+ # Crop vertically
337
+ new_height = int(original_width / target_aspect)
338
+ new_fov_y = fov_y * (new_height / original_height)
339
+ new_fov_x = fov_x
340
+
341
+ return new_fov_x, new_fov_y
core/data/utils.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import random
3
+ from PIL import Image
4
+
5
+ import numpy as np
6
+
7
+
8
+ def create_relative(RT_list, K_1=4.7, dataset="syn"):
9
+ if dataset == "realestate":
10
+ scale_T = 1
11
+ RT_list = [RT.reshape(3, 4) for RT in RT_list]
12
+ elif dataset == "syn":
13
+ scale_T = (470 / K_1) / 7.5
14
+ """
15
+ 4.694746736956946052e+02 0.000000000000000000e+00 4.800000000000000000e+02
16
+ 0.000000000000000000e+00 4.694746736956946052e+02 2.700000000000000000e+02
17
+ 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00
18
+ """
19
+ elif dataset == "zero123":
20
+ scale_T = 0.5
21
+ else:
22
+ raise Exception("invalid dataset type")
23
+
24
+ # convert x y z to x -y -z
25
+ if dataset == "zero123":
26
+ flip_matrix = np.array([[1, 0, 0], [0, -1, 0], [0, 0, -1]])
27
+ for i in range(len(RT_list)):
28
+ RT_list[i] = np.dot(flip_matrix, RT_list[i])
29
+
30
+ temp = []
31
+ first_frame_RT = copy.deepcopy(RT_list[0])
32
+ # first_frame_R_inv = np.linalg.inv(first_frame_RT[:,:3])
33
+ first_frame_R_inv = first_frame_RT[:, :3].T
34
+ first_frame_T = first_frame_RT[:, -1]
35
+ for RT in RT_list:
36
+ RT[:, :3] = np.dot(RT[:, :3], first_frame_R_inv)
37
+ RT[:, -1] = RT[:, -1] - np.dot(RT[:, :3], first_frame_T)
38
+ RT[:, -1] = RT[:, -1] * scale_T
39
+ temp.append(RT)
40
+ RT_list = temp
41
+
42
+ if dataset == "realestate":
43
+ RT_list = [RT.reshape(-1) for RT in RT_list]
44
+
45
+ return RT_list
46
+
47
+
48
+ def sigma_matrix2(sig_x, sig_y, theta):
49
+ """Calculate the rotated sigma matrix (two dimensional matrix).
50
+ Args:
51
+ sig_x (float):
52
+ sig_y (float):
53
+ theta (float): Radian measurement.
54
+ Returns:
55
+ ndarray: Rotated sigma matrix.
56
+ """
57
+ d_matrix = np.array([[sig_x**2, 0], [0, sig_y**2]])
58
+ u_matrix = np.array(
59
+ [[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]
60
+ )
61
+ return np.dot(u_matrix, np.dot(d_matrix, u_matrix.T))
62
+
63
+
64
+ def mesh_grid(kernel_size):
65
+ """Generate the mesh grid, centering at zero.
66
+ Args:
67
+ kernel_size (int):
68
+ Returns:
69
+ xy (ndarray): with the shape (kernel_size, kernel_size, 2)
70
+ xx (ndarray): with the shape (kernel_size, kernel_size)
71
+ yy (ndarray): with the shape (kernel_size, kernel_size)
72
+ """
73
+ ax = np.arange(-kernel_size // 2 + 1.0, kernel_size // 2 + 1.0)
74
+ xx, yy = np.meshgrid(ax, ax)
75
+ xy = np.hstack(
76
+ (
77
+ xx.reshape((kernel_size * kernel_size, 1)),
78
+ yy.reshape(kernel_size * kernel_size, 1),
79
+ )
80
+ ).reshape(kernel_size, kernel_size, 2)
81
+ return xy, xx, yy
82
+
83
+
84
+ def pdf2(sigma_matrix, grid):
85
+ """Calculate PDF of the bivariate Gaussian distribution.
86
+ Args:
87
+ sigma_matrix (ndarray): with the shape (2, 2)
88
+ grid (ndarray): generated by :func:`mesh_grid`,
89
+ with the shape (K, K, 2), K is the kernel size.
90
+ Returns:
91
+ kernel (ndarrray): un-normalized kernel.
92
+ """
93
+ inverse_sigma = np.linalg.inv(sigma_matrix)
94
+ kernel = np.exp(-0.5 * np.sum(np.dot(grid, inverse_sigma) * grid, 2))
95
+ return kernel
96
+
97
+
98
+ def bivariate_Gaussian(kernel_size, sig_x, sig_y, theta, grid=None, isotropic=True):
99
+ """Generate a bivariate isotropic or anisotropic Gaussian kernel.
100
+ In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.
101
+ Args:
102
+ kernel_size (int):
103
+ sig_x (float):
104
+ sig_y (float):
105
+ theta (float): Radian measurement.
106
+ grid (ndarray, optional): generated by :func:`mesh_grid`,
107
+ with the shape (K, K, 2), K is the kernel size. Default: None
108
+ isotropic (bool):
109
+ Returns:
110
+ kernel (ndarray): normalized kernel.
111
+ """
112
+ if grid is None:
113
+ grid, _, _ = mesh_grid(kernel_size)
114
+ if isotropic:
115
+ sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]])
116
+ else:
117
+ sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
118
+ kernel = pdf2(sigma_matrix, grid)
119
+ kernel = kernel / np.sum(kernel)
120
+ return kernel
121
+
122
+
123
+ def rgba_to_rgb_with_bg(rgba_image, bg_color=(255, 255, 255)):
124
+ """
125
+ Convert a PIL RGBA Image to an RGB Image with a white background.
126
+
127
+ Args:
128
+ rgba_image (Image): A PIL Image object in RGBA mode.
129
+
130
+ Returns:
131
+ Image: A PIL Image object in RGB mode with white background.
132
+ """
133
+ # Ensure the image is in RGBA mode
134
+ # Ensure the image is in RGBA mode
135
+ if rgba_image.mode != "RGBA":
136
+ return rgba_image
137
+ # raise ValueError("The image must be in RGBA mode")
138
+
139
+ # Create a white background image
140
+ white_bg_rgb = Image.new("RGB", rgba_image.size, bg_color)
141
+ # Paste the RGBA image onto the white background using alpha channel as mask
142
+ white_bg_rgb.paste(
143
+ rgba_image, mask=rgba_image.split()[3]
144
+ ) # 3 is the alpha channel index
145
+ return white_bg_rgb
146
+
147
+
148
+ def random_order_preserving_selection(items, num):
149
+ if num > len(items):
150
+ print("WARNING: Item list is shorter than `num` given.")
151
+ return items
152
+ selected_indices = sorted(random.sample(range(len(items)), num))
153
+ selected_items = [items[i] for i in selected_indices]
154
+ return selected_items
155
+
156
+
157
+ def pad_pil_image_to_square(image, fill_color=(255, 255, 255)):
158
+ """
159
+ Pad an image to make it square with the given fill color.
160
+
161
+ Args:
162
+ image (PIL.Image): The original image.
163
+ fill_color (tuple): The color to use for padding (default is black).
164
+
165
+ Returns:
166
+ PIL.Image: A new image that is padded to be square.
167
+ """
168
+ width, height = image.size
169
+
170
+ # Determine the new size, which will be the maximum of width or height
171
+ new_size = max(width, height)
172
+
173
+ # Create a new image with the new size and fill color
174
+ new_image = Image.new("RGB", (new_size, new_size), fill_color)
175
+
176
+ # Calculate the position to paste the original image onto the new image
177
+ # This calculation centers the original image in the new square canvas
178
+ left = (new_size - width) // 2
179
+ top = (new_size - height) // 2
180
+
181
+ # Paste the original image into the new image
182
+ new_image.paste(image, (left, top))
183
+
184
+ return new_image
core/distributions.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+
5
+ class AbstractDistribution:
6
+ def sample(self):
7
+ raise NotImplementedError()
8
+
9
+ def mode(self):
10
+ raise NotImplementedError()
11
+
12
+
13
+ class DiracDistribution(AbstractDistribution):
14
+ def __init__(self, value):
15
+ self.value = value
16
+
17
+ def sample(self):
18
+ return self.value
19
+
20
+ def mode(self):
21
+ return self.value
22
+
23
+
24
+ class DiagonalGaussianDistribution(object):
25
+ def __init__(self, parameters, deterministic=False):
26
+ self.parameters = parameters
27
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
28
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
29
+ self.deterministic = deterministic
30
+ self.std = torch.exp(0.5 * self.logvar)
31
+ self.var = torch.exp(self.logvar)
32
+ if self.deterministic:
33
+ self.var = self.std = torch.zeros_like(self.mean).to(
34
+ device=self.parameters.device
35
+ )
36
+
37
+ def sample(self, noise=None):
38
+ if noise is None:
39
+ noise = torch.randn(self.mean.shape)
40
+
41
+ x = self.mean + self.std * noise.to(device=self.parameters.device)
42
+ return x
43
+
44
+ def kl(self, other=None):
45
+ if self.deterministic:
46
+ return torch.Tensor([0.0])
47
+ else:
48
+ if other is None:
49
+ return 0.5 * torch.sum(
50
+ torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
51
+ dim=[1, 2, 3],
52
+ )
53
+ else:
54
+ return 0.5 * torch.sum(
55
+ torch.pow(self.mean - other.mean, 2) / other.var
56
+ + self.var / other.var
57
+ - 1.0
58
+ - self.logvar
59
+ + other.logvar,
60
+ dim=[1, 2, 3],
61
+ )
62
+
63
+ def nll(self, sample, dims=[1, 2, 3]):
64
+ if self.deterministic:
65
+ return torch.Tensor([0.0])
66
+ logtwopi = np.log(2.0 * np.pi)
67
+ return 0.5 * torch.sum(
68
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
69
+ dim=dims,
70
+ )
71
+
72
+ def mode(self):
73
+ return self.mean
74
+
75
+
76
+ def normal_kl(mean1, logvar1, mean2, logvar2):
77
+ """
78
+ Compute the KL divergence between two gaussians.
79
+ Shapes are automatically broadcasted, so batches can be compared to
80
+ scalars, among other use cases.
81
+ """
82
+ tensor = None
83
+ for obj in (mean1, logvar1, mean2, logvar2):
84
+ if isinstance(obj, torch.Tensor):
85
+ tensor = obj
86
+ break
87
+ assert tensor is not None, "at least one argument must be a Tensor"
88
+
89
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
90
+ # Tensors, but it does not work for torch.exp().
91
+ logvar1, logvar2 = [
92
+ x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
93
+ for x in (logvar1, logvar2)
94
+ ]
95
+
96
+ return 0.5 * (
97
+ -1.0
98
+ + logvar2
99
+ - logvar1
100
+ + torch.exp(logvar1 - logvar2)
101
+ + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
102
+ )
core/ema.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ class LitEma(nn.Module):
6
+ def __init__(self, model, decay=0.9999, use_num_upates=True):
7
+ super().__init__()
8
+ if decay < 0.0 or decay > 1.0:
9
+ raise ValueError("Decay must be between 0 and 1")
10
+
11
+ self.m_name2s_name = {}
12
+ self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32))
13
+ self.register_buffer(
14
+ "num_updates",
15
+ (
16
+ torch.tensor(0, dtype=torch.int)
17
+ if use_num_upates
18
+ else torch.tensor(-1, dtype=torch.int)
19
+ ),
20
+ )
21
+
22
+ for name, p in model.named_parameters():
23
+ if p.requires_grad:
24
+ # remove as '.'-character is not allowed in buffers
25
+ s_name = name.replace(".", "")
26
+ self.m_name2s_name.update({name: s_name})
27
+ self.register_buffer(s_name, p.clone().detach().data)
28
+
29
+ self.collected_params = []
30
+
31
+ def forward(self, model):
32
+ decay = self.decay
33
+
34
+ if self.num_updates >= 0:
35
+ self.num_updates += 1
36
+ decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
37
+
38
+ one_minus_decay = 1.0 - decay
39
+
40
+ with torch.no_grad():
41
+ m_param = dict(model.named_parameters())
42
+ shadow_params = dict(self.named_buffers())
43
+
44
+ for key in m_param:
45
+ if m_param[key].requires_grad:
46
+ sname = self.m_name2s_name[key]
47
+ shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
48
+ shadow_params[sname].sub_(
49
+ one_minus_decay * (shadow_params[sname] - m_param[key])
50
+ )
51
+ else:
52
+ assert not key in self.m_name2s_name
53
+
54
+ def copy_to(self, model):
55
+ m_param = dict(model.named_parameters())
56
+ shadow_params = dict(self.named_buffers())
57
+ for key in m_param:
58
+ if m_param[key].requires_grad:
59
+ m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
60
+ else:
61
+ assert not key in self.m_name2s_name
62
+
63
+ def store(self, parameters):
64
+ """
65
+ Save the current parameters for restoring later.
66
+ Args:
67
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
68
+ temporarily stored.
69
+ """
70
+ self.collected_params = [param.clone() for param in parameters]
71
+
72
+ def restore(self, parameters):
73
+ """
74
+ Restore the parameters stored with the `store` method.
75
+ Useful to validate the model with EMA parameters without affecting the
76
+ original optimization process. Store the parameters before the
77
+ `copy_to` method. After validation (or model saving), use this to
78
+ restore the former parameters.
79
+ Args:
80
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
81
+ updated with the stored parameters.
82
+ """
83
+ for c_param, param in zip(self.collected_params, parameters):
84
+ param.data.copy_(c_param.data)
core/losses/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from core.losses.contperceptual import LPIPSWithDiscriminator
core/losses/contperceptual.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from einops import rearrange
4
+ from taming.modules.losses.vqperceptual import *
5
+
6
+
7
+ class LPIPSWithDiscriminator(nn.Module):
8
+ def __init__(
9
+ self,
10
+ disc_start,
11
+ logvar_init=0.0,
12
+ kl_weight=1.0,
13
+ pixelloss_weight=1.0,
14
+ disc_num_layers=3,
15
+ disc_in_channels=3,
16
+ disc_factor=1.0,
17
+ disc_weight=1.0,
18
+ perceptual_weight=1.0,
19
+ use_actnorm=False,
20
+ disc_conditional=False,
21
+ disc_loss="hinge",
22
+ max_bs=None,
23
+ ):
24
+
25
+ super().__init__()
26
+ assert disc_loss in ["hinge", "vanilla"]
27
+ self.kl_weight = kl_weight
28
+ self.pixel_weight = pixelloss_weight
29
+ self.perceptual_loss = LPIPS().eval()
30
+ self.perceptual_weight = perceptual_weight
31
+ # output log variance
32
+ self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
33
+
34
+ self.discriminator = NLayerDiscriminator(
35
+ input_nc=disc_in_channels, n_layers=disc_num_layers, use_actnorm=use_actnorm
36
+ ).apply(weights_init)
37
+ self.discriminator_iter_start = disc_start
38
+ self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
39
+ self.disc_factor = disc_factor
40
+ self.discriminator_weight = disc_weight
41
+ self.disc_conditional = disc_conditional
42
+ self.max_bs = max_bs
43
+
44
+ def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
45
+ if last_layer is not None:
46
+ nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
47
+ g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
48
+ else:
49
+ nll_grads = torch.autograd.grad(
50
+ nll_loss, self.last_layer[0], retain_graph=True
51
+ )[0]
52
+ g_grads = torch.autograd.grad(
53
+ g_loss, self.last_layer[0], retain_graph=True
54
+ )[0]
55
+
56
+ d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
57
+ d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
58
+ d_weight = d_weight * self.discriminator_weight
59
+ return d_weight
60
+
61
+ def forward(
62
+ self,
63
+ inputs,
64
+ reconstructions,
65
+ posteriors,
66
+ optimizer_idx,
67
+ global_step,
68
+ last_layer=None,
69
+ cond=None,
70
+ split="train",
71
+ weights=None,
72
+ ):
73
+ if inputs.dim() == 5:
74
+ inputs = rearrange(inputs, "b c t h w -> (b t) c h w")
75
+ if reconstructions.dim() == 5:
76
+ reconstructions = rearrange(reconstructions, "b c t h w -> (b t) c h w")
77
+ rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
78
+ if self.perceptual_weight > 0:
79
+ if self.max_bs is not None and self.max_bs < inputs.shape[0]:
80
+ input_list = torch.split(inputs, self.max_bs, dim=0)
81
+ reconstruction_list = torch.split(reconstructions, self.max_bs, dim=0)
82
+ p_losses = [
83
+ self.perceptual_loss(
84
+ inputs.contiguous(), reconstructions.contiguous()
85
+ )
86
+ for inputs, reconstructions in zip(input_list, reconstruction_list)
87
+ ]
88
+ p_loss = torch.cat(p_losses, dim=0)
89
+ else:
90
+ p_loss = self.perceptual_loss(
91
+ inputs.contiguous(), reconstructions.contiguous()
92
+ )
93
+ rec_loss = rec_loss + self.perceptual_weight * p_loss
94
+
95
+ nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
96
+ weighted_nll_loss = nll_loss
97
+ if weights is not None:
98
+ weighted_nll_loss = weights * nll_loss
99
+ weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
100
+ nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
101
+
102
+ kl_loss = posteriors.kl()
103
+ kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
104
+
105
+ # now the GAN part
106
+ if optimizer_idx == 0:
107
+ # generator update
108
+ if cond is None:
109
+ assert not self.disc_conditional
110
+ logits_fake = self.discriminator(reconstructions.contiguous())
111
+ else:
112
+ assert self.disc_conditional
113
+ logits_fake = self.discriminator(
114
+ torch.cat((reconstructions.contiguous(), cond), dim=1)
115
+ )
116
+ g_loss = -torch.mean(logits_fake)
117
+
118
+ if self.disc_factor > 0.0:
119
+ try:
120
+ d_weight = self.calculate_adaptive_weight(
121
+ nll_loss, g_loss, last_layer=last_layer
122
+ )
123
+ except RuntimeError:
124
+ assert not self.training
125
+ d_weight = torch.tensor(0.0)
126
+ else:
127
+ d_weight = torch.tensor(0.0)
128
+
129
+ disc_factor = adopt_weight(
130
+ self.disc_factor, global_step, threshold=self.discriminator_iter_start
131
+ )
132
+ loss = (
133
+ weighted_nll_loss
134
+ + self.kl_weight * kl_loss
135
+ + d_weight * disc_factor * g_loss
136
+ )
137
+
138
+ log = {
139
+ "{}/total_loss".format(split): loss.clone().detach().mean(),
140
+ "{}/logvar".format(split): self.logvar.detach(),
141
+ "{}/kl_loss".format(split): kl_loss.detach().mean(),
142
+ "{}/nll_loss".format(split): nll_loss.detach().mean(),
143
+ "{}/rec_loss".format(split): rec_loss.detach().mean(),
144
+ "{}/d_weight".format(split): d_weight.detach(),
145
+ "{}/disc_factor".format(split): torch.tensor(disc_factor),
146
+ "{}/g_loss".format(split): g_loss.detach().mean(),
147
+ }
148
+ return loss, log
149
+
150
+ if optimizer_idx == 1:
151
+ # second pass for discriminator update
152
+ if cond is None:
153
+ logits_real = self.discriminator(inputs.contiguous().detach())
154
+ logits_fake = self.discriminator(reconstructions.contiguous().detach())
155
+ else:
156
+ logits_real = self.discriminator(
157
+ torch.cat((inputs.contiguous().detach(), cond), dim=1)
158
+ )
159
+ logits_fake = self.discriminator(
160
+ torch.cat((reconstructions.contiguous().detach(), cond), dim=1)
161
+ )
162
+
163
+ disc_factor = adopt_weight(
164
+ self.disc_factor, global_step, threshold=self.discriminator_iter_start
165
+ )
166
+ d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
167
+
168
+ log = {
169
+ "{}/disc_loss".format(split): d_loss.clone().detach().mean(),
170
+ "{}/logits_real".format(split): logits_real.detach().mean(),
171
+ "{}/logits_fake".format(split): logits_fake.detach().mean(),
172
+ }
173
+ return d_loss, log
core/losses/vqperceptual.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+ from einops import repeat
5
+
6
+ from taming.modules.discriminator.model import NLayerDiscriminator, weights_init
7
+ from taming.modules.losses.lpips import LPIPS
8
+ from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss
9
+
10
+
11
+ def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights):
12
+ assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0]
13
+ loss_real = torch.mean(F.relu(1.0 - logits_real), dim=[1, 2, 3])
14
+ loss_fake = torch.mean(F.relu(1.0 + logits_fake), dim=[1, 2, 3])
15
+ loss_real = (weights * loss_real).sum() / weights.sum()
16
+ loss_fake = (weights * loss_fake).sum() / weights.sum()
17
+ d_loss = 0.5 * (loss_real + loss_fake)
18
+ return d_loss
19
+
20
+
21
+ def adopt_weight(weight, global_step, threshold=0, value=0.0):
22
+ if global_step < threshold:
23
+ weight = value
24
+ return weight
25
+
26
+
27
+ def measure_perplexity(predicted_indices, n_embed):
28
+ encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed)
29
+ avg_probs = encodings.mean(0)
30
+ perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
31
+ cluster_use = torch.sum(avg_probs > 0)
32
+ return perplexity, cluster_use
33
+
34
+
35
+ def l1(x, y):
36
+ return torch.abs(x - y)
37
+
38
+
39
+ def l2(x, y):
40
+ return torch.pow((x - y), 2)
41
+
42
+
43
+ class VQLPIPSWithDiscriminator(nn.Module):
44
+ def __init__(
45
+ self,
46
+ disc_start,
47
+ codebook_weight=1.0,
48
+ pixelloss_weight=1.0,
49
+ disc_num_layers=3,
50
+ disc_in_channels=3,
51
+ disc_factor=1.0,
52
+ disc_weight=1.0,
53
+ perceptual_weight=1.0,
54
+ use_actnorm=False,
55
+ disc_conditional=False,
56
+ disc_ndf=64,
57
+ disc_loss="hinge",
58
+ n_classes=None,
59
+ perceptual_loss="lpips",
60
+ pixel_loss="l1",
61
+ ):
62
+ super().__init__()
63
+ assert disc_loss in ["hinge", "vanilla"]
64
+ assert perceptual_loss in ["lpips", "clips", "dists"]
65
+ assert pixel_loss in ["l1", "l2"]
66
+ self.codebook_weight = codebook_weight
67
+ self.pixel_weight = pixelloss_weight
68
+ if perceptual_loss == "lpips":
69
+ print(f"{self.__class__.__name__}: Running with LPIPS.")
70
+ self.perceptual_loss = LPIPS().eval()
71
+ else:
72
+ raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<")
73
+ self.perceptual_weight = perceptual_weight
74
+
75
+ if pixel_loss == "l1":
76
+ self.pixel_loss = l1
77
+ else:
78
+ self.pixel_loss = l2
79
+
80
+ self.discriminator = NLayerDiscriminator(
81
+ input_nc=disc_in_channels,
82
+ n_layers=disc_num_layers,
83
+ use_actnorm=use_actnorm,
84
+ ndf=disc_ndf,
85
+ ).apply(weights_init)
86
+ self.discriminator_iter_start = disc_start
87
+ if disc_loss == "hinge":
88
+ self.disc_loss = hinge_d_loss
89
+ elif disc_loss == "vanilla":
90
+ self.disc_loss = vanilla_d_loss
91
+ else:
92
+ raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
93
+ print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.")
94
+ self.disc_factor = disc_factor
95
+ self.discriminator_weight = disc_weight
96
+ self.disc_conditional = disc_conditional
97
+ self.n_classes = n_classes
98
+
99
+ def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
100
+ if last_layer is not None:
101
+ nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
102
+ g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
103
+ else:
104
+ nll_grads = torch.autograd.grad(
105
+ nll_loss, self.last_layer[0], retain_graph=True
106
+ )[0]
107
+ g_grads = torch.autograd.grad(
108
+ g_loss, self.last_layer[0], retain_graph=True
109
+ )[0]
110
+
111
+ d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
112
+ d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
113
+ d_weight = d_weight * self.discriminator_weight
114
+ return d_weight
115
+
116
+ def forward(
117
+ self,
118
+ codebook_loss,
119
+ inputs,
120
+ reconstructions,
121
+ optimizer_idx,
122
+ global_step,
123
+ last_layer=None,
124
+ cond=None,
125
+ split="train",
126
+ predicted_indices=None,
127
+ ):
128
+ if not exists(codebook_loss):
129
+ codebook_loss = torch.tensor([0.0]).to(inputs.device)
130
+ # rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
131
+ rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous())
132
+ if self.perceptual_weight > 0:
133
+ p_loss = self.perceptual_loss(
134
+ inputs.contiguous(), reconstructions.contiguous()
135
+ )
136
+ rec_loss = rec_loss + self.perceptual_weight * p_loss
137
+ else:
138
+ p_loss = torch.tensor([0.0])
139
+
140
+ nll_loss = rec_loss
141
+ # nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
142
+ nll_loss = torch.mean(nll_loss)
143
+
144
+ # now the GAN part
145
+ if optimizer_idx == 0:
146
+ # generator update
147
+ if cond is None:
148
+ assert not self.disc_conditional
149
+ logits_fake = self.discriminator(reconstructions.contiguous())
150
+ else:
151
+ assert self.disc_conditional
152
+ logits_fake = self.discriminator(
153
+ torch.cat((reconstructions.contiguous(), cond), dim=1)
154
+ )
155
+ g_loss = -torch.mean(logits_fake)
156
+
157
+ try:
158
+ d_weight = self.calculate_adaptive_weight(
159
+ nll_loss, g_loss, last_layer=last_layer
160
+ )
161
+ except RuntimeError:
162
+ assert not self.training
163
+ d_weight = torch.tensor(0.0)
164
+
165
+ disc_factor = adopt_weight(
166
+ self.disc_factor, global_step, threshold=self.discriminator_iter_start
167
+ )
168
+ loss = (
169
+ nll_loss
170
+ + d_weight * disc_factor * g_loss
171
+ + self.codebook_weight * codebook_loss.mean()
172
+ )
173
+
174
+ log = {
175
+ "{}/total_loss".format(split): loss.clone().detach().mean(),
176
+ "{}/quant_loss".format(split): codebook_loss.detach().mean(),
177
+ "{}/nll_loss".format(split): nll_loss.detach().mean(),
178
+ "{}/rec_loss".format(split): rec_loss.detach().mean(),
179
+ "{}/p_loss".format(split): p_loss.detach().mean(),
180
+ "{}/d_weight".format(split): d_weight.detach(),
181
+ "{}/disc_factor".format(split): torch.tensor(disc_factor),
182
+ "{}/g_loss".format(split): g_loss.detach().mean(),
183
+ }
184
+ if predicted_indices is not None:
185
+ assert self.n_classes is not None
186
+ with torch.no_grad():
187
+ perplexity, cluster_usage = measure_perplexity(
188
+ predicted_indices, self.n_classes
189
+ )
190
+ log[f"{split}/perplexity"] = perplexity
191
+ log[f"{split}/cluster_usage"] = cluster_usage
192
+ return loss, log
193
+
194
+ if optimizer_idx == 1:
195
+ # second pass for discriminator update
196
+ if cond is None:
197
+ logits_real = self.discriminator(inputs.contiguous().detach())
198
+ logits_fake = self.discriminator(reconstructions.contiguous().detach())
199
+ else:
200
+ logits_real = self.discriminator(
201
+ torch.cat((inputs.contiguous().detach(), cond), dim=1)
202
+ )
203
+ logits_fake = self.discriminator(
204
+ torch.cat((reconstructions.contiguous().detach(), cond), dim=1)
205
+ )
206
+
207
+ disc_factor = adopt_weight(
208
+ self.disc_factor, global_step, threshold=self.discriminator_iter_start
209
+ )
210
+ d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
211
+
212
+ log = {
213
+ "{}/disc_loss".format(split): d_loss.clone().detach().mean(),
214
+ "{}/logits_real".format(split): logits_real.detach().mean(),
215
+ "{}/logits_fake".format(split): logits_fake.detach().mean(),
216
+ }
217
+ return d_loss, log
core/models/autoencoder.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ from contextlib import contextmanager
4
+
5
+ import torch
6
+ import numpy as np
7
+ from einops import rearrange
8
+
9
+ import torch.nn.functional as F
10
+ import torch.distributed as dist
11
+ import pytorch_lightning as pl
12
+ from pytorch_lightning.utilities import rank_zero_only
13
+
14
+ from taming.modules.vqvae.quantize import VectorQuantizer as VectorQuantizer
15
+
16
+ from core.modules.networks.ae_modules import Encoder, Decoder
17
+ from core.distributions import DiagonalGaussianDistribution
18
+ from utils.utils import instantiate_from_config
19
+ from utils.save_video import tensor2videogrids
20
+ from core.common import shape_to_str, gather_data
21
+
22
+
23
+ class AutoencoderKL(pl.LightningModule):
24
+ def __init__(
25
+ self,
26
+ ddconfig,
27
+ lossconfig,
28
+ embed_dim,
29
+ ckpt_path=None,
30
+ ignore_keys=[],
31
+ image_key="image",
32
+ colorize_nlabels=None,
33
+ monitor=None,
34
+ test=False,
35
+ logdir=None,
36
+ input_dim=4,
37
+ test_args=None,
38
+ ):
39
+ super().__init__()
40
+ self.image_key = image_key
41
+ self.encoder = Encoder(**ddconfig)
42
+ self.decoder = Decoder(**ddconfig)
43
+ self.loss = instantiate_from_config(lossconfig)
44
+ assert ddconfig["double_z"]
45
+ self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1)
46
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
47
+ self.embed_dim = embed_dim
48
+ self.input_dim = input_dim
49
+ self.test = test
50
+ self.test_args = test_args
51
+ self.logdir = logdir
52
+ if colorize_nlabels is not None:
53
+ assert type(colorize_nlabels) == int
54
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
55
+ if monitor is not None:
56
+ self.monitor = monitor
57
+ if ckpt_path is not None:
58
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
59
+ if self.test:
60
+ self.init_test()
61
+
62
+ def init_test(
63
+ self,
64
+ ):
65
+ self.test = True
66
+ save_dir = os.path.join(self.logdir, "test")
67
+ if "ckpt" in self.test_args:
68
+ ckpt_name = (
69
+ os.path.basename(self.test_args.ckpt).split(".ckpt")[0]
70
+ + f"_epoch{self._cur_epoch}"
71
+ )
72
+ self.root = os.path.join(save_dir, ckpt_name)
73
+ else:
74
+ self.root = save_dir
75
+ if "test_subdir" in self.test_args:
76
+ self.root = os.path.join(save_dir, self.test_args.test_subdir)
77
+
78
+ self.root_zs = os.path.join(self.root, "zs")
79
+ self.root_dec = os.path.join(self.root, "reconstructions")
80
+ self.root_inputs = os.path.join(self.root, "inputs")
81
+ os.makedirs(self.root, exist_ok=True)
82
+
83
+ if self.test_args.save_z:
84
+ os.makedirs(self.root_zs, exist_ok=True)
85
+ if self.test_args.save_reconstruction:
86
+ os.makedirs(self.root_dec, exist_ok=True)
87
+ if self.test_args.save_input:
88
+ os.makedirs(self.root_inputs, exist_ok=True)
89
+ assert self.test_args is not None
90
+ self.test_maximum = getattr(
91
+ self.test_args, "test_maximum", None
92
+ ) # 1500 # 12000/8
93
+ self.count = 0
94
+ self.eval_metrics = {}
95
+ self.decodes = []
96
+ self.save_decode_samples = 2048
97
+ if getattr(self.test_args, "cal_metrics", False):
98
+ self.EvalLpips = EvalLpips()
99
+
100
+ def init_from_ckpt(self, path, ignore_keys=list()):
101
+ sd = torch.load(path, map_location="cpu")
102
+ try:
103
+ self._cur_epoch = sd["epoch"]
104
+ sd = sd["state_dict"]
105
+ except:
106
+ self._cur_epoch = "null"
107
+ keys = list(sd.keys())
108
+ for k in keys:
109
+ for ik in ignore_keys:
110
+ if k.startswith(ik):
111
+ print("Deleting key {} from state_dict.".format(k))
112
+ del sd[k]
113
+ self.load_state_dict(sd, strict=False)
114
+ # self.load_state_dict(sd, strict=True)
115
+ print(f"Restored from {path}")
116
+
117
+ def encode(self, x, **kwargs):
118
+
119
+ h = self.encoder(x)
120
+ moments = self.quant_conv(h)
121
+ posterior = DiagonalGaussianDistribution(moments)
122
+ return posterior
123
+
124
+ def decode(self, z, **kwargs):
125
+ z = self.post_quant_conv(z)
126
+ dec = self.decoder(z)
127
+ return dec
128
+
129
+ def forward(self, input, sample_posterior=True):
130
+ posterior = self.encode(input)
131
+ if sample_posterior:
132
+ z = posterior.sample()
133
+ else:
134
+ z = posterior.mode()
135
+ dec = self.decode(z)
136
+ return dec, posterior
137
+
138
+ def get_input(self, batch, k):
139
+ x = batch[k]
140
+ # if len(x.shape) == 3:
141
+ # x = x[..., None]
142
+ # if x.dim() == 4:
143
+ # x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
144
+ if x.dim() == 5 and self.input_dim == 4:
145
+ b, c, t, h, w = x.shape
146
+ self.b = b
147
+ self.t = t
148
+ x = rearrange(x, "b c t h w -> (b t) c h w")
149
+
150
+ return x
151
+
152
+ def training_step(self, batch, batch_idx, optimizer_idx):
153
+ inputs = self.get_input(batch, self.image_key)
154
+ reconstructions, posterior = self(inputs)
155
+
156
+ if optimizer_idx == 0:
157
+ # train encoder+decoder+logvar
158
+ aeloss, log_dict_ae = self.loss(
159
+ inputs,
160
+ reconstructions,
161
+ posterior,
162
+ optimizer_idx,
163
+ self.global_step,
164
+ last_layer=self.get_last_layer(),
165
+ split="train",
166
+ )
167
+ self.log(
168
+ "aeloss",
169
+ aeloss,
170
+ prog_bar=True,
171
+ logger=True,
172
+ on_step=True,
173
+ on_epoch=True,
174
+ )
175
+ self.log_dict(
176
+ log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False
177
+ )
178
+ return aeloss
179
+
180
+ if optimizer_idx == 1:
181
+ # train the discriminator
182
+ discloss, log_dict_disc = self.loss(
183
+ inputs,
184
+ reconstructions,
185
+ posterior,
186
+ optimizer_idx,
187
+ self.global_step,
188
+ last_layer=self.get_last_layer(),
189
+ split="train",
190
+ )
191
+
192
+ self.log(
193
+ "discloss",
194
+ discloss,
195
+ prog_bar=True,
196
+ logger=True,
197
+ on_step=True,
198
+ on_epoch=True,
199
+ )
200
+ self.log_dict(
201
+ log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False
202
+ )
203
+ return discloss
204
+
205
+ def validation_step(self, batch, batch_idx):
206
+ inputs = self.get_input(batch, self.image_key)
207
+ reconstructions, posterior = self(inputs)
208
+ aeloss, log_dict_ae = self.loss(
209
+ inputs,
210
+ reconstructions,
211
+ posterior,
212
+ 0,
213
+ self.global_step,
214
+ last_layer=self.get_last_layer(),
215
+ split="val",
216
+ )
217
+
218
+ discloss, log_dict_disc = self.loss(
219
+ inputs,
220
+ reconstructions,
221
+ posterior,
222
+ 1,
223
+ self.global_step,
224
+ last_layer=self.get_last_layer(),
225
+ split="val",
226
+ )
227
+
228
+ self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
229
+ self.log_dict(log_dict_ae)
230
+ self.log_dict(log_dict_disc)
231
+ return self.log_dict
232
+
233
+ def test_step(self, batch, batch_idx):
234
+ # save z, dec
235
+ inputs = self.get_input(batch, self.image_key)
236
+ # forward
237
+ sample_posterior = True
238
+ posterior = self.encode(inputs)
239
+ if sample_posterior:
240
+ z = posterior.sample()
241
+ else:
242
+ z = posterior.mode()
243
+ dec = self.decode(z)
244
+
245
+ # logs
246
+ if self.test_args.save_z:
247
+ torch.save(
248
+ z,
249
+ os.path.join(
250
+ self.root_zs,
251
+ f"zs_batch{batch_idx}_rank{self.global_rank}_shape{shape_to_str(z)}.pt",
252
+ ),
253
+ )
254
+ if self.test_args.save_reconstruction:
255
+ tensor2videogrids(
256
+ dec,
257
+ self.root_dec,
258
+ f"reconstructions_batch{batch_idx}_rank{self.global_rank}_shape{shape_to_str(z)}.mp4",
259
+ fps=10,
260
+ )
261
+ if self.test_args.save_input:
262
+ tensor2videogrids(
263
+ inputs,
264
+ self.root_inputs,
265
+ f"inputs_batch{batch_idx}_rank{self.global_rank}_shape{shape_to_str(z)}.mp4",
266
+ fps=10,
267
+ )
268
+
269
+ if "save_z" in self.test_args and self.test_args.save_z:
270
+ dec_np = (dec.detach().cpu().numpy().transpose(0, 2, 3, 4, 1) + 1) / 2 * 255
271
+ dec_np = dec_np.astype(np.uint8)
272
+ self.root_dec_np = os.path.join(self.root, "reconstructions_np")
273
+ os.makedirs(self.root_dec_np, exist_ok=True)
274
+ np.savez(
275
+ os.path.join(
276
+ self.root_dec_np,
277
+ f"reconstructions_batch{batch_idx}_rank{self.global_rank}_shape{shape_to_str(dec_np)}.npz",
278
+ ),
279
+ dec_np,
280
+ )
281
+
282
+ self.count += z.shape[0]
283
+
284
+ # misc
285
+ self.log("batch_idx", batch_idx, prog_bar=True)
286
+ self.log_dict(self.eval_metrics, prog_bar=True, logger=True)
287
+ torch.cuda.empty_cache()
288
+ if self.test_maximum is not None:
289
+ if self.count > self.test_maximum:
290
+ import sys
291
+
292
+ sys.exit()
293
+ else:
294
+ prog = self.count / self.test_maximum * 100
295
+ print(f"Test progress: {prog:.2f}% [{self.count}/{self.test_maximum}]")
296
+
297
+ @rank_zero_only
298
+ def on_test_end(self):
299
+ if self.test_args.cal_metrics:
300
+ psnrs, ssims, ms_ssims, lpipses = [], [], [], []
301
+ n_batches = 0
302
+ n_samples = 0
303
+ overall = {}
304
+ for k, v in self.eval_metrics.items():
305
+ psnrs.append(v["psnr"])
306
+ ssims.append(v["ssim"])
307
+ lpipses.append(v["lpips"])
308
+ n_batches += 1
309
+ n_samples += v["n_samples"]
310
+
311
+ mean_psnr = sum(psnrs) / len(psnrs)
312
+ mean_ssim = sum(ssims) / len(ssims)
313
+ # overall['ms_ssim'] = min(ms_ssims)
314
+ mean_lpips = sum(lpipses) / len(lpipses)
315
+
316
+ overall = {
317
+ "psnr": mean_psnr,
318
+ "ssim": mean_ssim,
319
+ "lpips": mean_lpips,
320
+ "n_batches": n_batches,
321
+ "n_samples": n_samples,
322
+ }
323
+ overall_t = torch.tensor([mean_psnr, mean_ssim, mean_lpips])
324
+ # dump
325
+ for k, v in overall.items():
326
+ if isinstance(v, torch.Tensor):
327
+ overall[k] = float(v)
328
+ with open(
329
+ os.path.join(self.root, f"reconstruction_metrics.json"), "w"
330
+ ) as f:
331
+ json.dump(overall, f)
332
+ f.close()
333
+
334
+ def configure_optimizers(self):
335
+ lr = self.learning_rate
336
+ opt_ae = torch.optim.Adam(
337
+ list(self.encoder.parameters())
338
+ + list(self.decoder.parameters())
339
+ + list(self.quant_conv.parameters())
340
+ + list(self.post_quant_conv.parameters()),
341
+ lr=lr,
342
+ betas=(0.5, 0.9),
343
+ )
344
+ opt_disc = torch.optim.Adam(
345
+ self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9)
346
+ )
347
+ return [opt_ae, opt_disc], []
348
+
349
+ def get_last_layer(self):
350
+ return self.decoder.conv_out.weight
351
+
352
+ @torch.no_grad()
353
+ def log_images(self, batch, only_inputs=False, **kwargs):
354
+ log = dict()
355
+ x = self.get_input(batch, self.image_key)
356
+ x = x.to(self.device)
357
+ if not only_inputs:
358
+ xrec, posterior = self(x)
359
+ if x.shape[1] > 3:
360
+ # colorize with random projection
361
+ assert xrec.shape[1] > 3
362
+ x = self.to_rgb(x)
363
+ xrec = self.to_rgb(xrec)
364
+ log["samples"] = self.decode(torch.randn_like(posterior.sample()))
365
+ log["reconstructions"] = xrec
366
+ log["inputs"] = x
367
+ return log
368
+
369
+ def to_rgb(self, x):
370
+ assert self.image_key == "segmentation"
371
+ if not hasattr(self, "colorize"):
372
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
373
+ x = F.conv2d(x, weight=self.colorize)
374
+ x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0
375
+ return x
376
+
377
+
378
+ class IdentityFirstStage(torch.nn.Module):
379
+ def __init__(self, *args, vq_interface=False, **kwargs):
380
+ self.vq_interface = vq_interface
381
+ super().__init__()
382
+
383
+ def encode(self, x, *args, **kwargs):
384
+ return x
385
+
386
+ def decode(self, x, *args, **kwargs):
387
+ return x
388
+
389
+ def quantize(self, x, *args, **kwargs):
390
+ if self.vq_interface:
391
+ return x, None, [None, None, None]
392
+ return x
393
+
394
+ def forward(self, x, *args, **kwargs):
395
+ return x
core/models/diffusion.py ADDED
@@ -0,0 +1,1679 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from collections import OrderedDict
3
+ from contextlib import contextmanager
4
+ from functools import partial
5
+
6
+ import numpy as np
7
+ from einops import rearrange
8
+ from tqdm import tqdm
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ from torchvision.utils import make_grid
13
+ import pytorch_lightning as pl
14
+ from pytorch_lightning.utilities import rank_zero_only
15
+
16
+ from core.modules.networks.unet_modules import TASK_IDX_IMAGE, TASK_IDX_RAY
17
+ from utils.utils import instantiate_from_config
18
+ from core.ema import LitEma
19
+ from core.distributions import DiagonalGaussianDistribution
20
+ from core.models.utils_diffusion import make_beta_schedule, rescale_zero_terminal_snr
21
+ from core.models.samplers.ddim import DDIMSampler
22
+ from core.basics import disabled_train
23
+ from core.common import extract_into_tensor, noise_like, exists, default
24
+
25
+ main_logger = logging.getLogger("main_logger")
26
+
27
+
28
+ class BD(nn.Module):
29
+ def __init__(self, G=10):
30
+ super(BD, self).__init__()
31
+
32
+ self.momentum = 0.9
33
+ self.register_buffer("running_wm", torch.eye(G).expand(G, G))
34
+ self.running_wm = None
35
+
36
+ def forward(self, x, T=5, eps=1e-5):
37
+ N, C, G, H, W = x.size()
38
+ x = torch.permute(x, [0, 2, 1, 3, 4])
39
+ x_in = x.transpose(0, 1).contiguous().view(G, -1)
40
+ if self.training:
41
+ mean = x_in.mean(-1, keepdim=True)
42
+ xc = x_in - mean
43
+ d, m = x_in.size()
44
+ P = [None] * (T + 1)
45
+ P[0] = torch.eye(G, device=x.device)
46
+ Sigma = (torch.matmul(xc, xc.transpose(0, 1))) / float(m) + P[0] * eps
47
+ rTr = (Sigma * P[0]).sum([0, 1], keepdim=True).reciprocal()
48
+ Sigma_N = Sigma * rTr
49
+ wm = torch.linalg.solve_triangular(
50
+ torch.linalg.cholesky(Sigma_N), P[0], upper=False
51
+ )
52
+ self.running_wm = self.momentum * self.running_wm + (1 - self.momentum) * wm
53
+ else:
54
+ wm = self.running_wm
55
+
56
+ x_out = wm @ x_in
57
+ x_out = x_out.view(G, N, C, H, W).permute([1, 2, 0, 3, 4]).contiguous()
58
+
59
+ return x_out
60
+
61
+
62
+ class AbstractDDPM(pl.LightningModule):
63
+
64
+ def __init__(
65
+ self,
66
+ unet_config,
67
+ time_steps=1000,
68
+ beta_schedule="linear",
69
+ loss_type="l2",
70
+ monitor=None,
71
+ use_ema=True,
72
+ first_stage_key="image",
73
+ image_size=256,
74
+ channels=3,
75
+ log_every_t=100,
76
+ clip_denoised=True,
77
+ linear_start=1e-4,
78
+ linear_end=2e-2,
79
+ cosine_s=8e-3,
80
+ given_betas=None,
81
+ original_elbo_weight=0.0,
82
+ # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
83
+ v_posterior=0.0,
84
+ l_simple_weight=1.0,
85
+ conditioning_key=None,
86
+ parameterization="eps",
87
+ rescale_betas_zero_snr=False,
88
+ scheduler_config=None,
89
+ use_positional_encodings=False,
90
+ learn_logvar=False,
91
+ logvar_init=0.0,
92
+ bd_noise=False,
93
+ ):
94
+ super().__init__()
95
+ assert parameterization in [
96
+ "eps",
97
+ "x0",
98
+ "v",
99
+ ], 'currently only supporting "eps" and "x0" and "v"'
100
+ self.parameterization = parameterization
101
+ main_logger.info(
102
+ f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode"
103
+ )
104
+ self.cond_stage_model = None
105
+ self.clip_denoised = clip_denoised
106
+ self.log_every_t = log_every_t
107
+ self.first_stage_key = first_stage_key
108
+ self.channels = channels
109
+ self.cond_channels = unet_config.params.in_channels - channels
110
+ self.temporal_length = unet_config.params.temporal_length
111
+ self.image_size = image_size
112
+ self.bd_noise = bd_noise
113
+
114
+ if self.bd_noise:
115
+ self.bd = BD(G=self.temporal_length)
116
+
117
+ if isinstance(self.image_size, int):
118
+ self.image_size = [self.image_size, self.image_size]
119
+ self.use_positional_encodings = use_positional_encodings
120
+ self.model = DiffusionWrapper(unet_config)
121
+ self.use_ema = use_ema
122
+ if self.use_ema:
123
+ self.model_ema = LitEma(self.model)
124
+ main_logger.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
125
+
126
+ self.rescale_betas_zero_snr = rescale_betas_zero_snr
127
+ self.use_scheduler = scheduler_config is not None
128
+ if self.use_scheduler:
129
+ self.scheduler_config = scheduler_config
130
+
131
+ self.v_posterior = v_posterior
132
+ self.original_elbo_weight = original_elbo_weight
133
+ self.l_simple_weight = l_simple_weight
134
+
135
+ self.linear_end = None
136
+ self.linear_start = None
137
+ self.num_time_steps: int = 1000
138
+
139
+ if monitor is not None:
140
+ self.monitor = monitor
141
+
142
+ self.register_schedule(
143
+ given_betas=given_betas,
144
+ beta_schedule=beta_schedule,
145
+ time_steps=time_steps,
146
+ linear_start=linear_start,
147
+ linear_end=linear_end,
148
+ cosine_s=cosine_s,
149
+ )
150
+
151
+ self.given_betas = given_betas
152
+ self.beta_schedule = beta_schedule
153
+ self.time_steps = time_steps
154
+ self.cosine_s = cosine_s
155
+
156
+ self.loss_type = loss_type
157
+
158
+ self.learn_logvar = learn_logvar
159
+ self.logvar = torch.full(fill_value=logvar_init, size=(self.num_time_steps,))
160
+ if self.learn_logvar:
161
+ self.logvar = nn.Parameter(self.logvar, requires_grad=True)
162
+
163
+ def predict_start_from_noise(self, x_t, t, noise):
164
+ return (
165
+ extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
166
+ - extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
167
+ * noise
168
+ )
169
+
170
+ def predict_start_from_z_and_v(self, x_t, t, v):
171
+ return (
172
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t
173
+ - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
174
+ )
175
+
176
+ def predict_eps_from_z_and_v(self, x_t, t, v):
177
+ return (
178
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v
179
+ + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape)
180
+ * x_t
181
+ )
182
+
183
+ def get_v(self, x, noise, t):
184
+ return (
185
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise
186
+ - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
187
+ )
188
+
189
+ @contextmanager
190
+ def ema_scope(self, context=None):
191
+ if self.use_ema:
192
+ self.model_ema.store(self.model.parameters())
193
+ self.model_ema.copy_to(self.model)
194
+ if context is not None:
195
+ main_logger.info(f"{context}: Switched to EMA weights")
196
+ try:
197
+ yield None
198
+ finally:
199
+ if self.use_ema:
200
+ self.model_ema.restore(self.model.parameters())
201
+ if context is not None:
202
+ main_logger.info(f"{context}: Restored training weights")
203
+
204
+ def q_mean_variance(self, x_start, t):
205
+ """
206
+ Get the distribution q(x_t | x_0).
207
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
208
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
209
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
210
+ """
211
+ mean = extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
212
+ variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
213
+ log_variance = extract_into_tensor(
214
+ self.log_one_minus_alphas_cumprod, t, x_start.shape
215
+ )
216
+ return mean, variance, log_variance
217
+
218
+ def q_posterior(self, x_start, x_t, t):
219
+ posterior_mean = (
220
+ extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
221
+ + extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
222
+ )
223
+ posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
224
+ posterior_log_variance_clipped = extract_into_tensor(
225
+ self.posterior_log_variance_clipped, t, x_t.shape
226
+ )
227
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
228
+
229
+ def q_sample(self, x_start, t, noise=None):
230
+ noise = default(noise, lambda: torch.randn_like(x_start))
231
+ return (
232
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
233
+ + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
234
+ * noise
235
+ )
236
+
237
+ def get_loss(self, pred, target, mean=True):
238
+ if self.loss_type == "l1":
239
+ loss = (target - pred).abs()
240
+ if mean:
241
+ loss = loss.mean()
242
+ elif self.loss_type == "l2":
243
+ if mean:
244
+ loss = torch.nn.functional.mse_loss(target, pred)
245
+ else:
246
+ loss = torch.nn.functional.mse_loss(target, pred, reduction="none")
247
+ else:
248
+ raise NotImplementedError("unknown loss type '{loss_type}'")
249
+
250
+ return loss
251
+
252
+ def on_train_batch_end(self, *args, **kwargs):
253
+ if self.use_ema:
254
+ self.model_ema(self.model)
255
+
256
+ def _get_rows_from_list(self, samples):
257
+ n_imgs_per_row = len(samples)
258
+ denoise_grid = rearrange(samples, "n b c h w -> b n c h w")
259
+ denoise_grid = rearrange(denoise_grid, "b n c h w -> (b n) c h w")
260
+ denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
261
+ return denoise_grid
262
+
263
+
264
+ class DualStreamMultiViewDiffusionModel(AbstractDDPM):
265
+
266
+ def __init__(
267
+ self,
268
+ first_stage_config,
269
+ data_key_images,
270
+ data_key_rays,
271
+ data_key_text_condition=None,
272
+ ckpt_path=None,
273
+ cond_stage_config=None,
274
+ num_time_steps_cond=None,
275
+ cond_stage_trainable=False,
276
+ cond_stage_forward=None,
277
+ conditioning_key=None,
278
+ uncond_prob=0.2,
279
+ uncond_type="empty_seq",
280
+ scale_factor=1.0,
281
+ scale_by_std=False,
282
+ use_noise_offset=False,
283
+ use_dynamic_rescale=False,
284
+ base_scale=0.3,
285
+ turning_step=400,
286
+ per_frame_auto_encoding=False,
287
+ # added for LVDM
288
+ encoder_type="2d",
289
+ cond_frames=None,
290
+ logdir=None,
291
+ empty_params_only=False,
292
+ # Image Condition
293
+ cond_img_config=None,
294
+ image_proj_model_config=None,
295
+ random_cond=False,
296
+ padding=False,
297
+ cond_concat=False,
298
+ frame_mask=False,
299
+ use_camera_pose_query_transformer=False,
300
+ with_cond_binary_mask=False,
301
+ apply_condition_mask_in_training_loss=True,
302
+ separate_noise_and_condition=False,
303
+ condition_padding_with_anchor=False,
304
+ ray_as_image=False,
305
+ use_task_embedding=False,
306
+ use_ray_decoder_loss_high_frequency_isolation=False,
307
+ disable_ray_stream=False,
308
+ ray_loss_weight=1.0,
309
+ train_with_multi_view_feature_alignment=False,
310
+ use_text_cross_attention_condition=True,
311
+ *args,
312
+ **kwargs,
313
+ ):
314
+
315
+ self.image_proj_model = None
316
+ self.apply_condition_mask_in_training_loss = (
317
+ apply_condition_mask_in_training_loss
318
+ )
319
+ self.separate_noise_and_condition = separate_noise_and_condition
320
+ self.condition_padding_with_anchor = condition_padding_with_anchor
321
+ self.use_text_cross_attention_condition = use_text_cross_attention_condition
322
+
323
+ self.data_key_images = data_key_images
324
+ self.data_key_rays = data_key_rays
325
+ self.data_key_text_condition = data_key_text_condition
326
+
327
+ self.num_time_steps_cond = default(num_time_steps_cond, 1)
328
+ self.scale_by_std = scale_by_std
329
+ assert self.num_time_steps_cond <= kwargs["time_steps"]
330
+ self.shorten_cond_schedule = self.num_time_steps_cond > 1
331
+ super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
332
+
333
+ self.cond_stage_trainable = cond_stage_trainable
334
+ self.empty_params_only = empty_params_only
335
+ self.per_frame_auto_encoding = per_frame_auto_encoding
336
+ try:
337
+ self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
338
+ except:
339
+ self.num_downs = 0
340
+ if not scale_by_std:
341
+ self.scale_factor = scale_factor
342
+ else:
343
+ self.register_buffer("scale_factor", torch.tensor(scale_factor))
344
+ self.use_noise_offset = use_noise_offset
345
+ self.use_dynamic_rescale = use_dynamic_rescale
346
+ if use_dynamic_rescale:
347
+ scale_arr1 = np.linspace(1.0, base_scale, turning_step)
348
+ scale_arr2 = np.full(self.num_time_steps, base_scale)
349
+ scale_arr = np.concatenate((scale_arr1, scale_arr2))
350
+ to_torch = partial(torch.tensor, dtype=torch.float32)
351
+ self.register_buffer("scale_arr", to_torch(scale_arr))
352
+ self.instantiate_first_stage(first_stage_config)
353
+
354
+ if self.use_text_cross_attention_condition and cond_stage_config is not None:
355
+ self.instantiate_cond_stage(cond_stage_config)
356
+
357
+ self.first_stage_config = first_stage_config
358
+ self.cond_stage_config = cond_stage_config
359
+ self.clip_denoised = False
360
+
361
+ self.cond_stage_forward = cond_stage_forward
362
+ self.encoder_type = encoder_type
363
+ assert encoder_type in ["2d", "3d"]
364
+ self.uncond_prob = uncond_prob
365
+ self.classifier_free_guidance = True if uncond_prob > 0 else False
366
+ assert uncond_type in ["zero_embed", "empty_seq"]
367
+ self.uncond_type = uncond_type
368
+
369
+ if cond_frames is not None:
370
+ frame_len = self.temporal_length
371
+ assert cond_frames[-1] < frame_len, main_logger.info(
372
+ f"Error: conditioning frame index must not be greater than {frame_len}!"
373
+ )
374
+ cond_mask = torch.zeros(frame_len, dtype=torch.float32)
375
+ cond_mask[cond_frames] = 1.0
376
+ self.cond_mask = cond_mask[None, None, :, None, None]
377
+ else:
378
+ self.cond_mask = None
379
+
380
+ self.restarted_from_ckpt = False
381
+ if ckpt_path is not None:
382
+ self.init_from_ckpt(ckpt_path)
383
+ self.restarted_from_ckpt = True
384
+
385
+ self.logdir = logdir
386
+ self.with_cond_binary_mask = with_cond_binary_mask
387
+ self.random_cond = random_cond
388
+ self.padding = padding
389
+ self.cond_concat = cond_concat
390
+ self.frame_mask = frame_mask
391
+ self.use_img_context = True if cond_img_config is not None else False
392
+ self.use_camera_pose_query_transformer = use_camera_pose_query_transformer
393
+ if self.use_img_context:
394
+ self.init_img_embedder(cond_img_config, freeze=True)
395
+ self.init_projector(image_proj_model_config, trainable=True)
396
+
397
+ self.ray_as_image = ray_as_image
398
+ self.use_task_embedding = use_task_embedding
399
+ self.use_ray_decoder_loss_high_frequency_isolation = (
400
+ use_ray_decoder_loss_high_frequency_isolation
401
+ )
402
+ self.disable_ray_stream = disable_ray_stream
403
+ if disable_ray_stream:
404
+ assert (
405
+ not ray_as_image
406
+ and not self.model.diffusion_model.use_ray_decoder
407
+ and not self.model.diffusion_model.use_ray_decoder_residual
408
+ ), "Options related to ray decoder should not be enabled when disabling ray stream."
409
+ assert (
410
+ not use_task_embedding
411
+ and not self.model.diffusion_model.use_task_embedding
412
+ ), "Task embedding should not be enabled when disabling ray stream."
413
+ assert (
414
+ not self.model.diffusion_model.use_addition_ray_output_head
415
+ ), "Additional ray output head should not be enabled when disabling ray stream."
416
+ assert (
417
+ not self.model.diffusion_model.use_lora_for_rays_in_output_blocks
418
+ ), "LoRA for rays should not be enabled when disabling ray stream."
419
+ self.ray_loss_weight = ray_loss_weight
420
+ self.train_with_multi_view_feature_alignment = False
421
+ if train_with_multi_view_feature_alignment:
422
+ print(f"MultiViewFeatureExtractor is ignored during inference.")
423
+
424
+ def init_from_ckpt(self, checkpoint_path):
425
+ main_logger.info(f"Initializing model from checkpoint {checkpoint_path}...")
426
+
427
+ def grab_ipa_weight(state_dict):
428
+ ipa_state_dict = OrderedDict()
429
+ for n in list(state_dict.keys()):
430
+ if "to_k_ip" in n or "to_v_ip" in n:
431
+ ipa_state_dict[n] = state_dict[n]
432
+ elif "image_proj_model" in n:
433
+ if (
434
+ self.use_camera_pose_query_transformer
435
+ and "image_proj_model.latents" in n
436
+ ):
437
+ ipa_state_dict[n] = torch.cat(
438
+ [state_dict[n] for i in range(16)], dim=1
439
+ )
440
+ else:
441
+ ipa_state_dict[n] = state_dict[n]
442
+ return ipa_state_dict
443
+
444
+ state_dict = torch.load(checkpoint_path, map_location="cpu")
445
+ if "module" in state_dict.keys():
446
+ # deepspeed
447
+ target_state_dict = OrderedDict()
448
+ for key in state_dict["module"].keys():
449
+ target_state_dict[key[16:]] = state_dict["module"][key]
450
+ elif "state_dict" in list(state_dict.keys()):
451
+ target_state_dict = state_dict["state_dict"]
452
+ else:
453
+ raise KeyError("Weight key is not found in the state dict.")
454
+ ipa_state_dict = grab_ipa_weight(target_state_dict)
455
+ self.load_state_dict(ipa_state_dict, strict=False)
456
+ main_logger.info("Checkpoint loaded.")
457
+
458
+ def init_img_embedder(self, config, freeze=True):
459
+ embedder = instantiate_from_config(config)
460
+ if freeze:
461
+ self.embedder = embedder.eval()
462
+ self.embedder.train = disabled_train
463
+ for param in self.embedder.parameters():
464
+ param.requires_grad = False
465
+
466
+ def make_cond_schedule(
467
+ self,
468
+ ):
469
+ self.cond_ids = torch.full(
470
+ size=(self.num_time_steps,),
471
+ fill_value=self.num_time_steps - 1,
472
+ dtype=torch.long,
473
+ )
474
+ ids = torch.round(
475
+ torch.linspace(0, self.num_time_steps - 1, self.num_time_steps_cond)
476
+ ).long()
477
+ self.cond_ids[: self.num_time_steps_cond] = ids
478
+
479
+ def init_projector(self, config, trainable):
480
+ self.image_proj_model = instantiate_from_config(config)
481
+ if not trainable:
482
+ self.image_proj_model.eval()
483
+ self.image_proj_model.train = disabled_train
484
+ for param in self.image_proj_model.parameters():
485
+ param.requires_grad = False
486
+
487
+ @staticmethod
488
+ def pad_cond_images(batch_images):
489
+ h, w = batch_images.shape[-2:]
490
+ border = (w - h) // 2
491
+ # use padding at (W_t,W_b,H_t,H_b)
492
+ batch_images = torch.nn.functional.pad(
493
+ batch_images, (0, 0, border, border), "constant", 0
494
+ )
495
+ return batch_images
496
+
497
+ # Never delete this func: it is used in log_images() and inference stage
498
+ def get_image_embeds(self, batch_images, batch=None):
499
+ # input shape: b c h w
500
+ if self.padding:
501
+ batch_images = self.pad_cond_images(batch_images)
502
+ img_token = self.embedder(batch_images)
503
+ if self.use_camera_pose_query_transformer:
504
+ batch_size, num_views, _ = batch["target_poses"].shape
505
+ img_emb = self.image_proj_model(
506
+ img_token, batch["target_poses"].reshape(batch_size, num_views, 12)
507
+ )
508
+ else:
509
+ img_emb = self.image_proj_model(img_token)
510
+
511
+ return img_emb
512
+
513
+ @staticmethod
514
+ def get_input(batch, k):
515
+ x = batch[k]
516
+ """
517
+ # for image batch from image loader
518
+ if len(x.shape) == 4:
519
+ x = rearrange(x, 'b h w c -> b c h w')
520
+ """
521
+ x = x.to(memory_format=torch.contiguous_format) # .float()
522
+ return x
523
+
524
+ @rank_zero_only
525
+ @torch.no_grad()
526
+ def on_train_batch_start(self, batch, batch_idx, dataloader_idx=None):
527
+ # only for very first batch, reset the self.scale_factor
528
+ if (
529
+ self.scale_by_std
530
+ and self.current_epoch == 0
531
+ and self.global_step == 0
532
+ and batch_idx == 0
533
+ and not self.restarted_from_ckpt
534
+ ):
535
+ assert (
536
+ self.scale_factor == 1.0
537
+ ), "rather not use custom rescaling and std-rescaling simultaneously"
538
+ # set rescale weight to 1./std of encodings
539
+ main_logger.info("## USING STD-RESCALING ###")
540
+ x = self.get_input(batch, self.first_stage_key)
541
+ x = x.to(self.device)
542
+ encoder_posterior = self.encode_first_stage(x)
543
+ z = self.get_first_stage_encoding(encoder_posterior).detach()
544
+ del self.scale_factor
545
+ self.register_buffer("scale_factor", 1.0 / z.flatten().std())
546
+ main_logger.info(f"setting self.scale_factor to {self.scale_factor}")
547
+ main_logger.info("## USING STD-RESCALING ###")
548
+ main_logger.info(f"std={z.flatten().std()}")
549
+
550
+ def register_schedule(
551
+ self,
552
+ given_betas=None,
553
+ beta_schedule="linear",
554
+ time_steps=1000,
555
+ linear_start=1e-4,
556
+ linear_end=2e-2,
557
+ cosine_s=8e-3,
558
+ ):
559
+ if exists(given_betas):
560
+ betas = given_betas
561
+ else:
562
+ betas = make_beta_schedule(
563
+ beta_schedule,
564
+ time_steps,
565
+ linear_start=linear_start,
566
+ linear_end=linear_end,
567
+ cosine_s=cosine_s,
568
+ )
569
+
570
+ if self.rescale_betas_zero_snr:
571
+ betas = rescale_zero_terminal_snr(betas)
572
+ alphas = 1.0 - betas
573
+ alphas_cumprod = np.cumprod(alphas, axis=0)
574
+ alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
575
+
576
+ (time_steps,) = betas.shape
577
+ self.num_time_steps = int(time_steps)
578
+ self.linear_start = linear_start
579
+ self.linear_end = linear_end
580
+ assert (
581
+ alphas_cumprod.shape[0] == self.num_time_steps
582
+ ), "alphas have to be defined for each timestep"
583
+
584
+ to_torch = partial(torch.tensor, dtype=torch.float32)
585
+
586
+ self.register_buffer("betas", to_torch(betas))
587
+ self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
588
+ self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev))
589
+
590
+ # calculations for diffusion q(x_t | x_{t-1}) and others
591
+ self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod)))
592
+ self.register_buffer(
593
+ "sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod))
594
+ )
595
+ self.register_buffer(
596
+ "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod))
597
+ )
598
+ self.register_buffer(
599
+ "sqrt_recip_alphas_cumprod",
600
+ to_torch(np.sqrt(1.0 / (alphas_cumprod + 1e-5))),
601
+ )
602
+ self.register_buffer(
603
+ "sqrt_recipm1_alphas_cumprod",
604
+ to_torch(np.sqrt(1.0 / (alphas_cumprod + 1e-5) - 1)),
605
+ )
606
+
607
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
608
+ posterior_variance = (1 - self.v_posterior) * betas * (
609
+ 1.0 - alphas_cumprod_prev
610
+ ) / (1.0 - alphas_cumprod) + self.v_posterior * betas
611
+ # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
612
+ self.register_buffer("posterior_variance", to_torch(posterior_variance))
613
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
614
+ self.register_buffer(
615
+ "posterior_log_variance_clipped",
616
+ to_torch(np.log(np.maximum(posterior_variance, 1e-20))),
617
+ )
618
+ self.register_buffer(
619
+ "posterior_mean_coef1",
620
+ to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)),
621
+ )
622
+ self.register_buffer(
623
+ "posterior_mean_coef2",
624
+ to_torch(
625
+ (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod)
626
+ ),
627
+ )
628
+
629
+ if self.parameterization == "eps":
630
+ lvlb_weights = self.betas**2 / (
631
+ 2
632
+ * self.posterior_variance
633
+ * to_torch(alphas)
634
+ * (1 - self.alphas_cumprod)
635
+ )
636
+ elif self.parameterization == "x0":
637
+ lvlb_weights = (
638
+ 0.5
639
+ * np.sqrt(torch.Tensor(alphas_cumprod))
640
+ / (2.0 * 1 - torch.Tensor(alphas_cumprod))
641
+ )
642
+ elif self.parameterization == "v":
643
+ lvlb_weights = torch.ones_like(
644
+ self.betas**2
645
+ / (
646
+ 2
647
+ * self.posterior_variance
648
+ * to_torch(alphas)
649
+ * (1 - self.alphas_cumprod)
650
+ )
651
+ )
652
+ else:
653
+ raise NotImplementedError("mu not supported")
654
+ lvlb_weights[0] = lvlb_weights[1]
655
+ self.register_buffer("lvlb_weights", lvlb_weights, persistent=False)
656
+ assert not torch.isnan(self.lvlb_weights).all()
657
+
658
+ if self.shorten_cond_schedule:
659
+ self.make_cond_schedule()
660
+
661
+ def instantiate_first_stage(self, config):
662
+ model = instantiate_from_config(config)
663
+ self.first_stage_model = model.eval()
664
+ self.first_stage_model.train = disabled_train
665
+ for param in self.first_stage_model.parameters():
666
+ param.requires_grad = False
667
+
668
+ def instantiate_cond_stage(self, config):
669
+ if not self.cond_stage_trainable:
670
+ model = instantiate_from_config(config)
671
+ self.cond_stage_model = model.eval()
672
+ self.cond_stage_model.train = disabled_train
673
+ for param in self.cond_stage_model.parameters():
674
+ param.requires_grad = False
675
+ else:
676
+ model = instantiate_from_config(config)
677
+ self.cond_stage_model = model
678
+
679
+ def get_learned_conditioning(self, c):
680
+ if self.cond_stage_forward is None:
681
+ if hasattr(self.cond_stage_model, "encode") and callable(
682
+ self.cond_stage_model.encode
683
+ ):
684
+ c = self.cond_stage_model.encode(c)
685
+ if isinstance(c, DiagonalGaussianDistribution):
686
+ c = c.mode()
687
+ else:
688
+ c = self.cond_stage_model(c)
689
+ else:
690
+ assert hasattr(self.cond_stage_model, self.cond_stage_forward)
691
+ c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
692
+ return c
693
+
694
+ def get_first_stage_encoding(self, encoder_posterior, noise=None):
695
+ if isinstance(encoder_posterior, DiagonalGaussianDistribution):
696
+ z = encoder_posterior.sample(noise=noise)
697
+ elif isinstance(encoder_posterior, torch.Tensor):
698
+ z = encoder_posterior
699
+ else:
700
+ raise NotImplementedError(
701
+ f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented"
702
+ )
703
+ return self.scale_factor * z
704
+
705
+ @torch.no_grad()
706
+ def encode_first_stage(self, x):
707
+ assert x.dim() == 5 or x.dim() == 4, (
708
+ "Images should be a either 5-dimensional (batched image sequence) "
709
+ "or 4-dimensional (batched images)."
710
+ )
711
+ if (
712
+ self.encoder_type == "2d"
713
+ and x.dim() == 5
714
+ and not self.per_frame_auto_encoding
715
+ ):
716
+ b, t, _, _, _ = x.shape
717
+ x = rearrange(x, "b t c h w -> (b t) c h w")
718
+ reshape_back = True
719
+ else:
720
+ b, _, _, _, _ = x.shape
721
+ t = 1
722
+ reshape_back = False
723
+
724
+ if not self.per_frame_auto_encoding:
725
+ encoder_posterior = self.first_stage_model.encode(x)
726
+ results = self.get_first_stage_encoding(encoder_posterior).detach()
727
+ else:
728
+ results = []
729
+ for index in range(x.shape[1]):
730
+ frame_batch = self.first_stage_model.encode(x[:, index, :, :, :])
731
+ frame_result = self.get_first_stage_encoding(frame_batch).detach()
732
+ results.append(frame_result)
733
+ results = torch.stack(results, dim=1)
734
+
735
+ if reshape_back:
736
+ results = rearrange(results, "(b t) c h w -> b t c h w", b=b, t=t)
737
+
738
+ return results
739
+
740
+ def decode_core(self, z, **kwargs):
741
+ assert z.dim() == 5 or z.dim() == 4, (
742
+ "Latents should be a either 5-dimensional (batched latent sequence) "
743
+ "or 4-dimensional (batched latents)."
744
+ )
745
+
746
+ if (
747
+ self.encoder_type == "2d"
748
+ and z.dim() == 5
749
+ and not self.per_frame_auto_encoding
750
+ ):
751
+ b, t, _, _, _ = z.shape
752
+ z = rearrange(z, "b t c h w -> (b t) c h w")
753
+ reshape_back = True
754
+ else:
755
+ b, _, _, _, _ = z.shape
756
+ t = 1
757
+ reshape_back = False
758
+
759
+ if not self.per_frame_auto_encoding:
760
+ z = 1.0 / self.scale_factor * z
761
+ results = self.first_stage_model.decode(z, **kwargs)
762
+ else:
763
+ results = []
764
+ for index in range(z.shape[1]):
765
+ frame_z = 1.0 / self.scale_factor * z[:, index, :, :, :]
766
+ frame_result = self.first_stage_model.decode(frame_z, **kwargs)
767
+ results.append(frame_result)
768
+ results = torch.stack(results, dim=1)
769
+
770
+ if reshape_back:
771
+ results = rearrange(results, "(b t) c h w -> b t c h w", b=b, t=t)
772
+ return results
773
+
774
+ @torch.no_grad()
775
+ def decode_first_stage(self, z, **kwargs):
776
+ return self.decode_core(z, **kwargs)
777
+
778
+ def differentiable_decode_first_stage(self, z, **kwargs):
779
+ return self.decode_core(z, **kwargs)
780
+
781
+ def get_batch_input(
782
+ self,
783
+ batch,
784
+ random_drop_training_conditions,
785
+ return_reconstructed_target_images=False,
786
+ ):
787
+ combined_images = batch[self.data_key_images]
788
+ clean_combined_image_latents = self.encode_first_stage(combined_images)
789
+ mask_preserving_target = batch["mask_preserving_target"].reshape(
790
+ batch["mask_preserving_target"].size(0),
791
+ batch["mask_preserving_target"].size(1),
792
+ 1,
793
+ 1,
794
+ 1,
795
+ )
796
+ mask_preserving_condition = 1.0 - mask_preserving_target
797
+ if self.ray_as_image:
798
+ clean_combined_ray_images = batch[self.data_key_rays]
799
+ clean_combined_ray_o_latents = self.encode_first_stage(
800
+ clean_combined_ray_images[:, :, :3, :, :]
801
+ )
802
+ clean_combined_ray_d_latents = self.encode_first_stage(
803
+ clean_combined_ray_images[:, :, 3:, :, :]
804
+ )
805
+ clean_combined_rays = torch.concat(
806
+ [clean_combined_ray_o_latents, clean_combined_ray_d_latents], dim=2
807
+ )
808
+
809
+ if self.condition_padding_with_anchor:
810
+ condition_ray_images = batch["condition_rays"]
811
+ condition_ray_o_images = self.encode_first_stage(
812
+ condition_ray_images[:, :, :3, :, :]
813
+ )
814
+ condition_ray_d_images = self.encode_first_stage(
815
+ condition_ray_images[:, :, 3:, :, :]
816
+ )
817
+ condition_rays = torch.concat(
818
+ [condition_ray_o_images, condition_ray_d_images], dim=2
819
+ )
820
+ else:
821
+ condition_rays = clean_combined_rays * mask_preserving_target
822
+ else:
823
+ clean_combined_rays = batch[self.data_key_rays]
824
+
825
+ if self.condition_padding_with_anchor:
826
+ condition_rays = batch["condition_rays"]
827
+ else:
828
+ condition_rays = clean_combined_rays * mask_preserving_target
829
+
830
+ if self.condition_padding_with_anchor:
831
+ condition_images_latents = self.encode_first_stage(
832
+ batch["condition_images"]
833
+ )
834
+ else:
835
+ condition_images_latents = (
836
+ clean_combined_image_latents * mask_preserving_condition
837
+ )
838
+
839
+ if random_drop_training_conditions:
840
+ random_num = torch.rand(
841
+ combined_images.size(0), device=combined_images.device
842
+ )
843
+ else:
844
+ random_num = torch.ones(
845
+ combined_images.size(0), device=combined_images.device
846
+ )
847
+
848
+ text_feature_condition_mask = rearrange(
849
+ random_num < 2 * self.uncond_prob, "n -> n 1 1"
850
+ )
851
+ image_feature_condition_mask = 1 - rearrange(
852
+ (random_num >= self.uncond_prob).float()
853
+ * (random_num < 3 * self.uncond_prob).float(),
854
+ "n -> n 1 1 1 1",
855
+ )
856
+ ray_condition_mask = 1 - rearrange(
857
+ (random_num >= 1.5 * self.uncond_prob).float()
858
+ * (random_num < 3.5 * self.uncond_prob).float(),
859
+ "n -> n 1 1 1 1",
860
+ )
861
+ mask_preserving_first_target = batch[
862
+ "mask_only_preserving_first_target"
863
+ ].reshape(
864
+ batch["mask_only_preserving_first_target"].size(0),
865
+ batch["mask_only_preserving_first_target"].size(1),
866
+ 1,
867
+ 1,
868
+ 1,
869
+ )
870
+ mask_preserving_first_condition = batch[
871
+ "mask_only_preserving_first_condition"
872
+ ].reshape(
873
+ batch["mask_only_preserving_first_condition"].size(0),
874
+ batch["mask_only_preserving_first_condition"].size(1),
875
+ 1,
876
+ 1,
877
+ 1,
878
+ )
879
+ mask_preserving_anchors = (
880
+ mask_preserving_first_target + mask_preserving_first_condition
881
+ )
882
+ mask_randomly_preserving_first_target = torch.where(
883
+ ray_condition_mask.repeat(1, mask_preserving_first_target.size(1), 1, 1, 1)
884
+ == 1.0,
885
+ 1.0,
886
+ mask_preserving_first_target,
887
+ )
888
+ mask_randomly_preserving_first_condition = torch.where(
889
+ image_feature_condition_mask.repeat(
890
+ 1, mask_preserving_first_condition.size(1), 1, 1, 1
891
+ )
892
+ == 1.0,
893
+ 1.0,
894
+ mask_preserving_first_condition,
895
+ )
896
+
897
+ if self.use_text_cross_attention_condition:
898
+ text_cond_key = self.data_key_text_condition
899
+ text_cond = batch[text_cond_key]
900
+ if isinstance(text_cond, dict) or isinstance(text_cond, list):
901
+ full_text_cond_emb = self.get_learned_conditioning(text_cond)
902
+ else:
903
+ full_text_cond_emb = self.get_learned_conditioning(
904
+ text_cond.to(self.device)
905
+ )
906
+ null_text_cond_emb = self.get_learned_conditioning([""])
907
+ text_cond_emb = torch.where(
908
+ text_feature_condition_mask,
909
+ null_text_cond_emb,
910
+ full_text_cond_emb.detach(),
911
+ )
912
+
913
+ batch_size, num_views, _, _, _ = batch[self.data_key_images].shape
914
+ if self.condition_padding_with_anchor:
915
+ condition_images = batch["condition_images"]
916
+ else:
917
+ condition_images = combined_images * mask_preserving_condition
918
+ if random_drop_training_conditions:
919
+ condition_image_for_embedder = rearrange(
920
+ condition_images * image_feature_condition_mask,
921
+ "b t c h w -> (b t) c h w",
922
+ )
923
+ else:
924
+ condition_image_for_embedder = rearrange(
925
+ condition_images, "b t c h w -> (b t) c h w"
926
+ )
927
+ img_token = self.embedder(condition_image_for_embedder)
928
+ if self.use_camera_pose_query_transformer:
929
+ img_emb = self.image_proj_model(
930
+ img_token, batch["target_poses"].reshape(batch_size, num_views, 12)
931
+ )
932
+ else:
933
+ img_emb = self.image_proj_model(img_token)
934
+
935
+ img_emb = rearrange(
936
+ img_emb, "(b t) s d -> b (t s) d", b=batch_size, t=num_views
937
+ )
938
+ if self.use_text_cross_attention_condition:
939
+ c_crossattn = [torch.cat([text_cond_emb, img_emb], dim=1)]
940
+ else:
941
+ c_crossattn = [img_emb]
942
+
943
+ cond_dict = {
944
+ "c_crossattn": c_crossattn,
945
+ "target_camera_poses": batch["target_and_condition_camera_poses"]
946
+ * batch["mask_preserving_target"].unsqueeze(-1),
947
+ }
948
+
949
+ if self.disable_ray_stream:
950
+ clean_gt = torch.cat([clean_combined_image_latents], dim=2)
951
+ else:
952
+ clean_gt = torch.cat(
953
+ [clean_combined_image_latents, clean_combined_rays], dim=2
954
+ )
955
+ if random_drop_training_conditions:
956
+ combined_condition = torch.cat(
957
+ [
958
+ condition_images_latents * mask_randomly_preserving_first_condition,
959
+ condition_rays * mask_randomly_preserving_first_target,
960
+ ],
961
+ dim=2,
962
+ )
963
+ else:
964
+ combined_condition = torch.cat(
965
+ [condition_images_latents, condition_rays], dim=2
966
+ )
967
+
968
+ uncond_combined_condition = torch.cat(
969
+ [
970
+ condition_images_latents * mask_preserving_anchors,
971
+ condition_rays * mask_preserving_anchors,
972
+ ],
973
+ dim=2,
974
+ )
975
+
976
+ mask_full_for_input = torch.cat(
977
+ [
978
+ mask_preserving_condition.repeat(
979
+ 1, 1, condition_images_latents.size(2), 1, 1
980
+ ),
981
+ mask_preserving_target.repeat(1, 1, condition_rays.size(2), 1, 1),
982
+ ],
983
+ dim=2,
984
+ )
985
+ cond_dict.update(
986
+ {
987
+ "mask_preserving_target": mask_preserving_target,
988
+ "mask_preserving_condition": mask_preserving_condition,
989
+ "combined_condition": combined_condition,
990
+ "uncond_combined_condition": uncond_combined_condition,
991
+ "clean_combined_rays": clean_combined_rays,
992
+ "mask_full_for_input": mask_full_for_input,
993
+ "num_cond_images": rearrange(
994
+ batch["num_cond_images"].float(), "b -> b 1 1 1 1"
995
+ ),
996
+ "num_target_images": rearrange(
997
+ batch["num_target_images"].float(), "b -> b 1 1 1 1"
998
+ ),
999
+ }
1000
+ )
1001
+
1002
+ out = [clean_gt, cond_dict]
1003
+ if return_reconstructed_target_images:
1004
+ target_images_reconstructed = self.decode_first_stage(
1005
+ clean_combined_image_latents
1006
+ )
1007
+ out.append(target_images_reconstructed)
1008
+ return out
1009
+
1010
+ def get_dynamic_scales(self, t, spin_step=400):
1011
+ base_scale = self.base_scale
1012
+ scale_t = torch.where(
1013
+ t < spin_step,
1014
+ t * (base_scale - 1.0) / spin_step + 1.0,
1015
+ base_scale * torch.ones_like(t),
1016
+ )
1017
+ return scale_t
1018
+
1019
+ def forward(self, x, c, **kwargs):
1020
+ t = torch.randint(
1021
+ 0, self.num_time_steps, (x.shape[0],), device=self.device
1022
+ ).long()
1023
+ if self.use_dynamic_rescale:
1024
+ x = x * extract_into_tensor(self.scale_arr, t, x.shape)
1025
+ return self.p_losses(x, c, t, **kwargs)
1026
+
1027
+ def extract_feature(self, batch, t, **kwargs):
1028
+ z, cond = self.get_batch_input(
1029
+ batch,
1030
+ random_drop_training_conditions=False,
1031
+ return_reconstructed_target_images=False,
1032
+ )
1033
+ if self.use_dynamic_rescale:
1034
+ z = z * extract_into_tensor(self.scale_arr, t, z.shape)
1035
+ noise = torch.randn_like(z)
1036
+ if self.use_noise_offset:
1037
+ noise = noise + 0.1 * torch.randn(
1038
+ noise.shape[0], noise.shape[1], 1, 1, 1
1039
+ ).to(self.device)
1040
+ x_noisy = self.q_sample(x_start=z, t=t, noise=noise)
1041
+ x_noisy = self.process_x_with_condition(x_noisy, condition_dict=cond)
1042
+ c_crossattn = torch.cat(cond["c_crossattn"], 1)
1043
+ target_camera_poses = cond["target_camera_poses"]
1044
+ x_pred, features = self.model(
1045
+ x_noisy,
1046
+ t,
1047
+ context=c_crossattn,
1048
+ return_output_block_features=True,
1049
+ camera_poses=target_camera_poses,
1050
+ **kwargs,
1051
+ )
1052
+ return x_pred, features, z
1053
+
1054
+ def apply_model(self, x_noisy, t, cond, features_to_return=None, **kwargs):
1055
+ if not isinstance(cond, dict):
1056
+ if not isinstance(cond, list):
1057
+ cond = [cond]
1058
+ key = (
1059
+ "c_concat" if self.model.conditioning_key == "concat" else "c_crossattn"
1060
+ )
1061
+ cond = {key: cond}
1062
+
1063
+ c_crossattn = torch.cat(cond["c_crossattn"], 1)
1064
+ x_noisy = self.process_x_with_condition(x_noisy, condition_dict=cond)
1065
+ target_camera_poses = cond["target_camera_poses"]
1066
+ if self.use_task_embedding:
1067
+ x_pred_images = self.model(
1068
+ x_noisy,
1069
+ t,
1070
+ context=c_crossattn,
1071
+ task_idx=TASK_IDX_IMAGE,
1072
+ camera_poses=target_camera_poses,
1073
+ **kwargs,
1074
+ )
1075
+ x_pred_rays = self.model(
1076
+ x_noisy,
1077
+ t,
1078
+ context=c_crossattn,
1079
+ task_idx=TASK_IDX_RAY,
1080
+ camera_poses=target_camera_poses,
1081
+ **kwargs,
1082
+ )
1083
+ x_pred = torch.concat([x_pred_images, x_pred_rays], dim=2)
1084
+ elif features_to_return is not None:
1085
+ x_pred, features = self.model(
1086
+ x_noisy,
1087
+ t,
1088
+ context=c_crossattn,
1089
+ return_input_block_features="input" in features_to_return,
1090
+ return_middle_feature="middle" in features_to_return,
1091
+ return_output_block_features="output" in features_to_return,
1092
+ camera_poses=target_camera_poses,
1093
+ **kwargs,
1094
+ )
1095
+ return x_pred, features
1096
+ elif self.train_with_multi_view_feature_alignment:
1097
+ x_pred, aligned_features = self.model(
1098
+ x_noisy,
1099
+ t,
1100
+ context=c_crossattn,
1101
+ camera_poses=target_camera_poses,
1102
+ **kwargs,
1103
+ )
1104
+ return x_pred, aligned_features
1105
+ else:
1106
+ x_pred = self.model(
1107
+ x_noisy,
1108
+ t,
1109
+ context=c_crossattn,
1110
+ camera_poses=target_camera_poses,
1111
+ **kwargs,
1112
+ )
1113
+ return x_pred
1114
+
1115
+ def process_x_with_condition(self, x_noisy, condition_dict):
1116
+ combined_condition = condition_dict["combined_condition"]
1117
+ if self.separate_noise_and_condition:
1118
+ if self.disable_ray_stream:
1119
+ x_noisy = torch.concat([x_noisy, combined_condition], dim=2)
1120
+ else:
1121
+ x_noisy = torch.concat(
1122
+ [
1123
+ x_noisy[:, :, :4, :, :],
1124
+ combined_condition[:, :, :4, :, :],
1125
+ x_noisy[:, :, 4:, :, :],
1126
+ combined_condition[:, :, 4:, :, :],
1127
+ ],
1128
+ dim=2,
1129
+ )
1130
+ else:
1131
+ assert (
1132
+ not self.use_ray_decoder_regression
1133
+ ), "`separate_noise_and_condition` must be True when enabling `use_ray_decoder_regression`."
1134
+ mask_preserving_target = condition_dict["mask_preserving_target"]
1135
+ mask_preserving_condition = condition_dict["mask_preserving_condition"]
1136
+ mask_for_combined_condition = torch.cat(
1137
+ [
1138
+ mask_preserving_target.repeat(1, 1, 4, 1, 1),
1139
+ mask_preserving_condition.repeat(1, 1, 6, 1, 1),
1140
+ ]
1141
+ )
1142
+ mask_for_x_noisy = torch.cat(
1143
+ [
1144
+ mask_preserving_target.repeat(1, 1, 4, 1, 1),
1145
+ mask_preserving_condition.repeat(1, 1, 6, 1, 1),
1146
+ ]
1147
+ )
1148
+ x_noisy = (
1149
+ x_noisy * mask_for_x_noisy
1150
+ + combined_condition * mask_for_combined_condition
1151
+ )
1152
+
1153
+ return x_noisy
1154
+
1155
+ def p_losses(self, x_start, cond, t, noise=None, **kwargs):
1156
+
1157
+ noise = default(noise, lambda: torch.randn_like(x_start))
1158
+
1159
+ if self.use_noise_offset:
1160
+ noise = noise + 0.1 * torch.randn(
1161
+ noise.shape[0], noise.shape[1], 1, 1, 1
1162
+ ).to(self.device)
1163
+
1164
+ # noise em !!!
1165
+ if self.bd_noise:
1166
+ noise_decor = self.bd(noise)
1167
+ noise_decor = (noise_decor - noise_decor.mean()) / (
1168
+ noise_decor.std() + 1e-5
1169
+ )
1170
+ noise_f = noise_decor[:, :, 0:1, :, :]
1171
+ noise = (
1172
+ np.sqrt(self.bd_ratio) * noise_decor[:, :, 1:]
1173
+ + np.sqrt(1 - self.bd_ratio) * noise_f
1174
+ )
1175
+ noise = torch.cat([noise_f, noise], dim=2)
1176
+
1177
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
1178
+ if self.train_with_multi_view_feature_alignment:
1179
+ model_output, aligned_features = self.apply_model(
1180
+ x_noisy, t, cond, **kwargs
1181
+ )
1182
+
1183
+ aligned_middle_feature = rearrange(
1184
+ aligned_features,
1185
+ "(b t) c h w -> b (t c h w)",
1186
+ b=cond["pts_anchor_to_all"].size(0),
1187
+ t=cond["pts_anchor_to_all"].size(1),
1188
+ )
1189
+ target_multi_view_feature = rearrange(
1190
+ torch.concat(
1191
+ [cond["pts_anchor_to_all"], cond["pts_all_to_anchor"]], dim=2
1192
+ ),
1193
+ "b t c h w -> b (t c h w)",
1194
+ ).to(aligned_middle_feature.device)
1195
+ else:
1196
+ model_output = self.apply_model(x_noisy, t, cond, **kwargs)
1197
+
1198
+ loss_dict = {}
1199
+ prefix = "train" if self.training else "val"
1200
+
1201
+ if self.parameterization == "x0":
1202
+ target = x_start
1203
+ elif self.parameterization == "eps":
1204
+ target = noise
1205
+ elif self.parameterization == "v":
1206
+ target = self.get_v(x_start, noise, t)
1207
+ else:
1208
+ raise NotImplementedError()
1209
+
1210
+ if self.apply_condition_mask_in_training_loss:
1211
+ mask_full_for_output = 1.0 - cond["mask_full_for_input"]
1212
+ model_output = model_output * mask_full_for_output
1213
+ target = target * mask_full_for_output
1214
+ loss_simple = self.get_loss(model_output, target, mean=False)
1215
+ if self.ray_loss_weight != 1.0:
1216
+ loss_simple[:, :, 4:, :, :] = (
1217
+ loss_simple[:, :, 4:, :, :] * self.ray_loss_weight
1218
+ )
1219
+ if self.apply_condition_mask_in_training_loss:
1220
+ # Ray loss: predicted items = # of condition images
1221
+ num_total_images = cond["num_cond_images"] + cond["num_target_images"]
1222
+ weight_for_image_loss = num_total_images / cond["num_target_images"]
1223
+ weight_for_ray_loss = num_total_images / cond["num_cond_images"]
1224
+ loss_simple[:, :, :4, :, :] = (
1225
+ loss_simple[:, :, :4, :, :] * weight_for_image_loss
1226
+ )
1227
+ # Ray loss: predicted items = # of condition images
1228
+ loss_simple[:, :, 4:, :, :] = (
1229
+ loss_simple[:, :, 4:, :, :] * weight_for_ray_loss
1230
+ )
1231
+
1232
+ loss_dict.update({f"{prefix}/loss_images": loss_simple[:, :, 0:4, :, :].mean()})
1233
+ if not self.disable_ray_stream:
1234
+ loss_dict.update(
1235
+ {f"{prefix}/loss_rays": loss_simple[:, :, 4:, :, :].mean()}
1236
+ )
1237
+ loss_simple = loss_simple.mean([1, 2, 3, 4])
1238
+ loss_dict.update({f"{prefix}/loss_simple": loss_simple.mean()})
1239
+
1240
+ if self.logvar.device is not self.device:
1241
+ self.logvar = self.logvar.to(self.device)
1242
+ logvar_t = self.logvar[t]
1243
+ loss = loss_simple / torch.exp(logvar_t) + logvar_t
1244
+ if self.learn_logvar:
1245
+ loss_dict.update({f"{prefix}/loss_gamma": loss.mean()})
1246
+ loss_dict.update({"logvar": self.logvar.data.mean()})
1247
+
1248
+ loss = self.l_simple_weight * loss.mean()
1249
+
1250
+ if self.train_with_multi_view_feature_alignment:
1251
+ multi_view_feature_alignment_loss = 0.25 * torch.nn.functional.mse_loss(
1252
+ aligned_middle_feature, target_multi_view_feature
1253
+ )
1254
+ loss += multi_view_feature_alignment_loss
1255
+ loss_dict.update(
1256
+ {f"{prefix}/loss_mv_feat_align": multi_view_feature_alignment_loss}
1257
+ )
1258
+
1259
+ loss_vlb = self.get_loss(model_output, target, mean=False).mean(
1260
+ dim=(1, 2, 3, 4)
1261
+ )
1262
+ loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
1263
+ loss_dict.update({f"{prefix}/loss_vlb": loss_vlb})
1264
+ loss += self.original_elbo_weight * loss_vlb
1265
+ loss_dict.update({f"{prefix}/loss": loss})
1266
+
1267
+ return loss, loss_dict
1268
+
1269
+ def _get_denoise_row_from_list(self, samples, desc=""):
1270
+ denoise_row = []
1271
+ for zd in tqdm(samples, desc=desc):
1272
+ denoise_row.append(self.decode_first_stage(zd.to(self.device)))
1273
+ n_log_time_steps = len(denoise_row)
1274
+
1275
+ denoise_row = torch.stack(denoise_row) # n_log_time_steps, b, C, H, W
1276
+
1277
+ if denoise_row.dim() == 5:
1278
+ denoise_grid = rearrange(denoise_row, "n b c h w -> b n c h w")
1279
+ denoise_grid = rearrange(denoise_grid, "b n c h w -> (b n) c h w")
1280
+ denoise_grid = make_grid(denoise_grid, nrow=n_log_time_steps)
1281
+ elif denoise_row.dim() == 6:
1282
+ video_length = denoise_row.shape[3]
1283
+ denoise_grid = rearrange(denoise_row, "n b c t h w -> b n c t h w")
1284
+ denoise_grid = rearrange(denoise_grid, "b n c t h w -> (b n) c t h w")
1285
+ denoise_grid = rearrange(denoise_grid, "n c t h w -> (n t) c h w")
1286
+ denoise_grid = make_grid(denoise_grid, nrow=video_length)
1287
+ else:
1288
+ raise ValueError
1289
+
1290
+ return denoise_grid
1291
+
1292
+ @torch.no_grad()
1293
+ def log_images(
1294
+ self,
1295
+ batch,
1296
+ sample=True,
1297
+ ddim_steps=50,
1298
+ ddim_eta=1.0,
1299
+ plot_denoise_rows=False,
1300
+ unconditional_guidance_scale=1.0,
1301
+ **kwargs,
1302
+ ):
1303
+ """log images for LatentDiffusion"""
1304
+ use_ddim = ddim_steps is not None
1305
+ log = dict()
1306
+ z, cond, x_rec = self.get_batch_input(
1307
+ batch,
1308
+ random_drop_training_conditions=False,
1309
+ return_reconstructed_target_images=True,
1310
+ )
1311
+ b, t, c, h, w = x_rec.shape
1312
+ log["num_cond_images_str"] = batch["num_cond_images_str"]
1313
+ log["caption"] = batch["caption"]
1314
+ if "condition_images" in batch:
1315
+ log["input_condition_images_all"] = batch["condition_images"]
1316
+ log["input_condition_image_latents_masked"] = cond["combined_condition"][
1317
+ :, :, 0:3, :, :
1318
+ ]
1319
+ log["input_condition_rays_o_masked"] = (
1320
+ cond["combined_condition"][:, :, 4:7, :, :] / 5.0
1321
+ )
1322
+ log["input_condition_rays_d_masked"] = (
1323
+ cond["combined_condition"][:, :, 7:, :, :] / 5.0
1324
+ )
1325
+ log["gt_images_after_vae"] = x_rec
1326
+ if self.train_with_multi_view_feature_alignment:
1327
+ log["pts_anchor_to_all"] = cond["pts_anchor_to_all"]
1328
+ log["pts_all_to_anchor"] = cond["pts_all_to_anchor"]
1329
+ log["pts_anchor_to_all"] = (
1330
+ log["pts_anchor_to_all"] - torch.min(log["pts_anchor_to_all"])
1331
+ ) / torch.max(log["pts_anchor_to_all"])
1332
+ log["pts_all_to_anchor"] = (
1333
+ log["pts_all_to_anchor"] - torch.min(log["pts_all_to_anchor"])
1334
+ ) / torch.max(log["pts_all_to_anchor"])
1335
+
1336
+ if self.ray_as_image:
1337
+ log["gt_rays_o"] = batch["combined_rays"][:, :, 0:3, :, :]
1338
+ log["gt_rays_d"] = batch["combined_rays"][:, :, 3:, :, :]
1339
+ else:
1340
+ log["gt_rays_o"] = batch["combined_rays"][:, :, 0:3, :, :] / 5.0
1341
+ log["gt_rays_d"] = batch["combined_rays"][:, :, 3:, :, :] / 5.0
1342
+
1343
+ if sample:
1344
+ # get uncond embedding for classifier-free guidance sampling
1345
+ if unconditional_guidance_scale != 1.0:
1346
+ uc = self.get_unconditional_dict_for_sampling(batch, cond, x_rec)
1347
+ else:
1348
+ uc = None
1349
+
1350
+ with self.ema_scope("Plotting"):
1351
+ out = self.sample_log(
1352
+ cond=cond,
1353
+ batch_size=b,
1354
+ ddim=use_ddim,
1355
+ ddim_steps=ddim_steps,
1356
+ eta=ddim_eta,
1357
+ unconditional_guidance_scale=unconditional_guidance_scale,
1358
+ unconditional_conditioning=uc,
1359
+ mask=self.cond_mask,
1360
+ x0=z,
1361
+ with_extra_returned_data=False,
1362
+ **kwargs,
1363
+ )
1364
+ samples, z_denoise_row = out
1365
+ per_instance_decoding = False
1366
+
1367
+ if per_instance_decoding:
1368
+ x_sample_images = []
1369
+ for idx in range(b):
1370
+ sample_image = samples[idx : idx + 1, :, 0:4, :, :]
1371
+ x_sample_image = self.decode_first_stage(sample_image)
1372
+ x_sample_images.append(x_sample_image)
1373
+ x_sample_images = torch.cat(x_sample_images, dim=0)
1374
+ else:
1375
+ x_sample_images = self.decode_first_stage(samples[:, :, 0:4, :, :])
1376
+ log["sample_images"] = x_sample_images
1377
+
1378
+ if not self.disable_ray_stream:
1379
+ if self.ray_as_image:
1380
+ log["sample_rays_o"] = self.decode_first_stage(
1381
+ samples[:, :, 4:8, :, :]
1382
+ )
1383
+ log["sample_rays_d"] = self.decode_first_stage(
1384
+ samples[:, :, 8:, :, :]
1385
+ )
1386
+ else:
1387
+ log["sample_rays_o"] = samples[:, :, 4:7, :, :] / 5.0
1388
+ log["sample_rays_d"] = samples[:, :, 7:, :, :] / 5.0
1389
+
1390
+ if plot_denoise_rows:
1391
+ denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
1392
+ log["denoise_row"] = denoise_grid
1393
+
1394
+ return log
1395
+
1396
+ def get_unconditional_dict_for_sampling(self, batch, cond, x_rec, is_extra=False):
1397
+ b, t, c, h, w = x_rec.shape
1398
+ if self.use_text_cross_attention_condition:
1399
+ if self.uncond_type == "empty_seq":
1400
+ # NVComposer's cross attention layers accept multi-view images
1401
+ prompts = b * [""]
1402
+ # prompts = b * t * [""] # if is_image_batch=True
1403
+ uc_emb = self.get_learned_conditioning(prompts)
1404
+ elif self.uncond_type == "zero_embed":
1405
+ c_emb = cond["c_crossattn"][0] if isinstance(cond, dict) else cond
1406
+ uc_emb = torch.zeros_like(c_emb)
1407
+ else:
1408
+ uc_emb = None
1409
+
1410
+ # process image condition
1411
+ if not is_extra:
1412
+ if hasattr(self, "embedder"):
1413
+ # uc_img = torch.zeros_like(x[:, :, 0, ...]) # b c h w
1414
+ uc_img = torch.zeros(
1415
+ # b c h w
1416
+ size=(b * t, c, h, w),
1417
+ dtype=x_rec.dtype,
1418
+ device=x_rec.device,
1419
+ )
1420
+ # img: b c h w >> b l c
1421
+ uc_img = self.get_image_embeds(uc_img, batch)
1422
+
1423
+ # Modified: The uc embeddings should be reshaped for valid post-processing
1424
+ uc_img = rearrange(
1425
+ uc_img, "(b t) s d -> b (t s) d", b=b, t=uc_img.shape[0] // b
1426
+ )
1427
+ if uc_emb is None:
1428
+ uc_emb = uc_img
1429
+ else:
1430
+ uc_emb = torch.cat([uc_emb, uc_img], dim=1)
1431
+ uc = {key: cond[key] for key in cond.keys()}
1432
+ uc.update({"c_crossattn": [uc_emb]})
1433
+ else:
1434
+ uc = {key: cond[key] for key in cond.keys()}
1435
+ uc.update({"combined_condition": uc["uncond_combined_condition"]})
1436
+
1437
+ return uc
1438
+
1439
+ def p_mean_variance(
1440
+ self,
1441
+ x,
1442
+ c,
1443
+ t,
1444
+ clip_denoised: bool,
1445
+ return_x0=False,
1446
+ score_corrector=None,
1447
+ corrector_kwargs=None,
1448
+ **kwargs,
1449
+ ):
1450
+ t_in = t
1451
+ model_out = self.apply_model(x, t_in, c, **kwargs)
1452
+
1453
+ if score_corrector is not None:
1454
+ assert self.parameterization == "eps"
1455
+ model_out = score_corrector.modify_score(
1456
+ self, model_out, x, t, c, **corrector_kwargs
1457
+ )
1458
+
1459
+ if self.parameterization == "eps":
1460
+ x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
1461
+ elif self.parameterization == "x0":
1462
+ x_recon = model_out
1463
+ else:
1464
+ raise NotImplementedError()
1465
+
1466
+ if clip_denoised:
1467
+ x_recon.clamp_(-1.0, 1.0)
1468
+
1469
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(
1470
+ x_start=x_recon, x_t=x, t=t
1471
+ )
1472
+
1473
+ if return_x0:
1474
+ return model_mean, posterior_variance, posterior_log_variance, x_recon
1475
+ else:
1476
+ return model_mean, posterior_variance, posterior_log_variance
1477
+
1478
+ @torch.no_grad()
1479
+ def p_sample(
1480
+ self,
1481
+ x,
1482
+ c,
1483
+ t,
1484
+ clip_denoised=False,
1485
+ repeat_noise=False,
1486
+ return_x0=False,
1487
+ temperature=1.0,
1488
+ noise_dropout=0.0,
1489
+ score_corrector=None,
1490
+ corrector_kwargs=None,
1491
+ **kwargs,
1492
+ ):
1493
+ b, *_, device = *x.shape, x.device
1494
+ outputs = self.p_mean_variance(
1495
+ x=x,
1496
+ c=c,
1497
+ t=t,
1498
+ clip_denoised=clip_denoised,
1499
+ return_x0=return_x0,
1500
+ score_corrector=score_corrector,
1501
+ corrector_kwargs=corrector_kwargs,
1502
+ **kwargs,
1503
+ )
1504
+ if return_x0:
1505
+ model_mean, _, model_log_variance, x0 = outputs
1506
+ else:
1507
+ model_mean, _, model_log_variance = outputs
1508
+
1509
+ noise = noise_like(x.shape, device, repeat_noise) * temperature
1510
+ if noise_dropout > 0.0:
1511
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
1512
+
1513
+ nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
1514
+
1515
+ if return_x0:
1516
+ return (
1517
+ model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise,
1518
+ x0,
1519
+ )
1520
+ else:
1521
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
1522
+
1523
+ @torch.no_grad()
1524
+ def p_sample_loop(
1525
+ self,
1526
+ cond,
1527
+ shape,
1528
+ return_intermediates=False,
1529
+ x_T=None,
1530
+ verbose=True,
1531
+ callback=None,
1532
+ time_steps=None,
1533
+ mask=None,
1534
+ x0=None,
1535
+ img_callback=None,
1536
+ start_T=None,
1537
+ log_every_t=None,
1538
+ **kwargs,
1539
+ ):
1540
+ if not log_every_t:
1541
+ log_every_t = self.log_every_t
1542
+ device = self.betas.device
1543
+ b = shape[0]
1544
+ if x_T is None:
1545
+ img = torch.randn(shape, device=device)
1546
+ else:
1547
+ img = x_T
1548
+
1549
+ intermediates = [img]
1550
+ if time_steps is None:
1551
+ time_steps = self.num_time_steps
1552
+ if start_T is not None:
1553
+ time_steps = min(time_steps, start_T)
1554
+
1555
+ iterator = (
1556
+ tqdm(reversed(range(0, time_steps)), desc="Sampling t", total=time_steps)
1557
+ if verbose
1558
+ else reversed(range(0, time_steps))
1559
+ )
1560
+
1561
+ if mask is not None:
1562
+ assert x0 is not None
1563
+ # spatial size has to match
1564
+ assert x0.shape[2:3] == mask.shape[2:3]
1565
+
1566
+ for i in iterator:
1567
+ ts = torch.full((b,), i, device=device, dtype=torch.long)
1568
+ if self.shorten_cond_schedule:
1569
+ assert self.model.conditioning_key != "hybrid"
1570
+ tc = self.cond_ids[ts].to(cond.device)
1571
+ cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
1572
+
1573
+ img = self.p_sample(
1574
+ img, cond, ts, clip_denoised=self.clip_denoised, **kwargs
1575
+ )
1576
+
1577
+ if mask is not None:
1578
+ img_orig = self.q_sample(x0, ts)
1579
+ img = img_orig * mask + (1.0 - mask) * img
1580
+
1581
+ if i % log_every_t == 0 or i == time_steps - 1:
1582
+ intermediates.append(img)
1583
+ if callback:
1584
+ callback(i)
1585
+ if img_callback:
1586
+ img_callback(img, i)
1587
+
1588
+ if return_intermediates:
1589
+ return img, intermediates
1590
+ return img
1591
+
1592
+ @torch.no_grad()
1593
+ def sample(
1594
+ self,
1595
+ cond,
1596
+ batch_size=16,
1597
+ return_intermediates=False,
1598
+ x_T=None,
1599
+ verbose=True,
1600
+ time_steps=None,
1601
+ mask=None,
1602
+ x0=None,
1603
+ shape=None,
1604
+ **kwargs,
1605
+ ):
1606
+ if shape is None:
1607
+ shape = (batch_size, self.channels, self.temporal_length, *self.image_size)
1608
+ if cond is not None:
1609
+ if isinstance(cond, dict):
1610
+ cond = {
1611
+ key: (
1612
+ cond[key][:batch_size]
1613
+ if not isinstance(cond[key], list)
1614
+ else list(map(lambda x: x[:batch_size], cond[key]))
1615
+ )
1616
+ for key in cond
1617
+ }
1618
+ else:
1619
+ cond = (
1620
+ [c[:batch_size] for c in cond]
1621
+ if isinstance(cond, list)
1622
+ else cond[:batch_size]
1623
+ )
1624
+ return self.p_sample_loop(
1625
+ cond,
1626
+ shape,
1627
+ return_intermediates=return_intermediates,
1628
+ x_T=x_T,
1629
+ verbose=verbose,
1630
+ time_steps=time_steps,
1631
+ mask=mask,
1632
+ x0=x0,
1633
+ **kwargs,
1634
+ )
1635
+
1636
+ @torch.no_grad()
1637
+ def sample_log(
1638
+ self,
1639
+ cond,
1640
+ batch_size,
1641
+ ddim,
1642
+ ddim_steps,
1643
+ with_extra_returned_data=False,
1644
+ **kwargs,
1645
+ ):
1646
+ if ddim:
1647
+ ddim_sampler = DDIMSampler(self)
1648
+ shape = (self.temporal_length, self.channels, *self.image_size)
1649
+ out = ddim_sampler.sample(
1650
+ ddim_steps,
1651
+ batch_size,
1652
+ shape,
1653
+ cond,
1654
+ verbose=True,
1655
+ with_extra_returned_data=with_extra_returned_data,
1656
+ **kwargs,
1657
+ )
1658
+ if with_extra_returned_data:
1659
+ samples, intermediates, extra_returned_data = out
1660
+ return samples, intermediates, extra_returned_data
1661
+ else:
1662
+ samples, intermediates = out
1663
+ return samples, intermediates
1664
+
1665
+ else:
1666
+ samples, intermediates = self.sample(
1667
+ cond=cond, batch_size=batch_size, return_intermediates=True, **kwargs
1668
+ )
1669
+
1670
+ return samples, intermediates
1671
+
1672
+
1673
+ class DiffusionWrapper(pl.LightningModule):
1674
+ def __init__(self, diff_model_config):
1675
+ super().__init__()
1676
+ self.diffusion_model = instantiate_from_config(diff_model_config)
1677
+
1678
+ def forward(self, x, c, **kwargs):
1679
+ return self.diffusion_model(x, c, **kwargs)
core/models/samplers/__init__.py ADDED
File without changes
core/models/samplers/ddim.py ADDED
@@ -0,0 +1,546 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SAMPLING ONLY."""
2
+
3
+ import numpy as np
4
+ import torch
5
+ from einops import rearrange
6
+ from tqdm import tqdm
7
+
8
+ from core.common import noise_like
9
+ from core.models.utils_diffusion import (
10
+ make_ddim_sampling_parameters,
11
+ make_ddim_time_steps,
12
+ rescale_noise_cfg,
13
+ )
14
+
15
+
16
+ class DDIMSampler(object):
17
+ def __init__(self, model, schedule="linear", **kwargs):
18
+ super().__init__()
19
+ self.model = model
20
+ self.ddpm_num_time_steps = model.num_time_steps
21
+ self.schedule = schedule
22
+ self.counter = 0
23
+
24
+ def register_buffer(self, name, attr):
25
+ if type(attr) == torch.Tensor:
26
+ if attr.device != torch.device("cuda"):
27
+ attr = attr.to(torch.device("cuda"))
28
+ setattr(self, name, attr)
29
+
30
+ def make_schedule(
31
+ self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True
32
+ ):
33
+ self.ddim_time_steps = make_ddim_time_steps(
34
+ ddim_discr_method=ddim_discretize,
35
+ num_ddim_time_steps=ddim_num_steps,
36
+ num_ddpm_time_steps=self.ddpm_num_time_steps,
37
+ verbose=verbose,
38
+ )
39
+ alphas_cumprod = self.model.alphas_cumprod
40
+ assert (
41
+ alphas_cumprod.shape[0] == self.ddpm_num_time_steps
42
+ ), "alphas have to be defined for each timestep"
43
+
44
+ def to_torch(x):
45
+ return x.clone().detach().to(torch.float32).to(self.model.device)
46
+
47
+ if self.model.use_dynamic_rescale:
48
+ self.ddim_scale_arr = self.model.scale_arr[self.ddim_time_steps]
49
+ self.ddim_scale_arr_prev = torch.cat(
50
+ [self.ddim_scale_arr[0:1], self.ddim_scale_arr[:-1]]
51
+ )
52
+
53
+ self.register_buffer("betas", to_torch(self.model.betas))
54
+ self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
55
+ self.register_buffer(
56
+ "alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev)
57
+ )
58
+
59
+ # calculations for diffusion q(x_t | x_{t-1}) and others
60
+ self.register_buffer(
61
+ "sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu()))
62
+ )
63
+ self.register_buffer(
64
+ "sqrt_one_minus_alphas_cumprod",
65
+ to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
66
+ )
67
+ self.register_buffer(
68
+ "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu()))
69
+ )
70
+ self.register_buffer(
71
+ "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))
72
+ )
73
+ self.register_buffer(
74
+ "sqrt_recipm1_alphas_cumprod",
75
+ to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
76
+ )
77
+
78
+ # ddim sampling parameters
79
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
80
+ alphacums=alphas_cumprod.cpu(),
81
+ ddim_time_steps=self.ddim_time_steps,
82
+ eta=ddim_eta,
83
+ verbose=verbose,
84
+ )
85
+ self.register_buffer("ddim_sigmas", ddim_sigmas)
86
+ self.register_buffer("ddim_alphas", ddim_alphas)
87
+ self.register_buffer("ddim_alphas_prev", ddim_alphas_prev)
88
+ self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas))
89
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
90
+ (1 - self.alphas_cumprod_prev)
91
+ / (1 - self.alphas_cumprod)
92
+ * (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
93
+ )
94
+ self.register_buffer(
95
+ "ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps
96
+ )
97
+
98
+ @torch.no_grad()
99
+ def sample(
100
+ self,
101
+ S,
102
+ batch_size,
103
+ shape,
104
+ conditioning=None,
105
+ callback=None,
106
+ img_callback=None,
107
+ quantize_x0=False,
108
+ eta=0.0,
109
+ mask=None,
110
+ x0=None,
111
+ temperature=1.0,
112
+ noise_dropout=0.0,
113
+ score_corrector=None,
114
+ corrector_kwargs=None,
115
+ verbose=True,
116
+ schedule_verbose=False,
117
+ x_T=None,
118
+ log_every_t=100,
119
+ unconditional_guidance_scale=1.0,
120
+ unconditional_conditioning=None,
121
+ unconditional_guidance_scale_extra=1.0,
122
+ unconditional_conditioning_extra=None,
123
+ with_extra_returned_data=False,
124
+ **kwargs,
125
+ ):
126
+
127
+ # check condition bs
128
+ if conditioning is not None:
129
+ if isinstance(conditioning, dict):
130
+ try:
131
+ cbs = conditioning[list(conditioning.keys())[0]].shape[0]
132
+ except:
133
+ cbs = conditioning[list(conditioning.keys())[0]][0].shape[0]
134
+
135
+ if cbs != batch_size:
136
+ print(
137
+ f"Warning: Got {cbs} conditionings but batch-size is {batch_size}"
138
+ )
139
+ else:
140
+ if conditioning.shape[0] != batch_size:
141
+ print(
142
+ f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}"
143
+ )
144
+
145
+ self.skip_step = self.ddpm_num_time_steps // S
146
+ discr_method = (
147
+ "uniform_trailing" if self.model.rescale_betas_zero_snr else "uniform"
148
+ )
149
+ self.make_schedule(
150
+ ddim_num_steps=S,
151
+ ddim_discretize=discr_method,
152
+ ddim_eta=eta,
153
+ verbose=schedule_verbose,
154
+ )
155
+
156
+ # make shape
157
+ if len(shape) == 3:
158
+ C, H, W = shape
159
+ size = (batch_size, C, H, W)
160
+ elif len(shape) == 4:
161
+ T, C, H, W = shape
162
+ size = (batch_size, T, C, H, W)
163
+ else:
164
+ assert False, f"Invalid shape: {shape}."
165
+ out = self.ddim_sampling(
166
+ conditioning,
167
+ size,
168
+ callback=callback,
169
+ img_callback=img_callback,
170
+ quantize_denoised=quantize_x0,
171
+ mask=mask,
172
+ x0=x0,
173
+ ddim_use_original_steps=False,
174
+ noise_dropout=noise_dropout,
175
+ temperature=temperature,
176
+ score_corrector=score_corrector,
177
+ corrector_kwargs=corrector_kwargs,
178
+ x_T=x_T,
179
+ log_every_t=log_every_t,
180
+ unconditional_guidance_scale=unconditional_guidance_scale,
181
+ unconditional_conditioning=unconditional_conditioning,
182
+ unconditional_guidance_scale_extra=unconditional_guidance_scale_extra,
183
+ unconditional_conditioning_extra=unconditional_conditioning_extra,
184
+ verbose=verbose,
185
+ with_extra_returned_data=with_extra_returned_data,
186
+ **kwargs,
187
+ )
188
+ if with_extra_returned_data:
189
+ samples, intermediates, extra_returned_data = out
190
+ return samples, intermediates, extra_returned_data
191
+ else:
192
+ samples, intermediates = out
193
+ return samples, intermediates
194
+
195
+ @torch.no_grad()
196
+ def ddim_sampling(
197
+ self,
198
+ cond,
199
+ shape,
200
+ x_T=None,
201
+ ddim_use_original_steps=False,
202
+ callback=None,
203
+ time_steps=None,
204
+ quantize_denoised=False,
205
+ mask=None,
206
+ x0=None,
207
+ img_callback=None,
208
+ log_every_t=100,
209
+ temperature=1.0,
210
+ noise_dropout=0.0,
211
+ score_corrector=None,
212
+ corrector_kwargs=None,
213
+ unconditional_guidance_scale=1.0,
214
+ unconditional_conditioning=None,
215
+ unconditional_guidance_scale_extra=1.0,
216
+ unconditional_conditioning_extra=None,
217
+ verbose=True,
218
+ with_extra_returned_data=False,
219
+ **kwargs,
220
+ ):
221
+ device = self.model.betas.device
222
+ b = shape[0]
223
+ if x_T is None:
224
+ img = torch.randn(shape, device=device, dtype=self.model.dtype)
225
+ if self.model.bd_noise:
226
+ noise_decor = self.model.bd(img)
227
+ noise_decor = (noise_decor - noise_decor.mean()) / (
228
+ noise_decor.std() + 1e-5
229
+ )
230
+ noise_f = noise_decor[:, :, 0:1, :, :]
231
+ noise = (
232
+ np.sqrt(self.model.bd_ratio) * noise_decor[:, :, 1:]
233
+ + np.sqrt(1 - self.model.bd_ratio) * noise_f
234
+ )
235
+ img = torch.cat([noise_f, noise], dim=2)
236
+ else:
237
+ img = x_T
238
+
239
+ if time_steps is None:
240
+ time_steps = (
241
+ self.ddpm_num_time_steps
242
+ if ddim_use_original_steps
243
+ else self.ddim_time_steps
244
+ )
245
+ elif time_steps is not None and not ddim_use_original_steps:
246
+ subset_end = (
247
+ int(
248
+ min(time_steps / self.ddim_time_steps.shape[0], 1)
249
+ * self.ddim_time_steps.shape[0]
250
+ )
251
+ - 1
252
+ )
253
+ time_steps = self.ddim_time_steps[:subset_end]
254
+
255
+ intermediates = {"x_inter": [img], "pred_x0": [img]}
256
+ time_range = (
257
+ reversed(range(0, time_steps))
258
+ if ddim_use_original_steps
259
+ else np.flip(time_steps)
260
+ )
261
+ total_steps = time_steps if ddim_use_original_steps else time_steps.shape[0]
262
+ if verbose:
263
+ iterator = tqdm(time_range, desc="DDIM Sampler", total=total_steps)
264
+ else:
265
+ iterator = time_range
266
+ # Sampling Loop
267
+ for i, step in enumerate(iterator):
268
+ print(f"Sample: i={i}, step={step}.")
269
+ index = total_steps - i - 1
270
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
271
+ print("ts=", ts)
272
+ # use mask to blend noised original latent (img_orig) & new sampled latent (img)
273
+ if mask is not None:
274
+ assert x0 is not None
275
+ img_orig = x0
276
+ # keep original & modify use img
277
+ img = img_orig * mask + (1.0 - mask) * img
278
+ outs = self.p_sample_ddim(
279
+ img,
280
+ cond,
281
+ ts,
282
+ index=index,
283
+ use_original_steps=ddim_use_original_steps,
284
+ quantize_denoised=quantize_denoised,
285
+ temperature=temperature,
286
+ noise_dropout=noise_dropout,
287
+ score_corrector=score_corrector,
288
+ corrector_kwargs=corrector_kwargs,
289
+ unconditional_guidance_scale=unconditional_guidance_scale,
290
+ unconditional_conditioning=unconditional_conditioning,
291
+ unconditional_guidance_scale_extra=unconditional_guidance_scale_extra,
292
+ unconditional_conditioning_extra=unconditional_conditioning_extra,
293
+ with_extra_returned_data=with_extra_returned_data,
294
+ **kwargs,
295
+ )
296
+ if with_extra_returned_data:
297
+ img, pred_x0, extra_returned_data = outs
298
+ else:
299
+ img, pred_x0 = outs
300
+ if callback:
301
+ callback(i)
302
+ if img_callback:
303
+ img_callback(pred_x0, i)
304
+ # log_every_t = 1
305
+ if index % log_every_t == 0 or index == total_steps - 1:
306
+ intermediates["x_inter"].append(img)
307
+ intermediates["pred_x0"].append(pred_x0)
308
+ # intermediates['extra_returned_data'].append(extra_returned_data)
309
+ if with_extra_returned_data:
310
+ return img, intermediates, extra_returned_data
311
+ return img, intermediates
312
+
313
+ def batch_time_transpose(
314
+ self, batch_time_tensor, num_target_views, num_condition_views
315
+ ):
316
+ # Input: N*N; N = T+C
317
+ assert num_target_views + num_condition_views == batch_time_tensor.shape[1]
318
+ target_tensor = batch_time_tensor[:, :num_target_views, ...] # T*T
319
+ condition_tensor = batch_time_tensor[:, num_target_views:, ...] # N*C
320
+ target_tensor = target_tensor.transpose(0, 1) # T*T
321
+ return torch.concat([target_tensor, condition_tensor], dim=1)
322
+
323
+ def ddim_batch_shard_step(
324
+ self,
325
+ pred_x0_post_process_function,
326
+ pred_x0_post_process_function_kwargs,
327
+ cond,
328
+ corrector_kwargs,
329
+ ddim_use_original_steps,
330
+ device,
331
+ img,
332
+ index,
333
+ kwargs,
334
+ noise_dropout,
335
+ quantize_denoised,
336
+ score_corrector,
337
+ step,
338
+ temperature,
339
+ with_extra_returned_data,
340
+ ):
341
+ img_list = []
342
+ pred_x0_list = []
343
+ shard_step = 5
344
+ shard_start = 0
345
+ while shard_start < img.shape[0]:
346
+ shard_end = shard_start + shard_step
347
+ if shard_start >= img.shape[0]:
348
+ break
349
+ if shard_end > img.shape[0]:
350
+ shard_end = img.shape[0]
351
+ print(
352
+ f"Sampling Batch Shard: From #{shard_start} to #{shard_end}. Total: {img.shape[0]}."
353
+ )
354
+ sub_img = img[shard_start:shard_end]
355
+ sub_cond = {
356
+ "combined_condition": cond["combined_condition"][shard_start:shard_end],
357
+ "c_crossattn": [
358
+ cond["c_crossattn"][0][0:1].expand(shard_end - shard_start, -1, -1)
359
+ ],
360
+ }
361
+ ts = torch.full((sub_img.shape[0],), step, device=device, dtype=torch.long)
362
+
363
+ _img, _pred_x0 = self.p_sample_ddim(
364
+ sub_img,
365
+ sub_cond,
366
+ ts,
367
+ index=index,
368
+ use_original_steps=ddim_use_original_steps,
369
+ quantize_denoised=quantize_denoised,
370
+ temperature=temperature,
371
+ noise_dropout=noise_dropout,
372
+ score_corrector=score_corrector,
373
+ corrector_kwargs=corrector_kwargs,
374
+ unconditional_guidance_scale=1.0,
375
+ unconditional_conditioning=None,
376
+ unconditional_guidance_scale_extra=1.0,
377
+ unconditional_conditioning_extra=None,
378
+ pred_x0_post_process_function=pred_x0_post_process_function,
379
+ pred_x0_post_process_function_kwargs=pred_x0_post_process_function_kwargs,
380
+ with_extra_returned_data=with_extra_returned_data,
381
+ **kwargs,
382
+ )
383
+ img_list.append(_img)
384
+ pred_x0_list.append(_pred_x0)
385
+ shard_start += shard_step
386
+ img = torch.concat(img_list, dim=0)
387
+ pred_x0 = torch.concat(pred_x0_list, dim=0)
388
+ return img, pred_x0
389
+
390
+ @torch.no_grad()
391
+ def p_sample_ddim(
392
+ self,
393
+ x,
394
+ c,
395
+ t,
396
+ index,
397
+ repeat_noise=False,
398
+ use_original_steps=False,
399
+ quantize_denoised=False,
400
+ temperature=1.0,
401
+ noise_dropout=0.0,
402
+ score_corrector=None,
403
+ corrector_kwargs=None,
404
+ unconditional_guidance_scale=1.0,
405
+ unconditional_conditioning=None,
406
+ unconditional_guidance_scale_extra=1.0,
407
+ unconditional_conditioning_extra=None,
408
+ with_extra_returned_data=False,
409
+ **kwargs,
410
+ ):
411
+ b, *_, device = *x.shape, x.device
412
+ if x.dim() == 5:
413
+ is_video = True
414
+ else:
415
+ is_video = False
416
+
417
+ extra_returned_data = None
418
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.0:
419
+ e_t_cfg = self.model.apply_model(x, t, c, **kwargs) # unet denoiser
420
+ if isinstance(e_t_cfg, tuple):
421
+ e_t_cfg = e_t_cfg[0]
422
+ extra_returned_data = e_t_cfg[1:]
423
+ else:
424
+ # with unconditional condition
425
+ if isinstance(c, torch.Tensor) or isinstance(c, dict):
426
+ e_t = self.model.apply_model(x, t, c, **kwargs)
427
+ e_t_uncond = self.model.apply_model(
428
+ x, t, unconditional_conditioning, **kwargs
429
+ )
430
+ if (
431
+ unconditional_guidance_scale_extra != 1.0
432
+ and unconditional_conditioning_extra is not None
433
+ ):
434
+ print(f"Using extra CFG: {unconditional_guidance_scale_extra}...")
435
+ e_t_uncond_extra = self.model.apply_model(
436
+ x, t, unconditional_conditioning_extra, **kwargs
437
+ )
438
+ else:
439
+ e_t_uncond_extra = None
440
+ else:
441
+ raise NotImplementedError
442
+
443
+ if isinstance(e_t, tuple):
444
+ e_t = e_t[0]
445
+ extra_returned_data = e_t[1:]
446
+
447
+ if isinstance(e_t_uncond, tuple):
448
+ e_t_uncond = e_t_uncond[0]
449
+ if isinstance(e_t_uncond_extra, tuple):
450
+ e_t_uncond_extra = e_t_uncond_extra[0]
451
+
452
+ # text cfg
453
+ if (
454
+ unconditional_guidance_scale_extra != 1.0
455
+ and unconditional_conditioning_extra is not None
456
+ ):
457
+ e_t_cfg = (
458
+ e_t_uncond
459
+ + unconditional_guidance_scale * (e_t - e_t_uncond)
460
+ + unconditional_guidance_scale_extra * (e_t - e_t_uncond_extra)
461
+ )
462
+ else:
463
+ e_t_cfg = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
464
+
465
+ if self.model.rescale_betas_zero_snr:
466
+ e_t_cfg = rescale_noise_cfg(e_t_cfg, e_t, guidance_rescale=0.7)
467
+
468
+ if self.model.parameterization == "v":
469
+ e_t = self.model.predict_eps_from_z_and_v(x, t, e_t_cfg)
470
+ else:
471
+ e_t = e_t_cfg
472
+
473
+ if score_corrector is not None:
474
+ assert self.model.parameterization == "eps", "not implemented"
475
+ e_t = score_corrector.modify_score(
476
+ self.model, e_t, x, t, c, **corrector_kwargs
477
+ )
478
+
479
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
480
+ alphas_prev = (
481
+ self.model.alphas_cumprod_prev
482
+ if use_original_steps
483
+ else self.ddim_alphas_prev
484
+ )
485
+ sqrt_one_minus_alphas = (
486
+ self.model.sqrt_one_minus_alphas_cumprod
487
+ if use_original_steps
488
+ else self.ddim_sqrt_one_minus_alphas
489
+ )
490
+ sigmas = (
491
+ self.model.ddim_sigmas_for_original_num_steps
492
+ if use_original_steps
493
+ else self.ddim_sigmas
494
+ )
495
+ # select parameters corresponding to the currently considered timestep
496
+
497
+ if is_video:
498
+ size = (b, 1, 1, 1, 1)
499
+ else:
500
+ size = (b, 1, 1, 1)
501
+ a_t = torch.full(size, alphas[index], device=device)
502
+ a_prev = torch.full(size, alphas_prev[index], device=device)
503
+ sigma_t = torch.full(size, sigmas[index], device=device)
504
+ sqrt_one_minus_at = torch.full(
505
+ size, sqrt_one_minus_alphas[index], device=device
506
+ )
507
+
508
+ # current prediction for x_0
509
+ if self.model.parameterization != "v":
510
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
511
+ else:
512
+ pred_x0 = self.model.predict_start_from_z_and_v(x, t, e_t_cfg)
513
+
514
+ if self.model.use_dynamic_rescale:
515
+ scale_t = torch.full(size, self.ddim_scale_arr[index], device=device)
516
+ prev_scale_t = torch.full(
517
+ size, self.ddim_scale_arr_prev[index], device=device
518
+ )
519
+ rescale = prev_scale_t / scale_t
520
+ pred_x0 *= rescale
521
+
522
+ if quantize_denoised:
523
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
524
+ dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t
525
+
526
+ noise = noise_like(x.shape, device, repeat_noise)
527
+ if self.model.bd_noise:
528
+ noise_decor = self.model.bd(noise)
529
+ noise_decor = (noise_decor - noise_decor.mean()) / (
530
+ noise_decor.std() + 1e-5
531
+ )
532
+ noise_f = noise_decor[:, :, 0:1, :, :]
533
+ noise = (
534
+ np.sqrt(self.model.bd_ratio) * noise_decor[:, :, 1:]
535
+ + np.sqrt(1 - self.model.bd_ratio) * noise_f
536
+ )
537
+ noise = torch.cat([noise_f, noise], dim=2)
538
+ noise = sigma_t * noise * temperature
539
+
540
+ if noise_dropout > 0.0:
541
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
542
+
543
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
544
+ if with_extra_returned_data:
545
+ return x_prev, pred_x0, extra_returned_data
546
+ return x_prev, pred_x0
core/models/samplers/dpm_solver/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .sampler import DPMSolverSampler
core/models/samplers/dpm_solver/dpm_solver.py ADDED
@@ -0,0 +1,1298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import math
4
+ from tqdm import tqdm
5
+
6
+
7
+ class NoiseScheduleVP:
8
+ def __init__(
9
+ self,
10
+ schedule='discrete',
11
+ betas=None,
12
+ alphas_cumprod=None,
13
+ continuous_beta_0=0.1,
14
+ continuous_beta_1=20.,
15
+ ):
16
+ """Create a wrapper class for the forward SDE (VP type).
17
+
18
+ ***
19
+ Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t.
20
+ We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images.
21
+ ***
22
+
23
+ The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ).
24
+ We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper).
25
+ Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have:
26
+
27
+ log_alpha_t = self.marginal_log_mean_coeff(t)
28
+ sigma_t = self.marginal_std(t)
29
+ lambda_t = self.marginal_lambda(t)
30
+
31
+ Moreover, as lambda(t) is an invertible function, we also support its inverse function:
32
+
33
+ t = self.inverse_lambda(lambda_t)
34
+
35
+ ===============================================================
36
+
37
+ We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]).
38
+
39
+ 1. For discrete-time DPMs:
40
+
41
+ For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by:
42
+ t_i = (i + 1) / N
43
+ e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1.
44
+ We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3.
45
+
46
+ Args:
47
+ betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details)
48
+ alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details)
49
+
50
+ Note that we always have alphas_cumprod = cumprod(betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`.
51
+
52
+ **Important**: Please pay special attention for the args for `alphas_cumprod`:
53
+ The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that
54
+ q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ).
55
+ Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have
56
+ alpha_{t_n} = \sqrt{\hat{alpha_n}},
57
+ and
58
+ log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}).
59
+
60
+
61
+ 2. For continuous-time DPMs:
62
+
63
+ We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise
64
+ schedule are the default settings in DDPM and improved-DDPM:
65
+
66
+ Args:
67
+ beta_min: A `float` number. The smallest beta for the linear schedule.
68
+ beta_max: A `float` number. The largest beta for the linear schedule.
69
+ cosine_s: A `float` number. The hyperparameter in the cosine schedule.
70
+ cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule.
71
+ T: A `float` number. The ending time of the forward process.
72
+
73
+ ===============================================================
74
+
75
+ Args:
76
+ schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs,
77
+ 'linear' or 'cosine' for continuous-time DPMs.
78
+ Returns:
79
+ A wrapper object of the forward SDE (VP type).
80
+
81
+ ===============================================================
82
+
83
+ Example:
84
+
85
+ # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1):
86
+ >>> ns = NoiseScheduleVP('discrete', betas=betas)
87
+
88
+ # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1):
89
+ >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
90
+
91
+ # For continuous-time DPMs (VPSDE), linear schedule:
92
+ >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.)
93
+
94
+ """
95
+
96
+ if schedule not in ['discrete', 'linear', 'cosine']:
97
+ raise ValueError(
98
+ "Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(
99
+ schedule))
100
+
101
+ self.schedule = schedule
102
+ if schedule == 'discrete':
103
+ if betas is not None:
104
+ log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0)
105
+ else:
106
+ assert alphas_cumprod is not None
107
+ log_alphas = 0.5 * torch.log(alphas_cumprod)
108
+ self.total_N = len(log_alphas)
109
+ self.T = 1.
110
+ self.t_array = torch.linspace(
111
+ 0., 1., self.total_N + 1)[1:].reshape((1, -1))
112
+ self.log_alpha_array = log_alphas.reshape((1, -1,))
113
+ else:
114
+ self.total_N = 1000
115
+ self.beta_0 = continuous_beta_0
116
+ self.beta_1 = continuous_beta_1
117
+ self.cosine_s = 0.008
118
+ self.cosine_beta_max = 999.
119
+ self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * (
120
+ 1. + self.cosine_s) / math.pi - self.cosine_s
121
+ self.cosine_log_alpha_0 = math.log(
122
+ math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.))
123
+ self.schedule = schedule
124
+ if schedule == 'cosine':
125
+ # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T.
126
+ # Note that T = 0.9946 may be not the optimal setting. However, we find it works well.
127
+ self.T = 0.9946
128
+ else:
129
+ self.T = 1.
130
+
131
+ def marginal_log_mean_coeff(self, t):
132
+ """
133
+ Compute log(alpha_t) of a given continuous-time label t in [0, T].
134
+ """
135
+ if self.schedule == 'discrete':
136
+ return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device),
137
+ self.log_alpha_array.to(t.device)).reshape((-1))
138
+ elif self.schedule == 'linear':
139
+ return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
140
+ elif self.schedule == 'cosine':
141
+ def log_alpha_fn(s): return torch.log(
142
+ torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.))
143
+ log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0
144
+ return log_alpha_t
145
+
146
+ def marginal_alpha(self, t):
147
+ """
148
+ Compute alpha_t of a given continuous-time label t in [0, T].
149
+ """
150
+ return torch.exp(self.marginal_log_mean_coeff(t))
151
+
152
+ def marginal_std(self, t):
153
+ """
154
+ Compute sigma_t of a given continuous-time label t in [0, T].
155
+ """
156
+ return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t)))
157
+
158
+ def marginal_lambda(self, t):
159
+ """
160
+ Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
161
+ """
162
+ log_mean_coeff = self.marginal_log_mean_coeff(t)
163
+ log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
164
+ return log_mean_coeff - log_std
165
+
166
+ def inverse_lambda(self, lamb):
167
+ """
168
+ Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
169
+ """
170
+ if self.schedule == 'linear':
171
+ tmp = 2. * (self.beta_1 - self.beta_0) * \
172
+ torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
173
+ Delta = self.beta_0 ** 2 + tmp
174
+ return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
175
+ elif self.schedule == 'discrete':
176
+ log_alpha = -0.5 * \
177
+ torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb)
178
+ t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]),
179
+ torch.flip(self.t_array.to(lamb.device), [1]))
180
+ return t.reshape((-1,))
181
+ else:
182
+ log_alpha = -0.5 * \
183
+ torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
184
+ def t_fn(log_alpha_t): return torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * (
185
+ 1. + self.cosine_s) / math.pi - self.cosine_s
186
+ t = t_fn(log_alpha)
187
+ return t
188
+
189
+
190
+ def model_wrapper(
191
+ model,
192
+ noise_schedule,
193
+ model_type="noise",
194
+ model_kwargs={},
195
+ guidance_type="uncond",
196
+ condition=None,
197
+ unconditional_condition=None,
198
+ guidance_scale=1.,
199
+ classifier_fn=None,
200
+ classifier_kwargs={},
201
+ ):
202
+ """Create a wrapper function for the noise prediction model.
203
+
204
+ DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to
205
+ firstly wrap the model function to a noise prediction model that accepts the continuous time as the input.
206
+
207
+ We support four types of the diffusion model by setting `model_type`:
208
+
209
+ 1. "noise": noise prediction model. (Trained by predicting noise).
210
+
211
+ 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0).
212
+
213
+ 3. "v": velocity prediction model. (Trained by predicting the velocity).
214
+ The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2].
215
+
216
+ [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models."
217
+ arXiv preprint arXiv:2202.00512 (2022).
218
+ [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models."
219
+ arXiv preprint arXiv:2210.02303 (2022).
220
+
221
+ 4. "score": marginal score function. (Trained by denoising score matching).
222
+ Note that the score function and the noise prediction model follows a simple relationship:
223
+ ```
224
+ noise(x_t, t) = -sigma_t * score(x_t, t)
225
+ ```
226
+
227
+ We support three types of guided sampling by DPMs by setting `guidance_type`:
228
+ 1. "uncond": unconditional sampling by DPMs.
229
+ The input `model` has the following format:
230
+ ``
231
+ model(x, t_input, **model_kwargs) -> noise | x_start | v | score
232
+ ``
233
+
234
+ 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier.
235
+ The input `model` has the following format:
236
+ ``
237
+ model(x, t_input, **model_kwargs) -> noise | x_start | v | score
238
+ ``
239
+
240
+ The input `classifier_fn` has the following format:
241
+ ``
242
+ classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond)
243
+ ``
244
+
245
+ [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis,"
246
+ in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794.
247
+
248
+ 3. "classifier-free": classifier-free guidance sampling by conditional DPMs.
249
+ The input `model` has the following format:
250
+ ``
251
+ model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
252
+ ``
253
+ And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
254
+
255
+ [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
256
+ arXiv preprint arXiv:2207.12598 (2022).
257
+
258
+
259
+ The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
260
+ or continuous-time labels (i.e. epsilon to T).
261
+
262
+ We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise:
263
+ ``
264
+ def model_fn(x, t_continuous) -> noise:
265
+ t_input = get_model_input_time(t_continuous)
266
+ return noise_pred(model, x, t_input, **model_kwargs)
267
+ ``
268
+ where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
269
+
270
+ ===============================================================
271
+
272
+ Args:
273
+ model: A diffusion model with the corresponding format described above.
274
+ noise_schedule: A noise schedule object, such as NoiseScheduleVP.
275
+ model_type: A `str`. The parameterization type of the diffusion model.
276
+ "noise" or "x_start" or "v" or "score".
277
+ model_kwargs: A `dict`. A dict for the other inputs of the model function.
278
+ guidance_type: A `str`. The type of the guidance for sampling.
279
+ "uncond" or "classifier" or "classifier-free".
280
+ condition: A pytorch tensor. The condition for the guided sampling.
281
+ Only used for "classifier" or "classifier-free" guidance type.
282
+ unconditional_condition: A pytorch tensor. The condition for the unconditional sampling.
283
+ Only used for "classifier-free" guidance type.
284
+ guidance_scale: A `float`. The scale for the guided sampling.
285
+ classifier_fn: A classifier function. Only used for the classifier guidance.
286
+ classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function.
287
+ Returns:
288
+ A noise prediction model that accepts the noised data and the continuous time as the inputs.
289
+ """
290
+
291
+ def get_model_input_time(t_continuous):
292
+ """
293
+ Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
294
+ For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].
295
+ For continuous-time DPMs, we just use `t_continuous`.
296
+ """
297
+ if noise_schedule.schedule == 'discrete':
298
+ return (t_continuous - 1. / noise_schedule.total_N) * 1000.
299
+ else:
300
+ return t_continuous
301
+
302
+ def noise_pred_fn(x, t_continuous, cond=None):
303
+ if t_continuous.reshape((-1,)).shape[0] == 1:
304
+ t_continuous = t_continuous.expand((x.shape[0]))
305
+ t_input = get_model_input_time(t_continuous)
306
+ if cond is None:
307
+ output = model(x, t_input, **model_kwargs)
308
+ else:
309
+ output = model(x, t_input, cond, **model_kwargs)
310
+
311
+ if isinstance(output, tuple):
312
+ output = output[0]
313
+
314
+ if model_type == "noise":
315
+ return output
316
+ elif model_type == "x_start":
317
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(
318
+ t_continuous), noise_schedule.marginal_std(t_continuous)
319
+ dims = x.dim()
320
+ return (x - expand_dims(alpha_t, dims) * output) / expand_dims(sigma_t, dims)
321
+ elif model_type == "v":
322
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(
323
+ t_continuous), noise_schedule.marginal_std(t_continuous)
324
+ dims = x.dim()
325
+ return expand_dims(alpha_t, dims) * output + expand_dims(sigma_t, dims) * x
326
+ elif model_type == "score":
327
+ sigma_t = noise_schedule.marginal_std(t_continuous)
328
+ dims = x.dim()
329
+ return -expand_dims(sigma_t, dims) * output
330
+
331
+ def cond_grad_fn(x, t_input):
332
+ """
333
+ Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).
334
+ """
335
+ with torch.enable_grad():
336
+ x_in = x.detach().requires_grad_(True)
337
+ log_prob = classifier_fn(
338
+ x_in, t_input, condition, **classifier_kwargs)
339
+ return torch.autograd.grad(log_prob.sum(), x_in)[0]
340
+
341
+ def model_fn(x, t_continuous):
342
+ """
343
+ The noise predicition model function that is used for DPM-Solver.
344
+ """
345
+ if t_continuous.reshape((-1,)).shape[0] == 1:
346
+ t_continuous = t_continuous.expand((x.shape[0]))
347
+ if guidance_type == "uncond":
348
+ return noise_pred_fn(x, t_continuous)
349
+ elif guidance_type == "classifier":
350
+ assert classifier_fn is not None
351
+ t_input = get_model_input_time(t_continuous)
352
+ cond_grad = cond_grad_fn(x, t_input)
353
+ sigma_t = noise_schedule.marginal_std(t_continuous)
354
+ noise = noise_pred_fn(x, t_continuous)
355
+ return noise - guidance_scale * expand_dims(sigma_t, dims=cond_grad.dim()) * cond_grad
356
+ elif guidance_type == "classifier-free":
357
+ if guidance_scale == 1. or unconditional_condition is None:
358
+ return noise_pred_fn(x, t_continuous, cond=condition)
359
+ else:
360
+ # x_in = torch.cat([x] * 2)
361
+ # t_in = torch.cat([t_continuous] * 2)
362
+ x_in = x
363
+ t_in = t_continuous
364
+ # c_in = torch.cat([unconditional_condition, condition])
365
+ noise = noise_pred_fn(x_in, t_in, cond=condition)
366
+ noise_uncond = noise_pred_fn(
367
+ x_in, t_in, cond=unconditional_condition)
368
+ return noise_uncond + guidance_scale * (noise - noise_uncond)
369
+
370
+ assert model_type in ["noise", "x_start", "v"]
371
+ assert guidance_type in ["uncond", "classifier", "classifier-free"]
372
+ return model_fn
373
+
374
+
375
+ class DPM_Solver:
376
+ def __init__(self, model_fn, noise_schedule, predict_x0=False, thresholding=False, max_val=1.):
377
+ """Construct a DPM-Solver.
378
+
379
+ We support both the noise prediction model ("predicting epsilon") and the data prediction model ("predicting x0").
380
+ If `predict_x0` is False, we use the solver for the noise prediction model (DPM-Solver).
381
+ If `predict_x0` is True, we use the solver for the data prediction model (DPM-Solver++).
382
+ In such case, we further support the "dynamic thresholding" in [1] when `thresholding` is True.
383
+ The "dynamic thresholding" can greatly improve the sample quality for pixel-space DPMs with large guidance scales.
384
+
385
+ Args:
386
+ model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]):
387
+ ``
388
+ def model_fn(x, t_continuous):
389
+ return noise
390
+ ``
391
+ noise_schedule: A noise schedule object, such as NoiseScheduleVP.
392
+ predict_x0: A `bool`. If true, use the data prediction model; else, use the noise prediction model.
393
+ thresholding: A `bool`. Valid when `predict_x0` is True. Whether to use the "dynamic thresholding" in [1].
394
+ max_val: A `float`. Valid when both `predict_x0` and `thresholding` are True. The max value for thresholding.
395
+
396
+ [1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour, Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b.
397
+ """
398
+ self.model = model_fn
399
+ self.noise_schedule = noise_schedule
400
+ self.predict_x0 = predict_x0
401
+ self.thresholding = thresholding
402
+ self.max_val = max_val
403
+
404
+ def noise_prediction_fn(self, x, t):
405
+ """
406
+ Return the noise prediction model.
407
+ """
408
+ return self.model(x, t)
409
+
410
+ def data_prediction_fn(self, x, t):
411
+ """
412
+ Return the data prediction model (with thresholding).
413
+ """
414
+ noise = self.noise_prediction_fn(x, t)
415
+ dims = x.dim()
416
+ alpha_t, sigma_t = self.noise_schedule.marginal_alpha(
417
+ t), self.noise_schedule.marginal_std(t)
418
+ x0 = (x - expand_dims(sigma_t, dims) * noise) / \
419
+ expand_dims(alpha_t, dims)
420
+ if self.thresholding:
421
+ p = 0.995 # A hyperparameter in the paper of "Imagen" [1].
422
+ s = torch.quantile(torch.abs(x0).reshape(
423
+ (x0.shape[0], -1)), p, dim=1)
424
+ s = expand_dims(torch.maximum(s, self.max_val *
425
+ torch.ones_like(s).to(s.device)), dims)
426
+ x0 = torch.clamp(x0, -s, s) / s
427
+ return x0
428
+
429
+ def model_fn(self, x, t):
430
+ """
431
+ Convert the model to the noise prediction model or the data prediction model.
432
+ """
433
+ if self.predict_x0:
434
+ return self.data_prediction_fn(x, t)
435
+ else:
436
+ return self.noise_prediction_fn(x, t)
437
+
438
+ def get_time_steps(self, skip_type, t_T, t_0, N, device):
439
+ """Compute the intermediate time steps for sampling.
440
+
441
+ Args:
442
+ skip_type: A `str`. The type for the spacing of the time steps. We support three types:
443
+ - 'logSNR': uniform logSNR for the time steps.
444
+ - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
445
+ - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
446
+ t_T: A `float`. The starting time of the sampling (default is T).
447
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
448
+ N: A `int`. The total number of the spacing of the time steps.
449
+ device: A torch device.
450
+ Returns:
451
+ A pytorch tensor of the time steps, with the shape (N + 1,).
452
+ """
453
+ if skip_type == 'logSNR':
454
+ lambda_T = self.noise_schedule.marginal_lambda(
455
+ torch.tensor(t_T).to(device))
456
+ lambda_0 = self.noise_schedule.marginal_lambda(
457
+ torch.tensor(t_0).to(device))
458
+ logSNR_steps = torch.linspace(
459
+ lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)
460
+ return self.noise_schedule.inverse_lambda(logSNR_steps)
461
+ elif skip_type == 'time_uniform':
462
+ return torch.linspace(t_T, t_0, N + 1).to(device)
463
+ elif skip_type == 'time_quadratic':
464
+ t_order = 2
465
+ t = torch.linspace(t_T ** (1. / t_order), t_0 **
466
+ (1. / t_order), N + 1).pow(t_order).to(device)
467
+ return t
468
+ else:
469
+ raise ValueError(
470
+ "Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type))
471
+
472
+ def get_orders_and_time_steps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
473
+ """
474
+ Get the order of each step for sampling by the singlestep DPM-Solver.
475
+
476
+ We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast".
477
+ Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is:
478
+ - If order == 1:
479
+ We take `steps` of DPM-Solver-1 (i.e. DDIM).
480
+ - If order == 2:
481
+ - Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling.
482
+ - If steps % 2 == 0, we use K steps of DPM-Solver-2.
483
+ - If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1.
484
+ - If order == 3:
485
+ - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
486
+ - If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1.
487
+ - If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1.
488
+ - If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2.
489
+
490
+ ============================================
491
+ Args:
492
+ order: A `int`. The max order for the solver (2 or 3).
493
+ steps: A `int`. The total number of function evaluations (NFE).
494
+ skip_type: A `str`. The type for the spacing of the time steps. We support three types:
495
+ - 'logSNR': uniform logSNR for the time steps.
496
+ - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
497
+ - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
498
+ t_T: A `float`. The starting time of the sampling (default is T).
499
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
500
+ device: A torch device.
501
+ Returns:
502
+ orders: A list of the solver order of each step.
503
+ """
504
+ if order == 3:
505
+ K = steps // 3 + 1
506
+ if steps % 3 == 0:
507
+ orders = [3, ] * (K - 2) + [2, 1]
508
+ elif steps % 3 == 1:
509
+ orders = [3, ] * (K - 1) + [1]
510
+ else:
511
+ orders = [3, ] * (K - 1) + [2]
512
+ elif order == 2:
513
+ if steps % 2 == 0:
514
+ K = steps // 2
515
+ orders = [2, ] * K
516
+ else:
517
+ K = steps // 2 + 1
518
+ orders = [2, ] * (K - 1) + [1]
519
+ elif order == 1:
520
+ K = 1
521
+ orders = [1, ] * steps
522
+ else:
523
+ raise ValueError("'order' must be '1' or '2' or '3'.")
524
+ if skip_type == 'logSNR':
525
+ # To reproduce the results in DPM-Solver paper
526
+ time_steps_outer = self.get_time_steps(
527
+ skip_type, t_T, t_0, K, device)
528
+ else:
529
+ time_steps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[
530
+ torch.cumsum(torch.tensor([0, ] + orders)).to(device)]
531
+ return time_steps_outer, orders
532
+
533
+ def denoise_to_zero_fn(self, x, s):
534
+ """
535
+ Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
536
+ """
537
+ return self.data_prediction_fn(x, s)
538
+
539
+ def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=False):
540
+ """
541
+ DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`.
542
+
543
+ Args:
544
+ x: A pytorch tensor. The initial value at time `s`.
545
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
546
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
547
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
548
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
549
+ return_intermediate: A `bool`. If true, also return the model value at time `s`.
550
+ Returns:
551
+ x_t: A pytorch tensor. The approximated solution at time `t`.
552
+ """
553
+ ns = self.noise_schedule
554
+ dims = x.dim()
555
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
556
+ h = lambda_t - lambda_s
557
+ log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(
558
+ s), ns.marginal_log_mean_coeff(t)
559
+ sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t)
560
+ alpha_t = torch.exp(log_alpha_t)
561
+
562
+ if self.predict_x0:
563
+ phi_1 = torch.expm1(-h)
564
+ if model_s is None:
565
+ model_s = self.model_fn(x, s)
566
+ x_t = (
567
+ expand_dims(sigma_t / sigma_s, dims) * x
568
+ - expand_dims(alpha_t * phi_1, dims) * model_s
569
+ )
570
+ if return_intermediate:
571
+ return x_t, {'model_s': model_s}
572
+ else:
573
+ return x_t
574
+ else:
575
+ phi_1 = torch.expm1(h)
576
+ if model_s is None:
577
+ model_s = self.model_fn(x, s)
578
+ x_t = (
579
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
580
+ - expand_dims(sigma_t * phi_1, dims) * model_s
581
+ )
582
+ if return_intermediate:
583
+ return x_t, {'model_s': model_s}
584
+ else:
585
+ return x_t
586
+
587
+ def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, return_intermediate=False,
588
+ solver_type='dpm_solver'):
589
+ """
590
+ Singlestep solver DPM-Solver-2 from time `s` to time `t`.
591
+
592
+ Args:
593
+ x: A pytorch tensor. The initial value at time `s`.
594
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
595
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
596
+ r1: A `float`. The hyperparameter of the second-order solver.
597
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
598
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
599
+ return_intermediate: A `bool`. If true, also return the model value at time `s` and `s1` (the intermediate time).
600
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
601
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
602
+ Returns:
603
+ x_t: A pytorch tensor. The approximated solution at time `t`.
604
+ """
605
+ if solver_type not in ['dpm_solver', 'taylor']:
606
+ raise ValueError(
607
+ "'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type))
608
+ if r1 is None:
609
+ r1 = 0.5
610
+ ns = self.noise_schedule
611
+ dims = x.dim()
612
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
613
+ h = lambda_t - lambda_s
614
+ lambda_s1 = lambda_s + r1 * h
615
+ s1 = ns.inverse_lambda(lambda_s1)
616
+ log_alpha_s, log_alpha_s1, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(
617
+ s1), ns.marginal_log_mean_coeff(t)
618
+ sigma_s, sigma_s1, sigma_t = ns.marginal_std(
619
+ s), ns.marginal_std(s1), ns.marginal_std(t)
620
+ alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t)
621
+
622
+ if self.predict_x0:
623
+ phi_11 = torch.expm1(-r1 * h)
624
+ phi_1 = torch.expm1(-h)
625
+
626
+ if model_s is None:
627
+ model_s = self.model_fn(x, s)
628
+ x_s1 = (
629
+ expand_dims(sigma_s1 / sigma_s, dims) * x
630
+ - expand_dims(alpha_s1 * phi_11, dims) * model_s
631
+ )
632
+ model_s1 = self.model_fn(x_s1, s1)
633
+ if solver_type == 'dpm_solver':
634
+ x_t = (
635
+ expand_dims(sigma_t / sigma_s, dims) * x
636
+ - expand_dims(alpha_t * phi_1, dims) * model_s
637
+ - (0.5 / r1) * expand_dims(alpha_t *
638
+ phi_1, dims) * (model_s1 - model_s)
639
+ )
640
+ elif solver_type == 'taylor':
641
+ x_t = (
642
+ expand_dims(sigma_t / sigma_s, dims) * x
643
+ - expand_dims(alpha_t * phi_1, dims) * model_s
644
+ + (1. / r1) * expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * (
645
+ model_s1 - model_s)
646
+ )
647
+ else:
648
+ phi_11 = torch.expm1(r1 * h)
649
+ phi_1 = torch.expm1(h)
650
+
651
+ if model_s is None:
652
+ model_s = self.model_fn(x, s)
653
+ x_s1 = (
654
+ expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x
655
+ - expand_dims(sigma_s1 * phi_11, dims) * model_s
656
+ )
657
+ model_s1 = self.model_fn(x_s1, s1)
658
+ if solver_type == 'dpm_solver':
659
+ x_t = (
660
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
661
+ - expand_dims(sigma_t * phi_1, dims) * model_s
662
+ - (0.5 / r1) * expand_dims(sigma_t *
663
+ phi_1, dims) * (model_s1 - model_s)
664
+ )
665
+ elif solver_type == 'taylor':
666
+ x_t = (
667
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
668
+ - expand_dims(sigma_t * phi_1, dims) * model_s
669
+ - (1. / r1) * expand_dims(sigma_t * ((torch.exp(h) -
670
+ 1.) / h - 1.), dims) * (model_s1 - model_s)
671
+ )
672
+ if return_intermediate:
673
+ return x_t, {'model_s': model_s, 'model_s1': model_s1}
674
+ else:
675
+ return x_t
676
+
677
+ def singlestep_dpm_solver_third_update(self, x, s, t, r1=1. / 3., r2=2. / 3., model_s=None, model_s1=None,
678
+ return_intermediate=False, solver_type='dpm_solver'):
679
+ """
680
+ Singlestep solver DPM-Solver-3 from time `s` to time `t`.
681
+
682
+ Args:
683
+ x: A pytorch tensor. The initial value at time `s`.
684
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
685
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
686
+ r1: A `float`. The hyperparameter of the third-order solver.
687
+ r2: A `float`. The hyperparameter of the third-order solver.
688
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
689
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
690
+ model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`).
691
+ If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it.
692
+ return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
693
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
694
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
695
+ Returns:
696
+ x_t: A pytorch tensor. The approximated solution at time `t`.
697
+ """
698
+ if solver_type not in ['dpm_solver', 'taylor']:
699
+ raise ValueError(
700
+ "'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type))
701
+ if r1 is None:
702
+ r1 = 1. / 3.
703
+ if r2 is None:
704
+ r2 = 2. / 3.
705
+ ns = self.noise_schedule
706
+ dims = x.dim()
707
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
708
+ h = lambda_t - lambda_s
709
+ lambda_s1 = lambda_s + r1 * h
710
+ lambda_s2 = lambda_s + r2 * h
711
+ s1 = ns.inverse_lambda(lambda_s1)
712
+ s2 = ns.inverse_lambda(lambda_s2)
713
+ log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ns.marginal_log_mean_coeff(
714
+ s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(s2), ns.marginal_log_mean_coeff(t)
715
+ sigma_s, sigma_s1, sigma_s2, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(
716
+ s2), ns.marginal_std(t)
717
+ alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(
718
+ log_alpha_s2), torch.exp(log_alpha_t)
719
+
720
+ if self.predict_x0:
721
+ phi_11 = torch.expm1(-r1 * h)
722
+ phi_12 = torch.expm1(-r2 * h)
723
+ phi_1 = torch.expm1(-h)
724
+ phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1.
725
+ phi_2 = phi_1 / h + 1.
726
+ phi_3 = phi_2 / h - 0.5
727
+
728
+ if model_s is None:
729
+ model_s = self.model_fn(x, s)
730
+ if model_s1 is None:
731
+ x_s1 = (
732
+ expand_dims(sigma_s1 / sigma_s, dims) * x
733
+ - expand_dims(alpha_s1 * phi_11, dims) * model_s
734
+ )
735
+ model_s1 = self.model_fn(x_s1, s1)
736
+ x_s2 = (
737
+ expand_dims(sigma_s2 / sigma_s, dims) * x
738
+ - expand_dims(alpha_s2 * phi_12, dims) * model_s
739
+ + r2 / r1 * expand_dims(alpha_s2 * phi_22,
740
+ dims) * (model_s1 - model_s)
741
+ )
742
+ model_s2 = self.model_fn(x_s2, s2)
743
+ if solver_type == 'dpm_solver':
744
+ x_t = (
745
+ expand_dims(sigma_t / sigma_s, dims) * x
746
+ - expand_dims(alpha_t * phi_1, dims) * model_s
747
+ + (1. / r2) * expand_dims(alpha_t *
748
+ phi_2, dims) * (model_s2 - model_s)
749
+ )
750
+ elif solver_type == 'taylor':
751
+ D1_0 = (1. / r1) * (model_s1 - model_s)
752
+ D1_1 = (1. / r2) * (model_s2 - model_s)
753
+ D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
754
+ D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
755
+ x_t = (
756
+ expand_dims(sigma_t / sigma_s, dims) * x
757
+ - expand_dims(alpha_t * phi_1, dims) * model_s
758
+ + expand_dims(alpha_t * phi_2, dims) * D1
759
+ - expand_dims(alpha_t * phi_3, dims) * D2
760
+ )
761
+ else:
762
+ phi_11 = torch.expm1(r1 * h)
763
+ phi_12 = torch.expm1(r2 * h)
764
+ phi_1 = torch.expm1(h)
765
+ phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1.
766
+ phi_2 = phi_1 / h - 1.
767
+ phi_3 = phi_2 / h - 0.5
768
+
769
+ if model_s is None:
770
+ model_s = self.model_fn(x, s)
771
+ if model_s1 is None:
772
+ x_s1 = (
773
+ expand_dims(
774
+ torch.exp(log_alpha_s1 - log_alpha_s), dims) * x
775
+ - expand_dims(sigma_s1 * phi_11, dims) * model_s
776
+ )
777
+ model_s1 = self.model_fn(x_s1, s1)
778
+ x_s2 = (
779
+ expand_dims(torch.exp(log_alpha_s2 - log_alpha_s), dims) * x
780
+ - expand_dims(sigma_s2 * phi_12, dims) * model_s
781
+ - r2 / r1 * expand_dims(sigma_s2 *
782
+ phi_22, dims) * (model_s1 - model_s)
783
+ )
784
+ model_s2 = self.model_fn(x_s2, s2)
785
+ if solver_type == 'dpm_solver':
786
+ x_t = (
787
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
788
+ - expand_dims(sigma_t * phi_1, dims) * model_s
789
+ - (1. / r2) * expand_dims(sigma_t *
790
+ phi_2, dims) * (model_s2 - model_s)
791
+ )
792
+ elif solver_type == 'taylor':
793
+ D1_0 = (1. / r1) * (model_s1 - model_s)
794
+ D1_1 = (1. / r2) * (model_s2 - model_s)
795
+ D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
796
+ D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
797
+ x_t = (
798
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
799
+ - expand_dims(sigma_t * phi_1, dims) * model_s
800
+ - expand_dims(sigma_t * phi_2, dims) * D1
801
+ - expand_dims(sigma_t * phi_3, dims) * D2
802
+ )
803
+
804
+ if return_intermediate:
805
+ return x_t, {'model_s': model_s, 'model_s1': model_s1, 'model_s2': model_s2}
806
+ else:
807
+ return x_t
808
+
809
+ def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpm_solver"):
810
+ """
811
+ Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`.
812
+
813
+ Args:
814
+ x: A pytorch tensor. The initial value at time `s`.
815
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
816
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
817
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
818
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
819
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
820
+ Returns:
821
+ x_t: A pytorch tensor. The approximated solution at time `t`.
822
+ """
823
+ if solver_type not in ['dpm_solver', 'taylor']:
824
+ raise ValueError(
825
+ "'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type))
826
+ ns = self.noise_schedule
827
+ dims = x.dim()
828
+ model_prev_1, model_prev_0 = model_prev_list
829
+ t_prev_1, t_prev_0 = t_prev_list
830
+ lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_1), ns.marginal_lambda(
831
+ t_prev_0), ns.marginal_lambda(t)
832
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(
833
+ t_prev_0), ns.marginal_log_mean_coeff(t)
834
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
835
+ alpha_t = torch.exp(log_alpha_t)
836
+
837
+ h_0 = lambda_prev_0 - lambda_prev_1
838
+ h = lambda_t - lambda_prev_0
839
+ r0 = h_0 / h
840
+ D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1)
841
+ if self.predict_x0:
842
+ if solver_type == 'dpm_solver':
843
+ x_t = (
844
+ expand_dims(sigma_t / sigma_prev_0, dims) * x
845
+ - expand_dims(alpha_t * (torch.exp(-h) - 1.),
846
+ dims) * model_prev_0
847
+ - 0.5 * expand_dims(alpha_t *
848
+ (torch.exp(-h) - 1.), dims) * D1_0
849
+ )
850
+ elif solver_type == 'taylor':
851
+ x_t = (
852
+ expand_dims(sigma_t / sigma_prev_0, dims) * x
853
+ - expand_dims(alpha_t * (torch.exp(-h) - 1.),
854
+ dims) * model_prev_0
855
+ + expand_dims(alpha_t *
856
+ ((torch.exp(-h) - 1.) / h + 1.), dims) * D1_0
857
+ )
858
+ else:
859
+ if solver_type == 'dpm_solver':
860
+ x_t = (
861
+ expand_dims(
862
+ torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
863
+ - expand_dims(sigma_t * (torch.exp(h) - 1.),
864
+ dims) * model_prev_0
865
+ - 0.5 * expand_dims(sigma_t *
866
+ (torch.exp(h) - 1.), dims) * D1_0
867
+ )
868
+ elif solver_type == 'taylor':
869
+ x_t = (
870
+ expand_dims(
871
+ torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
872
+ - expand_dims(sigma_t * (torch.exp(h) - 1.),
873
+ dims) * model_prev_0
874
+ - expand_dims(sigma_t * ((torch.exp(h) -
875
+ 1.) / h - 1.), dims) * D1_0
876
+ )
877
+ return x_t
878
+
879
+ def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type='dpm_solver'):
880
+ """
881
+ Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`.
882
+
883
+ Args:
884
+ x: A pytorch tensor. The initial value at time `s`.
885
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
886
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
887
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
888
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
889
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
890
+ Returns:
891
+ x_t: A pytorch tensor. The approximated solution at time `t`.
892
+ """
893
+ ns = self.noise_schedule
894
+ dims = x.dim()
895
+ model_prev_2, model_prev_1, model_prev_0 = model_prev_list
896
+ t_prev_2, t_prev_1, t_prev_0 = t_prev_list
897
+ lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_2), ns.marginal_lambda(
898
+ t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t)
899
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(
900
+ t_prev_0), ns.marginal_log_mean_coeff(t)
901
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
902
+ alpha_t = torch.exp(log_alpha_t)
903
+
904
+ h_1 = lambda_prev_1 - lambda_prev_2
905
+ h_0 = lambda_prev_0 - lambda_prev_1
906
+ h = lambda_t - lambda_prev_0
907
+ r0, r1 = h_0 / h, h_1 / h
908
+ D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1)
909
+ D1_1 = expand_dims(1. / r1, dims) * (model_prev_1 - model_prev_2)
910
+ D1 = D1_0 + expand_dims(r0 / (r0 + r1), dims) * (D1_0 - D1_1)
911
+ D2 = expand_dims(1. / (r0 + r1), dims) * (D1_0 - D1_1)
912
+ if self.predict_x0:
913
+ x_t = (
914
+ expand_dims(sigma_t / sigma_prev_0, dims) * x
915
+ - expand_dims(alpha_t * (torch.exp(-h) - 1.),
916
+ dims) * model_prev_0
917
+ + expand_dims(alpha_t *
918
+ ((torch.exp(-h) - 1.) / h + 1.), dims) * D1
919
+ - expand_dims(alpha_t * ((torch.exp(-h) -
920
+ 1. + h) / h ** 2 - 0.5), dims) * D2
921
+ )
922
+ else:
923
+ x_t = (
924
+ expand_dims(
925
+ torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
926
+ - expand_dims(sigma_t * (torch.exp(h) - 1.),
927
+ dims) * model_prev_0
928
+ - expand_dims(sigma_t *
929
+ ((torch.exp(h) - 1.) / h - 1.), dims) * D1
930
+ - expand_dims(sigma_t * ((torch.exp(h) -
931
+ 1. - h) / h ** 2 - 0.5), dims) * D2
932
+ )
933
+ return x_t
934
+
935
+ def singlestep_dpm_solver_update(self, x, s, t, order, return_intermediate=False, solver_type='dpm_solver', r1=None,
936
+ r2=None):
937
+ """
938
+ Singlestep DPM-Solver with the order `order` from time `s` to time `t`.
939
+
940
+ Args:
941
+ x: A pytorch tensor. The initial value at time `s`.
942
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
943
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
944
+ order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
945
+ return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
946
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
947
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
948
+ r1: A `float`. The hyperparameter of the second-order or third-order solver.
949
+ r2: A `float`. The hyperparameter of the third-order solver.
950
+ Returns:
951
+ x_t: A pytorch tensor. The approximated solution at time `t`.
952
+ """
953
+ if order == 1:
954
+ return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate)
955
+ elif order == 2:
956
+ return self.singlestep_dpm_solver_second_update(x, s, t, return_intermediate=return_intermediate,
957
+ solver_type=solver_type, r1=r1)
958
+ elif order == 3:
959
+ return self.singlestep_dpm_solver_third_update(x, s, t, return_intermediate=return_intermediate,
960
+ solver_type=solver_type, r1=r1, r2=r2)
961
+ else:
962
+ raise ValueError(
963
+ "Solver order must be 1 or 2 or 3, got {}".format(order))
964
+
965
+ def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type='dpm_solver'):
966
+ """
967
+ Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`.
968
+
969
+ Args:
970
+ x: A pytorch tensor. The initial value at time `s`.
971
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
972
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
973
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
974
+ order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
975
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
976
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
977
+ Returns:
978
+ x_t: A pytorch tensor. The approximated solution at time `t`.
979
+ """
980
+ if order == 1:
981
+ return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1])
982
+ elif order == 2:
983
+ return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
984
+ elif order == 3:
985
+ return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
986
+ else:
987
+ raise ValueError(
988
+ "Solver order must be 1 or 2 or 3, got {}".format(order))
989
+
990
+ def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5,
991
+ solver_type='dpm_solver'):
992
+ """
993
+ The adaptive step size solver based on singlestep DPM-Solver.
994
+
995
+ Args:
996
+ x: A pytorch tensor. The initial value at time `t_T`.
997
+ order: A `int`. The (higher) order of the solver. We only support order == 2 or 3.
998
+ t_T: A `float`. The starting time of the sampling (default is T).
999
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
1000
+ h_init: A `float`. The initial step size (for logSNR).
1001
+ atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1].
1002
+ rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05.
1003
+ theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1].
1004
+ t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the
1005
+ current time and `t_0` is less than `t_err`. The default setting is 1e-5.
1006
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
1007
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
1008
+ Returns:
1009
+ x_0: A pytorch tensor. The approximated solution at time `t_0`.
1010
+
1011
+ [1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021.
1012
+ """
1013
+ ns = self.noise_schedule
1014
+ s = t_T * torch.ones((x.shape[0],)).to(x)
1015
+ lambda_s = ns.marginal_lambda(s)
1016
+ lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x))
1017
+ h = h_init * torch.ones_like(s).to(x)
1018
+ x_prev = x
1019
+ nfe = 0
1020
+ if order == 2:
1021
+ r1 = 0.5
1022
+
1023
+ def lower_update(x, s, t): return self.dpm_solver_first_update(
1024
+ x, s, t, return_intermediate=True)
1025
+ higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1,
1026
+ solver_type=solver_type,
1027
+ **kwargs)
1028
+ elif order == 3:
1029
+ r1, r2 = 1. / 3., 2. / 3.
1030
+
1031
+ def lower_update(x, s, t): return self.singlestep_dpm_solver_second_update(x, s, t, r1=r1,
1032
+ return_intermediate=True,
1033
+ solver_type=solver_type)
1034
+ higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2,
1035
+ solver_type=solver_type,
1036
+ **kwargs)
1037
+ else:
1038
+ raise ValueError(
1039
+ "For adaptive step size solver, order must be 2 or 3, got {}".format(order))
1040
+ while torch.abs((s - t_0)).mean() > t_err:
1041
+ t = ns.inverse_lambda(lambda_s + h)
1042
+ x_lower, lower_noise_kwargs = lower_update(x, s, t)
1043
+ x_higher = higher_update(x, s, t, **lower_noise_kwargs)
1044
+ delta = torch.max(torch.ones_like(x).to(
1045
+ x) * atol, rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev)))
1046
+
1047
+ def norm_fn(v): return torch.sqrt(torch.square(
1048
+ v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True))
1049
+ E = norm_fn((x_higher - x_lower) / delta).max()
1050
+ if torch.all(E <= 1.):
1051
+ x = x_higher
1052
+ s = t
1053
+ x_prev = x_lower
1054
+ lambda_s = ns.marginal_lambda(s)
1055
+ h = torch.min(theta * h * torch.float_power(E, -
1056
+ 1. / order).float(), lambda_0 - lambda_s)
1057
+ nfe += order
1058
+ print('adaptive solver nfe', nfe)
1059
+ return x
1060
+
1061
+ def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time_uniform',
1062
+ method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver',
1063
+ atol=0.0078, rtol=0.05,
1064
+ ):
1065
+ """
1066
+ Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`.
1067
+
1068
+ =====================================================
1069
+
1070
+ We support the following algorithms for both noise prediction model and data prediction model:
1071
+ - 'singlestep':
1072
+ Singlestep DPM-Solver (i.e. "DPM-Solver-fast" in the paper), which combines different orders of singlestep DPM-Solver.
1073
+ We combine all the singlestep solvers with order <= `order` to use up all the function evaluations (steps).
1074
+ The total number of function evaluations (NFE) == `steps`.
1075
+ Given a fixed NFE == `steps`, the sampling procedure is:
1076
+ - If `order` == 1:
1077
+ - Denote K = steps. We use K steps of DPM-Solver-1 (i.e. DDIM).
1078
+ - If `order` == 2:
1079
+ - Denote K = (steps // 2) + (steps % 2). We take K intermediate time steps for sampling.
1080
+ - If steps % 2 == 0, we use K steps of singlestep DPM-Solver-2.
1081
+ - If steps % 2 == 1, we use (K - 1) steps of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
1082
+ - If `order` == 3:
1083
+ - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
1084
+ - If steps % 3 == 0, we use (K - 2) steps of singlestep DPM-Solver-3, and 1 step of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
1085
+ - If steps % 3 == 1, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of DPM-Solver-1.
1086
+ - If steps % 3 == 2, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of singlestep DPM-Solver-2.
1087
+ - 'multistep':
1088
+ Multistep DPM-Solver with the order of `order`. The total number of function evaluations (NFE) == `steps`.
1089
+ We initialize the first `order` values by lower order multistep solvers.
1090
+ Given a fixed NFE == `steps`, the sampling procedure is:
1091
+ Denote K = steps.
1092
+ - If `order` == 1:
1093
+ - We use K steps of DPM-Solver-1 (i.e. DDIM).
1094
+ - If `order` == 2:
1095
+ - We firstly use 1 step of DPM-Solver-1, then use (K - 1) step of multistep DPM-Solver-2.
1096
+ - If `order` == 3:
1097
+ - We firstly use 1 step of DPM-Solver-1, then 1 step of multistep DPM-Solver-2, then (K - 2) step of multistep DPM-Solver-3.
1098
+ - 'singlestep_fixed':
1099
+ Fixed order singlestep DPM-Solver (i.e. DPM-Solver-1 or singlestep DPM-Solver-2 or singlestep DPM-Solver-3).
1100
+ We use singlestep DPM-Solver-`order` for `order`=1 or 2 or 3, with total [`steps` // `order`] * `order` NFE.
1101
+ - 'adaptive':
1102
+ Adaptive step size DPM-Solver (i.e. "DPM-Solver-12" and "DPM-Solver-23" in the paper).
1103
+ We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`.
1104
+ You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` to balance the computatation costs
1105
+ (NFE) and the sample quality.
1106
+ - If `order` == 2, we use DPM-Solver-12 which combines DPM-Solver-1 and singlestep DPM-Solver-2.
1107
+ - If `order` == 3, we use DPM-Solver-23 which combines singlestep DPM-Solver-2 and singlestep DPM-Solver-3.
1108
+
1109
+ =====================================================
1110
+
1111
+ Some advices for choosing the algorithm:
1112
+ - For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs:
1113
+ Use singlestep DPM-Solver ("DPM-Solver-fast" in the paper) with `order = 3`.
1114
+ e.g.
1115
+ >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=False)
1116
+ >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3,
1117
+ skip_type='time_uniform', method='singlestep')
1118
+ - For **guided sampling with large guidance scale** by DPMs:
1119
+ Use multistep DPM-Solver with `predict_x0 = True` and `order = 2`.
1120
+ e.g.
1121
+ >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True)
1122
+ >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2,
1123
+ skip_type='time_uniform', method='multistep')
1124
+
1125
+ We support three types of `skip_type`:
1126
+ - 'logSNR': uniform logSNR for the time steps. **Recommended for low-resolutional images**
1127
+ - 'time_uniform': uniform time for the time steps. **Recommended for high-resolutional images**.
1128
+ - 'time_quadratic': quadratic time for the time steps.
1129
+
1130
+ =====================================================
1131
+ Args:
1132
+ x: A pytorch tensor. The initial value at time `t_start`
1133
+ e.g. if `t_start` == T, then `x` is a sample from the standard normal distribution.
1134
+ steps: A `int`. The total number of function evaluations (NFE).
1135
+ t_start: A `float`. The starting time of the sampling.
1136
+ If `T` is None, we use self.noise_schedule.T (default is 1.0).
1137
+ t_end: A `float`. The ending time of the sampling.
1138
+ If `t_end` is None, we use 1. / self.noise_schedule.total_N.
1139
+ e.g. if total_N == 1000, we have `t_end` == 1e-3.
1140
+ For discrete-time DPMs:
1141
+ - We recommend `t_end` == 1. / self.noise_schedule.total_N.
1142
+ For continuous-time DPMs:
1143
+ - We recommend `t_end` == 1e-3 when `steps` <= 15; and `t_end` == 1e-4 when `steps` > 15.
1144
+ order: A `int`. The order of DPM-Solver.
1145
+ skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'.
1146
+ method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'.
1147
+ denoise_to_zero: A `bool`. Whether to denoise to time 0 at the final step.
1148
+ Default is `False`. If `denoise_to_zero` is `True`, the total NFE is (`steps` + 1).
1149
+
1150
+ This trick is firstly proposed by DDPM (https://arxiv.org/abs/2006.11239) and
1151
+ score_sde (https://arxiv.org/abs/2011.13456). Such trick can improve the FID
1152
+ for diffusion models sampling by diffusion SDEs for low-resolutional images
1153
+ (such as CIFAR-10). However, we observed that such trick does not matter for
1154
+ high-resolutional images. As it needs an additional NFE, we do not recommend
1155
+ it for high-resolutional images.
1156
+ lower_order_final: A `bool`. Whether to use lower order solvers at the final steps.
1157
+ Only valid for `method=multistep` and `steps < 15`. We empirically find that
1158
+ this trick is a key to stabilizing the sampling by DPM-Solver with very few steps
1159
+ (especially for steps <= 10). So we recommend to set it to be `True`.
1160
+ solver_type: A `str`. The taylor expansion type for the solver. `dpm_solver` or `taylor`. We recommend `dpm_solver`.
1161
+ atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
1162
+ rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
1163
+ Returns:
1164
+ x_end: A pytorch tensor. The approximated solution at time `t_end`.
1165
+
1166
+ """
1167
+ t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
1168
+ t_T = self.noise_schedule.T if t_start is None else t_start
1169
+ device = x.device
1170
+ if method == 'adaptive':
1171
+ with torch.no_grad():
1172
+ x = self.dpm_solver_adaptive(x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol,
1173
+ solver_type=solver_type)
1174
+ elif method == 'multistep':
1175
+ assert steps >= order
1176
+ time_steps = self.get_time_steps(
1177
+ skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
1178
+ assert time_steps.shape[0] - 1 == steps
1179
+ with torch.no_grad():
1180
+ vec_t = time_steps[0].expand((x.shape[0]))
1181
+ model_prev_list = [self.model_fn(x, vec_t)]
1182
+ t_prev_list = [vec_t]
1183
+ # Init the first `order` values by lower order multistep DPM-Solver.
1184
+ for init_order in tqdm(range(1, order), desc="DPM init order"):
1185
+ vec_t = time_steps[init_order].expand(x.shape[0])
1186
+ x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, init_order,
1187
+ solver_type=solver_type)
1188
+ model_prev_list.append(self.model_fn(x, vec_t))
1189
+ t_prev_list.append(vec_t)
1190
+ # Compute the remaining values by `order`-th order multistep DPM-Solver.
1191
+ for step in tqdm(range(order, steps + 1), desc="DPM multistep"):
1192
+ vec_t = time_steps[step].expand(x.shape[0])
1193
+ if lower_order_final and steps < 15:
1194
+ step_order = min(order, steps + 1 - step)
1195
+ else:
1196
+ step_order = order
1197
+ x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, step_order,
1198
+ solver_type=solver_type)
1199
+ for i in range(order - 1):
1200
+ t_prev_list[i] = t_prev_list[i + 1]
1201
+ model_prev_list[i] = model_prev_list[i + 1]
1202
+ t_prev_list[-1] = vec_t
1203
+ # We do not need to evaluate the final model value.
1204
+ if step < steps:
1205
+ model_prev_list[-1] = self.model_fn(x, vec_t)
1206
+ elif method in ['singlestep', 'singlestep_fixed']:
1207
+ if method == 'singlestep':
1208
+ time_steps_outer, orders = self.get_orders_and_time_steps_for_singlestep_solver(steps=steps, order=order,
1209
+ skip_type=skip_type,
1210
+ t_T=t_T, t_0=t_0,
1211
+ device=device)
1212
+ elif method == 'singlestep_fixed':
1213
+ K = steps // order
1214
+ orders = [order, ] * K
1215
+ time_steps_outer = self.get_time_steps(
1216
+ skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device)
1217
+ for i, order in enumerate(orders):
1218
+ t_T_inner, t_0_inner = time_steps_outer[i], time_steps_outer[i + 1]
1219
+ time_steps_inner = self.get_time_steps(skip_type=skip_type, t_T=t_T_inner.item(), t_0=t_0_inner.item(),
1220
+ N=order, device=device)
1221
+ lambda_inner = self.noise_schedule.marginal_lambda(
1222
+ time_steps_inner)
1223
+ vec_s, vec_t = t_T_inner.tile(
1224
+ x.shape[0]), t_0_inner.tile(x.shape[0])
1225
+ h = lambda_inner[-1] - lambda_inner[0]
1226
+ r1 = None if order <= 1 else (
1227
+ lambda_inner[1] - lambda_inner[0]) / h
1228
+ r2 = None if order <= 2 else (
1229
+ lambda_inner[2] - lambda_inner[0]) / h
1230
+ x = self.singlestep_dpm_solver_update(
1231
+ x, vec_s, vec_t, order, solver_type=solver_type, r1=r1, r2=r2)
1232
+ if denoise_to_zero:
1233
+ x = self.denoise_to_zero_fn(
1234
+ x, torch.ones((x.shape[0],)).to(device) * t_0)
1235
+ return x
1236
+
1237
+
1238
+ def interpolate_fn(x, xp, yp):
1239
+ """
1240
+ A piecewise linear function y = f(x), using xp and yp as keypoints.
1241
+ We implement f(x) in a differentiable way (i.e. applicable for autograd).
1242
+ The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
1243
+
1244
+ Args:
1245
+ x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).
1246
+ xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
1247
+ yp: PyTorch tensor with shape [C, K].
1248
+ Returns:
1249
+ The function values f(x), with shape [N, C].
1250
+ """
1251
+ N, K = x.shape[0], xp.shape[1]
1252
+ all_x = torch.cat(
1253
+ [x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
1254
+ sorted_all_x, x_indices = torch.sort(all_x, dim=2)
1255
+ x_idx = torch.argmin(x_indices, dim=2)
1256
+ cand_start_idx = x_idx - 1
1257
+ start_idx = torch.where(
1258
+ torch.eq(x_idx, 0),
1259
+ torch.tensor(1, device=x.device),
1260
+ torch.where(
1261
+ torch.eq(x_idx, K), torch.tensor(
1262
+ K - 2, device=x.device), cand_start_idx,
1263
+ ),
1264
+ )
1265
+ end_idx = torch.where(torch.eq(start_idx, cand_start_idx),
1266
+ start_idx + 2, start_idx + 1)
1267
+ start_x = torch.gather(sorted_all_x, dim=2,
1268
+ index=start_idx.unsqueeze(2)).squeeze(2)
1269
+ end_x = torch.gather(sorted_all_x, dim=2,
1270
+ index=end_idx.unsqueeze(2)).squeeze(2)
1271
+ start_idx2 = torch.where(
1272
+ torch.eq(x_idx, 0),
1273
+ torch.tensor(0, device=x.device),
1274
+ torch.where(
1275
+ torch.eq(x_idx, K), torch.tensor(
1276
+ K - 2, device=x.device), cand_start_idx,
1277
+ ),
1278
+ )
1279
+ y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
1280
+ start_y = torch.gather(y_positions_expanded, dim=2,
1281
+ index=start_idx2.unsqueeze(2)).squeeze(2)
1282
+ end_y = torch.gather(y_positions_expanded, dim=2, index=(
1283
+ start_idx2 + 1).unsqueeze(2)).squeeze(2)
1284
+ cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
1285
+ return cand
1286
+
1287
+
1288
+ def expand_dims(v, dims):
1289
+ """
1290
+ Expand the tensor `v` to the dim `dims`.
1291
+
1292
+ Args:
1293
+ `v`: a PyTorch tensor with shape [N].
1294
+ `dim`: a `int`.
1295
+ Returns:
1296
+ a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
1297
+ """
1298
+ return v[(...,) + (None,) * (dims - 1)]
core/models/samplers/dpm_solver/sampler.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SAMPLING ONLY."""
2
+
3
+ import torch
4
+
5
+ from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver
6
+
7
+
8
+ MODEL_TYPES = {"eps": "noise", "v": "v"}
9
+
10
+
11
+ class DPMSolverSampler(object):
12
+ def __init__(self, model, **kwargs):
13
+ super().__init__()
14
+ self.model = model
15
+
16
+ def to_torch(x):
17
+ return x.clone().detach().to(torch.float32).to(model.device)
18
+
19
+ self.register_buffer("alphas_cumprod", to_torch(model.alphas_cumprod))
20
+
21
+ def register_buffer(self, name, attr):
22
+ if type(attr) == torch.Tensor:
23
+ if attr.device != torch.device("cuda"):
24
+ attr = attr.to(torch.device("cuda"))
25
+ setattr(self, name, attr)
26
+
27
+ @torch.no_grad()
28
+ def sample(
29
+ self,
30
+ S,
31
+ batch_size,
32
+ shape,
33
+ conditioning=None,
34
+ x_T=None,
35
+ unconditional_guidance_scale=1.0,
36
+ unconditional_conditioning=None,
37
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
38
+ **kwargs,
39
+ ):
40
+ if conditioning is not None:
41
+ if isinstance(conditioning, dict):
42
+ try:
43
+ cbs = conditioning[list(conditioning.keys())[0]].shape[0]
44
+ except:
45
+ cbs = conditioning[list(conditioning.keys())[0]][0].shape[0]
46
+
47
+ if cbs != batch_size:
48
+ print(
49
+ f"Warning: Got {cbs} conditionings but batch-size is {batch_size}"
50
+ )
51
+ else:
52
+ if conditioning.shape[0] != batch_size:
53
+ print(
54
+ f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}"
55
+ )
56
+
57
+ # sampling
58
+ T, C, H, W = shape
59
+ size = (batch_size, T, C, H, W)
60
+
61
+ print(f"Data shape for DPM-Solver sampling is {size}, sampling steps {S}")
62
+
63
+ device = self.model.betas.device
64
+ if x_T is None:
65
+ img = torch.randn(size, device=device)
66
+ else:
67
+ img = x_T
68
+
69
+ ns = NoiseScheduleVP("discrete", alphas_cumprod=self.alphas_cumprod)
70
+
71
+ model_fn = model_wrapper(
72
+ lambda x, t, c: self.model.apply_model(x, t, c),
73
+ ns,
74
+ model_type=MODEL_TYPES[self.model.parameterization],
75
+ guidance_type="classifier-free",
76
+ condition=conditioning,
77
+ unconditional_condition=unconditional_conditioning,
78
+ guidance_scale=unconditional_guidance_scale,
79
+ )
80
+
81
+ dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False)
82
+ x = dpm_solver.sample(
83
+ img,
84
+ steps=S,
85
+ skip_type="time_uniform",
86
+ method="multistep",
87
+ order=2,
88
+ lower_order_final=True,
89
+ )
90
+
91
+ return x.to(device), None
core/models/samplers/plms.py ADDED
@@ -0,0 +1,358 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SAMPLING ONLY."""
2
+
3
+ import numpy as np
4
+ from tqdm import tqdm
5
+
6
+ import torch
7
+ from core.models.utils_diffusion import (
8
+ make_ddim_sampling_parameters,
9
+ make_ddim_time_steps,
10
+ )
11
+ from core.common import noise_like
12
+
13
+
14
+ class PLMSSampler(object):
15
+ def __init__(self, model, schedule="linear", **kwargs):
16
+ super().__init__()
17
+ self.model = model
18
+ self.ddpm_num_time_steps = model.num_time_steps
19
+ self.schedule = schedule
20
+
21
+ def register_buffer(self, name, attr):
22
+ if type(attr) == torch.Tensor:
23
+ if attr.device != torch.device("cuda"):
24
+ attr = attr.to(torch.device("cuda"))
25
+ setattr(self, name, attr)
26
+
27
+ def make_schedule(
28
+ self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True
29
+ ):
30
+ if ddim_eta != 0:
31
+ raise ValueError("ddim_eta must be 0 for PLMS")
32
+ self.ddim_time_steps = make_ddim_time_steps(
33
+ ddim_discr_method=ddim_discretize,
34
+ num_ddim_time_steps=ddim_num_steps,
35
+ num_ddpm_time_steps=self.ddpm_num_time_steps,
36
+ verbose=verbose,
37
+ )
38
+ alphas_cumprod = self.model.alphas_cumprod
39
+ assert (
40
+ alphas_cumprod.shape[0] == self.ddpm_num_time_steps
41
+ ), "alphas have to be defined for each timestep"
42
+
43
+ def to_torch(x):
44
+ return x.clone().detach().to(torch.float32).to(self.model.device)
45
+
46
+ self.register_buffer("betas", to_torch(self.model.betas))
47
+ self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
48
+ self.register_buffer(
49
+ "alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev)
50
+ )
51
+
52
+ # calculations for diffusion q(x_t | x_{t-1}) and others
53
+ self.register_buffer(
54
+ "sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu()))
55
+ )
56
+ self.register_buffer(
57
+ "sqrt_one_minus_alphas_cumprod",
58
+ to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
59
+ )
60
+ self.register_buffer(
61
+ "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu()))
62
+ )
63
+ self.register_buffer(
64
+ "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))
65
+ )
66
+ self.register_buffer(
67
+ "sqrt_recipm1_alphas_cumprod",
68
+ to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
69
+ )
70
+
71
+ # ddim sampling parameters
72
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
73
+ alphacums=alphas_cumprod.cpu(),
74
+ ddim_time_steps=self.ddim_time_steps,
75
+ eta=ddim_eta,
76
+ verbose=verbose,
77
+ )
78
+ self.register_buffer("ddim_sigmas", ddim_sigmas)
79
+ self.register_buffer("ddim_alphas", ddim_alphas)
80
+ self.register_buffer("ddim_alphas_prev", ddim_alphas_prev)
81
+ self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas))
82
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
83
+ (1 - self.alphas_cumprod_prev)
84
+ / (1 - self.alphas_cumprod)
85
+ * (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
86
+ )
87
+ self.register_buffer(
88
+ "ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps
89
+ )
90
+
91
+ @torch.no_grad()
92
+ def sample(
93
+ self,
94
+ S,
95
+ batch_size,
96
+ shape,
97
+ conditioning=None,
98
+ callback=None,
99
+ normals_sequence=None,
100
+ img_callback=None,
101
+ quantize_x0=False,
102
+ eta=0.0,
103
+ mask=None,
104
+ x0=None,
105
+ temperature=1.0,
106
+ noise_dropout=0.0,
107
+ score_corrector=None,
108
+ corrector_kwargs=None,
109
+ verbose=True,
110
+ x_T=None,
111
+ log_every_t=100,
112
+ unconditional_guidance_scale=1.0,
113
+ unconditional_conditioning=None,
114
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
115
+ **kwargs,
116
+ ):
117
+ if conditioning is not None:
118
+ if isinstance(conditioning, dict):
119
+ cbs = conditioning[list(conditioning.keys())[0]].shape[0]
120
+ if cbs != batch_size:
121
+ print(
122
+ f"Warning: Got {cbs} conditionings but batch-size is {batch_size}"
123
+ )
124
+ else:
125
+ if conditioning.shape[0] != batch_size:
126
+ print(
127
+ f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}"
128
+ )
129
+
130
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
131
+ # sampling
132
+ C, H, W = shape
133
+ size = (batch_size, C, H, W)
134
+ print(f"Data shape for PLMS sampling is {size}")
135
+
136
+ samples, intermediates = self.plms_sampling(
137
+ conditioning,
138
+ size,
139
+ callback=callback,
140
+ img_callback=img_callback,
141
+ quantize_denoised=quantize_x0,
142
+ mask=mask,
143
+ x0=x0,
144
+ ddim_use_original_steps=False,
145
+ noise_dropout=noise_dropout,
146
+ temperature=temperature,
147
+ score_corrector=score_corrector,
148
+ corrector_kwargs=corrector_kwargs,
149
+ x_T=x_T,
150
+ log_every_t=log_every_t,
151
+ unconditional_guidance_scale=unconditional_guidance_scale,
152
+ unconditional_conditioning=unconditional_conditioning,
153
+ )
154
+ return samples, intermediates
155
+
156
+ @torch.no_grad()
157
+ def plms_sampling(
158
+ self,
159
+ cond,
160
+ shape,
161
+ x_T=None,
162
+ ddim_use_original_steps=False,
163
+ callback=None,
164
+ time_steps=None,
165
+ quantize_denoised=False,
166
+ mask=None,
167
+ x0=None,
168
+ img_callback=None,
169
+ log_every_t=100,
170
+ temperature=1.0,
171
+ noise_dropout=0.0,
172
+ score_corrector=None,
173
+ corrector_kwargs=None,
174
+ unconditional_guidance_scale=1.0,
175
+ unconditional_conditioning=None,
176
+ ):
177
+ device = self.model.betas.device
178
+ b = shape[0]
179
+ if x_T is None:
180
+ img = torch.randn(shape, device=device)
181
+ else:
182
+ img = x_T
183
+
184
+ if time_steps is None:
185
+ time_steps = (
186
+ self.ddpm_num_time_steps
187
+ if ddim_use_original_steps
188
+ else self.ddim_time_steps
189
+ )
190
+ elif time_steps is not None and not ddim_use_original_steps:
191
+ subset_end = (
192
+ int(
193
+ min(time_steps / self.ddim_time_steps.shape[0], 1)
194
+ * self.ddim_time_steps.shape[0]
195
+ )
196
+ - 1
197
+ )
198
+ time_steps = self.ddim_time_steps[:subset_end]
199
+
200
+ intermediates = {"x_inter": [img], "pred_x0": [img]}
201
+ time_range = (
202
+ list(reversed(range(0, time_steps)))
203
+ if ddim_use_original_steps
204
+ else np.flip(time_steps)
205
+ )
206
+ total_steps = time_steps if ddim_use_original_steps else time_steps.shape[0]
207
+ print(f"Running PLMS Sampling with {total_steps} time_steps")
208
+
209
+ iterator = tqdm(time_range, desc="PLMS Sampler", total=total_steps)
210
+ old_eps = []
211
+
212
+ for i, step in enumerate(iterator):
213
+ index = total_steps - i - 1
214
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
215
+ ts_next = torch.full(
216
+ (b,),
217
+ time_range[min(i + 1, len(time_range) - 1)],
218
+ device=device,
219
+ dtype=torch.long,
220
+ )
221
+
222
+ if mask is not None:
223
+ assert x0 is not None
224
+ img_orig = self.model.q_sample(x0, ts)
225
+ img = img_orig * mask + (1.0 - mask) * img
226
+
227
+ outs = self.p_sample_plms(
228
+ img,
229
+ cond,
230
+ ts,
231
+ index=index,
232
+ use_original_steps=ddim_use_original_steps,
233
+ quantize_denoised=quantize_denoised,
234
+ temperature=temperature,
235
+ noise_dropout=noise_dropout,
236
+ score_corrector=score_corrector,
237
+ corrector_kwargs=corrector_kwargs,
238
+ unconditional_guidance_scale=unconditional_guidance_scale,
239
+ unconditional_conditioning=unconditional_conditioning,
240
+ old_eps=old_eps,
241
+ t_next=ts_next,
242
+ )
243
+ img, pred_x0, e_t = outs
244
+ old_eps.append(e_t)
245
+ if len(old_eps) >= 4:
246
+ old_eps.pop(0)
247
+ if callback:
248
+ callback(i)
249
+ if img_callback:
250
+ img_callback(pred_x0, i)
251
+
252
+ if index % log_every_t == 0 or index == total_steps - 1:
253
+ intermediates["x_inter"].append(img)
254
+ intermediates["pred_x0"].append(pred_x0)
255
+
256
+ return img, intermediates
257
+
258
+ @torch.no_grad()
259
+ def p_sample_plms(
260
+ self,
261
+ x,
262
+ c,
263
+ t,
264
+ index,
265
+ repeat_noise=False,
266
+ use_original_steps=False,
267
+ quantize_denoised=False,
268
+ temperature=1.0,
269
+ noise_dropout=0.0,
270
+ score_corrector=None,
271
+ corrector_kwargs=None,
272
+ unconditional_guidance_scale=1.0,
273
+ unconditional_conditioning=None,
274
+ old_eps=None,
275
+ t_next=None,
276
+ ):
277
+ b, *_, device = *x.shape, x.device
278
+
279
+ def get_model_output(x, t):
280
+ if (
281
+ unconditional_conditioning is None
282
+ or unconditional_guidance_scale == 1.0
283
+ ):
284
+ e_t = self.model.apply_model(x, t, c)
285
+ else:
286
+ x_in = torch.cat([x] * 2)
287
+ t_in = torch.cat([t] * 2)
288
+ c_in = torch.cat([unconditional_conditioning, c])
289
+ e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
290
+ e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
291
+
292
+ if score_corrector is not None:
293
+ assert self.model.parameterization == "eps"
294
+ e_t = score_corrector.modify_score(
295
+ self.model, e_t, x, t, c, **corrector_kwargs
296
+ )
297
+
298
+ return e_t
299
+
300
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
301
+ alphas_prev = (
302
+ self.model.alphas_cumprod_prev
303
+ if use_original_steps
304
+ else self.ddim_alphas_prev
305
+ )
306
+ sqrt_one_minus_alphas = (
307
+ self.model.sqrt_one_minus_alphas_cumprod
308
+ if use_original_steps
309
+ else self.ddim_sqrt_one_minus_alphas
310
+ )
311
+ sigmas = (
312
+ self.model.ddim_sigmas_for_original_num_steps
313
+ if use_original_steps
314
+ else self.ddim_sigmas
315
+ )
316
+
317
+ def get_x_prev_and_pred_x0(e_t, index):
318
+ # select parameters corresponding to the currently considered timestep
319
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
320
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
321
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
322
+ sqrt_one_minus_at = torch.full(
323
+ (b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
324
+ )
325
+
326
+ # current prediction for x_0
327
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
328
+ if quantize_denoised:
329
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
330
+ # direction pointing to x_t
331
+ dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t
332
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
333
+ if noise_dropout > 0.0:
334
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
335
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
336
+ return x_prev, pred_x0
337
+
338
+ e_t = get_model_output(x, t)
339
+ if len(old_eps) == 0:
340
+ # Pseudo Improved Euler (2nd order)
341
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
342
+ e_t_next = get_model_output(x_prev, t_next)
343
+ e_t_prime = (e_t + e_t_next) / 2
344
+ elif len(old_eps) == 1:
345
+ # 2nd order Pseudo Linear Multistep (Adams-Bashforth)
346
+ e_t_prime = (3 * e_t - old_eps[-1]) / 2
347
+ elif len(old_eps) == 2:
348
+ # 3nd order Pseudo Linear Multistep (Adams-Bashforth)
349
+ e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
350
+ elif len(old_eps) >= 3:
351
+ # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
352
+ e_t_prime = (
353
+ 55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]
354
+ ) / 24
355
+
356
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
357
+
358
+ return x_prev, pred_x0, e_t
core/models/samplers/uni_pc/__init__.py ADDED
File without changes
core/models/samplers/uni_pc/sampler.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SAMPLING ONLY."""
2
+
3
+ import torch
4
+
5
+ from .uni_pc import NoiseScheduleVP, model_wrapper, UniPC
6
+
7
+
8
+ class UniPCSampler(object):
9
+ def __init__(self, model, **kwargs):
10
+ super().__init__()
11
+ self.model = model
12
+
13
+ def to_torch(x):
14
+ return x.clone().detach().to(torch.float32).to(model.device)
15
+
16
+ self.register_buffer("alphas_cumprod", to_torch(model.alphas_cumprod))
17
+
18
+ def register_buffer(self, name, attr):
19
+ if type(attr) == torch.Tensor:
20
+ if attr.device != torch.device("cuda"):
21
+ attr = attr.to(torch.device("cuda"))
22
+ setattr(self, name, attr)
23
+
24
+ @torch.no_grad()
25
+ def sample(
26
+ self,
27
+ S,
28
+ batch_size,
29
+ shape,
30
+ conditioning=None,
31
+ x_T=None,
32
+ unconditional_guidance_scale=1.0,
33
+ unconditional_conditioning=None,
34
+ ):
35
+ # sampling
36
+ T, C, H, W = shape
37
+ size = (batch_size, T, C, H, W)
38
+
39
+ device = self.model.betas.device
40
+ if x_T is None:
41
+ img = torch.randn(size, device=device)
42
+ else:
43
+ img = x_T
44
+
45
+ ns = NoiseScheduleVP("discrete", alphas_cumprod=self.alphas_cumprod)
46
+
47
+ model_fn = model_wrapper(
48
+ lambda x, t, c: self.model.apply_model(x, t, c),
49
+ ns,
50
+ model_type="v",
51
+ guidance_type="classifier-free",
52
+ condition=conditioning,
53
+ unconditional_condition=unconditional_conditioning,
54
+ guidance_scale=unconditional_guidance_scale,
55
+ )
56
+
57
+ uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False)
58
+ x = uni_pc.sample(
59
+ img,
60
+ steps=S,
61
+ skip_type="time_uniform",
62
+ method="multistep",
63
+ order=2,
64
+ lower_order_final=True,
65
+ )
66
+
67
+ return x.to(device), None
core/models/samplers/uni_pc/uni_pc.py ADDED
@@ -0,0 +1,998 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import math
4
+
5
+
6
+ class NoiseScheduleVP:
7
+ def __init__(
8
+ self,
9
+ schedule="discrete",
10
+ betas=None,
11
+ alphas_cumprod=None,
12
+ continuous_beta_0=0.1,
13
+ continuous_beta_1=20.0,
14
+ ):
15
+ """Create a wrapper class for the forward SDE (VP type).
16
+
17
+ ***
18
+ Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t.
19
+ We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images.
20
+ ***
21
+
22
+ The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ).
23
+ We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper).
24
+ Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have:
25
+
26
+ log_alpha_t = self.marginal_log_mean_coeff(t)
27
+ sigma_t = self.marginal_std(t)
28
+ lambda_t = self.marginal_lambda(t)
29
+
30
+ Moreover, as lambda(t) is an invertible function, we also support its inverse function:
31
+
32
+ t = self.inverse_lambda(lambda_t)
33
+
34
+ ===============================================================
35
+
36
+ We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]).
37
+
38
+ 1. For discrete-time DPMs:
39
+
40
+ For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by:
41
+ t_i = (i + 1) / N
42
+ e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1.
43
+ We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3.
44
+
45
+ Args:
46
+ betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details)
47
+ alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details)
48
+
49
+ Note that we always have alphas_cumprod = cumprod(betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`.
50
+
51
+ **Important**: Please pay special attention for the args for `alphas_cumprod`:
52
+ The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that
53
+ q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ).
54
+ Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have
55
+ alpha_{t_n} = \sqrt{\hat{alpha_n}},
56
+ and
57
+ log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}).
58
+
59
+
60
+ 2. For continuous-time DPMs:
61
+
62
+ We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise
63
+ schedule are the default settings in DDPM and improved-DDPM:
64
+
65
+ Args:
66
+ beta_min: A `float` number. The smallest beta for the linear schedule.
67
+ beta_max: A `float` number. The largest beta for the linear schedule.
68
+ cosine_s: A `float` number. The hyperparameter in the cosine schedule.
69
+ cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule.
70
+ T: A `float` number. The ending time of the forward process.
71
+
72
+ ===============================================================
73
+
74
+ Args:
75
+ schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs,
76
+ 'linear' or 'cosine' for continuous-time DPMs.
77
+ Returns:
78
+ A wrapper object of the forward SDE (VP type).
79
+
80
+ ===============================================================
81
+
82
+ Example:
83
+
84
+ # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1):
85
+ >>> ns = NoiseScheduleVP('discrete', betas=betas)
86
+
87
+ # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1):
88
+ >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
89
+
90
+ # For continuous-time DPMs (VPSDE), linear schedule:
91
+ >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.)
92
+
93
+ """
94
+
95
+ if schedule not in ["discrete", "linear", "cosine"]:
96
+ raise ValueError(
97
+ "Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(
98
+ schedule
99
+ )
100
+ )
101
+
102
+ self.schedule = schedule
103
+ if schedule == "discrete":
104
+ if betas is not None:
105
+ log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0)
106
+ else:
107
+ assert alphas_cumprod is not None
108
+ log_alphas = 0.5 * torch.log(alphas_cumprod)
109
+ self.total_N = len(log_alphas)
110
+ self.T = 1.0
111
+ self.t_array = torch.linspace(0.0, 1.0, self.total_N + 1)[1:].reshape(
112
+ (1, -1)
113
+ )
114
+ self.log_alpha_array = log_alphas.reshape(
115
+ (
116
+ 1,
117
+ -1,
118
+ )
119
+ )
120
+ else:
121
+ self.total_N = 1000
122
+ self.beta_0 = continuous_beta_0
123
+ self.beta_1 = continuous_beta_1
124
+ self.cosine_s = 0.008
125
+ self.cosine_beta_max = 999.0
126
+ self.cosine_t_max = (
127
+ math.atan(self.cosine_beta_max * (1.0 + self.cosine_s) / math.pi)
128
+ * 2.0
129
+ * (1.0 + self.cosine_s)
130
+ / math.pi
131
+ - self.cosine_s
132
+ )
133
+ self.cosine_log_alpha_0 = math.log(
134
+ math.cos(self.cosine_s / (1.0 + self.cosine_s) * math.pi / 2.0)
135
+ )
136
+ self.schedule = schedule
137
+ if schedule == "cosine":
138
+ # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T.
139
+ # Note that T = 0.9946 may be not the optimal setting. However, we find it works well.
140
+ self.T = 0.9946
141
+ else:
142
+ self.T = 1.0
143
+
144
+ def marginal_log_mean_coeff(self, t):
145
+ """
146
+ Compute log(alpha_t) of a given continuous-time label t in [0, T].
147
+ """
148
+ if self.schedule == "discrete":
149
+ return interpolate_fn(
150
+ t.reshape((-1, 1)),
151
+ self.t_array.to(t.device),
152
+ self.log_alpha_array.to(t.device),
153
+ ).reshape((-1))
154
+ elif self.schedule == "linear":
155
+ return -0.25 * t**2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
156
+ elif self.schedule == "cosine":
157
+
158
+ def log_alpha_fn(s):
159
+ return torch.log(
160
+ torch.cos(
161
+ (s + self.cosine_s) / (1.0 + self.cosine_s) * math.pi / 2.0
162
+ )
163
+ )
164
+
165
+ log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0
166
+ return log_alpha_t
167
+
168
+ def marginal_alpha(self, t):
169
+ """
170
+ Compute alpha_t of a given continuous-time label t in [0, T].
171
+ """
172
+ return torch.exp(self.marginal_log_mean_coeff(t))
173
+
174
+ def marginal_std(self, t):
175
+ """
176
+ Compute sigma_t of a given continuous-time label t in [0, T].
177
+ """
178
+ return torch.sqrt(1.0 - torch.exp(2.0 * self.marginal_log_mean_coeff(t)))
179
+
180
+ def marginal_lambda(self, t):
181
+ """
182
+ Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
183
+ """
184
+ log_mean_coeff = self.marginal_log_mean_coeff(t)
185
+ log_std = 0.5 * torch.log(1.0 - torch.exp(2.0 * log_mean_coeff))
186
+ return log_mean_coeff - log_std
187
+
188
+ def inverse_lambda(self, lamb):
189
+ """
190
+ Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
191
+ """
192
+ if self.schedule == "linear":
193
+ tmp = (
194
+ 2.0
195
+ * (self.beta_1 - self.beta_0)
196
+ * torch.logaddexp(-2.0 * lamb, torch.zeros((1,)).to(lamb))
197
+ )
198
+ Delta = self.beta_0**2 + tmp
199
+ return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
200
+ elif self.schedule == "discrete":
201
+ log_alpha = -0.5 * torch.logaddexp(
202
+ torch.zeros((1,)).to(lamb.device), -2.0 * lamb
203
+ )
204
+ t = interpolate_fn(
205
+ log_alpha.reshape((-1, 1)),
206
+ torch.flip(self.log_alpha_array.to(lamb.device), [1]),
207
+ torch.flip(self.t_array.to(lamb.device), [1]),
208
+ )
209
+ return t.reshape((-1,))
210
+ else:
211
+ log_alpha = -0.5 * torch.logaddexp(-2.0 * lamb, torch.zeros((1,)).to(lamb))
212
+
213
+ def t_fn(log_alpha_t):
214
+ return (
215
+ torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0))
216
+ * 2.0
217
+ * (1.0 + self.cosine_s)
218
+ / math.pi
219
+ - self.cosine_s
220
+ )
221
+
222
+ t = t_fn(log_alpha)
223
+ return t
224
+
225
+
226
+ def model_wrapper(
227
+ model,
228
+ noise_schedule,
229
+ model_type="noise",
230
+ model_kwargs={},
231
+ guidance_type="uncond",
232
+ condition=None,
233
+ unconditional_condition=None,
234
+ guidance_scale=1.0,
235
+ classifier_fn=None,
236
+ classifier_kwargs={},
237
+ ):
238
+ """Create a wrapper function for the noise prediction model.
239
+
240
+ DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to
241
+ firstly wrap the model function to a noise prediction model that accepts the continuous time as the input.
242
+
243
+ We support four types of the diffusion model by setting `model_type`:
244
+
245
+ 1. "noise": noise prediction model. (Trained by predicting noise).
246
+
247
+ 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0).
248
+
249
+ 3. "v": velocity prediction model. (Trained by predicting the velocity).
250
+ The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2].
251
+
252
+ [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models."
253
+ arXiv preprint arXiv:2202.00512 (2022).
254
+ [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models."
255
+ arXiv preprint arXiv:2210.02303 (2022).
256
+
257
+ 4. "score": marginal score function. (Trained by denoising score matching).
258
+ Note that the score function and the noise prediction model follows a simple relationship:
259
+ ```
260
+ noise(x_t, t) = -sigma_t * score(x_t, t)
261
+ ```
262
+
263
+ We support three types of guided sampling by DPMs by setting `guidance_type`:
264
+ 1. "uncond": unconditional sampling by DPMs.
265
+ The input `model` has the following format:
266
+ ``
267
+ model(x, t_input, **model_kwargs) -> noise | x_start | v | score
268
+ ``
269
+
270
+ 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier.
271
+ The input `model` has the following format:
272
+ ``
273
+ model(x, t_input, **model_kwargs) -> noise | x_start | v | score
274
+ ``
275
+
276
+ The input `classifier_fn` has the following format:
277
+ ``
278
+ classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond)
279
+ ``
280
+
281
+ [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis,"
282
+ in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794.
283
+
284
+ 3. "classifier-free": classifier-free guidance sampling by conditional DPMs.
285
+ The input `model` has the following format:
286
+ ``
287
+ model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
288
+ ``
289
+ And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
290
+
291
+ [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
292
+ arXiv preprint arXiv:2207.12598 (2022).
293
+
294
+
295
+ The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
296
+ or continuous-time labels (i.e. epsilon to T).
297
+
298
+ We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise:
299
+ ``
300
+ def model_fn(x, t_continuous) -> noise:
301
+ t_input = get_model_input_time(t_continuous)
302
+ return noise_pred(model, x, t_input, **model_kwargs)
303
+ ``
304
+ where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
305
+
306
+ ===============================================================
307
+
308
+ Args:
309
+ model: A diffusion model with the corresponding format described above.
310
+ noise_schedule: A noise schedule object, such as NoiseScheduleVP.
311
+ model_type: A `str`. The parameterization type of the diffusion model.
312
+ "noise" or "x_start" or "v" or "score".
313
+ model_kwargs: A `dict`. A dict for the other inputs of the model function.
314
+ guidance_type: A `str`. The type of the guidance for sampling.
315
+ "uncond" or "classifier" or "classifier-free".
316
+ condition: A pytorch tensor. The condition for the guided sampling.
317
+ Only used for "classifier" or "classifier-free" guidance type.
318
+ unconditional_condition: A pytorch tensor. The condition for the unconditional sampling.
319
+ Only used for "classifier-free" guidance type.
320
+ guidance_scale: A `float`. The scale for the guided sampling.
321
+ classifier_fn: A classifier function. Only used for the classifier guidance.
322
+ classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function.
323
+ Returns:
324
+ A noise prediction model that accepts the noised data and the continuous time as the inputs.
325
+ """
326
+
327
+ def get_model_input_time(t_continuous):
328
+ """
329
+ Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
330
+ For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].
331
+ For continuous-time DPMs, we just use `t_continuous`.
332
+ """
333
+ if noise_schedule.schedule == "discrete":
334
+ return (t_continuous - 1.0 / noise_schedule.total_N) * 1000.0
335
+ else:
336
+ return t_continuous
337
+
338
+ def noise_pred_fn(x, t_continuous, cond=None):
339
+ if t_continuous.reshape((-1,)).shape[0] == 1:
340
+ t_continuous = t_continuous.expand((x.shape[0]))
341
+ t_input = get_model_input_time(t_continuous)
342
+ if cond is None:
343
+ output = model(x, t_input, None, **model_kwargs)
344
+ else:
345
+ output = model(x, t_input, cond, **model_kwargs)
346
+ if isinstance(output, tuple):
347
+ output = output[0]
348
+ if model_type == "noise":
349
+ return output
350
+ elif model_type == "x_start":
351
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(
352
+ t_continuous
353
+ ), noise_schedule.marginal_std(t_continuous)
354
+ dims = x.dim()
355
+ return (x - expand_dims(alpha_t, dims) * output) / expand_dims(
356
+ sigma_t, dims
357
+ )
358
+ elif model_type == "v":
359
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(
360
+ t_continuous
361
+ ), noise_schedule.marginal_std(t_continuous)
362
+ dims = x.dim()
363
+ print("alpha_t.shape", alpha_t.shape)
364
+ print("sigma_t.shape", sigma_t.shape)
365
+ print("dims", dims)
366
+ print("x.shape", x.shape)
367
+ # x: b, t, c, h, w
368
+ alpha_t = alpha_t.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
369
+ sigma_t = sigma_t.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
370
+ print("alpha_t.shape", alpha_t.shape)
371
+ print("sigma_t.shape", sigma_t.shape)
372
+ print("output.shape", output.shape)
373
+ return alpha_t * output + sigma_t * x
374
+ elif model_type == "score":
375
+ sigma_t = noise_schedule.marginal_std(t_continuous)
376
+ dims = x.dim()
377
+ return -expand_dims(sigma_t, dims) * output
378
+
379
+ def cond_grad_fn(x, t_input):
380
+ """
381
+ Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).
382
+ """
383
+ with torch.enable_grad():
384
+ x_in = x.detach().requires_grad_(True)
385
+ log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs)
386
+ return torch.autograd.grad(log_prob.sum(), x_in)[0]
387
+
388
+ def model_fn(x, t_continuous):
389
+ """
390
+ The noise predicition model function that is used for DPM-Solver.
391
+ """
392
+ if t_continuous.reshape((-1,)).shape[0] == 1:
393
+ t_continuous = t_continuous.expand((x.shape[0]))
394
+ if guidance_type == "uncond":
395
+ return noise_pred_fn(x, t_continuous)
396
+ elif guidance_type == "classifier":
397
+ assert classifier_fn is not None
398
+ t_input = get_model_input_time(t_continuous)
399
+ cond_grad = cond_grad_fn(x, t_input)
400
+ sigma_t = noise_schedule.marginal_std(t_continuous)
401
+ noise = noise_pred_fn(x, t_continuous)
402
+ return (
403
+ noise
404
+ - guidance_scale
405
+ * expand_dims(sigma_t, dims=cond_grad.dim())
406
+ * cond_grad
407
+ )
408
+ elif guidance_type == "classifier-free":
409
+ if guidance_scale == 1.0 or unconditional_condition is None:
410
+ return noise_pred_fn(x, t_continuous, cond=condition)
411
+ else:
412
+ x_in = x
413
+ t_in = t_continuous
414
+ print("x_in.shape=", x_in.shape)
415
+ print("t_in.shape=", t_in.shape)
416
+ noise = noise_pred_fn(x_in, t_in, cond=condition)
417
+
418
+ noise_uncond = noise_pred_fn(x_in, t_in, cond=unconditional_condition)
419
+ return noise_uncond + guidance_scale * (noise - noise_uncond)
420
+
421
+ assert model_type in ["noise", "x_start", "v"]
422
+ assert guidance_type in ["uncond", "classifier", "classifier-free"]
423
+ return model_fn
424
+
425
+
426
+ class UniPC:
427
+ def __init__(
428
+ self,
429
+ model_fn,
430
+ noise_schedule,
431
+ predict_x0=True,
432
+ thresholding=False,
433
+ max_val=1.0,
434
+ variant="bh1",
435
+ ):
436
+ """Construct a UniPC.
437
+
438
+ We support both data_prediction and noise_prediction.
439
+ """
440
+ self.model = model_fn
441
+ self.noise_schedule = noise_schedule
442
+ self.variant = variant
443
+ self.predict_x0 = predict_x0
444
+ self.thresholding = thresholding
445
+ self.max_val = max_val
446
+
447
+ def dynamic_thresholding_fn(self, x0, t=None):
448
+ """
449
+ The dynamic thresholding method.
450
+ """
451
+ dims = x0.dim()
452
+ p = self.dynamic_thresholding_ratio
453
+ s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
454
+ s = expand_dims(
455
+ torch.maximum(
456
+ s, self.thresholding_max_val * torch.ones_like(s).to(s.device)
457
+ ),
458
+ dims,
459
+ )
460
+ x0 = torch.clamp(x0, -s, s) / s
461
+ return x0
462
+
463
+ def noise_prediction_fn(self, x, t):
464
+ """
465
+ Return the noise prediction model.
466
+ """
467
+ return self.model(x, t)
468
+
469
+ def data_prediction_fn(self, x, t):
470
+ """
471
+ Return the data prediction model (with thresholding).
472
+ """
473
+ noise = self.noise_prediction_fn(x, t)
474
+ dims = x.dim()
475
+ alpha_t, sigma_t = self.noise_schedule.marginal_alpha(
476
+ t
477
+ ), self.noise_schedule.marginal_std(t)
478
+ x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims)
479
+ if self.thresholding:
480
+ p = 0.995 # A hyperparameter in the paper of "Imagen" [1].
481
+ s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
482
+ s = expand_dims(
483
+ torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims
484
+ )
485
+ x0 = torch.clamp(x0, -s, s) / s
486
+ return x0
487
+
488
+ def model_fn(self, x, t):
489
+ """
490
+ Convert the model to the noise prediction model or the data prediction model.
491
+ """
492
+ if self.predict_x0:
493
+ return self.data_prediction_fn(x, t)
494
+ else:
495
+ return self.noise_prediction_fn(x, t)
496
+
497
+ def get_time_steps(self, skip_type, t_T, t_0, N, device):
498
+ """Compute the intermediate time steps for sampling."""
499
+ if skip_type == "logSNR":
500
+ lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
501
+ lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
502
+ logSNR_steps = torch.linspace(
503
+ lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1
504
+ ).to(device)
505
+ return self.noise_schedule.inverse_lambda(logSNR_steps)
506
+ elif skip_type == "time_uniform":
507
+ return torch.linspace(t_T, t_0, N + 1).to(device)
508
+ elif skip_type == "time_quadratic":
509
+ t_order = 2
510
+ t = (
511
+ torch.linspace(t_T ** (1.0 / t_order), t_0 ** (1.0 / t_order), N + 1)
512
+ .pow(t_order)
513
+ .to(device)
514
+ )
515
+ return t
516
+ else:
517
+ raise ValueError(
518
+ "Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(
519
+ skip_type
520
+ )
521
+ )
522
+
523
+ def get_orders_and_timesteps_for_singlestep_solver(
524
+ self, steps, order, skip_type, t_T, t_0, device
525
+ ):
526
+ """
527
+ Get the order of each step for sampling by the singlestep DPM-Solver.
528
+ """
529
+ if order == 3:
530
+ K = steps // 3 + 1
531
+ if steps % 3 == 0:
532
+ orders = [
533
+ 3,
534
+ ] * (
535
+ K - 2
536
+ ) + [2, 1]
537
+ elif steps % 3 == 1:
538
+ orders = [
539
+ 3,
540
+ ] * (
541
+ K - 1
542
+ ) + [1]
543
+ else:
544
+ orders = [
545
+ 3,
546
+ ] * (
547
+ K - 1
548
+ ) + [2]
549
+ elif order == 2:
550
+ if steps % 2 == 0:
551
+ K = steps // 2
552
+ orders = [
553
+ 2,
554
+ ] * K
555
+ else:
556
+ K = steps // 2 + 1
557
+ orders = [
558
+ 2,
559
+ ] * (
560
+ K - 1
561
+ ) + [1]
562
+ elif order == 1:
563
+ K = steps
564
+ orders = [
565
+ 1,
566
+ ] * steps
567
+ else:
568
+ raise ValueError("'order' must be '1' or '2' or '3'.")
569
+ if skip_type == "logSNR":
570
+ # To reproduce the results in DPM-Solver paper
571
+ timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device)
572
+ else:
573
+ timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[
574
+ torch.cumsum(
575
+ torch.tensor(
576
+ [
577
+ 0,
578
+ ]
579
+ + orders
580
+ ),
581
+ 0,
582
+ ).to(device)
583
+ ]
584
+ return timesteps_outer, orders
585
+
586
+ def denoise_to_zero_fn(self, x, s):
587
+ """
588
+ Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
589
+ """
590
+ return self.data_prediction_fn(x, s)
591
+
592
+ def multistep_uni_pc_update(
593
+ self, x, model_prev_list, t_prev_list, t, order, **kwargs
594
+ ):
595
+ if len(t.shape) == 0:
596
+ t = t.view(-1)
597
+ if "bh" in self.variant:
598
+ return self.multistep_uni_pc_bh_update(
599
+ x, model_prev_list, t_prev_list, t, order, **kwargs
600
+ )
601
+ else:
602
+ assert self.variant == "vary_coeff"
603
+ return self.multistep_uni_pc_vary_update(
604
+ x, model_prev_list, t_prev_list, t, order, **kwargs
605
+ )
606
+
607
+ def multistep_uni_pc_vary_update(
608
+ self, x, model_prev_list, t_prev_list, t, order, use_corrector=True
609
+ ):
610
+ print(
611
+ f"using unified predictor-corrector with order {order} (solver type: vary coeff)"
612
+ )
613
+ ns = self.noise_schedule
614
+ assert order <= len(model_prev_list)
615
+
616
+ # first compute rks
617
+ t_prev_0 = t_prev_list[-1]
618
+ lambda_prev_0 = ns.marginal_lambda(t_prev_0)
619
+ lambda_t = ns.marginal_lambda(t)
620
+ model_prev_0 = model_prev_list[-1]
621
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
622
+ log_alpha_t = ns.marginal_log_mean_coeff(t)
623
+ alpha_t = torch.exp(log_alpha_t)
624
+
625
+ h = lambda_t - lambda_prev_0
626
+
627
+ rks = []
628
+ D1s = []
629
+ for i in range(1, order):
630
+ t_prev_i = t_prev_list[-(i + 1)]
631
+ model_prev_i = model_prev_list[-(i + 1)]
632
+ lambda_prev_i = ns.marginal_lambda(t_prev_i)
633
+ rk = (lambda_prev_i - lambda_prev_0) / h
634
+ rks.append(rk)
635
+ D1s.append((model_prev_i - model_prev_0) / rk)
636
+
637
+ rks.append(1.0)
638
+ rks = torch.tensor(rks, device=x.device)
639
+
640
+ K = len(rks)
641
+ # build C matrix
642
+ C = []
643
+
644
+ col = torch.ones_like(rks)
645
+ for k in range(1, K + 1):
646
+ C.append(col)
647
+ col = col * rks / (k + 1)
648
+ C = torch.stack(C, dim=1)
649
+
650
+ if len(D1s) > 0:
651
+ D1s = torch.stack(D1s, dim=1) # (B, K)
652
+ C_inv_p = torch.linalg.inv(C[:-1, :-1])
653
+ A_p = C_inv_p
654
+
655
+ if use_corrector:
656
+ print("using corrector")
657
+ C_inv = torch.linalg.inv(C)
658
+ A_c = C_inv
659
+
660
+ hh = -h if self.predict_x0 else h
661
+ h_phi_1 = torch.expm1(hh)
662
+ h_phi_ks = []
663
+ factorial_k = 1
664
+ h_phi_k = h_phi_1
665
+ for k in range(1, K + 2):
666
+ h_phi_ks.append(h_phi_k)
667
+ h_phi_k = h_phi_k / hh - 1 / factorial_k
668
+ factorial_k *= k + 1
669
+
670
+ model_t = None
671
+ if self.predict_x0:
672
+ x_t_ = sigma_t / sigma_prev_0 * x - alpha_t * h_phi_1 * model_prev_0
673
+ # now predictor
674
+ x_t = x_t_
675
+ if len(D1s) > 0:
676
+ # compute the residuals for predictor
677
+ for k in range(K - 1):
678
+ x_t = x_t - alpha_t * h_phi_ks[k + 1] * torch.einsum(
679
+ "bktchw,k->btchw", D1s, A_p[k]
680
+ )
681
+ # now corrector
682
+ if use_corrector:
683
+ model_t = self.model_fn(x_t, t)
684
+ D1_t = model_t - model_prev_0
685
+ x_t = x_t_
686
+ k = 0
687
+ for k in range(K - 1):
688
+ x_t = x_t - alpha_t * h_phi_ks[k + 1] * torch.einsum(
689
+ "bktchw,k->btchw", D1s, A_c[k][:-1]
690
+ )
691
+ x_t = x_t - alpha_t * h_phi_ks[K] * (D1_t * A_c[k][-1])
692
+ else:
693
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(
694
+ t_prev_0
695
+ ), ns.marginal_log_mean_coeff(t)
696
+ x_t_ = (torch.exp(log_alpha_t - log_alpha_prev_0)) * x - (
697
+ sigma_t * h_phi_1
698
+ ) * model_prev_0
699
+ # now predictor
700
+ x_t = x_t_
701
+ if len(D1s) > 0:
702
+ # compute the residuals for predictor
703
+ for k in range(K - 1):
704
+ x_t = x_t - sigma_t * h_phi_ks[k + 1] * torch.einsum(
705
+ "bktchw,k->btchw", D1s, A_p[k]
706
+ )
707
+ # now corrector
708
+ if use_corrector:
709
+ model_t = self.model_fn(x_t, t)
710
+ D1_t = model_t - model_prev_0
711
+ x_t = x_t_
712
+ k = 0
713
+ for k in range(K - 1):
714
+ x_t = x_t - sigma_t * h_phi_ks[k + 1] * torch.einsum(
715
+ "bktchw,k->btchw", D1s, A_c[k][:-1]
716
+ )
717
+ x_t = x_t - sigma_t * h_phi_ks[K] * (D1_t * A_c[k][-1])
718
+ return x_t, model_t
719
+
720
+ def multistep_uni_pc_bh_update(
721
+ self, x, model_prev_list, t_prev_list, t, order, x_t=None, use_corrector=True
722
+ ):
723
+ print(
724
+ f"using unified predictor-corrector with order {order} (solver type: B(h))"
725
+ )
726
+ ns = self.noise_schedule
727
+ assert order <= len(model_prev_list)
728
+ dims = x.dim()
729
+
730
+ # first compute rks
731
+ t_prev_0 = t_prev_list[-1]
732
+ lambda_prev_0 = ns.marginal_lambda(t_prev_0)
733
+ lambda_t = ns.marginal_lambda(t)
734
+ model_prev_0 = model_prev_list[-1]
735
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
736
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(
737
+ t_prev_0
738
+ ), ns.marginal_log_mean_coeff(t)
739
+ alpha_t = torch.exp(log_alpha_t)
740
+
741
+ h = lambda_t - lambda_prev_0
742
+
743
+ rks = []
744
+ D1s = []
745
+ for i in range(1, order):
746
+ t_prev_i = t_prev_list[-(i + 1)]
747
+ model_prev_i = model_prev_list[-(i + 1)]
748
+ lambda_prev_i = ns.marginal_lambda(t_prev_i)
749
+ rk = ((lambda_prev_i - lambda_prev_0) / h)[0]
750
+ rks.append(rk)
751
+ D1s.append((model_prev_i - model_prev_0) / rk)
752
+
753
+ rks.append(1.0)
754
+ rks = torch.tensor(rks, device=x.device)
755
+
756
+ R = []
757
+ b = []
758
+
759
+ hh = -h[0] if self.predict_x0 else h[0]
760
+ h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
761
+ h_phi_k = h_phi_1 / hh - 1
762
+
763
+ factorial_i = 1
764
+
765
+ if self.variant == "bh1":
766
+ B_h = hh
767
+ elif self.variant == "bh2":
768
+ B_h = torch.expm1(hh)
769
+ else:
770
+ raise NotImplementedError()
771
+
772
+ for i in range(1, order + 1):
773
+ R.append(torch.pow(rks, i - 1))
774
+ b.append(h_phi_k * factorial_i / B_h)
775
+ factorial_i *= i + 1
776
+ h_phi_k = h_phi_k / hh - 1 / factorial_i
777
+
778
+ R = torch.stack(R)
779
+ b = torch.tensor(b, device=x.device)
780
+
781
+ # now predictor
782
+ use_predictor = len(D1s) > 0 and x_t is None
783
+ if len(D1s) > 0:
784
+ D1s = torch.stack(D1s, dim=1) # (B, K)
785
+ if x_t is None:
786
+ # for order 2, we use a simplified version
787
+ if order == 2:
788
+ rhos_p = torch.tensor([0.5], device=b.device)
789
+ else:
790
+ rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1])
791
+ else:
792
+ D1s = None
793
+
794
+ if use_corrector:
795
+ print("using corrector")
796
+ # for order 1, we use a simplified version
797
+ if order == 1:
798
+ rhos_c = torch.tensor([0.5], device=b.device)
799
+ else:
800
+ rhos_c = torch.linalg.solve(R, b)
801
+
802
+ model_t = None
803
+ if self.predict_x0:
804
+ x_t_ = (
805
+ expand_dims(sigma_t / sigma_prev_0, dims) * x
806
+ - expand_dims(alpha_t * h_phi_1, dims) * model_prev_0
807
+ )
808
+
809
+ if x_t is None:
810
+ if use_predictor:
811
+ pred_res = torch.einsum("k,bktchw->btchw", rhos_p, D1s)
812
+ else:
813
+ pred_res = 0
814
+ x_t = x_t_ - expand_dims(alpha_t * B_h, dims) * pred_res
815
+
816
+ if use_corrector:
817
+ model_t = self.model_fn(x_t, t)
818
+ if D1s is not None:
819
+ corr_res = torch.einsum("k,bktchw->btchw", rhos_c[:-1], D1s)
820
+ else:
821
+ corr_res = 0
822
+ D1_t = model_t - model_prev_0
823
+ x_t = x_t_ - expand_dims(alpha_t * B_h, dims) * (
824
+ corr_res + rhos_c[-1] * D1_t
825
+ )
826
+ else:
827
+ x_t_ = (
828
+ expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
829
+ - expand_dims(sigma_t * h_phi_1, dims) * model_prev_0
830
+ )
831
+ if x_t is None:
832
+ if use_predictor:
833
+ pred_res = torch.einsum("k,bktchw->btchw", rhos_p, D1s)
834
+ else:
835
+ pred_res = 0
836
+ x_t = x_t_ - expand_dims(sigma_t * B_h, dims) * pred_res
837
+
838
+ if use_corrector:
839
+ model_t = self.model_fn(x_t, t)
840
+ if D1s is not None:
841
+ corr_res = torch.einsum("k,bktchw->btchw", rhos_c[:-1], D1s)
842
+ else:
843
+ corr_res = 0
844
+ D1_t = model_t - model_prev_0
845
+ x_t = x_t_ - expand_dims(sigma_t * B_h, dims) * (
846
+ corr_res + rhos_c[-1] * D1_t
847
+ )
848
+ return x_t, model_t
849
+
850
+ def sample(
851
+ self,
852
+ x,
853
+ steps=20,
854
+ t_start=None,
855
+ t_end=None,
856
+ order=3,
857
+ skip_type="time_uniform",
858
+ method="singlestep",
859
+ lower_order_final=True,
860
+ denoise_to_zero=False,
861
+ solver_type="dpm_solver",
862
+ atol=0.0078,
863
+ rtol=0.05,
864
+ corrector=False,
865
+ ):
866
+ t_0 = 1.0 / self.noise_schedule.total_N if t_end is None else t_end
867
+ t_T = self.noise_schedule.T if t_start is None else t_start
868
+ device = x.device
869
+ if method == "multistep":
870
+ assert steps >= order
871
+ timesteps = self.get_time_steps(
872
+ skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device
873
+ )
874
+ assert timesteps.shape[0] - 1 == steps
875
+ with torch.no_grad():
876
+ vec_t = timesteps[0].expand((x.shape[0]))
877
+ model_prev_list = [self.model_fn(x, vec_t)]
878
+ t_prev_list = [vec_t]
879
+ # Init the first `order` values by lower order multistep DPM-Solver.
880
+ for init_order in range(1, order):
881
+ vec_t = timesteps[init_order].expand(x.shape[0])
882
+ x, model_x = self.multistep_uni_pc_update(
883
+ x,
884
+ model_prev_list,
885
+ t_prev_list,
886
+ vec_t,
887
+ init_order,
888
+ use_corrector=True,
889
+ )
890
+ if model_x is None:
891
+ model_x = self.model_fn(x, vec_t)
892
+ model_prev_list.append(model_x)
893
+ t_prev_list.append(vec_t)
894
+ for step in range(order, steps + 1):
895
+ vec_t = timesteps[step].expand(x.shape[0])
896
+ print(f"Current step={step}; vec_t={vec_t}.")
897
+ if lower_order_final:
898
+ step_order = min(order, steps + 1 - step)
899
+ else:
900
+ step_order = order
901
+ print("this step order:", step_order)
902
+ if step == steps:
903
+ print("do not run corrector at the last step")
904
+ use_corrector = False
905
+ else:
906
+ use_corrector = True
907
+ x, model_x = self.multistep_uni_pc_update(
908
+ x,
909
+ model_prev_list,
910
+ t_prev_list,
911
+ vec_t,
912
+ step_order,
913
+ use_corrector=use_corrector,
914
+ )
915
+ for i in range(order - 1):
916
+ t_prev_list[i] = t_prev_list[i + 1]
917
+ model_prev_list[i] = model_prev_list[i + 1]
918
+ t_prev_list[-1] = vec_t
919
+ # We do not need to evaluate the final model value.
920
+ if step < steps:
921
+ if model_x is None:
922
+ model_x = self.model_fn(x, vec_t)
923
+ model_prev_list[-1] = model_x
924
+ else:
925
+ raise NotImplementedError()
926
+ if denoise_to_zero:
927
+ x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0)
928
+ return x
929
+
930
+
931
+ #############################################################
932
+ # other utility functions
933
+ #############################################################
934
+
935
+
936
+ def interpolate_fn(x, xp, yp):
937
+ """
938
+ A piecewise linear function y = f(x), using xp and yp as keypoints.
939
+ We implement f(x) in a differentiable way (i.e. applicable for autograd).
940
+ The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
941
+
942
+ Args:
943
+ x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).
944
+ xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
945
+ yp: PyTorch tensor with shape [C, K].
946
+ Returns:
947
+ The function values f(x), with shape [N, C].
948
+ """
949
+ N, K = x.shape[0], xp.shape[1]
950
+ all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
951
+ sorted_all_x, x_indices = torch.sort(all_x, dim=2)
952
+ x_idx = torch.argmin(x_indices, dim=2)
953
+ cand_start_idx = x_idx - 1
954
+ start_idx = torch.where(
955
+ torch.eq(x_idx, 0),
956
+ torch.tensor(1, device=x.device),
957
+ torch.where(
958
+ torch.eq(x_idx, K),
959
+ torch.tensor(K - 2, device=x.device),
960
+ cand_start_idx,
961
+ ),
962
+ )
963
+ end_idx = torch.where(
964
+ torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1
965
+ )
966
+ start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
967
+ end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
968
+ start_idx2 = torch.where(
969
+ torch.eq(x_idx, 0),
970
+ torch.tensor(0, device=x.device),
971
+ torch.where(
972
+ torch.eq(x_idx, K),
973
+ torch.tensor(K - 2, device=x.device),
974
+ cand_start_idx,
975
+ ),
976
+ )
977
+ y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
978
+ start_y = torch.gather(
979
+ y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)
980
+ ).squeeze(2)
981
+ end_y = torch.gather(
982
+ y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)
983
+ ).squeeze(2)
984
+ cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
985
+ return cand
986
+
987
+
988
+ def expand_dims(v, dims):
989
+ """
990
+ Expand the tensor `v` to the dim `dims`.
991
+
992
+ Args:
993
+ `v`: a PyTorch tensor with shape [N].
994
+ `dim`: a `int`.
995
+ Returns:
996
+ a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
997
+ """
998
+ return v[(...,) + (None,) * (dims - 1)]
core/models/utils_diffusion.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import numpy as np
4
+ import torch
5
+ from einops import repeat
6
+
7
+
8
+ def timestep_embedding(time_steps, dim, max_period=10000, repeat_only=False):
9
+ """
10
+ Create sinusoidal timestep embeddings.
11
+ :param time_steps: a 1-D Tensor of N indices, one per batch element.
12
+ These may be fractional.
13
+ :param dim: the dimension of the output.
14
+ :param max_period: controls the minimum frequency of the embeddings.
15
+ :return: an [N x dim] Tensor of positional embeddings.
16
+ """
17
+ if not repeat_only:
18
+ half = dim // 2
19
+ freqs = torch.exp(
20
+ -math.log(max_period)
21
+ * torch.arange(start=0, end=half, dtype=torch.float32)
22
+ / half
23
+ ).to(device=time_steps.device)
24
+ args = time_steps[:, None].float() * freqs[None]
25
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
26
+ if dim % 2:
27
+ embedding = torch.cat(
28
+ [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
29
+ )
30
+ else:
31
+ embedding = repeat(time_steps, "b -> b d", d=dim)
32
+ return embedding
33
+
34
+
35
+ def make_beta_schedule(
36
+ schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3
37
+ ):
38
+ if schedule == "linear":
39
+ betas = (
40
+ torch.linspace(
41
+ linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64
42
+ )
43
+ ** 2
44
+ )
45
+
46
+ elif schedule == "cosine":
47
+ time_steps = (
48
+ torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
49
+ )
50
+ alphas = time_steps / (1 + cosine_s) * np.pi / 2
51
+ alphas = torch.cos(alphas).pow(2)
52
+ alphas = alphas / alphas[0]
53
+ betas = 1 - alphas[1:] / alphas[:-1]
54
+ betas = np.clip(betas, a_min=0, a_max=0.999)
55
+
56
+ elif schedule == "sqrt_linear":
57
+ betas = torch.linspace(
58
+ linear_start, linear_end, n_timestep, dtype=torch.float64
59
+ )
60
+ elif schedule == "sqrt":
61
+ betas = (
62
+ torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
63
+ ** 0.5
64
+ )
65
+ else:
66
+ raise ValueError(f"schedule '{schedule}' unknown.")
67
+ return betas.numpy()
68
+
69
+
70
+ def make_ddim_time_steps(
71
+ ddim_discr_method, num_ddim_time_steps, num_ddpm_time_steps, verbose=True
72
+ ):
73
+ if ddim_discr_method == "uniform":
74
+ c = num_ddpm_time_steps // num_ddim_time_steps
75
+ ddim_time_steps = np.asarray(list(range(0, num_ddpm_time_steps, c)))
76
+ steps_out = ddim_time_steps + 1
77
+ elif ddim_discr_method == "quad":
78
+ ddim_time_steps = (
79
+ (np.linspace(0, np.sqrt(num_ddpm_time_steps * 0.8), num_ddim_time_steps))
80
+ ** 2
81
+ ).astype(int)
82
+ steps_out = ddim_time_steps + 1
83
+ elif ddim_discr_method == "uniform_trailing":
84
+ c = num_ddpm_time_steps / num_ddim_time_steps
85
+ ddim_time_steps = np.flip(
86
+ np.round(np.arange(num_ddpm_time_steps, 0, -c))
87
+ ).astype(np.int64)
88
+ steps_out = ddim_time_steps - 1
89
+ else:
90
+ raise NotImplementedError(
91
+ f'There is no ddim discretization method called "{ddim_discr_method}"'
92
+ )
93
+
94
+ # assert ddim_time_steps.shape[0] == num_ddim_time_steps
95
+ # add one to get the final alpha values right (the ones from first scale to data during sampling)
96
+ if verbose:
97
+ print(f"Selected time_steps for ddim sampler: {steps_out}")
98
+ return steps_out
99
+
100
+
101
+ def make_ddim_sampling_parameters(alphacums, ddim_time_steps, eta, verbose=True):
102
+ # select alphas for computing the variance schedule
103
+ # print(f'ddim_time_steps={ddim_time_steps}, len_alphacums={len(alphacums)}')
104
+ alphas = alphacums[ddim_time_steps]
105
+ alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_time_steps[:-1]].tolist())
106
+
107
+ # according the the formula provided in https://arxiv.org/abs/2010.02502
108
+ sigmas = eta * np.sqrt(
109
+ (1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)
110
+ )
111
+ if verbose:
112
+ print(
113
+ f"Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}"
114
+ )
115
+ print(
116
+ f"For the chosen value of eta, which is {eta}, "
117
+ f"this results in the following sigma_t schedule for ddim sampler {sigmas}"
118
+ )
119
+ return sigmas, alphas, alphas_prev
120
+
121
+
122
+ def betas_for_alpha_bar(num_diffusion_time_steps, alpha_bar, max_beta=0.999):
123
+ """
124
+ Create a beta schedule that discretizes the given alpha_t_bar function,
125
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
126
+ :param num_diffusion_time_steps: the number of betas to produce.
127
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
128
+ produces the cumulative product of (1-beta) up to that
129
+ part of the diffusion process.
130
+ :param max_beta: the maximum beta to use; use values lower than 1 to
131
+ prevent singularities.
132
+ """
133
+ betas = []
134
+ for i in range(num_diffusion_time_steps):
135
+ t1 = i / num_diffusion_time_steps
136
+ t2 = (i + 1) / num_diffusion_time_steps
137
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
138
+ return np.array(betas)
139
+
140
+
141
+ def rescale_zero_terminal_snr(betas):
142
+ """
143
+ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
144
+
145
+ Args:
146
+ betas (`numpy.ndarray`):
147
+ the betas that the scheduler is being initialized with.
148
+
149
+ Returns:
150
+ `numpy.ndarray`: rescaled betas with zero terminal SNR
151
+ """
152
+ # Convert betas to alphas_bar_sqrt
153
+ alphas = 1.0 - betas
154
+ alphas_cumprod = np.cumprod(alphas, axis=0)
155
+ alphas_bar_sqrt = np.sqrt(alphas_cumprod)
156
+
157
+ # Store old values.
158
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].copy()
159
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].copy()
160
+
161
+ # Shift so the last timestep is zero.
162
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
163
+
164
+ # Scale so the first timestep is back to the old value.
165
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
166
+
167
+ # Convert alphas_bar_sqrt to betas
168
+ alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
169
+ alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
170
+ alphas = np.concatenate([alphas_bar[0:1], alphas])
171
+ betas = 1 - alphas
172
+
173
+ return betas
174
+
175
+
176
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
177
+ """
178
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
179
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
180
+ """
181
+ std_text = noise_pred_text.std(
182
+ dim=list(range(1, noise_pred_text.ndim)), keepdim=True
183
+ )
184
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
185
+ factor = guidance_rescale * (std_text / std_cfg) + (1 - guidance_rescale)
186
+ return noise_cfg * factor
core/modules/attention.py ADDED
@@ -0,0 +1,710 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn, einsum
3
+ import torch.nn.functional as F
4
+ from einops import rearrange, repeat
5
+ from functools import partial
6
+
7
+ try:
8
+ import xformers
9
+ import xformers.ops
10
+
11
+ XFORMERS_IS_AVAILBLE = True
12
+ except:
13
+ XFORMERS_IS_AVAILBLE = False
14
+ from core.common import (
15
+ gradient_checkpoint,
16
+ exists,
17
+ default,
18
+ )
19
+ from core.basics import zero_module
20
+
21
+
22
+ class RelativePosition(nn.Module):
23
+
24
+ def __init__(self, num_units, max_relative_position):
25
+ super().__init__()
26
+ self.num_units = num_units
27
+ self.max_relative_position = max_relative_position
28
+ self.embeddings_table = nn.Parameter(
29
+ torch.Tensor(max_relative_position * 2 + 1, num_units)
30
+ )
31
+ nn.init.xavier_uniform_(self.embeddings_table)
32
+
33
+ def forward(self, length_q, length_k):
34
+ device = self.embeddings_table.device
35
+ range_vec_q = torch.arange(length_q, device=device)
36
+ range_vec_k = torch.arange(length_k, device=device)
37
+ distance_mat = range_vec_k[None, :] - range_vec_q[:, None]
38
+ distance_mat_clipped = torch.clamp(
39
+ distance_mat, -self.max_relative_position, self.max_relative_position
40
+ )
41
+ final_mat = distance_mat_clipped + self.max_relative_position
42
+ final_mat = final_mat.long()
43
+ embeddings = self.embeddings_table[final_mat]
44
+ return embeddings
45
+
46
+
47
+ class CrossAttention(nn.Module):
48
+
49
+ def __init__(
50
+ self,
51
+ query_dim,
52
+ context_dim=None,
53
+ heads=8,
54
+ dim_head=64,
55
+ dropout=0.0,
56
+ relative_position=False,
57
+ temporal_length=None,
58
+ video_length=None,
59
+ image_cross_attention=False,
60
+ image_cross_attention_scale=1.0,
61
+ image_cross_attention_scale_learnable=False,
62
+ text_context_len=77,
63
+ ):
64
+ super().__init__()
65
+ inner_dim = dim_head * heads
66
+ context_dim = default(context_dim, query_dim)
67
+
68
+ self.scale = dim_head**-0.5
69
+ self.heads = heads
70
+ self.dim_head = dim_head
71
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
72
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
73
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
74
+
75
+ self.to_out = nn.Sequential(
76
+ nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
77
+ )
78
+
79
+ self.relative_position = relative_position
80
+ if self.relative_position:
81
+ assert temporal_length is not None
82
+ self.relative_position_k = RelativePosition(
83
+ num_units=dim_head, max_relative_position=temporal_length
84
+ )
85
+ self.relative_position_v = RelativePosition(
86
+ num_units=dim_head, max_relative_position=temporal_length
87
+ )
88
+ else:
89
+ # only used for spatial attention, while NOT for temporal attention
90
+ if XFORMERS_IS_AVAILBLE and temporal_length is None:
91
+ self.forward = self.efficient_forward
92
+
93
+ self.video_length = video_length
94
+ self.image_cross_attention = image_cross_attention
95
+ self.image_cross_attention_scale = image_cross_attention_scale
96
+ self.text_context_len = text_context_len
97
+ self.image_cross_attention_scale_learnable = (
98
+ image_cross_attention_scale_learnable
99
+ )
100
+ if self.image_cross_attention:
101
+ self.to_k_ip = nn.Linear(context_dim, inner_dim, bias=False)
102
+ self.to_v_ip = nn.Linear(context_dim, inner_dim, bias=False)
103
+ if image_cross_attention_scale_learnable:
104
+ self.register_parameter("alpha", nn.Parameter(torch.tensor(0.0)))
105
+
106
+ def forward(self, x, context=None, mask=None):
107
+ spatial_self_attn = context is None
108
+ k_ip, v_ip, out_ip = None, None, None
109
+
110
+ h = self.heads
111
+ q = self.to_q(x)
112
+ context = default(context, x)
113
+
114
+ if self.image_cross_attention and not spatial_self_attn:
115
+ context, context_image = (
116
+ context[:, : self.text_context_len, :],
117
+ context[:, self.text_context_len :, :],
118
+ )
119
+ k = self.to_k(context)
120
+ v = self.to_v(context)
121
+ k_ip = self.to_k_ip(context_image)
122
+ v_ip = self.to_v_ip(context_image)
123
+ else:
124
+ if not spatial_self_attn:
125
+ context = context[:, : self.text_context_len, :]
126
+ k = self.to_k(context)
127
+ v = self.to_v(context)
128
+
129
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
130
+
131
+ sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
132
+ if self.relative_position:
133
+ len_q, len_k, len_v = q.shape[1], k.shape[1], v.shape[1]
134
+ k2 = self.relative_position_k(len_q, len_k)
135
+ sim2 = einsum("b t d, t s d -> b t s", q, k2) * self.scale
136
+ sim += sim2
137
+ del k
138
+
139
+ if exists(mask):
140
+ # feasible for causal attention mask only
141
+ max_neg_value = -torch.finfo(sim.dtype).max
142
+ mask = repeat(mask, "b i j -> (b h) i j", h=h)
143
+ sim.masked_fill_(~(mask > 0.5), max_neg_value)
144
+
145
+ # attention, what we cannot get enough of
146
+ sim = sim.softmax(dim=-1)
147
+
148
+ out = torch.einsum("b i j, b j d -> b i d", sim, v)
149
+ if self.relative_position:
150
+ v2 = self.relative_position_v(len_q, len_v)
151
+ out2 = einsum("b t s, t s d -> b t d", sim, v2)
152
+ out += out2
153
+ out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
154
+
155
+ # for image cross-attention
156
+ if k_ip is not None:
157
+ k_ip, v_ip = map(
158
+ lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (k_ip, v_ip)
159
+ )
160
+ sim_ip = torch.einsum("b i d, b j d -> b i j", q, k_ip) * self.scale
161
+ del k_ip
162
+ sim_ip = sim_ip.softmax(dim=-1)
163
+ out_ip = torch.einsum("b i j, b j d -> b i d", sim_ip, v_ip)
164
+ out_ip = rearrange(out_ip, "(b h) n d -> b n (h d)", h=h)
165
+
166
+ if out_ip is not None:
167
+ if self.image_cross_attention_scale_learnable:
168
+ out = out + self.image_cross_attention_scale * out_ip * (
169
+ torch.tanh(self.alpha) + 1
170
+ )
171
+ else:
172
+ out = out + self.image_cross_attention_scale * out_ip
173
+
174
+ return self.to_out(out)
175
+
176
+ def efficient_forward(self, x, context=None, mask=None):
177
+ spatial_self_attn = context is None
178
+ k_ip, v_ip, out_ip = None, None, None
179
+
180
+ q = self.to_q(x)
181
+ context = default(context, x)
182
+
183
+ if self.image_cross_attention and not spatial_self_attn:
184
+ context, context_image = (
185
+ context[:, : self.text_context_len, :],
186
+ context[:, self.text_context_len :, :],
187
+ )
188
+ k = self.to_k(context)
189
+ v = self.to_v(context)
190
+ k_ip = self.to_k_ip(context_image)
191
+ v_ip = self.to_v_ip(context_image)
192
+ else:
193
+ if not spatial_self_attn:
194
+ context = context[:, : self.text_context_len, :]
195
+ k = self.to_k(context)
196
+ v = self.to_v(context)
197
+
198
+ b, _, _ = q.shape
199
+ q, k, v = map(
200
+ lambda t: t.unsqueeze(3)
201
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
202
+ .permute(0, 2, 1, 3)
203
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
204
+ .contiguous(),
205
+ (q, k, v),
206
+ )
207
+ # actually compute the attention, what we cannot get enough of
208
+ out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=None)
209
+
210
+ # for image cross-attention
211
+ if k_ip is not None:
212
+ k_ip, v_ip = map(
213
+ lambda t: t.unsqueeze(3)
214
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
215
+ .permute(0, 2, 1, 3)
216
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
217
+ .contiguous(),
218
+ (k_ip, v_ip),
219
+ )
220
+ out_ip = xformers.ops.memory_efficient_attention(
221
+ q, k_ip, v_ip, attn_bias=None, op=None
222
+ )
223
+ out_ip = (
224
+ out_ip.unsqueeze(0)
225
+ .reshape(b, self.heads, out.shape[1], self.dim_head)
226
+ .permute(0, 2, 1, 3)
227
+ .reshape(b, out.shape[1], self.heads * self.dim_head)
228
+ )
229
+
230
+ if exists(mask):
231
+ raise NotImplementedError
232
+ out = (
233
+ out.unsqueeze(0)
234
+ .reshape(b, self.heads, out.shape[1], self.dim_head)
235
+ .permute(0, 2, 1, 3)
236
+ .reshape(b, out.shape[1], self.heads * self.dim_head)
237
+ )
238
+ if out_ip is not None:
239
+ if self.image_cross_attention_scale_learnable:
240
+ out = out + self.image_cross_attention_scale * out_ip * (
241
+ torch.tanh(self.alpha) + 1
242
+ )
243
+ else:
244
+ out = out + self.image_cross_attention_scale * out_ip
245
+
246
+ return self.to_out(out)
247
+
248
+
249
+ class BasicTransformerBlock(nn.Module):
250
+
251
+ def __init__(
252
+ self,
253
+ dim,
254
+ n_heads,
255
+ d_head,
256
+ dropout=0.0,
257
+ context_dim=None,
258
+ gated_ff=True,
259
+ checkpoint=True,
260
+ disable_self_attn=False,
261
+ attention_cls=None,
262
+ video_length=None,
263
+ image_cross_attention=False,
264
+ image_cross_attention_scale=1.0,
265
+ image_cross_attention_scale_learnable=False,
266
+ text_context_len=77,
267
+ enable_lora=False,
268
+ ):
269
+ super().__init__()
270
+ attn_cls = CrossAttention if attention_cls is None else attention_cls
271
+ self.disable_self_attn = disable_self_attn
272
+ self.attn1 = attn_cls(
273
+ query_dim=dim,
274
+ heads=n_heads,
275
+ dim_head=d_head,
276
+ dropout=dropout,
277
+ context_dim=context_dim if self.disable_self_attn else None,
278
+ )
279
+ self.ff = FeedForward(
280
+ dim, dropout=dropout, glu=gated_ff, enable_lora=enable_lora
281
+ )
282
+ self.attn2 = attn_cls(
283
+ query_dim=dim,
284
+ context_dim=context_dim,
285
+ heads=n_heads,
286
+ dim_head=d_head,
287
+ dropout=dropout,
288
+ video_length=video_length,
289
+ image_cross_attention=image_cross_attention,
290
+ image_cross_attention_scale=image_cross_attention_scale,
291
+ image_cross_attention_scale_learnable=image_cross_attention_scale_learnable,
292
+ text_context_len=text_context_len,
293
+ )
294
+ self.image_cross_attention = image_cross_attention
295
+
296
+ self.norm1 = nn.LayerNorm(dim)
297
+ self.norm2 = nn.LayerNorm(dim)
298
+ self.norm3 = nn.LayerNorm(dim)
299
+ self.checkpoint = checkpoint
300
+
301
+ self.enable_lora = enable_lora
302
+
303
+ def forward(self, x, context=None, mask=None, with_lora=False, **kwargs):
304
+ # implementation tricks: because checkpointing doesn't support non-tensor (e.g. None or scalar) arguments
305
+ # should not be (x), otherwise *input_tuple will decouple x into multiple arguments
306
+ input_tuple = (x,)
307
+ if context is not None:
308
+ input_tuple = (x, context)
309
+ if mask is not None:
310
+ _forward = partial(self._forward, mask=None, with_lora=with_lora)
311
+ else:
312
+ _forward = partial(self._forward, mask=mask, with_lora=with_lora)
313
+ return gradient_checkpoint(
314
+ _forward, input_tuple, self.parameters(), self.checkpoint
315
+ )
316
+
317
+ def _forward(self, x, context=None, mask=None, with_lora=False):
318
+ x = (
319
+ self.attn1(
320
+ self.norm1(x),
321
+ context=context if self.disable_self_attn else None,
322
+ mask=mask,
323
+ )
324
+ + x
325
+ )
326
+ x = self.attn2(self.norm2(x), context=context, mask=mask) + x
327
+ x = self.ff(self.norm3(x), with_lora=with_lora) + x
328
+ return x
329
+
330
+
331
+ class SpatialTransformer(nn.Module):
332
+ """
333
+ Transformer block for image-like data in spatial axis.
334
+ First, project the input (aka embedding)
335
+ and reshape to b, t, d.
336
+ Then apply standard transformer action.
337
+ Finally, reshape to image
338
+ NEW: use_linear for more efficiency instead of the 1x1 convs
339
+ """
340
+
341
+ def __init__(
342
+ self,
343
+ in_channels,
344
+ n_heads,
345
+ d_head,
346
+ depth=1,
347
+ dropout=0.0,
348
+ context_dim=None,
349
+ use_checkpoint=True,
350
+ disable_self_attn=False,
351
+ use_linear=False,
352
+ video_length=None,
353
+ image_cross_attention=False,
354
+ image_cross_attention_scale_learnable=False,
355
+ enable_lora=False,
356
+ ):
357
+ super().__init__()
358
+ self.in_channels = in_channels
359
+ inner_dim = n_heads * d_head
360
+ self.norm = torch.nn.GroupNorm(
361
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
362
+ )
363
+ if not use_linear:
364
+ self.proj_in = nn.Conv2d(
365
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0
366
+ )
367
+ else:
368
+ self.proj_in = nn.Linear(in_channels, inner_dim)
369
+
370
+ self.enable_lora = enable_lora
371
+
372
+ attention_cls = None
373
+ self.transformer_blocks = nn.ModuleList(
374
+ [
375
+ BasicTransformerBlock(
376
+ inner_dim,
377
+ n_heads,
378
+ d_head,
379
+ dropout=dropout,
380
+ context_dim=context_dim,
381
+ disable_self_attn=disable_self_attn,
382
+ checkpoint=use_checkpoint,
383
+ attention_cls=attention_cls,
384
+ video_length=video_length,
385
+ image_cross_attention=image_cross_attention,
386
+ image_cross_attention_scale_learnable=image_cross_attention_scale_learnable,
387
+ enable_lora=self.enable_lora,
388
+ )
389
+ for d in range(depth)
390
+ ]
391
+ )
392
+ if not use_linear:
393
+ self.proj_out = zero_module(
394
+ nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
395
+ )
396
+ else:
397
+ self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
398
+ self.use_linear = use_linear
399
+
400
+ def forward(self, x, context=None, with_lora=False, **kwargs):
401
+ b, c, h, w = x.shape
402
+ x_in = x
403
+ x = self.norm(x)
404
+ if not self.use_linear:
405
+ x = self.proj_in(x)
406
+ x = rearrange(x, "b c h w -> b (h w) c").contiguous()
407
+ if self.use_linear:
408
+ x = self.proj_in(x)
409
+ for i, block in enumerate(self.transformer_blocks):
410
+ x = block(x, context=context, with_lora=with_lora, **kwargs)
411
+ if self.use_linear:
412
+ x = self.proj_out(x)
413
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
414
+ if not self.use_linear:
415
+ x = self.proj_out(x)
416
+ return x + x_in
417
+
418
+
419
+ class TemporalTransformer(nn.Module):
420
+ """
421
+ Transformer block for image-like data in temporal axis.
422
+ First, reshape to b, t, d.
423
+ Then apply standard transformer action.
424
+ Finally, reshape to image
425
+ """
426
+
427
+ def __init__(
428
+ self,
429
+ in_channels,
430
+ n_heads,
431
+ d_head,
432
+ depth=1,
433
+ dropout=0.0,
434
+ context_dim=None,
435
+ use_checkpoint=True,
436
+ use_linear=False,
437
+ only_self_att=True,
438
+ causal_attention=False,
439
+ causal_block_size=1,
440
+ relative_position=False,
441
+ temporal_length=None,
442
+ use_extra_spatial_temporal_self_attention=False,
443
+ enable_lora=False,
444
+ full_spatial_temporal_attention=False,
445
+ enhance_multi_view_correspondence=False,
446
+ ):
447
+ super().__init__()
448
+ self.only_self_att = only_self_att
449
+ self.relative_position = relative_position
450
+ self.causal_attention = causal_attention
451
+ self.causal_block_size = causal_block_size
452
+
453
+ self.in_channels = in_channels
454
+ inner_dim = n_heads * d_head
455
+ self.norm = torch.nn.GroupNorm(
456
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
457
+ )
458
+ self.proj_in = nn.Conv1d(
459
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0
460
+ )
461
+ if not use_linear:
462
+ self.proj_in = nn.Conv1d(
463
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0
464
+ )
465
+ else:
466
+ self.proj_in = nn.Linear(in_channels, inner_dim)
467
+
468
+ if relative_position:
469
+ assert temporal_length is not None
470
+ attention_cls = partial(
471
+ CrossAttention, relative_position=True, temporal_length=temporal_length
472
+ )
473
+ else:
474
+ attention_cls = partial(CrossAttention, temporal_length=temporal_length)
475
+ if self.causal_attention:
476
+ assert temporal_length is not None
477
+ self.mask = torch.tril(torch.ones([1, temporal_length, temporal_length]))
478
+
479
+ if self.only_self_att:
480
+ context_dim = None
481
+ self.transformer_blocks = nn.ModuleList(
482
+ [
483
+ BasicTransformerBlock(
484
+ inner_dim,
485
+ n_heads,
486
+ d_head,
487
+ dropout=dropout,
488
+ context_dim=context_dim,
489
+ attention_cls=attention_cls,
490
+ checkpoint=use_checkpoint,
491
+ enable_lora=enable_lora,
492
+ )
493
+ for d in range(depth)
494
+ ]
495
+ )
496
+ if not use_linear:
497
+ self.proj_out = zero_module(
498
+ nn.Conv1d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
499
+ )
500
+ else:
501
+ self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
502
+ self.use_linear = use_linear
503
+
504
+ self.use_extra_spatial_temporal_self_attention = (
505
+ use_extra_spatial_temporal_self_attention
506
+ )
507
+ if use_extra_spatial_temporal_self_attention:
508
+ from core.modules.attention_mv import MultiViewSelfAttentionTransformer
509
+
510
+ self.extra_spatial_time_self_attention = MultiViewSelfAttentionTransformer(
511
+ in_channels=in_channels,
512
+ n_heads=n_heads,
513
+ d_head=d_head,
514
+ num_views=temporal_length,
515
+ depth=depth,
516
+ use_linear=use_linear,
517
+ use_checkpoint=use_checkpoint,
518
+ full_spatial_temporal_attention=full_spatial_temporal_attention,
519
+ enhance_multi_view_correspondence=enhance_multi_view_correspondence,
520
+ )
521
+
522
+ def forward(self, x, context=None, with_lora=False, time_steps=None):
523
+ b, c, t, h, w = x.shape
524
+ x_in = x
525
+ x = self.norm(x)
526
+ x = rearrange(x, "b c t h w -> (b h w) c t").contiguous()
527
+ if not self.use_linear:
528
+ x = self.proj_in(x)
529
+ x = rearrange(x, "bhw c t -> bhw t c").contiguous()
530
+ if self.use_linear:
531
+ x = self.proj_in(x)
532
+
533
+ temp_mask = None
534
+ if self.causal_attention:
535
+ # slice the from mask map
536
+ temp_mask = self.mask[:, :t, :t].to(x.device)
537
+
538
+ if temp_mask is not None:
539
+ mask = temp_mask.to(x.device)
540
+ mask = repeat(mask, "l i j -> (l bhw) i j", bhw=b * h * w)
541
+ else:
542
+ mask = None
543
+
544
+ if self.only_self_att:
545
+ # note: if no context is given, cross-attention defaults to self-attention
546
+ for i, block in enumerate(self.transformer_blocks):
547
+ x = block(x, mask=mask, with_lora=with_lora)
548
+ x = rearrange(x, "(b hw) t c -> b hw t c", b=b).contiguous()
549
+ else:
550
+ x = rearrange(x, "(b hw) t c -> b hw t c", b=b).contiguous()
551
+ context = rearrange(context, "(b t) l con -> b t l con", t=t).contiguous()
552
+ for i, block in enumerate(self.transformer_blocks):
553
+ # calculate each batch one by one (since number in shape could not greater then 65,535 for some package)
554
+ for j in range(b):
555
+ context_j = repeat(
556
+ context[j], "t l con -> (t r) l con", r=(h * w) // t, t=t
557
+ ).contiguous()
558
+ # note: causal mask will not applied in cross-attention case
559
+ x[j] = block(x[j], context=context_j, with_lora=with_lora)
560
+
561
+ if self.use_linear:
562
+ x = self.proj_out(x)
563
+ x = rearrange(x, "b (h w) t c -> b c t h w", h=h, w=w).contiguous()
564
+ if not self.use_linear:
565
+ x = rearrange(x, "b hw t c -> (b hw) c t").contiguous()
566
+ x = self.proj_out(x)
567
+ x = rearrange(x, "(b h w) c t -> b c t h w", b=b, h=h, w=w).contiguous()
568
+
569
+ res = x + x_in
570
+
571
+ if self.use_extra_spatial_temporal_self_attention:
572
+ res = rearrange(res, "b c t h w -> (b t) c h w", b=b, h=h, w=w).contiguous()
573
+ res = self.extra_spatial_time_self_attention(res, time_steps=time_steps)
574
+ res = rearrange(res, "(b t) c h w -> b c t h w", b=b, h=h, w=w).contiguous()
575
+
576
+ return res
577
+
578
+
579
+ class GEGLU(nn.Module):
580
+ def __init__(self, dim_in, dim_out):
581
+ super().__init__()
582
+ self.proj = nn.Linear(dim_in, dim_out * 2)
583
+
584
+ def forward(self, x):
585
+ x, gate = self.proj(x).chunk(2, dim=-1)
586
+ return x * F.gelu(gate)
587
+
588
+
589
+ class FeedForward(nn.Module):
590
+ def __init__(
591
+ self,
592
+ dim,
593
+ dim_out=None,
594
+ mult=4,
595
+ glu=False,
596
+ dropout=0.0,
597
+ enable_lora=False,
598
+ lora_rank=32,
599
+ ):
600
+ super().__init__()
601
+ inner_dim = int(dim * mult)
602
+ dim_out = default(dim_out, dim)
603
+ project_in = (
604
+ nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
605
+ if not glu
606
+ else GEGLU(dim, inner_dim)
607
+ )
608
+
609
+ self.net = nn.Sequential(
610
+ project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
611
+ )
612
+ self.enable_lora = enable_lora
613
+ self.lora_rank = lora_rank
614
+ self.lora_alpha = 16
615
+ if self.enable_lora:
616
+ assert (
617
+ self.lora_rank is not None
618
+ ), "`lora_rank` must be given when `enable_lora` is True."
619
+ assert (
620
+ 0 < self.lora_rank < min(dim, dim_out)
621
+ ), f"`lora_rank` must be range [0, min(inner_dim={inner_dim}, dim_out={dim_out})], but got {self.lora_rank}."
622
+ self.lora_a = nn.Parameter(
623
+ torch.zeros((inner_dim, self.lora_rank), requires_grad=True)
624
+ )
625
+ self.lora_b = nn.Parameter(
626
+ torch.zeros((self.lora_rank, dim_out), requires_grad=True)
627
+ )
628
+ self.scaling = self.lora_alpha / self.lora_rank
629
+
630
+ def forward(self, x, with_lora=False):
631
+ if with_lora:
632
+ projected_x = self.net[1](self.net[0](x))
633
+ lora_x = (
634
+ torch.matmul(projected_x, torch.matmul(self.lora_a, self.lora_b))
635
+ * self.scaling
636
+ )
637
+ original_x = self.net[2](projected_x)
638
+ return original_x + lora_x
639
+ else:
640
+ return self.net(x)
641
+
642
+
643
+ class LinearAttention(nn.Module):
644
+ def __init__(self, dim, heads=4, dim_head=32):
645
+ super().__init__()
646
+ self.heads = heads
647
+ hidden_dim = dim_head * heads
648
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
649
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
650
+
651
+ def forward(self, x):
652
+ b, c, h, w = x.shape
653
+ qkv = self.to_qkv(x)
654
+ q, k, v = rearrange(
655
+ qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3
656
+ )
657
+ k = k.softmax(dim=-1)
658
+ context = torch.einsum("bhdn,bhen->bhde", k, v)
659
+ out = torch.einsum("bhde,bhdn->bhen", context, q)
660
+ out = rearrange(
661
+ out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w
662
+ )
663
+ return self.to_out(out)
664
+
665
+
666
+ class SpatialSelfAttention(nn.Module):
667
+ def __init__(self, in_channels):
668
+ super().__init__()
669
+ self.in_channels = in_channels
670
+
671
+ self.norm = torch.nn.GroupNorm(
672
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
673
+ )
674
+ self.q = torch.nn.Conv2d(
675
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
676
+ )
677
+ self.k = torch.nn.Conv2d(
678
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
679
+ )
680
+ self.v = torch.nn.Conv2d(
681
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
682
+ )
683
+ self.proj_out = torch.nn.Conv2d(
684
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
685
+ )
686
+
687
+ def forward(self, x):
688
+ h_ = x
689
+ h_ = self.norm(h_)
690
+ q = self.q(h_)
691
+ k = self.k(h_)
692
+ v = self.v(h_)
693
+
694
+ # compute attention
695
+ b, c, h, w = q.shape
696
+ q = rearrange(q, "b c h w -> b (h w) c")
697
+ k = rearrange(k, "b c h w -> b c (h w)")
698
+ w_ = torch.einsum("bij,bjk->bik", q, k)
699
+
700
+ w_ = w_ * (int(c) ** (-0.5))
701
+ w_ = torch.nn.functional.softmax(w_, dim=2)
702
+
703
+ # attend to values
704
+ v = rearrange(v, "b c h w -> b c (h w)")
705
+ w_ = rearrange(w_, "b i j -> b j i")
706
+ h_ = torch.einsum("bij,bjk->bik", v, w_)
707
+ h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
708
+ h_ = self.proj_out(h_)
709
+
710
+ return x + h_
core/modules/attention_mv.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from einops import rearrange
4
+ from torch import nn
5
+
6
+ from core.common import gradient_checkpoint
7
+
8
+ try:
9
+ import xformers
10
+ import xformers.ops
11
+
12
+ XFORMERS_IS_AVAILBLE = True
13
+ except:
14
+ XFORMERS_IS_AVAILBLE = False
15
+
16
+ print(f"XFORMERS_IS_AVAILBLE: {XFORMERS_IS_AVAILBLE}")
17
+
18
+
19
+ def get_group_norm_layer(in_channels):
20
+ if in_channels < 32:
21
+ if in_channels % 2 == 0:
22
+ num_groups = in_channels // 2
23
+ else:
24
+ num_groups = in_channels
25
+ else:
26
+ num_groups = 32
27
+ return torch.nn.GroupNorm(
28
+ num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
29
+ )
30
+
31
+
32
+ def zero_module(module):
33
+ """
34
+ Zero out the parameters of a module and return it.
35
+ """
36
+ for p in module.parameters():
37
+ p.detach().zero_()
38
+ return module
39
+
40
+
41
+ def conv_nd(dims, *args, **kwargs):
42
+ """
43
+ Create a 1D, 2D, or 3D convolution module.
44
+ """
45
+ if dims == 1:
46
+ return nn.Conv1d(*args, **kwargs)
47
+ elif dims == 2:
48
+ return nn.Conv2d(*args, **kwargs)
49
+ elif dims == 3:
50
+ return nn.Conv3d(*args, **kwargs)
51
+ raise ValueError(f"unsupported dimensions: {dims}")
52
+
53
+
54
+ class GEGLU(nn.Module):
55
+ def __init__(self, dim_in, dim_out):
56
+ super().__init__()
57
+ self.proj = nn.Linear(dim_in, dim_out * 2)
58
+
59
+ def forward(self, x):
60
+ x, gate = self.proj(x).chunk(2, dim=-1)
61
+ return x * F.gelu(gate)
62
+
63
+
64
+ class FeedForward(nn.Module):
65
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
66
+ super().__init__()
67
+ inner_dim = int(dim * mult)
68
+ if dim_out is None:
69
+ dim_out = dim
70
+ project_in = (
71
+ nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
72
+ if not glu
73
+ else GEGLU(dim, inner_dim)
74
+ )
75
+
76
+ self.net = nn.Sequential(
77
+ project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
78
+ )
79
+
80
+ def forward(self, x):
81
+ return self.net(x)
82
+
83
+
84
+ class SpatialTemporalAttention(nn.Module):
85
+ """Uses xformers to implement efficient epipolar masking for cross-attention between views."""
86
+
87
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
88
+ super().__init__()
89
+ inner_dim = dim_head * heads
90
+ if context_dim is None:
91
+ context_dim = query_dim
92
+
93
+ self.heads = heads
94
+ self.dim_head = dim_head
95
+
96
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
97
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
98
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
99
+
100
+ self.to_out = nn.Sequential(
101
+ nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
102
+ )
103
+ self.attention_op = None
104
+
105
+ def forward(self, x, context=None, enhance_multi_view_correspondence=False):
106
+ q = self.to_q(x)
107
+ if context is None:
108
+ context = x
109
+ k = self.to_k(context)
110
+ v = self.to_v(context)
111
+
112
+ b, _, _ = q.shape
113
+ q, k, v = map(
114
+ lambda t: t.unsqueeze(3)
115
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
116
+ .permute(0, 2, 1, 3)
117
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
118
+ .contiguous(),
119
+ (q, k, v),
120
+ )
121
+
122
+ if enhance_multi_view_correspondence:
123
+ with torch.no_grad():
124
+ normalized_x = torch.nn.functional.normalize(x.detach(), p=2, dim=-1)
125
+ cosine_sim_map = torch.bmm(normalized_x, normalized_x.transpose(-1, -2))
126
+ attn_bias = torch.where(cosine_sim_map > 0.0, 0.0, -1e9).to(
127
+ dtype=q.dtype
128
+ )
129
+ attn_bias = rearrange(
130
+ attn_bias.unsqueeze(1).expand(-1, self.heads, -1, -1),
131
+ "b h d1 d2 -> (b h) d1 d2",
132
+ ).detach()
133
+ else:
134
+ attn_bias = None
135
+
136
+ out = xformers.ops.memory_efficient_attention(
137
+ q, k, v, attn_bias=attn_bias, op=self.attention_op
138
+ )
139
+
140
+ out = (
141
+ out.unsqueeze(0)
142
+ .reshape(b, self.heads, out.shape[1], self.dim_head)
143
+ .permute(0, 2, 1, 3)
144
+ .reshape(b, out.shape[1], self.heads * self.dim_head)
145
+ )
146
+ del q, k, v, attn_bias
147
+ return self.to_out(out)
148
+
149
+
150
+ class MultiViewSelfAttentionTransformerBlock(nn.Module):
151
+
152
+ def __init__(
153
+ self,
154
+ dim,
155
+ n_heads,
156
+ d_head,
157
+ dropout=0.0,
158
+ gated_ff=True,
159
+ use_checkpoint=True,
160
+ full_spatial_temporal_attention=False,
161
+ enhance_multi_view_correspondence=False,
162
+ ):
163
+ super().__init__()
164
+ attn_cls = SpatialTemporalAttention
165
+ # self.self_attention_only = self_attention_only
166
+ self.attn1 = attn_cls(
167
+ query_dim=dim,
168
+ heads=n_heads,
169
+ dim_head=d_head,
170
+ dropout=dropout,
171
+ context_dim=None,
172
+ ) # is a self-attention if not self.disable_self_attn
173
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
174
+
175
+ if enhance_multi_view_correspondence:
176
+ # Zero initalization when MVCorr is enabled.
177
+ zero_module_fn = zero_module
178
+ else:
179
+
180
+ def zero_module_fn(x):
181
+ return x
182
+
183
+ self.attn2 = zero_module_fn(
184
+ attn_cls(
185
+ query_dim=dim,
186
+ heads=n_heads,
187
+ dim_head=d_head,
188
+ dropout=dropout,
189
+ context_dim=None,
190
+ )
191
+ ) # is self-attn if context is none
192
+
193
+ self.norm1 = nn.LayerNorm(dim)
194
+ self.norm2 = nn.LayerNorm(dim)
195
+ self.norm3 = nn.LayerNorm(dim)
196
+ self.use_checkpoint = use_checkpoint
197
+ self.full_spatial_temporal_attention = full_spatial_temporal_attention
198
+ self.enhance_multi_view_correspondence = enhance_multi_view_correspondence
199
+
200
+ def forward(self, x, time_steps=None):
201
+ return gradient_checkpoint(
202
+ self.many_stream_forward, (x, time_steps), None, flag=self.use_checkpoint
203
+ )
204
+
205
+ def many_stream_forward(self, x, time_steps=None):
206
+ n, v, hw = x.shape[:3]
207
+ x = rearrange(x, "n v hw c -> n (v hw) c")
208
+ x = (
209
+ self.attn1(
210
+ self.norm1(x), context=None, enhance_multi_view_correspondence=False
211
+ )
212
+ + x
213
+ )
214
+ if not self.full_spatial_temporal_attention:
215
+ x = rearrange(x, "n (v hw) c -> n v hw c", v=v)
216
+ x = rearrange(x, "n v hw c -> (n v) hw c")
217
+ x = (
218
+ self.attn2(
219
+ self.norm2(x),
220
+ context=None,
221
+ enhance_multi_view_correspondence=self.enhance_multi_view_correspondence
222
+ and hw <= 256,
223
+ )
224
+ + x
225
+ )
226
+ x = self.ff(self.norm3(x)) + x
227
+ if self.full_spatial_temporal_attention:
228
+ x = rearrange(x, "n (v hw) c -> n v hw c", v=v)
229
+ else:
230
+ x = rearrange(x, "(n v) hw c -> n v hw c", v=v)
231
+ return x
232
+
233
+
234
+ class MultiViewSelfAttentionTransformer(nn.Module):
235
+ """Spatial Transformer block with post init to add cross attn."""
236
+
237
+ def __init__(
238
+ self,
239
+ in_channels,
240
+ n_heads,
241
+ d_head,
242
+ num_views,
243
+ depth=1,
244
+ dropout=0.0,
245
+ use_linear=True,
246
+ use_checkpoint=True,
247
+ zero_out_initialization=True,
248
+ full_spatial_temporal_attention=False,
249
+ enhance_multi_view_correspondence=False,
250
+ ):
251
+ super().__init__()
252
+ self.num_views = num_views
253
+ self.in_channels = in_channels
254
+ inner_dim = n_heads * d_head
255
+ self.norm = get_group_norm_layer(in_channels)
256
+ if not use_linear:
257
+ self.proj_in = nn.Conv2d(
258
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0
259
+ )
260
+ else:
261
+ self.proj_in = nn.Linear(in_channels, inner_dim)
262
+
263
+ self.transformer_blocks = nn.ModuleList(
264
+ [
265
+ MultiViewSelfAttentionTransformerBlock(
266
+ inner_dim,
267
+ n_heads,
268
+ d_head,
269
+ dropout=dropout,
270
+ use_checkpoint=use_checkpoint,
271
+ full_spatial_temporal_attention=full_spatial_temporal_attention,
272
+ enhance_multi_view_correspondence=enhance_multi_view_correspondence,
273
+ )
274
+ for d in range(depth)
275
+ ]
276
+ )
277
+ self.zero_out_initialization = zero_out_initialization
278
+
279
+ if zero_out_initialization:
280
+ _zero_func = zero_module
281
+ else:
282
+
283
+ def _zero_func(x):
284
+ return x
285
+
286
+ if not use_linear:
287
+ self.proj_out = _zero_func(
288
+ nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
289
+ )
290
+ else:
291
+ self.proj_out = _zero_func(nn.Linear(inner_dim, in_channels))
292
+
293
+ self.use_linear = use_linear
294
+
295
+ def forward(self, x, time_steps=None):
296
+ # x: bt c h w
297
+ _, c, h, w = x.shape
298
+ n_views = self.num_views
299
+ x_in = x
300
+ x = self.norm(x)
301
+ x = rearrange(x, "(n v) c h w -> n v (h w) c", v=n_views)
302
+
303
+ if self.use_linear:
304
+ x = rearrange(x, "n v x c -> (n v) x c")
305
+ x = self.proj_in(x)
306
+ x = rearrange(x, "(n v) x c -> n v x c", v=n_views)
307
+ for i, block in enumerate(self.transformer_blocks):
308
+ x = block(x, time_steps=time_steps)
309
+ if self.use_linear:
310
+ x = rearrange(x, "n v x c -> (n v) x c")
311
+ x = self.proj_out(x)
312
+ x = rearrange(x, "(n v) x c -> n v x c", v=n_views)
313
+
314
+ x = rearrange(x, "n v (h w) c -> (n v) c h w", h=h, w=w).contiguous()
315
+
316
+ return x + x_in
core/modules/attention_temporal.py ADDED
@@ -0,0 +1,1111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch as th
5
+ import torch.nn.functional as F
6
+ from einops import rearrange, repeat
7
+ from torch import nn, einsum
8
+
9
+ try:
10
+ import xformers
11
+ import xformers.ops
12
+
13
+ XFORMERS_IS_AVAILBLE = True
14
+ except:
15
+ XFORMERS_IS_AVAILBLE = False
16
+ from core.common import gradient_checkpoint, exists, default
17
+ from core.basics import conv_nd, zero_module, normalization
18
+
19
+
20
+ class GEGLU(nn.Module):
21
+ def __init__(self, dim_in, dim_out):
22
+ super().__init__()
23
+ self.proj = nn.Linear(dim_in, dim_out * 2)
24
+
25
+ def forward(self, x):
26
+ x, gate = self.proj(x).chunk(2, dim=-1)
27
+ return x * F.gelu(gate)
28
+
29
+
30
+ class FeedForward(nn.Module):
31
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
32
+ super().__init__()
33
+ inner_dim = int(dim * mult)
34
+ dim_out = default(dim_out, dim)
35
+ project_in = (
36
+ nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
37
+ if not glu
38
+ else GEGLU(dim, inner_dim)
39
+ )
40
+
41
+ self.net = nn.Sequential(
42
+ project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
43
+ )
44
+
45
+ def forward(self, x):
46
+ return self.net(x)
47
+
48
+
49
+ def Normalize(in_channels):
50
+ return torch.nn.GroupNorm(
51
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
52
+ )
53
+
54
+
55
+ class RelativePosition(nn.Module):
56
+
57
+ def __init__(self, num_units, max_relative_position):
58
+ super().__init__()
59
+ self.num_units = num_units
60
+ self.max_relative_position = max_relative_position
61
+ self.embeddings_table = nn.Parameter(
62
+ th.Tensor(max_relative_position * 2 + 1, num_units)
63
+ )
64
+ nn.init.xavier_uniform_(self.embeddings_table)
65
+
66
+ def forward(self, length_q, length_k):
67
+ device = self.embeddings_table.device
68
+ range_vec_q = th.arange(length_q, device=device)
69
+ range_vec_k = th.arange(length_k, device=device)
70
+ distance_mat = range_vec_k[None, :] - range_vec_q[:, None]
71
+ distance_mat_clipped = th.clamp(
72
+ distance_mat, -self.max_relative_position, self.max_relative_position
73
+ )
74
+ final_mat = distance_mat_clipped + self.max_relative_position
75
+ final_mat = final_mat.long()
76
+ embeddings = self.embeddings_table[final_mat]
77
+ return embeddings
78
+
79
+
80
+ class TemporalCrossAttention(nn.Module):
81
+ def __init__(
82
+ self,
83
+ query_dim,
84
+ context_dim=None,
85
+ heads=8,
86
+ dim_head=64,
87
+ dropout=0.0,
88
+ # For relative positional representation and image-video joint training.
89
+ temporal_length=None,
90
+ image_length=None, # For image-video joint training.
91
+ # whether use relative positional representation in temporal attention.
92
+ use_relative_position=False,
93
+ # For image-video joint training.
94
+ img_video_joint_train=False,
95
+ use_tempoal_causal_attn=False,
96
+ bidirectional_causal_attn=False,
97
+ tempoal_attn_type=None,
98
+ joint_train_mode="same_batch",
99
+ **kwargs,
100
+ ):
101
+ super().__init__()
102
+ inner_dim = dim_head * heads
103
+ context_dim = default(context_dim, query_dim)
104
+ self.context_dim = context_dim
105
+
106
+ self.scale = dim_head**-0.5
107
+ self.heads = heads
108
+ self.temporal_length = temporal_length
109
+ self.use_relative_position = use_relative_position
110
+ self.img_video_joint_train = img_video_joint_train
111
+ self.bidirectional_causal_attn = bidirectional_causal_attn
112
+ self.joint_train_mode = joint_train_mode
113
+ assert joint_train_mode in ["same_batch", "diff_batch"]
114
+ self.tempoal_attn_type = tempoal_attn_type
115
+
116
+ if bidirectional_causal_attn:
117
+ assert use_tempoal_causal_attn
118
+ if tempoal_attn_type:
119
+ assert tempoal_attn_type in ["sparse_causal", "sparse_causal_first"]
120
+ assert not use_tempoal_causal_attn
121
+ assert not (
122
+ img_video_joint_train and (self.joint_train_mode == "same_batch")
123
+ )
124
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
125
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
126
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
127
+
128
+ assert not (
129
+ img_video_joint_train
130
+ and (self.joint_train_mode == "same_batch")
131
+ and use_tempoal_causal_attn
132
+ )
133
+ if img_video_joint_train:
134
+ if self.joint_train_mode == "same_batch":
135
+ mask = torch.ones(
136
+ [1, temporal_length + image_length, temporal_length + image_length]
137
+ )
138
+ mask[:, temporal_length:, :] = 0
139
+ mask[:, :, temporal_length:] = 0
140
+ self.mask = mask
141
+ else:
142
+ self.mask = None
143
+ elif use_tempoal_causal_attn:
144
+ # normal causal attn
145
+ self.mask = torch.tril(torch.ones([1, temporal_length, temporal_length]))
146
+ elif tempoal_attn_type == "sparse_causal":
147
+ # true indicates keeping
148
+ mask1 = torch.tril(torch.ones([1, temporal_length, temporal_length])).bool()
149
+ # initialize to same shape with mask1
150
+ mask2 = torch.zeros([1, temporal_length, temporal_length])
151
+ mask2[:, 2:temporal_length, : temporal_length - 2] = torch.tril(
152
+ torch.ones([1, temporal_length - 2, temporal_length - 2])
153
+ )
154
+ mask2 = (1 - mask2).bool() # false indicates masking
155
+ self.mask = mask1 & mask2
156
+ elif tempoal_attn_type == "sparse_causal_first":
157
+ # true indicates keeping
158
+ mask1 = torch.tril(torch.ones([1, temporal_length, temporal_length])).bool()
159
+ mask2 = torch.zeros([1, temporal_length, temporal_length])
160
+ mask2[:, 2:temporal_length, 1 : temporal_length - 1] = torch.tril(
161
+ torch.ones([1, temporal_length - 2, temporal_length - 2])
162
+ )
163
+ mask2 = (1 - mask2).bool() # false indicates masking
164
+ self.mask = mask1 & mask2
165
+ else:
166
+ self.mask = None
167
+
168
+ if use_relative_position:
169
+ assert temporal_length is not None
170
+ self.relative_position_k = RelativePosition(
171
+ num_units=dim_head, max_relative_position=temporal_length
172
+ )
173
+ self.relative_position_v = RelativePosition(
174
+ num_units=dim_head, max_relative_position=temporal_length
175
+ )
176
+
177
+ self.to_out = nn.Sequential(
178
+ nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
179
+ )
180
+
181
+ nn.init.constant_(self.to_q.weight, 0)
182
+ nn.init.constant_(self.to_k.weight, 0)
183
+ nn.init.constant_(self.to_v.weight, 0)
184
+ nn.init.constant_(self.to_out[0].weight, 0)
185
+ nn.init.constant_(self.to_out[0].bias, 0)
186
+
187
+ def forward(self, x, context=None, mask=None):
188
+ nh = self.heads
189
+ out = x
190
+ q = self.to_q(out)
191
+ context = default(context, x)
192
+ k = self.to_k(context)
193
+ v = self.to_v(context)
194
+
195
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=nh), (q, k, v))
196
+ sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
197
+
198
+ if self.use_relative_position:
199
+ len_q, len_k, len_v = q.shape[1], k.shape[1], v.shape[1]
200
+ k2 = self.relative_position_k(len_q, len_k)
201
+ sim2 = einsum("b t d, t s d -> b t s", q, k2) * self.scale
202
+ sim += sim2
203
+ if exists(self.mask):
204
+ if mask is None:
205
+ mask = self.mask.to(sim.device)
206
+ else:
207
+ # .to(sim.device)
208
+ mask = self.mask.to(sim.device).bool() & mask
209
+ else:
210
+ mask = mask
211
+ if mask is not None:
212
+ max_neg_value = -1e9
213
+ sim = sim + (1 - mask.float()) * max_neg_value # 1=masking,0=no masking
214
+
215
+ attn = sim.softmax(dim=-1)
216
+
217
+ out = einsum("b i j, b j d -> b i d", attn, v)
218
+
219
+ if self.bidirectional_causal_attn:
220
+ mask_reverse = torch.triu(
221
+ torch.ones(
222
+ [1, self.temporal_length, self.temporal_length], device=sim.device
223
+ )
224
+ )
225
+ sim_reverse = sim.float().masked_fill(mask_reverse == 0, max_neg_value)
226
+ attn_reverse = sim_reverse.softmax(dim=-1)
227
+ out_reverse = einsum("b i j, b j d -> b i d", attn_reverse, v)
228
+ out += out_reverse
229
+
230
+ if self.use_relative_position:
231
+ v2 = self.relative_position_v(len_q, len_v)
232
+ out2 = einsum("b t s, t s d -> b t d", attn, v2)
233
+ out += out2
234
+ out = rearrange(out, "(b h) n d -> b n (h d)", h=nh)
235
+ return self.to_out(out)
236
+
237
+
238
+ class CrossAttention(nn.Module):
239
+ def __init__(
240
+ self,
241
+ query_dim,
242
+ context_dim=None,
243
+ heads=8,
244
+ dim_head=64,
245
+ dropout=0.0,
246
+ sa_shared_kv=False,
247
+ shared_type="only_first",
248
+ **kwargs,
249
+ ):
250
+ super().__init__()
251
+ inner_dim = dim_head * heads
252
+ context_dim = default(context_dim, query_dim)
253
+ self.sa_shared_kv = sa_shared_kv
254
+ assert shared_type in [
255
+ "only_first",
256
+ "all_frames",
257
+ "first_and_prev",
258
+ "only_prev",
259
+ "full",
260
+ "causal",
261
+ "full_qkv",
262
+ ]
263
+ self.shared_type = shared_type
264
+
265
+ self.scale = dim_head**-0.5
266
+ self.heads = heads
267
+ self.dim_head = dim_head
268
+
269
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
270
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
271
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
272
+
273
+ self.to_out = nn.Sequential(
274
+ nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
275
+ )
276
+ if XFORMERS_IS_AVAILBLE:
277
+ self.forward = self.efficient_forward
278
+
279
+ def forward(self, x, context=None, mask=None):
280
+ h = self.heads
281
+ b = x.shape[0]
282
+
283
+ q = self.to_q(x)
284
+ context = default(context, x)
285
+ k = self.to_k(context)
286
+ v = self.to_v(context)
287
+ if self.sa_shared_kv:
288
+ if self.shared_type == "only_first":
289
+ k, v = map(
290
+ lambda xx: rearrange(xx[0].unsqueeze(0), "b n c -> (b n) c")
291
+ .unsqueeze(0)
292
+ .repeat(b, 1, 1),
293
+ (k, v),
294
+ )
295
+ else:
296
+ raise NotImplementedError
297
+
298
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
299
+
300
+ sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
301
+
302
+ if exists(mask):
303
+ mask = rearrange(mask, "b ... -> b (...)")
304
+ max_neg_value = -torch.finfo(sim.dtype).max
305
+ mask = repeat(mask, "b j -> (b h) () j", h=h)
306
+ sim.masked_fill_(~mask, max_neg_value)
307
+
308
+ # attention, what we cannot get enough of
309
+ attn = sim.softmax(dim=-1)
310
+
311
+ out = einsum("b i j, b j d -> b i d", attn, v)
312
+ out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
313
+ return self.to_out(out)
314
+
315
+ def efficient_forward(self, x, context=None, mask=None):
316
+ q = self.to_q(x)
317
+ context = default(context, x)
318
+ k = self.to_k(context)
319
+ v = self.to_v(context)
320
+
321
+ b, _, _ = q.shape
322
+ q, k, v = map(
323
+ lambda t: t.unsqueeze(3)
324
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
325
+ .permute(0, 2, 1, 3)
326
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
327
+ .contiguous(),
328
+ (q, k, v),
329
+ )
330
+ # actually compute the attention, what we cannot get enough of
331
+ out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=None)
332
+
333
+ if exists(mask):
334
+ raise NotImplementedError
335
+ out = (
336
+ out.unsqueeze(0)
337
+ .reshape(b, self.heads, out.shape[1], self.dim_head)
338
+ .permute(0, 2, 1, 3)
339
+ .reshape(b, out.shape[1], self.heads * self.dim_head)
340
+ )
341
+ return self.to_out(out)
342
+
343
+
344
+ class VideoSpatialCrossAttention(CrossAttention):
345
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0):
346
+ super().__init__(query_dim, context_dim, heads, dim_head, dropout)
347
+
348
+ def forward(self, x, context=None, mask=None):
349
+ b, c, t, h, w = x.shape
350
+ if context is not None:
351
+ context = context.repeat(t, 1, 1)
352
+ x = super.forward(spatial_attn_reshape(x), context=context) + x
353
+ return spatial_attn_reshape_back(x, b, h)
354
+
355
+
356
+ class BasicTransformerBlockST(nn.Module):
357
+ def __init__(
358
+ self,
359
+ # Spatial Stuff
360
+ dim,
361
+ n_heads,
362
+ d_head,
363
+ dropout=0.0,
364
+ context_dim=None,
365
+ gated_ff=True,
366
+ checkpoint=True,
367
+ # Temporal Stuff
368
+ temporal_length=None,
369
+ image_length=None,
370
+ use_relative_position=True,
371
+ img_video_joint_train=False,
372
+ cross_attn_on_tempoal=False,
373
+ temporal_crossattn_type="selfattn",
374
+ order="stst",
375
+ temporalcrossfirst=False,
376
+ temporal_context_dim=None,
377
+ split_stcontext=False,
378
+ local_spatial_temporal_attn=False,
379
+ window_size=2,
380
+ **kwargs,
381
+ ):
382
+ super().__init__()
383
+ # Self attention
384
+ self.attn1 = CrossAttention(
385
+ query_dim=dim,
386
+ heads=n_heads,
387
+ dim_head=d_head,
388
+ dropout=dropout,
389
+ **kwargs,
390
+ )
391
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
392
+ # cross attention if context is not None
393
+ self.attn2 = CrossAttention(
394
+ query_dim=dim,
395
+ context_dim=context_dim,
396
+ heads=n_heads,
397
+ dim_head=d_head,
398
+ dropout=dropout,
399
+ **kwargs,
400
+ )
401
+ self.norm1 = nn.LayerNorm(dim)
402
+ self.norm2 = nn.LayerNorm(dim)
403
+ self.norm3 = nn.LayerNorm(dim)
404
+ self.checkpoint = checkpoint
405
+ self.order = order
406
+ assert self.order in ["stst", "sstt", "st_parallel"]
407
+ self.temporalcrossfirst = temporalcrossfirst
408
+ self.split_stcontext = split_stcontext
409
+ self.local_spatial_temporal_attn = local_spatial_temporal_attn
410
+ if self.local_spatial_temporal_attn:
411
+ assert self.order == "stst"
412
+ assert self.order == "stst"
413
+ self.window_size = window_size
414
+ if not split_stcontext:
415
+ temporal_context_dim = context_dim
416
+ # Temporal attention
417
+ assert temporal_crossattn_type in ["selfattn", "crossattn", "skip"]
418
+ self.temporal_crossattn_type = temporal_crossattn_type
419
+ self.attn1_tmp = TemporalCrossAttention(
420
+ query_dim=dim,
421
+ heads=n_heads,
422
+ dim_head=d_head,
423
+ dropout=dropout,
424
+ temporal_length=temporal_length,
425
+ image_length=image_length,
426
+ use_relative_position=use_relative_position,
427
+ img_video_joint_train=img_video_joint_train,
428
+ **kwargs,
429
+ )
430
+ self.attn2_tmp = TemporalCrossAttention(
431
+ query_dim=dim,
432
+ heads=n_heads,
433
+ dim_head=d_head,
434
+ dropout=dropout,
435
+ # cross attn
436
+ context_dim=(
437
+ temporal_context_dim if temporal_crossattn_type == "crossattn" else None
438
+ ),
439
+ # temporal attn
440
+ temporal_length=temporal_length,
441
+ image_length=image_length,
442
+ use_relative_position=use_relative_position,
443
+ img_video_joint_train=img_video_joint_train,
444
+ **kwargs,
445
+ )
446
+ self.norm4 = nn.LayerNorm(dim)
447
+ self.norm5 = nn.LayerNorm(dim)
448
+
449
+ def forward(
450
+ self,
451
+ x,
452
+ context=None,
453
+ temporal_context=None,
454
+ no_temporal_attn=None,
455
+ attn_mask=None,
456
+ **kwargs,
457
+ ):
458
+ if not self.split_stcontext:
459
+ # st cross attention use the same context vector
460
+ temporal_context = context.detach().clone()
461
+
462
+ if context is None and temporal_context is None:
463
+ # self-attention models
464
+ if no_temporal_attn:
465
+ raise NotImplementedError
466
+ return gradient_checkpoint(
467
+ self._forward_nocontext, (x), self.parameters(), self.checkpoint
468
+ )
469
+ else:
470
+ # cross-attention models
471
+ if no_temporal_attn:
472
+ forward_func = self._forward_no_temporal_attn
473
+ else:
474
+ forward_func = self._forward
475
+ inputs = (
476
+ (x, context, temporal_context)
477
+ if temporal_context is not None
478
+ else (x, context)
479
+ )
480
+ return gradient_checkpoint(
481
+ forward_func, inputs, self.parameters(), self.checkpoint
482
+ )
483
+
484
+ def _forward(
485
+ self,
486
+ x,
487
+ context=None,
488
+ temporal_context=None,
489
+ mask=None,
490
+ no_temporal_attn=None,
491
+ ):
492
+ assert x.dim() == 5, f"x shape = {x.shape}"
493
+ b, c, t, h, w = x.shape
494
+
495
+ if self.order in ["stst", "sstt"]:
496
+ x = self._st_cross_attn(
497
+ x,
498
+ context,
499
+ temporal_context=temporal_context,
500
+ order=self.order,
501
+ mask=mask,
502
+ ) # no_temporal_attn=no_temporal_attn,
503
+ elif self.order == "st_parallel":
504
+ x = self._st_cross_attn_parallel(
505
+ x,
506
+ context,
507
+ temporal_context=temporal_context,
508
+ order=self.order,
509
+ ) # no_temporal_attn=no_temporal_attn,
510
+ else:
511
+ raise NotImplementedError
512
+
513
+ x = self.ff(self.norm3(x)) + x
514
+ if (no_temporal_attn is None) or (not no_temporal_attn):
515
+ x = rearrange(x, "(b h w) t c -> b c t h w", b=b, h=h, w=w) # 3d -> 5d
516
+ elif no_temporal_attn:
517
+ x = rearrange(x, "(b t) (h w) c -> b c t h w", b=b, h=h, w=w) # 3d -> 5d
518
+ return x
519
+
520
+ def _forward_no_temporal_attn(
521
+ self,
522
+ x,
523
+ context=None,
524
+ temporal_context=None,
525
+ ):
526
+ assert x.dim() == 5, f"x shape = {x.shape}"
527
+ b, c, t, h, w = x.shape
528
+
529
+ if self.order in ["stst", "sstt"]:
530
+ mask = torch.zeros([1, t, t], device=x.device).bool()
531
+ x = self._st_cross_attn(
532
+ x,
533
+ context,
534
+ temporal_context=temporal_context,
535
+ order=self.order,
536
+ mask=mask,
537
+ )
538
+ elif self.order == "st_parallel":
539
+ x = self._st_cross_attn_parallel(
540
+ x,
541
+ context,
542
+ temporal_context=temporal_context,
543
+ order=self.order,
544
+ no_temporal_attn=True,
545
+ )
546
+ else:
547
+ raise NotImplementedError
548
+
549
+ x = self.ff(self.norm3(x)) + x
550
+ x = rearrange(x, "(b h w) t c -> b c t h w", b=b, h=h, w=w) # 3d -> 5d
551
+ return x
552
+
553
+ def _forward_nocontext(self, x, no_temporal_attn=None):
554
+ assert x.dim() == 5, f"x shape = {x.shape}"
555
+ b, c, t, h, w = x.shape
556
+
557
+ if self.order in ["stst", "sstt"]:
558
+ x = self._st_cross_attn(
559
+ x, order=self.order, no_temporal_attn=no_temporal_attn
560
+ )
561
+ elif self.order == "st_parallel":
562
+ x = self._st_cross_attn_parallel(
563
+ x, order=self.order, no_temporal_attn=no_temporal_attn
564
+ )
565
+ else:
566
+ raise NotImplementedError
567
+
568
+ x = self.ff(self.norm3(x)) + x
569
+ x = rearrange(x, "(b h w) t c -> b c t h w", b=b, h=h, w=w) # 3d -> 5d
570
+
571
+ return x
572
+
573
+ def _st_cross_attn(
574
+ self, x, context=None, temporal_context=None, order="stst", mask=None
575
+ ):
576
+ b, c, t, h, w = x.shape
577
+ if order == "stst":
578
+ x = rearrange(x, "b c t h w -> (b t) (h w) c")
579
+ x = self.attn1(self.norm1(x)) + x
580
+ x = rearrange(x, "(b t) (h w) c -> b c t h w", b=b, h=h)
581
+ if self.local_spatial_temporal_attn:
582
+ x = local_spatial_temporal_attn_reshape(x, window_size=self.window_size)
583
+ else:
584
+ x = rearrange(x, "b c t h w -> (b h w) t c")
585
+ x = self.attn1_tmp(self.norm4(x), mask=mask) + x
586
+
587
+ if self.local_spatial_temporal_attn:
588
+ x = local_spatial_temporal_attn_reshape_back(
589
+ x, window_size=self.window_size, b=b, h=h, w=w, t=t
590
+ )
591
+ else:
592
+ x = rearrange(x, "(b h w) t c -> b c t h w", b=b, h=h, w=w) # 3d -> 5d
593
+
594
+ # spatial cross attention
595
+ x = rearrange(x, "b c t h w -> (b t) (h w) c")
596
+ if context is not None:
597
+ if context.shape[0] == t: # img captions no_temporal_attn or
598
+ context_ = context
599
+ else:
600
+ context_ = []
601
+ for i in range(context.shape[0]):
602
+ context_.append(context[i].unsqueeze(0).repeat(t, 1, 1))
603
+ context_ = torch.cat(context_, dim=0)
604
+ else:
605
+ context_ = None
606
+ x = self.attn2(self.norm2(x), context=context_) + x
607
+
608
+ # temporal cross attention
609
+ # if (no_temporal_attn is None) or (not no_temporal_attn):
610
+ x = rearrange(x, "(b t) (h w) c -> b c t h w", b=b, h=h)
611
+ x = rearrange(x, "b c t h w -> (b h w) t c")
612
+ if self.temporal_crossattn_type == "crossattn":
613
+ # tmporal cross attention
614
+ if temporal_context is not None:
615
+ # print(f'STATTN context={context.shape}, temporal_context={temporal_context.shape}')
616
+ temporal_context = torch.cat(
617
+ [context, temporal_context], dim=1
618
+ ) # blc
619
+ # print(f'STATTN after concat temporal_context={temporal_context.shape}')
620
+ temporal_context = temporal_context.repeat(h * w, 1, 1)
621
+ # print(f'after repeat temporal_context={temporal_context.shape}')
622
+ else:
623
+ temporal_context = context[0:1, ...].repeat(h * w, 1, 1)
624
+ # print(f'STATTN after concat x={x.shape}')
625
+ x = (
626
+ self.attn2_tmp(self.norm5(x), context=temporal_context, mask=mask)
627
+ + x
628
+ )
629
+ elif self.temporal_crossattn_type == "selfattn":
630
+ # temporal self attention
631
+ x = self.attn2_tmp(self.norm5(x), context=None, mask=mask) + x
632
+ elif self.temporal_crossattn_type == "skip":
633
+ # no temporal cross and self attention
634
+ pass
635
+ else:
636
+ raise NotImplementedError
637
+
638
+ elif order == "sstt":
639
+ # spatial self attention
640
+ x = rearrange(x, "b c t h w -> (b t) (h w) c")
641
+ x = self.attn1(self.norm1(x)) + x
642
+
643
+ # spatial cross attention
644
+ context_ = context.repeat(t, 1, 1) if context is not None else None
645
+ x = self.attn2(self.norm2(x), context=context_) + x
646
+ x = rearrange(x, "(b t) (h w) c -> b c t h w", b=b, h=h)
647
+
648
+ if (no_temporal_attn is None) or (not no_temporal_attn):
649
+ if self.temporalcrossfirst:
650
+ # temporal cross attention
651
+ if self.temporal_crossattn_type == "crossattn":
652
+ # if temporal_context is not None:
653
+ temporal_context = context.repeat(h * w, 1, 1)
654
+ x = (
655
+ self.attn2_tmp(
656
+ self.norm5(x), context=temporal_context, mask=mask
657
+ )
658
+ + x
659
+ )
660
+ elif self.temporal_crossattn_type == "selfattn":
661
+ x = self.attn2_tmp(self.norm5(x), context=None, mask=mask) + x
662
+ elif self.temporal_crossattn_type == "skip":
663
+ pass
664
+ else:
665
+ raise NotImplementedError
666
+ # temporal self attention
667
+ x = rearrange(x, "b c t h w -> (b h w) t c")
668
+ x = self.attn1_tmp(self.norm4(x), mask=mask) + x
669
+ else:
670
+ # temporal self attention
671
+ x = rearrange(x, "b c t h w -> (b h w) t c")
672
+ x = self.attn1_tmp(self.norm4(x), mask=mask) + x
673
+ # temporal cross attention
674
+ if self.temporal_crossattn_type == "crossattn":
675
+ if temporal_context is not None:
676
+ temporal_context = context.repeat(h * w, 1, 1)
677
+ x = (
678
+ self.attn2_tmp(
679
+ self.norm5(x), context=temporal_context, mask=mask
680
+ )
681
+ + x
682
+ )
683
+ elif self.temporal_crossattn_type == "selfattn":
684
+ x = self.attn2_tmp(self.norm5(x), context=None, mask=mask) + x
685
+ elif self.temporal_crossattn_type == "skip":
686
+ pass
687
+ else:
688
+ raise NotImplementedError
689
+ else:
690
+ raise NotImplementedError
691
+
692
+ return x
693
+
694
+ def _st_cross_attn_parallel(
695
+ self, x, context=None, temporal_context=None, order="sst", no_temporal_attn=None
696
+ ):
697
+ """order: x -> Self Attn -> Cross Attn -> attn_s
698
+ x -> Temp Self Attn -> attn_t
699
+ x' = x + attn_s + attn_t
700
+ """
701
+ if no_temporal_attn is not None:
702
+ raise NotImplementedError
703
+
704
+ B, C, T, H, W = x.shape
705
+ # spatial self attention
706
+ h = x
707
+ h = rearrange(h, "b c t h w -> (b t) (h w) c")
708
+ h = self.attn1(self.norm1(h)) + h
709
+ # spatial cross
710
+ # context_ = context.repeat(T, 1, 1) if context is not None else None
711
+ if context is not None:
712
+ context_ = []
713
+ for i in range(context.shape[0]):
714
+ context_.append(context[i].unsqueeze(0).repeat(T, 1, 1))
715
+ context_ = torch.cat(context_, dim=0)
716
+ else:
717
+ context_ = None
718
+
719
+ h = self.attn2(self.norm2(h), context=context_) + h
720
+ h = rearrange(h, "(b t) (h w) c -> b c t h w", b=B, h=H)
721
+
722
+ # temporal self
723
+ h2 = x
724
+ h2 = rearrange(h2, "b c t h w -> (b h w) t c")
725
+ h2 = self.attn1_tmp(self.norm4(h2)) # + h2
726
+ h2 = rearrange(h2, "(b h w) t c -> b c t h w", b=B, h=H, w=W)
727
+ out = h + h2
728
+ return rearrange(out, "b c t h w -> (b h w) t c")
729
+
730
+
731
+ def spatial_attn_reshape(x):
732
+ return rearrange(x, "b c t h w -> (b t) (h w) c")
733
+
734
+
735
+ def spatial_attn_reshape_back(x, b, h):
736
+ return rearrange(x, "(b t) (h w) c -> b c t h w", b=b, h=h)
737
+
738
+
739
+ def temporal_attn_reshape(x):
740
+ return rearrange(x, "b c t h w -> (b h w) t c")
741
+
742
+
743
+ def temporal_attn_reshape_back(x, b, h, w):
744
+ return rearrange(x, "(b h w) t c -> b c t h w", b=b, h=h, w=w)
745
+
746
+
747
+ def local_spatial_temporal_attn_reshape(x, window_size):
748
+ B, C, T, H, W = x.shape
749
+ NH = H // window_size
750
+ NW = W // window_size
751
+ # x = x.view(B, C, T, NH, window_size, NW, window_size)
752
+ # tokens = x.permute(0, 1, 2, 3, 5, 4, 6).contiguous()
753
+ # tokens = tokens.view(-1, window_size, window_size, C)
754
+ x = rearrange(
755
+ x,
756
+ "b c t (nh wh) (nw ww) -> b c t nh wh nw ww",
757
+ nh=NH,
758
+ nw=NW,
759
+ wh=window_size,
760
+ # # B, C, T, NH, NW, window_size, window_size
761
+ ww=window_size,
762
+ ).contiguous()
763
+ # (B, NH, NW) (T, window_size, window_size) C
764
+ x = rearrange(x, "b c t nh wh nw ww -> (b nh nw) (t wh ww) c")
765
+ return x
766
+
767
+
768
+ def local_spatial_temporal_attn_reshape_back(x, window_size, b, h, w, t):
769
+ B, L, C = x.shape
770
+ NH = h // window_size
771
+ NW = w // window_size
772
+ x = rearrange(
773
+ x,
774
+ "(b nh nw) (t wh ww) c -> b c t nh wh nw ww",
775
+ b=b,
776
+ nh=NH,
777
+ nw=NW,
778
+ t=t,
779
+ wh=window_size,
780
+ ww=window_size,
781
+ )
782
+ x = rearrange(x, "b c t nh wh nw ww -> b c t (nh wh) (nw ww)")
783
+ return x
784
+
785
+
786
+ class SpatialTemporalTransformer(nn.Module):
787
+ """
788
+ Transformer block for video-like data (5D tensor).
789
+ First, project the input (aka embedding) with NO reshape.
790
+ Then apply standard transformer action.
791
+ The 5D -> 3D reshape operation will be done in the specific attention module.
792
+ """
793
+
794
+ def __init__(
795
+ self,
796
+ in_channels,
797
+ n_heads,
798
+ d_head,
799
+ depth=1,
800
+ dropout=0.0,
801
+ context_dim=None,
802
+ # Temporal stuff
803
+ temporal_length=None,
804
+ image_length=None,
805
+ use_relative_position=True,
806
+ img_video_joint_train=False,
807
+ cross_attn_on_tempoal=False,
808
+ temporal_crossattn_type=False,
809
+ order="stst",
810
+ temporalcrossfirst=False,
811
+ split_stcontext=False,
812
+ temporal_context_dim=None,
813
+ **kwargs,
814
+ ):
815
+ super().__init__()
816
+
817
+ self.in_channels = in_channels
818
+ inner_dim = n_heads * d_head
819
+
820
+ self.norm = Normalize(in_channels)
821
+ self.proj_in = nn.Conv3d(
822
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0
823
+ )
824
+
825
+ self.transformer_blocks = nn.ModuleList(
826
+ [
827
+ BasicTransformerBlockST(
828
+ inner_dim,
829
+ n_heads,
830
+ d_head,
831
+ dropout=dropout,
832
+ # cross attn
833
+ context_dim=context_dim,
834
+ # temporal attn
835
+ temporal_length=temporal_length,
836
+ image_length=image_length,
837
+ use_relative_position=use_relative_position,
838
+ img_video_joint_train=img_video_joint_train,
839
+ temporal_crossattn_type=temporal_crossattn_type,
840
+ order=order,
841
+ temporalcrossfirst=temporalcrossfirst,
842
+ split_stcontext=split_stcontext,
843
+ temporal_context_dim=temporal_context_dim,
844
+ **kwargs,
845
+ )
846
+ for d in range(depth)
847
+ ]
848
+ )
849
+
850
+ self.proj_out = zero_module(
851
+ nn.Conv3d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
852
+ )
853
+
854
+ def forward(self, x, context=None, temporal_context=None, **kwargs):
855
+ # note: if no context is given, cross-attention defaults to self-attention
856
+ assert x.dim() == 5, f"x shape = {x.shape}"
857
+ b, c, t, h, w = x.shape
858
+ x_in = x
859
+
860
+ x = self.norm(x)
861
+ x = self.proj_in(x)
862
+
863
+ for block in self.transformer_blocks:
864
+ x = block(x, context=context, temporal_context=temporal_context, **kwargs)
865
+
866
+ x = self.proj_out(x)
867
+ return x + x_in
868
+
869
+
870
+ class STAttentionBlock2(nn.Module):
871
+ def __init__(
872
+ self,
873
+ channels,
874
+ num_heads=1,
875
+ num_head_channels=-1,
876
+ use_checkpoint=False, # not used, only used in ResBlock
877
+ use_new_attention_order=False, # QKVAttention or QKVAttentionLegacy
878
+ temporal_length=16, # used in relative positional representation.
879
+ image_length=8, # used for image-video joint training.
880
+ # whether use relative positional representation in temporal attention.
881
+ use_relative_position=False,
882
+ img_video_joint_train=False,
883
+ # norm_type="groupnorm",
884
+ attn_norm_type="group",
885
+ use_tempoal_causal_attn=False,
886
+ ):
887
+ """
888
+ version 1: guided_diffusion implemented version
889
+ version 2: remove args input argument
890
+ """
891
+ super().__init__()
892
+
893
+ if num_head_channels == -1:
894
+ self.num_heads = num_heads
895
+ else:
896
+ assert (
897
+ channels % num_head_channels == 0
898
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
899
+ self.num_heads = channels // num_head_channels
900
+ self.use_checkpoint = use_checkpoint
901
+
902
+ self.temporal_length = temporal_length
903
+ self.image_length = image_length
904
+ self.use_relative_position = use_relative_position
905
+ self.img_video_joint_train = img_video_joint_train
906
+ self.attn_norm_type = attn_norm_type
907
+ assert self.attn_norm_type in ["group", "no_norm"]
908
+ self.use_tempoal_causal_attn = use_tempoal_causal_attn
909
+
910
+ if self.attn_norm_type == "group":
911
+ self.norm_s = normalization(channels)
912
+ self.norm_t = normalization(channels)
913
+
914
+ self.qkv_s = conv_nd(1, channels, channels * 3, 1)
915
+ self.qkv_t = conv_nd(1, channels, channels * 3, 1)
916
+
917
+ if self.img_video_joint_train:
918
+ mask = th.ones(
919
+ [1, temporal_length + image_length, temporal_length + image_length]
920
+ )
921
+ mask[:, temporal_length:, :] = 0
922
+ mask[:, :, temporal_length:] = 0
923
+ self.register_buffer("mask", mask)
924
+ else:
925
+ self.mask = None
926
+
927
+ if use_new_attention_order:
928
+ # split qkv before split heads
929
+ self.attention_s = QKVAttention(self.num_heads)
930
+ self.attention_t = QKVAttention(self.num_heads)
931
+ else:
932
+ # split heads before split qkv
933
+ self.attention_s = QKVAttentionLegacy(self.num_heads)
934
+ self.attention_t = QKVAttentionLegacy(self.num_heads)
935
+
936
+ if use_relative_position:
937
+ self.relative_position_k = RelativePosition(
938
+ num_units=channels // self.num_heads,
939
+ max_relative_position=temporal_length,
940
+ )
941
+ self.relative_position_v = RelativePosition(
942
+ num_units=channels // self.num_heads,
943
+ max_relative_position=temporal_length,
944
+ )
945
+
946
+ self.proj_out_s = zero_module(
947
+ # conv_dim, in_channels, out_channels, kernel_size
948
+ conv_nd(1, channels, channels, 1)
949
+ )
950
+ self.proj_out_t = zero_module(
951
+ # conv_dim, in_channels, out_channels, kernel_size
952
+ conv_nd(1, channels, channels, 1)
953
+ )
954
+
955
+ def forward(self, x, mask=None):
956
+ b, c, t, h, w = x.shape
957
+
958
+ # spatial
959
+ out = rearrange(x, "b c t h w -> (b t) c (h w)")
960
+ if self.attn_norm_type == "no_norm":
961
+ qkv = self.qkv_s(out)
962
+ else:
963
+ qkv = self.qkv_s(self.norm_s(out))
964
+ out = self.attention_s(qkv)
965
+ out = self.proj_out_s(out)
966
+ out = rearrange(out, "(b t) c (h w) -> b c t h w", b=b, h=h)
967
+ x += out
968
+
969
+ # temporal
970
+ out = rearrange(x, "b c t h w -> (b h w) c t")
971
+ if self.attn_norm_type == "no_norm":
972
+ qkv = self.qkv_t(out)
973
+ else:
974
+ qkv = self.qkv_t(self.norm_t(out))
975
+
976
+ # relative positional embedding
977
+ if self.use_relative_position:
978
+ len_q = qkv.size()[-1]
979
+ len_k, len_v = len_q, len_q
980
+ k_rp = self.relative_position_k(len_q, len_k)
981
+ v_rp = self.relative_position_v(len_q, len_v) # [T,T,head_dim]
982
+ out = self.attention_t(
983
+ qkv,
984
+ rp=(k_rp, v_rp),
985
+ mask=self.mask,
986
+ use_tempoal_causal_attn=self.use_tempoal_causal_attn,
987
+ )
988
+ else:
989
+ out = self.attention_t(
990
+ qkv,
991
+ rp=None,
992
+ mask=self.mask,
993
+ use_tempoal_causal_attn=self.use_tempoal_causal_attn,
994
+ )
995
+
996
+ out = self.proj_out_t(out)
997
+ out = rearrange(out, "(b h w) c t -> b c t h w", b=b, h=h, w=w)
998
+
999
+ return x + out
1000
+
1001
+
1002
+ class QKVAttentionLegacy(nn.Module):
1003
+ """
1004
+ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
1005
+ """
1006
+
1007
+ def __init__(self, n_heads):
1008
+ super().__init__()
1009
+ self.n_heads = n_heads
1010
+
1011
+ def forward(self, qkv, rp=None, mask=None):
1012
+ """
1013
+ Apply QKV attention.
1014
+
1015
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
1016
+ :return: an [N x (H * C) x T] tensor after attention.
1017
+ """
1018
+ if rp is not None or mask is not None:
1019
+ raise NotImplementedError
1020
+ bs, width, length = qkv.shape
1021
+ assert width % (3 * self.n_heads) == 0
1022
+ ch = width // (3 * self.n_heads)
1023
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
1024
+ scale = 1 / math.sqrt(math.sqrt(ch))
1025
+ weight = th.einsum(
1026
+ "bct,bcs->bts", q * scale, k * scale
1027
+ ) # More stable with f16 than dividing afterwards
1028
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
1029
+ a = th.einsum("bts,bcs->bct", weight, v)
1030
+ return a.reshape(bs, -1, length)
1031
+
1032
+ @staticmethod
1033
+ def count_flops(model, _x, y):
1034
+ return count_flops_attn(model, _x, y)
1035
+
1036
+
1037
+ class QKVAttention(nn.Module):
1038
+ """
1039
+ A module which performs QKV attention and splits in a different order.
1040
+ """
1041
+
1042
+ def __init__(self, n_heads):
1043
+ super().__init__()
1044
+ self.n_heads = n_heads
1045
+
1046
+ def forward(self, qkv, rp=None, mask=None, use_tempoal_causal_attn=False):
1047
+ """
1048
+ Apply QKV attention.
1049
+
1050
+ :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
1051
+ :return: an [N x (H * C) x T] tensor after attention.
1052
+ """
1053
+ bs, width, length = qkv.shape
1054
+ assert width % (3 * self.n_heads) == 0
1055
+ ch = width // (3 * self.n_heads)
1056
+ # print('qkv', qkv.size())
1057
+ q, k, v = qkv.chunk(3, dim=1)
1058
+ scale = 1 / math.sqrt(math.sqrt(ch))
1059
+ # print('bs, self.n_heads, ch, length', bs, self.n_heads, ch, length)
1060
+
1061
+ weight = th.einsum(
1062
+ "bct,bcs->bts",
1063
+ (q * scale).view(bs * self.n_heads, ch, length),
1064
+ (k * scale).view(bs * self.n_heads, ch, length),
1065
+ ) # More stable with f16 than dividing afterwards
1066
+ # weight:[b,t,s] b=bs*n_heads*T
1067
+
1068
+ if rp is not None:
1069
+ k_rp, v_rp = rp # [length, length, head_dim] [8, 8, 48]
1070
+ weight2 = th.einsum(
1071
+ "bct,tsc->bst", (q * scale).view(bs * self.n_heads, ch, length), k_rp
1072
+ )
1073
+ weight += weight2
1074
+
1075
+ if use_tempoal_causal_attn:
1076
+ # weight = torch.tril(weight)
1077
+ assert mask is None, f"Not implemented for merging two masks!"
1078
+ mask = torch.tril(torch.ones(weight.shape))
1079
+ else:
1080
+ if mask is not None: # only keep upper-left matrix
1081
+ # process mask
1082
+ c, t, _ = weight.shape
1083
+
1084
+ if mask.shape[-1] > t:
1085
+ mask = mask[:, :t, :t]
1086
+ elif mask.shape[-1] < t: # pad ones
1087
+ mask_ = th.zeros([c, t, t]).to(mask.device)
1088
+ t_ = mask.shape[-1]
1089
+ mask_[:, :t_, :t_] = mask
1090
+ mask = mask_
1091
+ else:
1092
+ assert (
1093
+ weight.shape[-1] == mask.shape[-1]
1094
+ ), f"weight={weight.shape}, mask={mask.shape}"
1095
+
1096
+ if mask is not None:
1097
+ INF = -1e8 # float('-inf')
1098
+ weight = weight.float().masked_fill(mask == 0, INF)
1099
+
1100
+ weight = F.softmax(weight.float(), dim=-1).type(
1101
+ weight.dtype
1102
+ ) # [256, 8, 8] [b, t, t] b=bs*n_heads*h*w,t=nframes
1103
+ # weight = F.softmax(weight, dim=-1)#[256, 8, 8] [b, t, t] b=bs*n_heads*h*w,t=nframes
1104
+ # [256, 48, 8] [b, head_dim, t]
1105
+ a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
1106
+
1107
+ if rp is not None:
1108
+ a2 = th.einsum("bts,tsc->btc", weight, v_rp).transpose(1, 2) # btc->bct
1109
+ a += a2
1110
+
1111
+ return a.reshape(bs, -1, length)
core/modules/encoders/__init__.py ADDED
File without changes
core/modules/encoders/adapter.py ADDED
@@ -0,0 +1,485 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from collections import OrderedDict
4
+ from extralibs.cond_api import ExtraCondition
5
+ from core.modules.x_transformer import FixedPositionalEmbedding
6
+ from core.basics import zero_module, conv_nd, avg_pool_nd
7
+
8
+
9
+ class Downsample(nn.Module):
10
+ """
11
+ A downsampling layer with an optional convolution.
12
+ :param channels: channels in the inputs and outputs.
13
+ :param use_conv: a bool determining if a convolution is applied.
14
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
15
+ downsampling occurs in the inner-two dimensions.
16
+ """
17
+
18
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
19
+ super().__init__()
20
+ self.channels = channels
21
+ self.out_channels = out_channels or channels
22
+ self.use_conv = use_conv
23
+ self.dims = dims
24
+ stride = 2 if dims != 3 else (1, 2, 2)
25
+ if use_conv:
26
+ self.op = conv_nd(
27
+ dims,
28
+ self.channels,
29
+ self.out_channels,
30
+ 3,
31
+ stride=stride,
32
+ padding=padding,
33
+ )
34
+ else:
35
+ assert self.channels == self.out_channels
36
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
37
+
38
+ def forward(self, x):
39
+ assert x.shape[1] == self.channels
40
+ return self.op(x)
41
+
42
+
43
+ class ResnetBlock(nn.Module):
44
+ def __init__(self, in_c, out_c, down, ksize=3, sk=False, use_conv=True):
45
+ super().__init__()
46
+ ps = ksize // 2
47
+ if in_c != out_c or sk == False:
48
+ self.in_conv = nn.Conv2d(in_c, out_c, ksize, 1, ps)
49
+ else:
50
+ self.in_conv = None
51
+ self.block1 = nn.Conv2d(out_c, out_c, 3, 1, 1)
52
+ self.act = nn.ReLU()
53
+ self.block2 = nn.Conv2d(out_c, out_c, ksize, 1, ps)
54
+ if sk == False:
55
+ self.skep = nn.Conv2d(in_c, out_c, ksize, 1, ps)
56
+ else:
57
+ self.skep = None
58
+
59
+ self.down = down
60
+ if self.down == True:
61
+ self.down_opt = Downsample(in_c, use_conv=use_conv)
62
+
63
+ def forward(self, x):
64
+ if self.down == True:
65
+ x = self.down_opt(x)
66
+ if self.in_conv is not None:
67
+ x = self.in_conv(x)
68
+
69
+ h = self.block1(x)
70
+ h = self.act(h)
71
+ h = self.block2(h)
72
+ if self.skep is not None:
73
+ return h + self.skep(x)
74
+ else:
75
+ return h + x
76
+
77
+
78
+ class Adapter(nn.Module):
79
+ def __init__(
80
+ self,
81
+ channels=[320, 640, 1280, 1280],
82
+ nums_rb=3,
83
+ cin=64,
84
+ ksize=3,
85
+ sk=True,
86
+ use_conv=True,
87
+ stage_downscale=True,
88
+ use_identity=False,
89
+ ):
90
+ super(Adapter, self).__init__()
91
+ if use_identity:
92
+ self.inlayer = nn.Identity()
93
+ else:
94
+ self.inlayer = nn.PixelUnshuffle(8)
95
+
96
+ self.channels = channels
97
+ self.nums_rb = nums_rb
98
+ self.body = []
99
+ for i in range(len(channels)):
100
+ for j in range(nums_rb):
101
+ if (i != 0) and (j == 0):
102
+ self.body.append(
103
+ ResnetBlock(
104
+ channels[i - 1],
105
+ channels[i],
106
+ down=stage_downscale,
107
+ ksize=ksize,
108
+ sk=sk,
109
+ use_conv=use_conv,
110
+ )
111
+ )
112
+ else:
113
+ self.body.append(
114
+ ResnetBlock(
115
+ channels[i],
116
+ channels[i],
117
+ down=False,
118
+ ksize=ksize,
119
+ sk=sk,
120
+ use_conv=use_conv,
121
+ )
122
+ )
123
+ self.body = nn.ModuleList(self.body)
124
+ self.conv_in = nn.Conv2d(cin, channels[0], 3, 1, 1)
125
+
126
+ def forward(self, x):
127
+ # unshuffle
128
+ x = self.inlayer(x)
129
+ # extract features
130
+ features = []
131
+ x = self.conv_in(x)
132
+ for i in range(len(self.channels)):
133
+ for j in range(self.nums_rb):
134
+ idx = i * self.nums_rb + j
135
+ x = self.body[idx](x)
136
+ features.append(x)
137
+
138
+ return features
139
+
140
+
141
+ class PositionNet(nn.Module):
142
+ def __init__(self, input_size=(40, 64), cin=4, dim=512, out_dim=1024):
143
+ super().__init__()
144
+ self.input_size = input_size
145
+ self.out_dim = out_dim
146
+ self.down_factor = 8 # determined by the convnext backbone
147
+ feature_dim = dim
148
+ self.backbone = Adapter(
149
+ channels=[64, 128, 256, feature_dim],
150
+ nums_rb=2,
151
+ cin=cin,
152
+ stage_downscale=True,
153
+ use_identity=True,
154
+ )
155
+ self.num_tokens = (self.input_size[0] // self.down_factor) * (
156
+ self.input_size[1] // self.down_factor
157
+ )
158
+
159
+ self.pos_embedding = nn.Parameter(
160
+ torch.empty(1, self.num_tokens, feature_dim).normal_(std=0.02)
161
+ ) # from BERT
162
+
163
+ self.linears = nn.Sequential(
164
+ nn.Linear(feature_dim, 512),
165
+ nn.SiLU(),
166
+ nn.Linear(512, 512),
167
+ nn.SiLU(),
168
+ nn.Linear(512, out_dim),
169
+ )
170
+ # self.null_feature = torch.nn.Parameter(torch.zeros([feature_dim]))
171
+
172
+ def forward(self, x, mask=None):
173
+ B = x.shape[0]
174
+ # token from edge map
175
+ # x = torch.nn.functional.interpolate(x, self.input_size)
176
+ feature = self.backbone(x)[-1]
177
+ objs = feature.reshape(B, -1, self.num_tokens)
178
+ objs = objs.permute(0, 2, 1) # N*Num_tokens*dim
179
+ """
180
+ # expand null token
181
+ null_objs = self.null_feature.view(1,1,-1)
182
+ null_objs = null_objs.repeat(B,self.num_tokens,1)
183
+
184
+ # mask replacing
185
+ mask = mask.view(-1,1,1)
186
+ objs = objs*mask + null_objs*(1-mask)
187
+ """
188
+ # add pos
189
+ objs = objs + self.pos_embedding
190
+
191
+ # fuse them
192
+ objs = self.linears(objs)
193
+
194
+ assert objs.shape == torch.Size([B, self.num_tokens, self.out_dim])
195
+ return objs
196
+
197
+
198
+ class PositionNet2(nn.Module):
199
+ def __init__(self, input_size=(40, 64), cin=4, dim=320, out_dim=1024):
200
+ super().__init__()
201
+ self.input_size = input_size
202
+ self.out_dim = out_dim
203
+ self.down_factor = 8 # determined by the convnext backbone
204
+ self.dim = dim
205
+ self.backbone = Adapter(
206
+ channels=[dim, dim, dim, dim],
207
+ nums_rb=2,
208
+ cin=cin,
209
+ stage_downscale=True,
210
+ use_identity=True,
211
+ )
212
+ self.pos_embedding = FixedPositionalEmbedding(dim=self.dim)
213
+ self.linears = nn.Sequential(
214
+ nn.Linear(dim, 512),
215
+ nn.SiLU(),
216
+ nn.Linear(512, 512),
217
+ nn.SiLU(),
218
+ nn.Linear(512, out_dim),
219
+ )
220
+
221
+ def forward(self, x, mask=None):
222
+ B = x.shape[0]
223
+ features = self.backbone(x)
224
+ token_lists = []
225
+ for feature in features:
226
+ objs = feature.reshape(B, self.dim, -1)
227
+ objs = objs.permute(0, 2, 1) # N*Num_tokens*dim
228
+ # add pos
229
+ objs = objs + self.pos_embedding(objs)
230
+ # fuse them
231
+ objs = self.linears(objs)
232
+ token_lists.append(objs)
233
+
234
+ return token_lists
235
+
236
+
237
+ class LayerNorm(nn.LayerNorm):
238
+ """Subclass torch's LayerNorm to handle fp16."""
239
+
240
+ def forward(self, x: torch.Tensor):
241
+ orig_type = x.dtype
242
+ ret = super().forward(x.type(torch.float32))
243
+ return ret.type(orig_type)
244
+
245
+
246
+ class QuickGELU(nn.Module):
247
+
248
+ def forward(self, x: torch.Tensor):
249
+ return x * torch.sigmoid(1.702 * x)
250
+
251
+
252
+ class ResidualAttentionBlock(nn.Module):
253
+
254
+ def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
255
+ super().__init__()
256
+
257
+ self.attn = nn.MultiheadAttention(d_model, n_head)
258
+ self.ln_1 = LayerNorm(d_model)
259
+ self.mlp = nn.Sequential(
260
+ OrderedDict(
261
+ [
262
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
263
+ ("gelu", QuickGELU()),
264
+ ("c_proj", nn.Linear(d_model * 4, d_model)),
265
+ ]
266
+ )
267
+ )
268
+ self.ln_2 = LayerNorm(d_model)
269
+ self.attn_mask = attn_mask
270
+
271
+ def attention(self, x: torch.Tensor):
272
+ self.attn_mask = (
273
+ self.attn_mask.to(dtype=x.dtype, device=x.device)
274
+ if self.attn_mask is not None
275
+ else None
276
+ )
277
+ return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
278
+
279
+ def forward(self, x: torch.Tensor):
280
+ x = x + self.attention(self.ln_1(x))
281
+ x = x + self.mlp(self.ln_2(x))
282
+ return x
283
+
284
+
285
+ class StyleAdapter(nn.Module):
286
+
287
+ def __init__(self, width=1024, context_dim=768, num_head=8, n_layes=3, num_token=4):
288
+ super().__init__()
289
+
290
+ scale = width**-0.5
291
+ self.transformer_layes = nn.Sequential(
292
+ *[ResidualAttentionBlock(width, num_head) for _ in range(n_layes)]
293
+ )
294
+ self.num_token = num_token
295
+ self.style_embedding = nn.Parameter(torch.randn(1, num_token, width) * scale)
296
+ self.ln_post = LayerNorm(width)
297
+ self.ln_pre = LayerNorm(width)
298
+ self.proj = nn.Parameter(scale * torch.randn(width, context_dim))
299
+
300
+ def forward(self, x):
301
+ # x shape [N, HW+1, C]
302
+ style_embedding = self.style_embedding + torch.zeros(
303
+ (x.shape[0], self.num_token, self.style_embedding.shape[-1]),
304
+ device=x.device,
305
+ )
306
+ x = torch.cat([x, style_embedding], dim=1)
307
+ x = self.ln_pre(x)
308
+ x = x.permute(1, 0, 2) # NLD -> LND
309
+ x = self.transformer_layes(x)
310
+ x = x.permute(1, 0, 2) # LND -> NLD
311
+
312
+ x = self.ln_post(x[:, -self.num_token :, :])
313
+ x = x @ self.proj
314
+
315
+ return x
316
+
317
+
318
+ class ResnetBlock_light(nn.Module):
319
+ def __init__(self, in_c):
320
+ super().__init__()
321
+ self.block1 = nn.Conv2d(in_c, in_c, 3, 1, 1)
322
+ self.act = nn.ReLU()
323
+ self.block2 = nn.Conv2d(in_c, in_c, 3, 1, 1)
324
+
325
+ def forward(self, x):
326
+ h = self.block1(x)
327
+ h = self.act(h)
328
+ h = self.block2(h)
329
+
330
+ return h + x
331
+
332
+
333
+ class extractor(nn.Module):
334
+ def __init__(self, in_c, inter_c, out_c, nums_rb, down=False):
335
+ super().__init__()
336
+ self.in_conv = nn.Conv2d(in_c, inter_c, 1, 1, 0)
337
+ self.body = []
338
+ for _ in range(nums_rb):
339
+ self.body.append(ResnetBlock_light(inter_c))
340
+ self.body = nn.Sequential(*self.body)
341
+ self.out_conv = nn.Conv2d(inter_c, out_c, 1, 1, 0)
342
+ self.down = down
343
+ if self.down == True:
344
+ self.down_opt = Downsample(in_c, use_conv=False)
345
+
346
+ def forward(self, x):
347
+ if self.down == True:
348
+ x = self.down_opt(x)
349
+ x = self.in_conv(x)
350
+ x = self.body(x)
351
+ x = self.out_conv(x)
352
+
353
+ return x
354
+
355
+
356
+ class Adapter_light(nn.Module):
357
+ def __init__(self, channels=[320, 640, 1280, 1280], nums_rb=3, cin=64):
358
+ super(Adapter_light, self).__init__()
359
+ self.unshuffle = nn.PixelUnshuffle(8)
360
+ self.channels = channels
361
+ self.nums_rb = nums_rb
362
+ self.body = []
363
+ for i in range(len(channels)):
364
+ if i == 0:
365
+ self.body.append(
366
+ extractor(
367
+ in_c=cin,
368
+ inter_c=channels[i] // 4,
369
+ out_c=channels[i],
370
+ nums_rb=nums_rb,
371
+ down=False,
372
+ )
373
+ )
374
+ else:
375
+ self.body.append(
376
+ extractor(
377
+ in_c=channels[i - 1],
378
+ inter_c=channels[i] // 4,
379
+ out_c=channels[i],
380
+ nums_rb=nums_rb,
381
+ down=True,
382
+ )
383
+ )
384
+ self.body = nn.ModuleList(self.body)
385
+
386
+ def forward(self, x):
387
+ # unshuffle
388
+ x = self.unshuffle(x)
389
+ # extract features
390
+ features = []
391
+ for i in range(len(self.channels)):
392
+ x = self.body[i](x)
393
+ features.append(x)
394
+
395
+ return features
396
+
397
+
398
+ class CoAdapterFuser(nn.Module):
399
+ def __init__(
400
+ self, unet_channels=[320, 640, 1280, 1280], width=768, num_head=8, n_layes=3
401
+ ):
402
+ super(CoAdapterFuser, self).__init__()
403
+ scale = width**0.5
404
+ self.task_embedding = nn.Parameter(scale * torch.randn(16, width))
405
+ self.positional_embedding = nn.Parameter(
406
+ scale * torch.randn(len(unet_channels), width)
407
+ )
408
+ self.spatial_feat_mapping = nn.ModuleList()
409
+ for ch in unet_channels:
410
+ self.spatial_feat_mapping.append(
411
+ nn.Sequential(
412
+ nn.SiLU(),
413
+ nn.Linear(ch, width),
414
+ )
415
+ )
416
+ self.transformer_layes = nn.Sequential(
417
+ *[ResidualAttentionBlock(width, num_head) for _ in range(n_layes)]
418
+ )
419
+ self.ln_post = LayerNorm(width)
420
+ self.ln_pre = LayerNorm(width)
421
+ self.spatial_ch_projs = nn.ModuleList()
422
+ for ch in unet_channels:
423
+ self.spatial_ch_projs.append(zero_module(nn.Linear(width, ch)))
424
+ self.seq_proj = nn.Parameter(torch.zeros(width, width))
425
+
426
+ def forward(self, features):
427
+ if len(features) == 0:
428
+ return None, None
429
+ inputs = []
430
+ for cond_name in features.keys():
431
+ task_idx = getattr(ExtraCondition, cond_name).value
432
+ if not isinstance(features[cond_name], list):
433
+ inputs.append(features[cond_name] + self.task_embedding[task_idx])
434
+ continue
435
+
436
+ feat_seq = []
437
+ for idx, feature_map in enumerate(features[cond_name]):
438
+ feature_vec = torch.mean(feature_map, dim=(2, 3))
439
+ feature_vec = self.spatial_feat_mapping[idx](feature_vec)
440
+ feat_seq.append(feature_vec)
441
+ feat_seq = torch.stack(feat_seq, dim=1) # Nx4xC
442
+ feat_seq = feat_seq + self.task_embedding[task_idx]
443
+ feat_seq = feat_seq + self.positional_embedding
444
+ inputs.append(feat_seq)
445
+
446
+ x = torch.cat(inputs, dim=1) # NxLxC
447
+ x = self.ln_pre(x)
448
+ x = x.permute(1, 0, 2) # NLD -> LND
449
+ x = self.transformer_layes(x)
450
+ x = x.permute(1, 0, 2) # LND -> NLD
451
+ x = self.ln_post(x)
452
+
453
+ ret_feat_map = None
454
+ ret_feat_seq = None
455
+ cur_seq_idx = 0
456
+ for cond_name in features.keys():
457
+ if not isinstance(features[cond_name], list):
458
+ length = features[cond_name].size(1)
459
+ transformed_feature = features[cond_name] * (
460
+ (x[:, cur_seq_idx : cur_seq_idx + length] @ self.seq_proj) + 1
461
+ )
462
+ if ret_feat_seq is None:
463
+ ret_feat_seq = transformed_feature
464
+ else:
465
+ ret_feat_seq = torch.cat([ret_feat_seq, transformed_feature], dim=1)
466
+ cur_seq_idx += length
467
+ continue
468
+
469
+ length = len(features[cond_name])
470
+ transformed_feature_list = []
471
+ for idx in range(length):
472
+ alpha = self.spatial_ch_projs[idx](x[:, cur_seq_idx + idx])
473
+ alpha = alpha.unsqueeze(-1).unsqueeze(-1) + 1
474
+ transformed_feature_list.append(features[cond_name][idx] * alpha)
475
+ if ret_feat_map is None:
476
+ ret_feat_map = transformed_feature_list
477
+ else:
478
+ ret_feat_map = list(
479
+ map(lambda x, y: x + y, ret_feat_map, transformed_feature_list)
480
+ )
481
+ cur_seq_idx += length
482
+
483
+ assert cur_seq_idx == x.size(1)
484
+
485
+ return ret_feat_map, ret_feat_seq
core/modules/encoders/condition.py ADDED
@@ -0,0 +1,511 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import kornia
4
+ from torch.utils.checkpoint import checkpoint
5
+
6
+ from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel
7
+ import open_clip
8
+
9
+ from core.common import autocast
10
+ from utils.utils import count_params
11
+
12
+
13
+ class AbstractEncoder(nn.Module):
14
+ def __init__(self):
15
+ super().__init__()
16
+
17
+ def encode(self, *args, **kwargs):
18
+ raise NotImplementedError
19
+
20
+
21
+ class IdentityEncoder(AbstractEncoder):
22
+
23
+ def encode(self, x):
24
+ return x
25
+
26
+
27
+ class ClassEmbedder(nn.Module):
28
+ def __init__(self, embed_dim, n_classes=1000, key="class", ucg_rate=0.1):
29
+ super().__init__()
30
+ self.key = key
31
+ self.embedding = nn.Embedding(n_classes, embed_dim)
32
+ self.n_classes = n_classes
33
+ self.ucg_rate = ucg_rate
34
+
35
+ def forward(self, batch, key=None, disable_dropout=False):
36
+ if key is None:
37
+ key = self.key
38
+ # this is for use in crossattn
39
+ c = batch[key][:, None]
40
+ if self.ucg_rate > 0.0 and not disable_dropout:
41
+ mask = 1.0 - torch.bernoulli(torch.ones_like(c) * self.ucg_rate)
42
+ c = mask * c + (1 - mask) * torch.ones_like(c) * (self.n_classes - 1)
43
+ c = c.long()
44
+ c = self.embedding(c)
45
+ return c
46
+
47
+ def get_unconditional_conditioning(self, bs, device="cuda"):
48
+ # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000)
49
+ uc_class = self.n_classes - 1
50
+ uc = torch.ones((bs,), device=device) * uc_class
51
+ uc = {self.key: uc}
52
+ return uc
53
+
54
+
55
+ def disabled_train(self, mode=True):
56
+ """Overwrite model.train with this function to make sure train/eval mode
57
+ does not change anymore."""
58
+ return self
59
+
60
+
61
+ class FrozenT5Embedder(AbstractEncoder):
62
+ """Uses the T5 transformer encoder for text"""
63
+
64
+ def __init__(
65
+ self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True
66
+ ):
67
+ super().__init__()
68
+ self.tokenizer = T5Tokenizer.from_pretrained(version)
69
+ self.transformer = T5EncoderModel.from_pretrained(version)
70
+ self.device = device
71
+ self.max_length = max_length
72
+ if freeze:
73
+ self.freeze()
74
+
75
+ def freeze(self):
76
+ self.transformer = self.transformer.eval()
77
+ # self.train = disabled_train
78
+ for param in self.parameters():
79
+ param.requires_grad = False
80
+
81
+ def forward(self, text):
82
+ batch_encoding = self.tokenizer(
83
+ text,
84
+ truncation=True,
85
+ max_length=self.max_length,
86
+ return_length=True,
87
+ return_overflowing_tokens=False,
88
+ padding="max_length",
89
+ return_tensors="pt",
90
+ )
91
+ tokens = batch_encoding["input_ids"].to(self.device)
92
+ outputs = self.transformer(input_ids=tokens)
93
+
94
+ z = outputs.last_hidden_state
95
+ return z
96
+
97
+ def encode(self, text):
98
+ return self(text)
99
+
100
+
101
+ class FrozenCLIPEmbedder(AbstractEncoder):
102
+ """Uses the CLIP transformer encoder for text (from huggingface)"""
103
+
104
+ LAYERS = ["last", "pooled", "hidden"]
105
+
106
+ def __init__(
107
+ self,
108
+ version="openai/clip-vit-large-patch14",
109
+ device="cuda",
110
+ max_length=77,
111
+ freeze=True,
112
+ layer="last",
113
+ layer_idx=None,
114
+ ): # clip-vit-base-patch32
115
+ super().__init__()
116
+ assert layer in self.LAYERS
117
+ self.tokenizer = CLIPTokenizer.from_pretrained(version)
118
+ self.transformer = CLIPTextModel.from_pretrained(version)
119
+ self.device = device
120
+ self.max_length = max_length
121
+ if freeze:
122
+ self.freeze()
123
+ self.layer = layer
124
+ self.layer_idx = layer_idx
125
+ if layer == "hidden":
126
+ assert layer_idx is not None
127
+ assert 0 <= abs(layer_idx) <= 12
128
+
129
+ def freeze(self):
130
+ self.transformer = self.transformer.eval()
131
+ # self.train = disabled_train
132
+ for param in self.parameters():
133
+ param.requires_grad = False
134
+
135
+ def forward(self, text):
136
+ batch_encoding = self.tokenizer(
137
+ text,
138
+ truncation=True,
139
+ max_length=self.max_length,
140
+ return_length=True,
141
+ return_overflowing_tokens=False,
142
+ padding="max_length",
143
+ return_tensors="pt",
144
+ )
145
+ tokens = batch_encoding["input_ids"].to(self.device)
146
+ outputs = self.transformer(
147
+ input_ids=tokens, output_hidden_states=self.layer == "hidden"
148
+ )
149
+ if self.layer == "last":
150
+ z = outputs.last_hidden_state
151
+ elif self.layer == "pooled":
152
+ z = outputs.pooler_output[:, None, :]
153
+ else:
154
+ z = outputs.hidden_states[self.layer_idx]
155
+ return z
156
+
157
+ def encode(self, text):
158
+ return self(text)
159
+
160
+
161
+ class ClipImageEmbedder(nn.Module):
162
+ def __init__(
163
+ self,
164
+ model,
165
+ jit=False,
166
+ device="cuda" if torch.cuda.is_available() else "cpu",
167
+ antialias=True,
168
+ ucg_rate=0.0,
169
+ ):
170
+ super().__init__()
171
+ from clip import load as load_clip
172
+
173
+ self.model, _ = load_clip(name=model, device=device, jit=jit)
174
+
175
+ self.antialias = antialias
176
+
177
+ self.register_buffer(
178
+ "mean", torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False
179
+ )
180
+ self.register_buffer(
181
+ "std", torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False
182
+ )
183
+ self.ucg_rate = ucg_rate
184
+
185
+ def preprocess(self, x):
186
+ # normalize to [0,1]
187
+ x = kornia.geometry.resize(
188
+ x,
189
+ (224, 224),
190
+ interpolation="bicubic",
191
+ align_corners=True,
192
+ antialias=self.antialias,
193
+ )
194
+ x = (x + 1.0) / 2.0
195
+ # re-normalize according to clip
196
+ x = kornia.enhance.normalize(x, self.mean, self.std)
197
+ return x
198
+
199
+ def forward(self, x, no_dropout=False):
200
+ # x is assumed to be in range [-1,1]
201
+ out = self.model.encode_image(self.preprocess(x))
202
+ out = out.to(x.dtype)
203
+ if self.ucg_rate > 0.0 and not no_dropout:
204
+ out = (
205
+ torch.bernoulli(
206
+ (1.0 - self.ucg_rate) * torch.ones(out.shape[0], device=out.device)
207
+ )[:, None]
208
+ * out
209
+ )
210
+ return out
211
+
212
+
213
+ class FrozenOpenCLIPEmbedder(AbstractEncoder):
214
+ """
215
+ Uses the OpenCLIP transformer encoder for text
216
+ """
217
+
218
+ LAYERS = [
219
+ # "pooled",
220
+ "last",
221
+ "penultimate",
222
+ ]
223
+
224
+ def __init__(
225
+ self,
226
+ arch="ViT-H-14",
227
+ version=None,
228
+ device="cuda",
229
+ max_length=77,
230
+ freeze=True,
231
+ layer="last",
232
+ ):
233
+ super().__init__()
234
+ assert layer in self.LAYERS
235
+ model, _, _ = open_clip.create_model_and_transforms(
236
+ arch, device=torch.device("cpu"), pretrained=version
237
+ )
238
+ del model.visual
239
+ self.model = model
240
+
241
+ self.device = device
242
+ self.max_length = max_length
243
+ if freeze:
244
+ self.freeze()
245
+ self.layer = layer
246
+ if self.layer == "last":
247
+ self.layer_idx = 0
248
+ elif self.layer == "penultimate":
249
+ self.layer_idx = 1
250
+ else:
251
+ raise NotImplementedError()
252
+
253
+ def freeze(self):
254
+ self.model = self.model.eval()
255
+ for param in self.parameters():
256
+ param.requires_grad = False
257
+
258
+ def forward(self, text):
259
+ tokens = open_clip.tokenize(text)
260
+ z = self.encode_with_transformer(tokens.to(self.device))
261
+ return z
262
+
263
+ def encode_with_transformer(self, text):
264
+ x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
265
+ x = x + self.model.positional_embedding
266
+ x = x.permute(1, 0, 2) # NLD -> LND
267
+ x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
268
+ x = x.permute(1, 0, 2) # LND -> NLD
269
+ x = self.model.ln_final(x)
270
+ return x
271
+
272
+ def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
273
+ for i, r in enumerate(self.model.transformer.resblocks):
274
+ if i == len(self.model.transformer.resblocks) - self.layer_idx:
275
+ break
276
+ if (
277
+ self.model.transformer.grad_checkpointing
278
+ and not torch.jit.is_scripting()
279
+ ):
280
+ x = checkpoint(r, x, attn_mask)
281
+ else:
282
+ x = r(x, attn_mask=attn_mask)
283
+ return x
284
+
285
+ def encode(self, text):
286
+ return self(text)
287
+
288
+
289
+ class FrozenOpenCLIPImageEmbedder(AbstractEncoder):
290
+ """
291
+ Uses the OpenCLIP vision transformer encoder for images
292
+ """
293
+
294
+ def __init__(
295
+ self,
296
+ arch="ViT-H-14",
297
+ version=None,
298
+ device="cuda",
299
+ max_length=77,
300
+ freeze=True,
301
+ layer="pooled",
302
+ antialias=True,
303
+ ucg_rate=0.0,
304
+ ):
305
+ super().__init__()
306
+ model, _, _ = open_clip.create_model_and_transforms(
307
+ arch, device=torch.device("cpu"), pretrained=version
308
+ )
309
+ del model.transformer
310
+ self.model = model
311
+
312
+ self.device = device
313
+ self.max_length = max_length
314
+ if freeze:
315
+ self.freeze()
316
+ self.layer = layer
317
+ if self.layer == "penultimate":
318
+ raise NotImplementedError()
319
+ self.layer_idx = 1
320
+
321
+ self.antialias = antialias
322
+
323
+ self.register_buffer(
324
+ "mean", torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False
325
+ )
326
+ self.register_buffer(
327
+ "std", torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False
328
+ )
329
+ self.ucg_rate = ucg_rate
330
+
331
+ def preprocess(self, x):
332
+ # normalize to [0,1]
333
+ x = kornia.geometry.resize(
334
+ x,
335
+ (224, 224),
336
+ interpolation="bicubic",
337
+ align_corners=True,
338
+ antialias=self.antialias,
339
+ )
340
+ x = (x + 1.0) / 2.0
341
+ # renormalize according to clip
342
+ x = kornia.enhance.normalize(x, self.mean, self.std)
343
+ return x
344
+
345
+ def freeze(self):
346
+ self.model = self.model.eval()
347
+ for param in self.parameters():
348
+ param.requires_grad = False
349
+
350
+ @autocast
351
+ def forward(self, image, no_dropout=False):
352
+ z = self.encode_with_vision_transformer(image)
353
+ if self.ucg_rate > 0.0 and not no_dropout:
354
+ z = (
355
+ torch.bernoulli(
356
+ (1.0 - self.ucg_rate) * torch.ones(z.shape[0], device=z.device)
357
+ )[:, None]
358
+ * z
359
+ )
360
+ return z
361
+
362
+ def encode_with_vision_transformer(self, img):
363
+ img = self.preprocess(img)
364
+ x = self.model.visual(img)
365
+ return x
366
+
367
+ def encode(self, text):
368
+ return self(text)
369
+
370
+
371
+ class FrozenOpenCLIPImageEmbedderV2(AbstractEncoder):
372
+ """
373
+ Uses the OpenCLIP vision transformer encoder for images
374
+ """
375
+
376
+ def __init__(
377
+ self,
378
+ arch="ViT-H-14",
379
+ version=None,
380
+ device="cuda",
381
+ freeze=True,
382
+ layer="pooled",
383
+ antialias=True,
384
+ ):
385
+ super().__init__()
386
+ model, _, _ = open_clip.create_model_and_transforms(
387
+ arch,
388
+ device=torch.device("cpu"),
389
+ pretrained=version,
390
+ )
391
+ del model.transformer
392
+ self.model = model
393
+ self.device = device
394
+
395
+ if freeze:
396
+ self.freeze()
397
+ self.layer = layer
398
+ if self.layer == "penultimate":
399
+ raise NotImplementedError()
400
+ self.layer_idx = 1
401
+
402
+ self.antialias = antialias
403
+ self.register_buffer(
404
+ "mean", torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False
405
+ )
406
+ self.register_buffer(
407
+ "std", torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False
408
+ )
409
+
410
+ def preprocess(self, x):
411
+ # normalize to [0,1]
412
+ x = kornia.geometry.resize(
413
+ x,
414
+ (224, 224),
415
+ interpolation="bicubic",
416
+ align_corners=True,
417
+ antialias=self.antialias,
418
+ )
419
+ x = (x + 1.0) / 2.0
420
+ # renormalize according to clip
421
+ x = kornia.enhance.normalize(x, self.mean, self.std)
422
+ return x
423
+
424
+ def freeze(self):
425
+ self.model = self.model.eval()
426
+ for param in self.model.parameters():
427
+ param.requires_grad = False
428
+
429
+ def forward(self, image, no_dropout=False):
430
+ # image: b c h w
431
+ z = self.encode_with_vision_transformer(image)
432
+ return z
433
+
434
+ def encode_with_vision_transformer(self, x):
435
+ x = self.preprocess(x)
436
+
437
+ # to patches - whether to use dual patchnorm - https://arxiv.org/abs/2302.01327v1
438
+ if self.model.visual.input_patchnorm:
439
+ # einops - rearrange(x, 'b c (h p1) (w p2) -> b (h w) (c p1 p2)')
440
+ x = x.reshape(
441
+ x.shape[0],
442
+ x.shape[1],
443
+ self.model.visual.grid_size[0],
444
+ self.model.visual.patch_size[0],
445
+ self.model.visual.grid_size[1],
446
+ self.model.visual.patch_size[1],
447
+ )
448
+ x = x.permute(0, 2, 4, 1, 3, 5)
449
+ x = x.reshape(
450
+ x.shape[0],
451
+ self.model.visual.grid_size[0] * self.model.visual.grid_size[1],
452
+ -1,
453
+ )
454
+ x = self.model.visual.patchnorm_pre_ln(x)
455
+ x = self.model.visual.conv1(x)
456
+ else:
457
+ x = self.model.visual.conv1(x) # shape = [*, width, grid, grid]
458
+ # shape = [*, width, grid ** 2]
459
+ x = x.reshape(x.shape[0], x.shape[1], -1)
460
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
461
+
462
+ # class embeddings and positional embeddings
463
+ x = torch.cat(
464
+ [
465
+ self.model.visual.class_embedding.to(x.dtype)
466
+ + torch.zeros(
467
+ x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device
468
+ ),
469
+ x,
470
+ ],
471
+ dim=1,
472
+ ) # shape = [*, grid ** 2 + 1, width]
473
+ x = x + self.model.visual.positional_embedding.to(x.dtype)
474
+
475
+ # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
476
+ x = self.model.visual.patch_dropout(x)
477
+ x = self.model.visual.ln_pre(x)
478
+
479
+ x = x.permute(1, 0, 2) # NLD -> LND
480
+ x = self.model.visual.transformer(x)
481
+ x = x.permute(1, 0, 2) # LND -> NLD
482
+
483
+ return x
484
+
485
+
486
+ class FrozenCLIPT5Encoder(AbstractEncoder):
487
+ def __init__(
488
+ self,
489
+ clip_version="openai/clip-vit-large-patch14",
490
+ t5_version="google/t5-v1_1-xl",
491
+ device="cuda",
492
+ clip_max_length=77,
493
+ t5_max_length=77,
494
+ ):
495
+ super().__init__()
496
+ self.clip_encoder = FrozenCLIPEmbedder(
497
+ clip_version, device, max_length=clip_max_length
498
+ )
499
+ self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length)
500
+ print(
501
+ f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder) * 1.e-6:.2f} M parameters, "
502
+ f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder) * 1.e-6:.2f} M params."
503
+ )
504
+
505
+ def encode(self, text):
506
+ return self(text)
507
+
508
+ def forward(self, text):
509
+ clip_z = self.clip_encoder.encode(text)
510
+ t5_z = self.t5_encoder.encode(text)
511
+ return [clip_z, t5_z]
core/modules/encoders/resampler.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from einops import rearrange, repeat
6
+
7
+
8
+ class ImageProjModel(nn.Module):
9
+ """Projection Model"""
10
+
11
+ def __init__(
12
+ self,
13
+ cross_attention_dim=1024,
14
+ clip_embeddings_dim=1024,
15
+ clip_extra_context_tokens=4,
16
+ ):
17
+ super().__init__()
18
+ self.cross_attention_dim = cross_attention_dim
19
+ self.clip_extra_context_tokens = clip_extra_context_tokens
20
+ self.proj = nn.Linear(
21
+ clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim
22
+ )
23
+ self.norm = nn.LayerNorm(cross_attention_dim)
24
+
25
+ def forward(self, image_embeds):
26
+ # embeds = image_embeds
27
+ embeds = image_embeds.type(list(self.proj.parameters())[0].dtype)
28
+ clip_extra_context_tokens = self.proj(embeds).reshape(
29
+ -1, self.clip_extra_context_tokens, self.cross_attention_dim
30
+ )
31
+ clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
32
+ return clip_extra_context_tokens
33
+
34
+
35
+ # FFN
36
+ def FeedForward(dim, mult=4):
37
+ inner_dim = int(dim * mult)
38
+ return nn.Sequential(
39
+ nn.LayerNorm(dim),
40
+ nn.Linear(dim, inner_dim, bias=False),
41
+ nn.GELU(),
42
+ nn.Linear(inner_dim, dim, bias=False),
43
+ )
44
+
45
+
46
+ def reshape_tensor(x, heads):
47
+ bs, length, width = x.shape
48
+ # (bs, length, width) --> (bs, length, n_heads, dim_per_head)
49
+ x = x.view(bs, length, heads, -1)
50
+ # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
51
+ x = x.transpose(1, 2)
52
+ # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
53
+ x = x.reshape(bs, heads, length, -1)
54
+ return x
55
+
56
+
57
+ class PerceiverAttention(nn.Module):
58
+ def __init__(self, *, dim, dim_head=64, heads=8):
59
+ super().__init__()
60
+ self.scale = dim_head**-0.5
61
+ self.dim_head = dim_head
62
+ self.heads = heads
63
+ inner_dim = dim_head * heads
64
+
65
+ self.norm1 = nn.LayerNorm(dim)
66
+ self.norm2 = nn.LayerNorm(dim)
67
+
68
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
69
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
70
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
71
+
72
+ def forward(self, x, latents):
73
+ """
74
+ Args:
75
+ x (torch.Tensor): image features
76
+ shape (b, n1, D)
77
+ latent (torch.Tensor): latent features
78
+ shape (b, n2, D)
79
+ """
80
+ x = self.norm1(x)
81
+ latents = self.norm2(latents)
82
+
83
+ b, l, _ = latents.shape
84
+
85
+ q = self.to_q(latents)
86
+ kv_input = torch.cat((x, latents), dim=-2)
87
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
88
+
89
+ q = reshape_tensor(q, self.heads)
90
+ k = reshape_tensor(k, self.heads)
91
+ v = reshape_tensor(v, self.heads)
92
+
93
+ # attention
94
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
95
+ # More stable with f16 than dividing afterwards
96
+ weight = (q * scale) @ (k * scale).transpose(-2, -1)
97
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
98
+ out = weight @ v
99
+
100
+ out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
101
+
102
+ return self.to_out(out)
103
+
104
+
105
+ class Resampler(nn.Module):
106
+ def __init__(
107
+ self,
108
+ dim=1024,
109
+ depth=8,
110
+ dim_head=64,
111
+ heads=16,
112
+ num_queries=8,
113
+ embedding_dim=768,
114
+ output_dim=1024,
115
+ ff_mult=4,
116
+ video_length=None,
117
+ ):
118
+ super().__init__()
119
+ self.num_queries = num_queries
120
+ self.video_length = video_length
121
+ if video_length is not None:
122
+ num_queries = num_queries * video_length
123
+ self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
124
+ self.proj_in = nn.Linear(embedding_dim, dim)
125
+ self.proj_out = nn.Linear(dim, output_dim)
126
+ self.norm_out = nn.LayerNorm(output_dim)
127
+
128
+ self.layers = nn.ModuleList([])
129
+ for _ in range(depth):
130
+ self.layers.append(
131
+ nn.ModuleList(
132
+ [
133
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
134
+ FeedForward(dim=dim, mult=ff_mult),
135
+ ]
136
+ )
137
+ )
138
+
139
+ def forward(self, x):
140
+ latents = self.latents.repeat(x.size(0), 1, 1) # B (T L) C
141
+ x = self.proj_in(x)
142
+
143
+ for attn, ff in self.layers:
144
+ latents = attn(x, latents) + latents
145
+ latents = ff(latents) + latents
146
+
147
+ latents = self.proj_out(latents)
148
+ latents = self.norm_out(latents) # B L C or B (T L) C
149
+
150
+ return latents
151
+
152
+
153
+ class CameraPoseQueryTransformer(nn.Module):
154
+ def __init__(
155
+ self,
156
+ dim=1024,
157
+ depth=8,
158
+ dim_head=64,
159
+ heads=16,
160
+ num_queries=8,
161
+ embedding_dim=768,
162
+ output_dim=1024,
163
+ ff_mult=4,
164
+ num_views=None,
165
+ use_multi_view_attention=True,
166
+ ):
167
+ super().__init__()
168
+
169
+ self.num_queries = num_queries
170
+ self.num_views = num_views
171
+ assert num_views is not None, "video_length must be given."
172
+ self.use_multi_view_attention = use_multi_view_attention
173
+ self.camera_pose_embedding_layers = nn.Sequential(
174
+ nn.Linear(12, dim),
175
+ nn.SiLU(),
176
+ nn.Linear(dim, dim),
177
+ nn.SiLU(),
178
+ nn.Linear(dim, dim),
179
+ )
180
+ nn.init.zeros_(self.camera_pose_embedding_layers[-1].weight)
181
+ nn.init.zeros_(self.camera_pose_embedding_layers[-1].bias)
182
+
183
+ self.latents = nn.Parameter(
184
+ torch.randn(1, num_views * num_queries, dim) / dim**0.5
185
+ )
186
+
187
+ self.proj_in = nn.Linear(embedding_dim, dim)
188
+
189
+ self.proj_out = nn.Linear(dim, output_dim)
190
+ self.norm_out = nn.LayerNorm(output_dim)
191
+
192
+ self.layers = nn.ModuleList([])
193
+ for _ in range(depth):
194
+ self.layers.append(
195
+ nn.ModuleList(
196
+ [
197
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
198
+ FeedForward(dim=dim, mult=ff_mult),
199
+ ]
200
+ )
201
+ )
202
+
203
+ def forward(self, x, camera_poses):
204
+ # camera_poses: (b, t, 12)
205
+ batch_size, num_views, _ = camera_poses.shape
206
+ # latents: (1, t*q, d) -> (b, t*q, d)
207
+ latents = self.latents.repeat(batch_size, 1, 1)
208
+ x = self.proj_in(x)
209
+ # camera_poses: (b*t, 12)
210
+ camera_poses = rearrange(camera_poses, "b t d -> (b t) d", t=num_views)
211
+ camera_poses = self.camera_pose_embedding_layers(
212
+ camera_poses
213
+ ) # camera_poses: (b*t, d)
214
+ # camera_poses: (b, t, d)
215
+ camera_poses = rearrange(camera_poses, "(b t) d -> b t d", t=num_views)
216
+ # camera_poses: (b, t*q, d)
217
+ camera_poses = repeat(camera_poses, "b t d -> b (t q) d", q=self.num_queries)
218
+
219
+ latents = latents + camera_poses # b, t*q, d
220
+
221
+ latents = rearrange(
222
+ latents,
223
+ "b (t q) d -> (b t) q d",
224
+ b=batch_size,
225
+ t=num_views,
226
+ q=self.num_queries,
227
+ ) # (b*t, q, d)
228
+
229
+ _, x_seq_size, _ = x.shape
230
+ for layer_idx, (attn, ff) in enumerate(self.layers):
231
+ if self.use_multi_view_attention and layer_idx % 2 == 1:
232
+ # latents: (b*t, q, d)
233
+ latents = rearrange(
234
+ latents,
235
+ "(b t) q d -> b (t q) d",
236
+ b=batch_size,
237
+ t=num_views,
238
+ q=self.num_queries,
239
+ )
240
+ # x: (b*t, s, d)
241
+ x = rearrange(
242
+ x, "(b t) s d -> b (t s) d", b=batch_size, t=num_views, s=x_seq_size
243
+ )
244
+
245
+ # print("After rearrange: latents.shape=", latents.shape)
246
+ # print("After rearrange: x.shape=", camera_poses.shape)
247
+ latents = attn(x, latents) + latents
248
+ latents = ff(latents) + latents
249
+ if self.use_multi_view_attention and layer_idx % 2 == 1:
250
+ # latents: (b*q, t, d)
251
+ latents = rearrange(
252
+ latents,
253
+ "b (t q) d -> (b t) q d",
254
+ b=batch_size,
255
+ t=num_views,
256
+ q=self.num_queries,
257
+ )
258
+ # x: (b*s, t, d)
259
+ x = rearrange(
260
+ x, "b (t s) d -> (b t) s d", b=batch_size, t=num_views, s=x_seq_size
261
+ )
262
+ latents = self.proj_out(latents)
263
+ latents = self.norm_out(latents) # B L C or B (T L) C
264
+ return latents
core/modules/networks/ae_modules.py ADDED
@@ -0,0 +1,1023 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pytorch_diffusion + derived encoder decoder
2
+ import math
3
+
4
+ import torch
5
+ import numpy as np
6
+ import torch.nn as nn
7
+ from einops import rearrange
8
+
9
+ from utils.utils import instantiate_from_config
10
+ from core.modules.attention import LinearAttention
11
+
12
+
13
+ def nonlinearity(x):
14
+ # swish
15
+ return x * torch.sigmoid(x)
16
+
17
+
18
+ def Normalize(in_channels, num_groups=32):
19
+ return torch.nn.GroupNorm(
20
+ num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
21
+ )
22
+
23
+
24
+ class LinAttnBlock(LinearAttention):
25
+ """to match AttnBlock usage"""
26
+
27
+ def __init__(self, in_channels):
28
+ super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
29
+
30
+
31
+ class AttnBlock(nn.Module):
32
+ def __init__(self, in_channels):
33
+ super().__init__()
34
+ self.in_channels = in_channels
35
+
36
+ self.norm = Normalize(in_channels)
37
+ self.q = torch.nn.Conv2d(
38
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
39
+ )
40
+ self.k = torch.nn.Conv2d(
41
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
42
+ )
43
+ self.v = torch.nn.Conv2d(
44
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
45
+ )
46
+ self.proj_out = torch.nn.Conv2d(
47
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
48
+ )
49
+
50
+ def forward(self, x):
51
+ h_ = x
52
+ h_ = self.norm(h_)
53
+ q = self.q(h_)
54
+ k = self.k(h_)
55
+ v = self.v(h_)
56
+
57
+ # compute attention
58
+ b, c, h, w = q.shape
59
+ q = q.reshape(b, c, h * w) # bcl
60
+ q = q.permute(0, 2, 1) # bcl -> blc l=hw
61
+ k = k.reshape(b, c, h * w) # bcl
62
+
63
+ w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
64
+ w_ = w_ * (int(c) ** (-0.5))
65
+ w_ = torch.nn.functional.softmax(w_, dim=2)
66
+
67
+ # attend to values
68
+ v = v.reshape(b, c, h * w)
69
+ w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
70
+ # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
71
+ h_ = torch.bmm(v, w_)
72
+ h_ = h_.reshape(b, c, h, w)
73
+
74
+ h_ = self.proj_out(h_)
75
+
76
+ return x + h_
77
+
78
+
79
+ def make_attn(in_channels, attn_type="vanilla"):
80
+ assert attn_type in ["vanilla", "linear", "none"], f"attn_type {attn_type} unknown"
81
+ if attn_type == "vanilla":
82
+ return AttnBlock(in_channels)
83
+ elif attn_type == "none":
84
+ return nn.Identity(in_channels)
85
+ else:
86
+ return LinAttnBlock(in_channels)
87
+
88
+
89
+ class Downsample(nn.Module):
90
+ def __init__(self, in_channels, with_conv):
91
+ super().__init__()
92
+ self.with_conv = with_conv
93
+ self.in_channels = in_channels
94
+ if self.with_conv:
95
+ # no asymmetric padding in torch conv, must do it ourselves
96
+ self.conv = torch.nn.Conv2d(
97
+ in_channels, in_channels, kernel_size=3, stride=2, padding=0
98
+ )
99
+
100
+ def forward(self, x):
101
+ if self.with_conv:
102
+ pad = (0, 1, 0, 1)
103
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
104
+ x = self.conv(x)
105
+ else:
106
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
107
+ return x
108
+
109
+
110
+ class Upsample(nn.Module):
111
+ def __init__(self, in_channels, with_conv):
112
+ super().__init__()
113
+ self.with_conv = with_conv
114
+ self.in_channels = in_channels
115
+ if self.with_conv:
116
+ self.conv = torch.nn.Conv2d(
117
+ in_channels, in_channels, kernel_size=3, stride=1, padding=1
118
+ )
119
+
120
+ def forward(self, x):
121
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
122
+ if self.with_conv:
123
+ x = self.conv(x)
124
+ return x
125
+
126
+
127
+ def get_timestep_embedding(time_steps, embedding_dim):
128
+ """
129
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
130
+ From Fairseq.
131
+ Build sinusoidal embeddings.
132
+ This matches the implementation in tensor2tensor, but differs slightly
133
+ from the description in Section 3.5 of "Attention Is All You Need".
134
+ """
135
+ assert len(time_steps.shape) == 1
136
+
137
+ half_dim = embedding_dim // 2
138
+ emb = math.log(10000) / (half_dim - 1)
139
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
140
+ emb = emb.to(device=time_steps.device)
141
+ emb = time_steps.float()[:, None] * emb[None, :]
142
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
143
+ if embedding_dim % 2 == 1: # zero pad
144
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
145
+ return emb
146
+
147
+
148
+ class ResnetBlock(nn.Module):
149
+ def __init__(
150
+ self,
151
+ *,
152
+ in_channels,
153
+ out_channels=None,
154
+ conv_shortcut=False,
155
+ dropout,
156
+ temb_channels=512,
157
+ ):
158
+ super().__init__()
159
+ self.in_channels = in_channels
160
+ out_channels = in_channels if out_channels is None else out_channels
161
+ self.out_channels = out_channels
162
+ self.use_conv_shortcut = conv_shortcut
163
+
164
+ self.norm1 = Normalize(in_channels)
165
+ self.conv1 = torch.nn.Conv2d(
166
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
167
+ )
168
+ if temb_channels > 0:
169
+ self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
170
+ self.norm2 = Normalize(out_channels)
171
+ self.dropout = torch.nn.Dropout(dropout)
172
+ self.conv2 = torch.nn.Conv2d(
173
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1
174
+ )
175
+ if self.in_channels != self.out_channels:
176
+ if self.use_conv_shortcut:
177
+ self.conv_shortcut = torch.nn.Conv2d(
178
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
179
+ )
180
+ else:
181
+ self.nin_shortcut = torch.nn.Conv2d(
182
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
183
+ )
184
+
185
+ def forward(self, x, temb):
186
+ h = x
187
+ h = self.norm1(h)
188
+ h = nonlinearity(h)
189
+ h = self.conv1(h)
190
+
191
+ if temb is not None:
192
+ h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
193
+
194
+ h = self.norm2(h)
195
+ h = nonlinearity(h)
196
+ h = self.dropout(h)
197
+ h = self.conv2(h)
198
+
199
+ if self.in_channels != self.out_channels:
200
+ if self.use_conv_shortcut:
201
+ x = self.conv_shortcut(x)
202
+ else:
203
+ x = self.nin_shortcut(x)
204
+
205
+ return x + h
206
+
207
+
208
+ class Model(nn.Module):
209
+ def __init__(
210
+ self,
211
+ *,
212
+ ch,
213
+ out_ch,
214
+ ch_mult=(1, 2, 4, 8),
215
+ num_res_blocks,
216
+ attn_resolutions,
217
+ dropout=0.0,
218
+ resamp_with_conv=True,
219
+ in_channels,
220
+ resolution,
221
+ use_timestep=True,
222
+ use_linear_attn=False,
223
+ attn_type="vanilla",
224
+ ):
225
+ super().__init__()
226
+ if use_linear_attn:
227
+ attn_type = "linear"
228
+ self.ch = ch
229
+ self.temb_ch = self.ch * 4
230
+ self.num_resolutions = len(ch_mult)
231
+ self.num_res_blocks = num_res_blocks
232
+ self.resolution = resolution
233
+ self.in_channels = in_channels
234
+
235
+ self.use_timestep = use_timestep
236
+ if self.use_timestep:
237
+ # timestep embedding
238
+ self.temb = nn.Module()
239
+ self.temb.dense = nn.ModuleList(
240
+ [
241
+ torch.nn.Linear(self.ch, self.temb_ch),
242
+ torch.nn.Linear(self.temb_ch, self.temb_ch),
243
+ ]
244
+ )
245
+
246
+ # downsampling
247
+ self.conv_in = torch.nn.Conv2d(
248
+ in_channels, self.ch, kernel_size=3, stride=1, padding=1
249
+ )
250
+
251
+ curr_res = resolution
252
+ in_ch_mult = (1,) + tuple(ch_mult)
253
+ self.down = nn.ModuleList()
254
+ for i_level in range(self.num_resolutions):
255
+ block = nn.ModuleList()
256
+ attn = nn.ModuleList()
257
+ block_in = ch * in_ch_mult[i_level]
258
+ block_out = ch * ch_mult[i_level]
259
+ for i_block in range(self.num_res_blocks):
260
+ block.append(
261
+ ResnetBlock(
262
+ in_channels=block_in,
263
+ out_channels=block_out,
264
+ temb_channels=self.temb_ch,
265
+ dropout=dropout,
266
+ )
267
+ )
268
+ block_in = block_out
269
+ if curr_res in attn_resolutions:
270
+ attn.append(make_attn(block_in, attn_type=attn_type))
271
+ down = nn.Module()
272
+ down.block = block
273
+ down.attn = attn
274
+ if i_level != self.num_resolutions - 1:
275
+ down.downsample = Downsample(block_in, resamp_with_conv)
276
+ curr_res = curr_res // 2
277
+ self.down.append(down)
278
+
279
+ # middle
280
+ self.mid = nn.Module()
281
+ self.mid.block_1 = ResnetBlock(
282
+ in_channels=block_in,
283
+ out_channels=block_in,
284
+ temb_channels=self.temb_ch,
285
+ dropout=dropout,
286
+ )
287
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
288
+ self.mid.block_2 = ResnetBlock(
289
+ in_channels=block_in,
290
+ out_channels=block_in,
291
+ temb_channels=self.temb_ch,
292
+ dropout=dropout,
293
+ )
294
+
295
+ # upsampling
296
+ self.up = nn.ModuleList()
297
+ for i_level in reversed(range(self.num_resolutions)):
298
+ block = nn.ModuleList()
299
+ attn = nn.ModuleList()
300
+ block_out = ch * ch_mult[i_level]
301
+ skip_in = ch * ch_mult[i_level]
302
+ for i_block in range(self.num_res_blocks + 1):
303
+ if i_block == self.num_res_blocks:
304
+ skip_in = ch * in_ch_mult[i_level]
305
+ block.append(
306
+ ResnetBlock(
307
+ in_channels=block_in + skip_in,
308
+ out_channels=block_out,
309
+ temb_channels=self.temb_ch,
310
+ dropout=dropout,
311
+ )
312
+ )
313
+ block_in = block_out
314
+ if curr_res in attn_resolutions:
315
+ attn.append(make_attn(block_in, attn_type=attn_type))
316
+ up = nn.Module()
317
+ up.block = block
318
+ up.attn = attn
319
+ if i_level != 0:
320
+ up.upsample = Upsample(block_in, resamp_with_conv)
321
+ curr_res = curr_res * 2
322
+ self.up.insert(0, up) # prepend to get consistent order
323
+
324
+ # end
325
+ self.norm_out = Normalize(block_in)
326
+ self.conv_out = torch.nn.Conv2d(
327
+ block_in, out_ch, kernel_size=3, stride=1, padding=1
328
+ )
329
+
330
+ def forward(self, x, t=None, context=None):
331
+ # assert x.shape[2] == x.shape[3] == self.resolution
332
+ if context is not None:
333
+ # assume aligned context, cat along channel axis
334
+ x = torch.cat((x, context), dim=1)
335
+ if self.use_timestep:
336
+ # timestep embedding
337
+ assert t is not None
338
+ temb = get_timestep_embedding(t, self.ch)
339
+ temb = self.temb.dense[0](temb)
340
+ temb = nonlinearity(temb)
341
+ temb = self.temb.dense[1](temb)
342
+ else:
343
+ temb = None
344
+
345
+ # downsampling
346
+ hs = [self.conv_in(x)]
347
+ for i_level in range(self.num_resolutions):
348
+ for i_block in range(self.num_res_blocks):
349
+ h = self.down[i_level].block[i_block](hs[-1], temb)
350
+ if len(self.down[i_level].attn) > 0:
351
+ h = self.down[i_level].attn[i_block](h)
352
+ hs.append(h)
353
+ if i_level != self.num_resolutions - 1:
354
+ hs.append(self.down[i_level].downsample(hs[-1]))
355
+
356
+ # middle
357
+ h = hs[-1]
358
+ h = self.mid.block_1(h, temb)
359
+ h = self.mid.attn_1(h)
360
+ h = self.mid.block_2(h, temb)
361
+
362
+ # upsampling
363
+ for i_level in reversed(range(self.num_resolutions)):
364
+ for i_block in range(self.num_res_blocks + 1):
365
+ h = self.up[i_level].block[i_block](
366
+ torch.cat([h, hs.pop()], dim=1), temb
367
+ )
368
+ if len(self.up[i_level].attn) > 0:
369
+ h = self.up[i_level].attn[i_block](h)
370
+ if i_level != 0:
371
+ h = self.up[i_level].upsample(h)
372
+
373
+ # end
374
+ h = self.norm_out(h)
375
+ h = nonlinearity(h)
376
+ h = self.conv_out(h)
377
+ return h
378
+
379
+ def get_last_layer(self):
380
+ return self.conv_out.weight
381
+
382
+
383
+ class Encoder(nn.Module):
384
+ def __init__(
385
+ self,
386
+ *,
387
+ ch,
388
+ out_ch,
389
+ ch_mult=(1, 2, 4, 8),
390
+ num_res_blocks,
391
+ attn_resolutions,
392
+ dropout=0.0,
393
+ resamp_with_conv=True,
394
+ in_channels,
395
+ resolution,
396
+ z_channels,
397
+ double_z=True,
398
+ use_linear_attn=False,
399
+ attn_type="vanilla",
400
+ **ignore_kwargs,
401
+ ):
402
+ super().__init__()
403
+ if use_linear_attn:
404
+ attn_type = "linear"
405
+ self.ch = ch
406
+ self.temb_ch = 0
407
+ self.num_resolutions = len(ch_mult)
408
+ self.num_res_blocks = num_res_blocks
409
+ self.resolution = resolution
410
+ self.in_channels = in_channels
411
+
412
+ # downsampling
413
+ self.conv_in = torch.nn.Conv2d(
414
+ in_channels, self.ch, kernel_size=3, stride=1, padding=1
415
+ )
416
+
417
+ curr_res = resolution
418
+ in_ch_mult = (1,) + tuple(ch_mult)
419
+ self.in_ch_mult = in_ch_mult
420
+ self.down = nn.ModuleList()
421
+ for i_level in range(self.num_resolutions):
422
+ block = nn.ModuleList()
423
+ attn = nn.ModuleList()
424
+ block_in = ch * in_ch_mult[i_level]
425
+ block_out = ch * ch_mult[i_level]
426
+ for i_block in range(self.num_res_blocks):
427
+ block.append(
428
+ ResnetBlock(
429
+ in_channels=block_in,
430
+ out_channels=block_out,
431
+ temb_channels=self.temb_ch,
432
+ dropout=dropout,
433
+ )
434
+ )
435
+ block_in = block_out
436
+ if curr_res in attn_resolutions:
437
+ attn.append(make_attn(block_in, attn_type=attn_type))
438
+ down = nn.Module()
439
+ down.block = block
440
+ down.attn = attn
441
+ if i_level != self.num_resolutions - 1:
442
+ down.downsample = Downsample(block_in, resamp_with_conv)
443
+ curr_res = curr_res // 2
444
+ self.down.append(down)
445
+
446
+ # middle
447
+ self.mid = nn.Module()
448
+ self.mid.block_1 = ResnetBlock(
449
+ in_channels=block_in,
450
+ out_channels=block_in,
451
+ temb_channels=self.temb_ch,
452
+ dropout=dropout,
453
+ )
454
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
455
+ self.mid.block_2 = ResnetBlock(
456
+ in_channels=block_in,
457
+ out_channels=block_in,
458
+ temb_channels=self.temb_ch,
459
+ dropout=dropout,
460
+ )
461
+
462
+ # end
463
+ self.norm_out = Normalize(block_in)
464
+ self.conv_out = torch.nn.Conv2d(
465
+ block_in,
466
+ 2 * z_channels if double_z else z_channels,
467
+ kernel_size=3,
468
+ stride=1,
469
+ padding=1,
470
+ )
471
+
472
+ def forward(self, x):
473
+ # timestep embedding
474
+ temb = None
475
+
476
+ # print(f'encoder-input={x.shape}')
477
+ # downsampling
478
+ hs = [self.conv_in(x)]
479
+ # print(f'encoder-conv in feat={hs[0].shape}')
480
+ for i_level in range(self.num_resolutions):
481
+ for i_block in range(self.num_res_blocks):
482
+ h = self.down[i_level].block[i_block](hs[-1], temb)
483
+ # print(f'encoder-down feat={h.shape}')
484
+ if len(self.down[i_level].attn) > 0:
485
+ h = self.down[i_level].attn[i_block](h)
486
+ hs.append(h)
487
+ if i_level != self.num_resolutions - 1:
488
+ # print(f'encoder-downsample (input)={hs[-1].shape}')
489
+ hs.append(self.down[i_level].downsample(hs[-1]))
490
+ # print(f'encoder-downsample (output)={hs[-1].shape}')
491
+
492
+ # middle
493
+ h = hs[-1]
494
+ h = self.mid.block_1(h, temb)
495
+ # print(f'encoder-mid1 feat={h.shape}')
496
+ h = self.mid.attn_1(h)
497
+ h = self.mid.block_2(h, temb)
498
+ # print(f'encoder-mid2 feat={h.shape}')
499
+
500
+ # end
501
+ h = self.norm_out(h)
502
+ h = nonlinearity(h)
503
+ h = self.conv_out(h)
504
+ # print(f'end feat={h.shape}')
505
+ return h
506
+
507
+
508
+ class Decoder(nn.Module):
509
+ def __init__(
510
+ self,
511
+ *,
512
+ ch,
513
+ out_ch,
514
+ ch_mult=(1, 2, 4, 8),
515
+ num_res_blocks,
516
+ attn_resolutions,
517
+ dropout=0.0,
518
+ resamp_with_conv=True,
519
+ in_channels,
520
+ resolution,
521
+ z_channels,
522
+ give_pre_end=False,
523
+ tanh_out=False,
524
+ use_linear_attn=False,
525
+ attn_type="vanilla",
526
+ **ignored_kwargs,
527
+ ):
528
+ super().__init__()
529
+ if use_linear_attn:
530
+ attn_type = "linear"
531
+ self.ch = ch
532
+ self.temb_ch = 0
533
+ self.num_resolutions = len(ch_mult)
534
+ self.num_res_blocks = num_res_blocks
535
+ self.resolution = resolution
536
+ self.in_channels = in_channels
537
+ self.give_pre_end = give_pre_end
538
+ self.tanh_out = tanh_out
539
+
540
+ # compute in_ch_mult, block_in and curr_res at lowest res
541
+ in_ch_mult = (1,) + tuple(ch_mult)
542
+ block_in = ch * ch_mult[self.num_resolutions - 1]
543
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
544
+ self.z_shape = (1, z_channels, curr_res, curr_res)
545
+ # print("AE working on z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
546
+
547
+ # z to block_in
548
+ self.conv_in = torch.nn.Conv2d(
549
+ z_channels, block_in, kernel_size=3, stride=1, padding=1
550
+ )
551
+
552
+ # middle
553
+ self.mid = nn.Module()
554
+ self.mid.block_1 = ResnetBlock(
555
+ in_channels=block_in,
556
+ out_channels=block_in,
557
+ temb_channels=self.temb_ch,
558
+ dropout=dropout,
559
+ )
560
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
561
+ self.mid.block_2 = ResnetBlock(
562
+ in_channels=block_in,
563
+ out_channels=block_in,
564
+ temb_channels=self.temb_ch,
565
+ dropout=dropout,
566
+ )
567
+
568
+ # upsampling
569
+ self.up = nn.ModuleList()
570
+ for i_level in reversed(range(self.num_resolutions)):
571
+ block = nn.ModuleList()
572
+ attn = nn.ModuleList()
573
+ block_out = ch * ch_mult[i_level]
574
+ for i_block in range(self.num_res_blocks + 1):
575
+ block.append(
576
+ ResnetBlock(
577
+ in_channels=block_in,
578
+ out_channels=block_out,
579
+ temb_channels=self.temb_ch,
580
+ dropout=dropout,
581
+ )
582
+ )
583
+ block_in = block_out
584
+ if curr_res in attn_resolutions:
585
+ attn.append(make_attn(block_in, attn_type=attn_type))
586
+ up = nn.Module()
587
+ up.block = block
588
+ up.attn = attn
589
+ if i_level != 0:
590
+ up.upsample = Upsample(block_in, resamp_with_conv)
591
+ curr_res = curr_res * 2
592
+ self.up.insert(0, up) # prepend to get consistent order
593
+
594
+ # end
595
+ self.norm_out = Normalize(block_in)
596
+ self.conv_out = torch.nn.Conv2d(
597
+ block_in, out_ch, kernel_size=3, stride=1, padding=1
598
+ )
599
+
600
+ def forward(self, z):
601
+ # assert z.shape[1:] == self.z_shape[1:]
602
+ self.last_z_shape = z.shape
603
+
604
+ # print(f'decoder-input={z.shape}')
605
+ # timestep embedding
606
+ temb = None
607
+
608
+ # z to block_in
609
+ h = self.conv_in(z)
610
+ # print(f'decoder-conv in feat={h.shape}')
611
+
612
+ # middle
613
+ h = self.mid.block_1(h, temb)
614
+ h = self.mid.attn_1(h)
615
+ h = self.mid.block_2(h, temb)
616
+ # print(f'decoder-mid feat={h.shape}')
617
+
618
+ # upsampling
619
+ for i_level in reversed(range(self.num_resolutions)):
620
+ for i_block in range(self.num_res_blocks + 1):
621
+ h = self.up[i_level].block[i_block](h, temb)
622
+ if len(self.up[i_level].attn) > 0:
623
+ h = self.up[i_level].attn[i_block](h)
624
+ # print(f'decoder-up feat={h.shape}')
625
+ if i_level != 0:
626
+ h = self.up[i_level].upsample(h)
627
+ # print(f'decoder-upsample feat={h.shape}')
628
+
629
+ # end
630
+ if self.give_pre_end:
631
+ return h
632
+
633
+ h = self.norm_out(h)
634
+ h = nonlinearity(h)
635
+ h = self.conv_out(h)
636
+ # print(f'decoder-conv_out feat={h.shape}')
637
+ if self.tanh_out:
638
+ h = torch.tanh(h)
639
+ return h
640
+
641
+
642
+ class SimpleDecoder(nn.Module):
643
+ def __init__(self, in_channels, out_channels, *args, **kwargs):
644
+ super().__init__()
645
+ self.model = nn.ModuleList(
646
+ [
647
+ nn.Conv2d(in_channels, in_channels, 1),
648
+ ResnetBlock(
649
+ in_channels=in_channels,
650
+ out_channels=2 * in_channels,
651
+ temb_channels=0,
652
+ dropout=0.0,
653
+ ),
654
+ ResnetBlock(
655
+ in_channels=2 * in_channels,
656
+ out_channels=4 * in_channels,
657
+ temb_channels=0,
658
+ dropout=0.0,
659
+ ),
660
+ ResnetBlock(
661
+ in_channels=4 * in_channels,
662
+ out_channels=2 * in_channels,
663
+ temb_channels=0,
664
+ dropout=0.0,
665
+ ),
666
+ nn.Conv2d(2 * in_channels, in_channels, 1),
667
+ Upsample(in_channels, with_conv=True),
668
+ ]
669
+ )
670
+ # end
671
+ self.norm_out = Normalize(in_channels)
672
+ self.conv_out = torch.nn.Conv2d(
673
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
674
+ )
675
+
676
+ def forward(self, x):
677
+ for i, layer in enumerate(self.model):
678
+ if i in [1, 2, 3]:
679
+ x = layer(x, None)
680
+ else:
681
+ x = layer(x)
682
+
683
+ h = self.norm_out(x)
684
+ h = nonlinearity(h)
685
+ x = self.conv_out(h)
686
+ return x
687
+
688
+
689
+ class UpsampleDecoder(nn.Module):
690
+ def __init__(
691
+ self,
692
+ in_channels,
693
+ out_channels,
694
+ ch,
695
+ num_res_blocks,
696
+ resolution,
697
+ ch_mult=(2, 2),
698
+ dropout=0.0,
699
+ ):
700
+ super().__init__()
701
+ # upsampling
702
+ self.temb_ch = 0
703
+ self.num_resolutions = len(ch_mult)
704
+ self.num_res_blocks = num_res_blocks
705
+ block_in = in_channels
706
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
707
+ self.res_blocks = nn.ModuleList()
708
+ self.upsample_blocks = nn.ModuleList()
709
+ for i_level in range(self.num_resolutions):
710
+ res_block = []
711
+ block_out = ch * ch_mult[i_level]
712
+ for i_block in range(self.num_res_blocks + 1):
713
+ res_block.append(
714
+ ResnetBlock(
715
+ in_channels=block_in,
716
+ out_channels=block_out,
717
+ temb_channels=self.temb_ch,
718
+ dropout=dropout,
719
+ )
720
+ )
721
+ block_in = block_out
722
+ self.res_blocks.append(nn.ModuleList(res_block))
723
+ if i_level != self.num_resolutions - 1:
724
+ self.upsample_blocks.append(Upsample(block_in, True))
725
+ curr_res = curr_res * 2
726
+
727
+ # end
728
+ self.norm_out = Normalize(block_in)
729
+ self.conv_out = torch.nn.Conv2d(
730
+ block_in, out_channels, kernel_size=3, stride=1, padding=1
731
+ )
732
+
733
+ def forward(self, x):
734
+ # upsampling
735
+ h = x
736
+ for k, i_level in enumerate(range(self.num_resolutions)):
737
+ for i_block in range(self.num_res_blocks + 1):
738
+ h = self.res_blocks[i_level][i_block](h, None)
739
+ if i_level != self.num_resolutions - 1:
740
+ h = self.upsample_blocks[k](h)
741
+ h = self.norm_out(h)
742
+ h = nonlinearity(h)
743
+ h = self.conv_out(h)
744
+ return h
745
+
746
+
747
+ class LatentRescaler(nn.Module):
748
+ def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
749
+ super().__init__()
750
+ # residual block, interpolate, residual block
751
+ self.factor = factor
752
+ self.conv_in = nn.Conv2d(
753
+ in_channels, mid_channels, kernel_size=3, stride=1, padding=1
754
+ )
755
+ self.res_block1 = nn.ModuleList(
756
+ [
757
+ ResnetBlock(
758
+ in_channels=mid_channels,
759
+ out_channels=mid_channels,
760
+ temb_channels=0,
761
+ dropout=0.0,
762
+ )
763
+ for _ in range(depth)
764
+ ]
765
+ )
766
+ self.attn = AttnBlock(mid_channels)
767
+ self.res_block2 = nn.ModuleList(
768
+ [
769
+ ResnetBlock(
770
+ in_channels=mid_channels,
771
+ out_channels=mid_channels,
772
+ temb_channels=0,
773
+ dropout=0.0,
774
+ )
775
+ for _ in range(depth)
776
+ ]
777
+ )
778
+
779
+ self.conv_out = nn.Conv2d(
780
+ mid_channels,
781
+ out_channels,
782
+ kernel_size=1,
783
+ )
784
+
785
+ def forward(self, x):
786
+ x = self.conv_in(x)
787
+ for block in self.res_block1:
788
+ x = block(x, None)
789
+ x = torch.nn.functional.interpolate(
790
+ x,
791
+ size=(
792
+ int(round(x.shape[2] * self.factor)),
793
+ int(round(x.shape[3] * self.factor)),
794
+ ),
795
+ )
796
+ x = self.attn(x)
797
+ for block in self.res_block2:
798
+ x = block(x, None)
799
+ x = self.conv_out(x)
800
+ return x
801
+
802
+
803
+ class MergedRescaleEncoder(nn.Module):
804
+ def __init__(
805
+ self,
806
+ in_channels,
807
+ ch,
808
+ resolution,
809
+ out_ch,
810
+ num_res_blocks,
811
+ attn_resolutions,
812
+ dropout=0.0,
813
+ resamp_with_conv=True,
814
+ ch_mult=(1, 2, 4, 8),
815
+ rescale_factor=1.0,
816
+ rescale_module_depth=1,
817
+ ):
818
+ super().__init__()
819
+ intermediate_chn = ch * ch_mult[-1]
820
+ self.encoder = Encoder(
821
+ in_channels=in_channels,
822
+ num_res_blocks=num_res_blocks,
823
+ ch=ch,
824
+ ch_mult=ch_mult,
825
+ z_channels=intermediate_chn,
826
+ double_z=False,
827
+ resolution=resolution,
828
+ attn_resolutions=attn_resolutions,
829
+ dropout=dropout,
830
+ resamp_with_conv=resamp_with_conv,
831
+ out_ch=None,
832
+ )
833
+ self.rescaler = LatentRescaler(
834
+ factor=rescale_factor,
835
+ in_channels=intermediate_chn,
836
+ mid_channels=intermediate_chn,
837
+ out_channels=out_ch,
838
+ depth=rescale_module_depth,
839
+ )
840
+
841
+ def forward(self, x):
842
+ x = self.encoder(x)
843
+ x = self.rescaler(x)
844
+ return x
845
+
846
+
847
+ class MergedRescaleDecoder(nn.Module):
848
+ def __init__(
849
+ self,
850
+ z_channels,
851
+ out_ch,
852
+ resolution,
853
+ num_res_blocks,
854
+ attn_resolutions,
855
+ ch,
856
+ ch_mult=(1, 2, 4, 8),
857
+ dropout=0.0,
858
+ resamp_with_conv=True,
859
+ rescale_factor=1.0,
860
+ rescale_module_depth=1,
861
+ ):
862
+ super().__init__()
863
+ tmp_chn = z_channels * ch_mult[-1]
864
+ self.decoder = Decoder(
865
+ out_ch=out_ch,
866
+ z_channels=tmp_chn,
867
+ attn_resolutions=attn_resolutions,
868
+ dropout=dropout,
869
+ resamp_with_conv=resamp_with_conv,
870
+ in_channels=None,
871
+ num_res_blocks=num_res_blocks,
872
+ ch_mult=ch_mult,
873
+ resolution=resolution,
874
+ ch=ch,
875
+ )
876
+ self.rescaler = LatentRescaler(
877
+ factor=rescale_factor,
878
+ in_channels=z_channels,
879
+ mid_channels=tmp_chn,
880
+ out_channels=tmp_chn,
881
+ depth=rescale_module_depth,
882
+ )
883
+
884
+ def forward(self, x):
885
+ x = self.rescaler(x)
886
+ x = self.decoder(x)
887
+ return x
888
+
889
+
890
+ class Upsampler(nn.Module):
891
+ def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
892
+ super().__init__()
893
+ assert out_size >= in_size
894
+ num_blocks = int(np.log2(out_size // in_size)) + 1
895
+ factor_up = 1.0 + (out_size % in_size)
896
+ print(
897
+ f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}"
898
+ )
899
+ self.rescaler = LatentRescaler(
900
+ factor=factor_up,
901
+ in_channels=in_channels,
902
+ mid_channels=2 * in_channels,
903
+ out_channels=in_channels,
904
+ )
905
+ self.decoder = Decoder(
906
+ out_ch=out_channels,
907
+ resolution=out_size,
908
+ z_channels=in_channels,
909
+ num_res_blocks=2,
910
+ attn_resolutions=[],
911
+ in_channels=None,
912
+ ch=in_channels,
913
+ ch_mult=[ch_mult for _ in range(num_blocks)],
914
+ )
915
+
916
+ def forward(self, x):
917
+ x = self.rescaler(x)
918
+ x = self.decoder(x)
919
+ return x
920
+
921
+
922
+ class Resize(nn.Module):
923
+ def __init__(self, in_channels=None, learned=False, mode="bilinear"):
924
+ super().__init__()
925
+ self.with_conv = learned
926
+ self.mode = mode
927
+ if self.with_conv:
928
+ print(
929
+ f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode"
930
+ )
931
+ raise NotImplementedError()
932
+ assert in_channels is not None
933
+ # no asymmetric padding in torch conv, must do it ourselves
934
+ self.conv = torch.nn.Conv2d(
935
+ in_channels, in_channels, kernel_size=4, stride=2, padding=1
936
+ )
937
+
938
+ def forward(self, x, scale_factor=1.0):
939
+ if scale_factor == 1.0:
940
+ return x
941
+ else:
942
+ x = torch.nn.functional.interpolate(
943
+ x, mode=self.mode, align_corners=False, scale_factor=scale_factor
944
+ )
945
+ return x
946
+
947
+
948
+ class FirstStagePostProcessor(nn.Module):
949
+
950
+ def __init__(
951
+ self,
952
+ ch_mult: list,
953
+ in_channels,
954
+ pretrained_model: nn.Module = None,
955
+ reshape=False,
956
+ n_channels=None,
957
+ dropout=0.0,
958
+ pretrained_config=None,
959
+ ):
960
+ super().__init__()
961
+ if pretrained_config is None:
962
+ assert (
963
+ pretrained_model is not None
964
+ ), 'Either "pretrained_model" or "pretrained_config" must not be None'
965
+ self.pretrained_model = pretrained_model
966
+ else:
967
+ assert (
968
+ pretrained_config is not None
969
+ ), 'Either "pretrained_model" or "pretrained_config" must not be None'
970
+ self.instantiate_pretrained(pretrained_config)
971
+
972
+ self.do_reshape = reshape
973
+
974
+ if n_channels is None:
975
+ n_channels = self.pretrained_model.encoder.ch
976
+
977
+ self.proj_norm = Normalize(in_channels, num_groups=in_channels // 2)
978
+ self.proj = nn.Conv2d(
979
+ in_channels, n_channels, kernel_size=3, stride=1, padding=1
980
+ )
981
+
982
+ blocks = []
983
+ downs = []
984
+ ch_in = n_channels
985
+ for m in ch_mult:
986
+ blocks.append(
987
+ ResnetBlock(
988
+ in_channels=ch_in, out_channels=m * n_channels, dropout=dropout
989
+ )
990
+ )
991
+ ch_in = m * n_channels
992
+ downs.append(Downsample(ch_in, with_conv=False))
993
+
994
+ self.model = nn.ModuleList(blocks)
995
+ self.downsampler = nn.ModuleList(downs)
996
+
997
+ def instantiate_pretrained(self, config):
998
+ model = instantiate_from_config(config)
999
+ self.pretrained_model = model.eval()
1000
+ # self.pretrained_model.train = False
1001
+ for param in self.pretrained_model.parameters():
1002
+ param.requires_grad = False
1003
+
1004
+ @torch.no_grad()
1005
+ def encode_with_pretrained(self, x):
1006
+ c = self.pretrained_model.encode(x)
1007
+ if isinstance(c, DiagonalGaussianDistribution):
1008
+ c = c.mode()
1009
+ return c
1010
+
1011
+ def forward(self, x):
1012
+ z_fs = self.encode_with_pretrained(x)
1013
+ z = self.proj_norm(z_fs)
1014
+ z = self.proj(z)
1015
+ z = nonlinearity(z)
1016
+
1017
+ for submodel, downmodel in zip(self.model, self.downsampler):
1018
+ z = submodel(z, temb=None)
1019
+ z = downmodel(z)
1020
+
1021
+ if self.do_reshape:
1022
+ z = rearrange(z, "b c h w -> b (h w) c")
1023
+ return z
core/modules/networks/unet_modules.py ADDED
@@ -0,0 +1,1047 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ from abc import abstractmethod
3
+ import torch
4
+ import torch.nn as nn
5
+ from einops import rearrange
6
+ import torch.nn.functional as F
7
+ from core.models.utils_diffusion import timestep_embedding
8
+ from core.common import gradient_checkpoint
9
+ from core.basics import zero_module, conv_nd, linear, avg_pool_nd, normalization
10
+ from core.modules.attention import SpatialTransformer, TemporalTransformer
11
+
12
+ TASK_IDX_IMAGE = 0
13
+ TASK_IDX_RAY = 1
14
+
15
+
16
+ class TimestepBlock(nn.Module):
17
+ """
18
+ Any module where forward() takes timestep embeddings as a second argument.
19
+ """
20
+
21
+ @abstractmethod
22
+ def forward(self, x, emb):
23
+ """
24
+ Apply the module to `x` given `emb` timestep embeddings.
25
+ """
26
+
27
+
28
+ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
29
+ """
30
+ A sequential module that passes timestep embeddings to the children that
31
+ support it as an extra input.
32
+ """
33
+
34
+ def forward(
35
+ self, x, emb, context=None, batch_size=None, with_lora=False, time_steps=None
36
+ ):
37
+ for layer in self:
38
+ if isinstance(layer, TimestepBlock):
39
+ x = layer(x, emb, batch_size=batch_size)
40
+ elif isinstance(layer, SpatialTransformer):
41
+ x = layer(x, context, with_lora=with_lora)
42
+ elif isinstance(layer, TemporalTransformer):
43
+ x = rearrange(x, "(b f) c h w -> b c f h w", b=batch_size)
44
+ x = layer(x, context, with_lora=with_lora, time_steps=time_steps)
45
+ x = rearrange(x, "b c f h w -> (b f) c h w")
46
+ else:
47
+ x = layer(x)
48
+ return x
49
+
50
+
51
+ class Downsample(nn.Module):
52
+ """
53
+ A downsampling layer with an optional convolution.
54
+ :param channels: channels in the inputs and outputs.
55
+ :param use_conv: a bool determining if a convolution is applied.
56
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
57
+ downsampling occurs in the inner-two dimensions.
58
+ """
59
+
60
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
61
+ super().__init__()
62
+ self.channels = channels
63
+ self.out_channels = out_channels or channels
64
+ self.use_conv = use_conv
65
+ self.dims = dims
66
+ stride = 2 if dims != 3 else (1, 2, 2)
67
+ if use_conv:
68
+ self.op = conv_nd(
69
+ dims,
70
+ self.channels,
71
+ self.out_channels,
72
+ 3,
73
+ stride=stride,
74
+ padding=padding,
75
+ )
76
+ else:
77
+ assert self.channels == self.out_channels
78
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
79
+
80
+ def forward(self, x):
81
+ assert x.shape[1] == self.channels
82
+ return self.op(x)
83
+
84
+
85
+ class Upsample(nn.Module):
86
+ """
87
+ An upsampling layer with an optional convolution.
88
+ :param channels: channels in the inputs and outputs.
89
+ :param use_conv: a bool determining if a convolution is applied.
90
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
91
+ upsampling occurs in the inner-two dimensions.
92
+ """
93
+
94
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
95
+ super().__init__()
96
+ self.channels = channels
97
+ self.out_channels = out_channels or channels
98
+ self.use_conv = use_conv
99
+ self.dims = dims
100
+ if use_conv:
101
+ self.conv = conv_nd(
102
+ dims, self.channels, self.out_channels, 3, padding=padding
103
+ )
104
+
105
+ def forward(self, x):
106
+ assert x.shape[1] == self.channels
107
+ if self.dims == 3:
108
+ x = F.interpolate(
109
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
110
+ )
111
+ else:
112
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
113
+ if self.use_conv:
114
+ x = self.conv(x)
115
+ return x
116
+
117
+
118
+ class ResBlock(TimestepBlock):
119
+ """
120
+ A residual block that can optionally change the number of channels.
121
+ :param channels: the number of input channels.
122
+ :param emb_channels: the number of timestep embedding channels.
123
+ :param dropout: the rate of dropout.
124
+ :param out_channels: if specified, the number of out channels.
125
+ :param use_conv: if True and out_channels is specified, use a spatial
126
+ convolution instead of a smaller 1x1 convolution to change the
127
+ channels in the skip connection.
128
+ :param dims: determines if the signal is 1D, 2D, or 3D.
129
+ :param up: if True, use this block for upsampling.
130
+ :param down: if True, use this block for downsampling.
131
+ :param use_temporal_conv: if True, use the temporal convolution.
132
+ :param use_image_dataset: if True, the temporal parameters will not be optimized.
133
+ """
134
+
135
+ def __init__(
136
+ self,
137
+ channels,
138
+ emb_channels,
139
+ dropout,
140
+ out_channels=None,
141
+ use_scale_shift_norm=False,
142
+ dims=2,
143
+ use_checkpoint=False,
144
+ use_conv=False,
145
+ up=False,
146
+ down=False,
147
+ use_temporal_conv=False,
148
+ tempspatial_aware=False,
149
+ ):
150
+ super().__init__()
151
+ self.channels = channels
152
+ self.emb_channels = emb_channels
153
+ self.dropout = dropout
154
+ self.out_channels = out_channels or channels
155
+ self.use_conv = use_conv
156
+ self.use_checkpoint = use_checkpoint
157
+ self.use_scale_shift_norm = use_scale_shift_norm
158
+ self.use_temporal_conv = use_temporal_conv
159
+
160
+ self.in_layers = nn.Sequential(
161
+ normalization(channels),
162
+ nn.SiLU(),
163
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
164
+ )
165
+
166
+ self.updown = up or down
167
+
168
+ if up:
169
+ self.h_upd = Upsample(channels, False, dims)
170
+ self.x_upd = Upsample(channels, False, dims)
171
+ elif down:
172
+ self.h_upd = Downsample(channels, False, dims)
173
+ self.x_upd = Downsample(channels, False, dims)
174
+ else:
175
+ self.h_upd = self.x_upd = nn.Identity()
176
+
177
+ self.emb_layers = nn.Sequential(
178
+ nn.SiLU(),
179
+ nn.Linear(
180
+ emb_channels,
181
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
182
+ ),
183
+ )
184
+ self.out_layers = nn.Sequential(
185
+ normalization(self.out_channels),
186
+ nn.SiLU(),
187
+ nn.Dropout(p=dropout),
188
+ zero_module(nn.Conv2d(self.out_channels, self.out_channels, 3, padding=1)),
189
+ )
190
+
191
+ if self.out_channels == channels:
192
+ self.skip_connection = nn.Identity()
193
+ elif use_conv:
194
+ self.skip_connection = conv_nd(
195
+ dims, channels, self.out_channels, 3, padding=1
196
+ )
197
+ else:
198
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
199
+
200
+ if self.use_temporal_conv:
201
+ self.temopral_conv = TemporalConvBlock(
202
+ self.out_channels,
203
+ self.out_channels,
204
+ dropout=0.1,
205
+ spatial_aware=tempspatial_aware,
206
+ )
207
+
208
+ def forward(self, x, emb, batch_size=None):
209
+ """
210
+ Apply the block to a Tensor, conditioned on a timestep embedding.
211
+ :param x: an [N x C x ...] Tensor of features.
212
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
213
+ :return: an [N x C x ...] Tensor of outputs.
214
+ """
215
+ input_tuple = (x, emb)
216
+ if batch_size:
217
+ forward_batchsize = partial(self._forward, batch_size=batch_size)
218
+ return gradient_checkpoint(
219
+ forward_batchsize, input_tuple, self.parameters(), self.use_checkpoint
220
+ )
221
+ return gradient_checkpoint(
222
+ self._forward, input_tuple, self.parameters(), self.use_checkpoint
223
+ )
224
+
225
+ def _forward(self, x, emb, batch_size=None):
226
+ if self.updown:
227
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
228
+ h = in_rest(x)
229
+ h = self.h_upd(h)
230
+ x = self.x_upd(x)
231
+ h = in_conv(h)
232
+ else:
233
+ h = self.in_layers(x)
234
+ emb_out = self.emb_layers(emb).type(h.dtype)
235
+ while len(emb_out.shape) < len(h.shape):
236
+ emb_out = emb_out[..., None]
237
+ if self.use_scale_shift_norm:
238
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
239
+ scale, shift = torch.chunk(emb_out, 2, dim=1)
240
+ h = out_norm(h) * (1 + scale) + shift
241
+ h = out_rest(h)
242
+ else:
243
+ h = h + emb_out
244
+ h = self.out_layers(h)
245
+ h = self.skip_connection(x) + h
246
+
247
+ if self.use_temporal_conv and batch_size:
248
+ h = rearrange(h, "(b t) c h w -> b c t h w", b=batch_size)
249
+ h = self.temopral_conv(h)
250
+ h = rearrange(h, "b c t h w -> (b t) c h w")
251
+ return h
252
+
253
+
254
+ class TemporalConvBlock(nn.Module):
255
+ def __init__(
256
+ self, in_channels, out_channels=None, dropout=0.0, spatial_aware=False
257
+ ):
258
+ super(TemporalConvBlock, self).__init__()
259
+ if out_channels is None:
260
+ out_channels = in_channels
261
+ self.in_channels = in_channels
262
+ self.out_channels = out_channels
263
+ th_kernel_shape = (3, 1, 1) if not spatial_aware else (3, 3, 1)
264
+ th_padding_shape = (1, 0, 0) if not spatial_aware else (1, 1, 0)
265
+ tw_kernel_shape = (3, 1, 1) if not spatial_aware else (3, 1, 3)
266
+ tw_padding_shape = (1, 0, 0) if not spatial_aware else (1, 0, 1)
267
+
268
+ # conv layers
269
+ self.conv1 = nn.Sequential(
270
+ nn.GroupNorm(32, in_channels),
271
+ nn.SiLU(),
272
+ nn.Conv3d(
273
+ in_channels, out_channels, th_kernel_shape, padding=th_padding_shape
274
+ ),
275
+ )
276
+ self.conv2 = nn.Sequential(
277
+ nn.GroupNorm(32, out_channels),
278
+ nn.SiLU(),
279
+ nn.Dropout(dropout),
280
+ nn.Conv3d(
281
+ out_channels, in_channels, tw_kernel_shape, padding=tw_padding_shape
282
+ ),
283
+ )
284
+ self.conv3 = nn.Sequential(
285
+ nn.GroupNorm(32, out_channels),
286
+ nn.SiLU(),
287
+ nn.Dropout(dropout),
288
+ nn.Conv3d(
289
+ out_channels, in_channels, th_kernel_shape, padding=th_padding_shape
290
+ ),
291
+ )
292
+ self.conv4 = nn.Sequential(
293
+ nn.GroupNorm(32, out_channels),
294
+ nn.SiLU(),
295
+ nn.Dropout(dropout),
296
+ nn.Conv3d(
297
+ out_channels, in_channels, tw_kernel_shape, padding=tw_padding_shape
298
+ ),
299
+ )
300
+
301
+ # zero out the last layer params,so the conv block is identity
302
+ nn.init.zeros_(self.conv4[-1].weight)
303
+ nn.init.zeros_(self.conv4[-1].bias)
304
+
305
+ def forward(self, x):
306
+ identity = x
307
+ x = self.conv1(x)
308
+ x = self.conv2(x)
309
+ x = self.conv3(x)
310
+ x = self.conv4(x)
311
+
312
+ return identity + x
313
+
314
+
315
+ class UNetModel(nn.Module):
316
+ """
317
+ The full UNet model with attention and timestep embedding.
318
+ :param in_channels: in_channels in the input Tensor.
319
+ :param model_channels: base channel count for the model.
320
+ :param out_channels: channels in the output Tensor.
321
+ :param num_res_blocks: number of residual blocks per downsample.
322
+ :param attention_resolutions: a collection of downsample rates at which
323
+ attention will take place. May be a set, list, or tuple.
324
+ For example, if this contains 4, then at 4x downsampling, attention
325
+ will be used.
326
+ :param dropout: the dropout probability.
327
+ :param channel_mult: channel multiplier for each level of the UNet.
328
+ :param conv_resample: if True, use learned convolutions for upsampling and
329
+ downsampling.
330
+ :param dims: determines if the signal is 1D, 2D, or 3D.
331
+ :param num_classes: if specified (as an int), then this model will be
332
+ class-conditional with `num_classes` classes.
333
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
334
+ :param num_heads: the number of attention heads in each attention layer.
335
+ :param num_heads_channels: if specified, ignore num_heads and instead use
336
+ a fixed channel width per attention head.
337
+ :param num_heads_upsample: works with num_heads to set a different number
338
+ of heads for upsampling. Deprecated.
339
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
340
+ :param resblock_updown: use residual blocks for up/downsampling.
341
+ :param use_new_attention_order: use a different attention pattern for potentially
342
+ increased efficiency.
343
+ """
344
+
345
+ def __init__(
346
+ self,
347
+ in_channels,
348
+ model_channels,
349
+ out_channels,
350
+ num_res_blocks,
351
+ attention_resolutions,
352
+ dropout=0.0,
353
+ channel_mult=(1, 2, 4, 8),
354
+ conv_resample=True,
355
+ dims=2,
356
+ context_dim=None,
357
+ use_scale_shift_norm=False,
358
+ resblock_updown=False,
359
+ num_heads=-1,
360
+ num_head_channels=-1,
361
+ transformer_depth=1,
362
+ use_linear=False,
363
+ use_checkpoint=False,
364
+ temporal_conv=False,
365
+ tempspatial_aware=False,
366
+ temporal_attention=True,
367
+ use_relative_position=True,
368
+ use_causal_attention=False,
369
+ temporal_length=None,
370
+ use_fp16=False,
371
+ addition_attention=False,
372
+ temporal_selfatt_only=True,
373
+ image_cross_attention=False,
374
+ image_cross_attention_scale_learnable=False,
375
+ default_fs=4,
376
+ fs_condition=False,
377
+ use_spatial_temporal_attention=False,
378
+ # >>> Extra Ray Options
379
+ use_addition_ray_output_head=False,
380
+ ray_channels=6,
381
+ use_lora_for_rays_in_output_blocks=False,
382
+ use_task_embedding=False,
383
+ use_ray_decoder=False,
384
+ use_ray_decoder_residual=False,
385
+ full_spatial_temporal_attention=False,
386
+ enhance_multi_view_correspondence=False,
387
+ camera_pose_condition=False,
388
+ use_feature_alignment=False,
389
+ ):
390
+ super(UNetModel, self).__init__()
391
+ if num_heads == -1:
392
+ assert (
393
+ num_head_channels != -1
394
+ ), "Either num_heads or num_head_channels has to be set"
395
+ if num_head_channels == -1:
396
+ assert (
397
+ num_heads != -1
398
+ ), "Either num_heads or num_head_channels has to be set"
399
+
400
+ self.in_channels = in_channels
401
+ self.model_channels = model_channels
402
+ self.out_channels = out_channels
403
+ self.num_res_blocks = num_res_blocks
404
+ self.attention_resolutions = attention_resolutions
405
+ self.dropout = dropout
406
+ self.channel_mult = channel_mult
407
+ self.conv_resample = conv_resample
408
+ self.temporal_attention = temporal_attention
409
+ time_embed_dim = model_channels * 4
410
+ self.use_checkpoint = use_checkpoint
411
+ self.dtype = torch.float16 if use_fp16 else torch.float32
412
+ temporal_self_att_only = True
413
+ self.addition_attention = addition_attention
414
+ self.temporal_length = temporal_length
415
+ self.image_cross_attention = image_cross_attention
416
+ self.image_cross_attention_scale_learnable = (
417
+ image_cross_attention_scale_learnable
418
+ )
419
+ self.default_fs = default_fs
420
+ self.fs_condition = fs_condition
421
+ self.use_spatial_temporal_attention = use_spatial_temporal_attention
422
+
423
+ # >>> Extra Ray Options
424
+ self.use_addition_ray_output_head = use_addition_ray_output_head
425
+ self.use_lora_for_rays_in_output_blocks = use_lora_for_rays_in_output_blocks
426
+ if self.use_lora_for_rays_in_output_blocks:
427
+ assert (
428
+ use_addition_ray_output_head
429
+ ), "`use_addition_ray_output_head` is required to be True when using LoRA for rays in output blocks."
430
+ assert (
431
+ not use_task_embedding
432
+ ), "`use_task_embedding` cannot be True when `use_lora_for_rays_in_output_blocks` is enabled."
433
+ if self.use_addition_ray_output_head:
434
+ print("Using additional ray output head...")
435
+ assert (self.out_channels == 4) or (
436
+ 4 + ray_channels == self.out_channels
437
+ ), f"`out_channels`={out_channels} is invalid."
438
+ self.out_channels = 4
439
+ out_channels = 4
440
+ self.ray_channels = ray_channels
441
+ self.use_ray_decoder = use_ray_decoder
442
+ if use_ray_decoder:
443
+ assert (
444
+ not use_task_embedding
445
+ ), "`use_task_embedding` cannot be True when `use_ray_decoder_layers` is enabled."
446
+ assert (
447
+ use_addition_ray_output_head
448
+ ), "`use_addition_ray_output_head` must be True when `use_ray_decoder_layers` is enabled."
449
+ self.use_ray_decoder_residual = use_ray_decoder_residual
450
+
451
+ # >>> Time/Task Embedding Blocks
452
+ self.time_embed = nn.Sequential(
453
+ linear(model_channels, time_embed_dim),
454
+ nn.SiLU(),
455
+ linear(time_embed_dim, time_embed_dim),
456
+ )
457
+ if fs_condition:
458
+ self.fps_embedding = nn.Sequential(
459
+ linear(model_channels, time_embed_dim),
460
+ nn.SiLU(),
461
+ linear(time_embed_dim, time_embed_dim),
462
+ )
463
+ nn.init.zeros_(self.fps_embedding[-1].weight)
464
+ nn.init.zeros_(self.fps_embedding[-1].bias)
465
+
466
+ if camera_pose_condition:
467
+ self.camera_pose_condition = True
468
+ self.camera_pose_embedding = nn.Sequential(
469
+ linear(12, model_channels),
470
+ nn.SiLU(),
471
+ linear(model_channels, time_embed_dim),
472
+ nn.SiLU(),
473
+ linear(time_embed_dim, time_embed_dim),
474
+ )
475
+ nn.init.zeros_(self.camera_pose_embedding[-1].weight)
476
+ nn.init.zeros_(self.camera_pose_embedding[-1].bias)
477
+
478
+ self.use_task_embedding = use_task_embedding
479
+ if use_task_embedding:
480
+ assert (
481
+ not use_lora_for_rays_in_output_blocks
482
+ ), "`use_lora_for_rays_in_output_blocks` and `use_task_embedding` cannot be True at the same time."
483
+ assert (
484
+ use_addition_ray_output_head
485
+ ), "`use_addition_ray_output_head` is required to be True when `use_task_embedding` is enabled."
486
+ self.task_embedding = nn.Sequential(
487
+ linear(model_channels, time_embed_dim),
488
+ nn.SiLU(),
489
+ linear(time_embed_dim, time_embed_dim),
490
+ )
491
+ nn.init.zeros_(self.task_embedding[-1].weight)
492
+ nn.init.zeros_(self.task_embedding[-1].bias)
493
+ self.task_parameters = nn.ParameterList(
494
+ [
495
+ nn.Parameter(
496
+ torch.zeros(size=[model_channels], requires_grad=True)
497
+ ),
498
+ nn.Parameter(
499
+ torch.zeros(size=[model_channels], requires_grad=True)
500
+ ),
501
+ ]
502
+ )
503
+
504
+ # >>> Input Block
505
+ self.input_blocks = nn.ModuleList(
506
+ [
507
+ TimestepEmbedSequential(
508
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
509
+ )
510
+ ]
511
+ )
512
+ if self.addition_attention:
513
+ self.init_attn = TimestepEmbedSequential(
514
+ TemporalTransformer(
515
+ model_channels,
516
+ n_heads=8,
517
+ d_head=num_head_channels,
518
+ depth=transformer_depth,
519
+ context_dim=context_dim,
520
+ use_checkpoint=use_checkpoint,
521
+ only_self_att=temporal_selfatt_only,
522
+ causal_attention=False,
523
+ relative_position=use_relative_position,
524
+ temporal_length=temporal_length,
525
+ )
526
+ )
527
+
528
+ input_block_chans = [model_channels]
529
+ ch = model_channels
530
+ ds = 1
531
+ for level, mult in enumerate(channel_mult):
532
+ for _ in range(num_res_blocks):
533
+ layers = [
534
+ ResBlock(
535
+ ch,
536
+ time_embed_dim,
537
+ dropout,
538
+ out_channels=mult * model_channels,
539
+ dims=dims,
540
+ use_checkpoint=use_checkpoint,
541
+ use_scale_shift_norm=use_scale_shift_norm,
542
+ tempspatial_aware=tempspatial_aware,
543
+ use_temporal_conv=temporal_conv,
544
+ )
545
+ ]
546
+ ch = mult * model_channels
547
+ if ds in attention_resolutions:
548
+ if num_head_channels == -1:
549
+ dim_head = ch // num_heads
550
+ else:
551
+ num_heads = ch // num_head_channels
552
+ dim_head = num_head_channels
553
+ layers.append(
554
+ SpatialTransformer(
555
+ ch,
556
+ num_heads,
557
+ dim_head,
558
+ depth=transformer_depth,
559
+ context_dim=context_dim,
560
+ use_linear=use_linear,
561
+ use_checkpoint=use_checkpoint,
562
+ disable_self_attn=False,
563
+ video_length=temporal_length,
564
+ image_cross_attention=self.image_cross_attention,
565
+ image_cross_attention_scale_learnable=self.image_cross_attention_scale_learnable,
566
+ )
567
+ )
568
+ if self.temporal_attention:
569
+ layers.append(
570
+ TemporalTransformer(
571
+ ch,
572
+ num_heads,
573
+ dim_head,
574
+ depth=transformer_depth,
575
+ context_dim=context_dim,
576
+ use_linear=use_linear,
577
+ use_checkpoint=use_checkpoint,
578
+ only_self_att=temporal_self_att_only,
579
+ causal_attention=use_causal_attention,
580
+ relative_position=use_relative_position,
581
+ temporal_length=temporal_length,
582
+ )
583
+ )
584
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
585
+ input_block_chans.append(ch)
586
+ if level != len(channel_mult) - 1:
587
+ out_ch = ch
588
+ self.input_blocks.append(
589
+ TimestepEmbedSequential(
590
+ ResBlock(
591
+ ch,
592
+ time_embed_dim,
593
+ dropout,
594
+ out_channels=out_ch,
595
+ dims=dims,
596
+ use_checkpoint=use_checkpoint,
597
+ use_scale_shift_norm=use_scale_shift_norm,
598
+ down=True,
599
+ )
600
+ if resblock_updown
601
+ else Downsample(
602
+ ch, conv_resample, dims=dims, out_channels=out_ch
603
+ )
604
+ )
605
+ )
606
+ ch = out_ch
607
+ input_block_chans.append(ch)
608
+ ds *= 2
609
+
610
+ if num_head_channels == -1:
611
+ dim_head = ch // num_heads
612
+ else:
613
+ num_heads = ch // num_head_channels
614
+ dim_head = num_head_channels
615
+ layers = [
616
+ ResBlock(
617
+ ch,
618
+ time_embed_dim,
619
+ dropout,
620
+ dims=dims,
621
+ use_checkpoint=use_checkpoint,
622
+ use_scale_shift_norm=use_scale_shift_norm,
623
+ tempspatial_aware=tempspatial_aware,
624
+ use_temporal_conv=temporal_conv,
625
+ ),
626
+ SpatialTransformer(
627
+ ch,
628
+ num_heads,
629
+ dim_head,
630
+ depth=transformer_depth,
631
+ context_dim=context_dim,
632
+ use_linear=use_linear,
633
+ use_checkpoint=use_checkpoint,
634
+ disable_self_attn=False,
635
+ video_length=temporal_length,
636
+ image_cross_attention=self.image_cross_attention,
637
+ image_cross_attention_scale_learnable=self.image_cross_attention_scale_learnable,
638
+ ),
639
+ ]
640
+ if self.temporal_attention:
641
+ layers.append(
642
+ TemporalTransformer(
643
+ ch,
644
+ num_heads,
645
+ dim_head,
646
+ depth=transformer_depth,
647
+ context_dim=context_dim,
648
+ use_linear=use_linear,
649
+ use_checkpoint=use_checkpoint,
650
+ only_self_att=temporal_self_att_only,
651
+ causal_attention=use_causal_attention,
652
+ relative_position=use_relative_position,
653
+ temporal_length=temporal_length,
654
+ )
655
+ )
656
+ layers.append(
657
+ ResBlock(
658
+ ch,
659
+ time_embed_dim,
660
+ dropout,
661
+ dims=dims,
662
+ use_checkpoint=use_checkpoint,
663
+ use_scale_shift_norm=use_scale_shift_norm,
664
+ tempspatial_aware=tempspatial_aware,
665
+ use_temporal_conv=temporal_conv,
666
+ )
667
+ )
668
+
669
+ # >>> Middle Block
670
+ self.middle_block = TimestepEmbedSequential(*layers)
671
+
672
+ # >>> Ray Decoder
673
+ if use_ray_decoder:
674
+ self.ray_decoder_blocks = nn.ModuleList([])
675
+
676
+ # >>> Output Block
677
+ is_first_layer = True
678
+ self.output_blocks = nn.ModuleList([])
679
+ for level, mult in list(enumerate(channel_mult))[::-1]:
680
+ for i in range(num_res_blocks + 1):
681
+ ich = input_block_chans.pop()
682
+ layers = [
683
+ ResBlock(
684
+ ch + ich,
685
+ time_embed_dim,
686
+ dropout,
687
+ out_channels=mult * model_channels,
688
+ dims=dims,
689
+ use_checkpoint=use_checkpoint,
690
+ use_scale_shift_norm=use_scale_shift_norm,
691
+ tempspatial_aware=tempspatial_aware,
692
+ use_temporal_conv=temporal_conv,
693
+ )
694
+ ]
695
+ if use_ray_decoder:
696
+ if self.use_ray_decoder_residual:
697
+ ray_residual_ch = ich
698
+ else:
699
+ ray_residual_ch = 0
700
+ ray_decoder_layers = [
701
+ ResBlock(
702
+ (ch if is_first_layer else (ch // 10)) + ray_residual_ch,
703
+ time_embed_dim,
704
+ dropout,
705
+ out_channels=mult * model_channels // 10,
706
+ dims=dims,
707
+ use_checkpoint=use_checkpoint,
708
+ use_scale_shift_norm=use_scale_shift_norm,
709
+ tempspatial_aware=tempspatial_aware,
710
+ use_temporal_conv=True,
711
+ )
712
+ ]
713
+ is_first_layer = False
714
+ ch = model_channels * mult
715
+ if ds in attention_resolutions:
716
+ if num_head_channels == -1:
717
+ dim_head = ch // num_heads
718
+ else:
719
+ num_heads = ch // num_head_channels
720
+ dim_head = num_head_channels
721
+ layers.append(
722
+ SpatialTransformer(
723
+ ch,
724
+ num_heads,
725
+ dim_head,
726
+ depth=transformer_depth,
727
+ context_dim=context_dim,
728
+ use_linear=use_linear,
729
+ use_checkpoint=use_checkpoint,
730
+ disable_self_attn=False,
731
+ video_length=temporal_length,
732
+ image_cross_attention=self.image_cross_attention,
733
+ image_cross_attention_scale_learnable=self.image_cross_attention_scale_learnable,
734
+ enable_lora=self.use_lora_for_rays_in_output_blocks,
735
+ )
736
+ )
737
+ if self.temporal_attention:
738
+ layers.append(
739
+ TemporalTransformer(
740
+ ch,
741
+ num_heads,
742
+ dim_head,
743
+ depth=transformer_depth,
744
+ context_dim=context_dim,
745
+ use_linear=use_linear,
746
+ use_checkpoint=use_checkpoint,
747
+ only_self_att=temporal_self_att_only,
748
+ causal_attention=use_causal_attention,
749
+ relative_position=use_relative_position,
750
+ temporal_length=temporal_length,
751
+ use_extra_spatial_temporal_self_attention=use_spatial_temporal_attention,
752
+ enable_lora=self.use_lora_for_rays_in_output_blocks,
753
+ full_spatial_temporal_attention=full_spatial_temporal_attention,
754
+ enhance_multi_view_correspondence=enhance_multi_view_correspondence,
755
+ )
756
+ )
757
+ if level and i == num_res_blocks:
758
+ out_ch = ch
759
+ # out_ray_ch = ray_ch
760
+ layers.append(
761
+ ResBlock(
762
+ ch,
763
+ time_embed_dim,
764
+ dropout,
765
+ out_channels=out_ch,
766
+ dims=dims,
767
+ use_checkpoint=use_checkpoint,
768
+ use_scale_shift_norm=use_scale_shift_norm,
769
+ up=True,
770
+ )
771
+ if resblock_updown
772
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
773
+ )
774
+ if use_ray_decoder:
775
+ ray_decoder_layers.append(
776
+ ResBlock(
777
+ ch // 10,
778
+ time_embed_dim,
779
+ dropout,
780
+ out_channels=out_ch // 10,
781
+ dims=dims,
782
+ use_checkpoint=use_checkpoint,
783
+ use_scale_shift_norm=use_scale_shift_norm,
784
+ up=True,
785
+ )
786
+ if resblock_updown
787
+ else Upsample(
788
+ ch // 10,
789
+ conv_resample,
790
+ dims=dims,
791
+ out_channels=out_ch // 10,
792
+ )
793
+ )
794
+ ds //= 2
795
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
796
+ if use_ray_decoder:
797
+ self.ray_decoder_blocks.append(
798
+ TimestepEmbedSequential(*ray_decoder_layers)
799
+ )
800
+
801
+ self.out = nn.Sequential(
802
+ normalization(ch),
803
+ nn.SiLU(),
804
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
805
+ )
806
+
807
+ if self.use_addition_ray_output_head:
808
+ ray_model_channels = model_channels // 10
809
+ self.ray_output_head = nn.Sequential(
810
+ normalization(ray_model_channels),
811
+ nn.SiLU(),
812
+ conv_nd(dims, ray_model_channels, ray_model_channels, 3, padding=1),
813
+ nn.SiLU(),
814
+ conv_nd(dims, ray_model_channels, ray_model_channels, 3, padding=1),
815
+ nn.SiLU(),
816
+ zero_module(
817
+ conv_nd(dims, ray_model_channels, self.ray_channels, 3, padding=1)
818
+ ),
819
+ )
820
+ self.use_feature_alignment = use_feature_alignment
821
+ if self.use_feature_alignment:
822
+ self.feature_alignment_adapter = FeatureAlignmentAdapter(
823
+ time_embed_dim=time_embed_dim, use_checkpoint=use_checkpoint
824
+ )
825
+
826
+ def forward(
827
+ self,
828
+ x,
829
+ time_steps,
830
+ context=None,
831
+ features_adapter=None,
832
+ fs=None,
833
+ task_idx=None,
834
+ camera_poses=None,
835
+ return_input_block_features=False,
836
+ return_middle_feature=False,
837
+ return_output_block_features=False,
838
+ **kwargs,
839
+ ):
840
+ intermediate_features = {}
841
+ if return_input_block_features:
842
+ intermediate_features["input"] = []
843
+ if return_output_block_features:
844
+ intermediate_features["output"] = []
845
+ b, t, _, _, _ = x.shape
846
+ t_emb = timestep_embedding(
847
+ time_steps, self.model_channels, repeat_only=False
848
+ ).type(x.dtype)
849
+ emb = self.time_embed(t_emb)
850
+
851
+ # repeat t times for context [(b t) 77 768] & time embedding
852
+ # check if we use per-frame image conditioning
853
+ _, l_context, _ = context.shape
854
+ if l_context == 77 + t * 16: # !!! HARD CODE here
855
+ context_text, context_img = context[:, :77, :], context[:, 77:, :]
856
+ context_text = context_text.repeat_interleave(repeats=t, dim=0)
857
+ context_img = rearrange(context_img, "b (t l) c -> (b t) l c", t=t)
858
+ context = torch.cat([context_text, context_img], dim=1)
859
+ else:
860
+ context = context.repeat_interleave(repeats=t, dim=0)
861
+ emb = emb.repeat_interleave(repeats=t, dim=0)
862
+
863
+ # always in shape (b t) c h w, except for temporal layer
864
+ x = rearrange(x, "b t c h w -> (b t) c h w")
865
+
866
+ # combine emb
867
+ if self.fs_condition:
868
+ if fs is None:
869
+ fs = torch.tensor(
870
+ [self.default_fs] * b, dtype=torch.long, device=x.device
871
+ )
872
+ fs_emb = timestep_embedding(
873
+ fs, self.model_channels, repeat_only=False
874
+ ).type(x.dtype)
875
+
876
+ fs_embed = self.fps_embedding(fs_emb)
877
+ fs_embed = fs_embed.repeat_interleave(repeats=t, dim=0)
878
+ emb = emb + fs_embed
879
+
880
+ if self.camera_pose_condition:
881
+ # camera_poses: (b, t, 12)
882
+ camera_poses = rearrange(camera_poses, "b t x y -> (b t) (x y)") # x=3, y=4
883
+ camera_poses_embed = self.camera_pose_embedding(camera_poses)
884
+ emb = emb + camera_poses_embed
885
+
886
+ if self.use_task_embedding:
887
+ assert (
888
+ task_idx is not None
889
+ ), "`task_idx` should not be None when `use_task_embedding` is enabled."
890
+ task_embed = self.task_embedding(
891
+ self.task_parameters[task_idx]
892
+ .reshape(1, self.model_channels)
893
+ .repeat(b, 1)
894
+ )
895
+ task_embed = task_embed.repeat_interleave(repeats=t, dim=0)
896
+ emb = emb + task_embed
897
+
898
+ h = x.type(self.dtype)
899
+ adapter_idx = 0
900
+ hs = []
901
+ for _id, module in enumerate(self.input_blocks):
902
+
903
+ h = module(h, emb, context=context, batch_size=b)
904
+ if _id == 0 and self.addition_attention:
905
+ h = self.init_attn(h, emb, context=context, batch_size=b)
906
+ # plug-in adapter features
907
+ if ((_id + 1) % 3 == 0) and features_adapter is not None:
908
+ h = h + features_adapter[adapter_idx]
909
+ adapter_idx += 1
910
+ hs.append(h)
911
+ if return_input_block_features:
912
+ intermediate_features["input"].append(h)
913
+ if features_adapter is not None:
914
+ assert len(features_adapter) == adapter_idx, "Wrong features_adapter"
915
+
916
+ h = self.middle_block(h, emb, context=context, batch_size=b)
917
+
918
+ if return_middle_feature:
919
+ intermediate_features["middle"] = h
920
+
921
+ if self.use_feature_alignment:
922
+ feature_alignment_output = self.feature_alignment_adapter(
923
+ hs[2], hs[5], hs[8], emb=emb
924
+ )
925
+
926
+ # >>> Output Blocks Forward
927
+ if self.use_ray_decoder:
928
+ h_original = h
929
+ h_ray = h
930
+ for original_module, ray_module in zip(
931
+ self.output_blocks, self.ray_decoder_blocks
932
+ ):
933
+ cur_hs = hs.pop()
934
+ h_original = torch.cat([h_original, cur_hs], dim=1)
935
+ h_original = original_module(
936
+ h_original,
937
+ emb,
938
+ context=context,
939
+ batch_size=b,
940
+ time_steps=time_steps,
941
+ )
942
+ if self.use_ray_decoder_residual:
943
+ h_ray = torch.cat([h_ray, cur_hs], dim=1)
944
+ h_ray = ray_module(h_ray, emb, context=context, batch_size=b)
945
+ if return_output_block_features:
946
+ print(
947
+ "return_output_block_features: h_original.shape=",
948
+ h_original.shape,
949
+ )
950
+ intermediate_features["output"].append(h_original.detach())
951
+ h_original = h_original.type(x.dtype)
952
+ h_ray = h_ray.type(x.dtype)
953
+ y_original = self.out(h_original)
954
+ y_ray = self.ray_output_head(h_ray)
955
+ y = torch.cat([y_original, y_ray], dim=1)
956
+ else:
957
+ if self.use_lora_for_rays_in_output_blocks:
958
+ middle_h = h
959
+ h_original = middle_h
960
+ h_lora = middle_h
961
+ for output_idx, module in enumerate(self.output_blocks):
962
+ cur_hs = hs.pop()
963
+ h_original = torch.cat([h_original, cur_hs], dim=1)
964
+ h_original = module(
965
+ h_original, emb, context=context, batch_size=b, with_lora=False
966
+ )
967
+
968
+ h_lora = torch.cat([h_lora, cur_hs], dim=1)
969
+ h_lora = module(
970
+ h_lora, emb, context=context, batch_size=b, with_lora=True
971
+ )
972
+ h_original = h_original.type(x.dtype)
973
+ h_lora = h_lora.type(x.dtype)
974
+ y_original = self.out(h_original)
975
+ y_lora = self.ray_output_head(h_lora)
976
+ y = torch.cat([y_original, y_lora], dim=1)
977
+ else:
978
+ for module in self.output_blocks:
979
+ h = torch.cat([h, hs.pop()], dim=1)
980
+ h = module(h, emb, context=context, batch_size=b)
981
+ h = h.type(x.dtype)
982
+
983
+ if self.use_task_embedding:
984
+ # Seperated Input (Branch Control in CPU)
985
+ # Serial Execution (GPU Vectorization Pending)
986
+ if task_idx == TASK_IDX_IMAGE:
987
+ y = self.out(h)
988
+ elif task_idx == TASK_IDX_RAY:
989
+ y = self.ray_output_head(h)
990
+ else:
991
+ raise NotImplementedError(f"Unsupported `task_idx`: {task_idx}")
992
+ else:
993
+ # Output ray and images at the same forward
994
+ y = self.out(h)
995
+
996
+ if self.use_addition_ray_output_head:
997
+ y_ray = self.ray_output_head(h)
998
+ y = torch.cat([y, y_ray], dim=1)
999
+ # reshape back to (b c t h w)
1000
+ y = rearrange(y, "(b t) c h w -> b t c h w", b=b)
1001
+ if (
1002
+ return_input_block_features
1003
+ or return_output_block_features
1004
+ or return_middle_feature
1005
+ ):
1006
+ return y, intermediate_features
1007
+ # Assume intermediate features are only request during non-training scenarios (e.g., feature visualization)
1008
+ if self.use_feature_alignment:
1009
+ return y, feature_alignment_output
1010
+ return y
1011
+
1012
+
1013
+ class FeatureAlignmentAdapter(torch.nn.Module):
1014
+ def __init__(self, time_embed_dim, use_checkpoint, dropout=0.0, *args, **kwargs):
1015
+ super().__init__(*args, **kwargs)
1016
+ self.channel_adapter_conv_16 = torch.nn.Conv2d(
1017
+ in_channels=1280, out_channels=320, kernel_size=1
1018
+ )
1019
+ self.channel_adapter_conv_32 = torch.nn.Conv2d(
1020
+ in_channels=640, out_channels=320, kernel_size=1
1021
+ )
1022
+ self.upsampler_x2 = torch.nn.UpsamplingBilinear2d(scale_factor=2)
1023
+ self.upsampler_x4 = torch.nn.UpsamplingBilinear2d(scale_factor=4)
1024
+ self.res_block = ResBlock(
1025
+ 320 * 3,
1026
+ time_embed_dim,
1027
+ dropout,
1028
+ out_channels=32 * 3,
1029
+ dims=2,
1030
+ use_checkpoint=use_checkpoint,
1031
+ use_scale_shift_norm=False,
1032
+ )
1033
+ self.final_conv = conv_nd(
1034
+ dims=2, in_channels=32 * 3, out_channels=6, kernel_size=1
1035
+ )
1036
+
1037
+ def forward(self, feature_64, feature_32, feature_16, emb):
1038
+ feature_16_adapted = self.channel_adapter_conv_16(feature_16)
1039
+ feature_32_adapted = self.channel_adapter_conv_32(feature_32)
1040
+ feature_16_upsampled = self.upsampler_x4(feature_16_adapted)
1041
+ feature_32_upsampled = self.upsampler_x2(feature_32_adapted)
1042
+ feature_all = torch.concat(
1043
+ [feature_16_upsampled, feature_32_upsampled, feature_64], dim=1
1044
+ )
1045
+
1046
+ # bt, 3, h, w
1047
+ return self.final_conv(self.res_block(feature_all, emb=emb))
core/modules/position_encoding.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ """
3
+ Various positional encodings for the transformer.
4
+ """
5
+ import math
6
+ import torch
7
+ from torch import nn
8
+
9
+
10
+ class PositionEmbeddingSine(nn.Module):
11
+ """
12
+ This is a more standard version of the position embedding, very similar to the one
13
+ used by the Attention is all you need paper, generalized to work on images.
14
+ """
15
+
16
+ def __init__(
17
+ self, num_pos_feats=64, temperature=10000, normalize=False, scale=None
18
+ ):
19
+ super().__init__()
20
+ self.num_pos_feats = num_pos_feats
21
+ self.temperature = temperature
22
+ self.normalize = normalize
23
+ if scale is not None and normalize is False:
24
+ raise ValueError("normalize should be True if scale is passed")
25
+ if scale is None:
26
+ scale = 2 * math.pi
27
+ self.scale = scale
28
+
29
+ def forward(self, token_tensors):
30
+ # input: (B,C,H,W)
31
+ x = token_tensors
32
+ h, w = x.shape[-2:]
33
+ identity_map = torch.ones((h, w), device=x.device)
34
+ y_embed = identity_map.cumsum(0, dtype=torch.float32)
35
+ x_embed = identity_map.cumsum(1, dtype=torch.float32)
36
+ if self.normalize:
37
+ eps = 1e-6
38
+ y_embed = y_embed / (y_embed[-1:, :] + eps) * self.scale
39
+ x_embed = x_embed / (x_embed[:, -1:] + eps) * self.scale
40
+
41
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
42
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
43
+
44
+ pos_x = x_embed[:, :, None] / dim_t
45
+ pos_y = y_embed[:, :, None] / dim_t
46
+ pos_x = torch.stack(
47
+ (pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3
48
+ ).flatten(2)
49
+ pos_y = torch.stack(
50
+ (pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3
51
+ ).flatten(2)
52
+ pos = torch.cat((pos_y, pos_x), dim=2).permute(2, 0, 1)
53
+ batch_pos = pos.unsqueeze(0).repeat(x.shape[0], 1, 1, 1)
54
+ return batch_pos
55
+
56
+
57
+ class PositionEmbeddingLearned(nn.Module):
58
+ """
59
+ Absolute pos embedding, learned.
60
+ """
61
+
62
+ def __init__(self, n_pos_x=16, n_pos_y=16, num_pos_feats=64):
63
+ super().__init__()
64
+ self.row_embed = nn.Embedding(n_pos_y, num_pos_feats)
65
+ self.col_embed = nn.Embedding(n_pos_x, num_pos_feats)
66
+ self.reset_parameters()
67
+
68
+ def reset_parameters(self):
69
+ nn.init.uniform_(self.row_embed.weight)
70
+ nn.init.uniform_(self.col_embed.weight)
71
+
72
+ def forward(self, token_tensors):
73
+ # input: (B,C,H,W)
74
+ x = token_tensors
75
+ h, w = x.shape[-2:]
76
+ i = torch.arange(w, device=x.device)
77
+ j = torch.arange(h, device=x.device)
78
+ x_emb = self.col_embed(i)
79
+ y_emb = self.row_embed(j)
80
+ pos = torch.cat(
81
+ [
82
+ x_emb.unsqueeze(0).repeat(h, 1, 1),
83
+ y_emb.unsqueeze(1).repeat(1, w, 1),
84
+ ],
85
+ dim=-1,
86
+ ).permute(2, 0, 1)
87
+ batch_pos = pos.unsqueeze(0).repeat(x.shape[0], 1, 1, 1)
88
+ return batch_pos
89
+
90
+
91
+ def build_position_encoding(num_pos_feats=64, n_pos_x=16, n_pos_y=16, is_learned=False):
92
+ if is_learned:
93
+ position_embedding = PositionEmbeddingLearned(n_pos_x, n_pos_y, num_pos_feats)
94
+ else:
95
+ position_embedding = PositionEmbeddingSine(num_pos_feats, normalize=True)
96
+
97
+ return position_embedding