whyun13 commited on
Commit
882f6e2
·
verified ·
1 Parent(s): 828a2a7

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. .gitignore +4 -0
  3. CODE_OF_CONDUCT.md +80 -0
  4. CONTRIBUTING.md +31 -0
  5. LICENSE +400 -0
  6. README.md +373 -7
  7. assets/demo1.gif +0 -0
  8. assets/demo2.gif +3 -0
  9. assets/render_defaults_GQS883.pth +3 -0
  10. assets/render_defaults_PXB184.pth +3 -0
  11. assets/render_defaults_RLW104.pth +3 -0
  12. assets/render_defaults_TXB805.pth +3 -0
  13. checkpoints/ca_body/data/PXB184/body_dec.ckpt +3 -0
  14. checkpoints/ca_body/data/PXB184/config.yml +56 -0
  15. checkpoints/diffusion/c1_face/args.json +34 -0
  16. checkpoints/diffusion/c1_pose/args.json +66 -0
  17. checkpoints/guide/c1_pose/args.json +41 -0
  18. checkpoints/vq/c1_pose/args.json +43 -0
  19. checkpoints/vq/c1_pose/net_iter300000.pth +3 -0
  20. data_loaders/data.py +253 -0
  21. data_loaders/get_data.py +129 -0
  22. data_loaders/tensors.py +86 -0
  23. demo/.ipynb_checkpoints/demo-checkpoint.py +276 -0
  24. demo/demo.py +276 -0
  25. demo/install.sh +20 -0
  26. demo/requirements.txt +17 -0
  27. diffusion/fp16_util.py +250 -0
  28. diffusion/gaussian_diffusion.py +1273 -0
  29. diffusion/losses.py +83 -0
  30. diffusion/nn.py +213 -0
  31. diffusion/resample.py +168 -0
  32. diffusion/respace.py +145 -0
  33. flagged/audio/b90d90dbca93f47e8d01/audio.wav +0 -0
  34. flagged/audio/d8e03e2e6deae2f981b1/audio.wav +0 -0
  35. flagged/log.csv +4 -0
  36. model/cfg_sampler.py +33 -0
  37. model/diffusion.py +403 -0
  38. model/guide.py +222 -0
  39. model/modules/audio_encoder.py +194 -0
  40. model/modules/rotary_embedding_torch.py +139 -0
  41. model/modules/transformer_modules.py +702 -0
  42. model/utils.py +130 -0
  43. model/vqvae.py +550 -0
  44. sample/generate.py +316 -0
  45. scripts/download_alldatasets.sh +6 -0
  46. scripts/download_allmodels.sh +13 -0
  47. scripts/download_prereq.sh +9 -0
  48. scripts/installation.sh +4 -0
  49. scripts/requirements.txt +17 -0
  50. train/train_diffusion.py +83 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/demo2.gif filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ *.pyc
