Spaces:
Runtime error
Runtime error
Commit
·
2d5f249
0
Parent(s):
done
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +37 -0
- .gitignore +15 -0
- LICENSE +53 -0
- README.md +12 -0
- app.py +129 -0
- apps/ICON.py +765 -0
- apps/Normal.py +213 -0
- apps/infer.py +467 -0
- assets/garment_teaser.png +3 -0
- assets/intermediate_results.png +3 -0
- assets/teaser.gif +3 -0
- assets/thumbnail.png +3 -0
- configs/icon-filter.yaml +25 -0
- configs/icon-nofilter.yaml +25 -0
- configs/pamir.yaml +24 -0
- configs/pifu.yaml +24 -0
- examples/22097467bffc92d4a5c4246f7d4edb75.png +3 -0
- examples/44c0f84c957b6b9bdf77662af5bb7078.png +3 -0
- examples/5a6a25963db2f667441d5076972c207c.png +3 -0
- examples/8da7ceb94669c2f65cbd28022e1f9876.png +3 -0
- examples/923d65f767c85a42212cae13fba3750b.png +3 -0
- examples/959c4c726a69901ce71b93a9242ed900.png +3 -0
- examples/c9856a2bc31846d684cbb965457fad59.png +3 -0
- examples/e1e7622af7074a022f5d96dc16672517.png +3 -0
- examples/fb9d20fdb93750584390599478ecf86e.png +3 -0
- examples/slack_trial2-000150.png +3 -0
- lib/__init__.py +0 -0
- lib/common/__init__.py +0 -0
- lib/common/cloth_extraction.py +170 -0
- lib/common/config.py +218 -0
- lib/common/render.py +388 -0
- lib/common/render_utils.py +221 -0
- lib/common/seg3d_lossless.py +604 -0
- lib/common/seg3d_utils.py +392 -0
- lib/common/smpl_vert_segmentation.json +0 -0
- lib/common/train_util.py +597 -0
- lib/dataset/Evaluator.py +264 -0
- lib/dataset/NormalDataset.py +212 -0
- lib/dataset/NormalModule.py +94 -0
- lib/dataset/PIFuDataModule.py +71 -0
- lib/dataset/PIFuDataset.py +589 -0
- lib/dataset/TestDataset.py +256 -0
- lib/dataset/__init__.py +0 -0
- lib/dataset/body_model.py +494 -0
- lib/dataset/hoppeMesh.py +116 -0
- lib/dataset/mesh_util.py +894 -0
- lib/dataset/tbfo.ttf +0 -0
- lib/net/BasePIFuNet.py +84 -0
- lib/net/FBNet.py +387 -0
- lib/net/HGFilters.py +197 -0
.gitattributes
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
23 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
26 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.obj filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.mp4 filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*.glb filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
37 |
+
*.gif filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
data/*/*
|
2 |
+
data/thuman*
|
3 |
+
!data/tbfo.ttf
|
4 |
+
__pycache__
|
5 |
+
debug/
|
6 |
+
log/
|
7 |
+
.vscode
|
8 |
+
!.gitignore
|
9 |
+
force_push.sh
|
10 |
+
.idea
|
11 |
+
human_det/
|
12 |
+
kaolin/
|
13 |
+
neural_voxelization_layer/
|
14 |
+
pytorch3d/
|
15 |
+
force_push.sh
|
LICENSE
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
License
|
2 |
+
|
3 |
+
Software Copyright License for non-commercial scientific research purposes
|
4 |
+
Please read carefully the following terms and conditions and any accompanying documentation before you download and/or use the ICON model, data and software, (the "Data & Software"), including 3D meshes, images, videos, textures, software, scripts, and animations. By downloading and/or using the Data & Software (including downloading, cloning, installing, and any other use of the corresponding github repository), you acknowledge that you have read these terms and conditions, understand them, and agree to be bound by them. If you do not agree with these terms and conditions, you must not download and/or use the Data & Software. Any infringement of the terms of this agreement will automatically terminate your rights under this License
|
5 |
+
|
6 |
+
Ownership / Licensees
|
7 |
+
The Software and the associated materials has been developed at the Max Planck Institute for Intelligent Systems (hereinafter "MPI"). Any copyright or patent right is owned by and proprietary material of the Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (hereinafter “MPG”; MPI and MPG hereinafter collectively “Max-Planck”) hereinafter the “Licensor”.
|
8 |
+
|
9 |
+
License Grant
|
10 |
+
Licensor grants you (Licensee) personally a single-user, non-exclusive, non-transferable, free of charge right:
|
11 |
+
|
12 |
+
• To install the Model & Software on computers owned, leased or otherwise controlled by you and/or your organization;
|
13 |
+
• To use the Model & Software for the sole purpose of performing peaceful non-commercial scientific research, non-commercial education, or non-commercial artistic projects;
|
14 |
+
• To modify, adapt, translate or create derivative works based upon the Model & Software.
|
15 |
+
|
16 |
+
Any other use, in particular any use for commercial, pornographic, military, or surveillance, purposes is prohibited. This includes, without limitation, incorporation in a commercial product, use in a commercial service, or production of other artifacts for commercial purposes. The Data & Software may not be used to create fake, libelous, misleading, or defamatory content of any kind excluding analyses in peer-reviewed scientific research. The Data & Software may not be reproduced, modified and/or made available in any form to any third party without Max-Planck’s prior written permission.
|
17 |
+
|
18 |
+
The Data & Software may not be used for pornographic purposes or to generate pornographic material whether commercial or not. This license also prohibits the use of the Software to train methods/algorithms/neural networks/etc. for commercial, pornographic, military, surveillance, or defamatory use of any kind. By downloading the Data & Software, you agree not to reverse engineer it.
|
19 |
+
|
20 |
+
No Distribution
|
21 |
+
The Data & Software and the license herein granted shall not be copied, shared, distributed, re-sold, offered for re-sale, transferred or sub-licensed in whole or in part except that you may make one copy for archive purposes only.
|
22 |
+
|
23 |
+
Disclaimer of Representations and Warranties
|
24 |
+
You expressly acknowledge and agree that the Data & Software results from basic research, is provided “AS IS”, may contain errors, and that any use of the Data & Software is at your sole risk. LICENSOR MAKES NO REPRESENTATIONS OR WARRANTIES OF ANY KIND CONCERNING THE DATA & SOFTWARE, NEITHER EXPRESS NOR IMPLIED, AND THE ABSENCE OF ANY LEGAL OR ACTUAL DEFECTS, WHETHER DISCOVERABLE OR NOT. Specifically, and not to limit the foregoing, licensor makes no representations or warranties (i) regarding the merchantability or fitness for a particular purpose of the Data & Software, (ii) that the use of the Data & Software will not infringe any patents, copyrights or other intellectual property rights of a third party, and (iii) that the use of the Data & Software will not cause any damage of any kind to you or a third party.
|
25 |
+
|
26 |
+
Limitation of Liability
|
27 |
+
Because this Data & Software License Agreement qualifies as a donation, according to Section 521 of the German Civil Code (Bürgerliches Gesetzbuch – BGB) Licensor as a donor is liable for intent and gross negligence only. If the Licensor fraudulently conceals a legal or material defect, they are obliged to compensate the Licensee for the resulting damage.
|
28 |
+
Licensor shall be liable for loss of data only up to the amount of typical recovery costs which would have arisen had proper and regular data backup measures been taken. For the avoidance of doubt Licensor shall be liable in accordance with the German Product Liability Act in the event of product liability. The foregoing applies also to Licensor’s legal representatives or assistants in performance. Any further liability shall be excluded.
|
29 |
+
Patent claims generated through the usage of the Data & Software cannot be directed towards the copyright holders.
|
30 |
+
The Data & Software is provided in the state of development the licensor defines. If modified or extended by Licensee, the Licensor makes no claims about the fitness of the Data & Software and is not responsible for any problems such modifications cause.
|
31 |
+
|
32 |
+
No Maintenance Services
|
33 |
+
You understand and agree that Licensor is under no obligation to provide either maintenance services, update services, notices of latent defects, or corrections of defects with regard to the Data & Software. Licensor nevertheless reserves the right to update, modify, or discontinue the Data & Software at any time.
|
34 |
+
|
35 |
+
Defects of the Data & Software must be notified in writing to the Licensor with a comprehensible description of the error symptoms. The notification of the defect should enable the reproduction of the error. The Licensee is encouraged to communicate any use, results, modification or publication.
|
36 |
+
|
37 |
+
Publications using the Model & Software
|
38 |
+
You acknowledge that the Data & Software is a valuable scientific resource and agree to appropriately reference the following paper in any publication making use of the Data & Software.
|
39 |
+
|
40 |
+
Citation:
|
41 |
+
|
42 |
+
@inproceedings{xiu2022icon,
|
43 |
+
title={{ICON}: {I}mplicit {C}lothed humans {O}btained from {N}ormals},
|
44 |
+
author={Xiu, Yuliang and Yang, Jinlong and Tzionas, Dimitrios and Black, Michael J.},
|
45 |
+
booktitle={IEEE/CVF Conf.~on Computer Vision and Pattern Recognition (CVPR)},
|
46 |
+
month = jun,
|
47 |
+
year={2022}
|
48 |
+
}
|
49 |
+
|
50 |
+
Commercial licensing opportunities
|
51 |
+
For commercial uses of the Model & Software, please send email to [email protected]
|
52 |
+
|
53 |
+
This Agreement shall be governed by the laws of the Federal Republic of Germany except for the UN Sales Convention.
|
README.md
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: ICON - Clothed Human Digitization
|
3 |
+
metaTitle: "Making yourself an ICON, by Yuliang Xiu"
|
4 |
+
emoji: 🤼
|
5 |
+
colorFrom: indigo
|
6 |
+
colorTo: yellow
|
7 |
+
sdk: gradio
|
8 |
+
sdk_version: 3.1.1
|
9 |
+
app_file: app.py
|
10 |
+
pinned: true
|
11 |
+
python_version: 3.8.13
|
12 |
+
---
|
app.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# install
|
2 |
+
|
3 |
+
|
4 |
+
import glob
|
5 |
+
import gradio as gr
|
6 |
+
import os, random
|
7 |
+
|
8 |
+
import subprocess
|
9 |
+
|
10 |
+
if os.getenv('SYSTEM') == 'spaces':
|
11 |
+
subprocess.run('pip install pyembree'.split())
|
12 |
+
subprocess.run('pip install rembg'.split())
|
13 |
+
subprocess.run(
|
14 |
+
'pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html'.split())
|
15 |
+
subprocess.run(
|
16 |
+
'pip install git+https://github.com/YuliangXiu/kaolin.git'.split())
|
17 |
+
# subprocess.run('pip install https://download.is.tue.mpg.de/icon/kaolin-0.11.0-cp38-cp38-linux_x86_64.whl'.split())
|
18 |
+
subprocess.run('pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py38_cu113_pyt1110/download.html'.split())
|
19 |
+
subprocess.run(
|
20 |
+
'pip install git+https://github.com/Project-Splinter/human_det.git'.split())
|
21 |
+
subprocess.run(
|
22 |
+
'pip install git+https://github.com/YuliangXiu/neural_voxelization_layer.git'.split())
|
23 |
+
|
24 |
+
from apps.infer import generate_model
|
25 |
+
|
26 |
+
# running
|
27 |
+
|
28 |
+
description = '''
|
29 |
+
# ICON Clothed Human Digitization
|
30 |
+
### ICON: Implicit Clothed humans Obtained from Normals (CVPR 2022)
|
31 |
+
|
32 |
+
<table style="width:26%; padding:0; margin:0;">
|
33 |
+
<tr>
|
34 |
+
<th><iframe src="https://ghbtns.com/github-btn.html?user=yuliangxiu&repo=ICON&type=star&count=true&v=2&size=small" frameborder="0" scrolling="0" width="100" height="20"></iframe></th>
|
35 |
+
<th><img alt="Twitter Follow" src="https://img.shields.io/twitter/follow/yuliangxiu?style=social"></th>
|
36 |
+
<th><img alt="YouTube Video Views" src="https://img.shields.io/youtube/views/hZd6AYin2DE?style=social"></th>
|
37 |
+
</tr>
|
38 |
+
</table>
|
39 |
+
|
40 |
+
#### Acknowledgments:
|
41 |
+
|
42 |
+
- [StyleGAN-Human, ECCV 2022](https://stylegan-human.github.io/)
|
43 |
+
- [nagolinc/styleGanHuman_and_PIFu](https://huggingface.co/spaces/nagolinc/styleGanHuman_and_PIFu)
|
44 |
+
- [radames/PIFu-Clothed-Human-Digitization](https://huggingface.co/spaces/radames/PIFu-Clothed-Human-Digitization)
|
45 |
+
|
46 |
+
#### The reconstruction + refinement + video take about 80 seconds for single image.
|
47 |
+
|
48 |
+
<details>
|
49 |
+
|
50 |
+
<summary>More</summary>
|
51 |
+
|
52 |
+
#### Image Credits
|
53 |
+
|
54 |
+
* [Pinterest](https://www.pinterest.com/search/pins/?q=parkour&rs=sitelinks_searchbox)
|
55 |
+
* [Qianli Ma](https://qianlim.github.io/)
|
56 |
+
|
57 |
+
#### Related works
|
58 |
+
|
59 |
+
* [ICON @ MPI](https://icon.is.tue.mpg.de/)
|
60 |
+
* [MonoPort @ USC](https://xiuyuliang.cn/monoport)
|
61 |
+
* [Phorhum @ Google](https://phorhum.github.io/)
|
62 |
+
* [PIFuHD @ Meta](https://shunsukesaito.github.io/PIFuHD/)
|
63 |
+
* [PaMIR @ Tsinghua](http://www.liuyebin.com/pamir/pamir.html)
|
64 |
+
|
65 |
+
</details>
|
66 |
+
'''
|
67 |
+
|
68 |
+
|
69 |
+
def generate_image(seed, psi):
|
70 |
+
iface = gr.Interface.load("spaces/hysts/StyleGAN-Human")
|
71 |
+
img = iface(seed, psi)
|
72 |
+
return img
|
73 |
+
|
74 |
+
random.seed(1993)
|
75 |
+
model_types = ['icon-filter', 'pifu', 'pamir']
|
76 |
+
examples = [[item, random.choice(model_types)] for item in sorted(glob.glob('examples/*.png'))]
|
77 |
+
|
78 |
+
with gr.Blocks() as demo:
|
79 |
+
gr.Markdown(description)
|
80 |
+
|
81 |
+
out_lst = []
|
82 |
+
with gr.Row():
|
83 |
+
with gr.Column():
|
84 |
+
with gr.Row():
|
85 |
+
with gr.Column():
|
86 |
+
seed = gr.inputs.Slider(
|
87 |
+
0, 100, step=1, default=0, label='Seed (For Image Generation)')
|
88 |
+
psi = gr.inputs.Slider(
|
89 |
+
0, 2, step=0.05, default=0.7, label='Truncation psi (For Image Generation)')
|
90 |
+
radio_choice = gr.Radio(model_types, label='Method (For Reconstruction)', value='icon-filter')
|
91 |
+
inp = gr.Image(type="filepath", label="Input Image")
|
92 |
+
with gr.Row():
|
93 |
+
btn_sample = gr.Button("Sample Image")
|
94 |
+
btn_submit = gr.Button("Submit Image")
|
95 |
+
|
96 |
+
gr.Examples(examples=examples,
|
97 |
+
inputs=[inp, radio_choice],
|
98 |
+
cache_examples=True,
|
99 |
+
fn=generate_model,
|
100 |
+
outputs=out_lst)
|
101 |
+
|
102 |
+
out_vid_download = gr.File(label="Download Video, welcome share on Twitter with #ICON")
|
103 |
+
|
104 |
+
with gr.Column():
|
105 |
+
overlap_inp = gr.Image(type="filepath", label="Image Normal Overlap")
|
106 |
+
out_smpl = gr.Model3D(
|
107 |
+
clear_color=[0.0, 0.0, 0.0, 0.0], label="SMPL")
|
108 |
+
out_smpl_download = gr.File(label="Download SMPL mesh")
|
109 |
+
out_smpl_npy_download = gr.File(label="Download SMPL params")
|
110 |
+
out_recon = gr.Model3D(
|
111 |
+
clear_color=[0.0, 0.0, 0.0, 0.0], label="ICON")
|
112 |
+
out_recon_download = gr.File(label="Download clothed human mesh")
|
113 |
+
out_final = gr.Model3D(
|
114 |
+
clear_color=[0.0, 0.0, 0.0, 0.0], label="ICON++")
|
115 |
+
out_final_download = gr.File(label="Download refined clothed human mesh")
|
116 |
+
|
117 |
+
out_lst = [out_smpl, out_smpl_download, out_smpl_npy_download, out_recon, out_recon_download,
|
118 |
+
out_final, out_final_download, out_vid_download, overlap_inp]
|
119 |
+
|
120 |
+
btn_submit.click(fn=generate_model, inputs=[inp, radio_choice], outputs=out_lst)
|
121 |
+
btn_sample.click(fn=generate_image, inputs=[seed, psi], outputs=inp)
|
122 |
+
|
123 |
+
if __name__ == "__main__":
|
124 |
+
|
125 |
+
# demo.launch(debug=False, enable_queue=False,
|
126 |
+
# auth=("[email protected]", "icon_2022"),
|
127 |
+
# auth_message="Register at icon.is.tue.mpg.de to get HuggingFace username and password.")
|
128 |
+
|
129 |
+
demo.launch(debug=True, enable_queue=True)
|
apps/ICON.py
ADDED
@@ -0,0 +1,765 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
|
4 |
+
# holder of all proprietary rights on this computer program.
|
5 |
+
# You can only use this computer program if you have closed
|
6 |
+
# a license agreement with MPG or you get the right to use the computer
|
7 |
+
# program from someone who is authorized to grant you that right.
|
8 |
+
# Any use of the computer program without a valid license is prohibited and
|
9 |
+
# liable to prosecution.
|
10 |
+
#
|
11 |
+
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
|
12 |
+
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
|
13 |
+
# for Intelligent Systems. All rights reserved.
|
14 |
+
#
|
15 |
+
# Contact: [email protected]
|
16 |
+
|
17 |
+
|
18 |
+
import os
|
19 |
+
|
20 |
+
from lib.common.seg3d_lossless import Seg3dLossless
|
21 |
+
from lib.dataset.Evaluator import Evaluator
|
22 |
+
from lib.net import HGPIFuNet
|
23 |
+
from lib.common.train_util import *
|
24 |
+
from lib.common.render import Render
|
25 |
+
from lib.dataset.mesh_util import SMPLX, update_mesh_shape_prior_losses, get_visibility
|
26 |
+
import warnings
|
27 |
+
import logging
|
28 |
+
import torch
|
29 |
+
import lib.smplx as smplx
|
30 |
+
import numpy as np
|
31 |
+
from torch import nn
|
32 |
+
import os.path as osp
|
33 |
+
|
34 |
+
from skimage.transform import resize
|
35 |
+
import pytorch_lightning as pl
|
36 |
+
from huggingface_hub import cached_download
|
37 |
+
|
38 |
+
torch.backends.cudnn.benchmark = True
|
39 |
+
|
40 |
+
logging.getLogger("lightning").setLevel(logging.ERROR)
|
41 |
+
|
42 |
+
warnings.filterwarnings("ignore")
|
43 |
+
|
44 |
+
|
45 |
+
class ICON(pl.LightningModule):
|
46 |
+
def __init__(self, cfg):
|
47 |
+
super(ICON, self).__init__()
|
48 |
+
|
49 |
+
self.cfg = cfg
|
50 |
+
self.batch_size = self.cfg.batch_size
|
51 |
+
self.lr_G = self.cfg.lr_G
|
52 |
+
|
53 |
+
self.use_sdf = cfg.sdf
|
54 |
+
self.prior_type = cfg.net.prior_type
|
55 |
+
self.mcube_res = cfg.mcube_res
|
56 |
+
self.clean_mesh_flag = cfg.clean_mesh
|
57 |
+
|
58 |
+
self.netG = HGPIFuNet(
|
59 |
+
self.cfg,
|
60 |
+
self.cfg.projection_mode,
|
61 |
+
error_term=nn.SmoothL1Loss() if self.use_sdf else nn.MSELoss(),
|
62 |
+
)
|
63 |
+
|
64 |
+
# TODO: replace the renderer from opengl to pytorch3d
|
65 |
+
self.evaluator = Evaluator(
|
66 |
+
device=torch.device(f"cuda:{self.cfg.gpus[0]}"))
|
67 |
+
|
68 |
+
self.resolutions = (
|
69 |
+
np.logspace(
|
70 |
+
start=5,
|
71 |
+
stop=np.log2(self.mcube_res),
|
72 |
+
base=2,
|
73 |
+
num=int(np.log2(self.mcube_res) - 4),
|
74 |
+
endpoint=True,
|
75 |
+
)
|
76 |
+
+ 1.0
|
77 |
+
)
|
78 |
+
self.resolutions = self.resolutions.astype(np.int16).tolist()
|
79 |
+
|
80 |
+
self.icon_keys = ["smpl_verts", "smpl_faces", "smpl_vis", "smpl_cmap"]
|
81 |
+
self.pamir_keys = ["voxel_verts",
|
82 |
+
"voxel_faces", "pad_v_num", "pad_f_num"]
|
83 |
+
|
84 |
+
self.reconEngine = Seg3dLossless(
|
85 |
+
query_func=query_func,
|
86 |
+
b_min=[[-1.0, 1.0, -1.0]],
|
87 |
+
b_max=[[1.0, -1.0, 1.0]],
|
88 |
+
resolutions=self.resolutions,
|
89 |
+
align_corners=True,
|
90 |
+
balance_value=0.50,
|
91 |
+
device=torch.device(f"cuda:{self.cfg.test_gpus[0]}"),
|
92 |
+
visualize=False,
|
93 |
+
debug=False,
|
94 |
+
use_cuda_impl=False,
|
95 |
+
faster=True,
|
96 |
+
)
|
97 |
+
|
98 |
+
self.render = Render(
|
99 |
+
size=512, device=torch.device(f"cuda:{self.cfg.test_gpus[0]}")
|
100 |
+
)
|
101 |
+
self.smpl_data = SMPLX()
|
102 |
+
|
103 |
+
self.get_smpl_model = lambda smpl_type, gender, age, v_template: smplx.create(
|
104 |
+
self.smpl_data.model_dir,
|
105 |
+
kid_template_path=cached_download(osp.join(self.smpl_data.model_dir,
|
106 |
+
f"{smpl_type}/{smpl_type}_kid_template.npy"), use_auth_token=os.environ['ICON']),
|
107 |
+
model_type=smpl_type,
|
108 |
+
gender=gender,
|
109 |
+
age=age,
|
110 |
+
v_template=v_template,
|
111 |
+
use_face_contour=False,
|
112 |
+
ext="pkl",
|
113 |
+
)
|
114 |
+
|
115 |
+
self.in_geo = [item[0] for item in cfg.net.in_geo]
|
116 |
+
self.in_nml = [item[0] for item in cfg.net.in_nml]
|
117 |
+
self.in_geo_dim = [item[1] for item in cfg.net.in_geo]
|
118 |
+
self.in_total = self.in_geo + self.in_nml
|
119 |
+
self.smpl_dim = cfg.net.smpl_dim
|
120 |
+
|
121 |
+
self.export_dir = None
|
122 |
+
self.result_eval = {}
|
123 |
+
|
124 |
+
def get_progress_bar_dict(self):
|
125 |
+
tqdm_dict = super().get_progress_bar_dict()
|
126 |
+
if "v_num" in tqdm_dict:
|
127 |
+
del tqdm_dict["v_num"]
|
128 |
+
return tqdm_dict
|
129 |
+
|
130 |
+
# Training related
|
131 |
+
def configure_optimizers(self):
|
132 |
+
|
133 |
+
# set optimizer
|
134 |
+
weight_decay = self.cfg.weight_decay
|
135 |
+
momentum = self.cfg.momentum
|
136 |
+
|
137 |
+
optim_params_G = [
|
138 |
+
{"params": self.netG.if_regressor.parameters(), "lr": self.lr_G}
|
139 |
+
]
|
140 |
+
|
141 |
+
if self.cfg.net.use_filter:
|
142 |
+
optim_params_G.append(
|
143 |
+
{"params": self.netG.F_filter.parameters(), "lr": self.lr_G}
|
144 |
+
)
|
145 |
+
|
146 |
+
if self.cfg.net.prior_type == "pamir":
|
147 |
+
optim_params_G.append(
|
148 |
+
{"params": self.netG.ve.parameters(), "lr": self.lr_G}
|
149 |
+
)
|
150 |
+
|
151 |
+
if self.cfg.optim == "Adadelta":
|
152 |
+
|
153 |
+
optimizer_G = torch.optim.Adadelta(
|
154 |
+
optim_params_G, lr=self.lr_G, weight_decay=weight_decay
|
155 |
+
)
|
156 |
+
|
157 |
+
elif self.cfg.optim == "Adam":
|
158 |
+
|
159 |
+
optimizer_G = torch.optim.Adam(
|
160 |
+
optim_params_G, lr=self.lr_G, weight_decay=weight_decay
|
161 |
+
)
|
162 |
+
|
163 |
+
elif self.cfg.optim == "RMSprop":
|
164 |
+
|
165 |
+
optimizer_G = torch.optim.RMSprop(
|
166 |
+
optim_params_G,
|
167 |
+
lr=self.lr_G,
|
168 |
+
weight_decay=weight_decay,
|
169 |
+
momentum=momentum,
|
170 |
+
)
|
171 |
+
|
172 |
+
else:
|
173 |
+
raise NotImplementedError
|
174 |
+
|
175 |
+
# set scheduler
|
176 |
+
scheduler_G = torch.optim.lr_scheduler.MultiStepLR(
|
177 |
+
optimizer_G, milestones=self.cfg.schedule, gamma=self.cfg.gamma
|
178 |
+
)
|
179 |
+
|
180 |
+
return [optimizer_G], [scheduler_G]
|
181 |
+
|
182 |
+
def training_step(self, batch, batch_idx):
|
183 |
+
|
184 |
+
if not self.cfg.fast_dev:
|
185 |
+
export_cfg(self.logger, self.cfg)
|
186 |
+
|
187 |
+
self.netG.train()
|
188 |
+
|
189 |
+
in_tensor_dict = {
|
190 |
+
"sample": batch["samples_geo"].permute(0, 2, 1),
|
191 |
+
"calib": batch["calib"],
|
192 |
+
"label": batch["labels_geo"].unsqueeze(1),
|
193 |
+
}
|
194 |
+
|
195 |
+
for name in self.in_total:
|
196 |
+
in_tensor_dict.update({name: batch[name]})
|
197 |
+
|
198 |
+
if self.prior_type == "icon":
|
199 |
+
for key in self.icon_keys:
|
200 |
+
in_tensor_dict.update({key: batch[key]})
|
201 |
+
elif self.prior_type == "pamir":
|
202 |
+
for key in self.pamir_keys:
|
203 |
+
in_tensor_dict.update({key: batch[key]})
|
204 |
+
else:
|
205 |
+
pass
|
206 |
+
|
207 |
+
preds_G, error_G = self.netG(in_tensor_dict)
|
208 |
+
|
209 |
+
acc, iou, prec, recall = self.evaluator.calc_acc(
|
210 |
+
preds_G.flatten(),
|
211 |
+
in_tensor_dict["label"].flatten(),
|
212 |
+
0.5,
|
213 |
+
use_sdf=self.cfg.sdf,
|
214 |
+
)
|
215 |
+
|
216 |
+
# metrics processing
|
217 |
+
metrics_log = {
|
218 |
+
"train_loss": error_G.item(),
|
219 |
+
"train_acc": acc.item(),
|
220 |
+
"train_iou": iou.item(),
|
221 |
+
"train_prec": prec.item(),
|
222 |
+
"train_recall": recall.item(),
|
223 |
+
}
|
224 |
+
|
225 |
+
tf_log = tf_log_convert(metrics_log)
|
226 |
+
bar_log = bar_log_convert(metrics_log)
|
227 |
+
|
228 |
+
if batch_idx % int(self.cfg.freq_show_train) == 0:
|
229 |
+
|
230 |
+
with torch.no_grad():
|
231 |
+
self.render_func(in_tensor_dict, dataset="train")
|
232 |
+
|
233 |
+
metrics_return = {
|
234 |
+
k.replace("train_", ""): torch.tensor(v) for k, v in metrics_log.items()
|
235 |
+
}
|
236 |
+
|
237 |
+
metrics_return.update(
|
238 |
+
{"loss": error_G, "log": tf_log, "progress_bar": bar_log})
|
239 |
+
|
240 |
+
return metrics_return
|
241 |
+
|
242 |
+
def training_epoch_end(self, outputs):
|
243 |
+
|
244 |
+
if [] in outputs:
|
245 |
+
outputs = outputs[0]
|
246 |
+
|
247 |
+
# metrics processing
|
248 |
+
metrics_log = {
|
249 |
+
"train_avgloss": batch_mean(outputs, "loss"),
|
250 |
+
"train_avgiou": batch_mean(outputs, "iou"),
|
251 |
+
"train_avgprec": batch_mean(outputs, "prec"),
|
252 |
+
"train_avgrecall": batch_mean(outputs, "recall"),
|
253 |
+
"train_avgacc": batch_mean(outputs, "acc"),
|
254 |
+
}
|
255 |
+
|
256 |
+
tf_log = tf_log_convert(metrics_log)
|
257 |
+
|
258 |
+
return {"log": tf_log}
|
259 |
+
|
260 |
+
def validation_step(self, batch, batch_idx):
|
261 |
+
|
262 |
+
self.netG.eval()
|
263 |
+
self.netG.training = False
|
264 |
+
|
265 |
+
in_tensor_dict = {
|
266 |
+
"sample": batch["samples_geo"].permute(0, 2, 1),
|
267 |
+
"calib": batch["calib"],
|
268 |
+
"label": batch["labels_geo"].unsqueeze(1),
|
269 |
+
}
|
270 |
+
|
271 |
+
for name in self.in_total:
|
272 |
+
in_tensor_dict.update({name: batch[name]})
|
273 |
+
|
274 |
+
if self.prior_type == "icon":
|
275 |
+
for key in self.icon_keys:
|
276 |
+
in_tensor_dict.update({key: batch[key]})
|
277 |
+
elif self.prior_type == "pamir":
|
278 |
+
for key in self.pamir_keys:
|
279 |
+
in_tensor_dict.update({key: batch[key]})
|
280 |
+
else:
|
281 |
+
pass
|
282 |
+
|
283 |
+
preds_G, error_G = self.netG(in_tensor_dict)
|
284 |
+
|
285 |
+
acc, iou, prec, recall = self.evaluator.calc_acc(
|
286 |
+
preds_G.flatten(),
|
287 |
+
in_tensor_dict["label"].flatten(),
|
288 |
+
0.5,
|
289 |
+
use_sdf=self.cfg.sdf,
|
290 |
+
)
|
291 |
+
|
292 |
+
if batch_idx % int(self.cfg.freq_show_val) == 0:
|
293 |
+
with torch.no_grad():
|
294 |
+
self.render_func(in_tensor_dict, dataset="val", idx=batch_idx)
|
295 |
+
|
296 |
+
metrics_return = {
|
297 |
+
"val_loss": error_G,
|
298 |
+
"val_acc": acc,
|
299 |
+
"val_iou": iou,
|
300 |
+
"val_prec": prec,
|
301 |
+
"val_recall": recall,
|
302 |
+
}
|
303 |
+
|
304 |
+
return metrics_return
|
305 |
+
|
306 |
+
def validation_epoch_end(self, outputs):
|
307 |
+
|
308 |
+
# metrics processing
|
309 |
+
metrics_log = {
|
310 |
+
"val_avgloss": batch_mean(outputs, "val_loss"),
|
311 |
+
"val_avgacc": batch_mean(outputs, "val_acc"),
|
312 |
+
"val_avgiou": batch_mean(outputs, "val_iou"),
|
313 |
+
"val_avgprec": batch_mean(outputs, "val_prec"),
|
314 |
+
"val_avgrecall": batch_mean(outputs, "val_recall"),
|
315 |
+
}
|
316 |
+
|
317 |
+
tf_log = tf_log_convert(metrics_log)
|
318 |
+
|
319 |
+
return {"log": tf_log}
|
320 |
+
|
321 |
+
def compute_vis_cmap(self, smpl_type, smpl_verts, smpl_faces):
|
322 |
+
|
323 |
+
(xy, z) = torch.as_tensor(smpl_verts).split([2, 1], dim=1)
|
324 |
+
smpl_vis = get_visibility(xy, -z, torch.as_tensor(smpl_faces).long())
|
325 |
+
if smpl_type == "smpl":
|
326 |
+
smplx_ind = self.smpl_data.smpl2smplx(np.arange(smpl_vis.shape[0]))
|
327 |
+
else:
|
328 |
+
smplx_ind = np.arange(smpl_vis.shape[0])
|
329 |
+
smpl_cmap = self.smpl_data.get_smpl_mat(smplx_ind)
|
330 |
+
|
331 |
+
return {
|
332 |
+
"smpl_vis": smpl_vis.unsqueeze(0).to(self.device),
|
333 |
+
"smpl_cmap": smpl_cmap.unsqueeze(0).to(self.device),
|
334 |
+
"smpl_verts": smpl_verts.unsqueeze(0),
|
335 |
+
}
|
336 |
+
|
337 |
+
@torch.enable_grad()
|
338 |
+
def optim_body(self, in_tensor_dict, batch):
|
339 |
+
|
340 |
+
smpl_model = self.get_smpl_model(
|
341 |
+
batch["type"][0], batch["gender"][0], batch["age"][0], None
|
342 |
+
).to(self.device)
|
343 |
+
in_tensor_dict["smpl_faces"] = (
|
344 |
+
torch.tensor(smpl_model.faces.astype(np.int))
|
345 |
+
.long()
|
346 |
+
.unsqueeze(0)
|
347 |
+
.to(self.device)
|
348 |
+
)
|
349 |
+
|
350 |
+
# The optimizer and variables
|
351 |
+
optimed_pose = torch.tensor(
|
352 |
+
batch["body_pose"][0], device=self.device, requires_grad=True
|
353 |
+
) # [1,23,3,3]
|
354 |
+
optimed_trans = torch.tensor(
|
355 |
+
batch["transl"][0], device=self.device, requires_grad=True
|
356 |
+
) # [3]
|
357 |
+
optimed_betas = torch.tensor(
|
358 |
+
batch["betas"][0], device=self.device, requires_grad=True
|
359 |
+
) # [1,10]
|
360 |
+
optimed_orient = torch.tensor(
|
361 |
+
batch["global_orient"][0], device=self.device, requires_grad=True
|
362 |
+
) # [1,1,3,3]
|
363 |
+
|
364 |
+
optimizer_smpl = torch.optim.SGD(
|
365 |
+
[optimed_pose, optimed_trans, optimed_betas, optimed_orient],
|
366 |
+
lr=1e-3,
|
367 |
+
momentum=0.9,
|
368 |
+
)
|
369 |
+
scheduler_smpl = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
370 |
+
optimizer_smpl, mode="min", factor=0.5, verbose=0, min_lr=1e-5, patience=5
|
371 |
+
)
|
372 |
+
loop_smpl = range(50)
|
373 |
+
for i in loop_smpl:
|
374 |
+
|
375 |
+
optimizer_smpl.zero_grad()
|
376 |
+
|
377 |
+
# prior_loss, optimed_pose = dataset.vposer_prior(optimed_pose)
|
378 |
+
smpl_out = smpl_model(
|
379 |
+
betas=optimed_betas,
|
380 |
+
body_pose=optimed_pose,
|
381 |
+
global_orient=optimed_orient,
|
382 |
+
transl=optimed_trans,
|
383 |
+
return_verts=True,
|
384 |
+
)
|
385 |
+
|
386 |
+
smpl_verts = smpl_out.vertices[0] * 100.0
|
387 |
+
smpl_verts = projection(
|
388 |
+
smpl_verts, batch["calib"][0], format="tensor")
|
389 |
+
smpl_verts[:, 1] *= -1
|
390 |
+
# render optimized mesh (normal, T_normal, image [-1,1])
|
391 |
+
self.render.load_meshes(
|
392 |
+
smpl_verts, in_tensor_dict["smpl_faces"])
|
393 |
+
(
|
394 |
+
in_tensor_dict["T_normal_F"],
|
395 |
+
in_tensor_dict["T_normal_B"],
|
396 |
+
) = self.render.get_rgb_image()
|
397 |
+
|
398 |
+
T_mask_F, T_mask_B = self.render.get_silhouette_image()
|
399 |
+
|
400 |
+
with torch.no_grad():
|
401 |
+
(
|
402 |
+
in_tensor_dict["normal_F"],
|
403 |
+
in_tensor_dict["normal_B"],
|
404 |
+
) = self.netG.normal_filter(in_tensor_dict)
|
405 |
+
|
406 |
+
# mask = torch.abs(in_tensor['T_normal_F']).sum(dim=0, keepdims=True) > 0.0
|
407 |
+
diff_F_smpl = torch.abs(
|
408 |
+
in_tensor_dict["T_normal_F"] - in_tensor_dict["normal_F"]
|
409 |
+
)
|
410 |
+
diff_B_smpl = torch.abs(
|
411 |
+
in_tensor_dict["T_normal_B"] - in_tensor_dict["normal_B"]
|
412 |
+
)
|
413 |
+
loss = (diff_F_smpl + diff_B_smpl).mean()
|
414 |
+
|
415 |
+
# silhouette loss
|
416 |
+
smpl_arr = torch.cat([T_mask_F, T_mask_B], dim=-1)[0]
|
417 |
+
gt_arr = torch.cat(
|
418 |
+
[in_tensor_dict["normal_F"][0], in_tensor_dict["normal_B"][0]], dim=2
|
419 |
+
).permute(1, 2, 0)
|
420 |
+
gt_arr = ((gt_arr + 1.0) * 0.5).to(self.device)
|
421 |
+
bg_color = (
|
422 |
+
torch.Tensor([0.5, 0.5, 0.5]).unsqueeze(
|
423 |
+
0).unsqueeze(0).to(self.device)
|
424 |
+
)
|
425 |
+
gt_arr = ((gt_arr - bg_color).sum(dim=-1) != 0.0).float()
|
426 |
+
loss += torch.abs(smpl_arr - gt_arr).mean()
|
427 |
+
|
428 |
+
# Image.fromarray(((in_tensor_dict['T_normal_F'][0].permute(1,2,0)+1.0)*0.5*255.0).detach().cpu().numpy().astype(np.uint8)).show()
|
429 |
+
|
430 |
+
# loop_smpl.set_description(f"smpl = {loss:.3f}")
|
431 |
+
|
432 |
+
loss.backward(retain_graph=True)
|
433 |
+
optimizer_smpl.step()
|
434 |
+
scheduler_smpl.step(loss)
|
435 |
+
in_tensor_dict["smpl_verts"] = smpl_verts.unsqueeze(0)
|
436 |
+
|
437 |
+
in_tensor_dict.update(
|
438 |
+
self.compute_vis_cmap(
|
439 |
+
batch["type"][0],
|
440 |
+
in_tensor_dict["smpl_verts"][0],
|
441 |
+
in_tensor_dict["smpl_faces"][0],
|
442 |
+
)
|
443 |
+
)
|
444 |
+
|
445 |
+
features, inter = self.netG.filter(in_tensor_dict, return_inter=True)
|
446 |
+
|
447 |
+
return features, inter, in_tensor_dict
|
448 |
+
|
449 |
+
@torch.enable_grad()
|
450 |
+
def optim_cloth(self, verts_pr, faces_pr, inter):
|
451 |
+
|
452 |
+
# convert from GT to SDF
|
453 |
+
verts_pr -= (self.resolutions[-1] - 1) / 2.0
|
454 |
+
verts_pr /= (self.resolutions[-1] - 1) / 2.0
|
455 |
+
|
456 |
+
losses = {
|
457 |
+
"cloth": {"weight": 5.0, "value": 0.0},
|
458 |
+
"edge": {"weight": 100.0, "value": 0.0},
|
459 |
+
"normal": {"weight": 0.2, "value": 0.0},
|
460 |
+
"laplacian": {"weight": 100.0, "value": 0.0},
|
461 |
+
"smpl": {"weight": 1.0, "value": 0.0},
|
462 |
+
"deform": {"weight": 20.0, "value": 0.0},
|
463 |
+
}
|
464 |
+
|
465 |
+
deform_verts = torch.full(
|
466 |
+
verts_pr.shape, 0.0, device=self.device, requires_grad=True
|
467 |
+
)
|
468 |
+
optimizer_cloth = torch.optim.SGD(
|
469 |
+
[deform_verts], lr=1e-1, momentum=0.9)
|
470 |
+
scheduler_cloth = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
471 |
+
optimizer_cloth, mode="min", factor=0.1, verbose=0, min_lr=1e-3, patience=5
|
472 |
+
)
|
473 |
+
# cloth optimization
|
474 |
+
loop_cloth = range(100)
|
475 |
+
|
476 |
+
for i in loop_cloth:
|
477 |
+
|
478 |
+
optimizer_cloth.zero_grad()
|
479 |
+
|
480 |
+
self.render.load_meshes(
|
481 |
+
verts_pr.unsqueeze(0).to(self.device),
|
482 |
+
faces_pr.unsqueeze(0).to(self.device).long(),
|
483 |
+
deform_verts,
|
484 |
+
)
|
485 |
+
P_normal_F, P_normal_B = self.render.get_rgb_image()
|
486 |
+
|
487 |
+
update_mesh_shape_prior_losses(self.render.mesh, losses)
|
488 |
+
diff_F_cloth = torch.abs(P_normal_F[0] - inter[:3])
|
489 |
+
diff_B_cloth = torch.abs(P_normal_B[0] - inter[3:])
|
490 |
+
losses["cloth"]["value"] = (diff_F_cloth + diff_B_cloth).mean()
|
491 |
+
losses["deform"]["value"] = torch.topk(
|
492 |
+
torch.abs(deform_verts.flatten()), 30
|
493 |
+
)[0].mean()
|
494 |
+
|
495 |
+
# Weighted sum of the losses
|
496 |
+
cloth_loss = torch.tensor(0.0, device=self.device)
|
497 |
+
pbar_desc = ""
|
498 |
+
|
499 |
+
for k in losses.keys():
|
500 |
+
if k != "smpl":
|
501 |
+
cloth_loss_per_cls = losses[k]["value"] * \
|
502 |
+
losses[k]["weight"]
|
503 |
+
pbar_desc += f"{k}: {cloth_loss_per_cls:.3f} | "
|
504 |
+
cloth_loss += cloth_loss_per_cls
|
505 |
+
|
506 |
+
# loop_cloth.set_description(pbar_desc)
|
507 |
+
cloth_loss.backward(retain_graph=True)
|
508 |
+
optimizer_cloth.step()
|
509 |
+
scheduler_cloth.step(cloth_loss)
|
510 |
+
|
511 |
+
# convert from GT to SDF
|
512 |
+
deform_verts = deform_verts.flatten().detach()
|
513 |
+
deform_verts[torch.topk(torch.abs(deform_verts), 30)[
|
514 |
+
1]] = deform_verts.mean()
|
515 |
+
deform_verts = deform_verts.view(-1, 3).cpu()
|
516 |
+
|
517 |
+
verts_pr += deform_verts
|
518 |
+
verts_pr *= (self.resolutions[-1] - 1) / 2.0
|
519 |
+
verts_pr += (self.resolutions[-1] - 1) / 2.0
|
520 |
+
|
521 |
+
return verts_pr
|
522 |
+
|
523 |
+
def test_step(self, batch, batch_idx):
|
524 |
+
|
525 |
+
# dict_keys(['dataset', 'subject', 'rotation', 'scale', 'calib',
|
526 |
+
# 'normal_F', 'normal_B', 'image', 'T_normal_F', 'T_normal_B',
|
527 |
+
# 'z-trans', 'verts', 'faces', 'samples_geo', 'labels_geo',
|
528 |
+
# 'smpl_verts', 'smpl_faces', 'smpl_vis', 'smpl_cmap', 'pts_signs',
|
529 |
+
# 'type', 'gender', 'age', 'body_pose', 'global_orient', 'betas', 'transl'])
|
530 |
+
|
531 |
+
if self.evaluator._normal_render is None:
|
532 |
+
self.evaluator.init_gl()
|
533 |
+
|
534 |
+
self.netG.eval()
|
535 |
+
self.netG.training = False
|
536 |
+
in_tensor_dict = {}
|
537 |
+
|
538 |
+
# export paths
|
539 |
+
mesh_name = batch["subject"][0]
|
540 |
+
mesh_rot = batch["rotation"][0].item()
|
541 |
+
ckpt_dir = self.cfg.name
|
542 |
+
|
543 |
+
for kid, key in enumerate(self.cfg.dataset.noise_type):
|
544 |
+
ckpt_dir += f"_{key}_{self.cfg.dataset.noise_scale[kid]}"
|
545 |
+
|
546 |
+
if self.cfg.optim_cloth:
|
547 |
+
ckpt_dir += "_optim_cloth"
|
548 |
+
if self.cfg.optim_body:
|
549 |
+
ckpt_dir += "_optim_body"
|
550 |
+
|
551 |
+
self.export_dir = osp.join(self.cfg.results_path, ckpt_dir, mesh_name)
|
552 |
+
os.makedirs(self.export_dir, exist_ok=True)
|
553 |
+
|
554 |
+
for name in self.in_total:
|
555 |
+
if name in batch.keys():
|
556 |
+
in_tensor_dict.update({name: batch[name]})
|
557 |
+
|
558 |
+
# update the new T_normal_F/B
|
559 |
+
in_tensor_dict.update(
|
560 |
+
self.evaluator.render_normal(
|
561 |
+
batch["smpl_verts"], batch["smpl_faces"])
|
562 |
+
)
|
563 |
+
|
564 |
+
# update the new smpl_vis
|
565 |
+
(xy, z) = batch["smpl_verts"][0].split([2, 1], dim=1)
|
566 |
+
smpl_vis = get_visibility(
|
567 |
+
xy,
|
568 |
+
z,
|
569 |
+
torch.as_tensor(self.smpl_data.faces).type_as(
|
570 |
+
batch["smpl_verts"]).long(),
|
571 |
+
)
|
572 |
+
in_tensor_dict.update({"smpl_vis": smpl_vis.unsqueeze(0)})
|
573 |
+
|
574 |
+
if self.prior_type == "icon":
|
575 |
+
for key in self.icon_keys:
|
576 |
+
in_tensor_dict.update({key: batch[key]})
|
577 |
+
elif self.prior_type == "pamir":
|
578 |
+
for key in self.pamir_keys:
|
579 |
+
in_tensor_dict.update({key: batch[key]})
|
580 |
+
else:
|
581 |
+
pass
|
582 |
+
|
583 |
+
with torch.no_grad():
|
584 |
+
if self.cfg.optim_body:
|
585 |
+
features, inter, in_tensor_dict = self.optim_body(
|
586 |
+
in_tensor_dict, batch)
|
587 |
+
else:
|
588 |
+
features, inter = self.netG.filter(
|
589 |
+
in_tensor_dict, return_inter=True)
|
590 |
+
sdf = self.reconEngine(
|
591 |
+
opt=self.cfg, netG=self.netG, features=features, proj_matrix=None
|
592 |
+
)
|
593 |
+
|
594 |
+
# save inter results
|
595 |
+
image = (
|
596 |
+
in_tensor_dict["image"][0].permute(
|
597 |
+
1, 2, 0).detach().cpu().numpy() + 1.0
|
598 |
+
) * 0.5
|
599 |
+
smpl_F = (
|
600 |
+
in_tensor_dict["T_normal_F"][0].permute(
|
601 |
+
1, 2, 0).detach().cpu().numpy()
|
602 |
+
+ 1.0
|
603 |
+
) * 0.5
|
604 |
+
smpl_B = (
|
605 |
+
in_tensor_dict["T_normal_B"][0].permute(
|
606 |
+
1, 2, 0).detach().cpu().numpy()
|
607 |
+
+ 1.0
|
608 |
+
) * 0.5
|
609 |
+
image_inter = np.concatenate(
|
610 |
+
self.tensor2image(512, inter[0]) + [smpl_F, smpl_B, image], axis=1
|
611 |
+
)
|
612 |
+
Image.fromarray((image_inter * 255.0).astype(np.uint8)).save(
|
613 |
+
osp.join(self.export_dir, f"{mesh_rot}_inter.png")
|
614 |
+
)
|
615 |
+
|
616 |
+
verts_pr, faces_pr = self.reconEngine.export_mesh(sdf)
|
617 |
+
|
618 |
+
if self.clean_mesh_flag:
|
619 |
+
verts_pr, faces_pr = clean_mesh(verts_pr, faces_pr)
|
620 |
+
|
621 |
+
if self.cfg.optim_cloth:
|
622 |
+
verts_pr = self.optim_cloth(verts_pr, faces_pr, inter[0].detach())
|
623 |
+
|
624 |
+
verts_gt = batch["verts"][0]
|
625 |
+
faces_gt = batch["faces"][0]
|
626 |
+
|
627 |
+
self.result_eval.update(
|
628 |
+
{
|
629 |
+
"verts_gt": verts_gt,
|
630 |
+
"faces_gt": faces_gt,
|
631 |
+
"verts_pr": verts_pr,
|
632 |
+
"faces_pr": faces_pr,
|
633 |
+
"recon_size": (self.resolutions[-1] - 1.0),
|
634 |
+
"calib": batch["calib"][0],
|
635 |
+
}
|
636 |
+
)
|
637 |
+
|
638 |
+
self.evaluator.set_mesh(self.result_eval, scale_factor=1.0)
|
639 |
+
self.evaluator.space_transfer()
|
640 |
+
|
641 |
+
chamfer, p2s = self.evaluator.calculate_chamfer_p2s(
|
642 |
+
sampled_points=1000)
|
643 |
+
normal_consist = self.evaluator.calculate_normal_consist(
|
644 |
+
save_demo_img=osp.join(self.export_dir, f"{mesh_rot}_nc.png")
|
645 |
+
)
|
646 |
+
|
647 |
+
test_log = {"chamfer": chamfer, "p2s": p2s, "NC": normal_consist}
|
648 |
+
|
649 |
+
return test_log
|
650 |
+
|
651 |
+
def test_epoch_end(self, outputs):
|
652 |
+
|
653 |
+
# make_test_gif("/".join(self.export_dir.split("/")[:-2]))
|
654 |
+
|
655 |
+
accu_outputs = accumulate(
|
656 |
+
outputs,
|
657 |
+
rot_num=3,
|
658 |
+
split={
|
659 |
+
"thuman2": (0, 5),
|
660 |
+
},
|
661 |
+
)
|
662 |
+
|
663 |
+
print(colored(self.cfg.name, "green"))
|
664 |
+
print(colored(self.cfg.dataset.noise_scale, "green"))
|
665 |
+
|
666 |
+
self.logger.experiment.add_hparams(
|
667 |
+
hparam_dict={"lr_G": self.lr_G, "bsize": self.batch_size},
|
668 |
+
metric_dict=accu_outputs,
|
669 |
+
)
|
670 |
+
|
671 |
+
np.save(
|
672 |
+
osp.join(self.export_dir, "../test_results.npy"),
|
673 |
+
accu_outputs,
|
674 |
+
allow_pickle=True,
|
675 |
+
)
|
676 |
+
|
677 |
+
return accu_outputs
|
678 |
+
|
679 |
+
def tensor2image(self, height, inter):
|
680 |
+
|
681 |
+
all = []
|
682 |
+
for dim in self.in_geo_dim:
|
683 |
+
img = resize(
|
684 |
+
np.tile(
|
685 |
+
((inter[:dim].cpu().numpy() + 1.0) /
|
686 |
+
2.0).transpose(1, 2, 0),
|
687 |
+
(1, 1, int(3 / dim)),
|
688 |
+
),
|
689 |
+
(height, height),
|
690 |
+
anti_aliasing=True,
|
691 |
+
)
|
692 |
+
|
693 |
+
all.append(img)
|
694 |
+
inter = inter[dim:]
|
695 |
+
|
696 |
+
return all
|
697 |
+
|
698 |
+
def render_func(self, in_tensor_dict, dataset="title", idx=0):
|
699 |
+
|
700 |
+
for name in in_tensor_dict.keys():
|
701 |
+
in_tensor_dict[name] = in_tensor_dict[name][0:1]
|
702 |
+
|
703 |
+
self.netG.eval()
|
704 |
+
features, inter = self.netG.filter(in_tensor_dict, return_inter=True)
|
705 |
+
sdf = self.reconEngine(
|
706 |
+
opt=self.cfg, netG=self.netG, features=features, proj_matrix=None
|
707 |
+
)
|
708 |
+
|
709 |
+
if sdf is not None:
|
710 |
+
render = self.reconEngine.display(sdf)
|
711 |
+
|
712 |
+
image_pred = np.flip(render[:, :, ::-1], axis=0)
|
713 |
+
height = image_pred.shape[0]
|
714 |
+
|
715 |
+
image_gt = resize(
|
716 |
+
((in_tensor_dict["image"].cpu().numpy()[0] + 1.0) / 2.0).transpose(
|
717 |
+
1, 2, 0
|
718 |
+
),
|
719 |
+
(height, height),
|
720 |
+
anti_aliasing=True,
|
721 |
+
)
|
722 |
+
image_inter = self.tensor2image(height, inter[0])
|
723 |
+
image = np.concatenate(
|
724 |
+
[image_pred, image_gt] + image_inter, axis=1)
|
725 |
+
|
726 |
+
step_id = self.global_step if dataset == "train" else self.global_step + idx
|
727 |
+
self.logger.experiment.add_image(
|
728 |
+
tag=f"Occupancy-{dataset}/{step_id}",
|
729 |
+
img_tensor=image.transpose(2, 0, 1),
|
730 |
+
global_step=step_id,
|
731 |
+
)
|
732 |
+
|
733 |
+
def test_single(self, batch):
|
734 |
+
|
735 |
+
self.netG.eval()
|
736 |
+
self.netG.training = False
|
737 |
+
in_tensor_dict = {}
|
738 |
+
|
739 |
+
for name in self.in_total:
|
740 |
+
if name in batch.keys():
|
741 |
+
in_tensor_dict.update({name: batch[name]})
|
742 |
+
|
743 |
+
if self.prior_type == "icon":
|
744 |
+
for key in self.icon_keys:
|
745 |
+
in_tensor_dict.update({key: batch[key]})
|
746 |
+
elif self.prior_type == "pamir":
|
747 |
+
for key in self.pamir_keys:
|
748 |
+
in_tensor_dict.update({key: batch[key]})
|
749 |
+
else:
|
750 |
+
pass
|
751 |
+
|
752 |
+
features, inter = self.netG.filter(in_tensor_dict, return_inter=True)
|
753 |
+
sdf = self.reconEngine(
|
754 |
+
opt=self.cfg, netG=self.netG, features=features, proj_matrix=None
|
755 |
+
)
|
756 |
+
|
757 |
+
verts_pr, faces_pr = self.reconEngine.export_mesh(sdf)
|
758 |
+
|
759 |
+
if self.clean_mesh_flag:
|
760 |
+
verts_pr, faces_pr = clean_mesh(verts_pr, faces_pr)
|
761 |
+
|
762 |
+
verts_pr -= (self.resolutions[-1] - 1) / 2.0
|
763 |
+
verts_pr /= (self.resolutions[-1] - 1) / 2.0
|
764 |
+
|
765 |
+
return verts_pr, faces_pr, inter
|
apps/Normal.py
ADDED
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from lib.net import NormalNet
|
2 |
+
from lib.common.train_util import *
|
3 |
+
import logging
|
4 |
+
import torch
|
5 |
+
import numpy as np
|
6 |
+
from torch import nn
|
7 |
+
from skimage.transform import resize
|
8 |
+
import pytorch_lightning as pl
|
9 |
+
|
10 |
+
torch.backends.cudnn.benchmark = True
|
11 |
+
|
12 |
+
logging.getLogger("lightning").setLevel(logging.ERROR)
|
13 |
+
|
14 |
+
|
15 |
+
class Normal(pl.LightningModule):
|
16 |
+
def __init__(self, cfg):
|
17 |
+
super(Normal, self).__init__()
|
18 |
+
self.cfg = cfg
|
19 |
+
self.batch_size = self.cfg.batch_size
|
20 |
+
self.lr_N = self.cfg.lr_N
|
21 |
+
|
22 |
+
self.schedulers = []
|
23 |
+
|
24 |
+
self.netG = NormalNet(self.cfg, error_term=nn.SmoothL1Loss())
|
25 |
+
|
26 |
+
self.in_nml = [item[0] for item in cfg.net.in_nml]
|
27 |
+
|
28 |
+
def get_progress_bar_dict(self):
|
29 |
+
tqdm_dict = super().get_progress_bar_dict()
|
30 |
+
if "v_num" in tqdm_dict:
|
31 |
+
del tqdm_dict["v_num"]
|
32 |
+
return tqdm_dict
|
33 |
+
|
34 |
+
# Training related
|
35 |
+
def configure_optimizers(self):
|
36 |
+
|
37 |
+
# set optimizer
|
38 |
+
weight_decay = self.cfg.weight_decay
|
39 |
+
momentum = self.cfg.momentum
|
40 |
+
|
41 |
+
optim_params_N_F = [
|
42 |
+
{"params": self.netG.netF.parameters(), "lr": self.lr_N}]
|
43 |
+
optim_params_N_B = [
|
44 |
+
{"params": self.netG.netB.parameters(), "lr": self.lr_N}]
|
45 |
+
|
46 |
+
optimizer_N_F = torch.optim.Adam(
|
47 |
+
optim_params_N_F, lr=self.lr_N, weight_decay=weight_decay
|
48 |
+
)
|
49 |
+
|
50 |
+
optimizer_N_B = torch.optim.Adam(
|
51 |
+
optim_params_N_B, lr=self.lr_N, weight_decay=weight_decay
|
52 |
+
)
|
53 |
+
|
54 |
+
scheduler_N_F = torch.optim.lr_scheduler.MultiStepLR(
|
55 |
+
optimizer_N_F, milestones=self.cfg.schedule, gamma=self.cfg.gamma
|
56 |
+
)
|
57 |
+
|
58 |
+
scheduler_N_B = torch.optim.lr_scheduler.MultiStepLR(
|
59 |
+
optimizer_N_B, milestones=self.cfg.schedule, gamma=self.cfg.gamma
|
60 |
+
)
|
61 |
+
|
62 |
+
self.schedulers = [scheduler_N_F, scheduler_N_B]
|
63 |
+
optims = [optimizer_N_F, optimizer_N_B]
|
64 |
+
|
65 |
+
return optims, self.schedulers
|
66 |
+
|
67 |
+
def render_func(self, render_tensor):
|
68 |
+
|
69 |
+
height = render_tensor["image"].shape[2]
|
70 |
+
result_list = []
|
71 |
+
|
72 |
+
for name in render_tensor.keys():
|
73 |
+
result_list.append(
|
74 |
+
resize(
|
75 |
+
((render_tensor[name].cpu().numpy()[0] + 1.0) / 2.0).transpose(
|
76 |
+
1, 2, 0
|
77 |
+
),
|
78 |
+
(height, height),
|
79 |
+
anti_aliasing=True,
|
80 |
+
)
|
81 |
+
)
|
82 |
+
result_array = np.concatenate(result_list, axis=1)
|
83 |
+
|
84 |
+
return result_array
|
85 |
+
|
86 |
+
def training_step(self, batch, batch_idx, optimizer_idx):
|
87 |
+
|
88 |
+
export_cfg(self.logger, self.cfg)
|
89 |
+
|
90 |
+
# retrieve the data
|
91 |
+
in_tensor = {}
|
92 |
+
for name in self.in_nml:
|
93 |
+
in_tensor[name] = batch[name]
|
94 |
+
|
95 |
+
FB_tensor = {"normal_F": batch["normal_F"],
|
96 |
+
"normal_B": batch["normal_B"]}
|
97 |
+
|
98 |
+
self.netG.train()
|
99 |
+
|
100 |
+
preds_F, preds_B = self.netG(in_tensor)
|
101 |
+
error_NF, error_NB = self.netG.get_norm_error(
|
102 |
+
preds_F, preds_B, FB_tensor)
|
103 |
+
|
104 |
+
(opt_nf, opt_nb) = self.optimizers()
|
105 |
+
|
106 |
+
opt_nf.zero_grad()
|
107 |
+
opt_nb.zero_grad()
|
108 |
+
|
109 |
+
self.manual_backward(error_NF, opt_nf)
|
110 |
+
self.manual_backward(error_NB, opt_nb)
|
111 |
+
|
112 |
+
opt_nf.step()
|
113 |
+
opt_nb.step()
|
114 |
+
|
115 |
+
if batch_idx > 0 and batch_idx % int(self.cfg.freq_show_train) == 0:
|
116 |
+
|
117 |
+
self.netG.eval()
|
118 |
+
with torch.no_grad():
|
119 |
+
nmlF, nmlB = self.netG(in_tensor)
|
120 |
+
in_tensor.update({"nmlF": nmlF, "nmlB": nmlB})
|
121 |
+
result_array = self.render_func(in_tensor)
|
122 |
+
|
123 |
+
self.logger.experiment.add_image(
|
124 |
+
tag=f"Normal-train/{self.global_step}",
|
125 |
+
img_tensor=result_array.transpose(2, 0, 1),
|
126 |
+
global_step=self.global_step,
|
127 |
+
)
|
128 |
+
|
129 |
+
# metrics processing
|
130 |
+
metrics_log = {
|
131 |
+
"train_loss-NF": error_NF.item(),
|
132 |
+
"train_loss-NB": error_NB.item(),
|
133 |
+
}
|
134 |
+
|
135 |
+
tf_log = tf_log_convert(metrics_log)
|
136 |
+
bar_log = bar_log_convert(metrics_log)
|
137 |
+
|
138 |
+
return {
|
139 |
+
"loss": error_NF + error_NB,
|
140 |
+
"loss-NF": error_NF,
|
141 |
+
"loss-NB": error_NB,
|
142 |
+
"log": tf_log,
|
143 |
+
"progress_bar": bar_log,
|
144 |
+
}
|
145 |
+
|
146 |
+
def training_epoch_end(self, outputs):
|
147 |
+
|
148 |
+
if [] in outputs:
|
149 |
+
outputs = outputs[0]
|
150 |
+
|
151 |
+
# metrics processing
|
152 |
+
metrics_log = {
|
153 |
+
"train_avgloss": batch_mean(outputs, "loss"),
|
154 |
+
"train_avgloss-NF": batch_mean(outputs, "loss-NF"),
|
155 |
+
"train_avgloss-NB": batch_mean(outputs, "loss-NB"),
|
156 |
+
}
|
157 |
+
|
158 |
+
tf_log = tf_log_convert(metrics_log)
|
159 |
+
|
160 |
+
tf_log["lr-NF"] = self.schedulers[0].get_last_lr()[0]
|
161 |
+
tf_log["lr-NB"] = self.schedulers[1].get_last_lr()[0]
|
162 |
+
|
163 |
+
return {"log": tf_log}
|
164 |
+
|
165 |
+
def validation_step(self, batch, batch_idx):
|
166 |
+
|
167 |
+
# retrieve the data
|
168 |
+
in_tensor = {}
|
169 |
+
for name in self.in_nml:
|
170 |
+
in_tensor[name] = batch[name]
|
171 |
+
|
172 |
+
FB_tensor = {"normal_F": batch["normal_F"],
|
173 |
+
"normal_B": batch["normal_B"]}
|
174 |
+
|
175 |
+
self.netG.train()
|
176 |
+
|
177 |
+
preds_F, preds_B = self.netG(in_tensor)
|
178 |
+
error_NF, error_NB = self.netG.get_norm_error(
|
179 |
+
preds_F, preds_B, FB_tensor)
|
180 |
+
|
181 |
+
if (batch_idx > 0 and batch_idx % int(self.cfg.freq_show_train) == 0) or (
|
182 |
+
batch_idx == 0
|
183 |
+
):
|
184 |
+
|
185 |
+
with torch.no_grad():
|
186 |
+
nmlF, nmlB = self.netG(in_tensor)
|
187 |
+
in_tensor.update({"nmlF": nmlF, "nmlB": nmlB})
|
188 |
+
result_array = self.render_func(in_tensor)
|
189 |
+
|
190 |
+
self.logger.experiment.add_image(
|
191 |
+
tag=f"Normal-val/{self.global_step}",
|
192 |
+
img_tensor=result_array.transpose(2, 0, 1),
|
193 |
+
global_step=self.global_step,
|
194 |
+
)
|
195 |
+
|
196 |
+
return {
|
197 |
+
"val_loss": error_NF + error_NB,
|
198 |
+
"val_loss-NF": error_NF,
|
199 |
+
"val_loss-NB": error_NB,
|
200 |
+
}
|
201 |
+
|
202 |
+
def validation_epoch_end(self, outputs):
|
203 |
+
|
204 |
+
# metrics processing
|
205 |
+
metrics_log = {
|
206 |
+
"val_avgloss": batch_mean(outputs, "val_loss"),
|
207 |
+
"val_avgloss-NF": batch_mean(outputs, "val_loss-NF"),
|
208 |
+
"val_avgloss-NB": batch_mean(outputs, "val_loss-NB"),
|
209 |
+
}
|
210 |
+
|
211 |
+
tf_log = tf_log_convert(metrics_log)
|
212 |
+
|
213 |
+
return {"log": tf_log}
|
apps/infer.py
ADDED
@@ -0,0 +1,467 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
|
4 |
+
# holder of all proprietary rights on this computer program.
|
5 |
+
# You can only use this computer program if you have closed
|
6 |
+
# a license agreement with MPG or you get the right to use the computer
|
7 |
+
# program from someone who is authorized to grant you that right.
|
8 |
+
# Any use of the computer program without a valid license is prohibited and
|
9 |
+
# liable to prosecution.
|
10 |
+
#
|
11 |
+
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
|
12 |
+
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
|
13 |
+
# for Intelligent Systems. All rights reserved.
|
14 |
+
#
|
15 |
+
# Contact: [email protected]
|
16 |
+
|
17 |
+
import os
|
18 |
+
|
19 |
+
import logging
|
20 |
+
from lib.common.render import query_color
|
21 |
+
from lib.common.config import cfg
|
22 |
+
from lib.dataset.mesh_util import (
|
23 |
+
load_checkpoint,
|
24 |
+
update_mesh_shape_prior_losses,
|
25 |
+
blend_rgb_norm,
|
26 |
+
unwrap,
|
27 |
+
remesh,
|
28 |
+
tensor2variable,
|
29 |
+
)
|
30 |
+
|
31 |
+
from lib.dataset.TestDataset import TestDataset
|
32 |
+
from lib.net.local_affine import LocalAffine
|
33 |
+
from pytorch3d.structures import Meshes
|
34 |
+
from apps.ICON import ICON
|
35 |
+
|
36 |
+
from termcolor import colored
|
37 |
+
import numpy as np
|
38 |
+
from PIL import Image
|
39 |
+
import trimesh
|
40 |
+
import numpy as np
|
41 |
+
from tqdm import tqdm
|
42 |
+
|
43 |
+
import torch
|
44 |
+
torch.backends.cudnn.benchmark = True
|
45 |
+
|
46 |
+
logging.getLogger("trimesh").setLevel(logging.ERROR)
|
47 |
+
|
48 |
+
|
49 |
+
def generate_model(in_path, model_type):
|
50 |
+
|
51 |
+
torch.cuda.empty_cache()
|
52 |
+
|
53 |
+
config_dict = {'loop_smpl': 50,
|
54 |
+
'loop_cloth': 100,
|
55 |
+
'patience': 5,
|
56 |
+
'vis_freq': 10,
|
57 |
+
'out_dir': './results',
|
58 |
+
'hps_type': 'pymaf',
|
59 |
+
'config': f"./configs/{model_type}.yaml"}
|
60 |
+
|
61 |
+
# cfg read and merge
|
62 |
+
cfg.merge_from_file(config_dict['config'])
|
63 |
+
cfg.merge_from_file("./lib/pymaf/configs/pymaf_config.yaml")
|
64 |
+
|
65 |
+
os.makedirs(config_dict['out_dir'], exist_ok=True)
|
66 |
+
|
67 |
+
cfg_show_list = [
|
68 |
+
"test_gpus",
|
69 |
+
[0],
|
70 |
+
"mcube_res",
|
71 |
+
256,
|
72 |
+
"clean_mesh",
|
73 |
+
True,
|
74 |
+
]
|
75 |
+
|
76 |
+
cfg.merge_from_list(cfg_show_list)
|
77 |
+
cfg.freeze()
|
78 |
+
|
79 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
80 |
+
device = torch.device(f"cuda:0")
|
81 |
+
|
82 |
+
# load model and dataloader
|
83 |
+
model = ICON(cfg)
|
84 |
+
model = load_checkpoint(model, cfg)
|
85 |
+
|
86 |
+
dataset_param = {
|
87 |
+
'image_path': in_path,
|
88 |
+
'seg_dir': None,
|
89 |
+
'has_det': True, # w/ or w/o detection
|
90 |
+
'hps_type': 'pymaf' # pymaf/pare/pixie
|
91 |
+
}
|
92 |
+
|
93 |
+
if config_dict['hps_type'] == "pixie" and "pamir" in config_dict['config']:
|
94 |
+
print(colored("PIXIE isn't compatible with PaMIR, thus switch to PyMAF", "red"))
|
95 |
+
dataset_param["hps_type"] = "pymaf"
|
96 |
+
|
97 |
+
dataset = TestDataset(dataset_param, device)
|
98 |
+
|
99 |
+
print(colored(f"Dataset Size: {len(dataset)}", "green"))
|
100 |
+
|
101 |
+
pbar = tqdm(dataset)
|
102 |
+
|
103 |
+
for data in pbar:
|
104 |
+
|
105 |
+
pbar.set_description(f"{data['name']}")
|
106 |
+
|
107 |
+
in_tensor = {"smpl_faces": data["smpl_faces"], "image": data["image"]}
|
108 |
+
|
109 |
+
# The optimizer and variables
|
110 |
+
optimed_pose = torch.tensor(
|
111 |
+
data["body_pose"], device=device, requires_grad=True
|
112 |
+
) # [1,23,3,3]
|
113 |
+
optimed_trans = torch.tensor(
|
114 |
+
data["trans"], device=device, requires_grad=True
|
115 |
+
) # [3]
|
116 |
+
optimed_betas = torch.tensor(
|
117 |
+
data["betas"], device=device, requires_grad=True
|
118 |
+
) # [1,10]
|
119 |
+
optimed_orient = torch.tensor(
|
120 |
+
data["global_orient"], device=device, requires_grad=True
|
121 |
+
) # [1,1,3,3]
|
122 |
+
|
123 |
+
optimizer_smpl = torch.optim.SGD(
|
124 |
+
[optimed_pose, optimed_trans, optimed_betas, optimed_orient],
|
125 |
+
lr=1e-3,
|
126 |
+
momentum=0.9,
|
127 |
+
)
|
128 |
+
scheduler_smpl = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
129 |
+
optimizer_smpl,
|
130 |
+
mode="min",
|
131 |
+
factor=0.5,
|
132 |
+
verbose=0,
|
133 |
+
min_lr=1e-5,
|
134 |
+
patience=config_dict['patience'],
|
135 |
+
)
|
136 |
+
|
137 |
+
losses = {
|
138 |
+
# Cloth: Normal_recon - Normal_pred
|
139 |
+
"cloth": {"weight": 1e1, "value": 0.0},
|
140 |
+
# Cloth: [RT]_v1 - [RT]_v2 (v1-edge-v2)
|
141 |
+
"stiffness": {"weight": 1e5, "value": 0.0},
|
142 |
+
# Cloth: det(R) = 1
|
143 |
+
"rigid": {"weight": 1e5, "value": 0.0},
|
144 |
+
# Cloth: edge length
|
145 |
+
"edge": {"weight": 0, "value": 0.0},
|
146 |
+
# Cloth: normal consistency
|
147 |
+
"nc": {"weight": 0, "value": 0.0},
|
148 |
+
# Cloth: laplacian smoonth
|
149 |
+
"laplacian": {"weight": 1e2, "value": 0.0},
|
150 |
+
# Body: Normal_pred - Normal_smpl
|
151 |
+
"normal": {"weight": 1e0, "value": 0.0},
|
152 |
+
# Body: Silhouette_pred - Silhouette_smpl
|
153 |
+
"silhouette": {"weight": 1e0, "value": 0.0},
|
154 |
+
}
|
155 |
+
|
156 |
+
# smpl optimization
|
157 |
+
|
158 |
+
loop_smpl = tqdm(
|
159 |
+
range(config_dict['loop_smpl'] if cfg.net.prior_type != "pifu" else 1))
|
160 |
+
|
161 |
+
for i in loop_smpl:
|
162 |
+
|
163 |
+
optimizer_smpl.zero_grad()
|
164 |
+
|
165 |
+
if dataset_param["hps_type"] != "pixie":
|
166 |
+
smpl_out = dataset.smpl_model(
|
167 |
+
betas=optimed_betas,
|
168 |
+
body_pose=optimed_pose,
|
169 |
+
global_orient=optimed_orient,
|
170 |
+
pose2rot=False,
|
171 |
+
)
|
172 |
+
|
173 |
+
smpl_verts = ((smpl_out.vertices) +
|
174 |
+
optimed_trans) * data["scale"]
|
175 |
+
else:
|
176 |
+
smpl_verts, _, _ = dataset.smpl_model(
|
177 |
+
shape_params=optimed_betas,
|
178 |
+
expression_params=tensor2variable(data["exp"], device),
|
179 |
+
body_pose=optimed_pose,
|
180 |
+
global_pose=optimed_orient,
|
181 |
+
jaw_pose=tensor2variable(data["jaw_pose"], device),
|
182 |
+
left_hand_pose=tensor2variable(
|
183 |
+
data["left_hand_pose"], device),
|
184 |
+
right_hand_pose=tensor2variable(
|
185 |
+
data["right_hand_pose"], device),
|
186 |
+
)
|
187 |
+
|
188 |
+
smpl_verts = (smpl_verts + optimed_trans) * data["scale"]
|
189 |
+
|
190 |
+
# render optimized mesh (normal, T_normal, image [-1,1])
|
191 |
+
in_tensor["T_normal_F"], in_tensor["T_normal_B"] = dataset.render_normal(
|
192 |
+
smpl_verts *
|
193 |
+
torch.tensor([1.0, -1.0, -1.0]
|
194 |
+
).to(device), in_tensor["smpl_faces"]
|
195 |
+
)
|
196 |
+
T_mask_F, T_mask_B = dataset.render.get_silhouette_image()
|
197 |
+
|
198 |
+
with torch.no_grad():
|
199 |
+
in_tensor["normal_F"], in_tensor["normal_B"] = model.netG.normal_filter(
|
200 |
+
in_tensor
|
201 |
+
)
|
202 |
+
|
203 |
+
diff_F_smpl = torch.abs(
|
204 |
+
in_tensor["T_normal_F"] - in_tensor["normal_F"])
|
205 |
+
diff_B_smpl = torch.abs(
|
206 |
+
in_tensor["T_normal_B"] - in_tensor["normal_B"])
|
207 |
+
|
208 |
+
losses["normal"]["value"] = (diff_F_smpl + diff_B_smpl).mean()
|
209 |
+
|
210 |
+
# silhouette loss
|
211 |
+
smpl_arr = torch.cat([T_mask_F, T_mask_B], dim=-1)[0]
|
212 |
+
gt_arr = torch.cat(
|
213 |
+
[in_tensor["normal_F"][0], in_tensor["normal_B"][0]], dim=2
|
214 |
+
).permute(1, 2, 0)
|
215 |
+
gt_arr = ((gt_arr + 1.0) * 0.5).to(device)
|
216 |
+
bg_color = (
|
217 |
+
torch.Tensor([0.5, 0.5, 0.5]).unsqueeze(
|
218 |
+
0).unsqueeze(0).to(device)
|
219 |
+
)
|
220 |
+
gt_arr = ((gt_arr - bg_color).sum(dim=-1) != 0.0).float()
|
221 |
+
diff_S = torch.abs(smpl_arr - gt_arr)
|
222 |
+
losses["silhouette"]["value"] = diff_S.mean()
|
223 |
+
|
224 |
+
# Weighted sum of the losses
|
225 |
+
smpl_loss = 0.0
|
226 |
+
pbar_desc = "Body Fitting --- "
|
227 |
+
for k in ["normal", "silhouette"]:
|
228 |
+
pbar_desc += f"{k}: {losses[k]['value'] * losses[k]['weight']:.3f} | "
|
229 |
+
smpl_loss += losses[k]["value"] * losses[k]["weight"]
|
230 |
+
pbar_desc += f"Total: {smpl_loss:.3f}"
|
231 |
+
loop_smpl.set_description(pbar_desc)
|
232 |
+
|
233 |
+
smpl_loss.backward()
|
234 |
+
optimizer_smpl.step()
|
235 |
+
scheduler_smpl.step(smpl_loss)
|
236 |
+
in_tensor["smpl_verts"] = smpl_verts * \
|
237 |
+
torch.tensor([1.0, 1.0, -1.0]).to(device)
|
238 |
+
|
239 |
+
# visualize the optimization process
|
240 |
+
# 1. SMPL Fitting
|
241 |
+
# 2. Clothes Refinement
|
242 |
+
|
243 |
+
os.makedirs(os.path.join(config_dict['out_dir'], cfg.name,
|
244 |
+
"refinement"), exist_ok=True)
|
245 |
+
|
246 |
+
# visualize the final results in self-rotation mode
|
247 |
+
os.makedirs(os.path.join(config_dict['out_dir'],
|
248 |
+
cfg.name, "vid"), exist_ok=True)
|
249 |
+
|
250 |
+
# final results rendered as image
|
251 |
+
# 1. Render the final fitted SMPL (xxx_smpl.png)
|
252 |
+
# 2. Render the final reconstructed clothed human (xxx_cloth.png)
|
253 |
+
# 3. Blend the original image with predicted cloth normal (xxx_overlap.png)
|
254 |
+
|
255 |
+
os.makedirs(os.path.join(config_dict['out_dir'],
|
256 |
+
cfg.name, "png"), exist_ok=True)
|
257 |
+
|
258 |
+
# final reconstruction meshes
|
259 |
+
# 1. SMPL mesh (xxx_smpl.obj)
|
260 |
+
# 2. SMPL params (xxx_smpl.npy)
|
261 |
+
# 3. clohted mesh (xxx_recon.obj)
|
262 |
+
# 4. remeshed clothed mesh (xxx_remesh.obj)
|
263 |
+
# 5. refined clothed mesh (xxx_refine.obj)
|
264 |
+
|
265 |
+
os.makedirs(os.path.join(config_dict['out_dir'],
|
266 |
+
cfg.name, "obj"), exist_ok=True)
|
267 |
+
|
268 |
+
|
269 |
+
norm_pred = (
|
270 |
+
((in_tensor["normal_F"][0].permute(1, 2, 0) + 1.0) * 255.0 / 2.0)
|
271 |
+
.detach()
|
272 |
+
.cpu()
|
273 |
+
.numpy()
|
274 |
+
.astype(np.uint8)
|
275 |
+
)
|
276 |
+
|
277 |
+
norm_orig = unwrap(norm_pred, data)
|
278 |
+
mask_orig = unwrap(
|
279 |
+
np.repeat(
|
280 |
+
data["mask"].permute(1, 2, 0).detach().cpu().numpy(), 3, axis=2
|
281 |
+
).astype(np.uint8),
|
282 |
+
data,
|
283 |
+
)
|
284 |
+
rgb_norm = blend_rgb_norm(data["ori_image"], norm_orig, mask_orig)
|
285 |
+
|
286 |
+
Image.fromarray(
|
287 |
+
np.concatenate(
|
288 |
+
[data["ori_image"].astype(np.uint8), rgb_norm], axis=1)
|
289 |
+
).save(os.path.join(config_dict['out_dir'], cfg.name, f"png/{data['name']}_overlap.png"))
|
290 |
+
|
291 |
+
smpl_obj = trimesh.Trimesh(
|
292 |
+
in_tensor["smpl_verts"].detach().cpu()[0] *
|
293 |
+
torch.tensor([1.0, -1.0, 1.0]),
|
294 |
+
in_tensor['smpl_faces'].detach().cpu()[0],
|
295 |
+
process=False,
|
296 |
+
maintains_order=True
|
297 |
+
)
|
298 |
+
smpl_obj.export(
|
299 |
+
f"{config_dict['out_dir']}/{cfg.name}/obj/{data['name']}_smpl.glb")
|
300 |
+
|
301 |
+
smpl_info = {'betas': optimed_betas,
|
302 |
+
'pose': optimed_pose,
|
303 |
+
'orient': optimed_orient,
|
304 |
+
'trans': optimed_trans}
|
305 |
+
|
306 |
+
np.save(
|
307 |
+
f"{config_dict['out_dir']}/{cfg.name}/obj/{data['name']}_smpl.npy", smpl_info, allow_pickle=True)
|
308 |
+
|
309 |
+
# ------------------------------------------------------------------------------------------------------------------
|
310 |
+
|
311 |
+
# cloth optimization
|
312 |
+
|
313 |
+
# cloth recon
|
314 |
+
in_tensor.update(
|
315 |
+
dataset.compute_vis_cmap(
|
316 |
+
in_tensor["smpl_verts"][0], in_tensor["smpl_faces"][0]
|
317 |
+
)
|
318 |
+
)
|
319 |
+
|
320 |
+
if cfg.net.prior_type == "pamir":
|
321 |
+
in_tensor.update(
|
322 |
+
dataset.compute_voxel_verts(
|
323 |
+
optimed_pose,
|
324 |
+
optimed_orient,
|
325 |
+
optimed_betas,
|
326 |
+
optimed_trans,
|
327 |
+
data["scale"],
|
328 |
+
)
|
329 |
+
)
|
330 |
+
|
331 |
+
with torch.no_grad():
|
332 |
+
verts_pr, faces_pr, _ = model.test_single(in_tensor)
|
333 |
+
|
334 |
+
recon_obj = trimesh.Trimesh(
|
335 |
+
verts_pr, faces_pr, process=False, maintains_order=True
|
336 |
+
)
|
337 |
+
recon_obj.export(
|
338 |
+
os.path.join(config_dict['out_dir'], cfg.name,
|
339 |
+
f"obj/{data['name']}_recon.obj")
|
340 |
+
)
|
341 |
+
|
342 |
+
recon_obj.export(
|
343 |
+
os.path.join(config_dict['out_dir'], cfg.name,
|
344 |
+
f"obj/{data['name']}_recon.glb")
|
345 |
+
)
|
346 |
+
|
347 |
+
# Isotropic Explicit Remeshing for better geometry topology
|
348 |
+
verts_refine, faces_refine = remesh(os.path.join(config_dict['out_dir'], cfg.name,
|
349 |
+
f"obj/{data['name']}_recon.obj"), 0.5, device)
|
350 |
+
|
351 |
+
# define local_affine deform verts
|
352 |
+
mesh_pr = Meshes(verts_refine, faces_refine).to(device)
|
353 |
+
local_affine_model = LocalAffine(
|
354 |
+
mesh_pr.verts_padded().shape[1], mesh_pr.verts_padded().shape[0], mesh_pr.edges_packed()).to(device)
|
355 |
+
optimizer_cloth = torch.optim.Adam(
|
356 |
+
[{'params': local_affine_model.parameters()}], lr=1e-4, amsgrad=True)
|
357 |
+
|
358 |
+
scheduler_cloth = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
359 |
+
optimizer_cloth,
|
360 |
+
mode="min",
|
361 |
+
factor=0.1,
|
362 |
+
verbose=0,
|
363 |
+
min_lr=1e-5,
|
364 |
+
patience=config_dict['patience'],
|
365 |
+
)
|
366 |
+
|
367 |
+
final = None
|
368 |
+
|
369 |
+
if config_dict['loop_cloth'] > 0:
|
370 |
+
|
371 |
+
loop_cloth = tqdm(range(config_dict['loop_cloth']))
|
372 |
+
|
373 |
+
for i in loop_cloth:
|
374 |
+
|
375 |
+
optimizer_cloth.zero_grad()
|
376 |
+
|
377 |
+
deformed_verts, stiffness, rigid = local_affine_model(
|
378 |
+
verts_refine.to(device), return_stiff=True)
|
379 |
+
mesh_pr = mesh_pr.update_padded(deformed_verts)
|
380 |
+
|
381 |
+
# losses for laplacian, edge, normal consistency
|
382 |
+
update_mesh_shape_prior_losses(mesh_pr, losses)
|
383 |
+
|
384 |
+
in_tensor["P_normal_F"], in_tensor["P_normal_B"] = dataset.render_normal(
|
385 |
+
mesh_pr.verts_padded(), mesh_pr.faces_padded())
|
386 |
+
|
387 |
+
diff_F_cloth = torch.abs(
|
388 |
+
in_tensor["P_normal_F"] - in_tensor["normal_F"])
|
389 |
+
diff_B_cloth = torch.abs(
|
390 |
+
in_tensor["P_normal_B"] - in_tensor["normal_B"])
|
391 |
+
|
392 |
+
losses["cloth"]["value"] = (diff_F_cloth + diff_B_cloth).mean()
|
393 |
+
losses["stiffness"]["value"] = torch.mean(stiffness)
|
394 |
+
losses["rigid"]["value"] = torch.mean(rigid)
|
395 |
+
|
396 |
+
# Weighted sum of the losses
|
397 |
+
cloth_loss = torch.tensor(0.0, requires_grad=True).to(device)
|
398 |
+
pbar_desc = "Cloth Refinement --- "
|
399 |
+
|
400 |
+
for k in losses.keys():
|
401 |
+
if k not in ["normal", "silhouette"] and losses[k]["weight"] > 0.0:
|
402 |
+
cloth_loss = cloth_loss + \
|
403 |
+
losses[k]["value"] * losses[k]["weight"]
|
404 |
+
pbar_desc += f"{k}:{losses[k]['value']* losses[k]['weight']:.5f} | "
|
405 |
+
|
406 |
+
pbar_desc += f"Total: {cloth_loss:.5f}"
|
407 |
+
loop_cloth.set_description(pbar_desc)
|
408 |
+
|
409 |
+
# update params
|
410 |
+
cloth_loss.backward(retain_graph=True)
|
411 |
+
optimizer_cloth.step()
|
412 |
+
scheduler_cloth.step(cloth_loss)
|
413 |
+
|
414 |
+
|
415 |
+
final = trimesh.Trimesh(
|
416 |
+
mesh_pr.verts_packed().detach().squeeze(0).cpu(),
|
417 |
+
mesh_pr.faces_packed().detach().squeeze(0).cpu(),
|
418 |
+
process=False, maintains_order=True
|
419 |
+
)
|
420 |
+
final_colors = query_color(
|
421 |
+
mesh_pr.verts_packed().detach().squeeze(0).cpu(),
|
422 |
+
mesh_pr.faces_packed().detach().squeeze(0).cpu(),
|
423 |
+
in_tensor["image"],
|
424 |
+
device=device,
|
425 |
+
)
|
426 |
+
final.visual.vertex_colors = final_colors
|
427 |
+
final.export(
|
428 |
+
f"{config_dict['out_dir']}/{cfg.name}/obj/{data['name']}_refine.glb")
|
429 |
+
|
430 |
+
|
431 |
+
# always export visualized video regardless of the cloth refinment
|
432 |
+
if final is not None:
|
433 |
+
verts_lst = [verts_pr, final.vertices]
|
434 |
+
faces_lst = [faces_pr, final.faces]
|
435 |
+
else:
|
436 |
+
verts_lst = [verts_pr]
|
437 |
+
faces_lst = [faces_pr]
|
438 |
+
|
439 |
+
# self-rotated video
|
440 |
+
dataset.render.load_meshes(
|
441 |
+
verts_lst, faces_lst)
|
442 |
+
dataset.render.get_rendered_video(
|
443 |
+
[data["ori_image"], rgb_norm],
|
444 |
+
os.path.join(config_dict['out_dir'], cfg.name,
|
445 |
+
f"vid/{data['name']}_cloth.mp4"),
|
446 |
+
)
|
447 |
+
|
448 |
+
smpl_path = f"{config_dict['out_dir']}/{cfg.name}/obj/{data['name']}_smpl.glb"
|
449 |
+
smpl_npy_path = f"{config_dict['out_dir']}/{cfg.name}/obj/{data['name']}_smpl.npy"
|
450 |
+
recon_path = f"{config_dict['out_dir']}/{cfg.name}/obj/{data['name']}_recon.glb"
|
451 |
+
refine_path = f"{config_dict['out_dir']}/{cfg.name}/obj/{data['name']}_refine.glb"
|
452 |
+
|
453 |
+
video_path = os.path.join(config_dict['out_dir'], cfg.name, f"vid/{data['name']}_cloth.mp4")
|
454 |
+
overlap_path = os.path.join(config_dict['out_dir'], cfg.name, f"png/{data['name']}_overlap.png")
|
455 |
+
|
456 |
+
torch.cuda.empty_cache()
|
457 |
+
del model
|
458 |
+
del dataset
|
459 |
+
del local_affine_model
|
460 |
+
del optimizer_smpl
|
461 |
+
del optimizer_cloth
|
462 |
+
del scheduler_smpl
|
463 |
+
del scheduler_cloth
|
464 |
+
del losses
|
465 |
+
del in_tensor
|
466 |
+
|
467 |
+
return [smpl_path, smpl_path, smpl_npy_path, recon_path, recon_path, refine_path, refine_path, video_path, overlap_path]
|
assets/garment_teaser.png
ADDED
![]() |
Git LFS Details
|
assets/intermediate_results.png
ADDED
![]() |
Git LFS Details
|
assets/teaser.gif
ADDED
![]() |
Git LFS Details
|
assets/thumbnail.png
ADDED
![]() |
Git LFS Details
|
configs/icon-filter.yaml
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: icon-filter
|
2 |
+
ckpt_dir: "./data/ckpt/"
|
3 |
+
resume_path: "https://huggingface.co/Yuliang/ICON/resolve/main/icon-filter.ckpt"
|
4 |
+
normal_path: "https://huggingface.co/Yuliang/ICON/resolve/main/normal.ckpt"
|
5 |
+
|
6 |
+
test_mode: True
|
7 |
+
batch_size: 1
|
8 |
+
|
9 |
+
net:
|
10 |
+
mlp_dim: [256, 512, 256, 128, 1]
|
11 |
+
res_layers: [2,3,4]
|
12 |
+
num_stack: 2
|
13 |
+
prior_type: "icon" # icon/pamir/icon
|
14 |
+
use_filter: True
|
15 |
+
in_geo: (('normal_F',3), ('normal_B',3))
|
16 |
+
in_nml: (('image',3), ('T_normal_F',3), ('T_normal_B',3))
|
17 |
+
smpl_feats: ['sdf', 'norm', 'vis', 'cmap']
|
18 |
+
gtype: 'HGPIFuNet'
|
19 |
+
norm_mlp: 'batch'
|
20 |
+
hourglass_dim: 6
|
21 |
+
smpl_dim: 7
|
22 |
+
|
23 |
+
# user defined
|
24 |
+
mcube_res: 512 # occupancy field resolution, higher --> more details
|
25 |
+
clean_mesh: False # if True, will remove floating pieces
|
configs/icon-nofilter.yaml
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: icon-nofilter
|
2 |
+
ckpt_dir: "./data/ckpt/"
|
3 |
+
resume_path: "https://huggingface.co/Yuliang/ICON/resolve/main/icon-nofilter.ckpt"
|
4 |
+
normal_path: "https://huggingface.co/Yuliang/ICON/resolve/main/normal.ckpt"
|
5 |
+
|
6 |
+
test_mode: True
|
7 |
+
batch_size: 1
|
8 |
+
|
9 |
+
net:
|
10 |
+
mlp_dim: [256, 512, 256, 128, 1]
|
11 |
+
res_layers: [2,3,4]
|
12 |
+
num_stack: 2
|
13 |
+
prior_type: "icon" # icon/pamir/icon
|
14 |
+
use_filter: False
|
15 |
+
in_geo: (('normal_F',3), ('normal_B',3))
|
16 |
+
in_nml: (('image',3), ('T_normal_F',3), ('T_normal_B',3))
|
17 |
+
smpl_feats: ['sdf', 'norm', 'vis', 'cmap']
|
18 |
+
gtype: 'HGPIFuNet'
|
19 |
+
norm_mlp: 'batch'
|
20 |
+
hourglass_dim: 6
|
21 |
+
smpl_dim: 7
|
22 |
+
|
23 |
+
# user defined
|
24 |
+
mcube_res: 512 # occupancy field resolution, higher --> more details
|
25 |
+
clean_mesh: False # if True, will remove floating pieces
|
configs/pamir.yaml
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: pamir
|
2 |
+
ckpt_dir: "./data/ckpt/"
|
3 |
+
resume_path: "https://huggingface.co/Yuliang/ICON/resolve/main/pamir.ckpt"
|
4 |
+
normal_path: "https://huggingface.co/Yuliang/ICON/resolve/main/normal.ckpt"
|
5 |
+
|
6 |
+
test_mode: True
|
7 |
+
batch_size: 1
|
8 |
+
|
9 |
+
net:
|
10 |
+
mlp_dim: [256, 512, 256, 128, 1]
|
11 |
+
res_layers: [2,3,4]
|
12 |
+
num_stack: 2
|
13 |
+
prior_type: "pamir" # icon/pamir/icon
|
14 |
+
use_filter: True
|
15 |
+
in_geo: (('image',3), ('normal_F',3), ('normal_B',3))
|
16 |
+
in_nml: (('image',3), ('T_normal_F',3), ('T_normal_B',3))
|
17 |
+
gtype: 'HGPIFuNet'
|
18 |
+
norm_mlp: 'batch'
|
19 |
+
hourglass_dim: 6
|
20 |
+
voxel_dim: 7
|
21 |
+
|
22 |
+
# user defined
|
23 |
+
mcube_res: 512 # occupancy field resolution, higher --> more details
|
24 |
+
clean_mesh: False # if True, will remove floating pieces
|
configs/pifu.yaml
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: pifu
|
2 |
+
ckpt_dir: "./data/ckpt/"
|
3 |
+
resume_path: "https://huggingface.co/Yuliang/ICON/resolve/main/pifu.ckpt"
|
4 |
+
normal_path: "https://huggingface.co/Yuliang/ICON/resolve/main/normal.ckpt"
|
5 |
+
|
6 |
+
test_mode: True
|
7 |
+
batch_size: 1
|
8 |
+
|
9 |
+
net:
|
10 |
+
mlp_dim: [256, 512, 256, 128, 1]
|
11 |
+
res_layers: [2,3,4]
|
12 |
+
num_stack: 2
|
13 |
+
prior_type: "pifu" # icon/pamir/icon
|
14 |
+
use_filter: True
|
15 |
+
in_geo: (('image',3), ('normal_F',3), ('normal_B',3))
|
16 |
+
in_nml: (('image',3), ('T_normal_F',3), ('T_normal_B',3))
|
17 |
+
gtype: 'HGPIFuNet'
|
18 |
+
norm_mlp: 'batch'
|
19 |
+
hourglass_dim: 12
|
20 |
+
|
21 |
+
|
22 |
+
# user defined
|
23 |
+
mcube_res: 512 # occupancy field resolution, higher --> more details
|
24 |
+
clean_mesh: False # if True, will remove floating pieces
|
examples/22097467bffc92d4a5c4246f7d4edb75.png
ADDED
![]() |
Git LFS Details
|
examples/44c0f84c957b6b9bdf77662af5bb7078.png
ADDED
![]() |
Git LFS Details
|
examples/5a6a25963db2f667441d5076972c207c.png
ADDED
![]() |
Git LFS Details
|
examples/8da7ceb94669c2f65cbd28022e1f9876.png
ADDED
![]() |
Git LFS Details
|
examples/923d65f767c85a42212cae13fba3750b.png
ADDED
![]() |
Git LFS Details
|
examples/959c4c726a69901ce71b93a9242ed900.png
ADDED
![]() |
Git LFS Details
|
examples/c9856a2bc31846d684cbb965457fad59.png
ADDED
![]() |
Git LFS Details
|
examples/e1e7622af7074a022f5d96dc16672517.png
ADDED
![]() |
Git LFS Details
|
examples/fb9d20fdb93750584390599478ecf86e.png
ADDED
![]() |
Git LFS Details
|
examples/slack_trial2-000150.png
ADDED
![]() |
Git LFS Details
|
lib/__init__.py
ADDED
File without changes
|
lib/common/__init__.py
ADDED
File without changes
|
lib/common/cloth_extraction.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
import itertools
|
5 |
+
import trimesh
|
6 |
+
from matplotlib.path import Path
|
7 |
+
from collections import Counter
|
8 |
+
from sklearn.neighbors import KNeighborsClassifier
|
9 |
+
|
10 |
+
|
11 |
+
def load_segmentation(path, shape):
|
12 |
+
"""
|
13 |
+
Get a segmentation mask for a given image
|
14 |
+
Arguments:
|
15 |
+
path: path to the segmentation json file
|
16 |
+
shape: shape of the output mask
|
17 |
+
Returns:
|
18 |
+
Returns a segmentation mask
|
19 |
+
"""
|
20 |
+
with open(path) as json_file:
|
21 |
+
dict = json.load(json_file)
|
22 |
+
segmentations = []
|
23 |
+
for key, val in dict.items():
|
24 |
+
if not key.startswith('item'):
|
25 |
+
continue
|
26 |
+
|
27 |
+
# Each item can have multiple polygons. Combine them to one
|
28 |
+
# segmentation_coord = list(itertools.chain.from_iterable(val['segmentation']))
|
29 |
+
# segmentation_coord = np.round(np.array(segmentation_coord)).astype(int)
|
30 |
+
|
31 |
+
coordinates = []
|
32 |
+
for segmentation_coord in val['segmentation']:
|
33 |
+
# The format before is [x1,y1, x2, y2, ....]
|
34 |
+
x = segmentation_coord[::2]
|
35 |
+
y = segmentation_coord[1::2]
|
36 |
+
xy = np.vstack((x, y)).T
|
37 |
+
coordinates.append(xy)
|
38 |
+
|
39 |
+
segmentations.append(
|
40 |
+
{'type': val['category_name'], 'type_id': val['category_id'], 'coordinates': coordinates})
|
41 |
+
|
42 |
+
return segmentations
|
43 |
+
|
44 |
+
|
45 |
+
def smpl_to_recon_labels(recon, smpl, k=1):
|
46 |
+
"""
|
47 |
+
Get the bodypart labels for the recon object by using the labels from the corresponding smpl object
|
48 |
+
Arguments:
|
49 |
+
recon: trimesh object (fully clothed model)
|
50 |
+
shape: trimesh object (smpl model)
|
51 |
+
k: number of nearest neighbours to use
|
52 |
+
Returns:
|
53 |
+
Returns a dictionary containing the bodypart and the corresponding indices
|
54 |
+
"""
|
55 |
+
smpl_vert_segmentation = json.load(
|
56 |
+
open(os.path.join(os.path.dirname(__file__), 'smpl_vert_segmentation.json')))
|
57 |
+
n = smpl.vertices.shape[0]
|
58 |
+
y = np.array([None] * n)
|
59 |
+
for key, val in smpl_vert_segmentation.items():
|
60 |
+
y[val] = key
|
61 |
+
|
62 |
+
classifier = KNeighborsClassifier(n_neighbors=1)
|
63 |
+
classifier.fit(smpl.vertices, y)
|
64 |
+
|
65 |
+
y_pred = classifier.predict(recon.vertices)
|
66 |
+
|
67 |
+
recon_labels = {}
|
68 |
+
for key in smpl_vert_segmentation.keys():
|
69 |
+
recon_labels[key] = list(np.argwhere(
|
70 |
+
y_pred == key).flatten().astype(int))
|
71 |
+
|
72 |
+
return recon_labels
|
73 |
+
|
74 |
+
|
75 |
+
def extract_cloth(recon, segmentation, K, R, t, smpl=None):
|
76 |
+
"""
|
77 |
+
Extract a portion of a mesh using 2d segmentation coordinates
|
78 |
+
Arguments:
|
79 |
+
recon: fully clothed mesh
|
80 |
+
seg_coord: segmentation coordinates in 2D (NDC)
|
81 |
+
K: intrinsic matrix of the projection
|
82 |
+
R: rotation matrix of the projection
|
83 |
+
t: translation vector of the projection
|
84 |
+
Returns:
|
85 |
+
Returns a submesh using the segmentation coordinates
|
86 |
+
"""
|
87 |
+
seg_coord = segmentation['coord_normalized']
|
88 |
+
mesh = trimesh.Trimesh(recon.vertices, recon.faces)
|
89 |
+
extrinsic = np.zeros((3, 4))
|
90 |
+
extrinsic[:3, :3] = R
|
91 |
+
extrinsic[:, 3] = t
|
92 |
+
P = K[:3, :3] @ extrinsic
|
93 |
+
|
94 |
+
P_inv = np.linalg.pinv(P)
|
95 |
+
|
96 |
+
# Each segmentation can contain multiple polygons
|
97 |
+
# We need to check them separately
|
98 |
+
points_so_far = []
|
99 |
+
faces = recon.faces
|
100 |
+
for polygon in seg_coord:
|
101 |
+
n = len(polygon)
|
102 |
+
coords_h = np.hstack((polygon, np.ones((n, 1))))
|
103 |
+
# Apply the inverse projection on homogeneus 2D coordinates to get the corresponding 3d Coordinates
|
104 |
+
XYZ = P_inv @ coords_h[:, :, None]
|
105 |
+
XYZ = XYZ.reshape((XYZ.shape[0], XYZ.shape[1]))
|
106 |
+
XYZ = XYZ[:, :3] / XYZ[:, 3, None]
|
107 |
+
|
108 |
+
p = Path(XYZ[:, :2])
|
109 |
+
|
110 |
+
grid = p.contains_points(recon.vertices[:, :2])
|
111 |
+
indeces = np.argwhere(grid == True)
|
112 |
+
points_so_far += list(indeces.flatten())
|
113 |
+
|
114 |
+
if smpl is not None:
|
115 |
+
num_verts = recon.vertices.shape[0]
|
116 |
+
recon_labels = smpl_to_recon_labels(recon, smpl)
|
117 |
+
body_parts_to_remove = ['rightHand', 'leftToeBase', 'leftFoot', 'rightFoot', 'head',
|
118 |
+
'leftHandIndex1', 'rightHandIndex1', 'rightToeBase', 'leftHand', 'rightHand']
|
119 |
+
type = segmentation['type_id']
|
120 |
+
|
121 |
+
# Remove additional bodyparts that are most likely not part of the segmentation but might intersect (e.g. hand in front of torso)
|
122 |
+
# https://github.com/switchablenorms/DeepFashion2
|
123 |
+
# Short sleeve clothes
|
124 |
+
if type == 1 or type == 3 or type == 10:
|
125 |
+
body_parts_to_remove += ['leftForeArm', 'rightForeArm']
|
126 |
+
# No sleeves at all or lower body clothes
|
127 |
+
elif type == 5 or type == 6 or type == 12 or type == 13 or type == 8 or type == 9:
|
128 |
+
body_parts_to_remove += ['leftForeArm',
|
129 |
+
'rightForeArm', 'leftArm', 'rightArm']
|
130 |
+
# Shorts
|
131 |
+
elif type == 7:
|
132 |
+
body_parts_to_remove += ['leftLeg', 'rightLeg',
|
133 |
+
'leftForeArm', 'rightForeArm', 'leftArm', 'rightArm']
|
134 |
+
|
135 |
+
verts_to_remove = list(itertools.chain.from_iterable(
|
136 |
+
[recon_labels[part] for part in body_parts_to_remove]))
|
137 |
+
|
138 |
+
label_mask = np.zeros(num_verts, dtype=bool)
|
139 |
+
label_mask[verts_to_remove] = True
|
140 |
+
|
141 |
+
seg_mask = np.zeros(num_verts, dtype=bool)
|
142 |
+
seg_mask[points_so_far] = True
|
143 |
+
|
144 |
+
# Remove points that belong to other bodyparts
|
145 |
+
# If a vertice in pointsSoFar is included in the bodyparts to remove, then these points should be removed
|
146 |
+
extra_verts_to_remove = np.array(list(seg_mask) and list(label_mask))
|
147 |
+
|
148 |
+
combine_mask = np.zeros(num_verts, dtype=bool)
|
149 |
+
combine_mask[points_so_far] = True
|
150 |
+
combine_mask[extra_verts_to_remove] = False
|
151 |
+
|
152 |
+
all_indices = np.argwhere(combine_mask == True).flatten()
|
153 |
+
|
154 |
+
i_x = np.where(np.in1d(faces[:, 0], all_indices))[0]
|
155 |
+
i_y = np.where(np.in1d(faces[:, 1], all_indices))[0]
|
156 |
+
i_z = np.where(np.in1d(faces[:, 2], all_indices))[0]
|
157 |
+
|
158 |
+
faces_to_keep = np.array(list(set(i_x).union(i_y).union(i_z)))
|
159 |
+
mask = np.zeros(len(recon.faces), dtype=bool)
|
160 |
+
if len(faces_to_keep) > 0:
|
161 |
+
mask[faces_to_keep] = True
|
162 |
+
|
163 |
+
mesh.update_faces(mask)
|
164 |
+
mesh.remove_unreferenced_vertices()
|
165 |
+
|
166 |
+
# mesh.rezero()
|
167 |
+
|
168 |
+
return mesh
|
169 |
+
|
170 |
+
return None
|
lib/common/config.py
ADDED
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
|
5 |
+
# holder of all proprietary rights on this computer program.
|
6 |
+
# You can only use this computer program if you have closed
|
7 |
+
# a license agreement with MPG or you get the right to use the computer
|
8 |
+
# program from someone who is authorized to grant you that right.
|
9 |
+
# Any use of the computer program without a valid license is prohibited and
|
10 |
+
# liable to prosecution.
|
11 |
+
#
|
12 |
+
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
|
13 |
+
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
|
14 |
+
# for Intelligent Systems. All rights reserved.
|
15 |
+
#
|
16 |
+
# Contact: [email protected]
|
17 |
+
|
18 |
+
from yacs.config import CfgNode as CN
|
19 |
+
import os
|
20 |
+
|
21 |
+
_C = CN(new_allowed=True)
|
22 |
+
|
23 |
+
# needed by trainer
|
24 |
+
_C.name = 'default'
|
25 |
+
_C.gpus = [0]
|
26 |
+
_C.test_gpus = [1]
|
27 |
+
_C.root = "./data/"
|
28 |
+
_C.ckpt_dir = './data/ckpt/'
|
29 |
+
_C.resume_path = ''
|
30 |
+
_C.normal_path = ''
|
31 |
+
_C.corr_path = ''
|
32 |
+
_C.results_path = './data/results/'
|
33 |
+
_C.projection_mode = 'orthogonal'
|
34 |
+
_C.num_views = 1
|
35 |
+
_C.sdf = False
|
36 |
+
_C.sdf_clip = 5.0
|
37 |
+
|
38 |
+
_C.lr_G = 1e-3
|
39 |
+
_C.lr_C = 1e-3
|
40 |
+
_C.lr_N = 2e-4
|
41 |
+
_C.weight_decay = 0.0
|
42 |
+
_C.momentum = 0.0
|
43 |
+
_C.optim = 'RMSprop'
|
44 |
+
_C.schedule = [5, 10, 15]
|
45 |
+
_C.gamma = 0.1
|
46 |
+
|
47 |
+
_C.overfit = False
|
48 |
+
_C.resume = False
|
49 |
+
_C.test_mode = False
|
50 |
+
_C.test_uv = False
|
51 |
+
_C.draw_geo_thres = 0.60
|
52 |
+
_C.num_sanity_val_steps = 2
|
53 |
+
_C.fast_dev = 0
|
54 |
+
_C.get_fit = False
|
55 |
+
_C.agora = False
|
56 |
+
_C.optim_cloth = False
|
57 |
+
_C.optim_body = False
|
58 |
+
_C.mcube_res = 256
|
59 |
+
_C.clean_mesh = True
|
60 |
+
_C.remesh = False
|
61 |
+
|
62 |
+
_C.batch_size = 4
|
63 |
+
_C.num_threads = 8
|
64 |
+
|
65 |
+
_C.num_epoch = 10
|
66 |
+
_C.freq_plot = 0.01
|
67 |
+
_C.freq_show_train = 0.1
|
68 |
+
_C.freq_show_val = 0.2
|
69 |
+
_C.freq_eval = 0.5
|
70 |
+
_C.accu_grad_batch = 4
|
71 |
+
|
72 |
+
_C.test_items = ['sv', 'mv', 'mv-fusion', 'hybrid', 'dc-pred', 'gt']
|
73 |
+
|
74 |
+
_C.net = CN()
|
75 |
+
_C.net.gtype = 'HGPIFuNet'
|
76 |
+
_C.net.ctype = 'resnet18'
|
77 |
+
_C.net.classifierIMF = 'MultiSegClassifier'
|
78 |
+
_C.net.netIMF = 'resnet18'
|
79 |
+
_C.net.norm = 'group'
|
80 |
+
_C.net.norm_mlp = 'group'
|
81 |
+
_C.net.norm_color = 'group'
|
82 |
+
_C.net.hg_down = 'ave_pool'
|
83 |
+
_C.net.num_views = 1
|
84 |
+
|
85 |
+
# kernel_size, stride, dilation, padding
|
86 |
+
|
87 |
+
_C.net.conv1 = [7, 2, 1, 3]
|
88 |
+
_C.net.conv3x3 = [3, 1, 1, 1]
|
89 |
+
|
90 |
+
_C.net.num_stack = 4
|
91 |
+
_C.net.num_hourglass = 2
|
92 |
+
_C.net.hourglass_dim = 256
|
93 |
+
_C.net.voxel_dim = 32
|
94 |
+
_C.net.resnet_dim = 120
|
95 |
+
_C.net.mlp_dim = [320, 1024, 512, 256, 128, 1]
|
96 |
+
_C.net.mlp_dim_knn = [320, 1024, 512, 256, 128, 3]
|
97 |
+
_C.net.mlp_dim_color = [513, 1024, 512, 256, 128, 3]
|
98 |
+
_C.net.mlp_dim_multiseg = [1088, 2048, 1024, 500]
|
99 |
+
_C.net.res_layers = [2, 3, 4]
|
100 |
+
_C.net.filter_dim = 256
|
101 |
+
_C.net.smpl_dim = 3
|
102 |
+
|
103 |
+
_C.net.cly_dim = 3
|
104 |
+
_C.net.soft_dim = 64
|
105 |
+
_C.net.z_size = 200.0
|
106 |
+
_C.net.N_freqs = 10
|
107 |
+
_C.net.geo_w = 0.1
|
108 |
+
_C.net.norm_w = 0.1
|
109 |
+
_C.net.dc_w = 0.1
|
110 |
+
_C.net.C_cat_to_G = False
|
111 |
+
|
112 |
+
_C.net.skip_hourglass = True
|
113 |
+
_C.net.use_tanh = True
|
114 |
+
_C.net.soft_onehot = True
|
115 |
+
_C.net.no_residual = True
|
116 |
+
_C.net.use_attention = False
|
117 |
+
|
118 |
+
_C.net.prior_type = "sdf"
|
119 |
+
_C.net.smpl_feats = ['sdf', 'cmap', 'norm', 'vis']
|
120 |
+
_C.net.use_filter = True
|
121 |
+
_C.net.use_cc = False
|
122 |
+
_C.net.use_PE = False
|
123 |
+
_C.net.use_IGR = False
|
124 |
+
_C.net.in_geo = ()
|
125 |
+
_C.net.in_nml = ()
|
126 |
+
|
127 |
+
_C.dataset = CN()
|
128 |
+
_C.dataset.root = ''
|
129 |
+
_C.dataset.set_splits = [0.95, 0.04]
|
130 |
+
_C.dataset.types = [
|
131 |
+
"3dpeople", "axyz", "renderpeople", "renderpeople_p27", "humanalloy"
|
132 |
+
]
|
133 |
+
_C.dataset.scales = [1.0, 100.0, 1.0, 1.0, 100.0 / 39.37]
|
134 |
+
_C.dataset.rp_type = "pifu900"
|
135 |
+
_C.dataset.th_type = 'train'
|
136 |
+
_C.dataset.input_size = 512
|
137 |
+
_C.dataset.rotation_num = 3
|
138 |
+
_C.dataset.num_precomp = 10 # Number of segmentation classifiers
|
139 |
+
_C.dataset.num_multiseg = 500 # Number of categories per classifier
|
140 |
+
_C.dataset.num_knn = 10 # for loss/error
|
141 |
+
_C.dataset.num_knn_dis = 20 # for accuracy
|
142 |
+
_C.dataset.num_verts_max = 20000
|
143 |
+
_C.dataset.zray_type = False
|
144 |
+
_C.dataset.online_smpl = False
|
145 |
+
_C.dataset.noise_type = ['z-trans', 'pose', 'beta']
|
146 |
+
_C.dataset.noise_scale = [0.0, 0.0, 0.0]
|
147 |
+
_C.dataset.num_sample_geo = 10000
|
148 |
+
_C.dataset.num_sample_color = 0
|
149 |
+
_C.dataset.num_sample_seg = 0
|
150 |
+
_C.dataset.num_sample_knn = 10000
|
151 |
+
|
152 |
+
_C.dataset.sigma_geo = 5.0
|
153 |
+
_C.dataset.sigma_color = 0.10
|
154 |
+
_C.dataset.sigma_seg = 0.10
|
155 |
+
_C.dataset.thickness_threshold = 20.0
|
156 |
+
_C.dataset.ray_sample_num = 2
|
157 |
+
_C.dataset.semantic_p = False
|
158 |
+
_C.dataset.remove_outlier = False
|
159 |
+
|
160 |
+
_C.dataset.train_bsize = 1.0
|
161 |
+
_C.dataset.val_bsize = 1.0
|
162 |
+
_C.dataset.test_bsize = 1.0
|
163 |
+
|
164 |
+
|
165 |
+
def get_cfg_defaults():
|
166 |
+
"""Get a yacs CfgNode object with default values for my_project."""
|
167 |
+
# Return a clone so that the defaults will not be altered
|
168 |
+
# This is for the "local variable" use pattern
|
169 |
+
return _C.clone()
|
170 |
+
|
171 |
+
|
172 |
+
# Alternatively, provide a way to import the defaults as
|
173 |
+
# a global singleton:
|
174 |
+
cfg = _C # users can `from config import cfg`
|
175 |
+
|
176 |
+
# cfg = get_cfg_defaults()
|
177 |
+
# cfg.merge_from_file('./configs/example.yaml')
|
178 |
+
|
179 |
+
# # Now override from a list (opts could come from the command line)
|
180 |
+
# opts = ['dataset.root', './data/XXXX', 'learning_rate', '1e-2']
|
181 |
+
# cfg.merge_from_list(opts)
|
182 |
+
|
183 |
+
|
184 |
+
def update_cfg(cfg_file):
|
185 |
+
# cfg = get_cfg_defaults()
|
186 |
+
_C.merge_from_file(cfg_file)
|
187 |
+
# return cfg.clone()
|
188 |
+
return _C
|
189 |
+
|
190 |
+
|
191 |
+
def parse_args(args):
|
192 |
+
cfg_file = args.cfg_file
|
193 |
+
if args.cfg_file is not None:
|
194 |
+
cfg = update_cfg(args.cfg_file)
|
195 |
+
else:
|
196 |
+
cfg = get_cfg_defaults()
|
197 |
+
|
198 |
+
# if args.misc is not None:
|
199 |
+
# cfg.merge_from_list(args.misc)
|
200 |
+
|
201 |
+
return cfg
|
202 |
+
|
203 |
+
|
204 |
+
def parse_args_extend(args):
|
205 |
+
if args.resume:
|
206 |
+
if not os.path.exists(args.log_dir):
|
207 |
+
raise ValueError(
|
208 |
+
'Experiment are set to resume mode, but log directory does not exist.'
|
209 |
+
)
|
210 |
+
|
211 |
+
# load log's cfg
|
212 |
+
cfg_file = os.path.join(args.log_dir, 'cfg.yaml')
|
213 |
+
cfg = update_cfg(cfg_file)
|
214 |
+
|
215 |
+
if args.misc is not None:
|
216 |
+
cfg.merge_from_list(args.misc)
|
217 |
+
else:
|
218 |
+
parse_args(args)
|
lib/common/render.py
ADDED
@@ -0,0 +1,388 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
|
4 |
+
# holder of all proprietary rights on this computer program.
|
5 |
+
# You can only use this computer program if you have closed
|
6 |
+
# a license agreement with MPG or you get the right to use the computer
|
7 |
+
# program from someone who is authorized to grant you that right.
|
8 |
+
# Any use of the computer program without a valid license is prohibited and
|
9 |
+
# liable to prosecution.
|
10 |
+
#
|
11 |
+
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
|
12 |
+
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
|
13 |
+
# for Intelligent Systems. All rights reserved.
|
14 |
+
#
|
15 |
+
# Contact: [email protected]
|
16 |
+
|
17 |
+
from pytorch3d.renderer import (
|
18 |
+
BlendParams,
|
19 |
+
blending,
|
20 |
+
look_at_view_transform,
|
21 |
+
FoVOrthographicCameras,
|
22 |
+
PointLights,
|
23 |
+
RasterizationSettings,
|
24 |
+
PointsRasterizationSettings,
|
25 |
+
PointsRenderer,
|
26 |
+
AlphaCompositor,
|
27 |
+
PointsRasterizer,
|
28 |
+
MeshRenderer,
|
29 |
+
MeshRasterizer,
|
30 |
+
SoftPhongShader,
|
31 |
+
SoftSilhouetteShader,
|
32 |
+
TexturesVertex,
|
33 |
+
)
|
34 |
+
from pytorch3d.renderer.mesh import TexturesVertex
|
35 |
+
from pytorch3d.structures import Meshes
|
36 |
+
|
37 |
+
import os
|
38 |
+
|
39 |
+
from lib.dataset.mesh_util import SMPLX, get_visibility
|
40 |
+
import lib.common.render_utils as util
|
41 |
+
import torch
|
42 |
+
import numpy as np
|
43 |
+
from PIL import Image
|
44 |
+
from tqdm import tqdm
|
45 |
+
import cv2
|
46 |
+
import math
|
47 |
+
from termcolor import colored
|
48 |
+
|
49 |
+
|
50 |
+
def image2vid(images, vid_path):
|
51 |
+
|
52 |
+
w, h = images[0].size
|
53 |
+
videodims = (w, h)
|
54 |
+
fourcc = cv2.VideoWriter_fourcc(*'XVID')
|
55 |
+
video = cv2.VideoWriter(vid_path, fourcc, 30, videodims)
|
56 |
+
for image in images:
|
57 |
+
video.write(cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR))
|
58 |
+
video.release()
|
59 |
+
|
60 |
+
|
61 |
+
def query_color(verts, faces, image, device):
|
62 |
+
"""query colors from points and image
|
63 |
+
|
64 |
+
Args:
|
65 |
+
verts ([B, 3]): [query verts]
|
66 |
+
faces ([M, 3]): [query faces]
|
67 |
+
image ([B, 3, H, W]): [full image]
|
68 |
+
|
69 |
+
Returns:
|
70 |
+
[np.float]: [return colors]
|
71 |
+
"""
|
72 |
+
|
73 |
+
verts = verts.float().to(device)
|
74 |
+
faces = faces.long().to(device)
|
75 |
+
|
76 |
+
(xy, z) = verts.split([2, 1], dim=1)
|
77 |
+
visibility = get_visibility(xy, z, faces[:, [0, 2, 1]]).flatten()
|
78 |
+
uv = xy.unsqueeze(0).unsqueeze(2) # [B, N, 2]
|
79 |
+
uv = uv * torch.tensor([1.0, -1.0]).type_as(uv)
|
80 |
+
colors = (torch.nn.functional.grid_sample(image, uv, align_corners=True)[
|
81 |
+
0, :, :, 0].permute(1, 0) + 1.0) * 0.5 * 255.0
|
82 |
+
colors[visibility == 0.0] = ((Meshes(verts.unsqueeze(0), faces.unsqueeze(
|
83 |
+
0)).verts_normals_padded().squeeze(0) + 1.0) * 0.5 * 255.0)[visibility == 0.0]
|
84 |
+
|
85 |
+
return colors.detach().cpu()
|
86 |
+
|
87 |
+
|
88 |
+
class cleanShader(torch.nn.Module):
|
89 |
+
def __init__(self, device="cpu", cameras=None, blend_params=None):
|
90 |
+
super().__init__()
|
91 |
+
self.cameras = cameras
|
92 |
+
self.blend_params = blend_params if blend_params is not None else BlendParams()
|
93 |
+
|
94 |
+
def forward(self, fragments, meshes, **kwargs):
|
95 |
+
cameras = kwargs.get("cameras", self.cameras)
|
96 |
+
if cameras is None:
|
97 |
+
msg = "Cameras must be specified either at initialization \
|
98 |
+
or in the forward pass of TexturedSoftPhongShader"
|
99 |
+
|
100 |
+
raise ValueError(msg)
|
101 |
+
|
102 |
+
# get renderer output
|
103 |
+
blend_params = kwargs.get("blend_params", self.blend_params)
|
104 |
+
texels = meshes.sample_textures(fragments)
|
105 |
+
images = blending.softmax_rgb_blend(
|
106 |
+
texels, fragments, blend_params, znear=-256, zfar=256
|
107 |
+
)
|
108 |
+
|
109 |
+
return images
|
110 |
+
|
111 |
+
|
112 |
+
class Render:
|
113 |
+
def __init__(self, size=512, device=torch.device("cuda:0")):
|
114 |
+
self.device = device
|
115 |
+
self.mesh_y_center = 100.0
|
116 |
+
self.dis = 100.0
|
117 |
+
self.scale = 1.0
|
118 |
+
self.size = size
|
119 |
+
self.cam_pos = [(0, 100, 100)]
|
120 |
+
|
121 |
+
self.mesh = None
|
122 |
+
self.deform_mesh = None
|
123 |
+
self.pcd = None
|
124 |
+
self.renderer = None
|
125 |
+
self.meshRas = None
|
126 |
+
self.type = None
|
127 |
+
self.knn = None
|
128 |
+
self.knn_inverse = None
|
129 |
+
|
130 |
+
self.smpl_seg = None
|
131 |
+
self.smpl_cmap = None
|
132 |
+
|
133 |
+
self.smplx = SMPLX()
|
134 |
+
|
135 |
+
self.uv_rasterizer = util.Pytorch3dRasterizer(self.size)
|
136 |
+
|
137 |
+
def get_camera(self, cam_id):
|
138 |
+
|
139 |
+
R, T = look_at_view_transform(
|
140 |
+
eye=[self.cam_pos[cam_id]],
|
141 |
+
at=((0, self.mesh_y_center, 0),),
|
142 |
+
up=((0, 1, 0),),
|
143 |
+
)
|
144 |
+
|
145 |
+
camera = FoVOrthographicCameras(
|
146 |
+
device=self.device,
|
147 |
+
R=R,
|
148 |
+
T=T,
|
149 |
+
znear=100.0,
|
150 |
+
zfar=-100.0,
|
151 |
+
max_y=100.0,
|
152 |
+
min_y=-100.0,
|
153 |
+
max_x=100.0,
|
154 |
+
min_x=-100.0,
|
155 |
+
scale_xyz=(self.scale * np.ones(3),),
|
156 |
+
)
|
157 |
+
|
158 |
+
return camera
|
159 |
+
|
160 |
+
def init_renderer(self, camera, type="clean_mesh", bg="gray"):
|
161 |
+
|
162 |
+
if "mesh" in type:
|
163 |
+
|
164 |
+
# rasterizer
|
165 |
+
self.raster_settings_mesh = RasterizationSettings(
|
166 |
+
image_size=self.size,
|
167 |
+
blur_radius=np.log(1.0 / 1e-4) * 1e-7,
|
168 |
+
faces_per_pixel=30,
|
169 |
+
)
|
170 |
+
self.meshRas = MeshRasterizer(
|
171 |
+
cameras=camera, raster_settings=self.raster_settings_mesh
|
172 |
+
)
|
173 |
+
|
174 |
+
if bg == "black":
|
175 |
+
blendparam = BlendParams(1e-4, 1e-4, (0.0, 0.0, 0.0))
|
176 |
+
elif bg == "white":
|
177 |
+
blendparam = BlendParams(1e-4, 1e-8, (1.0, 1.0, 1.0))
|
178 |
+
elif bg == "gray":
|
179 |
+
blendparam = BlendParams(1e-4, 1e-8, (0.5, 0.5, 0.5))
|
180 |
+
|
181 |
+
if type == "ori_mesh":
|
182 |
+
|
183 |
+
lights = PointLights(
|
184 |
+
device=self.device,
|
185 |
+
ambient_color=((0.8, 0.8, 0.8),),
|
186 |
+
diffuse_color=((0.2, 0.2, 0.2),),
|
187 |
+
specular_color=((0.0, 0.0, 0.0),),
|
188 |
+
location=[[0.0, 200.0, 0.0]],
|
189 |
+
)
|
190 |
+
|
191 |
+
self.renderer = MeshRenderer(
|
192 |
+
rasterizer=self.meshRas,
|
193 |
+
shader=SoftPhongShader(
|
194 |
+
device=self.device,
|
195 |
+
cameras=camera,
|
196 |
+
lights=lights,
|
197 |
+
blend_params=blendparam,
|
198 |
+
),
|
199 |
+
)
|
200 |
+
|
201 |
+
if type == "silhouette":
|
202 |
+
self.raster_settings_silhouette = RasterizationSettings(
|
203 |
+
image_size=self.size,
|
204 |
+
blur_radius=np.log(1.0 / 1e-4 - 1.0) * 5e-5,
|
205 |
+
faces_per_pixel=50,
|
206 |
+
cull_backfaces=True,
|
207 |
+
)
|
208 |
+
|
209 |
+
self.silhouetteRas = MeshRasterizer(
|
210 |
+
cameras=camera, raster_settings=self.raster_settings_silhouette
|
211 |
+
)
|
212 |
+
self.renderer = MeshRenderer(
|
213 |
+
rasterizer=self.silhouetteRas, shader=SoftSilhouetteShader()
|
214 |
+
)
|
215 |
+
|
216 |
+
if type == "pointcloud":
|
217 |
+
self.raster_settings_pcd = PointsRasterizationSettings(
|
218 |
+
image_size=self.size, radius=0.006, points_per_pixel=10
|
219 |
+
)
|
220 |
+
|
221 |
+
self.pcdRas = PointsRasterizer(
|
222 |
+
cameras=camera, raster_settings=self.raster_settings_pcd
|
223 |
+
)
|
224 |
+
self.renderer = PointsRenderer(
|
225 |
+
rasterizer=self.pcdRas,
|
226 |
+
compositor=AlphaCompositor(background_color=(0, 0, 0)),
|
227 |
+
)
|
228 |
+
|
229 |
+
if type == "clean_mesh":
|
230 |
+
|
231 |
+
self.renderer = MeshRenderer(
|
232 |
+
rasterizer=self.meshRas,
|
233 |
+
shader=cleanShader(
|
234 |
+
device=self.device, cameras=camera, blend_params=blendparam
|
235 |
+
),
|
236 |
+
)
|
237 |
+
|
238 |
+
def VF2Mesh(self, verts, faces):
|
239 |
+
|
240 |
+
if not torch.is_tensor(verts):
|
241 |
+
verts = torch.tensor(verts)
|
242 |
+
if not torch.is_tensor(faces):
|
243 |
+
faces = torch.tensor(faces)
|
244 |
+
|
245 |
+
if verts.ndimension() == 2:
|
246 |
+
verts = verts.unsqueeze(0).float()
|
247 |
+
if faces.ndimension() == 2:
|
248 |
+
faces = faces.unsqueeze(0).long()
|
249 |
+
|
250 |
+
verts = verts.to(self.device)
|
251 |
+
faces = faces.to(self.device)
|
252 |
+
|
253 |
+
mesh = Meshes(verts, faces).to(self.device)
|
254 |
+
|
255 |
+
mesh.textures = TexturesVertex(
|
256 |
+
verts_features=(mesh.verts_normals_padded() + 1.0) * 0.5
|
257 |
+
)
|
258 |
+
|
259 |
+
return mesh
|
260 |
+
|
261 |
+
def load_meshes(self, verts, faces):
|
262 |
+
"""load mesh into the pytorch3d renderer
|
263 |
+
|
264 |
+
Args:
|
265 |
+
verts ([N,3]): verts
|
266 |
+
faces ([N,3]): faces
|
267 |
+
offset ([N,3]): offset
|
268 |
+
"""
|
269 |
+
|
270 |
+
# camera setting
|
271 |
+
self.scale = 100.0
|
272 |
+
self.mesh_y_center = 0.0
|
273 |
+
|
274 |
+
self.cam_pos = [
|
275 |
+
(0, self.mesh_y_center, 100.0),
|
276 |
+
(100.0, self.mesh_y_center, 0),
|
277 |
+
(0, self.mesh_y_center, -100.0),
|
278 |
+
(-100.0, self.mesh_y_center, 0),
|
279 |
+
]
|
280 |
+
|
281 |
+
self.type = "color"
|
282 |
+
|
283 |
+
if isinstance(verts, list):
|
284 |
+
self.meshes = []
|
285 |
+
for V, F in zip(verts, faces):
|
286 |
+
self.meshes.append(self.VF2Mesh(V, F))
|
287 |
+
else:
|
288 |
+
self.meshes = [self.VF2Mesh(verts, faces)]
|
289 |
+
|
290 |
+
def get_depth_map(self, cam_ids=[0, 2]):
|
291 |
+
|
292 |
+
depth_maps = []
|
293 |
+
for cam_id in cam_ids:
|
294 |
+
self.init_renderer(self.get_camera(cam_id), "clean_mesh", "gray")
|
295 |
+
fragments = self.meshRas(self.meshes[0])
|
296 |
+
depth_map = fragments.zbuf[..., 0].squeeze(0)
|
297 |
+
if cam_id == 2:
|
298 |
+
depth_map = torch.fliplr(depth_map)
|
299 |
+
depth_maps.append(depth_map)
|
300 |
+
|
301 |
+
return depth_maps
|
302 |
+
|
303 |
+
def get_rgb_image(self, cam_ids=[0, 2]):
|
304 |
+
|
305 |
+
images = []
|
306 |
+
for cam_id in range(len(self.cam_pos)):
|
307 |
+
if cam_id in cam_ids:
|
308 |
+
self.init_renderer(self.get_camera(
|
309 |
+
cam_id), "clean_mesh", "gray")
|
310 |
+
if len(cam_ids) == 4:
|
311 |
+
rendered_img = (
|
312 |
+
self.renderer(self.meshes[0])[
|
313 |
+
0:1, :, :, :3].permute(0, 3, 1, 2)
|
314 |
+
- 0.5
|
315 |
+
) * 2.0
|
316 |
+
else:
|
317 |
+
rendered_img = (
|
318 |
+
self.renderer(self.meshes[0])[
|
319 |
+
0:1, :, :, :3].permute(0, 3, 1, 2)
|
320 |
+
- 0.5
|
321 |
+
) * 2.0
|
322 |
+
if cam_id == 2 and len(cam_ids) == 2:
|
323 |
+
rendered_img = torch.flip(rendered_img, dims=[3])
|
324 |
+
images.append(rendered_img)
|
325 |
+
|
326 |
+
return images
|
327 |
+
|
328 |
+
def get_rendered_video(self, images, save_path):
|
329 |
+
|
330 |
+
self.cam_pos = []
|
331 |
+
for angle in range(0, 360, 3):
|
332 |
+
self.cam_pos.append(
|
333 |
+
(
|
334 |
+
100.0 * math.cos(np.pi / 180 * angle),
|
335 |
+
self.mesh_y_center,
|
336 |
+
100.0 * math.sin(np.pi / 180 * angle),
|
337 |
+
)
|
338 |
+
)
|
339 |
+
|
340 |
+
old_shape = np.array(images[0].shape[:2])
|
341 |
+
new_shape = np.around(
|
342 |
+
(self.size / old_shape[0]) * old_shape).astype(np.int)
|
343 |
+
|
344 |
+
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
|
345 |
+
video = cv2.VideoWriter(
|
346 |
+
save_path, fourcc, 30, (self.size * len(self.meshes) +
|
347 |
+
new_shape[1] * len(images), self.size)
|
348 |
+
)
|
349 |
+
|
350 |
+
pbar = tqdm(range(len(self.cam_pos)))
|
351 |
+
pbar.set_description(colored(f"exporting video {os.path.basename(save_path)}...", "blue"))
|
352 |
+
for cam_id in pbar:
|
353 |
+
self.init_renderer(self.get_camera(cam_id), "clean_mesh", "gray")
|
354 |
+
|
355 |
+
img_lst = [
|
356 |
+
np.array(Image.fromarray(img).resize(new_shape[::-1])).astype(np.uint8)[
|
357 |
+
:, :, [2, 1, 0]
|
358 |
+
]
|
359 |
+
for img in images
|
360 |
+
]
|
361 |
+
|
362 |
+
for mesh in self.meshes:
|
363 |
+
rendered_img = (
|
364 |
+
(self.renderer(mesh)[0, :, :, :3] * 255.0)
|
365 |
+
.detach()
|
366 |
+
.cpu()
|
367 |
+
.numpy()
|
368 |
+
.astype(np.uint8)
|
369 |
+
)
|
370 |
+
|
371 |
+
img_lst.append(rendered_img)
|
372 |
+
final_img = np.concatenate(img_lst, axis=1)
|
373 |
+
video.write(final_img)
|
374 |
+
|
375 |
+
video.release()
|
376 |
+
|
377 |
+
def get_silhouette_image(self, cam_ids=[0, 2]):
|
378 |
+
|
379 |
+
images = []
|
380 |
+
for cam_id in range(len(self.cam_pos)):
|
381 |
+
if cam_id in cam_ids:
|
382 |
+
self.init_renderer(self.get_camera(cam_id), "silhouette")
|
383 |
+
rendered_img = self.renderer(self.meshes[0])[0:1, :, :, 3]
|
384 |
+
if cam_id == 2 and len(cam_ids) == 2:
|
385 |
+
rendered_img = torch.flip(rendered_img, dims=[2])
|
386 |
+
images.append(rendered_img)
|
387 |
+
|
388 |
+
return images
|
lib/common/render_utils.py
ADDED
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
|
5 |
+
# holder of all proprietary rights on this computer program.
|
6 |
+
# You can only use this computer program if you have closed
|
7 |
+
# a license agreement with MPG or you get the right to use the computer
|
8 |
+
# program from someone who is authorized to grant you that right.
|
9 |
+
# Any use of the computer program without a valid license is prohibited and
|
10 |
+
# liable to prosecution.
|
11 |
+
#
|
12 |
+
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
|
13 |
+
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
|
14 |
+
# for Intelligent Systems. All rights reserved.
|
15 |
+
#
|
16 |
+
# Contact: [email protected]
|
17 |
+
|
18 |
+
import torch
|
19 |
+
from torch import nn
|
20 |
+
import trimesh
|
21 |
+
import math
|
22 |
+
from typing import NewType
|
23 |
+
from pytorch3d.structures import Meshes
|
24 |
+
from pytorch3d.renderer.mesh import rasterize_meshes
|
25 |
+
|
26 |
+
Tensor = NewType('Tensor', torch.Tensor)
|
27 |
+
|
28 |
+
|
29 |
+
def solid_angles(points: Tensor,
|
30 |
+
triangles: Tensor,
|
31 |
+
thresh: float = 1e-8) -> Tensor:
|
32 |
+
''' Compute solid angle between the input points and triangles
|
33 |
+
Follows the method described in:
|
34 |
+
The Solid Angle of a Plane Triangle
|
35 |
+
A. VAN OOSTEROM AND J. STRACKEE
|
36 |
+
IEEE TRANSACTIONS ON BIOMEDICAL ENGINEERING,
|
37 |
+
VOL. BME-30, NO. 2, FEBRUARY 1983
|
38 |
+
Parameters
|
39 |
+
-----------
|
40 |
+
points: BxQx3
|
41 |
+
Tensor of input query points
|
42 |
+
triangles: BxFx3x3
|
43 |
+
Target triangles
|
44 |
+
thresh: float
|
45 |
+
float threshold
|
46 |
+
Returns
|
47 |
+
-------
|
48 |
+
solid_angles: BxQxF
|
49 |
+
A tensor containing the solid angle between all query points
|
50 |
+
and input triangles
|
51 |
+
'''
|
52 |
+
# Center the triangles on the query points. Size should be BxQxFx3x3
|
53 |
+
centered_tris = triangles[:, None] - points[:, :, None, None]
|
54 |
+
|
55 |
+
# BxQxFx3
|
56 |
+
norms = torch.norm(centered_tris, dim=-1)
|
57 |
+
|
58 |
+
# Should be BxQxFx3
|
59 |
+
cross_prod = torch.cross(centered_tris[:, :, :, 1],
|
60 |
+
centered_tris[:, :, :, 2],
|
61 |
+
dim=-1)
|
62 |
+
# Should be BxQxF
|
63 |
+
numerator = (centered_tris[:, :, :, 0] * cross_prod).sum(dim=-1)
|
64 |
+
del cross_prod
|
65 |
+
|
66 |
+
dot01 = (centered_tris[:, :, :, 0] * centered_tris[:, :, :, 1]).sum(dim=-1)
|
67 |
+
dot12 = (centered_tris[:, :, :, 1] * centered_tris[:, :, :, 2]).sum(dim=-1)
|
68 |
+
dot02 = (centered_tris[:, :, :, 0] * centered_tris[:, :, :, 2]).sum(dim=-1)
|
69 |
+
del centered_tris
|
70 |
+
|
71 |
+
denominator = (norms.prod(dim=-1) + dot01 * norms[:, :, :, 2] +
|
72 |
+
dot02 * norms[:, :, :, 1] + dot12 * norms[:, :, :, 0])
|
73 |
+
del dot01, dot12, dot02, norms
|
74 |
+
|
75 |
+
# Should be BxQ
|
76 |
+
solid_angle = torch.atan2(numerator, denominator)
|
77 |
+
del numerator, denominator
|
78 |
+
|
79 |
+
torch.cuda.empty_cache()
|
80 |
+
|
81 |
+
return 2 * solid_angle
|
82 |
+
|
83 |
+
|
84 |
+
def winding_numbers(points: Tensor,
|
85 |
+
triangles: Tensor,
|
86 |
+
thresh: float = 1e-8) -> Tensor:
|
87 |
+
''' Uses winding_numbers to compute inside/outside
|
88 |
+
Robust inside-outside segmentation using generalized winding numbers
|
89 |
+
Alec Jacobson,
|
90 |
+
Ladislav Kavan,
|
91 |
+
Olga Sorkine-Hornung
|
92 |
+
Fast Winding Numbers for Soups and Clouds SIGGRAPH 2018
|
93 |
+
Gavin Barill
|
94 |
+
NEIL G. Dickson
|
95 |
+
Ryan Schmidt
|
96 |
+
David I.W. Levin
|
97 |
+
and Alec Jacobson
|
98 |
+
Parameters
|
99 |
+
-----------
|
100 |
+
points: BxQx3
|
101 |
+
Tensor of input query points
|
102 |
+
triangles: BxFx3x3
|
103 |
+
Target triangles
|
104 |
+
thresh: float
|
105 |
+
float threshold
|
106 |
+
Returns
|
107 |
+
-------
|
108 |
+
winding_numbers: BxQ
|
109 |
+
A tensor containing the Generalized winding numbers
|
110 |
+
'''
|
111 |
+
# The generalized winding number is the sum of solid angles of the point
|
112 |
+
# with respect to all triangles.
|
113 |
+
return 1 / (4 * math.pi) * solid_angles(points, triangles,
|
114 |
+
thresh=thresh).sum(dim=-1)
|
115 |
+
|
116 |
+
|
117 |
+
def batch_contains(verts, faces, points):
|
118 |
+
|
119 |
+
B = verts.shape[0]
|
120 |
+
N = points.shape[1]
|
121 |
+
|
122 |
+
verts = verts.detach().cpu()
|
123 |
+
faces = faces.detach().cpu()
|
124 |
+
points = points.detach().cpu()
|
125 |
+
contains = torch.zeros(B, N)
|
126 |
+
|
127 |
+
for i in range(B):
|
128 |
+
contains[i] = torch.as_tensor(
|
129 |
+
trimesh.Trimesh(verts[i], faces[i]).contains(points[i]))
|
130 |
+
|
131 |
+
return 2.0 * (contains - 0.5)
|
132 |
+
|
133 |
+
|
134 |
+
def dict2obj(d):
|
135 |
+
# if isinstance(d, list):
|
136 |
+
# d = [dict2obj(x) for x in d]
|
137 |
+
if not isinstance(d, dict):
|
138 |
+
return d
|
139 |
+
|
140 |
+
class C(object):
|
141 |
+
pass
|
142 |
+
|
143 |
+
o = C()
|
144 |
+
for k in d:
|
145 |
+
o.__dict__[k] = dict2obj(d[k])
|
146 |
+
return o
|
147 |
+
|
148 |
+
|
149 |
+
def face_vertices(vertices, faces):
|
150 |
+
"""
|
151 |
+
:param vertices: [batch size, number of vertices, 3]
|
152 |
+
:param faces: [batch size, number of faces, 3]
|
153 |
+
:return: [batch size, number of faces, 3, 3]
|
154 |
+
"""
|
155 |
+
|
156 |
+
bs, nv = vertices.shape[:2]
|
157 |
+
bs, nf = faces.shape[:2]
|
158 |
+
device = vertices.device
|
159 |
+
faces = faces + (torch.arange(bs, dtype=torch.int32).to(device) *
|
160 |
+
nv)[:, None, None]
|
161 |
+
vertices = vertices.reshape((bs * nv, vertices.shape[-1]))
|
162 |
+
|
163 |
+
return vertices[faces.long()]
|
164 |
+
|
165 |
+
|
166 |
+
class Pytorch3dRasterizer(nn.Module):
|
167 |
+
""" Borrowed from https://github.com/facebookresearch/pytorch3d
|
168 |
+
Notice:
|
169 |
+
x,y,z are in image space, normalized
|
170 |
+
can only render squared image now
|
171 |
+
"""
|
172 |
+
|
173 |
+
def __init__(self, image_size=224):
|
174 |
+
"""
|
175 |
+
use fixed raster_settings for rendering faces
|
176 |
+
"""
|
177 |
+
super().__init__()
|
178 |
+
raster_settings = {
|
179 |
+
'image_size': image_size,
|
180 |
+
'blur_radius': 0.0,
|
181 |
+
'faces_per_pixel': 1,
|
182 |
+
'bin_size': None,
|
183 |
+
'max_faces_per_bin': None,
|
184 |
+
'perspective_correct': True,
|
185 |
+
'cull_backfaces': True,
|
186 |
+
}
|
187 |
+
raster_settings = dict2obj(raster_settings)
|
188 |
+
self.raster_settings = raster_settings
|
189 |
+
|
190 |
+
def forward(self, vertices, faces, attributes=None):
|
191 |
+
fixed_vertices = vertices.clone()
|
192 |
+
fixed_vertices[..., :2] = -fixed_vertices[..., :2]
|
193 |
+
meshes_screen = Meshes(verts=fixed_vertices.float(),
|
194 |
+
faces=faces.long())
|
195 |
+
raster_settings = self.raster_settings
|
196 |
+
pix_to_face, zbuf, bary_coords, dists = rasterize_meshes(
|
197 |
+
meshes_screen,
|
198 |
+
image_size=raster_settings.image_size,
|
199 |
+
blur_radius=raster_settings.blur_radius,
|
200 |
+
faces_per_pixel=raster_settings.faces_per_pixel,
|
201 |
+
bin_size=raster_settings.bin_size,
|
202 |
+
max_faces_per_bin=raster_settings.max_faces_per_bin,
|
203 |
+
perspective_correct=raster_settings.perspective_correct,
|
204 |
+
)
|
205 |
+
vismask = (pix_to_face > -1).float()
|
206 |
+
D = attributes.shape[-1]
|
207 |
+
attributes = attributes.clone()
|
208 |
+
attributes = attributes.view(attributes.shape[0] * attributes.shape[1],
|
209 |
+
3, attributes.shape[-1])
|
210 |
+
N, H, W, K, _ = bary_coords.shape
|
211 |
+
mask = pix_to_face == -1
|
212 |
+
pix_to_face = pix_to_face.clone()
|
213 |
+
pix_to_face[mask] = 0
|
214 |
+
idx = pix_to_face.view(N * H * W * K, 1, 1).expand(N * H * W * K, 3, D)
|
215 |
+
pixel_face_vals = attributes.gather(0, idx).view(N, H, W, K, 3, D)
|
216 |
+
pixel_vals = (bary_coords[..., None] * pixel_face_vals).sum(dim=-2)
|
217 |
+
pixel_vals[mask] = 0 # Replace masked values in output.
|
218 |
+
pixel_vals = pixel_vals[:, :, :, 0].permute(0, 3, 1, 2)
|
219 |
+
pixel_vals = torch.cat(
|
220 |
+
[pixel_vals, vismask[:, :, :, 0][:, None, :, :]], dim=1)
|
221 |
+
return pixel_vals
|
lib/common/seg3d_lossless.py
ADDED
@@ -0,0 +1,604 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
|
5 |
+
# holder of all proprietary rights on this computer program.
|
6 |
+
# You can only use this computer program if you have closed
|
7 |
+
# a license agreement with MPG or you get the right to use the computer
|
8 |
+
# program from someone who is authorized to grant you that right.
|
9 |
+
# Any use of the computer program without a valid license is prohibited and
|
10 |
+
# liable to prosecution.
|
11 |
+
#
|
12 |
+
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
|
13 |
+
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
|
14 |
+
# for Intelligent Systems. All rights reserved.
|
15 |
+
#
|
16 |
+
# Contact: [email protected]
|
17 |
+
|
18 |
+
|
19 |
+
from .seg3d_utils import (
|
20 |
+
create_grid3D,
|
21 |
+
plot_mask3D,
|
22 |
+
SmoothConv3D,
|
23 |
+
)
|
24 |
+
|
25 |
+
import torch
|
26 |
+
import torch.nn as nn
|
27 |
+
import numpy as np
|
28 |
+
import torch.nn.functional as F
|
29 |
+
import mcubes
|
30 |
+
from kaolin.ops.conversions import voxelgrids_to_trianglemeshes
|
31 |
+
import logging
|
32 |
+
|
33 |
+
logging.getLogger("lightning").setLevel(logging.ERROR)
|
34 |
+
|
35 |
+
|
36 |
+
class Seg3dLossless(nn.Module):
|
37 |
+
def __init__(self,
|
38 |
+
query_func,
|
39 |
+
b_min,
|
40 |
+
b_max,
|
41 |
+
resolutions,
|
42 |
+
channels=1,
|
43 |
+
balance_value=0.5,
|
44 |
+
align_corners=False,
|
45 |
+
visualize=False,
|
46 |
+
debug=False,
|
47 |
+
use_cuda_impl=False,
|
48 |
+
faster=False,
|
49 |
+
use_shadow=False,
|
50 |
+
**kwargs):
|
51 |
+
"""
|
52 |
+
align_corners: same with how you process gt. (grid_sample / interpolate)
|
53 |
+
"""
|
54 |
+
super().__init__()
|
55 |
+
self.query_func = query_func
|
56 |
+
self.register_buffer(
|
57 |
+
'b_min',
|
58 |
+
torch.tensor(b_min).float().unsqueeze(1)) # [bz, 1, 3]
|
59 |
+
self.register_buffer(
|
60 |
+
'b_max',
|
61 |
+
torch.tensor(b_max).float().unsqueeze(1)) # [bz, 1, 3]
|
62 |
+
|
63 |
+
# ti.init(arch=ti.cuda)
|
64 |
+
# self.mciso_taichi = MCISO(dim=3, N=resolutions[-1]-1)
|
65 |
+
|
66 |
+
if type(resolutions[0]) is int:
|
67 |
+
resolutions = torch.tensor([(res, res, res)
|
68 |
+
for res in resolutions])
|
69 |
+
else:
|
70 |
+
resolutions = torch.tensor(resolutions)
|
71 |
+
self.register_buffer('resolutions', resolutions)
|
72 |
+
self.batchsize = self.b_min.size(0)
|
73 |
+
assert self.batchsize == 1
|
74 |
+
self.balance_value = balance_value
|
75 |
+
self.channels = channels
|
76 |
+
assert self.channels == 1
|
77 |
+
self.align_corners = align_corners
|
78 |
+
self.visualize = visualize
|
79 |
+
self.debug = debug
|
80 |
+
self.use_cuda_impl = use_cuda_impl
|
81 |
+
self.faster = faster
|
82 |
+
self.use_shadow = use_shadow
|
83 |
+
|
84 |
+
for resolution in resolutions:
|
85 |
+
assert resolution[0] % 2 == 1 and resolution[1] % 2 == 1, \
|
86 |
+
f"resolution {resolution} need to be odd becuase of align_corner."
|
87 |
+
|
88 |
+
# init first resolution
|
89 |
+
init_coords = create_grid3D(0,
|
90 |
+
resolutions[-1] - 1,
|
91 |
+
steps=resolutions[0]) # [N, 3]
|
92 |
+
init_coords = init_coords.unsqueeze(0).repeat(self.batchsize, 1,
|
93 |
+
1) # [bz, N, 3]
|
94 |
+
self.register_buffer('init_coords', init_coords)
|
95 |
+
|
96 |
+
# some useful tensors
|
97 |
+
calculated = torch.zeros(
|
98 |
+
(self.resolutions[-1][2], self.resolutions[-1][1],
|
99 |
+
self.resolutions[-1][0]),
|
100 |
+
dtype=torch.bool)
|
101 |
+
self.register_buffer('calculated', calculated)
|
102 |
+
|
103 |
+
gird8_offsets = torch.stack(
|
104 |
+
torch.meshgrid([
|
105 |
+
torch.tensor([-1, 0, 1]),
|
106 |
+
torch.tensor([-1, 0, 1]),
|
107 |
+
torch.tensor([-1, 0, 1])
|
108 |
+
])).int().view(3, -1).t() # [27, 3]
|
109 |
+
self.register_buffer('gird8_offsets', gird8_offsets)
|
110 |
+
|
111 |
+
# smooth convs
|
112 |
+
self.smooth_conv3x3 = SmoothConv3D(in_channels=1,
|
113 |
+
out_channels=1,
|
114 |
+
kernel_size=3)
|
115 |
+
self.smooth_conv5x5 = SmoothConv3D(in_channels=1,
|
116 |
+
out_channels=1,
|
117 |
+
kernel_size=5)
|
118 |
+
self.smooth_conv7x7 = SmoothConv3D(in_channels=1,
|
119 |
+
out_channels=1,
|
120 |
+
kernel_size=7)
|
121 |
+
self.smooth_conv9x9 = SmoothConv3D(in_channels=1,
|
122 |
+
out_channels=1,
|
123 |
+
kernel_size=9)
|
124 |
+
|
125 |
+
def batch_eval(self, coords, **kwargs):
|
126 |
+
"""
|
127 |
+
coords: in the coordinates of last resolution
|
128 |
+
**kwargs: for query_func
|
129 |
+
"""
|
130 |
+
coords = coords.detach()
|
131 |
+
# normalize coords to fit in [b_min, b_max]
|
132 |
+
if self.align_corners:
|
133 |
+
coords2D = coords.float() / (self.resolutions[-1] - 1)
|
134 |
+
else:
|
135 |
+
step = 1.0 / self.resolutions[-1].float()
|
136 |
+
coords2D = coords.float() / self.resolutions[-1] + step / 2
|
137 |
+
coords2D = coords2D * (self.b_max - self.b_min) + self.b_min
|
138 |
+
# query function
|
139 |
+
occupancys = self.query_func(**kwargs, points=coords2D)
|
140 |
+
if type(occupancys) is list:
|
141 |
+
occupancys = torch.stack(occupancys) # [bz, C, N]
|
142 |
+
assert len(occupancys.size()) == 3, \
|
143 |
+
"query_func should return a occupancy with shape of [bz, C, N]"
|
144 |
+
return occupancys
|
145 |
+
|
146 |
+
def forward(self, **kwargs):
|
147 |
+
if self.faster:
|
148 |
+
return self._forward_faster(**kwargs)
|
149 |
+
else:
|
150 |
+
return self._forward(**kwargs)
|
151 |
+
|
152 |
+
def _forward_faster(self, **kwargs):
|
153 |
+
"""
|
154 |
+
In faster mode, we make following changes to exchange accuracy for speed:
|
155 |
+
1. no conflict checking: 4.88 fps -> 6.56 fps
|
156 |
+
2. smooth_conv9x9 ~ smooth_conv3x3 for different resolution
|
157 |
+
3. last step no examine
|
158 |
+
"""
|
159 |
+
final_W = self.resolutions[-1][0]
|
160 |
+
final_H = self.resolutions[-1][1]
|
161 |
+
final_D = self.resolutions[-1][2]
|
162 |
+
|
163 |
+
for resolution in self.resolutions:
|
164 |
+
W, H, D = resolution
|
165 |
+
stride = (self.resolutions[-1] - 1) / (resolution - 1)
|
166 |
+
|
167 |
+
# first step
|
168 |
+
if torch.equal(resolution, self.resolutions[0]):
|
169 |
+
coords = self.init_coords.clone() # torch.long
|
170 |
+
occupancys = self.batch_eval(coords, **kwargs)
|
171 |
+
occupancys = occupancys.view(self.batchsize, self.channels, D,
|
172 |
+
H, W)
|
173 |
+
if (occupancys > 0.5).sum() == 0:
|
174 |
+
# return F.interpolate(
|
175 |
+
# occupancys, size=(final_D, final_H, final_W),
|
176 |
+
# mode="linear", align_corners=True)
|
177 |
+
return None
|
178 |
+
|
179 |
+
if self.visualize:
|
180 |
+
self.plot(occupancys, coords, final_D, final_H, final_W)
|
181 |
+
|
182 |
+
with torch.no_grad():
|
183 |
+
coords_accum = coords / stride
|
184 |
+
|
185 |
+
# last step
|
186 |
+
elif torch.equal(resolution, self.resolutions[-1]):
|
187 |
+
|
188 |
+
with torch.no_grad():
|
189 |
+
# here true is correct!
|
190 |
+
valid = F.interpolate(
|
191 |
+
(occupancys > self.balance_value).float(),
|
192 |
+
size=(D, H, W),
|
193 |
+
mode="trilinear",
|
194 |
+
align_corners=True)
|
195 |
+
|
196 |
+
# here true is correct!
|
197 |
+
occupancys = F.interpolate(occupancys.float(),
|
198 |
+
size=(D, H, W),
|
199 |
+
mode="trilinear",
|
200 |
+
align_corners=True)
|
201 |
+
|
202 |
+
# is_boundary = (valid > 0.0) & (valid < 1.0)
|
203 |
+
is_boundary = valid == 0.5
|
204 |
+
|
205 |
+
# next steps
|
206 |
+
else:
|
207 |
+
coords_accum *= 2
|
208 |
+
|
209 |
+
with torch.no_grad():
|
210 |
+
# here true is correct!
|
211 |
+
valid = F.interpolate(
|
212 |
+
(occupancys > self.balance_value).float(),
|
213 |
+
size=(D, H, W),
|
214 |
+
mode="trilinear",
|
215 |
+
align_corners=True)
|
216 |
+
|
217 |
+
# here true is correct!
|
218 |
+
occupancys = F.interpolate(occupancys.float(),
|
219 |
+
size=(D, H, W),
|
220 |
+
mode="trilinear",
|
221 |
+
align_corners=True)
|
222 |
+
|
223 |
+
is_boundary = (valid > 0.0) & (valid < 1.0)
|
224 |
+
|
225 |
+
with torch.no_grad():
|
226 |
+
if torch.equal(resolution, self.resolutions[1]):
|
227 |
+
is_boundary = (self.smooth_conv9x9(is_boundary.float())
|
228 |
+
> 0)[0, 0]
|
229 |
+
elif torch.equal(resolution, self.resolutions[2]):
|
230 |
+
is_boundary = (self.smooth_conv7x7(is_boundary.float())
|
231 |
+
> 0)[0, 0]
|
232 |
+
else:
|
233 |
+
is_boundary = (self.smooth_conv3x3(is_boundary.float())
|
234 |
+
> 0)[0, 0]
|
235 |
+
|
236 |
+
coords_accum = coords_accum.long()
|
237 |
+
is_boundary[coords_accum[0, :, 2], coords_accum[0, :, 1],
|
238 |
+
coords_accum[0, :, 0]] = False
|
239 |
+
point_coords = is_boundary.permute(
|
240 |
+
2, 1, 0).nonzero(as_tuple=False).unsqueeze(0)
|
241 |
+
point_indices = (point_coords[:, :, 2] * H * W +
|
242 |
+
point_coords[:, :, 1] * W +
|
243 |
+
point_coords[:, :, 0])
|
244 |
+
|
245 |
+
R, C, D, H, W = occupancys.shape
|
246 |
+
|
247 |
+
# inferred value
|
248 |
+
coords = point_coords * stride
|
249 |
+
|
250 |
+
if coords.size(1) == 0:
|
251 |
+
continue
|
252 |
+
occupancys_topk = self.batch_eval(coords, **kwargs)
|
253 |
+
|
254 |
+
# put mask point predictions to the right places on the upsampled grid.
|
255 |
+
R, C, D, H, W = occupancys.shape
|
256 |
+
point_indices = point_indices.unsqueeze(1).expand(-1, C, -1)
|
257 |
+
occupancys = (occupancys.reshape(R, C, D * H * W).scatter_(
|
258 |
+
2, point_indices, occupancys_topk).view(R, C, D, H, W))
|
259 |
+
|
260 |
+
with torch.no_grad():
|
261 |
+
voxels = coords / stride
|
262 |
+
coords_accum = torch.cat([voxels, coords_accum],
|
263 |
+
dim=1).unique(dim=1)
|
264 |
+
|
265 |
+
return occupancys[0, 0]
|
266 |
+
|
267 |
+
def _forward(self, **kwargs):
|
268 |
+
"""
|
269 |
+
output occupancy field would be:
|
270 |
+
(bz, C, res, res)
|
271 |
+
"""
|
272 |
+
final_W = self.resolutions[-1][0]
|
273 |
+
final_H = self.resolutions[-1][1]
|
274 |
+
final_D = self.resolutions[-1][2]
|
275 |
+
|
276 |
+
calculated = self.calculated.clone()
|
277 |
+
|
278 |
+
for resolution in self.resolutions:
|
279 |
+
W, H, D = resolution
|
280 |
+
stride = (self.resolutions[-1] - 1) / (resolution - 1)
|
281 |
+
|
282 |
+
if self.visualize:
|
283 |
+
this_stage_coords = []
|
284 |
+
|
285 |
+
# first step
|
286 |
+
if torch.equal(resolution, self.resolutions[0]):
|
287 |
+
coords = self.init_coords.clone() # torch.long
|
288 |
+
occupancys = self.batch_eval(coords, **kwargs)
|
289 |
+
occupancys = occupancys.view(self.batchsize, self.channels, D,
|
290 |
+
H, W)
|
291 |
+
|
292 |
+
if self.visualize:
|
293 |
+
self.plot(occupancys, coords, final_D, final_H, final_W)
|
294 |
+
|
295 |
+
with torch.no_grad():
|
296 |
+
coords_accum = coords / stride
|
297 |
+
calculated[coords[0, :, 2], coords[0, :, 1],
|
298 |
+
coords[0, :, 0]] = True
|
299 |
+
|
300 |
+
# next steps
|
301 |
+
else:
|
302 |
+
coords_accum *= 2
|
303 |
+
|
304 |
+
with torch.no_grad():
|
305 |
+
# here true is correct!
|
306 |
+
valid = F.interpolate(
|
307 |
+
(occupancys > self.balance_value).float(),
|
308 |
+
size=(D, H, W),
|
309 |
+
mode="trilinear",
|
310 |
+
align_corners=True)
|
311 |
+
|
312 |
+
# here true is correct!
|
313 |
+
occupancys = F.interpolate(occupancys.float(),
|
314 |
+
size=(D, H, W),
|
315 |
+
mode="trilinear",
|
316 |
+
align_corners=True)
|
317 |
+
|
318 |
+
is_boundary = (valid > 0.0) & (valid < 1.0)
|
319 |
+
|
320 |
+
with torch.no_grad():
|
321 |
+
# TODO
|
322 |
+
if self.use_shadow and torch.equal(resolution,
|
323 |
+
self.resolutions[-1]):
|
324 |
+
# larger z means smaller depth here
|
325 |
+
depth_res = resolution[2].item()
|
326 |
+
depth_index = torch.linspace(0,
|
327 |
+
depth_res - 1,
|
328 |
+
steps=depth_res).type_as(
|
329 |
+
occupancys.device)
|
330 |
+
depth_index_max = torch.max(
|
331 |
+
(occupancys > self.balance_value) *
|
332 |
+
(depth_index + 1),
|
333 |
+
dim=-1,
|
334 |
+
keepdim=True)[0] - 1
|
335 |
+
shadow = depth_index < depth_index_max
|
336 |
+
is_boundary[shadow] = False
|
337 |
+
is_boundary = is_boundary[0, 0]
|
338 |
+
else:
|
339 |
+
is_boundary = (self.smooth_conv3x3(is_boundary.float())
|
340 |
+
> 0)[0, 0]
|
341 |
+
# is_boundary = is_boundary[0, 0]
|
342 |
+
|
343 |
+
is_boundary[coords_accum[0, :, 2], coords_accum[0, :, 1],
|
344 |
+
coords_accum[0, :, 0]] = False
|
345 |
+
point_coords = is_boundary.permute(
|
346 |
+
2, 1, 0).nonzero(as_tuple=False).unsqueeze(0)
|
347 |
+
point_indices = (point_coords[:, :, 2] * H * W +
|
348 |
+
point_coords[:, :, 1] * W +
|
349 |
+
point_coords[:, :, 0])
|
350 |
+
|
351 |
+
R, C, D, H, W = occupancys.shape
|
352 |
+
# interpolated value
|
353 |
+
occupancys_interp = torch.gather(
|
354 |
+
occupancys.reshape(R, C, D * H * W), 2,
|
355 |
+
point_indices.unsqueeze(1))
|
356 |
+
|
357 |
+
# inferred value
|
358 |
+
coords = point_coords * stride
|
359 |
+
|
360 |
+
if coords.size(1) == 0:
|
361 |
+
continue
|
362 |
+
occupancys_topk = self.batch_eval(coords, **kwargs)
|
363 |
+
if self.visualize:
|
364 |
+
this_stage_coords.append(coords)
|
365 |
+
|
366 |
+
# put mask point predictions to the right places on the upsampled grid.
|
367 |
+
R, C, D, H, W = occupancys.shape
|
368 |
+
point_indices = point_indices.unsqueeze(1).expand(-1, C, -1)
|
369 |
+
occupancys = (occupancys.reshape(R, C, D * H * W).scatter_(
|
370 |
+
2, point_indices, occupancys_topk).view(R, C, D, H, W))
|
371 |
+
|
372 |
+
with torch.no_grad():
|
373 |
+
# conflicts
|
374 |
+
conflicts = ((occupancys_interp - self.balance_value) *
|
375 |
+
(occupancys_topk - self.balance_value) < 0)[0,
|
376 |
+
0]
|
377 |
+
|
378 |
+
if self.visualize:
|
379 |
+
self.plot(occupancys, coords, final_D, final_H,
|
380 |
+
final_W)
|
381 |
+
|
382 |
+
voxels = coords / stride
|
383 |
+
coords_accum = torch.cat([voxels, coords_accum],
|
384 |
+
dim=1).unique(dim=1)
|
385 |
+
calculated[coords[0, :, 2], coords[0, :, 1],
|
386 |
+
coords[0, :, 0]] = True
|
387 |
+
|
388 |
+
while conflicts.sum() > 0:
|
389 |
+
if self.use_shadow and torch.equal(resolution,
|
390 |
+
self.resolutions[-1]):
|
391 |
+
break
|
392 |
+
|
393 |
+
with torch.no_grad():
|
394 |
+
conflicts_coords = coords[0, conflicts, :]
|
395 |
+
|
396 |
+
if self.debug:
|
397 |
+
self.plot(occupancys,
|
398 |
+
conflicts_coords.unsqueeze(0),
|
399 |
+
final_D,
|
400 |
+
final_H,
|
401 |
+
final_W,
|
402 |
+
title='conflicts')
|
403 |
+
|
404 |
+
conflicts_boundary = (conflicts_coords.int() +
|
405 |
+
self.gird8_offsets.unsqueeze(1) *
|
406 |
+
stride.int()).reshape(
|
407 |
+
-1, 3).long().unique(dim=0)
|
408 |
+
conflicts_boundary[:, 0] = (
|
409 |
+
conflicts_boundary[:, 0].clamp(
|
410 |
+
0,
|
411 |
+
calculated.size(2) - 1))
|
412 |
+
conflicts_boundary[:, 1] = (
|
413 |
+
conflicts_boundary[:, 1].clamp(
|
414 |
+
0,
|
415 |
+
calculated.size(1) - 1))
|
416 |
+
conflicts_boundary[:, 2] = (
|
417 |
+
conflicts_boundary[:, 2].clamp(
|
418 |
+
0,
|
419 |
+
calculated.size(0) - 1))
|
420 |
+
|
421 |
+
coords = conflicts_boundary[calculated[
|
422 |
+
conflicts_boundary[:, 2], conflicts_boundary[:, 1],
|
423 |
+
conflicts_boundary[:, 0]] == False]
|
424 |
+
|
425 |
+
if self.debug:
|
426 |
+
self.plot(occupancys,
|
427 |
+
coords.unsqueeze(0),
|
428 |
+
final_D,
|
429 |
+
final_H,
|
430 |
+
final_W,
|
431 |
+
title='coords')
|
432 |
+
|
433 |
+
coords = coords.unsqueeze(0)
|
434 |
+
point_coords = coords / stride
|
435 |
+
point_indices = (point_coords[:, :, 2] * H * W +
|
436 |
+
point_coords[:, :, 1] * W +
|
437 |
+
point_coords[:, :, 0])
|
438 |
+
|
439 |
+
R, C, D, H, W = occupancys.shape
|
440 |
+
# interpolated value
|
441 |
+
occupancys_interp = torch.gather(
|
442 |
+
occupancys.reshape(R, C, D * H * W), 2,
|
443 |
+
point_indices.unsqueeze(1))
|
444 |
+
|
445 |
+
# inferred value
|
446 |
+
coords = point_coords * stride
|
447 |
+
|
448 |
+
if coords.size(1) == 0:
|
449 |
+
break
|
450 |
+
occupancys_topk = self.batch_eval(coords, **kwargs)
|
451 |
+
if self.visualize:
|
452 |
+
this_stage_coords.append(coords)
|
453 |
+
|
454 |
+
with torch.no_grad():
|
455 |
+
# conflicts
|
456 |
+
conflicts = ((occupancys_interp - self.balance_value) *
|
457 |
+
(occupancys_topk - self.balance_value) <
|
458 |
+
0)[0, 0]
|
459 |
+
|
460 |
+
# put mask point predictions to the right places on the upsampled grid.
|
461 |
+
point_indices = point_indices.unsqueeze(1).expand(
|
462 |
+
-1, C, -1)
|
463 |
+
occupancys = (occupancys.reshape(R, C, D * H * W).scatter_(
|
464 |
+
2, point_indices, occupancys_topk).view(R, C, D, H, W))
|
465 |
+
|
466 |
+
with torch.no_grad():
|
467 |
+
voxels = coords / stride
|
468 |
+
coords_accum = torch.cat([voxels, coords_accum],
|
469 |
+
dim=1).unique(dim=1)
|
470 |
+
calculated[coords[0, :, 2], coords[0, :, 1],
|
471 |
+
coords[0, :, 0]] = True
|
472 |
+
|
473 |
+
if self.visualize:
|
474 |
+
this_stage_coords = torch.cat(this_stage_coords, dim=1)
|
475 |
+
self.plot(occupancys, this_stage_coords, final_D, final_H,
|
476 |
+
final_W)
|
477 |
+
|
478 |
+
return occupancys[0, 0]
|
479 |
+
|
480 |
+
def plot(self,
|
481 |
+
occupancys,
|
482 |
+
coords,
|
483 |
+
final_D,
|
484 |
+
final_H,
|
485 |
+
final_W,
|
486 |
+
title='',
|
487 |
+
**kwargs):
|
488 |
+
final = F.interpolate(occupancys.float(),
|
489 |
+
size=(final_D, final_H, final_W),
|
490 |
+
mode="trilinear",
|
491 |
+
align_corners=True) # here true is correct!
|
492 |
+
x = coords[0, :, 0].to("cpu")
|
493 |
+
y = coords[0, :, 1].to("cpu")
|
494 |
+
z = coords[0, :, 2].to("cpu")
|
495 |
+
|
496 |
+
plot_mask3D(final[0, 0].to("cpu"), title, (x, y, z), **kwargs)
|
497 |
+
|
498 |
+
def find_vertices(self, sdf, direction="front"):
|
499 |
+
'''
|
500 |
+
- direction: "front" | "back" | "left" | "right"
|
501 |
+
'''
|
502 |
+
resolution = sdf.size(2)
|
503 |
+
if direction == "front":
|
504 |
+
pass
|
505 |
+
elif direction == "left":
|
506 |
+
sdf = sdf.permute(2, 1, 0)
|
507 |
+
elif direction == "back":
|
508 |
+
inv_idx = torch.arange(sdf.size(2) - 1, -1, -1).long()
|
509 |
+
sdf = sdf[inv_idx, :, :]
|
510 |
+
elif direction == "right":
|
511 |
+
inv_idx = torch.arange(sdf.size(2) - 1, -1, -1).long()
|
512 |
+
sdf = sdf[:, :, inv_idx]
|
513 |
+
sdf = sdf.permute(2, 1, 0)
|
514 |
+
|
515 |
+
inv_idx = torch.arange(sdf.size(2) - 1, -1, -1).long()
|
516 |
+
sdf = sdf[inv_idx, :, :]
|
517 |
+
sdf_all = sdf.permute(2, 1, 0)
|
518 |
+
|
519 |
+
# shadow
|
520 |
+
grad_v = (sdf_all > 0.5) * torch.linspace(
|
521 |
+
resolution, 1, steps=resolution).to(sdf.device)
|
522 |
+
grad_c = torch.ones_like(sdf_all) * torch.linspace(
|
523 |
+
0, resolution - 1, steps=resolution).to(sdf.device)
|
524 |
+
max_v, max_c = grad_v.max(dim=2)
|
525 |
+
shadow = grad_c > max_c.view(resolution, resolution, 1)
|
526 |
+
keep = (sdf_all > 0.5) & (~shadow)
|
527 |
+
|
528 |
+
p1 = keep.nonzero(as_tuple=False).t() # [3, N]
|
529 |
+
p2 = p1.clone() # z
|
530 |
+
p2[2, :] = (p2[2, :] - 2).clamp(0, resolution)
|
531 |
+
p3 = p1.clone() # y
|
532 |
+
p3[1, :] = (p3[1, :] - 2).clamp(0, resolution)
|
533 |
+
p4 = p1.clone() # x
|
534 |
+
p4[0, :] = (p4[0, :] - 2).clamp(0, resolution)
|
535 |
+
|
536 |
+
v1 = sdf_all[p1[0, :], p1[1, :], p1[2, :]]
|
537 |
+
v2 = sdf_all[p2[0, :], p2[1, :], p2[2, :]]
|
538 |
+
v3 = sdf_all[p3[0, :], p3[1, :], p3[2, :]]
|
539 |
+
v4 = sdf_all[p4[0, :], p4[1, :], p4[2, :]]
|
540 |
+
|
541 |
+
X = p1[0, :].long() # [N,]
|
542 |
+
Y = p1[1, :].long() # [N,]
|
543 |
+
Z = p2[2, :].float() * (0.5 - v1) / (v2 - v1) + \
|
544 |
+
p1[2, :].float() * (v2 - 0.5) / (v2 - v1) # [N,]
|
545 |
+
Z = Z.clamp(0, resolution)
|
546 |
+
|
547 |
+
# normal
|
548 |
+
norm_z = v2 - v1
|
549 |
+
norm_y = v3 - v1
|
550 |
+
norm_x = v4 - v1
|
551 |
+
# print (v2.min(dim=0)[0], v2.max(dim=0)[0], v3.min(dim=0)[0], v3.max(dim=0)[0])
|
552 |
+
|
553 |
+
norm = torch.stack([norm_x, norm_y, norm_z], dim=1)
|
554 |
+
norm = norm / torch.norm(norm, p=2, dim=1, keepdim=True)
|
555 |
+
|
556 |
+
return X, Y, Z, norm
|
557 |
+
|
558 |
+
def render_normal(self, resolution, X, Y, Z, norm):
|
559 |
+
image = torch.ones((1, 3, resolution, resolution),
|
560 |
+
dtype=torch.float32).to(norm.device)
|
561 |
+
color = (norm + 1) / 2.0
|
562 |
+
color = color.clamp(0, 1)
|
563 |
+
image[0, :, Y, X] = color.t()
|
564 |
+
return image
|
565 |
+
|
566 |
+
def display(self, sdf):
|
567 |
+
|
568 |
+
# render
|
569 |
+
X, Y, Z, norm = self.find_vertices(sdf, direction="front")
|
570 |
+
image1 = self.render_normal(self.resolutions[-1, -1], X, Y, Z, norm)
|
571 |
+
X, Y, Z, norm = self.find_vertices(sdf, direction="left")
|
572 |
+
image2 = self.render_normal(self.resolutions[-1, -1], X, Y, Z, norm)
|
573 |
+
X, Y, Z, norm = self.find_vertices(sdf, direction="right")
|
574 |
+
image3 = self.render_normal(self.resolutions[-1, -1], X, Y, Z, norm)
|
575 |
+
X, Y, Z, norm = self.find_vertices(sdf, direction="back")
|
576 |
+
image4 = self.render_normal(self.resolutions[-1, -1], X, Y, Z, norm)
|
577 |
+
|
578 |
+
image = torch.cat([image1, image2, image3, image4], axis=3)
|
579 |
+
image = image.detach().cpu().numpy()[0].transpose(1, 2, 0) * 255.0
|
580 |
+
|
581 |
+
return np.uint8(image)
|
582 |
+
|
583 |
+
def export_mesh(self, occupancys):
|
584 |
+
|
585 |
+
final = occupancys[1:, 1:, 1:].contiguous()
|
586 |
+
|
587 |
+
if final.shape[0] > 256:
|
588 |
+
# for voxelgrid larger than 256^3, the required GPU memory will be > 9GB
|
589 |
+
# thus we use CPU marching_cube to avoid "CUDA out of memory"
|
590 |
+
occu_arr = final.detach().cpu().numpy() # non-smooth surface
|
591 |
+
# occu_arr = mcubes.smooth(final.detach().cpu().numpy()) # smooth surface
|
592 |
+
vertices, triangles = mcubes.marching_cubes(
|
593 |
+
occu_arr, self.balance_value)
|
594 |
+
verts = torch.as_tensor(vertices[:, [2, 1, 0]])
|
595 |
+
faces = torch.as_tensor(triangles.astype(
|
596 |
+
np.long), dtype=torch.long)[:, [0, 2, 1]]
|
597 |
+
else:
|
598 |
+
torch.cuda.empty_cache()
|
599 |
+
vertices, triangles = voxelgrids_to_trianglemeshes(
|
600 |
+
final.unsqueeze(0))
|
601 |
+
verts = vertices[0][:, [2, 1, 0]].cpu()
|
602 |
+
faces = triangles[0][:, [0, 2, 1]].cpu()
|
603 |
+
|
604 |
+
return verts, faces
|
lib/common/seg3d_utils.py
ADDED
@@ -0,0 +1,392 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
|
5 |
+
# holder of all proprietary rights on this computer program.
|
6 |
+
# You can only use this computer program if you have closed
|
7 |
+
# a license agreement with MPG or you get the right to use the computer
|
8 |
+
# program from someone who is authorized to grant you that right.
|
9 |
+
# Any use of the computer program without a valid license is prohibited and
|
10 |
+
# liable to prosecution.
|
11 |
+
#
|
12 |
+
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
|
13 |
+
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
|
14 |
+
# for Intelligent Systems. All rights reserved.
|
15 |
+
#
|
16 |
+
# Contact: [email protected]
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
import torch.nn.functional as F
|
21 |
+
import matplotlib.pyplot as plt
|
22 |
+
|
23 |
+
|
24 |
+
def plot_mask2D(mask,
|
25 |
+
title="",
|
26 |
+
point_coords=None,
|
27 |
+
figsize=10,
|
28 |
+
point_marker_size=5):
|
29 |
+
'''
|
30 |
+
Simple plotting tool to show intermediate mask predictions and points
|
31 |
+
where PointRend is applied.
|
32 |
+
|
33 |
+
Args:
|
34 |
+
mask (Tensor): mask prediction of shape HxW
|
35 |
+
title (str): title for the plot
|
36 |
+
point_coords ((Tensor, Tensor)): x and y point coordinates
|
37 |
+
figsize (int): size of the figure to plot
|
38 |
+
point_marker_size (int): marker size for points
|
39 |
+
'''
|
40 |
+
|
41 |
+
H, W = mask.shape
|
42 |
+
plt.figure(figsize=(figsize, figsize))
|
43 |
+
if title:
|
44 |
+
title += ", "
|
45 |
+
plt.title("{}resolution {}x{}".format(title, H, W), fontsize=30)
|
46 |
+
plt.ylabel(H, fontsize=30)
|
47 |
+
plt.xlabel(W, fontsize=30)
|
48 |
+
plt.xticks([], [])
|
49 |
+
plt.yticks([], [])
|
50 |
+
plt.imshow(mask.detach(),
|
51 |
+
interpolation="nearest",
|
52 |
+
cmap=plt.get_cmap('gray'))
|
53 |
+
if point_coords is not None:
|
54 |
+
plt.scatter(x=point_coords[0],
|
55 |
+
y=point_coords[1],
|
56 |
+
color="red",
|
57 |
+
s=point_marker_size,
|
58 |
+
clip_on=True)
|
59 |
+
plt.xlim(-0.5, W - 0.5)
|
60 |
+
plt.ylim(H - 0.5, -0.5)
|
61 |
+
plt.show()
|
62 |
+
|
63 |
+
|
64 |
+
def plot_mask3D(mask=None,
|
65 |
+
title="",
|
66 |
+
point_coords=None,
|
67 |
+
figsize=1500,
|
68 |
+
point_marker_size=8,
|
69 |
+
interactive=True):
|
70 |
+
'''
|
71 |
+
Simple plotting tool to show intermediate mask predictions and points
|
72 |
+
where PointRend is applied.
|
73 |
+
|
74 |
+
Args:
|
75 |
+
mask (Tensor): mask prediction of shape DxHxW
|
76 |
+
title (str): title for the plot
|
77 |
+
point_coords ((Tensor, Tensor, Tensor)): x and y and z point coordinates
|
78 |
+
figsize (int): size of the figure to plot
|
79 |
+
point_marker_size (int): marker size for points
|
80 |
+
'''
|
81 |
+
import trimesh
|
82 |
+
import vtkplotter
|
83 |
+
from skimage import measure
|
84 |
+
|
85 |
+
vp = vtkplotter.Plotter(title=title, size=(figsize, figsize))
|
86 |
+
vis_list = []
|
87 |
+
|
88 |
+
if mask is not None:
|
89 |
+
mask = mask.detach().to("cpu").numpy()
|
90 |
+
mask = mask.transpose(2, 1, 0)
|
91 |
+
|
92 |
+
# marching cube to find surface
|
93 |
+
verts, faces, normals, values = measure.marching_cubes_lewiner(
|
94 |
+
mask, 0.5, gradient_direction='ascent')
|
95 |
+
|
96 |
+
# create a mesh
|
97 |
+
mesh = trimesh.Trimesh(verts, faces)
|
98 |
+
mesh.visual.face_colors = [200, 200, 250, 100]
|
99 |
+
vis_list.append(mesh)
|
100 |
+
|
101 |
+
if point_coords is not None:
|
102 |
+
point_coords = torch.stack(point_coords, 1).to("cpu").numpy()
|
103 |
+
|
104 |
+
# import numpy as np
|
105 |
+
# select_x = np.logical_and(point_coords[:, 0] >= 16, point_coords[:, 0] <= 112)
|
106 |
+
# select_y = np.logical_and(point_coords[:, 1] >= 48, point_coords[:, 1] <= 272)
|
107 |
+
# select_z = np.logical_and(point_coords[:, 2] >= 16, point_coords[:, 2] <= 112)
|
108 |
+
# select = np.logical_and(np.logical_and(select_x, select_y), select_z)
|
109 |
+
# point_coords = point_coords[select, :]
|
110 |
+
|
111 |
+
pc = vtkplotter.Points(point_coords, r=point_marker_size, c='red')
|
112 |
+
vis_list.append(pc)
|
113 |
+
|
114 |
+
vp.show(*vis_list,
|
115 |
+
bg="white",
|
116 |
+
axes=1,
|
117 |
+
interactive=interactive,
|
118 |
+
azimuth=30,
|
119 |
+
elevation=30)
|
120 |
+
|
121 |
+
|
122 |
+
def create_grid3D(min, max, steps):
|
123 |
+
if type(min) is int:
|
124 |
+
min = (min, min, min) # (x, y, z)
|
125 |
+
if type(max) is int:
|
126 |
+
max = (max, max, max) # (x, y)
|
127 |
+
if type(steps) is int:
|
128 |
+
steps = (steps, steps, steps) # (x, y, z)
|
129 |
+
arrangeX = torch.linspace(min[0], max[0], steps[0]).long()
|
130 |
+
arrangeY = torch.linspace(min[1], max[1], steps[1]).long()
|
131 |
+
arrangeZ = torch.linspace(min[2], max[2], steps[2]).long()
|
132 |
+
gridD, girdH, gridW = torch.meshgrid([arrangeZ, arrangeY, arrangeX])
|
133 |
+
coords = torch.stack([gridW, girdH,
|
134 |
+
gridD]) # [2, steps[0], steps[1], steps[2]]
|
135 |
+
coords = coords.view(3, -1).t() # [N, 3]
|
136 |
+
return coords
|
137 |
+
|
138 |
+
|
139 |
+
def create_grid2D(min, max, steps):
|
140 |
+
if type(min) is int:
|
141 |
+
min = (min, min) # (x, y)
|
142 |
+
if type(max) is int:
|
143 |
+
max = (max, max) # (x, y)
|
144 |
+
if type(steps) is int:
|
145 |
+
steps = (steps, steps) # (x, y)
|
146 |
+
arrangeX = torch.linspace(min[0], max[0], steps[0]).long()
|
147 |
+
arrangeY = torch.linspace(min[1], max[1], steps[1]).long()
|
148 |
+
girdH, gridW = torch.meshgrid([arrangeY, arrangeX])
|
149 |
+
coords = torch.stack([gridW, girdH]) # [2, steps[0], steps[1]]
|
150 |
+
coords = coords.view(2, -1).t() # [N, 2]
|
151 |
+
return coords
|
152 |
+
|
153 |
+
|
154 |
+
class SmoothConv2D(nn.Module):
|
155 |
+
def __init__(self, in_channels, out_channels, kernel_size=3):
|
156 |
+
super().__init__()
|
157 |
+
assert kernel_size % 2 == 1, "kernel_size for smooth_conv must be odd: {3, 5, ...}"
|
158 |
+
self.padding = (kernel_size - 1) // 2
|
159 |
+
|
160 |
+
weight = torch.ones(
|
161 |
+
(in_channels, out_channels, kernel_size, kernel_size),
|
162 |
+
dtype=torch.float32) / (kernel_size**2)
|
163 |
+
self.register_buffer('weight', weight)
|
164 |
+
|
165 |
+
def forward(self, input):
|
166 |
+
return F.conv2d(input, self.weight, padding=self.padding)
|
167 |
+
|
168 |
+
|
169 |
+
class SmoothConv3D(nn.Module):
|
170 |
+
def __init__(self, in_channels, out_channels, kernel_size=3):
|
171 |
+
super().__init__()
|
172 |
+
assert kernel_size % 2 == 1, "kernel_size for smooth_conv must be odd: {3, 5, ...}"
|
173 |
+
self.padding = (kernel_size - 1) // 2
|
174 |
+
|
175 |
+
weight = torch.ones(
|
176 |
+
(in_channels, out_channels, kernel_size, kernel_size, kernel_size),
|
177 |
+
dtype=torch.float32) / (kernel_size**3)
|
178 |
+
self.register_buffer('weight', weight)
|
179 |
+
|
180 |
+
def forward(self, input):
|
181 |
+
return F.conv3d(input, self.weight, padding=self.padding)
|
182 |
+
|
183 |
+
|
184 |
+
def build_smooth_conv3D(in_channels=1,
|
185 |
+
out_channels=1,
|
186 |
+
kernel_size=3,
|
187 |
+
padding=1):
|
188 |
+
smooth_conv = torch.nn.Conv3d(in_channels=in_channels,
|
189 |
+
out_channels=out_channels,
|
190 |
+
kernel_size=kernel_size,
|
191 |
+
padding=padding)
|
192 |
+
smooth_conv.weight.data = torch.ones(
|
193 |
+
(in_channels, out_channels, kernel_size, kernel_size, kernel_size),
|
194 |
+
dtype=torch.float32) / (kernel_size**3)
|
195 |
+
smooth_conv.bias.data = torch.zeros(out_channels)
|
196 |
+
return smooth_conv
|
197 |
+
|
198 |
+
|
199 |
+
def build_smooth_conv2D(in_channels=1,
|
200 |
+
out_channels=1,
|
201 |
+
kernel_size=3,
|
202 |
+
padding=1):
|
203 |
+
smooth_conv = torch.nn.Conv2d(in_channels=in_channels,
|
204 |
+
out_channels=out_channels,
|
205 |
+
kernel_size=kernel_size,
|
206 |
+
padding=padding)
|
207 |
+
smooth_conv.weight.data = torch.ones(
|
208 |
+
(in_channels, out_channels, kernel_size, kernel_size),
|
209 |
+
dtype=torch.float32) / (kernel_size**2)
|
210 |
+
smooth_conv.bias.data = torch.zeros(out_channels)
|
211 |
+
return smooth_conv
|
212 |
+
|
213 |
+
|
214 |
+
def get_uncertain_point_coords_on_grid3D(uncertainty_map, num_points,
|
215 |
+
**kwargs):
|
216 |
+
"""
|
217 |
+
Find `num_points` most uncertain points from `uncertainty_map` grid.
|
218 |
+
Args:
|
219 |
+
uncertainty_map (Tensor): A tensor of shape (N, 1, H, W, D) that contains uncertainty
|
220 |
+
values for a set of points on a regular H x W x D grid.
|
221 |
+
num_points (int): The number of points P to select.
|
222 |
+
Returns:
|
223 |
+
point_indices (Tensor): A tensor of shape (N, P) that contains indices from
|
224 |
+
[0, H x W x D) of the most uncertain points.
|
225 |
+
point_coords (Tensor): A tensor of shape (N, P, 3) that contains [0, 1] x [0, 1] normalized
|
226 |
+
coordinates of the most uncertain points from the H x W x D grid.
|
227 |
+
"""
|
228 |
+
R, _, D, H, W = uncertainty_map.shape
|
229 |
+
# h_step = 1.0 / float(H)
|
230 |
+
# w_step = 1.0 / float(W)
|
231 |
+
# d_step = 1.0 / float(D)
|
232 |
+
|
233 |
+
num_points = min(D * H * W, num_points)
|
234 |
+
point_scores, point_indices = torch.topk(uncertainty_map.view(
|
235 |
+
R, D * H * W),
|
236 |
+
k=num_points,
|
237 |
+
dim=1)
|
238 |
+
point_coords = torch.zeros(R,
|
239 |
+
num_points,
|
240 |
+
3,
|
241 |
+
dtype=torch.float,
|
242 |
+
device=uncertainty_map.device)
|
243 |
+
# point_coords[:, :, 0] = h_step / 2.0 + (point_indices // (W * D)).to(torch.float) * h_step
|
244 |
+
# point_coords[:, :, 1] = w_step / 2.0 + (point_indices % (W * D) // D).to(torch.float) * w_step
|
245 |
+
# point_coords[:, :, 2] = d_step / 2.0 + (point_indices % D).to(torch.float) * d_step
|
246 |
+
point_coords[:, :, 0] = (point_indices % W).to(torch.float) # x
|
247 |
+
point_coords[:, :, 1] = (point_indices % (H * W) // W).to(torch.float) # y
|
248 |
+
point_coords[:, :, 2] = (point_indices // (H * W)).to(torch.float) # z
|
249 |
+
print(f"resolution {D} x {H} x {W}", point_scores.min(),
|
250 |
+
point_scores.max())
|
251 |
+
return point_indices, point_coords
|
252 |
+
|
253 |
+
|
254 |
+
def get_uncertain_point_coords_on_grid3D_faster(uncertainty_map, num_points,
|
255 |
+
clip_min):
|
256 |
+
"""
|
257 |
+
Find `num_points` most uncertain points from `uncertainty_map` grid.
|
258 |
+
Args:
|
259 |
+
uncertainty_map (Tensor): A tensor of shape (N, 1, H, W, D) that contains uncertainty
|
260 |
+
values for a set of points on a regular H x W x D grid.
|
261 |
+
num_points (int): The number of points P to select.
|
262 |
+
Returns:
|
263 |
+
point_indices (Tensor): A tensor of shape (N, P) that contains indices from
|
264 |
+
[0, H x W x D) of the most uncertain points.
|
265 |
+
point_coords (Tensor): A tensor of shape (N, P, 3) that contains [0, 1] x [0, 1] normalized
|
266 |
+
coordinates of the most uncertain points from the H x W x D grid.
|
267 |
+
"""
|
268 |
+
R, _, D, H, W = uncertainty_map.shape
|
269 |
+
# h_step = 1.0 / float(H)
|
270 |
+
# w_step = 1.0 / float(W)
|
271 |
+
# d_step = 1.0 / float(D)
|
272 |
+
|
273 |
+
assert R == 1, "batchsize > 1 is not implemented!"
|
274 |
+
uncertainty_map = uncertainty_map.view(D * H * W)
|
275 |
+
indices = (uncertainty_map >= clip_min).nonzero().squeeze(1)
|
276 |
+
num_points = min(num_points, indices.size(0))
|
277 |
+
point_scores, point_indices = torch.topk(uncertainty_map[indices],
|
278 |
+
k=num_points,
|
279 |
+
dim=0)
|
280 |
+
point_indices = indices[point_indices].unsqueeze(0)
|
281 |
+
|
282 |
+
point_coords = torch.zeros(R,
|
283 |
+
num_points,
|
284 |
+
3,
|
285 |
+
dtype=torch.float,
|
286 |
+
device=uncertainty_map.device)
|
287 |
+
# point_coords[:, :, 0] = h_step / 2.0 + (point_indices // (W * D)).to(torch.float) * h_step
|
288 |
+
# point_coords[:, :, 1] = w_step / 2.0 + (point_indices % (W * D) // D).to(torch.float) * w_step
|
289 |
+
# point_coords[:, :, 2] = d_step / 2.0 + (point_indices % D).to(torch.float) * d_step
|
290 |
+
point_coords[:, :, 0] = (point_indices % W).to(torch.float) # x
|
291 |
+
point_coords[:, :, 1] = (point_indices % (H * W) // W).to(torch.float) # y
|
292 |
+
point_coords[:, :, 2] = (point_indices // (H * W)).to(torch.float) # z
|
293 |
+
# print (f"resolution {D} x {H} x {W}", point_scores.min(), point_scores.max())
|
294 |
+
return point_indices, point_coords
|
295 |
+
|
296 |
+
|
297 |
+
def get_uncertain_point_coords_on_grid2D(uncertainty_map, num_points,
|
298 |
+
**kwargs):
|
299 |
+
"""
|
300 |
+
Find `num_points` most uncertain points from `uncertainty_map` grid.
|
301 |
+
Args:
|
302 |
+
uncertainty_map (Tensor): A tensor of shape (N, 1, H, W) that contains uncertainty
|
303 |
+
values for a set of points on a regular H x W grid.
|
304 |
+
num_points (int): The number of points P to select.
|
305 |
+
Returns:
|
306 |
+
point_indices (Tensor): A tensor of shape (N, P) that contains indices from
|
307 |
+
[0, H x W) of the most uncertain points.
|
308 |
+
point_coords (Tensor): A tensor of shape (N, P, 2) that contains [0, 1] x [0, 1] normalized
|
309 |
+
coordinates of the most uncertain points from the H x W grid.
|
310 |
+
"""
|
311 |
+
R, _, H, W = uncertainty_map.shape
|
312 |
+
# h_step = 1.0 / float(H)
|
313 |
+
# w_step = 1.0 / float(W)
|
314 |
+
|
315 |
+
num_points = min(H * W, num_points)
|
316 |
+
point_scores, point_indices = torch.topk(uncertainty_map.view(R, H * W),
|
317 |
+
k=num_points,
|
318 |
+
dim=1)
|
319 |
+
point_coords = torch.zeros(R,
|
320 |
+
num_points,
|
321 |
+
2,
|
322 |
+
dtype=torch.long,
|
323 |
+
device=uncertainty_map.device)
|
324 |
+
# point_coords[:, :, 0] = w_step / 2.0 + (point_indices % W).to(torch.float) * w_step
|
325 |
+
# point_coords[:, :, 1] = h_step / 2.0 + (point_indices // W).to(torch.float) * h_step
|
326 |
+
point_coords[:, :, 0] = (point_indices % W).to(torch.long)
|
327 |
+
point_coords[:, :, 1] = (point_indices // W).to(torch.long)
|
328 |
+
# print (point_scores.min(), point_scores.max())
|
329 |
+
return point_indices, point_coords
|
330 |
+
|
331 |
+
|
332 |
+
def get_uncertain_point_coords_on_grid2D_faster(uncertainty_map, num_points,
|
333 |
+
clip_min):
|
334 |
+
"""
|
335 |
+
Find `num_points` most uncertain points from `uncertainty_map` grid.
|
336 |
+
Args:
|
337 |
+
uncertainty_map (Tensor): A tensor of shape (N, 1, H, W) that contains uncertainty
|
338 |
+
values for a set of points on a regular H x W grid.
|
339 |
+
num_points (int): The number of points P to select.
|
340 |
+
Returns:
|
341 |
+
point_indices (Tensor): A tensor of shape (N, P) that contains indices from
|
342 |
+
[0, H x W) of the most uncertain points.
|
343 |
+
point_coords (Tensor): A tensor of shape (N, P, 2) that contains [0, 1] x [0, 1] normalized
|
344 |
+
coordinates of the most uncertain points from the H x W grid.
|
345 |
+
"""
|
346 |
+
R, _, H, W = uncertainty_map.shape
|
347 |
+
# h_step = 1.0 / float(H)
|
348 |
+
# w_step = 1.0 / float(W)
|
349 |
+
|
350 |
+
assert R == 1, "batchsize > 1 is not implemented!"
|
351 |
+
uncertainty_map = uncertainty_map.view(H * W)
|
352 |
+
indices = (uncertainty_map >= clip_min).nonzero().squeeze(1)
|
353 |
+
num_points = min(num_points, indices.size(0))
|
354 |
+
point_scores, point_indices = torch.topk(uncertainty_map[indices],
|
355 |
+
k=num_points,
|
356 |
+
dim=0)
|
357 |
+
point_indices = indices[point_indices].unsqueeze(0)
|
358 |
+
|
359 |
+
point_coords = torch.zeros(R,
|
360 |
+
num_points,
|
361 |
+
2,
|
362 |
+
dtype=torch.long,
|
363 |
+
device=uncertainty_map.device)
|
364 |
+
# point_coords[:, :, 0] = w_step / 2.0 + (point_indices % W).to(torch.float) * w_step
|
365 |
+
# point_coords[:, :, 1] = h_step / 2.0 + (point_indices // W).to(torch.float) * h_step
|
366 |
+
point_coords[:, :, 0] = (point_indices % W).to(torch.long)
|
367 |
+
point_coords[:, :, 1] = (point_indices // W).to(torch.long)
|
368 |
+
# print (point_scores.min(), point_scores.max())
|
369 |
+
return point_indices, point_coords
|
370 |
+
|
371 |
+
|
372 |
+
def calculate_uncertainty(logits, classes=None, balance_value=0.5):
|
373 |
+
"""
|
374 |
+
We estimate uncerainty as L1 distance between 0.0 and the logit prediction in 'logits' for the
|
375 |
+
foreground class in `classes`.
|
376 |
+
Args:
|
377 |
+
logits (Tensor): A tensor of shape (R, C, ...) or (R, 1, ...) for class-specific or
|
378 |
+
class-agnostic, where R is the total number of predicted masks in all images and C is
|
379 |
+
the number of foreground classes. The values are logits.
|
380 |
+
classes (list): A list of length R that contains either predicted of ground truth class
|
381 |
+
for eash predicted mask.
|
382 |
+
Returns:
|
383 |
+
scores (Tensor): A tensor of shape (R, 1, ...) that contains uncertainty scores with
|
384 |
+
the most uncertain locations having the highest uncertainty score.
|
385 |
+
"""
|
386 |
+
if logits.shape[1] == 1:
|
387 |
+
gt_class_logits = logits
|
388 |
+
else:
|
389 |
+
gt_class_logits = logits[
|
390 |
+
torch.arange(logits.shape[0], device=logits.device),
|
391 |
+
classes].unsqueeze(1)
|
392 |
+
return -torch.abs(gt_class_logits - balance_value)
|
lib/common/smpl_vert_segmentation.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
lib/common/train_util.py
ADDED
@@ -0,0 +1,597 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
|
5 |
+
# holder of all proprietary rights on this computer program.
|
6 |
+
# You can only use this computer program if you have closed
|
7 |
+
# a license agreement with MPG or you get the right to use the computer
|
8 |
+
# program from someone who is authorized to grant you that right.
|
9 |
+
# Any use of the computer program without a valid license is prohibited and
|
10 |
+
# liable to prosecution.
|
11 |
+
#
|
12 |
+
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
|
13 |
+
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
|
14 |
+
# for Intelligent Systems. All rights reserved.
|
15 |
+
#
|
16 |
+
# Contact: [email protected]
|
17 |
+
|
18 |
+
import yaml
|
19 |
+
import os.path as osp
|
20 |
+
import torch
|
21 |
+
import numpy as np
|
22 |
+
import torch.nn.functional as F
|
23 |
+
from ..dataset.mesh_util import *
|
24 |
+
from ..net.geometry import orthogonal
|
25 |
+
from pytorch3d.renderer.mesh import rasterize_meshes
|
26 |
+
from .render_utils import Pytorch3dRasterizer
|
27 |
+
from pytorch3d.structures import Meshes
|
28 |
+
import cv2
|
29 |
+
from PIL import Image
|
30 |
+
from tqdm import tqdm
|
31 |
+
import os
|
32 |
+
from termcolor import colored
|
33 |
+
|
34 |
+
|
35 |
+
def reshape_sample_tensor(sample_tensor, num_views):
|
36 |
+
if num_views == 1:
|
37 |
+
return sample_tensor
|
38 |
+
# Need to repeat sample_tensor along the batch dim num_views times
|
39 |
+
sample_tensor = sample_tensor.unsqueeze(dim=1)
|
40 |
+
sample_tensor = sample_tensor.repeat(1, num_views, 1, 1)
|
41 |
+
sample_tensor = sample_tensor.view(
|
42 |
+
sample_tensor.shape[0] * sample_tensor.shape[1],
|
43 |
+
sample_tensor.shape[2], sample_tensor.shape[3])
|
44 |
+
return sample_tensor
|
45 |
+
|
46 |
+
|
47 |
+
def gen_mesh_eval(opt, net, cuda, data, resolution=None):
|
48 |
+
resolution = opt.resolution if resolution is None else resolution
|
49 |
+
image_tensor = data['img'].to(device=cuda)
|
50 |
+
calib_tensor = data['calib'].to(device=cuda)
|
51 |
+
|
52 |
+
net.filter(image_tensor)
|
53 |
+
|
54 |
+
b_min = data['b_min']
|
55 |
+
b_max = data['b_max']
|
56 |
+
try:
|
57 |
+
verts, faces, _, _ = reconstruction_faster(net,
|
58 |
+
cuda,
|
59 |
+
calib_tensor,
|
60 |
+
resolution,
|
61 |
+
b_min,
|
62 |
+
b_max,
|
63 |
+
use_octree=False)
|
64 |
+
|
65 |
+
except Exception as e:
|
66 |
+
print(e)
|
67 |
+
print('Can not create marching cubes at this time.')
|
68 |
+
verts, faces = None, None
|
69 |
+
return verts, faces
|
70 |
+
|
71 |
+
|
72 |
+
def gen_mesh(opt, net, cuda, data, save_path, resolution=None):
|
73 |
+
resolution = opt.resolution if resolution is None else resolution
|
74 |
+
image_tensor = data['img'].to(device=cuda)
|
75 |
+
calib_tensor = data['calib'].to(device=cuda)
|
76 |
+
|
77 |
+
net.filter(image_tensor)
|
78 |
+
|
79 |
+
b_min = data['b_min']
|
80 |
+
b_max = data['b_max']
|
81 |
+
try:
|
82 |
+
save_img_path = save_path[:-4] + '.png'
|
83 |
+
save_img_list = []
|
84 |
+
for v in range(image_tensor.shape[0]):
|
85 |
+
save_img = (np.transpose(image_tensor[v].detach().cpu().numpy(),
|
86 |
+
(1, 2, 0)) * 0.5 +
|
87 |
+
0.5)[:, :, ::-1] * 255.0
|
88 |
+
save_img_list.append(save_img)
|
89 |
+
save_img = np.concatenate(save_img_list, axis=1)
|
90 |
+
Image.fromarray(np.uint8(save_img[:, :, ::-1])).save(save_img_path)
|
91 |
+
|
92 |
+
verts, faces, _, _ = reconstruction_faster(net, cuda, calib_tensor,
|
93 |
+
resolution, b_min, b_max)
|
94 |
+
verts_tensor = torch.from_numpy(
|
95 |
+
verts.T).unsqueeze(0).to(device=cuda).float()
|
96 |
+
xyz_tensor = net.projection(verts_tensor, calib_tensor[:1])
|
97 |
+
uv = xyz_tensor[:, :2, :]
|
98 |
+
color = netG.index(image_tensor[:1], uv).detach().cpu().numpy()[0].T
|
99 |
+
color = color * 0.5 + 0.5
|
100 |
+
save_obj_mesh_with_color(save_path, verts, faces, color)
|
101 |
+
except Exception as e:
|
102 |
+
print(e)
|
103 |
+
print('Can not create marching cubes at this time.')
|
104 |
+
verts, faces, color = None, None, None
|
105 |
+
return verts, faces, color
|
106 |
+
|
107 |
+
|
108 |
+
def gen_mesh_color(opt, netG, netC, cuda, data, save_path, use_octree=True):
|
109 |
+
image_tensor = data['img'].to(device=cuda)
|
110 |
+
calib_tensor = data['calib'].to(device=cuda)
|
111 |
+
|
112 |
+
netG.filter(image_tensor)
|
113 |
+
netC.filter(image_tensor)
|
114 |
+
netC.attach(netG.get_im_feat())
|
115 |
+
|
116 |
+
b_min = data['b_min']
|
117 |
+
b_max = data['b_max']
|
118 |
+
try:
|
119 |
+
save_img_path = save_path[:-4] + '.png'
|
120 |
+
save_img_list = []
|
121 |
+
for v in range(image_tensor.shape[0]):
|
122 |
+
save_img = (np.transpose(image_tensor[v].detach().cpu().numpy(),
|
123 |
+
(1, 2, 0)) * 0.5 +
|
124 |
+
0.5)[:, :, ::-1] * 255.0
|
125 |
+
save_img_list.append(save_img)
|
126 |
+
save_img = np.concatenate(save_img_list, axis=1)
|
127 |
+
Image.fromarray(np.uint8(save_img[:, :, ::-1])).save(save_img_path)
|
128 |
+
|
129 |
+
verts, faces, _, _ = reconstruction_faster(netG,
|
130 |
+
cuda,
|
131 |
+
calib_tensor,
|
132 |
+
opt.resolution,
|
133 |
+
b_min,
|
134 |
+
b_max,
|
135 |
+
use_octree=use_octree)
|
136 |
+
|
137 |
+
# Now Getting colors
|
138 |
+
verts_tensor = torch.from_numpy(
|
139 |
+
verts.T).unsqueeze(0).to(device=cuda).float()
|
140 |
+
verts_tensor = reshape_sample_tensor(verts_tensor, opt.num_views)
|
141 |
+
color = np.zeros(verts.shape)
|
142 |
+
interval = 10000
|
143 |
+
for i in range(len(color) // interval):
|
144 |
+
left = i * interval
|
145 |
+
right = i * interval + interval
|
146 |
+
if i == len(color) // interval - 1:
|
147 |
+
right = -1
|
148 |
+
netC.query(verts_tensor[:, :, left:right], calib_tensor)
|
149 |
+
rgb = netC.get_preds()[0].detach().cpu().numpy() * 0.5 + 0.5
|
150 |
+
color[left:right] = rgb.T
|
151 |
+
|
152 |
+
save_obj_mesh_with_color(save_path, verts, faces, color)
|
153 |
+
except Exception as e:
|
154 |
+
print(e)
|
155 |
+
print('Can not create marching cubes at this time.')
|
156 |
+
verts, faces, color = None, None, None
|
157 |
+
return verts, faces, color
|
158 |
+
|
159 |
+
|
160 |
+
def adjust_learning_rate(optimizer, epoch, lr, schedule, gamma):
|
161 |
+
"""Sets the learning rate to the initial LR decayed by schedule"""
|
162 |
+
if epoch in schedule:
|
163 |
+
lr *= gamma
|
164 |
+
for param_group in optimizer.param_groups:
|
165 |
+
param_group['lr'] = lr
|
166 |
+
return lr
|
167 |
+
|
168 |
+
|
169 |
+
def compute_acc(pred, gt, thresh=0.5):
|
170 |
+
'''
|
171 |
+
return:
|
172 |
+
IOU, precision, and recall
|
173 |
+
'''
|
174 |
+
with torch.no_grad():
|
175 |
+
vol_pred = pred > thresh
|
176 |
+
vol_gt = gt > thresh
|
177 |
+
|
178 |
+
union = vol_pred | vol_gt
|
179 |
+
inter = vol_pred & vol_gt
|
180 |
+
|
181 |
+
true_pos = inter.sum().float()
|
182 |
+
|
183 |
+
union = union.sum().float()
|
184 |
+
if union == 0:
|
185 |
+
union = 1
|
186 |
+
vol_pred = vol_pred.sum().float()
|
187 |
+
if vol_pred == 0:
|
188 |
+
vol_pred = 1
|
189 |
+
vol_gt = vol_gt.sum().float()
|
190 |
+
if vol_gt == 0:
|
191 |
+
vol_gt = 1
|
192 |
+
return true_pos / union, true_pos / vol_pred, true_pos / vol_gt
|
193 |
+
|
194 |
+
|
195 |
+
# def calc_metrics(opt, net, cuda, dataset, num_tests,
|
196 |
+
# resolution=128, sampled_points=1000, use_kaolin=True):
|
197 |
+
# if num_tests > len(dataset):
|
198 |
+
# num_tests = len(dataset)
|
199 |
+
# with torch.no_grad():
|
200 |
+
# chamfer_arr, p2s_arr = [], []
|
201 |
+
# for idx in tqdm(range(num_tests)):
|
202 |
+
# data = dataset[idx * len(dataset) // num_tests]
|
203 |
+
|
204 |
+
# verts, faces = gen_mesh_eval(opt, net, cuda, data, resolution)
|
205 |
+
# if verts is None:
|
206 |
+
# continue
|
207 |
+
|
208 |
+
# mesh_gt = trimesh.load(data['mesh_path'])
|
209 |
+
# mesh_gt = mesh_gt.split(only_watertight=False)
|
210 |
+
# comp_num = [mesh.vertices.shape[0] for mesh in mesh_gt]
|
211 |
+
# mesh_gt = mesh_gt[comp_num.index(max(comp_num))]
|
212 |
+
|
213 |
+
# mesh_pred = trimesh.Trimesh(verts, faces)
|
214 |
+
|
215 |
+
# gt_surface_pts, _ = trimesh.sample.sample_surface_even(
|
216 |
+
# mesh_gt, sampled_points)
|
217 |
+
# pred_surface_pts, _ = trimesh.sample.sample_surface_even(
|
218 |
+
# mesh_pred, sampled_points)
|
219 |
+
|
220 |
+
# if use_kaolin and has_kaolin:
|
221 |
+
# kal_mesh_gt = kal.rep.TriangleMesh.from_tensors(
|
222 |
+
# torch.tensor(mesh_gt.vertices).float().to(device=cuda),
|
223 |
+
# torch.tensor(mesh_gt.faces).long().to(device=cuda))
|
224 |
+
# kal_mesh_pred = kal.rep.TriangleMesh.from_tensors(
|
225 |
+
# torch.tensor(mesh_pred.vertices).float().to(device=cuda),
|
226 |
+
# torch.tensor(mesh_pred.faces).long().to(device=cuda))
|
227 |
+
|
228 |
+
# kal_distance_0 = kal.metrics.mesh.point_to_surface(
|
229 |
+
# torch.tensor(pred_surface_pts).float().to(device=cuda), kal_mesh_gt)
|
230 |
+
# kal_distance_1 = kal.metrics.mesh.point_to_surface(
|
231 |
+
# torch.tensor(gt_surface_pts).float().to(device=cuda), kal_mesh_pred)
|
232 |
+
|
233 |
+
# dist_gt_pred = torch.sqrt(kal_distance_0).cpu().numpy()
|
234 |
+
# dist_pred_gt = torch.sqrt(kal_distance_1).cpu().numpy()
|
235 |
+
# else:
|
236 |
+
# try:
|
237 |
+
# _, dist_pred_gt, _ = trimesh.proximity.closest_point(mesh_pred, gt_surface_pts)
|
238 |
+
# _, dist_gt_pred, _ = trimesh.proximity.closest_point(mesh_gt, pred_surface_pts)
|
239 |
+
# except Exception as e:
|
240 |
+
# print (e)
|
241 |
+
# continue
|
242 |
+
|
243 |
+
# chamfer_dist = 0.5 * (dist_pred_gt.mean() + dist_gt_pred.mean())
|
244 |
+
# p2s_dist = dist_pred_gt.mean()
|
245 |
+
|
246 |
+
# chamfer_arr.append(chamfer_dist)
|
247 |
+
# p2s_arr.append(p2s_dist)
|
248 |
+
|
249 |
+
# return np.average(chamfer_arr), np.average(p2s_arr)
|
250 |
+
|
251 |
+
|
252 |
+
def calc_error(opt, net, cuda, dataset, num_tests):
|
253 |
+
if num_tests > len(dataset):
|
254 |
+
num_tests = len(dataset)
|
255 |
+
with torch.no_grad():
|
256 |
+
erorr_arr, IOU_arr, prec_arr, recall_arr = [], [], [], []
|
257 |
+
for idx in tqdm(range(num_tests)):
|
258 |
+
data = dataset[idx * len(dataset) // num_tests]
|
259 |
+
# retrieve the data
|
260 |
+
image_tensor = data['img'].to(device=cuda)
|
261 |
+
calib_tensor = data['calib'].to(device=cuda)
|
262 |
+
sample_tensor = data['samples'].to(device=cuda).unsqueeze(0)
|
263 |
+
if opt.num_views > 1:
|
264 |
+
sample_tensor = reshape_sample_tensor(sample_tensor,
|
265 |
+
opt.num_views)
|
266 |
+
label_tensor = data['labels'].to(device=cuda).unsqueeze(0)
|
267 |
+
|
268 |
+
res, error = net.forward(image_tensor,
|
269 |
+
sample_tensor,
|
270 |
+
calib_tensor,
|
271 |
+
labels=label_tensor)
|
272 |
+
|
273 |
+
IOU, prec, recall = compute_acc(res, label_tensor)
|
274 |
+
|
275 |
+
# print(
|
276 |
+
# '{0}/{1} | Error: {2:06f} IOU: {3:06f} prec: {4:06f} recall: {5:06f}'
|
277 |
+
# .format(idx, num_tests, error.item(), IOU.item(), prec.item(), recall.item()))
|
278 |
+
erorr_arr.append(error.item())
|
279 |
+
IOU_arr.append(IOU.item())
|
280 |
+
prec_arr.append(prec.item())
|
281 |
+
recall_arr.append(recall.item())
|
282 |
+
|
283 |
+
return np.average(erorr_arr), np.average(IOU_arr), np.average(
|
284 |
+
prec_arr), np.average(recall_arr)
|
285 |
+
|
286 |
+
|
287 |
+
def calc_error_color(opt, netG, netC, cuda, dataset, num_tests):
|
288 |
+
if num_tests > len(dataset):
|
289 |
+
num_tests = len(dataset)
|
290 |
+
with torch.no_grad():
|
291 |
+
error_color_arr = []
|
292 |
+
|
293 |
+
for idx in tqdm(range(num_tests)):
|
294 |
+
data = dataset[idx * len(dataset) // num_tests]
|
295 |
+
# retrieve the data
|
296 |
+
image_tensor = data['img'].to(device=cuda)
|
297 |
+
calib_tensor = data['calib'].to(device=cuda)
|
298 |
+
color_sample_tensor = data['color_samples'].to(
|
299 |
+
device=cuda).unsqueeze(0)
|
300 |
+
|
301 |
+
if opt.num_views > 1:
|
302 |
+
color_sample_tensor = reshape_sample_tensor(
|
303 |
+
color_sample_tensor, opt.num_views)
|
304 |
+
|
305 |
+
rgb_tensor = data['rgbs'].to(device=cuda).unsqueeze(0)
|
306 |
+
|
307 |
+
netG.filter(image_tensor)
|
308 |
+
_, errorC = netC.forward(image_tensor,
|
309 |
+
netG.get_im_feat(),
|
310 |
+
color_sample_tensor,
|
311 |
+
calib_tensor,
|
312 |
+
labels=rgb_tensor)
|
313 |
+
|
314 |
+
# print('{0}/{1} | Error inout: {2:06f} | Error color: {3:06f}'
|
315 |
+
# .format(idx, num_tests, errorG.item(), errorC.item()))
|
316 |
+
error_color_arr.append(errorC.item())
|
317 |
+
|
318 |
+
return np.average(error_color_arr)
|
319 |
+
|
320 |
+
|
321 |
+
# pytorch lightning training related fucntions
|
322 |
+
|
323 |
+
|
324 |
+
def query_func(opt, netG, features, points, proj_matrix=None):
|
325 |
+
'''
|
326 |
+
- points: size of (bz, N, 3)
|
327 |
+
- proj_matrix: size of (bz, 4, 4)
|
328 |
+
return: size of (bz, 1, N)
|
329 |
+
'''
|
330 |
+
assert len(points) == 1
|
331 |
+
samples = points.repeat(opt.num_views, 1, 1)
|
332 |
+
samples = samples.permute(0, 2, 1) # [bz, 3, N]
|
333 |
+
|
334 |
+
# view specific query
|
335 |
+
if proj_matrix is not None:
|
336 |
+
samples = orthogonal(samples, proj_matrix)
|
337 |
+
|
338 |
+
calib_tensor = torch.stack([torch.eye(4).float()], dim=0).type_as(samples)
|
339 |
+
|
340 |
+
preds = netG.query(features=features,
|
341 |
+
points=samples,
|
342 |
+
calibs=calib_tensor,
|
343 |
+
regressor=netG.if_regressor)
|
344 |
+
|
345 |
+
if type(preds) is list:
|
346 |
+
preds = preds[0]
|
347 |
+
|
348 |
+
return preds
|
349 |
+
|
350 |
+
|
351 |
+
def isin(ar1, ar2):
|
352 |
+
return (ar1[..., None] == ar2).any(-1)
|
353 |
+
|
354 |
+
|
355 |
+
def in1d(ar1, ar2):
|
356 |
+
mask = ar2.new_zeros((max(ar1.max(), ar2.max()) + 1, ), dtype=torch.bool)
|
357 |
+
mask[ar2.unique()] = True
|
358 |
+
return mask[ar1]
|
359 |
+
|
360 |
+
|
361 |
+
def get_visibility(xy, z, faces):
|
362 |
+
"""get the visibility of vertices
|
363 |
+
|
364 |
+
Args:
|
365 |
+
xy (torch.tensor): [N,2]
|
366 |
+
z (torch.tensor): [N,1]
|
367 |
+
faces (torch.tensor): [N,3]
|
368 |
+
size (int): resolution of rendered image
|
369 |
+
"""
|
370 |
+
|
371 |
+
xyz = torch.cat((xy, -z), dim=1)
|
372 |
+
xyz = (xyz + 1.0) / 2.0
|
373 |
+
faces = faces.long()
|
374 |
+
|
375 |
+
rasterizer = Pytorch3dRasterizer(image_size=2**12)
|
376 |
+
meshes_screen = Meshes(verts=xyz[None, ...], faces=faces[None, ...])
|
377 |
+
raster_settings = rasterizer.raster_settings
|
378 |
+
|
379 |
+
pix_to_face, zbuf, bary_coords, dists = rasterize_meshes(
|
380 |
+
meshes_screen,
|
381 |
+
image_size=raster_settings.image_size,
|
382 |
+
blur_radius=raster_settings.blur_radius,
|
383 |
+
faces_per_pixel=raster_settings.faces_per_pixel,
|
384 |
+
bin_size=raster_settings.bin_size,
|
385 |
+
max_faces_per_bin=raster_settings.max_faces_per_bin,
|
386 |
+
perspective_correct=raster_settings.perspective_correct,
|
387 |
+
cull_backfaces=raster_settings.cull_backfaces,
|
388 |
+
)
|
389 |
+
|
390 |
+
vis_vertices_id = torch.unique(faces[torch.unique(pix_to_face), :])
|
391 |
+
vis_mask = torch.zeros(size=(z.shape[0], 1))
|
392 |
+
vis_mask[vis_vertices_id] = 1.0
|
393 |
+
|
394 |
+
# print("------------------------\n")
|
395 |
+
# print(f"keep points : {vis_mask.sum()/len(vis_mask)}")
|
396 |
+
|
397 |
+
return vis_mask
|
398 |
+
|
399 |
+
|
400 |
+
def batch_mean(res, key):
|
401 |
+
# recursive mean for multilevel dicts
|
402 |
+
return torch.stack([
|
403 |
+
x[key] if isinstance(x, dict) else batch_mean(x, key) for x in res
|
404 |
+
]).mean()
|
405 |
+
|
406 |
+
|
407 |
+
def tf_log_convert(log_dict):
|
408 |
+
new_log_dict = log_dict.copy()
|
409 |
+
for k, v in log_dict.items():
|
410 |
+
new_log_dict[k.replace("_", "/")] = v
|
411 |
+
del new_log_dict[k]
|
412 |
+
|
413 |
+
return new_log_dict
|
414 |
+
|
415 |
+
|
416 |
+
def bar_log_convert(log_dict, name=None, rot=None):
|
417 |
+
from decimal import Decimal
|
418 |
+
|
419 |
+
new_log_dict = {}
|
420 |
+
|
421 |
+
if name is not None:
|
422 |
+
new_log_dict['name'] = name[0]
|
423 |
+
if rot is not None:
|
424 |
+
new_log_dict['rot'] = rot[0]
|
425 |
+
|
426 |
+
for k, v in log_dict.items():
|
427 |
+
color = "yellow"
|
428 |
+
if 'loss' in k:
|
429 |
+
color = "red"
|
430 |
+
k = k.replace("loss", "L")
|
431 |
+
elif 'acc' in k:
|
432 |
+
color = "green"
|
433 |
+
k = k.replace("acc", "A")
|
434 |
+
elif 'iou' in k:
|
435 |
+
color = "green"
|
436 |
+
k = k.replace("iou", "I")
|
437 |
+
elif 'prec' in k:
|
438 |
+
color = "green"
|
439 |
+
k = k.replace("prec", "P")
|
440 |
+
elif 'recall' in k:
|
441 |
+
color = "green"
|
442 |
+
k = k.replace("recall", "R")
|
443 |
+
|
444 |
+
if 'lr' not in k:
|
445 |
+
new_log_dict[colored(k.split("_")[1],
|
446 |
+
color)] = colored(f"{v:.3f}", color)
|
447 |
+
else:
|
448 |
+
new_log_dict[colored(k.split("_")[1],
|
449 |
+
color)] = colored(f"{Decimal(str(v)):.1E}",
|
450 |
+
color)
|
451 |
+
|
452 |
+
if 'loss' in new_log_dict.keys():
|
453 |
+
del new_log_dict['loss']
|
454 |
+
|
455 |
+
return new_log_dict
|
456 |
+
|
457 |
+
|
458 |
+
def accumulate(outputs, rot_num, split):
|
459 |
+
|
460 |
+
hparam_log_dict = {}
|
461 |
+
|
462 |
+
metrics = outputs[0].keys()
|
463 |
+
datasets = split.keys()
|
464 |
+
|
465 |
+
for dataset in datasets:
|
466 |
+
for metric in metrics:
|
467 |
+
keyword = f"hparam/{dataset}-{metric}"
|
468 |
+
if keyword not in hparam_log_dict.keys():
|
469 |
+
hparam_log_dict[keyword] = 0
|
470 |
+
for idx in range(split[dataset][0] * rot_num,
|
471 |
+
split[dataset][1] * rot_num):
|
472 |
+
hparam_log_dict[keyword] += outputs[idx][metric]
|
473 |
+
hparam_log_dict[keyword] /= (split[dataset][1] -
|
474 |
+
split[dataset][0]) * rot_num
|
475 |
+
|
476 |
+
print(colored(hparam_log_dict, "green"))
|
477 |
+
|
478 |
+
return hparam_log_dict
|
479 |
+
|
480 |
+
|
481 |
+
def calc_error_N(outputs, targets):
|
482 |
+
"""calculate the error of normal (IGR)
|
483 |
+
|
484 |
+
Args:
|
485 |
+
outputs (torch.tensor): [B, 3, N]
|
486 |
+
target (torch.tensor): [B, N, 3]
|
487 |
+
|
488 |
+
# manifold loss and grad_loss in IGR paper
|
489 |
+
grad_loss = ((nonmnfld_grad.norm(2, dim=-1) - 1) ** 2).mean()
|
490 |
+
normals_loss = ((mnfld_grad - normals).abs()).norm(2, dim=1).mean()
|
491 |
+
|
492 |
+
Returns:
|
493 |
+
torch.tensor: error of valid normals on the surface
|
494 |
+
"""
|
495 |
+
# outputs = torch.tanh(-outputs.permute(0,2,1).reshape(-1,3))
|
496 |
+
outputs = -outputs.permute(0, 2, 1).reshape(-1, 1)
|
497 |
+
targets = targets.reshape(-1, 3)[:, 2:3]
|
498 |
+
with_normals = targets.sum(dim=1).abs() > 0.0
|
499 |
+
|
500 |
+
# eikonal loss
|
501 |
+
grad_loss = ((outputs[with_normals].norm(2, dim=-1) - 1)**2).mean()
|
502 |
+
# normals loss
|
503 |
+
normal_loss = (outputs - targets)[with_normals].abs().norm(2, dim=1).mean()
|
504 |
+
|
505 |
+
return grad_loss * 0.0 + normal_loss
|
506 |
+
|
507 |
+
|
508 |
+
def calc_knn_acc(preds, carn_verts, labels, pick_num):
|
509 |
+
"""calculate knn accuracy
|
510 |
+
|
511 |
+
Args:
|
512 |
+
preds (torch.tensor): [B, 3, N]
|
513 |
+
carn_verts (torch.tensor): [SMPLX_V_num, 3]
|
514 |
+
labels (torch.tensor): [B, N_knn, N]
|
515 |
+
"""
|
516 |
+
N_knn_full = labels.shape[1]
|
517 |
+
preds = preds.permute(0, 2, 1).reshape(-1, 3)
|
518 |
+
labels = labels.permute(0, 2, 1).reshape(-1, N_knn_full) # [BxN, num_knn]
|
519 |
+
labels = labels[:, :pick_num]
|
520 |
+
|
521 |
+
dist = torch.cdist(preds, carn_verts, p=2) # [BxN, SMPL_V_num]
|
522 |
+
knn = dist.topk(k=pick_num, dim=1, largest=False)[1] # [BxN, num_knn]
|
523 |
+
cat_mat = torch.sort(torch.cat((knn, labels), dim=1))[0]
|
524 |
+
bool_col = torch.zeros_like(cat_mat)[:, 0]
|
525 |
+
for i in range(pick_num * 2 - 1):
|
526 |
+
bool_col += cat_mat[:, i] == cat_mat[:, i + 1]
|
527 |
+
acc = (bool_col > 0).sum() / len(bool_col)
|
528 |
+
|
529 |
+
return acc
|
530 |
+
|
531 |
+
|
532 |
+
def calc_acc_seg(output, target, num_multiseg):
|
533 |
+
from pytorch_lightning.metrics import Accuracy
|
534 |
+
return Accuracy()(output.reshape(-1, num_multiseg).cpu(),
|
535 |
+
target.flatten().cpu())
|
536 |
+
|
537 |
+
|
538 |
+
def add_watermark(imgs, titles):
|
539 |
+
|
540 |
+
# Write some Text
|
541 |
+
|
542 |
+
font = cv2.FONT_HERSHEY_SIMPLEX
|
543 |
+
bottomLeftCornerOfText = (350, 50)
|
544 |
+
bottomRightCornerOfText = (800, 50)
|
545 |
+
fontScale = 1
|
546 |
+
fontColor = (1.0, 1.0, 1.0)
|
547 |
+
lineType = 2
|
548 |
+
|
549 |
+
for i in range(len(imgs)):
|
550 |
+
|
551 |
+
title = titles[i + 1]
|
552 |
+
cv2.putText(imgs[i], title, bottomLeftCornerOfText, font, fontScale,
|
553 |
+
fontColor, lineType)
|
554 |
+
|
555 |
+
if i == 0:
|
556 |
+
cv2.putText(imgs[i], str(titles[i][0]), bottomRightCornerOfText,
|
557 |
+
font, fontScale, fontColor, lineType)
|
558 |
+
|
559 |
+
result = np.concatenate(imgs, axis=0).transpose(2, 0, 1)
|
560 |
+
|
561 |
+
return result
|
562 |
+
|
563 |
+
|
564 |
+
def make_test_gif(img_dir):
|
565 |
+
|
566 |
+
if img_dir is not None and len(os.listdir(img_dir)) > 0:
|
567 |
+
for dataset in os.listdir(img_dir):
|
568 |
+
for subject in sorted(os.listdir(osp.join(img_dir, dataset))):
|
569 |
+
img_lst = []
|
570 |
+
im1 = None
|
571 |
+
for file in sorted(
|
572 |
+
os.listdir(osp.join(img_dir, dataset, subject))):
|
573 |
+
if file[-3:] not in ['obj', 'gif']:
|
574 |
+
img_path = os.path.join(img_dir, dataset, subject,
|
575 |
+
file)
|
576 |
+
if im1 == None:
|
577 |
+
im1 = Image.open(img_path)
|
578 |
+
else:
|
579 |
+
img_lst.append(Image.open(img_path))
|
580 |
+
|
581 |
+
print(os.path.join(img_dir, dataset, subject, "out.gif"))
|
582 |
+
im1.save(os.path.join(img_dir, dataset, subject, "out.gif"),
|
583 |
+
save_all=True,
|
584 |
+
append_images=img_lst,
|
585 |
+
duration=500,
|
586 |
+
loop=0)
|
587 |
+
|
588 |
+
|
589 |
+
def export_cfg(logger, cfg):
|
590 |
+
|
591 |
+
cfg_export_file = osp.join(logger.save_dir, logger.name,
|
592 |
+
f"version_{logger.version}", "cfg.yaml")
|
593 |
+
|
594 |
+
if not osp.exists(cfg_export_file):
|
595 |
+
os.makedirs(osp.dirname(cfg_export_file), exist_ok=True)
|
596 |
+
with open(cfg_export_file, "w+") as file:
|
597 |
+
_ = yaml.dump(cfg, file)
|
lib/dataset/Evaluator.py
ADDED
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
|
5 |
+
# holder of all proprietary rights on this computer program.
|
6 |
+
# You can only use this computer program if you have closed
|
7 |
+
# a license agreement with MPG or you get the right to use the computer
|
8 |
+
# program from someone who is authorized to grant you that right.
|
9 |
+
# Any use of the computer program without a valid license is prohibited and
|
10 |
+
# liable to prosecution.
|
11 |
+
#
|
12 |
+
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
|
13 |
+
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
|
14 |
+
# for Intelligent Systems. All rights reserved.
|
15 |
+
#
|
16 |
+
# Contact: [email protected]
|
17 |
+
|
18 |
+
|
19 |
+
from lib.renderer.gl.normal_render import NormalRender
|
20 |
+
from lib.dataset.mesh_util import projection
|
21 |
+
from lib.common.render import Render
|
22 |
+
from PIL import Image
|
23 |
+
import numpy as np
|
24 |
+
import torch
|
25 |
+
from torch import nn
|
26 |
+
import trimesh
|
27 |
+
import os.path as osp
|
28 |
+
from PIL import Image
|
29 |
+
|
30 |
+
|
31 |
+
class Evaluator:
|
32 |
+
|
33 |
+
_normal_render = None
|
34 |
+
|
35 |
+
@staticmethod
|
36 |
+
def init_gl():
|
37 |
+
Evaluator._normal_render = NormalRender(width=512, height=512)
|
38 |
+
|
39 |
+
def __init__(self, device):
|
40 |
+
self.device = device
|
41 |
+
self.render = Render(size=512, device=self.device)
|
42 |
+
self.error_term = nn.MSELoss()
|
43 |
+
|
44 |
+
self.offset = 0.0
|
45 |
+
self.scale_factor = None
|
46 |
+
|
47 |
+
def set_mesh(self, result_dict, scale_factor=1.0, offset=0.0):
|
48 |
+
|
49 |
+
for key in result_dict.keys():
|
50 |
+
if torch.is_tensor(result_dict[key]):
|
51 |
+
result_dict[key] = result_dict[key].detach().cpu().numpy()
|
52 |
+
|
53 |
+
for k, v in result_dict.items():
|
54 |
+
setattr(self, k, v)
|
55 |
+
|
56 |
+
self.scale_factor = scale_factor
|
57 |
+
self.offset = offset
|
58 |
+
|
59 |
+
def _render_normal(self, mesh, deg, norms=None):
|
60 |
+
view_mat = np.identity(4)
|
61 |
+
rz = deg / 180.0 * np.pi
|
62 |
+
model_mat = np.identity(4)
|
63 |
+
model_mat[:3, :3] = self._normal_render.euler_to_rot_mat(0, rz, 0)
|
64 |
+
model_mat[1, 3] = self.offset
|
65 |
+
view_mat[2, 2] *= -1
|
66 |
+
|
67 |
+
self._normal_render.set_matrices(view_mat, model_mat)
|
68 |
+
if norms is None:
|
69 |
+
norms = mesh.vertex_normals
|
70 |
+
self._normal_render.set_normal_mesh(self.scale_factor * mesh.vertices,
|
71 |
+
mesh.faces, norms, mesh.faces)
|
72 |
+
self._normal_render.draw()
|
73 |
+
normal_img = self._normal_render.get_color()
|
74 |
+
return normal_img
|
75 |
+
|
76 |
+
def render_mesh_list(self, mesh_lst):
|
77 |
+
|
78 |
+
self.offset = 0.0
|
79 |
+
self.scale_factor = 1.0
|
80 |
+
|
81 |
+
full_list = []
|
82 |
+
for mesh in mesh_lst:
|
83 |
+
row_lst = []
|
84 |
+
for deg in np.arange(0, 360, 90):
|
85 |
+
normal = self._render_normal(mesh, deg)
|
86 |
+
row_lst.append(normal)
|
87 |
+
full_list.append(np.concatenate(row_lst, axis=1))
|
88 |
+
|
89 |
+
res_array = np.concatenate(full_list, axis=0)
|
90 |
+
|
91 |
+
return res_array
|
92 |
+
|
93 |
+
def _get_reproj_normal_error(self, deg):
|
94 |
+
|
95 |
+
tgt_normal = self._render_normal(self.tgt_mesh, deg)
|
96 |
+
src_normal = self._render_normal(self.src_mesh, deg)
|
97 |
+
error = (((src_normal[:, :, :3] -
|
98 |
+
tgt_normal[:, :, :3])**2).sum(axis=2).mean(axis=(0, 1)))
|
99 |
+
|
100 |
+
return error, [src_normal, tgt_normal]
|
101 |
+
|
102 |
+
def render_normal(self, verts, faces):
|
103 |
+
|
104 |
+
verts = verts[0].detach().cpu().numpy()
|
105 |
+
faces = faces[0].detach().cpu().numpy()
|
106 |
+
|
107 |
+
mesh_F = trimesh.Trimesh(verts * np.array([1.0, -1.0, 1.0]), faces)
|
108 |
+
mesh_B = trimesh.Trimesh(verts * np.array([1.0, -1.0, -1.0]), faces)
|
109 |
+
|
110 |
+
self.scale_factor = 1.0
|
111 |
+
|
112 |
+
normal_F = self._render_normal(mesh_F, 0)
|
113 |
+
normal_B = self._render_normal(mesh_B,
|
114 |
+
0,
|
115 |
+
norms=mesh_B.vertex_normals *
|
116 |
+
np.array([-1.0, -1.0, 1.0]))
|
117 |
+
|
118 |
+
mask = normal_F[:, :, 3:4]
|
119 |
+
normal_F = (torch.as_tensor(2.0 * (normal_F - 0.5) * mask).permute(
|
120 |
+
2, 0, 1)[:3, :, :].float().unsqueeze(0).to(self.device))
|
121 |
+
normal_B = (torch.as_tensor(2.0 * (normal_B - 0.5) * mask).permute(
|
122 |
+
2, 0, 1)[:3, :, :].float().unsqueeze(0).to(self.device))
|
123 |
+
|
124 |
+
return {"T_normal_F": normal_F, "T_normal_B": normal_B}
|
125 |
+
|
126 |
+
def calculate_normal_consist(
|
127 |
+
self,
|
128 |
+
frontal=True,
|
129 |
+
back=True,
|
130 |
+
left=True,
|
131 |
+
right=True,
|
132 |
+
save_demo_img=None,
|
133 |
+
return_demo=False,
|
134 |
+
):
|
135 |
+
|
136 |
+
# reproj error
|
137 |
+
# if save_demo_img is not None, save a visualization at the given path (etc, "./test.png")
|
138 |
+
if self._normal_render is None:
|
139 |
+
print(
|
140 |
+
"In order to use normal render, "
|
141 |
+
"you have to call init_gl() before initialing any evaluator objects."
|
142 |
+
)
|
143 |
+
return -1
|
144 |
+
|
145 |
+
side_cnt = 0
|
146 |
+
total_error = 0
|
147 |
+
demo_list = []
|
148 |
+
|
149 |
+
if frontal:
|
150 |
+
side_cnt += 1
|
151 |
+
error, normal_lst = self._get_reproj_normal_error(0)
|
152 |
+
total_error += error
|
153 |
+
demo_list.append(np.concatenate(normal_lst, axis=0))
|
154 |
+
if back:
|
155 |
+
side_cnt += 1
|
156 |
+
error, normal_lst = self._get_reproj_normal_error(180)
|
157 |
+
total_error += error
|
158 |
+
demo_list.append(np.concatenate(normal_lst, axis=0))
|
159 |
+
if left:
|
160 |
+
side_cnt += 1
|
161 |
+
error, normal_lst = self._get_reproj_normal_error(90)
|
162 |
+
total_error += error
|
163 |
+
demo_list.append(np.concatenate(normal_lst, axis=0))
|
164 |
+
if right:
|
165 |
+
side_cnt += 1
|
166 |
+
error, normal_lst = self._get_reproj_normal_error(270)
|
167 |
+
total_error += error
|
168 |
+
demo_list.append(np.concatenate(normal_lst, axis=0))
|
169 |
+
if save_demo_img is not None:
|
170 |
+
res_array = np.concatenate(demo_list, axis=1)
|
171 |
+
res_img = Image.fromarray((res_array * 255).astype(np.uint8))
|
172 |
+
res_img.save(save_demo_img)
|
173 |
+
|
174 |
+
if return_demo:
|
175 |
+
res_array = np.concatenate(demo_list, axis=1)
|
176 |
+
return res_array
|
177 |
+
else:
|
178 |
+
return total_error
|
179 |
+
|
180 |
+
def space_transfer(self):
|
181 |
+
|
182 |
+
# convert from GT to SDF
|
183 |
+
self.verts_pr -= self.recon_size / 2.0
|
184 |
+
self.verts_pr /= self.recon_size / 2.0
|
185 |
+
|
186 |
+
self.verts_gt = projection(self.verts_gt, self.calib)
|
187 |
+
self.verts_gt[:, 1] *= -1
|
188 |
+
|
189 |
+
self.tgt_mesh = trimesh.Trimesh(self.verts_gt, self.faces_gt)
|
190 |
+
self.src_mesh = trimesh.Trimesh(self.verts_pr, self.faces_pr)
|
191 |
+
|
192 |
+
# (self.tgt_mesh+self.src_mesh).show()
|
193 |
+
|
194 |
+
def export_mesh(self, dir, name):
|
195 |
+
self.tgt_mesh.visual.vertex_colors = np.array([255, 0, 0])
|
196 |
+
self.src_mesh.visual.vertex_colors = np.array([0, 255, 0])
|
197 |
+
|
198 |
+
(self.tgt_mesh + self.src_mesh).export(
|
199 |
+
osp.join(dir, f"{name}_gt_pr.obj"))
|
200 |
+
|
201 |
+
def calculate_chamfer_p2s(self, sampled_points=1000):
|
202 |
+
"""calculate the geometry metrics [chamfer, p2s, chamfer_H, p2s_H]
|
203 |
+
|
204 |
+
Args:
|
205 |
+
verts_gt (torch.cuda.tensor): [N, 3]
|
206 |
+
faces_gt (torch.cuda.tensor): [M, 3]
|
207 |
+
verts_pr (torch.cuda.tensor): [N', 3]
|
208 |
+
faces_pr (torch.cuda.tensor): [M', 3]
|
209 |
+
sampled_points (int, optional): use smaller number for faster testing. Defaults to 1000.
|
210 |
+
|
211 |
+
Returns:
|
212 |
+
tuple: chamfer, p2s, chamfer_H, p2s_H
|
213 |
+
"""
|
214 |
+
|
215 |
+
gt_surface_pts, _ = trimesh.sample.sample_surface_even(
|
216 |
+
self.tgt_mesh, sampled_points)
|
217 |
+
pred_surface_pts, _ = trimesh.sample.sample_surface_even(
|
218 |
+
self.src_mesh, sampled_points)
|
219 |
+
|
220 |
+
_, dist_pred_gt, _ = trimesh.proximity.closest_point(
|
221 |
+
self.src_mesh, gt_surface_pts)
|
222 |
+
_, dist_gt_pred, _ = trimesh.proximity.closest_point(
|
223 |
+
self.tgt_mesh, pred_surface_pts)
|
224 |
+
|
225 |
+
dist_pred_gt[np.isnan(dist_pred_gt)] = 0
|
226 |
+
dist_gt_pred[np.isnan(dist_gt_pred)] = 0
|
227 |
+
chamfer_dist = 0.5 * (dist_pred_gt.mean() +
|
228 |
+
dist_gt_pred.mean()).item() * 100
|
229 |
+
p2s_dist = dist_pred_gt.mean().item() * 100
|
230 |
+
|
231 |
+
return chamfer_dist, p2s_dist
|
232 |
+
|
233 |
+
def calc_acc(self, output, target, thres=0.5, use_sdf=False):
|
234 |
+
|
235 |
+
# # remove the surface points with thres
|
236 |
+
# non_surf_ids = (target != thres)
|
237 |
+
# output = output[non_surf_ids]
|
238 |
+
# target = target[non_surf_ids]
|
239 |
+
|
240 |
+
with torch.no_grad():
|
241 |
+
output = output.masked_fill(output < thres, 0.0)
|
242 |
+
output = output.masked_fill(output > thres, 1.0)
|
243 |
+
|
244 |
+
if use_sdf:
|
245 |
+
target = target.masked_fill(target < thres, 0.0)
|
246 |
+
target = target.masked_fill(target > thres, 1.0)
|
247 |
+
|
248 |
+
acc = output.eq(target).float().mean()
|
249 |
+
|
250 |
+
# iou, precison, recall
|
251 |
+
output = output > thres
|
252 |
+
target = target > thres
|
253 |
+
|
254 |
+
union = output | target
|
255 |
+
inter = output & target
|
256 |
+
|
257 |
+
_max = torch.tensor(1.0).to(output.device)
|
258 |
+
|
259 |
+
union = max(union.sum().float(), _max)
|
260 |
+
true_pos = max(inter.sum().float(), _max)
|
261 |
+
vol_pred = max(output.sum().float(), _max)
|
262 |
+
vol_gt = max(target.sum().float(), _max)
|
263 |
+
|
264 |
+
return acc, true_pos / union, true_pos / vol_pred, true_pos / vol_gt
|
lib/dataset/NormalDataset.py
ADDED
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
|
5 |
+
# holder of all proprietary rights on this computer program.
|
6 |
+
# You can only use this computer program if you have closed
|
7 |
+
# a license agreement with MPG or you get the right to use the computer
|
8 |
+
# program from someone who is authorized to grant you that right.
|
9 |
+
# Any use of the computer program without a valid license is prohibited and
|
10 |
+
# liable to prosecution.
|
11 |
+
#
|
12 |
+
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
|
13 |
+
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
|
14 |
+
# for Intelligent Systems. All rights reserved.
|
15 |
+
#
|
16 |
+
# Contact: [email protected]
|
17 |
+
|
18 |
+
import os.path as osp
|
19 |
+
import numpy as np
|
20 |
+
from PIL import Image
|
21 |
+
import torchvision.transforms as transforms
|
22 |
+
|
23 |
+
|
24 |
+
class NormalDataset():
|
25 |
+
def __init__(self, cfg, split='train'):
|
26 |
+
|
27 |
+
self.split = split
|
28 |
+
self.root = cfg.root
|
29 |
+
self.overfit = cfg.overfit
|
30 |
+
|
31 |
+
self.opt = cfg.dataset
|
32 |
+
self.datasets = self.opt.types
|
33 |
+
self.input_size = self.opt.input_size
|
34 |
+
self.set_splits = self.opt.set_splits
|
35 |
+
self.scales = self.opt.scales
|
36 |
+
self.pifu = self.opt.pifu
|
37 |
+
|
38 |
+
# input data types and dimensions
|
39 |
+
self.in_nml = [item[0] for item in cfg.net.in_nml]
|
40 |
+
self.in_nml_dim = [item[1] for item in cfg.net.in_nml]
|
41 |
+
self.in_total = self.in_nml + ['normal_F', 'normal_B']
|
42 |
+
self.in_total_dim = self.in_nml_dim + [3, 3]
|
43 |
+
|
44 |
+
if self.split != 'train':
|
45 |
+
self.rotations = range(0, 360, 120)
|
46 |
+
else:
|
47 |
+
self.rotations = np.arange(0, 360, 360 /
|
48 |
+
self.opt.rotation_num).astype(np.int)
|
49 |
+
|
50 |
+
self.datasets_dict = {}
|
51 |
+
for dataset_id, dataset in enumerate(self.datasets):
|
52 |
+
dataset_dir = osp.join(self.root, dataset, "smplx")
|
53 |
+
self.datasets_dict[dataset] = {
|
54 |
+
"subjects":
|
55 |
+
np.loadtxt(osp.join(self.root, dataset, "all.txt"), dtype=str),
|
56 |
+
"path":
|
57 |
+
dataset_dir,
|
58 |
+
"scale":
|
59 |
+
self.scales[dataset_id]
|
60 |
+
}
|
61 |
+
|
62 |
+
self.subject_list = self.get_subject_list(split)
|
63 |
+
|
64 |
+
# PIL to tensor
|
65 |
+
self.image_to_tensor = transforms.Compose([
|
66 |
+
transforms.Resize(self.input_size),
|
67 |
+
transforms.ToTensor(),
|
68 |
+
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
69 |
+
])
|
70 |
+
|
71 |
+
# PIL to tensor
|
72 |
+
self.mask_to_tensor = transforms.Compose([
|
73 |
+
transforms.Resize(self.input_size),
|
74 |
+
transforms.ToTensor(),
|
75 |
+
transforms.Normalize((0.0, ), (1.0, ))
|
76 |
+
])
|
77 |
+
|
78 |
+
def get_subject_list(self, split):
|
79 |
+
|
80 |
+
subject_list = []
|
81 |
+
|
82 |
+
for dataset in self.datasets:
|
83 |
+
|
84 |
+
if self.pifu:
|
85 |
+
txt = osp.join(self.root, dataset, f'{split}_pifu.txt')
|
86 |
+
else:
|
87 |
+
txt = osp.join(self.root, dataset, f'{split}.txt')
|
88 |
+
|
89 |
+
if osp.exists(txt):
|
90 |
+
print(f"load from {txt}")
|
91 |
+
subject_list += sorted(np.loadtxt(txt, dtype=str).tolist())
|
92 |
+
|
93 |
+
if self.pifu:
|
94 |
+
miss_pifu = sorted(
|
95 |
+
np.loadtxt(osp.join(self.root, dataset,
|
96 |
+
"miss_pifu.txt"),
|
97 |
+
dtype=str).tolist())
|
98 |
+
subject_list = [
|
99 |
+
subject for subject in subject_list
|
100 |
+
if subject not in miss_pifu
|
101 |
+
]
|
102 |
+
subject_list = [
|
103 |
+
"renderpeople/" + subject for subject in subject_list
|
104 |
+
]
|
105 |
+
|
106 |
+
else:
|
107 |
+
train_txt = osp.join(self.root, dataset, 'train.txt')
|
108 |
+
val_txt = osp.join(self.root, dataset, 'val.txt')
|
109 |
+
test_txt = osp.join(self.root, dataset, 'test.txt')
|
110 |
+
|
111 |
+
print(
|
112 |
+
f"generate lists of [train, val, test] \n {train_txt} \n {val_txt} \n {test_txt} \n"
|
113 |
+
)
|
114 |
+
|
115 |
+
split_txt = osp.join(self.root, dataset, f'{split}.txt')
|
116 |
+
|
117 |
+
subjects = self.datasets_dict[dataset]['subjects']
|
118 |
+
train_split = int(len(subjects) * self.set_splits[0])
|
119 |
+
val_split = int(
|
120 |
+
len(subjects) * self.set_splits[1]) + train_split
|
121 |
+
|
122 |
+
with open(train_txt, "w") as f:
|
123 |
+
f.write("\n".join(dataset + "/" + item
|
124 |
+
for item in subjects[:train_split]))
|
125 |
+
with open(val_txt, "w") as f:
|
126 |
+
f.write("\n".join(
|
127 |
+
dataset + "/" + item
|
128 |
+
for item in subjects[train_split:val_split]))
|
129 |
+
with open(test_txt, "w") as f:
|
130 |
+
f.write("\n".join(dataset + "/" + item
|
131 |
+
for item in subjects[val_split:]))
|
132 |
+
|
133 |
+
subject_list += sorted(
|
134 |
+
np.loadtxt(split_txt, dtype=str).tolist())
|
135 |
+
|
136 |
+
bug_list = sorted(
|
137 |
+
np.loadtxt(osp.join(self.root, 'bug.txt'), dtype=str).tolist())
|
138 |
+
|
139 |
+
subject_list = [
|
140 |
+
subject for subject in subject_list if (subject not in bug_list)
|
141 |
+
]
|
142 |
+
|
143 |
+
return subject_list
|
144 |
+
|
145 |
+
def __len__(self):
|
146 |
+
return len(self.subject_list) * len(self.rotations)
|
147 |
+
|
148 |
+
def __getitem__(self, index):
|
149 |
+
|
150 |
+
# only pick the first data if overfitting
|
151 |
+
if self.overfit:
|
152 |
+
index = 0
|
153 |
+
|
154 |
+
rid = index % len(self.rotations)
|
155 |
+
mid = index // len(self.rotations)
|
156 |
+
|
157 |
+
rotation = self.rotations[rid]
|
158 |
+
|
159 |
+
# choose specific test sets
|
160 |
+
subject = self.subject_list[mid]
|
161 |
+
|
162 |
+
subject_render = "/".join(
|
163 |
+
[subject.split("/")[0] + "_12views",
|
164 |
+
subject.split("/")[1]])
|
165 |
+
|
166 |
+
# setup paths
|
167 |
+
data_dict = {
|
168 |
+
'dataset':
|
169 |
+
subject.split("/")[0],
|
170 |
+
'subject':
|
171 |
+
subject,
|
172 |
+
'rotation':
|
173 |
+
rotation,
|
174 |
+
'image_path':
|
175 |
+
osp.join(self.root, subject_render, 'render',
|
176 |
+
f'{rotation:03d}.png')
|
177 |
+
}
|
178 |
+
|
179 |
+
# image/normal/depth loader
|
180 |
+
for name, channel in zip(self.in_total, self.in_total_dim):
|
181 |
+
|
182 |
+
if name != 'image':
|
183 |
+
data_dict.update({
|
184 |
+
f'{name}_path':
|
185 |
+
osp.join(self.root, subject_render, name,
|
186 |
+
f'{rotation:03d}.png')
|
187 |
+
})
|
188 |
+
data_dict.update({
|
189 |
+
name:
|
190 |
+
self.imagepath2tensor(data_dict[f'{name}_path'],
|
191 |
+
channel,
|
192 |
+
inv='depth_B' in name)
|
193 |
+
})
|
194 |
+
|
195 |
+
path_keys = [
|
196 |
+
key for key in data_dict.keys() if '_path' in key or '_dir' in key
|
197 |
+
]
|
198 |
+
for key in path_keys:
|
199 |
+
del data_dict[key]
|
200 |
+
|
201 |
+
return data_dict
|
202 |
+
|
203 |
+
def imagepath2tensor(self, path, channel=3, inv=False):
|
204 |
+
|
205 |
+
rgba = Image.open(path).convert('RGBA')
|
206 |
+
mask = rgba.split()[-1]
|
207 |
+
image = rgba.convert('RGB')
|
208 |
+
image = self.image_to_tensor(image)
|
209 |
+
mask = self.mask_to_tensor(mask)
|
210 |
+
image = (image * mask)[:channel]
|
211 |
+
|
212 |
+
return (image * (0.5 - inv) * 2.0).float()
|
lib/dataset/NormalModule.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
|
5 |
+
# holder of all proprietary rights on this computer program.
|
6 |
+
# You can only use this computer program if you have closed
|
7 |
+
# a license agreement with MPG or you get the right to use the computer
|
8 |
+
# program from someone who is authorized to grant you that right.
|
9 |
+
# Any use of the computer program without a valid license is prohibited and
|
10 |
+
# liable to prosecution.
|
11 |
+
#
|
12 |
+
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
|
13 |
+
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
|
14 |
+
# for Intelligent Systems. All rights reserved.
|
15 |
+
#
|
16 |
+
# Contact: [email protected]
|
17 |
+
|
18 |
+
import numpy as np
|
19 |
+
from torch.utils.data import DataLoader
|
20 |
+
from .NormalDataset import NormalDataset
|
21 |
+
|
22 |
+
# pytorch lightning related libs
|
23 |
+
import pytorch_lightning as pl
|
24 |
+
|
25 |
+
|
26 |
+
class NormalModule(pl.LightningDataModule):
|
27 |
+
def __init__(self, cfg):
|
28 |
+
super(NormalModule, self).__init__()
|
29 |
+
self.cfg = cfg
|
30 |
+
self.overfit = self.cfg.overfit
|
31 |
+
|
32 |
+
if self.overfit:
|
33 |
+
self.batch_size = 1
|
34 |
+
else:
|
35 |
+
self.batch_size = self.cfg.batch_size
|
36 |
+
|
37 |
+
self.data_size = {}
|
38 |
+
|
39 |
+
def prepare_data(self):
|
40 |
+
|
41 |
+
pass
|
42 |
+
|
43 |
+
@staticmethod
|
44 |
+
def worker_init_fn(worker_id):
|
45 |
+
np.random.seed(np.random.get_state()[1][0] + worker_id)
|
46 |
+
|
47 |
+
def setup(self, stage):
|
48 |
+
|
49 |
+
if stage == 'fit' or stage is None:
|
50 |
+
self.train_dataset = NormalDataset(cfg=self.cfg, split="train")
|
51 |
+
self.val_dataset = NormalDataset(cfg=self.cfg, split="val")
|
52 |
+
self.data_size = {
|
53 |
+
'train': len(self.train_dataset),
|
54 |
+
'val': len(self.val_dataset)
|
55 |
+
}
|
56 |
+
|
57 |
+
if stage == 'test' or stage is None:
|
58 |
+
self.test_dataset = NormalDataset(cfg=self.cfg, split="test")
|
59 |
+
|
60 |
+
def train_dataloader(self):
|
61 |
+
|
62 |
+
train_data_loader = DataLoader(self.train_dataset,
|
63 |
+
batch_size=self.batch_size,
|
64 |
+
shuffle=not self.overfit,
|
65 |
+
num_workers=self.cfg.num_threads,
|
66 |
+
pin_memory=True,
|
67 |
+
worker_init_fn=self.worker_init_fn)
|
68 |
+
|
69 |
+
return train_data_loader
|
70 |
+
|
71 |
+
def val_dataloader(self):
|
72 |
+
|
73 |
+
if self.overfit:
|
74 |
+
current_dataset = self.train_dataset
|
75 |
+
else:
|
76 |
+
current_dataset = self.val_dataset
|
77 |
+
|
78 |
+
val_data_loader = DataLoader(current_dataset,
|
79 |
+
batch_size=self.batch_size,
|
80 |
+
shuffle=False,
|
81 |
+
num_workers=self.cfg.num_threads,
|
82 |
+
pin_memory=True)
|
83 |
+
|
84 |
+
return val_data_loader
|
85 |
+
|
86 |
+
def test_dataloader(self):
|
87 |
+
|
88 |
+
test_data_loader = DataLoader(self.test_dataset,
|
89 |
+
batch_size=1,
|
90 |
+
shuffle=False,
|
91 |
+
num_workers=self.cfg.num_threads,
|
92 |
+
pin_memory=True)
|
93 |
+
|
94 |
+
return test_data_loader
|
lib/dataset/PIFuDataModule.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from torch.utils.data import DataLoader
|
3 |
+
from .PIFuDataset import PIFuDataset
|
4 |
+
import pytorch_lightning as pl
|
5 |
+
|
6 |
+
|
7 |
+
class PIFuDataModule(pl.LightningDataModule):
|
8 |
+
def __init__(self, cfg):
|
9 |
+
super(PIFuDataModule, self).__init__()
|
10 |
+
self.cfg = cfg
|
11 |
+
self.overfit = self.cfg.overfit
|
12 |
+
|
13 |
+
if self.overfit:
|
14 |
+
self.batch_size = 1
|
15 |
+
else:
|
16 |
+
self.batch_size = self.cfg.batch_size
|
17 |
+
|
18 |
+
self.data_size = {}
|
19 |
+
|
20 |
+
def prepare_data(self):
|
21 |
+
|
22 |
+
pass
|
23 |
+
|
24 |
+
@staticmethod
|
25 |
+
def worker_init_fn(worker_id):
|
26 |
+
np.random.seed(np.random.get_state()[1][0] + worker_id)
|
27 |
+
|
28 |
+
def setup(self, stage):
|
29 |
+
|
30 |
+
if stage == 'fit':
|
31 |
+
self.train_dataset = PIFuDataset(cfg=self.cfg, split="train")
|
32 |
+
self.val_dataset = PIFuDataset(cfg=self.cfg, split="val")
|
33 |
+
self.data_size = {'train': len(self.train_dataset),
|
34 |
+
'val': len(self.val_dataset)}
|
35 |
+
|
36 |
+
if stage == 'test':
|
37 |
+
self.test_dataset = PIFuDataset(cfg=self.cfg, split="test")
|
38 |
+
|
39 |
+
def train_dataloader(self):
|
40 |
+
|
41 |
+
train_data_loader = DataLoader(
|
42 |
+
self.train_dataset,
|
43 |
+
batch_size=self.batch_size, shuffle=True,
|
44 |
+
num_workers=self.cfg.num_threads, pin_memory=True,
|
45 |
+
worker_init_fn=self.worker_init_fn)
|
46 |
+
|
47 |
+
return train_data_loader
|
48 |
+
|
49 |
+
def val_dataloader(self):
|
50 |
+
|
51 |
+
if self.overfit:
|
52 |
+
current_dataset = self.train_dataset
|
53 |
+
else:
|
54 |
+
current_dataset = self.val_dataset
|
55 |
+
|
56 |
+
val_data_loader = DataLoader(
|
57 |
+
current_dataset,
|
58 |
+
batch_size=1, shuffle=False,
|
59 |
+
num_workers=self.cfg.num_threads, pin_memory=True,
|
60 |
+
worker_init_fn=self.worker_init_fn)
|
61 |
+
|
62 |
+
return val_data_loader
|
63 |
+
|
64 |
+
def test_dataloader(self):
|
65 |
+
|
66 |
+
test_data_loader = DataLoader(
|
67 |
+
self.test_dataset,
|
68 |
+
batch_size=1, shuffle=False,
|
69 |
+
num_workers=self.cfg.num_threads, pin_memory=True)
|
70 |
+
|
71 |
+
return test_data_loader
|
lib/dataset/PIFuDataset.py
ADDED
@@ -0,0 +1,589 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from lib.renderer.mesh import load_fit_body
|
2 |
+
from lib.dataset.hoppeMesh import HoppeMesh
|
3 |
+
from lib.dataset.body_model import TetraSMPLModel
|
4 |
+
from lib.common.render import Render
|
5 |
+
from lib.dataset.mesh_util import SMPLX, projection, cal_sdf_batch, get_visibility
|
6 |
+
from lib.pare.pare.utils.geometry import rotation_matrix_to_angle_axis
|
7 |
+
from termcolor import colored
|
8 |
+
import os.path as osp
|
9 |
+
import numpy as np
|
10 |
+
from PIL import Image
|
11 |
+
import random
|
12 |
+
import os
|
13 |
+
import trimesh
|
14 |
+
import torch
|
15 |
+
from kaolin.ops.mesh import check_sign
|
16 |
+
import torchvision.transforms as transforms
|
17 |
+
from huggingface_hub import hf_hub_download, cached_download
|
18 |
+
|
19 |
+
|
20 |
+
class PIFuDataset():
|
21 |
+
def __init__(self, cfg, split='train', vis=False):
|
22 |
+
|
23 |
+
self.split = split
|
24 |
+
self.root = cfg.root
|
25 |
+
self.bsize = cfg.batch_size
|
26 |
+
self.overfit = cfg.overfit
|
27 |
+
|
28 |
+
# for debug, only used in visualize_sampling3D
|
29 |
+
self.vis = vis
|
30 |
+
|
31 |
+
self.opt = cfg.dataset
|
32 |
+
self.datasets = self.opt.types
|
33 |
+
self.input_size = self.opt.input_size
|
34 |
+
self.scales = self.opt.scales
|
35 |
+
self.workers = cfg.num_threads
|
36 |
+
self.prior_type = cfg.net.prior_type
|
37 |
+
|
38 |
+
self.noise_type = self.opt.noise_type
|
39 |
+
self.noise_scale = self.opt.noise_scale
|
40 |
+
|
41 |
+
noise_joints = [4, 5, 7, 8, 13, 14, 16, 17, 18, 19, 20, 21]
|
42 |
+
|
43 |
+
self.noise_smpl_idx = []
|
44 |
+
self.noise_smplx_idx = []
|
45 |
+
|
46 |
+
for idx in noise_joints:
|
47 |
+
self.noise_smpl_idx.append(idx * 3)
|
48 |
+
self.noise_smpl_idx.append(idx * 3 + 1)
|
49 |
+
self.noise_smpl_idx.append(idx * 3 + 2)
|
50 |
+
|
51 |
+
self.noise_smplx_idx.append((idx-1) * 3)
|
52 |
+
self.noise_smplx_idx.append((idx-1) * 3 + 1)
|
53 |
+
self.noise_smplx_idx.append((idx-1) * 3 + 2)
|
54 |
+
|
55 |
+
self.use_sdf = cfg.sdf
|
56 |
+
self.sdf_clip = cfg.sdf_clip
|
57 |
+
|
58 |
+
# [(feat_name, channel_num),...]
|
59 |
+
self.in_geo = [item[0] for item in cfg.net.in_geo]
|
60 |
+
self.in_nml = [item[0] for item in cfg.net.in_nml]
|
61 |
+
|
62 |
+
self.in_geo_dim = [item[1] for item in cfg.net.in_geo]
|
63 |
+
self.in_nml_dim = [item[1] for item in cfg.net.in_nml]
|
64 |
+
|
65 |
+
self.in_total = self.in_geo + self.in_nml
|
66 |
+
self.in_total_dim = self.in_geo_dim + self.in_nml_dim
|
67 |
+
|
68 |
+
if self.split == 'train':
|
69 |
+
self.rotations = np.arange(
|
70 |
+
0, 360, 360 / self.opt.rotation_num).astype(np.int32)
|
71 |
+
else:
|
72 |
+
self.rotations = range(0, 360, 120)
|
73 |
+
|
74 |
+
self.datasets_dict = {}
|
75 |
+
|
76 |
+
for dataset_id, dataset in enumerate(self.datasets):
|
77 |
+
|
78 |
+
mesh_dir = None
|
79 |
+
smplx_dir = None
|
80 |
+
|
81 |
+
dataset_dir = osp.join(self.root, dataset)
|
82 |
+
|
83 |
+
if dataset in ['thuman2']:
|
84 |
+
mesh_dir = osp.join(dataset_dir, "scans")
|
85 |
+
smplx_dir = osp.join(dataset_dir, "fits")
|
86 |
+
smpl_dir = osp.join(dataset_dir, "smpl")
|
87 |
+
|
88 |
+
self.datasets_dict[dataset] = {
|
89 |
+
"subjects": np.loadtxt(osp.join(dataset_dir, "all.txt"), dtype=str),
|
90 |
+
"smplx_dir": smplx_dir,
|
91 |
+
"smpl_dir": smpl_dir,
|
92 |
+
"mesh_dir": mesh_dir,
|
93 |
+
"scale": self.scales[dataset_id]
|
94 |
+
}
|
95 |
+
|
96 |
+
self.subject_list = self.get_subject_list(split)
|
97 |
+
self.smplx = SMPLX()
|
98 |
+
|
99 |
+
# PIL to tensor
|
100 |
+
self.image_to_tensor = transforms.Compose([
|
101 |
+
transforms.Resize(self.input_size),
|
102 |
+
transforms.ToTensor(),
|
103 |
+
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
104 |
+
])
|
105 |
+
|
106 |
+
# PIL to tensor
|
107 |
+
self.mask_to_tensor = transforms.Compose([
|
108 |
+
transforms.Resize(self.input_size),
|
109 |
+
transforms.ToTensor(),
|
110 |
+
transforms.Normalize((0.0, ), (1.0, ))
|
111 |
+
])
|
112 |
+
|
113 |
+
self.device = torch.device(f"cuda:{cfg.gpus[0]}")
|
114 |
+
self.render = Render(size=512, device=self.device)
|
115 |
+
|
116 |
+
def render_normal(self, verts, faces):
|
117 |
+
|
118 |
+
# render optimized mesh (normal, T_normal, image [-1,1])
|
119 |
+
self.render.load_meshes(verts, faces)
|
120 |
+
return self.render.get_rgb_image()
|
121 |
+
|
122 |
+
def get_subject_list(self, split):
|
123 |
+
|
124 |
+
subject_list = []
|
125 |
+
|
126 |
+
for dataset in self.datasets:
|
127 |
+
|
128 |
+
split_txt = osp.join(self.root, dataset, f'{split}.txt')
|
129 |
+
|
130 |
+
if osp.exists(split_txt):
|
131 |
+
print(f"load from {split_txt}")
|
132 |
+
subject_list += np.loadtxt(split_txt, dtype=str).tolist()
|
133 |
+
else:
|
134 |
+
full_txt = osp.join(self.root, dataset, 'all.txt')
|
135 |
+
print(f"split {full_txt} into train/val/test")
|
136 |
+
|
137 |
+
full_lst = np.loadtxt(full_txt, dtype=str)
|
138 |
+
full_lst = [dataset+"/"+item for item in full_lst]
|
139 |
+
[train_lst, test_lst, val_lst] = np.split(
|
140 |
+
full_lst, [500, 500+5, ])
|
141 |
+
|
142 |
+
np.savetxt(full_txt.replace(
|
143 |
+
"all", "train"), train_lst, fmt="%s")
|
144 |
+
np.savetxt(full_txt.replace("all", "test"), test_lst, fmt="%s")
|
145 |
+
np.savetxt(full_txt.replace("all", "val"), val_lst, fmt="%s")
|
146 |
+
|
147 |
+
print(f"load from {split_txt}")
|
148 |
+
subject_list += np.loadtxt(split_txt, dtype=str).tolist()
|
149 |
+
|
150 |
+
if self.split != 'test':
|
151 |
+
subject_list += subject_list[:self.bsize -
|
152 |
+
len(subject_list) % self.bsize]
|
153 |
+
print(colored(f"total: {len(subject_list)}", "yellow"))
|
154 |
+
random.shuffle(subject_list)
|
155 |
+
|
156 |
+
# subject_list = ["thuman2/0008"]
|
157 |
+
return subject_list
|
158 |
+
|
159 |
+
def __len__(self):
|
160 |
+
return len(self.subject_list) * len(self.rotations)
|
161 |
+
|
162 |
+
def __getitem__(self, index):
|
163 |
+
|
164 |
+
# only pick the first data if overfitting
|
165 |
+
if self.overfit:
|
166 |
+
index = 0
|
167 |
+
|
168 |
+
rid = index % len(self.rotations)
|
169 |
+
mid = index // len(self.rotations)
|
170 |
+
|
171 |
+
rotation = self.rotations[rid]
|
172 |
+
subject = self.subject_list[mid].split("/")[1]
|
173 |
+
dataset = self.subject_list[mid].split("/")[0]
|
174 |
+
render_folder = "/".join([dataset +
|
175 |
+
f"_{self.opt.rotation_num}views", subject])
|
176 |
+
|
177 |
+
# setup paths
|
178 |
+
data_dict = {
|
179 |
+
'dataset': dataset,
|
180 |
+
'subject': subject,
|
181 |
+
'rotation': rotation,
|
182 |
+
'scale': self.datasets_dict[dataset]["scale"],
|
183 |
+
'mesh_path': osp.join(self.datasets_dict[dataset]["mesh_dir"], f"{subject}/{subject}.obj"),
|
184 |
+
'smplx_path': osp.join(self.datasets_dict[dataset]["smplx_dir"], f"{subject}/smplx_param.pkl"),
|
185 |
+
'smpl_path': osp.join(self.datasets_dict[dataset]["smpl_dir"], f"{subject}.pkl"),
|
186 |
+
'calib_path': osp.join(self.root, render_folder, 'calib', f'{rotation:03d}.txt'),
|
187 |
+
'vis_path': osp.join(self.root, render_folder, 'vis', f'{rotation:03d}.pt'),
|
188 |
+
'image_path': osp.join(self.root, render_folder, 'render', f'{rotation:03d}.png')
|
189 |
+
}
|
190 |
+
|
191 |
+
# load training data
|
192 |
+
data_dict.update(self.load_calib(data_dict))
|
193 |
+
|
194 |
+
# image/normal/depth loader
|
195 |
+
for name, channel in zip(self.in_total, self.in_total_dim):
|
196 |
+
|
197 |
+
if f'{name}_path' not in data_dict.keys():
|
198 |
+
data_dict.update({
|
199 |
+
f'{name}_path': osp.join(self.root, render_folder, name, f'{rotation:03d}.png')
|
200 |
+
})
|
201 |
+
|
202 |
+
# tensor update
|
203 |
+
data_dict.update({
|
204 |
+
name: self.imagepath2tensor(
|
205 |
+
data_dict[f'{name}_path'], channel, inv=False)
|
206 |
+
})
|
207 |
+
|
208 |
+
data_dict.update(self.load_mesh(data_dict))
|
209 |
+
data_dict.update(self.get_sampling_geo(
|
210 |
+
data_dict, is_valid=self.split == "val", is_sdf=self.use_sdf))
|
211 |
+
data_dict.update(self.load_smpl(data_dict, self.vis))
|
212 |
+
|
213 |
+
if self.prior_type == 'pamir':
|
214 |
+
data_dict.update(self.load_smpl_voxel(data_dict))
|
215 |
+
|
216 |
+
if (self.split != 'test') and (not self.vis):
|
217 |
+
|
218 |
+
del data_dict['verts']
|
219 |
+
del data_dict['faces']
|
220 |
+
|
221 |
+
if not self.vis:
|
222 |
+
del data_dict['mesh']
|
223 |
+
|
224 |
+
path_keys = [
|
225 |
+
key for key in data_dict.keys() if '_path' in key or '_dir' in key
|
226 |
+
]
|
227 |
+
for key in path_keys:
|
228 |
+
del data_dict[key]
|
229 |
+
|
230 |
+
return data_dict
|
231 |
+
|
232 |
+
def imagepath2tensor(self, path, channel=3, inv=False):
|
233 |
+
|
234 |
+
rgba = Image.open(path).convert('RGBA')
|
235 |
+
mask = rgba.split()[-1]
|
236 |
+
image = rgba.convert('RGB')
|
237 |
+
image = self.image_to_tensor(image)
|
238 |
+
mask = self.mask_to_tensor(mask)
|
239 |
+
image = (image * mask)[:channel]
|
240 |
+
|
241 |
+
return (image * (0.5 - inv) * 2.0).float()
|
242 |
+
|
243 |
+
def load_calib(self, data_dict):
|
244 |
+
calib_data = np.loadtxt(data_dict['calib_path'], dtype=float)
|
245 |
+
extrinsic = calib_data[:4, :4]
|
246 |
+
intrinsic = calib_data[4:8, :4]
|
247 |
+
calib_mat = np.matmul(intrinsic, extrinsic)
|
248 |
+
calib_mat = torch.from_numpy(calib_mat).float()
|
249 |
+
return {'calib': calib_mat}
|
250 |
+
|
251 |
+
def load_mesh(self, data_dict):
|
252 |
+
mesh_path = data_dict['mesh_path']
|
253 |
+
scale = data_dict['scale']
|
254 |
+
|
255 |
+
mesh_ori = trimesh.load(mesh_path,
|
256 |
+
skip_materials=True,
|
257 |
+
process=False,
|
258 |
+
maintain_order=True)
|
259 |
+
verts = mesh_ori.vertices * scale
|
260 |
+
faces = mesh_ori.faces
|
261 |
+
|
262 |
+
vert_normals = np.array(mesh_ori.vertex_normals)
|
263 |
+
face_normals = np.array(mesh_ori.face_normals)
|
264 |
+
|
265 |
+
mesh = HoppeMesh(verts, faces, vert_normals, face_normals)
|
266 |
+
|
267 |
+
return {
|
268 |
+
'mesh': mesh,
|
269 |
+
'verts': torch.as_tensor(mesh.verts).float(),
|
270 |
+
'faces': torch.as_tensor(mesh.faces).long()
|
271 |
+
}
|
272 |
+
|
273 |
+
def add_noise(self,
|
274 |
+
beta_num,
|
275 |
+
smpl_pose,
|
276 |
+
smpl_betas,
|
277 |
+
noise_type,
|
278 |
+
noise_scale,
|
279 |
+
type,
|
280 |
+
hashcode):
|
281 |
+
|
282 |
+
np.random.seed(hashcode)
|
283 |
+
|
284 |
+
if type == 'smplx':
|
285 |
+
noise_idx = self.noise_smplx_idx
|
286 |
+
else:
|
287 |
+
noise_idx = self.noise_smpl_idx
|
288 |
+
|
289 |
+
if 'beta' in noise_type and noise_scale[noise_type.index("beta")] > 0.0:
|
290 |
+
smpl_betas += (np.random.rand(beta_num) -
|
291 |
+
0.5) * 2.0 * noise_scale[noise_type.index("beta")]
|
292 |
+
smpl_betas = smpl_betas.astype(np.float32)
|
293 |
+
|
294 |
+
if 'pose' in noise_type and noise_scale[noise_type.index("pose")] > 0.0:
|
295 |
+
smpl_pose[noise_idx] += (
|
296 |
+
np.random.rand(len(noise_idx)) -
|
297 |
+
0.5) * 2.0 * np.pi * noise_scale[noise_type.index("pose")]
|
298 |
+
smpl_pose = smpl_pose.astype(np.float32)
|
299 |
+
if type == 'smplx':
|
300 |
+
return torch.as_tensor(smpl_pose[None, ...]), torch.as_tensor(smpl_betas[None, ...])
|
301 |
+
else:
|
302 |
+
return smpl_pose, smpl_betas
|
303 |
+
|
304 |
+
def compute_smpl_verts(self, data_dict, noise_type=None, noise_scale=None):
|
305 |
+
|
306 |
+
dataset = data_dict['dataset']
|
307 |
+
smplx_dict = {}
|
308 |
+
|
309 |
+
smplx_param = np.load(data_dict['smplx_path'], allow_pickle=True)
|
310 |
+
smplx_pose = smplx_param["body_pose"] # [1,63]
|
311 |
+
smplx_betas = smplx_param["betas"] # [1,10]
|
312 |
+
smplx_pose, smplx_betas = self.add_noise(
|
313 |
+
smplx_betas.shape[1],
|
314 |
+
smplx_pose[0],
|
315 |
+
smplx_betas[0],
|
316 |
+
noise_type,
|
317 |
+
noise_scale,
|
318 |
+
type='smplx',
|
319 |
+
hashcode=(hash(f"{data_dict['subject']}_{data_dict['rotation']}")) % (10**8))
|
320 |
+
|
321 |
+
smplx_out, _ = load_fit_body(fitted_path=data_dict['smplx_path'],
|
322 |
+
scale=self.datasets_dict[dataset]['scale'],
|
323 |
+
smpl_type='smplx',
|
324 |
+
smpl_gender='male',
|
325 |
+
noise_dict=dict(betas=smplx_betas, body_pose=smplx_pose))
|
326 |
+
|
327 |
+
smplx_dict.update({"type": "smplx",
|
328 |
+
"gender": 'male',
|
329 |
+
"body_pose": torch.as_tensor(smplx_pose),
|
330 |
+
"betas": torch.as_tensor(smplx_betas)})
|
331 |
+
|
332 |
+
return smplx_out.vertices, smplx_dict
|
333 |
+
|
334 |
+
def compute_voxel_verts(self,
|
335 |
+
data_dict,
|
336 |
+
noise_type=None,
|
337 |
+
noise_scale=None):
|
338 |
+
|
339 |
+
smpl_param = np.load(data_dict['smpl_path'], allow_pickle=True)
|
340 |
+
smplx_param = np.load(data_dict['smplx_path'], allow_pickle=True)
|
341 |
+
|
342 |
+
smpl_pose = rotation_matrix_to_angle_axis(
|
343 |
+
torch.as_tensor(smpl_param['full_pose'][0])).numpy()
|
344 |
+
smpl_betas = smpl_param["betas"]
|
345 |
+
|
346 |
+
smpl_path = cached_download(osp.join(self.smplx.model_dir, "smpl/SMPL_MALE.pkl"), use_auth_token=os.environ['ICON'])
|
347 |
+
tetra_path = cached_download(osp.join(self.smplx.tedra_dir,
|
348 |
+
"tetra_male_adult_smpl.npz"), use_auth_token=os.environ['ICON'])
|
349 |
+
|
350 |
+
smpl_model = TetraSMPLModel(smpl_path, tetra_path, 'adult')
|
351 |
+
|
352 |
+
smpl_pose, smpl_betas = self.add_noise(
|
353 |
+
smpl_model.beta_shape[0],
|
354 |
+
smpl_pose.flatten(),
|
355 |
+
smpl_betas[0],
|
356 |
+
noise_type,
|
357 |
+
noise_scale,
|
358 |
+
type='smpl',
|
359 |
+
hashcode=(hash(f"{data_dict['subject']}_{data_dict['rotation']}")) % (10**8))
|
360 |
+
|
361 |
+
smpl_model.set_params(pose=smpl_pose.reshape(-1, 3),
|
362 |
+
beta=smpl_betas,
|
363 |
+
trans=smpl_param["transl"])
|
364 |
+
|
365 |
+
verts = (np.concatenate([smpl_model.verts, smpl_model.verts_added],
|
366 |
+
axis=0) * smplx_param["scale"] + smplx_param["translation"]
|
367 |
+
) * self.datasets_dict[data_dict['dataset']]['scale']
|
368 |
+
faces = np.loadtxt(cached_download(osp.join(self.smplx.tedra_dir, "tetrahedrons_male_adult.txt"), use_auth_token=os.environ['ICON']),
|
369 |
+
dtype=np.int32) - 1
|
370 |
+
|
371 |
+
pad_v_num = int(8000 - verts.shape[0])
|
372 |
+
pad_f_num = int(25100 - faces.shape[0])
|
373 |
+
|
374 |
+
verts = np.pad(verts, ((0, pad_v_num), (0, 0)),
|
375 |
+
mode='constant',
|
376 |
+
constant_values=0.0).astype(np.float32)
|
377 |
+
faces = np.pad(faces, ((0, pad_f_num), (0, 0)),
|
378 |
+
mode='constant',
|
379 |
+
constant_values=0.0).astype(np.int32)
|
380 |
+
|
381 |
+
|
382 |
+
return verts, faces, pad_v_num, pad_f_num
|
383 |
+
|
384 |
+
def load_smpl(self, data_dict, vis=False):
|
385 |
+
|
386 |
+
smplx_verts, smplx_dict = self.compute_smpl_verts(
|
387 |
+
data_dict, self.noise_type,
|
388 |
+
self.noise_scale) # compute using smpl model
|
389 |
+
|
390 |
+
smplx_verts = projection(smplx_verts, data_dict['calib']).float()
|
391 |
+
smplx_faces = torch.as_tensor(self.smplx.faces).long()
|
392 |
+
smplx_vis = torch.load(data_dict['vis_path']).float()
|
393 |
+
smplx_cmap = torch.as_tensor(
|
394 |
+
np.load(self.smplx.cmap_vert_path)).float()
|
395 |
+
|
396 |
+
# get smpl_signs
|
397 |
+
query_points = projection(data_dict['samples_geo'],
|
398 |
+
data_dict['calib']).float()
|
399 |
+
|
400 |
+
pts_signs = 2.0 * (check_sign(smplx_verts.unsqueeze(0),
|
401 |
+
smplx_faces,
|
402 |
+
query_points.unsqueeze(0)).float() - 0.5).squeeze(0)
|
403 |
+
|
404 |
+
return_dict = {
|
405 |
+
'smpl_verts': smplx_verts,
|
406 |
+
'smpl_faces': smplx_faces,
|
407 |
+
'smpl_vis': smplx_vis,
|
408 |
+
'smpl_cmap': smplx_cmap,
|
409 |
+
'pts_signs': pts_signs
|
410 |
+
}
|
411 |
+
if smplx_dict is not None:
|
412 |
+
return_dict.update(smplx_dict)
|
413 |
+
|
414 |
+
if vis:
|
415 |
+
|
416 |
+
(xy, z) = torch.as_tensor(smplx_verts).to(
|
417 |
+
self.device).split([2, 1], dim=1)
|
418 |
+
smplx_vis = get_visibility(xy, z, torch.as_tensor(
|
419 |
+
smplx_faces).to(self.device).long())
|
420 |
+
|
421 |
+
T_normal_F, T_normal_B = self.render_normal(
|
422 |
+
(smplx_verts*torch.tensor([1.0, -1.0, 1.0])).to(self.device),
|
423 |
+
smplx_faces.to(self.device))
|
424 |
+
|
425 |
+
return_dict.update({"T_normal_F": T_normal_F.squeeze(0),
|
426 |
+
"T_normal_B": T_normal_B.squeeze(0)})
|
427 |
+
query_points = projection(data_dict['samples_geo'],
|
428 |
+
data_dict['calib']).float()
|
429 |
+
|
430 |
+
smplx_sdf, smplx_norm, smplx_cmap, smplx_vis = cal_sdf_batch(
|
431 |
+
smplx_verts.unsqueeze(0).to(self.device),
|
432 |
+
smplx_faces.unsqueeze(0).to(self.device),
|
433 |
+
smplx_cmap.unsqueeze(0).to(self.device),
|
434 |
+
smplx_vis.unsqueeze(0).to(self.device),
|
435 |
+
query_points.unsqueeze(0).contiguous().to(self.device))
|
436 |
+
|
437 |
+
return_dict.update({
|
438 |
+
'smpl_feat':
|
439 |
+
torch.cat(
|
440 |
+
(smplx_sdf[0].detach().cpu(),
|
441 |
+
smplx_cmap[0].detach().cpu(),
|
442 |
+
smplx_norm[0].detach().cpu(),
|
443 |
+
smplx_vis[0].detach().cpu()),
|
444 |
+
dim=1)
|
445 |
+
})
|
446 |
+
|
447 |
+
return return_dict
|
448 |
+
|
449 |
+
def load_smpl_voxel(self, data_dict):
|
450 |
+
|
451 |
+
smpl_verts, smpl_faces, pad_v_num, pad_f_num = self.compute_voxel_verts(
|
452 |
+
data_dict, self.noise_type,
|
453 |
+
self.noise_scale) # compute using smpl model
|
454 |
+
smpl_verts = projection(smpl_verts, data_dict['calib'])
|
455 |
+
|
456 |
+
smpl_verts *= 0.5
|
457 |
+
|
458 |
+
return {
|
459 |
+
'voxel_verts': smpl_verts,
|
460 |
+
'voxel_faces': smpl_faces,
|
461 |
+
'pad_v_num': pad_v_num,
|
462 |
+
'pad_f_num': pad_f_num
|
463 |
+
}
|
464 |
+
|
465 |
+
def get_sampling_geo(self, data_dict, is_valid=False, is_sdf=False):
|
466 |
+
|
467 |
+
mesh = data_dict['mesh']
|
468 |
+
calib = data_dict['calib']
|
469 |
+
|
470 |
+
# Samples are around the true surface with an offset
|
471 |
+
n_samples_surface = 4 * self.opt.num_sample_geo
|
472 |
+
vert_ids = np.arange(mesh.verts.shape[0])
|
473 |
+
thickness_sample_ratio = np.ones_like(vert_ids).astype(np.float32)
|
474 |
+
|
475 |
+
thickness_sample_ratio /= thickness_sample_ratio.sum()
|
476 |
+
|
477 |
+
samples_surface_ids = np.random.choice(vert_ids,
|
478 |
+
n_samples_surface,
|
479 |
+
replace=True,
|
480 |
+
p=thickness_sample_ratio)
|
481 |
+
|
482 |
+
samples_normal_ids = np.random.choice(vert_ids,
|
483 |
+
self.opt.num_sample_geo // 2,
|
484 |
+
replace=False,
|
485 |
+
p=thickness_sample_ratio)
|
486 |
+
|
487 |
+
surf_samples = mesh.verts[samples_normal_ids, :]
|
488 |
+
surf_normals = mesh.vert_normals[samples_normal_ids, :]
|
489 |
+
|
490 |
+
samples_surface = mesh.verts[samples_surface_ids, :]
|
491 |
+
|
492 |
+
# Sampling offsets are random noise with constant scale (15cm - 20cm)
|
493 |
+
offset = np.random.normal(scale=self.opt.sigma_geo,
|
494 |
+
size=(n_samples_surface, 1))
|
495 |
+
samples_surface += mesh.vert_normals[samples_surface_ids, :] * offset
|
496 |
+
|
497 |
+
# Uniform samples in [-1, 1]
|
498 |
+
calib_inv = np.linalg.inv(calib)
|
499 |
+
n_samples_space = self.opt.num_sample_geo // 4
|
500 |
+
samples_space_img = 2.0 * np.random.rand(n_samples_space, 3) - 1.0
|
501 |
+
samples_space = projection(samples_space_img, calib_inv)
|
502 |
+
|
503 |
+
# z-ray direction samples
|
504 |
+
if self.opt.zray_type and not is_valid:
|
505 |
+
n_samples_rayz = self.opt.ray_sample_num
|
506 |
+
samples_surface_cube = projection(samples_surface, calib)
|
507 |
+
samples_surface_cube_repeat = np.repeat(samples_surface_cube,
|
508 |
+
n_samples_rayz,
|
509 |
+
axis=0)
|
510 |
+
|
511 |
+
thickness_repeat = np.repeat(0.5 *
|
512 |
+
np.ones_like(samples_surface_ids),
|
513 |
+
n_samples_rayz,
|
514 |
+
axis=0)
|
515 |
+
|
516 |
+
noise_repeat = np.random.normal(scale=0.40,
|
517 |
+
size=(n_samples_surface *
|
518 |
+
n_samples_rayz, ))
|
519 |
+
samples_surface_cube_repeat[:,
|
520 |
+
-1] += thickness_repeat * noise_repeat
|
521 |
+
samples_surface_rayz = projection(samples_surface_cube_repeat,
|
522 |
+
calib_inv)
|
523 |
+
|
524 |
+
samples = np.concatenate(
|
525 |
+
[samples_surface, samples_space, samples_surface_rayz], 0)
|
526 |
+
else:
|
527 |
+
samples = np.concatenate([samples_surface, samples_space], 0)
|
528 |
+
|
529 |
+
np.random.shuffle(samples)
|
530 |
+
|
531 |
+
# labels: in->1.0; out->0.0.
|
532 |
+
if is_sdf:
|
533 |
+
sdfs = mesh.get_sdf(samples)
|
534 |
+
inside_samples = samples[sdfs < 0]
|
535 |
+
outside_samples = samples[sdfs >= 0]
|
536 |
+
|
537 |
+
inside_sdfs = sdfs[sdfs < 0]
|
538 |
+
outside_sdfs = sdfs[sdfs >= 0]
|
539 |
+
else:
|
540 |
+
inside = mesh.contains(samples)
|
541 |
+
inside_samples = samples[inside >= 0.5]
|
542 |
+
outside_samples = samples[inside < 0.5]
|
543 |
+
|
544 |
+
nin = inside_samples.shape[0]
|
545 |
+
|
546 |
+
if nin > self.opt.num_sample_geo // 2:
|
547 |
+
inside_samples = inside_samples[:self.opt.num_sample_geo // 2]
|
548 |
+
outside_samples = outside_samples[:self.opt.num_sample_geo // 2]
|
549 |
+
if is_sdf:
|
550 |
+
inside_sdfs = inside_sdfs[:self.opt.num_sample_geo // 2]
|
551 |
+
outside_sdfs = outside_sdfs[:self.opt.num_sample_geo // 2]
|
552 |
+
else:
|
553 |
+
outside_samples = outside_samples[:(self.opt.num_sample_geo - nin)]
|
554 |
+
if is_sdf:
|
555 |
+
outside_sdfs = outside_sdfs[:(self.opt.num_sample_geo - nin)]
|
556 |
+
|
557 |
+
if is_sdf:
|
558 |
+
samples = np.concatenate(
|
559 |
+
[inside_samples, outside_samples, surf_samples], 0)
|
560 |
+
|
561 |
+
labels = np.concatenate([
|
562 |
+
inside_sdfs, outside_sdfs, 0.0 * np.ones(surf_samples.shape[0])
|
563 |
+
])
|
564 |
+
|
565 |
+
normals = np.zeros_like(samples)
|
566 |
+
normals[-self.opt.num_sample_geo // 2:, :] = surf_normals
|
567 |
+
|
568 |
+
# convert sdf from [-14, 130] to [0, 1]
|
569 |
+
# outside: 0, inside: 1
|
570 |
+
# Note: Marching cubes is defined on occupancy space (inside=1.0, outside=0.0)
|
571 |
+
|
572 |
+
labels = -labels.clip(min=-self.sdf_clip, max=self.sdf_clip)
|
573 |
+
labels += self.sdf_clip
|
574 |
+
labels /= (self.sdf_clip * 2)
|
575 |
+
|
576 |
+
else:
|
577 |
+
samples = np.concatenate([inside_samples, outside_samples])
|
578 |
+
labels = np.concatenate([
|
579 |
+
np.ones(inside_samples.shape[0]),
|
580 |
+
np.zeros(outside_samples.shape[0])
|
581 |
+
])
|
582 |
+
|
583 |
+
normals = np.zeros_like(samples)
|
584 |
+
|
585 |
+
samples = torch.from_numpy(samples).float()
|
586 |
+
labels = torch.from_numpy(labels).float()
|
587 |
+
normals = torch.from_numpy(normals).float()
|
588 |
+
|
589 |
+
return {'samples_geo': samples, 'labels_geo': labels}
|
lib/dataset/TestDataset.py
ADDED
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
|
5 |
+
# holder of all proprietary rights on this computer program.
|
6 |
+
# You can only use this computer program if you have closed
|
7 |
+
# a license agreement with MPG or you get the right to use the computer
|
8 |
+
# program from someone who is authorized to grant you that right.
|
9 |
+
# Any use of the computer program without a valid license is prohibited and
|
10 |
+
# liable to prosecution.
|
11 |
+
#
|
12 |
+
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
|
13 |
+
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
|
14 |
+
# for Intelligent Systems. All rights reserved.
|
15 |
+
#
|
16 |
+
# Contact: [email protected]
|
17 |
+
|
18 |
+
import os
|
19 |
+
|
20 |
+
import lib.smplx as smplx
|
21 |
+
from lib.pymaf.utils.geometry import rotation_matrix_to_angle_axis, batch_rodrigues
|
22 |
+
from lib.pymaf.utils.imutils import process_image
|
23 |
+
from lib.pymaf.core import path_config
|
24 |
+
from lib.pymaf.models import pymaf_net
|
25 |
+
from lib.common.config import cfg
|
26 |
+
from lib.common.render import Render
|
27 |
+
from lib.dataset.body_model import TetraSMPLModel
|
28 |
+
from lib.dataset.mesh_util import get_visibility, SMPLX
|
29 |
+
import os.path as osp
|
30 |
+
import torch
|
31 |
+
import numpy as np
|
32 |
+
import random
|
33 |
+
import human_det
|
34 |
+
from termcolor import colored
|
35 |
+
from PIL import ImageFile
|
36 |
+
from huggingface_hub import cached_download
|
37 |
+
|
38 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
39 |
+
|
40 |
+
|
41 |
+
class TestDataset():
|
42 |
+
def __init__(self, cfg, device):
|
43 |
+
|
44 |
+
random.seed(1993)
|
45 |
+
|
46 |
+
self.image_path = cfg['image_path']
|
47 |
+
self.seg_dir = cfg['seg_dir']
|
48 |
+
self.has_det = cfg['has_det']
|
49 |
+
self.hps_type = cfg['hps_type']
|
50 |
+
self.smpl_type = 'smpl' if cfg['hps_type'] != 'pixie' else 'smplx'
|
51 |
+
self.smpl_gender = 'neutral'
|
52 |
+
|
53 |
+
self.device = device
|
54 |
+
|
55 |
+
if self.has_det:
|
56 |
+
self.det = human_det.Detection()
|
57 |
+
else:
|
58 |
+
self.det = None
|
59 |
+
|
60 |
+
|
61 |
+
self.subject_list = [self.image_path]
|
62 |
+
|
63 |
+
# smpl related
|
64 |
+
self.smpl_data = SMPLX()
|
65 |
+
|
66 |
+
self.get_smpl_model = lambda smpl_type, smpl_gender: smplx.create(
|
67 |
+
model_path=self.smpl_data.model_dir,
|
68 |
+
gender=smpl_gender,
|
69 |
+
model_type=smpl_type,
|
70 |
+
ext='npz')
|
71 |
+
|
72 |
+
# Load SMPL model
|
73 |
+
self.smpl_model = self.get_smpl_model(
|
74 |
+
self.smpl_type, self.smpl_gender).to(self.device)
|
75 |
+
self.faces = self.smpl_model.faces
|
76 |
+
|
77 |
+
self.hps = pymaf_net(path_config.SMPL_MEAN_PARAMS,
|
78 |
+
pretrained=True).to(self.device)
|
79 |
+
self.hps.load_state_dict(torch.load(
|
80 |
+
path_config.CHECKPOINT_FILE)['model'],
|
81 |
+
strict=True)
|
82 |
+
self.hps.eval()
|
83 |
+
|
84 |
+
print(colored(f"Using {self.hps_type} as HPS Estimator\n", "green"))
|
85 |
+
|
86 |
+
self.render = Render(size=512, device=device)
|
87 |
+
|
88 |
+
def __len__(self):
|
89 |
+
return len(self.subject_list)
|
90 |
+
|
91 |
+
def compute_vis_cmap(self, smpl_verts, smpl_faces):
|
92 |
+
|
93 |
+
(xy, z) = torch.as_tensor(smpl_verts).split([2, 1], dim=1)
|
94 |
+
smpl_vis = get_visibility(xy, -z, torch.as_tensor(smpl_faces).long())
|
95 |
+
if self.smpl_type == 'smpl':
|
96 |
+
smplx_ind = self.smpl_data.smpl2smplx(np.arange(smpl_vis.shape[0]))
|
97 |
+
else:
|
98 |
+
smplx_ind = np.arange(smpl_vis.shape[0])
|
99 |
+
smpl_cmap = self.smpl_data.get_smpl_mat(smplx_ind)
|
100 |
+
|
101 |
+
return {
|
102 |
+
'smpl_vis': smpl_vis.unsqueeze(0).to(self.device),
|
103 |
+
'smpl_cmap': smpl_cmap.unsqueeze(0).to(self.device),
|
104 |
+
'smpl_verts': smpl_verts.unsqueeze(0)
|
105 |
+
}
|
106 |
+
|
107 |
+
def compute_voxel_verts(self, body_pose, global_orient, betas, trans,
|
108 |
+
scale):
|
109 |
+
|
110 |
+
smpl_path = cached_download(osp.join(self.smpl_data.model_dir, "smpl/SMPL_NEUTRAL.pkl"), use_auth_token=os.environ['ICON'])
|
111 |
+
tetra_path = cached_download(osp.join(self.smpl_data.tedra_dir,
|
112 |
+
'tetra_neutral_adult_smpl.npz'), use_auth_token=os.environ['ICON'])
|
113 |
+
smpl_model = TetraSMPLModel(smpl_path, tetra_path, 'adult')
|
114 |
+
|
115 |
+
pose = torch.cat([global_orient[0], body_pose[0]], dim=0)
|
116 |
+
smpl_model.set_params(rotation_matrix_to_angle_axis(pose),
|
117 |
+
beta=betas[0])
|
118 |
+
|
119 |
+
verts = np.concatenate(
|
120 |
+
[smpl_model.verts, smpl_model.verts_added],
|
121 |
+
axis=0) * scale.item() + trans.detach().cpu().numpy()
|
122 |
+
faces = np.loadtxt(cached_download(osp.join(self.smpl_data.tedra_dir,
|
123 |
+
'tetrahedrons_neutral_adult.txt'), use_auth_token=os.environ['ICON']),
|
124 |
+
dtype=np.int32) - 1
|
125 |
+
|
126 |
+
pad_v_num = int(8000 - verts.shape[0])
|
127 |
+
pad_f_num = int(25100 - faces.shape[0])
|
128 |
+
|
129 |
+
verts = np.pad(verts, ((0, pad_v_num), (0, 0)),
|
130 |
+
mode='constant',
|
131 |
+
constant_values=0.0).astype(np.float32) * 0.5
|
132 |
+
faces = np.pad(faces, ((0, pad_f_num), (0, 0)),
|
133 |
+
mode='constant',
|
134 |
+
constant_values=0.0).astype(np.int32)
|
135 |
+
|
136 |
+
verts[:, 2] *= -1.0
|
137 |
+
|
138 |
+
voxel_dict = {
|
139 |
+
'voxel_verts':
|
140 |
+
torch.from_numpy(verts).to(self.device).unsqueeze(0).float(),
|
141 |
+
'voxel_faces':
|
142 |
+
torch.from_numpy(faces).to(self.device).unsqueeze(0).long(),
|
143 |
+
'pad_v_num':
|
144 |
+
torch.tensor(pad_v_num).to(self.device).unsqueeze(0).long(),
|
145 |
+
'pad_f_num':
|
146 |
+
torch.tensor(pad_f_num).to(self.device).unsqueeze(0).long()
|
147 |
+
}
|
148 |
+
|
149 |
+
return voxel_dict
|
150 |
+
|
151 |
+
def __getitem__(self, index):
|
152 |
+
|
153 |
+
img_path = self.subject_list[index]
|
154 |
+
img_name = img_path.split("/")[-1].rsplit(".", 1)[0]
|
155 |
+
|
156 |
+
if self.seg_dir is None:
|
157 |
+
img_icon, img_hps, img_ori, img_mask, uncrop_param = process_image(
|
158 |
+
img_path, self.det, self.hps_type, 512, self.device)
|
159 |
+
|
160 |
+
data_dict = {
|
161 |
+
'name': img_name,
|
162 |
+
'image': img_icon.to(self.device).unsqueeze(0),
|
163 |
+
'ori_image': img_ori,
|
164 |
+
'mask': img_mask,
|
165 |
+
'uncrop_param': uncrop_param
|
166 |
+
}
|
167 |
+
|
168 |
+
else:
|
169 |
+
img_icon, img_hps, img_ori, img_mask, uncrop_param, segmentations = process_image(
|
170 |
+
img_path, self.det, self.hps_type, 512, self.device,
|
171 |
+
seg_path=os.path.join(self.seg_dir, f'{img_name}.json'))
|
172 |
+
data_dict = {
|
173 |
+
'name': img_name,
|
174 |
+
'image': img_icon.to(self.device).unsqueeze(0),
|
175 |
+
'ori_image': img_ori,
|
176 |
+
'mask': img_mask,
|
177 |
+
'uncrop_param': uncrop_param,
|
178 |
+
'segmentations': segmentations
|
179 |
+
}
|
180 |
+
|
181 |
+
with torch.no_grad():
|
182 |
+
# import ipdb; ipdb.set_trace()
|
183 |
+
preds_dict = self.hps.forward(img_hps)
|
184 |
+
|
185 |
+
data_dict['smpl_faces'] = torch.Tensor(
|
186 |
+
self.faces.astype(np.int16)).long().unsqueeze(0).to(
|
187 |
+
self.device)
|
188 |
+
|
189 |
+
if self.hps_type == 'pymaf':
|
190 |
+
output = preds_dict['smpl_out'][-1]
|
191 |
+
scale, tranX, tranY = output['theta'][0, :3]
|
192 |
+
data_dict['betas'] = output['pred_shape']
|
193 |
+
data_dict['body_pose'] = output['rotmat'][:, 1:]
|
194 |
+
data_dict['global_orient'] = output['rotmat'][:, 0:1]
|
195 |
+
data_dict['smpl_verts'] = output['verts']
|
196 |
+
|
197 |
+
elif self.hps_type == 'pare':
|
198 |
+
data_dict['body_pose'] = preds_dict['pred_pose'][:, 1:]
|
199 |
+
data_dict['global_orient'] = preds_dict['pred_pose'][:, 0:1]
|
200 |
+
data_dict['betas'] = preds_dict['pred_shape']
|
201 |
+
data_dict['smpl_verts'] = preds_dict['smpl_vertices']
|
202 |
+
scale, tranX, tranY = preds_dict['pred_cam'][0, :3]
|
203 |
+
|
204 |
+
elif self.hps_type == 'pixie':
|
205 |
+
data_dict.update(preds_dict)
|
206 |
+
data_dict['body_pose'] = preds_dict['body_pose']
|
207 |
+
data_dict['global_orient'] = preds_dict['global_pose']
|
208 |
+
data_dict['betas'] = preds_dict['shape']
|
209 |
+
data_dict['smpl_verts'] = preds_dict['vertices']
|
210 |
+
scale, tranX, tranY = preds_dict['cam'][0, :3]
|
211 |
+
|
212 |
+
elif self.hps_type == 'hybrik':
|
213 |
+
data_dict['body_pose'] = preds_dict['pred_theta_mats'][:, 1:]
|
214 |
+
data_dict['global_orient'] = preds_dict['pred_theta_mats'][:, [0]]
|
215 |
+
data_dict['betas'] = preds_dict['pred_shape']
|
216 |
+
data_dict['smpl_verts'] = preds_dict['pred_vertices']
|
217 |
+
scale, tranX, tranY = preds_dict['pred_camera'][0, :3]
|
218 |
+
scale = scale * 2
|
219 |
+
|
220 |
+
elif self.hps_type == 'bev':
|
221 |
+
data_dict['betas'] = torch.from_numpy(preds_dict['smpl_betas'])[
|
222 |
+
[0], :10].to(self.device).float()
|
223 |
+
pred_thetas = batch_rodrigues(torch.from_numpy(
|
224 |
+
preds_dict['smpl_thetas'][0]).reshape(-1, 3)).float()
|
225 |
+
data_dict['body_pose'] = pred_thetas[1:][None].to(self.device)
|
226 |
+
data_dict['global_orient'] = pred_thetas[[0]][None].to(self.device)
|
227 |
+
data_dict['smpl_verts'] = torch.from_numpy(
|
228 |
+
preds_dict['verts'][[0]]).to(self.device).float()
|
229 |
+
tranX = preds_dict['cam_trans'][0, 0]
|
230 |
+
tranY = preds_dict['cam'][0, 1] + 0.28
|
231 |
+
scale = preds_dict['cam'][0, 0] * 1.1
|
232 |
+
|
233 |
+
data_dict['scale'] = scale
|
234 |
+
data_dict['trans'] = torch.tensor(
|
235 |
+
[tranX, tranY, 0.0]).to(self.device).float()
|
236 |
+
|
237 |
+
# data_dict info (key-shape):
|
238 |
+
# scale, tranX, tranY - tensor.float
|
239 |
+
# betas - [1,10] / [1, 200]
|
240 |
+
# body_pose - [1, 23, 3, 3] / [1, 21, 3, 3]
|
241 |
+
# global_orient - [1, 1, 3, 3]
|
242 |
+
# smpl_verts - [1, 6890, 3] / [1, 10475, 3]
|
243 |
+
|
244 |
+
return data_dict
|
245 |
+
|
246 |
+
def render_normal(self, verts, faces):
|
247 |
+
|
248 |
+
# render optimized mesh (normal, T_normal, image [-1,1])
|
249 |
+
self.render.load_meshes(verts, faces)
|
250 |
+
return self.render.get_rgb_image()
|
251 |
+
|
252 |
+
def render_depth(self, verts, faces):
|
253 |
+
|
254 |
+
# render optimized mesh (normal, T_normal, image [-1,1])
|
255 |
+
self.render.load_meshes(verts, faces)
|
256 |
+
return self.render.get_depth_map(cam_ids=[0, 2])
|
lib/dataset/__init__.py
ADDED
File without changes
|
lib/dataset/body_model.py
ADDED
@@ -0,0 +1,494 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
|
5 |
+
# holder of all proprietary rights on this computer program.
|
6 |
+
# You can only use this computer program if you have closed
|
7 |
+
# a license agreement with MPG or you get the right to use the computer
|
8 |
+
# program from someone who is authorized to grant you that right.
|
9 |
+
# Any use of the computer program without a valid license is prohibited and
|
10 |
+
# liable to prosecution.
|
11 |
+
#
|
12 |
+
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
|
13 |
+
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
|
14 |
+
# for Intelligent Systems. All rights reserved.
|
15 |
+
#
|
16 |
+
# Contact: [email protected]
|
17 |
+
|
18 |
+
import numpy as np
|
19 |
+
import pickle
|
20 |
+
import torch
|
21 |
+
import os
|
22 |
+
|
23 |
+
|
24 |
+
class SMPLModel():
|
25 |
+
def __init__(self, model_path, age):
|
26 |
+
"""
|
27 |
+
SMPL model.
|
28 |
+
|
29 |
+
Parameter:
|
30 |
+
---------
|
31 |
+
model_path: Path to the SMPL model parameters, pre-processed by
|
32 |
+
`preprocess.py`.
|
33 |
+
|
34 |
+
"""
|
35 |
+
with open(model_path, 'rb') as f:
|
36 |
+
params = pickle.load(f, encoding='latin1')
|
37 |
+
|
38 |
+
self.J_regressor = params['J_regressor']
|
39 |
+
self.weights = np.asarray(params['weights'])
|
40 |
+
self.posedirs = np.asarray(params['posedirs'])
|
41 |
+
self.v_template = np.asarray(params['v_template'])
|
42 |
+
self.shapedirs = np.asarray(params['shapedirs'])
|
43 |
+
self.faces = np.asarray(params['f'])
|
44 |
+
self.kintree_table = np.asarray(params['kintree_table'])
|
45 |
+
|
46 |
+
self.pose_shape = [24, 3]
|
47 |
+
self.beta_shape = [10]
|
48 |
+
self.trans_shape = [3]
|
49 |
+
|
50 |
+
if age == 'kid':
|
51 |
+
v_template_smil = np.load(
|
52 |
+
os.path.join(os.path.dirname(model_path),
|
53 |
+
"smpl/smpl_kid_template.npy"))
|
54 |
+
v_template_smil -= np.mean(v_template_smil, axis=0)
|
55 |
+
v_template_diff = np.expand_dims(v_template_smil - self.v_template,
|
56 |
+
axis=2)
|
57 |
+
self.shapedirs = np.concatenate(
|
58 |
+
(self.shapedirs[:, :, :self.beta_shape[0]], v_template_diff),
|
59 |
+
axis=2)
|
60 |
+
self.beta_shape[0] += 1
|
61 |
+
|
62 |
+
id_to_col = {
|
63 |
+
self.kintree_table[1, i]: i
|
64 |
+
for i in range(self.kintree_table.shape[1])
|
65 |
+
}
|
66 |
+
self.parent = {
|
67 |
+
i: id_to_col[self.kintree_table[0, i]]
|
68 |
+
for i in range(1, self.kintree_table.shape[1])
|
69 |
+
}
|
70 |
+
|
71 |
+
self.pose = np.zeros(self.pose_shape)
|
72 |
+
self.beta = np.zeros(self.beta_shape)
|
73 |
+
self.trans = np.zeros(self.trans_shape)
|
74 |
+
|
75 |
+
self.verts = None
|
76 |
+
self.J = None
|
77 |
+
self.R = None
|
78 |
+
self.G = None
|
79 |
+
|
80 |
+
self.update()
|
81 |
+
|
82 |
+
def set_params(self, pose=None, beta=None, trans=None):
|
83 |
+
"""
|
84 |
+
Set pose, shape, and/or translation parameters of SMPL model. Verices of the
|
85 |
+
model will be updated and returned.
|
86 |
+
|
87 |
+
Prameters:
|
88 |
+
---------
|
89 |
+
pose: Also known as 'theta', a [24,3] matrix indicating child joint rotation
|
90 |
+
relative to parent joint. For root joint it's global orientation.
|
91 |
+
Represented in a axis-angle format.
|
92 |
+
|
93 |
+
beta: Parameter for model shape. A vector of shape [10]. Coefficients for
|
94 |
+
PCA component. Only 10 components were released by MPI.
|
95 |
+
|
96 |
+
trans: Global translation of shape [3].
|
97 |
+
|
98 |
+
Return:
|
99 |
+
------
|
100 |
+
Updated vertices.
|
101 |
+
|
102 |
+
"""
|
103 |
+
if pose is not None:
|
104 |
+
self.pose = pose
|
105 |
+
if beta is not None:
|
106 |
+
self.beta = beta
|
107 |
+
if trans is not None:
|
108 |
+
self.trans = trans
|
109 |
+
self.update()
|
110 |
+
return self.verts
|
111 |
+
|
112 |
+
def update(self):
|
113 |
+
"""
|
114 |
+
Called automatically when parameters are updated.
|
115 |
+
|
116 |
+
"""
|
117 |
+
# how beta affect body shape
|
118 |
+
v_shaped = self.shapedirs.dot(self.beta) + self.v_template
|
119 |
+
# joints location
|
120 |
+
self.J = self.J_regressor.dot(v_shaped)
|
121 |
+
pose_cube = self.pose.reshape((-1, 1, 3))
|
122 |
+
# rotation matrix for each joint
|
123 |
+
self.R = self.rodrigues(pose_cube)
|
124 |
+
I_cube = np.broadcast_to(np.expand_dims(np.eye(3), axis=0),
|
125 |
+
(self.R.shape[0] - 1, 3, 3))
|
126 |
+
lrotmin = (self.R[1:] - I_cube).ravel()
|
127 |
+
# how pose affect body shape in zero pose
|
128 |
+
v_posed = v_shaped + self.posedirs.dot(lrotmin)
|
129 |
+
# world transformation of each joint
|
130 |
+
G = np.empty((self.kintree_table.shape[1], 4, 4))
|
131 |
+
G[0] = self.with_zeros(
|
132 |
+
np.hstack((self.R[0], self.J[0, :].reshape([3, 1]))))
|
133 |
+
for i in range(1, self.kintree_table.shape[1]):
|
134 |
+
G[i] = G[self.parent[i]].dot(
|
135 |
+
self.with_zeros(
|
136 |
+
np.hstack([
|
137 |
+
self.R[i],
|
138 |
+
((self.J[i, :] - self.J[self.parent[i], :]).reshape(
|
139 |
+
[3, 1]))
|
140 |
+
])))
|
141 |
+
# remove the transformation due to the rest pose
|
142 |
+
G = G - self.pack(
|
143 |
+
np.matmul(
|
144 |
+
G,
|
145 |
+
np.hstack([self.J, np.zeros([24, 1])]).reshape([24, 4, 1])))
|
146 |
+
# transformation of each vertex
|
147 |
+
T = np.tensordot(self.weights, G, axes=[[1], [0]])
|
148 |
+
rest_shape_h = np.hstack((v_posed, np.ones([v_posed.shape[0], 1])))
|
149 |
+
v = np.matmul(T, rest_shape_h.reshape([-1, 4, 1])).reshape([-1,
|
150 |
+
4])[:, :3]
|
151 |
+
self.verts = v + self.trans.reshape([1, 3])
|
152 |
+
self.G = G
|
153 |
+
|
154 |
+
def rodrigues(self, r):
|
155 |
+
"""
|
156 |
+
Rodrigues' rotation formula that turns axis-angle vector into rotation
|
157 |
+
matrix in a batch-ed manner.
|
158 |
+
|
159 |
+
Parameter:
|
160 |
+
----------
|
161 |
+
r: Axis-angle rotation vector of shape [batch_size, 1, 3].
|
162 |
+
|
163 |
+
Return:
|
164 |
+
-------
|
165 |
+
Rotation matrix of shape [batch_size, 3, 3].
|
166 |
+
|
167 |
+
"""
|
168 |
+
theta = np.linalg.norm(r, axis=(1, 2), keepdims=True)
|
169 |
+
# avoid zero divide
|
170 |
+
theta = np.maximum(theta, np.finfo(np.float64).tiny)
|
171 |
+
r_hat = r / theta
|
172 |
+
cos = np.cos(theta)
|
173 |
+
z_stick = np.zeros(theta.shape[0])
|
174 |
+
m = np.dstack([
|
175 |
+
z_stick, -r_hat[:, 0, 2], r_hat[:, 0, 1], r_hat[:, 0, 2], z_stick,
|
176 |
+
-r_hat[:, 0, 0], -r_hat[:, 0, 1], r_hat[:, 0, 0], z_stick
|
177 |
+
]).reshape([-1, 3, 3])
|
178 |
+
i_cube = np.broadcast_to(np.expand_dims(np.eye(3), axis=0),
|
179 |
+
[theta.shape[0], 3, 3])
|
180 |
+
A = np.transpose(r_hat, axes=[0, 2, 1])
|
181 |
+
B = r_hat
|
182 |
+
dot = np.matmul(A, B)
|
183 |
+
R = cos * i_cube + (1 - cos) * dot + np.sin(theta) * m
|
184 |
+
return R
|
185 |
+
|
186 |
+
def with_zeros(self, x):
|
187 |
+
"""
|
188 |
+
Append a [0, 0, 0, 1] vector to a [3, 4] matrix.
|
189 |
+
|
190 |
+
Parameter:
|
191 |
+
---------
|
192 |
+
x: Matrix to be appended.
|
193 |
+
|
194 |
+
Return:
|
195 |
+
------
|
196 |
+
Matrix after appending of shape [4,4]
|
197 |
+
|
198 |
+
"""
|
199 |
+
return np.vstack((x, np.array([[0.0, 0.0, 0.0, 1.0]])))
|
200 |
+
|
201 |
+
def pack(self, x):
|
202 |
+
"""
|
203 |
+
Append zero matrices of shape [4, 3] to vectors of [4, 1] shape in a batched
|
204 |
+
manner.
|
205 |
+
|
206 |
+
Parameter:
|
207 |
+
----------
|
208 |
+
x: Matrices to be appended of shape [batch_size, 4, 1]
|
209 |
+
|
210 |
+
Return:
|
211 |
+
------
|
212 |
+
Matrix of shape [batch_size, 4, 4] after appending.
|
213 |
+
|
214 |
+
"""
|
215 |
+
return np.dstack((np.zeros((x.shape[0], 4, 3)), x))
|
216 |
+
|
217 |
+
def save_to_obj(self, path):
|
218 |
+
"""
|
219 |
+
Save the SMPL model into .obj file.
|
220 |
+
|
221 |
+
Parameter:
|
222 |
+
---------
|
223 |
+
path: Path to save.
|
224 |
+
|
225 |
+
"""
|
226 |
+
with open(path, 'w') as fp:
|
227 |
+
for v in self.verts:
|
228 |
+
fp.write('v %f %f %f\n' % (v[0], v[1], v[2]))
|
229 |
+
for f in self.faces + 1:
|
230 |
+
fp.write('f %d %d %d\n' % (f[0], f[1], f[2]))
|
231 |
+
|
232 |
+
|
233 |
+
class TetraSMPLModel():
|
234 |
+
def __init__(self,
|
235 |
+
model_path,
|
236 |
+
model_addition_path,
|
237 |
+
age='adult',
|
238 |
+
v_template=None):
|
239 |
+
"""
|
240 |
+
SMPL model.
|
241 |
+
|
242 |
+
Parameter:
|
243 |
+
---------
|
244 |
+
model_path: Path to the SMPL model parameters, pre-processed by
|
245 |
+
`preprocess.py`.
|
246 |
+
|
247 |
+
"""
|
248 |
+
with open(model_path, 'rb') as f:
|
249 |
+
params = pickle.load(f, encoding='latin1')
|
250 |
+
|
251 |
+
self.J_regressor = params['J_regressor']
|
252 |
+
self.weights = np.asarray(params['weights'])
|
253 |
+
self.posedirs = np.asarray(params['posedirs'])
|
254 |
+
|
255 |
+
if v_template is not None:
|
256 |
+
self.v_template = v_template
|
257 |
+
else:
|
258 |
+
self.v_template = np.asarray(params['v_template'])
|
259 |
+
|
260 |
+
self.shapedirs = np.asarray(params['shapedirs'])
|
261 |
+
self.faces = np.asarray(params['f'])
|
262 |
+
self.kintree_table = np.asarray(params['kintree_table'])
|
263 |
+
|
264 |
+
params_added = np.load(model_addition_path)
|
265 |
+
self.v_template_added = params_added['v_template_added']
|
266 |
+
self.weights_added = params_added['weights_added']
|
267 |
+
self.shapedirs_added = params_added['shapedirs_added']
|
268 |
+
self.posedirs_added = params_added['posedirs_added']
|
269 |
+
self.tetrahedrons = params_added['tetrahedrons']
|
270 |
+
|
271 |
+
id_to_col = {
|
272 |
+
self.kintree_table[1, i]: i
|
273 |
+
for i in range(self.kintree_table.shape[1])
|
274 |
+
}
|
275 |
+
self.parent = {
|
276 |
+
i: id_to_col[self.kintree_table[0, i]]
|
277 |
+
for i in range(1, self.kintree_table.shape[1])
|
278 |
+
}
|
279 |
+
|
280 |
+
self.pose_shape = [24, 3]
|
281 |
+
self.beta_shape = [10]
|
282 |
+
self.trans_shape = [3]
|
283 |
+
|
284 |
+
if age == 'kid':
|
285 |
+
v_template_smil = np.load(
|
286 |
+
os.path.join(os.path.dirname(model_path),
|
287 |
+
"smpl/smpl_kid_template.npy"))
|
288 |
+
v_template_smil -= np.mean(v_template_smil, axis=0)
|
289 |
+
v_template_diff = np.expand_dims(v_template_smil - self.v_template,
|
290 |
+
axis=2)
|
291 |
+
self.shapedirs = np.concatenate(
|
292 |
+
(self.shapedirs[:, :, :self.beta_shape[0]], v_template_diff),
|
293 |
+
axis=2)
|
294 |
+
self.beta_shape[0] += 1
|
295 |
+
|
296 |
+
self.pose = np.zeros(self.pose_shape)
|
297 |
+
self.beta = np.zeros(self.beta_shape)
|
298 |
+
self.trans = np.zeros(self.trans_shape)
|
299 |
+
|
300 |
+
self.verts = None
|
301 |
+
self.verts_added = None
|
302 |
+
self.J = None
|
303 |
+
self.R = None
|
304 |
+
self.G = None
|
305 |
+
|
306 |
+
self.update()
|
307 |
+
|
308 |
+
def set_params(self, pose=None, beta=None, trans=None):
|
309 |
+
"""
|
310 |
+
Set pose, shape, and/or translation parameters of SMPL model. Verices of the
|
311 |
+
model will be updated and returned.
|
312 |
+
|
313 |
+
Prameters:
|
314 |
+
---------
|
315 |
+
pose: Also known as 'theta', a [24,3] matrix indicating child joint rotation
|
316 |
+
relative to parent joint. For root joint it's global orientation.
|
317 |
+
Represented in a axis-angle format.
|
318 |
+
|
319 |
+
beta: Parameter for model shape. A vector of shape [10]. Coefficients for
|
320 |
+
PCA component. Only 10 components were released by MPI.
|
321 |
+
|
322 |
+
trans: Global translation of shape [3].
|
323 |
+
|
324 |
+
Return:
|
325 |
+
------
|
326 |
+
Updated vertices.
|
327 |
+
|
328 |
+
"""
|
329 |
+
|
330 |
+
if torch.is_tensor(pose):
|
331 |
+
pose = pose.detach().cpu().numpy()
|
332 |
+
if torch.is_tensor(beta):
|
333 |
+
beta = beta.detach().cpu().numpy()
|
334 |
+
|
335 |
+
if pose is not None:
|
336 |
+
self.pose = pose
|
337 |
+
if beta is not None:
|
338 |
+
self.beta = beta
|
339 |
+
if trans is not None:
|
340 |
+
self.trans = trans
|
341 |
+
self.update()
|
342 |
+
return self.verts
|
343 |
+
|
344 |
+
def update(self):
|
345 |
+
"""
|
346 |
+
Called automatically when parameters are updated.
|
347 |
+
|
348 |
+
"""
|
349 |
+
# how beta affect body shape
|
350 |
+
v_shaped = self.shapedirs.dot(self.beta) + self.v_template
|
351 |
+
v_shaped_added = self.shapedirs_added.dot(
|
352 |
+
self.beta) + self.v_template_added
|
353 |
+
# joints location
|
354 |
+
self.J = self.J_regressor.dot(v_shaped)
|
355 |
+
pose_cube = self.pose.reshape((-1, 1, 3))
|
356 |
+
# rotation matrix for each joint
|
357 |
+
self.R = self.rodrigues(pose_cube)
|
358 |
+
I_cube = np.broadcast_to(np.expand_dims(np.eye(3), axis=0),
|
359 |
+
(self.R.shape[0] - 1, 3, 3))
|
360 |
+
lrotmin = (self.R[1:] - I_cube).ravel()
|
361 |
+
# how pose affect body shape in zero pose
|
362 |
+
v_posed = v_shaped + self.posedirs.dot(lrotmin)
|
363 |
+
v_posed_added = v_shaped_added + self.posedirs_added.dot(lrotmin)
|
364 |
+
# world transformation of each joint
|
365 |
+
G = np.empty((self.kintree_table.shape[1], 4, 4))
|
366 |
+
G[0] = self.with_zeros(
|
367 |
+
np.hstack((self.R[0], self.J[0, :].reshape([3, 1]))))
|
368 |
+
for i in range(1, self.kintree_table.shape[1]):
|
369 |
+
G[i] = G[self.parent[i]].dot(
|
370 |
+
self.with_zeros(
|
371 |
+
np.hstack([
|
372 |
+
self.R[i],
|
373 |
+
((self.J[i, :] - self.J[self.parent[i], :]).reshape(
|
374 |
+
[3, 1]))
|
375 |
+
])))
|
376 |
+
# remove the transformation due to the rest pose
|
377 |
+
G = G - self.pack(
|
378 |
+
np.matmul(
|
379 |
+
G,
|
380 |
+
np.hstack([self.J, np.zeros([24, 1])]).reshape([24, 4, 1])))
|
381 |
+
self.G = G
|
382 |
+
# transformation of each vertex
|
383 |
+
T = np.tensordot(self.weights, G, axes=[[1], [0]])
|
384 |
+
rest_shape_h = np.hstack((v_posed, np.ones([v_posed.shape[0], 1])))
|
385 |
+
v = np.matmul(T, rest_shape_h.reshape([-1, 4, 1])).reshape([-1,
|
386 |
+
4])[:, :3]
|
387 |
+
self.verts = v + self.trans.reshape([1, 3])
|
388 |
+
T_added = np.tensordot(self.weights_added, G, axes=[[1], [0]])
|
389 |
+
rest_shape_added_h = np.hstack(
|
390 |
+
(v_posed_added, np.ones([v_posed_added.shape[0], 1])))
|
391 |
+
v_added = np.matmul(T_added,
|
392 |
+
rest_shape_added_h.reshape([-1, 4,
|
393 |
+
1])).reshape([-1, 4
|
394 |
+
])[:, :3]
|
395 |
+
self.verts_added = v_added + self.trans.reshape([1, 3])
|
396 |
+
|
397 |
+
def rodrigues(self, r):
|
398 |
+
"""
|
399 |
+
Rodrigues' rotation formula that turns axis-angle vector into rotation
|
400 |
+
matrix in a batch-ed manner.
|
401 |
+
|
402 |
+
Parameter:
|
403 |
+
----------
|
404 |
+
r: Axis-angle rotation vector of shape [batch_size, 1, 3].
|
405 |
+
|
406 |
+
Return:
|
407 |
+
-------
|
408 |
+
Rotation matrix of shape [batch_size, 3, 3].
|
409 |
+
|
410 |
+
"""
|
411 |
+
theta = np.linalg.norm(r, axis=(1, 2), keepdims=True)
|
412 |
+
# avoid zero divide
|
413 |
+
theta = np.maximum(theta, np.finfo(np.float64).tiny)
|
414 |
+
r_hat = r / theta
|
415 |
+
cos = np.cos(theta)
|
416 |
+
z_stick = np.zeros(theta.shape[0])
|
417 |
+
m = np.dstack([
|
418 |
+
z_stick, -r_hat[:, 0, 2], r_hat[:, 0, 1], r_hat[:, 0, 2], z_stick,
|
419 |
+
-r_hat[:, 0, 0], -r_hat[:, 0, 1], r_hat[:, 0, 0], z_stick
|
420 |
+
]).reshape([-1, 3, 3])
|
421 |
+
i_cube = np.broadcast_to(np.expand_dims(np.eye(3), axis=0),
|
422 |
+
[theta.shape[0], 3, 3])
|
423 |
+
A = np.transpose(r_hat, axes=[0, 2, 1])
|
424 |
+
B = r_hat
|
425 |
+
dot = np.matmul(A, B)
|
426 |
+
R = cos * i_cube + (1 - cos) * dot + np.sin(theta) * m
|
427 |
+
return R
|
428 |
+
|
429 |
+
def with_zeros(self, x):
|
430 |
+
"""
|
431 |
+
Append a [0, 0, 0, 1] vector to a [3, 4] matrix.
|
432 |
+
|
433 |
+
Parameter:
|
434 |
+
---------
|
435 |
+
x: Matrix to be appended.
|
436 |
+
|
437 |
+
Return:
|
438 |
+
------
|
439 |
+
Matrix after appending of shape [4,4]
|
440 |
+
|
441 |
+
"""
|
442 |
+
return np.vstack((x, np.array([[0.0, 0.0, 0.0, 1.0]])))
|
443 |
+
|
444 |
+
def pack(self, x):
|
445 |
+
"""
|
446 |
+
Append zero matrices of shape [4, 3] to vectors of [4, 1] shape in a batched
|
447 |
+
manner.
|
448 |
+
|
449 |
+
Parameter:
|
450 |
+
----------
|
451 |
+
x: Matrices to be appended of shape [batch_size, 4, 1]
|
452 |
+
|
453 |
+
Return:
|
454 |
+
------
|
455 |
+
Matrix of shape [batch_size, 4, 4] after appending.
|
456 |
+
|
457 |
+
"""
|
458 |
+
return np.dstack((np.zeros((x.shape[0], 4, 3)), x))
|
459 |
+
|
460 |
+
def save_mesh_to_obj(self, path):
|
461 |
+
"""
|
462 |
+
Save the SMPL model into .obj file.
|
463 |
+
|
464 |
+
Parameter:
|
465 |
+
---------
|
466 |
+
path: Path to save.
|
467 |
+
|
468 |
+
"""
|
469 |
+
with open(path, 'w') as fp:
|
470 |
+
for v in self.verts:
|
471 |
+
fp.write('v %f %f %f\n' % (v[0], v[1], v[2]))
|
472 |
+
for f in self.faces + 1:
|
473 |
+
fp.write('f %d %d %d\n' % (f[0], f[1], f[2]))
|
474 |
+
|
475 |
+
def save_tetrahedron_to_obj(self, path):
|
476 |
+
"""
|
477 |
+
Save the tetrahedron SMPL model into .obj file.
|
478 |
+
|
479 |
+
Parameter:
|
480 |
+
---------
|
481 |
+
path: Path to save.
|
482 |
+
|
483 |
+
"""
|
484 |
+
|
485 |
+
with open(path, 'w') as fp:
|
486 |
+
for v in self.verts:
|
487 |
+
fp.write('v %f %f %f 1 0 0\n' % (v[0], v[1], v[2]))
|
488 |
+
for va in self.verts_added:
|
489 |
+
fp.write('v %f %f %f 0 0 1\n' % (va[0], va[1], va[2]))
|
490 |
+
for t in self.tetrahedrons + 1:
|
491 |
+
fp.write('f %d %d %d\n' % (t[0], t[2], t[1]))
|
492 |
+
fp.write('f %d %d %d\n' % (t[0], t[3], t[2]))
|
493 |
+
fp.write('f %d %d %d\n' % (t[0], t[1], t[3]))
|
494 |
+
fp.write('f %d %d %d\n' % (t[1], t[2], t[3]))
|
lib/dataset/hoppeMesh.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
|
5 |
+
# holder of all proprietary rights on this computer program.
|
6 |
+
# You can only use this computer program if you have closed
|
7 |
+
# a license agreement with MPG or you get the right to use the computer
|
8 |
+
# program from someone who is authorized to grant you that right.
|
9 |
+
# Any use of the computer program without a valid license is prohibited and
|
10 |
+
# liable to prosecution.
|
11 |
+
#
|
12 |
+
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
|
13 |
+
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
|
14 |
+
# for Intelligent Systems. All rights reserved.
|
15 |
+
#
|
16 |
+
# Contact: [email protected]
|
17 |
+
|
18 |
+
import numpy as np
|
19 |
+
from scipy.spatial import cKDTree
|
20 |
+
import trimesh
|
21 |
+
|
22 |
+
import logging
|
23 |
+
|
24 |
+
logging.getLogger("trimesh").setLevel(logging.ERROR)
|
25 |
+
|
26 |
+
|
27 |
+
def save_obj_mesh(mesh_path, verts, faces):
|
28 |
+
file = open(mesh_path, 'w')
|
29 |
+
for v in verts:
|
30 |
+
file.write('v %.4f %.4f %.4f\n' % (v[0], v[1], v[2]))
|
31 |
+
for f in faces:
|
32 |
+
f_plus = f + 1
|
33 |
+
file.write('f %d %d %d\n' % (f_plus[0], f_plus[1], f_plus[2]))
|
34 |
+
file.close()
|
35 |
+
|
36 |
+
|
37 |
+
def save_obj_mesh_with_color(mesh_path, verts, faces, colors):
|
38 |
+
file = open(mesh_path, 'w')
|
39 |
+
|
40 |
+
for idx, v in enumerate(verts):
|
41 |
+
c = colors[idx]
|
42 |
+
file.write('v %.4f %.4f %.4f %.4f %.4f %.4f\n' %
|
43 |
+
(v[0], v[1], v[2], c[0], c[1], c[2]))
|
44 |
+
for f in faces:
|
45 |
+
f_plus = f + 1
|
46 |
+
file.write('f %d %d %d\n' % (f_plus[0], f_plus[1], f_plus[2]))
|
47 |
+
file.close()
|
48 |
+
|
49 |
+
|
50 |
+
def save_ply(mesh_path, points, rgb):
|
51 |
+
'''
|
52 |
+
Save the visualization of sampling to a ply file.
|
53 |
+
Red points represent positive predictions.
|
54 |
+
Green points represent negative predictions.
|
55 |
+
:param mesh_path: File name to save
|
56 |
+
:param points: [N, 3] array of points
|
57 |
+
:param rgb: [N, 3] array of rgb values in the range [0~1]
|
58 |
+
:return:
|
59 |
+
'''
|
60 |
+
to_save = np.concatenate([points, rgb * 255], axis=-1)
|
61 |
+
return np.savetxt(
|
62 |
+
mesh_path,
|
63 |
+
to_save,
|
64 |
+
fmt='%.6f %.6f %.6f %d %d %d',
|
65 |
+
comments='',
|
66 |
+
header=(
|
67 |
+
'ply\nformat ascii 1.0\nelement vertex {:d}\n' +
|
68 |
+
'property float x\nproperty float y\nproperty float z\n' +
|
69 |
+
'property uchar red\nproperty uchar green\nproperty uchar blue\n' +
|
70 |
+
'end_header').format(points.shape[0]))
|
71 |
+
|
72 |
+
|
73 |
+
class HoppeMesh:
|
74 |
+
def __init__(self, verts, faces, vert_normals, face_normals):
|
75 |
+
'''
|
76 |
+
The HoppeSDF calculates signed distance towards a predefined oriented point cloud
|
77 |
+
http://hhoppe.com/recon.pdf
|
78 |
+
For clean and high-resolution pcl data, this is the fastest and accurate approximation of sdf
|
79 |
+
:param points: pts
|
80 |
+
:param normals: normals
|
81 |
+
'''
|
82 |
+
self.verts = verts # [n, 3]
|
83 |
+
self.faces = faces # [m, 3]
|
84 |
+
self.vert_normals = vert_normals # [n, 3]
|
85 |
+
self.face_normals = face_normals # [m, 3]
|
86 |
+
|
87 |
+
self.kd_tree = cKDTree(self.verts)
|
88 |
+
self.len = len(self.verts)
|
89 |
+
|
90 |
+
def query(self, points):
|
91 |
+
dists, idx = self.kd_tree.query(points, n_jobs=1)
|
92 |
+
# FIXME: because the eyebows are removed, cKDTree around eyebows
|
93 |
+
# are not accurate. Cause a few false-inside labels here.
|
94 |
+
dirs = points - self.verts[idx]
|
95 |
+
signs = (dirs * self.vert_normals[idx]).sum(axis=1)
|
96 |
+
signs = (signs > 0) * 2 - 1
|
97 |
+
return signs * dists
|
98 |
+
|
99 |
+
def contains(self, points):
|
100 |
+
|
101 |
+
labels = trimesh.Trimesh(vertices=self.verts,
|
102 |
+
faces=self.faces).contains(points)
|
103 |
+
return labels
|
104 |
+
|
105 |
+
def export(self, path):
|
106 |
+
if self.colors is not None:
|
107 |
+
save_obj_mesh_with_color(path, self.verts, self.faces,
|
108 |
+
self.colors[:, 0:3] / 255.0)
|
109 |
+
else:
|
110 |
+
save_obj_mesh(path, self.verts, self.faces)
|
111 |
+
|
112 |
+
def export_ply(self, path):
|
113 |
+
save_ply(path, self.verts, self.colors[:, 0:3] / 255.0)
|
114 |
+
|
115 |
+
def triangles(self):
|
116 |
+
return self.verts[self.faces] # [n, 3, 3]
|
lib/dataset/mesh_util.py
ADDED
@@ -0,0 +1,894 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
|
5 |
+
# holder of all proprietary rights on this computer program.
|
6 |
+
# You can only use this computer program if you have closed
|
7 |
+
# a license agreement with MPG or you get the right to use the computer
|
8 |
+
# program from someone who is authorized to grant you that right.
|
9 |
+
# Any use of the computer program without a valid license is prohibited and
|
10 |
+
# liable to prosecution.
|
11 |
+
#
|
12 |
+
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
|
13 |
+
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
|
14 |
+
# for Intelligent Systems. All rights reserved.
|
15 |
+
#
|
16 |
+
# Contact: [email protected]
|
17 |
+
|
18 |
+
import numpy as np
|
19 |
+
import cv2
|
20 |
+
import pymeshlab
|
21 |
+
import torch
|
22 |
+
import torchvision
|
23 |
+
import trimesh
|
24 |
+
from pytorch3d.io import load_obj
|
25 |
+
from termcolor import colored
|
26 |
+
from scipy.spatial import cKDTree
|
27 |
+
|
28 |
+
from pytorch3d.structures import Meshes
|
29 |
+
import torch.nn.functional as F
|
30 |
+
|
31 |
+
import os
|
32 |
+
from lib.pymaf.utils.imutils import uncrop
|
33 |
+
from lib.common.render_utils import Pytorch3dRasterizer, face_vertices
|
34 |
+
|
35 |
+
from pytorch3d.renderer.mesh import rasterize_meshes
|
36 |
+
from PIL import Image, ImageFont, ImageDraw
|
37 |
+
from kaolin.ops.mesh import check_sign
|
38 |
+
from kaolin.metrics.trianglemesh import point_to_mesh_distance
|
39 |
+
|
40 |
+
from pytorch3d.loss import (
|
41 |
+
mesh_laplacian_smoothing,
|
42 |
+
mesh_normal_consistency
|
43 |
+
)
|
44 |
+
|
45 |
+
from huggingface_hub import hf_hub_download, hf_hub_url, cached_download
|
46 |
+
|
47 |
+
|
48 |
+
def tensor2variable(tensor, device):
|
49 |
+
# [1,23,3,3]
|
50 |
+
return torch.tensor(tensor, device=device, requires_grad=True)
|
51 |
+
|
52 |
+
|
53 |
+
def normal_loss(vec1, vec2):
|
54 |
+
|
55 |
+
# vec1_mask = vec1.sum(dim=1) != 0.0
|
56 |
+
# vec2_mask = vec2.sum(dim=1) != 0.0
|
57 |
+
# union_mask = vec1_mask * vec2_mask
|
58 |
+
vec_sim = torch.nn.CosineSimilarity(dim=1, eps=1e-6)(vec1, vec2)
|
59 |
+
# vec_diff = ((vec_sim-1.0)**2)[union_mask].mean()
|
60 |
+
vec_diff = ((vec_sim-1.0)**2).mean()
|
61 |
+
|
62 |
+
return vec_diff
|
63 |
+
|
64 |
+
|
65 |
+
class GMoF(torch.nn.Module):
|
66 |
+
def __init__(self, rho=1):
|
67 |
+
super(GMoF, self).__init__()
|
68 |
+
self.rho = rho
|
69 |
+
|
70 |
+
def extra_repr(self):
|
71 |
+
return 'rho = {}'.format(self.rho)
|
72 |
+
|
73 |
+
def forward(self, residual):
|
74 |
+
dist = torch.div(residual, residual + self.rho ** 2)
|
75 |
+
return self.rho ** 2 * dist
|
76 |
+
|
77 |
+
|
78 |
+
def mesh_edge_loss(meshes, target_length: float = 0.0):
|
79 |
+
"""
|
80 |
+
Computes mesh edge length regularization loss averaged across all meshes
|
81 |
+
in a batch. Each mesh contributes equally to the final loss, regardless of
|
82 |
+
the number of edges per mesh in the batch by weighting each mesh with the
|
83 |
+
inverse number of edges. For example, if mesh 3 (out of N) has only E=4
|
84 |
+
edges, then the loss for each edge in mesh 3 should be multiplied by 1/E to
|
85 |
+
contribute to the final loss.
|
86 |
+
|
87 |
+
Args:
|
88 |
+
meshes: Meshes object with a batch of meshes.
|
89 |
+
target_length: Resting value for the edge length.
|
90 |
+
|
91 |
+
Returns:
|
92 |
+
loss: Average loss across the batch. Returns 0 if meshes contains
|
93 |
+
no meshes or all empty meshes.
|
94 |
+
"""
|
95 |
+
if meshes.isempty():
|
96 |
+
return torch.tensor(
|
97 |
+
[0.0], dtype=torch.float32, device=meshes.device, requires_grad=True
|
98 |
+
)
|
99 |
+
|
100 |
+
N = len(meshes)
|
101 |
+
edges_packed = meshes.edges_packed() # (sum(E_n), 3)
|
102 |
+
verts_packed = meshes.verts_packed() # (sum(V_n), 3)
|
103 |
+
edge_to_mesh_idx = meshes.edges_packed_to_mesh_idx() # (sum(E_n), )
|
104 |
+
num_edges_per_mesh = meshes.num_edges_per_mesh() # N
|
105 |
+
|
106 |
+
# Determine the weight for each edge based on the number of edges in the
|
107 |
+
# mesh it corresponds to.
|
108 |
+
# TODO (nikhilar) Find a faster way of computing the weights for each edge
|
109 |
+
# as this is currently a bottleneck for meshes with a large number of faces.
|
110 |
+
weights = num_edges_per_mesh.gather(0, edge_to_mesh_idx)
|
111 |
+
weights = 1.0 / weights.float()
|
112 |
+
|
113 |
+
verts_edges = verts_packed[edges_packed]
|
114 |
+
v0, v1 = verts_edges.unbind(1)
|
115 |
+
loss = ((v0 - v1).norm(dim=1, p=2) - target_length) ** 2.0
|
116 |
+
loss_vertex = loss * weights
|
117 |
+
# loss_outlier = torch.topk(loss, 100)[0].mean()
|
118 |
+
# loss_all = (loss_vertex.sum() + loss_outlier.mean()) / N
|
119 |
+
loss_all = loss_vertex.sum() / N
|
120 |
+
|
121 |
+
return loss_all
|
122 |
+
|
123 |
+
|
124 |
+
def remesh(obj_path, perc, device):
|
125 |
+
|
126 |
+
ms = pymeshlab.MeshSet()
|
127 |
+
ms.load_new_mesh(obj_path)
|
128 |
+
ms.laplacian_smooth()
|
129 |
+
ms.remeshing_isotropic_explicit_remeshing(
|
130 |
+
targetlen=pymeshlab.Percentage(perc), adaptive=True)
|
131 |
+
ms.save_current_mesh(obj_path.replace("recon", "remesh"))
|
132 |
+
polished_mesh = trimesh.load_mesh(obj_path.replace("recon", "remesh"))
|
133 |
+
verts_pr = torch.tensor(polished_mesh.vertices).float().unsqueeze(0).to(device)
|
134 |
+
faces_pr = torch.tensor(polished_mesh.faces).long().unsqueeze(0).to(device)
|
135 |
+
|
136 |
+
return verts_pr, faces_pr
|
137 |
+
|
138 |
+
|
139 |
+
def possion(mesh, obj_path):
|
140 |
+
|
141 |
+
mesh.export(obj_path)
|
142 |
+
ms = pymeshlab.MeshSet()
|
143 |
+
ms.load_new_mesh(obj_path)
|
144 |
+
ms.surface_reconstruction_screened_poisson(depth=10)
|
145 |
+
ms.set_current_mesh(1)
|
146 |
+
ms.save_current_mesh(obj_path)
|
147 |
+
|
148 |
+
return trimesh.load(obj_path)
|
149 |
+
|
150 |
+
|
151 |
+
def get_mask(tensor, dim):
|
152 |
+
|
153 |
+
mask = torch.abs(tensor).sum(dim=dim, keepdims=True) > 0.0
|
154 |
+
mask = mask.type_as(tensor)
|
155 |
+
|
156 |
+
return mask
|
157 |
+
|
158 |
+
|
159 |
+
def blend_rgb_norm(rgb, norm, mask):
|
160 |
+
|
161 |
+
# [0,0,0] or [127,127,127] should be marked as mask
|
162 |
+
final = rgb * (1-mask) + norm * (mask)
|
163 |
+
|
164 |
+
return final.astype(np.uint8)
|
165 |
+
|
166 |
+
|
167 |
+
def unwrap(image, data):
|
168 |
+
|
169 |
+
img_uncrop = uncrop(np.array(Image.fromarray(image).resize(data['uncrop_param']['box_shape'][:2])),
|
170 |
+
data['uncrop_param']['center'],
|
171 |
+
data['uncrop_param']['scale'],
|
172 |
+
data['uncrop_param']['crop_shape'])
|
173 |
+
|
174 |
+
img_orig = cv2.warpAffine(img_uncrop,
|
175 |
+
np.linalg.inv(data['uncrop_param']['M'])[:2, :],
|
176 |
+
data['uncrop_param']['ori_shape'][::-1][1:],
|
177 |
+
flags=cv2.INTER_CUBIC)
|
178 |
+
|
179 |
+
return img_orig
|
180 |
+
|
181 |
+
|
182 |
+
# Losses to smooth / regularize the mesh shape
|
183 |
+
def update_mesh_shape_prior_losses(mesh, losses):
|
184 |
+
|
185 |
+
# and (b) the edge length of the predicted mesh
|
186 |
+
losses["edge"]['value'] = mesh_edge_loss(mesh)
|
187 |
+
# mesh normal consistency
|
188 |
+
losses["nc"]['value'] = mesh_normal_consistency(mesh)
|
189 |
+
# mesh laplacian smoothing
|
190 |
+
losses["laplacian"]['value'] = mesh_laplacian_smoothing(
|
191 |
+
mesh, method="uniform")
|
192 |
+
|
193 |
+
|
194 |
+
def rename(old_dict, old_name, new_name):
|
195 |
+
new_dict = {}
|
196 |
+
for key, value in zip(old_dict.keys(), old_dict.values()):
|
197 |
+
new_key = key if key != old_name else new_name
|
198 |
+
new_dict[new_key] = old_dict[key]
|
199 |
+
return new_dict
|
200 |
+
|
201 |
+
|
202 |
+
def load_checkpoint(model, cfg):
|
203 |
+
|
204 |
+
model_dict = model.state_dict()
|
205 |
+
main_dict = {}
|
206 |
+
normal_dict = {}
|
207 |
+
|
208 |
+
device = torch.device(f"cuda:{cfg['test_gpus'][0]}")
|
209 |
+
|
210 |
+
main_dict = torch.load(cached_download(cfg.resume_path, use_auth_token=os.environ['ICON']),
|
211 |
+
map_location=device)['state_dict']
|
212 |
+
|
213 |
+
main_dict = {
|
214 |
+
k: v
|
215 |
+
for k, v in main_dict.items()
|
216 |
+
if k in model_dict and v.shape == model_dict[k].shape and (
|
217 |
+
'reconEngine' not in k) and ("normal_filter" not in k) and (
|
218 |
+
'voxelization' not in k)
|
219 |
+
}
|
220 |
+
print(colored(f"Resume MLP weights from {cfg.resume_path}", 'green'))
|
221 |
+
|
222 |
+
normal_dict = torch.load(cached_download(cfg.normal_path, use_auth_token=os.environ['ICON']),
|
223 |
+
map_location=device)['state_dict']
|
224 |
+
|
225 |
+
for key in normal_dict.keys():
|
226 |
+
normal_dict = rename(normal_dict, key,
|
227 |
+
key.replace("netG", "netG.normal_filter"))
|
228 |
+
|
229 |
+
normal_dict = {
|
230 |
+
k: v
|
231 |
+
for k, v in normal_dict.items()
|
232 |
+
if k in model_dict and v.shape == model_dict[k].shape
|
233 |
+
}
|
234 |
+
print(colored(f"Resume normal model from {cfg.normal_path}", 'green'))
|
235 |
+
|
236 |
+
model_dict.update(main_dict)
|
237 |
+
model_dict.update(normal_dict)
|
238 |
+
model.load_state_dict(model_dict)
|
239 |
+
|
240 |
+
model.netG = model.netG.to(device)
|
241 |
+
model.reconEngine = model.reconEngine.to(device)
|
242 |
+
|
243 |
+
model.netG.training = False
|
244 |
+
model.netG.eval()
|
245 |
+
|
246 |
+
del main_dict
|
247 |
+
del normal_dict
|
248 |
+
del model_dict
|
249 |
+
|
250 |
+
return model
|
251 |
+
|
252 |
+
|
253 |
+
def read_smpl_constants(folder):
|
254 |
+
"""Load smpl vertex code"""
|
255 |
+
smpl_vtx_std = np.loadtxt(cached_download(os.path.join(folder, 'vertices.txt'), use_auth_token=os.environ['ICON']))
|
256 |
+
min_x = np.min(smpl_vtx_std[:, 0])
|
257 |
+
max_x = np.max(smpl_vtx_std[:, 0])
|
258 |
+
min_y = np.min(smpl_vtx_std[:, 1])
|
259 |
+
max_y = np.max(smpl_vtx_std[:, 1])
|
260 |
+
min_z = np.min(smpl_vtx_std[:, 2])
|
261 |
+
max_z = np.max(smpl_vtx_std[:, 2])
|
262 |
+
|
263 |
+
smpl_vtx_std[:, 0] = (smpl_vtx_std[:, 0] - min_x) / (max_x - min_x)
|
264 |
+
smpl_vtx_std[:, 1] = (smpl_vtx_std[:, 1] - min_y) / (max_y - min_y)
|
265 |
+
smpl_vtx_std[:, 2] = (smpl_vtx_std[:, 2] - min_z) / (max_z - min_z)
|
266 |
+
smpl_vertex_code = np.float32(np.copy(smpl_vtx_std))
|
267 |
+
"""Load smpl faces & tetrahedrons"""
|
268 |
+
smpl_faces = np.loadtxt(cached_download(os.path.join(folder, 'faces.txt'), use_auth_token=os.environ['ICON']),
|
269 |
+
dtype=np.int32) - 1
|
270 |
+
smpl_face_code = (smpl_vertex_code[smpl_faces[:, 0]] +
|
271 |
+
smpl_vertex_code[smpl_faces[:, 1]] +
|
272 |
+
smpl_vertex_code[smpl_faces[:, 2]]) / 3.0
|
273 |
+
smpl_tetras = np.loadtxt(cached_download(os.path.join(folder, 'tetrahedrons.txt'), use_auth_token=os.environ['ICON']),
|
274 |
+
dtype=np.int32) - 1
|
275 |
+
|
276 |
+
return smpl_vertex_code, smpl_face_code, smpl_faces, smpl_tetras
|
277 |
+
|
278 |
+
|
279 |
+
def feat_select(feat, select):
|
280 |
+
|
281 |
+
# feat [B, featx2, N]
|
282 |
+
# select [B, 1, N]
|
283 |
+
# return [B, feat, N]
|
284 |
+
|
285 |
+
dim = feat.shape[1] // 2
|
286 |
+
idx = torch.tile((1-select), (1, dim, 1))*dim + \
|
287 |
+
torch.arange(0, dim).unsqueeze(0).unsqueeze(2).type_as(select)
|
288 |
+
feat_select = torch.gather(feat, 1, idx.long())
|
289 |
+
|
290 |
+
return feat_select
|
291 |
+
|
292 |
+
|
293 |
+
def get_visibility(xy, z, faces):
|
294 |
+
"""get the visibility of vertices
|
295 |
+
|
296 |
+
Args:
|
297 |
+
xy (torch.tensor): [N,2]
|
298 |
+
z (torch.tensor): [N,1]
|
299 |
+
faces (torch.tensor): [N,3]
|
300 |
+
size (int): resolution of rendered image
|
301 |
+
"""
|
302 |
+
|
303 |
+
xyz = torch.cat((xy, -z), dim=1)
|
304 |
+
xyz = (xyz + 1.0) / 2.0
|
305 |
+
faces = faces.long()
|
306 |
+
|
307 |
+
rasterizer = Pytorch3dRasterizer(image_size=2**12)
|
308 |
+
meshes_screen = Meshes(verts=xyz[None, ...], faces=faces[None, ...])
|
309 |
+
raster_settings = rasterizer.raster_settings
|
310 |
+
|
311 |
+
pix_to_face, zbuf, bary_coords, dists = rasterize_meshes(
|
312 |
+
meshes_screen,
|
313 |
+
image_size=raster_settings.image_size,
|
314 |
+
blur_radius=raster_settings.blur_radius,
|
315 |
+
faces_per_pixel=raster_settings.faces_per_pixel,
|
316 |
+
bin_size=raster_settings.bin_size,
|
317 |
+
max_faces_per_bin=raster_settings.max_faces_per_bin,
|
318 |
+
perspective_correct=raster_settings.perspective_correct,
|
319 |
+
cull_backfaces=raster_settings.cull_backfaces,
|
320 |
+
)
|
321 |
+
|
322 |
+
vis_vertices_id = torch.unique(faces[torch.unique(pix_to_face), :])
|
323 |
+
vis_mask = torch.zeros(size=(z.shape[0], 1))
|
324 |
+
vis_mask[vis_vertices_id] = 1.0
|
325 |
+
|
326 |
+
# print("------------------------\n")
|
327 |
+
# print(f"keep points : {vis_mask.sum()/len(vis_mask)}")
|
328 |
+
|
329 |
+
return vis_mask
|
330 |
+
|
331 |
+
|
332 |
+
def barycentric_coordinates_of_projection(points, vertices):
|
333 |
+
''' https://github.com/MPI-IS/mesh/blob/master/mesh/geometry/barycentric_coordinates_of_projection.py
|
334 |
+
'''
|
335 |
+
"""Given a point, gives projected coords of that point to a triangle
|
336 |
+
in barycentric coordinates.
|
337 |
+
See
|
338 |
+
**Heidrich**, Computing the Barycentric Coordinates of a Projected Point, JGT 05
|
339 |
+
at http://www.cs.ubc.ca/~heidrich/Papers/JGT.05.pdf
|
340 |
+
|
341 |
+
:param p: point to project. [B, 3]
|
342 |
+
:param v0: first vertex of triangles. [B, 3]
|
343 |
+
:returns: barycentric coordinates of ``p``'s projection in triangle defined by ``q``, ``u``, ``v``
|
344 |
+
vectorized so ``p``, ``q``, ``u``, ``v`` can all be ``3xN``
|
345 |
+
"""
|
346 |
+
#(p, q, u, v)
|
347 |
+
v0, v1, v2 = vertices[:, 0], vertices[:, 1], vertices[:, 2]
|
348 |
+
p = points
|
349 |
+
|
350 |
+
q = v0
|
351 |
+
u = v1 - v0
|
352 |
+
v = v2 - v0
|
353 |
+
n = torch.cross(u, v)
|
354 |
+
s = torch.sum(n * n, dim=1)
|
355 |
+
# If the triangle edges are collinear, cross-product is zero,
|
356 |
+
# which makes "s" 0, which gives us divide by zero. So we
|
357 |
+
# make the arbitrary choice to set s to epsv (=numpy.spacing(1)),
|
358 |
+
# the closest thing to zero
|
359 |
+
s[s == 0] = 1e-6
|
360 |
+
oneOver4ASquared = 1.0 / s
|
361 |
+
w = p - q
|
362 |
+
b2 = torch.sum(torch.cross(u, w) * n, dim=1) * oneOver4ASquared
|
363 |
+
b1 = torch.sum(torch.cross(w, v) * n, dim=1) * oneOver4ASquared
|
364 |
+
weights = torch.stack((1 - b1 - b2, b1, b2), dim=-1)
|
365 |
+
# check barycenric weights
|
366 |
+
# p_n = v0*weights[:,0:1] + v1*weights[:,1:2] + v2*weights[:,2:3]
|
367 |
+
return weights
|
368 |
+
|
369 |
+
|
370 |
+
def cal_sdf_batch(verts, faces, cmaps, vis, points):
|
371 |
+
|
372 |
+
# verts [B, N_vert, 3]
|
373 |
+
# faces [B, N_face, 3]
|
374 |
+
# triangles [B, N_face, 3, 3]
|
375 |
+
# points [B, N_point, 3]
|
376 |
+
# cmaps [B, N_vert, 3]
|
377 |
+
|
378 |
+
Bsize = points.shape[0]
|
379 |
+
|
380 |
+
normals = Meshes(verts, faces).verts_normals_padded()
|
381 |
+
|
382 |
+
triangles = face_vertices(verts, faces)
|
383 |
+
normals = face_vertices(normals, faces)
|
384 |
+
cmaps = face_vertices(cmaps, faces)
|
385 |
+
vis = face_vertices(vis, faces)
|
386 |
+
|
387 |
+
residues, pts_ind, _ = point_to_mesh_distance(points, triangles)
|
388 |
+
closest_triangles = torch.gather(
|
389 |
+
triangles, 1, pts_ind[:, :, None, None].expand(-1, -1, 3, 3)).view(-1, 3, 3)
|
390 |
+
closest_normals = torch.gather(
|
391 |
+
normals, 1, pts_ind[:, :, None, None].expand(-1, -1, 3, 3)).view(-1, 3, 3)
|
392 |
+
closest_cmaps = torch.gather(
|
393 |
+
cmaps, 1, pts_ind[:, :, None, None].expand(-1, -1, 3, 3)).view(-1, 3, 3)
|
394 |
+
closest_vis = torch.gather(
|
395 |
+
vis, 1, pts_ind[:, :, None, None].expand(-1, -1, 3, 1)).view(-1, 3, 1)
|
396 |
+
bary_weights = barycentric_coordinates_of_projection(
|
397 |
+
points.view(-1, 3), closest_triangles)
|
398 |
+
|
399 |
+
pts_cmap = (closest_cmaps*bary_weights[:, :, None]).sum(1).unsqueeze(0)
|
400 |
+
pts_vis = (closest_vis*bary_weights[:,
|
401 |
+
:, None]).sum(1).unsqueeze(0).ge(1e-1)
|
402 |
+
pts_norm = (closest_normals*bary_weights[:, :, None]).sum(
|
403 |
+
1).unsqueeze(0) * torch.tensor([-1.0, 1.0, -1.0]).type_as(normals)
|
404 |
+
pts_dist = torch.sqrt(residues) / torch.sqrt(torch.tensor(3))
|
405 |
+
|
406 |
+
pts_signs = 2.0 * (check_sign(verts, faces[0], points).float() - 0.5)
|
407 |
+
pts_sdf = (pts_dist * pts_signs).unsqueeze(-1)
|
408 |
+
|
409 |
+
return pts_sdf.view(Bsize, -1, 1), pts_norm.view(Bsize, -1, 3), pts_cmap.view(Bsize, -1, 3), pts_vis.view(Bsize, -1, 1)
|
410 |
+
|
411 |
+
|
412 |
+
def orthogonal(points, calibrations, transforms=None):
|
413 |
+
'''
|
414 |
+
Compute the orthogonal projections of 3D points into the image plane by given projection matrix
|
415 |
+
:param points: [B, 3, N] Tensor of 3D points
|
416 |
+
:param calibrations: [B, 3, 4] Tensor of projection matrix
|
417 |
+
:param transforms: [B, 2, 3] Tensor of image transform matrix
|
418 |
+
:return: xyz: [B, 3, N] Tensor of xyz coordinates in the image plane
|
419 |
+
'''
|
420 |
+
rot = calibrations[:, :3, :3]
|
421 |
+
trans = calibrations[:, :3, 3:4]
|
422 |
+
pts = torch.baddbmm(trans, rot, points) # [B, 3, N]
|
423 |
+
if transforms is not None:
|
424 |
+
scale = transforms[:2, :2]
|
425 |
+
shift = transforms[:2, 2:3]
|
426 |
+
pts[:, :2, :] = torch.baddbmm(shift, scale, pts[:, :2, :])
|
427 |
+
return pts
|
428 |
+
|
429 |
+
|
430 |
+
def projection(points, calib, format='numpy'):
|
431 |
+
if format == 'tensor':
|
432 |
+
return torch.mm(calib[:3, :3], points.T).T + calib[:3, 3]
|
433 |
+
else:
|
434 |
+
return np.matmul(calib[:3, :3], points.T).T + calib[:3, 3]
|
435 |
+
|
436 |
+
|
437 |
+
def load_calib(calib_path):
|
438 |
+
calib_data = np.loadtxt(calib_path, dtype=float)
|
439 |
+
extrinsic = calib_data[:4, :4]
|
440 |
+
intrinsic = calib_data[4:8, :4]
|
441 |
+
calib_mat = np.matmul(intrinsic, extrinsic)
|
442 |
+
calib_mat = torch.from_numpy(calib_mat).float()
|
443 |
+
return calib_mat
|
444 |
+
|
445 |
+
|
446 |
+
def load_obj_mesh_for_Hoppe(mesh_file):
|
447 |
+
vertex_data = []
|
448 |
+
face_data = []
|
449 |
+
|
450 |
+
if isinstance(mesh_file, str):
|
451 |
+
f = open(mesh_file, "r")
|
452 |
+
else:
|
453 |
+
f = mesh_file
|
454 |
+
for line in f:
|
455 |
+
if isinstance(line, bytes):
|
456 |
+
line = line.decode("utf-8")
|
457 |
+
if line.startswith('#'):
|
458 |
+
continue
|
459 |
+
values = line.split()
|
460 |
+
if not values:
|
461 |
+
continue
|
462 |
+
|
463 |
+
if values[0] == 'v':
|
464 |
+
v = list(map(float, values[1:4]))
|
465 |
+
vertex_data.append(v)
|
466 |
+
|
467 |
+
elif values[0] == 'f':
|
468 |
+
# quad mesh
|
469 |
+
if len(values) > 4:
|
470 |
+
f = list(map(lambda x: int(x.split('/')[0]), values[1:4]))
|
471 |
+
face_data.append(f)
|
472 |
+
f = list(
|
473 |
+
map(lambda x: int(x.split('/')[0]),
|
474 |
+
[values[3], values[4], values[1]]))
|
475 |
+
face_data.append(f)
|
476 |
+
# tri mesh
|
477 |
+
else:
|
478 |
+
f = list(map(lambda x: int(x.split('/')[0]), values[1:4]))
|
479 |
+
face_data.append(f)
|
480 |
+
|
481 |
+
vertices = np.array(vertex_data)
|
482 |
+
faces = np.array(face_data)
|
483 |
+
faces[faces > 0] -= 1
|
484 |
+
|
485 |
+
normals, _ = compute_normal(vertices, faces)
|
486 |
+
|
487 |
+
return vertices, normals, faces
|
488 |
+
|
489 |
+
|
490 |
+
def load_obj_mesh_with_color(mesh_file):
|
491 |
+
vertex_data = []
|
492 |
+
color_data = []
|
493 |
+
face_data = []
|
494 |
+
|
495 |
+
if isinstance(mesh_file, str):
|
496 |
+
f = open(mesh_file, "r")
|
497 |
+
else:
|
498 |
+
f = mesh_file
|
499 |
+
for line in f:
|
500 |
+
if isinstance(line, bytes):
|
501 |
+
line = line.decode("utf-8")
|
502 |
+
if line.startswith('#'):
|
503 |
+
continue
|
504 |
+
values = line.split()
|
505 |
+
if not values:
|
506 |
+
continue
|
507 |
+
|
508 |
+
if values[0] == 'v':
|
509 |
+
v = list(map(float, values[1:4]))
|
510 |
+
vertex_data.append(v)
|
511 |
+
c = list(map(float, values[4:7]))
|
512 |
+
color_data.append(c)
|
513 |
+
|
514 |
+
elif values[0] == 'f':
|
515 |
+
# quad mesh
|
516 |
+
if len(values) > 4:
|
517 |
+
f = list(map(lambda x: int(x.split('/')[0]), values[1:4]))
|
518 |
+
face_data.append(f)
|
519 |
+
f = list(
|
520 |
+
map(lambda x: int(x.split('/')[0]),
|
521 |
+
[values[3], values[4], values[1]]))
|
522 |
+
face_data.append(f)
|
523 |
+
# tri mesh
|
524 |
+
else:
|
525 |
+
f = list(map(lambda x: int(x.split('/')[0]), values[1:4]))
|
526 |
+
face_data.append(f)
|
527 |
+
|
528 |
+
vertices = np.array(vertex_data)
|
529 |
+
colors = np.array(color_data)
|
530 |
+
faces = np.array(face_data)
|
531 |
+
faces[faces > 0] -= 1
|
532 |
+
|
533 |
+
return vertices, colors, faces
|
534 |
+
|
535 |
+
|
536 |
+
def load_obj_mesh(mesh_file, with_normal=False, with_texture=False):
|
537 |
+
vertex_data = []
|
538 |
+
norm_data = []
|
539 |
+
uv_data = []
|
540 |
+
|
541 |
+
face_data = []
|
542 |
+
face_norm_data = []
|
543 |
+
face_uv_data = []
|
544 |
+
|
545 |
+
if isinstance(mesh_file, str):
|
546 |
+
f = open(mesh_file, "r")
|
547 |
+
else:
|
548 |
+
f = mesh_file
|
549 |
+
for line in f:
|
550 |
+
if isinstance(line, bytes):
|
551 |
+
line = line.decode("utf-8")
|
552 |
+
if line.startswith('#'):
|
553 |
+
continue
|
554 |
+
values = line.split()
|
555 |
+
if not values:
|
556 |
+
continue
|
557 |
+
|
558 |
+
if values[0] == 'v':
|
559 |
+
v = list(map(float, values[1:4]))
|
560 |
+
vertex_data.append(v)
|
561 |
+
elif values[0] == 'vn':
|
562 |
+
vn = list(map(float, values[1:4]))
|
563 |
+
norm_data.append(vn)
|
564 |
+
elif values[0] == 'vt':
|
565 |
+
vt = list(map(float, values[1:3]))
|
566 |
+
uv_data.append(vt)
|
567 |
+
|
568 |
+
elif values[0] == 'f':
|
569 |
+
# quad mesh
|
570 |
+
if len(values) > 4:
|
571 |
+
f = list(map(lambda x: int(x.split('/')[0]), values[1:4]))
|
572 |
+
face_data.append(f)
|
573 |
+
f = list(
|
574 |
+
map(lambda x: int(x.split('/')[0]),
|
575 |
+
[values[3], values[4], values[1]]))
|
576 |
+
face_data.append(f)
|
577 |
+
# tri mesh
|
578 |
+
else:
|
579 |
+
f = list(map(lambda x: int(x.split('/')[0]), values[1:4]))
|
580 |
+
face_data.append(f)
|
581 |
+
|
582 |
+
# deal with texture
|
583 |
+
if len(values[1].split('/')) >= 2:
|
584 |
+
# quad mesh
|
585 |
+
if len(values) > 4:
|
586 |
+
f = list(map(lambda x: int(x.split('/')[1]), values[1:4]))
|
587 |
+
face_uv_data.append(f)
|
588 |
+
f = list(
|
589 |
+
map(lambda x: int(x.split('/')[1]),
|
590 |
+
[values[3], values[4], values[1]]))
|
591 |
+
face_uv_data.append(f)
|
592 |
+
# tri mesh
|
593 |
+
elif len(values[1].split('/')[1]) != 0:
|
594 |
+
f = list(map(lambda x: int(x.split('/')[1]), values[1:4]))
|
595 |
+
face_uv_data.append(f)
|
596 |
+
# deal with normal
|
597 |
+
if len(values[1].split('/')) == 3:
|
598 |
+
# quad mesh
|
599 |
+
if len(values) > 4:
|
600 |
+
f = list(map(lambda x: int(x.split('/')[2]), values[1:4]))
|
601 |
+
face_norm_data.append(f)
|
602 |
+
f = list(
|
603 |
+
map(lambda x: int(x.split('/')[2]),
|
604 |
+
[values[3], values[4], values[1]]))
|
605 |
+
face_norm_data.append(f)
|
606 |
+
# tri mesh
|
607 |
+
elif len(values[1].split('/')[2]) != 0:
|
608 |
+
f = list(map(lambda x: int(x.split('/')[2]), values[1:4]))
|
609 |
+
face_norm_data.append(f)
|
610 |
+
|
611 |
+
vertices = np.array(vertex_data)
|
612 |
+
faces = np.array(face_data)
|
613 |
+
faces[faces > 0] -= 1
|
614 |
+
|
615 |
+
if with_texture and with_normal:
|
616 |
+
uvs = np.array(uv_data)
|
617 |
+
face_uvs = np.array(face_uv_data)
|
618 |
+
face_uvs[face_uvs > 0] -= 1
|
619 |
+
norms = np.array(norm_data)
|
620 |
+
if norms.shape[0] == 0:
|
621 |
+
norms, _ = compute_normal(vertices, faces)
|
622 |
+
face_normals = faces
|
623 |
+
else:
|
624 |
+
norms = normalize_v3(norms)
|
625 |
+
face_normals = np.array(face_norm_data)
|
626 |
+
face_normals[face_normals > 0] -= 1
|
627 |
+
return vertices, faces, norms, face_normals, uvs, face_uvs
|
628 |
+
|
629 |
+
if with_texture:
|
630 |
+
uvs = np.array(uv_data)
|
631 |
+
face_uvs = np.array(face_uv_data) - 1
|
632 |
+
return vertices, faces, uvs, face_uvs
|
633 |
+
|
634 |
+
if with_normal:
|
635 |
+
norms = np.array(norm_data)
|
636 |
+
norms = normalize_v3(norms)
|
637 |
+
face_normals = np.array(face_norm_data) - 1
|
638 |
+
return vertices, faces, norms, face_normals
|
639 |
+
|
640 |
+
return vertices, faces
|
641 |
+
|
642 |
+
|
643 |
+
def normalize_v3(arr):
|
644 |
+
''' Normalize a numpy array of 3 component vectors shape=(n,3) '''
|
645 |
+
lens = np.sqrt(arr[:, 0]**2 + arr[:, 1]**2 + arr[:, 2]**2)
|
646 |
+
eps = 0.00000001
|
647 |
+
lens[lens < eps] = eps
|
648 |
+
arr[:, 0] /= lens
|
649 |
+
arr[:, 1] /= lens
|
650 |
+
arr[:, 2] /= lens
|
651 |
+
return arr
|
652 |
+
|
653 |
+
|
654 |
+
def compute_normal(vertices, faces):
|
655 |
+
# Create a zeroed array with the same type and shape as our vertices i.e., per vertex normal
|
656 |
+
vert_norms = np.zeros(vertices.shape, dtype=vertices.dtype)
|
657 |
+
# Create an indexed view into the vertex array using the array of three indices for triangles
|
658 |
+
tris = vertices[faces]
|
659 |
+
# Calculate the normal for all the triangles, by taking the cross product of the vectors v1-v0, and v2-v0 in each triangle
|
660 |
+
face_norms = np.cross(tris[::, 1] - tris[::, 0], tris[::, 2] - tris[::, 0])
|
661 |
+
# n is now an array of normals per triangle. The length of each normal is dependent the vertices,
|
662 |
+
# we need to normalize these, so that our next step weights each normal equally.
|
663 |
+
normalize_v3(face_norms)
|
664 |
+
# now we have a normalized array of normals, one per triangle, i.e., per triangle normals.
|
665 |
+
# But instead of one per triangle (i.e., flat shading), we add to each vertex in that triangle,
|
666 |
+
# the triangles' normal. Multiple triangles would then contribute to every vertex, so we need to normalize again afterwards.
|
667 |
+
# The cool part, we can actually add the normals through an indexed view of our (zeroed) per vertex normal array
|
668 |
+
vert_norms[faces[:, 0]] += face_norms
|
669 |
+
vert_norms[faces[:, 1]] += face_norms
|
670 |
+
vert_norms[faces[:, 2]] += face_norms
|
671 |
+
normalize_v3(vert_norms)
|
672 |
+
|
673 |
+
return vert_norms, face_norms
|
674 |
+
|
675 |
+
|
676 |
+
def save_obj_mesh(mesh_path, verts, faces):
|
677 |
+
file = open(mesh_path, 'w')
|
678 |
+
for v in verts:
|
679 |
+
file.write('v %.4f %.4f %.4f\n' % (v[0], v[1], v[2]))
|
680 |
+
for f in faces:
|
681 |
+
f_plus = f + 1
|
682 |
+
file.write('f %d %d %d\n' % (f_plus[0], f_plus[1], f_plus[2]))
|
683 |
+
file.close()
|
684 |
+
|
685 |
+
|
686 |
+
def save_obj_mesh_with_color(mesh_path, verts, faces, colors):
|
687 |
+
file = open(mesh_path, 'w')
|
688 |
+
|
689 |
+
for idx, v in enumerate(verts):
|
690 |
+
c = colors[idx]
|
691 |
+
file.write('v %.4f %.4f %.4f %.4f %.4f %.4f\n' %
|
692 |
+
(v[0], v[1], v[2], c[0], c[1], c[2]))
|
693 |
+
for f in faces:
|
694 |
+
f_plus = f + 1
|
695 |
+
file.write('f %d %d %d\n' % (f_plus[0], f_plus[1], f_plus[2]))
|
696 |
+
file.close()
|
697 |
+
|
698 |
+
|
699 |
+
def calculate_mIoU(outputs, labels):
|
700 |
+
|
701 |
+
SMOOTH = 1e-6
|
702 |
+
|
703 |
+
outputs = outputs.int()
|
704 |
+
labels = labels.int()
|
705 |
+
|
706 |
+
intersection = (
|
707 |
+
outputs
|
708 |
+
& labels).float().sum() # Will be zero if Truth=0 or Prediction=0
|
709 |
+
union = (outputs | labels).float().sum() # Will be zzero if both are 0
|
710 |
+
|
711 |
+
iou = (intersection + SMOOTH) / (union + SMOOTH
|
712 |
+
) # We smooth our devision to avoid 0/0
|
713 |
+
|
714 |
+
thresholded = torch.clamp(
|
715 |
+
20 * (iou - 0.5), 0,
|
716 |
+
10).ceil() / 10 # This is equal to comparing with thresolds
|
717 |
+
|
718 |
+
return thresholded.mean().detach().cpu().numpy(
|
719 |
+
) # Or thresholded.mean() if you are interested in average across the batch
|
720 |
+
|
721 |
+
|
722 |
+
def mask_filter(mask, number=1000):
|
723 |
+
"""only keep {number} True items within a mask
|
724 |
+
|
725 |
+
Args:
|
726 |
+
mask (bool array): [N, ]
|
727 |
+
number (int, optional): total True item. Defaults to 1000.
|
728 |
+
"""
|
729 |
+
true_ids = np.where(mask)[0]
|
730 |
+
keep_ids = np.random.choice(true_ids, size=number)
|
731 |
+
filter_mask = np.isin(np.arange(len(mask)), keep_ids)
|
732 |
+
|
733 |
+
return filter_mask
|
734 |
+
|
735 |
+
|
736 |
+
def query_mesh(path):
|
737 |
+
|
738 |
+
verts, faces_idx, _ = load_obj(path)
|
739 |
+
|
740 |
+
return verts, faces_idx.verts_idx
|
741 |
+
|
742 |
+
|
743 |
+
def add_alpha(colors, alpha=0.7):
|
744 |
+
|
745 |
+
colors_pad = np.pad(colors, ((0, 0), (0, 1)),
|
746 |
+
mode='constant',
|
747 |
+
constant_values=alpha)
|
748 |
+
|
749 |
+
return colors_pad
|
750 |
+
|
751 |
+
|
752 |
+
def get_optim_grid_image(per_loop_lst, loss=None, nrow=4, type='smpl'):
|
753 |
+
|
754 |
+
font_path = os.path.join(os.path.dirname(__file__), "tbfo.ttf")
|
755 |
+
font = ImageFont.truetype(font_path, 30)
|
756 |
+
grid_img = torchvision.utils.make_grid(torch.cat(per_loop_lst, dim=0),
|
757 |
+
nrow=nrow)
|
758 |
+
grid_img = Image.fromarray(
|
759 |
+
((grid_img.permute(1, 2, 0).detach().cpu().numpy() + 1.0) * 0.5 *
|
760 |
+
255.0).astype(np.uint8))
|
761 |
+
|
762 |
+
# add text
|
763 |
+
draw = ImageDraw.Draw(grid_img)
|
764 |
+
grid_size = 512
|
765 |
+
if loss is not None:
|
766 |
+
draw.text((10, 5), f"error: {loss:.3f}", (255, 0, 0), font=font)
|
767 |
+
|
768 |
+
if type == 'smpl':
|
769 |
+
for col_id, col_txt in enumerate(
|
770 |
+
['image', 'smpl-norm(render)', 'cloth-norm(pred)', 'diff-norm', 'diff-mask']):
|
771 |
+
draw.text((10+(col_id*grid_size), 5),
|
772 |
+
col_txt, (255, 0, 0), font=font)
|
773 |
+
elif type == 'cloth':
|
774 |
+
for col_id, col_txt in enumerate(
|
775 |
+
['image', 'cloth-norm(recon)', 'cloth-norm(pred)', 'diff-norm']):
|
776 |
+
draw.text((10+(col_id*grid_size), 5),
|
777 |
+
col_txt, (255, 0, 0), font=font)
|
778 |
+
for col_id, col_txt in enumerate(
|
779 |
+
['0', '90', '180', '270']):
|
780 |
+
draw.text((10+(col_id*grid_size), grid_size*2+5),
|
781 |
+
col_txt, (255, 0, 0), font=font)
|
782 |
+
else:
|
783 |
+
print(f"{type} should be 'smpl' or 'cloth'")
|
784 |
+
|
785 |
+
grid_img = grid_img.resize((grid_img.size[0], grid_img.size[1]),
|
786 |
+
Image.ANTIALIAS)
|
787 |
+
|
788 |
+
return grid_img
|
789 |
+
|
790 |
+
|
791 |
+
def clean_mesh(verts, faces):
|
792 |
+
|
793 |
+
device = verts.device
|
794 |
+
|
795 |
+
mesh_lst = trimesh.Trimesh(verts.detach().cpu().numpy(),
|
796 |
+
faces.detach().cpu().numpy())
|
797 |
+
mesh_lst = mesh_lst.split(only_watertight=False)
|
798 |
+
comp_num = [mesh.vertices.shape[0] for mesh in mesh_lst]
|
799 |
+
mesh_clean = mesh_lst[comp_num.index(max(comp_num))]
|
800 |
+
|
801 |
+
final_verts = torch.as_tensor(mesh_clean.vertices).float().to(device)
|
802 |
+
final_faces = torch.as_tensor(mesh_clean.faces).int().to(device)
|
803 |
+
|
804 |
+
return final_verts, final_faces
|
805 |
+
|
806 |
+
|
807 |
+
def merge_mesh(verts_A, faces_A, verts_B, faces_B, color=False):
|
808 |
+
|
809 |
+
sep_mesh = trimesh.Trimesh(np.concatenate([verts_A, verts_B], axis=0),
|
810 |
+
np.concatenate(
|
811 |
+
[faces_A, faces_B + faces_A.max() + 1],
|
812 |
+
axis=0),
|
813 |
+
maintain_order=True,
|
814 |
+
process=False)
|
815 |
+
if color:
|
816 |
+
colors = np.ones_like(sep_mesh.vertices)
|
817 |
+
colors[:verts_A.shape[0]] *= np.array([255.0, 0.0, 0.0])
|
818 |
+
colors[verts_A.shape[0]:] *= np.array([0.0, 255.0, 0.0])
|
819 |
+
sep_mesh.visual.vertex_colors = colors
|
820 |
+
|
821 |
+
# union_mesh = trimesh.boolean.union([trimesh.Trimesh(verts_A, faces_A),
|
822 |
+
# trimesh.Trimesh(verts_B, faces_B)], engine='blender')
|
823 |
+
|
824 |
+
return sep_mesh
|
825 |
+
|
826 |
+
|
827 |
+
def mesh_move(mesh_lst, step, scale=1.0):
|
828 |
+
|
829 |
+
trans = np.array([1.0, 0.0, 0.0]) * step
|
830 |
+
|
831 |
+
resize_matrix = trimesh.transformations.scale_and_translate(
|
832 |
+
scale=(scale), translate=trans)
|
833 |
+
|
834 |
+
results = []
|
835 |
+
|
836 |
+
for mesh in mesh_lst:
|
837 |
+
mesh.apply_transform(resize_matrix)
|
838 |
+
results.append(mesh)
|
839 |
+
|
840 |
+
return results
|
841 |
+
|
842 |
+
|
843 |
+
class SMPLX():
|
844 |
+
def __init__(self):
|
845 |
+
|
846 |
+
REPO_ID = "Yuliang/SMPL"
|
847 |
+
|
848 |
+
self.smpl_verts_path = hf_hub_download(REPO_ID, filename='smpl_data/smpl_verts.npy', use_auth_token=os.environ['ICON'])
|
849 |
+
self.smplx_verts_path = hf_hub_download(REPO_ID, filename='smpl_data/smplx_verts.npy', use_auth_token=os.environ['ICON'])
|
850 |
+
self.faces_path = hf_hub_download(REPO_ID, filename='smpl_data/smplx_faces.npy', use_auth_token=os.environ['ICON'])
|
851 |
+
self.cmap_vert_path = hf_hub_download(REPO_ID, filename='smpl_data/smplx_cmap.npy', use_auth_token=os.environ['ICON'])
|
852 |
+
|
853 |
+
self.faces = np.load(self.faces_path)
|
854 |
+
self.verts = np.load(self.smplx_verts_path)
|
855 |
+
self.smpl_verts = np.load(self.smpl_verts_path)
|
856 |
+
|
857 |
+
self.model_dir = hf_hub_url(REPO_ID, filename='models')
|
858 |
+
self.tedra_dir = hf_hub_url(REPO_ID, filename='tedra_data')
|
859 |
+
|
860 |
+
def get_smpl_mat(self, vert_ids):
|
861 |
+
|
862 |
+
mat = torch.as_tensor(np.load(self.cmap_vert_path)).float()
|
863 |
+
return mat[vert_ids, :]
|
864 |
+
|
865 |
+
def smpl2smplx(self, vert_ids=None):
|
866 |
+
"""convert vert_ids in smpl to vert_ids in smplx
|
867 |
+
|
868 |
+
Args:
|
869 |
+
vert_ids ([int.array]): [n, knn_num]
|
870 |
+
"""
|
871 |
+
smplx_tree = cKDTree(self.verts, leafsize=1)
|
872 |
+
_, ind = smplx_tree.query(self.smpl_verts, k=1) # ind: [smpl_num, 1]
|
873 |
+
|
874 |
+
if vert_ids is not None:
|
875 |
+
smplx_vert_ids = ind[vert_ids]
|
876 |
+
else:
|
877 |
+
smplx_vert_ids = ind
|
878 |
+
|
879 |
+
return smplx_vert_ids
|
880 |
+
|
881 |
+
def smplx2smpl(self, vert_ids=None):
|
882 |
+
"""convert vert_ids in smplx to vert_ids in smpl
|
883 |
+
|
884 |
+
Args:
|
885 |
+
vert_ids ([int.array]): [n, knn_num]
|
886 |
+
"""
|
887 |
+
smpl_tree = cKDTree(self.smpl_verts, leafsize=1)
|
888 |
+
_, ind = smpl_tree.query(self.verts, k=1) # ind: [smplx_num, 1]
|
889 |
+
if vert_ids is not None:
|
890 |
+
smpl_vert_ids = ind[vert_ids]
|
891 |
+
else:
|
892 |
+
smpl_vert_ids = ind
|
893 |
+
|
894 |
+
return smpl_vert_ids
|
lib/dataset/tbfo.ttf
ADDED
Binary file (571 kB). View file
|
|
lib/net/BasePIFuNet.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
|
5 |
+
# holder of all proprietary rights on this computer program.
|
6 |
+
# You can only use this computer program if you have closed
|
7 |
+
# a license agreement with MPG or you get the right to use the computer
|
8 |
+
# program from someone who is authorized to grant you that right.
|
9 |
+
# Any use of the computer program without a valid license is prohibited and
|
10 |
+
# liable to prosecution.
|
11 |
+
#
|
12 |
+
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
|
13 |
+
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
|
14 |
+
# for Intelligent Systems. All rights reserved.
|
15 |
+
#
|
16 |
+
# Contact: [email protected]
|
17 |
+
|
18 |
+
import torch.nn as nn
|
19 |
+
import pytorch_lightning as pl
|
20 |
+
|
21 |
+
from .geometry import index, orthogonal, perspective
|
22 |
+
|
23 |
+
|
24 |
+
class BasePIFuNet(pl.LightningModule):
|
25 |
+
def __init__(
|
26 |
+
self,
|
27 |
+
projection_mode='orthogonal',
|
28 |
+
error_term=nn.MSELoss(),
|
29 |
+
):
|
30 |
+
"""
|
31 |
+
:param projection_mode:
|
32 |
+
Either orthogonal or perspective.
|
33 |
+
It will call the corresponding function for projection.
|
34 |
+
:param error_term:
|
35 |
+
nn Loss between the predicted [B, Res, N] and the label [B, Res, N]
|
36 |
+
"""
|
37 |
+
super(BasePIFuNet, self).__init__()
|
38 |
+
self.name = 'base'
|
39 |
+
|
40 |
+
self.error_term = error_term
|
41 |
+
|
42 |
+
self.index = index
|
43 |
+
self.projection = orthogonal if projection_mode == 'orthogonal' else perspective
|
44 |
+
|
45 |
+
def forward(self, points, images, calibs, transforms=None):
|
46 |
+
'''
|
47 |
+
:param points: [B, 3, N] world space coordinates of points
|
48 |
+
:param images: [B, C, H, W] input images
|
49 |
+
:param calibs: [B, 3, 4] calibration matrices for each image
|
50 |
+
:param transforms: Optional [B, 2, 3] image space coordinate transforms
|
51 |
+
:return: [B, Res, N] predictions for each point
|
52 |
+
'''
|
53 |
+
features = self.filter(images)
|
54 |
+
preds = self.query(features, points, calibs, transforms)
|
55 |
+
return preds
|
56 |
+
|
57 |
+
def filter(self, images):
|
58 |
+
'''
|
59 |
+
Filter the input images
|
60 |
+
store all intermediate features.
|
61 |
+
:param images: [B, C, H, W] input images
|
62 |
+
'''
|
63 |
+
return None
|
64 |
+
|
65 |
+
def query(self, features, points, calibs, transforms=None):
|
66 |
+
'''
|
67 |
+
Given 3D points, query the network predictions for each point.
|
68 |
+
Image features should be pre-computed before this call.
|
69 |
+
store all intermediate features.
|
70 |
+
query() function may behave differently during training/testing.
|
71 |
+
:param points: [B, 3, N] world space coordinates of points
|
72 |
+
:param calibs: [B, 3, 4] calibration matrices for each image
|
73 |
+
:param transforms: Optional [B, 2, 3] image space coordinate transforms
|
74 |
+
:param labels: Optional [B, Res, N] gt labeling
|
75 |
+
:return: [B, Res, N] predictions for each point
|
76 |
+
'''
|
77 |
+
return None
|
78 |
+
|
79 |
+
def get_error(self, preds, labels):
|
80 |
+
'''
|
81 |
+
Get the network loss from the last query
|
82 |
+
:return: loss term
|
83 |
+
'''
|
84 |
+
return self.error_term(preds, labels)
|
lib/net/FBNet.py
ADDED
@@ -0,0 +1,387 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Copyright (C) 2019 NVIDIA Corporation. Ting-Chun Wang, Ming-Yu Liu, Jun-Yan Zhu.
|
3 |
+
BSD License. All rights reserved.
|
4 |
+
|
5 |
+
Redistribution and use in source and binary forms, with or without
|
6 |
+
modification, are permitted provided that the following conditions are met:
|
7 |
+
|
8 |
+
* Redistributions of source code must retain the above copyright notice, this
|
9 |
+
list of conditions and the following disclaimer.
|
10 |
+
|
11 |
+
* Redistributions in binary form must reproduce the above copyright notice,
|
12 |
+
this list of conditions and the following disclaimer in the documentation
|
13 |
+
and/or other materials provided with the distribution.
|
14 |
+
|
15 |
+
THE AUTHOR DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, INCLUDING ALL
|
16 |
+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR ANY PARTICULAR PURPOSE.
|
17 |
+
IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL
|
18 |
+
DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
|
19 |
+
WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING
|
20 |
+
OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
21 |
+
'''
|
22 |
+
import torch
|
23 |
+
import torch.nn as nn
|
24 |
+
import functools
|
25 |
+
import numpy as np
|
26 |
+
import pytorch_lightning as pl
|
27 |
+
|
28 |
+
|
29 |
+
###############################################################################
|
30 |
+
# Functions
|
31 |
+
###############################################################################
|
32 |
+
def weights_init(m):
|
33 |
+
classname = m.__class__.__name__
|
34 |
+
if classname.find('Conv') != -1:
|
35 |
+
m.weight.data.normal_(0.0, 0.02)
|
36 |
+
elif classname.find('BatchNorm2d') != -1:
|
37 |
+
m.weight.data.normal_(1.0, 0.02)
|
38 |
+
m.bias.data.fill_(0)
|
39 |
+
|
40 |
+
|
41 |
+
def get_norm_layer(norm_type='instance'):
|
42 |
+
if norm_type == 'batch':
|
43 |
+
norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
|
44 |
+
elif norm_type == 'instance':
|
45 |
+
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False)
|
46 |
+
else:
|
47 |
+
raise NotImplementedError('normalization layer [%s] is not found' %
|
48 |
+
norm_type)
|
49 |
+
return norm_layer
|
50 |
+
|
51 |
+
|
52 |
+
def define_G(input_nc,
|
53 |
+
output_nc,
|
54 |
+
ngf,
|
55 |
+
netG,
|
56 |
+
n_downsample_global=3,
|
57 |
+
n_blocks_global=9,
|
58 |
+
n_local_enhancers=1,
|
59 |
+
n_blocks_local=3,
|
60 |
+
norm='instance',
|
61 |
+
gpu_ids=[],
|
62 |
+
last_op=nn.Tanh()):
|
63 |
+
norm_layer = get_norm_layer(norm_type=norm)
|
64 |
+
if netG == 'global':
|
65 |
+
netG = GlobalGenerator(input_nc,
|
66 |
+
output_nc,
|
67 |
+
ngf,
|
68 |
+
n_downsample_global,
|
69 |
+
n_blocks_global,
|
70 |
+
norm_layer,
|
71 |
+
last_op=last_op)
|
72 |
+
elif netG == 'local':
|
73 |
+
netG = LocalEnhancer(input_nc, output_nc, ngf, n_downsample_global,
|
74 |
+
n_blocks_global, n_local_enhancers,
|
75 |
+
n_blocks_local, norm_layer)
|
76 |
+
elif netG == 'encoder':
|
77 |
+
netG = Encoder(input_nc, output_nc, ngf, n_downsample_global,
|
78 |
+
norm_layer)
|
79 |
+
else:
|
80 |
+
raise ('generator not implemented!')
|
81 |
+
# print(netG)
|
82 |
+
if len(gpu_ids) > 0:
|
83 |
+
assert (torch.cuda.is_available())
|
84 |
+
netG.cuda(gpu_ids[0])
|
85 |
+
netG.apply(weights_init)
|
86 |
+
return netG
|
87 |
+
|
88 |
+
|
89 |
+
def print_network(net):
|
90 |
+
if isinstance(net, list):
|
91 |
+
net = net[0]
|
92 |
+
num_params = 0
|
93 |
+
for param in net.parameters():
|
94 |
+
num_params += param.numel()
|
95 |
+
print(net)
|
96 |
+
print('Total number of parameters: %d' % num_params)
|
97 |
+
|
98 |
+
|
99 |
+
##############################################################################
|
100 |
+
# Generator
|
101 |
+
##############################################################################
|
102 |
+
class LocalEnhancer(pl.LightningModule):
|
103 |
+
def __init__(self,
|
104 |
+
input_nc,
|
105 |
+
output_nc,
|
106 |
+
ngf=32,
|
107 |
+
n_downsample_global=3,
|
108 |
+
n_blocks_global=9,
|
109 |
+
n_local_enhancers=1,
|
110 |
+
n_blocks_local=3,
|
111 |
+
norm_layer=nn.BatchNorm2d,
|
112 |
+
padding_type='reflect'):
|
113 |
+
super(LocalEnhancer, self).__init__()
|
114 |
+
self.n_local_enhancers = n_local_enhancers
|
115 |
+
|
116 |
+
###### global generator model #####
|
117 |
+
ngf_global = ngf * (2**n_local_enhancers)
|
118 |
+
model_global = GlobalGenerator(input_nc, output_nc, ngf_global,
|
119 |
+
n_downsample_global, n_blocks_global,
|
120 |
+
norm_layer).model
|
121 |
+
model_global = [model_global[i] for i in range(len(model_global) - 3)
|
122 |
+
] # get rid of final convolution layers
|
123 |
+
self.model = nn.Sequential(*model_global)
|
124 |
+
|
125 |
+
###### local enhancer layers #####
|
126 |
+
for n in range(1, n_local_enhancers + 1):
|
127 |
+
# downsample
|
128 |
+
ngf_global = ngf * (2**(n_local_enhancers - n))
|
129 |
+
model_downsample = [
|
130 |
+
nn.ReflectionPad2d(3),
|
131 |
+
nn.Conv2d(input_nc, ngf_global, kernel_size=7, padding=0),
|
132 |
+
norm_layer(ngf_global),
|
133 |
+
nn.ReLU(True),
|
134 |
+
nn.Conv2d(ngf_global,
|
135 |
+
ngf_global * 2,
|
136 |
+
kernel_size=3,
|
137 |
+
stride=2,
|
138 |
+
padding=1),
|
139 |
+
norm_layer(ngf_global * 2),
|
140 |
+
nn.ReLU(True)
|
141 |
+
]
|
142 |
+
# residual blocks
|
143 |
+
model_upsample = []
|
144 |
+
for i in range(n_blocks_local):
|
145 |
+
model_upsample += [
|
146 |
+
ResnetBlock(ngf_global * 2,
|
147 |
+
padding_type=padding_type,
|
148 |
+
norm_layer=norm_layer)
|
149 |
+
]
|
150 |
+
|
151 |
+
# upsample
|
152 |
+
model_upsample += [
|
153 |
+
nn.ConvTranspose2d(ngf_global * 2,
|
154 |
+
ngf_global,
|
155 |
+
kernel_size=3,
|
156 |
+
stride=2,
|
157 |
+
padding=1,
|
158 |
+
output_padding=1),
|
159 |
+
norm_layer(ngf_global),
|
160 |
+
nn.ReLU(True)
|
161 |
+
]
|
162 |
+
|
163 |
+
# final convolution
|
164 |
+
if n == n_local_enhancers:
|
165 |
+
model_upsample += [
|
166 |
+
nn.ReflectionPad2d(3),
|
167 |
+
nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0),
|
168 |
+
nn.Tanh()
|
169 |
+
]
|
170 |
+
|
171 |
+
setattr(self, 'model' + str(n) + '_1',
|
172 |
+
nn.Sequential(*model_downsample))
|
173 |
+
setattr(self, 'model' + str(n) + '_2',
|
174 |
+
nn.Sequential(*model_upsample))
|
175 |
+
|
176 |
+
self.downsample = nn.AvgPool2d(3,
|
177 |
+
stride=2,
|
178 |
+
padding=[1, 1],
|
179 |
+
count_include_pad=False)
|
180 |
+
|
181 |
+
def forward(self, input):
|
182 |
+
# create input pyramid
|
183 |
+
input_downsampled = [input]
|
184 |
+
for i in range(self.n_local_enhancers):
|
185 |
+
input_downsampled.append(self.downsample(input_downsampled[-1]))
|
186 |
+
|
187 |
+
# output at coarest level
|
188 |
+
output_prev = self.model(input_downsampled[-1])
|
189 |
+
# build up one layer at a time
|
190 |
+
for n_local_enhancers in range(1, self.n_local_enhancers + 1):
|
191 |
+
model_downsample = getattr(self,
|
192 |
+
'model' + str(n_local_enhancers) + '_1')
|
193 |
+
model_upsample = getattr(self,
|
194 |
+
'model' + str(n_local_enhancers) + '_2')
|
195 |
+
input_i = input_downsampled[self.n_local_enhancers -
|
196 |
+
n_local_enhancers]
|
197 |
+
output_prev = model_upsample(
|
198 |
+
model_downsample(input_i) + output_prev)
|
199 |
+
return output_prev
|
200 |
+
|
201 |
+
|
202 |
+
class GlobalGenerator(pl.LightningModule):
|
203 |
+
def __init__(self,
|
204 |
+
input_nc,
|
205 |
+
output_nc,
|
206 |
+
ngf=64,
|
207 |
+
n_downsampling=3,
|
208 |
+
n_blocks=9,
|
209 |
+
norm_layer=nn.BatchNorm2d,
|
210 |
+
padding_type='reflect',
|
211 |
+
last_op=nn.Tanh()):
|
212 |
+
assert (n_blocks >= 0)
|
213 |
+
super(GlobalGenerator, self).__init__()
|
214 |
+
activation = nn.ReLU(True)
|
215 |
+
|
216 |
+
model = [
|
217 |
+
nn.ReflectionPad2d(3),
|
218 |
+
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0),
|
219 |
+
norm_layer(ngf), activation
|
220 |
+
]
|
221 |
+
# downsample
|
222 |
+
for i in range(n_downsampling):
|
223 |
+
mult = 2**i
|
224 |
+
model += [
|
225 |
+
nn.Conv2d(ngf * mult,
|
226 |
+
ngf * mult * 2,
|
227 |
+
kernel_size=3,
|
228 |
+
stride=2,
|
229 |
+
padding=1),
|
230 |
+
norm_layer(ngf * mult * 2), activation
|
231 |
+
]
|
232 |
+
|
233 |
+
# resnet blocks
|
234 |
+
mult = 2**n_downsampling
|
235 |
+
for i in range(n_blocks):
|
236 |
+
model += [
|
237 |
+
ResnetBlock(ngf * mult,
|
238 |
+
padding_type=padding_type,
|
239 |
+
activation=activation,
|
240 |
+
norm_layer=norm_layer)
|
241 |
+
]
|
242 |
+
|
243 |
+
# upsample
|
244 |
+
for i in range(n_downsampling):
|
245 |
+
mult = 2**(n_downsampling - i)
|
246 |
+
model += [
|
247 |
+
nn.ConvTranspose2d(ngf * mult,
|
248 |
+
int(ngf * mult / 2),
|
249 |
+
kernel_size=3,
|
250 |
+
stride=2,
|
251 |
+
padding=1,
|
252 |
+
output_padding=1),
|
253 |
+
norm_layer(int(ngf * mult / 2)), activation
|
254 |
+
]
|
255 |
+
model += [
|
256 |
+
nn.ReflectionPad2d(3),
|
257 |
+
nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)
|
258 |
+
]
|
259 |
+
if last_op is not None:
|
260 |
+
model += [last_op]
|
261 |
+
self.model = nn.Sequential(*model)
|
262 |
+
|
263 |
+
def forward(self, input):
|
264 |
+
return self.model(input)
|
265 |
+
|
266 |
+
|
267 |
+
# Define a resnet block
|
268 |
+
class ResnetBlock(pl.LightningModule):
|
269 |
+
def __init__(self,
|
270 |
+
dim,
|
271 |
+
padding_type,
|
272 |
+
norm_layer,
|
273 |
+
activation=nn.ReLU(True),
|
274 |
+
use_dropout=False):
|
275 |
+
super(ResnetBlock, self).__init__()
|
276 |
+
self.conv_block = self.build_conv_block(dim, padding_type, norm_layer,
|
277 |
+
activation, use_dropout)
|
278 |
+
|
279 |
+
def build_conv_block(self, dim, padding_type, norm_layer, activation,
|
280 |
+
use_dropout):
|
281 |
+
conv_block = []
|
282 |
+
p = 0
|
283 |
+
if padding_type == 'reflect':
|
284 |
+
conv_block += [nn.ReflectionPad2d(1)]
|
285 |
+
elif padding_type == 'replicate':
|
286 |
+
conv_block += [nn.ReplicationPad2d(1)]
|
287 |
+
elif padding_type == 'zero':
|
288 |
+
p = 1
|
289 |
+
else:
|
290 |
+
raise NotImplementedError('padding [%s] is not implemented' %
|
291 |
+
padding_type)
|
292 |
+
|
293 |
+
conv_block += [
|
294 |
+
nn.Conv2d(dim, dim, kernel_size=3, padding=p),
|
295 |
+
norm_layer(dim), activation
|
296 |
+
]
|
297 |
+
if use_dropout:
|
298 |
+
conv_block += [nn.Dropout(0.5)]
|
299 |
+
|
300 |
+
p = 0
|
301 |
+
if padding_type == 'reflect':
|
302 |
+
conv_block += [nn.ReflectionPad2d(1)]
|
303 |
+
elif padding_type == 'replicate':
|
304 |
+
conv_block += [nn.ReplicationPad2d(1)]
|
305 |
+
elif padding_type == 'zero':
|
306 |
+
p = 1
|
307 |
+
else:
|
308 |
+
raise NotImplementedError('padding [%s] is not implemented' %
|
309 |
+
padding_type)
|
310 |
+
conv_block += [
|
311 |
+
nn.Conv2d(dim, dim, kernel_size=3, padding=p),
|
312 |
+
norm_layer(dim)
|
313 |
+
]
|
314 |
+
|
315 |
+
return nn.Sequential(*conv_block)
|
316 |
+
|
317 |
+
def forward(self, x):
|
318 |
+
out = x + self.conv_block(x)
|
319 |
+
return out
|
320 |
+
|
321 |
+
|
322 |
+
class Encoder(pl.LightningModule):
|
323 |
+
def __init__(self,
|
324 |
+
input_nc,
|
325 |
+
output_nc,
|
326 |
+
ngf=32,
|
327 |
+
n_downsampling=4,
|
328 |
+
norm_layer=nn.BatchNorm2d):
|
329 |
+
super(Encoder, self).__init__()
|
330 |
+
self.output_nc = output_nc
|
331 |
+
|
332 |
+
model = [
|
333 |
+
nn.ReflectionPad2d(3),
|
334 |
+
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0),
|
335 |
+
norm_layer(ngf),
|
336 |
+
nn.ReLU(True)
|
337 |
+
]
|
338 |
+
# downsample
|
339 |
+
for i in range(n_downsampling):
|
340 |
+
mult = 2**i
|
341 |
+
model += [
|
342 |
+
nn.Conv2d(ngf * mult,
|
343 |
+
ngf * mult * 2,
|
344 |
+
kernel_size=3,
|
345 |
+
stride=2,
|
346 |
+
padding=1),
|
347 |
+
norm_layer(ngf * mult * 2),
|
348 |
+
nn.ReLU(True)
|
349 |
+
]
|
350 |
+
|
351 |
+
# upsample
|
352 |
+
for i in range(n_downsampling):
|
353 |
+
mult = 2**(n_downsampling - i)
|
354 |
+
model += [
|
355 |
+
nn.ConvTranspose2d(ngf * mult,
|
356 |
+
int(ngf * mult / 2),
|
357 |
+
kernel_size=3,
|
358 |
+
stride=2,
|
359 |
+
padding=1,
|
360 |
+
output_padding=1),
|
361 |
+
norm_layer(int(ngf * mult / 2)),
|
362 |
+
nn.ReLU(True)
|
363 |
+
]
|
364 |
+
|
365 |
+
model += [
|
366 |
+
nn.ReflectionPad2d(3),
|
367 |
+
nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0),
|
368 |
+
nn.Tanh()
|
369 |
+
]
|
370 |
+
self.model = nn.Sequential(*model)
|
371 |
+
|
372 |
+
def forward(self, input, inst):
|
373 |
+
outputs = self.model(input)
|
374 |
+
|
375 |
+
# instance-wise average pooling
|
376 |
+
outputs_mean = outputs.clone()
|
377 |
+
inst_list = np.unique(inst.cpu().numpy().astype(int))
|
378 |
+
for i in inst_list:
|
379 |
+
for b in range(input.size()[0]):
|
380 |
+
indices = (inst[b:b + 1] == int(i)).nonzero() # n x 4
|
381 |
+
for j in range(self.output_nc):
|
382 |
+
output_ins = outputs[indices[:, 0] + b, indices[:, 1] + j,
|
383 |
+
indices[:, 2], indices[:, 3]]
|
384 |
+
mean_feat = torch.mean(output_ins).expand_as(output_ins)
|
385 |
+
outputs_mean[indices[:, 0] + b, indices[:, 1] + j,
|
386 |
+
indices[:, 2], indices[:, 3]] = mean_feat
|
387 |
+
return outputs_mean
|
lib/net/HGFilters.py
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
|
5 |
+
# holder of all proprietary rights on this computer program.
|
6 |
+
# You can only use this computer program if you have closed
|
7 |
+
# a license agreement with MPG or you get the right to use the computer
|
8 |
+
# program from someone who is authorized to grant you that right.
|
9 |
+
# Any use of the computer program without a valid license is prohibited and
|
10 |
+
# liable to prosecution.
|
11 |
+
#
|
12 |
+
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
|
13 |
+
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
|
14 |
+
# for Intelligent Systems. All rights reserved.
|
15 |
+
#
|
16 |
+
# Contact: [email protected]
|
17 |
+
|
18 |
+
from lib.net.net_util import *
|
19 |
+
import torch.nn as nn
|
20 |
+
import torch.nn.functional as F
|
21 |
+
|
22 |
+
|
23 |
+
class HourGlass(nn.Module):
|
24 |
+
def __init__(self, num_modules, depth, num_features, opt):
|
25 |
+
super(HourGlass, self).__init__()
|
26 |
+
self.num_modules = num_modules
|
27 |
+
self.depth = depth
|
28 |
+
self.features = num_features
|
29 |
+
self.opt = opt
|
30 |
+
|
31 |
+
self._generate_network(self.depth)
|
32 |
+
|
33 |
+
def _generate_network(self, level):
|
34 |
+
self.add_module('b1_' + str(level),
|
35 |
+
ConvBlock(self.features, self.features, self.opt))
|
36 |
+
|
37 |
+
self.add_module('b2_' + str(level),
|
38 |
+
ConvBlock(self.features, self.features, self.opt))
|
39 |
+
|
40 |
+
if level > 1:
|
41 |
+
self._generate_network(level - 1)
|
42 |
+
else:
|
43 |
+
self.add_module('b2_plus_' + str(level),
|
44 |
+
ConvBlock(self.features, self.features, self.opt))
|
45 |
+
|
46 |
+
self.add_module('b3_' + str(level),
|
47 |
+
ConvBlock(self.features, self.features, self.opt))
|
48 |
+
|
49 |
+
def _forward(self, level, inp):
|
50 |
+
# Upper branch
|
51 |
+
up1 = inp
|
52 |
+
up1 = self._modules['b1_' + str(level)](up1)
|
53 |
+
|
54 |
+
# Lower branch
|
55 |
+
low1 = F.avg_pool2d(inp, 2, stride=2)
|
56 |
+
low1 = self._modules['b2_' + str(level)](low1)
|
57 |
+
|
58 |
+
if level > 1:
|
59 |
+
low2 = self._forward(level - 1, low1)
|
60 |
+
else:
|
61 |
+
low2 = low1
|
62 |
+
low2 = self._modules['b2_plus_' + str(level)](low2)
|
63 |
+
|
64 |
+
low3 = low2
|
65 |
+
low3 = self._modules['b3_' + str(level)](low3)
|
66 |
+
|
67 |
+
# NOTE: for newer PyTorch (1.3~), it seems that training results are degraded due to implementation diff in F.grid_sample
|
68 |
+
# if the pretrained model behaves weirdly, switch with the commented line.
|
69 |
+
# NOTE: I also found that "bicubic" works better.
|
70 |
+
up2 = F.interpolate(low3,
|
71 |
+
scale_factor=2,
|
72 |
+
mode='bicubic',
|
73 |
+
align_corners=True)
|
74 |
+
# up2 = F.interpolate(low3, scale_factor=2, mode='nearest)
|
75 |
+
|
76 |
+
return up1 + up2
|
77 |
+
|
78 |
+
def forward(self, x):
|
79 |
+
return self._forward(self.depth, x)
|
80 |
+
|
81 |
+
|
82 |
+
class HGFilter(nn.Module):
|
83 |
+
def __init__(self, opt, num_modules, in_dim):
|
84 |
+
super(HGFilter, self).__init__()
|
85 |
+
self.num_modules = num_modules
|
86 |
+
|
87 |
+
self.opt = opt
|
88 |
+
[k, s, d, p] = self.opt.conv1
|
89 |
+
|
90 |
+
# self.conv1 = nn.Conv2d(in_dim, 64, kernel_size=7, stride=2, padding=3)
|
91 |
+
self.conv1 = nn.Conv2d(in_dim,
|
92 |
+
64,
|
93 |
+
kernel_size=k,
|
94 |
+
stride=s,
|
95 |
+
dilation=d,
|
96 |
+
padding=p)
|
97 |
+
|
98 |
+
if self.opt.norm == 'batch':
|
99 |
+
self.bn1 = nn.BatchNorm2d(64)
|
100 |
+
elif self.opt.norm == 'group':
|
101 |
+
self.bn1 = nn.GroupNorm(32, 64)
|
102 |
+
|
103 |
+
if self.opt.hg_down == 'conv64':
|
104 |
+
self.conv2 = ConvBlock(64, 64, self.opt)
|
105 |
+
self.down_conv2 = nn.Conv2d(64,
|
106 |
+
128,
|
107 |
+
kernel_size=3,
|
108 |
+
stride=2,
|
109 |
+
padding=1)
|
110 |
+
elif self.opt.hg_down == 'conv128':
|
111 |
+
self.conv2 = ConvBlock(64, 128, self.opt)
|
112 |
+
self.down_conv2 = nn.Conv2d(128,
|
113 |
+
128,
|
114 |
+
kernel_size=3,
|
115 |
+
stride=2,
|
116 |
+
padding=1)
|
117 |
+
elif self.opt.hg_down == 'ave_pool':
|
118 |
+
self.conv2 = ConvBlock(64, 128, self.opt)
|
119 |
+
else:
|
120 |
+
raise NameError('Unknown Fan Filter setting!')
|
121 |
+
|
122 |
+
self.conv3 = ConvBlock(128, 128, self.opt)
|
123 |
+
self.conv4 = ConvBlock(128, 256, self.opt)
|
124 |
+
|
125 |
+
# Stacking part
|
126 |
+
for hg_module in range(self.num_modules):
|
127 |
+
self.add_module('m' + str(hg_module),
|
128 |
+
HourGlass(1, opt.num_hourglass, 256, self.opt))
|
129 |
+
|
130 |
+
self.add_module('top_m_' + str(hg_module),
|
131 |
+
ConvBlock(256, 256, self.opt))
|
132 |
+
self.add_module(
|
133 |
+
'conv_last' + str(hg_module),
|
134 |
+
nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
|
135 |
+
if self.opt.norm == 'batch':
|
136 |
+
self.add_module('bn_end' + str(hg_module), nn.BatchNorm2d(256))
|
137 |
+
elif self.opt.norm == 'group':
|
138 |
+
self.add_module('bn_end' + str(hg_module),
|
139 |
+
nn.GroupNorm(32, 256))
|
140 |
+
|
141 |
+
self.add_module(
|
142 |
+
'l' + str(hg_module),
|
143 |
+
nn.Conv2d(256,
|
144 |
+
opt.hourglass_dim,
|
145 |
+
kernel_size=1,
|
146 |
+
stride=1,
|
147 |
+
padding=0))
|
148 |
+
|
149 |
+
if hg_module < self.num_modules - 1:
|
150 |
+
self.add_module(
|
151 |
+
'bl' + str(hg_module),
|
152 |
+
nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
|
153 |
+
self.add_module(
|
154 |
+
'al' + str(hg_module),
|
155 |
+
nn.Conv2d(opt.hourglass_dim,
|
156 |
+
256,
|
157 |
+
kernel_size=1,
|
158 |
+
stride=1,
|
159 |
+
padding=0))
|
160 |
+
|
161 |
+
def forward(self, x):
|
162 |
+
x = F.relu(self.bn1(self.conv1(x)), True)
|
163 |
+
tmpx = x
|
164 |
+
if self.opt.hg_down == 'ave_pool':
|
165 |
+
x = F.avg_pool2d(self.conv2(x), 2, stride=2)
|
166 |
+
elif self.opt.hg_down in ['conv64', 'conv128']:
|
167 |
+
x = self.conv2(x)
|
168 |
+
x = self.down_conv2(x)
|
169 |
+
else:
|
170 |
+
raise NameError('Unknown Fan Filter setting!')
|
171 |
+
|
172 |
+
x = self.conv3(x)
|
173 |
+
x = self.conv4(x)
|
174 |
+
|
175 |
+
previous = x
|
176 |
+
|
177 |
+
outputs = []
|
178 |
+
for i in range(self.num_modules):
|
179 |
+
hg = self._modules['m' + str(i)](previous)
|
180 |
+
|
181 |
+
ll = hg
|
182 |
+
ll = self._modules['top_m_' + str(i)](ll)
|
183 |
+
|
184 |
+
ll = F.relu(
|
185 |
+
self._modules['bn_end' + str(i)](
|
186 |
+
self._modules['conv_last' + str(i)](ll)), True)
|
187 |
+
|
188 |
+
# Predict heatmaps
|
189 |
+
tmp_out = self._modules['l' + str(i)](ll)
|
190 |
+
outputs.append(tmp_out)
|
191 |
+
|
192 |
+
if i < self.num_modules - 1:
|
193 |
+
ll = self._modules['bl' + str(i)](ll)
|
194 |
+
tmp_out_ = self._modules['al' + str(i)](tmp_out)
|
195 |
+
previous = previous + ll + tmp_out_
|
196 |
+
|
197 |
+
return outputs
|