Yuliang commited on
Commit
2d5f249
·
0 Parent(s):
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +37 -0
  2. .gitignore +15 -0
  3. LICENSE +53 -0
  4. README.md +12 -0
  5. app.py +129 -0
  6. apps/ICON.py +765 -0
  7. apps/Normal.py +213 -0
  8. apps/infer.py +467 -0
  9. assets/garment_teaser.png +3 -0
  10. assets/intermediate_results.png +3 -0
  11. assets/teaser.gif +3 -0
  12. assets/thumbnail.png +3 -0
  13. configs/icon-filter.yaml +25 -0
  14. configs/icon-nofilter.yaml +25 -0
  15. configs/pamir.yaml +24 -0
  16. configs/pifu.yaml +24 -0
  17. examples/22097467bffc92d4a5c4246f7d4edb75.png +3 -0
  18. examples/44c0f84c957b6b9bdf77662af5bb7078.png +3 -0
  19. examples/5a6a25963db2f667441d5076972c207c.png +3 -0
  20. examples/8da7ceb94669c2f65cbd28022e1f9876.png +3 -0
  21. examples/923d65f767c85a42212cae13fba3750b.png +3 -0
  22. examples/959c4c726a69901ce71b93a9242ed900.png +3 -0
  23. examples/c9856a2bc31846d684cbb965457fad59.png +3 -0
  24. examples/e1e7622af7074a022f5d96dc16672517.png +3 -0
  25. examples/fb9d20fdb93750584390599478ecf86e.png +3 -0
  26. examples/slack_trial2-000150.png +3 -0
  27. lib/__init__.py +0 -0
  28. lib/common/__init__.py +0 -0
  29. lib/common/cloth_extraction.py +170 -0
  30. lib/common/config.py +218 -0
  31. lib/common/render.py +388 -0
  32. lib/common/render_utils.py +221 -0
  33. lib/common/seg3d_lossless.py +604 -0
  34. lib/common/seg3d_utils.py +392 -0
  35. lib/common/smpl_vert_segmentation.json +0 -0
  36. lib/common/train_util.py +597 -0
  37. lib/dataset/Evaluator.py +264 -0
  38. lib/dataset/NormalDataset.py +212 -0
  39. lib/dataset/NormalModule.py +94 -0
  40. lib/dataset/PIFuDataModule.py +71 -0
  41. lib/dataset/PIFuDataset.py +589 -0
  42. lib/dataset/TestDataset.py +256 -0
  43. lib/dataset/__init__.py +0 -0
  44. lib/dataset/body_model.py +494 -0
  45. lib/dataset/hoppeMesh.py +116 -0
  46. lib/dataset/mesh_util.py +894 -0
  47. lib/dataset/tbfo.ttf +0 -0
  48. lib/net/BasePIFuNet.py +84 -0
  49. lib/net/FBNet.py +387 -0
  50. lib/net/HGFilters.py +197 -0
.gitattributes ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ftz filter=lfs diff=lfs merge=lfs -text
6
+ *.gz filter=lfs diff=lfs merge=lfs -text
7
+ *.h5 filter=lfs diff=lfs merge=lfs -text
8
+ *.joblib filter=lfs diff=lfs merge=lfs -text
9
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
10
+ *.model filter=lfs diff=lfs merge=lfs -text
11
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
12
+ *.npy filter=lfs diff=lfs merge=lfs -text
13
+ *.npz filter=lfs diff=lfs merge=lfs -text
14
+ *.onnx filter=lfs diff=lfs merge=lfs -text
15
+ *.ot filter=lfs diff=lfs merge=lfs -text
16
+ *.parquet filter=lfs diff=lfs merge=lfs -text
17
+ *.pickle filter=lfs diff=lfs merge=lfs -text
18
+ *.pkl filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pt filter=lfs diff=lfs merge=lfs -text
21
+ *.pth filter=lfs diff=lfs merge=lfs -text
22
+ *.rar filter=lfs diff=lfs merge=lfs -text
23
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
24
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
25
+ *.tflite filter=lfs diff=lfs merge=lfs -text
26
+ *.tgz filter=lfs diff=lfs merge=lfs -text
27
+ *.wasm filter=lfs diff=lfs merge=lfs -text
28
+ *.xz filter=lfs diff=lfs merge=lfs -text
29
+ *.zip filter=lfs diff=lfs merge=lfs -text
30
+ *.zstandard filter=lfs diff=lfs merge=lfs -text
31
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
32
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
33
+ *.obj filter=lfs diff=lfs merge=lfs -text
34
+ *.mp4 filter=lfs diff=lfs merge=lfs -text
35
+ *.glb filter=lfs diff=lfs merge=lfs -text
36
+ *.png filter=lfs diff=lfs merge=lfs -text
37
+ *.gif filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data/*/*
2
+ data/thuman*
3
+ !data/tbfo.ttf
4
+ __pycache__
5
+ debug/
6
+ log/
7
+ .vscode
8
+ !.gitignore
9
+ force_push.sh
10
+ .idea
11
+ human_det/
12
+ kaolin/
13
+ neural_voxelization_layer/
14
+ pytorch3d/
15
+ force_push.sh
LICENSE ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ License
2
+
3
+ Software Copyright License for non-commercial scientific research purposes
4
+ Please read carefully the following terms and conditions and any accompanying documentation before you download and/or use the ICON model, data and software, (the "Data & Software"), including 3D meshes, images, videos, textures, software, scripts, and animations. By downloading and/or using the Data & Software (including downloading, cloning, installing, and any other use of the corresponding github repository), you acknowledge that you have read these terms and conditions, understand them, and agree to be bound by them. If you do not agree with these terms and conditions, you must not download and/or use the Data & Software. Any infringement of the terms of this agreement will automatically terminate your rights under this License
5
+
6
+ Ownership / Licensees
7
+ The Software and the associated materials has been developed at the Max Planck Institute for Intelligent Systems (hereinafter "MPI"). Any copyright or patent right is owned by and proprietary material of the Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (hereinafter “MPG”; MPI and MPG hereinafter collectively “Max-Planck”) hereinafter the “Licensor”.
8
+
9
+ License Grant
10
+ Licensor grants you (Licensee) personally a single-user, non-exclusive, non-transferable, free of charge right:
11
+
12
+ • To install the Model & Software on computers owned, leased or otherwise controlled by you and/or your organization;
13
+ • To use the Model & Software for the sole purpose of performing peaceful non-commercial scientific research, non-commercial education, or non-commercial artistic projects;
14
+ • To modify, adapt, translate or create derivative works based upon the Model & Software.
15
+
16
+ Any other use, in particular any use for commercial, pornographic, military, or surveillance, purposes is prohibited. This includes, without limitation, incorporation in a commercial product, use in a commercial service, or production of other artifacts for commercial purposes. The Data & Software may not be used to create fake, libelous, misleading, or defamatory content of any kind excluding analyses in peer-reviewed scientific research. The Data & Software may not be reproduced, modified and/or made available in any form to any third party without Max-Planck’s prior written permission.
17
+
18
+ The Data & Software may not be used for pornographic purposes or to generate pornographic material whether commercial or not. This license also prohibits the use of the Software to train methods/algorithms/neural networks/etc. for commercial, pornographic, military, surveillance, or defamatory use of any kind. By downloading the Data & Software, you agree not to reverse engineer it.
19
+
20
+ No Distribution
21
+ The Data & Software and the license herein granted shall not be copied, shared, distributed, re-sold, offered for re-sale, transferred or sub-licensed in whole or in part except that you may make one copy for archive purposes only.
22
+
23
+ Disclaimer of Representations and Warranties
24
+ You expressly acknowledge and agree that the Data & Software results from basic research, is provided “AS IS”, may contain errors, and that any use of the Data & Software is at your sole risk. LICENSOR MAKES NO REPRESENTATIONS OR WARRANTIES OF ANY KIND CONCERNING THE DATA & SOFTWARE, NEITHER EXPRESS NOR IMPLIED, AND THE ABSENCE OF ANY LEGAL OR ACTUAL DEFECTS, WHETHER DISCOVERABLE OR NOT. Specifically, and not to limit the foregoing, licensor makes no representations or warranties (i) regarding the merchantability or fitness for a particular purpose of the Data & Software, (ii) that the use of the Data & Software will not infringe any patents, copyrights or other intellectual property rights of a third party, and (iii) that the use of the Data & Software will not cause any damage of any kind to you or a third party.
25
+
26
+ Limitation of Liability
27
+ Because this Data & Software License Agreement qualifies as a donation, according to Section 521 of the German Civil Code (Bürgerliches Gesetzbuch – BGB) Licensor as a donor is liable for intent and gross negligence only. If the Licensor fraudulently conceals a legal or material defect, they are obliged to compensate the Licensee for the resulting damage.
28
+ Licensor shall be liable for loss of data only up to the amount of typical recovery costs which would have arisen had proper and regular data backup measures been taken. For the avoidance of doubt Licensor shall be liable in accordance with the German Product Liability Act in the event of product liability. The foregoing applies also to Licensor’s legal representatives or assistants in performance. Any further liability shall be excluded.
29
+ Patent claims generated through the usage of the Data & Software cannot be directed towards the copyright holders.
30
+ The Data & Software is provided in the state of development the licensor defines. If modified or extended by Licensee, the Licensor makes no claims about the fitness of the Data & Software and is not responsible for any problems such modifications cause.
31
+
32
+ No Maintenance Services
33
+ You understand and agree that Licensor is under no obligation to provide either maintenance services, update services, notices of latent defects, or corrections of defects with regard to the Data & Software. Licensor nevertheless reserves the right to update, modify, or discontinue the Data & Software at any time.
34
+
35
+ Defects of the Data & Software must be notified in writing to the Licensor with a comprehensible description of the error symptoms. The notification of the defect should enable the reproduction of the error. The Licensee is encouraged to communicate any use, results, modification or publication.
36
+
37
+ Publications using the Model & Software
38
+ You acknowledge that the Data & Software is a valuable scientific resource and agree to appropriately reference the following paper in any publication making use of the Data & Software.
39
+
40
+ Citation:
41
+
42
+ @inproceedings{xiu2022icon,
43
+ title={{ICON}: {I}mplicit {C}lothed humans {O}btained from {N}ormals},
44
+ author={Xiu, Yuliang and Yang, Jinlong and Tzionas, Dimitrios and Black, Michael J.},
45
+ booktitle={IEEE/CVF Conf.~on Computer Vision and Pattern Recognition (CVPR)},
46
+ month = jun,
47
+ year={2022}
48
+ }
49
+
50
+ Commercial licensing opportunities
51
+ For commercial uses of the Model & Software, please send email to [email protected]
52
+
53
+ This Agreement shall be governed by the laws of the Federal Republic of Germany except for the UN Sales Convention.
README.md ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: ICON - Clothed Human Digitization
3
+ metaTitle: "Making yourself an ICON, by Yuliang Xiu"
4
+ emoji: 🤼
5
+ colorFrom: indigo
6
+ colorTo: yellow
7
+ sdk: gradio
8
+ sdk_version: 3.1.1
9
+ app_file: app.py
10
+ pinned: true
11
+ python_version: 3.8.13
12
+ ---
app.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # install
2
+
3
+
4
+ import glob
5
+ import gradio as gr
6
+ import os, random
7
+
8
+ import subprocess
9
+
10
+ if os.getenv('SYSTEM') == 'spaces':
11
+ subprocess.run('pip install pyembree'.split())
12
+ subprocess.run('pip install rembg'.split())
13
+ subprocess.run(
14
+ 'pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html'.split())
15
+ subprocess.run(
16
+ 'pip install git+https://github.com/YuliangXiu/kaolin.git'.split())
17
+ # subprocess.run('pip install https://download.is.tue.mpg.de/icon/kaolin-0.11.0-cp38-cp38-linux_x86_64.whl'.split())
18
+ subprocess.run('pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py38_cu113_pyt1110/download.html'.split())
19
+ subprocess.run(
20
+ 'pip install git+https://github.com/Project-Splinter/human_det.git'.split())
21
+ subprocess.run(
22
+ 'pip install git+https://github.com/YuliangXiu/neural_voxelization_layer.git'.split())
23
+
24
+ from apps.infer import generate_model
25
+
26
+ # running
27
+
28
+ description = '''
29
+ # ICON Clothed Human Digitization
30
+ ### ICON: Implicit Clothed humans Obtained from Normals (CVPR 2022)
31
+
32
+ <table style="width:26%; padding:0; margin:0;">
33
+ <tr>
34
+ <th><iframe src="https://ghbtns.com/github-btn.html?user=yuliangxiu&repo=ICON&type=star&count=true&v=2&size=small" frameborder="0" scrolling="0" width="100" height="20"></iframe></th>
35
+ <th><img alt="Twitter Follow" src="https://img.shields.io/twitter/follow/yuliangxiu?style=social"></th>
36
+ <th><img alt="YouTube Video Views" src="https://img.shields.io/youtube/views/hZd6AYin2DE?style=social"></th>
37
+ </tr>
38
+ </table>
39
+
40
+ #### Acknowledgments:
41
+
42
+ - [StyleGAN-Human, ECCV 2022](https://stylegan-human.github.io/)
43
+ - [nagolinc/styleGanHuman_and_PIFu](https://huggingface.co/spaces/nagolinc/styleGanHuman_and_PIFu)
44
+ - [radames/PIFu-Clothed-Human-Digitization](https://huggingface.co/spaces/radames/PIFu-Clothed-Human-Digitization)
45
+
46
+ #### The reconstruction + refinement + video take about 80 seconds for single image.
47
+
48
+ <details>
49
+
50
+ <summary>More</summary>
51
+
52
+ #### Image Credits
53
+
54
+ * [Pinterest](https://www.pinterest.com/search/pins/?q=parkour&rs=sitelinks_searchbox)
55
+ * [Qianli Ma](https://qianlim.github.io/)
56
+
57
+ #### Related works
58
+
59
+ * [ICON @ MPI](https://icon.is.tue.mpg.de/)
60
+ * [MonoPort @ USC](https://xiuyuliang.cn/monoport)
61
+ * [Phorhum @ Google](https://phorhum.github.io/)
62
+ * [PIFuHD @ Meta](https://shunsukesaito.github.io/PIFuHD/)
63
+ * [PaMIR @ Tsinghua](http://www.liuyebin.com/pamir/pamir.html)
64
+
65
+ </details>
66
+ '''
67
+
68
+
69
+ def generate_image(seed, psi):
70
+ iface = gr.Interface.load("spaces/hysts/StyleGAN-Human")
71
+ img = iface(seed, psi)
72
+ return img
73
+
74
+ random.seed(1993)
75
+ model_types = ['icon-filter', 'pifu', 'pamir']
76
+ examples = [[item, random.choice(model_types)] for item in sorted(glob.glob('examples/*.png'))]
77
+
78
+ with gr.Blocks() as demo:
79
+ gr.Markdown(description)
80
+
81
+ out_lst = []
82
+ with gr.Row():
83
+ with gr.Column():
84
+ with gr.Row():
85
+ with gr.Column():
86
+ seed = gr.inputs.Slider(
87
+ 0, 100, step=1, default=0, label='Seed (For Image Generation)')
88
+ psi = gr.inputs.Slider(
89
+ 0, 2, step=0.05, default=0.7, label='Truncation psi (For Image Generation)')
90
+ radio_choice = gr.Radio(model_types, label='Method (For Reconstruction)', value='icon-filter')
91
+ inp = gr.Image(type="filepath", label="Input Image")
92
+ with gr.Row():
93
+ btn_sample = gr.Button("Sample Image")
94
+ btn_submit = gr.Button("Submit Image")
95
+
96
+ gr.Examples(examples=examples,
97
+ inputs=[inp, radio_choice],
98
+ cache_examples=True,
99
+ fn=generate_model,
100
+ outputs=out_lst)
101
+
102
+ out_vid_download = gr.File(label="Download Video, welcome share on Twitter with #ICON")
103
+
104
+ with gr.Column():
105
+ overlap_inp = gr.Image(type="filepath", label="Image Normal Overlap")
106
+ out_smpl = gr.Model3D(
107
+ clear_color=[0.0, 0.0, 0.0, 0.0], label="SMPL")
108
+ out_smpl_download = gr.File(label="Download SMPL mesh")
109
+ out_smpl_npy_download = gr.File(label="Download SMPL params")
110
+ out_recon = gr.Model3D(
111
+ clear_color=[0.0, 0.0, 0.0, 0.0], label="ICON")
112
+ out_recon_download = gr.File(label="Download clothed human mesh")
113
+ out_final = gr.Model3D(
114
+ clear_color=[0.0, 0.0, 0.0, 0.0], label="ICON++")
115
+ out_final_download = gr.File(label="Download refined clothed human mesh")
116
+
117
+ out_lst = [out_smpl, out_smpl_download, out_smpl_npy_download, out_recon, out_recon_download,
118
+ out_final, out_final_download, out_vid_download, overlap_inp]
119
+
120
+ btn_submit.click(fn=generate_model, inputs=[inp, radio_choice], outputs=out_lst)
121
+ btn_sample.click(fn=generate_image, inputs=[seed, psi], outputs=inp)
122
+
123
+ if __name__ == "__main__":
124
+
125
+ # demo.launch(debug=False, enable_queue=False,
126
+ # auth=("[email protected]", "icon_2022"),
127
+ # auth_message="Register at icon.is.tue.mpg.de to get HuggingFace username and password.")
128
+
129
+ demo.launch(debug=True, enable_queue=True)
apps/ICON.py ADDED
@@ -0,0 +1,765 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
4
+ # holder of all proprietary rights on this computer program.
5
+ # You can only use this computer program if you have closed
6
+ # a license agreement with MPG or you get the right to use the computer
7
+ # program from someone who is authorized to grant you that right.
8
+ # Any use of the computer program without a valid license is prohibited and
9
+ # liable to prosecution.
10
+ #
11
+ # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
12
+ # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
13
+ # for Intelligent Systems. All rights reserved.
14
+ #
15
+ # Contact: [email protected]
16
+
17
+
18
+ import os
19
+
20
+ from lib.common.seg3d_lossless import Seg3dLossless
21
+ from lib.dataset.Evaluator import Evaluator
22
+ from lib.net import HGPIFuNet
23
+ from lib.common.train_util import *
24
+ from lib.common.render import Render
25
+ from lib.dataset.mesh_util import SMPLX, update_mesh_shape_prior_losses, get_visibility
26
+ import warnings
27
+ import logging
28
+ import torch
29
+ import lib.smplx as smplx
30
+ import numpy as np
31
+ from torch import nn
32
+ import os.path as osp
33
+
34
+ from skimage.transform import resize
35
+ import pytorch_lightning as pl
36
+ from huggingface_hub import cached_download
37
+
38
+ torch.backends.cudnn.benchmark = True
39
+
40
+ logging.getLogger("lightning").setLevel(logging.ERROR)
41
+
42
+ warnings.filterwarnings("ignore")
43
+
44
+
45
+ class ICON(pl.LightningModule):
46
+ def __init__(self, cfg):
47
+ super(ICON, self).__init__()
48
+
49
+ self.cfg = cfg
50
+ self.batch_size = self.cfg.batch_size
51
+ self.lr_G = self.cfg.lr_G
52
+
53
+ self.use_sdf = cfg.sdf
54
+ self.prior_type = cfg.net.prior_type
55
+ self.mcube_res = cfg.mcube_res
56
+ self.clean_mesh_flag = cfg.clean_mesh
57
+
58
+ self.netG = HGPIFuNet(
59
+ self.cfg,
60
+ self.cfg.projection_mode,
61
+ error_term=nn.SmoothL1Loss() if self.use_sdf else nn.MSELoss(),
62
+ )
63
+
64
+ # TODO: replace the renderer from opengl to pytorch3d
65
+ self.evaluator = Evaluator(
66
+ device=torch.device(f"cuda:{self.cfg.gpus[0]}"))
67
+
68
+ self.resolutions = (
69
+ np.logspace(
70
+ start=5,
71
+ stop=np.log2(self.mcube_res),
72
+ base=2,
73
+ num=int(np.log2(self.mcube_res) - 4),
74
+ endpoint=True,
75
+ )
76
+ + 1.0
77
+ )
78
+ self.resolutions = self.resolutions.astype(np.int16).tolist()
79
+
80
+ self.icon_keys = ["smpl_verts", "smpl_faces", "smpl_vis", "smpl_cmap"]
81
+ self.pamir_keys = ["voxel_verts",
82
+ "voxel_faces", "pad_v_num", "pad_f_num"]
83
+
84
+ self.reconEngine = Seg3dLossless(
85
+ query_func=query_func,
86
+ b_min=[[-1.0, 1.0, -1.0]],
87
+ b_max=[[1.0, -1.0, 1.0]],
88
+ resolutions=self.resolutions,
89
+ align_corners=True,
90
+ balance_value=0.50,
91
+ device=torch.device(f"cuda:{self.cfg.test_gpus[0]}"),
92
+ visualize=False,
93
+ debug=False,
94
+ use_cuda_impl=False,
95
+ faster=True,
96
+ )
97
+
98
+ self.render = Render(
99
+ size=512, device=torch.device(f"cuda:{self.cfg.test_gpus[0]}")
100
+ )
101
+ self.smpl_data = SMPLX()
102
+
103
+ self.get_smpl_model = lambda smpl_type, gender, age, v_template: smplx.create(
104
+ self.smpl_data.model_dir,
105
+ kid_template_path=cached_download(osp.join(self.smpl_data.model_dir,
106
+ f"{smpl_type}/{smpl_type}_kid_template.npy"), use_auth_token=os.environ['ICON']),
107
+ model_type=smpl_type,
108
+ gender=gender,
109
+ age=age,
110
+ v_template=v_template,
111
+ use_face_contour=False,
112
+ ext="pkl",
113
+ )
114
+
115
+ self.in_geo = [item[0] for item in cfg.net.in_geo]
116
+ self.in_nml = [item[0] for item in cfg.net.in_nml]
117
+ self.in_geo_dim = [item[1] for item in cfg.net.in_geo]
118
+ self.in_total = self.in_geo + self.in_nml
119
+ self.smpl_dim = cfg.net.smpl_dim
120
+
121
+ self.export_dir = None
122
+ self.result_eval = {}
123
+
124
+ def get_progress_bar_dict(self):
125
+ tqdm_dict = super().get_progress_bar_dict()
126
+ if "v_num" in tqdm_dict:
127
+ del tqdm_dict["v_num"]
128
+ return tqdm_dict
129
+
130
+ # Training related
131
+ def configure_optimizers(self):
132
+
133
+ # set optimizer
134
+ weight_decay = self.cfg.weight_decay
135
+ momentum = self.cfg.momentum
136
+
137
+ optim_params_G = [
138
+ {"params": self.netG.if_regressor.parameters(), "lr": self.lr_G}
139
+ ]
140
+
141
+ if self.cfg.net.use_filter:
142
+ optim_params_G.append(
143
+ {"params": self.netG.F_filter.parameters(), "lr": self.lr_G}
144
+ )
145
+
146
+ if self.cfg.net.prior_type == "pamir":
147
+ optim_params_G.append(
148
+ {"params": self.netG.ve.parameters(), "lr": self.lr_G}
149
+ )
150
+
151
+ if self.cfg.optim == "Adadelta":
152
+
153
+ optimizer_G = torch.optim.Adadelta(
154
+ optim_params_G, lr=self.lr_G, weight_decay=weight_decay
155
+ )
156
+
157
+ elif self.cfg.optim == "Adam":
158
+
159
+ optimizer_G = torch.optim.Adam(
160
+ optim_params_G, lr=self.lr_G, weight_decay=weight_decay
161
+ )
162
+
163
+ elif self.cfg.optim == "RMSprop":
164
+
165
+ optimizer_G = torch.optim.RMSprop(
166
+ optim_params_G,
167
+ lr=self.lr_G,
168
+ weight_decay=weight_decay,
169
+ momentum=momentum,
170
+ )
171
+
172
+ else:
173
+ raise NotImplementedError
174
+
175
+ # set scheduler
176
+ scheduler_G = torch.optim.lr_scheduler.MultiStepLR(
177
+ optimizer_G, milestones=self.cfg.schedule, gamma=self.cfg.gamma
178
+ )
179
+
180
+ return [optimizer_G], [scheduler_G]
181
+
182
+ def training_step(self, batch, batch_idx):
183
+
184
+ if not self.cfg.fast_dev:
185
+ export_cfg(self.logger, self.cfg)
186
+
187
+ self.netG.train()
188
+
189
+ in_tensor_dict = {
190
+ "sample": batch["samples_geo"].permute(0, 2, 1),
191
+ "calib": batch["calib"],
192
+ "label": batch["labels_geo"].unsqueeze(1),
193
+ }
194
+
195
+ for name in self.in_total:
196
+ in_tensor_dict.update({name: batch[name]})
197
+
198
+ if self.prior_type == "icon":
199
+ for key in self.icon_keys:
200
+ in_tensor_dict.update({key: batch[key]})
201
+ elif self.prior_type == "pamir":
202
+ for key in self.pamir_keys:
203
+ in_tensor_dict.update({key: batch[key]})
204
+ else:
205
+ pass
206
+
207
+ preds_G, error_G = self.netG(in_tensor_dict)
208
+
209
+ acc, iou, prec, recall = self.evaluator.calc_acc(
210
+ preds_G.flatten(),
211
+ in_tensor_dict["label"].flatten(),
212
+ 0.5,
213
+ use_sdf=self.cfg.sdf,
214
+ )
215
+
216
+ # metrics processing
217
+ metrics_log = {
218
+ "train_loss": error_G.item(),
219
+ "train_acc": acc.item(),
220
+ "train_iou": iou.item(),
221
+ "train_prec": prec.item(),
222
+ "train_recall": recall.item(),
223
+ }
224
+
225
+ tf_log = tf_log_convert(metrics_log)
226
+ bar_log = bar_log_convert(metrics_log)
227
+
228
+ if batch_idx % int(self.cfg.freq_show_train) == 0:
229
+
230
+ with torch.no_grad():
231
+ self.render_func(in_tensor_dict, dataset="train")
232
+
233
+ metrics_return = {
234
+ k.replace("train_", ""): torch.tensor(v) for k, v in metrics_log.items()
235
+ }
236
+
237
+ metrics_return.update(
238
+ {"loss": error_G, "log": tf_log, "progress_bar": bar_log})
239
+
240
+ return metrics_return
241
+
242
+ def training_epoch_end(self, outputs):
243
+
244
+ if [] in outputs:
245
+ outputs = outputs[0]
246
+
247
+ # metrics processing
248
+ metrics_log = {
249
+ "train_avgloss": batch_mean(outputs, "loss"),
250
+ "train_avgiou": batch_mean(outputs, "iou"),
251
+ "train_avgprec": batch_mean(outputs, "prec"),
252
+ "train_avgrecall": batch_mean(outputs, "recall"),
253
+ "train_avgacc": batch_mean(outputs, "acc"),
254
+ }
255
+
256
+ tf_log = tf_log_convert(metrics_log)
257
+
258
+ return {"log": tf_log}
259
+
260
+ def validation_step(self, batch, batch_idx):
261
+
262
+ self.netG.eval()
263
+ self.netG.training = False
264
+
265
+ in_tensor_dict = {
266
+ "sample": batch["samples_geo"].permute(0, 2, 1),
267
+ "calib": batch["calib"],
268
+ "label": batch["labels_geo"].unsqueeze(1),
269
+ }
270
+
271
+ for name in self.in_total:
272
+ in_tensor_dict.update({name: batch[name]})
273
+
274
+ if self.prior_type == "icon":
275
+ for key in self.icon_keys:
276
+ in_tensor_dict.update({key: batch[key]})
277
+ elif self.prior_type == "pamir":
278
+ for key in self.pamir_keys:
279
+ in_tensor_dict.update({key: batch[key]})
280
+ else:
281
+ pass
282
+
283
+ preds_G, error_G = self.netG(in_tensor_dict)
284
+
285
+ acc, iou, prec, recall = self.evaluator.calc_acc(
286
+ preds_G.flatten(),
287
+ in_tensor_dict["label"].flatten(),
288
+ 0.5,
289
+ use_sdf=self.cfg.sdf,
290
+ )
291
+
292
+ if batch_idx % int(self.cfg.freq_show_val) == 0:
293
+ with torch.no_grad():
294
+ self.render_func(in_tensor_dict, dataset="val", idx=batch_idx)
295
+
296
+ metrics_return = {
297
+ "val_loss": error_G,
298
+ "val_acc": acc,
299
+ "val_iou": iou,
300
+ "val_prec": prec,
301
+ "val_recall": recall,
302
+ }
303
+
304
+ return metrics_return
305
+
306
+ def validation_epoch_end(self, outputs):
307
+
308
+ # metrics processing
309
+ metrics_log = {
310
+ "val_avgloss": batch_mean(outputs, "val_loss"),
311
+ "val_avgacc": batch_mean(outputs, "val_acc"),
312
+ "val_avgiou": batch_mean(outputs, "val_iou"),
313
+ "val_avgprec": batch_mean(outputs, "val_prec"),
314
+ "val_avgrecall": batch_mean(outputs, "val_recall"),
315
+ }
316
+
317
+ tf_log = tf_log_convert(metrics_log)
318
+
319
+ return {"log": tf_log}
320
+
321
+ def compute_vis_cmap(self, smpl_type, smpl_verts, smpl_faces):
322
+
323
+ (xy, z) = torch.as_tensor(smpl_verts).split([2, 1], dim=1)
324
+ smpl_vis = get_visibility(xy, -z, torch.as_tensor(smpl_faces).long())
325
+ if smpl_type == "smpl":
326
+ smplx_ind = self.smpl_data.smpl2smplx(np.arange(smpl_vis.shape[0]))
327
+ else:
328
+ smplx_ind = np.arange(smpl_vis.shape[0])
329
+ smpl_cmap = self.smpl_data.get_smpl_mat(smplx_ind)
330
+
331
+ return {
332
+ "smpl_vis": smpl_vis.unsqueeze(0).to(self.device),
333
+ "smpl_cmap": smpl_cmap.unsqueeze(0).to(self.device),
334
+ "smpl_verts": smpl_verts.unsqueeze(0),
335
+ }
336
+
337
+ @torch.enable_grad()
338
+ def optim_body(self, in_tensor_dict, batch):
339
+
340
+ smpl_model = self.get_smpl_model(
341
+ batch["type"][0], batch["gender"][0], batch["age"][0], None
342
+ ).to(self.device)
343
+ in_tensor_dict["smpl_faces"] = (
344
+ torch.tensor(smpl_model.faces.astype(np.int))
345
+ .long()
346
+ .unsqueeze(0)
347
+ .to(self.device)
348
+ )
349
+
350
+ # The optimizer and variables
351
+ optimed_pose = torch.tensor(
352
+ batch["body_pose"][0], device=self.device, requires_grad=True
353
+ ) # [1,23,3,3]
354
+ optimed_trans = torch.tensor(
355
+ batch["transl"][0], device=self.device, requires_grad=True
356
+ ) # [3]
357
+ optimed_betas = torch.tensor(
358
+ batch["betas"][0], device=self.device, requires_grad=True
359
+ ) # [1,10]
360
+ optimed_orient = torch.tensor(
361
+ batch["global_orient"][0], device=self.device, requires_grad=True
362
+ ) # [1,1,3,3]
363
+
364
+ optimizer_smpl = torch.optim.SGD(
365
+ [optimed_pose, optimed_trans, optimed_betas, optimed_orient],
366
+ lr=1e-3,
367
+ momentum=0.9,
368
+ )
369
+ scheduler_smpl = torch.optim.lr_scheduler.ReduceLROnPlateau(
370
+ optimizer_smpl, mode="min", factor=0.5, verbose=0, min_lr=1e-5, patience=5
371
+ )
372
+ loop_smpl = range(50)
373
+ for i in loop_smpl:
374
+
375
+ optimizer_smpl.zero_grad()
376
+
377
+ # prior_loss, optimed_pose = dataset.vposer_prior(optimed_pose)
378
+ smpl_out = smpl_model(
379
+ betas=optimed_betas,
380
+ body_pose=optimed_pose,
381
+ global_orient=optimed_orient,
382
+ transl=optimed_trans,
383
+ return_verts=True,
384
+ )
385
+
386
+ smpl_verts = smpl_out.vertices[0] * 100.0
387
+ smpl_verts = projection(
388
+ smpl_verts, batch["calib"][0], format="tensor")
389
+ smpl_verts[:, 1] *= -1
390
+ # render optimized mesh (normal, T_normal, image [-1,1])
391
+ self.render.load_meshes(
392
+ smpl_verts, in_tensor_dict["smpl_faces"])
393
+ (
394
+ in_tensor_dict["T_normal_F"],
395
+ in_tensor_dict["T_normal_B"],
396
+ ) = self.render.get_rgb_image()
397
+
398
+ T_mask_F, T_mask_B = self.render.get_silhouette_image()
399
+
400
+ with torch.no_grad():
401
+ (
402
+ in_tensor_dict["normal_F"],
403
+ in_tensor_dict["normal_B"],
404
+ ) = self.netG.normal_filter(in_tensor_dict)
405
+
406
+ # mask = torch.abs(in_tensor['T_normal_F']).sum(dim=0, keepdims=True) > 0.0
407
+ diff_F_smpl = torch.abs(
408
+ in_tensor_dict["T_normal_F"] - in_tensor_dict["normal_F"]
409
+ )
410
+ diff_B_smpl = torch.abs(
411
+ in_tensor_dict["T_normal_B"] - in_tensor_dict["normal_B"]
412
+ )
413
+ loss = (diff_F_smpl + diff_B_smpl).mean()
414
+
415
+ # silhouette loss
416
+ smpl_arr = torch.cat([T_mask_F, T_mask_B], dim=-1)[0]
417
+ gt_arr = torch.cat(
418
+ [in_tensor_dict["normal_F"][0], in_tensor_dict["normal_B"][0]], dim=2
419
+ ).permute(1, 2, 0)
420
+ gt_arr = ((gt_arr + 1.0) * 0.5).to(self.device)
421
+ bg_color = (
422
+ torch.Tensor([0.5, 0.5, 0.5]).unsqueeze(
423
+ 0).unsqueeze(0).to(self.device)
424
+ )
425
+ gt_arr = ((gt_arr - bg_color).sum(dim=-1) != 0.0).float()
426
+ loss += torch.abs(smpl_arr - gt_arr).mean()
427
+
428
+ # Image.fromarray(((in_tensor_dict['T_normal_F'][0].permute(1,2,0)+1.0)*0.5*255.0).detach().cpu().numpy().astype(np.uint8)).show()
429
+
430
+ # loop_smpl.set_description(f"smpl = {loss:.3f}")
431
+
432
+ loss.backward(retain_graph=True)
433
+ optimizer_smpl.step()
434
+ scheduler_smpl.step(loss)
435
+ in_tensor_dict["smpl_verts"] = smpl_verts.unsqueeze(0)
436
+
437
+ in_tensor_dict.update(
438
+ self.compute_vis_cmap(
439
+ batch["type"][0],
440
+ in_tensor_dict["smpl_verts"][0],
441
+ in_tensor_dict["smpl_faces"][0],
442
+ )
443
+ )
444
+
445
+ features, inter = self.netG.filter(in_tensor_dict, return_inter=True)
446
+
447
+ return features, inter, in_tensor_dict
448
+
449
+ @torch.enable_grad()
450
+ def optim_cloth(self, verts_pr, faces_pr, inter):
451
+
452
+ # convert from GT to SDF
453
+ verts_pr -= (self.resolutions[-1] - 1) / 2.0
454
+ verts_pr /= (self.resolutions[-1] - 1) / 2.0
455
+
456
+ losses = {
457
+ "cloth": {"weight": 5.0, "value": 0.0},
458
+ "edge": {"weight": 100.0, "value": 0.0},
459
+ "normal": {"weight": 0.2, "value": 0.0},
460
+ "laplacian": {"weight": 100.0, "value": 0.0},
461
+ "smpl": {"weight": 1.0, "value": 0.0},
462
+ "deform": {"weight": 20.0, "value": 0.0},
463
+ }
464
+
465
+ deform_verts = torch.full(
466
+ verts_pr.shape, 0.0, device=self.device, requires_grad=True
467
+ )
468
+ optimizer_cloth = torch.optim.SGD(
469
+ [deform_verts], lr=1e-1, momentum=0.9)
470
+ scheduler_cloth = torch.optim.lr_scheduler.ReduceLROnPlateau(
471
+ optimizer_cloth, mode="min", factor=0.1, verbose=0, min_lr=1e-3, patience=5
472
+ )
473
+ # cloth optimization
474
+ loop_cloth = range(100)
475
+
476
+ for i in loop_cloth:
477
+
478
+ optimizer_cloth.zero_grad()
479
+
480
+ self.render.load_meshes(
481
+ verts_pr.unsqueeze(0).to(self.device),
482
+ faces_pr.unsqueeze(0).to(self.device).long(),
483
+ deform_verts,
484
+ )
485
+ P_normal_F, P_normal_B = self.render.get_rgb_image()
486
+
487
+ update_mesh_shape_prior_losses(self.render.mesh, losses)
488
+ diff_F_cloth = torch.abs(P_normal_F[0] - inter[:3])
489
+ diff_B_cloth = torch.abs(P_normal_B[0] - inter[3:])
490
+ losses["cloth"]["value"] = (diff_F_cloth + diff_B_cloth).mean()
491
+ losses["deform"]["value"] = torch.topk(
492
+ torch.abs(deform_verts.flatten()), 30
493
+ )[0].mean()
494
+
495
+ # Weighted sum of the losses
496
+ cloth_loss = torch.tensor(0.0, device=self.device)
497
+ pbar_desc = ""
498
+
499
+ for k in losses.keys():
500
+ if k != "smpl":
501
+ cloth_loss_per_cls = losses[k]["value"] * \
502
+ losses[k]["weight"]
503
+ pbar_desc += f"{k}: {cloth_loss_per_cls:.3f} | "
504
+ cloth_loss += cloth_loss_per_cls
505
+
506
+ # loop_cloth.set_description(pbar_desc)
507
+ cloth_loss.backward(retain_graph=True)
508
+ optimizer_cloth.step()
509
+ scheduler_cloth.step(cloth_loss)
510
+
511
+ # convert from GT to SDF
512
+ deform_verts = deform_verts.flatten().detach()
513
+ deform_verts[torch.topk(torch.abs(deform_verts), 30)[
514
+ 1]] = deform_verts.mean()
515
+ deform_verts = deform_verts.view(-1, 3).cpu()
516
+
517
+ verts_pr += deform_verts
518
+ verts_pr *= (self.resolutions[-1] - 1) / 2.0
519
+ verts_pr += (self.resolutions[-1] - 1) / 2.0
520
+
521
+ return verts_pr
522
+
523
+ def test_step(self, batch, batch_idx):
524
+
525
+ # dict_keys(['dataset', 'subject', 'rotation', 'scale', 'calib',
526
+ # 'normal_F', 'normal_B', 'image', 'T_normal_F', 'T_normal_B',
527
+ # 'z-trans', 'verts', 'faces', 'samples_geo', 'labels_geo',
528
+ # 'smpl_verts', 'smpl_faces', 'smpl_vis', 'smpl_cmap', 'pts_signs',
529
+ # 'type', 'gender', 'age', 'body_pose', 'global_orient', 'betas', 'transl'])
530
+
531
+ if self.evaluator._normal_render is None:
532
+ self.evaluator.init_gl()
533
+
534
+ self.netG.eval()
535
+ self.netG.training = False
536
+ in_tensor_dict = {}
537
+
538
+ # export paths
539
+ mesh_name = batch["subject"][0]
540
+ mesh_rot = batch["rotation"][0].item()
541
+ ckpt_dir = self.cfg.name
542
+
543
+ for kid, key in enumerate(self.cfg.dataset.noise_type):
544
+ ckpt_dir += f"_{key}_{self.cfg.dataset.noise_scale[kid]}"
545
+
546
+ if self.cfg.optim_cloth:
547
+ ckpt_dir += "_optim_cloth"
548
+ if self.cfg.optim_body:
549
+ ckpt_dir += "_optim_body"
550
+
551
+ self.export_dir = osp.join(self.cfg.results_path, ckpt_dir, mesh_name)
552
+ os.makedirs(self.export_dir, exist_ok=True)
553
+
554
+ for name in self.in_total:
555
+ if name in batch.keys():
556
+ in_tensor_dict.update({name: batch[name]})
557
+
558
+ # update the new T_normal_F/B
559
+ in_tensor_dict.update(
560
+ self.evaluator.render_normal(
561
+ batch["smpl_verts"], batch["smpl_faces"])
562
+ )
563
+
564
+ # update the new smpl_vis
565
+ (xy, z) = batch["smpl_verts"][0].split([2, 1], dim=1)
566
+ smpl_vis = get_visibility(
567
+ xy,
568
+ z,
569
+ torch.as_tensor(self.smpl_data.faces).type_as(
570
+ batch["smpl_verts"]).long(),
571
+ )
572
+ in_tensor_dict.update({"smpl_vis": smpl_vis.unsqueeze(0)})
573
+
574
+ if self.prior_type == "icon":
575
+ for key in self.icon_keys:
576
+ in_tensor_dict.update({key: batch[key]})
577
+ elif self.prior_type == "pamir":
578
+ for key in self.pamir_keys:
579
+ in_tensor_dict.update({key: batch[key]})
580
+ else:
581
+ pass
582
+
583
+ with torch.no_grad():
584
+ if self.cfg.optim_body:
585
+ features, inter, in_tensor_dict = self.optim_body(
586
+ in_tensor_dict, batch)
587
+ else:
588
+ features, inter = self.netG.filter(
589
+ in_tensor_dict, return_inter=True)
590
+ sdf = self.reconEngine(
591
+ opt=self.cfg, netG=self.netG, features=features, proj_matrix=None
592
+ )
593
+
594
+ # save inter results
595
+ image = (
596
+ in_tensor_dict["image"][0].permute(
597
+ 1, 2, 0).detach().cpu().numpy() + 1.0
598
+ ) * 0.5
599
+ smpl_F = (
600
+ in_tensor_dict["T_normal_F"][0].permute(
601
+ 1, 2, 0).detach().cpu().numpy()
602
+ + 1.0
603
+ ) * 0.5
604
+ smpl_B = (
605
+ in_tensor_dict["T_normal_B"][0].permute(
606
+ 1, 2, 0).detach().cpu().numpy()
607
+ + 1.0
608
+ ) * 0.5
609
+ image_inter = np.concatenate(
610
+ self.tensor2image(512, inter[0]) + [smpl_F, smpl_B, image], axis=1
611
+ )
612
+ Image.fromarray((image_inter * 255.0).astype(np.uint8)).save(
613
+ osp.join(self.export_dir, f"{mesh_rot}_inter.png")
614
+ )
615
+
616
+ verts_pr, faces_pr = self.reconEngine.export_mesh(sdf)
617
+
618
+ if self.clean_mesh_flag:
619
+ verts_pr, faces_pr = clean_mesh(verts_pr, faces_pr)
620
+
621
+ if self.cfg.optim_cloth:
622
+ verts_pr = self.optim_cloth(verts_pr, faces_pr, inter[0].detach())
623
+
624
+ verts_gt = batch["verts"][0]
625
+ faces_gt = batch["faces"][0]
626
+
627
+ self.result_eval.update(
628
+ {
629
+ "verts_gt": verts_gt,
630
+ "faces_gt": faces_gt,
631
+ "verts_pr": verts_pr,
632
+ "faces_pr": faces_pr,
633
+ "recon_size": (self.resolutions[-1] - 1.0),
634
+ "calib": batch["calib"][0],
635
+ }
636
+ )
637
+
638
+ self.evaluator.set_mesh(self.result_eval, scale_factor=1.0)
639
+ self.evaluator.space_transfer()
640
+
641
+ chamfer, p2s = self.evaluator.calculate_chamfer_p2s(
642
+ sampled_points=1000)
643
+ normal_consist = self.evaluator.calculate_normal_consist(
644
+ save_demo_img=osp.join(self.export_dir, f"{mesh_rot}_nc.png")
645
+ )
646
+
647
+ test_log = {"chamfer": chamfer, "p2s": p2s, "NC": normal_consist}
648
+
649
+ return test_log
650
+
651
+ def test_epoch_end(self, outputs):
652
+
653
+ # make_test_gif("/".join(self.export_dir.split("/")[:-2]))
654
+
655
+ accu_outputs = accumulate(
656
+ outputs,
657
+ rot_num=3,
658
+ split={
659
+ "thuman2": (0, 5),
660
+ },
661
+ )
662
+
663
+ print(colored(self.cfg.name, "green"))
664
+ print(colored(self.cfg.dataset.noise_scale, "green"))
665
+
666
+ self.logger.experiment.add_hparams(
667
+ hparam_dict={"lr_G": self.lr_G, "bsize": self.batch_size},
668
+ metric_dict=accu_outputs,
669
+ )
670
+
671
+ np.save(
672
+ osp.join(self.export_dir, "../test_results.npy"),
673
+ accu_outputs,
674
+ allow_pickle=True,
675
+ )
676
+
677
+ return accu_outputs
678
+
679
+ def tensor2image(self, height, inter):
680
+
681
+ all = []
682
+ for dim in self.in_geo_dim:
683
+ img = resize(
684
+ np.tile(
685
+ ((inter[:dim].cpu().numpy() + 1.0) /
686
+ 2.0).transpose(1, 2, 0),
687
+ (1, 1, int(3 / dim)),
688
+ ),
689
+ (height, height),
690
+ anti_aliasing=True,
691
+ )
692
+
693
+ all.append(img)
694
+ inter = inter[dim:]
695
+
696
+ return all
697
+
698
+ def render_func(self, in_tensor_dict, dataset="title", idx=0):
699
+
700
+ for name in in_tensor_dict.keys():
701
+ in_tensor_dict[name] = in_tensor_dict[name][0:1]
702
+
703
+ self.netG.eval()
704
+ features, inter = self.netG.filter(in_tensor_dict, return_inter=True)
705
+ sdf = self.reconEngine(
706
+ opt=self.cfg, netG=self.netG, features=features, proj_matrix=None
707
+ )
708
+
709
+ if sdf is not None:
710
+ render = self.reconEngine.display(sdf)
711
+
712
+ image_pred = np.flip(render[:, :, ::-1], axis=0)
713
+ height = image_pred.shape[0]
714
+
715
+ image_gt = resize(
716
+ ((in_tensor_dict["image"].cpu().numpy()[0] + 1.0) / 2.0).transpose(
717
+ 1, 2, 0
718
+ ),
719
+ (height, height),
720
+ anti_aliasing=True,
721
+ )
722
+ image_inter = self.tensor2image(height, inter[0])
723
+ image = np.concatenate(
724
+ [image_pred, image_gt] + image_inter, axis=1)
725
+
726
+ step_id = self.global_step if dataset == "train" else self.global_step + idx
727
+ self.logger.experiment.add_image(
728
+ tag=f"Occupancy-{dataset}/{step_id}",
729
+ img_tensor=image.transpose(2, 0, 1),
730
+ global_step=step_id,
731
+ )
732
+
733
+ def test_single(self, batch):
734
+
735
+ self.netG.eval()
736
+ self.netG.training = False
737
+ in_tensor_dict = {}
738
+
739
+ for name in self.in_total:
740
+ if name in batch.keys():
741
+ in_tensor_dict.update({name: batch[name]})
742
+
743
+ if self.prior_type == "icon":
744
+ for key in self.icon_keys:
745
+ in_tensor_dict.update({key: batch[key]})
746
+ elif self.prior_type == "pamir":
747
+ for key in self.pamir_keys:
748
+ in_tensor_dict.update({key: batch[key]})
749
+ else:
750
+ pass
751
+
752
+ features, inter = self.netG.filter(in_tensor_dict, return_inter=True)
753
+ sdf = self.reconEngine(
754
+ opt=self.cfg, netG=self.netG, features=features, proj_matrix=None
755
+ )
756
+
757
+ verts_pr, faces_pr = self.reconEngine.export_mesh(sdf)
758
+
759
+ if self.clean_mesh_flag:
760
+ verts_pr, faces_pr = clean_mesh(verts_pr, faces_pr)
761
+
762
+ verts_pr -= (self.resolutions[-1] - 1) / 2.0
763
+ verts_pr /= (self.resolutions[-1] - 1) / 2.0
764
+
765
+ return verts_pr, faces_pr, inter
apps/Normal.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from lib.net import NormalNet
2
+ from lib.common.train_util import *
3
+ import logging
4
+ import torch
5
+ import numpy as np
6
+ from torch import nn
7
+ from skimage.transform import resize
8
+ import pytorch_lightning as pl
9
+
10
+ torch.backends.cudnn.benchmark = True
11
+
12
+ logging.getLogger("lightning").setLevel(logging.ERROR)
13
+
14
+
15
+ class Normal(pl.LightningModule):
16
+ def __init__(self, cfg):
17
+ super(Normal, self).__init__()
18
+ self.cfg = cfg
19
+ self.batch_size = self.cfg.batch_size
20
+ self.lr_N = self.cfg.lr_N
21
+
22
+ self.schedulers = []
23
+
24
+ self.netG = NormalNet(self.cfg, error_term=nn.SmoothL1Loss())
25
+
26
+ self.in_nml = [item[0] for item in cfg.net.in_nml]
27
+
28
+ def get_progress_bar_dict(self):
29
+ tqdm_dict = super().get_progress_bar_dict()
30
+ if "v_num" in tqdm_dict:
31
+ del tqdm_dict["v_num"]
32
+ return tqdm_dict
33
+
34
+ # Training related
35
+ def configure_optimizers(self):
36
+
37
+ # set optimizer
38
+ weight_decay = self.cfg.weight_decay
39
+ momentum = self.cfg.momentum
40
+
41
+ optim_params_N_F = [
42
+ {"params": self.netG.netF.parameters(), "lr": self.lr_N}]
43
+ optim_params_N_B = [
44
+ {"params": self.netG.netB.parameters(), "lr": self.lr_N}]
45
+
46
+ optimizer_N_F = torch.optim.Adam(
47
+ optim_params_N_F, lr=self.lr_N, weight_decay=weight_decay
48
+ )
49
+
50
+ optimizer_N_B = torch.optim.Adam(
51
+ optim_params_N_B, lr=self.lr_N, weight_decay=weight_decay
52
+ )
53
+
54
+ scheduler_N_F = torch.optim.lr_scheduler.MultiStepLR(
55
+ optimizer_N_F, milestones=self.cfg.schedule, gamma=self.cfg.gamma
56
+ )
57
+
58
+ scheduler_N_B = torch.optim.lr_scheduler.MultiStepLR(
59
+ optimizer_N_B, milestones=self.cfg.schedule, gamma=self.cfg.gamma
60
+ )
61
+
62
+ self.schedulers = [scheduler_N_F, scheduler_N_B]
63
+ optims = [optimizer_N_F, optimizer_N_B]
64
+
65
+ return optims, self.schedulers
66
+
67
+ def render_func(self, render_tensor):
68
+
69
+ height = render_tensor["image"].shape[2]
70
+ result_list = []
71
+
72
+ for name in render_tensor.keys():
73
+ result_list.append(
74
+ resize(
75
+ ((render_tensor[name].cpu().numpy()[0] + 1.0) / 2.0).transpose(
76
+ 1, 2, 0
77
+ ),
78
+ (height, height),
79
+ anti_aliasing=True,
80
+ )
81
+ )
82
+ result_array = np.concatenate(result_list, axis=1)
83
+
84
+ return result_array
85
+
86
+ def training_step(self, batch, batch_idx, optimizer_idx):
87
+
88
+ export_cfg(self.logger, self.cfg)
89
+
90
+ # retrieve the data
91
+ in_tensor = {}
92
+ for name in self.in_nml:
93
+ in_tensor[name] = batch[name]
94
+
95
+ FB_tensor = {"normal_F": batch["normal_F"],
96
+ "normal_B": batch["normal_B"]}
97
+
98
+ self.netG.train()
99
+
100
+ preds_F, preds_B = self.netG(in_tensor)
101
+ error_NF, error_NB = self.netG.get_norm_error(
102
+ preds_F, preds_B, FB_tensor)
103
+
104
+ (opt_nf, opt_nb) = self.optimizers()
105
+
106
+ opt_nf.zero_grad()
107
+ opt_nb.zero_grad()
108
+
109
+ self.manual_backward(error_NF, opt_nf)
110
+ self.manual_backward(error_NB, opt_nb)
111
+
112
+ opt_nf.step()
113
+ opt_nb.step()
114
+
115
+ if batch_idx > 0 and batch_idx % int(self.cfg.freq_show_train) == 0:
116
+
117
+ self.netG.eval()
118
+ with torch.no_grad():
119
+ nmlF, nmlB = self.netG(in_tensor)
120
+ in_tensor.update({"nmlF": nmlF, "nmlB": nmlB})
121
+ result_array = self.render_func(in_tensor)
122
+
123
+ self.logger.experiment.add_image(
124
+ tag=f"Normal-train/{self.global_step}",
125
+ img_tensor=result_array.transpose(2, 0, 1),
126
+ global_step=self.global_step,
127
+ )
128
+
129
+ # metrics processing
130
+ metrics_log = {
131
+ "train_loss-NF": error_NF.item(),
132
+ "train_loss-NB": error_NB.item(),
133
+ }
134
+
135
+ tf_log = tf_log_convert(metrics_log)
136
+ bar_log = bar_log_convert(metrics_log)
137
+
138
+ return {
139
+ "loss": error_NF + error_NB,
140
+ "loss-NF": error_NF,
141
+ "loss-NB": error_NB,
142
+ "log": tf_log,
143
+ "progress_bar": bar_log,
144
+ }
145
+
146
+ def training_epoch_end(self, outputs):
147
+
148
+ if [] in outputs:
149
+ outputs = outputs[0]
150
+
151
+ # metrics processing
152
+ metrics_log = {
153
+ "train_avgloss": batch_mean(outputs, "loss"),
154
+ "train_avgloss-NF": batch_mean(outputs, "loss-NF"),
155
+ "train_avgloss-NB": batch_mean(outputs, "loss-NB"),
156
+ }
157
+
158
+ tf_log = tf_log_convert(metrics_log)
159
+
160
+ tf_log["lr-NF"] = self.schedulers[0].get_last_lr()[0]
161
+ tf_log["lr-NB"] = self.schedulers[1].get_last_lr()[0]
162
+
163
+ return {"log": tf_log}
164
+
165
+ def validation_step(self, batch, batch_idx):
166
+
167
+ # retrieve the data
168
+ in_tensor = {}
169
+ for name in self.in_nml:
170
+ in_tensor[name] = batch[name]
171
+
172
+ FB_tensor = {"normal_F": batch["normal_F"],
173
+ "normal_B": batch["normal_B"]}
174
+
175
+ self.netG.train()
176
+
177
+ preds_F, preds_B = self.netG(in_tensor)
178
+ error_NF, error_NB = self.netG.get_norm_error(
179
+ preds_F, preds_B, FB_tensor)
180
+
181
+ if (batch_idx > 0 and batch_idx % int(self.cfg.freq_show_train) == 0) or (
182
+ batch_idx == 0
183
+ ):
184
+
185
+ with torch.no_grad():
186
+ nmlF, nmlB = self.netG(in_tensor)
187
+ in_tensor.update({"nmlF": nmlF, "nmlB": nmlB})
188
+ result_array = self.render_func(in_tensor)
189
+
190
+ self.logger.experiment.add_image(
191
+ tag=f"Normal-val/{self.global_step}",
192
+ img_tensor=result_array.transpose(2, 0, 1),
193
+ global_step=self.global_step,
194
+ )
195
+
196
+ return {
197
+ "val_loss": error_NF + error_NB,
198
+ "val_loss-NF": error_NF,
199
+ "val_loss-NB": error_NB,
200
+ }
201
+
202
+ def validation_epoch_end(self, outputs):
203
+
204
+ # metrics processing
205
+ metrics_log = {
206
+ "val_avgloss": batch_mean(outputs, "val_loss"),
207
+ "val_avgloss-NF": batch_mean(outputs, "val_loss-NF"),
208
+ "val_avgloss-NB": batch_mean(outputs, "val_loss-NB"),
209
+ }
210
+
211
+ tf_log = tf_log_convert(metrics_log)
212
+
213
+ return {"log": tf_log}
apps/infer.py ADDED
@@ -0,0 +1,467 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
4
+ # holder of all proprietary rights on this computer program.
5
+ # You can only use this computer program if you have closed
6
+ # a license agreement with MPG or you get the right to use the computer
7
+ # program from someone who is authorized to grant you that right.
8
+ # Any use of the computer program without a valid license is prohibited and
9
+ # liable to prosecution.
10
+ #
11
+ # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
12
+ # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
13
+ # for Intelligent Systems. All rights reserved.
14
+ #
15
+ # Contact: [email protected]
16
+
17
+ import os
18
+
19
+ import logging
20
+ from lib.common.render import query_color
21
+ from lib.common.config import cfg
22
+ from lib.dataset.mesh_util import (
23
+ load_checkpoint,
24
+ update_mesh_shape_prior_losses,
25
+ blend_rgb_norm,
26
+ unwrap,
27
+ remesh,
28
+ tensor2variable,
29
+ )
30
+
31
+ from lib.dataset.TestDataset import TestDataset
32
+ from lib.net.local_affine import LocalAffine
33
+ from pytorch3d.structures import Meshes
34
+ from apps.ICON import ICON
35
+
36
+ from termcolor import colored
37
+ import numpy as np
38
+ from PIL import Image
39
+ import trimesh
40
+ import numpy as np
41
+ from tqdm import tqdm
42
+
43
+ import torch
44
+ torch.backends.cudnn.benchmark = True
45
+
46
+ logging.getLogger("trimesh").setLevel(logging.ERROR)
47
+
48
+
49
+ def generate_model(in_path, model_type):
50
+
51
+ torch.cuda.empty_cache()
52
+
53
+ config_dict = {'loop_smpl': 50,
54
+ 'loop_cloth': 100,
55
+ 'patience': 5,
56
+ 'vis_freq': 10,
57
+ 'out_dir': './results',
58
+ 'hps_type': 'pymaf',
59
+ 'config': f"./configs/{model_type}.yaml"}
60
+
61
+ # cfg read and merge
62
+ cfg.merge_from_file(config_dict['config'])
63
+ cfg.merge_from_file("./lib/pymaf/configs/pymaf_config.yaml")
64
+
65
+ os.makedirs(config_dict['out_dir'], exist_ok=True)
66
+
67
+ cfg_show_list = [
68
+ "test_gpus",
69
+ [0],
70
+ "mcube_res",
71
+ 256,
72
+ "clean_mesh",
73
+ True,
74
+ ]
75
+
76
+ cfg.merge_from_list(cfg_show_list)
77
+ cfg.freeze()
78
+
79
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0"
80
+ device = torch.device(f"cuda:0")
81
+
82
+ # load model and dataloader
83
+ model = ICON(cfg)
84
+ model = load_checkpoint(model, cfg)
85
+
86
+ dataset_param = {
87
+ 'image_path': in_path,
88
+ 'seg_dir': None,
89
+ 'has_det': True, # w/ or w/o detection
90
+ 'hps_type': 'pymaf' # pymaf/pare/pixie
91
+ }
92
+
93
+ if config_dict['hps_type'] == "pixie" and "pamir" in config_dict['config']:
94
+ print(colored("PIXIE isn't compatible with PaMIR, thus switch to PyMAF", "red"))
95
+ dataset_param["hps_type"] = "pymaf"
96
+
97
+ dataset = TestDataset(dataset_param, device)
98
+
99
+ print(colored(f"Dataset Size: {len(dataset)}", "green"))
100
+
101
+ pbar = tqdm(dataset)
102
+
103
+ for data in pbar:
104
+
105
+ pbar.set_description(f"{data['name']}")
106
+
107
+ in_tensor = {"smpl_faces": data["smpl_faces"], "image": data["image"]}
108
+
109
+ # The optimizer and variables
110
+ optimed_pose = torch.tensor(
111
+ data["body_pose"], device=device, requires_grad=True
112
+ ) # [1,23,3,3]
113
+ optimed_trans = torch.tensor(
114
+ data["trans"], device=device, requires_grad=True
115
+ ) # [3]
116
+ optimed_betas = torch.tensor(
117
+ data["betas"], device=device, requires_grad=True
118
+ ) # [1,10]
119
+ optimed_orient = torch.tensor(
120
+ data["global_orient"], device=device, requires_grad=True
121
+ ) # [1,1,3,3]
122
+
123
+ optimizer_smpl = torch.optim.SGD(
124
+ [optimed_pose, optimed_trans, optimed_betas, optimed_orient],
125
+ lr=1e-3,
126
+ momentum=0.9,
127
+ )
128
+ scheduler_smpl = torch.optim.lr_scheduler.ReduceLROnPlateau(
129
+ optimizer_smpl,
130
+ mode="min",
131
+ factor=0.5,
132
+ verbose=0,
133
+ min_lr=1e-5,
134
+ patience=config_dict['patience'],
135
+ )
136
+
137
+ losses = {
138
+ # Cloth: Normal_recon - Normal_pred
139
+ "cloth": {"weight": 1e1, "value": 0.0},
140
+ # Cloth: [RT]_v1 - [RT]_v2 (v1-edge-v2)
141
+ "stiffness": {"weight": 1e5, "value": 0.0},
142
+ # Cloth: det(R) = 1
143
+ "rigid": {"weight": 1e5, "value": 0.0},
144
+ # Cloth: edge length
145
+ "edge": {"weight": 0, "value": 0.0},
146
+ # Cloth: normal consistency
147
+ "nc": {"weight": 0, "value": 0.0},
148
+ # Cloth: laplacian smoonth
149
+ "laplacian": {"weight": 1e2, "value": 0.0},
150
+ # Body: Normal_pred - Normal_smpl
151
+ "normal": {"weight": 1e0, "value": 0.0},
152
+ # Body: Silhouette_pred - Silhouette_smpl
153
+ "silhouette": {"weight": 1e0, "value": 0.0},
154
+ }
155
+
156
+ # smpl optimization
157
+
158
+ loop_smpl = tqdm(
159
+ range(config_dict['loop_smpl'] if cfg.net.prior_type != "pifu" else 1))
160
+
161
+ for i in loop_smpl:
162
+
163
+ optimizer_smpl.zero_grad()
164
+
165
+ if dataset_param["hps_type"] != "pixie":
166
+ smpl_out = dataset.smpl_model(
167
+ betas=optimed_betas,
168
+ body_pose=optimed_pose,
169
+ global_orient=optimed_orient,
170
+ pose2rot=False,
171
+ )
172
+
173
+ smpl_verts = ((smpl_out.vertices) +
174
+ optimed_trans) * data["scale"]
175
+ else:
176
+ smpl_verts, _, _ = dataset.smpl_model(
177
+ shape_params=optimed_betas,
178
+ expression_params=tensor2variable(data["exp"], device),
179
+ body_pose=optimed_pose,
180
+ global_pose=optimed_orient,
181
+ jaw_pose=tensor2variable(data["jaw_pose"], device),
182
+ left_hand_pose=tensor2variable(
183
+ data["left_hand_pose"], device),
184
+ right_hand_pose=tensor2variable(
185
+ data["right_hand_pose"], device),
186
+ )
187
+
188
+ smpl_verts = (smpl_verts + optimed_trans) * data["scale"]
189
+
190
+ # render optimized mesh (normal, T_normal, image [-1,1])
191
+ in_tensor["T_normal_F"], in_tensor["T_normal_B"] = dataset.render_normal(
192
+ smpl_verts *
193
+ torch.tensor([1.0, -1.0, -1.0]
194
+ ).to(device), in_tensor["smpl_faces"]
195
+ )
196
+ T_mask_F, T_mask_B = dataset.render.get_silhouette_image()
197
+
198
+ with torch.no_grad():
199
+ in_tensor["normal_F"], in_tensor["normal_B"] = model.netG.normal_filter(
200
+ in_tensor
201
+ )
202
+
203
+ diff_F_smpl = torch.abs(
204
+ in_tensor["T_normal_F"] - in_tensor["normal_F"])
205
+ diff_B_smpl = torch.abs(
206
+ in_tensor["T_normal_B"] - in_tensor["normal_B"])
207
+
208
+ losses["normal"]["value"] = (diff_F_smpl + diff_B_smpl).mean()
209
+
210
+ # silhouette loss
211
+ smpl_arr = torch.cat([T_mask_F, T_mask_B], dim=-1)[0]
212
+ gt_arr = torch.cat(
213
+ [in_tensor["normal_F"][0], in_tensor["normal_B"][0]], dim=2
214
+ ).permute(1, 2, 0)
215
+ gt_arr = ((gt_arr + 1.0) * 0.5).to(device)
216
+ bg_color = (
217
+ torch.Tensor([0.5, 0.5, 0.5]).unsqueeze(
218
+ 0).unsqueeze(0).to(device)
219
+ )
220
+ gt_arr = ((gt_arr - bg_color).sum(dim=-1) != 0.0).float()
221
+ diff_S = torch.abs(smpl_arr - gt_arr)
222
+ losses["silhouette"]["value"] = diff_S.mean()
223
+
224
+ # Weighted sum of the losses
225
+ smpl_loss = 0.0
226
+ pbar_desc = "Body Fitting --- "
227
+ for k in ["normal", "silhouette"]:
228
+ pbar_desc += f"{k}: {losses[k]['value'] * losses[k]['weight']:.3f} | "
229
+ smpl_loss += losses[k]["value"] * losses[k]["weight"]
230
+ pbar_desc += f"Total: {smpl_loss:.3f}"
231
+ loop_smpl.set_description(pbar_desc)
232
+
233
+ smpl_loss.backward()
234
+ optimizer_smpl.step()
235
+ scheduler_smpl.step(smpl_loss)
236
+ in_tensor["smpl_verts"] = smpl_verts * \
237
+ torch.tensor([1.0, 1.0, -1.0]).to(device)
238
+
239
+ # visualize the optimization process
240
+ # 1. SMPL Fitting
241
+ # 2. Clothes Refinement
242
+
243
+ os.makedirs(os.path.join(config_dict['out_dir'], cfg.name,
244
+ "refinement"), exist_ok=True)
245
+
246
+ # visualize the final results in self-rotation mode
247
+ os.makedirs(os.path.join(config_dict['out_dir'],
248
+ cfg.name, "vid"), exist_ok=True)
249
+
250
+ # final results rendered as image
251
+ # 1. Render the final fitted SMPL (xxx_smpl.png)
252
+ # 2. Render the final reconstructed clothed human (xxx_cloth.png)
253
+ # 3. Blend the original image with predicted cloth normal (xxx_overlap.png)
254
+
255
+ os.makedirs(os.path.join(config_dict['out_dir'],
256
+ cfg.name, "png"), exist_ok=True)
257
+
258
+ # final reconstruction meshes
259
+ # 1. SMPL mesh (xxx_smpl.obj)
260
+ # 2. SMPL params (xxx_smpl.npy)
261
+ # 3. clohted mesh (xxx_recon.obj)
262
+ # 4. remeshed clothed mesh (xxx_remesh.obj)
263
+ # 5. refined clothed mesh (xxx_refine.obj)
264
+
265
+ os.makedirs(os.path.join(config_dict['out_dir'],
266
+ cfg.name, "obj"), exist_ok=True)
267
+
268
+
269
+ norm_pred = (
270
+ ((in_tensor["normal_F"][0].permute(1, 2, 0) + 1.0) * 255.0 / 2.0)
271
+ .detach()
272
+ .cpu()
273
+ .numpy()
274
+ .astype(np.uint8)
275
+ )
276
+
277
+ norm_orig = unwrap(norm_pred, data)
278
+ mask_orig = unwrap(
279
+ np.repeat(
280
+ data["mask"].permute(1, 2, 0).detach().cpu().numpy(), 3, axis=2
281
+ ).astype(np.uint8),
282
+ data,
283
+ )
284
+ rgb_norm = blend_rgb_norm(data["ori_image"], norm_orig, mask_orig)
285
+
286
+ Image.fromarray(
287
+ np.concatenate(
288
+ [data["ori_image"].astype(np.uint8), rgb_norm], axis=1)
289
+ ).save(os.path.join(config_dict['out_dir'], cfg.name, f"png/{data['name']}_overlap.png"))
290
+
291
+ smpl_obj = trimesh.Trimesh(
292
+ in_tensor["smpl_verts"].detach().cpu()[0] *
293
+ torch.tensor([1.0, -1.0, 1.0]),
294
+ in_tensor['smpl_faces'].detach().cpu()[0],
295
+ process=False,
296
+ maintains_order=True
297
+ )
298
+ smpl_obj.export(
299
+ f"{config_dict['out_dir']}/{cfg.name}/obj/{data['name']}_smpl.glb")
300
+
301
+ smpl_info = {'betas': optimed_betas,
302
+ 'pose': optimed_pose,
303
+ 'orient': optimed_orient,
304
+ 'trans': optimed_trans}
305
+
306
+ np.save(
307
+ f"{config_dict['out_dir']}/{cfg.name}/obj/{data['name']}_smpl.npy", smpl_info, allow_pickle=True)
308
+
309
+ # ------------------------------------------------------------------------------------------------------------------
310
+
311
+ # cloth optimization
312
+
313
+ # cloth recon
314
+ in_tensor.update(
315
+ dataset.compute_vis_cmap(
316
+ in_tensor["smpl_verts"][0], in_tensor["smpl_faces"][0]
317
+ )
318
+ )
319
+
320
+ if cfg.net.prior_type == "pamir":
321
+ in_tensor.update(
322
+ dataset.compute_voxel_verts(
323
+ optimed_pose,
324
+ optimed_orient,
325
+ optimed_betas,
326
+ optimed_trans,
327
+ data["scale"],
328
+ )
329
+ )
330
+
331
+ with torch.no_grad():
332
+ verts_pr, faces_pr, _ = model.test_single(in_tensor)
333
+
334
+ recon_obj = trimesh.Trimesh(
335
+ verts_pr, faces_pr, process=False, maintains_order=True
336
+ )
337
+ recon_obj.export(
338
+ os.path.join(config_dict['out_dir'], cfg.name,
339
+ f"obj/{data['name']}_recon.obj")
340
+ )
341
+
342
+ recon_obj.export(
343
+ os.path.join(config_dict['out_dir'], cfg.name,
344
+ f"obj/{data['name']}_recon.glb")
345
+ )
346
+
347
+ # Isotropic Explicit Remeshing for better geometry topology
348
+ verts_refine, faces_refine = remesh(os.path.join(config_dict['out_dir'], cfg.name,
349
+ f"obj/{data['name']}_recon.obj"), 0.5, device)
350
+
351
+ # define local_affine deform verts
352
+ mesh_pr = Meshes(verts_refine, faces_refine).to(device)
353
+ local_affine_model = LocalAffine(
354
+ mesh_pr.verts_padded().shape[1], mesh_pr.verts_padded().shape[0], mesh_pr.edges_packed()).to(device)
355
+ optimizer_cloth = torch.optim.Adam(
356
+ [{'params': local_affine_model.parameters()}], lr=1e-4, amsgrad=True)
357
+
358
+ scheduler_cloth = torch.optim.lr_scheduler.ReduceLROnPlateau(
359
+ optimizer_cloth,
360
+ mode="min",
361
+ factor=0.1,
362
+ verbose=0,
363
+ min_lr=1e-5,
364
+ patience=config_dict['patience'],
365
+ )
366
+
367
+ final = None
368
+
369
+ if config_dict['loop_cloth'] > 0:
370
+
371
+ loop_cloth = tqdm(range(config_dict['loop_cloth']))
372
+
373
+ for i in loop_cloth:
374
+
375
+ optimizer_cloth.zero_grad()
376
+
377
+ deformed_verts, stiffness, rigid = local_affine_model(
378
+ verts_refine.to(device), return_stiff=True)
379
+ mesh_pr = mesh_pr.update_padded(deformed_verts)
380
+
381
+ # losses for laplacian, edge, normal consistency
382
+ update_mesh_shape_prior_losses(mesh_pr, losses)
383
+
384
+ in_tensor["P_normal_F"], in_tensor["P_normal_B"] = dataset.render_normal(
385
+ mesh_pr.verts_padded(), mesh_pr.faces_padded())
386
+
387
+ diff_F_cloth = torch.abs(
388
+ in_tensor["P_normal_F"] - in_tensor["normal_F"])
389
+ diff_B_cloth = torch.abs(
390
+ in_tensor["P_normal_B"] - in_tensor["normal_B"])
391
+
392
+ losses["cloth"]["value"] = (diff_F_cloth + diff_B_cloth).mean()
393
+ losses["stiffness"]["value"] = torch.mean(stiffness)
394
+ losses["rigid"]["value"] = torch.mean(rigid)
395
+
396
+ # Weighted sum of the losses
397
+ cloth_loss = torch.tensor(0.0, requires_grad=True).to(device)
398
+ pbar_desc = "Cloth Refinement --- "
399
+
400
+ for k in losses.keys():
401
+ if k not in ["normal", "silhouette"] and losses[k]["weight"] > 0.0:
402
+ cloth_loss = cloth_loss + \
403
+ losses[k]["value"] * losses[k]["weight"]
404
+ pbar_desc += f"{k}:{losses[k]['value']* losses[k]['weight']:.5f} | "
405
+
406
+ pbar_desc += f"Total: {cloth_loss:.5f}"
407
+ loop_cloth.set_description(pbar_desc)
408
+
409
+ # update params
410
+ cloth_loss.backward(retain_graph=True)
411
+ optimizer_cloth.step()
412
+ scheduler_cloth.step(cloth_loss)
413
+
414
+
415
+ final = trimesh.Trimesh(
416
+ mesh_pr.verts_packed().detach().squeeze(0).cpu(),
417
+ mesh_pr.faces_packed().detach().squeeze(0).cpu(),
418
+ process=False, maintains_order=True
419
+ )
420
+ final_colors = query_color(
421
+ mesh_pr.verts_packed().detach().squeeze(0).cpu(),
422
+ mesh_pr.faces_packed().detach().squeeze(0).cpu(),
423
+ in_tensor["image"],
424
+ device=device,
425
+ )
426
+ final.visual.vertex_colors = final_colors
427
+ final.export(
428
+ f"{config_dict['out_dir']}/{cfg.name}/obj/{data['name']}_refine.glb")
429
+
430
+
431
+ # always export visualized video regardless of the cloth refinment
432
+ if final is not None:
433
+ verts_lst = [verts_pr, final.vertices]
434
+ faces_lst = [faces_pr, final.faces]
435
+ else:
436
+ verts_lst = [verts_pr]
437
+ faces_lst = [faces_pr]
438
+
439
+ # self-rotated video
440
+ dataset.render.load_meshes(
441
+ verts_lst, faces_lst)
442
+ dataset.render.get_rendered_video(
443
+ [data["ori_image"], rgb_norm],
444
+ os.path.join(config_dict['out_dir'], cfg.name,
445
+ f"vid/{data['name']}_cloth.mp4"),
446
+ )
447
+
448
+ smpl_path = f"{config_dict['out_dir']}/{cfg.name}/obj/{data['name']}_smpl.glb"
449
+ smpl_npy_path = f"{config_dict['out_dir']}/{cfg.name}/obj/{data['name']}_smpl.npy"
450
+ recon_path = f"{config_dict['out_dir']}/{cfg.name}/obj/{data['name']}_recon.glb"
451
+ refine_path = f"{config_dict['out_dir']}/{cfg.name}/obj/{data['name']}_refine.glb"
452
+
453
+ video_path = os.path.join(config_dict['out_dir'], cfg.name, f"vid/{data['name']}_cloth.mp4")
454
+ overlap_path = os.path.join(config_dict['out_dir'], cfg.name, f"png/{data['name']}_overlap.png")
455
+
456
+ torch.cuda.empty_cache()
457
+ del model
458
+ del dataset
459
+ del local_affine_model
460
+ del optimizer_smpl
461
+ del optimizer_cloth
462
+ del scheduler_smpl
463
+ del scheduler_cloth
464
+ del losses
465
+ del in_tensor
466
+
467
+ return [smpl_path, smpl_path, smpl_npy_path, recon_path, recon_path, refine_path, refine_path, video_path, overlap_path]
assets/garment_teaser.png ADDED

Git LFS Details

  • SHA256: 1bf1fde8dcec40a5b50a5eb3ba6cdeefad344348271b9b087d9f327efc5db845
  • Pointer size: 131 Bytes
  • Size of remote file: 594 kB
assets/intermediate_results.png ADDED

Git LFS Details

  • SHA256: 2daa92446130e9bf410ba55889740537c68f9c51f1799f89f2575581870c0d80
  • Pointer size: 131 Bytes
  • Size of remote file: 301 kB
assets/teaser.gif ADDED

Git LFS Details

  • SHA256: 0955111cbe83559ee8065b15dfed9f52da9e8190297c715d74d1a30cdee7cad5
  • Pointer size: 131 Bytes
  • Size of remote file: 382 kB
assets/thumbnail.png ADDED

Git LFS Details

  • SHA256: 5259d6e413242c63afe88027122eed783612ff9a9e48b9a9c51313f6bf66fb94
  • Pointer size: 130 Bytes
  • Size of remote file: 51.5 kB
configs/icon-filter.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: icon-filter
2
+ ckpt_dir: "./data/ckpt/"
3
+ resume_path: "https://huggingface.co/Yuliang/ICON/resolve/main/icon-filter.ckpt"
4
+ normal_path: "https://huggingface.co/Yuliang/ICON/resolve/main/normal.ckpt"
5
+
6
+ test_mode: True
7
+ batch_size: 1
8
+
9
+ net:
10
+ mlp_dim: [256, 512, 256, 128, 1]
11
+ res_layers: [2,3,4]
12
+ num_stack: 2
13
+ prior_type: "icon" # icon/pamir/icon
14
+ use_filter: True
15
+ in_geo: (('normal_F',3), ('normal_B',3))
16
+ in_nml: (('image',3), ('T_normal_F',3), ('T_normal_B',3))
17
+ smpl_feats: ['sdf', 'norm', 'vis', 'cmap']
18
+ gtype: 'HGPIFuNet'
19
+ norm_mlp: 'batch'
20
+ hourglass_dim: 6
21
+ smpl_dim: 7
22
+
23
+ # user defined
24
+ mcube_res: 512 # occupancy field resolution, higher --> more details
25
+ clean_mesh: False # if True, will remove floating pieces
configs/icon-nofilter.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: icon-nofilter
2
+ ckpt_dir: "./data/ckpt/"
3
+ resume_path: "https://huggingface.co/Yuliang/ICON/resolve/main/icon-nofilter.ckpt"
4
+ normal_path: "https://huggingface.co/Yuliang/ICON/resolve/main/normal.ckpt"
5
+
6
+ test_mode: True
7
+ batch_size: 1
8
+
9
+ net:
10
+ mlp_dim: [256, 512, 256, 128, 1]
11
+ res_layers: [2,3,4]
12
+ num_stack: 2
13
+ prior_type: "icon" # icon/pamir/icon
14
+ use_filter: False
15
+ in_geo: (('normal_F',3), ('normal_B',3))
16
+ in_nml: (('image',3), ('T_normal_F',3), ('T_normal_B',3))
17
+ smpl_feats: ['sdf', 'norm', 'vis', 'cmap']
18
+ gtype: 'HGPIFuNet'
19
+ norm_mlp: 'batch'
20
+ hourglass_dim: 6
21
+ smpl_dim: 7
22
+
23
+ # user defined
24
+ mcube_res: 512 # occupancy field resolution, higher --> more details
25
+ clean_mesh: False # if True, will remove floating pieces
configs/pamir.yaml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: pamir
2
+ ckpt_dir: "./data/ckpt/"
3
+ resume_path: "https://huggingface.co/Yuliang/ICON/resolve/main/pamir.ckpt"
4
+ normal_path: "https://huggingface.co/Yuliang/ICON/resolve/main/normal.ckpt"
5
+
6
+ test_mode: True
7
+ batch_size: 1
8
+
9
+ net:
10
+ mlp_dim: [256, 512, 256, 128, 1]
11
+ res_layers: [2,3,4]
12
+ num_stack: 2
13
+ prior_type: "pamir" # icon/pamir/icon
14
+ use_filter: True
15
+ in_geo: (('image',3), ('normal_F',3), ('normal_B',3))
16
+ in_nml: (('image',3), ('T_normal_F',3), ('T_normal_B',3))
17
+ gtype: 'HGPIFuNet'
18
+ norm_mlp: 'batch'
19
+ hourglass_dim: 6
20
+ voxel_dim: 7
21
+
22
+ # user defined
23
+ mcube_res: 512 # occupancy field resolution, higher --> more details
24
+ clean_mesh: False # if True, will remove floating pieces
configs/pifu.yaml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: pifu
2
+ ckpt_dir: "./data/ckpt/"
3
+ resume_path: "https://huggingface.co/Yuliang/ICON/resolve/main/pifu.ckpt"
4
+ normal_path: "https://huggingface.co/Yuliang/ICON/resolve/main/normal.ckpt"
5
+
6
+ test_mode: True
7
+ batch_size: 1
8
+
9
+ net:
10
+ mlp_dim: [256, 512, 256, 128, 1]
11
+ res_layers: [2,3,4]
12
+ num_stack: 2
13
+ prior_type: "pifu" # icon/pamir/icon
14
+ use_filter: True
15
+ in_geo: (('image',3), ('normal_F',3), ('normal_B',3))
16
+ in_nml: (('image',3), ('T_normal_F',3), ('T_normal_B',3))
17
+ gtype: 'HGPIFuNet'
18
+ norm_mlp: 'batch'
19
+ hourglass_dim: 12
20
+
21
+
22
+ # user defined
23
+ mcube_res: 512 # occupancy field resolution, higher --> more details
24
+ clean_mesh: False # if True, will remove floating pieces
examples/22097467bffc92d4a5c4246f7d4edb75.png ADDED

Git LFS Details

  • SHA256: f37625631d1cea79fca0c77d6a809e827f86d2ddc51515abaade0801b9ef1a57
  • Pointer size: 131 Bytes
  • Size of remote file: 448 kB
examples/44c0f84c957b6b9bdf77662af5bb7078.png ADDED

Git LFS Details

  • SHA256: b5ccc3ff6e99b32fed04bdd8f72873e7d987e088e83bbb235152db0500fdc6dc
  • Pointer size: 131 Bytes
  • Size of remote file: 182 kB
examples/5a6a25963db2f667441d5076972c207c.png ADDED

Git LFS Details

  • SHA256: a4e0773d094b45a7c496292e5352166d6f47e469c2c6101ffa9536e44007a4e3
  • Pointer size: 131 Bytes
  • Size of remote file: 523 kB
examples/8da7ceb94669c2f65cbd28022e1f9876.png ADDED

Git LFS Details

  • SHA256: 7be8a036e6f3d11db05f0c6a93de165dae4c2afc052d09f6660c43a0a0484e99
  • Pointer size: 131 Bytes
  • Size of remote file: 286 kB
examples/923d65f767c85a42212cae13fba3750b.png ADDED

Git LFS Details

  • SHA256: 11310b5ef67f69d9efe7f00cced6e4e4a7c55ade2d928c3005ec102615d93ac0
  • Pointer size: 131 Bytes
  • Size of remote file: 616 kB
examples/959c4c726a69901ce71b93a9242ed900.png ADDED

Git LFS Details

  • SHA256: fc0b5e48a0cf3fbe664e2fcc54212167f8b973efdca74bbe1e8f5dd2ab23883e
  • Pointer size: 131 Bytes
  • Size of remote file: 476 kB
examples/c9856a2bc31846d684cbb965457fad59.png ADDED

Git LFS Details

  • SHA256: b97743cb85d8b2db10f86b5216a67f0df0ff84b71665d2be451dcd517c557fb6
  • Pointer size: 131 Bytes
  • Size of remote file: 157 kB
examples/e1e7622af7074a022f5d96dc16672517.png ADDED

Git LFS Details

  • SHA256: badb5a8c2d9591aa4c71915795cb3d229678cad3612f2ee36d399174de32004e
  • Pointer size: 131 Bytes
  • Size of remote file: 652 kB
examples/fb9d20fdb93750584390599478ecf86e.png ADDED

Git LFS Details

  • SHA256: ae80334944bb3c9496565dbe28e0ec30d2150344b600b6aac5c917c8c6ef4f1f
  • Pointer size: 131 Bytes
  • Size of remote file: 623 kB
examples/slack_trial2-000150.png ADDED

Git LFS Details

  • SHA256: 3a370b52849ea33b117608dc3179398cb4b36293ba49c93d91dea10887a54ff2
  • Pointer size: 130 Bytes
  • Size of remote file: 71.3 kB
lib/__init__.py ADDED
File without changes
lib/common/__init__.py ADDED
File without changes
lib/common/cloth_extraction.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import json
3
+ import os
4
+ import itertools
5
+ import trimesh
6
+ from matplotlib.path import Path
7
+ from collections import Counter
8
+ from sklearn.neighbors import KNeighborsClassifier
9
+
10
+
11
+ def load_segmentation(path, shape):
12
+ """
13
+ Get a segmentation mask for a given image
14
+ Arguments:
15
+ path: path to the segmentation json file
16
+ shape: shape of the output mask
17
+ Returns:
18
+ Returns a segmentation mask
19
+ """
20
+ with open(path) as json_file:
21
+ dict = json.load(json_file)
22
+ segmentations = []
23
+ for key, val in dict.items():
24
+ if not key.startswith('item'):
25
+ continue
26
+
27
+ # Each item can have multiple polygons. Combine them to one
28
+ # segmentation_coord = list(itertools.chain.from_iterable(val['segmentation']))
29
+ # segmentation_coord = np.round(np.array(segmentation_coord)).astype(int)
30
+
31
+ coordinates = []
32
+ for segmentation_coord in val['segmentation']:
33
+ # The format before is [x1,y1, x2, y2, ....]
34
+ x = segmentation_coord[::2]
35
+ y = segmentation_coord[1::2]
36
+ xy = np.vstack((x, y)).T
37
+ coordinates.append(xy)
38
+
39
+ segmentations.append(
40
+ {'type': val['category_name'], 'type_id': val['category_id'], 'coordinates': coordinates})
41
+
42
+ return segmentations
43
+
44
+
45
+ def smpl_to_recon_labels(recon, smpl, k=1):
46
+ """
47
+ Get the bodypart labels for the recon object by using the labels from the corresponding smpl object
48
+ Arguments:
49
+ recon: trimesh object (fully clothed model)
50
+ shape: trimesh object (smpl model)
51
+ k: number of nearest neighbours to use
52
+ Returns:
53
+ Returns a dictionary containing the bodypart and the corresponding indices
54
+ """
55
+ smpl_vert_segmentation = json.load(
56
+ open(os.path.join(os.path.dirname(__file__), 'smpl_vert_segmentation.json')))
57
+ n = smpl.vertices.shape[0]
58
+ y = np.array([None] * n)
59
+ for key, val in smpl_vert_segmentation.items():
60
+ y[val] = key
61
+
62
+ classifier = KNeighborsClassifier(n_neighbors=1)
63
+ classifier.fit(smpl.vertices, y)
64
+
65
+ y_pred = classifier.predict(recon.vertices)
66
+
67
+ recon_labels = {}
68
+ for key in smpl_vert_segmentation.keys():
69
+ recon_labels[key] = list(np.argwhere(
70
+ y_pred == key).flatten().astype(int))
71
+
72
+ return recon_labels
73
+
74
+
75
+ def extract_cloth(recon, segmentation, K, R, t, smpl=None):
76
+ """
77
+ Extract a portion of a mesh using 2d segmentation coordinates
78
+ Arguments:
79
+ recon: fully clothed mesh
80
+ seg_coord: segmentation coordinates in 2D (NDC)
81
+ K: intrinsic matrix of the projection
82
+ R: rotation matrix of the projection
83
+ t: translation vector of the projection
84
+ Returns:
85
+ Returns a submesh using the segmentation coordinates
86
+ """
87
+ seg_coord = segmentation['coord_normalized']
88
+ mesh = trimesh.Trimesh(recon.vertices, recon.faces)
89
+ extrinsic = np.zeros((3, 4))
90
+ extrinsic[:3, :3] = R
91
+ extrinsic[:, 3] = t
92
+ P = K[:3, :3] @ extrinsic
93
+
94
+ P_inv = np.linalg.pinv(P)
95
+
96
+ # Each segmentation can contain multiple polygons
97
+ # We need to check them separately
98
+ points_so_far = []
99
+ faces = recon.faces
100
+ for polygon in seg_coord:
101
+ n = len(polygon)
102
+ coords_h = np.hstack((polygon, np.ones((n, 1))))
103
+ # Apply the inverse projection on homogeneus 2D coordinates to get the corresponding 3d Coordinates
104
+ XYZ = P_inv @ coords_h[:, :, None]
105
+ XYZ = XYZ.reshape((XYZ.shape[0], XYZ.shape[1]))
106
+ XYZ = XYZ[:, :3] / XYZ[:, 3, None]
107
+
108
+ p = Path(XYZ[:, :2])
109
+
110
+ grid = p.contains_points(recon.vertices[:, :2])
111
+ indeces = np.argwhere(grid == True)
112
+ points_so_far += list(indeces.flatten())
113
+
114
+ if smpl is not None:
115
+ num_verts = recon.vertices.shape[0]
116
+ recon_labels = smpl_to_recon_labels(recon, smpl)
117
+ body_parts_to_remove = ['rightHand', 'leftToeBase', 'leftFoot', 'rightFoot', 'head',
118
+ 'leftHandIndex1', 'rightHandIndex1', 'rightToeBase', 'leftHand', 'rightHand']
119
+ type = segmentation['type_id']
120
+
121
+ # Remove additional bodyparts that are most likely not part of the segmentation but might intersect (e.g. hand in front of torso)
122
+ # https://github.com/switchablenorms/DeepFashion2
123
+ # Short sleeve clothes
124
+ if type == 1 or type == 3 or type == 10:
125
+ body_parts_to_remove += ['leftForeArm', 'rightForeArm']
126
+ # No sleeves at all or lower body clothes
127
+ elif type == 5 or type == 6 or type == 12 or type == 13 or type == 8 or type == 9:
128
+ body_parts_to_remove += ['leftForeArm',
129
+ 'rightForeArm', 'leftArm', 'rightArm']
130
+ # Shorts
131
+ elif type == 7:
132
+ body_parts_to_remove += ['leftLeg', 'rightLeg',
133
+ 'leftForeArm', 'rightForeArm', 'leftArm', 'rightArm']
134
+
135
+ verts_to_remove = list(itertools.chain.from_iterable(
136
+ [recon_labels[part] for part in body_parts_to_remove]))
137
+
138
+ label_mask = np.zeros(num_verts, dtype=bool)
139
+ label_mask[verts_to_remove] = True
140
+
141
+ seg_mask = np.zeros(num_verts, dtype=bool)
142
+ seg_mask[points_so_far] = True
143
+
144
+ # Remove points that belong to other bodyparts
145
+ # If a vertice in pointsSoFar is included in the bodyparts to remove, then these points should be removed
146
+ extra_verts_to_remove = np.array(list(seg_mask) and list(label_mask))
147
+
148
+ combine_mask = np.zeros(num_verts, dtype=bool)
149
+ combine_mask[points_so_far] = True
150
+ combine_mask[extra_verts_to_remove] = False
151
+
152
+ all_indices = np.argwhere(combine_mask == True).flatten()
153
+
154
+ i_x = np.where(np.in1d(faces[:, 0], all_indices))[0]
155
+ i_y = np.where(np.in1d(faces[:, 1], all_indices))[0]
156
+ i_z = np.where(np.in1d(faces[:, 2], all_indices))[0]
157
+
158
+ faces_to_keep = np.array(list(set(i_x).union(i_y).union(i_z)))
159
+ mask = np.zeros(len(recon.faces), dtype=bool)
160
+ if len(faces_to_keep) > 0:
161
+ mask[faces_to_keep] = True
162
+
163
+ mesh.update_faces(mask)
164
+ mesh.remove_unreferenced_vertices()
165
+
166
+ # mesh.rezero()
167
+
168
+ return mesh
169
+
170
+ return None
lib/common/config.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
5
+ # holder of all proprietary rights on this computer program.
6
+ # You can only use this computer program if you have closed
7
+ # a license agreement with MPG or you get the right to use the computer
8
+ # program from someone who is authorized to grant you that right.
9
+ # Any use of the computer program without a valid license is prohibited and
10
+ # liable to prosecution.
11
+ #
12
+ # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
13
+ # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
14
+ # for Intelligent Systems. All rights reserved.
15
+ #
16
+ # Contact: [email protected]
17
+
18
+ from yacs.config import CfgNode as CN
19
+ import os
20
+
21
+ _C = CN(new_allowed=True)
22
+
23
+ # needed by trainer
24
+ _C.name = 'default'
25
+ _C.gpus = [0]
26
+ _C.test_gpus = [1]
27
+ _C.root = "./data/"
28
+ _C.ckpt_dir = './data/ckpt/'
29
+ _C.resume_path = ''
30
+ _C.normal_path = ''
31
+ _C.corr_path = ''
32
+ _C.results_path = './data/results/'
33
+ _C.projection_mode = 'orthogonal'
34
+ _C.num_views = 1
35
+ _C.sdf = False
36
+ _C.sdf_clip = 5.0
37
+
38
+ _C.lr_G = 1e-3
39
+ _C.lr_C = 1e-3
40
+ _C.lr_N = 2e-4
41
+ _C.weight_decay = 0.0
42
+ _C.momentum = 0.0
43
+ _C.optim = 'RMSprop'
44
+ _C.schedule = [5, 10, 15]
45
+ _C.gamma = 0.1
46
+
47
+ _C.overfit = False
48
+ _C.resume = False
49
+ _C.test_mode = False
50
+ _C.test_uv = False
51
+ _C.draw_geo_thres = 0.60
52
+ _C.num_sanity_val_steps = 2
53
+ _C.fast_dev = 0
54
+ _C.get_fit = False
55
+ _C.agora = False
56
+ _C.optim_cloth = False
57
+ _C.optim_body = False
58
+ _C.mcube_res = 256
59
+ _C.clean_mesh = True
60
+ _C.remesh = False
61
+
62
+ _C.batch_size = 4
63
+ _C.num_threads = 8
64
+
65
+ _C.num_epoch = 10
66
+ _C.freq_plot = 0.01
67
+ _C.freq_show_train = 0.1
68
+ _C.freq_show_val = 0.2
69
+ _C.freq_eval = 0.5
70
+ _C.accu_grad_batch = 4
71
+
72
+ _C.test_items = ['sv', 'mv', 'mv-fusion', 'hybrid', 'dc-pred', 'gt']
73
+
74
+ _C.net = CN()
75
+ _C.net.gtype = 'HGPIFuNet'
76
+ _C.net.ctype = 'resnet18'
77
+ _C.net.classifierIMF = 'MultiSegClassifier'
78
+ _C.net.netIMF = 'resnet18'
79
+ _C.net.norm = 'group'
80
+ _C.net.norm_mlp = 'group'
81
+ _C.net.norm_color = 'group'
82
+ _C.net.hg_down = 'ave_pool'
83
+ _C.net.num_views = 1
84
+
85
+ # kernel_size, stride, dilation, padding
86
+
87
+ _C.net.conv1 = [7, 2, 1, 3]
88
+ _C.net.conv3x3 = [3, 1, 1, 1]
89
+
90
+ _C.net.num_stack = 4
91
+ _C.net.num_hourglass = 2
92
+ _C.net.hourglass_dim = 256
93
+ _C.net.voxel_dim = 32
94
+ _C.net.resnet_dim = 120
95
+ _C.net.mlp_dim = [320, 1024, 512, 256, 128, 1]
96
+ _C.net.mlp_dim_knn = [320, 1024, 512, 256, 128, 3]
97
+ _C.net.mlp_dim_color = [513, 1024, 512, 256, 128, 3]
98
+ _C.net.mlp_dim_multiseg = [1088, 2048, 1024, 500]
99
+ _C.net.res_layers = [2, 3, 4]
100
+ _C.net.filter_dim = 256
101
+ _C.net.smpl_dim = 3
102
+
103
+ _C.net.cly_dim = 3
104
+ _C.net.soft_dim = 64
105
+ _C.net.z_size = 200.0
106
+ _C.net.N_freqs = 10
107
+ _C.net.geo_w = 0.1
108
+ _C.net.norm_w = 0.1
109
+ _C.net.dc_w = 0.1
110
+ _C.net.C_cat_to_G = False
111
+
112
+ _C.net.skip_hourglass = True
113
+ _C.net.use_tanh = True
114
+ _C.net.soft_onehot = True
115
+ _C.net.no_residual = True
116
+ _C.net.use_attention = False
117
+
118
+ _C.net.prior_type = "sdf"
119
+ _C.net.smpl_feats = ['sdf', 'cmap', 'norm', 'vis']
120
+ _C.net.use_filter = True
121
+ _C.net.use_cc = False
122
+ _C.net.use_PE = False
123
+ _C.net.use_IGR = False
124
+ _C.net.in_geo = ()
125
+ _C.net.in_nml = ()
126
+
127
+ _C.dataset = CN()
128
+ _C.dataset.root = ''
129
+ _C.dataset.set_splits = [0.95, 0.04]
130
+ _C.dataset.types = [
131
+ "3dpeople", "axyz", "renderpeople", "renderpeople_p27", "humanalloy"
132
+ ]
133
+ _C.dataset.scales = [1.0, 100.0, 1.0, 1.0, 100.0 / 39.37]
134
+ _C.dataset.rp_type = "pifu900"
135
+ _C.dataset.th_type = 'train'
136
+ _C.dataset.input_size = 512
137
+ _C.dataset.rotation_num = 3
138
+ _C.dataset.num_precomp = 10 # Number of segmentation classifiers
139
+ _C.dataset.num_multiseg = 500 # Number of categories per classifier
140
+ _C.dataset.num_knn = 10 # for loss/error
141
+ _C.dataset.num_knn_dis = 20 # for accuracy
142
+ _C.dataset.num_verts_max = 20000
143
+ _C.dataset.zray_type = False
144
+ _C.dataset.online_smpl = False
145
+ _C.dataset.noise_type = ['z-trans', 'pose', 'beta']
146
+ _C.dataset.noise_scale = [0.0, 0.0, 0.0]
147
+ _C.dataset.num_sample_geo = 10000
148
+ _C.dataset.num_sample_color = 0
149
+ _C.dataset.num_sample_seg = 0
150
+ _C.dataset.num_sample_knn = 10000
151
+
152
+ _C.dataset.sigma_geo = 5.0
153
+ _C.dataset.sigma_color = 0.10
154
+ _C.dataset.sigma_seg = 0.10
155
+ _C.dataset.thickness_threshold = 20.0
156
+ _C.dataset.ray_sample_num = 2
157
+ _C.dataset.semantic_p = False
158
+ _C.dataset.remove_outlier = False
159
+
160
+ _C.dataset.train_bsize = 1.0
161
+ _C.dataset.val_bsize = 1.0
162
+ _C.dataset.test_bsize = 1.0
163
+
164
+
165
+ def get_cfg_defaults():
166
+ """Get a yacs CfgNode object with default values for my_project."""
167
+ # Return a clone so that the defaults will not be altered
168
+ # This is for the "local variable" use pattern
169
+ return _C.clone()
170
+
171
+
172
+ # Alternatively, provide a way to import the defaults as
173
+ # a global singleton:
174
+ cfg = _C # users can `from config import cfg`
175
+
176
+ # cfg = get_cfg_defaults()
177
+ # cfg.merge_from_file('./configs/example.yaml')
178
+
179
+ # # Now override from a list (opts could come from the command line)
180
+ # opts = ['dataset.root', './data/XXXX', 'learning_rate', '1e-2']
181
+ # cfg.merge_from_list(opts)
182
+
183
+
184
+ def update_cfg(cfg_file):
185
+ # cfg = get_cfg_defaults()
186
+ _C.merge_from_file(cfg_file)
187
+ # return cfg.clone()
188
+ return _C
189
+
190
+
191
+ def parse_args(args):
192
+ cfg_file = args.cfg_file
193
+ if args.cfg_file is not None:
194
+ cfg = update_cfg(args.cfg_file)
195
+ else:
196
+ cfg = get_cfg_defaults()
197
+
198
+ # if args.misc is not None:
199
+ # cfg.merge_from_list(args.misc)
200
+
201
+ return cfg
202
+
203
+
204
+ def parse_args_extend(args):
205
+ if args.resume:
206
+ if not os.path.exists(args.log_dir):
207
+ raise ValueError(
208
+ 'Experiment are set to resume mode, but log directory does not exist.'
209
+ )
210
+
211
+ # load log's cfg
212
+ cfg_file = os.path.join(args.log_dir, 'cfg.yaml')
213
+ cfg = update_cfg(cfg_file)
214
+
215
+ if args.misc is not None:
216
+ cfg.merge_from_list(args.misc)
217
+ else:
218
+ parse_args(args)
lib/common/render.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
4
+ # holder of all proprietary rights on this computer program.
5
+ # You can only use this computer program if you have closed
6
+ # a license agreement with MPG or you get the right to use the computer
7
+ # program from someone who is authorized to grant you that right.
8
+ # Any use of the computer program without a valid license is prohibited and
9
+ # liable to prosecution.
10
+ #
11
+ # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
12
+ # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
13
+ # for Intelligent Systems. All rights reserved.
14
+ #
15
+ # Contact: [email protected]
16
+
17
+ from pytorch3d.renderer import (
18
+ BlendParams,
19
+ blending,
20
+ look_at_view_transform,
21
+ FoVOrthographicCameras,
22
+ PointLights,
23
+ RasterizationSettings,
24
+ PointsRasterizationSettings,
25
+ PointsRenderer,
26
+ AlphaCompositor,
27
+ PointsRasterizer,
28
+ MeshRenderer,
29
+ MeshRasterizer,
30
+ SoftPhongShader,
31
+ SoftSilhouetteShader,
32
+ TexturesVertex,
33
+ )
34
+ from pytorch3d.renderer.mesh import TexturesVertex
35
+ from pytorch3d.structures import Meshes
36
+
37
+ import os
38
+
39
+ from lib.dataset.mesh_util import SMPLX, get_visibility
40
+ import lib.common.render_utils as util
41
+ import torch
42
+ import numpy as np
43
+ from PIL import Image
44
+ from tqdm import tqdm
45
+ import cv2
46
+ import math
47
+ from termcolor import colored
48
+
49
+
50
+ def image2vid(images, vid_path):
51
+
52
+ w, h = images[0].size
53
+ videodims = (w, h)
54
+ fourcc = cv2.VideoWriter_fourcc(*'XVID')
55
+ video = cv2.VideoWriter(vid_path, fourcc, 30, videodims)
56
+ for image in images:
57
+ video.write(cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR))
58
+ video.release()
59
+
60
+
61
+ def query_color(verts, faces, image, device):
62
+ """query colors from points and image
63
+
64
+ Args:
65
+ verts ([B, 3]): [query verts]
66
+ faces ([M, 3]): [query faces]
67
+ image ([B, 3, H, W]): [full image]
68
+
69
+ Returns:
70
+ [np.float]: [return colors]
71
+ """
72
+
73
+ verts = verts.float().to(device)
74
+ faces = faces.long().to(device)
75
+
76
+ (xy, z) = verts.split([2, 1], dim=1)
77
+ visibility = get_visibility(xy, z, faces[:, [0, 2, 1]]).flatten()
78
+ uv = xy.unsqueeze(0).unsqueeze(2) # [B, N, 2]
79
+ uv = uv * torch.tensor([1.0, -1.0]).type_as(uv)
80
+ colors = (torch.nn.functional.grid_sample(image, uv, align_corners=True)[
81
+ 0, :, :, 0].permute(1, 0) + 1.0) * 0.5 * 255.0
82
+ colors[visibility == 0.0] = ((Meshes(verts.unsqueeze(0), faces.unsqueeze(
83
+ 0)).verts_normals_padded().squeeze(0) + 1.0) * 0.5 * 255.0)[visibility == 0.0]
84
+
85
+ return colors.detach().cpu()
86
+
87
+
88
+ class cleanShader(torch.nn.Module):
89
+ def __init__(self, device="cpu", cameras=None, blend_params=None):
90
+ super().__init__()
91
+ self.cameras = cameras
92
+ self.blend_params = blend_params if blend_params is not None else BlendParams()
93
+
94
+ def forward(self, fragments, meshes, **kwargs):
95
+ cameras = kwargs.get("cameras", self.cameras)
96
+ if cameras is None:
97
+ msg = "Cameras must be specified either at initialization \
98
+ or in the forward pass of TexturedSoftPhongShader"
99
+
100
+ raise ValueError(msg)
101
+
102
+ # get renderer output
103
+ blend_params = kwargs.get("blend_params", self.blend_params)
104
+ texels = meshes.sample_textures(fragments)
105
+ images = blending.softmax_rgb_blend(
106
+ texels, fragments, blend_params, znear=-256, zfar=256
107
+ )
108
+
109
+ return images
110
+
111
+
112
+ class Render:
113
+ def __init__(self, size=512, device=torch.device("cuda:0")):
114
+ self.device = device
115
+ self.mesh_y_center = 100.0
116
+ self.dis = 100.0
117
+ self.scale = 1.0
118
+ self.size = size
119
+ self.cam_pos = [(0, 100, 100)]
120
+
121
+ self.mesh = None
122
+ self.deform_mesh = None
123
+ self.pcd = None
124
+ self.renderer = None
125
+ self.meshRas = None
126
+ self.type = None
127
+ self.knn = None
128
+ self.knn_inverse = None
129
+
130
+ self.smpl_seg = None
131
+ self.smpl_cmap = None
132
+
133
+ self.smplx = SMPLX()
134
+
135
+ self.uv_rasterizer = util.Pytorch3dRasterizer(self.size)
136
+
137
+ def get_camera(self, cam_id):
138
+
139
+ R, T = look_at_view_transform(
140
+ eye=[self.cam_pos[cam_id]],
141
+ at=((0, self.mesh_y_center, 0),),
142
+ up=((0, 1, 0),),
143
+ )
144
+
145
+ camera = FoVOrthographicCameras(
146
+ device=self.device,
147
+ R=R,
148
+ T=T,
149
+ znear=100.0,
150
+ zfar=-100.0,
151
+ max_y=100.0,
152
+ min_y=-100.0,
153
+ max_x=100.0,
154
+ min_x=-100.0,
155
+ scale_xyz=(self.scale * np.ones(3),),
156
+ )
157
+
158
+ return camera
159
+
160
+ def init_renderer(self, camera, type="clean_mesh", bg="gray"):
161
+
162
+ if "mesh" in type:
163
+
164
+ # rasterizer
165
+ self.raster_settings_mesh = RasterizationSettings(
166
+ image_size=self.size,
167
+ blur_radius=np.log(1.0 / 1e-4) * 1e-7,
168
+ faces_per_pixel=30,
169
+ )
170
+ self.meshRas = MeshRasterizer(
171
+ cameras=camera, raster_settings=self.raster_settings_mesh
172
+ )
173
+
174
+ if bg == "black":
175
+ blendparam = BlendParams(1e-4, 1e-4, (0.0, 0.0, 0.0))
176
+ elif bg == "white":
177
+ blendparam = BlendParams(1e-4, 1e-8, (1.0, 1.0, 1.0))
178
+ elif bg == "gray":
179
+ blendparam = BlendParams(1e-4, 1e-8, (0.5, 0.5, 0.5))
180
+
181
+ if type == "ori_mesh":
182
+
183
+ lights = PointLights(
184
+ device=self.device,
185
+ ambient_color=((0.8, 0.8, 0.8),),
186
+ diffuse_color=((0.2, 0.2, 0.2),),
187
+ specular_color=((0.0, 0.0, 0.0),),
188
+ location=[[0.0, 200.0, 0.0]],
189
+ )
190
+
191
+ self.renderer = MeshRenderer(
192
+ rasterizer=self.meshRas,
193
+ shader=SoftPhongShader(
194
+ device=self.device,
195
+ cameras=camera,
196
+ lights=lights,
197
+ blend_params=blendparam,
198
+ ),
199
+ )
200
+
201
+ if type == "silhouette":
202
+ self.raster_settings_silhouette = RasterizationSettings(
203
+ image_size=self.size,
204
+ blur_radius=np.log(1.0 / 1e-4 - 1.0) * 5e-5,
205
+ faces_per_pixel=50,
206
+ cull_backfaces=True,
207
+ )
208
+
209
+ self.silhouetteRas = MeshRasterizer(
210
+ cameras=camera, raster_settings=self.raster_settings_silhouette
211
+ )
212
+ self.renderer = MeshRenderer(
213
+ rasterizer=self.silhouetteRas, shader=SoftSilhouetteShader()
214
+ )
215
+
216
+ if type == "pointcloud":
217
+ self.raster_settings_pcd = PointsRasterizationSettings(
218
+ image_size=self.size, radius=0.006, points_per_pixel=10
219
+ )
220
+
221
+ self.pcdRas = PointsRasterizer(
222
+ cameras=camera, raster_settings=self.raster_settings_pcd
223
+ )
224
+ self.renderer = PointsRenderer(
225
+ rasterizer=self.pcdRas,
226
+ compositor=AlphaCompositor(background_color=(0, 0, 0)),
227
+ )
228
+
229
+ if type == "clean_mesh":
230
+
231
+ self.renderer = MeshRenderer(
232
+ rasterizer=self.meshRas,
233
+ shader=cleanShader(
234
+ device=self.device, cameras=camera, blend_params=blendparam
235
+ ),
236
+ )
237
+
238
+ def VF2Mesh(self, verts, faces):
239
+
240
+ if not torch.is_tensor(verts):
241
+ verts = torch.tensor(verts)
242
+ if not torch.is_tensor(faces):
243
+ faces = torch.tensor(faces)
244
+
245
+ if verts.ndimension() == 2:
246
+ verts = verts.unsqueeze(0).float()
247
+ if faces.ndimension() == 2:
248
+ faces = faces.unsqueeze(0).long()
249
+
250
+ verts = verts.to(self.device)
251
+ faces = faces.to(self.device)
252
+
253
+ mesh = Meshes(verts, faces).to(self.device)
254
+
255
+ mesh.textures = TexturesVertex(
256
+ verts_features=(mesh.verts_normals_padded() + 1.0) * 0.5
257
+ )
258
+
259
+ return mesh
260
+
261
+ def load_meshes(self, verts, faces):
262
+ """load mesh into the pytorch3d renderer
263
+
264
+ Args:
265
+ verts ([N,3]): verts
266
+ faces ([N,3]): faces
267
+ offset ([N,3]): offset
268
+ """
269
+
270
+ # camera setting
271
+ self.scale = 100.0
272
+ self.mesh_y_center = 0.0
273
+
274
+ self.cam_pos = [
275
+ (0, self.mesh_y_center, 100.0),
276
+ (100.0, self.mesh_y_center, 0),
277
+ (0, self.mesh_y_center, -100.0),
278
+ (-100.0, self.mesh_y_center, 0),
279
+ ]
280
+
281
+ self.type = "color"
282
+
283
+ if isinstance(verts, list):
284
+ self.meshes = []
285
+ for V, F in zip(verts, faces):
286
+ self.meshes.append(self.VF2Mesh(V, F))
287
+ else:
288
+ self.meshes = [self.VF2Mesh(verts, faces)]
289
+
290
+ def get_depth_map(self, cam_ids=[0, 2]):
291
+
292
+ depth_maps = []
293
+ for cam_id in cam_ids:
294
+ self.init_renderer(self.get_camera(cam_id), "clean_mesh", "gray")
295
+ fragments = self.meshRas(self.meshes[0])
296
+ depth_map = fragments.zbuf[..., 0].squeeze(0)
297
+ if cam_id == 2:
298
+ depth_map = torch.fliplr(depth_map)
299
+ depth_maps.append(depth_map)
300
+
301
+ return depth_maps
302
+
303
+ def get_rgb_image(self, cam_ids=[0, 2]):
304
+
305
+ images = []
306
+ for cam_id in range(len(self.cam_pos)):
307
+ if cam_id in cam_ids:
308
+ self.init_renderer(self.get_camera(
309
+ cam_id), "clean_mesh", "gray")
310
+ if len(cam_ids) == 4:
311
+ rendered_img = (
312
+ self.renderer(self.meshes[0])[
313
+ 0:1, :, :, :3].permute(0, 3, 1, 2)
314
+ - 0.5
315
+ ) * 2.0
316
+ else:
317
+ rendered_img = (
318
+ self.renderer(self.meshes[0])[
319
+ 0:1, :, :, :3].permute(0, 3, 1, 2)
320
+ - 0.5
321
+ ) * 2.0
322
+ if cam_id == 2 and len(cam_ids) == 2:
323
+ rendered_img = torch.flip(rendered_img, dims=[3])
324
+ images.append(rendered_img)
325
+
326
+ return images
327
+
328
+ def get_rendered_video(self, images, save_path):
329
+
330
+ self.cam_pos = []
331
+ for angle in range(0, 360, 3):
332
+ self.cam_pos.append(
333
+ (
334
+ 100.0 * math.cos(np.pi / 180 * angle),
335
+ self.mesh_y_center,
336
+ 100.0 * math.sin(np.pi / 180 * angle),
337
+ )
338
+ )
339
+
340
+ old_shape = np.array(images[0].shape[:2])
341
+ new_shape = np.around(
342
+ (self.size / old_shape[0]) * old_shape).astype(np.int)
343
+
344
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
345
+ video = cv2.VideoWriter(
346
+ save_path, fourcc, 30, (self.size * len(self.meshes) +
347
+ new_shape[1] * len(images), self.size)
348
+ )
349
+
350
+ pbar = tqdm(range(len(self.cam_pos)))
351
+ pbar.set_description(colored(f"exporting video {os.path.basename(save_path)}...", "blue"))
352
+ for cam_id in pbar:
353
+ self.init_renderer(self.get_camera(cam_id), "clean_mesh", "gray")
354
+
355
+ img_lst = [
356
+ np.array(Image.fromarray(img).resize(new_shape[::-1])).astype(np.uint8)[
357
+ :, :, [2, 1, 0]
358
+ ]
359
+ for img in images
360
+ ]
361
+
362
+ for mesh in self.meshes:
363
+ rendered_img = (
364
+ (self.renderer(mesh)[0, :, :, :3] * 255.0)
365
+ .detach()
366
+ .cpu()
367
+ .numpy()
368
+ .astype(np.uint8)
369
+ )
370
+
371
+ img_lst.append(rendered_img)
372
+ final_img = np.concatenate(img_lst, axis=1)
373
+ video.write(final_img)
374
+
375
+ video.release()
376
+
377
+ def get_silhouette_image(self, cam_ids=[0, 2]):
378
+
379
+ images = []
380
+ for cam_id in range(len(self.cam_pos)):
381
+ if cam_id in cam_ids:
382
+ self.init_renderer(self.get_camera(cam_id), "silhouette")
383
+ rendered_img = self.renderer(self.meshes[0])[0:1, :, :, 3]
384
+ if cam_id == 2 and len(cam_ids) == 2:
385
+ rendered_img = torch.flip(rendered_img, dims=[2])
386
+ images.append(rendered_img)
387
+
388
+ return images
lib/common/render_utils.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
5
+ # holder of all proprietary rights on this computer program.
6
+ # You can only use this computer program if you have closed
7
+ # a license agreement with MPG or you get the right to use the computer
8
+ # program from someone who is authorized to grant you that right.
9
+ # Any use of the computer program without a valid license is prohibited and
10
+ # liable to prosecution.
11
+ #
12
+ # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
13
+ # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
14
+ # for Intelligent Systems. All rights reserved.
15
+ #
16
+ # Contact: [email protected]
17
+
18
+ import torch
19
+ from torch import nn
20
+ import trimesh
21
+ import math
22
+ from typing import NewType
23
+ from pytorch3d.structures import Meshes
24
+ from pytorch3d.renderer.mesh import rasterize_meshes
25
+
26
+ Tensor = NewType('Tensor', torch.Tensor)
27
+
28
+
29
+ def solid_angles(points: Tensor,
30
+ triangles: Tensor,
31
+ thresh: float = 1e-8) -> Tensor:
32
+ ''' Compute solid angle between the input points and triangles
33
+ Follows the method described in:
34
+ The Solid Angle of a Plane Triangle
35
+ A. VAN OOSTEROM AND J. STRACKEE
36
+ IEEE TRANSACTIONS ON BIOMEDICAL ENGINEERING,
37
+ VOL. BME-30, NO. 2, FEBRUARY 1983
38
+ Parameters
39
+ -----------
40
+ points: BxQx3
41
+ Tensor of input query points
42
+ triangles: BxFx3x3
43
+ Target triangles
44
+ thresh: float
45
+ float threshold
46
+ Returns
47
+ -------
48
+ solid_angles: BxQxF
49
+ A tensor containing the solid angle between all query points
50
+ and input triangles
51
+ '''
52
+ # Center the triangles on the query points. Size should be BxQxFx3x3
53
+ centered_tris = triangles[:, None] - points[:, :, None, None]
54
+
55
+ # BxQxFx3
56
+ norms = torch.norm(centered_tris, dim=-1)
57
+
58
+ # Should be BxQxFx3
59
+ cross_prod = torch.cross(centered_tris[:, :, :, 1],
60
+ centered_tris[:, :, :, 2],
61
+ dim=-1)
62
+ # Should be BxQxF
63
+ numerator = (centered_tris[:, :, :, 0] * cross_prod).sum(dim=-1)
64
+ del cross_prod
65
+
66
+ dot01 = (centered_tris[:, :, :, 0] * centered_tris[:, :, :, 1]).sum(dim=-1)
67
+ dot12 = (centered_tris[:, :, :, 1] * centered_tris[:, :, :, 2]).sum(dim=-1)
68
+ dot02 = (centered_tris[:, :, :, 0] * centered_tris[:, :, :, 2]).sum(dim=-1)
69
+ del centered_tris
70
+
71
+ denominator = (norms.prod(dim=-1) + dot01 * norms[:, :, :, 2] +
72
+ dot02 * norms[:, :, :, 1] + dot12 * norms[:, :, :, 0])
73
+ del dot01, dot12, dot02, norms
74
+
75
+ # Should be BxQ
76
+ solid_angle = torch.atan2(numerator, denominator)
77
+ del numerator, denominator
78
+
79
+ torch.cuda.empty_cache()
80
+
81
+ return 2 * solid_angle
82
+
83
+
84
+ def winding_numbers(points: Tensor,
85
+ triangles: Tensor,
86
+ thresh: float = 1e-8) -> Tensor:
87
+ ''' Uses winding_numbers to compute inside/outside
88
+ Robust inside-outside segmentation using generalized winding numbers
89
+ Alec Jacobson,
90
+ Ladislav Kavan,
91
+ Olga Sorkine-Hornung
92
+ Fast Winding Numbers for Soups and Clouds SIGGRAPH 2018
93
+ Gavin Barill
94
+ NEIL G. Dickson
95
+ Ryan Schmidt
96
+ David I.W. Levin
97
+ and Alec Jacobson
98
+ Parameters
99
+ -----------
100
+ points: BxQx3
101
+ Tensor of input query points
102
+ triangles: BxFx3x3
103
+ Target triangles
104
+ thresh: float
105
+ float threshold
106
+ Returns
107
+ -------
108
+ winding_numbers: BxQ
109
+ A tensor containing the Generalized winding numbers
110
+ '''
111
+ # The generalized winding number is the sum of solid angles of the point
112
+ # with respect to all triangles.
113
+ return 1 / (4 * math.pi) * solid_angles(points, triangles,
114
+ thresh=thresh).sum(dim=-1)
115
+
116
+
117
+ def batch_contains(verts, faces, points):
118
+
119
+ B = verts.shape[0]
120
+ N = points.shape[1]
121
+
122
+ verts = verts.detach().cpu()
123
+ faces = faces.detach().cpu()
124
+ points = points.detach().cpu()
125
+ contains = torch.zeros(B, N)
126
+
127
+ for i in range(B):
128
+ contains[i] = torch.as_tensor(
129
+ trimesh.Trimesh(verts[i], faces[i]).contains(points[i]))
130
+
131
+ return 2.0 * (contains - 0.5)
132
+
133
+
134
+ def dict2obj(d):
135
+ # if isinstance(d, list):
136
+ # d = [dict2obj(x) for x in d]
137
+ if not isinstance(d, dict):
138
+ return d
139
+
140
+ class C(object):
141
+ pass
142
+
143
+ o = C()
144
+ for k in d:
145
+ o.__dict__[k] = dict2obj(d[k])
146
+ return o
147
+
148
+
149
+ def face_vertices(vertices, faces):
150
+ """
151
+ :param vertices: [batch size, number of vertices, 3]
152
+ :param faces: [batch size, number of faces, 3]
153
+ :return: [batch size, number of faces, 3, 3]
154
+ """
155
+
156
+ bs, nv = vertices.shape[:2]
157
+ bs, nf = faces.shape[:2]
158
+ device = vertices.device
159
+ faces = faces + (torch.arange(bs, dtype=torch.int32).to(device) *
160
+ nv)[:, None, None]
161
+ vertices = vertices.reshape((bs * nv, vertices.shape[-1]))
162
+
163
+ return vertices[faces.long()]
164
+
165
+
166
+ class Pytorch3dRasterizer(nn.Module):
167
+ """ Borrowed from https://github.com/facebookresearch/pytorch3d
168
+ Notice:
169
+ x,y,z are in image space, normalized
170
+ can only render squared image now
171
+ """
172
+
173
+ def __init__(self, image_size=224):
174
+ """
175
+ use fixed raster_settings for rendering faces
176
+ """
177
+ super().__init__()
178
+ raster_settings = {
179
+ 'image_size': image_size,
180
+ 'blur_radius': 0.0,
181
+ 'faces_per_pixel': 1,
182
+ 'bin_size': None,
183
+ 'max_faces_per_bin': None,
184
+ 'perspective_correct': True,
185
+ 'cull_backfaces': True,
186
+ }
187
+ raster_settings = dict2obj(raster_settings)
188
+ self.raster_settings = raster_settings
189
+
190
+ def forward(self, vertices, faces, attributes=None):
191
+ fixed_vertices = vertices.clone()
192
+ fixed_vertices[..., :2] = -fixed_vertices[..., :2]
193
+ meshes_screen = Meshes(verts=fixed_vertices.float(),
194
+ faces=faces.long())
195
+ raster_settings = self.raster_settings
196
+ pix_to_face, zbuf, bary_coords, dists = rasterize_meshes(
197
+ meshes_screen,
198
+ image_size=raster_settings.image_size,
199
+ blur_radius=raster_settings.blur_radius,
200
+ faces_per_pixel=raster_settings.faces_per_pixel,
201
+ bin_size=raster_settings.bin_size,
202
+ max_faces_per_bin=raster_settings.max_faces_per_bin,
203
+ perspective_correct=raster_settings.perspective_correct,
204
+ )
205
+ vismask = (pix_to_face > -1).float()
206
+ D = attributes.shape[-1]
207
+ attributes = attributes.clone()
208
+ attributes = attributes.view(attributes.shape[0] * attributes.shape[1],
209
+ 3, attributes.shape[-1])
210
+ N, H, W, K, _ = bary_coords.shape
211
+ mask = pix_to_face == -1
212
+ pix_to_face = pix_to_face.clone()
213
+ pix_to_face[mask] = 0
214
+ idx = pix_to_face.view(N * H * W * K, 1, 1).expand(N * H * W * K, 3, D)
215
+ pixel_face_vals = attributes.gather(0, idx).view(N, H, W, K, 3, D)
216
+ pixel_vals = (bary_coords[..., None] * pixel_face_vals).sum(dim=-2)
217
+ pixel_vals[mask] = 0 # Replace masked values in output.
218
+ pixel_vals = pixel_vals[:, :, :, 0].permute(0, 3, 1, 2)
219
+ pixel_vals = torch.cat(
220
+ [pixel_vals, vismask[:, :, :, 0][:, None, :, :]], dim=1)
221
+ return pixel_vals
lib/common/seg3d_lossless.py ADDED
@@ -0,0 +1,604 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
5
+ # holder of all proprietary rights on this computer program.
6
+ # You can only use this computer program if you have closed
7
+ # a license agreement with MPG or you get the right to use the computer
8
+ # program from someone who is authorized to grant you that right.
9
+ # Any use of the computer program without a valid license is prohibited and
10
+ # liable to prosecution.
11
+ #
12
+ # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
13
+ # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
14
+ # for Intelligent Systems. All rights reserved.
15
+ #
16
+ # Contact: [email protected]
17
+
18
+
19
+ from .seg3d_utils import (
20
+ create_grid3D,
21
+ plot_mask3D,
22
+ SmoothConv3D,
23
+ )
24
+
25
+ import torch
26
+ import torch.nn as nn
27
+ import numpy as np
28
+ import torch.nn.functional as F
29
+ import mcubes
30
+ from kaolin.ops.conversions import voxelgrids_to_trianglemeshes
31
+ import logging
32
+
33
+ logging.getLogger("lightning").setLevel(logging.ERROR)
34
+
35
+
36
+ class Seg3dLossless(nn.Module):
37
+ def __init__(self,
38
+ query_func,
39
+ b_min,
40
+ b_max,
41
+ resolutions,
42
+ channels=1,
43
+ balance_value=0.5,
44
+ align_corners=False,
45
+ visualize=False,
46
+ debug=False,
47
+ use_cuda_impl=False,
48
+ faster=False,
49
+ use_shadow=False,
50
+ **kwargs):
51
+ """
52
+ align_corners: same with how you process gt. (grid_sample / interpolate)
53
+ """
54
+ super().__init__()
55
+ self.query_func = query_func
56
+ self.register_buffer(
57
+ 'b_min',
58
+ torch.tensor(b_min).float().unsqueeze(1)) # [bz, 1, 3]
59
+ self.register_buffer(
60
+ 'b_max',
61
+ torch.tensor(b_max).float().unsqueeze(1)) # [bz, 1, 3]
62
+
63
+ # ti.init(arch=ti.cuda)
64
+ # self.mciso_taichi = MCISO(dim=3, N=resolutions[-1]-1)
65
+
66
+ if type(resolutions[0]) is int:
67
+ resolutions = torch.tensor([(res, res, res)
68
+ for res in resolutions])
69
+ else:
70
+ resolutions = torch.tensor(resolutions)
71
+ self.register_buffer('resolutions', resolutions)
72
+ self.batchsize = self.b_min.size(0)
73
+ assert self.batchsize == 1
74
+ self.balance_value = balance_value
75
+ self.channels = channels
76
+ assert self.channels == 1
77
+ self.align_corners = align_corners
78
+ self.visualize = visualize
79
+ self.debug = debug
80
+ self.use_cuda_impl = use_cuda_impl
81
+ self.faster = faster
82
+ self.use_shadow = use_shadow
83
+
84
+ for resolution in resolutions:
85
+ assert resolution[0] % 2 == 1 and resolution[1] % 2 == 1, \
86
+ f"resolution {resolution} need to be odd becuase of align_corner."
87
+
88
+ # init first resolution
89
+ init_coords = create_grid3D(0,
90
+ resolutions[-1] - 1,
91
+ steps=resolutions[0]) # [N, 3]
92
+ init_coords = init_coords.unsqueeze(0).repeat(self.batchsize, 1,
93
+ 1) # [bz, N, 3]
94
+ self.register_buffer('init_coords', init_coords)
95
+
96
+ # some useful tensors
97
+ calculated = torch.zeros(
98
+ (self.resolutions[-1][2], self.resolutions[-1][1],
99
+ self.resolutions[-1][0]),
100
+ dtype=torch.bool)
101
+ self.register_buffer('calculated', calculated)
102
+
103
+ gird8_offsets = torch.stack(
104
+ torch.meshgrid([
105
+ torch.tensor([-1, 0, 1]),
106
+ torch.tensor([-1, 0, 1]),
107
+ torch.tensor([-1, 0, 1])
108
+ ])).int().view(3, -1).t() # [27, 3]
109
+ self.register_buffer('gird8_offsets', gird8_offsets)
110
+
111
+ # smooth convs
112
+ self.smooth_conv3x3 = SmoothConv3D(in_channels=1,
113
+ out_channels=1,
114
+ kernel_size=3)
115
+ self.smooth_conv5x5 = SmoothConv3D(in_channels=1,
116
+ out_channels=1,
117
+ kernel_size=5)
118
+ self.smooth_conv7x7 = SmoothConv3D(in_channels=1,
119
+ out_channels=1,
120
+ kernel_size=7)
121
+ self.smooth_conv9x9 = SmoothConv3D(in_channels=1,
122
+ out_channels=1,
123
+ kernel_size=9)
124
+
125
+ def batch_eval(self, coords, **kwargs):
126
+ """
127
+ coords: in the coordinates of last resolution
128
+ **kwargs: for query_func
129
+ """
130
+ coords = coords.detach()
131
+ # normalize coords to fit in [b_min, b_max]
132
+ if self.align_corners:
133
+ coords2D = coords.float() / (self.resolutions[-1] - 1)
134
+ else:
135
+ step = 1.0 / self.resolutions[-1].float()
136
+ coords2D = coords.float() / self.resolutions[-1] + step / 2
137
+ coords2D = coords2D * (self.b_max - self.b_min) + self.b_min
138
+ # query function
139
+ occupancys = self.query_func(**kwargs, points=coords2D)
140
+ if type(occupancys) is list:
141
+ occupancys = torch.stack(occupancys) # [bz, C, N]
142
+ assert len(occupancys.size()) == 3, \
143
+ "query_func should return a occupancy with shape of [bz, C, N]"
144
+ return occupancys
145
+
146
+ def forward(self, **kwargs):
147
+ if self.faster:
148
+ return self._forward_faster(**kwargs)
149
+ else:
150
+ return self._forward(**kwargs)
151
+
152
+ def _forward_faster(self, **kwargs):
153
+ """
154
+ In faster mode, we make following changes to exchange accuracy for speed:
155
+ 1. no conflict checking: 4.88 fps -> 6.56 fps
156
+ 2. smooth_conv9x9 ~ smooth_conv3x3 for different resolution
157
+ 3. last step no examine
158
+ """
159
+ final_W = self.resolutions[-1][0]
160
+ final_H = self.resolutions[-1][1]
161
+ final_D = self.resolutions[-1][2]
162
+
163
+ for resolution in self.resolutions:
164
+ W, H, D = resolution
165
+ stride = (self.resolutions[-1] - 1) / (resolution - 1)
166
+
167
+ # first step
168
+ if torch.equal(resolution, self.resolutions[0]):
169
+ coords = self.init_coords.clone() # torch.long
170
+ occupancys = self.batch_eval(coords, **kwargs)
171
+ occupancys = occupancys.view(self.batchsize, self.channels, D,
172
+ H, W)
173
+ if (occupancys > 0.5).sum() == 0:
174
+ # return F.interpolate(
175
+ # occupancys, size=(final_D, final_H, final_W),
176
+ # mode="linear", align_corners=True)
177
+ return None
178
+
179
+ if self.visualize:
180
+ self.plot(occupancys, coords, final_D, final_H, final_W)
181
+
182
+ with torch.no_grad():
183
+ coords_accum = coords / stride
184
+
185
+ # last step
186
+ elif torch.equal(resolution, self.resolutions[-1]):
187
+
188
+ with torch.no_grad():
189
+ # here true is correct!
190
+ valid = F.interpolate(
191
+ (occupancys > self.balance_value).float(),
192
+ size=(D, H, W),
193
+ mode="trilinear",
194
+ align_corners=True)
195
+
196
+ # here true is correct!
197
+ occupancys = F.interpolate(occupancys.float(),
198
+ size=(D, H, W),
199
+ mode="trilinear",
200
+ align_corners=True)
201
+
202
+ # is_boundary = (valid > 0.0) & (valid < 1.0)
203
+ is_boundary = valid == 0.5
204
+
205
+ # next steps
206
+ else:
207
+ coords_accum *= 2
208
+
209
+ with torch.no_grad():
210
+ # here true is correct!
211
+ valid = F.interpolate(
212
+ (occupancys > self.balance_value).float(),
213
+ size=(D, H, W),
214
+ mode="trilinear",
215
+ align_corners=True)
216
+
217
+ # here true is correct!
218
+ occupancys = F.interpolate(occupancys.float(),
219
+ size=(D, H, W),
220
+ mode="trilinear",
221
+ align_corners=True)
222
+
223
+ is_boundary = (valid > 0.0) & (valid < 1.0)
224
+
225
+ with torch.no_grad():
226
+ if torch.equal(resolution, self.resolutions[1]):
227
+ is_boundary = (self.smooth_conv9x9(is_boundary.float())
228
+ > 0)[0, 0]
229
+ elif torch.equal(resolution, self.resolutions[2]):
230
+ is_boundary = (self.smooth_conv7x7(is_boundary.float())
231
+ > 0)[0, 0]
232
+ else:
233
+ is_boundary = (self.smooth_conv3x3(is_boundary.float())
234
+ > 0)[0, 0]
235
+
236
+ coords_accum = coords_accum.long()
237
+ is_boundary[coords_accum[0, :, 2], coords_accum[0, :, 1],
238
+ coords_accum[0, :, 0]] = False
239
+ point_coords = is_boundary.permute(
240
+ 2, 1, 0).nonzero(as_tuple=False).unsqueeze(0)
241
+ point_indices = (point_coords[:, :, 2] * H * W +
242
+ point_coords[:, :, 1] * W +
243
+ point_coords[:, :, 0])
244
+
245
+ R, C, D, H, W = occupancys.shape
246
+
247
+ # inferred value
248
+ coords = point_coords * stride
249
+
250
+ if coords.size(1) == 0:
251
+ continue
252
+ occupancys_topk = self.batch_eval(coords, **kwargs)
253
+
254
+ # put mask point predictions to the right places on the upsampled grid.
255
+ R, C, D, H, W = occupancys.shape
256
+ point_indices = point_indices.unsqueeze(1).expand(-1, C, -1)
257
+ occupancys = (occupancys.reshape(R, C, D * H * W).scatter_(
258
+ 2, point_indices, occupancys_topk).view(R, C, D, H, W))
259
+
260
+ with torch.no_grad():
261
+ voxels = coords / stride
262
+ coords_accum = torch.cat([voxels, coords_accum],
263
+ dim=1).unique(dim=1)
264
+
265
+ return occupancys[0, 0]
266
+
267
+ def _forward(self, **kwargs):
268
+ """
269
+ output occupancy field would be:
270
+ (bz, C, res, res)
271
+ """
272
+ final_W = self.resolutions[-1][0]
273
+ final_H = self.resolutions[-1][1]
274
+ final_D = self.resolutions[-1][2]
275
+
276
+ calculated = self.calculated.clone()
277
+
278
+ for resolution in self.resolutions:
279
+ W, H, D = resolution
280
+ stride = (self.resolutions[-1] - 1) / (resolution - 1)
281
+
282
+ if self.visualize:
283
+ this_stage_coords = []
284
+
285
+ # first step
286
+ if torch.equal(resolution, self.resolutions[0]):
287
+ coords = self.init_coords.clone() # torch.long
288
+ occupancys = self.batch_eval(coords, **kwargs)
289
+ occupancys = occupancys.view(self.batchsize, self.channels, D,
290
+ H, W)
291
+
292
+ if self.visualize:
293
+ self.plot(occupancys, coords, final_D, final_H, final_W)
294
+
295
+ with torch.no_grad():
296
+ coords_accum = coords / stride
297
+ calculated[coords[0, :, 2], coords[0, :, 1],
298
+ coords[0, :, 0]] = True
299
+
300
+ # next steps
301
+ else:
302
+ coords_accum *= 2
303
+
304
+ with torch.no_grad():
305
+ # here true is correct!
306
+ valid = F.interpolate(
307
+ (occupancys > self.balance_value).float(),
308
+ size=(D, H, W),
309
+ mode="trilinear",
310
+ align_corners=True)
311
+
312
+ # here true is correct!
313
+ occupancys = F.interpolate(occupancys.float(),
314
+ size=(D, H, W),
315
+ mode="trilinear",
316
+ align_corners=True)
317
+
318
+ is_boundary = (valid > 0.0) & (valid < 1.0)
319
+
320
+ with torch.no_grad():
321
+ # TODO
322
+ if self.use_shadow and torch.equal(resolution,
323
+ self.resolutions[-1]):
324
+ # larger z means smaller depth here
325
+ depth_res = resolution[2].item()
326
+ depth_index = torch.linspace(0,
327
+ depth_res - 1,
328
+ steps=depth_res).type_as(
329
+ occupancys.device)
330
+ depth_index_max = torch.max(
331
+ (occupancys > self.balance_value) *
332
+ (depth_index + 1),
333
+ dim=-1,
334
+ keepdim=True)[0] - 1
335
+ shadow = depth_index < depth_index_max
336
+ is_boundary[shadow] = False
337
+ is_boundary = is_boundary[0, 0]
338
+ else:
339
+ is_boundary = (self.smooth_conv3x3(is_boundary.float())
340
+ > 0)[0, 0]
341
+ # is_boundary = is_boundary[0, 0]
342
+
343
+ is_boundary[coords_accum[0, :, 2], coords_accum[0, :, 1],
344
+ coords_accum[0, :, 0]] = False
345
+ point_coords = is_boundary.permute(
346
+ 2, 1, 0).nonzero(as_tuple=False).unsqueeze(0)
347
+ point_indices = (point_coords[:, :, 2] * H * W +
348
+ point_coords[:, :, 1] * W +
349
+ point_coords[:, :, 0])
350
+
351
+ R, C, D, H, W = occupancys.shape
352
+ # interpolated value
353
+ occupancys_interp = torch.gather(
354
+ occupancys.reshape(R, C, D * H * W), 2,
355
+ point_indices.unsqueeze(1))
356
+
357
+ # inferred value
358
+ coords = point_coords * stride
359
+
360
+ if coords.size(1) == 0:
361
+ continue
362
+ occupancys_topk = self.batch_eval(coords, **kwargs)
363
+ if self.visualize:
364
+ this_stage_coords.append(coords)
365
+
366
+ # put mask point predictions to the right places on the upsampled grid.
367
+ R, C, D, H, W = occupancys.shape
368
+ point_indices = point_indices.unsqueeze(1).expand(-1, C, -1)
369
+ occupancys = (occupancys.reshape(R, C, D * H * W).scatter_(
370
+ 2, point_indices, occupancys_topk).view(R, C, D, H, W))
371
+
372
+ with torch.no_grad():
373
+ # conflicts
374
+ conflicts = ((occupancys_interp - self.balance_value) *
375
+ (occupancys_topk - self.balance_value) < 0)[0,
376
+ 0]
377
+
378
+ if self.visualize:
379
+ self.plot(occupancys, coords, final_D, final_H,
380
+ final_W)
381
+
382
+ voxels = coords / stride
383
+ coords_accum = torch.cat([voxels, coords_accum],
384
+ dim=1).unique(dim=1)
385
+ calculated[coords[0, :, 2], coords[0, :, 1],
386
+ coords[0, :, 0]] = True
387
+
388
+ while conflicts.sum() > 0:
389
+ if self.use_shadow and torch.equal(resolution,
390
+ self.resolutions[-1]):
391
+ break
392
+
393
+ with torch.no_grad():
394
+ conflicts_coords = coords[0, conflicts, :]
395
+
396
+ if self.debug:
397
+ self.plot(occupancys,
398
+ conflicts_coords.unsqueeze(0),
399
+ final_D,
400
+ final_H,
401
+ final_W,
402
+ title='conflicts')
403
+
404
+ conflicts_boundary = (conflicts_coords.int() +
405
+ self.gird8_offsets.unsqueeze(1) *
406
+ stride.int()).reshape(
407
+ -1, 3).long().unique(dim=0)
408
+ conflicts_boundary[:, 0] = (
409
+ conflicts_boundary[:, 0].clamp(
410
+ 0,
411
+ calculated.size(2) - 1))
412
+ conflicts_boundary[:, 1] = (
413
+ conflicts_boundary[:, 1].clamp(
414
+ 0,
415
+ calculated.size(1) - 1))
416
+ conflicts_boundary[:, 2] = (
417
+ conflicts_boundary[:, 2].clamp(
418
+ 0,
419
+ calculated.size(0) - 1))
420
+
421
+ coords = conflicts_boundary[calculated[
422
+ conflicts_boundary[:, 2], conflicts_boundary[:, 1],
423
+ conflicts_boundary[:, 0]] == False]
424
+
425
+ if self.debug:
426
+ self.plot(occupancys,
427
+ coords.unsqueeze(0),
428
+ final_D,
429
+ final_H,
430
+ final_W,
431
+ title='coords')
432
+
433
+ coords = coords.unsqueeze(0)
434
+ point_coords = coords / stride
435
+ point_indices = (point_coords[:, :, 2] * H * W +
436
+ point_coords[:, :, 1] * W +
437
+ point_coords[:, :, 0])
438
+
439
+ R, C, D, H, W = occupancys.shape
440
+ # interpolated value
441
+ occupancys_interp = torch.gather(
442
+ occupancys.reshape(R, C, D * H * W), 2,
443
+ point_indices.unsqueeze(1))
444
+
445
+ # inferred value
446
+ coords = point_coords * stride
447
+
448
+ if coords.size(1) == 0:
449
+ break
450
+ occupancys_topk = self.batch_eval(coords, **kwargs)
451
+ if self.visualize:
452
+ this_stage_coords.append(coords)
453
+
454
+ with torch.no_grad():
455
+ # conflicts
456
+ conflicts = ((occupancys_interp - self.balance_value) *
457
+ (occupancys_topk - self.balance_value) <
458
+ 0)[0, 0]
459
+
460
+ # put mask point predictions to the right places on the upsampled grid.
461
+ point_indices = point_indices.unsqueeze(1).expand(
462
+ -1, C, -1)
463
+ occupancys = (occupancys.reshape(R, C, D * H * W).scatter_(
464
+ 2, point_indices, occupancys_topk).view(R, C, D, H, W))
465
+
466
+ with torch.no_grad():
467
+ voxels = coords / stride
468
+ coords_accum = torch.cat([voxels, coords_accum],
469
+ dim=1).unique(dim=1)
470
+ calculated[coords[0, :, 2], coords[0, :, 1],
471
+ coords[0, :, 0]] = True
472
+
473
+ if self.visualize:
474
+ this_stage_coords = torch.cat(this_stage_coords, dim=1)
475
+ self.plot(occupancys, this_stage_coords, final_D, final_H,
476
+ final_W)
477
+
478
+ return occupancys[0, 0]
479
+
480
+ def plot(self,
481
+ occupancys,
482
+ coords,
483
+ final_D,
484
+ final_H,
485
+ final_W,
486
+ title='',
487
+ **kwargs):
488
+ final = F.interpolate(occupancys.float(),
489
+ size=(final_D, final_H, final_W),
490
+ mode="trilinear",
491
+ align_corners=True) # here true is correct!
492
+ x = coords[0, :, 0].to("cpu")
493
+ y = coords[0, :, 1].to("cpu")
494
+ z = coords[0, :, 2].to("cpu")
495
+
496
+ plot_mask3D(final[0, 0].to("cpu"), title, (x, y, z), **kwargs)
497
+
498
+ def find_vertices(self, sdf, direction="front"):
499
+ '''
500
+ - direction: "front" | "back" | "left" | "right"
501
+ '''
502
+ resolution = sdf.size(2)
503
+ if direction == "front":
504
+ pass
505
+ elif direction == "left":
506
+ sdf = sdf.permute(2, 1, 0)
507
+ elif direction == "back":
508
+ inv_idx = torch.arange(sdf.size(2) - 1, -1, -1).long()
509
+ sdf = sdf[inv_idx, :, :]
510
+ elif direction == "right":
511
+ inv_idx = torch.arange(sdf.size(2) - 1, -1, -1).long()
512
+ sdf = sdf[:, :, inv_idx]
513
+ sdf = sdf.permute(2, 1, 0)
514
+
515
+ inv_idx = torch.arange(sdf.size(2) - 1, -1, -1).long()
516
+ sdf = sdf[inv_idx, :, :]
517
+ sdf_all = sdf.permute(2, 1, 0)
518
+
519
+ # shadow
520
+ grad_v = (sdf_all > 0.5) * torch.linspace(
521
+ resolution, 1, steps=resolution).to(sdf.device)
522
+ grad_c = torch.ones_like(sdf_all) * torch.linspace(
523
+ 0, resolution - 1, steps=resolution).to(sdf.device)
524
+ max_v, max_c = grad_v.max(dim=2)
525
+ shadow = grad_c > max_c.view(resolution, resolution, 1)
526
+ keep = (sdf_all > 0.5) & (~shadow)
527
+
528
+ p1 = keep.nonzero(as_tuple=False).t() # [3, N]
529
+ p2 = p1.clone() # z
530
+ p2[2, :] = (p2[2, :] - 2).clamp(0, resolution)
531
+ p3 = p1.clone() # y
532
+ p3[1, :] = (p3[1, :] - 2).clamp(0, resolution)
533
+ p4 = p1.clone() # x
534
+ p4[0, :] = (p4[0, :] - 2).clamp(0, resolution)
535
+
536
+ v1 = sdf_all[p1[0, :], p1[1, :], p1[2, :]]
537
+ v2 = sdf_all[p2[0, :], p2[1, :], p2[2, :]]
538
+ v3 = sdf_all[p3[0, :], p3[1, :], p3[2, :]]
539
+ v4 = sdf_all[p4[0, :], p4[1, :], p4[2, :]]
540
+
541
+ X = p1[0, :].long() # [N,]
542
+ Y = p1[1, :].long() # [N,]
543
+ Z = p2[2, :].float() * (0.5 - v1) / (v2 - v1) + \
544
+ p1[2, :].float() * (v2 - 0.5) / (v2 - v1) # [N,]
545
+ Z = Z.clamp(0, resolution)
546
+
547
+ # normal
548
+ norm_z = v2 - v1
549
+ norm_y = v3 - v1
550
+ norm_x = v4 - v1
551
+ # print (v2.min(dim=0)[0], v2.max(dim=0)[0], v3.min(dim=0)[0], v3.max(dim=0)[0])
552
+
553
+ norm = torch.stack([norm_x, norm_y, norm_z], dim=1)
554
+ norm = norm / torch.norm(norm, p=2, dim=1, keepdim=True)
555
+
556
+ return X, Y, Z, norm
557
+
558
+ def render_normal(self, resolution, X, Y, Z, norm):
559
+ image = torch.ones((1, 3, resolution, resolution),
560
+ dtype=torch.float32).to(norm.device)
561
+ color = (norm + 1) / 2.0
562
+ color = color.clamp(0, 1)
563
+ image[0, :, Y, X] = color.t()
564
+ return image
565
+
566
+ def display(self, sdf):
567
+
568
+ # render
569
+ X, Y, Z, norm = self.find_vertices(sdf, direction="front")
570
+ image1 = self.render_normal(self.resolutions[-1, -1], X, Y, Z, norm)
571
+ X, Y, Z, norm = self.find_vertices(sdf, direction="left")
572
+ image2 = self.render_normal(self.resolutions[-1, -1], X, Y, Z, norm)
573
+ X, Y, Z, norm = self.find_vertices(sdf, direction="right")
574
+ image3 = self.render_normal(self.resolutions[-1, -1], X, Y, Z, norm)
575
+ X, Y, Z, norm = self.find_vertices(sdf, direction="back")
576
+ image4 = self.render_normal(self.resolutions[-1, -1], X, Y, Z, norm)
577
+
578
+ image = torch.cat([image1, image2, image3, image4], axis=3)
579
+ image = image.detach().cpu().numpy()[0].transpose(1, 2, 0) * 255.0
580
+
581
+ return np.uint8(image)
582
+
583
+ def export_mesh(self, occupancys):
584
+
585
+ final = occupancys[1:, 1:, 1:].contiguous()
586
+
587
+ if final.shape[0] > 256:
588
+ # for voxelgrid larger than 256^3, the required GPU memory will be > 9GB
589
+ # thus we use CPU marching_cube to avoid "CUDA out of memory"
590
+ occu_arr = final.detach().cpu().numpy() # non-smooth surface
591
+ # occu_arr = mcubes.smooth(final.detach().cpu().numpy()) # smooth surface
592
+ vertices, triangles = mcubes.marching_cubes(
593
+ occu_arr, self.balance_value)
594
+ verts = torch.as_tensor(vertices[:, [2, 1, 0]])
595
+ faces = torch.as_tensor(triangles.astype(
596
+ np.long), dtype=torch.long)[:, [0, 2, 1]]
597
+ else:
598
+ torch.cuda.empty_cache()
599
+ vertices, triangles = voxelgrids_to_trianglemeshes(
600
+ final.unsqueeze(0))
601
+ verts = vertices[0][:, [2, 1, 0]].cpu()
602
+ faces = triangles[0][:, [0, 2, 1]].cpu()
603
+
604
+ return verts, faces
lib/common/seg3d_utils.py ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
5
+ # holder of all proprietary rights on this computer program.
6
+ # You can only use this computer program if you have closed
7
+ # a license agreement with MPG or you get the right to use the computer
8
+ # program from someone who is authorized to grant you that right.
9
+ # Any use of the computer program without a valid license is prohibited and
10
+ # liable to prosecution.
11
+ #
12
+ # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
13
+ # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
14
+ # for Intelligent Systems. All rights reserved.
15
+ #
16
+ # Contact: [email protected]
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ import matplotlib.pyplot as plt
22
+
23
+
24
+ def plot_mask2D(mask,
25
+ title="",
26
+ point_coords=None,
27
+ figsize=10,
28
+ point_marker_size=5):
29
+ '''
30
+ Simple plotting tool to show intermediate mask predictions and points
31
+ where PointRend is applied.
32
+
33
+ Args:
34
+ mask (Tensor): mask prediction of shape HxW
35
+ title (str): title for the plot
36
+ point_coords ((Tensor, Tensor)): x and y point coordinates
37
+ figsize (int): size of the figure to plot
38
+ point_marker_size (int): marker size for points
39
+ '''
40
+
41
+ H, W = mask.shape
42
+ plt.figure(figsize=(figsize, figsize))
43
+ if title:
44
+ title += ", "
45
+ plt.title("{}resolution {}x{}".format(title, H, W), fontsize=30)
46
+ plt.ylabel(H, fontsize=30)
47
+ plt.xlabel(W, fontsize=30)
48
+ plt.xticks([], [])
49
+ plt.yticks([], [])
50
+ plt.imshow(mask.detach(),
51
+ interpolation="nearest",
52
+ cmap=plt.get_cmap('gray'))
53
+ if point_coords is not None:
54
+ plt.scatter(x=point_coords[0],
55
+ y=point_coords[1],
56
+ color="red",
57
+ s=point_marker_size,
58
+ clip_on=True)
59
+ plt.xlim(-0.5, W - 0.5)
60
+ plt.ylim(H - 0.5, -0.5)
61
+ plt.show()
62
+
63
+
64
+ def plot_mask3D(mask=None,
65
+ title="",
66
+ point_coords=None,
67
+ figsize=1500,
68
+ point_marker_size=8,
69
+ interactive=True):
70
+ '''
71
+ Simple plotting tool to show intermediate mask predictions and points
72
+ where PointRend is applied.
73
+
74
+ Args:
75
+ mask (Tensor): mask prediction of shape DxHxW
76
+ title (str): title for the plot
77
+ point_coords ((Tensor, Tensor, Tensor)): x and y and z point coordinates
78
+ figsize (int): size of the figure to plot
79
+ point_marker_size (int): marker size for points
80
+ '''
81
+ import trimesh
82
+ import vtkplotter
83
+ from skimage import measure
84
+
85
+ vp = vtkplotter.Plotter(title=title, size=(figsize, figsize))
86
+ vis_list = []
87
+
88
+ if mask is not None:
89
+ mask = mask.detach().to("cpu").numpy()
90
+ mask = mask.transpose(2, 1, 0)
91
+
92
+ # marching cube to find surface
93
+ verts, faces, normals, values = measure.marching_cubes_lewiner(
94
+ mask, 0.5, gradient_direction='ascent')
95
+
96
+ # create a mesh
97
+ mesh = trimesh.Trimesh(verts, faces)
98
+ mesh.visual.face_colors = [200, 200, 250, 100]
99
+ vis_list.append(mesh)
100
+
101
+ if point_coords is not None:
102
+ point_coords = torch.stack(point_coords, 1).to("cpu").numpy()
103
+
104
+ # import numpy as np
105
+ # select_x = np.logical_and(point_coords[:, 0] >= 16, point_coords[:, 0] <= 112)
106
+ # select_y = np.logical_and(point_coords[:, 1] >= 48, point_coords[:, 1] <= 272)
107
+ # select_z = np.logical_and(point_coords[:, 2] >= 16, point_coords[:, 2] <= 112)
108
+ # select = np.logical_and(np.logical_and(select_x, select_y), select_z)
109
+ # point_coords = point_coords[select, :]
110
+
111
+ pc = vtkplotter.Points(point_coords, r=point_marker_size, c='red')
112
+ vis_list.append(pc)
113
+
114
+ vp.show(*vis_list,
115
+ bg="white",
116
+ axes=1,
117
+ interactive=interactive,
118
+ azimuth=30,
119
+ elevation=30)
120
+
121
+
122
+ def create_grid3D(min, max, steps):
123
+ if type(min) is int:
124
+ min = (min, min, min) # (x, y, z)
125
+ if type(max) is int:
126
+ max = (max, max, max) # (x, y)
127
+ if type(steps) is int:
128
+ steps = (steps, steps, steps) # (x, y, z)
129
+ arrangeX = torch.linspace(min[0], max[0], steps[0]).long()
130
+ arrangeY = torch.linspace(min[1], max[1], steps[1]).long()
131
+ arrangeZ = torch.linspace(min[2], max[2], steps[2]).long()
132
+ gridD, girdH, gridW = torch.meshgrid([arrangeZ, arrangeY, arrangeX])
133
+ coords = torch.stack([gridW, girdH,
134
+ gridD]) # [2, steps[0], steps[1], steps[2]]
135
+ coords = coords.view(3, -1).t() # [N, 3]
136
+ return coords
137
+
138
+
139
+ def create_grid2D(min, max, steps):
140
+ if type(min) is int:
141
+ min = (min, min) # (x, y)
142
+ if type(max) is int:
143
+ max = (max, max) # (x, y)
144
+ if type(steps) is int:
145
+ steps = (steps, steps) # (x, y)
146
+ arrangeX = torch.linspace(min[0], max[0], steps[0]).long()
147
+ arrangeY = torch.linspace(min[1], max[1], steps[1]).long()
148
+ girdH, gridW = torch.meshgrid([arrangeY, arrangeX])
149
+ coords = torch.stack([gridW, girdH]) # [2, steps[0], steps[1]]
150
+ coords = coords.view(2, -1).t() # [N, 2]
151
+ return coords
152
+
153
+
154
+ class SmoothConv2D(nn.Module):
155
+ def __init__(self, in_channels, out_channels, kernel_size=3):
156
+ super().__init__()
157
+ assert kernel_size % 2 == 1, "kernel_size for smooth_conv must be odd: {3, 5, ...}"
158
+ self.padding = (kernel_size - 1) // 2
159
+
160
+ weight = torch.ones(
161
+ (in_channels, out_channels, kernel_size, kernel_size),
162
+ dtype=torch.float32) / (kernel_size**2)
163
+ self.register_buffer('weight', weight)
164
+
165
+ def forward(self, input):
166
+ return F.conv2d(input, self.weight, padding=self.padding)
167
+
168
+
169
+ class SmoothConv3D(nn.Module):
170
+ def __init__(self, in_channels, out_channels, kernel_size=3):
171
+ super().__init__()
172
+ assert kernel_size % 2 == 1, "kernel_size for smooth_conv must be odd: {3, 5, ...}"
173
+ self.padding = (kernel_size - 1) // 2
174
+
175
+ weight = torch.ones(
176
+ (in_channels, out_channels, kernel_size, kernel_size, kernel_size),
177
+ dtype=torch.float32) / (kernel_size**3)
178
+ self.register_buffer('weight', weight)
179
+
180
+ def forward(self, input):
181
+ return F.conv3d(input, self.weight, padding=self.padding)
182
+
183
+
184
+ def build_smooth_conv3D(in_channels=1,
185
+ out_channels=1,
186
+ kernel_size=3,
187
+ padding=1):
188
+ smooth_conv = torch.nn.Conv3d(in_channels=in_channels,
189
+ out_channels=out_channels,
190
+ kernel_size=kernel_size,
191
+ padding=padding)
192
+ smooth_conv.weight.data = torch.ones(
193
+ (in_channels, out_channels, kernel_size, kernel_size, kernel_size),
194
+ dtype=torch.float32) / (kernel_size**3)
195
+ smooth_conv.bias.data = torch.zeros(out_channels)
196
+ return smooth_conv
197
+
198
+
199
+ def build_smooth_conv2D(in_channels=1,
200
+ out_channels=1,
201
+ kernel_size=3,
202
+ padding=1):
203
+ smooth_conv = torch.nn.Conv2d(in_channels=in_channels,
204
+ out_channels=out_channels,
205
+ kernel_size=kernel_size,
206
+ padding=padding)
207
+ smooth_conv.weight.data = torch.ones(
208
+ (in_channels, out_channels, kernel_size, kernel_size),
209
+ dtype=torch.float32) / (kernel_size**2)
210
+ smooth_conv.bias.data = torch.zeros(out_channels)
211
+ return smooth_conv
212
+
213
+
214
+ def get_uncertain_point_coords_on_grid3D(uncertainty_map, num_points,
215
+ **kwargs):
216
+ """
217
+ Find `num_points` most uncertain points from `uncertainty_map` grid.
218
+ Args:
219
+ uncertainty_map (Tensor): A tensor of shape (N, 1, H, W, D) that contains uncertainty
220
+ values for a set of points on a regular H x W x D grid.
221
+ num_points (int): The number of points P to select.
222
+ Returns:
223
+ point_indices (Tensor): A tensor of shape (N, P) that contains indices from
224
+ [0, H x W x D) of the most uncertain points.
225
+ point_coords (Tensor): A tensor of shape (N, P, 3) that contains [0, 1] x [0, 1] normalized
226
+ coordinates of the most uncertain points from the H x W x D grid.
227
+ """
228
+ R, _, D, H, W = uncertainty_map.shape
229
+ # h_step = 1.0 / float(H)
230
+ # w_step = 1.0 / float(W)
231
+ # d_step = 1.0 / float(D)
232
+
233
+ num_points = min(D * H * W, num_points)
234
+ point_scores, point_indices = torch.topk(uncertainty_map.view(
235
+ R, D * H * W),
236
+ k=num_points,
237
+ dim=1)
238
+ point_coords = torch.zeros(R,
239
+ num_points,
240
+ 3,
241
+ dtype=torch.float,
242
+ device=uncertainty_map.device)
243
+ # point_coords[:, :, 0] = h_step / 2.0 + (point_indices // (W * D)).to(torch.float) * h_step
244
+ # point_coords[:, :, 1] = w_step / 2.0 + (point_indices % (W * D) // D).to(torch.float) * w_step
245
+ # point_coords[:, :, 2] = d_step / 2.0 + (point_indices % D).to(torch.float) * d_step
246
+ point_coords[:, :, 0] = (point_indices % W).to(torch.float) # x
247
+ point_coords[:, :, 1] = (point_indices % (H * W) // W).to(torch.float) # y
248
+ point_coords[:, :, 2] = (point_indices // (H * W)).to(torch.float) # z
249
+ print(f"resolution {D} x {H} x {W}", point_scores.min(),
250
+ point_scores.max())
251
+ return point_indices, point_coords
252
+
253
+
254
+ def get_uncertain_point_coords_on_grid3D_faster(uncertainty_map, num_points,
255
+ clip_min):
256
+ """
257
+ Find `num_points` most uncertain points from `uncertainty_map` grid.
258
+ Args:
259
+ uncertainty_map (Tensor): A tensor of shape (N, 1, H, W, D) that contains uncertainty
260
+ values for a set of points on a regular H x W x D grid.
261
+ num_points (int): The number of points P to select.
262
+ Returns:
263
+ point_indices (Tensor): A tensor of shape (N, P) that contains indices from
264
+ [0, H x W x D) of the most uncertain points.
265
+ point_coords (Tensor): A tensor of shape (N, P, 3) that contains [0, 1] x [0, 1] normalized
266
+ coordinates of the most uncertain points from the H x W x D grid.
267
+ """
268
+ R, _, D, H, W = uncertainty_map.shape
269
+ # h_step = 1.0 / float(H)
270
+ # w_step = 1.0 / float(W)
271
+ # d_step = 1.0 / float(D)
272
+
273
+ assert R == 1, "batchsize > 1 is not implemented!"
274
+ uncertainty_map = uncertainty_map.view(D * H * W)
275
+ indices = (uncertainty_map >= clip_min).nonzero().squeeze(1)
276
+ num_points = min(num_points, indices.size(0))
277
+ point_scores, point_indices = torch.topk(uncertainty_map[indices],
278
+ k=num_points,
279
+ dim=0)
280
+ point_indices = indices[point_indices].unsqueeze(0)
281
+
282
+ point_coords = torch.zeros(R,
283
+ num_points,
284
+ 3,
285
+ dtype=torch.float,
286
+ device=uncertainty_map.device)
287
+ # point_coords[:, :, 0] = h_step / 2.0 + (point_indices // (W * D)).to(torch.float) * h_step
288
+ # point_coords[:, :, 1] = w_step / 2.0 + (point_indices % (W * D) // D).to(torch.float) * w_step
289
+ # point_coords[:, :, 2] = d_step / 2.0 + (point_indices % D).to(torch.float) * d_step
290
+ point_coords[:, :, 0] = (point_indices % W).to(torch.float) # x
291
+ point_coords[:, :, 1] = (point_indices % (H * W) // W).to(torch.float) # y
292
+ point_coords[:, :, 2] = (point_indices // (H * W)).to(torch.float) # z
293
+ # print (f"resolution {D} x {H} x {W}", point_scores.min(), point_scores.max())
294
+ return point_indices, point_coords
295
+
296
+
297
+ def get_uncertain_point_coords_on_grid2D(uncertainty_map, num_points,
298
+ **kwargs):
299
+ """
300
+ Find `num_points` most uncertain points from `uncertainty_map` grid.
301
+ Args:
302
+ uncertainty_map (Tensor): A tensor of shape (N, 1, H, W) that contains uncertainty
303
+ values for a set of points on a regular H x W grid.
304
+ num_points (int): The number of points P to select.
305
+ Returns:
306
+ point_indices (Tensor): A tensor of shape (N, P) that contains indices from
307
+ [0, H x W) of the most uncertain points.
308
+ point_coords (Tensor): A tensor of shape (N, P, 2) that contains [0, 1] x [0, 1] normalized
309
+ coordinates of the most uncertain points from the H x W grid.
310
+ """
311
+ R, _, H, W = uncertainty_map.shape
312
+ # h_step = 1.0 / float(H)
313
+ # w_step = 1.0 / float(W)
314
+
315
+ num_points = min(H * W, num_points)
316
+ point_scores, point_indices = torch.topk(uncertainty_map.view(R, H * W),
317
+ k=num_points,
318
+ dim=1)
319
+ point_coords = torch.zeros(R,
320
+ num_points,
321
+ 2,
322
+ dtype=torch.long,
323
+ device=uncertainty_map.device)
324
+ # point_coords[:, :, 0] = w_step / 2.0 + (point_indices % W).to(torch.float) * w_step
325
+ # point_coords[:, :, 1] = h_step / 2.0 + (point_indices // W).to(torch.float) * h_step
326
+ point_coords[:, :, 0] = (point_indices % W).to(torch.long)
327
+ point_coords[:, :, 1] = (point_indices // W).to(torch.long)
328
+ # print (point_scores.min(), point_scores.max())
329
+ return point_indices, point_coords
330
+
331
+
332
+ def get_uncertain_point_coords_on_grid2D_faster(uncertainty_map, num_points,
333
+ clip_min):
334
+ """
335
+ Find `num_points` most uncertain points from `uncertainty_map` grid.
336
+ Args:
337
+ uncertainty_map (Tensor): A tensor of shape (N, 1, H, W) that contains uncertainty
338
+ values for a set of points on a regular H x W grid.
339
+ num_points (int): The number of points P to select.
340
+ Returns:
341
+ point_indices (Tensor): A tensor of shape (N, P) that contains indices from
342
+ [0, H x W) of the most uncertain points.
343
+ point_coords (Tensor): A tensor of shape (N, P, 2) that contains [0, 1] x [0, 1] normalized
344
+ coordinates of the most uncertain points from the H x W grid.
345
+ """
346
+ R, _, H, W = uncertainty_map.shape
347
+ # h_step = 1.0 / float(H)
348
+ # w_step = 1.0 / float(W)
349
+
350
+ assert R == 1, "batchsize > 1 is not implemented!"
351
+ uncertainty_map = uncertainty_map.view(H * W)
352
+ indices = (uncertainty_map >= clip_min).nonzero().squeeze(1)
353
+ num_points = min(num_points, indices.size(0))
354
+ point_scores, point_indices = torch.topk(uncertainty_map[indices],
355
+ k=num_points,
356
+ dim=0)
357
+ point_indices = indices[point_indices].unsqueeze(0)
358
+
359
+ point_coords = torch.zeros(R,
360
+ num_points,
361
+ 2,
362
+ dtype=torch.long,
363
+ device=uncertainty_map.device)
364
+ # point_coords[:, :, 0] = w_step / 2.0 + (point_indices % W).to(torch.float) * w_step
365
+ # point_coords[:, :, 1] = h_step / 2.0 + (point_indices // W).to(torch.float) * h_step
366
+ point_coords[:, :, 0] = (point_indices % W).to(torch.long)
367
+ point_coords[:, :, 1] = (point_indices // W).to(torch.long)
368
+ # print (point_scores.min(), point_scores.max())
369
+ return point_indices, point_coords
370
+
371
+
372
+ def calculate_uncertainty(logits, classes=None, balance_value=0.5):
373
+ """
374
+ We estimate uncerainty as L1 distance between 0.0 and the logit prediction in 'logits' for the
375
+ foreground class in `classes`.
376
+ Args:
377
+ logits (Tensor): A tensor of shape (R, C, ...) or (R, 1, ...) for class-specific or
378
+ class-agnostic, where R is the total number of predicted masks in all images and C is
379
+ the number of foreground classes. The values are logits.
380
+ classes (list): A list of length R that contains either predicted of ground truth class
381
+ for eash predicted mask.
382
+ Returns:
383
+ scores (Tensor): A tensor of shape (R, 1, ...) that contains uncertainty scores with
384
+ the most uncertain locations having the highest uncertainty score.
385
+ """
386
+ if logits.shape[1] == 1:
387
+ gt_class_logits = logits
388
+ else:
389
+ gt_class_logits = logits[
390
+ torch.arange(logits.shape[0], device=logits.device),
391
+ classes].unsqueeze(1)
392
+ return -torch.abs(gt_class_logits - balance_value)
lib/common/smpl_vert_segmentation.json ADDED
The diff for this file is too large to render. See raw diff
 
lib/common/train_util.py ADDED
@@ -0,0 +1,597 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
5
+ # holder of all proprietary rights on this computer program.
6
+ # You can only use this computer program if you have closed
7
+ # a license agreement with MPG or you get the right to use the computer
8
+ # program from someone who is authorized to grant you that right.
9
+ # Any use of the computer program without a valid license is prohibited and
10
+ # liable to prosecution.
11
+ #
12
+ # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
13
+ # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
14
+ # for Intelligent Systems. All rights reserved.
15
+ #
16
+ # Contact: [email protected]
17
+
18
+ import yaml
19
+ import os.path as osp
20
+ import torch
21
+ import numpy as np
22
+ import torch.nn.functional as F
23
+ from ..dataset.mesh_util import *
24
+ from ..net.geometry import orthogonal
25
+ from pytorch3d.renderer.mesh import rasterize_meshes
26
+ from .render_utils import Pytorch3dRasterizer
27
+ from pytorch3d.structures import Meshes
28
+ import cv2
29
+ from PIL import Image
30
+ from tqdm import tqdm
31
+ import os
32
+ from termcolor import colored
33
+
34
+
35
+ def reshape_sample_tensor(sample_tensor, num_views):
36
+ if num_views == 1:
37
+ return sample_tensor
38
+ # Need to repeat sample_tensor along the batch dim num_views times
39
+ sample_tensor = sample_tensor.unsqueeze(dim=1)
40
+ sample_tensor = sample_tensor.repeat(1, num_views, 1, 1)
41
+ sample_tensor = sample_tensor.view(
42
+ sample_tensor.shape[0] * sample_tensor.shape[1],
43
+ sample_tensor.shape[2], sample_tensor.shape[3])
44
+ return sample_tensor
45
+
46
+
47
+ def gen_mesh_eval(opt, net, cuda, data, resolution=None):
48
+ resolution = opt.resolution if resolution is None else resolution
49
+ image_tensor = data['img'].to(device=cuda)
50
+ calib_tensor = data['calib'].to(device=cuda)
51
+
52
+ net.filter(image_tensor)
53
+
54
+ b_min = data['b_min']
55
+ b_max = data['b_max']
56
+ try:
57
+ verts, faces, _, _ = reconstruction_faster(net,
58
+ cuda,
59
+ calib_tensor,
60
+ resolution,
61
+ b_min,
62
+ b_max,
63
+ use_octree=False)
64
+
65
+ except Exception as e:
66
+ print(e)
67
+ print('Can not create marching cubes at this time.')
68
+ verts, faces = None, None
69
+ return verts, faces
70
+
71
+
72
+ def gen_mesh(opt, net, cuda, data, save_path, resolution=None):
73
+ resolution = opt.resolution if resolution is None else resolution
74
+ image_tensor = data['img'].to(device=cuda)
75
+ calib_tensor = data['calib'].to(device=cuda)
76
+
77
+ net.filter(image_tensor)
78
+
79
+ b_min = data['b_min']
80
+ b_max = data['b_max']
81
+ try:
82
+ save_img_path = save_path[:-4] + '.png'
83
+ save_img_list = []
84
+ for v in range(image_tensor.shape[0]):
85
+ save_img = (np.transpose(image_tensor[v].detach().cpu().numpy(),
86
+ (1, 2, 0)) * 0.5 +
87
+ 0.5)[:, :, ::-1] * 255.0
88
+ save_img_list.append(save_img)
89
+ save_img = np.concatenate(save_img_list, axis=1)
90
+ Image.fromarray(np.uint8(save_img[:, :, ::-1])).save(save_img_path)
91
+
92
+ verts, faces, _, _ = reconstruction_faster(net, cuda, calib_tensor,
93
+ resolution, b_min, b_max)
94
+ verts_tensor = torch.from_numpy(
95
+ verts.T).unsqueeze(0).to(device=cuda).float()
96
+ xyz_tensor = net.projection(verts_tensor, calib_tensor[:1])
97
+ uv = xyz_tensor[:, :2, :]
98
+ color = netG.index(image_tensor[:1], uv).detach().cpu().numpy()[0].T
99
+ color = color * 0.5 + 0.5
100
+ save_obj_mesh_with_color(save_path, verts, faces, color)
101
+ except Exception as e:
102
+ print(e)
103
+ print('Can not create marching cubes at this time.')
104
+ verts, faces, color = None, None, None
105
+ return verts, faces, color
106
+
107
+
108
+ def gen_mesh_color(opt, netG, netC, cuda, data, save_path, use_octree=True):
109
+ image_tensor = data['img'].to(device=cuda)
110
+ calib_tensor = data['calib'].to(device=cuda)
111
+
112
+ netG.filter(image_tensor)
113
+ netC.filter(image_tensor)
114
+ netC.attach(netG.get_im_feat())
115
+
116
+ b_min = data['b_min']
117
+ b_max = data['b_max']
118
+ try:
119
+ save_img_path = save_path[:-4] + '.png'
120
+ save_img_list = []
121
+ for v in range(image_tensor.shape[0]):
122
+ save_img = (np.transpose(image_tensor[v].detach().cpu().numpy(),
123
+ (1, 2, 0)) * 0.5 +
124
+ 0.5)[:, :, ::-1] * 255.0
125
+ save_img_list.append(save_img)
126
+ save_img = np.concatenate(save_img_list, axis=1)
127
+ Image.fromarray(np.uint8(save_img[:, :, ::-1])).save(save_img_path)
128
+
129
+ verts, faces, _, _ = reconstruction_faster(netG,
130
+ cuda,
131
+ calib_tensor,
132
+ opt.resolution,
133
+ b_min,
134
+ b_max,
135
+ use_octree=use_octree)
136
+
137
+ # Now Getting colors
138
+ verts_tensor = torch.from_numpy(
139
+ verts.T).unsqueeze(0).to(device=cuda).float()
140
+ verts_tensor = reshape_sample_tensor(verts_tensor, opt.num_views)
141
+ color = np.zeros(verts.shape)
142
+ interval = 10000
143
+ for i in range(len(color) // interval):
144
+ left = i * interval
145
+ right = i * interval + interval
146
+ if i == len(color) // interval - 1:
147
+ right = -1
148
+ netC.query(verts_tensor[:, :, left:right], calib_tensor)
149
+ rgb = netC.get_preds()[0].detach().cpu().numpy() * 0.5 + 0.5
150
+ color[left:right] = rgb.T
151
+
152
+ save_obj_mesh_with_color(save_path, verts, faces, color)
153
+ except Exception as e:
154
+ print(e)
155
+ print('Can not create marching cubes at this time.')
156
+ verts, faces, color = None, None, None
157
+ return verts, faces, color
158
+
159
+
160
+ def adjust_learning_rate(optimizer, epoch, lr, schedule, gamma):
161
+ """Sets the learning rate to the initial LR decayed by schedule"""
162
+ if epoch in schedule:
163
+ lr *= gamma
164
+ for param_group in optimizer.param_groups:
165
+ param_group['lr'] = lr
166
+ return lr
167
+
168
+
169
+ def compute_acc(pred, gt, thresh=0.5):
170
+ '''
171
+ return:
172
+ IOU, precision, and recall
173
+ '''
174
+ with torch.no_grad():
175
+ vol_pred = pred > thresh
176
+ vol_gt = gt > thresh
177
+
178
+ union = vol_pred | vol_gt
179
+ inter = vol_pred & vol_gt
180
+
181
+ true_pos = inter.sum().float()
182
+
183
+ union = union.sum().float()
184
+ if union == 0:
185
+ union = 1
186
+ vol_pred = vol_pred.sum().float()
187
+ if vol_pred == 0:
188
+ vol_pred = 1
189
+ vol_gt = vol_gt.sum().float()
190
+ if vol_gt == 0:
191
+ vol_gt = 1
192
+ return true_pos / union, true_pos / vol_pred, true_pos / vol_gt
193
+
194
+
195
+ # def calc_metrics(opt, net, cuda, dataset, num_tests,
196
+ # resolution=128, sampled_points=1000, use_kaolin=True):
197
+ # if num_tests > len(dataset):
198
+ # num_tests = len(dataset)
199
+ # with torch.no_grad():
200
+ # chamfer_arr, p2s_arr = [], []
201
+ # for idx in tqdm(range(num_tests)):
202
+ # data = dataset[idx * len(dataset) // num_tests]
203
+
204
+ # verts, faces = gen_mesh_eval(opt, net, cuda, data, resolution)
205
+ # if verts is None:
206
+ # continue
207
+
208
+ # mesh_gt = trimesh.load(data['mesh_path'])
209
+ # mesh_gt = mesh_gt.split(only_watertight=False)
210
+ # comp_num = [mesh.vertices.shape[0] for mesh in mesh_gt]
211
+ # mesh_gt = mesh_gt[comp_num.index(max(comp_num))]
212
+
213
+ # mesh_pred = trimesh.Trimesh(verts, faces)
214
+
215
+ # gt_surface_pts, _ = trimesh.sample.sample_surface_even(
216
+ # mesh_gt, sampled_points)
217
+ # pred_surface_pts, _ = trimesh.sample.sample_surface_even(
218
+ # mesh_pred, sampled_points)
219
+
220
+ # if use_kaolin and has_kaolin:
221
+ # kal_mesh_gt = kal.rep.TriangleMesh.from_tensors(
222
+ # torch.tensor(mesh_gt.vertices).float().to(device=cuda),
223
+ # torch.tensor(mesh_gt.faces).long().to(device=cuda))
224
+ # kal_mesh_pred = kal.rep.TriangleMesh.from_tensors(
225
+ # torch.tensor(mesh_pred.vertices).float().to(device=cuda),
226
+ # torch.tensor(mesh_pred.faces).long().to(device=cuda))
227
+
228
+ # kal_distance_0 = kal.metrics.mesh.point_to_surface(
229
+ # torch.tensor(pred_surface_pts).float().to(device=cuda), kal_mesh_gt)
230
+ # kal_distance_1 = kal.metrics.mesh.point_to_surface(
231
+ # torch.tensor(gt_surface_pts).float().to(device=cuda), kal_mesh_pred)
232
+
233
+ # dist_gt_pred = torch.sqrt(kal_distance_0).cpu().numpy()
234
+ # dist_pred_gt = torch.sqrt(kal_distance_1).cpu().numpy()
235
+ # else:
236
+ # try:
237
+ # _, dist_pred_gt, _ = trimesh.proximity.closest_point(mesh_pred, gt_surface_pts)
238
+ # _, dist_gt_pred, _ = trimesh.proximity.closest_point(mesh_gt, pred_surface_pts)
239
+ # except Exception as e:
240
+ # print (e)
241
+ # continue
242
+
243
+ # chamfer_dist = 0.5 * (dist_pred_gt.mean() + dist_gt_pred.mean())
244
+ # p2s_dist = dist_pred_gt.mean()
245
+
246
+ # chamfer_arr.append(chamfer_dist)
247
+ # p2s_arr.append(p2s_dist)
248
+
249
+ # return np.average(chamfer_arr), np.average(p2s_arr)
250
+
251
+
252
+ def calc_error(opt, net, cuda, dataset, num_tests):
253
+ if num_tests > len(dataset):
254
+ num_tests = len(dataset)
255
+ with torch.no_grad():
256
+ erorr_arr, IOU_arr, prec_arr, recall_arr = [], [], [], []
257
+ for idx in tqdm(range(num_tests)):
258
+ data = dataset[idx * len(dataset) // num_tests]
259
+ # retrieve the data
260
+ image_tensor = data['img'].to(device=cuda)
261
+ calib_tensor = data['calib'].to(device=cuda)
262
+ sample_tensor = data['samples'].to(device=cuda).unsqueeze(0)
263
+ if opt.num_views > 1:
264
+ sample_tensor = reshape_sample_tensor(sample_tensor,
265
+ opt.num_views)
266
+ label_tensor = data['labels'].to(device=cuda).unsqueeze(0)
267
+
268
+ res, error = net.forward(image_tensor,
269
+ sample_tensor,
270
+ calib_tensor,
271
+ labels=label_tensor)
272
+
273
+ IOU, prec, recall = compute_acc(res, label_tensor)
274
+
275
+ # print(
276
+ # '{0}/{1} | Error: {2:06f} IOU: {3:06f} prec: {4:06f} recall: {5:06f}'
277
+ # .format(idx, num_tests, error.item(), IOU.item(), prec.item(), recall.item()))
278
+ erorr_arr.append(error.item())
279
+ IOU_arr.append(IOU.item())
280
+ prec_arr.append(prec.item())
281
+ recall_arr.append(recall.item())
282
+
283
+ return np.average(erorr_arr), np.average(IOU_arr), np.average(
284
+ prec_arr), np.average(recall_arr)
285
+
286
+
287
+ def calc_error_color(opt, netG, netC, cuda, dataset, num_tests):
288
+ if num_tests > len(dataset):
289
+ num_tests = len(dataset)
290
+ with torch.no_grad():
291
+ error_color_arr = []
292
+
293
+ for idx in tqdm(range(num_tests)):
294
+ data = dataset[idx * len(dataset) // num_tests]
295
+ # retrieve the data
296
+ image_tensor = data['img'].to(device=cuda)
297
+ calib_tensor = data['calib'].to(device=cuda)
298
+ color_sample_tensor = data['color_samples'].to(
299
+ device=cuda).unsqueeze(0)
300
+
301
+ if opt.num_views > 1:
302
+ color_sample_tensor = reshape_sample_tensor(
303
+ color_sample_tensor, opt.num_views)
304
+
305
+ rgb_tensor = data['rgbs'].to(device=cuda).unsqueeze(0)
306
+
307
+ netG.filter(image_tensor)
308
+ _, errorC = netC.forward(image_tensor,
309
+ netG.get_im_feat(),
310
+ color_sample_tensor,
311
+ calib_tensor,
312
+ labels=rgb_tensor)
313
+
314
+ # print('{0}/{1} | Error inout: {2:06f} | Error color: {3:06f}'
315
+ # .format(idx, num_tests, errorG.item(), errorC.item()))
316
+ error_color_arr.append(errorC.item())
317
+
318
+ return np.average(error_color_arr)
319
+
320
+
321
+ # pytorch lightning training related fucntions
322
+
323
+
324
+ def query_func(opt, netG, features, points, proj_matrix=None):
325
+ '''
326
+ - points: size of (bz, N, 3)
327
+ - proj_matrix: size of (bz, 4, 4)
328
+ return: size of (bz, 1, N)
329
+ '''
330
+ assert len(points) == 1
331
+ samples = points.repeat(opt.num_views, 1, 1)
332
+ samples = samples.permute(0, 2, 1) # [bz, 3, N]
333
+
334
+ # view specific query
335
+ if proj_matrix is not None:
336
+ samples = orthogonal(samples, proj_matrix)
337
+
338
+ calib_tensor = torch.stack([torch.eye(4).float()], dim=0).type_as(samples)
339
+
340
+ preds = netG.query(features=features,
341
+ points=samples,
342
+ calibs=calib_tensor,
343
+ regressor=netG.if_regressor)
344
+
345
+ if type(preds) is list:
346
+ preds = preds[0]
347
+
348
+ return preds
349
+
350
+
351
+ def isin(ar1, ar2):
352
+ return (ar1[..., None] == ar2).any(-1)
353
+
354
+
355
+ def in1d(ar1, ar2):
356
+ mask = ar2.new_zeros((max(ar1.max(), ar2.max()) + 1, ), dtype=torch.bool)
357
+ mask[ar2.unique()] = True
358
+ return mask[ar1]
359
+
360
+
361
+ def get_visibility(xy, z, faces):
362
+ """get the visibility of vertices
363
+
364
+ Args:
365
+ xy (torch.tensor): [N,2]
366
+ z (torch.tensor): [N,1]
367
+ faces (torch.tensor): [N,3]
368
+ size (int): resolution of rendered image
369
+ """
370
+
371
+ xyz = torch.cat((xy, -z), dim=1)
372
+ xyz = (xyz + 1.0) / 2.0
373
+ faces = faces.long()
374
+
375
+ rasterizer = Pytorch3dRasterizer(image_size=2**12)
376
+ meshes_screen = Meshes(verts=xyz[None, ...], faces=faces[None, ...])
377
+ raster_settings = rasterizer.raster_settings
378
+
379
+ pix_to_face, zbuf, bary_coords, dists = rasterize_meshes(
380
+ meshes_screen,
381
+ image_size=raster_settings.image_size,
382
+ blur_radius=raster_settings.blur_radius,
383
+ faces_per_pixel=raster_settings.faces_per_pixel,
384
+ bin_size=raster_settings.bin_size,
385
+ max_faces_per_bin=raster_settings.max_faces_per_bin,
386
+ perspective_correct=raster_settings.perspective_correct,
387
+ cull_backfaces=raster_settings.cull_backfaces,
388
+ )
389
+
390
+ vis_vertices_id = torch.unique(faces[torch.unique(pix_to_face), :])
391
+ vis_mask = torch.zeros(size=(z.shape[0], 1))
392
+ vis_mask[vis_vertices_id] = 1.0
393
+
394
+ # print("------------------------\n")
395
+ # print(f"keep points : {vis_mask.sum()/len(vis_mask)}")
396
+
397
+ return vis_mask
398
+
399
+
400
+ def batch_mean(res, key):
401
+ # recursive mean for multilevel dicts
402
+ return torch.stack([
403
+ x[key] if isinstance(x, dict) else batch_mean(x, key) for x in res
404
+ ]).mean()
405
+
406
+
407
+ def tf_log_convert(log_dict):
408
+ new_log_dict = log_dict.copy()
409
+ for k, v in log_dict.items():
410
+ new_log_dict[k.replace("_", "/")] = v
411
+ del new_log_dict[k]
412
+
413
+ return new_log_dict
414
+
415
+
416
+ def bar_log_convert(log_dict, name=None, rot=None):
417
+ from decimal import Decimal
418
+
419
+ new_log_dict = {}
420
+
421
+ if name is not None:
422
+ new_log_dict['name'] = name[0]
423
+ if rot is not None:
424
+ new_log_dict['rot'] = rot[0]
425
+
426
+ for k, v in log_dict.items():
427
+ color = "yellow"
428
+ if 'loss' in k:
429
+ color = "red"
430
+ k = k.replace("loss", "L")
431
+ elif 'acc' in k:
432
+ color = "green"
433
+ k = k.replace("acc", "A")
434
+ elif 'iou' in k:
435
+ color = "green"
436
+ k = k.replace("iou", "I")
437
+ elif 'prec' in k:
438
+ color = "green"
439
+ k = k.replace("prec", "P")
440
+ elif 'recall' in k:
441
+ color = "green"
442
+ k = k.replace("recall", "R")
443
+
444
+ if 'lr' not in k:
445
+ new_log_dict[colored(k.split("_")[1],
446
+ color)] = colored(f"{v:.3f}", color)
447
+ else:
448
+ new_log_dict[colored(k.split("_")[1],
449
+ color)] = colored(f"{Decimal(str(v)):.1E}",
450
+ color)
451
+
452
+ if 'loss' in new_log_dict.keys():
453
+ del new_log_dict['loss']
454
+
455
+ return new_log_dict
456
+
457
+
458
+ def accumulate(outputs, rot_num, split):
459
+
460
+ hparam_log_dict = {}
461
+
462
+ metrics = outputs[0].keys()
463
+ datasets = split.keys()
464
+
465
+ for dataset in datasets:
466
+ for metric in metrics:
467
+ keyword = f"hparam/{dataset}-{metric}"
468
+ if keyword not in hparam_log_dict.keys():
469
+ hparam_log_dict[keyword] = 0
470
+ for idx in range(split[dataset][0] * rot_num,
471
+ split[dataset][1] * rot_num):
472
+ hparam_log_dict[keyword] += outputs[idx][metric]
473
+ hparam_log_dict[keyword] /= (split[dataset][1] -
474
+ split[dataset][0]) * rot_num
475
+
476
+ print(colored(hparam_log_dict, "green"))
477
+
478
+ return hparam_log_dict
479
+
480
+
481
+ def calc_error_N(outputs, targets):
482
+ """calculate the error of normal (IGR)
483
+
484
+ Args:
485
+ outputs (torch.tensor): [B, 3, N]
486
+ target (torch.tensor): [B, N, 3]
487
+
488
+ # manifold loss and grad_loss in IGR paper
489
+ grad_loss = ((nonmnfld_grad.norm(2, dim=-1) - 1) ** 2).mean()
490
+ normals_loss = ((mnfld_grad - normals).abs()).norm(2, dim=1).mean()
491
+
492
+ Returns:
493
+ torch.tensor: error of valid normals on the surface
494
+ """
495
+ # outputs = torch.tanh(-outputs.permute(0,2,1).reshape(-1,3))
496
+ outputs = -outputs.permute(0, 2, 1).reshape(-1, 1)
497
+ targets = targets.reshape(-1, 3)[:, 2:3]
498
+ with_normals = targets.sum(dim=1).abs() > 0.0
499
+
500
+ # eikonal loss
501
+ grad_loss = ((outputs[with_normals].norm(2, dim=-1) - 1)**2).mean()
502
+ # normals loss
503
+ normal_loss = (outputs - targets)[with_normals].abs().norm(2, dim=1).mean()
504
+
505
+ return grad_loss * 0.0 + normal_loss
506
+
507
+
508
+ def calc_knn_acc(preds, carn_verts, labels, pick_num):
509
+ """calculate knn accuracy
510
+
511
+ Args:
512
+ preds (torch.tensor): [B, 3, N]
513
+ carn_verts (torch.tensor): [SMPLX_V_num, 3]
514
+ labels (torch.tensor): [B, N_knn, N]
515
+ """
516
+ N_knn_full = labels.shape[1]
517
+ preds = preds.permute(0, 2, 1).reshape(-1, 3)
518
+ labels = labels.permute(0, 2, 1).reshape(-1, N_knn_full) # [BxN, num_knn]
519
+ labels = labels[:, :pick_num]
520
+
521
+ dist = torch.cdist(preds, carn_verts, p=2) # [BxN, SMPL_V_num]
522
+ knn = dist.topk(k=pick_num, dim=1, largest=False)[1] # [BxN, num_knn]
523
+ cat_mat = torch.sort(torch.cat((knn, labels), dim=1))[0]
524
+ bool_col = torch.zeros_like(cat_mat)[:, 0]
525
+ for i in range(pick_num * 2 - 1):
526
+ bool_col += cat_mat[:, i] == cat_mat[:, i + 1]
527
+ acc = (bool_col > 0).sum() / len(bool_col)
528
+
529
+ return acc
530
+
531
+
532
+ def calc_acc_seg(output, target, num_multiseg):
533
+ from pytorch_lightning.metrics import Accuracy
534
+ return Accuracy()(output.reshape(-1, num_multiseg).cpu(),
535
+ target.flatten().cpu())
536
+
537
+
538
+ def add_watermark(imgs, titles):
539
+
540
+ # Write some Text
541
+
542
+ font = cv2.FONT_HERSHEY_SIMPLEX
543
+ bottomLeftCornerOfText = (350, 50)
544
+ bottomRightCornerOfText = (800, 50)
545
+ fontScale = 1
546
+ fontColor = (1.0, 1.0, 1.0)
547
+ lineType = 2
548
+
549
+ for i in range(len(imgs)):
550
+
551
+ title = titles[i + 1]
552
+ cv2.putText(imgs[i], title, bottomLeftCornerOfText, font, fontScale,
553
+ fontColor, lineType)
554
+
555
+ if i == 0:
556
+ cv2.putText(imgs[i], str(titles[i][0]), bottomRightCornerOfText,
557
+ font, fontScale, fontColor, lineType)
558
+
559
+ result = np.concatenate(imgs, axis=0).transpose(2, 0, 1)
560
+
561
+ return result
562
+
563
+
564
+ def make_test_gif(img_dir):
565
+
566
+ if img_dir is not None and len(os.listdir(img_dir)) > 0:
567
+ for dataset in os.listdir(img_dir):
568
+ for subject in sorted(os.listdir(osp.join(img_dir, dataset))):
569
+ img_lst = []
570
+ im1 = None
571
+ for file in sorted(
572
+ os.listdir(osp.join(img_dir, dataset, subject))):
573
+ if file[-3:] not in ['obj', 'gif']:
574
+ img_path = os.path.join(img_dir, dataset, subject,
575
+ file)
576
+ if im1 == None:
577
+ im1 = Image.open(img_path)
578
+ else:
579
+ img_lst.append(Image.open(img_path))
580
+
581
+ print(os.path.join(img_dir, dataset, subject, "out.gif"))
582
+ im1.save(os.path.join(img_dir, dataset, subject, "out.gif"),
583
+ save_all=True,
584
+ append_images=img_lst,
585
+ duration=500,
586
+ loop=0)
587
+
588
+
589
+ def export_cfg(logger, cfg):
590
+
591
+ cfg_export_file = osp.join(logger.save_dir, logger.name,
592
+ f"version_{logger.version}", "cfg.yaml")
593
+
594
+ if not osp.exists(cfg_export_file):
595
+ os.makedirs(osp.dirname(cfg_export_file), exist_ok=True)
596
+ with open(cfg_export_file, "w+") as file:
597
+ _ = yaml.dump(cfg, file)
lib/dataset/Evaluator.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
5
+ # holder of all proprietary rights on this computer program.
6
+ # You can only use this computer program if you have closed
7
+ # a license agreement with MPG or you get the right to use the computer
8
+ # program from someone who is authorized to grant you that right.
9
+ # Any use of the computer program without a valid license is prohibited and
10
+ # liable to prosecution.
11
+ #
12
+ # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
13
+ # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
14
+ # for Intelligent Systems. All rights reserved.
15
+ #
16
+ # Contact: [email protected]
17
+
18
+
19
+ from lib.renderer.gl.normal_render import NormalRender
20
+ from lib.dataset.mesh_util import projection
21
+ from lib.common.render import Render
22
+ from PIL import Image
23
+ import numpy as np
24
+ import torch
25
+ from torch import nn
26
+ import trimesh
27
+ import os.path as osp
28
+ from PIL import Image
29
+
30
+
31
+ class Evaluator:
32
+
33
+ _normal_render = None
34
+
35
+ @staticmethod
36
+ def init_gl():
37
+ Evaluator._normal_render = NormalRender(width=512, height=512)
38
+
39
+ def __init__(self, device):
40
+ self.device = device
41
+ self.render = Render(size=512, device=self.device)
42
+ self.error_term = nn.MSELoss()
43
+
44
+ self.offset = 0.0
45
+ self.scale_factor = None
46
+
47
+ def set_mesh(self, result_dict, scale_factor=1.0, offset=0.0):
48
+
49
+ for key in result_dict.keys():
50
+ if torch.is_tensor(result_dict[key]):
51
+ result_dict[key] = result_dict[key].detach().cpu().numpy()
52
+
53
+ for k, v in result_dict.items():
54
+ setattr(self, k, v)
55
+
56
+ self.scale_factor = scale_factor
57
+ self.offset = offset
58
+
59
+ def _render_normal(self, mesh, deg, norms=None):
60
+ view_mat = np.identity(4)
61
+ rz = deg / 180.0 * np.pi
62
+ model_mat = np.identity(4)
63
+ model_mat[:3, :3] = self._normal_render.euler_to_rot_mat(0, rz, 0)
64
+ model_mat[1, 3] = self.offset
65
+ view_mat[2, 2] *= -1
66
+
67
+ self._normal_render.set_matrices(view_mat, model_mat)
68
+ if norms is None:
69
+ norms = mesh.vertex_normals
70
+ self._normal_render.set_normal_mesh(self.scale_factor * mesh.vertices,
71
+ mesh.faces, norms, mesh.faces)
72
+ self._normal_render.draw()
73
+ normal_img = self._normal_render.get_color()
74
+ return normal_img
75
+
76
+ def render_mesh_list(self, mesh_lst):
77
+
78
+ self.offset = 0.0
79
+ self.scale_factor = 1.0
80
+
81
+ full_list = []
82
+ for mesh in mesh_lst:
83
+ row_lst = []
84
+ for deg in np.arange(0, 360, 90):
85
+ normal = self._render_normal(mesh, deg)
86
+ row_lst.append(normal)
87
+ full_list.append(np.concatenate(row_lst, axis=1))
88
+
89
+ res_array = np.concatenate(full_list, axis=0)
90
+
91
+ return res_array
92
+
93
+ def _get_reproj_normal_error(self, deg):
94
+
95
+ tgt_normal = self._render_normal(self.tgt_mesh, deg)
96
+ src_normal = self._render_normal(self.src_mesh, deg)
97
+ error = (((src_normal[:, :, :3] -
98
+ tgt_normal[:, :, :3])**2).sum(axis=2).mean(axis=(0, 1)))
99
+
100
+ return error, [src_normal, tgt_normal]
101
+
102
+ def render_normal(self, verts, faces):
103
+
104
+ verts = verts[0].detach().cpu().numpy()
105
+ faces = faces[0].detach().cpu().numpy()
106
+
107
+ mesh_F = trimesh.Trimesh(verts * np.array([1.0, -1.0, 1.0]), faces)
108
+ mesh_B = trimesh.Trimesh(verts * np.array([1.0, -1.0, -1.0]), faces)
109
+
110
+ self.scale_factor = 1.0
111
+
112
+ normal_F = self._render_normal(mesh_F, 0)
113
+ normal_B = self._render_normal(mesh_B,
114
+ 0,
115
+ norms=mesh_B.vertex_normals *
116
+ np.array([-1.0, -1.0, 1.0]))
117
+
118
+ mask = normal_F[:, :, 3:4]
119
+ normal_F = (torch.as_tensor(2.0 * (normal_F - 0.5) * mask).permute(
120
+ 2, 0, 1)[:3, :, :].float().unsqueeze(0).to(self.device))
121
+ normal_B = (torch.as_tensor(2.0 * (normal_B - 0.5) * mask).permute(
122
+ 2, 0, 1)[:3, :, :].float().unsqueeze(0).to(self.device))
123
+
124
+ return {"T_normal_F": normal_F, "T_normal_B": normal_B}
125
+
126
+ def calculate_normal_consist(
127
+ self,
128
+ frontal=True,
129
+ back=True,
130
+ left=True,
131
+ right=True,
132
+ save_demo_img=None,
133
+ return_demo=False,
134
+ ):
135
+
136
+ # reproj error
137
+ # if save_demo_img is not None, save a visualization at the given path (etc, "./test.png")
138
+ if self._normal_render is None:
139
+ print(
140
+ "In order to use normal render, "
141
+ "you have to call init_gl() before initialing any evaluator objects."
142
+ )
143
+ return -1
144
+
145
+ side_cnt = 0
146
+ total_error = 0
147
+ demo_list = []
148
+
149
+ if frontal:
150
+ side_cnt += 1
151
+ error, normal_lst = self._get_reproj_normal_error(0)
152
+ total_error += error
153
+ demo_list.append(np.concatenate(normal_lst, axis=0))
154
+ if back:
155
+ side_cnt += 1
156
+ error, normal_lst = self._get_reproj_normal_error(180)
157
+ total_error += error
158
+ demo_list.append(np.concatenate(normal_lst, axis=0))
159
+ if left:
160
+ side_cnt += 1
161
+ error, normal_lst = self._get_reproj_normal_error(90)
162
+ total_error += error
163
+ demo_list.append(np.concatenate(normal_lst, axis=0))
164
+ if right:
165
+ side_cnt += 1
166
+ error, normal_lst = self._get_reproj_normal_error(270)
167
+ total_error += error
168
+ demo_list.append(np.concatenate(normal_lst, axis=0))
169
+ if save_demo_img is not None:
170
+ res_array = np.concatenate(demo_list, axis=1)
171
+ res_img = Image.fromarray((res_array * 255).astype(np.uint8))
172
+ res_img.save(save_demo_img)
173
+
174
+ if return_demo:
175
+ res_array = np.concatenate(demo_list, axis=1)
176
+ return res_array
177
+ else:
178
+ return total_error
179
+
180
+ def space_transfer(self):
181
+
182
+ # convert from GT to SDF
183
+ self.verts_pr -= self.recon_size / 2.0
184
+ self.verts_pr /= self.recon_size / 2.0
185
+
186
+ self.verts_gt = projection(self.verts_gt, self.calib)
187
+ self.verts_gt[:, 1] *= -1
188
+
189
+ self.tgt_mesh = trimesh.Trimesh(self.verts_gt, self.faces_gt)
190
+ self.src_mesh = trimesh.Trimesh(self.verts_pr, self.faces_pr)
191
+
192
+ # (self.tgt_mesh+self.src_mesh).show()
193
+
194
+ def export_mesh(self, dir, name):
195
+ self.tgt_mesh.visual.vertex_colors = np.array([255, 0, 0])
196
+ self.src_mesh.visual.vertex_colors = np.array([0, 255, 0])
197
+
198
+ (self.tgt_mesh + self.src_mesh).export(
199
+ osp.join(dir, f"{name}_gt_pr.obj"))
200
+
201
+ def calculate_chamfer_p2s(self, sampled_points=1000):
202
+ """calculate the geometry metrics [chamfer, p2s, chamfer_H, p2s_H]
203
+
204
+ Args:
205
+ verts_gt (torch.cuda.tensor): [N, 3]
206
+ faces_gt (torch.cuda.tensor): [M, 3]
207
+ verts_pr (torch.cuda.tensor): [N', 3]
208
+ faces_pr (torch.cuda.tensor): [M', 3]
209
+ sampled_points (int, optional): use smaller number for faster testing. Defaults to 1000.
210
+
211
+ Returns:
212
+ tuple: chamfer, p2s, chamfer_H, p2s_H
213
+ """
214
+
215
+ gt_surface_pts, _ = trimesh.sample.sample_surface_even(
216
+ self.tgt_mesh, sampled_points)
217
+ pred_surface_pts, _ = trimesh.sample.sample_surface_even(
218
+ self.src_mesh, sampled_points)
219
+
220
+ _, dist_pred_gt, _ = trimesh.proximity.closest_point(
221
+ self.src_mesh, gt_surface_pts)
222
+ _, dist_gt_pred, _ = trimesh.proximity.closest_point(
223
+ self.tgt_mesh, pred_surface_pts)
224
+
225
+ dist_pred_gt[np.isnan(dist_pred_gt)] = 0
226
+ dist_gt_pred[np.isnan(dist_gt_pred)] = 0
227
+ chamfer_dist = 0.5 * (dist_pred_gt.mean() +
228
+ dist_gt_pred.mean()).item() * 100
229
+ p2s_dist = dist_pred_gt.mean().item() * 100
230
+
231
+ return chamfer_dist, p2s_dist
232
+
233
+ def calc_acc(self, output, target, thres=0.5, use_sdf=False):
234
+
235
+ # # remove the surface points with thres
236
+ # non_surf_ids = (target != thres)
237
+ # output = output[non_surf_ids]
238
+ # target = target[non_surf_ids]
239
+
240
+ with torch.no_grad():
241
+ output = output.masked_fill(output < thres, 0.0)
242
+ output = output.masked_fill(output > thres, 1.0)
243
+
244
+ if use_sdf:
245
+ target = target.masked_fill(target < thres, 0.0)
246
+ target = target.masked_fill(target > thres, 1.0)
247
+
248
+ acc = output.eq(target).float().mean()
249
+
250
+ # iou, precison, recall
251
+ output = output > thres
252
+ target = target > thres
253
+
254
+ union = output | target
255
+ inter = output & target
256
+
257
+ _max = torch.tensor(1.0).to(output.device)
258
+
259
+ union = max(union.sum().float(), _max)
260
+ true_pos = max(inter.sum().float(), _max)
261
+ vol_pred = max(output.sum().float(), _max)
262
+ vol_gt = max(target.sum().float(), _max)
263
+
264
+ return acc, true_pos / union, true_pos / vol_pred, true_pos / vol_gt
lib/dataset/NormalDataset.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
5
+ # holder of all proprietary rights on this computer program.
6
+ # You can only use this computer program if you have closed
7
+ # a license agreement with MPG or you get the right to use the computer
8
+ # program from someone who is authorized to grant you that right.
9
+ # Any use of the computer program without a valid license is prohibited and
10
+ # liable to prosecution.
11
+ #
12
+ # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
13
+ # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
14
+ # for Intelligent Systems. All rights reserved.
15
+ #
16
+ # Contact: [email protected]
17
+
18
+ import os.path as osp
19
+ import numpy as np
20
+ from PIL import Image
21
+ import torchvision.transforms as transforms
22
+
23
+
24
+ class NormalDataset():
25
+ def __init__(self, cfg, split='train'):
26
+
27
+ self.split = split
28
+ self.root = cfg.root
29
+ self.overfit = cfg.overfit
30
+
31
+ self.opt = cfg.dataset
32
+ self.datasets = self.opt.types
33
+ self.input_size = self.opt.input_size
34
+ self.set_splits = self.opt.set_splits
35
+ self.scales = self.opt.scales
36
+ self.pifu = self.opt.pifu
37
+
38
+ # input data types and dimensions
39
+ self.in_nml = [item[0] for item in cfg.net.in_nml]
40
+ self.in_nml_dim = [item[1] for item in cfg.net.in_nml]
41
+ self.in_total = self.in_nml + ['normal_F', 'normal_B']
42
+ self.in_total_dim = self.in_nml_dim + [3, 3]
43
+
44
+ if self.split != 'train':
45
+ self.rotations = range(0, 360, 120)
46
+ else:
47
+ self.rotations = np.arange(0, 360, 360 /
48
+ self.opt.rotation_num).astype(np.int)
49
+
50
+ self.datasets_dict = {}
51
+ for dataset_id, dataset in enumerate(self.datasets):
52
+ dataset_dir = osp.join(self.root, dataset, "smplx")
53
+ self.datasets_dict[dataset] = {
54
+ "subjects":
55
+ np.loadtxt(osp.join(self.root, dataset, "all.txt"), dtype=str),
56
+ "path":
57
+ dataset_dir,
58
+ "scale":
59
+ self.scales[dataset_id]
60
+ }
61
+
62
+ self.subject_list = self.get_subject_list(split)
63
+
64
+ # PIL to tensor
65
+ self.image_to_tensor = transforms.Compose([
66
+ transforms.Resize(self.input_size),
67
+ transforms.ToTensor(),
68
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
69
+ ])
70
+
71
+ # PIL to tensor
72
+ self.mask_to_tensor = transforms.Compose([
73
+ transforms.Resize(self.input_size),
74
+ transforms.ToTensor(),
75
+ transforms.Normalize((0.0, ), (1.0, ))
76
+ ])
77
+
78
+ def get_subject_list(self, split):
79
+
80
+ subject_list = []
81
+
82
+ for dataset in self.datasets:
83
+
84
+ if self.pifu:
85
+ txt = osp.join(self.root, dataset, f'{split}_pifu.txt')
86
+ else:
87
+ txt = osp.join(self.root, dataset, f'{split}.txt')
88
+
89
+ if osp.exists(txt):
90
+ print(f"load from {txt}")
91
+ subject_list += sorted(np.loadtxt(txt, dtype=str).tolist())
92
+
93
+ if self.pifu:
94
+ miss_pifu = sorted(
95
+ np.loadtxt(osp.join(self.root, dataset,
96
+ "miss_pifu.txt"),
97
+ dtype=str).tolist())
98
+ subject_list = [
99
+ subject for subject in subject_list
100
+ if subject not in miss_pifu
101
+ ]
102
+ subject_list = [
103
+ "renderpeople/" + subject for subject in subject_list
104
+ ]
105
+
106
+ else:
107
+ train_txt = osp.join(self.root, dataset, 'train.txt')
108
+ val_txt = osp.join(self.root, dataset, 'val.txt')
109
+ test_txt = osp.join(self.root, dataset, 'test.txt')
110
+
111
+ print(
112
+ f"generate lists of [train, val, test] \n {train_txt} \n {val_txt} \n {test_txt} \n"
113
+ )
114
+
115
+ split_txt = osp.join(self.root, dataset, f'{split}.txt')
116
+
117
+ subjects = self.datasets_dict[dataset]['subjects']
118
+ train_split = int(len(subjects) * self.set_splits[0])
119
+ val_split = int(
120
+ len(subjects) * self.set_splits[1]) + train_split
121
+
122
+ with open(train_txt, "w") as f:
123
+ f.write("\n".join(dataset + "/" + item
124
+ for item in subjects[:train_split]))
125
+ with open(val_txt, "w") as f:
126
+ f.write("\n".join(
127
+ dataset + "/" + item
128
+ for item in subjects[train_split:val_split]))
129
+ with open(test_txt, "w") as f:
130
+ f.write("\n".join(dataset + "/" + item
131
+ for item in subjects[val_split:]))
132
+
133
+ subject_list += sorted(
134
+ np.loadtxt(split_txt, dtype=str).tolist())
135
+
136
+ bug_list = sorted(
137
+ np.loadtxt(osp.join(self.root, 'bug.txt'), dtype=str).tolist())
138
+
139
+ subject_list = [
140
+ subject for subject in subject_list if (subject not in bug_list)
141
+ ]
142
+
143
+ return subject_list
144
+
145
+ def __len__(self):
146
+ return len(self.subject_list) * len(self.rotations)
147
+
148
+ def __getitem__(self, index):
149
+
150
+ # only pick the first data if overfitting
151
+ if self.overfit:
152
+ index = 0
153
+
154
+ rid = index % len(self.rotations)
155
+ mid = index // len(self.rotations)
156
+
157
+ rotation = self.rotations[rid]
158
+
159
+ # choose specific test sets
160
+ subject = self.subject_list[mid]
161
+
162
+ subject_render = "/".join(
163
+ [subject.split("/")[0] + "_12views",
164
+ subject.split("/")[1]])
165
+
166
+ # setup paths
167
+ data_dict = {
168
+ 'dataset':
169
+ subject.split("/")[0],
170
+ 'subject':
171
+ subject,
172
+ 'rotation':
173
+ rotation,
174
+ 'image_path':
175
+ osp.join(self.root, subject_render, 'render',
176
+ f'{rotation:03d}.png')
177
+ }
178
+
179
+ # image/normal/depth loader
180
+ for name, channel in zip(self.in_total, self.in_total_dim):
181
+
182
+ if name != 'image':
183
+ data_dict.update({
184
+ f'{name}_path':
185
+ osp.join(self.root, subject_render, name,
186
+ f'{rotation:03d}.png')
187
+ })
188
+ data_dict.update({
189
+ name:
190
+ self.imagepath2tensor(data_dict[f'{name}_path'],
191
+ channel,
192
+ inv='depth_B' in name)
193
+ })
194
+
195
+ path_keys = [
196
+ key for key in data_dict.keys() if '_path' in key or '_dir' in key
197
+ ]
198
+ for key in path_keys:
199
+ del data_dict[key]
200
+
201
+ return data_dict
202
+
203
+ def imagepath2tensor(self, path, channel=3, inv=False):
204
+
205
+ rgba = Image.open(path).convert('RGBA')
206
+ mask = rgba.split()[-1]
207
+ image = rgba.convert('RGB')
208
+ image = self.image_to_tensor(image)
209
+ mask = self.mask_to_tensor(mask)
210
+ image = (image * mask)[:channel]
211
+
212
+ return (image * (0.5 - inv) * 2.0).float()
lib/dataset/NormalModule.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
5
+ # holder of all proprietary rights on this computer program.
6
+ # You can only use this computer program if you have closed
7
+ # a license agreement with MPG or you get the right to use the computer
8
+ # program from someone who is authorized to grant you that right.
9
+ # Any use of the computer program without a valid license is prohibited and
10
+ # liable to prosecution.
11
+ #
12
+ # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
13
+ # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
14
+ # for Intelligent Systems. All rights reserved.
15
+ #
16
+ # Contact: [email protected]
17
+
18
+ import numpy as np
19
+ from torch.utils.data import DataLoader
20
+ from .NormalDataset import NormalDataset
21
+
22
+ # pytorch lightning related libs
23
+ import pytorch_lightning as pl
24
+
25
+
26
+ class NormalModule(pl.LightningDataModule):
27
+ def __init__(self, cfg):
28
+ super(NormalModule, self).__init__()
29
+ self.cfg = cfg
30
+ self.overfit = self.cfg.overfit
31
+
32
+ if self.overfit:
33
+ self.batch_size = 1
34
+ else:
35
+ self.batch_size = self.cfg.batch_size
36
+
37
+ self.data_size = {}
38
+
39
+ def prepare_data(self):
40
+
41
+ pass
42
+
43
+ @staticmethod
44
+ def worker_init_fn(worker_id):
45
+ np.random.seed(np.random.get_state()[1][0] + worker_id)
46
+
47
+ def setup(self, stage):
48
+
49
+ if stage == 'fit' or stage is None:
50
+ self.train_dataset = NormalDataset(cfg=self.cfg, split="train")
51
+ self.val_dataset = NormalDataset(cfg=self.cfg, split="val")
52
+ self.data_size = {
53
+ 'train': len(self.train_dataset),
54
+ 'val': len(self.val_dataset)
55
+ }
56
+
57
+ if stage == 'test' or stage is None:
58
+ self.test_dataset = NormalDataset(cfg=self.cfg, split="test")
59
+
60
+ def train_dataloader(self):
61
+
62
+ train_data_loader = DataLoader(self.train_dataset,
63
+ batch_size=self.batch_size,
64
+ shuffle=not self.overfit,
65
+ num_workers=self.cfg.num_threads,
66
+ pin_memory=True,
67
+ worker_init_fn=self.worker_init_fn)
68
+
69
+ return train_data_loader
70
+
71
+ def val_dataloader(self):
72
+
73
+ if self.overfit:
74
+ current_dataset = self.train_dataset
75
+ else:
76
+ current_dataset = self.val_dataset
77
+
78
+ val_data_loader = DataLoader(current_dataset,
79
+ batch_size=self.batch_size,
80
+ shuffle=False,
81
+ num_workers=self.cfg.num_threads,
82
+ pin_memory=True)
83
+
84
+ return val_data_loader
85
+
86
+ def test_dataloader(self):
87
+
88
+ test_data_loader = DataLoader(self.test_dataset,
89
+ batch_size=1,
90
+ shuffle=False,
91
+ num_workers=self.cfg.num_threads,
92
+ pin_memory=True)
93
+
94
+ return test_data_loader
lib/dataset/PIFuDataModule.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from torch.utils.data import DataLoader
3
+ from .PIFuDataset import PIFuDataset
4
+ import pytorch_lightning as pl
5
+
6
+
7
+ class PIFuDataModule(pl.LightningDataModule):
8
+ def __init__(self, cfg):
9
+ super(PIFuDataModule, self).__init__()
10
+ self.cfg = cfg
11
+ self.overfit = self.cfg.overfit
12
+
13
+ if self.overfit:
14
+ self.batch_size = 1
15
+ else:
16
+ self.batch_size = self.cfg.batch_size
17
+
18
+ self.data_size = {}
19
+
20
+ def prepare_data(self):
21
+
22
+ pass
23
+
24
+ @staticmethod
25
+ def worker_init_fn(worker_id):
26
+ np.random.seed(np.random.get_state()[1][0] + worker_id)
27
+
28
+ def setup(self, stage):
29
+
30
+ if stage == 'fit':
31
+ self.train_dataset = PIFuDataset(cfg=self.cfg, split="train")
32
+ self.val_dataset = PIFuDataset(cfg=self.cfg, split="val")
33
+ self.data_size = {'train': len(self.train_dataset),
34
+ 'val': len(self.val_dataset)}
35
+
36
+ if stage == 'test':
37
+ self.test_dataset = PIFuDataset(cfg=self.cfg, split="test")
38
+
39
+ def train_dataloader(self):
40
+
41
+ train_data_loader = DataLoader(
42
+ self.train_dataset,
43
+ batch_size=self.batch_size, shuffle=True,
44
+ num_workers=self.cfg.num_threads, pin_memory=True,
45
+ worker_init_fn=self.worker_init_fn)
46
+
47
+ return train_data_loader
48
+
49
+ def val_dataloader(self):
50
+
51
+ if self.overfit:
52
+ current_dataset = self.train_dataset
53
+ else:
54
+ current_dataset = self.val_dataset
55
+
56
+ val_data_loader = DataLoader(
57
+ current_dataset,
58
+ batch_size=1, shuffle=False,
59
+ num_workers=self.cfg.num_threads, pin_memory=True,
60
+ worker_init_fn=self.worker_init_fn)
61
+
62
+ return val_data_loader
63
+
64
+ def test_dataloader(self):
65
+
66
+ test_data_loader = DataLoader(
67
+ self.test_dataset,
68
+ batch_size=1, shuffle=False,
69
+ num_workers=self.cfg.num_threads, pin_memory=True)
70
+
71
+ return test_data_loader
lib/dataset/PIFuDataset.py ADDED
@@ -0,0 +1,589 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from lib.renderer.mesh import load_fit_body
2
+ from lib.dataset.hoppeMesh import HoppeMesh
3
+ from lib.dataset.body_model import TetraSMPLModel
4
+ from lib.common.render import Render
5
+ from lib.dataset.mesh_util import SMPLX, projection, cal_sdf_batch, get_visibility
6
+ from lib.pare.pare.utils.geometry import rotation_matrix_to_angle_axis
7
+ from termcolor import colored
8
+ import os.path as osp
9
+ import numpy as np
10
+ from PIL import Image
11
+ import random
12
+ import os
13
+ import trimesh
14
+ import torch
15
+ from kaolin.ops.mesh import check_sign
16
+ import torchvision.transforms as transforms
17
+ from huggingface_hub import hf_hub_download, cached_download
18
+
19
+
20
+ class PIFuDataset():
21
+ def __init__(self, cfg, split='train', vis=False):
22
+
23
+ self.split = split
24
+ self.root = cfg.root
25
+ self.bsize = cfg.batch_size
26
+ self.overfit = cfg.overfit
27
+
28
+ # for debug, only used in visualize_sampling3D
29
+ self.vis = vis
30
+
31
+ self.opt = cfg.dataset
32
+ self.datasets = self.opt.types
33
+ self.input_size = self.opt.input_size
34
+ self.scales = self.opt.scales
35
+ self.workers = cfg.num_threads
36
+ self.prior_type = cfg.net.prior_type
37
+
38
+ self.noise_type = self.opt.noise_type
39
+ self.noise_scale = self.opt.noise_scale
40
+
41
+ noise_joints = [4, 5, 7, 8, 13, 14, 16, 17, 18, 19, 20, 21]
42
+
43
+ self.noise_smpl_idx = []
44
+ self.noise_smplx_idx = []
45
+
46
+ for idx in noise_joints:
47
+ self.noise_smpl_idx.append(idx * 3)
48
+ self.noise_smpl_idx.append(idx * 3 + 1)
49
+ self.noise_smpl_idx.append(idx * 3 + 2)
50
+
51
+ self.noise_smplx_idx.append((idx-1) * 3)
52
+ self.noise_smplx_idx.append((idx-1) * 3 + 1)
53
+ self.noise_smplx_idx.append((idx-1) * 3 + 2)
54
+
55
+ self.use_sdf = cfg.sdf
56
+ self.sdf_clip = cfg.sdf_clip
57
+
58
+ # [(feat_name, channel_num),...]
59
+ self.in_geo = [item[0] for item in cfg.net.in_geo]
60
+ self.in_nml = [item[0] for item in cfg.net.in_nml]
61
+
62
+ self.in_geo_dim = [item[1] for item in cfg.net.in_geo]
63
+ self.in_nml_dim = [item[1] for item in cfg.net.in_nml]
64
+
65
+ self.in_total = self.in_geo + self.in_nml
66
+ self.in_total_dim = self.in_geo_dim + self.in_nml_dim
67
+
68
+ if self.split == 'train':
69
+ self.rotations = np.arange(
70
+ 0, 360, 360 / self.opt.rotation_num).astype(np.int32)
71
+ else:
72
+ self.rotations = range(0, 360, 120)
73
+
74
+ self.datasets_dict = {}
75
+
76
+ for dataset_id, dataset in enumerate(self.datasets):
77
+
78
+ mesh_dir = None
79
+ smplx_dir = None
80
+
81
+ dataset_dir = osp.join(self.root, dataset)
82
+
83
+ if dataset in ['thuman2']:
84
+ mesh_dir = osp.join(dataset_dir, "scans")
85
+ smplx_dir = osp.join(dataset_dir, "fits")
86
+ smpl_dir = osp.join(dataset_dir, "smpl")
87
+
88
+ self.datasets_dict[dataset] = {
89
+ "subjects": np.loadtxt(osp.join(dataset_dir, "all.txt"), dtype=str),
90
+ "smplx_dir": smplx_dir,
91
+ "smpl_dir": smpl_dir,
92
+ "mesh_dir": mesh_dir,
93
+ "scale": self.scales[dataset_id]
94
+ }
95
+
96
+ self.subject_list = self.get_subject_list(split)
97
+ self.smplx = SMPLX()
98
+
99
+ # PIL to tensor
100
+ self.image_to_tensor = transforms.Compose([
101
+ transforms.Resize(self.input_size),
102
+ transforms.ToTensor(),
103
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
104
+ ])
105
+
106
+ # PIL to tensor
107
+ self.mask_to_tensor = transforms.Compose([
108
+ transforms.Resize(self.input_size),
109
+ transforms.ToTensor(),
110
+ transforms.Normalize((0.0, ), (1.0, ))
111
+ ])
112
+
113
+ self.device = torch.device(f"cuda:{cfg.gpus[0]}")
114
+ self.render = Render(size=512, device=self.device)
115
+
116
+ def render_normal(self, verts, faces):
117
+
118
+ # render optimized mesh (normal, T_normal, image [-1,1])
119
+ self.render.load_meshes(verts, faces)
120
+ return self.render.get_rgb_image()
121
+
122
+ def get_subject_list(self, split):
123
+
124
+ subject_list = []
125
+
126
+ for dataset in self.datasets:
127
+
128
+ split_txt = osp.join(self.root, dataset, f'{split}.txt')
129
+
130
+ if osp.exists(split_txt):
131
+ print(f"load from {split_txt}")
132
+ subject_list += np.loadtxt(split_txt, dtype=str).tolist()
133
+ else:
134
+ full_txt = osp.join(self.root, dataset, 'all.txt')
135
+ print(f"split {full_txt} into train/val/test")
136
+
137
+ full_lst = np.loadtxt(full_txt, dtype=str)
138
+ full_lst = [dataset+"/"+item for item in full_lst]
139
+ [train_lst, test_lst, val_lst] = np.split(
140
+ full_lst, [500, 500+5, ])
141
+
142
+ np.savetxt(full_txt.replace(
143
+ "all", "train"), train_lst, fmt="%s")
144
+ np.savetxt(full_txt.replace("all", "test"), test_lst, fmt="%s")
145
+ np.savetxt(full_txt.replace("all", "val"), val_lst, fmt="%s")
146
+
147
+ print(f"load from {split_txt}")
148
+ subject_list += np.loadtxt(split_txt, dtype=str).tolist()
149
+
150
+ if self.split != 'test':
151
+ subject_list += subject_list[:self.bsize -
152
+ len(subject_list) % self.bsize]
153
+ print(colored(f"total: {len(subject_list)}", "yellow"))
154
+ random.shuffle(subject_list)
155
+
156
+ # subject_list = ["thuman2/0008"]
157
+ return subject_list
158
+
159
+ def __len__(self):
160
+ return len(self.subject_list) * len(self.rotations)
161
+
162
+ def __getitem__(self, index):
163
+
164
+ # only pick the first data if overfitting
165
+ if self.overfit:
166
+ index = 0
167
+
168
+ rid = index % len(self.rotations)
169
+ mid = index // len(self.rotations)
170
+
171
+ rotation = self.rotations[rid]
172
+ subject = self.subject_list[mid].split("/")[1]
173
+ dataset = self.subject_list[mid].split("/")[0]
174
+ render_folder = "/".join([dataset +
175
+ f"_{self.opt.rotation_num}views", subject])
176
+
177
+ # setup paths
178
+ data_dict = {
179
+ 'dataset': dataset,
180
+ 'subject': subject,
181
+ 'rotation': rotation,
182
+ 'scale': self.datasets_dict[dataset]["scale"],
183
+ 'mesh_path': osp.join(self.datasets_dict[dataset]["mesh_dir"], f"{subject}/{subject}.obj"),
184
+ 'smplx_path': osp.join(self.datasets_dict[dataset]["smplx_dir"], f"{subject}/smplx_param.pkl"),
185
+ 'smpl_path': osp.join(self.datasets_dict[dataset]["smpl_dir"], f"{subject}.pkl"),
186
+ 'calib_path': osp.join(self.root, render_folder, 'calib', f'{rotation:03d}.txt'),
187
+ 'vis_path': osp.join(self.root, render_folder, 'vis', f'{rotation:03d}.pt'),
188
+ 'image_path': osp.join(self.root, render_folder, 'render', f'{rotation:03d}.png')
189
+ }
190
+
191
+ # load training data
192
+ data_dict.update(self.load_calib(data_dict))
193
+
194
+ # image/normal/depth loader
195
+ for name, channel in zip(self.in_total, self.in_total_dim):
196
+
197
+ if f'{name}_path' not in data_dict.keys():
198
+ data_dict.update({
199
+ f'{name}_path': osp.join(self.root, render_folder, name, f'{rotation:03d}.png')
200
+ })
201
+
202
+ # tensor update
203
+ data_dict.update({
204
+ name: self.imagepath2tensor(
205
+ data_dict[f'{name}_path'], channel, inv=False)
206
+ })
207
+
208
+ data_dict.update(self.load_mesh(data_dict))
209
+ data_dict.update(self.get_sampling_geo(
210
+ data_dict, is_valid=self.split == "val", is_sdf=self.use_sdf))
211
+ data_dict.update(self.load_smpl(data_dict, self.vis))
212
+
213
+ if self.prior_type == 'pamir':
214
+ data_dict.update(self.load_smpl_voxel(data_dict))
215
+
216
+ if (self.split != 'test') and (not self.vis):
217
+
218
+ del data_dict['verts']
219
+ del data_dict['faces']
220
+
221
+ if not self.vis:
222
+ del data_dict['mesh']
223
+
224
+ path_keys = [
225
+ key for key in data_dict.keys() if '_path' in key or '_dir' in key
226
+ ]
227
+ for key in path_keys:
228
+ del data_dict[key]
229
+
230
+ return data_dict
231
+
232
+ def imagepath2tensor(self, path, channel=3, inv=False):
233
+
234
+ rgba = Image.open(path).convert('RGBA')
235
+ mask = rgba.split()[-1]
236
+ image = rgba.convert('RGB')
237
+ image = self.image_to_tensor(image)
238
+ mask = self.mask_to_tensor(mask)
239
+ image = (image * mask)[:channel]
240
+
241
+ return (image * (0.5 - inv) * 2.0).float()
242
+
243
+ def load_calib(self, data_dict):
244
+ calib_data = np.loadtxt(data_dict['calib_path'], dtype=float)
245
+ extrinsic = calib_data[:4, :4]
246
+ intrinsic = calib_data[4:8, :4]
247
+ calib_mat = np.matmul(intrinsic, extrinsic)
248
+ calib_mat = torch.from_numpy(calib_mat).float()
249
+ return {'calib': calib_mat}
250
+
251
+ def load_mesh(self, data_dict):
252
+ mesh_path = data_dict['mesh_path']
253
+ scale = data_dict['scale']
254
+
255
+ mesh_ori = trimesh.load(mesh_path,
256
+ skip_materials=True,
257
+ process=False,
258
+ maintain_order=True)
259
+ verts = mesh_ori.vertices * scale
260
+ faces = mesh_ori.faces
261
+
262
+ vert_normals = np.array(mesh_ori.vertex_normals)
263
+ face_normals = np.array(mesh_ori.face_normals)
264
+
265
+ mesh = HoppeMesh(verts, faces, vert_normals, face_normals)
266
+
267
+ return {
268
+ 'mesh': mesh,
269
+ 'verts': torch.as_tensor(mesh.verts).float(),
270
+ 'faces': torch.as_tensor(mesh.faces).long()
271
+ }
272
+
273
+ def add_noise(self,
274
+ beta_num,
275
+ smpl_pose,
276
+ smpl_betas,
277
+ noise_type,
278
+ noise_scale,
279
+ type,
280
+ hashcode):
281
+
282
+ np.random.seed(hashcode)
283
+
284
+ if type == 'smplx':
285
+ noise_idx = self.noise_smplx_idx
286
+ else:
287
+ noise_idx = self.noise_smpl_idx
288
+
289
+ if 'beta' in noise_type and noise_scale[noise_type.index("beta")] > 0.0:
290
+ smpl_betas += (np.random.rand(beta_num) -
291
+ 0.5) * 2.0 * noise_scale[noise_type.index("beta")]
292
+ smpl_betas = smpl_betas.astype(np.float32)
293
+
294
+ if 'pose' in noise_type and noise_scale[noise_type.index("pose")] > 0.0:
295
+ smpl_pose[noise_idx] += (
296
+ np.random.rand(len(noise_idx)) -
297
+ 0.5) * 2.0 * np.pi * noise_scale[noise_type.index("pose")]
298
+ smpl_pose = smpl_pose.astype(np.float32)
299
+ if type == 'smplx':
300
+ return torch.as_tensor(smpl_pose[None, ...]), torch.as_tensor(smpl_betas[None, ...])
301
+ else:
302
+ return smpl_pose, smpl_betas
303
+
304
+ def compute_smpl_verts(self, data_dict, noise_type=None, noise_scale=None):
305
+
306
+ dataset = data_dict['dataset']
307
+ smplx_dict = {}
308
+
309
+ smplx_param = np.load(data_dict['smplx_path'], allow_pickle=True)
310
+ smplx_pose = smplx_param["body_pose"] # [1,63]
311
+ smplx_betas = smplx_param["betas"] # [1,10]
312
+ smplx_pose, smplx_betas = self.add_noise(
313
+ smplx_betas.shape[1],
314
+ smplx_pose[0],
315
+ smplx_betas[0],
316
+ noise_type,
317
+ noise_scale,
318
+ type='smplx',
319
+ hashcode=(hash(f"{data_dict['subject']}_{data_dict['rotation']}")) % (10**8))
320
+
321
+ smplx_out, _ = load_fit_body(fitted_path=data_dict['smplx_path'],
322
+ scale=self.datasets_dict[dataset]['scale'],
323
+ smpl_type='smplx',
324
+ smpl_gender='male',
325
+ noise_dict=dict(betas=smplx_betas, body_pose=smplx_pose))
326
+
327
+ smplx_dict.update({"type": "smplx",
328
+ "gender": 'male',
329
+ "body_pose": torch.as_tensor(smplx_pose),
330
+ "betas": torch.as_tensor(smplx_betas)})
331
+
332
+ return smplx_out.vertices, smplx_dict
333
+
334
+ def compute_voxel_verts(self,
335
+ data_dict,
336
+ noise_type=None,
337
+ noise_scale=None):
338
+
339
+ smpl_param = np.load(data_dict['smpl_path'], allow_pickle=True)
340
+ smplx_param = np.load(data_dict['smplx_path'], allow_pickle=True)
341
+
342
+ smpl_pose = rotation_matrix_to_angle_axis(
343
+ torch.as_tensor(smpl_param['full_pose'][0])).numpy()
344
+ smpl_betas = smpl_param["betas"]
345
+
346
+ smpl_path = cached_download(osp.join(self.smplx.model_dir, "smpl/SMPL_MALE.pkl"), use_auth_token=os.environ['ICON'])
347
+ tetra_path = cached_download(osp.join(self.smplx.tedra_dir,
348
+ "tetra_male_adult_smpl.npz"), use_auth_token=os.environ['ICON'])
349
+
350
+ smpl_model = TetraSMPLModel(smpl_path, tetra_path, 'adult')
351
+
352
+ smpl_pose, smpl_betas = self.add_noise(
353
+ smpl_model.beta_shape[0],
354
+ smpl_pose.flatten(),
355
+ smpl_betas[0],
356
+ noise_type,
357
+ noise_scale,
358
+ type='smpl',
359
+ hashcode=(hash(f"{data_dict['subject']}_{data_dict['rotation']}")) % (10**8))
360
+
361
+ smpl_model.set_params(pose=smpl_pose.reshape(-1, 3),
362
+ beta=smpl_betas,
363
+ trans=smpl_param["transl"])
364
+
365
+ verts = (np.concatenate([smpl_model.verts, smpl_model.verts_added],
366
+ axis=0) * smplx_param["scale"] + smplx_param["translation"]
367
+ ) * self.datasets_dict[data_dict['dataset']]['scale']
368
+ faces = np.loadtxt(cached_download(osp.join(self.smplx.tedra_dir, "tetrahedrons_male_adult.txt"), use_auth_token=os.environ['ICON']),
369
+ dtype=np.int32) - 1
370
+
371
+ pad_v_num = int(8000 - verts.shape[0])
372
+ pad_f_num = int(25100 - faces.shape[0])
373
+
374
+ verts = np.pad(verts, ((0, pad_v_num), (0, 0)),
375
+ mode='constant',
376
+ constant_values=0.0).astype(np.float32)
377
+ faces = np.pad(faces, ((0, pad_f_num), (0, 0)),
378
+ mode='constant',
379
+ constant_values=0.0).astype(np.int32)
380
+
381
+
382
+ return verts, faces, pad_v_num, pad_f_num
383
+
384
+ def load_smpl(self, data_dict, vis=False):
385
+
386
+ smplx_verts, smplx_dict = self.compute_smpl_verts(
387
+ data_dict, self.noise_type,
388
+ self.noise_scale) # compute using smpl model
389
+
390
+ smplx_verts = projection(smplx_verts, data_dict['calib']).float()
391
+ smplx_faces = torch.as_tensor(self.smplx.faces).long()
392
+ smplx_vis = torch.load(data_dict['vis_path']).float()
393
+ smplx_cmap = torch.as_tensor(
394
+ np.load(self.smplx.cmap_vert_path)).float()
395
+
396
+ # get smpl_signs
397
+ query_points = projection(data_dict['samples_geo'],
398
+ data_dict['calib']).float()
399
+
400
+ pts_signs = 2.0 * (check_sign(smplx_verts.unsqueeze(0),
401
+ smplx_faces,
402
+ query_points.unsqueeze(0)).float() - 0.5).squeeze(0)
403
+
404
+ return_dict = {
405
+ 'smpl_verts': smplx_verts,
406
+ 'smpl_faces': smplx_faces,
407
+ 'smpl_vis': smplx_vis,
408
+ 'smpl_cmap': smplx_cmap,
409
+ 'pts_signs': pts_signs
410
+ }
411
+ if smplx_dict is not None:
412
+ return_dict.update(smplx_dict)
413
+
414
+ if vis:
415
+
416
+ (xy, z) = torch.as_tensor(smplx_verts).to(
417
+ self.device).split([2, 1], dim=1)
418
+ smplx_vis = get_visibility(xy, z, torch.as_tensor(
419
+ smplx_faces).to(self.device).long())
420
+
421
+ T_normal_F, T_normal_B = self.render_normal(
422
+ (smplx_verts*torch.tensor([1.0, -1.0, 1.0])).to(self.device),
423
+ smplx_faces.to(self.device))
424
+
425
+ return_dict.update({"T_normal_F": T_normal_F.squeeze(0),
426
+ "T_normal_B": T_normal_B.squeeze(0)})
427
+ query_points = projection(data_dict['samples_geo'],
428
+ data_dict['calib']).float()
429
+
430
+ smplx_sdf, smplx_norm, smplx_cmap, smplx_vis = cal_sdf_batch(
431
+ smplx_verts.unsqueeze(0).to(self.device),
432
+ smplx_faces.unsqueeze(0).to(self.device),
433
+ smplx_cmap.unsqueeze(0).to(self.device),
434
+ smplx_vis.unsqueeze(0).to(self.device),
435
+ query_points.unsqueeze(0).contiguous().to(self.device))
436
+
437
+ return_dict.update({
438
+ 'smpl_feat':
439
+ torch.cat(
440
+ (smplx_sdf[0].detach().cpu(),
441
+ smplx_cmap[0].detach().cpu(),
442
+ smplx_norm[0].detach().cpu(),
443
+ smplx_vis[0].detach().cpu()),
444
+ dim=1)
445
+ })
446
+
447
+ return return_dict
448
+
449
+ def load_smpl_voxel(self, data_dict):
450
+
451
+ smpl_verts, smpl_faces, pad_v_num, pad_f_num = self.compute_voxel_verts(
452
+ data_dict, self.noise_type,
453
+ self.noise_scale) # compute using smpl model
454
+ smpl_verts = projection(smpl_verts, data_dict['calib'])
455
+
456
+ smpl_verts *= 0.5
457
+
458
+ return {
459
+ 'voxel_verts': smpl_verts,
460
+ 'voxel_faces': smpl_faces,
461
+ 'pad_v_num': pad_v_num,
462
+ 'pad_f_num': pad_f_num
463
+ }
464
+
465
+ def get_sampling_geo(self, data_dict, is_valid=False, is_sdf=False):
466
+
467
+ mesh = data_dict['mesh']
468
+ calib = data_dict['calib']
469
+
470
+ # Samples are around the true surface with an offset
471
+ n_samples_surface = 4 * self.opt.num_sample_geo
472
+ vert_ids = np.arange(mesh.verts.shape[0])
473
+ thickness_sample_ratio = np.ones_like(vert_ids).astype(np.float32)
474
+
475
+ thickness_sample_ratio /= thickness_sample_ratio.sum()
476
+
477
+ samples_surface_ids = np.random.choice(vert_ids,
478
+ n_samples_surface,
479
+ replace=True,
480
+ p=thickness_sample_ratio)
481
+
482
+ samples_normal_ids = np.random.choice(vert_ids,
483
+ self.opt.num_sample_geo // 2,
484
+ replace=False,
485
+ p=thickness_sample_ratio)
486
+
487
+ surf_samples = mesh.verts[samples_normal_ids, :]
488
+ surf_normals = mesh.vert_normals[samples_normal_ids, :]
489
+
490
+ samples_surface = mesh.verts[samples_surface_ids, :]
491
+
492
+ # Sampling offsets are random noise with constant scale (15cm - 20cm)
493
+ offset = np.random.normal(scale=self.opt.sigma_geo,
494
+ size=(n_samples_surface, 1))
495
+ samples_surface += mesh.vert_normals[samples_surface_ids, :] * offset
496
+
497
+ # Uniform samples in [-1, 1]
498
+ calib_inv = np.linalg.inv(calib)
499
+ n_samples_space = self.opt.num_sample_geo // 4
500
+ samples_space_img = 2.0 * np.random.rand(n_samples_space, 3) - 1.0
501
+ samples_space = projection(samples_space_img, calib_inv)
502
+
503
+ # z-ray direction samples
504
+ if self.opt.zray_type and not is_valid:
505
+ n_samples_rayz = self.opt.ray_sample_num
506
+ samples_surface_cube = projection(samples_surface, calib)
507
+ samples_surface_cube_repeat = np.repeat(samples_surface_cube,
508
+ n_samples_rayz,
509
+ axis=0)
510
+
511
+ thickness_repeat = np.repeat(0.5 *
512
+ np.ones_like(samples_surface_ids),
513
+ n_samples_rayz,
514
+ axis=0)
515
+
516
+ noise_repeat = np.random.normal(scale=0.40,
517
+ size=(n_samples_surface *
518
+ n_samples_rayz, ))
519
+ samples_surface_cube_repeat[:,
520
+ -1] += thickness_repeat * noise_repeat
521
+ samples_surface_rayz = projection(samples_surface_cube_repeat,
522
+ calib_inv)
523
+
524
+ samples = np.concatenate(
525
+ [samples_surface, samples_space, samples_surface_rayz], 0)
526
+ else:
527
+ samples = np.concatenate([samples_surface, samples_space], 0)
528
+
529
+ np.random.shuffle(samples)
530
+
531
+ # labels: in->1.0; out->0.0.
532
+ if is_sdf:
533
+ sdfs = mesh.get_sdf(samples)
534
+ inside_samples = samples[sdfs < 0]
535
+ outside_samples = samples[sdfs >= 0]
536
+
537
+ inside_sdfs = sdfs[sdfs < 0]
538
+ outside_sdfs = sdfs[sdfs >= 0]
539
+ else:
540
+ inside = mesh.contains(samples)
541
+ inside_samples = samples[inside >= 0.5]
542
+ outside_samples = samples[inside < 0.5]
543
+
544
+ nin = inside_samples.shape[0]
545
+
546
+ if nin > self.opt.num_sample_geo // 2:
547
+ inside_samples = inside_samples[:self.opt.num_sample_geo // 2]
548
+ outside_samples = outside_samples[:self.opt.num_sample_geo // 2]
549
+ if is_sdf:
550
+ inside_sdfs = inside_sdfs[:self.opt.num_sample_geo // 2]
551
+ outside_sdfs = outside_sdfs[:self.opt.num_sample_geo // 2]
552
+ else:
553
+ outside_samples = outside_samples[:(self.opt.num_sample_geo - nin)]
554
+ if is_sdf:
555
+ outside_sdfs = outside_sdfs[:(self.opt.num_sample_geo - nin)]
556
+
557
+ if is_sdf:
558
+ samples = np.concatenate(
559
+ [inside_samples, outside_samples, surf_samples], 0)
560
+
561
+ labels = np.concatenate([
562
+ inside_sdfs, outside_sdfs, 0.0 * np.ones(surf_samples.shape[0])
563
+ ])
564
+
565
+ normals = np.zeros_like(samples)
566
+ normals[-self.opt.num_sample_geo // 2:, :] = surf_normals
567
+
568
+ # convert sdf from [-14, 130] to [0, 1]
569
+ # outside: 0, inside: 1
570
+ # Note: Marching cubes is defined on occupancy space (inside=1.0, outside=0.0)
571
+
572
+ labels = -labels.clip(min=-self.sdf_clip, max=self.sdf_clip)
573
+ labels += self.sdf_clip
574
+ labels /= (self.sdf_clip * 2)
575
+
576
+ else:
577
+ samples = np.concatenate([inside_samples, outside_samples])
578
+ labels = np.concatenate([
579
+ np.ones(inside_samples.shape[0]),
580
+ np.zeros(outside_samples.shape[0])
581
+ ])
582
+
583
+ normals = np.zeros_like(samples)
584
+
585
+ samples = torch.from_numpy(samples).float()
586
+ labels = torch.from_numpy(labels).float()
587
+ normals = torch.from_numpy(normals).float()
588
+
589
+ return {'samples_geo': samples, 'labels_geo': labels}
lib/dataset/TestDataset.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
5
+ # holder of all proprietary rights on this computer program.
6
+ # You can only use this computer program if you have closed
7
+ # a license agreement with MPG or you get the right to use the computer
8
+ # program from someone who is authorized to grant you that right.
9
+ # Any use of the computer program without a valid license is prohibited and
10
+ # liable to prosecution.
11
+ #
12
+ # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
13
+ # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
14
+ # for Intelligent Systems. All rights reserved.
15
+ #
16
+ # Contact: [email protected]
17
+
18
+ import os
19
+
20
+ import lib.smplx as smplx
21
+ from lib.pymaf.utils.geometry import rotation_matrix_to_angle_axis, batch_rodrigues
22
+ from lib.pymaf.utils.imutils import process_image
23
+ from lib.pymaf.core import path_config
24
+ from lib.pymaf.models import pymaf_net
25
+ from lib.common.config import cfg
26
+ from lib.common.render import Render
27
+ from lib.dataset.body_model import TetraSMPLModel
28
+ from lib.dataset.mesh_util import get_visibility, SMPLX
29
+ import os.path as osp
30
+ import torch
31
+ import numpy as np
32
+ import random
33
+ import human_det
34
+ from termcolor import colored
35
+ from PIL import ImageFile
36
+ from huggingface_hub import cached_download
37
+
38
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
39
+
40
+
41
+ class TestDataset():
42
+ def __init__(self, cfg, device):
43
+
44
+ random.seed(1993)
45
+
46
+ self.image_path = cfg['image_path']
47
+ self.seg_dir = cfg['seg_dir']
48
+ self.has_det = cfg['has_det']
49
+ self.hps_type = cfg['hps_type']
50
+ self.smpl_type = 'smpl' if cfg['hps_type'] != 'pixie' else 'smplx'
51
+ self.smpl_gender = 'neutral'
52
+
53
+ self.device = device
54
+
55
+ if self.has_det:
56
+ self.det = human_det.Detection()
57
+ else:
58
+ self.det = None
59
+
60
+
61
+ self.subject_list = [self.image_path]
62
+
63
+ # smpl related
64
+ self.smpl_data = SMPLX()
65
+
66
+ self.get_smpl_model = lambda smpl_type, smpl_gender: smplx.create(
67
+ model_path=self.smpl_data.model_dir,
68
+ gender=smpl_gender,
69
+ model_type=smpl_type,
70
+ ext='npz')
71
+
72
+ # Load SMPL model
73
+ self.smpl_model = self.get_smpl_model(
74
+ self.smpl_type, self.smpl_gender).to(self.device)
75
+ self.faces = self.smpl_model.faces
76
+
77
+ self.hps = pymaf_net(path_config.SMPL_MEAN_PARAMS,
78
+ pretrained=True).to(self.device)
79
+ self.hps.load_state_dict(torch.load(
80
+ path_config.CHECKPOINT_FILE)['model'],
81
+ strict=True)
82
+ self.hps.eval()
83
+
84
+ print(colored(f"Using {self.hps_type} as HPS Estimator\n", "green"))
85
+
86
+ self.render = Render(size=512, device=device)
87
+
88
+ def __len__(self):
89
+ return len(self.subject_list)
90
+
91
+ def compute_vis_cmap(self, smpl_verts, smpl_faces):
92
+
93
+ (xy, z) = torch.as_tensor(smpl_verts).split([2, 1], dim=1)
94
+ smpl_vis = get_visibility(xy, -z, torch.as_tensor(smpl_faces).long())
95
+ if self.smpl_type == 'smpl':
96
+ smplx_ind = self.smpl_data.smpl2smplx(np.arange(smpl_vis.shape[0]))
97
+ else:
98
+ smplx_ind = np.arange(smpl_vis.shape[0])
99
+ smpl_cmap = self.smpl_data.get_smpl_mat(smplx_ind)
100
+
101
+ return {
102
+ 'smpl_vis': smpl_vis.unsqueeze(0).to(self.device),
103
+ 'smpl_cmap': smpl_cmap.unsqueeze(0).to(self.device),
104
+ 'smpl_verts': smpl_verts.unsqueeze(0)
105
+ }
106
+
107
+ def compute_voxel_verts(self, body_pose, global_orient, betas, trans,
108
+ scale):
109
+
110
+ smpl_path = cached_download(osp.join(self.smpl_data.model_dir, "smpl/SMPL_NEUTRAL.pkl"), use_auth_token=os.environ['ICON'])
111
+ tetra_path = cached_download(osp.join(self.smpl_data.tedra_dir,
112
+ 'tetra_neutral_adult_smpl.npz'), use_auth_token=os.environ['ICON'])
113
+ smpl_model = TetraSMPLModel(smpl_path, tetra_path, 'adult')
114
+
115
+ pose = torch.cat([global_orient[0], body_pose[0]], dim=0)
116
+ smpl_model.set_params(rotation_matrix_to_angle_axis(pose),
117
+ beta=betas[0])
118
+
119
+ verts = np.concatenate(
120
+ [smpl_model.verts, smpl_model.verts_added],
121
+ axis=0) * scale.item() + trans.detach().cpu().numpy()
122
+ faces = np.loadtxt(cached_download(osp.join(self.smpl_data.tedra_dir,
123
+ 'tetrahedrons_neutral_adult.txt'), use_auth_token=os.environ['ICON']),
124
+ dtype=np.int32) - 1
125
+
126
+ pad_v_num = int(8000 - verts.shape[0])
127
+ pad_f_num = int(25100 - faces.shape[0])
128
+
129
+ verts = np.pad(verts, ((0, pad_v_num), (0, 0)),
130
+ mode='constant',
131
+ constant_values=0.0).astype(np.float32) * 0.5
132
+ faces = np.pad(faces, ((0, pad_f_num), (0, 0)),
133
+ mode='constant',
134
+ constant_values=0.0).astype(np.int32)
135
+
136
+ verts[:, 2] *= -1.0
137
+
138
+ voxel_dict = {
139
+ 'voxel_verts':
140
+ torch.from_numpy(verts).to(self.device).unsqueeze(0).float(),
141
+ 'voxel_faces':
142
+ torch.from_numpy(faces).to(self.device).unsqueeze(0).long(),
143
+ 'pad_v_num':
144
+ torch.tensor(pad_v_num).to(self.device).unsqueeze(0).long(),
145
+ 'pad_f_num':
146
+ torch.tensor(pad_f_num).to(self.device).unsqueeze(0).long()
147
+ }
148
+
149
+ return voxel_dict
150
+
151
+ def __getitem__(self, index):
152
+
153
+ img_path = self.subject_list[index]
154
+ img_name = img_path.split("/")[-1].rsplit(".", 1)[0]
155
+
156
+ if self.seg_dir is None:
157
+ img_icon, img_hps, img_ori, img_mask, uncrop_param = process_image(
158
+ img_path, self.det, self.hps_type, 512, self.device)
159
+
160
+ data_dict = {
161
+ 'name': img_name,
162
+ 'image': img_icon.to(self.device).unsqueeze(0),
163
+ 'ori_image': img_ori,
164
+ 'mask': img_mask,
165
+ 'uncrop_param': uncrop_param
166
+ }
167
+
168
+ else:
169
+ img_icon, img_hps, img_ori, img_mask, uncrop_param, segmentations = process_image(
170
+ img_path, self.det, self.hps_type, 512, self.device,
171
+ seg_path=os.path.join(self.seg_dir, f'{img_name}.json'))
172
+ data_dict = {
173
+ 'name': img_name,
174
+ 'image': img_icon.to(self.device).unsqueeze(0),
175
+ 'ori_image': img_ori,
176
+ 'mask': img_mask,
177
+ 'uncrop_param': uncrop_param,
178
+ 'segmentations': segmentations
179
+ }
180
+
181
+ with torch.no_grad():
182
+ # import ipdb; ipdb.set_trace()
183
+ preds_dict = self.hps.forward(img_hps)
184
+
185
+ data_dict['smpl_faces'] = torch.Tensor(
186
+ self.faces.astype(np.int16)).long().unsqueeze(0).to(
187
+ self.device)
188
+
189
+ if self.hps_type == 'pymaf':
190
+ output = preds_dict['smpl_out'][-1]
191
+ scale, tranX, tranY = output['theta'][0, :3]
192
+ data_dict['betas'] = output['pred_shape']
193
+ data_dict['body_pose'] = output['rotmat'][:, 1:]
194
+ data_dict['global_orient'] = output['rotmat'][:, 0:1]
195
+ data_dict['smpl_verts'] = output['verts']
196
+
197
+ elif self.hps_type == 'pare':
198
+ data_dict['body_pose'] = preds_dict['pred_pose'][:, 1:]
199
+ data_dict['global_orient'] = preds_dict['pred_pose'][:, 0:1]
200
+ data_dict['betas'] = preds_dict['pred_shape']
201
+ data_dict['smpl_verts'] = preds_dict['smpl_vertices']
202
+ scale, tranX, tranY = preds_dict['pred_cam'][0, :3]
203
+
204
+ elif self.hps_type == 'pixie':
205
+ data_dict.update(preds_dict)
206
+ data_dict['body_pose'] = preds_dict['body_pose']
207
+ data_dict['global_orient'] = preds_dict['global_pose']
208
+ data_dict['betas'] = preds_dict['shape']
209
+ data_dict['smpl_verts'] = preds_dict['vertices']
210
+ scale, tranX, tranY = preds_dict['cam'][0, :3]
211
+
212
+ elif self.hps_type == 'hybrik':
213
+ data_dict['body_pose'] = preds_dict['pred_theta_mats'][:, 1:]
214
+ data_dict['global_orient'] = preds_dict['pred_theta_mats'][:, [0]]
215
+ data_dict['betas'] = preds_dict['pred_shape']
216
+ data_dict['smpl_verts'] = preds_dict['pred_vertices']
217
+ scale, tranX, tranY = preds_dict['pred_camera'][0, :3]
218
+ scale = scale * 2
219
+
220
+ elif self.hps_type == 'bev':
221
+ data_dict['betas'] = torch.from_numpy(preds_dict['smpl_betas'])[
222
+ [0], :10].to(self.device).float()
223
+ pred_thetas = batch_rodrigues(torch.from_numpy(
224
+ preds_dict['smpl_thetas'][0]).reshape(-1, 3)).float()
225
+ data_dict['body_pose'] = pred_thetas[1:][None].to(self.device)
226
+ data_dict['global_orient'] = pred_thetas[[0]][None].to(self.device)
227
+ data_dict['smpl_verts'] = torch.from_numpy(
228
+ preds_dict['verts'][[0]]).to(self.device).float()
229
+ tranX = preds_dict['cam_trans'][0, 0]
230
+ tranY = preds_dict['cam'][0, 1] + 0.28
231
+ scale = preds_dict['cam'][0, 0] * 1.1
232
+
233
+ data_dict['scale'] = scale
234
+ data_dict['trans'] = torch.tensor(
235
+ [tranX, tranY, 0.0]).to(self.device).float()
236
+
237
+ # data_dict info (key-shape):
238
+ # scale, tranX, tranY - tensor.float
239
+ # betas - [1,10] / [1, 200]
240
+ # body_pose - [1, 23, 3, 3] / [1, 21, 3, 3]
241
+ # global_orient - [1, 1, 3, 3]
242
+ # smpl_verts - [1, 6890, 3] / [1, 10475, 3]
243
+
244
+ return data_dict
245
+
246
+ def render_normal(self, verts, faces):
247
+
248
+ # render optimized mesh (normal, T_normal, image [-1,1])
249
+ self.render.load_meshes(verts, faces)
250
+ return self.render.get_rgb_image()
251
+
252
+ def render_depth(self, verts, faces):
253
+
254
+ # render optimized mesh (normal, T_normal, image [-1,1])
255
+ self.render.load_meshes(verts, faces)
256
+ return self.render.get_depth_map(cam_ids=[0, 2])
lib/dataset/__init__.py ADDED
File without changes
lib/dataset/body_model.py ADDED
@@ -0,0 +1,494 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
5
+ # holder of all proprietary rights on this computer program.
6
+ # You can only use this computer program if you have closed
7
+ # a license agreement with MPG or you get the right to use the computer
8
+ # program from someone who is authorized to grant you that right.
9
+ # Any use of the computer program without a valid license is prohibited and
10
+ # liable to prosecution.
11
+ #
12
+ # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
13
+ # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
14
+ # for Intelligent Systems. All rights reserved.
15
+ #
16
+ # Contact: [email protected]
17
+
18
+ import numpy as np
19
+ import pickle
20
+ import torch
21
+ import os
22
+
23
+
24
+ class SMPLModel():
25
+ def __init__(self, model_path, age):
26
+ """
27
+ SMPL model.
28
+
29
+ Parameter:
30
+ ---------
31
+ model_path: Path to the SMPL model parameters, pre-processed by
32
+ `preprocess.py`.
33
+
34
+ """
35
+ with open(model_path, 'rb') as f:
36
+ params = pickle.load(f, encoding='latin1')
37
+
38
+ self.J_regressor = params['J_regressor']
39
+ self.weights = np.asarray(params['weights'])
40
+ self.posedirs = np.asarray(params['posedirs'])
41
+ self.v_template = np.asarray(params['v_template'])
42
+ self.shapedirs = np.asarray(params['shapedirs'])
43
+ self.faces = np.asarray(params['f'])
44
+ self.kintree_table = np.asarray(params['kintree_table'])
45
+
46
+ self.pose_shape = [24, 3]
47
+ self.beta_shape = [10]
48
+ self.trans_shape = [3]
49
+
50
+ if age == 'kid':
51
+ v_template_smil = np.load(
52
+ os.path.join(os.path.dirname(model_path),
53
+ "smpl/smpl_kid_template.npy"))
54
+ v_template_smil -= np.mean(v_template_smil, axis=0)
55
+ v_template_diff = np.expand_dims(v_template_smil - self.v_template,
56
+ axis=2)
57
+ self.shapedirs = np.concatenate(
58
+ (self.shapedirs[:, :, :self.beta_shape[0]], v_template_diff),
59
+ axis=2)
60
+ self.beta_shape[0] += 1
61
+
62
+ id_to_col = {
63
+ self.kintree_table[1, i]: i
64
+ for i in range(self.kintree_table.shape[1])
65
+ }
66
+ self.parent = {
67
+ i: id_to_col[self.kintree_table[0, i]]
68
+ for i in range(1, self.kintree_table.shape[1])
69
+ }
70
+
71
+ self.pose = np.zeros(self.pose_shape)
72
+ self.beta = np.zeros(self.beta_shape)
73
+ self.trans = np.zeros(self.trans_shape)
74
+
75
+ self.verts = None
76
+ self.J = None
77
+ self.R = None
78
+ self.G = None
79
+
80
+ self.update()
81
+
82
+ def set_params(self, pose=None, beta=None, trans=None):
83
+ """
84
+ Set pose, shape, and/or translation parameters of SMPL model. Verices of the
85
+ model will be updated and returned.
86
+
87
+ Prameters:
88
+ ---------
89
+ pose: Also known as 'theta', a [24,3] matrix indicating child joint rotation
90
+ relative to parent joint. For root joint it's global orientation.
91
+ Represented in a axis-angle format.
92
+
93
+ beta: Parameter for model shape. A vector of shape [10]. Coefficients for
94
+ PCA component. Only 10 components were released by MPI.
95
+
96
+ trans: Global translation of shape [3].
97
+
98
+ Return:
99
+ ------
100
+ Updated vertices.
101
+
102
+ """
103
+ if pose is not None:
104
+ self.pose = pose
105
+ if beta is not None:
106
+ self.beta = beta
107
+ if trans is not None:
108
+ self.trans = trans
109
+ self.update()
110
+ return self.verts
111
+
112
+ def update(self):
113
+ """
114
+ Called automatically when parameters are updated.
115
+
116
+ """
117
+ # how beta affect body shape
118
+ v_shaped = self.shapedirs.dot(self.beta) + self.v_template
119
+ # joints location
120
+ self.J = self.J_regressor.dot(v_shaped)
121
+ pose_cube = self.pose.reshape((-1, 1, 3))
122
+ # rotation matrix for each joint
123
+ self.R = self.rodrigues(pose_cube)
124
+ I_cube = np.broadcast_to(np.expand_dims(np.eye(3), axis=0),
125
+ (self.R.shape[0] - 1, 3, 3))
126
+ lrotmin = (self.R[1:] - I_cube).ravel()
127
+ # how pose affect body shape in zero pose
128
+ v_posed = v_shaped + self.posedirs.dot(lrotmin)
129
+ # world transformation of each joint
130
+ G = np.empty((self.kintree_table.shape[1], 4, 4))
131
+ G[0] = self.with_zeros(
132
+ np.hstack((self.R[0], self.J[0, :].reshape([3, 1]))))
133
+ for i in range(1, self.kintree_table.shape[1]):
134
+ G[i] = G[self.parent[i]].dot(
135
+ self.with_zeros(
136
+ np.hstack([
137
+ self.R[i],
138
+ ((self.J[i, :] - self.J[self.parent[i], :]).reshape(
139
+ [3, 1]))
140
+ ])))
141
+ # remove the transformation due to the rest pose
142
+ G = G - self.pack(
143
+ np.matmul(
144
+ G,
145
+ np.hstack([self.J, np.zeros([24, 1])]).reshape([24, 4, 1])))
146
+ # transformation of each vertex
147
+ T = np.tensordot(self.weights, G, axes=[[1], [0]])
148
+ rest_shape_h = np.hstack((v_posed, np.ones([v_posed.shape[0], 1])))
149
+ v = np.matmul(T, rest_shape_h.reshape([-1, 4, 1])).reshape([-1,
150
+ 4])[:, :3]
151
+ self.verts = v + self.trans.reshape([1, 3])
152
+ self.G = G
153
+
154
+ def rodrigues(self, r):
155
+ """
156
+ Rodrigues' rotation formula that turns axis-angle vector into rotation
157
+ matrix in a batch-ed manner.
158
+
159
+ Parameter:
160
+ ----------
161
+ r: Axis-angle rotation vector of shape [batch_size, 1, 3].
162
+
163
+ Return:
164
+ -------
165
+ Rotation matrix of shape [batch_size, 3, 3].
166
+
167
+ """
168
+ theta = np.linalg.norm(r, axis=(1, 2), keepdims=True)
169
+ # avoid zero divide
170
+ theta = np.maximum(theta, np.finfo(np.float64).tiny)
171
+ r_hat = r / theta
172
+ cos = np.cos(theta)
173
+ z_stick = np.zeros(theta.shape[0])
174
+ m = np.dstack([
175
+ z_stick, -r_hat[:, 0, 2], r_hat[:, 0, 1], r_hat[:, 0, 2], z_stick,
176
+ -r_hat[:, 0, 0], -r_hat[:, 0, 1], r_hat[:, 0, 0], z_stick
177
+ ]).reshape([-1, 3, 3])
178
+ i_cube = np.broadcast_to(np.expand_dims(np.eye(3), axis=0),
179
+ [theta.shape[0], 3, 3])
180
+ A = np.transpose(r_hat, axes=[0, 2, 1])
181
+ B = r_hat
182
+ dot = np.matmul(A, B)
183
+ R = cos * i_cube + (1 - cos) * dot + np.sin(theta) * m
184
+ return R
185
+
186
+ def with_zeros(self, x):
187
+ """
188
+ Append a [0, 0, 0, 1] vector to a [3, 4] matrix.
189
+
190
+ Parameter:
191
+ ---------
192
+ x: Matrix to be appended.
193
+
194
+ Return:
195
+ ------
196
+ Matrix after appending of shape [4,4]
197
+
198
+ """
199
+ return np.vstack((x, np.array([[0.0, 0.0, 0.0, 1.0]])))
200
+
201
+ def pack(self, x):
202
+ """
203
+ Append zero matrices of shape [4, 3] to vectors of [4, 1] shape in a batched
204
+ manner.
205
+
206
+ Parameter:
207
+ ----------
208
+ x: Matrices to be appended of shape [batch_size, 4, 1]
209
+
210
+ Return:
211
+ ------
212
+ Matrix of shape [batch_size, 4, 4] after appending.
213
+
214
+ """
215
+ return np.dstack((np.zeros((x.shape[0], 4, 3)), x))
216
+
217
+ def save_to_obj(self, path):
218
+ """
219
+ Save the SMPL model into .obj file.
220
+
221
+ Parameter:
222
+ ---------
223
+ path: Path to save.
224
+
225
+ """
226
+ with open(path, 'w') as fp:
227
+ for v in self.verts:
228
+ fp.write('v %f %f %f\n' % (v[0], v[1], v[2]))
229
+ for f in self.faces + 1:
230
+ fp.write('f %d %d %d\n' % (f[0], f[1], f[2]))
231
+
232
+
233
+ class TetraSMPLModel():
234
+ def __init__(self,
235
+ model_path,
236
+ model_addition_path,
237
+ age='adult',
238
+ v_template=None):
239
+ """
240
+ SMPL model.
241
+
242
+ Parameter:
243
+ ---------
244
+ model_path: Path to the SMPL model parameters, pre-processed by
245
+ `preprocess.py`.
246
+
247
+ """
248
+ with open(model_path, 'rb') as f:
249
+ params = pickle.load(f, encoding='latin1')
250
+
251
+ self.J_regressor = params['J_regressor']
252
+ self.weights = np.asarray(params['weights'])
253
+ self.posedirs = np.asarray(params['posedirs'])
254
+
255
+ if v_template is not None:
256
+ self.v_template = v_template
257
+ else:
258
+ self.v_template = np.asarray(params['v_template'])
259
+
260
+ self.shapedirs = np.asarray(params['shapedirs'])
261
+ self.faces = np.asarray(params['f'])
262
+ self.kintree_table = np.asarray(params['kintree_table'])
263
+
264
+ params_added = np.load(model_addition_path)
265
+ self.v_template_added = params_added['v_template_added']
266
+ self.weights_added = params_added['weights_added']
267
+ self.shapedirs_added = params_added['shapedirs_added']
268
+ self.posedirs_added = params_added['posedirs_added']
269
+ self.tetrahedrons = params_added['tetrahedrons']
270
+
271
+ id_to_col = {
272
+ self.kintree_table[1, i]: i
273
+ for i in range(self.kintree_table.shape[1])
274
+ }
275
+ self.parent = {
276
+ i: id_to_col[self.kintree_table[0, i]]
277
+ for i in range(1, self.kintree_table.shape[1])
278
+ }
279
+
280
+ self.pose_shape = [24, 3]
281
+ self.beta_shape = [10]
282
+ self.trans_shape = [3]
283
+
284
+ if age == 'kid':
285
+ v_template_smil = np.load(
286
+ os.path.join(os.path.dirname(model_path),
287
+ "smpl/smpl_kid_template.npy"))
288
+ v_template_smil -= np.mean(v_template_smil, axis=0)
289
+ v_template_diff = np.expand_dims(v_template_smil - self.v_template,
290
+ axis=2)
291
+ self.shapedirs = np.concatenate(
292
+ (self.shapedirs[:, :, :self.beta_shape[0]], v_template_diff),
293
+ axis=2)
294
+ self.beta_shape[0] += 1
295
+
296
+ self.pose = np.zeros(self.pose_shape)
297
+ self.beta = np.zeros(self.beta_shape)
298
+ self.trans = np.zeros(self.trans_shape)
299
+
300
+ self.verts = None
301
+ self.verts_added = None
302
+ self.J = None
303
+ self.R = None
304
+ self.G = None
305
+
306
+ self.update()
307
+
308
+ def set_params(self, pose=None, beta=None, trans=None):
309
+ """
310
+ Set pose, shape, and/or translation parameters of SMPL model. Verices of the
311
+ model will be updated and returned.
312
+
313
+ Prameters:
314
+ ---------
315
+ pose: Also known as 'theta', a [24,3] matrix indicating child joint rotation
316
+ relative to parent joint. For root joint it's global orientation.
317
+ Represented in a axis-angle format.
318
+
319
+ beta: Parameter for model shape. A vector of shape [10]. Coefficients for
320
+ PCA component. Only 10 components were released by MPI.
321
+
322
+ trans: Global translation of shape [3].
323
+
324
+ Return:
325
+ ------
326
+ Updated vertices.
327
+
328
+ """
329
+
330
+ if torch.is_tensor(pose):
331
+ pose = pose.detach().cpu().numpy()
332
+ if torch.is_tensor(beta):
333
+ beta = beta.detach().cpu().numpy()
334
+
335
+ if pose is not None:
336
+ self.pose = pose
337
+ if beta is not None:
338
+ self.beta = beta
339
+ if trans is not None:
340
+ self.trans = trans
341
+ self.update()
342
+ return self.verts
343
+
344
+ def update(self):
345
+ """
346
+ Called automatically when parameters are updated.
347
+
348
+ """
349
+ # how beta affect body shape
350
+ v_shaped = self.shapedirs.dot(self.beta) + self.v_template
351
+ v_shaped_added = self.shapedirs_added.dot(
352
+ self.beta) + self.v_template_added
353
+ # joints location
354
+ self.J = self.J_regressor.dot(v_shaped)
355
+ pose_cube = self.pose.reshape((-1, 1, 3))
356
+ # rotation matrix for each joint
357
+ self.R = self.rodrigues(pose_cube)
358
+ I_cube = np.broadcast_to(np.expand_dims(np.eye(3), axis=0),
359
+ (self.R.shape[0] - 1, 3, 3))
360
+ lrotmin = (self.R[1:] - I_cube).ravel()
361
+ # how pose affect body shape in zero pose
362
+ v_posed = v_shaped + self.posedirs.dot(lrotmin)
363
+ v_posed_added = v_shaped_added + self.posedirs_added.dot(lrotmin)
364
+ # world transformation of each joint
365
+ G = np.empty((self.kintree_table.shape[1], 4, 4))
366
+ G[0] = self.with_zeros(
367
+ np.hstack((self.R[0], self.J[0, :].reshape([3, 1]))))
368
+ for i in range(1, self.kintree_table.shape[1]):
369
+ G[i] = G[self.parent[i]].dot(
370
+ self.with_zeros(
371
+ np.hstack([
372
+ self.R[i],
373
+ ((self.J[i, :] - self.J[self.parent[i], :]).reshape(
374
+ [3, 1]))
375
+ ])))
376
+ # remove the transformation due to the rest pose
377
+ G = G - self.pack(
378
+ np.matmul(
379
+ G,
380
+ np.hstack([self.J, np.zeros([24, 1])]).reshape([24, 4, 1])))
381
+ self.G = G
382
+ # transformation of each vertex
383
+ T = np.tensordot(self.weights, G, axes=[[1], [0]])
384
+ rest_shape_h = np.hstack((v_posed, np.ones([v_posed.shape[0], 1])))
385
+ v = np.matmul(T, rest_shape_h.reshape([-1, 4, 1])).reshape([-1,
386
+ 4])[:, :3]
387
+ self.verts = v + self.trans.reshape([1, 3])
388
+ T_added = np.tensordot(self.weights_added, G, axes=[[1], [0]])
389
+ rest_shape_added_h = np.hstack(
390
+ (v_posed_added, np.ones([v_posed_added.shape[0], 1])))
391
+ v_added = np.matmul(T_added,
392
+ rest_shape_added_h.reshape([-1, 4,
393
+ 1])).reshape([-1, 4
394
+ ])[:, :3]
395
+ self.verts_added = v_added + self.trans.reshape([1, 3])
396
+
397
+ def rodrigues(self, r):
398
+ """
399
+ Rodrigues' rotation formula that turns axis-angle vector into rotation
400
+ matrix in a batch-ed manner.
401
+
402
+ Parameter:
403
+ ----------
404
+ r: Axis-angle rotation vector of shape [batch_size, 1, 3].
405
+
406
+ Return:
407
+ -------
408
+ Rotation matrix of shape [batch_size, 3, 3].
409
+
410
+ """
411
+ theta = np.linalg.norm(r, axis=(1, 2), keepdims=True)
412
+ # avoid zero divide
413
+ theta = np.maximum(theta, np.finfo(np.float64).tiny)
414
+ r_hat = r / theta
415
+ cos = np.cos(theta)
416
+ z_stick = np.zeros(theta.shape[0])
417
+ m = np.dstack([
418
+ z_stick, -r_hat[:, 0, 2], r_hat[:, 0, 1], r_hat[:, 0, 2], z_stick,
419
+ -r_hat[:, 0, 0], -r_hat[:, 0, 1], r_hat[:, 0, 0], z_stick
420
+ ]).reshape([-1, 3, 3])
421
+ i_cube = np.broadcast_to(np.expand_dims(np.eye(3), axis=0),
422
+ [theta.shape[0], 3, 3])
423
+ A = np.transpose(r_hat, axes=[0, 2, 1])
424
+ B = r_hat
425
+ dot = np.matmul(A, B)
426
+ R = cos * i_cube + (1 - cos) * dot + np.sin(theta) * m
427
+ return R
428
+
429
+ def with_zeros(self, x):
430
+ """
431
+ Append a [0, 0, 0, 1] vector to a [3, 4] matrix.
432
+
433
+ Parameter:
434
+ ---------
435
+ x: Matrix to be appended.
436
+
437
+ Return:
438
+ ------
439
+ Matrix after appending of shape [4,4]
440
+
441
+ """
442
+ return np.vstack((x, np.array([[0.0, 0.0, 0.0, 1.0]])))
443
+
444
+ def pack(self, x):
445
+ """
446
+ Append zero matrices of shape [4, 3] to vectors of [4, 1] shape in a batched
447
+ manner.
448
+
449
+ Parameter:
450
+ ----------
451
+ x: Matrices to be appended of shape [batch_size, 4, 1]
452
+
453
+ Return:
454
+ ------
455
+ Matrix of shape [batch_size, 4, 4] after appending.
456
+
457
+ """
458
+ return np.dstack((np.zeros((x.shape[0], 4, 3)), x))
459
+
460
+ def save_mesh_to_obj(self, path):
461
+ """
462
+ Save the SMPL model into .obj file.
463
+
464
+ Parameter:
465
+ ---------
466
+ path: Path to save.
467
+
468
+ """
469
+ with open(path, 'w') as fp:
470
+ for v in self.verts:
471
+ fp.write('v %f %f %f\n' % (v[0], v[1], v[2]))
472
+ for f in self.faces + 1:
473
+ fp.write('f %d %d %d\n' % (f[0], f[1], f[2]))
474
+
475
+ def save_tetrahedron_to_obj(self, path):
476
+ """
477
+ Save the tetrahedron SMPL model into .obj file.
478
+
479
+ Parameter:
480
+ ---------
481
+ path: Path to save.
482
+
483
+ """
484
+
485
+ with open(path, 'w') as fp:
486
+ for v in self.verts:
487
+ fp.write('v %f %f %f 1 0 0\n' % (v[0], v[1], v[2]))
488
+ for va in self.verts_added:
489
+ fp.write('v %f %f %f 0 0 1\n' % (va[0], va[1], va[2]))
490
+ for t in self.tetrahedrons + 1:
491
+ fp.write('f %d %d %d\n' % (t[0], t[2], t[1]))
492
+ fp.write('f %d %d %d\n' % (t[0], t[3], t[2]))
493
+ fp.write('f %d %d %d\n' % (t[0], t[1], t[3]))
494
+ fp.write('f %d %d %d\n' % (t[1], t[2], t[3]))
lib/dataset/hoppeMesh.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
5
+ # holder of all proprietary rights on this computer program.
6
+ # You can only use this computer program if you have closed
7
+ # a license agreement with MPG or you get the right to use the computer
8
+ # program from someone who is authorized to grant you that right.
9
+ # Any use of the computer program without a valid license is prohibited and
10
+ # liable to prosecution.
11
+ #
12
+ # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
13
+ # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
14
+ # for Intelligent Systems. All rights reserved.
15
+ #
16
+ # Contact: [email protected]
17
+
18
+ import numpy as np
19
+ from scipy.spatial import cKDTree
20
+ import trimesh
21
+
22
+ import logging
23
+
24
+ logging.getLogger("trimesh").setLevel(logging.ERROR)
25
+
26
+
27
+ def save_obj_mesh(mesh_path, verts, faces):
28
+ file = open(mesh_path, 'w')
29
+ for v in verts:
30
+ file.write('v %.4f %.4f %.4f\n' % (v[0], v[1], v[2]))
31
+ for f in faces:
32
+ f_plus = f + 1
33
+ file.write('f %d %d %d\n' % (f_plus[0], f_plus[1], f_plus[2]))
34
+ file.close()
35
+
36
+
37
+ def save_obj_mesh_with_color(mesh_path, verts, faces, colors):
38
+ file = open(mesh_path, 'w')
39
+
40
+ for idx, v in enumerate(verts):
41
+ c = colors[idx]
42
+ file.write('v %.4f %.4f %.4f %.4f %.4f %.4f\n' %
43
+ (v[0], v[1], v[2], c[0], c[1], c[2]))
44
+ for f in faces:
45
+ f_plus = f + 1
46
+ file.write('f %d %d %d\n' % (f_plus[0], f_plus[1], f_plus[2]))
47
+ file.close()
48
+
49
+
50
+ def save_ply(mesh_path, points, rgb):
51
+ '''
52
+ Save the visualization of sampling to a ply file.
53
+ Red points represent positive predictions.
54
+ Green points represent negative predictions.
55
+ :param mesh_path: File name to save
56
+ :param points: [N, 3] array of points
57
+ :param rgb: [N, 3] array of rgb values in the range [0~1]
58
+ :return:
59
+ '''
60
+ to_save = np.concatenate([points, rgb * 255], axis=-1)
61
+ return np.savetxt(
62
+ mesh_path,
63
+ to_save,
64
+ fmt='%.6f %.6f %.6f %d %d %d',
65
+ comments='',
66
+ header=(
67
+ 'ply\nformat ascii 1.0\nelement vertex {:d}\n' +
68
+ 'property float x\nproperty float y\nproperty float z\n' +
69
+ 'property uchar red\nproperty uchar green\nproperty uchar blue\n' +
70
+ 'end_header').format(points.shape[0]))
71
+
72
+
73
+ class HoppeMesh:
74
+ def __init__(self, verts, faces, vert_normals, face_normals):
75
+ '''
76
+ The HoppeSDF calculates signed distance towards a predefined oriented point cloud
77
+ http://hhoppe.com/recon.pdf
78
+ For clean and high-resolution pcl data, this is the fastest and accurate approximation of sdf
79
+ :param points: pts
80
+ :param normals: normals
81
+ '''
82
+ self.verts = verts # [n, 3]
83
+ self.faces = faces # [m, 3]
84
+ self.vert_normals = vert_normals # [n, 3]
85
+ self.face_normals = face_normals # [m, 3]
86
+
87
+ self.kd_tree = cKDTree(self.verts)
88
+ self.len = len(self.verts)
89
+
90
+ def query(self, points):
91
+ dists, idx = self.kd_tree.query(points, n_jobs=1)
92
+ # FIXME: because the eyebows are removed, cKDTree around eyebows
93
+ # are not accurate. Cause a few false-inside labels here.
94
+ dirs = points - self.verts[idx]
95
+ signs = (dirs * self.vert_normals[idx]).sum(axis=1)
96
+ signs = (signs > 0) * 2 - 1
97
+ return signs * dists
98
+
99
+ def contains(self, points):
100
+
101
+ labels = trimesh.Trimesh(vertices=self.verts,
102
+ faces=self.faces).contains(points)
103
+ return labels
104
+
105
+ def export(self, path):
106
+ if self.colors is not None:
107
+ save_obj_mesh_with_color(path, self.verts, self.faces,
108
+ self.colors[:, 0:3] / 255.0)
109
+ else:
110
+ save_obj_mesh(path, self.verts, self.faces)
111
+
112
+ def export_ply(self, path):
113
+ save_ply(path, self.verts, self.colors[:, 0:3] / 255.0)
114
+
115
+ def triangles(self):
116
+ return self.verts[self.faces] # [n, 3, 3]
lib/dataset/mesh_util.py ADDED
@@ -0,0 +1,894 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
5
+ # holder of all proprietary rights on this computer program.
6
+ # You can only use this computer program if you have closed
7
+ # a license agreement with MPG or you get the right to use the computer
8
+ # program from someone who is authorized to grant you that right.
9
+ # Any use of the computer program without a valid license is prohibited and
10
+ # liable to prosecution.
11
+ #
12
+ # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
13
+ # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
14
+ # for Intelligent Systems. All rights reserved.
15
+ #
16
+ # Contact: [email protected]
17
+
18
+ import numpy as np
19
+ import cv2
20
+ import pymeshlab
21
+ import torch
22
+ import torchvision
23
+ import trimesh
24
+ from pytorch3d.io import load_obj
25
+ from termcolor import colored
26
+ from scipy.spatial import cKDTree
27
+
28
+ from pytorch3d.structures import Meshes
29
+ import torch.nn.functional as F
30
+
31
+ import os
32
+ from lib.pymaf.utils.imutils import uncrop
33
+ from lib.common.render_utils import Pytorch3dRasterizer, face_vertices
34
+
35
+ from pytorch3d.renderer.mesh import rasterize_meshes
36
+ from PIL import Image, ImageFont, ImageDraw
37
+ from kaolin.ops.mesh import check_sign
38
+ from kaolin.metrics.trianglemesh import point_to_mesh_distance
39
+
40
+ from pytorch3d.loss import (
41
+ mesh_laplacian_smoothing,
42
+ mesh_normal_consistency
43
+ )
44
+
45
+ from huggingface_hub import hf_hub_download, hf_hub_url, cached_download
46
+
47
+
48
+ def tensor2variable(tensor, device):
49
+ # [1,23,3,3]
50
+ return torch.tensor(tensor, device=device, requires_grad=True)
51
+
52
+
53
+ def normal_loss(vec1, vec2):
54
+
55
+ # vec1_mask = vec1.sum(dim=1) != 0.0
56
+ # vec2_mask = vec2.sum(dim=1) != 0.0
57
+ # union_mask = vec1_mask * vec2_mask
58
+ vec_sim = torch.nn.CosineSimilarity(dim=1, eps=1e-6)(vec1, vec2)
59
+ # vec_diff = ((vec_sim-1.0)**2)[union_mask].mean()
60
+ vec_diff = ((vec_sim-1.0)**2).mean()
61
+
62
+ return vec_diff
63
+
64
+
65
+ class GMoF(torch.nn.Module):
66
+ def __init__(self, rho=1):
67
+ super(GMoF, self).__init__()
68
+ self.rho = rho
69
+
70
+ def extra_repr(self):
71
+ return 'rho = {}'.format(self.rho)
72
+
73
+ def forward(self, residual):
74
+ dist = torch.div(residual, residual + self.rho ** 2)
75
+ return self.rho ** 2 * dist
76
+
77
+
78
+ def mesh_edge_loss(meshes, target_length: float = 0.0):
79
+ """
80
+ Computes mesh edge length regularization loss averaged across all meshes
81
+ in a batch. Each mesh contributes equally to the final loss, regardless of
82
+ the number of edges per mesh in the batch by weighting each mesh with the
83
+ inverse number of edges. For example, if mesh 3 (out of N) has only E=4
84
+ edges, then the loss for each edge in mesh 3 should be multiplied by 1/E to
85
+ contribute to the final loss.
86
+
87
+ Args:
88
+ meshes: Meshes object with a batch of meshes.
89
+ target_length: Resting value for the edge length.
90
+
91
+ Returns:
92
+ loss: Average loss across the batch. Returns 0 if meshes contains
93
+ no meshes or all empty meshes.
94
+ """
95
+ if meshes.isempty():
96
+ return torch.tensor(
97
+ [0.0], dtype=torch.float32, device=meshes.device, requires_grad=True
98
+ )
99
+
100
+ N = len(meshes)
101
+ edges_packed = meshes.edges_packed() # (sum(E_n), 3)
102
+ verts_packed = meshes.verts_packed() # (sum(V_n), 3)
103
+ edge_to_mesh_idx = meshes.edges_packed_to_mesh_idx() # (sum(E_n), )
104
+ num_edges_per_mesh = meshes.num_edges_per_mesh() # N
105
+
106
+ # Determine the weight for each edge based on the number of edges in the
107
+ # mesh it corresponds to.
108
+ # TODO (nikhilar) Find a faster way of computing the weights for each edge
109
+ # as this is currently a bottleneck for meshes with a large number of faces.
110
+ weights = num_edges_per_mesh.gather(0, edge_to_mesh_idx)
111
+ weights = 1.0 / weights.float()
112
+
113
+ verts_edges = verts_packed[edges_packed]
114
+ v0, v1 = verts_edges.unbind(1)
115
+ loss = ((v0 - v1).norm(dim=1, p=2) - target_length) ** 2.0
116
+ loss_vertex = loss * weights
117
+ # loss_outlier = torch.topk(loss, 100)[0].mean()
118
+ # loss_all = (loss_vertex.sum() + loss_outlier.mean()) / N
119
+ loss_all = loss_vertex.sum() / N
120
+
121
+ return loss_all
122
+
123
+
124
+ def remesh(obj_path, perc, device):
125
+
126
+ ms = pymeshlab.MeshSet()
127
+ ms.load_new_mesh(obj_path)
128
+ ms.laplacian_smooth()
129
+ ms.remeshing_isotropic_explicit_remeshing(
130
+ targetlen=pymeshlab.Percentage(perc), adaptive=True)
131
+ ms.save_current_mesh(obj_path.replace("recon", "remesh"))
132
+ polished_mesh = trimesh.load_mesh(obj_path.replace("recon", "remesh"))
133
+ verts_pr = torch.tensor(polished_mesh.vertices).float().unsqueeze(0).to(device)
134
+ faces_pr = torch.tensor(polished_mesh.faces).long().unsqueeze(0).to(device)
135
+
136
+ return verts_pr, faces_pr
137
+
138
+
139
+ def possion(mesh, obj_path):
140
+
141
+ mesh.export(obj_path)
142
+ ms = pymeshlab.MeshSet()
143
+ ms.load_new_mesh(obj_path)
144
+ ms.surface_reconstruction_screened_poisson(depth=10)
145
+ ms.set_current_mesh(1)
146
+ ms.save_current_mesh(obj_path)
147
+
148
+ return trimesh.load(obj_path)
149
+
150
+
151
+ def get_mask(tensor, dim):
152
+
153
+ mask = torch.abs(tensor).sum(dim=dim, keepdims=True) > 0.0
154
+ mask = mask.type_as(tensor)
155
+
156
+ return mask
157
+
158
+
159
+ def blend_rgb_norm(rgb, norm, mask):
160
+
161
+ # [0,0,0] or [127,127,127] should be marked as mask
162
+ final = rgb * (1-mask) + norm * (mask)
163
+
164
+ return final.astype(np.uint8)
165
+
166
+
167
+ def unwrap(image, data):
168
+
169
+ img_uncrop = uncrop(np.array(Image.fromarray(image).resize(data['uncrop_param']['box_shape'][:2])),
170
+ data['uncrop_param']['center'],
171
+ data['uncrop_param']['scale'],
172
+ data['uncrop_param']['crop_shape'])
173
+
174
+ img_orig = cv2.warpAffine(img_uncrop,
175
+ np.linalg.inv(data['uncrop_param']['M'])[:2, :],
176
+ data['uncrop_param']['ori_shape'][::-1][1:],
177
+ flags=cv2.INTER_CUBIC)
178
+
179
+ return img_orig
180
+
181
+
182
+ # Losses to smooth / regularize the mesh shape
183
+ def update_mesh_shape_prior_losses(mesh, losses):
184
+
185
+ # and (b) the edge length of the predicted mesh
186
+ losses["edge"]['value'] = mesh_edge_loss(mesh)
187
+ # mesh normal consistency
188
+ losses["nc"]['value'] = mesh_normal_consistency(mesh)
189
+ # mesh laplacian smoothing
190
+ losses["laplacian"]['value'] = mesh_laplacian_smoothing(
191
+ mesh, method="uniform")
192
+
193
+
194
+ def rename(old_dict, old_name, new_name):
195
+ new_dict = {}
196
+ for key, value in zip(old_dict.keys(), old_dict.values()):
197
+ new_key = key if key != old_name else new_name
198
+ new_dict[new_key] = old_dict[key]
199
+ return new_dict
200
+
201
+
202
+ def load_checkpoint(model, cfg):
203
+
204
+ model_dict = model.state_dict()
205
+ main_dict = {}
206
+ normal_dict = {}
207
+
208
+ device = torch.device(f"cuda:{cfg['test_gpus'][0]}")
209
+
210
+ main_dict = torch.load(cached_download(cfg.resume_path, use_auth_token=os.environ['ICON']),
211
+ map_location=device)['state_dict']
212
+
213
+ main_dict = {
214
+ k: v
215
+ for k, v in main_dict.items()
216
+ if k in model_dict and v.shape == model_dict[k].shape and (
217
+ 'reconEngine' not in k) and ("normal_filter" not in k) and (
218
+ 'voxelization' not in k)
219
+ }
220
+ print(colored(f"Resume MLP weights from {cfg.resume_path}", 'green'))
221
+
222
+ normal_dict = torch.load(cached_download(cfg.normal_path, use_auth_token=os.environ['ICON']),
223
+ map_location=device)['state_dict']
224
+
225
+ for key in normal_dict.keys():
226
+ normal_dict = rename(normal_dict, key,
227
+ key.replace("netG", "netG.normal_filter"))
228
+
229
+ normal_dict = {
230
+ k: v
231
+ for k, v in normal_dict.items()
232
+ if k in model_dict and v.shape == model_dict[k].shape
233
+ }
234
+ print(colored(f"Resume normal model from {cfg.normal_path}", 'green'))
235
+
236
+ model_dict.update(main_dict)
237
+ model_dict.update(normal_dict)
238
+ model.load_state_dict(model_dict)
239
+
240
+ model.netG = model.netG.to(device)
241
+ model.reconEngine = model.reconEngine.to(device)
242
+
243
+ model.netG.training = False
244
+ model.netG.eval()
245
+
246
+ del main_dict
247
+ del normal_dict
248
+ del model_dict
249
+
250
+ return model
251
+
252
+
253
+ def read_smpl_constants(folder):
254
+ """Load smpl vertex code"""
255
+ smpl_vtx_std = np.loadtxt(cached_download(os.path.join(folder, 'vertices.txt'), use_auth_token=os.environ['ICON']))
256
+ min_x = np.min(smpl_vtx_std[:, 0])
257
+ max_x = np.max(smpl_vtx_std[:, 0])
258
+ min_y = np.min(smpl_vtx_std[:, 1])
259
+ max_y = np.max(smpl_vtx_std[:, 1])
260
+ min_z = np.min(smpl_vtx_std[:, 2])
261
+ max_z = np.max(smpl_vtx_std[:, 2])
262
+
263
+ smpl_vtx_std[:, 0] = (smpl_vtx_std[:, 0] - min_x) / (max_x - min_x)
264
+ smpl_vtx_std[:, 1] = (smpl_vtx_std[:, 1] - min_y) / (max_y - min_y)
265
+ smpl_vtx_std[:, 2] = (smpl_vtx_std[:, 2] - min_z) / (max_z - min_z)
266
+ smpl_vertex_code = np.float32(np.copy(smpl_vtx_std))
267
+ """Load smpl faces & tetrahedrons"""
268
+ smpl_faces = np.loadtxt(cached_download(os.path.join(folder, 'faces.txt'), use_auth_token=os.environ['ICON']),
269
+ dtype=np.int32) - 1
270
+ smpl_face_code = (smpl_vertex_code[smpl_faces[:, 0]] +
271
+ smpl_vertex_code[smpl_faces[:, 1]] +
272
+ smpl_vertex_code[smpl_faces[:, 2]]) / 3.0
273
+ smpl_tetras = np.loadtxt(cached_download(os.path.join(folder, 'tetrahedrons.txt'), use_auth_token=os.environ['ICON']),
274
+ dtype=np.int32) - 1
275
+
276
+ return smpl_vertex_code, smpl_face_code, smpl_faces, smpl_tetras
277
+
278
+
279
+ def feat_select(feat, select):
280
+
281
+ # feat [B, featx2, N]
282
+ # select [B, 1, N]
283
+ # return [B, feat, N]
284
+
285
+ dim = feat.shape[1] // 2
286
+ idx = torch.tile((1-select), (1, dim, 1))*dim + \
287
+ torch.arange(0, dim).unsqueeze(0).unsqueeze(2).type_as(select)
288
+ feat_select = torch.gather(feat, 1, idx.long())
289
+
290
+ return feat_select
291
+
292
+
293
+ def get_visibility(xy, z, faces):
294
+ """get the visibility of vertices
295
+
296
+ Args:
297
+ xy (torch.tensor): [N,2]
298
+ z (torch.tensor): [N,1]
299
+ faces (torch.tensor): [N,3]
300
+ size (int): resolution of rendered image
301
+ """
302
+
303
+ xyz = torch.cat((xy, -z), dim=1)
304
+ xyz = (xyz + 1.0) / 2.0
305
+ faces = faces.long()
306
+
307
+ rasterizer = Pytorch3dRasterizer(image_size=2**12)
308
+ meshes_screen = Meshes(verts=xyz[None, ...], faces=faces[None, ...])
309
+ raster_settings = rasterizer.raster_settings
310
+
311
+ pix_to_face, zbuf, bary_coords, dists = rasterize_meshes(
312
+ meshes_screen,
313
+ image_size=raster_settings.image_size,
314
+ blur_radius=raster_settings.blur_radius,
315
+ faces_per_pixel=raster_settings.faces_per_pixel,
316
+ bin_size=raster_settings.bin_size,
317
+ max_faces_per_bin=raster_settings.max_faces_per_bin,
318
+ perspective_correct=raster_settings.perspective_correct,
319
+ cull_backfaces=raster_settings.cull_backfaces,
320
+ )
321
+
322
+ vis_vertices_id = torch.unique(faces[torch.unique(pix_to_face), :])
323
+ vis_mask = torch.zeros(size=(z.shape[0], 1))
324
+ vis_mask[vis_vertices_id] = 1.0
325
+
326
+ # print("------------------------\n")
327
+ # print(f"keep points : {vis_mask.sum()/len(vis_mask)}")
328
+
329
+ return vis_mask
330
+
331
+
332
+ def barycentric_coordinates_of_projection(points, vertices):
333
+ ''' https://github.com/MPI-IS/mesh/blob/master/mesh/geometry/barycentric_coordinates_of_projection.py
334
+ '''
335
+ """Given a point, gives projected coords of that point to a triangle
336
+ in barycentric coordinates.
337
+ See
338
+ **Heidrich**, Computing the Barycentric Coordinates of a Projected Point, JGT 05
339
+ at http://www.cs.ubc.ca/~heidrich/Papers/JGT.05.pdf
340
+
341
+ :param p: point to project. [B, 3]
342
+ :param v0: first vertex of triangles. [B, 3]
343
+ :returns: barycentric coordinates of ``p``'s projection in triangle defined by ``q``, ``u``, ``v``
344
+ vectorized so ``p``, ``q``, ``u``, ``v`` can all be ``3xN``
345
+ """
346
+ #(p, q, u, v)
347
+ v0, v1, v2 = vertices[:, 0], vertices[:, 1], vertices[:, 2]
348
+ p = points
349
+
350
+ q = v0
351
+ u = v1 - v0
352
+ v = v2 - v0
353
+ n = torch.cross(u, v)
354
+ s = torch.sum(n * n, dim=1)
355
+ # If the triangle edges are collinear, cross-product is zero,
356
+ # which makes "s" 0, which gives us divide by zero. So we
357
+ # make the arbitrary choice to set s to epsv (=numpy.spacing(1)),
358
+ # the closest thing to zero
359
+ s[s == 0] = 1e-6
360
+ oneOver4ASquared = 1.0 / s
361
+ w = p - q
362
+ b2 = torch.sum(torch.cross(u, w) * n, dim=1) * oneOver4ASquared
363
+ b1 = torch.sum(torch.cross(w, v) * n, dim=1) * oneOver4ASquared
364
+ weights = torch.stack((1 - b1 - b2, b1, b2), dim=-1)
365
+ # check barycenric weights
366
+ # p_n = v0*weights[:,0:1] + v1*weights[:,1:2] + v2*weights[:,2:3]
367
+ return weights
368
+
369
+
370
+ def cal_sdf_batch(verts, faces, cmaps, vis, points):
371
+
372
+ # verts [B, N_vert, 3]
373
+ # faces [B, N_face, 3]
374
+ # triangles [B, N_face, 3, 3]
375
+ # points [B, N_point, 3]
376
+ # cmaps [B, N_vert, 3]
377
+
378
+ Bsize = points.shape[0]
379
+
380
+ normals = Meshes(verts, faces).verts_normals_padded()
381
+
382
+ triangles = face_vertices(verts, faces)
383
+ normals = face_vertices(normals, faces)
384
+ cmaps = face_vertices(cmaps, faces)
385
+ vis = face_vertices(vis, faces)
386
+
387
+ residues, pts_ind, _ = point_to_mesh_distance(points, triangles)
388
+ closest_triangles = torch.gather(
389
+ triangles, 1, pts_ind[:, :, None, None].expand(-1, -1, 3, 3)).view(-1, 3, 3)
390
+ closest_normals = torch.gather(
391
+ normals, 1, pts_ind[:, :, None, None].expand(-1, -1, 3, 3)).view(-1, 3, 3)
392
+ closest_cmaps = torch.gather(
393
+ cmaps, 1, pts_ind[:, :, None, None].expand(-1, -1, 3, 3)).view(-1, 3, 3)
394
+ closest_vis = torch.gather(
395
+ vis, 1, pts_ind[:, :, None, None].expand(-1, -1, 3, 1)).view(-1, 3, 1)
396
+ bary_weights = barycentric_coordinates_of_projection(
397
+ points.view(-1, 3), closest_triangles)
398
+
399
+ pts_cmap = (closest_cmaps*bary_weights[:, :, None]).sum(1).unsqueeze(0)
400
+ pts_vis = (closest_vis*bary_weights[:,
401
+ :, None]).sum(1).unsqueeze(0).ge(1e-1)
402
+ pts_norm = (closest_normals*bary_weights[:, :, None]).sum(
403
+ 1).unsqueeze(0) * torch.tensor([-1.0, 1.0, -1.0]).type_as(normals)
404
+ pts_dist = torch.sqrt(residues) / torch.sqrt(torch.tensor(3))
405
+
406
+ pts_signs = 2.0 * (check_sign(verts, faces[0], points).float() - 0.5)
407
+ pts_sdf = (pts_dist * pts_signs).unsqueeze(-1)
408
+
409
+ return pts_sdf.view(Bsize, -1, 1), pts_norm.view(Bsize, -1, 3), pts_cmap.view(Bsize, -1, 3), pts_vis.view(Bsize, -1, 1)
410
+
411
+
412
+ def orthogonal(points, calibrations, transforms=None):
413
+ '''
414
+ Compute the orthogonal projections of 3D points into the image plane by given projection matrix
415
+ :param points: [B, 3, N] Tensor of 3D points
416
+ :param calibrations: [B, 3, 4] Tensor of projection matrix
417
+ :param transforms: [B, 2, 3] Tensor of image transform matrix
418
+ :return: xyz: [B, 3, N] Tensor of xyz coordinates in the image plane
419
+ '''
420
+ rot = calibrations[:, :3, :3]
421
+ trans = calibrations[:, :3, 3:4]
422
+ pts = torch.baddbmm(trans, rot, points) # [B, 3, N]
423
+ if transforms is not None:
424
+ scale = transforms[:2, :2]
425
+ shift = transforms[:2, 2:3]
426
+ pts[:, :2, :] = torch.baddbmm(shift, scale, pts[:, :2, :])
427
+ return pts
428
+
429
+
430
+ def projection(points, calib, format='numpy'):
431
+ if format == 'tensor':
432
+ return torch.mm(calib[:3, :3], points.T).T + calib[:3, 3]
433
+ else:
434
+ return np.matmul(calib[:3, :3], points.T).T + calib[:3, 3]
435
+
436
+
437
+ def load_calib(calib_path):
438
+ calib_data = np.loadtxt(calib_path, dtype=float)
439
+ extrinsic = calib_data[:4, :4]
440
+ intrinsic = calib_data[4:8, :4]
441
+ calib_mat = np.matmul(intrinsic, extrinsic)
442
+ calib_mat = torch.from_numpy(calib_mat).float()
443
+ return calib_mat
444
+
445
+
446
+ def load_obj_mesh_for_Hoppe(mesh_file):
447
+ vertex_data = []
448
+ face_data = []
449
+
450
+ if isinstance(mesh_file, str):
451
+ f = open(mesh_file, "r")
452
+ else:
453
+ f = mesh_file
454
+ for line in f:
455
+ if isinstance(line, bytes):
456
+ line = line.decode("utf-8")
457
+ if line.startswith('#'):
458
+ continue
459
+ values = line.split()
460
+ if not values:
461
+ continue
462
+
463
+ if values[0] == 'v':
464
+ v = list(map(float, values[1:4]))
465
+ vertex_data.append(v)
466
+
467
+ elif values[0] == 'f':
468
+ # quad mesh
469
+ if len(values) > 4:
470
+ f = list(map(lambda x: int(x.split('/')[0]), values[1:4]))
471
+ face_data.append(f)
472
+ f = list(
473
+ map(lambda x: int(x.split('/')[0]),
474
+ [values[3], values[4], values[1]]))
475
+ face_data.append(f)
476
+ # tri mesh
477
+ else:
478
+ f = list(map(lambda x: int(x.split('/')[0]), values[1:4]))
479
+ face_data.append(f)
480
+
481
+ vertices = np.array(vertex_data)
482
+ faces = np.array(face_data)
483
+ faces[faces > 0] -= 1
484
+
485
+ normals, _ = compute_normal(vertices, faces)
486
+
487
+ return vertices, normals, faces
488
+
489
+
490
+ def load_obj_mesh_with_color(mesh_file):
491
+ vertex_data = []
492
+ color_data = []
493
+ face_data = []
494
+
495
+ if isinstance(mesh_file, str):
496
+ f = open(mesh_file, "r")
497
+ else:
498
+ f = mesh_file
499
+ for line in f:
500
+ if isinstance(line, bytes):
501
+ line = line.decode("utf-8")
502
+ if line.startswith('#'):
503
+ continue
504
+ values = line.split()
505
+ if not values:
506
+ continue
507
+
508
+ if values[0] == 'v':
509
+ v = list(map(float, values[1:4]))
510
+ vertex_data.append(v)
511
+ c = list(map(float, values[4:7]))
512
+ color_data.append(c)
513
+
514
+ elif values[0] == 'f':
515
+ # quad mesh
516
+ if len(values) > 4:
517
+ f = list(map(lambda x: int(x.split('/')[0]), values[1:4]))
518
+ face_data.append(f)
519
+ f = list(
520
+ map(lambda x: int(x.split('/')[0]),
521
+ [values[3], values[4], values[1]]))
522
+ face_data.append(f)
523
+ # tri mesh
524
+ else:
525
+ f = list(map(lambda x: int(x.split('/')[0]), values[1:4]))
526
+ face_data.append(f)
527
+
528
+ vertices = np.array(vertex_data)
529
+ colors = np.array(color_data)
530
+ faces = np.array(face_data)
531
+ faces[faces > 0] -= 1
532
+
533
+ return vertices, colors, faces
534
+
535
+
536
+ def load_obj_mesh(mesh_file, with_normal=False, with_texture=False):
537
+ vertex_data = []
538
+ norm_data = []
539
+ uv_data = []
540
+
541
+ face_data = []
542
+ face_norm_data = []
543
+ face_uv_data = []
544
+
545
+ if isinstance(mesh_file, str):
546
+ f = open(mesh_file, "r")
547
+ else:
548
+ f = mesh_file
549
+ for line in f:
550
+ if isinstance(line, bytes):
551
+ line = line.decode("utf-8")
552
+ if line.startswith('#'):
553
+ continue
554
+ values = line.split()
555
+ if not values:
556
+ continue
557
+
558
+ if values[0] == 'v':
559
+ v = list(map(float, values[1:4]))
560
+ vertex_data.append(v)
561
+ elif values[0] == 'vn':
562
+ vn = list(map(float, values[1:4]))
563
+ norm_data.append(vn)
564
+ elif values[0] == 'vt':
565
+ vt = list(map(float, values[1:3]))
566
+ uv_data.append(vt)
567
+
568
+ elif values[0] == 'f':
569
+ # quad mesh
570
+ if len(values) > 4:
571
+ f = list(map(lambda x: int(x.split('/')[0]), values[1:4]))
572
+ face_data.append(f)
573
+ f = list(
574
+ map(lambda x: int(x.split('/')[0]),
575
+ [values[3], values[4], values[1]]))
576
+ face_data.append(f)
577
+ # tri mesh
578
+ else:
579
+ f = list(map(lambda x: int(x.split('/')[0]), values[1:4]))
580
+ face_data.append(f)
581
+
582
+ # deal with texture
583
+ if len(values[1].split('/')) >= 2:
584
+ # quad mesh
585
+ if len(values) > 4:
586
+ f = list(map(lambda x: int(x.split('/')[1]), values[1:4]))
587
+ face_uv_data.append(f)
588
+ f = list(
589
+ map(lambda x: int(x.split('/')[1]),
590
+ [values[3], values[4], values[1]]))
591
+ face_uv_data.append(f)
592
+ # tri mesh
593
+ elif len(values[1].split('/')[1]) != 0:
594
+ f = list(map(lambda x: int(x.split('/')[1]), values[1:4]))
595
+ face_uv_data.append(f)
596
+ # deal with normal
597
+ if len(values[1].split('/')) == 3:
598
+ # quad mesh
599
+ if len(values) > 4:
600
+ f = list(map(lambda x: int(x.split('/')[2]), values[1:4]))
601
+ face_norm_data.append(f)
602
+ f = list(
603
+ map(lambda x: int(x.split('/')[2]),
604
+ [values[3], values[4], values[1]]))
605
+ face_norm_data.append(f)
606
+ # tri mesh
607
+ elif len(values[1].split('/')[2]) != 0:
608
+ f = list(map(lambda x: int(x.split('/')[2]), values[1:4]))
609
+ face_norm_data.append(f)
610
+
611
+ vertices = np.array(vertex_data)
612
+ faces = np.array(face_data)
613
+ faces[faces > 0] -= 1
614
+
615
+ if with_texture and with_normal:
616
+ uvs = np.array(uv_data)
617
+ face_uvs = np.array(face_uv_data)
618
+ face_uvs[face_uvs > 0] -= 1
619
+ norms = np.array(norm_data)
620
+ if norms.shape[0] == 0:
621
+ norms, _ = compute_normal(vertices, faces)
622
+ face_normals = faces
623
+ else:
624
+ norms = normalize_v3(norms)
625
+ face_normals = np.array(face_norm_data)
626
+ face_normals[face_normals > 0] -= 1
627
+ return vertices, faces, norms, face_normals, uvs, face_uvs
628
+
629
+ if with_texture:
630
+ uvs = np.array(uv_data)
631
+ face_uvs = np.array(face_uv_data) - 1
632
+ return vertices, faces, uvs, face_uvs
633
+
634
+ if with_normal:
635
+ norms = np.array(norm_data)
636
+ norms = normalize_v3(norms)
637
+ face_normals = np.array(face_norm_data) - 1
638
+ return vertices, faces, norms, face_normals
639
+
640
+ return vertices, faces
641
+
642
+
643
+ def normalize_v3(arr):
644
+ ''' Normalize a numpy array of 3 component vectors shape=(n,3) '''
645
+ lens = np.sqrt(arr[:, 0]**2 + arr[:, 1]**2 + arr[:, 2]**2)
646
+ eps = 0.00000001
647
+ lens[lens < eps] = eps
648
+ arr[:, 0] /= lens
649
+ arr[:, 1] /= lens
650
+ arr[:, 2] /= lens
651
+ return arr
652
+
653
+
654
+ def compute_normal(vertices, faces):
655
+ # Create a zeroed array with the same type and shape as our vertices i.e., per vertex normal
656
+ vert_norms = np.zeros(vertices.shape, dtype=vertices.dtype)
657
+ # Create an indexed view into the vertex array using the array of three indices for triangles
658
+ tris = vertices[faces]
659
+ # Calculate the normal for all the triangles, by taking the cross product of the vectors v1-v0, and v2-v0 in each triangle
660
+ face_norms = np.cross(tris[::, 1] - tris[::, 0], tris[::, 2] - tris[::, 0])
661
+ # n is now an array of normals per triangle. The length of each normal is dependent the vertices,
662
+ # we need to normalize these, so that our next step weights each normal equally.
663
+ normalize_v3(face_norms)
664
+ # now we have a normalized array of normals, one per triangle, i.e., per triangle normals.
665
+ # But instead of one per triangle (i.e., flat shading), we add to each vertex in that triangle,
666
+ # the triangles' normal. Multiple triangles would then contribute to every vertex, so we need to normalize again afterwards.
667
+ # The cool part, we can actually add the normals through an indexed view of our (zeroed) per vertex normal array
668
+ vert_norms[faces[:, 0]] += face_norms
669
+ vert_norms[faces[:, 1]] += face_norms
670
+ vert_norms[faces[:, 2]] += face_norms
671
+ normalize_v3(vert_norms)
672
+
673
+ return vert_norms, face_norms
674
+
675
+
676
+ def save_obj_mesh(mesh_path, verts, faces):
677
+ file = open(mesh_path, 'w')
678
+ for v in verts:
679
+ file.write('v %.4f %.4f %.4f\n' % (v[0], v[1], v[2]))
680
+ for f in faces:
681
+ f_plus = f + 1
682
+ file.write('f %d %d %d\n' % (f_plus[0], f_plus[1], f_plus[2]))
683
+ file.close()
684
+
685
+
686
+ def save_obj_mesh_with_color(mesh_path, verts, faces, colors):
687
+ file = open(mesh_path, 'w')
688
+
689
+ for idx, v in enumerate(verts):
690
+ c = colors[idx]
691
+ file.write('v %.4f %.4f %.4f %.4f %.4f %.4f\n' %
692
+ (v[0], v[1], v[2], c[0], c[1], c[2]))
693
+ for f in faces:
694
+ f_plus = f + 1
695
+ file.write('f %d %d %d\n' % (f_plus[0], f_plus[1], f_plus[2]))
696
+ file.close()
697
+
698
+
699
+ def calculate_mIoU(outputs, labels):
700
+
701
+ SMOOTH = 1e-6
702
+
703
+ outputs = outputs.int()
704
+ labels = labels.int()
705
+
706
+ intersection = (
707
+ outputs
708
+ & labels).float().sum() # Will be zero if Truth=0 or Prediction=0
709
+ union = (outputs | labels).float().sum() # Will be zzero if both are 0
710
+
711
+ iou = (intersection + SMOOTH) / (union + SMOOTH
712
+ ) # We smooth our devision to avoid 0/0
713
+
714
+ thresholded = torch.clamp(
715
+ 20 * (iou - 0.5), 0,
716
+ 10).ceil() / 10 # This is equal to comparing with thresolds
717
+
718
+ return thresholded.mean().detach().cpu().numpy(
719
+ ) # Or thresholded.mean() if you are interested in average across the batch
720
+
721
+
722
+ def mask_filter(mask, number=1000):
723
+ """only keep {number} True items within a mask
724
+
725
+ Args:
726
+ mask (bool array): [N, ]
727
+ number (int, optional): total True item. Defaults to 1000.
728
+ """
729
+ true_ids = np.where(mask)[0]
730
+ keep_ids = np.random.choice(true_ids, size=number)
731
+ filter_mask = np.isin(np.arange(len(mask)), keep_ids)
732
+
733
+ return filter_mask
734
+
735
+
736
+ def query_mesh(path):
737
+
738
+ verts, faces_idx, _ = load_obj(path)
739
+
740
+ return verts, faces_idx.verts_idx
741
+
742
+
743
+ def add_alpha(colors, alpha=0.7):
744
+
745
+ colors_pad = np.pad(colors, ((0, 0), (0, 1)),
746
+ mode='constant',
747
+ constant_values=alpha)
748
+
749
+ return colors_pad
750
+
751
+
752
+ def get_optim_grid_image(per_loop_lst, loss=None, nrow=4, type='smpl'):
753
+
754
+ font_path = os.path.join(os.path.dirname(__file__), "tbfo.ttf")
755
+ font = ImageFont.truetype(font_path, 30)
756
+ grid_img = torchvision.utils.make_grid(torch.cat(per_loop_lst, dim=0),
757
+ nrow=nrow)
758
+ grid_img = Image.fromarray(
759
+ ((grid_img.permute(1, 2, 0).detach().cpu().numpy() + 1.0) * 0.5 *
760
+ 255.0).astype(np.uint8))
761
+
762
+ # add text
763
+ draw = ImageDraw.Draw(grid_img)
764
+ grid_size = 512
765
+ if loss is not None:
766
+ draw.text((10, 5), f"error: {loss:.3f}", (255, 0, 0), font=font)
767
+
768
+ if type == 'smpl':
769
+ for col_id, col_txt in enumerate(
770
+ ['image', 'smpl-norm(render)', 'cloth-norm(pred)', 'diff-norm', 'diff-mask']):
771
+ draw.text((10+(col_id*grid_size), 5),
772
+ col_txt, (255, 0, 0), font=font)
773
+ elif type == 'cloth':
774
+ for col_id, col_txt in enumerate(
775
+ ['image', 'cloth-norm(recon)', 'cloth-norm(pred)', 'diff-norm']):
776
+ draw.text((10+(col_id*grid_size), 5),
777
+ col_txt, (255, 0, 0), font=font)
778
+ for col_id, col_txt in enumerate(
779
+ ['0', '90', '180', '270']):
780
+ draw.text((10+(col_id*grid_size), grid_size*2+5),
781
+ col_txt, (255, 0, 0), font=font)
782
+ else:
783
+ print(f"{type} should be 'smpl' or 'cloth'")
784
+
785
+ grid_img = grid_img.resize((grid_img.size[0], grid_img.size[1]),
786
+ Image.ANTIALIAS)
787
+
788
+ return grid_img
789
+
790
+
791
+ def clean_mesh(verts, faces):
792
+
793
+ device = verts.device
794
+
795
+ mesh_lst = trimesh.Trimesh(verts.detach().cpu().numpy(),
796
+ faces.detach().cpu().numpy())
797
+ mesh_lst = mesh_lst.split(only_watertight=False)
798
+ comp_num = [mesh.vertices.shape[0] for mesh in mesh_lst]
799
+ mesh_clean = mesh_lst[comp_num.index(max(comp_num))]
800
+
801
+ final_verts = torch.as_tensor(mesh_clean.vertices).float().to(device)
802
+ final_faces = torch.as_tensor(mesh_clean.faces).int().to(device)
803
+
804
+ return final_verts, final_faces
805
+
806
+
807
+ def merge_mesh(verts_A, faces_A, verts_B, faces_B, color=False):
808
+
809
+ sep_mesh = trimesh.Trimesh(np.concatenate([verts_A, verts_B], axis=0),
810
+ np.concatenate(
811
+ [faces_A, faces_B + faces_A.max() + 1],
812
+ axis=0),
813
+ maintain_order=True,
814
+ process=False)
815
+ if color:
816
+ colors = np.ones_like(sep_mesh.vertices)
817
+ colors[:verts_A.shape[0]] *= np.array([255.0, 0.0, 0.0])
818
+ colors[verts_A.shape[0]:] *= np.array([0.0, 255.0, 0.0])
819
+ sep_mesh.visual.vertex_colors = colors
820
+
821
+ # union_mesh = trimesh.boolean.union([trimesh.Trimesh(verts_A, faces_A),
822
+ # trimesh.Trimesh(verts_B, faces_B)], engine='blender')
823
+
824
+ return sep_mesh
825
+
826
+
827
+ def mesh_move(mesh_lst, step, scale=1.0):
828
+
829
+ trans = np.array([1.0, 0.0, 0.0]) * step
830
+
831
+ resize_matrix = trimesh.transformations.scale_and_translate(
832
+ scale=(scale), translate=trans)
833
+
834
+ results = []
835
+
836
+ for mesh in mesh_lst:
837
+ mesh.apply_transform(resize_matrix)
838
+ results.append(mesh)
839
+
840
+ return results
841
+
842
+
843
+ class SMPLX():
844
+ def __init__(self):
845
+
846
+ REPO_ID = "Yuliang/SMPL"
847
+
848
+ self.smpl_verts_path = hf_hub_download(REPO_ID, filename='smpl_data/smpl_verts.npy', use_auth_token=os.environ['ICON'])
849
+ self.smplx_verts_path = hf_hub_download(REPO_ID, filename='smpl_data/smplx_verts.npy', use_auth_token=os.environ['ICON'])
850
+ self.faces_path = hf_hub_download(REPO_ID, filename='smpl_data/smplx_faces.npy', use_auth_token=os.environ['ICON'])
851
+ self.cmap_vert_path = hf_hub_download(REPO_ID, filename='smpl_data/smplx_cmap.npy', use_auth_token=os.environ['ICON'])
852
+
853
+ self.faces = np.load(self.faces_path)
854
+ self.verts = np.load(self.smplx_verts_path)
855
+ self.smpl_verts = np.load(self.smpl_verts_path)
856
+
857
+ self.model_dir = hf_hub_url(REPO_ID, filename='models')
858
+ self.tedra_dir = hf_hub_url(REPO_ID, filename='tedra_data')
859
+
860
+ def get_smpl_mat(self, vert_ids):
861
+
862
+ mat = torch.as_tensor(np.load(self.cmap_vert_path)).float()
863
+ return mat[vert_ids, :]
864
+
865
+ def smpl2smplx(self, vert_ids=None):
866
+ """convert vert_ids in smpl to vert_ids in smplx
867
+
868
+ Args:
869
+ vert_ids ([int.array]): [n, knn_num]
870
+ """
871
+ smplx_tree = cKDTree(self.verts, leafsize=1)
872
+ _, ind = smplx_tree.query(self.smpl_verts, k=1) # ind: [smpl_num, 1]
873
+
874
+ if vert_ids is not None:
875
+ smplx_vert_ids = ind[vert_ids]
876
+ else:
877
+ smplx_vert_ids = ind
878
+
879
+ return smplx_vert_ids
880
+
881
+ def smplx2smpl(self, vert_ids=None):
882
+ """convert vert_ids in smplx to vert_ids in smpl
883
+
884
+ Args:
885
+ vert_ids ([int.array]): [n, knn_num]
886
+ """
887
+ smpl_tree = cKDTree(self.smpl_verts, leafsize=1)
888
+ _, ind = smpl_tree.query(self.verts, k=1) # ind: [smplx_num, 1]
889
+ if vert_ids is not None:
890
+ smpl_vert_ids = ind[vert_ids]
891
+ else:
892
+ smpl_vert_ids = ind
893
+
894
+ return smpl_vert_ids
lib/dataset/tbfo.ttf ADDED
Binary file (571 kB). View file
 
lib/net/BasePIFuNet.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
5
+ # holder of all proprietary rights on this computer program.
6
+ # You can only use this computer program if you have closed
7
+ # a license agreement with MPG or you get the right to use the computer
8
+ # program from someone who is authorized to grant you that right.
9
+ # Any use of the computer program without a valid license is prohibited and
10
+ # liable to prosecution.
11
+ #
12
+ # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
13
+ # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
14
+ # for Intelligent Systems. All rights reserved.
15
+ #
16
+ # Contact: [email protected]
17
+
18
+ import torch.nn as nn
19
+ import pytorch_lightning as pl
20
+
21
+ from .geometry import index, orthogonal, perspective
22
+
23
+
24
+ class BasePIFuNet(pl.LightningModule):
25
+ def __init__(
26
+ self,
27
+ projection_mode='orthogonal',
28
+ error_term=nn.MSELoss(),
29
+ ):
30
+ """
31
+ :param projection_mode:
32
+ Either orthogonal or perspective.
33
+ It will call the corresponding function for projection.
34
+ :param error_term:
35
+ nn Loss between the predicted [B, Res, N] and the label [B, Res, N]
36
+ """
37
+ super(BasePIFuNet, self).__init__()
38
+ self.name = 'base'
39
+
40
+ self.error_term = error_term
41
+
42
+ self.index = index
43
+ self.projection = orthogonal if projection_mode == 'orthogonal' else perspective
44
+
45
+ def forward(self, points, images, calibs, transforms=None):
46
+ '''
47
+ :param points: [B, 3, N] world space coordinates of points
48
+ :param images: [B, C, H, W] input images
49
+ :param calibs: [B, 3, 4] calibration matrices for each image
50
+ :param transforms: Optional [B, 2, 3] image space coordinate transforms
51
+ :return: [B, Res, N] predictions for each point
52
+ '''
53
+ features = self.filter(images)
54
+ preds = self.query(features, points, calibs, transforms)
55
+ return preds
56
+
57
+ def filter(self, images):
58
+ '''
59
+ Filter the input images
60
+ store all intermediate features.
61
+ :param images: [B, C, H, W] input images
62
+ '''
63
+ return None
64
+
65
+ def query(self, features, points, calibs, transforms=None):
66
+ '''
67
+ Given 3D points, query the network predictions for each point.
68
+ Image features should be pre-computed before this call.
69
+ store all intermediate features.
70
+ query() function may behave differently during training/testing.
71
+ :param points: [B, 3, N] world space coordinates of points
72
+ :param calibs: [B, 3, 4] calibration matrices for each image
73
+ :param transforms: Optional [B, 2, 3] image space coordinate transforms
74
+ :param labels: Optional [B, Res, N] gt labeling
75
+ :return: [B, Res, N] predictions for each point
76
+ '''
77
+ return None
78
+
79
+ def get_error(self, preds, labels):
80
+ '''
81
+ Get the network loss from the last query
82
+ :return: loss term
83
+ '''
84
+ return self.error_term(preds, labels)
lib/net/FBNet.py ADDED
@@ -0,0 +1,387 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Copyright (C) 2019 NVIDIA Corporation. Ting-Chun Wang, Ming-Yu Liu, Jun-Yan Zhu.
3
+ BSD License. All rights reserved.
4
+
5
+ Redistribution and use in source and binary forms, with or without
6
+ modification, are permitted provided that the following conditions are met:
7
+
8
+ * Redistributions of source code must retain the above copyright notice, this
9
+ list of conditions and the following disclaimer.
10
+
11
+ * Redistributions in binary form must reproduce the above copyright notice,
12
+ this list of conditions and the following disclaimer in the documentation
13
+ and/or other materials provided with the distribution.
14
+
15
+ THE AUTHOR DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, INCLUDING ALL
16
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR ANY PARTICULAR PURPOSE.
17
+ IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL
18
+ DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
19
+ WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING
20
+ OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
21
+ '''
22
+ import torch
23
+ import torch.nn as nn
24
+ import functools
25
+ import numpy as np
26
+ import pytorch_lightning as pl
27
+
28
+
29
+ ###############################################################################
30
+ # Functions
31
+ ###############################################################################
32
+ def weights_init(m):
33
+ classname = m.__class__.__name__
34
+ if classname.find('Conv') != -1:
35
+ m.weight.data.normal_(0.0, 0.02)
36
+ elif classname.find('BatchNorm2d') != -1:
37
+ m.weight.data.normal_(1.0, 0.02)
38
+ m.bias.data.fill_(0)
39
+
40
+
41
+ def get_norm_layer(norm_type='instance'):
42
+ if norm_type == 'batch':
43
+ norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
44
+ elif norm_type == 'instance':
45
+ norm_layer = functools.partial(nn.InstanceNorm2d, affine=False)
46
+ else:
47
+ raise NotImplementedError('normalization layer [%s] is not found' %
48
+ norm_type)
49
+ return norm_layer
50
+
51
+
52
+ def define_G(input_nc,
53
+ output_nc,
54
+ ngf,
55
+ netG,
56
+ n_downsample_global=3,
57
+ n_blocks_global=9,
58
+ n_local_enhancers=1,
59
+ n_blocks_local=3,
60
+ norm='instance',
61
+ gpu_ids=[],
62
+ last_op=nn.Tanh()):
63
+ norm_layer = get_norm_layer(norm_type=norm)
64
+ if netG == 'global':
65
+ netG = GlobalGenerator(input_nc,
66
+ output_nc,
67
+ ngf,
68
+ n_downsample_global,
69
+ n_blocks_global,
70
+ norm_layer,
71
+ last_op=last_op)
72
+ elif netG == 'local':
73
+ netG = LocalEnhancer(input_nc, output_nc, ngf, n_downsample_global,
74
+ n_blocks_global, n_local_enhancers,
75
+ n_blocks_local, norm_layer)
76
+ elif netG == 'encoder':
77
+ netG = Encoder(input_nc, output_nc, ngf, n_downsample_global,
78
+ norm_layer)
79
+ else:
80
+ raise ('generator not implemented!')
81
+ # print(netG)
82
+ if len(gpu_ids) > 0:
83
+ assert (torch.cuda.is_available())
84
+ netG.cuda(gpu_ids[0])
85
+ netG.apply(weights_init)
86
+ return netG
87
+
88
+
89
+ def print_network(net):
90
+ if isinstance(net, list):
91
+ net = net[0]
92
+ num_params = 0
93
+ for param in net.parameters():
94
+ num_params += param.numel()
95
+ print(net)
96
+ print('Total number of parameters: %d' % num_params)
97
+
98
+
99
+ ##############################################################################
100
+ # Generator
101
+ ##############################################################################
102
+ class LocalEnhancer(pl.LightningModule):
103
+ def __init__(self,
104
+ input_nc,
105
+ output_nc,
106
+ ngf=32,
107
+ n_downsample_global=3,
108
+ n_blocks_global=9,
109
+ n_local_enhancers=1,
110
+ n_blocks_local=3,
111
+ norm_layer=nn.BatchNorm2d,
112
+ padding_type='reflect'):
113
+ super(LocalEnhancer, self).__init__()
114
+ self.n_local_enhancers = n_local_enhancers
115
+
116
+ ###### global generator model #####
117
+ ngf_global = ngf * (2**n_local_enhancers)
118
+ model_global = GlobalGenerator(input_nc, output_nc, ngf_global,
119
+ n_downsample_global, n_blocks_global,
120
+ norm_layer).model
121
+ model_global = [model_global[i] for i in range(len(model_global) - 3)
122
+ ] # get rid of final convolution layers
123
+ self.model = nn.Sequential(*model_global)
124
+
125
+ ###### local enhancer layers #####
126
+ for n in range(1, n_local_enhancers + 1):
127
+ # downsample
128
+ ngf_global = ngf * (2**(n_local_enhancers - n))
129
+ model_downsample = [
130
+ nn.ReflectionPad2d(3),
131
+ nn.Conv2d(input_nc, ngf_global, kernel_size=7, padding=0),
132
+ norm_layer(ngf_global),
133
+ nn.ReLU(True),
134
+ nn.Conv2d(ngf_global,
135
+ ngf_global * 2,
136
+ kernel_size=3,
137
+ stride=2,
138
+ padding=1),
139
+ norm_layer(ngf_global * 2),
140
+ nn.ReLU(True)
141
+ ]
142
+ # residual blocks
143
+ model_upsample = []
144
+ for i in range(n_blocks_local):
145
+ model_upsample += [
146
+ ResnetBlock(ngf_global * 2,
147
+ padding_type=padding_type,
148
+ norm_layer=norm_layer)
149
+ ]
150
+
151
+ # upsample
152
+ model_upsample += [
153
+ nn.ConvTranspose2d(ngf_global * 2,
154
+ ngf_global,
155
+ kernel_size=3,
156
+ stride=2,
157
+ padding=1,
158
+ output_padding=1),
159
+ norm_layer(ngf_global),
160
+ nn.ReLU(True)
161
+ ]
162
+
163
+ # final convolution
164
+ if n == n_local_enhancers:
165
+ model_upsample += [
166
+ nn.ReflectionPad2d(3),
167
+ nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0),
168
+ nn.Tanh()
169
+ ]
170
+
171
+ setattr(self, 'model' + str(n) + '_1',
172
+ nn.Sequential(*model_downsample))
173
+ setattr(self, 'model' + str(n) + '_2',
174
+ nn.Sequential(*model_upsample))
175
+
176
+ self.downsample = nn.AvgPool2d(3,
177
+ stride=2,
178
+ padding=[1, 1],
179
+ count_include_pad=False)
180
+
181
+ def forward(self, input):
182
+ # create input pyramid
183
+ input_downsampled = [input]
184
+ for i in range(self.n_local_enhancers):
185
+ input_downsampled.append(self.downsample(input_downsampled[-1]))
186
+
187
+ # output at coarest level
188
+ output_prev = self.model(input_downsampled[-1])
189
+ # build up one layer at a time
190
+ for n_local_enhancers in range(1, self.n_local_enhancers + 1):
191
+ model_downsample = getattr(self,
192
+ 'model' + str(n_local_enhancers) + '_1')
193
+ model_upsample = getattr(self,
194
+ 'model' + str(n_local_enhancers) + '_2')
195
+ input_i = input_downsampled[self.n_local_enhancers -
196
+ n_local_enhancers]
197
+ output_prev = model_upsample(
198
+ model_downsample(input_i) + output_prev)
199
+ return output_prev
200
+
201
+
202
+ class GlobalGenerator(pl.LightningModule):
203
+ def __init__(self,
204
+ input_nc,
205
+ output_nc,
206
+ ngf=64,
207
+ n_downsampling=3,
208
+ n_blocks=9,
209
+ norm_layer=nn.BatchNorm2d,
210
+ padding_type='reflect',
211
+ last_op=nn.Tanh()):
212
+ assert (n_blocks >= 0)
213
+ super(GlobalGenerator, self).__init__()
214
+ activation = nn.ReLU(True)
215
+
216
+ model = [
217
+ nn.ReflectionPad2d(3),
218
+ nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0),
219
+ norm_layer(ngf), activation
220
+ ]
221
+ # downsample
222
+ for i in range(n_downsampling):
223
+ mult = 2**i
224
+ model += [
225
+ nn.Conv2d(ngf * mult,
226
+ ngf * mult * 2,
227
+ kernel_size=3,
228
+ stride=2,
229
+ padding=1),
230
+ norm_layer(ngf * mult * 2), activation
231
+ ]
232
+
233
+ # resnet blocks
234
+ mult = 2**n_downsampling
235
+ for i in range(n_blocks):
236
+ model += [
237
+ ResnetBlock(ngf * mult,
238
+ padding_type=padding_type,
239
+ activation=activation,
240
+ norm_layer=norm_layer)
241
+ ]
242
+
243
+ # upsample
244
+ for i in range(n_downsampling):
245
+ mult = 2**(n_downsampling - i)
246
+ model += [
247
+ nn.ConvTranspose2d(ngf * mult,
248
+ int(ngf * mult / 2),
249
+ kernel_size=3,
250
+ stride=2,
251
+ padding=1,
252
+ output_padding=1),
253
+ norm_layer(int(ngf * mult / 2)), activation
254
+ ]
255
+ model += [
256
+ nn.ReflectionPad2d(3),
257
+ nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)
258
+ ]
259
+ if last_op is not None:
260
+ model += [last_op]
261
+ self.model = nn.Sequential(*model)
262
+
263
+ def forward(self, input):
264
+ return self.model(input)
265
+
266
+
267
+ # Define a resnet block
268
+ class ResnetBlock(pl.LightningModule):
269
+ def __init__(self,
270
+ dim,
271
+ padding_type,
272
+ norm_layer,
273
+ activation=nn.ReLU(True),
274
+ use_dropout=False):
275
+ super(ResnetBlock, self).__init__()
276
+ self.conv_block = self.build_conv_block(dim, padding_type, norm_layer,
277
+ activation, use_dropout)
278
+
279
+ def build_conv_block(self, dim, padding_type, norm_layer, activation,
280
+ use_dropout):
281
+ conv_block = []
282
+ p = 0
283
+ if padding_type == 'reflect':
284
+ conv_block += [nn.ReflectionPad2d(1)]
285
+ elif padding_type == 'replicate':
286
+ conv_block += [nn.ReplicationPad2d(1)]
287
+ elif padding_type == 'zero':
288
+ p = 1
289
+ else:
290
+ raise NotImplementedError('padding [%s] is not implemented' %
291
+ padding_type)
292
+
293
+ conv_block += [
294
+ nn.Conv2d(dim, dim, kernel_size=3, padding=p),
295
+ norm_layer(dim), activation
296
+ ]
297
+ if use_dropout:
298
+ conv_block += [nn.Dropout(0.5)]
299
+
300
+ p = 0
301
+ if padding_type == 'reflect':
302
+ conv_block += [nn.ReflectionPad2d(1)]
303
+ elif padding_type == 'replicate':
304
+ conv_block += [nn.ReplicationPad2d(1)]
305
+ elif padding_type == 'zero':
306
+ p = 1
307
+ else:
308
+ raise NotImplementedError('padding [%s] is not implemented' %
309
+ padding_type)
310
+ conv_block += [
311
+ nn.Conv2d(dim, dim, kernel_size=3, padding=p),
312
+ norm_layer(dim)
313
+ ]
314
+
315
+ return nn.Sequential(*conv_block)
316
+
317
+ def forward(self, x):
318
+ out = x + self.conv_block(x)
319
+ return out
320
+
321
+
322
+ class Encoder(pl.LightningModule):
323
+ def __init__(self,
324
+ input_nc,
325
+ output_nc,
326
+ ngf=32,
327
+ n_downsampling=4,
328
+ norm_layer=nn.BatchNorm2d):
329
+ super(Encoder, self).__init__()
330
+ self.output_nc = output_nc
331
+
332
+ model = [
333
+ nn.ReflectionPad2d(3),
334
+ nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0),
335
+ norm_layer(ngf),
336
+ nn.ReLU(True)
337
+ ]
338
+ # downsample
339
+ for i in range(n_downsampling):
340
+ mult = 2**i
341
+ model += [
342
+ nn.Conv2d(ngf * mult,
343
+ ngf * mult * 2,
344
+ kernel_size=3,
345
+ stride=2,
346
+ padding=1),
347
+ norm_layer(ngf * mult * 2),
348
+ nn.ReLU(True)
349
+ ]
350
+
351
+ # upsample
352
+ for i in range(n_downsampling):
353
+ mult = 2**(n_downsampling - i)
354
+ model += [
355
+ nn.ConvTranspose2d(ngf * mult,
356
+ int(ngf * mult / 2),
357
+ kernel_size=3,
358
+ stride=2,
359
+ padding=1,
360
+ output_padding=1),
361
+ norm_layer(int(ngf * mult / 2)),
362
+ nn.ReLU(True)
363
+ ]
364
+
365
+ model += [
366
+ nn.ReflectionPad2d(3),
367
+ nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0),
368
+ nn.Tanh()
369
+ ]
370
+ self.model = nn.Sequential(*model)
371
+
372
+ def forward(self, input, inst):
373
+ outputs = self.model(input)
374
+
375
+ # instance-wise average pooling
376
+ outputs_mean = outputs.clone()
377
+ inst_list = np.unique(inst.cpu().numpy().astype(int))
378
+ for i in inst_list:
379
+ for b in range(input.size()[0]):
380
+ indices = (inst[b:b + 1] == int(i)).nonzero() # n x 4
381
+ for j in range(self.output_nc):
382
+ output_ins = outputs[indices[:, 0] + b, indices[:, 1] + j,
383
+ indices[:, 2], indices[:, 3]]
384
+ mean_feat = torch.mean(output_ins).expand_as(output_ins)
385
+ outputs_mean[indices[:, 0] + b, indices[:, 1] + j,
386
+ indices[:, 2], indices[:, 3]] = mean_feat
387
+ return outputs_mean
lib/net/HGFilters.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
5
+ # holder of all proprietary rights on this computer program.
6
+ # You can only use this computer program if you have closed
7
+ # a license agreement with MPG or you get the right to use the computer
8
+ # program from someone who is authorized to grant you that right.
9
+ # Any use of the computer program without a valid license is prohibited and
10
+ # liable to prosecution.
11
+ #
12
+ # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
13
+ # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
14
+ # for Intelligent Systems. All rights reserved.
15
+ #
16
+ # Contact: [email protected]
17
+
18
+ from lib.net.net_util import *
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+
22
+
23
+ class HourGlass(nn.Module):
24
+ def __init__(self, num_modules, depth, num_features, opt):
25
+ super(HourGlass, self).__init__()
26
+ self.num_modules = num_modules
27
+ self.depth = depth
28
+ self.features = num_features
29
+ self.opt = opt
30
+
31
+ self._generate_network(self.depth)
32
+
33
+ def _generate_network(self, level):
34
+ self.add_module('b1_' + str(level),
35
+ ConvBlock(self.features, self.features, self.opt))
36
+
37
+ self.add_module('b2_' + str(level),
38
+ ConvBlock(self.features, self.features, self.opt))
39
+
40
+ if level > 1:
41
+ self._generate_network(level - 1)
42
+ else:
43
+ self.add_module('b2_plus_' + str(level),
44
+ ConvBlock(self.features, self.features, self.opt))
45
+
46
+ self.add_module('b3_' + str(level),
47
+ ConvBlock(self.features, self.features, self.opt))
48
+
49
+ def _forward(self, level, inp):
50
+ # Upper branch
51
+ up1 = inp
52
+ up1 = self._modules['b1_' + str(level)](up1)
53
+
54
+ # Lower branch
55
+ low1 = F.avg_pool2d(inp, 2, stride=2)
56
+ low1 = self._modules['b2_' + str(level)](low1)
57
+
58
+ if level > 1:
59
+ low2 = self._forward(level - 1, low1)
60
+ else:
61
+ low2 = low1
62
+ low2 = self._modules['b2_plus_' + str(level)](low2)
63
+
64
+ low3 = low2
65
+ low3 = self._modules['b3_' + str(level)](low3)
66
+
67
+ # NOTE: for newer PyTorch (1.3~), it seems that training results are degraded due to implementation diff in F.grid_sample
68
+ # if the pretrained model behaves weirdly, switch with the commented line.
69
+ # NOTE: I also found that "bicubic" works better.
70
+ up2 = F.interpolate(low3,
71
+ scale_factor=2,
72
+ mode='bicubic',
73
+ align_corners=True)
74
+ # up2 = F.interpolate(low3, scale_factor=2, mode='nearest)
75
+
76
+ return up1 + up2
77
+
78
+ def forward(self, x):
79
+ return self._forward(self.depth, x)
80
+
81
+
82
+ class HGFilter(nn.Module):
83
+ def __init__(self, opt, num_modules, in_dim):
84
+ super(HGFilter, self).__init__()
85
+ self.num_modules = num_modules
86
+
87
+ self.opt = opt
88
+ [k, s, d, p] = self.opt.conv1
89
+
90
+ # self.conv1 = nn.Conv2d(in_dim, 64, kernel_size=7, stride=2, padding=3)
91
+ self.conv1 = nn.Conv2d(in_dim,
92
+ 64,
93
+ kernel_size=k,
94
+ stride=s,
95
+ dilation=d,
96
+ padding=p)
97
+
98
+ if self.opt.norm == 'batch':
99
+ self.bn1 = nn.BatchNorm2d(64)
100
+ elif self.opt.norm == 'group':
101
+ self.bn1 = nn.GroupNorm(32, 64)
102
+
103
+ if self.opt.hg_down == 'conv64':
104
+ self.conv2 = ConvBlock(64, 64, self.opt)
105
+ self.down_conv2 = nn.Conv2d(64,
106
+ 128,
107
+ kernel_size=3,
108
+ stride=2,
109
+ padding=1)
110
+ elif self.opt.hg_down == 'conv128':
111
+ self.conv2 = ConvBlock(64, 128, self.opt)
112
+ self.down_conv2 = nn.Conv2d(128,
113
+ 128,
114
+ kernel_size=3,
115
+ stride=2,
116
+ padding=1)
117
+ elif self.opt.hg_down == 'ave_pool':
118
+ self.conv2 = ConvBlock(64, 128, self.opt)
119
+ else:
120
+ raise NameError('Unknown Fan Filter setting!')
121
+
122
+ self.conv3 = ConvBlock(128, 128, self.opt)
123
+ self.conv4 = ConvBlock(128, 256, self.opt)
124
+
125
+ # Stacking part
126
+ for hg_module in range(self.num_modules):
127
+ self.add_module('m' + str(hg_module),
128
+ HourGlass(1, opt.num_hourglass, 256, self.opt))
129
+
130
+ self.add_module('top_m_' + str(hg_module),
131
+ ConvBlock(256, 256, self.opt))
132
+ self.add_module(
133
+ 'conv_last' + str(hg_module),
134
+ nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
135
+ if self.opt.norm == 'batch':
136
+ self.add_module('bn_end' + str(hg_module), nn.BatchNorm2d(256))
137
+ elif self.opt.norm == 'group':
138
+ self.add_module('bn_end' + str(hg_module),
139
+ nn.GroupNorm(32, 256))
140
+
141
+ self.add_module(
142
+ 'l' + str(hg_module),
143
+ nn.Conv2d(256,
144
+ opt.hourglass_dim,
145
+ kernel_size=1,
146
+ stride=1,
147
+ padding=0))
148
+
149
+ if hg_module < self.num_modules - 1:
150
+ self.add_module(
151
+ 'bl' + str(hg_module),
152
+ nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
153
+ self.add_module(
154
+ 'al' + str(hg_module),
155
+ nn.Conv2d(opt.hourglass_dim,
156
+ 256,
157
+ kernel_size=1,
158
+ stride=1,
159
+ padding=0))
160
+
161
+ def forward(self, x):
162
+ x = F.relu(self.bn1(self.conv1(x)), True)
163
+ tmpx = x
164
+ if self.opt.hg_down == 'ave_pool':
165
+ x = F.avg_pool2d(self.conv2(x), 2, stride=2)
166
+ elif self.opt.hg_down in ['conv64', 'conv128']:
167
+ x = self.conv2(x)
168
+ x = self.down_conv2(x)
169
+ else:
170
+ raise NameError('Unknown Fan Filter setting!')
171
+
172
+ x = self.conv3(x)
173
+ x = self.conv4(x)
174
+
175
+ previous = x
176
+
177
+ outputs = []
178
+ for i in range(self.num_modules):
179
+ hg = self._modules['m' + str(i)](previous)
180
+
181
+ ll = hg
182
+ ll = self._modules['top_m_' + str(i)](ll)
183
+
184
+ ll = F.relu(
185
+ self._modules['bn_end' + str(i)](
186
+ self._modules['conv_last' + str(i)](ll)), True)
187
+
188
+ # Predict heatmaps
189
+ tmp_out = self._modules['l' + str(i)](ll)
190
+ outputs.append(tmp_out)
191
+
192
+ if i < self.num_modules - 1:
193
+ ll = self._modules['bl' + str(i)](ll)
194
+ tmp_out_ = self._modules['al' + str(i)](tmp_out)
195
+ previous = previous + ll + tmp_out_
196
+
197
+ return outputs