diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..37052978977a6418dce6eea529a5353b380fd064 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +assets/demo2.gif filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..fc6039f2c5a753a7ce0b264b019ed26692ec6973 --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +*.pyc +*.pt +!dataset/*/data_stats.pth +dataset \ No newline at end of file diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 0000000000000000000000000000000000000000..3232ed665566ec047ce55a929db1581dbda266a1 --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,80 @@ +# Code of Conduct + +## Our Pledge + +In the interest of fostering an open and welcoming environment, we as +contributors and maintainers pledge to make participation in our project and +our community a harassment-free experience for everyone, regardless of age, body +size, disability, ethnicity, sex characteristics, gender identity and expression, +level of experience, education, socio-economic status, nationality, personal +appearance, race, religion, or sexual identity and orientation. + +## Our Standards + +Examples of behavior that contributes to creating a positive environment +include: + +* Using welcoming and inclusive language +* Being respectful of differing viewpoints and experiences +* Gracefully accepting constructive criticism +* Focusing on what is best for the community +* Showing empathy towards other community members + +Examples of unacceptable behavior by participants include: + +* The use of sexualized language or imagery and unwelcome sexual attention or +advances +* Trolling, insulting/derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or electronic +address, without explicit permission +* Other conduct which could reasonably be considered inappropriate in a +professional setting + +## Our Responsibilities + +Project maintainers are responsible for clarifying the standards of acceptable +behavior and are expected to take appropriate and fair corrective action in +response to any instances of unacceptable behavior. + +Project maintainers have the right and responsibility to remove, edit, or +reject comments, commits, code, wiki edits, issues, and other contributions +that are not aligned to this Code of Conduct, or to ban temporarily or +permanently any contributor for other behaviors that they deem inappropriate, +threatening, offensive, or harmful. + +## Scope + +This Code of Conduct applies within all project spaces, and it also applies when +an individual is representing the project or its community in public spaces. +Examples of representing a project or community include using an official +project e-mail address, posting via an official social media account, or acting +as an appointed representative at an online or offline event. Representation of +a project may be further defined and clarified by project maintainers. + +This Code of Conduct also applies outside the project spaces when there is a +reasonable belief that an individual's behavior may have a negative impact on +the project or its community. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported by contacting the project team at . All +complaints will be reviewed and investigated and will result in a response that +is deemed necessary and appropriate to the circumstances. The project team is +obligated to maintain confidentiality with regard to the reporter of an incident. +Further details of specific enforcement policies may be posted separately. + +Project maintainers who do not follow or enforce the Code of Conduct in good +faith may face temporary or permanent repercussions as determined by other +members of the project's leadership. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, +available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html + +[homepage]: https://www.contributor-covenant.org + +For answers to common questions about this code of conduct, see +https://www.contributor-covenant.org/faq diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000000000000000000000000000000000000..32bbaee9d2b29b9d8cd19bee046b93f685db47fd --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,31 @@ +# Contributing to audio2photoreal +We want to make contributing to this project as easy and transparent as +possible. + +## Pull Requests +We actively welcome your pull requests. + +1. Fork the repo and create your branch from `main`. +2. If you've added code that should be tested, add tests. +3. If you've changed APIs, update the documentation. +4. Ensure the test suite passes. +5. Make sure your code lints. +6. If you haven't already, complete the Contributor License Agreement ("CLA"). + +## Contributor License Agreement ("CLA") +In order to accept your pull request, we need you to submit a CLA. You only need +to do this once to work on any of Meta's open source projects. + +Complete your CLA here: + +## Issues +We use GitHub issues to track public bugs. Please ensure your description is +clear and has sufficient instructions to be able to reproduce the issue. + +Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe +disclosure of security bugs. In those cases, please go through the process +outlined on that page and do not file a public issue. + +## License +By contributing to audio2photoreal, you agree that your contributions will be licensed +under the LICENSE file in the root directory of this source tree. diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..320e3396e0f4867fc209f5c7b5cfe94a84f15dad --- /dev/null +++ b/LICENSE @@ -0,0 +1,400 @@ +Attribution-NonCommercial 4.0 International + +======================================================================= + +Creative Commons Corporation ("Creative Commons") is not a law firm and +does not provide legal services or legal advice. Distribution of +Creative Commons public licenses does not create a lawyer-client or +other relationship. Creative Commons makes its licenses and related +information available on an "as-is" basis. Creative Commons gives no +warranties regarding its licenses, any material licensed under their +terms and conditions, or any related information. Creative Commons +disclaims all liability for damages resulting from their use to the +fullest extent possible. + +Using Creative Commons Public Licenses + +Creative Commons public licenses provide a standard set of terms and +conditions that creators and other rights holders may use to share +original works of authorship and other material subject to copyright +and certain other rights specified in the public license below. The +following considerations are for informational purposes only, are not +exhaustive, and do not form part of our licenses. + + Considerations for licensors: Our public licenses are + intended for use by those authorized to give the public + permission to use material in ways otherwise restricted by + copyright and certain other rights. Our licenses are + irrevocable. Licensors should read and understand the terms + and conditions of the license they choose before applying it. + Licensors should also secure all rights necessary before + applying our licenses so that the public can reuse the + material as expected. Licensors should clearly mark any + material not subject to the license. This includes other CC- + licensed material, or material used under an exception or + limitation to copyright. More considerations for licensors: + wiki.creativecommons.org/Considerations_for_licensors + + Considerations for the public: By using one of our public + licenses, a licensor grants the public permission to use the + licensed material under specified terms and conditions. If + the licensor's permission is not necessary for any reason--for + example, because of any applicable exception or limitation to + copyright--then that use is not regulated by the license. Our + licenses grant only permissions under copyright and certain + other rights that a licensor has authority to grant. Use of + the licensed material may still be restricted for other + reasons, including because others have copyright or other + rights in the material. A licensor may make special requests, + such as asking that all changes be marked or described. + Although not required by our licenses, you are encouraged to + respect those requests where reasonable. More_considerations + for the public: + wiki.creativecommons.org/Considerations_for_licensees + +======================================================================= + +Creative Commons Attribution-NonCommercial 4.0 International Public +License + +By exercising the Licensed Rights (defined below), You accept and agree +to be bound by the terms and conditions of this Creative Commons +Attribution-NonCommercial 4.0 International Public License ("Public +License"). To the extent this Public License may be interpreted as a +contract, You are granted the Licensed Rights in consideration of Your +acceptance of these terms and conditions, and the Licensor grants You +such rights in consideration of benefits the Licensor receives from +making the Licensed Material available under these terms and +conditions. + +Section 1 -- Definitions. + + a. Adapted Material means material subject to Copyright and Similar + Rights that is derived from or based upon the Licensed Material + and in which the Licensed Material is translated, altered, + arranged, transformed, or otherwise modified in a manner requiring + permission under the Copyright and Similar Rights held by the + Licensor. For purposes of this Public License, where the Licensed + Material is a musical work, performance, or sound recording, + Adapted Material is always produced where the Licensed Material is + synched in timed relation with a moving image. + + b. Adapter's License means the license You apply to Your Copyright + and Similar Rights in Your contributions to Adapted Material in + accordance with the terms and conditions of this Public License. + + c. Copyright and Similar Rights means copyright and/or similar rights + closely related to copyright including, without limitation, + performance, broadcast, sound recording, and Sui Generis Database + Rights, without regard to how the rights are labeled or + categorized. For purposes of this Public License, the rights + specified in Section 2(b)(1)-(2) are not Copyright and Similar + Rights. + d. Effective Technological Measures means those measures that, in the + absence of proper authority, may not be circumvented under laws + fulfilling obligations under Article 11 of the WIPO Copyright + Treaty adopted on December 20, 1996, and/or similar international + agreements. + + e. Exceptions and Limitations means fair use, fair dealing, and/or + any other exception or limitation to Copyright and Similar Rights + that applies to Your use of the Licensed Material. + + f. Licensed Material means the artistic or literary work, database, + or other material to which the Licensor applied this Public + License. + + g. Licensed Rights means the rights granted to You subject to the + terms and conditions of this Public License, which are limited to + all Copyright and Similar Rights that apply to Your use of the + Licensed Material and that the Licensor has authority to license. + + h. Licensor means the individual(s) or entity(ies) granting rights + under this Public License. + + i. NonCommercial means not primarily intended for or directed towards + commercial advantage or monetary compensation. For purposes of + this Public License, the exchange of the Licensed Material for + other material subject to Copyright and Similar Rights by digital + file-sharing or similar means is NonCommercial provided there is + no payment of monetary compensation in connection with the + exchange. + + j. Share means to provide material to the public by any means or + process that requires permission under the Licensed Rights, such + as reproduction, public display, public performance, distribution, + dissemination, communication, or importation, and to make material + available to the public including in ways that members of the + public may access the material from a place and at a time + individually chosen by them. + + k. Sui Generis Database Rights means rights other than copyright + resulting from Directive 96/9/EC of the European Parliament and of + the Council of 11 March 1996 on the legal protection of databases, + as amended and/or succeeded, as well as other essentially + equivalent rights anywhere in the world. + + l. You means the individual or entity exercising the Licensed Rights + under this Public License. Your has a corresponding meaning. + +Section 2 -- Scope. + + a. License grant. + + 1. Subject to the terms and conditions of this Public License, + the Licensor hereby grants You a worldwide, royalty-free, + non-sublicensable, non-exclusive, irrevocable license to + exercise the Licensed Rights in the Licensed Material to: + + a. reproduce and Share the Licensed Material, in whole or + in part, for NonCommercial purposes only; and + + b. produce, reproduce, and Share Adapted Material for + NonCommercial purposes only. + + 2. Exceptions and Limitations. For the avoidance of doubt, where + Exceptions and Limitations apply to Your use, this Public + License does not apply, and You do not need to comply with + its terms and conditions. + + 3. Term. The term of this Public License is specified in Section + 6(a). + + 4. Media and formats; technical modifications allowed. The + Licensor authorizes You to exercise the Licensed Rights in + all media and formats whether now known or hereafter created, + and to make technical modifications necessary to do so. The + Licensor waives and/or agrees not to assert any right or + authority to forbid You from making technical modifications + necessary to exercise the Licensed Rights, including + technical modifications necessary to circumvent Effective + Technological Measures. For purposes of this Public License, + simply making modifications authorized by this Section 2(a) + (4) never produces Adapted Material. + + 5. Downstream recipients. + + a. Offer from the Licensor -- Licensed Material. Every + recipient of the Licensed Material automatically + receives an offer from the Licensor to exercise the + Licensed Rights under the terms and conditions of this + Public License. + + b. No downstream restrictions. You may not offer or impose + any additional or different terms or conditions on, or + apply any Effective Technological Measures to, the + Licensed Material if doing so restricts exercise of the + Licensed Rights by any recipient of the Licensed + Material. + + 6. No endorsement. Nothing in this Public License constitutes or + may be construed as permission to assert or imply that You + are, or that Your use of the Licensed Material is, connected + with, or sponsored, endorsed, or granted official status by, + the Licensor or others designated to receive attribution as + provided in Section 3(a)(1)(A)(i). + + b. Other rights. + + 1. Moral rights, such as the right of integrity, are not + licensed under this Public License, nor are publicity, + privacy, and/or other similar personality rights; however, to + the extent possible, the Licensor waives and/or agrees not to + assert any such rights held by the Licensor to the limited + extent necessary to allow You to exercise the Licensed + Rights, but not otherwise. + + 2. Patent and trademark rights are not licensed under this + Public License. + + 3. To the extent possible, the Licensor waives any right to + collect royalties from You for the exercise of the Licensed + Rights, whether directly or through a collecting society + under any voluntary or waivable statutory or compulsory + licensing scheme. In all other cases the Licensor expressly + reserves any right to collect such royalties, including when + the Licensed Material is used other than for NonCommercial + purposes. + +Section 3 -- License Conditions. + +Your exercise of the Licensed Rights is expressly made subject to the +following conditions. + + a. Attribution. + + 1. If You Share the Licensed Material (including in modified + form), You must: + + a. retain the following if it is supplied by the Licensor + with the Licensed Material: + + i. identification of the creator(s) of the Licensed + Material and any others designated to receive + attribution, in any reasonable manner requested by + the Licensor (including by pseudonym if + designated); + + ii. a copyright notice; + + iii. a notice that refers to this Public License; + + iv. a notice that refers to the disclaimer of + warranties; + + v. a URI or hyperlink to the Licensed Material to the + extent reasonably practicable; + + b. indicate if You modified the Licensed Material and + retain an indication of any previous modifications; and + + c. indicate the Licensed Material is licensed under this + Public License, and include the text of, or the URI or + hyperlink to, this Public License. + + 2. You may satisfy the conditions in Section 3(a)(1) in any + reasonable manner based on the medium, means, and context in + which You Share the Licensed Material. For example, it may be + reasonable to satisfy the conditions by providing a URI or + hyperlink to a resource that includes the required + information. + + 3. If requested by the Licensor, You must remove any of the + information required by Section 3(a)(1)(A) to the extent + reasonably practicable. + + 4. If You Share Adapted Material You produce, the Adapter's + License You apply must not prevent recipients of the Adapted + Material from complying with this Public License. + +Section 4 -- Sui Generis Database Rights. + +Where the Licensed Rights include Sui Generis Database Rights that +apply to Your use of the Licensed Material: + + a. for the avoidance of doubt, Section 2(a)(1) grants You the right + to extract, reuse, reproduce, and Share all or a substantial + portion of the contents of the database for NonCommercial purposes + only; + + b. if You include all or a substantial portion of the database + contents in a database in which You have Sui Generis Database + Rights, then the database in which You have Sui Generis Database + Rights (but not its individual contents) is Adapted Material; and + + c. You must comply with the conditions in Section 3(a) if You Share + all or a substantial portion of the contents of the database. + +For the avoidance of doubt, this Section 4 supplements and does not +replace Your obligations under this Public License where the Licensed +Rights include other Copyright and Similar Rights. + +Section 5 -- Disclaimer of Warranties and Limitation of Liability. + + a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE + EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS + AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF + ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, + IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, + WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR + PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, + ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT + KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT + ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. + + b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE + TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, + NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, + INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, + COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR + USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN + ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR + DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR + IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. + + c. The disclaimer of warranties and limitation of liability provided + above shall be interpreted in a manner that, to the extent + possible, most closely approximates an absolute disclaimer and + waiver of all liability. + +Section 6 -- Term and Termination. + + a. This Public License applies for the term of the Copyright and + Similar Rights licensed here. However, if You fail to comply with + this Public License, then Your rights under this Public License + terminate automatically. + + b. Where Your right to use the Licensed Material has terminated under + Section 6(a), it reinstates: + + 1. automatically as of the date the violation is cured, provided + it is cured within 30 days of Your discovery of the + violation; or + + 2. upon express reinstatement by the Licensor. + + For the avoidance of doubt, this Section 6(b) does not affect any + right the Licensor may have to seek remedies for Your violations + of this Public License. + + c. For the avoidance of doubt, the Licensor may also offer the + Licensed Material under separate terms or conditions or stop + distributing the Licensed Material at any time; however, doing so + will not terminate this Public License. + + d. Sections 1, 5, 6, 7, and 8 survive termination of this Public + License. + +Section 7 -- Other Terms and Conditions. + + a. The Licensor shall not be bound by any additional or different + terms or conditions communicated by You unless expressly agreed. + + b. Any arrangements, understandings, or agreements regarding the + Licensed Material not stated herein are separate from and + independent of the terms and conditions of this Public License. + +Section 8 -- Interpretation. + + a. For the avoidance of doubt, this Public License does not, and + shall not be interpreted to, reduce, limit, restrict, or impose + conditions on any use of the Licensed Material that could lawfully + be made without permission under this Public License. + + b. To the extent possible, if any provision of this Public License is + deemed unenforceable, it shall be automatically reformed to the + minimum extent necessary to make it enforceable. If the provision + cannot be reformed, it shall be severed from this Public License + without affecting the enforceability of the remaining terms and + conditions. + + c. No term or condition of this Public License will be waived and no + failure to comply consented to unless expressly agreed to by the + Licensor. + + d. Nothing in this Public License constitutes or may be interpreted + as a limitation upon, or waiver of, any privileges and immunities + that apply to the Licensor or You, including from the legal + processes of any jurisdiction or authority. + +======================================================================= + +Creative Commons is not a party to its public +licenses. Notwithstanding, Creative Commons may elect to apply one of +its public licenses to material it publishes and in those instances +will be considered the “Licensor.” The text of the Creative Commons +public licenses is dedicated to the public domain under the CC0 Public +Domain Dedication. Except for the limited purpose of indicating that +material is shared under a Creative Commons public license or as +otherwise permitted by the Creative Commons policies published at +creativecommons.org/policies, Creative Commons does not authorize the +use of the trademark "Creative Commons" or any other trademark or logo +of Creative Commons without its prior written consent including, +without limitation, in connection with any unauthorized modifications +to any of its public licenses or any other arrangements, +understandings, or agreements concerning use of licensed material. For +the avoidance of doubt, this paragraph does not form part of the +public licenses. + +Creative Commons may be contacted at creativecommons.org. + diff --git a/README.md b/README.md index c6c4c199cc4e9e4b5f9627ee09ba60cdfabc79cb..0e3bfc46f4bbf3932c208d001d0fdbfc5257ae88 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,378 @@ --- -title: Test Virtual -emoji: 📈 -colorFrom: gray -colorTo: indigo +title: test_virtual +app_file: ./demo/demo.py sdk: gradio sdk_version: 4.38.1 -app_file: app.py -pinned: false --- +# From Audio to Photoreal Embodiment: Synthesizing Humans in Conversations +This repository contains a pytorch implementation of ["From Audio to Photoreal Embodiment: Synthesizing Humans in Conversations"](https://people.eecs.berkeley.edu/~evonne_ng/projects/audio2photoreal/) -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +:hatching_chick: **Try out our demo [here](https://colab.research.google.com/drive/1lnX3d-3T3LaO3nlN6R8s6pPvVNAk5mdK?usp=sharing)** or continue following the steps below to run code locally! +And thanks everyone for the support via contributions/comments/issues! + +https://github.com/facebookresearch/audio2photoreal/assets/17986358/5cba4079-275e-48b6-aecc-f84f3108c810 + +This codebase provides: +- train code +- test code +- pretrained motion models +- access to dataset + +If you use the dataset or code, please cite our [Paper](https://arxiv.org/abs/2401.01885) + +``` +@inproceedings{ng2024audio2photoreal, + title={From Audio to Photoreal Embodiment: Synthesizing Humans in Conversations}, + author={Ng, Evonne and Romero, Javier and Bagautdinov, Timur and Bai, Shaojie and Darrell, Trevor and Kanazawa, Angjoo and Richard, Alexander}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, + year={2024} +} +``` + +### Repository Contents + +- [**Quickstart:**](#quickstart) easy gradio demo that lets you record audio and render a video +- [**Installation:**](#installation) environment setup and installation (for more details on the rendering pipeline, please refer to [Codec Avatar Body](https://github.com/facebookresearch/ca_body)) +- [**Download data and models:**](#download-data-and-models) download annotations and pre-trained models + - [Dataset desc.](#dataset): description of dataset annotations + - [Visualize Dataset](#visualize-ground-truth): script for visualizing ground truth annotations + - [model desc.](#pretrained-models): description of pretrained models +- [**Running the pretrained models:**](#running-the-pretrained-models) how to generate results files and visualize the results using the rendering pipeline. + - [Face generation](#face-generation): commands to generate the results file for the faces + - [Body generation](#body-generation): commands to generate the results file for the bodies + - [Visualization](#visualization): how to call into the rendering api. For full details, please refer to [this repo](https://github.com/facebookresearch/ca_body). +- [**Training from scratch (3 models):**](#training-from-scratch) scripts to get the training pipeline running from scratch for face, guide poses, and body models. + - [Face diffusion model](#1-face-diffusion-model) + - [Body diffusion](#2-body-diffusion-model) + - [Body vq vae](#3-body-vq-vae) + - [Body guide transformer](#4-body-guide-transformer) + +We annotate code that you can directly copy and paste into your terminal using the :point_down: icon. + +# Quickstart +With this demo, you can record an audio clip and select the number of samples you want to generate. + +Make sure you have CUDA 11.7 and gcc/++ 9.0 for pytorch3d compatibility + +:point_down: Install necessary components. This will do the environment configuration and install the corresponding rendering assets, prerequisite models, and pretrained models: +``` +conda create --name a2p_env python=3.9 +conda activate a2p_env +sh demo/install.sh +``` +:point_down: Run the demo. You can record your audio and then render corresponding results! +``` +python -m demo.demo +``` + +:microphone: First, record your audio + +![](assets/demo1.gif) + +:hourglass: Hold tight because the rendering can take a while! + +You can change the number of samples (1-10) you want to generate, and download your favorite video by clicking on the download button on the top right of each video. + +![](assets/demo2.gif) + +# Installation +The code has been tested with CUDA 11.7 and python 3.9, gcc/++ 9.0 + +:point_down: If you haven't done so already via the demo setup, configure the environments and download prerequisite models: +``` +conda create --name a2p_env python=3.9 +conda activate a2p_env +pip install -r scripts/requirements.txt +sh scripts/download_prereq.sh +``` +:point_down: To get the rendering working, please also make sure you install [pytorch3d](https://github.com/facebookresearch/pytorch3d/blob/main/INSTALL.md). +``` +pip install "git+https://github.com/facebookresearch/pytorch3d.git" +``` +Please see [CA Bodies repo](https://github.com/facebookresearch/ca_body) for more details on the renderer. + +# Download data and models +To download any of the datasets, you can find them at `https://github.com/facebookresearch/audio2photoreal/releases/download/v1.0/.zip`, where you can replace `` with any of `PXB184`, `RLW104`, `TXB805`, or `GQS883`. +Download over the command line can be done with this commands. +``` +curl -L https://github.com/facebookresearch/audio2photoreal/releases/download/v1.0/.zip -o .zip +unzip .zip -d dataset/ +rm .zip +``` +:point_down: To download *all* of the datasets, you can simply run the following which will download and unpack all the models. +``` +sh scripts/download_alldatasets.sh +``` + +Similarly, to download any of the models, you can find them at `http://audio2photoreal_models.berkeleyvision.org/_models.tar`. +``` +# download the motion generation +wget http://audio2photoreal_models.berkeleyvision.org/_models.tar +tar xvf _models.tar +rm _models.tar + +# download the body decoder/rendering assets and place them in the right place +mkdir -p checkpoints/ca_body/data/ +wget https://github.com/facebookresearch/ca_body/releases/download/v0.0.1-alpha/.tar.gz +tar xvf .tar.gz --directory checkpoints/ca_body/data/ +rm .tar.gz +``` +:point_down: You can also download all of the models with this script: +``` +sh scripts/download_allmodels.sh +``` +The above model script will download *both* the models for motion generation and the body decoder/rendering models. Please view the script for more details. + +### Dataset +Once the dataset is downloaded and unzipped (via `scripts/download_datasets.sh`), it should unfold into the following directory structure: +``` +|-- dataset/ + |-- PXB184/ + |-- data_stats.pth + |-- scene01_audio.wav + |-- scene01_body_pose.npy + |-- scene01_face_expression.npy + |-- scene01_missing_face_frames.npy + |-- ... + |-- scene30_audio.wav + |-- scene30_body_pose.npy + |-- scene30_face_expression.npy + |-- scene30_missing_face_frames.npy + |-- RLW104/ + |-- TXB805/ + |-- GQS883/ +``` +Each of the four participants (`PXB184`, `RLW104`, `TXB805`, `GQS883`) should have independent "scenes" (1 to 26 or so). +For each scene, there are 3 types of data annotations that we save. +``` +*audio.wav: wavefile containing the raw audio (two channels, 1600*T samples) at 48kHz; channel 0 is the audio associated with the current person, channel 1 is the audio associated with their conversational partner. + +*body_pose.npy: (T x 104) array of joint angles in a kinematic skeleton. Not all of the joints are represented with 3DoF. Each 104-d vector can be used to reconstruct a full-body skeleton. + +*face_expression.npy: (T x 256) array of facial codes, where each 256-d vector reconstructs a face mesh. + +*missing_face_frames.npy: List of indices (t) where the facial code is missing or corrupted. + +data_stats.pth: carries the mean and std for each modality of each person. +``` + +For the train/val/test split the indices are defined in `data_loaders/data.py` as: +``` +train_idx = list(range(0, len(data_dict["data"]) - 6)) +val_idx = list(range(len(data_dict["data"]) - 6, len(data_dict["data"]) - 4)) +test_idx = list(range(len(data_dict["data"]) - 4, len(data_dict["data"]))) +``` +for any of the four dataset participants we train on. + +### Visualize ground truth +If you've properly installed the rendering requirements, you can then visualize the full dataset with the following command: +``` +python -m visualize.render_anno + --save_dir + --data_root + --max_seq_length +``` + +The videos will be chunked lengths according to specified `--max_seq_length` arg, which you can specify (the default is 600). + +:point_down: For example, to visualize ground truth annotations for `PXB184`, you can run the following. +``` +python -m visualize.render_anno --save_dir vis_anno_test --data_root dataset/PXB184 --max_seq_length 600 +``` + +### Pretrained models +We train person-specific models, so each person should have an associated directory. For instance, for `PXB184`, their complete models should unzip into the following structure. +``` +|-- checkpoints/ + |-- diffusion/ + |-- c1_face/ + |-- args.json + |-- model:09d.pt + |-- c1_pose/ + |-- args.json + |-- model:09d.pt + |-- guide/ + |-- c1_pose/ + |-- args.json + |-- checkpoints/ + |-- iter-:07d.pt + |-- vq/ + |-- c1_pose/ + |-- args.json + |-- net_iter:06d.pth +``` +There are 4 models for each person and each model has an associated `args.json`. +1. a face diffusion model that outputs 256 facial codes conditioned on audio +2. a pose diffusion model that outputs 104 joint rotations conditioned on audio and guide poses +3. a guide vq pose model that outputs vq tokens conditioned on audio at 1 fps +4. a vq encoder-decoder model that vector quantizes the continuous 104-d pose space. + +# Running the pretrained models +To run the actual models, you will need to run the pretrained models and generate the associated results files before visualizing them. + +### Face generation +To generate the results file for the face, +``` +python -m sample.generate + --model_path + --num_samples + --num_repetitions + --timestep_respacing ddim500 + --guidance_param 10.0 +``` + +The `` should be the path to the diffusion model that is associated with generating the face. +E.g. for participant `PXB184`, the path might be `./checkpoints/diffusion/c1_face/model000155000.pt` +The other parameters are: +``` +--num_samples: number of samples to generate. To sample the full dataset, use 56 (except for TXB805, whcih is 58). +--num_repetitions: number of times to repeat the sampling, such that total number of sequences generated is (num_samples * num_repetitions). +--timestep_respacing: how many diffusion steps to take. Format will always be ddim. +--guidance_param: how influential the conditioning is on the results. I usually use range 2.0-10.0, and tend towards higher for the face. +``` + +:point_down: A full example of running the face model for `PXB184` with the provided pretrained models would then be: +``` +python -m sample.generate --model_path checkpoints/diffusion/c1_face/model000155000.pt --num_samples 10 --num_repetitions 5 --timestep_respacing ddim500 --guidance_param 10.0 +``` +This generates 10 samples from the dataset 1 time. The output results file will be saved to: +`./checkpoints/diffusion/c1_face/samples_c1_face_000155000_seed10_/results.npy` + +### Body generation +To generate the corresponding body, it will be very similar to generating the face, except now we have to feed in the model for generating the guide poses as well. +``` +python -m sample.generate + --model_path + --resume_trans + --num_samples + --num_repetitions + --timestep_respacing ddim500 + --guidance_param 2.0 +``` + +:point_down: Here, `` should point to the guide transformer. The full command would be: +``` +python -m sample.generate --model_path checkpoints/diffusion/c1_pose/model000340000.pt --resume_trans checkpoints/guide/c1_pose/checkpoints/iter-0100000.pt --num_samples 10 --num_repetitions 5 --timestep_respacing ddim500 --guidance_param 2.0 +``` +Similarly, the output will be saved to: +`./checkpoints/diffusion/c1_pose/samples_c1_pose_000340000_seed10_guide_iter-0100000.pt/results.npy` + +### Visualization +On the body generation side of things, you can also optionally pass in the `--plot` flag in order to render out the photorealistic avatar. You will also need to pass in the corresponding generated face codes with the `--face_codes` flag. +Optionally, if you already have the poses precomputed, you an also pass in the generated body with the `--pose_codes` flag. +This will save videos in the same directory as where the body's `results.npy` is stored. + +:point_down: An example of the full command with *the three new flags added is*: +``` +python -m sample.generate --model_path checkpoints/diffusion/c1_pose/model000340000.pt --resume_trans checkpoints/guide/c1_pose/checkpoints/iter-0100000.pt --num_samples 10 --num_repetitions 5 --timestep_respacing ddim500 --guidance_param 2.0 --face_codes ./checkpoints/diffusion/c1_face/samples_c1_face_000155000_seed10_/results.npy --pose_codes ./checkpoints/diffusion/c1_pose/samples_c1_pose_000340000_seed10_guide_iter-0100000.pt/results.npy --plot +``` +The remaining flags can be the same as before. For the actual rendering api, please see [Codec Avatar Body](https://github.com/facebookresearch/ca_body) for installation etc. +*Important: in order to visualize the full photorealistic avatar, you will need to run the face codes first, then pass them into the body generation code.* It will not work if you try to call generate with `--plot` for the face codes. + +# Training from scratch +There are four possible models you will need to train: 1) the face diffusion model, 2) the body diffusion model, 3) the body vq vae, 4) the body guide transformer. +The only dependency is that 3) is needed for 4). All other models can be trained in parallel. + +### 1) Face diffusion model +To train the face model, you will need to run the following script: +``` +python -m train.train_diffusion + --save_dir + --data_root + --batch_size + --dataset social + --data_format face + --layers 8 + --heads 8 + --timestep_respacing '' + --max_seq_length 600 +``` +Importantly, a few of the flags are as follows: +``` +--save_dir: path to directory where all outputs are stored +--data_root: path to the directory of where to load the data from +--dataset: name of dataset to load; right now we only support the 'social' dataset +--data_format: set to 'face' for the face, as opposed to pose +--timestep_respacing: set to '' which does the default spacing of 1k diffusion steps +--max_seq_length: the maximum number of frames for a given sequence to train on +``` +:point_down: A full example for training on person `PXB184` is: +``` +python -m train.train_diffusion --save_dir checkpoints/diffusion/c1_face_test --data_root ./dataset/PXB184/ --batch_size 4 --dataset social --data_format face --layers 8 --heads 8 --timestep_respacing '' --max_seq_length 600 +``` + +### 2) Body diffusion model +Training the body model is similar to the face model, but with the following additional parameters +``` +python -m train.train_diffusion + --save_dir + --data_root + --lambda_vel + --batch_size + --dataset social + --add_frame_cond 1 + --data_format pose + --layers 6 + --heads 8 + --timestep_respacing '' + --max_seq_length 600 +``` +The flags that differ from the face training are as follows: +``` +--lambda_vel: additional auxilary loss for training with velocity +--add_frame_cond: set to '1' for 1 fps. if not specified, it will default to 30 fps. +--data_format: set to 'pose' for the body, as opposed to face +``` +:point_down: A full example for training on person `PXB184` is: +``` +python -m train.train_diffusion --save_dir checkpoints/diffusion/c1_pose_test --data_root ./dataset/PXB184/ --lambda_vel 2.0 --batch_size 4 --dataset social --add_frame_cond 1 --data_format pose --layers 6 --heads 8 --timestep_respacing '' --max_seq_length 600 +``` + +### 3) Body VQ VAE +To train a vq encoder-decoder, you will need to run the following script: +``` +python -m train.train_vq + --out_dir + --data_root + --batch_size + --lr 1e-3 + --code_dim 1024 + --output_emb_width 64 + --depth 4 + --dataname social + --loss_vel 0.0 + --add_frame_cond 1 + --data_format pose + --max_seq_length 600 +``` +:point_down: For person `PXB184`, it would be: +``` +python -m train.train_vq --out_dir checkpoints/vq/c1_vq_test --data_root ./dataset/PXB184/ --lr 1e-3 --code_dim 1024 --output_emb_width 64 --depth 4 --dataname social --loss_vel 0.0 --data_format pose --batch_size 4 --add_frame_cond 1 --max_seq_length 600 +``` + +### 4) Body guide transformer +Once you have the vq trained from 3) you can then pass it in to train the body guide pose transformer: +``` +python -m train.train_guide + --out_dir + --data_root + --batch_size + --resume_pth + --add_frame_cond 1 + --layers 6 + --lr 2e-4 + --gn + --dim 64 +``` +:point_down: For person `PXB184`, it would be: +``` +python -m train.train_guide --out_dir checkpoints/guide/c1_trans_test --data_root ./dataset/PXB184/ --batch_size 4 --resume_pth checkpoints/vq/c1_vq_test/net_iter300000.pth --add_frame_cond 1 --layers 6 --lr 2e-4 --gn --dim 64 +``` + +After training these 4 models, you can now follow the ["Running the pretrained models"](#running-the-pretrained-models) section to generate samples and visualize results. + +You can also visualize the corresponding ground truth sequences by passing in the `--render_gt` flag. + + +# License +The code and dataset are released under [CC-NC 4.0 International license](https://github.com/facebookresearch/audio2photoreal/blob/main/LICENSE). diff --git a/assets/demo1.gif b/assets/demo1.gif new file mode 100644 index 0000000000000000000000000000000000000000..f7e52943590a488452809ba1b766485de1f8d9db Binary files /dev/null and b/assets/demo1.gif differ diff --git a/assets/demo2.gif b/assets/demo2.gif new file mode 100644 index 0000000000000000000000000000000000000000..9aef0cbf4e0d07faf55d3a27716bf41aef5243a8 --- /dev/null +++ b/assets/demo2.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4d07d3817b4a23bdb0a36869a469d051b9b10fe68d9e6f02f6cc8765cd6f5bc3 +size 1313657 diff --git a/assets/render_defaults_GQS883.pth b/assets/render_defaults_GQS883.pth new file mode 100644 index 0000000000000000000000000000000000000000..fbd96b9f605dc9821bd0be76db9dfd09bdba92ae --- /dev/null +++ b/assets/render_defaults_GQS883.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3ae7ee73849e258bbb8d8a04aa674960896fc1dff8757fefbd2df1685225dd7d +size 71354547 diff --git a/assets/render_defaults_PXB184.pth b/assets/render_defaults_PXB184.pth new file mode 100644 index 0000000000000000000000000000000000000000..dccf18780e13e5ea8320ce902d5586568c5a0c90 --- /dev/null +++ b/assets/render_defaults_PXB184.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c86ba14a58d4829c8d05428f5e601072dc4bab1bdc60bc53ce6c73990e9b97d7 +size 71354547 diff --git a/assets/render_defaults_RLW104.pth b/assets/render_defaults_RLW104.pth new file mode 100644 index 0000000000000000000000000000000000000000..7a24fb9d4c6029bd2f524af2bedb43c7f45f5fcc --- /dev/null +++ b/assets/render_defaults_RLW104.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:808a3fbf33115d3cc132bad48c2e95bfca29bb1847d912b1f72e5e5b4a081db5 +size 71354547 diff --git a/assets/render_defaults_TXB805.pth b/assets/render_defaults_TXB805.pth new file mode 100644 index 0000000000000000000000000000000000000000..4462856fa68ed8d3b872728d7df818b0954908bd --- /dev/null +++ b/assets/render_defaults_TXB805.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d7985c79edfba70f83f560859f2ce214d9779a46031aa8ca6a917d8fd4417e24 +size 71354547 diff --git a/checkpoints/ca_body/data/PXB184/body_dec.ckpt b/checkpoints/ca_body/data/PXB184/body_dec.ckpt new file mode 100644 index 0000000000000000000000000000000000000000..b6cd20e6442e36a35c43bd4731c21bfe4b6aa035 --- /dev/null +++ b/checkpoints/ca_body/data/PXB184/body_dec.ckpt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:26394ae03c1726b7c90b5633696d0eea733a3c5e423893c4e79b490c80c35ddf +size 893279810 diff --git a/checkpoints/ca_body/data/PXB184/config.yml b/checkpoints/ca_body/data/PXB184/config.yml new file mode 100644 index 0000000000000000000000000000000000000000..a9b02cf0fa869cfb48877bb0de43485dec9e4f75 --- /dev/null +++ b/checkpoints/ca_body/data/PXB184/config.yml @@ -0,0 +1,56 @@ + +model: + class_name: ca_body.models.mesh_vae_drivable.AutoEncoder + + encoder: + n_embs: 1024 + noise_std: 1.0 + + encoder_face: + n_embs: 256 + noise_std: 1.0 + + decoder_face: + n_latent: 256 + n_vert_out: 21918 + + decoder: + init_uv_size: 64 + n_init_channels: 64 + n_min_channels: 4 + n_pose_dims: 98 + n_pose_enc_channels: 16 + n_embs: 1024 + n_embs_enc_channels: 32 + n_face_embs: 256 + uv_size: 1024 + + decoder_view: + net_uv_size: 1024 + + upscale_net: + n_ftrs: 4 + + shadow_net: + uv_size: 2048 + shadow_size: 256 + n_dims: 4 + + pose_to_shadow: + n_pose_dims: 104 + uv_size: 2048 + + renderer: + image_height: 2048 + image_width: 1334 + depth_disc_ksize: 3 + + cal: + identity_camera: '400143' + + pixel_cal: + image_height: 2048 + image_width: 1334 + ds_rate: 8 + + learn_blur: true \ No newline at end of file diff --git a/checkpoints/diffusion/c1_face/args.json b/checkpoints/diffusion/c1_face/args.json new file mode 100644 index 0000000000000000000000000000000000000000..8f0ec9abe4dcb51500b0dbd2142c305601516230 --- /dev/null +++ b/checkpoints/diffusion/c1_face/args.json @@ -0,0 +1,34 @@ +{ + "add_frame_cond": null, + "batch_size": 4, + "cond_mask_prob": 0.2, + "cuda": true, + "data_format": "face", + "data_root": "./dataset/PXB184/", + "dataset": "social", + "device": 0, + "diffusion_steps": 10, + "heads": 8, + "lambda_vel": 0.0, + "latent_dim": 512, + "layers": 8, + "log_interval": 1000, + "lr": 0.0001, + "lr_anneal_steps": 0, + "max_seq_length": 600, + "noise_schedule": "cosine", + "not_rotary": false, + "num_audio_layers": 3, + "num_steps": 800000, + "overwrite": false, + "resume_checkpoint": "", + "save_dir": "checkpoints/diffusion/c1_face/", + "save_interval": 5000, + "seed": 10, + "sigma_small": true, + "simplify_audio": false, + "timestep_respacing": "", + "train_platform_type": "NoPlatform", + "unconstrained": false, + "weight_decay": 0.0 +} \ No newline at end of file diff --git a/checkpoints/diffusion/c1_pose/args.json b/checkpoints/diffusion/c1_pose/args.json new file mode 100644 index 0000000000000000000000000000000000000000..1b0b1da9c84c010907087ac1996403236a09fa42 --- /dev/null +++ b/checkpoints/diffusion/c1_pose/args.json @@ -0,0 +1,66 @@ +{ + "add_frame_cond": 1.0, + "arch": "trans_enc", + "batch_size": 32, + "clip_body": false, + "clip_use_delta": false, + "clip_use_vae": false, + "cond_mask_prob": 0.1, + "cuda": true, + "data_format": "pose", + "data_root": "./dataset/PXB184/", + "dataset": "social", + "device": 0, + "diffusion_steps": 10, + "emb_trans_dec": false, + "eval_batch_size": 32, + "eval_during_training": false, + "eval_num_samples": 1000, + "eval_rep_times": 3, + "eval_split": "val", + "filter": false, + "heads": 8, + "lambda_fc": 0.0, + "lambda_hands": 0.0, + "lambda_lips": 0.0, + "lambda_rcxyz": 0.0, + "lambda_vel": 2.0, + "lambda_xyz": 0.0, + "lambda_xyz_vel": 0.0, + "latent_dim": 512, + "layers": 6, + "log_interval": 1000, + "lr": 0.0001, + "lr_anneal_steps": 0, + "max_seq_length": 600, + "no_split": false, + "noise_schedule": "cosine", + "not_rotary": false, + "num_frames": 60, + "num_steps": 800000, + "overwrite": false, + "partial": false, + "resume_checkpoint": "", + "save_dir": "checkpoints/diffusion/c1_pose/", + "save_interval": 5000, + "seed": 10, + "sigma_small": true, + "simplify_audio": false, + "split_net": false, + "timestep_respacing": "", + "train_platform_type": "NoPlatform", + "unconstrained": false, + "use_clip": false, + "use_cm": true, + "use_full_dataset": false, + "use_kp": false, + "use_mask": true, + "use_mdm": false, + "use_nort": false, + "use_nort_mdm": false, + "use_pose_pos": false, + "use_resnet": true, + "use_vae": null, + "weight_decay": 0.0, + "z_norm": true +} diff --git a/checkpoints/guide/c1_pose/args.json b/checkpoints/guide/c1_pose/args.json new file mode 100644 index 0000000000000000000000000000000000000000..a2055f15c70c28be4823c3fae66b2710165cf00e --- /dev/null +++ b/checkpoints/guide/c1_pose/args.json @@ -0,0 +1,41 @@ +{ + "add_audio_pe": true, + "add_conv": true, + "add_frame_cond": 1, + "batch_size": 16, + "data_format": "pose", + "data_root": "./dataset/PXB184/", + "dataset": "social", + "dec_layers": null, + "dim": 64, + "enc_layers": null, + "eval_interval": 1000, + "filter": false, + "gamma": 0.1, + "gn": true, + "layers": 6, + "log_interval": 1000, + "lr": 0.0001, + "lr_scheduler": [ + 50000, + 400000 + ], + "no_split": false, + "num_audio_layers":2, + "out_dir": "checkpoints/guide/c1_pose", + "partial": false, + "resume_pth": "checkpoints/vq/c1_pose/net_iter300000.pth", + "resume_trans": null, + "save_interval": 5000, + "seed": 10, + "simplify_audio": false, + "total_iter": 1000000, + "use_full_dataset": false, + "use_kp": false, + "use_lstm": false, + "use_nort": false, + "use_nort_mdm": false, + "use_torch": false, + "warm_up_iter": 5000, + "weight_decay": 0.1 +} diff --git a/checkpoints/vq/c1_pose/args.json b/checkpoints/vq/c1_pose/args.json new file mode 100644 index 0000000000000000000000000000000000000000..4880e5a255a4ada34a28b8c77bc810cd7d31b58c --- /dev/null +++ b/checkpoints/vq/c1_pose/args.json @@ -0,0 +1,43 @@ +{ + "add_frame_cond": 1.0, + "batch_size": 16, + "code_dim": 1024, + "commit": 0.02, + "data_format": "pose", + "data_root": "./dataset/PXB184/", + "dataname": "social", + "dataset": "social", + "depth": 4, + "eval_iter": 1000, + "exp_name": "c1_pose", + "filter": false, + "gamma": 0.05, + "loss_vel": 0.0, + "lr": 0.001, + "lr_scheduler": [ + 300000 + ], + "max_seq_length": 600, + "nb_joints": 104, + "no_split": true, + "out_dir": "checkpoints/vq/c1_pose", + "output_emb_width": 64, + "partial": false, + "print_iter": 200, + "results_dir": "visual_results/", + "resume_pth": null, + "seed": 123, + "simplify_audio": false, + "total_iter": 10000000, + "use_full_dataset": false, + "use_kp": false, + "use_linear": false, + "use_nort": false, + "use_nort_mdm": false, + "use_quant": true, + "use_vae": false, + "visual_name": "baseline", + "warm_up_iter": 1000, + "weight_decay": 0.0, + "z_norm": true +} \ No newline at end of file diff --git a/checkpoints/vq/c1_pose/net_iter300000.pth b/checkpoints/vq/c1_pose/net_iter300000.pth new file mode 100644 index 0000000000000000000000000000000000000000..a4afe7cb7448cc6ebbac6ef378d0b4ae92d4e923 --- /dev/null +++ b/checkpoints/vq/c1_pose/net_iter300000.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5649ad5e49e0e1afcd9a7390f0ee79ee66de275a67ecb1cfe7fc691cb4ceb332 +size 3129275 diff --git a/data_loaders/data.py b/data_loaders/data.py new file mode 100644 index 0000000000000000000000000000000000000000..9312ccf7aef858e43a924954e3a82a94e260c2dc --- /dev/null +++ b/data_loaders/data.py @@ -0,0 +1,253 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import os +from typing import Dict, Iterable, List, Union + +import numpy as np +import torch +from torch.utils import data + +from utils.misc import prGreen + + +class Social(data.Dataset): + def __init__( + self, + args, + data_dict: Dict[str, Iterable], + split: str = "train", + chunk: bool = False, + add_padding: bool = True, + ) -> None: + if args.data_format == "face": + prGreen("[dataset.py] training face only model") + data_dict["data"] = data_dict["face"] + elif args.data_format == "pose": + prGreen("[dataset.py] training pose only model") + missing = [] + for d in data_dict["data"]: + missing.append(np.ones_like(d)) + data_dict["missing"] = missing + + # set up variables for dataloader + self.data_format = args.data_format + self.add_frame_cond = args.add_frame_cond + self._register_keyframe_step() + self.data_root = args.data_root + self.max_seq_length = args.max_seq_length + if hasattr(args, "curr_seq_length") and args.curr_seq_length is not None: + self.max_seq_length = args.curr_seq_length + prGreen([f"[dataset.py] sequences of {self.max_seq_length}"]) + self.add_padding = add_padding + self.audio_per_frame = 1600 + self.max_audio_length = self.max_seq_length * self.audio_per_frame + self.min_seq_length = 400 + + # set up training/validation splits + train_idx = list(range(0, len(data_dict["data"]) - 6)) + val_idx = list(range(len(data_dict["data"]) - 6, len(data_dict["data"]) - 4)) + test_idx = list(range(len(data_dict["data"]) - 4, len(data_dict["data"]))) + self.split = split + if split == "train": + self._pick_sequences(data_dict, train_idx) + elif split == "val": + self._pick_sequences(data_dict, val_idx) + else: + self._pick_sequences(data_dict, test_idx) + self.chunk = chunk + if split == "test": + print("[dataset.py] chunking data...") + self._chunk_data() + self._load_std() + prGreen( + f"[dataset.py] {split} | {len(self.data)} sequences ({self.data[0].shape}) | total len {self.total_len}" + ) + + def inv_transform( + self, data: Union[np.ndarray, torch.Tensor], data_type: str + ) -> Union[np.ndarray, torch.Tensor]: + if data_type == "pose": + std = self.std + mean = self.mean + elif data_type == "face": + std = self.face_std + mean = self.face_mean + elif data_type == "audio": + std = self.audio_std + mean = self.audio_mean + else: + assert False, f"datatype not defined: {data_type}" + + if torch.is_tensor(data): + return data * torch.tensor( + std, device=data.device, requires_grad=False + ) + torch.tensor(mean, device=data.device, requires_grad=False) + else: + return data * std + mean + + def _pick_sequences(self, data_dict: Dict[str, Iterable], idx: List[int]) -> None: + self.data = np.take(data_dict["data"], idx, axis=0) + self.missing = np.take(data_dict["missing"], idx, axis=0) + self.audio = np.take(data_dict["audio"], idx, axis=0) + self.lengths = np.take(data_dict["lengths"], idx, axis=0) + self.total_len = sum([len(d) for d in self.data]) + + def _load_std(self) -> None: + stats = torch.load(os.path.join(self.data_root, "data_stats.pth")) + print( + f'[dataset.py] loading from... {os.path.join(self.data_root, "data_stats.pth")}' + ) + self.mean = stats["pose_mean"].reshape(-1) + self.std = stats["pose_std"].reshape(-1) + self.face_mean = stats["code_mean"] + self.face_std = stats["code_std"] + self.audio_mean = stats["audio_mean"] + self.audio_std = stats["audio_std_flat"] + + def _chunk_data(self) -> None: + chunk_data = [] + chunk_missing = [] + chunk_lengths = [] + chunk_audio = [] + # create sequences of set lengths + for d_idx in range(len(self.data)): + curr_data = self.data[d_idx] + curr_missing = self.missing[d_idx] + curr_audio = self.audio[d_idx] + end_range = len(self.data[d_idx]) - self.max_seq_length + for chunk_idx in range(0, end_range, self.max_seq_length): + chunk_end = chunk_idx + self.max_seq_length + curr_data_chunk = curr_data[chunk_idx:chunk_end, :] + curr_missing_chunk = curr_missing[chunk_idx:chunk_end, :] + curr_audio_chunk = curr_audio[ + chunk_idx * self.audio_per_frame : chunk_end * self.audio_per_frame, + :, + ] + if curr_data_chunk.shape[0] < self.max_seq_length: + # do not add a short chunk to the list + continue + chunk_lengths.append(curr_data_chunk.shape[0]) + chunk_data.append(curr_data_chunk) + chunk_missing.append(curr_missing_chunk) + chunk_audio.append(curr_audio_chunk) + idx = np.random.permutation(len(chunk_data)) + print("==> shuffle", idx) + self.data = np.take(chunk_data, idx, axis=0) + self.missing = np.take(chunk_missing, idx, axis=0) + self.lengths = np.take(chunk_lengths, idx, axis=0) + self.audio = np.take(chunk_audio, idx, axis=0) + self.total_len = len(self.data) + + def _register_keyframe_step(self) -> None: + if self.add_frame_cond == 1: + self.step = 30 + if self.add_frame_cond is None: + self.step = 1 + + def _pad_sequence( + self, sequence: np.ndarray, actual_length: int, max_length: int + ) -> np.ndarray: + sequence = np.concatenate( + ( + sequence, + np.zeros((max_length - actual_length, sequence.shape[-1])), + ), + axis=0, + ) + return sequence + + def _get_idx(self, item: int) -> int: + cumulative_len = 0 + seq_idx = 0 + while item > cumulative_len: + cumulative_len += len(self.data[seq_idx]) + seq_idx += 1 + item = seq_idx - 1 + return item + + def _get_random_subsection( + self, data_dict: Dict[str, Iterable] + ) -> Dict[str, np.ndarray]: + isnonzero = False + while not isnonzero: + start = np.random.randint(0, data_dict["m_length"] - self.max_seq_length) + if self.add_padding: + length = ( + np.random.randint(self.min_seq_length, self.max_seq_length) + if not self.split == "test" + else self.max_seq_length + ) + else: + length = self.max_seq_length + curr_missing = data_dict["missing"][start : start + length] + isnonzero = np.any(curr_missing) + missing = curr_missing + motion = data_dict["motion"][start : start + length, :] + keyframes = motion[:: self.step] + audio = data_dict["audio"][ + start * self.audio_per_frame : (start + length) * self.audio_per_frame, + :, + ] + data_dict["m_length"] = len(motion) + data_dict["k_length"] = len(keyframes) + data_dict["a_length"] = len(audio) + + if data_dict["m_length"] < self.max_seq_length: + motion = self._pad_sequence( + motion, data_dict["m_length"], self.max_seq_length + ) + missing = self._pad_sequence( + missing, data_dict["m_length"], self.max_seq_length + ) + audio = self._pad_sequence( + audio, data_dict["a_length"], self.max_audio_length + ) + max_step_length = len(np.zeros(self.max_seq_length)[:: self.step]) + keyframes = self._pad_sequence( + keyframes, data_dict["k_length"], max_step_length + ) + data_dict["motion"] = motion + data_dict["keyframes"] = keyframes + data_dict["audio"] = audio + data_dict["missing"] = missing + return data_dict + + def __len__(self) -> int: + return self.total_len + + def __getitem__(self, item: int) -> Dict[str, np.ndarray]: + # figure out which sequence to randomly sample from + if not self.split == "test": + item = self._get_idx(item) + motion = self.data[item] + audio = self.audio[item] + m_length = self.lengths[item] + missing = self.missing[item] + a_length = len(audio) + # Z Normalization + if self.data_format == "pose": + motion = (motion - self.mean) / self.std + elif self.data_format == "face": + motion = (motion - self.face_mean) / self.face_std + audio = (audio - self.audio_mean) / self.audio_std + keyframes = motion[:: self.step] + k_length = len(keyframes) + data_dict = { + "motion": motion, + "m_length": m_length, + "audio": audio, + "a_length": a_length, + "keyframes": keyframes, + "k_length": k_length, + "missing": missing, + } + if not self.split == "test" and not self.chunk: + data_dict = self._get_random_subsection(data_dict) + if self.data_format == "face": + data_dict["motion"] *= data_dict["missing"] + return data_dict diff --git a/data_loaders/get_data.py b/data_loaders/get_data.py new file mode 100644 index 0000000000000000000000000000000000000000..60ceb57c3608cc2b93d6215ad065dc384da9aca3 --- /dev/null +++ b/data_loaders/get_data.py @@ -0,0 +1,129 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import os + +from typing import Dict, List + +import numpy as np +import torch +import torchaudio +from data_loaders.data import Social +from data_loaders.tensors import social_collate +from torch.utils.data import DataLoader +from utils.misc import prGreen + + +def get_dataset_loader( + args, + data_dict: Dict[str, np.ndarray], + split: str = "train", + chunk: bool = False, + add_padding: bool = True, +) -> DataLoader: + dataset = Social( + args=args, + data_dict=data_dict, + split=split, + chunk=chunk, + add_padding=add_padding, + ) + loader = DataLoader( + dataset, + batch_size=args.batch_size, + shuffle=not split == "test", + num_workers=8, + drop_last=True, + collate_fn=social_collate, + pin_memory=True, + ) + return loader + + +def _load_pose_data( + all_paths: List[str], audio_per_frame: int, flip_person: bool = False +) -> Dict[str, List]: + data = [] + face = [] + audio = [] + lengths = [] + missing = [] + for _, curr_path_name in enumerate(all_paths): + if not curr_path_name.endswith("_body_pose.npy"): + continue + # load face information and deal with missing codes + curr_code = np.load( + curr_path_name.replace("_body_pose.npy", "_face_expression.npy") + ).astype(float) + # curr_code = np.array(curr_face["codes"], dtype=float) + missing_list = np.load( + curr_path_name.replace("_body_pose.npy", "_missing_face_frames.npy") + ) + if len(missing_list) == len(curr_code): + print("skipping", curr_path_name, curr_code.shape) + continue + curr_missing = np.ones_like(curr_code) + curr_missing[missing_list] = 0.0 + + # load pose information and deal with discontinuities + curr_pose = np.load(curr_path_name) + if "PXB184" in curr_path_name or "RLW104" in curr_path_name: # Capture 1 or 2 + curr_pose[:, 3] = (curr_pose[:, 3] + np.pi) % (2 * np.pi) + curr_pose[:, 3] = (curr_pose[:, 3] + np.pi) % (2 * np.pi) + + # load audio information + curr_audio, _ = torchaudio.load( + curr_path_name.replace("_body_pose.npy", "_audio.wav") + ) + curr_audio = curr_audio.T + if flip_person: + prGreen("[get_data.py] flipping the dataset of left right person") + tmp = torch.zeros_like(curr_audio) + tmp[:, 1] = curr_audio[:, 0] + tmp[:, 0] = curr_audio[:, 1] + curr_audio = tmp + + assert len(curr_pose) * audio_per_frame == len( + curr_audio + ), f"motion {curr_pose.shape} vs audio {curr_audio.shape}" + + data.append(curr_pose) + face.append(curr_code) + missing.append(curr_missing) + audio.append(curr_audio) + lengths.append(len(curr_pose)) + + data_dict = { + "data": data, + "face": face, + "audio": audio, + "lengths": lengths, + "missing": missing, + } + return data_dict + + +def load_local_data( + data_root: str, audio_per_frame: int, flip_person: bool = False +) -> Dict[str, List]: + if flip_person: + if "PXB184" in data_root: + data_root = data_root.replace("PXB184", "RLW104") + elif "RLW104" in data_root: + data_root = data_root.replace("RLW104", "PXB184") + elif "TXB805" in data_root: + data_root = data_root.replace("TXB805", "GQS883") + elif "GQS883" in data_root: + data_root = data_root.replace("GQS883", "TXB805") + + all_paths = [os.path.join(data_root, x) for x in os.listdir(data_root)] + all_paths.sort() + return _load_pose_data( + all_paths, + audio_per_frame, + flip_person=flip_person, + ) diff --git a/data_loaders/tensors.py b/data_loaders/tensors.py new file mode 100644 index 0000000000000000000000000000000000000000..a00c495d0bbc1994c7e7ee266c4dc08c94f20d3d --- /dev/null +++ b/data_loaders/tensors.py @@ -0,0 +1,86 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import torch +from torch.utils.data._utils.collate import default_collate + + +def lengths_to_mask(lengths, max_len): + mask = torch.arange(max_len, device=lengths.device).expand( + len(lengths), max_len + ) < lengths.unsqueeze(1) + return mask + + +def collate_tensors(batch): + dims = batch[0].dim() + max_size = [max([b.size(i) for b in batch]) for i in range(dims)] + size = (len(batch),) + tuple(max_size) + canvas = batch[0].new_zeros(size=size) + for i, b in enumerate(batch): + sub_tensor = canvas[i] + for d in range(dims): + sub_tensor = sub_tensor.narrow(d, 0, b.size(d)) + sub_tensor.add_(b) + return canvas + + +## social collate +def collate_v2(batch): + notnone_batches = [b for b in batch if b is not None] + databatch = [b["inp"] for b in notnone_batches] + missingbatch = [b["missing"] for b in notnone_batches] + audiobatch = [b["audio"] for b in notnone_batches] + lenbatch = [b["lengths"] for b in notnone_batches] + alenbatch = [b["audio_lengths"] for b in notnone_batches] + keyframebatch = [b["keyframes"] for b in notnone_batches] + klenbatch = [b["key_lengths"] for b in notnone_batches] + + databatchTensor = collate_tensors(databatch) + missingbatchTensor = collate_tensors(missingbatch) + audiobatchTensor = collate_tensors(audiobatch) + lenbatchTensor = torch.as_tensor(lenbatch) + alenbatchTensor = torch.as_tensor(alenbatch) + keyframeTensor = collate_tensors(keyframebatch) + klenbatchTensor = torch.as_tensor(klenbatch) + + maskbatchTensor = ( + lengths_to_mask(lenbatchTensor, databatchTensor.shape[-1]) + .unsqueeze(1) + .unsqueeze(1) + ) # unqueeze for broadcasting + motion = databatchTensor + cond = { + "y": { + "missing": missingbatchTensor, + "mask": maskbatchTensor, + "lengths": lenbatchTensor, + "audio": audiobatchTensor, + "alengths": alenbatchTensor, + "keyframes": keyframeTensor, + "klengths": klenbatchTensor, + } + } + return motion, cond + + +def social_collate(batch): + adapted_batch = [ + { + "inp": torch.tensor(b["motion"].T).to(torch.float32).unsqueeze(1), + "lengths": b["m_length"], + "audio": b["audio"] + if torch.is_tensor(b["audio"]) + else torch.tensor(b["audio"]).to(torch.float32), + "keyframes": torch.tensor(b["keyframes"]).to(torch.float32), + "key_lengths": b["k_length"], + "audio_lengths": b["a_length"], + "missing": torch.tensor(b["missing"]).to(torch.float32), + } + for b in batch + ] + return collate_v2(adapted_batch) diff --git a/demo/.ipynb_checkpoints/demo-checkpoint.py b/demo/.ipynb_checkpoints/demo-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..f26f536b3f018f67043a544e7380ae4bf32d29ee --- /dev/null +++ b/demo/.ipynb_checkpoints/demo-checkpoint.py @@ -0,0 +1,276 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import copy +import json +from typing import Dict, Union + +import gradio as gr +import numpy as np +import torch +import torchaudio +from attrdict import AttrDict +from diffusion.respace import SpacedDiffusion +from model.cfg_sampler import ClassifierFreeSampleModel +from model.diffusion import FiLMTransformer +from utils.misc import fixseed +from utils.model_util import create_model_and_diffusion, load_model +from visualize.render_codes import BodyRenderer + + +class GradioModel: + def __init__(self, face_args, pose_args) -> None: + self.face_model, self.face_diffusion, self.device = self._setup_model( + face_args, "checkpoints/diffusion/c1_face/model000155000.pt" + ) + self.pose_model, self.pose_diffusion, _ = self._setup_model( + pose_args, "checkpoints/diffusion/c1_pose/model000340000.pt" + ) + # load standardization stuff + stats = torch.load("dataset/PXB184/data_stats.pth") + stats["pose_mean"] = stats["pose_mean"].reshape(-1) + stats["pose_std"] = stats["pose_std"].reshape(-1) + self.stats = stats + # set up renderer + config_base = f"./checkpoints/ca_body/data/PXB184" + self.body_renderer = BodyRenderer( + config_base=config_base, + render_rgb=True, + ) + + def _setup_model( + self, + args_path: str, + model_path: str, + ) -> (Union[FiLMTransformer, ClassifierFreeSampleModel], SpacedDiffusion): + with open(args_path) as f: + args = json.load(f) + args = AttrDict(args) + args.device = "cuda:0" if torch.cuda.is_available() else "cpu" + print("running on...", args.device) + args.model_path = model_path + args.output_dir = "/tmp/gradio/" + args.timestep_respacing = "ddim100" + if args.data_format == "pose": + args.resume_trans = "checkpoints/guide/c1_pose/checkpoints/iter-0100000.pt" + + ## create model + model, diffusion = create_model_and_diffusion(args, split_type="test") + print(f"Loading checkpoints from [{args.model_path}]...") + state_dict = torch.load(args.model_path, map_location=args.device) + load_model(model, state_dict) + model = ClassifierFreeSampleModel(model) + model.eval() + model.to(args.device) + return model, diffusion, args.device + + def _replace_keyframes( + self, + model_kwargs: Dict[str, Dict[str, torch.Tensor]], + B: int, + T: int, + top_p: float = 0.97, + ) -> torch.Tensor: + with torch.no_grad(): + tokens = self.pose_model.transformer.generate( + model_kwargs["y"]["audio"], + T, + layers=self.pose_model.tokenizer.residual_depth, + n_sequences=B, + top_p=top_p, + ) + tokens = tokens.reshape((B, -1, self.pose_model.tokenizer.residual_depth)) + pred = self.pose_model.tokenizer.decode(tokens).detach() + return pred + + def _run_single_diffusion( + self, + model_kwargs: Dict[str, Dict[str, torch.Tensor]], + diffusion: SpacedDiffusion, + model: Union[FiLMTransformer, ClassifierFreeSampleModel], + curr_seq_length: int, + num_repetitions: int = 1, + ) -> (torch.Tensor,): + sample_fn = diffusion.ddim_sample_loop + with torch.no_grad(): + sample = sample_fn( + model, + (num_repetitions, model.nfeats, 1, curr_seq_length), + clip_denoised=False, + model_kwargs=model_kwargs, + init_image=None, + progress=True, + dump_steps=None, + noise=None, + const_noise=False, + ) + return sample + + def generate_sequences( + self, + model_kwargs: Dict[str, Dict[str, torch.Tensor]], + data_format: str, + curr_seq_length: int, + num_repetitions: int = 5, + guidance_param: float = 10.0, + top_p: float = 0.97, + # batch_size: int = 1, + ) -> Dict[str, np.ndarray]: + if data_format == "pose": + model = self.pose_model + diffusion = self.pose_diffusion + else: + model = self.face_model + diffusion = self.face_diffusion + + all_motions = [] + model_kwargs["y"]["scale"] = torch.ones(num_repetitions) * guidance_param + model_kwargs["y"] = { + key: val.to(self.device) if torch.is_tensor(val) else val + for key, val in model_kwargs["y"].items() + } + if data_format == "pose": + model_kwargs["y"]["mask"] = ( + torch.ones((num_repetitions, 1, 1, curr_seq_length)) + .to(self.device) + .bool() + ) + model_kwargs["y"]["keyframes"] = self._replace_keyframes( + model_kwargs, + num_repetitions, + int(curr_seq_length / 30), + top_p=top_p, + ) + sample = self._run_single_diffusion( + model_kwargs, diffusion, model, curr_seq_length, num_repetitions + ) + all_motions.append(sample.cpu().numpy()) + print(f"created {len(all_motions) * num_repetitions} samples") + return np.concatenate(all_motions, axis=0) + + +def generate_results(audio: np.ndarray, num_repetitions: int, top_p: float): + if audio is None: + raise gr.Error("Please record audio to start") + sr, y = audio + # set to mono and perform resampling + y = torch.Tensor(y) + if y.dim() == 2: + dim = 0 if y.shape[0] == 2 else 1 + y = torch.mean(y, dim=dim) + y = torchaudio.functional.resample(torch.Tensor(y), orig_freq=sr, new_freq=48_000) + sr = 48_000 + # make it so that it is 4 seconds long + if len(y) < (sr * 4): + raise gr.Error("Please record at least 4 second of audio") + if num_repetitions is None or num_repetitions <= 0 or num_repetitions > 10: + raise gr.Error( + f"Invalid number of samples: {num_repetitions}. Please specify a number between 1-10" + ) + cutoff = int(len(y) / (sr * 4)) + y = y[: cutoff * sr * 4] + curr_seq_length = int(len(y) / sr) * 30 + # create model_kwargs + model_kwargs = {"y": {}} + dual_audio = np.random.normal(0, 0.001, (1, len(y), 2)) + dual_audio[:, :, 0] = y / max(y) + dual_audio = (dual_audio - gradio_model.stats["audio_mean"]) / gradio_model.stats[ + "audio_std_flat" + ] + model_kwargs["y"]["audio"] = ( + torch.Tensor(dual_audio).float().tile(num_repetitions, 1, 1) + ) + face_results = ( + gradio_model.generate_sequences( + model_kwargs, "face", curr_seq_length, num_repetitions=int(num_repetitions) + ) + .squeeze(2) + .transpose(0, 2, 1) + ) + face_results = ( + face_results * gradio_model.stats["code_std"] + gradio_model.stats["code_mean"] + ) + pose_results = ( + gradio_model.generate_sequences( + model_kwargs, + "pose", + curr_seq_length, + num_repetitions=int(num_repetitions), + guidance_param=2.0, + top_p=top_p, + ) + .squeeze(2) + .transpose(0, 2, 1) + ) + pose_results = ( + pose_results * gradio_model.stats["pose_std"] + gradio_model.stats["pose_mean"] + ) + dual_audio = ( + dual_audio * gradio_model.stats["audio_std_flat"] + + gradio_model.stats["audio_mean"] + ) + return face_results, pose_results, dual_audio[0].transpose(1, 0).astype(np.float32) + + +def audio_to_avatar(audio: np.ndarray, num_repetitions: int, top_p: float): + face_results, pose_results, audio = generate_results(audio, num_repetitions, top_p) + # returns: num_rep x T x 104 + B = len(face_results) + results = [] + for i in range(B): + render_data_block = { + "audio": audio, # 2 x T + "body_motion": pose_results[i, ...], # T x 104 + "face_motion": face_results[i, ...], # T x 256 + } + gradio_model.body_renderer.render_full_video( + render_data_block, f"/tmp/sample{i}", audio_sr=48_000 + ) + results += [gr.Video(value=f"/tmp/sample{i}_pred.mp4", visible=True)] + results += [gr.Video(visible=False) for _ in range(B, 10)] + return results + + +gradio_model = GradioModel( + face_args="./checkpoints/diffusion/c1_face/args.json", + pose_args="./checkpoints/diffusion/c1_pose/args.json", +) +demo = gr.Interface( + audio_to_avatar, # function + [ + gr.Audio(sources=["microphone"]), + gr.Number( + value=3, + label="Number of Samples (default = 3)", + precision=0, + minimum=1, + maximum=10, + ), + gr.Number( + value=0.97, + label="Sample Diversity (default = 0.97)", + precision=None, + minimum=0.01, + step=0.01, + maximum=1.00, + ), + ], # input type + [gr.Video(format="mp4", visible=True)] + + [gr.Video(format="mp4", visible=False) for _ in range(9)], # output type + title='"From Audio to Photoreal Embodiment: Synthesizing Humans in Conversations" Demo', + description="You can generate a photorealistic avatar from your voice!
\ + 1) Start by recording your audio.
\ + 2) Specify the number of samples to generate.
\ + 3) Specify how diverse you want the samples to be. This tunes the cumulative probability in nucleus sampling: 0.01 = low diversity, 1.0 = high diversity.
\ + 4) Then, sit back and wait for the rendering to happen! This may take a while (e.g. 30 minutes)
\ + 5) After, you can view the videos and download the ones you like.
", + article="Relevant links: [Project Page](https://people.eecs.berkeley.edu/~evonne_ng/projects/audio2photoreal)", # TODO: code and arxiv +) + +if __name__ == "__main__": + fixseed(10) + demo.launch(share=True) diff --git a/demo/demo.py b/demo/demo.py new file mode 100644 index 0000000000000000000000000000000000000000..f26f536b3f018f67043a544e7380ae4bf32d29ee --- /dev/null +++ b/demo/demo.py @@ -0,0 +1,276 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import copy +import json +from typing import Dict, Union + +import gradio as gr +import numpy as np +import torch +import torchaudio +from attrdict import AttrDict +from diffusion.respace import SpacedDiffusion +from model.cfg_sampler import ClassifierFreeSampleModel +from model.diffusion import FiLMTransformer +from utils.misc import fixseed +from utils.model_util import create_model_and_diffusion, load_model +from visualize.render_codes import BodyRenderer + + +class GradioModel: + def __init__(self, face_args, pose_args) -> None: + self.face_model, self.face_diffusion, self.device = self._setup_model( + face_args, "checkpoints/diffusion/c1_face/model000155000.pt" + ) + self.pose_model, self.pose_diffusion, _ = self._setup_model( + pose_args, "checkpoints/diffusion/c1_pose/model000340000.pt" + ) + # load standardization stuff + stats = torch.load("dataset/PXB184/data_stats.pth") + stats["pose_mean"] = stats["pose_mean"].reshape(-1) + stats["pose_std"] = stats["pose_std"].reshape(-1) + self.stats = stats + # set up renderer + config_base = f"./checkpoints/ca_body/data/PXB184" + self.body_renderer = BodyRenderer( + config_base=config_base, + render_rgb=True, + ) + + def _setup_model( + self, + args_path: str, + model_path: str, + ) -> (Union[FiLMTransformer, ClassifierFreeSampleModel], SpacedDiffusion): + with open(args_path) as f: + args = json.load(f) + args = AttrDict(args) + args.device = "cuda:0" if torch.cuda.is_available() else "cpu" + print("running on...", args.device) + args.model_path = model_path + args.output_dir = "/tmp/gradio/" + args.timestep_respacing = "ddim100" + if args.data_format == "pose": + args.resume_trans = "checkpoints/guide/c1_pose/checkpoints/iter-0100000.pt" + + ## create model + model, diffusion = create_model_and_diffusion(args, split_type="test") + print(f"Loading checkpoints from [{args.model_path}]...") + state_dict = torch.load(args.model_path, map_location=args.device) + load_model(model, state_dict) + model = ClassifierFreeSampleModel(model) + model.eval() + model.to(args.device) + return model, diffusion, args.device + + def _replace_keyframes( + self, + model_kwargs: Dict[str, Dict[str, torch.Tensor]], + B: int, + T: int, + top_p: float = 0.97, + ) -> torch.Tensor: + with torch.no_grad(): + tokens = self.pose_model.transformer.generate( + model_kwargs["y"]["audio"], + T, + layers=self.pose_model.tokenizer.residual_depth, + n_sequences=B, + top_p=top_p, + ) + tokens = tokens.reshape((B, -1, self.pose_model.tokenizer.residual_depth)) + pred = self.pose_model.tokenizer.decode(tokens).detach() + return pred + + def _run_single_diffusion( + self, + model_kwargs: Dict[str, Dict[str, torch.Tensor]], + diffusion: SpacedDiffusion, + model: Union[FiLMTransformer, ClassifierFreeSampleModel], + curr_seq_length: int, + num_repetitions: int = 1, + ) -> (torch.Tensor,): + sample_fn = diffusion.ddim_sample_loop + with torch.no_grad(): + sample = sample_fn( + model, + (num_repetitions, model.nfeats, 1, curr_seq_length), + clip_denoised=False, + model_kwargs=model_kwargs, + init_image=None, + progress=True, + dump_steps=None, + noise=None, + const_noise=False, + ) + return sample + + def generate_sequences( + self, + model_kwargs: Dict[str, Dict[str, torch.Tensor]], + data_format: str, + curr_seq_length: int, + num_repetitions: int = 5, + guidance_param: float = 10.0, + top_p: float = 0.97, + # batch_size: int = 1, + ) -> Dict[str, np.ndarray]: + if data_format == "pose": + model = self.pose_model + diffusion = self.pose_diffusion + else: + model = self.face_model + diffusion = self.face_diffusion + + all_motions = [] + model_kwargs["y"]["scale"] = torch.ones(num_repetitions) * guidance_param + model_kwargs["y"] = { + key: val.to(self.device) if torch.is_tensor(val) else val + for key, val in model_kwargs["y"].items() + } + if data_format == "pose": + model_kwargs["y"]["mask"] = ( + torch.ones((num_repetitions, 1, 1, curr_seq_length)) + .to(self.device) + .bool() + ) + model_kwargs["y"]["keyframes"] = self._replace_keyframes( + model_kwargs, + num_repetitions, + int(curr_seq_length / 30), + top_p=top_p, + ) + sample = self._run_single_diffusion( + model_kwargs, diffusion, model, curr_seq_length, num_repetitions + ) + all_motions.append(sample.cpu().numpy()) + print(f"created {len(all_motions) * num_repetitions} samples") + return np.concatenate(all_motions, axis=0) + + +def generate_results(audio: np.ndarray, num_repetitions: int, top_p: float): + if audio is None: + raise gr.Error("Please record audio to start") + sr, y = audio + # set to mono and perform resampling + y = torch.Tensor(y) + if y.dim() == 2: + dim = 0 if y.shape[0] == 2 else 1 + y = torch.mean(y, dim=dim) + y = torchaudio.functional.resample(torch.Tensor(y), orig_freq=sr, new_freq=48_000) + sr = 48_000 + # make it so that it is 4 seconds long + if len(y) < (sr * 4): + raise gr.Error("Please record at least 4 second of audio") + if num_repetitions is None or num_repetitions <= 0 or num_repetitions > 10: + raise gr.Error( + f"Invalid number of samples: {num_repetitions}. Please specify a number between 1-10" + ) + cutoff = int(len(y) / (sr * 4)) + y = y[: cutoff * sr * 4] + curr_seq_length = int(len(y) / sr) * 30 + # create model_kwargs + model_kwargs = {"y": {}} + dual_audio = np.random.normal(0, 0.001, (1, len(y), 2)) + dual_audio[:, :, 0] = y / max(y) + dual_audio = (dual_audio - gradio_model.stats["audio_mean"]) / gradio_model.stats[ + "audio_std_flat" + ] + model_kwargs["y"]["audio"] = ( + torch.Tensor(dual_audio).float().tile(num_repetitions, 1, 1) + ) + face_results = ( + gradio_model.generate_sequences( + model_kwargs, "face", curr_seq_length, num_repetitions=int(num_repetitions) + ) + .squeeze(2) + .transpose(0, 2, 1) + ) + face_results = ( + face_results * gradio_model.stats["code_std"] + gradio_model.stats["code_mean"] + ) + pose_results = ( + gradio_model.generate_sequences( + model_kwargs, + "pose", + curr_seq_length, + num_repetitions=int(num_repetitions), + guidance_param=2.0, + top_p=top_p, + ) + .squeeze(2) + .transpose(0, 2, 1) + ) + pose_results = ( + pose_results * gradio_model.stats["pose_std"] + gradio_model.stats["pose_mean"] + ) + dual_audio = ( + dual_audio * gradio_model.stats["audio_std_flat"] + + gradio_model.stats["audio_mean"] + ) + return face_results, pose_results, dual_audio[0].transpose(1, 0).astype(np.float32) + + +def audio_to_avatar(audio: np.ndarray, num_repetitions: int, top_p: float): + face_results, pose_results, audio = generate_results(audio, num_repetitions, top_p) + # returns: num_rep x T x 104 + B = len(face_results) + results = [] + for i in range(B): + render_data_block = { + "audio": audio, # 2 x T + "body_motion": pose_results[i, ...], # T x 104 + "face_motion": face_results[i, ...], # T x 256 + } + gradio_model.body_renderer.render_full_video( + render_data_block, f"/tmp/sample{i}", audio_sr=48_000 + ) + results += [gr.Video(value=f"/tmp/sample{i}_pred.mp4", visible=True)] + results += [gr.Video(visible=False) for _ in range(B, 10)] + return results + + +gradio_model = GradioModel( + face_args="./checkpoints/diffusion/c1_face/args.json", + pose_args="./checkpoints/diffusion/c1_pose/args.json", +) +demo = gr.Interface( + audio_to_avatar, # function + [ + gr.Audio(sources=["microphone"]), + gr.Number( + value=3, + label="Number of Samples (default = 3)", + precision=0, + minimum=1, + maximum=10, + ), + gr.Number( + value=0.97, + label="Sample Diversity (default = 0.97)", + precision=None, + minimum=0.01, + step=0.01, + maximum=1.00, + ), + ], # input type + [gr.Video(format="mp4", visible=True)] + + [gr.Video(format="mp4", visible=False) for _ in range(9)], # output type + title='"From Audio to Photoreal Embodiment: Synthesizing Humans in Conversations" Demo', + description="You can generate a photorealistic avatar from your voice!
\ + 1) Start by recording your audio.
\ + 2) Specify the number of samples to generate.
\ + 3) Specify how diverse you want the samples to be. This tunes the cumulative probability in nucleus sampling: 0.01 = low diversity, 1.0 = high diversity.
\ + 4) Then, sit back and wait for the rendering to happen! This may take a while (e.g. 30 minutes)
\ + 5) After, you can view the videos and download the ones you like.
", + article="Relevant links: [Project Page](https://people.eecs.berkeley.edu/~evonne_ng/projects/audio2photoreal)", # TODO: code and arxiv +) + +if __name__ == "__main__": + fixseed(10) + demo.launch(share=True) diff --git a/demo/install.sh b/demo/install.sh new file mode 100644 index 0000000000000000000000000000000000000000..824249c7728e424603ad3a28b1230c532dafcd73 --- /dev/null +++ b/demo/install.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +# make sure to have cuda 11.7 and gcc 9.0 installed +# install environment +pip install -r scripts/requirements.txt +sh scripts/download_prereq.sh + +# download pytorch3d +pip install "git+https://github.com/facebookresearch/pytorch3d.git" + +# download model stuff +wget http://audio2photoreal_models.berkeleyvision.org/PXB184_models.tar || { echo 'downloading model failed' ; exit 1; } +tar xvf PXB184_models.tar +rm PXB184_models.tar + +# install rendering stuff +mkdir -p checkpoints/ca_body/data/ +wget https://github.com/facebookresearch/ca_body/releases/download/v0.0.1-alpha/PXB184.tar.gz || { echo 'downloading ca body model failed' ; exit 1; } +tar xvf PXB184.tar.gz --directory checkpoints/ca_body/data/ +rm PXB184.tar.gz \ No newline at end of file diff --git a/demo/requirements.txt b/demo/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..1c5759ff4188f6180a82d72feaf596a9b57f0375 --- /dev/null +++ b/demo/requirements.txt @@ -0,0 +1,17 @@ +attrdict +einops==0.7.0 +fairseq==0.12.2 +gradio==4.31.3 +gradio_client==0.7.3 +huggingface-hub==0.19.4 +hydra-core==1.0.7 +mediapy==1.2.0 +numpy==1.26.2 +omegaconf==2.0.6 +opencv-python==4.8.1.78 +protobuf==4.25.1 +tensorboardX==2.6.2.2 +torch==2.0.1 +torchaudio==2.0.2 +torchvision==0.15.2 +tqdm==4.66.3 diff --git a/diffusion/fp16_util.py b/diffusion/fp16_util.py new file mode 100644 index 0000000000000000000000000000000000000000..54556e3bf91cbd69e954075c73d8862a390092cb --- /dev/null +++ b/diffusion/fp16_util.py @@ -0,0 +1,250 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +""" +original code from +https://github.com/GuyTevet/motion-diffusion-model/blob/main/diffusion/gaussian_diffusion.py +under an MIT license +https://github.com/GuyTevet/motion-diffusion-model/blob/main/LICENSE +""" + +""" +Helpers to train with 16-bit precision. +""" + +import numpy as np +import torch as th +import torch.nn as nn +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors + +from utils import logger + +INITIAL_LOG_LOSS_SCALE = 20.0 + + +def convert_module_to_f16(l): + """ + Convert primitive modules to float16. + """ + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): + l.weight.data = l.weight.data.half() + if l.bias is not None: + l.bias.data = l.bias.data.half() + + +def convert_module_to_f32(l): + """ + Convert primitive modules to float32, undoing convert_module_to_f16(). + """ + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): + l.weight.data = l.weight.data.float() + if l.bias is not None: + l.bias.data = l.bias.data.float() + + +def make_master_params(param_groups_and_shapes): + """ + Copy model parameters into a (differently-shaped) list of full-precision + parameters. + """ + master_params = [] + for param_group, shape in param_groups_and_shapes: + master_param = nn.Parameter( + _flatten_dense_tensors( + [param.detach().float() for (_, param) in param_group] + ).view(shape) + ) + master_param.requires_grad = True + master_params.append(master_param) + return master_params + + +def model_grads_to_master_grads(param_groups_and_shapes, master_params): + """ + Copy the gradients from the model parameters into the master parameters + from make_master_params(). + """ + for master_param, (param_group, shape) in zip( + master_params, param_groups_and_shapes + ): + master_param.grad = _flatten_dense_tensors( + [param_grad_or_zeros(param) for (_, param) in param_group] + ).view(shape) + + +def master_params_to_model_params(param_groups_and_shapes, master_params): + """ + Copy the master parameter data back into the model parameters. + """ + # Without copying to a list, if a generator is passed, this will + # silently not copy any parameters. + for master_param, (param_group, _) in zip(master_params, param_groups_and_shapes): + for (_, param), unflat_master_param in zip( + param_group, unflatten_master_params(param_group, master_param.view(-1)) + ): + param.detach().copy_(unflat_master_param) + + +def unflatten_master_params(param_group, master_param): + return _unflatten_dense_tensors(master_param, [param for (_, param) in param_group]) + + +def get_param_groups_and_shapes(named_model_params): + named_model_params = list(named_model_params) + scalar_vector_named_params = ( + [(n, p) for (n, p) in named_model_params if p.ndim <= 1], + (-1), + ) + matrix_named_params = ( + [(n, p) for (n, p) in named_model_params if p.ndim > 1], + (1, -1), + ) + return [scalar_vector_named_params, matrix_named_params] + + +def master_params_to_state_dict( + model, param_groups_and_shapes, master_params, use_fp16 +): + if use_fp16: + state_dict = model.state_dict() + for master_param, (param_group, _) in zip( + master_params, param_groups_and_shapes + ): + for (name, _), unflat_master_param in zip( + param_group, unflatten_master_params(param_group, master_param.view(-1)) + ): + assert name in state_dict + state_dict[name] = unflat_master_param + else: + state_dict = model.state_dict() + for i, (name, _value) in enumerate(model.named_parameters()): + assert name in state_dict + state_dict[name] = master_params[i] + return state_dict + + +def state_dict_to_master_params(model, state_dict, use_fp16): + if use_fp16: + named_model_params = [ + (name, state_dict[name]) for name, _ in model.named_parameters() + ] + param_groups_and_shapes = get_param_groups_and_shapes(named_model_params) + master_params = make_master_params(param_groups_and_shapes) + else: + master_params = [state_dict[name] for name, _ in model.named_parameters()] + return master_params + + +def zero_master_grads(master_params): + for param in master_params: + param.grad = None + + +def zero_grad(model_params): + for param in model_params: + # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group + if param.grad is not None: + param.grad.detach_() + param.grad.zero_() + + +def param_grad_or_zeros(param): + if param.grad is not None: + return param.grad.data.detach() + else: + return th.zeros_like(param) + + +class MixedPrecisionTrainer: + def __init__( + self, + *, + model, + use_fp16=False, + fp16_scale_growth=1e-3, + initial_lg_loss_scale=INITIAL_LOG_LOSS_SCALE, + ): + self.model = model + self.use_fp16 = use_fp16 + self.fp16_scale_growth = fp16_scale_growth + + self.model_params = list(self.model.parameters()) + self.master_params = self.model_params + self.param_groups_and_shapes = None + self.lg_loss_scale = initial_lg_loss_scale + + if self.use_fp16: + self.param_groups_and_shapes = get_param_groups_and_shapes( + self.model.named_parameters() + ) + self.master_params = make_master_params(self.param_groups_and_shapes) + self.model.convert_to_fp16() + + def zero_grad(self): + zero_grad(self.model_params) + + def backward(self, loss: th.Tensor): + if self.use_fp16: + loss_scale = 2**self.lg_loss_scale + (loss * loss_scale).backward() + else: + loss.backward() + + def optimize(self, opt: th.optim.Optimizer): + if self.use_fp16: + return self._optimize_fp16(opt) + else: + return self._optimize_normal(opt) + + def _optimize_fp16(self, opt: th.optim.Optimizer): + logger.logkv_mean("lg_loss_scale", self.lg_loss_scale) + model_grads_to_master_grads(self.param_groups_and_shapes, self.master_params) + grad_norm, param_norm = self._compute_norms(grad_scale=2**self.lg_loss_scale) + if check_overflow(grad_norm): + self.lg_loss_scale -= 1 + logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}") + zero_master_grads(self.master_params) + return False + + logger.logkv_mean("grad_norm", grad_norm) + logger.logkv_mean("param_norm", param_norm) + + self.master_params[0].grad.mul_(1.0 / (2**self.lg_loss_scale)) + opt.step() + zero_master_grads(self.master_params) + master_params_to_model_params(self.param_groups_and_shapes, self.master_params) + self.lg_loss_scale += self.fp16_scale_growth + return True + + def _optimize_normal(self, opt: th.optim.Optimizer): + grad_norm, param_norm = self._compute_norms() + logger.logkv_mean("grad_norm", grad_norm) + logger.logkv_mean("param_norm", param_norm) + opt.step() + return True + + def _compute_norms(self, grad_scale=1.0): + grad_norm = 0.0 + param_norm = 0.0 + for p in self.master_params: + with th.no_grad(): + param_norm += th.norm(p, p=2, dtype=th.float32).item() ** 2 + if p.grad is not None: + grad_norm += th.norm(p.grad, p=2, dtype=th.float32).item() ** 2 + return np.sqrt(grad_norm) / grad_scale, np.sqrt(param_norm) + + def master_params_to_state_dict(self, master_params): + return master_params_to_state_dict( + self.model, self.param_groups_and_shapes, master_params, self.use_fp16 + ) + + def state_dict_to_master_params(self, state_dict): + return state_dict_to_master_params(self.model, state_dict, self.use_fp16) + + +def check_overflow(value): + return (value == float("inf")) or (value == -float("inf")) or (value != value) diff --git a/diffusion/gaussian_diffusion.py b/diffusion/gaussian_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..acda506b27d437ba9ade74eb443882815c2629a4 --- /dev/null +++ b/diffusion/gaussian_diffusion.py @@ -0,0 +1,1273 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +""" +original code from +https://github.com/GuyTevet/motion-diffusion-model/blob/main/diffusion/gaussian_diffusion.py +under an MIT license +https://github.com/GuyTevet/motion-diffusion-model/blob/main/LICENSE +""" + +import enum +import math +from copy import deepcopy + +import numpy as np +import torch +import torch as th +from diffusion.losses import discretized_gaussian_log_likelihood, normal_kl +from diffusion.nn import mean_flat, sum_flat + + +def get_named_beta_schedule(schedule_name, num_diffusion_timesteps, scale_betas=1.0): + """ + Get a pre-defined beta schedule for the given name. + + The beta schedule library consists of beta schedules which remain similar + in the limit of num_diffusion_timesteps. + Beta schedules may be added, but should not be removed or changed once + they are committed to maintain backwards compatibility. + """ + if schedule_name == "linear": + # Linear schedule from Ho et al, extended to work for any number of + # diffusion steps. + scale = scale_betas * 1000 / num_diffusion_timesteps + beta_start = scale * 0.0001 + beta_end = scale * 0.02 + return np.linspace( + beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64 + ) + elif schedule_name == "cosine": + return betas_for_alpha_bar( + num_diffusion_timesteps, + lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, + ) + else: + raise NotImplementedError(f"unknown beta schedule: {schedule_name}") + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +class ModelMeanType(enum.Enum): + """ + Which type of output the model predicts. + """ + + PREVIOUS_X = enum.auto() # the model predicts x_{t-1} + START_X = enum.auto() # the model predicts x_0 + EPSILON = enum.auto() # the model predicts epsilon + + +class ModelVarType(enum.Enum): + """ + What is used as the model's output variance. + + The LEARNED_RANGE option has been added to allow the model to predict + values between FIXED_SMALL and FIXED_LARGE, making its job easier. + """ + + LEARNED = enum.auto() + FIXED_SMALL = enum.auto() + FIXED_LARGE = enum.auto() + LEARNED_RANGE = enum.auto() + + +class LossType(enum.Enum): + MSE = enum.auto() # use raw MSE loss (and KL when learning variances) + RESCALED_MSE = ( + enum.auto() + ) # use raw MSE loss (with RESCALED_KL when learning variances) + KL = enum.auto() # use the variational lower-bound + RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB + + def is_vb(self): + return self == LossType.KL or self == LossType.RESCALED_KL + + +class GaussianDiffusion: + """ + Utilities for training and sampling diffusion models. + + Ported directly from here, and then adapted over time to further experimentation. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42 + + :param betas: a 1-D numpy array of betas for each diffusion timestep, + starting at T and going to 1. + :param model_mean_type: a ModelMeanType determining what the model outputs. + :param model_var_type: a ModelVarType determining how variance is output. + :param loss_type: a LossType determining the loss function to use. + :param rescale_timesteps: if True, pass floating point timesteps into the + model so that they are always scaled like in the + original paper (0 to 1000). + """ + + def __init__( + self, + *, + betas, + model_mean_type, + model_var_type, + loss_type, + rescale_timesteps=False, + lambda_vel=0.0, + data_format="pose", + model_path=None, + ): + self.model_mean_type = model_mean_type + self.model_var_type = model_var_type + self.loss_type = loss_type + self.rescale_timesteps = rescale_timesteps + self.data_format = data_format + self.lambda_vel = lambda_vel + if self.lambda_vel > 0.0: + assert ( + self.loss_type == LossType.MSE + ), "Geometric losses are supported by MSE loss type only!" + + # Use float64 for accuracy. + betas = np.array(betas, dtype=np.float64) + self.betas = betas + assert len(betas.shape) == 1, "betas must be 1-D" + assert (betas > 0).all() and (betas <= 1).all() + + self.num_timesteps = int(betas.shape[0]) + + alphas = 1.0 - betas + self.alphas_cumprod = np.cumprod(alphas, axis=0) + self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) + self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0) + assert self.alphas_cumprod_prev.shape == (self.num_timesteps,) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) + self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod) + self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod) + self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod) + self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + self.posterior_variance = ( + betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + ) + # log calculation clipped because the posterior variance is 0 at the + # beginning of the diffusion chain. + self.posterior_log_variance_clipped = np.log( + np.append(self.posterior_variance[1], self.posterior_variance[1:]) + ) + self.posterior_mean_coef1 = ( + betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + ) + self.posterior_mean_coef2 = ( + (1.0 - self.alphas_cumprod_prev) + * np.sqrt(alphas) + / (1.0 - self.alphas_cumprod) + ) + + self.l2_loss = lambda a, b: (a - b) ** 2 + + def masked_l2(self, a, b, mask): + loss = self.l2_loss(a, b) + loss = sum_flat(loss * mask.float()) + n_entries = a.shape[1] * a.shape[2] + non_zero_elements = sum_flat(mask) * n_entries + mse_loss_val = loss / non_zero_elements + return mse_loss_val + + def q_mean_variance(self, x_start, t): + """ + Get the distribution q(x_t | x_0). + + :param x_start: the [N x C x ...] tensor of noiseless inputs. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :return: A tuple (mean, variance, log_variance), all of x_start's shape. + """ + mean = ( + _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + ) + variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) + log_variance = _extract_into_tensor( + self.log_one_minus_alphas_cumprod, t, x_start.shape + ) + return mean, variance, log_variance + + def q_sample(self, x_start, t, noise=None): + """ + Diffuse the dataset for a given number of diffusion steps. + + In other words, sample from q(x_t | x_0). + + :param x_start: the initial dataset batch. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :param noise: if specified, the split-out normal noise. + :return: A noisy version of x_start. + """ + if noise is None: + noise = th.randn_like(x_start) + assert noise.shape == x_start.shape + return ( + _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) + * noise + ) + + def q_posterior_mean_variance(self, x_start, x_t, t): + """ + Compute the mean and variance of the diffusion posterior: + + q(x_{t-1} | x_t, x_0) + + """ + assert x_start.shape == x_t.shape, f"x_start: {x_start.shape}, x_t: {x_t.shape}" + posterior_mean = ( + _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = _extract_into_tensor( + self.posterior_log_variance_clipped, t, x_t.shape + ) + assert ( + posterior_mean.shape[0] + == posterior_variance.shape[0] + == posterior_log_variance_clipped.shape[0] + == x_start.shape[0] + ) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance( + self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None + ): + """ + Apply the model to get p(x_{t-1} | x_t), as well as a prediction of + the initial x, x_0. + + :param model: the model, which takes a signal and a batch of timesteps + as input. + :param x: the [N x C x ...] tensor at time t. + :param t: a 1-D Tensor of timesteps. + :param clip_denoised: if True, clip the denoised signal into [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. Applies before + clip_denoised. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict with the following keys: + - 'mean': the model mean output. + - 'variance': the model variance output. + - 'log_variance': the log of 'variance'. + - 'pred_xstart': the prediction for x_0. + """ + if model_kwargs is None: + model_kwargs = {} + + B, C = x.shape[:2] + assert t.shape == (B,) + model_output = model(x, self._scale_timesteps(t), **model_kwargs) + + model_variance, model_log_variance = { + # for fixedlarge, we set the initial (log-)variance like so + # to get a better decoder log likelihood. + ModelVarType.FIXED_LARGE: ( + np.append(self.posterior_variance[1], self.betas[1:]), + np.log(np.append(self.posterior_variance[1], self.betas[1:])), + ), + ModelVarType.FIXED_SMALL: ( + self.posterior_variance, + self.posterior_log_variance_clipped, + ), + }[self.model_var_type] + + model_variance = _extract_into_tensor(model_variance, t, x.shape) + model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape) + + def process_xstart(x): + if denoised_fn is not None: + x = denoised_fn(x) + if clip_denoised: + return x.clamp(-1, 1) + return x + + pred_xstart = process_xstart(model_output) + pred_xstart = pred_xstart.permute(0, 2, 1).unsqueeze(2) + model_mean, _, _ = self.q_posterior_mean_variance( + x_start=pred_xstart, x_t=x, t=t + ) + + assert ( + model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape + ), print( + f"{model_mean.shape} == {model_log_variance.shape} == {pred_xstart.shape} == {x.shape}" + ) + return { + "mean": model_mean, + "variance": model_variance, + "log_variance": model_log_variance, + "pred_xstart": pred_xstart, + } + + def _predict_xstart_from_eps(self, x_t, t, eps): + assert x_t.shape == eps.shape + return ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps + ) + + def _predict_xstart_from_xprev(self, x_t, t, xprev): + assert x_t.shape == xprev.shape + return ( + _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev + - _extract_into_tensor( + self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape + ) + * x_t + ) + + def _predict_eps_from_xstart(self, x_t, t, pred_xstart): + return ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - pred_xstart + ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + + def _scale_timesteps(self, t): + if self.rescale_timesteps: + return t.float() * (1000.0 / self.num_timesteps) + return t + + def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + """ + Compute the mean for the previous step, given a function cond_fn that + computes the gradient of a conditional log probability with respect to + x. In particular, cond_fn computes grad(log(p(y|x))), and we want to + condition on y. + + This uses the conditioning strategy from Sohl-Dickstein et al. (2015). + """ + gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs) + new_mean = ( + p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() + ) + return new_mean + + def condition_mean_with_grad(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + """ + Compute the mean for the previous step, given a function cond_fn that + computes the gradient of a conditional log probability with respect to + x. In particular, cond_fn computes grad(log(p(y|x))), and we want to + condition on y. + + This uses the conditioning strategy from Sohl-Dickstein et al. (2015). + """ + gradient = cond_fn(x, t, p_mean_var, **model_kwargs) + new_mean = ( + p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() + ) + return new_mean + + def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + """ + Compute what the p_mean_variance output would have been, should the + model's score function be conditioned by cond_fn. + + See condition_mean() for details on cond_fn. + + Unlike condition_mean(), this instead uses the conditioning strategy + from Song et al (2020). + """ + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + + eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) + eps = eps - (1 - alpha_bar).sqrt() * cond_fn( + x, self._scale_timesteps(t), **model_kwargs + ) + + out = p_mean_var.copy() + out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) + out["mean"], _, _ = self.q_posterior_mean_variance( + x_start=out["pred_xstart"], x_t=x, t=t + ) + return out + + def condition_score_with_grad(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + """ + Compute what the p_mean_variance output would have been, should the + model's score function be conditioned by cond_fn. + + See condition_mean() for details on cond_fn. + + Unlike condition_mean(), this instead uses the conditioning strategy + from Song et al (2020). + """ + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + + eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) + eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, p_mean_var, **model_kwargs) + + out = p_mean_var.copy() + out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) + out["mean"], _, _ = self.q_posterior_mean_variance( + x_start=out["pred_xstart"], x_t=x, t=t + ) + return out + + def p_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + const_noise=False, + ): + """ + Sample x_{t-1} from the model at the given timestep. + + :param model: the model to sample from. + :param x: the current tensor at x_{t-1}. + :param t: the value of t, starting at 0 for the first diffusion step. + :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. + :param cond_fn: if not None, this is a gradient function that acts + similarly to the model. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict containing the following keys: + - 'sample': a random sample from the model. + - 'pred_xstart': a prediction of x_0. + """ + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + + nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) + if cond_fn is not None: + out["mean"] = self.condition_mean( + cond_fn, out, x, t, model_kwargs=model_kwargs + ) + sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise + return {"sample": sample, "pred_xstart": out["pred_xstart"]} + + def p_sample_with_grad( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + ): + """ + Sample x_{t-1} from the model at the given timestep. + + :param model: the model to sample from. + :param x: the current tensor at x_{t-1}. + :param t: the value of t, starting at 0 for the first diffusion step. + :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. + :param cond_fn: if not None, this is a gradient function that acts + similarly to the model. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict containing the following keys: + - 'sample': a random sample from the model. + - 'pred_xstart': a prediction of x_0. + """ + with th.enable_grad(): + x = x.detach().requires_grad_() + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + noise = th.randn_like(x) + nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) + if cond_fn is not None: + out["mean"] = self.condition_mean_with_grad( + cond_fn, out, x, t, model_kwargs=model_kwargs + ) + sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise + return {"sample": sample, "pred_xstart": out["pred_xstart"].detach()} + + def p_sample_loop( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + skip_timesteps=0, + init_image=None, + randomize_class=False, + cond_fn_with_grad=False, + dump_steps=None, + const_noise=False, + ): + """ + Generate samples from the model. + + :param model: the model module. + :param shape: the shape of the samples, (N, C, H, W). + :param noise: if specified, the noise from the encoder to sample. + Should be of the same shape as `shape`. + :param clip_denoised: if True, clip x_start predictions to [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. + :param cond_fn: if not None, this is a gradient function that acts + similarly to the model. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param device: if specified, the device to create the samples on. + If not specified, use a model parameter's device. + :param progress: if True, show a tqdm progress bar. + :param const_noise: If True, will noise all samples with the same noise throughout sampling + :return: a non-differentiable batch of samples. + """ + final = None + if dump_steps is not None: + dump = [] + + for i, sample in enumerate( + self.p_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + skip_timesteps=skip_timesteps, + init_image=init_image, + randomize_class=randomize_class, + cond_fn_with_grad=cond_fn_with_grad, + const_noise=const_noise, + ) + ): + if dump_steps is not None and i in dump_steps: + dump.append(deepcopy(sample["sample"])) + final = sample + if dump_steps is not None: + return dump + return final["sample"] + + def p_sample_loop_progressive( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + skip_timesteps=0, + init_image=None, + randomize_class=False, + cond_fn_with_grad=False, + const_noise=False, + ): + """ + Generate samples from the model and yield intermediate samples from + each timestep of diffusion. + + Arguments are the same as p_sample_loop(). + Returns a generator over dicts, where each dict is the return value of + p_sample(). + """ + if device is None: + device = next(model.parameters()).device + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = th.randn(*shape, device=device) + + if skip_timesteps and init_image is None: + init_image = th.zeros_like(img) + + indices = list(range(self.num_timesteps - skip_timesteps))[::-1] + + if init_image is not None: + my_t = th.ones([shape[0]], device=device, dtype=th.long) * indices[0] + img = self.q_sample(init_image, my_t, img) + + if progress: + # Lazy import so that we don't depend on tqdm. + from tqdm.auto import tqdm + + indices = tqdm(indices) + + # number of timestamps to diffuse + for i in indices: + t = th.tensor([i] * shape[0], device=device) + if randomize_class and "y" in model_kwargs: + model_kwargs["y"] = th.randint( + low=0, + high=model.num_classes, + size=model_kwargs["y"].shape, + device=model_kwargs["y"].device, + ) + with th.no_grad(): + sample_fn = ( + self.p_sample_with_grad if cond_fn_with_grad else self.p_sample + ) + out = sample_fn( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + const_noise=const_noise, + ) + yield out + img = out["sample"] + + def ddim_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + eta=0.0, + ): + """ + Sample x_{t-1} from the model using DDIM. + + Same usage as p_sample(). + """ + out_orig = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + if cond_fn is not None: + out = self.condition_score( + cond_fn, out_orig, x, t, model_kwargs=model_kwargs + ) + else: + out = out_orig + # Usually our model outputs epsilon, but we re-derive it + # in case we used x_start or x_prev prediction. + eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) + + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) + sigma = ( + eta + * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) + * th.sqrt(1 - alpha_bar / alpha_bar_prev) + ) + noise = th.randn_like(x) + + mean_pred = ( + out["pred_xstart"] * th.sqrt(alpha_bar_prev) + + th.sqrt(1 - alpha_bar_prev - sigma**2) * eps + ) + nonzero_mask = ( + (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) + ) # no noise when t == 0 + sample = mean_pred + nonzero_mask * sigma * noise + return {"sample": sample, "pred_xstart": out_orig["pred_xstart"]} + + def ddim_sample_with_grad( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + eta=0.0, + ): + """ + Sample x_{t-1} from the model using DDIM. + + Same usage as p_sample(). + """ + with th.enable_grad(): + x = x.detach().requires_grad_() + out_orig = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + if cond_fn is not None: + out = self.condition_score_with_grad( + cond_fn, out_orig, x, t, model_kwargs=model_kwargs + ) + else: + out = out_orig + + out["pred_xstart"] = out["pred_xstart"].detach() + # Usually our model outputs epsilon, but we re-derive it + # in case we used x_start or x_prev prediction. + eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) + + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) + sigma = ( + eta + * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) + * th.sqrt(1 - alpha_bar / alpha_bar_prev) + ) + # Equation 12. + noise = th.randn_like(x) + mean_pred = ( + out["pred_xstart"] * th.sqrt(alpha_bar_prev) + + th.sqrt(1 - alpha_bar_prev - sigma**2) * eps + ) + nonzero_mask = ( + (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) + ) # no noise when t == 0 + sample = mean_pred + nonzero_mask * sigma * noise + return {"sample": sample, "pred_xstart": out_orig["pred_xstart"].detach()} + + def ddim_reverse_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + model_kwargs=None, + eta=0.0, + ): + """ + Sample x_{t+1} from the model using DDIM reverse ODE. + """ + assert eta == 0.0, "Reverse ODE only for deterministic path" + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + # Usually our model outputs epsilon, but we re-derive it + # in case we used x_start or x_prev prediction. + eps = ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x + - out["pred_xstart"] + ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape) + alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape) + + # Equation 12. reversed + mean_pred = ( + out["pred_xstart"] * th.sqrt(alpha_bar_next) + + th.sqrt(1 - alpha_bar_next) * eps + ) + + return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]} + + def ddim_sample_loop( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + eta=0.0, + skip_timesteps=0, + init_image=None, + randomize_class=False, + cond_fn_with_grad=False, + dump_steps=None, + const_noise=False, + ): + """ + Generate samples from the model using DDIM. + + Same usage as p_sample_loop(). + """ + if dump_steps is not None: + raise NotImplementedError() + if const_noise == True: + raise NotImplementedError() + + final = None + for sample in self.ddim_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + eta=eta, + skip_timesteps=skip_timesteps, + init_image=init_image, + randomize_class=randomize_class, + cond_fn_with_grad=cond_fn_with_grad, + ): + final = sample + return final["pred_xstart"] + + def ddim_sample_loop_progressive( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + eta=0.0, + skip_timesteps=0, + init_image=None, + randomize_class=False, + cond_fn_with_grad=False, + ): + """ + Use DDIM to sample from the model and yield intermediate samples from + each timestep of DDIM. + + Same usage as p_sample_loop_progressive(). + """ + if device is None: + device = next(model.parameters()).device + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = th.randn(*shape, device=device) + + if skip_timesteps and init_image is None: + init_image = th.zeros_like(img) + + indices = list(range(self.num_timesteps - skip_timesteps))[::-1] + + if init_image is not None: + my_t = th.ones([shape[0]], device=device, dtype=th.long) * indices[0] + img = self.q_sample(init_image, my_t, img) + + if progress: + # Lazy import so that we don't depend on tqdm. + from tqdm.auto import tqdm + + indices = tqdm(indices) + + for i in indices: + t = th.tensor([i] * shape[0], device=device) + if randomize_class and "y" in model_kwargs: + model_kwargs["y"] = th.randint( + low=0, + high=model.num_classes, + size=model_kwargs["y"].shape, + device=model_kwargs["y"].device, + ) + with th.no_grad(): + sample_fn = ( + self.ddim_sample_with_grad + if cond_fn_with_grad + else self.ddim_sample + ) + out = sample_fn( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + eta=eta, + ) + yield out + img = out["sample"] + + def plms_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + cond_fn_with_grad=False, + order=2, + old_out=None, + ): + """ + Sample x_{t-1} from the model using Pseudo Linear Multistep. + + Same usage as p_sample(). + """ + if not int(order) or not 1 <= order <= 4: + raise ValueError("order is invalid (should be int from 1-4).") + + def get_model_output(x, t): + with th.set_grad_enabled(cond_fn_with_grad and cond_fn is not None): + x = x.detach().requires_grad_() if cond_fn_with_grad else x + out_orig = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + if cond_fn is not None: + if cond_fn_with_grad: + out = self.condition_score_with_grad( + cond_fn, out_orig, x, t, model_kwargs=model_kwargs + ) + x = x.detach() + else: + out = self.condition_score( + cond_fn, out_orig, x, t, model_kwargs=model_kwargs + ) + else: + out = out_orig + + # Usually our model outputs epsilon, but we re-derive it + # in case we used x_start or x_prev prediction. + eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) + return eps, out, out_orig + + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) + eps, out, out_orig = get_model_output(x, t) + + if order > 1 and old_out is None: + # Pseudo Improved Euler + old_eps = [eps] + mean_pred = ( + out["pred_xstart"] * th.sqrt(alpha_bar_prev) + + th.sqrt(1 - alpha_bar_prev) * eps + ) + eps_2, _, _ = get_model_output(mean_pred, t - 1) + eps_prime = (eps + eps_2) / 2 + pred_prime = self._predict_xstart_from_eps(x, t, eps_prime) + mean_pred = ( + pred_prime * th.sqrt(alpha_bar_prev) + + th.sqrt(1 - alpha_bar_prev) * eps_prime + ) + else: + # Pseudo Linear Multistep (Adams-Bashforth) + old_eps = old_out["old_eps"] + old_eps.append(eps) + cur_order = min(order, len(old_eps)) + if cur_order == 1: + eps_prime = old_eps[-1] + elif cur_order == 2: + eps_prime = (3 * old_eps[-1] - old_eps[-2]) / 2 + elif cur_order == 3: + eps_prime = (23 * old_eps[-1] - 16 * old_eps[-2] + 5 * old_eps[-3]) / 12 + elif cur_order == 4: + eps_prime = ( + 55 * old_eps[-1] + - 59 * old_eps[-2] + + 37 * old_eps[-3] + - 9 * old_eps[-4] + ) / 24 + else: + raise RuntimeError("cur_order is invalid.") + pred_prime = self._predict_xstart_from_eps(x, t, eps_prime) + mean_pred = ( + pred_prime * th.sqrt(alpha_bar_prev) + + th.sqrt(1 - alpha_bar_prev) * eps_prime + ) + + if len(old_eps) >= order: + old_eps.pop(0) + + nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) + sample = mean_pred * nonzero_mask + out["pred_xstart"] * (1 - nonzero_mask) + + return { + "sample": sample, + "pred_xstart": out_orig["pred_xstart"], + "old_eps": old_eps, + } + + def plms_sample_loop( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + skip_timesteps=0, + init_image=None, + randomize_class=False, + cond_fn_with_grad=False, + order=2, + ): + """ + Generate samples from the model using Pseudo Linear Multistep. + + Same usage as p_sample_loop(). + """ + final = None + for sample in self.plms_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + skip_timesteps=skip_timesteps, + init_image=init_image, + randomize_class=randomize_class, + cond_fn_with_grad=cond_fn_with_grad, + order=order, + ): + final = sample + return final["sample"] + + def plms_sample_loop_progressive( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + skip_timesteps=0, + init_image=None, + randomize_class=False, + cond_fn_with_grad=False, + order=2, + ): + """ + Use PLMS to sample from the model and yield intermediate samples from each + timestep of PLMS. + + Same usage as p_sample_loop_progressive(). + """ + if device is None: + device = next(model.parameters()).device + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = th.randn(*shape, device=device) + + if skip_timesteps and init_image is None: + init_image = th.zeros_like(img) + + indices = list(range(self.num_timesteps - skip_timesteps))[::-1] + + if init_image is not None: + my_t = th.ones([shape[0]], device=device, dtype=th.long) * indices[0] + img = self.q_sample(init_image, my_t, img) + + if progress: + # Lazy import so that we don't depend on tqdm. + from tqdm.auto import tqdm + + indices = tqdm(indices) + + old_out = None + + for i in indices: + t = th.tensor([i] * shape[0], device=device) + if randomize_class and "y" in model_kwargs: + model_kwargs["y"] = th.randint( + low=0, + high=model.num_classes, + size=model_kwargs["y"].shape, + device=model_kwargs["y"].device, + ) + with th.no_grad(): + out = self.plms_sample( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + cond_fn_with_grad=cond_fn_with_grad, + order=order, + old_out=old_out, + ) + yield out + old_out = out + img = out["sample"] + + def _vb_terms_bpd( + self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None + ): + """ + Get a term for the variational lower-bound. + + The resulting units are bits (rather than nats, as one might expect). + This allows for comparison to other papers. + + :return: a dict with the following keys: + - 'output': a shape [N] tensor of NLLs or KLs. + - 'pred_xstart': the x_0 predictions. + """ + true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance( + x_start=x_start, x_t=x_t, t=t + ) + out = self.p_mean_variance( + model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs + ) + kl = normal_kl( + true_mean, true_log_variance_clipped, out["mean"], out["log_variance"] + ) + kl = mean_flat(kl) / np.log(2.0) + + decoder_nll = -discretized_gaussian_log_likelihood( + x_start, means=out["mean"], log_scales=0.5 * out["log_variance"] + ) + assert decoder_nll.shape == x_start.shape + decoder_nll = mean_flat(decoder_nll) / np.log(2.0) + + # At the first timestep return the decoder NLL, + # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) + output = th.where((t == 0), decoder_nll, kl) + return {"output": output, "pred_xstart": out["pred_xstart"]} + + def training_losses(self, model, x_start, t, model_kwargs=None, noise=None): + """ + Compute training losses for a single timestep. + + :param model: the model to evaluate loss on. + :param x_start: the [N x C x ...] tensor of inputs. + :param t: a batch of timestep indices. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param noise: if specified, the specific Gaussian noise to try to remove. + :return: a dict with the key "loss" containing a tensor of shape [N]. + Some mean or variance settings may also have other keys. + """ + mask = model_kwargs["y"]["mask"] + if model_kwargs is None: + model_kwargs = {} + if noise is None: + noise = th.randn_like(x_start) + x_t = self.q_sample( + x_start, t, noise=noise + ) # use the formula to diffuse the starting tensor by t steps + terms = {} + + # set random dropout for conditioning in training + model_kwargs["cond_drop_prob"] = 0.2 + model_output = model(x_t, self._scale_timesteps(t), **model_kwargs) + target = { + ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance( + x_start=x_start, x_t=x_t, t=t + )[0], + ModelMeanType.START_X: x_start, + ModelMeanType.EPSILON: noise, + }[self.model_mean_type] + + model_output = model_output.permute(0, 2, 1).unsqueeze(2) + assert model_output.shape == target.shape == x_start.shape + + missing_mask = model_kwargs["y"]["missing"][..., 0] + missing_mask = missing_mask.unsqueeze(1).unsqueeze(1) + missing_mask = mask * missing_mask + terms["rot_mse"] = self.masked_l2(target, model_output, missing_mask) + if self.lambda_vel > 0.0: + target_vel = target[..., 1:] - target[..., :-1] + model_output_vel = model_output[..., 1:] - model_output[..., :-1] + terms["vel_mse"] = self.masked_l2( + target_vel, + model_output_vel, + mask[:, :, :, 1:], + ) + + terms["loss"] = terms["rot_mse"] + (self.lambda_vel * terms.get("vel_mse", 0.0)) + + with torch.no_grad(): + terms["vb"] = self._vb_terms_bpd( + model, + x_start, + x_t, + t, + clip_denoised=False, + model_kwargs=model_kwargs, + )["output"] + + return terms + + +def _extract_into_tensor(arr, timesteps, broadcast_shape): + """ + Extract values from a 1-D numpy array for a batch of indices. + + :param arr: the 1-D numpy array. + :param timesteps: a tensor of indices into the array to extract. + :param broadcast_shape: a larger shape of K dimensions with the batch + dimension equal to the length of timesteps. + :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. + """ + res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float() + while len(res.shape) < len(broadcast_shape): + res = res[..., None] + return res.expand(broadcast_shape) diff --git a/diffusion/losses.py b/diffusion/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..c3234f4cff570266a826670e80677bbd6ffd0a74 --- /dev/null +++ b/diffusion/losses.py @@ -0,0 +1,83 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +""" +Helpers for various likelihood-based losses. These are ported from the original +Ho et al. diffusion models codebase: +https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/utils.py +""" + +import numpy as np +import torch as th + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + Compute the KL divergence between two gaussians. + + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, th.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for th.exp(). + logvar1, logvar2 = [ + x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) + for x in (logvar1, logvar2) + ] + + return 0.5 * ( + -1.0 + + logvar2 + - logvar1 + + th.exp(logvar1 - logvar2) + + ((mean1 - mean2) ** 2) * th.exp(-logvar2) + ) + + +def approx_standard_normal_cdf(x): + """ + A fast approximation of the cumulative distribution function of the + standard normal. + """ + return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) + + +def discretized_gaussian_log_likelihood(x, *, means, log_scales): + """ + Compute the log-likelihood of a Gaussian distribution discretizing to a + given image. + + :param x: the target images. It is assumed that this was uint8 values, + rescaled to the range [-1, 1]. + :param means: the Gaussian mean Tensor. + :param log_scales: the Gaussian log stddev Tensor. + :return: a tensor like x of log probabilities (in nats). + """ + assert x.shape == means.shape == log_scales.shape + centered_x = x - means + inv_stdv = th.exp(-log_scales) + plus_in = inv_stdv * (centered_x + 1.0 / 255.0) + cdf_plus = approx_standard_normal_cdf(plus_in) + min_in = inv_stdv * (centered_x - 1.0 / 255.0) + cdf_min = approx_standard_normal_cdf(min_in) + log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) + log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) + cdf_delta = cdf_plus - cdf_min + log_probs = th.where( + x < -0.999, + log_cdf_plus, + th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), + ) + assert log_probs.shape == x.shape + return log_probs diff --git a/diffusion/nn.py b/diffusion/nn.py new file mode 100644 index 0000000000000000000000000000000000000000..f5a4169aafa14b1b2d7c45413615a613b3c9d1b9 --- /dev/null +++ b/diffusion/nn.py @@ -0,0 +1,213 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +""" +original code from +https://github.com/GuyTevet/motion-diffusion-model/blob/main/diffusion/gaussian_diffusion.py +under an MIT license +https://github.com/GuyTevet/motion-diffusion-model/blob/main/LICENSE +""" + +""" +Various utilities for neural networks. +""" + +import math + +import torch as th +import torch.nn as nn + + +# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. +class SiLU(nn.Module): + def forward(self, x): + return x * th.sigmoid(x) + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def update_ema(target_params, source_params, rate=0.99): + """ + Update target parameters to be closer to those of source parameters using + an exponential moving average. + + :param target_params: the target parameter sequence. + :param source_params: the source parameter sequence. + :param rate: the EMA rate (closer to 1 means slower). + """ + for targ, src in zip(target_params, source_params): + targ.detach().mul_(rate).add_(src, alpha=1 - rate) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def sum_flat(tensor): + """ + Take the sum over all non-batch dimensions. + """ + return tensor.sum(dim=list(range(1, len(tensor.shape)))) + + +def normalization(channels): + """ + Make a standard normalization layer. + + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNorm32(32, channels) + + +def timestep_embedding(timesteps, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + half = dim // 2 + freqs = th.exp( + -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) + if dim % 2: + embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(th.autograd.Function): + @staticmethod + @th.cuda.amp.custom_fwd + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_length = length + ctx.save_for_backward(*args) + with th.no_grad(): + output_tensors = ctx.run_function(*args[:length]) + return output_tensors + + @staticmethod + @th.cuda.amp.custom_bwd + def backward(ctx, *output_grads): + args = list(ctx.saved_tensors) + + # Filter for inputs that require grad. If none, exit early. + input_indices = [i for (i, x) in enumerate(args) if x.requires_grad] + if not input_indices: + return (None, None) + tuple(None for _ in args) + + with th.enable_grad(): + for i in input_indices: + if i < ctx.input_length: + # Not sure why the OAI code does this little + # dance. It might not be necessary. + args[i] = args[i].detach().requires_grad_() + args[i] = args[i].view_as(args[i]) + output_tensors = ctx.run_function(*args[: ctx.input_length]) + + if isinstance(output_tensors, th.Tensor): + output_tensors = [output_tensors] + + # Filter for outputs that require grad. If none, exit early. + out_and_grads = [ + (o, g) for (o, g) in zip(output_tensors, output_grads) if o.requires_grad + ] + if not out_and_grads: + return (None, None) + tuple(None for _ in args) + + # Compute gradients on the filtered tensors. + computed_grads = th.autograd.grad( + [o for (o, g) in out_and_grads], + [args[i] for i in input_indices], + [g for (o, g) in out_and_grads], + ) + + # Reassemble the complete gradient tuple. + input_grads = [None for _ in args] + for i, g in zip(input_indices, computed_grads): + input_grads[i] = g + return (None, None) + tuple(input_grads) diff --git a/diffusion/resample.py b/diffusion/resample.py new file mode 100644 index 0000000000000000000000000000000000000000..18f2633ed2d4d764de53bff483fc854a73b2145f --- /dev/null +++ b/diffusion/resample.py @@ -0,0 +1,168 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +""" +original code from +https://github.com/GuyTevet/motion-diffusion-model/blob/main/diffusion/gaussian_diffusion.py +under an MIT license +https://github.com/GuyTevet/motion-diffusion-model/blob/main/LICENSE +""" + +from abc import ABC, abstractmethod + +import numpy as np +import torch as th +import torch.distributed as dist + + +def create_named_schedule_sampler(name, diffusion): + """ + Create a ScheduleSampler from a library of pre-defined samplers. + + :param name: the name of the sampler. + :param diffusion: the diffusion object to sample for. + """ + if name == "uniform": + return UniformSampler(diffusion) + elif name == "loss-second-moment": + return LossSecondMomentResampler(diffusion) + else: + raise NotImplementedError(f"unknown schedule sampler: {name}") + + +class ScheduleSampler(ABC): + """ + A distribution over timesteps in the diffusion process, intended to reduce + variance of the objective. + + By default, samplers perform unbiased importance sampling, in which the + objective's mean is unchanged. + However, subclasses may override sample() to change how the resampled + terms are reweighted, allowing for actual changes in the objective. + """ + + @abstractmethod + def weights(self): + """ + Get a numpy array of weights, one per diffusion step. + + The weights needn't be normalized, but must be positive. + """ + + def sample(self, batch_size, device): + """ + Importance-sample timesteps for a batch. + + :param batch_size: the number of timesteps. + :param device: the torch device to save to. + :return: a tuple (timesteps, weights): + - timesteps: a tensor of timestep indices. + - weights: a tensor of weights to scale the resulting losses. + """ + w = self.weights() + p = w / np.sum(w) + indices_np = np.random.choice(len(p), size=(batch_size,), p=p) + indices = th.from_numpy(indices_np).long().to(device) + weights_np = 1 / (len(p) * p[indices_np]) + weights = th.from_numpy(weights_np).float().to(device) + return indices, weights + + +class UniformSampler(ScheduleSampler): + def __init__(self, diffusion): + self.diffusion = diffusion + self._weights = np.ones([diffusion.num_timesteps]) + + def weights(self): + return self._weights + + +class LossAwareSampler(ScheduleSampler): + def update_with_local_losses(self, local_ts, local_losses): + """ + Update the reweighting using losses from a model. + + Call this method from each rank with a batch of timesteps and the + corresponding losses for each of those timesteps. + This method will perform synchronization to make sure all of the ranks + maintain the exact same reweighting. + + :param local_ts: an integer Tensor of timesteps. + :param local_losses: a 1D Tensor of losses. + """ + batch_sizes = [ + th.tensor([0], dtype=th.int32, device=local_ts.device) + for _ in range(dist.get_world_size()) + ] + dist.all_gather( + batch_sizes, + th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), + ) + + # Pad all_gather batches to be the maximum batch size. + batch_sizes = [x.item() for x in batch_sizes] + max_bs = max(batch_sizes) + + timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes] + loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes] + dist.all_gather(timestep_batches, local_ts) + dist.all_gather(loss_batches, local_losses) + timesteps = [ + x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] + ] + losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] + self.update_with_all_losses(timesteps, losses) + + @abstractmethod + def update_with_all_losses(self, ts, losses): + """ + Update the reweighting using losses from a model. + + Sub-classes should override this method to update the reweighting + using losses from the model. + + This method directly updates the reweighting without synchronizing + between workers. It is called by update_with_local_losses from all + ranks with identical arguments. Thus, it should have deterministic + behavior to maintain state across workers. + + :param ts: a list of int timesteps. + :param losses: a list of float losses, one per timestep. + """ + + +class LossSecondMomentResampler(LossAwareSampler): + def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): + self.diffusion = diffusion + self.history_per_term = history_per_term + self.uniform_prob = uniform_prob + self._loss_history = np.zeros( + [diffusion.num_timesteps, history_per_term], dtype=np.float64 + ) + self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) + + def weights(self): + if not self._warmed_up(): + return np.ones([self.diffusion.num_timesteps], dtype=np.float64) + weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) + weights /= np.sum(weights) + weights *= 1 - self.uniform_prob + weights += self.uniform_prob / len(weights) + return weights + + def update_with_all_losses(self, ts, losses): + for t, loss in zip(ts, losses): + if self._loss_counts[t] == self.history_per_term: + # Shift out the oldest loss term. + self._loss_history[t, :-1] = self._loss_history[t, 1:] + self._loss_history[t, -1] = loss + else: + self._loss_history[t, self._loss_counts[t]] = loss + self._loss_counts[t] += 1 + + def _warmed_up(self): + return (self._loss_counts == self.history_per_term).all() diff --git a/diffusion/respace.py b/diffusion/respace.py new file mode 100644 index 0000000000000000000000000000000000000000..6f7a7f7a28fa20ff6c61d3db39cd5757ac6ea063 --- /dev/null +++ b/diffusion/respace.py @@ -0,0 +1,145 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +""" +original code from +https://github.com/GuyTevet/motion-diffusion-model/blob/main/diffusion/gaussian_diffusion.py +under an MIT license +https://github.com/GuyTevet/motion-diffusion-model/blob/main/LICENSE +""" + +import numpy as np +import torch as th + +from .gaussian_diffusion import GaussianDiffusion + + +def space_timesteps(num_timesteps, section_counts): + """ + Create a list of timesteps to use from an original diffusion process, + given the number of timesteps we want to take from equally-sized portions + of the original process. + + For example, if there's 300 timesteps and the section counts are [10,15,20] + then the first 100 timesteps are strided to be 10 timesteps, the second 100 + are strided to be 15 timesteps, and the final 100 are strided to be 20. + + If the stride is a string starting with "ddim", then the fixed striding + from the DDIM paper is used, and only one section is allowed. + + :param num_timesteps: the number of diffusion steps in the original + process to divide up. + :param section_counts: either a list of numbers, or a string containing + comma-separated numbers, indicating the step count + per section. As a special case, use "ddimN" where N + is a number of steps to use the striding from the + DDIM paper. + :return: a set of diffusion steps from the original process to use. + """ + if isinstance(section_counts, str): + if section_counts.startswith("ddim"): + desired_count = int(section_counts[len("ddim") :]) + for i in range(1, num_timesteps): + if len(range(0, num_timesteps, i)) == desired_count: + return set(range(0, num_timesteps, i)) + raise ValueError( + f"cannot create exactly {num_timesteps} steps with an integer stride" + ) + section_counts = [int(x) for x in section_counts.split(",")] + size_per = num_timesteps // len(section_counts) + extra = num_timesteps % len(section_counts) + start_idx = 0 + all_steps = [] + for i, section_count in enumerate(section_counts): + size = size_per + (1 if i < extra else 0) + if size < section_count: + raise ValueError( + f"cannot divide section of {size} steps into {section_count}" + ) + if section_count <= 1: + frac_stride = 1 + else: + frac_stride = (size - 1) / (section_count - 1) + cur_idx = 0.0 + taken_steps = [] + for _ in range(section_count): + taken_steps.append(start_idx + round(cur_idx)) + cur_idx += frac_stride + all_steps += taken_steps + start_idx += size + return set(all_steps) + + +class SpacedDiffusion(GaussianDiffusion): + """ + A diffusion process which can skip steps in a base diffusion process. + + :param use_timesteps: a collection (sequence or set) of timesteps from the + original diffusion process to retain. + :param kwargs: the kwargs to create the base diffusion process. + """ + + def __init__(self, use_timesteps, **kwargs): + self.use_timesteps = set(use_timesteps) + self.timestep_map = [] + self.original_num_steps = len(kwargs["betas"]) + + base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa + last_alpha_cumprod = 1.0 + new_betas = [] + for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): + if i in self.use_timesteps: + new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) + last_alpha_cumprod = alpha_cumprod + self.timestep_map.append(i) + kwargs["betas"] = np.array(new_betas) + super().__init__(**kwargs) + + def p_mean_variance( + self, model, *args, **kwargs + ): # pylint: disable=signature-differs + return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) + + def training_losses( + self, model, *args, **kwargs + ): # pylint: disable=signature-differs + return super().training_losses(self._wrap_model(model), *args, **kwargs) + + def condition_mean(self, cond_fn, *args, **kwargs): + return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) + + def condition_score(self, cond_fn, *args, **kwargs): + return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) + + def _wrap_model(self, model): + if isinstance(model, _WrappedModel): + return model + return _WrappedModel( + model, self.timestep_map, self.rescale_timesteps, self.original_num_steps + ) + + def _scale_timesteps(self, t): + # Scaling is done by the wrapped model. + return t + + +class _WrappedModel: + def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps): + self.model = model + if hasattr(model, "step"): + self.step = model.step + self.add_frame_cond = model.add_frame_cond + self.timestep_map = timestep_map + self.rescale_timesteps = rescale_timesteps + self.original_num_steps = original_num_steps + + def __call__(self, x, ts, **kwargs): + map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) + new_ts = map_tensor[ts] + if self.rescale_timesteps: + new_ts = new_ts.float() * (1000.0 / self.original_num_steps) + return self.model(x, new_ts, **kwargs) diff --git a/flagged/audio/b90d90dbca93f47e8d01/audio.wav b/flagged/audio/b90d90dbca93f47e8d01/audio.wav new file mode 100644 index 0000000000000000000000000000000000000000..7dcca3910551f767644e57d44d1f31679f09eb09 Binary files /dev/null and b/flagged/audio/b90d90dbca93f47e8d01/audio.wav differ diff --git a/flagged/audio/d8e03e2e6deae2f981b1/audio.wav b/flagged/audio/d8e03e2e6deae2f981b1/audio.wav new file mode 100644 index 0000000000000000000000000000000000000000..7dcca3910551f767644e57d44d1f31679f09eb09 Binary files /dev/null and b/flagged/audio/d8e03e2e6deae2f981b1/audio.wav differ diff --git a/flagged/log.csv b/flagged/log.csv new file mode 100644 index 0000000000000000000000000000000000000000..5a39f5607523b0e36a9741d7df544a935108fc27 --- /dev/null +++ b/flagged/log.csv @@ -0,0 +1,4 @@ +audio,Number of Samples (default = 3),Sample Diversity (default = 0.97),output 0,output 1,output 2,output 3,output 4,output 5,output 6,output 7,output 8,output 9,flag,username,timestamp +,1,0.69,,,,,,,,,,,,,2024-07-15 05:46:49.672259 +flagged/audio/d8e03e2e6deae2f981b1/audio.wav,1,0.69,,,,,,,,,,,,,2024-07-15 06:28:21.003877 +flagged/audio/b90d90dbca93f47e8d01/audio.wav,1,0.69,,,,,,,,,,,,,2024-07-15 06:28:24.442449 diff --git a/model/cfg_sampler.py b/model/cfg_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..a017b98486f6ae0788f8b161652a5e7ef5d7d0c4 --- /dev/null +++ b/model/cfg_sampler.py @@ -0,0 +1,33 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from copy import deepcopy + +import numpy as np +import torch +import torch.nn as nn + + +# A wrapper model for Classifier-free guidance **SAMPLING** only +# https://arxiv.org/abs/2207.12598 +class ClassifierFreeSampleModel(nn.Module): + def __init__(self, model): + super().__init__() + self.model = model # model is the actual model to run + self.nfeats = self.model.nfeats + self.cond_mode = self.model.cond_mode + self.add_frame_cond = self.model.add_frame_cond + if self.add_frame_cond is not None: + if self.model.resume_trans is not None: + self.transformer = self.model.transformer + self.tokenizer = self.model.tokenizer + self.step = self.model.step + + def forward(self, x, timesteps, y=None): + out = self.model(x, timesteps, y, cond_drop_prob=0.0) + out_uncond = self.model(x, timesteps, y, cond_drop_prob=1.0) + return out_uncond + (y["scale"].view(-1, 1, 1) * (out - out_uncond)) diff --git a/model/diffusion.py b/model/diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..10e271c0edd2d4feaa04305a9b0f576e9263c9e7 --- /dev/null +++ b/model/diffusion.py @@ -0,0 +1,403 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import json +from typing import Callable, Optional + +import torch +import torch.nn as nn +from einops import rearrange +from einops.layers.torch import Rearrange + +from model.guide import GuideTransformer +from model.modules.audio_encoder import Wav2VecEncoder +from model.modules.rotary_embedding_torch import RotaryEmbedding +from model.modules.transformer_modules import ( + DecoderLayerStack, + FiLMTransformerDecoderLayer, + RegressionTransformer, + TransformerEncoderLayerRotary, +) +from model.utils import ( + init_weight, + PositionalEncoding, + prob_mask_like, + setup_lip_regressor, + SinusoidalPosEmb, +) +from model.vqvae import setup_tokenizer +from torch.nn import functional as F +from utils.misc import prGreen, prRed + + +class Audio2LipRegressionTransformer(torch.nn.Module): + def __init__( + self, + n_vertices: int = 338, + causal: bool = False, + train_wav2vec: bool = False, + transformer_encoder_layers: int = 2, + transformer_decoder_layers: int = 4, + ): + super().__init__() + self.n_vertices = n_vertices + + self.audio_encoder = Wav2VecEncoder() + if not train_wav2vec: + self.audio_encoder.eval() + for param in self.audio_encoder.parameters(): + param.requires_grad = False + + self.regression_model = RegressionTransformer( + transformer_encoder_layers=transformer_encoder_layers, + transformer_decoder_layers=transformer_decoder_layers, + d_model=512, + d_cond=512, + num_heads=4, + causal=causal, + ) + self.project_output = torch.nn.Linear(512, self.n_vertices * 3) + + def forward(self, audio): + """ + :param audio: tensor of shape B x T x 1600 + :return: tensor of shape B x T x n_vertices x 3 containing reconstructed lip geometry + """ + B, T = audio.shape[0], audio.shape[1] + + cond = self.audio_encoder(audio) + + x = torch.zeros(B, T, 512, device=audio.device) + x = self.regression_model(x, cond) + x = self.project_output(x) + + verts = x.view(B, T, self.n_vertices, 3) + return verts + + +class FiLMTransformer(nn.Module): + def __init__( + self, + args, + nfeats: int, + latent_dim: int = 512, + ff_size: int = 1024, + num_layers: int = 4, + num_heads: int = 4, + dropout: float = 0.1, + cond_feature_dim: int = 4800, + activation: Callable[[torch.Tensor], torch.Tensor] = F.gelu, + use_rotary: bool = True, + cond_mode: str = "audio", + split_type: str = "train", + device: str = "cuda", + **kwargs, + ) -> None: + super().__init__() + self.nfeats = nfeats + self.cond_mode = cond_mode + self.cond_feature_dim = cond_feature_dim + self.add_frame_cond = args.add_frame_cond + self.data_format = args.data_format + self.split_type = split_type + self.device = device + + # positional embeddings + self.rotary = None + self.abs_pos_encoding = nn.Identity() + # if rotary, replace absolute embedding with a rotary embedding instance (absolute becomes an identity) + if use_rotary: + self.rotary = RotaryEmbedding(dim=latent_dim) + else: + self.abs_pos_encoding = PositionalEncoding( + latent_dim, dropout, batch_first=True + ) + + # time embedding processing + self.time_mlp = nn.Sequential( + SinusoidalPosEmb(latent_dim), + nn.Linear(latent_dim, latent_dim * 4), + nn.Mish(), + ) + self.to_time_cond = nn.Sequential( + nn.Linear(latent_dim * 4, latent_dim), + ) + self.to_time_tokens = nn.Sequential( + nn.Linear(latent_dim * 4, latent_dim * 2), + Rearrange("b (r d) -> b r d", r=2), + ) + + # null embeddings for guidance dropout + self.seq_len = args.max_seq_length + emb_len = 1998 # hardcoded for now + self.null_cond_embed = nn.Parameter(torch.randn(1, emb_len, latent_dim)) + self.null_cond_hidden = nn.Parameter(torch.randn(1, latent_dim)) + self.norm_cond = nn.LayerNorm(latent_dim) + self.setup_audio_models() + + # set up pose/face specific parts of the model + self.input_projection = nn.Linear(self.nfeats, latent_dim) + if self.data_format == "pose": + cond_feature_dim = 1024 + key_feature_dim = 104 + self.step = 30 + self.use_cm = True + self.setup_guide_models(args, latent_dim, key_feature_dim) + self.post_pose_layers = self._build_single_pose_conv(self.nfeats) + self.post_pose_layers.apply(init_weight) + self.final_conv = torch.nn.Conv1d(self.nfeats, self.nfeats, kernel_size=1) + self.receptive_field = 25 + elif self.data_format == "face": + self.use_cm = False + cond_feature_dim = 1024 + 1014 + self.setup_lip_models() + self.cond_encoder = nn.Sequential() + for _ in range(2): + self.cond_encoder.append( + TransformerEncoderLayerRotary( + d_model=latent_dim, + nhead=num_heads, + dim_feedforward=ff_size, + dropout=dropout, + activation=activation, + batch_first=True, + rotary=self.rotary, + ) + ) + self.cond_encoder.apply(init_weight) + + self.cond_projection = nn.Linear(cond_feature_dim, latent_dim) + self.non_attn_cond_projection = nn.Sequential( + nn.LayerNorm(latent_dim), + nn.Linear(latent_dim, latent_dim), + nn.SiLU(), + nn.Linear(latent_dim, latent_dim), + ) + + # decoder + decoderstack = nn.ModuleList([]) + for _ in range(num_layers): + decoderstack.append( + FiLMTransformerDecoderLayer( + latent_dim, + num_heads, + dim_feedforward=ff_size, + dropout=dropout, + activation=activation, + batch_first=True, + rotary=self.rotary, + use_cm=self.use_cm, + ) + ) + self.seqTransDecoder = DecoderLayerStack(decoderstack) + self.seqTransDecoder.apply(init_weight) + self.final_layer = nn.Linear(latent_dim, self.nfeats) + self.final_layer.apply(init_weight) + + def _build_single_pose_conv(self, nfeats: int) -> nn.ModuleList: + post_pose_layers = torch.nn.ModuleList( + [ + torch.nn.Conv1d(nfeats, max(256, nfeats), kernel_size=3, dilation=1), + torch.nn.Conv1d(max(256, nfeats), nfeats, kernel_size=3, dilation=2), + torch.nn.Conv1d(nfeats, nfeats, kernel_size=3, dilation=3), + torch.nn.Conv1d(nfeats, nfeats, kernel_size=3, dilation=1), + torch.nn.Conv1d(nfeats, nfeats, kernel_size=3, dilation=2), + torch.nn.Conv1d(nfeats, nfeats, kernel_size=3, dilation=3), + ] + ) + return post_pose_layers + + def _run_single_pose_conv(self, output: torch.Tensor) -> torch.Tensor: + output = torch.nn.functional.pad(output, pad=[self.receptive_field - 1, 0]) + for _, layer in enumerate(self.post_pose_layers): + y = torch.nn.functional.leaky_relu(layer(output), negative_slope=0.2) + if self.split_type == "train": + y = torch.nn.functional.dropout(y, 0.2) + if output.shape[1] == y.shape[1]: + output = (output[:, :, -y.shape[-1] :] + y) / 2.0 # skip connection + else: + output = y + return output + + def setup_guide_models(self, args, latent_dim: int, key_feature_dim: int) -> None: + # set up conditioning info + max_keyframe_len = len(list(range(self.seq_len))[:: self.step]) + self.null_pose_embed = nn.Parameter( + torch.randn(1, max_keyframe_len, latent_dim) + ) + prGreen(f"using keyframes: {self.null_pose_embed.shape}") + self.frame_cond_projection = nn.Linear(key_feature_dim, latent_dim) + self.frame_norm_cond = nn.LayerNorm(latent_dim) + # for test time set up keyframe transformer + self.resume_trans = None + if self.split_type == "test": + if hasattr(args, "resume_trans") and args.resume_trans is not None: + self.resume_trans = args.resume_trans + self.setup_guide_predictor(args.resume_trans) + else: + prRed("not using transformer, just using ground truth") + + def setup_guide_predictor(self, cp_path: str) -> None: + cp_dir = cp_path.split("checkpoints/iter-")[0] + with open(f"{cp_dir}/args.json") as f: + trans_args = json.load(f) + + # set up tokenizer based on trans_arg load point + self.tokenizer = setup_tokenizer(trans_args["resume_pth"]) + + # set up transformer + self.transformer = GuideTransformer( + tokens=self.tokenizer.n_clusters, + num_layers=trans_args["layers"], + dim=trans_args["dim"], + emb_len=1998, + num_audio_layers=trans_args["num_audio_layers"], + ) + for param in self.transformer.parameters(): + param.requires_grad = False + prGreen("loading TRANSFORMER checkpoint from {}".format(cp_path)) + cp = torch.load(cp_path) + missing_keys, unexpected_keys = self.transformer.load_state_dict( + cp["model_state_dict"], strict=False + ) + assert len(missing_keys) == 0, missing_keys + assert len(unexpected_keys) == 0, unexpected_keys + + def setup_audio_models(self) -> None: + self.audio_model, self.audio_resampler = setup_lip_regressor() + + def setup_lip_models(self) -> None: + self.lip_model = Audio2LipRegressionTransformer() + cp_path = "./assets/iter-0200000.pt" + cp = torch.load(cp_path, map_location=torch.device(self.device)) + self.lip_model.load_state_dict(cp["model_state_dict"]) + for param in self.lip_model.parameters(): + param.requires_grad = False + prGreen(f"adding lip conditioning {cp_path}") + + def parameters_w_grad(self): + return [p for p in self.parameters() if p.requires_grad] + + def encode_audio(self, raw_audio: torch.Tensor) -> torch.Tensor: + device = next(self.parameters()).device + a0 = self.audio_resampler(raw_audio[:, :, 0].to(device)) + a1 = self.audio_resampler(raw_audio[:, :, 1].to(device)) + with torch.no_grad(): + z0 = self.audio_model.feature_extractor(a0) + z1 = self.audio_model.feature_extractor(a1) + emb = torch.cat((z0, z1), axis=1).permute(0, 2, 1) + return emb + + def encode_lip(self, audio: torch.Tensor, cond_embed: torch.Tensor) -> torch.Tensor: + reshaped_audio = audio.reshape((audio.shape[0], -1, 1600, 2))[..., 0] + # processes 4 seconds at a time + B, T, _ = reshaped_audio.shape + lip_cond = torch.zeros( + (audio.shape[0], T, 338, 3), + device=audio.device, + dtype=audio.dtype, + ) + for i in range(0, T, 120): + lip_cond[:, i : i + 120, ...] = self.lip_model( + reshaped_audio[:, i : i + 120, ...] + ) + lip_cond = lip_cond.permute(0, 2, 3, 1).reshape((B, 338 * 3, -1)) + lip_cond = torch.nn.functional.interpolate( + lip_cond, size=cond_embed.shape[1], mode="nearest-exact" + ).permute(0, 2, 1) + cond_embed = torch.cat((cond_embed, lip_cond), dim=-1) + return cond_embed + + def encode_keyframes( + self, y: torch.Tensor, cond_drop_prob: float, batch_size: int + ) -> torch.Tensor: + pred = y["keyframes"] + new_mask = y["mask"][..., :: self.step].squeeze((1, 2)) + pred[~new_mask] = 0.0 # pad the unknown + pose_hidden = self.frame_cond_projection(pred.detach().clone().cuda()) + pose_embed = self.abs_pos_encoding(pose_hidden) + pose_tokens = self.frame_norm_cond(pose_embed) + # do conditional dropout for guide poses + key_cond_drop_prob = cond_drop_prob + keep_mask_pose = prob_mask_like( + (batch_size,), 1 - key_cond_drop_prob, device=pose_tokens.device + ) + keep_mask_pose_embed = rearrange(keep_mask_pose, "b -> b 1 1") + null_pose_embed = self.null_pose_embed.to(pose_tokens.dtype) + pose_tokens = torch.where( + keep_mask_pose_embed, + pose_tokens, + null_pose_embed[:, : pose_tokens.shape[1], :], + ) + return pose_tokens + + def forward( + self, + x: torch.Tensor, + times: torch.Tensor, + y: Optional[torch.Tensor] = None, + cond_drop_prob: float = 0.0, + ) -> torch.Tensor: + if x.dim() == 4: + x = x.permute(0, 3, 1, 2).squeeze(-1) + batch_size, device = x.shape[0], x.device + if self.cond_mode == "uncond": + cond_embed = torch.zeros( + (x.shape[0], x.shape[1], self.cond_feature_dim), + dtype=x.dtype, + device=x.device, + ) + else: + cond_embed = y["audio"] + cond_embed = self.encode_audio(cond_embed) + if self.data_format == "face": + cond_embed = self.encode_lip(y["audio"], cond_embed) + pose_tokens = None + if self.data_format == "pose": + pose_tokens = self.encode_keyframes(y, cond_drop_prob, batch_size) + assert cond_embed is not None, "cond emb should not be none" + # process conditioning information + x = self.input_projection(x) + x = self.abs_pos_encoding(x) + audio_cond_drop_prob = cond_drop_prob + keep_mask = prob_mask_like( + (batch_size,), 1 - audio_cond_drop_prob, device=device + ) + keep_mask_embed = rearrange(keep_mask, "b -> b 1 1") + keep_mask_hidden = rearrange(keep_mask, "b -> b 1") + cond_tokens = self.cond_projection(cond_embed) + cond_tokens = self.abs_pos_encoding(cond_tokens) + if self.data_format == "face": + cond_tokens = self.cond_encoder(cond_tokens) + null_cond_embed = self.null_cond_embed.to(cond_tokens.dtype) + cond_tokens = torch.where( + keep_mask_embed, cond_tokens, null_cond_embed[:, : cond_tokens.shape[1], :] + ) + mean_pooled_cond_tokens = cond_tokens.mean(dim=-2) + cond_hidden = self.non_attn_cond_projection(mean_pooled_cond_tokens) + + # create t conditioning + t_hidden = self.time_mlp(times) + t = self.to_time_cond(t_hidden) + t_tokens = self.to_time_tokens(t_hidden) + null_cond_hidden = self.null_cond_hidden.to(t.dtype) + cond_hidden = torch.where(keep_mask_hidden, cond_hidden, null_cond_hidden) + t += cond_hidden + + # cross-attention conditioning + c = torch.cat((cond_tokens, t_tokens), dim=-2) + cond_tokens = self.norm_cond(c) + + # Pass through the transformer decoder + output = self.seqTransDecoder(x, cond_tokens, t, memory2=pose_tokens) + output = self.final_layer(output) + if self.data_format == "pose": + output = output.permute(0, 2, 1) + output = self._run_single_pose_conv(output) + output = self.final_conv(output) + output = output.permute(0, 2, 1) + return output diff --git a/model/guide.py b/model/guide.py new file mode 100644 index 0000000000000000000000000000000000000000..2b362612e238a8ba9f3a4ed78fd2170246eaf3f3 --- /dev/null +++ b/model/guide.py @@ -0,0 +1,222 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from typing import Callable, List + +import torch +import torch as th +import torch.nn as nn +from einops import rearrange +from model.modules.rotary_embedding_torch import RotaryEmbedding + +from model.modules.transformer_modules import ( + DecoderLayerStack, + FiLMTransformerDecoderLayer, + PositionalEncoding, +) +from model.utils import prob_mask_like, setup_lip_regressor +from torch.distributions import Categorical +from torch.nn import functional as F + + +class GuideTransformer(nn.Module): + def __init__( + self, + tokens: int, + num_heads: int = 4, + num_layers: int = 4, + dim: int = 512, + ff_size: int = 1024, + dropout: float = 0.1, + activation: Callable = F.gelu, + use_rotary: bool = True, + cond_feature_dim: int = 1024, + emb_len: int = 798, + num_audio_layers: int = 2, + ): + super().__init__() + self.tokens = tokens + self.token_embedding = th.nn.Embedding( + num_embeddings=tokens + 1, # account for sequence start and end tokens + embedding_dim=dim, + ) + self.abs_pos_encoding = nn.Identity() + # if rotary, replace absolute embedding with a rotary embedding instance (absolute becomes an identity) + if use_rotary: + self.rotary = RotaryEmbedding(dim=dim) + else: + self.abs_pos_encoding = PositionalEncoding(dim, dropout, batch_first=True) + self.setup_audio_models(cond_feature_dim, num_audio_layers) + + self.null_cond_embed = nn.Parameter(torch.randn(1, emb_len, dim)) + self.null_cond_hidden = nn.Parameter(torch.randn(1, dim)) + self.norm_cond = nn.LayerNorm(dim) + + self.cond_projection = nn.Linear(cond_feature_dim, dim) + self.non_attn_cond_projection = nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, dim), + nn.SiLU(), + nn.Linear(dim, dim), + ) + # decoder + decoderstack = nn.ModuleList([]) + for _ in range(num_layers): + decoderstack.append( + FiLMTransformerDecoderLayer( + dim, + num_heads, + dim_feedforward=ff_size, + dropout=dropout, + activation=activation, + batch_first=True, + rotary=self.rotary, + ) + ) + self.seqTransDecoder = DecoderLayerStack(decoderstack) + self.final_layer = nn.Linear(dim, tokens) + + def _build_single_audio_conv(self, c: int) -> List[nn.Module]: + return [ + torch.nn.Conv1d(c, max(256, c), kernel_size=3, dilation=1), + torch.nn.LeakyReLU(negative_slope=0.2), + torch.nn.Dropout(0.2), + # + torch.nn.Conv1d(max(256, c), max(256, c), kernel_size=3, dilation=2), + torch.nn.LeakyReLU(negative_slope=0.2), + torch.nn.Dropout(0.2), + # + torch.nn.Conv1d(max(128, c), max(128, c), kernel_size=3, dilation=3), + torch.nn.LeakyReLU(negative_slope=0.2), + torch.nn.Dropout(0.2), + # + torch.nn.Conv1d(max(128, c), c, kernel_size=3, dilation=1), + torch.nn.LeakyReLU(negative_slope=0.2), + torch.nn.Dropout(0.2), + # + torch.nn.Conv1d(c, c, kernel_size=3, dilation=2), + torch.nn.LeakyReLU(negative_slope=0.2), + torch.nn.Dropout(0.2), + # + torch.nn.Conv1d(c, c, kernel_size=3, dilation=3), + torch.nn.LeakyReLU(negative_slope=0.2), + torch.nn.Dropout(0.2), + ] + + def setup_audio_models(self, cond_feature_dim: int, num_audio_layers: int) -> None: + pre_layers = [] + for _ in range(num_audio_layers): + pre_layers += self._build_single_audio_conv(cond_feature_dim) + pre_layers += [ + torch.nn.Conv1d(cond_feature_dim, cond_feature_dim, kernel_size=1) + ] + pre_layers = torch.nn.ModuleList(pre_layers) + self.pre_audio = nn.Sequential(*pre_layers) + self.audio_model, self.audio_resampler = setup_lip_regressor() + + def encode_audio(self, raw_audio: torch.Tensor) -> torch.Tensor: + device = next(self.parameters()).device + a0 = self.audio_resampler(raw_audio[:, :, 0].to(device)) # B x T + a1 = self.audio_resampler(raw_audio[:, :, 1].to(device)) # B x T + with torch.no_grad(): + z0 = self.audio_model.feature_extractor(a0) + z1 = self.audio_model.feature_extractor(a1) + emb = torch.cat((z0, z1), axis=1).permute(0, 2, 1) + return emb + + def get_tgt_mask(self, size: int, device: str) -> torch.tensor: + mask = torch.tril( + torch.ones((size, size), device=device) == 1 + ) # Lower triangular matrix + mask = mask.float() + mask = mask.masked_fill(mask == 0, float("-inf")) # Convert zeros to -inf + mask = mask.masked_fill(mask == 1, float(0.0)) # Convert ones to 0 + return mask + + def forward( + self, tokens: th.Tensor, condition: th.Tensor, cond_drop_prob: float = 0.0 + ) -> torch.Tensor: + batch_size, device = tokens.shape[0], tokens.device + + x = self.token_embedding(tokens) + x = self.abs_pos_encoding(x) + tgt_mask = self.get_tgt_mask(x.shape[1], x.device) + + cond_embed = self.encode_audio(condition) + keep_mask = prob_mask_like((batch_size,), 1 - cond_drop_prob, device=device) + keep_mask_embed = rearrange(keep_mask, "b -> b 1 1") + keep_mask_hidden = rearrange(keep_mask, "b -> b 1") + cond_tokens = self.pre_audio(cond_embed.permute(0, 2, 1)).permute(0, 2, 1) + # + cond_tokens = self.cond_projection(cond_tokens) + cond_tokens = self.abs_pos_encoding(cond_tokens) + + null_cond_embed = self.null_cond_embed.to(cond_tokens.dtype) + cond_tokens = torch.where( + keep_mask_embed, cond_tokens, null_cond_embed[:, : cond_tokens.shape[1], :] + ) + mean_pooled_cond_tokens = cond_tokens.mean(dim=-2) + cond_hidden = self.non_attn_cond_projection(mean_pooled_cond_tokens) + + # FiLM conditioning + null_cond_hidden = self.null_cond_hidden.to(cond_tokens.dtype) + cond_hidden = torch.where(keep_mask_hidden, cond_hidden, null_cond_hidden) + cond_tokens = self.norm_cond(cond_tokens) + + output = self.seqTransDecoder(x, cond_tokens, cond_hidden, tgt_mask=tgt_mask) + output = self.final_layer(output) + return output + + def generate( + self, + condition: th.Tensor, + sequence_length: int, + layers: int, + n_sequences: int = 1, + max_key_len: int = 8, + max_seq_len: int = 240, + top_p: float = 0.94, + ) -> torch.Tensor: + """ + :param sequence_length: number of tokens to generate in autoregressive fashion + :param n_sequences: number of sequences to generate simultaneously + :param temperature: temerature of the softmax for sampling from the output logits + :return n_sequences x sequence_length LongTensor containing generated tokens + """ + assert max_key_len == int(max_seq_len / 30), "currently only running for 1fps" + max_key_len *= layers + with th.no_grad(): + input_tokens = ( + th.zeros(n_sequences, 1, dtype=th.int64).to(condition.device) + + self.tokens + ) + for _ in range(sequence_length * layers): + curr_input_tokens = input_tokens + curr_condition = condition + logits = self.forward(curr_input_tokens, curr_condition) + logits = logits[:, -1, :] # only most recent time step is relevant + one_hot = th.nn.functional.softmax(logits, dim=-1) + sorted_probs, indices = torch.sort(one_hot, dim=-1, descending=True) + cumulative_probs = torch.cumsum(sorted_probs, dim=-1) + nucleus = cumulative_probs < top_p + nucleus = torch.cat( + [ + nucleus.new_ones(nucleus.shape[:-1] + (1,)), + nucleus[..., :-1], + ], + dim=-1, + ) + sorted_probs[~nucleus] = 0 + sorted_probs /= sorted_probs.sum(-1, keepdim=True) + dist = Categorical(sorted_probs) + idx = dist.sample() + tokens = indices.gather(-1, idx.unsqueeze(-1)) + input_tokens = th.cat([input_tokens, tokens], dim=-1) + + # return generated tokens except for sequence start token + tokens = input_tokens[:, 1:].contiguous() + return tokens diff --git a/model/modules/audio_encoder.py b/model/modules/audio_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..e46c7a99d22169f3b9843bc17aceaffb6ab09bcf --- /dev/null +++ b/model/modules/audio_encoder.py @@ -0,0 +1,194 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import fairseq +import torch as th +import torchaudio as ta + +wav2vec_model_path = "./assets/wav2vec_large.pt" + + +def weights_init(m): + if isinstance(m, th.nn.Conv1d): + th.nn.init.xavier_uniform_(m.weight) + try: + th.nn.init.constant_(m.bias, 0.01) + except: + pass + + +class Wav2VecEncoder(th.nn.Module): + def __init__(self): + super().__init__() + self.resampler = ta.transforms.Resample(orig_freq=48000, new_freq=16000) + model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task( + [wav2vec_model_path] + ) + self.wav2vec_model = model[0] + + def forward(self, audio: th.Tensor): + """ + :param audio: B x T x 1600 + :return: B x T_wav2vec x 512 + """ + audio = audio.view(audio.shape[0], audio.shape[1] * 1600) + audio = self.resampler(audio) + audio = th.cat( + [th.zeros(audio.shape[0], 320, device=audio.device), audio], dim=-1 + ) # zero padding on the left + x = self.wav2vec_model.feature_extractor(audio) + x = self.wav2vec_model.feature_aggregator(x) + x = x.permute(0, 2, 1).contiguous() + return x + + +class Wav2VecDownsampler(th.nn.Module): + def __init__(self): + super().__init__() + self.conv1 = th.nn.Conv1d(512, 512, kernel_size=3) + self.conv2 = th.nn.Conv1d(512, 512, kernel_size=3) + self.norm = th.nn.LayerNorm(512) + + def forward(self, x: th.Tensor, target_length: int): + """ + :param x: B x T x 512 tensor containing wav2vec features at 100Hz + :return: B x target_length x 512 tensor containing downsampled wav2vec features at 30Hz + """ + x = x.permute(0, 2, 1).contiguous() + # first conv + x = th.nn.functional.pad(x, pad=(2, 0)) + x = th.nn.functional.relu(self.conv1(x)) + # first downsampling + x = th.nn.functional.interpolate(x, size=(x.shape[-1] + target_length) // 2) + # second conv + x = th.nn.functional.pad(x, pad=(2, 0)) + x = self.conv2(x) + # second downsampling + x = th.nn.functional.interpolate(x, size=target_length) + # layer norm + x = x.permute(0, 2, 1).contiguous() + x = self.norm(x) + return x + + +class AudioTcn(th.nn.Module): + def __init__( + self, + encoding_dim: int = 128, + use_melspec: bool = True, + use_wav2vec: bool = True, + ): + """ + :param encoding_dim: size of encoding + :param use_melspec: extract mel spectrogram features as input + :param use_wav2vec: extract wav2vec features as input + """ + super().__init__() + self.encoding_dim = encoding_dim + self.use_melspec = use_melspec + self.use_wav2vec = use_wav2vec + + if use_melspec: + # hop_length=400 -> two feature vectors per visual frame (downsampling to 24kHz -> 800 samples per frame) + self.melspec = th.nn.Sequential( + ta.transforms.Resample(orig_freq=48000, new_freq=24000), + ta.transforms.MelSpectrogram( + sample_rate=24000, + n_fft=1024, + win_length=800, + hop_length=400, + n_mels=80, + ), + ) + + if use_wav2vec: + model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task( + [wav2vec_model_path] + ) + self.wav2vec_model = model[0] + self.wav2vec_model.eval() + self.wav2vec_postprocess = th.nn.Conv1d(512, 256, kernel_size=3) + self.wav2vec_postprocess.apply(lambda x: weights_init(x)) + + # temporal model + input_dim = 0 + (160 if use_melspec else 0) + (256 if use_wav2vec else 0) + self.layers = th.nn.ModuleList( + [ + th.nn.Conv1d( + input_dim, max(256, encoding_dim), kernel_size=3, dilation=1 + ), # 2 (+1) + th.nn.Conv1d( + max(256, encoding_dim), encoding_dim, kernel_size=3, dilation=2 + ), # 4 (+1) + th.nn.Conv1d( + encoding_dim, encoding_dim, kernel_size=3, dilation=3 + ), # 6 (+1) + th.nn.Conv1d( + encoding_dim, encoding_dim, kernel_size=3, dilation=1 + ), # 2 (+1) + th.nn.Conv1d( + encoding_dim, encoding_dim, kernel_size=3, dilation=2 + ), # 4 (+1) + th.nn.Conv1d( + encoding_dim, encoding_dim, kernel_size=3, dilation=3 + ), # 6 (+1) + ] + ) + self.layers.apply(lambda x: weights_init(x)) + self.receptive_field = 25 + + self.final = th.nn.Conv1d(encoding_dim, encoding_dim, kernel_size=1) + self.final.apply(lambda x: weights_init(x)) + + def forward(self, audio): + """ + :param audio: B x T x 1600 tensor containing audio samples for each frame + :return: B x T x encoding_dim tensor containing audio encodings for each frame + """ + B, T = audio.shape[0], audio.shape[1] + + # preprocess raw audio signal to extract feature vectors + audio = audio.view(B, T * 1600) + x_mel, x_w2v = th.zeros(B, 0, T).to(audio.device), th.zeros(B, 0, T).to( + audio.device + ) + if self.use_melspec: + x_mel = self.melspec(audio)[:, :, 1:].contiguous() + x_mel = th.log(x_mel.clamp(min=1e-10, max=None)) + x_mel = ( + x_mel.permute(0, 2, 1) + .contiguous() + .view(x_mel.shape[0], T, 160) + .permute(0, 2, 1) + .contiguous() + ) + if self.use_wav2vec: + with th.no_grad(): + x_w2v = ta.functional.resample(audio, 48000, 16000) + x_w2v = self.wav2vec_model.feature_extractor(x_w2v) + x_w2v = self.wav2vec_model.feature_aggregator(x_w2v) + x_w2v = self.wav2vec_postprocess(th.nn.functional.pad(x_w2v, pad=[2, 0])) + x_w2v = th.nn.functional.interpolate( + x_w2v, size=T, align_corners=True, mode="linear" + ) + x = th.cat([x_mel, x_w2v], dim=1) + + # process signal with TCN + x = th.nn.functional.pad(x, pad=[self.receptive_field - 1, 0]) + for layer_idx, layer in enumerate(self.layers): + y = th.nn.functional.leaky_relu(layer(x), negative_slope=0.2) + if self.training: + y = th.nn.functional.dropout(y, 0.2) + if x.shape[1] == y.shape[1]: + x = (x[:, :, -y.shape[-1] :] + y) / 2.0 # skip connection + else: + x = y + + x = self.final(x) + x = x.permute(0, 2, 1).contiguous() + + return x diff --git a/model/modules/rotary_embedding_torch.py b/model/modules/rotary_embedding_torch.py new file mode 100644 index 0000000000000000000000000000000000000000..517b6460641c1e20fac24ee67d53bb98200b87a3 --- /dev/null +++ b/model/modules/rotary_embedding_torch.py @@ -0,0 +1,139 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from inspect import isfunction +from math import log, pi + +import torch +from einops import rearrange, repeat +from torch import einsum, nn + +# helper functions + + +def exists(val): + return val is not None + + +def broadcat(tensors, dim=-1): + num_tensors = len(tensors) + shape_lens = set(list(map(lambda t: len(t.shape), tensors))) + assert len(shape_lens) == 1, "tensors must all have the same number of dimensions" + shape_len = list(shape_lens)[0] + + dim = (dim + shape_len) if dim < 0 else dim + dims = list(zip(*map(lambda t: list(t.shape), tensors))) + + expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim] + assert all( + [*map(lambda t: len(set(t[1])) <= 2, expandable_dims)] + ), "invalid dimensions for broadcastable concatentation" + max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims)) + expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims)) + expanded_dims.insert(dim, (dim, dims[dim])) + expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims))) + tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes))) + return torch.cat(tensors, dim=dim) + + +# rotary embedding helper functions + + +def rotate_half(x): + x = rearrange(x, "... (d r) -> ... d r", r=2) + x1, x2 = x.unbind(dim=-1) + x = torch.stack((-x2, x1), dim=-1) + return rearrange(x, "... d r -> ... (d r)") + + +def apply_rotary_emb(freqs, t, start_index=0): + freqs = freqs.to(t) + rot_dim = freqs.shape[-1] + end_index = start_index + rot_dim + assert ( + rot_dim <= t.shape[-1] + ), f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}" + t_left, t, t_right = ( + t[..., :start_index], + t[..., start_index:end_index], + t[..., end_index:], + ) + t = (t * freqs.cos()) + (rotate_half(t) * freqs.sin()) + return torch.cat((t_left, t, t_right), dim=-1) + + +# learned rotation helpers + + +def apply_learned_rotations(rotations, t, start_index=0, freq_ranges=None): + if exists(freq_ranges): + rotations = einsum("..., f -> ... f", rotations, freq_ranges) + rotations = rearrange(rotations, "... r f -> ... (r f)") + + rotations = repeat(rotations, "... n -> ... (n r)", r=2) + return apply_rotary_emb(rotations, t, start_index=start_index) + + +# classes + + +class RotaryEmbedding(nn.Module): + def __init__( + self, + dim, + custom_freqs=None, + freqs_for="lang", + theta=10000, + max_freq=10, + num_freqs=1, + learned_freq=False, + ): + super().__init__() + if exists(custom_freqs): + freqs = custom_freqs + elif freqs_for == "lang": + freqs = 1.0 / ( + theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim) + ) + elif freqs_for == "pixel": + freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi + elif freqs_for == "constant": + freqs = torch.ones(num_freqs).float() + else: + raise ValueError(f"unknown modality {freqs_for}") + + self.cache = dict() + + if learned_freq: + self.freqs = nn.Parameter(freqs) + else: + self.register_buffer("freqs", freqs) + + def rotate_queries_or_keys(self, t, seq_dim=-2): + device = t.device + seq_len = t.shape[seq_dim] + freqs = self.forward( + lambda: torch.arange(seq_len, device=device), cache_key=seq_len + ) + return apply_rotary_emb(freqs, t) + + def forward(self, t, cache_key=None): + if exists(cache_key) and cache_key in self.cache: + return self.cache[cache_key] + + if isfunction(t): + t = t() + + freqs = self.freqs + + freqs = torch.einsum("..., f -> ... f", t.type(freqs.dtype), freqs) + freqs = repeat(freqs, "... n -> ... (n r)", r=2) + + if exists(cache_key): + self.cache[cache_key] = freqs + + return freqs \ No newline at end of file diff --git a/model/modules/transformer_modules.py b/model/modules/transformer_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..3f2375c5f9e3f6e4d2ac47cc4fbaa549f35bd584 --- /dev/null +++ b/model/modules/transformer_modules.py @@ -0,0 +1,702 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import math +from typing import Any, Callable, List, Optional, Union + +import torch +import torch.nn as nn +from einops import rearrange +from torch import Tensor +from torch.nn import functional as F + + +def generate_causal_mask(source_length, target_length, device="cpu"): + if source_length == target_length: + mask = ( + torch.triu(torch.ones(target_length, source_length, device=device)) == 1 + ).transpose(0, 1) + else: + mask = torch.zeros(target_length, source_length, device=device) + idx = torch.linspace(0, source_length, target_length + 1)[1:].round().long() + for i in range(target_length): + mask[i, 0 : idx[i]] = 1 + + return ( + mask.float() + .masked_fill(mask == 0, float("-inf")) + .masked_fill(mask == 1, float(0.0)) + ) + + +class TransformerEncoderLayerRotary(nn.Module): + def __init__( + self, + d_model: int, + nhead: int, + dim_feedforward: int = 2048, + dropout: float = 0.1, + activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, + layer_norm_eps: float = 1e-5, + batch_first: bool = False, + norm_first: bool = True, + rotary=None, + ) -> None: + super().__init__() + self.self_attn = nn.MultiheadAttention( + d_model, nhead, dropout=dropout, batch_first=batch_first + ) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm_first = norm_first + self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps) + self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.activation = activation + + self.rotary = rotary + self.use_rotary = rotary is not None + + def forward( + self, + src: Tensor, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + x = src + if self.norm_first: + x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask) + x = x + self._ff_block(self.norm2(x)) + else: + x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask)) + x = self.norm2(x + self._ff_block(x)) + + return x + + # self-attention block + def _sa_block( + self, x: Tensor, attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor] + ) -> Tensor: + qk = self.rotary.rotate_queries_or_keys(x) if self.use_rotary else x + x = self.self_attn( + qk, + qk, + x, + attn_mask=attn_mask, + key_padding_mask=key_padding_mask, + need_weights=False, + )[0] + return self.dropout1(x) + + # feed forward block + def _ff_block(self, x: Tensor) -> Tensor: + x = self.linear2(self.dropout(self.activation(self.linear1(x)))) + return self.dropout2(x) + + +class DenseFiLM(nn.Module): + """Feature-wise linear modulation (FiLM) generator.""" + + def __init__(self, embed_channels): + super().__init__() + self.embed_channels = embed_channels + self.block = nn.Sequential( + nn.Mish(), nn.Linear(embed_channels, embed_channels * 2) + ) + + def forward(self, position): + pos_encoding = self.block(position) + pos_encoding = rearrange(pos_encoding, "b c -> b 1 c") + scale_shift = pos_encoding.chunk(2, dim=-1) + return scale_shift + + +def featurewise_affine(x, scale_shift): + scale, shift = scale_shift + return (scale + 1) * x + shift + + +class FiLMTransformerDecoderLayer(nn.Module): + def __init__( + self, + d_model: int, + nhead: int, + dim_feedforward=2048, + dropout=0.1, + activation=F.relu, + layer_norm_eps=1e-5, + batch_first=False, + norm_first=True, + rotary=None, + use_cm=False, + ): + super().__init__() + self.self_attn = nn.MultiheadAttention( + d_model, nhead, dropout=dropout, batch_first=batch_first + ) + self.multihead_attn = nn.MultiheadAttention( + d_model, nhead, dropout=dropout, batch_first=batch_first + ) + # Feedforward + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm_first = norm_first + self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps) + self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps) + self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + self.activation = activation + + self.film1 = DenseFiLM(d_model) + self.film2 = DenseFiLM(d_model) + self.film3 = DenseFiLM(d_model) + + if use_cm: + self.multihead_attn2 = nn.MultiheadAttention( # 2 + d_model, nhead, dropout=dropout, batch_first=batch_first + ) + self.norm2a = nn.LayerNorm(d_model, eps=layer_norm_eps) # 2 + self.dropout2a = nn.Dropout(dropout) # 2 + self.film2a = DenseFiLM(d_model) # 2 + + self.rotary = rotary + self.use_rotary = rotary is not None + + # x, cond, t + def forward( + self, + tgt, + memory, + t, + tgt_mask=None, + memory_mask=None, + tgt_key_padding_mask=None, + memory_key_padding_mask=None, + memory2=None, + ): + x = tgt + if self.norm_first: + # self-attention -> film -> residual + x_1 = self._sa_block(self.norm1(x), tgt_mask, tgt_key_padding_mask) + x = x + featurewise_affine(x_1, self.film1(t)) + # cross-attention -> film -> residual + x_2 = self._mha_block( + self.norm2(x), + memory, + memory_mask, + memory_key_padding_mask, + self.multihead_attn, + self.dropout2, + ) + x = x + featurewise_affine(x_2, self.film2(t)) + if memory2 is not None: + # cross-attention x2 -> film -> residual + x_2a = self._mha_block( + self.norm2a(x), + memory2, + memory_mask, + memory_key_padding_mask, + self.multihead_attn2, + self.dropout2a, + ) + x = x + featurewise_affine(x_2a, self.film2a(t)) + # feedforward -> film -> residual + x_3 = self._ff_block(self.norm3(x)) + x = x + featurewise_affine(x_3, self.film3(t)) + else: + x = self.norm1( + x + + featurewise_affine( + self._sa_block(x, tgt_mask, tgt_key_padding_mask), self.film1(t) + ) + ) + x = self.norm2( + x + + featurewise_affine( + self._mha_block(x, memory, memory_mask, memory_key_padding_mask), + self.film2(t), + ) + ) + x = self.norm3(x + featurewise_affine(self._ff_block(x), self.film3(t))) + return x + + # self-attention block + # qkv + def _sa_block(self, x, attn_mask, key_padding_mask): + qk = self.rotary.rotate_queries_or_keys(x) if self.use_rotary else x + x = self.self_attn( + qk, + qk, + x, + attn_mask=attn_mask, + key_padding_mask=key_padding_mask, + need_weights=False, + )[0] + return self.dropout1(x) + + # multihead attention block + # qkv + def _mha_block(self, x, mem, attn_mask, key_padding_mask, mha, dropout): + q = self.rotary.rotate_queries_or_keys(x) if self.use_rotary else x + k = self.rotary.rotate_queries_or_keys(mem) if self.use_rotary else mem + x = mha( + q, + k, + mem, + attn_mask=attn_mask, + key_padding_mask=key_padding_mask, + need_weights=False, + )[0] + return dropout(x) + + # feed forward block + def _ff_block(self, x): + x = self.linear2(self.dropout(self.activation(self.linear1(x)))) + return self.dropout3(x) + + +class DecoderLayerStack(nn.Module): + def __init__(self, stack): + super().__init__() + self.stack = stack + + def forward(self, x, cond, t, tgt_mask=None, memory2=None): + for layer in self.stack: + x = layer(x, cond, t, tgt_mask=tgt_mask, memory2=memory2) + return x + + +class PositionalEncoding(nn.Module): + def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 1024): + super().__init__() + pe = torch.zeros(max_len, d_model) + position = torch.arange(0, max_len).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model) + ) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + + self.register_buffer("pe", pe) + self.dropout = nn.Dropout(p=dropout) + + def forward(self, x: torch.Tensor): + """ + :param x: B x T x d_model tensor + :return: B x T x d_model tensor + """ + x = x + self.pe[None, : x.shape[1], :] + x = self.dropout(x) + return x + + +class TimestepEncoding(nn.Module): + def __init__(self, embedding_dim: int): + super().__init__() + + # Fourier embedding + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim) * -emb) + self.register_buffer("emb", emb) + + # encoding + self.encoding = nn.Sequential( + nn.Linear(embedding_dim, 4 * embedding_dim), + nn.Mish(), + nn.Linear(4 * embedding_dim, embedding_dim), + ) + + def forward(self, t: torch.Tensor): + """ + :param t: B-dimensional tensor containing timesteps in range [0, 1] + :return: B x embedding_dim tensor containing timestep encodings + """ + x = t[:, None] * self.emb[None, :] + x = torch.cat([torch.sin(x), torch.cos(x)], dim=-1) + x = self.encoding(x) + return x + + +class FiLM(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.dim = dim + self.film = nn.Sequential(nn.Mish(), nn.Linear(dim, dim * 2)) + + def forward(self, x: torch.Tensor, cond: torch.Tensor): + """ + :param x: ... x dim tensor + :param cond: ... x dim tensor + :return: ... x dim tensor as scale(cond) * x + bias(cond) + """ + cond = self.film(cond) + scale, bias = torch.chunk(cond, chunks=2, dim=-1) + x = (scale + 1) * x + bias + return x + + +class FeedforwardBlock(nn.Module): + def __init__(self, d_model: int, d_feedforward: int = 1024, dropout: float = 0.1): + super().__init__() + self.ff = nn.Sequential( + nn.Linear(d_model, d_feedforward), + nn.ReLU(), + nn.Dropout(p=dropout), + nn.Linear(d_feedforward, d_model), + nn.Dropout(p=dropout), + ) + + def forward(self, x: torch.Tensor): + """ + :param x: ... x d_model tensor + :return: ... x d_model tensor + """ + return self.ff(x) + + +class SelfAttention(nn.Module): + def __init__(self, d_model: int, num_heads: int, dropout: float = 0.1): + super().__init__() + self.self_attn = nn.MultiheadAttention( + d_model, num_heads, dropout=dropout, batch_first=True + ) + self.dropout = nn.Dropout(p=dropout) + + def forward( + self, + x: torch.Tensor, + attn_mask: torch.Tensor = None, + key_padding_mask: torch.Tensor = None, + ): + """ + :param x: B x T x d_model input tensor + :param attn_mask: B * num_heads x L x S mask with L=target sequence length, S=source sequence length + for a float mask: values will be added to attention weight + for a binary mask: True indicates that the element is not allowed to attend + :param key_padding_mask: B x S mask + for a float mask: values will be added directly to the corresponding key values + for a binary mask: True indicates that the corresponding key value will be ignored + :return: B x T x d_model output tensor + """ + x = self.self_attn( + x, + x, + x, + attn_mask=attn_mask, + key_padding_mask=key_padding_mask, + need_weights=False, + )[0] + x = self.dropout(x) + return x + + +class CrossAttention(nn.Module): + def __init__(self, d_model: int, d_cond: int, num_heads: int, dropout: float = 0.1): + super().__init__() + self.cross_attn = nn.MultiheadAttention( + d_model, + num_heads, + dropout=dropout, + batch_first=True, + kdim=d_cond, + vdim=d_cond, + ) + self.dropout = nn.Dropout(p=dropout) + + def forward( + self, + x: torch.Tensor, + cond: torch.Tensor, + attn_mask: torch.Tensor = None, + key_padding_mask: torch.Tensor = None, + ): + """ + :param x: B x T_target x d_model input tensor + :param cond: B x T_cond x d_cond condition tensor + :param attn_mask: B * num_heads x L x S mask with L=target sequence length, S=source sequence length + for a float mask: values will be added to attention weight + for a binary mask: True indicates that the element is not allowed to attend + :param key_padding_mask: B x S mask + for a float mask: values will be added directly to the corresponding key values + for a binary mask: True indicates that the corresponding key value will be ignored + :return: B x T x d_model output tensor + """ + x = self.cross_attn( + x, + cond, + cond, + attn_mask=attn_mask, + key_padding_mask=key_padding_mask, + need_weights=False, + )[0] + x = self.dropout(x) + return x + + +class TransformerEncoderLayer(nn.Module): + def __init__( + self, + d_model: int, + num_heads: int, + d_feedforward: int = 1024, + dropout: float = 0.1, + ): + super().__init__() + self.norm1 = nn.LayerNorm(d_model) + self.self_attn = SelfAttention(d_model, num_heads, dropout) + self.norm2 = nn.LayerNorm(d_model) + self.feedforward = FeedforwardBlock(d_model, d_feedforward, dropout) + + def forward( + self, + x: torch.Tensor, + mask: torch.Tensor = None, + key_padding_mask: torch.Tensor = None, + ): + x = x + self.self_attn(self.norm1(x), mask, key_padding_mask) + x = x + self.feedforward(self.norm2(x)) + return x + + +class TransformerDecoderLayer(nn.Module): + def __init__( + self, + d_model: int, + d_cond: int, + num_heads: int, + d_feedforward: int = 1024, + dropout: float = 0.1, + ): + super().__init__() + self.norm1 = nn.LayerNorm(d_model) + self.self_attn = SelfAttention(d_model, num_heads, dropout) + self.norm2 = nn.LayerNorm(d_model) + self.cross_attn = CrossAttention(d_model, d_cond, num_heads, dropout) + self.norm3 = nn.LayerNorm(d_model) + self.feedforward = FeedforwardBlock(d_model, d_feedforward, dropout) + + def forward( + self, + x: torch.Tensor, + cross_cond: torch.Tensor, + target_mask: torch.Tensor = None, + target_key_padding_mask: torch.Tensor = None, + cross_cond_mask: torch.Tensor = None, + cross_cond_key_padding_mask: torch.Tensor = None, + ): + """ + :param x: B x T x d_model tensor + :param cross_cond: B x T x d_cond tensor containing the conditioning input to cross attention layers + :return: B x T x d_model tensor + """ + x = x + self.self_attn(self.norm1(x), target_mask, target_key_padding_mask) + x = x + self.cross_attn( + self.norm2(x), cross_cond, cross_cond_mask, cross_cond_key_padding_mask + ) + x = x + self.feedforward(self.norm3(x)) + return x + + +class FilmTransformerDecoderLayer(nn.Module): + def __init__( + self, + d_model: int, + d_cond: int, + num_heads: int, + d_feedforward: int = 1024, + dropout: float = 0.1, + ): + super().__init__() + self.norm1 = nn.LayerNorm(d_model) + self.self_attn = SelfAttention(d_model, num_heads, dropout) + self.film1 = FiLM(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.cross_attn = CrossAttention(d_model, d_cond, num_heads, dropout) + self.film2 = FiLM(d_model) + self.norm3 = nn.LayerNorm(d_model) + self.feedforward = FeedforwardBlock(d_model, d_feedforward, dropout) + self.film3 = FiLM(d_model) + + def forward( + self, + x: torch.Tensor, + cross_cond: torch.Tensor, + film_cond: torch.Tensor, + target_mask: torch.Tensor = None, + target_key_padding_mask: torch.Tensor = None, + cross_cond_mask: torch.Tensor = None, + cross_cond_key_padding_mask: torch.Tensor = None, + ): + """ + :param x: B x T x d_model tensor + :param cross_cond: B x T x d_cond tensor containing the conditioning input to cross attention layers + :param film_cond: B x [1 or T] x film_cond tensor containing the conditioning input to FiLM layers + :return: B x T x d_model tensor + """ + x1 = self.self_attn(self.norm1(x), target_mask, target_key_padding_mask) + x = x + self.film1(x1, film_cond) + x2 = self.cross_attn( + self.norm2(x), cross_cond, cross_cond_mask, cross_cond_key_padding_mask + ) + x = x + self.film2(x2, film_cond) + x3 = self.feedforward(self.norm3(x)) + x = x + self.film3(x3, film_cond) + return x + + +class RegressionTransformer(nn.Module): + def __init__( + self, + transformer_encoder_layers: int = 2, + transformer_decoder_layers: int = 4, + d_model: int = 512, + d_cond: int = 512, + num_heads: int = 4, + d_feedforward: int = 1024, + dropout: float = 0.1, + causal: bool = False, + ): + super().__init__() + self.causal = causal + + self.cond_positional_encoding = PositionalEncoding(d_cond, dropout) + self.target_positional_encoding = PositionalEncoding(d_model, dropout) + + self.transformer_encoder = nn.ModuleList( + [ + TransformerEncoderLayer(d_cond, num_heads, d_feedforward, dropout) + for _ in range(transformer_encoder_layers) + ] + ) + + self.transformer_decoder = nn.ModuleList( + [ + TransformerDecoderLayer( + d_model, d_cond, num_heads, d_feedforward, dropout + ) + for _ in range(transformer_decoder_layers) + ] + ) + + def forward(self, x: torch.Tensor, cond: torch.Tensor): + """ + :param x: B x T x d_model input tensor + :param cond: B x T x d_cond conditional tensor + :return: B x T x d_model output tensor + """ + x = self.target_positional_encoding(x) + cond = self.cond_positional_encoding(cond) + + if self.causal: + encoder_mask = generate_causal_mask( + cond.shape[1], cond.shape[1], device=cond.device + ) + decoder_self_attn_mask = generate_causal_mask( + x.shape[1], x.shape[1], device=x.device + ) + decoder_cross_attn_mask = generate_causal_mask( + cond.shape[1], x.shape[1], device=x.device + ) + else: + encoder_mask = None + decoder_self_attn_mask = None + decoder_cross_attn_mask = None + + for encoder_layer in self.transformer_encoder: + cond = encoder_layer(cond, mask=encoder_mask) + for decoder_layer in self.transformer_decoder: + x = decoder_layer( + x, + cond, + target_mask=decoder_self_attn_mask, + cross_cond_mask=decoder_cross_attn_mask, + ) + return x + + +class DiffusionTransformer(nn.Module): + def __init__( + self, + transformer_encoder_layers: int = 2, + transformer_decoder_layers: int = 4, + d_model: int = 512, + d_cond: int = 512, + num_heads: int = 4, + d_feedforward: int = 1024, + dropout: float = 0.1, + causal: bool = False, + ): + super().__init__() + self.causal = causal + + self.timestep_encoder = TimestepEncoding(d_model) + self.cond_positional_encoding = PositionalEncoding(d_cond, dropout) + self.target_positional_encoding = PositionalEncoding(d_model, dropout) + + self.transformer_encoder = nn.ModuleList( + [ + TransformerEncoderLayer(d_cond, num_heads, d_feedforward, dropout) + for _ in range(transformer_encoder_layers) + ] + ) + + self.transformer_decoder = nn.ModuleList( + [ + FilmTransformerDecoderLayer( + d_model, d_cond, num_heads, d_feedforward, dropout + ) + for _ in range(transformer_decoder_layers) + ] + ) + + def forward(self, x: torch.Tensor, cond: torch.Tensor, t: torch.Tensor): + """ + :param x: B x T x d_model input tensor + :param cond: B x T x d_cond conditional tensor + :param t: B-dimensional tensor containing diffusion timesteps in range [0, 1] + :return: B x T x d_model output tensor + """ + t = self.timestep_encoder(t).unsqueeze(1) # B x 1 x d_model + x = self.target_positional_encoding(x) + cond = self.cond_positional_encoding(cond) + + if self.causal: + encoder_mask = generate_causal_mask( + cond.shape[1], cond.shape[1], device=cond.device + ) + decoder_self_attn_mask = generate_causal_mask( + x.shape[1], x.shape[1], device=x.device + ) + decoder_cross_attn_mask = generate_causal_mask( + cond.shape[1], x.shape[1], device=x.device + ) + else: + encoder_mask = None + decoder_self_attn_mask = None + decoder_cross_attn_mask = None + + for encoder_layer in self.transformer_encoder: + cond = encoder_layer(cond, mask=encoder_mask) + for decoder_layer in self.transformer_decoder: + x = decoder_layer( + x, + cond, + t, + target_mask=decoder_self_attn_mask, + cross_cond_mask=decoder_cross_attn_mask, + ) + + return x diff --git a/model/utils.py b/model/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2d23d49d4d3ae61559702f347cc8ca60939d0327 --- /dev/null +++ b/model/utils.py @@ -0,0 +1,130 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import math + +import fairseq + +import numpy as np +import torch +import torchaudio.transforms as T +from torch import nn + + +def setup_lip_regressor() -> ("Audio2LipRegressionTransformer", T.Resample): + cp_path = "./assets/vq-wav2vec.pt" + audio_model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([cp_path]) + audio_model = audio_model[0] + for param in audio_model.parameters(): + param.requires_grad = False + audio_model.eval() + audio_resampler = T.Resample(48000, 16000) + return audio_model, audio_resampler + + +def init_weight(m): + if ( + isinstance(m, nn.Conv1d) + or isinstance(m, nn.Linear) + or isinstance(m, nn.ConvTranspose1d) + ): + nn.init.xavier_normal_(m.weight) + # m.bias.data.fill_(0.01) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + +# absolute positional embedding used for vanilla transformer sequential data +class PositionalEncoding(nn.Module): + def __init__(self, d_model, dropout=0.1, max_len=800, batch_first=False): + super().__init__() + self.batch_first = batch_first + + self.dropout = nn.Dropout(p=dropout) + + pe = torch.zeros(max_len, d_model) + position = torch.arange(0, max_len).unsqueeze(1) + div_term = torch.exp(torch.arange(0, d_model, 2) * (-np.log(10000.0) / d_model)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0).transpose(0, 1) + + self.register_buffer("pe", pe) + + def forward(self, x): + if self.batch_first: + x = x + self.pe.permute(1, 0, 2)[:, : x.shape[1], :] + else: + x = x + self.pe[: x.shape[0], :] + return self.dropout(x) + + +# very similar positional embedding used for diffusion timesteps +class SinusoidalPosEmb(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x): + device = x.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device) * -emb) + emb = x[:, None] * emb[None, :] + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb + + +# dropout mask +def prob_mask_like(shape, prob, device): + if prob == 1: + return torch.ones(shape, device=device, dtype=torch.bool) + elif prob == 0: + return torch.zeros(shape, device=device, dtype=torch.bool) + else: + return torch.zeros(shape, device=device).float().uniform_(0, 1) < prob + + +def extract(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +def make_beta_schedule( + schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3 +): + if schedule == "linear": + betas = ( + torch.linspace( + linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64 + ) + ** 2 + ) + + elif schedule == "cosine": + timesteps = ( + torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s + ) + alphas = timesteps / (1 + cosine_s) * np.pi / 2 + alphas = torch.cos(alphas).pow(2) + alphas = alphas / alphas[0] + betas = 1 - alphas[1:] / alphas[:-1] + betas = np.clip(betas, a_min=0, a_max=0.999) + + elif schedule == "sqrt_linear": + betas = torch.linspace( + linear_start, linear_end, n_timestep, dtype=torch.float64 + ) + elif schedule == "sqrt": + betas = ( + torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) + ** 0.5 + ) + else: + raise ValueError(f"schedule '{schedule}' unknown.") + return betas.numpy() diff --git a/model/vqvae.py b/model/vqvae.py new file mode 100644 index 0000000000000000000000000000000000000000..e3363c0789ca19a7c8b6263f066a66c7726e9769 --- /dev/null +++ b/model/vqvae.py @@ -0,0 +1,550 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import json +import os + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange, repeat +from utils.misc import broadcast_tensors + + +def setup_tokenizer(resume_pth: str) -> "TemporalVertexCodec": + args_path = os.path.dirname(resume_pth) + with open(os.path.join(args_path, "args.json")) as f: + trans_args = json.load(f) + tokenizer = TemporalVertexCodec( + n_vertices=trans_args["nb_joints"], + latent_dim=trans_args["output_emb_width"], + categories=trans_args["code_dim"], + residual_depth=trans_args["depth"], + ) + print("loading checkpoint from {}".format(resume_pth)) + ckpt = torch.load(resume_pth, map_location="cpu") + tokenizer.load_state_dict(ckpt["net"], strict=True) + for p in tokenizer.parameters(): + p.requires_grad = False + tokenizer.cuda() + return tokenizer + + +def default(val, d): + return val if val is not None else d + + +def ema_inplace(moving_avg, new, decay: float): + moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) + + +def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5): + return (x + epsilon) / (x.sum() + n_categories * epsilon) + + +def uniform_init(*shape: int): + t = torch.empty(shape) + nn.init.kaiming_uniform_(t) + return t + + +def sum_flat(tensor): + """ + Take the sum over all non-batch dimensions. + """ + return tensor.sum(dim=list(range(1, len(tensor.shape)))) + + +def sample_vectors(samples, num: int): + num_samples, device = samples.shape[0], samples.device + + if num_samples >= num: + indices = torch.randperm(num_samples, device=device)[:num] + else: + indices = torch.randint(0, num_samples, (num,), device=device) + + return samples[indices] + + +def kmeans(samples, num_clusters: int, num_iters: int = 10): + dim, dtype = samples.shape[-1], samples.dtype + + means = sample_vectors(samples, num_clusters) + + for _ in range(num_iters): + diffs = rearrange(samples, "n d -> n () d") - rearrange(means, "c d -> () c d") + dists = -(diffs**2).sum(dim=-1) + + buckets = dists.max(dim=-1).indices + bins = torch.bincount(buckets, minlength=num_clusters) + zero_mask = bins == 0 + bins_min_clamped = bins.masked_fill(zero_mask, 1) + + new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype) + new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples) + new_means = new_means / bins_min_clamped[..., None] + + means = torch.where(zero_mask[..., None], means, new_means) + + return means, bins + + +class EuclideanCodebook(nn.Module): + """Codebook with Euclidean distance. + Args: + dim (int): Dimension. + codebook_size (int): Codebook size. + kmeans_init (bool): Whether to use k-means to initialize the codebooks. + If set to true, run the k-means algorithm on the first training batch and use + the learned centroids as initialization. + kmeans_iters (int): Number of iterations used for k-means algorithm at initialization. + decay (float): Decay for exponential moving average over the codebooks. + epsilon (float): Epsilon value for numerical stability. + threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes + that have an exponential moving average cluster size less than the specified threshold with + randomly selected vector from the current batch. + """ + + def __init__( + self, + dim: int, + codebook_size: int, + kmeans_init: int = False, + kmeans_iters: int = 10, + decay: float = 0.99, + epsilon: float = 1e-5, + threshold_ema_dead_code: int = 2, + ): + super().__init__() + self.decay = decay + init_fn = uniform_init if not kmeans_init else torch.zeros + embed = init_fn(codebook_size, dim) + + self.codebook_size = codebook_size + + self.kmeans_iters = kmeans_iters + self.epsilon = epsilon + self.threshold_ema_dead_code = threshold_ema_dead_code + + self.register_buffer("inited", torch.Tensor([not kmeans_init])) + self.register_buffer("cluster_size", torch.zeros(codebook_size)) + self.register_buffer("embed", embed) + self.register_buffer("embed_avg", embed.clone()) + + @torch.jit.ignore + def init_embed_(self, data): + if self.inited: + return + + embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters) + self.embed.data.copy_(embed) + self.embed_avg.data.copy_(embed.clone()) + self.cluster_size.data.copy_(cluster_size) + self.inited.data.copy_(torch.Tensor([True])) + # Make sure all buffers across workers are in sync after initialization + broadcast_tensors(self.buffers()) + + def replace_(self, samples, mask): + modified_codebook = torch.where( + mask[..., None], sample_vectors(samples, self.codebook_size), self.embed + ) + self.embed.data.copy_(modified_codebook) + + def expire_codes_(self, batch_samples): + if self.threshold_ema_dead_code == 0: + return + + expired_codes = self.cluster_size < self.threshold_ema_dead_code + if not torch.any(expired_codes): + return + + batch_samples = rearrange(batch_samples, "... d -> (...) d") + self.replace_(batch_samples, mask=expired_codes) + broadcast_tensors(self.buffers()) + + def preprocess(self, x): + x = rearrange(x, "... d -> (...) d") + return x + + def quantize(self, x): + embed = self.embed.t() + dist = -( + x.pow(2).sum(1, keepdim=True) + - 2 * x @ embed + + embed.pow(2).sum(0, keepdim=True) + ) + embed_ind = dist.max(dim=-1).indices + return embed_ind + + def postprocess_emb(self, embed_ind, shape): + return embed_ind.view(*shape[:-1]) + + def dequantize(self, embed_ind): + quantize = F.embedding(embed_ind, self.embed) + return quantize + + def encode(self, x): + shape = x.shape + x = self.preprocess(x) + embed_ind = self.quantize(x) + embed_ind = self.postprocess_emb(embed_ind, shape) + return embed_ind + + def decode(self, embed_ind): + quantize = self.dequantize(embed_ind) + return quantize + + def forward(self, x): + shape, dtype = x.shape, x.dtype + x = self.preprocess(x) + + self.init_embed_(x) + + embed_ind = self.quantize(x) + embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype) + embed_ind = self.postprocess_emb(embed_ind, shape) + quantize = self.dequantize(embed_ind) + + if self.training: + # We do the expiry of code at that point as buffers are in sync + # and all the workers will take the same decision. + self.expire_codes_(x) + ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay) + embed_sum = x.t() @ embed_onehot + ema_inplace(self.embed_avg, embed_sum.t(), self.decay) + cluster_size = ( + laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon) + * self.cluster_size.sum() + ) + embed_normalized = self.embed_avg / cluster_size.unsqueeze(1) + self.embed.data.copy_(embed_normalized) + + return quantize, embed_ind + + +class VectorQuantization(nn.Module): + """Vector quantization implementation. + Currently supports only euclidean distance. + Args: + dim (int): Dimension + codebook_size (int): Codebook size + codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim. + decay (float): Decay for exponential moving average over the codebooks. + epsilon (float): Epsilon value for numerical stability. + kmeans_init (bool): Whether to use kmeans to initialize the codebooks. + kmeans_iters (int): Number of iterations used for kmeans initialization. + threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes + that have an exponential moving average cluster size less than the specified threshold with + randomly selected vector from the current batch. + commitment_weight (float): Weight for commitment loss. + """ + + def __init__( + self, + dim: int, + codebook_size: int, + codebook_dim=None, + decay: float = 0.99, + epsilon: float = 1e-5, + kmeans_init: bool = True, + kmeans_iters: int = 50, + threshold_ema_dead_code: int = 2, + commitment_weight: float = 1.0, + ): + super().__init__() + _codebook_dim: int = default(codebook_dim, dim) + + requires_projection = _codebook_dim != dim + self.project_in = ( + nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity() + ) + self.project_out = ( + nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity() + ) + + self.epsilon = epsilon + self.commitment_weight = commitment_weight + + self._codebook = EuclideanCodebook( + dim=_codebook_dim, + codebook_size=codebook_size, + kmeans_init=kmeans_init, + kmeans_iters=kmeans_iters, + decay=decay, + epsilon=epsilon, + threshold_ema_dead_code=threshold_ema_dead_code, + ) + self.codebook_size = codebook_size + self.l2_loss = lambda a, b: (a - b) ** 2 + + @property + def codebook(self): + return self._codebook.embed + + def encode(self, x: torch.Tensor) -> torch.Tensor: + x = self.project_in(x) + embed_in = self._codebook.encode(x) + return embed_in + + def decode(self, embed_ind: torch.Tensor) -> torch.Tensor: + quantize = self._codebook.decode(embed_ind) + quantize = self.project_out(quantize) + return quantize + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + :param x: B x dim input tensor + :return: quantize: B x dim tensor containing reconstruction after quantization + embed_ind: B-dimensional tensor containing embedding indices + loss: scalar tensor containing commitment loss + """ + device = x.device + x = self.project_in(x) + + quantize, embed_ind = self._codebook(x) + + if self.training: + quantize = x + (quantize - x).detach() + + loss = torch.tensor([0.0], device=device, requires_grad=self.training) + + if self.training: + if self.commitment_weight > 0: + commit_loss = F.mse_loss(quantize.detach(), x) + loss = loss + commit_loss * self.commitment_weight + + quantize = self.project_out(quantize) + return quantize, embed_ind, loss + + +class ResidualVectorQuantization(nn.Module): + """Residual vector quantization implementation. + Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf + """ + + def __init__(self, *, num_quantizers: int, **kwargs): + super().__init__() + self.layers = nn.ModuleList( + [VectorQuantization(**kwargs) for _ in range(num_quantizers)] + ) + + def forward(self, x, B, T, mask, n_q=None): + """ + :param x: B x dim tensor + :return: quantized_out: B x dim tensor + out_indices: B x n_q LongTensor containing indices for each quantizer + out_losses: scalar tensor containing commitment loss + """ + quantized_out = 0.0 + residual = x + + all_losses = [] + all_indices = [] + + n_q = n_q or len(self.layers) + + for layer in self.layers[:n_q]: + quantized, indices, loss = layer(residual) + residual = ( + residual - quantized + ) # would need quantizer.detach() to have commitment gradients beyond the first quantizer, but this seems to harm performance + quantized_out = quantized_out + quantized + + all_indices.append(indices) + all_losses.append(loss) + + out_indices = torch.stack(all_indices, dim=-1) + out_losses = torch.mean(torch.stack(all_losses)) + return quantized_out, out_indices, out_losses + + def encode(self, x: torch.Tensor, n_q=None) -> torch.Tensor: + """ + :param x: B x dim input tensor + :return: B x n_q LongTensor containing indices for each quantizer + """ + residual = x + all_indices = [] + n_q = n_q or len(self.layers) + for layer in self.layers[:n_q]: + indices = layer.encode(residual) # indices = 16 x 8 = B x T + # print(indices.shape, residual.shape, x.shape) + quantized = layer.decode(indices) + residual = residual - quantized + all_indices.append(indices) + out_indices = torch.stack(all_indices, dim=-1) + return out_indices + + def decode(self, q_indices: torch.Tensor) -> torch.Tensor: + """ + :param q_indices: B x n_q LongTensor containing indices for each quantizer + :return: B x dim tensor containing reconstruction after quantization + """ + quantized_out = torch.tensor(0.0, device=q_indices.device) + q_indices = q_indices.permute(1, 0).contiguous() + for i, indices in enumerate(q_indices): + layer = self.layers[i] + quantized = layer.decode(indices) + quantized_out = quantized_out + quantized + return quantized_out + + +class TemporalVertexEncoder(nn.Module): + def __init__( + self, + n_vertices: int = 338, + latent_dim: int = 128, + ): + super().__init__() + self.input_dim = n_vertices + self.enc = nn.Sequential( + nn.Conv1d(self.input_dim, latent_dim, kernel_size=1), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv1d(latent_dim, latent_dim, kernel_size=2, dilation=1), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv1d(latent_dim, latent_dim, kernel_size=2, dilation=2), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv1d(latent_dim, latent_dim, kernel_size=2, dilation=3), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv1d(latent_dim, latent_dim, kernel_size=2, dilation=1), + ) + self.receptive_field = 8 + + def forward(self, verts): + """ + :param verts: B x T x n_vertices x 3 tensor containing batched sequences of vertices + :return: B x T x latent_dim tensor containing the latent representation + """ + if verts.dim() == 4: + verts = verts.permute(0, 2, 3, 1).contiguous() + verts = verts.view(verts.shape[0], self.input_dim, verts.shape[3]) + else: + verts = verts.permute(0, 2, 1) + verts = nn.functional.pad(verts, pad=[self.receptive_field - 1, 0]) + x = self.enc(verts) + x = x.permute(0, 2, 1).contiguous() + return x + + +class TemporalVertexDecoder(nn.Module): + def __init__( + self, + n_vertices: int = 338, + latent_dim: int = 128, + ): + super().__init__() + self.output_dim = n_vertices + self.project_mean_shape = nn.Linear(self.output_dim, latent_dim) + self.dec = nn.Sequential( + nn.Conv1d(latent_dim, latent_dim, kernel_size=2, dilation=1), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv1d(latent_dim, latent_dim, kernel_size=2, dilation=2), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv1d(latent_dim, latent_dim, kernel_size=2, dilation=3), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv1d(latent_dim, latent_dim, kernel_size=2, dilation=1), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv1d(latent_dim, self.output_dim, kernel_size=1), + ) + self.receptive_field = 8 + + def forward(self, x): + """ + :param x: B x T x latent_dim tensor containing batched sequences of vertex encodings + :return: B x T x n_vertices x 3 tensor containing batched sequences of vertices + """ + x = x.permute(0, 2, 1).contiguous() + x = nn.functional.pad(x, pad=[self.receptive_field - 1, 0]) + verts = self.dec(x) + verts = verts.permute(0, 2, 1) + return verts + + +class TemporalVertexCodec(nn.Module): + def __init__( + self, + n_vertices: int = 338, + latent_dim: int = 128, + categories: int = 128, + residual_depth: int = 4, + ): + super().__init__() + self.latent_dim = latent_dim + self.categories = categories + self.residual_depth = residual_depth + self.n_clusters = categories + self.encoder = TemporalVertexEncoder( + n_vertices=n_vertices, latent_dim=latent_dim + ) + self.decoder = TemporalVertexDecoder( + n_vertices=n_vertices, latent_dim=latent_dim + ) + self.quantizer = ResidualVectorQuantization( + dim=latent_dim, + codebook_size=categories, + num_quantizers=residual_depth, + decay=0.99, + kmeans_init=True, + kmeans_iters=10, + threshold_ema_dead_code=2, + ) + + def predict(self, verts): + """wrapper to provide compatibility with kmeans""" + return self.encode(verts) + + def encode(self, verts): + """ + :param verts: B x T x n_vertices x 3 tensor containing batched sequences of vertices + :return: B x T x categories x residual_depth LongTensor containing quantized encodings + """ + enc = self.encoder(verts) + q = self.quantizer.encode(enc) + return q + + def decode(self, q): + """ + :param q: B x T x categories x residual_depth LongTensor containing quantized encodings + :return: B x T x n_vertices x 3 tensor containing decoded vertices + """ + reformat = q.dim() > 2 + if reformat: + B, T, _ = q.shape + q = q.reshape((-1, self.residual_depth)) + enc = self.quantizer.decode(q) + if reformat: + enc = enc.reshape((B, T, -1)) + verts = self.decoder(enc) + return verts + + @torch.no_grad() + def compute_perplexity(self, code_idx): + # Calculate new centres + code_onehot = torch.zeros( + self.categories, code_idx.shape[0], device=code_idx.device + ) # categories, N * L + code_onehot.scatter_(0, code_idx.view(1, code_idx.shape[0]), 1) + + code_count = code_onehot.sum(dim=-1) # categories + prob = code_count / torch.sum(code_count) + perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7))) + return perplexity + + def forward(self, verts, mask=None): + """ + :param verts: B x T x n_vertices x 3 tensor containing mesh sequences + :return: verts: B x T x n_vertices x 3 tensor containing reconstructed mesh sequences + vq_loss: scalar tensor for vq commitment loss + """ + B, T = verts.shape[0], verts.shape[1] + x = self.encoder(verts) + x, code_idx, vq_loss = self.quantizer( + x.view(B * T, self.latent_dim), B, T, mask + ) + perplexity = self.compute_perplexity(code_idx[:, -1].view((-1))) + verts = self.decoder(x.view(B, T, self.latent_dim)) + verts = verts.reshape((verts.shape[0], verts.shape[1], -1)) + return verts, vq_loss, perplexity diff --git a/sample/generate.py b/sample/generate.py new file mode 100644 index 0000000000000000000000000000000000000000..c8996ffaf14913448db28e178713934db39a8a6e --- /dev/null +++ b/sample/generate.py @@ -0,0 +1,316 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import os + +from typing import Callable, Dict, Union + +import numpy as np +import torch +from data_loaders.get_data import get_dataset_loader, load_local_data +from diffusion.respace import SpacedDiffusion +from model.cfg_sampler import ClassifierFreeSampleModel +from model.diffusion import FiLMTransformer + +from torch.utils.data import DataLoader +from utils.diff_parser_utils import generate_args +from utils.misc import fixseed, prGreen +from utils.model_util import create_model_and_diffusion, get_person_num, load_model + + +def _construct_template_variables(unconstrained: bool) -> (str,): + row_file_template = "sample{:02d}.mp4" + all_file_template = "samples_{:02d}_to_{:02d}.mp4" + if unconstrained: + sample_file_template = "row{:02d}_col{:02d}.mp4" + sample_print_template = "[{} row #{:02d} column #{:02d} | -> {}]" + row_file_template = row_file_template.replace("sample", "row") + row_print_template = "[{} row #{:02d} | all columns | -> {}]" + all_file_template = all_file_template.replace("samples", "rows") + all_print_template = "[rows {:02d} to {:02d} | -> {}]" + else: + sample_file_template = "sample{:02d}_rep{:02d}.mp4" + sample_print_template = '["{}" ({:02d}) | Rep #{:02d} | -> {}]' + row_print_template = '[ "{}" ({:02d}) | all repetitions | -> {}]' + all_print_template = "[samples {:02d} to {:02d} | all repetitions | -> {}]" + + return ( + sample_print_template, + row_print_template, + all_print_template, + sample_file_template, + row_file_template, + all_file_template, + ) + + +def _replace_keyframes( + model_kwargs: Dict[str, Dict[str, torch.Tensor]], + model: Union[FiLMTransformer, ClassifierFreeSampleModel], +) -> torch.Tensor: + B, T = ( + model_kwargs["y"]["keyframes"].shape[0], + model_kwargs["y"]["keyframes"].shape[1], + ) + with torch.no_grad(): + tokens = model.transformer.generate( + model_kwargs["y"]["audio"], + T, + layers=model.tokenizer.residual_depth, + n_sequences=B, + ) + tokens = tokens.reshape((B, -1, model.tokenizer.residual_depth)) + pred = model.tokenizer.decode(tokens).detach().cpu() + assert ( + model_kwargs["y"]["keyframes"].shape == pred.shape + ), f"{model_kwargs['y']['keyframes'].shape} vs {pred.shape}" + return pred + + +def _run_single_diffusion( + args, + model_kwargs: Dict[str, Dict[str, torch.Tensor]], + diffusion: SpacedDiffusion, + model: Union[FiLMTransformer, ClassifierFreeSampleModel], + inv_transform: Callable, + gt: torch.Tensor, +) -> (torch.Tensor,): + if args.data_format == "pose" and args.resume_trans is not None: + model_kwargs["y"]["keyframes"] = _replace_keyframes(model_kwargs, model) + + sample_fn = diffusion.ddim_sample_loop + with torch.no_grad(): + sample = sample_fn( + model, + (args.batch_size, model.nfeats, 1, args.curr_seq_length), + clip_denoised=False, + model_kwargs=model_kwargs, + init_image=None, + progress=True, + dump_steps=None, + noise=None, + const_noise=False, + ) + sample = inv_transform(sample.cpu().permute(0, 2, 3, 1), args.data_format).permute( + 0, 3, 1, 2 + ) + curr_audio = inv_transform(model_kwargs["y"]["audio"].cpu().numpy(), "audio") + keyframes = inv_transform(model_kwargs["y"]["keyframes"], args.data_format) + gt_seq = inv_transform(gt.cpu().permute(0, 2, 3, 1), args.data_format).permute( + 0, 3, 1, 2 + ) + + return sample, curr_audio, keyframes, gt_seq + + +def _generate_sequences( + args, + model_kwargs: Dict[str, Dict[str, torch.Tensor]], + diffusion: SpacedDiffusion, + model: Union[FiLMTransformer, ClassifierFreeSampleModel], + test_data: torch.Tensor, + gt: torch.Tensor, +) -> Dict[str, np.ndarray]: + all_motions = [] + all_lengths = [] + all_audio = [] + all_gt = [] + all_keyframes = [] + + for rep_i in range(args.num_repetitions): + print(f"### Sampling [repetitions #{rep_i}]") + # add CFG scale to batch + if args.guidance_param != 1: + model_kwargs["y"]["scale"] = ( + torch.ones(args.batch_size, device=args.device) * args.guidance_param + ) + model_kwargs["y"] = { + key: val.to(args.device) if torch.is_tensor(val) else val + for key, val in model_kwargs["y"].items() + } + sample, curr_audio, keyframes, gt_seq = _run_single_diffusion( + args, model_kwargs, diffusion, model, test_data.dataset.inv_transform, gt + ) + all_motions.append(sample.cpu().numpy()) + all_audio.append(curr_audio) + all_keyframes.append(keyframes.cpu().numpy()) + all_gt.append(gt_seq.cpu().numpy()) + all_lengths.append(model_kwargs["y"]["lengths"].cpu().numpy()) + + print(f"created {len(all_motions) * args.batch_size} samples") + + return { + "motions": np.concatenate(all_motions, axis=0), + "audio": np.concatenate(all_audio, axis=0), + "gt": np.concatenate(all_gt, axis=0), + "lengths": np.concatenate(all_lengths, axis=0), + "keyframes": np.concatenate(all_keyframes, axis=0), + } + + +def _render_pred( + args, + data_block: Dict[str, torch.Tensor], + sample_file_template: str, + audio_per_frame: int, +) -> None: + from visualize.render_codes import BodyRenderer + + face_codes = None + if args.face_codes is not None: + face_codes = np.load(args.face_codes, allow_pickle=True).item() + face_motions = face_codes["motions"] + face_gts = face_codes["gt"] + face_audio = face_codes["audio"] + + config_base = f"./checkpoints/ca_body/data/{get_person_num(args.data_root)}" + body_renderer = BodyRenderer( + config_base=config_base, + render_rgb=True, + ) + + for sample_i in range(args.num_samples): + for rep_i in range(args.num_repetitions): + idx = rep_i * args.batch_size + sample_i + save_file = sample_file_template.format(sample_i, rep_i) + animation_save_path = os.path.join(args.output_dir, save_file) + # format data + length = data_block["lengths"][idx] + body_motion = ( + data_block["motions"][idx].transpose(2, 0, 1)[:length].squeeze(-1) + ) + face_motion = face_motions[idx].transpose(2, 0, 1)[:length].squeeze(-1) + assert np.array_equal( + data_block["audio"][idx], face_audio[idx] + ), "face audio is not the same" + audio = data_block["audio"][idx, : length * audio_per_frame, :].T + # set up render data block to pass into renderer + render_data_block = { + "audio": audio, + "body_motion": body_motion, + "face_motion": face_motion, + } + if args.render_gt: + gt_body = data_block["gt"][idx].transpose(2, 0, 1)[:length].squeeze(-1) + gt_face = face_gts[idx].transpose(2, 0, 1)[:length].squeeze(-1) + render_data_block["gt_body"] = gt_body + render_data_block["gt_face"] = gt_face + body_renderer.render_full_video( + render_data_block, + animation_save_path, + audio_sr=audio_per_frame * 30, + render_gt=args.render_gt, + ) + + +def _reset_sample_args(args) -> None: + # set the sequence length to match the one specified by user + name = os.path.basename(os.path.dirname(args.model_path)) + niter = os.path.basename(args.model_path).replace("model", "").replace(".pt", "") + args.curr_seq_length = ( + args.curr_seq_length + if args.curr_seq_length is not None + else args.max_seq_length + ) + # add the resume predictor model path + resume_trans_name = "" + if args.data_format == "pose" and args.resume_trans is not None: + resume_trans_parts = args.resume_trans.split("/") + resume_trans_name = f"{resume_trans_parts[1]}_{resume_trans_parts[-1]}" + # reformat the output directory + args.output_dir = os.path.join( + os.path.dirname(args.model_path), + "samples_{}_{}_seed{}_{}".format(name, niter, args.seed, resume_trans_name), + ) + assert ( + args.num_samples <= args.batch_size + ), f"Please either increase batch_size({args.batch_size}) or reduce num_samples({args.num_samples})" + # set the batch size to match the number of samples to generate + args.batch_size = args.num_samples + + +def _setup_dataset(args) -> DataLoader: + data_root = args.data_root + data_dict = load_local_data( + data_root, + audio_per_frame=1600, + flip_person=args.flip_person, + ) + test_data = get_dataset_loader( + args=args, + data_dict=data_dict, + split="test", + chunk=True, + ) + return test_data + + +def _setup_model( + args, +) -> (Union[FiLMTransformer, ClassifierFreeSampleModel], SpacedDiffusion): + model, diffusion = create_model_and_diffusion(args, split_type="test") + print(f"Loading checkpoints from [{args.model_path}]...") + state_dict = torch.load(args.model_path, map_location="cpu") + load_model(model, state_dict) + + if not args.unconstrained: + assert args.guidance_param != 1 + + if args.guidance_param != 1: + prGreen("[CFS] wrapping model in classifier free sample") + model = ClassifierFreeSampleModel(model) + model.to(args.device) + model.eval() + return model, diffusion + + +def main(): + args = generate_args() + fixseed(args.seed) + _reset_sample_args(args) + + print("Loading dataset...") + test_data = _setup_dataset(args) + iterator = iter(test_data) + + print("Creating model and diffusion...") + model, diffusion = _setup_model(args) + + if args.pose_codes is None: + # generate sequences + gt, model_kwargs = next(iterator) + data_block = _generate_sequences( + args, model_kwargs, diffusion, model, test_data, gt + ) + os.makedirs(args.output_dir, exist_ok=True) + npy_path = os.path.join(args.output_dir, "results.npy") + print(f"saving results file to [{npy_path}]") + np.save(npy_path, data_block) + else: + # load the pre generated results + data_block = np.load(args.pose_codes, allow_pickle=True).item() + + # plot function only if face_codes exist and we are on pose prediction + if args.plot: + assert args.face_codes is not None, "need body and faces" + assert ( + args.data_format == "pose" + ), "currently only supporting plot on pose stuff" + print(f"saving visualizations to [{args.output_dir}]...") + _, _, _, sample_file_template, _, _ = _construct_template_variables( + args.unconstrained + ) + _render_pred( + args, + data_block, + sample_file_template, + test_data.dataset.audio_per_frame, + ) + + +if __name__ == "__main__": + main() diff --git a/scripts/download_alldatasets.sh b/scripts/download_alldatasets.sh new file mode 100644 index 0000000000000000000000000000000000000000..2efc8ce71aa1b6cbaee7f1a22499ddc609286b4e --- /dev/null +++ b/scripts/download_alldatasets.sh @@ -0,0 +1,6 @@ +for i in "PXB184" "RLW104" "TXB805" "GQS883" +do + curl -L https://github.com/facebookresearch/audio2photoreal/releases/download/v1.0/${i}.zip -o ${i}.zip || { echo 'downloading dataset failed' ; exit 1; } + unzip ${i}.zip -d dataset/ + rm ${i}.zip +done diff --git a/scripts/download_allmodels.sh b/scripts/download_allmodels.sh new file mode 100644 index 0000000000000000000000000000000000000000..4e69a2a76a9c9a5366478acb41187100bcc3ef29 --- /dev/null +++ b/scripts/download_allmodels.sh @@ -0,0 +1,13 @@ +for i in "PXB184" "RLW104" "TXB805" "GQS883" +do + # download motion models + wget http://audio2photoreal_models.berkeleyvision.org/${i}_models.tar || { echo 'downloading model failed' ; exit 1; } + tar xvf ${i}_models.tar + rm ${i}_models.tar + + # download ca body rendering checkpoints and assets + mkdir -p checkpoints/ca_body/data/ + wget https://github.com/facebookresearch/ca_body/releases/download/v0.0.1-alpha/${i}.tar.gz || { echo 'downloading ca body model failed' ; exit 1; } + tar xvf ${i}.tar.gz --directory checkpoints/ca_body/data/ + rm ${i}.tar.gz +done \ No newline at end of file diff --git a/scripts/download_prereq.sh b/scripts/download_prereq.sh new file mode 100644 index 0000000000000000000000000000000000000000..65896a8b42c2eceed8f6218a04085b7d13ddc6a8 --- /dev/null +++ b/scripts/download_prereq.sh @@ -0,0 +1,9 @@ + +# install the prerequisite asset models (lip regressor and wav2vec) +wget http://audio2photoreal_models.berkeleyvision.org/asset_models.tar +tar xvf asset_models.tar +rm asset_models.tar + +# we obtained the wav2vec models via these links: +# wget https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_large.pt -P ./assets/ +# wget https://dl.fbaipublicfiles.com/fairseq/wav2vec/vq-wav2vec.pt -P ./assets/ diff --git a/scripts/installation.sh b/scripts/installation.sh new file mode 100644 index 0000000000000000000000000000000000000000..453ffb0f258e6f13247f77910a2d19b2f6084403 --- /dev/null +++ b/scripts/installation.sh @@ -0,0 +1,4 @@ +# download the prerequisite asset models (lip regressor and wav2vec) +wget http://audio2photoreal_models.berkeleyvision.org/asset_models.tar +tar xvf asset_models.tar +rm asset_models.tar diff --git a/scripts/requirements.txt b/scripts/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..741b25894df97dabe013febe39b62ded2e6cb9f1 --- /dev/null +++ b/scripts/requirements.txt @@ -0,0 +1,17 @@ +attrdict +blobfile +einops +fairseq +gradio +matplotlib +mediapy +numpy==1.23.0 +opencv-python +packaging +scikit-learn +tensorboard +tensorboardX +torch==2.0.1 +torchaudio==2.0.2 +torchvision==0.15.2 +tqdm diff --git a/train/train_diffusion.py b/train/train_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..e2992e2caaf8f4603de4c5297595d5f7f1c46276 --- /dev/null +++ b/train/train_diffusion.py @@ -0,0 +1,83 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import json +import os + +import torch +import torch.multiprocessing as mp + +from data_loaders.get_data import get_dataset_loader, load_local_data +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from train.train_platforms import ClearmlPlatform, NoPlatform, TensorboardPlatform +from train.training_loop import TrainLoop +from utils.diff_parser_utils import train_args +from utils.misc import cleanup, fixseed, setup_dist +from utils.model_util import create_model_and_diffusion + + +def main(rank: int, world_size: int): + args = train_args() + fixseed(args.seed) + train_platform_type = eval(args.train_platform_type) + train_platform = train_platform_type(args.save_dir) + train_platform.report_args(args, name="Args") + setup_dist(args.device) + + if rank == 0: + if args.save_dir is None: + raise FileNotFoundError("save_dir was not specified.") + elif os.path.exists(args.save_dir) and not args.overwrite: + raise FileExistsError("save_dir [{}] already exists.".format(args.save_dir)) + elif not os.path.exists(args.save_dir): + os.makedirs(args.save_dir) + args_path = os.path.join(args.save_dir, "args.json") + with open(args_path, "w") as fw: + json.dump(vars(args), fw, indent=4, sort_keys=True) + + if not os.path.exists(args.data_root): + args.data_root = args.data_root.replace("/home/", "/derived/") + + data_dict = load_local_data(args.data_root, audio_per_frame=1600) + print("creating data loader...") + data = get_dataset_loader(args=args, data_dict=data_dict) + + print("creating logger...") + writer = SummaryWriter(args.save_dir) + + print("creating model and diffusion...") + model, diffusion = create_model_and_diffusion(args, split_type="train") + model.to(rank) + + if world_size > 1: + model = DDP( + model, device_ids=[rank], output_device=rank, find_unused_parameters=True + ) + + params = ( + model.module.parameters_w_grad() + if world_size > 1 + else model.parameters_w_grad() + ) + print("Total params: %.2fM" % (sum(p.numel() for p in params) / 1000000.0)) + print("Training...") + + TrainLoop( + args, train_platform, model, diffusion, data, writer, rank, world_size + ).run_loop() + train_platform.close() + cleanup() + + +if __name__ == "__main__": + world_size = torch.cuda.device_count() + print(f"using {world_size} gpus") + if world_size > 1: + mp.spawn(main, args=(world_size,), nprocs=world_size, join=True) + else: + main(rank=0, world_size=1) diff --git a/train/train_guide.py b/train/train_guide.py new file mode 100644 index 0000000000000000000000000000000000000000..e5e7a4a38e209bd6fdf5ac7fdfa045f9a28d886c --- /dev/null +++ b/train/train_guide.py @@ -0,0 +1,362 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import json +import os +from typing import Any, Dict + +import numpy as np +import torch +import torch.optim as optim + +from data_loaders.get_data import get_dataset_loader, load_local_data +from diffusion.nn import sum_flat +from model.guide import GuideTransformer +from model.vqvae import setup_tokenizer, TemporalVertexCodec +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter +from tqdm import tqdm +from utils.guide_parser_utils import train_args +from utils.misc import fixseed + + +class ModelTrainer: + def __init__( + self, args, model: GuideTransformer, tokenizer: TemporalVertexCodec + ) -> None: + self.add_frame_cond = args.add_frame_cond + self.data_format = args.data_format + self.tokenizer = tokenizer + self.model = model.cuda() + self.gn = args.gn + self.max_seq_length = args.max_seq_length + self.optimizer = optim.AdamW( + model.parameters(), + lr=args.lr, + betas=(0.9, 0.99), + weight_decay=args.weight_decay, + ) + self.scheduler = optim.lr_scheduler.MultiStepLR( + self.optimizer, milestones=args.lr_scheduler, gamma=args.gamma + ) + self.l2_loss = lambda a, b: (a - b) ** 2 + self.start_step = 0 + self.warm_up_iter = args.warm_up_iter + self.lr = args.lr + self.ce_loss = torch.nn.CrossEntropyLoss( + ignore_index=self.tokenizer.n_clusters + 1, label_smoothing=0.1 + ) + + if args.resume_trans is not None: + self._load_from_checkpoint() + + def _load_from_checkpoint(self) -> None: + print("loading", args.resume_trans) + ckpt = torch.load(args.resume_trans, map_location="cpu") + self.model.load_state_dict(ckpt["model_state_dict"], strict=True) + self.optimizer.load_state_dict(ckpt["optimizer_state_dict"]) + self.start_step = ckpt["iteration"] + + def _abbreviate( + self, meshes: torch.Tensor, mask: torch.Tensor, step: int + ) -> (torch.Tensor,): + keyframes = meshes[..., ::step] + new_mask = mask[..., ::step] + return keyframes, new_mask + + def _prepare_tokens( + self, meshes: torch.Tensor, mask: torch.Tensor + ) -> (torch.Tensor,): + if self.add_frame_cond == 1: + keyframes, new_mask = self._abbreviate(meshes, mask, 30) + elif self.add_frame_cond is None: + keyframes, new_mask = self._abbreviate(meshes, mask, 1) + + meshes = keyframes.squeeze(2).permute((0, 2, 1)) + B, T, _ = meshes.shape + target_tokens = self.tokenizer.predict(meshes) + target_tokens = target_tokens.reshape(B, -1) + input_tokens = torch.cat( + [ + torch.zeros( + (B, 1), dtype=target_tokens.dtype, device=target_tokens.device + ) + + self.model.tokens, + target_tokens[:, :-1], + ], + axis=-1, + ) + return input_tokens, target_tokens, new_mask, meshes.reshape((B, T, -1)) + + def _run_single_train_step(self, input_tokens, audio, target_tokens): + B, T = input_tokens.shape[0], input_tokens.shape[1] + self.optimizer.zero_grad() + logits = self.model(input_tokens, audio, cond_drop_prob=0.20) + loss = self.ce_loss( + logits.reshape((B * T, -1)), target_tokens.reshape((B * T)).long() + ) + loss.backward() + if self.gn: + torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) + self.optimizer.step() + self.scheduler.step() + return logits, loss + + def _run_single_val_step( + self, motion: torch.Tensor, cond: torch.Tensor + ) -> Dict[str, Any]: + self.model.eval() + with torch.no_grad(): + motion = torch.as_tensor(motion).cuda() + ( + input_tokens, + target_tokens, + new_mask, + downsampled_gt, + ) = self._prepare_tokens(motion, cond["mask"]) + audio = cond["audio"].cuda() + + new_mask = torch.as_tensor(new_mask) + B, T = target_tokens.shape[0], target_tokens.shape[1] + logits = self.model(input_tokens, audio) + tokens = torch.argmax(logits, dim=-1).view( + B, -1, self.tokenizer.residual_depth + ) + pred = self.tokenizer.decode(tokens).detach().cpu() + ce_loss = self.ce_loss( + logits.reshape((B * T, -1)), target_tokens.reshape((B * T)).long() + ) + l2_loss = self._masked_l2( + downsampled_gt.permute(0, 2, 1).unsqueeze(2).detach().cpu(), + pred.permute(0, 2, 1).unsqueeze(2), + new_mask, + ) + acc = self.compute_accuracy(logits, target_tokens, new_mask) + + return { + "pred": pred, + "gt": downsampled_gt, + "metrics": { + "ce_loss": ce_loss.item(), + "l2_loss": l2_loss.item(), + "perplexity": np.exp(ce_loss.item()), + "acc": acc.item(), + }, + } + + def _masked_l2(self, a: torch.Tensor, b: torch.Tensor, mask: torch.Tensor) -> float: + loss = self.l2_loss(a, b) + loss = sum_flat(loss * mask.float()) + n_entries = a.shape[1] * a.shape[2] + non_zero_elements = sum_flat(mask) * n_entries + mse_loss_val = loss / non_zero_elements + return mse_loss_val.mean() + + def compute_ce_loss( + self, logits: torch.Tensor, target_tokens: torch.Tensor, mask: torch.Tensor + ) -> float: + target_tokens[~mask.squeeze().detach().cpu()] = 0 + B = logits.shape[0] + logprobs = torch.log_softmax(logits, dim=-1).view( + B, -1, 1, self.tokenizer.n_clusters + ) + logprobs = logprobs[:, self.mask_left :, :, :].contiguous() + labels = target_tokens.view(B, -1, 1) + labels = labels[:, self.mask_left :, :].contiguous() + loss = torch.nn.functional.nll_loss( + logprobs.view(-1, self.tokenizer.n_clusters), + labels.view(-1).long(), + reduction="none", + ).reshape((B, 1, 1, -1)) + mask = mask.float().to(loss.device) + loss = sum_flat(loss * mask) + non_zero_elements = sum_flat(mask) + ce_loss_val = loss / non_zero_elements + return ce_loss_val.mean() + + def compute_accuracy( + self, logits: torch.Tensor, target: torch.Tensor, mask: torch.Tensor + ) -> float: + mask = mask.squeeze() + probs = torch.softmax(logits, dim=-1) + _, cls_pred_index = torch.max(probs, dim=-1) + acc = (cls_pred_index.flatten(0) == target.flatten(0)).reshape( + cls_pred_index.shape + ) + acc = sum_flat(acc).detach().cpu() + non_zero_elements = sum_flat(mask) + acc_val = acc / non_zero_elements * 100 + return acc_val.mean() + + def update_lr_warm_up(self, nb_iter: int) -> float: + current_lr = self.lr * (nb_iter + 1) / (self.warm_up_iter + 1) + for param_group in self.optimizer.param_groups: + param_group["lr"] = current_lr + return current_lr + + def train_step(self, motion: torch.Tensor, cond: torch.Tensor) -> Dict[str, Any]: + self.model.train() + motion = torch.as_tensor(motion).cuda() + input_tokens, target_tokens, new_mask, downsampled_gt = self._prepare_tokens( + motion, cond["mask"] + ) + audio = cond["audio"].cuda() + new_mask = torch.as_tensor(new_mask) + + logits, loss = self._run_single_train_step(input_tokens, audio, target_tokens) + with torch.no_grad(): + tokens = torch.argmax(logits, dim=-1).view( + input_tokens.shape[0], -1, self.tokenizer.residual_depth + ) + pred = self.tokenizer.decode(tokens).detach().cpu() + l2_loss = self._masked_l2( + downsampled_gt.permute(0, 2, 1).unsqueeze(2).detach().cpu(), + pred.permute(0, 2, 1).unsqueeze(2), + new_mask, + ) + acc = self.compute_accuracy(logits, target_tokens, new_mask) + + return { + "pred": pred, + "gt": downsampled_gt, + "loss": loss, + "metrics": { + "ce_loss": loss.item(), + "l2_loss": l2_loss.item(), + "perplexity": np.exp(loss.item()), + "acc": acc.item(), + }, + } + + def validate( + self, + val_data: DataLoader, + writer: SummaryWriter, + step: int, + save_dir: str, + log_step: int = 100, + max_samples: int = 30, + ) -> None: + val_metrics = {} + pred_values = [] + gt_values = [] + for i, (val_motion, val_cond) in enumerate(val_data): + val_out = self._run_single_val_step(val_motion, val_cond["y"]) + if "metrics" in val_out.keys(): + for k, v in val_out["metrics"].items(): + val_metrics[k] = val_metrics.get(k, 0.0) + v + if "pred" in val_out.keys() and i % log_step == 0: + pred_values.append( + val_data.dataset.inv_transform(val_out["pred"], self.data_format) + ) + gt_values.append( + val_data.dataset.inv_transform(val_out["gt"], self.data_format) + ) + if i % log_step == 0: + print( + f'val_l2_loss at {step} [{i}]: {val_metrics["l2_loss"] / len(val_data):.4f}' + ) + pred_values = torch.concatenate((pred_values), dim=0) + gt_values = torch.concatenate((gt_values), dim=0) + idx = np.random.permutation(len(pred_values))[:max_samples] + pred_values = pred_values[idx] + gt_values = gt_values[idx] + for i, (pred, gt) in enumerate(zip(pred_values, gt_values)): + pred = pred.unsqueeze(0).detach().cpu().numpy() + pose = gt.unsqueeze(0).detach().cpu().numpy() + np.save(os.path.join(save_dir, f"b{i:04d}_pred.npy"), pred) + np.save(os.path.join(save_dir, f"b{i:04d}_gt.npy"), pose) + + msg = "" + for k, v in val_metrics.items(): + writer.add_scalar(f"val_{k}", v / len(val_data), step) + msg += f"val_{k} at {step}: {v / len(val_data):.4f} | " + print(msg) + + +def _save_checkpoint( + args, iteration: int, model: GuideTransformer, optimizer: optim.Optimizer +) -> None: + os.makedirs(f"{args.out_dir}/checkpoints/", exist_ok=True) + filename = f"iter-{iteration:07d}.pt" + torch.save( + { + "iteration": iteration, + "model_state_dict": model.state_dict(), + "optimizer_state_dict": optimizer.state_dict(), + }, + f"{args.out_dir}/checkpoints/{filename}", + ) + + +def _load_data_info(args) -> (DataLoader, DataLoader): + data_dict = load_local_data(args.data_root, audio_per_frame=1600) + train_data = get_dataset_loader( + args=args, data_dict=data_dict, split="train", add_padding=False + ) + val_data = get_dataset_loader(args=args, data_dict=data_dict, split="val") + return train_data, val_data + + +def main(args): + fixseed(args.seed) + os.makedirs(args.out_dir, exist_ok=True) + writer = SummaryWriter(f"{args.out_dir}/logs/") + args_path = os.path.join(args.out_dir, "args.json") + with open(args_path, "w") as fw: + json.dump(vars(args), fw, indent=4, sort_keys=True) + tokenizer = setup_tokenizer(args.resume_pth) + + model = GuideTransformer( + tokens=tokenizer.n_clusters, + emb_len=798 if args.max_seq_length == 240 else 1998, + num_layers=args.layers, + dim=args.dim, + ) + train_data, val_data = _load_data_info(args) + trainer = ModelTrainer(args, model, tokenizer) + step = trainer.start_step + + for _ in range(1, args.total_iter + 1): + train_metrics = {} + count = 0 + for motion, cond in tqdm(train_data): + if step < args.warm_up_iter: + current_lr = trainer.update_lr_warm_up(step) + + # rum single train step + train_out = trainer.train_step(motion, cond["y"]) + if "metrics" in train_out.keys(): + for k, v in train_out["metrics"].items(): + train_metrics[k] = train_metrics.get(k, 0.0) + v + count += 1 + + # log all of the metrics + if step % args.log_interval == 0: + msg = "" + for k, v in train_metrics.items(): + writer.add_scalar(f"train_{k}", v / count, step) + msg += f"train_{k} at {step}: {v / count:.4f} | " + train_metrics = {} + count = 0 + writer.add_scalar(f"train_lr", trainer.scheduler.get_lr()[0], step) + if step < args.warm_up_iter: + msg += f"lr: {current_lr} | " + print(msg) + writer.flush() + + # run single evaluation step and save + if step % args.eval_interval == 0: + trainer.validate(val_data, writer, step, args.out_dir) + if step % args.save_interval == 0: + _save_checkpoint(args, step, trainer.model, trainer.optimizer) + step += 1 + + +if __name__ == "__main__": + args = train_args() + main(args) diff --git a/train/train_platforms.py b/train/train_platforms.py new file mode 100644 index 0000000000000000000000000000000000000000..83200ee6956a9680f557200f38fc559481c4527a --- /dev/null +++ b/train/train_platforms.py @@ -0,0 +1,59 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import os + +class TrainPlatform: + def __init__(self, save_dir): + pass + + def report_scalar(self, name, value, iteration, group_name=None): + pass + + def report_args(self, args, name): + pass + + def close(self): + pass + + +class ClearmlPlatform(TrainPlatform): + def __init__(self, save_dir): + from clearml import Task + path, name = os.path.split(save_dir) + self.task = Task.init(project_name='motion_diffusion', + task_name=name, + output_uri=path) + self.logger = self.task.get_logger() + + def report_scalar(self, name, value, iteration, group_name): + self.logger.report_scalar(title=group_name, series=name, iteration=iteration, value=value) + + def report_args(self, args, name): + self.task.connect(args, name=name) + + def close(self): + self.task.close() + + +class TensorboardPlatform(TrainPlatform): + def __init__(self, save_dir): + from torch.utils.tensorboard import SummaryWriter + self.writer = SummaryWriter(log_dir=save_dir) + + def report_scalar(self, name, value, iteration, group_name=None): + self.writer.add_scalar(f'{group_name}/{name}', value, iteration) + + def close(self): + self.writer.close() + + +class NoPlatform(TrainPlatform): + def __init__(self, save_dir): + pass + + diff --git a/train/train_vq.py b/train/train_vq.py new file mode 100644 index 0000000000000000000000000000000000000000..41829ff5785ca667a6ff4b90038c5438df0a8e24 --- /dev/null +++ b/train/train_vq.py @@ -0,0 +1,374 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import argparse +import copy +import json +import logging +import os +import sys +import warnings +from typing import Any, Dict + +import model.vqvae as vqvae + +import numpy as np +import torch +import torch.optim as optim +from data_loaders.get_data import get_dataset_loader, load_local_data +from diffusion.nn import sum_flat +from torch.utils.tensorboard import SummaryWriter +from tqdm import tqdm +from utils.vq_parser_utils import train_args + +warnings.filterwarnings("ignore") + + +def cycle(iterable): + while True: + for x in iterable: + yield x + + +def get_logger(out_dir: str): + logger = logging.getLogger("Exp") + logger.setLevel(logging.INFO) + formatter = logging.Formatter("%(asctime)s %(levelname)s %(message)s") + + file_path = os.path.join(out_dir, "run.log") + file_hdlr = logging.FileHandler(file_path) + file_hdlr.setFormatter(formatter) + + strm_hdlr = logging.StreamHandler(sys.stdout) + strm_hdlr.setFormatter(formatter) + + logger.addHandler(file_hdlr) + logger.addHandler(strm_hdlr) + return logger + + +class ModelTrainer: + def __init__(self, args, net: vqvae.TemporalVertexCodec, logger, writer): + self.net = net + self.warm_up_iter = args.warm_up_iter + self.lr = args.lr + self.optimizer = optim.AdamW( + self.net.parameters(), + lr=args.lr, + betas=(0.9, 0.99), + weight_decay=args.weight_decay, + ) + self.scheduler = torch.optim.lr_scheduler.MultiStepLR( + self.optimizer, milestones=args.lr_scheduler, gamma=args.gamma + ) + self.data_format = args.data_format + self.loss = torch.nn.SmoothL1Loss() + self.loss_vel = args.loss_vel + self.commit = args.commit + self.logger = logger + self.writer = writer + self.best_commit = float("inf") + self.best_recons = float("inf") + self.best_perplexity = float("inf") + self.best_iter = 0 + self.out_dir = args.out_dir + + def _masked_l2(self, a, b, mask): + loss = self._l2_loss(a, b) + loss = sum_flat(loss * mask.float()) + n_entries = a.shape[1] * a.shape[2] + non_zero_elements = sum_flat(mask) * n_entries + mse_loss_val = loss / non_zero_elements + return mse_loss_val + + def _l2_loss(self, motion_pred, motion_gt, mask=None): + if mask is not None: + return self._masked_l2(motion_pred, motion_gt, mask) + else: + return self.loss(motion_pred, motion_gt) + + def _vel_loss(self, motion_pred, motion_gt): + model_results_vel = motion_pred[..., :-1] - motion_pred[..., 1:] + model_targets_vel = motion_gt[..., :-1] - motion_gt[..., 1:] + return self.loss(model_results_vel, model_targets_vel) + + def _update_lr_warm_up(self, nb_iter): + current_lr = self.lr * (nb_iter + 1) / (self.warm_up_iter + 1) + for param_group in self.optimizer.param_groups: + param_group["lr"] = current_lr + return current_lr + + def run_warmup_steps(self, train_loader_iter, skip_step, logger): + avg_recons, avg_perplexity, avg_commit = 0.0, 0.0, 0.0 + for nb_iter in tqdm(range(1, args.warm_up_iter)): + current_lr = self._update_lr_warm_up(nb_iter) + gt_motion, cond = next(train_loader_iter) + loss_dict = self.run_train_step(gt_motion, cond, skip_step) + + avg_recons += loss_dict["loss_motion"] + avg_perplexity += loss_dict["perplexity"] + avg_commit += loss_dict["loss_commit"] + + if nb_iter % args.print_iter == 0: + avg_recons /= args.print_iter + avg_perplexity /= args.print_iter + avg_commit /= args.print_iter + + logger.info( + f"Warmup. Iter {nb_iter} : lr {current_lr:.5f} \t Commit. {avg_commit:.5f} \t PPL. {avg_perplexity:.2f} \t Recons. {avg_recons:.5f}" + ) + + avg_recons, avg_perplexity, avg_commit = 0.0, 0.0, 0.0 + + def run_train_step( + self, gt_motion: torch.Tensor, cond: torch.Tensor, skip_step: int + ) -> Dict[str, Any]: + self.net.train() + loss_dict = {} + # run model + gt_motion = gt_motion.permute(0, 3, 1, 2).squeeze(-1).cuda().float() + cond["y"] = { + key: val.to(gt_motion.device) if torch.is_tensor(val) else val + for key, val in cond["y"].items() + } + gt_motion = gt_motion[:, ::skip_step, :] + pred_motion, loss_commit, perplexity = self.net(gt_motion, mask=None) + loss_motion = self._l2_loss(pred_motion, gt_motion).mean() + loss_vel = 0.0 + if self.loss_vel > 0: + loss_vel = self._vel_loss(pred_motion, gt_motion) + loss = loss_motion + self.commit * loss_commit + self.loss_vel * loss_vel + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + # record losses + if self.loss_vel > 0: + loss_dict["vel"] = loss_vel.item() + loss_dict["loss"] = loss.item() + loss_dict["loss_motion"] = loss_motion.item() + loss_dict["loss_commit"] = loss_commit.item() + loss_dict["perplexity"] = perplexity.item() + return loss_dict + + def save_model(self, save_path): + torch.save( + { + "net": self.net.state_dict(), + "optimizer": self.optimizer.state_dict(), + "scheduler": self.scheduler, + }, + save_path, + ) + + def _save_predictions(self, name, unstd_pose, unstd_pred): + curr_name = os.path.basename(name) + path = os.path.join(self.out_dir, curr_name) + for j in range(len(path.split("/")) - 1): + if not os.path.exists("/".join(path.split("/")[: j + 1])): + os.system("mkdir " + "/".join(path.split("/")[: j + 1])) + np.save(os.path.join(self.out_dir, curr_name + "_gt.npy"), unstd_pose) + np.save(os.path.join(self.out_dir, curr_name + "_pred.npy"), unstd_pred) + + def _log_losses( + self, + commit_loss: float, + recons_loss: float, + total_perplexity: float, + nb_iter: int, + nb_sample: int, + draw: bool, + save: bool, + ) -> None: + avg_commit = commit_loss / nb_sample + avg_recons = recons_loss / nb_sample + avg_perplexity = total_perplexity / nb_sample + self.logger.info( + f"Eval. Iter {nb_iter} : \t Commit. {avg_commit:.5f} \t PPL. {avg_perplexity:.2f} \t Recons. {avg_recons:.5f}" + ) + + if draw: + self.writer.add_scalar("./Val/Perplexity", avg_perplexity, nb_iter) + self.writer.add_scalar("./Val/Commit", avg_commit, nb_iter) + self.writer.add_scalar("./Val/Recons", avg_recons, nb_iter) + + if avg_perplexity < self.best_perplexity: + msg = f"--> --> \t Perplexity Improved from {self.best_perplexity:.5f} to {avg_perplexity:.5f} !!!" + self.logger.info(msg) + self.best_perplexity = avg_perplexity + if save: + print(f"saving checkpoint net_best.pth") + self.save_model(os.path.join(self.out_dir, "net_best.pth")) + + if avg_commit < self.best_commit: + msg = f"--> --> \t Commit Improved from {self.best_commit:.5f} to {avg_commit:.5f} !!!" + self.logger.info(msg) + self.best_commit = avg_commit + + if avg_recons < self.best_recons: + msg = f"--> --> \t Recons Improved from {self.best_recons:.5f} to {avg_recons:.5f} !!!" + self.logger.info(msg) + self.best_recons = avg_recons + + @torch.no_grad() + def evaluation_vqvae( + self, + val_loader, + nb_iter: int, + draw: bool = True, + save: bool = True, + savenpy: bool = False, + ) -> None: + self.net.eval() + nb_sample = 0 + commit_loss = 0 + recons_loss = 0 + total_perplexity = 0 + for _, batch in enumerate(val_loader): + motion, cond = batch + m_length = cond["y"]["lengths"] + motion = motion.permute(0, 3, 1, 2).squeeze(-1).cuda().float() + cond["y"] = { + key: val.to(motion.device) if torch.is_tensor(val) else val + for key, val in cond["y"].items() + } + motion = motion[:, :: val_loader.dataset.step, :].cuda().float() + bs, seq = motion.shape[0], motion.shape[1] + pred_pose_eval = torch.zeros((bs, seq, motion.shape[-1])).cuda() + for i in range(bs): + curr_gt = motion[i : i + 1, : m_length[i]] + pred, loss_commit, perplexity = self.net(curr_gt) + l2_loss = self._l2_loss(pred, curr_gt) + recons_loss += l2_loss.mean().item() + commit_loss += loss_commit + total_perplexity += perplexity + unstd_pred = val_loader.dataset.inv_transform( + pred.detach().cpu().numpy(), self.data_format + ) + unstd_pose = val_loader.dataset.inv_transform( + curr_gt.detach().cpu().numpy(), self.data_format + ) + if savenpy: + self._save_predictions( + "b{i:04d}", unstd_pose[:, : m_length[i]], unstd_pred + ) + pred_pose_eval[i : i + 1, : m_length[i], :] = pred + nb_sample += bs + + self._log_losses( + commit_loss, recons_loss, total_perplexity, nb_iter, nb_sample, draw, save + ) + if save: + print(f"saving checkpoint net_last.pth") + self.save_model(os.path.join(self.out_dir, "net_last.pth")) + if nb_iter % 100000 == 0: + print(f"saving checkpoint net_iter_x.pth") + self.save_model( + os.path.join(self.out_dir, "net_iter" + str(nb_iter) + ".pth") + ) + + +def _load_data_info(args, logger): + data_dict = load_local_data(args.data_root, audio_per_frame=1600) + train_loader = get_dataset_loader( + args=args, data_dict=data_dict, split="train", add_padding=False + ) + val_loader = get_dataset_loader( + args=args, data_dict=data_dict, split="val", add_padding=False + ) + + logger.info( + f"Training on {args.dataname}, motions are with {args.nb_joints} joints" + ) + train_loader_iter = cycle(train_loader) + skip_step = train_loader.dataset.step + return train_loader_iter, val_loader, skip_step + + +def _load_checkpoint(args, net, logger): + cp_dir = os.path.dirname(args.resume_pth) + with open(f"{cp_dir}/args.json") as f: + trans_args = json.load(f) + assert trans_args["data_root"] == args.data_root, "data_root doesnt match" + logger.info("loading checkpoint from {}".format(args.resume_pth)) + ckpt = torch.load(args.resume_pth, map_location="cpu") + net.load_state_dict(ckpt["net"], strict=True) + return net + + +def main(args): + torch.manual_seed(args.seed) + os.makedirs(args.out_dir, exist_ok=True) + logger = get_logger(args.out_dir) + writer = SummaryWriter(args.out_dir) + logger.info(json.dumps(vars(args), indent=4, sort_keys=True)) + + if args.data_format == "pose": + args.nb_joints = 104 + elif args.data_format == "face": + args.nb_joints = 256 + + args_path = os.path.join(args.out_dir, "args.json") + with open(args_path, "w") as fw: + json.dump(vars(args), fw, indent=4, sort_keys=True) + + if not os.path.exists(args.data_root): + args.data_root = args.data_root.replace("/home/", "/derived/") + + train_loader_iter, val_loader, skip_step = _load_data_info(args, logger) + net = vqvae.TemporalVertexCodec( + n_vertices=args.nb_joints, + latent_dim=args.output_emb_width, + categories=args.code_dim, + residual_depth=args.depth, + ) + if args.resume_pth: + net = _load_checkpoint(args, net, logger) + net.train() + net.cuda() + + trainer = ModelTrainer(args, net, logger, writer) + + trainer.run_warmup_steps(train_loader_iter, skip_step, logger) + avg_recons, avg_perplexity, avg_commit = 0.0, 0.0, 0.0 + with torch.no_grad(): + trainer.evaluation_vqvae( + val_loader, 0, save=(args.total_iter > 0), savenpy=True + ) + + for nb_iter in range(1, args.total_iter + 1): + gt_motion, cond = next(train_loader_iter) + loss_dict = trainer.run_train_step(gt_motion, cond, skip_step) + trainer.scheduler.step() + + avg_recons += loss_dict["loss_motion"] + avg_perplexity += loss_dict["perplexity"] + avg_commit += loss_dict["loss_commit"] + + if nb_iter % args.print_iter == 0: + avg_recons /= args.print_iter + avg_perplexity /= args.print_iter + avg_commit /= args.print_iter + + writer.add_scalar("./Train/L1", avg_recons, nb_iter) + writer.add_scalar("./Train/PPL", avg_perplexity, nb_iter) + writer.add_scalar("./Train/Commit", avg_commit, nb_iter) + + logger.info( + f"Train. Iter {nb_iter} : \t Commit. {avg_commit:.5f} \t PPL. {avg_perplexity:.2f} \t Recons. {avg_recons:.5f}" + ) + + avg_recons, avg_perplexity, avg_commit = (0.0, 0.0, 0.0) + + if nb_iter % args.eval_iter == 0: + trainer.evaluation_vqvae( + val_loader, nb_iter, save=(args.total_iter > 0), savenpy=True + ) + + +if __name__ == "__main__": + args = train_args() + main(args) diff --git a/train/training_loop.py b/train/training_loop.py new file mode 100644 index 0000000000000000000000000000000000000000..29d9c69db9e0c7e26345f28c48b6b2573f8c7ebf --- /dev/null +++ b/train/training_loop.py @@ -0,0 +1,288 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import cProfile as profile +import functools +import pstats + +import blobfile as bf +import numpy as np +import torch +from torch.optim import AdamW +from tqdm import tqdm + +import utils.logger as logger +from diffusion.fp16_util import MixedPrecisionTrainer +from diffusion.resample import LossAwareSampler, create_named_schedule_sampler +from utils.misc import dev, load_state_dict + +INITIAL_LOG_LOSS_SCALE = 20.0 + + +class TrainLoop: + def __init__( + self, args, train_platform, model, diffusion, data, writer, rank=0, world_size=1 + ): + self.args = args + self.dataset = args.dataset + self.train_platform = train_platform + self.model = model + self.diffusion = diffusion + self.cond_mode = model.module.cond_mode if world_size > 1 else model.cond_mode + self.data = data + self.batch_size = args.batch_size + self.microbatch = args.batch_size # deprecating this option + self.lr = args.lr + self.log_interval = args.log_interval + self.save_interval = args.save_interval + self.resume_checkpoint = args.resume_checkpoint + self.use_fp16 = False # deprecating this option + self.fp16_scale_growth = 1e-3 # deprecating this option + self.weight_decay = args.weight_decay + self.lr_anneal_steps = args.lr_anneal_steps + self.rank = rank + self.world_size = world_size + + self.step = 0 + self.resume_step = 0 + self.global_batch = self.batch_size + self.num_steps = args.num_steps + self.num_epochs = self.num_steps // len(self.data) + 1 + chunks = list(range(self.num_steps)) + num_chunks = int(self.num_steps / 10) + chunks = np.array_split(chunks, num_chunks) + self.chunks = np.reshape(chunks[10_000::10], (-1)) + self.sync_cuda = torch.cuda.is_available() + self.writer = writer + + self._load_and_sync_parameters() + self.mp_trainer = MixedPrecisionTrainer( + model=self.model, + use_fp16=self.use_fp16, + fp16_scale_growth=self.fp16_scale_growth, + ) + + self.save_dir = args.save_dir + self.overwrite = args.overwrite + + self.opt = AdamW( + self.mp_trainer.master_params, lr=self.lr, weight_decay=self.weight_decay + ) + if self.resume_step: + self._load_optimizer_state() + + if torch.cuda.is_available(): + self.device = torch.device(f"cuda:{self.rank}") + + self.schedule_sampler_type = "uniform" + self.schedule_sampler = create_named_schedule_sampler( + self.schedule_sampler_type, diffusion + ) + self.eval_wrapper, self.eval_data, self.eval_gt_data = None, None, None + self.use_ddp = True + self.ddp_model = self.model + + def _load_and_sync_parameters(self): + resume_checkpoint = find_resume_checkpoint() or self.resume_checkpoint + + if resume_checkpoint: + self.resume_step = parse_resume_step_from_filename(resume_checkpoint) + logger.log(f"loading model from checkpoint: {resume_checkpoint}...") + self.model.load_state_dict( + load_state_dict(resume_checkpoint, map_location=dev()) + ) + + def _load_optimizer_state(self): + main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint + opt_checkpoint = bf.join( + bf.dirname(main_checkpoint), f"opt{self.resume_step:09}.pt" + ) + if bf.exists(opt_checkpoint): + logger.log(f"loading optimizer state from checkpoint: {opt_checkpoint}") + state_dict = load_state_dict(opt_checkpoint, map_location=dev()) + self.opt.load_state_dict(state_dict) + + def _print_stats(self, logger): + if (self.step % 100 == 0 and self.step > 0) and self.rank == 0: + v = logger.get_current().name2val + v = v["loss"] + print("step[{}]: loss[{:0.5f}]".format(self.step + self.resume_step, v)) + + def _write_to_logger(self, logger): + if (self.step % self.log_interval == 0) and self.rank == 0: + for k, v in logger.get_current().name2val.items(): + if k == "loss": + print( + "step[{}]: loss[{:0.5f}]".format( + self.step + self.resume_step, v + ) + ) + self.writer.add_scalar(f"./Train/{k}", v, self.step) + if k in ["step", "samples"] or "_q" in k: + continue + else: + self.train_platform.report_scalar( + name=k, value=v, iteration=self.step, group_name="Loss" + ) + self.writer.add_scalar(f"./Train/{k}", v, self.step) + + def run_loop(self): + for _ in range(self.num_epochs): + if self.rank == 0: + prof = profile.Profile() + prof.enable() + + for motion, cond in tqdm(self.data, disable=(self.rank != 0)): + if not ( + not self.lr_anneal_steps + or self.step + self.resume_step < self.lr_anneal_steps + ): + break + + motion = motion.to(self.device) + cond["y"] = { + key: val.to(self.device) if torch.is_tensor(val) else val + for key, val in cond["y"].items() + } + self.run_step(motion, cond) + self._print_stats(logger) + self._write_to_logger(logger) + if (self.step % self.save_interval == 0) and self.rank == 0: + self.save() + + self.step += 1 + + if (self.step == 1000) and self.rank == 0: + prof.disable() + stats = pstats.Stats(prof).strip_dirs().sort_stats("cumtime") + stats.print_stats(10) + + if not ( + not self.lr_anneal_steps + or self.step + self.resume_step < self.lr_anneal_steps + ): + break + + # Save the last checkpoint if it wasn't already saved. + if ((self.step - 1) % self.save_interval != 0) and self.rank == 0: + self.save() + + def run_step(self, batch, cond): + self.forward_backward(batch, cond) + self.mp_trainer.optimize(self.opt) + self._anneal_lr() + if self.rank == 0: + self.log_step() + + def forward_backward(self, batch, cond): + self.mp_trainer.zero_grad() + for i in range(0, batch.shape[0], self.microbatch): + # Eliminates the microbatch feature + assert i == 0 + assert self.microbatch == self.batch_size + micro = batch + micro_cond = cond + last_batch = (i + self.microbatch) >= batch.shape[0] + t, weights = self.schedule_sampler.sample(micro.shape[0], batch.device) + + compute_losses = functools.partial( + self.diffusion.training_losses, + self.ddp_model, + micro, + t, + model_kwargs=micro_cond, + ) + + if last_batch or not self.use_ddp: + losses = compute_losses() + else: + with self.ddp_model.no_sync(): + losses = compute_losses() + + if isinstance(self.schedule_sampler, LossAwareSampler): + self.schedule_sampler.update_with_local_losses( + t, losses["loss"].detach() + ) + + loss = (losses["loss"] * weights).mean() + log_loss_dict( + self.diffusion, t, {k: v * weights for k, v in losses.items()} + ) + self.mp_trainer.backward(loss) + + def _anneal_lr(self): + if not self.lr_anneal_steps: + return + frac_done = (self.step + self.resume_step) / self.lr_anneal_steps + lr = self.lr * (1 - frac_done) + for param_group in self.opt.param_groups: + param_group["lr"] = lr + + def log_step(self): + logger.logkv("step", self.step + self.resume_step) + logger.logkv("samples", (self.step + self.resume_step + 1) * self.global_batch) + + def ckpt_file_name(self): + return f"model{(self.step+self.resume_step):09d}.pt" + + def save(self): + def save_checkpoint(params): + state_dict = self.mp_trainer.master_params_to_state_dict(params) + + # Do not save CLIP weights + clip_weights = [e for e in state_dict.keys() if e.startswith("clip_model.")] + for e in clip_weights: + del state_dict[e] + + logger.log(f"saving model...") + filename = self.ckpt_file_name() + with bf.BlobFile(bf.join(self.save_dir, filename), "wb") as f: + torch.save(state_dict, f) + + save_checkpoint(self.mp_trainer.master_params) + + with bf.BlobFile( + bf.join(self.save_dir, f"opt{(self.step+self.resume_step):09d}.pt"), + "wb", + ) as f: + torch.save(self.opt.state_dict(), f) + + +def parse_resume_step_from_filename(filename): + """ + Parse filenames of the form path/to/modelNNNNNN.pt, where NNNNNN is the + checkpoint's number of steps. + """ + split = filename.split("model") + if len(split) < 2: + return 0 + split1 = split[-1].split(".")[0] + try: + return int(split1) + except ValueError: + return 0 + + +def get_blob_logdir(): + # You can change this to be a separate path to save checkpoints to + # a blobstore or some external drive. + return logger.get_dir() + + +def find_resume_checkpoint(): + # On your infrastructure, you may want to override this to automatically + # discover the latest checkpoint on your blob storage, etc. + return None + + +def log_loss_dict(diffusion, ts, losses): + for key, values in losses.items(): + logger.logkv_mean(key, values.mean().item()) + # Log the quantiles (four quartiles, in particular). + for sub_t, sub_loss in zip(ts.cpu().numpy(), values.detach().cpu().numpy()): + quartile = int(4 * sub_t / diffusion.num_timesteps) + logger.logkv_mean(f"{key}_q{quartile}", sub_loss) diff --git a/utils/diff_parser_utils.py b/utils/diff_parser_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..10f0e819ff4649a17d6127a4695f9599710a7b27 --- /dev/null +++ b/utils/diff_parser_utils.py @@ -0,0 +1,307 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import argparse +import json +import os +from argparse import ArgumentParser + + +def parse_and_load_from_model(parser): + # args according to the loaded model + # do not try to specify them from cmd line since they will be overwritten + add_data_options(parser) + add_model_options(parser) + add_diffusion_options(parser) + args = parser.parse_args() + args_to_overwrite = [] + for group_name in ["dataset", "model", "diffusion"]: + args_to_overwrite += get_args_per_group_name(parser, args, group_name) + args_to_overwrite += ["data_root"] + + # load args from model + model_path = get_model_path_from_args() + args_path = os.path.join(os.path.dirname(model_path), "args.json") + print(args_path) + assert os.path.exists(args_path), "Arguments json file was not found!" + with open(args_path, "r") as fr: + model_args = json.load(fr) + + for a in args_to_overwrite: + if a in model_args.keys(): + if a == "timestep_respacing" or a == "partial": + continue + setattr(args, a, model_args[a]) + + elif "cond_mode" in model_args: # backward compitability + unconstrained = model_args["cond_mode"] == "no_cond" + setattr(args, "unconstrained", unconstrained) + + else: + print( + "Warning: was not able to load [{}], using default value [{}] instead.".format( + a, args.__dict__[a] + ) + ) + + if args.cond_mask_prob == 0: + args.guidance_param = 1 + return args + + +def get_args_per_group_name(parser, args, group_name): + for group in parser._action_groups: + if group.title == group_name: + group_dict = { + a.dest: getattr(args, a.dest, None) for a in group._group_actions + } + return list(argparse.Namespace(**group_dict).__dict__.keys()) + return ValueError("group_name was not found.") + + +def get_model_path_from_args(): + try: + dummy_parser = ArgumentParser() + dummy_parser.add_argument("model_path") + dummy_args, _ = dummy_parser.parse_known_args() + return dummy_args.model_path + except: + raise ValueError("model_path argument must be specified.") + + +def add_base_options(parser): + group = parser.add_argument_group("base") + group.add_argument( + "--cuda", default=True, type=bool, help="Use cuda device, otherwise use CPU." + ) + group.add_argument("--device", default=0, type=int, help="Device id to use.") + group.add_argument("--seed", default=10, type=int, help="For fixing random seed.") + group.add_argument( + "--batch_size", default=64, type=int, help="Batch size during training." + ) + + +def add_diffusion_options(parser): + group = parser.add_argument_group("diffusion") + group.add_argument( + "--noise_schedule", + default="cosine", + choices=["linear", "cosine"], + type=str, + help="Noise schedule type", + ) + group.add_argument( + "--diffusion_steps", + default=10, + type=int, + help="Number of diffusion steps (denoted T in the paper)", + ) + group.add_argument( + "--timestep_respacing", + default="ddim100", + type=str, + help="ddimN, else empty string", + ) + group.add_argument( + "--sigma_small", default=True, type=bool, help="Use smaller sigma values." + ) + + +def add_model_options(parser): + group = parser.add_argument_group("model") + group.add_argument("--layers", default=8, type=int, help="Number of layers.") + group.add_argument( + "--num_audio_layers", default=3, type=int, help="Number of audio layers." + ) + group.add_argument("--heads", default=4, type=int, help="Number of heads.") + group.add_argument( + "--latent_dim", default=512, type=int, help="Transformer/GRU width." + ) + group.add_argument( + "--cond_mask_prob", + default=0.20, + type=float, + help="The probability of masking the condition during training." + " For classifier-free guidance learning.", + ) + group.add_argument( + "--lambda_vel", default=0.0, type=float, help="Joint velocity loss." + ) + group.add_argument( + "--unconstrained", + action="store_true", + help="Model is trained unconditionally. That is, it is constrained by neither text nor action. " + "Currently tested on HumanAct12 only.", + ) + group.add_argument( + "--data_format", + type=str, + choices=["pose", "face"], + default="pose", + help="whether or not to use vae for diffusion process", + ) + group.add_argument("--not_rotary", action="store_true") + group.add_argument("--simplify_audio", action="store_true") + group.add_argument("--add_frame_cond", type=float, choices=[1], default=None) + + +def add_data_options(parser): + group = parser.add_argument_group("dataset") + group.add_argument( + "--dataset", + default="social", + choices=["social"], + type=str, + help="Dataset name (choose from list).", + ) + group.add_argument("--data_root", type=str, default=None, help="dataset directory") + group.add_argument("--max_seq_length", default=600, type=int) + group.add_argument( + "--split", type=str, default=None, choices=["test", "train", "val"] + ) + + +def add_training_options(parser): + group = parser.add_argument_group("training") + group.add_argument( + "--save_dir", + required=True, + type=str, + help="Path to save checkpoints and results.", + ) + group.add_argument( + "--overwrite", + action="store_true", + help="If True, will enable to use an already existing save_dir.", + ) + group.add_argument( + "--train_platform_type", + default="NoPlatform", + choices=["NoPlatform", "ClearmlPlatform", "TensorboardPlatform"], + type=str, + help="Choose platform to log results. NoPlatform means no logging.", + ) + group.add_argument("--lr", default=1e-4, type=float, help="Learning rate.") + group.add_argument( + "--weight_decay", default=0.0, type=float, help="Optimizer weight decay." + ) + group.add_argument( + "--lr_anneal_steps", + default=0, + type=int, + help="Number of learning rate anneal steps.", + ) + group.add_argument( + "--log_interval", default=1_000, type=int, help="Log losses each N steps" + ) + group.add_argument( + "--save_interval", + default=5_000, + type=int, + help="Save checkpoints and run evaluation each N steps", + ) + group.add_argument( + "--num_steps", + default=800_000, + type=int, + help="Training will stop after the specified number of steps.", + ) + group.add_argument( + "--resume_checkpoint", + default="", + type=str, + help="If not empty, will start from the specified checkpoint (path to model###.pt file).", + ) + + +def add_sampling_options(parser): + group = parser.add_argument_group("sampling") + group.add_argument( + "--model_path", + required=True, + type=str, + help="Path to model####.pt file to be sampled.", + ) + group.add_argument( + "--output_dir", + default="", + type=str, + help="Path to results dir (auto created by the script). " + "If empty, will create dir in parallel to checkpoint.", + ) + group.add_argument("--face_codes", default=None, type=str) + group.add_argument("--pose_codes", default=None, type=str) + group.add_argument( + "--num_samples", + default=10, + type=int, + help="Maximal number of prompts to sample, " + "if loading dataset from file, this field will be ignored.", + ) + group.add_argument( + "--num_repetitions", + default=3, + type=int, + help="Number of repetitions, per sample (text prompt/action)", + ) + group.add_argument( + "--guidance_param", + default=2.5, + type=float, + help="For classifier-free sampling - specifies the s parameter, as defined in the paper.", + ) + group.add_argument( + "--curr_seq_length", + default=None, + type=int, + ) + group.add_argument( + "--render_gt", + action="store_true", + help="whether to use pretrained clipmodel for audio encoding", + ) + + +def add_generate_options(parser): + group = parser.add_argument_group("generate") + group.add_argument( + "--plot", + action="store_true", + help="Whether or not to save the renderings as a video.", + ) + group.add_argument( + "--resume_trans", + default=None, + type=str, + help="keyframe prediction network.", + ) + group.add_argument("--flip_person", action="store_true") + + +def get_cond_mode(args): + if args.dataset == "social": + cond_mode = "audio" + return cond_mode + + +def train_args(): + parser = ArgumentParser() + add_base_options(parser) + add_data_options(parser) + add_model_options(parser) + add_diffusion_options(parser) + add_training_options(parser) + return parser.parse_args() + + +def generate_args(): + parser = ArgumentParser() + add_base_options(parser) + add_sampling_options(parser) + add_generate_options(parser) + args = parse_and_load_from_model(parser) + return args diff --git a/utils/eval.py b/utils/eval.py new file mode 100644 index 0000000000000000000000000000000000000000..2dc3360baa465e48c400158adb154fa9c649043f --- /dev/null +++ b/utils/eval.py @@ -0,0 +1,115 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import argparse + +import numpy as np +from scipy import linalg + + +def calculate_diversity(activation: np.ndarray, diversity_times: int = 10_000) -> float: + assert len(activation.shape) == 2 + assert activation.shape[0] > diversity_times + num_samples = activation.shape[0] + first_indices = np.random.choice(num_samples, diversity_times, replace=False) + second_indices = np.random.choice(num_samples, diversity_times, replace=False) + dist = linalg.norm(activation[first_indices] - activation[second_indices], axis=1) + return dist + + +def calculate_activation_statistics( + activations: np.ndarray, +) -> (np.ndarray, np.ndarray): + mu = np.mean(activations, axis=0) + cov = np.cov(activations, rowvar=False) + return mu, cov + + +def calculate_frechet_distance( + mu1: np.ndarray, + sigma1: np.ndarray, + mu2: np.ndarray, + sigma2: np.ndarray, + eps: float = 1e-6, +) -> float: + mu1 = np.atleast_1d(mu1) + mu2 = np.atleast_1d(mu2) + + sigma1 = np.atleast_2d(sigma1) + sigma2 = np.atleast_2d(sigma2) + + assert ( + mu1.shape == mu2.shape + ), "Training and test mean vectors have different lengths" + assert ( + sigma1.shape == sigma2.shape + ), "Training and test covariances have different dimensions" + + diff = mu1 - mu2 + + # Product might be almost singular + covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) + if not np.isfinite(covmean).all(): + msg = ( + "fid calculation produces singular product; " + "adding %s to diagonal of cov estimates" + ) % eps + print(msg) + offset = np.eye(sigma1.shape[0]) * eps + covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) + + # Numerical error might give slight imaginary component + if np.iscomplexobj(covmean): + if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): + m = np.max(np.abs(covmean.imag)) + raise ValueError("Imaginary component {}".format(m)) + covmean = covmean.real + + tr_covmean = np.trace(covmean) + + return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean + + +def main(args): + num_samples = 5 + results = np.load(args.results, allow_pickle=True).item() + pred_reshaped = results["motion"].squeeze().reshape((num_samples, -1, 104, 600)) + gt_reshaped = results["gt"].squeeze().reshape((num_samples, -1, 104, 600)) + + # calulate variance across the different samples generated + cross_sample_var = np.var(pred_reshaped.reshape((num_samples, -1)), axis=0) + print("cross var", cross_sample_var.mean()) + + pred_pose_last = pred_reshaped.transpose((0, 1, 3, 2)).reshape(-1, 104) + gt_pose_last = gt_reshaped.transpose((0, 1, 3, 2)).reshape(-1, 104) + # calculate the static and kinematic diversity + var_g = calculate_diversity(pred_pose_last) + print("var_g", var_g.mean()) + var_k = np.var(pred_reshaped, axis=-1) + print("var_k", var_k.mean()) + + # calculate the static and kinematic fid + pred_mu_g, pred_cov_g = calculate_activation_statistics(pred_pose_last) + gt_mu_g, gt_cov_g = calculate_activation_statistics(gt_pose_last) + fid_g = calculate_frechet_distance(gt_mu_g, gt_cov_g, pred_mu_g, pred_cov_g) + print("fid_g", fid_g) + # reshape for kinematic fid + pred_motion = pred_reshaped[..., 1:] - pred_reshaped[..., :-1] + gt_motion = gt_reshaped[..., 1:] - gt_reshaped[..., :-1] + pred_motion_last = pred_motion.transpose((0, 1, 3, 2)).reshape(-1, 104) + gt_motion_last = gt_motion.transpose((0, 1, 3, 2)).reshape(-1, 104) + pred_mu_k, pred_cov_k = calculate_activation_statistics(pred_motion_last) + gt_mu_k, gt_cov_k = calculate_activation_statistics(gt_motion_last) + fid_k = calculate_frechet_distance(gt_mu_k, gt_cov_k, pred_mu_k, pred_cov_k) + print("fid_k", fid_k) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--results", type=str, required=True) + args = parser.parse_args() + main(args) diff --git a/utils/guide_parser_utils.py b/utils/guide_parser_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a8a6279a9d61af0739eb11686e77cd0bd7f2671a --- /dev/null +++ b/utils/guide_parser_utils.py @@ -0,0 +1,56 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import argparse + + +def _add_dataset_args(parser): + parser.add_argument("--dataset", default="social", type=str) + parser.add_argument( + "--data_format", type=str, default="pose", choices=["pose", "face"] + ) + parser.add_argument("--data_root", type=str, default=None, help="dataset directory") + parser.add_argument("--batch_size", default=16, type=int) + parser.add_argument("--add_frame_cond", type=int, default=None, choices=[1]) + parser.add_argument("--max_seq_length", default=600, type=int) + + +def _add_opt_args(parser): + parser.add_argument("--lr", type=float, default=1e-4) + parser.add_argument("--warm_up_iter", type=int, default=5_000) + parser.add_argument("--weight_decay", type=float, default=0.1) + parser.add_argument( + "--lr-scheduler", + default=[50000, 400000], + nargs="+", + type=int, + help="learning rate schedule (iterations)", + ) + parser.add_argument("--gamma", default=0.1, type=float) + parser.add_argument("--gn", action="store_true", help="gradient clipping") + + +def _add_model_args(parser): + parser.add_argument("--layers", default=8, type=int) + parser.add_argument("--dim", default=8, type=int) + parser.add_argument("--resume_pth", type=str, required=True) + parser.add_argument("--resume_trans", type=str, default=None) + + +def train_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--seed", default=10, type=int, help="For fixing random seed.") + parser.add_argument("--out_dir", type=str, required=True) + parser.add_argument("--total_iter", default=1_000_000, type=int) + parser.add_argument("--log_interval", default=1_000, type=int) + parser.add_argument("--eval_interval", default=1_000, type=int) + parser.add_argument("--save_interval", default=5_000, type=int) + _add_model_args(parser) + _add_opt_args(parser) + _add_dataset_args(parser) + args = parser.parse_args() + return args diff --git a/utils/logger.py b/utils/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..e0a603e456c40a7aab437455b48792ecb784a35f --- /dev/null +++ b/utils/logger.py @@ -0,0 +1,496 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import os +import sys +import shutil +import os.path as osp +import json +import time +import datetime +import tempfile +import warnings +from collections import defaultdict +from contextlib import contextmanager + +DEBUG = 10 +INFO = 20 +WARN = 30 +ERROR = 40 + +DISABLED = 50 + + +class KVWriter(object): + def writekvs(self, kvs): + raise NotImplementedError + + +class SeqWriter(object): + def writeseq(self, seq): + raise NotImplementedError + + +class HumanOutputFormat(KVWriter, SeqWriter): + def __init__(self, filename_or_file): + if isinstance(filename_or_file, str): + self.file = open(filename_or_file, "wt") + self.own_file = True + else: + assert hasattr(filename_or_file, "read"), ( + "expected file or str, got %s" % filename_or_file + ) + self.file = filename_or_file + self.own_file = False + + def writekvs(self, kvs): + # Create strings for printing + key2str = {} + for key, val in sorted(kvs.items()): + if hasattr(val, "__float__"): + valstr = "%-8.3g" % val + else: + valstr = str(val) + key2str[self._truncate(key)] = self._truncate(valstr) + + # Find max widths + if len(key2str) == 0: + print("WARNING: tried to write empty key-value dict") + return + else: + keywidth = max(map(len, key2str.keys())) + valwidth = max(map(len, key2str.values())) + + # Write out the data + dashes = "-" * (keywidth + valwidth + 7) + lines = [dashes] + for key, val in sorted(key2str.items(), key=lambda kv: kv[0].lower()): + lines.append( + "| %s%s | %s%s |" + % (key, " " * (keywidth - len(key)), val, " " * (valwidth - len(val))) + ) + lines.append(dashes) + self.file.write("\n".join(lines) + "\n") + + # Flush the output to the file + self.file.flush() + + def _truncate(self, s): + maxlen = 30 + return s[: maxlen - 3] + "..." if len(s) > maxlen else s + + def writeseq(self, seq): + seq = list(seq) + for i, elem in enumerate(seq): + self.file.write(elem) + if i < len(seq) - 1: # add space unless this is the last one + self.file.write(" ") + self.file.write("\n") + self.file.flush() + + def close(self): + if self.own_file: + self.file.close() + + +class JSONOutputFormat(KVWriter): + def __init__(self, filename): + self.file = open(filename, "wt") + + def writekvs(self, kvs): + for k, v in sorted(kvs.items()): + if hasattr(v, "dtype"): + kvs[k] = float(v) + self.file.write(json.dumps(kvs) + "\n") + self.file.flush() + + def close(self): + self.file.close() + + +class CSVOutputFormat(KVWriter): + def __init__(self, filename): + self.file = open(filename, "w+t") + self.keys = [] + self.sep = "," + + def writekvs(self, kvs): + # Add our current row to the history + extra_keys = list(kvs.keys() - self.keys) + extra_keys.sort() + if extra_keys: + self.keys.extend(extra_keys) + self.file.seek(0) + lines = self.file.readlines() + self.file.seek(0) + for i, k in enumerate(self.keys): + if i > 0: + self.file.write(",") + self.file.write(k) + self.file.write("\n") + for line in lines[1:]: + self.file.write(line[:-1]) + self.file.write(self.sep * len(extra_keys)) + self.file.write("\n") + for i, k in enumerate(self.keys): + if i > 0: + self.file.write(",") + v = kvs.get(k) + if v is not None: + self.file.write(str(v)) + self.file.write("\n") + self.file.flush() + + def close(self): + self.file.close() + + +class TensorBoardOutputFormat(KVWriter): + """ + Dumps key/value pairs into TensorBoard's numeric format. + """ + + def __init__(self, dir): + os.makedirs(dir, exist_ok=True) + self.dir = dir + self.step = 1 + prefix = "events" + path = osp.join(osp.abspath(dir), prefix) + import tensorflow as tf + from tensorflow.python import pywrap_tensorflow + from tensorflow.core.util import event_pb2 + from tensorflow.python.util import compat + + self.tf = tf + self.event_pb2 = event_pb2 + self.pywrap_tensorflow = pywrap_tensorflow + self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path)) + + def writekvs(self, kvs): + def summary_val(k, v): + kwargs = {"tag": k, "simple_value": float(v)} + return self.tf.Summary.Value(**kwargs) + + summary = self.tf.Summary(value=[summary_val(k, v) for k, v in kvs.items()]) + event = self.event_pb2.Event(wall_time=time.time(), summary=summary) + event.step = ( + self.step + ) # is there any reason why you'd want to specify the step? + self.writer.WriteEvent(event) + self.writer.Flush() + self.step += 1 + + def close(self): + if self.writer: + self.writer.Close() + self.writer = None + + +def make_output_format(format, ev_dir, log_suffix=""): + os.makedirs(ev_dir, exist_ok=True) + if format == "stdout": + return HumanOutputFormat(sys.stdout) + elif format == "log": + return HumanOutputFormat(osp.join(ev_dir, "log%s.txt" % log_suffix)) + elif format == "json": + return JSONOutputFormat(osp.join(ev_dir, "progress%s.json" % log_suffix)) + elif format == "csv": + return CSVOutputFormat(osp.join(ev_dir, "progress%s.csv" % log_suffix)) + elif format == "tensorboard": + return TensorBoardOutputFormat(osp.join(ev_dir, "tb%s" % log_suffix)) + else: + raise ValueError("Unknown format specified: %s" % (format,)) + + +# ================================================================ +# API +# ================================================================ + + +def logkv(key, val): + """ + Log a value of some diagnostic + Call this once for each diagnostic quantity, each iteration + If called many times, last value will be used. + """ + get_current().logkv(key, val) + + +def logkv_mean(key, val): + """ + The same as logkv(), but if called many times, values averaged. + """ + get_current().logkv_mean(key, val) + + +def logkvs(d): + """ + Log a dictionary of key-value pairs + """ + for k, v in d.items(): + logkv(k, v) + + +def dumpkvs(): + """ + Write all of the diagnostics from the current iteration + """ + return get_current().dumpkvs() + + +def getkvs(): + return get_current().name2val + + +def log(*args, level=INFO): + """ + Write the sequence of args, with no separators, to the console and output files (if you've configured an output file). + """ + get_current().log(*args, level=level) + + +def debug(*args): + log(*args, level=DEBUG) + + +def info(*args): + log(*args, level=INFO) + + +def warn(*args): + log(*args, level=WARN) + + +def error(*args): + log(*args, level=ERROR) + + +def set_level(level): + """ + Set logging threshold on current logger. + """ + get_current().set_level(level) + + +def set_comm(comm): + get_current().set_comm(comm) + + +def get_dir(): + """ + Get directory that log files are being written to. + will be None if there is no output directory (i.e., if you didn't call start) + """ + return get_current().get_dir() + + +record_tabular = logkv +dump_tabular = dumpkvs + + +@contextmanager +def profile_kv(scopename): + logkey = "wait_" + scopename + tstart = time.time() + try: + yield + finally: + get_current().name2val[logkey] += time.time() - tstart + + +def profile(n): + """ + Usage: + @profile("my_func") + def my_func(): code + """ + + def decorator_with_name(func): + def func_wrapper(*args, **kwargs): + with profile_kv(n): + return func(*args, **kwargs) + + return func_wrapper + + return decorator_with_name + + +# ================================================================ +# Backend +# ================================================================ + + +def get_current(): + if Logger.CURRENT is None: + _configure_default_logger() + + return Logger.CURRENT + + +class Logger(object): + DEFAULT = None # A logger with no output files. (See right below class definition) + # So that you can still log to the terminal without setting up any output files + CURRENT = None # Current logger being used by the free functions above + + def __init__(self, dir, output_formats, comm=None): + self.name2val = defaultdict(float) # values this iteration + self.name2cnt = defaultdict(int) + self.level = INFO + self.dir = dir + self.output_formats = output_formats + self.comm = comm + + # Logging API, forwarded + # ---------------------------------------- + def logkv(self, key, val): + self.name2val[key] = val + + def logkv_mean(self, key, val): + oldval, cnt = self.name2val[key], self.name2cnt[key] + self.name2val[key] = oldval * cnt / (cnt + 1) + val / (cnt + 1) + self.name2cnt[key] = cnt + 1 + + def dumpkvs(self): + if self.comm is None: + d = self.name2val + else: + d = mpi_weighted_mean( + self.comm, + { + name: (val, self.name2cnt.get(name, 1)) + for (name, val) in self.name2val.items() + }, + ) + if self.comm.rank != 0: + d["dummy"] = 1 # so we don't get a warning about empty dict + out = d.copy() # Return the dict for unit testing purposes + for fmt in self.output_formats: + if isinstance(fmt, KVWriter): + fmt.writekvs(d) + self.name2val.clear() + self.name2cnt.clear() + return out + + def log(self, *args, level=INFO): + if self.level <= level: + self._do_log(args) + + # Configuration + # ---------------------------------------- + def set_level(self, level): + self.level = level + + def set_comm(self, comm): + self.comm = comm + + def get_dir(self): + return self.dir + + def close(self): + for fmt in self.output_formats: + fmt.close() + + # Misc + # ---------------------------------------- + def _do_log(self, args): + for fmt in self.output_formats: + if isinstance(fmt, SeqWriter): + fmt.writeseq(map(str, args)) + + +def get_rank_without_mpi_import(): + # check environment variables here instead of importing mpi4py + # to avoid calling MPI_Init() when this module is imported + for varname in ["PMI_RANK", "OMPI_COMM_WORLD_RANK"]: + if varname in os.environ: + return int(os.environ[varname]) + return 0 + + +def mpi_weighted_mean(comm, local_name2valcount): + """ + Copied from: https://github.com/EXP/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/common/mpi_util.py#L110 + Perform a weighted average over dicts that are each on a different node + Input: local_name2valcount: dict mapping key -> (value, count) + Returns: key -> mean + """ + all_name2valcount = comm.gather(local_name2valcount) + if comm.rank == 0: + name2sum = defaultdict(float) + name2count = defaultdict(float) + for n2vc in all_name2valcount: + for name, (val, count) in n2vc.items(): + try: + val = float(val) + except ValueError: + if comm.rank == 0: + warnings.warn( + "WARNING: tried to compute mean on non-float {}={}".format( + name, val + ) + ) + else: + name2sum[name] += val * count + name2count[name] += count + return {name: name2sum[name] / name2count[name] for name in name2sum} + else: + return {} + + +def configure(dir=None, format_strs=None, comm=None, log_suffix=""): + """ + If comm is provided, average all numerical stats across that comm + """ + if dir is None: + dir = os.getenv("EXP_LOGDIR") + if dir is None: + dir = osp.join( + tempfile.gettempdir(), + datetime.datetime.now().strftime("exp-%Y-%m-%d-%H-%M-%S-%f"), + ) + assert isinstance(dir, str) + dir = os.path.expanduser(dir) + os.makedirs(os.path.expanduser(dir), exist_ok=True) + + rank = get_rank_without_mpi_import() + if rank > 0: + log_suffix = log_suffix + "-rank%03i" % rank + + if format_strs is None: + if rank == 0: + format_strs = os.getenv("EXP_LOG_FORMAT", "stdout,log,csv").split(",") + else: + format_strs = os.getenv("EXP_LOG_FORMAT_MPI", "log").split(",") + format_strs = filter(None, format_strs) + output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs] + + Logger.CURRENT = Logger(dir=dir, output_formats=output_formats, comm=comm) + if output_formats: + log("Logging to %s" % dir) + + +def _configure_default_logger(): + configure() + Logger.DEFAULT = Logger.CURRENT + + +def reset(): + if Logger.CURRENT is not Logger.DEFAULT: + Logger.CURRENT.close() + Logger.CURRENT = Logger.DEFAULT + log("Reset logger") + + +@contextmanager +def scoped_configure(dir=None, format_strs=None, comm=None): + prevlogger = Logger.CURRENT + configure(dir=dir, format_strs=format_strs, comm=comm) + try: + yield + finally: + Logger.CURRENT.close() + Logger.CURRENT = prevlogger diff --git a/utils/misc.py b/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..e4511ed604c55e15d32b2d317c60f04c6deb3077 --- /dev/null +++ b/utils/misc.py @@ -0,0 +1,223 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import time +import numpy as np +import random + +import os +import socket +import typing as tp + +import torch +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP + +# Change this to reflect your cluster layout. +# The GPU for a given rank is (rank % GPUS_PER_NODE). +GPUS_PER_NODE = 8 + +SETUP_RETRY_COUNT = 3 + +used_device = 0 + + +def setup(rank, world_size): + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12355" + + # initialize the process group + dist.init_process_group("gloo", rank=rank, world_size=world_size) + + +def cleanup(): + dist.destroy_process_group() + + +def setup_dist(device=0): + """ + Setup a distributed process group. + """ + global used_device + used_device = device + if dist.is_initialized(): + return + +def dev(): + """ + Get the device to use for torch.distributed. + """ + global used_device + if torch.cuda.is_available() and used_device >= 0: + return torch.device(f"cuda:{used_device}") + return torch.device("cpu") + + +def load_state_dict(path, **kwargs): + """ + Load a PyTorch file without redundant fetches across MPI ranks. + """ + return torch.load(path, **kwargs) + + +def sync_params(params): + """ + Synchronize a sequence of Tensors across ranks from rank 0. + """ + for p in params: + with torch.no_grad(): + dist.broadcast(p, 0) + + +def _find_free_port(): + try: + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.bind(("", 0)) + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + return s.getsockname()[1] + finally: + s.close() + + +def world_size(): + if torch.distributed.is_initialized(): + return torch.distributed.get_world_size() + else: + return 1 + + +def is_distributed(): + return world_size() > 1 + + +def all_reduce(tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM): + if is_distributed(): + return torch.distributed.all_reduce(tensor, op) + + +def _is_complex_or_float(tensor): + return torch.is_floating_point(tensor) or torch.is_complex(tensor) + + +def _check_number_of_params(params: tp.List[torch.Tensor]): + # utility function to check that the number of params in all workers is the same, + # and thus avoid a deadlock with distributed all reduce. + if not is_distributed() or not params: + return + tensor = torch.tensor([len(params)], device=params[0].device, dtype=torch.long) + all_reduce(tensor) + if tensor.item() != len(params) * world_size(): + # If not all the workers have the same number, for at least one of them, + # this inequality will be verified. + raise RuntimeError( + f"Mismatch in number of params: ours is {len(params)}, " + "at least one worker has a different one." + ) + + +def broadcast_tensors(tensors: tp.Iterable[torch.Tensor], src: int = 0): + """Broadcast the tensors from the given parameters to all workers. + This can be used to ensure that all workers have the same model to start with. + """ + if not is_distributed(): + return + tensors = [tensor for tensor in tensors if _is_complex_or_float(tensor)] + _check_number_of_params(tensors) + handles = [] + for tensor in tensors: + handle = torch.distributed.broadcast(tensor.data, src=src, async_op=True) + handles.append(handle) + for handle in handles: + handle.wait() + + +def fixseed(seed): + torch.backends.cudnn.benchmark = False + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + +def prGreen(skk): + print("\033[92m {}\033[00m".format(skk)) + + +def prRed(skk): + print("\033[91m {}\033[00m".format(skk)) + + +def to_numpy(tensor): + if torch.is_tensor(tensor): + return tensor.cpu().numpy() + elif type(tensor).__module__ != "numpy": + raise ValueError("Cannot convert {} to numpy array".format(type(tensor))) + return tensor + + +def to_torch(ndarray): + if type(ndarray).__module__ == "numpy": + return torch.from_numpy(ndarray) + elif not torch.is_tensor(ndarray): + raise ValueError("Cannot convert {} to torch tensor".format(type(ndarray))) + return ndarray + + +def cleanexit(): + import sys + import os + + try: + sys.exit(0) + except SystemExit: + os._exit(0) + + +def load_model_wo_clip(model, state_dict): + missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) + assert len(unexpected_keys) == 0 + assert all([k.startswith("clip_model.") for k in missing_keys]) + + +def freeze_joints(x, joints_to_freeze): + # Freezes selected joint *rotations* as they appear in the first frame + # x [bs, [root+n_joints], joint_dim(6), seqlen] + frozen = x.detach().clone() + frozen[:, joints_to_freeze, :, :] = frozen[:, joints_to_freeze, :, :1] + return frozen + + +class TimerError(Exception): + """A custom exception used to report errors in use of Timer class""" + + +class Timer: + def __init__(self): + self._start_time = None + + def start(self): + """Start a new timer""" + if self._start_time is not None: + raise TimerError(f"Timer is running. Use .stop() to stop it") + + self._start_time = time.perf_counter() + + def stop(self, iter=None): + """Stop the timer, and report the elapsed time""" + if self._start_time is None: + raise TimerError(f"Timer is not running. Use .start() to start it") + + elapsed_time = time.perf_counter() - self._start_time + self._start_time = None + iter_msg = "" + if iter is not None: + if iter > elapsed_time: + iter_per_sec = iter / elapsed_time + iter_msg = f"[iter/s: {iter_per_sec:0.4f}]" + else: + sec_per_iter = elapsed_time / iter + iter_msg = f"[s/iter: {sec_per_iter:0.4f}]" + print(f"Elapsed time: {elapsed_time:0.4f} seconds {iter_msg}") diff --git a/utils/model_util.py b/utils/model_util.py new file mode 100644 index 0000000000000000000000000000000000000000..da1d177038d15a069cfe2e228951d20168ab2014 --- /dev/null +++ b/utils/model_util.py @@ -0,0 +1,114 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import torch + +from diffusion import gaussian_diffusion as gd +from diffusion.respace import space_timesteps, SpacedDiffusion +from model.diffusion import FiLMTransformer +from torch.nn import functional as F + + +def get_person_num(config_path): + if "PXB184" in config_path: + person = "PXB184" + elif "RLW104" in config_path: + person = "RLW104" + elif "TXB805" in config_path: + person = "TXB805" + elif "GQS883" in config_path: + person = "GQS883" + else: + assert False, f"something wrong with config: {config_path}" + return person + + +def load_model(model, state_dict): + missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) + assert len(unexpected_keys) == 0, unexpected_keys + assert all( + [ + k.startswith("transformer.") or k.startswith("tokenizer.") + for k in missing_keys + ] + ), missing_keys + + +def create_model_and_diffusion(args, split_type): + model = FiLMTransformer(**get_model_args(args, split_type=split_type)).to( + torch.float32 + ) + diffusion = create_gaussian_diffusion(args) + return model, diffusion + + +def get_model_args(args, split_type): + if args.data_format == "face": + nfeat = 256 + lfeat = 512 + elif args.data_format == "pose": + nfeat = 104 + lfeat = 256 + + if not hasattr(args, "num_audio_layers"): + args.num_audio_layers = 3 # backwards compat + + model_args = { + "args": args, + "nfeats": nfeat, + "latent_dim": lfeat, + "ff_size": 1024, + "num_layers": args.layers, + "num_heads": args.heads, + "dropout": 0.1, + "cond_feature_dim": 512 * 2, + "activation": F.gelu, + "use_rotary": not args.not_rotary, + "cond_mode": "uncond" if args.unconstrained else "audio", + "split_type": split_type, + "num_audio_layers": args.num_audio_layers, + "device": args.device, + } + return model_args + + +def create_gaussian_diffusion(args): + predict_xstart = True + steps = 1000 + scale_beta = 1.0 + timestep_respacing = args.timestep_respacing + learn_sigma = False + rescale_timesteps = False + + betas = gd.get_named_beta_schedule(args.noise_schedule, steps, scale_beta) + loss_type = gd.LossType.MSE + + if not timestep_respacing: + timestep_respacing = [steps] + + name = args.save_dir if hasattr(args, "save_dir") else args.model_path + return SpacedDiffusion( + use_timesteps=space_timesteps(steps, timestep_respacing), + betas=betas, + model_mean_type=( + gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X + ), + model_var_type=( + ( + gd.ModelVarType.FIXED_LARGE + if not args.sigma_small + else gd.ModelVarType.FIXED_SMALL + ) + if not learn_sigma + else gd.ModelVarType.LEARNED_RANGE + ), + data_format=args.data_format, + loss_type=loss_type, + rescale_timesteps=rescale_timesteps, + lambda_vel=args.lambda_vel, + model_path=name, + ) diff --git a/utils/vq_parser_utils.py b/utils/vq_parser_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..92f7ad7ba783c73a8dfa82b1358350a8b6a04403 --- /dev/null +++ b/utils/vq_parser_utils.py @@ -0,0 +1,93 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import argparse + + +def _add_dataset_args(parser): + parser.add_argument("--dataname", type=str, default="kit", help="dataset directory") + parser.add_argument("--data_root", type=str, default=None, help="dataset directory") + parser.add_argument("--max_seq_length", default=600, type=int) + parser.add_argument("--add_frame_cond", type=float, choices=[1], default=None) + parser.add_argument( + "--data_format", type=str, default="pose", choices=["pose", "face"] + ) + parser.add_argument("--dataset", default="social", type=str) + parser.add_argument("--batch_size", default=64, type=int, help="batch size") + + +def _add_optim_args(parser): + parser.add_argument( + "--total_iter", + default=300_000, + type=int, + help="number of total iterations to run", + ) + parser.add_argument( + "--warm_up_iter", + default=1000, + type=int, + help="number of total iterations for warmup", + ) + parser.add_argument("--lr", default=2e-4, type=float, help="max learning rate") + parser.add_argument( + "--lr_scheduler", + default=[300_000], + nargs="+", + type=int, + help="learning rate schedule (iterations)", + ) + parser.add_argument("--gamma", default=0.05, type=float, help="learning rate decay") + + parser.add_argument("--weight_decay", default=0.0, type=float, help="weight decay") + parser.add_argument( + "--commit", + type=float, + default=0.02, + help="hyper-parameter for the commitment loss", + ) + parser.add_argument( + "--loss_vel", + type=float, + default=0.1, + help="hyper-parameter for the velocity loss", + ) + + +def _add_model_args(parser): + parser.add_argument("--code_dim", type=int, default=512, help="embedding dimension") + parser.add_argument("--depth", type=int, default=3, help="depth of the network") + parser.add_argument( + "--output_emb_width", type=int, default=512, help="output embedding width" + ) + parser.add_argument( + "--resume_pth", type=str, default=None, help="resume pth for VQ" + ) + + +def train_args(): + parser = argparse.ArgumentParser( + description="Optimal Transport AutoEncoder training for AIST", + add_help=True, + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + _add_dataset_args(parser) + _add_optim_args(parser) + _add_model_args(parser) + + ## output directory + parser.add_argument("--out_dir", type=str, required=True, help="output directory") + ## other + parser.add_argument("--print_iter", default=200, type=int, help="print frequency") + parser.add_argument( + "--eval_iter", default=1000, type=int, help="evaluation frequency" + ) + parser.add_argument( + "--seed", default=123, type=int, help="seed for initializing training." + ) + args = parser.parse_args() + return args diff --git a/visualize/.ipynb_checkpoints/render_codes-checkpoint.py b/visualize/.ipynb_checkpoints/render_codes-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..16f9cd1c6a4c31bfe4f3ed63379d5d450b3d02cd --- /dev/null +++ b/visualize/.ipynb_checkpoints/render_codes-checkpoint.py @@ -0,0 +1,163 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import copy +import glob +import os +import re +import subprocess +from collections import OrderedDict +from typing import Dict, List + +import mediapy + +import numpy as np + +import torch +import torch as th +import torchaudio +from attrdict import AttrDict + +from omegaconf import OmegaConf +from tqdm import tqdm +from utils.model_util import get_person_num +from visualize.ca_body.utils.image import linear2displayBatch +from visualize.ca_body.utils.train import load_checkpoint, load_from_config + +ffmpeg_header = "ffmpeg -y " # -hide_banner -loglevel error " + + +def filter_params(params, ignore_names): + return OrderedDict( + [ + (k, v) + for k, v in params.items() + if not any([re.match(n, k) is not None for n in ignore_names]) + ] + ) + + +def call_ffmpeg(command: str) -> None: + print(command, "-" * 100) + e = subprocess.call(command, shell=True) + if e != 0: + assert False, e + + +class BodyRenderer(th.nn.Module): + def __init__( + self, + config_base: str, + render_rgb: bool, + ): + super().__init__() + self.config_base = config_base + ckpt_path = f"{config_base}/body_dec.ckpt" + config_path = f"{config_base}/config.yml" + assets_path = f"{config_base}/static_assets.pt" + # config + config = OmegaConf.load(config_path) + gpu = config.get("gpu", 0) + self.device = th.device(f"cuda:{gpu}") + # assets + static_assets = AttrDict(torch.load(assets_path)) + # build model + self.model = load_from_config(config.model, assets=static_assets).to( + self.device + ) + self.model.cal_enabled = False + self.model.pixel_cal_enabled = False + self.model.learn_blur_enabled = False + self.render_rgb = render_rgb + if not self.render_rgb: + self.model.rendering_enabled = None + # load model checkpoints + print("loading...", ckpt_path) + load_checkpoint( + ckpt_path, + modules={"model": self.model}, + ignore_names={"model": ["lbs_fn.*"]}, + ) + self.model.eval() + self.model.to(self.device) + # load default parameters for renderer + person = get_person_num(config_path) + self.default_inputs = th.load(f"assets/render_defaults_{person}.pth") + + def _write_video_stream( + self, motion: np.ndarray, face: np.ndarray, save_name: str + ) -> None: + out = self._render_loop(motion, face) + mediapy.write_video(save_name, out, fps=30) + + def _render_loop(self, body_pose: np.ndarray, face: np.ndarray) -> List[np.ndarray]: + all_rgb = [] + default_inputs_copy = copy.deepcopy(self.default_inputs) + for b in tqdm(range(len(body_pose))): + B = default_inputs_copy["K"].shape[0] + default_inputs_copy["lbs_motion"] = ( + th.tensor(body_pose[b : b + 1, :], device=self.device, dtype=th.float) + .tile(B, 1) + .to(self.device) + ) + geom = ( + self.model.lbs_fn.lbs_fn( + default_inputs_copy["lbs_motion"], + self.model.lbs_fn.lbs_scale.unsqueeze(0).tile(B, 1), + self.model.lbs_fn.lbs_template_verts.unsqueeze(0).tile(B, 1, 1), + ) + * self.model.lbs_fn.global_scaling + ) + default_inputs_copy["geom"] = geom + face_codes = ( + th.from_numpy(face).float().cuda() if not th.is_tensor(face) else face + ) + curr_face = th.tile(face_codes[b : b + 1, ...], (2, 1)) + default_inputs_copy["face_embs"] = curr_face + preds = self.model(**default_inputs_copy) + rgb0 = linear2displayBatch(preds["rgb"])[0] + rgb1 = linear2displayBatch(preds["rgb"])[1] + rgb = th.cat((rgb0, rgb1), axis=-1).permute(1, 2, 0) + rgb = rgb.clip(0, 255).to(th.uint8) + all_rgb.append(rgb.contiguous().detach().byte().cpu().numpy()) + return all_rgb + + def render_full_video( + self, + data_block: Dict[str, np.ndarray], + animation_save_path: str, + audio_sr: int = None, + render_gt: bool = False, + ) -> None: + tag = os.path.basename(os.path.dirname(animation_save_path)) + save_name = os.path.splitext(os.path.basename(animation_save_path))[0] + save_name = f"{tag}_{save_name}" + torchaudio.save( + f"/tmp/audio_{save_name}.wav", + torch.tensor(data_block["audio"]), + audio_sr, + ) + if render_gt: + tag = "gt" + self._write_video_stream( + data_block["gt_body"], + data_block["gt_face"], + f"/tmp/{tag}_{save_name}.mp4", + ) + else: + tag = "pred" + self._write_video_stream( + data_block["body_motion"], + data_block["face_motion"], + f"/tmp/{tag}_{save_name}.mp4", + ) + command = f"{ffmpeg_header} -i /tmp/{tag}_{save_name}.mp4 -i /tmp/audio_{save_name}.wav -c:v copy -map 0:v:0 -map 1:a:0 -c:a aac -b:a 192k -pix_fmt yuva420p {animation_save_path}_{tag}.mp4" + call_ffmpeg(command) + subprocess.call( + f"rm /tmp/audio_{save_name}.wav && rm /tmp/{tag}_{save_name}.mp4", + shell=True, + ) diff --git a/visualize/ca_body/LICENSE b/visualize/ca_body/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..872bc82ca7881a1c072b24e4c33783c7fc288c1d --- /dev/null +++ b/visualize/ca_body/LICENSE @@ -0,0 +1,399 @@ +Attribution-NonCommercial 4.0 International + +======================================================================= + +Creative Commons Corporation ("Creative Commons") is not a law firm and +does not provide legal services or legal advice. Distribution of +Creative Commons public licenses does not create a lawyer-client or +other relationship. Creative Commons makes its licenses and related +information available on an "as-is" basis. Creative Commons gives no +warranties regarding its licenses, any material licensed under their +terms and conditions, or any related information. Creative Commons +disclaims all liability for damages resulting from their use to the +fullest extent possible. + +Using Creative Commons Public Licenses + +Creative Commons public licenses provide a standard set of terms and +conditions that creators and other rights holders may use to share +original works of authorship and other material subject to copyright +and certain other rights specified in the public license below. The +following considerations are for informational purposes only, are not +exhaustive, and do not form part of our licenses. + + Considerations for licensors: Our public licenses are + intended for use by those authorized to give the public + permission to use material in ways otherwise restricted by + copyright and certain other rights. Our licenses are + irrevocable. Licensors should read and understand the terms + and conditions of the license they choose before applying it. + Licensors should also secure all rights necessary before + applying our licenses so that the public can reuse the + material as expected. Licensors should clearly mark any + material not subject to the license. This includes other CC- + licensed material, or material used under an exception or + limitation to copyright. More considerations for licensors: + wiki.creativecommons.org/Considerations_for_licensors + + Considerations for the public: By using one of our public + licenses, a licensor grants the public permission to use the + licensed material under specified terms and conditions. If + the licensor's permission is not necessary for any reason--for + example, because of any applicable exception or limitation to + copyright--then that use is not regulated by the license. Our + licenses grant only permissions under copyright and certain + other rights that a licensor has authority to grant. Use of + the licensed material may still be restricted for other + reasons, including because others have copyright or other + rights in the material. A licensor may make special requests, + such as asking that all changes be marked or described. + Although not required by our licenses, you are encouraged to + respect those requests where reasonable. More_considerations + for the public: + wiki.creativecommons.org/Considerations_for_licensees + +======================================================================= + +Creative Commons Attribution-NonCommercial 4.0 International Public +License + +By exercising the Licensed Rights (defined below), You accept and agree +to be bound by the terms and conditions of this Creative Commons +Attribution-NonCommercial 4.0 International Public License ("Public +License"). To the extent this Public License may be interpreted as a +contract, You are granted the Licensed Rights in consideration of Your +acceptance of these terms and conditions, and the Licensor grants You +such rights in consideration of benefits the Licensor receives from +making the Licensed Material available under these terms and +conditions. + +Section 1 -- Definitions. + + a. Adapted Material means material subject to Copyright and Similar + Rights that is derived from or based upon the Licensed Material + and in which the Licensed Material is translated, altered, + arranged, transformed, or otherwise modified in a manner requiring + permission under the Copyright and Similar Rights held by the + Licensor. For purposes of this Public License, where the Licensed + Material is a musical work, performance, or sound recording, + Adapted Material is always produced where the Licensed Material is + synched in timed relation with a moving image. + + b. Adapter's License means the license You apply to Your Copyright + and Similar Rights in Your contributions to Adapted Material in + accordance with the terms and conditions of this Public License. + + c. Copyright and Similar Rights means copyright and/or similar rights + closely related to copyright including, without limitation, + performance, broadcast, sound recording, and Sui Generis Database + Rights, without regard to how the rights are labeled or + categorized. For purposes of this Public License, the rights + specified in Section 2(b)(1)-(2) are not Copyright and Similar + Rights. + d. Effective Technological Measures means those measures that, in the + absence of proper authority, may not be circumvented under laws + fulfilling obligations under Article 11 of the WIPO Copyright + Treaty adopted on December 20, 1996, and/or similar international + agreements. + + e. Exceptions and Limitations means fair use, fair dealing, and/or + any other exception or limitation to Copyright and Similar Rights + that applies to Your use of the Licensed Material. + + f. Licensed Material means the artistic or literary work, database, + or other material to which the Licensor applied this Public + License. + + g. Licensed Rights means the rights granted to You subject to the + terms and conditions of this Public License, which are limited to + all Copyright and Similar Rights that apply to Your use of the + Licensed Material and that the Licensor has authority to license. + + h. Licensor means the individual(s) or entity(ies) granting rights + under this Public License. + + i. NonCommercial means not primarily intended for or directed towards + commercial advantage or monetary compensation. For purposes of + this Public License, the exchange of the Licensed Material for + other material subject to Copyright and Similar Rights by digital + file-sharing or similar means is NonCommercial provided there is + no payment of monetary compensation in connection with the + exchange. + + j. Share means to provide material to the public by any means or + process that requires permission under the Licensed Rights, such + as reproduction, public display, public performance, distribution, + dissemination, communication, or importation, and to make material + available to the public including in ways that members of the + public may access the material from a place and at a time + individually chosen by them. + + k. Sui Generis Database Rights means rights other than copyright + resulting from Directive 96/9/EC of the European Parliament and of + the Council of 11 March 1996 on the legal protection of databases, + as amended and/or succeeded, as well as other essentially + equivalent rights anywhere in the world. + + l. You means the individual or entity exercising the Licensed Rights + under this Public License. Your has a corresponding meaning. + +Section 2 -- Scope. + + a. License grant. + + 1. Subject to the terms and conditions of this Public License, + the Licensor hereby grants You a worldwide, royalty-free, + non-sublicensable, non-exclusive, irrevocable license to + exercise the Licensed Rights in the Licensed Material to: + + a. reproduce and Share the Licensed Material, in whole or + in part, for NonCommercial purposes only; and + + b. produce, reproduce, and Share Adapted Material for + NonCommercial purposes only. + + 2. Exceptions and Limitations. For the avoidance of doubt, where + Exceptions and Limitations apply to Your use, this Public + License does not apply, and You do not need to comply with + its terms and conditions. + + 3. Term. The term of this Public License is specified in Section + 6(a). + + 4. Media and formats; technical modifications allowed. The + Licensor authorizes You to exercise the Licensed Rights in + all media and formats whether now known or hereafter created, + and to make technical modifications necessary to do so. The + Licensor waives and/or agrees not to assert any right or + authority to forbid You from making technical modifications + necessary to exercise the Licensed Rights, including + technical modifications necessary to circumvent Effective + Technological Measures. For purposes of this Public License, + simply making modifications authorized by this Section 2(a) + (4) never produces Adapted Material. + + 5. Downstream recipients. + + a. Offer from the Licensor -- Licensed Material. Every + recipient of the Licensed Material automatically + receives an offer from the Licensor to exercise the + Licensed Rights under the terms and conditions of this + Public License. + + b. No downstream restrictions. You may not offer or impose + any additional or different terms or conditions on, or + apply any Effective Technological Measures to, the + Licensed Material if doing so restricts exercise of the + Licensed Rights by any recipient of the Licensed + Material. + + 6. No endorsement. Nothing in this Public License constitutes or + may be construed as permission to assert or imply that You + are, or that Your use of the Licensed Material is, connected + with, or sponsored, endorsed, or granted official status by, + the Licensor or others designated to receive attribution as + provided in Section 3(a)(1)(A)(i). + + b. Other rights. + + 1. Moral rights, such as the right of integrity, are not + licensed under this Public License, nor are publicity, + privacy, and/or other similar personality rights; however, to + the extent possible, the Licensor waives and/or agrees not to + assert any such rights held by the Licensor to the limited + extent necessary to allow You to exercise the Licensed + Rights, but not otherwise. + + 2. Patent and trademark rights are not licensed under this + Public License. + + 3. To the extent possible, the Licensor waives any right to + collect royalties from You for the exercise of the Licensed + Rights, whether directly or through a collecting society + under any voluntary or waivable statutory or compulsory + licensing scheme. In all other cases the Licensor expressly + reserves any right to collect such royalties, including when + the Licensed Material is used other than for NonCommercial + purposes. + +Section 3 -- License Conditions. + +Your exercise of the Licensed Rights is expressly made subject to the +following conditions. + + a. Attribution. + + 1. If You Share the Licensed Material (including in modified + form), You must: + + a. retain the following if it is supplied by the Licensor + with the Licensed Material: + + i. identification of the creator(s) of the Licensed + Material and any others designated to receive + attribution, in any reasonable manner requested by + the Licensor (including by pseudonym if + designated); + + ii. a copyright notice; + + iii. a notice that refers to this Public License; + + iv. a notice that refers to the disclaimer of + warranties; + + v. a URI or hyperlink to the Licensed Material to the + extent reasonably practicable; + + b. indicate if You modified the Licensed Material and + retain an indication of any previous modifications; and + + c. indicate the Licensed Material is licensed under this + Public License, and include the text of, or the URI or + hyperlink to, this Public License. + + 2. You may satisfy the conditions in Section 3(a)(1) in any + reasonable manner based on the medium, means, and context in + which You Share the Licensed Material. For example, it may be + reasonable to satisfy the conditions by providing a URI or + hyperlink to a resource that includes the required + information. + + 3. If requested by the Licensor, You must remove any of the + information required by Section 3(a)(1)(A) to the extent + reasonably practicable. + + 4. If You Share Adapted Material You produce, the Adapter's + License You apply must not prevent recipients of the Adapted + Material from complying with this Public License. + +Section 4 -- Sui Generis Database Rights. + +Where the Licensed Rights include Sui Generis Database Rights that +apply to Your use of the Licensed Material: + + a. for the avoidance of doubt, Section 2(a)(1) grants You the right + to extract, reuse, reproduce, and Share all or a substantial + portion of the contents of the database for NonCommercial purposes + only; + + b. if You include all or a substantial portion of the database + contents in a database in which You have Sui Generis Database + Rights, then the database in which You have Sui Generis Database + Rights (but not its individual contents) is Adapted Material; and + + c. You must comply with the conditions in Section 3(a) if You Share + all or a substantial portion of the contents of the database. + +For the avoidance of doubt, this Section 4 supplements and does not +replace Your obligations under this Public License where the Licensed +Rights include other Copyright and Similar Rights. + +Section 5 -- Disclaimer of Warranties and Limitation of Liability. + + a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE + EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS + AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF + ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, + IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, + WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR + PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, + ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT + KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT + ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. + + b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE + TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, + NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, + INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, + COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR + USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN + ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR + DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR + IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. + + c. The disclaimer of warranties and limitation of liability provided + above shall be interpreted in a manner that, to the extent + possible, most closely approximates an absolute disclaimer and + waiver of all liability. + +Section 6 -- Term and Termination. + + a. This Public License applies for the term of the Copyright and + Similar Rights licensed here. However, if You fail to comply with + this Public License, then Your rights under this Public License + terminate automatically. + + b. Where Your right to use the Licensed Material has terminated under + Section 6(a), it reinstates: + + 1. automatically as of the date the violation is cured, provided + it is cured within 30 days of Your discovery of the + violation; or + + 2. upon express reinstatement by the Licensor. + + For the avoidance of doubt, this Section 6(b) does not affect any + right the Licensor may have to seek remedies for Your violations + of this Public License. + + c. For the avoidance of doubt, the Licensor may also offer the + Licensed Material under separate terms or conditions or stop + distributing the Licensed Material at any time; however, doing so + will not terminate this Public License. + + d. Sections 1, 5, 6, 7, and 8 survive termination of this Public + License. + +Section 7 -- Other Terms and Conditions. + + a. The Licensor shall not be bound by any additional or different + terms or conditions communicated by You unless expressly agreed. + + b. Any arrangements, understandings, or agreements regarding the + Licensed Material not stated herein are separate from and + independent of the terms and conditions of this Public License. + +Section 8 -- Interpretation. + + a. For the avoidance of doubt, this Public License does not, and + shall not be interpreted to, reduce, limit, restrict, or impose + conditions on any use of the Licensed Material that could lawfully + be made without permission under this Public License. + + b. To the extent possible, if any provision of this Public License is + deemed unenforceable, it shall be automatically reformed to the + minimum extent necessary to make it enforceable. If the provision + cannot be reformed, it shall be severed from this Public License + without affecting the enforceability of the remaining terms and + conditions. + + c. No term or condition of this Public License will be waived and no + failure to comply consented to unless expressly agreed to by the + Licensor. + + d. Nothing in this Public License constitutes or may be interpreted + as a limitation upon, or waiver of, any privileges and immunities + that apply to the Licensor or You, including from the legal + processes of any jurisdiction or authority. + +======================================================================= + +Creative Commons is not a party to its public +licenses. Notwithstanding, Creative Commons may elect to apply one of +its public licenses to material it publishes and in those instances +will be considered the “Licensor.” The text of the Creative Commons +public licenses is dedicated to the public domain under the CC0 Public +Domain Dedication. Except for the limited purpose of indicating that +material is shared under a Creative Commons public license or as +otherwise permitted by the Creative Commons policies published at +creativecommons.org/policies, Creative Commons does not authorize the +use of the trademark "Creative Commons" or any other trademark or logo +of Creative Commons without its prior written consent including, +without limitation, in connection with any unauthorized modifications +to any of its public licenses or any other arrangements, +understandings, or agreements concerning use of licensed material. For +the avoidance of doubt, this paragraph does not form part of the +public licenses. + +Creative Commons may be contacted at creativecommons.org. \ No newline at end of file diff --git a/visualize/ca_body/README.md b/visualize/ca_body/README.md new file mode 100644 index 0000000000000000000000000000000000000000..9c5b110923561e5fb5679b11fed7c0bdd8ec927f --- /dev/null +++ b/visualize/ca_body/README.md @@ -0,0 +1,17 @@ +# ca_body + +Codec Avatar Body + +### Dependencies + +See `requirements.txt` + +### Repository structure + +- `ca_body/` - python source + * `models` - standalone models + * `nn` - reusable modules (layers, blocks, learnable, modules, networks) + * `utils` - reusable utils (functions, modules w/o learnable params) + +- `notebooks/` - example notebooks +- `data/` - location of sample data and checkpoints diff --git a/visualize/ca_body/models/mesh_vae_drivable.py b/visualize/ca_body/models/mesh_vae_drivable.py new file mode 100644 index 0000000000000000000000000000000000000000..e236b1ed1108a247a49543220467ca6fdbb2f4ad --- /dev/null +++ b/visualize/ca_body/models/mesh_vae_drivable.py @@ -0,0 +1,765 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import logging +from typing import Dict, Optional, Tuple + +import numpy as np +import torch as th +import torch.nn as nn +import torch.nn.functional as F + +from torchvision.utils import make_grid +from torchvision.transforms.functional import gaussian_blur + +import visualize.ca_body.nn.layers as la + +from visualize.ca_body.nn.blocks import ( + ConvBlock, + ConvDownBlock, + UpConvBlockDeep, + tile2d, + weights_initializer, +) +from visualize.ca_body.nn.dof_cal import LearnableBlur + +from visualize.ca_body.utils.geom import ( + GeometryModule, + compute_view_cos, + depth_discontuity_mask, + depth2normals, +) + +from visualize.ca_body.nn.shadow import ShadowUNet, PoseToShadow +from visualize.ca_body.nn.unet import UNetWB +from visualize.ca_body.nn.color_cal import CalV5 + +from visualize.ca_body.utils.image import linear2displayBatch +from visualize.ca_body.utils.lbs import LBSModule +from visualize.ca_body.utils.render import RenderLayer +from visualize.ca_body.utils.seams import SeamSampler +from visualize.ca_body.utils.render import RenderLayer + +from visualize.ca_body.nn.face import FaceDecoderFrontal + +logger = logging.getLogger(__name__) + + +class CameraPixelBias(nn.Module): + def __init__(self, image_height, image_width, cameras, ds_rate) -> None: + super().__init__() + self.image_height = image_height + self.image_width = image_width + self.cameras = cameras + self.n_cameras = len(cameras) + + bias = th.zeros( + (self.n_cameras, 1, image_width // ds_rate, image_height // ds_rate), dtype=th.float32 + ) + self.register_parameter("bias", nn.Parameter(bias)) + + def forward(self, idxs: th.Tensor): + bias_up = F.interpolate( + self.bias[idxs], size=(self.image_height, self.image_width), mode='bilinear' + ) + return bias_up + + +class AutoEncoder(nn.Module): + def __init__( + self, + encoder, + decoder, + decoder_view, + encoder_face, + # hqlp decoder to get the codes + decoder_face, + shadow_net, + upscale_net, + assets, + pose_to_shadow=None, + renderer=None, + cal=None, + pixel_cal=None, + learn_blur: bool = True, + ): + super().__init__() + # TODO: should we have a shared LBS here? + + self.geo_fn = GeometryModule( + assets.topology.vi, + assets.topology.vt, + assets.topology.vti, + assets.topology.v2uv, + uv_size=1024, + impaint=True, + ) + + self.lbs_fn = LBSModule( + assets.lbs_model_json, + assets.lbs_config_dict, + assets.lbs_template_verts, + assets.lbs_scale, + assets.global_scaling, + ) + + self.seam_sampler = SeamSampler(assets.seam_data_1024) + self.seam_sampler_2k = SeamSampler(assets.seam_data_2048) + + # joint tex -> body and clothes + # TODO: why do we have a joint one in the first place? + tex_mean = gaussian_blur(th.as_tensor(assets.tex_mean)[np.newaxis], kernel_size=11) + self.register_buffer("tex_mean", F.interpolate(tex_mean, (2048, 2048), mode='bilinear')) + + # this is shared + self.tex_std = assets.tex_var if 'tex_var' in assets else 64.0 + + face_cond_mask = th.as_tensor(assets.face_cond_mask, dtype=th.float32)[ + np.newaxis, np.newaxis + ] + self.register_buffer("face_cond_mask", face_cond_mask) + + meye_mask = self.geo_fn.to_uv( + th.as_tensor(assets.mouth_eyes_mask_geom[np.newaxis, :, np.newaxis]) + ) + meye_mask = F.interpolate(meye_mask, (2048, 2048), mode='bilinear') + self.register_buffer("meye_mask", meye_mask) + + self.decoder = ConvDecoder( + geo_fn=self.geo_fn, + seam_sampler=self.seam_sampler, + **decoder, + assets=assets, + ) + + # embs for everything but face + non_head_mask = 1.0 - assets.face_mask + self.encoder = Encoder( + geo_fn=self.geo_fn, + mask=non_head_mask, + **encoder, + ) + self.encoder_face = FaceEncoder( + assets=assets, + **encoder_face, + ) + + # using face decoder to generate better conditioning + decoder_face_ckpt_path = None + if 'ckpt' in decoder_face: + decoder_face_ckpt_path = decoder_face.pop('ckpt') + self.decoder_face = FaceDecoderFrontal(assets=assets, **decoder_face) + + if decoder_face_ckpt_path is not None: + self.decoder_face.load_state_dict(th.load(decoder_face_ckpt_path), strict=False) + + self.decoder_view = UNetViewDecoder( + self.geo_fn, + seam_sampler=self.seam_sampler, + **decoder_view, + ) + + self.shadow_net = ShadowUNet( + ao_mean=assets.ao_mean, + interp_mode="bilinear", + biases=False, + **shadow_net, + ) + + self.pose_to_shadow_enabled = False + if pose_to_shadow is not None: + self.pose_to_shadow_enabled = True + self.pose_to_shadow = PoseToShadow(**pose_to_shadow) + + self.upscale_net = UpscaleNet( + in_channels=6, size=1024, upscale_factor=2, out_channels=3, **upscale_net + ) + + self.pixel_cal_enabled = False + if pixel_cal is not None: + self.pixel_cal_enabled = True + self.pixel_cal = CameraPixelBias(**pixel_cal, cameras=assets.camera_ids) + + self.learn_blur_enabled = False + if learn_blur: + self.learn_blur_enabled = True + self.learn_blur = LearnableBlur(assets.camera_ids) + + # training-only stuff + self.cal_enabled = False + if cal is not None: + self.cal_enabled = True + self.cal = CalV5(**cal, cameras=assets.camera_ids) + + self.rendering_enabled = False + if renderer is not None: + self.rendering_enabled = True + self.renderer = RenderLayer( + h=renderer.image_height, + w=renderer.image_width, + vt=self.geo_fn.vt, + vi=self.geo_fn.vi, + vti=self.geo_fn.vti, + flip_uvs=False, + ) + + @th.jit.unused + def compute_summaries(self, preds, batch): + # TODO: switch to common summaries? + # return compute_summaries_mesh(preds, batch) + rgb = linear2displayBatch(preds['rgb'][:, :3]) + rgb_gt = linear2displayBatch(batch['image']) + depth = preds['depth'][:, np.newaxis] + mask = depth > 0.0 + normals = ( + 255 * (1.0 - depth2normals(depth, batch['focal'], batch['princpt'])) / 2.0 + ) * mask + grid_rgb = make_grid(rgb, nrow=16).permute(1, 2, 0).clip(0, 255).to(th.uint8) + grid_rgb_gt = make_grid(rgb_gt, nrow=16).permute(1, 2, 0).clip(0, 255).to(th.uint8) + grid_normals = make_grid(normals, nrow=16).permute(1, 2, 0).clip(0, 255).to(th.uint8) + + progress_image = th.cat([grid_rgb, grid_rgb_gt, grid_normals], dim=0) + return { + 'progress_image': (progress_image, 'png'), + } + + def forward_tex(self, tex_mean_rec, tex_view_rec, shadow_map): + x = th.cat([tex_mean_rec, tex_view_rec], dim=1) + tex_rec = tex_mean_rec + tex_view_rec + + tex_rec = self.seam_sampler.impaint(tex_rec) + tex_rec = self.seam_sampler.resample(tex_rec) + + tex_rec = F.interpolate(tex_rec, size=(2048, 2048), mode="bilinear", align_corners=False) + tex_rec = tex_rec + self.upscale_net(x) + + tex_rec = tex_rec * self.tex_std + self.tex_mean + + shadow_map = self.seam_sampler_2k.impaint(shadow_map) + shadow_map = self.seam_sampler_2k.resample(shadow_map) + shadow_map = self.seam_sampler_2k.resample(shadow_map) + + tex_rec = tex_rec * shadow_map + + tex_rec = self.seam_sampler_2k.impaint(tex_rec) + tex_rec = self.seam_sampler_2k.resample(tex_rec) + tex_rec = self.seam_sampler_2k.resample(tex_rec) + + return tex_rec + + def encode(self, geom: th.Tensor, lbs_motion: th.Tensor, face_embs_hqlp: th.Tensor): + + with th.no_grad(): + verts_unposed = self.lbs_fn.unpose(geom, lbs_motion) + verts_unposed_uv = self.geo_fn.to_uv(verts_unposed) + + # extract face region for geom + tex + enc_preds = self.encoder(motion=lbs_motion, verts_unposed=verts_unposed) + # TODO: probably need to rename these to `face_embs_mugsy` or smth + # TODO: we need the same thing for face? + # enc_face_preds = self.encoder_face(verts_unposed_uv) + with th.no_grad(): + face_dec_preds = self.decoder_face(face_embs_hqlp) + enc_face_preds = self.encoder_face(**face_dec_preds) + + preds = { + **enc_preds, + **enc_face_preds, + 'face_dec_preds': face_dec_preds, + } + return preds + + def forward( + self, + # TODO: should we try using this as well for cond? + lbs_motion: th.Tensor, + campos: th.Tensor, + geom: Optional[th.Tensor] = None, + ao: Optional[th.Tensor] = None, + K: Optional[th.Tensor] = None, + Rt: Optional[th.Tensor] = None, + image_bg: Optional[th.Tensor] = None, + image: Optional[th.Tensor] = None, + image_mask: Optional[th.Tensor] = None, + embs: Optional[th.Tensor] = None, + _index: Optional[Dict[str, th.Tensor]] = None, + face_embs: Optional[th.Tensor] = None, + embs_conv: Optional[th.Tensor] = None, + tex_seg: Optional[th.Tensor] = None, + encode=True, + iteration: Optional[int] = None, + **kwargs, + ): + B = lbs_motion.shape[0] + + if not th.jit.is_scripting() and encode: + # NOTE: these are `face_embs_hqlp` + enc_preds = self.encode(geom, lbs_motion, face_embs) + embs = enc_preds['embs'] + # NOTE: these are `face_embs` in body space + face_embs_body = enc_preds['face_embs'] + + dec_preds = self.decoder( + motion=lbs_motion, + embs=embs, + face_embs=face_embs_body, + embs_conv=embs_conv, + ) + + geom_rec = self.lbs_fn.pose(dec_preds['geom_delta_rec'], lbs_motion) + + dec_view_preds = self.decoder_view( + geom_rec=geom_rec, + tex_mean_rec=dec_preds["tex_mean_rec"], + camera_pos=campos, + ) + + # TODO: should we train an AO model? + if self.training and self.pose_to_shadow_enabled: + shadow_preds = self.shadow_net(ao_map=ao) + pose_shadow_preds = self.pose_to_shadow(lbs_motion) + shadow_preds['pose_shadow_map'] = pose_shadow_preds['shadow_map'] + elif self.pose_to_shadow_enabled: + shadow_preds = self.pose_to_shadow(lbs_motion) + else: + shadow_preds = self.shadow_net(ao_map=ao) + + tex_rec = self.forward_tex( + dec_preds["tex_mean_rec"], + dec_view_preds["tex_view_rec"], + shadow_preds["shadow_map"], + ) + + if not th.jit.is_scripting() and self.cal_enabled: + tex_rec = self.cal(tex_rec, self.cal.name_to_idx(_index['camera'])) + + preds = { + 'geom': geom_rec, + 'tex_rec': tex_rec, + **dec_preds, + **shadow_preds, + **dec_view_preds, + } + + if not th.jit.is_scripting() and encode: + preds.update(**enc_preds) + + if not th.jit.is_scripting() and self.rendering_enabled: + + # NOTE: this is a reduced version tested for forward only + renders = self.renderer( + preds['geom'], + tex_rec, + K=K, + Rt=Rt, + ) + + preds.update(rgb=renders['render']) + + if not th.jit.is_scripting() and self.learn_blur_enabled: + preds['rgb'] = self.learn_blur(preds['rgb'], _index['camera']) + preds['learn_blur_weights'] = self.learn_blur.reg(_index['camera']) + + if not th.jit.is_scripting() and self.pixel_cal_enabled: + assert self.cal_enabled + cam_idxs = self.cal.name_to_idx(_index['camera']) + pixel_bias = self.pixel_cal(cam_idxs) + preds['rgb'] = preds['rgb'] + pixel_bias + + return preds + + +class Encoder(nn.Module): + """A joint encoder for tex and geometry.""" + + def __init__( + self, + geo_fn, + n_embs, + noise_std, + mask, + logvar_scale=0.1, + ): + """Fixed-width conv encoder.""" + super().__init__() + + self.noise_std = noise_std + self.n_embs = n_embs + self.geo_fn = geo_fn + self.logvar_scale = logvar_scale + + self.verts_conv = ConvDownBlock(3, 8, 512) + + mask = th.as_tensor(mask[np.newaxis, np.newaxis], dtype=th.float32) + mask = F.interpolate(mask, size=(512, 512), mode='bilinear').to(th.bool) + self.register_buffer("mask", mask) + + self.joint_conv_blocks = nn.Sequential( + ConvDownBlock(8, 16, 256), + ConvDownBlock(16, 32, 128), + ConvDownBlock(32, 32, 64), + ConvDownBlock(32, 64, 32), + ConvDownBlock(64, 128, 16), + ConvDownBlock(128, 128, 8), + # ConvDownBlock(128, 128, 4), + ) + + # TODO: should we put initializer + self.mu = la.LinearWN(4 * 4 * 128, self.n_embs) + self.logvar = la.LinearWN(4 * 4 * 128, self.n_embs) + + self.apply(weights_initializer(0.2)) + self.mu.apply(weights_initializer(1.0)) + self.logvar.apply(weights_initializer(1.0)) + + def forward(self, motion, verts_unposed): + preds = {} + + B = motion.shape[0] + + # converting motion to the unposed + verts_cond = ( + F.interpolate(self.geo_fn.to_uv(verts_unposed), size=(512, 512), mode='bilinear') + * self.mask + ) + verts_cond = self.verts_conv(verts_cond) + + # tex_cond = F.interpolate(tex_avg, size=(512, 512), mode='bilinear') * self.mask + # tex_cond = self.tex_conv(tex_cond) + # joint_cond = th.cat([verts_cond, tex_cond], dim=1) + joint_cond = verts_cond + x = self.joint_conv_blocks(joint_cond) + x = x.reshape(B, -1) + embs_mu = self.mu(x) + embs_logvar = self.logvar_scale * self.logvar(x) + + # NOTE: the noise is only applied to the input-conditioned values + if self.training: + noise = th.randn_like(embs_mu) + embs = embs_mu + th.exp(embs_logvar) * noise * self.noise_std + else: + embs = embs_mu.clone() + + preds.update( + embs=embs, + embs_mu=embs_mu, + embs_logvar=embs_logvar, + ) + + return preds + + +class ConvDecoder(nn.Module): + """Multi-region view-independent decoder.""" + + def __init__( + self, + geo_fn, + uv_size, + seam_sampler, + init_uv_size, + n_pose_dims, + n_pose_enc_channels, + n_embs, + n_embs_enc_channels, + n_face_embs, + n_init_channels, + n_min_channels, + assets, + ): + super().__init__() + + self.geo_fn = geo_fn + + self.uv_size = uv_size + self.init_uv_size = init_uv_size + self.n_pose_dims = n_pose_dims + self.n_pose_enc_channels = n_pose_enc_channels + self.n_embs = n_embs + self.n_embs_enc_channels = n_embs_enc_channels + self.n_face_embs = n_face_embs + + self.n_blocks = int(np.log2(self.uv_size // init_uv_size)) + self.sizes = [init_uv_size * 2**s for s in range(self.n_blocks + 1)] + + # TODO: just specify a sequence? + self.n_channels = [ + max(n_init_channels // 2**b, n_min_channels) for b in range(self.n_blocks + 1) + ] + + logger.info(f"ConvDecoder: n_channels = {self.n_channels}") + + self.local_pose_conv_block = ConvBlock( + n_pose_dims, + n_pose_enc_channels, + init_uv_size, + kernel_size=1, + padding=0, + ) + + self.embs_fc = nn.Sequential( + la.LinearWN(n_embs, 4 * 4 * 128), + nn.LeakyReLU(0.2, inplace=True), + ) + # TODO: should we switch to the basic version? + self.embs_conv_block = nn.Sequential( + UpConvBlockDeep(128, 128, 8), + UpConvBlockDeep(128, 128, 16), + UpConvBlockDeep(128, 64, 32), + UpConvBlockDeep(64, n_embs_enc_channels, 64), + ) + + self.face_embs_fc = nn.Sequential( + la.LinearWN(n_face_embs, 4 * 4 * 32), + nn.LeakyReLU(0.2, inplace=True), + ) + self.face_embs_conv_block = nn.Sequential( + UpConvBlockDeep(32, 64, 8), + UpConvBlockDeep(64, 64, 16), + UpConvBlockDeep(64, n_embs_enc_channels, 32), + ) + + n_groups = 2 + + self.joint_conv_block = ConvBlock( + n_pose_enc_channels + n_embs_enc_channels, + n_init_channels, + self.init_uv_size, + ) + + self.conv_blocks = nn.ModuleList([]) + for b in range(self.n_blocks): + self.conv_blocks.append( + UpConvBlockDeep( + self.n_channels[b] * n_groups, + self.n_channels[b + 1] * n_groups, + self.sizes[b + 1], + groups=n_groups, + ), + ) + + self.verts_conv = la.Conv2dWNUB( + in_channels=self.n_channels[-1], + out_channels=3, + kernel_size=3, + height=self.uv_size, + width=self.uv_size, + padding=1, + ) + self.tex_conv = la.Conv2dWNUB( + in_channels=self.n_channels[-1], + out_channels=3, + kernel_size=3, + height=self.uv_size, + width=self.uv_size, + padding=1, + ) + + self.apply(weights_initializer(0.2)) + self.verts_conv.apply(weights_initializer(1.0)) + self.tex_conv.apply(weights_initializer(1.0)) + + self.seam_sampler = seam_sampler + + # NOTE: removing head region from pose completely + pose_cond_mask = th.as_tensor( + assets.pose_cond_mask[np.newaxis] * (1 - assets.head_cond_mask[np.newaxis, np.newaxis]), + dtype=th.int32, + ) + self.register_buffer("pose_cond_mask", pose_cond_mask) + face_cond_mask = th.as_tensor(assets.face_cond_mask, dtype=th.float32)[ + np.newaxis, np.newaxis + ] + self.register_buffer("face_cond_mask", face_cond_mask) + + body_cond_mask = th.as_tensor(assets.body_cond_mask, dtype=th.float32)[ + np.newaxis, np.newaxis + ] + self.register_buffer("body_cond_mask", body_cond_mask) + + def forward(self, motion, embs, face_embs, embs_conv: Optional[th.Tensor] = None): + + # processing pose + pose = motion[:, 6:] + + B = pose.shape[0] + + non_head_mask = (self.body_cond_mask * (1.0 - self.face_cond_mask)).clip(0.0, 1.0) + + pose_masked = tile2d(pose, self.init_uv_size) * self.pose_cond_mask + pose_conv = self.local_pose_conv_block(pose_masked) * non_head_mask + + # TODO: decoding properly? + if embs_conv is None: + embs_conv = self.embs_conv_block(self.embs_fc(embs).reshape(B, 128, 4, 4)) + + face_conv = self.face_embs_conv_block(self.face_embs_fc(face_embs).reshape(B, 32, 4, 4)) + # merging embeddings with spatial masks + embs_conv[:, :, 32:, :32] = ( + face_conv * self.face_cond_mask[:, :, 32:, :32] + + embs_conv[:, :, 32:, :32] * non_head_mask[:, :, 32:, :32] + ) + + joint = th.cat([pose_conv, embs_conv], axis=1) + joint = self.joint_conv_block(joint) + + x = th.cat([joint, joint], axis=1) + for b in range(self.n_blocks): + x = self.conv_blocks[b](x) + + # NOTE: here we do resampling at feature level + x = self.seam_sampler.impaint(x) + x = self.seam_sampler.resample(x) + x = self.seam_sampler.resample(x) + + verts_features, tex_features = th.split(x, self.n_channels[-1], 1) + + verts_uv_delta_rec = self.verts_conv(verts_features) + # TODO: need to get values + verts_delta_rec = self.geo_fn.from_uv(verts_uv_delta_rec) + tex_mean_rec = self.tex_conv(tex_features) + + preds = { + 'geom_delta_rec': verts_delta_rec, + 'geom_uv_delta_rec': verts_uv_delta_rec, + 'tex_mean_rec': tex_mean_rec, + 'embs_conv': embs_conv, + 'pose_conv': pose_conv, + } + + return preds + + +class FaceEncoder(nn.Module): + """A joint encoder for tex and geometry.""" + + def __init__( + self, + noise_std, + assets, + n_embs=256, + uv_size=512, + logvar_scale=0.1, + n_vert_in=7306 * 3, + prefix="face_", + ): + + """Fixed-width conv encoder.""" + super().__init__() + + # TODO: + self.noise_std = noise_std + self.n_embs = n_embs + self.logvar_scale = logvar_scale + self.prefix = prefix + self.uv_size = uv_size + + assert self.uv_size == 512 + + tex_cond_mask = assets.mugsy_face_mask[..., 0] + tex_cond_mask = th.as_tensor(tex_cond_mask, dtype=th.float32)[np.newaxis, np.newaxis] + tex_cond_mask = F.interpolate( + tex_cond_mask, (self.uv_size, self.uv_size), mode="bilinear", align_corners=True + ) + self.register_buffer("tex_cond_mask", tex_cond_mask) + + self.conv_blocks = nn.Sequential( + ConvDownBlock(3, 4, 512), + ConvDownBlock(4, 8, 256), + ConvDownBlock(8, 16, 128), + ConvDownBlock(16, 32, 64), + ConvDownBlock(32, 64, 32), + ConvDownBlock(64, 128, 16), + ConvDownBlock(128, 128, 8), + ) + self.geommod = nn.Sequential(la.LinearWN(n_vert_in, 256), nn.LeakyReLU(0.2, inplace=True)) + self.jointmod = nn.Sequential( + la.LinearWN(256 + 128 * 4 * 4, 512), nn.LeakyReLU(0.2, inplace=True) + ) + # TODO: should we put initializer + self.mu = la.LinearWN(512, self.n_embs) + self.logvar = la.LinearWN(512, self.n_embs) + + self.apply(weights_initializer(0.2)) + self.mu.apply(weights_initializer(1.0)) + self.logvar.apply(weights_initializer(1.0)) + + # TODO: compute_losses()? + + def forward(self, face_geom: th.Tensor, face_tex: th.Tensor, **kwargs): + B = face_geom.shape[0] + + tex_cond = F.interpolate( + face_tex, (self.uv_size, self.uv_size), mode="bilinear", align_corners=False + ) + tex_cond = (tex_cond / 255.0 - 0.5) * self.tex_cond_mask + x = self.conv_blocks(tex_cond) + tex_enc = x.reshape(B, 4 * 4 * 128) + + geom_enc = self.geommod(face_geom.reshape(B, -1)) + + x = self.jointmod(th.cat([tex_enc, geom_enc], dim=1)) + embs_mu = self.mu(x) + embs_logvar = self.logvar_scale * self.logvar(x) + + # NOTE: the noise is only applied to the input-conditioned values + if self.training: + noise = th.randn_like(embs_mu) + embs = embs_mu + th.exp(embs_logvar) * noise * self.noise_std + else: + embs = embs_mu.clone() + + preds = {"embs": embs, "embs_mu": embs_mu, "embs_logvar": embs_logvar, "tex_cond": tex_cond} + preds = {f"{self.prefix}{k}": v for k, v in preds.items()} + return preds + + +class UNetViewDecoder(nn.Module): + def __init__(self, geo_fn, net_uv_size, seam_sampler, n_init_ftrs=8): + super().__init__() + self.geo_fn = geo_fn + self.net_uv_size = net_uv_size + self.unet = UNetWB(4, 3, n_init_ftrs=n_init_ftrs, size=net_uv_size) + self.register_buffer("faces", self.geo_fn.vi.to(th.int64), persistent=False) + + def forward(self, geom_rec, tex_mean_rec, camera_pos): + + with th.no_grad(): + view_cos = compute_view_cos(geom_rec, self.faces, camera_pos) + view_cos_uv = self.geo_fn.to_uv(view_cos[..., np.newaxis]) + cond_view = th.cat([view_cos_uv, tex_mean_rec], dim=1) + tex_view = self.unet(cond_view) + # TODO: should we try warping here? + return {"tex_view_rec": tex_view, "cond_view": cond_view} + + +class UpscaleNet(nn.Module): + def __init__(self, in_channels, out_channels, n_ftrs, size=1024, upscale_factor=2): + super().__init__() + + self.conv_block = nn.Sequential( + la.Conv2dWNUB(in_channels, n_ftrs, size, size, kernel_size=3, padding=1), + nn.LeakyReLU(0.2, inplace=True), + ) + + self.out_block = la.Conv2dWNUB( + n_ftrs, + out_channels * upscale_factor**2, + size, + size, + kernel_size=1, + padding=0, + ) + + self.pixel_shuffle = nn.PixelShuffle(upscale_factor=upscale_factor) + self.apply(weights_initializer(0.2)) + self.out_block.apply(weights_initializer(1.0)) + + def forward(self, x): + x = self.conv_block(x) + x = self.out_block(x) + return self.pixel_shuffle(x) \ No newline at end of file diff --git a/visualize/ca_body/nn/blocks.py b/visualize/ca_body/nn/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..9e1dbb62dd8ffc294eed8bd9472ea6857fd43105 --- /dev/null +++ b/visualize/ca_body/nn/blocks.py @@ -0,0 +1,786 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + + +import logging +from turtle import forward + +import visualize.ca_body.nn.layers as la +from visualize.ca_body.nn.layers import weight_norm_wrapper + +import numpy as np +import torch as th +import torch.nn as nn +import torch.nn.functional as F + + +logger = logging.getLogger(__name__) + + +# pyre-ignore +def weights_initializer(lrelu_slope=0.2): + # pyre-ignore + def init_fn(m): + if isinstance( + m, + ( + nn.Conv2d, + nn.Conv1d, + nn.ConvTranspose2d, + nn.Linear, + ), + ): + gain = nn.init.calculate_gain("leaky_relu", lrelu_slope) + nn.init.kaiming_uniform_(m.weight.data, a=gain) + if hasattr(m, "bias") and m.bias is not None: + nn.init.zeros_(m.bias.data) + else: + logger.debug(f"skipping initialization for {m}") + + return init_fn + + +# pyre-ignore +def WeightNorm(x, dim=0): + return nn.utils.weight_norm(x, dim=dim) + + +# pyre-ignore +def np_warp_bias(uv_size): + xgrid, ygrid = np.meshgrid(np.linspace(-1.0, 1.0, uv_size), np.linspace(-1.0, 1.0, uv_size)) + grid = np.concatenate((xgrid[None, :, :], ygrid[None, :, :]), axis=0)[None, ...].astype( + np.float32 + ) + return grid + + +class Conv2dBias(nn.Conv2d): + __annotations__ = {"bias": th.Tensor} + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + size, + stride=1, + padding=1, + bias=True, + *args, + **kwargs, + ): + super().__init__( + in_channels, + out_channels, + bias=False, + kernel_size=kernel_size, + stride=stride, + padding=padding, + *args, + **kwargs, + ) + if not bias: + logger.warning("ignoring bias=False") + self.bias = nn.Parameter(th.zeros(out_channels, size, size)) + + def forward(self, x): + bias = self.bias.clone() + return ( + # pyre-ignore + th.conv2d( + x, + self.weight, + bias=None, + stride=self.stride, + # pyre-ignore + padding=self.padding, + dilation=self.dilation, + groups=self.groups, + ) + + bias[np.newaxis] + ) + + +class Conv1dBias(nn.Conv1d): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + size, + stride=1, + padding=0, + bias=True, + *args, + **kwargs, + ): + super().__init__( + in_channels, + out_channels, + bias=False, + kernel_size=kernel_size, + stride=stride, + padding=padding, + *args, + **kwargs, + ) + if not bias: + logger.warning("ignoring bias=False") + self.bias = nn.Parameter(th.zeros(out_channels, size)) + + def forward(self, x): + return ( + # pyre-ignore + th.conv1d( + x, + self.weight, + bias=None, + stride=self.stride, + # pyre-ignore + padding=self.padding, + dilation=self.dilation, + groups=self.groups, + ) + + self.bias + ) + + +class UpConvBlock(nn.Module): + # pyre-ignore + def __init__(self, in_channels, out_channels, size, lrelu_slope=0.2): + super().__init__() + # Intergration: it was not exist in github, but assume upsample is same as other class + self.upsample = nn.UpsamplingBilinear2d(size) + self.conv_resize = la.Conv2dWN( + in_channels=in_channels, out_channels=out_channels, kernel_size=1 + ) + self.conv1 = la.Conv2dWNUB( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + height=size, + width=size, + padding=1, + ) + self.lrelu1 = nn.LeakyReLU(lrelu_slope) + # self.conv2 = nn.utils.weight_norm( + # Conv2dBias(in_channels, out_channels, kernel_size=3, size=size), dim=None, + # ) + # self.lrelu2 = nn.LeakyReLU(lrelu_slope) + + # pyre-ignore + def forward(self, x): + x_up = self.upsample(x) + x_skip = self.conv_resize(x_up) + x = self.conv1(x_up) + x = self.lrelu1(x) + return x + x_skip + + +class ConvBlock1d(nn.Module): + def __init__( + self, + in_channels, + out_channels, + size, + lrelu_slope=0.2, + kernel_size=3, + padding=1, + wnorm_dim=0, + ): + super().__init__() + + self.conv_resize = WeightNorm( + nn.Conv1d(in_channels, out_channels, kernel_size=1), dim=wnorm_dim + ) + self.conv1 = WeightNorm( + Conv1dBias( + in_channels, + in_channels, + kernel_size=kernel_size, + padding=padding, + size=size, + ), + dim=wnorm_dim, + ) + self.lrelu1 = nn.LeakyReLU(lrelu_slope) + self.conv2 = WeightNorm( + Conv1dBias( + in_channels, + out_channels, + kernel_size=kernel_size, + padding=padding, + size=size, + ), + dim=wnorm_dim, + ) + self.lrelu2 = nn.LeakyReLU(lrelu_slope) + + def forward(self, x): + x_skip = self.conv_resize(x) + x = self.conv1(x) + x = self.lrelu1(x) + x = self.conv2(x) + x = self.lrelu2(x) + return x + x_skip + + +class ConvBlock(nn.Module): + def __init__( + self, + in_channels, + out_channels, + size, + lrelu_slope=0.2, + kernel_size=3, + padding=1, + wnorm_dim=0, + ): + super().__init__() + + Conv2dWNUB = weight_norm_wrapper(la.Conv2dUB, "Conv2dWNUB", g_dim=wnorm_dim, v_dim=None) + Conv2dWN = weight_norm_wrapper(th.nn.Conv2d, "Conv2dWN", g_dim=wnorm_dim, v_dim=None) + + # TODO: do we really need this? + self.conv_resize = Conv2dWN(in_channels, out_channels, kernel_size=1) + self.conv1 = Conv2dWNUB( + in_channels, + in_channels, + kernel_size=kernel_size, + padding=padding, + height=size, + width=size, + ) + + self.lrelu1 = nn.LeakyReLU(lrelu_slope) + self.conv2 = Conv2dWNUB( + in_channels, + out_channels, + kernel_size=kernel_size, + padding=padding, + height=size, + width=size, + ) + self.lrelu2 = nn.LeakyReLU(lrelu_slope) + + def forward(self, x): + x_skip = self.conv_resize(x) + x = self.conv1(x) + x = self.lrelu1(x) + x = self.conv2(x) + x = self.lrelu2(x) + return x + x_skip + + +class ConvBlockNoSkip(nn.Module): + def __init__( + self, + in_channels, + out_channels, + size, + lrelu_slope=0.2, + kernel_size=3, + padding=1, + wnorm_dim=0, + ): + super().__init__() + + self.conv1 = WeightNorm( + Conv2dBias( + in_channels, + in_channels, + kernel_size=kernel_size, + padding=padding, + size=size, + ), + dim=wnorm_dim, + ) + self.lrelu1 = nn.LeakyReLU(lrelu_slope) + self.conv2 = WeightNorm( + Conv2dBias( + in_channels, + out_channels, + kernel_size=kernel_size, + padding=padding, + size=size, + ), + dim=wnorm_dim, + ) + self.lrelu2 = nn.LeakyReLU(lrelu_slope) + + def forward(self, x): + x = self.conv1(x) + x = self.lrelu1(x) + x = self.conv2(x) + x = self.lrelu2(x) + return x + + +class ConvDownBlock(nn.Module): + def __init__(self, in_channels, out_channels, size, lrelu_slope=0.2, groups=1, wnorm_dim=0): + """Constructor. + + Args: + in_channels: int, # of input channels + out_channels: int, # of input channels + size: the *input* size + """ + super().__init__() + + Conv2dWNUB = weight_norm_wrapper(la.Conv2dUB, "Conv2dWNUB", g_dim=wnorm_dim, v_dim=None) + Conv2dWN = weight_norm_wrapper(th.nn.Conv2d, "Conv2dWN", g_dim=wnorm_dim, v_dim=None) + + self.conv_resize = Conv2dWN( + in_channels, out_channels, kernel_size=1, stride=2, groups=groups + ) + self.conv1 = Conv2dWNUB( + in_channels, + in_channels, + kernel_size=3, + height=size, + width=size, + groups=groups, + padding=1, + ) + self.lrelu1 = nn.LeakyReLU(lrelu_slope) + + self.conv2 = Conv2dWNUB( + in_channels, + out_channels, + kernel_size=3, + stride=2, + height=size // 2, + width=size // 2, + groups=groups, + padding=1, + ) + self.lrelu2 = nn.LeakyReLU(lrelu_slope) + + def forward(self, x): + x_skip = self.conv_resize(x) + x = self.conv1(x) + x = self.lrelu1(x) + x = self.conv2(x) + x = self.lrelu2(x) + return x + x_skip + + +class UpConvBlockDeep(nn.Module): + def __init__(self, in_channels, out_channels, size, lrelu_slope=0.2, wnorm_dim=0, groups=1): + super().__init__() + self.upsample = nn.UpsamplingBilinear2d(size) + + Conv2dWNUB = weight_norm_wrapper(la.Conv2dUB, "Conv2dWNUB", g_dim=wnorm_dim, v_dim=None) + Conv2dWN = weight_norm_wrapper(th.nn.Conv2d, "Conv2dWN", g_dim=wnorm_dim, v_dim=None) + # NOTE: the old one normalizes only across one dimension + + self.conv_resize = Conv2dWN( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + groups=groups, + ) + self.conv1 = Conv2dWNUB( + in_channels, + in_channels, + kernel_size=3, + height=size, + width=size, + padding=1, + groups=groups, + ) + self.lrelu1 = nn.LeakyReLU(lrelu_slope) + self.conv2 = Conv2dWNUB( + in_channels, + out_channels, + kernel_size=3, + height=size, + width=size, + padding=1, + groups=groups, + ) + self.lrelu2 = nn.LeakyReLU(lrelu_slope) + + def forward(self, x): + x_up = self.upsample(x) + x_skip = self.conv_resize(x_up) + + x = x_up + x = self.conv1(x) + x = self.lrelu1(x) + x = self.conv2(x) + x = self.lrelu2(x) + + return x + x_skip + + +class ConvBlockPositional(nn.Module): + def __init__( + self, + in_channels, + out_channels, + pos_map, + lrelu_slope=0.2, + kernel_size=3, + padding=1, + wnorm_dim=0, + ): + """Block with positional encoding. + + Args: + in_channels: # of input channels (not counting the positional encoding) + out_channels: # of output channels + pos_map: tensor [P, size, size] + """ + super().__init__() + assert len(pos_map.shape) == 3 and pos_map.shape[1] == pos_map.shape[2] + self.register_buffer("pos_map", pos_map) + + self.conv_resize = WeightNorm(nn.Conv2d(in_channels, out_channels, 1), dim=wnorm_dim) + + self.conv1 = WeightNorm( + nn.Conv2d( + in_channels + pos_map.shape[0], + in_channels, + kernel_size=3, + padding=padding, + ), + dim=wnorm_dim, + ) + self.lrelu1 = nn.LeakyReLU(lrelu_slope) + self.conv2 = WeightNorm( + nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=padding), + dim=wnorm_dim, + ) + self.lrelu2 = nn.LeakyReLU(lrelu_slope) + + def forward(self, x): + B = x.shape[0] + + x_skip = self.conv_resize(x) + + pos = self.pos_map[np.newaxis].expand(B, -1, -1, -1) + + x = th.cat([x, pos], dim=1) + x = self.conv1(x) + x = self.lrelu1(x) + x = self.conv2(x) + x = self.lrelu2(x) + return x + x_skip + + +class UpConvBlockPositional(nn.Module): + def __init__( + self, + in_channels, + out_channels, + pos_map, + lrelu_slope=0.2, + wnorm_dim=0, + ): + """Block with positional encoding. + + Args: + in_channels: # of input channels (not counting the positional encoding) + out_channels: # of output channels + pos_map: tensor [P, size, size] + """ + super().__init__() + assert len(pos_map.shape) == 3 and pos_map.shape[1] == pos_map.shape[2] + self.register_buffer("pos_map", pos_map) + size = pos_map.shape[1] + + self.in_channels = in_channels + self.out_channels = out_channels + + self.upsample = nn.UpsamplingBilinear2d(size) + + if in_channels != out_channels: + self.conv_resize = WeightNorm(nn.Conv2d(in_channels, out_channels, 1), dim=wnorm_dim) + + self.conv1 = WeightNorm( + nn.Conv2d( + in_channels + pos_map.shape[0], + in_channels, + kernel_size=3, + padding=1, + ), + dim=wnorm_dim, + ) + self.lrelu1 = nn.LeakyReLU(lrelu_slope) + self.conv2 = WeightNorm( + nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), + dim=wnorm_dim, + ) + self.lrelu2 = nn.LeakyReLU(lrelu_slope) + + def forward(self, x): + B = x.shape[0] + + x_up = self.upsample(x) + + x_skip = x_up + if self.in_channels != self.out_channels: + x_skip = self.conv_resize(x_up) + + pos = self.pos_map[np.newaxis].expand(B, -1, -1, -1) + + x = th.cat([x_up, pos], dim=1) + x = self.conv1(x) + x = self.lrelu1(x) + x = self.conv2(x) + x = self.lrelu2(x) + + return x + x_skip + + +class UpConvBlockDeepNoBias(nn.Module): + def __init__(self, in_channels, out_channels, size, lrelu_slope=0.2, wnorm_dim=0, groups=1): + super().__init__() + self.upsample = nn.UpsamplingBilinear2d(size) + # NOTE: the old one normalizes only across one dimension + self.conv_resize = WeightNorm( + nn.Conv2d(in_channels, out_channels, 1, groups=groups), dim=wnorm_dim + ) + self.conv1 = WeightNorm( + nn.Conv2d(in_channels, in_channels, padding=1, kernel_size=3, groups=groups), + dim=wnorm_dim, + ) + self.lrelu1 = nn.LeakyReLU(lrelu_slope) + self.conv2 = WeightNorm( + nn.Conv2d(in_channels, out_channels, padding=1, kernel_size=3, groups=groups), + dim=wnorm_dim, + ) + self.lrelu2 = nn.LeakyReLU(lrelu_slope) + + def forward(self, x): + x_up = self.upsample(x) + x_skip = self.conv_resize(x_up) + + x = x_up + x = self.conv1(x) + x = self.lrelu1(x) + x = self.conv2(x) + x = self.lrelu2(x) + + return x + x_skip + + +class UpConvBlockXDeep(nn.Module): + def __init__(self, in_channels, out_channels, size, lrelu_slope=0.2, wnorm_dim=0): + super().__init__() + self.upsample = nn.UpsamplingBilinear2d(size) + # TODO: see if this is necce + self.conv_resize = WeightNorm(nn.Conv2d(in_channels, out_channels, 1), dim=wnorm_dim) + self.conv1 = WeightNorm( + Conv2dBias(in_channels, in_channels // 2, kernel_size=3, size=size), + dim=wnorm_dim, + ) + self.lrelu1 = nn.LeakyReLU(lrelu_slope) + + self.conv2 = WeightNorm( + Conv2dBias(in_channels // 2, in_channels // 2, kernel_size=3, size=size), + dim=wnorm_dim, + ) + self.lrelu2 = nn.LeakyReLU(lrelu_slope) + + self.conv2 = WeightNorm( + Conv2dBias(in_channels // 2, in_channels // 2, kernel_size=3, size=size), + dim=wnorm_dim, + ) + self.lrelu2 = nn.LeakyReLU(lrelu_slope) + + self.conv3 = WeightNorm( + Conv2dBias(in_channels // 2, out_channels, kernel_size=3, size=size), + dim=wnorm_dim, + ) + self.lrelu3 = nn.LeakyReLU(lrelu_slope) + + def forward(self, x): + x_up = self.upsample(x) + x_skip = self.conv_resize(x_up) + + x = x_up + x = self.conv1(x) + x = self.lrelu1(x) + x = self.conv2(x) + x = self.lrelu2(x) + x = self.conv3(x) + x = self.lrelu3(x) + + return x + x_skip + + +class UpConvCondBlock(nn.Module): + def __init__(self, in_channels, out_channels, size, cond_channels, lrelu_slope=0.2): + super().__init__() + self.upsample = nn.UpsamplingBilinear2d(size) + self.conv_resize = nn.utils.weight_norm(nn.Conv2d(in_channels, out_channels, 1), dim=None) + self.conv1 = WeightNorm( + Conv2dBias(in_channels + cond_channels, in_channels, kernel_size=3, size=size), + ) + self.lrelu1 = nn.LeakyReLU(lrelu_slope) + self.conv2 = WeightNorm( + Conv2dBias(in_channels, out_channels, kernel_size=3, size=size), + ) + self.lrelu2 = nn.LeakyReLU(lrelu_slope) + + def forward(self, x, cond): + x_up = self.upsample(x) + x_skip = self.conv_resize(x_up) + + x = x_up + x = th.cat([x, cond], dim=1) + x = self.conv1(x) + x = self.lrelu1(x) + x = self.conv2(x) + x = self.lrelu2(x) + + return x + x_skip + + +class UpConvBlockPS(nn.Module): + # pyre-ignore + def __init__(self, n_in, n_out, size, kernel_size=3, padding=1): + super().__init__() + self.conv1 = la.Conv2dWNUB( + n_in, + n_out * 4, + size, + size, + kernel_size=kernel_size, + padding=padding, + ) + self.lrelu = nn.LeakyReLU(0.2, inplace=True) + self.ps = nn.PixelShuffle(2) + + def forward(self, x): + x = self.conv(x) + x = self.lrelu(x) + return self.ps(x) + + +# pyre-ignore +def apply_crop( + image, + ymin, + ymax, + xmin, + xmax, +): + """Crops a region from an image.""" + # NOTE: here we are expecting one of [H, W] [H, W, C] [B, H, W, C] + if len(image.shape) == 2: + return image[ymin:ymax, xmin:xmax] + elif len(image.shape) == 3: + return image[ymin:ymax, xmin:xmax, :] + elif len(image.shape) == 4: + return image[:, ymin:ymax, xmin:xmax, :] + else: + raise ValueError("provide a batch of images or a single image") + + +def tile1d(x, size): + """Tile a given set of features into a convolutional map. + Args: + x: float tensor of shape [N, F] + size: int or a tuple + Returns: + a feature map [N, F, ∑size[0], size[1]] + """ + # size = size if isinstance(size, tuple) else (size, size) + return x[:, :, np.newaxis].expand(-1, -1, size) + + +def tile2d(x, size: int): + """Tile a given set of features into a convolutional map. + + Args: + x: float tensor of shape [N, F] + size: int or a tuple + + Returns: + a feature map [N, F, size[0], size[1]] + """ + # size = size if isinstance(size, tuple) else (size, size) + # NOTE: expecting only int here (!!!) + return x[:, :, np.newaxis, np.newaxis].expand(-1, -1, size, size) + + +def sample_negative_idxs(size, *args, **kwargs): + idxs = th.randperm(size, *args, **kwargs) + if th.all(idxs == th.arange(size, dtype=idxs.dtype, device=idxs.device)): + return th.flip(idxs, (0,)) + return idxs + + +def icnr_init(x, scale=2, init=nn.init.kaiming_normal_): + ni, nf, h, w = x.shape + ni2 = int(ni / (scale**2)) + k = init(x.new_zeros([ni2, nf, h, w])).transpose(0, 1) + k = k.contiguous().view(ni2, nf, -1) + k = k.repeat(1, 1, scale**2) + return k.contiguous().view([nf, ni, h, w]).transpose(0, 1) + + +class PixelShuffleWN(nn.Module): + """PixelShuffle with the right initialization. + + NOTE: make sure to create this one + """ + + def __init__(self, n_in, n_out, upscale_factor=2): + super().__init__() + self.upscale_factor = upscale_factor + self.n_in = n_in + self.n_out = n_out + self.conv = la.Conv2dWN(n_in, n_out * (upscale_factor**2), kernel_size=1, padding=0) + # NOTE: the bias is 2K? + self.ps = nn.PixelShuffle(upscale_factor) + self._init_icnr() + + def _init_icnr(self): + self.conv.weight_v.data.copy_(icnr_init(self.conv.weight_v.data)) + self.conv.weight_g.data.copy_( + ((self.conv.weight_v.data**2).sum(dim=[1, 2, 3]) ** 0.5)[:, None, None, None] + ) + + def forward(self, x): + x = self.conv(x) + return self.ps(x) + + +class UpscaleNet(nn.Module): + def __init__(self, in_channels, out_channels=3, n_ftrs=16, size=1024, upscale_factor=2): + super().__init__() + + self.conv_block = nn.Sequential( + la.Conv2dWNUB(in_channels, n_ftrs, size, size, kernel_size=3, padding=1), + nn.LeakyReLU(0.2, inplace=True), + la.Conv2dWNUB(n_ftrs, n_ftrs, size, size, kernel_size=3, padding=1), + nn.LeakyReLU(0.2, inplace=True), + ) + + self.out_block = la.Conv2dWNUB( + n_ftrs, + out_channels * upscale_factor**2, + size, + size, + kernel_size=1, + padding=0, + ) + + self.pixel_shuffle = nn.PixelShuffle(upscale_factor=upscale_factor) + self.apply(lambda x: la.glorot(x, 0.2)) + self.out_block.apply(weights_initializer(1.0)) + + def forward(self, x): + x = self.conv_block(x) + x = self.out_block(x) + return self.pixel_shuffle(x) + + diff --git a/visualize/ca_body/nn/color_cal.py b/visualize/ca_body/nn/color_cal.py new file mode 100644 index 0000000000000000000000000000000000000000..a866f646f5bad64673ff9325ef49917fbe7012aa --- /dev/null +++ b/visualize/ca_body/nn/color_cal.py @@ -0,0 +1,322 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import logging +from typing import Any, Dict, List, Mapping, Optional, Sequence + +import numpy as np +import torch as th +from visualize.ca_body.utils.torch import ParamHolder + +from typing import List + +import torch as th +import torch.nn as nn + +from torchvision.transforms.functional import gaussian_blur + + +logger: logging.Logger = logging.getLogger(__name__) + + +def scale_hook(grad: Optional[th.Tensor], scale: float) -> Optional[th.Tensor]: + if grad is not None: + grad = grad * scale + return grad + + +class CalBase(th.nn.Module): + def name_to_idx(self, cam_names: Sequence[str]) -> th.Tensor: + ... + + +class Identity(th.nn.Module): + def __init__( + self, + cameras: List[str], + identity_camera: str, + ) -> None: + super().__init__() + + if identity_camera not in cameras: + identity_camera = cameras[0] + logger.warning( + f"Requested color-calibration identity camera not present, defaulting to {identity_camera}." + ) + + self.identity_camera = identity_camera + self.cameras = cameras + self.holder = ParamHolder( + (3 + 3,), cameras, init_value=th.tensor([1, 1, 1, 0, 0, 0], dtype=th.float32) + ) + + def name_to_idx(self, cam_names: Sequence[str]) -> th.Tensor: + return self.holder.to_idx(cam_names) + + def forward(self, image: th.Tensor, cam_idxs: th.Tensor) -> th.Tensor: + return image + + +class CalV3(CalBase): + # pyre-fixme[2]: Parameter must be annotated. + def __init__(self, cameras, identity_camera) -> None: + super(CalBase, self).__init__() + # pyre-fixme[4]: Attribute must be annotated. + self.cameras = cameras + + self.conv = th.nn.ModuleList( + [th.nn.Conv2d(3, 3, 1, 1, 0, groups=3) for i in range(len(cameras))] + ) + + for i in range(len(cameras)): + winit = [[1.0], [1.0], [1.0]] + self.conv[i].weight.data[:] = th.from_numpy( + np.array(winit, dtype=np.float32)[:, :, None, None] + ) + self.conv[i].bias.data.zero_() + + if identity_camera not in cameras: + identity_camera = cameras[0] + logger.warning( + f"Requested color-calibration identity camera not present, defaulting to {identity_camera}." + ) + + iidx = cameras.index(identity_camera) + self.conv[iidx].weight.requires_grad = False + self.conv[iidx].bias.requires_grad = False + + def name_to_idx(self, cam_names: Sequence[str]) -> th.Tensor: + dev = next(self.parameters()).device + return th.tensor([self.cameras.index(cn) for cn in cam_names], device=dev, dtype=th.long) + + def forward(self, image: th.Tensor, cam: th.Tensor) -> th.Tensor: + return th.cat([self.conv[cam[i]](image[i : i + 1, :, :, :]) for i in range(image.size(0))]) + + +class CalV5(CalBase): + def __init__( + self, + # pyre-fixme[2]: Parameter must be annotated. + cameras, + # pyre-fixme[2]: Parameter must be annotated. + identity_camera, + gs_lrscale: float = 1e0, + col_lrscale: float = 1e-1, + ) -> None: + super(CalBase, self).__init__() + + if identity_camera not in cameras: + identity_camera = cameras[0] + logger.warning( + f"Requested color-calibration identity camera not present, defaulting to {identity_camera}." + ) + + # pyre-fixme[4]: Attribute must be annotated. + self.identity_camera = identity_camera + # pyre-fixme[4]: Attribute must be annotated. + self.cameras = cameras + self.gs_lrscale = gs_lrscale + self.col_lrscale = col_lrscale + self.holder: ParamHolder = ParamHolder( + # pyre-fixme[6]: For 1st param expected `Tuple[int]` but got `int`. + 3 + 3, + cameras, + init_value=th.FloatTensor([1, 1, 1, 0, 0, 0]), + ) + + # pyre-fixme[4]: Attribute must be annotated. + self.identity_idx = self.holder.to_idx([identity_camera]).item() + # pyre-fixme[4]: Attribute must be annotated. + self.grey_idxs = [self.holder.to_idx([c]).item() for c in cameras if c.startswith("41")] + + s = th.FloatTensor([0.37, 0.52, 0.52]) + self.holder.params.data[th.LongTensor(self.grey_idxs), :3] = s + + def name_to_idx(self, cam_names: Sequence[str]) -> th.Tensor: + return self.holder.to_idx(cam_names) + + # pyre-fixme[2]: Parameter must be annotated. + def initialize_from_texs(self, ds) -> float: + tex_mean = ds.tex_mean.permute(1, 2, 0) + texs = {} + idx = 0 + while ds[idx] is None: + idx += 1 + + for cam in self.cameras: + samp = ds[idx, cam] + if samp is None: + continue + + tex = samp["tex"] + texs[cam] = tex.permute(1, 2, 0) + + stats = {} + for cam in texs.keys(): + t = texs[cam] + mask = (t > 0).all(dim=2) + t = t * ds.tex_std + tex_mean + stats[cam] = (t[mask].mean(dim=0), t[mask].std(dim=0)) + + normstats = {} + for cam in texs.keys(): + mean, std = stats[cam] + imean, istd = stats[self.identity_camera] + scale = istd / std + bias = imean - scale * mean + normstats[cam] = (scale.clamp(max=2), bias) + + for cam, nstats in normstats.items(): + cidx = self.name_to_idx([cam])[0] + if cidx in self.grey_idxs: + nstats = (nstats[0] / 3, nstats[1] / 3) + self.holder.params.data[cidx, 0:3] = nstats[0] + self.holder.params.data[cidx, 3:6] = nstats[1] + return len(stats.keys()) / len(ds.cameras) + + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. + # pyre-fixme[14]: `load_state_dict` overrides method defined in `Module` + # inconsistently. + def load_state_dict(self, state_dict, strict: bool = True): + state_dict = {k[7:]: v for k, v in state_dict.items() if k.startswith("holder.")} + return self.holder.load_state_dict(state_dict, strict=strict) + + # pyre-fixme[14]: `state_dict` overrides method defined in `Module` inconsistently. + # pyre-fixme[3]: Return type must be annotated. + def state_dict( + self, + # pyre-fixme[2]: Parameter must be annotated. + destination=None, + prefix: str = "", + keep_vars: bool = False, + saving: bool = False, + ): + sd = super(CalBase, self).state_dict( + destination=destination, prefix=prefix, keep_vars=keep_vars + ) + if saving: + sd[prefix + "holder.key_list"] = self.holder.key_list + return sd + + def forward(self, image: th.Tensor, cam_idxs: th.Tensor) -> th.Tensor: + params = self.holder(cam_idxs) + outs = [] + hook_scales = [] + for i in range(cam_idxs.shape[0]): + idx = cam_idxs[i] + img = image[i : i + 1] + if idx == self.identity_idx: + outs.append(img) + hook_scales.append(1) + continue + + w, b = params[i, :3], params[i, 3:] + if idx in self.grey_idxs: + b = b.sum() + out = (img * w[None, :, None, None]).sum(dim=1, keepdim=True).expand( + -1, 3, -1, -1 + ) + b + else: + out = img * w[None, :, None, None] + b[None, :, None, None] + outs.append(out) + hook_scales.append(self.gs_lrscale if idx in self.grey_idxs else self.col_lrscale) + + hook_scales = th.tensor(hook_scales, device=image.device, dtype=th.float32) + cal_out = th.cat(outs) + + if self.training and params.requires_grad: + params.register_hook(lambda g, hs=hook_scales: scale_hook(g, hs[:, None])) + return cal_out + + +class CalV6(CalBase): + """ + A faster version of CalV5, which also does not cause CUDA synchronization. It does not support gray + cameras. + """ + + def __init__( + self, + cameras: List[str], + identity_camera: str, + ) -> None: + """ + Args: + cameras (List[str]): A list of cameras. + + identity_camera (str): Name of identity camera. + """ + super(CalBase, self).__init__() + + if identity_camera not in cameras: + identity_camera = cameras[0] + logger.warning( + f"Requested color-calibration identity camera not present, defaulting to {identity_camera}." + ) + + if any(c.startswith("41") for c in cameras): + raise ValueError("Gray cameras are not supported") + + self.identity_camera = identity_camera + self.cameras = cameras + self.holder = ParamHolder( + (3 + 3,), cameras, init_value=th.as_tensor([1, 1, 1, 0, 0, 0], dtype=th.float32) + ) + self.identity_idx: int = self.holder.key_list.index(identity_camera) + self.register_buffer( + "identity", + th.as_tensor([1, 1, 1, 0, 0, 0], dtype=th.float32)[None].expand(len(cameras), -1), + persistent=False, + ) + identity_w = th.zeros_like(self.identity) + identity_w[self.identity_idx, :] = 1.0 + self.register_buffer("identity_w", identity_w, persistent=False) + + # pyre-fixme[14]: `load_state_dict` overrides method defined in `Module` + # inconsistently. + def load_state_dict( + self, state_dict: Mapping[str, Any], strict: bool = True + ) -> th.nn.modules.module._IncompatibleKeys: + state_dict = {k[7:]: v for k, v in state_dict.items() if k.startswith("holder.")} + return self.holder.load_state_dict(state_dict, strict=strict) + + def name_to_idx(self, cam_names: Sequence[str]) -> th.Tensor: + dev = next(self.parameters()).device + return th.tensor([self.cameras.index(cn) for cn in cam_names], device=dev, dtype=th.long) + + # pyre-fixme[14]: `state_dict` overrides method defined in `Module` inconsistently. + def state_dict( + self, + destination: Optional[Mapping[str, Any]] = None, + prefix: str = "", + keep_vars: bool = False, + saving: bool = False, + ) -> Mapping[str, Any]: + sd = super(CalBase, self).state_dict( + destination=destination, prefix=prefix, keep_vars=keep_vars + ) + if saving: + sd[prefix + "holder.key_list"] = self.holder.key_list + return sd + + def forward(self, image: th.Tensor, cam_idxs: th.Tensor) -> th.Tensor: + params = th.lerp(self.holder.params, self.identity, self.identity_w)[ + cam_idxs, :, None, None + ] + w, b = params[:, :3], params[:, 3:] + return th.addcmul(b, w, image) + + +def make_cal(version: str, cal_kwargs: Dict[str, Any]) -> CalBase: + cal_registry = {"v3": CalV3, "v5": CalV5, "v6": CalV6} + + if version not in cal_registry: + raise ValueError(f"{version} not in {cal_registry.keys()}") + + return cal_registry[version](**cal_kwargs) + diff --git a/visualize/ca_body/nn/dof_cal.py b/visualize/ca_body/nn/dof_cal.py new file mode 100644 index 0000000000000000000000000000000000000000..e8205703a1d55ac2bdd57b3120816587869d499c --- /dev/null +++ b/visualize/ca_body/nn/dof_cal.py @@ -0,0 +1,52 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from typing import List + +import torch as th +import torch.nn as nn + +from torchvision.transforms.functional import gaussian_blur + + +class LearnableBlur(nn.Module): + # TODO: should we make this conditional? + def __init__(self, cameras: List[str]) -> None: + super().__init__() + self.cameras = cameras + self.register_parameter( + "weights_raw", nn.Parameter(th.ones(len(cameras), 3, dtype=th.float32)) + ) + + def name_to_idx(self, cameras: List[str]) -> th.Tensor: + return th.tensor( + [self.cameras.index(c) for c in cameras], + device=self.weights_raw.device, + dtype=th.long, + ) + + # pyre-ignore + def reg(self, cameras: List[str]): + # pyre-ignore + idxs = self.name_to_idx(cameras) + # pyre-ignore + return self.weights_raw[idxs] + + # pyre-ignore + def forward(self, img: th.Tensor, cameras: List[str]): + B = img.shape[0] + # B, C, H, W + idxs = self.name_to_idx(cameras) + # TODO: mask? + # pyre-ignore + weights = th.softmax(self.weights_raw[idxs], dim=-1) + weights = weights.reshape(B, 3, 1, 1, 1) + return ( + weights[:, 0] * img + + weights[:, 1] * gaussian_blur(img, [3, 3]) + + weights[:, 2] * gaussian_blur(img, [7, 7]) + ) diff --git a/visualize/ca_body/nn/face.py b/visualize/ca_body/nn/face.py new file mode 100644 index 0000000000000000000000000000000000000000..9be98c309b99b3b5e7bb53b644b9848c7b9fe668 --- /dev/null +++ b/visualize/ca_body/nn/face.py @@ -0,0 +1,85 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from typing import Dict, Tuple + +import numpy as np +import torch as th +import torch.nn as nn + +import visualize.ca_body.nn.layers as la +from attrdict import AttrDict + + +class FaceDecoderFrontal(nn.Module): + def __init__( + self, + assets: AttrDict, + n_latent: int = 256, + n_vert_out: int = 3 * 7306, + tex_out_shp: Tuple[int, int] = (1024, 1024), + tex_roi: Tuple[Tuple[int, int], Tuple[int, int]] = ((0, 0), (1024, 1024)), + ) -> None: + super().__init__() + self.n_latent = n_latent + self.n_vert_out = n_vert_out + self.tex_roi = tex_roi + self.tex_roi_shp: Tuple[int, int] = tuple( + [int(i) for i in np.diff(np.array(tex_roi), axis=0).squeeze()] + ) + self.tex_out_shp = tex_out_shp + + self.encmod = nn.Sequential( + la.LinearWN(n_latent, 256), nn.LeakyReLU(0.2, inplace=True) + ) + self.geommod = nn.Sequential(la.LinearWN(256, n_vert_out)) + + self.viewmod = nn.Sequential(la.LinearWN(3, 8), nn.LeakyReLU(0.2, inplace=True)) + self.texmod2 = nn.Sequential( + la.LinearWN(256 + 8, 256 * 4 * 4), nn.LeakyReLU(0.2, inplace=True) + ) + self.texmod = nn.Sequential( + la.ConvTranspose2dWNUB(256, 256, 8, 8, 4, 2, 1), + nn.LeakyReLU(0.2, inplace=True), + la.ConvTranspose2dWNUB(256, 128, 16, 16, 4, 2, 1), + nn.LeakyReLU(0.2, inplace=True), + la.ConvTranspose2dWNUB(128, 128, 32, 32, 4, 2, 1), + nn.LeakyReLU(0.2, inplace=True), + la.ConvTranspose2dWNUB(128, 64, 64, 64, 4, 2, 1), + nn.LeakyReLU(0.2, inplace=True), + la.ConvTranspose2dWNUB(64, 64, 128, 128, 4, 2, 1), + nn.LeakyReLU(0.2, inplace=True), + la.ConvTranspose2dWNUB(64, 32, 256, 256, 4, 2, 1), + nn.LeakyReLU(0.2, inplace=True), + la.ConvTranspose2dWNUB(32, 8, 512, 512, 4, 2, 1), + nn.LeakyReLU(0.2, inplace=True), + la.ConvTranspose2dWNUB(8, 3, 1024, 1024, 4, 2, 1), + ) + + self.bias = nn.Parameter(th.zeros(3, self.tex_roi_shp[0], self.tex_roi_shp[1])) + self.bias.data.zero_() + + self.register_buffer( + "frontal_view", th.as_tensor(assets.face_frontal_view, dtype=th.float32) + ) + + self.apply(lambda x: la.glorot(x, 0.2)) + la.glorot(self.texmod[-1], 1.0) + + def forward(self, face_embs: th.Tensor) -> Dict[str, th.Tensor]: + B = face_embs.shape[0] + view = self.frontal_view[np.newaxis].expand(B, -1) + encout = self.encmod(face_embs) + geomout = self.geommod(encout) + viewout = self.viewmod(view) + encview = th.cat([encout, viewout], dim=1) + texout = self.texmod(self.texmod2(encview).view(-1, 256, 4, 4)) + out = {"face_geom": geomout.view(geomout.shape[0], -1, 3)} + out["face_tex_raw"] = texout + texout = texout + self.bias[None] + out["face_tex"] = 255 * (texout + 0.5) + return out diff --git a/visualize/ca_body/nn/layers.py b/visualize/ca_body/nn/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..b995e3921638db598b50e1bf48a2e168b95fdf3e --- /dev/null +++ b/visualize/ca_body/nn/layers.py @@ -0,0 +1,934 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import copy +import inspect +from typing import Any, Dict, List, Optional, Tuple, Type, Union + +import numpy as np +import torch as th +import torch.nn.functional as thf +from torch.nn import init +from torch.nn.modules.utils import _pair +from torch.nn.utils.weight_norm import remove_weight_norm, WeightNorm + +fc_default_activation = th.nn.LeakyReLU(0.2, inplace=True) + + +def gaussian_kernel(ksize: int, std: Optional[float] = None) -> np.ndarray: + """Generates numpy array filled in with Gaussian values. + + The function generates Gaussian kernel (values according to the Gauss distribution) + on the grid according to the kernel size. + + Args: + ksize (int): The kernel size, must be odd number larger than 1. Otherwise throws an exception. + std (float): The standard deviation, could be None, in which case it will be calculated + accordoing to the kernel size. + + Returns: + np.array: The gaussian kernel. + + """ + + assert ksize % 2 == 1 + radius = ksize // 2 + if std is None: + std = np.sqrt(-(radius**2) / (2 * np.log(0.05))) + + x, y = np.meshgrid(np.linspace(-radius, radius, ksize), np.linspace(-radius, radius, ksize)) + xy = np.stack([x, y], axis=2) + gk = np.exp(-(xy**2).sum(-1) / (2 * std**2)) + gk /= gk.sum() + return gk + + +class FCLayer(th.nn.Module): + # pyre-fixme[2]: Parameter must be annotated. + def __init__(self, n_in, n_out, nonlin=fc_default_activation) -> None: + super().__init__() + self.fc = th.nn.Linear(n_in, n_out, bias=True) + # pyre-fixme[4]: Attribute must be annotated. + self.nonlin = nonlin if nonlin is not None else lambda x: x + + self.fc.bias.data.fill_(0) + th.nn.init.xavier_uniform_(self.fc.weight.data) + + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. + def forward(self, x): + x = self.fc(x) + x = self.nonlin(x) + return x + + +# pyre-fixme[2]: Parameter must be annotated. +def check_args_shadowing(name, method: object, arg_names) -> None: + spec = inspect.getfullargspec(method) + init_args = {*spec.args, *spec.kwonlyargs} + for arg_name in arg_names: + if arg_name in init_args: + raise TypeError(f"{name} attempted to shadow a wrapped argument: {arg_name}") + + +# For backward compatibility. +class TensorMappingHook(object): + def __init__( + self, + name_mapping: List[Tuple[str, str]], + expected_shape: Optional[Dict[str, List[int]]] = None, + ) -> None: + """This hook is expected to be used with "_register_load_state_dict_pre_hook" to + modify names and tensor shapes in the loaded state dictionary. + + Args: + name_mapping: list of string tuples + A list of tuples containing expected names from the state dict and names expected + by the module. + + expected_shape: dict + A mapping from parameter names to expected tensor shapes. + """ + self.name_mapping = name_mapping + # pyre-fixme[4]: Attribute must be annotated. + self.expected_shape = expected_shape if expected_shape is not None else {} + + def __call__( + self, + # pyre-fixme[2]: Parameter must be annotated. + state_dict, + # pyre-fixme[2]: Parameter must be annotated. + prefix, + # pyre-fixme[2]: Parameter must be annotated. + local_metadata, + # pyre-fixme[2]: Parameter must be annotated. + strict, + # pyre-fixme[2]: Parameter must be annotated. + missing_keys, + # pyre-fixme[2]: Parameter must be annotated. + unexpected_keys, + # pyre-fixme[2]: Parameter must be annotated. + error_msgs, + ) -> None: + for old_name, new_name in self.name_mapping: + if prefix + old_name in state_dict: + tensor = state_dict.pop(prefix + old_name) + if new_name in self.expected_shape: + tensor = tensor.view(*self.expected_shape[new_name]) + state_dict[prefix + new_name] = tensor + + +# pyre-fixme[3]: Return type must be annotated. +def weight_norm_wrapper( + cls: Type[th.nn.Module], + new_cls_name: str, + name: str = "weight", + g_dim: int = 0, + v_dim: Optional[int] = 0, +): + """Wraps a torch.nn.Module class to support weight normalization. The wrapped class + is compatible with the fuse/unfuse syntax and is able to load state dict from previous + implementations. + + Args: + cls: Type[th.nn.Module] + Class to apply the wrapper to. + + new_cls_name: str + Name of the new class created by the wrapper. This should be the name + of whatever variable you assign the result of this function to. Ex: + ``SomeLayerWN = weight_norm_wrapper(SomeLayer, "SomeLayerWN", ...)`` + + name: str + Name of the parameter to apply weight normalization to. + + g_dim: int + Learnable dimension of the magnitude tensor. Set to None or -1 for single scalar magnitude. + Default values for Linear and Conv2d layers are 0s and for ConvTranspose2d layers are 1s. + + v_dim: int + Of which dimension of the direction tensor is calutated independently for the norm. Set to + None or -1 for calculating norm over the entire direction tensor (weight tensor). Default + values for most of the WN layers are None to preserve the existing behavior. + """ + + class Wrap(cls): + def __init__(self, *args: Any, name=name, g_dim=g_dim, v_dim=v_dim, **kwargs: Any): + # Check if the extra arguments are overwriting arguments for the wrapped class + check_args_shadowing( + "weight_norm_wrapper", super().__init__, ["name", "g_dim", "v_dim"] + ) + super().__init__(*args, **kwargs) + + # Sanitize v_dim since we are hacking the built-in utility to support + # a non-standard WeightNorm implementation. + if v_dim is None: + v_dim = -1 + self.weight_norm_args = {"name": name, "g_dim": g_dim, "v_dim": v_dim} + self.is_fused = True + self.unfuse() + + # For backward compatibility. + self._register_load_state_dict_pre_hook( + TensorMappingHook( + [(name, name + "_v"), ("g", name + "_g")], + {name + "_g": getattr(self, name + "_g").shape}, + ) + ) + + def fuse(self): + if self.is_fused: + return + # Check if the module is frozen. + param_name = self.weight_norm_args["name"] + "_g" + if hasattr(self, param_name) and param_name not in self._parameters: + raise ValueError("Trying to fuse frozen module.") + remove_weight_norm(self, self.weight_norm_args["name"]) + self.is_fused = True + + def unfuse(self): + if not self.is_fused: + return + # Check if the module is frozen. + param_name = self.weight_norm_args["name"] + if hasattr(self, param_name) and param_name not in self._parameters: + raise ValueError("Trying to unfuse frozen module.") + wn = WeightNorm.apply( + self, self.weight_norm_args["name"], self.weight_norm_args["g_dim"] + ) + # Overwrite the dim property to support mismatched norm calculate for v and g tensor. + if wn.dim != self.weight_norm_args["v_dim"]: + wn.dim = self.weight_norm_args["v_dim"] + # Adjust the norm values. + weight = getattr(self, self.weight_norm_args["name"] + "_v") + norm = getattr(self, self.weight_norm_args["name"] + "_g") + norm.data[:] = th.norm_except_dim(weight, 2, wn.dim) + self.is_fused = False + + def __deepcopy__(self, memo): + # Delete derived tensor to avoid deepcopy error. + if not self.is_fused: + delattr(self, self.weight_norm_args["name"]) + + # Deepcopy. + cls = self.__class__ + result = cls.__new__(cls) + memo[id(self)] = result + for k, v in self.__dict__.items(): + setattr(result, k, copy.deepcopy(v, memo)) + + if not self.is_fused: + setattr(result, self.weight_norm_args["name"], None) + setattr(self, self.weight_norm_args["name"], None) + return result + + # Allows for pickling of the wrapper: https://bugs.python.org/issue13520 + Wrap.__qualname__ = new_cls_name + + return Wrap + + +# pyre-fixme[2]: Parameter must be annotated. +def is_weight_norm_wrapped(module) -> bool: + for hook in module._forward_pre_hooks.values(): + if isinstance(hook, WeightNorm): + return True + return False + + +class Conv2dUB(th.nn.Conv2d): + def __init__( + self, + in_channels: int, + out_channels: int, + height: int, + width: int, + # pyre-fixme[2]: Parameter must be annotated. + *args, + bias: bool = True, + # pyre-fixme[2]: Parameter must be annotated. + **kwargs, + ) -> None: + """Conv2d with untied bias.""" + super().__init__(in_channels, out_channels, *args, bias=False, **kwargs) + # pyre-fixme[4]: Attribute must be annotated. + self.bias = th.nn.Parameter(th.zeros(out_channels, height, width)) if bias else None + + # TODO: remove this method once upgraded to pytorch 1.8 + # pyre-fixme[3]: Return type must be annotated. + def _conv_forward(self, input: th.Tensor, weight: th.Tensor, bias: Optional[th.Tensor]): + # Copied from pt1.8 source code. + if self.padding_mode != "zeros": + input = thf.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode) + return thf.conv2d( + input, weight, bias, self.stride, _pair(0), self.dilation, self.groups + ) + return thf.conv2d( + input, + weight, + bias, + self.stride, + # pyre-fixme[6]: For 5th param expected `Union[List[int], int, Size, + # typing.Tuple[int, ...]]` but got `Union[str, typing.Tuple[int, ...]]`. + self.padding, + self.dilation, + self.groups, + ) + + def forward(self, input: th.Tensor) -> th.Tensor: + output = self._conv_forward(input, self.weight, None) + bias = self.bias + if bias is not None: + # Assertion for jit script. + assert bias is not None + output = output + bias[None] + return output + + +class ConvTranspose2dUB(th.nn.ConvTranspose2d): + def __init__( + self, + in_channels: int, + out_channels: int, + height: int, + width: int, + # pyre-fixme[2]: Parameter must be annotated. + *args, + bias: bool = True, + # pyre-fixme[2]: Parameter must be annotated. + **kwargs, + ) -> None: + """ConvTranspose2d with untied bias.""" + super().__init__(in_channels, out_channels, *args, bias=False, **kwargs) + + if self.padding_mode != "zeros": + raise ValueError("Only `zeros` padding mode is supported for ConvTranspose2dUB") + + # pyre-fixme[4]: Attribute must be annotated. + self.bias = th.nn.Parameter(th.zeros(out_channels, height, width)) if bias else None + + def forward(self, input: th.Tensor, output_size: Optional[List[int]] = None) -> th.Tensor: + # TODO(T111390117): Fix Conv member annotations. + output_padding = self._output_padding( + input=input, + output_size=output_size, + # pyre-fixme[6]: For 3rd param expected `List[int]` but got + # `Tuple[int, ...]`. + stride=self.stride, + # pyre-fixme[6]: For 4th param expected `List[int]` but got + # `Union[str, typing.Tuple[int, ...]]`. + padding=self.padding, + # pyre-fixme[6]: For 5th param expected `List[int]` but got + # `Tuple[int, ...]`. + kernel_size=self.kernel_size, + # This is now required as of D35874490 + num_spatial_dims=input.dim() - 2, + # pyre-fixme[6]: For 6th param expected `Optional[List[int]]` but got + # `Tuple[int, ...]`. + dilation=self.dilation, + ) + + output = thf.conv_transpose2d( + input, + self.weight, + None, + self.stride, + # pyre-fixme[6]: For 5th param expected `Union[List[int], int, Size, + # typing.Tuple[int, ...]]` but got `Union[str, typing.Tuple[int, ...]]`. + self.padding, + output_padding, + self.groups, + self.dilation, + ) + bias = self.bias + if bias is not None: + # Assertion for jit script. + assert bias is not None + output = output + bias[None] + return output + + # NOTE: This function (on super _ConvTransposeNd) was updated in D35874490 with non-optional + # param num_spatial_dims added. Since we need both old/new pytorch versions to work (until those + # changes reach DGX), we're simply copying the updated code here until then. + # TODO remove this function once updated torch code is released to DGX + def _output_padding( + self, + input: th.Tensor, + output_size: Optional[List[int]], + stride: List[int], + padding: List[int], + kernel_size: List[int], + num_spatial_dims: int, + dilation: Optional[List[int]] = None, + ) -> List[int]: + if output_size is None: + # converting to list if was not already + ret = th.nn.modules.utils._single(self.output_padding) + else: + has_batch_dim = input.dim() == num_spatial_dims + 2 + num_non_spatial_dims = 2 if has_batch_dim else 1 + if len(output_size) == num_non_spatial_dims + num_spatial_dims: + output_size = output_size[num_non_spatial_dims:] + if len(output_size) != num_spatial_dims: + raise ValueError( + "ConvTranspose{}D: for {}D input, output_size must have {} or {} elements (got {})".format( + num_spatial_dims, + input.dim(), + num_spatial_dims, + num_non_spatial_dims + num_spatial_dims, + len(output_size), + ) + ) + + min_sizes = th.jit.annotate(List[int], []) + max_sizes = th.jit.annotate(List[int], []) + for d in range(num_spatial_dims): + dim_size = ( + (input.size(d + num_non_spatial_dims) - 1) * stride[d] + - 2 * padding[d] + + (dilation[d] if dilation is not None else 1) * (kernel_size[d] - 1) + + 1 + ) + min_sizes.append(dim_size) + max_sizes.append(min_sizes[d] + stride[d] - 1) + + for i in range(len(output_size)): + size = output_size[i] + min_size = min_sizes[i] + max_size = max_sizes[i] + if size < min_size or size > max_size: + raise ValueError( + ( + "requested an output size of {}, but valid sizes range " + "from {} to {} (for an input of {})" + ).format(output_size, min_sizes, max_sizes, input.size()[2:]) + ) + + res = th.jit.annotate(List[int], []) + for d in range(num_spatial_dims): + res.append(output_size[d] - min_sizes[d]) + + ret = res + return ret + + +# Set default g_dim=0 (Conv2d) or 1 (ConvTranspose2d) and v_dim=None to preserve +# the current weight norm behavior. +# pyre-fixme[5]: Global expression must be annotated. +LinearWN = weight_norm_wrapper(th.nn.Linear, "LinearWN", g_dim=0, v_dim=None) +# pyre-fixme[5]: Global expression must be annotated. +Conv2dWN = weight_norm_wrapper(th.nn.Conv2d, "Conv2dWN", g_dim=0, v_dim=None) +# pyre-fixme[5]: Global expression must be annotated. +Conv2dWNUB = weight_norm_wrapper(Conv2dUB, "Conv2dWNUB", g_dim=0, v_dim=None) +# pyre-fixme[5]: Global expression must be annotated. +ConvTranspose2dWN = weight_norm_wrapper( + th.nn.ConvTranspose2d, "ConvTranspose2dWN", g_dim=1, v_dim=None +) +# pyre-fixme[5]: Global expression must be annotated. +ConvTranspose2dWNUB = weight_norm_wrapper( + ConvTranspose2dUB, "ConvTranspose2dWNUB", g_dim=1, v_dim=None +) + + +class InterpolateHook(object): + # pyre-fixme[2]: Parameter must be annotated. + def __init__(self, size=None, scale_factor=None, mode: str = "bilinear") -> None: + """An object storing options for interpolate function""" + # pyre-fixme[4]: Attribute must be annotated. + self.size = size + # pyre-fixme[4]: Attribute must be annotated. + self.scale_factor = scale_factor + self.mode = mode + + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. + def __call__(self, module, x): + assert len(x) == 1, "Module should take only one input for the forward method." + return thf.interpolate( + x[0], + size=self.size, + scale_factor=self.scale_factor, + mode=self.mode, + align_corners=False, + ) + + +# pyre-fixme[3]: Return type must be annotated. +def interpolate_wrapper(cls: Type[th.nn.Module], new_cls_name: str): + """Wraps a torch.nn.Module class and perform additional interpolation on the + first and only positional input of the forward method. + + Args: + cls: Type[th.nn.Module] + Class to apply the wrapper to. + + new_cls_name: str + Name of the new class created by the wrapper. This should be the name + of whatever variable you assign the result of this function to. Ex: + ``UpConv = interpolate_wrapper(Conv, "UpConv", ...)`` + + """ + + class Wrap(cls): + def __init__( + self, *args: Any, size=None, scale_factor=None, mode="bilinear", **kwargs: Any + ): + check_args_shadowing( + "interpolate_wrapper", super().__init__, ["size", "scale_factor", "mode"] + ) + super().__init__(*args, **kwargs) + self.register_forward_pre_hook( + InterpolateHook(size=size, scale_factor=scale_factor, mode=mode) + ) + + # Allows for pickling of the wrapper: https://bugs.python.org/issue13520 + Wrap.__qualname__ = new_cls_name + return Wrap + + +# pyre-fixme[5]: Global expression must be annotated. +UpConv2d = interpolate_wrapper(th.nn.Conv2d, "UpConv2d") +# pyre-fixme[5]: Global expression must be annotated. +UpConv2dWN = interpolate_wrapper(Conv2dWN, "UpConv2dWN") +# pyre-fixme[5]: Global expression must be annotated. +UpConv2dWNUB = interpolate_wrapper(Conv2dWNUB, "UpConv2dWNUB") + + +class GlobalAvgPool(th.nn.Module): + def __init__(self) -> None: + super().__init__() + + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. + def forward(self, x): + return x.view(x.shape[0], x.shape[1], -1).mean(dim=2) + + +class Upsample(th.nn.Module): + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__() + # pyre-fixme[4]: Attribute must be annotated. + self.args = args + # pyre-fixme[4]: Attribute must be annotated. + self.kwargs = kwargs + + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. + def forward(self, x): + return thf.interpolate(x, *self.args, **self.kwargs) + + +class DenseAffine(th.nn.Module): + # Per-pixel affine transform layer. + + # pyre-fixme[2]: Parameter must be annotated. + def __init__(self, shape) -> None: + super().__init__() + + self.W = th.nn.Parameter(th.ones(*shape)) + self.b = th.nn.Parameter(th.zeros(*shape)) + + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. + def forward(self, x, scale=None, crop=None): + W = self.W + b = self.b + + if scale is not None: + W = thf.interpolate(W, scale_factor=scale, mode="bilinear") + b = thf.interpolate(b, scale_factor=scale, mode="bilinear") + + if crop is not None: + W = W[..., crop[0] : crop[1], crop[2] : crop[3]] + b = b[..., crop[0] : crop[1], crop[2] : crop[3]] + + return x * W + b + + +def glorot(m: th.nn.Module, alpha: float = 1.0) -> None: + gain = np.sqrt(2.0 / (1.0 + alpha**2)) + + if isinstance(m, th.nn.Conv2d): + ksize = m.kernel_size[0] * m.kernel_size[1] + n1 = m.in_channels + n2 = m.out_channels + + std = gain * np.sqrt(2.0 / ((n1 + n2) * ksize)) + elif isinstance(m, th.nn.ConvTranspose2d): + ksize = m.kernel_size[0] * m.kernel_size[1] // 4 + n1 = m.in_channels + n2 = m.out_channels + + std = gain * np.sqrt(2.0 / ((n1 + n2) * ksize)) + elif isinstance(m, th.nn.ConvTranspose3d): + ksize = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] // 8 + n1 = m.in_channels + n2 = m.out_channels + + std = gain * np.sqrt(2.0 / ((n1 + n2) * ksize)) + elif isinstance(m, th.nn.Linear): + n1 = m.in_features + n2 = m.out_features + + std = gain * np.sqrt(2.0 / (n1 + n2)) + else: + return + + is_wnw = is_weight_norm_wrapped(m) + if is_wnw: + m.fuse() + + m.weight.data.uniform_(-std * np.sqrt(3.0), std * np.sqrt(3.0)) + if m.bias is not None: + m.bias.data.zero_() + + if isinstance(m, th.nn.ConvTranspose2d): + # hardcoded for stride=2 for now + m.weight.data[:, :, 0::2, 1::2] = m.weight.data[:, :, 0::2, 0::2] + m.weight.data[:, :, 1::2, 0::2] = m.weight.data[:, :, 0::2, 0::2] + m.weight.data[:, :, 1::2, 1::2] = m.weight.data[:, :, 0::2, 0::2] + + if is_wnw: + m.unfuse() + + +def make_tuple(x: Union[int, Tuple[int, int]], n: int) -> Tuple[int, int]: + if isinstance(x, int): + return tuple([x for _ in range(n)]) + else: + return x + + +class LinearELR(th.nn.Module): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + gain: Optional[float] = None, + lr_mul: float = 1.0, + bias_lr_mul: Optional[float] = None, + ) -> None: + super(LinearELR, self).__init__() + self.in_features = in_features + self.weight = th.nn.Parameter(th.zeros(out_features, in_features, dtype=th.float32)) + if bias: + self.bias: th.nn.Parameter = th.nn.Parameter(th.zeros(out_features, dtype=th.float32)) + else: + self.register_parameter("bias", None) + self.std: float = 0.0 + if gain is None: + self.gain: float = np.sqrt(2.0) + else: + self.gain: float = gain + self.lr_mul = lr_mul + if bias_lr_mul is None: + bias_lr_mul = lr_mul + self.bias_lr_mul = bias_lr_mul + self.reset_parameters() + + def reset_parameters(self) -> None: + self.std = self.gain / np.sqrt(self.in_features) * self.lr_mul + init.normal_(self.weight, mean=0, std=1.0 / self.lr_mul) + + if self.bias is not None: + with th.no_grad(): + self.bias.zero_() + + def forward(self, x: th.Tensor) -> th.Tensor: + bias = self.bias + if bias is not None: + bias = bias * self.bias_lr_mul + return thf.linear(x, self.weight.mul(self.std), bias) + + +class Conv2dELR(th.nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int]], + stride: Union[int, Tuple[int, int]] = 1, + padding: Union[int, Tuple[int, int]] = 0, + output_padding: Union[int, Tuple[int, int]] = 0, + dilation: Union[int, Tuple[int, int]] = 1, + groups: int = 1, + bias: bool = True, + untied: bool = False, + height: int = 1, + width: int = 1, + gain: Optional[float] = None, + transpose: bool = False, + fuse_box_filter: bool = False, + lr_mul: float = 1.0, + bias_lr_mul: Optional[float] = None, + ) -> None: + super().__init__() + if in_channels % groups != 0: + raise ValueError("in_channels must be divisible by groups") + if out_channels % groups != 0: + raise ValueError("out_channels must be divisible by groups") + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size: Tuple[int, int] = make_tuple(kernel_size, 2) + self.stride: Tuple[int, int] = make_tuple(stride, 2) + self.padding: Tuple[int, int] = make_tuple(padding, 2) + self.output_padding: Tuple[int, int] = make_tuple(output_padding, 2) + self.dilation: Tuple[int, int] = make_tuple(dilation, 2) + self.groups = groups + if gain is None: + self.gain: float = np.sqrt(2.0) + else: + self.gain: float = gain + self.lr_mul = lr_mul + if bias_lr_mul is None: + bias_lr_mul = lr_mul + self.bias_lr_mul = bias_lr_mul + self.transpose = transpose + self.fan_in: float = np.prod(self.kernel_size) * in_channels // groups + self.fuse_box_filter = fuse_box_filter + if transpose: + self.weight: th.nn.Parameter = th.nn.Parameter( + th.zeros(in_channels, out_channels // groups, *self.kernel_size, dtype=th.float32) + ) + else: + self.weight: th.nn.Parameter = th.nn.Parameter( + th.zeros(out_channels, in_channels // groups, *self.kernel_size, dtype=th.float32) + ) + if bias: + if untied: + self.bias: th.nn.Parameter = th.nn.Parameter( + th.zeros(out_channels, height, width, dtype=th.float32) + ) + else: + self.bias: th.nn.Parameter = th.nn.Parameter( + th.zeros(out_channels, dtype=th.float32) + ) + else: + self.register_parameter("bias", None) + self.untied = untied + self.std: float = 0.0 + self.reset_parameters() + + def reset_parameters(self) -> None: + self.std = self.gain / np.sqrt(self.fan_in) * self.lr_mul + init.normal_(self.weight, mean=0, std=1.0 / self.lr_mul) + + if self.bias is not None: + with th.no_grad(): + self.bias.zero_() + + def forward(self, x: th.Tensor) -> th.Tensor: + if self.transpose: + w = self.weight + if self.fuse_box_filter: + w = thf.pad(w, (1, 1, 1, 1), mode="constant") + w = w[:, :, 1:, 1:] + w[:, :, :-1, 1:] + w[:, :, 1:, :-1] + w[:, :, :-1, :-1] + bias = self.bias + if bias is not None: + bias = bias * self.bias_lr_mul + out = thf.conv_transpose2d( + x, + w * self.std, + bias if not self.untied else None, + stride=self.stride, + padding=self.padding, + output_padding=self.output_padding, + dilation=self.dilation, + groups=self.groups, + ) + if self.untied and bias is not None: + out = out + bias[None, ...] + return out + else: + w = self.weight + if self.fuse_box_filter: + w = thf.pad(w, (1, 1, 1, 1), mode="constant") + w = ( + w[:, :, 1:, 1:] + w[:, :, :-1, 1:] + w[:, :, 1:, :-1] + w[:, :, :-1, :-1] + ) * 0.25 + bias = self.bias + if bias is not None: + bias = bias * self.bias_lr_mul + out = thf.conv2d( + x, + w * self.std, + bias if not self.untied else None, + stride=self.stride, + padding=self.padding, + dilation=self.dilation, + groups=self.groups, + ) + if self.untied and bias is not None: + out = out + bias[None, ...] + return out + + +class ConcatPyramid(th.nn.Module): + def __init__( + self, + # pyre-fixme[2]: Parameter must be annotated. + branch, + # pyre-fixme[2]: Parameter must be annotated. + n_concat_in, + every_other: bool = True, + ksize: int = 7, + # pyre-fixme[2]: Parameter must be annotated. + kstd=None, + transposed: bool = False, + ) -> None: + """Module which wraps an up/down conv branch taking one input X and + converts it into a branch which takes two inputs X, Y. At each layer of + the original branch, we concatenate the previous output and Y, + up/downsampling Y appropriately, before running the layer. + + Args: + branch: th.nn.Sequential or th.nn.ModuleList + A branch containing up/down convs, optionally separated by nonlinearities. + + n_concat_in: int + Number of channels in the to-be-concatenated input (Y). + + every_other: bool + If every other layer is a nonlinearity, set this flag. Default is on. + + ksize: int + Kernel size for the Gaussian blur used to downsample each step of the pyramid. + + kstd: int + Kernel std. dev. for the Gaussian blur used to downsample each step of the pyramid. + If None, it is determined automatically. + + transposed: bool + Whether or not the conv stack contains transposed convolutions or not. + """ + super().__init__() + assert isinstance(branch, (th.nn.Sequential, th.nn.ModuleList)) + + # pyre-fixme[4]: Attribute must be annotated. + self.branch = branch + # pyre-fixme[4]: Attribute must be annotated. + self.n_concat_in = n_concat_in + self.every_other = every_other + self.ksize = ksize + # pyre-fixme[4]: Attribute must be annotated. + self.kstd = kstd + self.transposed = transposed + if every_other: + # pyre-fixme[4]: Attribute must be annotated. + self.levels = int(np.ceil(len(branch) / 2)) + else: + self.levels = len(branch) + + kernel = th.from_numpy(gaussian_kernel(ksize, kstd)).float() + self.register_buffer("blur_kernel", kernel[None, None].expand(n_concat_in, -1, -1, -1)) + + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. + def forward(self, x, y): + if self.transposed: + blurred = thf.conv2d( + y, self.blur_kernel, groups=self.n_concat_in, padding=self.ksize // 2 + ) + pyramid = [blurred[:, :, ::2, ::2]] + else: + pyramid = [y] + + for _ in range(self.levels - 1): + blurred = thf.conv2d( + pyramid[0], self.blur_kernel, groups=self.n_concat_in, padding=self.ksize // 2 + ) + pyramid.insert(0, blurred[:, :, ::2, ::2]) + + out = x + for i, layer in enumerate(self.branch): + if (i % 2) == 0 or not self.every_other: + idx = i // 2 if self.every_other else i + out = th.cat([out, pyramid[idx]], dim=1) + out = layer(out) + return out + + +# From paper "Making Convolutional Networks Shift-Invariant Again" +# https://richzhang.github.io/antialiased-cnns/ +# pyre-fixme[3]: Return type must be annotated. +# pyre-fixme[2]: Parameter must be annotated. +def get_pad_layer(pad_type): + if pad_type in ["refl", "reflect"]: + PadLayer = th.nn.ReflectionPad2d + elif pad_type in ["repl", "replicate"]: + PadLayer = th.nn.ReplicationPad2d + elif pad_type == "zero": + PadLayer = th.nn.ZeroPad2d + else: + print("Pad type [%s] not recognized" % pad_type) + # pyre-fixme[61]: `PadLayer` is undefined, or not always defined. + return PadLayer + + +class Downsample(th.nn.Module): + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. + def __init__(self, pad_type="reflect", filt_size=3, stride=2, channels=None, pad_off=0): + super(Downsample, self).__init__() + # pyre-fixme[4]: Attribute must be annotated. + self.filt_size = filt_size + # pyre-fixme[4]: Attribute must be annotated. + self.pad_off = pad_off + # pyre-fixme[4]: Attribute must be annotated. + self.pad_sizes = [ + int(1.0 * (filt_size - 1) / 2), + int(np.ceil(1.0 * (filt_size - 1) / 2)), + int(1.0 * (filt_size - 1) / 2), + int(np.ceil(1.0 * (filt_size - 1) / 2)), + ] + self.pad_sizes = [pad_size + pad_off for pad_size in self.pad_sizes] + # pyre-fixme[4]: Attribute must be annotated. + self.stride = stride + self.off = int((self.stride - 1) / 2.0) + # pyre-fixme[4]: Attribute must be annotated. + self.channels = channels + + # print('Filter size [%i]'%filt_size) + if self.filt_size == 1: + a = np.array( + [ + 1.0, + ] + ) + elif self.filt_size == 2: + a = np.array([1.0, 1.0]) + elif self.filt_size == 3: + a = np.array([1.0, 2.0, 1.0]) + elif self.filt_size == 4: + a = np.array([1.0, 3.0, 3.0, 1.0]) + elif self.filt_size == 5: + a = np.array([1.0, 4.0, 6.0, 4.0, 1.0]) + elif self.filt_size == 6: + a = np.array([1.0, 5.0, 10.0, 10.0, 5.0, 1.0]) + elif self.filt_size == 7: + a = np.array([1.0, 6.0, 15.0, 20.0, 15.0, 6.0, 1.0]) + + filt = th.Tensor(a[:, None] * a[None, :]) + filt = filt / th.sum(filt) + self.register_buffer("filt", filt[None, None, :, :].repeat((self.channels, 1, 1, 1))) + + # pyre-fixme[4]: Attribute must be annotated. + self.pad = get_pad_layer(pad_type)(self.pad_sizes) + + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. + def forward(self, inp): + if self.filt_size == 1: + if self.pad_off == 0: + return inp[:, :, :: self.stride, :: self.stride] + else: + return self.pad(inp)[:, :, :: self.stride, :: self.stride] + else: + return th.nn.functional.conv2d( + self.pad(inp), self.filt, stride=self.stride, groups=inp.shape[1] + ) diff --git a/visualize/ca_body/nn/shadow.py b/visualize/ca_body/nn/shadow.py new file mode 100644 index 0000000000000000000000000000000000000000..e23c098996e503fb6a6baa6d2d9eba6b31393db5 --- /dev/null +++ b/visualize/ca_body/nn/shadow.py @@ -0,0 +1,615 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import logging + +from typing import Optional, Dict + + +import numpy as np +import torch as th +import torch.nn as nn +import torch.nn.functional as F + +# TODO: use shared utils here? +import visualize.ca_body.nn.layers as la +from visualize.ca_body.nn.blocks import tile2d, weights_initializer + +logger = logging.getLogger(__name__) + + +class ShadowUNet(nn.Module): + def __init__( + self, + uv_size, + ao_mean, + shadow_size, + lrelu_slope=0.2, + beta=1.0, + n_dims=64, + interp_mode="bilinear", + biases=True, + trainable_mean=False, + ): + super().__init__() + + # this is the size of the output + self.uv_size = uv_size + self.shadow_size = shadow_size + + ao_mean = F.interpolate( + th.as_tensor(ao_mean)[np.newaxis], + size=(self.shadow_size, self.shadow_size), + )[0] + if not trainable_mean: + # TODO: + self.register_buffer("ao_mean", ao_mean) + else: + self.register_parameter("ao_mean", th.nn.Parameter(ao_mean)) + + self.depth = 3 + self.lrelu_slope = lrelu_slope + self.interp_mode = interp_mode + self.align_corners = None + if interp_mode == "bilinear": + self.align_corners = False + + # the base number of dimensions for the shadow maps + n_dims = n_dims + + # TODO: generate this? + self.n_enc_dims = [ + (1, n_dims), + (n_dims, n_dims), + (n_dims, n_dims), + (n_dims, n_dims), + ] + + self.sizes = [shadow_size // (2**i) for i in range(len(self.n_enc_dims))] + + logger.debug(f"sizes: {self.sizes}") + + self.enc_layers = nn.ModuleList() + for i, size in enumerate(self.sizes): + n_in, n_out = self.n_enc_dims[i] + logger.debug(f"EncoderLayers({i}): {n_in}, {n_out}, {size}") + self.enc_layers.append( + nn.Sequential( + la.Conv2dWNUB( + n_in, + n_out, + kernel_size=3, + height=size, + width=size, + stride=1, + padding=1, + ), + nn.LeakyReLU(self.lrelu_slope, inplace=True), + ) + ) + + self.n_dec_dims = [ + (n_dims, n_dims), + (n_dims * 2, n_dims), + (n_dims * 2, n_dims), + (n_dims * 2, n_dims), + ] + self.dec_layers = nn.ModuleList() + for i in range(len(self.sizes)): + size = self.sizes[-i - 1] + n_in, n_out = self.n_dec_dims[i] + logger.debug(f"DecoderLayer({i}): {n_in}, {n_out}, {size}") + + self.dec_layers.append( + nn.Sequential( + la.Conv2dWNUB( + n_in, + n_out, + kernel_size=3, + height=size, + width=size, + stride=1, + padding=1, + ), + nn.LeakyReLU(self.lrelu_slope, inplace=True), + ) + ) + + self.apply(weights_initializer(self.lrelu_slope)) + + if biases: + self.shadow_pred = la.Conv2dWNUB( + self.n_dec_dims[-1][-1], + 1, + kernel_size=3, + height=self.sizes[0], + width=self.sizes[0], + stride=1, + padding=1, + ) + else: + self.shadow_pred = la.Conv2dWN( + self.n_dec_dims[-1][-1], + 1, + kernel_size=3, + stride=1, + padding=1, + ) + + self.shadow_pred.apply(weights_initializer(1.0)) + self.beta = beta + + def forward(self, ao_map): + # resizing the inputs if necessary + if ao_map.shape[-2:] != (self.shadow_size, self.shadow_size): + ao_map = F.interpolate(ao_map, size=(self.shadow_size, self.shadow_size)) + + x = ao_map - self.ao_mean + + enc_acts = [] + # unet enc + for i, layer in enumerate(self.enc_layers): + # TODO: try applying a 1D sparse op? + x = layer(x) + enc_acts.append(x) + # TODO: add this layer elsewhere? + if i < len(self.sizes) - 1: + x = F.interpolate( + x, + scale_factor=0.5, + mode="bilinear", + recompute_scale_factor=True, + align_corners=True, + ) + + # we do not need the last one? + for i, layer in enumerate(self.dec_layers): + if i > 0: + x_prev = enc_acts[-i - 1] + x = F.interpolate(x, size=x_prev.shape[2:4], mode="bilinear", align_corners=True) + x = th.cat([x, x_prev], dim=1) + x = layer(x) + + shadow_map_lowres = th.sigmoid(self.shadow_pred(x) + self.beta) + shadow_map = F.interpolate( + shadow_map_lowres, + (self.uv_size, self.uv_size), + mode=self.interp_mode, + align_corners=self.align_corners, + ) + + return { + "shadow_map": shadow_map, + "ao_map": ao_map, + "shadow_map_lowres": shadow_map_lowres, + } + + +class FloorShadowDecoder(nn.Module): + def __init__( + self, + uv_size, + beta=1.0, + ): + super().__init__() + + # TODO: can we reduce # dims here? + self.down1 = nn.Sequential(la.Conv2dWNUB(1, 64, 256, 256, 4, 2, 1), nn.LeakyReLU(0.2)) + self.down2 = nn.Sequential(la.Conv2dWNUB(64, 64, 128, 128, 4, 2, 1), nn.LeakyReLU(0.2)) + self.down3 = nn.Sequential(la.Conv2dWNUB(64, 128, 64, 64, 4, 2, 1), nn.LeakyReLU(0.2)) + self.down4 = nn.Sequential(la.Conv2dWNUB(128, 256, 32, 32, 4, 2, 1), nn.LeakyReLU(0.2)) + self.down5 = nn.Sequential(la.Conv2dWNUB(256, 512, 16, 16, 4, 2, 1), nn.LeakyReLU(0.2)) + self.up1 = nn.Sequential( + la.ConvTranspose2dWNUB(512, 256, 32, 32, 4, 2, 1), nn.LeakyReLU(0.2) + ) + self.up2 = nn.Sequential( + la.ConvTranspose2dWNUB(256, 128, 64, 64, 4, 2, 1), nn.LeakyReLU(0.2) + ) + self.up3 = nn.Sequential( + la.ConvTranspose2dWNUB(128, 64, 128, 128, 4, 2, 1), nn.LeakyReLU(0.2) + ) + self.up4 = nn.Sequential( + la.ConvTranspose2dWNUB(64, 64, 256, 256, 4, 2, 1), nn.LeakyReLU(0.2) + ) + self.up5 = nn.Sequential(la.ConvTranspose2dWNUB(64, 1, 512, 512, 4, 2, 1)) + + self.uv_size = uv_size + + self.apply(lambda x: la.glorot(x, 0.2)) + la.glorot(self.up5, 1.0) + + self.beta = beta + + def forward(self, aomap: th.Tensor): + aomap = F.interpolate( + aomap, + size=(self.uv_size, self.uv_size), + mode="bilinear", + align_corners=True, + ) + + x2 = self.down1(aomap - 0.5) + x3 = self.down2(x2) + x4 = self.down3(x3) + x5 = self.down4(x4) + x6 = self.down5(x5) + x = self.up1(x6) + x5 + x = self.up2(x) + x4 + x = self.up3(x) + x3 + x = self.up4(x) + x2 + logits = (th.tanh(self.up5(x) + aomap) + 1.0) / 2.0 + + return {"shadow_map": logits} + + +class ShadowUNet_PoseCond(nn.Module): + def __init__( + self, + uv_size, + ao_mean, + shadow_size, + # uv_coords, # for bottleneck + # uv_mapping, # for bottleneck + # uv_faces, # for bottleneck + lrelu_slope=0.2, + beta=1.0, + n_dims=64, + n_pose_dims=6, # root orientation only + n_pose_enc_dims=32, + interp_mode="bilinear", + ): + super().__init__() + + self.uv_size = uv_size + + self.register_buffer("ao_mean", th.as_tensor(ao_mean)) + + # self.register_buffer("uv_coords", th.as_tensor(uv_coords)) + # self.register_buffer("uv_mapping", th.as_tensor(uv_mapping)) + # self.register_buffer("uv_faces", th.as_tensor(uv_faces)) + + self.depth = 3 + self.lrelu_slope = lrelu_slope + self.interp_mode = interp_mode + + self.uv_size = uv_size + + # the base number of dimensions for the shadow maps + n_dims = n_dims + + # TODO: generate this? + self.n_enc_dims = [ + (1, n_dims), + (n_dims, n_dims), + (n_dims, n_dims), + (n_dims, n_dims), + ] + + self.shadow_size = shadow_size + self.sizes = [shadow_size // (2**i) for i in range(len(self.n_enc_dims))] + + logger.info(f" shadow map size: {self.shadow_size}") + # logger.info(f"sizes: {self.sizes}") + + ##### + ## FC for root pose encoding + self.num_pose_dims = n_pose_dims + self.num_pose_enc_dims = n_pose_enc_dims + self.pose_fc_block = nn.Sequential( + la.LinearWN(self.num_pose_dims, self.num_pose_enc_dims), + nn.LeakyReLU(lrelu_slope), + ) + + self.pose_conv_block = la.Conv2dWNUB( + in_channels=self.num_pose_dims, + out_channels=self.num_pose_enc_dims, + kernel_size=3, + height=self.sizes[-1], + width=self.sizes[-1], + padding=1, + ) + + self.enc_layers = nn.ModuleList() + for i, size in enumerate(self.sizes): + n_in, n_out = self.n_enc_dims[i] + # logger.info(f"EncoderLayers({i}): {n_in}, {n_out}, {size}") + self.enc_layers.append( + nn.Sequential( + la.Conv2dWNUB( + n_in, + n_out, + kernel_size=3, + height=size, + width=size, + stride=1, + padding=1, + ), + nn.LeakyReLU(self.lrelu_slope, inplace=True), + ) + ) + + self.n_dec_dims = [ + (n_dims + self.num_pose_enc_dims, n_dims), + (n_dims * 2, n_dims), + (n_dims * 2, n_dims), + (n_dims * 2, n_dims), + ] + self.dec_layers = nn.ModuleList() + for i in range(len(self.sizes)): + size = self.sizes[-i - 1] + n_in, n_out = self.n_dec_dims[i] + # logger.info(f"DecoderLayer({i}): {n_in}, {n_out}, {size}") + self.dec_layers.append( + nn.Sequential( + la.Conv2dWNUB( + n_in, + n_out, + kernel_size=3, + height=size, + width=size, + stride=1, + padding=1, + ), + nn.LeakyReLU(self.lrelu_slope, inplace=True), + ) + ) + + self.apply(weights_initializer(self.lrelu_slope)) + self.shadow_pred = la.Conv2dWNUB( + self.n_dec_dims[-1][-1], + 1, + kernel_size=3, + height=self.sizes[0], + width=self.sizes[0], + stride=1, + padding=1, + ) + + self.shadow_pred.apply(weights_initializer(1.0)) + self.beta = beta + + def forward(self, ao_map, pose_vec): + # import pdb; pdb.set_trace() + x = ao_map - self.ao_mean + + x = F.interpolate(x, size=(self.shadow_size, self.shadow_size)) + + enc_acts = [] + # unet enc + for i, layer in enumerate(self.enc_layers): + # for i in range(len(self.sizes)): + # TODO: try applying a 1D sparse op? + # x = self.enc_layers[i](x) + x = layer(x) + enc_acts.append(x) + # TODO: add this layer elsewhere? + if i < len(self.sizes) - 1: + x = F.interpolate( + x, + scale_factor=0.5, + mode="bilinear", + recompute_scale_factor=True, + align_corners=True, + ) + + pose_enc = self.pose_conv_block(tile2d(pose_vec, self.sizes[-1])) + + # we do not need the last one? + x = th.cat([x, pose_enc], dim=1) + + for i, layer in enumerate(self.dec_layers): + if i > 0: + x_prev = enc_acts[-i - 1] + x = F.interpolate(x, size=x_prev.shape[2:4], mode="bilinear", align_corners=True) + x = th.cat([x, x_prev], dim=1) + x = layer(x) + + shadow_map_lowres = th.sigmoid(self.shadow_pred(x) + self.beta) + shadow_map = F.interpolate( + shadow_map_lowres, (self.uv_size, self.uv_size), mode=self.interp_mode + ) + return {"shadow_map": shadow_map, "ao_map": ao_map} + + +class PoseToShadow(nn.Module): + def __init__( + self, + n_pose_dims, + uv_size, + beta=1.0, + ) -> None: + super().__init__() + self.n_pose_dims = n_pose_dims + self.uv_size = uv_size + + self.fc_block = nn.Sequential( + la.LinearWN(self.n_pose_dims, 256 * 4 * 4), + nn.LeakyReLU(0.2), + ) + self.conv_block = nn.Sequential( + la.ConvTranspose2dWNUB(256, 256, 8, 8, 4, 2, 1), + nn.LeakyReLU(0.2), + la.ConvTranspose2dWNUB(256, 128, 16, 16, 4, 2, 1), + nn.LeakyReLU(0.2), + la.ConvTranspose2dWNUB(128, 128, 32, 32, 4, 2, 1), + nn.LeakyReLU(0.2), + la.ConvTranspose2dWNUB(128, 64, 64, 64, 4, 2, 1), + nn.LeakyReLU(0.2), + # la.ConvTranspose2dWNUB(64, 64, 128, 128, 4, 2, 1), + # nn.LeakyReLU(0.2), + # la.ConvTranspose2dWNUB(64, 1, 256, 256, 4, 2, 1), + la.ConvTranspose2dWNUB(64, 1, 128, 128, 4, 2, 1), + ) + self.beta = beta + self.apply(lambda x: la.glorot(x, 0.2)) + la.glorot(self.conv_block[-1], 1.0) + + def forward(self, pose: th.Tensor): + assert pose.shape + x = self.fc_block(pose) + x = self.conv_block(x.reshape(-1, 256, 4, 4)) + shadow_map_lowres = th.sigmoid(x + self.beta) + + shadow_map = F.interpolate( + shadow_map_lowres, size=(self.uv_size, self.uv_size), mode="bilinear" + ) + return {"shadow_map": shadow_map} + + +class DistMapShadowUNet(nn.Module): + def __init__( + self, + uv_size, + shadow_size, + n_dist_joints, + lrelu_slope=0.2, + beta=1.0, + n_dims=64, + interp_mode="bilinear", + biases=True, + ): + super().__init__() + + # this is the size of the output + self.uv_size = uv_size + self.shadow_size = shadow_size + + self.depth = 3 + self.lrelu_slope = lrelu_slope + self.interp_mode = interp_mode + self.align_corners = None + if interp_mode == "bilinear": + self.align_corners = False + + # the base number of dimensions for the shadow maps + n_dims = n_dims + + # TODO: generate this? + self.n_enc_dims = [ + (n_dist_joints, n_dims), + (n_dims, n_dims), + (n_dims, n_dims), + (n_dims, n_dims), + ] + + self.sizes = [shadow_size // (2**i) for i in range(len(self.n_enc_dims))] + + logger.debug(f"sizes: {self.sizes}") + + self.enc_layers = nn.ModuleList() + for i, size in enumerate(self.sizes): + n_in, n_out = self.n_enc_dims[i] + logger.debug(f"EncoderLayers({i}): {n_in}, {n_out}, {size}") + self.enc_layers.append( + nn.Sequential( + la.Conv2dWNUB( + n_in, + n_out, + kernel_size=3, + height=size, + width=size, + stride=1, + padding=1, + ), + nn.LeakyReLU(self.lrelu_slope, inplace=True), + ) + ) + + self.n_dec_dims = [ + (n_dims, n_dims), + (n_dims * 2, n_dims), + (n_dims * 2, n_dims), + (n_dims * 2, n_dims), + ] + self.dec_layers = nn.ModuleList() + for i in range(len(self.sizes)): + size = self.sizes[-i - 1] + n_in, n_out = self.n_dec_dims[i] + logger.debug(f"DecoderLayer({i}): {n_in}, {n_out}, {size}") + + self.dec_layers.append( + nn.Sequential( + la.Conv2dWNUB( + n_in, + n_out, + kernel_size=3, + height=size, + width=size, + stride=1, + padding=1, + ), + nn.LeakyReLU(self.lrelu_slope, inplace=True), + ) + ) + + self.apply(weights_initializer(self.lrelu_slope)) + + if biases: + self.shadow_pred = la.Conv2dWNUB( + self.n_dec_dims[-1][-1], + 1, + kernel_size=3, + height=self.sizes[0], + width=self.sizes[0], + stride=1, + padding=1, + ) + else: + self.shadow_pred = la.Conv2dWN( + self.n_dec_dims[-1][-1], + 1, + kernel_size=3, + stride=1, + padding=1, + ) + + self.shadow_pred.apply(weights_initializer(1.0)) + self.beta = beta + + def forward(self, dist_map: th.Tensor) -> Dict[str, th.Tensor]: + # resizing the inputs if necessary + if dist_map.shape[-2:] != (self.shadow_size, self.shadow_size): + dist_map = F.interpolate(dist_map, size=(self.shadow_size, self.shadow_size)) + + x = dist_map + + enc_acts = [] + # unet enc + for i, layer in enumerate(self.enc_layers): + # TODO: try applying a 1D sparse op? + x = layer(x) + enc_acts.append(x) + # TODO: add this layer elsewhere? + if i < len(self.sizes) - 1: + x = F.interpolate( + x, + scale_factor=0.5, + mode="bilinear", + recompute_scale_factor=True, + align_corners=True, + ) + + # we do not need the last one? + for i, layer in enumerate(self.dec_layers): + if i > 0: + x_prev = enc_acts[-i - 1] + x = F.interpolate(x, size=x_prev.shape[2:4], mode="bilinear", align_corners=True) + x = th.cat([x, x_prev], dim=1) + x = layer(x) + + shadow_map_lowres = th.sigmoid(self.shadow_pred(x) + self.beta) + shadow_map = F.interpolate( + shadow_map_lowres, + (self.uv_size, self.uv_size), + mode=self.interp_mode, + align_corners=self.align_corners, + ) + + return { + "shadow_map": shadow_map, + "shadow_map_lowres": shadow_map_lowres, + } diff --git a/visualize/ca_body/nn/unet.py b/visualize/ca_body/nn/unet.py new file mode 100644 index 0000000000000000000000000000000000000000..4a9f65d73977ac5965a61d708cf4acb6e8f5b43d --- /dev/null +++ b/visualize/ca_body/nn/unet.py @@ -0,0 +1,254 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import torch as th +import torch.nn as nn +import visualize.ca_body.nn.layers as la + +from visualize.ca_body.nn.blocks import weights_initializer +from visualize.ca_body.nn.layers import Conv2dWNUB, ConvTranspose2dWNUB, glorot + + +class UNetWB(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + size: int, + n_init_ftrs: int = 8, + out_scale: float = 0.1, + ): + # super().__init__(*args, **kwargs) + super().__init__() + + self.out_scale = out_scale + + F = n_init_ftrs + + self.size = size + + self.down1 = nn.Sequential( + Conv2dWNUB(in_channels, F, self.size // 2, self.size // 2, 4, 2, 1), + nn.LeakyReLU(0.2), + ) + self.down2 = nn.Sequential( + Conv2dWNUB(F, 2 * F, self.size // 4, self.size // 4, 4, 2, 1), + nn.LeakyReLU(0.2), + ) + self.down3 = nn.Sequential( + Conv2dWNUB(2 * F, 4 * F, self.size // 8, self.size // 8, 4, 2, 1), + nn.LeakyReLU(0.2), + ) + self.down4 = nn.Sequential( + Conv2dWNUB(4 * F, 8 * F, self.size // 16, self.size // 16, 4, 2, 1), + nn.LeakyReLU(0.2), + ) + self.down5 = nn.Sequential( + Conv2dWNUB(8 * F, 16 * F, self.size // 32, self.size // 32, 4, 2, 1), + nn.LeakyReLU(0.2), + ) + self.up1 = nn.Sequential( + ConvTranspose2dWNUB( + 16 * F, 8 * F, self.size // 16, self.size // 16, 4, 2, 1 + ), + nn.LeakyReLU(0.2), + ) + self.up2 = nn.Sequential( + ConvTranspose2dWNUB(8 * F, 4 * F, self.size // 8, self.size // 8, 4, 2, 1), + nn.LeakyReLU(0.2), + ) + self.up3 = nn.Sequential( + ConvTranspose2dWNUB(4 * F, 2 * F, self.size // 4, self.size // 4, 4, 2, 1), + nn.LeakyReLU(0.2), + ) + self.up4 = nn.Sequential( + ConvTranspose2dWNUB(2 * F, F, self.size // 2, self.size // 2, 4, 2, 1), + nn.LeakyReLU(0.2), + ) + self.up5 = nn.Sequential( + ConvTranspose2dWNUB(F, F, self.size, self.size, 4, 2, 1), nn.LeakyReLU(0.2) + ) + self.out = Conv2dWNUB( + F + in_channels, out_channels, self.size, self.size, kernel_size=1 + ) + self.apply(lambda x: glorot(x, 0.2)) + glorot(self.out, 1.0) + + def forward(self, x): + x1 = x + x2 = self.down1(x1) + x3 = self.down2(x2) + x4 = self.down3(x3) + x5 = self.down4(x4) + x6 = self.down5(x5) + # TODO: switch to concat? + x = self.up1(x6) + x5 + x = self.up2(x) + x4 + x = self.up3(x) + x3 + x = self.up4(x) + x2 + x = self.up5(x) + x = th.cat([x, x1], dim=1) + return self.out(x) * self.out_scale + + +class UNetWBConcat(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + size: int, + n_init_ftrs: int = 8, + ): + super().__init__() + + F = n_init_ftrs + + self.size = size + + self.down1 = nn.Sequential( + la.Conv2dWNUB(in_channels, F, self.size // 2, self.size // 2, 4, 2, 1), + nn.LeakyReLU(0.2), + ) + self.down2 = nn.Sequential( + la.Conv2dWNUB(F, 2 * F, self.size // 4, self.size // 4, 4, 2, 1), + nn.LeakyReLU(0.2), + ) + self.down3 = nn.Sequential( + la.Conv2dWNUB(2 * F, 4 * F, self.size // 8, self.size // 8, 4, 2, 1), + nn.LeakyReLU(0.2), + ) + self.down4 = nn.Sequential( + la.Conv2dWNUB(4 * F, 8 * F, self.size // 16, self.size // 16, 4, 2, 1), + nn.LeakyReLU(0.2), + ) + self.down5 = nn.Sequential( + la.Conv2dWNUB(8 * F, 16 * F, self.size // 32, self.size // 32, 4, 2, 1), + nn.LeakyReLU(0.2), + ) + self.up1 = nn.Sequential( + la.ConvTranspose2dWNUB( + 16 * F, 8 * F, self.size // 16, self.size // 16, 4, 2, 1 + ), + nn.LeakyReLU(0.2), + ) + self.up2 = nn.Sequential( + la.ConvTranspose2dWNUB( + 2 * 8 * F, 4 * F, self.size // 8, self.size // 8, 4, 2, 1 + ), + nn.LeakyReLU(0.2), + ) + self.up3 = nn.Sequential( + la.ConvTranspose2dWNUB( + 2 * 4 * F, 2 * F, self.size // 4, self.size // 4, 4, 2, 1 + ), + nn.LeakyReLU(0.2), + ) + self.up4 = nn.Sequential( + la.ConvTranspose2dWNUB( + 2 * 2 * F, F, self.size // 2, self.size // 2, 4, 2, 1 + ), + nn.LeakyReLU(0.2), + ) + self.up5 = nn.Sequential( + la.ConvTranspose2dWNUB(2 * F, F, self.size, self.size, 4, 2, 1), + nn.LeakyReLU(0.2), + ) + self.out = la.Conv2dWNUB( + F + in_channels, out_channels, self.size, self.size, kernel_size=1 + ) + self.apply(lambda x: la.glorot(x, 0.2)) + la.glorot(self.out, 1.0) + + def forward(self, x): + x1 = x + x2 = self.down1(x1) + x3 = self.down2(x2) + x4 = self.down3(x3) + x5 = self.down4(x4) + x6 = self.down5(x5) + x = th.cat([self.up1(x6), x5], 1) + x = th.cat([self.up2(x), x4], 1) + x = th.cat([self.up3(x), x3], 1) + x = th.cat([self.up4(x), x2], 1) + x = self.up5(x) + x = th.cat([x, x1], dim=1) + return self.out(x) + + +class UNetW(nn.Module): + def __init__( + self, + in_channels, + out_channels, + n_init_ftrs, + kernel_size=4, + out_scale=1.0, + ): + super().__init__() + + self.out_scale = out_scale + + F = n_init_ftrs + + self.down1 = nn.Sequential( + la.Conv2dWN(in_channels, F, kernel_size, 2, 1), + nn.LeakyReLU(0.2), + ) + self.down2 = nn.Sequential( + la.Conv2dWN(F, 2 * F, kernel_size, 2, 1), + nn.LeakyReLU(0.2), + ) + self.down3 = nn.Sequential( + la.Conv2dWN(2 * F, 4 * F, kernel_size, 2, 1), + nn.LeakyReLU(0.2), + ) + self.down4 = nn.Sequential( + la.Conv2dWN(4 * F, 8 * F, kernel_size, 2, 1), + nn.LeakyReLU(0.2), + ) + self.down5 = nn.Sequential( + la.Conv2dWN(8 * F, 16 * F, kernel_size, 2, 1), + nn.LeakyReLU(0.2), + ) + self.up1 = nn.Sequential( + la.ConvTranspose2dWN(16 * F, 8 * F, kernel_size, 2, 1), + nn.LeakyReLU(0.2), + ) + self.up2 = nn.Sequential( + la.ConvTranspose2dWN(8 * F, 4 * F, kernel_size, 2, 1), + nn.LeakyReLU(0.2), + ) + self.up3 = nn.Sequential( + la.ConvTranspose2dWN(4 * F, 2 * F, kernel_size, 2, 1), + nn.LeakyReLU(0.2), + ) + self.up4 = nn.Sequential( + la.ConvTranspose2dWN(2 * F, F, kernel_size, 2, 1), + nn.LeakyReLU(0.2), + ) + self.up5 = nn.Sequential( + la.ConvTranspose2dWN(F, F, kernel_size, 2, 1), nn.LeakyReLU(0.2) + ) + self.out = la.Conv2dWN(F + in_channels, out_channels, kernel_size=1) + self.apply(weights_initializer(0.2)) + self.out.apply(weights_initializer(1.0)) + + def forward(self, x): + x1 = x + x2 = self.down1(x1) + x3 = self.down2(x2) + x4 = self.down3(x3) + x5 = self.down4(x4) + x6 = self.down5(x5) + # TODO: switch to concat? + x = self.up1(x6) + x5 + x = self.up2(x) + x4 + x = self.up3(x) + x3 + x = self.up4(x) + x2 + x = self.up5(x) + x = th.cat([x, x1], dim=1) + return self.out(x) * self.out_scale diff --git a/visualize/ca_body/notebooks/render_example_cca.ipynb b/visualize/ca_body/notebooks/render_example_cca.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..09211d203fa91c31f683bb754bf62e60951a882b --- /dev/null +++ b/visualize/ca_body/notebooks/render_example_cca.ipynb @@ -0,0 +1,195 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "43cbd3f0", + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "import os\n", + "import torch as th\n", + "import cv2\n", + "\n", + "# set the right device\n", + "#os.environ['CUDA_VISIBLE_DEVICES'] = '0'\n", + "# NOTE: assuming we are in `ca_body/notebooks`\n", + "sys.path.insert(0, '/home/evonneng/audio2photoreal')\n", + "from attrdict import AttrDict\n", + "\n", + "from omegaconf import OmegaConf\n", + "from torchvision.utils import make_grid\n", + "\n", + "from visualize.ca_body.utils.module_loader import load_from_config\n", + "from visualize.ca_body.utils.lbs import LBSModule\n", + "from visualize.ca_body.utils.train import load_checkpoint\n", + "\n", + "device = th.device('cuda:0')" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "5caf2480", + "metadata": {}, + "outputs": [], + "source": [ + "# NOTE: make sure to download the data\n", + "model_dir = '/home/evonneng/audio2photoreal/checkpoints/ca_body/data/PXB184/'\n", + "\n", + "ckpt_path = f'{model_dir}/body_dec.ckpt'\n", + "config_path = f'{model_dir}/config.yml'\n", + "assets_path = f'{model_dir}/static_assets.pt'\n", + "\n", + "# config\n", + "config = OmegaConf.load(config_path)\n", + "# assets\n", + "static_assets = AttrDict(th.load(assets_path))\n", + "# sample batch\n", + "batch = th.load(f'{model_dir}/sample_batch.pt')\n", + "batch = {\n", + " key: val.to(device) if th.is_tensor(val) else val\n", + " for key, val in batch.items()\n", + "}\n", + "# batch = to_device(batch, device)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "dict_keys(['image', 'ao', 'seg_fg', 'seg_part', 'lbs_motion', 'geom', 'face_embs', 'camera_ids', 'campos', 'camrot', 'focal', 'princpt', 'K', 'Rt', '_index', 'face_R', 'face_t'])" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "batch.keys()" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "73331f2e", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[2023-12-21 17:04:11][INFO][visualize.ca_body.utils.geom]:impainting index image might take a while for sizes >= 1024\n", + "[2023-12-21 17:04:13][INFO][visualize.ca_body.models.mesh_vae_drivable]:ConvDecoder: n_channels = [64, 32, 16, 8, 4]\n", + "[2023-12-21 17:04:14][WARNING][visualize.ca_body.nn.color_cal]:Requested color-calibration identity camera not present, defaulting to 400883.\n", + "[2023-12-21 17:04:14][INFO][visualize.ca_body.utils.train]:loading checkpoint /home/evonneng/audio2photoreal/checkpoints/ca_body/data/PXB184//body_dec.ckpt\n", + "[2023-12-21 17:04:15][INFO][visualize.ca_body.utils.train]:skipping: ['lbs_fn.*']\n" + ] + } + ], + "source": [ + "# building the model\n", + "model = load_from_config(\n", + " config.model, \n", + " assets=static_assets,\n", + ").to(device)\n", + "\n", + "# loading model checkpoint\n", + "load_checkpoint(\n", + " ckpt_path, \n", + " modules={'model': model},\n", + " # NOTE: this is accounting for difference in LBS impl\n", + " ignore_names={'model': ['lbs_fn.*']},\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "86a2a291", + "metadata": {}, + "outputs": [], + "source": [ + "# disabling training-only stuff\n", + "model.learn_blur_enabled = False\n", + "model.pixel_cal_enabled = False\n", + "model.cal_enabled = False\n", + "\n", + "# forward\n", + "with th.no_grad():\n", + " preds = model(**batch)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "9a566533", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[2023-12-21 17:31:18][WARNING][matplotlib.image]:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "# visualizing\n", + "import matplotlib.pyplot as plt\n", + "rgb_preds_grid = make_grid(preds['rgb'], nrow=4).permute(1, 2, 0).cpu().numpy() / 255.\n", + "plt.figure(figsize=(15, 15))\n", + "plt.imshow(rgb_preds_grid[::4,::4])" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/visualize/ca_body/requirements.txt b/visualize/ca_body/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..292f6db5ae09201412070f663f4560ba2ee53359 --- /dev/null +++ b/visualize/ca_body/requirements.txt @@ -0,0 +1,4 @@ +torch>=2.0.0 +pytorch3d +numpy +torchvision diff --git a/visualize/ca_body/utils/geom.py b/visualize/ca_body/utils/geom.py new file mode 100644 index 0000000000000000000000000000000000000000..7e50c917ea4bb1aa3043888b421b3efe6612e700 --- /dev/null +++ b/visualize/ca_body/utils/geom.py @@ -0,0 +1,659 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional +import numpy as np +import torch as th +import torch.nn.functional as F +import torch.nn as nn + +from sklearn.neighbors import KDTree + +import logging + +logger = logging.getLogger(__name__) + +# NOTE: we need pytorch3d primarily for UV rasterization things +from pytorch3d.renderer.mesh.rasterize_meshes import rasterize_meshes +from pytorch3d.structures import Meshes +from typing import Union, Optional, Tuple + + +def make_uv_face_index( + vt: th.Tensor, + vti: th.Tensor, + uv_shape: Union[Tuple[int, int], int], + flip_uv: bool = True, + device: Optional[Union[str, th.device]] = None, +): + """Compute a UV-space face index map identifying which mesh face contains each + texel. For texels with no assigned triangle, the index will be -1.""" + + if isinstance(uv_shape, int): + uv_shape = (uv_shape, uv_shape) + + if device is not None: + if isinstance(device, str): + dev = th.device(device) + else: + dev = device + assert dev.type == "cuda" + else: + dev = th.device("cuda") + + vt = 1.0 - vt.clone() + + if flip_uv: + vt = vt.clone() + vt[:, 1] = 1 - vt[:, 1] + vt_pix = 2.0 * vt.to(dev) - 1.0 + vt_pix = th.cat([vt_pix, th.ones_like(vt_pix[:, 0:1])], dim=1) + meshes = Meshes(vt_pix[np.newaxis], vti[np.newaxis].to(dev)) + with th.no_grad(): + face_index, _, _, _ = rasterize_meshes( + meshes, uv_shape, faces_per_pixel=1, z_clip_value=0.0, bin_size=0 + ) + face_index = face_index[0, ..., 0] + return face_index + + +def make_uv_vert_index( + vt: th.Tensor, + vi: th.Tensor, + vti: th.Tensor, + uv_shape: Union[Tuple[int, int], int], + flip_uv: bool = True, +): + """Compute a UV-space vertex index map identifying which mesh vertices + comprise the triangle containing each texel. For texels with no assigned + triangle, all indices will be -1. + """ + face_index_map = make_uv_face_index(vt, vti, uv_shape, flip_uv).to(vi.device) + vert_index_map = vi[face_index_map.clamp(min=0)] + vert_index_map[face_index_map < 0] = -1 + return vert_index_map.long() + + +def bary_coords(points: th.Tensor, triangles: th.Tensor, eps: float = 1.0e-6): + """Computes barycentric coordinates for a set of 2D query points given + coordintes for the 3 vertices of the enclosing triangle for each point.""" + x = points[:, 0] - triangles[2, :, 0] + x1 = triangles[0, :, 0] - triangles[2, :, 0] + x2 = triangles[1, :, 0] - triangles[2, :, 0] + y = points[:, 1] - triangles[2, :, 1] + y1 = triangles[0, :, 1] - triangles[2, :, 1] + y2 = triangles[1, :, 1] - triangles[2, :, 1] + denom = y2 * x1 - y1 * x2 + n0 = y2 * x - x2 * y + n1 = x1 * y - y1 * x + + # Small epsilon to prevent divide-by-zero error. + denom = th.where(denom >= 0, denom.clamp(min=eps), denom.clamp(max=-eps)) + + bary_0 = n0 / denom + bary_1 = n1 / denom + bary_2 = 1.0 - bary_0 - bary_1 + + return th.stack((bary_0, bary_1, bary_2)) + + +def make_uv_barys( + vt: th.Tensor, + vti: th.Tensor, + uv_shape: Union[Tuple[int, int], int], + flip_uv: bool = True, +): + """Compute a UV-space barycentric map where each texel contains barycentric + coordinates for that texel within its enclosing UV triangle. For texels + with no assigned triangle, all 3 barycentric coordinates will be 0. + """ + if isinstance(uv_shape, int): + uv_shape = (uv_shape, uv_shape) + + if flip_uv: + # Flip here because texture coordinates in some of our topo files are + # stored in OpenGL convention with Y=0 on the bottom of the texture + # unlike numpy/torch arrays/tensors. + vt = vt.clone() + vt[:, 1] = 1 - vt[:, 1] + + face_index_map = make_uv_face_index(vt, vti, uv_shape, flip_uv=False).to(vt.device) + vti_map = vti.long()[face_index_map.clamp(min=0)] + uv_tri_uvs = vt[vti_map].permute(2, 0, 1, 3) + + uv_grid = th.meshgrid( + th.linspace(0.5, uv_shape[0] - 0.5, uv_shape[0]) / uv_shape[0], + th.linspace(0.5, uv_shape[1] - 0.5, uv_shape[1]) / uv_shape[1], + ) + uv_grid = th.stack(uv_grid[::-1], dim=2).to(uv_tri_uvs) + + bary_map = bary_coords(uv_grid.view(-1, 2), uv_tri_uvs.view(3, -1, 2)) + bary_map = bary_map.permute(1, 0).view(uv_shape[0], uv_shape[1], 3) + bary_map[face_index_map < 0] = 0 + return face_index_map, bary_map + + +def index_image_impaint( + index_image: th.Tensor, + bary_image: Optional[th.Tensor] = None, + distance_threshold=100.0, +): + # getting the mask around the indexes? + if len(index_image.shape) == 3: + valid_index = (index_image != -1).any(dim=-1) + elif len(index_image.shape) == 2: + valid_index = index_image != -1 + else: + raise ValueError("`index_image` should be a [H,W] or [H,W,C] image") + + invalid_index = ~valid_index + + device = index_image.device + + valid_ij = th.stack(th.where(valid_index), dim=-1) + invalid_ij = th.stack(th.where(invalid_index), dim=-1) + lookup_valid = KDTree(valid_ij.cpu().numpy()) + + dists, idxs = lookup_valid.query(invalid_ij.cpu()) + + # TODO: try average? + idxs = th.as_tensor(idxs, device=device)[..., 0] + dists = th.as_tensor(dists, device=device)[..., 0] + + dist_mask = dists < distance_threshold + + invalid_border = th.zeros_like(invalid_index) + invalid_border[invalid_index] = dist_mask + + invalid_src_ij = valid_ij[idxs][dist_mask] + invalid_dst_ij = invalid_ij[dist_mask] + + index_image_imp = index_image.clone() + + index_image_imp[invalid_dst_ij[:, 0], invalid_dst_ij[:, 1]] = index_image[ + invalid_src_ij[:, 0], invalid_src_ij[:, 1] + ] + + if bary_image is not None: + bary_image_imp = bary_image.clone() + + bary_image_imp[invalid_dst_ij[:, 0], invalid_dst_ij[:, 1]] = bary_image[ + invalid_src_ij[:, 0], invalid_src_ij[:, 1] + ] + + return index_image_imp, bary_image_imp + return index_image_imp + + +class GeometryModule(nn.Module): + def __init__( + self, + vi, + vt, + vti, + v2uv, + uv_size, + flip_uv=False, + impaint=False, + impaint_threshold=100.0, + ): + super().__init__() + + self.register_buffer("vi", th.as_tensor(vi)) + self.register_buffer("vt", th.as_tensor(vt)) + self.register_buffer("vti", th.as_tensor(vti)) + self.register_buffer("v2uv", th.as_tensor(v2uv, dtype=th.int64)) + + # TODO: should we just pass topology here? + self.n_verts = v2uv.shape[0] + + self.uv_size = uv_size + + # TODO: can't we just index face_index? + index_image = make_uv_vert_index( + self.vt, self.vi, self.vti, uv_shape=uv_size, flip_uv=flip_uv + ).cpu() + face_index, bary_image = make_uv_barys( + self.vt, self.vti, uv_shape=uv_size, flip_uv=flip_uv + ) + if impaint: + if uv_size >= 1024: + logger.info( + "impainting index image might take a while for sizes >= 1024" + ) + + index_image, bary_image = index_image_impaint( + index_image, bary_image, impaint_threshold + ) + # TODO: we can avoid doing this 2x + face_index = index_image_impaint( + face_index, distance_threshold=impaint_threshold + ) + + self.register_buffer("index_image", index_image.cpu()) + self.register_buffer("bary_image", bary_image.cpu()) + self.register_buffer("face_index_image", face_index.cpu()) + + def render_index_images(self, uv_size, flip_uv=False, impaint=False): + index_image = make_uv_vert_index( + self.vt, self.vi, self.vti, uv_shape=uv_size, flip_uv=flip_uv + ) + face_image, bary_image = make_uv_barys( + self.vt, self.vti, uv_shape=uv_size, flip_uv=flip_uv + ) + + if impaint: + index_image, bary_image = index_image_impaint( + index_image, + bary_image, + ) + + return index_image, face_image, bary_image + + def vn(self, verts): + return vert_normals(verts, self.vi[np.newaxis].to(th.long)) + + def to_uv(self, values): + return values_to_uv(values, self.index_image, self.bary_image) + + def from_uv(self, values_uv): + # TODO: we need to sample this + return sample_uv(values_uv, self.vt, self.v2uv.to(th.long)) + + +def sample_uv( + values_uv, + uv_coords, + v2uv: Optional[th.Tensor] = None, + mode: str = "bilinear", + align_corners: bool = True, + flip_uvs: bool = False, +): + batch_size = values_uv.shape[0] + + if flip_uvs: + uv_coords = uv_coords.clone() + uv_coords[:, 1] = 1.0 - uv_coords[:, 1] + + uv_coords_norm = (uv_coords * 2.0 - 1.0)[np.newaxis, :, np.newaxis].expand( + batch_size, -1, -1, -1 + ) + values = ( + F.grid_sample(values_uv, uv_coords_norm, align_corners=align_corners, mode=mode) + .squeeze(-1) + .permute((0, 2, 1)) + ) + + if v2uv is not None: + values_duplicate = values[:, v2uv] + values = values_duplicate.mean(2) + + return values + + +def values_to_uv(values, index_img, bary_img): + uv_size = index_img.shape[0] + index_mask = th.all(index_img != -1, dim=-1) + idxs_flat = index_img[index_mask].to(th.int64) + bary_flat = bary_img[index_mask].to(th.float32) + # NOTE: here we assume + values_flat = th.sum(values[:, idxs_flat].permute(0, 3, 1, 2) * bary_flat, dim=-1) + values_uv = th.zeros( + values.shape[0], + values.shape[-1], + uv_size, + uv_size, + dtype=values.dtype, + device=values.device, + ) + values_uv[:, :, index_mask] = values_flat + return values_uv + + +def face_normals(v, vi, eps: float = 1e-5): + pts = v[:, vi] + v0 = pts[:, :, 1] - pts[:, :, 0] + v1 = pts[:, :, 2] - pts[:, :, 0] + n = th.cross(v0, v1, dim=-1) + norm = th.norm(n, dim=-1, keepdim=True) + norm[norm < eps] = 1 + n /= norm + return n + + +def vert_normals(v, vi, eps: float = 1.0e-5): + fnorms = face_normals(v, vi) + fnorms = fnorms[:, :, None].expand(-1, -1, 3, -1).reshape(fnorms.shape[0], -1, 3) + vi_flat = vi.view(1, -1).expand(v.shape[0], -1) + vnorms = th.zeros_like(v) + for j in range(3): + vnorms[..., j].scatter_add_(1, vi_flat, fnorms[..., j]) + norm = th.norm(vnorms, dim=-1, keepdim=True) + norm[norm < eps] = 1 + vnorms /= norm + return vnorms + + +def compute_view_cos(verts, faces, camera_pos): + vn = F.normalize(vert_normals(verts, faces), dim=-1) + v2c = F.normalize(verts - camera_pos[:, np.newaxis], dim=-1) + return th.einsum("bnd,bnd->bn", vn, v2c) + + +def compute_tbn(geom, vt, vi, vti): + """Computes tangent, bitangent, and normal vectors given a mesh. + Args: + geom: [N, n_verts, 3] th.Tensor + Vertex positions. + vt: [n_uv_coords, 2] th.Tensor + UV coordinates. + vi: [..., 3] th.Tensor + Face vertex indices. + vti: [..., 3] th.Tensor + Face UV indices. + Returns: + [..., 3] th.Tensors for T, B, N. + """ + + v0 = geom[:, vi[..., 0]] + v1 = geom[:, vi[..., 1]] + v2 = geom[:, vi[..., 2]] + vt0 = vt[vti[..., 0]] + vt1 = vt[vti[..., 1]] + vt2 = vt[vti[..., 2]] + + v01 = v1 - v0 + v02 = v2 - v0 + vt01 = vt1 - vt0 + vt02 = vt2 - vt0 + f = 1.0 / ( + vt01[None, ..., 0] * vt02[None, ..., 1] + - vt01[None, ..., 1] * vt02[None, ..., 0] + ) + tangent = f[..., None] * th.stack( + [ + v01[..., 0] * vt02[None, ..., 1] - v02[..., 0] * vt01[None, ..., 1], + v01[..., 1] * vt02[None, ..., 1] - v02[..., 1] * vt01[None, ..., 1], + v01[..., 2] * vt02[None, ..., 1] - v02[..., 2] * vt01[None, ..., 1], + ], + dim=-1, + ) + tangent = F.normalize(tangent, dim=-1) + normal = F.normalize(th.cross(v01, v02, dim=3), dim=-1) + bitangent = F.normalize(th.cross(tangent, normal, dim=3), dim=-1) + + return tangent, bitangent, normal + + +def compute_v2uv(n_verts, vi, vti, n_max=4): + """Computes mapping from vertex indices to texture indices. + + Args: + vi: [F, 3], triangles + vti: [F, 3], texture triangles + n_max: int, max number of texture locations + + Returns: + [n_verts, n_max], texture indices + """ + v2uv_dict = {} + for i_v, i_uv in zip(vi.reshape(-1), vti.reshape(-1)): + v2uv_dict.setdefault(i_v, set()).add(i_uv) + assert len(v2uv_dict) == n_verts + v2uv = np.zeros((n_verts, n_max), dtype=np.int32) + for i in range(n_verts): + vals = sorted(list(v2uv_dict[i])) + v2uv[i, :] = vals[0] + v2uv[i, : len(vals)] = np.array(vals) + return v2uv + + +def compute_neighbours(n_verts, vi, n_max_values=10): + """Computes first-ring neighbours given vertices and faces.""" + n_vi = vi.shape[0] + + adj = {i: set() for i in range(n_verts)} + for i in range(n_vi): + for idx in vi[i]: + adj[idx] |= set(vi[i]) - set([idx]) + + nbs_idxs = np.tile(np.arange(n_verts)[:, np.newaxis], (1, n_max_values)) + nbs_weights = np.zeros((n_verts, n_max_values), dtype=np.float32) + + for idx in range(n_verts): + n_values = min(len(adj[idx]), n_max_values) + nbs_idxs[idx, :n_values] = np.array(list(adj[idx]))[:n_values] + nbs_weights[idx, :n_values] = -1.0 / n_values + + return nbs_idxs, nbs_weights + + +def make_postex(v, idxim, barim): + return ( + barim[None, :, :, 0, None] * v[:, idxim[:, :, 0]] + + barim[None, :, :, 1, None] * v[:, idxim[:, :, 1]] + + barim[None, :, :, 2, None] * v[:, idxim[:, :, 2]] + ).permute(0, 3, 1, 2) + + +def matrix_to_axisangle(r): + th = th.arccos(0.5 * (r[..., 0, 0] + r[..., 1, 1] + r[..., 2, 2] - 1.0))[..., None] + vec = ( + 0.5 + * th.stack( + [ + r[..., 2, 1] - r[..., 1, 2], + r[..., 0, 2] - r[..., 2, 0], + r[..., 1, 0] - r[..., 0, 1], + ], + dim=-1, + ) + / th.sin(th) + ) + return th, vec + + +def axisangle_to_matrix(rvec): + theta = th.sqrt(1e-5 + th.sum(rvec**2, dim=-1)) + rvec = rvec / theta[..., None] + costh = th.cos(theta) + sinth = th.sin(theta) + return th.stack( + ( + th.stack( + ( + rvec[..., 0] ** 2 + (1.0 - rvec[..., 0] ** 2) * costh, + rvec[..., 0] * rvec[..., 1] * (1.0 - costh) - rvec[..., 2] * sinth, + rvec[..., 0] * rvec[..., 2] * (1.0 - costh) + rvec[..., 1] * sinth, + ), + dim=-1, + ), + th.stack( + ( + rvec[..., 0] * rvec[..., 1] * (1.0 - costh) + rvec[..., 2] * sinth, + rvec[..., 1] ** 2 + (1.0 - rvec[..., 1] ** 2) * costh, + rvec[..., 1] * rvec[..., 2] * (1.0 - costh) - rvec[..., 0] * sinth, + ), + dim=-1, + ), + th.stack( + ( + rvec[..., 0] * rvec[..., 2] * (1.0 - costh) - rvec[..., 1] * sinth, + rvec[..., 1] * rvec[..., 2] * (1.0 - costh) + rvec[..., 0] * sinth, + rvec[..., 2] ** 2 + (1.0 - rvec[..., 2] ** 2) * costh, + ), + dim=-1, + ), + ), + dim=-2, + ) + + +def rotation_interp(r0, r1, alpha): + r0a = r0.view(-1, 3, 3) + r1a = r1.view(-1, 3, 3) + r = th.bmm(r0a.permute(0, 2, 1), r1a).view_as(r0) + + th, rvec = matrix_to_axisangle(r) + rvec = rvec * (alpha * th) + + r = axisangle_to_matrix(rvec) + return th.bmm(r0a, r.view(-1, 3, 3)).view_as(r0) + + +def convert_camera_parameters(Rt, K): + R = Rt[:, :3, :3] + t = -R.permute(0, 2, 1).bmm(Rt[:, :3, 3].unsqueeze(2)).squeeze(2) + return dict( + campos=t, + camrot=R, + focal=K[:, :2, :2], + princpt=K[:, :2, 2], + ) + + +def project_points_multi(p, Rt, K, normalize=False, size=None): + """Project a set of 3D points into multiple cameras with a pinhole model. + Args: + p: [B, N, 3], input 3D points in world coordinates + Rt: [B, NC, 3, 4], extrinsics (where NC is the number of cameras to project to) + K: [B, NC, 3, 3], intrinsics + normalize: bool, whether to normalize coordinates to [-1.0, 1.0] + Returns: + tuple: + - [B, NC, N, 2] - projected points + - [B, NC, N] - their + """ + B, N = p.shape[:2] + NC = Rt.shape[1] + + Rt = Rt.reshape(B * NC, 3, 4) + K = K.reshape(B * NC, 3, 3) + + # [B, N, 3] -> [B * NC, N, 3] + p = p[:, np.newaxis].expand(-1, NC, -1, -1).reshape(B * NC, -1, 3) + p_cam = p @ Rt[:, :3, :3].mT + Rt[:, :3, 3][:, np.newaxis] + p_pix = p_cam @ K.mT + p_depth = p_pix[:, :, 2:] + p_pix = (p_pix[..., :2] / p_depth).reshape(B, NC, N, 2) + p_depth = p_depth.reshape(B, NC, N) + + if normalize: + assert size is not None + h, w = size + p_pix = ( + 2.0 * p_pix / th.as_tensor([w, h], dtype=th.float32, device=p.device) - 1.0 + ) + return p_pix, p_depth + +def xyz2normals(xyz: th.Tensor, eps: float = 1e-8) -> th.Tensor: + """Convert XYZ image to normal image + + Args: + xyz: th.Tensor + [B, 3, H, W] XYZ image + + Returns: + th.Tensor: [B, 3, H, W] image of normals + """ + + nrml = th.zeros_like(xyz) + xyz = th.cat((xyz[:, :, :1, :] * 0, xyz[:, :, :, :], xyz[:, :, :1, :] * 0), dim=2) + xyz = th.cat((xyz[:, :, :, :1] * 0, xyz[:, :, :, :], xyz[:, :, :, :1] * 0), dim=3) + U = (xyz[:, :, 2:, 1:-1] - xyz[:, :, :-2, 1:-1]) / -2 + V = (xyz[:, :, 1:-1, 2:] - xyz[:, :, 1:-1, :-2]) / -2 + + nrml[:, 0, ...] = U[:, 1, ...] * V[:, 2, ...] - U[:, 2, ...] * V[:, 1, ...] + nrml[:, 1, ...] = U[:, 2, ...] * V[:, 0, ...] - U[:, 0, ...] * V[:, 2, ...] + nrml[:, 2, ...] = U[:, 0, ...] * V[:, 1, ...] - U[:, 1, ...] * V[:, 0, ...] + veclen = th.norm(nrml, dim=1, keepdim=True).clamp(min=eps) + return nrml / veclen + + +# pyre-fixme[2]: Parameter must be annotated. +def depth2xyz(depth, focal, princpt) -> th.Tensor: + """Convert depth image to XYZ image using camera intrinsics + + Args: + depth: th.Tensor + [B, 1, H, W] depth image + + focal: th.Tensor + [B, 2, 2] camera focal lengths + + princpt: th.Tensor + [B, 2] camera principal points + + Returns: + th.Tensor: [B, 3, H, W] XYZ image + """ + + b, h, w = depth.shape[0], depth.shape[2], depth.shape[3] + ix = ( + th.arange(w, device=depth.device).float()[None, None, :] - princpt[:, None, None, 0] + ) / focal[:, None, None, 0, 0] + iy = ( + th.arange(h, device=depth.device).float()[None, :, None] - princpt[:, None, None, 1] + ) / focal[:, None, None, 1, 1] + xyz = th.zeros((b, 3, h, w), device=depth.device) + xyz[:, 0, ...] = depth[:, 0, :, :] * ix + xyz[:, 1, ...] = depth[:, 0, :, :] * iy + xyz[:, 2, ...] = depth[:, 0, :, :] + return xyz + + +# pyre-fixme[2]: Parameter must be annotated. +def depth2normals(depth, focal, princpt) -> th.Tensor: + """Convert depth image to normal image using camera intrinsics + + Args: + depth: th.Tensor + [B, 1, H, W] depth image + + focal: th.Tensor + [B, 2, 2] camera focal lengths + + princpt: th.Tensor + [B, 2] camera principal points + + Returns: + th.Tensor: [B, 3, H, W] normal image + """ + + return xyz2normals(depth2xyz(depth, focal, princpt)) + + +def depth_discontuity_mask( + depth: th.Tensor, threshold: float = 40.0, kscale: float = 4.0, pool_ksize: int = 3 +) -> th.Tensor: + device = depth.device + + with th.no_grad(): + # TODO: pass the kernel? + kernel = th.as_tensor( + [ + [[[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]], + [[[-1, -2, -1], [0, 0, 0], [1, 2, 1]]], + ], + dtype=th.float32, + device=device, + ) + + disc_mask = (th.norm(F.conv2d(depth, kernel, bias=None, padding=1), dim=1) > threshold)[ + :, np.newaxis + ] + disc_mask = ( + F.avg_pool2d(disc_mask.float(), pool_ksize, stride=1, padding=pool_ksize // 2) > 0.0 + ) + + return disc_mask diff --git a/visualize/ca_body/utils/geom_body.py b/visualize/ca_body/utils/geom_body.py new file mode 100644 index 0000000000000000000000000000000000000000..f6c36109e7f12f4900d9effd86d38e52f9c14d90 --- /dev/null +++ b/visualize/ca_body/utils/geom_body.py @@ -0,0 +1,702 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import logging +from logging import Logger + +from typing import Any, Dict, Optional, Tuple, Union + +import igl + +import numpy as np +import torch as th + +import torch.nn as nn + +import torch.nn.functional as F + +from visualize.ca_body.utils.geom import ( + index_image_impaint, + make_uv_barys, + make_uv_vert_index, +) + +from trimesh import Trimesh +from trimesh.triangles import points_to_barycentric + +logger: Logger = logging.getLogger(__name__) + + +def face_normals_v2(v: th.Tensor, vi: th.Tensor, eps: float = 1e-5) -> th.Tensor: + pts = v[:, vi] + v0 = pts[:, :, 1] - pts[:, :, 0] + v1 = pts[:, :, 2] - pts[:, :, 0] + n = th.cross(v0, v1, dim=-1) + norm = th.norm(n, dim=-1, keepdim=True) + norm[norm < eps] = 1 + n /= norm + return n + + +def vert_normals_v2(v: th.Tensor, vi: th.Tensor, eps: float = 1.0e-5) -> th.Tensor: + fnorms = face_normals_v2(v, vi) + fnorms = fnorms[:, :, None].expand(-1, -1, 3, -1).reshape(fnorms.shape[0], -1, 3) + vi_flat = vi.view(1, -1).expand(v.shape[0], -1) + vnorms = th.zeros_like(v) + for j in range(3): + vnorms[..., j].scatter_add_(1, vi_flat, fnorms[..., j]) + norm = th.norm(vnorms, dim=-1, keepdim=True) + norm[norm < eps] = 1 + vnorms /= norm + return vnorms + + +def compute_neighbours( + n_verts: int, vi: th.Tensor, n_max_values: int = 10 +) -> Tuple[th.Tensor, th.Tensor]: + """Computes first-ring neighbours given vertices and faces.""" + n_vi = vi.shape[0] + + adj = {i: set() for i in range(n_verts)} + for i in range(n_vi): + for idx in vi[i]: + adj[idx] |= set(vi[i]) - {idx} + + nbs_idxs = np.tile(np.arange(n_verts)[:, np.newaxis], (1, n_max_values)) + nbs_weights = np.zeros((n_verts, n_max_values), dtype=np.float32) + + for idx in range(n_verts): + n_values = min(len(adj[idx]), n_max_values) + nbs_idxs[idx, :n_values] = np.array(list(adj[idx]))[:n_values] + nbs_weights[idx, :n_values] = -1.0 / n_values + + return nbs_idxs, nbs_weights + + +def compute_v2uv(n_verts: int, vi: th.Tensor, vti: th.Tensor, n_max: int = 4) -> th.Tensor: + """Computes mapping from vertex indices to texture indices. + + Args: + vi: [F, 3], triangles + vti: [F, 3], texture triangles + n_max: int, max number of texture locations + + Returns: + [n_verts, n_max], texture indices + """ + v2uv_dict = {} + for i_v, i_uv in zip(vi.reshape(-1), vti.reshape(-1)): + v2uv_dict.setdefault(i_v, set()).add(i_uv) + assert len(v2uv_dict) == n_verts + v2uv = np.zeros((n_verts, n_max), dtype=np.int32) + for i in range(n_verts): + vals = sorted(v2uv_dict[i]) + v2uv[i, :] = vals[0] + v2uv[i, : len(vals)] = np.array(vals) + return v2uv + + +def values_to_uv(values: th.Tensor, index_img: th.Tensor, bary_img: th.Tensor) -> th.Tensor: + uv_size = index_img.shape[0] + index_mask = th.all(index_img != -1, dim=-1) + idxs_flat = index_img[index_mask].to(th.int64) + bary_flat = bary_img[index_mask].to(th.float32) + # NOTE: here we assume + values_flat = th.sum(values[:, idxs_flat].permute(0, 3, 1, 2) * bary_flat, dim=-1) + values_uv = th.zeros( + values.shape[0], + values.shape[-1], + uv_size, + uv_size, + dtype=values.dtype, + device=values.device, + ) + values_uv[:, :, index_mask] = values_flat + return values_uv + + +def sample_uv( + values_uv: th.Tensor, + uv_coords: th.Tensor, + v2uv: Optional[th.Tensor] = None, + mode: str = "bilinear", + align_corners: bool = False, + flip_uvs: bool = False, +) -> th.Tensor: + batch_size = values_uv.shape[0] + + if flip_uvs: + uv_coords = uv_coords.clone() + uv_coords[:, 1] = 1.0 - uv_coords[:, 1] + + uv_coords_norm = (uv_coords * 2.0 - 1.0)[np.newaxis, :, np.newaxis].expand( + batch_size, -1, -1, -1 + ) + values = ( + F.grid_sample(values_uv, uv_coords_norm, align_corners=align_corners, mode=mode) + .squeeze(-1) + .permute((0, 2, 1)) + ) + + if v2uv is not None: + values_duplicate = values[:, v2uv] + values = values_duplicate.mean(2) + + # if return_var: + # values_var = values_duplicate.var(2) + # return values, values_var + + return values + + +def compute_tbn_uv( + tri_xyz: th.Tensor, tri_uv: th.Tensor, eps: float = 1e-5 +) -> Tuple[th.Tensor, th.Tensor, th.Tensor]: + """Compute tangents, bitangents, normals. + + Args: + tri_xyz: [B,N,3,3] vertex coordinates + tri_uv: [N,2] texture coordinates + + Returns: + tangents, bitangents, normals + """ + + tri_uv = tri_uv[np.newaxis] + + v01 = tri_xyz[:, :, 1] - tri_xyz[:, :, 0] + v02 = tri_xyz[:, :, 2] - tri_xyz[:, :, 0] + + normals = th.cross(v01, v02, dim=-1) + normals = normals / th.norm(normals, dim=-1, keepdim=True).clamp(min=eps) + + vt01 = tri_uv[:, :, 1] - tri_uv[:, :, 0] + vt02 = tri_uv[:, :, 2] - tri_uv[:, :, 0] + + f = th.tensor([1.0], device=tri_xyz.device) / ( + vt01[..., 0] * vt02[..., 1] - vt01[..., 1] * vt02[..., 0] + ) + + tangents = f[..., np.newaxis] * ( + v01 * vt02[..., 1][..., np.newaxis] - v02 * vt01[..., 1][..., np.newaxis] + ) + tangents = tangents / th.norm(tangents, dim=-1, keepdim=True).clamp(min=eps) + + bitangents = th.cross(normals, tangents, dim=-1) + bitangents = bitangents / th.norm(bitangents, dim=-1, keepdim=True).clamp(min=eps).clamp( + min=eps + ) + return tangents, bitangents, normals + + +class GeometryModule(nn.Module): + """This module encapsulates uv correspondences and vertex images.""" + + def __init__( + self, + vi: th.Tensor, + vt: th.Tensor, + vti: th.Tensor, + v2uv: th.Tensor, + uv_size: int, + flip_uv: bool = False, + impaint: bool = False, + impaint_threshold: float = 100.0, + device=None, + ) -> None: + super().__init__() + + self.register_buffer("vi", th.as_tensor(vi)) + self.register_buffer("vt", th.as_tensor(vt)) + self.register_buffer("vti", th.as_tensor(vti)) + self.register_buffer("v2uv", th.as_tensor(v2uv)) + + self.uv_size: int = uv_size + + index_image = make_uv_vert_index( + self.vt, + self.vi, + self.vti, + uv_shape=uv_size, + flip_uv=flip_uv, + ).cpu() + face_index, bary_image = make_uv_barys(self.vt, self.vti, uv_shape=uv_size, flip_uv=flip_uv) + if impaint: + # TODO: have an option to pre-compute this? + assert isinstance(uv_size, int) + if uv_size >= 1024: + logger.info("impainting index image might take a while for sizes >= 1024") + + index_image, bary_image = index_image_impaint( + index_image, bary_image, impaint_threshold + ) + + self.register_buffer("index_image", index_image.cpu()) + self.register_buffer("bary_image", bary_image.cpu()) + self.register_buffer("face_index_image", face_index.cpu()) + + def render_index_images( + self, uv_size: Union[Tuple[int, int], int], flip_uv: bool = False, impaint: bool = False + ) -> Tuple[th.Tensor, th.Tensor]: + index_image = make_uv_vert_index( + self.vt, self.vi, self.vti, uv_shape=uv_size, flip_uv=flip_uv + ) + _, bary_image = make_uv_barys(self.vt, self.vti, uv_shape=uv_size, flip_uv=flip_uv) + + if impaint: + index_image, bary_image = index_image_impaint( + index_image, + bary_image, + ) + + return index_image, bary_image + + def vn(self, verts: th.Tensor) -> th.Tensor: + return vert_normals_v2(verts, self.vi[np.newaxis].to(th.long)) + + def to_uv(self, values: th.Tensor) -> th.Tensor: + return values_to_uv(values, self.index_image, self.bary_image) + + def from_uv(self, values_uv: th.Tensor) -> th.Tensor: + # TODO: we need to sample this + return sample_uv(values_uv, self.vt, self.v2uv.to(th.long)) + + +def compute_view_cos(verts: th.Tensor, faces: th.Tensor, camera_pos: th.Tensor) -> th.Tensor: + vn = F.normalize(vert_normals_v2(verts, faces), dim=-1) + v2c = F.normalize(verts - camera_pos[:, np.newaxis], dim=-1) + return th.einsum("bnd,bnd->bn", vn, v2c) + + +def interpolate_values_mesh( + src_values: th.Tensor, src_faces: th.Tensor, idxs: th.Tensor, weights: th.Tensor +) -> th.Tensor: + """Interpolate values on the mesh.""" + assert src_faces.dtype == th.long, "index should be torch.long" + assert len(src_values.shape) in [2, 3], "supporting [N, F] and [B, N, F] only" + + if src_values.shape == 2: + return (src_values[src_faces[idxs]] * weights[..., np.newaxis]).sum(dim=1) + else: # src.verts.shape == 3: + return (src_values[:, src_faces[idxs]] * weights[np.newaxis, ..., np.newaxis]).sum(dim=2) + + +def depth_discontuity_mask( + depth: th.Tensor, threshold: float = 40.0, kscale: float = 4.0, pool_ksize: int = 3 +) -> th.Tensor: + device = depth.device + + with th.no_grad(): + # TODO: pass the kernel? + kernel = th.as_tensor( + [ + [[[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]], + [[[-1, -2, -1], [0, 0, 0], [1, 2, 1]]], + ], + dtype=th.float32, + device=device, + ) + + disc_mask = (th.norm(F.conv2d(depth, kernel, bias=None, padding=1), dim=1) > threshold)[ + :, np.newaxis + ] + disc_mask = ( + F.avg_pool2d(disc_mask.float(), pool_ksize, stride=1, padding=pool_ksize // 2) > 0.0 + ) + + return disc_mask + + +def convert_camera_parameters(Rt: th.Tensor, K: th.Tensor) -> Dict[str, th.Tensor]: + R = Rt[:, :3, :3] + t = -R.permute(0, 2, 1).bmm(Rt[:, :3, 3].unsqueeze(2)).squeeze(2) + return { + "campos": t, + "camrot": R, + "focal": K[:, :2, :2], + "princpt": K[:, :2, 2], + } + + +def closest_point(mesh: Trimesh, points: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + v = mesh.vertices + vi = mesh.faces + # pyre-ignore + dist, face_idxs, p = igl.point_mesh_squared_distance(points, v, vi) + return p, dist, face_idxs + + +def closest_point_barycentrics( + v: np.ndarray, vi: np.ndarray, points: np.ndarray +) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """Given a 3D mesh and a set of query points, return closest point barycentrics + Args: + v: np.array (float) + [N, 3] mesh vertices + vi: np.array (int) + [N, 3] mesh triangle indices + points: np.array (float) + [M, 3] query points + Returns: + Tuple[approx, barys, interp_idxs, face_idxs] + approx: [M, 3] approximated (closest) points on the mesh + barys: [M, 3] barycentric weights that produce "approx" + interp_idxs: [M, 3] vertex indices for barycentric interpolation + face_idxs: [M] face indices for barycentric interpolation. interp_idxs = vi[face_idxs] + """ + mesh = Trimesh(vertices=v, faces=vi) + p, _, face_idxs = closest_point(mesh, points) + barys = points_to_barycentric(mesh.triangles[face_idxs], p) + b0, b1, b2 = np.split(barys, 3, axis=1) + + interp_idxs = vi[face_idxs] + v0 = v[interp_idxs[:, 0]] + v1 = v[interp_idxs[:, 1]] + v2 = v[interp_idxs[:, 2]] + approx = b0 * v0 + b1 * v1 + b2 * v2 + return approx, barys, interp_idxs, face_idxs + + +def make_closest_uv_barys( + vt: np.ndarray, + vti: np.ndarray, + uv_shape: Union[Tuple[int, int], int], + flip_uv: bool = True, + return_approx_dist: bool = False, +) -> Union[Tuple[th.Tensor, th.Tensor], Tuple[th.Tensor, th.Tensor, th.Tensor]]: + """Compute a UV-space barycentric map where each texel contains barycentric + coordinates for the closest point on a UV triangle. + Args: + vt: th.Tensor + Texture coordinates. Shape = [n_texcoords, 2] + vti: th.Tensor + Face texture coordinate indices. Shape = [n_faces, 3] + uv_shape: Tuple[int, int] or int + Shape of the texture map. (HxW) + flip_uv: bool + Whether or not to flip UV coordinates along the V axis (OpenGL -> numpy/pytorch convention). + return_approx_dist: bool + Whether or not to include the distance to the nearest point. + Returns: + th.Tensor: index_img: Face index image, shape [uv_shape[0], uv_shape[1]] + th.Tensor: Barycentric coordinate map, shape [uv_shape[0], uv_shape[1], 3] + """ + + if isinstance(uv_shape, int): + uv_shape = (uv_shape, uv_shape) + + if flip_uv: + # Flip here because texture coordinates in some of our topo files are + # stored in OpenGL convention with Y=0 on the bottom of the texture + # unlike numpy/torch arrays/tensors. + vt = vt.clone() + vt[:, 1] = 1 - vt[:, 1] + + # Texel to UV mapping (as per OpenGL linear filtering) + # https://www.khronos.org/registry/OpenGL/specs/gl/glspec46.core.pdf + # Sect. 8.14, page 261 + # uv=(0.5,0.5)/w is at the center of texel [0,0] + # uv=(w-0.5, w-0.5)/w is the center of texel [w-1,w-1] + # texel = floor(u*w - 0.5) + # u = (texel+0.5)/w + uv_grid = th.meshgrid( + th.linspace(0.5, uv_shape[0] - 1 + 0.5, uv_shape[0]) / uv_shape[0], + th.linspace(0.5, uv_shape[1] - 1 + 0.5, uv_shape[1]) / uv_shape[1], + ) # HxW, v,u + uv_grid = th.stack(uv_grid[::-1], dim=2) # HxW, u, v + + uv = uv_grid.reshape(-1, 2).data.to("cpu").numpy() + vth = np.hstack((vt, vt[:, 0:1] * 0 + 1)) + uvh = np.hstack((uv, uv[:, 0:1] * 0 + 1)) + approx, barys, interp_idxs, face_idxs = closest_point_barycentrics(vth, vti, uvh) + index_img = th.from_numpy(face_idxs.reshape(uv_shape[0], uv_shape[1])).long() + bary_img = th.from_numpy(barys.reshape(uv_shape[0], uv_shape[1], 3)).float() + + if return_approx_dist: + dist = np.linalg.norm(approx - uvh, axis=1) + dist = th.from_numpy(dist.reshape(uv_shape[0], uv_shape[1])).float() + return index_img, bary_img, dist + else: + return index_img, bary_img + + +def compute_tbn( + geom: th.Tensor, vt: th.Tensor, vi: th.Tensor, vti: th.Tensor +) -> Tuple[th.Tensor, th.Tensor, th.Tensor]: + """Computes tangent, bitangent, and normal vectors given a mesh. + Args: + geom: [N, n_verts, 3] th.Tensor + Vertex positions. + vt: [n_uv_coords, 2] th.Tensor + UV coordinates. + vi: [..., 3] th.Tensor + Face vertex indices. + vti: [..., 3] th.Tensor + Face UV indices. + Returns: + [..., 3] th.Tensors for T, B, N. + """ + + v0 = geom[:, vi[..., 0]] + v1 = geom[:, vi[..., 1]] + v2 = geom[:, vi[..., 2]] + vt0 = vt[vti[..., 0]] + vt1 = vt[vti[..., 1]] + vt2 = vt[vti[..., 2]] + + v01 = v1 - v0 + v02 = v2 - v0 + vt01 = vt1 - vt0 + vt02 = vt2 - vt0 + f = th.tensor([1.0], device=geom.device) / ( + vt01[None, ..., 0] * vt02[None, ..., 1] - vt01[None, ..., 1] * vt02[None, ..., 0] + ) + tangent = f[..., None] * th.stack( + [ + v01[..., 0] * vt02[None, ..., 1] - v02[..., 0] * vt01[None, ..., 1], + v01[..., 1] * vt02[None, ..., 1] - v02[..., 1] * vt01[None, ..., 1], + v01[..., 2] * vt02[None, ..., 1] - v02[..., 2] * vt01[None, ..., 1], + ], + dim=-1, + ) + tangent = F.normalize(tangent, dim=-1) + normal = F.normalize(th.cross(v01, v02, dim=3), dim=-1) + bitangent = F.normalize(th.cross(tangent, normal, dim=3), dim=-1) + + return tangent, bitangent, normal + + +def make_postex(v: th.Tensor, idxim: th.Tensor, barim: th.Tensor) -> th.Tensor: + return ( + barim[None, :, :, 0, None] * v[:, idxim[:, :, 0]] + + barim[None, :, :, 1, None] * v[:, idxim[:, :, 1]] + + barim[None, :, :, 2, None] * v[:, idxim[:, :, 2]] + ).permute( + 0, 3, 1, 2 + ) # B x 3 x H x W + + +def acos_safe_th(x: th.Tensor, eps: float = 1e-4) -> th.Tensor: + slope = th.arccos(th.as_tensor(1 - eps)) / th.as_tensor(eps) + # TODO: stop doing this allocation once sparse gradients with NaNs (like in + # th.where) are handled differently. + buf = th.empty_like(x) + good = abs(x) <= 1 - eps + bad = ~good + sign = th.sign(x.data[bad]) + buf[good] = th.acos(x[good]) + buf[bad] = th.acos(sign * (1 - eps)) - slope * sign * (abs(x[bad]) - 1 + eps) + return buf + + +def invRodrigues(R: th.Tensor, eps: float = 1e-8) -> th.Tensor: + """Computes the Rodrigues vectors r from the rotation matrices `R`""" + + # t = trace(R) + # theta = rotational angle + # [omega]_x = (R-R^T)/2 + # r = theta/sin(theta)*omega + assert R.shape[-2:] == (3, 3) + + t = R[..., 0, 0] + R[..., 1, 1] + R[..., 2, 2] + theta = acos_safe_th((t - 1) / 2) + omega = ( + th.stack( + ( + R[..., 2, 1] - R[..., 1, 2], + R[..., 0, 2] - R[..., 2, 0], + R[..., 1, 0] - R[..., 0, 1], + ), + -1, + ) + / 2 + ) + + # Edge Case 1: t >= 3 - eps + inv_sinc = theta / th.sin(theta) + inv_sinc_taylor_expansion = ( + 1 + + (1.0 / 6.0) * th.pow(theta, 2) + + (7.0 / 360.0) * th.pow(theta, 4) + + (31.0 / 15120.0) * th.pow(theta, 6) + ) + + # Edge Case 2: t <= -1 + eps + # From: https://math.stackexchange.com/questions/83874/efficient-and-accurate-numerical + # -implementation-of-the-inverse-rodrigues-rotatio + a = th.diagonal(R, 0, -2, -1).argmax(dim=-1) + b = (a + 1) % 3 + c = (a + 2) % 3 + + s = th.sqrt(R[..., a, a] - R[..., b, b] - R[..., c, c] + 1 + 1e-4) + v = th.zeros_like(omega) + v[..., a] = s / 2 + v[..., b] = (R[..., b, a] + R[..., a, b]) / (2 * s) + v[..., c] = (R[..., c, a] + R[..., a, c]) / (2 * s) + norm = th.norm(v, dim=-1, keepdim=True).to(v.dtype).clamp(min=eps) + pi_vnorm = np.pi * (v / norm) + + # use taylor expansion when R is close to a identity matrix (trace(R) ~= 3) + r = th.where( + t[:, None] > (3 - 1e-3), + inv_sinc_taylor_expansion[..., None] * omega, + th.where(t[:, None] < -1 + 1e-3, pi_vnorm, inv_sinc[..., None] * omega), + ) + + return r + + +def EulerXYZ_to_matrix(xyz: th.Tensor) -> th.Tensor: + # R = Rz(φ)Ry(θ)Rx(ψ) = [ + # cos θ cos φ sin ψ sin θ cos φ − cos ψ sin φ cos ψ sin θ cos φ + sin ψ sin φ + # cos θ sin φ sin ψ sin θ sin φ + cos ψ cos φ cos ψ sin θ sin φ − sin ψ cos φ + # − sin θ sin ψ cos θ cos ψ cos θ + # ] + ( + x, + y, + z, + ) = ( + xyz[..., 0:1], + xyz[..., 1:2], + xyz[..., 2:3], + ) + sinx, cosx = th.sin(x), th.cos(x) + siny, cosy = th.sin(y), th.cos(y) + sinz, cosz = th.sin(z), th.cos(z) + + r1 = th.cat( + ( + cosy * cosz, + sinx * siny * cosz + - cosx * sinz, # th.sin(x) * th.sin(y) * th.cos(z) - th.cos(x) * th.sin(z), + cosx * siny * cosz + + sinx * sinz, # th.cos(x) * th.sin(y) * th.cos(z) + th.sin(x) * th.sin(z) + ), + -1, + ) # [..., 3] + r3 = th.cat( + ( + -siny, # -th.sin(y), + sinx * cosy, # th.sin(x) * th.cos(y), + cosx * cosy, # th.cos(x) * th.cos(y) + ), + -1, + ) # [..., 3] + r2 = th.cross(r3, r1, dim=-1) + + R = th.cat((r1.unsqueeze(-2), r2.unsqueeze(-2), r3.unsqueeze(-2)), -2) + return R + + +def axisangle_to_matrix(rvec: th.Tensor) -> th.Tensor: + theta = th.sqrt(1e-5 + th.sum(th.pow(rvec, 2), dim=-1)) + rvec = rvec / theta[..., None] + costh = th.cos(theta) + sinth = th.sin(theta) + return th.stack( + ( + th.stack( + ( + th.pow(rvec[..., 0], 2) + (1.0 - th.pow(rvec[..., 0], 2)) * costh, + rvec[..., 0] * rvec[..., 1] * (1.0 - costh) - rvec[..., 2] * sinth, + rvec[..., 0] * rvec[..., 2] * (1.0 - costh) + rvec[..., 1] * sinth, + ), + dim=-1, + ), + th.stack( + ( + rvec[..., 0] * rvec[..., 1] * (1.0 - costh) + rvec[..., 2] * sinth, + th.pow(rvec[..., 1], 2) + (1.0 - th.pow(rvec[..., 1], 2)) * costh, + rvec[..., 1] * rvec[..., 2] * (1.0 - costh) - rvec[..., 0] * sinth, + ), + dim=-1, + ), + th.stack( + ( + rvec[..., 0] * rvec[..., 2] * (1.0 - costh) - rvec[..., 1] * sinth, + rvec[..., 1] * rvec[..., 2] * (1.0 - costh) + rvec[..., 0] * sinth, + th.pow(rvec[..., 2], 2) + (1.0 - th.pow(rvec[..., 2], 2)) * costh, + ), + dim=-1, + ), + ), + dim=-2, + ) + + +def compute_view_cond_tbnrefl( + geom: th.Tensor, campos: th.Tensor, geo_fn: GeometryModule +) -> th.Tensor: + B = int(geom.shape[0]) + S = geo_fn.uv_size + device = geom.device + + # TODO: this can be pre-computed, or we can assume no invalid pixels? + mask = (geo_fn.index_image != -1).any(dim=-1) + idxs = geo_fn.index_image[mask] + tri_uv = geo_fn.vt[geo_fn.v2uv[idxs, 0].to(th.long)] + + tri_xyz = geom[:, idxs] + + t, b, n = compute_tbn_uv(tri_xyz, tri_uv) + + tbn_rot = th.stack((t, -b, n), dim=-2) + + tbn_rot_uv = th.zeros( + (B, S, S, 3, 3), + dtype=th.float32, + device=device, + ) + tbn_rot_uv[:, mask] = tbn_rot + view = F.normalize(campos[:, np.newaxis] - geom, dim=-1) + v_uv = geo_fn.to_uv(values=view) + tbn_uv = th.einsum("bhwij,bjhw->bihw", tbn_rot_uv, v_uv) + + # reflectance vector + n_uv = th.zeros((B, 3, S, S), dtype=th.float32, device=device) + n_uv[..., mask] = n.permute(0, 2, 1) + n_dot_v = (v_uv * n_uv).sum(dim=1, keepdim=True) + + r_uv = 2.0 * n_uv * n_dot_v - v_uv + + return th.cat([tbn_uv, r_uv], dim=1) + + +def get_barys_for_uvs( + topology: Dict[str, Any], uv_correspondences: np.ndarray +) -> Tuple[np.ndarray, np.ndarray]: + """ + Given a topology along with uv correspondences for the topology (eg. keypoints correspondences in uv space), + this function will produce a tuple with the bary coordinates for each uv correspondece along with the vertex index. + + Parameters: + ---------- + topology: Input mesh that contains vertices, faces and texture coordinates info. + uv_correspondences: N X 2 uv locations that describe the uv correspondence to the topology + + Returns: + ------- + bary: (N X 3 float) + For each uv correspondence returns the bary corrdinates for the uv pixel + triangles: (N X 3 int) + For each uv correspondence returns the face (i.e vertices of the faces) for that pixel. + """ + vi: np.ndarray = topology["vi"] + vt: np.ndarray = topology["vt"] + vti: np.ndarray = topology["vti"] + + # # No up-down flip here + # Here we pad the texture cordinates and correspondences with a 0 + vth = np.hstack((vt[:, :2], vt[:, :1] * 0)) + kp_uv_h = np.hstack((uv_correspondences, uv_correspondences[:, :1] * 0)) + + _, kp_barys, _, face_indices = closest_point_barycentrics(vth, vti, kp_uv_h) + + kp_verts = vi[face_indices] + + return kp_barys, kp_verts diff --git a/visualize/ca_body/utils/image.py b/visualize/ca_body/utils/image.py new file mode 100644 index 0000000000000000000000000000000000000000..2951d57a574696e852da8233ad1331f01b3a7b74 --- /dev/null +++ b/visualize/ca_body/utils/image.py @@ -0,0 +1,977 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import warnings +from typing import Dict, Final, List, Optional, overload, Sequence, Tuple, Union + +import cv2 +import numpy as np +import torch as th +import torch.nn.functional as thf + + +Color = Tuple[np.uint8, np.uint8, np.uint8] + +__DEFAULT_WB_SCALE: np.ndarray = np.array([1.05, 0.95, 1.45], dtype=np.float32) + + +@overload +def linear2srgb(img: th.Tensor, gamma: float = 2.4) -> th.Tensor: + ... + + +@overload +def linear2srgb(img: np.ndarray, gamma: float = 2.4) -> np.ndarray: + ... + + +def linear2srgb( + img: Union[th.Tensor, np.ndarray], gamma: float = 2.4 +) -> Union[th.Tensor, np.ndarray]: + if isinstance(img, th.Tensor): + # Note: The following combines the linear and exponential parts of the sRGB curve without + # causing NaN values or gradients for negative inputs (where the curve would be linear). + linear_part = img * 12.92 # linear part of sRGB curve + exp_part = 1.055 * th.pow(th.clamp(img, min=0.0031308), 1 / gamma) - 0.055 + return th.where(img <= 0.0031308, linear_part, exp_part) + else: + linear_part = img * 12.92 + exp_part = 1.055 * (np.maximum(img, 0.0031308) ** (1 / gamma)) - 0.055 + return np.where(img <= 0.0031308, linear_part, exp_part) + + +@overload +def linear2color_corr(img: th.Tensor, dim: int = -1) -> th.Tensor: + ... + + +@overload +def linear2color_corr(img: np.ndarray, dim: int = -1) -> np.ndarray: + ... + + +def linear2color_corr( + img: Union[th.Tensor, np.ndarray], dim: int = -1 +) -> Union[th.Tensor, np.ndarray]: + """Applies ad-hoc 'color correction' to a linear RGB Mugsy image along + color channel `dim` and returns the gamma-corrected result.""" + + if dim == -1: + dim = len(img.shape) - 1 + + gamma = 2.0 + black = 3.0 / 255.0 + color_scale = [1.4, 1.1, 1.6] + + assert img.shape[dim] == 3 + if dim == -1: + dim = len(img.shape) - 1 + if isinstance(img, th.Tensor): + scale = th.FloatTensor(color_scale).view([3 if i == dim else 1 for i in range(img.dim())]) + img = img * scale.to(img) / 1.1 + return th.clamp( + (((1.0 / (1 - black)) * 0.95 * th.clamp(img - black, 0, 2)).pow(1.0 / gamma)) + - 15.0 / 255.0, + 0, + 2, + ) + else: + scale = np.array(color_scale).reshape([3 if i == dim else 1 for i in range(img.ndim)]) + img = img * scale / 1.1 + return np.clip( + (((1.0 / (1 - black)) * 0.95 * np.clip(img - black, 0, 2)) ** (1.0 / gamma)) + - 15.0 / 255.0, + 0, + 2, + ) + + +def linear2displayBatch( + val: th.Tensor, + gamma: float = 1.5, + wbscale: np.ndarray = __DEFAULT_WB_SCALE, + black: float = 5.0 / 255.0, + mode: str = "srgb", +) -> th.Tensor: + scaling: th.Tensor = th.from_numpy(wbscale).to(val.device) + val = val.float() / 255.0 * scaling[None, :, None, None] - black + if mode == "srgb": + val = linear2srgb(val, gamma=gamma) + else: + val = val ** th.tensor(1.0 / gamma) + return th.clamp(val, 0, 1) * 255.0 + + +def linear2color_corr_inv(img: th.Tensor, dim: int) -> th.Tensor: + """Inverse of linear2color_corr. + Removes ad-hoc 'color correction' from a gamma-corrected RGB Mugsy image + along color channel `dim` and returns the linear RGB result.""" + + gamma = 2.0 + black = 3.0 / 255.0 + color_scale = [1.4, 1.1, 1.6] + + assert img.shape[dim] == 3 + if dim == -1: + dim = len(img.shape) - 1 + scale = th.FloatTensor(color_scale).view([3 if i == dim else 1 for i in range(img.dim())]) + + img = (img + 15.0 / 255.0).pow(gamma) / (0.95 / (1 - black)) + black + + return th.clamp(img / (scale.to(img) / 1.1), 0, 1) + + +DEFAULT_CCM: List[List[float]] = [[1, 0, 0], [0, 1, 0], [0, 0, 1]] +DEFAULT_DC_OFFSET: List[float] = [0, 0, 0] +DEFAULT_GAMMA: float = 1.0 + + +@overload +def mapped2linear( + img: th.Tensor, + dim: int = -1, + ccm: Union[List[List[float]], th.Tensor, np.ndarray] = DEFAULT_CCM, + dc_offset: Union[List[float], th.Tensor, np.ndarray] = DEFAULT_DC_OFFSET, + gamma: float = DEFAULT_GAMMA, +) -> th.Tensor: + ... + + +@overload +def mapped2linear( + img: np.ndarray, + dim: int = -1, + ccm: Union[List[List[float]], th.Tensor, np.ndarray] = DEFAULT_CCM, + dc_offset: Union[List[float], th.Tensor, np.ndarray] = DEFAULT_DC_OFFSET, + gamma: float = DEFAULT_GAMMA, +) -> np.ndarray: + ... + + +def mapped2linear( + img: Union[th.Tensor, np.ndarray], + dim: int = -1, + ccm: Union[List[List[float]], th.Tensor, np.ndarray] = DEFAULT_CCM, + dc_offset: Union[List[float], th.Tensor, np.ndarray] = DEFAULT_DC_OFFSET, + gamma: float = DEFAULT_GAMMA, +) -> Union[th.Tensor, np.ndarray]: + """Maps a previously-characterized camera color space into a linear + color space. IMPORTANT: This function assumes RGB channel order, + not BGR. + + The characterization is specified by `ccm`, `dc_offset`, and `gamma`. + The dimension index of the color channel is specified with `dim` (de- + fault is -1 i.e. last dimension.) + + The function accepts both [0, 255] integer and [0, 1] float formats. + However, the return value is always floating point in [0, 1]-range. + + FIXME(swirajaya) - + This is a reimplementation of `RGBMapping::map_to_lin_rgb` in + `//arvr/projects/codec_avatar/calibration/colorcal:colorspace`. To + figure out a C++ / Py binding solution that works for both DGX and + PROD, as well as `np.ndarray` and `th.Tensor`. + + Args: + @param img the image in RGB, as th.Tensor or np.ndarray + @param dim dimension of color channel + @param ccm 3x3 color correction matrix + @param dc_offset camera black level/dc offset + @param gamma encoding gamma + + Returns: + @return the corrected image as float th.Tensor or np.ndarray + """ + + assert img.shape[dim] == 3 + if dim == -1: + dim = len(img.shape) - 1 + + ndim: int = img.dim() if th.is_tensor(img) else img.ndim + pixel_shape: List[int] = [3 if i == dim else 1 for i in range(ndim)] + + # Summation indices for CCM matrix multiplication + # e.g. [sum_j] CCM_ij * Img_kljnpq -> ImgCorr_klinpq if say, dim == 2 + ein_ccm: List[int] = [0, 1] + ein_inp: List[int] = [1 if i == dim else i + 2 for i in range(ndim)] + ein_out: List[int] = [0 if i == dim else i + 2 for i in range(ndim)] + + EPS: float = 1e-7 + if isinstance(img, th.Tensor): + if th.is_floating_point(img): + input_saturated = img > (1.0 - EPS) + imgf = img.double() + else: + input_saturated = img == 255 + imgf = img.double() / 255.0 + dc_offset = th.DoubleTensor(dc_offset).view(pixel_shape).to(img.device) + img_linear = th.clamp( + imgf - dc_offset, + min=EPS, + ).pow(1.0 / gamma) + img_corr = th.clamp( # CCM * img_linear + th.einsum(th.DoubleTensor(ccm).to(img.device), ein_ccm, img_linear, ein_inp, ein_out), + min=0.0, + max=1.0, + ) + img_corr = th.where(input_saturated, 1.0, img_corr) + else: + if np.issubdtype(img.dtype, np.floating): + input_saturated = img > (1.0 - EPS) + imgf = img.astype(float) + else: + input_saturated = img == 255 + imgf = img.astype(float) / 255.0 + dc_offset = np.array(dc_offset).reshape(pixel_shape) + img_linear = np.clip(imgf - dc_offset, a_min=EPS, a_max=None) ** (1.0 / gamma) + img_corr: np.ndarray = np.clip( # CCM * img_linear + np.einsum(np.array(ccm), ein_ccm, img_linear, ein_inp, ein_out), + a_min=0.0, + a_max=1.0, + ) + img_corr: np.ndarray = np.where(input_saturated, 1.0, img_corr) + + return img_corr + + +@overload +def mapped2srgb( + img: th.Tensor, + dim: int = -1, + ccm: Union[List[List[float]], th.Tensor, np.ndarray] = DEFAULT_CCM, + dc_offset: Union[List[float], th.Tensor, np.ndarray] = DEFAULT_DC_OFFSET, + gamma: float = DEFAULT_GAMMA, +) -> th.Tensor: + ... + + +@overload +def mapped2srgb( + img: np.ndarray, + dim: int = -1, + ccm: Union[List[List[float]], th.Tensor, np.ndarray] = DEFAULT_CCM, + dc_offset: Union[List[float], th.Tensor, np.ndarray] = DEFAULT_DC_OFFSET, + gamma: float = DEFAULT_GAMMA, +) -> np.ndarray: + ... + + +def mapped2srgb( + img: Union[th.Tensor, np.ndarray], + dim: int = -1, + ccm: Union[List[List[float]], th.Tensor, np.ndarray] = DEFAULT_CCM, + dc_offset: Union[List[float], th.Tensor, np.ndarray] = DEFAULT_DC_OFFSET, + gamma: float = DEFAULT_GAMMA, +) -> Union[th.Tensor, np.ndarray]: + """Maps a previously-characterized camera color space into sRGB co- + lor space (assuming mapped to Rec709). IMPORTANT: This function + assumes RGB channel order, not BGR. + + The characterization is specified by `ccm`, `dc_offset`, and `gamma`. + The dimension index of the color channel is specified with `dim` + (default is -1 i.e. last dimension.) + """ + # Note: The redundant if-statement below is due to a Pyre bug. + # Currently Pyre fails to handle arguments into overloaded functions that are typed + # as a union of the overloaded method parameter types. + if isinstance(img, th.Tensor): + return linear2srgb(mapped2linear(img, dim, ccm, dc_offset, gamma), gamma=2.4) + else: + return linear2srgb(mapped2linear(img, dim, ccm, dc_offset, gamma), gamma=2.4) + + +@overload +def srgb2linear(img: th.Tensor, gamma: float = 2.4) -> th.Tensor: + ... + + +@overload +def srgb2linear(img: np.ndarray, gamma: float = 2.4) -> np.ndarray: + ... + + +def srgb2linear( + img: Union[th.Tensor, np.ndarray], gamma: float = 2.4 +) -> Union[th.Tensor, np.ndarray]: + linear_part = img / 12.92 # linear part of sRGB curve + if isinstance(img, th.Tensor): + # Note: The following combines the linear and exponential parts of the sRGB curve without + # causing NaN values or gradients for negative inputs (where the curve would be linear). + exp_part = th.pow((th.clamp(img, min=0.04045) + 0.055) / 1.055, gamma) + return th.where(img <= 0.04045, linear_part, exp_part) + else: + exp_part = ((np.maximum(img, 0.04045) + 0.055) / 1.055) ** gamma + return np.where(img <= 0.04045, linear_part, exp_part) + + +def scale_diff_image(diff_img: th.Tensor) -> th.Tensor: + """Takes a difference image returns a new version scaled s.t. its values + are remapped from [-IMG_MAX, IMG_MAX] -> [0, IMG_MAX] where IMG_MAX is + either 1 or 255 dpeending on the range of the input.""" + + mval = abs(diff_img).max().item() + pix_range = (0, 128 if mval > 1 else 0.5, 255 if mval > 1 else 1) + return (pix_range[1] * (diff_img / mval) + pix_range[1]).clamp(pix_range[0], pix_range[2]) + + +class LaplacianTexture(th.nn.Module): + def __init__( + self, n_levels: int, n_channels: int = 3, init_scalar: Optional[float] = None + ) -> None: + super().__init__() + self.n_levels = n_levels + self.n_channels = n_channels + if init_scalar is not None: + init_scalar = init_scalar / n_levels + + pyr_texs = [] + for level in range(n_levels): + if init_scalar is not None: + pyr_texs.append( + th.nn.Parameter(init_scalar * th.ones(1, n_channels, 2**level, 2**level)) + ) + else: + pyr_texs.append(th.nn.Parameter(th.zeros(1, n_channels, 2**level, 2**level))) + + self.pyr_texs = th.nn.ParameterList(pyr_texs) + + def forward(self) -> th.Tensor: + tex = self.pyr_texs[0] + for level in range(1, self.n_levels): + tex = ( + thf.interpolate(tex, scale_factor=2, mode="bilinear", align_corners=False) + + self.pyr_texs[level] + ) + return tex + + def init_from_tex(self, tex: th.Tensor) -> None: + ds = [tex] + for level in range(1, self.n_levels): + ds.append(thf.avg_pool2d(tex, 2**level)) + ds = ds[::-1] + + self.pyr_texs[0].data[:] = ds[0].data + for level in range(1, self.n_levels): + self.pyr_texs[level].data[:] = ds[level].data - thf.interpolate( + ds[level - 1].data, + scale_factor=2, + mode="bilinear", + align_corners=False, + ) + + def render_grad(self) -> th.Tensor: + gtex = self.pyr_texs[0].grad + for level in range(1, self.n_levels): + gtex = ( + thf.interpolate(gtex, scale_factor=2, mode="bilinear", align_corners=False) + + self.pyr_texs[level].grad + ) + return gtex + + +morph_cache: Dict[Tuple[int, th.device], th.Tensor] = {} + + +def dilate(x: th.Tensor, ks: int) -> th.Tensor: + assert (ks % 2) == 1 + orig_dtype = x.dtype + + if x.dtype in [th.bool, th.int64, th.int32]: + x = x.float() + if x.dim() == 3: + x = x[:, None] + + if (ks, x.device) in morph_cache: + w = morph_cache[(ks, x.device)] + else: + w = th.ones(1, 1, ks, ks, device=x.device) + morph_cache[(ks, x.device)] = w + + return (thf.conv2d(x, w, padding=ks // 2) > 0).to(dtype=orig_dtype) + + +def erode(x: th.Tensor, ks: int) -> th.Tensor: + if x.dtype is th.bool: + flip_x = ~x + else: + flip_x = 1 - x + + flip_out = dilate(flip_x, ks) + + if flip_out.dtype is th.bool: + return ~flip_out + else: + return 1 - flip_out + + +def smoothstep(e0: np.ndarray, e1: np.ndarray, x: np.ndarray) -> np.ndarray: + t = np.clip(((x - e0) / (e1 - e0)), 0, 1) + return t * t * (3.0 - 2.0 * t) + + +def smootherstep(e0: np.ndarray, e1: np.ndarray, x: np.ndarray) -> np.ndarray: + t = np.clip(((x - e0) / (e1 - e0)), 0, 1) + return (t**3) * (t * (t * 6 - 15) + 10) + + +def tensor2rgbjet( + tensor: th.Tensor, x_max: Optional[float] = None, x_min: Optional[float] = None +) -> np.ndarray: + """Converts a tensor to an uint8 image Numpy array with `cv2.COLORMAP_JET` applied. + + Args: + tensor: Input tensor to be converted. + + x_max: The output color will be normalized as (x-x_min)/(x_max-x_min)*255. + x_max = tensor.max() if None is given. + + x_min: The output color will be normalized as (x-x_min)/(x_max-x_min)*255. + x_min = tensor.min() if None is given. + """ + return cv2.applyColorMap(tensor2rgb(tensor, x_max=x_max, x_min=x_min), cv2.COLORMAP_JET) + + +def tensor2rgb( + tensor: th.Tensor, x_max: Optional[float] = None, x_min: Optional[float] = None +) -> np.ndarray: + """Converts a tensor to an uint8 image Numpy array. + + Args: + tensor: Input tensor to be converted. + + x_max: The output color will be normalized as (x-x_min)/(x_max-x_min)*255. + x_max = tensor.max() if None is given. + + x_min: The output color will be normalized as (x-x_min)/(x_max-x_min)*255. + x_min = tensor.min() if None is given. + """ + x = tensor.data.cpu().numpy() + if x_min is None: + x_min = x.min() + if x_max is None: + x_max = x.max() + + gain = 255 / np.clip(x_max - x_min, 1e-3, None) + x = (x - x_min) * gain + x = x.clip(0.0, 255.0) + x = x.astype(np.uint8) + return x + + +def tensor2image( + tensor: th.Tensor, + x_max: Optional[float] = 1.0, + x_min: Optional[float] = 0.0, + mode: str = "rgb", + mask: Optional[th.Tensor] = None, + label: Optional[str] = None, +) -> np.ndarray: + """Converts a tensor to an image. + + Args: + tensor: Input tensor to be converted. + The shape of the tensor should be CxHxW or HxW. The channels are assumed to be in RGB format. + + x_max: The output color will be normalized as (x-x_min)/(x_max-x_min)*255. + x_max = tensor.max() if None is explicitly given. + + x_min: The output color will be normalized as (x-x_min)/(x_max-x_min)*255. + x_min = tensor.min() if None is explicitly given. + + mode: Can be `rgb` or `jet`. If `jet` is given, cv2.COLORMAP_JET would be applied. + + mask: Optional mask to be applied to the input tensor. + + label: Optional text to be added to the output image. + """ + tensor = tensor.detach() + + # Apply mask + if mask is not None: + tensor = tensor * mask + + if len(tensor.size()) == 2: + tensor = tensor[None] + + # Make three channel image + assert len(tensor.size()) == 3, tensor.size() + n_channels = tensor.shape[0] + if n_channels == 1: + tensor = tensor.repeat(3, 1, 1) + elif n_channels != 3: + raise ValueError(f"Unsupported number of channels {n_channels}.") + + # Convert to display format + img = tensor.permute(1, 2, 0) + + if mode == "rgb": + img = tensor2rgb(img, x_max=x_max, x_min=x_min) + elif mode == "jet": + # `cv2.applyColorMap` assumes input format in BGR + img[:, :, :3] = img[:, :, [2, 1, 0]] + img = tensor2rgbjet(img, x_max=x_max, x_min=x_min) + # convert back to rgb + img[:, :, :3] = img[:, :, [2, 1, 0]] + else: + raise ValueError(f"Unsupported mode {mode}.") + + if label is not None: + img = add_label_centered(img, label) + + return img + + +def add_label_centered( + img: np.ndarray, + text: str, + font_scale: float = 1.0, + thickness: int = 2, + alignment: str = "top", + color: Tuple[int, int, int] = (0, 255, 0), +) -> np.ndarray: + """Adds label to an image + + Args: + img: Input image. + + text: Text to be added on the image. + + font_scale: The scale of the font. + + thickness: Thinkness of the lines. + + alignment: Can be `top` or `buttom`. The alignment of the text. + + color: The color of the text. Assumes the same color space as `img`. + """ + font = cv2.FONT_HERSHEY_SIMPLEX + textsize = cv2.getTextSize(text, font, font_scale, thickness=thickness)[0] + img = img.astype(np.uint8).copy() + + if alignment == "top": + cv2.putText( + img, + text, + ((img.shape[1] - textsize[0]) // 2, 50), + font, + font_scale, + color, + thickness=thickness, + lineType=cv2.LINE_AA, + ) + elif alignment == "bottom": + cv2.putText( + img, + text, + ((img.shape[1] - textsize[0]) // 2, img.shape[0] - textsize[1]), + font, + font_scale, + color, + thickness=thickness, + lineType=cv2.LINE_AA, + ) + else: + raise ValueError("Unknown text alignment") + + return img + + +def get_color_map(name: str = "COLORMAP_JET") -> np.ndarray: + """Return a 256 x 3 array representing a color map from OpenCV.""" + color_map = np.arange(256, dtype=np.uint8).reshape(1, 256) + color_map = cv2.applyColorMap(color_map, getattr(cv2, name)) + return color_map[0, :, ::-1].copy() + + +def feature2rgb(x: Union[th.Tensor, np.ndarray], scale: int = -1) -> np.ndarray: + # expect 3 dim tensor + b = (x[::3].sum(0)).data.cpu().numpy()[:, :, None] + g = (x[1::3].sum(0)).data.cpu().numpy()[:, :, None] + r = (x[2::3].sum(0)).data.cpu().numpy()[:, :, None] + rgb = np.concatenate((b, g, r), axis=2) + rgb_norm = (rgb - rgb.min()) / (rgb.max() - rgb.min()) + rgb_norm = (rgb_norm * 255).astype(np.uint8) + if scale != -1: + rgb_norm = cv2.resize(rgb_norm, None, fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC) + return rgb_norm + + +def kpts2delta(kpts: th.Tensor, size: Sequence[int]) -> th.Tensor: + # kpts: B x N x 2 + # Return: B x N x H x W x 2, 2D vectors from each grid location to kpts. + h, w = size + grid = th.meshgrid( + th.arange(h, dtype=kpts.dtype, device=kpts.device), + th.arange(w, dtype=kpts.dtype, device=kpts.device), + indexing="xy", + ) + delta = kpts.unflatten(-1, (1, 1, 2)) - th.stack(grid, dim=-1).unflatten(0, (1, 1, h)) + return delta + + +def kpts2heatmap(kpts: th.Tensor, size: Sequence[int], sigma: int = 7) -> th.Tensor: + # kpts: B x N x 2 + dist = kpts2delta(kpts, size).square().sum(-1) + heatmap = th.exp(-dist / (2 * sigma**2)) + return heatmap + + +def make_image_grid( + data: Union[th.Tensor, Dict[str, th.Tensor]], + keys_to_draw: Optional[List[str]] = None, + scale_factor: Optional[float] = None, + draw_labels: bool = True, + grid_size: Optional[Tuple[int, int]] = None, +) -> np.ndarray: + """Arranges a tensor of images (or a dict with labeled image tensors) into + a grid. + + Params: + data: Either a single image tensor [N, {1, 3}, H, W] containing images to + arrange in a grid layout, or a dict with tensors of the same shape. + If a dict is given, assume each entry in the dict is a batch of + images, and form a grid where each cell contains one sample from + each entry in the dict. Images should be in the range [0, 255]. + + keys_to_draw: Select which keys in the dict should be included in each + grid cell. If none are given, draw all keys. + + scale_factor: Optional scale factor applied to each image. + + draw_labels: Whether or not to draw the keys on each image. + + grid_size: Optionally specify the size of the resulting grid. + """ + + if isinstance(data, th.Tensor): + data = {"": data} + keys_to_draw = [""] + + if keys_to_draw is None: + keys_to_draw = list(data.keys()) + + n_cells = data[keys_to_draw[0]].shape[0] + img_h = data[keys_to_draw[0]].shape[2] + img_w = data[keys_to_draw[0]].shape[3] + + # Resize all images to match the shape of the first image, and convert + # Greyscale -> RGB. + for key in keys_to_draw: + if data[key].shape[1] == 1: + data[key] = data[key].expand(-1, 3, -1, -1) + elif data[key].shape[1] != 3: + raise ValueError( + f"Image data must all be of shape [N, {1,3}, H, W]. Got shape {data[key].shape}." + ) + + data[key] = data[key].clamp(min=0, max=255) + if data[key].shape[2] != img_h or data[key].shape[3] != img_w: + data[key] = thf.interpolate(data[key], size=(img_h, img_w), mode="area") + + if scale_factor is not None: + data[key] = thf.interpolate(data[key], scale_factor=scale_factor, mode="area") + + # Make an image for each grid cell by labeling and concatenating a sample + # from each key in the data. + cell_imgs = [] + for i in range(n_cells): + imgs = [data[key][i].byte().cpu().numpy().transpose(1, 2, 0) for key in keys_to_draw] + imgs = [np.ascontiguousarray(img) for img in imgs] + if draw_labels: + for img, label in zip(imgs, keys_to_draw): + cv2.putText( + img, label, (31, 31), cv2.FONT_HERSHEY_SIMPLEX, 0.75, (0, 0, 0), 2, cv2.LINE_AA + ) + cv2.putText( + img, + label, + (30, 30), + cv2.FONT_HERSHEY_SIMPLEX, + 0.75, + (255, 255, 255), + 2, + cv2.LINE_AA, + ) + cell_imgs.append(np.concatenate(imgs, axis=1)) + + cell_h, cell_w = cell_imgs[0].shape[:2] + + # Find the most-square grid layout. + if grid_size is not None: + gh, gw = grid_size + if gh * gw < n_cells: + raise ValueError( + f"Requested grid size ({gh}, {gw}) (H, W) cannot hold {n_cells} images." + ) + else: + best_diff = np.inf + best_side = np.inf + best_leftover = np.inf + gw = 0 + for gh_ in range(1, n_cells + 1): + for gw_ in range(1, n_cells + 1): + if gh_ * gw_ < n_cells: + continue + + h = gh_ * cell_h + w = gw_ * cell_w + diff = abs(h - w) + max_side = max(gh_, gw_) + leftover = gh_ * gw_ - n_cells + + if diff <= best_diff and max_side <= best_side and leftover <= best_leftover: + gh = gh_ + gw = gw_ + best_diff = diff + best_side = max_side + best_leftover = leftover + + # Put the images into the grid. + img = np.zeros((gh * cell_h, gw * cell_w, 3), dtype=np.uint8) + for i in range(n_cells): + gr = i // gw + gc = i % gw + img[gr * cell_h : (gr + 1) * cell_h, gc * cell_w : (gc + 1) * cell_w] = cell_imgs[i] + + return img + + +def make_image_grid_batched( + data: Dict[str, th.Tensor], + max_row_hight: Optional[int] = None, + draw_labels: bool = True, + input_is_in_0_1: bool = False, +) -> np.ndarray: + """A simpler version of `make_image_grid` that works for the whole batch at once. + + Usecase: A dict containing diagnostic output. All tensors in the dict have a shape of [N, {1, 3}, H, W] + where N concides for all entries. The goal is to arranges images into a grid so that each column + corrensponds to a key, and each row corrensponds to an index in batch. + + Example: + Data: + dict = {"A": A, "B": B, "C": C} + + Grid: + | A[0] | B[0] | C[0] | + | A[1] | B[1] | C[1] | + | A[2] | B[2] | C[2] | + + The the grid will be aranged such way, that: + - Each row corrensponds to an index in the batch. + - Each column corrensponds to a key in the dict + - For each row, images are resize such that the vertical edge matches the largest image + + Args: + data (Dict[str, th.Tensor]): Diagnostic data. + max_row_hight (int): The maximum allowed hight of a row. + draw_labels (bool): Whether the keys should be drawn as labels + input_is_in_0_1 (bool): If true, input data is assumed to be in range 0..1 otherwise in range 0..255 + """ + data_list = list(data.values()) + keys_to_draw = data.keys() + + if not all(x.ndim == 4 and (x.shape[1] == 1 or x.shape[1] == 3) for x in data_list): + raise ValueError( + f"Image data must all be of shape [N, {1, 3}, H, W]. Got shapes {[x.shape for x in data_list]}." + ) + + if not all(x.shape[0] == data_list[0].shape[0] for x in data_list): + raise ValueError("Batch sizes must be the same.") + + data_list = resize_to_match(data_list, edge="vertical", max_size=max_row_hight) + + if not all(x.shape[2] == data_list[0].shape[2] for x in data_list): + raise ValueError("Heights must be the same.") + + with th.no_grad(): + # Make all images contain 3 channels + data_list = [x.expand(-1, 3, -1, -1) if x.shape[1] == 1 else x for x in data_list] + + # Convert to byte + scale = 255.0 if input_is_in_0_1 else 1.0 + data_list = [x.mul(scale).round().clamp(min=0, max=255).byte() for x in data_list] + + # Convert to numpy and make it BHWC + data_list = [x.cpu().numpy().transpose(0, 2, 3, 1) for x in data_list] + + rows = [] + # Iterate by key + for j, label in zip(range(len(data_list)), keys_to_draw): + col = [] + # Iterate by batch index + for i in range(data_list[0].shape[0]): + img = np.ascontiguousarray(data_list[j][i]) + if draw_labels: + cv2.putText( + img, label, (31, 31), cv2.FONT_HERSHEY_SIMPLEX, 0.75, (0, 0, 0), 2, cv2.LINE_AA + ) + cv2.putText( + img, + label, + (30, 30), + cv2.FONT_HERSHEY_SIMPLEX, + 0.75, + (255, 255, 255), + 2, + cv2.LINE_AA, + ) + col.append(img) + rows.append(np.concatenate(col, axis=0)) + return np.concatenate(rows, axis=1) + + +def resize_to_match( + tensors: List[th.Tensor], + edge: str = "long", + mode: str = "nearest", + max_size: Optional[int] = None, +) -> List[th.Tensor]: + """Resizes a list of image tensors s.t. a chosen edge ("long", "short", "vertical", or "horizontal") + matches that edge on the largest image in the list.""" + + assert edge in {"short", "long", "vertical", "horizontal"} + max_shape = [max(x) for x in zip(*[t.shape for t in tensors])] + + resized_tensors = [] + for tensor in tensors: + if edge == "long": + edge_idx = np.argmax(tensor.shape[-2:]) + elif edge == "short": + edge_idx = np.argmin(tensor.shape[-2:]) + elif edge == "vertical": + edge_idx = 0 + else: # edge == "horizontal": + edge_idx = 1 + + target_size = max_shape[-2:][edge_idx] + if max_size is not None: + target_size = min(max_size, max_shape[-2:][edge_idx]) + + if tensor.shape[-2:][edge_idx] != target_size: + ratio = target_size / tensor.shape[-2:][edge_idx] + tensor = thf.interpolate( + tensor, + scale_factor=ratio, + align_corners=False if mode in ["bilinear", "bicubic"] else None, + recompute_scale_factor=True, + mode=mode, + ) + resized_tensors.append(tensor) + return resized_tensors + + +def draw_text( + canvas: th.Tensor, + text: str, + loc: Tuple[float, float], + font: int = cv2.FONT_HERSHEY_SIMPLEX, + scale: float = 2, + color: Tuple[float, float, float] = (0, 0, 0), + thickness: float = 3, +) -> th.Tensor: + """Helper used by Rosetta to draw text on tensors using OpenCV.""" + device = canvas.device + canvas_new = canvas.cpu().numpy().transpose(0, 2, 3, 1) + for i in range(canvas_new.shape[0]): + image = canvas_new[i].copy() + if isinstance(text, list): + cv2.putText(image, text[i], loc, font, scale, color, thickness) + else: + cv2.putText(image, text, loc, font, scale, color, thickness) + canvas_new[i] = image + canvas_tensor = th.ByteTensor(canvas_new.transpose(0, 3, 1, 2)).to(device) + return canvas_tensor + + +# TODO(T153410551): Deprecate this function +def visualize_scalar_image( + img: np.ndarray, + min_val: float, + val_range: float, + color_map: int = cv2.COLORMAP_JET, + convert_to_rgb: bool = True, +) -> np.ndarray: + """ + Visualizes a scalar image using specified color map. + """ + scaled_img = (img.astype(np.float32) - min_val) / val_range + vis = cv2.applyColorMap((scaled_img * 255).clip(0, 255).astype(np.uint8), color_map) + if convert_to_rgb: + vis = cv2.cvtColor(vis, cv2.COLOR_BGR2RGB) + return vis + + +def process_depth_image( + depth_img: np.ndarray, depth_min: float, depth_max: float, depth_err_range: float +) -> Tuple[np.ndarray, np.ndarray]: + """ + Process the depth image within the range for visualization. + """ + valid_pixels = np.logical_and(depth_img > 0, depth_img <= depth_max) + new_depth_img = np.zeros_like(depth_img) + new_depth_img[valid_pixels] = depth_img[valid_pixels] + err_image = np.abs(new_depth_img - depth_img).astype(np.float32) / depth_err_range + return new_depth_img, err_image + + +def draw_keypoints(img: np.ndarray, kpt: np.ndarray, kpt_w: float) -> np.ndarray: + """ + Draw Keypoints on given image. + """ + x, y = kpt[:, 0], kpt[:, 1] + w = kpt[:, 2] * kpt_w + col = np.array([-255.0, 255.0, -255.0]) * w[:, np.newaxis] + pts = np.column_stack((x.astype(np.int32), y.astype(np.int32))) + for pt, c in zip(pts, col): + cv2.circle(img, tuple(pt), 2, tuple(c), -1) + + return img + + +def tensor_to_rgb_array(tensor: th.Tensor) -> np.ndarray: + """Moves channels dimension to the end of tensor. + Makes it more suitable for visualizations. + """ + return tensor.permute(0, 2, 3, 1).detach().cpu().numpy() + + +def draw_keypoints_with_color( + image: np.ndarray, keypoints_uvw: np.ndarray, color: Color +) -> np.ndarray: + """Renders keypoints onto a given image with particular color. + Supports overlaps. + """ + assert len(image.shape) == 3 + assert image.shape[-1] == 3 + coords = keypoints_uvw[:, :2].astype(np.int32) + tmp_img = np.zeros(image.shape, dtype=np.float32) + for uv in coords: + cv2.circle(tmp_img, tuple(uv), 2, color, -1) + return (image + tmp_img).clip(0.0, 255.0).astype(np.uint8) + + +def draw_contour(img: np.ndarray, contour_corrs: np.ndarray) -> np.ndarray: + """ + Draw Contour on given image. + """ + for corr in contour_corrs: + mesh_uv = corr[1:3] + seg_uv = corr[3:] + + x, y = int(mesh_uv[0] + 0.5), int(mesh_uv[1] + 0.5) + cv2.circle(img, (x, y), 1, (255, 0, 0), -1) + + cv2.line( + img, + (int(mesh_uv[0]), int(mesh_uv[1])), + (int(seg_uv[0]), int(seg_uv[1])), + (-255, -255, 255), + 1, + ) + + return img diff --git a/visualize/ca_body/utils/lbs.py b/visualize/ca_body/utils/lbs.py new file mode 100644 index 0000000000000000000000000000000000000000..4764f54afe7d8e231d565a188944857e19f2c943 --- /dev/null +++ b/visualize/ca_body/utils/lbs.py @@ -0,0 +1,828 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import json +import numpy as np +import re + +import torch +import torch as th +import torch.nn as nn +import torch.nn.functional as F + +from typing import Dict, Any + +from visualize.ca_body.utils.quaternion import Quaternion + +from pytorch3d.transforms import matrix_to_euler_angles + + +from typing import Optional, Tuple + +import logging + +logger = logging.getLogger(__name__) + + +class ParameterTransform(nn.Module): + def __init__(self, lbs_cfg_dict: Dict[str, Any]): + super().__init__() + + # self.pose_names = list(lbs_cfg_dict["joint_names"]) + self.channel_names = list(lbs_cfg_dict["channel_names"]) + transform_offsets = torch.FloatTensor(lbs_cfg_dict["transform_offsets"]) + transform = torch.FloatTensor(lbs_cfg_dict["transform"]) + self.limits = lbs_cfg_dict["limits"] + + self.nr_scaling_params = lbs_cfg_dict["nr_scaling_params"] + self.nr_position_params = lbs_cfg_dict["nr_position_params"] + self.nr_total_params = self.nr_scaling_params + self.nr_position_params + + self.register_buffer("transform_offsets", transform_offsets) + self.register_buffer("transform", transform) + + def forward(self, pose: th.Tensor) -> th.Tensor: + """ + :param pose: raw pose inputs, shape (batch_size, len(pose_names)) + :return: skeleton parameters, shape (batch_size, len(channel_names)*nr_skeleton_joints) + """ + return self.transform.mm(pose.t()).t() + self.transform_offsets + + +class LinearBlendSkinning(nn.Module): + def __init__( + self, + model_json: Dict[str, Any], + lbs_config_dict: Dict[str, Any], + num_max_skin_joints: int =8, + scale_path: str =None, + ): + super().__init__() + + model = model_json + self.param_transform = ParameterTransform(lbs_config_dict) + + self.joint_names = [] + + nr_joints = len(model["Skeleton"]["Bones"]) + joint_parents = torch.zeros((nr_joints, 1), dtype=torch.int64) + joint_rotation = torch.zeros((nr_joints, 4), dtype=torch.float32) + joint_offset = torch.zeros((nr_joints, 3), dtype=torch.float32) + for idx, bone in enumerate(model["Skeleton"]["Bones"]): + self.joint_names.append(bone["Name"]) + if bone["Parent"] > nr_joints: + joint_parents[idx] = -1 + else: + joint_parents[idx] = bone["Parent"] + joint_rotation[idx, :] = torch.FloatTensor(bone["PreRotation"]) + joint_offset[idx, :] = torch.FloatTensor(bone["TranslationOffset"]) + + skin_model = model["SkinnedModel"] + mesh_vertices = torch.FloatTensor(skin_model["RestPositions"]) + mesh_normals = torch.FloatTensor(skin_model["RestVertexNormals"]) + + weights = torch.FloatTensor([e[1] for e in skin_model["SkinningWeights"]]) + indices = torch.LongTensor([e[0] for e in skin_model["SkinningWeights"]]) + offsets = torch.LongTensor(skin_model["SkinningOffsets"]) + + nr_vertices = len(offsets) - 1 + skin_weights = torch.zeros((nr_vertices, num_max_skin_joints), dtype=torch.float32) + skin_indices = torch.zeros((nr_vertices, num_max_skin_joints), dtype=torch.int64) + + offset_right = offsets[1:] + for offset in range(num_max_skin_joints): + offset_left = offsets[:-1] + offset + skin_weights[offset_left < offset_right, offset] = weights[ + offset_left[offset_left < offset_right] + ] + skin_indices[offset_left < offset_right, offset] = indices[ + offset_left[offset_left < offset_right] + ] + + mesh_faces = torch.IntTensor(skin_model["Faces"]["Indices"]).view(-1, 3) + mesh_texture_faces = torch.IntTensor(skin_model["Faces"]["TextureIndices"]).view(-1, 3) + mesh_texture_coords = torch.FloatTensor(skin_model["TextureCoordinates"]).view(-1, 2) + + # zero_pose = torch.zeros((1, len(self.param_transform.pose_names)), dtype=torch.float32) + zero_pose = torch.zeros((1, self.param_transform.nr_total_params), dtype=torch.float32) + bind_state = solve_skeleton_state( + self.param_transform(zero_pose), joint_offset, joint_rotation, joint_parents + ) + + # self.register_buffer('mesh_vertices', mesh_vertices) # we want to train on rest pose + # self.mesh_vertices = nn.Parameter(mesh_vertices, requires_grad=optimize_mesh) + self.register_buffer("mesh_vertices", mesh_vertices) + + self.register_buffer("joint_parents", joint_parents) + self.register_buffer("joint_rotation", joint_rotation) + self.register_buffer("joint_offset", joint_offset) + self.register_buffer("mesh_normals", mesh_normals) + self.register_buffer("mesh_faces", mesh_faces) + self.register_buffer("mesh_texture_faces", mesh_texture_faces) + self.register_buffer("mesh_texture_coords", mesh_texture_coords) + self.register_buffer("skin_weights", skin_weights) + self.register_buffer("skin_indices", skin_indices) + self.register_buffer("bind_state", bind_state) + self.register_buffer("rest_vertices", mesh_vertices) + + # pre-compute joint weights + self.register_buffer("joints_weights", self.compute_joints_weights()) + + if scale_path is not None: + scale = np.loadtxt(scale_path).astype(np.float32)[np.newaxis] + scale = scale[:, 0, :] if len(scale.shape) == 3 else scale + self.register_buffer("scale", torch.tensor(scale)) + + @property + def num_verts(self): + return self.mesh_vertices.size(0) + + @property + def num_joints(self): + return self.joint_offset.size(0) + + @property + def num_params(self): + return self.skin_weights.shape[-1] + + def compute_rigid_transforms(self, global_pose: th.Tensor, local_pose: th.Tensor, scale: th.Tensor): + """Returns rigid transforms.""" + params = torch.cat([global_pose, local_pose, scale], axis=-1) + params = self.param_transform(params) + return solve_skeleton_state( + params, self.joint_offset, self.joint_rotation, self.joint_parents + ) + + def compute_rigid_transforms_matrix(self, global_pose: th.Tensor, local_pose: th.Tensor, scale: th.Tensor): + params = torch.cat([global_pose, local_pose, scale], axis=-1) + params = self.param_transform(params) + states = solve_skeleton_state( + params, self.joint_offset, self.joint_rotation, self.joint_parents + ) + return states_to_matrix(self.bind_state, states) + + def compute_joints_weights(self, drop_empty=False): + """Compute weights per joint given flattened weights-indices.""" + idxs_verts = torch.arange(self.num_verts)[:, np.newaxis].expand(-1, self.num_params) + weights_joints = torch.zeros( + (self.num_joints, self.num_verts), + dtype=torch.float32, + device=self.skin_weights.device, + ) + weights_joints[self.skin_indices, idxs_verts] = self.skin_weights + + if drop_empty: + weights_joints = weights_joints[weights_joints.sum(axis=-1).abs() > 0] + + return weights_joints + + def compute_root_rigid_transform(self, poses: th.Tensor) -> Tuple[th.Tensor, th.Tensor]: + """Get a transform of the root joint.""" + scales = torch.zeros( + (poses.shape[0], self.nr_total_params - poses.shape[1]), + dtype=poses.dtype, + device=poses.device, + ) + params = torch.cat((poses, scales), 1) + states = solve_skeleton_state( + self.param_transform(params), + self.joint_offset, + self.joint_rotation, + self.joint_parents, + ) + mat = states_to_matrix(self.bind_state, states) + return mat[:, 1, :, 3], mat[:, 1, :, :3] + + def compute_relative_rigid_transforms(self, global_pose: th.Tensor, local_pose: th.Tensor, scale: th.Tensor): + params = torch.cat([global_pose, local_pose, scale], axis=-1) + params = self.param_transform(params) + + batch_size = params.shape[0] + + joint_offset = self.joint_offset + joint_rotation = self.joint_rotation + + # batch processing for parameters + jp = params.view((batch_size, -1, 7)) + lt = jp[:, :, 0:3] + joint_offset.unsqueeze(0) + lr = Quaternion.batchMul(joint_rotation.unsqueeze(0), Quaternion.batchFromXYZ(jp[:, :, 3:6])) + return torch.cat([lt, lr], axis=-1) + + def skinning(self, bind_state: th.Tensor, vertices: th.Tensor, target_states: th.Tensor): + """ + Apply skinning to a set of states + + Args: + b/bind_state: 1 x nr_joint x 8 bind state + v/vertices: 1 x nr_vertices x 3 vertices + t/target_states: batch_size x nr_joint x 8 current states + + Returns: + batch_size x nr_vertices x 3 skinned vertices + """ + assert target_states.size()[1:] == bind_state.size()[1:] + + mat = states_to_matrix(bind_state, target_states) + + # apply skinning to vertices + vs = torch.matmul( + mat[:, self.skin_indices], + torch.cat((vertices, torch.ones_like(vertices[:, :, 0]).unsqueeze(2)), dim=2) + .unsqueeze(2) + .unsqueeze(4), + ) + ws = self.skin_weights.unsqueeze(2).unsqueeze(3) + res = (vs * ws).sum(dim=2).squeeze(3) + + return res + + def unpose(self, poses: th.Tensor, scales: th.Tensor, verts: th.Tensor): + """ + :param poses: 100 (tx ty tz rx ry rz) params in blueman + :param scales: 29 (s) params in blueman + :return: + """ + # check shape of poses and scales + params = torch.cat((poses, scales), 1) + states = solve_skeleton_state( + self.param_transform(params), + self.joint_offset, + self.joint_rotation, + self.joint_parents, + ) + + return self.unskinning(self.bind_state, states, verts) + + def unskinning(self, bind_state: th.Tensor, target_states: th.Tensor, verts: th.Tensor): + """Apply skinning to a set of states + + Args: + bind_state: [B, NJ, 8] - bind state + target_states: [B, NJ, 8] - current states + vertices: [B, V, 3] - vertices + + Returns: + batch_size x nr_vertices x 3 skinned vertices + """ + assert target_states.size()[1:] == bind_state.size()[1:] + + mat = states_to_matrix(bind_state, target_states) + + ws = self.skin_weights[None, :, :, None, None] + sum_mat = (mat[:, self.skin_indices] * ws).sum(dim=2) + + sum_mat4x4 = torch.cat((sum_mat, torch.zeros_like(sum_mat[:, :, :1, :])), dim=2) + sum_mat4x4[:, :, 3, 3] = 1.0 + + verts_4d = torch.cat((verts, torch.ones_like(verts[:, :, :1])), dim=2).unsqueeze(3) + + resmesh = [] + for i in range(sum_mat.shape[0]): + newmat = sum_mat4x4[i, :, :, :].contiguous() + invnewmat = newmat.inverse() + tmpvets = invnewmat.matmul(verts_4d[i]) + resmesh.append(tmpvets.unsqueeze(0)) + resmesh = torch.cat(resmesh) + + return resmesh.squeeze(3)[..., :3].contiguous() + + def forward(self, poses: th.Tensor, scales: th.Tensor, verts_unposed: Optional[th.Tensor] = None) -> th.Tensor: + """ + Args: + poses: [B, NP] - pose parametersa + scales: [B, NS] - additional scaling params + verts_unposed: [B, N, 3] - unposed vertices + + Returns: + [B, N, 3] - posed vertices + """ + params = torch.cat((poses, scales), 1) + params_transformed = self.param_transform(params) + states = solve_skeleton_state( + params_transformed, + self.joint_offset, + self.joint_rotation, + self.joint_parents, + ) + if verts_unposed is None: + mesh = self.skinning(self.bind_state, self.mesh_vertices.unsqueeze(0), states) + else: + mesh = self.skinning(self.bind_state, verts_unposed, states) + return mesh + + +def solve_skeleton_state(param: th.Tensor, joint_offset: th.Tensor, joint_rotation: th.Tensor, joint_parents: th.Tensor): + """ + :param param: batch_size x (7*nr_skeleton_joints) ParamTransform Outputs. + :return: batch_size x nr_skeleton_joints x 8 Skeleton States + 8 stands form 3 translation + 4 rotation (quat) + 1 scale + """ + batch_size = param.shape[0] + # batch processing for parameters + jp = param.view((batch_size, -1, 7)) + lt = jp[:, :, 0:3] + joint_offset.unsqueeze(0) + lr = Quaternion.batchMul(joint_rotation.unsqueeze(0), Quaternion.batchFromXYZ(jp[:, :, 3:6])) + ls = torch.pow( + torch.tensor([2.0], dtype=torch.float32, device=param.device), + jp[:, :, 6].unsqueeze(2), + ) + + state = [] + for index, parent in enumerate(joint_parents): + if int(parent) != -1: + gr = Quaternion.batchMul(state[parent][:, :, 3:7], lr[:, index, :].unsqueeze(1)) + gt = ( + Quaternion.batchRot( + state[parent][:, :, 3:7], + lt[:, index, :].unsqueeze(1) * state[parent][:, :, 7].unsqueeze(2), + ) + + state[parent][:, :, 0:3] + ) + gs = state[parent][:, :, 7].unsqueeze(2) * ls[:, index, :].unsqueeze(1) + state.append(torch.cat((gt, gr, gs), dim=2)) + else: + state.append( + torch.cat((lt[:, index, :], lr[:, index, :], ls[:, index, :]), dim=1).view( + (batch_size, 1, 8) + ) + ) + + return torch.cat(state, dim=1) + + +def states_to_matrix(bind_state: th.Tensor, target_states: th.Tensor, return_transform: bool=False): + # multiply bind inverse with states + br = Quaternion.batchInvert(bind_state[:, :, 3:7]) + bs = bind_state[:, :, 7].unsqueeze(2).reciprocal() + bt = Quaternion.batchRot(br, -bind_state[:, :, 0:3]) * bs + + # applying rotation + tr = Quaternion.batchMul(target_states[:, :, 3:7], br) + # applying scaling + ts = target_states[:, :, 7].unsqueeze(2) * bs + # applying transformation + tt = ( + Quaternion.batchRot(target_states[:, :, 3:7], bt * target_states[:, :, 7].unsqueeze(2)) + + target_states[:, :, 0:3] + ) + + # convert to matrices + twx = 2.0 * tr[:, :, 0] * tr[:, :, 3] + twy = 2.0 * tr[:, :, 1] * tr[:, :, 3] + twz = 2.0 * tr[:, :, 2] * tr[:, :, 3] + txx = 2.0 * tr[:, :, 0] * tr[:, :, 0] + txy = 2.0 * tr[:, :, 1] * tr[:, :, 0] + txz = 2.0 * tr[:, :, 2] * tr[:, :, 0] + tyy = 2.0 * tr[:, :, 1] * tr[:, :, 1] + tyz = 2.0 * tr[:, :, 2] * tr[:, :, 1] + tzz = 2.0 * tr[:, :, 2] * tr[:, :, 2] + mat = torch.stack( + ( + torch.stack((1.0 - (tyy + tzz), txy + twz, txz - twy), dim=2) * ts, + torch.stack((txy - twz, 1.0 - (txx + tzz), tyz + twx), dim=2) * ts, + torch.stack((txz + twy, tyz - twx, 1.0 - (txx + tyy)), dim=2) * ts, + tt, + ), + dim=3, + ) + if return_transform: + return mat, (tr, tt, ts) + return mat + + +def get_influence_map( + transform_raw: th.Tensor, pose_length=None, num_params_per_joint=7, eps=1.0e-6 +): + num_joints = transform_raw.shape[0] // num_params_per_joint + num_params = transform_raw.shape[-1] + + if pose_length is None: + pose_length = num_params + assert pose_length <= num_params + + transform_raw = transform_raw.reshape((num_joints, num_params_per_joint, num_params)) + + return [ + torch.where(torch.abs(transform_raw[i, :, :pose_length]) > eps)[1].tolist() + for i in range(num_joints) + ] + + +def compute_weights_joints_slow(lbs_weights, lbs_indices, num_joints): + num_verts = lbs_weights.shape[0] + weights_joints = torch.zeros((num_joints, num_verts), dtype=torch.float32) + for i in range(num_verts): + idx = lbs_indices[i, :] + weights_joints[idx, i] = lbs_weights[i, :] + return weights_joints + + +def load_momentum_cfg(model, lbs_config_txt_fh, nr_scaling_params=None): + def find(l, x): + try: + return l.index(x) + except ValueError: + return None + + """Load a parameter configuration file""" + channelNames = ["tx", "ty", "tz", "rx", "ry", "rz", "sc"] + paramNames = [] + joint_names = [] + for idx, bone in enumerate(model["Skeleton"]["Bones"]): + joint_names.append(bone["Name"]) + + def findJointIndex(x): + return find(joint_names, x) + + def findParameterIndex(x): + return find(paramNames, x) + + limits = [] + + # create empty result + transform_triplets = [] + lines = lbs_config_txt_fh.readlines() + + # read until end + for line in lines: + # strip comments + line = line[: line.find("#")] + + if line.find("limit") != -1: + r = re.search("limit ([\\w.]+) (\\w+) (.*)", line) + if r is None: + continue + + if len(r.groups()) != 3: + logger.info("Failed to parse limit configuration line :\n " + line) + continue + + # find parameter and/or joint index + fullname = r.groups()[0] + type = r.groups()[1] + remaining = r.groups()[2] + + parameterIndex = findParameterIndex(fullname) + jointName = fullname.split(".") + jointIndex = findJointIndex(jointName[0]) + channelIndex = -1 + + if jointIndex is not None and len(jointName) == 2: + # find matching channel name + channelIndex = channelNames.index(jointName[1]) + if channelIndex is None: + logger.info( + "Unknown joint channel name " + + jointName[1] + + " in parameter configuration line :\n " + + line + ) + continue + + # only parse passive limits for now + if type == "minmax_passive" or type == "minmax": + # match [ , ] + rp = re.search( + "\\[\\s*([-+]?[0-9]*\\.?[0-9]+)\\s*,\\s*([-+]?[0-9]*\\.?[0-9]+)\\s*\\](\\s*[-+]?[0-9]*\\.?[0-9]+)?", + remaining, + ) + + if len(rp.groups()) != 3: + logger.info(f"Failed to parse passive limit configuration line :\n {line}") + continue + + minVal = float(rp.groups()[0]) + maxVal = float(rp.groups()[1]) + weightVal = 1.0 + if len(rp.groups()) == 3 and not rp.groups()[2] is None: + weightVal = float(rp.groups()[2]) + + # result.limits.append([jointIndex * 7 + channelIndex, minVal, maxVal]) + + if channelIndex >= 0: + valueIndex = jointIndex * 7 + channelIndex + limit = { + "type": "LimitMinMaxJointValue", + "str": fullname, + "valueIndex": valueIndex, + "limits": [minVal, maxVal], + "weight": weightVal, + } + limits.append(limit) + else: + if parameterIndex is None: + logger.info(f"Unknown parameterIndex : {fullname}\n {line} {paramNames} ") + continue + limit = { + "type": "LimitMinMaxParameter", + "str": fullname, + "parameterIndex": parameterIndex, + "limits": [minVal, maxVal], + "weight": weightVal, + } + limits.append(limit) + # continue the remaining file + continue + + # check for parameterset definitions and ignore + if line.find("parameterset") != -1: + continue + + # use regex to parse definition + r = re.search("(\w+).(\w+)\s*=\s*(.*)", line) + if r is None: + continue + + if len(r.groups()) != 3: + logger.info("Failed to parse parameter configuration line :\n " + line) + continue + + # find joint name and parameter + jointIndex = findJointIndex(r.groups()[0]) + if jointIndex is None: + logger.info( + "Unknown joint name " + + r.groups()[0] + + " in parameter configuration line :\n " + + line + ) + continue + + # find matching channel name + channelIndex = channelNames.index(r.groups()[1]) + if channelIndex is None: + logger.info( + "Unknown joint channel name " + + r.groups()[1] + + " in parameter configuration line :\n " + + line + ) + continue + + valueIndex = jointIndex * 7 + channelIndex + + # parse parameters + parameterList = r.groups()[2].split("+") + for parameterPair in parameterList: + parameterPair = parameterPair.strip() + + r = re.search("\s*([+-]?[0-9]*\.?[0-9]*)\s\*\s(\w+)\s*", parameterPair) + if r is None or len(r.groups()) != 2: + logger.info( + "Malformed parameter description " + + parameterPair + + " in parameter configuration line :\n " + + line + ) + continue + + val = float(r.groups()[0]) + parameter = r.groups()[1] + + # check if parameter exists + parameterIndex = findParameterIndex(parameter) + if parameterIndex is None: + # no, create new parameter entry + parameterIndex = len(paramNames) + paramNames.append(parameter) + transform_triplets.append((valueIndex, parameterIndex, val)) + + # set (dense) parameter_transformation matrix + transform = np.zeros((len(channelNames) * len(joint_names), len(paramNames)), dtype=np.float32) + for i, j, v in transform_triplets: + transform[i, j] = v + + outputs = { + "model_param_names": paramNames, + "joint_names": joint_names, + "channel_names": channelNames, + "limits": limits, + "transform": transform, + "transform_offsets": np.zeros((1, len(channelNames) * len(joint_names)), dtype=np.float32), + } + # set number of scales automatically + if nr_scaling_params is None: + outputs.update(nr_scaling_params=len([s for s in paramNames if s.startswith("scale")])) + outputs.update(nr_position_params=len(paramNames) - outputs["nr_scaling_params"]) + + return outputs + + +def compute_normalized_pose_quat(lbs, local_pose, scale): + """Computes a normalized representation of the pose in quaternion space. + This is a delta between the per-joint local transformation and the bind state. + + Returns: + [B, NJ, 4] - normalized rotations + """ + B = local_pose.shape[0] + global_pose_zero = th.zeros((B, 6), dtype=th.float32, device=local_pose.device) + params = lbs.param_transform(th.cat([global_pose_zero, local_pose, scale], axis=-1)) + params = params.reshape(B, -1, 7) + # applying rotation + # TODO: what is this? + rot_quat = Quaternion.batchMul(lbs.joint_rotation[np.newaxis], Quaternion.batchFromXYZ(params[:, :, 3:6])) + # removing the bind state + bind_rot_quat = Quaternion.batchInvert(lbs.bind_state[:, :, 3:7]) + return Quaternion.batchMul(rot_quat, bind_rot_quat) + + +def compute_root_transform_cuda(lbs_fn, poses, verts=None): + # NOTE: verts is not really necessary, + # NOTE: should be used in conjuncation with LBSCuda + B = poses.shape[0] + + # NOTE: scales are zero (!) + _, _, _, state_t, state_r, state_s = lbs_fn(poses, vertices=verts) + + bind_r = lbs_fn.joint_state_r_zero[np.newaxis, 1].expand(B, -1, -1) + bind_t = lbs_fn.joint_state_t_zero[np.newaxis, 1].expand(B, -1) + + R_root = th.matmul(state_r[:, 1], bind_r) + t_root = ( + th.matmul(state_r[:, 1], (bind_t * state_s[:, 1])[..., np.newaxis])[..., 0] + state_t[:, 1] + ) + + return R_root, t_root + + +# def compute_joints_weights(lbs_fn: LinearBlendSkinningCuda, drop_empty: bool = False) -> th.Tensor: +# device = lbs_fn.skin_indices.device +# idxs_verts = th.arange(lbs_fn.nr_vertices)[:, np.newaxis].to(device) +# weights_joints = th.zeros( +# (lbs_fn.nr_joints, lbs_fn.nr_vertices), +# dtype=th.float32, +# device=lbs_fn.skin_indices.device, +# ) +# weights_joints[lbs_fn.skin_indices, idxs_verts] = lbs_fn.skin_weights +# if drop_empty: +# weights_joints = weights_joints[weights_joints.sum(axis=-1).abs() > 0] +# return weights_joints + + +# def compute_pose_regions(lbs_fn: LinearBlendSkinningCuda) -> np.ndarray: +# """Computes pose regions given a linear blend skinning function. + +# Returns: +# np.ndarray of boolean masks of shape [nr_params, n_rvertices] +# """ + +# weights = compute_joints_weights(lbs_fn).cpu().numpy() + +# n_pos = lbs_fn.nr_position_params + +# param_masks = np.zeros((n_pos, weights.shape[-1])) + +# children = {j: [] for j in range(lbs_fn.nr_joints)} +# parents = {j: None for j in range(lbs_fn.nr_joints)} +# prec = {j: [] for j in range(lbs_fn.nr_joints)} +# for j in range(lbs_fn.nr_joints): +# parent_index = int(lbs_fn.joint_parents[j]) +# if parent_index == -1: +# continue +# children[parent_index].append(j) +# parents[j] = parent_index +# prec[j] = [parent_index, int(lbs_fn.joint_parents[parent_index])] + +# # get parameters for each joint +# # j_to_p = get_influence_map(lbs_fn.param_transform.transform, n_pos) +# j_to_p = get_influence_map(lbs_fn.param_transform, n_pos) + +# # get all the joints +# p_to_j = [[] for i in range(n_pos)] +# for j, pidx in enumerate(j_to_p): +# for p in pidx: +# if j not in p_to_j[p]: +# p_to_j[p].append(j) + +# for p, jidx in enumerate(p_to_j): +# param_masks[p] = weights[jidx].sum(axis=0) +# if not np.any(param_masks[p]): +# assert len(jidx) == 1 +# jidx_c = children[jidx[0]][:] +# for jc in jidx_c[:]: +# jidx_c += children[jc] +# param_masks[p] = weights[jidx_c].sum(axis=0) +# return param_masks > 0.0 + + +def compute_pose_regions_legacy(lbs_fn) -> np.ndarray: + """Computes pose regions given a linear blend skinning function.""" + weights = lbs_fn.joints_weights.cpu().numpy() + + n_pos = lbs_fn.param_transform.nr_position_params + + param_masks = np.zeros((n_pos, lbs_fn.joints_weights.shape[-1])) + + children = {j: [] for j in range(lbs_fn.num_joints)} + parents = {j: None for j in range(lbs_fn.num_joints)} + prec = {j: [] for j in range(lbs_fn.num_joints)} + for j in range(lbs_fn.num_joints): + parent_index = int(lbs_fn.joint_parents[j, 0]) + if parent_index == -1: + continue + children[parent_index].append(j) + parents[j] = parent_index + prec[j] = [parent_index, int(lbs_fn.joint_parents[parent_index, 0])] + + # get parameters for each joint + j_to_p = get_influence_map(lbs_fn.param_transform.transform, n_pos) + + # get all the joints + p_to_j = [[] for i in range(n_pos)] + for j, pidx in enumerate(j_to_p): + for p in pidx: + if j not in p_to_j[p]: + p_to_j[p].append(j) + + for p, jidx in enumerate(p_to_j): + param_masks[p] = weights[jidx].sum(axis=0) + if not np.any(param_masks[p]): + assert len(jidx) == 1 + jidx_c = children[jidx[0]][:] + for jc in jidx_c[:]: + jidx_c += children[jc] + param_masks[p] = weights[jidx_c].sum(axis=0) + return param_masks > 0.0 + + +def compute_pose_mask_uv(lbs_fn, geo_fn, uv_size, ksize=25): + device = geo_fn.index_image.device + pose_regions = compute_pose_regions(lbs_fn) + pose_regions = ( + th.as_tensor(pose_regions[6:], dtype=th.float32).permute(1, 0)[np.newaxis].to(device) + ) + pose_regions_uv = geo_fn.to_uv(pose_regions) + pose_regions_uv = F.max_pool2d(pose_regions_uv, ksize, 1, padding=ksize // 2) + pose_cond_mask = (F.interpolate(pose_regions_uv, size=(uv_size, uv_size)) > 0.1).to(th.int32) + return pose_cond_mask + + +def parent_chain(joint_parents, idx, depth): + if depth == 0 or idx == 0: + return [] + parent_idx = int(joint_parents[idx]) + return [parent_idx] + parent_chain(joint_parents, parent_idx, depth - 1) + + +def joint_connectivity(nr_joints, joint_parents, chain_depth=2, pad_ancestors=False): + children = {j: [] for j in range(nr_joints)} + parents = {j: None for j in range(nr_joints)} + ancestors = {j: [] for j in range(nr_joints)} + for j in range(nr_joints): + parent_index = int(joint_parents[j]) + ancestors[j] = parent_chain(joint_parents, j, depth=chain_depth) + if pad_ancestors: + # adding itself + ancestors[j] += [j] * (chain_depth - len(ancestors[j])) + + if parent_index == -1: + continue + children[parent_index].append(j) + parents[j] = parent_index + + return { + 'children': children, + 'parents': parents, + 'ancestors': ancestors, + } + + +# TODO: merge this with LinearBlendSkinning? +class LBSModule(nn.Module): + def __init__( + self, lbs_model_json, lbs_config_dict, lbs_template_verts, lbs_scale, global_scaling + ): + super().__init__() + self.lbs_fn = LinearBlendSkinning(lbs_model_json, lbs_config_dict) + + self.register_buffer("lbs_scale", th.as_tensor(lbs_scale, dtype=th.float32)) + self.register_buffer( + "lbs_template_verts", th.as_tensor(lbs_template_verts, dtype=th.float32) + ) + self.register_buffer("global_scaling", th.as_tensor(global_scaling)) + + def pose(self, verts_unposed, motion, template: Optional[th.Tensor] = None): + scale = self.lbs_scale.expand(motion.shape[0], -1) + if template is None: + template = self.lbs_template_verts + return self.lbs_fn(motion, scale, verts_unposed + template) * self.global_scaling + + def unpose(self, verts, motion): + B = motion.shape[0] + scale = self.lbs_scale.expand(B, -1) + return ( + self.lbs_fn.unpose(motion, scale, verts / self.global_scaling) - self.lbs_template_verts + ) + + def template_pose(self, motion): + B = motion.shape[0] + scale = self.lbs_scale.expand(B, -1) + verts = self.lbs_template_verts[np.newaxis].expand(B, -1, -1) + return self.lbs_fn(motion, scale, verts) * self.global_scaling[np.newaxis] + + diff --git a/visualize/ca_body/utils/module_loader.py b/visualize/ca_body/utils/module_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..750fde207b03c7dccd6dd29c0421f0cbadd8349b --- /dev/null +++ b/visualize/ca_body/utils/module_loader.py @@ -0,0 +1,279 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import copy +import importlib +import inspect +import logging +from dataclasses import dataclass, field +from typing import Any, Dict, Optional + +from attrdict import AttrDict + +from torch import nn + + +logger: logging.Logger = logging.getLogger(__name__) + + +def load_module( + module_name: str, class_name: Optional[str] = None, silent: bool = False +): + """ + Load a module or class given the module/class name. + + Example: + .. code-block:: python + + eye_geo = load_class("path.to.module", "ClassName") + + Args: + module_name: str + The full path of the module relative to the root directory. Ex: ``utils.module_loader`` + + class_name: str + The name of the class within the module to load. + + silent: bool + If set to True, return None instead of raising an exception if module/class is missing + + Returns: + object: + The loaded module or class object. + """ + try: + module = importlib.import_module(f"visualize.{module_name}") + if class_name: + return getattr(module, class_name) + else: + return module + except ModuleNotFoundError as e: + if silent: + return None + logger.error(f"Module not found: {module_name}", exc_info=True) + raise + except AttributeError as e: + if silent: + return None + logger.error( + f"Can not locate class: {class_name} in {module_name}.", exc_info=True + ) + raise + + +# pyre-ignore[3] +def make_module(mod_config: AttrDict, *args: Any, **kwargs: Any) -> Any: + """ + A shortcut for making an object given the config and arguments + + Args: + mod_config: AttrDict + Config. Should contain keys: module_name, class_name, and optionally args + + *args + Positional arguments. + + **kwargs + Default keyword arguments. Overwritten by content from mod_config.args + + Returns: + object: + The loaded module or class object. + """ + mod_config_dict = dict(mod_config) + mod_args = mod_config_dict.pop("args", {}) + mod_args.update({k: v for k, v in kwargs.items() if k not in mod_args.keys()}) + mod_class = load_module(**mod_config_dict) + return mod_class(*args, **mod_args) + + +def get_full_name(mod: object) -> str: + """ + Returns a name of an object in a form .. + """ + mod_class = mod.__class__ + return f"{mod_class.__module__}.{mod_class.__qualname__}" + + +# pyre-fixme[3]: Return type must be annotated. +def load_class(class_name: str): + """ + Load a class given the full class name. + + Example: + .. code-block:: python + + class_instance = load_class("module.path.ClassName") + + Args: + class_name: txt + The full class name including the full path of the module relative to the root directory. + Returns: + A class + """ + # This is a false-positive, pyre doesn't understand rsplit(..., 1) can only have 1-2 elements + # pyre-fixme[6]: In call `load_module`, for 1st positional only parameter expected `bool` but got `str`. + return load_module(*class_name.rsplit(".", 1)) + + +@dataclass(frozen=True) +class ObjectSpec: + """ + Args: + class_name: str + The full class name including the full path of the module relative to + the root directory or just the name of the class within the module to + load when module name is also provided. + + module_name: str + The full path of the module relative to the root directory. Ex: ``utils.module_loader`` + + kwargs: dict + Keyword arguments for initializing the object. + """ + + class_name: str + module_name: Optional[str] = None + kwargs: Dict[str, Any] = field(default_factory=dict) + + +# pyre-fixme[3]: Return type must be annotated. +def load_object(spec: ObjectSpec, **kwargs: Any): + """ + Instantiate an object given the class name and initialization arguments. + + Example: + .. code-block:: python + + my_model = load_object(ObjectSpec(**my_model_config), in_channels=3) + + Args: + spec: ObjectSpec + An ObjectSpec object that specifies the class name and init arguments. + + kwargs: dict + Additional keyword arguments for initialization. + + Returns: + An object + """ + if spec.module_name is None: + object_class = load_class(spec.class_name) + else: + object_class = load_module(spec.module_name, spec.class_name) + + # Debug message for overriding the object spec + for key in kwargs: + if key in spec.kwargs: + logger.debug(f"Overriding {key} as {kwargs[key]} in {spec}.") + + return object_class(**{**spec.kwargs, **kwargs}) + + +# From DaaT merge. Fix here T145981161 +# pyre-fixme[2]: parameter must be annotated. +# pyre-fixme[3]: Return type must be annotated. +def load_from_config(config: AttrDict, **kwargs): + """Instantiate an object given a config and arguments.""" + assert "class_name" in config and "module_name" not in config + config = copy.deepcopy(config) + class_name = config.pop("class_name") + object_class = load_class(class_name) + return object_class(**config, **kwargs) + + +# From DaaT merge. Fix here T145981161 +# pyre-fixme[2]: parameter must be annotated. +# pyre-fixme[3]: Return type must be annotated. +def forward_parameter_names(module): + """Get the names arguments of the forward pass for the module. + + Args: + module: a class with `forward()` method + """ + names = [] + params = list(inspect.signature(module.forward).parameters.values())[1:] + for p in params: + if p.name in {"*args", "**kwargs"}: + raise ValueError("*args and **kwargs are not supported") + names.append(p.name) + return names + + +# From DaaT merge. Fix here T145981161 +def build_optimizer(config, model): + """Build an optimizer given optimizer config and a model. + + Args: + config: DictConfig + model: nn.Module|Dict[str,nn.Module] + + """ + config = copy.deepcopy(config) + + if isinstance(model, nn.Module): + if "per_module" in config: + params = [] + for name, value in config.per_module.items(): + if not hasattr(model, name): + logger.warning( + f"model {model.__class__} does not have a submodule {name}, skipping" + ) + continue + + params.append( + dict( + params=getattr(model, name).parameters(), + **value, + ) + ) + + defined_names = set(config.per_module.keys()) + for name, module in model.named_children(): + n_params = len(list(module.named_parameters())) + if name not in defined_names and n_params: + logger.warning( + f"not going to optimize module {name} which has {n_params} parameters" + ) + config.pop("per_module") + else: + params = model.parameters() + else: + # NOTE: can we do + assert "per_module" in config + assert isinstance(model, dict) + for name, value in config.per_module.items(): + params = [] + for name, value in config.per_module.items(): + if name not in model: + logger.warning(f"not aware of {name}, skipping") + continue + params.append( + dict( + params=model[name].parameters(), + **value, + ) + ) + + return load_from_config(config, params=params) + + +# From DaaT merge. Fix here T145981161 +class ForwardFilter: + """A module that filters out arguments for the `forward()`.""" + + # pyre-ignore + def __init__(self, module, optional: bool = False) -> None: + # pyre-ignore + self.module = module + # pyre-ignore + self.input_names = set(forward_parameter_names(module)) + + # pyre-ignore + def __call__(self, **kwargs): + filtered_kwargs = {k: v for k, v in kwargs.items() if k in self.input_names} + return self.module(**filtered_kwargs) diff --git a/visualize/ca_body/utils/quaternion.py b/visualize/ca_body/utils/quaternion.py new file mode 100644 index 0000000000000000000000000000000000000000..c8d2bf213835fe3de21efb50765f0a8ad77c6f94 --- /dev/null +++ b/visualize/ca_body/utils/quaternion.py @@ -0,0 +1,679 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import numpy as np +import torch as th + +import torch.nn as nn +import torch.nn.functional as F + + +class Quaternion: + """Torch Tensor based Quaternion class""" + + @staticmethod + def identity(dtype=th.double): + """ + Create identity quaternion + """ + return th.tensor([0.0, 0.0, 0.0, 1.0], dtype=dtype) + + @staticmethod + def mul(q, r): + """ + mul two quaternions, expects those to be double tesnors of length 4 + """ + return th.stack( + [ + (q * th.tensor([1.0, 1.0, -1.0, 1.0], dtype=q.dtype)).dot(r[[3, 2, 1, 0]]), + (q * th.tensor([-1.0, 1.0, 1.0, 1.0], dtype=q.dtype)).dot(r[[2, 3, 0, 1]]), + (q * th.tensor([1.0, -1.0, 1.0, 1.0], dtype=q.dtype)).dot(r[[1, 0, 3, 2]]), + (q * th.tensor([-1.0, -1.0, -1.0, 1.0], dtype=q.dtype)).dot(r[[0, 1, 2, 3]]), + ] + ) + + @staticmethod + def rot(q, v): + """ + Rotate 3d-vector v given with quaternion q + """ + axis = q[:3] + av = th.cross(axis, v) + aav = th.cross(axis, av) + return v + 2 * (av * q[3] + aav) + + @staticmethod + def invert(q): + """ + Get the inverse of quaternion q + """ + return q * th.tensor([-1.0, -1.0, -1.0, 1.0], dtype=q.dtype) * (1.0 / q.dot(q)) + + @staticmethod + def fromAxisAngle(axis, angle): + """ + Generate a quaternion representing a rotation around axis by angle + """ + s = th.sin(angle * 0.5) + c = th.cos(angle * 0.5).view([1]) + return th.cat((axis * s, c), 0) + + @staticmethod + def fromXYZ(angles): + """ + Generate a quaternion representing a rotation defined by a XYZ-Euler + rotation. + This is faster than creating three separate quaternions and muling + them. + """ + rc = th.cos( + angles * th.tensor([-0.5, 0.5, 0.5], dtype=angles.dtype, device=angles.device) + ) + rs = th.sin( + angles * th.tensor([-0.5, 0.5, 0.5], dtype=angles.dtype, device=angles.device) + ) + + return th.stack( + [ + -rs[0] * rc[1] * rc[2] - rc[0] * rs[1] * rs[2], + rc[0] * rs[1] * rc[2] - rs[0] * rc[1] * rs[2], + rc[0] * rc[1] * rs[2] + rs[0] * rs[1] * rc[2], + rc[0] * rc[1] * rc[2] - rs[0] * rs[1] * rs[2], + ] + ) + + @staticmethod + def toMatrix(q): + """ + Convert quaternion q to 3x3 rotation matrix + """ + result = th.empty([3, 3], dtype=q.dtype) + + tx = q[0] * 2.0 + ty = q[1] * 2.0 + tz = q[2] * 2.0 + twx = tx * q[3] + twy = ty * q[3] + twz = tz * q[3] + txx = tx * q[0] + txy = ty * q[0] + txz = tz * q[0] + tyy = ty * q[1] + tyz = tz * q[1] + tzz = tz * q[2] + + result[0, 0] = 1.0 - (tyy + tzz) + result[0, 1] = txy - twz + result[0, 2] = txz + twy + result[1, 0] = txy + twz + result[1, 1] = 1.0 - (txx + tzz) + result[1, 2] = tyz - twx + result[2, 0] = txz - twy + result[2, 1] = tyz + twx + result[2, 2] = 1.0 - (txx + tyy) + + return result + + @staticmethod + def toMatrixBatch(q): + tx = q[..., 0] * 2.0 + ty = q[..., 1] * 2.0 + tz = q[..., 2] * 2.0 + twx = tx * q[..., 3] + twy = ty * q[..., 3] + twz = tz * q[..., 3] + txx = tx * q[..., 0] + txy = ty * q[..., 0] + txz = tz * q[..., 0] + tyy = ty * q[..., 1] + tyz = tz * q[..., 1] + tzz = tz * q[..., 2] + + return th.stack( + ( + th.stack((1.0 - (tyy + tzz), txy + twz, txz - twy), dim=2), + th.stack((txy - twz, 1.0 - (txx + tzz), tyz + twx), dim=2), + th.stack((txz + twy, tyz - twx, 1.0 - (txx + tyy)), dim=2), + ), + dim=3, + ) + + @staticmethod + def toMatrixBatchDim1(q): + tx = q[..., 0] * 2.0 + ty = q[..., 1] * 2.0 + tz = q[..., 2] * 2.0 + twx = tx * q[..., 3] + twy = ty * q[..., 3] + twz = tz * q[..., 3] + txx = tx * q[..., 0] + txy = ty * q[..., 0] + txz = tz * q[..., 0] + tyy = ty * q[..., 1] + tyz = tz * q[..., 1] + tzz = tz * q[..., 2] + + return th.stack( + ( + th.stack((1.0 - (tyy + tzz), txy + twz, txz - twy), dim=1), + th.stack((txy - twz, 1.0 - (txx + tzz), tyz + twx), dim=1), + th.stack((txz + twy, tyz - twx, 1.0 - (txx + tyy)), dim=1), + ), + dim=2, + ) + + + @staticmethod + def batchMul(q, r): + """ + mul two quaternions, expects those to be double tesnors of length 4 + + Args: + q: N x K x 4 quaternions + r: N x K x 4 quaternions + + Returns: + N x K x 4 multiplied quaternions + """ + return th.stack( + [ + th.sum( + th.mul( + th.mul( + q, + th.tensor( + [[[1.0, 1.0, -1.0, 1.0]]], + dtype=q.dtype, + device=q.device, + ), + ), + r[:, :, (3, 2, 1, 0)], + ), + dim=-1, + ), + th.sum( + th.mul( + th.mul( + q, + th.tensor( + [[[-1.0, 1.0, 1.0, 1.0]]], + dtype=q.dtype, + device=q.device, + ), + ), + r[:, :, (2, 3, 0, 1)], + ), + dim=-1, + ), + th.sum( + th.mul( + th.mul( + q, + th.tensor( + [[[1.0, -1.0, 1.0, 1.0]]], + dtype=q.dtype, + device=q.device, + ), + ), + r[:, :, (1, 0, 3, 2)], + ), + dim=-1, + ), + th.sum( + th.mul( + th.mul( + q, + th.tensor( + [[[-1.0, -1.0, -1.0, 1.0]]], + dtype=q.dtype, + device=q.device, + ), + ), + r[:, :, (0, 1, 2, 3)], + ), + dim=-1, + ), + ], + dim=2, + ) + + @staticmethod + def batchRot(q, v): + """ + Rotate 3d-vector v given with quaternion q + + Args: + q: N x K x 4 quaternions + v: N x K x 3 vectors + + Returns: + N x K x 3 rotated vectors + """ + av = th.cross(q[:, :, :3], v, dim=2) + aav = th.cross(q[:, :, :3], av, dim=2) + return th.add(v, 2 * th.add(th.mul(av, q[:, :, 3].unsqueeze(2)), aav)) + + + @staticmethod + def batchInvert(q): + """ + Get the inverse of quaternion q + + Args: + q: N x K x 4 quaternions + + Returns: + N x K x 4 inverted quaternions + """ + return ( + q + * th.tensor([-1.0, -1.0, -1.0, 1.0], dtype=q.dtype, device=q.device) + * (th.reciprocal(th.sum(q * q, dim=2).unsqueeze(2))) + ) + + @staticmethod + def batchFromXYZ(r): + """ + Generate a quaternion representing a rotation defined by a XYZ-Euler + rotation. + + Args: + r: N x K x 3 rotation vectors + + Returns: + N x K x 4 quaternions + """ + rm = r * th.tensor([[[-0.5, 0.5, 0.5]]], dtype=r.dtype, device=r.device) + rc = th.cos(rm) + rs = th.sin(rm) + + return th.stack( + [ + th.sub( + th.mul(th.neg(rs[:, :, 0]), th.mul(rc[:, :, 1], rc[:, :, 2])), + th.mul(rc[:, :, 0], th.mul(rs[:, :, 1], rs[:, :, 2])), + ), + th.sub( + th.mul(rc[:, :, 0], th.mul(rs[:, :, 1], rc[:, :, 2])), + th.mul(rs[:, :, 0], th.mul(rc[:, :, 1], rs[:, :, 2])), + ), + th.add( + th.mul(rc[:, :, 0], th.mul(rc[:, :, 1], rs[:, :, 2])), + th.mul(rs[:, :, 0], th.mul(rs[:, :, 1], rc[:, :, 2])), + ), + th.sub( + th.mul(rc[:, :, 0], th.mul(rc[:, :, 1], rc[:, :, 2])), + th.mul(rs[:, :, 0], th.mul(rs[:, :, 1], rs[:, :, 2])), + ), + ], + dim=2, + ) + + @staticmethod + def batchMatrixFromXYZ(r): + """ + Generate a matrix representing a rotation defined by a XYZ-Euler + rotation. + + Args: + r: N x 3 rotation vectors + + Returns: + N x 3 x 3 rotation matrices + """ + rc = th.cos(r) + rs = th.sin(r) + cx = rc[:, 0] + cy = rc[:, 1] + cz = rc[:, 2] + sx = rs[:, 0] + sy = rs[:, 1] + sz = rs[:, 2] + + result = th.stack( + ( + cy * cz, + -cx * sz + sx * sy * cz, + sx * sz + cx * sy * cz, + cy * sz, + cx * cz + sx * sy * sz, + -sx * cz + cx * sy * sz, + -sy, + sx * cy, + cx * cy, + ), + dim=1, + ).view(-1, 3, 3) + return result + + @staticmethod + def batchQuatFromMatrix(m): + """ + :param m: B*3*3 + :return: B*4, order xyzw + """ + assert len(m.shape) == 3 + b, j, k = m.shape + assert j == 3 + assert k == 3 + result = th.zeros((b, 4), dtype=th.float32).to(m.device) + S = th.zeros((b,), dtype=th.float32).to(m.device) + + m00 = m[:, 0, 0] + m01 = m[:, 0, 1] + m02 = m[:, 0, 2] + m10 = m[:, 1, 0] + m11 = m[:, 1, 1] + m12 = m[:, 1, 2] + m20 = m[:, 2, 0] + m21 = m[:, 2, 1] + m22 = m[:, 2, 2] + + tr = m00 + m11 + m22 + flag = tr > 0 + S[flag] = 2 * th.sqrt(1 + tr[flag]) + result[flag, 0] = (m21 - m12)[flag] / S[flag] + result[flag, 1] = (m02 - m20)[flag] / S[flag] + result[flag, 2] = (m10 - m01)[flag] / S[flag] + result[flag, 3] = 0.25 * S[flag] + + flag = ~flag & (m00 > m11) & (m00 > m22) + S[flag] = 2 * th.sqrt(1.0 + m00[flag] - m11[flag] - m22[flag]) + result[flag, 0] = 0.25 * S[flag] + result[flag, 1] = (m01 + m10)[flag] / S[flag] + result[flag, 2] = (m02 + m20)[flag] / S[flag] + result[flag, 3] = (m21 - m12)[flag] / S[flag] + + flag = ~flag & (m11 > m22) + S[flag] = 2 * th.sqrt(1.0 + m11[flag] - m00[flag] - m22[flag]) + result[flag, 0] = (m01 + m10)[flag] / S[flag] + result[flag, 1] = 0.25 * S[flag] + result[flag, 2] = (m12 + m21)[flag] / S[flag] + result[flag, 3] = (m02 - m20)[flag] / S[flag] + + flag = ~flag + S[flag] = 2 * th.sqrt(1.0 + m22[flag] - m00[flag] - m11[flag]) + result[flag, 0] = (m02 + m20)[flag] / S[flag] + result[flag, 1] = (m12 + m21)[flag] / S[flag] + result[flag, 2] = 0.25 * S[flag] + result[flag, 3] = (m10 - m01)[flag] / S[flag] + + return result + + +class RodriguesVecBatch(nn.Module): + def __init__(self): + super(RodriguesVecBatch, self).__init__() + self.register_buffer("eye", (th.eye(3))) + self.register_buffer( + "zero", + ( + th.zeros( + 1, + ) + ), + ) + # mat = th.zeros((nbat,3,3),dtype=th.float32,device=r.device,requires_grad=True) + + def forward( + self, v0, v1 + ): # assuming v0 and v1 are already normalized, compute matrix aligning v0 to v1 + nbat = v0.size(0) + cosn = (v0 * v1).sum(dim=1, keepdim=True).unsqueeze(2) + # r = v0.cross(v1,dim=1) + r = v1.cross(v0, dim=1) + sinn = r.pow(2).sum(1, keepdim=True).sqrt().unsqueeze(2) + rn = r.unsqueeze(2) / (sinn + 1e-10) + R = cosn * self.eye.unsqueeze(0).expand(nbat, 3, 3) + R = R + (1.0 - cosn) * rn.bmm(rn.permute(0, 2, 1)) + R[:, 0, 1] = R[:, 0, 1] + rn[:, 2, 0] * sinn[:, 0, 0] + R[:, 1, 0] = R[:, 0, 1] - rn[:, 2, 0] * sinn[:, 0, 0] + R[:, 0, 2] = R[:, 0, 2] - rn[:, 1, 0] * sinn[:, 0, 0] + R[:, 2, 0] = R[:, 2, 0] + rn[:, 1, 0] * sinn[:, 0, 0] + R[:, 1, 2] = R[:, 1, 2] + rn[:, 0, 0] * sinn[:, 0, 0] + R[:, 2, 1] = R[:, 2, 1] - rn[:, 0, 0] * sinn[:, 0, 0] + return R + + +class RodriguesBatch(nn.Module): + def __init__(self): + super(RodriguesBatch, self).__init__() + self.register_buffer("eye", (th.eye(3))) + self.register_buffer( + "zero", + ( + th.zeros( + 1, + ) + ), + ) + + def forward(self, r): + # pdb.set_trace() + nbat = r.size(0) + n = ((r * r).sum(dim=1, keepdim=True) + 1e-10).sqrt() + rn = th.div(r, n).unsqueeze(2) + + cosn = th.cos(n).unsqueeze(2) + sinn = th.sin(n).unsqueeze(2) + R = cosn * self.eye.unsqueeze(0).expand(nbat, 3, 3) + R = R + (1.0 - cosn) * rn.bmm(rn.permute(0, 2, 1)) + + R[:, 0, 1] = R[:, 0, 1] + rn[:, 2, 0] * sinn[:, 0, 0] + R[:, 1, 0] = R[:, 0, 1] - rn[:, 2, 0] * sinn[:, 0, 0] + R[:, 0, 2] = R[:, 0, 2] - rn[:, 1, 0] * sinn[:, 0, 0] + R[:, 2, 0] = R[:, 2, 0] + rn[:, 1, 0] * sinn[:, 0, 0] + R[:, 1, 2] = R[:, 1, 2] + rn[:, 0, 0] * sinn[:, 0, 0] + R[:, 2, 1] = R[:, 2, 1] - rn[:, 0, 0] * sinn[:, 0, 0] + return R + + +class NormalComputer(nn.Module): + def __init__(self, height, width, maskin=None): + super(NormalComputer, self).__init__() + # self.register_buffer('eye', (th.eye(3))) + # self.register_buffer('zero', (th.zeros(1,))) + + patchttnum = 5 # neighbor + self + patchmatch_uvpos = np.zeros((height, width, patchttnum, 2), dtype=np.int32) + vec_standuv = ( + np.indices((height, width)) + .swapaxes(0, 2) + .swapaxes(0, 1) + .astype(np.int32) + .reshape(height, width, 1, 2) + ) + patchmatch_uvpos = patchmatch_uvpos + vec_standuv + localpatchcoord = np.zeros((patchttnum, 2), dtype=np.int32) + localpatchcoord = np.array([[-1, 0], [0, 1], [1, 0], [0, -1], [0, 0]]).astype(np.int32) + + patchmatch_uvpos = patchmatch_uvpos + localpatchcoord.reshape(1, 1, patchttnum, 2) + patchmatch_uvpos[..., 0] = np.clip(patchmatch_uvpos[..., 0], 0, height - 1) + patchmatch_uvpos[..., 1] = np.clip(patchmatch_uvpos[..., 1], 0, width - 1) + + # geoemtry mask , apply simiilar to texture mask + # mesh_mask_int = mesh_mask.reshape(height,width).astype(np.int32) + if maskin is None: + maskin = np.ones((height, width), dtype=np.int32) + mesh_mask_int = maskin.reshape(height, width).astype( + np.int32 + ) # using all pixel valid mask; can use a tailored mask + patchmatch_mask = mesh_mask_int[patchmatch_uvpos[..., 0], patchmatch_uvpos[..., 1]].reshape( + height, width, patchttnum, 1 + ) + patch_indicemap = patchmatch_uvpos * patchmatch_mask + (1 - patchmatch_mask) * vec_standuv + + tensor_patch_geoindicemap = th.from_numpy(patch_indicemap).long() + tensor_patch_geoindicemap1d = ( + tensor_patch_geoindicemap[..., 0] * width + tensor_patch_geoindicemap[..., 1] + ) + + self.register_buffer("tensor_patch_geoindicemap1d", tensor_patch_geoindicemap1d) + # tensor_patchmatch_uvpos = th.from_numpy(patchmatch_uvpos).long() + # tensor_vec_standuv = th.from_numpy(vec_standuv).long() + + def forward(self, t_georecon): # in: N 3 H W + # pdb.set_trace() + # Intergration switch it to index_select + # geometry_in = index_selection_nd( + # t_georecon.view(t_georecon.size(0), t_georecon.size(1), -1), + # self.tensor_patch_geoindicemap1d, + # 2, + # ).permute(0, 2, 3, 4, 1) + + geometry_in = th.index_select( + t_georecon.view(t_georecon.size(0), t_georecon.size(1), -1), + self.tensor_patch_geoindicemap1d, + 2, + ).permute(0, 2, 3, 4, 1) + + normal = (geometry_in[..., 0, :] - geometry_in[..., 4, :]).cross( + geometry_in[..., 1, :] - geometry_in[..., 4, :], dim=3 + ) + normal = normal + (geometry_in[..., 1, :] - geometry_in[..., 4, :]).cross( + geometry_in[..., 2, :] - geometry_in[..., 4, :], dim=3 + ) + normal = normal + (geometry_in[..., 2, :] - geometry_in[..., 4, :]).cross( + geometry_in[..., 3, :] - geometry_in[..., 4, :], dim=3 + ) + normal = normal + (geometry_in[..., 3, :] - geometry_in[..., 4, :]).cross( + geometry_in[..., 0, :] - geometry_in[..., 4, :], dim=3 + ) + normal = normal / th.clamp(normal.pow(2).sum(3, keepdim=True).sqrt(), min=1e-6) + return normal.permute(0, 3, 1, 2) + + +def pointcloud_rigid_registration(src_pointcloud, dst_pointcloud, reduce_loss: bool = True): + """ + Calculate RT and residual L2 loss for two pointclouds + :param src_pointcloud: x (b, v, 3) + :param dst_pointcloud: y (b, v, 3) + :return: loss, R, t s.t. ||Rx+t-y||_2^2 minimal. + """ + if len(src_pointcloud.shape) == 2: + src_pointcloud = src_pointcloud.unsqueeze(0) + if len(dst_pointcloud.shape) == 2: + dst_pointcloud = dst_pointcloud.unsqueeze(0) + bn = src_pointcloud.shape[0] + + assert src_pointcloud.shape == dst_pointcloud.shape + assert src_pointcloud.shape[2] == 3 + + X = src_pointcloud - src_pointcloud.mean(dim=1, keepdim=True) + Y = dst_pointcloud - dst_pointcloud.mean(dim=1, keepdim=True) + + XYT = th.einsum("nji,njk->nik", X, Y) + muX = src_pointcloud.mean(dim=1) + muY = dst_pointcloud.mean(dim=1) + + R = th.zeros((bn, 3, 3), dtype=src_pointcloud.dtype).to(src_pointcloud.device) + t = th.zeros((bn, 1, 3), dtype=src_pointcloud.dtype).to(src_pointcloud.device) + loss = th.zeros((bn,), dtype=src_pointcloud.dtype).to(src_pointcloud.device) + + for i in range(bn): + u_, s_, v_ = th.svd(XYT[i, :, :]) + detvut = th.det(v_.mm(u_.t())) + diag_m = th.ones_like(s_) + diag_m[-1] = detvut + + r_ = v_.mm(th.diag(diag_m)).mm(u_.t()) + t_ = muY[i, :] - r_.mm(muX[i, :, None])[:, 0] + + R[i, :, :] = r_ + t[i, 0, :] = t_ + loss[i] = (th.einsum("ij,nj->ni", r_, X[i]) - Y[i]).pow(2).sum(1).mean(0) + + loss = loss.mean(0) if reduce_loss else loss + return loss, R, t + + +def pointcloud_rigid_registration_balanced(src_pointcloud, dst_pointcloud, weight): + """ + Calculate RT and residual L2 loss for two pointclouds + :param src_pointcloud: x (b, v, 3) + :param dst_pointcloud: y (b, v, 3) + :param weight: (v, ), duplication of vertices + :return: loss, R, t s.t. ||w(Rx+t-y)||_2^2 minimal. + """ + if len(src_pointcloud.shape) == 2: + src_pointcloud = src_pointcloud.unsqueeze(0) + if len(dst_pointcloud.shape) == 2: + dst_pointcloud = dst_pointcloud.unsqueeze(0) + bn = src_pointcloud.shape[0] + + assert src_pointcloud.shape == dst_pointcloud.shape + assert src_pointcloud.shape[2] == 3 + assert src_pointcloud.shape[1] == weight.shape[0] + assert len(weight.shape) == 1 + w = weight[None, :, None] + + def s1(a): + return a.sum(dim=1, keepdim=True) + + w2 = w.pow(2) + sw2 = s1(w2) + X = src_pointcloud + Y = dst_pointcloud + + wXYT = th.einsum("nji,njk->nik", w2 * (sw2 - w2) * X, Y) + U, s, V = batch_svd(wXYT) + UT = U.permute(0, 2, 1).contiguous() + det = batch_det(V.bmm(UT)) + diag = th.ones_like(s).to(s.device) + diag[:, -1] = det + + R = V.bmm(batch_diag(diag)).bmm(UT) + RX = th.einsum("bij,bnj->bni", R, X) + t = th.sum(w * (Y - RX), dim=1, keepdim=True) / sw2 + loss = w * (RX + t - Y) + loss = F.mse_loss(loss, th.zeros_like(loss)) * 3 + + return loss, R, t + + +def batch_dot(x, y): + assert x.shape == y.shape + assert len(x.shape) == 2 + return th.einsum("ni,ni->n", x, y) + + +def batch_svd(x): + assert len(x.shape) == 3 + bn, m, n = x.shape + U = th.zeros((bn, m, m), dtype=th.float32).to(x.device) + s = th.zeros((bn, min(n, m)), dtype=th.float32).to(x.device) + V = th.zeros((bn, n, n), dtype=th.float32).to(x.device) + for i in range(bn): + u_, s_, v_ = th.svd(x[i, :, :]) + U[i] = u_ + s[i] = s_ + V[i] = v_ + return U, s, V + + +def batch_diag(x): + if len(x.shape) == 2: + bn, n = x.shape + res = th.zeros((bn, n, n), dtype=th.float32).to(x.device) + res[:, range(n), range(n)] = x + return res + elif len(x.shape) == 3: + assert x.shape[1] == x.shape[2] + n = x.shape[1] + return x[:, range(n), range(n)] + else: + raise ValueError("dim of batch_diag should be 2 or 3") + + +def batch_det(x): + assert len(x.shape) == 3 + assert x.shape[1] == x.shape[2] + bn, _, _ = x.shape + res = th.zeros((bn,), dtype=th.float32).to(x.device) + for i in range(bn): + res[i] = th.det(x[i]) + return res diff --git a/visualize/ca_body/utils/render.py b/visualize/ca_body/utils/render.py new file mode 100644 index 0000000000000000000000000000000000000000..0a89e1563e36b9d5adb5216b307eb80b26d84901 --- /dev/null +++ b/visualize/ca_body/utils/render.py @@ -0,0 +1,65 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from typing import List, Dict +import torch as th +import torch.nn as nn + +from pytorch3d.renderer import ( + RasterizationSettings, + MeshRasterizer, +) + +from pytorch3d.structures import Meshes +from pytorch3d.renderer.mesh.textures import TexturesUV +from pytorch3d.utils import cameras_from_opencv_projection + +class RenderLayer(nn.Module): + + def __init__(self, h, w, vi, vt, vti, flip_uvs=False): + super().__init__() + self.register_buffer("vi", vi, persistent=False) + self.register_buffer("vt", vt, persistent=False) + self.register_buffer("vti", vti, persistent=False) + raster_settings = RasterizationSettings(image_size=(h, w)) + self.rasterizer = MeshRasterizer(raster_settings=raster_settings) + self.flip_uvs = flip_uvs + image_size = th.as_tensor([h, w], dtype=th.int32) + self.register_buffer("image_size", image_size) + + def forward(self, verts: th.Tensor, tex: th.Tensor, K: th.Tensor, Rt: th.Tensor, background: th.Tensor = None, output_filters: List[str] = None): + + assert output_filters is None + assert background is None + + device = verts.device # Get device info + B = verts.shape[0] + + image_size = th.repeat_interleave(self.image_size[None], B, dim=0).to(device) + + cameras = cameras_from_opencv_projection(Rt[:,:,:3], Rt[:,:3,3], K, image_size) + + faces = self.vi[None].repeat(B, 1, 1).to(device) + faces_uvs = self.vti[None].repeat(B, 1, 1).to(device) + verts_uvs = self.vt[None].repeat(B, 1, 1).to(device) + + # In-place operation for flipping and permuting tensor + if not self.flip_uvs: + tex = tex.permute(0, 2, 3, 1).flip((1,)).to(device) + + textures = TexturesUV( + maps=tex, + faces_uvs=faces_uvs, + verts_uvs=verts_uvs, + ) + meshes = Meshes(verts.to(device), faces, textures=textures) + + fragments = self.rasterizer(meshes, cameras=cameras) + rgb = meshes.sample_textures(fragments)[:,:,:,0] + rgb[fragments.pix_to_face[...,0] == -1] = 0.0 + + return {'render': rgb.permute(0, 3, 1, 2)} \ No newline at end of file diff --git a/visualize/ca_body/utils/seams.py b/visualize/ca_body/utils/seams.py new file mode 100644 index 0000000000000000000000000000000000000000..e5d22890c76f4f3b8a84d8cdb9d7f1c33c543692 --- /dev/null +++ b/visualize/ca_body/utils/seams.py @@ -0,0 +1,52 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from typing import Any, Dict + +import numpy as np +import torch as th +import torch.nn as nn +import torch.nn.functional as F + + +def impaint_batch(value: th.Tensor, dst_ij: th.Tensor, src_ij: th.Tensor) -> th.Tensor: + assert len(value.shape) == 4, "expecting a 4D tensor" + preds = value[:] + preds[:, :, dst_ij[:, 0], dst_ij[:, 1]] = value[:, :, src_ij[:, 0], src_ij[:, 1]] + return preds + + +def resample_tex(tex: th.Tensor, uvs: th.Tensor, weights: th.Tensor) -> th.Tensor: + B = tex.shape[0] + grid = 2.0 * (uvs[np.newaxis].expand(B, -1, -1, -1) - 0.5) + tex_resampled = F.grid_sample(tex, grid, align_corners=False, padding_mode="border") + return (1.0 - weights) * tex + weights * tex_resampled + + +class SeamSampler(nn.Module): + def __init__(self, seamless_data: Dict[str, Any]) -> None: + super().__init__() + + self.register_buffer("dst_ij", seamless_data["dst_ij"]) + self.register_buffer("src_ij", seamless_data["src_ij"]) + self.register_buffer("uvs", seamless_data["uvs"]) + self.register_buffer("weights", seamless_data["weights"]) + + def impaint(self, value: th.Tensor) -> th.Tensor: + return impaint_batch(value, self.dst_ij, self.src_ij) + + def resample(self, tex: th.Tensor) -> th.Tensor: + return resample_tex(tex, self.uvs, self.weights) + + def resample_border_only(self, tex: th.Tensor) -> th.Tensor: + tex = resample_tex(tex, self.uvs, self.weights) + return tex + + def forward(self, tex: th.Tensor) -> th.Tensor: + x = self.impaint(tex) + x = self.resample(x) + return x diff --git a/visualize/ca_body/utils/torch.py b/visualize/ca_body/utils/torch.py new file mode 100644 index 0000000000000000000000000000000000000000..7b2e15aacfd190bc6a977b5180bcb217d628eba4 --- /dev/null +++ b/visualize/ca_body/utils/torch.py @@ -0,0 +1,229 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from typing import Optional, Tuple, Sequence, TypeVar, Union, Mapping, Any, List, Dict + +import torch as th +import numpy as np + +TensorOrContainer = Union[ + th.Tensor, str, int, Sequence["TensorOrContainer"], Mapping[str, "TensorOrContainer"] +] +NdarrayOrContainer = Union[ + np.ndarray, + str, + int, + Sequence["NdarrayOrContainer"], + Mapping[str, "NdarrayOrContainer"], +] +TensorNdarrayOrContainer = Union[ + th.Tensor, + np.ndarray, + str, + int, + Sequence["TensorNdarrayOrContainer"], + Mapping[str, "TensorNdarrayOrContainer"], +] +TensorNdarrayModuleOrContainer = Union[ + th.Tensor, + np.ndarray, + th.nn.Module, + str, + int, + Sequence["TensorNdarrayModuleOrContainer"], + Mapping[str, "TensorNdarrayModuleOrContainer"], +] +TTensorOrContainer = TypeVar("TTensorOrContainer", bound=TensorOrContainer) +TNdarrayOrContainer = TypeVar("TNdarrayOrContainer", bound=NdarrayOrContainer) +TTensorNdarrayOrContainer = TypeVar("TTensorNdarrayOrContainer", bound=TensorNdarrayOrContainer) +TTensorNdarrayModuleOrContainer = TypeVar( + "TTensorNdarrayModuleOrContainer", bound=TensorNdarrayModuleOrContainer +) + + +import torch as th + +import logging + +logger = logging.getLogger(__name__) + + +class ParamHolder(th.nn.Module): + def __init__( + self, + param_shape: Tuple[int, ...], + key_list: Sequence[str], + init_value: Union[None, bool, float, int, th.Tensor] = None, + ) -> None: + super().__init__() + + if isinstance(param_shape, int): + param_shape = (param_shape,) + self.key_list: Sequence[str] = sorted(key_list) + shp = (len(self.key_list),) + param_shape + self.params = th.nn.Parameter(th.zeros(*shp)) + + if init_value is not None: + self.params.data[:] = init_value + + def state_dict(self, *args: Any, saving: bool = False, **kwargs: Any) -> Dict[str, Any]: + sd = super().state_dict(*args, **kwargs) + if saving: + assert "key_list" not in sd + sd["key_list"] = self.key_list + return sd + + # pyre-fixme[14]: `load_state_dict` overrides method defined in `Module` + # inconsistently. + def load_state_dict( + self, state_dict: Mapping[str, Any], strict: bool = True, **kwargs: Any + ) -> th.nn.modules.module._IncompatibleKeys: + # Note: Mapping is immutable while Dict is mutable. According to pyre ErrorCode[14], + # the type of state_dict must be Mapping or supertype of Mapping to keep consistent + # with the overrided function in its superclass. + sd = dict(state_dict) + if "key_list" not in sd: + logger.warning("Missing key list list in state dict, only checking params shape.") + assert sd["params"].shape == self.params.shape + sd["key_list"] = self.key_list + + matching_kl = sd["key_list"] == self.key_list + if strict: + logger.warning("Attempting to load from mismatched key lists.") + assert sd["params"].shape[1:] == self.params.shape[1:] + + if not matching_kl: + src_kl = sd["key_list"] + new_kl = sorted(set(self.key_list) | set(src_kl)) + new_shp = (len(new_kl),) + tuple(self.params.shape[1:]) + new_params = th.zeros(*new_shp, device=self.params.device) + for f in self.key_list: + new_params[new_kl.index(f)] = self.params[self.key_list.index(f)] + upd = 0 + new = 0 + for f in src_kl: + new_params[new_kl.index(f)] = sd["params"][src_kl.index(f)] + if f in self.key_list: + upd += 1 + else: + new += 1 + logger.info( + f"Updated {upd} keys ({100*upd/len(self.key_list):0.2f}%), added {new} new keys." + ) + + self.key_list = new_kl + sd["params"] = new_params + self.params = th.nn.Parameter(new_params) + del sd["key_list"] + return super().load_state_dict(sd, strict=strict, **kwargs) + + def to_idx(self, *args: Any) -> th.Tensor: + if len(args) == 1: + keys = args[0] + else: + keys = zip(*args) + + return th.tensor( + [self.key_list.index(k) for k in keys], + dtype=th.long, + device=self.params.device, + ) + + def from_idx(self, idxs: th.Tensor) -> List[str]: + return [self.key_list[idx] for idx in idxs] + + def forward(self, idxs: th.Tensor) -> th.Tensor: + return self.params[idxs] + + + +def to_device( + things: TTensorNdarrayModuleOrContainer, + device: th.device, + cache: Optional[Dict[str, th.Tensor]] = None, + key: Optional[str] = None, + verbose: bool = False, + max_bs: Optional[int] = None, + non_blocking: bool = False, +) -> TTensorNdarrayModuleOrContainer: + """Sends a potentially nested container of Tensors to the specified + device. Non-tensors are preserved as-is. + + Args: + things: Container with tensors or other containers of tensors to send + to a GPU. + + device: Device to send the tensors to. + + cache: Optional dictionary to use as a cache for CUDAfied tensors. If + passed, use this cache to allocate a tensor once and then resize / + refill it on future calls to to_device() instead of reallocating + it. + + key: If using the cache, store the tensor in this key, only for + internal use. + + verbose: Print some info when a cached tensor is resized. + + max_bs: Maximum batch size allowed for tensors in cache + + non_blocking: if True and this copy is between CPU and GPU, the copy + may occur asynchronously with respect to the host. For other cases, + this argument has no effect. + + Returns: + collection: The input collection with all tensors transferred to the given device. + """ + device = th.device(device) + + pr = print if verbose else lambda *args, **kwargs: None + + if isinstance(things, th.Tensor) and things.device != device: + if cache is not None: + assert key is not None + batch_size = things.shape[0] + if key in cache: + assert things.shape[1:] == cache[key].shape[1:] + if batch_size > cache[key].shape[0]: + pr("Resized:", key, "from", cache[key].shape[0], "to", batch_size) + cache[key].resize_as_(things) + else: + buf_shape = list(things.shape) + if max_bs is not None: + assert max_bs >= batch_size + buf_shape[0] = max_bs + cache[key] = th.zeros(*buf_shape, dtype=things.dtype, device=device) + pr("Allocated:", key, buf_shape) + cache[key][:batch_size].copy_(things, non_blocking=non_blocking) + + return cache[key][:batch_size] + else: + return things.to(device, non_blocking=non_blocking) + elif isinstance(things, th.nn.Module): + return things.to(device, non_blocking=non_blocking) + elif isinstance(things, dict): + key = key + "." if key is not None else "" + return { + k: to_device(v, device, cache, key + k, verbose, max_bs, non_blocking) + for k, v in things.items() + } + elif isinstance(things, Sequence) and not isinstance(things, str): + key = key if key is not None else "" + out = [ + to_device(v, device, cache, key + f"_{i}", verbose, max_bs, non_blocking) + for i, v in enumerate(things) + ] + if isinstance(things, tuple): + out = tuple(out) + return out + elif isinstance(things, np.ndarray): + return to_device(th.from_numpy(things), device, cache, key, verbose, max_bs, non_blocking) + else: + return things + + + diff --git a/visualize/ca_body/utils/train.py b/visualize/ca_body/utils/train.py new file mode 100644 index 0000000000000000000000000000000000000000..c9c7a5492ebee145037c9fb390182baf215defb4 --- /dev/null +++ b/visualize/ca_body/utils/train.py @@ -0,0 +1,223 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import torch as th +import os +import re +import glob +import copy +from typing import Dict, Any, Iterator, Mapping, Optional, Union, Tuple, List + + +from collections import OrderedDict +from torch.utils.tensorboard import SummaryWriter +from omegaconf import OmegaConf, DictConfig + +from torch.optim.lr_scheduler import LRScheduler + +from visualize.ca_body.utils.torch import to_device +from visualize.ca_body.utils.module_loader import load_class, build_optimizer + +import torch.nn as nn + +import logging + +logging.basicConfig( + format="[%(asctime)s][%(levelname)s][%(name)s]:%(message)s", + level=logging.INFO, + datefmt="%Y-%m-%d %H:%M:%S", +) + +logger = logging.getLogger(__name__) + + +def process_losses( + loss_dict: Dict[str, Any], reduce: bool = True, detach: bool = True +) -> Dict[str, th.Tensor]: + """Preprocess the dict of losses outputs.""" + result = {k.replace("loss_", ""): v for k, v in loss_dict.items() if k.startswith("loss_")} + if detach: + result = {k: v.detach() for k, v in result.items()} + if reduce: + result = {k: float(v.mean().item()) for k, v in result.items()} + return result + + + +def load_config(path: str) -> DictConfig: + # NOTE: THIS IS THE ONLY PLACE WHERE WE MODIFY CONFIG + config = OmegaConf.load(path) + + # TODO: we should need to get rid of this in favor of DB + assert 'CARE_ROOT' in os.environ + config.CARE_ROOT = os.environ['CARE_ROOT'] + logger.info(f'{config.CARE_ROOT=}') + + if not os.path.isabs(config.train.run_dir): + config.train.run_dir = os.path.join(os.environ['CARE_ROOT'], config.train.run_dir) + logger.info(f'{config.train.run_dir=}') + os.makedirs(config.train.run_dir, exist_ok=True) + return config + + +def load_from_config(config: Mapping[str, Any], **kwargs): + """Instantiate an object given a config and arguments.""" + assert 'class_name' in config and 'module_name' not in config + config = copy.deepcopy(config) + ckpt = None if 'ckpt' not in config else config.pop('ckpt') + class_name = config.pop('class_name') + object_class = load_class(class_name) + instance = object_class(**config, **kwargs) + if ckpt is not None: + load_checkpoint( + ckpt_path=ckpt.path, + modules={ckpt.get('module_name', 'model'): instance}, + ignore_names=ckpt.get('ignore_names', []), + strict=ckpt.get('strict', False), + ) + return instance + + +def save_checkpoint(ckpt_path, modules: Dict[str, Any], iteration=None, keep_last_k=None): + if keep_last_k is not None: + raise NotImplementedError() + ckpt_dict = {} + if os.path.isdir(ckpt_path): + assert iteration is not None + ckpt_path = os.path.join(ckpt_path, f"{iteration:06d}.pt") + ckpt_dict["iteration"] = iteration + for name, mod in modules.items(): + if hasattr(mod, "module"): + mod = mod.module + ckpt_dict[name] = mod.state_dict() + th.save(ckpt_dict, ckpt_path) + + +def filter_params(params, ignore_names): + return OrderedDict( + [ + (k, v) + for k, v in params.items() + if not any([re.match(n, k) is not None for n in ignore_names]) + ] + ) + + +def save_file_summaries(path: str, summaries: Dict[str, Tuple[str, Any]]): + """Saving regular summaries for monitoring purposes.""" + for name, (value, ext) in summaries.items(): + #save(f'{path}/{name}.{ext}', value) + raise NotImplementedError() + + +def load_checkpoint( + ckpt_path: str, + modules: Dict[str, Any], + iteration: int =None, + strict: bool =False, + map_location: Optional[str] =None, + ignore_names: Optional[Dict[str, List[str]]]=None, +): + """Load a checkpoint. + Args: + ckpt_path: directory or the full path to the checkpoint + """ + if map_location is None: + map_location = "cpu" + # adding + if os.path.isdir(ckpt_path): + if iteration is None: + # lookup latest iteration + iteration = max( + [ + int(os.path.splitext(os.path.basename(p))[0]) + for p in glob.glob(os.path.join(ckpt_path, "*.pt")) + ] + ) + ckpt_path = os.path.join(ckpt_path, f"{iteration:06d}.pt") + logger.info(f"loading checkpoint {ckpt_path}") + ckpt_dict = th.load(ckpt_path, map_location=map_location) + for name, mod in modules.items(): + params = ckpt_dict[name] + if ignore_names is not None and name in ignore_names: + logger.info(f"skipping: {ignore_names[name]}") + params = filter_params(params, ignore_names[name]) + mod.load_state_dict(params, strict=strict) + + +def train( + model: nn.Module, + loss_fn: nn.Module, + optimizer: th.optim.Optimizer, + train_data: Iterator, + config: Mapping[str, Any], + lr_scheduler: Optional[LRScheduler] = None, + train_writer: Optional[SummaryWriter] = None, + saving_enabled: bool = True, + logging_enabled: bool = True, + iteration: int = 0, + device: Optional[Union[th.device, str]] = "cuda:0", +) -> None: + + for batch in train_data: + if batch is None: + logger.info("skipping empty batch") + continue + batch = to_device(batch, device) + batch["iteration"] = iteration + + # leaving only inputs acutally used by the model + preds = model(**filter_inputs(batch, model, required_only=False)) + + # TODO: switch to the old-school loss computation + loss, loss_dict = loss_fn(preds, batch, iteration=iteration) + assert not th.isnan(loss), "loss is NaN" + + if th.isnan(loss): + _loss_dict = process_losses(loss_dict) + loss_str = " ".join([f"{k}={v:.4f}" for k, v in _loss_dict.items()]) + logger.info(f"iter={iteration}: {loss_str}") + raise ValueError("loss is NaN") + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + if logging_enabled and iteration % config.train.log_every_n_steps == 0: + _loss_dict = process_losses(loss_dict) + loss_str = " ".join([f"{k}={v:.4f}" for k, v in _loss_dict.items()]) + logger.info(f"iter={iteration}: {loss_str}") + + if logging_enabled and train_writer and iteration % config.train.log_every_n_steps == 0: + for name, value in _loss_dict.items(): + train_writer.add_scalar(f"Losses/{name}", value, global_step=iteration) + train_writer.flush() + + if saving_enabled and iteration % config.train.ckpt_every_n_steps == 0: + logger.info(f"iter={iteration}: saving checkpoint to `{config.train.ckpt_dir}`") + save_checkpoint( + config.train.ckpt_dir, + {"model": model, "optimizer": optimizer}, + iteration=iteration, + ) + + if logging_enabled and iteration % config.train.summary_every_n_steps == 0: + summaries = model.compute_summaries(preds, batch) + save_file_summaries(config.train.run_dir, summaries, prefix="train") + + if lr_scheduler is not None and iteration and iteration % config.train.update_lr_every == 0: + lr_scheduler.step() + + iteration += 1 + if iteration >= config.train.n_max_iters: + logger.info(f"reached max number of iters ({config.train.n_max_iters})") + break + + if saving_enabled: + logger.info(f"saving the final checkpoint to `{config.train.run_dir}/model.pt`") + save_checkpoint(f"{config.train.run_dir}/model.pt", {"model": model}) + diff --git a/visualize/render_anno.py b/visualize/render_anno.py new file mode 100644 index 0000000000000000000000000000000000000000..5c02d19bc1f4fdf0a959d836194a2cf225ce9a95 --- /dev/null +++ b/visualize/render_anno.py @@ -0,0 +1,58 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import os + +import torch + +from data_loaders.get_data import load_local_data + +from tqdm import tqdm + +from utils.diff_parser_utils import train_args +from utils.misc import fixseed +from utils.model_util import get_person_num +from visualize.render_codes import BodyRenderer + + +def main(): + args = train_args() + fixseed(args.seed) + args.num_repetitions = 1 + config_base = f"./checkpoints/ca_body/data/{get_person_num(args.data_root)}" + body_renderer = BodyRenderer( + config_base=config_base, + render_rgb=True, + ).to(args.device) + data_root = args.data_root + data_dict = load_local_data(data_root, audio_per_frame=1600) + if not os.path.exists(args.save_dir): + os.makedirs(args.save_dir, exist_ok=True) + + for i in range(len(data_dict["data"])): + end_range = len(data_dict["data"][i]) - args.max_seq_length + for chunk_idx in tqdm(range(0, end_range, args.max_seq_length)): + chunk_end = chunk_idx + args.max_seq_length + curr_data_chunk = data_dict["data"][i][chunk_idx:chunk_end, :] + curr_face_chunk = data_dict["face"][i][chunk_idx:chunk_end, :] + curr_audio_chunk = data_dict["audio"][i][ + chunk_idx * 1600 : chunk_end * 1600, : + ].T + render_data_block = { + "audio": curr_audio_chunk, # 2 x T + "body_motion": curr_data_chunk, # T x 104 + "face_motion": curr_face_chunk, # T x 256 + } + body_renderer.render_full_video( + render_data_block, + f"{args.save_dir}/scene{i}_{chunk_idx:04d}.mp4", + audio_sr=48_000, + ) + + +if __name__ == "__main__": + main() diff --git a/visualize/render_codes.py b/visualize/render_codes.py new file mode 100644 index 0000000000000000000000000000000000000000..16f9cd1c6a4c31bfe4f3ed63379d5d450b3d02cd --- /dev/null +++ b/visualize/render_codes.py @@ -0,0 +1,163 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import copy +import glob +import os +import re +import subprocess +from collections import OrderedDict +from typing import Dict, List + +import mediapy + +import numpy as np + +import torch +import torch as th +import torchaudio +from attrdict import AttrDict + +from omegaconf import OmegaConf +from tqdm import tqdm +from utils.model_util import get_person_num +from visualize.ca_body.utils.image import linear2displayBatch +from visualize.ca_body.utils.train import load_checkpoint, load_from_config + +ffmpeg_header = "ffmpeg -y " # -hide_banner -loglevel error " + + +def filter_params(params, ignore_names): + return OrderedDict( + [ + (k, v) + for k, v in params.items() + if not any([re.match(n, k) is not None for n in ignore_names]) + ] + ) + + +def call_ffmpeg(command: str) -> None: + print(command, "-" * 100) + e = subprocess.call(command, shell=True) + if e != 0: + assert False, e + + +class BodyRenderer(th.nn.Module): + def __init__( + self, + config_base: str, + render_rgb: bool, + ): + super().__init__() + self.config_base = config_base + ckpt_path = f"{config_base}/body_dec.ckpt" + config_path = f"{config_base}/config.yml" + assets_path = f"{config_base}/static_assets.pt" + # config + config = OmegaConf.load(config_path) + gpu = config.get("gpu", 0) + self.device = th.device(f"cuda:{gpu}") + # assets + static_assets = AttrDict(torch.load(assets_path)) + # build model + self.model = load_from_config(config.model, assets=static_assets).to( + self.device + ) + self.model.cal_enabled = False + self.model.pixel_cal_enabled = False + self.model.learn_blur_enabled = False + self.render_rgb = render_rgb + if not self.render_rgb: + self.model.rendering_enabled = None + # load model checkpoints + print("loading...", ckpt_path) + load_checkpoint( + ckpt_path, + modules={"model": self.model}, + ignore_names={"model": ["lbs_fn.*"]}, + ) + self.model.eval() + self.model.to(self.device) + # load default parameters for renderer + person = get_person_num(config_path) + self.default_inputs = th.load(f"assets/render_defaults_{person}.pth") + + def _write_video_stream( + self, motion: np.ndarray, face: np.ndarray, save_name: str + ) -> None: + out = self._render_loop(motion, face) + mediapy.write_video(save_name, out, fps=30) + + def _render_loop(self, body_pose: np.ndarray, face: np.ndarray) -> List[np.ndarray]: + all_rgb = [] + default_inputs_copy = copy.deepcopy(self.default_inputs) + for b in tqdm(range(len(body_pose))): + B = default_inputs_copy["K"].shape[0] + default_inputs_copy["lbs_motion"] = ( + th.tensor(body_pose[b : b + 1, :], device=self.device, dtype=th.float) + .tile(B, 1) + .to(self.device) + ) + geom = ( + self.model.lbs_fn.lbs_fn( + default_inputs_copy["lbs_motion"], + self.model.lbs_fn.lbs_scale.unsqueeze(0).tile(B, 1), + self.model.lbs_fn.lbs_template_verts.unsqueeze(0).tile(B, 1, 1), + ) + * self.model.lbs_fn.global_scaling + ) + default_inputs_copy["geom"] = geom + face_codes = ( + th.from_numpy(face).float().cuda() if not th.is_tensor(face) else face + ) + curr_face = th.tile(face_codes[b : b + 1, ...], (2, 1)) + default_inputs_copy["face_embs"] = curr_face + preds = self.model(**default_inputs_copy) + rgb0 = linear2displayBatch(preds["rgb"])[0] + rgb1 = linear2displayBatch(preds["rgb"])[1] + rgb = th.cat((rgb0, rgb1), axis=-1).permute(1, 2, 0) + rgb = rgb.clip(0, 255).to(th.uint8) + all_rgb.append(rgb.contiguous().detach().byte().cpu().numpy()) + return all_rgb + + def render_full_video( + self, + data_block: Dict[str, np.ndarray], + animation_save_path: str, + audio_sr: int = None, + render_gt: bool = False, + ) -> None: + tag = os.path.basename(os.path.dirname(animation_save_path)) + save_name = os.path.splitext(os.path.basename(animation_save_path))[0] + save_name = f"{tag}_{save_name}" + torchaudio.save( + f"/tmp/audio_{save_name}.wav", + torch.tensor(data_block["audio"]), + audio_sr, + ) + if render_gt: + tag = "gt" + self._write_video_stream( + data_block["gt_body"], + data_block["gt_face"], + f"/tmp/{tag}_{save_name}.mp4", + ) + else: + tag = "pred" + self._write_video_stream( + data_block["body_motion"], + data_block["face_motion"], + f"/tmp/{tag}_{save_name}.mp4", + ) + command = f"{ffmpeg_header} -i /tmp/{tag}_{save_name}.mp4 -i /tmp/audio_{save_name}.wav -c:v copy -map 0:v:0 -map 1:a:0 -c:a aac -b:a 192k -pix_fmt yuva420p {animation_save_path}_{tag}.mp4" + call_ffmpeg(command) + subprocess.call( + f"rm /tmp/audio_{save_name}.wav && rm /tmp/{tag}_{save_name}.mp4", + shell=True, + )