Spaces:
Runtime error
Runtime error
Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- .gitignore +4 -0
- CODE_OF_CONDUCT.md +80 -0
- CONTRIBUTING.md +31 -0
- LICENSE +400 -0
- README.md +373 -7
- assets/demo1.gif +0 -0
- assets/demo2.gif +3 -0
- assets/render_defaults_GQS883.pth +3 -0
- assets/render_defaults_PXB184.pth +3 -0
- assets/render_defaults_RLW104.pth +3 -0
- assets/render_defaults_TXB805.pth +3 -0
- checkpoints/ca_body/data/PXB184/body_dec.ckpt +3 -0
- checkpoints/ca_body/data/PXB184/config.yml +56 -0
- checkpoints/diffusion/c1_face/args.json +34 -0
- checkpoints/diffusion/c1_pose/args.json +66 -0
- checkpoints/guide/c1_pose/args.json +41 -0
- checkpoints/vq/c1_pose/args.json +43 -0
- checkpoints/vq/c1_pose/net_iter300000.pth +3 -0
- data_loaders/data.py +253 -0
- data_loaders/get_data.py +129 -0
- data_loaders/tensors.py +86 -0
- demo/.ipynb_checkpoints/demo-checkpoint.py +276 -0
- demo/demo.py +276 -0
- demo/install.sh +20 -0
- demo/requirements.txt +17 -0
- diffusion/fp16_util.py +250 -0
- diffusion/gaussian_diffusion.py +1273 -0
- diffusion/losses.py +83 -0
- diffusion/nn.py +213 -0
- diffusion/resample.py +168 -0
- diffusion/respace.py +145 -0
- flagged/audio/b90d90dbca93f47e8d01/audio.wav +0 -0
- flagged/audio/d8e03e2e6deae2f981b1/audio.wav +0 -0
- flagged/log.csv +4 -0
- model/cfg_sampler.py +33 -0
- model/diffusion.py +403 -0
- model/guide.py +222 -0
- model/modules/audio_encoder.py +194 -0
- model/modules/rotary_embedding_torch.py +139 -0
- model/modules/transformer_modules.py +702 -0
- model/utils.py +130 -0
- model/vqvae.py +550 -0
- sample/generate.py +316 -0
- scripts/download_alldatasets.sh +6 -0
- scripts/download_allmodels.sh +13 -0
- scripts/download_prereq.sh +9 -0
- scripts/installation.sh +4 -0
- scripts/requirements.txt +17 -0
- train/train_diffusion.py +83 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
assets/demo2.gif filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.pyc
|
2 |
+
*.pt
|
3 |
+
!dataset/*/data_stats.pth
|
4 |
+
dataset
|
CODE_OF_CONDUCT.md
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Code of Conduct
|
2 |
+
|
3 |
+
## Our Pledge
|
4 |
+
|
5 |
+
In the interest of fostering an open and welcoming environment, we as
|
6 |
+
contributors and maintainers pledge to make participation in our project and
|
7 |
+
our community a harassment-free experience for everyone, regardless of age, body
|
8 |
+
size, disability, ethnicity, sex characteristics, gender identity and expression,
|
9 |
+
level of experience, education, socio-economic status, nationality, personal
|
10 |
+
appearance, race, religion, or sexual identity and orientation.
|
11 |
+
|
12 |
+
## Our Standards
|
13 |
+
|
14 |
+
Examples of behavior that contributes to creating a positive environment
|
15 |
+
include:
|
16 |
+
|
17 |
+
* Using welcoming and inclusive language
|
18 |
+
* Being respectful of differing viewpoints and experiences
|
19 |
+
* Gracefully accepting constructive criticism
|
20 |
+
* Focusing on what is best for the community
|
21 |
+
* Showing empathy towards other community members
|
22 |
+
|
23 |
+
Examples of unacceptable behavior by participants include:
|
24 |
+
|
25 |
+
* The use of sexualized language or imagery and unwelcome sexual attention or
|
26 |
+
advances
|
27 |
+
* Trolling, insulting/derogatory comments, and personal or political attacks
|
28 |
+
* Public or private harassment
|
29 |
+
* Publishing others' private information, such as a physical or electronic
|
30 |
+
address, without explicit permission
|
31 |
+
* Other conduct which could reasonably be considered inappropriate in a
|
32 |
+
professional setting
|
33 |
+
|
34 |
+
## Our Responsibilities
|
35 |
+
|
36 |
+
Project maintainers are responsible for clarifying the standards of acceptable
|
37 |
+
behavior and are expected to take appropriate and fair corrective action in
|
38 |
+
response to any instances of unacceptable behavior.
|
39 |
+
|
40 |
+
Project maintainers have the right and responsibility to remove, edit, or
|
41 |
+
reject comments, commits, code, wiki edits, issues, and other contributions
|
42 |
+
that are not aligned to this Code of Conduct, or to ban temporarily or
|
43 |
+
permanently any contributor for other behaviors that they deem inappropriate,
|
44 |
+
threatening, offensive, or harmful.
|
45 |
+
|
46 |
+
## Scope
|
47 |
+
|
48 |
+
This Code of Conduct applies within all project spaces, and it also applies when
|
49 |
+
an individual is representing the project or its community in public spaces.
|
50 |
+
Examples of representing a project or community include using an official
|
51 |
+
project e-mail address, posting via an official social media account, or acting
|
52 |
+
as an appointed representative at an online or offline event. Representation of
|
53 |
+
a project may be further defined and clarified by project maintainers.
|
54 |
+
|
55 |
+
This Code of Conduct also applies outside the project spaces when there is a
|
56 |
+
reasonable belief that an individual's behavior may have a negative impact on
|
57 |
+
the project or its community.
|
58 |
+
|
59 |
+
## Enforcement
|
60 |
+
|
61 |
+
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
62 |
+
reported by contacting the project team at <[email protected]>. All
|
63 |
+
complaints will be reviewed and investigated and will result in a response that
|
64 |
+
is deemed necessary and appropriate to the circumstances. The project team is
|
65 |
+
obligated to maintain confidentiality with regard to the reporter of an incident.
|
66 |
+
Further details of specific enforcement policies may be posted separately.
|
67 |
+
|
68 |
+
Project maintainers who do not follow or enforce the Code of Conduct in good
|
69 |
+
faith may face temporary or permanent repercussions as determined by other
|
70 |
+
members of the project's leadership.
|
71 |
+
|
72 |
+
## Attribution
|
73 |
+
|
74 |
+
This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
|
75 |
+
available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
|
76 |
+
|
77 |
+
[homepage]: https://www.contributor-covenant.org
|
78 |
+
|
79 |
+
For answers to common questions about this code of conduct, see
|
80 |
+
https://www.contributor-covenant.org/faq
|
CONTRIBUTING.md
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Contributing to audio2photoreal
|
2 |
+
We want to make contributing to this project as easy and transparent as
|
3 |
+
possible.
|
4 |
+
|
5 |
+
## Pull Requests
|
6 |
+
We actively welcome your pull requests.
|
7 |
+
|
8 |
+
1. Fork the repo and create your branch from `main`.
|
9 |
+
2. If you've added code that should be tested, add tests.
|
10 |
+
3. If you've changed APIs, update the documentation.
|
11 |
+
4. Ensure the test suite passes.
|
12 |
+
5. Make sure your code lints.
|
13 |
+
6. If you haven't already, complete the Contributor License Agreement ("CLA").
|
14 |
+
|
15 |
+
## Contributor License Agreement ("CLA")
|
16 |
+
In order to accept your pull request, we need you to submit a CLA. You only need
|
17 |
+
to do this once to work on any of Meta's open source projects.
|
18 |
+
|
19 |
+
Complete your CLA here: <https://code.facebook.com/cla>
|
20 |
+
|
21 |
+
## Issues
|
22 |
+
We use GitHub issues to track public bugs. Please ensure your description is
|
23 |
+
clear and has sufficient instructions to be able to reproduce the issue.
|
24 |
+
|
25 |
+
Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe
|
26 |
+
disclosure of security bugs. In those cases, please go through the process
|
27 |
+
outlined on that page and do not file a public issue.
|
28 |
+
|
29 |
+
## License
|
30 |
+
By contributing to audio2photoreal, you agree that your contributions will be licensed
|
31 |
+
under the LICENSE file in the root directory of this source tree.
|
LICENSE
ADDED
@@ -0,0 +1,400 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Attribution-NonCommercial 4.0 International
|
2 |
+
|
3 |
+
=======================================================================
|
4 |
+
|
5 |
+
Creative Commons Corporation ("Creative Commons") is not a law firm and
|
6 |
+
does not provide legal services or legal advice. Distribution of
|
7 |
+
Creative Commons public licenses does not create a lawyer-client or
|
8 |
+
other relationship. Creative Commons makes its licenses and related
|
9 |
+
information available on an "as-is" basis. Creative Commons gives no
|
10 |
+
warranties regarding its licenses, any material licensed under their
|
11 |
+
terms and conditions, or any related information. Creative Commons
|
12 |
+
disclaims all liability for damages resulting from their use to the
|
13 |
+
fullest extent possible.
|
14 |
+
|
15 |
+
Using Creative Commons Public Licenses
|
16 |
+
|
17 |
+
Creative Commons public licenses provide a standard set of terms and
|
18 |
+
conditions that creators and other rights holders may use to share
|
19 |
+
original works of authorship and other material subject to copyright
|
20 |
+
and certain other rights specified in the public license below. The
|
21 |
+
following considerations are for informational purposes only, are not
|
22 |
+
exhaustive, and do not form part of our licenses.
|
23 |
+
|
24 |
+
Considerations for licensors: Our public licenses are
|
25 |
+
intended for use by those authorized to give the public
|
26 |
+
permission to use material in ways otherwise restricted by
|
27 |
+
copyright and certain other rights. Our licenses are
|
28 |
+
irrevocable. Licensors should read and understand the terms
|
29 |
+
and conditions of the license they choose before applying it.
|
30 |
+
Licensors should also secure all rights necessary before
|
31 |
+
applying our licenses so that the public can reuse the
|
32 |
+
material as expected. Licensors should clearly mark any
|
33 |
+
material not subject to the license. This includes other CC-
|
34 |
+
licensed material, or material used under an exception or
|
35 |
+
limitation to copyright. More considerations for licensors:
|
36 |
+
wiki.creativecommons.org/Considerations_for_licensors
|
37 |
+
|
38 |
+
Considerations for the public: By using one of our public
|
39 |
+
licenses, a licensor grants the public permission to use the
|
40 |
+
licensed material under specified terms and conditions. If
|
41 |
+
the licensor's permission is not necessary for any reason--for
|
42 |
+
example, because of any applicable exception or limitation to
|
43 |
+
copyright--then that use is not regulated by the license. Our
|
44 |
+
licenses grant only permissions under copyright and certain
|
45 |
+
other rights that a licensor has authority to grant. Use of
|
46 |
+
the licensed material may still be restricted for other
|
47 |
+
reasons, including because others have copyright or other
|
48 |
+
rights in the material. A licensor may make special requests,
|
49 |
+
such as asking that all changes be marked or described.
|
50 |
+
Although not required by our licenses, you are encouraged to
|
51 |
+
respect those requests where reasonable. More_considerations
|
52 |
+
for the public:
|
53 |
+
wiki.creativecommons.org/Considerations_for_licensees
|
54 |
+
|
55 |
+
=======================================================================
|
56 |
+
|
57 |
+
Creative Commons Attribution-NonCommercial 4.0 International Public
|
58 |
+
License
|
59 |
+
|
60 |
+
By exercising the Licensed Rights (defined below), You accept and agree
|
61 |
+
to be bound by the terms and conditions of this Creative Commons
|
62 |
+
Attribution-NonCommercial 4.0 International Public License ("Public
|
63 |
+
License"). To the extent this Public License may be interpreted as a
|
64 |
+
contract, You are granted the Licensed Rights in consideration of Your
|
65 |
+
acceptance of these terms and conditions, and the Licensor grants You
|
66 |
+
such rights in consideration of benefits the Licensor receives from
|
67 |
+
making the Licensed Material available under these terms and
|
68 |
+
conditions.
|
69 |
+
|
70 |
+
Section 1 -- Definitions.
|
71 |
+
|
72 |
+
a. Adapted Material means material subject to Copyright and Similar
|
73 |
+
Rights that is derived from or based upon the Licensed Material
|
74 |
+
and in which the Licensed Material is translated, altered,
|
75 |
+
arranged, transformed, or otherwise modified in a manner requiring
|
76 |
+
permission under the Copyright and Similar Rights held by the
|
77 |
+
Licensor. For purposes of this Public License, where the Licensed
|
78 |
+
Material is a musical work, performance, or sound recording,
|
79 |
+
Adapted Material is always produced where the Licensed Material is
|
80 |
+
synched in timed relation with a moving image.
|
81 |
+
|
82 |
+
b. Adapter's License means the license You apply to Your Copyright
|
83 |
+
and Similar Rights in Your contributions to Adapted Material in
|
84 |
+
accordance with the terms and conditions of this Public License.
|
85 |
+
|
86 |
+
c. Copyright and Similar Rights means copyright and/or similar rights
|
87 |
+
closely related to copyright including, without limitation,
|
88 |
+
performance, broadcast, sound recording, and Sui Generis Database
|
89 |
+
Rights, without regard to how the rights are labeled or
|
90 |
+
categorized. For purposes of this Public License, the rights
|
91 |
+
specified in Section 2(b)(1)-(2) are not Copyright and Similar
|
92 |
+
Rights.
|
93 |
+
d. Effective Technological Measures means those measures that, in the
|
94 |
+
absence of proper authority, may not be circumvented under laws
|
95 |
+
fulfilling obligations under Article 11 of the WIPO Copyright
|
96 |
+
Treaty adopted on December 20, 1996, and/or similar international
|
97 |
+
agreements.
|
98 |
+
|
99 |
+
e. Exceptions and Limitations means fair use, fair dealing, and/or
|
100 |
+
any other exception or limitation to Copyright and Similar Rights
|
101 |
+
that applies to Your use of the Licensed Material.
|
102 |
+
|
103 |
+
f. Licensed Material means the artistic or literary work, database,
|
104 |
+
or other material to which the Licensor applied this Public
|
105 |
+
License.
|
106 |
+
|
107 |
+
g. Licensed Rights means the rights granted to You subject to the
|
108 |
+
terms and conditions of this Public License, which are limited to
|
109 |
+
all Copyright and Similar Rights that apply to Your use of the
|
110 |
+
Licensed Material and that the Licensor has authority to license.
|
111 |
+
|
112 |
+
h. Licensor means the individual(s) or entity(ies) granting rights
|
113 |
+
under this Public License.
|
114 |
+
|
115 |
+
i. NonCommercial means not primarily intended for or directed towards
|
116 |
+
commercial advantage or monetary compensation. For purposes of
|
117 |
+
this Public License, the exchange of the Licensed Material for
|
118 |
+
other material subject to Copyright and Similar Rights by digital
|
119 |
+
file-sharing or similar means is NonCommercial provided there is
|
120 |
+
no payment of monetary compensation in connection with the
|
121 |
+
exchange.
|
122 |
+
|
123 |
+
j. Share means to provide material to the public by any means or
|
124 |
+
process that requires permission under the Licensed Rights, such
|
125 |
+
as reproduction, public display, public performance, distribution,
|
126 |
+
dissemination, communication, or importation, and to make material
|
127 |
+
available to the public including in ways that members of the
|
128 |
+
public may access the material from a place and at a time
|
129 |
+
individually chosen by them.
|
130 |
+
|
131 |
+
k. Sui Generis Database Rights means rights other than copyright
|
132 |
+
resulting from Directive 96/9/EC of the European Parliament and of
|
133 |
+
the Council of 11 March 1996 on the legal protection of databases,
|
134 |
+
as amended and/or succeeded, as well as other essentially
|
135 |
+
equivalent rights anywhere in the world.
|
136 |
+
|
137 |
+
l. You means the individual or entity exercising the Licensed Rights
|
138 |
+
under this Public License. Your has a corresponding meaning.
|
139 |
+
|
140 |
+
Section 2 -- Scope.
|
141 |
+
|
142 |
+
a. License grant.
|
143 |
+
|
144 |
+
1. Subject to the terms and conditions of this Public License,
|
145 |
+
the Licensor hereby grants You a worldwide, royalty-free,
|
146 |
+
non-sublicensable, non-exclusive, irrevocable license to
|
147 |
+
exercise the Licensed Rights in the Licensed Material to:
|
148 |
+
|
149 |
+
a. reproduce and Share the Licensed Material, in whole or
|
150 |
+
in part, for NonCommercial purposes only; and
|
151 |
+
|
152 |
+
b. produce, reproduce, and Share Adapted Material for
|
153 |
+
NonCommercial purposes only.
|
154 |
+
|
155 |
+
2. Exceptions and Limitations. For the avoidance of doubt, where
|
156 |
+
Exceptions and Limitations apply to Your use, this Public
|
157 |
+
License does not apply, and You do not need to comply with
|
158 |
+
its terms and conditions.
|
159 |
+
|
160 |
+
3. Term. The term of this Public License is specified in Section
|
161 |
+
6(a).
|
162 |
+
|
163 |
+
4. Media and formats; technical modifications allowed. The
|
164 |
+
Licensor authorizes You to exercise the Licensed Rights in
|
165 |
+
all media and formats whether now known or hereafter created,
|
166 |
+
and to make technical modifications necessary to do so. The
|
167 |
+
Licensor waives and/or agrees not to assert any right or
|
168 |
+
authority to forbid You from making technical modifications
|
169 |
+
necessary to exercise the Licensed Rights, including
|
170 |
+
technical modifications necessary to circumvent Effective
|
171 |
+
Technological Measures. For purposes of this Public License,
|
172 |
+
simply making modifications authorized by this Section 2(a)
|
173 |
+
(4) never produces Adapted Material.
|
174 |
+
|
175 |
+
5. Downstream recipients.
|
176 |
+
|
177 |
+
a. Offer from the Licensor -- Licensed Material. Every
|
178 |
+
recipient of the Licensed Material automatically
|
179 |
+
receives an offer from the Licensor to exercise the
|
180 |
+
Licensed Rights under the terms and conditions of this
|
181 |
+
Public License.
|
182 |
+
|
183 |
+
b. No downstream restrictions. You may not offer or impose
|
184 |
+
any additional or different terms or conditions on, or
|
185 |
+
apply any Effective Technological Measures to, the
|
186 |
+
Licensed Material if doing so restricts exercise of the
|
187 |
+
Licensed Rights by any recipient of the Licensed
|
188 |
+
Material.
|
189 |
+
|
190 |
+
6. No endorsement. Nothing in this Public License constitutes or
|
191 |
+
may be construed as permission to assert or imply that You
|
192 |
+
are, or that Your use of the Licensed Material is, connected
|
193 |
+
with, or sponsored, endorsed, or granted official status by,
|
194 |
+
the Licensor or others designated to receive attribution as
|
195 |
+
provided in Section 3(a)(1)(A)(i).
|
196 |
+
|
197 |
+
b. Other rights.
|
198 |
+
|
199 |
+
1. Moral rights, such as the right of integrity, are not
|
200 |
+
licensed under this Public License, nor are publicity,
|
201 |
+
privacy, and/or other similar personality rights; however, to
|
202 |
+
the extent possible, the Licensor waives and/or agrees not to
|
203 |
+
assert any such rights held by the Licensor to the limited
|
204 |
+
extent necessary to allow You to exercise the Licensed
|
205 |
+
Rights, but not otherwise.
|
206 |
+
|
207 |
+
2. Patent and trademark rights are not licensed under this
|
208 |
+
Public License.
|
209 |
+
|
210 |
+
3. To the extent possible, the Licensor waives any right to
|
211 |
+
collect royalties from You for the exercise of the Licensed
|
212 |
+
Rights, whether directly or through a collecting society
|
213 |
+
under any voluntary or waivable statutory or compulsory
|
214 |
+
licensing scheme. In all other cases the Licensor expressly
|
215 |
+
reserves any right to collect such royalties, including when
|
216 |
+
the Licensed Material is used other than for NonCommercial
|
217 |
+
purposes.
|
218 |
+
|
219 |
+
Section 3 -- License Conditions.
|
220 |
+
|
221 |
+
Your exercise of the Licensed Rights is expressly made subject to the
|
222 |
+
following conditions.
|
223 |
+
|
224 |
+
a. Attribution.
|
225 |
+
|
226 |
+
1. If You Share the Licensed Material (including in modified
|
227 |
+
form), You must:
|
228 |
+
|
229 |
+
a. retain the following if it is supplied by the Licensor
|
230 |
+
with the Licensed Material:
|
231 |
+
|
232 |
+
i. identification of the creator(s) of the Licensed
|
233 |
+
Material and any others designated to receive
|
234 |
+
attribution, in any reasonable manner requested by
|
235 |
+
the Licensor (including by pseudonym if
|
236 |
+
designated);
|
237 |
+
|
238 |
+
ii. a copyright notice;
|
239 |
+
|
240 |
+
iii. a notice that refers to this Public License;
|
241 |
+
|
242 |
+
iv. a notice that refers to the disclaimer of
|
243 |
+
warranties;
|
244 |
+
|
245 |
+
v. a URI or hyperlink to the Licensed Material to the
|
246 |
+
extent reasonably practicable;
|
247 |
+
|
248 |
+
b. indicate if You modified the Licensed Material and
|
249 |
+
retain an indication of any previous modifications; and
|
250 |
+
|
251 |
+
c. indicate the Licensed Material is licensed under this
|
252 |
+
Public License, and include the text of, or the URI or
|
253 |
+
hyperlink to, this Public License.
|
254 |
+
|
255 |
+
2. You may satisfy the conditions in Section 3(a)(1) in any
|
256 |
+
reasonable manner based on the medium, means, and context in
|
257 |
+
which You Share the Licensed Material. For example, it may be
|
258 |
+
reasonable to satisfy the conditions by providing a URI or
|
259 |
+
hyperlink to a resource that includes the required
|
260 |
+
information.
|
261 |
+
|
262 |
+
3. If requested by the Licensor, You must remove any of the
|
263 |
+
information required by Section 3(a)(1)(A) to the extent
|
264 |
+
reasonably practicable.
|
265 |
+
|
266 |
+
4. If You Share Adapted Material You produce, the Adapter's
|
267 |
+
License You apply must not prevent recipients of the Adapted
|
268 |
+
Material from complying with this Public License.
|
269 |
+
|
270 |
+
Section 4 -- Sui Generis Database Rights.
|
271 |
+
|
272 |
+
Where the Licensed Rights include Sui Generis Database Rights that
|
273 |
+
apply to Your use of the Licensed Material:
|
274 |
+
|
275 |
+
a. for the avoidance of doubt, Section 2(a)(1) grants You the right
|
276 |
+
to extract, reuse, reproduce, and Share all or a substantial
|
277 |
+
portion of the contents of the database for NonCommercial purposes
|
278 |
+
only;
|
279 |
+
|
280 |
+
b. if You include all or a substantial portion of the database
|
281 |
+
contents in a database in which You have Sui Generis Database
|
282 |
+
Rights, then the database in which You have Sui Generis Database
|
283 |
+
Rights (but not its individual contents) is Adapted Material; and
|
284 |
+
|
285 |
+
c. You must comply with the conditions in Section 3(a) if You Share
|
286 |
+
all or a substantial portion of the contents of the database.
|
287 |
+
|
288 |
+
For the avoidance of doubt, this Section 4 supplements and does not
|
289 |
+
replace Your obligations under this Public License where the Licensed
|
290 |
+
Rights include other Copyright and Similar Rights.
|
291 |
+
|
292 |
+
Section 5 -- Disclaimer of Warranties and Limitation of Liability.
|
293 |
+
|
294 |
+
a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
|
295 |
+
EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
|
296 |
+
AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
|
297 |
+
ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
|
298 |
+
IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
|
299 |
+
WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
|
300 |
+
PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
|
301 |
+
ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
|
302 |
+
KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
|
303 |
+
ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
|
304 |
+
|
305 |
+
b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
|
306 |
+
TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
|
307 |
+
NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
|
308 |
+
INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
|
309 |
+
COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
|
310 |
+
USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
|
311 |
+
ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
|
312 |
+
DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
|
313 |
+
IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
|
314 |
+
|
315 |
+
c. The disclaimer of warranties and limitation of liability provided
|
316 |
+
above shall be interpreted in a manner that, to the extent
|
317 |
+
possible, most closely approximates an absolute disclaimer and
|
318 |
+
waiver of all liability.
|
319 |
+
|
320 |
+
Section 6 -- Term and Termination.
|
321 |
+
|
322 |
+
a. This Public License applies for the term of the Copyright and
|
323 |
+
Similar Rights licensed here. However, if You fail to comply with
|
324 |
+
this Public License, then Your rights under this Public License
|
325 |
+
terminate automatically.
|
326 |
+
|
327 |
+
b. Where Your right to use the Licensed Material has terminated under
|
328 |
+
Section 6(a), it reinstates:
|
329 |
+
|
330 |
+
1. automatically as of the date the violation is cured, provided
|
331 |
+
it is cured within 30 days of Your discovery of the
|
332 |
+
violation; or
|
333 |
+
|
334 |
+
2. upon express reinstatement by the Licensor.
|
335 |
+
|
336 |
+
For the avoidance of doubt, this Section 6(b) does not affect any
|
337 |
+
right the Licensor may have to seek remedies for Your violations
|
338 |
+
of this Public License.
|
339 |
+
|
340 |
+
c. For the avoidance of doubt, the Licensor may also offer the
|
341 |
+
Licensed Material under separate terms or conditions or stop
|
342 |
+
distributing the Licensed Material at any time; however, doing so
|
343 |
+
will not terminate this Public License.
|
344 |
+
|
345 |
+
d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
|
346 |
+
License.
|
347 |
+
|
348 |
+
Section 7 -- Other Terms and Conditions.
|
349 |
+
|
350 |
+
a. The Licensor shall not be bound by any additional or different
|
351 |
+
terms or conditions communicated by You unless expressly agreed.
|
352 |
+
|
353 |
+
b. Any arrangements, understandings, or agreements regarding the
|
354 |
+
Licensed Material not stated herein are separate from and
|
355 |
+
independent of the terms and conditions of this Public License.
|
356 |
+
|
357 |
+
Section 8 -- Interpretation.
|
358 |
+
|
359 |
+
a. For the avoidance of doubt, this Public License does not, and
|
360 |
+
shall not be interpreted to, reduce, limit, restrict, or impose
|
361 |
+
conditions on any use of the Licensed Material that could lawfully
|
362 |
+
be made without permission under this Public License.
|
363 |
+
|
364 |
+
b. To the extent possible, if any provision of this Public License is
|
365 |
+
deemed unenforceable, it shall be automatically reformed to the
|
366 |
+
minimum extent necessary to make it enforceable. If the provision
|
367 |
+
cannot be reformed, it shall be severed from this Public License
|
368 |
+
without affecting the enforceability of the remaining terms and
|
369 |
+
conditions.
|
370 |
+
|
371 |
+
c. No term or condition of this Public License will be waived and no
|
372 |
+
failure to comply consented to unless expressly agreed to by the
|
373 |
+
Licensor.
|
374 |
+
|
375 |
+
d. Nothing in this Public License constitutes or may be interpreted
|
376 |
+
as a limitation upon, or waiver of, any privileges and immunities
|
377 |
+
that apply to the Licensor or You, including from the legal
|
378 |
+
processes of any jurisdiction or authority.
|
379 |
+
|
380 |
+
=======================================================================
|
381 |
+
|
382 |
+
Creative Commons is not a party to its public
|
383 |
+
licenses. Notwithstanding, Creative Commons may elect to apply one of
|
384 |
+
its public licenses to material it publishes and in those instances
|
385 |
+
will be considered the “Licensor.” The text of the Creative Commons
|
386 |
+
public licenses is dedicated to the public domain under the CC0 Public
|
387 |
+
Domain Dedication. Except for the limited purpose of indicating that
|
388 |
+
material is shared under a Creative Commons public license or as
|
389 |
+
otherwise permitted by the Creative Commons policies published at
|
390 |
+
creativecommons.org/policies, Creative Commons does not authorize the
|
391 |
+
use of the trademark "Creative Commons" or any other trademark or logo
|
392 |
+
of Creative Commons without its prior written consent including,
|
393 |
+
without limitation, in connection with any unauthorized modifications
|
394 |
+
to any of its public licenses or any other arrangements,
|
395 |
+
understandings, or agreements concerning use of licensed material. For
|
396 |
+
the avoidance of doubt, this paragraph does not form part of the
|
397 |
+
public licenses.
|
398 |
+
|
399 |
+
Creative Commons may be contacted at creativecommons.org.
|
400 |
+
|
README.md
CHANGED
@@ -1,12 +1,378 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
|
4 |
-
colorFrom: gray
|
5 |
-
colorTo: indigo
|
6 |
sdk: gradio
|
7 |
sdk_version: 4.38.1
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
---
|
|
|
|
|
11 |
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
title: test_virtual
|
3 |
+
app_file: ./demo/demo.py
|
|
|
|
|
4 |
sdk: gradio
|
5 |
sdk_version: 4.38.1
|
|
|
|
|
6 |
---
|
7 |
+
# From Audio to Photoreal Embodiment: Synthesizing Humans in Conversations
|
8 |
+
This repository contains a pytorch implementation of ["From Audio to Photoreal Embodiment: Synthesizing Humans in Conversations"](https://people.eecs.berkeley.edu/~evonne_ng/projects/audio2photoreal/)
|
9 |
|
10 |
+
:hatching_chick: **Try out our demo [here](https://colab.research.google.com/drive/1lnX3d-3T3LaO3nlN6R8s6pPvVNAk5mdK?usp=sharing)** or continue following the steps below to run code locally!
|
11 |
+
And thanks everyone for the support via contributions/comments/issues!
|
12 |
+
|
13 |
+
https://github.com/facebookresearch/audio2photoreal/assets/17986358/5cba4079-275e-48b6-aecc-f84f3108c810
|
14 |
+
|
15 |
+
This codebase provides:
|
16 |
+
- train code
|
17 |
+
- test code
|
18 |
+
- pretrained motion models
|
19 |
+
- access to dataset
|
20 |
+
|
21 |
+
If you use the dataset or code, please cite our [Paper](https://arxiv.org/abs/2401.01885)
|
22 |
+
|
23 |
+
```
|
24 |
+
@inproceedings{ng2024audio2photoreal,
|
25 |
+
title={From Audio to Photoreal Embodiment: Synthesizing Humans in Conversations},
|
26 |
+
author={Ng, Evonne and Romero, Javier and Bagautdinov, Timur and Bai, Shaojie and Darrell, Trevor and Kanazawa, Angjoo and Richard, Alexander},
|
27 |
+
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
|
28 |
+
year={2024}
|
29 |
+
}
|
30 |
+
```
|
31 |
+
|
32 |
+
### Repository Contents
|
33 |
+
|
34 |
+
- [**Quickstart:**](#quickstart) easy gradio demo that lets you record audio and render a video
|
35 |
+
- [**Installation:**](#installation) environment setup and installation (for more details on the rendering pipeline, please refer to [Codec Avatar Body](https://github.com/facebookresearch/ca_body))
|
36 |
+
- [**Download data and models:**](#download-data-and-models) download annotations and pre-trained models
|
37 |
+
- [Dataset desc.](#dataset): description of dataset annotations
|
38 |
+
- [Visualize Dataset](#visualize-ground-truth): script for visualizing ground truth annotations
|
39 |
+
- [model desc.](#pretrained-models): description of pretrained models
|
40 |
+
- [**Running the pretrained models:**](#running-the-pretrained-models) how to generate results files and visualize the results using the rendering pipeline.
|
41 |
+
- [Face generation](#face-generation): commands to generate the results file for the faces
|
42 |
+
- [Body generation](#body-generation): commands to generate the results file for the bodies
|
43 |
+
- [Visualization](#visualization): how to call into the rendering api. For full details, please refer to [this repo](https://github.com/facebookresearch/ca_body).
|
44 |
+
- [**Training from scratch (3 models):**](#training-from-scratch) scripts to get the training pipeline running from scratch for face, guide poses, and body models.
|
45 |
+
- [Face diffusion model](#1-face-diffusion-model)
|
46 |
+
- [Body diffusion](#2-body-diffusion-model)
|
47 |
+
- [Body vq vae](#3-body-vq-vae)
|
48 |
+
- [Body guide transformer](#4-body-guide-transformer)
|
49 |
+
|
50 |
+
We annotate code that you can directly copy and paste into your terminal using the :point_down: icon.
|
51 |
+
|
52 |
+
# Quickstart
|
53 |
+
With this demo, you can record an audio clip and select the number of samples you want to generate.
|
54 |
+
|
55 |
+
Make sure you have CUDA 11.7 and gcc/++ 9.0 for pytorch3d compatibility
|
56 |
+
|
57 |
+
:point_down: Install necessary components. This will do the environment configuration and install the corresponding rendering assets, prerequisite models, and pretrained models:
|
58 |
+
```
|
59 |
+
conda create --name a2p_env python=3.9
|
60 |
+
conda activate a2p_env
|
61 |
+
sh demo/install.sh
|
62 |
+
```
|
63 |
+
:point_down: Run the demo. You can record your audio and then render corresponding results!
|
64 |
+
```
|
65 |
+
python -m demo.demo
|
66 |
+
```
|
67 |
+
|
68 |
+
:microphone: First, record your audio
|
69 |
+
|
70 |
+

|
71 |
+
|
72 |
+
:hourglass: Hold tight because the rendering can take a while!
|
73 |
+
|
74 |
+
You can change the number of samples (1-10) you want to generate, and download your favorite video by clicking on the download button on the top right of each video.
|
75 |
+
|
76 |
+

|
77 |
+
|
78 |
+
# Installation
|
79 |
+
The code has been tested with CUDA 11.7 and python 3.9, gcc/++ 9.0
|
80 |
+
|
81 |
+
:point_down: If you haven't done so already via the demo setup, configure the environments and download prerequisite models:
|
82 |
+
```
|
83 |
+
conda create --name a2p_env python=3.9
|
84 |
+
conda activate a2p_env
|
85 |
+
pip install -r scripts/requirements.txt
|
86 |
+
sh scripts/download_prereq.sh
|
87 |
+
```
|
88 |
+
:point_down: To get the rendering working, please also make sure you install [pytorch3d](https://github.com/facebookresearch/pytorch3d/blob/main/INSTALL.md).
|
89 |
+
```
|
90 |
+
pip install "git+https://github.com/facebookresearch/pytorch3d.git"
|
91 |
+
```
|
92 |
+
Please see [CA Bodies repo](https://github.com/facebookresearch/ca_body) for more details on the renderer.
|
93 |
+
|
94 |
+
# Download data and models
|
95 |
+
To download any of the datasets, you can find them at `https://github.com/facebookresearch/audio2photoreal/releases/download/v1.0/<person_id>.zip`, where you can replace `<person_id>` with any of `PXB184`, `RLW104`, `TXB805`, or `GQS883`.
|
96 |
+
Download over the command line can be done with this commands.
|
97 |
+
```
|
98 |
+
curl -L https://github.com/facebookresearch/audio2photoreal/releases/download/v1.0/<person_id>.zip -o <person_id>.zip
|
99 |
+
unzip <person_id>.zip -d dataset/
|
100 |
+
rm <person_id>.zip
|
101 |
+
```
|
102 |
+
:point_down: To download *all* of the datasets, you can simply run the following which will download and unpack all the models.
|
103 |
+
```
|
104 |
+
sh scripts/download_alldatasets.sh
|
105 |
+
```
|
106 |
+
|
107 |
+
Similarly, to download any of the models, you can find them at `http://audio2photoreal_models.berkeleyvision.org/<person_id>_models.tar`.
|
108 |
+
```
|
109 |
+
# download the motion generation
|
110 |
+
wget http://audio2photoreal_models.berkeleyvision.org/<person_id>_models.tar
|
111 |
+
tar xvf <person_id>_models.tar
|
112 |
+
rm <person_id>_models.tar
|
113 |
+
|
114 |
+
# download the body decoder/rendering assets and place them in the right place
|
115 |
+
mkdir -p checkpoints/ca_body/data/
|
116 |
+
wget https://github.com/facebookresearch/ca_body/releases/download/v0.0.1-alpha/<person_id>.tar.gz
|
117 |
+
tar xvf <person_id>.tar.gz --directory checkpoints/ca_body/data/
|
118 |
+
rm <person_id>.tar.gz
|
119 |
+
```
|
120 |
+
:point_down: You can also download all of the models with this script:
|
121 |
+
```
|
122 |
+
sh scripts/download_allmodels.sh
|
123 |
+
```
|
124 |
+
The above model script will download *both* the models for motion generation and the body decoder/rendering models. Please view the script for more details.
|
125 |
+
|
126 |
+
### Dataset
|
127 |
+
Once the dataset is downloaded and unzipped (via `scripts/download_datasets.sh`), it should unfold into the following directory structure:
|
128 |
+
```
|
129 |
+
|-- dataset/
|
130 |
+
|-- PXB184/
|
131 |
+
|-- data_stats.pth
|
132 |
+
|-- scene01_audio.wav
|
133 |
+
|-- scene01_body_pose.npy
|
134 |
+
|-- scene01_face_expression.npy
|
135 |
+
|-- scene01_missing_face_frames.npy
|
136 |
+
|-- ...
|
137 |
+
|-- scene30_audio.wav
|
138 |
+
|-- scene30_body_pose.npy
|
139 |
+
|-- scene30_face_expression.npy
|
140 |
+
|-- scene30_missing_face_frames.npy
|
141 |
+
|-- RLW104/
|
142 |
+
|-- TXB805/
|
143 |
+
|-- GQS883/
|
144 |
+
```
|
145 |
+
Each of the four participants (`PXB184`, `RLW104`, `TXB805`, `GQS883`) should have independent "scenes" (1 to 26 or so).
|
146 |
+
For each scene, there are 3 types of data annotations that we save.
|
147 |
+
```
|
148 |
+
*audio.wav: wavefile containing the raw audio (two channels, 1600*T samples) at 48kHz; channel 0 is the audio associated with the current person, channel 1 is the audio associated with their conversational partner.
|
149 |
+
|
150 |
+
*body_pose.npy: (T x 104) array of joint angles in a kinematic skeleton. Not all of the joints are represented with 3DoF. Each 104-d vector can be used to reconstruct a full-body skeleton.
|
151 |
+
|
152 |
+
*face_expression.npy: (T x 256) array of facial codes, where each 256-d vector reconstructs a face mesh.
|
153 |
+
|
154 |
+
*missing_face_frames.npy: List of indices (t) where the facial code is missing or corrupted.
|
155 |
+
|
156 |
+
data_stats.pth: carries the mean and std for each modality of each person.
|
157 |
+
```
|
158 |
+
|
159 |
+
For the train/val/test split the indices are defined in `data_loaders/data.py` as:
|
160 |
+
```
|
161 |
+
train_idx = list(range(0, len(data_dict["data"]) - 6))
|
162 |
+
val_idx = list(range(len(data_dict["data"]) - 6, len(data_dict["data"]) - 4))
|
163 |
+
test_idx = list(range(len(data_dict["data"]) - 4, len(data_dict["data"])))
|
164 |
+
```
|
165 |
+
for any of the four dataset participants we train on.
|
166 |
+
|
167 |
+
### Visualize ground truth
|
168 |
+
If you've properly installed the rendering requirements, you can then visualize the full dataset with the following command:
|
169 |
+
```
|
170 |
+
python -m visualize.render_anno
|
171 |
+
--save_dir <path/to/save/dir>
|
172 |
+
--data_root <path/to/data/root>
|
173 |
+
--max_seq_length <num>
|
174 |
+
```
|
175 |
+
|
176 |
+
The videos will be chunked lengths according to specified `--max_seq_length` arg, which you can specify (the default is 600).
|
177 |
+
|
178 |
+
:point_down: For example, to visualize ground truth annotations for `PXB184`, you can run the following.
|
179 |
+
```
|
180 |
+
python -m visualize.render_anno --save_dir vis_anno_test --data_root dataset/PXB184 --max_seq_length 600
|
181 |
+
```
|
182 |
+
|
183 |
+
### Pretrained models
|
184 |
+
We train person-specific models, so each person should have an associated directory. For instance, for `PXB184`, their complete models should unzip into the following structure.
|
185 |
+
```
|
186 |
+
|-- checkpoints/
|
187 |
+
|-- diffusion/
|
188 |
+
|-- c1_face/
|
189 |
+
|-- args.json
|
190 |
+
|-- model:09d.pt
|
191 |
+
|-- c1_pose/
|
192 |
+
|-- args.json
|
193 |
+
|-- model:09d.pt
|
194 |
+
|-- guide/
|
195 |
+
|-- c1_pose/
|
196 |
+
|-- args.json
|
197 |
+
|-- checkpoints/
|
198 |
+
|-- iter-:07d.pt
|
199 |
+
|-- vq/
|
200 |
+
|-- c1_pose/
|
201 |
+
|-- args.json
|
202 |
+
|-- net_iter:06d.pth
|
203 |
+
```
|
204 |
+
There are 4 models for each person and each model has an associated `args.json`.
|
205 |
+
1. a face diffusion model that outputs 256 facial codes conditioned on audio
|
206 |
+
2. a pose diffusion model that outputs 104 joint rotations conditioned on audio and guide poses
|
207 |
+
3. a guide vq pose model that outputs vq tokens conditioned on audio at 1 fps
|
208 |
+
4. a vq encoder-decoder model that vector quantizes the continuous 104-d pose space.
|
209 |
+
|
210 |
+
# Running the pretrained models
|
211 |
+
To run the actual models, you will need to run the pretrained models and generate the associated results files before visualizing them.
|
212 |
+
|
213 |
+
### Face generation
|
214 |
+
To generate the results file for the face,
|
215 |
+
```
|
216 |
+
python -m sample.generate
|
217 |
+
--model_path <path/to/model>
|
218 |
+
--num_samples <xsamples>
|
219 |
+
--num_repetitions <xreps>
|
220 |
+
--timestep_respacing ddim500
|
221 |
+
--guidance_param 10.0
|
222 |
+
```
|
223 |
+
|
224 |
+
The `<path/to/model>` should be the path to the diffusion model that is associated with generating the face.
|
225 |
+
E.g. for participant `PXB184`, the path might be `./checkpoints/diffusion/c1_face/model000155000.pt`
|
226 |
+
The other parameters are:
|
227 |
+
```
|
228 |
+
--num_samples: number of samples to generate. To sample the full dataset, use 56 (except for TXB805, whcih is 58).
|
229 |
+
--num_repetitions: number of times to repeat the sampling, such that total number of sequences generated is (num_samples * num_repetitions).
|
230 |
+
--timestep_respacing: how many diffusion steps to take. Format will always be ddim<number>.
|
231 |
+
--guidance_param: how influential the conditioning is on the results. I usually use range 2.0-10.0, and tend towards higher for the face.
|
232 |
+
```
|
233 |
+
|
234 |
+
:point_down: A full example of running the face model for `PXB184` with the provided pretrained models would then be:
|
235 |
+
```
|
236 |
+
python -m sample.generate --model_path checkpoints/diffusion/c1_face/model000155000.pt --num_samples 10 --num_repetitions 5 --timestep_respacing ddim500 --guidance_param 10.0
|
237 |
+
```
|
238 |
+
This generates 10 samples from the dataset 1 time. The output results file will be saved to:
|
239 |
+
`./checkpoints/diffusion/c1_face/samples_c1_face_000155000_seed10_/results.npy`
|
240 |
+
|
241 |
+
### Body generation
|
242 |
+
To generate the corresponding body, it will be very similar to generating the face, except now we have to feed in the model for generating the guide poses as well.
|
243 |
+
```
|
244 |
+
python -m sample.generate
|
245 |
+
--model_path <path/to/model>
|
246 |
+
--resume_trans <path/to/guide/model>
|
247 |
+
--num_samples <xsamples>
|
248 |
+
--num_repetitions <xreps>
|
249 |
+
--timestep_respacing ddim500
|
250 |
+
--guidance_param 2.0
|
251 |
+
```
|
252 |
+
|
253 |
+
:point_down: Here, `<path/to/guide/model>` should point to the guide transformer. The full command would be:
|
254 |
+
```
|
255 |
+
python -m sample.generate --model_path checkpoints/diffusion/c1_pose/model000340000.pt --resume_trans checkpoints/guide/c1_pose/checkpoints/iter-0100000.pt --num_samples 10 --num_repetitions 5 --timestep_respacing ddim500 --guidance_param 2.0
|
256 |
+
```
|
257 |
+
Similarly, the output will be saved to:
|
258 |
+
`./checkpoints/diffusion/c1_pose/samples_c1_pose_000340000_seed10_guide_iter-0100000.pt/results.npy`
|
259 |
+
|
260 |
+
### Visualization
|
261 |
+
On the body generation side of things, you can also optionally pass in the `--plot` flag in order to render out the photorealistic avatar. You will also need to pass in the corresponding generated face codes with the `--face_codes` flag.
|
262 |
+
Optionally, if you already have the poses precomputed, you an also pass in the generated body with the `--pose_codes` flag.
|
263 |
+
This will save videos in the same directory as where the body's `results.npy` is stored.
|
264 |
+
|
265 |
+
:point_down: An example of the full command with *the three new flags added is*:
|
266 |
+
```
|
267 |
+
python -m sample.generate --model_path checkpoints/diffusion/c1_pose/model000340000.pt --resume_trans checkpoints/guide/c1_pose/checkpoints/iter-0100000.pt --num_samples 10 --num_repetitions 5 --timestep_respacing ddim500 --guidance_param 2.0 --face_codes ./checkpoints/diffusion/c1_face/samples_c1_face_000155000_seed10_/results.npy --pose_codes ./checkpoints/diffusion/c1_pose/samples_c1_pose_000340000_seed10_guide_iter-0100000.pt/results.npy --plot
|
268 |
+
```
|
269 |
+
The remaining flags can be the same as before. For the actual rendering api, please see [Codec Avatar Body](https://github.com/facebookresearch/ca_body) for installation etc.
|
270 |
+
*Important: in order to visualize the full photorealistic avatar, you will need to run the face codes first, then pass them into the body generation code.* It will not work if you try to call generate with `--plot` for the face codes.
|
271 |
+
|
272 |
+
# Training from scratch
|
273 |
+
There are four possible models you will need to train: 1) the face diffusion model, 2) the body diffusion model, 3) the body vq vae, 4) the body guide transformer.
|
274 |
+
The only dependency is that 3) is needed for 4). All other models can be trained in parallel.
|
275 |
+
|
276 |
+
### 1) Face diffusion model
|
277 |
+
To train the face model, you will need to run the following script:
|
278 |
+
```
|
279 |
+
python -m train.train_diffusion
|
280 |
+
--save_dir <path/to/save/dir>
|
281 |
+
--data_root <path/to/data/root>
|
282 |
+
--batch_size <bs>
|
283 |
+
--dataset social
|
284 |
+
--data_format face
|
285 |
+
--layers 8
|
286 |
+
--heads 8
|
287 |
+
--timestep_respacing ''
|
288 |
+
--max_seq_length 600
|
289 |
+
```
|
290 |
+
Importantly, a few of the flags are as follows:
|
291 |
+
```
|
292 |
+
--save_dir: path to directory where all outputs are stored
|
293 |
+
--data_root: path to the directory of where to load the data from
|
294 |
+
--dataset: name of dataset to load; right now we only support the 'social' dataset
|
295 |
+
--data_format: set to 'face' for the face, as opposed to pose
|
296 |
+
--timestep_respacing: set to '' which does the default spacing of 1k diffusion steps
|
297 |
+
--max_seq_length: the maximum number of frames for a given sequence to train on
|
298 |
+
```
|
299 |
+
:point_down: A full example for training on person `PXB184` is:
|
300 |
+
```
|
301 |
+
python -m train.train_diffusion --save_dir checkpoints/diffusion/c1_face_test --data_root ./dataset/PXB184/ --batch_size 4 --dataset social --data_format face --layers 8 --heads 8 --timestep_respacing '' --max_seq_length 600
|
302 |
+
```
|
303 |
+
|
304 |
+
### 2) Body diffusion model
|
305 |
+
Training the body model is similar to the face model, but with the following additional parameters
|
306 |
+
```
|
307 |
+
python -m train.train_diffusion
|
308 |
+
--save_dir <path/to/save/dir>
|
309 |
+
--data_root <path/to/data/root>
|
310 |
+
--lambda_vel <num>
|
311 |
+
--batch_size <bs>
|
312 |
+
--dataset social
|
313 |
+
--add_frame_cond 1
|
314 |
+
--data_format pose
|
315 |
+
--layers 6
|
316 |
+
--heads 8
|
317 |
+
--timestep_respacing ''
|
318 |
+
--max_seq_length 600
|
319 |
+
```
|
320 |
+
The flags that differ from the face training are as follows:
|
321 |
+
```
|
322 |
+
--lambda_vel: additional auxilary loss for training with velocity
|
323 |
+
--add_frame_cond: set to '1' for 1 fps. if not specified, it will default to 30 fps.
|
324 |
+
--data_format: set to 'pose' for the body, as opposed to face
|
325 |
+
```
|
326 |
+
:point_down: A full example for training on person `PXB184` is:
|
327 |
+
```
|
328 |
+
python -m train.train_diffusion --save_dir checkpoints/diffusion/c1_pose_test --data_root ./dataset/PXB184/ --lambda_vel 2.0 --batch_size 4 --dataset social --add_frame_cond 1 --data_format pose --layers 6 --heads 8 --timestep_respacing '' --max_seq_length 600
|
329 |
+
```
|
330 |
+
|
331 |
+
### 3) Body VQ VAE
|
332 |
+
To train a vq encoder-decoder, you will need to run the following script:
|
333 |
+
```
|
334 |
+
python -m train.train_vq
|
335 |
+
--out_dir <path/to/out/dir>
|
336 |
+
--data_root <path/to/data/root>
|
337 |
+
--batch_size <bs>
|
338 |
+
--lr 1e-3
|
339 |
+
--code_dim 1024
|
340 |
+
--output_emb_width 64
|
341 |
+
--depth 4
|
342 |
+
--dataname social
|
343 |
+
--loss_vel 0.0
|
344 |
+
--add_frame_cond 1
|
345 |
+
--data_format pose
|
346 |
+
--max_seq_length 600
|
347 |
+
```
|
348 |
+
:point_down: For person `PXB184`, it would be:
|
349 |
+
```
|
350 |
+
python -m train.train_vq --out_dir checkpoints/vq/c1_vq_test --data_root ./dataset/PXB184/ --lr 1e-3 --code_dim 1024 --output_emb_width 64 --depth 4 --dataname social --loss_vel 0.0 --data_format pose --batch_size 4 --add_frame_cond 1 --max_seq_length 600
|
351 |
+
```
|
352 |
+
|
353 |
+
### 4) Body guide transformer
|
354 |
+
Once you have the vq trained from 3) you can then pass it in to train the body guide pose transformer:
|
355 |
+
```
|
356 |
+
python -m train.train_guide
|
357 |
+
--out_dir <path/to/out/dir>
|
358 |
+
--data_root <path/to/data/root>
|
359 |
+
--batch_size <bs>
|
360 |
+
--resume_pth <path/to/vq/model>
|
361 |
+
--add_frame_cond 1
|
362 |
+
--layers 6
|
363 |
+
--lr 2e-4
|
364 |
+
--gn
|
365 |
+
--dim 64
|
366 |
+
```
|
367 |
+
:point_down: For person `PXB184`, it would be:
|
368 |
+
```
|
369 |
+
python -m train.train_guide --out_dir checkpoints/guide/c1_trans_test --data_root ./dataset/PXB184/ --batch_size 4 --resume_pth checkpoints/vq/c1_vq_test/net_iter300000.pth --add_frame_cond 1 --layers 6 --lr 2e-4 --gn --dim 64
|
370 |
+
```
|
371 |
+
|
372 |
+
After training these 4 models, you can now follow the ["Running the pretrained models"](#running-the-pretrained-models) section to generate samples and visualize results.
|
373 |
+
|
374 |
+
You can also visualize the corresponding ground truth sequences by passing in the `--render_gt` flag.
|
375 |
+
|
376 |
+
|
377 |
+
# License
|
378 |
+
The code and dataset are released under [CC-NC 4.0 International license](https://github.com/facebookresearch/audio2photoreal/blob/main/LICENSE).
|
assets/demo1.gif
ADDED
![]() |
assets/demo2.gif
ADDED
![]() |
Git LFS Details
|
assets/render_defaults_GQS883.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3ae7ee73849e258bbb8d8a04aa674960896fc1dff8757fefbd2df1685225dd7d
|
3 |
+
size 71354547
|
assets/render_defaults_PXB184.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c86ba14a58d4829c8d05428f5e601072dc4bab1bdc60bc53ce6c73990e9b97d7
|
3 |
+
size 71354547
|
assets/render_defaults_RLW104.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:808a3fbf33115d3cc132bad48c2e95bfca29bb1847d912b1f72e5e5b4a081db5
|
3 |
+
size 71354547
|
assets/render_defaults_TXB805.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d7985c79edfba70f83f560859f2ce214d9779a46031aa8ca6a917d8fd4417e24
|
3 |
+
size 71354547
|
checkpoints/ca_body/data/PXB184/body_dec.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:26394ae03c1726b7c90b5633696d0eea733a3c5e423893c4e79b490c80c35ddf
|
3 |
+
size 893279810
|
checkpoints/ca_body/data/PXB184/config.yml
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
model:
|
3 |
+
class_name: ca_body.models.mesh_vae_drivable.AutoEncoder
|
4 |
+
|
5 |
+
encoder:
|
6 |
+
n_embs: 1024
|
7 |
+
noise_std: 1.0
|
8 |
+
|
9 |
+
encoder_face:
|
10 |
+
n_embs: 256
|
11 |
+
noise_std: 1.0
|
12 |
+
|
13 |
+
decoder_face:
|
14 |
+
n_latent: 256
|
15 |
+
n_vert_out: 21918
|
16 |
+
|
17 |
+
decoder:
|
18 |
+
init_uv_size: 64
|
19 |
+
n_init_channels: 64
|
20 |
+
n_min_channels: 4
|
21 |
+
n_pose_dims: 98
|
22 |
+
n_pose_enc_channels: 16
|
23 |
+
n_embs: 1024
|
24 |
+
n_embs_enc_channels: 32
|
25 |
+
n_face_embs: 256
|
26 |
+
uv_size: 1024
|
27 |
+
|
28 |
+
decoder_view:
|
29 |
+
net_uv_size: 1024
|
30 |
+
|
31 |
+
upscale_net:
|
32 |
+
n_ftrs: 4
|
33 |
+
|
34 |
+
shadow_net:
|
35 |
+
uv_size: 2048
|
36 |
+
shadow_size: 256
|
37 |
+
n_dims: 4
|
38 |
+
|
39 |
+
pose_to_shadow:
|
40 |
+
n_pose_dims: 104
|
41 |
+
uv_size: 2048
|
42 |
+
|
43 |
+
renderer:
|
44 |
+
image_height: 2048
|
45 |
+
image_width: 1334
|
46 |
+
depth_disc_ksize: 3
|
47 |
+
|
48 |
+
cal:
|
49 |
+
identity_camera: '400143'
|
50 |
+
|
51 |
+
pixel_cal:
|
52 |
+
image_height: 2048
|
53 |
+
image_width: 1334
|
54 |
+
ds_rate: 8
|
55 |
+
|
56 |
+
learn_blur: true
|
checkpoints/diffusion/c1_face/args.json
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"add_frame_cond": null,
|
3 |
+
"batch_size": 4,
|
4 |
+
"cond_mask_prob": 0.2,
|
5 |
+
"cuda": true,
|
6 |
+
"data_format": "face",
|
7 |
+
"data_root": "./dataset/PXB184/",
|
8 |
+
"dataset": "social",
|
9 |
+
"device": 0,
|
10 |
+
"diffusion_steps": 10,
|
11 |
+
"heads": 8,
|
12 |
+
"lambda_vel": 0.0,
|
13 |
+
"latent_dim": 512,
|
14 |
+
"layers": 8,
|
15 |
+
"log_interval": 1000,
|
16 |
+
"lr": 0.0001,
|
17 |
+
"lr_anneal_steps": 0,
|
18 |
+
"max_seq_length": 600,
|
19 |
+
"noise_schedule": "cosine",
|
20 |
+
"not_rotary": false,
|
21 |
+
"num_audio_layers": 3,
|
22 |
+
"num_steps": 800000,
|
23 |
+
"overwrite": false,
|
24 |
+
"resume_checkpoint": "",
|
25 |
+
"save_dir": "checkpoints/diffusion/c1_face/",
|
26 |
+
"save_interval": 5000,
|
27 |
+
"seed": 10,
|
28 |
+
"sigma_small": true,
|
29 |
+
"simplify_audio": false,
|
30 |
+
"timestep_respacing": "",
|
31 |
+
"train_platform_type": "NoPlatform",
|
32 |
+
"unconstrained": false,
|
33 |
+
"weight_decay": 0.0
|
34 |
+
}
|
checkpoints/diffusion/c1_pose/args.json
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"add_frame_cond": 1.0,
|
3 |
+
"arch": "trans_enc",
|
4 |
+
"batch_size": 32,
|
5 |
+
"clip_body": false,
|
6 |
+
"clip_use_delta": false,
|
7 |
+
"clip_use_vae": false,
|
8 |
+
"cond_mask_prob": 0.1,
|
9 |
+
"cuda": true,
|
10 |
+
"data_format": "pose",
|
11 |
+
"data_root": "./dataset/PXB184/",
|
12 |
+
"dataset": "social",
|
13 |
+
"device": 0,
|
14 |
+
"diffusion_steps": 10,
|
15 |
+
"emb_trans_dec": false,
|
16 |
+
"eval_batch_size": 32,
|
17 |
+
"eval_during_training": false,
|
18 |
+
"eval_num_samples": 1000,
|
19 |
+
"eval_rep_times": 3,
|
20 |
+
"eval_split": "val",
|
21 |
+
"filter": false,
|
22 |
+
"heads": 8,
|
23 |
+
"lambda_fc": 0.0,
|
24 |
+
"lambda_hands": 0.0,
|
25 |
+
"lambda_lips": 0.0,
|
26 |
+
"lambda_rcxyz": 0.0,
|
27 |
+
"lambda_vel": 2.0,
|
28 |
+
"lambda_xyz": 0.0,
|
29 |
+
"lambda_xyz_vel": 0.0,
|
30 |
+
"latent_dim": 512,
|
31 |
+
"layers": 6,
|
32 |
+
"log_interval": 1000,
|
33 |
+
"lr": 0.0001,
|
34 |
+
"lr_anneal_steps": 0,
|
35 |
+
"max_seq_length": 600,
|
36 |
+
"no_split": false,
|
37 |
+
"noise_schedule": "cosine",
|
38 |
+
"not_rotary": false,
|
39 |
+
"num_frames": 60,
|
40 |
+
"num_steps": 800000,
|
41 |
+
"overwrite": false,
|
42 |
+
"partial": false,
|
43 |
+
"resume_checkpoint": "",
|
44 |
+
"save_dir": "checkpoints/diffusion/c1_pose/",
|
45 |
+
"save_interval": 5000,
|
46 |
+
"seed": 10,
|
47 |
+
"sigma_small": true,
|
48 |
+
"simplify_audio": false,
|
49 |
+
"split_net": false,
|
50 |
+
"timestep_respacing": "",
|
51 |
+
"train_platform_type": "NoPlatform",
|
52 |
+
"unconstrained": false,
|
53 |
+
"use_clip": false,
|
54 |
+
"use_cm": true,
|
55 |
+
"use_full_dataset": false,
|
56 |
+
"use_kp": false,
|
57 |
+
"use_mask": true,
|
58 |
+
"use_mdm": false,
|
59 |
+
"use_nort": false,
|
60 |
+
"use_nort_mdm": false,
|
61 |
+
"use_pose_pos": false,
|
62 |
+
"use_resnet": true,
|
63 |
+
"use_vae": null,
|
64 |
+
"weight_decay": 0.0,
|
65 |
+
"z_norm": true
|
66 |
+
}
|
checkpoints/guide/c1_pose/args.json
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"add_audio_pe": true,
|
3 |
+
"add_conv": true,
|
4 |
+
"add_frame_cond": 1,
|
5 |
+
"batch_size": 16,
|
6 |
+
"data_format": "pose",
|
7 |
+
"data_root": "./dataset/PXB184/",
|
8 |
+
"dataset": "social",
|
9 |
+
"dec_layers": null,
|
10 |
+
"dim": 64,
|
11 |
+
"enc_layers": null,
|
12 |
+
"eval_interval": 1000,
|
13 |
+
"filter": false,
|
14 |
+
"gamma": 0.1,
|
15 |
+
"gn": true,
|
16 |
+
"layers": 6,
|
17 |
+
"log_interval": 1000,
|
18 |
+
"lr": 0.0001,
|
19 |
+
"lr_scheduler": [
|
20 |
+
50000,
|
21 |
+
400000
|
22 |
+
],
|
23 |
+
"no_split": false,
|
24 |
+
"num_audio_layers":2,
|
25 |
+
"out_dir": "checkpoints/guide/c1_pose",
|
26 |
+
"partial": false,
|
27 |
+
"resume_pth": "checkpoints/vq/c1_pose/net_iter300000.pth",
|
28 |
+
"resume_trans": null,
|
29 |
+
"save_interval": 5000,
|
30 |
+
"seed": 10,
|
31 |
+
"simplify_audio": false,
|
32 |
+
"total_iter": 1000000,
|
33 |
+
"use_full_dataset": false,
|
34 |
+
"use_kp": false,
|
35 |
+
"use_lstm": false,
|
36 |
+
"use_nort": false,
|
37 |
+
"use_nort_mdm": false,
|
38 |
+
"use_torch": false,
|
39 |
+
"warm_up_iter": 5000,
|
40 |
+
"weight_decay": 0.1
|
41 |
+
}
|
checkpoints/vq/c1_pose/args.json
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"add_frame_cond": 1.0,
|
3 |
+
"batch_size": 16,
|
4 |
+
"code_dim": 1024,
|
5 |
+
"commit": 0.02,
|
6 |
+
"data_format": "pose",
|
7 |
+
"data_root": "./dataset/PXB184/",
|
8 |
+
"dataname": "social",
|
9 |
+
"dataset": "social",
|
10 |
+
"depth": 4,
|
11 |
+
"eval_iter": 1000,
|
12 |
+
"exp_name": "c1_pose",
|
13 |
+
"filter": false,
|
14 |
+
"gamma": 0.05,
|
15 |
+
"loss_vel": 0.0,
|
16 |
+
"lr": 0.001,
|
17 |
+
"lr_scheduler": [
|
18 |
+
300000
|
19 |
+
],
|
20 |
+
"max_seq_length": 600,
|
21 |
+
"nb_joints": 104,
|
22 |
+
"no_split": true,
|
23 |
+
"out_dir": "checkpoints/vq/c1_pose",
|
24 |
+
"output_emb_width": 64,
|
25 |
+
"partial": false,
|
26 |
+
"print_iter": 200,
|
27 |
+
"results_dir": "visual_results/",
|
28 |
+
"resume_pth": null,
|
29 |
+
"seed": 123,
|
30 |
+
"simplify_audio": false,
|
31 |
+
"total_iter": 10000000,
|
32 |
+
"use_full_dataset": false,
|
33 |
+
"use_kp": false,
|
34 |
+
"use_linear": false,
|
35 |
+
"use_nort": false,
|
36 |
+
"use_nort_mdm": false,
|
37 |
+
"use_quant": true,
|
38 |
+
"use_vae": false,
|
39 |
+
"visual_name": "baseline",
|
40 |
+
"warm_up_iter": 1000,
|
41 |
+
"weight_decay": 0.0,
|
42 |
+
"z_norm": true
|
43 |
+
}
|
checkpoints/vq/c1_pose/net_iter300000.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5649ad5e49e0e1afcd9a7390f0ee79ee66de275a67ecb1cfe7fc691cb4ceb332
|
3 |
+
size 3129275
|
data_loaders/data.py
ADDED
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) Meta Platforms, Inc. and affiliates.
|
3 |
+
All rights reserved.
|
4 |
+
This source code is licensed under the license found in the
|
5 |
+
LICENSE file in the root directory of this source tree.
|
6 |
+
"""
|
7 |
+
|
8 |
+
import os
|
9 |
+
from typing import Dict, Iterable, List, Union
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
from torch.utils import data
|
14 |
+
|
15 |
+
from utils.misc import prGreen
|
16 |
+
|
17 |
+
|
18 |
+
class Social(data.Dataset):
|
19 |
+
def __init__(
|
20 |
+
self,
|
21 |
+
args,
|
22 |
+
data_dict: Dict[str, Iterable],
|
23 |
+
split: str = "train",
|
24 |
+
chunk: bool = False,
|
25 |
+
add_padding: bool = True,
|
26 |
+
) -> None:
|
27 |
+
if args.data_format == "face":
|
28 |
+
prGreen("[dataset.py] training face only model")
|
29 |
+
data_dict["data"] = data_dict["face"]
|
30 |
+
elif args.data_format == "pose":
|
31 |
+
prGreen("[dataset.py] training pose only model")
|
32 |
+
missing = []
|
33 |
+
for d in data_dict["data"]:
|
34 |
+
missing.append(np.ones_like(d))
|
35 |
+
data_dict["missing"] = missing
|
36 |
+
|
37 |
+
# set up variables for dataloader
|
38 |
+
self.data_format = args.data_format
|
39 |
+
self.add_frame_cond = args.add_frame_cond
|
40 |
+
self._register_keyframe_step()
|
41 |
+
self.data_root = args.data_root
|
42 |
+
self.max_seq_length = args.max_seq_length
|
43 |
+
if hasattr(args, "curr_seq_length") and args.curr_seq_length is not None:
|
44 |
+
self.max_seq_length = args.curr_seq_length
|
45 |
+
prGreen([f"[dataset.py] sequences of {self.max_seq_length}"])
|
46 |
+
self.add_padding = add_padding
|
47 |
+
self.audio_per_frame = 1600
|
48 |
+
self.max_audio_length = self.max_seq_length * self.audio_per_frame
|
49 |
+
self.min_seq_length = 400
|
50 |
+
|
51 |
+
# set up training/validation splits
|
52 |
+
train_idx = list(range(0, len(data_dict["data"]) - 6))
|
53 |
+
val_idx = list(range(len(data_dict["data"]) - 6, len(data_dict["data"]) - 4))
|
54 |
+
test_idx = list(range(len(data_dict["data"]) - 4, len(data_dict["data"])))
|
55 |
+
self.split = split
|
56 |
+
if split == "train":
|
57 |
+
self._pick_sequences(data_dict, train_idx)
|
58 |
+
elif split == "val":
|
59 |
+
self._pick_sequences(data_dict, val_idx)
|
60 |
+
else:
|
61 |
+
self._pick_sequences(data_dict, test_idx)
|
62 |
+
self.chunk = chunk
|
63 |
+
if split == "test":
|
64 |
+
print("[dataset.py] chunking data...")
|
65 |
+
self._chunk_data()
|
66 |
+
self._load_std()
|
67 |
+
prGreen(
|
68 |
+
f"[dataset.py] {split} | {len(self.data)} sequences ({self.data[0].shape}) | total len {self.total_len}"
|
69 |
+
)
|
70 |
+
|
71 |
+
def inv_transform(
|
72 |
+
self, data: Union[np.ndarray, torch.Tensor], data_type: str
|
73 |
+
) -> Union[np.ndarray, torch.Tensor]:
|
74 |
+
if data_type == "pose":
|
75 |
+
std = self.std
|
76 |
+
mean = self.mean
|
77 |
+
elif data_type == "face":
|
78 |
+
std = self.face_std
|
79 |
+
mean = self.face_mean
|
80 |
+
elif data_type == "audio":
|
81 |
+
std = self.audio_std
|
82 |
+
mean = self.audio_mean
|
83 |
+
else:
|
84 |
+
assert False, f"datatype not defined: {data_type}"
|
85 |
+
|
86 |
+
if torch.is_tensor(data):
|
87 |
+
return data * torch.tensor(
|
88 |
+
std, device=data.device, requires_grad=False
|
89 |
+
) + torch.tensor(mean, device=data.device, requires_grad=False)
|
90 |
+
else:
|
91 |
+
return data * std + mean
|
92 |
+
|
93 |
+
def _pick_sequences(self, data_dict: Dict[str, Iterable], idx: List[int]) -> None:
|
94 |
+
self.data = np.take(data_dict["data"], idx, axis=0)
|
95 |
+
self.missing = np.take(data_dict["missing"], idx, axis=0)
|
96 |
+
self.audio = np.take(data_dict["audio"], idx, axis=0)
|
97 |
+
self.lengths = np.take(data_dict["lengths"], idx, axis=0)
|
98 |
+
self.total_len = sum([len(d) for d in self.data])
|
99 |
+
|
100 |
+
def _load_std(self) -> None:
|
101 |
+
stats = torch.load(os.path.join(self.data_root, "data_stats.pth"))
|
102 |
+
print(
|
103 |
+
f'[dataset.py] loading from... {os.path.join(self.data_root, "data_stats.pth")}'
|
104 |
+
)
|
105 |
+
self.mean = stats["pose_mean"].reshape(-1)
|
106 |
+
self.std = stats["pose_std"].reshape(-1)
|
107 |
+
self.face_mean = stats["code_mean"]
|
108 |
+
self.face_std = stats["code_std"]
|
109 |
+
self.audio_mean = stats["audio_mean"]
|
110 |
+
self.audio_std = stats["audio_std_flat"]
|
111 |
+
|
112 |
+
def _chunk_data(self) -> None:
|
113 |
+
chunk_data = []
|
114 |
+
chunk_missing = []
|
115 |
+
chunk_lengths = []
|
116 |
+
chunk_audio = []
|
117 |
+
# create sequences of set lengths
|
118 |
+
for d_idx in range(len(self.data)):
|
119 |
+
curr_data = self.data[d_idx]
|
120 |
+
curr_missing = self.missing[d_idx]
|
121 |
+
curr_audio = self.audio[d_idx]
|
122 |
+
end_range = len(self.data[d_idx]) - self.max_seq_length
|
123 |
+
for chunk_idx in range(0, end_range, self.max_seq_length):
|
124 |
+
chunk_end = chunk_idx + self.max_seq_length
|
125 |
+
curr_data_chunk = curr_data[chunk_idx:chunk_end, :]
|
126 |
+
curr_missing_chunk = curr_missing[chunk_idx:chunk_end, :]
|
127 |
+
curr_audio_chunk = curr_audio[
|
128 |
+
chunk_idx * self.audio_per_frame : chunk_end * self.audio_per_frame,
|
129 |
+
:,
|
130 |
+
]
|
131 |
+
if curr_data_chunk.shape[0] < self.max_seq_length:
|
132 |
+
# do not add a short chunk to the list
|
133 |
+
continue
|
134 |
+
chunk_lengths.append(curr_data_chunk.shape[0])
|
135 |
+
chunk_data.append(curr_data_chunk)
|
136 |
+
chunk_missing.append(curr_missing_chunk)
|
137 |
+
chunk_audio.append(curr_audio_chunk)
|
138 |
+
idx = np.random.permutation(len(chunk_data))
|
139 |
+
print("==> shuffle", idx)
|
140 |
+
self.data = np.take(chunk_data, idx, axis=0)
|
141 |
+
self.missing = np.take(chunk_missing, idx, axis=0)
|
142 |
+
self.lengths = np.take(chunk_lengths, idx, axis=0)
|
143 |
+
self.audio = np.take(chunk_audio, idx, axis=0)
|
144 |
+
self.total_len = len(self.data)
|
145 |
+
|
146 |
+
def _register_keyframe_step(self) -> None:
|
147 |
+
if self.add_frame_cond == 1:
|
148 |
+
self.step = 30
|
149 |
+
if self.add_frame_cond is None:
|
150 |
+
self.step = 1
|
151 |
+
|
152 |
+
def _pad_sequence(
|
153 |
+
self, sequence: np.ndarray, actual_length: int, max_length: int
|
154 |
+
) -> np.ndarray:
|
155 |
+
sequence = np.concatenate(
|
156 |
+
(
|
157 |
+
sequence,
|
158 |
+
np.zeros((max_length - actual_length, sequence.shape[-1])),
|
159 |
+
),
|
160 |
+
axis=0,
|
161 |
+
)
|
162 |
+
return sequence
|
163 |
+
|
164 |
+
def _get_idx(self, item: int) -> int:
|
165 |
+
cumulative_len = 0
|
166 |
+
seq_idx = 0
|
167 |
+
while item > cumulative_len:
|
168 |
+
cumulative_len += len(self.data[seq_idx])
|
169 |
+
seq_idx += 1
|
170 |
+
item = seq_idx - 1
|
171 |
+
return item
|
172 |
+
|
173 |
+
def _get_random_subsection(
|
174 |
+
self, data_dict: Dict[str, Iterable]
|
175 |
+
) -> Dict[str, np.ndarray]:
|
176 |
+
isnonzero = False
|
177 |
+
while not isnonzero:
|
178 |
+
start = np.random.randint(0, data_dict["m_length"] - self.max_seq_length)
|
179 |
+
if self.add_padding:
|
180 |
+
length = (
|
181 |
+
np.random.randint(self.min_seq_length, self.max_seq_length)
|
182 |
+
if not self.split == "test"
|
183 |
+
else self.max_seq_length
|
184 |
+
)
|
185 |
+
else:
|
186 |
+
length = self.max_seq_length
|
187 |
+
curr_missing = data_dict["missing"][start : start + length]
|
188 |
+
isnonzero = np.any(curr_missing)
|
189 |
+
missing = curr_missing
|
190 |
+
motion = data_dict["motion"][start : start + length, :]
|
191 |
+
keyframes = motion[:: self.step]
|
192 |
+
audio = data_dict["audio"][
|
193 |
+
start * self.audio_per_frame : (start + length) * self.audio_per_frame,
|
194 |
+
:,
|
195 |
+
]
|
196 |
+
data_dict["m_length"] = len(motion)
|
197 |
+
data_dict["k_length"] = len(keyframes)
|
198 |
+
data_dict["a_length"] = len(audio)
|
199 |
+
|
200 |
+
if data_dict["m_length"] < self.max_seq_length:
|
201 |
+
motion = self._pad_sequence(
|
202 |
+
motion, data_dict["m_length"], self.max_seq_length
|
203 |
+
)
|
204 |
+
missing = self._pad_sequence(
|
205 |
+
missing, data_dict["m_length"], self.max_seq_length
|
206 |
+
)
|
207 |
+
audio = self._pad_sequence(
|
208 |
+
audio, data_dict["a_length"], self.max_audio_length
|
209 |
+
)
|
210 |
+
max_step_length = len(np.zeros(self.max_seq_length)[:: self.step])
|
211 |
+
keyframes = self._pad_sequence(
|
212 |
+
keyframes, data_dict["k_length"], max_step_length
|
213 |
+
)
|
214 |
+
data_dict["motion"] = motion
|
215 |
+
data_dict["keyframes"] = keyframes
|
216 |
+
data_dict["audio"] = audio
|
217 |
+
data_dict["missing"] = missing
|
218 |
+
return data_dict
|
219 |
+
|
220 |
+
def __len__(self) -> int:
|
221 |
+
return self.total_len
|
222 |
+
|
223 |
+
def __getitem__(self, item: int) -> Dict[str, np.ndarray]:
|
224 |
+
# figure out which sequence to randomly sample from
|
225 |
+
if not self.split == "test":
|
226 |
+
item = self._get_idx(item)
|
227 |
+
motion = self.data[item]
|
228 |
+
audio = self.audio[item]
|
229 |
+
m_length = self.lengths[item]
|
230 |
+
missing = self.missing[item]
|
231 |
+
a_length = len(audio)
|
232 |
+
# Z Normalization
|
233 |
+
if self.data_format == "pose":
|
234 |
+
motion = (motion - self.mean) / self.std
|
235 |
+
elif self.data_format == "face":
|
236 |
+
motion = (motion - self.face_mean) / self.face_std
|
237 |
+
audio = (audio - self.audio_mean) / self.audio_std
|
238 |
+
keyframes = motion[:: self.step]
|
239 |
+
k_length = len(keyframes)
|
240 |
+
data_dict = {
|
241 |
+
"motion": motion,
|
242 |
+
"m_length": m_length,
|
243 |
+
"audio": audio,
|
244 |
+
"a_length": a_length,
|
245 |
+
"keyframes": keyframes,
|
246 |
+
"k_length": k_length,
|
247 |
+
"missing": missing,
|
248 |
+
}
|
249 |
+
if not self.split == "test" and not self.chunk:
|
250 |
+
data_dict = self._get_random_subsection(data_dict)
|
251 |
+
if self.data_format == "face":
|
252 |
+
data_dict["motion"] *= data_dict["missing"]
|
253 |
+
return data_dict
|
data_loaders/get_data.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) Meta Platforms, Inc. and affiliates.
|
3 |
+
All rights reserved.
|
4 |
+
This source code is licensed under the license found in the
|
5 |
+
LICENSE file in the root directory of this source tree.
|
6 |
+
"""
|
7 |
+
|
8 |
+
import os
|
9 |
+
|
10 |
+
from typing import Dict, List
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
import torchaudio
|
15 |
+
from data_loaders.data import Social
|
16 |
+
from data_loaders.tensors import social_collate
|
17 |
+
from torch.utils.data import DataLoader
|
18 |
+
from utils.misc import prGreen
|
19 |
+
|
20 |
+
|
21 |
+
def get_dataset_loader(
|
22 |
+
args,
|
23 |
+
data_dict: Dict[str, np.ndarray],
|
24 |
+
split: str = "train",
|
25 |
+
chunk: bool = False,
|
26 |
+
add_padding: bool = True,
|
27 |
+
) -> DataLoader:
|
28 |
+
dataset = Social(
|
29 |
+
args=args,
|
30 |
+
data_dict=data_dict,
|
31 |
+
split=split,
|
32 |
+
chunk=chunk,
|
33 |
+
add_padding=add_padding,
|
34 |
+
)
|
35 |
+
loader = DataLoader(
|
36 |
+
dataset,
|
37 |
+
batch_size=args.batch_size,
|
38 |
+
shuffle=not split == "test",
|
39 |
+
num_workers=8,
|
40 |
+
drop_last=True,
|
41 |
+
collate_fn=social_collate,
|
42 |
+
pin_memory=True,
|
43 |
+
)
|
44 |
+
return loader
|
45 |
+
|
46 |
+
|
47 |
+
def _load_pose_data(
|
48 |
+
all_paths: List[str], audio_per_frame: int, flip_person: bool = False
|
49 |
+
) -> Dict[str, List]:
|
50 |
+
data = []
|
51 |
+
face = []
|
52 |
+
audio = []
|
53 |
+
lengths = []
|
54 |
+
missing = []
|
55 |
+
for _, curr_path_name in enumerate(all_paths):
|
56 |
+
if not curr_path_name.endswith("_body_pose.npy"):
|
57 |
+
continue
|
58 |
+
# load face information and deal with missing codes
|
59 |
+
curr_code = np.load(
|
60 |
+
curr_path_name.replace("_body_pose.npy", "_face_expression.npy")
|
61 |
+
).astype(float)
|
62 |
+
# curr_code = np.array(curr_face["codes"], dtype=float)
|
63 |
+
missing_list = np.load(
|
64 |
+
curr_path_name.replace("_body_pose.npy", "_missing_face_frames.npy")
|
65 |
+
)
|
66 |
+
if len(missing_list) == len(curr_code):
|
67 |
+
print("skipping", curr_path_name, curr_code.shape)
|
68 |
+
continue
|
69 |
+
curr_missing = np.ones_like(curr_code)
|
70 |
+
curr_missing[missing_list] = 0.0
|
71 |
+
|
72 |
+
# load pose information and deal with discontinuities
|
73 |
+
curr_pose = np.load(curr_path_name)
|
74 |
+
if "PXB184" in curr_path_name or "RLW104" in curr_path_name: # Capture 1 or 2
|
75 |
+
curr_pose[:, 3] = (curr_pose[:, 3] + np.pi) % (2 * np.pi)
|
76 |
+
curr_pose[:, 3] = (curr_pose[:, 3] + np.pi) % (2 * np.pi)
|
77 |
+
|
78 |
+
# load audio information
|
79 |
+
curr_audio, _ = torchaudio.load(
|
80 |
+
curr_path_name.replace("_body_pose.npy", "_audio.wav")
|
81 |
+
)
|
82 |
+
curr_audio = curr_audio.T
|
83 |
+
if flip_person:
|
84 |
+
prGreen("[get_data.py] flipping the dataset of left right person")
|
85 |
+
tmp = torch.zeros_like(curr_audio)
|
86 |
+
tmp[:, 1] = curr_audio[:, 0]
|
87 |
+
tmp[:, 0] = curr_audio[:, 1]
|
88 |
+
curr_audio = tmp
|
89 |
+
|
90 |
+
assert len(curr_pose) * audio_per_frame == len(
|
91 |
+
curr_audio
|
92 |
+
), f"motion {curr_pose.shape} vs audio {curr_audio.shape}"
|
93 |
+
|
94 |
+
data.append(curr_pose)
|
95 |
+
face.append(curr_code)
|
96 |
+
missing.append(curr_missing)
|
97 |
+
audio.append(curr_audio)
|
98 |
+
lengths.append(len(curr_pose))
|
99 |
+
|
100 |
+
data_dict = {
|
101 |
+
"data": data,
|
102 |
+
"face": face,
|
103 |
+
"audio": audio,
|
104 |
+
"lengths": lengths,
|
105 |
+
"missing": missing,
|
106 |
+
}
|
107 |
+
return data_dict
|
108 |
+
|
109 |
+
|
110 |
+
def load_local_data(
|
111 |
+
data_root: str, audio_per_frame: int, flip_person: bool = False
|
112 |
+
) -> Dict[str, List]:
|
113 |
+
if flip_person:
|
114 |
+
if "PXB184" in data_root:
|
115 |
+
data_root = data_root.replace("PXB184", "RLW104")
|
116 |
+
elif "RLW104" in data_root:
|
117 |
+
data_root = data_root.replace("RLW104", "PXB184")
|
118 |
+
elif "TXB805" in data_root:
|
119 |
+
data_root = data_root.replace("TXB805", "GQS883")
|
120 |
+
elif "GQS883" in data_root:
|
121 |
+
data_root = data_root.replace("GQS883", "TXB805")
|
122 |
+
|
123 |
+
all_paths = [os.path.join(data_root, x) for x in os.listdir(data_root)]
|
124 |
+
all_paths.sort()
|
125 |
+
return _load_pose_data(
|
126 |
+
all_paths,
|
127 |
+
audio_per_frame,
|
128 |
+
flip_person=flip_person,
|
129 |
+
)
|
data_loaders/tensors.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) Meta Platforms, Inc. and affiliates.
|
3 |
+
All rights reserved.
|
4 |
+
This source code is licensed under the license found in the
|
5 |
+
LICENSE file in the root directory of this source tree.
|
6 |
+
"""
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from torch.utils.data._utils.collate import default_collate
|
10 |
+
|
11 |
+
|
12 |
+
def lengths_to_mask(lengths, max_len):
|
13 |
+
mask = torch.arange(max_len, device=lengths.device).expand(
|
14 |
+
len(lengths), max_len
|
15 |
+
) < lengths.unsqueeze(1)
|
16 |
+
return mask
|
17 |
+
|
18 |
+
|
19 |
+
def collate_tensors(batch):
|
20 |
+
dims = batch[0].dim()
|
21 |
+
max_size = [max([b.size(i) for b in batch]) for i in range(dims)]
|
22 |
+
size = (len(batch),) + tuple(max_size)
|
23 |
+
canvas = batch[0].new_zeros(size=size)
|
24 |
+
for i, b in enumerate(batch):
|
25 |
+
sub_tensor = canvas[i]
|
26 |
+
for d in range(dims):
|
27 |
+
sub_tensor = sub_tensor.narrow(d, 0, b.size(d))
|
28 |
+
sub_tensor.add_(b)
|
29 |
+
return canvas
|
30 |
+
|
31 |
+
|
32 |
+
## social collate
|
33 |
+
def collate_v2(batch):
|
34 |
+
notnone_batches = [b for b in batch if b is not None]
|
35 |
+
databatch = [b["inp"] for b in notnone_batches]
|
36 |
+
missingbatch = [b["missing"] for b in notnone_batches]
|
37 |
+
audiobatch = [b["audio"] for b in notnone_batches]
|
38 |
+
lenbatch = [b["lengths"] for b in notnone_batches]
|
39 |
+
alenbatch = [b["audio_lengths"] for b in notnone_batches]
|
40 |
+
keyframebatch = [b["keyframes"] for b in notnone_batches]
|
41 |
+
klenbatch = [b["key_lengths"] for b in notnone_batches]
|
42 |
+
|
43 |
+
databatchTensor = collate_tensors(databatch)
|
44 |
+
missingbatchTensor = collate_tensors(missingbatch)
|
45 |
+
audiobatchTensor = collate_tensors(audiobatch)
|
46 |
+
lenbatchTensor = torch.as_tensor(lenbatch)
|
47 |
+
alenbatchTensor = torch.as_tensor(alenbatch)
|
48 |
+
keyframeTensor = collate_tensors(keyframebatch)
|
49 |
+
klenbatchTensor = torch.as_tensor(klenbatch)
|
50 |
+
|
51 |
+
maskbatchTensor = (
|
52 |
+
lengths_to_mask(lenbatchTensor, databatchTensor.shape[-1])
|
53 |
+
.unsqueeze(1)
|
54 |
+
.unsqueeze(1)
|
55 |
+
) # unqueeze for broadcasting
|
56 |
+
motion = databatchTensor
|
57 |
+
cond = {
|
58 |
+
"y": {
|
59 |
+
"missing": missingbatchTensor,
|
60 |
+
"mask": maskbatchTensor,
|
61 |
+
"lengths": lenbatchTensor,
|
62 |
+
"audio": audiobatchTensor,
|
63 |
+
"alengths": alenbatchTensor,
|
64 |
+
"keyframes": keyframeTensor,
|
65 |
+
"klengths": klenbatchTensor,
|
66 |
+
}
|
67 |
+
}
|
68 |
+
return motion, cond
|
69 |
+
|
70 |
+
|
71 |
+
def social_collate(batch):
|
72 |
+
adapted_batch = [
|
73 |
+
{
|
74 |
+
"inp": torch.tensor(b["motion"].T).to(torch.float32).unsqueeze(1),
|
75 |
+
"lengths": b["m_length"],
|
76 |
+
"audio": b["audio"]
|
77 |
+
if torch.is_tensor(b["audio"])
|
78 |
+
else torch.tensor(b["audio"]).to(torch.float32),
|
79 |
+
"keyframes": torch.tensor(b["keyframes"]).to(torch.float32),
|
80 |
+
"key_lengths": b["k_length"],
|
81 |
+
"audio_lengths": b["a_length"],
|
82 |
+
"missing": torch.tensor(b["missing"]).to(torch.float32),
|
83 |
+
}
|
84 |
+
for b in batch
|
85 |
+
]
|
86 |
+
return collate_v2(adapted_batch)
|
demo/.ipynb_checkpoints/demo-checkpoint.py
ADDED
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) Meta Platforms, Inc. and affiliates.
|
3 |
+
All rights reserved.
|
4 |
+
This source code is licensed under the license found in the
|
5 |
+
LICENSE file in the root directory of this source tree.
|
6 |
+
"""
|
7 |
+
|
8 |
+
import copy
|
9 |
+
import json
|
10 |
+
from typing import Dict, Union
|
11 |
+
|
12 |
+
import gradio as gr
|
13 |
+
import numpy as np
|
14 |
+
import torch
|
15 |
+
import torchaudio
|
16 |
+
from attrdict import AttrDict
|
17 |
+
from diffusion.respace import SpacedDiffusion
|
18 |
+
from model.cfg_sampler import ClassifierFreeSampleModel
|
19 |
+
from model.diffusion import FiLMTransformer
|
20 |
+
from utils.misc import fixseed
|
21 |
+
from utils.model_util import create_model_and_diffusion, load_model
|
22 |
+
from visualize.render_codes import BodyRenderer
|
23 |
+
|
24 |
+
|
25 |
+
class GradioModel:
|
26 |
+
def __init__(self, face_args, pose_args) -> None:
|
27 |
+
self.face_model, self.face_diffusion, self.device = self._setup_model(
|
28 |
+
face_args, "checkpoints/diffusion/c1_face/model000155000.pt"
|
29 |
+
)
|
30 |
+
self.pose_model, self.pose_diffusion, _ = self._setup_model(
|
31 |
+
pose_args, "checkpoints/diffusion/c1_pose/model000340000.pt"
|
32 |
+
)
|
33 |
+
# load standardization stuff
|
34 |
+
stats = torch.load("dataset/PXB184/data_stats.pth")
|
35 |
+
stats["pose_mean"] = stats["pose_mean"].reshape(-1)
|
36 |
+
stats["pose_std"] = stats["pose_std"].reshape(-1)
|
37 |
+
self.stats = stats
|
38 |
+
# set up renderer
|
39 |
+
config_base = f"./checkpoints/ca_body/data/PXB184"
|
40 |
+
self.body_renderer = BodyRenderer(
|
41 |
+
config_base=config_base,
|
42 |
+
render_rgb=True,
|
43 |
+
)
|
44 |
+
|
45 |
+
def _setup_model(
|
46 |
+
self,
|
47 |
+
args_path: str,
|
48 |
+
model_path: str,
|
49 |
+
) -> (Union[FiLMTransformer, ClassifierFreeSampleModel], SpacedDiffusion):
|
50 |
+
with open(args_path) as f:
|
51 |
+
args = json.load(f)
|
52 |
+
args = AttrDict(args)
|
53 |
+
args.device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
54 |
+
print("running on...", args.device)
|
55 |
+
args.model_path = model_path
|
56 |
+
args.output_dir = "/tmp/gradio/"
|
57 |
+
args.timestep_respacing = "ddim100"
|
58 |
+
if args.data_format == "pose":
|
59 |
+
args.resume_trans = "checkpoints/guide/c1_pose/checkpoints/iter-0100000.pt"
|
60 |
+
|
61 |
+
## create model
|
62 |
+
model, diffusion = create_model_and_diffusion(args, split_type="test")
|
63 |
+
print(f"Loading checkpoints from [{args.model_path}]...")
|
64 |
+
state_dict = torch.load(args.model_path, map_location=args.device)
|
65 |
+
load_model(model, state_dict)
|
66 |
+
model = ClassifierFreeSampleModel(model)
|
67 |
+
model.eval()
|
68 |
+
model.to(args.device)
|
69 |
+
return model, diffusion, args.device
|
70 |
+
|
71 |
+
def _replace_keyframes(
|
72 |
+
self,
|
73 |
+
model_kwargs: Dict[str, Dict[str, torch.Tensor]],
|
74 |
+
B: int,
|
75 |
+
T: int,
|
76 |
+
top_p: float = 0.97,
|
77 |
+
) -> torch.Tensor:
|
78 |
+
with torch.no_grad():
|
79 |
+
tokens = self.pose_model.transformer.generate(
|
80 |
+
model_kwargs["y"]["audio"],
|
81 |
+
T,
|
82 |
+
layers=self.pose_model.tokenizer.residual_depth,
|
83 |
+
n_sequences=B,
|
84 |
+
top_p=top_p,
|
85 |
+
)
|
86 |
+
tokens = tokens.reshape((B, -1, self.pose_model.tokenizer.residual_depth))
|
87 |
+
pred = self.pose_model.tokenizer.decode(tokens).detach()
|
88 |
+
return pred
|
89 |
+
|
90 |
+
def _run_single_diffusion(
|
91 |
+
self,
|
92 |
+
model_kwargs: Dict[str, Dict[str, torch.Tensor]],
|
93 |
+
diffusion: SpacedDiffusion,
|
94 |
+
model: Union[FiLMTransformer, ClassifierFreeSampleModel],
|
95 |
+
curr_seq_length: int,
|
96 |
+
num_repetitions: int = 1,
|
97 |
+
) -> (torch.Tensor,):
|
98 |
+
sample_fn = diffusion.ddim_sample_loop
|
99 |
+
with torch.no_grad():
|
100 |
+
sample = sample_fn(
|
101 |
+
model,
|
102 |
+
(num_repetitions, model.nfeats, 1, curr_seq_length),
|
103 |
+
clip_denoised=False,
|
104 |
+
model_kwargs=model_kwargs,
|
105 |
+
init_image=None,
|
106 |
+
progress=True,
|
107 |
+
dump_steps=None,
|
108 |
+
noise=None,
|
109 |
+
const_noise=False,
|
110 |
+
)
|
111 |
+
return sample
|
112 |
+
|
113 |
+
def generate_sequences(
|
114 |
+
self,
|
115 |
+
model_kwargs: Dict[str, Dict[str, torch.Tensor]],
|
116 |
+
data_format: str,
|
117 |
+
curr_seq_length: int,
|
118 |
+
num_repetitions: int = 5,
|
119 |
+
guidance_param: float = 10.0,
|
120 |
+
top_p: float = 0.97,
|
121 |
+
# batch_size: int = 1,
|
122 |
+
) -> Dict[str, np.ndarray]:
|
123 |
+
if data_format == "pose":
|
124 |
+
model = self.pose_model
|
125 |
+
diffusion = self.pose_diffusion
|
126 |
+
else:
|
127 |
+
model = self.face_model
|
128 |
+
diffusion = self.face_diffusion
|
129 |
+
|
130 |
+
all_motions = []
|
131 |
+
model_kwargs["y"]["scale"] = torch.ones(num_repetitions) * guidance_param
|
132 |
+
model_kwargs["y"] = {
|
133 |
+
key: val.to(self.device) if torch.is_tensor(val) else val
|
134 |
+
for key, val in model_kwargs["y"].items()
|
135 |
+
}
|
136 |
+
if data_format == "pose":
|
137 |
+
model_kwargs["y"]["mask"] = (
|
138 |
+
torch.ones((num_repetitions, 1, 1, curr_seq_length))
|
139 |
+
.to(self.device)
|
140 |
+
.bool()
|
141 |
+
)
|
142 |
+
model_kwargs["y"]["keyframes"] = self._replace_keyframes(
|
143 |
+
model_kwargs,
|
144 |
+
num_repetitions,
|
145 |
+
int(curr_seq_length / 30),
|
146 |
+
top_p=top_p,
|
147 |
+
)
|
148 |
+
sample = self._run_single_diffusion(
|
149 |
+
model_kwargs, diffusion, model, curr_seq_length, num_repetitions
|
150 |
+
)
|
151 |
+
all_motions.append(sample.cpu().numpy())
|
152 |
+
print(f"created {len(all_motions) * num_repetitions} samples")
|
153 |
+
return np.concatenate(all_motions, axis=0)
|
154 |
+
|
155 |
+
|
156 |
+
def generate_results(audio: np.ndarray, num_repetitions: int, top_p: float):
|
157 |
+
if audio is None:
|
158 |
+
raise gr.Error("Please record audio to start")
|
159 |
+
sr, y = audio
|
160 |
+
# set to mono and perform resampling
|
161 |
+
y = torch.Tensor(y)
|
162 |
+
if y.dim() == 2:
|
163 |
+
dim = 0 if y.shape[0] == 2 else 1
|
164 |
+
y = torch.mean(y, dim=dim)
|
165 |
+
y = torchaudio.functional.resample(torch.Tensor(y), orig_freq=sr, new_freq=48_000)
|
166 |
+
sr = 48_000
|
167 |
+
# make it so that it is 4 seconds long
|
168 |
+
if len(y) < (sr * 4):
|
169 |
+
raise gr.Error("Please record at least 4 second of audio")
|
170 |
+
if num_repetitions is None or num_repetitions <= 0 or num_repetitions > 10:
|
171 |
+
raise gr.Error(
|
172 |
+
f"Invalid number of samples: {num_repetitions}. Please specify a number between 1-10"
|
173 |
+
)
|
174 |
+
cutoff = int(len(y) / (sr * 4))
|
175 |
+
y = y[: cutoff * sr * 4]
|
176 |
+
curr_seq_length = int(len(y) / sr) * 30
|
177 |
+
# create model_kwargs
|
178 |
+
model_kwargs = {"y": {}}
|
179 |
+
dual_audio = np.random.normal(0, 0.001, (1, len(y), 2))
|
180 |
+
dual_audio[:, :, 0] = y / max(y)
|
181 |
+
dual_audio = (dual_audio - gradio_model.stats["audio_mean"]) / gradio_model.stats[
|
182 |
+
"audio_std_flat"
|
183 |
+
]
|
184 |
+
model_kwargs["y"]["audio"] = (
|
185 |
+
torch.Tensor(dual_audio).float().tile(num_repetitions, 1, 1)
|
186 |
+
)
|
187 |
+
face_results = (
|
188 |
+
gradio_model.generate_sequences(
|
189 |
+
model_kwargs, "face", curr_seq_length, num_repetitions=int(num_repetitions)
|
190 |
+
)
|
191 |
+
.squeeze(2)
|
192 |
+
.transpose(0, 2, 1)
|
193 |
+
)
|
194 |
+
face_results = (
|
195 |
+
face_results * gradio_model.stats["code_std"] + gradio_model.stats["code_mean"]
|
196 |
+
)
|
197 |
+
pose_results = (
|
198 |
+
gradio_model.generate_sequences(
|
199 |
+
model_kwargs,
|
200 |
+
"pose",
|
201 |
+
curr_seq_length,
|
202 |
+
num_repetitions=int(num_repetitions),
|
203 |
+
guidance_param=2.0,
|
204 |
+
top_p=top_p,
|
205 |
+
)
|
206 |
+
.squeeze(2)
|
207 |
+
.transpose(0, 2, 1)
|
208 |
+
)
|
209 |
+
pose_results = (
|
210 |
+
pose_results * gradio_model.stats["pose_std"] + gradio_model.stats["pose_mean"]
|
211 |
+
)
|
212 |
+
dual_audio = (
|
213 |
+
dual_audio * gradio_model.stats["audio_std_flat"]
|
214 |
+
+ gradio_model.stats["audio_mean"]
|
215 |
+
)
|
216 |
+
return face_results, pose_results, dual_audio[0].transpose(1, 0).astype(np.float32)
|
217 |
+
|
218 |
+
|
219 |
+
def audio_to_avatar(audio: np.ndarray, num_repetitions: int, top_p: float):
|
220 |
+
face_results, pose_results, audio = generate_results(audio, num_repetitions, top_p)
|
221 |
+
# returns: num_rep x T x 104
|
222 |
+
B = len(face_results)
|
223 |
+
results = []
|
224 |
+
for i in range(B):
|
225 |
+
render_data_block = {
|
226 |
+
"audio": audio, # 2 x T
|
227 |
+
"body_motion": pose_results[i, ...], # T x 104
|
228 |
+
"face_motion": face_results[i, ...], # T x 256
|
229 |
+
}
|
230 |
+
gradio_model.body_renderer.render_full_video(
|
231 |
+
render_data_block, f"/tmp/sample{i}", audio_sr=48_000
|
232 |
+
)
|
233 |
+
results += [gr.Video(value=f"/tmp/sample{i}_pred.mp4", visible=True)]
|
234 |
+
results += [gr.Video(visible=False) for _ in range(B, 10)]
|
235 |
+
return results
|
236 |
+
|
237 |
+
|
238 |
+
gradio_model = GradioModel(
|
239 |
+
face_args="./checkpoints/diffusion/c1_face/args.json",
|
240 |
+
pose_args="./checkpoints/diffusion/c1_pose/args.json",
|
241 |
+
)
|
242 |
+
demo = gr.Interface(
|
243 |
+
audio_to_avatar, # function
|
244 |
+
[
|
245 |
+
gr.Audio(sources=["microphone"]),
|
246 |
+
gr.Number(
|
247 |
+
value=3,
|
248 |
+
label="Number of Samples (default = 3)",
|
249 |
+
precision=0,
|
250 |
+
minimum=1,
|
251 |
+
maximum=10,
|
252 |
+
),
|
253 |
+
gr.Number(
|
254 |
+
value=0.97,
|
255 |
+
label="Sample Diversity (default = 0.97)",
|
256 |
+
precision=None,
|
257 |
+
minimum=0.01,
|
258 |
+
step=0.01,
|
259 |
+
maximum=1.00,
|
260 |
+
),
|
261 |
+
], # input type
|
262 |
+
[gr.Video(format="mp4", visible=True)]
|
263 |
+
+ [gr.Video(format="mp4", visible=False) for _ in range(9)], # output type
|
264 |
+
title='"From Audio to Photoreal Embodiment: Synthesizing Humans in Conversations" Demo',
|
265 |
+
description="You can generate a photorealistic avatar from your voice! <br/>\
|
266 |
+
1) Start by recording your audio. <br/>\
|
267 |
+
2) Specify the number of samples to generate. <br/>\
|
268 |
+
3) Specify how diverse you want the samples to be. This tunes the cumulative probability in nucleus sampling: 0.01 = low diversity, 1.0 = high diversity. <br/>\
|
269 |
+
4) Then, sit back and wait for the rendering to happen! This may take a while (e.g. 30 minutes) <br/>\
|
270 |
+
5) After, you can view the videos and download the ones you like. <br/>",
|
271 |
+
article="Relevant links: [Project Page](https://people.eecs.berkeley.edu/~evonne_ng/projects/audio2photoreal)", # TODO: code and arxiv
|
272 |
+
)
|
273 |
+
|
274 |
+
if __name__ == "__main__":
|
275 |
+
fixseed(10)
|
276 |
+
demo.launch(share=True)
|
demo/demo.py
ADDED
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) Meta Platforms, Inc. and affiliates.
|
3 |
+
All rights reserved.
|
4 |
+
This source code is licensed under the license found in the
|
5 |
+
LICENSE file in the root directory of this source tree.
|
6 |
+
"""
|
7 |
+
|
8 |
+
import copy
|
9 |
+
import json
|
10 |
+
from typing import Dict, Union
|
11 |
+
|
12 |
+
import gradio as gr
|
13 |
+
import numpy as np
|
14 |
+
import torch
|
15 |
+
import torchaudio
|
16 |
+
from attrdict import AttrDict
|
17 |
+
from diffusion.respace import SpacedDiffusion
|
18 |
+
from model.cfg_sampler import ClassifierFreeSampleModel
|
19 |
+
from model.diffusion import FiLMTransformer
|
20 |
+
from utils.misc import fixseed
|
21 |
+
from utils.model_util import create_model_and_diffusion, load_model
|
22 |
+
from visualize.render_codes import BodyRenderer
|
23 |
+
|
24 |
+
|
25 |
+
class GradioModel:
|
26 |
+
def __init__(self, face_args, pose_args) -> None:
|
27 |
+
self.face_model, self.face_diffusion, self.device = self._setup_model(
|
28 |
+
face_args, "checkpoints/diffusion/c1_face/model000155000.pt"
|
29 |
+
)
|
30 |
+
self.pose_model, self.pose_diffusion, _ = self._setup_model(
|
31 |
+
pose_args, "checkpoints/diffusion/c1_pose/model000340000.pt"
|
32 |
+
)
|
33 |
+
# load standardization stuff
|
34 |
+
stats = torch.load("dataset/PXB184/data_stats.pth")
|
35 |
+
stats["pose_mean"] = stats["pose_mean"].reshape(-1)
|
36 |
+
stats["pose_std"] = stats["pose_std"].reshape(-1)
|
37 |
+
self.stats = stats
|
38 |
+
# set up renderer
|
39 |
+
config_base = f"./checkpoints/ca_body/data/PXB184"
|
40 |
+
self.body_renderer = BodyRenderer(
|
41 |
+
config_base=config_base,
|
42 |
+
render_rgb=True,
|
43 |
+
)
|
44 |
+
|
45 |
+
def _setup_model(
|
46 |
+
self,
|
47 |
+
args_path: str,
|
48 |
+
model_path: str,
|
49 |
+
) -> (Union[FiLMTransformer, ClassifierFreeSampleModel], SpacedDiffusion):
|
50 |
+
with open(args_path) as f:
|
51 |
+
args = json.load(f)
|
52 |
+
args = AttrDict(args)
|
53 |
+
args.device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
54 |
+
print("running on...", args.device)
|
55 |
+
args.model_path = model_path
|
56 |
+
args.output_dir = "/tmp/gradio/"
|
57 |
+
args.timestep_respacing = "ddim100"
|
58 |
+
if args.data_format == "pose":
|
59 |
+
args.resume_trans = "checkpoints/guide/c1_pose/checkpoints/iter-0100000.pt"
|
60 |
+
|
61 |
+
## create model
|
62 |
+
model, diffusion = create_model_and_diffusion(args, split_type="test")
|
63 |
+
print(f"Loading checkpoints from [{args.model_path}]...")
|
64 |
+
state_dict = torch.load(args.model_path, map_location=args.device)
|
65 |
+
load_model(model, state_dict)
|
66 |
+
model = ClassifierFreeSampleModel(model)
|
67 |
+
model.eval()
|
68 |
+
model.to(args.device)
|
69 |
+
return model, diffusion, args.device
|
70 |
+
|
71 |
+
def _replace_keyframes(
|
72 |
+
self,
|
73 |
+
model_kwargs: Dict[str, Dict[str, torch.Tensor]],
|
74 |
+
B: int,
|
75 |
+
T: int,
|
76 |
+
top_p: float = 0.97,
|
77 |
+
) -> torch.Tensor:
|
78 |
+
with torch.no_grad():
|
79 |
+
tokens = self.pose_model.transformer.generate(
|
80 |
+
model_kwargs["y"]["audio"],
|
81 |
+
T,
|
82 |
+
layers=self.pose_model.tokenizer.residual_depth,
|
83 |
+
n_sequences=B,
|
84 |
+
top_p=top_p,
|
85 |
+
)
|
86 |
+
tokens = tokens.reshape((B, -1, self.pose_model.tokenizer.residual_depth))
|
87 |
+
pred = self.pose_model.tokenizer.decode(tokens).detach()
|
88 |
+
return pred
|
89 |
+
|
90 |
+
def _run_single_diffusion(
|
91 |
+
self,
|
92 |
+
model_kwargs: Dict[str, Dict[str, torch.Tensor]],
|
93 |
+
diffusion: SpacedDiffusion,
|
94 |
+
model: Union[FiLMTransformer, ClassifierFreeSampleModel],
|
95 |
+
curr_seq_length: int,
|
96 |
+
num_repetitions: int = 1,
|
97 |
+
) -> (torch.Tensor,):
|
98 |
+
sample_fn = diffusion.ddim_sample_loop
|
99 |
+
with torch.no_grad():
|
100 |
+
sample = sample_fn(
|
101 |
+
model,
|
102 |
+
(num_repetitions, model.nfeats, 1, curr_seq_length),
|
103 |
+
clip_denoised=False,
|
104 |
+
model_kwargs=model_kwargs,
|
105 |
+
init_image=None,
|
106 |
+
progress=True,
|
107 |
+
dump_steps=None,
|
108 |
+
noise=None,
|
109 |
+
const_noise=False,
|
110 |
+
)
|
111 |
+
return sample
|
112 |
+
|
113 |
+
def generate_sequences(
|
114 |
+
self,
|
115 |
+
model_kwargs: Dict[str, Dict[str, torch.Tensor]],
|
116 |
+
data_format: str,
|
117 |
+
curr_seq_length: int,
|
118 |
+
num_repetitions: int = 5,
|
119 |
+
guidance_param: float = 10.0,
|
120 |
+
top_p: float = 0.97,
|
121 |
+
# batch_size: int = 1,
|
122 |
+
) -> Dict[str, np.ndarray]:
|
123 |
+
if data_format == "pose":
|
124 |
+
model = self.pose_model
|
125 |
+
diffusion = self.pose_diffusion
|
126 |
+
else:
|
127 |
+
model = self.face_model
|
128 |
+
diffusion = self.face_diffusion
|
129 |
+
|
130 |
+
all_motions = []
|
131 |
+
model_kwargs["y"]["scale"] = torch.ones(num_repetitions) * guidance_param
|
132 |
+
model_kwargs["y"] = {
|
133 |
+
key: val.to(self.device) if torch.is_tensor(val) else val
|
134 |
+
for key, val in model_kwargs["y"].items()
|
135 |
+
}
|
136 |
+
if data_format == "pose":
|
137 |
+
model_kwargs["y"]["mask"] = (
|
138 |
+
torch.ones((num_repetitions, 1, 1, curr_seq_length))
|
139 |
+
.to(self.device)
|
140 |
+
.bool()
|
141 |
+
)
|
142 |
+
model_kwargs["y"]["keyframes"] = self._replace_keyframes(
|
143 |
+
model_kwargs,
|
144 |
+
num_repetitions,
|
145 |
+
int(curr_seq_length / 30),
|
146 |
+
top_p=top_p,
|
147 |
+
)
|
148 |
+
sample = self._run_single_diffusion(
|
149 |
+
model_kwargs, diffusion, model, curr_seq_length, num_repetitions
|
150 |
+
)
|
151 |
+
all_motions.append(sample.cpu().numpy())
|
152 |
+
print(f"created {len(all_motions) * num_repetitions} samples")
|
153 |
+
return np.concatenate(all_motions, axis=0)
|
154 |
+
|
155 |
+
|
156 |
+
def generate_results(audio: np.ndarray, num_repetitions: int, top_p: float):
|
157 |
+
if audio is None:
|
158 |
+
raise gr.Error("Please record audio to start")
|
159 |
+
sr, y = audio
|
160 |
+
# set to mono and perform resampling
|
161 |
+
y = torch.Tensor(y)
|
162 |
+
if y.dim() == 2:
|
163 |
+
dim = 0 if y.shape[0] == 2 else 1
|
164 |
+
y = torch.mean(y, dim=dim)
|
165 |
+
y = torchaudio.functional.resample(torch.Tensor(y), orig_freq=sr, new_freq=48_000)
|
166 |
+
sr = 48_000
|
167 |
+
# make it so that it is 4 seconds long
|
168 |
+
if len(y) < (sr * 4):
|
169 |
+
raise gr.Error("Please record at least 4 second of audio")
|
170 |
+
if num_repetitions is None or num_repetitions <= 0 or num_repetitions > 10:
|
171 |
+
raise gr.Error(
|
172 |
+
f"Invalid number of samples: {num_repetitions}. Please specify a number between 1-10"
|
173 |
+
)
|
174 |
+
cutoff = int(len(y) / (sr * 4))
|
175 |
+
y = y[: cutoff * sr * 4]
|
176 |
+
curr_seq_length = int(len(y) / sr) * 30
|
177 |
+
# create model_kwargs
|
178 |
+
model_kwargs = {"y": {}}
|
179 |
+
dual_audio = np.random.normal(0, 0.001, (1, len(y), 2))
|
180 |
+
dual_audio[:, :, 0] = y / max(y)
|
181 |
+
dual_audio = (dual_audio - gradio_model.stats["audio_mean"]) / gradio_model.stats[
|
182 |
+
"audio_std_flat"
|
183 |
+
]
|
184 |
+
model_kwargs["y"]["audio"] = (
|
185 |
+
torch.Tensor(dual_audio).float().tile(num_repetitions, 1, 1)
|
186 |
+
)
|
187 |
+
face_results = (
|
188 |
+
gradio_model.generate_sequences(
|
189 |
+
model_kwargs, "face", curr_seq_length, num_repetitions=int(num_repetitions)
|
190 |
+
)
|
191 |
+
.squeeze(2)
|
192 |
+
.transpose(0, 2, 1)
|
193 |
+
)
|
194 |
+
face_results = (
|
195 |
+
face_results * gradio_model.stats["code_std"] + gradio_model.stats["code_mean"]
|
196 |
+
)
|
197 |
+
pose_results = (
|
198 |
+
gradio_model.generate_sequences(
|
199 |
+
model_kwargs,
|
200 |
+
"pose",
|
201 |
+
curr_seq_length,
|
202 |
+
num_repetitions=int(num_repetitions),
|
203 |
+
guidance_param=2.0,
|
204 |
+
top_p=top_p,
|
205 |
+
)
|
206 |
+
.squeeze(2)
|
207 |
+
.transpose(0, 2, 1)
|
208 |
+
)
|
209 |
+
pose_results = (
|
210 |
+
pose_results * gradio_model.stats["pose_std"] + gradio_model.stats["pose_mean"]
|
211 |
+
)
|
212 |
+
dual_audio = (
|
213 |
+
dual_audio * gradio_model.stats["audio_std_flat"]
|
214 |
+
+ gradio_model.stats["audio_mean"]
|
215 |
+
)
|
216 |
+
return face_results, pose_results, dual_audio[0].transpose(1, 0).astype(np.float32)
|
217 |
+
|
218 |
+
|
219 |
+
def audio_to_avatar(audio: np.ndarray, num_repetitions: int, top_p: float):
|
220 |
+
face_results, pose_results, audio = generate_results(audio, num_repetitions, top_p)
|
221 |
+
# returns: num_rep x T x 104
|
222 |
+
B = len(face_results)
|
223 |
+
results = []
|
224 |
+
for i in range(B):
|
225 |
+
render_data_block = {
|
226 |
+
"audio": audio, # 2 x T
|
227 |
+
"body_motion": pose_results[i, ...], # T x 104
|
228 |
+
"face_motion": face_results[i, ...], # T x 256
|
229 |
+
}
|
230 |
+
gradio_model.body_renderer.render_full_video(
|
231 |
+
render_data_block, f"/tmp/sample{i}", audio_sr=48_000
|
232 |
+
)
|
233 |
+
results += [gr.Video(value=f"/tmp/sample{i}_pred.mp4", visible=True)]
|
234 |
+
results += [gr.Video(visible=False) for _ in range(B, 10)]
|
235 |
+
return results
|
236 |
+
|
237 |
+
|
238 |
+
gradio_model = GradioModel(
|
239 |
+
face_args="./checkpoints/diffusion/c1_face/args.json",
|
240 |
+
pose_args="./checkpoints/diffusion/c1_pose/args.json",
|
241 |
+
)
|
242 |
+
demo = gr.Interface(
|
243 |
+
audio_to_avatar, # function
|
244 |
+
[
|
245 |
+
gr.Audio(sources=["microphone"]),
|
246 |
+
gr.Number(
|
247 |
+
value=3,
|
248 |
+
label="Number of Samples (default = 3)",
|
249 |
+
precision=0,
|
250 |
+
minimum=1,
|
251 |
+
maximum=10,
|
252 |
+
),
|
253 |
+
gr.Number(
|
254 |
+
value=0.97,
|
255 |
+
label="Sample Diversity (default = 0.97)",
|
256 |
+
precision=None,
|
257 |
+
minimum=0.01,
|
258 |
+
step=0.01,
|
259 |
+
maximum=1.00,
|
260 |
+
),
|
261 |
+
], # input type
|
262 |
+
[gr.Video(format="mp4", visible=True)]
|
263 |
+
+ [gr.Video(format="mp4", visible=False) for _ in range(9)], # output type
|
264 |
+
title='"From Audio to Photoreal Embodiment: Synthesizing Humans in Conversations" Demo',
|
265 |
+
description="You can generate a photorealistic avatar from your voice! <br/>\
|
266 |
+
1) Start by recording your audio. <br/>\
|
267 |
+
2) Specify the number of samples to generate. <br/>\
|
268 |
+
3) Specify how diverse you want the samples to be. This tunes the cumulative probability in nucleus sampling: 0.01 = low diversity, 1.0 = high diversity. <br/>\
|
269 |
+
4) Then, sit back and wait for the rendering to happen! This may take a while (e.g. 30 minutes) <br/>\
|
270 |
+
5) After, you can view the videos and download the ones you like. <br/>",
|
271 |
+
article="Relevant links: [Project Page](https://people.eecs.berkeley.edu/~evonne_ng/projects/audio2photoreal)", # TODO: code and arxiv
|
272 |
+
)
|
273 |
+
|
274 |
+
if __name__ == "__main__":
|
275 |
+
fixseed(10)
|
276 |
+
demo.launch(share=True)
|
demo/install.sh
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
# make sure to have cuda 11.7 and gcc 9.0 installed
|
4 |
+
# install environment
|
5 |
+
pip install -r scripts/requirements.txt
|
6 |
+
sh scripts/download_prereq.sh
|
7 |
+
|
8 |
+
# download pytorch3d
|
9 |
+
pip install "git+https://github.com/facebookresearch/pytorch3d.git"
|
10 |
+
|
11 |
+
# download model stuff
|
12 |
+
wget http://audio2photoreal_models.berkeleyvision.org/PXB184_models.tar || { echo 'downloading model failed' ; exit 1; }
|
13 |
+
tar xvf PXB184_models.tar
|
14 |
+
rm PXB184_models.tar
|
15 |
+
|
16 |
+
# install rendering stuff
|
17 |
+
mkdir -p checkpoints/ca_body/data/
|
18 |
+
wget https://github.com/facebookresearch/ca_body/releases/download/v0.0.1-alpha/PXB184.tar.gz || { echo 'downloading ca body model failed' ; exit 1; }
|
19 |
+
tar xvf PXB184.tar.gz --directory checkpoints/ca_body/data/
|
20 |
+
rm PXB184.tar.gz
|
demo/requirements.txt
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
attrdict
|
2 |
+
einops==0.7.0
|
3 |
+
fairseq==0.12.2
|
4 |
+
gradio==4.31.3
|
5 |
+
gradio_client==0.7.3
|
6 |
+
huggingface-hub==0.19.4
|
7 |
+
hydra-core==1.0.7
|
8 |
+
mediapy==1.2.0
|
9 |
+
numpy==1.26.2
|
10 |
+
omegaconf==2.0.6
|
11 |
+
opencv-python==4.8.1.78
|
12 |
+
protobuf==4.25.1
|
13 |
+
tensorboardX==2.6.2.2
|
14 |
+
torch==2.0.1
|
15 |
+
torchaudio==2.0.2
|
16 |
+
torchvision==0.15.2
|
17 |
+
tqdm==4.66.3
|
diffusion/fp16_util.py
ADDED
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) Meta Platforms, Inc. and affiliates.
|
3 |
+
All rights reserved.
|
4 |
+
This source code is licensed under the license found in the
|
5 |
+
LICENSE file in the root directory of this source tree.
|
6 |
+
"""
|
7 |
+
|
8 |
+
"""
|
9 |
+
original code from
|
10 |
+
https://github.com/GuyTevet/motion-diffusion-model/blob/main/diffusion/gaussian_diffusion.py
|
11 |
+
under an MIT license
|
12 |
+
https://github.com/GuyTevet/motion-diffusion-model/blob/main/LICENSE
|
13 |
+
"""
|
14 |
+
|
15 |
+
"""
|
16 |
+
Helpers to train with 16-bit precision.
|
17 |
+
"""
|
18 |
+
|
19 |
+
import numpy as np
|
20 |
+
import torch as th
|
21 |
+
import torch.nn as nn
|
22 |
+
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
23 |
+
|
24 |
+
from utils import logger
|
25 |
+
|
26 |
+
INITIAL_LOG_LOSS_SCALE = 20.0
|
27 |
+
|
28 |
+
|
29 |
+
def convert_module_to_f16(l):
|
30 |
+
"""
|
31 |
+
Convert primitive modules to float16.
|
32 |
+
"""
|
33 |
+
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
|
34 |
+
l.weight.data = l.weight.data.half()
|
35 |
+
if l.bias is not None:
|
36 |
+
l.bias.data = l.bias.data.half()
|
37 |
+
|
38 |
+
|
39 |
+
def convert_module_to_f32(l):
|
40 |
+
"""
|
41 |
+
Convert primitive modules to float32, undoing convert_module_to_f16().
|
42 |
+
"""
|
43 |
+
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
|
44 |
+
l.weight.data = l.weight.data.float()
|
45 |
+
if l.bias is not None:
|
46 |
+
l.bias.data = l.bias.data.float()
|
47 |
+
|
48 |
+
|
49 |
+
def make_master_params(param_groups_and_shapes):
|
50 |
+
"""
|
51 |
+
Copy model parameters into a (differently-shaped) list of full-precision
|
52 |
+
parameters.
|
53 |
+
"""
|
54 |
+
master_params = []
|
55 |
+
for param_group, shape in param_groups_and_shapes:
|
56 |
+
master_param = nn.Parameter(
|
57 |
+
_flatten_dense_tensors(
|
58 |
+
[param.detach().float() for (_, param) in param_group]
|
59 |
+
).view(shape)
|
60 |
+
)
|
61 |
+
master_param.requires_grad = True
|
62 |
+
master_params.append(master_param)
|
63 |
+
return master_params
|
64 |
+
|
65 |
+
|
66 |
+
def model_grads_to_master_grads(param_groups_and_shapes, master_params):
|
67 |
+
"""
|
68 |
+
Copy the gradients from the model parameters into the master parameters
|
69 |
+
from make_master_params().
|
70 |
+
"""
|
71 |
+
for master_param, (param_group, shape) in zip(
|
72 |
+
master_params, param_groups_and_shapes
|
73 |
+
):
|
74 |
+
master_param.grad = _flatten_dense_tensors(
|
75 |
+
[param_grad_or_zeros(param) for (_, param) in param_group]
|
76 |
+
).view(shape)
|
77 |
+
|
78 |
+
|
79 |
+
def master_params_to_model_params(param_groups_and_shapes, master_params):
|
80 |
+
"""
|
81 |
+
Copy the master parameter data back into the model parameters.
|
82 |
+
"""
|
83 |
+
# Without copying to a list, if a generator is passed, this will
|
84 |
+
# silently not copy any parameters.
|
85 |
+
for master_param, (param_group, _) in zip(master_params, param_groups_and_shapes):
|
86 |
+
for (_, param), unflat_master_param in zip(
|
87 |
+
param_group, unflatten_master_params(param_group, master_param.view(-1))
|
88 |
+
):
|
89 |
+
param.detach().copy_(unflat_master_param)
|
90 |
+
|
91 |
+
|
92 |
+
def unflatten_master_params(param_group, master_param):
|
93 |
+
return _unflatten_dense_tensors(master_param, [param for (_, param) in param_group])
|
94 |
+
|
95 |
+
|
96 |
+
def get_param_groups_and_shapes(named_model_params):
|
97 |
+
named_model_params = list(named_model_params)
|
98 |
+
scalar_vector_named_params = (
|
99 |
+
[(n, p) for (n, p) in named_model_params if p.ndim <= 1],
|
100 |
+
(-1),
|
101 |
+
)
|
102 |
+
matrix_named_params = (
|
103 |
+
[(n, p) for (n, p) in named_model_params if p.ndim > 1],
|
104 |
+
(1, -1),
|
105 |
+
)
|
106 |
+
return [scalar_vector_named_params, matrix_named_params]
|
107 |
+
|
108 |
+
|
109 |
+
def master_params_to_state_dict(
|
110 |
+
model, param_groups_and_shapes, master_params, use_fp16
|
111 |
+
):
|
112 |
+
if use_fp16:
|
113 |
+
state_dict = model.state_dict()
|
114 |
+
for master_param, (param_group, _) in zip(
|
115 |
+
master_params, param_groups_and_shapes
|
116 |
+
):
|
117 |
+
for (name, _), unflat_master_param in zip(
|
118 |
+
param_group, unflatten_master_params(param_group, master_param.view(-1))
|
119 |
+
):
|
120 |
+
assert name in state_dict
|
121 |
+
state_dict[name] = unflat_master_param
|
122 |
+
else:
|
123 |
+
state_dict = model.state_dict()
|
124 |
+
for i, (name, _value) in enumerate(model.named_parameters()):
|
125 |
+
assert name in state_dict
|
126 |
+
state_dict[name] = master_params[i]
|
127 |
+
return state_dict
|
128 |
+
|
129 |
+
|
130 |
+
def state_dict_to_master_params(model, state_dict, use_fp16):
|
131 |
+
if use_fp16:
|
132 |
+
named_model_params = [
|
133 |
+
(name, state_dict[name]) for name, _ in model.named_parameters()
|
134 |
+
]
|
135 |
+
param_groups_and_shapes = get_param_groups_and_shapes(named_model_params)
|
136 |
+
master_params = make_master_params(param_groups_and_shapes)
|
137 |
+
else:
|
138 |
+
master_params = [state_dict[name] for name, _ in model.named_parameters()]
|
139 |
+
return master_params
|
140 |
+
|
141 |
+
|
142 |
+
def zero_master_grads(master_params):
|
143 |
+
for param in master_params:
|
144 |
+
param.grad = None
|
145 |
+
|
146 |
+
|
147 |
+
def zero_grad(model_params):
|
148 |
+
for param in model_params:
|
149 |
+
# Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group
|
150 |
+
if param.grad is not None:
|
151 |
+
param.grad.detach_()
|
152 |
+
param.grad.zero_()
|
153 |
+
|
154 |
+
|
155 |
+
def param_grad_or_zeros(param):
|
156 |
+
if param.grad is not None:
|
157 |
+
return param.grad.data.detach()
|
158 |
+
else:
|
159 |
+
return th.zeros_like(param)
|
160 |
+
|
161 |
+
|
162 |
+
class MixedPrecisionTrainer:
|
163 |
+
def __init__(
|
164 |
+
self,
|
165 |
+
*,
|
166 |
+
model,
|
167 |
+
use_fp16=False,
|
168 |
+
fp16_scale_growth=1e-3,
|
169 |
+
initial_lg_loss_scale=INITIAL_LOG_LOSS_SCALE,
|
170 |
+
):
|
171 |
+
self.model = model
|
172 |
+
self.use_fp16 = use_fp16
|
173 |
+
self.fp16_scale_growth = fp16_scale_growth
|
174 |
+
|
175 |
+
self.model_params = list(self.model.parameters())
|
176 |
+
self.master_params = self.model_params
|
177 |
+
self.param_groups_and_shapes = None
|
178 |
+
self.lg_loss_scale = initial_lg_loss_scale
|
179 |
+
|
180 |
+
if self.use_fp16:
|
181 |
+
self.param_groups_and_shapes = get_param_groups_and_shapes(
|
182 |
+
self.model.named_parameters()
|
183 |
+
)
|
184 |
+
self.master_params = make_master_params(self.param_groups_and_shapes)
|
185 |
+
self.model.convert_to_fp16()
|
186 |
+
|
187 |
+
def zero_grad(self):
|
188 |
+
zero_grad(self.model_params)
|
189 |
+
|
190 |
+
def backward(self, loss: th.Tensor):
|
191 |
+
if self.use_fp16:
|
192 |
+
loss_scale = 2**self.lg_loss_scale
|
193 |
+
(loss * loss_scale).backward()
|
194 |
+
else:
|
195 |
+
loss.backward()
|
196 |
+
|
197 |
+
def optimize(self, opt: th.optim.Optimizer):
|
198 |
+
if self.use_fp16:
|
199 |
+
return self._optimize_fp16(opt)
|
200 |
+
else:
|
201 |
+
return self._optimize_normal(opt)
|
202 |
+
|
203 |
+
def _optimize_fp16(self, opt: th.optim.Optimizer):
|
204 |
+
logger.logkv_mean("lg_loss_scale", self.lg_loss_scale)
|
205 |
+
model_grads_to_master_grads(self.param_groups_and_shapes, self.master_params)
|
206 |
+
grad_norm, param_norm = self._compute_norms(grad_scale=2**self.lg_loss_scale)
|
207 |
+
if check_overflow(grad_norm):
|
208 |
+
self.lg_loss_scale -= 1
|
209 |
+
logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}")
|
210 |
+
zero_master_grads(self.master_params)
|
211 |
+
return False
|
212 |
+
|
213 |
+
logger.logkv_mean("grad_norm", grad_norm)
|
214 |
+
logger.logkv_mean("param_norm", param_norm)
|
215 |
+
|
216 |
+
self.master_params[0].grad.mul_(1.0 / (2**self.lg_loss_scale))
|
217 |
+
opt.step()
|
218 |
+
zero_master_grads(self.master_params)
|
219 |
+
master_params_to_model_params(self.param_groups_and_shapes, self.master_params)
|
220 |
+
self.lg_loss_scale += self.fp16_scale_growth
|
221 |
+
return True
|
222 |
+
|
223 |
+
def _optimize_normal(self, opt: th.optim.Optimizer):
|
224 |
+
grad_norm, param_norm = self._compute_norms()
|
225 |
+
logger.logkv_mean("grad_norm", grad_norm)
|
226 |
+
logger.logkv_mean("param_norm", param_norm)
|
227 |
+
opt.step()
|
228 |
+
return True
|
229 |
+
|
230 |
+
def _compute_norms(self, grad_scale=1.0):
|
231 |
+
grad_norm = 0.0
|
232 |
+
param_norm = 0.0
|
233 |
+
for p in self.master_params:
|
234 |
+
with th.no_grad():
|
235 |
+
param_norm += th.norm(p, p=2, dtype=th.float32).item() ** 2
|
236 |
+
if p.grad is not None:
|
237 |
+
grad_norm += th.norm(p.grad, p=2, dtype=th.float32).item() ** 2
|
238 |
+
return np.sqrt(grad_norm) / grad_scale, np.sqrt(param_norm)
|
239 |
+
|
240 |
+
def master_params_to_state_dict(self, master_params):
|
241 |
+
return master_params_to_state_dict(
|
242 |
+
self.model, self.param_groups_and_shapes, master_params, self.use_fp16
|
243 |
+
)
|
244 |
+
|
245 |
+
def state_dict_to_master_params(self, state_dict):
|
246 |
+
return state_dict_to_master_params(self.model, state_dict, self.use_fp16)
|
247 |
+
|
248 |
+
|
249 |
+
def check_overflow(value):
|
250 |
+
return (value == float("inf")) or (value == -float("inf")) or (value != value)
|
diffusion/gaussian_diffusion.py
ADDED
@@ -0,0 +1,1273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) Meta Platforms, Inc. and affiliates.
|
3 |
+
All rights reserved.
|
4 |
+
This source code is licensed under the license found in the
|
5 |
+
LICENSE file in the root directory of this source tree.
|
6 |
+
"""
|
7 |
+
|
8 |
+
"""
|
9 |
+
original code from
|
10 |
+
https://github.com/GuyTevet/motion-diffusion-model/blob/main/diffusion/gaussian_diffusion.py
|
11 |
+
under an MIT license
|
12 |
+
https://github.com/GuyTevet/motion-diffusion-model/blob/main/LICENSE
|
13 |
+
"""
|
14 |
+
|
15 |
+
import enum
|
16 |
+
import math
|
17 |
+
from copy import deepcopy
|
18 |
+
|
19 |
+
import numpy as np
|
20 |
+
import torch
|
21 |
+
import torch as th
|
22 |
+
from diffusion.losses import discretized_gaussian_log_likelihood, normal_kl
|
23 |
+
from diffusion.nn import mean_flat, sum_flat
|
24 |
+
|
25 |
+
|
26 |
+
def get_named_beta_schedule(schedule_name, num_diffusion_timesteps, scale_betas=1.0):
|
27 |
+
"""
|
28 |
+
Get a pre-defined beta schedule for the given name.
|
29 |
+
|
30 |
+
The beta schedule library consists of beta schedules which remain similar
|
31 |
+
in the limit of num_diffusion_timesteps.
|
32 |
+
Beta schedules may be added, but should not be removed or changed once
|
33 |
+
they are committed to maintain backwards compatibility.
|
34 |
+
"""
|
35 |
+
if schedule_name == "linear":
|
36 |
+
# Linear schedule from Ho et al, extended to work for any number of
|
37 |
+
# diffusion steps.
|
38 |
+
scale = scale_betas * 1000 / num_diffusion_timesteps
|
39 |
+
beta_start = scale * 0.0001
|
40 |
+
beta_end = scale * 0.02
|
41 |
+
return np.linspace(
|
42 |
+
beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
|
43 |
+
)
|
44 |
+
elif schedule_name == "cosine":
|
45 |
+
return betas_for_alpha_bar(
|
46 |
+
num_diffusion_timesteps,
|
47 |
+
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
|
48 |
+
)
|
49 |
+
else:
|
50 |
+
raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
|
51 |
+
|
52 |
+
|
53 |
+
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
|
54 |
+
"""
|
55 |
+
Create a beta schedule that discretizes the given alpha_t_bar function,
|
56 |
+
which defines the cumulative product of (1-beta) over time from t = [0,1].
|
57 |
+
|
58 |
+
:param num_diffusion_timesteps: the number of betas to produce.
|
59 |
+
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
|
60 |
+
produces the cumulative product of (1-beta) up to that
|
61 |
+
part of the diffusion process.
|
62 |
+
:param max_beta: the maximum beta to use; use values lower than 1 to
|
63 |
+
prevent singularities.
|
64 |
+
"""
|
65 |
+
betas = []
|
66 |
+
for i in range(num_diffusion_timesteps):
|
67 |
+
t1 = i / num_diffusion_timesteps
|
68 |
+
t2 = (i + 1) / num_diffusion_timesteps
|
69 |
+
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
70 |
+
return np.array(betas)
|
71 |
+
|
72 |
+
|
73 |
+
class ModelMeanType(enum.Enum):
|
74 |
+
"""
|
75 |
+
Which type of output the model predicts.
|
76 |
+
"""
|
77 |
+
|
78 |
+
PREVIOUS_X = enum.auto() # the model predicts x_{t-1}
|
79 |
+
START_X = enum.auto() # the model predicts x_0
|
80 |
+
EPSILON = enum.auto() # the model predicts epsilon
|
81 |
+
|
82 |
+
|
83 |
+
class ModelVarType(enum.Enum):
|
84 |
+
"""
|
85 |
+
What is used as the model's output variance.
|
86 |
+
|
87 |
+
The LEARNED_RANGE option has been added to allow the model to predict
|
88 |
+
values between FIXED_SMALL and FIXED_LARGE, making its job easier.
|
89 |
+
"""
|
90 |
+
|
91 |
+
LEARNED = enum.auto()
|
92 |
+
FIXED_SMALL = enum.auto()
|
93 |
+
FIXED_LARGE = enum.auto()
|
94 |
+
LEARNED_RANGE = enum.auto()
|
95 |
+
|
96 |
+
|
97 |
+
class LossType(enum.Enum):
|
98 |
+
MSE = enum.auto() # use raw MSE loss (and KL when learning variances)
|
99 |
+
RESCALED_MSE = (
|
100 |
+
enum.auto()
|
101 |
+
) # use raw MSE loss (with RESCALED_KL when learning variances)
|
102 |
+
KL = enum.auto() # use the variational lower-bound
|
103 |
+
RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB
|
104 |
+
|
105 |
+
def is_vb(self):
|
106 |
+
return self == LossType.KL or self == LossType.RESCALED_KL
|
107 |
+
|
108 |
+
|
109 |
+
class GaussianDiffusion:
|
110 |
+
"""
|
111 |
+
Utilities for training and sampling diffusion models.
|
112 |
+
|
113 |
+
Ported directly from here, and then adapted over time to further experimentation.
|
114 |
+
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
|
115 |
+
|
116 |
+
:param betas: a 1-D numpy array of betas for each diffusion timestep,
|
117 |
+
starting at T and going to 1.
|
118 |
+
:param model_mean_type: a ModelMeanType determining what the model outputs.
|
119 |
+
:param model_var_type: a ModelVarType determining how variance is output.
|
120 |
+
:param loss_type: a LossType determining the loss function to use.
|
121 |
+
:param rescale_timesteps: if True, pass floating point timesteps into the
|
122 |
+
model so that they are always scaled like in the
|
123 |
+
original paper (0 to 1000).
|
124 |
+
"""
|
125 |
+
|
126 |
+
def __init__(
|
127 |
+
self,
|
128 |
+
*,
|
129 |
+
betas,
|
130 |
+
model_mean_type,
|
131 |
+
model_var_type,
|
132 |
+
loss_type,
|
133 |
+
rescale_timesteps=False,
|
134 |
+
lambda_vel=0.0,
|
135 |
+
data_format="pose",
|
136 |
+
model_path=None,
|
137 |
+
):
|
138 |
+
self.model_mean_type = model_mean_type
|
139 |
+
self.model_var_type = model_var_type
|
140 |
+
self.loss_type = loss_type
|
141 |
+
self.rescale_timesteps = rescale_timesteps
|
142 |
+
self.data_format = data_format
|
143 |
+
self.lambda_vel = lambda_vel
|
144 |
+
if self.lambda_vel > 0.0:
|
145 |
+
assert (
|
146 |
+
self.loss_type == LossType.MSE
|
147 |
+
), "Geometric losses are supported by MSE loss type only!"
|
148 |
+
|
149 |
+
# Use float64 for accuracy.
|
150 |
+
betas = np.array(betas, dtype=np.float64)
|
151 |
+
self.betas = betas
|
152 |
+
assert len(betas.shape) == 1, "betas must be 1-D"
|
153 |
+
assert (betas > 0).all() and (betas <= 1).all()
|
154 |
+
|
155 |
+
self.num_timesteps = int(betas.shape[0])
|
156 |
+
|
157 |
+
alphas = 1.0 - betas
|
158 |
+
self.alphas_cumprod = np.cumprod(alphas, axis=0)
|
159 |
+
self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
|
160 |
+
self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
|
161 |
+
assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
|
162 |
+
|
163 |
+
# calculations for diffusion q(x_t | x_{t-1}) and others
|
164 |
+
self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
|
165 |
+
self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
|
166 |
+
self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
|
167 |
+
self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
|
168 |
+
self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
|
169 |
+
|
170 |
+
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
171 |
+
self.posterior_variance = (
|
172 |
+
betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
|
173 |
+
)
|
174 |
+
# log calculation clipped because the posterior variance is 0 at the
|
175 |
+
# beginning of the diffusion chain.
|
176 |
+
self.posterior_log_variance_clipped = np.log(
|
177 |
+
np.append(self.posterior_variance[1], self.posterior_variance[1:])
|
178 |
+
)
|
179 |
+
self.posterior_mean_coef1 = (
|
180 |
+
betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
|
181 |
+
)
|
182 |
+
self.posterior_mean_coef2 = (
|
183 |
+
(1.0 - self.alphas_cumprod_prev)
|
184 |
+
* np.sqrt(alphas)
|
185 |
+
/ (1.0 - self.alphas_cumprod)
|
186 |
+
)
|
187 |
+
|
188 |
+
self.l2_loss = lambda a, b: (a - b) ** 2
|
189 |
+
|
190 |
+
def masked_l2(self, a, b, mask):
|
191 |
+
loss = self.l2_loss(a, b)
|
192 |
+
loss = sum_flat(loss * mask.float())
|
193 |
+
n_entries = a.shape[1] * a.shape[2]
|
194 |
+
non_zero_elements = sum_flat(mask) * n_entries
|
195 |
+
mse_loss_val = loss / non_zero_elements
|
196 |
+
return mse_loss_val
|
197 |
+
|
198 |
+
def q_mean_variance(self, x_start, t):
|
199 |
+
"""
|
200 |
+
Get the distribution q(x_t | x_0).
|
201 |
+
|
202 |
+
:param x_start: the [N x C x ...] tensor of noiseless inputs.
|
203 |
+
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
|
204 |
+
:return: A tuple (mean, variance, log_variance), all of x_start's shape.
|
205 |
+
"""
|
206 |
+
mean = (
|
207 |
+
_extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
208 |
+
)
|
209 |
+
variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
|
210 |
+
log_variance = _extract_into_tensor(
|
211 |
+
self.log_one_minus_alphas_cumprod, t, x_start.shape
|
212 |
+
)
|
213 |
+
return mean, variance, log_variance
|
214 |
+
|
215 |
+
def q_sample(self, x_start, t, noise=None):
|
216 |
+
"""
|
217 |
+
Diffuse the dataset for a given number of diffusion steps.
|
218 |
+
|
219 |
+
In other words, sample from q(x_t | x_0).
|
220 |
+
|
221 |
+
:param x_start: the initial dataset batch.
|
222 |
+
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
|
223 |
+
:param noise: if specified, the split-out normal noise.
|
224 |
+
:return: A noisy version of x_start.
|
225 |
+
"""
|
226 |
+
if noise is None:
|
227 |
+
noise = th.randn_like(x_start)
|
228 |
+
assert noise.shape == x_start.shape
|
229 |
+
return (
|
230 |
+
_extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
231 |
+
+ _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
|
232 |
+
* noise
|
233 |
+
)
|
234 |
+
|
235 |
+
def q_posterior_mean_variance(self, x_start, x_t, t):
|
236 |
+
"""
|
237 |
+
Compute the mean and variance of the diffusion posterior:
|
238 |
+
|
239 |
+
q(x_{t-1} | x_t, x_0)
|
240 |
+
|
241 |
+
"""
|
242 |
+
assert x_start.shape == x_t.shape, f"x_start: {x_start.shape}, x_t: {x_t.shape}"
|
243 |
+
posterior_mean = (
|
244 |
+
_extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
|
245 |
+
+ _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
|
246 |
+
)
|
247 |
+
posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
|
248 |
+
posterior_log_variance_clipped = _extract_into_tensor(
|
249 |
+
self.posterior_log_variance_clipped, t, x_t.shape
|
250 |
+
)
|
251 |
+
assert (
|
252 |
+
posterior_mean.shape[0]
|
253 |
+
== posterior_variance.shape[0]
|
254 |
+
== posterior_log_variance_clipped.shape[0]
|
255 |
+
== x_start.shape[0]
|
256 |
+
)
|
257 |
+
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
258 |
+
|
259 |
+
def p_mean_variance(
|
260 |
+
self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None
|
261 |
+
):
|
262 |
+
"""
|
263 |
+
Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
|
264 |
+
the initial x, x_0.
|
265 |
+
|
266 |
+
:param model: the model, which takes a signal and a batch of timesteps
|
267 |
+
as input.
|
268 |
+
:param x: the [N x C x ...] tensor at time t.
|
269 |
+
:param t: a 1-D Tensor of timesteps.
|
270 |
+
:param clip_denoised: if True, clip the denoised signal into [-1, 1].
|
271 |
+
:param denoised_fn: if not None, a function which applies to the
|
272 |
+
x_start prediction before it is used to sample. Applies before
|
273 |
+
clip_denoised.
|
274 |
+
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
275 |
+
pass to the model. This can be used for conditioning.
|
276 |
+
:return: a dict with the following keys:
|
277 |
+
- 'mean': the model mean output.
|
278 |
+
- 'variance': the model variance output.
|
279 |
+
- 'log_variance': the log of 'variance'.
|
280 |
+
- 'pred_xstart': the prediction for x_0.
|
281 |
+
"""
|
282 |
+
if model_kwargs is None:
|
283 |
+
model_kwargs = {}
|
284 |
+
|
285 |
+
B, C = x.shape[:2]
|
286 |
+
assert t.shape == (B,)
|
287 |
+
model_output = model(x, self._scale_timesteps(t), **model_kwargs)
|
288 |
+
|
289 |
+
model_variance, model_log_variance = {
|
290 |
+
# for fixedlarge, we set the initial (log-)variance like so
|
291 |
+
# to get a better decoder log likelihood.
|
292 |
+
ModelVarType.FIXED_LARGE: (
|
293 |
+
np.append(self.posterior_variance[1], self.betas[1:]),
|
294 |
+
np.log(np.append(self.posterior_variance[1], self.betas[1:])),
|
295 |
+
),
|
296 |
+
ModelVarType.FIXED_SMALL: (
|
297 |
+
self.posterior_variance,
|
298 |
+
self.posterior_log_variance_clipped,
|
299 |
+
),
|
300 |
+
}[self.model_var_type]
|
301 |
+
|
302 |
+
model_variance = _extract_into_tensor(model_variance, t, x.shape)
|
303 |
+
model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
|
304 |
+
|
305 |
+
def process_xstart(x):
|
306 |
+
if denoised_fn is not None:
|
307 |
+
x = denoised_fn(x)
|
308 |
+
if clip_denoised:
|
309 |
+
return x.clamp(-1, 1)
|
310 |
+
return x
|
311 |
+
|
312 |
+
pred_xstart = process_xstart(model_output)
|
313 |
+
pred_xstart = pred_xstart.permute(0, 2, 1).unsqueeze(2)
|
314 |
+
model_mean, _, _ = self.q_posterior_mean_variance(
|
315 |
+
x_start=pred_xstart, x_t=x, t=t
|
316 |
+
)
|
317 |
+
|
318 |
+
assert (
|
319 |
+
model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
|
320 |
+
), print(
|
321 |
+
f"{model_mean.shape} == {model_log_variance.shape} == {pred_xstart.shape} == {x.shape}"
|
322 |
+
)
|
323 |
+
return {
|
324 |
+
"mean": model_mean,
|
325 |
+
"variance": model_variance,
|
326 |
+
"log_variance": model_log_variance,
|
327 |
+
"pred_xstart": pred_xstart,
|
328 |
+
}
|
329 |
+
|
330 |
+
def _predict_xstart_from_eps(self, x_t, t, eps):
|
331 |
+
assert x_t.shape == eps.shape
|
332 |
+
return (
|
333 |
+
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
|
334 |
+
- _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
|
335 |
+
)
|
336 |
+
|
337 |
+
def _predict_xstart_from_xprev(self, x_t, t, xprev):
|
338 |
+
assert x_t.shape == xprev.shape
|
339 |
+
return (
|
340 |
+
_extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev
|
341 |
+
- _extract_into_tensor(
|
342 |
+
self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape
|
343 |
+
)
|
344 |
+
* x_t
|
345 |
+
)
|
346 |
+
|
347 |
+
def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
|
348 |
+
return (
|
349 |
+
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
|
350 |
+
- pred_xstart
|
351 |
+
) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
|
352 |
+
|
353 |
+
def _scale_timesteps(self, t):
|
354 |
+
if self.rescale_timesteps:
|
355 |
+
return t.float() * (1000.0 / self.num_timesteps)
|
356 |
+
return t
|
357 |
+
|
358 |
+
def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
|
359 |
+
"""
|
360 |
+
Compute the mean for the previous step, given a function cond_fn that
|
361 |
+
computes the gradient of a conditional log probability with respect to
|
362 |
+
x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
|
363 |
+
condition on y.
|
364 |
+
|
365 |
+
This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
|
366 |
+
"""
|
367 |
+
gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs)
|
368 |
+
new_mean = (
|
369 |
+
p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
|
370 |
+
)
|
371 |
+
return new_mean
|
372 |
+
|
373 |
+
def condition_mean_with_grad(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
|
374 |
+
"""
|
375 |
+
Compute the mean for the previous step, given a function cond_fn that
|
376 |
+
computes the gradient of a conditional log probability with respect to
|
377 |
+
x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
|
378 |
+
condition on y.
|
379 |
+
|
380 |
+
This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
|
381 |
+
"""
|
382 |
+
gradient = cond_fn(x, t, p_mean_var, **model_kwargs)
|
383 |
+
new_mean = (
|
384 |
+
p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
|
385 |
+
)
|
386 |
+
return new_mean
|
387 |
+
|
388 |
+
def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
|
389 |
+
"""
|
390 |
+
Compute what the p_mean_variance output would have been, should the
|
391 |
+
model's score function be conditioned by cond_fn.
|
392 |
+
|
393 |
+
See condition_mean() for details on cond_fn.
|
394 |
+
|
395 |
+
Unlike condition_mean(), this instead uses the conditioning strategy
|
396 |
+
from Song et al (2020).
|
397 |
+
"""
|
398 |
+
alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
|
399 |
+
|
400 |
+
eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
|
401 |
+
eps = eps - (1 - alpha_bar).sqrt() * cond_fn(
|
402 |
+
x, self._scale_timesteps(t), **model_kwargs
|
403 |
+
)
|
404 |
+
|
405 |
+
out = p_mean_var.copy()
|
406 |
+
out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
|
407 |
+
out["mean"], _, _ = self.q_posterior_mean_variance(
|
408 |
+
x_start=out["pred_xstart"], x_t=x, t=t
|
409 |
+
)
|
410 |
+
return out
|
411 |
+
|
412 |
+
def condition_score_with_grad(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
|
413 |
+
"""
|
414 |
+
Compute what the p_mean_variance output would have been, should the
|
415 |
+
model's score function be conditioned by cond_fn.
|
416 |
+
|
417 |
+
See condition_mean() for details on cond_fn.
|
418 |
+
|
419 |
+
Unlike condition_mean(), this instead uses the conditioning strategy
|
420 |
+
from Song et al (2020).
|
421 |
+
"""
|
422 |
+
alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
|
423 |
+
|
424 |
+
eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
|
425 |
+
eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, p_mean_var, **model_kwargs)
|
426 |
+
|
427 |
+
out = p_mean_var.copy()
|
428 |
+
out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
|
429 |
+
out["mean"], _, _ = self.q_posterior_mean_variance(
|
430 |
+
x_start=out["pred_xstart"], x_t=x, t=t
|
431 |
+
)
|
432 |
+
return out
|
433 |
+
|
434 |
+
def p_sample(
|
435 |
+
self,
|
436 |
+
model,
|
437 |
+
x,
|
438 |
+
t,
|
439 |
+
clip_denoised=True,
|
440 |
+
denoised_fn=None,
|
441 |
+
cond_fn=None,
|
442 |
+
model_kwargs=None,
|
443 |
+
const_noise=False,
|
444 |
+
):
|
445 |
+
"""
|
446 |
+
Sample x_{t-1} from the model at the given timestep.
|
447 |
+
|
448 |
+
:param model: the model to sample from.
|
449 |
+
:param x: the current tensor at x_{t-1}.
|
450 |
+
:param t: the value of t, starting at 0 for the first diffusion step.
|
451 |
+
:param clip_denoised: if True, clip the x_start prediction to [-1, 1].
|
452 |
+
:param denoised_fn: if not None, a function which applies to the
|
453 |
+
x_start prediction before it is used to sample.
|
454 |
+
:param cond_fn: if not None, this is a gradient function that acts
|
455 |
+
similarly to the model.
|
456 |
+
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
457 |
+
pass to the model. This can be used for conditioning.
|
458 |
+
:return: a dict containing the following keys:
|
459 |
+
- 'sample': a random sample from the model.
|
460 |
+
- 'pred_xstart': a prediction of x_0.
|
461 |
+
"""
|
462 |
+
out = self.p_mean_variance(
|
463 |
+
model,
|
464 |
+
x,
|
465 |
+
t,
|
466 |
+
clip_denoised=clip_denoised,
|
467 |
+
denoised_fn=denoised_fn,
|
468 |
+
model_kwargs=model_kwargs,
|
469 |
+
)
|
470 |
+
|
471 |
+
nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
|
472 |
+
if cond_fn is not None:
|
473 |
+
out["mean"] = self.condition_mean(
|
474 |
+
cond_fn, out, x, t, model_kwargs=model_kwargs
|
475 |
+
)
|
476 |
+
sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
|
477 |
+
return {"sample": sample, "pred_xstart": out["pred_xstart"]}
|
478 |
+
|
479 |
+
def p_sample_with_grad(
|
480 |
+
self,
|
481 |
+
model,
|
482 |
+
x,
|
483 |
+
t,
|
484 |
+
clip_denoised=True,
|
485 |
+
denoised_fn=None,
|
486 |
+
cond_fn=None,
|
487 |
+
model_kwargs=None,
|
488 |
+
):
|
489 |
+
"""
|
490 |
+
Sample x_{t-1} from the model at the given timestep.
|
491 |
+
|
492 |
+
:param model: the model to sample from.
|
493 |
+
:param x: the current tensor at x_{t-1}.
|
494 |
+
:param t: the value of t, starting at 0 for the first diffusion step.
|
495 |
+
:param clip_denoised: if True, clip the x_start prediction to [-1, 1].
|
496 |
+
:param denoised_fn: if not None, a function which applies to the
|
497 |
+
x_start prediction before it is used to sample.
|
498 |
+
:param cond_fn: if not None, this is a gradient function that acts
|
499 |
+
similarly to the model.
|
500 |
+
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
501 |
+
pass to the model. This can be used for conditioning.
|
502 |
+
:return: a dict containing the following keys:
|
503 |
+
- 'sample': a random sample from the model.
|
504 |
+
- 'pred_xstart': a prediction of x_0.
|
505 |
+
"""
|
506 |
+
with th.enable_grad():
|
507 |
+
x = x.detach().requires_grad_()
|
508 |
+
out = self.p_mean_variance(
|
509 |
+
model,
|
510 |
+
x,
|
511 |
+
t,
|
512 |
+
clip_denoised=clip_denoised,
|
513 |
+
denoised_fn=denoised_fn,
|
514 |
+
model_kwargs=model_kwargs,
|
515 |
+
)
|
516 |
+
noise = th.randn_like(x)
|
517 |
+
nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
|
518 |
+
if cond_fn is not None:
|
519 |
+
out["mean"] = self.condition_mean_with_grad(
|
520 |
+
cond_fn, out, x, t, model_kwargs=model_kwargs
|
521 |
+
)
|
522 |
+
sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
|
523 |
+
return {"sample": sample, "pred_xstart": out["pred_xstart"].detach()}
|
524 |
+
|
525 |
+
def p_sample_loop(
|
526 |
+
self,
|
527 |
+
model,
|
528 |
+
shape,
|
529 |
+
noise=None,
|
530 |
+
clip_denoised=True,
|
531 |
+
denoised_fn=None,
|
532 |
+
cond_fn=None,
|
533 |
+
model_kwargs=None,
|
534 |
+
device=None,
|
535 |
+
progress=False,
|
536 |
+
skip_timesteps=0,
|
537 |
+
init_image=None,
|
538 |
+
randomize_class=False,
|
539 |
+
cond_fn_with_grad=False,
|
540 |
+
dump_steps=None,
|
541 |
+
const_noise=False,
|
542 |
+
):
|
543 |
+
"""
|
544 |
+
Generate samples from the model.
|
545 |
+
|
546 |
+
:param model: the model module.
|
547 |
+
:param shape: the shape of the samples, (N, C, H, W).
|
548 |
+
:param noise: if specified, the noise from the encoder to sample.
|
549 |
+
Should be of the same shape as `shape`.
|
550 |
+
:param clip_denoised: if True, clip x_start predictions to [-1, 1].
|
551 |
+
:param denoised_fn: if not None, a function which applies to the
|
552 |
+
x_start prediction before it is used to sample.
|
553 |
+
:param cond_fn: if not None, this is a gradient function that acts
|
554 |
+
similarly to the model.
|
555 |
+
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
556 |
+
pass to the model. This can be used for conditioning.
|
557 |
+
:param device: if specified, the device to create the samples on.
|
558 |
+
If not specified, use a model parameter's device.
|
559 |
+
:param progress: if True, show a tqdm progress bar.
|
560 |
+
:param const_noise: If True, will noise all samples with the same noise throughout sampling
|
561 |
+
:return: a non-differentiable batch of samples.
|
562 |
+
"""
|
563 |
+
final = None
|
564 |
+
if dump_steps is not None:
|
565 |
+
dump = []
|
566 |
+
|
567 |
+
for i, sample in enumerate(
|
568 |
+
self.p_sample_loop_progressive(
|
569 |
+
model,
|
570 |
+
shape,
|
571 |
+
noise=noise,
|
572 |
+
clip_denoised=clip_denoised,
|
573 |
+
denoised_fn=denoised_fn,
|
574 |
+
cond_fn=cond_fn,
|
575 |
+
model_kwargs=model_kwargs,
|
576 |
+
device=device,
|
577 |
+
progress=progress,
|
578 |
+
skip_timesteps=skip_timesteps,
|
579 |
+
init_image=init_image,
|
580 |
+
randomize_class=randomize_class,
|
581 |
+
cond_fn_with_grad=cond_fn_with_grad,
|
582 |
+
const_noise=const_noise,
|
583 |
+
)
|
584 |
+
):
|
585 |
+
if dump_steps is not None and i in dump_steps:
|
586 |
+
dump.append(deepcopy(sample["sample"]))
|
587 |
+
final = sample
|
588 |
+
if dump_steps is not None:
|
589 |
+
return dump
|
590 |
+
return final["sample"]
|
591 |
+
|
592 |
+
def p_sample_loop_progressive(
|
593 |
+
self,
|
594 |
+
model,
|
595 |
+
shape,
|
596 |
+
noise=None,
|
597 |
+
clip_denoised=True,
|
598 |
+
denoised_fn=None,
|
599 |
+
cond_fn=None,
|
600 |
+
model_kwargs=None,
|
601 |
+
device=None,
|
602 |
+
progress=False,
|
603 |
+
skip_timesteps=0,
|
604 |
+
init_image=None,
|
605 |
+
randomize_class=False,
|
606 |
+
cond_fn_with_grad=False,
|
607 |
+
const_noise=False,
|
608 |
+
):
|
609 |
+
"""
|
610 |
+
Generate samples from the model and yield intermediate samples from
|
611 |
+
each timestep of diffusion.
|
612 |
+
|
613 |
+
Arguments are the same as p_sample_loop().
|
614 |
+
Returns a generator over dicts, where each dict is the return value of
|
615 |
+
p_sample().
|
616 |
+
"""
|
617 |
+
if device is None:
|
618 |
+
device = next(model.parameters()).device
|
619 |
+
assert isinstance(shape, (tuple, list))
|
620 |
+
if noise is not None:
|
621 |
+
img = noise
|
622 |
+
else:
|
623 |
+
img = th.randn(*shape, device=device)
|
624 |
+
|
625 |
+
if skip_timesteps and init_image is None:
|
626 |
+
init_image = th.zeros_like(img)
|
627 |
+
|
628 |
+
indices = list(range(self.num_timesteps - skip_timesteps))[::-1]
|
629 |
+
|
630 |
+
if init_image is not None:
|
631 |
+
my_t = th.ones([shape[0]], device=device, dtype=th.long) * indices[0]
|
632 |
+
img = self.q_sample(init_image, my_t, img)
|
633 |
+
|
634 |
+
if progress:
|
635 |
+
# Lazy import so that we don't depend on tqdm.
|
636 |
+
from tqdm.auto import tqdm
|
637 |
+
|
638 |
+
indices = tqdm(indices)
|
639 |
+
|
640 |
+
# number of timestamps to diffuse
|
641 |
+
for i in indices:
|
642 |
+
t = th.tensor([i] * shape[0], device=device)
|
643 |
+
if randomize_class and "y" in model_kwargs:
|
644 |
+
model_kwargs["y"] = th.randint(
|
645 |
+
low=0,
|
646 |
+
high=model.num_classes,
|
647 |
+
size=model_kwargs["y"].shape,
|
648 |
+
device=model_kwargs["y"].device,
|
649 |
+
)
|
650 |
+
with th.no_grad():
|
651 |
+
sample_fn = (
|
652 |
+
self.p_sample_with_grad if cond_fn_with_grad else self.p_sample
|
653 |
+
)
|
654 |
+
out = sample_fn(
|
655 |
+
model,
|
656 |
+
img,
|
657 |
+
t,
|
658 |
+
clip_denoised=clip_denoised,
|
659 |
+
denoised_fn=denoised_fn,
|
660 |
+
cond_fn=cond_fn,
|
661 |
+
model_kwargs=model_kwargs,
|
662 |
+
const_noise=const_noise,
|
663 |
+
)
|
664 |
+
yield out
|
665 |
+
img = out["sample"]
|
666 |
+
|
667 |
+
def ddim_sample(
|
668 |
+
self,
|
669 |
+
model,
|
670 |
+
x,
|
671 |
+
t,
|
672 |
+
clip_denoised=True,
|
673 |
+
denoised_fn=None,
|
674 |
+
cond_fn=None,
|
675 |
+
model_kwargs=None,
|
676 |
+
eta=0.0,
|
677 |
+
):
|
678 |
+
"""
|
679 |
+
Sample x_{t-1} from the model using DDIM.
|
680 |
+
|
681 |
+
Same usage as p_sample().
|
682 |
+
"""
|
683 |
+
out_orig = self.p_mean_variance(
|
684 |
+
model,
|
685 |
+
x,
|
686 |
+
t,
|
687 |
+
clip_denoised=clip_denoised,
|
688 |
+
denoised_fn=denoised_fn,
|
689 |
+
model_kwargs=model_kwargs,
|
690 |
+
)
|
691 |
+
if cond_fn is not None:
|
692 |
+
out = self.condition_score(
|
693 |
+
cond_fn, out_orig, x, t, model_kwargs=model_kwargs
|
694 |
+
)
|
695 |
+
else:
|
696 |
+
out = out_orig
|
697 |
+
# Usually our model outputs epsilon, but we re-derive it
|
698 |
+
# in case we used x_start or x_prev prediction.
|
699 |
+
eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
|
700 |
+
|
701 |
+
alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
|
702 |
+
alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
|
703 |
+
sigma = (
|
704 |
+
eta
|
705 |
+
* th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
|
706 |
+
* th.sqrt(1 - alpha_bar / alpha_bar_prev)
|
707 |
+
)
|
708 |
+
noise = th.randn_like(x)
|
709 |
+
|
710 |
+
mean_pred = (
|
711 |
+
out["pred_xstart"] * th.sqrt(alpha_bar_prev)
|
712 |
+
+ th.sqrt(1 - alpha_bar_prev - sigma**2) * eps
|
713 |
+
)
|
714 |
+
nonzero_mask = (
|
715 |
+
(t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
|
716 |
+
) # no noise when t == 0
|
717 |
+
sample = mean_pred + nonzero_mask * sigma * noise
|
718 |
+
return {"sample": sample, "pred_xstart": out_orig["pred_xstart"]}
|
719 |
+
|
720 |
+
def ddim_sample_with_grad(
|
721 |
+
self,
|
722 |
+
model,
|
723 |
+
x,
|
724 |
+
t,
|
725 |
+
clip_denoised=True,
|
726 |
+
denoised_fn=None,
|
727 |
+
cond_fn=None,
|
728 |
+
model_kwargs=None,
|
729 |
+
eta=0.0,
|
730 |
+
):
|
731 |
+
"""
|
732 |
+
Sample x_{t-1} from the model using DDIM.
|
733 |
+
|
734 |
+
Same usage as p_sample().
|
735 |
+
"""
|
736 |
+
with th.enable_grad():
|
737 |
+
x = x.detach().requires_grad_()
|
738 |
+
out_orig = self.p_mean_variance(
|
739 |
+
model,
|
740 |
+
x,
|
741 |
+
t,
|
742 |
+
clip_denoised=clip_denoised,
|
743 |
+
denoised_fn=denoised_fn,
|
744 |
+
model_kwargs=model_kwargs,
|
745 |
+
)
|
746 |
+
if cond_fn is not None:
|
747 |
+
out = self.condition_score_with_grad(
|
748 |
+
cond_fn, out_orig, x, t, model_kwargs=model_kwargs
|
749 |
+
)
|
750 |
+
else:
|
751 |
+
out = out_orig
|
752 |
+
|
753 |
+
out["pred_xstart"] = out["pred_xstart"].detach()
|
754 |
+
# Usually our model outputs epsilon, but we re-derive it
|
755 |
+
# in case we used x_start or x_prev prediction.
|
756 |
+
eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
|
757 |
+
|
758 |
+
alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
|
759 |
+
alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
|
760 |
+
sigma = (
|
761 |
+
eta
|
762 |
+
* th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
|
763 |
+
* th.sqrt(1 - alpha_bar / alpha_bar_prev)
|
764 |
+
)
|
765 |
+
# Equation 12.
|
766 |
+
noise = th.randn_like(x)
|
767 |
+
mean_pred = (
|
768 |
+
out["pred_xstart"] * th.sqrt(alpha_bar_prev)
|
769 |
+
+ th.sqrt(1 - alpha_bar_prev - sigma**2) * eps
|
770 |
+
)
|
771 |
+
nonzero_mask = (
|
772 |
+
(t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
|
773 |
+
) # no noise when t == 0
|
774 |
+
sample = mean_pred + nonzero_mask * sigma * noise
|
775 |
+
return {"sample": sample, "pred_xstart": out_orig["pred_xstart"].detach()}
|
776 |
+
|
777 |
+
def ddim_reverse_sample(
|
778 |
+
self,
|
779 |
+
model,
|
780 |
+
x,
|
781 |
+
t,
|
782 |
+
clip_denoised=True,
|
783 |
+
denoised_fn=None,
|
784 |
+
model_kwargs=None,
|
785 |
+
eta=0.0,
|
786 |
+
):
|
787 |
+
"""
|
788 |
+
Sample x_{t+1} from the model using DDIM reverse ODE.
|
789 |
+
"""
|
790 |
+
assert eta == 0.0, "Reverse ODE only for deterministic path"
|
791 |
+
out = self.p_mean_variance(
|
792 |
+
model,
|
793 |
+
x,
|
794 |
+
t,
|
795 |
+
clip_denoised=clip_denoised,
|
796 |
+
denoised_fn=denoised_fn,
|
797 |
+
model_kwargs=model_kwargs,
|
798 |
+
)
|
799 |
+
# Usually our model outputs epsilon, but we re-derive it
|
800 |
+
# in case we used x_start or x_prev prediction.
|
801 |
+
eps = (
|
802 |
+
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
|
803 |
+
- out["pred_xstart"]
|
804 |
+
) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
|
805 |
+
alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
|
806 |
+
|
807 |
+
# Equation 12. reversed
|
808 |
+
mean_pred = (
|
809 |
+
out["pred_xstart"] * th.sqrt(alpha_bar_next)
|
810 |
+
+ th.sqrt(1 - alpha_bar_next) * eps
|
811 |
+
)
|
812 |
+
|
813 |
+
return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
|
814 |
+
|
815 |
+
def ddim_sample_loop(
|
816 |
+
self,
|
817 |
+
model,
|
818 |
+
shape,
|
819 |
+
noise=None,
|
820 |
+
clip_denoised=True,
|
821 |
+
denoised_fn=None,
|
822 |
+
cond_fn=None,
|
823 |
+
model_kwargs=None,
|
824 |
+
device=None,
|
825 |
+
progress=False,
|
826 |
+
eta=0.0,
|
827 |
+
skip_timesteps=0,
|
828 |
+
init_image=None,
|
829 |
+
randomize_class=False,
|
830 |
+
cond_fn_with_grad=False,
|
831 |
+
dump_steps=None,
|
832 |
+
const_noise=False,
|
833 |
+
):
|
834 |
+
"""
|
835 |
+
Generate samples from the model using DDIM.
|
836 |
+
|
837 |
+
Same usage as p_sample_loop().
|
838 |
+
"""
|
839 |
+
if dump_steps is not None:
|
840 |
+
raise NotImplementedError()
|
841 |
+
if const_noise == True:
|
842 |
+
raise NotImplementedError()
|
843 |
+
|
844 |
+
final = None
|
845 |
+
for sample in self.ddim_sample_loop_progressive(
|
846 |
+
model,
|
847 |
+
shape,
|
848 |
+
noise=noise,
|
849 |
+
clip_denoised=clip_denoised,
|
850 |
+
denoised_fn=denoised_fn,
|
851 |
+
cond_fn=cond_fn,
|
852 |
+
model_kwargs=model_kwargs,
|
853 |
+
device=device,
|
854 |
+
progress=progress,
|
855 |
+
eta=eta,
|
856 |
+
skip_timesteps=skip_timesteps,
|
857 |
+
init_image=init_image,
|
858 |
+
randomize_class=randomize_class,
|
859 |
+
cond_fn_with_grad=cond_fn_with_grad,
|
860 |
+
):
|
861 |
+
final = sample
|
862 |
+
return final["pred_xstart"]
|
863 |
+
|
864 |
+
def ddim_sample_loop_progressive(
|
865 |
+
self,
|
866 |
+
model,
|
867 |
+
shape,
|
868 |
+
noise=None,
|
869 |
+
clip_denoised=True,
|
870 |
+
denoised_fn=None,
|
871 |
+
cond_fn=None,
|
872 |
+
model_kwargs=None,
|
873 |
+
device=None,
|
874 |
+
progress=False,
|
875 |
+
eta=0.0,
|
876 |
+
skip_timesteps=0,
|
877 |
+
init_image=None,
|
878 |
+
randomize_class=False,
|
879 |
+
cond_fn_with_grad=False,
|
880 |
+
):
|
881 |
+
"""
|
882 |
+
Use DDIM to sample from the model and yield intermediate samples from
|
883 |
+
each timestep of DDIM.
|
884 |
+
|
885 |
+
Same usage as p_sample_loop_progressive().
|
886 |
+
"""
|
887 |
+
if device is None:
|
888 |
+
device = next(model.parameters()).device
|
889 |
+
assert isinstance(shape, (tuple, list))
|
890 |
+
if noise is not None:
|
891 |
+
img = noise
|
892 |
+
else:
|
893 |
+
img = th.randn(*shape, device=device)
|
894 |
+
|
895 |
+
if skip_timesteps and init_image is None:
|
896 |
+
init_image = th.zeros_like(img)
|
897 |
+
|
898 |
+
indices = list(range(self.num_timesteps - skip_timesteps))[::-1]
|
899 |
+
|
900 |
+
if init_image is not None:
|
901 |
+
my_t = th.ones([shape[0]], device=device, dtype=th.long) * indices[0]
|
902 |
+
img = self.q_sample(init_image, my_t, img)
|
903 |
+
|
904 |
+
if progress:
|
905 |
+
# Lazy import so that we don't depend on tqdm.
|
906 |
+
from tqdm.auto import tqdm
|
907 |
+
|
908 |
+
indices = tqdm(indices)
|
909 |
+
|
910 |
+
for i in indices:
|
911 |
+
t = th.tensor([i] * shape[0], device=device)
|
912 |
+
if randomize_class and "y" in model_kwargs:
|
913 |
+
model_kwargs["y"] = th.randint(
|
914 |
+
low=0,
|
915 |
+
high=model.num_classes,
|
916 |
+
size=model_kwargs["y"].shape,
|
917 |
+
device=model_kwargs["y"].device,
|
918 |
+
)
|
919 |
+
with th.no_grad():
|
920 |
+
sample_fn = (
|
921 |
+
self.ddim_sample_with_grad
|
922 |
+
if cond_fn_with_grad
|
923 |
+
else self.ddim_sample
|
924 |
+
)
|
925 |
+
out = sample_fn(
|
926 |
+
model,
|
927 |
+
img,
|
928 |
+
t,
|
929 |
+
clip_denoised=clip_denoised,
|
930 |
+
denoised_fn=denoised_fn,
|
931 |
+
cond_fn=cond_fn,
|
932 |
+
model_kwargs=model_kwargs,
|
933 |
+
eta=eta,
|
934 |
+
)
|
935 |
+
yield out
|
936 |
+
img = out["sample"]
|
937 |
+
|
938 |
+
def plms_sample(
|
939 |
+
self,
|
940 |
+
model,
|
941 |
+
x,
|
942 |
+
t,
|
943 |
+
clip_denoised=True,
|
944 |
+
denoised_fn=None,
|
945 |
+
cond_fn=None,
|
946 |
+
model_kwargs=None,
|
947 |
+
cond_fn_with_grad=False,
|
948 |
+
order=2,
|
949 |
+
old_out=None,
|
950 |
+
):
|
951 |
+
"""
|
952 |
+
Sample x_{t-1} from the model using Pseudo Linear Multistep.
|
953 |
+
|
954 |
+
Same usage as p_sample().
|
955 |
+
"""
|
956 |
+
if not int(order) or not 1 <= order <= 4:
|
957 |
+
raise ValueError("order is invalid (should be int from 1-4).")
|
958 |
+
|
959 |
+
def get_model_output(x, t):
|
960 |
+
with th.set_grad_enabled(cond_fn_with_grad and cond_fn is not None):
|
961 |
+
x = x.detach().requires_grad_() if cond_fn_with_grad else x
|
962 |
+
out_orig = self.p_mean_variance(
|
963 |
+
model,
|
964 |
+
x,
|
965 |
+
t,
|
966 |
+
clip_denoised=clip_denoised,
|
967 |
+
denoised_fn=denoised_fn,
|
968 |
+
model_kwargs=model_kwargs,
|
969 |
+
)
|
970 |
+
if cond_fn is not None:
|
971 |
+
if cond_fn_with_grad:
|
972 |
+
out = self.condition_score_with_grad(
|
973 |
+
cond_fn, out_orig, x, t, model_kwargs=model_kwargs
|
974 |
+
)
|
975 |
+
x = x.detach()
|
976 |
+
else:
|
977 |
+
out = self.condition_score(
|
978 |
+
cond_fn, out_orig, x, t, model_kwargs=model_kwargs
|
979 |
+
)
|
980 |
+
else:
|
981 |
+
out = out_orig
|
982 |
+
|
983 |
+
# Usually our model outputs epsilon, but we re-derive it
|
984 |
+
# in case we used x_start or x_prev prediction.
|
985 |
+
eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
|
986 |
+
return eps, out, out_orig
|
987 |
+
|
988 |
+
alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
|
989 |
+
alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
|
990 |
+
eps, out, out_orig = get_model_output(x, t)
|
991 |
+
|
992 |
+
if order > 1 and old_out is None:
|
993 |
+
# Pseudo Improved Euler
|
994 |
+
old_eps = [eps]
|
995 |
+
mean_pred = (
|
996 |
+
out["pred_xstart"] * th.sqrt(alpha_bar_prev)
|
997 |
+
+ th.sqrt(1 - alpha_bar_prev) * eps
|
998 |
+
)
|
999 |
+
eps_2, _, _ = get_model_output(mean_pred, t - 1)
|
1000 |
+
eps_prime = (eps + eps_2) / 2
|
1001 |
+
pred_prime = self._predict_xstart_from_eps(x, t, eps_prime)
|
1002 |
+
mean_pred = (
|
1003 |
+
pred_prime * th.sqrt(alpha_bar_prev)
|
1004 |
+
+ th.sqrt(1 - alpha_bar_prev) * eps_prime
|
1005 |
+
)
|
1006 |
+
else:
|
1007 |
+
# Pseudo Linear Multistep (Adams-Bashforth)
|
1008 |
+
old_eps = old_out["old_eps"]
|
1009 |
+
old_eps.append(eps)
|
1010 |
+
cur_order = min(order, len(old_eps))
|
1011 |
+
if cur_order == 1:
|
1012 |
+
eps_prime = old_eps[-1]
|
1013 |
+
elif cur_order == 2:
|
1014 |
+
eps_prime = (3 * old_eps[-1] - old_eps[-2]) / 2
|
1015 |
+
elif cur_order == 3:
|
1016 |
+
eps_prime = (23 * old_eps[-1] - 16 * old_eps[-2] + 5 * old_eps[-3]) / 12
|
1017 |
+
elif cur_order == 4:
|
1018 |
+
eps_prime = (
|
1019 |
+
55 * old_eps[-1]
|
1020 |
+
- 59 * old_eps[-2]
|
1021 |
+
+ 37 * old_eps[-3]
|
1022 |
+
- 9 * old_eps[-4]
|
1023 |
+
) / 24
|
1024 |
+
else:
|
1025 |
+
raise RuntimeError("cur_order is invalid.")
|
1026 |
+
pred_prime = self._predict_xstart_from_eps(x, t, eps_prime)
|
1027 |
+
mean_pred = (
|
1028 |
+
pred_prime * th.sqrt(alpha_bar_prev)
|
1029 |
+
+ th.sqrt(1 - alpha_bar_prev) * eps_prime
|
1030 |
+
)
|
1031 |
+
|
1032 |
+
if len(old_eps) >= order:
|
1033 |
+
old_eps.pop(0)
|
1034 |
+
|
1035 |
+
nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
|
1036 |
+
sample = mean_pred * nonzero_mask + out["pred_xstart"] * (1 - nonzero_mask)
|
1037 |
+
|
1038 |
+
return {
|
1039 |
+
"sample": sample,
|
1040 |
+
"pred_xstart": out_orig["pred_xstart"],
|
1041 |
+
"old_eps": old_eps,
|
1042 |
+
}
|
1043 |
+
|
1044 |
+
def plms_sample_loop(
|
1045 |
+
self,
|
1046 |
+
model,
|
1047 |
+
shape,
|
1048 |
+
noise=None,
|
1049 |
+
clip_denoised=True,
|
1050 |
+
denoised_fn=None,
|
1051 |
+
cond_fn=None,
|
1052 |
+
model_kwargs=None,
|
1053 |
+
device=None,
|
1054 |
+
progress=False,
|
1055 |
+
skip_timesteps=0,
|
1056 |
+
init_image=None,
|
1057 |
+
randomize_class=False,
|
1058 |
+
cond_fn_with_grad=False,
|
1059 |
+
order=2,
|
1060 |
+
):
|
1061 |
+
"""
|
1062 |
+
Generate samples from the model using Pseudo Linear Multistep.
|
1063 |
+
|
1064 |
+
Same usage as p_sample_loop().
|
1065 |
+
"""
|
1066 |
+
final = None
|
1067 |
+
for sample in self.plms_sample_loop_progressive(
|
1068 |
+
model,
|
1069 |
+
shape,
|
1070 |
+
noise=noise,
|
1071 |
+
clip_denoised=clip_denoised,
|
1072 |
+
denoised_fn=denoised_fn,
|
1073 |
+
cond_fn=cond_fn,
|
1074 |
+
model_kwargs=model_kwargs,
|
1075 |
+
device=device,
|
1076 |
+
progress=progress,
|
1077 |
+
skip_timesteps=skip_timesteps,
|
1078 |
+
init_image=init_image,
|
1079 |
+
randomize_class=randomize_class,
|
1080 |
+
cond_fn_with_grad=cond_fn_with_grad,
|
1081 |
+
order=order,
|
1082 |
+
):
|
1083 |
+
final = sample
|
1084 |
+
return final["sample"]
|
1085 |
+
|
1086 |
+
def plms_sample_loop_progressive(
|
1087 |
+
self,
|
1088 |
+
model,
|
1089 |
+
shape,
|
1090 |
+
noise=None,
|
1091 |
+
clip_denoised=True,
|
1092 |
+
denoised_fn=None,
|
1093 |
+
cond_fn=None,
|
1094 |
+
model_kwargs=None,
|
1095 |
+
device=None,
|
1096 |
+
progress=False,
|
1097 |
+
skip_timesteps=0,
|
1098 |
+
init_image=None,
|
1099 |
+
randomize_class=False,
|
1100 |
+
cond_fn_with_grad=False,
|
1101 |
+
order=2,
|
1102 |
+
):
|
1103 |
+
"""
|
1104 |
+
Use PLMS to sample from the model and yield intermediate samples from each
|
1105 |
+
timestep of PLMS.
|
1106 |
+
|
1107 |
+
Same usage as p_sample_loop_progressive().
|
1108 |
+
"""
|
1109 |
+
if device is None:
|
1110 |
+
device = next(model.parameters()).device
|
1111 |
+
assert isinstance(shape, (tuple, list))
|
1112 |
+
if noise is not None:
|
1113 |
+
img = noise
|
1114 |
+
else:
|
1115 |
+
img = th.randn(*shape, device=device)
|
1116 |
+
|
1117 |
+
if skip_timesteps and init_image is None:
|
1118 |
+
init_image = th.zeros_like(img)
|
1119 |
+
|
1120 |
+
indices = list(range(self.num_timesteps - skip_timesteps))[::-1]
|
1121 |
+
|
1122 |
+
if init_image is not None:
|
1123 |
+
my_t = th.ones([shape[0]], device=device, dtype=th.long) * indices[0]
|
1124 |
+
img = self.q_sample(init_image, my_t, img)
|
1125 |
+
|
1126 |
+
if progress:
|
1127 |
+
# Lazy import so that we don't depend on tqdm.
|
1128 |
+
from tqdm.auto import tqdm
|
1129 |
+
|
1130 |
+
indices = tqdm(indices)
|
1131 |
+
|
1132 |
+
old_out = None
|
1133 |
+
|
1134 |
+
for i in indices:
|
1135 |
+
t = th.tensor([i] * shape[0], device=device)
|
1136 |
+
if randomize_class and "y" in model_kwargs:
|
1137 |
+
model_kwargs["y"] = th.randint(
|
1138 |
+
low=0,
|
1139 |
+
high=model.num_classes,
|
1140 |
+
size=model_kwargs["y"].shape,
|
1141 |
+
device=model_kwargs["y"].device,
|
1142 |
+
)
|
1143 |
+
with th.no_grad():
|
1144 |
+
out = self.plms_sample(
|
1145 |
+
model,
|
1146 |
+
img,
|
1147 |
+
t,
|
1148 |
+
clip_denoised=clip_denoised,
|
1149 |
+
denoised_fn=denoised_fn,
|
1150 |
+
cond_fn=cond_fn,
|
1151 |
+
model_kwargs=model_kwargs,
|
1152 |
+
cond_fn_with_grad=cond_fn_with_grad,
|
1153 |
+
order=order,
|
1154 |
+
old_out=old_out,
|
1155 |
+
)
|
1156 |
+
yield out
|
1157 |
+
old_out = out
|
1158 |
+
img = out["sample"]
|
1159 |
+
|
1160 |
+
def _vb_terms_bpd(
|
1161 |
+
self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None
|
1162 |
+
):
|
1163 |
+
"""
|
1164 |
+
Get a term for the variational lower-bound.
|
1165 |
+
|
1166 |
+
The resulting units are bits (rather than nats, as one might expect).
|
1167 |
+
This allows for comparison to other papers.
|
1168 |
+
|
1169 |
+
:return: a dict with the following keys:
|
1170 |
+
- 'output': a shape [N] tensor of NLLs or KLs.
|
1171 |
+
- 'pred_xstart': the x_0 predictions.
|
1172 |
+
"""
|
1173 |
+
true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
|
1174 |
+
x_start=x_start, x_t=x_t, t=t
|
1175 |
+
)
|
1176 |
+
out = self.p_mean_variance(
|
1177 |
+
model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
|
1178 |
+
)
|
1179 |
+
kl = normal_kl(
|
1180 |
+
true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]
|
1181 |
+
)
|
1182 |
+
kl = mean_flat(kl) / np.log(2.0)
|
1183 |
+
|
1184 |
+
decoder_nll = -discretized_gaussian_log_likelihood(
|
1185 |
+
x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
|
1186 |
+
)
|
1187 |
+
assert decoder_nll.shape == x_start.shape
|
1188 |
+
decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
|
1189 |
+
|
1190 |
+
# At the first timestep return the decoder NLL,
|
1191 |
+
# otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
|
1192 |
+
output = th.where((t == 0), decoder_nll, kl)
|
1193 |
+
return {"output": output, "pred_xstart": out["pred_xstart"]}
|
1194 |
+
|
1195 |
+
def training_losses(self, model, x_start, t, model_kwargs=None, noise=None):
|
1196 |
+
"""
|
1197 |
+
Compute training losses for a single timestep.
|
1198 |
+
|
1199 |
+
:param model: the model to evaluate loss on.
|
1200 |
+
:param x_start: the [N x C x ...] tensor of inputs.
|
1201 |
+
:param t: a batch of timestep indices.
|
1202 |
+
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
1203 |
+
pass to the model. This can be used for conditioning.
|
1204 |
+
:param noise: if specified, the specific Gaussian noise to try to remove.
|
1205 |
+
:return: a dict with the key "loss" containing a tensor of shape [N].
|
1206 |
+
Some mean or variance settings may also have other keys.
|
1207 |
+
"""
|
1208 |
+
mask = model_kwargs["y"]["mask"]
|
1209 |
+
if model_kwargs is None:
|
1210 |
+
model_kwargs = {}
|
1211 |
+
if noise is None:
|
1212 |
+
noise = th.randn_like(x_start)
|
1213 |
+
x_t = self.q_sample(
|
1214 |
+
x_start, t, noise=noise
|
1215 |
+
) # use the formula to diffuse the starting tensor by t steps
|
1216 |
+
terms = {}
|
1217 |
+
|
1218 |
+
# set random dropout for conditioning in training
|
1219 |
+
model_kwargs["cond_drop_prob"] = 0.2
|
1220 |
+
model_output = model(x_t, self._scale_timesteps(t), **model_kwargs)
|
1221 |
+
target = {
|
1222 |
+
ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(
|
1223 |
+
x_start=x_start, x_t=x_t, t=t
|
1224 |
+
)[0],
|
1225 |
+
ModelMeanType.START_X: x_start,
|
1226 |
+
ModelMeanType.EPSILON: noise,
|
1227 |
+
}[self.model_mean_type]
|
1228 |
+
|
1229 |
+
model_output = model_output.permute(0, 2, 1).unsqueeze(2)
|
1230 |
+
assert model_output.shape == target.shape == x_start.shape
|
1231 |
+
|
1232 |
+
missing_mask = model_kwargs["y"]["missing"][..., 0]
|
1233 |
+
missing_mask = missing_mask.unsqueeze(1).unsqueeze(1)
|
1234 |
+
missing_mask = mask * missing_mask
|
1235 |
+
terms["rot_mse"] = self.masked_l2(target, model_output, missing_mask)
|
1236 |
+
if self.lambda_vel > 0.0:
|
1237 |
+
target_vel = target[..., 1:] - target[..., :-1]
|
1238 |
+
model_output_vel = model_output[..., 1:] - model_output[..., :-1]
|
1239 |
+
terms["vel_mse"] = self.masked_l2(
|
1240 |
+
target_vel,
|
1241 |
+
model_output_vel,
|
1242 |
+
mask[:, :, :, 1:],
|
1243 |
+
)
|
1244 |
+
|
1245 |
+
terms["loss"] = terms["rot_mse"] + (self.lambda_vel * terms.get("vel_mse", 0.0))
|
1246 |
+
|
1247 |
+
with torch.no_grad():
|
1248 |
+
terms["vb"] = self._vb_terms_bpd(
|
1249 |
+
model,
|
1250 |
+
x_start,
|
1251 |
+
x_t,
|
1252 |
+
t,
|
1253 |
+
clip_denoised=False,
|
1254 |
+
model_kwargs=model_kwargs,
|
1255 |
+
)["output"]
|
1256 |
+
|
1257 |
+
return terms
|
1258 |
+
|
1259 |
+
|
1260 |
+
def _extract_into_tensor(arr, timesteps, broadcast_shape):
|
1261 |
+
"""
|
1262 |
+
Extract values from a 1-D numpy array for a batch of indices.
|
1263 |
+
|
1264 |
+
:param arr: the 1-D numpy array.
|
1265 |
+
:param timesteps: a tensor of indices into the array to extract.
|
1266 |
+
:param broadcast_shape: a larger shape of K dimensions with the batch
|
1267 |
+
dimension equal to the length of timesteps.
|
1268 |
+
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
|
1269 |
+
"""
|
1270 |
+
res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
|
1271 |
+
while len(res.shape) < len(broadcast_shape):
|
1272 |
+
res = res[..., None]
|
1273 |
+
return res.expand(broadcast_shape)
|
diffusion/losses.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) Meta Platforms, Inc. and affiliates.
|
3 |
+
All rights reserved.
|
4 |
+
This source code is licensed under the license found in the
|
5 |
+
LICENSE file in the root directory of this source tree.
|
6 |
+
"""
|
7 |
+
|
8 |
+
"""
|
9 |
+
Helpers for various likelihood-based losses. These are ported from the original
|
10 |
+
Ho et al. diffusion models codebase:
|
11 |
+
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/utils.py
|
12 |
+
"""
|
13 |
+
|
14 |
+
import numpy as np
|
15 |
+
import torch as th
|
16 |
+
|
17 |
+
|
18 |
+
def normal_kl(mean1, logvar1, mean2, logvar2):
|
19 |
+
"""
|
20 |
+
Compute the KL divergence between two gaussians.
|
21 |
+
|
22 |
+
Shapes are automatically broadcasted, so batches can be compared to
|
23 |
+
scalars, among other use cases.
|
24 |
+
"""
|
25 |
+
tensor = None
|
26 |
+
for obj in (mean1, logvar1, mean2, logvar2):
|
27 |
+
if isinstance(obj, th.Tensor):
|
28 |
+
tensor = obj
|
29 |
+
break
|
30 |
+
assert tensor is not None, "at least one argument must be a Tensor"
|
31 |
+
|
32 |
+
# Force variances to be Tensors. Broadcasting helps convert scalars to
|
33 |
+
# Tensors, but it does not work for th.exp().
|
34 |
+
logvar1, logvar2 = [
|
35 |
+
x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
|
36 |
+
for x in (logvar1, logvar2)
|
37 |
+
]
|
38 |
+
|
39 |
+
return 0.5 * (
|
40 |
+
-1.0
|
41 |
+
+ logvar2
|
42 |
+
- logvar1
|
43 |
+
+ th.exp(logvar1 - logvar2)
|
44 |
+
+ ((mean1 - mean2) ** 2) * th.exp(-logvar2)
|
45 |
+
)
|
46 |
+
|
47 |
+
|
48 |
+
def approx_standard_normal_cdf(x):
|
49 |
+
"""
|
50 |
+
A fast approximation of the cumulative distribution function of the
|
51 |
+
standard normal.
|
52 |
+
"""
|
53 |
+
return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
|
54 |
+
|
55 |
+
|
56 |
+
def discretized_gaussian_log_likelihood(x, *, means, log_scales):
|
57 |
+
"""
|
58 |
+
Compute the log-likelihood of a Gaussian distribution discretizing to a
|
59 |
+
given image.
|
60 |
+
|
61 |
+
:param x: the target images. It is assumed that this was uint8 values,
|
62 |
+
rescaled to the range [-1, 1].
|
63 |
+
:param means: the Gaussian mean Tensor.
|
64 |
+
:param log_scales: the Gaussian log stddev Tensor.
|
65 |
+
:return: a tensor like x of log probabilities (in nats).
|
66 |
+
"""
|
67 |
+
assert x.shape == means.shape == log_scales.shape
|
68 |
+
centered_x = x - means
|
69 |
+
inv_stdv = th.exp(-log_scales)
|
70 |
+
plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
|
71 |
+
cdf_plus = approx_standard_normal_cdf(plus_in)
|
72 |
+
min_in = inv_stdv * (centered_x - 1.0 / 255.0)
|
73 |
+
cdf_min = approx_standard_normal_cdf(min_in)
|
74 |
+
log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
|
75 |
+
log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
|
76 |
+
cdf_delta = cdf_plus - cdf_min
|
77 |
+
log_probs = th.where(
|
78 |
+
x < -0.999,
|
79 |
+
log_cdf_plus,
|
80 |
+
th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
|
81 |
+
)
|
82 |
+
assert log_probs.shape == x.shape
|
83 |
+
return log_probs
|
diffusion/nn.py
ADDED
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) Meta Platforms, Inc. and affiliates.
|
3 |
+
All rights reserved.
|
4 |
+
This source code is licensed under the license found in the
|
5 |
+
LICENSE file in the root directory of this source tree.
|
6 |
+
"""
|
7 |
+
|
8 |
+
"""
|
9 |
+
original code from
|
10 |
+
https://github.com/GuyTevet/motion-diffusion-model/blob/main/diffusion/gaussian_diffusion.py
|
11 |
+
under an MIT license
|
12 |
+
https://github.com/GuyTevet/motion-diffusion-model/blob/main/LICENSE
|
13 |
+
"""
|
14 |
+
|
15 |
+
"""
|
16 |
+
Various utilities for neural networks.
|
17 |
+
"""
|
18 |
+
|
19 |
+
import math
|
20 |
+
|
21 |
+
import torch as th
|
22 |
+
import torch.nn as nn
|
23 |
+
|
24 |
+
|
25 |
+
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
|
26 |
+
class SiLU(nn.Module):
|
27 |
+
def forward(self, x):
|
28 |
+
return x * th.sigmoid(x)
|
29 |
+
|
30 |
+
|
31 |
+
class GroupNorm32(nn.GroupNorm):
|
32 |
+
def forward(self, x):
|
33 |
+
return super().forward(x.float()).type(x.dtype)
|
34 |
+
|
35 |
+
|
36 |
+
def conv_nd(dims, *args, **kwargs):
|
37 |
+
"""
|
38 |
+
Create a 1D, 2D, or 3D convolution module.
|
39 |
+
"""
|
40 |
+
if dims == 1:
|
41 |
+
return nn.Conv1d(*args, **kwargs)
|
42 |
+
elif dims == 2:
|
43 |
+
return nn.Conv2d(*args, **kwargs)
|
44 |
+
elif dims == 3:
|
45 |
+
return nn.Conv3d(*args, **kwargs)
|
46 |
+
raise ValueError(f"unsupported dimensions: {dims}")
|
47 |
+
|
48 |
+
|
49 |
+
def linear(*args, **kwargs):
|
50 |
+
"""
|
51 |
+
Create a linear module.
|
52 |
+
"""
|
53 |
+
return nn.Linear(*args, **kwargs)
|
54 |
+
|
55 |
+
|
56 |
+
def avg_pool_nd(dims, *args, **kwargs):
|
57 |
+
"""
|
58 |
+
Create a 1D, 2D, or 3D average pooling module.
|
59 |
+
"""
|
60 |
+
if dims == 1:
|
61 |
+
return nn.AvgPool1d(*args, **kwargs)
|
62 |
+
elif dims == 2:
|
63 |
+
return nn.AvgPool2d(*args, **kwargs)
|
64 |
+
elif dims == 3:
|
65 |
+
return nn.AvgPool3d(*args, **kwargs)
|
66 |
+
raise ValueError(f"unsupported dimensions: {dims}")
|
67 |
+
|
68 |
+
|
69 |
+
def update_ema(target_params, source_params, rate=0.99):
|
70 |
+
"""
|
71 |
+
Update target parameters to be closer to those of source parameters using
|
72 |
+
an exponential moving average.
|
73 |
+
|
74 |
+
:param target_params: the target parameter sequence.
|
75 |
+
:param source_params: the source parameter sequence.
|
76 |
+
:param rate: the EMA rate (closer to 1 means slower).
|
77 |
+
"""
|
78 |
+
for targ, src in zip(target_params, source_params):
|
79 |
+
targ.detach().mul_(rate).add_(src, alpha=1 - rate)
|
80 |
+
|
81 |
+
|
82 |
+
def zero_module(module):
|
83 |
+
"""
|
84 |
+
Zero out the parameters of a module and return it.
|
85 |
+
"""
|
86 |
+
for p in module.parameters():
|
87 |
+
p.detach().zero_()
|
88 |
+
return module
|
89 |
+
|
90 |
+
|
91 |
+
def scale_module(module, scale):
|
92 |
+
"""
|
93 |
+
Scale the parameters of a module and return it.
|
94 |
+
"""
|
95 |
+
for p in module.parameters():
|
96 |
+
p.detach().mul_(scale)
|
97 |
+
return module
|
98 |
+
|
99 |
+
|
100 |
+
def mean_flat(tensor):
|
101 |
+
"""
|
102 |
+
Take the mean over all non-batch dimensions.
|
103 |
+
"""
|
104 |
+
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
105 |
+
|
106 |
+
|
107 |
+
def sum_flat(tensor):
|
108 |
+
"""
|
109 |
+
Take the sum over all non-batch dimensions.
|
110 |
+
"""
|
111 |
+
return tensor.sum(dim=list(range(1, len(tensor.shape))))
|
112 |
+
|
113 |
+
|
114 |
+
def normalization(channels):
|
115 |
+
"""
|
116 |
+
Make a standard normalization layer.
|
117 |
+
|
118 |
+
:param channels: number of input channels.
|
119 |
+
:return: an nn.Module for normalization.
|
120 |
+
"""
|
121 |
+
return GroupNorm32(32, channels)
|
122 |
+
|
123 |
+
|
124 |
+
def timestep_embedding(timesteps, dim, max_period=10000):
|
125 |
+
"""
|
126 |
+
Create sinusoidal timestep embeddings.
|
127 |
+
|
128 |
+
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
129 |
+
These may be fractional.
|
130 |
+
:param dim: the dimension of the output.
|
131 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
132 |
+
:return: an [N x dim] Tensor of positional embeddings.
|
133 |
+
"""
|
134 |
+
half = dim // 2
|
135 |
+
freqs = th.exp(
|
136 |
+
-math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half
|
137 |
+
).to(device=timesteps.device)
|
138 |
+
args = timesteps[:, None].float() * freqs[None]
|
139 |
+
embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
|
140 |
+
if dim % 2:
|
141 |
+
embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)
|
142 |
+
return embedding
|
143 |
+
|
144 |
+
|
145 |
+
def checkpoint(func, inputs, params, flag):
|
146 |
+
"""
|
147 |
+
Evaluate a function without caching intermediate activations, allowing for
|
148 |
+
reduced memory at the expense of extra compute in the backward pass.
|
149 |
+
:param func: the function to evaluate.
|
150 |
+
:param inputs: the argument sequence to pass to `func`.
|
151 |
+
:param params: a sequence of parameters `func` depends on but does not
|
152 |
+
explicitly take as arguments.
|
153 |
+
:param flag: if False, disable gradient checkpointing.
|
154 |
+
"""
|
155 |
+
if flag:
|
156 |
+
args = tuple(inputs) + tuple(params)
|
157 |
+
return CheckpointFunction.apply(func, len(inputs), *args)
|
158 |
+
else:
|
159 |
+
return func(*inputs)
|
160 |
+
|
161 |
+
|
162 |
+
class CheckpointFunction(th.autograd.Function):
|
163 |
+
@staticmethod
|
164 |
+
@th.cuda.amp.custom_fwd
|
165 |
+
def forward(ctx, run_function, length, *args):
|
166 |
+
ctx.run_function = run_function
|
167 |
+
ctx.input_length = length
|
168 |
+
ctx.save_for_backward(*args)
|
169 |
+
with th.no_grad():
|
170 |
+
output_tensors = ctx.run_function(*args[:length])
|
171 |
+
return output_tensors
|
172 |
+
|
173 |
+
@staticmethod
|
174 |
+
@th.cuda.amp.custom_bwd
|
175 |
+
def backward(ctx, *output_grads):
|
176 |
+
args = list(ctx.saved_tensors)
|
177 |
+
|
178 |
+
# Filter for inputs that require grad. If none, exit early.
|
179 |
+
input_indices = [i for (i, x) in enumerate(args) if x.requires_grad]
|
180 |
+
if not input_indices:
|
181 |
+
return (None, None) + tuple(None for _ in args)
|
182 |
+
|
183 |
+
with th.enable_grad():
|
184 |
+
for i in input_indices:
|
185 |
+
if i < ctx.input_length:
|
186 |
+
# Not sure why the OAI code does this little
|
187 |
+
# dance. It might not be necessary.
|
188 |
+
args[i] = args[i].detach().requires_grad_()
|
189 |
+
args[i] = args[i].view_as(args[i])
|
190 |
+
output_tensors = ctx.run_function(*args[: ctx.input_length])
|
191 |
+
|
192 |
+
if isinstance(output_tensors, th.Tensor):
|
193 |
+
output_tensors = [output_tensors]
|
194 |
+
|
195 |
+
# Filter for outputs that require grad. If none, exit early.
|
196 |
+
out_and_grads = [
|
197 |
+
(o, g) for (o, g) in zip(output_tensors, output_grads) if o.requires_grad
|
198 |
+
]
|
199 |
+
if not out_and_grads:
|
200 |
+
return (None, None) + tuple(None for _ in args)
|
201 |
+
|
202 |
+
# Compute gradients on the filtered tensors.
|
203 |
+
computed_grads = th.autograd.grad(
|
204 |
+
[o for (o, g) in out_and_grads],
|
205 |
+
[args[i] for i in input_indices],
|
206 |
+
[g for (o, g) in out_and_grads],
|
207 |
+
)
|
208 |
+
|
209 |
+
# Reassemble the complete gradient tuple.
|
210 |
+
input_grads = [None for _ in args]
|
211 |
+
for i, g in zip(input_indices, computed_grads):
|
212 |
+
input_grads[i] = g
|
213 |
+
return (None, None) + tuple(input_grads)
|
diffusion/resample.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) Meta Platforms, Inc. and affiliates.
|
3 |
+
All rights reserved.
|
4 |
+
This source code is licensed under the license found in the
|
5 |
+
LICENSE file in the root directory of this source tree.
|
6 |
+
"""
|
7 |
+
|
8 |
+
"""
|
9 |
+
original code from
|
10 |
+
https://github.com/GuyTevet/motion-diffusion-model/blob/main/diffusion/gaussian_diffusion.py
|
11 |
+
under an MIT license
|
12 |
+
https://github.com/GuyTevet/motion-diffusion-model/blob/main/LICENSE
|
13 |
+
"""
|
14 |
+
|
15 |
+
from abc import ABC, abstractmethod
|
16 |
+
|
17 |
+
import numpy as np
|
18 |
+
import torch as th
|
19 |
+
import torch.distributed as dist
|
20 |
+
|
21 |
+
|
22 |
+
def create_named_schedule_sampler(name, diffusion):
|
23 |
+
"""
|
24 |
+
Create a ScheduleSampler from a library of pre-defined samplers.
|
25 |
+
|
26 |
+
:param name: the name of the sampler.
|
27 |
+
:param diffusion: the diffusion object to sample for.
|
28 |
+
"""
|
29 |
+
if name == "uniform":
|
30 |
+
return UniformSampler(diffusion)
|
31 |
+
elif name == "loss-second-moment":
|
32 |
+
return LossSecondMomentResampler(diffusion)
|
33 |
+
else:
|
34 |
+
raise NotImplementedError(f"unknown schedule sampler: {name}")
|
35 |
+
|
36 |
+
|
37 |
+
class ScheduleSampler(ABC):
|
38 |
+
"""
|
39 |
+
A distribution over timesteps in the diffusion process, intended to reduce
|
40 |
+
variance of the objective.
|
41 |
+
|
42 |
+
By default, samplers perform unbiased importance sampling, in which the
|
43 |
+
objective's mean is unchanged.
|
44 |
+
However, subclasses may override sample() to change how the resampled
|
45 |
+
terms are reweighted, allowing for actual changes in the objective.
|
46 |
+
"""
|
47 |
+
|
48 |
+
@abstractmethod
|
49 |
+
def weights(self):
|
50 |
+
"""
|
51 |
+
Get a numpy array of weights, one per diffusion step.
|
52 |
+
|
53 |
+
The weights needn't be normalized, but must be positive.
|
54 |
+
"""
|
55 |
+
|
56 |
+
def sample(self, batch_size, device):
|
57 |
+
"""
|
58 |
+
Importance-sample timesteps for a batch.
|
59 |
+
|
60 |
+
:param batch_size: the number of timesteps.
|
61 |
+
:param device: the torch device to save to.
|
62 |
+
:return: a tuple (timesteps, weights):
|
63 |
+
- timesteps: a tensor of timestep indices.
|
64 |
+
- weights: a tensor of weights to scale the resulting losses.
|
65 |
+
"""
|
66 |
+
w = self.weights()
|
67 |
+
p = w / np.sum(w)
|
68 |
+
indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
|
69 |
+
indices = th.from_numpy(indices_np).long().to(device)
|
70 |
+
weights_np = 1 / (len(p) * p[indices_np])
|
71 |
+
weights = th.from_numpy(weights_np).float().to(device)
|
72 |
+
return indices, weights
|
73 |
+
|
74 |
+
|
75 |
+
class UniformSampler(ScheduleSampler):
|
76 |
+
def __init__(self, diffusion):
|
77 |
+
self.diffusion = diffusion
|
78 |
+
self._weights = np.ones([diffusion.num_timesteps])
|
79 |
+
|
80 |
+
def weights(self):
|
81 |
+
return self._weights
|
82 |
+
|
83 |
+
|
84 |
+
class LossAwareSampler(ScheduleSampler):
|
85 |
+
def update_with_local_losses(self, local_ts, local_losses):
|
86 |
+
"""
|
87 |
+
Update the reweighting using losses from a model.
|
88 |
+
|
89 |
+
Call this method from each rank with a batch of timesteps and the
|
90 |
+
corresponding losses for each of those timesteps.
|
91 |
+
This method will perform synchronization to make sure all of the ranks
|
92 |
+
maintain the exact same reweighting.
|
93 |
+
|
94 |
+
:param local_ts: an integer Tensor of timesteps.
|
95 |
+
:param local_losses: a 1D Tensor of losses.
|
96 |
+
"""
|
97 |
+
batch_sizes = [
|
98 |
+
th.tensor([0], dtype=th.int32, device=local_ts.device)
|
99 |
+
for _ in range(dist.get_world_size())
|
100 |
+
]
|
101 |
+
dist.all_gather(
|
102 |
+
batch_sizes,
|
103 |
+
th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),
|
104 |
+
)
|
105 |
+
|
106 |
+
# Pad all_gather batches to be the maximum batch size.
|
107 |
+
batch_sizes = [x.item() for x in batch_sizes]
|
108 |
+
max_bs = max(batch_sizes)
|
109 |
+
|
110 |
+
timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes]
|
111 |
+
loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes]
|
112 |
+
dist.all_gather(timestep_batches, local_ts)
|
113 |
+
dist.all_gather(loss_batches, local_losses)
|
114 |
+
timesteps = [
|
115 |
+
x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]
|
116 |
+
]
|
117 |
+
losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
|
118 |
+
self.update_with_all_losses(timesteps, losses)
|
119 |
+
|
120 |
+
@abstractmethod
|
121 |
+
def update_with_all_losses(self, ts, losses):
|
122 |
+
"""
|
123 |
+
Update the reweighting using losses from a model.
|
124 |
+
|
125 |
+
Sub-classes should override this method to update the reweighting
|
126 |
+
using losses from the model.
|
127 |
+
|
128 |
+
This method directly updates the reweighting without synchronizing
|
129 |
+
between workers. It is called by update_with_local_losses from all
|
130 |
+
ranks with identical arguments. Thus, it should have deterministic
|
131 |
+
behavior to maintain state across workers.
|
132 |
+
|
133 |
+
:param ts: a list of int timesteps.
|
134 |
+
:param losses: a list of float losses, one per timestep.
|
135 |
+
"""
|
136 |
+
|
137 |
+
|
138 |
+
class LossSecondMomentResampler(LossAwareSampler):
|
139 |
+
def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
|
140 |
+
self.diffusion = diffusion
|
141 |
+
self.history_per_term = history_per_term
|
142 |
+
self.uniform_prob = uniform_prob
|
143 |
+
self._loss_history = np.zeros(
|
144 |
+
[diffusion.num_timesteps, history_per_term], dtype=np.float64
|
145 |
+
)
|
146 |
+
self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)
|
147 |
+
|
148 |
+
def weights(self):
|
149 |
+
if not self._warmed_up():
|
150 |
+
return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
|
151 |
+
weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1))
|
152 |
+
weights /= np.sum(weights)
|
153 |
+
weights *= 1 - self.uniform_prob
|
154 |
+
weights += self.uniform_prob / len(weights)
|
155 |
+
return weights
|
156 |
+
|
157 |
+
def update_with_all_losses(self, ts, losses):
|
158 |
+
for t, loss in zip(ts, losses):
|
159 |
+
if self._loss_counts[t] == self.history_per_term:
|
160 |
+
# Shift out the oldest loss term.
|
161 |
+
self._loss_history[t, :-1] = self._loss_history[t, 1:]
|
162 |
+
self._loss_history[t, -1] = loss
|
163 |
+
else:
|
164 |
+
self._loss_history[t, self._loss_counts[t]] = loss
|
165 |
+
self._loss_counts[t] += 1
|
166 |
+
|
167 |
+
def _warmed_up(self):
|
168 |
+
return (self._loss_counts == self.history_per_term).all()
|
diffusion/respace.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) Meta Platforms, Inc. and affiliates.
|
3 |
+
All rights reserved.
|
4 |
+
This source code is licensed under the license found in the
|
5 |
+
LICENSE file in the root directory of this source tree.
|
6 |
+
"""
|
7 |
+
|
8 |
+
"""
|
9 |
+
original code from
|
10 |
+
https://github.com/GuyTevet/motion-diffusion-model/blob/main/diffusion/gaussian_diffusion.py
|
11 |
+
under an MIT license
|
12 |
+
https://github.com/GuyTevet/motion-diffusion-model/blob/main/LICENSE
|
13 |
+
"""
|
14 |
+
|
15 |
+
import numpy as np
|
16 |
+
import torch as th
|
17 |
+
|
18 |
+
from .gaussian_diffusion import GaussianDiffusion
|
19 |
+
|
20 |
+
|
21 |
+
def space_timesteps(num_timesteps, section_counts):
|
22 |
+
"""
|
23 |
+
Create a list of timesteps to use from an original diffusion process,
|
24 |
+
given the number of timesteps we want to take from equally-sized portions
|
25 |
+
of the original process.
|
26 |
+
|
27 |
+
For example, if there's 300 timesteps and the section counts are [10,15,20]
|
28 |
+
then the first 100 timesteps are strided to be 10 timesteps, the second 100
|
29 |
+
are strided to be 15 timesteps, and the final 100 are strided to be 20.
|
30 |
+
|
31 |
+
If the stride is a string starting with "ddim", then the fixed striding
|
32 |
+
from the DDIM paper is used, and only one section is allowed.
|
33 |
+
|
34 |
+
:param num_timesteps: the number of diffusion steps in the original
|
35 |
+
process to divide up.
|
36 |
+
:param section_counts: either a list of numbers, or a string containing
|
37 |
+
comma-separated numbers, indicating the step count
|
38 |
+
per section. As a special case, use "ddimN" where N
|
39 |
+
is a number of steps to use the striding from the
|
40 |
+
DDIM paper.
|
41 |
+
:return: a set of diffusion steps from the original process to use.
|
42 |
+
"""
|
43 |
+
if isinstance(section_counts, str):
|
44 |
+
if section_counts.startswith("ddim"):
|
45 |
+
desired_count = int(section_counts[len("ddim") :])
|
46 |
+
for i in range(1, num_timesteps):
|
47 |
+
if len(range(0, num_timesteps, i)) == desired_count:
|
48 |
+
return set(range(0, num_timesteps, i))
|
49 |
+
raise ValueError(
|
50 |
+
f"cannot create exactly {num_timesteps} steps with an integer stride"
|
51 |
+
)
|
52 |
+
section_counts = [int(x) for x in section_counts.split(",")]
|
53 |
+
size_per = num_timesteps // len(section_counts)
|
54 |
+
extra = num_timesteps % len(section_counts)
|
55 |
+
start_idx = 0
|
56 |
+
all_steps = []
|
57 |
+
for i, section_count in enumerate(section_counts):
|
58 |
+
size = size_per + (1 if i < extra else 0)
|
59 |
+
if size < section_count:
|
60 |
+
raise ValueError(
|
61 |
+
f"cannot divide section of {size} steps into {section_count}"
|
62 |
+
)
|
63 |
+
if section_count <= 1:
|
64 |
+
frac_stride = 1
|
65 |
+
else:
|
66 |
+
frac_stride = (size - 1) / (section_count - 1)
|
67 |
+
cur_idx = 0.0
|
68 |
+
taken_steps = []
|
69 |
+
for _ in range(section_count):
|
70 |
+
taken_steps.append(start_idx + round(cur_idx))
|
71 |
+
cur_idx += frac_stride
|
72 |
+
all_steps += taken_steps
|
73 |
+
start_idx += size
|
74 |
+
return set(all_steps)
|
75 |
+
|
76 |
+
|
77 |
+
class SpacedDiffusion(GaussianDiffusion):
|
78 |
+
"""
|
79 |
+
A diffusion process which can skip steps in a base diffusion process.
|
80 |
+
|
81 |
+
:param use_timesteps: a collection (sequence or set) of timesteps from the
|
82 |
+
original diffusion process to retain.
|
83 |
+
:param kwargs: the kwargs to create the base diffusion process.
|
84 |
+
"""
|
85 |
+
|
86 |
+
def __init__(self, use_timesteps, **kwargs):
|
87 |
+
self.use_timesteps = set(use_timesteps)
|
88 |
+
self.timestep_map = []
|
89 |
+
self.original_num_steps = len(kwargs["betas"])
|
90 |
+
|
91 |
+
base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
|
92 |
+
last_alpha_cumprod = 1.0
|
93 |
+
new_betas = []
|
94 |
+
for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
|
95 |
+
if i in self.use_timesteps:
|
96 |
+
new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
|
97 |
+
last_alpha_cumprod = alpha_cumprod
|
98 |
+
self.timestep_map.append(i)
|
99 |
+
kwargs["betas"] = np.array(new_betas)
|
100 |
+
super().__init__(**kwargs)
|
101 |
+
|
102 |
+
def p_mean_variance(
|
103 |
+
self, model, *args, **kwargs
|
104 |
+
): # pylint: disable=signature-differs
|
105 |
+
return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
|
106 |
+
|
107 |
+
def training_losses(
|
108 |
+
self, model, *args, **kwargs
|
109 |
+
): # pylint: disable=signature-differs
|
110 |
+
return super().training_losses(self._wrap_model(model), *args, **kwargs)
|
111 |
+
|
112 |
+
def condition_mean(self, cond_fn, *args, **kwargs):
|
113 |
+
return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
|
114 |
+
|
115 |
+
def condition_score(self, cond_fn, *args, **kwargs):
|
116 |
+
return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
|
117 |
+
|
118 |
+
def _wrap_model(self, model):
|
119 |
+
if isinstance(model, _WrappedModel):
|
120 |
+
return model
|
121 |
+
return _WrappedModel(
|
122 |
+
model, self.timestep_map, self.rescale_timesteps, self.original_num_steps
|
123 |
+
)
|
124 |
+
|
125 |
+
def _scale_timesteps(self, t):
|
126 |
+
# Scaling is done by the wrapped model.
|
127 |
+
return t
|
128 |
+
|
129 |
+
|
130 |
+
class _WrappedModel:
|
131 |
+
def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps):
|
132 |
+
self.model = model
|
133 |
+
if hasattr(model, "step"):
|
134 |
+
self.step = model.step
|
135 |
+
self.add_frame_cond = model.add_frame_cond
|
136 |
+
self.timestep_map = timestep_map
|
137 |
+
self.rescale_timesteps = rescale_timesteps
|
138 |
+
self.original_num_steps = original_num_steps
|
139 |
+
|
140 |
+
def __call__(self, x, ts, **kwargs):
|
141 |
+
map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
|
142 |
+
new_ts = map_tensor[ts]
|
143 |
+
if self.rescale_timesteps:
|
144 |
+
new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
|
145 |
+
return self.model(x, new_ts, **kwargs)
|
flagged/audio/b90d90dbca93f47e8d01/audio.wav
ADDED
Binary file (696 kB). View file
|
|
flagged/audio/d8e03e2e6deae2f981b1/audio.wav
ADDED
Binary file (696 kB). View file
|
|
flagged/log.csv
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
audio,Number of Samples (default = 3),Sample Diversity (default = 0.97),output 0,output 1,output 2,output 3,output 4,output 5,output 6,output 7,output 8,output 9,flag,username,timestamp
|
2 |
+
,1,0.69,,,,,,,,,,,,,2024-07-15 05:46:49.672259
|
3 |
+
flagged/audio/d8e03e2e6deae2f981b1/audio.wav,1,0.69,,,,,,,,,,,,,2024-07-15 06:28:21.003877
|
4 |
+
flagged/audio/b90d90dbca93f47e8d01/audio.wav,1,0.69,,,,,,,,,,,,,2024-07-15 06:28:24.442449
|
model/cfg_sampler.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) Meta Platforms, Inc. and affiliates.
|
3 |
+
All rights reserved.
|
4 |
+
This source code is licensed under the license found in the
|
5 |
+
LICENSE file in the root directory of this source tree.
|
6 |
+
"""
|
7 |
+
|
8 |
+
from copy import deepcopy
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
|
14 |
+
|
15 |
+
# A wrapper model for Classifier-free guidance **SAMPLING** only
|
16 |
+
# https://arxiv.org/abs/2207.12598
|
17 |
+
class ClassifierFreeSampleModel(nn.Module):
|
18 |
+
def __init__(self, model):
|
19 |
+
super().__init__()
|
20 |
+
self.model = model # model is the actual model to run
|
21 |
+
self.nfeats = self.model.nfeats
|
22 |
+
self.cond_mode = self.model.cond_mode
|
23 |
+
self.add_frame_cond = self.model.add_frame_cond
|
24 |
+
if self.add_frame_cond is not None:
|
25 |
+
if self.model.resume_trans is not None:
|
26 |
+
self.transformer = self.model.transformer
|
27 |
+
self.tokenizer = self.model.tokenizer
|
28 |
+
self.step = self.model.step
|
29 |
+
|
30 |
+
def forward(self, x, timesteps, y=None):
|
31 |
+
out = self.model(x, timesteps, y, cond_drop_prob=0.0)
|
32 |
+
out_uncond = self.model(x, timesteps, y, cond_drop_prob=1.0)
|
33 |
+
return out_uncond + (y["scale"].view(-1, 1, 1) * (out - out_uncond))
|
model/diffusion.py
ADDED
@@ -0,0 +1,403 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) Meta Platforms, Inc. and affiliates.
|
3 |
+
All rights reserved.
|
4 |
+
This source code is licensed under the license found in the
|
5 |
+
LICENSE file in the root directory of this source tree.
|
6 |
+
"""
|
7 |
+
|
8 |
+
import json
|
9 |
+
from typing import Callable, Optional
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
from einops import rearrange
|
14 |
+
from einops.layers.torch import Rearrange
|
15 |
+
|
16 |
+
from model.guide import GuideTransformer
|
17 |
+
from model.modules.audio_encoder import Wav2VecEncoder
|
18 |
+
from model.modules.rotary_embedding_torch import RotaryEmbedding
|
19 |
+
from model.modules.transformer_modules import (
|
20 |
+
DecoderLayerStack,
|
21 |
+
FiLMTransformerDecoderLayer,
|
22 |
+
RegressionTransformer,
|
23 |
+
TransformerEncoderLayerRotary,
|
24 |
+
)
|
25 |
+
from model.utils import (
|
26 |
+
init_weight,
|
27 |
+
PositionalEncoding,
|
28 |
+
prob_mask_like,
|
29 |
+
setup_lip_regressor,
|
30 |
+
SinusoidalPosEmb,
|
31 |
+
)
|
32 |
+
from model.vqvae import setup_tokenizer
|
33 |
+
from torch.nn import functional as F
|
34 |
+
from utils.misc import prGreen, prRed
|
35 |
+
|
36 |
+
|
37 |
+
class Audio2LipRegressionTransformer(torch.nn.Module):
|
38 |
+
def __init__(
|
39 |
+
self,
|
40 |
+
n_vertices: int = 338,
|
41 |
+
causal: bool = False,
|
42 |
+
train_wav2vec: bool = False,
|
43 |
+
transformer_encoder_layers: int = 2,
|
44 |
+
transformer_decoder_layers: int = 4,
|
45 |
+
):
|
46 |
+
super().__init__()
|
47 |
+
self.n_vertices = n_vertices
|
48 |
+
|
49 |
+
self.audio_encoder = Wav2VecEncoder()
|
50 |
+
if not train_wav2vec:
|
51 |
+
self.audio_encoder.eval()
|
52 |
+
for param in self.audio_encoder.parameters():
|
53 |
+
param.requires_grad = False
|
54 |
+
|
55 |
+
self.regression_model = RegressionTransformer(
|
56 |
+
transformer_encoder_layers=transformer_encoder_layers,
|
57 |
+
transformer_decoder_layers=transformer_decoder_layers,
|
58 |
+
d_model=512,
|
59 |
+
d_cond=512,
|
60 |
+
num_heads=4,
|
61 |
+
causal=causal,
|
62 |
+
)
|
63 |
+
self.project_output = torch.nn.Linear(512, self.n_vertices * 3)
|
64 |
+
|
65 |
+
def forward(self, audio):
|
66 |
+
"""
|
67 |
+
:param audio: tensor of shape B x T x 1600
|
68 |
+
:return: tensor of shape B x T x n_vertices x 3 containing reconstructed lip geometry
|
69 |
+
"""
|
70 |
+
B, T = audio.shape[0], audio.shape[1]
|
71 |
+
|
72 |
+
cond = self.audio_encoder(audio)
|
73 |
+
|
74 |
+
x = torch.zeros(B, T, 512, device=audio.device)
|
75 |
+
x = self.regression_model(x, cond)
|
76 |
+
x = self.project_output(x)
|
77 |
+
|
78 |
+
verts = x.view(B, T, self.n_vertices, 3)
|
79 |
+
return verts
|
80 |
+
|
81 |
+
|
82 |
+
class FiLMTransformer(nn.Module):
|
83 |
+
def __init__(
|
84 |
+
self,
|
85 |
+
args,
|
86 |
+
nfeats: int,
|
87 |
+
latent_dim: int = 512,
|
88 |
+
ff_size: int = 1024,
|
89 |
+
num_layers: int = 4,
|
90 |
+
num_heads: int = 4,
|
91 |
+
dropout: float = 0.1,
|
92 |
+
cond_feature_dim: int = 4800,
|
93 |
+
activation: Callable[[torch.Tensor], torch.Tensor] = F.gelu,
|
94 |
+
use_rotary: bool = True,
|
95 |
+
cond_mode: str = "audio",
|
96 |
+
split_type: str = "train",
|
97 |
+
device: str = "cuda",
|
98 |
+
**kwargs,
|
99 |
+
) -> None:
|
100 |
+
super().__init__()
|
101 |
+
self.nfeats = nfeats
|
102 |
+
self.cond_mode = cond_mode
|
103 |
+
self.cond_feature_dim = cond_feature_dim
|
104 |
+
self.add_frame_cond = args.add_frame_cond
|
105 |
+
self.data_format = args.data_format
|
106 |
+
self.split_type = split_type
|
107 |
+
self.device = device
|
108 |
+
|
109 |
+
# positional embeddings
|
110 |
+
self.rotary = None
|
111 |
+
self.abs_pos_encoding = nn.Identity()
|
112 |
+
# if rotary, replace absolute embedding with a rotary embedding instance (absolute becomes an identity)
|
113 |
+
if use_rotary:
|
114 |
+
self.rotary = RotaryEmbedding(dim=latent_dim)
|
115 |
+
else:
|
116 |
+
self.abs_pos_encoding = PositionalEncoding(
|
117 |
+
latent_dim, dropout, batch_first=True
|
118 |
+
)
|
119 |
+
|
120 |
+
# time embedding processing
|
121 |
+
self.time_mlp = nn.Sequential(
|
122 |
+
SinusoidalPosEmb(latent_dim),
|
123 |
+
nn.Linear(latent_dim, latent_dim * 4),
|
124 |
+
nn.Mish(),
|
125 |
+
)
|
126 |
+
self.to_time_cond = nn.Sequential(
|
127 |
+
nn.Linear(latent_dim * 4, latent_dim),
|
128 |
+
)
|
129 |
+
self.to_time_tokens = nn.Sequential(
|
130 |
+
nn.Linear(latent_dim * 4, latent_dim * 2),
|
131 |
+
Rearrange("b (r d) -> b r d", r=2),
|
132 |
+
)
|
133 |
+
|
134 |
+
# null embeddings for guidance dropout
|
135 |
+
self.seq_len = args.max_seq_length
|
136 |
+
emb_len = 1998 # hardcoded for now
|
137 |
+
self.null_cond_embed = nn.Parameter(torch.randn(1, emb_len, latent_dim))
|
138 |
+
self.null_cond_hidden = nn.Parameter(torch.randn(1, latent_dim))
|
139 |
+
self.norm_cond = nn.LayerNorm(latent_dim)
|
140 |
+
self.setup_audio_models()
|
141 |
+
|
142 |
+
# set up pose/face specific parts of the model
|
143 |
+
self.input_projection = nn.Linear(self.nfeats, latent_dim)
|
144 |
+
if self.data_format == "pose":
|
145 |
+
cond_feature_dim = 1024
|
146 |
+
key_feature_dim = 104
|
147 |
+
self.step = 30
|
148 |
+
self.use_cm = True
|
149 |
+
self.setup_guide_models(args, latent_dim, key_feature_dim)
|
150 |
+
self.post_pose_layers = self._build_single_pose_conv(self.nfeats)
|
151 |
+
self.post_pose_layers.apply(init_weight)
|
152 |
+
self.final_conv = torch.nn.Conv1d(self.nfeats, self.nfeats, kernel_size=1)
|
153 |
+
self.receptive_field = 25
|
154 |
+
elif self.data_format == "face":
|
155 |
+
self.use_cm = False
|
156 |
+
cond_feature_dim = 1024 + 1014
|
157 |
+
self.setup_lip_models()
|
158 |
+
self.cond_encoder = nn.Sequential()
|
159 |
+
for _ in range(2):
|
160 |
+
self.cond_encoder.append(
|
161 |
+
TransformerEncoderLayerRotary(
|
162 |
+
d_model=latent_dim,
|
163 |
+
nhead=num_heads,
|
164 |
+
dim_feedforward=ff_size,
|
165 |
+
dropout=dropout,
|
166 |
+
activation=activation,
|
167 |
+
batch_first=True,
|
168 |
+
rotary=self.rotary,
|
169 |
+
)
|
170 |
+
)
|
171 |
+
self.cond_encoder.apply(init_weight)
|
172 |
+
|
173 |
+
self.cond_projection = nn.Linear(cond_feature_dim, latent_dim)
|
174 |
+
self.non_attn_cond_projection = nn.Sequential(
|
175 |
+
nn.LayerNorm(latent_dim),
|
176 |
+
nn.Linear(latent_dim, latent_dim),
|
177 |
+
nn.SiLU(),
|
178 |
+
nn.Linear(latent_dim, latent_dim),
|
179 |
+
)
|
180 |
+
|
181 |
+
# decoder
|
182 |
+
decoderstack = nn.ModuleList([])
|
183 |
+
for _ in range(num_layers):
|
184 |
+
decoderstack.append(
|
185 |
+
FiLMTransformerDecoderLayer(
|
186 |
+
latent_dim,
|
187 |
+
num_heads,
|
188 |
+
dim_feedforward=ff_size,
|
189 |
+
dropout=dropout,
|
190 |
+
activation=activation,
|
191 |
+
batch_first=True,
|
192 |
+
rotary=self.rotary,
|
193 |
+
use_cm=self.use_cm,
|
194 |
+
)
|
195 |
+
)
|
196 |
+
self.seqTransDecoder = DecoderLayerStack(decoderstack)
|
197 |
+
self.seqTransDecoder.apply(init_weight)
|
198 |
+
self.final_layer = nn.Linear(latent_dim, self.nfeats)
|
199 |
+
self.final_layer.apply(init_weight)
|
200 |
+
|
201 |
+
def _build_single_pose_conv(self, nfeats: int) -> nn.ModuleList:
|
202 |
+
post_pose_layers = torch.nn.ModuleList(
|
203 |
+
[
|
204 |
+
torch.nn.Conv1d(nfeats, max(256, nfeats), kernel_size=3, dilation=1),
|
205 |
+
torch.nn.Conv1d(max(256, nfeats), nfeats, kernel_size=3, dilation=2),
|
206 |
+
torch.nn.Conv1d(nfeats, nfeats, kernel_size=3, dilation=3),
|
207 |
+
torch.nn.Conv1d(nfeats, nfeats, kernel_size=3, dilation=1),
|
208 |
+
torch.nn.Conv1d(nfeats, nfeats, kernel_size=3, dilation=2),
|
209 |
+
torch.nn.Conv1d(nfeats, nfeats, kernel_size=3, dilation=3),
|
210 |
+
]
|
211 |
+
)
|
212 |
+
return post_pose_layers
|
213 |
+
|
214 |
+
def _run_single_pose_conv(self, output: torch.Tensor) -> torch.Tensor:
|
215 |
+
output = torch.nn.functional.pad(output, pad=[self.receptive_field - 1, 0])
|
216 |
+
for _, layer in enumerate(self.post_pose_layers):
|
217 |
+
y = torch.nn.functional.leaky_relu(layer(output), negative_slope=0.2)
|
218 |
+
if self.split_type == "train":
|
219 |
+
y = torch.nn.functional.dropout(y, 0.2)
|
220 |
+
if output.shape[1] == y.shape[1]:
|
221 |
+
output = (output[:, :, -y.shape[-1] :] + y) / 2.0 # skip connection
|
222 |
+
else:
|
223 |
+
output = y
|
224 |
+
return output
|
225 |
+
|
226 |
+
def setup_guide_models(self, args, latent_dim: int, key_feature_dim: int) -> None:
|
227 |
+
# set up conditioning info
|
228 |
+
max_keyframe_len = len(list(range(self.seq_len))[:: self.step])
|
229 |
+
self.null_pose_embed = nn.Parameter(
|
230 |
+
torch.randn(1, max_keyframe_len, latent_dim)
|
231 |
+
)
|
232 |
+
prGreen(f"using keyframes: {self.null_pose_embed.shape}")
|
233 |
+
self.frame_cond_projection = nn.Linear(key_feature_dim, latent_dim)
|
234 |
+
self.frame_norm_cond = nn.LayerNorm(latent_dim)
|
235 |
+
# for test time set up keyframe transformer
|
236 |
+
self.resume_trans = None
|
237 |
+
if self.split_type == "test":
|
238 |
+
if hasattr(args, "resume_trans") and args.resume_trans is not None:
|
239 |
+
self.resume_trans = args.resume_trans
|
240 |
+
self.setup_guide_predictor(args.resume_trans)
|
241 |
+
else:
|
242 |
+
prRed("not using transformer, just using ground truth")
|
243 |
+
|
244 |
+
def setup_guide_predictor(self, cp_path: str) -> None:
|
245 |
+
cp_dir = cp_path.split("checkpoints/iter-")[0]
|
246 |
+
with open(f"{cp_dir}/args.json") as f:
|
247 |
+
trans_args = json.load(f)
|
248 |
+
|
249 |
+
# set up tokenizer based on trans_arg load point
|
250 |
+
self.tokenizer = setup_tokenizer(trans_args["resume_pth"])
|
251 |
+
|
252 |
+
# set up transformer
|
253 |
+
self.transformer = GuideTransformer(
|
254 |
+
tokens=self.tokenizer.n_clusters,
|
255 |
+
num_layers=trans_args["layers"],
|
256 |
+
dim=trans_args["dim"],
|
257 |
+
emb_len=1998,
|
258 |
+
num_audio_layers=trans_args["num_audio_layers"],
|
259 |
+
)
|
260 |
+
for param in self.transformer.parameters():
|
261 |
+
param.requires_grad = False
|
262 |
+
prGreen("loading TRANSFORMER checkpoint from {}".format(cp_path))
|
263 |
+
cp = torch.load(cp_path)
|
264 |
+
missing_keys, unexpected_keys = self.transformer.load_state_dict(
|
265 |
+
cp["model_state_dict"], strict=False
|
266 |
+
)
|
267 |
+
assert len(missing_keys) == 0, missing_keys
|
268 |
+
assert len(unexpected_keys) == 0, unexpected_keys
|
269 |
+
|
270 |
+
def setup_audio_models(self) -> None:
|
271 |
+
self.audio_model, self.audio_resampler = setup_lip_regressor()
|
272 |
+
|
273 |
+
def setup_lip_models(self) -> None:
|
274 |
+
self.lip_model = Audio2LipRegressionTransformer()
|
275 |
+
cp_path = "./assets/iter-0200000.pt"
|
276 |
+
cp = torch.load(cp_path, map_location=torch.device(self.device))
|
277 |
+
self.lip_model.load_state_dict(cp["model_state_dict"])
|
278 |
+
for param in self.lip_model.parameters():
|
279 |
+
param.requires_grad = False
|
280 |
+
prGreen(f"adding lip conditioning {cp_path}")
|
281 |
+
|
282 |
+
def parameters_w_grad(self):
|
283 |
+
return [p for p in self.parameters() if p.requires_grad]
|
284 |
+
|
285 |
+
def encode_audio(self, raw_audio: torch.Tensor) -> torch.Tensor:
|
286 |
+
device = next(self.parameters()).device
|
287 |
+
a0 = self.audio_resampler(raw_audio[:, :, 0].to(device))
|
288 |
+
a1 = self.audio_resampler(raw_audio[:, :, 1].to(device))
|
289 |
+
with torch.no_grad():
|
290 |
+
z0 = self.audio_model.feature_extractor(a0)
|
291 |
+
z1 = self.audio_model.feature_extractor(a1)
|
292 |
+
emb = torch.cat((z0, z1), axis=1).permute(0, 2, 1)
|
293 |
+
return emb
|
294 |
+
|
295 |
+
def encode_lip(self, audio: torch.Tensor, cond_embed: torch.Tensor) -> torch.Tensor:
|
296 |
+
reshaped_audio = audio.reshape((audio.shape[0], -1, 1600, 2))[..., 0]
|
297 |
+
# processes 4 seconds at a time
|
298 |
+
B, T, _ = reshaped_audio.shape
|
299 |
+
lip_cond = torch.zeros(
|
300 |
+
(audio.shape[0], T, 338, 3),
|
301 |
+
device=audio.device,
|
302 |
+
dtype=audio.dtype,
|
303 |
+
)
|
304 |
+
for i in range(0, T, 120):
|
305 |
+
lip_cond[:, i : i + 120, ...] = self.lip_model(
|
306 |
+
reshaped_audio[:, i : i + 120, ...]
|
307 |
+
)
|
308 |
+
lip_cond = lip_cond.permute(0, 2, 3, 1).reshape((B, 338 * 3, -1))
|
309 |
+
lip_cond = torch.nn.functional.interpolate(
|
310 |
+
lip_cond, size=cond_embed.shape[1], mode="nearest-exact"
|
311 |
+
).permute(0, 2, 1)
|
312 |
+
cond_embed = torch.cat((cond_embed, lip_cond), dim=-1)
|
313 |
+
return cond_embed
|
314 |
+
|
315 |
+
def encode_keyframes(
|
316 |
+
self, y: torch.Tensor, cond_drop_prob: float, batch_size: int
|
317 |
+
) -> torch.Tensor:
|
318 |
+
pred = y["keyframes"]
|
319 |
+
new_mask = y["mask"][..., :: self.step].squeeze((1, 2))
|
320 |
+
pred[~new_mask] = 0.0 # pad the unknown
|
321 |
+
pose_hidden = self.frame_cond_projection(pred.detach().clone().cuda())
|
322 |
+
pose_embed = self.abs_pos_encoding(pose_hidden)
|
323 |
+
pose_tokens = self.frame_norm_cond(pose_embed)
|
324 |
+
# do conditional dropout for guide poses
|
325 |
+
key_cond_drop_prob = cond_drop_prob
|
326 |
+
keep_mask_pose = prob_mask_like(
|
327 |
+
(batch_size,), 1 - key_cond_drop_prob, device=pose_tokens.device
|
328 |
+
)
|
329 |
+
keep_mask_pose_embed = rearrange(keep_mask_pose, "b -> b 1 1")
|
330 |
+
null_pose_embed = self.null_pose_embed.to(pose_tokens.dtype)
|
331 |
+
pose_tokens = torch.where(
|
332 |
+
keep_mask_pose_embed,
|
333 |
+
pose_tokens,
|
334 |
+
null_pose_embed[:, : pose_tokens.shape[1], :],
|
335 |
+
)
|
336 |
+
return pose_tokens
|
337 |
+
|
338 |
+
def forward(
|
339 |
+
self,
|
340 |
+
x: torch.Tensor,
|
341 |
+
times: torch.Tensor,
|
342 |
+
y: Optional[torch.Tensor] = None,
|
343 |
+
cond_drop_prob: float = 0.0,
|
344 |
+
) -> torch.Tensor:
|
345 |
+
if x.dim() == 4:
|
346 |
+
x = x.permute(0, 3, 1, 2).squeeze(-1)
|
347 |
+
batch_size, device = x.shape[0], x.device
|
348 |
+
if self.cond_mode == "uncond":
|
349 |
+
cond_embed = torch.zeros(
|
350 |
+
(x.shape[0], x.shape[1], self.cond_feature_dim),
|
351 |
+
dtype=x.dtype,
|
352 |
+
device=x.device,
|
353 |
+
)
|
354 |
+
else:
|
355 |
+
cond_embed = y["audio"]
|
356 |
+
cond_embed = self.encode_audio(cond_embed)
|
357 |
+
if self.data_format == "face":
|
358 |
+
cond_embed = self.encode_lip(y["audio"], cond_embed)
|
359 |
+
pose_tokens = None
|
360 |
+
if self.data_format == "pose":
|
361 |
+
pose_tokens = self.encode_keyframes(y, cond_drop_prob, batch_size)
|
362 |
+
assert cond_embed is not None, "cond emb should not be none"
|
363 |
+
# process conditioning information
|
364 |
+
x = self.input_projection(x)
|
365 |
+
x = self.abs_pos_encoding(x)
|
366 |
+
audio_cond_drop_prob = cond_drop_prob
|
367 |
+
keep_mask = prob_mask_like(
|
368 |
+
(batch_size,), 1 - audio_cond_drop_prob, device=device
|
369 |
+
)
|
370 |
+
keep_mask_embed = rearrange(keep_mask, "b -> b 1 1")
|
371 |
+
keep_mask_hidden = rearrange(keep_mask, "b -> b 1")
|
372 |
+
cond_tokens = self.cond_projection(cond_embed)
|
373 |
+
cond_tokens = self.abs_pos_encoding(cond_tokens)
|
374 |
+
if self.data_format == "face":
|
375 |
+
cond_tokens = self.cond_encoder(cond_tokens)
|
376 |
+
null_cond_embed = self.null_cond_embed.to(cond_tokens.dtype)
|
377 |
+
cond_tokens = torch.where(
|
378 |
+
keep_mask_embed, cond_tokens, null_cond_embed[:, : cond_tokens.shape[1], :]
|
379 |
+
)
|
380 |
+
mean_pooled_cond_tokens = cond_tokens.mean(dim=-2)
|
381 |
+
cond_hidden = self.non_attn_cond_projection(mean_pooled_cond_tokens)
|
382 |
+
|
383 |
+
# create t conditioning
|
384 |
+
t_hidden = self.time_mlp(times)
|
385 |
+
t = self.to_time_cond(t_hidden)
|
386 |
+
t_tokens = self.to_time_tokens(t_hidden)
|
387 |
+
null_cond_hidden = self.null_cond_hidden.to(t.dtype)
|
388 |
+
cond_hidden = torch.where(keep_mask_hidden, cond_hidden, null_cond_hidden)
|
389 |
+
t += cond_hidden
|
390 |
+
|
391 |
+
# cross-attention conditioning
|
392 |
+
c = torch.cat((cond_tokens, t_tokens), dim=-2)
|
393 |
+
cond_tokens = self.norm_cond(c)
|
394 |
+
|
395 |
+
# Pass through the transformer decoder
|
396 |
+
output = self.seqTransDecoder(x, cond_tokens, t, memory2=pose_tokens)
|
397 |
+
output = self.final_layer(output)
|
398 |
+
if self.data_format == "pose":
|
399 |
+
output = output.permute(0, 2, 1)
|
400 |
+
output = self._run_single_pose_conv(output)
|
401 |
+
output = self.final_conv(output)
|
402 |
+
output = output.permute(0, 2, 1)
|
403 |
+
return output
|
model/guide.py
ADDED
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) Meta Platforms, Inc. and affiliates.
|
3 |
+
All rights reserved.
|
4 |
+
This source code is licensed under the license found in the
|
5 |
+
LICENSE file in the root directory of this source tree.
|
6 |
+
"""
|
7 |
+
|
8 |
+
from typing import Callable, List
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch as th
|
12 |
+
import torch.nn as nn
|
13 |
+
from einops import rearrange
|
14 |
+
from model.modules.rotary_embedding_torch import RotaryEmbedding
|
15 |
+
|
16 |
+
from model.modules.transformer_modules import (
|
17 |
+
DecoderLayerStack,
|
18 |
+
FiLMTransformerDecoderLayer,
|
19 |
+
PositionalEncoding,
|
20 |
+
)
|
21 |
+
from model.utils import prob_mask_like, setup_lip_regressor
|
22 |
+
from torch.distributions import Categorical
|
23 |
+
from torch.nn import functional as F
|
24 |
+
|
25 |
+
|
26 |
+
class GuideTransformer(nn.Module):
|
27 |
+
def __init__(
|
28 |
+
self,
|
29 |
+
tokens: int,
|
30 |
+
num_heads: int = 4,
|
31 |
+
num_layers: int = 4,
|
32 |
+
dim: int = 512,
|
33 |
+
ff_size: int = 1024,
|
34 |
+
dropout: float = 0.1,
|
35 |
+
activation: Callable = F.gelu,
|
36 |
+
use_rotary: bool = True,
|
37 |
+
cond_feature_dim: int = 1024,
|
38 |
+
emb_len: int = 798,
|
39 |
+
num_audio_layers: int = 2,
|
40 |
+
):
|
41 |
+
super().__init__()
|
42 |
+
self.tokens = tokens
|
43 |
+
self.token_embedding = th.nn.Embedding(
|
44 |
+
num_embeddings=tokens + 1, # account for sequence start and end tokens
|
45 |
+
embedding_dim=dim,
|
46 |
+
)
|
47 |
+
self.abs_pos_encoding = nn.Identity()
|
48 |
+
# if rotary, replace absolute embedding with a rotary embedding instance (absolute becomes an identity)
|
49 |
+
if use_rotary:
|
50 |
+
self.rotary = RotaryEmbedding(dim=dim)
|
51 |
+
else:
|
52 |
+
self.abs_pos_encoding = PositionalEncoding(dim, dropout, batch_first=True)
|
53 |
+
self.setup_audio_models(cond_feature_dim, num_audio_layers)
|
54 |
+
|
55 |
+
self.null_cond_embed = nn.Parameter(torch.randn(1, emb_len, dim))
|
56 |
+
self.null_cond_hidden = nn.Parameter(torch.randn(1, dim))
|
57 |
+
self.norm_cond = nn.LayerNorm(dim)
|
58 |
+
|
59 |
+
self.cond_projection = nn.Linear(cond_feature_dim, dim)
|
60 |
+
self.non_attn_cond_projection = nn.Sequential(
|
61 |
+
nn.LayerNorm(dim),
|
62 |
+
nn.Linear(dim, dim),
|
63 |
+
nn.SiLU(),
|
64 |
+
nn.Linear(dim, dim),
|
65 |
+
)
|
66 |
+
# decoder
|
67 |
+
decoderstack = nn.ModuleList([])
|
68 |
+
for _ in range(num_layers):
|
69 |
+
decoderstack.append(
|
70 |
+
FiLMTransformerDecoderLayer(
|
71 |
+
dim,
|
72 |
+
num_heads,
|
73 |
+
dim_feedforward=ff_size,
|
74 |
+
dropout=dropout,
|
75 |
+
activation=activation,
|
76 |
+
batch_first=True,
|
77 |
+
rotary=self.rotary,
|
78 |
+
)
|
79 |
+
)
|
80 |
+
self.seqTransDecoder = DecoderLayerStack(decoderstack)
|
81 |
+
self.final_layer = nn.Linear(dim, tokens)
|
82 |
+
|
83 |
+
def _build_single_audio_conv(self, c: int) -> List[nn.Module]:
|
84 |
+
return [
|
85 |
+
torch.nn.Conv1d(c, max(256, c), kernel_size=3, dilation=1),
|
86 |
+
torch.nn.LeakyReLU(negative_slope=0.2),
|
87 |
+
torch.nn.Dropout(0.2),
|
88 |
+
#
|
89 |
+
torch.nn.Conv1d(max(256, c), max(256, c), kernel_size=3, dilation=2),
|
90 |
+
torch.nn.LeakyReLU(negative_slope=0.2),
|
91 |
+
torch.nn.Dropout(0.2),
|
92 |
+
#
|
93 |
+
torch.nn.Conv1d(max(128, c), max(128, c), kernel_size=3, dilation=3),
|
94 |
+
torch.nn.LeakyReLU(negative_slope=0.2),
|
95 |
+
torch.nn.Dropout(0.2),
|
96 |
+
#
|
97 |
+
torch.nn.Conv1d(max(128, c), c, kernel_size=3, dilation=1),
|
98 |
+
torch.nn.LeakyReLU(negative_slope=0.2),
|
99 |
+
torch.nn.Dropout(0.2),
|
100 |
+
#
|
101 |
+
torch.nn.Conv1d(c, c, kernel_size=3, dilation=2),
|
102 |
+
torch.nn.LeakyReLU(negative_slope=0.2),
|
103 |
+
torch.nn.Dropout(0.2),
|
104 |
+
#
|
105 |
+
torch.nn.Conv1d(c, c, kernel_size=3, dilation=3),
|
106 |
+
torch.nn.LeakyReLU(negative_slope=0.2),
|
107 |
+
torch.nn.Dropout(0.2),
|
108 |
+
]
|
109 |
+
|
110 |
+
def setup_audio_models(self, cond_feature_dim: int, num_audio_layers: int) -> None:
|
111 |
+
pre_layers = []
|
112 |
+
for _ in range(num_audio_layers):
|
113 |
+
pre_layers += self._build_single_audio_conv(cond_feature_dim)
|
114 |
+
pre_layers += [
|
115 |
+
torch.nn.Conv1d(cond_feature_dim, cond_feature_dim, kernel_size=1)
|
116 |
+
]
|
117 |
+
pre_layers = torch.nn.ModuleList(pre_layers)
|
118 |
+
self.pre_audio = nn.Sequential(*pre_layers)
|
119 |
+
self.audio_model, self.audio_resampler = setup_lip_regressor()
|
120 |
+
|
121 |
+
def encode_audio(self, raw_audio: torch.Tensor) -> torch.Tensor:
|
122 |
+
device = next(self.parameters()).device
|
123 |
+
a0 = self.audio_resampler(raw_audio[:, :, 0].to(device)) # B x T
|
124 |
+
a1 = self.audio_resampler(raw_audio[:, :, 1].to(device)) # B x T
|
125 |
+
with torch.no_grad():
|
126 |
+
z0 = self.audio_model.feature_extractor(a0)
|
127 |
+
z1 = self.audio_model.feature_extractor(a1)
|
128 |
+
emb = torch.cat((z0, z1), axis=1).permute(0, 2, 1)
|
129 |
+
return emb
|
130 |
+
|
131 |
+
def get_tgt_mask(self, size: int, device: str) -> torch.tensor:
|
132 |
+
mask = torch.tril(
|
133 |
+
torch.ones((size, size), device=device) == 1
|
134 |
+
) # Lower triangular matrix
|
135 |
+
mask = mask.float()
|
136 |
+
mask = mask.masked_fill(mask == 0, float("-inf")) # Convert zeros to -inf
|
137 |
+
mask = mask.masked_fill(mask == 1, float(0.0)) # Convert ones to 0
|
138 |
+
return mask
|
139 |
+
|
140 |
+
def forward(
|
141 |
+
self, tokens: th.Tensor, condition: th.Tensor, cond_drop_prob: float = 0.0
|
142 |
+
) -> torch.Tensor:
|
143 |
+
batch_size, device = tokens.shape[0], tokens.device
|
144 |
+
|
145 |
+
x = self.token_embedding(tokens)
|
146 |
+
x = self.abs_pos_encoding(x)
|
147 |
+
tgt_mask = self.get_tgt_mask(x.shape[1], x.device)
|
148 |
+
|
149 |
+
cond_embed = self.encode_audio(condition)
|
150 |
+
keep_mask = prob_mask_like((batch_size,), 1 - cond_drop_prob, device=device)
|
151 |
+
keep_mask_embed = rearrange(keep_mask, "b -> b 1 1")
|
152 |
+
keep_mask_hidden = rearrange(keep_mask, "b -> b 1")
|
153 |
+
cond_tokens = self.pre_audio(cond_embed.permute(0, 2, 1)).permute(0, 2, 1)
|
154 |
+
#
|
155 |
+
cond_tokens = self.cond_projection(cond_tokens)
|
156 |
+
cond_tokens = self.abs_pos_encoding(cond_tokens)
|
157 |
+
|
158 |
+
null_cond_embed = self.null_cond_embed.to(cond_tokens.dtype)
|
159 |
+
cond_tokens = torch.where(
|
160 |
+
keep_mask_embed, cond_tokens, null_cond_embed[:, : cond_tokens.shape[1], :]
|
161 |
+
)
|
162 |
+
mean_pooled_cond_tokens = cond_tokens.mean(dim=-2)
|
163 |
+
cond_hidden = self.non_attn_cond_projection(mean_pooled_cond_tokens)
|
164 |
+
|
165 |
+
# FiLM conditioning
|
166 |
+
null_cond_hidden = self.null_cond_hidden.to(cond_tokens.dtype)
|
167 |
+
cond_hidden = torch.where(keep_mask_hidden, cond_hidden, null_cond_hidden)
|
168 |
+
cond_tokens = self.norm_cond(cond_tokens)
|
169 |
+
|
170 |
+
output = self.seqTransDecoder(x, cond_tokens, cond_hidden, tgt_mask=tgt_mask)
|
171 |
+
output = self.final_layer(output)
|
172 |
+
return output
|
173 |
+
|
174 |
+
def generate(
|
175 |
+
self,
|
176 |
+
condition: th.Tensor,
|
177 |
+
sequence_length: int,
|
178 |
+
layers: int,
|
179 |
+
n_sequences: int = 1,
|
180 |
+
max_key_len: int = 8,
|
181 |
+
max_seq_len: int = 240,
|
182 |
+
top_p: float = 0.94,
|
183 |
+
) -> torch.Tensor:
|
184 |
+
"""
|
185 |
+
:param sequence_length: number of tokens to generate in autoregressive fashion
|
186 |
+
:param n_sequences: number of sequences to generate simultaneously
|
187 |
+
:param temperature: temerature of the softmax for sampling from the output logits
|
188 |
+
:return n_sequences x sequence_length LongTensor containing generated tokens
|
189 |
+
"""
|
190 |
+
assert max_key_len == int(max_seq_len / 30), "currently only running for 1fps"
|
191 |
+
max_key_len *= layers
|
192 |
+
with th.no_grad():
|
193 |
+
input_tokens = (
|
194 |
+
th.zeros(n_sequences, 1, dtype=th.int64).to(condition.device)
|
195 |
+
+ self.tokens
|
196 |
+
)
|
197 |
+
for _ in range(sequence_length * layers):
|
198 |
+
curr_input_tokens = input_tokens
|
199 |
+
curr_condition = condition
|
200 |
+
logits = self.forward(curr_input_tokens, curr_condition)
|
201 |
+
logits = logits[:, -1, :] # only most recent time step is relevant
|
202 |
+
one_hot = th.nn.functional.softmax(logits, dim=-1)
|
203 |
+
sorted_probs, indices = torch.sort(one_hot, dim=-1, descending=True)
|
204 |
+
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
|
205 |
+
nucleus = cumulative_probs < top_p
|
206 |
+
nucleus = torch.cat(
|
207 |
+
[
|
208 |
+
nucleus.new_ones(nucleus.shape[:-1] + (1,)),
|
209 |
+
nucleus[..., :-1],
|
210 |
+
],
|
211 |
+
dim=-1,
|
212 |
+
)
|
213 |
+
sorted_probs[~nucleus] = 0
|
214 |
+
sorted_probs /= sorted_probs.sum(-1, keepdim=True)
|
215 |
+
dist = Categorical(sorted_probs)
|
216 |
+
idx = dist.sample()
|
217 |
+
tokens = indices.gather(-1, idx.unsqueeze(-1))
|
218 |
+
input_tokens = th.cat([input_tokens, tokens], dim=-1)
|
219 |
+
|
220 |
+
# return generated tokens except for sequence start token
|
221 |
+
tokens = input_tokens[:, 1:].contiguous()
|
222 |
+
return tokens
|
model/modules/audio_encoder.py
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) Meta Platforms, Inc. and affiliates.
|
3 |
+
All rights reserved.
|
4 |
+
This source code is licensed under the license found in the
|
5 |
+
LICENSE file in the root directory of this source tree.
|
6 |
+
"""
|
7 |
+
|
8 |
+
import fairseq
|
9 |
+
import torch as th
|
10 |
+
import torchaudio as ta
|
11 |
+
|
12 |
+
wav2vec_model_path = "./assets/wav2vec_large.pt"
|
13 |
+
|
14 |
+
|
15 |
+
def weights_init(m):
|
16 |
+
if isinstance(m, th.nn.Conv1d):
|
17 |
+
th.nn.init.xavier_uniform_(m.weight)
|
18 |
+
try:
|
19 |
+
th.nn.init.constant_(m.bias, 0.01)
|
20 |
+
except:
|
21 |
+
pass
|
22 |
+
|
23 |
+
|
24 |
+
class Wav2VecEncoder(th.nn.Module):
|
25 |
+
def __init__(self):
|
26 |
+
super().__init__()
|
27 |
+
self.resampler = ta.transforms.Resample(orig_freq=48000, new_freq=16000)
|
28 |
+
model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task(
|
29 |
+
[wav2vec_model_path]
|
30 |
+
)
|
31 |
+
self.wav2vec_model = model[0]
|
32 |
+
|
33 |
+
def forward(self, audio: th.Tensor):
|
34 |
+
"""
|
35 |
+
:param audio: B x T x 1600
|
36 |
+
:return: B x T_wav2vec x 512
|
37 |
+
"""
|
38 |
+
audio = audio.view(audio.shape[0], audio.shape[1] * 1600)
|
39 |
+
audio = self.resampler(audio)
|
40 |
+
audio = th.cat(
|
41 |
+
[th.zeros(audio.shape[0], 320, device=audio.device), audio], dim=-1
|
42 |
+
) # zero padding on the left
|
43 |
+
x = self.wav2vec_model.feature_extractor(audio)
|
44 |
+
x = self.wav2vec_model.feature_aggregator(x)
|
45 |
+
x = x.permute(0, 2, 1).contiguous()
|
46 |
+
return x
|
47 |
+
|
48 |
+
|
49 |
+
class Wav2VecDownsampler(th.nn.Module):
|
50 |
+
def __init__(self):
|
51 |
+
super().__init__()
|
52 |
+
self.conv1 = th.nn.Conv1d(512, 512, kernel_size=3)
|
53 |
+
self.conv2 = th.nn.Conv1d(512, 512, kernel_size=3)
|
54 |
+
self.norm = th.nn.LayerNorm(512)
|
55 |
+
|
56 |
+
def forward(self, x: th.Tensor, target_length: int):
|
57 |
+
"""
|
58 |
+
:param x: B x T x 512 tensor containing wav2vec features at 100Hz
|
59 |
+
:return: B x target_length x 512 tensor containing downsampled wav2vec features at 30Hz
|
60 |
+
"""
|
61 |
+
x = x.permute(0, 2, 1).contiguous()
|
62 |
+
# first conv
|
63 |
+
x = th.nn.functional.pad(x, pad=(2, 0))
|
64 |
+
x = th.nn.functional.relu(self.conv1(x))
|
65 |
+
# first downsampling
|
66 |
+
x = th.nn.functional.interpolate(x, size=(x.shape[-1] + target_length) // 2)
|
67 |
+
# second conv
|
68 |
+
x = th.nn.functional.pad(x, pad=(2, 0))
|
69 |
+
x = self.conv2(x)
|
70 |
+
# second downsampling
|
71 |
+
x = th.nn.functional.interpolate(x, size=target_length)
|
72 |
+
# layer norm
|
73 |
+
x = x.permute(0, 2, 1).contiguous()
|
74 |
+
x = self.norm(x)
|
75 |
+
return x
|
76 |
+
|
77 |
+
|
78 |
+
class AudioTcn(th.nn.Module):
|
79 |
+
def __init__(
|
80 |
+
self,
|
81 |
+
encoding_dim: int = 128,
|
82 |
+
use_melspec: bool = True,
|
83 |
+
use_wav2vec: bool = True,
|
84 |
+
):
|
85 |
+
"""
|
86 |
+
:param encoding_dim: size of encoding
|
87 |
+
:param use_melspec: extract mel spectrogram features as input
|
88 |
+
:param use_wav2vec: extract wav2vec features as input
|
89 |
+
"""
|
90 |
+
super().__init__()
|
91 |
+
self.encoding_dim = encoding_dim
|
92 |
+
self.use_melspec = use_melspec
|
93 |
+
self.use_wav2vec = use_wav2vec
|
94 |
+
|
95 |
+
if use_melspec:
|
96 |
+
# hop_length=400 -> two feature vectors per visual frame (downsampling to 24kHz -> 800 samples per frame)
|
97 |
+
self.melspec = th.nn.Sequential(
|
98 |
+
ta.transforms.Resample(orig_freq=48000, new_freq=24000),
|
99 |
+
ta.transforms.MelSpectrogram(
|
100 |
+
sample_rate=24000,
|
101 |
+
n_fft=1024,
|
102 |
+
win_length=800,
|
103 |
+
hop_length=400,
|
104 |
+
n_mels=80,
|
105 |
+
),
|
106 |
+
)
|
107 |
+
|
108 |
+
if use_wav2vec:
|
109 |
+
model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task(
|
110 |
+
[wav2vec_model_path]
|
111 |
+
)
|
112 |
+
self.wav2vec_model = model[0]
|
113 |
+
self.wav2vec_model.eval()
|
114 |
+
self.wav2vec_postprocess = th.nn.Conv1d(512, 256, kernel_size=3)
|
115 |
+
self.wav2vec_postprocess.apply(lambda x: weights_init(x))
|
116 |
+
|
117 |
+
# temporal model
|
118 |
+
input_dim = 0 + (160 if use_melspec else 0) + (256 if use_wav2vec else 0)
|
119 |
+
self.layers = th.nn.ModuleList(
|
120 |
+
[
|
121 |
+
th.nn.Conv1d(
|
122 |
+
input_dim, max(256, encoding_dim), kernel_size=3, dilation=1
|
123 |
+
), # 2 (+1)
|
124 |
+
th.nn.Conv1d(
|
125 |
+
max(256, encoding_dim), encoding_dim, kernel_size=3, dilation=2
|
126 |
+
), # 4 (+1)
|
127 |
+
th.nn.Conv1d(
|
128 |
+
encoding_dim, encoding_dim, kernel_size=3, dilation=3
|
129 |
+
), # 6 (+1)
|
130 |
+
th.nn.Conv1d(
|
131 |
+
encoding_dim, encoding_dim, kernel_size=3, dilation=1
|
132 |
+
), # 2 (+1)
|
133 |
+
th.nn.Conv1d(
|
134 |
+
encoding_dim, encoding_dim, kernel_size=3, dilation=2
|
135 |
+
), # 4 (+1)
|
136 |
+
th.nn.Conv1d(
|
137 |
+
encoding_dim, encoding_dim, kernel_size=3, dilation=3
|
138 |
+
), # 6 (+1)
|
139 |
+
]
|
140 |
+
)
|
141 |
+
self.layers.apply(lambda x: weights_init(x))
|
142 |
+
self.receptive_field = 25
|
143 |
+
|
144 |
+
self.final = th.nn.Conv1d(encoding_dim, encoding_dim, kernel_size=1)
|
145 |
+
self.final.apply(lambda x: weights_init(x))
|
146 |
+
|
147 |
+
def forward(self, audio):
|
148 |
+
"""
|
149 |
+
:param audio: B x T x 1600 tensor containing audio samples for each frame
|
150 |
+
:return: B x T x encoding_dim tensor containing audio encodings for each frame
|
151 |
+
"""
|
152 |
+
B, T = audio.shape[0], audio.shape[1]
|
153 |
+
|
154 |
+
# preprocess raw audio signal to extract feature vectors
|
155 |
+
audio = audio.view(B, T * 1600)
|
156 |
+
x_mel, x_w2v = th.zeros(B, 0, T).to(audio.device), th.zeros(B, 0, T).to(
|
157 |
+
audio.device
|
158 |
+
)
|
159 |
+
if self.use_melspec:
|
160 |
+
x_mel = self.melspec(audio)[:, :, 1:].contiguous()
|
161 |
+
x_mel = th.log(x_mel.clamp(min=1e-10, max=None))
|
162 |
+
x_mel = (
|
163 |
+
x_mel.permute(0, 2, 1)
|
164 |
+
.contiguous()
|
165 |
+
.view(x_mel.shape[0], T, 160)
|
166 |
+
.permute(0, 2, 1)
|
167 |
+
.contiguous()
|
168 |
+
)
|
169 |
+
if self.use_wav2vec:
|
170 |
+
with th.no_grad():
|
171 |
+
x_w2v = ta.functional.resample(audio, 48000, 16000)
|
172 |
+
x_w2v = self.wav2vec_model.feature_extractor(x_w2v)
|
173 |
+
x_w2v = self.wav2vec_model.feature_aggregator(x_w2v)
|
174 |
+
x_w2v = self.wav2vec_postprocess(th.nn.functional.pad(x_w2v, pad=[2, 0]))
|
175 |
+
x_w2v = th.nn.functional.interpolate(
|
176 |
+
x_w2v, size=T, align_corners=True, mode="linear"
|
177 |
+
)
|
178 |
+
x = th.cat([x_mel, x_w2v], dim=1)
|
179 |
+
|
180 |
+
# process signal with TCN
|
181 |
+
x = th.nn.functional.pad(x, pad=[self.receptive_field - 1, 0])
|
182 |
+
for layer_idx, layer in enumerate(self.layers):
|
183 |
+
y = th.nn.functional.leaky_relu(layer(x), negative_slope=0.2)
|
184 |
+
if self.training:
|
185 |
+
y = th.nn.functional.dropout(y, 0.2)
|
186 |
+
if x.shape[1] == y.shape[1]:
|
187 |
+
x = (x[:, :, -y.shape[-1] :] + y) / 2.0 # skip connection
|
188 |
+
else:
|
189 |
+
x = y
|
190 |
+
|
191 |
+
x = self.final(x)
|
192 |
+
x = x.permute(0, 2, 1).contiguous()
|
193 |
+
|
194 |
+
return x
|
model/modules/rotary_embedding_torch.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) Meta Platforms, Inc. and affiliates.
|
3 |
+
All rights reserved.
|
4 |
+
This source code is licensed under the license found in the
|
5 |
+
LICENSE file in the root directory of this source tree.
|
6 |
+
"""
|
7 |
+
|
8 |
+
from inspect import isfunction
|
9 |
+
from math import log, pi
|
10 |
+
|
11 |
+
import torch
|
12 |
+
from einops import rearrange, repeat
|
13 |
+
from torch import einsum, nn
|
14 |
+
|
15 |
+
# helper functions
|
16 |
+
|
17 |
+
|
18 |
+
def exists(val):
|
19 |
+
return val is not None
|
20 |
+
|
21 |
+
|
22 |
+
def broadcat(tensors, dim=-1):
|
23 |
+
num_tensors = len(tensors)
|
24 |
+
shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
|
25 |
+
assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
|
26 |
+
shape_len = list(shape_lens)[0]
|
27 |
+
|
28 |
+
dim = (dim + shape_len) if dim < 0 else dim
|
29 |
+
dims = list(zip(*map(lambda t: list(t.shape), tensors)))
|
30 |
+
|
31 |
+
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
|
32 |
+
assert all(
|
33 |
+
[*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]
|
34 |
+
), "invalid dimensions for broadcastable concatentation"
|
35 |
+
max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
|
36 |
+
expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
|
37 |
+
expanded_dims.insert(dim, (dim, dims[dim]))
|
38 |
+
expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
|
39 |
+
tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
|
40 |
+
return torch.cat(tensors, dim=dim)
|
41 |
+
|
42 |
+
|
43 |
+
# rotary embedding helper functions
|
44 |
+
|
45 |
+
|
46 |
+
def rotate_half(x):
|
47 |
+
x = rearrange(x, "... (d r) -> ... d r", r=2)
|
48 |
+
x1, x2 = x.unbind(dim=-1)
|
49 |
+
x = torch.stack((-x2, x1), dim=-1)
|
50 |
+
return rearrange(x, "... d r -> ... (d r)")
|
51 |
+
|
52 |
+
|
53 |
+
def apply_rotary_emb(freqs, t, start_index=0):
|
54 |
+
freqs = freqs.to(t)
|
55 |
+
rot_dim = freqs.shape[-1]
|
56 |
+
end_index = start_index + rot_dim
|
57 |
+
assert (
|
58 |
+
rot_dim <= t.shape[-1]
|
59 |
+
), f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}"
|
60 |
+
t_left, t, t_right = (
|
61 |
+
t[..., :start_index],
|
62 |
+
t[..., start_index:end_index],
|
63 |
+
t[..., end_index:],
|
64 |
+
)
|
65 |
+
t = (t * freqs.cos()) + (rotate_half(t) * freqs.sin())
|
66 |
+
return torch.cat((t_left, t, t_right), dim=-1)
|
67 |
+
|
68 |
+
|
69 |
+
# learned rotation helpers
|
70 |
+
|
71 |
+
|
72 |
+
def apply_learned_rotations(rotations, t, start_index=0, freq_ranges=None):
|
73 |
+
if exists(freq_ranges):
|
74 |
+
rotations = einsum("..., f -> ... f", rotations, freq_ranges)
|
75 |
+
rotations = rearrange(rotations, "... r f -> ... (r f)")
|
76 |
+
|
77 |
+
rotations = repeat(rotations, "... n -> ... (n r)", r=2)
|
78 |
+
return apply_rotary_emb(rotations, t, start_index=start_index)
|
79 |
+
|
80 |
+
|
81 |
+
# classes
|
82 |
+
|
83 |
+
|
84 |
+
class RotaryEmbedding(nn.Module):
|
85 |
+
def __init__(
|
86 |
+
self,
|
87 |
+
dim,
|
88 |
+
custom_freqs=None,
|
89 |
+
freqs_for="lang",
|
90 |
+
theta=10000,
|
91 |
+
max_freq=10,
|
92 |
+
num_freqs=1,
|
93 |
+
learned_freq=False,
|
94 |
+
):
|
95 |
+
super().__init__()
|
96 |
+
if exists(custom_freqs):
|
97 |
+
freqs = custom_freqs
|
98 |
+
elif freqs_for == "lang":
|
99 |
+
freqs = 1.0 / (
|
100 |
+
theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
|
101 |
+
)
|
102 |
+
elif freqs_for == "pixel":
|
103 |
+
freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
|
104 |
+
elif freqs_for == "constant":
|
105 |
+
freqs = torch.ones(num_freqs).float()
|
106 |
+
else:
|
107 |
+
raise ValueError(f"unknown modality {freqs_for}")
|
108 |
+
|
109 |
+
self.cache = dict()
|
110 |
+
|
111 |
+
if learned_freq:
|
112 |
+
self.freqs = nn.Parameter(freqs)
|
113 |
+
else:
|
114 |
+
self.register_buffer("freqs", freqs)
|
115 |
+
|
116 |
+
def rotate_queries_or_keys(self, t, seq_dim=-2):
|
117 |
+
device = t.device
|
118 |
+
seq_len = t.shape[seq_dim]
|
119 |
+
freqs = self.forward(
|
120 |
+
lambda: torch.arange(seq_len, device=device), cache_key=seq_len
|
121 |
+
)
|
122 |
+
return apply_rotary_emb(freqs, t)
|
123 |
+
|
124 |
+
def forward(self, t, cache_key=None):
|
125 |
+
if exists(cache_key) and cache_key in self.cache:
|
126 |
+
return self.cache[cache_key]
|
127 |
+
|
128 |
+
if isfunction(t):
|
129 |
+
t = t()
|
130 |
+
|
131 |
+
freqs = self.freqs
|
132 |
+
|
133 |
+
freqs = torch.einsum("..., f -> ... f", t.type(freqs.dtype), freqs)
|
134 |
+
freqs = repeat(freqs, "... n -> ... (n r)", r=2)
|
135 |
+
|
136 |
+
if exists(cache_key):
|
137 |
+
self.cache[cache_key] = freqs
|
138 |
+
|
139 |
+
return freqs
|
model/modules/transformer_modules.py
ADDED
@@ -0,0 +1,702 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) Meta Platforms, Inc. and affiliates.
|
3 |
+
All rights reserved.
|
4 |
+
This source code is licensed under the license found in the
|
5 |
+
LICENSE file in the root directory of this source tree.
|
6 |
+
"""
|
7 |
+
|
8 |
+
import math
|
9 |
+
from typing import Any, Callable, List, Optional, Union
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
from einops import rearrange
|
14 |
+
from torch import Tensor
|
15 |
+
from torch.nn import functional as F
|
16 |
+
|
17 |
+
|
18 |
+
def generate_causal_mask(source_length, target_length, device="cpu"):
|
19 |
+
if source_length == target_length:
|
20 |
+
mask = (
|
21 |
+
torch.triu(torch.ones(target_length, source_length, device=device)) == 1
|
22 |
+
).transpose(0, 1)
|
23 |
+
else:
|
24 |
+
mask = torch.zeros(target_length, source_length, device=device)
|
25 |
+
idx = torch.linspace(0, source_length, target_length + 1)[1:].round().long()
|
26 |
+
for i in range(target_length):
|
27 |
+
mask[i, 0 : idx[i]] = 1
|
28 |
+
|
29 |
+
return (
|
30 |
+
mask.float()
|
31 |
+
.masked_fill(mask == 0, float("-inf"))
|
32 |
+
.masked_fill(mask == 1, float(0.0))
|
33 |
+
)
|
34 |
+
|
35 |
+
|
36 |
+
class TransformerEncoderLayerRotary(nn.Module):
|
37 |
+
def __init__(
|
38 |
+
self,
|
39 |
+
d_model: int,
|
40 |
+
nhead: int,
|
41 |
+
dim_feedforward: int = 2048,
|
42 |
+
dropout: float = 0.1,
|
43 |
+
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
|
44 |
+
layer_norm_eps: float = 1e-5,
|
45 |
+
batch_first: bool = False,
|
46 |
+
norm_first: bool = True,
|
47 |
+
rotary=None,
|
48 |
+
) -> None:
|
49 |
+
super().__init__()
|
50 |
+
self.self_attn = nn.MultiheadAttention(
|
51 |
+
d_model, nhead, dropout=dropout, batch_first=batch_first
|
52 |
+
)
|
53 |
+
# Implementation of Feedforward model
|
54 |
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
55 |
+
self.dropout = nn.Dropout(dropout)
|
56 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
57 |
+
|
58 |
+
self.norm_first = norm_first
|
59 |
+
self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
|
60 |
+
self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
|
61 |
+
self.dropout1 = nn.Dropout(dropout)
|
62 |
+
self.dropout2 = nn.Dropout(dropout)
|
63 |
+
self.activation = activation
|
64 |
+
|
65 |
+
self.rotary = rotary
|
66 |
+
self.use_rotary = rotary is not None
|
67 |
+
|
68 |
+
def forward(
|
69 |
+
self,
|
70 |
+
src: Tensor,
|
71 |
+
src_mask: Optional[Tensor] = None,
|
72 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
73 |
+
) -> Tensor:
|
74 |
+
x = src
|
75 |
+
if self.norm_first:
|
76 |
+
x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask)
|
77 |
+
x = x + self._ff_block(self.norm2(x))
|
78 |
+
else:
|
79 |
+
x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask))
|
80 |
+
x = self.norm2(x + self._ff_block(x))
|
81 |
+
|
82 |
+
return x
|
83 |
+
|
84 |
+
# self-attention block
|
85 |
+
def _sa_block(
|
86 |
+
self, x: Tensor, attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor]
|
87 |
+
) -> Tensor:
|
88 |
+
qk = self.rotary.rotate_queries_or_keys(x) if self.use_rotary else x
|
89 |
+
x = self.self_attn(
|
90 |
+
qk,
|
91 |
+
qk,
|
92 |
+
x,
|
93 |
+
attn_mask=attn_mask,
|
94 |
+
key_padding_mask=key_padding_mask,
|
95 |
+
need_weights=False,
|
96 |
+
)[0]
|
97 |
+
return self.dropout1(x)
|
98 |
+
|
99 |
+
# feed forward block
|
100 |
+
def _ff_block(self, x: Tensor) -> Tensor:
|
101 |
+
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
|
102 |
+
return self.dropout2(x)
|
103 |
+
|
104 |
+
|
105 |
+
class DenseFiLM(nn.Module):
|
106 |
+
"""Feature-wise linear modulation (FiLM) generator."""
|
107 |
+
|
108 |
+
def __init__(self, embed_channels):
|
109 |
+
super().__init__()
|
110 |
+
self.embed_channels = embed_channels
|
111 |
+
self.block = nn.Sequential(
|
112 |
+
nn.Mish(), nn.Linear(embed_channels, embed_channels * 2)
|
113 |
+
)
|
114 |
+
|
115 |
+
def forward(self, position):
|
116 |
+
pos_encoding = self.block(position)
|
117 |
+
pos_encoding = rearrange(pos_encoding, "b c -> b 1 c")
|
118 |
+
scale_shift = pos_encoding.chunk(2, dim=-1)
|
119 |
+
return scale_shift
|
120 |
+
|
121 |
+
|
122 |
+
def featurewise_affine(x, scale_shift):
|
123 |
+
scale, shift = scale_shift
|
124 |
+
return (scale + 1) * x + shift
|
125 |
+
|
126 |
+
|
127 |
+
class FiLMTransformerDecoderLayer(nn.Module):
|
128 |
+
def __init__(
|
129 |
+
self,
|
130 |
+
d_model: int,
|
131 |
+
nhead: int,
|
132 |
+
dim_feedforward=2048,
|
133 |
+
dropout=0.1,
|
134 |
+
activation=F.relu,
|
135 |
+
layer_norm_eps=1e-5,
|
136 |
+
batch_first=False,
|
137 |
+
norm_first=True,
|
138 |
+
rotary=None,
|
139 |
+
use_cm=False,
|
140 |
+
):
|
141 |
+
super().__init__()
|
142 |
+
self.self_attn = nn.MultiheadAttention(
|
143 |
+
d_model, nhead, dropout=dropout, batch_first=batch_first
|
144 |
+
)
|
145 |
+
self.multihead_attn = nn.MultiheadAttention(
|
146 |
+
d_model, nhead, dropout=dropout, batch_first=batch_first
|
147 |
+
)
|
148 |
+
# Feedforward
|
149 |
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
150 |
+
self.dropout = nn.Dropout(dropout)
|
151 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
152 |
+
|
153 |
+
self.norm_first = norm_first
|
154 |
+
self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
|
155 |
+
self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
|
156 |
+
self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps)
|
157 |
+
self.dropout1 = nn.Dropout(dropout)
|
158 |
+
self.dropout2 = nn.Dropout(dropout)
|
159 |
+
self.dropout3 = nn.Dropout(dropout)
|
160 |
+
self.activation = activation
|
161 |
+
|
162 |
+
self.film1 = DenseFiLM(d_model)
|
163 |
+
self.film2 = DenseFiLM(d_model)
|
164 |
+
self.film3 = DenseFiLM(d_model)
|
165 |
+
|
166 |
+
if use_cm:
|
167 |
+
self.multihead_attn2 = nn.MultiheadAttention( # 2
|
168 |
+
d_model, nhead, dropout=dropout, batch_first=batch_first
|
169 |
+
)
|
170 |
+
self.norm2a = nn.LayerNorm(d_model, eps=layer_norm_eps) # 2
|
171 |
+
self.dropout2a = nn.Dropout(dropout) # 2
|
172 |
+
self.film2a = DenseFiLM(d_model) # 2
|
173 |
+
|
174 |
+
self.rotary = rotary
|
175 |
+
self.use_rotary = rotary is not None
|
176 |
+
|
177 |
+
# x, cond, t
|
178 |
+
def forward(
|
179 |
+
self,
|
180 |
+
tgt,
|
181 |
+
memory,
|
182 |
+
t,
|
183 |
+
tgt_mask=None,
|
184 |
+
memory_mask=None,
|
185 |
+
tgt_key_padding_mask=None,
|
186 |
+
memory_key_padding_mask=None,
|
187 |
+
memory2=None,
|
188 |
+
):
|
189 |
+
x = tgt
|
190 |
+
if self.norm_first:
|
191 |
+
# self-attention -> film -> residual
|
192 |
+
x_1 = self._sa_block(self.norm1(x), tgt_mask, tgt_key_padding_mask)
|
193 |
+
x = x + featurewise_affine(x_1, self.film1(t))
|
194 |
+
# cross-attention -> film -> residual
|
195 |
+
x_2 = self._mha_block(
|
196 |
+
self.norm2(x),
|
197 |
+
memory,
|
198 |
+
memory_mask,
|
199 |
+
memory_key_padding_mask,
|
200 |
+
self.multihead_attn,
|
201 |
+
self.dropout2,
|
202 |
+
)
|
203 |
+
x = x + featurewise_affine(x_2, self.film2(t))
|
204 |
+
if memory2 is not None:
|
205 |
+
# cross-attention x2 -> film -> residual
|
206 |
+
x_2a = self._mha_block(
|
207 |
+
self.norm2a(x),
|
208 |
+
memory2,
|
209 |
+
memory_mask,
|
210 |
+
memory_key_padding_mask,
|
211 |
+
self.multihead_attn2,
|
212 |
+
self.dropout2a,
|
213 |
+
)
|
214 |
+
x = x + featurewise_affine(x_2a, self.film2a(t))
|
215 |
+
# feedforward -> film -> residual
|
216 |
+
x_3 = self._ff_block(self.norm3(x))
|
217 |
+
x = x + featurewise_affine(x_3, self.film3(t))
|
218 |
+
else:
|
219 |
+
x = self.norm1(
|
220 |
+
x
|
221 |
+
+ featurewise_affine(
|
222 |
+
self._sa_block(x, tgt_mask, tgt_key_padding_mask), self.film1(t)
|
223 |
+
)
|
224 |
+
)
|
225 |
+
x = self.norm2(
|
226 |
+
x
|
227 |
+
+ featurewise_affine(
|
228 |
+
self._mha_block(x, memory, memory_mask, memory_key_padding_mask),
|
229 |
+
self.film2(t),
|
230 |
+
)
|
231 |
+
)
|
232 |
+
x = self.norm3(x + featurewise_affine(self._ff_block(x), self.film3(t)))
|
233 |
+
return x
|
234 |
+
|
235 |
+
# self-attention block
|
236 |
+
# qkv
|
237 |
+
def _sa_block(self, x, attn_mask, key_padding_mask):
|
238 |
+
qk = self.rotary.rotate_queries_or_keys(x) if self.use_rotary else x
|
239 |
+
x = self.self_attn(
|
240 |
+
qk,
|
241 |
+
qk,
|
242 |
+
x,
|
243 |
+
attn_mask=attn_mask,
|
244 |
+
key_padding_mask=key_padding_mask,
|
245 |
+
need_weights=False,
|
246 |
+
)[0]
|
247 |
+
return self.dropout1(x)
|
248 |
+
|
249 |
+
# multihead attention block
|
250 |
+
# qkv
|
251 |
+
def _mha_block(self, x, mem, attn_mask, key_padding_mask, mha, dropout):
|
252 |
+
q = self.rotary.rotate_queries_or_keys(x) if self.use_rotary else x
|
253 |
+
k = self.rotary.rotate_queries_or_keys(mem) if self.use_rotary else mem
|
254 |
+
x = mha(
|
255 |
+
q,
|
256 |
+
k,
|
257 |
+
mem,
|
258 |
+
attn_mask=attn_mask,
|
259 |
+
key_padding_mask=key_padding_mask,
|
260 |
+
need_weights=False,
|
261 |
+
)[0]
|
262 |
+
return dropout(x)
|
263 |
+
|
264 |
+
# feed forward block
|
265 |
+
def _ff_block(self, x):
|
266 |
+
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
|
267 |
+
return self.dropout3(x)
|
268 |
+
|
269 |
+
|
270 |
+
class DecoderLayerStack(nn.Module):
|
271 |
+
def __init__(self, stack):
|
272 |
+
super().__init__()
|
273 |
+
self.stack = stack
|
274 |
+
|
275 |
+
def forward(self, x, cond, t, tgt_mask=None, memory2=None):
|
276 |
+
for layer in self.stack:
|
277 |
+
x = layer(x, cond, t, tgt_mask=tgt_mask, memory2=memory2)
|
278 |
+
return x
|
279 |
+
|
280 |
+
|
281 |
+
class PositionalEncoding(nn.Module):
|
282 |
+
def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 1024):
|
283 |
+
super().__init__()
|
284 |
+
pe = torch.zeros(max_len, d_model)
|
285 |
+
position = torch.arange(0, max_len).unsqueeze(1)
|
286 |
+
div_term = torch.exp(
|
287 |
+
torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
|
288 |
+
)
|
289 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
290 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
291 |
+
|
292 |
+
self.register_buffer("pe", pe)
|
293 |
+
self.dropout = nn.Dropout(p=dropout)
|
294 |
+
|
295 |
+
def forward(self, x: torch.Tensor):
|
296 |
+
"""
|
297 |
+
:param x: B x T x d_model tensor
|
298 |
+
:return: B x T x d_model tensor
|
299 |
+
"""
|
300 |
+
x = x + self.pe[None, : x.shape[1], :]
|
301 |
+
x = self.dropout(x)
|
302 |
+
return x
|
303 |
+
|
304 |
+
|
305 |
+
class TimestepEncoding(nn.Module):
|
306 |
+
def __init__(self, embedding_dim: int):
|
307 |
+
super().__init__()
|
308 |
+
|
309 |
+
# Fourier embedding
|
310 |
+
half_dim = embedding_dim // 2
|
311 |
+
emb = math.log(10000) / (half_dim - 1)
|
312 |
+
emb = torch.exp(torch.arange(half_dim) * -emb)
|
313 |
+
self.register_buffer("emb", emb)
|
314 |
+
|
315 |
+
# encoding
|
316 |
+
self.encoding = nn.Sequential(
|
317 |
+
nn.Linear(embedding_dim, 4 * embedding_dim),
|
318 |
+
nn.Mish(),
|
319 |
+
nn.Linear(4 * embedding_dim, embedding_dim),
|
320 |
+
)
|
321 |
+
|
322 |
+
def forward(self, t: torch.Tensor):
|
323 |
+
"""
|
324 |
+
:param t: B-dimensional tensor containing timesteps in range [0, 1]
|
325 |
+
:return: B x embedding_dim tensor containing timestep encodings
|
326 |
+
"""
|
327 |
+
x = t[:, None] * self.emb[None, :]
|
328 |
+
x = torch.cat([torch.sin(x), torch.cos(x)], dim=-1)
|
329 |
+
x = self.encoding(x)
|
330 |
+
return x
|
331 |
+
|
332 |
+
|
333 |
+
class FiLM(nn.Module):
|
334 |
+
def __init__(self, dim: int):
|
335 |
+
super().__init__()
|
336 |
+
self.dim = dim
|
337 |
+
self.film = nn.Sequential(nn.Mish(), nn.Linear(dim, dim * 2))
|
338 |
+
|
339 |
+
def forward(self, x: torch.Tensor, cond: torch.Tensor):
|
340 |
+
"""
|
341 |
+
:param x: ... x dim tensor
|
342 |
+
:param cond: ... x dim tensor
|
343 |
+
:return: ... x dim tensor as scale(cond) * x + bias(cond)
|
344 |
+
"""
|
345 |
+
cond = self.film(cond)
|
346 |
+
scale, bias = torch.chunk(cond, chunks=2, dim=-1)
|
347 |
+
x = (scale + 1) * x + bias
|
348 |
+
return x
|
349 |
+
|
350 |
+
|
351 |
+
class FeedforwardBlock(nn.Module):
|
352 |
+
def __init__(self, d_model: int, d_feedforward: int = 1024, dropout: float = 0.1):
|
353 |
+
super().__init__()
|
354 |
+
self.ff = nn.Sequential(
|
355 |
+
nn.Linear(d_model, d_feedforward),
|
356 |
+
nn.ReLU(),
|
357 |
+
nn.Dropout(p=dropout),
|
358 |
+
nn.Linear(d_feedforward, d_model),
|
359 |
+
nn.Dropout(p=dropout),
|
360 |
+
)
|
361 |
+
|
362 |
+
def forward(self, x: torch.Tensor):
|
363 |
+
"""
|
364 |
+
:param x: ... x d_model tensor
|
365 |
+
:return: ... x d_model tensor
|
366 |
+
"""
|
367 |
+
return self.ff(x)
|
368 |
+
|
369 |
+
|
370 |
+
class SelfAttention(nn.Module):
|
371 |
+
def __init__(self, d_model: int, num_heads: int, dropout: float = 0.1):
|
372 |
+
super().__init__()
|
373 |
+
self.self_attn = nn.MultiheadAttention(
|
374 |
+
d_model, num_heads, dropout=dropout, batch_first=True
|
375 |
+
)
|
376 |
+
self.dropout = nn.Dropout(p=dropout)
|
377 |
+
|
378 |
+
def forward(
|
379 |
+
self,
|
380 |
+
x: torch.Tensor,
|
381 |
+
attn_mask: torch.Tensor = None,
|
382 |
+
key_padding_mask: torch.Tensor = None,
|
383 |
+
):
|
384 |
+
"""
|
385 |
+
:param x: B x T x d_model input tensor
|
386 |
+
:param attn_mask: B * num_heads x L x S mask with L=target sequence length, S=source sequence length
|
387 |
+
for a float mask: values will be added to attention weight
|
388 |
+
for a binary mask: True indicates that the element is not allowed to attend
|
389 |
+
:param key_padding_mask: B x S mask
|
390 |
+
for a float mask: values will be added directly to the corresponding key values
|
391 |
+
for a binary mask: True indicates that the corresponding key value will be ignored
|
392 |
+
:return: B x T x d_model output tensor
|
393 |
+
"""
|
394 |
+
x = self.self_attn(
|
395 |
+
x,
|
396 |
+
x,
|
397 |
+
x,
|
398 |
+
attn_mask=attn_mask,
|
399 |
+
key_padding_mask=key_padding_mask,
|
400 |
+
need_weights=False,
|
401 |
+
)[0]
|
402 |
+
x = self.dropout(x)
|
403 |
+
return x
|
404 |
+
|
405 |
+
|
406 |
+
class CrossAttention(nn.Module):
|
407 |
+
def __init__(self, d_model: int, d_cond: int, num_heads: int, dropout: float = 0.1):
|
408 |
+
super().__init__()
|
409 |
+
self.cross_attn = nn.MultiheadAttention(
|
410 |
+
d_model,
|
411 |
+
num_heads,
|
412 |
+
dropout=dropout,
|
413 |
+
batch_first=True,
|
414 |
+
kdim=d_cond,
|
415 |
+
vdim=d_cond,
|
416 |
+
)
|
417 |
+
self.dropout = nn.Dropout(p=dropout)
|
418 |
+
|
419 |
+
def forward(
|
420 |
+
self,
|
421 |
+
x: torch.Tensor,
|
422 |
+
cond: torch.Tensor,
|
423 |
+
attn_mask: torch.Tensor = None,
|
424 |
+
key_padding_mask: torch.Tensor = None,
|
425 |
+
):
|
426 |
+
"""
|
427 |
+
:param x: B x T_target x d_model input tensor
|
428 |
+
:param cond: B x T_cond x d_cond condition tensor
|
429 |
+
:param attn_mask: B * num_heads x L x S mask with L=target sequence length, S=source sequence length
|
430 |
+
for a float mask: values will be added to attention weight
|
431 |
+
for a binary mask: True indicates that the element is not allowed to attend
|
432 |
+
:param key_padding_mask: B x S mask
|
433 |
+
for a float mask: values will be added directly to the corresponding key values
|
434 |
+
for a binary mask: True indicates that the corresponding key value will be ignored
|
435 |
+
:return: B x T x d_model output tensor
|
436 |
+
"""
|
437 |
+
x = self.cross_attn(
|
438 |
+
x,
|
439 |
+
cond,
|
440 |
+
cond,
|
441 |
+
attn_mask=attn_mask,
|
442 |
+
key_padding_mask=key_padding_mask,
|
443 |
+
need_weights=False,
|
444 |
+
)[0]
|
445 |
+
x = self.dropout(x)
|
446 |
+
return x
|
447 |
+
|
448 |
+
|
449 |
+
class TransformerEncoderLayer(nn.Module):
|
450 |
+
def __init__(
|
451 |
+
self,
|
452 |
+
d_model: int,
|
453 |
+
num_heads: int,
|
454 |
+
d_feedforward: int = 1024,
|
455 |
+
dropout: float = 0.1,
|
456 |
+
):
|
457 |
+
super().__init__()
|
458 |
+
self.norm1 = nn.LayerNorm(d_model)
|
459 |
+
self.self_attn = SelfAttention(d_model, num_heads, dropout)
|
460 |
+
self.norm2 = nn.LayerNorm(d_model)
|
461 |
+
self.feedforward = FeedforwardBlock(d_model, d_feedforward, dropout)
|
462 |
+
|
463 |
+
def forward(
|
464 |
+
self,
|
465 |
+
x: torch.Tensor,
|
466 |
+
mask: torch.Tensor = None,
|
467 |
+
key_padding_mask: torch.Tensor = None,
|
468 |
+
):
|
469 |
+
x = x + self.self_attn(self.norm1(x), mask, key_padding_mask)
|
470 |
+
x = x + self.feedforward(self.norm2(x))
|
471 |
+
return x
|
472 |
+
|
473 |
+
|
474 |
+
class TransformerDecoderLayer(nn.Module):
|
475 |
+
def __init__(
|
476 |
+
self,
|
477 |
+
d_model: int,
|
478 |
+
d_cond: int,
|
479 |
+
num_heads: int,
|
480 |
+
d_feedforward: int = 1024,
|
481 |
+
dropout: float = 0.1,
|
482 |
+
):
|
483 |
+
super().__init__()
|
484 |
+
self.norm1 = nn.LayerNorm(d_model)
|
485 |
+
self.self_attn = SelfAttention(d_model, num_heads, dropout)
|
486 |
+
self.norm2 = nn.LayerNorm(d_model)
|
487 |
+
self.cross_attn = CrossAttention(d_model, d_cond, num_heads, dropout)
|
488 |
+
self.norm3 = nn.LayerNorm(d_model)
|
489 |
+
self.feedforward = FeedforwardBlock(d_model, d_feedforward, dropout)
|
490 |
+
|
491 |
+
def forward(
|
492 |
+
self,
|
493 |
+
x: torch.Tensor,
|
494 |
+
cross_cond: torch.Tensor,
|
495 |
+
target_mask: torch.Tensor = None,
|
496 |
+
target_key_padding_mask: torch.Tensor = None,
|
497 |
+
cross_cond_mask: torch.Tensor = None,
|
498 |
+
cross_cond_key_padding_mask: torch.Tensor = None,
|
499 |
+
):
|
500 |
+
"""
|
501 |
+
:param x: B x T x d_model tensor
|
502 |
+
:param cross_cond: B x T x d_cond tensor containing the conditioning input to cross attention layers
|
503 |
+
:return: B x T x d_model tensor
|
504 |
+
"""
|
505 |
+
x = x + self.self_attn(self.norm1(x), target_mask, target_key_padding_mask)
|
506 |
+
x = x + self.cross_attn(
|
507 |
+
self.norm2(x), cross_cond, cross_cond_mask, cross_cond_key_padding_mask
|
508 |
+
)
|
509 |
+
x = x + self.feedforward(self.norm3(x))
|
510 |
+
return x
|
511 |
+
|
512 |
+
|
513 |
+
class FilmTransformerDecoderLayer(nn.Module):
|
514 |
+
def __init__(
|
515 |
+
self,
|
516 |
+
d_model: int,
|
517 |
+
d_cond: int,
|
518 |
+
num_heads: int,
|
519 |
+
d_feedforward: int = 1024,
|
520 |
+
dropout: float = 0.1,
|
521 |
+
):
|
522 |
+
super().__init__()
|
523 |
+
self.norm1 = nn.LayerNorm(d_model)
|
524 |
+
self.self_attn = SelfAttention(d_model, num_heads, dropout)
|
525 |
+
self.film1 = FiLM(d_model)
|
526 |
+
self.norm2 = nn.LayerNorm(d_model)
|
527 |
+
self.cross_attn = CrossAttention(d_model, d_cond, num_heads, dropout)
|
528 |
+
self.film2 = FiLM(d_model)
|
529 |
+
self.norm3 = nn.LayerNorm(d_model)
|
530 |
+
self.feedforward = FeedforwardBlock(d_model, d_feedforward, dropout)
|
531 |
+
self.film3 = FiLM(d_model)
|
532 |
+
|
533 |
+
def forward(
|
534 |
+
self,
|
535 |
+
x: torch.Tensor,
|
536 |
+
cross_cond: torch.Tensor,
|
537 |
+
film_cond: torch.Tensor,
|
538 |
+
target_mask: torch.Tensor = None,
|
539 |
+
target_key_padding_mask: torch.Tensor = None,
|
540 |
+
cross_cond_mask: torch.Tensor = None,
|
541 |
+
cross_cond_key_padding_mask: torch.Tensor = None,
|
542 |
+
):
|
543 |
+
"""
|
544 |
+
:param x: B x T x d_model tensor
|
545 |
+
:param cross_cond: B x T x d_cond tensor containing the conditioning input to cross attention layers
|
546 |
+
:param film_cond: B x [1 or T] x film_cond tensor containing the conditioning input to FiLM layers
|
547 |
+
:return: B x T x d_model tensor
|
548 |
+
"""
|
549 |
+
x1 = self.self_attn(self.norm1(x), target_mask, target_key_padding_mask)
|
550 |
+
x = x + self.film1(x1, film_cond)
|
551 |
+
x2 = self.cross_attn(
|
552 |
+
self.norm2(x), cross_cond, cross_cond_mask, cross_cond_key_padding_mask
|
553 |
+
)
|
554 |
+
x = x + self.film2(x2, film_cond)
|
555 |
+
x3 = self.feedforward(self.norm3(x))
|
556 |
+
x = x + self.film3(x3, film_cond)
|
557 |
+
return x
|
558 |
+
|
559 |
+
|
560 |
+
class RegressionTransformer(nn.Module):
|
561 |
+
def __init__(
|
562 |
+
self,
|
563 |
+
transformer_encoder_layers: int = 2,
|
564 |
+
transformer_decoder_layers: int = 4,
|
565 |
+
d_model: int = 512,
|
566 |
+
d_cond: int = 512,
|
567 |
+
num_heads: int = 4,
|
568 |
+
d_feedforward: int = 1024,
|
569 |
+
dropout: float = 0.1,
|
570 |
+
causal: bool = False,
|
571 |
+
):
|
572 |
+
super().__init__()
|
573 |
+
self.causal = causal
|
574 |
+
|
575 |
+
self.cond_positional_encoding = PositionalEncoding(d_cond, dropout)
|
576 |
+
self.target_positional_encoding = PositionalEncoding(d_model, dropout)
|
577 |
+
|
578 |
+
self.transformer_encoder = nn.ModuleList(
|
579 |
+
[
|
580 |
+
TransformerEncoderLayer(d_cond, num_heads, d_feedforward, dropout)
|
581 |
+
for _ in range(transformer_encoder_layers)
|
582 |
+
]
|
583 |
+
)
|
584 |
+
|
585 |
+
self.transformer_decoder = nn.ModuleList(
|
586 |
+
[
|
587 |
+
TransformerDecoderLayer(
|
588 |
+
d_model, d_cond, num_heads, d_feedforward, dropout
|
589 |
+
)
|
590 |
+
for _ in range(transformer_decoder_layers)
|
591 |
+
]
|
592 |
+
)
|
593 |
+
|
594 |
+
def forward(self, x: torch.Tensor, cond: torch.Tensor):
|
595 |
+
"""
|
596 |
+
:param x: B x T x d_model input tensor
|
597 |
+
:param cond: B x T x d_cond conditional tensor
|
598 |
+
:return: B x T x d_model output tensor
|
599 |
+
"""
|
600 |
+
x = self.target_positional_encoding(x)
|
601 |
+
cond = self.cond_positional_encoding(cond)
|
602 |
+
|
603 |
+
if self.causal:
|
604 |
+
encoder_mask = generate_causal_mask(
|
605 |
+
cond.shape[1], cond.shape[1], device=cond.device
|
606 |
+
)
|
607 |
+
decoder_self_attn_mask = generate_causal_mask(
|
608 |
+
x.shape[1], x.shape[1], device=x.device
|
609 |
+
)
|
610 |
+
decoder_cross_attn_mask = generate_causal_mask(
|
611 |
+
cond.shape[1], x.shape[1], device=x.device
|
612 |
+
)
|
613 |
+
else:
|
614 |
+
encoder_mask = None
|
615 |
+
decoder_self_attn_mask = None
|
616 |
+
decoder_cross_attn_mask = None
|
617 |
+
|
618 |
+
for encoder_layer in self.transformer_encoder:
|
619 |
+
cond = encoder_layer(cond, mask=encoder_mask)
|
620 |
+
for decoder_layer in self.transformer_decoder:
|
621 |
+
x = decoder_layer(
|
622 |
+
x,
|
623 |
+
cond,
|
624 |
+
target_mask=decoder_self_attn_mask,
|
625 |
+
cross_cond_mask=decoder_cross_attn_mask,
|
626 |
+
)
|
627 |
+
return x
|
628 |
+
|
629 |
+
|
630 |
+
class DiffusionTransformer(nn.Module):
|
631 |
+
def __init__(
|
632 |
+
self,
|
633 |
+
transformer_encoder_layers: int = 2,
|
634 |
+
transformer_decoder_layers: int = 4,
|
635 |
+
d_model: int = 512,
|
636 |
+
d_cond: int = 512,
|
637 |
+
num_heads: int = 4,
|
638 |
+
d_feedforward: int = 1024,
|
639 |
+
dropout: float = 0.1,
|
640 |
+
causal: bool = False,
|
641 |
+
):
|
642 |
+
super().__init__()
|
643 |
+
self.causal = causal
|
644 |
+
|
645 |
+
self.timestep_encoder = TimestepEncoding(d_model)
|
646 |
+
self.cond_positional_encoding = PositionalEncoding(d_cond, dropout)
|
647 |
+
self.target_positional_encoding = PositionalEncoding(d_model, dropout)
|
648 |
+
|
649 |
+
self.transformer_encoder = nn.ModuleList(
|
650 |
+
[
|
651 |
+
TransformerEncoderLayer(d_cond, num_heads, d_feedforward, dropout)
|
652 |
+
for _ in range(transformer_encoder_layers)
|
653 |
+
]
|
654 |
+
)
|
655 |
+
|
656 |
+
self.transformer_decoder = nn.ModuleList(
|
657 |
+
[
|
658 |
+
FilmTransformerDecoderLayer(
|
659 |
+
d_model, d_cond, num_heads, d_feedforward, dropout
|
660 |
+
)
|
661 |
+
for _ in range(transformer_decoder_layers)
|
662 |
+
]
|
663 |
+
)
|
664 |
+
|
665 |
+
def forward(self, x: torch.Tensor, cond: torch.Tensor, t: torch.Tensor):
|
666 |
+
"""
|
667 |
+
:param x: B x T x d_model input tensor
|
668 |
+
:param cond: B x T x d_cond conditional tensor
|
669 |
+
:param t: B-dimensional tensor containing diffusion timesteps in range [0, 1]
|
670 |
+
:return: B x T x d_model output tensor
|
671 |
+
"""
|
672 |
+
t = self.timestep_encoder(t).unsqueeze(1) # B x 1 x d_model
|
673 |
+
x = self.target_positional_encoding(x)
|
674 |
+
cond = self.cond_positional_encoding(cond)
|
675 |
+
|
676 |
+
if self.causal:
|
677 |
+
encoder_mask = generate_causal_mask(
|
678 |
+
cond.shape[1], cond.shape[1], device=cond.device
|
679 |
+
)
|
680 |
+
decoder_self_attn_mask = generate_causal_mask(
|
681 |
+
x.shape[1], x.shape[1], device=x.device
|
682 |
+
)
|
683 |
+
decoder_cross_attn_mask = generate_causal_mask(
|
684 |
+
cond.shape[1], x.shape[1], device=x.device
|
685 |
+
)
|
686 |
+
else:
|
687 |
+
encoder_mask = None
|
688 |
+
decoder_self_attn_mask = None
|
689 |
+
decoder_cross_attn_mask = None
|
690 |
+
|
691 |
+
for encoder_layer in self.transformer_encoder:
|
692 |
+
cond = encoder_layer(cond, mask=encoder_mask)
|
693 |
+
for decoder_layer in self.transformer_decoder:
|
694 |
+
x = decoder_layer(
|
695 |
+
x,
|
696 |
+
cond,
|
697 |
+
t,
|
698 |
+
target_mask=decoder_self_attn_mask,
|
699 |
+
cross_cond_mask=decoder_cross_attn_mask,
|
700 |
+
)
|
701 |
+
|
702 |
+
return x
|
model/utils.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) Meta Platforms, Inc. and affiliates.
|
3 |
+
All rights reserved.
|
4 |
+
This source code is licensed under the license found in the
|
5 |
+
LICENSE file in the root directory of this source tree.
|
6 |
+
"""
|
7 |
+
|
8 |
+
import math
|
9 |
+
|
10 |
+
import fairseq
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
import torchaudio.transforms as T
|
15 |
+
from torch import nn
|
16 |
+
|
17 |
+
|
18 |
+
def setup_lip_regressor() -> ("Audio2LipRegressionTransformer", T.Resample):
|
19 |
+
cp_path = "./assets/vq-wav2vec.pt"
|
20 |
+
audio_model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([cp_path])
|
21 |
+
audio_model = audio_model[0]
|
22 |
+
for param in audio_model.parameters():
|
23 |
+
param.requires_grad = False
|
24 |
+
audio_model.eval()
|
25 |
+
audio_resampler = T.Resample(48000, 16000)
|
26 |
+
return audio_model, audio_resampler
|
27 |
+
|
28 |
+
|
29 |
+
def init_weight(m):
|
30 |
+
if (
|
31 |
+
isinstance(m, nn.Conv1d)
|
32 |
+
or isinstance(m, nn.Linear)
|
33 |
+
or isinstance(m, nn.ConvTranspose1d)
|
34 |
+
):
|
35 |
+
nn.init.xavier_normal_(m.weight)
|
36 |
+
# m.bias.data.fill_(0.01)
|
37 |
+
if m.bias is not None:
|
38 |
+
nn.init.constant_(m.bias, 0)
|
39 |
+
|
40 |
+
|
41 |
+
# absolute positional embedding used for vanilla transformer sequential data
|
42 |
+
class PositionalEncoding(nn.Module):
|
43 |
+
def __init__(self, d_model, dropout=0.1, max_len=800, batch_first=False):
|
44 |
+
super().__init__()
|
45 |
+
self.batch_first = batch_first
|
46 |
+
|
47 |
+
self.dropout = nn.Dropout(p=dropout)
|
48 |
+
|
49 |
+
pe = torch.zeros(max_len, d_model)
|
50 |
+
position = torch.arange(0, max_len).unsqueeze(1)
|
51 |
+
div_term = torch.exp(torch.arange(0, d_model, 2) * (-np.log(10000.0) / d_model))
|
52 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
53 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
54 |
+
pe = pe.unsqueeze(0).transpose(0, 1)
|
55 |
+
|
56 |
+
self.register_buffer("pe", pe)
|
57 |
+
|
58 |
+
def forward(self, x):
|
59 |
+
if self.batch_first:
|
60 |
+
x = x + self.pe.permute(1, 0, 2)[:, : x.shape[1], :]
|
61 |
+
else:
|
62 |
+
x = x + self.pe[: x.shape[0], :]
|
63 |
+
return self.dropout(x)
|
64 |
+
|
65 |
+
|
66 |
+
# very similar positional embedding used for diffusion timesteps
|
67 |
+
class SinusoidalPosEmb(nn.Module):
|
68 |
+
def __init__(self, dim):
|
69 |
+
super().__init__()
|
70 |
+
self.dim = dim
|
71 |
+
|
72 |
+
def forward(self, x):
|
73 |
+
device = x.device
|
74 |
+
half_dim = self.dim // 2
|
75 |
+
emb = math.log(10000) / (half_dim - 1)
|
76 |
+
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
|
77 |
+
emb = x[:, None] * emb[None, :]
|
78 |
+
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
79 |
+
return emb
|
80 |
+
|
81 |
+
|
82 |
+
# dropout mask
|
83 |
+
def prob_mask_like(shape, prob, device):
|
84 |
+
if prob == 1:
|
85 |
+
return torch.ones(shape, device=device, dtype=torch.bool)
|
86 |
+
elif prob == 0:
|
87 |
+
return torch.zeros(shape, device=device, dtype=torch.bool)
|
88 |
+
else:
|
89 |
+
return torch.zeros(shape, device=device).float().uniform_(0, 1) < prob
|
90 |
+
|
91 |
+
|
92 |
+
def extract(a, t, x_shape):
|
93 |
+
b, *_ = t.shape
|
94 |
+
out = a.gather(-1, t)
|
95 |
+
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
|
96 |
+
|
97 |
+
|
98 |
+
def make_beta_schedule(
|
99 |
+
schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3
|
100 |
+
):
|
101 |
+
if schedule == "linear":
|
102 |
+
betas = (
|
103 |
+
torch.linspace(
|
104 |
+
linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64
|
105 |
+
)
|
106 |
+
** 2
|
107 |
+
)
|
108 |
+
|
109 |
+
elif schedule == "cosine":
|
110 |
+
timesteps = (
|
111 |
+
torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
|
112 |
+
)
|
113 |
+
alphas = timesteps / (1 + cosine_s) * np.pi / 2
|
114 |
+
alphas = torch.cos(alphas).pow(2)
|
115 |
+
alphas = alphas / alphas[0]
|
116 |
+
betas = 1 - alphas[1:] / alphas[:-1]
|
117 |
+
betas = np.clip(betas, a_min=0, a_max=0.999)
|
118 |
+
|
119 |
+
elif schedule == "sqrt_linear":
|
120 |
+
betas = torch.linspace(
|
121 |
+
linear_start, linear_end, n_timestep, dtype=torch.float64
|
122 |
+
)
|
123 |
+
elif schedule == "sqrt":
|
124 |
+
betas = (
|
125 |
+
torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
|
126 |
+
** 0.5
|
127 |
+
)
|
128 |
+
else:
|
129 |
+
raise ValueError(f"schedule '{schedule}' unknown.")
|
130 |
+
return betas.numpy()
|
model/vqvae.py
ADDED
@@ -0,0 +1,550 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) Meta Platforms, Inc. and affiliates.
|
3 |
+
All rights reserved.
|
4 |
+
This source code is licensed under the license found in the
|
5 |
+
LICENSE file in the root directory of this source tree.
|
6 |
+
"""
|
7 |
+
|
8 |
+
import json
|
9 |
+
import os
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
import torch.nn.functional as F
|
14 |
+
from einops import rearrange, repeat
|
15 |
+
from utils.misc import broadcast_tensors
|
16 |
+
|
17 |
+
|
18 |
+
def setup_tokenizer(resume_pth: str) -> "TemporalVertexCodec":
|
19 |
+
args_path = os.path.dirname(resume_pth)
|
20 |
+
with open(os.path.join(args_path, "args.json")) as f:
|
21 |
+
trans_args = json.load(f)
|
22 |
+
tokenizer = TemporalVertexCodec(
|
23 |
+
n_vertices=trans_args["nb_joints"],
|
24 |
+
latent_dim=trans_args["output_emb_width"],
|
25 |
+
categories=trans_args["code_dim"],
|
26 |
+
residual_depth=trans_args["depth"],
|
27 |
+
)
|
28 |
+
print("loading checkpoint from {}".format(resume_pth))
|
29 |
+
ckpt = torch.load(resume_pth, map_location="cpu")
|
30 |
+
tokenizer.load_state_dict(ckpt["net"], strict=True)
|
31 |
+
for p in tokenizer.parameters():
|
32 |
+
p.requires_grad = False
|
33 |
+
tokenizer.cuda()
|
34 |
+
return tokenizer
|
35 |
+
|
36 |
+
|
37 |
+
def default(val, d):
|
38 |
+
return val if val is not None else d
|
39 |
+
|
40 |
+
|
41 |
+
def ema_inplace(moving_avg, new, decay: float):
|
42 |
+
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
|
43 |
+
|
44 |
+
|
45 |
+
def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
|
46 |
+
return (x + epsilon) / (x.sum() + n_categories * epsilon)
|
47 |
+
|
48 |
+
|
49 |
+
def uniform_init(*shape: int):
|
50 |
+
t = torch.empty(shape)
|
51 |
+
nn.init.kaiming_uniform_(t)
|
52 |
+
return t
|
53 |
+
|
54 |
+
|
55 |
+
def sum_flat(tensor):
|
56 |
+
"""
|
57 |
+
Take the sum over all non-batch dimensions.
|
58 |
+
"""
|
59 |
+
return tensor.sum(dim=list(range(1, len(tensor.shape))))
|
60 |
+
|
61 |
+
|
62 |
+
def sample_vectors(samples, num: int):
|
63 |
+
num_samples, device = samples.shape[0], samples.device
|
64 |
+
|
65 |
+
if num_samples >= num:
|
66 |
+
indices = torch.randperm(num_samples, device=device)[:num]
|
67 |
+
else:
|
68 |
+
indices = torch.randint(0, num_samples, (num,), device=device)
|
69 |
+
|
70 |
+
return samples[indices]
|
71 |
+
|
72 |
+
|
73 |
+
def kmeans(samples, num_clusters: int, num_iters: int = 10):
|
74 |
+
dim, dtype = samples.shape[-1], samples.dtype
|
75 |
+
|
76 |
+
means = sample_vectors(samples, num_clusters)
|
77 |
+
|
78 |
+
for _ in range(num_iters):
|
79 |
+
diffs = rearrange(samples, "n d -> n () d") - rearrange(means, "c d -> () c d")
|
80 |
+
dists = -(diffs**2).sum(dim=-1)
|
81 |
+
|
82 |
+
buckets = dists.max(dim=-1).indices
|
83 |
+
bins = torch.bincount(buckets, minlength=num_clusters)
|
84 |
+
zero_mask = bins == 0
|
85 |
+
bins_min_clamped = bins.masked_fill(zero_mask, 1)
|
86 |
+
|
87 |
+
new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
|
88 |
+
new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
|
89 |
+
new_means = new_means / bins_min_clamped[..., None]
|
90 |
+
|
91 |
+
means = torch.where(zero_mask[..., None], means, new_means)
|
92 |
+
|
93 |
+
return means, bins
|
94 |
+
|
95 |
+
|
96 |
+
class EuclideanCodebook(nn.Module):
|
97 |
+
"""Codebook with Euclidean distance.
|
98 |
+
Args:
|
99 |
+
dim (int): Dimension.
|
100 |
+
codebook_size (int): Codebook size.
|
101 |
+
kmeans_init (bool): Whether to use k-means to initialize the codebooks.
|
102 |
+
If set to true, run the k-means algorithm on the first training batch and use
|
103 |
+
the learned centroids as initialization.
|
104 |
+
kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
|
105 |
+
decay (float): Decay for exponential moving average over the codebooks.
|
106 |
+
epsilon (float): Epsilon value for numerical stability.
|
107 |
+
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
|
108 |
+
that have an exponential moving average cluster size less than the specified threshold with
|
109 |
+
randomly selected vector from the current batch.
|
110 |
+
"""
|
111 |
+
|
112 |
+
def __init__(
|
113 |
+
self,
|
114 |
+
dim: int,
|
115 |
+
codebook_size: int,
|
116 |
+
kmeans_init: int = False,
|
117 |
+
kmeans_iters: int = 10,
|
118 |
+
decay: float = 0.99,
|
119 |
+
epsilon: float = 1e-5,
|
120 |
+
threshold_ema_dead_code: int = 2,
|
121 |
+
):
|
122 |
+
super().__init__()
|
123 |
+
self.decay = decay
|
124 |
+
init_fn = uniform_init if not kmeans_init else torch.zeros
|
125 |
+
embed = init_fn(codebook_size, dim)
|
126 |
+
|
127 |
+
self.codebook_size = codebook_size
|
128 |
+
|
129 |
+
self.kmeans_iters = kmeans_iters
|
130 |
+
self.epsilon = epsilon
|
131 |
+
self.threshold_ema_dead_code = threshold_ema_dead_code
|
132 |
+
|
133 |
+
self.register_buffer("inited", torch.Tensor([not kmeans_init]))
|
134 |
+
self.register_buffer("cluster_size", torch.zeros(codebook_size))
|
135 |
+
self.register_buffer("embed", embed)
|
136 |
+
self.register_buffer("embed_avg", embed.clone())
|
137 |
+
|
138 |
+
@torch.jit.ignore
|
139 |
+
def init_embed_(self, data):
|
140 |
+
if self.inited:
|
141 |
+
return
|
142 |
+
|
143 |
+
embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
|
144 |
+
self.embed.data.copy_(embed)
|
145 |
+
self.embed_avg.data.copy_(embed.clone())
|
146 |
+
self.cluster_size.data.copy_(cluster_size)
|
147 |
+
self.inited.data.copy_(torch.Tensor([True]))
|
148 |
+
# Make sure all buffers across workers are in sync after initialization
|
149 |
+
broadcast_tensors(self.buffers())
|
150 |
+
|
151 |
+
def replace_(self, samples, mask):
|
152 |
+
modified_codebook = torch.where(
|
153 |
+
mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
|
154 |
+
)
|
155 |
+
self.embed.data.copy_(modified_codebook)
|
156 |
+
|
157 |
+
def expire_codes_(self, batch_samples):
|
158 |
+
if self.threshold_ema_dead_code == 0:
|
159 |
+
return
|
160 |
+
|
161 |
+
expired_codes = self.cluster_size < self.threshold_ema_dead_code
|
162 |
+
if not torch.any(expired_codes):
|
163 |
+
return
|
164 |
+
|
165 |
+
batch_samples = rearrange(batch_samples, "... d -> (...) d")
|
166 |
+
self.replace_(batch_samples, mask=expired_codes)
|
167 |
+
broadcast_tensors(self.buffers())
|
168 |
+
|
169 |
+
def preprocess(self, x):
|
170 |
+
x = rearrange(x, "... d -> (...) d")
|
171 |
+
return x
|
172 |
+
|
173 |
+
def quantize(self, x):
|
174 |
+
embed = self.embed.t()
|
175 |
+
dist = -(
|
176 |
+
x.pow(2).sum(1, keepdim=True)
|
177 |
+
- 2 * x @ embed
|
178 |
+
+ embed.pow(2).sum(0, keepdim=True)
|
179 |
+
)
|
180 |
+
embed_ind = dist.max(dim=-1).indices
|
181 |
+
return embed_ind
|
182 |
+
|
183 |
+
def postprocess_emb(self, embed_ind, shape):
|
184 |
+
return embed_ind.view(*shape[:-1])
|
185 |
+
|
186 |
+
def dequantize(self, embed_ind):
|
187 |
+
quantize = F.embedding(embed_ind, self.embed)
|
188 |
+
return quantize
|
189 |
+
|
190 |
+
def encode(self, x):
|
191 |
+
shape = x.shape
|
192 |
+
x = self.preprocess(x)
|
193 |
+
embed_ind = self.quantize(x)
|
194 |
+
embed_ind = self.postprocess_emb(embed_ind, shape)
|
195 |
+
return embed_ind
|
196 |
+
|
197 |
+
def decode(self, embed_ind):
|
198 |
+
quantize = self.dequantize(embed_ind)
|
199 |
+
return quantize
|
200 |
+
|
201 |
+
def forward(self, x):
|
202 |
+
shape, dtype = x.shape, x.dtype
|
203 |
+
x = self.preprocess(x)
|
204 |
+
|
205 |
+
self.init_embed_(x)
|
206 |
+
|
207 |
+
embed_ind = self.quantize(x)
|
208 |
+
embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
|
209 |
+
embed_ind = self.postprocess_emb(embed_ind, shape)
|
210 |
+
quantize = self.dequantize(embed_ind)
|
211 |
+
|
212 |
+
if self.training:
|
213 |
+
# We do the expiry of code at that point as buffers are in sync
|
214 |
+
# and all the workers will take the same decision.
|
215 |
+
self.expire_codes_(x)
|
216 |
+
ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
|
217 |
+
embed_sum = x.t() @ embed_onehot
|
218 |
+
ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
|
219 |
+
cluster_size = (
|
220 |
+
laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon)
|
221 |
+
* self.cluster_size.sum()
|
222 |
+
)
|
223 |
+
embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
|
224 |
+
self.embed.data.copy_(embed_normalized)
|
225 |
+
|
226 |
+
return quantize, embed_ind
|
227 |
+
|
228 |
+
|
229 |
+
class VectorQuantization(nn.Module):
|
230 |
+
"""Vector quantization implementation.
|
231 |
+
Currently supports only euclidean distance.
|
232 |
+
Args:
|
233 |
+
dim (int): Dimension
|
234 |
+
codebook_size (int): Codebook size
|
235 |
+
codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
|
236 |
+
decay (float): Decay for exponential moving average over the codebooks.
|
237 |
+
epsilon (float): Epsilon value for numerical stability.
|
238 |
+
kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
|
239 |
+
kmeans_iters (int): Number of iterations used for kmeans initialization.
|
240 |
+
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
|
241 |
+
that have an exponential moving average cluster size less than the specified threshold with
|
242 |
+
randomly selected vector from the current batch.
|
243 |
+
commitment_weight (float): Weight for commitment loss.
|
244 |
+
"""
|
245 |
+
|
246 |
+
def __init__(
|
247 |
+
self,
|
248 |
+
dim: int,
|
249 |
+
codebook_size: int,
|
250 |
+
codebook_dim=None,
|
251 |
+
decay: float = 0.99,
|
252 |
+
epsilon: float = 1e-5,
|
253 |
+
kmeans_init: bool = True,
|
254 |
+
kmeans_iters: int = 50,
|
255 |
+
threshold_ema_dead_code: int = 2,
|
256 |
+
commitment_weight: float = 1.0,
|
257 |
+
):
|
258 |
+
super().__init__()
|
259 |
+
_codebook_dim: int = default(codebook_dim, dim)
|
260 |
+
|
261 |
+
requires_projection = _codebook_dim != dim
|
262 |
+
self.project_in = (
|
263 |
+
nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity()
|
264 |
+
)
|
265 |
+
self.project_out = (
|
266 |
+
nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity()
|
267 |
+
)
|
268 |
+
|
269 |
+
self.epsilon = epsilon
|
270 |
+
self.commitment_weight = commitment_weight
|
271 |
+
|
272 |
+
self._codebook = EuclideanCodebook(
|
273 |
+
dim=_codebook_dim,
|
274 |
+
codebook_size=codebook_size,
|
275 |
+
kmeans_init=kmeans_init,
|
276 |
+
kmeans_iters=kmeans_iters,
|
277 |
+
decay=decay,
|
278 |
+
epsilon=epsilon,
|
279 |
+
threshold_ema_dead_code=threshold_ema_dead_code,
|
280 |
+
)
|
281 |
+
self.codebook_size = codebook_size
|
282 |
+
self.l2_loss = lambda a, b: (a - b) ** 2
|
283 |
+
|
284 |
+
@property
|
285 |
+
def codebook(self):
|
286 |
+
return self._codebook.embed
|
287 |
+
|
288 |
+
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
289 |
+
x = self.project_in(x)
|
290 |
+
embed_in = self._codebook.encode(x)
|
291 |
+
return embed_in
|
292 |
+
|
293 |
+
def decode(self, embed_ind: torch.Tensor) -> torch.Tensor:
|
294 |
+
quantize = self._codebook.decode(embed_ind)
|
295 |
+
quantize = self.project_out(quantize)
|
296 |
+
return quantize
|
297 |
+
|
298 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
299 |
+
"""
|
300 |
+
:param x: B x dim input tensor
|
301 |
+
:return: quantize: B x dim tensor containing reconstruction after quantization
|
302 |
+
embed_ind: B-dimensional tensor containing embedding indices
|
303 |
+
loss: scalar tensor containing commitment loss
|
304 |
+
"""
|
305 |
+
device = x.device
|
306 |
+
x = self.project_in(x)
|
307 |
+
|
308 |
+
quantize, embed_ind = self._codebook(x)
|
309 |
+
|
310 |
+
if self.training:
|
311 |
+
quantize = x + (quantize - x).detach()
|
312 |
+
|
313 |
+
loss = torch.tensor([0.0], device=device, requires_grad=self.training)
|
314 |
+
|
315 |
+
if self.training:
|
316 |
+
if self.commitment_weight > 0:
|
317 |
+
commit_loss = F.mse_loss(quantize.detach(), x)
|
318 |
+
loss = loss + commit_loss * self.commitment_weight
|
319 |
+
|
320 |
+
quantize = self.project_out(quantize)
|
321 |
+
return quantize, embed_ind, loss
|
322 |
+
|
323 |
+
|
324 |
+
class ResidualVectorQuantization(nn.Module):
|
325 |
+
"""Residual vector quantization implementation.
|
326 |
+
Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
|
327 |
+
"""
|
328 |
+
|
329 |
+
def __init__(self, *, num_quantizers: int, **kwargs):
|
330 |
+
super().__init__()
|
331 |
+
self.layers = nn.ModuleList(
|
332 |
+
[VectorQuantization(**kwargs) for _ in range(num_quantizers)]
|
333 |
+
)
|
334 |
+
|
335 |
+
def forward(self, x, B, T, mask, n_q=None):
|
336 |
+
"""
|
337 |
+
:param x: B x dim tensor
|
338 |
+
:return: quantized_out: B x dim tensor
|
339 |
+
out_indices: B x n_q LongTensor containing indices for each quantizer
|
340 |
+
out_losses: scalar tensor containing commitment loss
|
341 |
+
"""
|
342 |
+
quantized_out = 0.0
|
343 |
+
residual = x
|
344 |
+
|
345 |
+
all_losses = []
|
346 |
+
all_indices = []
|
347 |
+
|
348 |
+
n_q = n_q or len(self.layers)
|
349 |
+
|
350 |
+
for layer in self.layers[:n_q]:
|
351 |
+
quantized, indices, loss = layer(residual)
|
352 |
+
residual = (
|
353 |
+
residual - quantized
|
354 |
+
) # would need quantizer.detach() to have commitment gradients beyond the first quantizer, but this seems to harm performance
|
355 |
+
quantized_out = quantized_out + quantized
|
356 |
+
|
357 |
+
all_indices.append(indices)
|
358 |
+
all_losses.append(loss)
|
359 |
+
|
360 |
+
out_indices = torch.stack(all_indices, dim=-1)
|
361 |
+
out_losses = torch.mean(torch.stack(all_losses))
|
362 |
+
return quantized_out, out_indices, out_losses
|
363 |
+
|
364 |
+
def encode(self, x: torch.Tensor, n_q=None) -> torch.Tensor:
|
365 |
+
"""
|
366 |
+
:param x: B x dim input tensor
|
367 |
+
:return: B x n_q LongTensor containing indices for each quantizer
|
368 |
+
"""
|
369 |
+
residual = x
|
370 |
+
all_indices = []
|
371 |
+
n_q = n_q or len(self.layers)
|
372 |
+
for layer in self.layers[:n_q]:
|
373 |
+
indices = layer.encode(residual) # indices = 16 x 8 = B x T
|
374 |
+
# print(indices.shape, residual.shape, x.shape)
|
375 |
+
quantized = layer.decode(indices)
|
376 |
+
residual = residual - quantized
|
377 |
+
all_indices.append(indices)
|
378 |
+
out_indices = torch.stack(all_indices, dim=-1)
|
379 |
+
return out_indices
|
380 |
+
|
381 |
+
def decode(self, q_indices: torch.Tensor) -> torch.Tensor:
|
382 |
+
"""
|
383 |
+
:param q_indices: B x n_q LongTensor containing indices for each quantizer
|
384 |
+
:return: B x dim tensor containing reconstruction after quantization
|
385 |
+
"""
|
386 |
+
quantized_out = torch.tensor(0.0, device=q_indices.device)
|
387 |
+
q_indices = q_indices.permute(1, 0).contiguous()
|
388 |
+
for i, indices in enumerate(q_indices):
|
389 |
+
layer = self.layers[i]
|
390 |
+
quantized = layer.decode(indices)
|
391 |
+
quantized_out = quantized_out + quantized
|
392 |
+
return quantized_out
|
393 |
+
|
394 |
+
|
395 |
+
class TemporalVertexEncoder(nn.Module):
|
396 |
+
def __init__(
|
397 |
+
self,
|
398 |
+
n_vertices: int = 338,
|
399 |
+
latent_dim: int = 128,
|
400 |
+
):
|
401 |
+
super().__init__()
|
402 |
+
self.input_dim = n_vertices
|
403 |
+
self.enc = nn.Sequential(
|
404 |
+
nn.Conv1d(self.input_dim, latent_dim, kernel_size=1),
|
405 |
+
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
406 |
+
nn.Conv1d(latent_dim, latent_dim, kernel_size=2, dilation=1),
|
407 |
+
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
408 |
+
nn.Conv1d(latent_dim, latent_dim, kernel_size=2, dilation=2),
|
409 |
+
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
410 |
+
nn.Conv1d(latent_dim, latent_dim, kernel_size=2, dilation=3),
|
411 |
+
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
412 |
+
nn.Conv1d(latent_dim, latent_dim, kernel_size=2, dilation=1),
|
413 |
+
)
|
414 |
+
self.receptive_field = 8
|
415 |
+
|
416 |
+
def forward(self, verts):
|
417 |
+
"""
|
418 |
+
:param verts: B x T x n_vertices x 3 tensor containing batched sequences of vertices
|
419 |
+
:return: B x T x latent_dim tensor containing the latent representation
|
420 |
+
"""
|
421 |
+
if verts.dim() == 4:
|
422 |
+
verts = verts.permute(0, 2, 3, 1).contiguous()
|
423 |
+
verts = verts.view(verts.shape[0], self.input_dim, verts.shape[3])
|
424 |
+
else:
|
425 |
+
verts = verts.permute(0, 2, 1)
|
426 |
+
verts = nn.functional.pad(verts, pad=[self.receptive_field - 1, 0])
|
427 |
+
x = self.enc(verts)
|
428 |
+
x = x.permute(0, 2, 1).contiguous()
|
429 |
+
return x
|
430 |
+
|
431 |
+
|
432 |
+
class TemporalVertexDecoder(nn.Module):
|
433 |
+
def __init__(
|
434 |
+
self,
|
435 |
+
n_vertices: int = 338,
|
436 |
+
latent_dim: int = 128,
|
437 |
+
):
|
438 |
+
super().__init__()
|
439 |
+
self.output_dim = n_vertices
|
440 |
+
self.project_mean_shape = nn.Linear(self.output_dim, latent_dim)
|
441 |
+
self.dec = nn.Sequential(
|
442 |
+
nn.Conv1d(latent_dim, latent_dim, kernel_size=2, dilation=1),
|
443 |
+
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
444 |
+
nn.Conv1d(latent_dim, latent_dim, kernel_size=2, dilation=2),
|
445 |
+
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
446 |
+
nn.Conv1d(latent_dim, latent_dim, kernel_size=2, dilation=3),
|
447 |
+
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
448 |
+
nn.Conv1d(latent_dim, latent_dim, kernel_size=2, dilation=1),
|
449 |
+
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
450 |
+
nn.Conv1d(latent_dim, self.output_dim, kernel_size=1),
|
451 |
+
)
|
452 |
+
self.receptive_field = 8
|
453 |
+
|
454 |
+
def forward(self, x):
|
455 |
+
"""
|
456 |
+
:param x: B x T x latent_dim tensor containing batched sequences of vertex encodings
|
457 |
+
:return: B x T x n_vertices x 3 tensor containing batched sequences of vertices
|
458 |
+
"""
|
459 |
+
x = x.permute(0, 2, 1).contiguous()
|
460 |
+
x = nn.functional.pad(x, pad=[self.receptive_field - 1, 0])
|
461 |
+
verts = self.dec(x)
|
462 |
+
verts = verts.permute(0, 2, 1)
|
463 |
+
return verts
|
464 |
+
|
465 |
+
|
466 |
+
class TemporalVertexCodec(nn.Module):
|
467 |
+
def __init__(
|
468 |
+
self,
|
469 |
+
n_vertices: int = 338,
|
470 |
+
latent_dim: int = 128,
|
471 |
+
categories: int = 128,
|
472 |
+
residual_depth: int = 4,
|
473 |
+
):
|
474 |
+
super().__init__()
|
475 |
+
self.latent_dim = latent_dim
|
476 |
+
self.categories = categories
|
477 |
+
self.residual_depth = residual_depth
|
478 |
+
self.n_clusters = categories
|
479 |
+
self.encoder = TemporalVertexEncoder(
|
480 |
+
n_vertices=n_vertices, latent_dim=latent_dim
|
481 |
+
)
|
482 |
+
self.decoder = TemporalVertexDecoder(
|
483 |
+
n_vertices=n_vertices, latent_dim=latent_dim
|
484 |
+
)
|
485 |
+
self.quantizer = ResidualVectorQuantization(
|
486 |
+
dim=latent_dim,
|
487 |
+
codebook_size=categories,
|
488 |
+
num_quantizers=residual_depth,
|
489 |
+
decay=0.99,
|
490 |
+
kmeans_init=True,
|
491 |
+
kmeans_iters=10,
|
492 |
+
threshold_ema_dead_code=2,
|
493 |
+
)
|
494 |
+
|
495 |
+
def predict(self, verts):
|
496 |
+
"""wrapper to provide compatibility with kmeans"""
|
497 |
+
return self.encode(verts)
|
498 |
+
|
499 |
+
def encode(self, verts):
|
500 |
+
"""
|
501 |
+
:param verts: B x T x n_vertices x 3 tensor containing batched sequences of vertices
|
502 |
+
:return: B x T x categories x residual_depth LongTensor containing quantized encodings
|
503 |
+
"""
|
504 |
+
enc = self.encoder(verts)
|
505 |
+
q = self.quantizer.encode(enc)
|
506 |
+
return q
|
507 |
+
|
508 |
+
def decode(self, q):
|
509 |
+
"""
|
510 |
+
:param q: B x T x categories x residual_depth LongTensor containing quantized encodings
|
511 |
+
:return: B x T x n_vertices x 3 tensor containing decoded vertices
|
512 |
+
"""
|
513 |
+
reformat = q.dim() > 2
|
514 |
+
if reformat:
|
515 |
+
B, T, _ = q.shape
|
516 |
+
q = q.reshape((-1, self.residual_depth))
|
517 |
+
enc = self.quantizer.decode(q)
|
518 |
+
if reformat:
|
519 |
+
enc = enc.reshape((B, T, -1))
|
520 |
+
verts = self.decoder(enc)
|
521 |
+
return verts
|
522 |
+
|
523 |
+
@torch.no_grad()
|
524 |
+
def compute_perplexity(self, code_idx):
|
525 |
+
# Calculate new centres
|
526 |
+
code_onehot = torch.zeros(
|
527 |
+
self.categories, code_idx.shape[0], device=code_idx.device
|
528 |
+
) # categories, N * L
|
529 |
+
code_onehot.scatter_(0, code_idx.view(1, code_idx.shape[0]), 1)
|
530 |
+
|
531 |
+
code_count = code_onehot.sum(dim=-1) # categories
|
532 |
+
prob = code_count / torch.sum(code_count)
|
533 |
+
perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7)))
|
534 |
+
return perplexity
|
535 |
+
|
536 |
+
def forward(self, verts, mask=None):
|
537 |
+
"""
|
538 |
+
:param verts: B x T x n_vertices x 3 tensor containing mesh sequences
|
539 |
+
:return: verts: B x T x n_vertices x 3 tensor containing reconstructed mesh sequences
|
540 |
+
vq_loss: scalar tensor for vq commitment loss
|
541 |
+
"""
|
542 |
+
B, T = verts.shape[0], verts.shape[1]
|
543 |
+
x = self.encoder(verts)
|
544 |
+
x, code_idx, vq_loss = self.quantizer(
|
545 |
+
x.view(B * T, self.latent_dim), B, T, mask
|
546 |
+
)
|
547 |
+
perplexity = self.compute_perplexity(code_idx[:, -1].view((-1)))
|
548 |
+
verts = self.decoder(x.view(B, T, self.latent_dim))
|
549 |
+
verts = verts.reshape((verts.shape[0], verts.shape[1], -1))
|
550 |
+
return verts, vq_loss, perplexity
|
sample/generate.py
ADDED
@@ -0,0 +1,316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) Meta Platforms, Inc. and affiliates.
|
3 |
+
All rights reserved.
|
4 |
+
This source code is licensed under the license found in the
|
5 |
+
LICENSE file in the root directory of this source tree.
|
6 |
+
"""
|
7 |
+
|
8 |
+
import os
|
9 |
+
|
10 |
+
from typing import Callable, Dict, Union
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
from data_loaders.get_data import get_dataset_loader, load_local_data
|
15 |
+
from diffusion.respace import SpacedDiffusion
|
16 |
+
from model.cfg_sampler import ClassifierFreeSampleModel
|
17 |
+
from model.diffusion import FiLMTransformer
|
18 |
+
|
19 |
+
from torch.utils.data import DataLoader
|
20 |
+
from utils.diff_parser_utils import generate_args
|
21 |
+
from utils.misc import fixseed, prGreen
|
22 |
+
from utils.model_util import create_model_and_diffusion, get_person_num, load_model
|
23 |
+
|
24 |
+
|
25 |
+
def _construct_template_variables(unconstrained: bool) -> (str,):
|
26 |
+
row_file_template = "sample{:02d}.mp4"
|
27 |
+
all_file_template = "samples_{:02d}_to_{:02d}.mp4"
|
28 |
+
if unconstrained:
|
29 |
+
sample_file_template = "row{:02d}_col{:02d}.mp4"
|
30 |
+
sample_print_template = "[{} row #{:02d} column #{:02d} | -> {}]"
|
31 |
+
row_file_template = row_file_template.replace("sample", "row")
|
32 |
+
row_print_template = "[{} row #{:02d} | all columns | -> {}]"
|
33 |
+
all_file_template = all_file_template.replace("samples", "rows")
|
34 |
+
all_print_template = "[rows {:02d} to {:02d} | -> {}]"
|
35 |
+
else:
|
36 |
+
sample_file_template = "sample{:02d}_rep{:02d}.mp4"
|
37 |
+
sample_print_template = '["{}" ({:02d}) | Rep #{:02d} | -> {}]'
|
38 |
+
row_print_template = '[ "{}" ({:02d}) | all repetitions | -> {}]'
|
39 |
+
all_print_template = "[samples {:02d} to {:02d} | all repetitions | -> {}]"
|
40 |
+
|
41 |
+
return (
|
42 |
+
sample_print_template,
|
43 |
+
row_print_template,
|
44 |
+
all_print_template,
|
45 |
+
sample_file_template,
|
46 |
+
row_file_template,
|
47 |
+
all_file_template,
|
48 |
+
)
|
49 |
+
|
50 |
+
|
51 |
+
def _replace_keyframes(
|
52 |
+
model_kwargs: Dict[str, Dict[str, torch.Tensor]],
|
53 |
+
model: Union[FiLMTransformer, ClassifierFreeSampleModel],
|
54 |
+
) -> torch.Tensor:
|
55 |
+
B, T = (
|
56 |
+
model_kwargs["y"]["keyframes"].shape[0],
|
57 |
+
model_kwargs["y"]["keyframes"].shape[1],
|
58 |
+
)
|
59 |
+
with torch.no_grad():
|
60 |
+
tokens = model.transformer.generate(
|
61 |
+
model_kwargs["y"]["audio"],
|
62 |
+
T,
|
63 |
+
layers=model.tokenizer.residual_depth,
|
64 |
+
n_sequences=B,
|
65 |
+
)
|
66 |
+
tokens = tokens.reshape((B, -1, model.tokenizer.residual_depth))
|
67 |
+
pred = model.tokenizer.decode(tokens).detach().cpu()
|
68 |
+
assert (
|
69 |
+
model_kwargs["y"]["keyframes"].shape == pred.shape
|
70 |
+
), f"{model_kwargs['y']['keyframes'].shape} vs {pred.shape}"
|
71 |
+
return pred
|
72 |
+
|
73 |
+
|
74 |
+
def _run_single_diffusion(
|
75 |
+
args,
|
76 |
+
model_kwargs: Dict[str, Dict[str, torch.Tensor]],
|
77 |
+
diffusion: SpacedDiffusion,
|
78 |
+
model: Union[FiLMTransformer, ClassifierFreeSampleModel],
|
79 |
+
inv_transform: Callable,
|
80 |
+
gt: torch.Tensor,
|
81 |
+
) -> (torch.Tensor,):
|
82 |
+
if args.data_format == "pose" and args.resume_trans is not None:
|
83 |
+
model_kwargs["y"]["keyframes"] = _replace_keyframes(model_kwargs, model)
|
84 |
+
|
85 |
+
sample_fn = diffusion.ddim_sample_loop
|
86 |
+
with torch.no_grad():
|
87 |
+
sample = sample_fn(
|
88 |
+
model,
|
89 |
+
(args.batch_size, model.nfeats, 1, args.curr_seq_length),
|
90 |
+
clip_denoised=False,
|
91 |
+
model_kwargs=model_kwargs,
|
92 |
+
init_image=None,
|
93 |
+
progress=True,
|
94 |
+
dump_steps=None,
|
95 |
+
noise=None,
|
96 |
+
const_noise=False,
|
97 |
+
)
|
98 |
+
sample = inv_transform(sample.cpu().permute(0, 2, 3, 1), args.data_format).permute(
|
99 |
+
0, 3, 1, 2
|
100 |
+
)
|
101 |
+
curr_audio = inv_transform(model_kwargs["y"]["audio"].cpu().numpy(), "audio")
|
102 |
+
keyframes = inv_transform(model_kwargs["y"]["keyframes"], args.data_format)
|
103 |
+
gt_seq = inv_transform(gt.cpu().permute(0, 2, 3, 1), args.data_format).permute(
|
104 |
+
0, 3, 1, 2
|
105 |
+
)
|
106 |
+
|
107 |
+
return sample, curr_audio, keyframes, gt_seq
|
108 |
+
|
109 |
+
|
110 |
+
def _generate_sequences(
|
111 |
+
args,
|
112 |
+
model_kwargs: Dict[str, Dict[str, torch.Tensor]],
|
113 |
+
diffusion: SpacedDiffusion,
|
114 |
+
model: Union[FiLMTransformer, ClassifierFreeSampleModel],
|
115 |
+
test_data: torch.Tensor,
|
116 |
+
gt: torch.Tensor,
|
117 |
+
) -> Dict[str, np.ndarray]:
|
118 |
+
all_motions = []
|
119 |
+
all_lengths = []
|
120 |
+
all_audio = []
|
121 |
+
all_gt = []
|
122 |
+
all_keyframes = []
|
123 |
+
|
124 |
+
for rep_i in range(args.num_repetitions):
|
125 |
+
print(f"### Sampling [repetitions #{rep_i}]")
|
126 |
+
# add CFG scale to batch
|
127 |
+
if args.guidance_param != 1:
|
128 |
+
model_kwargs["y"]["scale"] = (
|
129 |
+
torch.ones(args.batch_size, device=args.device) * args.guidance_param
|
130 |
+
)
|
131 |
+
model_kwargs["y"] = {
|
132 |
+
key: val.to(args.device) if torch.is_tensor(val) else val
|
133 |
+
for key, val in model_kwargs["y"].items()
|
134 |
+
}
|
135 |
+
sample, curr_audio, keyframes, gt_seq = _run_single_diffusion(
|
136 |
+
args, model_kwargs, diffusion, model, test_data.dataset.inv_transform, gt
|
137 |
+
)
|
138 |
+
all_motions.append(sample.cpu().numpy())
|
139 |
+
all_audio.append(curr_audio)
|
140 |
+
all_keyframes.append(keyframes.cpu().numpy())
|
141 |
+
all_gt.append(gt_seq.cpu().numpy())
|
142 |
+
all_lengths.append(model_kwargs["y"]["lengths"].cpu().numpy())
|
143 |
+
|
144 |
+
print(f"created {len(all_motions) * args.batch_size} samples")
|
145 |
+
|
146 |
+
return {
|
147 |
+
"motions": np.concatenate(all_motions, axis=0),
|
148 |
+
"audio": np.concatenate(all_audio, axis=0),
|
149 |
+
"gt": np.concatenate(all_gt, axis=0),
|
150 |
+
"lengths": np.concatenate(all_lengths, axis=0),
|
151 |
+
"keyframes": np.concatenate(all_keyframes, axis=0),
|
152 |
+
}
|
153 |
+
|
154 |
+
|
155 |
+
def _render_pred(
|
156 |
+
args,
|
157 |
+
data_block: Dict[str, torch.Tensor],
|
158 |
+
sample_file_template: str,
|
159 |
+
audio_per_frame: int,
|
160 |
+
) -> None:
|
161 |
+
from visualize.render_codes import BodyRenderer
|
162 |
+
|
163 |
+
face_codes = None
|
164 |
+
if args.face_codes is not None:
|
165 |
+
face_codes = np.load(args.face_codes, allow_pickle=True).item()
|
166 |
+
face_motions = face_codes["motions"]
|
167 |
+
face_gts = face_codes["gt"]
|
168 |
+
face_audio = face_codes["audio"]
|
169 |
+
|
170 |
+
config_base = f"./checkpoints/ca_body/data/{get_person_num(args.data_root)}"
|
171 |
+
body_renderer = BodyRenderer(
|
172 |
+
config_base=config_base,
|
173 |
+
render_rgb=True,
|
174 |
+
)
|
175 |
+
|
176 |
+
for sample_i in range(args.num_samples):
|
177 |
+
for rep_i in range(args.num_repetitions):
|
178 |
+
idx = rep_i * args.batch_size + sample_i
|
179 |
+
save_file = sample_file_template.format(sample_i, rep_i)
|
180 |
+
animation_save_path = os.path.join(args.output_dir, save_file)
|
181 |
+
# format data
|
182 |
+
length = data_block["lengths"][idx]
|
183 |
+
body_motion = (
|
184 |
+
data_block["motions"][idx].transpose(2, 0, 1)[:length].squeeze(-1)
|
185 |
+
)
|
186 |
+
face_motion = face_motions[idx].transpose(2, 0, 1)[:length].squeeze(-1)
|
187 |
+
assert np.array_equal(
|
188 |
+
data_block["audio"][idx], face_audio[idx]
|
189 |
+
), "face audio is not the same"
|
190 |
+
audio = data_block["audio"][idx, : length * audio_per_frame, :].T
|
191 |
+
# set up render data block to pass into renderer
|
192 |
+
render_data_block = {
|
193 |
+
"audio": audio,
|
194 |
+
"body_motion": body_motion,
|
195 |
+
"face_motion": face_motion,
|
196 |
+
}
|
197 |
+
if args.render_gt:
|
198 |
+
gt_body = data_block["gt"][idx].transpose(2, 0, 1)[:length].squeeze(-1)
|
199 |
+
gt_face = face_gts[idx].transpose(2, 0, 1)[:length].squeeze(-1)
|
200 |
+
render_data_block["gt_body"] = gt_body
|
201 |
+
render_data_block["gt_face"] = gt_face
|
202 |
+
body_renderer.render_full_video(
|
203 |
+
render_data_block,
|
204 |
+
animation_save_path,
|
205 |
+
audio_sr=audio_per_frame * 30,
|
206 |
+
render_gt=args.render_gt,
|
207 |
+
)
|
208 |
+
|
209 |
+
|
210 |
+
def _reset_sample_args(args) -> None:
|
211 |
+
# set the sequence length to match the one specified by user
|
212 |
+
name = os.path.basename(os.path.dirname(args.model_path))
|
213 |
+
niter = os.path.basename(args.model_path).replace("model", "").replace(".pt", "")
|
214 |
+
args.curr_seq_length = (
|
215 |
+
args.curr_seq_length
|
216 |
+
if args.curr_seq_length is not None
|
217 |
+
else args.max_seq_length
|
218 |
+
)
|
219 |
+
# add the resume predictor model path
|
220 |
+
resume_trans_name = ""
|
221 |
+
if args.data_format == "pose" and args.resume_trans is not None:
|
222 |
+
resume_trans_parts = args.resume_trans.split("/")
|
223 |
+
resume_trans_name = f"{resume_trans_parts[1]}_{resume_trans_parts[-1]}"
|
224 |
+
# reformat the output directory
|
225 |
+
args.output_dir = os.path.join(
|
226 |
+
os.path.dirname(args.model_path),
|
227 |
+
"samples_{}_{}_seed{}_{}".format(name, niter, args.seed, resume_trans_name),
|
228 |
+
)
|
229 |
+
assert (
|
230 |
+
args.num_samples <= args.batch_size
|
231 |
+
), f"Please either increase batch_size({args.batch_size}) or reduce num_samples({args.num_samples})"
|
232 |
+
# set the batch size to match the number of samples to generate
|
233 |
+
args.batch_size = args.num_samples
|
234 |
+
|
235 |
+
|
236 |
+
def _setup_dataset(args) -> DataLoader:
|
237 |
+
data_root = args.data_root
|
238 |
+
data_dict = load_local_data(
|
239 |
+
data_root,
|
240 |
+
audio_per_frame=1600,
|
241 |
+
flip_person=args.flip_person,
|
242 |
+
)
|
243 |
+
test_data = get_dataset_loader(
|
244 |
+
args=args,
|
245 |
+
data_dict=data_dict,
|
246 |
+
split="test",
|
247 |
+
chunk=True,
|
248 |
+
)
|
249 |
+
return test_data
|
250 |
+
|
251 |
+
|
252 |
+
def _setup_model(
|
253 |
+
args,
|
254 |
+
) -> (Union[FiLMTransformer, ClassifierFreeSampleModel], SpacedDiffusion):
|
255 |
+
model, diffusion = create_model_and_diffusion(args, split_type="test")
|
256 |
+
print(f"Loading checkpoints from [{args.model_path}]...")
|
257 |
+
state_dict = torch.load(args.model_path, map_location="cpu")
|
258 |
+
load_model(model, state_dict)
|
259 |
+
|
260 |
+
if not args.unconstrained:
|
261 |
+
assert args.guidance_param != 1
|
262 |
+
|
263 |
+
if args.guidance_param != 1:
|
264 |
+
prGreen("[CFS] wrapping model in classifier free sample")
|
265 |
+
model = ClassifierFreeSampleModel(model)
|
266 |
+
model.to(args.device)
|
267 |
+
model.eval()
|
268 |
+
return model, diffusion
|
269 |
+
|
270 |
+
|
271 |
+
def main():
|
272 |
+
args = generate_args()
|
273 |
+
fixseed(args.seed)
|
274 |
+
_reset_sample_args(args)
|
275 |
+
|
276 |
+
print("Loading dataset...")
|
277 |
+
test_data = _setup_dataset(args)
|
278 |
+
iterator = iter(test_data)
|
279 |
+
|
280 |
+
print("Creating model and diffusion...")
|
281 |
+
model, diffusion = _setup_model(args)
|
282 |
+
|
283 |
+
if args.pose_codes is None:
|
284 |
+
# generate sequences
|
285 |
+
gt, model_kwargs = next(iterator)
|
286 |
+
data_block = _generate_sequences(
|
287 |
+
args, model_kwargs, diffusion, model, test_data, gt
|
288 |
+
)
|
289 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
290 |
+
npy_path = os.path.join(args.output_dir, "results.npy")
|
291 |
+
print(f"saving results file to [{npy_path}]")
|
292 |
+
np.save(npy_path, data_block)
|
293 |
+
else:
|
294 |
+
# load the pre generated results
|
295 |
+
data_block = np.load(args.pose_codes, allow_pickle=True).item()
|
296 |
+
|
297 |
+
# plot function only if face_codes exist and we are on pose prediction
|
298 |
+
if args.plot:
|
299 |
+
assert args.face_codes is not None, "need body and faces"
|
300 |
+
assert (
|
301 |
+
args.data_format == "pose"
|
302 |
+
), "currently only supporting plot on pose stuff"
|
303 |
+
print(f"saving visualizations to [{args.output_dir}]...")
|
304 |
+
_, _, _, sample_file_template, _, _ = _construct_template_variables(
|
305 |
+
args.unconstrained
|
306 |
+
)
|
307 |
+
_render_pred(
|
308 |
+
args,
|
309 |
+
data_block,
|
310 |
+
sample_file_template,
|
311 |
+
test_data.dataset.audio_per_frame,
|
312 |
+
)
|
313 |
+
|
314 |
+
|
315 |
+
if __name__ == "__main__":
|
316 |
+
main()
|
scripts/download_alldatasets.sh
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
for i in "PXB184" "RLW104" "TXB805" "GQS883"
|
2 |
+
do
|
3 |
+
curl -L https://github.com/facebookresearch/audio2photoreal/releases/download/v1.0/${i}.zip -o ${i}.zip || { echo 'downloading dataset failed' ; exit 1; }
|
4 |
+
unzip ${i}.zip -d dataset/
|
5 |
+
rm ${i}.zip
|
6 |
+
done
|
scripts/download_allmodels.sh
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
for i in "PXB184" "RLW104" "TXB805" "GQS883"
|
2 |
+
do
|
3 |
+
# download motion models
|
4 |
+
wget http://audio2photoreal_models.berkeleyvision.org/${i}_models.tar || { echo 'downloading model failed' ; exit 1; }
|
5 |
+
tar xvf ${i}_models.tar
|
6 |
+
rm ${i}_models.tar
|
7 |
+
|
8 |
+
# download ca body rendering checkpoints and assets
|
9 |
+
mkdir -p checkpoints/ca_body/data/
|
10 |
+
wget https://github.com/facebookresearch/ca_body/releases/download/v0.0.1-alpha/${i}.tar.gz || { echo 'downloading ca body model failed' ; exit 1; }
|
11 |
+
tar xvf ${i}.tar.gz --directory checkpoints/ca_body/data/
|
12 |
+
rm ${i}.tar.gz
|
13 |
+
done
|
scripts/download_prereq.sh
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# install the prerequisite asset models (lip regressor and wav2vec)
|
3 |
+
wget http://audio2photoreal_models.berkeleyvision.org/asset_models.tar
|
4 |
+
tar xvf asset_models.tar
|
5 |
+
rm asset_models.tar
|
6 |
+
|
7 |
+
# we obtained the wav2vec models via these links:
|
8 |
+
# wget https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_large.pt -P ./assets/
|
9 |
+
# wget https://dl.fbaipublicfiles.com/fairseq/wav2vec/vq-wav2vec.pt -P ./assets/
|
scripts/installation.sh
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# download the prerequisite asset models (lip regressor and wav2vec)
|
2 |
+
wget http://audio2photoreal_models.berkeleyvision.org/asset_models.tar
|
3 |
+
tar xvf asset_models.tar
|
4 |
+
rm asset_models.tar
|
scripts/requirements.txt
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
attrdict
|
2 |
+
blobfile
|
3 |
+
einops
|
4 |
+
fairseq
|
5 |
+
gradio
|
6 |
+
matplotlib
|
7 |
+
mediapy
|
8 |
+
numpy==1.23.0
|
9 |
+
opencv-python
|
10 |
+
packaging
|
11 |
+
scikit-learn
|
12 |
+
tensorboard
|
13 |
+
tensorboardX
|
14 |
+
torch==2.0.1
|
15 |
+
torchaudio==2.0.2
|
16 |
+
torchvision==0.15.2
|
17 |
+
tqdm
|
train/train_diffusion.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) Meta Platforms, Inc. and affiliates.
|
3 |
+
All rights reserved.
|
4 |
+
This source code is licensed under the license found in the
|
5 |
+
LICENSE file in the root directory of this source tree.
|
6 |
+
"""
|
7 |
+
|
8 |
+
import json
|
9 |
+
import os
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.multiprocessing as mp
|
13 |
+
|
14 |
+
from data_loaders.get_data import get_dataset_loader, load_local_data
|
15 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
16 |
+
from torch.utils.tensorboard import SummaryWriter
|
17 |
+
from train.train_platforms import ClearmlPlatform, NoPlatform, TensorboardPlatform
|
18 |
+
from train.training_loop import TrainLoop
|
19 |
+
from utils.diff_parser_utils import train_args
|
20 |
+
from utils.misc import cleanup, fixseed, setup_dist
|
21 |
+
from utils.model_util import create_model_and_diffusion
|
22 |
+
|
23 |
+
|
24 |
+
def main(rank: int, world_size: int):
|
25 |
+
args = train_args()
|
26 |
+
fixseed(args.seed)
|
27 |
+
train_platform_type = eval(args.train_platform_type)
|
28 |
+
train_platform = train_platform_type(args.save_dir)
|
29 |
+
train_platform.report_args(args, name="Args")
|
30 |
+
setup_dist(args.device)
|
31 |
+
|
32 |
+
if rank == 0:
|
33 |
+
if args.save_dir is None:
|
34 |
+
raise FileNotFoundError("save_dir was not specified.")
|
35 |
+
elif os.path.exists(args.save_dir) and not args.overwrite:
|
36 |
+
raise FileExistsError("save_dir [{}] already exists.".format(args.save_dir))
|
37 |
+
elif not os.path.exists(args.save_dir):
|
38 |
+
os.makedirs(args.save_dir)
|
39 |
+
args_path = os.path.join(args.save_dir, "args.json")
|
40 |
+
with open(args_path, "w") as fw:
|
41 |
+
json.dump(vars(args), fw, indent=4, sort_keys=True)
|
42 |
+
|
43 |
+
if not os.path.exists(args.data_root):
|
44 |
+
args.data_root = args.data_root.replace("/home/", "/derived/")
|
45 |
+
|
46 |
+
data_dict = load_local_data(args.data_root, audio_per_frame=1600)
|
47 |
+
print("creating data loader...")
|
48 |
+
data = get_dataset_loader(args=args, data_dict=data_dict)
|
49 |
+
|
50 |
+
print("creating logger...")
|
51 |
+
writer = SummaryWriter(args.save_dir)
|
52 |
+
|
53 |
+
print("creating model and diffusion...")
|
54 |
+
model, diffusion = create_model_and_diffusion(args, split_type="train")
|
55 |
+
model.to(rank)
|
56 |
+
|
57 |
+
if world_size > 1:
|
58 |
+
model = DDP(
|
59 |
+
model, device_ids=[rank], output_device=rank, find_unused_parameters=True
|
60 |
+
)
|
61 |
+
|
62 |
+
params = (
|
63 |
+
model.module.parameters_w_grad()
|
64 |
+
if world_size > 1
|
65 |
+
else model.parameters_w_grad()
|
66 |
+
)
|
67 |
+
print("Total params: %.2fM" % (sum(p.numel() for p in params) / 1000000.0))
|
68 |
+
print("Training...")
|
69 |
+
|
70 |
+
TrainLoop(
|
71 |
+
args, train_platform, model, diffusion, data, writer, rank, world_size
|
72 |
+
).run_loop()
|
73 |
+
train_platform.close()
|
74 |
+
cleanup()
|
75 |
+
|
76 |
+
|
77 |
+
if __name__ == "__main__":
|
78 |
+
world_size = torch.cuda.device_count()
|
79 |
+
print(f"using {world_size} gpus")
|
80 |
+
if world_size > 1:
|
81 |
+
mp.spawn(main, args=(world_size,), nprocs=world_size, join=True)
|
82 |
+
else:
|
83 |
+
main(rank=0, world_size=1)
|