diff --git a/.gitattributes b/.gitattributes
new file mode 100644
index 0000000000000000000000000000000000000000..21f0dde4a18d380134c83213b3b3de7ec8cba6a3
--- /dev/null
+++ b/.gitattributes
@@ -0,0 +1,58 @@
+*.7z filter=lfs diff=lfs merge=lfs -text
+*.arrow filter=lfs diff=lfs merge=lfs -text
+*.bin filter=lfs diff=lfs merge=lfs -text
+*.bz2 filter=lfs diff=lfs merge=lfs -text
+*.ckpt filter=lfs diff=lfs merge=lfs -text
+*.ftz filter=lfs diff=lfs merge=lfs -text
+*.gz filter=lfs diff=lfs merge=lfs -text
+*.h5 filter=lfs diff=lfs merge=lfs -text
+*.joblib filter=lfs diff=lfs merge=lfs -text
+*.lfs.* filter=lfs diff=lfs merge=lfs -text
+*.mlmodel filter=lfs diff=lfs merge=lfs -text
+*.model filter=lfs diff=lfs merge=lfs -text
+*.msgpack filter=lfs diff=lfs merge=lfs -text
+*.npy filter=lfs diff=lfs merge=lfs -text
+*.npz filter=lfs diff=lfs merge=lfs -text
+*.onnx filter=lfs diff=lfs merge=lfs -text
+*.ot filter=lfs diff=lfs merge=lfs -text
+*.parquet filter=lfs diff=lfs merge=lfs -text
+*.pb filter=lfs diff=lfs merge=lfs -text
+*.pickle filter=lfs diff=lfs merge=lfs -text
+*.pkl filter=lfs diff=lfs merge=lfs -text
+*.pt filter=lfs diff=lfs merge=lfs -text
+*.pth filter=lfs diff=lfs merge=lfs -text
+*.rar filter=lfs diff=lfs merge=lfs -text
+*.safetensors filter=lfs diff=lfs merge=lfs -text
+saved_model/**/* filter=lfs diff=lfs merge=lfs -text
+*.tar.* filter=lfs diff=lfs merge=lfs -text
+*.tar filter=lfs diff=lfs merge=lfs -text
+*.tflite filter=lfs diff=lfs merge=lfs -text
+*.tgz filter=lfs diff=lfs merge=lfs -text
+*.wasm filter=lfs diff=lfs merge=lfs -text
+*.xz filter=lfs diff=lfs merge=lfs -text
+*.zip filter=lfs diff=lfs merge=lfs -text
+*.zst filter=lfs diff=lfs merge=lfs -text
+*tfevents* filter=lfs diff=lfs merge=lfs -text
+__asset__/sample3.jpg filter=lfs diff=lfs merge=lfs -text
+__asset__/sample1-1.png filter=lfs diff=lfs merge=lfs -text
+__asset__/sample1-2.png filter=lfs diff=lfs merge=lfs -text
+__asset__/sample2.png filter=lfs diff=lfs merge=lfs -text
+assets/sample1-1.png filter=lfs diff=lfs merge=lfs -text
+assets/sample1-2.png filter=lfs diff=lfs merge=lfs -text
+assets/sample2.png filter=lfs diff=lfs merge=lfs -text
+assets/sample3.jpg filter=lfs diff=lfs merge=lfs -text
+assets/sample1.mp4 filter=lfs diff=lfs merge=lfs -text
+assets/sample2.mp4 filter=lfs diff=lfs merge=lfs -text
+assets/sample3.mp4 filter=lfs diff=lfs merge=lfs -text
+assets/sample3-2.png filter=lfs diff=lfs merge=lfs -text
+assets/sample1.jpg filter=lfs diff=lfs merge=lfs -text
+assets/sample2.jpeg filter=lfs diff=lfs merge=lfs -text
+assets/sample3-1.png filter=lfs diff=lfs merge=lfs -text
+assets/sample4.mp4 filter=lfs diff=lfs merge=lfs -text
+assets/sample5-1.png filter=lfs diff=lfs merge=lfs -text
+assets/sample5-2.png filter=lfs diff=lfs merge=lfs -text
+assets/sample5.mp4 filter=lfs diff=lfs merge=lfs -text
+assets/sample2.jpg filter=lfs diff=lfs merge=lfs -text
+assets/sample3.jpeg filter=lfs diff=lfs merge=lfs -text
+assets/sample4.jpeg filter=lfs diff=lfs merge=lfs -text
+assets/sample4.jpg filter=lfs diff=lfs merge=lfs -text
diff --git a/.gitignore b/.gitignore
new file mode 100755
index 0000000000000000000000000000000000000000..be7f7227f4b12a4a84a948b03b6ea5e7c9d2abb8
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,7 @@
+.idea
+__pycache__
+.git
+*.pyc
+.DS_Store
+._*
+cache
\ No newline at end of file
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..fe0e2114fd54d47ab8fcf964766a509e445031c5
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,302 @@
+Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. The below software and/or models in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"). All Tencent Modifications are Copyright (C) THL A29 Limited.
+
+
+License Terms of the NVComposer:
+--------------------------------------------------------------------
+
+Permission is hereby granted, free of charge, to any person obtaining a copy of this Software and associated documentation files, to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, and/or sublicense copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
+
+- You agree to use the NVComposer only for academic, research and education purposes, and refrain from using it for any commercial or production purposes under any circumstances.
+
+- The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
+
+For avoidance of doubts, "Software" means the NVComposer model inference-enabling code, parameters and weights made available under this license.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
+
+
+Other dependencies and licenses:
+
+
+Open Source Model Licensed under the CreativeML OpenRAIL M license:
+The below model in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"), as model weights provided for the NVComposer Project hereunder is fine-tuned with the assistance of below model.
+
+All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
+--------------------------------------------------------------------
+1. stable-diffusion-v1-5
+This stable-diffusion-v1-5 is licensed under the CreativeML OpenRAIL M license, Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors
+The original model is available at: https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5
+
+
+Terms of the CreativeML OpenRAIL M license:
+--------------------------------------------------------------------
+Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors
+
+CreativeML Open RAIL-M
+dated August 22, 2022
+
+Section I: PREAMBLE
+
+Multimodal generative models are being widely adopted and used, and have the potential to transform the way artists, among other individuals, conceive and benefit from AI or ML technologies as a tool for content creation.
+
+Notwithstanding the current and potential benefits that these artifacts can bring to society at large, there are also concerns about potential misuses of them, either due to their technical limitations or ethical considerations.
+
+In short, this license strives for both the open and responsible downstream use of the accompanying model. When it comes to the open character, we took inspiration from open source permissive licenses regarding the grant of IP rights. Referring to the downstream responsible use, we added use-based restrictions not permitting the use of the Model in very specific scenarios, in order for the licensor to be able to enforce the license in case potential misuses of the Model may occur. At the same time, we strive to promote open and responsible research on generative models for art and content generation.
+
+Even though downstream derivative versions of the model could be released under different licensing terms, the latter will always have to include - at minimum - the same use-based restrictions as the ones in the original license (this license). We believe in the intersection between open and responsible AI development; thus, this License aims to strike a balance between both in order to enable responsible open-science in the field of AI.
+
+This License governs the use of the model (and its derivatives) and is informed by the model card associated with the model.
+
+NOW THEREFORE, You and Licensor agree as follows:
+
+1. Definitions
+
+- "License" means the terms and conditions for use, reproduction, and Distribution as defined in this document.
+- "Data" means a collection of information and/or content extracted from the dataset used with the Model, including to train, pretrain, or otherwise evaluate the Model. The Data is not licensed under this License.
+- "Output" means the results of operating a Model as embodied in informational content resulting therefrom.
+- "Model" means any accompanying machine-learning based assemblies (including checkpoints), consisting of learnt weights, parameters (including optimizer states), corresponding to the model architecture as embodied in the Complementary Material, that have been trained or tuned, in whole or in part on the Data, using the Complementary Material.
+- "Derivatives of the Model" means all modifications to the Model, works based on the Model, or any other model which is created or initialized by transfer of patterns of the weights, parameters, activations or output of the Model, to the other model, in order to cause the other model to perform similarly to the Model, including - but not limited to - distillation methods entailing the use of intermediate data representations or methods based on the generation of synthetic data by the Model for training the other model.
+- "Complementary Material" means the accompanying source code and scripts used to define, run, load, benchmark or evaluate the Model, and used to prepare data for training or evaluation, if any. This includes any accompanying documentation, tutorials, examples, etc, if any.
+- "Distribution" means any transmission, reproduction, publication or other sharing of the Model or Derivatives of the Model to a third party, including providing the Model as a hosted service made available by electronic or other remote means - e.g. API-based or web access.
+- "Licensor" means the copyright owner or entity authorized by the copyright owner that is granting the License, including the persons or entities that may have rights in the Model and/or distributing the Model.
+- "You" (or "Your") means an individual or Legal Entity exercising permissions granted by this License and/or making use of the Model for whichever purpose and in any field of use, including usage of the Model in an end-use application - e.g. chatbot, translator, image generator.
+- "Third Parties" means individuals or legal entities that are not under common control with Licensor or You.
+- "Contribution" means any work of authorship, including the original version of the Model and any modifications or additions to that Model or Derivatives of the Model thereof, that is intentionally submitted to Licensor for inclusion in the Model by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Model, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution."
+- "Contributor" means Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Model.
+
+Section II: INTELLECTUAL PROPERTY RIGHTS
+
+Both copyright and patent grants apply to the Model, Derivatives of the Model and Complementary Material. The Model and Derivatives of the Model are subject to additional terms as described in Section III.
+
+2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare, publicly display, publicly perform, sublicense, and distribute the Complementary Material, the Model, and Derivatives of the Model.
+3. Grant of Patent License. Subject to the terms and conditions of this License and where and as applicable, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this paragraph) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Model and the Complementary Material, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Model to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Model and/or Complementary Material or a Contribution incorporated within the Model and/or Complementary Material constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for the Model and/or Work shall terminate as of the date such litigation is asserted or filed.
+
+Section III: CONDITIONS OF USAGE, DISTRIBUTION AND REDISTRIBUTION
+
+4. Distribution and Redistribution. You may host for Third Party remote access purposes (e.g. software-as-a-service), reproduce and distribute copies of the Model or Derivatives of the Model thereof in any medium, with or without modifications, provided that You meet the following conditions:
+Use-based restrictions as referenced in paragraph 5 MUST be included as an enforceable provision by You in any type of legal agreement (e.g. a license) governing the use and/or distribution of the Model or Derivatives of the Model, and You shall give notice to subsequent users You Distribute to, that the Model or Derivatives of the Model are subject to paragraph 5. This provision does not apply to the use of Complementary Material.
+You must give any Third Party recipients of the Model or Derivatives of the Model a copy of this License;
+You must cause any modified files to carry prominent notices stating that You changed the files;
+You must retain all copyright, patent, trademark, and attribution notices excluding those notices that do not pertain to any part of the Model, Derivatives of the Model.
+You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions - respecting paragraph 4.a. - for use, reproduction, or Distribution of Your modifications, or for any such Derivatives of the Model as a whole, provided Your use, reproduction, and Distribution of the Model otherwise complies with the conditions stated in this License.
+5. Use-based restrictions. The restrictions set forth in Attachment A are considered Use-based restrictions. Therefore You cannot use the Model and the Derivatives of the Model for the specified restricted uses. You may use the Model subject to this License, including only for lawful purposes and in accordance with the License. Use may include creating any content with, finetuning, updating, running, training, evaluating and/or reparametrizing the Model. You shall require all of Your users who use the Model or a Derivative of the Model to comply with the terms of this paragraph (paragraph 5).
+6. The Output You Generate. Except as set forth herein, Licensor claims no rights in the Output You generate using the Model. You are accountable for the Output you generate and its subsequent uses. No use of the output can contravene any provision as stated in the License.
+
+Section IV: OTHER PROVISIONS
+
+7. Updates and Runtime Restrictions. To the maximum extent permitted by law, Licensor reserves the right to restrict (remotely or otherwise) usage of the Model in violation of this License, update the Model through electronic means, or modify the Output of the Model based on updates. You shall undertake reasonable efforts to use the latest version of the Model.
+8. Trademarks and related. Nothing in this License permits You to make use of Licensors’ trademarks, trade names, logos or to otherwise suggest endorsement or misrepresent the relationship between the parties; and any rights not expressly granted herein are reserved by the Licensors.
+9. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Model and the Complementary Material (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Model, Derivatives of the Model, and the Complementary Material and assume any risks associated with Your exercise of permissions under this License.
+10. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Model and the Complementary Material (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
+11. Accepting Warranty or Additional Liability. While redistributing the Model, Derivatives of the Model and the Complementary Material thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
+12. If any provision of this License is held to be invalid, illegal or unenforceable, the remaining provisions shall be unaffected thereby and remain valid as if such provision had not been set forth herein.
+
+END OF TERMS AND CONDITIONS
+
+
+
+
+Attachment A
+
+Use Restrictions
+
+You agree not to use the Model or Derivatives of the Model:
+- In any way that violates any applicable national, federal, state, local or international law or regulation;
+- For the purpose of exploiting, harming or attempting to exploit or harm minors in any way;
+- To generate or disseminate verifiably false information and/or content with the purpose of harming others;
+- To generate or disseminate personal identifiable information that can be used to harm an individual;
+- To defame, disparage or otherwise harass others;
+- For fully automated decision making that adversely impacts an individual’s legal rights or otherwise creates or modifies a binding, enforceable obligation;
+- For any use intended to or which has the effect of discriminating against or harming individuals or groups based on online or offline social behavior or known or predicted personal or personality characteristics;
+- To exploit any of the vulnerabilities of a specific group of persons based on their age, social, physical or mental characteristics, in order to materially distort the behavior of a person pertaining to that group in a manner that causes or is likely to cause that person or another person physical or psychological harm;
+- For any use intended to or which has the effect of discriminating against individuals or groups based on legally protected characteristics or categories;
+- To provide medical advice and medical results interpretation;
+- To generate or disseminate information for the purpose to be used for administration of justice, law enforcement, immigration or asylum processes, such as predicting an individual will commit fraud/crime commitment (e.g. by text profiling, drawing causal relationships between assertions made in documents, indiscriminate and arbitrarily-targeted use).
+
+
+
+Open Source Software Licensed under the Apache License Version 2.0:
+--------------------------------------------------------------------
+1. pytorch_lightning
+Copyright 2018-2021 William Falcon
+
+2. gradio
+Copyright (c) gradio original author and authors
+
+
+Terms of the Apache License Version 2.0:
+--------------------------------------------------------------------
+Apache License
+
+Version 2.0, January 2004
+
+http://www.apache.org/licenses/
+
+TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+1. Definitions.
+
+"License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document.
+
+"Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License.
+
+"Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
+
+"You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License.
+
+"Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files.
+
+"Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types.
+
+"Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below).
+
+"Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof.
+
+"Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution."
+
+"Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work.
+
+2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form.
+
+3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed.
+
+4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:
+
+You must give any other recipients of the Work or Derivative Works a copy of this License; and
+
+You must cause any modified files to carry prominent notices stating that You changed the files; and
+
+You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and
+
+If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License.
+
+You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License.
+
+5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions.
+
+6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file.
+
+7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License.
+
+8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
+
+9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
+
+END OF TERMS AND CONDITIONS
+
+
+
+Open Source Software Licensed under the BSD 3-Clause License:
+--------------------------------------------------------------------
+1. torchvision
+Copyright (c) Soumith Chintala 2016,
+All rights reserved.
+
+2. scikit-learn
+Copyright (c) 2007-2024 The scikit-learn developers.
+All rights reserved.
+
+
+Terms of the BSD 3-Clause License:
+--------------------------------------------------------------------
+Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
+
+1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
+
+2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
+
+3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+
+
+Open Source Software Licensed under the BSD 3-Clause License and Other Licenses of the Third-Party Components therein:
+--------------------------------------------------------------------
+1. torch
+Copyright (c) 2016- Facebook, Inc (Adam Paszke)
+Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
+Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
+Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
+Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
+Copyright (c) 2011-2013 NYU (Clement Farabet)
+Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
+Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
+Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
+
+
+
+A copy of the BSD 3-Clause is included in this file.
+
+For the license of other third party components, please refer to the following URL:
+https://github.com/pytorch/pytorch/tree/v2.1.2/third_party
+
+
+Open Source Software Licensed under the BSD 3-Clause License and Other Licenses of the Third-Party Components therein:
+--------------------------------------------------------------------
+1. numpy
+Copyright (c) 2005-2023, NumPy Developers.
+All rights reserved.
+
+
+A copy of the BSD 3-Clause is included in this file.
+
+For the license of other third party components, please refer to the following URL:
+https://github.com/numpy/numpy/blob/v1.26.3/LICENSES_bundled.txt
+
+
+Open Source Software Licensed under the HPND License:
+--------------------------------------------------------------------
+1. Pillow
+Copyright © 2010-2024 by Jeffrey A. Clark (Alex) and contributors.
+
+
+Terms of the HPND License:
+--------------------------------------------------------------------
+By obtaining, using, and/or copying this software and/or its associated
+documentation, you agree that you have read, understood, and will comply
+with the following terms and conditions:
+
+Permission to use, copy, modify and distribute this software and its
+documentation for any purpose and without fee is hereby granted,
+provided that the above copyright notice appears in all copies, and that
+both that copyright notice and this permission notice appear in supporting
+documentation, and that the name of Secret Labs AB or the author not be
+used in advertising or publicity pertaining to distribution of the software
+without specific, written prior permission.
+
+SECRET LABS AB AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS
+SOFTWARE, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS.
+IN NO EVENT SHALL SECRET LABS AB OR THE AUTHOR BE LIABLE FOR ANY SPECIAL,
+INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM
+LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE
+OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR
+PERFORMANCE OF THIS SOFTWARE.
+
+
+
+Open Source Software Licensed under the MIT License:
+--------------------------------------------------------------------
+1. einops
+Copyright (c) 2018 Alex Rogozhnikov
+
+
+Terms of the MIT License:
+--------------------------------------------------------------------
+Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
+
+
+
+Open Source Software Licensed under the MIT License and Other Licenses of the Third-Party Components therein:
+--------------------------------------------------------------------
+1. opencv-python
+Copyright (c) Olli-Pekka Heinisuo
+
+
+A copy of the MIT is included in this file.
+
+For the license of other third party components, please refer to the following URL:
+https://github.com/opencv/opencv-python/blob/4.x/LICENSE-3RD-PARTY.txt
\ No newline at end of file
diff --git a/README.md b/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..766cdf0114969ed71bdaa46c2ba126c28e45d0be
--- /dev/null
+++ b/README.md
@@ -0,0 +1,11 @@
+---
+title: NVComposer
+emoji: 📸
+colorFrom: indigo
+colorTo: gray
+sdk: gradio
+sdk_version: 4.38.1
+app_file: app.py
+pinned: false
+python_version: 3.1
+---
\ No newline at end of file
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..375b24689044db51b88c4226634030579c8bc24f
--- /dev/null
+++ b/app.py
@@ -0,0 +1,809 @@
+import datetime
+import json
+import os
+
+import gradio as gr
+from huggingface_hub import hf_hub_download
+import spaces
+import PIL.Image
+import numpy as np
+import torch
+import torchvision.transforms.functional
+from numpy import deg2rad
+from omegaconf import OmegaConf
+
+from core.data.camera_pose_utils import convert_w2c_between_c2w
+from core.data.combined_multi_view_dataset import (
+ get_ray_embeddings,
+ normalize_w2c_camera_pose_sequence,
+ crop_and_resize,
+)
+from main.evaluation.funcs import load_model_checkpoint
+from main.evaluation.pose_interpolation import (
+ move_pose,
+ interpolate_camera_poses,
+ generate_spherical_trajectory,
+)
+from main.evaluation.utils_eval import process_inference_batch
+from utils.utils import instantiate_from_config
+from core.models.samplers.ddim import DDIMSampler
+
+torch.set_float32_matmul_precision("medium")
+
+gpu_no = 0
+config = "./configs/dual_stream/nvcomposer.yaml"
+ckpt = hf_hub_download(
+ repo_id="TencentARC/NVComposer", filename="NVComposer-V0.1.ckpt", repo_type="model"
+)
+
+model_resolution_height, model_resolution_width = 576, 1024
+num_views = 16
+dtype = torch.float16
+config = OmegaConf.load(config)
+model_config = config.pop("model", OmegaConf.create())
+model_config.params.train_with_multi_view_feature_alignment = False
+model = instantiate_from_config(model_config).cuda(gpu_no).to(dtype=dtype)
+assert os.path.exists(ckpt), f"Error: checkpoint [{ckpt}] Not Found!"
+print(f"Loading checkpoint from {ckpt}...")
+model = load_model_checkpoint(model, ckpt)
+model.eval()
+latent_h, latent_w = (
+ model_resolution_height // 8,
+ model_resolution_width // 8,
+)
+channels = model.channels
+sampler = DDIMSampler(model)
+
+EXAMPLES = [
+ [
+ "./assets/sample1.jpg",
+ None,
+ 1,
+ 0,
+ 0,
+ 1,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.2,
+ 3,
+ 1.5,
+ 20,
+ "./assets/sample1.mp4",
+ 1,
+ ],
+ [
+ "./assets/sample2.jpg",
+ None,
+ 0,
+ 0,
+ 25,
+ 1,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 3,
+ 1.5,
+ 20,
+ "./assets/sample2.mp4",
+ 1,
+ ],
+ [
+ "./assets/sample3.jpg",
+ None,
+ 0,
+ 0,
+ 15,
+ 1,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 3,
+ 1.5,
+ 20,
+ "./assets/sample3.mp4",
+ 1,
+ ],
+ [
+ "./assets/sample4.jpg",
+ None,
+ 0,
+ 0,
+ -15,
+ 1,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 3,
+ 1.5,
+ 20,
+ "./assets/sample4.mp4",
+ 1,
+ ],
+ [
+ "./assets/sample5-1.png",
+ "./assets/sample5-2.png",
+ 0,
+ 0,
+ -30,
+ 1,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 3,
+ 1.5,
+ 20,
+ "./assets/sample5.mp4",
+ 2,
+ ],
+]
+
+
+def compose_data_item(
+ num_views,
+ cond_pil_image_list,
+ caption="",
+ camera_mode=False,
+ input_pose_format="c2w",
+ model_pose_format="c2w",
+ x_rotation_angle=10,
+ y_rotation_angle=10,
+ z_rotation_angle=10,
+ x_translation=0.5,
+ y_translation=0.5,
+ z_translation=0.5,
+ image_size=None,
+ spherical_angle_x=10,
+ spherical_angle_y=10,
+ spherical_radius=10,
+):
+ if image_size is None:
+ image_size = [512, 512]
+ latent_size = [image_size[0] // 8, image_size[1] // 8]
+
+ def image_processing_function(x):
+ return (
+ torch.from_numpy(
+ np.array(
+ crop_and_resize(
+ x, target_height=image_size[0], target_width=image_size[1]
+ )
+ ).transpose((2, 0, 1))
+ ).float()
+ / 255.0
+ )
+
+ resizer_image_to_latent_size = torchvision.transforms.Resize(
+ size=latent_size,
+ interpolation=torchvision.transforms.InterpolationMode.BILINEAR,
+ antialias=True,
+ )
+ num_cond_views = len(cond_pil_image_list)
+ print(f"Number of received condition images: {num_cond_views}.")
+ num_target_views = num_views - num_cond_views
+ if camera_mode == 1:
+ print("Camera Mode: Movement with Rotation and Translation.")
+ start_pose = torch.tensor(
+ [
+ [1, 0, 0, 0],
+ [0, 1, 0, 0],
+ [0, 0, 1, 0],
+ ]
+ ).float()
+ end_pose = move_pose(
+ start_pose,
+ x_angle=torch.tensor(deg2rad(x_rotation_angle)),
+ y_angle=torch.tensor(deg2rad(y_rotation_angle)),
+ z_angle=torch.tensor(deg2rad(z_rotation_angle)),
+ translation=torch.tensor([x_translation, y_translation, z_translation]),
+ )
+ target_poses = interpolate_camera_poses(
+ start_pose, end_pose, num_steps=num_target_views
+ )
+ elif camera_mode == 0:
+ print("Camera Mode: Spherical Movement.")
+ target_poses = generate_spherical_trajectory(
+ end_angles=(spherical_angle_x, spherical_angle_y),
+ radius=spherical_radius,
+ num_steps=num_target_views,
+ )
+ print("Target pose sequence (before normalization): \n ", target_poses)
+ cond_poses = [
+ torch.tensor(
+ [
+ [1, 0, 0, 0],
+ [0, 1, 0, 0],
+ [0, 0, 1, 0],
+ ]
+ ).float()
+ ] * num_cond_views
+ target_poses = torch.stack(target_poses, dim=0).float()
+ cond_poses = torch.stack(cond_poses, dim=0).float()
+ if not camera_mode != 0 and (input_pose_format != "w2c"):
+ # c2w to w2c. Input for normalize_camera_pose_sequence() should be w2c
+ target_poses = convert_w2c_between_c2w(target_poses)
+ cond_poses = convert_w2c_between_c2w(cond_poses)
+ target_poses, cond_poses = normalize_w2c_camera_pose_sequence(
+ target_poses,
+ cond_poses,
+ output_c2w=model_pose_format == "c2w",
+ translation_norm_mode="disabled",
+ )
+ target_and_condition_camera_poses = torch.cat([target_poses, cond_poses], dim=0)
+
+ print("Target pose sequence (after normalization): \n ", target_poses)
+ fov_xy = [80, 45]
+ target_rays = get_ray_embeddings(
+ target_poses,
+ size_h=image_size[0],
+ size_w=image_size[1],
+ fov_xy_list=[fov_xy for _ in range(num_target_views)],
+ )
+ condition_rays = get_ray_embeddings(
+ cond_poses,
+ size_h=image_size[0],
+ size_w=image_size[1],
+ fov_xy_list=[fov_xy for _ in range(num_cond_views)],
+ )
+ target_images_tensor = torch.zeros(
+ num_target_views, 3, image_size[0], image_size[1]
+ )
+ condition_images = [image_processing_function(x) for x in cond_pil_image_list]
+ condition_images_tensor = torch.stack(condition_images, dim=0) * 2.0 - 1.0
+ target_images_tensor[0, :, :, :] = condition_images_tensor[0, :, :, :]
+ target_and_condition_images_tensor = torch.cat(
+ [target_images_tensor, condition_images_tensor], dim=0
+ )
+ target_and_condition_rays_tensor = torch.cat([target_rays, condition_rays], dim=0)
+ target_and_condition_rays_tensor = resizer_image_to_latent_size(
+ target_and_condition_rays_tensor * 5.0
+ )
+ mask_preserving_target = torch.ones(size=[num_views, 1], dtype=torch.float16)
+ mask_preserving_target[num_target_views:] = 0.0
+ combined_fovs = torch.stack([torch.tensor(fov_xy)] * num_views, dim=0)
+
+ mask_only_preserving_first_target = torch.zeros_like(mask_preserving_target)
+ mask_only_preserving_first_target[0] = 1.0
+ mask_only_preserving_first_condition = torch.zeros_like(mask_preserving_target)
+ mask_only_preserving_first_condition[num_target_views] = 1.0
+ test_data = {
+ # T, C, H, W
+ "combined_images": target_and_condition_images_tensor.unsqueeze(0),
+ "mask_preserving_target": mask_preserving_target.unsqueeze(0), # T, 1
+ # T, 1
+ "mask_only_preserving_first_target": mask_only_preserving_first_target.unsqueeze(
+ 0
+ ),
+ # T, 1
+ "mask_only_preserving_first_condition": mask_only_preserving_first_condition.unsqueeze(
+ 0
+ ),
+ # T, C, H//8, W//8
+ "combined_rays": target_and_condition_rays_tensor.unsqueeze(0),
+ "combined_fovs": combined_fovs.unsqueeze(0),
+ "target_and_condition_camera_poses": target_and_condition_camera_poses.unsqueeze(
+ 0
+ ),
+ "num_target_images": torch.tensor([num_target_views]),
+ "num_cond_images": torch.tensor([num_cond_views]),
+ "num_cond_images_str": [str(num_cond_views)],
+ "item_idx": [0],
+ "subset_key": ["evaluation"],
+ "caption": [caption],
+ "fov_xy": torch.tensor(fov_xy).float().unsqueeze(0),
+ }
+ return test_data
+
+
+def tensor_to_mp4(video, savepath, fps, nrow=None):
+ """
+ video: torch.Tensor, b,t,c,h,w, value range: 0-1
+ """
+ n = video.shape[0]
+ print("Video shape=", video.shape)
+ video = video.permute(1, 0, 2, 3, 4) # t,n,c,h,w
+ nrow = int(np.sqrt(n)) if nrow is None else nrow
+ frame_grids = [
+ torchvision.utils.make_grid(framesheet, nrow=nrow) for framesheet in video
+ ] # [3, grid_h, grid_w]
+ # stack in temporal dim [T, 3, grid_h, grid_w]
+ grid = torch.stack(frame_grids, dim=0)
+ grid = torch.clamp(grid.float(), -1.0, 1.0)
+ # [T, 3, grid_h, grid_w] -> [T, grid_h, grid_w, 3]
+ grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
+ # print(f'Save video to {savepath}')
+ torchvision.io.write_video(
+ savepath, grid, fps=fps, video_codec="h264", options={"crf": "10"}
+ )
+
+
+def parse_to_np_array(input_string):
+ try:
+ # Try to parse the input as JSON first
+ data = json.loads(input_string)
+ arr = np.array(data)
+ except json.JSONDecodeError:
+ # If JSON parsing fails, assume it's a multi-line string and handle accordingly
+ lines = input_string.strip().splitlines()
+ data = []
+ for line in lines:
+ # Split the line by spaces and convert to floats
+ data.append([float(x) for x in line.split()])
+ arr = np.array(data)
+
+ # Check if the resulting array is 3x4
+ if arr.shape != (3, 4):
+ raise ValueError(f"Expected array shape (3, 4), but got {arr.shape}")
+
+ return arr
+
+
+@spaces.GPU(duration=180)
+def run_inference(
+ camera_mode,
+ input_cond_image1=None,
+ input_cond_image2=None,
+ input_cond_image3=None,
+ input_cond_image4=None,
+ input_pose_format="c2w",
+ model_pose_format="c2w",
+ x_rotation_angle=None,
+ y_rotation_angle=None,
+ z_rotation_angle=None,
+ x_translation=None,
+ y_translation=None,
+ z_translation=None,
+ trajectory_extension_factor=1,
+ cfg_scale=1.0,
+ cfg_scale_extra=1.0,
+ sample_steps=50,
+ num_images_slider=None,
+ spherical_angle_x=10,
+ spherical_angle_y=10,
+ spherical_radius=10,
+ random_seed=1,
+):
+ cfg_scale_extra = 1.0 # Disable Extra CFG due to time limit of ZeroGPU
+ os.makedirs("./cache/", exist_ok=True)
+ with torch.no_grad():
+ with torch.cuda.amp.autocast(dtype=dtype):
+ torch.manual_seed(random_seed)
+ input_cond_images = []
+ for _cond_image in [
+ input_cond_image1,
+ input_cond_image2,
+ input_cond_image3,
+ input_cond_image4,
+ ]:
+ if _cond_image is not None:
+ if isinstance(_cond_image, np.ndarray):
+ _cond_image = PIL.Image.fromarray(_cond_image)
+ input_cond_images.append(_cond_image)
+ num_condition_views = len(input_cond_images)
+ assert (
+ num_images_slider == num_condition_views
+ ), f"The `num_condition_views`={num_condition_views} while got `num_images_slider`={num_images_slider}."
+ input_caption = ""
+ num_target_views = num_views - num_condition_views
+ data_item = compose_data_item(
+ num_views=num_views,
+ cond_pil_image_list=input_cond_images,
+ caption=input_caption,
+ camera_mode=camera_mode,
+ input_pose_format=input_pose_format,
+ model_pose_format=model_pose_format,
+ x_rotation_angle=x_rotation_angle,
+ y_rotation_angle=y_rotation_angle,
+ z_rotation_angle=z_rotation_angle,
+ x_translation=x_translation,
+ y_translation=y_translation,
+ z_translation=z_translation,
+ image_size=[model_resolution_height, model_resolution_width],
+ spherical_angle_x=spherical_angle_x,
+ spherical_angle_y=spherical_angle_y,
+ spherical_radius=spherical_radius,
+ )
+ batch = data_item
+ if trajectory_extension_factor == 1:
+ print("No trajectory extension.")
+ else:
+ print(f"Trajectory is enabled: {trajectory_extension_factor}.")
+ full_x_samples = []
+ for repeat_idx in range(int(trajectory_extension_factor)):
+ if repeat_idx != 0:
+ batch["combined_images"][:, 0, :, :, :] = full_x_samples[-1][
+ :, -1, :, :, :
+ ]
+ batch["combined_images"][:, num_target_views, :, :, :] = (
+ full_x_samples[-1][:, -1, :, :, :]
+ )
+ cond, uc, uc_extra, x_rec = process_inference_batch(
+ cfg_scale, batch, model, with_uncondition_extra=True
+ )
+
+ batch_size = x_rec.shape[0]
+ shape_without_batch = (num_views, channels, latent_h, latent_w)
+ samples, _ = sampler.sample(
+ sample_steps,
+ batch_size=batch_size,
+ shape=shape_without_batch,
+ conditioning=cond,
+ verbose=True,
+ unconditional_conditioning=uc,
+ unconditional_guidance_scale=cfg_scale,
+ unconditional_conditioning_extra=uc_extra,
+ unconditional_guidance_scale_extra=cfg_scale_extra,
+ x_T=None,
+ expand_mode=False,
+ num_target_views=num_views - num_condition_views,
+ num_condition_views=num_condition_views,
+ dense_expansion_ratio=None,
+ pred_x0_post_process_function=None,
+ pred_x0_post_process_function_kwargs=None,
+ )
+
+ if samples.size(2) > 4:
+ image_samples = samples[:, :num_target_views, :4, :, :]
+ else:
+ image_samples = samples
+ per_instance_decoding = False
+ if per_instance_decoding:
+ x_samples = []
+ for item_idx in range(image_samples.shape[0]):
+ image_samples = image_samples[
+ item_idx : item_idx + 1, :, :, :, :
+ ]
+ x_sample = model.decode_first_stage(image_samples)
+ x_samples.append(x_sample)
+ x_samples = torch.cat(x_samples, dim=0)
+ else:
+ x_samples = model.decode_first_stage(image_samples)
+ full_x_samples.append(x_samples[:, :num_target_views, ...])
+
+ full_x_samples = torch.concat(full_x_samples, dim=1)
+ x_samples = full_x_samples
+ x_samples = torch.clamp((x_samples + 1.0) / 2.0, 0.0, 1.0)
+ video_name = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") + ".mp4"
+ video_path = "./cache/" + video_name
+ tensor_to_mp4(x_samples.detach().cpu(), fps=6, savepath=video_path)
+ return video_path
+
+
+with gr.Blocks() as demo:
+ gr.HTML(
+ """
+
+
+
+ - Choose camera movement mode: Spherical Mode or Rotation & Translation Mode.
+ - Customize the camera trajectory: Adjust the spherical parameters or rotation/translations along the X, Y,
+ and Z axes.
+ - Upload images: You can upload up to 4 images as input conditions.
+ - Set sampling parameters (optional): Tweak the settings and click the Generate button.
+
+
+ ⏱️ ZeroGPU Time Limit: Hugging Face ZeroGPU has a inference time limit of 180 seconds.
+ You may need to log in with a free account to use this demo.
+ Large sampling steps might lead to timeout (GPU Abort).
+ In that case, please consider log in with a Pro account or run it on your local machine.
+
+
🤗 Please 🌟 star our GitHub repo
+ and click on the ❤️ like button above if you find our work helpful.
+
+
+ """
+ )
+ with gr.Row():
+ with gr.Column(scale=1):
+ with gr.Accordion("Camera Movement Settings", open=True):
+ camera_mode = gr.Radio(
+ choices=[("Spherical Mode", 0), ("Rotation & Translation Mode", 1)],
+ label="Camera Mode",
+ value=0,
+ interactive=True,
+ )
+
+ with gr.Group(visible=True) as group_spherical:
+ # This tab can be left blank for now as per your request
+ # Add extra options manually here in the future
+ gr.HTML(
+ """
+ Spherical Mode allows you to control the camera's movement by specifying its position on a sphere centered around the scene.
+ Adjust the Polar Angle (vertical rotation), Azimuth Angle (horizontal rotation), and Radius (distance from the center of the anchor view) to define the camera's viewpoint.
+ The anchor view is considered located on the sphere at the specified radius, aligned with a zero polar angle and zero azimuth angle, oriented toward the origin.
+
+ """
+ )
+ spherical_angle_x = gr.Slider(
+ minimum=-30,
+ maximum=30,
+ step=1,
+ value=0,
+ label="Polar Angle (Theta)",
+ )
+ spherical_angle_y = gr.Slider(
+ minimum=-30,
+ maximum=30,
+ step=1,
+ value=5,
+ label="Azimuth Angle (Phi)",
+ )
+ spherical_radius = gr.Slider(
+ minimum=0.5, maximum=1.5, step=0.1, value=1, label="Radius"
+ )
+
+ with gr.Group(visible=False) as group_move_rotation_translation:
+ gr.HTML(
+ """
+ Rotation & Translation Mode lets you directly define how the camera moves and rotates in the 3D space.
+ Use Rotation X/Y/Z to control the camera's orientation and Translation X/Y/Z to shift its position.
+ The anchor view serves as the starting point, with no initial rotation or translation applied.
+
+ """
+ )
+ rotation_x = gr.Slider(
+ minimum=-20, maximum=20, step=1, value=0, label="Rotation X"
+ )
+ rotation_y = gr.Slider(
+ minimum=-20, maximum=20, step=1, value=0, label="Rotation Y"
+ )
+ rotation_z = gr.Slider(
+ minimum=-20, maximum=20, step=1, value=0, label="Rotation Z"
+ )
+ translation_x = gr.Slider(
+ minimum=-1, maximum=1, step=0.1, value=0, label="Translation X"
+ )
+ translation_y = gr.Slider(
+ minimum=-1, maximum=1, step=0.1, value=0, label="Translation Y"
+ )
+ translation_z = gr.Slider(
+ minimum=-1,
+ maximum=1,
+ step=0.1,
+ value=-0.2,
+ label="Translation Z",
+ )
+
+ input_camera_pose_format = gr.Radio(
+ choices=["W2C", "C2W"],
+ value="C2W",
+ label="Input Camera Pose Format",
+ visible=False,
+ )
+ model_camera_pose_format = gr.Radio(
+ choices=["W2C", "C2W"],
+ value="C2W",
+ label="Model Camera Pose Format",
+ visible=False,
+ )
+
+ def on_change_selected_camera_settings(_id):
+ return [gr.update(visible=_id == 0), gr.update(visible=_id == 1)]
+
+ camera_mode.change(
+ fn=on_change_selected_camera_settings,
+ inputs=camera_mode,
+ outputs=[group_spherical, group_move_rotation_translation],
+ )
+
+ with gr.Accordion("Advanced Sampling Settings"):
+ cfg_scale = gr.Slider(
+ value=3.0,
+ label="Classifier-Free Guidance Scale",
+ minimum=1,
+ maximum=10,
+ step=0.1,
+ )
+ extra_cfg_scale = gr.Slider(
+ value=1.0,
+ label="Extra Classifier-Free Guidance Scale",
+ minimum=1,
+ maximum=10,
+ step=0.1,
+ visible=False,
+ )
+ sample_steps = gr.Slider(
+ value=18, label="DDIM Sample Steps", minimum=0, maximum=25, step=1
+ )
+ trajectory_extension_factor = gr.Slider(
+ value=1,
+ label="Trajectory Extension (proportional to runtime)",
+ minimum=1,
+ maximum=3,
+ step=1,
+ )
+ random_seed = gr.Slider(
+ value=1024, minimum=1, maximum=9999, step=1, label="Random Seed"
+ )
+
+ def on_change_trajectory_extension_factor(_val):
+ if _val == 1:
+ return [
+ gr.update(minimum=-30, maximum=30),
+ gr.update(minimum=-30, maximum=30),
+ gr.update(minimum=0.5, maximum=1.5),
+ gr.update(minimum=-20, maximum=20),
+ gr.update(minimum=-20, maximum=20),
+ gr.update(minimum=-20, maximum=20),
+ gr.update(minimum=-1, maximum=1),
+ gr.update(minimum=-1, maximum=1),
+ gr.update(minimum=-1, maximum=1),
+ ]
+ elif _val == 2:
+ return [
+ gr.update(minimum=-15, maximum=15),
+ gr.update(minimum=-15, maximum=15),
+ gr.update(minimum=0.5, maximum=1.5),
+ gr.update(minimum=-10, maximum=10),
+ gr.update(minimum=-10, maximum=10),
+ gr.update(minimum=-10, maximum=10),
+ gr.update(minimum=-0.5, maximum=0.5),
+ gr.update(minimum=-0.5, maximum=0.5),
+ gr.update(minimum=-0.5, maximum=0.5),
+ ]
+ elif _val == 3:
+ return [
+ gr.update(minimum=-10, maximum=10),
+ gr.update(minimum=-10, maximum=10),
+ gr.update(minimum=0.5, maximum=1.5),
+ gr.update(minimum=-6, maximum=6),
+ gr.update(minimum=-6, maximum=6),
+ gr.update(minimum=-6, maximum=6),
+ gr.update(minimum=-0.3, maximum=0.3),
+ gr.update(minimum=-0.3, maximum=0.3),
+ gr.update(minimum=-0.3, maximum=0.3),
+ ]
+
+ trajectory_extension_factor.change(
+ fn=on_change_trajectory_extension_factor,
+ inputs=trajectory_extension_factor,
+ outputs=[
+ spherical_angle_x,
+ spherical_angle_y,
+ spherical_radius,
+ rotation_x,
+ rotation_y,
+ rotation_z,
+ translation_x,
+ translation_y,
+ translation_z,
+ ],
+ )
+
+ with gr.Column(scale=1):
+ with gr.Accordion("Input Image(s)", open=True):
+ num_images_slider = gr.Slider(
+ minimum=1,
+ maximum=4,
+ step=1,
+ value=1,
+ label="Number of Input Image(s)",
+ )
+ condition_image_1 = gr.Image(label="Input Image 1 (Anchor View)")
+ condition_image_2 = gr.Image(label="Input Image 2", visible=False)
+ condition_image_3 = gr.Image(label="Input Image 3", visible=False)
+ condition_image_4 = gr.Image(label="Input Image 4", visible=False)
+
+ with gr.Column(scale=1):
+ with gr.Accordion("Output Video", open=True):
+ output_video = gr.Video(label="Output Video")
+ run_btn = gr.Button("Generate")
+ with gr.Accordion("Notes", open=True):
+ gr.HTML(
+ """
+
+🧐 Reminder:
+ As a generative model, NVComposer may occasionally produce unexpected outputs.
+ Try adjusting the random seed, sampling steps, or CFG scales to explore different results.
+
+🤔 Longer Generation:
+ If you need longer video, you can increase the trajectory extension value in the advanced sampling settings and run with your own GPU.
+ This extends the defined camera trajectory by repeating it, allowing for a longer output.
+ This also requires using smaller rotation or translation scales to maintain smooth transitions and will increase the generation time.
+🤗 Limitation:
+ This is the initial beta version of NVComposer.
+ Its generalizability may be limited in certain scenarios, and artifacts can appear with large camera motions due to the current foundation model's constraints.
+ We’re actively working on an improved version with enhanced datasets and a more powerful foundation model,
+ and we are looking for collaboration opportunities from the community.
+✨ We welcome your feedback and questions. Thank you!
+ """
+ )
+
+ with gr.Row():
+ gr.Examples(
+ label="Quick Examples",
+ examples=EXAMPLES,
+ inputs=[
+ condition_image_1,
+ condition_image_2,
+ camera_mode,
+ spherical_angle_x,
+ spherical_angle_y,
+ spherical_radius,
+ rotation_x,
+ rotation_y,
+ rotation_z,
+ translation_x,
+ translation_y,
+ translation_z,
+ cfg_scale,
+ extra_cfg_scale,
+ sample_steps,
+ output_video,
+ num_images_slider,
+ ],
+ examples_per_page=5,
+ cache_examples=False,
+ )
+
+ # Update visibility of condition images based on the slider
+ def update_visible_images(num_images):
+ return [
+ gr.update(visible=num_images >= 2),
+ gr.update(visible=num_images >= 3),
+ gr.update(visible=num_images >= 4),
+ ]
+
+ # Trigger visibility update when the slider value changes
+ num_images_slider.change(
+ fn=update_visible_images,
+ inputs=num_images_slider,
+ outputs=[condition_image_2, condition_image_3, condition_image_4],
+ )
+
+ run_btn.click(
+ fn=run_inference,
+ inputs=[
+ camera_mode,
+ condition_image_1,
+ condition_image_2,
+ condition_image_3,
+ condition_image_4,
+ input_camera_pose_format,
+ model_camera_pose_format,
+ rotation_x,
+ rotation_y,
+ rotation_z,
+ translation_x,
+ translation_y,
+ translation_z,
+ trajectory_extension_factor,
+ cfg_scale,
+ extra_cfg_scale,
+ sample_steps,
+ num_images_slider,
+ spherical_angle_x,
+ spherical_angle_y,
+ spherical_radius,
+ random_seed,
+ ],
+ outputs=output_video,
+ )
+
+demo.launch()
diff --git a/assets/sample1.jpg b/assets/sample1.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..37d221e001cd8f419498922000d67bb038f37136
--- /dev/null
+++ b/assets/sample1.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:821bdc48093db88c75b89e124de2d3511ee3d6f17617ffc94bcc5b30ebe7d295
+size 1496521
diff --git a/assets/sample1.mp4 b/assets/sample1.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..5b8d73a713ac51dca6f8803a99881efba0debbc8
--- /dev/null
+++ b/assets/sample1.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:242b3617d2c50a9f175619827974b90dd665fa51ae06b2cc7bb9373248f5f8d1
+size 2513866
diff --git a/assets/sample2.jpg b/assets/sample2.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..9f8fc25c3abe826cb6330acd09a1d4088e68f0ec
--- /dev/null
+++ b/assets/sample2.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:085b781f0330692c746e6f9e2d28f24fbfe0285db1b5ec94383037200b673b0a
+size 153749
diff --git a/assets/sample2.mp4 b/assets/sample2.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..76b9897bb9845d6e4a02a2fad317e37ac90be461
--- /dev/null
+++ b/assets/sample2.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:754a3246734802261d445878fa7ad5f5b860deceb597d3bb439547078b7f0281
+size 2369420
diff --git a/assets/sample3.jpg b/assets/sample3.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..351e3c3c3a012a495ed052c14792eb2465bcc4dc
--- /dev/null
+++ b/assets/sample3.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0d55565e60ea8a80e7c09ef8f1f3ee4e64b507571174cf79f0c65c3d8cdcb1de
+size 756562
diff --git a/assets/sample3.mp4 b/assets/sample3.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..42ed428b04ac72e475b3212ab2bb8ad4e83f8638
--- /dev/null
+++ b/assets/sample3.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:047c45a0e93627e63464fe939484acbfbfc9087c43d366618c8b34dd331ba3f5
+size 4129878
diff --git a/assets/sample4.jpg b/assets/sample4.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..351e3c3c3a012a495ed052c14792eb2465bcc4dc
--- /dev/null
+++ b/assets/sample4.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0d55565e60ea8a80e7c09ef8f1f3ee4e64b507571174cf79f0c65c3d8cdcb1de
+size 756562
diff --git a/assets/sample4.mp4 b/assets/sample4.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..b763f4df3d10eb41739e0c2304021cb4b45d2de9
--- /dev/null
+++ b/assets/sample4.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5a76a2950e6b7cc82bc2e36c04d312a1e5babc25cdd489b25a542862912b9f62
+size 4118935
diff --git a/assets/sample5-1.png b/assets/sample5-1.png
new file mode 100644
index 0000000000000000000000000000000000000000..cd60d981c1aec4b6d146a3b33a81f64ce4e446d7
--- /dev/null
+++ b/assets/sample5-1.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6c41016f7cc5acd012ab0251d4e4a2a698b9160aaacf017e0aa6053786d87f58
+size 1193293
diff --git a/assets/sample5-2.png b/assets/sample5-2.png
new file mode 100644
index 0000000000000000000000000000000000000000..1a8712b2e58b694c4216fc514a82be9e5e5526d4
--- /dev/null
+++ b/assets/sample5-2.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c06371aa6dd628f733adec128bb650a4c2aa710f26f9c8e266f18b1bf9b536a2
+size 1187684
diff --git a/assets/sample5.mp4 b/assets/sample5.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..9f264cac3e4a0ae0558d996e60d065a2e5dda4dc
--- /dev/null
+++ b/assets/sample5.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d4846600b8e47774729e3bc199d4f08399d983cd7213b724ede7a5ed9057a3d5
+size 4124063
diff --git a/configs/dual_stream/nvcomposer.yaml b/configs/dual_stream/nvcomposer.yaml
new file mode 100755
index 0000000000000000000000000000000000000000..dcd7c020ad347b1944c9ff2acf5ae618e73e5e83
--- /dev/null
+++ b/configs/dual_stream/nvcomposer.yaml
@@ -0,0 +1,139 @@
+num_frames: &num_frames 16
+resolution: &resolution [576, 1024]
+model:
+ base_learning_rate: 1.0e-5
+ scale_lr: false
+ target: core.models.diffusion.DualStreamMultiViewDiffusionModel
+ params:
+ use_task_embedding: false
+ ray_as_image: false
+ apply_condition_mask_in_training_loss: true
+ separate_noise_and_condition: true
+ condition_padding_with_anchor: false
+ use_ray_decoder_loss_high_frequency_isolation: false
+ train_with_multi_view_feature_alignment: true
+ use_text_cross_attention_condition: false
+
+ linear_start: 0.00085
+ linear_end: 0.012
+ num_time_steps_cond: 1
+ log_every_t: 200
+ time_steps: 1000
+
+ data_key_images: combined_images
+ data_key_rays: combined_rays
+ data_key_text_condition: caption
+ cond_stage_trainable: false
+ image_size: [72, 128]
+
+ channels: 10
+ monitor: global_step
+ scale_by_std: false
+ scale_factor: 0.18215
+ use_dynamic_rescale: true
+ base_scale: 0.3
+
+ use_ema: false
+ uncond_prob: 0.05
+ uncond_type: 'empty_seq'
+
+ use_camera_pose_query_transformer: false
+ random_cond: false
+ cond_concat: true
+ frame_mask: false
+ padding: true
+ per_frame_auto_encoding: true
+ parameterization: "v"
+ rescale_betas_zero_snr: true
+ use_noise_offset: false
+ scheduler_config:
+ target: utils.lr_scheduler.LambdaLRScheduler
+ interval: 'step'
+ frequency: 100
+ params:
+ start_step: 0
+ final_decay_ratio: 0.1
+ decay_steps: 100
+ bd_noise: false
+
+ unet_config:
+ target: core.modules.networks.unet_modules.UNetModel
+ params:
+ in_channels: 20
+ out_channels: 10
+ model_channels: 320
+ attention_resolutions:
+ - 4
+ - 2
+ - 1
+ num_res_blocks: 2
+ channel_mult:
+ - 1
+ - 2
+ - 4
+ - 4
+ dropout: 0.1
+ num_head_channels: 64
+ transformer_depth: 1
+ context_dim: 1024
+ use_linear: true
+ use_checkpoint: true
+ temporal_conv: true
+ temporal_attention: true
+ temporal_selfatt_only: true
+ use_relative_position: false
+ use_causal_attention: false
+ temporal_length: *num_frames
+ addition_attention: true
+ image_cross_attention: true
+ image_cross_attention_scale_learnable: true
+ default_fs: 3
+ fs_condition: false
+ use_spatial_temporal_attention: true
+ use_addition_ray_output_head: true
+ ray_channels: 6
+ use_lora_for_rays_in_output_blocks: false
+ use_task_embedding: false
+ use_ray_decoder: true
+ use_ray_decoder_residual: true
+ full_spatial_temporal_attention: true
+ enhance_multi_view_correspondence: false
+ camera_pose_condition: true
+ use_feature_alignment: true
+
+ first_stage_config:
+ target: core.models.autoencoder.AutoencoderKL
+ params:
+ embed_dim: 4
+ monitor: val/rec_loss
+ ddconfig:
+ double_z: true
+ z_channels: 4
+ resolution: 256
+ in_channels: 3
+ out_ch: 3
+ ch: 128
+ ch_mult: [1, 2, 4, 4]
+ num_res_blocks: 2
+ attn_resolutions: []
+ dropout: 0.0
+ lossconfig:
+ target: torch.nn.Identity
+
+ cond_img_config:
+ target: core.modules.encoders.condition.FrozenOpenCLIPImageEmbedderV2
+ params:
+ freeze: true
+
+ image_proj_model_config:
+ target: core.modules.encoders.resampler.Resampler
+ params:
+ dim: 1024
+ depth: 4
+ dim_head: 64
+ heads: 12
+ num_queries: 16
+ embedding_dim: 1280
+ output_dim: 1024
+ ff_mult: 4
+ video_length: *num_frames
diff --git a/core/basics.py b/core/basics.py
new file mode 100755
index 0000000000000000000000000000000000000000..fc1380ab7cab2d0a311e4855bbf74d6474e9a035
--- /dev/null
+++ b/core/basics.py
@@ -0,0 +1,95 @@
+import torch.nn as nn
+
+from utils.utils import instantiate_from_config
+
+
+def disabled_train(self, mode=True):
+ """Overwrite model.train with this function to make sure train/eval mode
+ does not change anymore."""
+ return self
+
+
+def zero_module(module):
+ """
+ Zero out the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+
+
+def scale_module(module, scale):
+ """
+ Scale the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().mul_(scale)
+ return module
+
+
+def conv_nd(dims, *args, **kwargs):
+ """
+ Create a 1D, 2D, or 3D convolution module.
+ """
+ if dims == 1:
+ return nn.Conv1d(*args, **kwargs)
+ elif dims == 2:
+ return nn.Conv2d(*args, **kwargs)
+ elif dims == 3:
+ return nn.Conv3d(*args, **kwargs)
+ raise ValueError(f"unsupported dimensions: {dims}")
+
+
+def linear(*args, **kwargs):
+ """
+ Create a linear module.
+ """
+ return nn.Linear(*args, **kwargs)
+
+
+def avg_pool_nd(dims, *args, **kwargs):
+ """
+ Create a 1D, 2D, or 3D average pooling module.
+ """
+ if dims == 1:
+ return nn.AvgPool1d(*args, **kwargs)
+ elif dims == 2:
+ return nn.AvgPool2d(*args, **kwargs)
+ elif dims == 3:
+ return nn.AvgPool3d(*args, **kwargs)
+ raise ValueError(f"unsupported dimensions: {dims}")
+
+
+def nonlinearity(type="silu"):
+ if type == "silu":
+ return nn.SiLU()
+ elif type == "leaky_relu":
+ return nn.LeakyReLU()
+
+
+class GroupNormSpecific(nn.GroupNorm):
+ def forward(self, x):
+ return super().forward(x.float()).type(x.dtype)
+
+
+def normalization(channels, num_groups=32):
+ """
+ Make a standard normalization layer.
+ :param channels: number of input channels.
+ :param num_groups: number of groupseg.
+ :return: an nn.Module for normalization.
+ """
+ return GroupNormSpecific(num_groups, channels)
+
+
+class HybridConditioner(nn.Module):
+
+ def __init__(self, c_concat_config, c_crossattn_config):
+ super().__init__()
+ self.concat_conditioner = instantiate_from_config(c_concat_config)
+ self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
+
+ def forward(self, c_concat, c_crossattn):
+ c_concat = self.concat_conditioner(c_concat)
+ c_crossattn = self.crossattn_conditioner(c_crossattn)
+ return {"c_concat": [c_concat], "c_crossattn": [c_crossattn]}
diff --git a/core/common.py b/core/common.py
new file mode 100755
index 0000000000000000000000000000000000000000..2f9a973af7e2051e7c219f9743edae29fe1e0d8c
--- /dev/null
+++ b/core/common.py
@@ -0,0 +1,167 @@
+import math
+from inspect import isfunction
+
+import torch
+import torch.distributed as dist
+from torch import nn
+
+
+def gather_data(data, return_np=True):
+ """gather data from multiple processes to one list"""
+ data_list = [torch.zeros_like(data) for _ in range(dist.get_world_size())]
+ dist.all_gather(data_list, data) # gather not supported with NCCL
+ if return_np:
+ data_list = [data.cpu().numpy() for data in data_list]
+ return data_list
+
+
+def autocast(f):
+ def do_autocast(*args, **kwargs):
+ with torch.cuda.amp.autocast(
+ enabled=True,
+ dtype=torch.get_autocast_gpu_dtype(),
+ cache_enabled=torch.is_autocast_cache_enabled(),
+ ):
+ return f(*args, **kwargs)
+
+ return do_autocast
+
+
+def extract_into_tensor(a, t, x_shape):
+ b, *_ = t.shape
+ out = a.gather(-1, t)
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
+
+
+def noise_like(shape, device, repeat=False):
+ def repeat_noise():
+ return torch.randn((1, *shape[1:]), device=device).repeat(
+ shape[0], *((1,) * (len(shape) - 1))
+ )
+
+ def noise():
+ return torch.randn(shape, device=device)
+
+ return repeat_noise() if repeat else noise()
+
+
+def default(val, d):
+ if exists(val):
+ return val
+ return d() if isfunction(d) else d
+
+
+def exists(val):
+ return val is not None
+
+
+def identity(*args, **kwargs):
+ return nn.Identity()
+
+
+def uniq(arr):
+ return {el: True for el in arr}.keys()
+
+
+def mean_flat(tensor):
+ """
+ Take the mean over all non-batch dimensions.
+ """
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
+
+
+def ismap(x):
+ if not isinstance(x, torch.Tensor):
+ return False
+ return (len(x.shape) == 4) and (x.shape[1] > 3)
+
+
+def isimage(x):
+ if not isinstance(x, torch.Tensor):
+ return False
+ return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
+
+
+def max_neg_value(t):
+ return -torch.finfo(t.dtype).max
+
+
+def shape_to_str(x):
+ shape_str = "x".join([str(x) for x in x.shape])
+ return shape_str
+
+
+def init_(tensor):
+ dim = tensor.shape[-1]
+ std = 1 / math.sqrt(dim)
+ tensor.uniform_(-std, std)
+ return tensor
+
+
+# USE_DEEP_SPEED_CHECKPOINTING = False
+# if USE_DEEP_SPEED_CHECKPOINTING:
+# import deepspeed
+#
+# _gradient_checkpoint_function = deepspeed.checkpointing.checkpoint
+# else:
+_gradient_checkpoint_function = torch.utils.checkpoint.checkpoint
+
+
+def gradient_checkpoint(func, inputs, params, flag):
+ """
+ Evaluate a function without caching intermediate activations, allowing for
+ reduced memory at the expense of extra compute in the backward pass.
+ :param func: the function to evaluate.
+ :param inputs: the argument sequence to pass to `func`.
+ :param params: a sequence of parameters `func` depends on but does not
+ explicitly take as arguments.
+ :param flag: if False, disable gradient checkpointing.
+ """
+ if flag:
+ # args = tuple(inputs) + tuple(params)
+ # return CheckpointFunction.apply(func, len(inputs), *args)
+ if isinstance(inputs, tuple):
+ return _gradient_checkpoint_function(func, *inputs, use_reentrant=False)
+ else:
+ return _gradient_checkpoint_function(func, inputs, use_reentrant=False)
+ else:
+ return func(*inputs)
+
+
+class CheckpointFunction(torch.autograd.Function):
+ @staticmethod
+ @torch.cuda.amp.custom_fwd
+ def forward(ctx, run_function, length, *args):
+ ctx.run_function = run_function
+ ctx.input_tensors = list(args[:length])
+ ctx.input_params = list(args[length:])
+
+ with torch.no_grad():
+ output_tensors = ctx.run_function(*ctx.input_tensors)
+ return output_tensors
+
+ @staticmethod
+ @torch.cuda.amp.custom_bwd # add this
+ def backward(ctx, *output_grads):
+ """
+ for x in ctx.input_tensors:
+ if isinstance(x, int):
+ print('-----------------', ctx.run_function)
+ """
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
+ with torch.enable_grad():
+ # Fixes a bug where the first op in run_function modifies the
+ # Tensor storage in place, which is not allowed for detach()'d
+ # Tensors.
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
+ output_tensors = ctx.run_function(*shallow_copies)
+ input_grads = torch.autograd.grad(
+ output_tensors,
+ ctx.input_tensors + ctx.input_params,
+ output_grads,
+ allow_unused=True,
+ )
+ del ctx.input_tensors
+ del ctx.input_params
+ del output_tensors
+ return (None, None) + input_grads
diff --git a/core/data/__init__.py b/core/data/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/core/data/camera_pose_utils.py b/core/data/camera_pose_utils.py
new file mode 100755
index 0000000000000000000000000000000000000000..adc98677ee7a0a6b0c4ef5c3af22b8fb909e5f14
--- /dev/null
+++ b/core/data/camera_pose_utils.py
@@ -0,0 +1,277 @@
+import copy
+import numpy as np
+import torch
+from scipy.spatial.transform import Rotation as R
+
+
+def get_opencv_from_blender(matrix_world, fov, image_size):
+ # convert matrix_world to opencv format extrinsics
+ opencv_world_to_cam = matrix_world.inverse()
+ opencv_world_to_cam[1, :] *= -1
+ opencv_world_to_cam[2, :] *= -1
+ R, T = opencv_world_to_cam[:3, :3], opencv_world_to_cam[:3, 3]
+ R, T = R.unsqueeze(0), T.unsqueeze(0)
+
+ # convert fov to opencv format intrinsics
+ focal = 1 / np.tan(fov / 2)
+ intrinsics = np.diag(np.array([focal, focal, 1])).astype(np.float32)
+ opencv_cam_matrix = torch.from_numpy(intrinsics).unsqueeze(0).float()
+ opencv_cam_matrix[:, :2, -1] += torch.tensor([image_size / 2, image_size / 2])
+ opencv_cam_matrix[:, [0, 1], [0, 1]] *= image_size / 2
+
+ return R, T, opencv_cam_matrix
+
+
+def cartesian_to_spherical(xyz):
+ xy = xyz[:, 0] ** 2 + xyz[:, 1] ** 2
+ z = np.sqrt(xy + xyz[:, 2] ** 2)
+ # for elevation angle defined from z-axis down
+ theta = np.arctan2(np.sqrt(xy), xyz[:, 2])
+ azimuth = np.arctan2(xyz[:, 1], xyz[:, 0])
+ return np.stack([theta, azimuth, z], axis=-1)
+
+
+def spherical_to_cartesian(spherical_coords):
+ # convert from spherical to cartesian coordinates
+ theta, azimuth, radius = spherical_coords.T
+ x = radius * np.sin(theta) * np.cos(azimuth)
+ y = radius * np.sin(theta) * np.sin(azimuth)
+ z = radius * np.cos(theta)
+ return np.stack([x, y, z], axis=-1)
+
+
+def look_at(eye, center, up):
+ # Create a normalized direction vector from eye to center
+ f = np.array(center) - np.array(eye)
+ f /= np.linalg.norm(f)
+
+ # Create a normalized right vector
+ up_norm = np.array(up) / np.linalg.norm(up)
+ s = np.cross(f, up_norm)
+ s /= np.linalg.norm(s)
+
+ # Recompute the up vector
+ u = np.cross(s, f)
+
+ # Create rotation matrix R
+ R = np.array([[s[0], s[1], s[2]], [u[0], u[1], u[2]], [-f[0], -f[1], -f[2]]])
+
+ # Create translation vector T
+ T = -np.dot(R, np.array(eye))
+
+ return R, T
+
+
+def get_blender_from_spherical(elevation, azimuth):
+ """Generates blender camera from spherical coordinates."""
+
+ cartesian_coords = spherical_to_cartesian(np.array([[elevation, azimuth, 3.5]]))
+
+ # get camera rotation
+ center = np.array([0, 0, 0])
+ eye = cartesian_coords[0]
+ up = np.array([0, 0, 1])
+
+ R, T = look_at(eye, center, up)
+ R = R.T
+ T = -np.dot(R, T)
+ RT = np.concatenate([R, T.reshape(3, 1)], axis=-1)
+
+ blender_cam = torch.from_numpy(RT).float()
+ blender_cam = torch.cat([blender_cam, torch.tensor([[0, 0, 0, 1]])], dim=0)
+ print(blender_cam)
+ return blender_cam
+
+
+def invert_pose(r, t):
+ r_inv = r.T
+ t_inv = -np.dot(r_inv, t)
+ return r_inv, t_inv
+
+
+def transform_pose_sequence_to_relative(poses, as_z_up=False):
+ """
+ poses: a sequence of 3*4 C2W camera pose matrices
+ as_z_up: output in z-up format. If False, the output is in y-up format
+ """
+ r0, t0 = poses[0][:3, :3], poses[0][:3, 3]
+ # r0_inv, t0_inv = invert_pose(r0, t0)
+ r0_inv = r0.T
+ new_rt0 = np.hstack([np.eye(3, 3), np.zeros((3, 1))])
+ if as_z_up:
+ new_rt0 = c2w_y_up_to_z_up(new_rt0)
+ transformed_poses = [new_rt0]
+ for pose in poses[1:]:
+ r, t = pose[:3, :3], pose[:3, 3]
+ new_r = np.dot(r0_inv, r)
+ new_t = np.dot(r0_inv, t - t0)
+ new_rt = np.hstack([new_r, new_t[:, None]])
+ if as_z_up:
+ new_rt = c2w_y_up_to_z_up(new_rt)
+ transformed_poses.append(new_rt)
+ return transformed_poses
+
+
+def c2w_y_up_to_z_up(c2w_3x4):
+ R_y_up_to_z_up = np.array([[1, 0, 0], [0, 0, -1], [0, 1, 0]])
+
+ R = c2w_3x4[:, :3]
+ t = c2w_3x4[:, 3]
+
+ R_z_up = R_y_up_to_z_up @ R
+ t_z_up = R_y_up_to_z_up @ t
+
+ T_z_up = np.hstack((R_z_up, t_z_up.reshape(3, 1)))
+
+ return T_z_up
+
+
+def transform_pose_sequence_to_relative_w2c(poses):
+ new_rt_list = []
+ first_frame_rt = copy.deepcopy(poses[0])
+ first_frame_r_inv = first_frame_rt[:, :3].T
+ first_frame_t = first_frame_rt[:, -1]
+ for rt in poses:
+ rt[:, :3] = np.matmul(rt[:, :3], first_frame_r_inv)
+ rt[:, -1] = rt[:, -1] - np.matmul(rt[:, :3], first_frame_t)
+ new_rt_list.append(copy.deepcopy(rt))
+ return new_rt_list
+
+
+def transform_pose_sequence_to_relative_c2w(poses):
+ first_frame_rt = poses[0]
+ first_frame_r_inv = first_frame_rt[:, :3].T
+ first_frame_t = first_frame_rt[:, -1]
+ rotations = poses[:, :, :3]
+ translations = poses[:, :, 3]
+
+ # Compute new rotations and translations in batch
+ new_rotations = torch.matmul(first_frame_r_inv, rotations)
+ new_translations = torch.matmul(
+ first_frame_r_inv, (translations - first_frame_t.unsqueeze(0)).unsqueeze(-1)
+ )
+ # Concatenate new rotations and translations
+ new_rt = torch.cat([new_rotations, new_translations], dim=-1)
+
+ return new_rt
+
+
+def convert_w2c_between_c2w(poses):
+ rotations = poses[:, :, :3]
+ translations = poses[:, :, 3]
+ new_rotations = rotations.transpose(-1, -2)
+ new_translations = torch.matmul(-new_rotations, translations.unsqueeze(-1))
+ new_rt = torch.cat([new_rotations, new_translations], dim=-1)
+ return new_rt
+
+
+def slerp(q1, q2, t):
+ """
+ Performs spherical linear interpolation (SLERP) between two quaternions.
+
+ Args:
+ q1 (torch.Tensor): Start quaternion (4,).
+ q2 (torch.Tensor): End quaternion (4,).
+ t (float or torch.Tensor): Interpolation parameter in [0, 1].
+
+ Returns:
+ torch.Tensor: Interpolated quaternion (4,).
+ """
+ q1 = q1 / torch.linalg.norm(q1) # Normalize q1
+ q2 = q2 / torch.linalg.norm(q2) # Normalize q2
+
+ dot = torch.dot(q1, q2)
+
+ # Ensure shortest path (flip q2 if needed)
+ if dot < 0.0:
+ q2 = -q2
+ dot = -dot
+
+ # Avoid numerical precision issues
+ dot = torch.clamp(dot, -1.0, 1.0)
+
+ theta = torch.acos(dot) # Angle between q1 and q2
+
+ if theta < 1e-6: # If very close, use linear interpolation
+ return (1 - t) * q1 + t * q2
+
+ sin_theta = torch.sin(theta)
+
+ return (torch.sin((1 - t) * theta) / sin_theta) * q1 + (
+ torch.sin(t * theta) / sin_theta
+ ) * q2
+
+
+def interpolate_camera_poses(c2w: torch.Tensor, factor: int) -> torch.Tensor:
+ """
+ Interpolates a sequence of camera c2w poses to N times the length of the original sequence.
+
+ Args:
+ c2w (torch.Tensor): Input camera poses of shape (N, 3, 4).
+ factor (int): The upsampling factor (e.g., 2 for doubling the length).
+
+ Returns:
+ torch.Tensor: Interpolated camera poses of shape (N * factor, 3, 4).
+ """
+ assert c2w.ndim == 3 and c2w.shape[1:] == (
+ 3,
+ 4,
+ ), "Input tensor must have shape (N, 3, 4)."
+ assert factor > 1, "Upsampling factor must be greater than 1."
+
+ N = c2w.shape[0]
+ new_length = N * factor
+
+ # Extract rotations (R) and translations (T)
+ rotations = c2w[:, :3, :3] # Shape (N, 3, 3)
+ translations = c2w[:, :3, 3] # Shape (N, 3)
+
+ # Convert rotations to quaternions for interpolation
+ quaternions = torch.tensor(
+ R.from_matrix(rotations.numpy()).as_quat()
+ ) # Shape (N, 4)
+
+ # Initialize interpolated quaternions and translations
+ interpolated_quats = []
+ interpolated_translations = []
+
+ # Perform interpolation
+ for i in range(N - 1):
+ # Start and end quaternions and translations for this segment
+ q1, q2 = quaternions[i], quaternions[i + 1]
+ t1, t2 = translations[i], translations[i + 1]
+
+ # Time steps for interpolation within this segment
+ t_values = torch.linspace(0, 1, factor, dtype=torch.float32)
+
+ # Interpolate quaternions using SLERP
+ for t in t_values:
+ interpolated_quats.append(slerp(q1, q2, t))
+
+ # Interpolate translations linearly
+ interp_t = t1 * (1 - t_values[:, None]) + t2 * t_values[:, None]
+ interpolated_translations.append(interp_t)
+
+ interpolated_quats.append(quaternions[0])
+ interpolated_translations.append(translations[0].unsqueeze(0))
+ # Add the last pose (end of sequence)
+ interpolated_quats.append(quaternions[-1])
+ interpolated_translations.append(translations[-1].unsqueeze(0)) # Add as 2D tensor
+
+ # Combine interpolated results
+ interpolated_quats = torch.stack(interpolated_quats, dim=0) # Shape (new_length, 4)
+ interpolated_translations = torch.cat(
+ interpolated_translations, dim=0
+ ) # Shape (new_length, 3)
+
+ # Convert quaternions back to rotation matrices
+ interpolated_rotations = torch.tensor(
+ R.from_quat(interpolated_quats.numpy()).as_matrix()
+ ) # Shape (new_length, 3, 3)
+
+ # Form final c2w matrix
+ interpolated_c2w = torch.zeros((new_length, 3, 4), dtype=torch.float32)
+ interpolated_c2w[:, :3, :3] = interpolated_rotations
+ interpolated_c2w[:, :3, 3] = interpolated_translations
+
+ return interpolated_c2w
diff --git a/core/data/combined_multi_view_dataset.py b/core/data/combined_multi_view_dataset.py
new file mode 100755
index 0000000000000000000000000000000000000000..12af9fe7252d4165a92f14c89a45965414df8187
--- /dev/null
+++ b/core/data/combined_multi_view_dataset.py
@@ -0,0 +1,341 @@
+import PIL
+import numpy as np
+import torch
+from PIL import Image
+
+from .camera_pose_utils import (
+ convert_w2c_between_c2w,
+ transform_pose_sequence_to_relative_c2w,
+)
+
+
+def get_ray_embeddings(
+ poses, size_h=256, size_w=256, fov_xy_list=None, focal_xy_list=None
+):
+ """
+ poses: sequence of cameras poses (y-up format)
+ """
+ use_focal = False
+ if fov_xy_list is None or fov_xy_list[0] is None or fov_xy_list[0][0] is None:
+ assert focal_xy_list is not None
+ use_focal = True
+
+ rays_embeddings = []
+ for i in range(poses.shape[0]):
+ cur_pose = poses[i]
+ if use_focal:
+ rays_o, rays_d = get_rays(
+ # [h, w, 3]
+ cur_pose,
+ size_h,
+ size_w,
+ focal_xy=focal_xy_list[i],
+ )
+ else:
+ rays_o, rays_d = get_rays(
+ cur_pose, size_h, size_w, fov_xy=fov_xy_list[i]
+ ) # [h, w, 3]
+
+ rays_plucker = torch.cat(
+ [torch.cross(rays_o, rays_d, dim=-1), rays_d], dim=-1
+ ) # [h, w, 6]
+ rays_embeddings.append(rays_plucker)
+
+ rays_embeddings = (
+ torch.stack(rays_embeddings, dim=0).permute(0, 3, 1, 2).contiguous()
+ ) # [V, 6, h, w]
+ return rays_embeddings
+
+
+def get_rays(pose, h, w, fov_xy=None, focal_xy=None, opengl=True):
+ x, y = torch.meshgrid(
+ torch.arange(w, device=pose.device),
+ torch.arange(h, device=pose.device),
+ indexing="xy",
+ )
+ x = x.flatten()
+ y = y.flatten()
+
+ cx = w * 0.5
+ cy = h * 0.5
+
+ # print("fov_xy=", fov_xy)
+ # print("focal_xy=", focal_xy)
+
+ if focal_xy is None:
+ assert fov_xy is not None, "fov_x/y and focal_x/y cannot both be None."
+ focal_x = w * 0.5 / np.tan(0.5 * np.deg2rad(fov_xy[0]))
+ focal_y = h * 0.5 / np.tan(0.5 * np.deg2rad(fov_xy[1]))
+ else:
+ assert (
+ len(focal_xy) == 2
+ ), "focal_xy should be a list-like object containing only two elements (focal length in x and y direction)."
+ focal_x = w * focal_xy[0]
+ focal_y = h * focal_xy[1]
+
+ camera_dirs = torch.nn.functional.pad(
+ torch.stack(
+ [
+ (x - cx + 0.5) / focal_x,
+ (y - cy + 0.5) / focal_y * (-1.0 if opengl else 1.0),
+ ],
+ dim=-1,
+ ),
+ (0, 1),
+ value=(-1.0 if opengl else 1.0),
+ ) # [hw, 3]
+
+ rays_d = camera_dirs @ pose[:3, :3].transpose(0, 1) # [hw, 3]
+ rays_o = pose[:3, 3].unsqueeze(0).expand_as(rays_d) # [hw, 3]
+
+ rays_o = rays_o.view(h, w, 3)
+ rays_d = safe_normalize(rays_d).view(h, w, 3)
+
+ return rays_o, rays_d
+
+
+def safe_normalize(x, eps=1e-20):
+ return x / length(x, eps)
+
+
+def length(x, eps=1e-20):
+ if isinstance(x, np.ndarray):
+ return np.sqrt(np.maximum(np.sum(x * x, axis=-1, keepdims=True), eps))
+ else:
+ return torch.sqrt(torch.clamp(dot(x, x), min=eps))
+
+
+def dot(x, y):
+ if isinstance(x, np.ndarray):
+ return np.sum(x * y, -1, keepdims=True)
+ else:
+ return torch.sum(x * y, -1, keepdim=True)
+
+
+def extend_list_by_repeating(original_list, target_length, repeat_idx, at_front):
+ if not original_list:
+ raise ValueError("The original list cannot be empty.")
+
+ extended_list = []
+ original_length = len(original_list)
+ for i in range(target_length - original_length):
+ extended_list.append(original_list[repeat_idx])
+
+ if at_front:
+ extended_list.extend(original_list)
+ return extended_list
+ else:
+ original_list.extend(extended_list)
+ return original_list
+
+
+def select_evenly_spaced_elements(arr, x):
+ if x <= 0 or len(arr) == 0:
+ return []
+
+ # Calculate step size as the ratio of length of the list and x
+ step = len(arr) / x
+
+ # Pick elements at indices that are multiples of step (round them to nearest integer)
+ selected_elements = [arr[round(i * step)] for i in range(x)]
+
+ return selected_elements
+
+
+def convert_co3d_annotation_to_opengl_pose_and_intrinsics(frame_annotation):
+ p = frame_annotation.viewpoint.principal_point
+ f = frame_annotation.viewpoint.focal_length
+ h, w = frame_annotation.image.size
+ K = np.eye(3)
+ s = (min(h, w) - 1) / 2
+ if frame_annotation.viewpoint.intrinsics_format == "ndc_norm_image_bounds":
+ K[0, 0] = f[0] * (w - 1) / 2
+ K[1, 1] = f[1] * (h - 1) / 2
+ elif frame_annotation.viewpoint.intrinsics_format == "ndc_isotropic":
+ K[0, 0] = f[0] * s / 2
+ K[1, 1] = f[1] * s / 2
+ else:
+ assert (
+ False
+ ), f"Invalid intrinsics_format: {frame_annotation.viewpoint.intrinsics_format}"
+ K[0, 2] = -p[0] * s + (w - 1) / 2
+ K[1, 2] = -p[1] * s + (h - 1) / 2
+
+ R = np.array(frame_annotation.viewpoint.R).T # note the transpose here
+ T = np.array(frame_annotation.viewpoint.T)
+ pose = np.concatenate([R, T[:, None]], 1)
+ # Need to be converted into OpenGL format. Flip the direction of x, z axis
+ pose = np.diag([-1, 1, -1]).astype(np.float32) @ pose
+ return pose, K
+
+
+def normalize_w2c_camera_pose_sequence(
+ target_camera_poses,
+ condition_camera_poses=None,
+ output_c2w=False,
+ translation_norm_mode="div_by_max",
+):
+ """
+ Normalize camera pose sequence so that the first frame is identity rotation and zero translation,
+ and the translation scale is normalized by the farest point from the first frame (to one).
+ :param target_camera_poses: W2C poses tensor in [N, 3, 4]
+ :param condition_camera_poses: W2C poses tensor in [N, 3, 4]
+ :return: Tuple(Tensor, Tensor), the normalized `target_camera_poses` and `condition_camera_poses`
+ """
+ # Normalize at w2c, all poses should be in w2c in UnifiedFrame
+ num_target_views = target_camera_poses.size(0)
+ if condition_camera_poses is not None:
+ all_poses = torch.concat([target_camera_poses, condition_camera_poses], dim=0)
+ else:
+ all_poses = target_camera_poses
+ # Convert W2C to C2W
+ normalized_poses = transform_pose_sequence_to_relative_c2w(
+ convert_w2c_between_c2w(all_poses)
+ )
+ # Here normalized_poses is C2W
+ if not output_c2w:
+ # Convert from C2W back to W2C if output_c2w is False.
+ normalized_poses = convert_w2c_between_c2w(normalized_poses)
+
+ t_norms = torch.linalg.norm(normalized_poses[:, :, 3], ord=2, dim=-1)
+ # print("t_norms=", t_norms)
+ largest_t_norm = torch.max(t_norms)
+
+ # print("largest_t_norm=", largest_t_norm)
+ # normalized_poses[:, :, 3] -= first_t.unsqueeze(0).repeat(normalized_poses.size(0), 1)
+ if translation_norm_mode == "div_by_max_plus_one":
+ # Always add a constant component to the translation norm
+ largest_t_norm = largest_t_norm + 1.0
+ elif translation_norm_mode == "div_by_max":
+ largest_t_norm = largest_t_norm
+ if largest_t_norm <= 0.05:
+ largest_t_norm = 0.05
+ elif translation_norm_mode == "disabled":
+ largest_t_norm = 1.0
+ else:
+ assert False, f"Invalid translation_norm_mode: {translation_norm_mode}."
+ normalized_poses[:, :, 3] /= largest_t_norm
+
+ target_camera_poses = normalized_poses[:num_target_views]
+ if condition_camera_poses is not None:
+ condition_camera_poses = normalized_poses[num_target_views:]
+ else:
+ condition_camera_poses = None
+ # print("After First condition:", condition_camera_poses[0])
+ # print("After First target:", target_camera_poses[0])
+ return target_camera_poses, condition_camera_poses
+
+
+def central_crop_pil_image(_image, crop_size, use_central_padding=False):
+ if use_central_padding:
+ # Determine the new size
+ _w, _h = _image.size
+ new_size = max(_w, _h)
+ # Create a new image with white background
+ new_image = Image.new("RGB", (new_size, new_size), (255, 255, 255))
+ # Calculate the position to paste the original image
+ paste_position = ((new_size - _w) // 2, (new_size - _h) // 2)
+ # Paste the original image onto the new image
+ new_image.paste(_image, paste_position)
+ _image = new_image
+ # get the new size again if padded
+ _w, _h = _image.size
+ scale = crop_size / min(_h, _w)
+ # resize shortest side to crop_size
+ _w_out, _h_out = int(scale * _w), int(scale * _h)
+ _image = _image.resize(
+ (_w_out, _h_out),
+ resample=(
+ PIL.Image.Resampling.LANCZOS if scale < 1 else PIL.Image.Resampling.BICUBIC
+ ),
+ )
+ # center crop
+ margin_w = (_image.size[0] - crop_size) // 2
+ margin_h = (_image.size[1] - crop_size) // 2
+ _image = _image.crop(
+ (margin_w, margin_h, margin_w + crop_size, margin_h + crop_size)
+ )
+ return _image
+
+
+def crop_and_resize(
+ image: Image.Image, target_width: int, target_height: int
+) -> Image.Image:
+ """
+ Crops and resizes an image while preserving the aspect ratio.
+
+ Args:
+ image (Image.Image): Input PIL image to be cropped and resized.
+ target_width (int): Target width of the output image.
+ target_height (int): Target height of the output image.
+
+ Returns:
+ Image.Image: Cropped and resized image.
+ """
+ # Original dimensions
+ original_width, original_height = image.size
+ original_aspect = original_width / original_height
+ target_aspect = target_width / target_height
+
+ # Calculate crop box to maintain aspect ratio
+ if original_aspect > target_aspect:
+ # Crop horizontally
+ new_width = int(original_height * target_aspect)
+ new_height = original_height
+ left = (original_width - new_width) / 2
+ top = 0
+ right = left + new_width
+ bottom = original_height
+ else:
+ # Crop vertically
+ new_width = original_width
+ new_height = int(original_width / target_aspect)
+ left = 0
+ top = (original_height - new_height) / 2
+ right = original_width
+ bottom = top + new_height
+
+ # Crop and resize
+ cropped_image = image.crop((left, top, right, bottom))
+ resized_image = cropped_image.resize((target_width, target_height), Image.LANCZOS)
+
+ return resized_image
+
+
+def calculate_fov_after_resize(
+ fov_x: float,
+ fov_y: float,
+ original_width: int,
+ original_height: int,
+ target_width: int,
+ target_height: int,
+) -> (float, float):
+ """
+ Calculates the new field of view after cropping and resizing an image.
+
+ Args:
+ fov_x (float): Original field of view in the x-direction (horizontal).
+ fov_y (float): Original field of view in the y-direction (vertical).
+ original_width (int): Original width of the image.
+ original_height (int): Original height of the image.
+ target_width (int): Target width of the output image.
+ target_height (int): Target height of the output image.
+
+ Returns:
+ (float, float): New field of view (fov_x, fov_y) after cropping and resizing.
+ """
+ original_aspect = original_width / original_height
+ target_aspect = target_width / target_height
+
+ if original_aspect > target_aspect:
+ # Crop horizontally
+ new_width = int(original_height * target_aspect)
+ new_fov_x = fov_x * (new_width / original_width)
+ new_fov_y = fov_y
+ else:
+ # Crop vertically
+ new_height = int(original_width / target_aspect)
+ new_fov_y = fov_y * (new_height / original_height)
+ new_fov_x = fov_x
+
+ return new_fov_x, new_fov_y
diff --git a/core/data/utils.py b/core/data/utils.py
new file mode 100755
index 0000000000000000000000000000000000000000..2747bf822452a7b5c544cec651ae8365ae28963d
--- /dev/null
+++ b/core/data/utils.py
@@ -0,0 +1,184 @@
+import copy
+import random
+from PIL import Image
+
+import numpy as np
+
+
+def create_relative(RT_list, K_1=4.7, dataset="syn"):
+ if dataset == "realestate":
+ scale_T = 1
+ RT_list = [RT.reshape(3, 4) for RT in RT_list]
+ elif dataset == "syn":
+ scale_T = (470 / K_1) / 7.5
+ """
+ 4.694746736956946052e+02 0.000000000000000000e+00 4.800000000000000000e+02
+ 0.000000000000000000e+00 4.694746736956946052e+02 2.700000000000000000e+02
+ 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00
+ """
+ elif dataset == "zero123":
+ scale_T = 0.5
+ else:
+ raise Exception("invalid dataset type")
+
+ # convert x y z to x -y -z
+ if dataset == "zero123":
+ flip_matrix = np.array([[1, 0, 0], [0, -1, 0], [0, 0, -1]])
+ for i in range(len(RT_list)):
+ RT_list[i] = np.dot(flip_matrix, RT_list[i])
+
+ temp = []
+ first_frame_RT = copy.deepcopy(RT_list[0])
+ # first_frame_R_inv = np.linalg.inv(first_frame_RT[:,:3])
+ first_frame_R_inv = first_frame_RT[:, :3].T
+ first_frame_T = first_frame_RT[:, -1]
+ for RT in RT_list:
+ RT[:, :3] = np.dot(RT[:, :3], first_frame_R_inv)
+ RT[:, -1] = RT[:, -1] - np.dot(RT[:, :3], first_frame_T)
+ RT[:, -1] = RT[:, -1] * scale_T
+ temp.append(RT)
+ RT_list = temp
+
+ if dataset == "realestate":
+ RT_list = [RT.reshape(-1) for RT in RT_list]
+
+ return RT_list
+
+
+def sigma_matrix2(sig_x, sig_y, theta):
+ """Calculate the rotated sigma matrix (two dimensional matrix).
+ Args:
+ sig_x (float):
+ sig_y (float):
+ theta (float): Radian measurement.
+ Returns:
+ ndarray: Rotated sigma matrix.
+ """
+ d_matrix = np.array([[sig_x**2, 0], [0, sig_y**2]])
+ u_matrix = np.array(
+ [[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]
+ )
+ return np.dot(u_matrix, np.dot(d_matrix, u_matrix.T))
+
+
+def mesh_grid(kernel_size):
+ """Generate the mesh grid, centering at zero.
+ Args:
+ kernel_size (int):
+ Returns:
+ xy (ndarray): with the shape (kernel_size, kernel_size, 2)
+ xx (ndarray): with the shape (kernel_size, kernel_size)
+ yy (ndarray): with the shape (kernel_size, kernel_size)
+ """
+ ax = np.arange(-kernel_size // 2 + 1.0, kernel_size // 2 + 1.0)
+ xx, yy = np.meshgrid(ax, ax)
+ xy = np.hstack(
+ (
+ xx.reshape((kernel_size * kernel_size, 1)),
+ yy.reshape(kernel_size * kernel_size, 1),
+ )
+ ).reshape(kernel_size, kernel_size, 2)
+ return xy, xx, yy
+
+
+def pdf2(sigma_matrix, grid):
+ """Calculate PDF of the bivariate Gaussian distribution.
+ Args:
+ sigma_matrix (ndarray): with the shape (2, 2)
+ grid (ndarray): generated by :func:`mesh_grid`,
+ with the shape (K, K, 2), K is the kernel size.
+ Returns:
+ kernel (ndarrray): un-normalized kernel.
+ """
+ inverse_sigma = np.linalg.inv(sigma_matrix)
+ kernel = np.exp(-0.5 * np.sum(np.dot(grid, inverse_sigma) * grid, 2))
+ return kernel
+
+
+def bivariate_Gaussian(kernel_size, sig_x, sig_y, theta, grid=None, isotropic=True):
+ """Generate a bivariate isotropic or anisotropic Gaussian kernel.
+ In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.
+ Args:
+ kernel_size (int):
+ sig_x (float):
+ sig_y (float):
+ theta (float): Radian measurement.
+ grid (ndarray, optional): generated by :func:`mesh_grid`,
+ with the shape (K, K, 2), K is the kernel size. Default: None
+ isotropic (bool):
+ Returns:
+ kernel (ndarray): normalized kernel.
+ """
+ if grid is None:
+ grid, _, _ = mesh_grid(kernel_size)
+ if isotropic:
+ sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]])
+ else:
+ sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
+ kernel = pdf2(sigma_matrix, grid)
+ kernel = kernel / np.sum(kernel)
+ return kernel
+
+
+def rgba_to_rgb_with_bg(rgba_image, bg_color=(255, 255, 255)):
+ """
+ Convert a PIL RGBA Image to an RGB Image with a white background.
+
+ Args:
+ rgba_image (Image): A PIL Image object in RGBA mode.
+
+ Returns:
+ Image: A PIL Image object in RGB mode with white background.
+ """
+ # Ensure the image is in RGBA mode
+ # Ensure the image is in RGBA mode
+ if rgba_image.mode != "RGBA":
+ return rgba_image
+ # raise ValueError("The image must be in RGBA mode")
+
+ # Create a white background image
+ white_bg_rgb = Image.new("RGB", rgba_image.size, bg_color)
+ # Paste the RGBA image onto the white background using alpha channel as mask
+ white_bg_rgb.paste(
+ rgba_image, mask=rgba_image.split()[3]
+ ) # 3 is the alpha channel index
+ return white_bg_rgb
+
+
+def random_order_preserving_selection(items, num):
+ if num > len(items):
+ print("WARNING: Item list is shorter than `num` given.")
+ return items
+ selected_indices = sorted(random.sample(range(len(items)), num))
+ selected_items = [items[i] for i in selected_indices]
+ return selected_items
+
+
+def pad_pil_image_to_square(image, fill_color=(255, 255, 255)):
+ """
+ Pad an image to make it square with the given fill color.
+
+ Args:
+ image (PIL.Image): The original image.
+ fill_color (tuple): The color to use for padding (default is black).
+
+ Returns:
+ PIL.Image: A new image that is padded to be square.
+ """
+ width, height = image.size
+
+ # Determine the new size, which will be the maximum of width or height
+ new_size = max(width, height)
+
+ # Create a new image with the new size and fill color
+ new_image = Image.new("RGB", (new_size, new_size), fill_color)
+
+ # Calculate the position to paste the original image onto the new image
+ # This calculation centers the original image in the new square canvas
+ left = (new_size - width) // 2
+ top = (new_size - height) // 2
+
+ # Paste the original image into the new image
+ new_image.paste(image, (left, top))
+
+ return new_image
diff --git a/core/distributions.py b/core/distributions.py
new file mode 100755
index 0000000000000000000000000000000000000000..593dafdbd2bb7422bb8fcd397d8cf05ee86d591b
--- /dev/null
+++ b/core/distributions.py
@@ -0,0 +1,102 @@
+import torch
+import numpy as np
+
+
+class AbstractDistribution:
+ def sample(self):
+ raise NotImplementedError()
+
+ def mode(self):
+ raise NotImplementedError()
+
+
+class DiracDistribution(AbstractDistribution):
+ def __init__(self, value):
+ self.value = value
+
+ def sample(self):
+ return self.value
+
+ def mode(self):
+ return self.value
+
+
+class DiagonalGaussianDistribution(object):
+ def __init__(self, parameters, deterministic=False):
+ self.parameters = parameters
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
+ self.deterministic = deterministic
+ self.std = torch.exp(0.5 * self.logvar)
+ self.var = torch.exp(self.logvar)
+ if self.deterministic:
+ self.var = self.std = torch.zeros_like(self.mean).to(
+ device=self.parameters.device
+ )
+
+ def sample(self, noise=None):
+ if noise is None:
+ noise = torch.randn(self.mean.shape)
+
+ x = self.mean + self.std * noise.to(device=self.parameters.device)
+ return x
+
+ def kl(self, other=None):
+ if self.deterministic:
+ return torch.Tensor([0.0])
+ else:
+ if other is None:
+ return 0.5 * torch.sum(
+ torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
+ dim=[1, 2, 3],
+ )
+ else:
+ return 0.5 * torch.sum(
+ torch.pow(self.mean - other.mean, 2) / other.var
+ + self.var / other.var
+ - 1.0
+ - self.logvar
+ + other.logvar,
+ dim=[1, 2, 3],
+ )
+
+ def nll(self, sample, dims=[1, 2, 3]):
+ if self.deterministic:
+ return torch.Tensor([0.0])
+ logtwopi = np.log(2.0 * np.pi)
+ return 0.5 * torch.sum(
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
+ dim=dims,
+ )
+
+ def mode(self):
+ return self.mean
+
+
+def normal_kl(mean1, logvar1, mean2, logvar2):
+ """
+ Compute the KL divergence between two gaussians.
+ Shapes are automatically broadcasted, so batches can be compared to
+ scalars, among other use cases.
+ """
+ tensor = None
+ for obj in (mean1, logvar1, mean2, logvar2):
+ if isinstance(obj, torch.Tensor):
+ tensor = obj
+ break
+ assert tensor is not None, "at least one argument must be a Tensor"
+
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
+ # Tensors, but it does not work for torch.exp().
+ logvar1, logvar2 = [
+ x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
+ for x in (logvar1, logvar2)
+ ]
+
+ return 0.5 * (
+ -1.0
+ + logvar2
+ - logvar1
+ + torch.exp(logvar1 - logvar2)
+ + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
+ )
diff --git a/core/ema.py b/core/ema.py
new file mode 100755
index 0000000000000000000000000000000000000000..0e1447b06b710151e769fc820049db54fe132510
--- /dev/null
+++ b/core/ema.py
@@ -0,0 +1,84 @@
+import torch
+from torch import nn
+
+
+class LitEma(nn.Module):
+ def __init__(self, model, decay=0.9999, use_num_upates=True):
+ super().__init__()
+ if decay < 0.0 or decay > 1.0:
+ raise ValueError("Decay must be between 0 and 1")
+
+ self.m_name2s_name = {}
+ self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32))
+ self.register_buffer(
+ "num_updates",
+ (
+ torch.tensor(0, dtype=torch.int)
+ if use_num_upates
+ else torch.tensor(-1, dtype=torch.int)
+ ),
+ )
+
+ for name, p in model.named_parameters():
+ if p.requires_grad:
+ # remove as '.'-character is not allowed in buffers
+ s_name = name.replace(".", "")
+ self.m_name2s_name.update({name: s_name})
+ self.register_buffer(s_name, p.clone().detach().data)
+
+ self.collected_params = []
+
+ def forward(self, model):
+ decay = self.decay
+
+ if self.num_updates >= 0:
+ self.num_updates += 1
+ decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
+
+ one_minus_decay = 1.0 - decay
+
+ with torch.no_grad():
+ m_param = dict(model.named_parameters())
+ shadow_params = dict(self.named_buffers())
+
+ for key in m_param:
+ if m_param[key].requires_grad:
+ sname = self.m_name2s_name[key]
+ shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
+ shadow_params[sname].sub_(
+ one_minus_decay * (shadow_params[sname] - m_param[key])
+ )
+ else:
+ assert not key in self.m_name2s_name
+
+ def copy_to(self, model):
+ m_param = dict(model.named_parameters())
+ shadow_params = dict(self.named_buffers())
+ for key in m_param:
+ if m_param[key].requires_grad:
+ m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
+ else:
+ assert not key in self.m_name2s_name
+
+ def store(self, parameters):
+ """
+ Save the current parameters for restoring later.
+ Args:
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+ temporarily stored.
+ """
+ self.collected_params = [param.clone() for param in parameters]
+
+ def restore(self, parameters):
+ """
+ Restore the parameters stored with the `store` method.
+ Useful to validate the model with EMA parameters without affecting the
+ original optimization process. Store the parameters before the
+ `copy_to` method. After validation (or model saving), use this to
+ restore the former parameters.
+ Args:
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+ updated with the stored parameters.
+ """
+ for c_param, param in zip(self.collected_params, parameters):
+ param.data.copy_(c_param.data)
diff --git a/core/losses/__init__.py b/core/losses/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..a1a0bd9f7bcd37e54cba7cbe0cf413bf799ff515
--- /dev/null
+++ b/core/losses/__init__.py
@@ -0,0 +1 @@
+from core.losses.contperceptual import LPIPSWithDiscriminator
\ No newline at end of file
diff --git a/core/losses/contperceptual.py b/core/losses/contperceptual.py
new file mode 100755
index 0000000000000000000000000000000000000000..5396eaacd3cf213ad5eb4696c1adbe834cdabff1
--- /dev/null
+++ b/core/losses/contperceptual.py
@@ -0,0 +1,173 @@
+import torch
+import torch.nn as nn
+from einops import rearrange
+from taming.modules.losses.vqperceptual import *
+
+
+class LPIPSWithDiscriminator(nn.Module):
+ def __init__(
+ self,
+ disc_start,
+ logvar_init=0.0,
+ kl_weight=1.0,
+ pixelloss_weight=1.0,
+ disc_num_layers=3,
+ disc_in_channels=3,
+ disc_factor=1.0,
+ disc_weight=1.0,
+ perceptual_weight=1.0,
+ use_actnorm=False,
+ disc_conditional=False,
+ disc_loss="hinge",
+ max_bs=None,
+ ):
+
+ super().__init__()
+ assert disc_loss in ["hinge", "vanilla"]
+ self.kl_weight = kl_weight
+ self.pixel_weight = pixelloss_weight
+ self.perceptual_loss = LPIPS().eval()
+ self.perceptual_weight = perceptual_weight
+ # output log variance
+ self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
+
+ self.discriminator = NLayerDiscriminator(
+ input_nc=disc_in_channels, n_layers=disc_num_layers, use_actnorm=use_actnorm
+ ).apply(weights_init)
+ self.discriminator_iter_start = disc_start
+ self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
+ self.disc_factor = disc_factor
+ self.discriminator_weight = disc_weight
+ self.disc_conditional = disc_conditional
+ self.max_bs = max_bs
+
+ def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
+ if last_layer is not None:
+ nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
+ g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
+ else:
+ nll_grads = torch.autograd.grad(
+ nll_loss, self.last_layer[0], retain_graph=True
+ )[0]
+ g_grads = torch.autograd.grad(
+ g_loss, self.last_layer[0], retain_graph=True
+ )[0]
+
+ d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
+ d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
+ d_weight = d_weight * self.discriminator_weight
+ return d_weight
+
+ def forward(
+ self,
+ inputs,
+ reconstructions,
+ posteriors,
+ optimizer_idx,
+ global_step,
+ last_layer=None,
+ cond=None,
+ split="train",
+ weights=None,
+ ):
+ if inputs.dim() == 5:
+ inputs = rearrange(inputs, "b c t h w -> (b t) c h w")
+ if reconstructions.dim() == 5:
+ reconstructions = rearrange(reconstructions, "b c t h w -> (b t) c h w")
+ rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
+ if self.perceptual_weight > 0:
+ if self.max_bs is not None and self.max_bs < inputs.shape[0]:
+ input_list = torch.split(inputs, self.max_bs, dim=0)
+ reconstruction_list = torch.split(reconstructions, self.max_bs, dim=0)
+ p_losses = [
+ self.perceptual_loss(
+ inputs.contiguous(), reconstructions.contiguous()
+ )
+ for inputs, reconstructions in zip(input_list, reconstruction_list)
+ ]
+ p_loss = torch.cat(p_losses, dim=0)
+ else:
+ p_loss = self.perceptual_loss(
+ inputs.contiguous(), reconstructions.contiguous()
+ )
+ rec_loss = rec_loss + self.perceptual_weight * p_loss
+
+ nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
+ weighted_nll_loss = nll_loss
+ if weights is not None:
+ weighted_nll_loss = weights * nll_loss
+ weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
+ nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
+
+ kl_loss = posteriors.kl()
+ kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
+
+ # now the GAN part
+ if optimizer_idx == 0:
+ # generator update
+ if cond is None:
+ assert not self.disc_conditional
+ logits_fake = self.discriminator(reconstructions.contiguous())
+ else:
+ assert self.disc_conditional
+ logits_fake = self.discriminator(
+ torch.cat((reconstructions.contiguous(), cond), dim=1)
+ )
+ g_loss = -torch.mean(logits_fake)
+
+ if self.disc_factor > 0.0:
+ try:
+ d_weight = self.calculate_adaptive_weight(
+ nll_loss, g_loss, last_layer=last_layer
+ )
+ except RuntimeError:
+ assert not self.training
+ d_weight = torch.tensor(0.0)
+ else:
+ d_weight = torch.tensor(0.0)
+
+ disc_factor = adopt_weight(
+ self.disc_factor, global_step, threshold=self.discriminator_iter_start
+ )
+ loss = (
+ weighted_nll_loss
+ + self.kl_weight * kl_loss
+ + d_weight * disc_factor * g_loss
+ )
+
+ log = {
+ "{}/total_loss".format(split): loss.clone().detach().mean(),
+ "{}/logvar".format(split): self.logvar.detach(),
+ "{}/kl_loss".format(split): kl_loss.detach().mean(),
+ "{}/nll_loss".format(split): nll_loss.detach().mean(),
+ "{}/rec_loss".format(split): rec_loss.detach().mean(),
+ "{}/d_weight".format(split): d_weight.detach(),
+ "{}/disc_factor".format(split): torch.tensor(disc_factor),
+ "{}/g_loss".format(split): g_loss.detach().mean(),
+ }
+ return loss, log
+
+ if optimizer_idx == 1:
+ # second pass for discriminator update
+ if cond is None:
+ logits_real = self.discriminator(inputs.contiguous().detach())
+ logits_fake = self.discriminator(reconstructions.contiguous().detach())
+ else:
+ logits_real = self.discriminator(
+ torch.cat((inputs.contiguous().detach(), cond), dim=1)
+ )
+ logits_fake = self.discriminator(
+ torch.cat((reconstructions.contiguous().detach(), cond), dim=1)
+ )
+
+ disc_factor = adopt_weight(
+ self.disc_factor, global_step, threshold=self.discriminator_iter_start
+ )
+ d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
+
+ log = {
+ "{}/disc_loss".format(split): d_loss.clone().detach().mean(),
+ "{}/logits_real".format(split): logits_real.detach().mean(),
+ "{}/logits_fake".format(split): logits_fake.detach().mean(),
+ }
+ return d_loss, log
diff --git a/core/losses/vqperceptual.py b/core/losses/vqperceptual.py
new file mode 100755
index 0000000000000000000000000000000000000000..304482ac3bcd972bdce497cb49638d59d119eacc
--- /dev/null
+++ b/core/losses/vqperceptual.py
@@ -0,0 +1,217 @@
+import torch
+from torch import nn
+import torch.nn.functional as F
+from einops import repeat
+
+from taming.modules.discriminator.model import NLayerDiscriminator, weights_init
+from taming.modules.losses.lpips import LPIPS
+from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss
+
+
+def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights):
+ assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0]
+ loss_real = torch.mean(F.relu(1.0 - logits_real), dim=[1, 2, 3])
+ loss_fake = torch.mean(F.relu(1.0 + logits_fake), dim=[1, 2, 3])
+ loss_real = (weights * loss_real).sum() / weights.sum()
+ loss_fake = (weights * loss_fake).sum() / weights.sum()
+ d_loss = 0.5 * (loss_real + loss_fake)
+ return d_loss
+
+
+def adopt_weight(weight, global_step, threshold=0, value=0.0):
+ if global_step < threshold:
+ weight = value
+ return weight
+
+
+def measure_perplexity(predicted_indices, n_embed):
+ encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed)
+ avg_probs = encodings.mean(0)
+ perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
+ cluster_use = torch.sum(avg_probs > 0)
+ return perplexity, cluster_use
+
+
+def l1(x, y):
+ return torch.abs(x - y)
+
+
+def l2(x, y):
+ return torch.pow((x - y), 2)
+
+
+class VQLPIPSWithDiscriminator(nn.Module):
+ def __init__(
+ self,
+ disc_start,
+ codebook_weight=1.0,
+ pixelloss_weight=1.0,
+ disc_num_layers=3,
+ disc_in_channels=3,
+ disc_factor=1.0,
+ disc_weight=1.0,
+ perceptual_weight=1.0,
+ use_actnorm=False,
+ disc_conditional=False,
+ disc_ndf=64,
+ disc_loss="hinge",
+ n_classes=None,
+ perceptual_loss="lpips",
+ pixel_loss="l1",
+ ):
+ super().__init__()
+ assert disc_loss in ["hinge", "vanilla"]
+ assert perceptual_loss in ["lpips", "clips", "dists"]
+ assert pixel_loss in ["l1", "l2"]
+ self.codebook_weight = codebook_weight
+ self.pixel_weight = pixelloss_weight
+ if perceptual_loss == "lpips":
+ print(f"{self.__class__.__name__}: Running with LPIPS.")
+ self.perceptual_loss = LPIPS().eval()
+ else:
+ raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<")
+ self.perceptual_weight = perceptual_weight
+
+ if pixel_loss == "l1":
+ self.pixel_loss = l1
+ else:
+ self.pixel_loss = l2
+
+ self.discriminator = NLayerDiscriminator(
+ input_nc=disc_in_channels,
+ n_layers=disc_num_layers,
+ use_actnorm=use_actnorm,
+ ndf=disc_ndf,
+ ).apply(weights_init)
+ self.discriminator_iter_start = disc_start
+ if disc_loss == "hinge":
+ self.disc_loss = hinge_d_loss
+ elif disc_loss == "vanilla":
+ self.disc_loss = vanilla_d_loss
+ else:
+ raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
+ print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.")
+ self.disc_factor = disc_factor
+ self.discriminator_weight = disc_weight
+ self.disc_conditional = disc_conditional
+ self.n_classes = n_classes
+
+ def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
+ if last_layer is not None:
+ nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
+ g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
+ else:
+ nll_grads = torch.autograd.grad(
+ nll_loss, self.last_layer[0], retain_graph=True
+ )[0]
+ g_grads = torch.autograd.grad(
+ g_loss, self.last_layer[0], retain_graph=True
+ )[0]
+
+ d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
+ d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
+ d_weight = d_weight * self.discriminator_weight
+ return d_weight
+
+ def forward(
+ self,
+ codebook_loss,
+ inputs,
+ reconstructions,
+ optimizer_idx,
+ global_step,
+ last_layer=None,
+ cond=None,
+ split="train",
+ predicted_indices=None,
+ ):
+ if not exists(codebook_loss):
+ codebook_loss = torch.tensor([0.0]).to(inputs.device)
+ # rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
+ rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous())
+ if self.perceptual_weight > 0:
+ p_loss = self.perceptual_loss(
+ inputs.contiguous(), reconstructions.contiguous()
+ )
+ rec_loss = rec_loss + self.perceptual_weight * p_loss
+ else:
+ p_loss = torch.tensor([0.0])
+
+ nll_loss = rec_loss
+ # nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
+ nll_loss = torch.mean(nll_loss)
+
+ # now the GAN part
+ if optimizer_idx == 0:
+ # generator update
+ if cond is None:
+ assert not self.disc_conditional
+ logits_fake = self.discriminator(reconstructions.contiguous())
+ else:
+ assert self.disc_conditional
+ logits_fake = self.discriminator(
+ torch.cat((reconstructions.contiguous(), cond), dim=1)
+ )
+ g_loss = -torch.mean(logits_fake)
+
+ try:
+ d_weight = self.calculate_adaptive_weight(
+ nll_loss, g_loss, last_layer=last_layer
+ )
+ except RuntimeError:
+ assert not self.training
+ d_weight = torch.tensor(0.0)
+
+ disc_factor = adopt_weight(
+ self.disc_factor, global_step, threshold=self.discriminator_iter_start
+ )
+ loss = (
+ nll_loss
+ + d_weight * disc_factor * g_loss
+ + self.codebook_weight * codebook_loss.mean()
+ )
+
+ log = {
+ "{}/total_loss".format(split): loss.clone().detach().mean(),
+ "{}/quant_loss".format(split): codebook_loss.detach().mean(),
+ "{}/nll_loss".format(split): nll_loss.detach().mean(),
+ "{}/rec_loss".format(split): rec_loss.detach().mean(),
+ "{}/p_loss".format(split): p_loss.detach().mean(),
+ "{}/d_weight".format(split): d_weight.detach(),
+ "{}/disc_factor".format(split): torch.tensor(disc_factor),
+ "{}/g_loss".format(split): g_loss.detach().mean(),
+ }
+ if predicted_indices is not None:
+ assert self.n_classes is not None
+ with torch.no_grad():
+ perplexity, cluster_usage = measure_perplexity(
+ predicted_indices, self.n_classes
+ )
+ log[f"{split}/perplexity"] = perplexity
+ log[f"{split}/cluster_usage"] = cluster_usage
+ return loss, log
+
+ if optimizer_idx == 1:
+ # second pass for discriminator update
+ if cond is None:
+ logits_real = self.discriminator(inputs.contiguous().detach())
+ logits_fake = self.discriminator(reconstructions.contiguous().detach())
+ else:
+ logits_real = self.discriminator(
+ torch.cat((inputs.contiguous().detach(), cond), dim=1)
+ )
+ logits_fake = self.discriminator(
+ torch.cat((reconstructions.contiguous().detach(), cond), dim=1)
+ )
+
+ disc_factor = adopt_weight(
+ self.disc_factor, global_step, threshold=self.discriminator_iter_start
+ )
+ d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
+
+ log = {
+ "{}/disc_loss".format(split): d_loss.clone().detach().mean(),
+ "{}/logits_real".format(split): logits_real.detach().mean(),
+ "{}/logits_fake".format(split): logits_fake.detach().mean(),
+ }
+ return d_loss, log
diff --git a/core/models/autoencoder.py b/core/models/autoencoder.py
new file mode 100755
index 0000000000000000000000000000000000000000..7b17d390a648357afb589c5eea7953b019aceb54
--- /dev/null
+++ b/core/models/autoencoder.py
@@ -0,0 +1,395 @@
+import os
+import json
+from contextlib import contextmanager
+
+import torch
+import numpy as np
+from einops import rearrange
+
+import torch.nn.functional as F
+import torch.distributed as dist
+import pytorch_lightning as pl
+from pytorch_lightning.utilities import rank_zero_only
+
+from taming.modules.vqvae.quantize import VectorQuantizer as VectorQuantizer
+
+from core.modules.networks.ae_modules import Encoder, Decoder
+from core.distributions import DiagonalGaussianDistribution
+from utils.utils import instantiate_from_config
+from utils.save_video import tensor2videogrids
+from core.common import shape_to_str, gather_data
+
+
+class AutoencoderKL(pl.LightningModule):
+ def __init__(
+ self,
+ ddconfig,
+ lossconfig,
+ embed_dim,
+ ckpt_path=None,
+ ignore_keys=[],
+ image_key="image",
+ colorize_nlabels=None,
+ monitor=None,
+ test=False,
+ logdir=None,
+ input_dim=4,
+ test_args=None,
+ ):
+ super().__init__()
+ self.image_key = image_key
+ self.encoder = Encoder(**ddconfig)
+ self.decoder = Decoder(**ddconfig)
+ self.loss = instantiate_from_config(lossconfig)
+ assert ddconfig["double_z"]
+ self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1)
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
+ self.embed_dim = embed_dim
+ self.input_dim = input_dim
+ self.test = test
+ self.test_args = test_args
+ self.logdir = logdir
+ if colorize_nlabels is not None:
+ assert type(colorize_nlabels) == int
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
+ if monitor is not None:
+ self.monitor = monitor
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
+ if self.test:
+ self.init_test()
+
+ def init_test(
+ self,
+ ):
+ self.test = True
+ save_dir = os.path.join(self.logdir, "test")
+ if "ckpt" in self.test_args:
+ ckpt_name = (
+ os.path.basename(self.test_args.ckpt).split(".ckpt")[0]
+ + f"_epoch{self._cur_epoch}"
+ )
+ self.root = os.path.join(save_dir, ckpt_name)
+ else:
+ self.root = save_dir
+ if "test_subdir" in self.test_args:
+ self.root = os.path.join(save_dir, self.test_args.test_subdir)
+
+ self.root_zs = os.path.join(self.root, "zs")
+ self.root_dec = os.path.join(self.root, "reconstructions")
+ self.root_inputs = os.path.join(self.root, "inputs")
+ os.makedirs(self.root, exist_ok=True)
+
+ if self.test_args.save_z:
+ os.makedirs(self.root_zs, exist_ok=True)
+ if self.test_args.save_reconstruction:
+ os.makedirs(self.root_dec, exist_ok=True)
+ if self.test_args.save_input:
+ os.makedirs(self.root_inputs, exist_ok=True)
+ assert self.test_args is not None
+ self.test_maximum = getattr(
+ self.test_args, "test_maximum", None
+ ) # 1500 # 12000/8
+ self.count = 0
+ self.eval_metrics = {}
+ self.decodes = []
+ self.save_decode_samples = 2048
+ if getattr(self.test_args, "cal_metrics", False):
+ self.EvalLpips = EvalLpips()
+
+ def init_from_ckpt(self, path, ignore_keys=list()):
+ sd = torch.load(path, map_location="cpu")
+ try:
+ self._cur_epoch = sd["epoch"]
+ sd = sd["state_dict"]
+ except:
+ self._cur_epoch = "null"
+ keys = list(sd.keys())
+ for k in keys:
+ for ik in ignore_keys:
+ if k.startswith(ik):
+ print("Deleting key {} from state_dict.".format(k))
+ del sd[k]
+ self.load_state_dict(sd, strict=False)
+ # self.load_state_dict(sd, strict=True)
+ print(f"Restored from {path}")
+
+ def encode(self, x, **kwargs):
+
+ h = self.encoder(x)
+ moments = self.quant_conv(h)
+ posterior = DiagonalGaussianDistribution(moments)
+ return posterior
+
+ def decode(self, z, **kwargs):
+ z = self.post_quant_conv(z)
+ dec = self.decoder(z)
+ return dec
+
+ def forward(self, input, sample_posterior=True):
+ posterior = self.encode(input)
+ if sample_posterior:
+ z = posterior.sample()
+ else:
+ z = posterior.mode()
+ dec = self.decode(z)
+ return dec, posterior
+
+ def get_input(self, batch, k):
+ x = batch[k]
+ # if len(x.shape) == 3:
+ # x = x[..., None]
+ # if x.dim() == 4:
+ # x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
+ if x.dim() == 5 and self.input_dim == 4:
+ b, c, t, h, w = x.shape
+ self.b = b
+ self.t = t
+ x = rearrange(x, "b c t h w -> (b t) c h w")
+
+ return x
+
+ def training_step(self, batch, batch_idx, optimizer_idx):
+ inputs = self.get_input(batch, self.image_key)
+ reconstructions, posterior = self(inputs)
+
+ if optimizer_idx == 0:
+ # train encoder+decoder+logvar
+ aeloss, log_dict_ae = self.loss(
+ inputs,
+ reconstructions,
+ posterior,
+ optimizer_idx,
+ self.global_step,
+ last_layer=self.get_last_layer(),
+ split="train",
+ )
+ self.log(
+ "aeloss",
+ aeloss,
+ prog_bar=True,
+ logger=True,
+ on_step=True,
+ on_epoch=True,
+ )
+ self.log_dict(
+ log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False
+ )
+ return aeloss
+
+ if optimizer_idx == 1:
+ # train the discriminator
+ discloss, log_dict_disc = self.loss(
+ inputs,
+ reconstructions,
+ posterior,
+ optimizer_idx,
+ self.global_step,
+ last_layer=self.get_last_layer(),
+ split="train",
+ )
+
+ self.log(
+ "discloss",
+ discloss,
+ prog_bar=True,
+ logger=True,
+ on_step=True,
+ on_epoch=True,
+ )
+ self.log_dict(
+ log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False
+ )
+ return discloss
+
+ def validation_step(self, batch, batch_idx):
+ inputs = self.get_input(batch, self.image_key)
+ reconstructions, posterior = self(inputs)
+ aeloss, log_dict_ae = self.loss(
+ inputs,
+ reconstructions,
+ posterior,
+ 0,
+ self.global_step,
+ last_layer=self.get_last_layer(),
+ split="val",
+ )
+
+ discloss, log_dict_disc = self.loss(
+ inputs,
+ reconstructions,
+ posterior,
+ 1,
+ self.global_step,
+ last_layer=self.get_last_layer(),
+ split="val",
+ )
+
+ self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
+ self.log_dict(log_dict_ae)
+ self.log_dict(log_dict_disc)
+ return self.log_dict
+
+ def test_step(self, batch, batch_idx):
+ # save z, dec
+ inputs = self.get_input(batch, self.image_key)
+ # forward
+ sample_posterior = True
+ posterior = self.encode(inputs)
+ if sample_posterior:
+ z = posterior.sample()
+ else:
+ z = posterior.mode()
+ dec = self.decode(z)
+
+ # logs
+ if self.test_args.save_z:
+ torch.save(
+ z,
+ os.path.join(
+ self.root_zs,
+ f"zs_batch{batch_idx}_rank{self.global_rank}_shape{shape_to_str(z)}.pt",
+ ),
+ )
+ if self.test_args.save_reconstruction:
+ tensor2videogrids(
+ dec,
+ self.root_dec,
+ f"reconstructions_batch{batch_idx}_rank{self.global_rank}_shape{shape_to_str(z)}.mp4",
+ fps=10,
+ )
+ if self.test_args.save_input:
+ tensor2videogrids(
+ inputs,
+ self.root_inputs,
+ f"inputs_batch{batch_idx}_rank{self.global_rank}_shape{shape_to_str(z)}.mp4",
+ fps=10,
+ )
+
+ if "save_z" in self.test_args and self.test_args.save_z:
+ dec_np = (dec.detach().cpu().numpy().transpose(0, 2, 3, 4, 1) + 1) / 2 * 255
+ dec_np = dec_np.astype(np.uint8)
+ self.root_dec_np = os.path.join(self.root, "reconstructions_np")
+ os.makedirs(self.root_dec_np, exist_ok=True)
+ np.savez(
+ os.path.join(
+ self.root_dec_np,
+ f"reconstructions_batch{batch_idx}_rank{self.global_rank}_shape{shape_to_str(dec_np)}.npz",
+ ),
+ dec_np,
+ )
+
+ self.count += z.shape[0]
+
+ # misc
+ self.log("batch_idx", batch_idx, prog_bar=True)
+ self.log_dict(self.eval_metrics, prog_bar=True, logger=True)
+ torch.cuda.empty_cache()
+ if self.test_maximum is not None:
+ if self.count > self.test_maximum:
+ import sys
+
+ sys.exit()
+ else:
+ prog = self.count / self.test_maximum * 100
+ print(f"Test progress: {prog:.2f}% [{self.count}/{self.test_maximum}]")
+
+ @rank_zero_only
+ def on_test_end(self):
+ if self.test_args.cal_metrics:
+ psnrs, ssims, ms_ssims, lpipses = [], [], [], []
+ n_batches = 0
+ n_samples = 0
+ overall = {}
+ for k, v in self.eval_metrics.items():
+ psnrs.append(v["psnr"])
+ ssims.append(v["ssim"])
+ lpipses.append(v["lpips"])
+ n_batches += 1
+ n_samples += v["n_samples"]
+
+ mean_psnr = sum(psnrs) / len(psnrs)
+ mean_ssim = sum(ssims) / len(ssims)
+ # overall['ms_ssim'] = min(ms_ssims)
+ mean_lpips = sum(lpipses) / len(lpipses)
+
+ overall = {
+ "psnr": mean_psnr,
+ "ssim": mean_ssim,
+ "lpips": mean_lpips,
+ "n_batches": n_batches,
+ "n_samples": n_samples,
+ }
+ overall_t = torch.tensor([mean_psnr, mean_ssim, mean_lpips])
+ # dump
+ for k, v in overall.items():
+ if isinstance(v, torch.Tensor):
+ overall[k] = float(v)
+ with open(
+ os.path.join(self.root, f"reconstruction_metrics.json"), "w"
+ ) as f:
+ json.dump(overall, f)
+ f.close()
+
+ def configure_optimizers(self):
+ lr = self.learning_rate
+ opt_ae = torch.optim.Adam(
+ list(self.encoder.parameters())
+ + list(self.decoder.parameters())
+ + list(self.quant_conv.parameters())
+ + list(self.post_quant_conv.parameters()),
+ lr=lr,
+ betas=(0.5, 0.9),
+ )
+ opt_disc = torch.optim.Adam(
+ self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9)
+ )
+ return [opt_ae, opt_disc], []
+
+ def get_last_layer(self):
+ return self.decoder.conv_out.weight
+
+ @torch.no_grad()
+ def log_images(self, batch, only_inputs=False, **kwargs):
+ log = dict()
+ x = self.get_input(batch, self.image_key)
+ x = x.to(self.device)
+ if not only_inputs:
+ xrec, posterior = self(x)
+ if x.shape[1] > 3:
+ # colorize with random projection
+ assert xrec.shape[1] > 3
+ x = self.to_rgb(x)
+ xrec = self.to_rgb(xrec)
+ log["samples"] = self.decode(torch.randn_like(posterior.sample()))
+ log["reconstructions"] = xrec
+ log["inputs"] = x
+ return log
+
+ def to_rgb(self, x):
+ assert self.image_key == "segmentation"
+ if not hasattr(self, "colorize"):
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
+ x = F.conv2d(x, weight=self.colorize)
+ x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0
+ return x
+
+
+class IdentityFirstStage(torch.nn.Module):
+ def __init__(self, *args, vq_interface=False, **kwargs):
+ self.vq_interface = vq_interface
+ super().__init__()
+
+ def encode(self, x, *args, **kwargs):
+ return x
+
+ def decode(self, x, *args, **kwargs):
+ return x
+
+ def quantize(self, x, *args, **kwargs):
+ if self.vq_interface:
+ return x, None, [None, None, None]
+ return x
+
+ def forward(self, x, *args, **kwargs):
+ return x
diff --git a/core/models/diffusion.py b/core/models/diffusion.py
new file mode 100755
index 0000000000000000000000000000000000000000..2b80441b2b8dec83a4bcbc8149c792f6dc88dcf3
--- /dev/null
+++ b/core/models/diffusion.py
@@ -0,0 +1,1679 @@
+import logging
+from collections import OrderedDict
+from contextlib import contextmanager
+from functools import partial
+
+import numpy as np
+from einops import rearrange
+from tqdm import tqdm
+
+import torch
+import torch.nn as nn
+from torchvision.utils import make_grid
+import pytorch_lightning as pl
+from pytorch_lightning.utilities import rank_zero_only
+
+from core.modules.networks.unet_modules import TASK_IDX_IMAGE, TASK_IDX_RAY
+from utils.utils import instantiate_from_config
+from core.ema import LitEma
+from core.distributions import DiagonalGaussianDistribution
+from core.models.utils_diffusion import make_beta_schedule, rescale_zero_terminal_snr
+from core.models.samplers.ddim import DDIMSampler
+from core.basics import disabled_train
+from core.common import extract_into_tensor, noise_like, exists, default
+
+main_logger = logging.getLogger("main_logger")
+
+
+class BD(nn.Module):
+ def __init__(self, G=10):
+ super(BD, self).__init__()
+
+ self.momentum = 0.9
+ self.register_buffer("running_wm", torch.eye(G).expand(G, G))
+ self.running_wm = None
+
+ def forward(self, x, T=5, eps=1e-5):
+ N, C, G, H, W = x.size()
+ x = torch.permute(x, [0, 2, 1, 3, 4])
+ x_in = x.transpose(0, 1).contiguous().view(G, -1)
+ if self.training:
+ mean = x_in.mean(-1, keepdim=True)
+ xc = x_in - mean
+ d, m = x_in.size()
+ P = [None] * (T + 1)
+ P[0] = torch.eye(G, device=x.device)
+ Sigma = (torch.matmul(xc, xc.transpose(0, 1))) / float(m) + P[0] * eps
+ rTr = (Sigma * P[0]).sum([0, 1], keepdim=True).reciprocal()
+ Sigma_N = Sigma * rTr
+ wm = torch.linalg.solve_triangular(
+ torch.linalg.cholesky(Sigma_N), P[0], upper=False
+ )
+ self.running_wm = self.momentum * self.running_wm + (1 - self.momentum) * wm
+ else:
+ wm = self.running_wm
+
+ x_out = wm @ x_in
+ x_out = x_out.view(G, N, C, H, W).permute([1, 2, 0, 3, 4]).contiguous()
+
+ return x_out
+
+
+class AbstractDDPM(pl.LightningModule):
+
+ def __init__(
+ self,
+ unet_config,
+ time_steps=1000,
+ beta_schedule="linear",
+ loss_type="l2",
+ monitor=None,
+ use_ema=True,
+ first_stage_key="image",
+ image_size=256,
+ channels=3,
+ log_every_t=100,
+ clip_denoised=True,
+ linear_start=1e-4,
+ linear_end=2e-2,
+ cosine_s=8e-3,
+ given_betas=None,
+ original_elbo_weight=0.0,
+ # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
+ v_posterior=0.0,
+ l_simple_weight=1.0,
+ conditioning_key=None,
+ parameterization="eps",
+ rescale_betas_zero_snr=False,
+ scheduler_config=None,
+ use_positional_encodings=False,
+ learn_logvar=False,
+ logvar_init=0.0,
+ bd_noise=False,
+ ):
+ super().__init__()
+ assert parameterization in [
+ "eps",
+ "x0",
+ "v",
+ ], 'currently only supporting "eps" and "x0" and "v"'
+ self.parameterization = parameterization
+ main_logger.info(
+ f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode"
+ )
+ self.cond_stage_model = None
+ self.clip_denoised = clip_denoised
+ self.log_every_t = log_every_t
+ self.first_stage_key = first_stage_key
+ self.channels = channels
+ self.cond_channels = unet_config.params.in_channels - channels
+ self.temporal_length = unet_config.params.temporal_length
+ self.image_size = image_size
+ self.bd_noise = bd_noise
+
+ if self.bd_noise:
+ self.bd = BD(G=self.temporal_length)
+
+ if isinstance(self.image_size, int):
+ self.image_size = [self.image_size, self.image_size]
+ self.use_positional_encodings = use_positional_encodings
+ self.model = DiffusionWrapper(unet_config)
+ self.use_ema = use_ema
+ if self.use_ema:
+ self.model_ema = LitEma(self.model)
+ main_logger.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
+
+ self.rescale_betas_zero_snr = rescale_betas_zero_snr
+ self.use_scheduler = scheduler_config is not None
+ if self.use_scheduler:
+ self.scheduler_config = scheduler_config
+
+ self.v_posterior = v_posterior
+ self.original_elbo_weight = original_elbo_weight
+ self.l_simple_weight = l_simple_weight
+
+ self.linear_end = None
+ self.linear_start = None
+ self.num_time_steps: int = 1000
+
+ if monitor is not None:
+ self.monitor = monitor
+
+ self.register_schedule(
+ given_betas=given_betas,
+ beta_schedule=beta_schedule,
+ time_steps=time_steps,
+ linear_start=linear_start,
+ linear_end=linear_end,
+ cosine_s=cosine_s,
+ )
+
+ self.given_betas = given_betas
+ self.beta_schedule = beta_schedule
+ self.time_steps = time_steps
+ self.cosine_s = cosine_s
+
+ self.loss_type = loss_type
+
+ self.learn_logvar = learn_logvar
+ self.logvar = torch.full(fill_value=logvar_init, size=(self.num_time_steps,))
+ if self.learn_logvar:
+ self.logvar = nn.Parameter(self.logvar, requires_grad=True)
+
+ def predict_start_from_noise(self, x_t, t, noise):
+ return (
+ extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
+ - extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
+ * noise
+ )
+
+ def predict_start_from_z_and_v(self, x_t, t, v):
+ return (
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t
+ - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
+ )
+
+ def predict_eps_from_z_and_v(self, x_t, t, v):
+ return (
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v
+ + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape)
+ * x_t
+ )
+
+ def get_v(self, x, noise, t):
+ return (
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise
+ - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
+ )
+
+ @contextmanager
+ def ema_scope(self, context=None):
+ if self.use_ema:
+ self.model_ema.store(self.model.parameters())
+ self.model_ema.copy_to(self.model)
+ if context is not None:
+ main_logger.info(f"{context}: Switched to EMA weights")
+ try:
+ yield None
+ finally:
+ if self.use_ema:
+ self.model_ema.restore(self.model.parameters())
+ if context is not None:
+ main_logger.info(f"{context}: Restored training weights")
+
+ def q_mean_variance(self, x_start, t):
+ """
+ Get the distribution q(x_t | x_0).
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
+ """
+ mean = extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
+ log_variance = extract_into_tensor(
+ self.log_one_minus_alphas_cumprod, t, x_start.shape
+ )
+ return mean, variance, log_variance
+
+ def q_posterior(self, x_start, x_t, t):
+ posterior_mean = (
+ extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
+ + extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
+ )
+ posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
+ posterior_log_variance_clipped = extract_into_tensor(
+ self.posterior_log_variance_clipped, t, x_t.shape
+ )
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
+
+ def q_sample(self, x_start, t, noise=None):
+ noise = default(noise, lambda: torch.randn_like(x_start))
+ return (
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
+ * noise
+ )
+
+ def get_loss(self, pred, target, mean=True):
+ if self.loss_type == "l1":
+ loss = (target - pred).abs()
+ if mean:
+ loss = loss.mean()
+ elif self.loss_type == "l2":
+ if mean:
+ loss = torch.nn.functional.mse_loss(target, pred)
+ else:
+ loss = torch.nn.functional.mse_loss(target, pred, reduction="none")
+ else:
+ raise NotImplementedError("unknown loss type '{loss_type}'")
+
+ return loss
+
+ def on_train_batch_end(self, *args, **kwargs):
+ if self.use_ema:
+ self.model_ema(self.model)
+
+ def _get_rows_from_list(self, samples):
+ n_imgs_per_row = len(samples)
+ denoise_grid = rearrange(samples, "n b c h w -> b n c h w")
+ denoise_grid = rearrange(denoise_grid, "b n c h w -> (b n) c h w")
+ denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
+ return denoise_grid
+
+
+class DualStreamMultiViewDiffusionModel(AbstractDDPM):
+
+ def __init__(
+ self,
+ first_stage_config,
+ data_key_images,
+ data_key_rays,
+ data_key_text_condition=None,
+ ckpt_path=None,
+ cond_stage_config=None,
+ num_time_steps_cond=None,
+ cond_stage_trainable=False,
+ cond_stage_forward=None,
+ conditioning_key=None,
+ uncond_prob=0.2,
+ uncond_type="empty_seq",
+ scale_factor=1.0,
+ scale_by_std=False,
+ use_noise_offset=False,
+ use_dynamic_rescale=False,
+ base_scale=0.3,
+ turning_step=400,
+ per_frame_auto_encoding=False,
+ # added for LVDM
+ encoder_type="2d",
+ cond_frames=None,
+ logdir=None,
+ empty_params_only=False,
+ # Image Condition
+ cond_img_config=None,
+ image_proj_model_config=None,
+ random_cond=False,
+ padding=False,
+ cond_concat=False,
+ frame_mask=False,
+ use_camera_pose_query_transformer=False,
+ with_cond_binary_mask=False,
+ apply_condition_mask_in_training_loss=True,
+ separate_noise_and_condition=False,
+ condition_padding_with_anchor=False,
+ ray_as_image=False,
+ use_task_embedding=False,
+ use_ray_decoder_loss_high_frequency_isolation=False,
+ disable_ray_stream=False,
+ ray_loss_weight=1.0,
+ train_with_multi_view_feature_alignment=False,
+ use_text_cross_attention_condition=True,
+ *args,
+ **kwargs,
+ ):
+
+ self.image_proj_model = None
+ self.apply_condition_mask_in_training_loss = (
+ apply_condition_mask_in_training_loss
+ )
+ self.separate_noise_and_condition = separate_noise_and_condition
+ self.condition_padding_with_anchor = condition_padding_with_anchor
+ self.use_text_cross_attention_condition = use_text_cross_attention_condition
+
+ self.data_key_images = data_key_images
+ self.data_key_rays = data_key_rays
+ self.data_key_text_condition = data_key_text_condition
+
+ self.num_time_steps_cond = default(num_time_steps_cond, 1)
+ self.scale_by_std = scale_by_std
+ assert self.num_time_steps_cond <= kwargs["time_steps"]
+ self.shorten_cond_schedule = self.num_time_steps_cond > 1
+ super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
+
+ self.cond_stage_trainable = cond_stage_trainable
+ self.empty_params_only = empty_params_only
+ self.per_frame_auto_encoding = per_frame_auto_encoding
+ try:
+ self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
+ except:
+ self.num_downs = 0
+ if not scale_by_std:
+ self.scale_factor = scale_factor
+ else:
+ self.register_buffer("scale_factor", torch.tensor(scale_factor))
+ self.use_noise_offset = use_noise_offset
+ self.use_dynamic_rescale = use_dynamic_rescale
+ if use_dynamic_rescale:
+ scale_arr1 = np.linspace(1.0, base_scale, turning_step)
+ scale_arr2 = np.full(self.num_time_steps, base_scale)
+ scale_arr = np.concatenate((scale_arr1, scale_arr2))
+ to_torch = partial(torch.tensor, dtype=torch.float32)
+ self.register_buffer("scale_arr", to_torch(scale_arr))
+ self.instantiate_first_stage(first_stage_config)
+
+ if self.use_text_cross_attention_condition and cond_stage_config is not None:
+ self.instantiate_cond_stage(cond_stage_config)
+
+ self.first_stage_config = first_stage_config
+ self.cond_stage_config = cond_stage_config
+ self.clip_denoised = False
+
+ self.cond_stage_forward = cond_stage_forward
+ self.encoder_type = encoder_type
+ assert encoder_type in ["2d", "3d"]
+ self.uncond_prob = uncond_prob
+ self.classifier_free_guidance = True if uncond_prob > 0 else False
+ assert uncond_type in ["zero_embed", "empty_seq"]
+ self.uncond_type = uncond_type
+
+ if cond_frames is not None:
+ frame_len = self.temporal_length
+ assert cond_frames[-1] < frame_len, main_logger.info(
+ f"Error: conditioning frame index must not be greater than {frame_len}!"
+ )
+ cond_mask = torch.zeros(frame_len, dtype=torch.float32)
+ cond_mask[cond_frames] = 1.0
+ self.cond_mask = cond_mask[None, None, :, None, None]
+ else:
+ self.cond_mask = None
+
+ self.restarted_from_ckpt = False
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path)
+ self.restarted_from_ckpt = True
+
+ self.logdir = logdir
+ self.with_cond_binary_mask = with_cond_binary_mask
+ self.random_cond = random_cond
+ self.padding = padding
+ self.cond_concat = cond_concat
+ self.frame_mask = frame_mask
+ self.use_img_context = True if cond_img_config is not None else False
+ self.use_camera_pose_query_transformer = use_camera_pose_query_transformer
+ if self.use_img_context:
+ self.init_img_embedder(cond_img_config, freeze=True)
+ self.init_projector(image_proj_model_config, trainable=True)
+
+ self.ray_as_image = ray_as_image
+ self.use_task_embedding = use_task_embedding
+ self.use_ray_decoder_loss_high_frequency_isolation = (
+ use_ray_decoder_loss_high_frequency_isolation
+ )
+ self.disable_ray_stream = disable_ray_stream
+ if disable_ray_stream:
+ assert (
+ not ray_as_image
+ and not self.model.diffusion_model.use_ray_decoder
+ and not self.model.diffusion_model.use_ray_decoder_residual
+ ), "Options related to ray decoder should not be enabled when disabling ray stream."
+ assert (
+ not use_task_embedding
+ and not self.model.diffusion_model.use_task_embedding
+ ), "Task embedding should not be enabled when disabling ray stream."
+ assert (
+ not self.model.diffusion_model.use_addition_ray_output_head
+ ), "Additional ray output head should not be enabled when disabling ray stream."
+ assert (
+ not self.model.diffusion_model.use_lora_for_rays_in_output_blocks
+ ), "LoRA for rays should not be enabled when disabling ray stream."
+ self.ray_loss_weight = ray_loss_weight
+ self.train_with_multi_view_feature_alignment = False
+ if train_with_multi_view_feature_alignment:
+ print(f"MultiViewFeatureExtractor is ignored during inference.")
+
+ def init_from_ckpt(self, checkpoint_path):
+ main_logger.info(f"Initializing model from checkpoint {checkpoint_path}...")
+
+ def grab_ipa_weight(state_dict):
+ ipa_state_dict = OrderedDict()
+ for n in list(state_dict.keys()):
+ if "to_k_ip" in n or "to_v_ip" in n:
+ ipa_state_dict[n] = state_dict[n]
+ elif "image_proj_model" in n:
+ if (
+ self.use_camera_pose_query_transformer
+ and "image_proj_model.latents" in n
+ ):
+ ipa_state_dict[n] = torch.cat(
+ [state_dict[n] for i in range(16)], dim=1
+ )
+ else:
+ ipa_state_dict[n] = state_dict[n]
+ return ipa_state_dict
+
+ state_dict = torch.load(checkpoint_path, map_location="cpu")
+ if "module" in state_dict.keys():
+ # deepspeed
+ target_state_dict = OrderedDict()
+ for key in state_dict["module"].keys():
+ target_state_dict[key[16:]] = state_dict["module"][key]
+ elif "state_dict" in list(state_dict.keys()):
+ target_state_dict = state_dict["state_dict"]
+ else:
+ raise KeyError("Weight key is not found in the state dict.")
+ ipa_state_dict = grab_ipa_weight(target_state_dict)
+ self.load_state_dict(ipa_state_dict, strict=False)
+ main_logger.info("Checkpoint loaded.")
+
+ def init_img_embedder(self, config, freeze=True):
+ embedder = instantiate_from_config(config)
+ if freeze:
+ self.embedder = embedder.eval()
+ self.embedder.train = disabled_train
+ for param in self.embedder.parameters():
+ param.requires_grad = False
+
+ def make_cond_schedule(
+ self,
+ ):
+ self.cond_ids = torch.full(
+ size=(self.num_time_steps,),
+ fill_value=self.num_time_steps - 1,
+ dtype=torch.long,
+ )
+ ids = torch.round(
+ torch.linspace(0, self.num_time_steps - 1, self.num_time_steps_cond)
+ ).long()
+ self.cond_ids[: self.num_time_steps_cond] = ids
+
+ def init_projector(self, config, trainable):
+ self.image_proj_model = instantiate_from_config(config)
+ if not trainable:
+ self.image_proj_model.eval()
+ self.image_proj_model.train = disabled_train
+ for param in self.image_proj_model.parameters():
+ param.requires_grad = False
+
+ @staticmethod
+ def pad_cond_images(batch_images):
+ h, w = batch_images.shape[-2:]
+ border = (w - h) // 2
+ # use padding at (W_t,W_b,H_t,H_b)
+ batch_images = torch.nn.functional.pad(
+ batch_images, (0, 0, border, border), "constant", 0
+ )
+ return batch_images
+
+ # Never delete this func: it is used in log_images() and inference stage
+ def get_image_embeds(self, batch_images, batch=None):
+ # input shape: b c h w
+ if self.padding:
+ batch_images = self.pad_cond_images(batch_images)
+ img_token = self.embedder(batch_images)
+ if self.use_camera_pose_query_transformer:
+ batch_size, num_views, _ = batch["target_poses"].shape
+ img_emb = self.image_proj_model(
+ img_token, batch["target_poses"].reshape(batch_size, num_views, 12)
+ )
+ else:
+ img_emb = self.image_proj_model(img_token)
+
+ return img_emb
+
+ @staticmethod
+ def get_input(batch, k):
+ x = batch[k]
+ """
+ # for image batch from image loader
+ if len(x.shape) == 4:
+ x = rearrange(x, 'b h w c -> b c h w')
+ """
+ x = x.to(memory_format=torch.contiguous_format) # .float()
+ return x
+
+ @rank_zero_only
+ @torch.no_grad()
+ def on_train_batch_start(self, batch, batch_idx, dataloader_idx=None):
+ # only for very first batch, reset the self.scale_factor
+ if (
+ self.scale_by_std
+ and self.current_epoch == 0
+ and self.global_step == 0
+ and batch_idx == 0
+ and not self.restarted_from_ckpt
+ ):
+ assert (
+ self.scale_factor == 1.0
+ ), "rather not use custom rescaling and std-rescaling simultaneously"
+ # set rescale weight to 1./std of encodings
+ main_logger.info("## USING STD-RESCALING ###")
+ x = self.get_input(batch, self.first_stage_key)
+ x = x.to(self.device)
+ encoder_posterior = self.encode_first_stage(x)
+ z = self.get_first_stage_encoding(encoder_posterior).detach()
+ del self.scale_factor
+ self.register_buffer("scale_factor", 1.0 / z.flatten().std())
+ main_logger.info(f"setting self.scale_factor to {self.scale_factor}")
+ main_logger.info("## USING STD-RESCALING ###")
+ main_logger.info(f"std={z.flatten().std()}")
+
+ def register_schedule(
+ self,
+ given_betas=None,
+ beta_schedule="linear",
+ time_steps=1000,
+ linear_start=1e-4,
+ linear_end=2e-2,
+ cosine_s=8e-3,
+ ):
+ if exists(given_betas):
+ betas = given_betas
+ else:
+ betas = make_beta_schedule(
+ beta_schedule,
+ time_steps,
+ linear_start=linear_start,
+ linear_end=linear_end,
+ cosine_s=cosine_s,
+ )
+
+ if self.rescale_betas_zero_snr:
+ betas = rescale_zero_terminal_snr(betas)
+ alphas = 1.0 - betas
+ alphas_cumprod = np.cumprod(alphas, axis=0)
+ alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
+
+ (time_steps,) = betas.shape
+ self.num_time_steps = int(time_steps)
+ self.linear_start = linear_start
+ self.linear_end = linear_end
+ assert (
+ alphas_cumprod.shape[0] == self.num_time_steps
+ ), "alphas have to be defined for each timestep"
+
+ to_torch = partial(torch.tensor, dtype=torch.float32)
+
+ self.register_buffer("betas", to_torch(betas))
+ self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
+ self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev))
+
+ # calculations for diffusion q(x_t | x_{t-1}) and others
+ self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod)))
+ self.register_buffer(
+ "sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod))
+ )
+ self.register_buffer(
+ "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod))
+ )
+ self.register_buffer(
+ "sqrt_recip_alphas_cumprod",
+ to_torch(np.sqrt(1.0 / (alphas_cumprod + 1e-5))),
+ )
+ self.register_buffer(
+ "sqrt_recipm1_alphas_cumprod",
+ to_torch(np.sqrt(1.0 / (alphas_cumprod + 1e-5) - 1)),
+ )
+
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
+ posterior_variance = (1 - self.v_posterior) * betas * (
+ 1.0 - alphas_cumprod_prev
+ ) / (1.0 - alphas_cumprod) + self.v_posterior * betas
+ # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
+ self.register_buffer("posterior_variance", to_torch(posterior_variance))
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
+ self.register_buffer(
+ "posterior_log_variance_clipped",
+ to_torch(np.log(np.maximum(posterior_variance, 1e-20))),
+ )
+ self.register_buffer(
+ "posterior_mean_coef1",
+ to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)),
+ )
+ self.register_buffer(
+ "posterior_mean_coef2",
+ to_torch(
+ (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod)
+ ),
+ )
+
+ if self.parameterization == "eps":
+ lvlb_weights = self.betas**2 / (
+ 2
+ * self.posterior_variance
+ * to_torch(alphas)
+ * (1 - self.alphas_cumprod)
+ )
+ elif self.parameterization == "x0":
+ lvlb_weights = (
+ 0.5
+ * np.sqrt(torch.Tensor(alphas_cumprod))
+ / (2.0 * 1 - torch.Tensor(alphas_cumprod))
+ )
+ elif self.parameterization == "v":
+ lvlb_weights = torch.ones_like(
+ self.betas**2
+ / (
+ 2
+ * self.posterior_variance
+ * to_torch(alphas)
+ * (1 - self.alphas_cumprod)
+ )
+ )
+ else:
+ raise NotImplementedError("mu not supported")
+ lvlb_weights[0] = lvlb_weights[1]
+ self.register_buffer("lvlb_weights", lvlb_weights, persistent=False)
+ assert not torch.isnan(self.lvlb_weights).all()
+
+ if self.shorten_cond_schedule:
+ self.make_cond_schedule()
+
+ def instantiate_first_stage(self, config):
+ model = instantiate_from_config(config)
+ self.first_stage_model = model.eval()
+ self.first_stage_model.train = disabled_train
+ for param in self.first_stage_model.parameters():
+ param.requires_grad = False
+
+ def instantiate_cond_stage(self, config):
+ if not self.cond_stage_trainable:
+ model = instantiate_from_config(config)
+ self.cond_stage_model = model.eval()
+ self.cond_stage_model.train = disabled_train
+ for param in self.cond_stage_model.parameters():
+ param.requires_grad = False
+ else:
+ model = instantiate_from_config(config)
+ self.cond_stage_model = model
+
+ def get_learned_conditioning(self, c):
+ if self.cond_stage_forward is None:
+ if hasattr(self.cond_stage_model, "encode") and callable(
+ self.cond_stage_model.encode
+ ):
+ c = self.cond_stage_model.encode(c)
+ if isinstance(c, DiagonalGaussianDistribution):
+ c = c.mode()
+ else:
+ c = self.cond_stage_model(c)
+ else:
+ assert hasattr(self.cond_stage_model, self.cond_stage_forward)
+ c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
+ return c
+
+ def get_first_stage_encoding(self, encoder_posterior, noise=None):
+ if isinstance(encoder_posterior, DiagonalGaussianDistribution):
+ z = encoder_posterior.sample(noise=noise)
+ elif isinstance(encoder_posterior, torch.Tensor):
+ z = encoder_posterior
+ else:
+ raise NotImplementedError(
+ f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented"
+ )
+ return self.scale_factor * z
+
+ @torch.no_grad()
+ def encode_first_stage(self, x):
+ assert x.dim() == 5 or x.dim() == 4, (
+ "Images should be a either 5-dimensional (batched image sequence) "
+ "or 4-dimensional (batched images)."
+ )
+ if (
+ self.encoder_type == "2d"
+ and x.dim() == 5
+ and not self.per_frame_auto_encoding
+ ):
+ b, t, _, _, _ = x.shape
+ x = rearrange(x, "b t c h w -> (b t) c h w")
+ reshape_back = True
+ else:
+ b, _, _, _, _ = x.shape
+ t = 1
+ reshape_back = False
+
+ if not self.per_frame_auto_encoding:
+ encoder_posterior = self.first_stage_model.encode(x)
+ results = self.get_first_stage_encoding(encoder_posterior).detach()
+ else:
+ results = []
+ for index in range(x.shape[1]):
+ frame_batch = self.first_stage_model.encode(x[:, index, :, :, :])
+ frame_result = self.get_first_stage_encoding(frame_batch).detach()
+ results.append(frame_result)
+ results = torch.stack(results, dim=1)
+
+ if reshape_back:
+ results = rearrange(results, "(b t) c h w -> b t c h w", b=b, t=t)
+
+ return results
+
+ def decode_core(self, z, **kwargs):
+ assert z.dim() == 5 or z.dim() == 4, (
+ "Latents should be a either 5-dimensional (batched latent sequence) "
+ "or 4-dimensional (batched latents)."
+ )
+
+ if (
+ self.encoder_type == "2d"
+ and z.dim() == 5
+ and not self.per_frame_auto_encoding
+ ):
+ b, t, _, _, _ = z.shape
+ z = rearrange(z, "b t c h w -> (b t) c h w")
+ reshape_back = True
+ else:
+ b, _, _, _, _ = z.shape
+ t = 1
+ reshape_back = False
+
+ if not self.per_frame_auto_encoding:
+ z = 1.0 / self.scale_factor * z
+ results = self.first_stage_model.decode(z, **kwargs)
+ else:
+ results = []
+ for index in range(z.shape[1]):
+ frame_z = 1.0 / self.scale_factor * z[:, index, :, :, :]
+ frame_result = self.first_stage_model.decode(frame_z, **kwargs)
+ results.append(frame_result)
+ results = torch.stack(results, dim=1)
+
+ if reshape_back:
+ results = rearrange(results, "(b t) c h w -> b t c h w", b=b, t=t)
+ return results
+
+ @torch.no_grad()
+ def decode_first_stage(self, z, **kwargs):
+ return self.decode_core(z, **kwargs)
+
+ def differentiable_decode_first_stage(self, z, **kwargs):
+ return self.decode_core(z, **kwargs)
+
+ def get_batch_input(
+ self,
+ batch,
+ random_drop_training_conditions,
+ return_reconstructed_target_images=False,
+ ):
+ combined_images = batch[self.data_key_images]
+ clean_combined_image_latents = self.encode_first_stage(combined_images)
+ mask_preserving_target = batch["mask_preserving_target"].reshape(
+ batch["mask_preserving_target"].size(0),
+ batch["mask_preserving_target"].size(1),
+ 1,
+ 1,
+ 1,
+ )
+ mask_preserving_condition = 1.0 - mask_preserving_target
+ if self.ray_as_image:
+ clean_combined_ray_images = batch[self.data_key_rays]
+ clean_combined_ray_o_latents = self.encode_first_stage(
+ clean_combined_ray_images[:, :, :3, :, :]
+ )
+ clean_combined_ray_d_latents = self.encode_first_stage(
+ clean_combined_ray_images[:, :, 3:, :, :]
+ )
+ clean_combined_rays = torch.concat(
+ [clean_combined_ray_o_latents, clean_combined_ray_d_latents], dim=2
+ )
+
+ if self.condition_padding_with_anchor:
+ condition_ray_images = batch["condition_rays"]
+ condition_ray_o_images = self.encode_first_stage(
+ condition_ray_images[:, :, :3, :, :]
+ )
+ condition_ray_d_images = self.encode_first_stage(
+ condition_ray_images[:, :, 3:, :, :]
+ )
+ condition_rays = torch.concat(
+ [condition_ray_o_images, condition_ray_d_images], dim=2
+ )
+ else:
+ condition_rays = clean_combined_rays * mask_preserving_target
+ else:
+ clean_combined_rays = batch[self.data_key_rays]
+
+ if self.condition_padding_with_anchor:
+ condition_rays = batch["condition_rays"]
+ else:
+ condition_rays = clean_combined_rays * mask_preserving_target
+
+ if self.condition_padding_with_anchor:
+ condition_images_latents = self.encode_first_stage(
+ batch["condition_images"]
+ )
+ else:
+ condition_images_latents = (
+ clean_combined_image_latents * mask_preserving_condition
+ )
+
+ if random_drop_training_conditions:
+ random_num = torch.rand(
+ combined_images.size(0), device=combined_images.device
+ )
+ else:
+ random_num = torch.ones(
+ combined_images.size(0), device=combined_images.device
+ )
+
+ text_feature_condition_mask = rearrange(
+ random_num < 2 * self.uncond_prob, "n -> n 1 1"
+ )
+ image_feature_condition_mask = 1 - rearrange(
+ (random_num >= self.uncond_prob).float()
+ * (random_num < 3 * self.uncond_prob).float(),
+ "n -> n 1 1 1 1",
+ )
+ ray_condition_mask = 1 - rearrange(
+ (random_num >= 1.5 * self.uncond_prob).float()
+ * (random_num < 3.5 * self.uncond_prob).float(),
+ "n -> n 1 1 1 1",
+ )
+ mask_preserving_first_target = batch[
+ "mask_only_preserving_first_target"
+ ].reshape(
+ batch["mask_only_preserving_first_target"].size(0),
+ batch["mask_only_preserving_first_target"].size(1),
+ 1,
+ 1,
+ 1,
+ )
+ mask_preserving_first_condition = batch[
+ "mask_only_preserving_first_condition"
+ ].reshape(
+ batch["mask_only_preserving_first_condition"].size(0),
+ batch["mask_only_preserving_first_condition"].size(1),
+ 1,
+ 1,
+ 1,
+ )
+ mask_preserving_anchors = (
+ mask_preserving_first_target + mask_preserving_first_condition
+ )
+ mask_randomly_preserving_first_target = torch.where(
+ ray_condition_mask.repeat(1, mask_preserving_first_target.size(1), 1, 1, 1)
+ == 1.0,
+ 1.0,
+ mask_preserving_first_target,
+ )
+ mask_randomly_preserving_first_condition = torch.where(
+ image_feature_condition_mask.repeat(
+ 1, mask_preserving_first_condition.size(1), 1, 1, 1
+ )
+ == 1.0,
+ 1.0,
+ mask_preserving_first_condition,
+ )
+
+ if self.use_text_cross_attention_condition:
+ text_cond_key = self.data_key_text_condition
+ text_cond = batch[text_cond_key]
+ if isinstance(text_cond, dict) or isinstance(text_cond, list):
+ full_text_cond_emb = self.get_learned_conditioning(text_cond)
+ else:
+ full_text_cond_emb = self.get_learned_conditioning(
+ text_cond.to(self.device)
+ )
+ null_text_cond_emb = self.get_learned_conditioning([""])
+ text_cond_emb = torch.where(
+ text_feature_condition_mask,
+ null_text_cond_emb,
+ full_text_cond_emb.detach(),
+ )
+
+ batch_size, num_views, _, _, _ = batch[self.data_key_images].shape
+ if self.condition_padding_with_anchor:
+ condition_images = batch["condition_images"]
+ else:
+ condition_images = combined_images * mask_preserving_condition
+ if random_drop_training_conditions:
+ condition_image_for_embedder = rearrange(
+ condition_images * image_feature_condition_mask,
+ "b t c h w -> (b t) c h w",
+ )
+ else:
+ condition_image_for_embedder = rearrange(
+ condition_images, "b t c h w -> (b t) c h w"
+ )
+ img_token = self.embedder(condition_image_for_embedder)
+ if self.use_camera_pose_query_transformer:
+ img_emb = self.image_proj_model(
+ img_token, batch["target_poses"].reshape(batch_size, num_views, 12)
+ )
+ else:
+ img_emb = self.image_proj_model(img_token)
+
+ img_emb = rearrange(
+ img_emb, "(b t) s d -> b (t s) d", b=batch_size, t=num_views
+ )
+ if self.use_text_cross_attention_condition:
+ c_crossattn = [torch.cat([text_cond_emb, img_emb], dim=1)]
+ else:
+ c_crossattn = [img_emb]
+
+ cond_dict = {
+ "c_crossattn": c_crossattn,
+ "target_camera_poses": batch["target_and_condition_camera_poses"]
+ * batch["mask_preserving_target"].unsqueeze(-1),
+ }
+
+ if self.disable_ray_stream:
+ clean_gt = torch.cat([clean_combined_image_latents], dim=2)
+ else:
+ clean_gt = torch.cat(
+ [clean_combined_image_latents, clean_combined_rays], dim=2
+ )
+ if random_drop_training_conditions:
+ combined_condition = torch.cat(
+ [
+ condition_images_latents * mask_randomly_preserving_first_condition,
+ condition_rays * mask_randomly_preserving_first_target,
+ ],
+ dim=2,
+ )
+ else:
+ combined_condition = torch.cat(
+ [condition_images_latents, condition_rays], dim=2
+ )
+
+ uncond_combined_condition = torch.cat(
+ [
+ condition_images_latents * mask_preserving_anchors,
+ condition_rays * mask_preserving_anchors,
+ ],
+ dim=2,
+ )
+
+ mask_full_for_input = torch.cat(
+ [
+ mask_preserving_condition.repeat(
+ 1, 1, condition_images_latents.size(2), 1, 1
+ ),
+ mask_preserving_target.repeat(1, 1, condition_rays.size(2), 1, 1),
+ ],
+ dim=2,
+ )
+ cond_dict.update(
+ {
+ "mask_preserving_target": mask_preserving_target,
+ "mask_preserving_condition": mask_preserving_condition,
+ "combined_condition": combined_condition,
+ "uncond_combined_condition": uncond_combined_condition,
+ "clean_combined_rays": clean_combined_rays,
+ "mask_full_for_input": mask_full_for_input,
+ "num_cond_images": rearrange(
+ batch["num_cond_images"].float(), "b -> b 1 1 1 1"
+ ),
+ "num_target_images": rearrange(
+ batch["num_target_images"].float(), "b -> b 1 1 1 1"
+ ),
+ }
+ )
+
+ out = [clean_gt, cond_dict]
+ if return_reconstructed_target_images:
+ target_images_reconstructed = self.decode_first_stage(
+ clean_combined_image_latents
+ )
+ out.append(target_images_reconstructed)
+ return out
+
+ def get_dynamic_scales(self, t, spin_step=400):
+ base_scale = self.base_scale
+ scale_t = torch.where(
+ t < spin_step,
+ t * (base_scale - 1.0) / spin_step + 1.0,
+ base_scale * torch.ones_like(t),
+ )
+ return scale_t
+
+ def forward(self, x, c, **kwargs):
+ t = torch.randint(
+ 0, self.num_time_steps, (x.shape[0],), device=self.device
+ ).long()
+ if self.use_dynamic_rescale:
+ x = x * extract_into_tensor(self.scale_arr, t, x.shape)
+ return self.p_losses(x, c, t, **kwargs)
+
+ def extract_feature(self, batch, t, **kwargs):
+ z, cond = self.get_batch_input(
+ batch,
+ random_drop_training_conditions=False,
+ return_reconstructed_target_images=False,
+ )
+ if self.use_dynamic_rescale:
+ z = z * extract_into_tensor(self.scale_arr, t, z.shape)
+ noise = torch.randn_like(z)
+ if self.use_noise_offset:
+ noise = noise + 0.1 * torch.randn(
+ noise.shape[0], noise.shape[1], 1, 1, 1
+ ).to(self.device)
+ x_noisy = self.q_sample(x_start=z, t=t, noise=noise)
+ x_noisy = self.process_x_with_condition(x_noisy, condition_dict=cond)
+ c_crossattn = torch.cat(cond["c_crossattn"], 1)
+ target_camera_poses = cond["target_camera_poses"]
+ x_pred, features = self.model(
+ x_noisy,
+ t,
+ context=c_crossattn,
+ return_output_block_features=True,
+ camera_poses=target_camera_poses,
+ **kwargs,
+ )
+ return x_pred, features, z
+
+ def apply_model(self, x_noisy, t, cond, features_to_return=None, **kwargs):
+ if not isinstance(cond, dict):
+ if not isinstance(cond, list):
+ cond = [cond]
+ key = (
+ "c_concat" if self.model.conditioning_key == "concat" else "c_crossattn"
+ )
+ cond = {key: cond}
+
+ c_crossattn = torch.cat(cond["c_crossattn"], 1)
+ x_noisy = self.process_x_with_condition(x_noisy, condition_dict=cond)
+ target_camera_poses = cond["target_camera_poses"]
+ if self.use_task_embedding:
+ x_pred_images = self.model(
+ x_noisy,
+ t,
+ context=c_crossattn,
+ task_idx=TASK_IDX_IMAGE,
+ camera_poses=target_camera_poses,
+ **kwargs,
+ )
+ x_pred_rays = self.model(
+ x_noisy,
+ t,
+ context=c_crossattn,
+ task_idx=TASK_IDX_RAY,
+ camera_poses=target_camera_poses,
+ **kwargs,
+ )
+ x_pred = torch.concat([x_pred_images, x_pred_rays], dim=2)
+ elif features_to_return is not None:
+ x_pred, features = self.model(
+ x_noisy,
+ t,
+ context=c_crossattn,
+ return_input_block_features="input" in features_to_return,
+ return_middle_feature="middle" in features_to_return,
+ return_output_block_features="output" in features_to_return,
+ camera_poses=target_camera_poses,
+ **kwargs,
+ )
+ return x_pred, features
+ elif self.train_with_multi_view_feature_alignment:
+ x_pred, aligned_features = self.model(
+ x_noisy,
+ t,
+ context=c_crossattn,
+ camera_poses=target_camera_poses,
+ **kwargs,
+ )
+ return x_pred, aligned_features
+ else:
+ x_pred = self.model(
+ x_noisy,
+ t,
+ context=c_crossattn,
+ camera_poses=target_camera_poses,
+ **kwargs,
+ )
+ return x_pred
+
+ def process_x_with_condition(self, x_noisy, condition_dict):
+ combined_condition = condition_dict["combined_condition"]
+ if self.separate_noise_and_condition:
+ if self.disable_ray_stream:
+ x_noisy = torch.concat([x_noisy, combined_condition], dim=2)
+ else:
+ x_noisy = torch.concat(
+ [
+ x_noisy[:, :, :4, :, :],
+ combined_condition[:, :, :4, :, :],
+ x_noisy[:, :, 4:, :, :],
+ combined_condition[:, :, 4:, :, :],
+ ],
+ dim=2,
+ )
+ else:
+ assert (
+ not self.use_ray_decoder_regression
+ ), "`separate_noise_and_condition` must be True when enabling `use_ray_decoder_regression`."
+ mask_preserving_target = condition_dict["mask_preserving_target"]
+ mask_preserving_condition = condition_dict["mask_preserving_condition"]
+ mask_for_combined_condition = torch.cat(
+ [
+ mask_preserving_target.repeat(1, 1, 4, 1, 1),
+ mask_preserving_condition.repeat(1, 1, 6, 1, 1),
+ ]
+ )
+ mask_for_x_noisy = torch.cat(
+ [
+ mask_preserving_target.repeat(1, 1, 4, 1, 1),
+ mask_preserving_condition.repeat(1, 1, 6, 1, 1),
+ ]
+ )
+ x_noisy = (
+ x_noisy * mask_for_x_noisy
+ + combined_condition * mask_for_combined_condition
+ )
+
+ return x_noisy
+
+ def p_losses(self, x_start, cond, t, noise=None, **kwargs):
+
+ noise = default(noise, lambda: torch.randn_like(x_start))
+
+ if self.use_noise_offset:
+ noise = noise + 0.1 * torch.randn(
+ noise.shape[0], noise.shape[1], 1, 1, 1
+ ).to(self.device)
+
+ # noise em !!!
+ if self.bd_noise:
+ noise_decor = self.bd(noise)
+ noise_decor = (noise_decor - noise_decor.mean()) / (
+ noise_decor.std() + 1e-5
+ )
+ noise_f = noise_decor[:, :, 0:1, :, :]
+ noise = (
+ np.sqrt(self.bd_ratio) * noise_decor[:, :, 1:]
+ + np.sqrt(1 - self.bd_ratio) * noise_f
+ )
+ noise = torch.cat([noise_f, noise], dim=2)
+
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
+ if self.train_with_multi_view_feature_alignment:
+ model_output, aligned_features = self.apply_model(
+ x_noisy, t, cond, **kwargs
+ )
+
+ aligned_middle_feature = rearrange(
+ aligned_features,
+ "(b t) c h w -> b (t c h w)",
+ b=cond["pts_anchor_to_all"].size(0),
+ t=cond["pts_anchor_to_all"].size(1),
+ )
+ target_multi_view_feature = rearrange(
+ torch.concat(
+ [cond["pts_anchor_to_all"], cond["pts_all_to_anchor"]], dim=2
+ ),
+ "b t c h w -> b (t c h w)",
+ ).to(aligned_middle_feature.device)
+ else:
+ model_output = self.apply_model(x_noisy, t, cond, **kwargs)
+
+ loss_dict = {}
+ prefix = "train" if self.training else "val"
+
+ if self.parameterization == "x0":
+ target = x_start
+ elif self.parameterization == "eps":
+ target = noise
+ elif self.parameterization == "v":
+ target = self.get_v(x_start, noise, t)
+ else:
+ raise NotImplementedError()
+
+ if self.apply_condition_mask_in_training_loss:
+ mask_full_for_output = 1.0 - cond["mask_full_for_input"]
+ model_output = model_output * mask_full_for_output
+ target = target * mask_full_for_output
+ loss_simple = self.get_loss(model_output, target, mean=False)
+ if self.ray_loss_weight != 1.0:
+ loss_simple[:, :, 4:, :, :] = (
+ loss_simple[:, :, 4:, :, :] * self.ray_loss_weight
+ )
+ if self.apply_condition_mask_in_training_loss:
+ # Ray loss: predicted items = # of condition images
+ num_total_images = cond["num_cond_images"] + cond["num_target_images"]
+ weight_for_image_loss = num_total_images / cond["num_target_images"]
+ weight_for_ray_loss = num_total_images / cond["num_cond_images"]
+ loss_simple[:, :, :4, :, :] = (
+ loss_simple[:, :, :4, :, :] * weight_for_image_loss
+ )
+ # Ray loss: predicted items = # of condition images
+ loss_simple[:, :, 4:, :, :] = (
+ loss_simple[:, :, 4:, :, :] * weight_for_ray_loss
+ )
+
+ loss_dict.update({f"{prefix}/loss_images": loss_simple[:, :, 0:4, :, :].mean()})
+ if not self.disable_ray_stream:
+ loss_dict.update(
+ {f"{prefix}/loss_rays": loss_simple[:, :, 4:, :, :].mean()}
+ )
+ loss_simple = loss_simple.mean([1, 2, 3, 4])
+ loss_dict.update({f"{prefix}/loss_simple": loss_simple.mean()})
+
+ if self.logvar.device is not self.device:
+ self.logvar = self.logvar.to(self.device)
+ logvar_t = self.logvar[t]
+ loss = loss_simple / torch.exp(logvar_t) + logvar_t
+ if self.learn_logvar:
+ loss_dict.update({f"{prefix}/loss_gamma": loss.mean()})
+ loss_dict.update({"logvar": self.logvar.data.mean()})
+
+ loss = self.l_simple_weight * loss.mean()
+
+ if self.train_with_multi_view_feature_alignment:
+ multi_view_feature_alignment_loss = 0.25 * torch.nn.functional.mse_loss(
+ aligned_middle_feature, target_multi_view_feature
+ )
+ loss += multi_view_feature_alignment_loss
+ loss_dict.update(
+ {f"{prefix}/loss_mv_feat_align": multi_view_feature_alignment_loss}
+ )
+
+ loss_vlb = self.get_loss(model_output, target, mean=False).mean(
+ dim=(1, 2, 3, 4)
+ )
+ loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
+ loss_dict.update({f"{prefix}/loss_vlb": loss_vlb})
+ loss += self.original_elbo_weight * loss_vlb
+ loss_dict.update({f"{prefix}/loss": loss})
+
+ return loss, loss_dict
+
+ def _get_denoise_row_from_list(self, samples, desc=""):
+ denoise_row = []
+ for zd in tqdm(samples, desc=desc):
+ denoise_row.append(self.decode_first_stage(zd.to(self.device)))
+ n_log_time_steps = len(denoise_row)
+
+ denoise_row = torch.stack(denoise_row) # n_log_time_steps, b, C, H, W
+
+ if denoise_row.dim() == 5:
+ denoise_grid = rearrange(denoise_row, "n b c h w -> b n c h w")
+ denoise_grid = rearrange(denoise_grid, "b n c h w -> (b n) c h w")
+ denoise_grid = make_grid(denoise_grid, nrow=n_log_time_steps)
+ elif denoise_row.dim() == 6:
+ video_length = denoise_row.shape[3]
+ denoise_grid = rearrange(denoise_row, "n b c t h w -> b n c t h w")
+ denoise_grid = rearrange(denoise_grid, "b n c t h w -> (b n) c t h w")
+ denoise_grid = rearrange(denoise_grid, "n c t h w -> (n t) c h w")
+ denoise_grid = make_grid(denoise_grid, nrow=video_length)
+ else:
+ raise ValueError
+
+ return denoise_grid
+
+ @torch.no_grad()
+ def log_images(
+ self,
+ batch,
+ sample=True,
+ ddim_steps=50,
+ ddim_eta=1.0,
+ plot_denoise_rows=False,
+ unconditional_guidance_scale=1.0,
+ **kwargs,
+ ):
+ """log images for LatentDiffusion"""
+ use_ddim = ddim_steps is not None
+ log = dict()
+ z, cond, x_rec = self.get_batch_input(
+ batch,
+ random_drop_training_conditions=False,
+ return_reconstructed_target_images=True,
+ )
+ b, t, c, h, w = x_rec.shape
+ log["num_cond_images_str"] = batch["num_cond_images_str"]
+ log["caption"] = batch["caption"]
+ if "condition_images" in batch:
+ log["input_condition_images_all"] = batch["condition_images"]
+ log["input_condition_image_latents_masked"] = cond["combined_condition"][
+ :, :, 0:3, :, :
+ ]
+ log["input_condition_rays_o_masked"] = (
+ cond["combined_condition"][:, :, 4:7, :, :] / 5.0
+ )
+ log["input_condition_rays_d_masked"] = (
+ cond["combined_condition"][:, :, 7:, :, :] / 5.0
+ )
+ log["gt_images_after_vae"] = x_rec
+ if self.train_with_multi_view_feature_alignment:
+ log["pts_anchor_to_all"] = cond["pts_anchor_to_all"]
+ log["pts_all_to_anchor"] = cond["pts_all_to_anchor"]
+ log["pts_anchor_to_all"] = (
+ log["pts_anchor_to_all"] - torch.min(log["pts_anchor_to_all"])
+ ) / torch.max(log["pts_anchor_to_all"])
+ log["pts_all_to_anchor"] = (
+ log["pts_all_to_anchor"] - torch.min(log["pts_all_to_anchor"])
+ ) / torch.max(log["pts_all_to_anchor"])
+
+ if self.ray_as_image:
+ log["gt_rays_o"] = batch["combined_rays"][:, :, 0:3, :, :]
+ log["gt_rays_d"] = batch["combined_rays"][:, :, 3:, :, :]
+ else:
+ log["gt_rays_o"] = batch["combined_rays"][:, :, 0:3, :, :] / 5.0
+ log["gt_rays_d"] = batch["combined_rays"][:, :, 3:, :, :] / 5.0
+
+ if sample:
+ # get uncond embedding for classifier-free guidance sampling
+ if unconditional_guidance_scale != 1.0:
+ uc = self.get_unconditional_dict_for_sampling(batch, cond, x_rec)
+ else:
+ uc = None
+
+ with self.ema_scope("Plotting"):
+ out = self.sample_log(
+ cond=cond,
+ batch_size=b,
+ ddim=use_ddim,
+ ddim_steps=ddim_steps,
+ eta=ddim_eta,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=uc,
+ mask=self.cond_mask,
+ x0=z,
+ with_extra_returned_data=False,
+ **kwargs,
+ )
+ samples, z_denoise_row = out
+ per_instance_decoding = False
+
+ if per_instance_decoding:
+ x_sample_images = []
+ for idx in range(b):
+ sample_image = samples[idx : idx + 1, :, 0:4, :, :]
+ x_sample_image = self.decode_first_stage(sample_image)
+ x_sample_images.append(x_sample_image)
+ x_sample_images = torch.cat(x_sample_images, dim=0)
+ else:
+ x_sample_images = self.decode_first_stage(samples[:, :, 0:4, :, :])
+ log["sample_images"] = x_sample_images
+
+ if not self.disable_ray_stream:
+ if self.ray_as_image:
+ log["sample_rays_o"] = self.decode_first_stage(
+ samples[:, :, 4:8, :, :]
+ )
+ log["sample_rays_d"] = self.decode_first_stage(
+ samples[:, :, 8:, :, :]
+ )
+ else:
+ log["sample_rays_o"] = samples[:, :, 4:7, :, :] / 5.0
+ log["sample_rays_d"] = samples[:, :, 7:, :, :] / 5.0
+
+ if plot_denoise_rows:
+ denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
+ log["denoise_row"] = denoise_grid
+
+ return log
+
+ def get_unconditional_dict_for_sampling(self, batch, cond, x_rec, is_extra=False):
+ b, t, c, h, w = x_rec.shape
+ if self.use_text_cross_attention_condition:
+ if self.uncond_type == "empty_seq":
+ # NVComposer's cross attention layers accept multi-view images
+ prompts = b * [""]
+ # prompts = b * t * [""] # if is_image_batch=True
+ uc_emb = self.get_learned_conditioning(prompts)
+ elif self.uncond_type == "zero_embed":
+ c_emb = cond["c_crossattn"][0] if isinstance(cond, dict) else cond
+ uc_emb = torch.zeros_like(c_emb)
+ else:
+ uc_emb = None
+
+ # process image condition
+ if not is_extra:
+ if hasattr(self, "embedder"):
+ # uc_img = torch.zeros_like(x[:, :, 0, ...]) # b c h w
+ uc_img = torch.zeros(
+ # b c h w
+ size=(b * t, c, h, w),
+ dtype=x_rec.dtype,
+ device=x_rec.device,
+ )
+ # img: b c h w >> b l c
+ uc_img = self.get_image_embeds(uc_img, batch)
+
+ # Modified: The uc embeddings should be reshaped for valid post-processing
+ uc_img = rearrange(
+ uc_img, "(b t) s d -> b (t s) d", b=b, t=uc_img.shape[0] // b
+ )
+ if uc_emb is None:
+ uc_emb = uc_img
+ else:
+ uc_emb = torch.cat([uc_emb, uc_img], dim=1)
+ uc = {key: cond[key] for key in cond.keys()}
+ uc.update({"c_crossattn": [uc_emb]})
+ else:
+ uc = {key: cond[key] for key in cond.keys()}
+ uc.update({"combined_condition": uc["uncond_combined_condition"]})
+
+ return uc
+
+ def p_mean_variance(
+ self,
+ x,
+ c,
+ t,
+ clip_denoised: bool,
+ return_x0=False,
+ score_corrector=None,
+ corrector_kwargs=None,
+ **kwargs,
+ ):
+ t_in = t
+ model_out = self.apply_model(x, t_in, c, **kwargs)
+
+ if score_corrector is not None:
+ assert self.parameterization == "eps"
+ model_out = score_corrector.modify_score(
+ self, model_out, x, t, c, **corrector_kwargs
+ )
+
+ if self.parameterization == "eps":
+ x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
+ elif self.parameterization == "x0":
+ x_recon = model_out
+ else:
+ raise NotImplementedError()
+
+ if clip_denoised:
+ x_recon.clamp_(-1.0, 1.0)
+
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(
+ x_start=x_recon, x_t=x, t=t
+ )
+
+ if return_x0:
+ return model_mean, posterior_variance, posterior_log_variance, x_recon
+ else:
+ return model_mean, posterior_variance, posterior_log_variance
+
+ @torch.no_grad()
+ def p_sample(
+ self,
+ x,
+ c,
+ t,
+ clip_denoised=False,
+ repeat_noise=False,
+ return_x0=False,
+ temperature=1.0,
+ noise_dropout=0.0,
+ score_corrector=None,
+ corrector_kwargs=None,
+ **kwargs,
+ ):
+ b, *_, device = *x.shape, x.device
+ outputs = self.p_mean_variance(
+ x=x,
+ c=c,
+ t=t,
+ clip_denoised=clip_denoised,
+ return_x0=return_x0,
+ score_corrector=score_corrector,
+ corrector_kwargs=corrector_kwargs,
+ **kwargs,
+ )
+ if return_x0:
+ model_mean, _, model_log_variance, x0 = outputs
+ else:
+ model_mean, _, model_log_variance = outputs
+
+ noise = noise_like(x.shape, device, repeat_noise) * temperature
+ if noise_dropout > 0.0:
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
+
+ nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
+
+ if return_x0:
+ return (
+ model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise,
+ x0,
+ )
+ else:
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
+
+ @torch.no_grad()
+ def p_sample_loop(
+ self,
+ cond,
+ shape,
+ return_intermediates=False,
+ x_T=None,
+ verbose=True,
+ callback=None,
+ time_steps=None,
+ mask=None,
+ x0=None,
+ img_callback=None,
+ start_T=None,
+ log_every_t=None,
+ **kwargs,
+ ):
+ if not log_every_t:
+ log_every_t = self.log_every_t
+ device = self.betas.device
+ b = shape[0]
+ if x_T is None:
+ img = torch.randn(shape, device=device)
+ else:
+ img = x_T
+
+ intermediates = [img]
+ if time_steps is None:
+ time_steps = self.num_time_steps
+ if start_T is not None:
+ time_steps = min(time_steps, start_T)
+
+ iterator = (
+ tqdm(reversed(range(0, time_steps)), desc="Sampling t", total=time_steps)
+ if verbose
+ else reversed(range(0, time_steps))
+ )
+
+ if mask is not None:
+ assert x0 is not None
+ # spatial size has to match
+ assert x0.shape[2:3] == mask.shape[2:3]
+
+ for i in iterator:
+ ts = torch.full((b,), i, device=device, dtype=torch.long)
+ if self.shorten_cond_schedule:
+ assert self.model.conditioning_key != "hybrid"
+ tc = self.cond_ids[ts].to(cond.device)
+ cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
+
+ img = self.p_sample(
+ img, cond, ts, clip_denoised=self.clip_denoised, **kwargs
+ )
+
+ if mask is not None:
+ img_orig = self.q_sample(x0, ts)
+ img = img_orig * mask + (1.0 - mask) * img
+
+ if i % log_every_t == 0 or i == time_steps - 1:
+ intermediates.append(img)
+ if callback:
+ callback(i)
+ if img_callback:
+ img_callback(img, i)
+
+ if return_intermediates:
+ return img, intermediates
+ return img
+
+ @torch.no_grad()
+ def sample(
+ self,
+ cond,
+ batch_size=16,
+ return_intermediates=False,
+ x_T=None,
+ verbose=True,
+ time_steps=None,
+ mask=None,
+ x0=None,
+ shape=None,
+ **kwargs,
+ ):
+ if shape is None:
+ shape = (batch_size, self.channels, self.temporal_length, *self.image_size)
+ if cond is not None:
+ if isinstance(cond, dict):
+ cond = {
+ key: (
+ cond[key][:batch_size]
+ if not isinstance(cond[key], list)
+ else list(map(lambda x: x[:batch_size], cond[key]))
+ )
+ for key in cond
+ }
+ else:
+ cond = (
+ [c[:batch_size] for c in cond]
+ if isinstance(cond, list)
+ else cond[:batch_size]
+ )
+ return self.p_sample_loop(
+ cond,
+ shape,
+ return_intermediates=return_intermediates,
+ x_T=x_T,
+ verbose=verbose,
+ time_steps=time_steps,
+ mask=mask,
+ x0=x0,
+ **kwargs,
+ )
+
+ @torch.no_grad()
+ def sample_log(
+ self,
+ cond,
+ batch_size,
+ ddim,
+ ddim_steps,
+ with_extra_returned_data=False,
+ **kwargs,
+ ):
+ if ddim:
+ ddim_sampler = DDIMSampler(self)
+ shape = (self.temporal_length, self.channels, *self.image_size)
+ out = ddim_sampler.sample(
+ ddim_steps,
+ batch_size,
+ shape,
+ cond,
+ verbose=True,
+ with_extra_returned_data=with_extra_returned_data,
+ **kwargs,
+ )
+ if with_extra_returned_data:
+ samples, intermediates, extra_returned_data = out
+ return samples, intermediates, extra_returned_data
+ else:
+ samples, intermediates = out
+ return samples, intermediates
+
+ else:
+ samples, intermediates = self.sample(
+ cond=cond, batch_size=batch_size, return_intermediates=True, **kwargs
+ )
+
+ return samples, intermediates
+
+
+class DiffusionWrapper(pl.LightningModule):
+ def __init__(self, diff_model_config):
+ super().__init__()
+ self.diffusion_model = instantiate_from_config(diff_model_config)
+
+ def forward(self, x, c, **kwargs):
+ return self.diffusion_model(x, c, **kwargs)
diff --git a/core/models/samplers/__init__.py b/core/models/samplers/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/core/models/samplers/ddim.py b/core/models/samplers/ddim.py
new file mode 100755
index 0000000000000000000000000000000000000000..ea4dbd0f9eb64a6696f9d53e99df998286d63266
--- /dev/null
+++ b/core/models/samplers/ddim.py
@@ -0,0 +1,546 @@
+"""SAMPLING ONLY."""
+
+import numpy as np
+import torch
+from einops import rearrange
+from tqdm import tqdm
+
+from core.common import noise_like
+from core.models.utils_diffusion import (
+ make_ddim_sampling_parameters,
+ make_ddim_time_steps,
+ rescale_noise_cfg,
+)
+
+
+class DDIMSampler(object):
+ def __init__(self, model, schedule="linear", **kwargs):
+ super().__init__()
+ self.model = model
+ self.ddpm_num_time_steps = model.num_time_steps
+ self.schedule = schedule
+ self.counter = 0
+
+ def register_buffer(self, name, attr):
+ if type(attr) == torch.Tensor:
+ if attr.device != torch.device("cuda"):
+ attr = attr.to(torch.device("cuda"))
+ setattr(self, name, attr)
+
+ def make_schedule(
+ self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True
+ ):
+ self.ddim_time_steps = make_ddim_time_steps(
+ ddim_discr_method=ddim_discretize,
+ num_ddim_time_steps=ddim_num_steps,
+ num_ddpm_time_steps=self.ddpm_num_time_steps,
+ verbose=verbose,
+ )
+ alphas_cumprod = self.model.alphas_cumprod
+ assert (
+ alphas_cumprod.shape[0] == self.ddpm_num_time_steps
+ ), "alphas have to be defined for each timestep"
+
+ def to_torch(x):
+ return x.clone().detach().to(torch.float32).to(self.model.device)
+
+ if self.model.use_dynamic_rescale:
+ self.ddim_scale_arr = self.model.scale_arr[self.ddim_time_steps]
+ self.ddim_scale_arr_prev = torch.cat(
+ [self.ddim_scale_arr[0:1], self.ddim_scale_arr[:-1]]
+ )
+
+ self.register_buffer("betas", to_torch(self.model.betas))
+ self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
+ self.register_buffer(
+ "alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev)
+ )
+
+ # calculations for diffusion q(x_t | x_{t-1}) and others
+ self.register_buffer(
+ "sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu()))
+ )
+ self.register_buffer(
+ "sqrt_one_minus_alphas_cumprod",
+ to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
+ )
+ self.register_buffer(
+ "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu()))
+ )
+ self.register_buffer(
+ "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))
+ )
+ self.register_buffer(
+ "sqrt_recipm1_alphas_cumprod",
+ to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
+ )
+
+ # ddim sampling parameters
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
+ alphacums=alphas_cumprod.cpu(),
+ ddim_time_steps=self.ddim_time_steps,
+ eta=ddim_eta,
+ verbose=verbose,
+ )
+ self.register_buffer("ddim_sigmas", ddim_sigmas)
+ self.register_buffer("ddim_alphas", ddim_alphas)
+ self.register_buffer("ddim_alphas_prev", ddim_alphas_prev)
+ self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas))
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
+ (1 - self.alphas_cumprod_prev)
+ / (1 - self.alphas_cumprod)
+ * (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
+ )
+ self.register_buffer(
+ "ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps
+ )
+
+ @torch.no_grad()
+ def sample(
+ self,
+ S,
+ batch_size,
+ shape,
+ conditioning=None,
+ callback=None,
+ img_callback=None,
+ quantize_x0=False,
+ eta=0.0,
+ mask=None,
+ x0=None,
+ temperature=1.0,
+ noise_dropout=0.0,
+ score_corrector=None,
+ corrector_kwargs=None,
+ verbose=True,
+ schedule_verbose=False,
+ x_T=None,
+ log_every_t=100,
+ unconditional_guidance_scale=1.0,
+ unconditional_conditioning=None,
+ unconditional_guidance_scale_extra=1.0,
+ unconditional_conditioning_extra=None,
+ with_extra_returned_data=False,
+ **kwargs,
+ ):
+
+ # check condition bs
+ if conditioning is not None:
+ if isinstance(conditioning, dict):
+ try:
+ cbs = conditioning[list(conditioning.keys())[0]].shape[0]
+ except:
+ cbs = conditioning[list(conditioning.keys())[0]][0].shape[0]
+
+ if cbs != batch_size:
+ print(
+ f"Warning: Got {cbs} conditionings but batch-size is {batch_size}"
+ )
+ else:
+ if conditioning.shape[0] != batch_size:
+ print(
+ f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}"
+ )
+
+ self.skip_step = self.ddpm_num_time_steps // S
+ discr_method = (
+ "uniform_trailing" if self.model.rescale_betas_zero_snr else "uniform"
+ )
+ self.make_schedule(
+ ddim_num_steps=S,
+ ddim_discretize=discr_method,
+ ddim_eta=eta,
+ verbose=schedule_verbose,
+ )
+
+ # make shape
+ if len(shape) == 3:
+ C, H, W = shape
+ size = (batch_size, C, H, W)
+ elif len(shape) == 4:
+ T, C, H, W = shape
+ size = (batch_size, T, C, H, W)
+ else:
+ assert False, f"Invalid shape: {shape}."
+ out = self.ddim_sampling(
+ conditioning,
+ size,
+ callback=callback,
+ img_callback=img_callback,
+ quantize_denoised=quantize_x0,
+ mask=mask,
+ x0=x0,
+ ddim_use_original_steps=False,
+ noise_dropout=noise_dropout,
+ temperature=temperature,
+ score_corrector=score_corrector,
+ corrector_kwargs=corrector_kwargs,
+ x_T=x_T,
+ log_every_t=log_every_t,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ unconditional_guidance_scale_extra=unconditional_guidance_scale_extra,
+ unconditional_conditioning_extra=unconditional_conditioning_extra,
+ verbose=verbose,
+ with_extra_returned_data=with_extra_returned_data,
+ **kwargs,
+ )
+ if with_extra_returned_data:
+ samples, intermediates, extra_returned_data = out
+ return samples, intermediates, extra_returned_data
+ else:
+ samples, intermediates = out
+ return samples, intermediates
+
+ @torch.no_grad()
+ def ddim_sampling(
+ self,
+ cond,
+ shape,
+ x_T=None,
+ ddim_use_original_steps=False,
+ callback=None,
+ time_steps=None,
+ quantize_denoised=False,
+ mask=None,
+ x0=None,
+ img_callback=None,
+ log_every_t=100,
+ temperature=1.0,
+ noise_dropout=0.0,
+ score_corrector=None,
+ corrector_kwargs=None,
+ unconditional_guidance_scale=1.0,
+ unconditional_conditioning=None,
+ unconditional_guidance_scale_extra=1.0,
+ unconditional_conditioning_extra=None,
+ verbose=True,
+ with_extra_returned_data=False,
+ **kwargs,
+ ):
+ device = self.model.betas.device
+ b = shape[0]
+ if x_T is None:
+ img = torch.randn(shape, device=device, dtype=self.model.dtype)
+ if self.model.bd_noise:
+ noise_decor = self.model.bd(img)
+ noise_decor = (noise_decor - noise_decor.mean()) / (
+ noise_decor.std() + 1e-5
+ )
+ noise_f = noise_decor[:, :, 0:1, :, :]
+ noise = (
+ np.sqrt(self.model.bd_ratio) * noise_decor[:, :, 1:]
+ + np.sqrt(1 - self.model.bd_ratio) * noise_f
+ )
+ img = torch.cat([noise_f, noise], dim=2)
+ else:
+ img = x_T
+
+ if time_steps is None:
+ time_steps = (
+ self.ddpm_num_time_steps
+ if ddim_use_original_steps
+ else self.ddim_time_steps
+ )
+ elif time_steps is not None and not ddim_use_original_steps:
+ subset_end = (
+ int(
+ min(time_steps / self.ddim_time_steps.shape[0], 1)
+ * self.ddim_time_steps.shape[0]
+ )
+ - 1
+ )
+ time_steps = self.ddim_time_steps[:subset_end]
+
+ intermediates = {"x_inter": [img], "pred_x0": [img]}
+ time_range = (
+ reversed(range(0, time_steps))
+ if ddim_use_original_steps
+ else np.flip(time_steps)
+ )
+ total_steps = time_steps if ddim_use_original_steps else time_steps.shape[0]
+ if verbose:
+ iterator = tqdm(time_range, desc="DDIM Sampler", total=total_steps)
+ else:
+ iterator = time_range
+ # Sampling Loop
+ for i, step in enumerate(iterator):
+ print(f"Sample: i={i}, step={step}.")
+ index = total_steps - i - 1
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
+ print("ts=", ts)
+ # use mask to blend noised original latent (img_orig) & new sampled latent (img)
+ if mask is not None:
+ assert x0 is not None
+ img_orig = x0
+ # keep original & modify use img
+ img = img_orig * mask + (1.0 - mask) * img
+ outs = self.p_sample_ddim(
+ img,
+ cond,
+ ts,
+ index=index,
+ use_original_steps=ddim_use_original_steps,
+ quantize_denoised=quantize_denoised,
+ temperature=temperature,
+ noise_dropout=noise_dropout,
+ score_corrector=score_corrector,
+ corrector_kwargs=corrector_kwargs,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ unconditional_guidance_scale_extra=unconditional_guidance_scale_extra,
+ unconditional_conditioning_extra=unconditional_conditioning_extra,
+ with_extra_returned_data=with_extra_returned_data,
+ **kwargs,
+ )
+ if with_extra_returned_data:
+ img, pred_x0, extra_returned_data = outs
+ else:
+ img, pred_x0 = outs
+ if callback:
+ callback(i)
+ if img_callback:
+ img_callback(pred_x0, i)
+ # log_every_t = 1
+ if index % log_every_t == 0 or index == total_steps - 1:
+ intermediates["x_inter"].append(img)
+ intermediates["pred_x0"].append(pred_x0)
+ # intermediates['extra_returned_data'].append(extra_returned_data)
+ if with_extra_returned_data:
+ return img, intermediates, extra_returned_data
+ return img, intermediates
+
+ def batch_time_transpose(
+ self, batch_time_tensor, num_target_views, num_condition_views
+ ):
+ # Input: N*N; N = T+C
+ assert num_target_views + num_condition_views == batch_time_tensor.shape[1]
+ target_tensor = batch_time_tensor[:, :num_target_views, ...] # T*T
+ condition_tensor = batch_time_tensor[:, num_target_views:, ...] # N*C
+ target_tensor = target_tensor.transpose(0, 1) # T*T
+ return torch.concat([target_tensor, condition_tensor], dim=1)
+
+ def ddim_batch_shard_step(
+ self,
+ pred_x0_post_process_function,
+ pred_x0_post_process_function_kwargs,
+ cond,
+ corrector_kwargs,
+ ddim_use_original_steps,
+ device,
+ img,
+ index,
+ kwargs,
+ noise_dropout,
+ quantize_denoised,
+ score_corrector,
+ step,
+ temperature,
+ with_extra_returned_data,
+ ):
+ img_list = []
+ pred_x0_list = []
+ shard_step = 5
+ shard_start = 0
+ while shard_start < img.shape[0]:
+ shard_end = shard_start + shard_step
+ if shard_start >= img.shape[0]:
+ break
+ if shard_end > img.shape[0]:
+ shard_end = img.shape[0]
+ print(
+ f"Sampling Batch Shard: From #{shard_start} to #{shard_end}. Total: {img.shape[0]}."
+ )
+ sub_img = img[shard_start:shard_end]
+ sub_cond = {
+ "combined_condition": cond["combined_condition"][shard_start:shard_end],
+ "c_crossattn": [
+ cond["c_crossattn"][0][0:1].expand(shard_end - shard_start, -1, -1)
+ ],
+ }
+ ts = torch.full((sub_img.shape[0],), step, device=device, dtype=torch.long)
+
+ _img, _pred_x0 = self.p_sample_ddim(
+ sub_img,
+ sub_cond,
+ ts,
+ index=index,
+ use_original_steps=ddim_use_original_steps,
+ quantize_denoised=quantize_denoised,
+ temperature=temperature,
+ noise_dropout=noise_dropout,
+ score_corrector=score_corrector,
+ corrector_kwargs=corrector_kwargs,
+ unconditional_guidance_scale=1.0,
+ unconditional_conditioning=None,
+ unconditional_guidance_scale_extra=1.0,
+ unconditional_conditioning_extra=None,
+ pred_x0_post_process_function=pred_x0_post_process_function,
+ pred_x0_post_process_function_kwargs=pred_x0_post_process_function_kwargs,
+ with_extra_returned_data=with_extra_returned_data,
+ **kwargs,
+ )
+ img_list.append(_img)
+ pred_x0_list.append(_pred_x0)
+ shard_start += shard_step
+ img = torch.concat(img_list, dim=0)
+ pred_x0 = torch.concat(pred_x0_list, dim=0)
+ return img, pred_x0
+
+ @torch.no_grad()
+ def p_sample_ddim(
+ self,
+ x,
+ c,
+ t,
+ index,
+ repeat_noise=False,
+ use_original_steps=False,
+ quantize_denoised=False,
+ temperature=1.0,
+ noise_dropout=0.0,
+ score_corrector=None,
+ corrector_kwargs=None,
+ unconditional_guidance_scale=1.0,
+ unconditional_conditioning=None,
+ unconditional_guidance_scale_extra=1.0,
+ unconditional_conditioning_extra=None,
+ with_extra_returned_data=False,
+ **kwargs,
+ ):
+ b, *_, device = *x.shape, x.device
+ if x.dim() == 5:
+ is_video = True
+ else:
+ is_video = False
+
+ extra_returned_data = None
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.0:
+ e_t_cfg = self.model.apply_model(x, t, c, **kwargs) # unet denoiser
+ if isinstance(e_t_cfg, tuple):
+ e_t_cfg = e_t_cfg[0]
+ extra_returned_data = e_t_cfg[1:]
+ else:
+ # with unconditional condition
+ if isinstance(c, torch.Tensor) or isinstance(c, dict):
+ e_t = self.model.apply_model(x, t, c, **kwargs)
+ e_t_uncond = self.model.apply_model(
+ x, t, unconditional_conditioning, **kwargs
+ )
+ if (
+ unconditional_guidance_scale_extra != 1.0
+ and unconditional_conditioning_extra is not None
+ ):
+ print(f"Using extra CFG: {unconditional_guidance_scale_extra}...")
+ e_t_uncond_extra = self.model.apply_model(
+ x, t, unconditional_conditioning_extra, **kwargs
+ )
+ else:
+ e_t_uncond_extra = None
+ else:
+ raise NotImplementedError
+
+ if isinstance(e_t, tuple):
+ e_t = e_t[0]
+ extra_returned_data = e_t[1:]
+
+ if isinstance(e_t_uncond, tuple):
+ e_t_uncond = e_t_uncond[0]
+ if isinstance(e_t_uncond_extra, tuple):
+ e_t_uncond_extra = e_t_uncond_extra[0]
+
+ # text cfg
+ if (
+ unconditional_guidance_scale_extra != 1.0
+ and unconditional_conditioning_extra is not None
+ ):
+ e_t_cfg = (
+ e_t_uncond
+ + unconditional_guidance_scale * (e_t - e_t_uncond)
+ + unconditional_guidance_scale_extra * (e_t - e_t_uncond_extra)
+ )
+ else:
+ e_t_cfg = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
+
+ if self.model.rescale_betas_zero_snr:
+ e_t_cfg = rescale_noise_cfg(e_t_cfg, e_t, guidance_rescale=0.7)
+
+ if self.model.parameterization == "v":
+ e_t = self.model.predict_eps_from_z_and_v(x, t, e_t_cfg)
+ else:
+ e_t = e_t_cfg
+
+ if score_corrector is not None:
+ assert self.model.parameterization == "eps", "not implemented"
+ e_t = score_corrector.modify_score(
+ self.model, e_t, x, t, c, **corrector_kwargs
+ )
+
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
+ alphas_prev = (
+ self.model.alphas_cumprod_prev
+ if use_original_steps
+ else self.ddim_alphas_prev
+ )
+ sqrt_one_minus_alphas = (
+ self.model.sqrt_one_minus_alphas_cumprod
+ if use_original_steps
+ else self.ddim_sqrt_one_minus_alphas
+ )
+ sigmas = (
+ self.model.ddim_sigmas_for_original_num_steps
+ if use_original_steps
+ else self.ddim_sigmas
+ )
+ # select parameters corresponding to the currently considered timestep
+
+ if is_video:
+ size = (b, 1, 1, 1, 1)
+ else:
+ size = (b, 1, 1, 1)
+ a_t = torch.full(size, alphas[index], device=device)
+ a_prev = torch.full(size, alphas_prev[index], device=device)
+ sigma_t = torch.full(size, sigmas[index], device=device)
+ sqrt_one_minus_at = torch.full(
+ size, sqrt_one_minus_alphas[index], device=device
+ )
+
+ # current prediction for x_0
+ if self.model.parameterization != "v":
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
+ else:
+ pred_x0 = self.model.predict_start_from_z_and_v(x, t, e_t_cfg)
+
+ if self.model.use_dynamic_rescale:
+ scale_t = torch.full(size, self.ddim_scale_arr[index], device=device)
+ prev_scale_t = torch.full(
+ size, self.ddim_scale_arr_prev[index], device=device
+ )
+ rescale = prev_scale_t / scale_t
+ pred_x0 *= rescale
+
+ if quantize_denoised:
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
+ dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t
+
+ noise = noise_like(x.shape, device, repeat_noise)
+ if self.model.bd_noise:
+ noise_decor = self.model.bd(noise)
+ noise_decor = (noise_decor - noise_decor.mean()) / (
+ noise_decor.std() + 1e-5
+ )
+ noise_f = noise_decor[:, :, 0:1, :, :]
+ noise = (
+ np.sqrt(self.model.bd_ratio) * noise_decor[:, :, 1:]
+ + np.sqrt(1 - self.model.bd_ratio) * noise_f
+ )
+ noise = torch.cat([noise_f, noise], dim=2)
+ noise = sigma_t * noise * temperature
+
+ if noise_dropout > 0.0:
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
+
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
+ if with_extra_returned_data:
+ return x_prev, pred_x0, extra_returned_data
+ return x_prev, pred_x0
diff --git a/core/models/samplers/dpm_solver/__init__.py b/core/models/samplers/dpm_solver/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..7427f38c07530afbab79154ea8aaf88c4bf70a08
--- /dev/null
+++ b/core/models/samplers/dpm_solver/__init__.py
@@ -0,0 +1 @@
+from .sampler import DPMSolverSampler
\ No newline at end of file
diff --git a/core/models/samplers/dpm_solver/dpm_solver.py b/core/models/samplers/dpm_solver/dpm_solver.py
new file mode 100755
index 0000000000000000000000000000000000000000..03cce2b6a4f41fa166f61f4c34f8cf63489c56ec
--- /dev/null
+++ b/core/models/samplers/dpm_solver/dpm_solver.py
@@ -0,0 +1,1298 @@
+import torch
+import torch.nn.functional as F
+import math
+from tqdm import tqdm
+
+
+class NoiseScheduleVP:
+ def __init__(
+ self,
+ schedule='discrete',
+ betas=None,
+ alphas_cumprod=None,
+ continuous_beta_0=0.1,
+ continuous_beta_1=20.,
+ ):
+ """Create a wrapper class for the forward SDE (VP type).
+
+ ***
+ Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t.
+ We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images.
+ ***
+
+ The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ).
+ We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper).
+ Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have:
+
+ log_alpha_t = self.marginal_log_mean_coeff(t)
+ sigma_t = self.marginal_std(t)
+ lambda_t = self.marginal_lambda(t)
+
+ Moreover, as lambda(t) is an invertible function, we also support its inverse function:
+
+ t = self.inverse_lambda(lambda_t)
+
+ ===============================================================
+
+ We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]).
+
+ 1. For discrete-time DPMs:
+
+ For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by:
+ t_i = (i + 1) / N
+ e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1.
+ We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3.
+
+ Args:
+ betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details)
+ alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details)
+
+ Note that we always have alphas_cumprod = cumprod(betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`.
+
+ **Important**: Please pay special attention for the args for `alphas_cumprod`:
+ The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that
+ q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ).
+ Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have
+ alpha_{t_n} = \sqrt{\hat{alpha_n}},
+ and
+ log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}).
+
+
+ 2. For continuous-time DPMs:
+
+ We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise
+ schedule are the default settings in DDPM and improved-DDPM:
+
+ Args:
+ beta_min: A `float` number. The smallest beta for the linear schedule.
+ beta_max: A `float` number. The largest beta for the linear schedule.
+ cosine_s: A `float` number. The hyperparameter in the cosine schedule.
+ cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule.
+ T: A `float` number. The ending time of the forward process.
+
+ ===============================================================
+
+ Args:
+ schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs,
+ 'linear' or 'cosine' for continuous-time DPMs.
+ Returns:
+ A wrapper object of the forward SDE (VP type).
+
+ ===============================================================
+
+ Example:
+
+ # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1):
+ >>> ns = NoiseScheduleVP('discrete', betas=betas)
+
+ # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1):
+ >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
+
+ # For continuous-time DPMs (VPSDE), linear schedule:
+ >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.)
+
+ """
+
+ if schedule not in ['discrete', 'linear', 'cosine']:
+ raise ValueError(
+ "Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(
+ schedule))
+
+ self.schedule = schedule
+ if schedule == 'discrete':
+ if betas is not None:
+ log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0)
+ else:
+ assert alphas_cumprod is not None
+ log_alphas = 0.5 * torch.log(alphas_cumprod)
+ self.total_N = len(log_alphas)
+ self.T = 1.
+ self.t_array = torch.linspace(
+ 0., 1., self.total_N + 1)[1:].reshape((1, -1))
+ self.log_alpha_array = log_alphas.reshape((1, -1,))
+ else:
+ self.total_N = 1000
+ self.beta_0 = continuous_beta_0
+ self.beta_1 = continuous_beta_1
+ self.cosine_s = 0.008
+ self.cosine_beta_max = 999.
+ self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * (
+ 1. + self.cosine_s) / math.pi - self.cosine_s
+ self.cosine_log_alpha_0 = math.log(
+ math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.))
+ self.schedule = schedule
+ if schedule == 'cosine':
+ # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T.
+ # Note that T = 0.9946 may be not the optimal setting. However, we find it works well.
+ self.T = 0.9946
+ else:
+ self.T = 1.
+
+ def marginal_log_mean_coeff(self, t):
+ """
+ Compute log(alpha_t) of a given continuous-time label t in [0, T].
+ """
+ if self.schedule == 'discrete':
+ return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device),
+ self.log_alpha_array.to(t.device)).reshape((-1))
+ elif self.schedule == 'linear':
+ return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
+ elif self.schedule == 'cosine':
+ def log_alpha_fn(s): return torch.log(
+ torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.))
+ log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0
+ return log_alpha_t
+
+ def marginal_alpha(self, t):
+ """
+ Compute alpha_t of a given continuous-time label t in [0, T].
+ """
+ return torch.exp(self.marginal_log_mean_coeff(t))
+
+ def marginal_std(self, t):
+ """
+ Compute sigma_t of a given continuous-time label t in [0, T].
+ """
+ return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t)))
+
+ def marginal_lambda(self, t):
+ """
+ Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
+ """
+ log_mean_coeff = self.marginal_log_mean_coeff(t)
+ log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
+ return log_mean_coeff - log_std
+
+ def inverse_lambda(self, lamb):
+ """
+ Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
+ """
+ if self.schedule == 'linear':
+ tmp = 2. * (self.beta_1 - self.beta_0) * \
+ torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
+ Delta = self.beta_0 ** 2 + tmp
+ return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
+ elif self.schedule == 'discrete':
+ log_alpha = -0.5 * \
+ torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb)
+ t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]),
+ torch.flip(self.t_array.to(lamb.device), [1]))
+ return t.reshape((-1,))
+ else:
+ log_alpha = -0.5 * \
+ torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
+ def t_fn(log_alpha_t): return torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * (
+ 1. + self.cosine_s) / math.pi - self.cosine_s
+ t = t_fn(log_alpha)
+ return t
+
+
+def model_wrapper(
+ model,
+ noise_schedule,
+ model_type="noise",
+ model_kwargs={},
+ guidance_type="uncond",
+ condition=None,
+ unconditional_condition=None,
+ guidance_scale=1.,
+ classifier_fn=None,
+ classifier_kwargs={},
+):
+ """Create a wrapper function for the noise prediction model.
+
+ DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to
+ firstly wrap the model function to a noise prediction model that accepts the continuous time as the input.
+
+ We support four types of the diffusion model by setting `model_type`:
+
+ 1. "noise": noise prediction model. (Trained by predicting noise).
+
+ 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0).
+
+ 3. "v": velocity prediction model. (Trained by predicting the velocity).
+ The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2].
+
+ [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models."
+ arXiv preprint arXiv:2202.00512 (2022).
+ [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models."
+ arXiv preprint arXiv:2210.02303 (2022).
+
+ 4. "score": marginal score function. (Trained by denoising score matching).
+ Note that the score function and the noise prediction model follows a simple relationship:
+ ```
+ noise(x_t, t) = -sigma_t * score(x_t, t)
+ ```
+
+ We support three types of guided sampling by DPMs by setting `guidance_type`:
+ 1. "uncond": unconditional sampling by DPMs.
+ The input `model` has the following format:
+ ``
+ model(x, t_input, **model_kwargs) -> noise | x_start | v | score
+ ``
+
+ 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier.
+ The input `model` has the following format:
+ ``
+ model(x, t_input, **model_kwargs) -> noise | x_start | v | score
+ ``
+
+ The input `classifier_fn` has the following format:
+ ``
+ classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond)
+ ``
+
+ [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis,"
+ in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794.
+
+ 3. "classifier-free": classifier-free guidance sampling by conditional DPMs.
+ The input `model` has the following format:
+ ``
+ model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
+ ``
+ And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
+
+ [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
+ arXiv preprint arXiv:2207.12598 (2022).
+
+
+ The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
+ or continuous-time labels (i.e. epsilon to T).
+
+ We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise:
+ ``
+ def model_fn(x, t_continuous) -> noise:
+ t_input = get_model_input_time(t_continuous)
+ return noise_pred(model, x, t_input, **model_kwargs)
+ ``
+ where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
+
+ ===============================================================
+
+ Args:
+ model: A diffusion model with the corresponding format described above.
+ noise_schedule: A noise schedule object, such as NoiseScheduleVP.
+ model_type: A `str`. The parameterization type of the diffusion model.
+ "noise" or "x_start" or "v" or "score".
+ model_kwargs: A `dict`. A dict for the other inputs of the model function.
+ guidance_type: A `str`. The type of the guidance for sampling.
+ "uncond" or "classifier" or "classifier-free".
+ condition: A pytorch tensor. The condition for the guided sampling.
+ Only used for "classifier" or "classifier-free" guidance type.
+ unconditional_condition: A pytorch tensor. The condition for the unconditional sampling.
+ Only used for "classifier-free" guidance type.
+ guidance_scale: A `float`. The scale for the guided sampling.
+ classifier_fn: A classifier function. Only used for the classifier guidance.
+ classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function.
+ Returns:
+ A noise prediction model that accepts the noised data and the continuous time as the inputs.
+ """
+
+ def get_model_input_time(t_continuous):
+ """
+ Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
+ For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].
+ For continuous-time DPMs, we just use `t_continuous`.
+ """
+ if noise_schedule.schedule == 'discrete':
+ return (t_continuous - 1. / noise_schedule.total_N) * 1000.
+ else:
+ return t_continuous
+
+ def noise_pred_fn(x, t_continuous, cond=None):
+ if t_continuous.reshape((-1,)).shape[0] == 1:
+ t_continuous = t_continuous.expand((x.shape[0]))
+ t_input = get_model_input_time(t_continuous)
+ if cond is None:
+ output = model(x, t_input, **model_kwargs)
+ else:
+ output = model(x, t_input, cond, **model_kwargs)
+
+ if isinstance(output, tuple):
+ output = output[0]
+
+ if model_type == "noise":
+ return output
+ elif model_type == "x_start":
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(
+ t_continuous), noise_schedule.marginal_std(t_continuous)
+ dims = x.dim()
+ return (x - expand_dims(alpha_t, dims) * output) / expand_dims(sigma_t, dims)
+ elif model_type == "v":
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(
+ t_continuous), noise_schedule.marginal_std(t_continuous)
+ dims = x.dim()
+ return expand_dims(alpha_t, dims) * output + expand_dims(sigma_t, dims) * x
+ elif model_type == "score":
+ sigma_t = noise_schedule.marginal_std(t_continuous)
+ dims = x.dim()
+ return -expand_dims(sigma_t, dims) * output
+
+ def cond_grad_fn(x, t_input):
+ """
+ Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).
+ """
+ with torch.enable_grad():
+ x_in = x.detach().requires_grad_(True)
+ log_prob = classifier_fn(
+ x_in, t_input, condition, **classifier_kwargs)
+ return torch.autograd.grad(log_prob.sum(), x_in)[0]
+
+ def model_fn(x, t_continuous):
+ """
+ The noise predicition model function that is used for DPM-Solver.
+ """
+ if t_continuous.reshape((-1,)).shape[0] == 1:
+ t_continuous = t_continuous.expand((x.shape[0]))
+ if guidance_type == "uncond":
+ return noise_pred_fn(x, t_continuous)
+ elif guidance_type == "classifier":
+ assert classifier_fn is not None
+ t_input = get_model_input_time(t_continuous)
+ cond_grad = cond_grad_fn(x, t_input)
+ sigma_t = noise_schedule.marginal_std(t_continuous)
+ noise = noise_pred_fn(x, t_continuous)
+ return noise - guidance_scale * expand_dims(sigma_t, dims=cond_grad.dim()) * cond_grad
+ elif guidance_type == "classifier-free":
+ if guidance_scale == 1. or unconditional_condition is None:
+ return noise_pred_fn(x, t_continuous, cond=condition)
+ else:
+ # x_in = torch.cat([x] * 2)
+ # t_in = torch.cat([t_continuous] * 2)
+ x_in = x
+ t_in = t_continuous
+ # c_in = torch.cat([unconditional_condition, condition])
+ noise = noise_pred_fn(x_in, t_in, cond=condition)
+ noise_uncond = noise_pred_fn(
+ x_in, t_in, cond=unconditional_condition)
+ return noise_uncond + guidance_scale * (noise - noise_uncond)
+
+ assert model_type in ["noise", "x_start", "v"]
+ assert guidance_type in ["uncond", "classifier", "classifier-free"]
+ return model_fn
+
+
+class DPM_Solver:
+ def __init__(self, model_fn, noise_schedule, predict_x0=False, thresholding=False, max_val=1.):
+ """Construct a DPM-Solver.
+
+ We support both the noise prediction model ("predicting epsilon") and the data prediction model ("predicting x0").
+ If `predict_x0` is False, we use the solver for the noise prediction model (DPM-Solver).
+ If `predict_x0` is True, we use the solver for the data prediction model (DPM-Solver++).
+ In such case, we further support the "dynamic thresholding" in [1] when `thresholding` is True.
+ The "dynamic thresholding" can greatly improve the sample quality for pixel-space DPMs with large guidance scales.
+
+ Args:
+ model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]):
+ ``
+ def model_fn(x, t_continuous):
+ return noise
+ ``
+ noise_schedule: A noise schedule object, such as NoiseScheduleVP.
+ predict_x0: A `bool`. If true, use the data prediction model; else, use the noise prediction model.
+ thresholding: A `bool`. Valid when `predict_x0` is True. Whether to use the "dynamic thresholding" in [1].
+ max_val: A `float`. Valid when both `predict_x0` and `thresholding` are True. The max value for thresholding.
+
+ [1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour, Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b.
+ """
+ self.model = model_fn
+ self.noise_schedule = noise_schedule
+ self.predict_x0 = predict_x0
+ self.thresholding = thresholding
+ self.max_val = max_val
+
+ def noise_prediction_fn(self, x, t):
+ """
+ Return the noise prediction model.
+ """
+ return self.model(x, t)
+
+ def data_prediction_fn(self, x, t):
+ """
+ Return the data prediction model (with thresholding).
+ """
+ noise = self.noise_prediction_fn(x, t)
+ dims = x.dim()
+ alpha_t, sigma_t = self.noise_schedule.marginal_alpha(
+ t), self.noise_schedule.marginal_std(t)
+ x0 = (x - expand_dims(sigma_t, dims) * noise) / \
+ expand_dims(alpha_t, dims)
+ if self.thresholding:
+ p = 0.995 # A hyperparameter in the paper of "Imagen" [1].
+ s = torch.quantile(torch.abs(x0).reshape(
+ (x0.shape[0], -1)), p, dim=1)
+ s = expand_dims(torch.maximum(s, self.max_val *
+ torch.ones_like(s).to(s.device)), dims)
+ x0 = torch.clamp(x0, -s, s) / s
+ return x0
+
+ def model_fn(self, x, t):
+ """
+ Convert the model to the noise prediction model or the data prediction model.
+ """
+ if self.predict_x0:
+ return self.data_prediction_fn(x, t)
+ else:
+ return self.noise_prediction_fn(x, t)
+
+ def get_time_steps(self, skip_type, t_T, t_0, N, device):
+ """Compute the intermediate time steps for sampling.
+
+ Args:
+ skip_type: A `str`. The type for the spacing of the time steps. We support three types:
+ - 'logSNR': uniform logSNR for the time steps.
+ - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
+ - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
+ t_T: A `float`. The starting time of the sampling (default is T).
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
+ N: A `int`. The total number of the spacing of the time steps.
+ device: A torch device.
+ Returns:
+ A pytorch tensor of the time steps, with the shape (N + 1,).
+ """
+ if skip_type == 'logSNR':
+ lambda_T = self.noise_schedule.marginal_lambda(
+ torch.tensor(t_T).to(device))
+ lambda_0 = self.noise_schedule.marginal_lambda(
+ torch.tensor(t_0).to(device))
+ logSNR_steps = torch.linspace(
+ lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)
+ return self.noise_schedule.inverse_lambda(logSNR_steps)
+ elif skip_type == 'time_uniform':
+ return torch.linspace(t_T, t_0, N + 1).to(device)
+ elif skip_type == 'time_quadratic':
+ t_order = 2
+ t = torch.linspace(t_T ** (1. / t_order), t_0 **
+ (1. / t_order), N + 1).pow(t_order).to(device)
+ return t
+ else:
+ raise ValueError(
+ "Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type))
+
+ def get_orders_and_time_steps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
+ """
+ Get the order of each step for sampling by the singlestep DPM-Solver.
+
+ We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast".
+ Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is:
+ - If order == 1:
+ We take `steps` of DPM-Solver-1 (i.e. DDIM).
+ - If order == 2:
+ - Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling.
+ - If steps % 2 == 0, we use K steps of DPM-Solver-2.
+ - If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1.
+ - If order == 3:
+ - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
+ - If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1.
+ - If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1.
+ - If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2.
+
+ ============================================
+ Args:
+ order: A `int`. The max order for the solver (2 or 3).
+ steps: A `int`. The total number of function evaluations (NFE).
+ skip_type: A `str`. The type for the spacing of the time steps. We support three types:
+ - 'logSNR': uniform logSNR for the time steps.
+ - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
+ - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
+ t_T: A `float`. The starting time of the sampling (default is T).
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
+ device: A torch device.
+ Returns:
+ orders: A list of the solver order of each step.
+ """
+ if order == 3:
+ K = steps // 3 + 1
+ if steps % 3 == 0:
+ orders = [3, ] * (K - 2) + [2, 1]
+ elif steps % 3 == 1:
+ orders = [3, ] * (K - 1) + [1]
+ else:
+ orders = [3, ] * (K - 1) + [2]
+ elif order == 2:
+ if steps % 2 == 0:
+ K = steps // 2
+ orders = [2, ] * K
+ else:
+ K = steps // 2 + 1
+ orders = [2, ] * (K - 1) + [1]
+ elif order == 1:
+ K = 1
+ orders = [1, ] * steps
+ else:
+ raise ValueError("'order' must be '1' or '2' or '3'.")
+ if skip_type == 'logSNR':
+ # To reproduce the results in DPM-Solver paper
+ time_steps_outer = self.get_time_steps(
+ skip_type, t_T, t_0, K, device)
+ else:
+ time_steps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[
+ torch.cumsum(torch.tensor([0, ] + orders)).to(device)]
+ return time_steps_outer, orders
+
+ def denoise_to_zero_fn(self, x, s):
+ """
+ Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
+ """
+ return self.data_prediction_fn(x, s)
+
+ def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=False):
+ """
+ DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`.
+
+ Args:
+ x: A pytorch tensor. The initial value at time `s`.
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
+ return_intermediate: A `bool`. If true, also return the model value at time `s`.
+ Returns:
+ x_t: A pytorch tensor. The approximated solution at time `t`.
+ """
+ ns = self.noise_schedule
+ dims = x.dim()
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
+ h = lambda_t - lambda_s
+ log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(
+ s), ns.marginal_log_mean_coeff(t)
+ sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t)
+ alpha_t = torch.exp(log_alpha_t)
+
+ if self.predict_x0:
+ phi_1 = torch.expm1(-h)
+ if model_s is None:
+ model_s = self.model_fn(x, s)
+ x_t = (
+ expand_dims(sigma_t / sigma_s, dims) * x
+ - expand_dims(alpha_t * phi_1, dims) * model_s
+ )
+ if return_intermediate:
+ return x_t, {'model_s': model_s}
+ else:
+ return x_t
+ else:
+ phi_1 = torch.expm1(h)
+ if model_s is None:
+ model_s = self.model_fn(x, s)
+ x_t = (
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
+ - expand_dims(sigma_t * phi_1, dims) * model_s
+ )
+ if return_intermediate:
+ return x_t, {'model_s': model_s}
+ else:
+ return x_t
+
+ def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, return_intermediate=False,
+ solver_type='dpm_solver'):
+ """
+ Singlestep solver DPM-Solver-2 from time `s` to time `t`.
+
+ Args:
+ x: A pytorch tensor. The initial value at time `s`.
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+ r1: A `float`. The hyperparameter of the second-order solver.
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
+ return_intermediate: A `bool`. If true, also return the model value at time `s` and `s1` (the intermediate time).
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+ Returns:
+ x_t: A pytorch tensor. The approximated solution at time `t`.
+ """
+ if solver_type not in ['dpm_solver', 'taylor']:
+ raise ValueError(
+ "'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type))
+ if r1 is None:
+ r1 = 0.5
+ ns = self.noise_schedule
+ dims = x.dim()
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
+ h = lambda_t - lambda_s
+ lambda_s1 = lambda_s + r1 * h
+ s1 = ns.inverse_lambda(lambda_s1)
+ log_alpha_s, log_alpha_s1, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(
+ s1), ns.marginal_log_mean_coeff(t)
+ sigma_s, sigma_s1, sigma_t = ns.marginal_std(
+ s), ns.marginal_std(s1), ns.marginal_std(t)
+ alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t)
+
+ if self.predict_x0:
+ phi_11 = torch.expm1(-r1 * h)
+ phi_1 = torch.expm1(-h)
+
+ if model_s is None:
+ model_s = self.model_fn(x, s)
+ x_s1 = (
+ expand_dims(sigma_s1 / sigma_s, dims) * x
+ - expand_dims(alpha_s1 * phi_11, dims) * model_s
+ )
+ model_s1 = self.model_fn(x_s1, s1)
+ if solver_type == 'dpm_solver':
+ x_t = (
+ expand_dims(sigma_t / sigma_s, dims) * x
+ - expand_dims(alpha_t * phi_1, dims) * model_s
+ - (0.5 / r1) * expand_dims(alpha_t *
+ phi_1, dims) * (model_s1 - model_s)
+ )
+ elif solver_type == 'taylor':
+ x_t = (
+ expand_dims(sigma_t / sigma_s, dims) * x
+ - expand_dims(alpha_t * phi_1, dims) * model_s
+ + (1. / r1) * expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * (
+ model_s1 - model_s)
+ )
+ else:
+ phi_11 = torch.expm1(r1 * h)
+ phi_1 = torch.expm1(h)
+
+ if model_s is None:
+ model_s = self.model_fn(x, s)
+ x_s1 = (
+ expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x
+ - expand_dims(sigma_s1 * phi_11, dims) * model_s
+ )
+ model_s1 = self.model_fn(x_s1, s1)
+ if solver_type == 'dpm_solver':
+ x_t = (
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
+ - expand_dims(sigma_t * phi_1, dims) * model_s
+ - (0.5 / r1) * expand_dims(sigma_t *
+ phi_1, dims) * (model_s1 - model_s)
+ )
+ elif solver_type == 'taylor':
+ x_t = (
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
+ - expand_dims(sigma_t * phi_1, dims) * model_s
+ - (1. / r1) * expand_dims(sigma_t * ((torch.exp(h) -
+ 1.) / h - 1.), dims) * (model_s1 - model_s)
+ )
+ if return_intermediate:
+ return x_t, {'model_s': model_s, 'model_s1': model_s1}
+ else:
+ return x_t
+
+ def singlestep_dpm_solver_third_update(self, x, s, t, r1=1. / 3., r2=2. / 3., model_s=None, model_s1=None,
+ return_intermediate=False, solver_type='dpm_solver'):
+ """
+ Singlestep solver DPM-Solver-3 from time `s` to time `t`.
+
+ Args:
+ x: A pytorch tensor. The initial value at time `s`.
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+ r1: A `float`. The hyperparameter of the third-order solver.
+ r2: A `float`. The hyperparameter of the third-order solver.
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
+ model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`).
+ If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it.
+ return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+ Returns:
+ x_t: A pytorch tensor. The approximated solution at time `t`.
+ """
+ if solver_type not in ['dpm_solver', 'taylor']:
+ raise ValueError(
+ "'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type))
+ if r1 is None:
+ r1 = 1. / 3.
+ if r2 is None:
+ r2 = 2. / 3.
+ ns = self.noise_schedule
+ dims = x.dim()
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
+ h = lambda_t - lambda_s
+ lambda_s1 = lambda_s + r1 * h
+ lambda_s2 = lambda_s + r2 * h
+ s1 = ns.inverse_lambda(lambda_s1)
+ s2 = ns.inverse_lambda(lambda_s2)
+ log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ns.marginal_log_mean_coeff(
+ s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(s2), ns.marginal_log_mean_coeff(t)
+ sigma_s, sigma_s1, sigma_s2, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(
+ s2), ns.marginal_std(t)
+ alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(
+ log_alpha_s2), torch.exp(log_alpha_t)
+
+ if self.predict_x0:
+ phi_11 = torch.expm1(-r1 * h)
+ phi_12 = torch.expm1(-r2 * h)
+ phi_1 = torch.expm1(-h)
+ phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1.
+ phi_2 = phi_1 / h + 1.
+ phi_3 = phi_2 / h - 0.5
+
+ if model_s is None:
+ model_s = self.model_fn(x, s)
+ if model_s1 is None:
+ x_s1 = (
+ expand_dims(sigma_s1 / sigma_s, dims) * x
+ - expand_dims(alpha_s1 * phi_11, dims) * model_s
+ )
+ model_s1 = self.model_fn(x_s1, s1)
+ x_s2 = (
+ expand_dims(sigma_s2 / sigma_s, dims) * x
+ - expand_dims(alpha_s2 * phi_12, dims) * model_s
+ + r2 / r1 * expand_dims(alpha_s2 * phi_22,
+ dims) * (model_s1 - model_s)
+ )
+ model_s2 = self.model_fn(x_s2, s2)
+ if solver_type == 'dpm_solver':
+ x_t = (
+ expand_dims(sigma_t / sigma_s, dims) * x
+ - expand_dims(alpha_t * phi_1, dims) * model_s
+ + (1. / r2) * expand_dims(alpha_t *
+ phi_2, dims) * (model_s2 - model_s)
+ )
+ elif solver_type == 'taylor':
+ D1_0 = (1. / r1) * (model_s1 - model_s)
+ D1_1 = (1. / r2) * (model_s2 - model_s)
+ D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
+ D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
+ x_t = (
+ expand_dims(sigma_t / sigma_s, dims) * x
+ - expand_dims(alpha_t * phi_1, dims) * model_s
+ + expand_dims(alpha_t * phi_2, dims) * D1
+ - expand_dims(alpha_t * phi_3, dims) * D2
+ )
+ else:
+ phi_11 = torch.expm1(r1 * h)
+ phi_12 = torch.expm1(r2 * h)
+ phi_1 = torch.expm1(h)
+ phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1.
+ phi_2 = phi_1 / h - 1.
+ phi_3 = phi_2 / h - 0.5
+
+ if model_s is None:
+ model_s = self.model_fn(x, s)
+ if model_s1 is None:
+ x_s1 = (
+ expand_dims(
+ torch.exp(log_alpha_s1 - log_alpha_s), dims) * x
+ - expand_dims(sigma_s1 * phi_11, dims) * model_s
+ )
+ model_s1 = self.model_fn(x_s1, s1)
+ x_s2 = (
+ expand_dims(torch.exp(log_alpha_s2 - log_alpha_s), dims) * x
+ - expand_dims(sigma_s2 * phi_12, dims) * model_s
+ - r2 / r1 * expand_dims(sigma_s2 *
+ phi_22, dims) * (model_s1 - model_s)
+ )
+ model_s2 = self.model_fn(x_s2, s2)
+ if solver_type == 'dpm_solver':
+ x_t = (
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
+ - expand_dims(sigma_t * phi_1, dims) * model_s
+ - (1. / r2) * expand_dims(sigma_t *
+ phi_2, dims) * (model_s2 - model_s)
+ )
+ elif solver_type == 'taylor':
+ D1_0 = (1. / r1) * (model_s1 - model_s)
+ D1_1 = (1. / r2) * (model_s2 - model_s)
+ D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
+ D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
+ x_t = (
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
+ - expand_dims(sigma_t * phi_1, dims) * model_s
+ - expand_dims(sigma_t * phi_2, dims) * D1
+ - expand_dims(sigma_t * phi_3, dims) * D2
+ )
+
+ if return_intermediate:
+ return x_t, {'model_s': model_s, 'model_s1': model_s1, 'model_s2': model_s2}
+ else:
+ return x_t
+
+ def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpm_solver"):
+ """
+ Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`.
+
+ Args:
+ x: A pytorch tensor. The initial value at time `s`.
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+ Returns:
+ x_t: A pytorch tensor. The approximated solution at time `t`.
+ """
+ if solver_type not in ['dpm_solver', 'taylor']:
+ raise ValueError(
+ "'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type))
+ ns = self.noise_schedule
+ dims = x.dim()
+ model_prev_1, model_prev_0 = model_prev_list
+ t_prev_1, t_prev_0 = t_prev_list
+ lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_1), ns.marginal_lambda(
+ t_prev_0), ns.marginal_lambda(t)
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(
+ t_prev_0), ns.marginal_log_mean_coeff(t)
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
+ alpha_t = torch.exp(log_alpha_t)
+
+ h_0 = lambda_prev_0 - lambda_prev_1
+ h = lambda_t - lambda_prev_0
+ r0 = h_0 / h
+ D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1)
+ if self.predict_x0:
+ if solver_type == 'dpm_solver':
+ x_t = (
+ expand_dims(sigma_t / sigma_prev_0, dims) * x
+ - expand_dims(alpha_t * (torch.exp(-h) - 1.),
+ dims) * model_prev_0
+ - 0.5 * expand_dims(alpha_t *
+ (torch.exp(-h) - 1.), dims) * D1_0
+ )
+ elif solver_type == 'taylor':
+ x_t = (
+ expand_dims(sigma_t / sigma_prev_0, dims) * x
+ - expand_dims(alpha_t * (torch.exp(-h) - 1.),
+ dims) * model_prev_0
+ + expand_dims(alpha_t *
+ ((torch.exp(-h) - 1.) / h + 1.), dims) * D1_0
+ )
+ else:
+ if solver_type == 'dpm_solver':
+ x_t = (
+ expand_dims(
+ torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
+ - expand_dims(sigma_t * (torch.exp(h) - 1.),
+ dims) * model_prev_0
+ - 0.5 * expand_dims(sigma_t *
+ (torch.exp(h) - 1.), dims) * D1_0
+ )
+ elif solver_type == 'taylor':
+ x_t = (
+ expand_dims(
+ torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
+ - expand_dims(sigma_t * (torch.exp(h) - 1.),
+ dims) * model_prev_0
+ - expand_dims(sigma_t * ((torch.exp(h) -
+ 1.) / h - 1.), dims) * D1_0
+ )
+ return x_t
+
+ def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type='dpm_solver'):
+ """
+ Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`.
+
+ Args:
+ x: A pytorch tensor. The initial value at time `s`.
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+ Returns:
+ x_t: A pytorch tensor. The approximated solution at time `t`.
+ """
+ ns = self.noise_schedule
+ dims = x.dim()
+ model_prev_2, model_prev_1, model_prev_0 = model_prev_list
+ t_prev_2, t_prev_1, t_prev_0 = t_prev_list
+ lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_2), ns.marginal_lambda(
+ t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t)
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(
+ t_prev_0), ns.marginal_log_mean_coeff(t)
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
+ alpha_t = torch.exp(log_alpha_t)
+
+ h_1 = lambda_prev_1 - lambda_prev_2
+ h_0 = lambda_prev_0 - lambda_prev_1
+ h = lambda_t - lambda_prev_0
+ r0, r1 = h_0 / h, h_1 / h
+ D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1)
+ D1_1 = expand_dims(1. / r1, dims) * (model_prev_1 - model_prev_2)
+ D1 = D1_0 + expand_dims(r0 / (r0 + r1), dims) * (D1_0 - D1_1)
+ D2 = expand_dims(1. / (r0 + r1), dims) * (D1_0 - D1_1)
+ if self.predict_x0:
+ x_t = (
+ expand_dims(sigma_t / sigma_prev_0, dims) * x
+ - expand_dims(alpha_t * (torch.exp(-h) - 1.),
+ dims) * model_prev_0
+ + expand_dims(alpha_t *
+ ((torch.exp(-h) - 1.) / h + 1.), dims) * D1
+ - expand_dims(alpha_t * ((torch.exp(-h) -
+ 1. + h) / h ** 2 - 0.5), dims) * D2
+ )
+ else:
+ x_t = (
+ expand_dims(
+ torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
+ - expand_dims(sigma_t * (torch.exp(h) - 1.),
+ dims) * model_prev_0
+ - expand_dims(sigma_t *
+ ((torch.exp(h) - 1.) / h - 1.), dims) * D1
+ - expand_dims(sigma_t * ((torch.exp(h) -
+ 1. - h) / h ** 2 - 0.5), dims) * D2
+ )
+ return x_t
+
+ def singlestep_dpm_solver_update(self, x, s, t, order, return_intermediate=False, solver_type='dpm_solver', r1=None,
+ r2=None):
+ """
+ Singlestep DPM-Solver with the order `order` from time `s` to time `t`.
+
+ Args:
+ x: A pytorch tensor. The initial value at time `s`.
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+ order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
+ return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+ r1: A `float`. The hyperparameter of the second-order or third-order solver.
+ r2: A `float`. The hyperparameter of the third-order solver.
+ Returns:
+ x_t: A pytorch tensor. The approximated solution at time `t`.
+ """
+ if order == 1:
+ return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate)
+ elif order == 2:
+ return self.singlestep_dpm_solver_second_update(x, s, t, return_intermediate=return_intermediate,
+ solver_type=solver_type, r1=r1)
+ elif order == 3:
+ return self.singlestep_dpm_solver_third_update(x, s, t, return_intermediate=return_intermediate,
+ solver_type=solver_type, r1=r1, r2=r2)
+ else:
+ raise ValueError(
+ "Solver order must be 1 or 2 or 3, got {}".format(order))
+
+ def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type='dpm_solver'):
+ """
+ Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`.
+
+ Args:
+ x: A pytorch tensor. The initial value at time `s`.
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+ order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+ Returns:
+ x_t: A pytorch tensor. The approximated solution at time `t`.
+ """
+ if order == 1:
+ return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1])
+ elif order == 2:
+ return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
+ elif order == 3:
+ return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
+ else:
+ raise ValueError(
+ "Solver order must be 1 or 2 or 3, got {}".format(order))
+
+ def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5,
+ solver_type='dpm_solver'):
+ """
+ The adaptive step size solver based on singlestep DPM-Solver.
+
+ Args:
+ x: A pytorch tensor. The initial value at time `t_T`.
+ order: A `int`. The (higher) order of the solver. We only support order == 2 or 3.
+ t_T: A `float`. The starting time of the sampling (default is T).
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
+ h_init: A `float`. The initial step size (for logSNR).
+ atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1].
+ rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05.
+ theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1].
+ t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the
+ current time and `t_0` is less than `t_err`. The default setting is 1e-5.
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+ Returns:
+ x_0: A pytorch tensor. The approximated solution at time `t_0`.
+
+ [1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021.
+ """
+ ns = self.noise_schedule
+ s = t_T * torch.ones((x.shape[0],)).to(x)
+ lambda_s = ns.marginal_lambda(s)
+ lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x))
+ h = h_init * torch.ones_like(s).to(x)
+ x_prev = x
+ nfe = 0
+ if order == 2:
+ r1 = 0.5
+
+ def lower_update(x, s, t): return self.dpm_solver_first_update(
+ x, s, t, return_intermediate=True)
+ higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1,
+ solver_type=solver_type,
+ **kwargs)
+ elif order == 3:
+ r1, r2 = 1. / 3., 2. / 3.
+
+ def lower_update(x, s, t): return self.singlestep_dpm_solver_second_update(x, s, t, r1=r1,
+ return_intermediate=True,
+ solver_type=solver_type)
+ higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2,
+ solver_type=solver_type,
+ **kwargs)
+ else:
+ raise ValueError(
+ "For adaptive step size solver, order must be 2 or 3, got {}".format(order))
+ while torch.abs((s - t_0)).mean() > t_err:
+ t = ns.inverse_lambda(lambda_s + h)
+ x_lower, lower_noise_kwargs = lower_update(x, s, t)
+ x_higher = higher_update(x, s, t, **lower_noise_kwargs)
+ delta = torch.max(torch.ones_like(x).to(
+ x) * atol, rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev)))
+
+ def norm_fn(v): return torch.sqrt(torch.square(
+ v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True))
+ E = norm_fn((x_higher - x_lower) / delta).max()
+ if torch.all(E <= 1.):
+ x = x_higher
+ s = t
+ x_prev = x_lower
+ lambda_s = ns.marginal_lambda(s)
+ h = torch.min(theta * h * torch.float_power(E, -
+ 1. / order).float(), lambda_0 - lambda_s)
+ nfe += order
+ print('adaptive solver nfe', nfe)
+ return x
+
+ def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time_uniform',
+ method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver',
+ atol=0.0078, rtol=0.05,
+ ):
+ """
+ Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`.
+
+ =====================================================
+
+ We support the following algorithms for both noise prediction model and data prediction model:
+ - 'singlestep':
+ Singlestep DPM-Solver (i.e. "DPM-Solver-fast" in the paper), which combines different orders of singlestep DPM-Solver.
+ We combine all the singlestep solvers with order <= `order` to use up all the function evaluations (steps).
+ The total number of function evaluations (NFE) == `steps`.
+ Given a fixed NFE == `steps`, the sampling procedure is:
+ - If `order` == 1:
+ - Denote K = steps. We use K steps of DPM-Solver-1 (i.e. DDIM).
+ - If `order` == 2:
+ - Denote K = (steps // 2) + (steps % 2). We take K intermediate time steps for sampling.
+ - If steps % 2 == 0, we use K steps of singlestep DPM-Solver-2.
+ - If steps % 2 == 1, we use (K - 1) steps of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
+ - If `order` == 3:
+ - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
+ - If steps % 3 == 0, we use (K - 2) steps of singlestep DPM-Solver-3, and 1 step of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
+ - If steps % 3 == 1, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of DPM-Solver-1.
+ - If steps % 3 == 2, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of singlestep DPM-Solver-2.
+ - 'multistep':
+ Multistep DPM-Solver with the order of `order`. The total number of function evaluations (NFE) == `steps`.
+ We initialize the first `order` values by lower order multistep solvers.
+ Given a fixed NFE == `steps`, the sampling procedure is:
+ Denote K = steps.
+ - If `order` == 1:
+ - We use K steps of DPM-Solver-1 (i.e. DDIM).
+ - If `order` == 2:
+ - We firstly use 1 step of DPM-Solver-1, then use (K - 1) step of multistep DPM-Solver-2.
+ - If `order` == 3:
+ - We firstly use 1 step of DPM-Solver-1, then 1 step of multistep DPM-Solver-2, then (K - 2) step of multistep DPM-Solver-3.
+ - 'singlestep_fixed':
+ Fixed order singlestep DPM-Solver (i.e. DPM-Solver-1 or singlestep DPM-Solver-2 or singlestep DPM-Solver-3).
+ We use singlestep DPM-Solver-`order` for `order`=1 or 2 or 3, with total [`steps` // `order`] * `order` NFE.
+ - 'adaptive':
+ Adaptive step size DPM-Solver (i.e. "DPM-Solver-12" and "DPM-Solver-23" in the paper).
+ We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`.
+ You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` to balance the computatation costs
+ (NFE) and the sample quality.
+ - If `order` == 2, we use DPM-Solver-12 which combines DPM-Solver-1 and singlestep DPM-Solver-2.
+ - If `order` == 3, we use DPM-Solver-23 which combines singlestep DPM-Solver-2 and singlestep DPM-Solver-3.
+
+ =====================================================
+
+ Some advices for choosing the algorithm:
+ - For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs:
+ Use singlestep DPM-Solver ("DPM-Solver-fast" in the paper) with `order = 3`.
+ e.g.
+ >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=False)
+ >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3,
+ skip_type='time_uniform', method='singlestep')
+ - For **guided sampling with large guidance scale** by DPMs:
+ Use multistep DPM-Solver with `predict_x0 = True` and `order = 2`.
+ e.g.
+ >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True)
+ >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2,
+ skip_type='time_uniform', method='multistep')
+
+ We support three types of `skip_type`:
+ - 'logSNR': uniform logSNR for the time steps. **Recommended for low-resolutional images**
+ - 'time_uniform': uniform time for the time steps. **Recommended for high-resolutional images**.
+ - 'time_quadratic': quadratic time for the time steps.
+
+ =====================================================
+ Args:
+ x: A pytorch tensor. The initial value at time `t_start`
+ e.g. if `t_start` == T, then `x` is a sample from the standard normal distribution.
+ steps: A `int`. The total number of function evaluations (NFE).
+ t_start: A `float`. The starting time of the sampling.
+ If `T` is None, we use self.noise_schedule.T (default is 1.0).
+ t_end: A `float`. The ending time of the sampling.
+ If `t_end` is None, we use 1. / self.noise_schedule.total_N.
+ e.g. if total_N == 1000, we have `t_end` == 1e-3.
+ For discrete-time DPMs:
+ - We recommend `t_end` == 1. / self.noise_schedule.total_N.
+ For continuous-time DPMs:
+ - We recommend `t_end` == 1e-3 when `steps` <= 15; and `t_end` == 1e-4 when `steps` > 15.
+ order: A `int`. The order of DPM-Solver.
+ skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'.
+ method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'.
+ denoise_to_zero: A `bool`. Whether to denoise to time 0 at the final step.
+ Default is `False`. If `denoise_to_zero` is `True`, the total NFE is (`steps` + 1).
+
+ This trick is firstly proposed by DDPM (https://arxiv.org/abs/2006.11239) and
+ score_sde (https://arxiv.org/abs/2011.13456). Such trick can improve the FID
+ for diffusion models sampling by diffusion SDEs for low-resolutional images
+ (such as CIFAR-10). However, we observed that such trick does not matter for
+ high-resolutional images. As it needs an additional NFE, we do not recommend
+ it for high-resolutional images.
+ lower_order_final: A `bool`. Whether to use lower order solvers at the final steps.
+ Only valid for `method=multistep` and `steps < 15`. We empirically find that
+ this trick is a key to stabilizing the sampling by DPM-Solver with very few steps
+ (especially for steps <= 10). So we recommend to set it to be `True`.
+ solver_type: A `str`. The taylor expansion type for the solver. `dpm_solver` or `taylor`. We recommend `dpm_solver`.
+ atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
+ rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
+ Returns:
+ x_end: A pytorch tensor. The approximated solution at time `t_end`.
+
+ """
+ t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
+ t_T = self.noise_schedule.T if t_start is None else t_start
+ device = x.device
+ if method == 'adaptive':
+ with torch.no_grad():
+ x = self.dpm_solver_adaptive(x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol,
+ solver_type=solver_type)
+ elif method == 'multistep':
+ assert steps >= order
+ time_steps = self.get_time_steps(
+ skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
+ assert time_steps.shape[0] - 1 == steps
+ with torch.no_grad():
+ vec_t = time_steps[0].expand((x.shape[0]))
+ model_prev_list = [self.model_fn(x, vec_t)]
+ t_prev_list = [vec_t]
+ # Init the first `order` values by lower order multistep DPM-Solver.
+ for init_order in tqdm(range(1, order), desc="DPM init order"):
+ vec_t = time_steps[init_order].expand(x.shape[0])
+ x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, init_order,
+ solver_type=solver_type)
+ model_prev_list.append(self.model_fn(x, vec_t))
+ t_prev_list.append(vec_t)
+ # Compute the remaining values by `order`-th order multistep DPM-Solver.
+ for step in tqdm(range(order, steps + 1), desc="DPM multistep"):
+ vec_t = time_steps[step].expand(x.shape[0])
+ if lower_order_final and steps < 15:
+ step_order = min(order, steps + 1 - step)
+ else:
+ step_order = order
+ x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, step_order,
+ solver_type=solver_type)
+ for i in range(order - 1):
+ t_prev_list[i] = t_prev_list[i + 1]
+ model_prev_list[i] = model_prev_list[i + 1]
+ t_prev_list[-1] = vec_t
+ # We do not need to evaluate the final model value.
+ if step < steps:
+ model_prev_list[-1] = self.model_fn(x, vec_t)
+ elif method in ['singlestep', 'singlestep_fixed']:
+ if method == 'singlestep':
+ time_steps_outer, orders = self.get_orders_and_time_steps_for_singlestep_solver(steps=steps, order=order,
+ skip_type=skip_type,
+ t_T=t_T, t_0=t_0,
+ device=device)
+ elif method == 'singlestep_fixed':
+ K = steps // order
+ orders = [order, ] * K
+ time_steps_outer = self.get_time_steps(
+ skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device)
+ for i, order in enumerate(orders):
+ t_T_inner, t_0_inner = time_steps_outer[i], time_steps_outer[i + 1]
+ time_steps_inner = self.get_time_steps(skip_type=skip_type, t_T=t_T_inner.item(), t_0=t_0_inner.item(),
+ N=order, device=device)
+ lambda_inner = self.noise_schedule.marginal_lambda(
+ time_steps_inner)
+ vec_s, vec_t = t_T_inner.tile(
+ x.shape[0]), t_0_inner.tile(x.shape[0])
+ h = lambda_inner[-1] - lambda_inner[0]
+ r1 = None if order <= 1 else (
+ lambda_inner[1] - lambda_inner[0]) / h
+ r2 = None if order <= 2 else (
+ lambda_inner[2] - lambda_inner[0]) / h
+ x = self.singlestep_dpm_solver_update(
+ x, vec_s, vec_t, order, solver_type=solver_type, r1=r1, r2=r2)
+ if denoise_to_zero:
+ x = self.denoise_to_zero_fn(
+ x, torch.ones((x.shape[0],)).to(device) * t_0)
+ return x
+
+
+def interpolate_fn(x, xp, yp):
+ """
+ A piecewise linear function y = f(x), using xp and yp as keypoints.
+ We implement f(x) in a differentiable way (i.e. applicable for autograd).
+ The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
+
+ Args:
+ x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).
+ xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
+ yp: PyTorch tensor with shape [C, K].
+ Returns:
+ The function values f(x), with shape [N, C].
+ """
+ N, K = x.shape[0], xp.shape[1]
+ all_x = torch.cat(
+ [x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
+ sorted_all_x, x_indices = torch.sort(all_x, dim=2)
+ x_idx = torch.argmin(x_indices, dim=2)
+ cand_start_idx = x_idx - 1
+ start_idx = torch.where(
+ torch.eq(x_idx, 0),
+ torch.tensor(1, device=x.device),
+ torch.where(
+ torch.eq(x_idx, K), torch.tensor(
+ K - 2, device=x.device), cand_start_idx,
+ ),
+ )
+ end_idx = torch.where(torch.eq(start_idx, cand_start_idx),
+ start_idx + 2, start_idx + 1)
+ start_x = torch.gather(sorted_all_x, dim=2,
+ index=start_idx.unsqueeze(2)).squeeze(2)
+ end_x = torch.gather(sorted_all_x, dim=2,
+ index=end_idx.unsqueeze(2)).squeeze(2)
+ start_idx2 = torch.where(
+ torch.eq(x_idx, 0),
+ torch.tensor(0, device=x.device),
+ torch.where(
+ torch.eq(x_idx, K), torch.tensor(
+ K - 2, device=x.device), cand_start_idx,
+ ),
+ )
+ y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
+ start_y = torch.gather(y_positions_expanded, dim=2,
+ index=start_idx2.unsqueeze(2)).squeeze(2)
+ end_y = torch.gather(y_positions_expanded, dim=2, index=(
+ start_idx2 + 1).unsqueeze(2)).squeeze(2)
+ cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
+ return cand
+
+
+def expand_dims(v, dims):
+ """
+ Expand the tensor `v` to the dim `dims`.
+
+ Args:
+ `v`: a PyTorch tensor with shape [N].
+ `dim`: a `int`.
+ Returns:
+ a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
+ """
+ return v[(...,) + (None,) * (dims - 1)]
diff --git a/core/models/samplers/dpm_solver/sampler.py b/core/models/samplers/dpm_solver/sampler.py
new file mode 100755
index 0000000000000000000000000000000000000000..c37699aeaa110d19fd1859b4b95acfe1d5ff0944
--- /dev/null
+++ b/core/models/samplers/dpm_solver/sampler.py
@@ -0,0 +1,91 @@
+"""SAMPLING ONLY."""
+
+import torch
+
+from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver
+
+
+MODEL_TYPES = {"eps": "noise", "v": "v"}
+
+
+class DPMSolverSampler(object):
+ def __init__(self, model, **kwargs):
+ super().__init__()
+ self.model = model
+
+ def to_torch(x):
+ return x.clone().detach().to(torch.float32).to(model.device)
+
+ self.register_buffer("alphas_cumprod", to_torch(model.alphas_cumprod))
+
+ def register_buffer(self, name, attr):
+ if type(attr) == torch.Tensor:
+ if attr.device != torch.device("cuda"):
+ attr = attr.to(torch.device("cuda"))
+ setattr(self, name, attr)
+
+ @torch.no_grad()
+ def sample(
+ self,
+ S,
+ batch_size,
+ shape,
+ conditioning=None,
+ x_T=None,
+ unconditional_guidance_scale=1.0,
+ unconditional_conditioning=None,
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
+ **kwargs,
+ ):
+ if conditioning is not None:
+ if isinstance(conditioning, dict):
+ try:
+ cbs = conditioning[list(conditioning.keys())[0]].shape[0]
+ except:
+ cbs = conditioning[list(conditioning.keys())[0]][0].shape[0]
+
+ if cbs != batch_size:
+ print(
+ f"Warning: Got {cbs} conditionings but batch-size is {batch_size}"
+ )
+ else:
+ if conditioning.shape[0] != batch_size:
+ print(
+ f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}"
+ )
+
+ # sampling
+ T, C, H, W = shape
+ size = (batch_size, T, C, H, W)
+
+ print(f"Data shape for DPM-Solver sampling is {size}, sampling steps {S}")
+
+ device = self.model.betas.device
+ if x_T is None:
+ img = torch.randn(size, device=device)
+ else:
+ img = x_T
+
+ ns = NoiseScheduleVP("discrete", alphas_cumprod=self.alphas_cumprod)
+
+ model_fn = model_wrapper(
+ lambda x, t, c: self.model.apply_model(x, t, c),
+ ns,
+ model_type=MODEL_TYPES[self.model.parameterization],
+ guidance_type="classifier-free",
+ condition=conditioning,
+ unconditional_condition=unconditional_conditioning,
+ guidance_scale=unconditional_guidance_scale,
+ )
+
+ dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False)
+ x = dpm_solver.sample(
+ img,
+ steps=S,
+ skip_type="time_uniform",
+ method="multistep",
+ order=2,
+ lower_order_final=True,
+ )
+
+ return x.to(device), None
diff --git a/core/models/samplers/plms.py b/core/models/samplers/plms.py
new file mode 100755
index 0000000000000000000000000000000000000000..ff0512528602dd38dccc3a11a392e77ec0178b9c
--- /dev/null
+++ b/core/models/samplers/plms.py
@@ -0,0 +1,358 @@
+"""SAMPLING ONLY."""
+
+import numpy as np
+from tqdm import tqdm
+
+import torch
+from core.models.utils_diffusion import (
+ make_ddim_sampling_parameters,
+ make_ddim_time_steps,
+)
+from core.common import noise_like
+
+
+class PLMSSampler(object):
+ def __init__(self, model, schedule="linear", **kwargs):
+ super().__init__()
+ self.model = model
+ self.ddpm_num_time_steps = model.num_time_steps
+ self.schedule = schedule
+
+ def register_buffer(self, name, attr):
+ if type(attr) == torch.Tensor:
+ if attr.device != torch.device("cuda"):
+ attr = attr.to(torch.device("cuda"))
+ setattr(self, name, attr)
+
+ def make_schedule(
+ self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True
+ ):
+ if ddim_eta != 0:
+ raise ValueError("ddim_eta must be 0 for PLMS")
+ self.ddim_time_steps = make_ddim_time_steps(
+ ddim_discr_method=ddim_discretize,
+ num_ddim_time_steps=ddim_num_steps,
+ num_ddpm_time_steps=self.ddpm_num_time_steps,
+ verbose=verbose,
+ )
+ alphas_cumprod = self.model.alphas_cumprod
+ assert (
+ alphas_cumprod.shape[0] == self.ddpm_num_time_steps
+ ), "alphas have to be defined for each timestep"
+
+ def to_torch(x):
+ return x.clone().detach().to(torch.float32).to(self.model.device)
+
+ self.register_buffer("betas", to_torch(self.model.betas))
+ self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
+ self.register_buffer(
+ "alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev)
+ )
+
+ # calculations for diffusion q(x_t | x_{t-1}) and others
+ self.register_buffer(
+ "sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu()))
+ )
+ self.register_buffer(
+ "sqrt_one_minus_alphas_cumprod",
+ to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
+ )
+ self.register_buffer(
+ "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu()))
+ )
+ self.register_buffer(
+ "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))
+ )
+ self.register_buffer(
+ "sqrt_recipm1_alphas_cumprod",
+ to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
+ )
+
+ # ddim sampling parameters
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
+ alphacums=alphas_cumprod.cpu(),
+ ddim_time_steps=self.ddim_time_steps,
+ eta=ddim_eta,
+ verbose=verbose,
+ )
+ self.register_buffer("ddim_sigmas", ddim_sigmas)
+ self.register_buffer("ddim_alphas", ddim_alphas)
+ self.register_buffer("ddim_alphas_prev", ddim_alphas_prev)
+ self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas))
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
+ (1 - self.alphas_cumprod_prev)
+ / (1 - self.alphas_cumprod)
+ * (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
+ )
+ self.register_buffer(
+ "ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps
+ )
+
+ @torch.no_grad()
+ def sample(
+ self,
+ S,
+ batch_size,
+ shape,
+ conditioning=None,
+ callback=None,
+ normals_sequence=None,
+ img_callback=None,
+ quantize_x0=False,
+ eta=0.0,
+ mask=None,
+ x0=None,
+ temperature=1.0,
+ noise_dropout=0.0,
+ score_corrector=None,
+ corrector_kwargs=None,
+ verbose=True,
+ x_T=None,
+ log_every_t=100,
+ unconditional_guidance_scale=1.0,
+ unconditional_conditioning=None,
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
+ **kwargs,
+ ):
+ if conditioning is not None:
+ if isinstance(conditioning, dict):
+ cbs = conditioning[list(conditioning.keys())[0]].shape[0]
+ if cbs != batch_size:
+ print(
+ f"Warning: Got {cbs} conditionings but batch-size is {batch_size}"
+ )
+ else:
+ if conditioning.shape[0] != batch_size:
+ print(
+ f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}"
+ )
+
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
+ # sampling
+ C, H, W = shape
+ size = (batch_size, C, H, W)
+ print(f"Data shape for PLMS sampling is {size}")
+
+ samples, intermediates = self.plms_sampling(
+ conditioning,
+ size,
+ callback=callback,
+ img_callback=img_callback,
+ quantize_denoised=quantize_x0,
+ mask=mask,
+ x0=x0,
+ ddim_use_original_steps=False,
+ noise_dropout=noise_dropout,
+ temperature=temperature,
+ score_corrector=score_corrector,
+ corrector_kwargs=corrector_kwargs,
+ x_T=x_T,
+ log_every_t=log_every_t,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ )
+ return samples, intermediates
+
+ @torch.no_grad()
+ def plms_sampling(
+ self,
+ cond,
+ shape,
+ x_T=None,
+ ddim_use_original_steps=False,
+ callback=None,
+ time_steps=None,
+ quantize_denoised=False,
+ mask=None,
+ x0=None,
+ img_callback=None,
+ log_every_t=100,
+ temperature=1.0,
+ noise_dropout=0.0,
+ score_corrector=None,
+ corrector_kwargs=None,
+ unconditional_guidance_scale=1.0,
+ unconditional_conditioning=None,
+ ):
+ device = self.model.betas.device
+ b = shape[0]
+ if x_T is None:
+ img = torch.randn(shape, device=device)
+ else:
+ img = x_T
+
+ if time_steps is None:
+ time_steps = (
+ self.ddpm_num_time_steps
+ if ddim_use_original_steps
+ else self.ddim_time_steps
+ )
+ elif time_steps is not None and not ddim_use_original_steps:
+ subset_end = (
+ int(
+ min(time_steps / self.ddim_time_steps.shape[0], 1)
+ * self.ddim_time_steps.shape[0]
+ )
+ - 1
+ )
+ time_steps = self.ddim_time_steps[:subset_end]
+
+ intermediates = {"x_inter": [img], "pred_x0": [img]}
+ time_range = (
+ list(reversed(range(0, time_steps)))
+ if ddim_use_original_steps
+ else np.flip(time_steps)
+ )
+ total_steps = time_steps if ddim_use_original_steps else time_steps.shape[0]
+ print(f"Running PLMS Sampling with {total_steps} time_steps")
+
+ iterator = tqdm(time_range, desc="PLMS Sampler", total=total_steps)
+ old_eps = []
+
+ for i, step in enumerate(iterator):
+ index = total_steps - i - 1
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
+ ts_next = torch.full(
+ (b,),
+ time_range[min(i + 1, len(time_range) - 1)],
+ device=device,
+ dtype=torch.long,
+ )
+
+ if mask is not None:
+ assert x0 is not None
+ img_orig = self.model.q_sample(x0, ts)
+ img = img_orig * mask + (1.0 - mask) * img
+
+ outs = self.p_sample_plms(
+ img,
+ cond,
+ ts,
+ index=index,
+ use_original_steps=ddim_use_original_steps,
+ quantize_denoised=quantize_denoised,
+ temperature=temperature,
+ noise_dropout=noise_dropout,
+ score_corrector=score_corrector,
+ corrector_kwargs=corrector_kwargs,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ old_eps=old_eps,
+ t_next=ts_next,
+ )
+ img, pred_x0, e_t = outs
+ old_eps.append(e_t)
+ if len(old_eps) >= 4:
+ old_eps.pop(0)
+ if callback:
+ callback(i)
+ if img_callback:
+ img_callback(pred_x0, i)
+
+ if index % log_every_t == 0 or index == total_steps - 1:
+ intermediates["x_inter"].append(img)
+ intermediates["pred_x0"].append(pred_x0)
+
+ return img, intermediates
+
+ @torch.no_grad()
+ def p_sample_plms(
+ self,
+ x,
+ c,
+ t,
+ index,
+ repeat_noise=False,
+ use_original_steps=False,
+ quantize_denoised=False,
+ temperature=1.0,
+ noise_dropout=0.0,
+ score_corrector=None,
+ corrector_kwargs=None,
+ unconditional_guidance_scale=1.0,
+ unconditional_conditioning=None,
+ old_eps=None,
+ t_next=None,
+ ):
+ b, *_, device = *x.shape, x.device
+
+ def get_model_output(x, t):
+ if (
+ unconditional_conditioning is None
+ or unconditional_guidance_scale == 1.0
+ ):
+ e_t = self.model.apply_model(x, t, c)
+ else:
+ x_in = torch.cat([x] * 2)
+ t_in = torch.cat([t] * 2)
+ c_in = torch.cat([unconditional_conditioning, c])
+ e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
+ e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
+
+ if score_corrector is not None:
+ assert self.model.parameterization == "eps"
+ e_t = score_corrector.modify_score(
+ self.model, e_t, x, t, c, **corrector_kwargs
+ )
+
+ return e_t
+
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
+ alphas_prev = (
+ self.model.alphas_cumprod_prev
+ if use_original_steps
+ else self.ddim_alphas_prev
+ )
+ sqrt_one_minus_alphas = (
+ self.model.sqrt_one_minus_alphas_cumprod
+ if use_original_steps
+ else self.ddim_sqrt_one_minus_alphas
+ )
+ sigmas = (
+ self.model.ddim_sigmas_for_original_num_steps
+ if use_original_steps
+ else self.ddim_sigmas
+ )
+
+ def get_x_prev_and_pred_x0(e_t, index):
+ # select parameters corresponding to the currently considered timestep
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
+ sqrt_one_minus_at = torch.full(
+ (b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
+ )
+
+ # current prediction for x_0
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
+ if quantize_denoised:
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
+ # direction pointing to x_t
+ dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
+ if noise_dropout > 0.0:
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
+ return x_prev, pred_x0
+
+ e_t = get_model_output(x, t)
+ if len(old_eps) == 0:
+ # Pseudo Improved Euler (2nd order)
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
+ e_t_next = get_model_output(x_prev, t_next)
+ e_t_prime = (e_t + e_t_next) / 2
+ elif len(old_eps) == 1:
+ # 2nd order Pseudo Linear Multistep (Adams-Bashforth)
+ e_t_prime = (3 * e_t - old_eps[-1]) / 2
+ elif len(old_eps) == 2:
+ # 3nd order Pseudo Linear Multistep (Adams-Bashforth)
+ e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
+ elif len(old_eps) >= 3:
+ # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
+ e_t_prime = (
+ 55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]
+ ) / 24
+
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
+
+ return x_prev, pred_x0, e_t
diff --git a/core/models/samplers/uni_pc/__init__.py b/core/models/samplers/uni_pc/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/core/models/samplers/uni_pc/sampler.py b/core/models/samplers/uni_pc/sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..d3a652cebae6c746e62b29a183637d4b95b8928f
--- /dev/null
+++ b/core/models/samplers/uni_pc/sampler.py
@@ -0,0 +1,67 @@
+"""SAMPLING ONLY."""
+
+import torch
+
+from .uni_pc import NoiseScheduleVP, model_wrapper, UniPC
+
+
+class UniPCSampler(object):
+ def __init__(self, model, **kwargs):
+ super().__init__()
+ self.model = model
+
+ def to_torch(x):
+ return x.clone().detach().to(torch.float32).to(model.device)
+
+ self.register_buffer("alphas_cumprod", to_torch(model.alphas_cumprod))
+
+ def register_buffer(self, name, attr):
+ if type(attr) == torch.Tensor:
+ if attr.device != torch.device("cuda"):
+ attr = attr.to(torch.device("cuda"))
+ setattr(self, name, attr)
+
+ @torch.no_grad()
+ def sample(
+ self,
+ S,
+ batch_size,
+ shape,
+ conditioning=None,
+ x_T=None,
+ unconditional_guidance_scale=1.0,
+ unconditional_conditioning=None,
+ ):
+ # sampling
+ T, C, H, W = shape
+ size = (batch_size, T, C, H, W)
+
+ device = self.model.betas.device
+ if x_T is None:
+ img = torch.randn(size, device=device)
+ else:
+ img = x_T
+
+ ns = NoiseScheduleVP("discrete", alphas_cumprod=self.alphas_cumprod)
+
+ model_fn = model_wrapper(
+ lambda x, t, c: self.model.apply_model(x, t, c),
+ ns,
+ model_type="v",
+ guidance_type="classifier-free",
+ condition=conditioning,
+ unconditional_condition=unconditional_conditioning,
+ guidance_scale=unconditional_guidance_scale,
+ )
+
+ uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False)
+ x = uni_pc.sample(
+ img,
+ steps=S,
+ skip_type="time_uniform",
+ method="multistep",
+ order=2,
+ lower_order_final=True,
+ )
+
+ return x.to(device), None
diff --git a/core/models/samplers/uni_pc/uni_pc.py b/core/models/samplers/uni_pc/uni_pc.py
new file mode 100644
index 0000000000000000000000000000000000000000..616ea23380738f0cae2dd054e5e3d87ffa50d567
--- /dev/null
+++ b/core/models/samplers/uni_pc/uni_pc.py
@@ -0,0 +1,998 @@
+import torch
+import torch.nn.functional as F
+import math
+
+
+class NoiseScheduleVP:
+ def __init__(
+ self,
+ schedule="discrete",
+ betas=None,
+ alphas_cumprod=None,
+ continuous_beta_0=0.1,
+ continuous_beta_1=20.0,
+ ):
+ """Create a wrapper class for the forward SDE (VP type).
+
+ ***
+ Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t.
+ We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images.
+ ***
+
+ The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ).
+ We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper).
+ Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have:
+
+ log_alpha_t = self.marginal_log_mean_coeff(t)
+ sigma_t = self.marginal_std(t)
+ lambda_t = self.marginal_lambda(t)
+
+ Moreover, as lambda(t) is an invertible function, we also support its inverse function:
+
+ t = self.inverse_lambda(lambda_t)
+
+ ===============================================================
+
+ We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]).
+
+ 1. For discrete-time DPMs:
+
+ For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by:
+ t_i = (i + 1) / N
+ e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1.
+ We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3.
+
+ Args:
+ betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details)
+ alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details)
+
+ Note that we always have alphas_cumprod = cumprod(betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`.
+
+ **Important**: Please pay special attention for the args for `alphas_cumprod`:
+ The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that
+ q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ).
+ Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have
+ alpha_{t_n} = \sqrt{\hat{alpha_n}},
+ and
+ log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}).
+
+
+ 2. For continuous-time DPMs:
+
+ We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise
+ schedule are the default settings in DDPM and improved-DDPM:
+
+ Args:
+ beta_min: A `float` number. The smallest beta for the linear schedule.
+ beta_max: A `float` number. The largest beta for the linear schedule.
+ cosine_s: A `float` number. The hyperparameter in the cosine schedule.
+ cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule.
+ T: A `float` number. The ending time of the forward process.
+
+ ===============================================================
+
+ Args:
+ schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs,
+ 'linear' or 'cosine' for continuous-time DPMs.
+ Returns:
+ A wrapper object of the forward SDE (VP type).
+
+ ===============================================================
+
+ Example:
+
+ # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1):
+ >>> ns = NoiseScheduleVP('discrete', betas=betas)
+
+ # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1):
+ >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
+
+ # For continuous-time DPMs (VPSDE), linear schedule:
+ >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.)
+
+ """
+
+ if schedule not in ["discrete", "linear", "cosine"]:
+ raise ValueError(
+ "Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(
+ schedule
+ )
+ )
+
+ self.schedule = schedule
+ if schedule == "discrete":
+ if betas is not None:
+ log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0)
+ else:
+ assert alphas_cumprod is not None
+ log_alphas = 0.5 * torch.log(alphas_cumprod)
+ self.total_N = len(log_alphas)
+ self.T = 1.0
+ self.t_array = torch.linspace(0.0, 1.0, self.total_N + 1)[1:].reshape(
+ (1, -1)
+ )
+ self.log_alpha_array = log_alphas.reshape(
+ (
+ 1,
+ -1,
+ )
+ )
+ else:
+ self.total_N = 1000
+ self.beta_0 = continuous_beta_0
+ self.beta_1 = continuous_beta_1
+ self.cosine_s = 0.008
+ self.cosine_beta_max = 999.0
+ self.cosine_t_max = (
+ math.atan(self.cosine_beta_max * (1.0 + self.cosine_s) / math.pi)
+ * 2.0
+ * (1.0 + self.cosine_s)
+ / math.pi
+ - self.cosine_s
+ )
+ self.cosine_log_alpha_0 = math.log(
+ math.cos(self.cosine_s / (1.0 + self.cosine_s) * math.pi / 2.0)
+ )
+ self.schedule = schedule
+ if schedule == "cosine":
+ # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T.
+ # Note that T = 0.9946 may be not the optimal setting. However, we find it works well.
+ self.T = 0.9946
+ else:
+ self.T = 1.0
+
+ def marginal_log_mean_coeff(self, t):
+ """
+ Compute log(alpha_t) of a given continuous-time label t in [0, T].
+ """
+ if self.schedule == "discrete":
+ return interpolate_fn(
+ t.reshape((-1, 1)),
+ self.t_array.to(t.device),
+ self.log_alpha_array.to(t.device),
+ ).reshape((-1))
+ elif self.schedule == "linear":
+ return -0.25 * t**2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
+ elif self.schedule == "cosine":
+
+ def log_alpha_fn(s):
+ return torch.log(
+ torch.cos(
+ (s + self.cosine_s) / (1.0 + self.cosine_s) * math.pi / 2.0
+ )
+ )
+
+ log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0
+ return log_alpha_t
+
+ def marginal_alpha(self, t):
+ """
+ Compute alpha_t of a given continuous-time label t in [0, T].
+ """
+ return torch.exp(self.marginal_log_mean_coeff(t))
+
+ def marginal_std(self, t):
+ """
+ Compute sigma_t of a given continuous-time label t in [0, T].
+ """
+ return torch.sqrt(1.0 - torch.exp(2.0 * self.marginal_log_mean_coeff(t)))
+
+ def marginal_lambda(self, t):
+ """
+ Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
+ """
+ log_mean_coeff = self.marginal_log_mean_coeff(t)
+ log_std = 0.5 * torch.log(1.0 - torch.exp(2.0 * log_mean_coeff))
+ return log_mean_coeff - log_std
+
+ def inverse_lambda(self, lamb):
+ """
+ Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
+ """
+ if self.schedule == "linear":
+ tmp = (
+ 2.0
+ * (self.beta_1 - self.beta_0)
+ * torch.logaddexp(-2.0 * lamb, torch.zeros((1,)).to(lamb))
+ )
+ Delta = self.beta_0**2 + tmp
+ return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
+ elif self.schedule == "discrete":
+ log_alpha = -0.5 * torch.logaddexp(
+ torch.zeros((1,)).to(lamb.device), -2.0 * lamb
+ )
+ t = interpolate_fn(
+ log_alpha.reshape((-1, 1)),
+ torch.flip(self.log_alpha_array.to(lamb.device), [1]),
+ torch.flip(self.t_array.to(lamb.device), [1]),
+ )
+ return t.reshape((-1,))
+ else:
+ log_alpha = -0.5 * torch.logaddexp(-2.0 * lamb, torch.zeros((1,)).to(lamb))
+
+ def t_fn(log_alpha_t):
+ return (
+ torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0))
+ * 2.0
+ * (1.0 + self.cosine_s)
+ / math.pi
+ - self.cosine_s
+ )
+
+ t = t_fn(log_alpha)
+ return t
+
+
+def model_wrapper(
+ model,
+ noise_schedule,
+ model_type="noise",
+ model_kwargs={},
+ guidance_type="uncond",
+ condition=None,
+ unconditional_condition=None,
+ guidance_scale=1.0,
+ classifier_fn=None,
+ classifier_kwargs={},
+):
+ """Create a wrapper function for the noise prediction model.
+
+ DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to
+ firstly wrap the model function to a noise prediction model that accepts the continuous time as the input.
+
+ We support four types of the diffusion model by setting `model_type`:
+
+ 1. "noise": noise prediction model. (Trained by predicting noise).
+
+ 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0).
+
+ 3. "v": velocity prediction model. (Trained by predicting the velocity).
+ The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2].
+
+ [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models."
+ arXiv preprint arXiv:2202.00512 (2022).
+ [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models."
+ arXiv preprint arXiv:2210.02303 (2022).
+
+ 4. "score": marginal score function. (Trained by denoising score matching).
+ Note that the score function and the noise prediction model follows a simple relationship:
+ ```
+ noise(x_t, t) = -sigma_t * score(x_t, t)
+ ```
+
+ We support three types of guided sampling by DPMs by setting `guidance_type`:
+ 1. "uncond": unconditional sampling by DPMs.
+ The input `model` has the following format:
+ ``
+ model(x, t_input, **model_kwargs) -> noise | x_start | v | score
+ ``
+
+ 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier.
+ The input `model` has the following format:
+ ``
+ model(x, t_input, **model_kwargs) -> noise | x_start | v | score
+ ``
+
+ The input `classifier_fn` has the following format:
+ ``
+ classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond)
+ ``
+
+ [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis,"
+ in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794.
+
+ 3. "classifier-free": classifier-free guidance sampling by conditional DPMs.
+ The input `model` has the following format:
+ ``
+ model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
+ ``
+ And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
+
+ [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
+ arXiv preprint arXiv:2207.12598 (2022).
+
+
+ The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
+ or continuous-time labels (i.e. epsilon to T).
+
+ We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise:
+ ``
+ def model_fn(x, t_continuous) -> noise:
+ t_input = get_model_input_time(t_continuous)
+ return noise_pred(model, x, t_input, **model_kwargs)
+ ``
+ where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
+
+ ===============================================================
+
+ Args:
+ model: A diffusion model with the corresponding format described above.
+ noise_schedule: A noise schedule object, such as NoiseScheduleVP.
+ model_type: A `str`. The parameterization type of the diffusion model.
+ "noise" or "x_start" or "v" or "score".
+ model_kwargs: A `dict`. A dict for the other inputs of the model function.
+ guidance_type: A `str`. The type of the guidance for sampling.
+ "uncond" or "classifier" or "classifier-free".
+ condition: A pytorch tensor. The condition for the guided sampling.
+ Only used for "classifier" or "classifier-free" guidance type.
+ unconditional_condition: A pytorch tensor. The condition for the unconditional sampling.
+ Only used for "classifier-free" guidance type.
+ guidance_scale: A `float`. The scale for the guided sampling.
+ classifier_fn: A classifier function. Only used for the classifier guidance.
+ classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function.
+ Returns:
+ A noise prediction model that accepts the noised data and the continuous time as the inputs.
+ """
+
+ def get_model_input_time(t_continuous):
+ """
+ Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
+ For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].
+ For continuous-time DPMs, we just use `t_continuous`.
+ """
+ if noise_schedule.schedule == "discrete":
+ return (t_continuous - 1.0 / noise_schedule.total_N) * 1000.0
+ else:
+ return t_continuous
+
+ def noise_pred_fn(x, t_continuous, cond=None):
+ if t_continuous.reshape((-1,)).shape[0] == 1:
+ t_continuous = t_continuous.expand((x.shape[0]))
+ t_input = get_model_input_time(t_continuous)
+ if cond is None:
+ output = model(x, t_input, None, **model_kwargs)
+ else:
+ output = model(x, t_input, cond, **model_kwargs)
+ if isinstance(output, tuple):
+ output = output[0]
+ if model_type == "noise":
+ return output
+ elif model_type == "x_start":
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(
+ t_continuous
+ ), noise_schedule.marginal_std(t_continuous)
+ dims = x.dim()
+ return (x - expand_dims(alpha_t, dims) * output) / expand_dims(
+ sigma_t, dims
+ )
+ elif model_type == "v":
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(
+ t_continuous
+ ), noise_schedule.marginal_std(t_continuous)
+ dims = x.dim()
+ print("alpha_t.shape", alpha_t.shape)
+ print("sigma_t.shape", sigma_t.shape)
+ print("dims", dims)
+ print("x.shape", x.shape)
+ # x: b, t, c, h, w
+ alpha_t = alpha_t.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
+ sigma_t = sigma_t.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
+ print("alpha_t.shape", alpha_t.shape)
+ print("sigma_t.shape", sigma_t.shape)
+ print("output.shape", output.shape)
+ return alpha_t * output + sigma_t * x
+ elif model_type == "score":
+ sigma_t = noise_schedule.marginal_std(t_continuous)
+ dims = x.dim()
+ return -expand_dims(sigma_t, dims) * output
+
+ def cond_grad_fn(x, t_input):
+ """
+ Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).
+ """
+ with torch.enable_grad():
+ x_in = x.detach().requires_grad_(True)
+ log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs)
+ return torch.autograd.grad(log_prob.sum(), x_in)[0]
+
+ def model_fn(x, t_continuous):
+ """
+ The noise predicition model function that is used for DPM-Solver.
+ """
+ if t_continuous.reshape((-1,)).shape[0] == 1:
+ t_continuous = t_continuous.expand((x.shape[0]))
+ if guidance_type == "uncond":
+ return noise_pred_fn(x, t_continuous)
+ elif guidance_type == "classifier":
+ assert classifier_fn is not None
+ t_input = get_model_input_time(t_continuous)
+ cond_grad = cond_grad_fn(x, t_input)
+ sigma_t = noise_schedule.marginal_std(t_continuous)
+ noise = noise_pred_fn(x, t_continuous)
+ return (
+ noise
+ - guidance_scale
+ * expand_dims(sigma_t, dims=cond_grad.dim())
+ * cond_grad
+ )
+ elif guidance_type == "classifier-free":
+ if guidance_scale == 1.0 or unconditional_condition is None:
+ return noise_pred_fn(x, t_continuous, cond=condition)
+ else:
+ x_in = x
+ t_in = t_continuous
+ print("x_in.shape=", x_in.shape)
+ print("t_in.shape=", t_in.shape)
+ noise = noise_pred_fn(x_in, t_in, cond=condition)
+
+ noise_uncond = noise_pred_fn(x_in, t_in, cond=unconditional_condition)
+ return noise_uncond + guidance_scale * (noise - noise_uncond)
+
+ assert model_type in ["noise", "x_start", "v"]
+ assert guidance_type in ["uncond", "classifier", "classifier-free"]
+ return model_fn
+
+
+class UniPC:
+ def __init__(
+ self,
+ model_fn,
+ noise_schedule,
+ predict_x0=True,
+ thresholding=False,
+ max_val=1.0,
+ variant="bh1",
+ ):
+ """Construct a UniPC.
+
+ We support both data_prediction and noise_prediction.
+ """
+ self.model = model_fn
+ self.noise_schedule = noise_schedule
+ self.variant = variant
+ self.predict_x0 = predict_x0
+ self.thresholding = thresholding
+ self.max_val = max_val
+
+ def dynamic_thresholding_fn(self, x0, t=None):
+ """
+ The dynamic thresholding method.
+ """
+ dims = x0.dim()
+ p = self.dynamic_thresholding_ratio
+ s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
+ s = expand_dims(
+ torch.maximum(
+ s, self.thresholding_max_val * torch.ones_like(s).to(s.device)
+ ),
+ dims,
+ )
+ x0 = torch.clamp(x0, -s, s) / s
+ return x0
+
+ def noise_prediction_fn(self, x, t):
+ """
+ Return the noise prediction model.
+ """
+ return self.model(x, t)
+
+ def data_prediction_fn(self, x, t):
+ """
+ Return the data prediction model (with thresholding).
+ """
+ noise = self.noise_prediction_fn(x, t)
+ dims = x.dim()
+ alpha_t, sigma_t = self.noise_schedule.marginal_alpha(
+ t
+ ), self.noise_schedule.marginal_std(t)
+ x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims)
+ if self.thresholding:
+ p = 0.995 # A hyperparameter in the paper of "Imagen" [1].
+ s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
+ s = expand_dims(
+ torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims
+ )
+ x0 = torch.clamp(x0, -s, s) / s
+ return x0
+
+ def model_fn(self, x, t):
+ """
+ Convert the model to the noise prediction model or the data prediction model.
+ """
+ if self.predict_x0:
+ return self.data_prediction_fn(x, t)
+ else:
+ return self.noise_prediction_fn(x, t)
+
+ def get_time_steps(self, skip_type, t_T, t_0, N, device):
+ """Compute the intermediate time steps for sampling."""
+ if skip_type == "logSNR":
+ lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
+ lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
+ logSNR_steps = torch.linspace(
+ lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1
+ ).to(device)
+ return self.noise_schedule.inverse_lambda(logSNR_steps)
+ elif skip_type == "time_uniform":
+ return torch.linspace(t_T, t_0, N + 1).to(device)
+ elif skip_type == "time_quadratic":
+ t_order = 2
+ t = (
+ torch.linspace(t_T ** (1.0 / t_order), t_0 ** (1.0 / t_order), N + 1)
+ .pow(t_order)
+ .to(device)
+ )
+ return t
+ else:
+ raise ValueError(
+ "Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(
+ skip_type
+ )
+ )
+
+ def get_orders_and_timesteps_for_singlestep_solver(
+ self, steps, order, skip_type, t_T, t_0, device
+ ):
+ """
+ Get the order of each step for sampling by the singlestep DPM-Solver.
+ """
+ if order == 3:
+ K = steps // 3 + 1
+ if steps % 3 == 0:
+ orders = [
+ 3,
+ ] * (
+ K - 2
+ ) + [2, 1]
+ elif steps % 3 == 1:
+ orders = [
+ 3,
+ ] * (
+ K - 1
+ ) + [1]
+ else:
+ orders = [
+ 3,
+ ] * (
+ K - 1
+ ) + [2]
+ elif order == 2:
+ if steps % 2 == 0:
+ K = steps // 2
+ orders = [
+ 2,
+ ] * K
+ else:
+ K = steps // 2 + 1
+ orders = [
+ 2,
+ ] * (
+ K - 1
+ ) + [1]
+ elif order == 1:
+ K = steps
+ orders = [
+ 1,
+ ] * steps
+ else:
+ raise ValueError("'order' must be '1' or '2' or '3'.")
+ if skip_type == "logSNR":
+ # To reproduce the results in DPM-Solver paper
+ timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device)
+ else:
+ timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[
+ torch.cumsum(
+ torch.tensor(
+ [
+ 0,
+ ]
+ + orders
+ ),
+ 0,
+ ).to(device)
+ ]
+ return timesteps_outer, orders
+
+ def denoise_to_zero_fn(self, x, s):
+ """
+ Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
+ """
+ return self.data_prediction_fn(x, s)
+
+ def multistep_uni_pc_update(
+ self, x, model_prev_list, t_prev_list, t, order, **kwargs
+ ):
+ if len(t.shape) == 0:
+ t = t.view(-1)
+ if "bh" in self.variant:
+ return self.multistep_uni_pc_bh_update(
+ x, model_prev_list, t_prev_list, t, order, **kwargs
+ )
+ else:
+ assert self.variant == "vary_coeff"
+ return self.multistep_uni_pc_vary_update(
+ x, model_prev_list, t_prev_list, t, order, **kwargs
+ )
+
+ def multistep_uni_pc_vary_update(
+ self, x, model_prev_list, t_prev_list, t, order, use_corrector=True
+ ):
+ print(
+ f"using unified predictor-corrector with order {order} (solver type: vary coeff)"
+ )
+ ns = self.noise_schedule
+ assert order <= len(model_prev_list)
+
+ # first compute rks
+ t_prev_0 = t_prev_list[-1]
+ lambda_prev_0 = ns.marginal_lambda(t_prev_0)
+ lambda_t = ns.marginal_lambda(t)
+ model_prev_0 = model_prev_list[-1]
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
+ log_alpha_t = ns.marginal_log_mean_coeff(t)
+ alpha_t = torch.exp(log_alpha_t)
+
+ h = lambda_t - lambda_prev_0
+
+ rks = []
+ D1s = []
+ for i in range(1, order):
+ t_prev_i = t_prev_list[-(i + 1)]
+ model_prev_i = model_prev_list[-(i + 1)]
+ lambda_prev_i = ns.marginal_lambda(t_prev_i)
+ rk = (lambda_prev_i - lambda_prev_0) / h
+ rks.append(rk)
+ D1s.append((model_prev_i - model_prev_0) / rk)
+
+ rks.append(1.0)
+ rks = torch.tensor(rks, device=x.device)
+
+ K = len(rks)
+ # build C matrix
+ C = []
+
+ col = torch.ones_like(rks)
+ for k in range(1, K + 1):
+ C.append(col)
+ col = col * rks / (k + 1)
+ C = torch.stack(C, dim=1)
+
+ if len(D1s) > 0:
+ D1s = torch.stack(D1s, dim=1) # (B, K)
+ C_inv_p = torch.linalg.inv(C[:-1, :-1])
+ A_p = C_inv_p
+
+ if use_corrector:
+ print("using corrector")
+ C_inv = torch.linalg.inv(C)
+ A_c = C_inv
+
+ hh = -h if self.predict_x0 else h
+ h_phi_1 = torch.expm1(hh)
+ h_phi_ks = []
+ factorial_k = 1
+ h_phi_k = h_phi_1
+ for k in range(1, K + 2):
+ h_phi_ks.append(h_phi_k)
+ h_phi_k = h_phi_k / hh - 1 / factorial_k
+ factorial_k *= k + 1
+
+ model_t = None
+ if self.predict_x0:
+ x_t_ = sigma_t / sigma_prev_0 * x - alpha_t * h_phi_1 * model_prev_0
+ # now predictor
+ x_t = x_t_
+ if len(D1s) > 0:
+ # compute the residuals for predictor
+ for k in range(K - 1):
+ x_t = x_t - alpha_t * h_phi_ks[k + 1] * torch.einsum(
+ "bktchw,k->btchw", D1s, A_p[k]
+ )
+ # now corrector
+ if use_corrector:
+ model_t = self.model_fn(x_t, t)
+ D1_t = model_t - model_prev_0
+ x_t = x_t_
+ k = 0
+ for k in range(K - 1):
+ x_t = x_t - alpha_t * h_phi_ks[k + 1] * torch.einsum(
+ "bktchw,k->btchw", D1s, A_c[k][:-1]
+ )
+ x_t = x_t - alpha_t * h_phi_ks[K] * (D1_t * A_c[k][-1])
+ else:
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(
+ t_prev_0
+ ), ns.marginal_log_mean_coeff(t)
+ x_t_ = (torch.exp(log_alpha_t - log_alpha_prev_0)) * x - (
+ sigma_t * h_phi_1
+ ) * model_prev_0
+ # now predictor
+ x_t = x_t_
+ if len(D1s) > 0:
+ # compute the residuals for predictor
+ for k in range(K - 1):
+ x_t = x_t - sigma_t * h_phi_ks[k + 1] * torch.einsum(
+ "bktchw,k->btchw", D1s, A_p[k]
+ )
+ # now corrector
+ if use_corrector:
+ model_t = self.model_fn(x_t, t)
+ D1_t = model_t - model_prev_0
+ x_t = x_t_
+ k = 0
+ for k in range(K - 1):
+ x_t = x_t - sigma_t * h_phi_ks[k + 1] * torch.einsum(
+ "bktchw,k->btchw", D1s, A_c[k][:-1]
+ )
+ x_t = x_t - sigma_t * h_phi_ks[K] * (D1_t * A_c[k][-1])
+ return x_t, model_t
+
+ def multistep_uni_pc_bh_update(
+ self, x, model_prev_list, t_prev_list, t, order, x_t=None, use_corrector=True
+ ):
+ print(
+ f"using unified predictor-corrector with order {order} (solver type: B(h))"
+ )
+ ns = self.noise_schedule
+ assert order <= len(model_prev_list)
+ dims = x.dim()
+
+ # first compute rks
+ t_prev_0 = t_prev_list[-1]
+ lambda_prev_0 = ns.marginal_lambda(t_prev_0)
+ lambda_t = ns.marginal_lambda(t)
+ model_prev_0 = model_prev_list[-1]
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(
+ t_prev_0
+ ), ns.marginal_log_mean_coeff(t)
+ alpha_t = torch.exp(log_alpha_t)
+
+ h = lambda_t - lambda_prev_0
+
+ rks = []
+ D1s = []
+ for i in range(1, order):
+ t_prev_i = t_prev_list[-(i + 1)]
+ model_prev_i = model_prev_list[-(i + 1)]
+ lambda_prev_i = ns.marginal_lambda(t_prev_i)
+ rk = ((lambda_prev_i - lambda_prev_0) / h)[0]
+ rks.append(rk)
+ D1s.append((model_prev_i - model_prev_0) / rk)
+
+ rks.append(1.0)
+ rks = torch.tensor(rks, device=x.device)
+
+ R = []
+ b = []
+
+ hh = -h[0] if self.predict_x0 else h[0]
+ h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
+ h_phi_k = h_phi_1 / hh - 1
+
+ factorial_i = 1
+
+ if self.variant == "bh1":
+ B_h = hh
+ elif self.variant == "bh2":
+ B_h = torch.expm1(hh)
+ else:
+ raise NotImplementedError()
+
+ for i in range(1, order + 1):
+ R.append(torch.pow(rks, i - 1))
+ b.append(h_phi_k * factorial_i / B_h)
+ factorial_i *= i + 1
+ h_phi_k = h_phi_k / hh - 1 / factorial_i
+
+ R = torch.stack(R)
+ b = torch.tensor(b, device=x.device)
+
+ # now predictor
+ use_predictor = len(D1s) > 0 and x_t is None
+ if len(D1s) > 0:
+ D1s = torch.stack(D1s, dim=1) # (B, K)
+ if x_t is None:
+ # for order 2, we use a simplified version
+ if order == 2:
+ rhos_p = torch.tensor([0.5], device=b.device)
+ else:
+ rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1])
+ else:
+ D1s = None
+
+ if use_corrector:
+ print("using corrector")
+ # for order 1, we use a simplified version
+ if order == 1:
+ rhos_c = torch.tensor([0.5], device=b.device)
+ else:
+ rhos_c = torch.linalg.solve(R, b)
+
+ model_t = None
+ if self.predict_x0:
+ x_t_ = (
+ expand_dims(sigma_t / sigma_prev_0, dims) * x
+ - expand_dims(alpha_t * h_phi_1, dims) * model_prev_0
+ )
+
+ if x_t is None:
+ if use_predictor:
+ pred_res = torch.einsum("k,bktchw->btchw", rhos_p, D1s)
+ else:
+ pred_res = 0
+ x_t = x_t_ - expand_dims(alpha_t * B_h, dims) * pred_res
+
+ if use_corrector:
+ model_t = self.model_fn(x_t, t)
+ if D1s is not None:
+ corr_res = torch.einsum("k,bktchw->btchw", rhos_c[:-1], D1s)
+ else:
+ corr_res = 0
+ D1_t = model_t - model_prev_0
+ x_t = x_t_ - expand_dims(alpha_t * B_h, dims) * (
+ corr_res + rhos_c[-1] * D1_t
+ )
+ else:
+ x_t_ = (
+ expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
+ - expand_dims(sigma_t * h_phi_1, dims) * model_prev_0
+ )
+ if x_t is None:
+ if use_predictor:
+ pred_res = torch.einsum("k,bktchw->btchw", rhos_p, D1s)
+ else:
+ pred_res = 0
+ x_t = x_t_ - expand_dims(sigma_t * B_h, dims) * pred_res
+
+ if use_corrector:
+ model_t = self.model_fn(x_t, t)
+ if D1s is not None:
+ corr_res = torch.einsum("k,bktchw->btchw", rhos_c[:-1], D1s)
+ else:
+ corr_res = 0
+ D1_t = model_t - model_prev_0
+ x_t = x_t_ - expand_dims(sigma_t * B_h, dims) * (
+ corr_res + rhos_c[-1] * D1_t
+ )
+ return x_t, model_t
+
+ def sample(
+ self,
+ x,
+ steps=20,
+ t_start=None,
+ t_end=None,
+ order=3,
+ skip_type="time_uniform",
+ method="singlestep",
+ lower_order_final=True,
+ denoise_to_zero=False,
+ solver_type="dpm_solver",
+ atol=0.0078,
+ rtol=0.05,
+ corrector=False,
+ ):
+ t_0 = 1.0 / self.noise_schedule.total_N if t_end is None else t_end
+ t_T = self.noise_schedule.T if t_start is None else t_start
+ device = x.device
+ if method == "multistep":
+ assert steps >= order
+ timesteps = self.get_time_steps(
+ skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device
+ )
+ assert timesteps.shape[0] - 1 == steps
+ with torch.no_grad():
+ vec_t = timesteps[0].expand((x.shape[0]))
+ model_prev_list = [self.model_fn(x, vec_t)]
+ t_prev_list = [vec_t]
+ # Init the first `order` values by lower order multistep DPM-Solver.
+ for init_order in range(1, order):
+ vec_t = timesteps[init_order].expand(x.shape[0])
+ x, model_x = self.multistep_uni_pc_update(
+ x,
+ model_prev_list,
+ t_prev_list,
+ vec_t,
+ init_order,
+ use_corrector=True,
+ )
+ if model_x is None:
+ model_x = self.model_fn(x, vec_t)
+ model_prev_list.append(model_x)
+ t_prev_list.append(vec_t)
+ for step in range(order, steps + 1):
+ vec_t = timesteps[step].expand(x.shape[0])
+ print(f"Current step={step}; vec_t={vec_t}.")
+ if lower_order_final:
+ step_order = min(order, steps + 1 - step)
+ else:
+ step_order = order
+ print("this step order:", step_order)
+ if step == steps:
+ print("do not run corrector at the last step")
+ use_corrector = False
+ else:
+ use_corrector = True
+ x, model_x = self.multistep_uni_pc_update(
+ x,
+ model_prev_list,
+ t_prev_list,
+ vec_t,
+ step_order,
+ use_corrector=use_corrector,
+ )
+ for i in range(order - 1):
+ t_prev_list[i] = t_prev_list[i + 1]
+ model_prev_list[i] = model_prev_list[i + 1]
+ t_prev_list[-1] = vec_t
+ # We do not need to evaluate the final model value.
+ if step < steps:
+ if model_x is None:
+ model_x = self.model_fn(x, vec_t)
+ model_prev_list[-1] = model_x
+ else:
+ raise NotImplementedError()
+ if denoise_to_zero:
+ x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0)
+ return x
+
+
+#############################################################
+# other utility functions
+#############################################################
+
+
+def interpolate_fn(x, xp, yp):
+ """
+ A piecewise linear function y = f(x), using xp and yp as keypoints.
+ We implement f(x) in a differentiable way (i.e. applicable for autograd).
+ The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
+
+ Args:
+ x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).
+ xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
+ yp: PyTorch tensor with shape [C, K].
+ Returns:
+ The function values f(x), with shape [N, C].
+ """
+ N, K = x.shape[0], xp.shape[1]
+ all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
+ sorted_all_x, x_indices = torch.sort(all_x, dim=2)
+ x_idx = torch.argmin(x_indices, dim=2)
+ cand_start_idx = x_idx - 1
+ start_idx = torch.where(
+ torch.eq(x_idx, 0),
+ torch.tensor(1, device=x.device),
+ torch.where(
+ torch.eq(x_idx, K),
+ torch.tensor(K - 2, device=x.device),
+ cand_start_idx,
+ ),
+ )
+ end_idx = torch.where(
+ torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1
+ )
+ start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
+ end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
+ start_idx2 = torch.where(
+ torch.eq(x_idx, 0),
+ torch.tensor(0, device=x.device),
+ torch.where(
+ torch.eq(x_idx, K),
+ torch.tensor(K - 2, device=x.device),
+ cand_start_idx,
+ ),
+ )
+ y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
+ start_y = torch.gather(
+ y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)
+ ).squeeze(2)
+ end_y = torch.gather(
+ y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)
+ ).squeeze(2)
+ cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
+ return cand
+
+
+def expand_dims(v, dims):
+ """
+ Expand the tensor `v` to the dim `dims`.
+
+ Args:
+ `v`: a PyTorch tensor with shape [N].
+ `dim`: a `int`.
+ Returns:
+ a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
+ """
+ return v[(...,) + (None,) * (dims - 1)]
diff --git a/core/models/utils_diffusion.py b/core/models/utils_diffusion.py
new file mode 100755
index 0000000000000000000000000000000000000000..a4b509cef136b742516c75f0d71e56d8f571fbcc
--- /dev/null
+++ b/core/models/utils_diffusion.py
@@ -0,0 +1,186 @@
+import math
+
+import numpy as np
+import torch
+from einops import repeat
+
+
+def timestep_embedding(time_steps, dim, max_period=10000, repeat_only=False):
+ """
+ Create sinusoidal timestep embeddings.
+ :param time_steps: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ :param dim: the dimension of the output.
+ :param max_period: controls the minimum frequency of the embeddings.
+ :return: an [N x dim] Tensor of positional embeddings.
+ """
+ if not repeat_only:
+ half = dim // 2
+ freqs = torch.exp(
+ -math.log(max_period)
+ * torch.arange(start=0, end=half, dtype=torch.float32)
+ / half
+ ).to(device=time_steps.device)
+ args = time_steps[:, None].float() * freqs[None]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat(
+ [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
+ )
+ else:
+ embedding = repeat(time_steps, "b -> b d", d=dim)
+ return embedding
+
+
+def make_beta_schedule(
+ schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3
+):
+ if schedule == "linear":
+ betas = (
+ torch.linspace(
+ linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64
+ )
+ ** 2
+ )
+
+ elif schedule == "cosine":
+ time_steps = (
+ torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
+ )
+ alphas = time_steps / (1 + cosine_s) * np.pi / 2
+ alphas = torch.cos(alphas).pow(2)
+ alphas = alphas / alphas[0]
+ betas = 1 - alphas[1:] / alphas[:-1]
+ betas = np.clip(betas, a_min=0, a_max=0.999)
+
+ elif schedule == "sqrt_linear":
+ betas = torch.linspace(
+ linear_start, linear_end, n_timestep, dtype=torch.float64
+ )
+ elif schedule == "sqrt":
+ betas = (
+ torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
+ ** 0.5
+ )
+ else:
+ raise ValueError(f"schedule '{schedule}' unknown.")
+ return betas.numpy()
+
+
+def make_ddim_time_steps(
+ ddim_discr_method, num_ddim_time_steps, num_ddpm_time_steps, verbose=True
+):
+ if ddim_discr_method == "uniform":
+ c = num_ddpm_time_steps // num_ddim_time_steps
+ ddim_time_steps = np.asarray(list(range(0, num_ddpm_time_steps, c)))
+ steps_out = ddim_time_steps + 1
+ elif ddim_discr_method == "quad":
+ ddim_time_steps = (
+ (np.linspace(0, np.sqrt(num_ddpm_time_steps * 0.8), num_ddim_time_steps))
+ ** 2
+ ).astype(int)
+ steps_out = ddim_time_steps + 1
+ elif ddim_discr_method == "uniform_trailing":
+ c = num_ddpm_time_steps / num_ddim_time_steps
+ ddim_time_steps = np.flip(
+ np.round(np.arange(num_ddpm_time_steps, 0, -c))
+ ).astype(np.int64)
+ steps_out = ddim_time_steps - 1
+ else:
+ raise NotImplementedError(
+ f'There is no ddim discretization method called "{ddim_discr_method}"'
+ )
+
+ # assert ddim_time_steps.shape[0] == num_ddim_time_steps
+ # add one to get the final alpha values right (the ones from first scale to data during sampling)
+ if verbose:
+ print(f"Selected time_steps for ddim sampler: {steps_out}")
+ return steps_out
+
+
+def make_ddim_sampling_parameters(alphacums, ddim_time_steps, eta, verbose=True):
+ # select alphas for computing the variance schedule
+ # print(f'ddim_time_steps={ddim_time_steps}, len_alphacums={len(alphacums)}')
+ alphas = alphacums[ddim_time_steps]
+ alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_time_steps[:-1]].tolist())
+
+ # according the the formula provided in https://arxiv.org/abs/2010.02502
+ sigmas = eta * np.sqrt(
+ (1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)
+ )
+ if verbose:
+ print(
+ f"Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}"
+ )
+ print(
+ f"For the chosen value of eta, which is {eta}, "
+ f"this results in the following sigma_t schedule for ddim sampler {sigmas}"
+ )
+ return sigmas, alphas, alphas_prev
+
+
+def betas_for_alpha_bar(num_diffusion_time_steps, alpha_bar, max_beta=0.999):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function,
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
+ :param num_diffusion_time_steps: the number of betas to produce.
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
+ produces the cumulative product of (1-beta) up to that
+ part of the diffusion process.
+ :param max_beta: the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+ """
+ betas = []
+ for i in range(num_diffusion_time_steps):
+ t1 = i / num_diffusion_time_steps
+ t2 = (i + 1) / num_diffusion_time_steps
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
+ return np.array(betas)
+
+
+def rescale_zero_terminal_snr(betas):
+ """
+ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
+
+ Args:
+ betas (`numpy.ndarray`):
+ the betas that the scheduler is being initialized with.
+
+ Returns:
+ `numpy.ndarray`: rescaled betas with zero terminal SNR
+ """
+ # Convert betas to alphas_bar_sqrt
+ alphas = 1.0 - betas
+ alphas_cumprod = np.cumprod(alphas, axis=0)
+ alphas_bar_sqrt = np.sqrt(alphas_cumprod)
+
+ # Store old values.
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].copy()
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].copy()
+
+ # Shift so the last timestep is zero.
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
+
+ # Scale so the first timestep is back to the old value.
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
+
+ # Convert alphas_bar_sqrt to betas
+ alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
+ alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
+ alphas = np.concatenate([alphas_bar[0:1], alphas])
+ betas = 1 - alphas
+
+ return betas
+
+
+def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
+ """
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
+ """
+ std_text = noise_pred_text.std(
+ dim=list(range(1, noise_pred_text.ndim)), keepdim=True
+ )
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
+ factor = guidance_rescale * (std_text / std_cfg) + (1 - guidance_rescale)
+ return noise_cfg * factor
diff --git a/core/modules/attention.py b/core/modules/attention.py
new file mode 100755
index 0000000000000000000000000000000000000000..2813b5f2d819a45f33eec1a00e7cf28a874c8ee7
--- /dev/null
+++ b/core/modules/attention.py
@@ -0,0 +1,710 @@
+import torch
+from torch import nn, einsum
+import torch.nn.functional as F
+from einops import rearrange, repeat
+from functools import partial
+
+try:
+ import xformers
+ import xformers.ops
+
+ XFORMERS_IS_AVAILBLE = True
+except:
+ XFORMERS_IS_AVAILBLE = False
+from core.common import (
+ gradient_checkpoint,
+ exists,
+ default,
+)
+from core.basics import zero_module
+
+
+class RelativePosition(nn.Module):
+
+ def __init__(self, num_units, max_relative_position):
+ super().__init__()
+ self.num_units = num_units
+ self.max_relative_position = max_relative_position
+ self.embeddings_table = nn.Parameter(
+ torch.Tensor(max_relative_position * 2 + 1, num_units)
+ )
+ nn.init.xavier_uniform_(self.embeddings_table)
+
+ def forward(self, length_q, length_k):
+ device = self.embeddings_table.device
+ range_vec_q = torch.arange(length_q, device=device)
+ range_vec_k = torch.arange(length_k, device=device)
+ distance_mat = range_vec_k[None, :] - range_vec_q[:, None]
+ distance_mat_clipped = torch.clamp(
+ distance_mat, -self.max_relative_position, self.max_relative_position
+ )
+ final_mat = distance_mat_clipped + self.max_relative_position
+ final_mat = final_mat.long()
+ embeddings = self.embeddings_table[final_mat]
+ return embeddings
+
+
+class CrossAttention(nn.Module):
+
+ def __init__(
+ self,
+ query_dim,
+ context_dim=None,
+ heads=8,
+ dim_head=64,
+ dropout=0.0,
+ relative_position=False,
+ temporal_length=None,
+ video_length=None,
+ image_cross_attention=False,
+ image_cross_attention_scale=1.0,
+ image_cross_attention_scale_learnable=False,
+ text_context_len=77,
+ ):
+ super().__init__()
+ inner_dim = dim_head * heads
+ context_dim = default(context_dim, query_dim)
+
+ self.scale = dim_head**-0.5
+ self.heads = heads
+ self.dim_head = dim_head
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
+
+ self.to_out = nn.Sequential(
+ nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
+ )
+
+ self.relative_position = relative_position
+ if self.relative_position:
+ assert temporal_length is not None
+ self.relative_position_k = RelativePosition(
+ num_units=dim_head, max_relative_position=temporal_length
+ )
+ self.relative_position_v = RelativePosition(
+ num_units=dim_head, max_relative_position=temporal_length
+ )
+ else:
+ # only used for spatial attention, while NOT for temporal attention
+ if XFORMERS_IS_AVAILBLE and temporal_length is None:
+ self.forward = self.efficient_forward
+
+ self.video_length = video_length
+ self.image_cross_attention = image_cross_attention
+ self.image_cross_attention_scale = image_cross_attention_scale
+ self.text_context_len = text_context_len
+ self.image_cross_attention_scale_learnable = (
+ image_cross_attention_scale_learnable
+ )
+ if self.image_cross_attention:
+ self.to_k_ip = nn.Linear(context_dim, inner_dim, bias=False)
+ self.to_v_ip = nn.Linear(context_dim, inner_dim, bias=False)
+ if image_cross_attention_scale_learnable:
+ self.register_parameter("alpha", nn.Parameter(torch.tensor(0.0)))
+
+ def forward(self, x, context=None, mask=None):
+ spatial_self_attn = context is None
+ k_ip, v_ip, out_ip = None, None, None
+
+ h = self.heads
+ q = self.to_q(x)
+ context = default(context, x)
+
+ if self.image_cross_attention and not spatial_self_attn:
+ context, context_image = (
+ context[:, : self.text_context_len, :],
+ context[:, self.text_context_len :, :],
+ )
+ k = self.to_k(context)
+ v = self.to_v(context)
+ k_ip = self.to_k_ip(context_image)
+ v_ip = self.to_v_ip(context_image)
+ else:
+ if not spatial_self_attn:
+ context = context[:, : self.text_context_len, :]
+ k = self.to_k(context)
+ v = self.to_v(context)
+
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
+
+ sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
+ if self.relative_position:
+ len_q, len_k, len_v = q.shape[1], k.shape[1], v.shape[1]
+ k2 = self.relative_position_k(len_q, len_k)
+ sim2 = einsum("b t d, t s d -> b t s", q, k2) * self.scale
+ sim += sim2
+ del k
+
+ if exists(mask):
+ # feasible for causal attention mask only
+ max_neg_value = -torch.finfo(sim.dtype).max
+ mask = repeat(mask, "b i j -> (b h) i j", h=h)
+ sim.masked_fill_(~(mask > 0.5), max_neg_value)
+
+ # attention, what we cannot get enough of
+ sim = sim.softmax(dim=-1)
+
+ out = torch.einsum("b i j, b j d -> b i d", sim, v)
+ if self.relative_position:
+ v2 = self.relative_position_v(len_q, len_v)
+ out2 = einsum("b t s, t s d -> b t d", sim, v2)
+ out += out2
+ out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
+
+ # for image cross-attention
+ if k_ip is not None:
+ k_ip, v_ip = map(
+ lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (k_ip, v_ip)
+ )
+ sim_ip = torch.einsum("b i d, b j d -> b i j", q, k_ip) * self.scale
+ del k_ip
+ sim_ip = sim_ip.softmax(dim=-1)
+ out_ip = torch.einsum("b i j, b j d -> b i d", sim_ip, v_ip)
+ out_ip = rearrange(out_ip, "(b h) n d -> b n (h d)", h=h)
+
+ if out_ip is not None:
+ if self.image_cross_attention_scale_learnable:
+ out = out + self.image_cross_attention_scale * out_ip * (
+ torch.tanh(self.alpha) + 1
+ )
+ else:
+ out = out + self.image_cross_attention_scale * out_ip
+
+ return self.to_out(out)
+
+ def efficient_forward(self, x, context=None, mask=None):
+ spatial_self_attn = context is None
+ k_ip, v_ip, out_ip = None, None, None
+
+ q = self.to_q(x)
+ context = default(context, x)
+
+ if self.image_cross_attention and not spatial_self_attn:
+ context, context_image = (
+ context[:, : self.text_context_len, :],
+ context[:, self.text_context_len :, :],
+ )
+ k = self.to_k(context)
+ v = self.to_v(context)
+ k_ip = self.to_k_ip(context_image)
+ v_ip = self.to_v_ip(context_image)
+ else:
+ if not spatial_self_attn:
+ context = context[:, : self.text_context_len, :]
+ k = self.to_k(context)
+ v = self.to_v(context)
+
+ b, _, _ = q.shape
+ q, k, v = map(
+ lambda t: t.unsqueeze(3)
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
+ .permute(0, 2, 1, 3)
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
+ .contiguous(),
+ (q, k, v),
+ )
+ # actually compute the attention, what we cannot get enough of
+ out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=None)
+
+ # for image cross-attention
+ if k_ip is not None:
+ k_ip, v_ip = map(
+ lambda t: t.unsqueeze(3)
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
+ .permute(0, 2, 1, 3)
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
+ .contiguous(),
+ (k_ip, v_ip),
+ )
+ out_ip = xformers.ops.memory_efficient_attention(
+ q, k_ip, v_ip, attn_bias=None, op=None
+ )
+ out_ip = (
+ out_ip.unsqueeze(0)
+ .reshape(b, self.heads, out.shape[1], self.dim_head)
+ .permute(0, 2, 1, 3)
+ .reshape(b, out.shape[1], self.heads * self.dim_head)
+ )
+
+ if exists(mask):
+ raise NotImplementedError
+ out = (
+ out.unsqueeze(0)
+ .reshape(b, self.heads, out.shape[1], self.dim_head)
+ .permute(0, 2, 1, 3)
+ .reshape(b, out.shape[1], self.heads * self.dim_head)
+ )
+ if out_ip is not None:
+ if self.image_cross_attention_scale_learnable:
+ out = out + self.image_cross_attention_scale * out_ip * (
+ torch.tanh(self.alpha) + 1
+ )
+ else:
+ out = out + self.image_cross_attention_scale * out_ip
+
+ return self.to_out(out)
+
+
+class BasicTransformerBlock(nn.Module):
+
+ def __init__(
+ self,
+ dim,
+ n_heads,
+ d_head,
+ dropout=0.0,
+ context_dim=None,
+ gated_ff=True,
+ checkpoint=True,
+ disable_self_attn=False,
+ attention_cls=None,
+ video_length=None,
+ image_cross_attention=False,
+ image_cross_attention_scale=1.0,
+ image_cross_attention_scale_learnable=False,
+ text_context_len=77,
+ enable_lora=False,
+ ):
+ super().__init__()
+ attn_cls = CrossAttention if attention_cls is None else attention_cls
+ self.disable_self_attn = disable_self_attn
+ self.attn1 = attn_cls(
+ query_dim=dim,
+ heads=n_heads,
+ dim_head=d_head,
+ dropout=dropout,
+ context_dim=context_dim if self.disable_self_attn else None,
+ )
+ self.ff = FeedForward(
+ dim, dropout=dropout, glu=gated_ff, enable_lora=enable_lora
+ )
+ self.attn2 = attn_cls(
+ query_dim=dim,
+ context_dim=context_dim,
+ heads=n_heads,
+ dim_head=d_head,
+ dropout=dropout,
+ video_length=video_length,
+ image_cross_attention=image_cross_attention,
+ image_cross_attention_scale=image_cross_attention_scale,
+ image_cross_attention_scale_learnable=image_cross_attention_scale_learnable,
+ text_context_len=text_context_len,
+ )
+ self.image_cross_attention = image_cross_attention
+
+ self.norm1 = nn.LayerNorm(dim)
+ self.norm2 = nn.LayerNorm(dim)
+ self.norm3 = nn.LayerNorm(dim)
+ self.checkpoint = checkpoint
+
+ self.enable_lora = enable_lora
+
+ def forward(self, x, context=None, mask=None, with_lora=False, **kwargs):
+ # implementation tricks: because checkpointing doesn't support non-tensor (e.g. None or scalar) arguments
+ # should not be (x), otherwise *input_tuple will decouple x into multiple arguments
+ input_tuple = (x,)
+ if context is not None:
+ input_tuple = (x, context)
+ if mask is not None:
+ _forward = partial(self._forward, mask=None, with_lora=with_lora)
+ else:
+ _forward = partial(self._forward, mask=mask, with_lora=with_lora)
+ return gradient_checkpoint(
+ _forward, input_tuple, self.parameters(), self.checkpoint
+ )
+
+ def _forward(self, x, context=None, mask=None, with_lora=False):
+ x = (
+ self.attn1(
+ self.norm1(x),
+ context=context if self.disable_self_attn else None,
+ mask=mask,
+ )
+ + x
+ )
+ x = self.attn2(self.norm2(x), context=context, mask=mask) + x
+ x = self.ff(self.norm3(x), with_lora=with_lora) + x
+ return x
+
+
+class SpatialTransformer(nn.Module):
+ """
+ Transformer block for image-like data in spatial axis.
+ First, project the input (aka embedding)
+ and reshape to b, t, d.
+ Then apply standard transformer action.
+ Finally, reshape to image
+ NEW: use_linear for more efficiency instead of the 1x1 convs
+ """
+
+ def __init__(
+ self,
+ in_channels,
+ n_heads,
+ d_head,
+ depth=1,
+ dropout=0.0,
+ context_dim=None,
+ use_checkpoint=True,
+ disable_self_attn=False,
+ use_linear=False,
+ video_length=None,
+ image_cross_attention=False,
+ image_cross_attention_scale_learnable=False,
+ enable_lora=False,
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ inner_dim = n_heads * d_head
+ self.norm = torch.nn.GroupNorm(
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
+ )
+ if not use_linear:
+ self.proj_in = nn.Conv2d(
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0
+ )
+ else:
+ self.proj_in = nn.Linear(in_channels, inner_dim)
+
+ self.enable_lora = enable_lora
+
+ attention_cls = None
+ self.transformer_blocks = nn.ModuleList(
+ [
+ BasicTransformerBlock(
+ inner_dim,
+ n_heads,
+ d_head,
+ dropout=dropout,
+ context_dim=context_dim,
+ disable_self_attn=disable_self_attn,
+ checkpoint=use_checkpoint,
+ attention_cls=attention_cls,
+ video_length=video_length,
+ image_cross_attention=image_cross_attention,
+ image_cross_attention_scale_learnable=image_cross_attention_scale_learnable,
+ enable_lora=self.enable_lora,
+ )
+ for d in range(depth)
+ ]
+ )
+ if not use_linear:
+ self.proj_out = zero_module(
+ nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
+ )
+ else:
+ self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
+ self.use_linear = use_linear
+
+ def forward(self, x, context=None, with_lora=False, **kwargs):
+ b, c, h, w = x.shape
+ x_in = x
+ x = self.norm(x)
+ if not self.use_linear:
+ x = self.proj_in(x)
+ x = rearrange(x, "b c h w -> b (h w) c").contiguous()
+ if self.use_linear:
+ x = self.proj_in(x)
+ for i, block in enumerate(self.transformer_blocks):
+ x = block(x, context=context, with_lora=with_lora, **kwargs)
+ if self.use_linear:
+ x = self.proj_out(x)
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
+ if not self.use_linear:
+ x = self.proj_out(x)
+ return x + x_in
+
+
+class TemporalTransformer(nn.Module):
+ """
+ Transformer block for image-like data in temporal axis.
+ First, reshape to b, t, d.
+ Then apply standard transformer action.
+ Finally, reshape to image
+ """
+
+ def __init__(
+ self,
+ in_channels,
+ n_heads,
+ d_head,
+ depth=1,
+ dropout=0.0,
+ context_dim=None,
+ use_checkpoint=True,
+ use_linear=False,
+ only_self_att=True,
+ causal_attention=False,
+ causal_block_size=1,
+ relative_position=False,
+ temporal_length=None,
+ use_extra_spatial_temporal_self_attention=False,
+ enable_lora=False,
+ full_spatial_temporal_attention=False,
+ enhance_multi_view_correspondence=False,
+ ):
+ super().__init__()
+ self.only_self_att = only_self_att
+ self.relative_position = relative_position
+ self.causal_attention = causal_attention
+ self.causal_block_size = causal_block_size
+
+ self.in_channels = in_channels
+ inner_dim = n_heads * d_head
+ self.norm = torch.nn.GroupNorm(
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
+ )
+ self.proj_in = nn.Conv1d(
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0
+ )
+ if not use_linear:
+ self.proj_in = nn.Conv1d(
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0
+ )
+ else:
+ self.proj_in = nn.Linear(in_channels, inner_dim)
+
+ if relative_position:
+ assert temporal_length is not None
+ attention_cls = partial(
+ CrossAttention, relative_position=True, temporal_length=temporal_length
+ )
+ else:
+ attention_cls = partial(CrossAttention, temporal_length=temporal_length)
+ if self.causal_attention:
+ assert temporal_length is not None
+ self.mask = torch.tril(torch.ones([1, temporal_length, temporal_length]))
+
+ if self.only_self_att:
+ context_dim = None
+ self.transformer_blocks = nn.ModuleList(
+ [
+ BasicTransformerBlock(
+ inner_dim,
+ n_heads,
+ d_head,
+ dropout=dropout,
+ context_dim=context_dim,
+ attention_cls=attention_cls,
+ checkpoint=use_checkpoint,
+ enable_lora=enable_lora,
+ )
+ for d in range(depth)
+ ]
+ )
+ if not use_linear:
+ self.proj_out = zero_module(
+ nn.Conv1d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
+ )
+ else:
+ self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
+ self.use_linear = use_linear
+
+ self.use_extra_spatial_temporal_self_attention = (
+ use_extra_spatial_temporal_self_attention
+ )
+ if use_extra_spatial_temporal_self_attention:
+ from core.modules.attention_mv import MultiViewSelfAttentionTransformer
+
+ self.extra_spatial_time_self_attention = MultiViewSelfAttentionTransformer(
+ in_channels=in_channels,
+ n_heads=n_heads,
+ d_head=d_head,
+ num_views=temporal_length,
+ depth=depth,
+ use_linear=use_linear,
+ use_checkpoint=use_checkpoint,
+ full_spatial_temporal_attention=full_spatial_temporal_attention,
+ enhance_multi_view_correspondence=enhance_multi_view_correspondence,
+ )
+
+ def forward(self, x, context=None, with_lora=False, time_steps=None):
+ b, c, t, h, w = x.shape
+ x_in = x
+ x = self.norm(x)
+ x = rearrange(x, "b c t h w -> (b h w) c t").contiguous()
+ if not self.use_linear:
+ x = self.proj_in(x)
+ x = rearrange(x, "bhw c t -> bhw t c").contiguous()
+ if self.use_linear:
+ x = self.proj_in(x)
+
+ temp_mask = None
+ if self.causal_attention:
+ # slice the from mask map
+ temp_mask = self.mask[:, :t, :t].to(x.device)
+
+ if temp_mask is not None:
+ mask = temp_mask.to(x.device)
+ mask = repeat(mask, "l i j -> (l bhw) i j", bhw=b * h * w)
+ else:
+ mask = None
+
+ if self.only_self_att:
+ # note: if no context is given, cross-attention defaults to self-attention
+ for i, block in enumerate(self.transformer_blocks):
+ x = block(x, mask=mask, with_lora=with_lora)
+ x = rearrange(x, "(b hw) t c -> b hw t c", b=b).contiguous()
+ else:
+ x = rearrange(x, "(b hw) t c -> b hw t c", b=b).contiguous()
+ context = rearrange(context, "(b t) l con -> b t l con", t=t).contiguous()
+ for i, block in enumerate(self.transformer_blocks):
+ # calculate each batch one by one (since number in shape could not greater then 65,535 for some package)
+ for j in range(b):
+ context_j = repeat(
+ context[j], "t l con -> (t r) l con", r=(h * w) // t, t=t
+ ).contiguous()
+ # note: causal mask will not applied in cross-attention case
+ x[j] = block(x[j], context=context_j, with_lora=with_lora)
+
+ if self.use_linear:
+ x = self.proj_out(x)
+ x = rearrange(x, "b (h w) t c -> b c t h w", h=h, w=w).contiguous()
+ if not self.use_linear:
+ x = rearrange(x, "b hw t c -> (b hw) c t").contiguous()
+ x = self.proj_out(x)
+ x = rearrange(x, "(b h w) c t -> b c t h w", b=b, h=h, w=w).contiguous()
+
+ res = x + x_in
+
+ if self.use_extra_spatial_temporal_self_attention:
+ res = rearrange(res, "b c t h w -> (b t) c h w", b=b, h=h, w=w).contiguous()
+ res = self.extra_spatial_time_self_attention(res, time_steps=time_steps)
+ res = rearrange(res, "(b t) c h w -> b c t h w", b=b, h=h, w=w).contiguous()
+
+ return res
+
+
+class GEGLU(nn.Module):
+ def __init__(self, dim_in, dim_out):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out * 2)
+
+ def forward(self, x):
+ x, gate = self.proj(x).chunk(2, dim=-1)
+ return x * F.gelu(gate)
+
+
+class FeedForward(nn.Module):
+ def __init__(
+ self,
+ dim,
+ dim_out=None,
+ mult=4,
+ glu=False,
+ dropout=0.0,
+ enable_lora=False,
+ lora_rank=32,
+ ):
+ super().__init__()
+ inner_dim = int(dim * mult)
+ dim_out = default(dim_out, dim)
+ project_in = (
+ nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
+ if not glu
+ else GEGLU(dim, inner_dim)
+ )
+
+ self.net = nn.Sequential(
+ project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
+ )
+ self.enable_lora = enable_lora
+ self.lora_rank = lora_rank
+ self.lora_alpha = 16
+ if self.enable_lora:
+ assert (
+ self.lora_rank is not None
+ ), "`lora_rank` must be given when `enable_lora` is True."
+ assert (
+ 0 < self.lora_rank < min(dim, dim_out)
+ ), f"`lora_rank` must be range [0, min(inner_dim={inner_dim}, dim_out={dim_out})], but got {self.lora_rank}."
+ self.lora_a = nn.Parameter(
+ torch.zeros((inner_dim, self.lora_rank), requires_grad=True)
+ )
+ self.lora_b = nn.Parameter(
+ torch.zeros((self.lora_rank, dim_out), requires_grad=True)
+ )
+ self.scaling = self.lora_alpha / self.lora_rank
+
+ def forward(self, x, with_lora=False):
+ if with_lora:
+ projected_x = self.net[1](self.net[0](x))
+ lora_x = (
+ torch.matmul(projected_x, torch.matmul(self.lora_a, self.lora_b))
+ * self.scaling
+ )
+ original_x = self.net[2](projected_x)
+ return original_x + lora_x
+ else:
+ return self.net(x)
+
+
+class LinearAttention(nn.Module):
+ def __init__(self, dim, heads=4, dim_head=32):
+ super().__init__()
+ self.heads = heads
+ hidden_dim = dim_head * heads
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
+
+ def forward(self, x):
+ b, c, h, w = x.shape
+ qkv = self.to_qkv(x)
+ q, k, v = rearrange(
+ qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3
+ )
+ k = k.softmax(dim=-1)
+ context = torch.einsum("bhdn,bhen->bhde", k, v)
+ out = torch.einsum("bhde,bhdn->bhen", context, q)
+ out = rearrange(
+ out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w
+ )
+ return self.to_out(out)
+
+
+class SpatialSelfAttention(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = torch.nn.GroupNorm(
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
+ )
+ self.q = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+ self.k = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+ self.v = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+ self.proj_out = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b, c, h, w = q.shape
+ q = rearrange(q, "b c h w -> b (h w) c")
+ k = rearrange(k, "b c h w -> b c (h w)")
+ w_ = torch.einsum("bij,bjk->bik", q, k)
+
+ w_ = w_ * (int(c) ** (-0.5))
+ w_ = torch.nn.functional.softmax(w_, dim=2)
+
+ # attend to values
+ v = rearrange(v, "b c h w -> b c (h w)")
+ w_ = rearrange(w_, "b i j -> b j i")
+ h_ = torch.einsum("bij,bjk->bik", v, w_)
+ h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
+ h_ = self.proj_out(h_)
+
+ return x + h_
diff --git a/core/modules/attention_mv.py b/core/modules/attention_mv.py
new file mode 100755
index 0000000000000000000000000000000000000000..9d20230d89f600f78d0ca26ac1567a568fc38cb3
--- /dev/null
+++ b/core/modules/attention_mv.py
@@ -0,0 +1,316 @@
+import torch
+import torch.nn.functional as F
+from einops import rearrange
+from torch import nn
+
+from core.common import gradient_checkpoint
+
+try:
+ import xformers
+ import xformers.ops
+
+ XFORMERS_IS_AVAILBLE = True
+except:
+ XFORMERS_IS_AVAILBLE = False
+
+print(f"XFORMERS_IS_AVAILBLE: {XFORMERS_IS_AVAILBLE}")
+
+
+def get_group_norm_layer(in_channels):
+ if in_channels < 32:
+ if in_channels % 2 == 0:
+ num_groups = in_channels // 2
+ else:
+ num_groups = in_channels
+ else:
+ num_groups = 32
+ return torch.nn.GroupNorm(
+ num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
+ )
+
+
+def zero_module(module):
+ """
+ Zero out the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+
+
+def conv_nd(dims, *args, **kwargs):
+ """
+ Create a 1D, 2D, or 3D convolution module.
+ """
+ if dims == 1:
+ return nn.Conv1d(*args, **kwargs)
+ elif dims == 2:
+ return nn.Conv2d(*args, **kwargs)
+ elif dims == 3:
+ return nn.Conv3d(*args, **kwargs)
+ raise ValueError(f"unsupported dimensions: {dims}")
+
+
+class GEGLU(nn.Module):
+ def __init__(self, dim_in, dim_out):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out * 2)
+
+ def forward(self, x):
+ x, gate = self.proj(x).chunk(2, dim=-1)
+ return x * F.gelu(gate)
+
+
+class FeedForward(nn.Module):
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
+ super().__init__()
+ inner_dim = int(dim * mult)
+ if dim_out is None:
+ dim_out = dim
+ project_in = (
+ nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
+ if not glu
+ else GEGLU(dim, inner_dim)
+ )
+
+ self.net = nn.Sequential(
+ project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
+ )
+
+ def forward(self, x):
+ return self.net(x)
+
+
+class SpatialTemporalAttention(nn.Module):
+ """Uses xformers to implement efficient epipolar masking for cross-attention between views."""
+
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
+ super().__init__()
+ inner_dim = dim_head * heads
+ if context_dim is None:
+ context_dim = query_dim
+
+ self.heads = heads
+ self.dim_head = dim_head
+
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
+
+ self.to_out = nn.Sequential(
+ nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
+ )
+ self.attention_op = None
+
+ def forward(self, x, context=None, enhance_multi_view_correspondence=False):
+ q = self.to_q(x)
+ if context is None:
+ context = x
+ k = self.to_k(context)
+ v = self.to_v(context)
+
+ b, _, _ = q.shape
+ q, k, v = map(
+ lambda t: t.unsqueeze(3)
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
+ .permute(0, 2, 1, 3)
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
+ .contiguous(),
+ (q, k, v),
+ )
+
+ if enhance_multi_view_correspondence:
+ with torch.no_grad():
+ normalized_x = torch.nn.functional.normalize(x.detach(), p=2, dim=-1)
+ cosine_sim_map = torch.bmm(normalized_x, normalized_x.transpose(-1, -2))
+ attn_bias = torch.where(cosine_sim_map > 0.0, 0.0, -1e9).to(
+ dtype=q.dtype
+ )
+ attn_bias = rearrange(
+ attn_bias.unsqueeze(1).expand(-1, self.heads, -1, -1),
+ "b h d1 d2 -> (b h) d1 d2",
+ ).detach()
+ else:
+ attn_bias = None
+
+ out = xformers.ops.memory_efficient_attention(
+ q, k, v, attn_bias=attn_bias, op=self.attention_op
+ )
+
+ out = (
+ out.unsqueeze(0)
+ .reshape(b, self.heads, out.shape[1], self.dim_head)
+ .permute(0, 2, 1, 3)
+ .reshape(b, out.shape[1], self.heads * self.dim_head)
+ )
+ del q, k, v, attn_bias
+ return self.to_out(out)
+
+
+class MultiViewSelfAttentionTransformerBlock(nn.Module):
+
+ def __init__(
+ self,
+ dim,
+ n_heads,
+ d_head,
+ dropout=0.0,
+ gated_ff=True,
+ use_checkpoint=True,
+ full_spatial_temporal_attention=False,
+ enhance_multi_view_correspondence=False,
+ ):
+ super().__init__()
+ attn_cls = SpatialTemporalAttention
+ # self.self_attention_only = self_attention_only
+ self.attn1 = attn_cls(
+ query_dim=dim,
+ heads=n_heads,
+ dim_head=d_head,
+ dropout=dropout,
+ context_dim=None,
+ ) # is a self-attention if not self.disable_self_attn
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
+
+ if enhance_multi_view_correspondence:
+ # Zero initalization when MVCorr is enabled.
+ zero_module_fn = zero_module
+ else:
+
+ def zero_module_fn(x):
+ return x
+
+ self.attn2 = zero_module_fn(
+ attn_cls(
+ query_dim=dim,
+ heads=n_heads,
+ dim_head=d_head,
+ dropout=dropout,
+ context_dim=None,
+ )
+ ) # is self-attn if context is none
+
+ self.norm1 = nn.LayerNorm(dim)
+ self.norm2 = nn.LayerNorm(dim)
+ self.norm3 = nn.LayerNorm(dim)
+ self.use_checkpoint = use_checkpoint
+ self.full_spatial_temporal_attention = full_spatial_temporal_attention
+ self.enhance_multi_view_correspondence = enhance_multi_view_correspondence
+
+ def forward(self, x, time_steps=None):
+ return gradient_checkpoint(
+ self.many_stream_forward, (x, time_steps), None, flag=self.use_checkpoint
+ )
+
+ def many_stream_forward(self, x, time_steps=None):
+ n, v, hw = x.shape[:3]
+ x = rearrange(x, "n v hw c -> n (v hw) c")
+ x = (
+ self.attn1(
+ self.norm1(x), context=None, enhance_multi_view_correspondence=False
+ )
+ + x
+ )
+ if not self.full_spatial_temporal_attention:
+ x = rearrange(x, "n (v hw) c -> n v hw c", v=v)
+ x = rearrange(x, "n v hw c -> (n v) hw c")
+ x = (
+ self.attn2(
+ self.norm2(x),
+ context=None,
+ enhance_multi_view_correspondence=self.enhance_multi_view_correspondence
+ and hw <= 256,
+ )
+ + x
+ )
+ x = self.ff(self.norm3(x)) + x
+ if self.full_spatial_temporal_attention:
+ x = rearrange(x, "n (v hw) c -> n v hw c", v=v)
+ else:
+ x = rearrange(x, "(n v) hw c -> n v hw c", v=v)
+ return x
+
+
+class MultiViewSelfAttentionTransformer(nn.Module):
+ """Spatial Transformer block with post init to add cross attn."""
+
+ def __init__(
+ self,
+ in_channels,
+ n_heads,
+ d_head,
+ num_views,
+ depth=1,
+ dropout=0.0,
+ use_linear=True,
+ use_checkpoint=True,
+ zero_out_initialization=True,
+ full_spatial_temporal_attention=False,
+ enhance_multi_view_correspondence=False,
+ ):
+ super().__init__()
+ self.num_views = num_views
+ self.in_channels = in_channels
+ inner_dim = n_heads * d_head
+ self.norm = get_group_norm_layer(in_channels)
+ if not use_linear:
+ self.proj_in = nn.Conv2d(
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0
+ )
+ else:
+ self.proj_in = nn.Linear(in_channels, inner_dim)
+
+ self.transformer_blocks = nn.ModuleList(
+ [
+ MultiViewSelfAttentionTransformerBlock(
+ inner_dim,
+ n_heads,
+ d_head,
+ dropout=dropout,
+ use_checkpoint=use_checkpoint,
+ full_spatial_temporal_attention=full_spatial_temporal_attention,
+ enhance_multi_view_correspondence=enhance_multi_view_correspondence,
+ )
+ for d in range(depth)
+ ]
+ )
+ self.zero_out_initialization = zero_out_initialization
+
+ if zero_out_initialization:
+ _zero_func = zero_module
+ else:
+
+ def _zero_func(x):
+ return x
+
+ if not use_linear:
+ self.proj_out = _zero_func(
+ nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
+ )
+ else:
+ self.proj_out = _zero_func(nn.Linear(inner_dim, in_channels))
+
+ self.use_linear = use_linear
+
+ def forward(self, x, time_steps=None):
+ # x: bt c h w
+ _, c, h, w = x.shape
+ n_views = self.num_views
+ x_in = x
+ x = self.norm(x)
+ x = rearrange(x, "(n v) c h w -> n v (h w) c", v=n_views)
+
+ if self.use_linear:
+ x = rearrange(x, "n v x c -> (n v) x c")
+ x = self.proj_in(x)
+ x = rearrange(x, "(n v) x c -> n v x c", v=n_views)
+ for i, block in enumerate(self.transformer_blocks):
+ x = block(x, time_steps=time_steps)
+ if self.use_linear:
+ x = rearrange(x, "n v x c -> (n v) x c")
+ x = self.proj_out(x)
+ x = rearrange(x, "(n v) x c -> n v x c", v=n_views)
+
+ x = rearrange(x, "n v (h w) c -> (n v) c h w", h=h, w=w).contiguous()
+
+ return x + x_in
diff --git a/core/modules/attention_temporal.py b/core/modules/attention_temporal.py
new file mode 100755
index 0000000000000000000000000000000000000000..690c77a7795f0cfd3fa59ee3db7d15d7b62c0840
--- /dev/null
+++ b/core/modules/attention_temporal.py
@@ -0,0 +1,1111 @@
+import math
+
+import torch
+import torch as th
+import torch.nn.functional as F
+from einops import rearrange, repeat
+from torch import nn, einsum
+
+try:
+ import xformers
+ import xformers.ops
+
+ XFORMERS_IS_AVAILBLE = True
+except:
+ XFORMERS_IS_AVAILBLE = False
+from core.common import gradient_checkpoint, exists, default
+from core.basics import conv_nd, zero_module, normalization
+
+
+class GEGLU(nn.Module):
+ def __init__(self, dim_in, dim_out):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out * 2)
+
+ def forward(self, x):
+ x, gate = self.proj(x).chunk(2, dim=-1)
+ return x * F.gelu(gate)
+
+
+class FeedForward(nn.Module):
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
+ super().__init__()
+ inner_dim = int(dim * mult)
+ dim_out = default(dim_out, dim)
+ project_in = (
+ nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
+ if not glu
+ else GEGLU(dim, inner_dim)
+ )
+
+ self.net = nn.Sequential(
+ project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
+ )
+
+ def forward(self, x):
+ return self.net(x)
+
+
+def Normalize(in_channels):
+ return torch.nn.GroupNorm(
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
+ )
+
+
+class RelativePosition(nn.Module):
+
+ def __init__(self, num_units, max_relative_position):
+ super().__init__()
+ self.num_units = num_units
+ self.max_relative_position = max_relative_position
+ self.embeddings_table = nn.Parameter(
+ th.Tensor(max_relative_position * 2 + 1, num_units)
+ )
+ nn.init.xavier_uniform_(self.embeddings_table)
+
+ def forward(self, length_q, length_k):
+ device = self.embeddings_table.device
+ range_vec_q = th.arange(length_q, device=device)
+ range_vec_k = th.arange(length_k, device=device)
+ distance_mat = range_vec_k[None, :] - range_vec_q[:, None]
+ distance_mat_clipped = th.clamp(
+ distance_mat, -self.max_relative_position, self.max_relative_position
+ )
+ final_mat = distance_mat_clipped + self.max_relative_position
+ final_mat = final_mat.long()
+ embeddings = self.embeddings_table[final_mat]
+ return embeddings
+
+
+class TemporalCrossAttention(nn.Module):
+ def __init__(
+ self,
+ query_dim,
+ context_dim=None,
+ heads=8,
+ dim_head=64,
+ dropout=0.0,
+ # For relative positional representation and image-video joint training.
+ temporal_length=None,
+ image_length=None, # For image-video joint training.
+ # whether use relative positional representation in temporal attention.
+ use_relative_position=False,
+ # For image-video joint training.
+ img_video_joint_train=False,
+ use_tempoal_causal_attn=False,
+ bidirectional_causal_attn=False,
+ tempoal_attn_type=None,
+ joint_train_mode="same_batch",
+ **kwargs,
+ ):
+ super().__init__()
+ inner_dim = dim_head * heads
+ context_dim = default(context_dim, query_dim)
+ self.context_dim = context_dim
+
+ self.scale = dim_head**-0.5
+ self.heads = heads
+ self.temporal_length = temporal_length
+ self.use_relative_position = use_relative_position
+ self.img_video_joint_train = img_video_joint_train
+ self.bidirectional_causal_attn = bidirectional_causal_attn
+ self.joint_train_mode = joint_train_mode
+ assert joint_train_mode in ["same_batch", "diff_batch"]
+ self.tempoal_attn_type = tempoal_attn_type
+
+ if bidirectional_causal_attn:
+ assert use_tempoal_causal_attn
+ if tempoal_attn_type:
+ assert tempoal_attn_type in ["sparse_causal", "sparse_causal_first"]
+ assert not use_tempoal_causal_attn
+ assert not (
+ img_video_joint_train and (self.joint_train_mode == "same_batch")
+ )
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
+
+ assert not (
+ img_video_joint_train
+ and (self.joint_train_mode == "same_batch")
+ and use_tempoal_causal_attn
+ )
+ if img_video_joint_train:
+ if self.joint_train_mode == "same_batch":
+ mask = torch.ones(
+ [1, temporal_length + image_length, temporal_length + image_length]
+ )
+ mask[:, temporal_length:, :] = 0
+ mask[:, :, temporal_length:] = 0
+ self.mask = mask
+ else:
+ self.mask = None
+ elif use_tempoal_causal_attn:
+ # normal causal attn
+ self.mask = torch.tril(torch.ones([1, temporal_length, temporal_length]))
+ elif tempoal_attn_type == "sparse_causal":
+ # true indicates keeping
+ mask1 = torch.tril(torch.ones([1, temporal_length, temporal_length])).bool()
+ # initialize to same shape with mask1
+ mask2 = torch.zeros([1, temporal_length, temporal_length])
+ mask2[:, 2:temporal_length, : temporal_length - 2] = torch.tril(
+ torch.ones([1, temporal_length - 2, temporal_length - 2])
+ )
+ mask2 = (1 - mask2).bool() # false indicates masking
+ self.mask = mask1 & mask2
+ elif tempoal_attn_type == "sparse_causal_first":
+ # true indicates keeping
+ mask1 = torch.tril(torch.ones([1, temporal_length, temporal_length])).bool()
+ mask2 = torch.zeros([1, temporal_length, temporal_length])
+ mask2[:, 2:temporal_length, 1 : temporal_length - 1] = torch.tril(
+ torch.ones([1, temporal_length - 2, temporal_length - 2])
+ )
+ mask2 = (1 - mask2).bool() # false indicates masking
+ self.mask = mask1 & mask2
+ else:
+ self.mask = None
+
+ if use_relative_position:
+ assert temporal_length is not None
+ self.relative_position_k = RelativePosition(
+ num_units=dim_head, max_relative_position=temporal_length
+ )
+ self.relative_position_v = RelativePosition(
+ num_units=dim_head, max_relative_position=temporal_length
+ )
+
+ self.to_out = nn.Sequential(
+ nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
+ )
+
+ nn.init.constant_(self.to_q.weight, 0)
+ nn.init.constant_(self.to_k.weight, 0)
+ nn.init.constant_(self.to_v.weight, 0)
+ nn.init.constant_(self.to_out[0].weight, 0)
+ nn.init.constant_(self.to_out[0].bias, 0)
+
+ def forward(self, x, context=None, mask=None):
+ nh = self.heads
+ out = x
+ q = self.to_q(out)
+ context = default(context, x)
+ k = self.to_k(context)
+ v = self.to_v(context)
+
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=nh), (q, k, v))
+ sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
+
+ if self.use_relative_position:
+ len_q, len_k, len_v = q.shape[1], k.shape[1], v.shape[1]
+ k2 = self.relative_position_k(len_q, len_k)
+ sim2 = einsum("b t d, t s d -> b t s", q, k2) * self.scale
+ sim += sim2
+ if exists(self.mask):
+ if mask is None:
+ mask = self.mask.to(sim.device)
+ else:
+ # .to(sim.device)
+ mask = self.mask.to(sim.device).bool() & mask
+ else:
+ mask = mask
+ if mask is not None:
+ max_neg_value = -1e9
+ sim = sim + (1 - mask.float()) * max_neg_value # 1=masking,0=no masking
+
+ attn = sim.softmax(dim=-1)
+
+ out = einsum("b i j, b j d -> b i d", attn, v)
+
+ if self.bidirectional_causal_attn:
+ mask_reverse = torch.triu(
+ torch.ones(
+ [1, self.temporal_length, self.temporal_length], device=sim.device
+ )
+ )
+ sim_reverse = sim.float().masked_fill(mask_reverse == 0, max_neg_value)
+ attn_reverse = sim_reverse.softmax(dim=-1)
+ out_reverse = einsum("b i j, b j d -> b i d", attn_reverse, v)
+ out += out_reverse
+
+ if self.use_relative_position:
+ v2 = self.relative_position_v(len_q, len_v)
+ out2 = einsum("b t s, t s d -> b t d", attn, v2)
+ out += out2
+ out = rearrange(out, "(b h) n d -> b n (h d)", h=nh)
+ return self.to_out(out)
+
+
+class CrossAttention(nn.Module):
+ def __init__(
+ self,
+ query_dim,
+ context_dim=None,
+ heads=8,
+ dim_head=64,
+ dropout=0.0,
+ sa_shared_kv=False,
+ shared_type="only_first",
+ **kwargs,
+ ):
+ super().__init__()
+ inner_dim = dim_head * heads
+ context_dim = default(context_dim, query_dim)
+ self.sa_shared_kv = sa_shared_kv
+ assert shared_type in [
+ "only_first",
+ "all_frames",
+ "first_and_prev",
+ "only_prev",
+ "full",
+ "causal",
+ "full_qkv",
+ ]
+ self.shared_type = shared_type
+
+ self.scale = dim_head**-0.5
+ self.heads = heads
+ self.dim_head = dim_head
+
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
+
+ self.to_out = nn.Sequential(
+ nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
+ )
+ if XFORMERS_IS_AVAILBLE:
+ self.forward = self.efficient_forward
+
+ def forward(self, x, context=None, mask=None):
+ h = self.heads
+ b = x.shape[0]
+
+ q = self.to_q(x)
+ context = default(context, x)
+ k = self.to_k(context)
+ v = self.to_v(context)
+ if self.sa_shared_kv:
+ if self.shared_type == "only_first":
+ k, v = map(
+ lambda xx: rearrange(xx[0].unsqueeze(0), "b n c -> (b n) c")
+ .unsqueeze(0)
+ .repeat(b, 1, 1),
+ (k, v),
+ )
+ else:
+ raise NotImplementedError
+
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
+
+ sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
+
+ if exists(mask):
+ mask = rearrange(mask, "b ... -> b (...)")
+ max_neg_value = -torch.finfo(sim.dtype).max
+ mask = repeat(mask, "b j -> (b h) () j", h=h)
+ sim.masked_fill_(~mask, max_neg_value)
+
+ # attention, what we cannot get enough of
+ attn = sim.softmax(dim=-1)
+
+ out = einsum("b i j, b j d -> b i d", attn, v)
+ out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
+ return self.to_out(out)
+
+ def efficient_forward(self, x, context=None, mask=None):
+ q = self.to_q(x)
+ context = default(context, x)
+ k = self.to_k(context)
+ v = self.to_v(context)
+
+ b, _, _ = q.shape
+ q, k, v = map(
+ lambda t: t.unsqueeze(3)
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
+ .permute(0, 2, 1, 3)
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
+ .contiguous(),
+ (q, k, v),
+ )
+ # actually compute the attention, what we cannot get enough of
+ out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=None)
+
+ if exists(mask):
+ raise NotImplementedError
+ out = (
+ out.unsqueeze(0)
+ .reshape(b, self.heads, out.shape[1], self.dim_head)
+ .permute(0, 2, 1, 3)
+ .reshape(b, out.shape[1], self.heads * self.dim_head)
+ )
+ return self.to_out(out)
+
+
+class VideoSpatialCrossAttention(CrossAttention):
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0):
+ super().__init__(query_dim, context_dim, heads, dim_head, dropout)
+
+ def forward(self, x, context=None, mask=None):
+ b, c, t, h, w = x.shape
+ if context is not None:
+ context = context.repeat(t, 1, 1)
+ x = super.forward(spatial_attn_reshape(x), context=context) + x
+ return spatial_attn_reshape_back(x, b, h)
+
+
+class BasicTransformerBlockST(nn.Module):
+ def __init__(
+ self,
+ # Spatial Stuff
+ dim,
+ n_heads,
+ d_head,
+ dropout=0.0,
+ context_dim=None,
+ gated_ff=True,
+ checkpoint=True,
+ # Temporal Stuff
+ temporal_length=None,
+ image_length=None,
+ use_relative_position=True,
+ img_video_joint_train=False,
+ cross_attn_on_tempoal=False,
+ temporal_crossattn_type="selfattn",
+ order="stst",
+ temporalcrossfirst=False,
+ temporal_context_dim=None,
+ split_stcontext=False,
+ local_spatial_temporal_attn=False,
+ window_size=2,
+ **kwargs,
+ ):
+ super().__init__()
+ # Self attention
+ self.attn1 = CrossAttention(
+ query_dim=dim,
+ heads=n_heads,
+ dim_head=d_head,
+ dropout=dropout,
+ **kwargs,
+ )
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
+ # cross attention if context is not None
+ self.attn2 = CrossAttention(
+ query_dim=dim,
+ context_dim=context_dim,
+ heads=n_heads,
+ dim_head=d_head,
+ dropout=dropout,
+ **kwargs,
+ )
+ self.norm1 = nn.LayerNorm(dim)
+ self.norm2 = nn.LayerNorm(dim)
+ self.norm3 = nn.LayerNorm(dim)
+ self.checkpoint = checkpoint
+ self.order = order
+ assert self.order in ["stst", "sstt", "st_parallel"]
+ self.temporalcrossfirst = temporalcrossfirst
+ self.split_stcontext = split_stcontext
+ self.local_spatial_temporal_attn = local_spatial_temporal_attn
+ if self.local_spatial_temporal_attn:
+ assert self.order == "stst"
+ assert self.order == "stst"
+ self.window_size = window_size
+ if not split_stcontext:
+ temporal_context_dim = context_dim
+ # Temporal attention
+ assert temporal_crossattn_type in ["selfattn", "crossattn", "skip"]
+ self.temporal_crossattn_type = temporal_crossattn_type
+ self.attn1_tmp = TemporalCrossAttention(
+ query_dim=dim,
+ heads=n_heads,
+ dim_head=d_head,
+ dropout=dropout,
+ temporal_length=temporal_length,
+ image_length=image_length,
+ use_relative_position=use_relative_position,
+ img_video_joint_train=img_video_joint_train,
+ **kwargs,
+ )
+ self.attn2_tmp = TemporalCrossAttention(
+ query_dim=dim,
+ heads=n_heads,
+ dim_head=d_head,
+ dropout=dropout,
+ # cross attn
+ context_dim=(
+ temporal_context_dim if temporal_crossattn_type == "crossattn" else None
+ ),
+ # temporal attn
+ temporal_length=temporal_length,
+ image_length=image_length,
+ use_relative_position=use_relative_position,
+ img_video_joint_train=img_video_joint_train,
+ **kwargs,
+ )
+ self.norm4 = nn.LayerNorm(dim)
+ self.norm5 = nn.LayerNorm(dim)
+
+ def forward(
+ self,
+ x,
+ context=None,
+ temporal_context=None,
+ no_temporal_attn=None,
+ attn_mask=None,
+ **kwargs,
+ ):
+ if not self.split_stcontext:
+ # st cross attention use the same context vector
+ temporal_context = context.detach().clone()
+
+ if context is None and temporal_context is None:
+ # self-attention models
+ if no_temporal_attn:
+ raise NotImplementedError
+ return gradient_checkpoint(
+ self._forward_nocontext, (x), self.parameters(), self.checkpoint
+ )
+ else:
+ # cross-attention models
+ if no_temporal_attn:
+ forward_func = self._forward_no_temporal_attn
+ else:
+ forward_func = self._forward
+ inputs = (
+ (x, context, temporal_context)
+ if temporal_context is not None
+ else (x, context)
+ )
+ return gradient_checkpoint(
+ forward_func, inputs, self.parameters(), self.checkpoint
+ )
+
+ def _forward(
+ self,
+ x,
+ context=None,
+ temporal_context=None,
+ mask=None,
+ no_temporal_attn=None,
+ ):
+ assert x.dim() == 5, f"x shape = {x.shape}"
+ b, c, t, h, w = x.shape
+
+ if self.order in ["stst", "sstt"]:
+ x = self._st_cross_attn(
+ x,
+ context,
+ temporal_context=temporal_context,
+ order=self.order,
+ mask=mask,
+ ) # no_temporal_attn=no_temporal_attn,
+ elif self.order == "st_parallel":
+ x = self._st_cross_attn_parallel(
+ x,
+ context,
+ temporal_context=temporal_context,
+ order=self.order,
+ ) # no_temporal_attn=no_temporal_attn,
+ else:
+ raise NotImplementedError
+
+ x = self.ff(self.norm3(x)) + x
+ if (no_temporal_attn is None) or (not no_temporal_attn):
+ x = rearrange(x, "(b h w) t c -> b c t h w", b=b, h=h, w=w) # 3d -> 5d
+ elif no_temporal_attn:
+ x = rearrange(x, "(b t) (h w) c -> b c t h w", b=b, h=h, w=w) # 3d -> 5d
+ return x
+
+ def _forward_no_temporal_attn(
+ self,
+ x,
+ context=None,
+ temporal_context=None,
+ ):
+ assert x.dim() == 5, f"x shape = {x.shape}"
+ b, c, t, h, w = x.shape
+
+ if self.order in ["stst", "sstt"]:
+ mask = torch.zeros([1, t, t], device=x.device).bool()
+ x = self._st_cross_attn(
+ x,
+ context,
+ temporal_context=temporal_context,
+ order=self.order,
+ mask=mask,
+ )
+ elif self.order == "st_parallel":
+ x = self._st_cross_attn_parallel(
+ x,
+ context,
+ temporal_context=temporal_context,
+ order=self.order,
+ no_temporal_attn=True,
+ )
+ else:
+ raise NotImplementedError
+
+ x = self.ff(self.norm3(x)) + x
+ x = rearrange(x, "(b h w) t c -> b c t h w", b=b, h=h, w=w) # 3d -> 5d
+ return x
+
+ def _forward_nocontext(self, x, no_temporal_attn=None):
+ assert x.dim() == 5, f"x shape = {x.shape}"
+ b, c, t, h, w = x.shape
+
+ if self.order in ["stst", "sstt"]:
+ x = self._st_cross_attn(
+ x, order=self.order, no_temporal_attn=no_temporal_attn
+ )
+ elif self.order == "st_parallel":
+ x = self._st_cross_attn_parallel(
+ x, order=self.order, no_temporal_attn=no_temporal_attn
+ )
+ else:
+ raise NotImplementedError
+
+ x = self.ff(self.norm3(x)) + x
+ x = rearrange(x, "(b h w) t c -> b c t h w", b=b, h=h, w=w) # 3d -> 5d
+
+ return x
+
+ def _st_cross_attn(
+ self, x, context=None, temporal_context=None, order="stst", mask=None
+ ):
+ b, c, t, h, w = x.shape
+ if order == "stst":
+ x = rearrange(x, "b c t h w -> (b t) (h w) c")
+ x = self.attn1(self.norm1(x)) + x
+ x = rearrange(x, "(b t) (h w) c -> b c t h w", b=b, h=h)
+ if self.local_spatial_temporal_attn:
+ x = local_spatial_temporal_attn_reshape(x, window_size=self.window_size)
+ else:
+ x = rearrange(x, "b c t h w -> (b h w) t c")
+ x = self.attn1_tmp(self.norm4(x), mask=mask) + x
+
+ if self.local_spatial_temporal_attn:
+ x = local_spatial_temporal_attn_reshape_back(
+ x, window_size=self.window_size, b=b, h=h, w=w, t=t
+ )
+ else:
+ x = rearrange(x, "(b h w) t c -> b c t h w", b=b, h=h, w=w) # 3d -> 5d
+
+ # spatial cross attention
+ x = rearrange(x, "b c t h w -> (b t) (h w) c")
+ if context is not None:
+ if context.shape[0] == t: # img captions no_temporal_attn or
+ context_ = context
+ else:
+ context_ = []
+ for i in range(context.shape[0]):
+ context_.append(context[i].unsqueeze(0).repeat(t, 1, 1))
+ context_ = torch.cat(context_, dim=0)
+ else:
+ context_ = None
+ x = self.attn2(self.norm2(x), context=context_) + x
+
+ # temporal cross attention
+ # if (no_temporal_attn is None) or (not no_temporal_attn):
+ x = rearrange(x, "(b t) (h w) c -> b c t h w", b=b, h=h)
+ x = rearrange(x, "b c t h w -> (b h w) t c")
+ if self.temporal_crossattn_type == "crossattn":
+ # tmporal cross attention
+ if temporal_context is not None:
+ # print(f'STATTN context={context.shape}, temporal_context={temporal_context.shape}')
+ temporal_context = torch.cat(
+ [context, temporal_context], dim=1
+ ) # blc
+ # print(f'STATTN after concat temporal_context={temporal_context.shape}')
+ temporal_context = temporal_context.repeat(h * w, 1, 1)
+ # print(f'after repeat temporal_context={temporal_context.shape}')
+ else:
+ temporal_context = context[0:1, ...].repeat(h * w, 1, 1)
+ # print(f'STATTN after concat x={x.shape}')
+ x = (
+ self.attn2_tmp(self.norm5(x), context=temporal_context, mask=mask)
+ + x
+ )
+ elif self.temporal_crossattn_type == "selfattn":
+ # temporal self attention
+ x = self.attn2_tmp(self.norm5(x), context=None, mask=mask) + x
+ elif self.temporal_crossattn_type == "skip":
+ # no temporal cross and self attention
+ pass
+ else:
+ raise NotImplementedError
+
+ elif order == "sstt":
+ # spatial self attention
+ x = rearrange(x, "b c t h w -> (b t) (h w) c")
+ x = self.attn1(self.norm1(x)) + x
+
+ # spatial cross attention
+ context_ = context.repeat(t, 1, 1) if context is not None else None
+ x = self.attn2(self.norm2(x), context=context_) + x
+ x = rearrange(x, "(b t) (h w) c -> b c t h w", b=b, h=h)
+
+ if (no_temporal_attn is None) or (not no_temporal_attn):
+ if self.temporalcrossfirst:
+ # temporal cross attention
+ if self.temporal_crossattn_type == "crossattn":
+ # if temporal_context is not None:
+ temporal_context = context.repeat(h * w, 1, 1)
+ x = (
+ self.attn2_tmp(
+ self.norm5(x), context=temporal_context, mask=mask
+ )
+ + x
+ )
+ elif self.temporal_crossattn_type == "selfattn":
+ x = self.attn2_tmp(self.norm5(x), context=None, mask=mask) + x
+ elif self.temporal_crossattn_type == "skip":
+ pass
+ else:
+ raise NotImplementedError
+ # temporal self attention
+ x = rearrange(x, "b c t h w -> (b h w) t c")
+ x = self.attn1_tmp(self.norm4(x), mask=mask) + x
+ else:
+ # temporal self attention
+ x = rearrange(x, "b c t h w -> (b h w) t c")
+ x = self.attn1_tmp(self.norm4(x), mask=mask) + x
+ # temporal cross attention
+ if self.temporal_crossattn_type == "crossattn":
+ if temporal_context is not None:
+ temporal_context = context.repeat(h * w, 1, 1)
+ x = (
+ self.attn2_tmp(
+ self.norm5(x), context=temporal_context, mask=mask
+ )
+ + x
+ )
+ elif self.temporal_crossattn_type == "selfattn":
+ x = self.attn2_tmp(self.norm5(x), context=None, mask=mask) + x
+ elif self.temporal_crossattn_type == "skip":
+ pass
+ else:
+ raise NotImplementedError
+ else:
+ raise NotImplementedError
+
+ return x
+
+ def _st_cross_attn_parallel(
+ self, x, context=None, temporal_context=None, order="sst", no_temporal_attn=None
+ ):
+ """order: x -> Self Attn -> Cross Attn -> attn_s
+ x -> Temp Self Attn -> attn_t
+ x' = x + attn_s + attn_t
+ """
+ if no_temporal_attn is not None:
+ raise NotImplementedError
+
+ B, C, T, H, W = x.shape
+ # spatial self attention
+ h = x
+ h = rearrange(h, "b c t h w -> (b t) (h w) c")
+ h = self.attn1(self.norm1(h)) + h
+ # spatial cross
+ # context_ = context.repeat(T, 1, 1) if context is not None else None
+ if context is not None:
+ context_ = []
+ for i in range(context.shape[0]):
+ context_.append(context[i].unsqueeze(0).repeat(T, 1, 1))
+ context_ = torch.cat(context_, dim=0)
+ else:
+ context_ = None
+
+ h = self.attn2(self.norm2(h), context=context_) + h
+ h = rearrange(h, "(b t) (h w) c -> b c t h w", b=B, h=H)
+
+ # temporal self
+ h2 = x
+ h2 = rearrange(h2, "b c t h w -> (b h w) t c")
+ h2 = self.attn1_tmp(self.norm4(h2)) # + h2
+ h2 = rearrange(h2, "(b h w) t c -> b c t h w", b=B, h=H, w=W)
+ out = h + h2
+ return rearrange(out, "b c t h w -> (b h w) t c")
+
+
+def spatial_attn_reshape(x):
+ return rearrange(x, "b c t h w -> (b t) (h w) c")
+
+
+def spatial_attn_reshape_back(x, b, h):
+ return rearrange(x, "(b t) (h w) c -> b c t h w", b=b, h=h)
+
+
+def temporal_attn_reshape(x):
+ return rearrange(x, "b c t h w -> (b h w) t c")
+
+
+def temporal_attn_reshape_back(x, b, h, w):
+ return rearrange(x, "(b h w) t c -> b c t h w", b=b, h=h, w=w)
+
+
+def local_spatial_temporal_attn_reshape(x, window_size):
+ B, C, T, H, W = x.shape
+ NH = H // window_size
+ NW = W // window_size
+ # x = x.view(B, C, T, NH, window_size, NW, window_size)
+ # tokens = x.permute(0, 1, 2, 3, 5, 4, 6).contiguous()
+ # tokens = tokens.view(-1, window_size, window_size, C)
+ x = rearrange(
+ x,
+ "b c t (nh wh) (nw ww) -> b c t nh wh nw ww",
+ nh=NH,
+ nw=NW,
+ wh=window_size,
+ # # B, C, T, NH, NW, window_size, window_size
+ ww=window_size,
+ ).contiguous()
+ # (B, NH, NW) (T, window_size, window_size) C
+ x = rearrange(x, "b c t nh wh nw ww -> (b nh nw) (t wh ww) c")
+ return x
+
+
+def local_spatial_temporal_attn_reshape_back(x, window_size, b, h, w, t):
+ B, L, C = x.shape
+ NH = h // window_size
+ NW = w // window_size
+ x = rearrange(
+ x,
+ "(b nh nw) (t wh ww) c -> b c t nh wh nw ww",
+ b=b,
+ nh=NH,
+ nw=NW,
+ t=t,
+ wh=window_size,
+ ww=window_size,
+ )
+ x = rearrange(x, "b c t nh wh nw ww -> b c t (nh wh) (nw ww)")
+ return x
+
+
+class SpatialTemporalTransformer(nn.Module):
+ """
+ Transformer block for video-like data (5D tensor).
+ First, project the input (aka embedding) with NO reshape.
+ Then apply standard transformer action.
+ The 5D -> 3D reshape operation will be done in the specific attention module.
+ """
+
+ def __init__(
+ self,
+ in_channels,
+ n_heads,
+ d_head,
+ depth=1,
+ dropout=0.0,
+ context_dim=None,
+ # Temporal stuff
+ temporal_length=None,
+ image_length=None,
+ use_relative_position=True,
+ img_video_joint_train=False,
+ cross_attn_on_tempoal=False,
+ temporal_crossattn_type=False,
+ order="stst",
+ temporalcrossfirst=False,
+ split_stcontext=False,
+ temporal_context_dim=None,
+ **kwargs,
+ ):
+ super().__init__()
+
+ self.in_channels = in_channels
+ inner_dim = n_heads * d_head
+
+ self.norm = Normalize(in_channels)
+ self.proj_in = nn.Conv3d(
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0
+ )
+
+ self.transformer_blocks = nn.ModuleList(
+ [
+ BasicTransformerBlockST(
+ inner_dim,
+ n_heads,
+ d_head,
+ dropout=dropout,
+ # cross attn
+ context_dim=context_dim,
+ # temporal attn
+ temporal_length=temporal_length,
+ image_length=image_length,
+ use_relative_position=use_relative_position,
+ img_video_joint_train=img_video_joint_train,
+ temporal_crossattn_type=temporal_crossattn_type,
+ order=order,
+ temporalcrossfirst=temporalcrossfirst,
+ split_stcontext=split_stcontext,
+ temporal_context_dim=temporal_context_dim,
+ **kwargs,
+ )
+ for d in range(depth)
+ ]
+ )
+
+ self.proj_out = zero_module(
+ nn.Conv3d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
+ )
+
+ def forward(self, x, context=None, temporal_context=None, **kwargs):
+ # note: if no context is given, cross-attention defaults to self-attention
+ assert x.dim() == 5, f"x shape = {x.shape}"
+ b, c, t, h, w = x.shape
+ x_in = x
+
+ x = self.norm(x)
+ x = self.proj_in(x)
+
+ for block in self.transformer_blocks:
+ x = block(x, context=context, temporal_context=temporal_context, **kwargs)
+
+ x = self.proj_out(x)
+ return x + x_in
+
+
+class STAttentionBlock2(nn.Module):
+ def __init__(
+ self,
+ channels,
+ num_heads=1,
+ num_head_channels=-1,
+ use_checkpoint=False, # not used, only used in ResBlock
+ use_new_attention_order=False, # QKVAttention or QKVAttentionLegacy
+ temporal_length=16, # used in relative positional representation.
+ image_length=8, # used for image-video joint training.
+ # whether use relative positional representation in temporal attention.
+ use_relative_position=False,
+ img_video_joint_train=False,
+ # norm_type="groupnorm",
+ attn_norm_type="group",
+ use_tempoal_causal_attn=False,
+ ):
+ """
+ version 1: guided_diffusion implemented version
+ version 2: remove args input argument
+ """
+ super().__init__()
+
+ if num_head_channels == -1:
+ self.num_heads = num_heads
+ else:
+ assert (
+ channels % num_head_channels == 0
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
+ self.num_heads = channels // num_head_channels
+ self.use_checkpoint = use_checkpoint
+
+ self.temporal_length = temporal_length
+ self.image_length = image_length
+ self.use_relative_position = use_relative_position
+ self.img_video_joint_train = img_video_joint_train
+ self.attn_norm_type = attn_norm_type
+ assert self.attn_norm_type in ["group", "no_norm"]
+ self.use_tempoal_causal_attn = use_tempoal_causal_attn
+
+ if self.attn_norm_type == "group":
+ self.norm_s = normalization(channels)
+ self.norm_t = normalization(channels)
+
+ self.qkv_s = conv_nd(1, channels, channels * 3, 1)
+ self.qkv_t = conv_nd(1, channels, channels * 3, 1)
+
+ if self.img_video_joint_train:
+ mask = th.ones(
+ [1, temporal_length + image_length, temporal_length + image_length]
+ )
+ mask[:, temporal_length:, :] = 0
+ mask[:, :, temporal_length:] = 0
+ self.register_buffer("mask", mask)
+ else:
+ self.mask = None
+
+ if use_new_attention_order:
+ # split qkv before split heads
+ self.attention_s = QKVAttention(self.num_heads)
+ self.attention_t = QKVAttention(self.num_heads)
+ else:
+ # split heads before split qkv
+ self.attention_s = QKVAttentionLegacy(self.num_heads)
+ self.attention_t = QKVAttentionLegacy(self.num_heads)
+
+ if use_relative_position:
+ self.relative_position_k = RelativePosition(
+ num_units=channels // self.num_heads,
+ max_relative_position=temporal_length,
+ )
+ self.relative_position_v = RelativePosition(
+ num_units=channels // self.num_heads,
+ max_relative_position=temporal_length,
+ )
+
+ self.proj_out_s = zero_module(
+ # conv_dim, in_channels, out_channels, kernel_size
+ conv_nd(1, channels, channels, 1)
+ )
+ self.proj_out_t = zero_module(
+ # conv_dim, in_channels, out_channels, kernel_size
+ conv_nd(1, channels, channels, 1)
+ )
+
+ def forward(self, x, mask=None):
+ b, c, t, h, w = x.shape
+
+ # spatial
+ out = rearrange(x, "b c t h w -> (b t) c (h w)")
+ if self.attn_norm_type == "no_norm":
+ qkv = self.qkv_s(out)
+ else:
+ qkv = self.qkv_s(self.norm_s(out))
+ out = self.attention_s(qkv)
+ out = self.proj_out_s(out)
+ out = rearrange(out, "(b t) c (h w) -> b c t h w", b=b, h=h)
+ x += out
+
+ # temporal
+ out = rearrange(x, "b c t h w -> (b h w) c t")
+ if self.attn_norm_type == "no_norm":
+ qkv = self.qkv_t(out)
+ else:
+ qkv = self.qkv_t(self.norm_t(out))
+
+ # relative positional embedding
+ if self.use_relative_position:
+ len_q = qkv.size()[-1]
+ len_k, len_v = len_q, len_q
+ k_rp = self.relative_position_k(len_q, len_k)
+ v_rp = self.relative_position_v(len_q, len_v) # [T,T,head_dim]
+ out = self.attention_t(
+ qkv,
+ rp=(k_rp, v_rp),
+ mask=self.mask,
+ use_tempoal_causal_attn=self.use_tempoal_causal_attn,
+ )
+ else:
+ out = self.attention_t(
+ qkv,
+ rp=None,
+ mask=self.mask,
+ use_tempoal_causal_attn=self.use_tempoal_causal_attn,
+ )
+
+ out = self.proj_out_t(out)
+ out = rearrange(out, "(b h w) c t -> b c t h w", b=b, h=h, w=w)
+
+ return x + out
+
+
+class QKVAttentionLegacy(nn.Module):
+ """
+ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
+ """
+
+ def __init__(self, n_heads):
+ super().__init__()
+ self.n_heads = n_heads
+
+ def forward(self, qkv, rp=None, mask=None):
+ """
+ Apply QKV attention.
+
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
+ :return: an [N x (H * C) x T] tensor after attention.
+ """
+ if rp is not None or mask is not None:
+ raise NotImplementedError
+ bs, width, length = qkv.shape
+ assert width % (3 * self.n_heads) == 0
+ ch = width // (3 * self.n_heads)
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
+ scale = 1 / math.sqrt(math.sqrt(ch))
+ weight = th.einsum(
+ "bct,bcs->bts", q * scale, k * scale
+ ) # More stable with f16 than dividing afterwards
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
+ a = th.einsum("bts,bcs->bct", weight, v)
+ return a.reshape(bs, -1, length)
+
+ @staticmethod
+ def count_flops(model, _x, y):
+ return count_flops_attn(model, _x, y)
+
+
+class QKVAttention(nn.Module):
+ """
+ A module which performs QKV attention and splits in a different order.
+ """
+
+ def __init__(self, n_heads):
+ super().__init__()
+ self.n_heads = n_heads
+
+ def forward(self, qkv, rp=None, mask=None, use_tempoal_causal_attn=False):
+ """
+ Apply QKV attention.
+
+ :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
+ :return: an [N x (H * C) x T] tensor after attention.
+ """
+ bs, width, length = qkv.shape
+ assert width % (3 * self.n_heads) == 0
+ ch = width // (3 * self.n_heads)
+ # print('qkv', qkv.size())
+ q, k, v = qkv.chunk(3, dim=1)
+ scale = 1 / math.sqrt(math.sqrt(ch))
+ # print('bs, self.n_heads, ch, length', bs, self.n_heads, ch, length)
+
+ weight = th.einsum(
+ "bct,bcs->bts",
+ (q * scale).view(bs * self.n_heads, ch, length),
+ (k * scale).view(bs * self.n_heads, ch, length),
+ ) # More stable with f16 than dividing afterwards
+ # weight:[b,t,s] b=bs*n_heads*T
+
+ if rp is not None:
+ k_rp, v_rp = rp # [length, length, head_dim] [8, 8, 48]
+ weight2 = th.einsum(
+ "bct,tsc->bst", (q * scale).view(bs * self.n_heads, ch, length), k_rp
+ )
+ weight += weight2
+
+ if use_tempoal_causal_attn:
+ # weight = torch.tril(weight)
+ assert mask is None, f"Not implemented for merging two masks!"
+ mask = torch.tril(torch.ones(weight.shape))
+ else:
+ if mask is not None: # only keep upper-left matrix
+ # process mask
+ c, t, _ = weight.shape
+
+ if mask.shape[-1] > t:
+ mask = mask[:, :t, :t]
+ elif mask.shape[-1] < t: # pad ones
+ mask_ = th.zeros([c, t, t]).to(mask.device)
+ t_ = mask.shape[-1]
+ mask_[:, :t_, :t_] = mask
+ mask = mask_
+ else:
+ assert (
+ weight.shape[-1] == mask.shape[-1]
+ ), f"weight={weight.shape}, mask={mask.shape}"
+
+ if mask is not None:
+ INF = -1e8 # float('-inf')
+ weight = weight.float().masked_fill(mask == 0, INF)
+
+ weight = F.softmax(weight.float(), dim=-1).type(
+ weight.dtype
+ ) # [256, 8, 8] [b, t, t] b=bs*n_heads*h*w,t=nframes
+ # weight = F.softmax(weight, dim=-1)#[256, 8, 8] [b, t, t] b=bs*n_heads*h*w,t=nframes
+ # [256, 48, 8] [b, head_dim, t]
+ a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
+
+ if rp is not None:
+ a2 = th.einsum("bts,tsc->btc", weight, v_rp).transpose(1, 2) # btc->bct
+ a += a2
+
+ return a.reshape(bs, -1, length)
diff --git a/core/modules/encoders/__init__.py b/core/modules/encoders/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/core/modules/encoders/adapter.py b/core/modules/encoders/adapter.py
new file mode 100755
index 0000000000000000000000000000000000000000..37ec694254bb44b9efc0cc10c55a9986bc31def4
--- /dev/null
+++ b/core/modules/encoders/adapter.py
@@ -0,0 +1,485 @@
+import torch
+import torch.nn as nn
+from collections import OrderedDict
+from extralibs.cond_api import ExtraCondition
+from core.modules.x_transformer import FixedPositionalEmbedding
+from core.basics import zero_module, conv_nd, avg_pool_nd
+
+
+class Downsample(nn.Module):
+ """
+ A downsampling layer with an optional convolution.
+ :param channels: channels in the inputs and outputs.
+ :param use_conv: a bool determining if a convolution is applied.
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ downsampling occurs in the inner-two dimensions.
+ """
+
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.dims = dims
+ stride = 2 if dims != 3 else (1, 2, 2)
+ if use_conv:
+ self.op = conv_nd(
+ dims,
+ self.channels,
+ self.out_channels,
+ 3,
+ stride=stride,
+ padding=padding,
+ )
+ else:
+ assert self.channels == self.out_channels
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ return self.op(x)
+
+
+class ResnetBlock(nn.Module):
+ def __init__(self, in_c, out_c, down, ksize=3, sk=False, use_conv=True):
+ super().__init__()
+ ps = ksize // 2
+ if in_c != out_c or sk == False:
+ self.in_conv = nn.Conv2d(in_c, out_c, ksize, 1, ps)
+ else:
+ self.in_conv = None
+ self.block1 = nn.Conv2d(out_c, out_c, 3, 1, 1)
+ self.act = nn.ReLU()
+ self.block2 = nn.Conv2d(out_c, out_c, ksize, 1, ps)
+ if sk == False:
+ self.skep = nn.Conv2d(in_c, out_c, ksize, 1, ps)
+ else:
+ self.skep = None
+
+ self.down = down
+ if self.down == True:
+ self.down_opt = Downsample(in_c, use_conv=use_conv)
+
+ def forward(self, x):
+ if self.down == True:
+ x = self.down_opt(x)
+ if self.in_conv is not None:
+ x = self.in_conv(x)
+
+ h = self.block1(x)
+ h = self.act(h)
+ h = self.block2(h)
+ if self.skep is not None:
+ return h + self.skep(x)
+ else:
+ return h + x
+
+
+class Adapter(nn.Module):
+ def __init__(
+ self,
+ channels=[320, 640, 1280, 1280],
+ nums_rb=3,
+ cin=64,
+ ksize=3,
+ sk=True,
+ use_conv=True,
+ stage_downscale=True,
+ use_identity=False,
+ ):
+ super(Adapter, self).__init__()
+ if use_identity:
+ self.inlayer = nn.Identity()
+ else:
+ self.inlayer = nn.PixelUnshuffle(8)
+
+ self.channels = channels
+ self.nums_rb = nums_rb
+ self.body = []
+ for i in range(len(channels)):
+ for j in range(nums_rb):
+ if (i != 0) and (j == 0):
+ self.body.append(
+ ResnetBlock(
+ channels[i - 1],
+ channels[i],
+ down=stage_downscale,
+ ksize=ksize,
+ sk=sk,
+ use_conv=use_conv,
+ )
+ )
+ else:
+ self.body.append(
+ ResnetBlock(
+ channels[i],
+ channels[i],
+ down=False,
+ ksize=ksize,
+ sk=sk,
+ use_conv=use_conv,
+ )
+ )
+ self.body = nn.ModuleList(self.body)
+ self.conv_in = nn.Conv2d(cin, channels[0], 3, 1, 1)
+
+ def forward(self, x):
+ # unshuffle
+ x = self.inlayer(x)
+ # extract features
+ features = []
+ x = self.conv_in(x)
+ for i in range(len(self.channels)):
+ for j in range(self.nums_rb):
+ idx = i * self.nums_rb + j
+ x = self.body[idx](x)
+ features.append(x)
+
+ return features
+
+
+class PositionNet(nn.Module):
+ def __init__(self, input_size=(40, 64), cin=4, dim=512, out_dim=1024):
+ super().__init__()
+ self.input_size = input_size
+ self.out_dim = out_dim
+ self.down_factor = 8 # determined by the convnext backbone
+ feature_dim = dim
+ self.backbone = Adapter(
+ channels=[64, 128, 256, feature_dim],
+ nums_rb=2,
+ cin=cin,
+ stage_downscale=True,
+ use_identity=True,
+ )
+ self.num_tokens = (self.input_size[0] // self.down_factor) * (
+ self.input_size[1] // self.down_factor
+ )
+
+ self.pos_embedding = nn.Parameter(
+ torch.empty(1, self.num_tokens, feature_dim).normal_(std=0.02)
+ ) # from BERT
+
+ self.linears = nn.Sequential(
+ nn.Linear(feature_dim, 512),
+ nn.SiLU(),
+ nn.Linear(512, 512),
+ nn.SiLU(),
+ nn.Linear(512, out_dim),
+ )
+ # self.null_feature = torch.nn.Parameter(torch.zeros([feature_dim]))
+
+ def forward(self, x, mask=None):
+ B = x.shape[0]
+ # token from edge map
+ # x = torch.nn.functional.interpolate(x, self.input_size)
+ feature = self.backbone(x)[-1]
+ objs = feature.reshape(B, -1, self.num_tokens)
+ objs = objs.permute(0, 2, 1) # N*Num_tokens*dim
+ """
+ # expand null token
+ null_objs = self.null_feature.view(1,1,-1)
+ null_objs = null_objs.repeat(B,self.num_tokens,1)
+
+ # mask replacing
+ mask = mask.view(-1,1,1)
+ objs = objs*mask + null_objs*(1-mask)
+ """
+ # add pos
+ objs = objs + self.pos_embedding
+
+ # fuse them
+ objs = self.linears(objs)
+
+ assert objs.shape == torch.Size([B, self.num_tokens, self.out_dim])
+ return objs
+
+
+class PositionNet2(nn.Module):
+ def __init__(self, input_size=(40, 64), cin=4, dim=320, out_dim=1024):
+ super().__init__()
+ self.input_size = input_size
+ self.out_dim = out_dim
+ self.down_factor = 8 # determined by the convnext backbone
+ self.dim = dim
+ self.backbone = Adapter(
+ channels=[dim, dim, dim, dim],
+ nums_rb=2,
+ cin=cin,
+ stage_downscale=True,
+ use_identity=True,
+ )
+ self.pos_embedding = FixedPositionalEmbedding(dim=self.dim)
+ self.linears = nn.Sequential(
+ nn.Linear(dim, 512),
+ nn.SiLU(),
+ nn.Linear(512, 512),
+ nn.SiLU(),
+ nn.Linear(512, out_dim),
+ )
+
+ def forward(self, x, mask=None):
+ B = x.shape[0]
+ features = self.backbone(x)
+ token_lists = []
+ for feature in features:
+ objs = feature.reshape(B, self.dim, -1)
+ objs = objs.permute(0, 2, 1) # N*Num_tokens*dim
+ # add pos
+ objs = objs + self.pos_embedding(objs)
+ # fuse them
+ objs = self.linears(objs)
+ token_lists.append(objs)
+
+ return token_lists
+
+
+class LayerNorm(nn.LayerNorm):
+ """Subclass torch's LayerNorm to handle fp16."""
+
+ def forward(self, x: torch.Tensor):
+ orig_type = x.dtype
+ ret = super().forward(x.type(torch.float32))
+ return ret.type(orig_type)
+
+
+class QuickGELU(nn.Module):
+
+ def forward(self, x: torch.Tensor):
+ return x * torch.sigmoid(1.702 * x)
+
+
+class ResidualAttentionBlock(nn.Module):
+
+ def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
+ super().__init__()
+
+ self.attn = nn.MultiheadAttention(d_model, n_head)
+ self.ln_1 = LayerNorm(d_model)
+ self.mlp = nn.Sequential(
+ OrderedDict(
+ [
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
+ ("gelu", QuickGELU()),
+ ("c_proj", nn.Linear(d_model * 4, d_model)),
+ ]
+ )
+ )
+ self.ln_2 = LayerNorm(d_model)
+ self.attn_mask = attn_mask
+
+ def attention(self, x: torch.Tensor):
+ self.attn_mask = (
+ self.attn_mask.to(dtype=x.dtype, device=x.device)
+ if self.attn_mask is not None
+ else None
+ )
+ return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
+
+ def forward(self, x: torch.Tensor):
+ x = x + self.attention(self.ln_1(x))
+ x = x + self.mlp(self.ln_2(x))
+ return x
+
+
+class StyleAdapter(nn.Module):
+
+ def __init__(self, width=1024, context_dim=768, num_head=8, n_layes=3, num_token=4):
+ super().__init__()
+
+ scale = width**-0.5
+ self.transformer_layes = nn.Sequential(
+ *[ResidualAttentionBlock(width, num_head) for _ in range(n_layes)]
+ )
+ self.num_token = num_token
+ self.style_embedding = nn.Parameter(torch.randn(1, num_token, width) * scale)
+ self.ln_post = LayerNorm(width)
+ self.ln_pre = LayerNorm(width)
+ self.proj = nn.Parameter(scale * torch.randn(width, context_dim))
+
+ def forward(self, x):
+ # x shape [N, HW+1, C]
+ style_embedding = self.style_embedding + torch.zeros(
+ (x.shape[0], self.num_token, self.style_embedding.shape[-1]),
+ device=x.device,
+ )
+ x = torch.cat([x, style_embedding], dim=1)
+ x = self.ln_pre(x)
+ x = x.permute(1, 0, 2) # NLD -> LND
+ x = self.transformer_layes(x)
+ x = x.permute(1, 0, 2) # LND -> NLD
+
+ x = self.ln_post(x[:, -self.num_token :, :])
+ x = x @ self.proj
+
+ return x
+
+
+class ResnetBlock_light(nn.Module):
+ def __init__(self, in_c):
+ super().__init__()
+ self.block1 = nn.Conv2d(in_c, in_c, 3, 1, 1)
+ self.act = nn.ReLU()
+ self.block2 = nn.Conv2d(in_c, in_c, 3, 1, 1)
+
+ def forward(self, x):
+ h = self.block1(x)
+ h = self.act(h)
+ h = self.block2(h)
+
+ return h + x
+
+
+class extractor(nn.Module):
+ def __init__(self, in_c, inter_c, out_c, nums_rb, down=False):
+ super().__init__()
+ self.in_conv = nn.Conv2d(in_c, inter_c, 1, 1, 0)
+ self.body = []
+ for _ in range(nums_rb):
+ self.body.append(ResnetBlock_light(inter_c))
+ self.body = nn.Sequential(*self.body)
+ self.out_conv = nn.Conv2d(inter_c, out_c, 1, 1, 0)
+ self.down = down
+ if self.down == True:
+ self.down_opt = Downsample(in_c, use_conv=False)
+
+ def forward(self, x):
+ if self.down == True:
+ x = self.down_opt(x)
+ x = self.in_conv(x)
+ x = self.body(x)
+ x = self.out_conv(x)
+
+ return x
+
+
+class Adapter_light(nn.Module):
+ def __init__(self, channels=[320, 640, 1280, 1280], nums_rb=3, cin=64):
+ super(Adapter_light, self).__init__()
+ self.unshuffle = nn.PixelUnshuffle(8)
+ self.channels = channels
+ self.nums_rb = nums_rb
+ self.body = []
+ for i in range(len(channels)):
+ if i == 0:
+ self.body.append(
+ extractor(
+ in_c=cin,
+ inter_c=channels[i] // 4,
+ out_c=channels[i],
+ nums_rb=nums_rb,
+ down=False,
+ )
+ )
+ else:
+ self.body.append(
+ extractor(
+ in_c=channels[i - 1],
+ inter_c=channels[i] // 4,
+ out_c=channels[i],
+ nums_rb=nums_rb,
+ down=True,
+ )
+ )
+ self.body = nn.ModuleList(self.body)
+
+ def forward(self, x):
+ # unshuffle
+ x = self.unshuffle(x)
+ # extract features
+ features = []
+ for i in range(len(self.channels)):
+ x = self.body[i](x)
+ features.append(x)
+
+ return features
+
+
+class CoAdapterFuser(nn.Module):
+ def __init__(
+ self, unet_channels=[320, 640, 1280, 1280], width=768, num_head=8, n_layes=3
+ ):
+ super(CoAdapterFuser, self).__init__()
+ scale = width**0.5
+ self.task_embedding = nn.Parameter(scale * torch.randn(16, width))
+ self.positional_embedding = nn.Parameter(
+ scale * torch.randn(len(unet_channels), width)
+ )
+ self.spatial_feat_mapping = nn.ModuleList()
+ for ch in unet_channels:
+ self.spatial_feat_mapping.append(
+ nn.Sequential(
+ nn.SiLU(),
+ nn.Linear(ch, width),
+ )
+ )
+ self.transformer_layes = nn.Sequential(
+ *[ResidualAttentionBlock(width, num_head) for _ in range(n_layes)]
+ )
+ self.ln_post = LayerNorm(width)
+ self.ln_pre = LayerNorm(width)
+ self.spatial_ch_projs = nn.ModuleList()
+ for ch in unet_channels:
+ self.spatial_ch_projs.append(zero_module(nn.Linear(width, ch)))
+ self.seq_proj = nn.Parameter(torch.zeros(width, width))
+
+ def forward(self, features):
+ if len(features) == 0:
+ return None, None
+ inputs = []
+ for cond_name in features.keys():
+ task_idx = getattr(ExtraCondition, cond_name).value
+ if not isinstance(features[cond_name], list):
+ inputs.append(features[cond_name] + self.task_embedding[task_idx])
+ continue
+
+ feat_seq = []
+ for idx, feature_map in enumerate(features[cond_name]):
+ feature_vec = torch.mean(feature_map, dim=(2, 3))
+ feature_vec = self.spatial_feat_mapping[idx](feature_vec)
+ feat_seq.append(feature_vec)
+ feat_seq = torch.stack(feat_seq, dim=1) # Nx4xC
+ feat_seq = feat_seq + self.task_embedding[task_idx]
+ feat_seq = feat_seq + self.positional_embedding
+ inputs.append(feat_seq)
+
+ x = torch.cat(inputs, dim=1) # NxLxC
+ x = self.ln_pre(x)
+ x = x.permute(1, 0, 2) # NLD -> LND
+ x = self.transformer_layes(x)
+ x = x.permute(1, 0, 2) # LND -> NLD
+ x = self.ln_post(x)
+
+ ret_feat_map = None
+ ret_feat_seq = None
+ cur_seq_idx = 0
+ for cond_name in features.keys():
+ if not isinstance(features[cond_name], list):
+ length = features[cond_name].size(1)
+ transformed_feature = features[cond_name] * (
+ (x[:, cur_seq_idx : cur_seq_idx + length] @ self.seq_proj) + 1
+ )
+ if ret_feat_seq is None:
+ ret_feat_seq = transformed_feature
+ else:
+ ret_feat_seq = torch.cat([ret_feat_seq, transformed_feature], dim=1)
+ cur_seq_idx += length
+ continue
+
+ length = len(features[cond_name])
+ transformed_feature_list = []
+ for idx in range(length):
+ alpha = self.spatial_ch_projs[idx](x[:, cur_seq_idx + idx])
+ alpha = alpha.unsqueeze(-1).unsqueeze(-1) + 1
+ transformed_feature_list.append(features[cond_name][idx] * alpha)
+ if ret_feat_map is None:
+ ret_feat_map = transformed_feature_list
+ else:
+ ret_feat_map = list(
+ map(lambda x, y: x + y, ret_feat_map, transformed_feature_list)
+ )
+ cur_seq_idx += length
+
+ assert cur_seq_idx == x.size(1)
+
+ return ret_feat_map, ret_feat_seq
diff --git a/core/modules/encoders/condition.py b/core/modules/encoders/condition.py
new file mode 100755
index 0000000000000000000000000000000000000000..66e6cafff6fb934e1fb26caff8e97235ae668bd1
--- /dev/null
+++ b/core/modules/encoders/condition.py
@@ -0,0 +1,511 @@
+import torch
+import torch.nn as nn
+import kornia
+from torch.utils.checkpoint import checkpoint
+
+from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel
+import open_clip
+
+from core.common import autocast
+from utils.utils import count_params
+
+
+class AbstractEncoder(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def encode(self, *args, **kwargs):
+ raise NotImplementedError
+
+
+class IdentityEncoder(AbstractEncoder):
+
+ def encode(self, x):
+ return x
+
+
+class ClassEmbedder(nn.Module):
+ def __init__(self, embed_dim, n_classes=1000, key="class", ucg_rate=0.1):
+ super().__init__()
+ self.key = key
+ self.embedding = nn.Embedding(n_classes, embed_dim)
+ self.n_classes = n_classes
+ self.ucg_rate = ucg_rate
+
+ def forward(self, batch, key=None, disable_dropout=False):
+ if key is None:
+ key = self.key
+ # this is for use in crossattn
+ c = batch[key][:, None]
+ if self.ucg_rate > 0.0 and not disable_dropout:
+ mask = 1.0 - torch.bernoulli(torch.ones_like(c) * self.ucg_rate)
+ c = mask * c + (1 - mask) * torch.ones_like(c) * (self.n_classes - 1)
+ c = c.long()
+ c = self.embedding(c)
+ return c
+
+ def get_unconditional_conditioning(self, bs, device="cuda"):
+ # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000)
+ uc_class = self.n_classes - 1
+ uc = torch.ones((bs,), device=device) * uc_class
+ uc = {self.key: uc}
+ return uc
+
+
+def disabled_train(self, mode=True):
+ """Overwrite model.train with this function to make sure train/eval mode
+ does not change anymore."""
+ return self
+
+
+class FrozenT5Embedder(AbstractEncoder):
+ """Uses the T5 transformer encoder for text"""
+
+ def __init__(
+ self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True
+ ):
+ super().__init__()
+ self.tokenizer = T5Tokenizer.from_pretrained(version)
+ self.transformer = T5EncoderModel.from_pretrained(version)
+ self.device = device
+ self.max_length = max_length
+ if freeze:
+ self.freeze()
+
+ def freeze(self):
+ self.transformer = self.transformer.eval()
+ # self.train = disabled_train
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, text):
+ batch_encoding = self.tokenizer(
+ text,
+ truncation=True,
+ max_length=self.max_length,
+ return_length=True,
+ return_overflowing_tokens=False,
+ padding="max_length",
+ return_tensors="pt",
+ )
+ tokens = batch_encoding["input_ids"].to(self.device)
+ outputs = self.transformer(input_ids=tokens)
+
+ z = outputs.last_hidden_state
+ return z
+
+ def encode(self, text):
+ return self(text)
+
+
+class FrozenCLIPEmbedder(AbstractEncoder):
+ """Uses the CLIP transformer encoder for text (from huggingface)"""
+
+ LAYERS = ["last", "pooled", "hidden"]
+
+ def __init__(
+ self,
+ version="openai/clip-vit-large-patch14",
+ device="cuda",
+ max_length=77,
+ freeze=True,
+ layer="last",
+ layer_idx=None,
+ ): # clip-vit-base-patch32
+ super().__init__()
+ assert layer in self.LAYERS
+ self.tokenizer = CLIPTokenizer.from_pretrained(version)
+ self.transformer = CLIPTextModel.from_pretrained(version)
+ self.device = device
+ self.max_length = max_length
+ if freeze:
+ self.freeze()
+ self.layer = layer
+ self.layer_idx = layer_idx
+ if layer == "hidden":
+ assert layer_idx is not None
+ assert 0 <= abs(layer_idx) <= 12
+
+ def freeze(self):
+ self.transformer = self.transformer.eval()
+ # self.train = disabled_train
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, text):
+ batch_encoding = self.tokenizer(
+ text,
+ truncation=True,
+ max_length=self.max_length,
+ return_length=True,
+ return_overflowing_tokens=False,
+ padding="max_length",
+ return_tensors="pt",
+ )
+ tokens = batch_encoding["input_ids"].to(self.device)
+ outputs = self.transformer(
+ input_ids=tokens, output_hidden_states=self.layer == "hidden"
+ )
+ if self.layer == "last":
+ z = outputs.last_hidden_state
+ elif self.layer == "pooled":
+ z = outputs.pooler_output[:, None, :]
+ else:
+ z = outputs.hidden_states[self.layer_idx]
+ return z
+
+ def encode(self, text):
+ return self(text)
+
+
+class ClipImageEmbedder(nn.Module):
+ def __init__(
+ self,
+ model,
+ jit=False,
+ device="cuda" if torch.cuda.is_available() else "cpu",
+ antialias=True,
+ ucg_rate=0.0,
+ ):
+ super().__init__()
+ from clip import load as load_clip
+
+ self.model, _ = load_clip(name=model, device=device, jit=jit)
+
+ self.antialias = antialias
+
+ self.register_buffer(
+ "mean", torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False
+ )
+ self.register_buffer(
+ "std", torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False
+ )
+ self.ucg_rate = ucg_rate
+
+ def preprocess(self, x):
+ # normalize to [0,1]
+ x = kornia.geometry.resize(
+ x,
+ (224, 224),
+ interpolation="bicubic",
+ align_corners=True,
+ antialias=self.antialias,
+ )
+ x = (x + 1.0) / 2.0
+ # re-normalize according to clip
+ x = kornia.enhance.normalize(x, self.mean, self.std)
+ return x
+
+ def forward(self, x, no_dropout=False):
+ # x is assumed to be in range [-1,1]
+ out = self.model.encode_image(self.preprocess(x))
+ out = out.to(x.dtype)
+ if self.ucg_rate > 0.0 and not no_dropout:
+ out = (
+ torch.bernoulli(
+ (1.0 - self.ucg_rate) * torch.ones(out.shape[0], device=out.device)
+ )[:, None]
+ * out
+ )
+ return out
+
+
+class FrozenOpenCLIPEmbedder(AbstractEncoder):
+ """
+ Uses the OpenCLIP transformer encoder for text
+ """
+
+ LAYERS = [
+ # "pooled",
+ "last",
+ "penultimate",
+ ]
+
+ def __init__(
+ self,
+ arch="ViT-H-14",
+ version=None,
+ device="cuda",
+ max_length=77,
+ freeze=True,
+ layer="last",
+ ):
+ super().__init__()
+ assert layer in self.LAYERS
+ model, _, _ = open_clip.create_model_and_transforms(
+ arch, device=torch.device("cpu"), pretrained=version
+ )
+ del model.visual
+ self.model = model
+
+ self.device = device
+ self.max_length = max_length
+ if freeze:
+ self.freeze()
+ self.layer = layer
+ if self.layer == "last":
+ self.layer_idx = 0
+ elif self.layer == "penultimate":
+ self.layer_idx = 1
+ else:
+ raise NotImplementedError()
+
+ def freeze(self):
+ self.model = self.model.eval()
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, text):
+ tokens = open_clip.tokenize(text)
+ z = self.encode_with_transformer(tokens.to(self.device))
+ return z
+
+ def encode_with_transformer(self, text):
+ x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
+ x = x + self.model.positional_embedding
+ x = x.permute(1, 0, 2) # NLD -> LND
+ x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
+ x = x.permute(1, 0, 2) # LND -> NLD
+ x = self.model.ln_final(x)
+ return x
+
+ def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
+ for i, r in enumerate(self.model.transformer.resblocks):
+ if i == len(self.model.transformer.resblocks) - self.layer_idx:
+ break
+ if (
+ self.model.transformer.grad_checkpointing
+ and not torch.jit.is_scripting()
+ ):
+ x = checkpoint(r, x, attn_mask)
+ else:
+ x = r(x, attn_mask=attn_mask)
+ return x
+
+ def encode(self, text):
+ return self(text)
+
+
+class FrozenOpenCLIPImageEmbedder(AbstractEncoder):
+ """
+ Uses the OpenCLIP vision transformer encoder for images
+ """
+
+ def __init__(
+ self,
+ arch="ViT-H-14",
+ version=None,
+ device="cuda",
+ max_length=77,
+ freeze=True,
+ layer="pooled",
+ antialias=True,
+ ucg_rate=0.0,
+ ):
+ super().__init__()
+ model, _, _ = open_clip.create_model_and_transforms(
+ arch, device=torch.device("cpu"), pretrained=version
+ )
+ del model.transformer
+ self.model = model
+
+ self.device = device
+ self.max_length = max_length
+ if freeze:
+ self.freeze()
+ self.layer = layer
+ if self.layer == "penultimate":
+ raise NotImplementedError()
+ self.layer_idx = 1
+
+ self.antialias = antialias
+
+ self.register_buffer(
+ "mean", torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False
+ )
+ self.register_buffer(
+ "std", torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False
+ )
+ self.ucg_rate = ucg_rate
+
+ def preprocess(self, x):
+ # normalize to [0,1]
+ x = kornia.geometry.resize(
+ x,
+ (224, 224),
+ interpolation="bicubic",
+ align_corners=True,
+ antialias=self.antialias,
+ )
+ x = (x + 1.0) / 2.0
+ # renormalize according to clip
+ x = kornia.enhance.normalize(x, self.mean, self.std)
+ return x
+
+ def freeze(self):
+ self.model = self.model.eval()
+ for param in self.parameters():
+ param.requires_grad = False
+
+ @autocast
+ def forward(self, image, no_dropout=False):
+ z = self.encode_with_vision_transformer(image)
+ if self.ucg_rate > 0.0 and not no_dropout:
+ z = (
+ torch.bernoulli(
+ (1.0 - self.ucg_rate) * torch.ones(z.shape[0], device=z.device)
+ )[:, None]
+ * z
+ )
+ return z
+
+ def encode_with_vision_transformer(self, img):
+ img = self.preprocess(img)
+ x = self.model.visual(img)
+ return x
+
+ def encode(self, text):
+ return self(text)
+
+
+class FrozenOpenCLIPImageEmbedderV2(AbstractEncoder):
+ """
+ Uses the OpenCLIP vision transformer encoder for images
+ """
+
+ def __init__(
+ self,
+ arch="ViT-H-14",
+ version=None,
+ device="cuda",
+ freeze=True,
+ layer="pooled",
+ antialias=True,
+ ):
+ super().__init__()
+ model, _, _ = open_clip.create_model_and_transforms(
+ arch,
+ device=torch.device("cpu"),
+ pretrained=version,
+ )
+ del model.transformer
+ self.model = model
+ self.device = device
+
+ if freeze:
+ self.freeze()
+ self.layer = layer
+ if self.layer == "penultimate":
+ raise NotImplementedError()
+ self.layer_idx = 1
+
+ self.antialias = antialias
+ self.register_buffer(
+ "mean", torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False
+ )
+ self.register_buffer(
+ "std", torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False
+ )
+
+ def preprocess(self, x):
+ # normalize to [0,1]
+ x = kornia.geometry.resize(
+ x,
+ (224, 224),
+ interpolation="bicubic",
+ align_corners=True,
+ antialias=self.antialias,
+ )
+ x = (x + 1.0) / 2.0
+ # renormalize according to clip
+ x = kornia.enhance.normalize(x, self.mean, self.std)
+ return x
+
+ def freeze(self):
+ self.model = self.model.eval()
+ for param in self.model.parameters():
+ param.requires_grad = False
+
+ def forward(self, image, no_dropout=False):
+ # image: b c h w
+ z = self.encode_with_vision_transformer(image)
+ return z
+
+ def encode_with_vision_transformer(self, x):
+ x = self.preprocess(x)
+
+ # to patches - whether to use dual patchnorm - https://arxiv.org/abs/2302.01327v1
+ if self.model.visual.input_patchnorm:
+ # einops - rearrange(x, 'b c (h p1) (w p2) -> b (h w) (c p1 p2)')
+ x = x.reshape(
+ x.shape[0],
+ x.shape[1],
+ self.model.visual.grid_size[0],
+ self.model.visual.patch_size[0],
+ self.model.visual.grid_size[1],
+ self.model.visual.patch_size[1],
+ )
+ x = x.permute(0, 2, 4, 1, 3, 5)
+ x = x.reshape(
+ x.shape[0],
+ self.model.visual.grid_size[0] * self.model.visual.grid_size[1],
+ -1,
+ )
+ x = self.model.visual.patchnorm_pre_ln(x)
+ x = self.model.visual.conv1(x)
+ else:
+ x = self.model.visual.conv1(x) # shape = [*, width, grid, grid]
+ # shape = [*, width, grid ** 2]
+ x = x.reshape(x.shape[0], x.shape[1], -1)
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
+
+ # class embeddings and positional embeddings
+ x = torch.cat(
+ [
+ self.model.visual.class_embedding.to(x.dtype)
+ + torch.zeros(
+ x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device
+ ),
+ x,
+ ],
+ dim=1,
+ ) # shape = [*, grid ** 2 + 1, width]
+ x = x + self.model.visual.positional_embedding.to(x.dtype)
+
+ # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
+ x = self.model.visual.patch_dropout(x)
+ x = self.model.visual.ln_pre(x)
+
+ x = x.permute(1, 0, 2) # NLD -> LND
+ x = self.model.visual.transformer(x)
+ x = x.permute(1, 0, 2) # LND -> NLD
+
+ return x
+
+
+class FrozenCLIPT5Encoder(AbstractEncoder):
+ def __init__(
+ self,
+ clip_version="openai/clip-vit-large-patch14",
+ t5_version="google/t5-v1_1-xl",
+ device="cuda",
+ clip_max_length=77,
+ t5_max_length=77,
+ ):
+ super().__init__()
+ self.clip_encoder = FrozenCLIPEmbedder(
+ clip_version, device, max_length=clip_max_length
+ )
+ self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length)
+ print(
+ f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder) * 1.e-6:.2f} M parameters, "
+ f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder) * 1.e-6:.2f} M params."
+ )
+
+ def encode(self, text):
+ return self(text)
+
+ def forward(self, text):
+ clip_z = self.clip_encoder.encode(text)
+ t5_z = self.t5_encoder.encode(text)
+ return [clip_z, t5_z]
diff --git a/core/modules/encoders/resampler.py b/core/modules/encoders/resampler.py
new file mode 100755
index 0000000000000000000000000000000000000000..8002f509304b255064da0f371d634358d57b79ac
--- /dev/null
+++ b/core/modules/encoders/resampler.py
@@ -0,0 +1,264 @@
+import math
+
+import torch
+import torch.nn as nn
+from einops import rearrange, repeat
+
+
+class ImageProjModel(nn.Module):
+ """Projection Model"""
+
+ def __init__(
+ self,
+ cross_attention_dim=1024,
+ clip_embeddings_dim=1024,
+ clip_extra_context_tokens=4,
+ ):
+ super().__init__()
+ self.cross_attention_dim = cross_attention_dim
+ self.clip_extra_context_tokens = clip_extra_context_tokens
+ self.proj = nn.Linear(
+ clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim
+ )
+ self.norm = nn.LayerNorm(cross_attention_dim)
+
+ def forward(self, image_embeds):
+ # embeds = image_embeds
+ embeds = image_embeds.type(list(self.proj.parameters())[0].dtype)
+ clip_extra_context_tokens = self.proj(embeds).reshape(
+ -1, self.clip_extra_context_tokens, self.cross_attention_dim
+ )
+ clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
+ return clip_extra_context_tokens
+
+
+# FFN
+def FeedForward(dim, mult=4):
+ inner_dim = int(dim * mult)
+ return nn.Sequential(
+ nn.LayerNorm(dim),
+ nn.Linear(dim, inner_dim, bias=False),
+ nn.GELU(),
+ nn.Linear(inner_dim, dim, bias=False),
+ )
+
+
+def reshape_tensor(x, heads):
+ bs, length, width = x.shape
+ # (bs, length, width) --> (bs, length, n_heads, dim_per_head)
+ x = x.view(bs, length, heads, -1)
+ # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
+ x = x.transpose(1, 2)
+ # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
+ x = x.reshape(bs, heads, length, -1)
+ return x
+
+
+class PerceiverAttention(nn.Module):
+ def __init__(self, *, dim, dim_head=64, heads=8):
+ super().__init__()
+ self.scale = dim_head**-0.5
+ self.dim_head = dim_head
+ self.heads = heads
+ inner_dim = dim_head * heads
+
+ self.norm1 = nn.LayerNorm(dim)
+ self.norm2 = nn.LayerNorm(dim)
+
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
+
+ def forward(self, x, latents):
+ """
+ Args:
+ x (torch.Tensor): image features
+ shape (b, n1, D)
+ latent (torch.Tensor): latent features
+ shape (b, n2, D)
+ """
+ x = self.norm1(x)
+ latents = self.norm2(latents)
+
+ b, l, _ = latents.shape
+
+ q = self.to_q(latents)
+ kv_input = torch.cat((x, latents), dim=-2)
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
+
+ q = reshape_tensor(q, self.heads)
+ k = reshape_tensor(k, self.heads)
+ v = reshape_tensor(v, self.heads)
+
+ # attention
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
+ # More stable with f16 than dividing afterwards
+ weight = (q * scale) @ (k * scale).transpose(-2, -1)
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
+ out = weight @ v
+
+ out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
+
+ return self.to_out(out)
+
+
+class Resampler(nn.Module):
+ def __init__(
+ self,
+ dim=1024,
+ depth=8,
+ dim_head=64,
+ heads=16,
+ num_queries=8,
+ embedding_dim=768,
+ output_dim=1024,
+ ff_mult=4,
+ video_length=None,
+ ):
+ super().__init__()
+ self.num_queries = num_queries
+ self.video_length = video_length
+ if video_length is not None:
+ num_queries = num_queries * video_length
+ self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
+ self.proj_in = nn.Linear(embedding_dim, dim)
+ self.proj_out = nn.Linear(dim, output_dim)
+ self.norm_out = nn.LayerNorm(output_dim)
+
+ self.layers = nn.ModuleList([])
+ for _ in range(depth):
+ self.layers.append(
+ nn.ModuleList(
+ [
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
+ FeedForward(dim=dim, mult=ff_mult),
+ ]
+ )
+ )
+
+ def forward(self, x):
+ latents = self.latents.repeat(x.size(0), 1, 1) # B (T L) C
+ x = self.proj_in(x)
+
+ for attn, ff in self.layers:
+ latents = attn(x, latents) + latents
+ latents = ff(latents) + latents
+
+ latents = self.proj_out(latents)
+ latents = self.norm_out(latents) # B L C or B (T L) C
+
+ return latents
+
+
+class CameraPoseQueryTransformer(nn.Module):
+ def __init__(
+ self,
+ dim=1024,
+ depth=8,
+ dim_head=64,
+ heads=16,
+ num_queries=8,
+ embedding_dim=768,
+ output_dim=1024,
+ ff_mult=4,
+ num_views=None,
+ use_multi_view_attention=True,
+ ):
+ super().__init__()
+
+ self.num_queries = num_queries
+ self.num_views = num_views
+ assert num_views is not None, "video_length must be given."
+ self.use_multi_view_attention = use_multi_view_attention
+ self.camera_pose_embedding_layers = nn.Sequential(
+ nn.Linear(12, dim),
+ nn.SiLU(),
+ nn.Linear(dim, dim),
+ nn.SiLU(),
+ nn.Linear(dim, dim),
+ )
+ nn.init.zeros_(self.camera_pose_embedding_layers[-1].weight)
+ nn.init.zeros_(self.camera_pose_embedding_layers[-1].bias)
+
+ self.latents = nn.Parameter(
+ torch.randn(1, num_views * num_queries, dim) / dim**0.5
+ )
+
+ self.proj_in = nn.Linear(embedding_dim, dim)
+
+ self.proj_out = nn.Linear(dim, output_dim)
+ self.norm_out = nn.LayerNorm(output_dim)
+
+ self.layers = nn.ModuleList([])
+ for _ in range(depth):
+ self.layers.append(
+ nn.ModuleList(
+ [
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
+ FeedForward(dim=dim, mult=ff_mult),
+ ]
+ )
+ )
+
+ def forward(self, x, camera_poses):
+ # camera_poses: (b, t, 12)
+ batch_size, num_views, _ = camera_poses.shape
+ # latents: (1, t*q, d) -> (b, t*q, d)
+ latents = self.latents.repeat(batch_size, 1, 1)
+ x = self.proj_in(x)
+ # camera_poses: (b*t, 12)
+ camera_poses = rearrange(camera_poses, "b t d -> (b t) d", t=num_views)
+ camera_poses = self.camera_pose_embedding_layers(
+ camera_poses
+ ) # camera_poses: (b*t, d)
+ # camera_poses: (b, t, d)
+ camera_poses = rearrange(camera_poses, "(b t) d -> b t d", t=num_views)
+ # camera_poses: (b, t*q, d)
+ camera_poses = repeat(camera_poses, "b t d -> b (t q) d", q=self.num_queries)
+
+ latents = latents + camera_poses # b, t*q, d
+
+ latents = rearrange(
+ latents,
+ "b (t q) d -> (b t) q d",
+ b=batch_size,
+ t=num_views,
+ q=self.num_queries,
+ ) # (b*t, q, d)
+
+ _, x_seq_size, _ = x.shape
+ for layer_idx, (attn, ff) in enumerate(self.layers):
+ if self.use_multi_view_attention and layer_idx % 2 == 1:
+ # latents: (b*t, q, d)
+ latents = rearrange(
+ latents,
+ "(b t) q d -> b (t q) d",
+ b=batch_size,
+ t=num_views,
+ q=self.num_queries,
+ )
+ # x: (b*t, s, d)
+ x = rearrange(
+ x, "(b t) s d -> b (t s) d", b=batch_size, t=num_views, s=x_seq_size
+ )
+
+ # print("After rearrange: latents.shape=", latents.shape)
+ # print("After rearrange: x.shape=", camera_poses.shape)
+ latents = attn(x, latents) + latents
+ latents = ff(latents) + latents
+ if self.use_multi_view_attention and layer_idx % 2 == 1:
+ # latents: (b*q, t, d)
+ latents = rearrange(
+ latents,
+ "b (t q) d -> (b t) q d",
+ b=batch_size,
+ t=num_views,
+ q=self.num_queries,
+ )
+ # x: (b*s, t, d)
+ x = rearrange(
+ x, "b (t s) d -> (b t) s d", b=batch_size, t=num_views, s=x_seq_size
+ )
+ latents = self.proj_out(latents)
+ latents = self.norm_out(latents) # B L C or B (T L) C
+ return latents
diff --git a/core/modules/networks/ae_modules.py b/core/modules/networks/ae_modules.py
new file mode 100755
index 0000000000000000000000000000000000000000..2a2849020fb6cc186883b66ce988b1dc04e52035
--- /dev/null
+++ b/core/modules/networks/ae_modules.py
@@ -0,0 +1,1023 @@
+# pytorch_diffusion + derived encoder decoder
+import math
+
+import torch
+import numpy as np
+import torch.nn as nn
+from einops import rearrange
+
+from utils.utils import instantiate_from_config
+from core.modules.attention import LinearAttention
+
+
+def nonlinearity(x):
+ # swish
+ return x * torch.sigmoid(x)
+
+
+def Normalize(in_channels, num_groups=32):
+ return torch.nn.GroupNorm(
+ num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
+ )
+
+
+class LinAttnBlock(LinearAttention):
+ """to match AttnBlock usage"""
+
+ def __init__(self, in_channels):
+ super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
+
+
+class AttnBlock(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+ self.k = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+ self.v = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+ self.proj_out = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b, c, h, w = q.shape
+ q = q.reshape(b, c, h * w) # bcl
+ q = q.permute(0, 2, 1) # bcl -> blc l=hw
+ k = k.reshape(b, c, h * w) # bcl
+
+ w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
+ w_ = w_ * (int(c) ** (-0.5))
+ w_ = torch.nn.functional.softmax(w_, dim=2)
+
+ # attend to values
+ v = v.reshape(b, c, h * w)
+ w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
+ # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
+ h_ = torch.bmm(v, w_)
+ h_ = h_.reshape(b, c, h, w)
+
+ h_ = self.proj_out(h_)
+
+ return x + h_
+
+
+def make_attn(in_channels, attn_type="vanilla"):
+ assert attn_type in ["vanilla", "linear", "none"], f"attn_type {attn_type} unknown"
+ if attn_type == "vanilla":
+ return AttnBlock(in_channels)
+ elif attn_type == "none":
+ return nn.Identity(in_channels)
+ else:
+ return LinAttnBlock(in_channels)
+
+
+class Downsample(nn.Module):
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ self.in_channels = in_channels
+ if self.with_conv:
+ # no asymmetric padding in torch conv, must do it ourselves
+ self.conv = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=3, stride=2, padding=0
+ )
+
+ def forward(self, x):
+ if self.with_conv:
+ pad = (0, 1, 0, 1)
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
+ x = self.conv(x)
+ else:
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
+ return x
+
+
+class Upsample(nn.Module):
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ self.in_channels = in_channels
+ if self.with_conv:
+ self.conv = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=3, stride=1, padding=1
+ )
+
+ def forward(self, x):
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
+ if self.with_conv:
+ x = self.conv(x)
+ return x
+
+
+def get_timestep_embedding(time_steps, embedding_dim):
+ """
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
+ From Fairseq.
+ Build sinusoidal embeddings.
+ This matches the implementation in tensor2tensor, but differs slightly
+ from the description in Section 3.5 of "Attention Is All You Need".
+ """
+ assert len(time_steps.shape) == 1
+
+ half_dim = embedding_dim // 2
+ emb = math.log(10000) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
+ emb = emb.to(device=time_steps.device)
+ emb = time_steps.float()[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
+ return emb
+
+
+class ResnetBlock(nn.Module):
+ def __init__(
+ self,
+ *,
+ in_channels,
+ out_channels=None,
+ conv_shortcut=False,
+ dropout,
+ temb_channels=512,
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+ self.use_conv_shortcut = conv_shortcut
+
+ self.norm1 = Normalize(in_channels)
+ self.conv1 = torch.nn.Conv2d(
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
+ )
+ if temb_channels > 0:
+ self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
+ self.norm2 = Normalize(out_channels)
+ self.dropout = torch.nn.Dropout(dropout)
+ self.conv2 = torch.nn.Conv2d(
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1
+ )
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ self.conv_shortcut = torch.nn.Conv2d(
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
+ )
+ else:
+ self.nin_shortcut = torch.nn.Conv2d(
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
+ )
+
+ def forward(self, x, temb):
+ h = x
+ h = self.norm1(h)
+ h = nonlinearity(h)
+ h = self.conv1(h)
+
+ if temb is not None:
+ h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
+
+ h = self.norm2(h)
+ h = nonlinearity(h)
+ h = self.dropout(h)
+ h = self.conv2(h)
+
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ x = self.conv_shortcut(x)
+ else:
+ x = self.nin_shortcut(x)
+
+ return x + h
+
+
+class Model(nn.Module):
+ def __init__(
+ self,
+ *,
+ ch,
+ out_ch,
+ ch_mult=(1, 2, 4, 8),
+ num_res_blocks,
+ attn_resolutions,
+ dropout=0.0,
+ resamp_with_conv=True,
+ in_channels,
+ resolution,
+ use_timestep=True,
+ use_linear_attn=False,
+ attn_type="vanilla",
+ ):
+ super().__init__()
+ if use_linear_attn:
+ attn_type = "linear"
+ self.ch = ch
+ self.temb_ch = self.ch * 4
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+
+ self.use_timestep = use_timestep
+ if self.use_timestep:
+ # timestep embedding
+ self.temb = nn.Module()
+ self.temb.dense = nn.ModuleList(
+ [
+ torch.nn.Linear(self.ch, self.temb_ch),
+ torch.nn.Linear(self.temb_ch, self.temb_ch),
+ ]
+ )
+
+ # downsampling
+ self.conv_in = torch.nn.Conv2d(
+ in_channels, self.ch, kernel_size=3, stride=1, padding=1
+ )
+
+ curr_res = resolution
+ in_ch_mult = (1,) + tuple(ch_mult)
+ self.down = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_in = ch * in_ch_mult[i_level]
+ block_out = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks):
+ block.append(
+ ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+ )
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ down = nn.Module()
+ down.block = block
+ down.attn = attn
+ if i_level != self.num_resolutions - 1:
+ down.downsample = Downsample(block_in, resamp_with_conv)
+ curr_res = curr_res // 2
+ self.down.append(down)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
+ self.mid.block_2 = ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+
+ # upsampling
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = ch * ch_mult[i_level]
+ skip_in = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks + 1):
+ if i_block == self.num_res_blocks:
+ skip_in = ch * in_ch_mult[i_level]
+ block.append(
+ ResnetBlock(
+ in_channels=block_in + skip_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+ )
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if i_level != 0:
+ up.upsample = Upsample(block_in, resamp_with_conv)
+ curr_res = curr_res * 2
+ self.up.insert(0, up) # prepend to get consistent order
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(
+ block_in, out_ch, kernel_size=3, stride=1, padding=1
+ )
+
+ def forward(self, x, t=None, context=None):
+ # assert x.shape[2] == x.shape[3] == self.resolution
+ if context is not None:
+ # assume aligned context, cat along channel axis
+ x = torch.cat((x, context), dim=1)
+ if self.use_timestep:
+ # timestep embedding
+ assert t is not None
+ temb = get_timestep_embedding(t, self.ch)
+ temb = self.temb.dense[0](temb)
+ temb = nonlinearity(temb)
+ temb = self.temb.dense[1](temb)
+ else:
+ temb = None
+
+ # downsampling
+ hs = [self.conv_in(x)]
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level].block[i_block](hs[-1], temb)
+ if len(self.down[i_level].attn) > 0:
+ h = self.down[i_level].attn[i_block](h)
+ hs.append(h)
+ if i_level != self.num_resolutions - 1:
+ hs.append(self.down[i_level].downsample(hs[-1]))
+
+ # middle
+ h = hs[-1]
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks + 1):
+ h = self.up[i_level].block[i_block](
+ torch.cat([h, hs.pop()], dim=1), temb
+ )
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h)
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+
+ # end
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+ def get_last_layer(self):
+ return self.conv_out.weight
+
+
+class Encoder(nn.Module):
+ def __init__(
+ self,
+ *,
+ ch,
+ out_ch,
+ ch_mult=(1, 2, 4, 8),
+ num_res_blocks,
+ attn_resolutions,
+ dropout=0.0,
+ resamp_with_conv=True,
+ in_channels,
+ resolution,
+ z_channels,
+ double_z=True,
+ use_linear_attn=False,
+ attn_type="vanilla",
+ **ignore_kwargs,
+ ):
+ super().__init__()
+ if use_linear_attn:
+ attn_type = "linear"
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+
+ # downsampling
+ self.conv_in = torch.nn.Conv2d(
+ in_channels, self.ch, kernel_size=3, stride=1, padding=1
+ )
+
+ curr_res = resolution
+ in_ch_mult = (1,) + tuple(ch_mult)
+ self.in_ch_mult = in_ch_mult
+ self.down = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_in = ch * in_ch_mult[i_level]
+ block_out = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks):
+ block.append(
+ ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+ )
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ down = nn.Module()
+ down.block = block
+ down.attn = attn
+ if i_level != self.num_resolutions - 1:
+ down.downsample = Downsample(block_in, resamp_with_conv)
+ curr_res = curr_res // 2
+ self.down.append(down)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
+ self.mid.block_2 = ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(
+ block_in,
+ 2 * z_channels if double_z else z_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ )
+
+ def forward(self, x):
+ # timestep embedding
+ temb = None
+
+ # print(f'encoder-input={x.shape}')
+ # downsampling
+ hs = [self.conv_in(x)]
+ # print(f'encoder-conv in feat={hs[0].shape}')
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level].block[i_block](hs[-1], temb)
+ # print(f'encoder-down feat={h.shape}')
+ if len(self.down[i_level].attn) > 0:
+ h = self.down[i_level].attn[i_block](h)
+ hs.append(h)
+ if i_level != self.num_resolutions - 1:
+ # print(f'encoder-downsample (input)={hs[-1].shape}')
+ hs.append(self.down[i_level].downsample(hs[-1]))
+ # print(f'encoder-downsample (output)={hs[-1].shape}')
+
+ # middle
+ h = hs[-1]
+ h = self.mid.block_1(h, temb)
+ # print(f'encoder-mid1 feat={h.shape}')
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+ # print(f'encoder-mid2 feat={h.shape}')
+
+ # end
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ # print(f'end feat={h.shape}')
+ return h
+
+
+class Decoder(nn.Module):
+ def __init__(
+ self,
+ *,
+ ch,
+ out_ch,
+ ch_mult=(1, 2, 4, 8),
+ num_res_blocks,
+ attn_resolutions,
+ dropout=0.0,
+ resamp_with_conv=True,
+ in_channels,
+ resolution,
+ z_channels,
+ give_pre_end=False,
+ tanh_out=False,
+ use_linear_attn=False,
+ attn_type="vanilla",
+ **ignored_kwargs,
+ ):
+ super().__init__()
+ if use_linear_attn:
+ attn_type = "linear"
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+ self.give_pre_end = give_pre_end
+ self.tanh_out = tanh_out
+
+ # compute in_ch_mult, block_in and curr_res at lowest res
+ in_ch_mult = (1,) + tuple(ch_mult)
+ block_in = ch * ch_mult[self.num_resolutions - 1]
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
+ self.z_shape = (1, z_channels, curr_res, curr_res)
+ # print("AE working on z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
+
+ # z to block_in
+ self.conv_in = torch.nn.Conv2d(
+ z_channels, block_in, kernel_size=3, stride=1, padding=1
+ )
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
+ self.mid.block_2 = ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+
+ # upsampling
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks + 1):
+ block.append(
+ ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+ )
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if i_level != 0:
+ up.upsample = Upsample(block_in, resamp_with_conv)
+ curr_res = curr_res * 2
+ self.up.insert(0, up) # prepend to get consistent order
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(
+ block_in, out_ch, kernel_size=3, stride=1, padding=1
+ )
+
+ def forward(self, z):
+ # assert z.shape[1:] == self.z_shape[1:]
+ self.last_z_shape = z.shape
+
+ # print(f'decoder-input={z.shape}')
+ # timestep embedding
+ temb = None
+
+ # z to block_in
+ h = self.conv_in(z)
+ # print(f'decoder-conv in feat={h.shape}')
+
+ # middle
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+ # print(f'decoder-mid feat={h.shape}')
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks + 1):
+ h = self.up[i_level].block[i_block](h, temb)
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h)
+ # print(f'decoder-up feat={h.shape}')
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+ # print(f'decoder-upsample feat={h.shape}')
+
+ # end
+ if self.give_pre_end:
+ return h
+
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ # print(f'decoder-conv_out feat={h.shape}')
+ if self.tanh_out:
+ h = torch.tanh(h)
+ return h
+
+
+class SimpleDecoder(nn.Module):
+ def __init__(self, in_channels, out_channels, *args, **kwargs):
+ super().__init__()
+ self.model = nn.ModuleList(
+ [
+ nn.Conv2d(in_channels, in_channels, 1),
+ ResnetBlock(
+ in_channels=in_channels,
+ out_channels=2 * in_channels,
+ temb_channels=0,
+ dropout=0.0,
+ ),
+ ResnetBlock(
+ in_channels=2 * in_channels,
+ out_channels=4 * in_channels,
+ temb_channels=0,
+ dropout=0.0,
+ ),
+ ResnetBlock(
+ in_channels=4 * in_channels,
+ out_channels=2 * in_channels,
+ temb_channels=0,
+ dropout=0.0,
+ ),
+ nn.Conv2d(2 * in_channels, in_channels, 1),
+ Upsample(in_channels, with_conv=True),
+ ]
+ )
+ # end
+ self.norm_out = Normalize(in_channels)
+ self.conv_out = torch.nn.Conv2d(
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
+ )
+
+ def forward(self, x):
+ for i, layer in enumerate(self.model):
+ if i in [1, 2, 3]:
+ x = layer(x, None)
+ else:
+ x = layer(x)
+
+ h = self.norm_out(x)
+ h = nonlinearity(h)
+ x = self.conv_out(h)
+ return x
+
+
+class UpsampleDecoder(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ ch,
+ num_res_blocks,
+ resolution,
+ ch_mult=(2, 2),
+ dropout=0.0,
+ ):
+ super().__init__()
+ # upsampling
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ block_in = in_channels
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
+ self.res_blocks = nn.ModuleList()
+ self.upsample_blocks = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ res_block = []
+ block_out = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks + 1):
+ res_block.append(
+ ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+ )
+ block_in = block_out
+ self.res_blocks.append(nn.ModuleList(res_block))
+ if i_level != self.num_resolutions - 1:
+ self.upsample_blocks.append(Upsample(block_in, True))
+ curr_res = curr_res * 2
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(
+ block_in, out_channels, kernel_size=3, stride=1, padding=1
+ )
+
+ def forward(self, x):
+ # upsampling
+ h = x
+ for k, i_level in enumerate(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks + 1):
+ h = self.res_blocks[i_level][i_block](h, None)
+ if i_level != self.num_resolutions - 1:
+ h = self.upsample_blocks[k](h)
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+
+class LatentRescaler(nn.Module):
+ def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
+ super().__init__()
+ # residual block, interpolate, residual block
+ self.factor = factor
+ self.conv_in = nn.Conv2d(
+ in_channels, mid_channels, kernel_size=3, stride=1, padding=1
+ )
+ self.res_block1 = nn.ModuleList(
+ [
+ ResnetBlock(
+ in_channels=mid_channels,
+ out_channels=mid_channels,
+ temb_channels=0,
+ dropout=0.0,
+ )
+ for _ in range(depth)
+ ]
+ )
+ self.attn = AttnBlock(mid_channels)
+ self.res_block2 = nn.ModuleList(
+ [
+ ResnetBlock(
+ in_channels=mid_channels,
+ out_channels=mid_channels,
+ temb_channels=0,
+ dropout=0.0,
+ )
+ for _ in range(depth)
+ ]
+ )
+
+ self.conv_out = nn.Conv2d(
+ mid_channels,
+ out_channels,
+ kernel_size=1,
+ )
+
+ def forward(self, x):
+ x = self.conv_in(x)
+ for block in self.res_block1:
+ x = block(x, None)
+ x = torch.nn.functional.interpolate(
+ x,
+ size=(
+ int(round(x.shape[2] * self.factor)),
+ int(round(x.shape[3] * self.factor)),
+ ),
+ )
+ x = self.attn(x)
+ for block in self.res_block2:
+ x = block(x, None)
+ x = self.conv_out(x)
+ return x
+
+
+class MergedRescaleEncoder(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ ch,
+ resolution,
+ out_ch,
+ num_res_blocks,
+ attn_resolutions,
+ dropout=0.0,
+ resamp_with_conv=True,
+ ch_mult=(1, 2, 4, 8),
+ rescale_factor=1.0,
+ rescale_module_depth=1,
+ ):
+ super().__init__()
+ intermediate_chn = ch * ch_mult[-1]
+ self.encoder = Encoder(
+ in_channels=in_channels,
+ num_res_blocks=num_res_blocks,
+ ch=ch,
+ ch_mult=ch_mult,
+ z_channels=intermediate_chn,
+ double_z=False,
+ resolution=resolution,
+ attn_resolutions=attn_resolutions,
+ dropout=dropout,
+ resamp_with_conv=resamp_with_conv,
+ out_ch=None,
+ )
+ self.rescaler = LatentRescaler(
+ factor=rescale_factor,
+ in_channels=intermediate_chn,
+ mid_channels=intermediate_chn,
+ out_channels=out_ch,
+ depth=rescale_module_depth,
+ )
+
+ def forward(self, x):
+ x = self.encoder(x)
+ x = self.rescaler(x)
+ return x
+
+
+class MergedRescaleDecoder(nn.Module):
+ def __init__(
+ self,
+ z_channels,
+ out_ch,
+ resolution,
+ num_res_blocks,
+ attn_resolutions,
+ ch,
+ ch_mult=(1, 2, 4, 8),
+ dropout=0.0,
+ resamp_with_conv=True,
+ rescale_factor=1.0,
+ rescale_module_depth=1,
+ ):
+ super().__init__()
+ tmp_chn = z_channels * ch_mult[-1]
+ self.decoder = Decoder(
+ out_ch=out_ch,
+ z_channels=tmp_chn,
+ attn_resolutions=attn_resolutions,
+ dropout=dropout,
+ resamp_with_conv=resamp_with_conv,
+ in_channels=None,
+ num_res_blocks=num_res_blocks,
+ ch_mult=ch_mult,
+ resolution=resolution,
+ ch=ch,
+ )
+ self.rescaler = LatentRescaler(
+ factor=rescale_factor,
+ in_channels=z_channels,
+ mid_channels=tmp_chn,
+ out_channels=tmp_chn,
+ depth=rescale_module_depth,
+ )
+
+ def forward(self, x):
+ x = self.rescaler(x)
+ x = self.decoder(x)
+ return x
+
+
+class Upsampler(nn.Module):
+ def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
+ super().__init__()
+ assert out_size >= in_size
+ num_blocks = int(np.log2(out_size // in_size)) + 1
+ factor_up = 1.0 + (out_size % in_size)
+ print(
+ f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}"
+ )
+ self.rescaler = LatentRescaler(
+ factor=factor_up,
+ in_channels=in_channels,
+ mid_channels=2 * in_channels,
+ out_channels=in_channels,
+ )
+ self.decoder = Decoder(
+ out_ch=out_channels,
+ resolution=out_size,
+ z_channels=in_channels,
+ num_res_blocks=2,
+ attn_resolutions=[],
+ in_channels=None,
+ ch=in_channels,
+ ch_mult=[ch_mult for _ in range(num_blocks)],
+ )
+
+ def forward(self, x):
+ x = self.rescaler(x)
+ x = self.decoder(x)
+ return x
+
+
+class Resize(nn.Module):
+ def __init__(self, in_channels=None, learned=False, mode="bilinear"):
+ super().__init__()
+ self.with_conv = learned
+ self.mode = mode
+ if self.with_conv:
+ print(
+ f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode"
+ )
+ raise NotImplementedError()
+ assert in_channels is not None
+ # no asymmetric padding in torch conv, must do it ourselves
+ self.conv = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=4, stride=2, padding=1
+ )
+
+ def forward(self, x, scale_factor=1.0):
+ if scale_factor == 1.0:
+ return x
+ else:
+ x = torch.nn.functional.interpolate(
+ x, mode=self.mode, align_corners=False, scale_factor=scale_factor
+ )
+ return x
+
+
+class FirstStagePostProcessor(nn.Module):
+
+ def __init__(
+ self,
+ ch_mult: list,
+ in_channels,
+ pretrained_model: nn.Module = None,
+ reshape=False,
+ n_channels=None,
+ dropout=0.0,
+ pretrained_config=None,
+ ):
+ super().__init__()
+ if pretrained_config is None:
+ assert (
+ pretrained_model is not None
+ ), 'Either "pretrained_model" or "pretrained_config" must not be None'
+ self.pretrained_model = pretrained_model
+ else:
+ assert (
+ pretrained_config is not None
+ ), 'Either "pretrained_model" or "pretrained_config" must not be None'
+ self.instantiate_pretrained(pretrained_config)
+
+ self.do_reshape = reshape
+
+ if n_channels is None:
+ n_channels = self.pretrained_model.encoder.ch
+
+ self.proj_norm = Normalize(in_channels, num_groups=in_channels // 2)
+ self.proj = nn.Conv2d(
+ in_channels, n_channels, kernel_size=3, stride=1, padding=1
+ )
+
+ blocks = []
+ downs = []
+ ch_in = n_channels
+ for m in ch_mult:
+ blocks.append(
+ ResnetBlock(
+ in_channels=ch_in, out_channels=m * n_channels, dropout=dropout
+ )
+ )
+ ch_in = m * n_channels
+ downs.append(Downsample(ch_in, with_conv=False))
+
+ self.model = nn.ModuleList(blocks)
+ self.downsampler = nn.ModuleList(downs)
+
+ def instantiate_pretrained(self, config):
+ model = instantiate_from_config(config)
+ self.pretrained_model = model.eval()
+ # self.pretrained_model.train = False
+ for param in self.pretrained_model.parameters():
+ param.requires_grad = False
+
+ @torch.no_grad()
+ def encode_with_pretrained(self, x):
+ c = self.pretrained_model.encode(x)
+ if isinstance(c, DiagonalGaussianDistribution):
+ c = c.mode()
+ return c
+
+ def forward(self, x):
+ z_fs = self.encode_with_pretrained(x)
+ z = self.proj_norm(z_fs)
+ z = self.proj(z)
+ z = nonlinearity(z)
+
+ for submodel, downmodel in zip(self.model, self.downsampler):
+ z = submodel(z, temb=None)
+ z = downmodel(z)
+
+ if self.do_reshape:
+ z = rearrange(z, "b c h w -> b (h w) c")
+ return z
diff --git a/core/modules/networks/unet_modules.py b/core/modules/networks/unet_modules.py
new file mode 100755
index 0000000000000000000000000000000000000000..199b9d7202381c329274ae16be48beb18d578354
--- /dev/null
+++ b/core/modules/networks/unet_modules.py
@@ -0,0 +1,1047 @@
+from functools import partial
+from abc import abstractmethod
+import torch
+import torch.nn as nn
+from einops import rearrange
+import torch.nn.functional as F
+from core.models.utils_diffusion import timestep_embedding
+from core.common import gradient_checkpoint
+from core.basics import zero_module, conv_nd, linear, avg_pool_nd, normalization
+from core.modules.attention import SpatialTransformer, TemporalTransformer
+
+TASK_IDX_IMAGE = 0
+TASK_IDX_RAY = 1
+
+
+class TimestepBlock(nn.Module):
+ """
+ Any module where forward() takes timestep embeddings as a second argument.
+ """
+
+ @abstractmethod
+ def forward(self, x, emb):
+ """
+ Apply the module to `x` given `emb` timestep embeddings.
+ """
+
+
+class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
+ """
+ A sequential module that passes timestep embeddings to the children that
+ support it as an extra input.
+ """
+
+ def forward(
+ self, x, emb, context=None, batch_size=None, with_lora=False, time_steps=None
+ ):
+ for layer in self:
+ if isinstance(layer, TimestepBlock):
+ x = layer(x, emb, batch_size=batch_size)
+ elif isinstance(layer, SpatialTransformer):
+ x = layer(x, context, with_lora=with_lora)
+ elif isinstance(layer, TemporalTransformer):
+ x = rearrange(x, "(b f) c h w -> b c f h w", b=batch_size)
+ x = layer(x, context, with_lora=with_lora, time_steps=time_steps)
+ x = rearrange(x, "b c f h w -> (b f) c h w")
+ else:
+ x = layer(x)
+ return x
+
+
+class Downsample(nn.Module):
+ """
+ A downsampling layer with an optional convolution.
+ :param channels: channels in the inputs and outputs.
+ :param use_conv: a bool determining if a convolution is applied.
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ downsampling occurs in the inner-two dimensions.
+ """
+
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.dims = dims
+ stride = 2 if dims != 3 else (1, 2, 2)
+ if use_conv:
+ self.op = conv_nd(
+ dims,
+ self.channels,
+ self.out_channels,
+ 3,
+ stride=stride,
+ padding=padding,
+ )
+ else:
+ assert self.channels == self.out_channels
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ return self.op(x)
+
+
+class Upsample(nn.Module):
+ """
+ An upsampling layer with an optional convolution.
+ :param channels: channels in the inputs and outputs.
+ :param use_conv: a bool determining if a convolution is applied.
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ upsampling occurs in the inner-two dimensions.
+ """
+
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.dims = dims
+ if use_conv:
+ self.conv = conv_nd(
+ dims, self.channels, self.out_channels, 3, padding=padding
+ )
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ if self.dims == 3:
+ x = F.interpolate(
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
+ )
+ else:
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
+ if self.use_conv:
+ x = self.conv(x)
+ return x
+
+
+class ResBlock(TimestepBlock):
+ """
+ A residual block that can optionally change the number of channels.
+ :param channels: the number of input channels.
+ :param emb_channels: the number of timestep embedding channels.
+ :param dropout: the rate of dropout.
+ :param out_channels: if specified, the number of out channels.
+ :param use_conv: if True and out_channels is specified, use a spatial
+ convolution instead of a smaller 1x1 convolution to change the
+ channels in the skip connection.
+ :param dims: determines if the signal is 1D, 2D, or 3D.
+ :param up: if True, use this block for upsampling.
+ :param down: if True, use this block for downsampling.
+ :param use_temporal_conv: if True, use the temporal convolution.
+ :param use_image_dataset: if True, the temporal parameters will not be optimized.
+ """
+
+ def __init__(
+ self,
+ channels,
+ emb_channels,
+ dropout,
+ out_channels=None,
+ use_scale_shift_norm=False,
+ dims=2,
+ use_checkpoint=False,
+ use_conv=False,
+ up=False,
+ down=False,
+ use_temporal_conv=False,
+ tempspatial_aware=False,
+ ):
+ super().__init__()
+ self.channels = channels
+ self.emb_channels = emb_channels
+ self.dropout = dropout
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.use_checkpoint = use_checkpoint
+ self.use_scale_shift_norm = use_scale_shift_norm
+ self.use_temporal_conv = use_temporal_conv
+
+ self.in_layers = nn.Sequential(
+ normalization(channels),
+ nn.SiLU(),
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
+ )
+
+ self.updown = up or down
+
+ if up:
+ self.h_upd = Upsample(channels, False, dims)
+ self.x_upd = Upsample(channels, False, dims)
+ elif down:
+ self.h_upd = Downsample(channels, False, dims)
+ self.x_upd = Downsample(channels, False, dims)
+ else:
+ self.h_upd = self.x_upd = nn.Identity()
+
+ self.emb_layers = nn.Sequential(
+ nn.SiLU(),
+ nn.Linear(
+ emb_channels,
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
+ ),
+ )
+ self.out_layers = nn.Sequential(
+ normalization(self.out_channels),
+ nn.SiLU(),
+ nn.Dropout(p=dropout),
+ zero_module(nn.Conv2d(self.out_channels, self.out_channels, 3, padding=1)),
+ )
+
+ if self.out_channels == channels:
+ self.skip_connection = nn.Identity()
+ elif use_conv:
+ self.skip_connection = conv_nd(
+ dims, channels, self.out_channels, 3, padding=1
+ )
+ else:
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
+
+ if self.use_temporal_conv:
+ self.temopral_conv = TemporalConvBlock(
+ self.out_channels,
+ self.out_channels,
+ dropout=0.1,
+ spatial_aware=tempspatial_aware,
+ )
+
+ def forward(self, x, emb, batch_size=None):
+ """
+ Apply the block to a Tensor, conditioned on a timestep embedding.
+ :param x: an [N x C x ...] Tensor of features.
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
+ :return: an [N x C x ...] Tensor of outputs.
+ """
+ input_tuple = (x, emb)
+ if batch_size:
+ forward_batchsize = partial(self._forward, batch_size=batch_size)
+ return gradient_checkpoint(
+ forward_batchsize, input_tuple, self.parameters(), self.use_checkpoint
+ )
+ return gradient_checkpoint(
+ self._forward, input_tuple, self.parameters(), self.use_checkpoint
+ )
+
+ def _forward(self, x, emb, batch_size=None):
+ if self.updown:
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
+ h = in_rest(x)
+ h = self.h_upd(h)
+ x = self.x_upd(x)
+ h = in_conv(h)
+ else:
+ h = self.in_layers(x)
+ emb_out = self.emb_layers(emb).type(h.dtype)
+ while len(emb_out.shape) < len(h.shape):
+ emb_out = emb_out[..., None]
+ if self.use_scale_shift_norm:
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
+ scale, shift = torch.chunk(emb_out, 2, dim=1)
+ h = out_norm(h) * (1 + scale) + shift
+ h = out_rest(h)
+ else:
+ h = h + emb_out
+ h = self.out_layers(h)
+ h = self.skip_connection(x) + h
+
+ if self.use_temporal_conv and batch_size:
+ h = rearrange(h, "(b t) c h w -> b c t h w", b=batch_size)
+ h = self.temopral_conv(h)
+ h = rearrange(h, "b c t h w -> (b t) c h w")
+ return h
+
+
+class TemporalConvBlock(nn.Module):
+ def __init__(
+ self, in_channels, out_channels=None, dropout=0.0, spatial_aware=False
+ ):
+ super(TemporalConvBlock, self).__init__()
+ if out_channels is None:
+ out_channels = in_channels
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ th_kernel_shape = (3, 1, 1) if not spatial_aware else (3, 3, 1)
+ th_padding_shape = (1, 0, 0) if not spatial_aware else (1, 1, 0)
+ tw_kernel_shape = (3, 1, 1) if not spatial_aware else (3, 1, 3)
+ tw_padding_shape = (1, 0, 0) if not spatial_aware else (1, 0, 1)
+
+ # conv layers
+ self.conv1 = nn.Sequential(
+ nn.GroupNorm(32, in_channels),
+ nn.SiLU(),
+ nn.Conv3d(
+ in_channels, out_channels, th_kernel_shape, padding=th_padding_shape
+ ),
+ )
+ self.conv2 = nn.Sequential(
+ nn.GroupNorm(32, out_channels),
+ nn.SiLU(),
+ nn.Dropout(dropout),
+ nn.Conv3d(
+ out_channels, in_channels, tw_kernel_shape, padding=tw_padding_shape
+ ),
+ )
+ self.conv3 = nn.Sequential(
+ nn.GroupNorm(32, out_channels),
+ nn.SiLU(),
+ nn.Dropout(dropout),
+ nn.Conv3d(
+ out_channels, in_channels, th_kernel_shape, padding=th_padding_shape
+ ),
+ )
+ self.conv4 = nn.Sequential(
+ nn.GroupNorm(32, out_channels),
+ nn.SiLU(),
+ nn.Dropout(dropout),
+ nn.Conv3d(
+ out_channels, in_channels, tw_kernel_shape, padding=tw_padding_shape
+ ),
+ )
+
+ # zero out the last layer params,so the conv block is identity
+ nn.init.zeros_(self.conv4[-1].weight)
+ nn.init.zeros_(self.conv4[-1].bias)
+
+ def forward(self, x):
+ identity = x
+ x = self.conv1(x)
+ x = self.conv2(x)
+ x = self.conv3(x)
+ x = self.conv4(x)
+
+ return identity + x
+
+
+class UNetModel(nn.Module):
+ """
+ The full UNet model with attention and timestep embedding.
+ :param in_channels: in_channels in the input Tensor.
+ :param model_channels: base channel count for the model.
+ :param out_channels: channels in the output Tensor.
+ :param num_res_blocks: number of residual blocks per downsample.
+ :param attention_resolutions: a collection of downsample rates at which
+ attention will take place. May be a set, list, or tuple.
+ For example, if this contains 4, then at 4x downsampling, attention
+ will be used.
+ :param dropout: the dropout probability.
+ :param channel_mult: channel multiplier for each level of the UNet.
+ :param conv_resample: if True, use learned convolutions for upsampling and
+ downsampling.
+ :param dims: determines if the signal is 1D, 2D, or 3D.
+ :param num_classes: if specified (as an int), then this model will be
+ class-conditional with `num_classes` classes.
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
+ :param num_heads: the number of attention heads in each attention layer.
+ :param num_heads_channels: if specified, ignore num_heads and instead use
+ a fixed channel width per attention head.
+ :param num_heads_upsample: works with num_heads to set a different number
+ of heads for upsampling. Deprecated.
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
+ :param resblock_updown: use residual blocks for up/downsampling.
+ :param use_new_attention_order: use a different attention pattern for potentially
+ increased efficiency.
+ """
+
+ def __init__(
+ self,
+ in_channels,
+ model_channels,
+ out_channels,
+ num_res_blocks,
+ attention_resolutions,
+ dropout=0.0,
+ channel_mult=(1, 2, 4, 8),
+ conv_resample=True,
+ dims=2,
+ context_dim=None,
+ use_scale_shift_norm=False,
+ resblock_updown=False,
+ num_heads=-1,
+ num_head_channels=-1,
+ transformer_depth=1,
+ use_linear=False,
+ use_checkpoint=False,
+ temporal_conv=False,
+ tempspatial_aware=False,
+ temporal_attention=True,
+ use_relative_position=True,
+ use_causal_attention=False,
+ temporal_length=None,
+ use_fp16=False,
+ addition_attention=False,
+ temporal_selfatt_only=True,
+ image_cross_attention=False,
+ image_cross_attention_scale_learnable=False,
+ default_fs=4,
+ fs_condition=False,
+ use_spatial_temporal_attention=False,
+ # >>> Extra Ray Options
+ use_addition_ray_output_head=False,
+ ray_channels=6,
+ use_lora_for_rays_in_output_blocks=False,
+ use_task_embedding=False,
+ use_ray_decoder=False,
+ use_ray_decoder_residual=False,
+ full_spatial_temporal_attention=False,
+ enhance_multi_view_correspondence=False,
+ camera_pose_condition=False,
+ use_feature_alignment=False,
+ ):
+ super(UNetModel, self).__init__()
+ if num_heads == -1:
+ assert (
+ num_head_channels != -1
+ ), "Either num_heads or num_head_channels has to be set"
+ if num_head_channels == -1:
+ assert (
+ num_heads != -1
+ ), "Either num_heads or num_head_channels has to be set"
+
+ self.in_channels = in_channels
+ self.model_channels = model_channels
+ self.out_channels = out_channels
+ self.num_res_blocks = num_res_blocks
+ self.attention_resolutions = attention_resolutions
+ self.dropout = dropout
+ self.channel_mult = channel_mult
+ self.conv_resample = conv_resample
+ self.temporal_attention = temporal_attention
+ time_embed_dim = model_channels * 4
+ self.use_checkpoint = use_checkpoint
+ self.dtype = torch.float16 if use_fp16 else torch.float32
+ temporal_self_att_only = True
+ self.addition_attention = addition_attention
+ self.temporal_length = temporal_length
+ self.image_cross_attention = image_cross_attention
+ self.image_cross_attention_scale_learnable = (
+ image_cross_attention_scale_learnable
+ )
+ self.default_fs = default_fs
+ self.fs_condition = fs_condition
+ self.use_spatial_temporal_attention = use_spatial_temporal_attention
+
+ # >>> Extra Ray Options
+ self.use_addition_ray_output_head = use_addition_ray_output_head
+ self.use_lora_for_rays_in_output_blocks = use_lora_for_rays_in_output_blocks
+ if self.use_lora_for_rays_in_output_blocks:
+ assert (
+ use_addition_ray_output_head
+ ), "`use_addition_ray_output_head` is required to be True when using LoRA for rays in output blocks."
+ assert (
+ not use_task_embedding
+ ), "`use_task_embedding` cannot be True when `use_lora_for_rays_in_output_blocks` is enabled."
+ if self.use_addition_ray_output_head:
+ print("Using additional ray output head...")
+ assert (self.out_channels == 4) or (
+ 4 + ray_channels == self.out_channels
+ ), f"`out_channels`={out_channels} is invalid."
+ self.out_channels = 4
+ out_channels = 4
+ self.ray_channels = ray_channels
+ self.use_ray_decoder = use_ray_decoder
+ if use_ray_decoder:
+ assert (
+ not use_task_embedding
+ ), "`use_task_embedding` cannot be True when `use_ray_decoder_layers` is enabled."
+ assert (
+ use_addition_ray_output_head
+ ), "`use_addition_ray_output_head` must be True when `use_ray_decoder_layers` is enabled."
+ self.use_ray_decoder_residual = use_ray_decoder_residual
+
+ # >>> Time/Task Embedding Blocks
+ self.time_embed = nn.Sequential(
+ linear(model_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ )
+ if fs_condition:
+ self.fps_embedding = nn.Sequential(
+ linear(model_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ )
+ nn.init.zeros_(self.fps_embedding[-1].weight)
+ nn.init.zeros_(self.fps_embedding[-1].bias)
+
+ if camera_pose_condition:
+ self.camera_pose_condition = True
+ self.camera_pose_embedding = nn.Sequential(
+ linear(12, model_channels),
+ nn.SiLU(),
+ linear(model_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ )
+ nn.init.zeros_(self.camera_pose_embedding[-1].weight)
+ nn.init.zeros_(self.camera_pose_embedding[-1].bias)
+
+ self.use_task_embedding = use_task_embedding
+ if use_task_embedding:
+ assert (
+ not use_lora_for_rays_in_output_blocks
+ ), "`use_lora_for_rays_in_output_blocks` and `use_task_embedding` cannot be True at the same time."
+ assert (
+ use_addition_ray_output_head
+ ), "`use_addition_ray_output_head` is required to be True when `use_task_embedding` is enabled."
+ self.task_embedding = nn.Sequential(
+ linear(model_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ )
+ nn.init.zeros_(self.task_embedding[-1].weight)
+ nn.init.zeros_(self.task_embedding[-1].bias)
+ self.task_parameters = nn.ParameterList(
+ [
+ nn.Parameter(
+ torch.zeros(size=[model_channels], requires_grad=True)
+ ),
+ nn.Parameter(
+ torch.zeros(size=[model_channels], requires_grad=True)
+ ),
+ ]
+ )
+
+ # >>> Input Block
+ self.input_blocks = nn.ModuleList(
+ [
+ TimestepEmbedSequential(
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
+ )
+ ]
+ )
+ if self.addition_attention:
+ self.init_attn = TimestepEmbedSequential(
+ TemporalTransformer(
+ model_channels,
+ n_heads=8,
+ d_head=num_head_channels,
+ depth=transformer_depth,
+ context_dim=context_dim,
+ use_checkpoint=use_checkpoint,
+ only_self_att=temporal_selfatt_only,
+ causal_attention=False,
+ relative_position=use_relative_position,
+ temporal_length=temporal_length,
+ )
+ )
+
+ input_block_chans = [model_channels]
+ ch = model_channels
+ ds = 1
+ for level, mult in enumerate(channel_mult):
+ for _ in range(num_res_blocks):
+ layers = [
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=mult * model_channels,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ tempspatial_aware=tempspatial_aware,
+ use_temporal_conv=temporal_conv,
+ )
+ ]
+ ch = mult * model_channels
+ if ds in attention_resolutions:
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ layers.append(
+ SpatialTransformer(
+ ch,
+ num_heads,
+ dim_head,
+ depth=transformer_depth,
+ context_dim=context_dim,
+ use_linear=use_linear,
+ use_checkpoint=use_checkpoint,
+ disable_self_attn=False,
+ video_length=temporal_length,
+ image_cross_attention=self.image_cross_attention,
+ image_cross_attention_scale_learnable=self.image_cross_attention_scale_learnable,
+ )
+ )
+ if self.temporal_attention:
+ layers.append(
+ TemporalTransformer(
+ ch,
+ num_heads,
+ dim_head,
+ depth=transformer_depth,
+ context_dim=context_dim,
+ use_linear=use_linear,
+ use_checkpoint=use_checkpoint,
+ only_self_att=temporal_self_att_only,
+ causal_attention=use_causal_attention,
+ relative_position=use_relative_position,
+ temporal_length=temporal_length,
+ )
+ )
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
+ input_block_chans.append(ch)
+ if level != len(channel_mult) - 1:
+ out_ch = ch
+ self.input_blocks.append(
+ TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ down=True,
+ )
+ if resblock_updown
+ else Downsample(
+ ch, conv_resample, dims=dims, out_channels=out_ch
+ )
+ )
+ )
+ ch = out_ch
+ input_block_chans.append(ch)
+ ds *= 2
+
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ layers = [
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ tempspatial_aware=tempspatial_aware,
+ use_temporal_conv=temporal_conv,
+ ),
+ SpatialTransformer(
+ ch,
+ num_heads,
+ dim_head,
+ depth=transformer_depth,
+ context_dim=context_dim,
+ use_linear=use_linear,
+ use_checkpoint=use_checkpoint,
+ disable_self_attn=False,
+ video_length=temporal_length,
+ image_cross_attention=self.image_cross_attention,
+ image_cross_attention_scale_learnable=self.image_cross_attention_scale_learnable,
+ ),
+ ]
+ if self.temporal_attention:
+ layers.append(
+ TemporalTransformer(
+ ch,
+ num_heads,
+ dim_head,
+ depth=transformer_depth,
+ context_dim=context_dim,
+ use_linear=use_linear,
+ use_checkpoint=use_checkpoint,
+ only_self_att=temporal_self_att_only,
+ causal_attention=use_causal_attention,
+ relative_position=use_relative_position,
+ temporal_length=temporal_length,
+ )
+ )
+ layers.append(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ tempspatial_aware=tempspatial_aware,
+ use_temporal_conv=temporal_conv,
+ )
+ )
+
+ # >>> Middle Block
+ self.middle_block = TimestepEmbedSequential(*layers)
+
+ # >>> Ray Decoder
+ if use_ray_decoder:
+ self.ray_decoder_blocks = nn.ModuleList([])
+
+ # >>> Output Block
+ is_first_layer = True
+ self.output_blocks = nn.ModuleList([])
+ for level, mult in list(enumerate(channel_mult))[::-1]:
+ for i in range(num_res_blocks + 1):
+ ich = input_block_chans.pop()
+ layers = [
+ ResBlock(
+ ch + ich,
+ time_embed_dim,
+ dropout,
+ out_channels=mult * model_channels,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ tempspatial_aware=tempspatial_aware,
+ use_temporal_conv=temporal_conv,
+ )
+ ]
+ if use_ray_decoder:
+ if self.use_ray_decoder_residual:
+ ray_residual_ch = ich
+ else:
+ ray_residual_ch = 0
+ ray_decoder_layers = [
+ ResBlock(
+ (ch if is_first_layer else (ch // 10)) + ray_residual_ch,
+ time_embed_dim,
+ dropout,
+ out_channels=mult * model_channels // 10,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ tempspatial_aware=tempspatial_aware,
+ use_temporal_conv=True,
+ )
+ ]
+ is_first_layer = False
+ ch = model_channels * mult
+ if ds in attention_resolutions:
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ layers.append(
+ SpatialTransformer(
+ ch,
+ num_heads,
+ dim_head,
+ depth=transformer_depth,
+ context_dim=context_dim,
+ use_linear=use_linear,
+ use_checkpoint=use_checkpoint,
+ disable_self_attn=False,
+ video_length=temporal_length,
+ image_cross_attention=self.image_cross_attention,
+ image_cross_attention_scale_learnable=self.image_cross_attention_scale_learnable,
+ enable_lora=self.use_lora_for_rays_in_output_blocks,
+ )
+ )
+ if self.temporal_attention:
+ layers.append(
+ TemporalTransformer(
+ ch,
+ num_heads,
+ dim_head,
+ depth=transformer_depth,
+ context_dim=context_dim,
+ use_linear=use_linear,
+ use_checkpoint=use_checkpoint,
+ only_self_att=temporal_self_att_only,
+ causal_attention=use_causal_attention,
+ relative_position=use_relative_position,
+ temporal_length=temporal_length,
+ use_extra_spatial_temporal_self_attention=use_spatial_temporal_attention,
+ enable_lora=self.use_lora_for_rays_in_output_blocks,
+ full_spatial_temporal_attention=full_spatial_temporal_attention,
+ enhance_multi_view_correspondence=enhance_multi_view_correspondence,
+ )
+ )
+ if level and i == num_res_blocks:
+ out_ch = ch
+ # out_ray_ch = ray_ch
+ layers.append(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ up=True,
+ )
+ if resblock_updown
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
+ )
+ if use_ray_decoder:
+ ray_decoder_layers.append(
+ ResBlock(
+ ch // 10,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch // 10,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ up=True,
+ )
+ if resblock_updown
+ else Upsample(
+ ch // 10,
+ conv_resample,
+ dims=dims,
+ out_channels=out_ch // 10,
+ )
+ )
+ ds //= 2
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
+ if use_ray_decoder:
+ self.ray_decoder_blocks.append(
+ TimestepEmbedSequential(*ray_decoder_layers)
+ )
+
+ self.out = nn.Sequential(
+ normalization(ch),
+ nn.SiLU(),
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
+ )
+
+ if self.use_addition_ray_output_head:
+ ray_model_channels = model_channels // 10
+ self.ray_output_head = nn.Sequential(
+ normalization(ray_model_channels),
+ nn.SiLU(),
+ conv_nd(dims, ray_model_channels, ray_model_channels, 3, padding=1),
+ nn.SiLU(),
+ conv_nd(dims, ray_model_channels, ray_model_channels, 3, padding=1),
+ nn.SiLU(),
+ zero_module(
+ conv_nd(dims, ray_model_channels, self.ray_channels, 3, padding=1)
+ ),
+ )
+ self.use_feature_alignment = use_feature_alignment
+ if self.use_feature_alignment:
+ self.feature_alignment_adapter = FeatureAlignmentAdapter(
+ time_embed_dim=time_embed_dim, use_checkpoint=use_checkpoint
+ )
+
+ def forward(
+ self,
+ x,
+ time_steps,
+ context=None,
+ features_adapter=None,
+ fs=None,
+ task_idx=None,
+ camera_poses=None,
+ return_input_block_features=False,
+ return_middle_feature=False,
+ return_output_block_features=False,
+ **kwargs,
+ ):
+ intermediate_features = {}
+ if return_input_block_features:
+ intermediate_features["input"] = []
+ if return_output_block_features:
+ intermediate_features["output"] = []
+ b, t, _, _, _ = x.shape
+ t_emb = timestep_embedding(
+ time_steps, self.model_channels, repeat_only=False
+ ).type(x.dtype)
+ emb = self.time_embed(t_emb)
+
+ # repeat t times for context [(b t) 77 768] & time embedding
+ # check if we use per-frame image conditioning
+ _, l_context, _ = context.shape
+ if l_context == 77 + t * 16: # !!! HARD CODE here
+ context_text, context_img = context[:, :77, :], context[:, 77:, :]
+ context_text = context_text.repeat_interleave(repeats=t, dim=0)
+ context_img = rearrange(context_img, "b (t l) c -> (b t) l c", t=t)
+ context = torch.cat([context_text, context_img], dim=1)
+ else:
+ context = context.repeat_interleave(repeats=t, dim=0)
+ emb = emb.repeat_interleave(repeats=t, dim=0)
+
+ # always in shape (b t) c h w, except for temporal layer
+ x = rearrange(x, "b t c h w -> (b t) c h w")
+
+ # combine emb
+ if self.fs_condition:
+ if fs is None:
+ fs = torch.tensor(
+ [self.default_fs] * b, dtype=torch.long, device=x.device
+ )
+ fs_emb = timestep_embedding(
+ fs, self.model_channels, repeat_only=False
+ ).type(x.dtype)
+
+ fs_embed = self.fps_embedding(fs_emb)
+ fs_embed = fs_embed.repeat_interleave(repeats=t, dim=0)
+ emb = emb + fs_embed
+
+ if self.camera_pose_condition:
+ # camera_poses: (b, t, 12)
+ camera_poses = rearrange(camera_poses, "b t x y -> (b t) (x y)") # x=3, y=4
+ camera_poses_embed = self.camera_pose_embedding(camera_poses)
+ emb = emb + camera_poses_embed
+
+ if self.use_task_embedding:
+ assert (
+ task_idx is not None
+ ), "`task_idx` should not be None when `use_task_embedding` is enabled."
+ task_embed = self.task_embedding(
+ self.task_parameters[task_idx]
+ .reshape(1, self.model_channels)
+ .repeat(b, 1)
+ )
+ task_embed = task_embed.repeat_interleave(repeats=t, dim=0)
+ emb = emb + task_embed
+
+ h = x.type(self.dtype)
+ adapter_idx = 0
+ hs = []
+ for _id, module in enumerate(self.input_blocks):
+
+ h = module(h, emb, context=context, batch_size=b)
+ if _id == 0 and self.addition_attention:
+ h = self.init_attn(h, emb, context=context, batch_size=b)
+ # plug-in adapter features
+ if ((_id + 1) % 3 == 0) and features_adapter is not None:
+ h = h + features_adapter[adapter_idx]
+ adapter_idx += 1
+ hs.append(h)
+ if return_input_block_features:
+ intermediate_features["input"].append(h)
+ if features_adapter is not None:
+ assert len(features_adapter) == adapter_idx, "Wrong features_adapter"
+
+ h = self.middle_block(h, emb, context=context, batch_size=b)
+
+ if return_middle_feature:
+ intermediate_features["middle"] = h
+
+ if self.use_feature_alignment:
+ feature_alignment_output = self.feature_alignment_adapter(
+ hs[2], hs[5], hs[8], emb=emb
+ )
+
+ # >>> Output Blocks Forward
+ if self.use_ray_decoder:
+ h_original = h
+ h_ray = h
+ for original_module, ray_module in zip(
+ self.output_blocks, self.ray_decoder_blocks
+ ):
+ cur_hs = hs.pop()
+ h_original = torch.cat([h_original, cur_hs], dim=1)
+ h_original = original_module(
+ h_original,
+ emb,
+ context=context,
+ batch_size=b,
+ time_steps=time_steps,
+ )
+ if self.use_ray_decoder_residual:
+ h_ray = torch.cat([h_ray, cur_hs], dim=1)
+ h_ray = ray_module(h_ray, emb, context=context, batch_size=b)
+ if return_output_block_features:
+ print(
+ "return_output_block_features: h_original.shape=",
+ h_original.shape,
+ )
+ intermediate_features["output"].append(h_original.detach())
+ h_original = h_original.type(x.dtype)
+ h_ray = h_ray.type(x.dtype)
+ y_original = self.out(h_original)
+ y_ray = self.ray_output_head(h_ray)
+ y = torch.cat([y_original, y_ray], dim=1)
+ else:
+ if self.use_lora_for_rays_in_output_blocks:
+ middle_h = h
+ h_original = middle_h
+ h_lora = middle_h
+ for output_idx, module in enumerate(self.output_blocks):
+ cur_hs = hs.pop()
+ h_original = torch.cat([h_original, cur_hs], dim=1)
+ h_original = module(
+ h_original, emb, context=context, batch_size=b, with_lora=False
+ )
+
+ h_lora = torch.cat([h_lora, cur_hs], dim=1)
+ h_lora = module(
+ h_lora, emb, context=context, batch_size=b, with_lora=True
+ )
+ h_original = h_original.type(x.dtype)
+ h_lora = h_lora.type(x.dtype)
+ y_original = self.out(h_original)
+ y_lora = self.ray_output_head(h_lora)
+ y = torch.cat([y_original, y_lora], dim=1)
+ else:
+ for module in self.output_blocks:
+ h = torch.cat([h, hs.pop()], dim=1)
+ h = module(h, emb, context=context, batch_size=b)
+ h = h.type(x.dtype)
+
+ if self.use_task_embedding:
+ # Seperated Input (Branch Control in CPU)
+ # Serial Execution (GPU Vectorization Pending)
+ if task_idx == TASK_IDX_IMAGE:
+ y = self.out(h)
+ elif task_idx == TASK_IDX_RAY:
+ y = self.ray_output_head(h)
+ else:
+ raise NotImplementedError(f"Unsupported `task_idx`: {task_idx}")
+ else:
+ # Output ray and images at the same forward
+ y = self.out(h)
+
+ if self.use_addition_ray_output_head:
+ y_ray = self.ray_output_head(h)
+ y = torch.cat([y, y_ray], dim=1)
+ # reshape back to (b c t h w)
+ y = rearrange(y, "(b t) c h w -> b t c h w", b=b)
+ if (
+ return_input_block_features
+ or return_output_block_features
+ or return_middle_feature
+ ):
+ return y, intermediate_features
+ # Assume intermediate features are only request during non-training scenarios (e.g., feature visualization)
+ if self.use_feature_alignment:
+ return y, feature_alignment_output
+ return y
+
+
+class FeatureAlignmentAdapter(torch.nn.Module):
+ def __init__(self, time_embed_dim, use_checkpoint, dropout=0.0, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.channel_adapter_conv_16 = torch.nn.Conv2d(
+ in_channels=1280, out_channels=320, kernel_size=1
+ )
+ self.channel_adapter_conv_32 = torch.nn.Conv2d(
+ in_channels=640, out_channels=320, kernel_size=1
+ )
+ self.upsampler_x2 = torch.nn.UpsamplingBilinear2d(scale_factor=2)
+ self.upsampler_x4 = torch.nn.UpsamplingBilinear2d(scale_factor=4)
+ self.res_block = ResBlock(
+ 320 * 3,
+ time_embed_dim,
+ dropout,
+ out_channels=32 * 3,
+ dims=2,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=False,
+ )
+ self.final_conv = conv_nd(
+ dims=2, in_channels=32 * 3, out_channels=6, kernel_size=1
+ )
+
+ def forward(self, feature_64, feature_32, feature_16, emb):
+ feature_16_adapted = self.channel_adapter_conv_16(feature_16)
+ feature_32_adapted = self.channel_adapter_conv_32(feature_32)
+ feature_16_upsampled = self.upsampler_x4(feature_16_adapted)
+ feature_32_upsampled = self.upsampler_x2(feature_32_adapted)
+ feature_all = torch.concat(
+ [feature_16_upsampled, feature_32_upsampled, feature_64], dim=1
+ )
+
+ # bt, 3, h, w
+ return self.final_conv(self.res_block(feature_all, emb=emb))
diff --git a/core/modules/position_encoding.py b/core/modules/position_encoding.py
new file mode 100755
index 0000000000000000000000000000000000000000..8954466a55e147b9c46f453ac1482cf05be87ebb
--- /dev/null
+++ b/core/modules/position_encoding.py
@@ -0,0 +1,97 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+"""
+Various positional encodings for the transformer.
+"""
+import math
+import torch
+from torch import nn
+
+
+class PositionEmbeddingSine(nn.Module):
+ """
+ This is a more standard version of the position embedding, very similar to the one
+ used by the Attention is all you need paper, generalized to work on images.
+ """
+
+ def __init__(
+ self, num_pos_feats=64, temperature=10000, normalize=False, scale=None
+ ):
+ super().__init__()
+ self.num_pos_feats = num_pos_feats
+ self.temperature = temperature
+ self.normalize = normalize
+ if scale is not None and normalize is False:
+ raise ValueError("normalize should be True if scale is passed")
+ if scale is None:
+ scale = 2 * math.pi
+ self.scale = scale
+
+ def forward(self, token_tensors):
+ # input: (B,C,H,W)
+ x = token_tensors
+ h, w = x.shape[-2:]
+ identity_map = torch.ones((h, w), device=x.device)
+ y_embed = identity_map.cumsum(0, dtype=torch.float32)
+ x_embed = identity_map.cumsum(1, dtype=torch.float32)
+ if self.normalize:
+ eps = 1e-6
+ y_embed = y_embed / (y_embed[-1:, :] + eps) * self.scale
+ x_embed = x_embed / (x_embed[:, -1:] + eps) * self.scale
+
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
+
+ pos_x = x_embed[:, :, None] / dim_t
+ pos_y = y_embed[:, :, None] / dim_t
+ pos_x = torch.stack(
+ (pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3
+ ).flatten(2)
+ pos_y = torch.stack(
+ (pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3
+ ).flatten(2)
+ pos = torch.cat((pos_y, pos_x), dim=2).permute(2, 0, 1)
+ batch_pos = pos.unsqueeze(0).repeat(x.shape[0], 1, 1, 1)
+ return batch_pos
+
+
+class PositionEmbeddingLearned(nn.Module):
+ """
+ Absolute pos embedding, learned.
+ """
+
+ def __init__(self, n_pos_x=16, n_pos_y=16, num_pos_feats=64):
+ super().__init__()
+ self.row_embed = nn.Embedding(n_pos_y, num_pos_feats)
+ self.col_embed = nn.Embedding(n_pos_x, num_pos_feats)
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ nn.init.uniform_(self.row_embed.weight)
+ nn.init.uniform_(self.col_embed.weight)
+
+ def forward(self, token_tensors):
+ # input: (B,C,H,W)
+ x = token_tensors
+ h, w = x.shape[-2:]
+ i = torch.arange(w, device=x.device)
+ j = torch.arange(h, device=x.device)
+ x_emb = self.col_embed(i)
+ y_emb = self.row_embed(j)
+ pos = torch.cat(
+ [
+ x_emb.unsqueeze(0).repeat(h, 1, 1),
+ y_emb.unsqueeze(1).repeat(1, w, 1),
+ ],
+ dim=-1,
+ ).permute(2, 0, 1)
+ batch_pos = pos.unsqueeze(0).repeat(x.shape[0], 1, 1, 1)
+ return batch_pos
+
+
+def build_position_encoding(num_pos_feats=64, n_pos_x=16, n_pos_y=16, is_learned=False):
+ if is_learned:
+ position_embedding = PositionEmbeddingLearned(n_pos_x, n_pos_y, num_pos_feats)
+ else:
+ position_embedding = PositionEmbeddingSine(num_pos_feats, normalize=True)
+
+ return position_embedding
diff --git a/core/modules/x_transformer.py b/core/modules/x_transformer.py
new file mode 100755
index 0000000000000000000000000000000000000000..878a93eb2a3ce7cff355d08c3a745cf9c86b4d40
--- /dev/null
+++ b/core/modules/x_transformer.py
@@ -0,0 +1,679 @@
+from functools import partial
+from inspect import isfunction
+from collections import namedtuple
+from einops import rearrange, repeat
+
+import torch
+from torch import nn, einsum
+import torch.nn.functional as F
+
+DEFAULT_DIM_HEAD = 64
+
+Intermediates = namedtuple("Intermediates", ["pre_softmax_attn", "post_softmax_attn"])
+
+LayerIntermediates = namedtuple("Intermediates", ["hiddens", "attn_intermediates"])
+
+
+class AbsolutePositionalEmbedding(nn.Module):
+ def __init__(self, dim, max_seq_len):
+ super().__init__()
+ self.emb = nn.Embedding(max_seq_len, dim)
+ self.init_()
+
+ def init_(self):
+ nn.init.normal_(self.emb.weight, std=0.02)
+
+ def forward(self, x):
+ n = torch.arange(x.shape[1], device=x.device)
+ return self.emb(n)[None, :, :]
+
+
+class FixedPositionalEmbedding(nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
+ self.register_buffer("inv_freq", inv_freq)
+
+ def forward(self, x, seq_dim=1, offset=0):
+ t = (
+ torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq)
+ + offset
+ )
+ sinusoid_inp = torch.einsum("i , j -> i j", t, self.inv_freq)
+ emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
+ return emb[None, :, :]
+
+
+def exists(val):
+ return val is not None
+
+
+def default(val, d):
+ if exists(val):
+ return val
+ return d() if isfunction(d) else d
+
+
+def always(val):
+ def inner(*args, **kwargs):
+ return val
+
+ return inner
+
+
+def not_equals(val):
+ def inner(x):
+ return x != val
+
+ return inner
+
+
+def equals(val):
+ def inner(x):
+ return x == val
+
+ return inner
+
+
+def max_neg_value(tensor):
+ return -torch.finfo(tensor.dtype).max
+
+
+def pick_and_pop(keys, d):
+ values = list(map(lambda key: d.pop(key), keys))
+ return dict(zip(keys, values))
+
+
+def group_dict_by_key(cond, d):
+ return_val = [dict(), dict()]
+ for key in d.keys():
+ match = bool(cond(key))
+ ind = int(not match)
+ return_val[ind][key] = d[key]
+ return (*return_val,)
+
+
+def string_begins_with(prefix, str):
+ return str.startswith(prefix)
+
+
+def group_by_key_prefix(prefix, d):
+ return group_dict_by_key(partial(string_begins_with, prefix), d)
+
+
+def groupby_prefix_and_trim(prefix, d):
+ kwargs_with_prefix, kwargs = group_dict_by_key(
+ partial(string_begins_with, prefix), d
+ )
+ kwargs_without_prefix = dict(
+ map(lambda x: (x[0][len(prefix) :], x[1]), tuple(kwargs_with_prefix.items()))
+ )
+ return kwargs_without_prefix, kwargs
+
+
+class Scale(nn.Module):
+ def __init__(self, value, fn):
+ super().__init__()
+ self.value = value
+ self.fn = fn
+
+ def forward(self, x, **kwargs):
+ x, *rest = self.fn(x, **kwargs)
+ return (x * self.value, *rest)
+
+
+class Rezero(nn.Module):
+ def __init__(self, fn):
+ super().__init__()
+ self.fn = fn
+ self.g = nn.Parameter(torch.zeros(1))
+
+ def forward(self, x, **kwargs):
+ x, *rest = self.fn(x, **kwargs)
+ return (x * self.g, *rest)
+
+
+class ScaleNorm(nn.Module):
+ def __init__(self, dim, eps=1e-5):
+ super().__init__()
+ self.scale = dim**-0.5
+ self.eps = eps
+ self.g = nn.Parameter(torch.ones(1))
+
+ def forward(self, x):
+ norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
+ return x / norm.clamp(min=self.eps) * self.g
+
+
+class RMSNorm(nn.Module):
+ def __init__(self, dim, eps=1e-8):
+ super().__init__()
+ self.scale = dim**-0.5
+ self.eps = eps
+ self.g = nn.Parameter(torch.ones(dim))
+
+ def forward(self, x):
+ norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
+ return x / norm.clamp(min=self.eps) * self.g
+
+
+class Residual(nn.Module):
+ def forward(self, x, residual):
+ return x + residual
+
+
+class GRUGating(nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+ self.gru = nn.GRUCell(dim, dim)
+
+ def forward(self, x, residual):
+ gated_output = self.gru(
+ rearrange(x, "b n d -> (b n) d"), rearrange(residual, "b n d -> (b n) d")
+ )
+
+ return gated_output.reshape_as(x)
+
+
+class GEGLU(nn.Module):
+ def __init__(self, dim_in, dim_out):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out * 2)
+
+ def forward(self, x):
+ x, gate = self.proj(x).chunk(2, dim=-1)
+ return x * F.gelu(gate)
+
+
+class FeedForward(nn.Module):
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
+ super().__init__()
+ inner_dim = int(dim * mult)
+ dim_out = default(dim_out, dim)
+ project_in = (
+ nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
+ if not glu
+ else GEGLU(dim, inner_dim)
+ )
+
+ self.net = nn.Sequential(
+ project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
+ )
+
+ def forward(self, x):
+ return self.net(x)
+
+
+class Attention(nn.Module):
+ def __init__(
+ self,
+ dim,
+ dim_head=DEFAULT_DIM_HEAD,
+ heads=8,
+ causal=False,
+ mask=None,
+ talking_heads=False,
+ sparse_topk=None,
+ use_entmax15=False,
+ num_mem_kv=0,
+ dropout=0.0,
+ on_attn=False,
+ ):
+ super().__init__()
+ if use_entmax15:
+ raise NotImplementedError(
+ "Check out entmax activation instead of softmax activation!"
+ )
+ self.scale = dim_head**-0.5
+ self.heads = heads
+ self.causal = causal
+ self.mask = mask
+
+ inner_dim = dim_head * heads
+
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
+ self.to_k = nn.Linear(dim, inner_dim, bias=False)
+ self.to_v = nn.Linear(dim, inner_dim, bias=False)
+ self.dropout = nn.Dropout(dropout)
+
+ self.talking_heads = talking_heads
+ if talking_heads:
+ self.pre_softmax_proj = nn.Parameter(torch.randn(heads, heads))
+ self.post_softmax_proj = nn.Parameter(torch.randn(heads, heads))
+
+ self.sparse_topk = sparse_topk
+ self.attn_fn = F.softmax
+
+ self.num_mem_kv = num_mem_kv
+ if num_mem_kv > 0:
+ self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
+ self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
+
+ self.attn_on_attn = on_attn
+ self.to_out = (
+ nn.Sequential(nn.Linear(inner_dim, dim * 2), nn.GLU())
+ if on_attn
+ else nn.Linear(inner_dim, dim)
+ )
+
+ def forward(
+ self,
+ x,
+ context=None,
+ mask=None,
+ context_mask=None,
+ rel_pos=None,
+ sinusoidal_emb=None,
+ prev_attn=None,
+ mem=None,
+ ):
+ b, n, _, h, talking_heads, device = (
+ *x.shape,
+ self.heads,
+ self.talking_heads,
+ x.device,
+ )
+ kv_input = default(context, x)
+
+ q_input = x
+ k_input = kv_input
+ v_input = kv_input
+
+ if exists(mem):
+ k_input = torch.cat((mem, k_input), dim=-2)
+ v_input = torch.cat((mem, v_input), dim=-2)
+
+ if exists(sinusoidal_emb):
+ offset = k_input.shape[-2] - q_input.shape[-2]
+ q_input = q_input + sinusoidal_emb(q_input, offset=offset)
+ k_input = k_input + sinusoidal_emb(k_input)
+
+ q = self.to_q(q_input)
+ k = self.to_k(k_input)
+ v = self.to_v(v_input)
+
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
+
+ input_mask = None
+ if any(map(exists, (mask, context_mask))):
+ q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool())
+ k_mask = q_mask if not exists(context) else context_mask
+ k_mask = default(
+ k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool()
+ )
+ q_mask = rearrange(q_mask, "b i -> b () i ()")
+ k_mask = rearrange(k_mask, "b j -> b () () j")
+ input_mask = q_mask * k_mask
+
+ if self.num_mem_kv > 0:
+ mem_k, mem_v = map(
+ lambda t: repeat(t, "h n d -> b h n d", b=b), (self.mem_k, self.mem_v)
+ )
+ k = torch.cat((mem_k, k), dim=-2)
+ v = torch.cat((mem_v, v), dim=-2)
+ if exists(input_mask):
+ input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True)
+
+ dots = einsum("b h i d, b h j d -> b h i j", q, k) * self.scale
+ mask_value = max_neg_value(dots)
+
+ if exists(prev_attn):
+ dots = dots + prev_attn
+
+ pre_softmax_attn = dots
+
+ if talking_heads:
+ dots = einsum(
+ "b h i j, h k -> b k i j", dots, self.pre_softmax_proj
+ ).contiguous()
+
+ if exists(rel_pos):
+ dots = rel_pos(dots)
+
+ if exists(input_mask):
+ dots.masked_fill_(~input_mask, mask_value)
+ del input_mask
+
+ if self.causal:
+ i, j = dots.shape[-2:]
+ r = torch.arange(i, device=device)
+ mask = rearrange(r, "i -> () () i ()") < rearrange(r, "j -> () () () j")
+ mask = F.pad(mask, (j - i, 0), value=False)
+ dots.masked_fill_(mask, mask_value)
+ del mask
+
+ if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]:
+ top, _ = dots.topk(self.sparse_topk, dim=-1)
+ vk = top[..., -1].unsqueeze(-1).expand_as(dots)
+ mask = dots < vk
+ dots.masked_fill_(mask, mask_value)
+ del mask
+
+ attn = self.attn_fn(dots, dim=-1)
+ post_softmax_attn = attn
+
+ attn = self.dropout(attn)
+
+ if talking_heads:
+ attn = einsum(
+ "b h i j, h k -> b k i j", attn, self.post_softmax_proj
+ ).contiguous()
+
+ out = einsum("b h i j, b h j d -> b h i d", attn, v)
+ out = rearrange(out, "b h n d -> b n (h d)")
+
+ intermediates = Intermediates(
+ pre_softmax_attn=pre_softmax_attn, post_softmax_attn=post_softmax_attn
+ )
+
+ return self.to_out(out), intermediates
+
+
+class AttentionLayers(nn.Module):
+ def __init__(
+ self,
+ dim,
+ depth,
+ heads=8,
+ causal=False,
+ cross_attend=False,
+ only_cross=False,
+ use_scalenorm=False,
+ use_rmsnorm=False,
+ use_rezero=False,
+ rel_pos_num_buckets=32,
+ rel_pos_max_distance=128,
+ position_infused_attn=False,
+ custom_layers=None,
+ sandwich_coef=None,
+ par_ratio=None,
+ residual_attn=False,
+ cross_residual_attn=False,
+ macaron=False,
+ pre_norm=True,
+ gate_residual=False,
+ **kwargs,
+ ):
+ super().__init__()
+ ff_kwargs, kwargs = groupby_prefix_and_trim("ff_", kwargs)
+ attn_kwargs, _ = groupby_prefix_and_trim("attn_", kwargs)
+
+ dim_head = attn_kwargs.get("dim_head", DEFAULT_DIM_HEAD)
+
+ self.dim = dim
+ self.depth = depth
+ self.layers = nn.ModuleList([])
+
+ self.has_pos_emb = position_infused_attn
+ self.pia_pos_emb = (
+ FixedPositionalEmbedding(dim) if position_infused_attn else None
+ )
+ self.rotary_pos_emb = always(None)
+
+ assert (
+ rel_pos_num_buckets <= rel_pos_max_distance
+ ), "number of relative position buckets must be less than the relative position max distance"
+ self.rel_pos = None
+
+ self.pre_norm = pre_norm
+
+ self.residual_attn = residual_attn
+ self.cross_residual_attn = cross_residual_attn
+
+ norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm
+ norm_class = RMSNorm if use_rmsnorm else norm_class
+ norm_fn = partial(norm_class, dim)
+
+ norm_fn = nn.Identity if use_rezero else norm_fn
+ branch_fn = Rezero if use_rezero else None
+
+ if cross_attend and not only_cross:
+ default_block = ("a", "c", "f")
+ elif cross_attend and only_cross:
+ default_block = ("c", "f")
+ else:
+ default_block = ("a", "f")
+
+ if macaron:
+ default_block = ("f",) + default_block
+
+ if exists(custom_layers):
+ layer_types = custom_layers
+ elif exists(par_ratio):
+ par_depth = depth * len(default_block)
+ assert 1 < par_ratio <= par_depth, "par ratio out of range"
+ default_block = tuple(filter(not_equals("f"), default_block))
+ par_attn = par_depth // par_ratio
+ depth_cut = par_depth * 2 // 3
+ par_width = (depth_cut + depth_cut // par_attn) // par_attn
+ assert (
+ len(default_block) <= par_width
+ ), "default block is too large for par_ratio"
+ par_block = default_block + ("f",) * (par_width - len(default_block))
+ par_head = par_block * par_attn
+ layer_types = par_head + ("f",) * (par_depth - len(par_head))
+ elif exists(sandwich_coef):
+ assert (
+ sandwich_coef > 0 and sandwich_coef <= depth
+ ), "sandwich coefficient should be less than the depth"
+ layer_types = (
+ ("a",) * sandwich_coef
+ + default_block * (depth - sandwich_coef)
+ + ("f",) * sandwich_coef
+ )
+ else:
+ layer_types = default_block * depth
+
+ self.layer_types = layer_types
+ self.num_attn_layers = len(list(filter(equals("a"), layer_types)))
+
+ for layer_type in self.layer_types:
+ if layer_type == "a":
+ layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs)
+ elif layer_type == "c":
+ layer = Attention(dim, heads=heads, **attn_kwargs)
+ elif layer_type == "f":
+ layer = FeedForward(dim, **ff_kwargs)
+ layer = layer if not macaron else Scale(0.5, layer)
+ else:
+ raise Exception(f"invalid layer type {layer_type}")
+
+ if isinstance(layer, Attention) and exists(branch_fn):
+ layer = branch_fn(layer)
+
+ if gate_residual:
+ residual_fn = GRUGating(dim)
+ else:
+ residual_fn = Residual()
+
+ self.layers.append(nn.ModuleList([norm_fn(), layer, residual_fn]))
+
+ def forward(
+ self,
+ x,
+ context=None,
+ mask=None,
+ context_mask=None,
+ mems=None,
+ return_hiddens=False,
+ ):
+ hiddens = []
+ intermediates = []
+ prev_attn = None
+ prev_cross_attn = None
+
+ mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers
+
+ for ind, (layer_type, (norm, block, residual_fn)) in enumerate(
+ zip(self.layer_types, self.layers)
+ ):
+ is_last = ind == (len(self.layers) - 1)
+
+ if layer_type == "a":
+ hiddens.append(x)
+ layer_mem = mems.pop(0)
+
+ residual = x
+
+ if self.pre_norm:
+ x = norm(x)
+
+ if layer_type == "a":
+ out, inter = block(
+ x,
+ mask=mask,
+ sinusoidal_emb=self.pia_pos_emb,
+ rel_pos=self.rel_pos,
+ prev_attn=prev_attn,
+ mem=layer_mem,
+ )
+ elif layer_type == "c":
+ out, inter = block(
+ x,
+ context=context,
+ mask=mask,
+ context_mask=context_mask,
+ prev_attn=prev_cross_attn,
+ )
+ elif layer_type == "f":
+ out = block(x)
+
+ x = residual_fn(out, residual)
+
+ if layer_type in ("a", "c"):
+ intermediates.append(inter)
+
+ if layer_type == "a" and self.residual_attn:
+ prev_attn = inter.pre_softmax_attn
+ elif layer_type == "c" and self.cross_residual_attn:
+ prev_cross_attn = inter.pre_softmax_attn
+
+ if not self.pre_norm and not is_last:
+ x = norm(x)
+
+ if return_hiddens:
+ intermediates = LayerIntermediates(
+ hiddens=hiddens, attn_intermediates=intermediates
+ )
+
+ return x, intermediates
+
+ return x
+
+
+class Encoder(AttentionLayers):
+ def __init__(self, **kwargs):
+ assert "causal" not in kwargs, "cannot set causality on encoder"
+ super().__init__(causal=False, **kwargs)
+
+
+class TransformerWrapper(nn.Module):
+ def __init__(
+ self,
+ *,
+ num_tokens,
+ max_seq_len,
+ attn_layers,
+ emb_dim=None,
+ max_mem_len=0.0,
+ emb_dropout=0.0,
+ num_memory_tokens=None,
+ tie_embedding=False,
+ use_pos_emb=True,
+ ):
+ super().__init__()
+ assert isinstance(
+ attn_layers, AttentionLayers
+ ), "attention layers must be one of Encoder or Decoder"
+
+ dim = attn_layers.dim
+ emb_dim = default(emb_dim, dim)
+
+ self.max_seq_len = max_seq_len
+ self.max_mem_len = max_mem_len
+ self.num_tokens = num_tokens
+
+ self.token_emb = nn.Embedding(num_tokens, emb_dim)
+ self.pos_emb = (
+ AbsolutePositionalEmbedding(emb_dim, max_seq_len)
+ if (use_pos_emb and not attn_layers.has_pos_emb)
+ else always(0)
+ )
+ self.emb_dropout = nn.Dropout(emb_dropout)
+
+ self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
+ self.attn_layers = attn_layers
+ self.norm = nn.LayerNorm(dim)
+
+ self.init_()
+
+ self.to_logits = (
+ nn.Linear(dim, num_tokens)
+ if not tie_embedding
+ else lambda t: t @ self.token_emb.weight.t()
+ )
+
+ num_memory_tokens = default(num_memory_tokens, 0)
+ self.num_memory_tokens = num_memory_tokens
+ if num_memory_tokens > 0:
+ self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))
+
+ if hasattr(attn_layers, "num_memory_tokens"):
+ attn_layers.num_memory_tokens = num_memory_tokens
+
+ def init_(self):
+ nn.init.normal_(self.token_emb.weight, std=0.02)
+
+ def forward(
+ self,
+ x,
+ return_embeddings=False,
+ mask=None,
+ return_mems=False,
+ return_attn=False,
+ mems=None,
+ **kwargs,
+ ):
+ b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens
+ x = self.token_emb(x)
+ x += self.pos_emb(x)
+ x = self.emb_dropout(x)
+
+ x = self.project_emb(x)
+
+ if num_mem > 0:
+ mem = repeat(self.memory_tokens, "n d -> b n d", b=b)
+ x = torch.cat((mem, x), dim=1)
+
+ # auto-handle masking after appending memory tokens
+ if exists(mask):
+ mask = F.pad(mask, (num_mem, 0), value=True)
+
+ x, intermediates = self.attn_layers(
+ x, mask=mask, mems=mems, return_hiddens=True, **kwargs
+ )
+ x = self.norm(x)
+
+ mem, x = x[:, :num_mem], x[:, num_mem:]
+
+ out = self.to_logits(x) if not return_embeddings else x
+
+ if return_mems:
+ hiddens = intermediates.hiddens
+ new_mems = (
+ list(map(lambda pair: torch.cat(pair, dim=-2), zip(mems, hiddens)))
+ if exists(mems)
+ else hiddens
+ )
+ new_mems = list(
+ map(lambda t: t[..., -self.max_mem_len :, :].detach(), new_mems)
+ )
+ return out, new_mems
+
+ if return_attn:
+ attn_maps = list(
+ map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates)
+ )
+ return out, attn_maps
+
+ return out
diff --git a/main/evaluation/funcs.py b/main/evaluation/funcs.py
new file mode 100755
index 0000000000000000000000000000000000000000..258c3cb2a61cd5a961d6ad0f27e486e57e52b581
--- /dev/null
+++ b/main/evaluation/funcs.py
@@ -0,0 +1,295 @@
+from core.models.samplers.ddim import DDIMSampler
+import glob
+import json
+import os
+import sys
+from collections import OrderedDict
+
+import numpy as np
+import torch
+import torchvision
+from PIL import Image
+
+sys.path.insert(1, os.path.join(sys.path[0], "..", ".."))
+
+
+def batch_ddim_sampling(
+ model,
+ cond,
+ noise_shape,
+ n_samples=1,
+ ddim_steps=50,
+ ddim_eta=1.0,
+ cfg_scale=1.0,
+ temporal_cfg_scale=None,
+ use_cat_ucg=False,
+ **kwargs,
+):
+ ddim_sampler = DDIMSampler(model)
+ uncond_type = model.uncond_type
+ batch_size = noise_shape[0]
+
+ # construct unconditional guidance
+ if cfg_scale != 1.0:
+ if uncond_type == "empty_seq":
+ prompts = batch_size * [""]
+ # prompts = N * T * [""] # if is_image_batch=True
+ uc_emb = model.get_learned_conditioning(prompts)
+ elif uncond_type == "zero_embed":
+ c_emb = cond["c_crossattn"][0] if isinstance(cond, dict) else cond
+ uc_emb = torch.zeros_like(c_emb)
+
+ # process image condition
+ if hasattr(model, "embedder"):
+ uc_img = torch.zeros(noise_shape[0], 3, 224, 224).to(model.device)
+ # img: b c h w >> b l c
+ uc_img = model.get_image_embeds(uc_img)
+ uc_emb = torch.cat([uc_emb, uc_img], dim=1)
+
+ if isinstance(cond, dict):
+ uc = {key: cond[key] for key in cond.keys()}
+ uc.update({"c_crossattn": [uc_emb]})
+ # special CFG for frame concatenation
+ if use_cat_ucg and hasattr(model, "cond_concat") and model.cond_concat:
+ uc_cat = torch.zeros(
+ noise_shape[0], model.cond_channels, *noise_shape[2:]
+ ).to(model.device)
+ uc.update({"c_concat": [uc_cat]})
+ else:
+ uc = [uc_emb]
+ else:
+ uc = None
+ # uc.update({'fps': torch.tensor([-4]*batch_size).to(model.device).long()})
+ # sampling
+ noise = torch.randn(noise_shape, device=model.device)
+ # x_T = repeat(noise[:,:,:1,:,:], 'b c l h w -> b c (l t) h w', t=noise_shape[2])
+ # x_T = 0.2 * x_T + 0.8 * torch.randn(noise_shape, device=model.device)
+ x_T = None
+ batch_variants = []
+ # batch_variants1, batch_variants2 = [], []
+ for _ in range(n_samples):
+ if ddim_sampler is not None:
+ samples, _ = ddim_sampler.sample(
+ S=ddim_steps,
+ conditioning=cond,
+ batch_size=noise_shape[0],
+ shape=noise_shape[1:],
+ verbose=False,
+ unconditional_guidance_scale=cfg_scale,
+ unconditional_conditioning=uc,
+ eta=ddim_eta,
+ temporal_length=noise_shape[2],
+ conditional_guidance_scale_temporal=temporal_cfg_scale,
+ x_T=x_T,
+ **kwargs,
+ )
+ # reconstruct from latent to pixel space
+ batch_images = model.decode_first_stage(samples)
+ batch_variants.append(batch_images)
+ """
+ pred_x0_list, x_iter_list = _['pred_x0'], _['x_inter']
+ steps = [0, 15, 25, 30, 35, 40, 43, 46, 49, 50]
+ for nn in steps:
+ pred_x0 = pred_x0_list[nn]
+ x_iter = x_iter_list[nn]
+ batch_images_x0 = model.decode_first_stage(pred_x0)
+ batch_variants1.append(batch_images_x0)
+ batch_images_xt = model.decode_first_stage(x_iter)
+ batch_variants2.append(batch_images_xt)
+ """
+ # batch, , c, t, h, w
+ batch_variants = torch.stack(batch_variants, dim=1)
+ # batch_variants1 = torch.stack(batch_variants1, dim=1)
+ # batch_variants2 = torch.stack(batch_variants2, dim=1)
+ # return batch_variants1, batch_variants2
+ return batch_variants
+
+
+def batch_sliding_interpolation(
+ model,
+ cond,
+ base_videos,
+ base_stride,
+ noise_shape,
+ n_samples=1,
+ ddim_steps=50,
+ ddim_eta=1.0,
+ cfg_scale=1.0,
+ temporal_cfg_scale=None,
+ **kwargs,
+):
+ """
+ Current implementation has a flaw: the inter-episode keyframe is used as pre-last and cur-first, so keyframe repeated.
+ For example, cond_frames=[0,4,7], model.temporal_length=8, base_stride=4, then
+ base frame : 0 4 8 12 16 20 24 28
+ interplation: (0~7) (8~15) (16~23) (20~27)
+ """
+ b, c, t, h, w = noise_shape
+ base_z0 = model.encode_first_stage(base_videos)
+ unit_length = model.temporal_length
+ n_base_frames = base_videos.shape[2]
+ n_refs = len(model.cond_frames)
+ sliding_steps = (n_base_frames - 1) // (n_refs - 1)
+ sliding_steps = (
+ sliding_steps + 1 if (n_base_frames - 1) % (n_refs - 1) > 0 else sliding_steps
+ )
+
+ cond_mask = model.cond_mask.to("cuda")
+ proxy_z0 = torch.zeros((b, c, unit_length, h, w), dtype=torch.float32).to("cuda")
+ batch_samples = None
+ last_offset = None
+ for idx in range(sliding_steps):
+ base_idx = idx * (n_refs - 1)
+ # check index overflow
+ if base_idx + n_refs > n_base_frames:
+ last_offset = base_idx - (n_base_frames - n_refs)
+ base_idx = n_base_frames - n_refs
+ cond_z0 = base_z0[:, :, base_idx : base_idx + n_refs, :, :]
+ proxy_z0[:, :, model.cond_frames, :, :] = cond_z0
+
+ if "c_concat" in cond:
+ c_cat, text_emb = cond["c_concat"][0], cond["c_crossattn"][0]
+ episode_idx = idx * unit_length
+ if last_offset is not None:
+ episode_idx = episode_idx - last_offset * base_stride
+ cond_idx = {
+ "c_concat": [
+ c_cat[:, :, episode_idx : episode_idx + unit_length, :, :]
+ ],
+ "c_crossattn": [text_emb],
+ }
+ else:
+ cond_idx = cond
+ noise_shape_idx = [b, c, unit_length, h, w]
+ # batch, , c, t, h, w
+ batch_idx = batch_ddim_sampling(
+ model,
+ cond_idx,
+ noise_shape_idx,
+ n_samples,
+ ddim_steps,
+ ddim_eta,
+ cfg_scale,
+ temporal_cfg_scale,
+ mask=cond_mask,
+ x0=proxy_z0,
+ **kwargs,
+ )
+
+ if batch_samples is None:
+ batch_samples = batch_idx
+ else:
+ # b,s,c,t,h,w
+ if last_offset is None:
+ batch_samples = torch.cat(
+ [batch_samples[:, :, :, :-1, :, :], batch_idx], dim=3
+ )
+ else:
+ batch_samples = torch.cat(
+ [
+ batch_samples[:, :, :, :-1, :, :],
+ batch_idx[:, :, :, last_offset * base_stride :, :, :],
+ ],
+ dim=3,
+ )
+
+ return batch_samples
+
+
+def get_filelist(data_dir, ext="*"):
+ file_list = glob.glob(os.path.join(data_dir, "*.%s" % ext))
+ file_list.sort()
+ return file_list
+
+
+def get_dirlist(path):
+ list = []
+ if os.path.exists(path):
+ files = os.listdir(path)
+ for file in files:
+ m = os.path.join(path, file)
+ if os.path.isdir(m):
+ list.append(m)
+ list.sort()
+ return list
+
+
+def load_model_checkpoint(model, ckpt, adapter_ckpt=None):
+ def load_checkpoint(model, ckpt, full_strict):
+ state_dict = torch.load(ckpt, map_location="cpu", weights_only=True)
+ try:
+ # deepspeed
+ new_pl_sd = OrderedDict()
+ for key in state_dict["module"].keys():
+ new_pl_sd[key[16:]] = state_dict["module"][key]
+ model.load_state_dict(new_pl_sd, strict=full_strict)
+ except:
+ if "state_dict" in list(state_dict.keys()):
+ state_dict = state_dict["state_dict"]
+ model.load_state_dict(state_dict, strict=full_strict)
+ return model
+
+ if adapter_ckpt:
+ # main model
+ load_checkpoint(model, ckpt, full_strict=False)
+ print(">>> model checkpoint loaded.")
+ # adapter
+ state_dict = torch.load(adapter_ckpt, map_location="cpu")
+ if "state_dict" in list(state_dict.keys()):
+ state_dict = state_dict["state_dict"]
+ model.adapter.load_state_dict(state_dict, strict=True)
+ print(">>> adapter checkpoint loaded.")
+ else:
+ load_checkpoint(model, ckpt, full_strict=False)
+ print(">>> model checkpoint loaded.")
+ return model
+
+
+def load_prompts(prompt_file):
+ f = open(prompt_file, "r")
+ prompt_list = []
+ for idx, line in enumerate(f.readlines()):
+ l = line.strip()
+ if len(l) != 0:
+ prompt_list.append(l)
+ f.close()
+ return prompt_list
+
+
+def load_camera_poses(filepath_list, video_frames=16):
+ pose_list = []
+ for filepath in filepath_list:
+ with open(filepath, "r") as f:
+ pose = json.load(f)
+ pose = np.array(pose) # [t, 12]
+ pose = torch.tensor(pose).float() # [t, 12]
+ assert (
+ pose.shape[0] == video_frames
+ ), f"conditional pose frames Not matching the target frames [{video_frames}]."
+ pose_list.append(pose)
+ batch_poses = torch.stack(pose_list, dim=0)
+ # shape [b,t,12,1]
+ return batch_poses[..., None]
+
+
+def save_videos(
+ batch_tensors: torch.Tensor, save_dir: str, filenames: list[str], fps: int = 10
+):
+ # b,samples,t,c,h,w
+ n_samples = batch_tensors.shape[1]
+ for idx, vid_tensor in enumerate(batch_tensors):
+ video = vid_tensor.detach().cpu()
+ video = torch.clamp(video.float(), -1.0, 1.0)
+ video = video.permute(1, 0, 2, 3, 4) # t,n,c,h,w
+ frame_grids = [
+ torchvision.utils.make_grid(framesheet, nrow=int(n_samples))
+ for framesheet in video
+ ] # [3, 1*h, n*w]
+ # stack in temporal dim [t, 3, n*h, w]
+ grid = torch.stack(frame_grids, dim=0)
+ grid = (grid + 1.0) / 2.0
+ grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
+ savepath = os.path.join(save_dir, f"{filenames[idx]}.mp4")
+ torchvision.io.write_video(
+ savepath, grid, fps=fps, video_codec="h264", options={"crf": "10"}
+ )
diff --git a/main/evaluation/pose_interpolation.py b/main/evaluation/pose_interpolation.py
new file mode 100755
index 0000000000000000000000000000000000000000..12bdf1aef21d6a374280b1f3df76b10ba14723c3
--- /dev/null
+++ b/main/evaluation/pose_interpolation.py
@@ -0,0 +1,215 @@
+import torch
+import math
+
+
+def slerp(R1, R2, alpha):
+ """
+ Perform Spherical Linear Interpolation (SLERP) between two rotation matrices.
+ R1, R2: (3x3) rotation matrices.
+ alpha: interpolation factor, ranging from 0 to 1.
+ """
+
+ # Convert the rotation matrices to quaternions
+ def rotation_matrix_to_quaternion(R):
+ w = torch.sqrt(1.0 + R[0, 0] + R[1, 1] + R[2, 2]) / 2.0
+ w4 = 4.0 * w
+ x = (R[2, 1] - R[1, 2]) / w4
+ y = (R[0, 2] - R[2, 0]) / w4
+ z = (R[1, 0] - R[0, 1]) / w4
+ return torch.tensor([w, x, y, z]).float()
+
+ def quaternion_to_rotation_matrix(q):
+ w, x, y, z = q
+ return torch.tensor(
+ [
+ [
+ 1 - 2 * y * y - 2 * z * z,
+ 2 * x * y - 2 * w * z,
+ 2 * x * z + 2 * w * y,
+ ],
+ [
+ 2 * x * y + 2 * w * z,
+ 1 - 2 * x * x - 2 * z * z,
+ 2 * y * z - 2 * w * x,
+ ],
+ [
+ 2 * x * z - 2 * w * y,
+ 2 * y * z + 2 * w * x,
+ 1 - 2 * x * x - 2 * y * y,
+ ],
+ ]
+ ).float()
+
+ q1 = rotation_matrix_to_quaternion(R1)
+ q2 = rotation_matrix_to_quaternion(R2)
+
+ # Dot product of the quaternions
+ dot = torch.dot(q1, q2)
+
+ # If the dot product is negative, negate one quaternion to ensure the shortest path is taken
+ if dot < 0.0:
+ q2 = -q2
+ dot = -dot
+
+ # SLERP formula
+ if (
+ dot > 0.9995
+ ): # If the quaternions are nearly identical, use linear interpolation
+ q_interp = (1 - alpha) * q1 + alpha * q2
+ else:
+ theta_0 = torch.acos(dot) # Angle between q1 and q2
+ sin_theta_0 = torch.sin(theta_0)
+ theta = theta_0 * alpha # Angle between q1 and interpolated quaternion
+ sin_theta = torch.sin(theta)
+ s1 = torch.sin((1 - alpha) * theta_0) / sin_theta_0
+ s2 = sin_theta / sin_theta_0
+ q_interp = s1 * q1 + s2 * q2
+
+ # Convert the interpolated quaternion back to a rotation matrix
+ R_interp = quaternion_to_rotation_matrix(q_interp)
+ return R_interp
+
+
+def interpolate_camera_poses(pose1, pose2, num_steps):
+ """
+ Interpolate between two camera poses (3x4 matrices) over a number of steps.
+
+ pose1, pose2: (3x4) camera pose matrices (R|t), where R is a 3x3 rotation matrix and t is a 3x1 translation vector.
+ num_steps: number of interpolation steps.
+
+ Returns:
+ A list of interpolated poses as (3x4) matrices.
+ """
+ R1, t1 = pose1[:, :3], pose1[:, 3]
+ R2, t2 = pose2[:, :3], pose2[:, 3]
+
+ interpolated_poses = []
+ for i in range(num_steps):
+ alpha = i / (num_steps - 1) # Interpolation factor ranging from 0 to 1
+ # Interpolate rotation using SLERP
+ R_interp = slerp(R1, R2, alpha)
+ # Interpolate translation using linear interpolation (LERP)
+ t_interp = (1 - alpha) * t1 + alpha * t2
+ # Combine interpolated rotation and translation into a (3x4) pose matrix
+ pose_interp = torch.cat([R_interp, t_interp.unsqueeze(1)], dim=1)
+ interpolated_poses.append(pose_interp)
+
+ return interpolated_poses
+
+
+def rotation_matrix_from_xyz_angles(x_angle, y_angle, z_angle):
+ """
+ Compute the rotation matrix from given x, y, z angles (in radians).
+
+ x_angle: Rotation around the x-axis (pitch).
+ y_angle: Rotation around the y-axis (yaw).
+ z_angle: Rotation around the z-axis (roll).
+
+ Returns:
+ A 3x3 rotation matrix.
+ """
+ # Rotation matrices around each axis
+ Rx = torch.tensor(
+ [
+ [1, 0, 0],
+ [0, torch.cos(x_angle), -torch.sin(x_angle)],
+ [0, torch.sin(x_angle), torch.cos(x_angle)],
+ ]
+ ).float()
+ Ry = torch.tensor(
+ [
+ [torch.cos(y_angle), 0, torch.sin(y_angle)],
+ [0, 1, 0],
+ [-torch.sin(y_angle), 0, torch.cos(y_angle)],
+ ]
+ ).float()
+ Rz = torch.tensor(
+ [
+ [torch.cos(z_angle), -torch.sin(z_angle), 0],
+ [torch.sin(z_angle), torch.cos(z_angle), 0],
+ [0, 0, 1],
+ ]
+ ).float()
+ # Combined rotation matrix R = Rz * Ry * Rx
+ R_combined = Rz @ Ry @ Rx
+ return R_combined.float()
+
+
+def move_pose(pose1, x_angle, y_angle, z_angle, translation):
+ """
+ Calculate the second camera pose based on the first pose and given rotations (x, y, z) and translation.
+
+ pose1: The first camera pose (3x4 matrix).
+ x_angle, y_angle, z_angle: Rotation angles around the x, y, and z axes, in radians.
+ translation: Translation vector (3,).
+
+ Returns:
+ pose2: The second camera pose as a (3x4) matrix.
+ """
+ # Extract the rotation (R1) and translation (t1) from the first pose
+ R1 = pose1[:, :3]
+ t1 = pose1[:, 3]
+ # Calculate the new rotation matrix from the given angles
+ R_delta = rotation_matrix_from_xyz_angles(x_angle, y_angle, z_angle)
+ # New rotation = R1 * R_delta
+ R2 = R1 @ R_delta
+ # New translation = t1 + translation
+ t2 = t1 + translation
+ # Combine R2 and t2 into the new pose (3x4 matrix)
+ pose2 = torch.cat([R2, t2.unsqueeze(1)], dim=1)
+
+ return pose2
+
+
+def deg2rad(degrees):
+ """Convert degrees to radians."""
+ return degrees * math.pi / 180.0
+
+
+def generate_spherical_trajectory(end_angles, radius=1.0, num_steps=36):
+ """
+ Generate a camera-to-world (C2W) trajectory interpolating angles on a sphere.
+
+ Args:
+ end_angles (tuple): The endpoint rotation angles in degrees (x, y, z).
+ (start is assumed to be (0, 0, 0)).
+ radius (float): Radius of the sphere.
+ num_steps (int): Number of steps in the trajectory.
+
+ Returns:
+ torch.Tensor: A tensor of shape [num_steps, 3, 4] with the C2W transformations.
+ """
+ # Convert angles to radians
+ end_angles_rad = torch.tensor(
+ [deg2rad(angle) for angle in end_angles], dtype=torch.float32
+ )
+ # Interpolate angles linearly
+ interpolated_angles = (
+ torch.linspace(0, 1, num_steps).view(-1, 1) * end_angles_rad
+ ) # Shape: [num_steps, 3]
+ poses = []
+ for angles in interpolated_angles:
+ # Extract interpolated angles
+ x_angle, y_angle = angles
+ # Compute camera position on the sphere
+ x = radius * math.sin(y_angle) * math.cos(x_angle)
+ y = radius * math.sin(x_angle)
+ z = radius * math.cos(y_angle) * math.cos(x_angle)
+ cam_position = torch.tensor([x, y, z], dtype=torch.float32)
+ # Camera's forward direction (looking at the origin)
+ look_at_dir = -cam_position / torch.norm(cam_position)
+ # Define the "up" vector
+ up = torch.tensor([0.0, 1.0, 0.0], dtype=torch.float32)
+ # Compute the right vector
+ right = torch.cross(up, look_at_dir)
+ right = right / torch.norm(right)
+ # Recompute the orthogonal up vector
+ up = torch.cross(look_at_dir, right)
+ # Build the rotation matrix
+ rotation_matrix = torch.stack([right, up, look_at_dir], dim=0) # [3, 3]
+ # Combine the rotation matrix with the translation (camera position)
+ c2w = torch.cat([rotation_matrix, cam_position.view(3, 1)], dim=1) # [3, 4]
+ # Append the pose
+ poses.append(c2w)
+
+ return poses
diff --git a/main/evaluation/utils_eval.py b/main/evaluation/utils_eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..183497927b66062f188254277235be5f4d304376
--- /dev/null
+++ b/main/evaluation/utils_eval.py
@@ -0,0 +1,26 @@
+import torch
+
+
+def process_inference_batch(cfg_scale, batch, model, with_uncondition_extra=False):
+ for k in batch.keys():
+ if isinstance(batch[k], torch.Tensor):
+ batch[k] = batch[k].to(model.device, dtype=model.dtype)
+ z, cond, x_rec = model.get_batch_input(
+ batch,
+ random_drop_training_conditions=False,
+ return_reconstructed_target_images=True,
+ )
+ # batch_size = x_rec.shape[0]
+ # Get unconditioned embedding for classifier-free guidance sampling
+ if cfg_scale != 1.0:
+ uc = model.get_unconditional_dict_for_sampling(batch, cond, x_rec)
+ else:
+ uc = None
+
+ if with_uncondition_extra:
+ uc_extra = model.get_unconditional_dict_for_sampling(
+ batch, cond, x_rec, is_extra=True
+ )
+ return cond, uc, uc_extra, x_rec
+ else:
+ return cond, uc, x_rec
diff --git a/main/utils_data.py b/main/utils_data.py
new file mode 100755
index 0000000000000000000000000000000000000000..2d19e34eb937b2f0e4d2f15ba6de8624eecd527d
--- /dev/null
+++ b/main/utils_data.py
@@ -0,0 +1,164 @@
+from utils.utils import instantiate_from_config
+import os
+import sys
+from functools import partial
+
+import numpy as np
+import pytorch_lightning as pl
+import torch
+from torch.utils.data import DataLoader, Dataset
+
+os.chdir(sys.path[0])
+sys.path.append("..")
+
+
+def t_range(name, tensor):
+ print(
+ f"{name}: shape={tensor.shape}, max={torch.max(tensor)}, min={torch.min(tensor)}."
+ )
+
+
+def worker_init_fn(_):
+ worker_info = torch.utils.data.get_worker_info()
+ worker_id = worker_info.id
+ return np.random.seed(np.random.get_state()[1][0] + worker_id)
+
+
+class WrappedDataset(Dataset):
+ """Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset"""
+
+ def __init__(self, dataset):
+ self.data = dataset
+
+ def __len__(self):
+ return len(self.data)
+
+ def __getitem__(self, idx):
+ return self.data[idx]
+
+
+class DataModuleFromConfig(pl.LightningDataModule):
+ def __init__(
+ self,
+ batch_size,
+ train=None,
+ validation=None,
+ test=None,
+ predict=None,
+ train_img=None,
+ wrap=False,
+ num_workers=None,
+ shuffle_test_loader=False,
+ use_worker_init_fn=False,
+ shuffle_val_dataloader=False,
+ test_max_n_samples=None,
+ **kwargs,
+ ):
+ super().__init__()
+ self.batch_size = batch_size
+ self.dataset_configs = dict()
+ self.num_workers = num_workers if num_workers is not None else batch_size * 2
+ self.use_worker_init_fn = use_worker_init_fn
+ if train is not None:
+ self.dataset_configs["train"] = train
+ self.train_dataloader = self._train_dataloader
+ if validation is not None:
+ self.dataset_configs["validation"] = validation
+ self.val_dataloader = partial(
+ self._val_dataloader, shuffle=shuffle_val_dataloader
+ )
+ if test is not None:
+ self.dataset_configs["test"] = test
+ self.test_dataloader = partial(
+ self._test_dataloader, shuffle=shuffle_test_loader
+ )
+ if predict is not None:
+ self.dataset_configs["predict"] = predict
+ self.predict_dataloader = self._predict_dataloader
+ # train image dataset
+ if train_img is not None:
+ img_data = instantiate_from_config(train_img)
+ self.img_loader = img_data.train_dataloader()
+ else:
+ self.img_loader = None
+ self.wrap = wrap
+ self.test_max_n_samples = test_max_n_samples
+ self.collate_fn = None
+
+ def prepare_data(self):
+ # for data_cfg in self.dataset_configs.values():
+ # instantiate_from_config(data_cfg)
+ pass
+
+ def setup(self, stage=None):
+ self.datasets = dict(
+ (k, instantiate_from_config(self.dataset_configs[k]))
+ for k in self.dataset_configs
+ )
+ if self.wrap:
+ for k in self.datasets:
+ self.datasets[k] = WrappedDataset(self.datasets[k])
+
+ def _train_dataloader(self):
+ is_iterable_dataset = False
+ if is_iterable_dataset or self.use_worker_init_fn:
+ init_fn = worker_init_fn
+ else:
+ init_fn = None
+ loader = DataLoader(
+ self.datasets["train"],
+ batch_size=self.batch_size,
+ num_workers=self.num_workers,
+ shuffle=False if is_iterable_dataset else True,
+ worker_init_fn=init_fn,
+ collate_fn=self.collate_fn,
+ )
+ if self.img_loader is not None:
+ return {"loader_video": loader, "loader_img": self.img_loader}
+ else:
+ return loader
+
+ def _val_dataloader(self, shuffle=False):
+ init_fn = None
+ return DataLoader(
+ self.datasets["validation"],
+ batch_size=self.batch_size,
+ num_workers=self.num_workers,
+ worker_init_fn=init_fn,
+ shuffle=shuffle,
+ collate_fn=self.collate_fn,
+ )
+
+ def _test_dataloader(self, shuffle=False):
+ is_iterable_dataset = False
+ if is_iterable_dataset or self.use_worker_init_fn:
+ init_fn = worker_init_fn
+ else:
+ init_fn = None
+
+ # do not shuffle dataloader for iterable dataset
+ shuffle = shuffle and (not is_iterable_dataset)
+ if self.test_max_n_samples is not None:
+ dataset = torch.utils.data.Subset(
+ self.datasets["test"], list(range(self.test_max_n_samples))
+ )
+ else:
+ dataset = self.datasets["test"]
+ return DataLoader(
+ dataset,
+ batch_size=self.batch_size,
+ num_workers=self.num_workers,
+ worker_init_fn=init_fn,
+ shuffle=shuffle,
+ collate_fn=self.collate_fn,
+ )
+
+ def _predict_dataloader(self, shuffle=False):
+ init_fn = None
+ return DataLoader(
+ self.datasets["predict"],
+ batch_size=self.batch_size,
+ num_workers=self.num_workers,
+ worker_init_fn=init_fn,
+ collate_fn=self.collate_fn,
+ )
diff --git a/requirements.txt b/requirements.txt
new file mode 100755
index 0000000000000000000000000000000000000000..cdf144c8a80e4c6de3f3846958b429f9c8912d81
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,18 @@
+torch==2.1.2
+pytorch_lightning==2.1.2
+deepspeed==0.15.0
+taming-transformers==0.0.1
+diffusers==0.30.1
+transformers
+torchvision==0.16.2
+scipy==1.14.1
+einops==0.8.0
+kornia==0.7.3
+open_clip_torch==2.20.0
+openai-clip==1.0.1
+numpy==1.26.3
+xformers==0.0.23.post1
+timm==1.0.9
+av==12.3.0
+gradio==5.8.0
+huggingface_hub
\ No newline at end of file
diff --git a/utils/constants.py b/utils/constants.py
new file mode 100755
index 0000000000000000000000000000000000000000..24cd39e0616524895e441afdbeb98b2bdbe77163
--- /dev/null
+++ b/utils/constants.py
@@ -0,0 +1,2 @@
+FLAG_RUN_DEBUG = False
+PATH_DIR_DEBUG = "./debug/"
diff --git a/utils/load_weigths.py b/utils/load_weigths.py
new file mode 100755
index 0000000000000000000000000000000000000000..c3e16bcaa14476f7fdbb95ca8453d76b8de90fae
--- /dev/null
+++ b/utils/load_weigths.py
@@ -0,0 +1,252 @@
+from utils.utils import instantiate_from_config
+import torch
+import copy
+from omegaconf import OmegaConf
+import logging
+
+main_logger = logging.getLogger("main_logger")
+
+
+def expand_conv_kernel(pretrained_dict):
+ """expand 2d conv parameters from 4D -> 5D"""
+ for k, v in pretrained_dict.items():
+ if v.dim() == 4 and not k.startswith("first_stage_model"):
+ v = v.unsqueeze(2)
+ pretrained_dict[k] = v
+ return pretrained_dict
+
+
+def print_state_dict(state_dict):
+ print("====== Dumping State Dict ======")
+ for k, v in state_dict.items():
+ print(k, v.shape)
+
+
+def load_from_pretrainedSD_checkpoint(
+ model,
+ pretained_ckpt,
+ expand_to_3d=True,
+ adapt_keyname=False,
+ echo_empty_params=False,
+):
+ sd_state_dict = torch.load(pretained_ckpt, map_location="cpu")
+ if "state_dict" in list(sd_state_dict.keys()):
+ sd_state_dict = sd_state_dict["state_dict"]
+ model_state_dict = model.state_dict()
+ # delete ema_weights just for
+ for k in list(sd_state_dict.keys()):
+ if k.startswith("model_ema"):
+ del sd_state_dict[k]
+ main_logger.info(
+ f"Num of model params of Source:{len(sd_state_dict.keys())} VS. Target:{len(model_state_dict.keys())}"
+ )
+ # print_state_dict(model_state_dict)
+ # print_state_dict(sd_state_dict)
+
+ if adapt_keyname:
+ # adapting to standard 2d network: modify the key name because of the add of temporal-attention
+ mapping_dict = {
+ "middle_block.2": "middle_block.3",
+ "output_blocks.5.2": "output_blocks.5.3",
+ "output_blocks.8.2": "output_blocks.8.3",
+ }
+ cnt = 0
+ for k in list(sd_state_dict.keys()):
+ for src_word, dst_word in mapping_dict.items():
+ if src_word in k:
+ new_key = k.replace(src_word, dst_word)
+ sd_state_dict[new_key] = sd_state_dict[k]
+ del sd_state_dict[k]
+ cnt += 1
+ main_logger.info(f"[renamed {cnt} Source keys to match Target model]")
+
+ pretrained_dict = {
+ k: v for k, v in sd_state_dict.items() if k in model_state_dict
+ } # drop extra keys
+ empty_paras = [
+ k for k, v in model_state_dict.items() if k not in pretrained_dict
+ ] # log no pretrained keys
+ assert len(empty_paras) + len(pretrained_dict.keys()) == len(
+ model_state_dict.keys()
+ )
+
+ if expand_to_3d:
+ # adapting to 2d inflated network
+ pretrained_dict = expand_conv_kernel(pretrained_dict)
+
+ # overwrite entries in the existing state dict
+ model_state_dict.update(pretrained_dict)
+
+ # load the new state dict
+ try:
+ model.load_state_dict(model_state_dict)
+ except:
+ skipped = []
+ model_dict_ori = model.state_dict()
+ for n, p in model_state_dict.items():
+ if p.shape != model_dict_ori[n].shape:
+ # skip by using original empty paras
+ model_state_dict[n] = model_dict_ori[n]
+ main_logger.info(
+ f"Skip para: {n}, size={pretrained_dict[n].shape} in pretrained, {model_state_dict[n].shape} in current model"
+ )
+ skipped.append(n)
+ main_logger.info(
+ f"[INFO] Skip {len(skipped)} parameters becasuse of size mismatch!"
+ )
+ model.load_state_dict(model_state_dict)
+ empty_paras += skipped
+
+ # only count Unet part of depth estimation model
+ unet_empty_paras = [
+ name for name in empty_paras if name.startswith("model.diffusion_model")
+ ]
+ main_logger.info(
+ f"Pretrained parameters: {len(pretrained_dict.keys())} | Empty parameters: {len(empty_paras)} [Unet:{len(unet_empty_paras)}]"
+ )
+ if echo_empty_params:
+ print("Printing empty parameters:")
+ for k in empty_paras:
+ print(k)
+ return model, empty_paras
+
+
+# Below: written by Yingqing --------------------------------------------------------
+
+
+def load_model_from_config(config, ckpt, verbose=False):
+ pl_sd = torch.load(ckpt, map_location="cpu")
+ sd = pl_sd["state_dict"]
+ model = instantiate_from_config(config.model)
+ m, u = model.load_state_dict(sd, strict=False)
+ if len(m) > 0 and verbose:
+ main_logger.info("missing keys:")
+ main_logger.info(m)
+ if len(u) > 0 and verbose:
+ main_logger.info("unexpected keys:")
+ main_logger.info(u)
+
+ model.eval()
+ return model
+
+
+def init_and_load_ldm_model(config_path, ckpt_path, device=None):
+ assert config_path.endswith(".yaml"), f"config_path = {config_path}"
+ assert ckpt_path.endswith(".ckpt"), f"ckpt_path = {ckpt_path}"
+ config = OmegaConf.load(config_path)
+ model = load_model_from_config(config, ckpt_path)
+ if device is not None:
+ model = model.to(device)
+ return model
+
+
+def load_img_model_to_video_model(
+ model,
+ device=None,
+ expand_to_3d=True,
+ adapt_keyname=False,
+ config_path="configs/latent-diffusion/txt2img-1p4B-eval.yaml",
+ ckpt_path="models/ldm/text2img-large/model.ckpt",
+):
+ pretrained_ldm = init_and_load_ldm_model(config_path, ckpt_path, device)
+ model, empty_paras = load_partial_weights(
+ model,
+ pretrained_ldm.state_dict(),
+ expand_to_3d=expand_to_3d,
+ adapt_keyname=adapt_keyname,
+ )
+ return model, empty_paras
+
+
+def load_partial_weights(
+ model, pretrained_dict, expand_to_3d=True, adapt_keyname=False
+):
+ model2 = copy.deepcopy(model)
+ model_dict = model.state_dict()
+ model_dict_ori = copy.deepcopy(model_dict)
+
+ main_logger.info(f"[Load pretrained LDM weights]")
+ main_logger.info(
+ f"Num of parameters of source model:{len(pretrained_dict.keys())} VS. target model:{len(model_dict.keys())}"
+ )
+
+ if adapt_keyname:
+ # adapting to menghan's standard 2d network: modify the key name because of the add of temporal-attention
+ mapping_dict = {
+ "middle_block.2": "middle_block.3",
+ "output_blocks.5.2": "output_blocks.5.3",
+ "output_blocks.8.2": "output_blocks.8.3",
+ }
+ cnt = 0
+ newpretrained_dict = copy.deepcopy(pretrained_dict)
+ for k, v in newpretrained_dict.items():
+ for src_word, dst_word in mapping_dict.items():
+ if src_word in k:
+ new_key = k.replace(src_word, dst_word)
+ pretrained_dict[new_key] = v
+ pretrained_dict.pop(k)
+ cnt += 1
+ main_logger.info(f"--renamed {cnt} source keys to match target model.")
+ pretrained_dict = {
+ k: v for k, v in pretrained_dict.items() if k in model_dict
+ } # drop extra keys
+ empty_paras = [
+ k for k, v in model_dict.items() if k not in pretrained_dict
+ ] # log no pretrained keys
+ main_logger.info(
+ f"Pretrained parameters: {len(pretrained_dict.keys())} | Empty parameters: {len(empty_paras)}"
+ )
+ # disable info
+ # main_logger.info(f'Empty parameters: {empty_paras} ')
+ assert len(empty_paras) + len(pretrained_dict.keys()) == len(model_dict.keys())
+
+ if expand_to_3d:
+ # adapting to yingqing's 2d inflation network
+ pretrained_dict = expand_conv_kernel(pretrained_dict)
+
+ # overwrite entries in the existing state dict
+ model_dict.update(pretrained_dict)
+
+ # load the new state dict
+ try:
+ model2.load_state_dict(model_dict)
+ except:
+ # if parameter size mismatch, skip them
+ skipped = []
+ for n, p in model_dict.items():
+ if p.shape != model_dict_ori[n].shape:
+ # skip by using original empty paras
+ model_dict[n] = model_dict_ori[n]
+ main_logger.info(
+ f"Skip para: {n}, size={pretrained_dict[n].shape} in pretrained, {model_dict[n].shape} in current model"
+ )
+ skipped.append(n)
+ main_logger.info(
+ f"[INFO] Skip {len(skipped)} parameters becasuse of size mismatch!"
+ )
+ model2.load_state_dict(model_dict)
+ empty_paras += skipped
+ main_logger.info(f"Empty parameters: {len(empty_paras)} ")
+
+ main_logger.info(f"Finished.")
+ return model2, empty_paras
+
+
+def load_autoencoder(model, config_path=None, ckpt_path=None, device=None):
+ if config_path is None:
+ config_path = "configs/latent-diffusion/txt2img-1p4B-eval.yaml"
+ if ckpt_path is None:
+ ckpt_path = "models/ldm/text2img-large/model.ckpt"
+
+ pretrained_ldm = init_and_load_ldm_model(config_path, ckpt_path, device)
+ autoencoder_dict = {}
+ for n, p in pretrained_ldm.state_dict().items():
+ if n.startswith("first_stage_model"):
+ autoencoder_dict[n] = p
+ model_dict = model.state_dict()
+ model_dict.update(autoencoder_dict)
+ main_logger.info(f"Load [{len(autoencoder_dict)}] autoencoder parameters!")
+
+ model.load_state_dict(model_dict)
+
+ return model
diff --git a/utils/lr_scheduler.py b/utils/lr_scheduler.py
new file mode 100755
index 0000000000000000000000000000000000000000..2be15ed6a36b7cf3c5cfda3958868ab093462b9f
--- /dev/null
+++ b/utils/lr_scheduler.py
@@ -0,0 +1,199 @@
+import numpy as np
+import torch
+import torch.optim as optim
+
+
+def build_LR_scheduler(
+ optimizer, scheduler_name, lr_decay_ratio, max_epochs, start_epoch=0
+):
+ # print("-LR scheduler:%s"%scheduler_name)
+ if scheduler_name == "LambdaLR":
+ decay_ratio = lr_decay_ratio
+ decay_epochs = max_epochs
+
+ def polynomial_decay(epoch):
+ return (
+ 1 + (decay_ratio - 1) * ((epoch + start_epoch) / decay_epochs)
+ if (epoch + start_epoch) < decay_epochs
+ else decay_ratio
+ )
+
+ lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
+ optimizer, lr_lambda=polynomial_decay
+ )
+ elif scheduler_name == "CosineAnnealingLR":
+ last_epoch = -1 if start_epoch == 0 else start_epoch
+ lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
+ optimizer, T_max=max_epochs, last_epoch=last_epoch
+ )
+ elif scheduler_name == "ReduceLROnPlateau":
+ lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
+ optimizer, mode="min", factor=0.5, threshold=0.01, patience=5
+ )
+ else:
+ raise NotImplementedError
+ return lr_scheduler
+
+
+class LambdaLRScheduler:
+ # target: torch.optim.lr_scheduler.LambdaLR
+ def __init__(self, start_step, final_decay_ratio, decay_steps):
+ self.final_decay_ratio = final_decay_ratio
+ self.decay_steps = decay_steps
+ self.start_step = start_step
+
+ def schedule(self, step):
+ if step + self.start_step < self.decay_steps:
+ return 1.0 + (self.final_decay_ratio - 1) * (
+ (step + self.start_step) / self.decay_steps
+ )
+ else:
+ return self.final_decay_ratio
+
+ def __call__(self, step):
+ return self.scheduler(step)
+
+
+class CosineAnnealingLRScheduler:
+ # target: torch.optim.lr_scheduler.CosineAnnealingLR
+ def __init__(self, start_step, decay_steps):
+ self.decay_steps = decay_steps
+ self.start_step = start_step
+
+ def __call__(self, step):
+ pass
+
+
+class LambdaWarmUpCosineScheduler:
+ """
+ note: use with a base_lr of 1.0
+ """
+
+ def __init__(
+ self,
+ warm_up_steps,
+ lr_min,
+ lr_max,
+ lr_start,
+ max_decay_steps,
+ verbosity_interval=0,
+ ):
+ self.lr_warm_up_steps = warm_up_steps
+ self.lr_start = lr_start
+ self.lr_min = lr_min
+ self.lr_max = lr_max
+ self.lr_max_decay_steps = max_decay_steps
+ self.last_lr = 0.0
+ self.verbosity_interval = verbosity_interval
+
+ def schedule(self, n, **kwargs):
+ if self.verbosity_interval > 0:
+ if n % self.verbosity_interval == 0:
+ print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
+ if n < self.lr_warm_up_steps:
+ lr = (
+ self.lr_max - self.lr_start
+ ) / self.lr_warm_up_steps * n + self.lr_start
+ self.last_lr = lr
+ return lr
+ else:
+ t = (n - self.lr_warm_up_steps) / (
+ self.lr_max_decay_steps - self.lr_warm_up_steps
+ )
+ t = min(t, 1.0)
+ lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
+ 1 + np.cos(t * np.pi)
+ )
+ self.last_lr = lr
+ return lr
+
+ def __call__(self, n, **kwargs):
+ return self.schedule(n, **kwargs)
+
+
+class LambdaWarmUpCosineScheduler2:
+ """
+ supports repeated iterations, configurable via lists
+ note: use with a base_lr of 1.0.
+ """
+
+ def __init__(
+ self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0
+ ):
+ assert (
+ len(warm_up_steps)
+ == len(f_min)
+ == len(f_max)
+ == len(f_start)
+ == len(cycle_lengths)
+ )
+ self.lr_warm_up_steps = warm_up_steps
+ self.f_start = f_start
+ self.f_min = f_min
+ self.f_max = f_max
+ self.cycle_lengths = cycle_lengths
+ self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
+ self.last_f = 0.0
+ self.verbosity_interval = verbosity_interval
+
+ def find_in_interval(self, n):
+ interval = 0
+ for cl in self.cum_cycles[1:]:
+ if n <= cl:
+ return interval
+ interval += 1
+
+ def schedule(self, n, **kwargs):
+ cycle = self.find_in_interval(n)
+ n = n - self.cum_cycles[cycle]
+ if self.verbosity_interval > 0:
+ if n % self.verbosity_interval == 0:
+ print(
+ f"current step: {n}, recent lr-multiplier: {self.last_f}, "
+ f"current cycle {cycle}"
+ )
+ if n < self.lr_warm_up_steps[cycle]:
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
+ cycle
+ ] * n + self.f_start[cycle]
+ self.last_f = f
+ return f
+ else:
+ t = (n - self.lr_warm_up_steps[cycle]) / (
+ self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]
+ )
+ t = min(t, 1.0)
+ f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
+ 1 + np.cos(t * np.pi)
+ )
+ self.last_f = f
+ return f
+
+ def __call__(self, n, **kwargs):
+ return self.schedule(n, **kwargs)
+
+
+class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
+
+ def schedule(self, n, **kwargs):
+ cycle = self.find_in_interval(n)
+ n = n - self.cum_cycles[cycle]
+ if self.verbosity_interval > 0:
+ if n % self.verbosity_interval == 0:
+ print(
+ f"current step: {n}, recent lr-multiplier: {self.last_f}, "
+ f"current cycle {cycle}"
+ )
+
+ if n < self.lr_warm_up_steps[cycle]:
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
+ cycle
+ ] * n + self.f_start[cycle]
+ self.last_f = f
+ return f
+ else:
+ f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (
+ self.cycle_lengths[cycle] - n
+ ) / (self.cycle_lengths[cycle])
+ self.last_f = f
+ return f
diff --git a/utils/save_video.py b/utils/save_video.py
new file mode 100755
index 0000000000000000000000000000000000000000..386022ae5372c2030afae70c5b4aac241784c183
--- /dev/null
+++ b/utils/save_video.py
@@ -0,0 +1,258 @@
+import os
+import numpy as np
+from tqdm import tqdm
+from PIL import Image
+
+import torch
+import torchvision
+from torch import Tensor
+from torchvision.utils import make_grid
+from torchvision.transforms.functional import to_tensor
+
+
+def frames_to_mp4(frame_dir, output_path, fps):
+ def read_first_n_frames(d: os.PathLike, num_frames: int):
+ if num_frames:
+ images = [
+ Image.open(os.path.join(d, f))
+ for f in sorted(os.listdir(d))[:num_frames]
+ ]
+ else:
+ images = [Image.open(os.path.join(d, f)) for f in sorted(os.listdir(d))]
+ images = [to_tensor(x) for x in images]
+ return torch.stack(images)
+
+ videos = read_first_n_frames(frame_dir, num_frames=None)
+ videos = videos.mul(255).to(torch.uint8).permute(0, 2, 3, 1)
+ torchvision.io.write_video(
+ output_path, videos, fps=fps, video_codec="h264", options={"crf": "10"}
+ )
+
+
+def tensor_to_mp4(video, savepath, fps, rescale=True, nrow=None):
+ """
+ video: torch.Tensor, b,c,t,h,w, 0-1
+ if -1~1, enable rescale=True
+ """
+ n = video.shape[0]
+ video = video.permute(2, 0, 1, 3, 4) # t,n,c,h,w
+ nrow = int(np.sqrt(n)) if nrow is None else nrow
+ frame_grids = [
+ torchvision.utils.make_grid(framesheet, nrow=nrow) for framesheet in video
+ ] # [3, grid_h, grid_w]
+ # stack in temporal dim [T, 3, grid_h, grid_w]
+ grid = torch.stack(frame_grids, dim=0)
+ grid = torch.clamp(grid.float(), -1.0, 1.0)
+ if rescale:
+ grid = (grid + 1.0) / 2.0
+ # [T, 3, grid_h, grid_w] -> [T, grid_h, grid_w, 3]
+ grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
+ # print(f'Save video to {savepath}')
+ torchvision.io.write_video(
+ savepath, grid, fps=fps, video_codec="h264", options={"crf": "10"}
+ )
+
+
+def tensor2videogrids(video, root, filename, fps, rescale=True, clamp=True):
+ assert video.dim() == 5 # b,c,t,h,w
+ assert isinstance(video, torch.Tensor)
+
+ video = video.detach().cpu()
+ if clamp:
+ video = torch.clamp(video, -1.0, 1.0)
+ n = video.shape[0]
+ video = video.permute(2, 0, 1, 3, 4) # t,n,c,h,w
+ frame_grids = [
+ torchvision.utils.make_grid(framesheet, nrow=int(np.sqrt(n)))
+ for framesheet in video
+ ] # [3, grid_h, grid_w]
+ # stack in temporal dim [T, 3, grid_h, grid_w]
+ grid = torch.stack(frame_grids, dim=0)
+ if rescale:
+ grid = (grid + 1.0) / 2.0
+ # [T, 3, grid_h, grid_w] -> [T, grid_h, grid_w, 3]
+ grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
+ path = os.path.join(root, filename)
+ # print('Save video ...')
+ torchvision.io.write_video(
+ path, grid, fps=fps, video_codec="h264", options={"crf": "10"}
+ )
+ # print('Finish!')
+
+
+def log_txt_as_img(wh, xc, size=10):
+ # wh a tuple of (width, height)
+ # xc a list of captions to plot
+ b = len(xc)
+ txts = list()
+ for bi in range(b):
+ txt = Image.new("RGB", wh, color="white")
+ draw = ImageDraw.Draw(txt)
+ font = ImageFont.truetype("data/DejaVuSans.ttf", size=size)
+ nc = int(40 * (wh[0] / 256))
+ lines = "\n".join(
+ xc[bi][start : start + nc] for start in range(0, len(xc[bi]), nc)
+ )
+
+ try:
+ draw.text((0, 0), lines, fill="black", font=font)
+ except UnicodeEncodeError:
+ print("Cant encode string for logging. Skipping.")
+
+ txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
+ txts.append(txt)
+ txts = np.stack(txts)
+ txts = torch.tensor(txts)
+ return txts
+
+
+def log_local(batch_logs, save_dir, filename, save_fps=10, rescale=True):
+ if batch_logs is None:
+ return None
+ """ save images and videos from images dict """
+
+ def save_img_grid(grid, path, rescale):
+ if rescale:
+ grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
+ grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
+ grid = grid.numpy()
+ grid = (grid * 255).astype(np.uint8)
+ os.makedirs(os.path.split(path)[0], exist_ok=True)
+ Image.fromarray(grid).save(path)
+
+ for key in batch_logs:
+ value = batch_logs[key]
+ if isinstance(value, list) and isinstance(value[0], str):
+ # a batch of captions
+ path = os.path.join(save_dir, "%s-%s.txt" % (key, filename))
+ with open(path, "w") as f:
+ for i, txt in enumerate(value):
+ f.write(f"idx={i}, txt={txt}\n")
+ f.close()
+ elif isinstance(value, torch.Tensor) and value.dim() == 5:
+ # save video grids
+ video = value # b,c,t,h,w
+ # only save grayscale or rgb mode
+ if video.shape[1] != 1 and video.shape[1] != 3:
+ continue
+ n = video.shape[0]
+ video = video.permute(2, 0, 1, 3, 4) # t,n,c,h,w
+ frame_grids = [
+ torchvision.utils.make_grid(framesheet, nrow=int(1))
+ for framesheet in video
+ ] # [3, n*h, 1*w]
+ # stack in temporal dim [t, 3, n*h, w]
+ grid = torch.stack(frame_grids, dim=0)
+ if rescale:
+ grid = (grid + 1.0) / 2.0
+ grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
+ path = os.path.join(save_dir, "%s-%s.mp4" % (key, filename))
+ torchvision.io.write_video(
+ path, grid, fps=save_fps, video_codec="h264", options={"crf": "10"}
+ )
+ elif isinstance(value, torch.Tensor) and value.dim() == 4:
+ img = value
+ if img.shape[1] != 1 and img.shape[1] != 3:
+ continue
+ n = img.shape[0]
+ grid = torchvision.utils.make_grid(img, nrow=1)
+ path = os.path.join(save_dir, "%s-%s.jpg" % (key, filename))
+ save_img_grid(grid, path, rescale)
+ else:
+ pass
+
+
+def prepare_to_log(batch_logs, max_images=100000, clamp=True):
+ if batch_logs is None:
+ return None
+ # process
+ for key in batch_logs:
+ N = (
+ batch_logs[key].shape[0]
+ if hasattr(batch_logs[key], "shape")
+ else len(batch_logs[key])
+ )
+ N = min(N, max_images)
+ batch_logs[key] = batch_logs[key][:N]
+ if isinstance(batch_logs[key], torch.Tensor):
+ batch_logs[key] = batch_logs[key].detach().cpu()
+ if clamp:
+ try:
+ batch_logs[key] = torch.clamp(batch_logs[key].float(), -1.0, 1.0)
+ except RuntimeError:
+ print("clamp_scalar_cpu not implemented for Half")
+ return batch_logs
+
+
+def fill_with_black_squares(video, desired_len: int) -> Tensor:
+ if len(video) >= desired_len:
+ return video
+
+ return torch.cat(
+ [
+ video,
+ torch.zeros_like(video[0])
+ .unsqueeze(0)
+ .repeat(desired_len - len(video), 1, 1, 1),
+ ],
+ dim=0,
+ )
+
+
+def load_num_videos(data_path, num_videos):
+ # first argument can be either data_path of np array
+ if isinstance(data_path, str):
+ videos = np.load(data_path)["arr_0"] # NTHWC
+ elif isinstance(data_path, np.ndarray):
+ videos = data_path
+ else:
+ raise Exception
+
+ if num_videos is not None:
+ videos = videos[:num_videos, :, :, :, :]
+ return videos
+
+
+def npz_to_video_grid(
+ data_path, out_path, num_frames, fps, num_videos=None, nrow=None, verbose=True
+):
+ if isinstance(data_path, str):
+ videos = load_num_videos(data_path, num_videos)
+ elif isinstance(data_path, np.ndarray):
+ videos = data_path
+ else:
+ raise Exception
+ n, t, h, w, c = videos.shape
+ videos_th = []
+ for i in range(n):
+ video = videos[i, :, :, :, :]
+ images = [video[j, :, :, :] for j in range(t)]
+ images = [to_tensor(img) for img in images]
+ video = torch.stack(images)
+ videos_th.append(video)
+ if verbose:
+ videos = [
+ fill_with_black_squares(v, num_frames)
+ for v in tqdm(videos_th, desc="Adding empty frames")
+ ] # NTCHW
+ else:
+ videos = [fill_with_black_squares(v, num_frames) for v in videos_th] # NTCHW
+
+ frame_grids = torch.stack(videos).permute(1, 0, 2, 3, 4) # [T, N, C, H, W]
+ if nrow is None:
+ nrow = int(np.ceil(np.sqrt(n)))
+ if verbose:
+ frame_grids = [
+ make_grid(fs, nrow=nrow) for fs in tqdm(frame_grids, desc="Making grids")
+ ]
+ else:
+ frame_grids = [make_grid(fs, nrow=nrow) for fs in frame_grids]
+
+ if os.path.dirname(out_path) != "":
+ os.makedirs(os.path.dirname(out_path), exist_ok=True)
+ frame_grids = (
+ (torch.stack(frame_grids) * 255).to(torch.uint8).permute(0, 2, 3, 1)
+ ) # [T, H, W, C]
+ torchvision.io.write_video(
+ out_path, frame_grids, fps=fps, video_codec="h264", options={"crf": "10"}
+ )
diff --git a/utils/utils.py b/utils/utils.py
new file mode 100755
index 0000000000000000000000000000000000000000..cf12fda8ce892a4cb68a010c63d37bf8abf16f27
--- /dev/null
+++ b/utils/utils.py
@@ -0,0 +1,64 @@
+import importlib
+import numpy as np
+
+import torch
+import torch.distributed as dist
+
+
+def count_params(model, verbose=False):
+ total_params = sum(p.numel() for p in model.parameters())
+ if verbose:
+ print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
+ return total_params
+
+
+def check_istarget(name, para_list):
+ """
+ name: full name of source para
+ para_list: partial name of target para
+ """
+ istarget = False
+ for para in para_list:
+ if para in name:
+ return True
+ return istarget
+
+
+def instantiate_from_config(config):
+ if not "target" in config:
+ if config == "__is_first_stage__":
+ return None
+ elif config == "__is_unconditional__":
+ return None
+ raise KeyError("Expected key `target` to instantiate.")
+ return get_obj_from_str(config["target"])(**config.get("params", dict()))
+
+
+def get_obj_from_str(string, reload=False):
+ module, cls = string.rsplit(".", 1)
+ if reload:
+ module_imp = importlib.import_module(module)
+ importlib.reload(module_imp)
+ return getattr(importlib.import_module(module, package=None), cls)
+
+
+def load_npz_from_dir(data_dir):
+ data = [
+ np.load(os.path.join(data_dir, data_name))["arr_0"]
+ for data_name in os.listdir(data_dir)
+ ]
+ data = np.concatenate(data, axis=0)
+ return data
+
+
+def load_npz_from_paths(data_paths):
+ data = [np.load(data_path)["arr_0"] for data_path in data_paths]
+ data = np.concatenate(data, axis=0)
+ return data
+
+
+def setup_dist(args):
+ if dist.is_initialized():
+ return
+ torch.cuda.set_device(args.local_rank)
+ torch.distributed.init_process_group("nccl", init_method="env://")