2
+ *.pt
3
+ !dataset/*/data_stats.pth
4
+ dataset
CODE_OF_CONDUCT.md ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code of Conduct
2
+
3
+ ## Our Pledge
4
+
5
+ In the interest of fostering an open and welcoming environment, we as
6
+ contributors and maintainers pledge to make participation in our project and
7
+ our community a harassment-free experience for everyone, regardless of age, body
8
+ size, disability, ethnicity, sex characteristics, gender identity and expression,
9
+ level of experience, education, socio-economic status, nationality, personal
10
+ appearance, race, religion, or sexual identity and orientation.
11
+
12
+ ## Our Standards
13
+
14
+ Examples of behavior that contributes to creating a positive environment
15
+ include:
16
+
17
+ * Using welcoming and inclusive language
18
+ * Being respectful of differing viewpoints and experiences
19
+ * Gracefully accepting constructive criticism
20
+ * Focusing on what is best for the community
21
+ * Showing empathy towards other community members
22
+
23
+ Examples of unacceptable behavior by participants include:
24
+
25
+ * The use of sexualized language or imagery and unwelcome sexual attention or
26
+ advances
27
+ * Trolling, insulting/derogatory comments, and personal or political attacks
28
+ * Public or private harassment
29
+ * Publishing others' private information, such as a physical or electronic
30
+ address, without explicit permission
31
+ * Other conduct which could reasonably be considered inappropriate in a
32
+ professional setting
33
+
34
+ ## Our Responsibilities
35
+
36
+ Project maintainers are responsible for clarifying the standards of acceptable
37
+ behavior and are expected to take appropriate and fair corrective action in
38
+ response to any instances of unacceptable behavior.
39
+
40
+ Project maintainers have the right and responsibility to remove, edit, or
41
+ reject comments, commits, code, wiki edits, issues, and other contributions
42
+ that are not aligned to this Code of Conduct, or to ban temporarily or
43
+ permanently any contributor for other behaviors that they deem inappropriate,
44
+ threatening, offensive, or harmful.
45
+
46
+ ## Scope
47
+
48
+ This Code of Conduct applies within all project spaces, and it also applies when
49
+ an individual is representing the project or its community in public spaces.
50
+ Examples of representing a project or community include using an official
51
+ project e-mail address, posting via an official social media account, or acting
52
+ as an appointed representative at an online or offline event. Representation of
53
+ a project may be further defined and clarified by project maintainers.
54
+
55
+ This Code of Conduct also applies outside the project spaces when there is a
56
+ reasonable belief that an individual's behavior may have a negative impact on
57
+ the project or its community.
58
+
59
+ ## Enforcement
60
+
61
+ Instances of abusive, harassing, or otherwise unacceptable behavior may be
62
+ reported by contacting the project team at <[email protected]>. All
63
+ complaints will be reviewed and investigated and will result in a response that
64
+ is deemed necessary and appropriate to the circumstances. The project team is
65
+ obligated to maintain confidentiality with regard to the reporter of an incident.
66
+ Further details of specific enforcement policies may be posted separately.
67
+
68
+ Project maintainers who do not follow or enforce the Code of Conduct in good
69
+ faith may face temporary or permanent repercussions as determined by other
70
+ members of the project's leadership.
71
+
72
+ ## Attribution
73
+
74
+ This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
75
+ available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
76
+
77
+ [homepage]: https://www.contributor-covenant.org
78
+
79
+ For answers to common questions about this code of conduct, see
80
+ https://www.contributor-covenant.org/faq
CONTRIBUTING.md ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Contributing to audio2photoreal
2
+ We want to make contributing to this project as easy and transparent as
3
+ possible.
4
+
5
+ ## Pull Requests
6
+ We actively welcome your pull requests.
7
+
8
+ 1. Fork the repo and create your branch from `main`.
9
+ 2. If you've added code that should be tested, add tests.
10
+ 3. If you've changed APIs, update the documentation.
11
+ 4. Ensure the test suite passes.
12
+ 5. Make sure your code lints.
13
+ 6. If you haven't already, complete the Contributor License Agreement ("CLA").
14
+
15
+ ## Contributor License Agreement ("CLA")
16
+ In order to accept your pull request, we need you to submit a CLA. You only need
17
+ to do this once to work on any of Meta's open source projects.
18
+
19
+ Complete your CLA here: <https://code.facebook.com/cla>
20
+
21
+ ## Issues
22
+ We use GitHub issues to track public bugs. Please ensure your description is
23
+ clear and has sufficient instructions to be able to reproduce the issue.
24
+
25
+ Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe
26
+ disclosure of security bugs. In those cases, please go through the process
27
+ outlined on that page and do not file a public issue.
28
+
29
+ ## License
30
+ By contributing to audio2photoreal, you agree that your contributions will be licensed
31
+ under the LICENSE file in the root directory of this source tree.
LICENSE ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Attribution-NonCommercial 4.0 International
2
+
3
+ =======================================================================
4
+
5
+ Creative Commons Corporation ("Creative Commons") is not a law firm and
6
+ does not provide legal services or legal advice. Distribution of
7
+ Creative Commons public licenses does not create a lawyer-client or
8
+ other relationship. Creative Commons makes its licenses and related
9
+ information available on an "as-is" basis. Creative Commons gives no
10
+ warranties regarding its licenses, any material licensed under their
11
+ terms and conditions, or any related information. Creative Commons
12
+ disclaims all liability for damages resulting from their use to the
13
+ fullest extent possible.
14
+
15
+ Using Creative Commons Public Licenses
16
+
17
+ Creative Commons public licenses provide a standard set of terms and
18
+ conditions that creators and other rights holders may use to share
19
+ original works of authorship and other material subject to copyright
20
+ and certain other rights specified in the public license below. The
21
+ following considerations are for informational purposes only, are not
22
+ exhaustive, and do not form part of our licenses.
23
+
24
+ Considerations for licensors: Our public licenses are
25
+ intended for use by those authorized to give the public
26
+ permission to use material in ways otherwise restricted by
27
+ copyright and certain other rights. Our licenses are
28
+ irrevocable. Licensors should read and understand the terms
29
+ and conditions of the license they choose before applying it.
30
+ Licensors should also secure all rights necessary before
31
+ applying our licenses so that the public can reuse the
32
+ material as expected. Licensors should clearly mark any
33
+ material not subject to the license. This includes other CC-
34
+ licensed material, or material used under an exception or
35
+ limitation to copyright. More considerations for licensors:
36
+ wiki.creativecommons.org/Considerations_for_licensors
37
+
38
+ Considerations for the public: By using one of our public
39
+ licenses, a licensor grants the public permission to use the
40
+ licensed material under specified terms and conditions. If
41
+ the licensor's permission is not necessary for any reason--for
42
+ example, because of any applicable exception or limitation to
43
+ copyright--then that use is not regulated by the license. Our
44
+ licenses grant only permissions under copyright and certain
45
+ other rights that a licensor has authority to grant. Use of
46
+ the licensed material may still be restricted for other
47
+ reasons, including because others have copyright or other
48
+ rights in the material. A licensor may make special requests,
49
+ such as asking that all changes be marked or described.
50
+ Although not required by our licenses, you are encouraged to
51
+ respect those requests where reasonable. More_considerations
52
+ for the public:
53
+ wiki.creativecommons.org/Considerations_for_licensees
54
+
55
+ =======================================================================
56
+
57
+ Creative Commons Attribution-NonCommercial 4.0 International Public
58
+ License
59
+
60
+ By exercising the Licensed Rights (defined below), You accept and agree
61
+ to be bound by the terms and conditions of this Creative Commons
62
+ Attribution-NonCommercial 4.0 International Public License ("Public
63
+ License"). To the extent this Public License may be interpreted as a
64
+ contract, You are granted the Licensed Rights in consideration of Your
65
+ acceptance of these terms and conditions, and the Licensor grants You
66
+ such rights in consideration of benefits the Licensor receives from
67
+ making the Licensed Material available under these terms and
68
+ conditions.
69
+
70
+ Section 1 -- Definitions.
71
+
72
+ a. Adapted Material means material subject to Copyright and Similar
73
+ Rights that is derived from or based upon the Licensed Material
74
+ and in which the Licensed Material is translated, altered,
75
+ arranged, transformed, or otherwise modified in a manner requiring
76
+ permission under the Copyright and Similar Rights held by the
77
+ Licensor. For purposes of this Public License, where the Licensed
78
+ Material is a musical work, performance, or sound recording,
79
+ Adapted Material is always produced where the Licensed Material is
80
+ synched in timed relation with a moving image.
81
+
82
+ b. Adapter's License means the license You apply to Your Copyright
83
+ and Similar Rights in Your contributions to Adapted Material in
84
+ accordance with the terms and conditions of this Public License.
85
+
86
+ c. Copyright and Similar Rights means copyright and/or similar rights
87
+ closely related to copyright including, without limitation,
88
+ performance, broadcast, sound recording, and Sui Generis Database
89
+ Rights, without regard to how the rights are labeled or
90
+ categorized. For purposes of this Public License, the rights
91
+ specified in Section 2(b)(1)-(2) are not Copyright and Similar
92
+ Rights.
93
+ d. Effective Technological Measures means those measures that, in the
94
+ absence of proper authority, may not be circumvented under laws
95
+ fulfilling obligations under Article 11 of the WIPO Copyright
96
+ Treaty adopted on December 20, 1996, and/or similar international
97
+ agreements.
98
+
99
+ e. Exceptions and Limitations means fair use, fair dealing, and/or
100
+ any other exception or limitation to Copyright and Similar Rights
101
+ that applies to Your use of the Licensed Material.
102
+
103
+ f. Licensed Material means the artistic or literary work, database,
104
+ or other material to which the Licensor applied this Public
105
+ License.
106
+
107
+ g. Licensed Rights means the rights granted to You subject to the
108
+ terms and conditions of this Public License, which are limited to
109
+ all Copyright and Similar Rights that apply to Your use of the
110
+ Licensed Material and that the Licensor has authority to license.
111
+
112
+ h. Licensor means the individual(s) or entity(ies) granting rights
113
+ under this Public License.
114
+
115
+ i. NonCommercial means not primarily intended for or directed towards
116
+ commercial advantage or monetary compensation. For purposes of
117
+ this Public License, the exchange of the Licensed Material for
118
+ other material subject to Copyright and Similar Rights by digital
119
+ file-sharing or similar means is NonCommercial provided there is
120
+ no payment of monetary compensation in connection with the
121
+ exchange.
122
+
123
+ j. Share means to provide material to the public by any means or
124
+ process that requires permission under the Licensed Rights, such
125
+ as reproduction, public display, public performance, distribution,
126
+ dissemination, communication, or importation, and to make material
127
+ available to the public including in ways that members of the
128
+ public may access the material from a place and at a time
129
+ individually chosen by them.
130
+
131
+ k. Sui Generis Database Rights means rights other than copyright
132
+ resulting from Directive 96/9/EC of the European Parliament and of
133
+ the Council of 11 March 1996 on the legal protection of databases,
134
+ as amended and/or succeeded, as well as other essentially
135
+ equivalent rights anywhere in the world.
136
+
137
+ l. You means the individual or entity exercising the Licensed Rights
138
+ under this Public License. Your has a corresponding meaning.
139
+
140
+ Section 2 -- Scope.
141
+
142
+ a. License grant.
143
+
144
+ 1. Subject to the terms and conditions of this Public License,
145
+ the Licensor hereby grants You a worldwide, royalty-free,
146
+ non-sublicensable, non-exclusive, irrevocable license to
147
+ exercise the Licensed Rights in the Licensed Material to:
148
+
149
+ a. reproduce and Share the Licensed Material, in whole or
150
+ in part, for NonCommercial purposes only; and
151
+
152
+ b. produce, reproduce, and Share Adapted Material for
153
+ NonCommercial purposes only.
154
+
155
+ 2. Exceptions and Limitations. For the avoidance of doubt, where
156
+ Exceptions and Limitations apply to Your use, this Public
157
+ License does not apply, and You do not need to comply with
158
+ its terms and conditions.
159
+
160
+ 3. Term. The term of this Public License is specified in Section
161
+ 6(a).
162
+
163
+ 4. Media and formats; technical modifications allowed. The
164
+ Licensor authorizes You to exercise the Licensed Rights in
165
+ all media and formats whether now known or hereafter created,
166
+ and to make technical modifications necessary to do so. The
167
+ Licensor waives and/or agrees not to assert any right or
168
+ authority to forbid You from making technical modifications
169
+ necessary to exercise the Licensed Rights, including
170
+ technical modifications necessary to circumvent Effective
171
+ Technological Measures. For purposes of this Public License,
172
+ simply making modifications authorized by this Section 2(a)
173
+ (4) never produces Adapted Material.
174
+
175
+ 5. Downstream recipients.
176
+
177
+ a. Offer from the Licensor -- Licensed Material. Every
178
+ recipient of the Licensed Material automatically
179
+ receives an offer from the Licensor to exercise the
180
+ Licensed Rights under the terms and conditions of this
181
+ Public License.
182
+
183
+ b. No downstream restrictions. You may not offer or impose
184
+ any additional or different terms or conditions on, or
185
+ apply any Effective Technological Measures to, the
186
+ Licensed Material if doing so restricts exercise of the
187
+ Licensed Rights by any recipient of the Licensed
188
+ Material.
189
+
190
+ 6. No endorsement. Nothing in this Public License constitutes or
191
+ may be construed as permission to assert or imply that You
192
+ are, or that Your use of the Licensed Material is, connected
193
+ with, or sponsored, endorsed, or granted official status by,
194
+ the Licensor or others designated to receive attribution as
195
+ provided in Section 3(a)(1)(A)(i).
196
+
197
+ b. Other rights.
198
+
199
+ 1. Moral rights, such as the right of integrity, are not
200
+ licensed under this Public License, nor are publicity,
201
+ privacy, and/or other similar personality rights; however, to
202
+ the extent possible, the Licensor waives and/or agrees not to
203
+ assert any such rights held by the Licensor to the limited
204
+ extent necessary to allow You to exercise the Licensed
205
+ Rights, but not otherwise.
206
+
207
+ 2. Patent and trademark rights are not licensed under this
208
+ Public License.
209
+
210
+ 3. To the extent possible, the Licensor waives any right to
211
+ collect royalties from You for the exercise of the Licensed
212
+ Rights, whether directly or through a collecting society
213
+ under any voluntary or waivable statutory or compulsory
214
+ licensing scheme. In all other cases the Licensor expressly
215
+ reserves any right to collect such royalties, including when
216
+ the Licensed Material is used other than for NonCommercial
217
+ purposes.
218
+
219
+ Section 3 -- License Conditions.
220
+
221
+ Your exercise of the Licensed Rights is expressly made subject to the
222
+ following conditions.
223
+
224
+ a. Attribution.
225
+
226
+ 1. If You Share the Licensed Material (including in modified
227
+ form), You must:
228
+
229
+ a. retain the following if it is supplied by the Licensor
230
+ with the Licensed Material:
231
+
232
+ i. identification of the creator(s) of the Licensed
233
+ Material and any others designated to receive
234
+ attribution, in any reasonable manner requested by
235
+ the Licensor (including by pseudonym if
236
+ designated);
237
+
238
+ ii. a copyright notice;
239
+
240
+ iii. a notice that refers to this Public License;
241
+
242
+ iv. a notice that refers to the disclaimer of
243
+ warranties;
244
+
245
+ v. a URI or hyperlink to the Licensed Material to the
246
+ extent reasonably practicable;
247
+
248
+ b. indicate if You modified the Licensed Material and
249
+ retain an indication of any previous modifications; and
250
+
251
+ c. indicate the Licensed Material is licensed under this
252
+ Public License, and include the text of, or the URI or
253
+ hyperlink to, this Public License.
254
+
255
+ 2. You may satisfy the conditions in Section 3(a)(1) in any
256
+ reasonable manner based on the medium, means, and context in
257
+ which You Share the Licensed Material. For example, it may be
258
+ reasonable to satisfy the conditions by providing a URI or
259
+ hyperlink to a resource that includes the required
260
+ information.
261
+
262
+ 3. If requested by the Licensor, You must remove any of the
263
+ information required by Section 3(a)(1)(A) to the extent
264
+ reasonably practicable.
265
+
266
+ 4. If You Share Adapted Material You produce, the Adapter's
267
+ License You apply must not prevent recipients of the Adapted
268
+ Material from complying with this Public License.
269
+
270
+ Section 4 -- Sui Generis Database Rights.
271
+
272
+ Where the Licensed Rights include Sui Generis Database Rights that
273
+ apply to Your use of the Licensed Material:
274
+
275
+ a. for the avoidance of doubt, Section 2(a)(1) grants You the right
276
+ to extract, reuse, reproduce, and Share all or a substantial
277
+ portion of the contents of the database for NonCommercial purposes
278
+ only;
279
+
280
+ b. if You include all or a substantial portion of the database
281
+ contents in a database in which You have Sui Generis Database
282
+ Rights, then the database in which You have Sui Generis Database
283
+ Rights (but not its individual contents) is Adapted Material; and
284
+
285
+ c. You must comply with the conditions in Section 3(a) if You Share
286
+ all or a substantial portion of the contents of the database.
287
+
288
+ For the avoidance of doubt, this Section 4 supplements and does not
289
+ replace Your obligations under this Public License where the Licensed
290
+ Rights include other Copyright and Similar Rights.
291
+
292
+ Section 5 -- Disclaimer of Warranties and Limitation of Liability.
293
+
294
+ a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
295
+ EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
296
+ AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
297
+ ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
298
+ IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
299
+ WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
300
+ PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
301
+ ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
302
+ KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
303
+ ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
304
+
305
+ b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
306
+ TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
307
+ NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
308
+ INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
309
+ COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
310
+ USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
311
+ ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
312
+ DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
313
+ IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
314
+
315
+ c. The disclaimer of warranties and limitation of liability provided
316
+ above shall be interpreted in a manner that, to the extent
317
+ possible, most closely approximates an absolute disclaimer and
318
+ waiver of all liability.
319
+
320
+ Section 6 -- Term and Termination.
321
+
322
+ a. This Public License applies for the term of the Copyright and
323
+ Similar Rights licensed here. However, if You fail to comply with
324
+ this Public License, then Your rights under this Public License
325
+ terminate automatically.
326
+
327
+ b. Where Your right to use the Licensed Material has terminated under
328
+ Section 6(a), it reinstates:
329
+
330
+ 1. automatically as of the date the violation is cured, provided
331
+ it is cured within 30 days of Your discovery of the
332
+ violation; or
333
+
334
+ 2. upon express reinstatement by the Licensor.
335
+
336
+ For the avoidance of doubt, this Section 6(b) does not affect any
337
+ right the Licensor may have to seek remedies for Your violations
338
+ of this Public License.
339
+
340
+ c. For the avoidance of doubt, the Licensor may also offer the
341
+ Licensed Material under separate terms or conditions or stop
342
+ distributing the Licensed Material at any time; however, doing so
343
+ will not terminate this Public License.
344
+
345
+ d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
346
+ License.
347
+
348
+ Section 7 -- Other Terms and Conditions.
349
+
350
+ a. The Licensor shall not be bound by any additional or different
351
+ terms or conditions communicated by You unless expressly agreed.
352
+
353
+ b. Any arrangements, understandings, or agreements regarding the
354
+ Licensed Material not stated herein are separate from and
355
+ independent of the terms and conditions of this Public License.
356
+
357
+ Section 8 -- Interpretation.
358
+
359
+ a. For the avoidance of doubt, this Public License does not, and
360
+ shall not be interpreted to, reduce, limit, restrict, or impose
361
+ conditions on any use of the Licensed Material that could lawfully
362
+ be made without permission under this Public License.
363
+
364
+ b. To the extent possible, if any provision of this Public License is
365
+ deemed unenforceable, it shall be automatically reformed to the
366
+ minimum extent necessary to make it enforceable. If the provision
367
+ cannot be reformed, it shall be severed from this Public License
368
+ without affecting the enforceability of the remaining terms and
369
+ conditions.
370
+
371
+ c. No term or condition of this Public License will be waived and no
372
+ failure to comply consented to unless expressly agreed to by the
373
+ Licensor.
374
+
375
+ d. Nothing in this Public License constitutes or may be interpreted
376
+ as a limitation upon, or waiver of, any privileges and immunities
377
+ that apply to the Licensor or You, including from the legal
378
+ processes of any jurisdiction or authority.
379
+
380
+ =======================================================================
381
+
382
+ Creative Commons is not a party to its public
383
+ licenses. Notwithstanding, Creative Commons may elect to apply one of
384
+ its public licenses to material it publishes and in those instances
385
+ will be considered the “Licensor.” The text of the Creative Commons
386
+ public licenses is dedicated to the public domain under the CC0 Public
387
+ Domain Dedication. Except for the limited purpose of indicating that
388
+ material is shared under a Creative Commons public license or as
389
+ otherwise permitted by the Creative Commons policies published at
390
+ creativecommons.org/policies, Creative Commons does not authorize the
391
+ use of the trademark "Creative Commons" or any other trademark or logo
392
+ of Creative Commons without its prior written consent including,
393
+ without limitation, in connection with any unauthorized modifications
394
+ to any of its public licenses or any other arrangements,
395
+ understandings, or agreements concerning use of licensed material. For
396
+ the avoidance of doubt, this paragraph does not form part of the
397
+ public licenses.
398
+
399
+ Creative Commons may be contacted at creativecommons.org.
400
+
README.md CHANGED
@@ -1,12 +1,378 @@
1
  ---
2
- title: Test Virtual
3
- emoji: 📈
4
- colorFrom: gray
5
- colorTo: indigo
6
  sdk: gradio
7
  sdk_version: 4.38.1
8
- app_file: app.py
9
- pinned: false
10
  ---
 
 
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: test_virtual
3
+ app_file: ./demo/demo.py
 
 
4
  sdk: gradio
5
  sdk_version: 4.38.1
 
 
6
  ---
7
+ # From Audio to Photoreal Embodiment: Synthesizing Humans in Conversations
8
+ This repository contains a pytorch implementation of ["From Audio to Photoreal Embodiment: Synthesizing Humans in Conversations"](https://people.eecs.berkeley.edu/~evonne_ng/projects/audio2photoreal/)
9
 
10
+ :hatching_chick: **Try out our demo [here](https://colab.research.google.com/drive/1lnX3d-3T3LaO3nlN6R8s6pPvVNAk5mdK?usp=sharing)** or continue following the steps below to run code locally!
11
+ And thanks everyone for the support via contributions/comments/issues!
12
+
13
+ https://github.com/facebookresearch/audio2photoreal/assets/17986358/5cba4079-275e-48b6-aecc-f84f3108c810
14
+
15
+ This codebase provides:
16
+ - train code
17
+ - test code
18
+ - pretrained motion models
19
+ - access to dataset
20
+
21
+ If you use the dataset or code, please cite our [Paper](https://arxiv.org/abs/2401.01885)
22
+
23
+ ```
24
+ @inproceedings{ng2024audio2photoreal,
25
+ title={From Audio to Photoreal Embodiment: Synthesizing Humans in Conversations},
26
+ author={Ng, Evonne and Romero, Javier and Bagautdinov, Timur and Bai, Shaojie and Darrell, Trevor and Kanazawa, Angjoo and Richard, Alexander},
27
+ booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
28
+ year={2024}
29
+ }
30
+ ```
31
+
32
+ ### Repository Contents
33
+
34
+ - [**Quickstart:**](#quickstart) easy gradio demo that lets you record audio and render a video
35
+ - [**Installation:**](#installation) environment setup and installation (for more details on the rendering pipeline, please refer to [Codec Avatar Body](https://github.com/facebookresearch/ca_body))
36
+ - [**Download data and models:**](#download-data-and-models) download annotations and pre-trained models
37
+ - [Dataset desc.](#dataset): description of dataset annotations
38
+ - [Visualize Dataset](#visualize-ground-truth): script for visualizing ground truth annotations
39
+ - [model desc.](#pretrained-models): description of pretrained models
40
+ - [**Running the pretrained models:**](#running-the-pretrained-models) how to generate results files and visualize the results using the rendering pipeline.
41
+ - [Face generation](#face-generation): commands to generate the results file for the faces
42
+ - [Body generation](#body-generation): commands to generate the results file for the bodies
43
+ - [Visualization](#visualization): how to call into the rendering api. For full details, please refer to [this repo](https://github.com/facebookresearch/ca_body).
44
+ - [**Training from scratch (3 models):**](#training-from-scratch) scripts to get the training pipeline running from scratch for face, guide poses, and body models.
45
+ - [Face diffusion model](#1-face-diffusion-model)
46
+ - [Body diffusion](#2-body-diffusion-model)
47
+ - [Body vq vae](#3-body-vq-vae)
48
+ - [Body guide transformer](#4-body-guide-transformer)
49
+
50
+ We annotate code that you can directly copy and paste into your terminal using the :point_down: icon.
51
+
52
+ # Quickstart
53
+ With this demo, you can record an audio clip and select the number of samples you want to generate.
54
+
55
+ Make sure you have CUDA 11.7 and gcc/++ 9.0 for pytorch3d compatibility
56
+
57
+ :point_down: Install necessary components. This will do the environment configuration and install the corresponding rendering assets, prerequisite models, and pretrained models:
58
+ ```
59
+ conda create --name a2p_env python=3.9
60
+ conda activate a2p_env
61
+ sh demo/install.sh
62
+ ```
63
+ :point_down: Run the demo. You can record your audio and then render corresponding results!
64
+ ```
65
+ python -m demo.demo
66
+ ```
67
+
68
+ :microphone: First, record your audio
69
+
70
+ ![](assets/demo1.gif)
71
+
72
+ :hourglass: Hold tight because the rendering can take a while!
73
+
74
+ You can change the number of samples (1-10) you want to generate, and download your favorite video by clicking on the download button on the top right of each video.
75
+
76
+ ![](assets/demo2.gif)
77
+
78
+ # Installation
79
+ The code has been tested with CUDA 11.7 and python 3.9, gcc/++ 9.0
80
+
81
+ :point_down: If you haven't done so already via the demo setup, configure the environments and download prerequisite models:
82
+ ```
83
+ conda create --name a2p_env python=3.9
84
+ conda activate a2p_env
85
+ pip install -r scripts/requirements.txt
86
+ sh scripts/download_prereq.sh
87
+ ```
88
+ :point_down: To get the rendering working, please also make sure you install [pytorch3d](https://github.com/facebookresearch/pytorch3d/blob/main/INSTALL.md).
89
+ ```
90
+ pip install "git+https://github.com/facebookresearch/pytorch3d.git"
91
+ ```
92
+ Please see [CA Bodies repo](https://github.com/facebookresearch/ca_body) for more details on the renderer.
93
+
94
+ # Download data and models
95
+ To download any of the datasets, you can find them at `https://github.com/facebookresearch/audio2photoreal/releases/download/v1.0/<person_id>.zip`, where you can replace `<person_id>` with any of `PXB184`, `RLW104`, `TXB805`, or `GQS883`.
96
+ Download over the command line can be done with this commands.
97
+ ```
98
+ curl -L https://github.com/facebookresearch/audio2photoreal/releases/download/v1.0/<person_id>.zip -o <person_id>.zip
99
+ unzip <person_id>.zip -d dataset/
100
+ rm <person_id>.zip
101
+ ```
102
+ :point_down: To download *all* of the datasets, you can simply run the following which will download and unpack all the models.
103
+ ```
104
+ sh scripts/download_alldatasets.sh
105
+ ```
106
+
107
+ Similarly, to download any of the models, you can find them at `http://audio2photoreal_models.berkeleyvision.org/<person_id>_models.tar`.
108
+ ```
109
+ # download the motion generation
110
+ wget http://audio2photoreal_models.berkeleyvision.org/<person_id>_models.tar
111
+ tar xvf <person_id>_models.tar
112
+ rm <person_id>_models.tar
113
+
114
+ # download the body decoder/rendering assets and place them in the right place
115
+ mkdir -p checkpoints/ca_body/data/
116
+ wget https://github.com/facebookresearch/ca_body/releases/download/v0.0.1-alpha/<person_id>.tar.gz
117
+ tar xvf <person_id>.tar.gz --directory checkpoints/ca_body/data/
118
+ rm <person_id>.tar.gz
119
+ ```
120
+ :point_down: You can also download all of the models with this script:
121
+ ```
122
+ sh scripts/download_allmodels.sh
123
+ ```
124
+ The above model script will download *both* the models for motion generation and the body decoder/rendering models. Please view the script for more details.
125
+
126
+ ### Dataset
127
+ Once the dataset is downloaded and unzipped (via `scripts/download_datasets.sh`), it should unfold into the following directory structure:
128
+ ```
129
+ |-- dataset/
130
+ |-- PXB184/
131
+ |-- data_stats.pth
132
+ |-- scene01_audio.wav
133
+ |-- scene01_body_pose.npy
134
+ |-- scene01_face_expression.npy
135
+ |-- scene01_missing_face_frames.npy
136
+ |-- ...
137
+ |-- scene30_audio.wav
138
+ |-- scene30_body_pose.npy
139
+ |-- scene30_face_expression.npy
140
+ |-- scene30_missing_face_frames.npy
141
+ |-- RLW104/
142
+ |-- TXB805/
143
+ |-- GQS883/
144
+ ```
145
+ Each of the four participants (`PXB184`, `RLW104`, `TXB805`, `GQS883`) should have independent "scenes" (1 to 26 or so).
146
+ For each scene, there are 3 types of data annotations that we save.
147
+ ```
148
+ *audio.wav: wavefile containing the raw audio (two channels, 1600*T samples) at 48kHz; channel 0 is the audio associated with the current person, channel 1 is the audio associated with their conversational partner.
149
+
150
+ *body_pose.npy: (T x 104) array of joint angles in a kinematic skeleton. Not all of the joints are represented with 3DoF. Each 104-d vector can be used to reconstruct a full-body skeleton.
151
+
152
+ *face_expression.npy: (T x 256) array of facial codes, where each 256-d vector reconstructs a face mesh.
153
+
154
+ *missing_face_frames.npy: List of indices (t) where the facial code is missing or corrupted.
155
+
156
+ data_stats.pth: carries the mean and std for each modality of each person.
157
+ ```
158
+
159
+ For the train/val/test split the indices are defined in `data_loaders/data.py` as:
160
+ ```
161
+ train_idx = list(range(0, len(data_dict["data"]) - 6))
162
+ val_idx = list(range(len(data_dict["data"]) - 6, len(data_dict["data"]) - 4))
163
+ test_idx = list(range(len(data_dict["data"]) - 4, len(data_dict["data"])))
164
+ ```
165
+ for any of the four dataset participants we train on.
166
+
167
+ ### Visualize ground truth
168
+ If you've properly installed the rendering requirements, you can then visualize the full dataset with the following command:
169
+ ```
170
+ python -m visualize.render_anno
171
+ --save_dir <path/to/save/dir>
172
+ --data_root <path/to/data/root>
173
+ --max_seq_length <num>
174
+ ```
175
+
176
+ The videos will be chunked lengths according to specified `--max_seq_length` arg, which you can specify (the default is 600).
177
+
178
+ :point_down: For example, to visualize ground truth annotations for `PXB184`, you can run the following.
179
+ ```
180
+ python -m visualize.render_anno --save_dir vis_anno_test --data_root dataset/PXB184 --max_seq_length 600
181
+ ```
182
+
183
+ ### Pretrained models
184
+ We train person-specific models, so each person should have an associated directory. For instance, for `PXB184`, their complete models should unzip into the following structure.
185
+ ```
186
+ |-- checkpoints/
187
+ |-- diffusion/
188
+ |-- c1_face/
189
+ |-- args.json
190
+ |-- model:09d.pt
191
+ |-- c1_pose/
192
+ |-- args.json
193
+ |-- model:09d.pt
194
+ |-- guide/
195
+ |-- c1_pose/
196
+ |-- args.json
197
+ |-- checkpoints/
198
+ |-- iter-:07d.pt
199
+ |-- vq/
200
+ |-- c1_pose/
201
+ |-- args.json
202
+ |-- net_iter:06d.pth
203
+ ```
204
+ There are 4 models for each person and each model has an associated `args.json`.
205
+ 1. a face diffusion model that outputs 256 facial codes conditioned on audio
206
+ 2. a pose diffusion model that outputs 104 joint rotations conditioned on audio and guide poses
207
+ 3. a guide vq pose model that outputs vq tokens conditioned on audio at 1 fps
208
+ 4. a vq encoder-decoder model that vector quantizes the continuous 104-d pose space.
209
+
210
+ # Running the pretrained models
211
+ To run the actual models, you will need to run the pretrained models and generate the associated results files before visualizing them.
212
+
213
+ ### Face generation
214
+ To generate the results file for the face,
215
+ ```
216
+ python -m sample.generate
217
+ --model_path <path/to/model>
218
+ --num_samples <xsamples>
219
+ --num_repetitions <xreps>
220
+ --timestep_respacing ddim500
221
+ --guidance_param 10.0
222
+ ```
223
+
224
+ The `<path/to/model>` should be the path to the diffusion model that is associated with generating the face.
225
+ E.g. for participant `PXB184`, the path might be `./checkpoints/diffusion/c1_face/model000155000.pt`
226
+ The other parameters are:
227
+ ```
228
+ --num_samples: number of samples to generate. To sample the full dataset, use 56 (except for TXB805, whcih is 58).
229
+ --num_repetitions: number of times to repeat the sampling, such that total number of sequences generated is (num_samples * num_repetitions).
230
+ --timestep_respacing: how many diffusion steps to take. Format will always be ddim<number>.
231
+ --guidance_param: how influential the conditioning is on the results. I usually use range 2.0-10.0, and tend towards higher for the face.
232
+ ```
233
+
234
+ :point_down: A full example of running the face model for `PXB184` with the provided pretrained models would then be:
235
+ ```
236
+ python -m sample.generate --model_path checkpoints/diffusion/c1_face/model000155000.pt --num_samples 10 --num_repetitions 5 --timestep_respacing ddim500 --guidance_param 10.0
237
+ ```
238
+ This generates 10 samples from the dataset 1 time. The output results file will be saved to:
239
+ `./checkpoints/diffusion/c1_face/samples_c1_face_000155000_seed10_/results.npy`
240
+
241
+ ### Body generation
242
+ To generate the corresponding body, it will be very similar to generating the face, except now we have to feed in the model for generating the guide poses as well.
243
+ ```
244
+ python -m sample.generate
245
+ --model_path <path/to/model>
246
+ --resume_trans <path/to/guide/model>
247
+ --num_samples <xsamples>
248
+ --num_repetitions <xreps>
249
+ --timestep_respacing ddim500
250
+ --guidance_param 2.0
251
+ ```
252
+
253
+ :point_down: Here, `<path/to/guide/model>` should point to the guide transformer. The full command would be:
254
+ ```
255
+ python -m sample.generate --model_path checkpoints/diffusion/c1_pose/model000340000.pt --resume_trans checkpoints/guide/c1_pose/checkpoints/iter-0100000.pt --num_samples 10 --num_repetitions 5 --timestep_respacing ddim500 --guidance_param 2.0
256
+ ```
257
+ Similarly, the output will be saved to:
258
+ `./checkpoints/diffusion/c1_pose/samples_c1_pose_000340000_seed10_guide_iter-0100000.pt/results.npy`
259
+
260
+ ### Visualization
261
+ On the body generation side of things, you can also optionally pass in the `--plot` flag in order to render out the photorealistic avatar. You will also need to pass in the corresponding generated face codes with the `--face_codes` flag.
262
+ Optionally, if you already have the poses precomputed, you an also pass in the generated body with the `--pose_codes` flag.
263
+ This will save videos in the same directory as where the body's `results.npy` is stored.
264
+
265
+ :point_down: An example of the full command with *the three new flags added is*:
266
+ ```
267
+ python -m sample.generate --model_path checkpoints/diffusion/c1_pose/model000340000.pt --resume_trans checkpoints/guide/c1_pose/checkpoints/iter-0100000.pt --num_samples 10 --num_repetitions 5 --timestep_respacing ddim500 --guidance_param 2.0 --face_codes ./checkpoints/diffusion/c1_face/samples_c1_face_000155000_seed10_/results.npy --pose_codes ./checkpoints/diffusion/c1_pose/samples_c1_pose_000340000_seed10_guide_iter-0100000.pt/results.npy --plot
268
+ ```
269
+ The remaining flags can be the same as before. For the actual rendering api, please see [Codec Avatar Body](https://github.com/facebookresearch/ca_body) for installation etc.
270
+ *Important: in order to visualize the full photorealistic avatar, you will need to run the face codes first, then pass them into the body generation code.* It will not work if you try to call generate with `--plot` for the face codes.
271
+
272
+ # Training from scratch
273
+ There are four possible models you will need to train: 1) the face diffusion model, 2) the body diffusion model, 3) the body vq vae, 4) the body guide transformer.
274
+ The only dependency is that 3) is needed for 4). All other models can be trained in parallel.
275
+
276
+ ### 1) Face diffusion model
277
+ To train the face model, you will need to run the following script:
278
+ ```
279
+ python -m train.train_diffusion
280
+ --save_dir <path/to/save/dir>
281
+ --data_root <path/to/data/root>
282
+ --batch_size <bs>
283
+ --dataset social
284
+ --data_format face
285
+ --layers 8
286
+ --heads 8
287
+ --timestep_respacing ''
288
+ --max_seq_length 600
289
+ ```
290
+ Importantly, a few of the flags are as follows:
291
+ ```
292
+ --save_dir: path to directory where all outputs are stored
293
+ --data_root: path to the directory of where to load the data from
294
+ --dataset: name of dataset to load; right now we only support the 'social' dataset
295
+ --data_format: set to 'face' for the face, as opposed to pose
296
+ --timestep_respacing: set to '' which does the default spacing of 1k diffusion steps
297
+ --max_seq_length: the maximum number of frames for a given sequence to train on
298
+ ```
299
+ :point_down: A full example for training on person `PXB184` is:
300
+ ```
301
+ python -m train.train_diffusion --save_dir checkpoints/diffusion/c1_face_test --data_root ./dataset/PXB184/ --batch_size 4 --dataset social --data_format face --layers 8 --heads 8 --timestep_respacing '' --max_seq_length 600
302
+ ```
303
+
304
+ ### 2) Body diffusion model
305
+ Training the body model is similar to the face model, but with the following additional parameters
306
+ ```
307
+ python -m train.train_diffusion
308
+ --save_dir <path/to/save/dir>
309
+ --data_root <path/to/data/root>
310
+ --lambda_vel <num>
311
+ --batch_size <bs>
312
+ --dataset social
313
+ --add_frame_cond 1
314
+ --data_format pose
315
+ --layers 6
316
+ --heads 8
317
+ --timestep_respacing ''
318
+ --max_seq_length 600
319
+ ```
320
+ The flags that differ from the face training are as follows:
321
+ ```
322
+ --lambda_vel: additional auxilary loss for training with velocity
323
+ --add_frame_cond: set to '1' for 1 fps. if not specified, it will default to 30 fps.
324
+ --data_format: set to 'pose' for the body, as opposed to face
325
+ ```
326
+ :point_down: A full example for training on person `PXB184` is:
327
+ ```
328
+ python -m train.train_diffusion --save_dir checkpoints/diffusion/c1_pose_test --data_root ./dataset/PXB184/ --lambda_vel 2.0 --batch_size 4 --dataset social --add_frame_cond 1 --data_format pose --layers 6 --heads 8 --timestep_respacing '' --max_seq_length 600
329
+ ```
330
+
331
+ ### 3) Body VQ VAE
332
+ To train a vq encoder-decoder, you will need to run the following script:
333
+ ```
334
+ python -m train.train_vq
335
+ --out_dir <path/to/out/dir>
336
+ --data_root <path/to/data/root>
337
+ --batch_size <bs>
338
+ --lr 1e-3
339
+ --code_dim 1024
340
+ --output_emb_width 64
341
+ --depth 4
342
+ --dataname social
343
+ --loss_vel 0.0
344
+ --add_frame_cond 1
345
+ --data_format pose
346
+ --max_seq_length 600
347
+ ```
348
+ :point_down: For person `PXB184`, it would be:
349
+ ```
350
+ python -m train.train_vq --out_dir checkpoints/vq/c1_vq_test --data_root ./dataset/PXB184/ --lr 1e-3 --code_dim 1024 --output_emb_width 64 --depth 4 --dataname social --loss_vel 0.0 --data_format pose --batch_size 4 --add_frame_cond 1 --max_seq_length 600
351
+ ```
352
+
353
+ ### 4) Body guide transformer
354
+ Once you have the vq trained from 3) you can then pass it in to train the body guide pose transformer:
355
+ ```
356
+ python -m train.train_guide
357
+ --out_dir <path/to/out/dir>
358
+ --data_root <path/to/data/root>
359
+ --batch_size <bs>
360
+ --resume_pth <path/to/vq/model>
361
+ --add_frame_cond 1
362
+ --layers 6
363
+ --lr 2e-4
364
+ --gn
365
+ --dim 64
366
+ ```
367
+ :point_down: For person `PXB184`, it would be:
368
+ ```
369
+ python -m train.train_guide --out_dir checkpoints/guide/c1_trans_test --data_root ./dataset/PXB184/ --batch_size 4 --resume_pth checkpoints/vq/c1_vq_test/net_iter300000.pth --add_frame_cond 1 --layers 6 --lr 2e-4 --gn --dim 64
370
+ ```
371
+
372
+ After training these 4 models, you can now follow the ["Running the pretrained models"](#running-the-pretrained-models) section to generate samples and visualize results.
373
+
374
+ You can also visualize the corresponding ground truth sequences by passing in the `--render_gt` flag.
375
+
376
+
377
+ # License
378
+ The code and dataset are released under [CC-NC 4.0 International license](https://github.com/facebookresearch/audio2photoreal/blob/main/LICENSE).
assets/demo1.gif ADDED
assets/demo2.gif ADDED

Git LFS Details

  • SHA256: 4d07d3817b4a23bdb0a36869a469d051b9b10fe68d9e6f02f6cc8765cd6f5bc3
  • Pointer size: 132 Bytes
  • Size of remote file: 1.31 MB
assets/render_defaults_GQS883.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3ae7ee73849e258bbb8d8a04aa674960896fc1dff8757fefbd2df1685225dd7d
3
+ size 71354547
assets/render_defaults_PXB184.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c86ba14a58d4829c8d05428f5e601072dc4bab1bdc60bc53ce6c73990e9b97d7
3
+ size 71354547
assets/render_defaults_RLW104.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:808a3fbf33115d3cc132bad48c2e95bfca29bb1847d912b1f72e5e5b4a081db5
3
+ size 71354547
assets/render_defaults_TXB805.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d7985c79edfba70f83f560859f2ce214d9779a46031aa8ca6a917d8fd4417e24
3
+ size 71354547
checkpoints/ca_body/data/PXB184/body_dec.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:26394ae03c1726b7c90b5633696d0eea733a3c5e423893c4e79b490c80c35ddf
3
+ size 893279810
checkpoints/ca_body/data/PXB184/config.yml ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ model:
3
+ class_name: ca_body.models.mesh_vae_drivable.AutoEncoder
4
+
5
+ encoder:
6
+ n_embs: 1024
7
+ noise_std: 1.0
8
+
9
+ encoder_face:
10
+ n_embs: 256
11
+ noise_std: 1.0
12
+
13
+ decoder_face:
14
+ n_latent: 256
15
+ n_vert_out: 21918
16
+
17
+ decoder:
18
+ init_uv_size: 64
19
+ n_init_channels: 64
20
+ n_min_channels: 4
21
+ n_pose_dims: 98
22
+ n_pose_enc_channels: 16
23
+ n_embs: 1024
24
+ n_embs_enc_channels: 32
25
+ n_face_embs: 256
26
+ uv_size: 1024
27
+
28
+ decoder_view:
29
+ net_uv_size: 1024
30
+
31
+ upscale_net:
32
+ n_ftrs: 4
33
+
34
+ shadow_net:
35
+ uv_size: 2048
36
+ shadow_size: 256
37
+ n_dims: 4
38
+
39
+ pose_to_shadow:
40
+ n_pose_dims: 104
41
+ uv_size: 2048
42
+
43
+ renderer:
44
+ image_height: 2048
45
+ image_width: 1334
46
+ depth_disc_ksize: 3
47
+
48
+ cal:
49
+ identity_camera: '400143'
50
+
51
+ pixel_cal:
52
+ image_height: 2048
53
+ image_width: 1334
54
+ ds_rate: 8
55
+
56
+ learn_blur: true
checkpoints/diffusion/c1_face/args.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_frame_cond": null,
3
+ "batch_size": 4,
4
+ "cond_mask_prob": 0.2,
5
+ "cuda": true,
6
+ "data_format": "face",
7
+ "data_root": "./dataset/PXB184/",
8
+ "dataset": "social",
9
+ "device": 0,
10
+ "diffusion_steps": 10,
11
+ "heads": 8,
12
+ "lambda_vel": 0.0,
13
+ "latent_dim": 512,
14
+ "layers": 8,
15
+ "log_interval": 1000,
16
+ "lr": 0.0001,
17
+ "lr_anneal_steps": 0,
18
+ "max_seq_length": 600,
19
+ "noise_schedule": "cosine",
20
+ "not_rotary": false,
21
+ "num_audio_layers": 3,
22
+ "num_steps": 800000,
23
+ "overwrite": false,
24
+ "resume_checkpoint": "",
25
+ "save_dir": "checkpoints/diffusion/c1_face/",
26
+ "save_interval": 5000,
27
+ "seed": 10,
28
+ "sigma_small": true,
29
+ "simplify_audio": false,
30
+ "timestep_respacing": "",
31
+ "train_platform_type": "NoPlatform",
32
+ "unconstrained": false,
33
+ "weight_decay": 0.0
34
+ }
checkpoints/diffusion/c1_pose/args.json ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_frame_cond": 1.0,
3
+ "arch": "trans_enc",
4
+ "batch_size": 32,
5
+ "clip_body": false,
6
+ "clip_use_delta": false,
7
+ "clip_use_vae": false,
8
+ "cond_mask_prob": 0.1,
9
+ "cuda": true,
10
+ "data_format": "pose",
11
+ "data_root": "./dataset/PXB184/",
12
+ "dataset": "social",
13
+ "device": 0,
14
+ "diffusion_steps": 10,
15
+ "emb_trans_dec": false,
16
+ "eval_batch_size": 32,
17
+ "eval_during_training": false,
18
+ "eval_num_samples": 1000,
19
+ "eval_rep_times": 3,
20
+ "eval_split": "val",
21
+ "filter": false,
22
+ "heads": 8,
23
+ "lambda_fc": 0.0,
24
+ "lambda_hands": 0.0,
25
+ "lambda_lips": 0.0,
26
+ "lambda_rcxyz": 0.0,
27
+ "lambda_vel": 2.0,
28
+ "lambda_xyz": 0.0,
29
+ "lambda_xyz_vel": 0.0,
30
+ "latent_dim": 512,
31
+ "layers": 6,
32
+ "log_interval": 1000,
33
+ "lr": 0.0001,
34
+ "lr_anneal_steps": 0,
35
+ "max_seq_length": 600,
36
+ "no_split": false,
37
+ "noise_schedule": "cosine",
38
+ "not_rotary": false,
39
+ "num_frames": 60,
40
+ "num_steps": 800000,
41
+ "overwrite": false,
42
+ "partial": false,
43
+ "resume_checkpoint": "",
44
+ "save_dir": "checkpoints/diffusion/c1_pose/",
45
+ "save_interval": 5000,
46
+ "seed": 10,
47
+ "sigma_small": true,
48
+ "simplify_audio": false,
49
+ "split_net": false,
50
+ "timestep_respacing": "",
51
+ "train_platform_type": "NoPlatform",
52
+ "unconstrained": false,
53
+ "use_clip": false,
54
+ "use_cm": true,
55
+ "use_full_dataset": false,
56
+ "use_kp": false,
57
+ "use_mask": true,
58
+ "use_mdm": false,
59
+ "use_nort": false,
60
+ "use_nort_mdm": false,
61
+ "use_pose_pos": false,
62
+ "use_resnet": true,
63
+ "use_vae": null,
64
+ "weight_decay": 0.0,
65
+ "z_norm": true
66
+ }
checkpoints/guide/c1_pose/args.json ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_audio_pe": true,
3
+ "add_conv": true,
4
+ "add_frame_cond": 1,
5
+ "batch_size": 16,
6
+ "data_format": "pose",
7
+ "data_root": "./dataset/PXB184/",
8
+ "dataset": "social",
9
+ "dec_layers": null,
10
+ "dim": 64,
11
+ "enc_layers": null,
12
+ "eval_interval": 1000,
13
+ "filter": false,
14
+ "gamma": 0.1,
15
+ "gn": true,
16
+ "layers": 6,
17
+ "log_interval": 1000,
18
+ "lr": 0.0001,
19
+ "lr_scheduler": [
20
+ 50000,
21
+ 400000
22
+ ],
23
+ "no_split": false,
24
+ "num_audio_layers":2,
25
+ "out_dir": "checkpoints/guide/c1_pose",
26
+ "partial": false,
27
+ "resume_pth": "checkpoints/vq/c1_pose/net_iter300000.pth",
28
+ "resume_trans": null,
29
+ "save_interval": 5000,
30
+ "seed": 10,
31
+ "simplify_audio": false,
32
+ "total_iter": 1000000,
33
+ "use_full_dataset": false,
34
+ "use_kp": false,
35
+ "use_lstm": false,
36
+ "use_nort": false,
37
+ "use_nort_mdm": false,
38
+ "use_torch": false,
39
+ "warm_up_iter": 5000,
40
+ "weight_decay": 0.1
41
+ }
checkpoints/vq/c1_pose/args.json ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_frame_cond": 1.0,
3
+ "batch_size": 16,
4
+ "code_dim": 1024,
5
+ "commit": 0.02,
6
+ "data_format": "pose",
7
+ "data_root": "./dataset/PXB184/",
8
+ "dataname": "social",
9
+ "dataset": "social",
10
+ "depth": 4,
11
+ "eval_iter": 1000,
12
+ "exp_name": "c1_pose",
13
+ "filter": false,
14
+ "gamma": 0.05,
15
+ "loss_vel": 0.0,
16
+ "lr": 0.001,
17
+ "lr_scheduler": [
18
+ 300000
19
+ ],
20
+ "max_seq_length": 600,
21
+ "nb_joints": 104,
22
+ "no_split": true,
23
+ "out_dir": "checkpoints/vq/c1_pose",
24
+ "output_emb_width": 64,
25
+ "partial": false,
26
+ "print_iter": 200,
27
+ "results_dir": "visual_results/",
28
+ "resume_pth": null,
29
+ "seed": 123,
30
+ "simplify_audio": false,
31
+ "total_iter": 10000000,
32
+ "use_full_dataset": false,
33
+ "use_kp": false,
34
+ "use_linear": false,
35
+ "use_nort": false,
36
+ "use_nort_mdm": false,
37
+ "use_quant": true,
38
+ "use_vae": false,
39
+ "visual_name": "baseline",
40
+ "warm_up_iter": 1000,
41
+ "weight_decay": 0.0,
42
+ "z_norm": true
43
+ }
checkpoints/vq/c1_pose/net_iter300000.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5649ad5e49e0e1afcd9a7390f0ee79ee66de275a67ecb1cfe7fc691cb4ceb332
3
+ size 3129275
data_loaders/data.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ All rights reserved.
4
+ This source code is licensed under the license found in the
5
+ LICENSE file in the root directory of this source tree.
6
+ """
7
+
8
+ import os
9
+ from typing import Dict, Iterable, List, Union
10
+
11
+ import numpy as np
12
+ import torch
13
+ from torch.utils import data
14
+
15
+ from utils.misc import prGreen
16
+
17
+
18
+ class Social(data.Dataset):
19
+ def __init__(
20
+ self,
21
+ args,
22
+ data_dict: Dict[str, Iterable],
23
+ split: str = "train",
24
+ chunk: bool = False,
25
+ add_padding: bool = True,
26
+ ) -> None:
27
+ if args.data_format == "face":
28
+ prGreen("[dataset.py] training face only model")
29
+ data_dict["data"] = data_dict["face"]
30
+ elif args.data_format == "pose":
31
+ prGreen("[dataset.py] training pose only model")
32
+ missing = []
33
+ for d in data_dict["data"]:
34
+ missing.append(np.ones_like(d))
35
+ data_dict["missing"] = missing
36
+
37
+ # set up variables for dataloader
38
+ self.data_format = args.data_format
39
+ self.add_frame_cond = args.add_frame_cond
40
+ self._register_keyframe_step()
41
+ self.data_root = args.data_root
42
+ self.max_seq_length = args.max_seq_length
43
+ if hasattr(args, "curr_seq_length") and args.curr_seq_length is not None:
44
+ self.max_seq_length = args.curr_seq_length
45
+ prGreen([f"[dataset.py] sequences of {self.max_seq_length}"])
46
+ self.add_padding = add_padding
47
+ self.audio_per_frame = 1600
48
+ self.max_audio_length = self.max_seq_length * self.audio_per_frame
49
+ self.min_seq_length = 400
50
+
51
+ # set up training/validation splits
52
+ train_idx = list(range(0, len(data_dict["data"]) - 6))
53
+ val_idx = list(range(len(data_dict["data"]) - 6, len(data_dict["data"]) - 4))
54
+ test_idx = list(range(len(data_dict["data"]) - 4, len(data_dict["data"])))
55
+ self.split = split
56
+ if split == "train":
57
+ self._pick_sequences(data_dict, train_idx)
58
+ elif split == "val":
59
+ self._pick_sequences(data_dict, val_idx)
60
+ else:
61
+ self._pick_sequences(data_dict, test_idx)
62
+ self.chunk = chunk
63
+ if split == "test":
64
+ print("[dataset.py] chunking data...")
65
+ self._chunk_data()
66
+ self._load_std()
67
+ prGreen(
68
+ f"[dataset.py] {split} | {len(self.data)} sequences ({self.data[0].shape}) | total len {self.total_len}"
69
+ )
70
+
71
+ def inv_transform(
72
+ self, data: Union[np.ndarray, torch.Tensor], data_type: str
73
+ ) -> Union[np.ndarray, torch.Tensor]:
74
+ if data_type == "pose":
75
+ std = self.std
76
+ mean = self.mean
77
+ elif data_type == "face":
78
+ std = self.face_std
79
+ mean = self.face_mean
80
+ elif data_type == "audio":
81
+ std = self.audio_std
82
+ mean = self.audio_mean
83
+ else:
84
+ assert False, f"datatype not defined: {data_type}"
85
+
86
+ if torch.is_tensor(data):
87
+ return data * torch.tensor(
88
+ std, device=data.device, requires_grad=False
89
+ ) + torch.tensor(mean, device=data.device, requires_grad=False)
90
+ else:
91
+ return data * std + mean
92
+
93
+ def _pick_sequences(self, data_dict: Dict[str, Iterable], idx: List[int]) -> None:
94
+ self.data = np.take(data_dict["data"], idx, axis=0)
95
+ self.missing = np.take(data_dict["missing"], idx, axis=0)
96
+ self.audio = np.take(data_dict["audio"], idx, axis=0)
97
+ self.lengths = np.take(data_dict["lengths"], idx, axis=0)
98
+ self.total_len = sum([len(d) for d in self.data])
99
+
100
+ def _load_std(self) -> None:
101
+ stats = torch.load(os.path.join(self.data_root, "data_stats.pth"))
102
+ print(
103
+ f'[dataset.py] loading from... {os.path.join(self.data_root, "data_stats.pth")}'
104
+ )
105
+ self.mean = stats["pose_mean"].reshape(-1)
106
+ self.std = stats["pose_std"].reshape(-1)
107
+ self.face_mean = stats["code_mean"]
108
+ self.face_std = stats["code_std"]
109
+ self.audio_mean = stats["audio_mean"]
110
+ self.audio_std = stats["audio_std_flat"]
111
+
112
+ def _chunk_data(self) -> None:
113
+ chunk_data = []
114
+ chunk_missing = []
115
+ chunk_lengths = []
116
+ chunk_audio = []
117
+ # create sequences of set lengths
118
+ for d_idx in range(len(self.data)):
119
+ curr_data = self.data[d_idx]
120
+ curr_missing = self.missing[d_idx]
121
+ curr_audio = self.audio[d_idx]
122
+ end_range = len(self.data[d_idx]) - self.max_seq_length
123
+ for chunk_idx in range(0, end_range, self.max_seq_length):
124
+ chunk_end = chunk_idx + self.max_seq_length
125
+ curr_data_chunk = curr_data[chunk_idx:chunk_end, :]
126
+ curr_missing_chunk = curr_missing[chunk_idx:chunk_end, :]
127
+ curr_audio_chunk = curr_audio[
128
+ chunk_idx * self.audio_per_frame : chunk_end * self.audio_per_frame,
129
+ :,
130
+ ]
131
+ if curr_data_chunk.shape[0] < self.max_seq_length:
132
+ # do not add a short chunk to the list
133
+ continue
134
+ chunk_lengths.append(curr_data_chunk.shape[0])
135
+ chunk_data.append(curr_data_chunk)
136
+ chunk_missing.append(curr_missing_chunk)
137
+ chunk_audio.append(curr_audio_chunk)
138
+ idx = np.random.permutation(len(chunk_data))
139
+ print("==> shuffle", idx)
140
+ self.data = np.take(chunk_data, idx, axis=0)
141
+ self.missing = np.take(chunk_missing, idx, axis=0)
142
+ self.lengths = np.take(chunk_lengths, idx, axis=0)
143
+ self.audio = np.take(chunk_audio, idx, axis=0)
144
+ self.total_len = len(self.data)
145
+
146
+ def _register_keyframe_step(self) -> None:
147
+ if self.add_frame_cond == 1:
148
+ self.step = 30
149
+ if self.add_frame_cond is None:
150
+ self.step = 1
151
+
152
+ def _pad_sequence(
153
+ self, sequence: np.ndarray, actual_length: int, max_length: int
154
+ ) -> np.ndarray:
155
+ sequence = np.concatenate(
156
+ (
157
+ sequence,
158
+ np.zeros((max_length - actual_length, sequence.shape[-1])),
159
+ ),
160
+ axis=0,
161
+ )
162
+ return sequence
163
+
164
+ def _get_idx(self, item: int) -> int:
165
+ cumulative_len = 0
166
+ seq_idx = 0
167
+ while item > cumulative_len:
168
+ cumulative_len += len(self.data[seq_idx])
169
+ seq_idx += 1
170
+ item = seq_idx - 1
171
+ return item
172
+
173
+ def _get_random_subsection(
174
+ self, data_dict: Dict[str, Iterable]
175
+ ) -> Dict[str, np.ndarray]:
176
+ isnonzero = False
177
+ while not isnonzero:
178
+ start = np.random.randint(0, data_dict["m_length"] - self.max_seq_length)
179
+ if self.add_padding:
180
+ length = (
181
+ np.random.randint(self.min_seq_length, self.max_seq_length)
182
+ if not self.split == "test"
183
+ else self.max_seq_length
184
+ )
185
+ else:
186
+ length = self.max_seq_length
187
+ curr_missing = data_dict["missing"][start : start + length]
188
+ isnonzero = np.any(curr_missing)
189
+ missing = curr_missing
190
+ motion = data_dict["motion"][start : start + length, :]
191
+ keyframes = motion[:: self.step]
192
+ audio = data_dict["audio"][
193
+ start * self.audio_per_frame : (start + length) * self.audio_per_frame,
194
+ :,
195
+ ]
196
+ data_dict["m_length"] = len(motion)
197
+ data_dict["k_length"] = len(keyframes)
198
+ data_dict["a_length"] = len(audio)
199
+
200
+ if data_dict["m_length"] < self.max_seq_length:
201
+ motion = self._pad_sequence(
202
+ motion, data_dict["m_length"], self.max_seq_length
203
+ )
204
+ missing = self._pad_sequence(
205
+ missing, data_dict["m_length"], self.max_seq_length
206
+ )
207
+ audio = self._pad_sequence(
208
+ audio, data_dict["a_length"], self.max_audio_length
209
+ )
210
+ max_step_length = len(np.zeros(self.max_seq_length)[:: self.step])
211
+ keyframes = self._pad_sequence(
212
+ keyframes, data_dict["k_length"], max_step_length
213
+ )
214
+ data_dict["motion"] = motion
215
+ data_dict["keyframes"] = keyframes
216
+ data_dict["audio"] = audio
217
+ data_dict["missing"] = missing
218
+ return data_dict
219
+
220
+ def __len__(self) -> int:
221
+ return self.total_len
222
+
223
+ def __getitem__(self, item: int) -> Dict[str, np.ndarray]:
224
+ # figure out which sequence to randomly sample from
225
+ if not self.split == "test":
226
+ item = self._get_idx(item)
227
+ motion = self.data[item]
228
+ audio = self.audio[item]
229
+ m_length = self.lengths[item]
230
+ missing = self.missing[item]
231
+ a_length = len(audio)
232
+ # Z Normalization
233
+ if self.data_format == "pose":
234
+ motion = (motion - self.mean) / self.std
235
+ elif self.data_format == "face":
236
+ motion = (motion - self.face_mean) / self.face_std
237
+ audio = (audio - self.audio_mean) / self.audio_std
238
+ keyframes = motion[:: self.step]
239
+ k_length = len(keyframes)
240
+ data_dict = {
241
+ "motion": motion,
242
+ "m_length": m_length,
243
+ "audio": audio,
244
+ "a_length": a_length,
245
+ "keyframes": keyframes,
246
+ "k_length": k_length,
247
+ "missing": missing,
248
+ }
249
+ if not self.split == "test" and not self.chunk:
250
+ data_dict = self._get_random_subsection(data_dict)
251
+ if self.data_format == "face":
252
+ data_dict["motion"] *= data_dict["missing"]
253
+ return data_dict
data_loaders/get_data.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ All rights reserved.
4
+ This source code is licensed under the license found in the
5
+ LICENSE file in the root directory of this source tree.
6
+ """
7
+
8
+ import os
9
+
10
+ from typing import Dict, List
11
+
12
+ import numpy as np
13
+ import torch
14
+ import torchaudio
15
+ from data_loaders.data import Social
16
+ from data_loaders.tensors import social_collate
17
+ from torch.utils.data import DataLoader
18
+ from utils.misc import prGreen
19
+
20
+
21
+ def get_dataset_loader(
22
+ args,
23
+ data_dict: Dict[str, np.ndarray],
24
+ split: str = "train",
25
+ chunk: bool = False,
26
+ add_padding: bool = True,
27
+ ) -> DataLoader:
28
+ dataset = Social(
29
+ args=args,
30
+ data_dict=data_dict,
31
+ split=split,
32
+ chunk=chunk,
33
+ add_padding=add_padding,
34
+ )
35
+ loader = DataLoader(
36
+ dataset,
37
+ batch_size=args.batch_size,
38
+ shuffle=not split == "test",
39
+ num_workers=8,
40
+ drop_last=True,
41
+ collate_fn=social_collate,
42
+ pin_memory=True,
43
+ )
44
+ return loader
45
+
46
+
47
+ def _load_pose_data(
48
+ all_paths: List[str], audio_per_frame: int, flip_person: bool = False
49
+ ) -> Dict[str, List]:
50
+ data = []
51
+ face = []
52
+ audio = []
53
+ lengths = []
54
+ missing = []
55
+ for _, curr_path_name in enumerate(all_paths):
56
+ if not curr_path_name.endswith("_body_pose.npy"):
57
+ continue
58
+ # load face information and deal with missing codes
59
+ curr_code = np.load(
60
+ curr_path_name.replace("_body_pose.npy", "_face_expression.npy")
61
+ ).astype(float)
62
+ # curr_code = np.array(curr_face["codes"], dtype=float)
63
+ missing_list = np.load(
64
+ curr_path_name.replace("_body_pose.npy", "_missing_face_frames.npy")
65
+ )
66
+ if len(missing_list) == len(curr_code):
67
+ print("skipping", curr_path_name, curr_code.shape)
68
+ continue
69
+ curr_missing = np.ones_like(curr_code)
70
+ curr_missing[missing_list] = 0.0
71
+
72
+ # load pose information and deal with discontinuities
73
+ curr_pose = np.load(curr_path_name)
74
+ if "PXB184" in curr_path_name or "RLW104" in curr_path_name: # Capture 1 or 2
75
+ curr_pose[:, 3] = (curr_pose[:, 3] + np.pi) % (2 * np.pi)
76
+ curr_pose[:, 3] = (curr_pose[:, 3] + np.pi) % (2 * np.pi)
77
+
78
+ # load audio information
79
+ curr_audio, _ = torchaudio.load(
80
+ curr_path_name.replace("_body_pose.npy", "_audio.wav")
81
+ )
82
+ curr_audio = curr_audio.T
83
+ if flip_person:
84
+ prGreen("[get_data.py] flipping the dataset of left right person")
85
+ tmp = torch.zeros_like(curr_audio)
86
+ tmp[:, 1] = curr_audio[:, 0]
87
+ tmp[:, 0] = curr_audio[:, 1]
88
+ curr_audio = tmp
89
+
90
+ assert len(curr_pose) * audio_per_frame == len(
91
+ curr_audio
92
+ ), f"motion {curr_pose.shape} vs audio {curr_audio.shape}"
93
+
94
+ data.append(curr_pose)
95
+ face.append(curr_code)
96
+ missing.append(curr_missing)
97
+ audio.append(curr_audio)
98
+ lengths.append(len(curr_pose))
99
+
100
+ data_dict = {
101
+ "data": data,
102
+ "face": face,
103
+ "audio": audio,
104
+ "lengths": lengths,
105
+ "missing": missing,
106
+ }
107
+ return data_dict
108
+
109
+
110
+ def load_local_data(
111
+ data_root: str, audio_per_frame: int, flip_person: bool = False
112
+ ) -> Dict[str, List]:
113
+ if flip_person:
114
+ if "PXB184" in data_root:
115
+ data_root = data_root.replace("PXB184", "RLW104")
116
+ elif "RLW104" in data_root:
117
+ data_root = data_root.replace("RLW104", "PXB184")
118
+ elif "TXB805" in data_root:
119
+ data_root = data_root.replace("TXB805", "GQS883")
120
+ elif "GQS883" in data_root:
121
+ data_root = data_root.replace("GQS883", "TXB805")
122
+
123
+ all_paths = [os.path.join(data_root, x) for x in os.listdir(data_root)]
124
+ all_paths.sort()
125
+ return _load_pose_data(
126
+ all_paths,
127
+ audio_per_frame,
128
+ flip_person=flip_person,
129
+ )
data_loaders/tensors.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ All rights reserved.
4
+ This source code is licensed under the license found in the
5
+ LICENSE file in the root directory of this source tree.
6
+ """
7
+
8
+ import torch
9
+ from torch.utils.data._utils.collate import default_collate
10
+
11
+
12
+ def lengths_to_mask(lengths, max_len):
13
+ mask = torch.arange(max_len, device=lengths.device).expand(
14
+ len(lengths), max_len
15
+ ) < lengths.unsqueeze(1)
16
+ return mask
17
+
18
+
19
+ def collate_tensors(batch):
20
+ dims = batch[0].dim()
21
+ max_size = [max([b.size(i) for b in batch]) for i in range(dims)]
22
+ size = (len(batch),) + tuple(max_size)
23
+ canvas = batch[0].new_zeros(size=size)
24
+ for i, b in enumerate(batch):
25
+ sub_tensor = canvas[i]
26
+ for d in range(dims):
27
+ sub_tensor = sub_tensor.narrow(d, 0, b.size(d))
28
+ sub_tensor.add_(b)
29
+ return canvas
30
+
31
+
32
+ ## social collate
33
+ def collate_v2(batch):
34
+ notnone_batches = [b for b in batch if b is not None]
35
+ databatch = [b["inp"] for b in notnone_batches]
36
+ missingbatch = [b["missing"] for b in notnone_batches]
37
+ audiobatch = [b["audio"] for b in notnone_batches]
38
+ lenbatch = [b["lengths"] for b in notnone_batches]
39
+ alenbatch = [b["audio_lengths"] for b in notnone_batches]
40
+ keyframebatch = [b["keyframes"] for b in notnone_batches]
41
+ klenbatch = [b["key_lengths"] for b in notnone_batches]
42
+
43
+ databatchTensor = collate_tensors(databatch)
44
+ missingbatchTensor = collate_tensors(missingbatch)
45
+ audiobatchTensor = collate_tensors(audiobatch)
46
+ lenbatchTensor = torch.as_tensor(lenbatch)
47
+ alenbatchTensor = torch.as_tensor(alenbatch)
48
+ keyframeTensor = collate_tensors(keyframebatch)
49
+ klenbatchTensor = torch.as_tensor(klenbatch)
50
+
51
+ maskbatchTensor = (
52
+ lengths_to_mask(lenbatchTensor, databatchTensor.shape[-1])
53
+ .unsqueeze(1)
54
+ .unsqueeze(1)
55
+ ) # unqueeze for broadcasting
56
+ motion = databatchTensor
57
+ cond = {
58
+ "y": {
59
+ "missing": missingbatchTensor,
60
+ "mask": maskbatchTensor,
61
+ "lengths": lenbatchTensor,
62
+ "audio": audiobatchTensor,
63
+ "alengths": alenbatchTensor,
64
+ "keyframes": keyframeTensor,
65
+ "klengths": klenbatchTensor,
66
+ }
67
+ }
68
+ return motion, cond
69
+
70
+
71
+ def social_collate(batch):
72
+ adapted_batch = [
73
+ {
74
+ "inp": torch.tensor(b["motion"].T).to(torch.float32).unsqueeze(1),
75
+ "lengths": b["m_length"],
76
+ "audio": b["audio"]
77
+ if torch.is_tensor(b["audio"])
78
+ else torch.tensor(b["audio"]).to(torch.float32),
79
+ "keyframes": torch.tensor(b["keyframes"]).to(torch.float32),
80
+ "key_lengths": b["k_length"],
81
+ "audio_lengths": b["a_length"],
82
+ "missing": torch.tensor(b["missing"]).to(torch.float32),
83
+ }
84
+ for b in batch
85
+ ]
86
+ return collate_v2(adapted_batch)
demo/.ipynb_checkpoints/demo-checkpoint.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ All rights reserved.
4
+ This source code is licensed under the license found in the
5
+ LICENSE file in the root directory of this source tree.
6
+ """
7
+
8
+ import copy
9
+ import json
10
+ from typing import Dict, Union
11
+
12
+ import gradio as gr
13
+ import numpy as np
14
+ import torch
15
+ import torchaudio
16
+ from attrdict import AttrDict
17
+ from diffusion.respace import SpacedDiffusion
18
+ from model.cfg_sampler import ClassifierFreeSampleModel
19
+ from model.diffusion import FiLMTransformer
20
+ from utils.misc import fixseed
21
+ from utils.model_util import create_model_and_diffusion, load_model
22
+ from visualize.render_codes import BodyRenderer
23
+
24
+
25
+ class GradioModel:
26
+ def __init__(self, face_args, pose_args) -> None:
27
+ self.face_model, self.face_diffusion, self.device = self._setup_model(
28
+ face_args, "checkpoints/diffusion/c1_face/model000155000.pt"
29
+ )
30
+ self.pose_model, self.pose_diffusion, _ = self._setup_model(
31
+ pose_args, "checkpoints/diffusion/c1_pose/model000340000.pt"
32
+ )
33
+ # load standardization stuff
34
+ stats = torch.load("dataset/PXB184/data_stats.pth")
35
+ stats["pose_mean"] = stats["pose_mean"].reshape(-1)
36
+ stats["pose_std"] = stats["pose_std"].reshape(-1)
37
+ self.stats = stats
38
+ # set up renderer
39
+ config_base = f"./checkpoints/ca_body/data/PXB184"
40
+ self.body_renderer = BodyRenderer(
41
+ config_base=config_base,
42
+ render_rgb=True,
43
+ )
44
+
45
+ def _setup_model(
46
+ self,
47
+ args_path: str,
48
+ model_path: str,
49
+ ) -> (Union[FiLMTransformer, ClassifierFreeSampleModel], SpacedDiffusion):
50
+ with open(args_path) as f:
51
+ args = json.load(f)
52
+ args = AttrDict(args)
53
+ args.device = "cuda:0" if torch.cuda.is_available() else "cpu"
54
+ print("running on...", args.device)
55
+ args.model_path = model_path
56
+ args.output_dir = "/tmp/gradio/"
57
+ args.timestep_respacing = "ddim100"
58
+ if args.data_format == "pose":
59
+ args.resume_trans = "checkpoints/guide/c1_pose/checkpoints/iter-0100000.pt"
60
+
61
+ ## create model
62
+ model, diffusion = create_model_and_diffusion(args, split_type="test")
63
+ print(f"Loading checkpoints from [{args.model_path}]...")
64
+ state_dict = torch.load(args.model_path, map_location=args.device)
65
+ load_model(model, state_dict)
66
+ model = ClassifierFreeSampleModel(model)
67
+ model.eval()
68
+ model.to(args.device)
69
+ return model, diffusion, args.device
70
+
71
+ def _replace_keyframes(
72
+ self,
73
+ model_kwargs: Dict[str, Dict[str, torch.Tensor]],
74
+ B: int,
75
+ T: int,
76
+ top_p: float = 0.97,
77
+ ) -> torch.Tensor:
78
+ with torch.no_grad():
79
+ tokens = self.pose_model.transformer.generate(
80
+ model_kwargs["y"]["audio"],
81
+ T,
82
+ layers=self.pose_model.tokenizer.residual_depth,
83
+ n_sequences=B,
84
+ top_p=top_p,
85
+ )
86
+ tokens = tokens.reshape((B, -1, self.pose_model.tokenizer.residual_depth))
87
+ pred = self.pose_model.tokenizer.decode(tokens).detach()
88
+ return pred
89
+
90
+ def _run_single_diffusion(
91
+ self,
92
+ model_kwargs: Dict[str, Dict[str, torch.Tensor]],
93
+ diffusion: SpacedDiffusion,
94
+ model: Union[FiLMTransformer, ClassifierFreeSampleModel],
95
+ curr_seq_length: int,
96
+ num_repetitions: int = 1,
97
+ ) -> (torch.Tensor,):
98
+ sample_fn = diffusion.ddim_sample_loop
99
+ with torch.no_grad():
100
+ sample = sample_fn(
101
+ model,
102
+ (num_repetitions, model.nfeats, 1, curr_seq_length),
103
+ clip_denoised=False,
104
+ model_kwargs=model_kwargs,
105
+ init_image=None,
106
+ progress=True,
107
+ dump_steps=None,
108
+ noise=None,
109
+ const_noise=False,
110
+ )
111
+ return sample
112
+
113
+ def generate_sequences(
114
+ self,
115
+ model_kwargs: Dict[str, Dict[str, torch.Tensor]],
116
+ data_format: str,
117
+ curr_seq_length: int,
118
+ num_repetitions: int = 5,
119
+ guidance_param: float = 10.0,
120
+ top_p: float = 0.97,
121
+ # batch_size: int = 1,
122
+ ) -> Dict[str, np.ndarray]:
123
+ if data_format == "pose":
124
+ model = self.pose_model
125
+ diffusion = self.pose_diffusion
126
+ else:
127
+ model = self.face_model
128
+ diffusion = self.face_diffusion
129
+
130
+ all_motions = []
131
+ model_kwargs["y"]["scale"] = torch.ones(num_repetitions) * guidance_param
132
+ model_kwargs["y"] = {
133
+ key: val.to(self.device) if torch.is_tensor(val) else val
134
+ for key, val in model_kwargs["y"].items()
135
+ }
136
+ if data_format == "pose":
137
+ model_kwargs["y"]["mask"] = (
138
+ torch.ones((num_repetitions, 1, 1, curr_seq_length))
139
+ .to(self.device)
140
+ .bool()
141
+ )
142
+ model_kwargs["y"]["keyframes"] = self._replace_keyframes(
143
+ model_kwargs,
144
+ num_repetitions,
145
+ int(curr_seq_length / 30),
146
+ top_p=top_p,
147
+ )
148
+ sample = self._run_single_diffusion(
149
+ model_kwargs, diffusion, model, curr_seq_length, num_repetitions
150
+ )
151
+ all_motions.append(sample.cpu().numpy())
152
+ print(f"created {len(all_motions) * num_repetitions} samples")
153
+ return np.concatenate(all_motions, axis=0)
154
+
155
+
156
+ def generate_results(audio: np.ndarray, num_repetitions: int, top_p: float):
157
+ if audio is None:
158
+ raise gr.Error("Please record audio to start")
159
+ sr, y = audio
160
+ # set to mono and perform resampling
161
+ y = torch.Tensor(y)
162
+ if y.dim() == 2:
163
+ dim = 0 if y.shape[0] == 2 else 1
164
+ y = torch.mean(y, dim=dim)
165
+ y = torchaudio.functional.resample(torch.Tensor(y), orig_freq=sr, new_freq=48_000)
166
+ sr = 48_000
167
+ # make it so that it is 4 seconds long
168
+ if len(y) < (sr * 4):
169
+ raise gr.Error("Please record at least 4 second of audio")
170
+ if num_repetitions is None or num_repetitions <= 0 or num_repetitions > 10:
171
+ raise gr.Error(
172
+ f"Invalid number of samples: {num_repetitions}. Please specify a number between 1-10"
173
+ )
174
+ cutoff = int(len(y) / (sr * 4))
175
+ y = y[: cutoff * sr * 4]
176
+ curr_seq_length = int(len(y) / sr) * 30
177
+ # create model_kwargs
178
+ model_kwargs = {"y": {}}
179
+ dual_audio = np.random.normal(0, 0.001, (1, len(y), 2))
180
+ dual_audio[:, :, 0] = y / max(y)
181
+ dual_audio = (dual_audio - gradio_model.stats["audio_mean"]) / gradio_model.stats[
182
+ "audio_std_flat"
183
+ ]
184
+ model_kwargs["y"]["audio"] = (
185
+ torch.Tensor(dual_audio).float().tile(num_repetitions, 1, 1)
186
+ )
187
+ face_results = (
188
+ gradio_model.generate_sequences(
189
+ model_kwargs, "face", curr_seq_length, num_repetitions=int(num_repetitions)
190
+ )
191
+ .squeeze(2)
192
+ .transpose(0, 2, 1)
193
+ )
194
+ face_results = (
195
+ face_results * gradio_model.stats["code_std"] + gradio_model.stats["code_mean"]
196
+ )
197
+ pose_results = (
198
+ gradio_model.generate_sequences(
199
+ model_kwargs,
200
+ "pose",
201
+ curr_seq_length,
202
+ num_repetitions=int(num_repetitions),
203
+ guidance_param=2.0,
204
+ top_p=top_p,
205
+ )
206
+ .squeeze(2)
207
+ .transpose(0, 2, 1)
208
+ )
209
+ pose_results = (
210
+ pose_results * gradio_model.stats["pose_std"] + gradio_model.stats["pose_mean"]
211
+ )
212
+ dual_audio = (
213
+ dual_audio * gradio_model.stats["audio_std_flat"]
214
+ + gradio_model.stats["audio_mean"]
215
+ )
216
+ return face_results, pose_results, dual_audio[0].transpose(1, 0).astype(np.float32)
217
+
218
+
219
+ def audio_to_avatar(audio: np.ndarray, num_repetitions: int, top_p: float):
220
+ face_results, pose_results, audio = generate_results(audio, num_repetitions, top_p)
221
+ # returns: num_rep x T x 104
222
+ B = len(face_results)
223
+ results = []
224
+ for i in range(B):
225
+ render_data_block = {
226
+ "audio": audio, # 2 x T
227
+ "body_motion": pose_results[i, ...], # T x 104
228
+ "face_motion": face_results[i, ...], # T x 256
229
+ }
230
+ gradio_model.body_renderer.render_full_video(
231
+ render_data_block, f"/tmp/sample{i}", audio_sr=48_000
232
+ )
233
+ results += [gr.Video(value=f"/tmp/sample{i}_pred.mp4", visible=True)]
234
+ results += [gr.Video(visible=False) for _ in range(B, 10)]
235
+ return results
236
+
237
+
238
+ gradio_model = GradioModel(
239
+ face_args="./checkpoints/diffusion/c1_face/args.json",
240
+ pose_args="./checkpoints/diffusion/c1_pose/args.json",
241
+ )
242
+ demo = gr.Interface(
243
+ audio_to_avatar, # function
244
+ [
245
+ gr.Audio(sources=["microphone"]),
246
+ gr.Number(
247
+ value=3,
248
+ label="Number of Samples (default = 3)",
249
+ precision=0,
250
+ minimum=1,
251
+ maximum=10,
252
+ ),
253
+ gr.Number(
254
+ value=0.97,
255
+ label="Sample Diversity (default = 0.97)",
256
+ precision=None,
257
+ minimum=0.01,
258
+ step=0.01,
259
+ maximum=1.00,
260
+ ),
261
+ ], # input type
262
+ [gr.Video(format="mp4", visible=True)]
263
+ + [gr.Video(format="mp4", visible=False) for _ in range(9)], # output type
264
+ title='"From Audio to Photoreal Embodiment: Synthesizing Humans in Conversations" Demo',
265
+ description="You can generate a photorealistic avatar from your voice! <br/>\
266
+ 1) Start by recording your audio. <br/>\
267
+ 2) Specify the number of samples to generate. <br/>\
268
+ 3) Specify how diverse you want the samples to be. This tunes the cumulative probability in nucleus sampling: 0.01 = low diversity, 1.0 = high diversity. <br/>\
269
+ 4) Then, sit back and wait for the rendering to happen! This may take a while (e.g. 30 minutes) <br/>\
270
+ 5) After, you can view the videos and download the ones you like. <br/>",
271
+ article="Relevant links: [Project Page](https://people.eecs.berkeley.edu/~evonne_ng/projects/audio2photoreal)", # TODO: code and arxiv
272
+ )
273
+
274
+ if __name__ == "__main__":
275
+ fixseed(10)
276
+ demo.launch(share=True)
demo/demo.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ All rights reserved.
4
+ This source code is licensed under the license found in the
5
+ LICENSE file in the root directory of this source tree.
6
+ """
7
+
8
+ import copy
9
+ import json
10
+ from typing import Dict, Union
11
+
12
+ import gradio as gr
13
+ import numpy as np
14
+ import torch
15
+ import torchaudio
16
+ from attrdict import AttrDict
17
+ from diffusion.respace import SpacedDiffusion
18
+ from model.cfg_sampler import ClassifierFreeSampleModel
19
+ from model.diffusion import FiLMTransformer
20
+ from utils.misc import fixseed
21
+ from utils.model_util import create_model_and_diffusion, load_model
22
+ from visualize.render_codes import BodyRenderer
23
+
24
+
25
+ class GradioModel:
26
+ def __init__(self, face_args, pose_args) -> None:
27
+ self.face_model, self.face_diffusion, self.device = self._setup_model(
28
+ face_args, "checkpoints/diffusion/c1_face/model000155000.pt"
29
+ )
30
+ self.pose_model, self.pose_diffusion, _ = self._setup_model(
31
+ pose_args, "checkpoints/diffusion/c1_pose/model000340000.pt"
32
+ )
33
+ # load standardization stuff
34
+ stats = torch.load("dataset/PXB184/data_stats.pth")
35
+ stats["pose_mean"] = stats["pose_mean"].reshape(-1)
36
+ stats["pose_std"] = stats["pose_std"].reshape(-1)
37
+ self.stats = stats
38
+ # set up renderer
39
+ config_base = f"./checkpoints/ca_body/data/PXB184"
40
+ self.body_renderer = BodyRenderer(
41
+ config_base=config_base,
42
+ render_rgb=True,
43
+ )
44
+
45
+ def _setup_model(
46
+ self,
47
+ args_path: str,
48
+ model_path: str,
49
+ ) -> (Union[FiLMTransformer, ClassifierFreeSampleModel], SpacedDiffusion):
50
+ with open(args_path) as f:
51
+ args = json.load(f)
52
+ args = AttrDict(args)
53
+ args.device = "cuda:0" if torch.cuda.is_available() else "cpu"
54
+ print("running on...", args.device)
55
+ args.model_path = model_path
56
+ args.output_dir = "/tmp/gradio/"
57
+ args.timestep_respacing = "ddim100"
58
+ if args.data_format == "pose":
59
+ args.resume_trans = "checkpoints/guide/c1_pose/checkpoints/iter-0100000.pt"
60
+
61
+ ## create model
62
+ model, diffusion = create_model_and_diffusion(args, split_type="test")
63
+ print(f"Loading checkpoints from [{args.model_path}]...")
64
+ state_dict = torch.load(args.model_path, map_location=args.device)
65
+ load_model(model, state_dict)
66
+ model = ClassifierFreeSampleModel(model)
67
+ model.eval()
68
+ model.to(args.device)
69
+ return model, diffusion, args.device
70
+
71
+ def _replace_keyframes(
72
+ self,
73
+ model_kwargs: Dict[str, Dict[str, torch.Tensor]],
74
+ B: int,
75
+ T: int,
76
+ top_p: float = 0.97,
77
+ ) -> torch.Tensor:
78
+ with torch.no_grad():
79
+ tokens = self.pose_model.transformer.generate(
80
+ model_kwargs["y"]["audio"],
81
+ T,
82
+ layers=self.pose_model.tokenizer.residual_depth,
83
+ n_sequences=B,
84
+ top_p=top_p,
85
+ )
86
+ tokens = tokens.reshape((B, -1, self.pose_model.tokenizer.residual_depth))
87
+ pred = self.pose_model.tokenizer.decode(tokens).detach()
88
+ return pred
89
+
90
+ def _run_single_diffusion(
91
+ self,
92
+ model_kwargs: Dict[str, Dict[str, torch.Tensor]],
93
+ diffusion: SpacedDiffusion,
94
+ model: Union[FiLMTransformer, ClassifierFreeSampleModel],
95
+ curr_seq_length: int,
96
+ num_repetitions: int = 1,
97
+ ) -> (torch.Tensor,):
98
+ sample_fn = diffusion.ddim_sample_loop
99
+ with torch.no_grad():
100
+ sample = sample_fn(
101
+ model,
102
+ (num_repetitions, model.nfeats, 1, curr_seq_length),
103
+ clip_denoised=False,
104
+ model_kwargs=model_kwargs,
105
+ init_image=None,
106
+ progress=True,
107
+ dump_steps=None,
108
+ noise=None,
109
+ const_noise=False,
110
+ )
111
+ return sample
112
+
113
+ def generate_sequences(
114
+ self,
115
+ model_kwargs: Dict[str, Dict[str, torch.Tensor]],
116
+ data_format: str,
117
+ curr_seq_length: int,
118
+ num_repetitions: int = 5,
119
+ guidance_param: float = 10.0,
120
+ top_p: float = 0.97,
121
+ # batch_size: int = 1,
122
+ ) -> Dict[str, np.ndarray]:
123
+ if data_format == "pose":
124
+ model = self.pose_model
125
+ diffusion = self.pose_diffusion
126
+ else:
127
+ model = self.face_model
128
+ diffusion = self.face_diffusion
129
+
130
+ all_motions = []
131
+ model_kwargs["y"]["scale"] = torch.ones(num_repetitions) * guidance_param
132
+ model_kwargs["y"] = {
133
+ key: val.to(self.device) if torch.is_tensor(val) else val
134
+ for key, val in model_kwargs["y"].items()
135
+ }
136
+ if data_format == "pose":
137
+ model_kwargs["y"]["mask"] = (
138
+ torch.ones((num_repetitions, 1, 1, curr_seq_length))
139
+ .to(self.device)
140
+ .bool()
141
+ )
142
+ model_kwargs["y"]["keyframes"] = self._replace_keyframes(
143
+ model_kwargs,
144
+ num_repetitions,
145
+ int(curr_seq_length / 30),
146
+ top_p=top_p,
147
+ )
148
+ sample = self._run_single_diffusion(
149
+ model_kwargs, diffusion, model, curr_seq_length, num_repetitions
150
+ )
151
+ all_motions.append(sample.cpu().numpy())
152
+ print(f"created {len(all_motions) * num_repetitions} samples")
153
+ return np.concatenate(all_motions, axis=0)
154
+
155
+
156
+ def generate_results(audio: np.ndarray, num_repetitions: int, top_p: float):
157
+ if audio is None:
158
+ raise gr.Error("Please record audio to start")
159
+ sr, y = audio
160
+ # set to mono and perform resampling
161
+ y = torch.Tensor(y)
162
+ if y.dim() == 2:
163
+ dim = 0 if y.shape[0] == 2 else 1
164
+ y = torch.mean(y, dim=dim)
165
+ y = torchaudio.functional.resample(torch.Tensor(y), orig_freq=sr, new_freq=48_000)
166
+ sr = 48_000
167
+ # make it so that it is 4 seconds long
168
+ if len(y) < (sr * 4):
169
+ raise gr.Error("Please record at least 4 second of audio")
170
+ if num_repetitions is None or num_repetitions <= 0 or num_repetitions > 10:
171
+ raise gr.Error(
172
+ f"Invalid number of samples: {num_repetitions}. Please specify a number between 1-10"
173
+ )
174
+ cutoff = int(len(y) / (sr * 4))
175
+ y = y[: cutoff * sr * 4]
176
+ curr_seq_length = int(len(y) / sr) * 30
177
+ # create model_kwargs
178
+ model_kwargs = {"y": {}}
179
+ dual_audio = np.random.normal(0, 0.001, (1, len(y), 2))
180
+ dual_audio[:, :, 0] = y / max(y)
181
+ dual_audio = (dual_audio - gradio_model.stats["audio_mean"]) / gradio_model.stats[
182
+ "audio_std_flat"
183
+ ]
184
+ model_kwargs["y"]["audio"] = (
185
+ torch.Tensor(dual_audio).float().tile(num_repetitions, 1, 1)
186
+ )
187
+ face_results = (
188
+ gradio_model.generate_sequences(
189
+ model_kwargs, "face", curr_seq_length, num_repetitions=int(num_repetitions)
190
+ )
191
+ .squeeze(2)
192
+ .transpose(0, 2, 1)
193
+ )
194
+ face_results = (
195
+ face_results * gradio_model.stats["code_std"] + gradio_model.stats["code_mean"]
196
+ )
197
+ pose_results = (
198
+ gradio_model.generate_sequences(
199
+ model_kwargs,
200
+ "pose",
201
+ curr_seq_length,
202
+ num_repetitions=int(num_repetitions),
203
+ guidance_param=2.0,
204
+ top_p=top_p,
205
+ )
206
+ .squeeze(2)
207
+ .transpose(0, 2, 1)
208
+ )
209
+ pose_results = (
210
+ pose_results * gradio_model.stats["pose_std"] + gradio_model.stats["pose_mean"]
211
+ )
212
+ dual_audio = (
213
+ dual_audio * gradio_model.stats["audio_std_flat"]
214
+ + gradio_model.stats["audio_mean"]
215
+ )
216
+ return face_results, pose_results, dual_audio[0].transpose(1, 0).astype(np.float32)
217
+
218
+
219
+ def audio_to_avatar(audio: np.ndarray, num_repetitions: int, top_p: float):
220
+ face_results, pose_results, audio = generate_results(audio, num_repetitions, top_p)
221
+ # returns: num_rep x T x 104
222
+ B = len(face_results)
223
+ results = []
224
+ for i in range(B):
225
+ render_data_block = {
226
+ "audio": audio, # 2 x T
227
+ "body_motion": pose_results[i, ...], # T x 104
228
+ "face_motion": face_results[i, ...], # T x 256
229
+ }
230
+ gradio_model.body_renderer.render_full_video(
231
+ render_data_block, f"/tmp/sample{i}", audio_sr=48_000
232
+ )
233
+ results += [gr.Video(value=f"/tmp/sample{i}_pred.mp4", visible=True)]
234
+ results += [gr.Video(visible=False) for _ in range(B, 10)]
235
+ return results
236
+
237
+
238
+ gradio_model = GradioModel(
239
+ face_args="./checkpoints/diffusion/c1_face/args.json",
240
+ pose_args="./checkpoints/diffusion/c1_pose/args.json",
241
+ )
242
+ demo = gr.Interface(
243
+ audio_to_avatar, # function
244
+ [
245
+ gr.Audio(sources=["microphone"]),
246
+ gr.Number(
247
+ value=3,
248
+ label="Number of Samples (default = 3)",
249
+ precision=0,
250
+ minimum=1,
251
+ maximum=10,
252
+ ),
253
+ gr.Number(
254
+ value=0.97,
255
+ label="Sample Diversity (default = 0.97)",
256
+ precision=None,
257
+ minimum=0.01,
258
+ step=0.01,
259
+ maximum=1.00,
260
+ ),
261
+ ], # input type
262
+ [gr.Video(format="mp4", visible=True)]
263
+ + [gr.Video(format="mp4", visible=False) for _ in range(9)], # output type
264
+ title='"From Audio to Photoreal Embodiment: Synthesizing Humans in Conversations" Demo',
265
+ description="You can generate a photorealistic avatar from your voice! <br/>\
266
+ 1) Start by recording your audio. <br/>\
267
+ 2) Specify the number of samples to generate. <br/>\
268
+ 3) Specify how diverse you want the samples to be. This tunes the cumulative probability in nucleus sampling: 0.01 = low diversity, 1.0 = high diversity. <br/>\
269
+ 4) Then, sit back and wait for the rendering to happen! This may take a while (e.g. 30 minutes) <br/>\
270
+ 5) After, you can view the videos and download the ones you like. <br/>",
271
+ article="Relevant links: [Project Page](https://people.eecs.berkeley.edu/~evonne_ng/projects/audio2photoreal)", # TODO: code and arxiv
272
+ )
273
+
274
+ if __name__ == "__main__":
275
+ fixseed(10)
276
+ demo.launch(share=True)
demo/install.sh ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # make sure to have cuda 11.7 and gcc 9.0 installed
4
+ # install environment
5
+ pip install -r scripts/requirements.txt
6
+ sh scripts/download_prereq.sh
7
+
8
+ # download pytorch3d
9
+ pip install "git+https://github.com/facebookresearch/pytorch3d.git"
10
+
11
+ # download model stuff
12
+ wget http://audio2photoreal_models.berkeleyvision.org/PXB184_models.tar || { echo 'downloading model failed' ; exit 1; }
13
+ tar xvf PXB184_models.tar
14
+ rm PXB184_models.tar
15
+
16
+ # install rendering stuff
17
+ mkdir -p checkpoints/ca_body/data/
18
+ wget https://github.com/facebookresearch/ca_body/releases/download/v0.0.1-alpha/PXB184.tar.gz || { echo 'downloading ca body model failed' ; exit 1; }
19
+ tar xvf PXB184.tar.gz --directory checkpoints/ca_body/data/
20
+ rm PXB184.tar.gz
demo/requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ attrdict
2
+ einops==0.7.0
3
+ fairseq==0.12.2
4
+ gradio==4.31.3
5
+ gradio_client==0.7.3
6
+ huggingface-hub==0.19.4
7
+ hydra-core==1.0.7
8
+ mediapy==1.2.0
9
+ numpy==1.26.2
10
+ omegaconf==2.0.6
11
+ opencv-python==4.8.1.78
12
+ protobuf==4.25.1
13
+ tensorboardX==2.6.2.2
14
+ torch==2.0.1
15
+ torchaudio==2.0.2
16
+ torchvision==0.15.2
17
+ tqdm==4.66.3
diffusion/fp16_util.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ All rights reserved.
4
+ This source code is licensed under the license found in the
5
+ LICENSE file in the root directory of this source tree.
6
+ """
7
+
8
+ """
9
+ original code from
10
+ https://github.com/GuyTevet/motion-diffusion-model/blob/main/diffusion/gaussian_diffusion.py
11
+ under an MIT license
12
+ https://github.com/GuyTevet/motion-diffusion-model/blob/main/LICENSE
13
+ """
14
+
15
+ """
16
+ Helpers to train with 16-bit precision.
17
+ """
18
+
19
+ import numpy as np
20
+ import torch as th
21
+ import torch.nn as nn
22
+ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
23
+
24
+ from utils import logger
25
+
26
+ INITIAL_LOG_LOSS_SCALE = 20.0
27
+
28
+
29
+ def convert_module_to_f16(l):
30
+ """
31
+ Convert primitive modules to float16.
32
+ """
33
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
34
+ l.weight.data = l.weight.data.half()
35
+ if l.bias is not None:
36
+ l.bias.data = l.bias.data.half()
37
+
38
+
39
+ def convert_module_to_f32(l):
40
+ """
41
+ Convert primitive modules to float32, undoing convert_module_to_f16().
42
+ """
43
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
44
+ l.weight.data = l.weight.data.float()
45
+ if l.bias is not None:
46
+ l.bias.data = l.bias.data.float()
47
+
48
+
49
+ def make_master_params(param_groups_and_shapes):
50
+ """
51
+ Copy model parameters into a (differently-shaped) list of full-precision
52
+ parameters.
53
+ """
54
+ master_params = []
55
+ for param_group, shape in param_groups_and_shapes:
56
+ master_param = nn.Parameter(
57
+ _flatten_dense_tensors(
58
+ [param.detach().float() for (_, param) in param_group]
59
+ ).view(shape)
60
+ )
61
+ master_param.requires_grad = True
62
+ master_params.append(master_param)
63
+ return master_params
64
+
65
+
66
+ def model_grads_to_master_grads(param_groups_and_shapes, master_params):
67
+ """
68
+ Copy the gradients from the model parameters into the master parameters
69
+ from make_master_params().
70
+ """
71
+ for master_param, (param_group, shape) in zip(
72
+ master_params, param_groups_and_shapes
73
+ ):
74
+ master_param.grad = _flatten_dense_tensors(
75
+ [param_grad_or_zeros(param) for (_, param) in param_group]
76
+ ).view(shape)
77
+
78
+
79
+ def master_params_to_model_params(param_groups_and_shapes, master_params):
80
+ """
81
+ Copy the master parameter data back into the model parameters.
82
+ """
83
+ # Without copying to a list, if a generator is passed, this will
84
+ # silently not copy any parameters.
85
+ for master_param, (param_group, _) in zip(master_params, param_groups_and_shapes):
86
+ for (_, param), unflat_master_param in zip(
87
+ param_group, unflatten_master_params(param_group, master_param.view(-1))
88
+ ):
89
+ param.detach().copy_(unflat_master_param)
90
+
91
+
92
+ def unflatten_master_params(param_group, master_param):
93
+ return _unflatten_dense_tensors(master_param, [param for (_, param) in param_group])
94
+
95
+
96
+ def get_param_groups_and_shapes(named_model_params):
97
+ named_model_params = list(named_model_params)
98
+ scalar_vector_named_params = (
99
+ [(n, p) for (n, p) in named_model_params if p.ndim <= 1],
100
+ (-1),
101
+ )
102
+ matrix_named_params = (
103
+ [(n, p) for (n, p) in named_model_params if p.ndim > 1],
104
+ (1, -1),
105
+ )
106
+ return [scalar_vector_named_params, matrix_named_params]
107
+
108
+
109
+ def master_params_to_state_dict(
110
+ model, param_groups_and_shapes, master_params, use_fp16
111
+ ):
112
+ if use_fp16:
113
+ state_dict = model.state_dict()
114
+ for master_param, (param_group, _) in zip(
115
+ master_params, param_groups_and_shapes
116
+ ):
117
+ for (name, _), unflat_master_param in zip(
118
+ param_group, unflatten_master_params(param_group, master_param.view(-1))
119
+ ):
120
+ assert name in state_dict
121
+ state_dict[name] = unflat_master_param
122
+ else:
123
+ state_dict = model.state_dict()
124
+ for i, (name, _value) in enumerate(model.named_parameters()):
125
+ assert name in state_dict
126
+ state_dict[name] = master_params[i]
127
+ return state_dict
128
+
129
+
130
+ def state_dict_to_master_params(model, state_dict, use_fp16):
131
+ if use_fp16:
132
+ named_model_params = [
133
+ (name, state_dict[name]) for name, _ in model.named_parameters()
134
+ ]
135
+ param_groups_and_shapes = get_param_groups_and_shapes(named_model_params)
136
+ master_params = make_master_params(param_groups_and_shapes)
137
+ else:
138
+ master_params = [state_dict[name] for name, _ in model.named_parameters()]
139
+ return master_params
140
+
141
+
142
+ def zero_master_grads(master_params):
143
+ for param in master_params:
144
+ param.grad = None
145
+
146
+
147
+ def zero_grad(model_params):
148
+ for param in model_params:
149
+ # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group
150
+ if param.grad is not None:
151
+ param.grad.detach_()
152
+ param.grad.zero_()
153
+
154
+
155
+ def param_grad_or_zeros(param):
156
+ if param.grad is not None:
157
+ return param.grad.data.detach()
158
+ else:
159
+ return th.zeros_like(param)
160
+
161
+
162
+ class MixedPrecisionTrainer:
163
+ def __init__(
164
+ self,
165
+ *,
166
+ model,
167
+ use_fp16=False,
168
+ fp16_scale_growth=1e-3,
169
+ initial_lg_loss_scale=INITIAL_LOG_LOSS_SCALE,
170
+ ):
171
+ self.model = model
172
+ self.use_fp16 = use_fp16
173
+ self.fp16_scale_growth = fp16_scale_growth
174
+
175
+ self.model_params = list(self.model.parameters())
176
+ self.master_params = self.model_params
177
+ self.param_groups_and_shapes = None
178
+ self.lg_loss_scale = initial_lg_loss_scale
179
+
180
+ if self.use_fp16:
181
+ self.param_groups_and_shapes = get_param_groups_and_shapes(
182
+ self.model.named_parameters()
183
+ )
184
+ self.master_params = make_master_params(self.param_groups_and_shapes)
185
+ self.model.convert_to_fp16()
186
+
187
+ def zero_grad(self):
188
+ zero_grad(self.model_params)
189
+
190
+ def backward(self, loss: th.Tensor):
191
+ if self.use_fp16:
192
+ loss_scale = 2**self.lg_loss_scale
193
+ (loss * loss_scale).backward()
194
+ else:
195
+ loss.backward()
196
+
197
+ def optimize(self, opt: th.optim.Optimizer):
198
+ if self.use_fp16:
199
+ return self._optimize_fp16(opt)
200
+ else:
201
+ return self._optimize_normal(opt)
202
+
203
+ def _optimize_fp16(self, opt: th.optim.Optimizer):
204
+ logger.logkv_mean("lg_loss_scale", self.lg_loss_scale)
205
+ model_grads_to_master_grads(self.param_groups_and_shapes, self.master_params)
206
+ grad_norm, param_norm = self._compute_norms(grad_scale=2**self.lg_loss_scale)
207
+ if check_overflow(grad_norm):
208
+ self.lg_loss_scale -= 1
209
+ logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}")
210
+ zero_master_grads(self.master_params)
211
+ return False
212
+
213
+ logger.logkv_mean("grad_norm", grad_norm)
214
+ logger.logkv_mean("param_norm", param_norm)
215
+
216
+ self.master_params[0].grad.mul_(1.0 / (2**self.lg_loss_scale))
217
+ opt.step()
218
+ zero_master_grads(self.master_params)
219
+ master_params_to_model_params(self.param_groups_and_shapes, self.master_params)
220
+ self.lg_loss_scale += self.fp16_scale_growth
221
+ return True
222
+
223
+ def _optimize_normal(self, opt: th.optim.Optimizer):
224
+ grad_norm, param_norm = self._compute_norms()
225
+ logger.logkv_mean("grad_norm", grad_norm)
226
+ logger.logkv_mean("param_norm", param_norm)
227
+ opt.step()
228
+ return True
229
+
230
+ def _compute_norms(self, grad_scale=1.0):
231
+ grad_norm = 0.0
232
+ param_norm = 0.0
233
+ for p in self.master_params:
234
+ with th.no_grad():
235
+ param_norm += th.norm(p, p=2, dtype=th.float32).item() ** 2
236
+ if p.grad is not None:
237
+ grad_norm += th.norm(p.grad, p=2, dtype=th.float32).item() ** 2
238
+ return np.sqrt(grad_norm) / grad_scale, np.sqrt(param_norm)
239
+
240
+ def master_params_to_state_dict(self, master_params):
241
+ return master_params_to_state_dict(
242
+ self.model, self.param_groups_and_shapes, master_params, self.use_fp16
243
+ )
244
+
245
+ def state_dict_to_master_params(self, state_dict):
246
+ return state_dict_to_master_params(self.model, state_dict, self.use_fp16)
247
+
248
+
249
+ def check_overflow(value):
250
+ return (value == float("inf")) or (value == -float("inf")) or (value != value)
diffusion/gaussian_diffusion.py ADDED
@@ -0,0 +1,1273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ All rights reserved.
4
+ This source code is licensed under the license found in the
5
+ LICENSE file in the root directory of this source tree.
6
+ """
7
+
8
+ """
9
+ original code from
10
+ https://github.com/GuyTevet/motion-diffusion-model/blob/main/diffusion/gaussian_diffusion.py
11
+ under an MIT license
12
+ https://github.com/GuyTevet/motion-diffusion-model/blob/main/LICENSE
13
+ """
14
+
15
+ import enum
16
+ import math
17
+ from copy import deepcopy
18
+
19
+ import numpy as np
20
+ import torch
21
+ import torch as th
22
+ from diffusion.losses import discretized_gaussian_log_likelihood, normal_kl
23
+ from diffusion.nn import mean_flat, sum_flat
24
+
25
+
26
+ def get_named_beta_schedule(schedule_name, num_diffusion_timesteps, scale_betas=1.0):
27
+ """
28
+ Get a pre-defined beta schedule for the given name.
29
+
30
+ The beta schedule library consists of beta schedules which remain similar
31
+ in the limit of num_diffusion_timesteps.
32
+ Beta schedules may be added, but should not be removed or changed once
33
+ they are committed to maintain backwards compatibility.
34
+ """
35
+ if schedule_name == "linear":
36
+ # Linear schedule from Ho et al, extended to work for any number of
37
+ # diffusion steps.
38
+ scale = scale_betas * 1000 / num_diffusion_timesteps
39
+ beta_start = scale * 0.0001
40
+ beta_end = scale * 0.02
41
+ return np.linspace(
42
+ beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
43
+ )
44
+ elif schedule_name == "cosine":
45
+ return betas_for_alpha_bar(
46
+ num_diffusion_timesteps,
47
+ lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
48
+ )
49
+ else:
50
+ raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
51
+
52
+
53
+ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
54
+ """
55
+ Create a beta schedule that discretizes the given alpha_t_bar function,
56
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
57
+
58
+ :param num_diffusion_timesteps: the number of betas to produce.
59
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
60
+ produces the cumulative product of (1-beta) up to that
61
+ part of the diffusion process.
62
+ :param max_beta: the maximum beta to use; use values lower than 1 to
63
+ prevent singularities.
64
+ """
65
+ betas = []
66
+ for i in range(num_diffusion_timesteps):
67
+ t1 = i / num_diffusion_timesteps
68
+ t2 = (i + 1) / num_diffusion_timesteps
69
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
70
+ return np.array(betas)
71
+
72
+
73
+ class ModelMeanType(enum.Enum):
74
+ """
75
+ Which type of output the model predicts.
76
+ """
77
+
78
+ PREVIOUS_X = enum.auto() # the model predicts x_{t-1}
79
+ START_X = enum.auto() # the model predicts x_0
80
+ EPSILON = enum.auto() # the model predicts epsilon
81
+
82
+
83
+ class ModelVarType(enum.Enum):
84
+ """
85
+ What is used as the model's output variance.
86
+
87
+ The LEARNED_RANGE option has been added to allow the model to predict
88
+ values between FIXED_SMALL and FIXED_LARGE, making its job easier.
89
+ """
90
+
91
+ LEARNED = enum.auto()
92
+ FIXED_SMALL = enum.auto()
93
+ FIXED_LARGE = enum.auto()
94
+ LEARNED_RANGE = enum.auto()
95
+
96
+
97
+ class LossType(enum.Enum):
98
+ MSE = enum.auto() # use raw MSE loss (and KL when learning variances)
99
+ RESCALED_MSE = (
100
+ enum.auto()
101
+ ) # use raw MSE loss (with RESCALED_KL when learning variances)
102
+ KL = enum.auto() # use the variational lower-bound
103
+ RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB
104
+
105
+ def is_vb(self):
106
+ return self == LossType.KL or self == LossType.RESCALED_KL
107
+
108
+
109
+ class GaussianDiffusion:
110
+ """
111
+ Utilities for training and sampling diffusion models.
112
+
113
+ Ported directly from here, and then adapted over time to further experimentation.
114
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
115
+
116
+ :param betas: a 1-D numpy array of betas for each diffusion timestep,
117
+ starting at T and going to 1.
118
+ :param model_mean_type: a ModelMeanType determining what the model outputs.
119
+ :param model_var_type: a ModelVarType determining how variance is output.
120
+ :param loss_type: a LossType determining the loss function to use.
121
+ :param rescale_timesteps: if True, pass floating point timesteps into the
122
+ model so that they are always scaled like in the
123
+ original paper (0 to 1000).
124
+ """
125
+
126
+ def __init__(
127
+ self,
128
+ *,
129
+ betas,
130
+ model_mean_type,
131
+ model_var_type,
132
+ loss_type,
133
+ rescale_timesteps=False,
134
+ lambda_vel=0.0,
135
+ data_format="pose",
136
+ model_path=None,
137
+ ):
138
+ self.model_mean_type = model_mean_type
139
+ self.model_var_type = model_var_type
140
+ self.loss_type = loss_type
141
+ self.rescale_timesteps = rescale_timesteps
142
+ self.data_format = data_format
143
+ self.lambda_vel = lambda_vel
144
+ if self.lambda_vel > 0.0:
145
+ assert (
146
+ self.loss_type == LossType.MSE
147
+ ), "Geometric losses are supported by MSE loss type only!"
148
+
149
+ # Use float64 for accuracy.
150
+ betas = np.array(betas, dtype=np.float64)
151
+ self.betas = betas
152
+ assert len(betas.shape) == 1, "betas must be 1-D"
153
+ assert (betas > 0).all() and (betas <= 1).all()
154
+
155
+ self.num_timesteps = int(betas.shape[0])
156
+
157
+ alphas = 1.0 - betas
158
+ self.alphas_cumprod = np.cumprod(alphas, axis=0)
159
+ self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
160
+ self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
161
+ assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
162
+
163
+ # calculations for diffusion q(x_t | x_{t-1}) and others
164
+ self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
165
+ self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
166
+ self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
167
+ self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
168
+ self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
169
+
170
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
171
+ self.posterior_variance = (
172
+ betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
173
+ )
174
+ # log calculation clipped because the posterior variance is 0 at the
175
+ # beginning of the diffusion chain.
176
+ self.posterior_log_variance_clipped = np.log(
177
+ np.append(self.posterior_variance[1], self.posterior_variance[1:])
178
+ )
179
+ self.posterior_mean_coef1 = (
180
+ betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
181
+ )
182
+ self.posterior_mean_coef2 = (
183
+ (1.0 - self.alphas_cumprod_prev)
184
+ * np.sqrt(alphas)
185
+ / (1.0 - self.alphas_cumprod)
186
+ )
187
+
188
+ self.l2_loss = lambda a, b: (a - b) ** 2
189
+
190
+ def masked_l2(self, a, b, mask):
191
+ loss = self.l2_loss(a, b)
192
+ loss = sum_flat(loss * mask.float())
193
+ n_entries = a.shape[1] * a.shape[2]
194
+ non_zero_elements = sum_flat(mask) * n_entries
195
+ mse_loss_val = loss / non_zero_elements
196
+ return mse_loss_val
197
+
198
+ def q_mean_variance(self, x_start, t):
199
+ """
200
+ Get the distribution q(x_t | x_0).
201
+
202
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
203
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
204
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
205
+ """
206
+ mean = (
207
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
208
+ )
209
+ variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
210
+ log_variance = _extract_into_tensor(
211
+ self.log_one_minus_alphas_cumprod, t, x_start.shape
212
+ )
213
+ return mean, variance, log_variance
214
+
215
+ def q_sample(self, x_start, t, noise=None):
216
+ """
217
+ Diffuse the dataset for a given number of diffusion steps.
218
+
219
+ In other words, sample from q(x_t | x_0).
220
+
221
+ :param x_start: the initial dataset batch.
222
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
223
+ :param noise: if specified, the split-out normal noise.
224
+ :return: A noisy version of x_start.
225
+ """
226
+ if noise is None:
227
+ noise = th.randn_like(x_start)
228
+ assert noise.shape == x_start.shape
229
+ return (
230
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
231
+ + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
232
+ * noise
233
+ )
234
+
235
+ def q_posterior_mean_variance(self, x_start, x_t, t):
236
+ """
237
+ Compute the mean and variance of the diffusion posterior:
238
+
239
+ q(x_{t-1} | x_t, x_0)
240
+
241
+ """
242
+ assert x_start.shape == x_t.shape, f"x_start: {x_start.shape}, x_t: {x_t.shape}"
243
+ posterior_mean = (
244
+ _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
245
+ + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
246
+ )
247
+ posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
248
+ posterior_log_variance_clipped = _extract_into_tensor(
249
+ self.posterior_log_variance_clipped, t, x_t.shape
250
+ )
251
+ assert (
252
+ posterior_mean.shape[0]
253
+ == posterior_variance.shape[0]
254
+ == posterior_log_variance_clipped.shape[0]
255
+ == x_start.shape[0]
256
+ )
257
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
258
+
259
+ def p_mean_variance(
260
+ self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None
261
+ ):
262
+ """
263
+ Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
264
+ the initial x, x_0.
265
+
266
+ :param model: the model, which takes a signal and a batch of timesteps
267
+ as input.
268
+ :param x: the [N x C x ...] tensor at time t.
269
+ :param t: a 1-D Tensor of timesteps.
270
+ :param clip_denoised: if True, clip the denoised signal into [-1, 1].
271
+ :param denoised_fn: if not None, a function which applies to the
272
+ x_start prediction before it is used to sample. Applies before
273
+ clip_denoised.
274
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
275
+ pass to the model. This can be used for conditioning.
276
+ :return: a dict with the following keys:
277
+ - 'mean': the model mean output.
278
+ - 'variance': the model variance output.
279
+ - 'log_variance': the log of 'variance'.
280
+ - 'pred_xstart': the prediction for x_0.
281
+ """
282
+ if model_kwargs is None:
283
+ model_kwargs = {}
284
+
285
+ B, C = x.shape[:2]
286
+ assert t.shape == (B,)
287
+ model_output = model(x, self._scale_timesteps(t), **model_kwargs)
288
+
289
+ model_variance, model_log_variance = {
290
+ # for fixedlarge, we set the initial (log-)variance like so
291
+ # to get a better decoder log likelihood.
292
+ ModelVarType.FIXED_LARGE: (
293
+ np.append(self.posterior_variance[1], self.betas[1:]),
294
+ np.log(np.append(self.posterior_variance[1], self.betas[1:])),
295
+ ),
296
+ ModelVarType.FIXED_SMALL: (
297
+ self.posterior_variance,
298
+ self.posterior_log_variance_clipped,
299
+ ),
300
+ }[self.model_var_type]
301
+
302
+ model_variance = _extract_into_tensor(model_variance, t, x.shape)
303
+ model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
304
+
305
+ def process_xstart(x):
306
+ if denoised_fn is not None:
307
+ x = denoised_fn(x)
308
+ if clip_denoised:
309
+ return x.clamp(-1, 1)
310
+ return x
311
+
312
+ pred_xstart = process_xstart(model_output)
313
+ pred_xstart = pred_xstart.permute(0, 2, 1).unsqueeze(2)
314
+ model_mean, _, _ = self.q_posterior_mean_variance(
315
+ x_start=pred_xstart, x_t=x, t=t
316
+ )
317
+
318
+ assert (
319
+ model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
320
+ ), print(
321
+ f"{model_mean.shape} == {model_log_variance.shape} == {pred_xstart.shape} == {x.shape}"
322
+ )
323
+ return {
324
+ "mean": model_mean,
325
+ "variance": model_variance,
326
+ "log_variance": model_log_variance,
327
+ "pred_xstart": pred_xstart,
328
+ }
329
+
330
+ def _predict_xstart_from_eps(self, x_t, t, eps):
331
+ assert x_t.shape == eps.shape
332
+ return (
333
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
334
+ - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
335
+ )
336
+
337
+ def _predict_xstart_from_xprev(self, x_t, t, xprev):
338
+ assert x_t.shape == xprev.shape
339
+ return (
340
+ _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev
341
+ - _extract_into_tensor(
342
+ self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape
343
+ )
344
+ * x_t
345
+ )
346
+
347
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
348
+ return (
349
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
350
+ - pred_xstart
351
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
352
+
353
+ def _scale_timesteps(self, t):
354
+ if self.rescale_timesteps:
355
+ return t.float() * (1000.0 / self.num_timesteps)
356
+ return t
357
+
358
+ def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
359
+ """
360
+ Compute the mean for the previous step, given a function cond_fn that
361
+ computes the gradient of a conditional log probability with respect to
362
+ x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
363
+ condition on y.
364
+
365
+ This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
366
+ """
367
+ gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs)
368
+ new_mean = (
369
+ p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
370
+ )
371
+ return new_mean
372
+
373
+ def condition_mean_with_grad(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
374
+ """
375
+ Compute the mean for the previous step, given a function cond_fn that
376
+ computes the gradient of a conditional log probability with respect to
377
+ x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
378
+ condition on y.
379
+
380
+ This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
381
+ """
382
+ gradient = cond_fn(x, t, p_mean_var, **model_kwargs)
383
+ new_mean = (
384
+ p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
385
+ )
386
+ return new_mean
387
+
388
+ def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
389
+ """
390
+ Compute what the p_mean_variance output would have been, should the
391
+ model's score function be conditioned by cond_fn.
392
+
393
+ See condition_mean() for details on cond_fn.
394
+
395
+ Unlike condition_mean(), this instead uses the conditioning strategy
396
+ from Song et al (2020).
397
+ """
398
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
399
+
400
+ eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
401
+ eps = eps - (1 - alpha_bar).sqrt() * cond_fn(
402
+ x, self._scale_timesteps(t), **model_kwargs
403
+ )
404
+
405
+ out = p_mean_var.copy()
406
+ out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
407
+ out["mean"], _, _ = self.q_posterior_mean_variance(
408
+ x_start=out["pred_xstart"], x_t=x, t=t
409
+ )
410
+ return out
411
+
412
+ def condition_score_with_grad(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
413
+ """
414
+ Compute what the p_mean_variance output would have been, should the
415
+ model's score function be conditioned by cond_fn.
416
+
417
+ See condition_mean() for details on cond_fn.
418
+
419
+ Unlike condition_mean(), this instead uses the conditioning strategy
420
+ from Song et al (2020).
421
+ """
422
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
423
+
424
+ eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
425
+ eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, p_mean_var, **model_kwargs)
426
+
427
+ out = p_mean_var.copy()
428
+ out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
429
+ out["mean"], _, _ = self.q_posterior_mean_variance(
430
+ x_start=out["pred_xstart"], x_t=x, t=t
431
+ )
432
+ return out
433
+
434
+ def p_sample(
435
+ self,
436
+ model,
437
+ x,
438
+ t,
439
+ clip_denoised=True,
440
+ denoised_fn=None,
441
+ cond_fn=None,
442
+ model_kwargs=None,
443
+ const_noise=False,
444
+ ):
445
+ """
446
+ Sample x_{t-1} from the model at the given timestep.
447
+
448
+ :param model: the model to sample from.
449
+ :param x: the current tensor at x_{t-1}.
450
+ :param t: the value of t, starting at 0 for the first diffusion step.
451
+ :param clip_denoised: if True, clip the x_start prediction to [-1, 1].
452
+ :param denoised_fn: if not None, a function which applies to the
453
+ x_start prediction before it is used to sample.
454
+ :param cond_fn: if not None, this is a gradient function that acts
455
+ similarly to the model.
456
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
457
+ pass to the model. This can be used for conditioning.
458
+ :return: a dict containing the following keys:
459
+ - 'sample': a random sample from the model.
460
+ - 'pred_xstart': a prediction of x_0.
461
+ """
462
+ out = self.p_mean_variance(
463
+ model,
464
+ x,
465
+ t,
466
+ clip_denoised=clip_denoised,
467
+ denoised_fn=denoised_fn,
468
+ model_kwargs=model_kwargs,
469
+ )
470
+
471
+ nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
472
+ if cond_fn is not None:
473
+ out["mean"] = self.condition_mean(
474
+ cond_fn, out, x, t, model_kwargs=model_kwargs
475
+ )
476
+ sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
477
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
478
+
479
+ def p_sample_with_grad(
480
+ self,
481
+ model,
482
+ x,
483
+ t,
484
+ clip_denoised=True,
485
+ denoised_fn=None,
486
+ cond_fn=None,
487
+ model_kwargs=None,
488
+ ):
489
+ """
490
+ Sample x_{t-1} from the model at the given timestep.
491
+
492
+ :param model: the model to sample from.
493
+ :param x: the current tensor at x_{t-1}.
494
+ :param t: the value of t, starting at 0 for the first diffusion step.
495
+ :param clip_denoised: if True, clip the x_start prediction to [-1, 1].
496
+ :param denoised_fn: if not None, a function which applies to the
497
+ x_start prediction before it is used to sample.
498
+ :param cond_fn: if not None, this is a gradient function that acts
499
+ similarly to the model.
500
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
501
+ pass to the model. This can be used for conditioning.
502
+ :return: a dict containing the following keys:
503
+ - 'sample': a random sample from the model.
504
+ - 'pred_xstart': a prediction of x_0.
505
+ """
506
+ with th.enable_grad():
507
+ x = x.detach().requires_grad_()
508
+ out = self.p_mean_variance(
509
+ model,
510
+ x,
511
+ t,
512
+ clip_denoised=clip_denoised,
513
+ denoised_fn=denoised_fn,
514
+ model_kwargs=model_kwargs,
515
+ )
516
+ noise = th.randn_like(x)
517
+ nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
518
+ if cond_fn is not None:
519
+ out["mean"] = self.condition_mean_with_grad(
520
+ cond_fn, out, x, t, model_kwargs=model_kwargs
521
+ )
522
+ sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
523
+ return {"sample": sample, "pred_xstart": out["pred_xstart"].detach()}
524
+
525
+ def p_sample_loop(
526
+ self,
527
+ model,
528
+ shape,
529
+ noise=None,
530
+ clip_denoised=True,
531
+ denoised_fn=None,
532
+ cond_fn=None,
533
+ model_kwargs=None,
534
+ device=None,
535
+ progress=False,
536
+ skip_timesteps=0,
537
+ init_image=None,
538
+ randomize_class=False,
539
+ cond_fn_with_grad=False,
540
+ dump_steps=None,
541
+ const_noise=False,
542
+ ):
543
+ """
544
+ Generate samples from the model.
545
+
546
+ :param model: the model module.
547
+ :param shape: the shape of the samples, (N, C, H, W).
548
+ :param noise: if specified, the noise from the encoder to sample.
549
+ Should be of the same shape as `shape`.
550
+ :param clip_denoised: if True, clip x_start predictions to [-1, 1].
551
+ :param denoised_fn: if not None, a function which applies to the
552
+ x_start prediction before it is used to sample.
553
+ :param cond_fn: if not None, this is a gradient function that acts
554
+ similarly to the model.
555
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
556
+ pass to the model. This can be used for conditioning.
557
+ :param device: if specified, the device to create the samples on.
558
+ If not specified, use a model parameter's device.
559
+ :param progress: if True, show a tqdm progress bar.
560
+ :param const_noise: If True, will noise all samples with the same noise throughout sampling
561
+ :return: a non-differentiable batch of samples.
562
+ """
563
+ final = None
564
+ if dump_steps is not None:
565
+ dump = []
566
+
567
+ for i, sample in enumerate(
568
+ self.p_sample_loop_progressive(
569
+ model,
570
+ shape,
571
+ noise=noise,
572
+ clip_denoised=clip_denoised,
573
+ denoised_fn=denoised_fn,
574
+ cond_fn=cond_fn,
575
+ model_kwargs=model_kwargs,
576
+ device=device,
577
+ progress=progress,
578
+ skip_timesteps=skip_timesteps,
579
+ init_image=init_image,
580
+ randomize_class=randomize_class,
581
+ cond_fn_with_grad=cond_fn_with_grad,
582
+ const_noise=const_noise,
583
+ )
584
+ ):
585
+ if dump_steps is not None and i in dump_steps:
586
+ dump.append(deepcopy(sample["sample"]))
587
+ final = sample
588
+ if dump_steps is not None:
589
+ return dump
590
+ return final["sample"]
591
+
592
+ def p_sample_loop_progressive(
593
+ self,
594
+ model,
595
+ shape,
596
+ noise=None,
597
+ clip_denoised=True,
598
+ denoised_fn=None,
599
+ cond_fn=None,
600
+ model_kwargs=None,
601
+ device=None,
602
+ progress=False,
603
+ skip_timesteps=0,
604
+ init_image=None,
605
+ randomize_class=False,
606
+ cond_fn_with_grad=False,
607
+ const_noise=False,
608
+ ):
609
+ """
610
+ Generate samples from the model and yield intermediate samples from
611
+ each timestep of diffusion.
612
+
613
+ Arguments are the same as p_sample_loop().
614
+ Returns a generator over dicts, where each dict is the return value of
615
+ p_sample().
616
+ """
617
+ if device is None:
618
+ device = next(model.parameters()).device
619
+ assert isinstance(shape, (tuple, list))
620
+ if noise is not None:
621
+ img = noise
622
+ else:
623
+ img = th.randn(*shape, device=device)
624
+
625
+ if skip_timesteps and init_image is None:
626
+ init_image = th.zeros_like(img)
627
+
628
+ indices = list(range(self.num_timesteps - skip_timesteps))[::-1]
629
+
630
+ if init_image is not None:
631
+ my_t = th.ones([shape[0]], device=device, dtype=th.long) * indices[0]
632
+ img = self.q_sample(init_image, my_t, img)
633
+
634
+ if progress:
635
+ # Lazy import so that we don't depend on tqdm.
636
+ from tqdm.auto import tqdm
637
+
638
+ indices = tqdm(indices)
639
+
640
+ # number of timestamps to diffuse
641
+ for i in indices:
642
+ t = th.tensor([i] * shape[0], device=device)
643
+ if randomize_class and "y" in model_kwargs:
644
+ model_kwargs["y"] = th.randint(
645
+ low=0,
646
+ high=model.num_classes,
647
+ size=model_kwargs["y"].shape,
648
+ device=model_kwargs["y"].device,
649
+ )
650
+ with th.no_grad():
651
+ sample_fn = (
652
+ self.p_sample_with_grad if cond_fn_with_grad else self.p_sample
653
+ )
654
+ out = sample_fn(
655
+ model,
656
+ img,
657
+ t,
658
+ clip_denoised=clip_denoised,
659
+ denoised_fn=denoised_fn,
660
+ cond_fn=cond_fn,
661
+ model_kwargs=model_kwargs,
662
+ const_noise=const_noise,
663
+ )
664
+ yield out
665
+ img = out["sample"]
666
+
667
+ def ddim_sample(
668
+ self,
669
+ model,
670
+ x,
671
+ t,
672
+ clip_denoised=True,
673
+ denoised_fn=None,
674
+ cond_fn=None,
675
+ model_kwargs=None,
676
+ eta=0.0,
677
+ ):
678
+ """
679
+ Sample x_{t-1} from the model using DDIM.
680
+
681
+ Same usage as p_sample().
682
+ """
683
+ out_orig = self.p_mean_variance(
684
+ model,
685
+ x,
686
+ t,
687
+ clip_denoised=clip_denoised,
688
+ denoised_fn=denoised_fn,
689
+ model_kwargs=model_kwargs,
690
+ )
691
+ if cond_fn is not None:
692
+ out = self.condition_score(
693
+ cond_fn, out_orig, x, t, model_kwargs=model_kwargs
694
+ )
695
+ else:
696
+ out = out_orig
697
+ # Usually our model outputs epsilon, but we re-derive it
698
+ # in case we used x_start or x_prev prediction.
699
+ eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
700
+
701
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
702
+ alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
703
+ sigma = (
704
+ eta
705
+ * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
706
+ * th.sqrt(1 - alpha_bar / alpha_bar_prev)
707
+ )
708
+ noise = th.randn_like(x)
709
+
710
+ mean_pred = (
711
+ out["pred_xstart"] * th.sqrt(alpha_bar_prev)
712
+ + th.sqrt(1 - alpha_bar_prev - sigma**2) * eps
713
+ )
714
+ nonzero_mask = (
715
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
716
+ ) # no noise when t == 0
717
+ sample = mean_pred + nonzero_mask * sigma * noise
718
+ return {"sample": sample, "pred_xstart": out_orig["pred_xstart"]}
719
+
720
+ def ddim_sample_with_grad(
721
+ self,
722
+ model,
723
+ x,
724
+ t,
725
+ clip_denoised=True,
726
+ denoised_fn=None,
727
+ cond_fn=None,
728
+ model_kwargs=None,
729
+ eta=0.0,
730
+ ):
731
+ """
732
+ Sample x_{t-1} from the model using DDIM.
733
+
734
+ Same usage as p_sample().
735
+ """
736
+ with th.enable_grad():
737
+ x = x.detach().requires_grad_()
738
+ out_orig = self.p_mean_variance(
739
+ model,
740
+ x,
741
+ t,
742
+ clip_denoised=clip_denoised,
743
+ denoised_fn=denoised_fn,
744
+ model_kwargs=model_kwargs,
745
+ )
746
+ if cond_fn is not None:
747
+ out = self.condition_score_with_grad(
748
+ cond_fn, out_orig, x, t, model_kwargs=model_kwargs
749
+ )
750
+ else:
751
+ out = out_orig
752
+
753
+ out["pred_xstart"] = out["pred_xstart"].detach()
754
+ # Usually our model outputs epsilon, but we re-derive it
755
+ # in case we used x_start or x_prev prediction.
756
+ eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
757
+
758
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
759
+ alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
760
+ sigma = (
761
+ eta
762
+ * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
763
+ * th.sqrt(1 - alpha_bar / alpha_bar_prev)
764
+ )
765
+ # Equation 12.
766
+ noise = th.randn_like(x)
767
+ mean_pred = (
768
+ out["pred_xstart"] * th.sqrt(alpha_bar_prev)
769
+ + th.sqrt(1 - alpha_bar_prev - sigma**2) * eps
770
+ )
771
+ nonzero_mask = (
772
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
773
+ ) # no noise when t == 0
774
+ sample = mean_pred + nonzero_mask * sigma * noise
775
+ return {"sample": sample, "pred_xstart": out_orig["pred_xstart"].detach()}
776
+
777
+ def ddim_reverse_sample(
778
+ self,
779
+ model,
780
+ x,
781
+ t,
782
+ clip_denoised=True,
783
+ denoised_fn=None,
784
+ model_kwargs=None,
785
+ eta=0.0,
786
+ ):
787
+ """
788
+ Sample x_{t+1} from the model using DDIM reverse ODE.
789
+ """
790
+ assert eta == 0.0, "Reverse ODE only for deterministic path"
791
+ out = self.p_mean_variance(
792
+ model,
793
+ x,
794
+ t,
795
+ clip_denoised=clip_denoised,
796
+ denoised_fn=denoised_fn,
797
+ model_kwargs=model_kwargs,
798
+ )
799
+ # Usually our model outputs epsilon, but we re-derive it
800
+ # in case we used x_start or x_prev prediction.
801
+ eps = (
802
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
803
+ - out["pred_xstart"]
804
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
805
+ alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
806
+
807
+ # Equation 12. reversed
808
+ mean_pred = (
809
+ out["pred_xstart"] * th.sqrt(alpha_bar_next)
810
+ + th.sqrt(1 - alpha_bar_next) * eps
811
+ )
812
+
813
+ return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
814
+
815
+ def ddim_sample_loop(
816
+ self,
817
+ model,
818
+ shape,
819
+ noise=None,
820
+ clip_denoised=True,
821
+ denoised_fn=None,
822
+ cond_fn=None,
823
+ model_kwargs=None,
824
+ device=None,
825
+ progress=False,
826
+ eta=0.0,
827
+ skip_timesteps=0,
828
+ init_image=None,
829
+ randomize_class=False,
830
+ cond_fn_with_grad=False,
831
+ dump_steps=None,
832
+ const_noise=False,
833
+ ):
834
+ """
835
+ Generate samples from the model using DDIM.
836
+
837
+ Same usage as p_sample_loop().
838
+ """
839
+ if dump_steps is not None:
840
+ raise NotImplementedError()
841
+ if const_noise == True:
842
+ raise NotImplementedError()
843
+
844
+ final = None
845
+ for sample in self.ddim_sample_loop_progressive(
846
+ model,
847
+ shape,
848
+ noise=noise,
849
+ clip_denoised=clip_denoised,
850
+ denoised_fn=denoised_fn,
851
+ cond_fn=cond_fn,
852
+ model_kwargs=model_kwargs,
853
+ device=device,
854
+ progress=progress,
855
+ eta=eta,
856
+ skip_timesteps=skip_timesteps,
857
+ init_image=init_image,
858
+ randomize_class=randomize_class,
859
+ cond_fn_with_grad=cond_fn_with_grad,
860
+ ):
861
+ final = sample
862
+ return final["pred_xstart"]
863
+
864
+ def ddim_sample_loop_progressive(
865
+ self,
866
+ model,
867
+ shape,
868
+ noise=None,
869
+ clip_denoised=True,
870
+ denoised_fn=None,
871
+ cond_fn=None,
872
+ model_kwargs=None,
873
+ device=None,
874
+ progress=False,
875
+ eta=0.0,
876
+ skip_timesteps=0,
877
+ init_image=None,
878
+ randomize_class=False,
879
+ cond_fn_with_grad=False,
880
+ ):
881
+ """
882
+ Use DDIM to sample from the model and yield intermediate samples from
883
+ each timestep of DDIM.
884
+
885
+ Same usage as p_sample_loop_progressive().
886
+ """
887
+ if device is None:
888
+ device = next(model.parameters()).device
889
+ assert isinstance(shape, (tuple, list))
890
+ if noise is not None:
891
+ img = noise
892
+ else:
893
+ img = th.randn(*shape, device=device)
894
+
895
+ if skip_timesteps and init_image is None:
896
+ init_image = th.zeros_like(img)
897
+
898
+ indices = list(range(self.num_timesteps - skip_timesteps))[::-1]
899
+
900
+ if init_image is not None:
901
+ my_t = th.ones([shape[0]], device=device, dtype=th.long) * indices[0]
902
+ img = self.q_sample(init_image, my_t, img)
903
+
904
+ if progress:
905
+ # Lazy import so that we don't depend on tqdm.
906
+ from tqdm.auto import tqdm
907
+
908
+ indices = tqdm(indices)
909
+
910
+ for i in indices:
911
+ t = th.tensor([i] * shape[0], device=device)
912
+ if randomize_class and "y" in model_kwargs:
913
+ model_kwargs["y"] = th.randint(
914
+ low=0,
915
+ high=model.num_classes,
916
+ size=model_kwargs["y"].shape,
917
+ device=model_kwargs["y"].device,
918
+ )
919
+ with th.no_grad():
920
+ sample_fn = (
921
+ self.ddim_sample_with_grad
922
+ if cond_fn_with_grad
923
+ else self.ddim_sample
924
+ )
925
+ out = sample_fn(
926
+ model,
927
+ img,
928
+ t,
929
+ clip_denoised=clip_denoised,
930
+ denoised_fn=denoised_fn,
931
+ cond_fn=cond_fn,
932
+ model_kwargs=model_kwargs,
933
+ eta=eta,
934
+ )
935
+ yield out
936
+ img = out["sample"]
937
+
938
+ def plms_sample(
939
+ self,
940
+ model,
941
+ x,
942
+ t,
943
+ clip_denoised=True,
944
+ denoised_fn=None,
945
+ cond_fn=None,
946
+ model_kwargs=None,
947
+ cond_fn_with_grad=False,
948
+ order=2,
949
+ old_out=None,
950
+ ):
951
+ """
952
+ Sample x_{t-1} from the model using Pseudo Linear Multistep.
953
+
954
+ Same usage as p_sample().
955
+ """
956
+ if not int(order) or not 1 <= order <= 4:
957
+ raise ValueError("order is invalid (should be int from 1-4).")
958
+
959
+ def get_model_output(x, t):
960
+ with th.set_grad_enabled(cond_fn_with_grad and cond_fn is not None):
961
+ x = x.detach().requires_grad_() if cond_fn_with_grad else x
962
+ out_orig = self.p_mean_variance(
963
+ model,
964
+ x,
965
+ t,
966
+ clip_denoised=clip_denoised,
967
+ denoised_fn=denoised_fn,
968
+ model_kwargs=model_kwargs,
969
+ )
970
+ if cond_fn is not None:
971
+ if cond_fn_with_grad:
972
+ out = self.condition_score_with_grad(
973
+ cond_fn, out_orig, x, t, model_kwargs=model_kwargs
974
+ )
975
+ x = x.detach()
976
+ else:
977
+ out = self.condition_score(
978
+ cond_fn, out_orig, x, t, model_kwargs=model_kwargs
979
+ )
980
+ else:
981
+ out = out_orig
982
+
983
+ # Usually our model outputs epsilon, but we re-derive it
984
+ # in case we used x_start or x_prev prediction.
985
+ eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
986
+ return eps, out, out_orig
987
+
988
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
989
+ alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
990
+ eps, out, out_orig = get_model_output(x, t)
991
+
992
+ if order > 1 and old_out is None:
993
+ # Pseudo Improved Euler
994
+ old_eps = [eps]
995
+ mean_pred = (
996
+ out["pred_xstart"] * th.sqrt(alpha_bar_prev)
997
+ + th.sqrt(1 - alpha_bar_prev) * eps
998
+ )
999
+ eps_2, _, _ = get_model_output(mean_pred, t - 1)
1000
+ eps_prime = (eps + eps_2) / 2
1001
+ pred_prime = self._predict_xstart_from_eps(x, t, eps_prime)
1002
+ mean_pred = (
1003
+ pred_prime * th.sqrt(alpha_bar_prev)
1004
+ + th.sqrt(1 - alpha_bar_prev) * eps_prime
1005
+ )
1006
+ else:
1007
+ # Pseudo Linear Multistep (Adams-Bashforth)
1008
+ old_eps = old_out["old_eps"]
1009
+ old_eps.append(eps)
1010
+ cur_order = min(order, len(old_eps))
1011
+ if cur_order == 1:
1012
+ eps_prime = old_eps[-1]
1013
+ elif cur_order == 2:
1014
+ eps_prime = (3 * old_eps[-1] - old_eps[-2]) / 2
1015
+ elif cur_order == 3:
1016
+ eps_prime = (23 * old_eps[-1] - 16 * old_eps[-2] + 5 * old_eps[-3]) / 12
1017
+ elif cur_order == 4:
1018
+ eps_prime = (
1019
+ 55 * old_eps[-1]
1020
+ - 59 * old_eps[-2]
1021
+ + 37 * old_eps[-3]
1022
+ - 9 * old_eps[-4]
1023
+ ) / 24
1024
+ else:
1025
+ raise RuntimeError("cur_order is invalid.")
1026
+ pred_prime = self._predict_xstart_from_eps(x, t, eps_prime)
1027
+ mean_pred = (
1028
+ pred_prime * th.sqrt(alpha_bar_prev)
1029
+ + th.sqrt(1 - alpha_bar_prev) * eps_prime
1030
+ )
1031
+
1032
+ if len(old_eps) >= order:
1033
+ old_eps.pop(0)
1034
+
1035
+ nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
1036
+ sample = mean_pred * nonzero_mask + out["pred_xstart"] * (1 - nonzero_mask)
1037
+
1038
+ return {
1039
+ "sample": sample,
1040
+ "pred_xstart": out_orig["pred_xstart"],
1041
+ "old_eps": old_eps,
1042
+ }
1043
+
1044
+ def plms_sample_loop(
1045
+ self,
1046
+ model,
1047
+ shape,
1048
+ noise=None,
1049
+ clip_denoised=True,
1050
+ denoised_fn=None,
1051
+ cond_fn=None,
1052
+ model_kwargs=None,
1053
+ device=None,
1054
+ progress=False,
1055
+ skip_timesteps=0,
1056
+ init_image=None,
1057
+ randomize_class=False,
1058
+ cond_fn_with_grad=False,
1059
+ order=2,
1060
+ ):
1061
+ """
1062
+ Generate samples from the model using Pseudo Linear Multistep.
1063
+
1064
+ Same usage as p_sample_loop().
1065
+ """
1066
+ final = None
1067
+ for sample in self.plms_sample_loop_progressive(
1068
+ model,
1069
+ shape,
1070
+ noise=noise,
1071
+ clip_denoised=clip_denoised,
1072
+ denoised_fn=denoised_fn,
1073
+ cond_fn=cond_fn,
1074
+ model_kwargs=model_kwargs,
1075
+ device=device,
1076
+ progress=progress,
1077
+ skip_timesteps=skip_timesteps,
1078
+ init_image=init_image,
1079
+ randomize_class=randomize_class,
1080
+ cond_fn_with_grad=cond_fn_with_grad,
1081
+ order=order,
1082
+ ):
1083
+ final = sample
1084
+ return final["sample"]
1085
+
1086
+ def plms_sample_loop_progressive(
1087
+ self,
1088
+ model,
1089
+ shape,
1090
+ noise=None,
1091
+ clip_denoised=True,
1092
+ denoised_fn=None,
1093
+ cond_fn=None,
1094
+ model_kwargs=None,
1095
+ device=None,
1096
+ progress=False,
1097
+ skip_timesteps=0,
1098
+ init_image=None,
1099
+ randomize_class=False,
1100
+ cond_fn_with_grad=False,
1101
+ order=2,
1102
+ ):
1103
+ """
1104
+ Use PLMS to sample from the model and yield intermediate samples from each
1105
+ timestep of PLMS.
1106
+
1107
+ Same usage as p_sample_loop_progressive().
1108
+ """
1109
+ if device is None:
1110
+ device = next(model.parameters()).device
1111
+ assert isinstance(shape, (tuple, list))
1112
+ if noise is not None:
1113
+ img = noise
1114
+ else:
1115
+ img = th.randn(*shape, device=device)
1116
+
1117
+ if skip_timesteps and init_image is None:
1118
+ init_image = th.zeros_like(img)
1119
+
1120
+ indices = list(range(self.num_timesteps - skip_timesteps))[::-1]
1121
+
1122
+ if init_image is not None:
1123
+ my_t = th.ones([shape[0]], device=device, dtype=th.long) * indices[0]
1124
+ img = self.q_sample(init_image, my_t, img)
1125
+
1126
+ if progress:
1127
+ # Lazy import so that we don't depend on tqdm.
1128
+ from tqdm.auto import tqdm
1129
+
1130
+ indices = tqdm(indices)
1131
+
1132
+ old_out = None
1133
+
1134
+ for i in indices:
1135
+ t = th.tensor([i] * shape[0], device=device)
1136
+ if randomize_class and "y" in model_kwargs:
1137
+ model_kwargs["y"] = th.randint(
1138
+ low=0,
1139
+ high=model.num_classes,
1140
+ size=model_kwargs["y"].shape,
1141
+ device=model_kwargs["y"].device,
1142
+ )
1143
+ with th.no_grad():
1144
+ out = self.plms_sample(
1145
+ model,
1146
+ img,
1147
+ t,
1148
+ clip_denoised=clip_denoised,
1149
+ denoised_fn=denoised_fn,
1150
+ cond_fn=cond_fn,
1151
+ model_kwargs=model_kwargs,
1152
+ cond_fn_with_grad=cond_fn_with_grad,
1153
+ order=order,
1154
+ old_out=old_out,
1155
+ )
1156
+ yield out
1157
+ old_out = out
1158
+ img = out["sample"]
1159
+
1160
+ def _vb_terms_bpd(
1161
+ self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None
1162
+ ):
1163
+ """
1164
+ Get a term for the variational lower-bound.
1165
+
1166
+ The resulting units are bits (rather than nats, as one might expect).
1167
+ This allows for comparison to other papers.
1168
+
1169
+ :return: a dict with the following keys:
1170
+ - 'output': a shape [N] tensor of NLLs or KLs.
1171
+ - 'pred_xstart': the x_0 predictions.
1172
+ """
1173
+ true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
1174
+ x_start=x_start, x_t=x_t, t=t
1175
+ )
1176
+ out = self.p_mean_variance(
1177
+ model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
1178
+ )
1179
+ kl = normal_kl(
1180
+ true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]
1181
+ )
1182
+ kl = mean_flat(kl) / np.log(2.0)
1183
+
1184
+ decoder_nll = -discretized_gaussian_log_likelihood(
1185
+ x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
1186
+ )
1187
+ assert decoder_nll.shape == x_start.shape
1188
+ decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
1189
+
1190
+ # At the first timestep return the decoder NLL,
1191
+ # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
1192
+ output = th.where((t == 0), decoder_nll, kl)
1193
+ return {"output": output, "pred_xstart": out["pred_xstart"]}
1194
+
1195
+ def training_losses(self, model, x_start, t, model_kwargs=None, noise=None):
1196
+ """
1197
+ Compute training losses for a single timestep.
1198
+
1199
+ :param model: the model to evaluate loss on.
1200
+ :param x_start: the [N x C x ...] tensor of inputs.
1201
+ :param t: a batch of timestep indices.
1202
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
1203
+ pass to the model. This can be used for conditioning.
1204
+ :param noise: if specified, the specific Gaussian noise to try to remove.
1205
+ :return: a dict with the key "loss" containing a tensor of shape [N].
1206
+ Some mean or variance settings may also have other keys.
1207
+ """
1208
+ mask = model_kwargs["y"]["mask"]
1209
+ if model_kwargs is None:
1210
+ model_kwargs = {}
1211
+ if noise is None:
1212
+ noise = th.randn_like(x_start)
1213
+ x_t = self.q_sample(
1214
+ x_start, t, noise=noise
1215
+ ) # use the formula to diffuse the starting tensor by t steps
1216
+ terms = {}
1217
+
1218
+ # set random dropout for conditioning in training
1219
+ model_kwargs["cond_drop_prob"] = 0.2
1220
+ model_output = model(x_t, self._scale_timesteps(t), **model_kwargs)
1221
+ target = {
1222
+ ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(
1223
+ x_start=x_start, x_t=x_t, t=t
1224
+ )[0],
1225
+ ModelMeanType.START_X: x_start,
1226
+ ModelMeanType.EPSILON: noise,
1227
+ }[self.model_mean_type]
1228
+
1229
+ model_output = model_output.permute(0, 2, 1).unsqueeze(2)
1230
+ assert model_output.shape == target.shape == x_start.shape
1231
+
1232
+ missing_mask = model_kwargs["y"]["missing"][..., 0]
1233
+ missing_mask = missing_mask.unsqueeze(1).unsqueeze(1)
1234
+ missing_mask = mask * missing_mask
1235
+ terms["rot_mse"] = self.masked_l2(target, model_output, missing_mask)
1236
+ if self.lambda_vel > 0.0:
1237
+ target_vel = target[..., 1:] - target[..., :-1]
1238
+ model_output_vel = model_output[..., 1:] - model_output[..., :-1]
1239
+ terms["vel_mse"] = self.masked_l2(
1240
+ target_vel,
1241
+ model_output_vel,
1242
+ mask[:, :, :, 1:],
1243
+ )
1244
+
1245
+ terms["loss"] = terms["rot_mse"] + (self.lambda_vel * terms.get("vel_mse", 0.0))
1246
+
1247
+ with torch.no_grad():
1248
+ terms["vb"] = self._vb_terms_bpd(
1249
+ model,
1250
+ x_start,
1251
+ x_t,
1252
+ t,
1253
+ clip_denoised=False,
1254
+ model_kwargs=model_kwargs,
1255
+ )["output"]
1256
+
1257
+ return terms
1258
+
1259
+
1260
+ def _extract_into_tensor(arr, timesteps, broadcast_shape):
1261
+ """
1262
+ Extract values from a 1-D numpy array for a batch of indices.
1263
+
1264
+ :param arr: the 1-D numpy array.
1265
+ :param timesteps: a tensor of indices into the array to extract.
1266
+ :param broadcast_shape: a larger shape of K dimensions with the batch
1267
+ dimension equal to the length of timesteps.
1268
+ :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
1269
+ """
1270
+ res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
1271
+ while len(res.shape) < len(broadcast_shape):
1272
+ res = res[..., None]
1273
+ return res.expand(broadcast_shape)
diffusion/losses.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ All rights reserved.
4
+ This source code is licensed under the license found in the
5
+ LICENSE file in the root directory of this source tree.
6
+ """
7
+
8
+ """
9
+ Helpers for various likelihood-based losses. These are ported from the original
10
+ Ho et al. diffusion models codebase:
11
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/utils.py
12
+ """
13
+
14
+ import numpy as np
15
+ import torch as th
16
+
17
+
18
+ def normal_kl(mean1, logvar1, mean2, logvar2):
19
+ """
20
+ Compute the KL divergence between two gaussians.
21
+
22
+ Shapes are automatically broadcasted, so batches can be compared to
23
+ scalars, among other use cases.
24
+ """
25
+ tensor = None
26
+ for obj in (mean1, logvar1, mean2, logvar2):
27
+ if isinstance(obj, th.Tensor):
28
+ tensor = obj
29
+ break
30
+ assert tensor is not None, "at least one argument must be a Tensor"
31
+
32
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
33
+ # Tensors, but it does not work for th.exp().
34
+ logvar1, logvar2 = [
35
+ x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
36
+ for x in (logvar1, logvar2)
37
+ ]
38
+
39
+ return 0.5 * (
40
+ -1.0
41
+ + logvar2
42
+ - logvar1
43
+ + th.exp(logvar1 - logvar2)
44
+ + ((mean1 - mean2) ** 2) * th.exp(-logvar2)
45
+ )
46
+
47
+
48
+ def approx_standard_normal_cdf(x):
49
+ """
50
+ A fast approximation of the cumulative distribution function of the
51
+ standard normal.
52
+ """
53
+ return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
54
+
55
+
56
+ def discretized_gaussian_log_likelihood(x, *, means, log_scales):
57
+ """
58
+ Compute the log-likelihood of a Gaussian distribution discretizing to a
59
+ given image.
60
+
61
+ :param x: the target images. It is assumed that this was uint8 values,
62
+ rescaled to the range [-1, 1].
63
+ :param means: the Gaussian mean Tensor.
64
+ :param log_scales: the Gaussian log stddev Tensor.
65
+ :return: a tensor like x of log probabilities (in nats).
66
+ """
67
+ assert x.shape == means.shape == log_scales.shape
68
+ centered_x = x - means
69
+ inv_stdv = th.exp(-log_scales)
70
+ plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
71
+ cdf_plus = approx_standard_normal_cdf(plus_in)
72
+ min_in = inv_stdv * (centered_x - 1.0 / 255.0)
73
+ cdf_min = approx_standard_normal_cdf(min_in)
74
+ log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
75
+ log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
76
+ cdf_delta = cdf_plus - cdf_min
77
+ log_probs = th.where(
78
+ x < -0.999,
79
+ log_cdf_plus,
80
+ th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
81
+ )
82
+ assert log_probs.shape == x.shape
83
+ return log_probs
diffusion/nn.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ All rights reserved.
4
+ This source code is licensed under the license found in the
5
+ LICENSE file in the root directory of this source tree.
6
+ """
7
+
8
+ """
9
+ original code from
10
+ https://github.com/GuyTevet/motion-diffusion-model/blob/main/diffusion/gaussian_diffusion.py
11
+ under an MIT license
12
+ https://github.com/GuyTevet/motion-diffusion-model/blob/main/LICENSE
13
+ """
14
+
15
+ """
16
+ Various utilities for neural networks.
17
+ """
18
+
19
+ import math
20
+
21
+ import torch as th
22
+ import torch.nn as nn
23
+
24
+
25
+ # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
26
+ class SiLU(nn.Module):
27
+ def forward(self, x):
28
+ return x * th.sigmoid(x)
29
+
30
+
31
+ class GroupNorm32(nn.GroupNorm):
32
+ def forward(self, x):
33
+ return super().forward(x.float()).type(x.dtype)
34
+
35
+
36
+ def conv_nd(dims, *args, **kwargs):
37
+ """
38
+ Create a 1D, 2D, or 3D convolution module.
39
+ """
40
+ if dims == 1:
41
+ return nn.Conv1d(*args, **kwargs)
42
+ elif dims == 2:
43
+ return nn.Conv2d(*args, **kwargs)
44
+ elif dims == 3:
45
+ return nn.Conv3d(*args, **kwargs)
46
+ raise ValueError(f"unsupported dimensions: {dims}")
47
+
48
+
49
+ def linear(*args, **kwargs):
50
+ """
51
+ Create a linear module.
52
+ """
53
+ return nn.Linear(*args, **kwargs)
54
+
55
+
56
+ def avg_pool_nd(dims, *args, **kwargs):
57
+ """
58
+ Create a 1D, 2D, or 3D average pooling module.
59
+ """
60
+ if dims == 1:
61
+ return nn.AvgPool1d(*args, **kwargs)
62
+ elif dims == 2:
63
+ return nn.AvgPool2d(*args, **kwargs)
64
+ elif dims == 3:
65
+ return nn.AvgPool3d(*args, **kwargs)
66
+ raise ValueError(f"unsupported dimensions: {dims}")
67
+
68
+
69
+ def update_ema(target_params, source_params, rate=0.99):
70
+ """
71
+ Update target parameters to be closer to those of source parameters using
72
+ an exponential moving average.
73
+
74
+ :param target_params: the target parameter sequence.
75
+ :param source_params: the source parameter sequence.
76
+ :param rate: the EMA rate (closer to 1 means slower).
77
+ """
78
+ for targ, src in zip(target_params, source_params):
79
+ targ.detach().mul_(rate).add_(src, alpha=1 - rate)
80
+
81
+
82
+ def zero_module(module):
83
+ """
84
+ Zero out the parameters of a module and return it.
85
+ """
86
+ for p in module.parameters():
87
+ p.detach().zero_()
88
+ return module
89
+
90
+
91
+ def scale_module(module, scale):
92
+ """
93
+ Scale the parameters of a module and return it.
94
+ """
95
+ for p in module.parameters():
96
+ p.detach().mul_(scale)
97
+ return module
98
+
99
+
100
+ def mean_flat(tensor):
101
+ """
102
+ Take the mean over all non-batch dimensions.
103
+ """
104
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
105
+
106
+
107
+ def sum_flat(tensor):
108
+ """
109
+ Take the sum over all non-batch dimensions.
110
+ """
111
+ return tensor.sum(dim=list(range(1, len(tensor.shape))))
112
+
113
+
114
+ def normalization(channels):
115
+ """
116
+ Make a standard normalization layer.
117
+
118
+ :param channels: number of input channels.
119
+ :return: an nn.Module for normalization.
120
+ """
121
+ return GroupNorm32(32, channels)
122
+
123
+
124
+ def timestep_embedding(timesteps, dim, max_period=10000):
125
+ """
126
+ Create sinusoidal timestep embeddings.
127
+
128
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
129
+ These may be fractional.
130
+ :param dim: the dimension of the output.
131
+ :param max_period: controls the minimum frequency of the embeddings.
132
+ :return: an [N x dim] Tensor of positional embeddings.
133
+ """
134
+ half = dim // 2
135
+ freqs = th.exp(
136
+ -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half
137
+ ).to(device=timesteps.device)
138
+ args = timesteps[:, None].float() * freqs[None]
139
+ embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
140
+ if dim % 2:
141
+ embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)
142
+ return embedding
143
+
144
+
145
+ def checkpoint(func, inputs, params, flag):
146
+ """
147
+ Evaluate a function without caching intermediate activations, allowing for
148
+ reduced memory at the expense of extra compute in the backward pass.
149
+ :param func: the function to evaluate.
150
+ :param inputs: the argument sequence to pass to `func`.
151
+ :param params: a sequence of parameters `func` depends on but does not
152
+ explicitly take as arguments.
153
+ :param flag: if False, disable gradient checkpointing.
154
+ """
155
+ if flag:
156
+ args = tuple(inputs) + tuple(params)
157
+ return CheckpointFunction.apply(func, len(inputs), *args)
158
+ else:
159
+ return func(*inputs)
160
+
161
+
162
+ class CheckpointFunction(th.autograd.Function):
163
+ @staticmethod
164
+ @th.cuda.amp.custom_fwd
165
+ def forward(ctx, run_function, length, *args):
166
+ ctx.run_function = run_function
167
+ ctx.input_length = length
168
+ ctx.save_for_backward(*args)
169
+ with th.no_grad():
170
+ output_tensors = ctx.run_function(*args[:length])
171
+ return output_tensors
172
+
173
+ @staticmethod
174
+ @th.cuda.amp.custom_bwd
175
+ def backward(ctx, *output_grads):
176
+ args = list(ctx.saved_tensors)
177
+
178
+ # Filter for inputs that require grad. If none, exit early.
179
+ input_indices = [i for (i, x) in enumerate(args) if x.requires_grad]
180
+ if not input_indices:
181
+ return (None, None) + tuple(None for _ in args)
182
+
183
+ with th.enable_grad():
184
+ for i in input_indices:
185
+ if i < ctx.input_length:
186
+ # Not sure why the OAI code does this little
187
+ # dance. It might not be necessary.
188
+ args[i] = args[i].detach().requires_grad_()
189
+ args[i] = args[i].view_as(args[i])
190
+ output_tensors = ctx.run_function(*args[: ctx.input_length])
191
+
192
+ if isinstance(output_tensors, th.Tensor):
193
+ output_tensors = [output_tensors]
194
+
195
+ # Filter for outputs that require grad. If none, exit early.
196
+ out_and_grads = [
197
+ (o, g) for (o, g) in zip(output_tensors, output_grads) if o.requires_grad
198
+ ]
199
+ if not out_and_grads:
200
+ return (None, None) + tuple(None for _ in args)
201
+
202
+ # Compute gradients on the filtered tensors.
203
+ computed_grads = th.autograd.grad(
204
+ [o for (o, g) in out_and_grads],
205
+ [args[i] for i in input_indices],
206
+ [g for (o, g) in out_and_grads],
207
+ )
208
+
209
+ # Reassemble the complete gradient tuple.
210
+ input_grads = [None for _ in args]
211
+ for i, g in zip(input_indices, computed_grads):
212
+ input_grads[i] = g
213
+ return (None, None) + tuple(input_grads)
diffusion/resample.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ All rights reserved.
4
+ This source code is licensed under the license found in the
5
+ LICENSE file in the root directory of this source tree.
6
+ """
7
+
8
+ """
9
+ original code from
10
+ https://github.com/GuyTevet/motion-diffusion-model/blob/main/diffusion/gaussian_diffusion.py
11
+ under an MIT license
12
+ https://github.com/GuyTevet/motion-diffusion-model/blob/main/LICENSE
13
+ """
14
+
15
+ from abc import ABC, abstractmethod
16
+
17
+ import numpy as np
18
+ import torch as th
19
+ import torch.distributed as dist
20
+
21
+
22
+ def create_named_schedule_sampler(name, diffusion):
23
+ """
24
+ Create a ScheduleSampler from a library of pre-defined samplers.
25
+
26
+ :param name: the name of the sampler.
27
+ :param diffusion: the diffusion object to sample for.
28
+ """
29
+ if name == "uniform":
30
+ return UniformSampler(diffusion)
31
+ elif name == "loss-second-moment":
32
+ return LossSecondMomentResampler(diffusion)
33
+ else:
34
+ raise NotImplementedError(f"unknown schedule sampler: {name}")
35
+
36
+
37
+ class ScheduleSampler(ABC):
38
+ """
39
+ A distribution over timesteps in the diffusion process, intended to reduce
40
+ variance of the objective.
41
+
42
+ By default, samplers perform unbiased importance sampling, in which the
43
+ objective's mean is unchanged.
44
+ However, subclasses may override sample() to change how the resampled
45
+ terms are reweighted, allowing for actual changes in the objective.
46
+ """
47
+
48
+ @abstractmethod
49
+ def weights(self):
50
+ """
51
+ Get a numpy array of weights, one per diffusion step.
52
+
53
+ The weights needn't be normalized, but must be positive.
54
+ """
55
+
56
+ def sample(self, batch_size, device):
57
+ """
58
+ Importance-sample timesteps for a batch.
59
+
60
+ :param batch_size: the number of timesteps.
61
+ :param device: the torch device to save to.
62
+ :return: a tuple (timesteps, weights):
63
+ - timesteps: a tensor of timestep indices.
64
+ - weights: a tensor of weights to scale the resulting losses.
65
+ """
66
+ w = self.weights()
67
+ p = w / np.sum(w)
68
+ indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
69
+ indices = th.from_numpy(indices_np).long().to(device)
70
+ weights_np = 1 / (len(p) * p[indices_np])
71
+ weights = th.from_numpy(weights_np).float().to(device)
72
+ return indices, weights
73
+
74
+
75
+ class UniformSampler(ScheduleSampler):
76
+ def __init__(self, diffusion):
77
+ self.diffusion = diffusion
78
+ self._weights = np.ones([diffusion.num_timesteps])
79
+
80
+ def weights(self):
81
+ return self._weights
82
+
83
+
84
+ class LossAwareSampler(ScheduleSampler):
85
+ def update_with_local_losses(self, local_ts, local_losses):
86
+ """
87
+ Update the reweighting using losses from a model.
88
+
89
+ Call this method from each rank with a batch of timesteps and the
90
+ corresponding losses for each of those timesteps.
91
+ This method will perform synchronization to make sure all of the ranks
92
+ maintain the exact same reweighting.
93
+
94
+ :param local_ts: an integer Tensor of timesteps.
95
+ :param local_losses: a 1D Tensor of losses.
96
+ """
97
+ batch_sizes = [
98
+ th.tensor([0], dtype=th.int32, device=local_ts.device)
99
+ for _ in range(dist.get_world_size())
100
+ ]
101
+ dist.all_gather(
102
+ batch_sizes,
103
+ th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),
104
+ )
105
+
106
+ # Pad all_gather batches to be the maximum batch size.
107
+ batch_sizes = [x.item() for x in batch_sizes]
108
+ max_bs = max(batch_sizes)
109
+
110
+ timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes]
111
+ loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes]
112
+ dist.all_gather(timestep_batches, local_ts)
113
+ dist.all_gather(loss_batches, local_losses)
114
+ timesteps = [
115
+ x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]
116
+ ]
117
+ losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
118
+ self.update_with_all_losses(timesteps, losses)
119
+
120
+ @abstractmethod
121
+ def update_with_all_losses(self, ts, losses):
122
+ """
123
+ Update the reweighting using losses from a model.
124
+
125
+ Sub-classes should override this method to update the reweighting
126
+ using losses from the model.
127
+
128
+ This method directly updates the reweighting without synchronizing
129
+ between workers. It is called by update_with_local_losses from all
130
+ ranks with identical arguments. Thus, it should have deterministic
131
+ behavior to maintain state across workers.
132
+
133
+ :param ts: a list of int timesteps.
134
+ :param losses: a list of float losses, one per timestep.
135
+ """
136
+
137
+
138
+ class LossSecondMomentResampler(LossAwareSampler):
139
+ def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
140
+ self.diffusion = diffusion
141
+ self.history_per_term = history_per_term
142
+ self.uniform_prob = uniform_prob
143
+ self._loss_history = np.zeros(
144
+ [diffusion.num_timesteps, history_per_term], dtype=np.float64
145
+ )
146
+ self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)
147
+
148
+ def weights(self):
149
+ if not self._warmed_up():
150
+ return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
151
+ weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1))
152
+ weights /= np.sum(weights)
153
+ weights *= 1 - self.uniform_prob
154
+ weights += self.uniform_prob / len(weights)
155
+ return weights
156
+
157
+ def update_with_all_losses(self, ts, losses):
158
+ for t, loss in zip(ts, losses):
159
+ if self._loss_counts[t] == self.history_per_term:
160
+ # Shift out the oldest loss term.
161
+ self._loss_history[t, :-1] = self._loss_history[t, 1:]
162
+ self._loss_history[t, -1] = loss
163
+ else:
164
+ self._loss_history[t, self._loss_counts[t]] = loss
165
+ self._loss_counts[t] += 1
166
+
167
+ def _warmed_up(self):
168
+ return (self._loss_counts == self.history_per_term).all()
diffusion/respace.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ All rights reserved.
4
+ This source code is licensed under the license found in the
5
+ LICENSE file in the root directory of this source tree.
6
+ """
7
+
8
+ """
9
+ original code from
10
+ https://github.com/GuyTevet/motion-diffusion-model/blob/main/diffusion/gaussian_diffusion.py
11
+ under an MIT license
12
+ https://github.com/GuyTevet/motion-diffusion-model/blob/main/LICENSE
13
+ """
14
+
15
+ import numpy as np
16
+ import torch as th
17
+
18
+ from .gaussian_diffusion import GaussianDiffusion
19
+
20
+
21
+ def space_timesteps(num_timesteps, section_counts):
22
+ """
23
+ Create a list of timesteps to use from an original diffusion process,
24
+ given the number of timesteps we want to take from equally-sized portions
25
+ of the original process.
26
+
27
+ For example, if there's 300 timesteps and the section counts are [10,15,20]
28
+ then the first 100 timesteps are strided to be 10 timesteps, the second 100
29
+ are strided to be 15 timesteps, and the final 100 are strided to be 20.
30
+
31
+ If the stride is a string starting with "ddim", then the fixed striding
32
+ from the DDIM paper is used, and only one section is allowed.
33
+
34
+ :param num_timesteps: the number of diffusion steps in the original
35
+ process to divide up.
36
+ :param section_counts: either a list of numbers, or a string containing
37
+ comma-separated numbers, indicating the step count
38
+ per section. As a special case, use "ddimN" where N
39
+ is a number of steps to use the striding from the
40
+ DDIM paper.
41
+ :return: a set of diffusion steps from the original process to use.
42
+ """
43
+ if isinstance(section_counts, str):
44
+ if section_counts.startswith("ddim"):
45
+ desired_count = int(section_counts[len("ddim") :])
46
+ for i in range(1, num_timesteps):
47
+ if len(range(0, num_timesteps, i)) == desired_count:
48
+ return set(range(0, num_timesteps, i))
49
+ raise ValueError(
50
+ f"cannot create exactly {num_timesteps} steps with an integer stride"
51
+ )
52
+ section_counts = [int(x) for x in section_counts.split(",")]
53
+ size_per = num_timesteps // len(section_counts)
54
+ extra = num_timesteps % len(section_counts)
55
+ start_idx = 0
56
+ all_steps = []
57
+ for i, section_count in enumerate(section_counts):
58
+ size = size_per + (1 if i < extra else 0)
59
+ if size < section_count:
60
+ raise ValueError(
61
+ f"cannot divide section of {size} steps into {section_count}"
62
+ )
63
+ if section_count <= 1:
64
+ frac_stride = 1
65
+ else:
66
+ frac_stride = (size - 1) / (section_count - 1)
67
+ cur_idx = 0.0
68
+ taken_steps = []
69
+ for _ in range(section_count):
70
+ taken_steps.append(start_idx + round(cur_idx))
71
+ cur_idx += frac_stride
72
+ all_steps += taken_steps
73
+ start_idx += size
74
+ return set(all_steps)
75
+
76
+
77
+ class SpacedDiffusion(GaussianDiffusion):
78
+ """
79
+ A diffusion process which can skip steps in a base diffusion process.
80
+
81
+ :param use_timesteps: a collection (sequence or set) of timesteps from the
82
+ original diffusion process to retain.
83
+ :param kwargs: the kwargs to create the base diffusion process.
84
+ """
85
+
86
+ def __init__(self, use_timesteps, **kwargs):
87
+ self.use_timesteps = set(use_timesteps)
88
+ self.timestep_map = []
89
+ self.original_num_steps = len(kwargs["betas"])
90
+
91
+ base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
92
+ last_alpha_cumprod = 1.0
93
+ new_betas = []
94
+ for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
95
+ if i in self.use_timesteps:
96
+ new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
97
+ last_alpha_cumprod = alpha_cumprod
98
+ self.timestep_map.append(i)
99
+ kwargs["betas"] = np.array(new_betas)
100
+ super().__init__(**kwargs)
101
+
102
+ def p_mean_variance(
103
+ self, model, *args, **kwargs
104
+ ): # pylint: disable=signature-differs
105
+ return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
106
+
107
+ def training_losses(
108
+ self, model, *args, **kwargs
109
+ ): # pylint: disable=signature-differs
110
+ return super().training_losses(self._wrap_model(model), *args, **kwargs)
111
+
112
+ def condition_mean(self, cond_fn, *args, **kwargs):
113
+ return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
114
+
115
+ def condition_score(self, cond_fn, *args, **kwargs):
116
+ return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
117
+
118
+ def _wrap_model(self, model):
119
+ if isinstance(model, _WrappedModel):
120
+ return model
121
+ return _WrappedModel(
122
+ model, self.timestep_map, self.rescale_timesteps, self.original_num_steps
123
+ )
124
+
125
+ def _scale_timesteps(self, t):
126
+ # Scaling is done by the wrapped model.
127
+ return t
128
+
129
+
130
+ class _WrappedModel:
131
+ def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps):
132
+ self.model = model
133
+ if hasattr(model, "step"):
134
+ self.step = model.step
135
+ self.add_frame_cond = model.add_frame_cond
136
+ self.timestep_map = timestep_map
137
+ self.rescale_timesteps = rescale_timesteps
138
+ self.original_num_steps = original_num_steps
139
+
140
+ def __call__(self, x, ts, **kwargs):
141
+ map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
142
+ new_ts = map_tensor[ts]
143
+ if self.rescale_timesteps:
144
+ new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
145
+ return self.model(x, new_ts, **kwargs)
flagged/audio/b90d90dbca93f47e8d01/audio.wav ADDED
Binary file (696 kB). View file
 
flagged/audio/d8e03e2e6deae2f981b1/audio.wav ADDED
Binary file (696 kB). View file
 
flagged/log.csv ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ audio,Number of Samples (default = 3),Sample Diversity (default = 0.97),output 0,output 1,output 2,output 3,output 4,output 5,output 6,output 7,output 8,output 9,flag,username,timestamp
2
+ ,1,0.69,,,,,,,,,,,,,2024-07-15 05:46:49.672259
3
+ flagged/audio/d8e03e2e6deae2f981b1/audio.wav,1,0.69,,,,,,,,,,,,,2024-07-15 06:28:21.003877
4
+ flagged/audio/b90d90dbca93f47e8d01/audio.wav,1,0.69,,,,,,,,,,,,,2024-07-15 06:28:24.442449
model/cfg_sampler.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ All rights reserved.
4
+ This source code is licensed under the license found in the
5
+ LICENSE file in the root directory of this source tree.
6
+ """
7
+
8
+ from copy import deepcopy
9
+
10
+ import numpy as np
11
+ import torch
12
+ import torch.nn as nn
13
+
14
+
15
+ # A wrapper model for Classifier-free guidance **SAMPLING** only
16
+ # https://arxiv.org/abs/2207.12598
17
+ class ClassifierFreeSampleModel(nn.Module):
18
+ def __init__(self, model):
19
+ super().__init__()
20
+ self.model = model # model is the actual model to run
21
+ self.nfeats = self.model.nfeats
22
+ self.cond_mode = self.model.cond_mode
23
+ self.add_frame_cond = self.model.add_frame_cond
24
+ if self.add_frame_cond is not None:
25
+ if self.model.resume_trans is not None:
26
+ self.transformer = self.model.transformer
27
+ self.tokenizer = self.model.tokenizer
28
+ self.step = self.model.step
29
+
30
+ def forward(self, x, timesteps, y=None):
31
+ out = self.model(x, timesteps, y, cond_drop_prob=0.0)
32
+ out_uncond = self.model(x, timesteps, y, cond_drop_prob=1.0)
33
+ return out_uncond + (y["scale"].view(-1, 1, 1) * (out - out_uncond))
model/diffusion.py ADDED
@@ -0,0 +1,403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ All rights reserved.
4
+ This source code is licensed under the license found in the
5
+ LICENSE file in the root directory of this source tree.
6
+ """
7
+
8
+ import json
9
+ from typing import Callable, Optional
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ from einops import rearrange
14
+ from einops.layers.torch import Rearrange
15
+
16
+ from model.guide import GuideTransformer
17
+ from model.modules.audio_encoder import Wav2VecEncoder
18
+ from model.modules.rotary_embedding_torch import RotaryEmbedding
19
+ from model.modules.transformer_modules import (
20
+ DecoderLayerStack,
21
+ FiLMTransformerDecoderLayer,
22
+ RegressionTransformer,
23
+ TransformerEncoderLayerRotary,
24
+ )
25
+ from model.utils import (
26
+ init_weight,
27
+ PositionalEncoding,
28
+ prob_mask_like,
29
+ setup_lip_regressor,
30
+ SinusoidalPosEmb,
31
+ )
32
+ from model.vqvae import setup_tokenizer
33
+ from torch.nn import functional as F
34
+ from utils.misc import prGreen, prRed
35
+
36
+
37
+ class Audio2LipRegressionTransformer(torch.nn.Module):
38
+ def __init__(
39
+ self,
40
+ n_vertices: int = 338,
41
+ causal: bool = False,
42
+ train_wav2vec: bool = False,
43
+ transformer_encoder_layers: int = 2,
44
+ transformer_decoder_layers: int = 4,
45
+ ):
46
+ super().__init__()
47
+ self.n_vertices = n_vertices
48
+
49
+ self.audio_encoder = Wav2VecEncoder()
50
+ if not train_wav2vec:
51
+ self.audio_encoder.eval()
52
+ for param in self.audio_encoder.parameters():
53
+ param.requires_grad = False
54
+
55
+ self.regression_model = RegressionTransformer(
56
+ transformer_encoder_layers=transformer_encoder_layers,
57
+ transformer_decoder_layers=transformer_decoder_layers,
58
+ d_model=512,
59
+ d_cond=512,
60
+ num_heads=4,
61
+ causal=causal,
62
+ )
63
+ self.project_output = torch.nn.Linear(512, self.n_vertices * 3)
64
+
65
+ def forward(self, audio):
66
+ """
67
+ :param audio: tensor of shape B x T x 1600
68
+ :return: tensor of shape B x T x n_vertices x 3 containing reconstructed lip geometry
69
+ """
70
+ B, T = audio.shape[0], audio.shape[1]
71
+
72
+ cond = self.audio_encoder(audio)
73
+
74
+ x = torch.zeros(B, T, 512, device=audio.device)
75
+ x = self.regression_model(x, cond)
76
+ x = self.project_output(x)
77
+
78
+ verts = x.view(B, T, self.n_vertices, 3)
79
+ return verts
80
+
81
+
82
+ class FiLMTransformer(nn.Module):
83
+ def __init__(
84
+ self,
85
+ args,
86
+ nfeats: int,
87
+ latent_dim: int = 512,
88
+ ff_size: int = 1024,
89
+ num_layers: int = 4,
90
+ num_heads: int = 4,
91
+ dropout: float = 0.1,
92
+ cond_feature_dim: int = 4800,
93
+ activation: Callable[[torch.Tensor], torch.Tensor] = F.gelu,
94
+ use_rotary: bool = True,
95
+ cond_mode: str = "audio",
96
+ split_type: str = "train",
97
+ device: str = "cuda",
98
+ **kwargs,
99
+ ) -> None:
100
+ super().__init__()
101
+ self.nfeats = nfeats
102
+ self.cond_mode = cond_mode
103
+ self.cond_feature_dim = cond_feature_dim
104
+ self.add_frame_cond = args.add_frame_cond
105
+ self.data_format = args.data_format
106
+ self.split_type = split_type
107
+ self.device = device
108
+
109
+ # positional embeddings
110
+ self.rotary = None
111
+ self.abs_pos_encoding = nn.Identity()
112
+ # if rotary, replace absolute embedding with a rotary embedding instance (absolute becomes an identity)
113
+ if use_rotary:
114
+ self.rotary = RotaryEmbedding(dim=latent_dim)
115
+ else:
116
+ self.abs_pos_encoding = PositionalEncoding(
117
+ latent_dim, dropout, batch_first=True
118
+ )
119
+
120
+ # time embedding processing
121
+ self.time_mlp = nn.Sequential(
122
+ SinusoidalPosEmb(latent_dim),
123
+ nn.Linear(latent_dim, latent_dim * 4),
124
+ nn.Mish(),
125
+ )
126
+ self.to_time_cond = nn.Sequential(
127
+ nn.Linear(latent_dim * 4, latent_dim),
128
+ )
129
+ self.to_time_tokens = nn.Sequential(
130
+ nn.Linear(latent_dim * 4, latent_dim * 2),
131
+ Rearrange("b (r d) -> b r d", r=2),
132
+ )
133
+
134
+ # null embeddings for guidance dropout
135
+ self.seq_len = args.max_seq_length
136
+ emb_len = 1998 # hardcoded for now
137
+ self.null_cond_embed = nn.Parameter(torch.randn(1, emb_len, latent_dim))
138
+ self.null_cond_hidden = nn.Parameter(torch.randn(1, latent_dim))
139
+ self.norm_cond = nn.LayerNorm(latent_dim)
140
+ self.setup_audio_models()
141
+
142
+ # set up pose/face specific parts of the model
143
+ self.input_projection = nn.Linear(self.nfeats, latent_dim)
144
+ if self.data_format == "pose":
145
+ cond_feature_dim = 1024
146
+ key_feature_dim = 104
147
+ self.step = 30
148
+ self.use_cm = True
149
+ self.setup_guide_models(args, latent_dim, key_feature_dim)
150
+ self.post_pose_layers = self._build_single_pose_conv(self.nfeats)
151
+ self.post_pose_layers.apply(init_weight)
152
+ self.final_conv = torch.nn.Conv1d(self.nfeats, self.nfeats, kernel_size=1)
153
+ self.receptive_field = 25
154
+ elif self.data_format == "face":
155
+ self.use_cm = False
156
+ cond_feature_dim = 1024 + 1014
157
+ self.setup_lip_models()
158
+ self.cond_encoder = nn.Sequential()
159
+ for _ in range(2):
160
+ self.cond_encoder.append(
161
+ TransformerEncoderLayerRotary(
162
+ d_model=latent_dim,
163
+ nhead=num_heads,
164
+ dim_feedforward=ff_size,
165
+ dropout=dropout,
166
+ activation=activation,
167
+ batch_first=True,
168
+ rotary=self.rotary,
169
+ )
170
+ )
171
+ self.cond_encoder.apply(init_weight)
172
+
173
+ self.cond_projection = nn.Linear(cond_feature_dim, latent_dim)
174
+ self.non_attn_cond_projection = nn.Sequential(
175
+ nn.LayerNorm(latent_dim),
176
+ nn.Linear(latent_dim, latent_dim),
177
+ nn.SiLU(),
178
+ nn.Linear(latent_dim, latent_dim),
179
+ )
180
+
181
+ # decoder
182
+ decoderstack = nn.ModuleList([])
183
+ for _ in range(num_layers):
184
+ decoderstack.append(
185
+ FiLMTransformerDecoderLayer(
186
+ latent_dim,
187
+ num_heads,
188
+ dim_feedforward=ff_size,
189
+ dropout=dropout,
190
+ activation=activation,
191
+ batch_first=True,
192
+ rotary=self.rotary,
193
+ use_cm=self.use_cm,
194
+ )
195
+ )
196
+ self.seqTransDecoder = DecoderLayerStack(decoderstack)
197
+ self.seqTransDecoder.apply(init_weight)
198
+ self.final_layer = nn.Linear(latent_dim, self.nfeats)
199
+ self.final_layer.apply(init_weight)
200
+
201
+ def _build_single_pose_conv(self, nfeats: int) -> nn.ModuleList:
202
+ post_pose_layers = torch.nn.ModuleList(
203
+ [
204
+ torch.nn.Conv1d(nfeats, max(256, nfeats), kernel_size=3, dilation=1),
205
+ torch.nn.Conv1d(max(256, nfeats), nfeats, kernel_size=3, dilation=2),
206
+ torch.nn.Conv1d(nfeats, nfeats, kernel_size=3, dilation=3),
207
+ torch.nn.Conv1d(nfeats, nfeats, kernel_size=3, dilation=1),
208
+ torch.nn.Conv1d(nfeats, nfeats, kernel_size=3, dilation=2),
209
+ torch.nn.Conv1d(nfeats, nfeats, kernel_size=3, dilation=3),
210
+ ]
211
+ )
212
+ return post_pose_layers
213
+
214
+ def _run_single_pose_conv(self, output: torch.Tensor) -> torch.Tensor:
215
+ output = torch.nn.functional.pad(output, pad=[self.receptive_field - 1, 0])
216
+ for _, layer in enumerate(self.post_pose_layers):
217
+ y = torch.nn.functional.leaky_relu(layer(output), negative_slope=0.2)
218
+ if self.split_type == "train":
219
+ y = torch.nn.functional.dropout(y, 0.2)
220
+ if output.shape[1] == y.shape[1]:
221
+ output = (output[:, :, -y.shape[-1] :] + y) / 2.0 # skip connection
222
+ else:
223
+ output = y
224
+ return output
225
+
226
+ def setup_guide_models(self, args, latent_dim: int, key_feature_dim: int) -> None:
227
+ # set up conditioning info
228
+ max_keyframe_len = len(list(range(self.seq_len))[:: self.step])
229
+ self.null_pose_embed = nn.Parameter(
230
+ torch.randn(1, max_keyframe_len, latent_dim)
231
+ )
232
+ prGreen(f"using keyframes: {self.null_pose_embed.shape}")
233
+ self.frame_cond_projection = nn.Linear(key_feature_dim, latent_dim)
234
+ self.frame_norm_cond = nn.LayerNorm(latent_dim)
235
+ # for test time set up keyframe transformer
236
+ self.resume_trans = None
237
+ if self.split_type == "test":
238
+ if hasattr(args, "resume_trans") and args.resume_trans is not None:
239
+ self.resume_trans = args.resume_trans
240
+ self.setup_guide_predictor(args.resume_trans)
241
+ else:
242
+ prRed("not using transformer, just using ground truth")
243
+
244
+ def setup_guide_predictor(self, cp_path: str) -> None:
245
+ cp_dir = cp_path.split("checkpoints/iter-")[0]
246
+ with open(f"{cp_dir}/args.json") as f:
247
+ trans_args = json.load(f)
248
+
249
+ # set up tokenizer based on trans_arg load point
250
+ self.tokenizer = setup_tokenizer(trans_args["resume_pth"])
251
+
252
+ # set up transformer
253
+ self.transformer = GuideTransformer(
254
+ tokens=self.tokenizer.n_clusters,
255
+ num_layers=trans_args["layers"],
256
+ dim=trans_args["dim"],
257
+ emb_len=1998,
258
+ num_audio_layers=trans_args["num_audio_layers"],
259
+ )
260
+ for param in self.transformer.parameters():
261
+ param.requires_grad = False
262
+ prGreen("loading TRANSFORMER checkpoint from {}".format(cp_path))
263
+ cp = torch.load(cp_path)
264
+ missing_keys, unexpected_keys = self.transformer.load_state_dict(
265
+ cp["model_state_dict"], strict=False
266
+ )
267
+ assert len(missing_keys) == 0, missing_keys
268
+ assert len(unexpected_keys) == 0, unexpected_keys
269
+
270
+ def setup_audio_models(self) -> None:
271
+ self.audio_model, self.audio_resampler = setup_lip_regressor()
272
+
273
+ def setup_lip_models(self) -> None:
274
+ self.lip_model = Audio2LipRegressionTransformer()
275
+ cp_path = "./assets/iter-0200000.pt"
276
+ cp = torch.load(cp_path, map_location=torch.device(self.device))
277
+ self.lip_model.load_state_dict(cp["model_state_dict"])
278
+ for param in self.lip_model.parameters():
279
+ param.requires_grad = False
280
+ prGreen(f"adding lip conditioning {cp_path}")
281
+
282
+ def parameters_w_grad(self):
283
+ return [p for p in self.parameters() if p.requires_grad]
284
+
285
+ def encode_audio(self, raw_audio: torch.Tensor) -> torch.Tensor:
286
+ device = next(self.parameters()).device
287
+ a0 = self.audio_resampler(raw_audio[:, :, 0].to(device))
288
+ a1 = self.audio_resampler(raw_audio[:, :, 1].to(device))
289
+ with torch.no_grad():
290
+ z0 = self.audio_model.feature_extractor(a0)
291
+ z1 = self.audio_model.feature_extractor(a1)
292
+ emb = torch.cat((z0, z1), axis=1).permute(0, 2, 1)
293
+ return emb
294
+
295
+ def encode_lip(self, audio: torch.Tensor, cond_embed: torch.Tensor) -> torch.Tensor:
296
+ reshaped_audio = audio.reshape((audio.shape[0], -1, 1600, 2))[..., 0]
297
+ # processes 4 seconds at a time
298
+ B, T, _ = reshaped_audio.shape
299
+ lip_cond = torch.zeros(
300
+ (audio.shape[0], T, 338, 3),
301
+ device=audio.device,
302
+ dtype=audio.dtype,
303
+ )
304
+ for i in range(0, T, 120):
305
+ lip_cond[:, i : i + 120, ...] = self.lip_model(
306
+ reshaped_audio[:, i : i + 120, ...]
307
+ )
308
+ lip_cond = lip_cond.permute(0, 2, 3, 1).reshape((B, 338 * 3, -1))
309
+ lip_cond = torch.nn.functional.interpolate(
310
+ lip_cond, size=cond_embed.shape[1], mode="nearest-exact"
311
+ ).permute(0, 2, 1)
312
+ cond_embed = torch.cat((cond_embed, lip_cond), dim=-1)
313
+ return cond_embed
314
+
315
+ def encode_keyframes(
316
+ self, y: torch.Tensor, cond_drop_prob: float, batch_size: int
317
+ ) -> torch.Tensor:
318
+ pred = y["keyframes"]
319
+ new_mask = y["mask"][..., :: self.step].squeeze((1, 2))
320
+ pred[~new_mask] = 0.0 # pad the unknown
321
+ pose_hidden = self.frame_cond_projection(pred.detach().clone().cuda())
322
+ pose_embed = self.abs_pos_encoding(pose_hidden)
323
+ pose_tokens = self.frame_norm_cond(pose_embed)
324
+ # do conditional dropout for guide poses
325
+ key_cond_drop_prob = cond_drop_prob
326
+ keep_mask_pose = prob_mask_like(
327
+ (batch_size,), 1 - key_cond_drop_prob, device=pose_tokens.device
328
+ )
329
+ keep_mask_pose_embed = rearrange(keep_mask_pose, "b -> b 1 1")
330
+ null_pose_embed = self.null_pose_embed.to(pose_tokens.dtype)
331
+ pose_tokens = torch.where(
332
+ keep_mask_pose_embed,
333
+ pose_tokens,
334
+ null_pose_embed[:, : pose_tokens.shape[1], :],
335
+ )
336
+ return pose_tokens
337
+
338
+ def forward(
339
+ self,
340
+ x: torch.Tensor,
341
+ times: torch.Tensor,
342
+ y: Optional[torch.Tensor] = None,
343
+ cond_drop_prob: float = 0.0,
344
+ ) -> torch.Tensor:
345
+ if x.dim() == 4:
346
+ x = x.permute(0, 3, 1, 2).squeeze(-1)
347
+ batch_size, device = x.shape[0], x.device
348
+ if self.cond_mode == "uncond":
349
+ cond_embed = torch.zeros(
350
+ (x.shape[0], x.shape[1], self.cond_feature_dim),
351
+ dtype=x.dtype,
352
+ device=x.device,
353
+ )
354
+ else:
355
+ cond_embed = y["audio"]
356
+ cond_embed = self.encode_audio(cond_embed)
357
+ if self.data_format == "face":
358
+ cond_embed = self.encode_lip(y["audio"], cond_embed)
359
+ pose_tokens = None
360
+ if self.data_format == "pose":
361
+ pose_tokens = self.encode_keyframes(y, cond_drop_prob, batch_size)
362
+ assert cond_embed is not None, "cond emb should not be none"
363
+ # process conditioning information
364
+ x = self.input_projection(x)
365
+ x = self.abs_pos_encoding(x)
366
+ audio_cond_drop_prob = cond_drop_prob
367
+ keep_mask = prob_mask_like(
368
+ (batch_size,), 1 - audio_cond_drop_prob, device=device
369
+ )
370
+ keep_mask_embed = rearrange(keep_mask, "b -> b 1 1")
371
+ keep_mask_hidden = rearrange(keep_mask, "b -> b 1")
372
+ cond_tokens = self.cond_projection(cond_embed)
373
+ cond_tokens = self.abs_pos_encoding(cond_tokens)
374
+ if self.data_format == "face":
375
+ cond_tokens = self.cond_encoder(cond_tokens)
376
+ null_cond_embed = self.null_cond_embed.to(cond_tokens.dtype)
377
+ cond_tokens = torch.where(
378
+ keep_mask_embed, cond_tokens, null_cond_embed[:, : cond_tokens.shape[1], :]
379
+ )
380
+ mean_pooled_cond_tokens = cond_tokens.mean(dim=-2)
381
+ cond_hidden = self.non_attn_cond_projection(mean_pooled_cond_tokens)
382
+
383
+ # create t conditioning
384
+ t_hidden = self.time_mlp(times)
385
+ t = self.to_time_cond(t_hidden)
386
+ t_tokens = self.to_time_tokens(t_hidden)
387
+ null_cond_hidden = self.null_cond_hidden.to(t.dtype)
388
+ cond_hidden = torch.where(keep_mask_hidden, cond_hidden, null_cond_hidden)
389
+ t += cond_hidden
390
+
391
+ # cross-attention conditioning
392
+ c = torch.cat((cond_tokens, t_tokens), dim=-2)
393
+ cond_tokens = self.norm_cond(c)
394
+
395
+ # Pass through the transformer decoder
396
+ output = self.seqTransDecoder(x, cond_tokens, t, memory2=pose_tokens)
397
+ output = self.final_layer(output)
398
+ if self.data_format == "pose":
399
+ output = output.permute(0, 2, 1)
400
+ output = self._run_single_pose_conv(output)
401
+ output = self.final_conv(output)
402
+ output = output.permute(0, 2, 1)
403
+ return output
model/guide.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ All rights reserved.
4
+ This source code is licensed under the license found in the
5
+ LICENSE file in the root directory of this source tree.
6
+ """
7
+
8
+ from typing import Callable, List
9
+
10
+ import torch
11
+ import torch as th
12
+ import torch.nn as nn
13
+ from einops import rearrange
14
+ from model.modules.rotary_embedding_torch import RotaryEmbedding
15
+
16
+ from model.modules.transformer_modules import (
17
+ DecoderLayerStack,
18
+ FiLMTransformerDecoderLayer,
19
+ PositionalEncoding,
20
+ )
21
+ from model.utils import prob_mask_like, setup_lip_regressor
22
+ from torch.distributions import Categorical
23
+ from torch.nn import functional as F
24
+
25
+
26
+ class GuideTransformer(nn.Module):
27
+ def __init__(
28
+ self,
29
+ tokens: int,
30
+ num_heads: int = 4,
31
+ num_layers: int = 4,
32
+ dim: int = 512,
33
+ ff_size: int = 1024,
34
+ dropout: float = 0.1,
35
+ activation: Callable = F.gelu,
36
+ use_rotary: bool = True,
37
+ cond_feature_dim: int = 1024,
38
+ emb_len: int = 798,
39
+ num_audio_layers: int = 2,
40
+ ):
41
+ super().__init__()
42
+ self.tokens = tokens
43
+ self.token_embedding = th.nn.Embedding(
44
+ num_embeddings=tokens + 1, # account for sequence start and end tokens
45
+ embedding_dim=dim,
46
+ )
47
+ self.abs_pos_encoding = nn.Identity()
48
+ # if rotary, replace absolute embedding with a rotary embedding instance (absolute becomes an identity)
49
+ if use_rotary:
50
+ self.rotary = RotaryEmbedding(dim=dim)
51
+ else:
52
+ self.abs_pos_encoding = PositionalEncoding(dim, dropout, batch_first=True)
53
+ self.setup_audio_models(cond_feature_dim, num_audio_layers)
54
+
55
+ self.null_cond_embed = nn.Parameter(torch.randn(1, emb_len, dim))
56
+ self.null_cond_hidden = nn.Parameter(torch.randn(1, dim))
57
+ self.norm_cond = nn.LayerNorm(dim)
58
+
59
+ self.cond_projection = nn.Linear(cond_feature_dim, dim)
60
+ self.non_attn_cond_projection = nn.Sequential(
61
+ nn.LayerNorm(dim),
62
+ nn.Linear(dim, dim),
63
+ nn.SiLU(),
64
+ nn.Linear(dim, dim),
65
+ )
66
+ # decoder
67
+ decoderstack = nn.ModuleList([])
68
+ for _ in range(num_layers):
69
+ decoderstack.append(
70
+ FiLMTransformerDecoderLayer(
71
+ dim,
72
+ num_heads,
73
+ dim_feedforward=ff_size,
74
+ dropout=dropout,
75
+ activation=activation,
76
+ batch_first=True,
77
+ rotary=self.rotary,
78
+ )
79
+ )
80
+ self.seqTransDecoder = DecoderLayerStack(decoderstack)
81
+ self.final_layer = nn.Linear(dim, tokens)
82
+
83
+ def _build_single_audio_conv(self, c: int) -> List[nn.Module]:
84
+ return [
85
+ torch.nn.Conv1d(c, max(256, c), kernel_size=3, dilation=1),
86
+ torch.nn.LeakyReLU(negative_slope=0.2),
87
+ torch.nn.Dropout(0.2),
88
+ #
89
+ torch.nn.Conv1d(max(256, c), max(256, c), kernel_size=3, dilation=2),
90
+ torch.nn.LeakyReLU(negative_slope=0.2),
91
+ torch.nn.Dropout(0.2),
92
+ #
93
+ torch.nn.Conv1d(max(128, c), max(128, c), kernel_size=3, dilation=3),
94
+ torch.nn.LeakyReLU(negative_slope=0.2),
95
+ torch.nn.Dropout(0.2),
96
+ #
97
+ torch.nn.Conv1d(max(128, c), c, kernel_size=3, dilation=1),
98
+ torch.nn.LeakyReLU(negative_slope=0.2),
99
+ torch.nn.Dropout(0.2),
100
+ #
101
+ torch.nn.Conv1d(c, c, kernel_size=3, dilation=2),
102
+ torch.nn.LeakyReLU(negative_slope=0.2),
103
+ torch.nn.Dropout(0.2),
104
+ #
105
+ torch.nn.Conv1d(c, c, kernel_size=3, dilation=3),
106
+ torch.nn.LeakyReLU(negative_slope=0.2),
107
+ torch.nn.Dropout(0.2),
108
+ ]
109
+
110
+ def setup_audio_models(self, cond_feature_dim: int, num_audio_layers: int) -> None:
111
+ pre_layers = []
112
+ for _ in range(num_audio_layers):
113
+ pre_layers += self._build_single_audio_conv(cond_feature_dim)
114
+ pre_layers += [
115
+ torch.nn.Conv1d(cond_feature_dim, cond_feature_dim, kernel_size=1)
116
+ ]
117
+ pre_layers = torch.nn.ModuleList(pre_layers)
118
+ self.pre_audio = nn.Sequential(*pre_layers)
119
+ self.audio_model, self.audio_resampler = setup_lip_regressor()
120
+
121
+ def encode_audio(self, raw_audio: torch.Tensor) -> torch.Tensor:
122
+ device = next(self.parameters()).device
123
+ a0 = self.audio_resampler(raw_audio[:, :, 0].to(device)) # B x T
124
+ a1 = self.audio_resampler(raw_audio[:, :, 1].to(device)) # B x T
125
+ with torch.no_grad():
126
+ z0 = self.audio_model.feature_extractor(a0)
127
+ z1 = self.audio_model.feature_extractor(a1)
128
+ emb = torch.cat((z0, z1), axis=1).permute(0, 2, 1)
129
+ return emb
130
+
131
+ def get_tgt_mask(self, size: int, device: str) -> torch.tensor:
132
+ mask = torch.tril(
133
+ torch.ones((size, size), device=device) == 1
134
+ ) # Lower triangular matrix
135
+ mask = mask.float()
136
+ mask = mask.masked_fill(mask == 0, float("-inf")) # Convert zeros to -inf
137
+ mask = mask.masked_fill(mask == 1, float(0.0)) # Convert ones to 0
138
+ return mask
139
+
140
+ def forward(
141
+ self, tokens: th.Tensor, condition: th.Tensor, cond_drop_prob: float = 0.0
142
+ ) -> torch.Tensor:
143
+ batch_size, device = tokens.shape[0], tokens.device
144
+
145
+ x = self.token_embedding(tokens)
146
+ x = self.abs_pos_encoding(x)
147
+ tgt_mask = self.get_tgt_mask(x.shape[1], x.device)
148
+
149
+ cond_embed = self.encode_audio(condition)
150
+ keep_mask = prob_mask_like((batch_size,), 1 - cond_drop_prob, device=device)
151
+ keep_mask_embed = rearrange(keep_mask, "b -> b 1 1")
152
+ keep_mask_hidden = rearrange(keep_mask, "b -> b 1")
153
+ cond_tokens = self.pre_audio(cond_embed.permute(0, 2, 1)).permute(0, 2, 1)
154
+ #
155
+ cond_tokens = self.cond_projection(cond_tokens)
156
+ cond_tokens = self.abs_pos_encoding(cond_tokens)
157
+
158
+ null_cond_embed = self.null_cond_embed.to(cond_tokens.dtype)
159
+ cond_tokens = torch.where(
160
+ keep_mask_embed, cond_tokens, null_cond_embed[:, : cond_tokens.shape[1], :]
161
+ )
162
+ mean_pooled_cond_tokens = cond_tokens.mean(dim=-2)
163
+ cond_hidden = self.non_attn_cond_projection(mean_pooled_cond_tokens)
164
+
165
+ # FiLM conditioning
166
+ null_cond_hidden = self.null_cond_hidden.to(cond_tokens.dtype)
167
+ cond_hidden = torch.where(keep_mask_hidden, cond_hidden, null_cond_hidden)
168
+ cond_tokens = self.norm_cond(cond_tokens)
169
+
170
+ output = self.seqTransDecoder(x, cond_tokens, cond_hidden, tgt_mask=tgt_mask)
171
+ output = self.final_layer(output)
172
+ return output
173
+
174
+ def generate(
175
+ self,
176
+ condition: th.Tensor,
177
+ sequence_length: int,
178
+ layers: int,
179
+ n_sequences: int = 1,
180
+ max_key_len: int = 8,
181
+ max_seq_len: int = 240,
182
+ top_p: float = 0.94,
183
+ ) -> torch.Tensor:
184
+ """
185
+ :param sequence_length: number of tokens to generate in autoregressive fashion
186
+ :param n_sequences: number of sequences to generate simultaneously
187
+ :param temperature: temerature of the softmax for sampling from the output logits
188
+ :return n_sequences x sequence_length LongTensor containing generated tokens
189
+ """
190
+ assert max_key_len == int(max_seq_len / 30), "currently only running for 1fps"
191
+ max_key_len *= layers
192
+ with th.no_grad():
193
+ input_tokens = (
194
+ th.zeros(n_sequences, 1, dtype=th.int64).to(condition.device)
195
+ + self.tokens
196
+ )
197
+ for _ in range(sequence_length * layers):
198
+ curr_input_tokens = input_tokens
199
+ curr_condition = condition
200
+ logits = self.forward(curr_input_tokens, curr_condition)
201
+ logits = logits[:, -1, :] # only most recent time step is relevant
202
+ one_hot = th.nn.functional.softmax(logits, dim=-1)
203
+ sorted_probs, indices = torch.sort(one_hot, dim=-1, descending=True)
204
+ cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
205
+ nucleus = cumulative_probs < top_p
206
+ nucleus = torch.cat(
207
+ [
208
+ nucleus.new_ones(nucleus.shape[:-1] + (1,)),
209
+ nucleus[..., :-1],
210
+ ],
211
+ dim=-1,
212
+ )
213
+ sorted_probs[~nucleus] = 0
214
+ sorted_probs /= sorted_probs.sum(-1, keepdim=True)
215
+ dist = Categorical(sorted_probs)
216
+ idx = dist.sample()
217
+ tokens = indices.gather(-1, idx.unsqueeze(-1))
218
+ input_tokens = th.cat([input_tokens, tokens], dim=-1)
219
+
220
+ # return generated tokens except for sequence start token
221
+ tokens = input_tokens[:, 1:].contiguous()
222
+ return tokens
model/modules/audio_encoder.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ All rights reserved.
4
+ This source code is licensed under the license found in the
5
+ LICENSE file in the root directory of this source tree.
6
+ """
7
+
8
+ import fairseq
9
+ import torch as th
10
+ import torchaudio as ta
11
+
12
+ wav2vec_model_path = "./assets/wav2vec_large.pt"
13
+
14
+
15
+ def weights_init(m):
16
+ if isinstance(m, th.nn.Conv1d):
17
+ th.nn.init.xavier_uniform_(m.weight)
18
+ try:
19
+ th.nn.init.constant_(m.bias, 0.01)
20
+ except:
21
+ pass
22
+
23
+
24
+ class Wav2VecEncoder(th.nn.Module):
25
+ def __init__(self):
26
+ super().__init__()
27
+ self.resampler = ta.transforms.Resample(orig_freq=48000, new_freq=16000)
28
+ model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task(
29
+ [wav2vec_model_path]
30
+ )
31
+ self.wav2vec_model = model[0]
32
+
33
+ def forward(self, audio: th.Tensor):
34
+ """
35
+ :param audio: B x T x 1600
36
+ :return: B x T_wav2vec x 512
37
+ """
38
+ audio = audio.view(audio.shape[0], audio.shape[1] * 1600)
39
+ audio = self.resampler(audio)
40
+ audio = th.cat(
41
+ [th.zeros(audio.shape[0], 320, device=audio.device), audio], dim=-1
42
+ ) # zero padding on the left
43
+ x = self.wav2vec_model.feature_extractor(audio)
44
+ x = self.wav2vec_model.feature_aggregator(x)
45
+ x = x.permute(0, 2, 1).contiguous()
46
+ return x
47
+
48
+
49
+ class Wav2VecDownsampler(th.nn.Module):
50
+ def __init__(self):
51
+ super().__init__()
52
+ self.conv1 = th.nn.Conv1d(512, 512, kernel_size=3)
53
+ self.conv2 = th.nn.Conv1d(512, 512, kernel_size=3)
54
+ self.norm = th.nn.LayerNorm(512)
55
+
56
+ def forward(self, x: th.Tensor, target_length: int):
57
+ """
58
+ :param x: B x T x 512 tensor containing wav2vec features at 100Hz
59
+ :return: B x target_length x 512 tensor containing downsampled wav2vec features at 30Hz
60
+ """
61
+ x = x.permute(0, 2, 1).contiguous()
62
+ # first conv
63
+ x = th.nn.functional.pad(x, pad=(2, 0))
64
+ x = th.nn.functional.relu(self.conv1(x))
65
+ # first downsampling
66
+ x = th.nn.functional.interpolate(x, size=(x.shape[-1] + target_length) // 2)
67
+ # second conv
68
+ x = th.nn.functional.pad(x, pad=(2, 0))
69
+ x = self.conv2(x)
70
+ # second downsampling
71
+ x = th.nn.functional.interpolate(x, size=target_length)
72
+ # layer norm
73
+ x = x.permute(0, 2, 1).contiguous()
74
+ x = self.norm(x)
75
+ return x
76
+
77
+
78
+ class AudioTcn(th.nn.Module):
79
+ def __init__(
80
+ self,
81
+ encoding_dim: int = 128,
82
+ use_melspec: bool = True,
83
+ use_wav2vec: bool = True,
84
+ ):
85
+ """
86
+ :param encoding_dim: size of encoding
87
+ :param use_melspec: extract mel spectrogram features as input
88
+ :param use_wav2vec: extract wav2vec features as input
89
+ """
90
+ super().__init__()
91
+ self.encoding_dim = encoding_dim
92
+ self.use_melspec = use_melspec
93
+ self.use_wav2vec = use_wav2vec
94
+
95
+ if use_melspec:
96
+ # hop_length=400 -> two feature vectors per visual frame (downsampling to 24kHz -> 800 samples per frame)
97
+ self.melspec = th.nn.Sequential(
98
+ ta.transforms.Resample(orig_freq=48000, new_freq=24000),
99
+ ta.transforms.MelSpectrogram(
100
+ sample_rate=24000,
101
+ n_fft=1024,
102
+ win_length=800,
103
+ hop_length=400,
104
+ n_mels=80,
105
+ ),
106
+ )
107
+
108
+ if use_wav2vec:
109
+ model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task(
110
+ [wav2vec_model_path]
111
+ )
112
+ self.wav2vec_model = model[0]
113
+ self.wav2vec_model.eval()
114
+ self.wav2vec_postprocess = th.nn.Conv1d(512, 256, kernel_size=3)
115
+ self.wav2vec_postprocess.apply(lambda x: weights_init(x))
116
+
117
+ # temporal model
118
+ input_dim = 0 + (160 if use_melspec else 0) + (256 if use_wav2vec else 0)
119
+ self.layers = th.nn.ModuleList(
120
+ [
121
+ th.nn.Conv1d(
122
+ input_dim, max(256, encoding_dim), kernel_size=3, dilation=1
123
+ ), # 2 (+1)
124
+ th.nn.Conv1d(
125
+ max(256, encoding_dim), encoding_dim, kernel_size=3, dilation=2
126
+ ), # 4 (+1)
127
+ th.nn.Conv1d(
128
+ encoding_dim, encoding_dim, kernel_size=3, dilation=3
129
+ ), # 6 (+1)
130
+ th.nn.Conv1d(
131
+ encoding_dim, encoding_dim, kernel_size=3, dilation=1
132
+ ), # 2 (+1)
133
+ th.nn.Conv1d(
134
+ encoding_dim, encoding_dim, kernel_size=3, dilation=2
135
+ ), # 4 (+1)
136
+ th.nn.Conv1d(
137
+ encoding_dim, encoding_dim, kernel_size=3, dilation=3
138
+ ), # 6 (+1)
139
+ ]
140
+ )
141
+ self.layers.apply(lambda x: weights_init(x))
142
+ self.receptive_field = 25
143
+
144
+ self.final = th.nn.Conv1d(encoding_dim, encoding_dim, kernel_size=1)
145
+ self.final.apply(lambda x: weights_init(x))
146
+
147
+ def forward(self, audio):
148
+ """
149
+ :param audio: B x T x 1600 tensor containing audio samples for each frame
150
+ :return: B x T x encoding_dim tensor containing audio encodings for each frame
151
+ """
152
+ B, T = audio.shape[0], audio.shape[1]
153
+
154
+ # preprocess raw audio signal to extract feature vectors
155
+ audio = audio.view(B, T * 1600)
156
+ x_mel, x_w2v = th.zeros(B, 0, T).to(audio.device), th.zeros(B, 0, T).to(
157
+ audio.device
158
+ )
159
+ if self.use_melspec:
160
+ x_mel = self.melspec(audio)[:, :, 1:].contiguous()
161
+ x_mel = th.log(x_mel.clamp(min=1e-10, max=None))
162
+ x_mel = (
163
+ x_mel.permute(0, 2, 1)
164
+ .contiguous()
165
+ .view(x_mel.shape[0], T, 160)
166
+ .permute(0, 2, 1)
167
+ .contiguous()
168
+ )
169
+ if self.use_wav2vec:
170
+ with th.no_grad():
171
+ x_w2v = ta.functional.resample(audio, 48000, 16000)
172
+ x_w2v = self.wav2vec_model.feature_extractor(x_w2v)
173
+ x_w2v = self.wav2vec_model.feature_aggregator(x_w2v)
174
+ x_w2v = self.wav2vec_postprocess(th.nn.functional.pad(x_w2v, pad=[2, 0]))
175
+ x_w2v = th.nn.functional.interpolate(
176
+ x_w2v, size=T, align_corners=True, mode="linear"
177
+ )
178
+ x = th.cat([x_mel, x_w2v], dim=1)
179
+
180
+ # process signal with TCN
181
+ x = th.nn.functional.pad(x, pad=[self.receptive_field - 1, 0])
182
+ for layer_idx, layer in enumerate(self.layers):
183
+ y = th.nn.functional.leaky_relu(layer(x), negative_slope=0.2)
184
+ if self.training:
185
+ y = th.nn.functional.dropout(y, 0.2)
186
+ if x.shape[1] == y.shape[1]:
187
+ x = (x[:, :, -y.shape[-1] :] + y) / 2.0 # skip connection
188
+ else:
189
+ x = y
190
+
191
+ x = self.final(x)
192
+ x = x.permute(0, 2, 1).contiguous()
193
+
194
+ return x
model/modules/rotary_embedding_torch.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ All rights reserved.
4
+ This source code is licensed under the license found in the
5
+ LICENSE file in the root directory of this source tree.
6
+ """
7
+
8
+ from inspect import isfunction
9
+ from math import log, pi
10
+
11
+ import torch
12
+ from einops import rearrange, repeat
13
+ from torch import einsum, nn
14
+
15
+ # helper functions
16
+
17
+
18
+ def exists(val):
19
+ return val is not None
20
+
21
+
22
+ def broadcat(tensors, dim=-1):
23
+ num_tensors = len(tensors)
24
+ shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
25
+ assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
26
+ shape_len = list(shape_lens)[0]
27
+
28
+ dim = (dim + shape_len) if dim < 0 else dim
29
+ dims = list(zip(*map(lambda t: list(t.shape), tensors)))
30
+
31
+ expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
32
+ assert all(
33
+ [*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]
34
+ ), "invalid dimensions for broadcastable concatentation"
35
+ max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
36
+ expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
37
+ expanded_dims.insert(dim, (dim, dims[dim]))
38
+ expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
39
+ tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
40
+ return torch.cat(tensors, dim=dim)
41
+
42
+
43
+ # rotary embedding helper functions
44
+
45
+
46
+ def rotate_half(x):
47
+ x = rearrange(x, "... (d r) -> ... d r", r=2)
48
+ x1, x2 = x.unbind(dim=-1)
49
+ x = torch.stack((-x2, x1), dim=-1)
50
+ return rearrange(x, "... d r -> ... (d r)")
51
+
52
+
53
+ def apply_rotary_emb(freqs, t, start_index=0):
54
+ freqs = freqs.to(t)
55
+ rot_dim = freqs.shape[-1]
56
+ end_index = start_index + rot_dim
57
+ assert (
58
+ rot_dim <= t.shape[-1]
59
+ ), f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}"
60
+ t_left, t, t_right = (
61
+ t[..., :start_index],
62
+ t[..., start_index:end_index],
63
+ t[..., end_index:],
64
+ )
65
+ t = (t * freqs.cos()) + (rotate_half(t) * freqs.sin())
66
+ return torch.cat((t_left, t, t_right), dim=-1)
67
+
68
+
69
+ # learned rotation helpers
70
+
71
+
72
+ def apply_learned_rotations(rotations, t, start_index=0, freq_ranges=None):
73
+ if exists(freq_ranges):
74
+ rotations = einsum("..., f -> ... f", rotations, freq_ranges)
75
+ rotations = rearrange(rotations, "... r f -> ... (r f)")
76
+
77
+ rotations = repeat(rotations, "... n -> ... (n r)", r=2)
78
+ return apply_rotary_emb(rotations, t, start_index=start_index)
79
+
80
+
81
+ # classes
82
+
83
+
84
+ class RotaryEmbedding(nn.Module):
85
+ def __init__(
86
+ self,
87
+ dim,
88
+ custom_freqs=None,
89
+ freqs_for="lang",
90
+ theta=10000,
91
+ max_freq=10,
92
+ num_freqs=1,
93
+ learned_freq=False,
94
+ ):
95
+ super().__init__()
96
+ if exists(custom_freqs):
97
+ freqs = custom_freqs
98
+ elif freqs_for == "lang":
99
+ freqs = 1.0 / (
100
+ theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
101
+ )
102
+ elif freqs_for == "pixel":
103
+ freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
104
+ elif freqs_for == "constant":
105
+ freqs = torch.ones(num_freqs).float()
106
+ else:
107
+ raise ValueError(f"unknown modality {freqs_for}")
108
+
109
+ self.cache = dict()
110
+
111
+ if learned_freq:
112
+ self.freqs = nn.Parameter(freqs)
113
+ else:
114
+ self.register_buffer("freqs", freqs)
115
+
116
+ def rotate_queries_or_keys(self, t, seq_dim=-2):
117
+ device = t.device
118
+ seq_len = t.shape[seq_dim]
119
+ freqs = self.forward(
120
+ lambda: torch.arange(seq_len, device=device), cache_key=seq_len
121
+ )
122
+ return apply_rotary_emb(freqs, t)
123
+
124
+ def forward(self, t, cache_key=None):
125
+ if exists(cache_key) and cache_key in self.cache:
126
+ return self.cache[cache_key]
127
+
128
+ if isfunction(t):
129
+ t = t()
130
+
131
+ freqs = self.freqs
132
+
133
+ freqs = torch.einsum("..., f -> ... f", t.type(freqs.dtype), freqs)
134
+ freqs = repeat(freqs, "... n -> ... (n r)", r=2)
135
+
136
+ if exists(cache_key):
137
+ self.cache[cache_key] = freqs
138
+
139
+ return freqs
model/modules/transformer_modules.py ADDED
@@ -0,0 +1,702 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ All rights reserved.
4
+ This source code is licensed under the license found in the
5
+ LICENSE file in the root directory of this source tree.
6
+ """
7
+
8
+ import math
9
+ from typing import Any, Callable, List, Optional, Union
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ from einops import rearrange
14
+ from torch import Tensor
15
+ from torch.nn import functional as F
16
+
17
+
18
+ def generate_causal_mask(source_length, target_length, device="cpu"):
19
+ if source_length == target_length:
20
+ mask = (
21
+ torch.triu(torch.ones(target_length, source_length, device=device)) == 1
22
+ ).transpose(0, 1)
23
+ else:
24
+ mask = torch.zeros(target_length, source_length, device=device)
25
+ idx = torch.linspace(0, source_length, target_length + 1)[1:].round().long()
26
+ for i in range(target_length):
27
+ mask[i, 0 : idx[i]] = 1
28
+
29
+ return (
30
+ mask.float()
31
+ .masked_fill(mask == 0, float("-inf"))
32
+ .masked_fill(mask == 1, float(0.0))
33
+ )
34
+
35
+
36
+ class TransformerEncoderLayerRotary(nn.Module):
37
+ def __init__(
38
+ self,
39
+ d_model: int,
40
+ nhead: int,
41
+ dim_feedforward: int = 2048,
42
+ dropout: float = 0.1,
43
+ activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
44
+ layer_norm_eps: float = 1e-5,
45
+ batch_first: bool = False,
46
+ norm_first: bool = True,
47
+ rotary=None,
48
+ ) -> None:
49
+ super().__init__()
50
+ self.self_attn = nn.MultiheadAttention(
51
+ d_model, nhead, dropout=dropout, batch_first=batch_first
52
+ )
53
+ # Implementation of Feedforward model
54
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
55
+ self.dropout = nn.Dropout(dropout)
56
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
57
+
58
+ self.norm_first = norm_first
59
+ self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
60
+ self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
61
+ self.dropout1 = nn.Dropout(dropout)
62
+ self.dropout2 = nn.Dropout(dropout)
63
+ self.activation = activation
64
+
65
+ self.rotary = rotary
66
+ self.use_rotary = rotary is not None
67
+
68
+ def forward(
69
+ self,
70
+ src: Tensor,
71
+ src_mask: Optional[Tensor] = None,
72
+ src_key_padding_mask: Optional[Tensor] = None,
73
+ ) -> Tensor:
74
+ x = src
75
+ if self.norm_first:
76
+ x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask)
77
+ x = x + self._ff_block(self.norm2(x))
78
+ else:
79
+ x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask))
80
+ x = self.norm2(x + self._ff_block(x))
81
+
82
+ return x
83
+
84
+ # self-attention block
85
+ def _sa_block(
86
+ self, x: Tensor, attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor]
87
+ ) -> Tensor:
88
+ qk = self.rotary.rotate_queries_or_keys(x) if self.use_rotary else x
89
+ x = self.self_attn(
90
+ qk,
91
+ qk,
92
+ x,
93
+ attn_mask=attn_mask,
94
+ key_padding_mask=key_padding_mask,
95
+ need_weights=False,
96
+ )[0]
97
+ return self.dropout1(x)
98
+
99
+ # feed forward block
100
+ def _ff_block(self, x: Tensor) -> Tensor:
101
+ x = self.linear2(self.dropout(self.activation(self.linear1(x))))
102
+ return self.dropout2(x)
103
+
104
+
105
+ class DenseFiLM(nn.Module):
106
+ """Feature-wise linear modulation (FiLM) generator."""
107
+
108
+ def __init__(self, embed_channels):
109
+ super().__init__()
110
+ self.embed_channels = embed_channels
111
+ self.block = nn.Sequential(
112
+ nn.Mish(), nn.Linear(embed_channels, embed_channels * 2)
113
+ )
114
+
115
+ def forward(self, position):
116
+ pos_encoding = self.block(position)
117
+ pos_encoding = rearrange(pos_encoding, "b c -> b 1 c")
118
+ scale_shift = pos_encoding.chunk(2, dim=-1)
119
+ return scale_shift
120
+
121
+
122
+ def featurewise_affine(x, scale_shift):
123
+ scale, shift = scale_shift
124
+ return (scale + 1) * x + shift
125
+
126
+
127
+ class FiLMTransformerDecoderLayer(nn.Module):
128
+ def __init__(
129
+ self,
130
+ d_model: int,
131
+ nhead: int,
132
+ dim_feedforward=2048,
133
+ dropout=0.1,
134
+ activation=F.relu,
135
+ layer_norm_eps=1e-5,
136
+ batch_first=False,
137
+ norm_first=True,
138
+ rotary=None,
139
+ use_cm=False,
140
+ ):
141
+ super().__init__()
142
+ self.self_attn = nn.MultiheadAttention(
143
+ d_model, nhead, dropout=dropout, batch_first=batch_first
144
+ )
145
+ self.multihead_attn = nn.MultiheadAttention(
146
+ d_model, nhead, dropout=dropout, batch_first=batch_first
147
+ )
148
+ # Feedforward
149
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
150
+ self.dropout = nn.Dropout(dropout)
151
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
152
+
153
+ self.norm_first = norm_first
154
+ self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
155
+ self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
156
+ self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps)
157
+ self.dropout1 = nn.Dropout(dropout)
158
+ self.dropout2 = nn.Dropout(dropout)
159
+ self.dropout3 = nn.Dropout(dropout)
160
+ self.activation = activation
161
+
162
+ self.film1 = DenseFiLM(d_model)
163
+ self.film2 = DenseFiLM(d_model)
164
+ self.film3 = DenseFiLM(d_model)
165
+
166
+ if use_cm:
167
+ self.multihead_attn2 = nn.MultiheadAttention( # 2
168
+ d_model, nhead, dropout=dropout, batch_first=batch_first
169
+ )
170
+ self.norm2a = nn.LayerNorm(d_model, eps=layer_norm_eps) # 2
171
+ self.dropout2a = nn.Dropout(dropout) # 2
172
+ self.film2a = DenseFiLM(d_model) # 2
173
+
174
+ self.rotary = rotary
175
+ self.use_rotary = rotary is not None
176
+
177
+ # x, cond, t
178
+ def forward(
179
+ self,
180
+ tgt,
181
+ memory,
182
+ t,
183
+ tgt_mask=None,
184
+ memory_mask=None,
185
+ tgt_key_padding_mask=None,
186
+ memory_key_padding_mask=None,
187
+ memory2=None,
188
+ ):
189
+ x = tgt
190
+ if self.norm_first:
191
+ # self-attention -> film -> residual
192
+ x_1 = self._sa_block(self.norm1(x), tgt_mask, tgt_key_padding_mask)
193
+ x = x + featurewise_affine(x_1, self.film1(t))
194
+ # cross-attention -> film -> residual
195
+ x_2 = self._mha_block(
196
+ self.norm2(x),
197
+ memory,
198
+ memory_mask,
199
+ memory_key_padding_mask,
200
+ self.multihead_attn,
201
+ self.dropout2,
202
+ )
203
+ x = x + featurewise_affine(x_2, self.film2(t))
204
+ if memory2 is not None:
205
+ # cross-attention x2 -> film -> residual
206
+ x_2a = self._mha_block(
207
+ self.norm2a(x),
208
+ memory2,
209
+ memory_mask,
210
+ memory_key_padding_mask,
211
+ self.multihead_attn2,
212
+ self.dropout2a,
213
+ )
214
+ x = x + featurewise_affine(x_2a, self.film2a(t))
215
+ # feedforward -> film -> residual
216
+ x_3 = self._ff_block(self.norm3(x))
217
+ x = x + featurewise_affine(x_3, self.film3(t))
218
+ else:
219
+ x = self.norm1(
220
+ x
221
+ + featurewise_affine(
222
+ self._sa_block(x, tgt_mask, tgt_key_padding_mask), self.film1(t)
223
+ )
224
+ )
225
+ x = self.norm2(
226
+ x
227
+ + featurewise_affine(
228
+ self._mha_block(x, memory, memory_mask, memory_key_padding_mask),
229
+ self.film2(t),
230
+ )
231
+ )
232
+ x = self.norm3(x + featurewise_affine(self._ff_block(x), self.film3(t)))
233
+ return x
234
+
235
+ # self-attention block
236
+ # qkv
237
+ def _sa_block(self, x, attn_mask, key_padding_mask):
238
+ qk = self.rotary.rotate_queries_or_keys(x) if self.use_rotary else x
239
+ x = self.self_attn(
240
+ qk,
241
+ qk,
242
+ x,
243
+ attn_mask=attn_mask,
244
+ key_padding_mask=key_padding_mask,
245
+ need_weights=False,
246
+ )[0]
247
+ return self.dropout1(x)
248
+
249
+ # multihead attention block
250
+ # qkv
251
+ def _mha_block(self, x, mem, attn_mask, key_padding_mask, mha, dropout):
252
+ q = self.rotary.rotate_queries_or_keys(x) if self.use_rotary else x
253
+ k = self.rotary.rotate_queries_or_keys(mem) if self.use_rotary else mem
254
+ x = mha(
255
+ q,
256
+ k,
257
+ mem,
258
+ attn_mask=attn_mask,
259
+ key_padding_mask=key_padding_mask,
260
+ need_weights=False,
261
+ )[0]
262
+ return dropout(x)
263
+
264
+ # feed forward block
265
+ def _ff_block(self, x):
266
+ x = self.linear2(self.dropout(self.activation(self.linear1(x))))
267
+ return self.dropout3(x)
268
+
269
+
270
+ class DecoderLayerStack(nn.Module):
271
+ def __init__(self, stack):
272
+ super().__init__()
273
+ self.stack = stack
274
+
275
+ def forward(self, x, cond, t, tgt_mask=None, memory2=None):
276
+ for layer in self.stack:
277
+ x = layer(x, cond, t, tgt_mask=tgt_mask, memory2=memory2)
278
+ return x
279
+
280
+
281
+ class PositionalEncoding(nn.Module):
282
+ def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 1024):
283
+ super().__init__()
284
+ pe = torch.zeros(max_len, d_model)
285
+ position = torch.arange(0, max_len).unsqueeze(1)
286
+ div_term = torch.exp(
287
+ torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
288
+ )
289
+ pe[:, 0::2] = torch.sin(position * div_term)
290
+ pe[:, 1::2] = torch.cos(position * div_term)
291
+
292
+ self.register_buffer("pe", pe)
293
+ self.dropout = nn.Dropout(p=dropout)
294
+
295
+ def forward(self, x: torch.Tensor):
296
+ """
297
+ :param x: B x T x d_model tensor
298
+ :return: B x T x d_model tensor
299
+ """
300
+ x = x + self.pe[None, : x.shape[1], :]
301
+ x = self.dropout(x)
302
+ return x
303
+
304
+
305
+ class TimestepEncoding(nn.Module):
306
+ def __init__(self, embedding_dim: int):
307
+ super().__init__()
308
+
309
+ # Fourier embedding
310
+ half_dim = embedding_dim // 2
311
+ emb = math.log(10000) / (half_dim - 1)
312
+ emb = torch.exp(torch.arange(half_dim) * -emb)
313
+ self.register_buffer("emb", emb)
314
+
315
+ # encoding
316
+ self.encoding = nn.Sequential(
317
+ nn.Linear(embedding_dim, 4 * embedding_dim),
318
+ nn.Mish(),
319
+ nn.Linear(4 * embedding_dim, embedding_dim),
320
+ )
321
+
322
+ def forward(self, t: torch.Tensor):
323
+ """
324
+ :param t: B-dimensional tensor containing timesteps in range [0, 1]
325
+ :return: B x embedding_dim tensor containing timestep encodings
326
+ """
327
+ x = t[:, None] * self.emb[None, :]
328
+ x = torch.cat([torch.sin(x), torch.cos(x)], dim=-1)
329
+ x = self.encoding(x)
330
+ return x
331
+
332
+
333
+ class FiLM(nn.Module):
334
+ def __init__(self, dim: int):
335
+ super().__init__()
336
+ self.dim = dim
337
+ self.film = nn.Sequential(nn.Mish(), nn.Linear(dim, dim * 2))
338
+
339
+ def forward(self, x: torch.Tensor, cond: torch.Tensor):
340
+ """
341
+ :param x: ... x dim tensor
342
+ :param cond: ... x dim tensor
343
+ :return: ... x dim tensor as scale(cond) * x + bias(cond)
344
+ """
345
+ cond = self.film(cond)
346
+ scale, bias = torch.chunk(cond, chunks=2, dim=-1)
347
+ x = (scale + 1) * x + bias
348
+ return x
349
+
350
+
351
+ class FeedforwardBlock(nn.Module):
352
+ def __init__(self, d_model: int, d_feedforward: int = 1024, dropout: float = 0.1):
353
+ super().__init__()
354
+ self.ff = nn.Sequential(
355
+ nn.Linear(d_model, d_feedforward),
356
+ nn.ReLU(),
357
+ nn.Dropout(p=dropout),
358
+ nn.Linear(d_feedforward, d_model),
359
+ nn.Dropout(p=dropout),
360
+ )
361
+
362
+ def forward(self, x: torch.Tensor):
363
+ """
364
+ :param x: ... x d_model tensor
365
+ :return: ... x d_model tensor
366
+ """
367
+ return self.ff(x)
368
+
369
+
370
+ class SelfAttention(nn.Module):
371
+ def __init__(self, d_model: int, num_heads: int, dropout: float = 0.1):
372
+ super().__init__()
373
+ self.self_attn = nn.MultiheadAttention(
374
+ d_model, num_heads, dropout=dropout, batch_first=True
375
+ )
376
+ self.dropout = nn.Dropout(p=dropout)
377
+
378
+ def forward(
379
+ self,
380
+ x: torch.Tensor,
381
+ attn_mask: torch.Tensor = None,
382
+ key_padding_mask: torch.Tensor = None,
383
+ ):
384
+ """
385
+ :param x: B x T x d_model input tensor
386
+ :param attn_mask: B * num_heads x L x S mask with L=target sequence length, S=source sequence length
387
+ for a float mask: values will be added to attention weight
388
+ for a binary mask: True indicates that the element is not allowed to attend
389
+ :param key_padding_mask: B x S mask
390
+ for a float mask: values will be added directly to the corresponding key values
391
+ for a binary mask: True indicates that the corresponding key value will be ignored
392
+ :return: B x T x d_model output tensor
393
+ """
394
+ x = self.self_attn(
395
+ x,
396
+ x,
397
+ x,
398
+ attn_mask=attn_mask,
399
+ key_padding_mask=key_padding_mask,
400
+ need_weights=False,
401
+ )[0]
402
+ x = self.dropout(x)
403
+ return x
404
+
405
+
406
+ class CrossAttention(nn.Module):
407
+ def __init__(self, d_model: int, d_cond: int, num_heads: int, dropout: float = 0.1):
408
+ super().__init__()
409
+ self.cross_attn = nn.MultiheadAttention(
410
+ d_model,
411
+ num_heads,
412
+ dropout=dropout,
413
+ batch_first=True,
414
+ kdim=d_cond,
415
+ vdim=d_cond,
416
+ )
417
+ self.dropout = nn.Dropout(p=dropout)
418
+
419
+ def forward(
420
+ self,
421
+ x: torch.Tensor,
422
+ cond: torch.Tensor,
423
+ attn_mask: torch.Tensor = None,
424
+ key_padding_mask: torch.Tensor = None,
425
+ ):
426
+ """
427
+ :param x: B x T_target x d_model input tensor
428
+ :param cond: B x T_cond x d_cond condition tensor
429
+ :param attn_mask: B * num_heads x L x S mask with L=target sequence length, S=source sequence length
430
+ for a float mask: values will be added to attention weight
431
+ for a binary mask: True indicates that the element is not allowed to attend
432
+ :param key_padding_mask: B x S mask
433
+ for a float mask: values will be added directly to the corresponding key values
434
+ for a binary mask: True indicates that the corresponding key value will be ignored
435
+ :return: B x T x d_model output tensor
436
+ """
437
+ x = self.cross_attn(
438
+ x,
439
+ cond,
440
+ cond,
441
+ attn_mask=attn_mask,
442
+ key_padding_mask=key_padding_mask,
443
+ need_weights=False,
444
+ )[0]
445
+ x = self.dropout(x)
446
+ return x
447
+
448
+
449
+ class TransformerEncoderLayer(nn.Module):
450
+ def __init__(
451
+ self,
452
+ d_model: int,
453
+ num_heads: int,
454
+ d_feedforward: int = 1024,
455
+ dropout: float = 0.1,
456
+ ):
457
+ super().__init__()
458
+ self.norm1 = nn.LayerNorm(d_model)
459
+ self.self_attn = SelfAttention(d_model, num_heads, dropout)
460
+ self.norm2 = nn.LayerNorm(d_model)
461
+ self.feedforward = FeedforwardBlock(d_model, d_feedforward, dropout)
462
+
463
+ def forward(
464
+ self,
465
+ x: torch.Tensor,
466
+ mask: torch.Tensor = None,
467
+ key_padding_mask: torch.Tensor = None,
468
+ ):
469
+ x = x + self.self_attn(self.norm1(x), mask, key_padding_mask)
470
+ x = x + self.feedforward(self.norm2(x))
471
+ return x
472
+
473
+
474
+ class TransformerDecoderLayer(nn.Module):
475
+ def __init__(
476
+ self,
477
+ d_model: int,
478
+ d_cond: int,
479
+ num_heads: int,
480
+ d_feedforward: int = 1024,
481
+ dropout: float = 0.1,
482
+ ):
483
+ super().__init__()
484
+ self.norm1 = nn.LayerNorm(d_model)
485
+ self.self_attn = SelfAttention(d_model, num_heads, dropout)
486
+ self.norm2 = nn.LayerNorm(d_model)
487
+ self.cross_attn = CrossAttention(d_model, d_cond, num_heads, dropout)
488
+ self.norm3 = nn.LayerNorm(d_model)
489
+ self.feedforward = FeedforwardBlock(d_model, d_feedforward, dropout)
490
+
491
+ def forward(
492
+ self,
493
+ x: torch.Tensor,
494
+ cross_cond: torch.Tensor,
495
+ target_mask: torch.Tensor = None,
496
+ target_key_padding_mask: torch.Tensor = None,
497
+ cross_cond_mask: torch.Tensor = None,
498
+ cross_cond_key_padding_mask: torch.Tensor = None,
499
+ ):
500
+ """
501
+ :param x: B x T x d_model tensor
502
+ :param cross_cond: B x T x d_cond tensor containing the conditioning input to cross attention layers
503
+ :return: B x T x d_model tensor
504
+ """
505
+ x = x + self.self_attn(self.norm1(x), target_mask, target_key_padding_mask)
506
+ x = x + self.cross_attn(
507
+ self.norm2(x), cross_cond, cross_cond_mask, cross_cond_key_padding_mask
508
+ )
509
+ x = x + self.feedforward(self.norm3(x))
510
+ return x
511
+
512
+
513
+ class FilmTransformerDecoderLayer(nn.Module):
514
+ def __init__(
515
+ self,
516
+ d_model: int,
517
+ d_cond: int,
518
+ num_heads: int,
519
+ d_feedforward: int = 1024,
520
+ dropout: float = 0.1,
521
+ ):
522
+ super().__init__()
523
+ self.norm1 = nn.LayerNorm(d_model)
524
+ self.self_attn = SelfAttention(d_model, num_heads, dropout)
525
+ self.film1 = FiLM(d_model)
526
+ self.norm2 = nn.LayerNorm(d_model)
527
+ self.cross_attn = CrossAttention(d_model, d_cond, num_heads, dropout)
528
+ self.film2 = FiLM(d_model)
529
+ self.norm3 = nn.LayerNorm(d_model)
530
+ self.feedforward = FeedforwardBlock(d_model, d_feedforward, dropout)
531
+ self.film3 = FiLM(d_model)
532
+
533
+ def forward(
534
+ self,
535
+ x: torch.Tensor,
536
+ cross_cond: torch.Tensor,
537
+ film_cond: torch.Tensor,
538
+ target_mask: torch.Tensor = None,
539
+ target_key_padding_mask: torch.Tensor = None,
540
+ cross_cond_mask: torch.Tensor = None,
541
+ cross_cond_key_padding_mask: torch.Tensor = None,
542
+ ):
543
+ """
544
+ :param x: B x T x d_model tensor
545
+ :param cross_cond: B x T x d_cond tensor containing the conditioning input to cross attention layers
546
+ :param film_cond: B x [1 or T] x film_cond tensor containing the conditioning input to FiLM layers
547
+ :return: B x T x d_model tensor
548
+ """
549
+ x1 = self.self_attn(self.norm1(x), target_mask, target_key_padding_mask)
550
+ x = x + self.film1(x1, film_cond)
551
+ x2 = self.cross_attn(
552
+ self.norm2(x), cross_cond, cross_cond_mask, cross_cond_key_padding_mask
553
+ )
554
+ x = x + self.film2(x2, film_cond)
555
+ x3 = self.feedforward(self.norm3(x))
556
+ x = x + self.film3(x3, film_cond)
557
+ return x
558
+
559
+
560
+ class RegressionTransformer(nn.Module):
561
+ def __init__(
562
+ self,
563
+ transformer_encoder_layers: int = 2,
564
+ transformer_decoder_layers: int = 4,
565
+ d_model: int = 512,
566
+ d_cond: int = 512,
567
+ num_heads: int = 4,
568
+ d_feedforward: int = 1024,
569
+ dropout: float = 0.1,
570
+ causal: bool = False,
571
+ ):
572
+ super().__init__()
573
+ self.causal = causal
574
+
575
+ self.cond_positional_encoding = PositionalEncoding(d_cond, dropout)
576
+ self.target_positional_encoding = PositionalEncoding(d_model, dropout)
577
+
578
+ self.transformer_encoder = nn.ModuleList(
579
+ [
580
+ TransformerEncoderLayer(d_cond, num_heads, d_feedforward, dropout)
581
+ for _ in range(transformer_encoder_layers)
582
+ ]
583
+ )
584
+
585
+ self.transformer_decoder = nn.ModuleList(
586
+ [
587
+ TransformerDecoderLayer(
588
+ d_model, d_cond, num_heads, d_feedforward, dropout
589
+ )
590
+ for _ in range(transformer_decoder_layers)
591
+ ]
592
+ )
593
+
594
+ def forward(self, x: torch.Tensor, cond: torch.Tensor):
595
+ """
596
+ :param x: B x T x d_model input tensor
597
+ :param cond: B x T x d_cond conditional tensor
598
+ :return: B x T x d_model output tensor
599
+ """
600
+ x = self.target_positional_encoding(x)
601
+ cond = self.cond_positional_encoding(cond)
602
+
603
+ if self.causal:
604
+ encoder_mask = generate_causal_mask(
605
+ cond.shape[1], cond.shape[1], device=cond.device
606
+ )
607
+ decoder_self_attn_mask = generate_causal_mask(
608
+ x.shape[1], x.shape[1], device=x.device
609
+ )
610
+ decoder_cross_attn_mask = generate_causal_mask(
611
+ cond.shape[1], x.shape[1], device=x.device
612
+ )
613
+ else:
614
+ encoder_mask = None
615
+ decoder_self_attn_mask = None
616
+ decoder_cross_attn_mask = None
617
+
618
+ for encoder_layer in self.transformer_encoder:
619
+ cond = encoder_layer(cond, mask=encoder_mask)
620
+ for decoder_layer in self.transformer_decoder:
621
+ x = decoder_layer(
622
+ x,
623
+ cond,
624
+ target_mask=decoder_self_attn_mask,
625
+ cross_cond_mask=decoder_cross_attn_mask,
626
+ )
627
+ return x
628
+
629
+
630
+ class DiffusionTransformer(nn.Module):
631
+ def __init__(
632
+ self,
633
+ transformer_encoder_layers: int = 2,
634
+ transformer_decoder_layers: int = 4,
635
+ d_model: int = 512,
636
+ d_cond: int = 512,
637
+ num_heads: int = 4,
638
+ d_feedforward: int = 1024,
639
+ dropout: float = 0.1,
640
+ causal: bool = False,
641
+ ):
642
+ super().__init__()
643
+ self.causal = causal
644
+
645
+ self.timestep_encoder = TimestepEncoding(d_model)
646
+ self.cond_positional_encoding = PositionalEncoding(d_cond, dropout)
647
+ self.target_positional_encoding = PositionalEncoding(d_model, dropout)
648
+
649
+ self.transformer_encoder = nn.ModuleList(
650
+ [
651
+ TransformerEncoderLayer(d_cond, num_heads, d_feedforward, dropout)
652
+ for _ in range(transformer_encoder_layers)
653
+ ]
654
+ )
655
+
656
+ self.transformer_decoder = nn.ModuleList(
657
+ [
658
+ FilmTransformerDecoderLayer(
659
+ d_model, d_cond, num_heads, d_feedforward, dropout
660
+ )
661
+ for _ in range(transformer_decoder_layers)
662
+ ]
663
+ )
664
+
665
+ def forward(self, x: torch.Tensor, cond: torch.Tensor, t: torch.Tensor):
666
+ """
667
+ :param x: B x T x d_model input tensor
668
+ :param cond: B x T x d_cond conditional tensor
669
+ :param t: B-dimensional tensor containing diffusion timesteps in range [0, 1]
670
+ :return: B x T x d_model output tensor
671
+ """
672
+ t = self.timestep_encoder(t).unsqueeze(1) # B x 1 x d_model
673
+ x = self.target_positional_encoding(x)
674
+ cond = self.cond_positional_encoding(cond)
675
+
676
+ if self.causal:
677
+ encoder_mask = generate_causal_mask(
678
+ cond.shape[1], cond.shape[1], device=cond.device
679
+ )
680
+ decoder_self_attn_mask = generate_causal_mask(
681
+ x.shape[1], x.shape[1], device=x.device
682
+ )
683
+ decoder_cross_attn_mask = generate_causal_mask(
684
+ cond.shape[1], x.shape[1], device=x.device
685
+ )
686
+ else:
687
+ encoder_mask = None
688
+ decoder_self_attn_mask = None
689
+ decoder_cross_attn_mask = None
690
+
691
+ for encoder_layer in self.transformer_encoder:
692
+ cond = encoder_layer(cond, mask=encoder_mask)
693
+ for decoder_layer in self.transformer_decoder:
694
+ x = decoder_layer(
695
+ x,
696
+ cond,
697
+ t,
698
+ target_mask=decoder_self_attn_mask,
699
+ cross_cond_mask=decoder_cross_attn_mask,
700
+ )
701
+
702
+ return x
model/utils.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ All rights reserved.
4
+ This source code is licensed under the license found in the
5
+ LICENSE file in the root directory of this source tree.
6
+ """
7
+
8
+ import math
9
+
10
+ import fairseq
11
+
12
+ import numpy as np
13
+ import torch
14
+ import torchaudio.transforms as T
15
+ from torch import nn
16
+
17
+
18
+ def setup_lip_regressor() -> ("Audio2LipRegressionTransformer", T.Resample):
19
+ cp_path = "./assets/vq-wav2vec.pt"
20
+ audio_model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([cp_path])
21
+ audio_model = audio_model[0]
22
+ for param in audio_model.parameters():
23
+ param.requires_grad = False
24
+ audio_model.eval()
25
+ audio_resampler = T.Resample(48000, 16000)
26
+ return audio_model, audio_resampler
27
+
28
+
29
+ def init_weight(m):
30
+ if (
31
+ isinstance(m, nn.Conv1d)
32
+ or isinstance(m, nn.Linear)
33
+ or isinstance(m, nn.ConvTranspose1d)
34
+ ):
35
+ nn.init.xavier_normal_(m.weight)
36
+ # m.bias.data.fill_(0.01)
37
+ if m.bias is not None:
38
+ nn.init.constant_(m.bias, 0)
39
+
40
+
41
+ # absolute positional embedding used for vanilla transformer sequential data
42
+ class PositionalEncoding(nn.Module):
43
+ def __init__(self, d_model, dropout=0.1, max_len=800, batch_first=False):
44
+ super().__init__()
45
+ self.batch_first = batch_first
46
+
47
+ self.dropout = nn.Dropout(p=dropout)
48
+
49
+ pe = torch.zeros(max_len, d_model)
50
+ position = torch.arange(0, max_len).unsqueeze(1)
51
+ div_term = torch.exp(torch.arange(0, d_model, 2) * (-np.log(10000.0) / d_model))
52
+ pe[:, 0::2] = torch.sin(position * div_term)
53
+ pe[:, 1::2] = torch.cos(position * div_term)
54
+ pe = pe.unsqueeze(0).transpose(0, 1)
55
+
56
+ self.register_buffer("pe", pe)
57
+
58
+ def forward(self, x):
59
+ if self.batch_first:
60
+ x = x + self.pe.permute(1, 0, 2)[:, : x.shape[1], :]
61
+ else:
62
+ x = x + self.pe[: x.shape[0], :]
63
+ return self.dropout(x)
64
+
65
+
66
+ # very similar positional embedding used for diffusion timesteps
67
+ class SinusoidalPosEmb(nn.Module):
68
+ def __init__(self, dim):
69
+ super().__init__()
70
+ self.dim = dim
71
+
72
+ def forward(self, x):
73
+ device = x.device
74
+ half_dim = self.dim // 2
75
+ emb = math.log(10000) / (half_dim - 1)
76
+ emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
77
+ emb = x[:, None] * emb[None, :]
78
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
79
+ return emb
80
+
81
+
82
+ # dropout mask
83
+ def prob_mask_like(shape, prob, device):
84
+ if prob == 1:
85
+ return torch.ones(shape, device=device, dtype=torch.bool)
86
+ elif prob == 0:
87
+ return torch.zeros(shape, device=device, dtype=torch.bool)
88
+ else:
89
+ return torch.zeros(shape, device=device).float().uniform_(0, 1) < prob
90
+
91
+
92
+ def extract(a, t, x_shape):
93
+ b, *_ = t.shape
94
+ out = a.gather(-1, t)
95
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
96
+
97
+
98
+ def make_beta_schedule(
99
+ schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3
100
+ ):
101
+ if schedule == "linear":
102
+ betas = (
103
+ torch.linspace(
104
+ linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64
105
+ )
106
+ ** 2
107
+ )
108
+
109
+ elif schedule == "cosine":
110
+ timesteps = (
111
+ torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
112
+ )
113
+ alphas = timesteps / (1 + cosine_s) * np.pi / 2
114
+ alphas = torch.cos(alphas).pow(2)
115
+ alphas = alphas / alphas[0]
116
+ betas = 1 - alphas[1:] / alphas[:-1]
117
+ betas = np.clip(betas, a_min=0, a_max=0.999)
118
+
119
+ elif schedule == "sqrt_linear":
120
+ betas = torch.linspace(
121
+ linear_start, linear_end, n_timestep, dtype=torch.float64
122
+ )
123
+ elif schedule == "sqrt":
124
+ betas = (
125
+ torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
126
+ ** 0.5
127
+ )
128
+ else:
129
+ raise ValueError(f"schedule '{schedule}' unknown.")
130
+ return betas.numpy()
model/vqvae.py ADDED
@@ -0,0 +1,550 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ All rights reserved.
4
+ This source code is licensed under the license found in the
5
+ LICENSE file in the root directory of this source tree.
6
+ """
7
+
8
+ import json
9
+ import os
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from einops import rearrange, repeat
15
+ from utils.misc import broadcast_tensors
16
+
17
+
18
+ def setup_tokenizer(resume_pth: str) -> "TemporalVertexCodec":
19
+ args_path = os.path.dirname(resume_pth)
20
+ with open(os.path.join(args_path, "args.json")) as f:
21
+ trans_args = json.load(f)
22
+ tokenizer = TemporalVertexCodec(
23
+ n_vertices=trans_args["nb_joints"],
24
+ latent_dim=trans_args["output_emb_width"],
25
+ categories=trans_args["code_dim"],
26
+ residual_depth=trans_args["depth"],
27
+ )
28
+ print("loading checkpoint from {}".format(resume_pth))
29
+ ckpt = torch.load(resume_pth, map_location="cpu")
30
+ tokenizer.load_state_dict(ckpt["net"], strict=True)
31
+ for p in tokenizer.parameters():
32
+ p.requires_grad = False
33
+ tokenizer.cuda()
34
+ return tokenizer
35
+
36
+
37
+ def default(val, d):
38
+ return val if val is not None else d
39
+
40
+
41
+ def ema_inplace(moving_avg, new, decay: float):
42
+ moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
43
+
44
+
45
+ def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
46
+ return (x + epsilon) / (x.sum() + n_categories * epsilon)
47
+
48
+
49
+ def uniform_init(*shape: int):
50
+ t = torch.empty(shape)
51
+ nn.init.kaiming_uniform_(t)
52
+ return t
53
+
54
+
55
+ def sum_flat(tensor):
56
+ """
57
+ Take the sum over all non-batch dimensions.
58
+ """
59
+ return tensor.sum(dim=list(range(1, len(tensor.shape))))
60
+
61
+
62
+ def sample_vectors(samples, num: int):
63
+ num_samples, device = samples.shape[0], samples.device
64
+
65
+ if num_samples >= num:
66
+ indices = torch.randperm(num_samples, device=device)[:num]
67
+ else:
68
+ indices = torch.randint(0, num_samples, (num,), device=device)
69
+
70
+ return samples[indices]
71
+
72
+
73
+ def kmeans(samples, num_clusters: int, num_iters: int = 10):
74
+ dim, dtype = samples.shape[-1], samples.dtype
75
+
76
+ means = sample_vectors(samples, num_clusters)
77
+
78
+ for _ in range(num_iters):
79
+ diffs = rearrange(samples, "n d -> n () d") - rearrange(means, "c d -> () c d")
80
+ dists = -(diffs**2).sum(dim=-1)
81
+
82
+ buckets = dists.max(dim=-1).indices
83
+ bins = torch.bincount(buckets, minlength=num_clusters)
84
+ zero_mask = bins == 0
85
+ bins_min_clamped = bins.masked_fill(zero_mask, 1)
86
+
87
+ new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
88
+ new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
89
+ new_means = new_means / bins_min_clamped[..., None]
90
+
91
+ means = torch.where(zero_mask[..., None], means, new_means)
92
+
93
+ return means, bins
94
+
95
+
96
+ class EuclideanCodebook(nn.Module):
97
+ """Codebook with Euclidean distance.
98
+ Args:
99
+ dim (int): Dimension.
100
+ codebook_size (int): Codebook size.
101
+ kmeans_init (bool): Whether to use k-means to initialize the codebooks.
102
+ If set to true, run the k-means algorithm on the first training batch and use
103
+ the learned centroids as initialization.
104
+ kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
105
+ decay (float): Decay for exponential moving average over the codebooks.
106
+ epsilon (float): Epsilon value for numerical stability.
107
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
108
+ that have an exponential moving average cluster size less than the specified threshold with
109
+ randomly selected vector from the current batch.
110
+ """
111
+
112
+ def __init__(
113
+ self,
114
+ dim: int,
115
+ codebook_size: int,
116
+ kmeans_init: int = False,
117
+ kmeans_iters: int = 10,
118
+ decay: float = 0.99,
119
+ epsilon: float = 1e-5,
120
+ threshold_ema_dead_code: int = 2,
121
+ ):
122
+ super().__init__()
123
+ self.decay = decay
124
+ init_fn = uniform_init if not kmeans_init else torch.zeros
125
+ embed = init_fn(codebook_size, dim)
126
+
127
+ self.codebook_size = codebook_size
128
+
129
+ self.kmeans_iters = kmeans_iters
130
+ self.epsilon = epsilon
131
+ self.threshold_ema_dead_code = threshold_ema_dead_code
132
+
133
+ self.register_buffer("inited", torch.Tensor([not kmeans_init]))
134
+ self.register_buffer("cluster_size", torch.zeros(codebook_size))
135
+ self.register_buffer("embed", embed)
136
+ self.register_buffer("embed_avg", embed.clone())
137
+
138
+ @torch.jit.ignore
139
+ def init_embed_(self, data):
140
+ if self.inited:
141
+ return
142
+
143
+ embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
144
+ self.embed.data.copy_(embed)
145
+ self.embed_avg.data.copy_(embed.clone())
146
+ self.cluster_size.data.copy_(cluster_size)
147
+ self.inited.data.copy_(torch.Tensor([True]))
148
+ # Make sure all buffers across workers are in sync after initialization
149
+ broadcast_tensors(self.buffers())
150
+
151
+ def replace_(self, samples, mask):
152
+ modified_codebook = torch.where(
153
+ mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
154
+ )
155
+ self.embed.data.copy_(modified_codebook)
156
+
157
+ def expire_codes_(self, batch_samples):
158
+ if self.threshold_ema_dead_code == 0:
159
+ return
160
+
161
+ expired_codes = self.cluster_size < self.threshold_ema_dead_code
162
+ if not torch.any(expired_codes):
163
+ return
164
+
165
+ batch_samples = rearrange(batch_samples, "... d -> (...) d")
166
+ self.replace_(batch_samples, mask=expired_codes)
167
+ broadcast_tensors(self.buffers())
168
+
169
+ def preprocess(self, x):
170
+ x = rearrange(x, "... d -> (...) d")
171
+ return x
172
+
173
+ def quantize(self, x):
174
+ embed = self.embed.t()
175
+ dist = -(
176
+ x.pow(2).sum(1, keepdim=True)
177
+ - 2 * x @ embed
178
+ + embed.pow(2).sum(0, keepdim=True)
179
+ )
180
+ embed_ind = dist.max(dim=-1).indices
181
+ return embed_ind
182
+
183
+ def postprocess_emb(self, embed_ind, shape):
184
+ return embed_ind.view(*shape[:-1])
185
+
186
+ def dequantize(self, embed_ind):
187
+ quantize = F.embedding(embed_ind, self.embed)
188
+ return quantize
189
+
190
+ def encode(self, x):
191
+ shape = x.shape
192
+ x = self.preprocess(x)
193
+ embed_ind = self.quantize(x)
194
+ embed_ind = self.postprocess_emb(embed_ind, shape)
195
+ return embed_ind
196
+
197
+ def decode(self, embed_ind):
198
+ quantize = self.dequantize(embed_ind)
199
+ return quantize
200
+
201
+ def forward(self, x):
202
+ shape, dtype = x.shape, x.dtype
203
+ x = self.preprocess(x)
204
+
205
+ self.init_embed_(x)
206
+
207
+ embed_ind = self.quantize(x)
208
+ embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
209
+ embed_ind = self.postprocess_emb(embed_ind, shape)
210
+ quantize = self.dequantize(embed_ind)
211
+
212
+ if self.training:
213
+ # We do the expiry of code at that point as buffers are in sync
214
+ # and all the workers will take the same decision.
215
+ self.expire_codes_(x)
216
+ ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
217
+ embed_sum = x.t() @ embed_onehot
218
+ ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
219
+ cluster_size = (
220
+ laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon)
221
+ * self.cluster_size.sum()
222
+ )
223
+ embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
224
+ self.embed.data.copy_(embed_normalized)
225
+
226
+ return quantize, embed_ind
227
+
228
+
229
+ class VectorQuantization(nn.Module):
230
+ """Vector quantization implementation.
231
+ Currently supports only euclidean distance.
232
+ Args:
233
+ dim (int): Dimension
234
+ codebook_size (int): Codebook size
235
+ codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
236
+ decay (float): Decay for exponential moving average over the codebooks.
237
+ epsilon (float): Epsilon value for numerical stability.
238
+ kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
239
+ kmeans_iters (int): Number of iterations used for kmeans initialization.
240
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
241
+ that have an exponential moving average cluster size less than the specified threshold with
242
+ randomly selected vector from the current batch.
243
+ commitment_weight (float): Weight for commitment loss.
244
+ """
245
+
246
+ def __init__(
247
+ self,
248
+ dim: int,
249
+ codebook_size: int,
250
+ codebook_dim=None,
251
+ decay: float = 0.99,
252
+ epsilon: float = 1e-5,
253
+ kmeans_init: bool = True,
254
+ kmeans_iters: int = 50,
255
+ threshold_ema_dead_code: int = 2,
256
+ commitment_weight: float = 1.0,
257
+ ):
258
+ super().__init__()
259
+ _codebook_dim: int = default(codebook_dim, dim)
260
+
261
+ requires_projection = _codebook_dim != dim
262
+ self.project_in = (
263
+ nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity()
264
+ )
265
+ self.project_out = (
266
+ nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity()
267
+ )
268
+
269
+ self.epsilon = epsilon
270
+ self.commitment_weight = commitment_weight
271
+
272
+ self._codebook = EuclideanCodebook(
273
+ dim=_codebook_dim,
274
+ codebook_size=codebook_size,
275
+ kmeans_init=kmeans_init,
276
+ kmeans_iters=kmeans_iters,
277
+ decay=decay,
278
+ epsilon=epsilon,
279
+ threshold_ema_dead_code=threshold_ema_dead_code,
280
+ )
281
+ self.codebook_size = codebook_size
282
+ self.l2_loss = lambda a, b: (a - b) ** 2
283
+
284
+ @property
285
+ def codebook(self):
286
+ return self._codebook.embed
287
+
288
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
289
+ x = self.project_in(x)
290
+ embed_in = self._codebook.encode(x)
291
+ return embed_in
292
+
293
+ def decode(self, embed_ind: torch.Tensor) -> torch.Tensor:
294
+ quantize = self._codebook.decode(embed_ind)
295
+ quantize = self.project_out(quantize)
296
+ return quantize
297
+
298
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
299
+ """
300
+ :param x: B x dim input tensor
301
+ :return: quantize: B x dim tensor containing reconstruction after quantization
302
+ embed_ind: B-dimensional tensor containing embedding indices
303
+ loss: scalar tensor containing commitment loss
304
+ """
305
+ device = x.device
306
+ x = self.project_in(x)
307
+
308
+ quantize, embed_ind = self._codebook(x)
309
+
310
+ if self.training:
311
+ quantize = x + (quantize - x).detach()
312
+
313
+ loss = torch.tensor([0.0], device=device, requires_grad=self.training)
314
+
315
+ if self.training:
316
+ if self.commitment_weight > 0:
317
+ commit_loss = F.mse_loss(quantize.detach(), x)
318
+ loss = loss + commit_loss * self.commitment_weight
319
+
320
+ quantize = self.project_out(quantize)
321
+ return quantize, embed_ind, loss
322
+
323
+
324
+ class ResidualVectorQuantization(nn.Module):
325
+ """Residual vector quantization implementation.
326
+ Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
327
+ """
328
+
329
+ def __init__(self, *, num_quantizers: int, **kwargs):
330
+ super().__init__()
331
+ self.layers = nn.ModuleList(
332
+ [VectorQuantization(**kwargs) for _ in range(num_quantizers)]
333
+ )
334
+
335
+ def forward(self, x, B, T, mask, n_q=None):
336
+ """
337
+ :param x: B x dim tensor
338
+ :return: quantized_out: B x dim tensor
339
+ out_indices: B x n_q LongTensor containing indices for each quantizer
340
+ out_losses: scalar tensor containing commitment loss
341
+ """
342
+ quantized_out = 0.0
343
+ residual = x
344
+
345
+ all_losses = []
346
+ all_indices = []
347
+
348
+ n_q = n_q or len(self.layers)
349
+
350
+ for layer in self.layers[:n_q]:
351
+ quantized, indices, loss = layer(residual)
352
+ residual = (
353
+ residual - quantized
354
+ ) # would need quantizer.detach() to have commitment gradients beyond the first quantizer, but this seems to harm performance
355
+ quantized_out = quantized_out + quantized
356
+
357
+ all_indices.append(indices)
358
+ all_losses.append(loss)
359
+
360
+ out_indices = torch.stack(all_indices, dim=-1)
361
+ out_losses = torch.mean(torch.stack(all_losses))
362
+ return quantized_out, out_indices, out_losses
363
+
364
+ def encode(self, x: torch.Tensor, n_q=None) -> torch.Tensor:
365
+ """
366
+ :param x: B x dim input tensor
367
+ :return: B x n_q LongTensor containing indices for each quantizer
368
+ """
369
+ residual = x
370
+ all_indices = []
371
+ n_q = n_q or len(self.layers)
372
+ for layer in self.layers[:n_q]:
373
+ indices = layer.encode(residual) # indices = 16 x 8 = B x T
374
+ # print(indices.shape, residual.shape, x.shape)
375
+ quantized = layer.decode(indices)
376
+ residual = residual - quantized
377
+ all_indices.append(indices)
378
+ out_indices = torch.stack(all_indices, dim=-1)
379
+ return out_indices
380
+
381
+ def decode(self, q_indices: torch.Tensor) -> torch.Tensor:
382
+ """
383
+ :param q_indices: B x n_q LongTensor containing indices for each quantizer
384
+ :return: B x dim tensor containing reconstruction after quantization
385
+ """
386
+ quantized_out = torch.tensor(0.0, device=q_indices.device)
387
+ q_indices = q_indices.permute(1, 0).contiguous()
388
+ for i, indices in enumerate(q_indices):
389
+ layer = self.layers[i]
390
+ quantized = layer.decode(indices)
391
+ quantized_out = quantized_out + quantized
392
+ return quantized_out
393
+
394
+
395
+ class TemporalVertexEncoder(nn.Module):
396
+ def __init__(
397
+ self,
398
+ n_vertices: int = 338,
399
+ latent_dim: int = 128,
400
+ ):
401
+ super().__init__()
402
+ self.input_dim = n_vertices
403
+ self.enc = nn.Sequential(
404
+ nn.Conv1d(self.input_dim, latent_dim, kernel_size=1),
405
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
406
+ nn.Conv1d(latent_dim, latent_dim, kernel_size=2, dilation=1),
407
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
408
+ nn.Conv1d(latent_dim, latent_dim, kernel_size=2, dilation=2),
409
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
410
+ nn.Conv1d(latent_dim, latent_dim, kernel_size=2, dilation=3),
411
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
412
+ nn.Conv1d(latent_dim, latent_dim, kernel_size=2, dilation=1),
413
+ )
414
+ self.receptive_field = 8
415
+
416
+ def forward(self, verts):
417
+ """
418
+ :param verts: B x T x n_vertices x 3 tensor containing batched sequences of vertices
419
+ :return: B x T x latent_dim tensor containing the latent representation
420
+ """
421
+ if verts.dim() == 4:
422
+ verts = verts.permute(0, 2, 3, 1).contiguous()
423
+ verts = verts.view(verts.shape[0], self.input_dim, verts.shape[3])
424
+ else:
425
+ verts = verts.permute(0, 2, 1)
426
+ verts = nn.functional.pad(verts, pad=[self.receptive_field - 1, 0])
427
+ x = self.enc(verts)
428
+ x = x.permute(0, 2, 1).contiguous()
429
+ return x
430
+
431
+
432
+ class TemporalVertexDecoder(nn.Module):
433
+ def __init__(
434
+ self,
435
+ n_vertices: int = 338,
436
+ latent_dim: int = 128,
437
+ ):
438
+ super().__init__()
439
+ self.output_dim = n_vertices
440
+ self.project_mean_shape = nn.Linear(self.output_dim, latent_dim)
441
+ self.dec = nn.Sequential(
442
+ nn.Conv1d(latent_dim, latent_dim, kernel_size=2, dilation=1),
443
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
444
+ nn.Conv1d(latent_dim, latent_dim, kernel_size=2, dilation=2),
445
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
446
+ nn.Conv1d(latent_dim, latent_dim, kernel_size=2, dilation=3),
447
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
448
+ nn.Conv1d(latent_dim, latent_dim, kernel_size=2, dilation=1),
449
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
450
+ nn.Conv1d(latent_dim, self.output_dim, kernel_size=1),
451
+ )
452
+ self.receptive_field = 8
453
+
454
+ def forward(self, x):
455
+ """
456
+ :param x: B x T x latent_dim tensor containing batched sequences of vertex encodings
457
+ :return: B x T x n_vertices x 3 tensor containing batched sequences of vertices
458
+ """
459
+ x = x.permute(0, 2, 1).contiguous()
460
+ x = nn.functional.pad(x, pad=[self.receptive_field - 1, 0])
461
+ verts = self.dec(x)
462
+ verts = verts.permute(0, 2, 1)
463
+ return verts
464
+
465
+
466
+ class TemporalVertexCodec(nn.Module):
467
+ def __init__(
468
+ self,
469
+ n_vertices: int = 338,
470
+ latent_dim: int = 128,
471
+ categories: int = 128,
472
+ residual_depth: int = 4,
473
+ ):
474
+ super().__init__()
475
+ self.latent_dim = latent_dim
476
+ self.categories = categories
477
+ self.residual_depth = residual_depth
478
+ self.n_clusters = categories
479
+ self.encoder = TemporalVertexEncoder(
480
+ n_vertices=n_vertices, latent_dim=latent_dim
481
+ )
482
+ self.decoder = TemporalVertexDecoder(
483
+ n_vertices=n_vertices, latent_dim=latent_dim
484
+ )
485
+ self.quantizer = ResidualVectorQuantization(
486
+ dim=latent_dim,
487
+ codebook_size=categories,
488
+ num_quantizers=residual_depth,
489
+ decay=0.99,
490
+ kmeans_init=True,
491
+ kmeans_iters=10,
492
+ threshold_ema_dead_code=2,
493
+ )
494
+
495
+ def predict(self, verts):
496
+ """wrapper to provide compatibility with kmeans"""
497
+ return self.encode(verts)
498
+
499
+ def encode(self, verts):
500
+ """
501
+ :param verts: B x T x n_vertices x 3 tensor containing batched sequences of vertices
502
+ :return: B x T x categories x residual_depth LongTensor containing quantized encodings
503
+ """
504
+ enc = self.encoder(verts)
505
+ q = self.quantizer.encode(enc)
506
+ return q
507
+
508
+ def decode(self, q):
509
+ """
510
+ :param q: B x T x categories x residual_depth LongTensor containing quantized encodings
511
+ :return: B x T x n_vertices x 3 tensor containing decoded vertices
512
+ """
513
+ reformat = q.dim() > 2
514
+ if reformat:
515
+ B, T, _ = q.shape
516
+ q = q.reshape((-1, self.residual_depth))
517
+ enc = self.quantizer.decode(q)
518
+ if reformat:
519
+ enc = enc.reshape((B, T, -1))
520
+ verts = self.decoder(enc)
521
+ return verts
522
+
523
+ @torch.no_grad()
524
+ def compute_perplexity(self, code_idx):
525
+ # Calculate new centres
526
+ code_onehot = torch.zeros(
527
+ self.categories, code_idx.shape[0], device=code_idx.device
528
+ ) # categories, N * L
529
+ code_onehot.scatter_(0, code_idx.view(1, code_idx.shape[0]), 1)
530
+
531
+ code_count = code_onehot.sum(dim=-1) # categories
532
+ prob = code_count / torch.sum(code_count)
533
+ perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7)))
534
+ return perplexity
535
+
536
+ def forward(self, verts, mask=None):
537
+ """
538
+ :param verts: B x T x n_vertices x 3 tensor containing mesh sequences
539
+ :return: verts: B x T x n_vertices x 3 tensor containing reconstructed mesh sequences
540
+ vq_loss: scalar tensor for vq commitment loss
541
+ """
542
+ B, T = verts.shape[0], verts.shape[1]
543
+ x = self.encoder(verts)
544
+ x, code_idx, vq_loss = self.quantizer(
545
+ x.view(B * T, self.latent_dim), B, T, mask
546
+ )
547
+ perplexity = self.compute_perplexity(code_idx[:, -1].view((-1)))
548
+ verts = self.decoder(x.view(B, T, self.latent_dim))
549
+ verts = verts.reshape((verts.shape[0], verts.shape[1], -1))
550
+ return verts, vq_loss, perplexity
sample/generate.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ All rights reserved.
4
+ This source code is licensed under the license found in the
5
+ LICENSE file in the root directory of this source tree.
6
+ """
7
+
8
+ import os
9
+
10
+ from typing import Callable, Dict, Union
11
+
12
+ import numpy as np
13
+ import torch
14
+ from data_loaders.get_data import get_dataset_loader, load_local_data
15
+ from diffusion.respace import SpacedDiffusion
16
+ from model.cfg_sampler import ClassifierFreeSampleModel
17
+ from model.diffusion import FiLMTransformer
18
+
19
+ from torch.utils.data import DataLoader
20
+ from utils.diff_parser_utils import generate_args
21
+ from utils.misc import fixseed, prGreen
22
+ from utils.model_util import create_model_and_diffusion, get_person_num, load_model
23
+
24
+
25
+ def _construct_template_variables(unconstrained: bool) -> (str,):
26
+ row_file_template = "sample{:02d}.mp4"
27
+ all_file_template = "samples_{:02d}_to_{:02d}.mp4"
28
+ if unconstrained:
29
+ sample_file_template = "row{:02d}_col{:02d}.mp4"
30
+ sample_print_template = "[{} row #{:02d} column #{:02d} | -> {}]"
31
+ row_file_template = row_file_template.replace("sample", "row")
32
+ row_print_template = "[{} row #{:02d} | all columns | -> {}]"
33
+ all_file_template = all_file_template.replace("samples", "rows")
34
+ all_print_template = "[rows {:02d} to {:02d} | -> {}]"
35
+ else:
36
+ sample_file_template = "sample{:02d}_rep{:02d}.mp4"
37
+ sample_print_template = '["{}" ({:02d}) | Rep #{:02d} | -> {}]'
38
+ row_print_template = '[ "{}" ({:02d}) | all repetitions | -> {}]'
39
+ all_print_template = "[samples {:02d} to {:02d} | all repetitions | -> {}]"
40
+
41
+ return (
42
+ sample_print_template,
43
+ row_print_template,
44
+ all_print_template,
45
+ sample_file_template,
46
+ row_file_template,
47
+ all_file_template,
48
+ )
49
+
50
+
51
+ def _replace_keyframes(
52
+ model_kwargs: Dict[str, Dict[str, torch.Tensor]],
53
+ model: Union[FiLMTransformer, ClassifierFreeSampleModel],
54
+ ) -> torch.Tensor:
55
+ B, T = (
56
+ model_kwargs["y"]["keyframes"].shape[0],
57
+ model_kwargs["y"]["keyframes"].shape[1],
58
+ )
59
+ with torch.no_grad():
60
+ tokens = model.transformer.generate(
61
+ model_kwargs["y"]["audio"],
62
+ T,
63
+ layers=model.tokenizer.residual_depth,
64
+ n_sequences=B,
65
+ )
66
+ tokens = tokens.reshape((B, -1, model.tokenizer.residual_depth))
67
+ pred = model.tokenizer.decode(tokens).detach().cpu()
68
+ assert (
69
+ model_kwargs["y"]["keyframes"].shape == pred.shape
70
+ ), f"{model_kwargs['y']['keyframes'].shape} vs {pred.shape}"
71
+ return pred
72
+
73
+
74
+ def _run_single_diffusion(
75
+ args,
76
+ model_kwargs: Dict[str, Dict[str, torch.Tensor]],
77
+ diffusion: SpacedDiffusion,
78
+ model: Union[FiLMTransformer, ClassifierFreeSampleModel],
79
+ inv_transform: Callable,
80
+ gt: torch.Tensor,
81
+ ) -> (torch.Tensor,):
82
+ if args.data_format == "pose" and args.resume_trans is not None:
83
+ model_kwargs["y"]["keyframes"] = _replace_keyframes(model_kwargs, model)
84
+
85
+ sample_fn = diffusion.ddim_sample_loop
86
+ with torch.no_grad():
87
+ sample = sample_fn(
88
+ model,
89
+ (args.batch_size, model.nfeats, 1, args.curr_seq_length),
90
+ clip_denoised=False,
91
+ model_kwargs=model_kwargs,
92
+ init_image=None,
93
+ progress=True,
94
+ dump_steps=None,
95
+ noise=None,
96
+ const_noise=False,
97
+ )
98
+ sample = inv_transform(sample.cpu().permute(0, 2, 3, 1), args.data_format).permute(
99
+ 0, 3, 1, 2
100
+ )
101
+ curr_audio = inv_transform(model_kwargs["y"]["audio"].cpu().numpy(), "audio")
102
+ keyframes = inv_transform(model_kwargs["y"]["keyframes"], args.data_format)
103
+ gt_seq = inv_transform(gt.cpu().permute(0, 2, 3, 1), args.data_format).permute(
104
+ 0, 3, 1, 2
105
+ )
106
+
107
+ return sample, curr_audio, keyframes, gt_seq
108
+
109
+
110
+ def _generate_sequences(
111
+ args,
112
+ model_kwargs: Dict[str, Dict[str, torch.Tensor]],
113
+ diffusion: SpacedDiffusion,
114
+ model: Union[FiLMTransformer, ClassifierFreeSampleModel],
115
+ test_data: torch.Tensor,
116
+ gt: torch.Tensor,
117
+ ) -> Dict[str, np.ndarray]:
118
+ all_motions = []
119
+ all_lengths = []
120
+ all_audio = []
121
+ all_gt = []
122
+ all_keyframes = []
123
+
124
+ for rep_i in range(args.num_repetitions):
125
+ print(f"### Sampling [repetitions #{rep_i}]")
126
+ # add CFG scale to batch
127
+ if args.guidance_param != 1:
128
+ model_kwargs["y"]["scale"] = (
129
+ torch.ones(args.batch_size, device=args.device) * args.guidance_param
130
+ )
131
+ model_kwargs["y"] = {
132
+ key: val.to(args.device) if torch.is_tensor(val) else val
133
+ for key, val in model_kwargs["y"].items()
134
+ }
135
+ sample, curr_audio, keyframes, gt_seq = _run_single_diffusion(
136
+ args, model_kwargs, diffusion, model, test_data.dataset.inv_transform, gt
137
+ )
138
+ all_motions.append(sample.cpu().numpy())
139
+ all_audio.append(curr_audio)
140
+ all_keyframes.append(keyframes.cpu().numpy())
141
+ all_gt.append(gt_seq.cpu().numpy())
142
+ all_lengths.append(model_kwargs["y"]["lengths"].cpu().numpy())
143
+
144
+ print(f"created {len(all_motions) * args.batch_size} samples")
145
+
146
+ return {
147
+ "motions": np.concatenate(all_motions, axis=0),
148
+ "audio": np.concatenate(all_audio, axis=0),
149
+ "gt": np.concatenate(all_gt, axis=0),
150
+ "lengths": np.concatenate(all_lengths, axis=0),
151
+ "keyframes": np.concatenate(all_keyframes, axis=0),
152
+ }
153
+
154
+
155
+ def _render_pred(
156
+ args,
157
+ data_block: Dict[str, torch.Tensor],
158
+ sample_file_template: str,
159
+ audio_per_frame: int,
160
+ ) -> None:
161
+ from visualize.render_codes import BodyRenderer
162
+
163
+ face_codes = None
164
+ if args.face_codes is not None:
165
+ face_codes = np.load(args.face_codes, allow_pickle=True).item()
166
+ face_motions = face_codes["motions"]
167
+ face_gts = face_codes["gt"]
168
+ face_audio = face_codes["audio"]
169
+
170
+ config_base = f"./checkpoints/ca_body/data/{get_person_num(args.data_root)}"
171
+ body_renderer = BodyRenderer(
172
+ config_base=config_base,
173
+ render_rgb=True,
174
+ )
175
+
176
+ for sample_i in range(args.num_samples):
177
+ for rep_i in range(args.num_repetitions):
178
+ idx = rep_i * args.batch_size + sample_i
179
+ save_file = sample_file_template.format(sample_i, rep_i)
180
+ animation_save_path = os.path.join(args.output_dir, save_file)
181
+ # format data
182
+ length = data_block["lengths"][idx]
183
+ body_motion = (
184
+ data_block["motions"][idx].transpose(2, 0, 1)[:length].squeeze(-1)
185
+ )
186
+ face_motion = face_motions[idx].transpose(2, 0, 1)[:length].squeeze(-1)
187
+ assert np.array_equal(
188
+ data_block["audio"][idx], face_audio[idx]
189
+ ), "face audio is not the same"
190
+ audio = data_block["audio"][idx, : length * audio_per_frame, :].T
191
+ # set up render data block to pass into renderer
192
+ render_data_block = {
193
+ "audio": audio,
194
+ "body_motion": body_motion,
195
+ "face_motion": face_motion,
196
+ }
197
+ if args.render_gt:
198
+ gt_body = data_block["gt"][idx].transpose(2, 0, 1)[:length].squeeze(-1)
199
+ gt_face = face_gts[idx].transpose(2, 0, 1)[:length].squeeze(-1)
200
+ render_data_block["gt_body"] = gt_body
201
+ render_data_block["gt_face"] = gt_face
202
+ body_renderer.render_full_video(
203
+ render_data_block,
204
+ animation_save_path,
205
+ audio_sr=audio_per_frame * 30,
206
+ render_gt=args.render_gt,
207
+ )
208
+
209
+
210
+ def _reset_sample_args(args) -> None:
211
+ # set the sequence length to match the one specified by user
212
+ name = os.path.basename(os.path.dirname(args.model_path))
213
+ niter = os.path.basename(args.model_path).replace("model", "").replace(".pt", "")
214
+ args.curr_seq_length = (
215
+ args.curr_seq_length
216
+ if args.curr_seq_length is not None
217
+ else args.max_seq_length
218
+ )
219
+ # add the resume predictor model path
220
+ resume_trans_name = ""
221
+ if args.data_format == "pose" and args.resume_trans is not None:
222
+ resume_trans_parts = args.resume_trans.split("/")
223
+ resume_trans_name = f"{resume_trans_parts[1]}_{resume_trans_parts[-1]}"
224
+ # reformat the output directory
225
+ args.output_dir = os.path.join(
226
+ os.path.dirname(args.model_path),
227
+ "samples_{}_{}_seed{}_{}".format(name, niter, args.seed, resume_trans_name),
228
+ )
229
+ assert (
230
+ args.num_samples <= args.batch_size
231
+ ), f"Please either increase batch_size({args.batch_size}) or reduce num_samples({args.num_samples})"
232
+ # set the batch size to match the number of samples to generate
233
+ args.batch_size = args.num_samples
234
+
235
+
236
+ def _setup_dataset(args) -> DataLoader:
237
+ data_root = args.data_root
238
+ data_dict = load_local_data(
239
+ data_root,
240
+ audio_per_frame=1600,
241
+ flip_person=args.flip_person,
242
+ )
243
+ test_data = get_dataset_loader(
244
+ args=args,
245
+ data_dict=data_dict,
246
+ split="test",
247
+ chunk=True,
248
+ )
249
+ return test_data
250
+
251
+
252
+ def _setup_model(
253
+ args,
254
+ ) -> (Union[FiLMTransformer, ClassifierFreeSampleModel], SpacedDiffusion):
255
+ model, diffusion = create_model_and_diffusion(args, split_type="test")
256
+ print(f"Loading checkpoints from [{args.model_path}]...")
257
+ state_dict = torch.load(args.model_path, map_location="cpu")
258
+ load_model(model, state_dict)
259
+
260
+ if not args.unconstrained:
261
+ assert args.guidance_param != 1
262
+
263
+ if args.guidance_param != 1:
264
+ prGreen("[CFS] wrapping model in classifier free sample")
265
+ model = ClassifierFreeSampleModel(model)
266
+ model.to(args.device)
267
+ model.eval()
268
+ return model, diffusion
269
+
270
+
271
+ def main():
272
+ args = generate_args()
273
+ fixseed(args.seed)
274
+ _reset_sample_args(args)
275
+
276
+ print("Loading dataset...")
277
+ test_data = _setup_dataset(args)
278
+ iterator = iter(test_data)
279
+
280
+ print("Creating model and diffusion...")
281
+ model, diffusion = _setup_model(args)
282
+
283
+ if args.pose_codes is None:
284
+ # generate sequences
285
+ gt, model_kwargs = next(iterator)
286
+ data_block = _generate_sequences(
287
+ args, model_kwargs, diffusion, model, test_data, gt
288
+ )
289
+ os.makedirs(args.output_dir, exist_ok=True)
290
+ npy_path = os.path.join(args.output_dir, "results.npy")
291
+ print(f"saving results file to [{npy_path}]")
292
+ np.save(npy_path, data_block)
293
+ else:
294
+ # load the pre generated results
295
+ data_block = np.load(args.pose_codes, allow_pickle=True).item()
296
+
297
+ # plot function only if face_codes exist and we are on pose prediction
298
+ if args.plot:
299
+ assert args.face_codes is not None, "need body and faces"
300
+ assert (
301
+ args.data_format == "pose"
302
+ ), "currently only supporting plot on pose stuff"
303
+ print(f"saving visualizations to [{args.output_dir}]...")
304
+ _, _, _, sample_file_template, _, _ = _construct_template_variables(
305
+ args.unconstrained
306
+ )
307
+ _render_pred(
308
+ args,
309
+ data_block,
310
+ sample_file_template,
311
+ test_data.dataset.audio_per_frame,
312
+ )
313
+
314
+
315
+ if __name__ == "__main__":
316
+ main()
scripts/download_alldatasets.sh ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ for i in "PXB184" "RLW104" "TXB805" "GQS883"
2
+ do
3
+ curl -L https://github.com/facebookresearch/audio2photoreal/releases/download/v1.0/${i}.zip -o ${i}.zip || { echo 'downloading dataset failed' ; exit 1; }
4
+ unzip ${i}.zip -d dataset/
5
+ rm ${i}.zip
6
+ done
scripts/download_allmodels.sh ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ for i in "PXB184" "RLW104" "TXB805" "GQS883"
2
+ do
3
+ # download motion models
4
+ wget http://audio2photoreal_models.berkeleyvision.org/${i}_models.tar || { echo 'downloading model failed' ; exit 1; }
5
+ tar xvf ${i}_models.tar
6
+ rm ${i}_models.tar
7
+
8
+ # download ca body rendering checkpoints and assets
9
+ mkdir -p checkpoints/ca_body/data/
10
+ wget https://github.com/facebookresearch/ca_body/releases/download/v0.0.1-alpha/${i}.tar.gz || { echo 'downloading ca body model failed' ; exit 1; }
11
+ tar xvf ${i}.tar.gz --directory checkpoints/ca_body/data/
12
+ rm ${i}.tar.gz
13
+ done
scripts/download_prereq.sh ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # install the prerequisite asset models (lip regressor and wav2vec)
3
+ wget http://audio2photoreal_models.berkeleyvision.org/asset_models.tar
4
+ tar xvf asset_models.tar
5
+ rm asset_models.tar
6
+
7
+ # we obtained the wav2vec models via these links:
8
+ # wget https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_large.pt -P ./assets/
9
+ # wget https://dl.fbaipublicfiles.com/fairseq/wav2vec/vq-wav2vec.pt -P ./assets/
scripts/installation.sh ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # download the prerequisite asset models (lip regressor and wav2vec)
2
+ wget http://audio2photoreal_models.berkeleyvision.org/asset_models.tar
3
+ tar xvf asset_models.tar
4
+ rm asset_models.tar
scripts/requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ attrdict
2
+ blobfile
3
+ einops
4
+ fairseq
5
+ gradio
6
+ matplotlib
7
+ mediapy
8
+ numpy==1.23.0
9
+ opencv-python
10
+ packaging
11
+ scikit-learn
12
+ tensorboard
13
+ tensorboardX
14
+ torch==2.0.1
15
+ torchaudio==2.0.2
16
+ torchvision==0.15.2
17
+ tqdm
train/train_diffusion.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ All rights reserved.
4
+ This source code is licensed under the license found in the
5
+ LICENSE file in the root directory of this source tree.
6
+ """
7
+
8
+ import json
9
+ import os
10
+
11
+ import torch
12
+ import torch.multiprocessing as mp
13
+
14
+ from data_loaders.get_data import get_dataset_loader, load_local_data
15
+ from torch.nn.parallel import DistributedDataParallel as DDP
16
+ from torch.utils.tensorboard import SummaryWriter
17
+ from train.train_platforms import ClearmlPlatform, NoPlatform, TensorboardPlatform
18
+ from train.training_loop import TrainLoop
19
+ from utils.diff_parser_utils import train_args
20
+ from utils.misc import cleanup, fixseed, setup_dist
21
+ from utils.model_util import create_model_and_diffusion
22
+
23
+
24
+ def main(rank: int, world_size: int):
25
+ args = train_args()
26
+ fixseed(args.seed)
27
+ train_platform_type = eval(args.train_platform_type)
28
+ train_platform = train_platform_type(args.save_dir)
29
+ train_platform.report_args(args, name="Args")
30
+ setup_dist(args.device)
31
+
32
+ if rank == 0:
33
+ if args.save_dir is None:
34
+ raise FileNotFoundError("save_dir was not specified.")
35
+ elif os.path.exists(args.save_dir) and not args.overwrite:
36
+ raise FileExistsError("save_dir [{}] already exists.".format(args.save_dir))
37
+ elif not os.path.exists(args.save_dir):
38
+ os.makedirs(args.save_dir)
39
+ args_path = os.path.join(args.save_dir, "args.json")
40
+ with open(args_path, "w") as fw:
41
+ json.dump(vars(args), fw, indent=4, sort_keys=True)
42
+
43
+ if not os.path.exists(args.data_root):
44
+ args.data_root = args.data_root.replace("/home/", "/derived/")
45
+
46
+ data_dict = load_local_data(args.data_root, audio_per_frame=1600)
47
+ print("creating data loader...")
48
+ data = get_dataset_loader(args=args, data_dict=data_dict)
49
+
50
+ print("creating logger...")
51
+ writer = SummaryWriter(args.save_dir)
52
+
53
+ print("creating model and diffusion...")
54
+ model, diffusion = create_model_and_diffusion(args, split_type="train")
55
+ model.to(rank)
56
+
57
+ if world_size > 1:
58
+ model = DDP(
59
+ model, device_ids=[rank], output_device=rank, find_unused_parameters=True
60
+ )
61
+
62
+ params = (
63
+ model.module.parameters_w_grad()
64
+ if world_size > 1
65
+ else model.parameters_w_grad()
66
+ )
67
+ print("Total params: %.2fM" % (sum(p.numel() for p in params) / 1000000.0))
68
+ print("Training...")
69
+
70
+ TrainLoop(
71
+ args, train_platform, model, diffusion, data, writer, rank, world_size
72
+ ).run_loop()
73
+ train_platform.close()
74
+ cleanup()
75
+
76
+
77
+ if __name__ == "__main__":
78
+ world_size = torch.cuda.device_count()
79
+ print(f"using {world_size} gpus")
80
+ if world_size > 1:
81
+ mp.spawn(main, args=(world_size,), nprocs=world_size, join=True)
82
+ else:
83
+ main(rank=0, world_size=1)