Commit
·
934fdee
0
Parent(s):
first commit
Browse files- .gitattributes +42 -0
- .gitignore +6 -0
- Checkpoint_License.txt +40 -0
- Example_Data_License.txt +29 -0
- Multi-HMR_License.txt +98 -0
- README.md +117 -0
- app.py +269 -0
- assets/visu1.gif +3 -0
- assets/visu2.gif +3 -0
- blocks/__init__.py +8 -0
- blocks/camera_embed.py +58 -0
- blocks/cross_attn_transformer.py +359 -0
- blocks/dinov2.py +27 -0
- blocks/smpl_layer.py +153 -0
- demo.py +262 -0
- example_data/170149601_13aa4e4483_c.jpg +3 -0
- example_data/3692623581_aca6eb02d4_e.jpg +3 -0
- example_data/3969570423_58eb848b75_c.jpg +3 -0
- example_data/39742984604_46934fbd50_c.jpg +3 -0
- example_data/4446582661_b188f82f3c_c.jpg +3 -0
- example_data/51960182045_d5d6407a3c_c.jpg +3 -0
- example_data/5850091922_73ba296093_c.jpg +3 -0
- model.py +485 -0
- models/multiHMR/multiHMR.pt +3 -0
- packages.txt +3 -0
- requirements.txt +15 -0
- utils/__init__.py +15 -0
- utils/camera.py +75 -0
- utils/color.py +22 -0
- utils/constants.py +9 -0
- utils/download.py +103 -0
- utils/humans.py +24 -0
- utils/image.py +40 -0
- utils/render.py +448 -0
- utils/tensor_manip.py +45 -0
.gitattributes
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.png* filter=lfs diff=lfs merge=lfs -text
|
37 |
+
*.pt* filter=lfs diff=lfs merge=lfs -text
|
38 |
+
*.jpg* filter=lfs diff=lfs merge=lfs -text
|
39 |
+
*.jpeg* filter=lfs diff=lfs merge=lfs -text
|
40 |
+
assets/visu1.gif filter=lfs diff=lfs merge=lfs -text
|
41 |
+
assets/visu2.gif filter=lfs diff=lfs merge=lfs -text
|
42 |
+
models/multiHMR/multiHMR.pt filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__
|
2 |
+
*.glb
|
3 |
+
*.npz
|
4 |
+
tmp_data
|
5 |
+
._.DS_Store
|
6 |
+
.DS_Store
|
Checkpoint_License.txt
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Multi-HMR Checkpoints, Copyright (c) 2024 Naver Corporation, are licensed under the Checkpoint License below.
|
2 |
+
|
3 |
+
The following datasets, which are not being distributed with the Multi-HMR Checkpoints (hereinafter referred to as "Checkpoints"), were used to train one or more of the Checkpoints:
|
4 |
+
|
5 |
+
(A) BEDLAM Dataset: see https://bedlam.is.tue.mpg.de
|
6 |
+
made available under the following license: https://bedlam.is.tue.mpg.de/license.html
|
7 |
+
Also see: Michael Black et al., "A Synthetic Dataset of Bodies Exhibiting Detailed Lifelike Animated Motion" in Proceedings IEEE/CVF Conf.~on Computer Vision and Pattern Recognition (CVPR), pp. 8726-8737, June 2023.
|
8 |
+
|
9 |
+
(B) AGORA Dataset: see https://agora.is.tue.mpg.de/index.html
|
10 |
+
made available under the following license: https://agora.is.tue.mpg.de/license.html
|
11 |
+
Also see: Priyanka Patel et al, "{AGORA}: Avatars in Geography Optimized for Regression Analysis" in Proceedings IEEE/CVF Conf.~on Computer Vision and Pattern Recognition ({CVPR}), June 2021.
|
12 |
+
|
13 |
+
(C) 3DPW Dataset: see https://virtualhumans.mpi-inf.mpg.de/3DPW/evaluation.html
|
14 |
+
made available under the following license: https://virtualhumans.mpi-inf.mpg.de/3DPW/license.html
|
15 |
+
Also see: von Marcard et al., "Recovering Accurate 3D Human Pose in The Wild Using IMUs and a Moving Camera" in European Conference on Computer Vision (ECCV), Sept. 2018.
|
16 |
+
|
17 |
+
(D) UBody Dataset: see https://osx-ubody.github.io/
|
18 |
+
made available under the following license: https://docs.google.com/document/d/1R-nn6qguO0YDkPKBleZ8NyrqGrjLXfJ7AQTTsATMYZc/edit
|
19 |
+
Also see: Jing Lin et al., "One-Stage 3D Whole-Body Mesh Recovery with Component Aware Transformer" in CVPR 2023.
|
20 |
+
|
21 |
+
|
22 |
+
----------------------------------------------------------------
|
23 |
+
CHECKPOINT LICENSE WHICH ACCOUNT FOR DATASET LICENSES ABOVE:
|
24 |
+
----------------------------------------------------------------
|
25 |
+
|
26 |
+
LICENSE GRANT
|
27 |
+
|
28 |
+
BY EXERCISING ANY RIGHTS TO THE CHECKPOINTS, YOU ACCEPT AND AGREE TO BE BOUND BY THE TERMS OF THIS LICENSE. TO THE EXTENT THIS LICENSE MAY BE CONSIDERED TO BE A CONTRACT, NAVER GRANTS YOU THE RIGHTS CONTAINED HERE IN CONSIDERATION OF YOUR ACCEPTANCE OF SUCH TERMS AND CONDITIONS.
|
29 |
+
|
30 |
+
Subject to the terms and conditions of this License, Naver hereby grants you a personal, revocable, royalty-free, non-exclusive, non-sublicensable, non-transferable license to use the Checkpoints subject to the following conditions:
|
31 |
+
|
32 |
+
(1) PERMITTED USES: You may use the Checkpoints (1) solely for the sole purpose of performing non-commercial scientific research, non-commercial education, or non-commercial artistic projects and (2) for no purpose that is excluded by any Dataset License under (A)-(D) above (“Purpose”).
|
33 |
+
|
34 |
+
(2) COPYRIGHT: You will retain the above copyright notice and license along with the disclaimer below in all copies or substantial portions of the Checkpoints.
|
35 |
+
|
36 |
+
(3) TERM: You agree the License automatically terminates without notice if you fail to comply with its terms, or you may terminate this License by ceasing to use the Checkpoints for the Purpose. Upon termination you agree to delete any and all copies of the Checkpoints.
|
37 |
+
|
38 |
+
ALL RIGHTS NOT EXPRESSLY GRANTED IN THIS LICENSE ARE RESERVED BY NAVER.
|
39 |
+
|
40 |
+
THE CHECKPOINTS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL NAVER BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE CHECKPOINTS OR THE USE OR OTHER DEALINGS IN THE CHECKPOINTS.
|
Example_Data_License.txt
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
THE FOLLOWING FILES AVAILABLE IN THIS DIRECTORY ARE BEING REDISTRIBUTED WITHOUT MODIFICATION FROM THEIR ORGINAL SOURCE INDICATED BELOW:
|
2 |
+
|
3 |
+
[A] FILE: 39742984604_46934fbd50_c.jpg CC0 [LICENSE: https://creativecommons.org/publicdomain/zero/1.0/]
|
4 |
+
LINK TO ORIGINAL: https://www.flickr.com/photos/138224835@N02/39742984604/in/photolist-23xXbGu-2kXXquq-2kXWBq4-kMD93w-kMB2PK-2kXWAvD-kMB4i6-kMDaeu-kMB4Np-kMBmEi-2kRpySa-2gLwm9L-2jjfFMN-2kRmJmq-2bBtz9S-2jjfFKt-2kSKY3X-2j8jGWF-2ipiiZ8-2ipbovG-2kNanr8-kMD93S-kMB3st-2kYGdpc-2dmWhrJ-kMB2AZ-kMCMRj-UPgVWc-2kYAXNe-2d5bp4B-2kYGAep-2kN98Ni-214amvd-kMBJnr-2j8k8cj-2kN8kmS-2kYBJzk-2kQG8N2-kMBJAn-kMALNH-kMBnhR-kMD9uU-2kN8jWZ-kMBKiV-2dmWhih-2kXNmL8-kMAH5a-kMCEA7-kMACwx-2gLwm4F
|
5 |
+
ATTRIBUTION: Yun_Q
|
6 |
+
|
7 |
+
[B] FILE: 4446582661_b188f82f3c_c.jpg [LICENSE: CC BY-NC-ND 2.0 AVAILABLE HERE: https://creativecommons.org/licenses/by-nc-nd/2.0/]
|
8 |
+
LINK TO ORIGINAL: https://www.flickr.com/photos/mirsasha/4446582661/in/photolist-7LVU7i-e3mDLj-e3mDxL-amhvdR-bw3UQR-amhBAH-amkkDA-e3fYZP-6sRpSL-SqQnjB-THieCq-9uZrTE-e3fYtH-e3fYYF-fQpAoR-2nNEszq-y4Faqu-4Vt164-mjzSco-amhB4P-amhBSP-YW9exq-bUPLmh-yiX3A9-5uEfDG-5PJRav-7wgFZn-XQKh1N-fBSsEX-5uzT4n-5qFqyV-mmpMR-5wju3g-YUFLBo-e5F5wq-e5yNza-4U1qLk-e5FkTG-4Vrbfc-4CNy3W-4X2Zyv-5t969C-amkjah-5qKKSG-5rzJuu-5vcjew-5uEgdC-e5EMef-4Lm1s9-ymYibg
|
9 |
+
ATTRIBUTION: mirsasha
|
10 |
+
|
11 |
+
[C] FILE: 3692623581_aca6eb02d4_e.jpg [LICENSE: CC BY-NC-ND 2.0 AVAILABLE HERE: https://creativecommons.org/licenses/by-nc-nd/2.0/]
|
12 |
+
LINK TO ORIGINAL: https://www.flickr.com/photos/chasemcalpine/3692623581/in/photolist-6CiEkP-6CiEpR-6CiEx2-6CiEZT-6CnMiC-6CnNoN-6CiErz-7kZN9Y-6CnN1h-6CiETx-6CnN4W-6CiF6P-6NRfC8-5fw88t-tMsMc-M9JFr6-audTNc-7p39Vu-7oYi88-8fbPSG-6CiFmM-6CiDMv-6CnLCW-6CnMnu-6CnMsu-6CiECr-6CiDYz-6CnMRJ-6CnN8N-5655q7-7p3a59-y9dZj-9yXwAb-7snC7p-4FRj19-8jaDN6-qtjT1-7pn36Y-dP4XkX-dP4Xh4-6fba2g-9EtDM3-9EqJEn-9EqJtB-2i2h8EN-6P9Mee-f11f3F-f11dBx-eZZTPF-551BoQ
|
13 |
+
ATTRIBUTION: Chase McAlpine
|
14 |
+
|
15 |
+
[D] FILE: 5850091922_73ba296093_c.jpg [LICENSE: CC BY-NC-ND 2.0 AVAILABLE HERE: https://creativecommons.org/licenses/by-nc/2.0/]
|
16 |
+
LINK TO ORIGINAL: https://www.flickr.com/photos/rwoan/5850091922/in/photolist-9UXfvQ-9UXbHo-9UXumJ-9UXrqf-9UXzou-9UUGiV-9UXn7s-9UX1Hf-9UXpk3-9UXnM1-9UXaTd-9UX6Am-9UUFne-9UXsVj-9UUfJT-9UXkkj-9UXogs-9UUvV8-9UUdBX-9UU9wB-9UUD1p-9UUkA2-9UXjQy-9UX8ym-9UX19A-9UWZqh-9UUtJP-9UXxxh-9UXavG-9UX6V3-9UUEb8-9UXf3m-9UUDQe-9UXy9C-9UXtMh-8VRM4n-8VUQPQ-9UUBsc-9UXoMj-9UXhtU-9UUf2Z-9UUy7T-9UXe31-9UXhcy-9UXuR5-9UXs8G-9UX2Xh-9UXgWj-9UUzxn-9UUdSM
|
17 |
+
ATTRIBUTION: Ronald Woan
|
18 |
+
|
19 |
+
[E] FILE: 51960182045_d5d6407a3c_c.jpg [LICENSE: CC BY 2.0 AVAILABLE HERE: https://creativecommons.org/licenses/by/2.0/]
|
20 |
+
LINK TO ORIGINAL: https://www.flickr.com/photos/edrost88/51960182045/in/photolist-2dhKmwF-v4wCzs-vHVFcm-vHVFrj-2naxvRF-s5y56-2ozwe9b-8r14iU-8r14Bj-8qWWxB-s5y1M-2jA1RTm-NxtH1E-9QNiVy-oDvaVX-oDLQgD-oDMaZg-oDLUqv-onherq-NxtwnL-onhp6z-NPLSwA-PSsqmC-oDLVXD-onhPX7-oni2jp-oDM9Nt-onhw5q-NxtA2j-oDyjkQ-oDyj6b-oDyDSq-7RcmkH-oDvzHr-onho2F-onhqoz-onha73-onhmaz-oDvaPp-oBK9Us-oDLShc-oDM4nt-273QQFs-onhivh-onhJp4-oBKrWy-oBKxdf-w1WKvv-vHVE89-v4F4NX
|
21 |
+
ATTRIBUTION: Erik Drost
|
22 |
+
|
23 |
+
[F] FILE: 3969570423_58eb848b75_c.jpg [LICENSE: CC BY 2.0 AVAILABLE HERE: https://creativecommons.org/licenses/by-sa/2.0/]
|
24 |
+
LINK TO ORIGINAL: https://www.flickr.com/photos/jansolo09/3969570423/in/photolist-73M5Zn-8dQJqW-8aVGvd-7ZK2Ch-8vmBrz-N3U17e-9zkuGB-yLQNS-9zoufL-6TTR1w-8q6oSW-5MzhfE-9zotMY-dJvQKH-9gaGZg-9wfaFK-7p8Eqh-62BtZL-7UzsUq-rovX6Z-7UwaMB-7UzsCE-7Uw9Fr-7Uzp1J-62xdfP-2m14TNX-4o1kr6-4oXNj2-nsi9QD-6TWGFX-7R6qCC-VbRRUk-2ooTq9C-8bCzjp-9zotxb-7p4NDM-q3gbSg-71fw1w-2oLpmTq-dhK9bi-7iwmvQ-6TPV9x-6TTSNC-7UzoTQ-bwnd4G-8jdTfA-PnFy7Y-nSL5Fv-dCHAKX-6TPPBx
|
25 |
+
ATTRIBUTION: Jan S0L0
|
26 |
+
|
27 |
+
[G] 170149601_13aa4e4483_c.jpg [LICENSE: CC BY-NC-ND 2.0 AVAILABLE HERE: https://creativecommons.org/licenses/by-nc-nd/2.0/]
|
28 |
+
LINK TO ORIGINAL: https://www.flickr.com/photos/wallyg/170149601/in/photolist-g34xP-9yRoKk-haeGC-i4u8dc-dbzfGd-6XJgsx-37hzc-2m792Tp-2n6BnqE-5zH4Yv-8KGj4D-8XMdHG-2n6D2Yc-yyCTU-8qevrJ-9CaiFs-4Be6kH-57oaAr-2m794vN-uindz-bTvAZD-47p3Fj-bEATx9-2m3xhFm-8KGiRx-6VQxan-6bUYCv-bvdoT6-9y8bbY-6jovR1-2n6BnUq-zrvUX-2ncrjMP-2ncxVyn-2ncrjkb-aVsBcB-ShZV6w-2m79N6e-2m3tqy5-2m3usgd-9CahZ9-57oaVH-8AoThL-57ob86-57oaKi-2m4hjp2-kZQsEY-57snZL-9xVH4B-7GQdij
|
29 |
+
ATTRIBUTION: Wally Gobetz
|
Multi-HMR_License.txt
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Multi-HMR, Copyright (c) 2024 Naver Corporation, is licensed under the Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license.
|
2 |
+
|
3 |
+
A summary of the CC BY-NC-SA 4.0 license is located here:
|
4 |
+
https://creativecommons.org/licenses/by-nc-sa/4.0/
|
5 |
+
|
6 |
+
The CC BY-NC-SA 4.0 license is located here:
|
7 |
+
https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
|
8 |
+
|
9 |
+
|
10 |
+
**************************************************************************
|
11 |
+
SEE NOTICES BELOW CONCERNING SOFTWARE AND DATA:
|
12 |
+
**************************************************************************
|
13 |
+
|
14 |
+
----------------------------------------------------------------
|
15 |
+
PART 1: NOTICES CONCERNING SOFTWARE FILES:
|
16 |
+
----------------------------------------------------------------
|
17 |
+
|
18 |
+
(A) NOTICE WITH RESPECT TO THE SOFTWARE: blocks/cross_attn_transformer.py
|
19 |
+
|
20 |
+
This software is being redistributed in a modifiled form. The original form is available here:
|
21 |
+
|
22 |
+
https://github.com/shubham-goel/4D-Humans/blob/a0def798c7eac811a63c8220fcc22d983b39785e/hmr2/models/components/t_cond_mlp.py
|
23 |
+
https://github.com/shubham-goel/4D-Humans/blob/a0def798c7eac811a63c8220fcc22d983b39785e/hmr2/models/components/pose_transformer.py
|
24 |
+
|
25 |
+
ORIGINAL COPYRIGHT NOTICE AND PERMISSION NOTICE AVAILABLE HERE IS REPRODUCE BELOW in Part 3 at [A]:
|
26 |
+
|
27 |
+
https://github.com/shubham-goel/4D-Humans/blob/a0def798c7eac811a63c8220fcc22d983b39785e/LICENSE.md
|
28 |
+
|
29 |
+
|
30 |
+
(B) NOTICE WITH RESPECT TO THE SOFTWARE: model.py
|
31 |
+
|
32 |
+
This software is being redistributed in a modifiled form. The original form is available here:
|
33 |
+
|
34 |
+
https://github.com/shubham-goel/4D-Humans/blob/a0def798c7eac811a63c8220fcc22d983b39785e/hmr2/models/components/pose_transformer.py
|
35 |
+
https://github.com/shubham-goel/4D-Humans/blob/a0def798c7eac811a63c8220fcc22d983b39785e/hmr2/models/heads/smpl_head.py
|
36 |
+
|
37 |
+
ORIGINAL COPYRIGHT NOTICE AND PERMISSION NOTICE AVAILABLE HERE IS REPRODUCE BELOW in Part 3 at [B]:
|
38 |
+
|
39 |
+
https://github.com/shubham-goel/4D-Humans/blob/a0def798c7eac811a63c8220fcc22d983b39785e/LICENSE.md
|
40 |
+
|
41 |
+
|
42 |
+
(C) NOTICE WITH RESPECT TO THE SOFTWARE: blocks/cross_attn_transformer.py
|
43 |
+
|
44 |
+
This software is being redistributed in a modifiled form. The original form is available here:
|
45 |
+
|
46 |
+
https://github.com/lucidrains/vit-pytorch
|
47 |
+
|
48 |
+
ORIGINAL COPYRIGHT NOTICE AND PERMISSION NOTICE AVAILABLE HERE IS REPRODUCE BELOW in Part 3 at [C]:
|
49 |
+
|
50 |
+
https://github.com/lucidrains/vit-pytorch/blob/main/LICENSE
|
51 |
+
|
52 |
+
|
53 |
+
----------------------------------------------------------------
|
54 |
+
PART 2: ORIGINAL COPYRIGHT NOTICE AND PERMISSION NOTICE:
|
55 |
+
----------------------------------------------------------------
|
56 |
+
|
57 |
+
NOTICE WITH RESPECT TO DATA IN THIS DIRECTORY: example_data
|
58 |
+
jpg files available in the directory are made available subject to the license set forth therein.
|
59 |
+
|
60 |
+
----------------------------------------------------------------
|
61 |
+
PART 3: ORIGINAL COPYRIGHT NOTICE AND PERMISSION NOTICE:
|
62 |
+
----------------------------------------------------------------
|
63 |
+
|
64 |
+
[A] / [B] https://github.com/shubham-goel/4D-Humans/blob/a0def798c7eac811a63c8220fcc22d983b39785e/LICENSE.md
|
65 |
+
|
66 |
+
MIT License
|
67 |
+
|
68 |
+
Copyright (c) 2023 UC Regents, Shubham Goel
|
69 |
+
|
70 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
71 |
+
|
72 |
+
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
73 |
+
|
74 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
75 |
+
|
76 |
+
[C] https://github.com/lucidrains/vit-pytorch
|
77 |
+
|
78 |
+
MIT License
|
79 |
+
|
80 |
+
Copyright (c) 2020 Phil Wang
|
81 |
+
|
82 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
83 |
+
of this software and associated documentation files (the "Software"), to deal
|
84 |
+
in the Software without restriction, including without limitation the rights
|
85 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
86 |
+
copies of the Software, and to permit persons to whom the Software is
|
87 |
+
furnished to do so, subject to the following conditions:
|
88 |
+
|
89 |
+
The above copyright notice and this permission notice shall be included in all
|
90 |
+
copies or substantial portions of the Software.
|
91 |
+
|
92 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
93 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
94 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
95 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
96 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
97 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
98 |
+
SOFTWARE.
|
README.md
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Multi HMR
|
3 |
+
emoji: 👬
|
4 |
+
colorFrom: pink
|
5 |
+
colorTo: purple
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 4.13.0
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
---
|
11 |
+
|
12 |
+
<p align="center">
|
13 |
+
<h1 align="center">Multi-HMR: Regressing Whole-Body Human Meshes <br> for Multiple Persons in a Single Shot</h1>
|
14 |
+
|
15 |
+
<p align="center">
|
16 |
+
Fabien Baradel*,
|
17 |
+
Matthieu Armando,
|
18 |
+
Salma Galaaoui,
|
19 |
+
Romain Brégier, <br>
|
20 |
+
Philippe Weinzaepfel,
|
21 |
+
Grégory Rogez,
|
22 |
+
Thomas Lucas*
|
23 |
+
</p>
|
24 |
+
|
25 |
+
<p align="center">
|
26 |
+
<sup>*</sup> equal contribution
|
27 |
+
</p>
|
28 |
+
|
29 |
+
<p align="center">
|
30 |
+
<a href="./"><img alt="arXiv" src="https://img.shields.io/badge/arXiv-xxxx.xxxxx-00ff00.svg"></a>
|
31 |
+
<a href="./"><img alt="Blogpost" src="https://img.shields.io/badge/Blogpost-up-yellow"></a>
|
32 |
+
<a href="./"><img alt="Demo" src="https://img.shields.io/badge/Demo-up-blue"></a>
|
33 |
+
<a href="./"><img alt="Hugging Face Spaces" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue"></a>
|
34 |
+
</p>
|
35 |
+
|
36 |
+
<div align="center">
|
37 |
+
<img width="49%" alt="Multi-HMR illustration 1" src="assets/visu1.gif">
|
38 |
+
<img width="49%" alt="Multi-HMR illustration 2" src="assets/visu2.gif">
|
39 |
+
|
40 |
+
<br>
|
41 |
+
Multi-HMR is a simple yet effective single-shot model for multi-person and expressive human mesh recovery.
|
42 |
+
It takes as input a single RGB image and efficiently performs 3D reconstruction of multiple humans in camera space.
|
43 |
+
<br>
|
44 |
+
</div>
|
45 |
+
</p>
|
46 |
+
|
47 |
+
## Installation
|
48 |
+
First, you need to clone the repo.
|
49 |
+
|
50 |
+
We recommand to use virtual enviroment for running MultiHMR.
|
51 |
+
Please run the following lines for creating the environment with ```venv```:
|
52 |
+
```bash
|
53 |
+
python3.9 -m venv .multihmr
|
54 |
+
source .multihmr/bin/activate
|
55 |
+
pip install -r requirements.txt
|
56 |
+
```
|
57 |
+
|
58 |
+
Otherwise you can also create a conda environment.
|
59 |
+
```bash
|
60 |
+
conda env create -f conda.yaml
|
61 |
+
conda activate multihmr
|
62 |
+
```
|
63 |
+
|
64 |
+
The installation has been tested with CUDA 11.7.
|
65 |
+
|
66 |
+
Checkpoints will automatically be downloaded to `$HOME/models/multiHMR` the first time you run the demo code.
|
67 |
+
|
68 |
+
Besides these files, you also need to download the *SMPLX* model.
|
69 |
+
You will need the [neutral model](http://smplify.is.tue.mpg.de) for running the demo code.
|
70 |
+
Please go to the corresponding website and register to get access to the downloads section.
|
71 |
+
Download the model and place `SMPLX_NEUTRAL.npz` in `./models/smplx/`.
|
72 |
+
|
73 |
+
## Run Multi-HMR on images
|
74 |
+
The following command will run Multi-HMR on all images in the specified `--img_folder`, and save renderings of the reconstructions in `--out_folder`.
|
75 |
+
The `--model_name` flag specifies the model to use.
|
76 |
+
The `--extra_views` flags additionally renders the side and bev view of the reconstructed scene, `--save_mesh` saves meshes as in a '.npy' file.
|
77 |
+
```bash
|
78 |
+
python3.9 demo.py \
|
79 |
+
--img_folder example_data \
|
80 |
+
--out_folder demo_out \
|
81 |
+
--extra_views 1 \
|
82 |
+
--model_name multiHMR_896_L_synth
|
83 |
+
```
|
84 |
+
|
85 |
+
## Pre-trained models
|
86 |
+
We provide multiple pre-trained checkpoints.
|
87 |
+
Here is a list of their associated features.
|
88 |
+
Once downloaded you need to place them into `$HOME/models/multiHMR`.
|
89 |
+
|
90 |
+
| modelname | training data | backbone | resolution | runtime (ms) |
|
91 |
+
|-------------------------------|-----------------------------------|----------|------------|--------------|
|
92 |
+
| [multiHMR_896_L_synth](./) | BEDLAM+AGORA | ViT-L | 896x896 | 126 |
|
93 |
+
|
94 |
+
We compute the runtime on GPU V100-32GB.
|
95 |
+
|
96 |
+
## License
|
97 |
+
The code is distributed under the CC BY-NC-SA 4.0 License.\
|
98 |
+
See [Multi-HMR LICENSE](Multi-HMR_License.txt), [Checkpoint LICENSE](Checkpoint_License.txt) and [Example Data LICENSE](Example_Data_License.txt) for more information.
|
99 |
+
|
100 |
+
## Citing
|
101 |
+
If you find this code useful for your research, please consider citing the following paper:
|
102 |
+
```bibtex
|
103 |
+
@inproceedings{multi-hmr2024,
|
104 |
+
title={Multi-HMR: Single-Shot Multi-Person Expressive Human Mesh Recovery},
|
105 |
+
author={Baradel*, Fabien and
|
106 |
+
Armando, Matthieu and
|
107 |
+
Galaaoui, Salma and
|
108 |
+
Br{\'e}gier, Romain and
|
109 |
+
Weinzaepfel, Philippe and
|
110 |
+
Rogez, Gr{\'e}gory and
|
111 |
+
Lucas*, Thomas
|
112 |
+
},
|
113 |
+
booktitle={arXiv},
|
114 |
+
year={2024}
|
115 |
+
}
|
116 |
+
```
|
117 |
+
|
app.py
ADDED
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
On chaos-01:
|
3 |
+
CUDA_VISIBLE_DEVICES="" XFORMERS_DISABLED=1 python app.py
|
4 |
+
CUDA_VISIBLE_DEVICES="0" XFORMERS_DISABLED=0 python app.py
|
5 |
+
|
6 |
+
On laptop:
|
7 |
+
ssh -N -L 8000:127.0.0.1:7860 chaos-01
|
8 |
+
|
9 |
+
"""
|
10 |
+
import spaces
|
11 |
+
import os
|
12 |
+
os.environ["no_proxy"] = "localhost,127.0.0.1,::1"
|
13 |
+
|
14 |
+
from utils.constants import SMPLX_DIR, MEAN_PARAMS
|
15 |
+
from argparse import ArgumentParser
|
16 |
+
import torch
|
17 |
+
import gradio as gr
|
18 |
+
from PIL import Image, ImageOps
|
19 |
+
import numpy as np
|
20 |
+
from pathlib import Path
|
21 |
+
|
22 |
+
if torch.cuda.is_available() and torch.cuda.device_count()>0:
|
23 |
+
device = torch.device('cuda:0')
|
24 |
+
os.environ["PYOPENGL_PLATFORM"] = "egl"
|
25 |
+
device_name = torch.cuda.get_device_name(0)
|
26 |
+
print(f"Device - GPU: {device_name}")
|
27 |
+
else:
|
28 |
+
device = torch.device('cpu')
|
29 |
+
os.environ["PYOPENGL_PLATFORM"] = "osmesa"
|
30 |
+
device_name = 'CPU'
|
31 |
+
print("Device - CPU")
|
32 |
+
|
33 |
+
from demo import forward_model, get_camera_parameters, overlay_human_meshes, load_model as _load_model
|
34 |
+
from utils import normalize_rgb, demo_color as color, create_scene
|
35 |
+
import time
|
36 |
+
import shutil
|
37 |
+
|
38 |
+
model = None
|
39 |
+
example_data_dir = 'example_data'
|
40 |
+
list_examples = os.listdir(example_data_dir)
|
41 |
+
list_examples_basename = [x for x in list_examples if x.endswith(('.jpg', 'jpeg', 'png')) and not x.startswith('._')]
|
42 |
+
list_examples = [[os.path.join(example_data_dir, x)] for x in list_examples_basename]
|
43 |
+
_list_examples_basename = [Path(x).stem for x in list_examples_basename]
|
44 |
+
tmp_data_dir = 'tmp_data'
|
45 |
+
|
46 |
+
def download_smplx():
|
47 |
+
os.makedirs(os.path.join(SMPLX_DIR, 'smplx'), exist_ok=True)
|
48 |
+
smplx_fname = os.path.join(SMPLX_DIR, 'smplx', 'SMPLX_NEUTRAL.npz')
|
49 |
+
|
50 |
+
if not os.path.isfile(smplx_fname):
|
51 |
+
print('Start to download the SMPL-X model')
|
52 |
+
if not ('SMPLX_LOGIN' in os.environ and 'SMPLX_PWD' in os.environ):
|
53 |
+
raise ValueError('You need to set a secret for SMPLX_LOGIN and for SMPLX_PWD to run this space')
|
54 |
+
fname = "models_smplx_v1_1.zip"
|
55 |
+
username = os.environ['SMPLX_LOGIN'].replace('@','%40')
|
56 |
+
password = os.environ['SMPLX_PWD']
|
57 |
+
cmd = f"wget -O {fname} --save-cookies cookies.txt --keep-session-cookies --post-data 'username={username}&password={password}' \"https://download.is.tue.mpg.de/download.php?domain=smplx&sfile={fname}\""
|
58 |
+
os.system(cmd)
|
59 |
+
assert os.path.isfile(fname), "failed to download"
|
60 |
+
os.system(f'unzip {fname}')
|
61 |
+
os.system(f"cp models/smplx/SMPLX_NEUTRAL.npz {smplx_fname}")
|
62 |
+
assert os.path.isfile(smplx_fname), "failed to find smplx file"
|
63 |
+
print('SMPL-X has been succesfully downloaded')
|
64 |
+
else:
|
65 |
+
print('SMPL-X is already here')
|
66 |
+
|
67 |
+
if not os.path.isfile(MEAN_PARAMS):
|
68 |
+
print('Start to download the SMPL mean params')
|
69 |
+
os.system(f"wget -O {MEAN_PARAMS} https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmhuman3d/models/smpl_mean_params.npz?versionId=CAEQHhiBgICN6M3V6xciIDU1MzUzNjZjZGNiOTQ3OWJiZTJmNThiZmY4NmMxMTM4")
|
70 |
+
print('SMPL mean params have been succesfully downloaded')
|
71 |
+
else:
|
72 |
+
print('SMPL mean params is already here')
|
73 |
+
|
74 |
+
@spaces.GPU
|
75 |
+
def infer(fn, det_thresh, nms_kernel_size):
|
76 |
+
global device
|
77 |
+
global model
|
78 |
+
|
79 |
+
# Is it an image from example_data_dir ?
|
80 |
+
basename = Path(os.path.basename(fn)).stem
|
81 |
+
_basename = f"{basename}_thresh{int(det_thresh*100)}_nms{int(nms_kernel_size)}"
|
82 |
+
is_known_image = (basename in _list_examples_basename) # only images from example_data
|
83 |
+
|
84 |
+
# Filenames
|
85 |
+
if not is_known_image:
|
86 |
+
_basename = 'output' # such that we do not save all the uploaded results - not sure ?
|
87 |
+
_glb_fn = f"{_basename}.glb"
|
88 |
+
_rend_fn = f"{_basename}.png"
|
89 |
+
glb_fn = os.path.join(tmp_data_dir, _glb_fn)
|
90 |
+
rend_fn = os.path.join(tmp_data_dir, _rend_fn)
|
91 |
+
os.makedirs(tmp_data_dir, exist_ok=True)
|
92 |
+
|
93 |
+
# Already processed
|
94 |
+
is_preprocessed = False
|
95 |
+
if is_known_image:
|
96 |
+
_tmp_data_dir_files = os.listdir(tmp_data_dir)
|
97 |
+
is_preprocessed = (_glb_fn in _tmp_data_dir_files) and (_rend_fn in _tmp_data_dir_files) # already preprocessed
|
98 |
+
|
99 |
+
is_known = is_known_image and is_preprocessed
|
100 |
+
if not is_known:
|
101 |
+
im = Image.open(fn)
|
102 |
+
fov, p_x, p_y = 60, None, None # FOV=60 always here!
|
103 |
+
img_size = model.img_size
|
104 |
+
|
105 |
+
# Get camera information
|
106 |
+
p_x, p_y = None, None
|
107 |
+
K = get_camera_parameters(img_size, fov=fov, p_x=p_x, p_y=p_y, device=device)
|
108 |
+
|
109 |
+
# Resise but keep aspect ratio
|
110 |
+
img_pil = ImageOps.contain(im, (img_size,img_size)) # keep the same aspect ratio
|
111 |
+
|
112 |
+
# Which side is too small/big
|
113 |
+
width, height = img_pil.size
|
114 |
+
pad = abs(width - height) // 2
|
115 |
+
|
116 |
+
# Pad
|
117 |
+
img_pil_bis = ImageOps.pad(img_pil.copy(), size=(img_size, img_size), color=(255, 255, 255))
|
118 |
+
img_pil = ImageOps.pad(img_pil, size=(img_size, img_size)) # pad with zero on the smallest side
|
119 |
+
|
120 |
+
# Numpy - normalize - torch.
|
121 |
+
resize_img = normalize_rgb(np.asarray(img_pil))
|
122 |
+
x = torch.from_numpy(resize_img).unsqueeze(0).to(device)
|
123 |
+
|
124 |
+
img_array = np.asarray(img_pil_bis)
|
125 |
+
img_pil_visu = Image.fromarray(img_array)
|
126 |
+
|
127 |
+
start = time.time()
|
128 |
+
humans = forward_model(model, x, K, det_thresh=det_thresh, nms_kernel_size=nms_kernel_size)
|
129 |
+
print(f"Forward: {time.time() - start:.2f}sec")
|
130 |
+
|
131 |
+
# Overlay
|
132 |
+
start = time.time()
|
133 |
+
pred_rend_array, _ = overlay_human_meshes(humans, K, model, img_pil_visu)
|
134 |
+
rend_pil = Image.fromarray(pred_rend_array.astype(np.uint8))
|
135 |
+
rend_pil.crop()
|
136 |
+
if width > height:
|
137 |
+
rend_pil = rend_pil.crop((0,pad,width,pad+height))
|
138 |
+
else:
|
139 |
+
rend_pil =rend_pil.crop((pad,0,pad+width,height))
|
140 |
+
rend_pil.save(rend_fn)
|
141 |
+
print(f"Rendering with pyrender: {time.time() - start:.2f}sec")
|
142 |
+
|
143 |
+
# Save into glb
|
144 |
+
start = time.time()
|
145 |
+
l_mesh = [humans[j]['verts_smplx'].detach().cpu().numpy() for j in range(len(humans))]
|
146 |
+
l_face = [model.smpl_layer['neutral'].bm_x.faces for j in range(len(humans))]
|
147 |
+
scene = create_scene(img_pil_visu, l_mesh, l_face, color=color, metallicFactor=0., roughnessFactor=0.5)
|
148 |
+
scene.export(glb_fn)
|
149 |
+
print(f"Exporting scene in glb: {time.time() - start:.2f}sec")
|
150 |
+
else:
|
151 |
+
print("We already have the predictions-visus stored somewhere...")
|
152 |
+
|
153 |
+
out = [rend_fn, glb_fn]
|
154 |
+
print(out)
|
155 |
+
return out
|
156 |
+
# return [rend_fn, hidden_glb_fn]
|
157 |
+
# return [rend_fn, my_glb_fn]
|
158 |
+
|
159 |
+
|
160 |
+
if __name__ == "__main__":
|
161 |
+
parser = ArgumentParser()
|
162 |
+
parser.add_argument("--model_name", type=str, default='multiHMR')
|
163 |
+
parser.add_argument("--logs_path", type=str, default='./data')
|
164 |
+
|
165 |
+
args = parser.parse_args()
|
166 |
+
|
167 |
+
# Info
|
168 |
+
### Description and style
|
169 |
+
logo = r"""
|
170 |
+
<center>
|
171 |
+
<img src='https://europe.naverlabs.com/wp-content/uploads/2020/10/NLE_1_WHITE_264x60_opti.png' alt='Multi-HMR logo' style="width:250px; margin-bottom:10px">
|
172 |
+
</center>
|
173 |
+
"""
|
174 |
+
title = r"""
|
175 |
+
<center>
|
176 |
+
<h1 align="center">Multi-HMR: Regressing Whole-Body Human Meshes for Multiple Persons in a Single Shot</h1>
|
177 |
+
</center>
|
178 |
+
"""
|
179 |
+
|
180 |
+
description = f"""
|
181 |
+
The demo is running on a {device_name}.
|
182 |
+
<br>
|
183 |
+
[<b>Demo code</b>] If you want to run Multi-HMR on several images please consider using the demo code available on [our Github repo](https://github.com/naver/multiHMR)
|
184 |
+
"""
|
185 |
+
|
186 |
+
article = r"""
|
187 |
+
---
|
188 |
+
📝 **Citation**
|
189 |
+
<br>
|
190 |
+
If our work is useful for your research, please consider citing:
|
191 |
+
```bibtex
|
192 |
+
@inproceedings{multihmr2024,
|
193 |
+
title={Multi-HMR: Regressing Whole-Body Human Meshes for Multiple Persons in a Single Shot},
|
194 |
+
author={Baradel*, Fabien and
|
195 |
+
Armando, Matthieu and
|
196 |
+
Galaaoui, Salma and
|
197 |
+
Br{\'e}gier, Romain and
|
198 |
+
Weinzaepfel, Philippe and
|
199 |
+
Rogez, Gr{\'e}gory and
|
200 |
+
Lucas*, Thomas},
|
201 |
+
booktitle={arXiv},
|
202 |
+
year={2024}
|
203 |
+
}
|
204 |
+
```
|
205 |
+
📋 **License**
|
206 |
+
<br>
|
207 |
+
CC BY-NC-SA 4.0 License. Please refer to the [LICENSE file](./Multi-HMR_License.txt) for details.
|
208 |
+
<br>
|
209 |
+
📧 **Contact**
|
210 |
+
<br>
|
211 |
+
If you have any questions, please feel free to send a message to <b>[email protected]</b> or open an issue on the [Github repo](https://github.com/naver/multi-hmr).
|
212 |
+
"""
|
213 |
+
|
214 |
+
# Download SMPLX model and mean params
|
215 |
+
download_smplx()
|
216 |
+
|
217 |
+
# Loading the model
|
218 |
+
model = _load_model(args.model_name, device=device)
|
219 |
+
|
220 |
+
# Gradio demo
|
221 |
+
with gr.Blocks(title="Multi-HMR", css=".gradio-container") as demo:
|
222 |
+
# gr.HTML("""
|
223 |
+
# <div style="font-weight:bold; text-align:center; color:royalblue;">Multi-HMR: <br> Multi-Person Whole-Body Human Mesh Recovery in a Single Shot </div>
|
224 |
+
# """)
|
225 |
+
gr.Markdown(logo)
|
226 |
+
gr.Markdown(title)
|
227 |
+
gr.Markdown(description)
|
228 |
+
|
229 |
+
with gr.Row():
|
230 |
+
with gr.Column():
|
231 |
+
input_image = gr.Image(label="Input image",
|
232 |
+
# type="pil",
|
233 |
+
type="filepath",
|
234 |
+
sources=['upload', 'clipboard'])
|
235 |
+
with gr.Column():
|
236 |
+
output_image = gr.Image(label="Reconstructions - Overlay",
|
237 |
+
# type="pil",
|
238 |
+
type="filepath",
|
239 |
+
)
|
240 |
+
|
241 |
+
gr.HTML("""<br/>""")
|
242 |
+
|
243 |
+
with gr.Row():
|
244 |
+
with gr.Column():
|
245 |
+
alpha = -70 # longitudinal rotation in degree
|
246 |
+
beta = 70 # latitudinal rotation in degree
|
247 |
+
radius = 3. # distance to the 3D model
|
248 |
+
radius = None # distance to the 3D model
|
249 |
+
output_model3d = gr.Model3D(label="Reconstructions - 3D scene",
|
250 |
+
camera_position=(alpha, beta, radius),
|
251 |
+
clear_color=[1.0, 1.0, 1.0, 0.0])
|
252 |
+
|
253 |
+
gr.HTML("""<br/>""")
|
254 |
+
|
255 |
+
with gr.Row():
|
256 |
+
threshold = gr.Slider(0.1, 0.7, step=0.1, value=0.3, label='Detection Threshold')
|
257 |
+
nms = gr.Radio(label="NMS kernel size", choices=[1, 3, 5], value=3)
|
258 |
+
send_btn = gr.Button("Infer")
|
259 |
+
send_btn.click(fn=infer, inputs=[input_image, threshold, nms], outputs=[output_image, output_model3d])
|
260 |
+
|
261 |
+
gr.Examples(list_examples,
|
262 |
+
inputs=[input_image, 0.3, 3])
|
263 |
+
|
264 |
+
gr.Markdown(article)
|
265 |
+
|
266 |
+
demo.queue() # <-- Sets up a queue with default parameters
|
267 |
+
demo.launch(debug=True, share=False)
|
268 |
+
|
269 |
+
|
assets/visu1.gif
ADDED
![]() |
Git LFS Details
|
assets/visu2.gif
ADDED
![]() |
Git LFS Details
|
blocks/__init__.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from .camera_embed import FourierPositionEncoding
|
3 |
+
|
4 |
+
from .dinov2 import Dinov2Backbone
|
5 |
+
|
6 |
+
from .cross_attn_transformer import TransformerDecoder
|
7 |
+
|
8 |
+
from .smpl_layer import SMPL_Layer
|
blocks/camera_embed.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Multi-HMR
|
2 |
+
# Copyright (c) 2024-present NAVER Corp.
|
3 |
+
# CC BY-NC-SA 4.0 license
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from torch import nn
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
class FourierPositionEncoding(nn.Module):
|
10 |
+
def __init__(self, n, num_bands, max_resolution):
|
11 |
+
"""
|
12 |
+
Module that generate Fourier encoding - no learning involved
|
13 |
+
"""
|
14 |
+
super().__init__()
|
15 |
+
|
16 |
+
self.num_bands = num_bands
|
17 |
+
self.max_resolution = [max_resolution] * n
|
18 |
+
|
19 |
+
@property
|
20 |
+
def channels(self):
|
21 |
+
"""
|
22 |
+
Return the output dimension
|
23 |
+
"""
|
24 |
+
num_dims = len(self.max_resolution)
|
25 |
+
encoding_size = self.num_bands * num_dims
|
26 |
+
encoding_size *= 2 # sin-cos
|
27 |
+
encoding_size += num_dims # concat
|
28 |
+
|
29 |
+
return encoding_size
|
30 |
+
|
31 |
+
def forward(self, pos):
|
32 |
+
"""
|
33 |
+
Forward pass that take rays as input and generate Fourier positional encodings
|
34 |
+
"""
|
35 |
+
fourier_pos_enc = _generate_fourier_features(pos, num_bands=self.num_bands, max_resolution=self.max_resolution)
|
36 |
+
return fourier_pos_enc
|
37 |
+
|
38 |
+
|
39 |
+
def _generate_fourier_features(pos, num_bands, max_resolution):
|
40 |
+
"""Generate fourier features from a given set of positions and frequencies"""
|
41 |
+
b, n = pos.shape[:2]
|
42 |
+
device = pos.device
|
43 |
+
|
44 |
+
# Linear frequency sampling
|
45 |
+
min_freq = 1.0
|
46 |
+
freq_bands = torch.stack([torch.linspace(start=min_freq, end=res / 2, steps=num_bands, device=device) for res in max_resolution], dim=0)
|
47 |
+
|
48 |
+
# Stacking
|
49 |
+
per_pos_features = torch.stack([pos[i, :, :][:, :, None] * freq_bands[None, :, :] for i in range(b)], 0)
|
50 |
+
per_pos_features = per_pos_features.reshape(b, n, -1)
|
51 |
+
|
52 |
+
# Sin-Cos
|
53 |
+
per_pos_features = torch.cat([torch.sin(np.pi * per_pos_features), torch.cos(np.pi * per_pos_features)], dim=-1)
|
54 |
+
|
55 |
+
# Concat with initial pos
|
56 |
+
per_pos_features = torch.cat([pos, per_pos_features], dim=-1)
|
57 |
+
|
58 |
+
return per_pos_features
|
blocks/cross_attn_transformer.py
ADDED
@@ -0,0 +1,359 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Multi-HMR
|
2 |
+
# Copyright (c) 2024-present NAVER Corp.
|
3 |
+
# CC BY-NC-SA 4.0 license
|
4 |
+
|
5 |
+
from typing import Callable, Optional
|
6 |
+
import torch
|
7 |
+
from torch import nn
|
8 |
+
from inspect import isfunction
|
9 |
+
from einops import rearrange
|
10 |
+
|
11 |
+
class AdaptiveLayerNorm1D(torch.nn.Module):
|
12 |
+
"""
|
13 |
+
Code modified from https://github.com/shubham-goel/4D-Humans/blob/a0def798c7eac811a63c8220fcc22d983b39785e/hmr2/models/components/t_cond_mlp.py#L7
|
14 |
+
"""
|
15 |
+
def __init__(self, data_dim: int, norm_cond_dim: int):
|
16 |
+
super().__init__()
|
17 |
+
if data_dim <= 0:
|
18 |
+
raise ValueError(f"data_dim must be positive, but got {data_dim}")
|
19 |
+
if norm_cond_dim <= 0:
|
20 |
+
raise ValueError(f"norm_cond_dim must be positive, but got {norm_cond_dim}")
|
21 |
+
self.norm = torch.nn.LayerNorm(
|
22 |
+
data_dim
|
23 |
+
) # TODO: Check if elementwise_affine=True is correct
|
24 |
+
self.linear = torch.nn.Linear(norm_cond_dim, 2 * data_dim)
|
25 |
+
torch.nn.init.zeros_(self.linear.weight)
|
26 |
+
torch.nn.init.zeros_(self.linear.bias)
|
27 |
+
|
28 |
+
def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
|
29 |
+
# x: (batch, ..., data_dim)
|
30 |
+
# t: (batch, norm_cond_dim)
|
31 |
+
# return: (batch, data_dim)
|
32 |
+
x = self.norm(x)
|
33 |
+
alpha, beta = self.linear(t).chunk(2, dim=-1)
|
34 |
+
|
35 |
+
# Add singleton dimensions to alpha and beta
|
36 |
+
if x.dim() > 2:
|
37 |
+
alpha = alpha.view(alpha.shape[0], *([1] * (x.dim() - 2)), alpha.shape[1])
|
38 |
+
beta = beta.view(beta.shape[0], *([1] * (x.dim() - 2)), beta.shape[1])
|
39 |
+
|
40 |
+
return x * (1 + alpha) + beta
|
41 |
+
|
42 |
+
|
43 |
+
def normalization_layer(norm: Optional[str], dim: int, norm_cond_dim: int = -1):
|
44 |
+
"""
|
45 |
+
Code modified from https://github.com/shubham-goel/4D-Humans/blob/a0def798c7eac811a63c8220fcc22d983b39785e/hmr2/models/components/t_cond_mlp.py#L48
|
46 |
+
"""
|
47 |
+
if norm == "batch":
|
48 |
+
return torch.nn.BatchNorm1d(dim)
|
49 |
+
elif norm == "layer":
|
50 |
+
return torch.nn.LayerNorm(dim)
|
51 |
+
elif norm == "ada":
|
52 |
+
assert norm_cond_dim > 0, f"norm_cond_dim must be positive, got {norm_cond_dim}"
|
53 |
+
return AdaptiveLayerNorm1D(dim, norm_cond_dim)
|
54 |
+
elif norm is None:
|
55 |
+
return torch.nn.Identity()
|
56 |
+
else:
|
57 |
+
raise ValueError(f"Unknown norm: {norm}")
|
58 |
+
|
59 |
+
|
60 |
+
def exists(val):
|
61 |
+
"Code modified from https://github.com/shubham-goel/4D-Humans/blob/a0def798c7eac811a63c8220fcc22d983b39785e/hmr2/models/components/pose_transformer.py#L17"
|
62 |
+
return val is not None
|
63 |
+
|
64 |
+
|
65 |
+
def default(val, d):
|
66 |
+
"Code modified from https://github.com/shubham-goel/4D-Humans/blob/a0def798c7eac811a63c8220fcc22d983b39785e/hmr2/models/components/pose_transformer.py#L21"
|
67 |
+
if exists(val):
|
68 |
+
return val
|
69 |
+
return d() if isfunction(d) else d
|
70 |
+
|
71 |
+
|
72 |
+
class PreNorm(nn.Module):
|
73 |
+
"""
|
74 |
+
Code modified from https://github.com/shubham-goel/4D-Humans/blob/a0def798c7eac811a63c8220fcc22d983b39785e/hmr2/models/components/pose_transformer.py#L27
|
75 |
+
"""
|
76 |
+
def __init__(self, dim: int, fn: Callable, norm: str = "layer", norm_cond_dim: int = -1):
|
77 |
+
super().__init__()
|
78 |
+
self.norm = normalization_layer(norm, dim, norm_cond_dim)
|
79 |
+
self.fn = fn
|
80 |
+
|
81 |
+
def forward(self, x: torch.Tensor, *args, **kwargs):
|
82 |
+
if isinstance(self.norm, AdaptiveLayerNorm1D):
|
83 |
+
return self.fn(self.norm(x, *args), **kwargs)
|
84 |
+
else:
|
85 |
+
return self.fn(self.norm(x), **kwargs)
|
86 |
+
|
87 |
+
|
88 |
+
class FeedForward(nn.Module):
|
89 |
+
"""
|
90 |
+
Code modified from https://github.com/shubham-goel/4D-Humans/blob/a0def798c7eac811a63c8220fcc22d983b39785e/hmr2/models/components/pose_transformer.py#L40
|
91 |
+
"""
|
92 |
+
def __init__(self, dim, hidden_dim, dropout=0.0):
|
93 |
+
super().__init__()
|
94 |
+
self.net = nn.Sequential(
|
95 |
+
nn.Linear(dim, hidden_dim),
|
96 |
+
nn.GELU(),
|
97 |
+
nn.Dropout(dropout),
|
98 |
+
nn.Linear(hidden_dim, dim),
|
99 |
+
nn.Dropout(dropout),
|
100 |
+
)
|
101 |
+
|
102 |
+
def forward(self, x):
|
103 |
+
return self.net(x)
|
104 |
+
|
105 |
+
|
106 |
+
class Attention(nn.Module):
|
107 |
+
"""
|
108 |
+
Code modified from https://github.com/shubham-goel/4D-Humans/blob/a0def798c7eac811a63c8220fcc22d983b39785e/hmr2/models/components/pose_transformer.py#L55
|
109 |
+
"""
|
110 |
+
def __init__(self, dim, heads=8, dim_head=64, dropout=0.0):
|
111 |
+
super().__init__()
|
112 |
+
inner_dim = dim_head * heads
|
113 |
+
project_out = not (heads == 1 and dim_head == dim)
|
114 |
+
|
115 |
+
self.heads = heads
|
116 |
+
self.scale = dim_head**-0.5
|
117 |
+
|
118 |
+
self.attend = nn.Softmax(dim=-1)
|
119 |
+
self.dropout = nn.Dropout(dropout)
|
120 |
+
|
121 |
+
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
|
122 |
+
|
123 |
+
self.to_out = (
|
124 |
+
nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))
|
125 |
+
if project_out
|
126 |
+
else nn.Identity()
|
127 |
+
)
|
128 |
+
|
129 |
+
def forward(self, x, mask=None):
|
130 |
+
|
131 |
+
qkv = self.to_qkv(x).chunk(3, dim=-1)
|
132 |
+
# n --> the num query dimension
|
133 |
+
|
134 |
+
# TODO reshape b into b2 n and mask.
|
135 |
+
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), qkv)
|
136 |
+
|
137 |
+
if mask is not None:
|
138 |
+
q, k, v = [x * mask[:, None, :, None] for x in [q, k, v]]
|
139 |
+
|
140 |
+
# q, k, v: [13:51:03.400365] torch.Size([22, 1, 256])
|
141 |
+
#q, k ,vk after reshape: torch.Size([16, 8, 1, 32])
|
142 |
+
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
143 |
+
|
144 |
+
if mask is not None:
|
145 |
+
dots = dots - (1 - mask)[:, None, None, :] * 10e10
|
146 |
+
|
147 |
+
attn = self.attend(dots)
|
148 |
+
|
149 |
+
if mask is not None: # Just for good measure; this is probably overkill
|
150 |
+
attn = attn * mask[:, None, None, :]
|
151 |
+
|
152 |
+
attn = self.dropout(attn)
|
153 |
+
|
154 |
+
out = torch.matmul(attn, v)
|
155 |
+
|
156 |
+
# out shape :torch.Size([16, 8, 1, 32])
|
157 |
+
|
158 |
+
out = rearrange(out, "b h n d -> b n (h d)")
|
159 |
+
return self.to_out(out)
|
160 |
+
|
161 |
+
|
162 |
+
class CrossAttention(nn.Module):
|
163 |
+
"Code modified from https://github.com/shubham-goel/4D-Humans/blob/a0def798c7eac811a63c8220fcc22d983b39785e/hmr2/models/components/pose_transformer.py#L89"
|
164 |
+
def __init__(self, dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
|
165 |
+
super().__init__()
|
166 |
+
inner_dim = dim_head * heads
|
167 |
+
project_out = not (heads == 1 and dim_head == dim)
|
168 |
+
|
169 |
+
self.heads = heads
|
170 |
+
self.scale = dim_head**-0.5
|
171 |
+
|
172 |
+
self.attend = nn.Softmax(dim=-1)
|
173 |
+
self.dropout = nn.Dropout(dropout)
|
174 |
+
|
175 |
+
context_dim = default(context_dim, dim)
|
176 |
+
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=False)
|
177 |
+
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
178 |
+
|
179 |
+
self.to_out = (
|
180 |
+
nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))
|
181 |
+
if project_out
|
182 |
+
else nn.Identity()
|
183 |
+
)
|
184 |
+
|
185 |
+
def forward(self, x, context=None, mask=None):
|
186 |
+
|
187 |
+
context = default(context, x)
|
188 |
+
k, v = self.to_kv(context).chunk(2, dim=-1)
|
189 |
+
q = self.to_q(x)
|
190 |
+
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), [q, k, v])
|
191 |
+
|
192 |
+
if mask is not None:
|
193 |
+
q = q * mask[:, None, :, None]
|
194 |
+
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
195 |
+
if mask is not None:
|
196 |
+
dots = dots - (1 - mask).float()[:, None, :, None] * 1e6
|
197 |
+
attn = self.attend(dots)
|
198 |
+
attn = self.dropout(attn)
|
199 |
+
|
200 |
+
out = torch.matmul(attn, v)
|
201 |
+
|
202 |
+
if mask is not None: # Just for good measure; this is probably overkill
|
203 |
+
out = out * mask[:, None, :, None]
|
204 |
+
out = rearrange(out, "b h n d -> b n (h d)")
|
205 |
+
return self.to_out(out)
|
206 |
+
|
207 |
+
class TransformerCrossAttn(nn.Module):
|
208 |
+
"Code modified from https://github.com/shubham-goel/4D-Humans/blob/a0def798c7eac811a63c8220fcc22d983b39785e/hmr2/models/components/pose_transformer.py#L160"
|
209 |
+
def __init__(
|
210 |
+
self,
|
211 |
+
dim: int,
|
212 |
+
depth: int,
|
213 |
+
heads: int,
|
214 |
+
dim_head: int,
|
215 |
+
mlp_dim: int,
|
216 |
+
dropout: float = 0.0,
|
217 |
+
norm: str = "layer",
|
218 |
+
norm_cond_dim: int = -1,
|
219 |
+
context_dim: Optional[int] = None,
|
220 |
+
):
|
221 |
+
super().__init__()
|
222 |
+
self.layers = nn.ModuleList([])
|
223 |
+
for _ in range(depth):
|
224 |
+
sa = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)
|
225 |
+
ca = CrossAttention(
|
226 |
+
dim, context_dim=context_dim, heads=heads, dim_head=dim_head, dropout=dropout
|
227 |
+
)
|
228 |
+
ff = FeedForward(dim, mlp_dim, dropout=dropout)
|
229 |
+
self.layers.append(
|
230 |
+
nn.ModuleList(
|
231 |
+
[
|
232 |
+
PreNorm(dim, sa, norm=norm, norm_cond_dim=norm_cond_dim),
|
233 |
+
PreNorm(dim, ca, norm=norm, norm_cond_dim=norm_cond_dim),
|
234 |
+
PreNorm(dim, ff, norm=norm, norm_cond_dim=norm_cond_dim),
|
235 |
+
]
|
236 |
+
)
|
237 |
+
)
|
238 |
+
|
239 |
+
def forward(self, x: torch.Tensor, *args, context=None, context_list=None, mask=None):
|
240 |
+
|
241 |
+
if context_list is None:
|
242 |
+
context_list = [context] * len(self.layers)
|
243 |
+
|
244 |
+
if len(context_list) != len(self.layers):
|
245 |
+
raise ValueError(f"len(context_list) != len(self.layers) ({len(context_list)} != {len(self.layers)})")
|
246 |
+
|
247 |
+
for i, (self_attn, cross_attn, ff) in enumerate(self.layers):
|
248 |
+
if mask is not None:
|
249 |
+
try:
|
250 |
+
x = x * mask[:, :, None]
|
251 |
+
except:
|
252 |
+
print("see ")
|
253 |
+
import pdb; pdb.set_trace()
|
254 |
+
x = self_attn(x, mask=mask, *args) + x
|
255 |
+
x = cross_attn(x, mask=mask, *args, context=context_list[i]) + x
|
256 |
+
x = ff(x, *args) + x
|
257 |
+
|
258 |
+
if mask is not None:
|
259 |
+
x = x * mask[:, :, None]
|
260 |
+
|
261 |
+
return x
|
262 |
+
|
263 |
+
class DropTokenDropout(nn.Module):
|
264 |
+
"Code modified from https://github.com/shubham-goel/4D-Humans/blob/a0def798c7eac811a63c8220fcc22d983b39785e/hmr2/models/components/pose_transformer.py#L204"
|
265 |
+
def __init__(self, p: float = 0.1):
|
266 |
+
super().__init__()
|
267 |
+
if p < 0 or p > 1:
|
268 |
+
raise ValueError(
|
269 |
+
"dropout probability has to be between 0 and 1, " "but got {}".format(p)
|
270 |
+
)
|
271 |
+
self.p = p
|
272 |
+
|
273 |
+
def forward(self, x: torch.Tensor):
|
274 |
+
# x: (batch_size, seq_len, dim)
|
275 |
+
if self.training and self.p > 0:
|
276 |
+
zero_mask = torch.full_like(x[0, :, 0], self.p).bernoulli().bool()
|
277 |
+
# TODO: permutation idx for each batch using torch.argsort
|
278 |
+
if zero_mask.any():
|
279 |
+
x = x[:, ~zero_mask, :]
|
280 |
+
return x
|
281 |
+
|
282 |
+
|
283 |
+
class ZeroTokenDropout(nn.Module):
|
284 |
+
"Code modified from https://github.com/shubham-goel/4D-Humans/blob/a0def798c7eac811a63c8220fcc22d983b39785e/hmr2/models/components/pose_transformer.py#L223"
|
285 |
+
def __init__(self, p: float = 0.1):
|
286 |
+
super().__init__()
|
287 |
+
if p < 0 or p > 1:
|
288 |
+
raise ValueError(
|
289 |
+
"dropout probability has to be between 0 and 1, " "but got {}".format(p)
|
290 |
+
)
|
291 |
+
self.p = p
|
292 |
+
|
293 |
+
def forward(self, x: torch.Tensor):
|
294 |
+
# x: (batch_size, seq_len, dim)
|
295 |
+
if self.training and self.p > 0:
|
296 |
+
zero_mask = torch.full_like(x[:, :, 0], self.p).bernoulli().bool()
|
297 |
+
# Zero-out the masked tokens
|
298 |
+
x[zero_mask, :] = 0
|
299 |
+
return x
|
300 |
+
|
301 |
+
|
302 |
+
class TransformerDecoder(nn.Module):
|
303 |
+
"Code modified from https://github.com/shubham-goel/4D-Humans/blob/a0def798c7eac811a63c8220fcc22d983b39785e/hmr2/models/components/pose_transformer.py#L301"
|
304 |
+
def __init__(
|
305 |
+
self,
|
306 |
+
num_tokens: int,
|
307 |
+
token_dim: int,
|
308 |
+
dim: int,
|
309 |
+
depth: int,
|
310 |
+
heads: int,
|
311 |
+
mlp_dim: int,
|
312 |
+
dim_head: int = 64,
|
313 |
+
dropout: float = 0.0,
|
314 |
+
emb_dropout: float = 0.0,
|
315 |
+
emb_dropout_type: str = 'drop',
|
316 |
+
norm: str = "layer",
|
317 |
+
norm_cond_dim: int = -1,
|
318 |
+
context_dim: Optional[int] = None,
|
319 |
+
skip_token_embedding: bool = False,
|
320 |
+
):
|
321 |
+
super().__init__()
|
322 |
+
if not skip_token_embedding:
|
323 |
+
self.to_token_embedding = nn.Linear(token_dim, dim)
|
324 |
+
else:
|
325 |
+
self.to_token_embedding = nn.Identity()
|
326 |
+
if token_dim != dim:
|
327 |
+
raise ValueError(
|
328 |
+
f"token_dim ({token_dim}) != dim ({dim}) when skip_token_embedding is True"
|
329 |
+
)
|
330 |
+
|
331 |
+
self.pos_embedding = nn.Parameter(torch.randn(1, num_tokens, dim))
|
332 |
+
if emb_dropout_type == "drop":
|
333 |
+
self.dropout = DropTokenDropout(emb_dropout)
|
334 |
+
elif emb_dropout_type == "zero":
|
335 |
+
self.dropout = ZeroTokenDropout(emb_dropout)
|
336 |
+
elif emb_dropout_type == "normal":
|
337 |
+
self.dropout = nn.Dropout(emb_dropout)
|
338 |
+
|
339 |
+
self.transformer = TransformerCrossAttn(
|
340 |
+
dim,
|
341 |
+
depth,
|
342 |
+
heads,
|
343 |
+
dim_head,
|
344 |
+
mlp_dim,
|
345 |
+
dropout,
|
346 |
+
norm=norm,
|
347 |
+
norm_cond_dim=norm_cond_dim,
|
348 |
+
context_dim=context_dim,
|
349 |
+
)
|
350 |
+
|
351 |
+
def forward(self, inp: torch.Tensor, *args, context=None, context_list=None, mask=None):
|
352 |
+
x = self.to_token_embedding(inp)
|
353 |
+
b, n, _ = x.shape
|
354 |
+
|
355 |
+
x = self.dropout(x)
|
356 |
+
#x += self.pos_embedding[:, :n]
|
357 |
+
x += self.pos_embedding[:, 0][:, None, :] # For now, we don't wish to embed a position. We might in future versions though.
|
358 |
+
x = self.transformer(x, *args, context=context, context_list=context_list, mask=mask)
|
359 |
+
return x
|
blocks/dinov2.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Multi-HMR
|
2 |
+
# Copyright (c) 2024-present NAVER Corp.
|
3 |
+
# CC BY-NC-SA 4.0 license
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from torch import nn
|
7 |
+
|
8 |
+
class Dinov2Backbone(nn.Module):
|
9 |
+
def __init__(self, name='dinov2_vitb14', *args, **kwargs):
|
10 |
+
super().__init__()
|
11 |
+
self.name = name
|
12 |
+
self.encoder = torch.hub.load('facebookresearch/dinov2', self.name, pretrained=False)
|
13 |
+
self.patch_size = self.encoder.patch_size
|
14 |
+
self.embed_dim = self.encoder.embed_dim
|
15 |
+
|
16 |
+
def forward(self, x):
|
17 |
+
"""
|
18 |
+
Encode a RGB image using a ViT-backbone
|
19 |
+
Args:
|
20 |
+
- x: torch.Tensor of shape [bs,3,w,h]
|
21 |
+
Return:
|
22 |
+
- y: torch.Tensor of shape [bs,k,d] - image in patchified mode
|
23 |
+
"""
|
24 |
+
assert len(x.shape) == 4
|
25 |
+
y = self.encoder.get_intermediate_layers(x)[0] # ViT-L+896x896: [bs,4096,1024] - [bs,nb_patches,emb]
|
26 |
+
return y
|
27 |
+
|
blocks/smpl_layer.py
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Multi-HMR
|
2 |
+
# Copyright (c) 2024-present NAVER Corp.
|
3 |
+
# CC BY-NC-SA 4.0 license
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from torch import nn
|
7 |
+
from torch import nn
|
8 |
+
import smplx
|
9 |
+
import torch
|
10 |
+
import numpy as np
|
11 |
+
import utils
|
12 |
+
from utils import inverse_perspective_projection, perspective_projection
|
13 |
+
import roma
|
14 |
+
import pickle
|
15 |
+
import os
|
16 |
+
from utils.constants import SMPLX_DIR
|
17 |
+
|
18 |
+
class SMPL_Layer(nn.Module):
|
19 |
+
"""
|
20 |
+
Extension of the SMPL Layer with information about the camera for (inverse) projection the camera plane.
|
21 |
+
"""
|
22 |
+
def __init__(self,
|
23 |
+
type='smplx',
|
24 |
+
gender='neutral',
|
25 |
+
num_betas=10,
|
26 |
+
kid=False,
|
27 |
+
person_center=None,
|
28 |
+
*args,
|
29 |
+
**kwargs,
|
30 |
+
):
|
31 |
+
super().__init__()
|
32 |
+
|
33 |
+
# Args
|
34 |
+
assert type == 'smplx'
|
35 |
+
self.type = type
|
36 |
+
self.kid = kid
|
37 |
+
self.num_betas = num_betas
|
38 |
+
self.bm_x = smplx.create(SMPLX_DIR, 'smplx', gender=gender, use_pca=False, flat_hand_mean=True, num_betas=num_betas)
|
39 |
+
|
40 |
+
# Primary keypoint - root
|
41 |
+
self.joint_names = eval(f"utils.get_{self.type}_joint_names")()
|
42 |
+
self.person_center = person_center
|
43 |
+
self.person_center_idx = None
|
44 |
+
if self.person_center is not None:
|
45 |
+
self.person_center_idx = self.joint_names.index(self.person_center)
|
46 |
+
|
47 |
+
def forward(self,
|
48 |
+
pose, shape,
|
49 |
+
loc, dist, transl,
|
50 |
+
K,
|
51 |
+
expression=None, # facial expression
|
52 |
+
):
|
53 |
+
"""
|
54 |
+
Args:
|
55 |
+
- pose: pose of the person in axis-angle - torch.Tensor [bs,24,3]
|
56 |
+
- shape: torch.Tensor [bs,10]
|
57 |
+
- loc: 2D location of the pelvis in pixel space - torch.Tensor [bs,2]
|
58 |
+
- dist: distance of the pelvis from the camera in m - torch.Tensor [bs,1]
|
59 |
+
Return:
|
60 |
+
- dict containing a bunch of useful information about each person
|
61 |
+
"""
|
62 |
+
|
63 |
+
if loc is not None and dist is not None:
|
64 |
+
assert pose.shape[0] == shape.shape[0] == loc.shape[0] == dist.shape[0]
|
65 |
+
if self.type == 'smpl':
|
66 |
+
assert len(pose.shape) == 3 and list(pose.shape[1:]) == [24,3]
|
67 |
+
elif self.type == 'smplx':
|
68 |
+
assert len(pose.shape) == 3 and list(pose.shape[1:]) == [53,3] # taking root_orient, body_pose, lhand, rhan and jaw for the moment
|
69 |
+
else:
|
70 |
+
raise NameError
|
71 |
+
assert len(shape.shape) == 2 and (list(shape.shape[1:]) == [self.num_betas] or list(shape.shape[1:]) == [self.num_betas+1])
|
72 |
+
if loc is not None and dist is not None:
|
73 |
+
assert len(loc.shape) == 2 and list(loc.shape[1:]) == [2]
|
74 |
+
assert len(dist.shape) == 2 and list(dist.shape[1:]) == [1]
|
75 |
+
|
76 |
+
bs = pose.shape[0]
|
77 |
+
|
78 |
+
out = {}
|
79 |
+
|
80 |
+
# No humans
|
81 |
+
if bs == 0:
|
82 |
+
return {}
|
83 |
+
|
84 |
+
# Low dimensional parameters
|
85 |
+
kwargs_pose = {
|
86 |
+
'betas': shape,
|
87 |
+
}
|
88 |
+
kwargs_pose['global_orient'] = self.bm_x.global_orient.repeat(bs,1)
|
89 |
+
kwargs_pose['body_pose'] = pose[:,1:22].flatten(1)
|
90 |
+
kwargs_pose['left_hand_pose'] = pose[:,22:37].flatten(1)
|
91 |
+
kwargs_pose['right_hand_pose'] = pose[:,37:52].flatten(1)
|
92 |
+
kwargs_pose['jaw_pose'] = pose[:,52:53].flatten(1)
|
93 |
+
|
94 |
+
if expression is not None:
|
95 |
+
kwargs_pose['expression'] = expression.flatten(1) # [bs,10]
|
96 |
+
else:
|
97 |
+
kwargs_pose['expression'] = self.bm_x.expression.repeat(bs,1)
|
98 |
+
|
99 |
+
# default - to be generalized
|
100 |
+
kwargs_pose['leye_pose'] = self.bm_x.leye_pose.repeat(bs,1)
|
101 |
+
kwargs_pose['reye_pose'] = self.bm_x.reye_pose.repeat(bs,1)
|
102 |
+
|
103 |
+
# Forward using the parametric 3d model SMPL-X layer
|
104 |
+
output = self.bm_x(**kwargs_pose)
|
105 |
+
verts = output.vertices
|
106 |
+
j3d = output.joints # 45 joints
|
107 |
+
R = roma.rotvec_to_rotmat(pose[:,0])
|
108 |
+
|
109 |
+
# Apply global orientation on 3D points
|
110 |
+
pelvis = j3d[:,[0]]
|
111 |
+
j3d = (R.unsqueeze(1) @ (j3d - pelvis).unsqueeze(-1)).squeeze(-1)
|
112 |
+
|
113 |
+
# Apply global orientation on 3D points - bis
|
114 |
+
verts = (R.unsqueeze(1) @ (verts - pelvis).unsqueeze(-1)).squeeze(-1)
|
115 |
+
|
116 |
+
# Location of the person in 3D
|
117 |
+
if transl is None:
|
118 |
+
if K.dtype == torch.float16:
|
119 |
+
# because of torch.inverse - not working with float16 at the moment
|
120 |
+
transl = inverse_perspective_projection(loc.unsqueeze(1).float(), K.float(), dist.unsqueeze(1).float())[:,0]
|
121 |
+
transl = transl.half()
|
122 |
+
else:
|
123 |
+
transl = inverse_perspective_projection(loc.unsqueeze(1), K, dist.unsqueeze(1))[:,0]
|
124 |
+
|
125 |
+
# Updating transl if we choose a certain person center
|
126 |
+
transl_up = transl.clone()
|
127 |
+
|
128 |
+
# Definition of the translation depend on the args: 1) vanilla SMPL - 2) computed from a given joint
|
129 |
+
if self.person_center_idx is None:
|
130 |
+
# Add pelvis to transl - standard way for SMPLX layer
|
131 |
+
transl_up = transl_up + pelvis[:,0]
|
132 |
+
else:
|
133 |
+
# Center around the joint because teh translation is computed from this joint
|
134 |
+
person_center = j3d[:, [self.person_center_idx]]
|
135 |
+
verts = verts - person_center
|
136 |
+
j3d = j3d - person_center
|
137 |
+
|
138 |
+
# Moving into the camera coordinate system
|
139 |
+
j3d_cam = j3d + transl_up.unsqueeze(1)
|
140 |
+
verts_cam = verts + transl_up.unsqueeze(1)
|
141 |
+
|
142 |
+
# Projection in camera plane
|
143 |
+
j2d = perspective_projection(j3d_cam, K)
|
144 |
+
|
145 |
+
out.update({
|
146 |
+
'verts_smplx_cam': verts_cam,
|
147 |
+
'j3d': j3d_cam,
|
148 |
+
'j2d': j2d,
|
149 |
+
'transl': transl, # translation of the primary keypoint
|
150 |
+
'transl_pelvis': j3d_cam[:,[0]], # root=pelvis
|
151 |
+
})
|
152 |
+
|
153 |
+
return out
|
demo.py
ADDED
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Multi-HMR
|
2 |
+
# Copyright (c) 2024-present NAVER Corp.
|
3 |
+
# CC BY-NC-SA 4.0 license
|
4 |
+
|
5 |
+
import os
|
6 |
+
os.environ["PYOPENGL_PLATFORM"] = "egl"
|
7 |
+
os.environ['EGL_DEVICE_ID'] = '0'
|
8 |
+
|
9 |
+
import sys
|
10 |
+
from argparse import ArgumentParser
|
11 |
+
import random
|
12 |
+
import pickle as pkl
|
13 |
+
import numpy as np
|
14 |
+
from PIL import Image, ImageOps
|
15 |
+
import torch
|
16 |
+
from tqdm import tqdm
|
17 |
+
import time
|
18 |
+
|
19 |
+
from utils import normalize_rgb, render_meshes, get_focalLength_from_fieldOfView, demo_color as color, print_distance_on_image, render_side_views, create_scene, MEAN_PARAMS, CACHE_DIR_MULTIHMR, SMPLX_DIR
|
20 |
+
from model import Model
|
21 |
+
from pathlib import Path
|
22 |
+
import warnings
|
23 |
+
|
24 |
+
torch.cuda.empty_cache()
|
25 |
+
|
26 |
+
np.random.seed(seed=0)
|
27 |
+
random.seed(0)
|
28 |
+
|
29 |
+
def open_image(img_path, img_size, device=torch.device('cuda')):
|
30 |
+
""" Open image at path, resize and pad """
|
31 |
+
|
32 |
+
# Open and reshape
|
33 |
+
img_pil = Image.open(img_path).convert('RGB')
|
34 |
+
img_pil = ImageOps.contain(img_pil, (img_size,img_size)) # keep the same aspect ratio
|
35 |
+
|
36 |
+
# Keep a copy for visualisations.
|
37 |
+
img_pil_bis = ImageOps.pad(img_pil.copy(), size=(img_size,img_size), color=(255, 255, 255))
|
38 |
+
img_pil = ImageOps.pad(img_pil, size=(img_size,img_size)) # pad with zero on the smallest side
|
39 |
+
|
40 |
+
# Go to numpy
|
41 |
+
resize_img = np.asarray(img_pil)
|
42 |
+
|
43 |
+
# Normalize and go to torch.
|
44 |
+
resize_img = normalize_rgb(resize_img)
|
45 |
+
x = torch.from_numpy(resize_img).unsqueeze(0).to(device)
|
46 |
+
return x, img_pil_bis
|
47 |
+
|
48 |
+
def get_camera_parameters(img_size, fov=60, p_x=None, p_y=None, device=torch.device('cuda')):
|
49 |
+
""" Given image size, fov and principal point coordinates, return K the camera parameter matrix"""
|
50 |
+
K = torch.eye(3)
|
51 |
+
# Get focal length.
|
52 |
+
focal = get_focalLength_from_fieldOfView(fov=fov, img_size=img_size)
|
53 |
+
K[0,0], K[1,1] = focal, focal
|
54 |
+
|
55 |
+
# Set principal point
|
56 |
+
if p_x is not None and p_y is not None:
|
57 |
+
K[0,-1], K[1,-1] = p_x * img_size, p_y * img_size
|
58 |
+
else:
|
59 |
+
K[0,-1], K[1,-1] = img_size//2, img_size//2
|
60 |
+
|
61 |
+
# Add batch dimension
|
62 |
+
K = K.unsqueeze(0).to(device)
|
63 |
+
return K
|
64 |
+
|
65 |
+
def load_model(model_name, device=torch.device('cuda')):
|
66 |
+
""" Open a checkpoint, build Multi-HMR using saved arguments, load the model weigths. """
|
67 |
+
|
68 |
+
# Model
|
69 |
+
ckpt_path = os.path.join(CACHE_DIR_MULTIHMR, model_name+ '.pt')
|
70 |
+
if not os.path.isfile(ckpt_path):
|
71 |
+
os.makedirs(CACHE_DIR_MULTIHMR, exist_ok=True)
|
72 |
+
print(f"{ckpt_path} not found...")
|
73 |
+
print("It should be the first time you run the demo code")
|
74 |
+
print("Downloading checkpoint from NAVER LABS Europe website...")
|
75 |
+
|
76 |
+
try:
|
77 |
+
os.system(f"wget -O {ckpt_path} http://download.europe.naverlabs.com/multihmr/{model_name}.pt")
|
78 |
+
print(f"Ckpt downloaded to {ckpt_path}")
|
79 |
+
except:
|
80 |
+
assert "Please contact [email protected] or open an issue on the github repo"
|
81 |
+
|
82 |
+
# Load weights
|
83 |
+
print("Loading model")
|
84 |
+
ckpt = torch.load(ckpt_path, map_location=device)
|
85 |
+
|
86 |
+
# Get arguments saved in the checkpoint to rebuild the model
|
87 |
+
kwargs = {}
|
88 |
+
for k,v in vars(ckpt['args']).items():
|
89 |
+
kwargs[k] = v
|
90 |
+
|
91 |
+
# Build the model.
|
92 |
+
kwargs['type'] = ckpt['args'].train_return_type
|
93 |
+
kwargs['img_size'] = ckpt['args'].img_size[0]
|
94 |
+
model = Model(**kwargs).to(device)
|
95 |
+
|
96 |
+
# Load weights into model.
|
97 |
+
model.load_state_dict(ckpt['model_state_dict'], strict=False)
|
98 |
+
print("Weights have been loaded")
|
99 |
+
|
100 |
+
return model
|
101 |
+
|
102 |
+
def forward_model(model, input_image, camera_parameters,
|
103 |
+
det_thresh=0.3,
|
104 |
+
nms_kernel_size=1,
|
105 |
+
):
|
106 |
+
|
107 |
+
""" Make a forward pass on an input image and camera parameters. """
|
108 |
+
|
109 |
+
# Forward the model.
|
110 |
+
with torch.no_grad():
|
111 |
+
with torch.cuda.amp.autocast(enabled=True):
|
112 |
+
humans = model(input_image,
|
113 |
+
is_training=False,
|
114 |
+
nms_kernel_size=int(nms_kernel_size),
|
115 |
+
det_thresh=det_thresh,
|
116 |
+
K=camera_parameters)
|
117 |
+
|
118 |
+
return humans
|
119 |
+
|
120 |
+
def overlay_human_meshes(humans, K, model, img_pil, unique_color=False):
|
121 |
+
|
122 |
+
# Color of humans seen in the image.
|
123 |
+
_color = [color[0] for _ in range(len(humans))] if unique_color else color
|
124 |
+
|
125 |
+
# Get focal and princpt for rendering.
|
126 |
+
focal = np.asarray([K[0,0,0].cpu().numpy(),K[0,1,1].cpu().numpy()])
|
127 |
+
princpt = np.asarray([K[0,0,-1].cpu().numpy(),K[0,1,-1].cpu().numpy()])
|
128 |
+
|
129 |
+
# Get the vertices produced by the model.
|
130 |
+
verts_list = [humans[j]['verts_smplx'].cpu().numpy() for j in range(len(humans))]
|
131 |
+
faces_list = [model.smpl_layer['neutral'].bm_x.faces for j in range(len(humans))]
|
132 |
+
|
133 |
+
# Render the meshes onto the image.
|
134 |
+
pred_rend_array = render_meshes(np.asarray(img_pil),
|
135 |
+
verts_list,
|
136 |
+
faces_list,
|
137 |
+
{'focal': focal, 'princpt': princpt},
|
138 |
+
alpha=1.0,
|
139 |
+
color=_color)
|
140 |
+
|
141 |
+
return pred_rend_array, _color
|
142 |
+
|
143 |
+
if __name__ == "__main__":
|
144 |
+
parser = ArgumentParser()
|
145 |
+
parser.add_argument("--model_name", type=str, default='multiHMR_896_L_synth')
|
146 |
+
parser.add_argument("--img_folder", type=str, default='example_data')
|
147 |
+
parser.add_argument("--out_folder", type=str, default='demo_out')
|
148 |
+
parser.add_argument("--save_mesh", type=int, default=0, choices=[0,1])
|
149 |
+
parser.add_argument("--extra_views", type=int, default=0, choices=[0,1])
|
150 |
+
parser.add_argument("--det_thresh", type=float, default=0.3)
|
151 |
+
parser.add_argument("--nms_kernel_size", type=float, default=3)
|
152 |
+
parser.add_argument("--fov", type=float, default=60)
|
153 |
+
parser.add_argument("--distance", type=int, default=0, choices=[0,1], help='add distance on the reprojected mesh')
|
154 |
+
parser.add_argument("--unique_color", type=int, default=0, choices=[0,1], help='only one color for all humans')
|
155 |
+
|
156 |
+
args = parser.parse_args()
|
157 |
+
|
158 |
+
dict_args = vars(args)
|
159 |
+
|
160 |
+
assert torch.cuda.is_available()
|
161 |
+
|
162 |
+
# SMPL-X models
|
163 |
+
smplx_fn = os.path.join(SMPLX_DIR, 'smplx', 'SMPLX_NEUTRAL.npz')
|
164 |
+
if not os.path.isfile(smplx_fn):
|
165 |
+
print(f"{smplx_fn} not found, please download SMPLX_NEUTRAL.npz file")
|
166 |
+
print("To do so you need to create an account in https://smpl-x.is.tue.mpg.de")
|
167 |
+
print("Then download 'SMPL-X-v1.1 (NPZ+PKL, 830MB) - Use thsi for SMPL-X Python codebase'")
|
168 |
+
print(f"Extract the zip file and move SMPLX_NEUTRAL.npz to {smplx_fn}")
|
169 |
+
print("Sorry for this incovenience but we do not have license for redustributing SMPLX model")
|
170 |
+
assert NotImplementedError
|
171 |
+
else:
|
172 |
+
print('SMPLX found')
|
173 |
+
|
174 |
+
# SMPL mean params download
|
175 |
+
if not os.path.isfile(MEAN_PARAMS):
|
176 |
+
print('Start to download the SMPL mean params')
|
177 |
+
os.system(f"wget -O {MEAN_PARAMS} https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmhuman3d/models/smpl_mean_params.npz?versionId=CAEQHhiBgICN6M3V6xciIDU1MzUzNjZjZGNiOTQ3OWJiZTJmNThiZmY4NmMxMTM4")
|
178 |
+
print('SMPL mean params have been succesfully downloaded')
|
179 |
+
else:
|
180 |
+
print('SMPL mean params is already here')
|
181 |
+
|
182 |
+
# Input images
|
183 |
+
suffixes = ('.jpg', '.jpeg', '.png', '.webp')
|
184 |
+
l_img_path = [file for file in os.listdir(args.img_folder) if file.endswith(suffixes) and file[0] != '.']
|
185 |
+
|
186 |
+
# Loading
|
187 |
+
model = load_model(args.model_name)
|
188 |
+
|
189 |
+
# Model name for saving results.
|
190 |
+
model_name = os.path.basename(args.model_name)
|
191 |
+
|
192 |
+
# All images
|
193 |
+
os.makedirs(args.out_folder, exist_ok=True)
|
194 |
+
l_duration = []
|
195 |
+
for i, img_path in enumerate(tqdm(l_img_path)):
|
196 |
+
|
197 |
+
# Path where the image + overlays of human meshes + optional views will be saved.
|
198 |
+
save_fn = os.path.join(args.out_folder, f"{Path(img_path).stem}_{model_name}.png")
|
199 |
+
|
200 |
+
# Get input in the right format for the model
|
201 |
+
img_size = model.img_size
|
202 |
+
x, img_pil_nopad = open_image(os.path.join(args.img_folder, img_path), img_size)
|
203 |
+
|
204 |
+
# Get camera parameters
|
205 |
+
p_x, p_y = None, None
|
206 |
+
K = get_camera_parameters(model.img_size, fov=args.fov, p_x=p_x, p_y=p_y)
|
207 |
+
|
208 |
+
# Make model predictions
|
209 |
+
start = time.time()
|
210 |
+
humans = forward_model(model, x, K,
|
211 |
+
det_thresh=args.det_thresh,
|
212 |
+
nms_kernel_size=args.nms_kernel_size)
|
213 |
+
duration = time.time() - start
|
214 |
+
l_duration.append(duration)
|
215 |
+
|
216 |
+
# Superimpose predicted human meshes to the input image.
|
217 |
+
img_array = np.asarray(img_pil_nopad)
|
218 |
+
img_pil_visu= Image.fromarray(img_array)
|
219 |
+
pred_rend_array, _color = overlay_human_meshes(humans, K, model, img_pil_visu, unique_color=args.unique_color)
|
220 |
+
|
221 |
+
# Optionally add distance as an annotation to each mesh
|
222 |
+
if args.distance:
|
223 |
+
pred_rend_array = print_distance_on_image(pred_rend_array, humans, _color)
|
224 |
+
|
225 |
+
# List of images too view side by side.
|
226 |
+
l_img = [img_array, pred_rend_array]
|
227 |
+
|
228 |
+
# More views
|
229 |
+
if args.extra_views:
|
230 |
+
# Render more side views of the meshes.
|
231 |
+
pred_rend_array_bis, pred_rend_array_sideview, pred_rend_array_bev = render_side_views(img_array, _color, humans, model, K)
|
232 |
+
|
233 |
+
# Concat
|
234 |
+
_img1 = np.concatenate([img_array, pred_rend_array],1).astype(np.uint8)
|
235 |
+
_img2 = np.concatenate([pred_rend_array_bis, pred_rend_array_sideview, pred_rend_array_bev],1).astype(np.uint8)
|
236 |
+
_h = int(_img2.shape[0] * (_img1.shape[1]/_img2.shape[1]))
|
237 |
+
_img2 = np.asarray(Image.fromarray(_img2).resize((_img1.shape[1], _h)))
|
238 |
+
_img = np.concatenate([_img1, _img2],0).astype(np.uint8)
|
239 |
+
else:
|
240 |
+
# Concatenate side by side
|
241 |
+
_img = np.concatenate([img_array, pred_rend_array],1).astype(np.uint8)
|
242 |
+
|
243 |
+
# Save to path.
|
244 |
+
Image.fromarray(_img).save(save_fn)
|
245 |
+
print(f"Avg Multi-HMR inference time={int(1000*np.median(np.asarray(l_duration[-1:])))}ms on a {torch.cuda.get_device_name()}")
|
246 |
+
|
247 |
+
# Saving mesh
|
248 |
+
if args.save_mesh:
|
249 |
+
# npy file
|
250 |
+
l_mesh = [hum['verts_smplx'].cpu().numpy() for hum in humans]
|
251 |
+
mesh_fn = save_fn+'.npy'
|
252 |
+
np.save(mesh_fn, np.asarray(l_mesh), allow_pickle=True)
|
253 |
+
x = np.load(mesh_fn, allow_pickle=True)
|
254 |
+
|
255 |
+
# glb file
|
256 |
+
l_mesh = [humans[j]['verts_smplx'].detach().cpu().numpy() for j in range(len(humans))]
|
257 |
+
l_face = [model.smpl_layer['neutral'].bm_x.faces for j in range(len(humans))]
|
258 |
+
scene = create_scene(img_pil_visu, l_mesh, l_face, color=None, metallicFactor=0., roughnessFactor=0.5)
|
259 |
+
scene_fn = save_fn+'.glb'
|
260 |
+
scene.export(scene_fn)
|
261 |
+
|
262 |
+
print('end')
|
example_data/170149601_13aa4e4483_c.jpg
ADDED
![]() |
Git LFS Details
|
example_data/3692623581_aca6eb02d4_e.jpg
ADDED
![]() |
Git LFS Details
|
example_data/3969570423_58eb848b75_c.jpg
ADDED
![]() |
Git LFS Details
|
example_data/39742984604_46934fbd50_c.jpg
ADDED
![]() |
Git LFS Details
|
example_data/4446582661_b188f82f3c_c.jpg
ADDED
![]() |
Git LFS Details
|
example_data/51960182045_d5d6407a3c_c.jpg
ADDED
![]() |
Git LFS Details
|
example_data/5850091922_73ba296093_c.jpg
ADDED
![]() |
Git LFS Details
|
model.py
ADDED
@@ -0,0 +1,485 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Multi-HMR
|
2 |
+
# Copyright (c) 2024-present NAVER Corp.
|
3 |
+
# CC BY-NC-SA 4.0 license
|
4 |
+
|
5 |
+
from torch import nn
|
6 |
+
import torch
|
7 |
+
import numpy as np
|
8 |
+
import roma
|
9 |
+
import copy
|
10 |
+
|
11 |
+
from utils import unpatch, inverse_perspective_projection, undo_focal_length_normalization, undo_log_depth
|
12 |
+
from blocks import Dinov2Backbone, FourierPositionEncoding, TransformerDecoder, SMPL_Layer
|
13 |
+
from utils import rot6d_to_rotmat, rebatch, pad_to_max
|
14 |
+
import torch.nn as nn
|
15 |
+
import numpy as np
|
16 |
+
import einops
|
17 |
+
from utils.constants import MEAN_PARAMS
|
18 |
+
|
19 |
+
class Model(nn.Module):
|
20 |
+
""" A ViT backbone followed by a "HPH" head (stack of cross attention layers with queries corresponding to detected humans.) """
|
21 |
+
|
22 |
+
def __init__(self,
|
23 |
+
backbone='dinov2_vitb14',
|
24 |
+
img_size=896,
|
25 |
+
camera_embedding='geometric', # geometric encodes viewing directions with fourrier encoding
|
26 |
+
camera_embedding_num_bands=16, # increase the size of the camera embedding
|
27 |
+
camera_embedding_max_resolution=64, # does not increase the size of the camera embedding
|
28 |
+
nearness=True, # regress log(1/z)
|
29 |
+
xat_depth=2, # number of cross attention block (SA, CA, MLP) in the HPH head.
|
30 |
+
xat_num_heads=8, # Number of attention heads
|
31 |
+
dict_smpl_layer=None,
|
32 |
+
person_center='head',
|
33 |
+
clip_dist=True,
|
34 |
+
*args, **kwargs):
|
35 |
+
super().__init__()
|
36 |
+
|
37 |
+
# Save options
|
38 |
+
self.img_size = img_size
|
39 |
+
self.nearness = nearness
|
40 |
+
self.clip_dist = clip_dist,
|
41 |
+
self.xat_depth = xat_depth
|
42 |
+
self.xat_num_heads = xat_num_heads
|
43 |
+
|
44 |
+
# Setup backbone
|
45 |
+
self.backbone = Dinov2Backbone(backbone)
|
46 |
+
self.embed_dim = self.backbone.embed_dim
|
47 |
+
self.patch_size = self.backbone.patch_size
|
48 |
+
assert self.img_size % self.patch_size == 0, "Invalid img size"
|
49 |
+
|
50 |
+
# Camera instrinsics
|
51 |
+
self.fovn = 60
|
52 |
+
self.camera_embedding = camera_embedding
|
53 |
+
self.camera_embed_dim = 0
|
54 |
+
if self.camera_embedding is not None:
|
55 |
+
if not self.camera_embedding == 'geometric':
|
56 |
+
raise NotImplementedError("Only geometric camera embedding is implemented")
|
57 |
+
self.camera = FourierPositionEncoding(n=3, num_bands=camera_embedding_num_bands,max_resolution=camera_embedding_max_resolution)
|
58 |
+
# import pdb
|
59 |
+
# pdb.set_trace()
|
60 |
+
self.camera_embed_dim = self.camera.channels
|
61 |
+
|
62 |
+
# Heads - Detection
|
63 |
+
self.mlp_classif = regression_mlp([self.embed_dim, self.embed_dim, 1]) # bg or human
|
64 |
+
|
65 |
+
# Heads - Human properties
|
66 |
+
self.mlp_offset = regression_mlp([self.embed_dim, self.embed_dim, 2]) # offset
|
67 |
+
|
68 |
+
# Dense vetcor idx
|
69 |
+
self.nrot = 53
|
70 |
+
self.idx_score, self.idx_offset, self.idx_dist = [0], [1,2], [3]
|
71 |
+
self.idx_pose = list(range(4,4+self.nrot*9))
|
72 |
+
self.idx_shape = list(range(4+self.nrot*9,4+self.nrot*9+11))
|
73 |
+
self.idx_expr = list(range(4+self.nrot*9+11,4+self.nrot*9+11+10))
|
74 |
+
|
75 |
+
# SMPL Layers
|
76 |
+
dict_smpl_layer = {'neutral': {10: SMPL_Layer(type='smplx', gender='neutral', num_betas=10, kid=False, person_center=person_center)}}
|
77 |
+
_moduleDict = []
|
78 |
+
for k, _smpl_layer in dict_smpl_layer.items():
|
79 |
+
_moduleDict.append([k, copy.deepcopy(_smpl_layer[10])])
|
80 |
+
self.smpl_layer = nn.ModuleDict(_moduleDict)
|
81 |
+
|
82 |
+
self.x_attention_head = HPH(
|
83 |
+
num_body_joints=self.nrot-1, #23,
|
84 |
+
context_dim=self.embed_dim + self.camera_embed_dim,
|
85 |
+
dim=1024,
|
86 |
+
depth=self.xat_depth,
|
87 |
+
heads=self.xat_num_heads,
|
88 |
+
mlp_dim=1024,
|
89 |
+
dim_head=32,
|
90 |
+
dropout=0.0,
|
91 |
+
emb_dropout=0.0,
|
92 |
+
at_token_res=self.img_size // self.patch_size)
|
93 |
+
|
94 |
+
def detection(self, z, nms_kernel_size, det_thresh, N):
|
95 |
+
""" Detection score on the entire low res image """
|
96 |
+
scores = _sigmoid(self.mlp_classif(z)) # per token detection score.
|
97 |
+
# Restore Height and Width dimensions.
|
98 |
+
scores = unpatch(scores, patch_size=1, c=scores.shape[2], img_size=int(np.sqrt(N)))
|
99 |
+
|
100 |
+
if nms_kernel_size > 1: # Easy nms: supress adjacent high scores with max pooling.
|
101 |
+
scores = _nms(scores, kernel=nms_kernel_size)
|
102 |
+
_scores = torch.permute(scores, (0, 2, 3, 1))
|
103 |
+
|
104 |
+
# Binary decision (keep confident detections)
|
105 |
+
idx = apply_threshold(det_thresh, _scores)
|
106 |
+
|
107 |
+
# Scores
|
108 |
+
scores_detected = scores[idx[0], idx[3], idx[1],idx[2]] # scores of the detected humans only
|
109 |
+
scores = torch.permute(scores, (0, 2, 3, 1))
|
110 |
+
return scores, scores_detected, idx
|
111 |
+
|
112 |
+
def embedd_camera(self, K, z):
|
113 |
+
""" Embed viewing directions using fourrier encoding."""
|
114 |
+
bs = z.shape[0]
|
115 |
+
_h, _w = list(z.shape[-2:])
|
116 |
+
points = torch.stack([torch.arange(0,_h,1).reshape(-1,1).repeat(1,_w), torch.arange(0,_w,1).reshape(1,-1).repeat(_h,1)],-1).to(z.device).float() # [h,w,2]
|
117 |
+
points = points * self.patch_size + self.patch_size // 2 # move to pixel space - we give the pixel center of each token
|
118 |
+
points = points.reshape(1,-1,2).repeat(bs,1,1) # (bs, N, 2): 2D points
|
119 |
+
distance = torch.ones(bs,points.shape[1],1).to(K.device) # (bs, N, 1): distance in the 3D world
|
120 |
+
rays = inverse_perspective_projection(points, K, distance) # (bs, N, 3)
|
121 |
+
rays_embeddings = self.camera(pos=rays)
|
122 |
+
|
123 |
+
# Repeat for each element of the batch
|
124 |
+
z_K = rays_embeddings.reshape(bs,_h,_w,self.camera_embed_dim) # [bs,h,w,D]
|
125 |
+
return z_K
|
126 |
+
|
127 |
+
def to_euclidean_dist(self, x, dist, _K):
|
128 |
+
# Focal length normalization
|
129 |
+
focal = _K[:,[0],[0]]
|
130 |
+
dist = undo_focal_length_normalization(dist, focal, fovn=self.fovn, img_size=x.shape[-1])
|
131 |
+
# log space
|
132 |
+
if self.nearness:
|
133 |
+
dist = undo_log_depth(dist)
|
134 |
+
|
135 |
+
# Clamping
|
136 |
+
if self.clip_dist:
|
137 |
+
dist = torch.clamp(dist, 0, 50)
|
138 |
+
|
139 |
+
return dist
|
140 |
+
|
141 |
+
|
142 |
+
def forward(self,
|
143 |
+
x,
|
144 |
+
idx=None,
|
145 |
+
det_thresh=0.5,
|
146 |
+
nms_kernel_size=3,
|
147 |
+
K=None,
|
148 |
+
*args,
|
149 |
+
**kwargs):
|
150 |
+
"""
|
151 |
+
Forward pass of the model and compute the loss according to the groundtruth
|
152 |
+
Args:
|
153 |
+
- x: RGB image - [bs,3,224,224]
|
154 |
+
- idx: GT location of persons - tuple of 3 tensor of shape [p]
|
155 |
+
- idx_j2d: GT location of 2d-kpts for each detected humans - tensor of shape [bs',14,2] - location in pixel space
|
156 |
+
Return:
|
157 |
+
- y: [bs,D,16,16]
|
158 |
+
"""
|
159 |
+
persons = []
|
160 |
+
out = {}
|
161 |
+
|
162 |
+
# Feature extraction
|
163 |
+
z = self.backbone(x)
|
164 |
+
B,N,C = z.size() # [bs,256,768]
|
165 |
+
|
166 |
+
# Detection
|
167 |
+
scores, scores_det, idx = self.detection(z, nms_kernel_size=nms_kernel_size, det_thresh=det_thresh, N=N)
|
168 |
+
if len(idx[0]) == 0:
|
169 |
+
# no humans detected in the frame
|
170 |
+
return persons
|
171 |
+
|
172 |
+
# Map of Dense Feature
|
173 |
+
z = unpatch(z, patch_size=1, c=z.shape[2], img_size=int(np.sqrt(N))) # [bs,D,16,16]
|
174 |
+
z_all = z
|
175 |
+
|
176 |
+
# Extract the 'central' features
|
177 |
+
z = torch.reshape(z, (z.shape[0], 1, z.shape[1]//1, z.shape[2], z.shape[3])) # [bs,stack_K,D,16,16]
|
178 |
+
z_central = z[idx[0],idx[3],:,idx[1],idx[2]] # dense vectors
|
179 |
+
|
180 |
+
# 2D offset regression
|
181 |
+
offset = self.mlp_offset(z_central)
|
182 |
+
|
183 |
+
# Camera instrincs
|
184 |
+
K_det = K[idx[0]] # cameras for detected person
|
185 |
+
z_K = self.embedd_camera(K, z) # Embed viewing directions.
|
186 |
+
z_central = torch.cat([z_central, z_K[idx[0],idx[1], idx[2]]], 1) # Add to query tokens.
|
187 |
+
z_all = torch.cat([z_all, z_K.permute(0,3,1,2)], 1) # for the cross-attention only
|
188 |
+
z = torch.cat([z, z_K.permute(0,3,1,2).unsqueeze(1)],2)
|
189 |
+
|
190 |
+
# Distance for estimating the 3D location in 3D space
|
191 |
+
loc = torch.stack([idx[2],idx[1]]).permute(1,0) # Moving from higher resolution the location of the pelvis
|
192 |
+
loc = (loc + 0.5 + offset ) * self.patch_size
|
193 |
+
|
194 |
+
# SMPL parameter regression
|
195 |
+
kv = z_all[idx[0]] # retrieving dense features associated to each central vector
|
196 |
+
pred_smpl_params, pred_cam = self.x_attention_head(z_central, kv, idx_0=idx[0], idx_det=idx)
|
197 |
+
|
198 |
+
# Get outputs from the SMPL layer.
|
199 |
+
shape = pred_smpl_params['betas']
|
200 |
+
rotmat = torch.cat([pred_smpl_params['global_orient'],pred_smpl_params['body_pose']], 1)
|
201 |
+
expression = pred_smpl_params['expression']
|
202 |
+
rotvec = roma.rotmat_to_rotvec(rotmat)
|
203 |
+
|
204 |
+
# Distance
|
205 |
+
dist = pred_cam[:, 0][:, None]
|
206 |
+
out['dist_postprocessed'] = dist # before applying any post-processing such as focal length normalization, inverse or log
|
207 |
+
dist = self.to_euclidean_dist(x, dist, K_det)
|
208 |
+
|
209 |
+
# Populate output dictionnary
|
210 |
+
out.update({'scores': scores, 'offset': offset, 'dist': dist, 'expression': expression,
|
211 |
+
'rotmat': rotmat, 'shape': shape, 'rotvec': rotvec, 'loc': loc})
|
212 |
+
|
213 |
+
assert rotvec.shape[0] == shape.shape[0] == loc.shape[0] == dist.shape[0], "Incoherent shapes"
|
214 |
+
|
215 |
+
# Neutral
|
216 |
+
smpl_out = self.smpl_layer['neutral'](rotvec, shape, loc, dist, None, K=K_det, expression=expression)
|
217 |
+
out.update(smpl_out)
|
218 |
+
|
219 |
+
# Populate a dictionnary for each person
|
220 |
+
for i in range(idx[0].shape[0]):
|
221 |
+
person = {
|
222 |
+
# Detection
|
223 |
+
'scores': scores_det[i], # detection scores
|
224 |
+
'loc': out['loc'][i], # 2d pixel location of the primary keypoints
|
225 |
+
# SMPL-X params
|
226 |
+
'transl': out['transl'][i], # from the primary keypoint i.e. the head
|
227 |
+
'transl_pelvis': out['transl_pelvis'][i], # of the pelvis joint
|
228 |
+
'rotvec': out['rotvec'][i],
|
229 |
+
'expression': out['expression'][i],
|
230 |
+
'shape': out['shape'][i],
|
231 |
+
# SMPL-X meshs
|
232 |
+
'verts_smplx': out['verts_smplx_cam'][i],
|
233 |
+
'j3d_smplx': out['j3d'][i],
|
234 |
+
'j2d_smplx': out['j2d'][i],
|
235 |
+
}
|
236 |
+
persons.append(person)
|
237 |
+
|
238 |
+
return persons
|
239 |
+
|
240 |
+
class HPH(nn.Module):
|
241 |
+
""" Cross-attention based SMPL Transformer decoder
|
242 |
+
|
243 |
+
Code modified from:
|
244 |
+
https://github.com/shubham-goel/4D-Humans/blob/a0def798c7eac811a63c8220fcc22d983b39785e/hmr2/models/heads/smpl_head.py#L17
|
245 |
+
https://github.com/shubham-goel/4D-Humans/blob/a0def798c7eac811a63c8220fcc22d983b39785e/hmr2/models/components/pose_transformer.py#L301
|
246 |
+
"""
|
247 |
+
|
248 |
+
def __init__(self,
|
249 |
+
num_body_joints=52,
|
250 |
+
context_dim=1280,
|
251 |
+
dim=1024,
|
252 |
+
depth=2,
|
253 |
+
heads=8,
|
254 |
+
mlp_dim=1024,
|
255 |
+
dim_head=64,
|
256 |
+
dropout=0.0,
|
257 |
+
emb_dropout=0.0,
|
258 |
+
at_token_res=32,
|
259 |
+
):
|
260 |
+
super().__init__()
|
261 |
+
|
262 |
+
self.joint_rep_type, self.joint_rep_dim = '6d', 6
|
263 |
+
self.num_body_joints = num_body_joints
|
264 |
+
self.nrot = self.num_body_joints + 1
|
265 |
+
|
266 |
+
npose = self.joint_rep_dim * (self.num_body_joints + 1)
|
267 |
+
self.npose = npose
|
268 |
+
|
269 |
+
self.depth = depth,
|
270 |
+
self.heads = heads,
|
271 |
+
self.res = at_token_res
|
272 |
+
self.input_is_mean_shape = True
|
273 |
+
_context_dim = context_dim # for the central features
|
274 |
+
|
275 |
+
# Transformer Decoder setup.
|
276 |
+
# Based on https://github.com/shubham-goel/4D-Humans/blob/8830bb330558eea2395b7f57088ef0aae7f8fa22/hmr2/configs_hydra/experiment/hmr_vit_transformer.yaml#L35
|
277 |
+
transformer_args = dict(
|
278 |
+
num_tokens=1,
|
279 |
+
token_dim=(npose + 10 + 3 + _context_dim) if self.input_is_mean_shape else 1,
|
280 |
+
dim=dim,
|
281 |
+
depth=depth,
|
282 |
+
heads=heads,
|
283 |
+
mlp_dim=mlp_dim,
|
284 |
+
dim_head=dim_head,
|
285 |
+
dropout=dropout,
|
286 |
+
emb_dropout=emb_dropout,
|
287 |
+
context_dim=context_dim,
|
288 |
+
)
|
289 |
+
self.transformer = TransformerDecoder(**transformer_args)
|
290 |
+
|
291 |
+
dim = transformer_args['dim']
|
292 |
+
|
293 |
+
# Final decoders to regress targets
|
294 |
+
self.decpose, self.decshape, self.deccam, self.decexpression = [nn.Linear(dim, od) for od in [npose, 10, 3, 10]]
|
295 |
+
|
296 |
+
# Register bufffers for the smpl layer.
|
297 |
+
self.set_smpl_init()
|
298 |
+
|
299 |
+
# Init learned embeddings for the cross attention queries
|
300 |
+
self.init_learned_queries(context_dim)
|
301 |
+
|
302 |
+
|
303 |
+
def init_learned_queries(self, context_dim, std=0.2):
|
304 |
+
""" Init learned embeddings for queries"""
|
305 |
+
self.cross_queries_x = nn.Parameter(torch.zeros(self.res, context_dim))
|
306 |
+
torch.nn.init.normal_(self.cross_queries_x, std=std)
|
307 |
+
|
308 |
+
self.cross_queries_y = nn.Parameter(torch.zeros(self.res, context_dim))
|
309 |
+
torch.nn.init.normal_(self.cross_queries_y, std=std)
|
310 |
+
|
311 |
+
self.cross_values_x = nn.Parameter(torch.zeros(self.res, context_dim))
|
312 |
+
torch.nn.init.normal_(self.cross_values_x, std=std)
|
313 |
+
|
314 |
+
self.cross_values_y = nn.Parameter(nn.Parameter(torch.zeros(self.res, context_dim)))
|
315 |
+
torch.nn.init.normal_(self.cross_values_y, std=std)
|
316 |
+
|
317 |
+
def set_smpl_init(self):
|
318 |
+
""" Fetch saved SMPL parameters and register buffers."""
|
319 |
+
mean_params = np.load(MEAN_PARAMS)
|
320 |
+
if self.nrot == 53:
|
321 |
+
init_body_pose = torch.eye(3).reshape(1,3,3).repeat(self.nrot,1,1)[:,:,:2].flatten(1).reshape(1, -1)
|
322 |
+
init_body_pose[:,:24*6] = torch.from_numpy(mean_params['pose'][:]).float() # global_orient+body_pose from SMPL
|
323 |
+
else:
|
324 |
+
init_body_pose = torch.from_numpy(mean_params['pose'].astype(np.float32)).unsqueeze(0)
|
325 |
+
|
326 |
+
init_betas = torch.from_numpy(mean_params['shape'].astype('float32')).unsqueeze(0)
|
327 |
+
init_cam = torch.from_numpy(mean_params['cam'].astype(np.float32)).unsqueeze(0)
|
328 |
+
init_betas_kid = torch.cat([init_betas, torch.zeros_like(init_betas[:,[0]])],1)
|
329 |
+
init_expression = 0. * torch.from_numpy(mean_params['shape'].astype('float32')).unsqueeze(0)
|
330 |
+
|
331 |
+
self.register_buffer('init_body_pose', init_body_pose)
|
332 |
+
self.register_buffer('init_betas', init_betas)
|
333 |
+
self.register_buffer('init_betas_kid', init_betas_kid)
|
334 |
+
self.register_buffer('init_cam', init_cam)
|
335 |
+
self.register_buffer('init_expression', init_expression)
|
336 |
+
|
337 |
+
|
338 |
+
def cross_attn_inputs(self, x, x_central, idx_0, idx_det):
|
339 |
+
""" Reshape and pad x_central to have the right shape for Cross-attention processing.
|
340 |
+
Inject learned embeddings to query and key inputs at the location of detected people. """
|
341 |
+
|
342 |
+
h, w = x.shape[2], x.shape[3]
|
343 |
+
x = einops.rearrange(x, 'b c h w -> b (h w) c')
|
344 |
+
|
345 |
+
assert idx_0 is not None, "Learned cross queries only work with multicross"
|
346 |
+
|
347 |
+
if idx_0.shape[0] > 0:
|
348 |
+
# reconstruct the batch/nb_people dimensions: pad for images with fewer people than max.
|
349 |
+
counts, idx_det_0 = rebatch(idx_0, idx_det)
|
350 |
+
old_shape = x_central.shape
|
351 |
+
|
352 |
+
# Legacy check for old versions
|
353 |
+
assert idx_det is not None, 'idx_det needed for learned_attention'
|
354 |
+
|
355 |
+
# xx is the tensor with all features
|
356 |
+
xx = einops.rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
|
357 |
+
# Get learned embeddings for queries, at positions with detected people.
|
358 |
+
queries_xy = self.cross_queries_x[idx_det[1]] + self.cross_queries_y[idx_det[2]]
|
359 |
+
# Add the embedding to the central features.
|
360 |
+
x_central = x_central + queries_xy
|
361 |
+
assert x_central.shape == old_shape, "Problem with shape"
|
362 |
+
|
363 |
+
# Make it a tensor of dim. [batch, max_ppl_along_batch, ...]
|
364 |
+
x_central, mask = pad_to_max(x_central, counts)
|
365 |
+
|
366 |
+
#xx = einops.rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
|
367 |
+
xx = xx[torch.cumsum(counts, dim=0)-1]
|
368 |
+
|
369 |
+
# Inject leared embeddings for key/values at detected locations.
|
370 |
+
values_xy = self.cross_values_x[idx_det[1]] + self.cross_values_y[idx_det[2]]
|
371 |
+
xx[idx_det_0, :, idx_det[1], idx_det[2]] += values_xy
|
372 |
+
|
373 |
+
x = einops.rearrange(xx, 'b c h w -> b (h w) c')
|
374 |
+
num_ppl = x_central.shape[1]
|
375 |
+
else:
|
376 |
+
mask = None
|
377 |
+
num_ppl = 1
|
378 |
+
counts = None
|
379 |
+
return x, x_central, mask, num_ppl, counts
|
380 |
+
|
381 |
+
|
382 |
+
def forward(self,
|
383 |
+
x_central,
|
384 |
+
x,
|
385 |
+
idx_0=None,
|
386 |
+
idx_det=None,
|
387 |
+
**kwargs):
|
388 |
+
""""
|
389 |
+
Forward the HPH module.
|
390 |
+
"""
|
391 |
+
batch_size = x.shape[0]
|
392 |
+
|
393 |
+
# Reshape inputs for cross attention and inject learned embeddings for queries and values.
|
394 |
+
x, x_central, mask, num_ppl, counts = self.cross_attn_inputs(x, x_central, idx_0, idx_det)
|
395 |
+
|
396 |
+
# Add init (mean smpl params) to the query for each quantity being regressed.
|
397 |
+
bs = x_central.shape[0] if idx_0.shape[0] else batch_size
|
398 |
+
expand = lambda x: x.expand(bs, num_ppl , -1)
|
399 |
+
pred_body_pose, pred_betas, pred_cam, pred_expression = [expand(x) for x in
|
400 |
+
[self.init_body_pose, self.init_betas, self.init_cam, self.init_expression]]
|
401 |
+
token = torch.cat([x_central, pred_body_pose, pred_betas, pred_cam], dim=-1)
|
402 |
+
if len(token.shape) == 2:
|
403 |
+
token = token[:,None,:]
|
404 |
+
|
405 |
+
# Process query and inputs with the cross-attention module.
|
406 |
+
token_out = self.transformer(token, context=x, mask=mask)
|
407 |
+
|
408 |
+
# Reshape outputs from [batch_size, nmax_ppl, ...] to [total_ppl, ...]
|
409 |
+
if mask is not None:
|
410 |
+
# Stack along batch axis.
|
411 |
+
token_out_list = [token_out[i, :c, ...] for i, c in enumerate(counts)]
|
412 |
+
token_out = torch.concat(token_out_list, dim=0)
|
413 |
+
else:
|
414 |
+
token_out = token_out.squeeze(1) # (B, C)
|
415 |
+
|
416 |
+
# Decoded output token and add to init for each quantity to regress.
|
417 |
+
reshape = (lambda x: x) if idx_0.shape[0] == 0 else (lambda x: x[0, 0, ...][None, ...])
|
418 |
+
decoders = [self.decpose, self.decshape, self.deccam, self.decexpression]
|
419 |
+
inits = [pred_body_pose, pred_betas, pred_cam, pred_expression]
|
420 |
+
pred_body_pose, pred_betas, pred_cam, pred_expression = [d(token_out) + reshape(i) for d, i in zip(decoders, inits)]
|
421 |
+
|
422 |
+
# Convert self.joint_rep_type -> rotmat
|
423 |
+
joint_conversion_fn = rot6d_to_rotmat
|
424 |
+
|
425 |
+
# conversion
|
426 |
+
pred_body_pose = joint_conversion_fn(pred_body_pose).view(batch_size, self.num_body_joints+1, 3, 3)
|
427 |
+
|
428 |
+
# Build the output dict
|
429 |
+
pred_smpl_params = {'global_orient': pred_body_pose[:, [0]],
|
430 |
+
'body_pose': pred_body_pose[:, 1:],
|
431 |
+
'betas': pred_betas,
|
432 |
+
#'betas_kid': pred_betas_kid,
|
433 |
+
'expression': pred_expression}
|
434 |
+
return pred_smpl_params, pred_cam #, pred_smpl_params_list
|
435 |
+
|
436 |
+
def regression_mlp(layers_sizes):
|
437 |
+
"""
|
438 |
+
Return a fully connected network.
|
439 |
+
"""
|
440 |
+
assert len(layers_sizes) >= 2
|
441 |
+
in_features = layers_sizes[0]
|
442 |
+
layers = []
|
443 |
+
for i in range(1, len(layers_sizes)-1):
|
444 |
+
out_features = layers_sizes[i]
|
445 |
+
layers.append(torch.nn.Linear(in_features, out_features))
|
446 |
+
layers.append(torch.nn.ReLU())
|
447 |
+
in_features = out_features
|
448 |
+
layers.append(torch.nn.Linear(in_features, layers_sizes[-1]))
|
449 |
+
return torch.nn.Sequential(*layers)
|
450 |
+
|
451 |
+
def apply_threshold(det_thresh, _scores):
|
452 |
+
""" Apply thresholding to detection scores; if stack_K is used and det_thresh is a list, apply to each channel separately """
|
453 |
+
if isinstance(det_thresh, list):
|
454 |
+
det_thresh = det_thresh[0]
|
455 |
+
idx = torch.where(_scores >= det_thresh)
|
456 |
+
return idx
|
457 |
+
|
458 |
+
def _nms(heat, kernel=3):
|
459 |
+
""" easy non maximal supression (as in CenterNet) """
|
460 |
+
|
461 |
+
if kernel not in [2, 4]:
|
462 |
+
pad = (kernel - 1) // 2
|
463 |
+
else:
|
464 |
+
if kernel == 2:
|
465 |
+
pad = 1
|
466 |
+
else:
|
467 |
+
pad = 2
|
468 |
+
|
469 |
+
hmax = nn.functional.max_pool2d( heat, (kernel, kernel), stride=1, padding=pad)
|
470 |
+
|
471 |
+
if hmax.shape[2] > heat.shape[2]:
|
472 |
+
hmax = hmax[:, :, :heat.shape[2], :heat.shape[3]]
|
473 |
+
|
474 |
+
keep = (hmax == heat).float()
|
475 |
+
|
476 |
+
return heat * keep
|
477 |
+
|
478 |
+
def _sigmoid(x):
|
479 |
+
y = torch.clamp(x.sigmoid_(), min=1e-4, max=1-1e-4)
|
480 |
+
return y
|
481 |
+
|
482 |
+
|
483 |
+
|
484 |
+
if __name__ == "__main__":
|
485 |
+
Model()
|
models/multiHMR/multiHMR.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:54bda6698dd3a11639d54c5ae71190817549232fa57e48072e0fa533ea52639c
|
3 |
+
size 1286462544
|
packages.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
libglfw3-dev
|
2 |
+
libgles2-mesa-dev
|
3 |
+
freeglut3-dev
|
requirements.txt
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==2.0.1
|
2 |
+
trimesh==3.22.3
|
3 |
+
pyrender==0.1.45
|
4 |
+
einops==0.6.1
|
5 |
+
roma
|
6 |
+
pillow==10.0.1
|
7 |
+
smplx
|
8 |
+
pyvista==0.42.3
|
9 |
+
numpy==1.22.4
|
10 |
+
pyglet==1.5.24
|
11 |
+
tqdm==4.65.0
|
12 |
+
#xformers==0.0.20 # does not work for CPU demo on HF
|
13 |
+
# for huggingface
|
14 |
+
gradio==4.18.0
|
15 |
+
spaces==0.19.4
|
utils/__init__.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .humans import (get_smplx_joint_names, rot6d_to_rotmat)
|
2 |
+
|
3 |
+
from .camera import (perspective_projection, get_focalLength_from_fieldOfView, inverse_perspective_projection,
|
4 |
+
undo_focal_length_normalization, undo_log_depth)
|
5 |
+
|
6 |
+
from .image import normalize_rgb, unpatch
|
7 |
+
|
8 |
+
from .render import render_meshes, print_distance_on_image, render_side_views, create_scene
|
9 |
+
|
10 |
+
from .tensor_manip import rebatch, pad, pad_to_max
|
11 |
+
|
12 |
+
from .color import demo_color
|
13 |
+
|
14 |
+
from .constants import SMPLX_DIR, MEAN_PARAMS, CACHE_DIR_MULTIHMR
|
15 |
+
|
utils/camera.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Multi-HMR
|
2 |
+
# Copyright (c) 2024-present NAVER Corp.
|
3 |
+
# CC BY-NC-SA 4.0 license
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import math
|
7 |
+
import torch
|
8 |
+
|
9 |
+
OPENCV_TO_OPENGL_CAMERA_CONVENTION = np.array([[1, 0, 0, 0],
|
10 |
+
[0, -1, 0, 0],
|
11 |
+
[0, 0, -1, 0],
|
12 |
+
[0, 0, 0, 1]])
|
13 |
+
|
14 |
+
def perspective_projection(x, K):
|
15 |
+
"""
|
16 |
+
This function computes the perspective projection of a set of points assuming the extrinsinc params have already been applied
|
17 |
+
Args:
|
18 |
+
- x [bs,N,3]: 3D points
|
19 |
+
- K [bs,3,3]: Camera instrincs params
|
20 |
+
"""
|
21 |
+
# Apply perspective distortion
|
22 |
+
y = x / x[:, :, -1].unsqueeze(-1) # (bs, N, 3)
|
23 |
+
|
24 |
+
# Apply camera intrinsics
|
25 |
+
y = torch.einsum('bij,bkj->bki', K, y) # (bs, N, 3)
|
26 |
+
|
27 |
+
return y[:, :, :2]
|
28 |
+
|
29 |
+
|
30 |
+
def inverse_perspective_projection(points, K, distance):
|
31 |
+
"""
|
32 |
+
This function computes the inverse perspective projection of a set of points given an estimated distance.
|
33 |
+
Input:
|
34 |
+
points (bs, N, 2): 2D points
|
35 |
+
K (bs,3,3): camera intrinsics params
|
36 |
+
distance (bs, N, 1): distance in the 3D world
|
37 |
+
Similar to:
|
38 |
+
- pts_l_norm = cv2.undistortPoints(np.expand_dims(pts_l, axis=1), cameraMatrix=K_l, distCoeffs=None)
|
39 |
+
"""
|
40 |
+
# Apply camera intrinsics
|
41 |
+
points = torch.cat([points, torch.ones_like(points[..., :1])], -1)
|
42 |
+
points = torch.einsum('bij,bkj->bki', torch.inverse(K), points)
|
43 |
+
|
44 |
+
# Apply perspective distortion
|
45 |
+
if distance == None:
|
46 |
+
return points
|
47 |
+
points = points * distance
|
48 |
+
return points
|
49 |
+
|
50 |
+
def get_focalLength_from_fieldOfView(fov=60, img_size=512):
|
51 |
+
"""
|
52 |
+
Compute the focal length of the camera lens by assuming a certain FOV for the entire image
|
53 |
+
Args:
|
54 |
+
- fov: float, expressed in degree
|
55 |
+
- img_size: int
|
56 |
+
Return:
|
57 |
+
focal: float
|
58 |
+
"""
|
59 |
+
focal = img_size / (2 * np.tan(np.radians(fov) /2))
|
60 |
+
return focal
|
61 |
+
|
62 |
+
def undo_focal_length_normalization(y, f, fovn=60, img_size=448):
|
63 |
+
"""
|
64 |
+
Undo focal_length_normalization()
|
65 |
+
"""
|
66 |
+
fn = get_focalLength_from_fieldOfView(fov=fovn, img_size=img_size)
|
67 |
+
x = y * (f/fn)
|
68 |
+
return x
|
69 |
+
|
70 |
+
def undo_log_depth(y):
|
71 |
+
"""
|
72 |
+
Undo log_depth()
|
73 |
+
"""
|
74 |
+
return torch.exp(y)
|
75 |
+
|
utils/color.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Multi-HMR
|
2 |
+
# Copyright (c) 2024-present NAVER Corp.
|
3 |
+
# CC BY-NC-SA 4.0 license
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
def hex_to_rgb(hex):
|
8 |
+
y = tuple(int(hex[i:i+2], 16) for i in (0, 2, 4))
|
9 |
+
return (y[0]/255,y[1]/255,y[2]/255)
|
10 |
+
|
11 |
+
# Define colors for the demo
|
12 |
+
color = ['0047AB', # cobaltblue
|
13 |
+
'6495ED', # cornerblue
|
14 |
+
'FF9999', 'FF9933', '00CC66', '66B2FF', 'FF6666', 'FF3333', 'C0C0C0', '9933FF'] # rosé - orange - green - blue - red - grey - violet
|
15 |
+
color = [ hex_to_rgb(x) for x in color]
|
16 |
+
|
17 |
+
for i in range(200):
|
18 |
+
color_i = list(np.random.choice(range(256), size=3))
|
19 |
+
color.append((color_i[0]/225, color_i[1]/225, color_i[2]/225))
|
20 |
+
|
21 |
+
demo_color = color
|
22 |
+
|
utils/constants.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Multi-HMR
|
2 |
+
# Copyright (c) 2024-present NAVER Corp.
|
3 |
+
# CC BY-NC-SA 4.0 license
|
4 |
+
|
5 |
+
import os
|
6 |
+
|
7 |
+
SMPLX_DIR = 'models'
|
8 |
+
MEAN_PARAMS = 'models/smpl_mean_params.npz'
|
9 |
+
CACHE_DIR_MULTIHMR = 'models/multiHMR'
|
utils/download.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import sys
|
4 |
+
from urllib import request as urlrequest
|
5 |
+
|
6 |
+
from constants import CACHE_DIR_MULTIHMR, SMPLX_DIR, MEAN_PARAMS
|
7 |
+
|
8 |
+
def _progress_bar(count, total):
|
9 |
+
"""Report download progress. Credit:
|
10 |
+
https://stackoverflow.com/questions/3173320/text-progress-bar-in-the-console/27871113
|
11 |
+
"""
|
12 |
+
bar_len = 60
|
13 |
+
filled_len = int(round(bar_len * count / float(total)))
|
14 |
+
percents = round(100.0 * count / float(total), 1)
|
15 |
+
bar = "=" * filled_len + "-" * (bar_len - filled_len)
|
16 |
+
sys.stdout.write(
|
17 |
+
" [{}] {}% of {:.1f}MB file \r".format(bar, percents, total / 1024 / 1024)
|
18 |
+
)
|
19 |
+
sys.stdout.flush()
|
20 |
+
if count >= total:
|
21 |
+
sys.stdout.write("\n")
|
22 |
+
|
23 |
+
|
24 |
+
def download_url(url, dst_file_path, chunk_size=8192, progress_hook=_progress_bar):
|
25 |
+
"""Download url and write it to dst_file_path. Credit:
|
26 |
+
https://stackoverflow.com/questions/2028517/python-urllib2-progress-hook
|
27 |
+
"""
|
28 |
+
# url = url + "?dl=1" if "dropbox" in url else url
|
29 |
+
req = urlrequest.Request(url)
|
30 |
+
response = urlrequest.urlopen(req)
|
31 |
+
total_size = response.info().get("Content-Length")
|
32 |
+
if total_size is None:
|
33 |
+
raise ValueError("Cannot determine size of download from {}".format(url))
|
34 |
+
total_size = int(total_size.strip())
|
35 |
+
bytes_so_far = 0
|
36 |
+
|
37 |
+
with open(dst_file_path, "wb") as f:
|
38 |
+
while 1:
|
39 |
+
chunk = response.read(chunk_size)
|
40 |
+
bytes_so_far += len(chunk)
|
41 |
+
if not chunk:
|
42 |
+
break
|
43 |
+
|
44 |
+
if progress_hook:
|
45 |
+
progress_hook(bytes_so_far, total_size)
|
46 |
+
|
47 |
+
f.write(chunk)
|
48 |
+
return bytes_so_far
|
49 |
+
|
50 |
+
|
51 |
+
def cache_url(url_or_file, cache_file_path, download=True):
|
52 |
+
"""Download the file specified by the URL to the cache_dir and return the path to
|
53 |
+
the cached file. If the argument is not a URL, simply return it as is.
|
54 |
+
"""
|
55 |
+
is_url = re.match(r"^(?:http)s?://", url_or_file, re.IGNORECASE) is not None
|
56 |
+
if not is_url:
|
57 |
+
return url_or_file
|
58 |
+
url = url_or_file
|
59 |
+
if os.path.exists(cache_file_path):
|
60 |
+
return cache_file_path
|
61 |
+
cache_file_dir = os.path.dirname(cache_file_path)
|
62 |
+
if not os.path.exists(cache_file_dir):
|
63 |
+
os.makedirs(cache_file_dir)
|
64 |
+
if download:
|
65 |
+
print("Downloading remote file {} to {}".format(url, cache_file_path))
|
66 |
+
download_url(url, cache_file_path)
|
67 |
+
return cache_file_path
|
68 |
+
|
69 |
+
|
70 |
+
def download_models(folder=CACHE_DIR_MULTIHMR):
|
71 |
+
"""Download checkpoints and files for running inference.
|
72 |
+
"""
|
73 |
+
import os
|
74 |
+
os.makedirs(folder, exist_ok=True)
|
75 |
+
download_files = {
|
76 |
+
"hmr2_data.tar.gz" : ["https://people.eecs.berkeley.edu/~jathushan/projects/4dhumans/hmr2_data.tar.gz", folder],
|
77 |
+
}
|
78 |
+
|
79 |
+
for file_name, url in download_files.items():
|
80 |
+
output_path = os.path.join(url[1], file_name)
|
81 |
+
if not os.path.exists(output_path):
|
82 |
+
print("Downloading file: " + file_name)
|
83 |
+
# output = gdown.cached_download(url[0], output_path, fuzzy=True)
|
84 |
+
output = cache_url(url[0], output_path)
|
85 |
+
assert os.path.exists(output_path), f"{output} does not exist"
|
86 |
+
|
87 |
+
# if ends with tar.gz, tar -xzf
|
88 |
+
if file_name.endswith(".tar.gz"):
|
89 |
+
print("Extracting file: " + file_name)
|
90 |
+
os.system("tar -xvf " + output_path + " -C " + url[1])
|
91 |
+
|
92 |
+
def check_smplx_exists():
|
93 |
+
import os
|
94 |
+
candidates = [
|
95 |
+
f'{SMPLX_DIR}/data/smplx/SMPLX_NEUTRAL.npz',
|
96 |
+
f'{MEAN_PARAMS}'
|
97 |
+
]
|
98 |
+
candidates_exist = [os.path.exists(c) for c in candidates]
|
99 |
+
if not any(candidates_exist):
|
100 |
+
raise FileNotFoundError(f"SMPLX model not found. Please download it from https://smplify.is.tue.mpg.de/ and place it at {candidates[1]}")
|
101 |
+
|
102 |
+
return True
|
103 |
+
|
utils/humans.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Multi-HMR
|
2 |
+
# Copyright (c) 2024-present NAVER Corp.
|
3 |
+
# CC BY-NC-SA 4.0 license
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import torch
|
8 |
+
import roma
|
9 |
+
from smplx.joint_names import JOINT_NAMES
|
10 |
+
|
11 |
+
def rot6d_to_rotmat(x):
|
12 |
+
"""
|
13 |
+
6D rotation representation to 3x3 rotation matrix.
|
14 |
+
Args:
|
15 |
+
x: (B,6) Batch of 6-D rotation representations.
|
16 |
+
Returns:
|
17 |
+
torch.Tensor: Batch of corresponding rotation matrices with shape (B,3,3).
|
18 |
+
"""
|
19 |
+
x = x.reshape(-1,2,3).permute(0, 2, 1).contiguous()
|
20 |
+
y = roma.special_gramschmidt(x)
|
21 |
+
return y
|
22 |
+
|
23 |
+
def get_smplx_joint_names(*args, **kwargs):
|
24 |
+
return JOINT_NAMES[:127]
|
utils/image.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Multi-HMR
|
2 |
+
# Copyright (c) 2024-present NAVER Corp.
|
3 |
+
# CC BY-NC-SA 4.0 license
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
IMG_NORM_MEAN = [0.485, 0.456, 0.406]
|
9 |
+
IMG_NORM_STD = [0.229, 0.224, 0.225]
|
10 |
+
|
11 |
+
|
12 |
+
def normalize_rgb(img, imagenet_normalization=True):
|
13 |
+
"""
|
14 |
+
Args:
|
15 |
+
- img: np.array - (W,H,3) - np.uint8 - 0/255
|
16 |
+
Return:
|
17 |
+
- img: np.array - (3,W,H) - np.float - -3/3
|
18 |
+
"""
|
19 |
+
img = img.astype(np.float32) / 255.
|
20 |
+
img = np.transpose(img, (2,0,1))
|
21 |
+
if imagenet_normalization:
|
22 |
+
img = (img - np.asarray(IMG_NORM_MEAN).reshape(3,1,1)) / np.asarray(IMG_NORM_STD).reshape(3,1,1)
|
23 |
+
img = img.astype(np.float32)
|
24 |
+
return img
|
25 |
+
|
26 |
+
def unpatch(data, patch_size=14, c=3, img_size=224):
|
27 |
+
# c = 3
|
28 |
+
if len(data.shape) == 2:
|
29 |
+
c=1
|
30 |
+
data = data[:,:,None].repeat([1,1,patch_size**2])
|
31 |
+
|
32 |
+
B,N,HWC = data.shape
|
33 |
+
HW = patch_size**2
|
34 |
+
c = int(HWC / HW)
|
35 |
+
h = w = int(N**.5)
|
36 |
+
p = q = int(HW**.5)
|
37 |
+
data = data.reshape([B,h,w,p,q,c])
|
38 |
+
data = torch.einsum('nhwpqc->nchpwq', data)
|
39 |
+
return data.reshape([B,c,img_size,img_size])
|
40 |
+
|
utils/render.py
ADDED
@@ -0,0 +1,448 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Multi-HMR
|
2 |
+
# Copyright (c) 2024-present NAVER Corp.
|
3 |
+
# CC BY-NC-SA 4.0 license
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import numpy as np
|
7 |
+
import pyrender
|
8 |
+
import trimesh
|
9 |
+
import math
|
10 |
+
from scipy.spatial.transform import Rotation
|
11 |
+
from PIL import ImageFont, ImageDraw, Image
|
12 |
+
|
13 |
+
OPENCV_TO_OPENGL_CAMERA_CONVENTION = np.array([[1, 0, 0, 0],
|
14 |
+
[0, -1, 0, 0],
|
15 |
+
[0, 0, -1, 0],
|
16 |
+
[0, 0, 0, 1]])
|
17 |
+
|
18 |
+
def geotrf( Trf, pts, ncol=None, norm=False):
|
19 |
+
""" Apply a geometric transformation to a list of 3-D points.
|
20 |
+
H: 3x3 or 4x4 projection matrix (typically a Homography)
|
21 |
+
p: numpy/torch/tuple of coordinates. Shape must be (...,2) or (...,3)
|
22 |
+
|
23 |
+
ncol: int. number of columns of the result (2 or 3)
|
24 |
+
norm: float. if != 0, the resut is projected on the z=norm plane.
|
25 |
+
|
26 |
+
Returns an array of projected 2d points.
|
27 |
+
"""
|
28 |
+
assert Trf.ndim in (2,3)
|
29 |
+
if isinstance(Trf, np.ndarray):
|
30 |
+
pts = np.asarray(pts)
|
31 |
+
elif isinstance(Trf, torch.Tensor):
|
32 |
+
pts = torch.as_tensor(pts, dtype=Trf.dtype)
|
33 |
+
|
34 |
+
ncol = ncol or pts.shape[-1]
|
35 |
+
|
36 |
+
# adapt shape if necessary
|
37 |
+
output_reshape = pts.shape[:-1]
|
38 |
+
if Trf.ndim == 3:
|
39 |
+
assert len(Trf) == len(pts), 'batch size does not match'
|
40 |
+
if Trf.ndim == 3 and pts.ndim > 3:
|
41 |
+
# Trf == (B,d,d) & pts == (B,H,W,d) --> (B, H*W, d)
|
42 |
+
pts = pts.reshape(pts.shape[0], -1, pts.shape[-1])
|
43 |
+
elif Trf.ndim == 3 and pts.ndim == 2:
|
44 |
+
# Trf == (B,d,d) & pts == (B,d) --> (B, 1, d)
|
45 |
+
pts = pts[:, None, :]
|
46 |
+
|
47 |
+
if pts.shape[-1]+1 == Trf.shape[-1]:
|
48 |
+
Trf = Trf.swapaxes(-1,-2) # transpose Trf
|
49 |
+
pts = pts @ Trf[...,:-1,:] + Trf[...,-1:,:]
|
50 |
+
elif pts.shape[-1] == Trf.shape[-1]:
|
51 |
+
Trf = Trf.swapaxes(-1,-2) # transpose Trf
|
52 |
+
pts = pts @ Trf
|
53 |
+
else:
|
54 |
+
pts = Trf @ pts.T
|
55 |
+
if pts.ndim >= 2: pts = pts.swapaxes(-1,-2)
|
56 |
+
if norm:
|
57 |
+
pts = pts / pts[...,-1:] # DONT DO /= BECAUSE OF WEIRD PYTORCH BUG
|
58 |
+
if norm != 1: pts *= norm
|
59 |
+
|
60 |
+
return pts[...,:ncol].reshape(*output_reshape, ncol)
|
61 |
+
|
62 |
+
def create_scene(img_pil, l_mesh, l_face, color=None, metallicFactor=0., roughnessFactor=0.5, focal=600):
|
63 |
+
|
64 |
+
scene = trimesh.Scene(
|
65 |
+
lights=trimesh.scene.lighting.Light(intensity=3.0)
|
66 |
+
)
|
67 |
+
|
68 |
+
# Human meshes
|
69 |
+
for i, mesh in enumerate(l_mesh):
|
70 |
+
if color is None:
|
71 |
+
_color = (np.random.choice(range(1,225))/255, np.random.choice(range(1,225))/255, np.random.choice(range(1,225))/255)
|
72 |
+
else:
|
73 |
+
if isinstance(color,list):
|
74 |
+
_color = color[i]
|
75 |
+
elif isinstance(color,tuple):
|
76 |
+
_color = color
|
77 |
+
else:
|
78 |
+
raise NotImplementedError
|
79 |
+
mesh = trimesh.Trimesh(mesh, l_face[i])
|
80 |
+
mesh.visual = trimesh.visual.TextureVisuals(
|
81 |
+
uv=None,
|
82 |
+
material=trimesh.visual.material.PBRMaterial(
|
83 |
+
metallicFactor=metallicFactor,
|
84 |
+
roughnessFactor=roughnessFactor,
|
85 |
+
alphaMode='OPAQUE',
|
86 |
+
baseColorFactor=(_color[0], _color[1], _color[2], 1.0)
|
87 |
+
),
|
88 |
+
image=None,
|
89 |
+
face_materials=None
|
90 |
+
)
|
91 |
+
scene.add_geometry(mesh)
|
92 |
+
|
93 |
+
# Image
|
94 |
+
H, W = img_pil.size[0], img_pil.size[1]
|
95 |
+
screen_width = 0.3
|
96 |
+
height = focal * screen_width / H
|
97 |
+
width = screen_width * 0.5**0.5
|
98 |
+
rot45 = np.eye(4)
|
99 |
+
rot45[:3,:3] = Rotation.from_euler('z',np.deg2rad(45)).as_matrix()
|
100 |
+
rot45[2,3] = -height # set the tip of the cone = optical center
|
101 |
+
aspect_ratio = np.eye(4)
|
102 |
+
aspect_ratio[0,0] = W/H
|
103 |
+
transform = OPENCV_TO_OPENGL_CAMERA_CONVENTION @ aspect_ratio @ rot45
|
104 |
+
cam = trimesh.creation.cone(width, height, sections=4, transform=transform)
|
105 |
+
# cam.apply_transform(transform)
|
106 |
+
# import ipdb
|
107 |
+
# ipdb.set_trace()
|
108 |
+
|
109 |
+
# vertices = geotrf(transform, cam.vertices[[4,5,1,3]])
|
110 |
+
vertices = cam.vertices[[4,5,1,3]]
|
111 |
+
faces = np.array([[0, 1, 2], [0, 2, 3], [2, 1, 0], [3, 2, 0]])
|
112 |
+
img = trimesh.Trimesh(vertices=vertices, faces=faces)
|
113 |
+
uv_coords = np.float32([[0, 0], [1, 0], [1, 1], [0, 1]])
|
114 |
+
# img_pil = Image.fromarray((255. * np.ones((20,20,3))).astype(np.uint8)) # white only!
|
115 |
+
material = trimesh.visual.texture.SimpleMaterial(image=img_pil,
|
116 |
+
diffuse=[255,255,255,0],
|
117 |
+
ambient=[255,255,255,0],
|
118 |
+
specular=[255,255,255,0],
|
119 |
+
glossiness=1.0)
|
120 |
+
img.visual = trimesh.visual.TextureVisuals(uv=uv_coords, image=img_pil) #, material=material)
|
121 |
+
# _main_color = [255,255,255,0]
|
122 |
+
# print(img.visual.material.ambient)
|
123 |
+
# print(img.visual.material.diffuse)
|
124 |
+
# print(img.visual.material.specular)
|
125 |
+
# print(img.visual.material.main_color)
|
126 |
+
|
127 |
+
# img.visual.material.ambient = _main_color
|
128 |
+
# img.visual.material.diffuse = _main_color
|
129 |
+
# img.visual.material.specular = _main_color
|
130 |
+
|
131 |
+
# img.visual.material.main_color = _main_color
|
132 |
+
# img.visual.material.glossiness = _main_color
|
133 |
+
scene.add_geometry(img)
|
134 |
+
|
135 |
+
# this is the camera mesh
|
136 |
+
rot2 = np.eye(4)
|
137 |
+
rot2[:3,:3] = Rotation.from_euler('z',np.deg2rad(2)).as_matrix()
|
138 |
+
# import ipdb
|
139 |
+
# ipdb.set_trace()
|
140 |
+
# vertices = cam.vertices
|
141 |
+
# print(rot2)
|
142 |
+
vertices = np.r_[cam.vertices, 0.95*cam.vertices, geotrf(rot2, cam.vertices)]
|
143 |
+
# vertices = np.r_[cam.vertices, 0.95*cam.vertices, 1.05*cam.vertices]
|
144 |
+
faces = []
|
145 |
+
for face in cam.faces:
|
146 |
+
if 0 in face: continue
|
147 |
+
a,b,c = face
|
148 |
+
a2,b2,c2 = face + len(cam.vertices)
|
149 |
+
a3,b3,c3 = face + 2*len(cam.vertices)
|
150 |
+
|
151 |
+
# add 3 pseudo-edges
|
152 |
+
faces.append((a,b,b2))
|
153 |
+
faces.append((a,a2,c))
|
154 |
+
faces.append((c2,b,c))
|
155 |
+
|
156 |
+
faces.append((a,b,b3))
|
157 |
+
faces.append((a,a3,c))
|
158 |
+
faces.append((c3,b,c))
|
159 |
+
|
160 |
+
# no culling
|
161 |
+
faces += [(c,b,a) for a,b,c in faces]
|
162 |
+
|
163 |
+
cam = trimesh.Trimesh(vertices=vertices, faces=faces)
|
164 |
+
cam.visual.face_colors[:,:3] = (255, 0, 0)
|
165 |
+
scene.add_geometry(cam)
|
166 |
+
|
167 |
+
# OpenCV to OpenGL
|
168 |
+
rot = np.eye(4)
|
169 |
+
cams2world = np.eye(4)
|
170 |
+
rot[:3,:3] = Rotation.from_euler('y',np.deg2rad(180)).as_matrix()
|
171 |
+
scene.apply_transform(np.linalg.inv(cams2world @ OPENCV_TO_OPENGL_CAMERA_CONVENTION @ rot))
|
172 |
+
|
173 |
+
return scene
|
174 |
+
|
175 |
+
def render_meshes(img, l_mesh, l_face, cam_param, color=None, alpha=1.0,
|
176 |
+
show_camera=False,
|
177 |
+
intensity=3.0,
|
178 |
+
metallicFactor=0., roughnessFactor=0.5, smooth=True,
|
179 |
+
):
|
180 |
+
"""
|
181 |
+
Rendering multiple mesh and project then in the initial image.
|
182 |
+
Args:
|
183 |
+
- img: np.array [w,h,3]
|
184 |
+
- l_mesh: np.array list of [v,3]
|
185 |
+
- l_face: np.array list of [f,3]
|
186 |
+
- cam_param: info about the camera intrinsics (focal, princpt) and (R,t) is possible
|
187 |
+
Return:
|
188 |
+
- img: np.array [w,h,3]
|
189 |
+
"""
|
190 |
+
# scene
|
191 |
+
scene = pyrender.Scene(ambient_light=(0.3, 0.3, 0.3))
|
192 |
+
|
193 |
+
# mesh
|
194 |
+
for i, mesh in enumerate(l_mesh):
|
195 |
+
if color is None:
|
196 |
+
_color = (np.random.choice(range(1,225))/255, np.random.choice(range(1,225))/255, np.random.choice(range(1,225))/255)
|
197 |
+
else:
|
198 |
+
if isinstance(color,list):
|
199 |
+
_color = color[i]
|
200 |
+
elif isinstance(color,tuple):
|
201 |
+
_color = color
|
202 |
+
else:
|
203 |
+
raise NotImplementedError
|
204 |
+
mesh = trimesh.Trimesh(mesh, l_face[i])
|
205 |
+
|
206 |
+
# import ipdb
|
207 |
+
# ipdb.set_trace()
|
208 |
+
|
209 |
+
# mesh.visual = trimesh.visual.TextureVisuals(
|
210 |
+
# uv=None,
|
211 |
+
# material=trimesh.visual.material.PBRMaterial(
|
212 |
+
# metallicFactor=metallicFactor,
|
213 |
+
# roughnessFactor=roughnessFactor,
|
214 |
+
# alphaMode='OPAQUE',
|
215 |
+
# baseColorFactor=(_color[0], _color[1], _color[2], 1.0)
|
216 |
+
# ),
|
217 |
+
# image=None,
|
218 |
+
# face_materials=None
|
219 |
+
# )
|
220 |
+
# print('saving')
|
221 |
+
# mesh.export('human.obj')
|
222 |
+
# mesh = trimesh.load('human.obj')
|
223 |
+
# print('loading')
|
224 |
+
# mesh = pyrender.Mesh.from_trimesh(mesh, smooth=smooth)
|
225 |
+
|
226 |
+
material = pyrender.MetallicRoughnessMaterial(
|
227 |
+
metallicFactor=metallicFactor,
|
228 |
+
roughnessFactor=roughnessFactor,
|
229 |
+
alphaMode='OPAQUE',
|
230 |
+
baseColorFactor=(_color[0], _color[1], _color[2], 1.0))
|
231 |
+
mesh = pyrender.Mesh.from_trimesh(mesh, material=material, smooth=smooth)
|
232 |
+
scene.add(mesh, f"mesh_{i}")
|
233 |
+
|
234 |
+
# Adding coordinate system at (0,0,2) for the moment
|
235 |
+
# Using lines defined in pyramid https://docs.pyvista.org/version/stable/api/utilities/_autosummary/pyvista.Pyramid.html
|
236 |
+
if show_camera:
|
237 |
+
import pyvista
|
238 |
+
|
239 |
+
def get_faces(x):
|
240 |
+
return x.faces.astype(np.uint32).reshape((x.n_faces, 4))[:, 1:]
|
241 |
+
|
242 |
+
# Camera = Box + Cone (or Cylinder?)
|
243 |
+
material_cam = pyrender.MetallicRoughnessMaterial(metallicFactor=metallicFactor, roughnessFactor=roughnessFactor, alphaMode='OPAQUE', baseColorFactor=(0.5,0.5,0.5))
|
244 |
+
height = 0.2
|
245 |
+
radius = 0.1
|
246 |
+
cone = pyvista.Cone(center=(0.0, 0.0, -height/2), direction=(0.0, 0.0, -1.0), height=height, radius=radius).extract_surface().triangulate()
|
247 |
+
verts = cone.points
|
248 |
+
mesh = pyrender.Mesh.from_trimesh(trimesh.Trimesh(verts, get_faces(cone)), material=material_cam, smooth=smooth)
|
249 |
+
scene.add(mesh, f"cone")
|
250 |
+
|
251 |
+
size = 0.1
|
252 |
+
box = pyvista.Box(bounds=(-size, size,
|
253 |
+
-size, size,
|
254 |
+
verts[:,-1].min() - 3*size, verts[:,-1].min())).extract_surface().triangulate()
|
255 |
+
verts = box.points
|
256 |
+
mesh = pyrender.Mesh.from_trimesh(trimesh.Trimesh(verts, get_faces(box)), material=material_cam, smooth=smooth)
|
257 |
+
scene.add(mesh, f"box")
|
258 |
+
|
259 |
+
|
260 |
+
# Coordinate system
|
261 |
+
# https://docs.pyvista.org/version/stable/api/utilities/_autosummary/pyvista.Arrow.html
|
262 |
+
l_color = [(1,0,0,1.0), (0,1,0,1.0), (0,0,1,1.0)]
|
263 |
+
l_direction = [(1,0,0), (0,1,0), (0,0,1)]
|
264 |
+
scale = 0.2
|
265 |
+
pose3d = [2*scale, 0.0, -scale]
|
266 |
+
for i in range(len(l_color)):
|
267 |
+
arrow = pyvista.Arrow(direction=l_direction[i], scale=scale)
|
268 |
+
arrow = arrow.extract_surface().triangulate()
|
269 |
+
verts = arrow.points + np.asarray([pose3d])
|
270 |
+
faces = arrow.faces.astype(np.uint32).reshape((arrow.n_faces, 4))[:, 1:]
|
271 |
+
mesh = trimesh.Trimesh(verts, faces)
|
272 |
+
material = pyrender.MetallicRoughnessMaterial(metallicFactor=metallicFactor, roughnessFactor=roughnessFactor, alphaMode='OPAQUE', baseColorFactor=l_color[i])
|
273 |
+
mesh = pyrender.Mesh.from_trimesh(mesh, material=material, smooth=smooth)
|
274 |
+
scene.add(mesh, f"arrow_{i}")
|
275 |
+
|
276 |
+
focal, princpt = cam_param['focal'], cam_param['princpt']
|
277 |
+
camera_pose = np.eye(4)
|
278 |
+
if 'R' in cam_param.keys():
|
279 |
+
camera_pose[:3, :3] = cam_param['R']
|
280 |
+
if 't' in cam_param.keys():
|
281 |
+
camera_pose[:3, 3] = cam_param['t']
|
282 |
+
camera = pyrender.IntrinsicsCamera(fx=focal[0], fy=focal[1], cx=princpt[0], cy=princpt[1])
|
283 |
+
|
284 |
+
# camera
|
285 |
+
camera_pose = OPENCV_TO_OPENGL_CAMERA_CONVENTION @ camera_pose
|
286 |
+
camera_pose = np.linalg.inv(camera_pose)
|
287 |
+
scene.add(camera, pose=camera_pose)
|
288 |
+
|
289 |
+
# renderer
|
290 |
+
renderer = pyrender.OffscreenRenderer(viewport_width=img.shape[1], viewport_height=img.shape[0], point_size=1.0)
|
291 |
+
|
292 |
+
# light
|
293 |
+
light = pyrender.DirectionalLight(intensity=intensity)
|
294 |
+
scene.add(light, pose=camera_pose)
|
295 |
+
|
296 |
+
# render
|
297 |
+
rgb, depth = renderer.render(scene, flags=pyrender.RenderFlags.RGBA)
|
298 |
+
rgb = rgb[:,:,:3].astype(np.float32)
|
299 |
+
fg = (depth > 0)[:,:,None].astype(np.float32)
|
300 |
+
|
301 |
+
# Simple smoothing of the mask
|
302 |
+
bg_blending_radius = 1
|
303 |
+
bg_blending_kernel = 2.0 * torch.ones((1, 1, 2 * bg_blending_radius + 1, 2 * bg_blending_radius + 1)) / (2 * bg_blending_radius + 1) ** 2
|
304 |
+
bg_blending_bias = -torch.ones(1)
|
305 |
+
fg = fg.reshape((fg.shape[0],fg.shape[1]))
|
306 |
+
fg = torch.from_numpy(fg).unsqueeze(0)
|
307 |
+
fg = torch.clamp_min(torch.nn.functional.conv2d(fg, weight=bg_blending_kernel, bias=bg_blending_bias, padding=bg_blending_radius) * fg, 0.0)
|
308 |
+
fg = fg.permute(1,2,0).numpy()
|
309 |
+
|
310 |
+
# Alpha-blending
|
311 |
+
img = (fg * (alpha * rgb + (1.0-alpha) * img) + (1-fg) * img).astype(np.uint8)
|
312 |
+
|
313 |
+
renderer.delete()
|
314 |
+
|
315 |
+
return img.astype(np.uint8)
|
316 |
+
|
317 |
+
def length(v):
|
318 |
+
return math.sqrt(v[0]*v[0]+v[1]*v[1]+v[2]*v[2])
|
319 |
+
|
320 |
+
def cross(v0, v1):
|
321 |
+
return [
|
322 |
+
v0[1]*v1[2]-v1[1]*v0[2],
|
323 |
+
v0[2]*v1[0]-v1[2]*v0[0],
|
324 |
+
v0[0]*v1[1]-v1[0]*v0[1]]
|
325 |
+
|
326 |
+
def dot(v0, v1):
|
327 |
+
return v0[0]*v1[0]+v0[1]*v1[1]+v0[2]*v1[2]
|
328 |
+
|
329 |
+
def normalize(v, eps=1e-13):
|
330 |
+
l = length(v)
|
331 |
+
return [v[0]/(l+eps), v[1]/(l+eps), v[2]/(l+eps)]
|
332 |
+
|
333 |
+
def lookAt(eye, target, *args, **kwargs):
|
334 |
+
"""
|
335 |
+
eye is the point of view, target is the point which is looked at and up is the upwards direction.
|
336 |
+
|
337 |
+
Input should be in OpenCV format - we transform arguments to OpenGL
|
338 |
+
Do compute in OpenGL and then transform back to OpenCV
|
339 |
+
|
340 |
+
"""
|
341 |
+
# Transform from OpenCV to OpenGL format
|
342 |
+
# eye = [eye[0], -eye[1], -eye[2]]
|
343 |
+
# target = [target[0], -target[1], -target[2]]
|
344 |
+
up = [0,-1,0]
|
345 |
+
|
346 |
+
eye, at, up = eye, target, up
|
347 |
+
zaxis = normalize((at[0]-eye[0], at[1]-eye[1], at[2]-eye[2]))
|
348 |
+
xaxis = normalize(cross(zaxis, up))
|
349 |
+
yaxis = cross(xaxis, zaxis)
|
350 |
+
|
351 |
+
zaxis = [-zaxis[0],-zaxis[1],-zaxis[2]]
|
352 |
+
|
353 |
+
viewMatrix = np.asarray([
|
354 |
+
[xaxis[0], xaxis[1], xaxis[2], -dot(xaxis, eye)],
|
355 |
+
[yaxis[0], yaxis[1], yaxis[2], -dot(yaxis, eye)],
|
356 |
+
[zaxis[0], zaxis[1], zaxis[2], -dot(zaxis, eye)],
|
357 |
+
[0, 0, 0, 1]]
|
358 |
+
).reshape(4,4)
|
359 |
+
|
360 |
+
# OpenGL to OpenCV
|
361 |
+
viewMatrix = OPENCV_TO_OPENGL_CAMERA_CONVENTION @ viewMatrix
|
362 |
+
|
363 |
+
return viewMatrix
|
364 |
+
|
365 |
+
def print_distance_on_image(pred_rend_array, humans, _color):
|
366 |
+
# Add distance to the image.
|
367 |
+
font = ImageFont.load_default()
|
368 |
+
rend_pil = Image.fromarray(pred_rend_array)
|
369 |
+
draw = ImageDraw.Draw(rend_pil)
|
370 |
+
for i_hum, hum in enumerate(humans):
|
371 |
+
# distance
|
372 |
+
transl = hum['transl_pelvis'].cpu().numpy().reshape(3)
|
373 |
+
dist_cam = np.sqrt(((transl[[0,2]])**2).sum()) # discarding Y axis
|
374 |
+
# 2d - bbox
|
375 |
+
bbox = get_bbox(hum['j2d_smplx'].cpu().numpy(), factor=1.35, output_format='x1y1x2y2')
|
376 |
+
loc = [(bbox[0] + bbox[2]) / 2., bbox[1]]
|
377 |
+
txt = f"{dist_cam:.2f}m"
|
378 |
+
length = font.getlength(txt)
|
379 |
+
loc[0] = loc[0] - length // 2
|
380 |
+
fill = tuple((np.asarray(_color[i_hum]) * 255).astype(np.int32).tolist())
|
381 |
+
draw.text((loc[0], loc[1]), txt, fill=fill, font=font)
|
382 |
+
return np.asarray(rend_pil)
|
383 |
+
|
384 |
+
def get_bbox(points, factor=1., output_format='xywh'):
|
385 |
+
"""
|
386 |
+
Args:
|
387 |
+
- y: [k,2]
|
388 |
+
Return:
|
389 |
+
- bbox: [4] in a specific format
|
390 |
+
"""
|
391 |
+
assert len(points.shape) == 2, f"Wrong shape, expected two-dimensional array. Got shape {points.shape}"
|
392 |
+
assert points.shape[1] == 2
|
393 |
+
x1, x2 = points[:,0].min(), points[:,0].max()
|
394 |
+
y1, y2 = points[:,1].min(), points[:,1].max()
|
395 |
+
cx, cy = (x2 + x1) / 2., (y2 + y1) / 2.
|
396 |
+
sx, sy = np.abs(x2 - x1), np.abs(y2 - y1)
|
397 |
+
sx, sy = int(factor * sx), int(factor * sy)
|
398 |
+
x1, y1 = int(cx - sx / 2.), int(cy - sy / 2.)
|
399 |
+
x2, y2 = int(cx + sx / 2.), int(cy + sy / 2.)
|
400 |
+
if output_format == 'xywh':
|
401 |
+
return [x1,y1,sx,sy]
|
402 |
+
elif output_format == 'x1y1x2y2':
|
403 |
+
return [x1,y1,x2,y2]
|
404 |
+
else:
|
405 |
+
raise NotImplementedError
|
406 |
+
|
407 |
+
def render_side_views(img_array, _color, humans, model, K):
|
408 |
+
_bg = 255. # white
|
409 |
+
|
410 |
+
# camera
|
411 |
+
focal = np.asarray([K[0,0,0].cpu().numpy(),K[0,1,1].cpu().numpy()])
|
412 |
+
princpt = np.asarray([K[0,0,-1].cpu().numpy(),K[0,1,-1].cpu().numpy()])
|
413 |
+
|
414 |
+
# Get the vertices produced by the model.
|
415 |
+
l_verts = [humans[j]['verts_smplx'].cpu().numpy() for j in range(len(humans))]
|
416 |
+
l_faces = [model.smpl_layer['neutral'].bm_x.faces for j in range(len(humans))]
|
417 |
+
|
418 |
+
bg_array = 1 + 0.*img_array.copy()
|
419 |
+
if len(humans) == 0:
|
420 |
+
pred_rend_array_bis = _bg * bg_array.copy()
|
421 |
+
pred_rend_array_sideview = _bg * bg_array.copy()
|
422 |
+
pred_rend_array_bev = _bg * bg_array.copy()
|
423 |
+
else:
|
424 |
+
# Small displacement
|
425 |
+
H_bis = lookAt(eye=[2.,-1,-2], target=[0,0,3])
|
426 |
+
pred_rend_array_bis = render_meshes(_bg* bg_array.copy(), l_verts, l_faces,
|
427 |
+
{'focal': focal, 'princpt': princpt, 'R': H_bis[:3,:3], 't': H_bis[:3,3]},
|
428 |
+
alpha=1.0, color=_color, show_camera=True)
|
429 |
+
|
430 |
+
# Where to look at
|
431 |
+
l_z = []
|
432 |
+
for hum in humans:
|
433 |
+
l_z.append(hum['transl_pelvis'].cpu().numpy().reshape(-1)[-1])
|
434 |
+
target_z = np.median(np.asarray(l_z))
|
435 |
+
|
436 |
+
# Sideview
|
437 |
+
H_sideview = lookAt(eye=[2.2*target_z,0,target_z], target=[0,0,target_z])
|
438 |
+
pred_rend_array_sideview = render_meshes(_bg*bg_array.copy(), l_verts, l_faces,
|
439 |
+
{'focal': focal, 'princpt': princpt, 'R': H_sideview[:3,:3], 't': H_sideview[:3,3]},
|
440 |
+
alpha=1.0, color=_color, show_camera=True)
|
441 |
+
|
442 |
+
# Bird-Eye-View
|
443 |
+
H_bev = lookAt(eye=[0.,-2*target_z,target_z-0.001], target=[0,0,target_z])
|
444 |
+
pred_rend_array_bev = render_meshes(_bg* bg_array.copy(), l_verts, l_faces,
|
445 |
+
{'focal': focal, 'princpt': princpt, 'R': H_bev[:3,:3], 't': H_bev[:3,3]},
|
446 |
+
alpha=1.0, color=_color, show_camera=True)
|
447 |
+
|
448 |
+
return pred_rend_array_bis, pred_rend_array_sideview, pred_rend_array_bev
|
utils/tensor_manip.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Multi-HMR
|
2 |
+
# Copyright (c) 2024-present NAVER Corp.
|
3 |
+
# CC BY-NC-SA 4.0 license
|
4 |
+
|
5 |
+
import torch
|
6 |
+
|
7 |
+
def rebatch(idx_0, idx_det):
|
8 |
+
# Rebuild the batch dimension : (N, ...) is turned into (batch_dim, nb_max, ...)
|
9 |
+
# with zero padding for batch elements with fewer people.
|
10 |
+
values, counts = torch.unique(idx_0, sorted=True, return_counts=True)
|
11 |
+
#print(idx_0)
|
12 |
+
if not len(values) == values.max() + 1:
|
13 |
+
# Abnormal jumps in the idx_0: some images in the batch did not produce any inputs.
|
14 |
+
jumps = (values - torch.concat([torch.Tensor([-1]).to(values.device), values])[:-1]) - 1
|
15 |
+
offsets = torch.cumsum(jumps.int(), dim=0)
|
16 |
+
|
17 |
+
# Correcting idx_0 to account for missing batch elements
|
18 |
+
# This is actually wrong: in the case where we have 2 consecutive images without ppl, this will fail.
|
19 |
+
# But two consecutive jumps has proba so close to 0 that I consider it 'impossible'.
|
20 |
+
offsets = [c * [o] for o, c in [(offsets[i], counts[i]) for i in range(offsets.shape[0])]]
|
21 |
+
offsets = torch.Tensor([e for o in offsets for e in o]).to(jumps.device).int()
|
22 |
+
idx_0 = idx_0 - offsets
|
23 |
+
idx_det_0 = idx_det[0] - offsets
|
24 |
+
else:
|
25 |
+
idx_det_0 = idx_det[0]
|
26 |
+
return counts, idx_det_0
|
27 |
+
|
28 |
+
def pad(x, padlen, dim):
|
29 |
+
assert x.shape[dim] <= padlen, "Incoherent dimensions"
|
30 |
+
if not dim == 1:
|
31 |
+
raise NotImplementedError("Not implemented for this dim.")
|
32 |
+
padded = torch.concat([x, x.new_zeros((x.shape[0], padlen - x.shape[dim],) + x.shape[2:])], dim=dim)
|
33 |
+
mask = torch.concat([x.new_ones((x.shape[0], x.shape[dim])), x.new_zeros((x.shape[0], padlen - x.shape[dim]))], dim=dim)
|
34 |
+
return padded, mask
|
35 |
+
|
36 |
+
def pad_to_max(x_central, counts):
|
37 |
+
"""Pad so that each batch images has the same number of x_central queries.
|
38 |
+
Mask is used in attention to remove the fact queries. """
|
39 |
+
max_count = counts.max()
|
40 |
+
xlist = torch.split(x_central, tuple(counts), dim=0)
|
41 |
+
xlist2 = [x.unsqueeze(0) for x in xlist]
|
42 |
+
xlist3 = [pad(x, max_count, dim=1) for x in xlist2]
|
43 |
+
xlist4, mask = [x[0] for x in xlist3], [x[1] for x in xlist3]
|
44 |
+
x_central, mask = torch.concat(xlist4, dim=0), torch.concat(mask, dim=0)
|
45 |
+
return x_central, mask
|