Stylique commited on
Commit
0de41d8
·
verified ·
1 Parent(s): ff75948

Upload 260 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +25 -0
  2. LICENSE.txt +21 -0
  3. ORIGINAL_README.md +79 -0
  4. README.md +6 -5
  5. app.py +225 -0
  6. assets/result_clr_scale4_pexels-barbara-olsen-7869640.mp4 +3 -0
  7. assets/result_clr_scale4_pexels-zdmit-6780091.mp4 +3 -0
  8. blender/blender_render_human_ortho.py +837 -0
  9. blender/check_render.py +46 -0
  10. blender/count.py +44 -0
  11. blender/distribute.py +149 -0
  12. blender/rename_smpl_files.py +25 -0
  13. blender/render.sh +4 -0
  14. blender/render_human.py +88 -0
  15. blender/render_single.sh +7 -0
  16. blender/utils.py +128 -0
  17. configs/inference-768-6view.yaml +72 -0
  18. configs/remesh.yaml +18 -0
  19. configs/train-768-6view-onlyscan_face.yaml +145 -0
  20. configs/train-768-6view-onlyscan_face_smplx.yaml +154 -0
  21. core/opt.py +197 -0
  22. core/remesh.py +359 -0
  23. econdataset.py +370 -0
  24. examples/02986d0998ce01aa0aa67a99fbd1e09a.png +3 -0
  25. examples/16171.png +3 -0
  26. examples/26d2e846349647ff04c536816e0e8ca1.png +3 -0
  27. examples/30755.png +3 -0
  28. examples/3930.png +3 -0
  29. examples/4656716-3016170581.png +3 -0
  30. examples/663dcd6db19490de0b790da430bd5681.png +3 -0
  31. examples/7332.png +3 -0
  32. examples/85891251f52a2399e660a63c2a7fdf40.png +3 -0
  33. examples/a689a48d23d6b8d58d67ff5146c6e088.png +3 -0
  34. examples/b0d178743c7e3e09700aaee8d2b1ec47.png +3 -0
  35. examples/case5.png +3 -0
  36. examples/d40776a1e1582179d97907d36f84d776.png +3 -0
  37. examples/durant.png +3 -0
  38. examples/eedb9018-e0eb-45be-33bd-5a0108ca0d8b.png +3 -0
  39. examples/f14f7d40b72062928461b21c6cc877407e69ee0c_high.png +3 -0
  40. examples/f6317ac1b0498f4e6ef9d12bd991a9bd1ff4ae04f898-IQTEBw_fw1200.png +3 -0
  41. examples/pexels-barbara-olsen-7869640.png +3 -0
  42. examples/pexels-julia-m-cameron-4145040.png +3 -0
  43. examples/pexels-marta-wave-6437749.png +3 -0
  44. examples/pexels-photo-6311555-removebg.png +3 -0
  45. examples/pexels-zdmit-6780091.png +3 -0
  46. inference.py +223 -0
  47. lib/__init__.py +0 -0
  48. lib/common/__init__.py +0 -0
  49. lib/common/cloth_extraction.py +182 -0
  50. lib/common/config.py +218 -0
.gitattributes CHANGED
@@ -33,3 +33,28 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/result_clr_scale4_pexels-barbara-olsen-7869640.mp4 filter=lfs diff=lfs merge=lfs -text
37
+ assets/result_clr_scale4_pexels-zdmit-6780091.mp4 filter=lfs diff=lfs merge=lfs -text
38
+ examples/02986d0998ce01aa0aa67a99fbd1e09a.png filter=lfs diff=lfs merge=lfs -text
39
+ examples/16171.png filter=lfs diff=lfs merge=lfs -text
40
+ examples/26d2e846349647ff04c536816e0e8ca1.png filter=lfs diff=lfs merge=lfs -text
41
+ examples/30755.png filter=lfs diff=lfs merge=lfs -text
42
+ examples/3930.png filter=lfs diff=lfs merge=lfs -text
43
+ examples/4656716-3016170581.png filter=lfs diff=lfs merge=lfs -text
44
+ examples/663dcd6db19490de0b790da430bd5681.png filter=lfs diff=lfs merge=lfs -text
45
+ examples/7332.png filter=lfs diff=lfs merge=lfs -text
46
+ examples/85891251f52a2399e660a63c2a7fdf40.png filter=lfs diff=lfs merge=lfs -text
47
+ examples/a689a48d23d6b8d58d67ff5146c6e088.png filter=lfs diff=lfs merge=lfs -text
48
+ examples/b0d178743c7e3e09700aaee8d2b1ec47.png filter=lfs diff=lfs merge=lfs -text
49
+ examples/case5.png filter=lfs diff=lfs merge=lfs -text
50
+ examples/d40776a1e1582179d97907d36f84d776.png filter=lfs diff=lfs merge=lfs -text
51
+ examples/durant.png filter=lfs diff=lfs merge=lfs -text
52
+ examples/eedb9018-e0eb-45be-33bd-5a0108ca0d8b.png filter=lfs diff=lfs merge=lfs -text
53
+ examples/f14f7d40b72062928461b21c6cc877407e69ee0c_high.png filter=lfs diff=lfs merge=lfs -text
54
+ examples/f6317ac1b0498f4e6ef9d12bd991a9bd1ff4ae04f898-IQTEBw_fw1200.png filter=lfs diff=lfs merge=lfs -text
55
+ examples/pexels-barbara-olsen-7869640.png filter=lfs diff=lfs merge=lfs -text
56
+ examples/pexels-julia-m-cameron-4145040.png filter=lfs diff=lfs merge=lfs -text
57
+ examples/pexels-marta-wave-6437749.png filter=lfs diff=lfs merge=lfs -text
58
+ examples/pexels-photo-6311555-removebg.png filter=lfs diff=lfs merge=lfs -text
59
+ examples/pexels-zdmit-6780091.png filter=lfs diff=lfs merge=lfs -text
60
+ lib/dataset/tbfo.ttf filter=lfs diff=lfs merge=lfs -text
LICENSE.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Fusion Lab: Generative Vision Lab of Fudan University
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
ORIGINAL_README.md ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # PSHuman
2
+
3
+ This is the official implementation of *PSHuman: Photorealistic Single-image 3D Human Reconstruction using Cross-Scale Multiview Diffusion*.
4
+
5
+ ### [Project Page](https://penghtyx.github.io/PSHuman/) | [Arxiv](https://arxiv.org/pdf/2409.10141) | [Weights](https://huggingface.co/pengHTYX/PSHuman_Unclip_768_6views)
6
+
7
+ https://github.com/user-attachments/assets/b62e3305-38a7-4b51-aed8-1fde967cca70
8
+
9
+ https://github.com/user-attachments/assets/76100d2e-4a1a-41ad-815c-816340ac6500
10
+
11
+
12
+ Given a single image of a clothed person, **PSHuman** facilitates detailed geometry and realistic 3D human appearance across various poses within one minute.
13
+
14
+ ### 📝 Update
15
+ - __[2024.11.30]__: Release the SMPL-free [version](https://huggingface.co/pengHTYX/PSHuman_Unclip_768_6views), which does not requires SMPL condition for multview generation and perfome well in general posed human.
16
+
17
+
18
+ ### Installation
19
+ ```
20
+ conda create -n pshuman python=3.10
21
+ conda activate pshuman
22
+
23
+ # torch
24
+ pip install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu121
25
+
26
+ # other depedency
27
+ pip install -r requirement.txt
28
+ ```
29
+
30
+ This project is also based on SMPLX. We borrowed the related models from [ECON](https://github.com/YuliangXiu/ECON) and [SIFU](https://github.com/River-Zhang/SIFU), and re-orginized them, which can be downloaded from [Onedrive](https://hkustconnect-my.sharepoint.com/:u:/g/personal/plibp_connect_ust_hk/EZQphP-2y5BGhEIe8jb03i4BIcqiJ2mUW2JmGC5s0VKOdw?e=qVzBBD).
31
+
32
+
33
+
34
+ ### Inference
35
+ 1. Given a human image, we use [Clipdrop](https://github.com/xxlong0/Wonder3D?tab=readme-ov-file) or ```rembg``` to remove the background. For the latter, we provide a simple scrip.
36
+ ```
37
+ python utils/remove_bg.py --path $DATA_PATH$
38
+ ```
39
+ Then, put the RGBA images in the ```$DATA_PATH$```.
40
+
41
+ 2. By running [inference.py](inference.py), the textured mesh and rendered video will be saved in ```out```.
42
+ ```
43
+ CUDA_VISIBLE_DEVICES=$GPU python inference.py --config configs/inference-768-6view.yaml \
44
+ pretrained_model_name_or_path='pengHTYX/PSHuman_Unclip_768_6views' \
45
+ validation_dataset.crop_size=740 \
46
+ with_smpl=false \
47
+ validation_dataset.root_dir=$DATA_PATH$ \
48
+ seed=600 \
49
+ num_views=7 \
50
+ save_mode='rgb'
51
+
52
+ ```
53
+ You can adjust the ```crop_size``` (720 or 740) and ```seed``` (42 or 600) to obtain best results for some cases.
54
+
55
+ ### Training
56
+ For the data preparing and preprocessing, please refer to our [paper](https://arxiv.org/pdf/2409.10141). Once the data is ready, we begin the training by running
57
+ ```
58
+ bash scripts/train_768.sh
59
+ ```
60
+ You should modified some parameters, such as ```data_common.root_dir``` and ```data_common.object_list```.
61
+
62
+ ### Related projects
63
+ We collect code from following projects. We thanks for the contributions from the open-source community!
64
+
65
+ [ECON](https://github.com/YuliangXiu/ECON) and [SIFU](https://github.com/River-Zhang/SIFU) recover human mesh from single human image.
66
+ [Era3D](https://github.com/pengHTYX/Era3D) and [Unique3D](https://github.com/AiuniAI/Unique3D) generate consistent multiview images with single color image.
67
+ [Continuous-Remeshing](https://github.com/Profactor/continuous-remeshing) for Inverse Rendering.
68
+
69
+
70
+ ### Citation
71
+ If you find this codebase useful, please consider cite our work.
72
+ ```
73
+ @article{li2024pshuman,
74
+ title={PSHuman: Photorealistic Single-view Human Reconstruction using Cross-Scale Diffusion},
75
+ author={Li, Peng and Zheng, Wangguandong and Liu, Yuan and Yu, Tao and Li, Yangguang and Qi, Xingqun and Li, Mengfei and Chi, Xiaowei and Xia, Siyu and Xue, Wei and others},
76
+ journal={arXiv preprint arXiv:2409.10141},
77
+ year={2024}
78
+ }
79
+ ```
README.md CHANGED
@@ -1,12 +1,13 @@
1
  ---
2
  title: PSHuman
3
- emoji: 📊
4
- colorFrom: red
5
- colorTo: blue
6
  sdk: gradio
7
- sdk_version: 5.38.2
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: PSHuman
3
+ emoji: 🏃
4
+ colorFrom: indigo
5
+ colorTo: pink
6
  sdk: gradio
7
+ sdk_version: 5.8.0
8
  app_file: app.py
9
  pinned: false
10
+ short_description: PHOTOREALISTIC HUMAN RECONSTRUCTION w/ CROSS-SCALE DIFF
11
  ---
12
 
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ import shutil
4
+ import tempfile
5
+ import gradio as gr
6
+ from PIL import Image
7
+ from rembg import remove
8
+ import sys
9
+ import uuid
10
+ import subprocess
11
+ from glob import glob
12
+ import requests
13
+ from huggingface_hub import snapshot_download
14
+
15
+ # Download models
16
+ os.makedirs("ckpts", exist_ok=True)
17
+
18
+ snapshot_download(
19
+ repo_id = "pengHTYX/PSHuman_Unclip_768_6views",
20
+ local_dir = "./ckpts"
21
+ )
22
+
23
+ os.makedirs("smpl_related", exist_ok=True)
24
+ snapshot_download(
25
+ repo_id = "fffiloni/PSHuman-SMPL-related",
26
+ local_dir = "./smpl_related"
27
+ )
28
+
29
+ # Folder containing example images
30
+ examples_folder = "examples"
31
+
32
+ # Retrieve all file paths in the folder
33
+ images_examples = [
34
+ os.path.join(examples_folder, file)
35
+ for file in os.listdir(examples_folder)
36
+ if os.path.isfile(os.path.join(examples_folder, file))
37
+ ]
38
+
39
+ def remove_background(input_pil, remove_bg):
40
+
41
+ # Create a temporary folder for downloaded and processed images
42
+ temp_dir = tempfile.mkdtemp()
43
+ unique_id = str(uuid.uuid4())
44
+ image_path = os.path.join(temp_dir, f'input_image_{unique_id}.png')
45
+
46
+ try:
47
+ # Check if input_url is already a PIL Image
48
+ if isinstance(input_pil, Image.Image):
49
+ image = input_pil
50
+ else:
51
+ # Otherwise, assume it's a file path and open it
52
+ image = Image.open(input_pil)
53
+
54
+ # Flip the image horizontally
55
+ image = image.transpose(Image.FLIP_LEFT_RIGHT)
56
+
57
+ # Save the resized image
58
+ image.save(image_path)
59
+ except Exception as e:
60
+ shutil.rmtree(temp_dir)
61
+ raise gr.Error(f"Error downloading or saving the image: {str(e)}")
62
+
63
+ if remove_bg is True:
64
+ # Run background removal
65
+ removed_bg_path = os.path.join(temp_dir, f'output_image_rmbg_{unique_id}.png')
66
+ try:
67
+ img = Image.open(image_path)
68
+ result = remove(img)
69
+ result.save(removed_bg_path)
70
+
71
+ # Remove the input image to keep the temp directory clean
72
+ os.remove(image_path)
73
+ except Exception as e:
74
+ shutil.rmtree(temp_dir)
75
+ raise gr.Error(f"Error removing background: {str(e)}")
76
+
77
+ return removed_bg_path, temp_dir
78
+ else:
79
+ return image_path, temp_dir
80
+
81
+ def run_inference(temp_dir, removed_bg_path):
82
+ # Define the inference configuration
83
+ inference_config = "configs/inference-768-6view.yaml"
84
+ pretrained_model = "./ckpts"
85
+ crop_size = 740
86
+ seed = 600
87
+ num_views = 7
88
+ save_mode = "rgb"
89
+
90
+ try:
91
+ # Run the inference command
92
+ subprocess.run(
93
+ [
94
+ "python", "inference.py",
95
+ "--config", inference_config,
96
+ f"pretrained_model_name_or_path={pretrained_model}",
97
+ f"validation_dataset.crop_size={crop_size}",
98
+ f"with_smpl=false",
99
+ f"validation_dataset.root_dir={temp_dir}",
100
+ f"seed={seed}",
101
+ f"num_views={num_views}",
102
+ f"save_mode={save_mode}"
103
+ ],
104
+ check=True
105
+ )
106
+
107
+
108
+ # Retrieve the file name without the extension
109
+ removed_bg_file_name = os.path.splitext(os.path.basename(removed_bg_path))[0]
110
+
111
+ # List objects in the "out" folder
112
+ out_folder_path = "out"
113
+ out_folder_objects = os.listdir(out_folder_path)
114
+ print(f"Objects in '{out_folder_path}':")
115
+ for obj in out_folder_objects:
116
+ print(f" - {obj}")
117
+
118
+ # List objects in the "out/{removed_bg_file_name}" folder
119
+ specific_out_folder_path = os.path.join(out_folder_path, removed_bg_file_name)
120
+ if os.path.exists(specific_out_folder_path) and os.path.isdir(specific_out_folder_path):
121
+ specific_out_folder_objects = os.listdir(specific_out_folder_path)
122
+ print(f"\nObjects in '{specific_out_folder_path}':")
123
+ for obj in specific_out_folder_objects:
124
+ print(f" - {obj}")
125
+ else:
126
+ print(f"\nThe folder '{specific_out_folder_path}' does not exist.")
127
+
128
+ output_video = glob(os.path.join(f"out/{removed_bg_file_name}", "*.mp4"))
129
+ output_objects = glob(os.path.join(f"out/{removed_bg_file_name}", "*.obj"))
130
+ return output_video, output_objects
131
+
132
+ except subprocess.CalledProcessError as e:
133
+ return f"Error during inference: {str(e)}"
134
+
135
+ def process_image(input_pil, remove_bg, progress=gr.Progress(track_tqdm=True)):
136
+
137
+ torch.cuda.empty_cache()
138
+
139
+ # Remove background
140
+ result = remove_background(input_pil, remove_bg)
141
+
142
+ if isinstance(result, str) and result.startswith("Error"):
143
+ raise gr.Error(f"{result}") # Return the error message if something went wrong
144
+
145
+ removed_bg_path, temp_dir = result # Unpack only if successful
146
+
147
+ # Run inference
148
+ output_video, output_objects = run_inference(temp_dir, removed_bg_path)
149
+
150
+ if isinstance(output_video, str) and output_video.startswith("Error"):
151
+ shutil.rmtree(temp_dir)
152
+ raise gr.Error(f"{output_video}") # Return the error message if inference failed
153
+
154
+
155
+ shutil.rmtree(temp_dir) # Cleanup temporary folder
156
+ print(output_video)
157
+ torch.cuda.empty_cache()
158
+ return output_video[0], output_objects[0], output_objects[1]
159
+
160
+ css="""
161
+ div#col-container{
162
+ margin: 0 auto;
163
+ max-width: 982px;
164
+ }
165
+ div#video-out-elm{
166
+ height: 323px;
167
+ }
168
+ """
169
+ def gradio_interface():
170
+ with gr.Blocks(css=css) as app:
171
+ with gr.Column(elem_id="col-container"):
172
+ gr.Markdown("# PSHuman: Photorealistic Single-image 3D Human Reconstruction using Cross-Scale Multiview Diffusion and Explicit Remeshing")
173
+ gr.HTML("""
174
+ <div style="display:flex;column-gap:4px;">
175
+ <a href="https://github.com/pengHTYX/PSHuman">
176
+ <img src='https://img.shields.io/badge/GitHub-Repo-blue'>
177
+ </a>
178
+ <a href="https://penghtyx.github.io/PSHuman/">
179
+ <img src='https://img.shields.io/badge/Project-Page-green'>
180
+ </a>
181
+ <a href="https://arxiv.org/pdf/2409.10141">
182
+ <img src='https://img.shields.io/badge/ArXiv-Paper-red'>
183
+ </a>
184
+ <a href="https://huggingface.co/spaces/fffiloni/PSHuman?duplicate=true">
185
+ <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-sm.svg" alt="Duplicate this Space">
186
+ </a>
187
+ <a href="https://huggingface.co/fffiloni">
188
+ <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/follow-me-on-HF-sm-dark.svg" alt="Follow me on HF">
189
+ </a>
190
+ </div>
191
+ """)
192
+ with gr.Group():
193
+ with gr.Row():
194
+ with gr.Column(scale=2):
195
+
196
+ input_image = gr.Image(
197
+ label="Image input",
198
+ type="pil",
199
+ image_mode="RGBA",
200
+ height=480
201
+ )
202
+
203
+ remove_bg = gr.Checkbox(label="Need to remove BG ?", value=False)
204
+
205
+ submit_button = gr.Button("Process")
206
+
207
+ with gr.Column(scale=4):
208
+ output_video= gr.Video(label="Output Video", elem_id="video-out-elm")
209
+ with gr.Row():
210
+ output_object_mesh = gr.Model3D(label=".OBJ Mesh", height=240)
211
+ output_object_color = gr.Model3D(label=".OBJ colored", height=240)
212
+
213
+ gr.Examples(
214
+ examples = examples_folder,
215
+ inputs = [input_image],
216
+ examples_per_page = 11
217
+ )
218
+
219
+ submit_button.click(process_image, inputs=[input_image, remove_bg], outputs=[output_video, output_object_mesh, output_object_color])
220
+
221
+ return app
222
+
223
+ # Launch the Gradio app
224
+ app = gradio_interface()
225
+ app.launch(show_api=False, show_error=True, ssr_mode=False)
assets/result_clr_scale4_pexels-barbara-olsen-7869640.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2ec184ff46d1c066f6b1e5384a0964cc7c12c1b0a870c171949c1fa0ff2dd164
3
+ size 319966
assets/result_clr_scale4_pexels-zdmit-6780091.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d59a4ab1bb37dce928c04bedebf8acdbddf9e1474e4551c22d3dca8a8d2f7647
3
+ size 628781
blender/blender_render_human_ortho.py ADDED
@@ -0,0 +1,837 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Blender script to render images of 3D models.
2
+
3
+ This script is used to render images of 3D models. It takes in a list of paths
4
+ to .glb files and renders images of each model. The images are from rotating the
5
+ object around the origin. The images are saved to the output directory.
6
+
7
+ Example usage:
8
+ blender -b -P blender_script.py -- \
9
+ --object_path my_object.glb \
10
+ --output_dir ./views \
11
+ --engine CYCLES \
12
+ --scale 0.8 \
13
+ --num_images 12 \
14
+ --camera_dist 1.2
15
+
16
+ Here, input_model_paths.json is a json file containing a list of paths to .glb.
17
+ """
18
+ import argparse
19
+ import json
20
+ import math
21
+ import os
22
+ import random
23
+ import sys
24
+ import time
25
+ import glob
26
+ import urllib.request
27
+ import uuid
28
+ from typing import Tuple
29
+ from mathutils import Vector, Matrix
30
+ os.environ["OPENCV_IO_ENABLE_OPENEXR"]="1"
31
+ # os.environ["CUDA_VISIBLE_DEVICES"] = "0"
32
+ import cv2
33
+ import numpy as np
34
+ from typing import Any, Callable, Dict, Generator, List, Literal, Optional, Set, Tuple
35
+
36
+ import bpy
37
+ from mathutils import Vector
38
+
39
+ import OpenEXR
40
+ import Imath
41
+ from PIL import Image
42
+
43
+ # import blenderproc as bproc
44
+
45
+ bpy.app.debug_value=256
46
+
47
+ parser = argparse.ArgumentParser()
48
+ parser.add_argument(
49
+ "--object_path",
50
+ type=str,
51
+ required=True,
52
+ help="Path to the object file",
53
+ )
54
+ parser.add_argument("--smpl_path", type=str, required=True, help="Path to the object file")
55
+ parser.add_argument("--output_dir", type=str, default="/views_whole_sphere-test2")
56
+ parser.add_argument(
57
+ "--engine", type=str, default="BLENDER_EEVEE", choices=["CYCLES", "BLENDER_EEVEE"]
58
+ )
59
+ parser.add_argument("--scale", type=float, default=1.0)
60
+ parser.add_argument("--num_images", type=int, default=8)
61
+ parser.add_argument("--random_images", type=int, default=3)
62
+ parser.add_argument("--random_ortho", type=int, default=1)
63
+ parser.add_argument("--device", type=str, default="CUDA")
64
+ parser.add_argument("--resolution", type=int, default=512)
65
+
66
+
67
+ argv = sys.argv[sys.argv.index("--") + 1 :]
68
+ args = parser.parse_args(argv)
69
+
70
+
71
+
72
+ print('===================', args.engine, '===================')
73
+
74
+ context = bpy.context
75
+ scene = context.scene
76
+ render = scene.render
77
+
78
+ cam = scene.objects["Camera"]
79
+ cam.data.type = 'ORTHO'
80
+ cam.data.ortho_scale = 1.
81
+ cam.data.lens = 35
82
+ cam.data.sensor_height = 32
83
+ cam.data.sensor_width = 32
84
+
85
+ cam_constraint = cam.constraints.new(type="TRACK_TO")
86
+ cam_constraint.track_axis = "TRACK_NEGATIVE_Z"
87
+ cam_constraint.up_axis = "UP_Y"
88
+
89
+ # setup lighting
90
+ # bpy.ops.object.light_add(type="AREA")
91
+ # light2 = bpy.data.lights["Area"]
92
+ # light2.energy = 3000
93
+ # bpy.data.objects["Area"].location[2] = 0.5
94
+ # bpy.data.objects["Area"].scale[0] = 100
95
+ # bpy.data.objects["Area"].scale[1] = 100
96
+ # bpy.data.objects["Area"].scale[2] = 100
97
+
98
+ render.engine = args.engine
99
+ render.image_settings.file_format = "PNG"
100
+ render.image_settings.color_mode = "RGBA"
101
+ render.resolution_x = args.resolution
102
+ render.resolution_y = args.resolution
103
+ render.resolution_percentage = 100
104
+ render.threads_mode = 'FIXED' # 使用固定线程数模式
105
+ render.threads = 32 # 设置线程数
106
+
107
+ scene.cycles.device = "GPU"
108
+ scene.cycles.samples = 128 # 128
109
+ scene.cycles.diffuse_bounces = 1
110
+ scene.cycles.glossy_bounces = 1
111
+ scene.cycles.transparent_max_bounces = 3 # 3
112
+ scene.cycles.transmission_bounces = 3 # 3
113
+ # scene.cycles.filter_width = 0.01
114
+ bpy.context.scene.cycles.adaptive_threshold = 0
115
+ scene.cycles.use_denoising = True
116
+ scene.render.film_transparent = True
117
+
118
+ bpy.context.preferences.addons["cycles"].preferences.get_devices()
119
+ # Set the device_type
120
+ bpy.context.preferences.addons["cycles"].preferences.compute_device_type = 'CUDA' # or "OPENCL"
121
+ bpy.context.scene.cycles.tile_size = 8192
122
+
123
+
124
+ # eevee = scene.eevee
125
+ # eevee.use_soft_shadows = True
126
+ # eevee.use_ssr = True
127
+ # eevee.use_ssr_refraction = True
128
+ # eevee.taa_render_samples = 64
129
+ # eevee.use_gtao = True
130
+ # eevee.gtao_distance = 1
131
+ # eevee.use_volumetric_shadows = True
132
+ # eevee.volumetric_tile_size = '2'
133
+ # eevee.gi_diffuse_bounces = 1
134
+ # eevee.gi_cubemap_resolution = '128'
135
+ # eevee.gi_visibility_resolution = '16'
136
+ # eevee.gi_irradiance_smoothing = 0
137
+
138
+
139
+ # for depth & normal
140
+ context.view_layer.use_pass_normal = True
141
+ context.view_layer.use_pass_z = True
142
+ context.scene.use_nodes = True
143
+
144
+
145
+ tree = bpy.context.scene.node_tree
146
+ nodes = bpy.context.scene.node_tree.nodes
147
+ links = bpy.context.scene.node_tree.links
148
+
149
+ # Clear default nodes
150
+ for n in nodes:
151
+ nodes.remove(n)
152
+
153
+ # # Create input render layer node.
154
+ render_layers = nodes.new('CompositorNodeRLayers')
155
+
156
+ scale_normal = nodes.new(type="CompositorNodeMixRGB")
157
+ scale_normal.blend_type = 'MULTIPLY'
158
+ scale_normal.inputs[2].default_value = (0.5, 0.5, 0.5, 1)
159
+ links.new(render_layers.outputs['Normal'], scale_normal.inputs[1])
160
+ bias_normal = nodes.new(type="CompositorNodeMixRGB")
161
+ bias_normal.blend_type = 'ADD'
162
+ bias_normal.inputs[2].default_value = (0.5, 0.5, 0.5, 0)
163
+ links.new(scale_normal.outputs[0], bias_normal.inputs[1])
164
+ normal_file_output = nodes.new(type="CompositorNodeOutputFile")
165
+ normal_file_output.label = 'Normal Output'
166
+ links.new(bias_normal.outputs[0], normal_file_output.inputs[0])
167
+
168
+ normal_file_output.format.file_format = "OPEN_EXR" # default is "PNG"
169
+ normal_file_output.format.color_mode = "RGB" # default is "BW"
170
+
171
+ depth_file_output = nodes.new(type="CompositorNodeOutputFile")
172
+ depth_file_output.label = 'Depth Output'
173
+ links.new(render_layers.outputs['Depth'], depth_file_output.inputs[0])
174
+ depth_file_output.format.file_format = "OPEN_EXR" # default is "PNG"
175
+ depth_file_output.format.color_mode = "RGB" # default is "BW"
176
+
177
+ def prepare_depth_outputs():
178
+ tree = bpy.context.scene.node_tree
179
+ links = tree.links
180
+ render_node = tree.nodes['Render Layers']
181
+ depth_out_node = tree.nodes.new(type="CompositorNodeOutputFile")
182
+ depth_map_node = tree.nodes.new(type="CompositorNodeMapRange")
183
+ depth_out_node.base_path = ''
184
+ depth_out_node.format.file_format = 'OPEN_EXR'
185
+ depth_out_node.format.color_depth = '32'
186
+
187
+ depth_map_node.inputs[1].default_value = 0.54
188
+ depth_map_node.inputs[2].default_value = 1.96
189
+ depth_map_node.inputs[3].default_value = 0
190
+ depth_map_node.inputs[4].default_value = 1
191
+ depth_map_node.use_clamp = True
192
+ links.new(render_node.outputs[2],depth_map_node.inputs[0])
193
+ links.new(depth_map_node.outputs[0], depth_out_node.inputs[0])
194
+ return depth_out_node, depth_map_node
195
+
196
+ depth_file_output, depth_map_node = prepare_depth_outputs()
197
+
198
+
199
+ def exr_to_png(exr_path):
200
+ depth_path = exr_path.replace('.exr', '.png')
201
+ exr_image = OpenEXR.InputFile(exr_path)
202
+ dw = exr_image.header()['dataWindow']
203
+ (width, height) = (dw.max.x - dw.min.x + 1, dw.max.y - dw.min.y + 1)
204
+
205
+ def read_exr(s, width, height):
206
+ mat = np.fromstring(s, dtype=np.float32)
207
+ mat = mat.reshape(height, width)
208
+ return mat
209
+
210
+ dmap, _, _ = [read_exr(s, width, height) for s in exr_image.channels('BGR', Imath.PixelType(Imath.PixelType.FLOAT))]
211
+ dmap = np.clip(np.asarray(dmap,np.float64),a_max=1.0, a_min=0.0) * 65535
212
+ dmap = Image.fromarray(dmap.astype(np.uint16))
213
+ dmap.save(depth_path)
214
+ exr_image.close()
215
+ # os.system('rm {}'.format(exr_path))
216
+
217
+ def extract_depth(directory):
218
+ fns = glob.glob(f'{directory}/*.exr')
219
+ for fn in fns: exr_to_png(fn)
220
+ os.system(f'rm {directory}/*.exr')
221
+
222
+ def sample_point_on_sphere(radius: float) -> Tuple[float, float, float]:
223
+ theta = random.random() * 2 * math.pi
224
+ phi = math.acos(2 * random.random() - 1)
225
+ return (
226
+ radius * math.sin(phi) * math.cos(theta),
227
+ radius * math.sin(phi) * math.sin(theta),
228
+ radius * math.cos(phi),
229
+ )
230
+
231
+ def sample_spherical(radius=3.0, maxz=3.0, minz=0.):
232
+ correct = False
233
+ while not correct:
234
+ vec = np.random.uniform(-1, 1, 3)
235
+ vec[2] = np.abs(vec[2])
236
+ vec = vec / np.linalg.norm(vec, axis=0) * radius
237
+ if maxz > vec[2] > minz:
238
+ correct = True
239
+ return vec
240
+
241
+ def sample_spherical(radius_min=1.5, radius_max=2.0, maxz=1.6, minz=-0.75):
242
+ correct = False
243
+ while not correct:
244
+ vec = np.random.uniform(-1, 1, 3)
245
+ # vec[2] = np.abs(vec[2])
246
+ radius = np.random.uniform(radius_min, radius_max, 1)
247
+ vec = vec / np.linalg.norm(vec, axis=0) * radius[0]
248
+ if maxz > vec[2] > minz:
249
+ correct = True
250
+ return vec
251
+
252
+ def randomize_camera():
253
+ elevation = random.uniform(0., 90.)
254
+ azimuth = random.uniform(0., 360)
255
+ distance = random.uniform(0.8, 1.6)
256
+ return set_camera_location(elevation, azimuth, distance)
257
+
258
+ def set_camera_location(elevation, azimuth, distance):
259
+ # from https://blender.stackexchange.com/questions/18530/
260
+ x, y, z = sample_spherical(radius_min=1.5, radius_max=2.2, maxz=2.2, minz=-2.2)
261
+ camera = bpy.data.objects["Camera"]
262
+ camera.location = x, y, z
263
+
264
+ direction = - camera.location
265
+ rot_quat = direction.to_track_quat('-Z', 'Y')
266
+ camera.rotation_euler = rot_quat.to_euler()
267
+ return camera
268
+
269
+ def set_camera_mvdream(azimuth, elevation, distance):
270
+ # theta, phi = np.deg2rad(azimuth), np.deg2rad(elevation)
271
+ azimuth, elevation = np.deg2rad(azimuth), np.deg2rad(elevation)
272
+ point = (
273
+ distance * math.cos(azimuth) * math.cos(elevation),
274
+ distance * math.sin(azimuth) * math.cos(elevation),
275
+ distance * math.sin(elevation),
276
+ )
277
+ camera = bpy.data.objects["Camera"]
278
+ camera.location = point
279
+
280
+ direction = -camera.location
281
+ rot_quat = direction.to_track_quat('-Z', 'Y')
282
+ camera.rotation_euler = rot_quat.to_euler()
283
+ return camera
284
+
285
+ def reset_scene() -> None:
286
+ """Resets the scene to a clean state.
287
+
288
+ Returns:
289
+ None
290
+ """
291
+ # delete everything that isn't part of a camera or a light
292
+ for obj in bpy.data.objects:
293
+ if obj.type not in {"CAMERA", "LIGHT"}:
294
+ bpy.data.objects.remove(obj, do_unlink=True)
295
+
296
+ # delete all the materials
297
+ for material in bpy.data.materials:
298
+ bpy.data.materials.remove(material, do_unlink=True)
299
+
300
+ # delete all the textures
301
+ for texture in bpy.data.textures:
302
+ bpy.data.textures.remove(texture, do_unlink=True)
303
+
304
+ # delete all the images
305
+ for image in bpy.data.images:
306
+ bpy.data.images.remove(image, do_unlink=True)
307
+ def process_ply(obj):
308
+ # obj = bpy.context.selected_objects[0]
309
+
310
+ # 创建一个新的材质
311
+ material = bpy.data.materials.new(name="VertexColors")
312
+ material.use_nodes = True
313
+ obj.data.materials.append(material)
314
+
315
+ # 获取材质的节点树
316
+ nodes = material.node_tree.nodes
317
+ links = material.node_tree.links
318
+
319
+ # 删除原有的'Principled BSDF'节点
320
+ principled_bsdf_node = nodes.get("Principled BSDF")
321
+ if principled_bsdf_node:
322
+ nodes.remove(principled_bsdf_node)
323
+
324
+ # 创建一个新的'Emission'节点
325
+ emission_node = nodes.new(type="ShaderNodeEmission")
326
+ emission_node.location = 0, 0
327
+
328
+ # 创建一个'Attribute'节点
329
+ attribute_node = nodes.new(type="ShaderNodeAttribute")
330
+ attribute_node.location = -300, 0
331
+ attribute_node.attribute_name = "Col" # 顶点颜色属性名称
332
+
333
+ # 创建一个'Output'节点
334
+ output_node = nodes.get("Material Output")
335
+
336
+ # 连接节点
337
+ links.new(attribute_node.outputs["Color"], emission_node.inputs["Color"])
338
+ links.new(emission_node.outputs["Emission"], output_node.inputs["Surface"])
339
+
340
+ # # load the glb model
341
+ # def load_object(object_path: str) -> None:
342
+
343
+ # if object_path.endswith(".glb"):
344
+ # bpy.ops.import_scene.gltf(filepath=object_path, merge_vertices=False)
345
+ # elif object_path.endswith(".fbx"):
346
+ # bpy.ops.import_scene.fbx(filepath=object_path)
347
+ # elif object_path.endswith(".obj"):
348
+ # bpy.ops.import_scene.obj(filepath=object_path)
349
+ # elif object_path.endswith(".ply"):
350
+ # bpy.ops.import_mesh.ply(filepath=object_path)
351
+ # obj = bpy.context.selected_objects[0]
352
+ # obj.rotation_euler[0] = 1.5708
353
+ # # bpy.ops.wm.ply_import(filepath=object_path, directory=os.path.dirname(object_path),forward_axis='X', up_axis='Y')
354
+ # process_ply(obj)
355
+ # else:
356
+ # raise ValueError(f"Unsupported file type: {object_path}")
357
+
358
+
359
+
360
+ def scene_bbox(
361
+ single_obj: Optional[bpy.types.Object] = None, ignore_matrix: bool = False
362
+ ) -> Tuple[Vector, Vector]:
363
+ """Returns the bounding box of the scene.
364
+
365
+ Taken from Shap-E rendering script
366
+ (https://github.com/openai/shap-e/blob/main/shap_e/rendering/blender/blender_script.py#L68-L82)
367
+
368
+ Args:
369
+ single_obj (Optional[bpy.types.Object], optional): If not None, only computes
370
+ the bounding box for the given object. Defaults to None.
371
+ ignore_matrix (bool, optional): Whether to ignore the object's matrix. Defaults
372
+ to False.
373
+
374
+ Raises:
375
+ RuntimeError: If there are no objects in the scene.
376
+
377
+ Returns:
378
+ Tuple[Vector, Vector]: The minimum and maximum coordinates of the bounding box.
379
+ """
380
+ bbox_min = (math.inf,) * 3
381
+ bbox_max = (-math.inf,) * 3
382
+ found = False
383
+ for obj in get_scene_meshes() if single_obj is None else [single_obj]:
384
+ found = True
385
+ for coord in obj.bound_box:
386
+ coord = Vector(coord)
387
+ if not ignore_matrix:
388
+ coord = obj.matrix_world @ coord
389
+ bbox_min = tuple(min(x, y) for x, y in zip(bbox_min, coord))
390
+ bbox_max = tuple(max(x, y) for x, y in zip(bbox_max, coord))
391
+
392
+ if not found:
393
+ raise RuntimeError("no objects in scene to compute bounding box for")
394
+
395
+ return Vector(bbox_min), Vector(bbox_max)
396
+
397
+
398
+ def get_scene_root_objects() -> Generator[bpy.types.Object, None, None]:
399
+ """Returns all root objects in the scene.
400
+
401
+ Yields:
402
+ Generator[bpy.types.Object, None, None]: Generator of all root objects in the
403
+ scene.
404
+ """
405
+ for obj in bpy.context.scene.objects.values():
406
+ if not obj.parent:
407
+ yield obj
408
+
409
+
410
+ def get_scene_meshes() -> Generator[bpy.types.Object, None, None]:
411
+ """Returns all meshes in the scene.
412
+
413
+ Yields:
414
+ Generator[bpy.types.Object, None, None]: Generator of all meshes in the scene.
415
+ """
416
+ for obj in bpy.context.scene.objects.values():
417
+ if isinstance(obj.data, (bpy.types.Mesh)):
418
+ yield obj
419
+
420
+
421
+ # Build intrinsic camera parameters from Blender camera data
422
+ #
423
+ # See notes on this in
424
+ # blender.stackexchange.com/questions/15102/what-is-blenders-camera-projection-matrix-model
425
+ def get_calibration_matrix_K_from_blender(camd):
426
+ f_in_mm = camd.lens
427
+ scene = bpy.context.scene
428
+ resolution_x_in_px = scene.render.resolution_x
429
+ resolution_y_in_px = scene.render.resolution_y
430
+ scale = scene.render.resolution_percentage / 100
431
+ sensor_width_in_mm = camd.sensor_width
432
+ sensor_height_in_mm = camd.sensor_height
433
+ pixel_aspect_ratio = scene.render.pixel_aspect_x / scene.render.pixel_aspect_y
434
+ if (camd.sensor_fit == 'VERTICAL'):
435
+ # the sensor height is fixed (sensor fit is horizontal),
436
+ # the sensor width is effectively changed with the pixel aspect ratio
437
+ s_u = resolution_x_in_px * scale / sensor_width_in_mm / pixel_aspect_ratio
438
+ s_v = resolution_y_in_px * scale / sensor_height_in_mm
439
+ else: # 'HORIZONTAL' and 'AUTO'
440
+ # the sensor width is fixed (sensor fit is horizontal),
441
+ # the sensor height is effectively changed with the pixel aspect ratio
442
+ pixel_aspect_ratio = scene.render.pixel_aspect_x / scene.render.pixel_aspect_y
443
+ s_u = resolution_x_in_px * scale / sensor_width_in_mm
444
+ s_v = resolution_y_in_px * scale * pixel_aspect_ratio / sensor_height_in_mm
445
+
446
+ # Parameters of intrinsic calibration matrix K
447
+ alpha_u = f_in_mm * s_u
448
+ alpha_v = f_in_mm * s_v
449
+ u_0 = resolution_x_in_px * scale / 2
450
+ v_0 = resolution_y_in_px * scale / 2
451
+ skew = 0 # only use rectangular pixels
452
+
453
+ K = Matrix(
454
+ ((alpha_u, skew, u_0),
455
+ ( 0 , alpha_v, v_0),
456
+ ( 0 , 0, 1 )))
457
+ return K
458
+
459
+
460
+ def get_calibration_matrix_K_from_blender_for_ortho(camd, ortho_scale):
461
+ scene = bpy.context.scene
462
+ resolution_x_in_px = scene.render.resolution_x
463
+ resolution_y_in_px = scene.render.resolution_y
464
+ scale = scene.render.resolution_percentage / 100
465
+ pixel_aspect_ratio = scene.render.pixel_aspect_x / scene.render.pixel_aspect_y
466
+
467
+ fx = resolution_x_in_px / ortho_scale
468
+ fy = resolution_y_in_px / ortho_scale / pixel_aspect_ratio
469
+
470
+ cx = resolution_x_in_px / 2
471
+ cy = resolution_y_in_px / 2
472
+
473
+ K = Matrix(
474
+ ((fx, 0, cx),
475
+ (0, fy, cy),
476
+ (0 , 0, 1)))
477
+ return K
478
+
479
+
480
+ def get_3x4_RT_matrix_from_blender(cam):
481
+ bpy.context.view_layer.update()
482
+ location, rotation = cam.matrix_world.decompose()[0:2]
483
+ R = np.asarray(rotation.to_matrix())
484
+ t = np.asarray(location)
485
+
486
+ cam_rec = np.asarray([[1, 0, 0], [0, -1, 0], [0, 0, -1]], np.float32)
487
+ R = R.T
488
+ t = -R @ t
489
+ R_world2cv = cam_rec @ R
490
+ t_world2cv = cam_rec @ t
491
+
492
+ RT = np.concatenate([R_world2cv,t_world2cv[:,None]],1)
493
+ return RT
494
+
495
+ def delete_invisible_objects() -> None:
496
+ """Deletes all invisible objects in the scene.
497
+
498
+ Returns:
499
+ None
500
+ """
501
+ bpy.ops.object.select_all(action="DESELECT")
502
+ for obj in scene.objects:
503
+ if obj.hide_viewport or obj.hide_render:
504
+ obj.hide_viewport = False
505
+ obj.hide_render = False
506
+ obj.hide_select = False
507
+ obj.select_set(True)
508
+ bpy.ops.object.delete()
509
+
510
+ # Delete invisible collections
511
+ invisible_collections = [col for col in bpy.data.collections if col.hide_viewport]
512
+ for col in invisible_collections:
513
+ bpy.data.collections.remove(col)
514
+
515
+
516
+ def normalize_scene():
517
+ """Normalizes the scene by scaling and translating it to fit in a unit cube centered
518
+ at the origin.
519
+
520
+ Mostly taken from the Point-E / Shap-E rendering script
521
+ (https://github.com/openai/point-e/blob/main/point_e/evals/scripts/blender_script.py#L97-L112),
522
+ but fix for multiple root objects: (see bug report here:
523
+ https://github.com/openai/shap-e/pull/60).
524
+
525
+ Returns:
526
+ None
527
+ """
528
+ if len(list(get_scene_root_objects())) > 1:
529
+ print('we have more than one root objects!!')
530
+ # create an empty object to be used as a parent for all root objects
531
+ parent_empty = bpy.data.objects.new("ParentEmpty", None)
532
+ bpy.context.scene.collection.objects.link(parent_empty)
533
+
534
+ # parent all root objects to the empty object
535
+ for obj in get_scene_root_objects():
536
+ if obj != parent_empty:
537
+ obj.parent = parent_empty
538
+
539
+ bbox_min, bbox_max = scene_bbox()
540
+ dxyz = bbox_max - bbox_min
541
+ dist = np.sqrt(dxyz[0]**2+ dxyz[1]**2+dxyz[2]**2)
542
+ scale = 1 / dist
543
+ for obj in get_scene_root_objects():
544
+ obj.scale = obj.scale * scale
545
+
546
+ # Apply scale to matrix_world.
547
+ bpy.context.view_layer.update()
548
+ bbox_min, bbox_max = scene_bbox()
549
+ offset = -(bbox_min + bbox_max) / 2
550
+ for obj in get_scene_root_objects():
551
+ obj.matrix_world.translation += offset
552
+ bpy.ops.object.select_all(action="DESELECT")
553
+
554
+ # unparent the camera
555
+ bpy.data.objects["Camera"].parent = None
556
+ return scale, offset
557
+
558
+ def download_object(object_url: str) -> str:
559
+ """Download the object and return the path."""
560
+ # uid = uuid.uuid4()
561
+ uid = object_url.split("/")[-1].split(".")[0]
562
+ tmp_local_path = os.path.join("tmp-objects", f"{uid}.glb" + ".tmp")
563
+ local_path = os.path.join("tmp-objects", f"{uid}.glb")
564
+ # wget the file and put it in local_path
565
+ os.makedirs(os.path.dirname(tmp_local_path), exist_ok=True)
566
+ urllib.request.urlretrieve(object_url, tmp_local_path)
567
+ os.rename(tmp_local_path, local_path)
568
+ # get the absolute path
569
+ local_path = os.path.abspath(local_path)
570
+ return local_path
571
+
572
+
573
+ def render_and_save(view_id, object_uid, len_val, azimuth, elevation, distance, ortho=False):
574
+ # print(view_id)
575
+ # render the image
576
+ render_path = os.path.join(args.output_dir, 'image', f"{view_id:03d}.png")
577
+ scene.render.filepath = render_path
578
+
579
+ if not ortho:
580
+ cam.data.lens = len_val
581
+
582
+ depth_map_node.inputs[1].default_value = distance - 1
583
+ depth_map_node.inputs[2].default_value = distance + 1
584
+ depth_file_output.base_path = os.path.join(args.output_dir, object_uid, 'depth')
585
+
586
+ depth_file_output.file_slots[0].path = f"{view_id:03d}"
587
+ normal_file_output.file_slots[0].path = f"{view_id:03d}"
588
+
589
+ if not os.path.exists(os.path.join(args.output_dir, 'normal', f"{view_id+1:03d}.png")):
590
+ bpy.ops.render.render(write_still=True)
591
+
592
+
593
+ if os.path.exists(os.path.join(args.output_dir, object_uid, 'depth', f"{view_id:03d}0001.exr")):
594
+ os.rename(os.path.join(args.output_dir, object_uid, 'depth', f"{view_id:03d}0001.exr"),
595
+ os.path.join(args.output_dir, object_uid, 'depth', f"{view_id:03d}.exr"))
596
+
597
+ if os.path.exists(os.path.join(args.output_dir, 'normal', f"{view_id:03d}0001.exr")):
598
+ normal = cv2.imread(os.path.join(args.output_dir, 'normal', f"{view_id:03d}0001.exr"), cv2.IMREAD_UNCHANGED)
599
+ normal_unit16 = (normal * 65535).astype(np.uint16)
600
+ cv2.imwrite(os.path.join(args.output_dir, 'normal', f"{view_id:03d}.png"), normal_unit16)
601
+ os.remove(os.path.join(args.output_dir, 'normal', f"{view_id:03d}0001.exr"))
602
+
603
+ # save camera KRT matrix
604
+ if ortho:
605
+ K = get_calibration_matrix_K_from_blender_for_ortho(cam.data, ortho_scale=cam.data.ortho_scale)
606
+ else:
607
+ K = get_calibration_matrix_K_from_blender(cam.data)
608
+
609
+ RT = get_3x4_RT_matrix_from_blender(cam)
610
+ para_path = os.path.join(args.output_dir, 'camera', f"{view_id:03d}.npy")
611
+ # np.save(RT_path, RT)
612
+ paras = {}
613
+ paras['intrinsic'] = np.array(K, np.float32)
614
+ paras['extrinsic'] = np.array(RT, np.float32)
615
+ paras['fov'] = cam.data.angle
616
+ paras['azimuth'] = azimuth
617
+ paras['elevation'] = elevation
618
+ paras['distance'] = distance
619
+ paras['focal'] = cam.data.lens
620
+ paras['sensor_width'] = cam.data.sensor_width
621
+ paras['near'] = distance - 1
622
+ paras['far'] = distance + 1
623
+ paras['camera'] = 'persp' if not ortho else 'ortho'
624
+ np.save(para_path, paras)
625
+
626
+ def render_and_save_smpl(view_id, object_uid, len_val, azimuth, elevation, distance, ortho=False):
627
+
628
+
629
+ if not ortho:
630
+ cam.data.lens = len_val
631
+
632
+ render_path = os.path.join(args.output_dir, 'smpl_image', f"{view_id:03d}.png")
633
+ scene.render.filepath = render_path
634
+
635
+ normal_file_output.file_slots[0].path = f"{view_id:03d}"
636
+ if not os.path.exists(os.path.join(args.output_dir, 'smpl_normal', f"{view_id:03d}.png")):
637
+ bpy.ops.render.render(write_still=True)
638
+
639
+ if os.path.exists(os.path.join(args.output_dir, 'smpl_normal', f"{view_id:03d}0001.exr")):
640
+ normal = cv2.imread(os.path.join(args.output_dir, 'smpl_normal', f"{view_id:03d}0001.exr"), cv2.IMREAD_UNCHANGED)
641
+ normal_unit16 = (normal * 65535).astype(np.uint16)
642
+ cv2.imwrite(os.path.join(args.output_dir, 'smpl_normal', f"{view_id:03d}.png"), normal_unit16)
643
+ os.remove(os.path.join(args.output_dir, 'smpl_normal', f"{view_id:03d}0001.exr"))
644
+
645
+
646
+
647
+ def scene_meshes():
648
+ for obj in bpy.context.scene.objects.values():
649
+ if isinstance(obj.data, (bpy.types.Mesh)):
650
+ yield obj
651
+
652
+ def load_object(object_path: str) -> None:
653
+ """Loads a glb model into the scene."""
654
+ if object_path.endswith(".glb"):
655
+ bpy.ops.import_scene.gltf(filepath=object_path, merge_vertices=False)
656
+ elif object_path.endswith(".fbx"):
657
+ bpy.ops.import_scene.fbx(filepath=object_path)
658
+ elif object_path.endswith(".obj"):
659
+ bpy.ops.import_scene.obj(filepath=object_path)
660
+ obj = bpy.context.selected_objects[0]
661
+ obj.rotation_euler[0] = 6.28319
662
+ # obj.rotation_euler[2] = 1.5708
663
+ elif object_path.endswith(".ply"):
664
+ bpy.ops.import_mesh.ply(filepath=object_path)
665
+ obj = bpy.context.selected_objects[0]
666
+ obj.rotation_euler[0] = 1.5708
667
+ obj.rotation_euler[2] = 1.5708
668
+ # bpy.ops.wm.ply_import(filepath=object_path, directory=os.path.dirname(object_path),forward_axis='X', up_axis='Y')
669
+ process_ply(obj)
670
+ else:
671
+ raise ValueError(f"Unsupported file type: {object_path}")
672
+
673
+ def save_images(object_file: str, smpl_file: str) -> None:
674
+ """Saves rendered images of the object in the scene."""
675
+ object_uid = '' # os.path.basename(object_file).split(".")[0]
676
+ # # if we already render this object, we skip it
677
+ if os.path.exists(os.path.join(args.output_dir, 'meta.npy')): return
678
+ os.makedirs(args.output_dir, exist_ok=True)
679
+ os.makedirs(os.path.join(args.output_dir, 'camera'), exist_ok=True)
680
+
681
+ reset_scene()
682
+ load_object(object_file)
683
+
684
+ lights = [obj for obj in bpy.context.scene.objects if obj.type == 'LIGHT']
685
+ for light in lights:
686
+ bpy.data.objects.remove(light, do_unlink=True)
687
+
688
+ # bproc.init()
689
+
690
+ world_tree = bpy.context.scene.world.node_tree
691
+ back_node = world_tree.nodes['Background']
692
+ env_light = 0.5
693
+ back_node.inputs['Color'].default_value = Vector([env_light, env_light, env_light, 1.0])
694
+ back_node.inputs['Strength'].default_value = 1.0
695
+
696
+ #Make light just directional, disable shadows.
697
+ light_data = bpy.data.lights.new(name=f'Light', type='SUN')
698
+ light = bpy.data.objects.new(name=f'Light', object_data=light_data)
699
+ bpy.context.collection.objects.link(light)
700
+ light = bpy.data.lights['Light']
701
+ light.use_shadow = False
702
+ # Possibly disable specular shading:
703
+ light.specular_factor = 1.0
704
+ light.energy = 5.0
705
+
706
+ #Add another light source so stuff facing away from light is not completely dark
707
+ light_data = bpy.data.lights.new(name=f'Light2', type='SUN')
708
+ light = bpy.data.objects.new(name=f'Light2', object_data=light_data)
709
+ bpy.context.collection.objects.link(light)
710
+ light2 = bpy.data.lights['Light2']
711
+ light2.use_shadow = False
712
+ light2.specular_factor = 1.0
713
+ light2.energy = 3 #0.015
714
+ bpy.data.objects['Light2'].rotation_euler = bpy.data.objects['Light2'].rotation_euler
715
+ bpy.data.objects['Light2'].rotation_euler[0] += 180
716
+
717
+ #Add another light source so stuff facing away from light is not completely dark
718
+ light_data = bpy.data.lights.new(name=f'Light3', type='SUN')
719
+ light = bpy.data.objects.new(name=f'Light3', object_data=light_data)
720
+ bpy.context.collection.objects.link(light)
721
+ light3 = bpy.data.lights['Light3']
722
+ light3.use_shadow = False
723
+ light3.specular_factor = 1.0
724
+ light3.energy = 3 #0.015
725
+ bpy.data.objects['Light3'].rotation_euler = bpy.data.objects['Light3'].rotation_euler
726
+ bpy.data.objects['Light3'].rotation_euler[0] += 90
727
+
728
+ #Add another light source so stuff facing away from light is not completely dark
729
+ light_data = bpy.data.lights.new(name=f'Light4', type='SUN')
730
+ light = bpy.data.objects.new(name=f'Light4', object_data=light_data)
731
+ bpy.context.collection.objects.link(light)
732
+ light4 = bpy.data.lights['Light4']
733
+ light4.use_shadow = False
734
+ light4.specular_factor = 1.0
735
+ light4.energy = 3 #0.015
736
+ bpy.data.objects['Light4'].rotation_euler = bpy.data.objects['Light4'].rotation_euler
737
+ bpy.data.objects['Light4'].rotation_euler[0] += -90
738
+
739
+ scale, offset = normalize_scene()
740
+
741
+
742
+ try:
743
+ # some objects' normals are affected by textures
744
+ mesh_objects = [obj for obj in scene_meshes()]
745
+ main_bsdf_name = 'BsdfPrincipled'
746
+ normal_name = 'Normal'
747
+ for obj in mesh_objects:
748
+ for mat in obj.data.materials:
749
+ for node in mat.node_tree.nodes:
750
+ if main_bsdf_name in node.bl_idname:
751
+ principled_bsdf = node
752
+ # remove links, we don't want add normal textures
753
+ if principled_bsdf.inputs[normal_name].links:
754
+ mat.node_tree.links.remove(principled_bsdf.inputs[normal_name].links[0])
755
+ except:
756
+ print("don't know why")
757
+ # create an empty object to track
758
+ empty = bpy.data.objects.new("Empty", None)
759
+ scene.collection.objects.link(empty)
760
+ cam_constraint.target = empty
761
+
762
+ subject_width = 1.0
763
+
764
+ normal_file_output.base_path = os.path.join(args.output_dir, object_uid, 'normal')
765
+ for i in range(args.num_images):
766
+ # change the camera to orthogonal
767
+ cam.data.type = 'ORTHO'
768
+ cam.data.ortho_scale = subject_width
769
+ distance = 1.5
770
+ azimuth = i * 360 / args.num_images
771
+ bpy.context.view_layer.update()
772
+ set_camera_mvdream(azimuth, 0, distance)
773
+ render_and_save(i * (args.random_images+1), object_uid, -1, azimuth, 0, distance, ortho=True)
774
+ extract_depth(os.path.join(args.output_dir, object_uid, 'depth'))
775
+ # #### smpl
776
+ reset_scene()
777
+ load_object(smpl_file)
778
+
779
+ lights = [obj for obj in bpy.context.scene.objects if obj.type == 'LIGHT']
780
+ for light in lights:
781
+ bpy.data.objects.remove(light, do_unlink=True)
782
+
783
+ scale, offset = normalize_scene()
784
+
785
+ try:
786
+ # some objects' normals are affected by textures
787
+ mesh_objects = [obj for obj in scene_meshes()]
788
+ main_bsdf_name = 'BsdfPrincipled'
789
+ normal_name = 'Normal'
790
+ for obj in mesh_objects:
791
+ for mat in obj.data.materials:
792
+ for node in mat.node_tree.nodes:
793
+ if main_bsdf_name in node.bl_idname:
794
+ principled_bsdf = node
795
+ # remove links, we don't want add normal textures
796
+ if principled_bsdf.inputs[normal_name].links:
797
+ mat.node_tree.links.remove(principled_bsdf.inputs[normal_name].links[0])
798
+ except:
799
+ print("don't know why")
800
+ # create an empty object to track
801
+ empty = bpy.data.objects.new("Empty", None)
802
+ scene.collection.objects.link(empty)
803
+ cam_constraint.target = empty
804
+
805
+ subject_width = 1.0
806
+
807
+ normal_file_output.base_path = os.path.join(args.output_dir, object_uid, 'smpl_normal')
808
+ for i in range(args.num_images):
809
+ # change the camera to orthogonal
810
+ cam.data.type = 'ORTHO'
811
+ cam.data.ortho_scale = subject_width
812
+ distance = 1.5
813
+ azimuth = i * 360 / args.num_images
814
+ bpy.context.view_layer.update()
815
+ set_camera_mvdream(azimuth, 0, distance)
816
+ render_and_save_smpl(i * (args.random_images+1), object_uid, -1, azimuth, 0, distance, ortho=True)
817
+
818
+
819
+ np.save(os.path.join(args.output_dir, object_uid, 'meta.npy'), np.asarray([scale, offset[0], offset[1], offset[1]],np.float32))
820
+
821
+
822
+ if __name__ == "__main__":
823
+ try:
824
+ start_i = time.time()
825
+ if args.object_path.startswith("http"):
826
+ local_path = download_object(args.object_path)
827
+ else:
828
+ local_path = args.object_path
829
+ save_images(local_path, args.smpl_path)
830
+ end_i = time.time()
831
+ print("Finished", local_path, "in", end_i - start_i, "seconds")
832
+ # delete the object if it was downloaded
833
+ if args.object_path.startswith("http"):
834
+ os.remove(local_path)
835
+ except Exception as e:
836
+ print("Failed to render", args.object_path)
837
+ print(e)
blender/check_render.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from tqdm import tqdm
3
+ import json
4
+ from icecream import ic
5
+
6
+
7
+ def check_render(dataset, st=None, end=None):
8
+ total_lists = []
9
+ with open(dataset+'.json', 'r') as f:
10
+ glb_list = json.load(f)
11
+ for x in glb_list:
12
+ total_lists.append(x.split('/')[-2] )
13
+
14
+ if st is not None:
15
+ end = min(end, len(total_lists))
16
+ total_lists = total_lists[st:end]
17
+ glb_list = glb_list[st:end]
18
+
19
+ save_dir = '/data/lipeng/human_8view_with_smplx/'+dataset
20
+ unrendered = set(total_lists) - set(os.listdir(save_dir))
21
+
22
+ num_finish = 0
23
+ num_failed = len(unrendered)
24
+ failed_case = []
25
+ for case in os.listdir(save_dir):
26
+ if not os.path.exists(os.path.join(save_dir, case, 'smpl_normal', '007.png')):
27
+ failed_case.append(case)
28
+ num_failed += 1
29
+ else:
30
+ num_finish += 1
31
+ ic(num_failed)
32
+ ic(num_finish)
33
+
34
+
35
+ need_render = []
36
+ for full_path in glb_list:
37
+ for case in failed_case:
38
+ if case in full_path:
39
+ need_render.append(full_path)
40
+
41
+ with open('need_render.json', 'w') as f:
42
+ json.dump(need_render, f, indent=4)
43
+
44
+ if __name__ == '__main__':
45
+ dataset = 'THuman2.1'
46
+ check_render(dataset)
blender/count.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ def find_files(directory, extensions):
4
+ results = []
5
+ for foldername, subfolders, filenames in os.walk(directory):
6
+ for filename in filenames:
7
+ if filename.endswith(extensions):
8
+ file_path = os.path.abspath(os.path.join(foldername, filename))
9
+ results.append(file_path)
10
+ return results
11
+
12
+ def count_customhumans(root):
13
+ directory_path = ['CustomHumans/mesh']
14
+
15
+ extensions = ('.ply', '.obj')
16
+
17
+ lists = []
18
+ for dataset_path in directory_path:
19
+ dir = os.path.join(root, dataset_path)
20
+ file_paths = find_files(dir, extensions)
21
+ # import pdb;pdb.set_trace()
22
+ dataset_name = dataset_path.split('/')[0]
23
+ for file_path in file_paths:
24
+ lists.append(file_path.replace(root, ""))
25
+ with open(f'{dataset_name}.json', 'w') as f:
26
+ json.dump(lists, f, indent=4)
27
+
28
+ def count_thuman21(root):
29
+ directory_path = ['THuman2.1/mesh']
30
+ extensions = ('.ply', '.obj')
31
+ lists = []
32
+ for dataset_path in directory_path:
33
+ dir = os.path.join(root, dataset_path)
34
+ file_paths = find_files(dir, extensions)
35
+ dataset_name = dataset_path.split('/')[0]
36
+ for file_path in file_paths:
37
+ lists.append(file_path.replace(root, ""))
38
+ with open(f'{dataset_name}.json', 'w') as f:
39
+ json.dump(lists, f, indent=4)
40
+
41
+ if __name__ == '__main__':
42
+ root = '/data/lipeng/human_scan/'
43
+ # count_customhumans(root)
44
+ count_thuman21(root)
blender/distribute.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import json
3
+ import multiprocessing
4
+ import shutil
5
+ import subprocess
6
+ import time
7
+ from dataclasses import dataclass
8
+ from typing import Optional
9
+ import os
10
+
11
+ import boto3
12
+
13
+
14
+ from glob import glob
15
+
16
+ import argparse
17
+
18
+ parser = argparse.ArgumentParser(description='distributed rendering')
19
+
20
+ parser.add_argument('--workers_per_gpu', type=int, default=10,
21
+ help='number of workers per gpu.')
22
+ parser.add_argument('--input_models_path', type=str, default='/data/lipeng/human_scan/',
23
+ help='Path to a json file containing a list of 3D object files.')
24
+ parser.add_argument('--num_gpus', type=int, default=-1,
25
+ help='number of gpus to use. -1 means all available gpus.')
26
+ parser.add_argument('--gpu_list',nargs='+', type=int,
27
+ help='the avalaible gpus')
28
+ parser.add_argument('--resolution', type=int, default=512,
29
+ help='')
30
+ parser.add_argument('--random_images', type=int, default=0)
31
+ parser.add_argument('--start_i', type=int, default=0,
32
+ help='the index of first object to be rendered.')
33
+ parser.add_argument('--end_i', type=int, default=-1,
34
+ help='the index of the last object to be rendered.')
35
+
36
+ parser.add_argument('--data_dir', type=str, default='/data/lipeng/human_scan/',
37
+ help='Path to a json file containing a list of 3D object files.')
38
+
39
+ parser.add_argument('--json_path', type=str, default='2K2K.json')
40
+
41
+ parser.add_argument('--save_dir', type=str, default='/data/lipeng/human_8view',
42
+ help='Path to a json file containing a list of 3D object files.')
43
+
44
+ parser.add_argument('--ortho_scale', type=float, default=1.,
45
+ help='ortho rendering usage; how large the object is')
46
+
47
+
48
+ args = parser.parse_args()
49
+
50
+ def parse_obj_list(xs):
51
+ cases = []
52
+ # print(xs[:2])
53
+
54
+ for x in xs:
55
+ if 'THuman3.0' in x:
56
+ # print(apath)
57
+ splits = x.split('/')
58
+ x = os.path.join('THuman3.0', splits[-2])
59
+ elif 'THuman2.1' in x:
60
+ splits = x.split('/')
61
+ x = os.path.join('THuman2.1', splits[-2])
62
+ elif 'CustomHumans' in x:
63
+ splits = x.split('/')
64
+ x = os.path.join('CustomHumans', splits[-2])
65
+ elif '1M' in x:
66
+ splits = x.split('/')
67
+ x = os.path.join('2K2K', splits[-2])
68
+ elif 'realistic_8k_model' in x:
69
+ splits = x.split('/')
70
+ x = os.path.join('realistic_8k_model', splits[-1].split('.')[0])
71
+ cases.append(f'{args.save_dir}/{x}')
72
+ return cases
73
+
74
+
75
+ with open(args.json_path, 'r') as f:
76
+ glb_list = json.load(f)
77
+
78
+ # glb_list = ['THuman2.1/mesh/1618/1618.obj']
79
+ # glb_list = ['THuman3.0/00024_1/00024_0006/mesh.obj']
80
+ # glb_list = ['CustomHumans/mesh/0383_00070_02_00061/mesh-f00061.obj']
81
+ # glb_list = ['1M/01968/01968.ply', '1M/00103/00103.ply']
82
+ # glb_list = ['realistic_8k_model/01aab099a2fe4af7be120110a385105d.glb']
83
+
84
+ total_num_glbs = len(glb_list)
85
+
86
+
87
+
88
+ def worker(
89
+ queue: multiprocessing.JoinableQueue,
90
+ count: multiprocessing.Value,
91
+ gpu: int,
92
+ s3: Optional[boto3.client],
93
+ ) -> None:
94
+ print("Worker started")
95
+ while True:
96
+ case, save_p = queue.get()
97
+ src_path = os.path.join(args.data_dir, case)
98
+ smpl_path = src_path.replace('mesh', 'smplx', 1)
99
+
100
+ command = ('blender -b -P blender_render_human_ortho.py'
101
+ f' -- --object_path {src_path}'
102
+ f' --smpl_path {smpl_path}'
103
+ f' --output_dir {save_p} --engine CYCLES'
104
+ f' --resolution {args.resolution}'
105
+ f' --random_images {args.random_images}'
106
+ )
107
+
108
+ print(command)
109
+ subprocess.run(command, shell=True)
110
+
111
+ with count.get_lock():
112
+ count.value += 1
113
+
114
+ queue.task_done()
115
+
116
+
117
+ if __name__ == "__main__":
118
+ # args = tyro.cli(Args)
119
+
120
+ s3 = None
121
+ queue = multiprocessing.JoinableQueue()
122
+ count = multiprocessing.Value("i", 0)
123
+
124
+ # Start worker processes on each of the GPUs
125
+ for gpu_i in range(args.num_gpus):
126
+ for worker_i in range(args.workers_per_gpu):
127
+ worker_i = gpu_i * args.workers_per_gpu + worker_i
128
+ process = multiprocessing.Process(
129
+ target=worker, args=(queue, count, args.gpu_list[gpu_i], s3)
130
+ )
131
+ process.daemon = True
132
+ process.start()
133
+
134
+ # Add items to the queue
135
+
136
+ save_dirs = parse_obj_list(glb_list)
137
+ args.end_i = len(save_dirs) if args.end_i > len(save_dirs) or args.end_i==-1 else args.end_i
138
+
139
+ for case_sub, save_dir in zip(glb_list[args.start_i:args.end_i], save_dirs[args.start_i:args.end_i]):
140
+ queue.put([case_sub, save_dir])
141
+
142
+
143
+
144
+ # Wait for all tasks to be completed
145
+ queue.join()
146
+
147
+ # Add sentinels to the queue to stop the worker processes
148
+ for i in range(args.num_gpus * args.workers_per_gpu):
149
+ queue.put(None)
blender/rename_smpl_files.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from tqdm import tqdm
3
+ from glob import glob
4
+
5
+ def rename_customhumans():
6
+ root = '/data/lipeng/human_scan/CustomHumans/smplx/'
7
+ file_paths = glob(os.path.join(root, '*/*_smpl.obj'))
8
+ for file_path in tqdm(file_paths):
9
+ new_path = file_path.replace('_smpl', '')
10
+ os.rename(file_path, new_path)
11
+
12
+ def rename_thuman21():
13
+ root = '/data/lipeng/human_scan/THuman2.1/smplx/'
14
+ file_paths = glob(os.path.join(root, '*/*.obj'))
15
+ for file_path in tqdm(file_paths):
16
+ obj_name = file_path.split('/')[-2]
17
+ folder_name = os.path.dirname(file_path)
18
+ new_path = os.path.join(folder_name, obj_name+'.obj')
19
+ # print(new_path)
20
+ # print(file_path)
21
+ os.rename(file_path, new_path)
22
+
23
+ if __name__ == '__main__':
24
+ rename_thuman21()
25
+ rename_customhumans()
blender/render.sh ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ #### install environment
2
+ # ~/pkgs/blender-3.6.4/3.6/python/bin/python3.10 -m pip install openexr opencv-python
3
+
4
+ python render_human.py
blender/render_human.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import math
4
+ from concurrent.futures import ProcessPoolExecutor
5
+ import threading
6
+ from tqdm import tqdm
7
+
8
+ # from glcontext import egl
9
+ # egl.create_context()
10
+ # exit(0)
11
+
12
+ LOCAL_RANK = 0
13
+
14
+ num_processes = 4
15
+ NODE_RANK = int(os.getenv("SLURM_PROCID"))
16
+ WORLD_SIZE = 1
17
+ NODE_NUM=1
18
+ # NODE_RANK = int(os.getenv("SLURM_NODEID"))
19
+ IS_MAIN = False
20
+ if NODE_RANK == 0 and LOCAL_RANK == 0:
21
+ IS_MAIN = True
22
+
23
+ GLOBAL_RANK = NODE_RANK * (WORLD_SIZE//NODE_NUM) + LOCAL_RANK
24
+
25
+
26
+ # json_path = "object_lists/Thuman2.0.json"
27
+ # json_path = "object_lists/THuman3.0.json"
28
+ json_path = "object_lists/CustomHumans.json"
29
+ data_dir = '/aifs4su/mmcode/lipeng'
30
+ save_dir = '/aifs4su/mmcode/lipeng/human_8view_new'
31
+ def parse_obj_list(x):
32
+ if 'THuman3.0' in x:
33
+ # print(apath)
34
+ splits = x.split('/')
35
+ x = os.path.join('THuman3.0', splits[-2])
36
+ elif 'Thuman2.0' in x:
37
+ splits = x.split('/')
38
+ x = os.path.join('Thuman2.0', splits[-2])
39
+ elif 'CustomHumans' in x:
40
+ splits = x.split('/')
41
+ x = os.path.join('CustomHumans', splits[-2])
42
+ # print(splits[-2])
43
+ elif '1M' in x:
44
+ splits = x.split('/')
45
+ x = os.path.join('2K2K', splits[-2])
46
+ elif 'realistic_8k_model' in x:
47
+ splits = x.split('/')
48
+ x = os.path.join('realistic_8k_model', splits[-1].split('.')[0])
49
+ return f'{save_dir}/{x}'
50
+
51
+ with open(json_path, 'r') as f:
52
+ glb_list = json.load(f)
53
+
54
+ # glb_list = ['Thuman2.0/0011/0011.obj']
55
+ # glb_list = ['THuman3.0/00024_1/00024_0006/mesh.obj']
56
+ # glb_list = ['CustomHumans/mesh/0383_00070_02_00061/mesh-f00061.obj']
57
+ # glb_list = ['realistic_8k_model/1d41f2a72f994306b80e632f1cc8233f.glb']
58
+
59
+ total_num_glbs = len(glb_list)
60
+
61
+ num_glbs_local = int(math.ceil(total_num_glbs / WORLD_SIZE))
62
+ start_idx = GLOBAL_RANK * num_glbs_local
63
+ end_idx = start_idx + num_glbs_local
64
+ # print(start_idx, end_idx)
65
+ local_glbs = glb_list[start_idx:end_idx]
66
+ if IS_MAIN:
67
+ pbar = tqdm(total=len(local_glbs))
68
+ lock = threading.Lock()
69
+
70
+ def process_human(glb_path):
71
+ src_path = os.path.join(data_dir, glb_path)
72
+ save_path = parse_obj_list(glb_path)
73
+ # print(save_path)
74
+ command = ('blender -b -P blender_render_human_script.py'
75
+ f' -- --object_path {src_path}'
76
+ f' --output_dir {save_path} ')
77
+ # 1>/dev/null
78
+ # print(command)
79
+ os.system(command)
80
+
81
+ if IS_MAIN:
82
+ with lock:
83
+ pbar.update(1)
84
+
85
+ with ProcessPoolExecutor(max_workers=num_processes) as executor:
86
+ executor.map(process_human, local_glbs)
87
+
88
+
blender/render_single.sh ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # debug single sample
2
+ blender -b -P blender_render_human_ortho.py \
3
+ -- --object_path /data/lipeng/human_scan/THuman2.1/mesh/0011/0011.obj \
4
+ --smpl_path /data/lipeng/human_scan/THuman2.1/smplx/0011/0011.obj \
5
+ --output_dir debug --engine CYCLES \
6
+ --resolution 768 \
7
+ --random_images 0
blender/utils.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import pytz
3
+ import traceback
4
+ from torchvision.utils import make_grid
5
+ from PIL import Image, ImageDraw, ImageFont
6
+ import numpy as np
7
+ import torch
8
+ import json
9
+ import os
10
+ from tqdm import tqdm
11
+ import cv2
12
+ import imageio
13
+ def get_time_for_log():
14
+ return datetime.datetime.now(pytz.timezone('Asia/Shanghai')).strftime(
15
+ "%Y%m%d %H:%M:%S")
16
+
17
+
18
+ def get_trace_for_log():
19
+ return str(traceback.format_exc())
20
+
21
+ def make_grid_(imgs, save_file, nrow=10, pad_value=1):
22
+ if isinstance(imgs, list):
23
+ if isinstance(imgs[0], Image.Image):
24
+ imgs = [torch.from_numpy(np.array(img)/255.) for img in imgs]
25
+ elif isinstance(imgs[0], np.ndarray):
26
+ imgs = [torch.from_numpy(img/255.) for img in imgs]
27
+ imgs = torch.stack(imgs, 0).permute(0, 3, 1, 2)
28
+ if isinstance(imgs, np.ndarray):
29
+ imgs = torch.from_numpy(imgs)
30
+
31
+ img_grid = make_grid(imgs, nrow=nrow, padding=2, pad_value=pad_value)
32
+ img_grid = img_grid.permute(1, 2, 0).numpy()
33
+ img_grid = (img_grid * 255).astype(np.uint8)
34
+ img_grid = Image.fromarray(img_grid)
35
+ img_grid.save(save_file)
36
+
37
+ def draw_caption(img, text, pos, size=100, color=(128, 128, 128)):
38
+ draw = ImageDraw.Draw(img)
39
+ # font = ImageFont.truetype(size= size)
40
+ font = ImageFont.load_default()
41
+ font = font.font_variant(size=size)
42
+ draw.text(pos, text, color, font=font)
43
+ return img
44
+
45
+
46
+ def txt2json(txt_file, json_file):
47
+ with open(txt_file, 'r') as f:
48
+ items = f.readlines()
49
+ items = [x.strip() for x in items]
50
+
51
+ with open(json_file, 'w') as f:
52
+ json.dump(items.tolist(), f)
53
+
54
+ def process_thuman_texture():
55
+ path = '/aifs4su/mmcode/lipeng/Thuman2.0'
56
+ cases = os.listdir(path)
57
+ for case in tqdm(cases):
58
+ mtl = os.path.join(path, case, 'material0.mtl')
59
+ with open(mtl, 'r') as f:
60
+ lines = f.read()
61
+ lines = lines.replace('png', 'jpeg')
62
+ with open(mtl, 'w') as f:
63
+ f.write(lines)
64
+
65
+
66
+ #### for debug
67
+ os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
68
+
69
+
70
+ def get_intrinsic_from_fov(fov, H, W, bs=-1):
71
+ focal_length = 0.5 * H / np.tan(0.5 * fov)
72
+ intrinsic = np.identity(3, dtype=np.float32)
73
+ intrinsic[0, 0] = focal_length
74
+ intrinsic[1, 1] = focal_length
75
+ intrinsic[0, 2] = W / 2.0
76
+ intrinsic[1, 2] = H / 2.0
77
+
78
+ if bs > 0:
79
+ intrinsic = intrinsic[None].repeat(bs, axis=0)
80
+
81
+ return torch.from_numpy(intrinsic)
82
+
83
+ def read_data(data_dir, i):
84
+ """
85
+ Return:
86
+ rgb: (H, W, 3) torch.float32
87
+ depth: (H, W, 1) torch.float32
88
+ mask: (H, W, 1) torch.float32
89
+ c2w: (4, 4) torch.float32
90
+ intrinsic: (3, 3) torch.float32
91
+ """
92
+ background_color = torch.tensor([0.0, 0.0, 0.0])
93
+
94
+ rgb_name = os.path.join(data_dir, f'render_%04d.webp' % i)
95
+ depth_name = os.path.join(data_dir, f'depth_%04d.exr' % i)
96
+
97
+ img = torch.from_numpy(
98
+ np.asarray(
99
+ Image.fromarray(imageio.v2.imread(rgb_name))
100
+ .convert("RGBA")
101
+ )
102
+ / 255.0
103
+ ).float()
104
+ mask = img[:, :, -1:]
105
+ rgb = img[:, :, :3] * mask + background_color[
106
+ None, None, :
107
+ ] * (1 - mask)
108
+
109
+ depth = torch.from_numpy(
110
+ cv2.imread(depth_name, cv2.IMREAD_UNCHANGED)[..., 0, None]
111
+ )
112
+ mask[depth > 100.0] = 0.0
113
+ depth[~(mask > 0.5)] = 0.0 # set invalid depth to 0
114
+
115
+ meta_path = os.path.join(data_dir, 'meta.json')
116
+ with open(meta_path, 'r') as f:
117
+ meta = json.load(f)
118
+
119
+ c2w = torch.as_tensor(
120
+ meta['locations'][i]["transform_matrix"],
121
+ dtype=torch.float32,
122
+ )
123
+
124
+ H, W = rgb.shape[:2]
125
+ fovy = meta["camera_angle_x"]
126
+ intrinsic = get_intrinsic_from_fov(fovy, H=H, W=W)
127
+
128
+ return rgb, depth, mask, c2w, intrinsic
configs/inference-768-6view.yaml ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pretrained_model_name_or_path: 'stabilityai/stable-diffusion-2-1-unclip'
2
+ revision: null
3
+
4
+ num_views: 7
5
+ with_smpl: false
6
+ validation_dataset:
7
+ prompt_embeds_path: mvdiffusion/data/fixed_prompt_embeds_7view
8
+ root_dir: 'examples/shhq'
9
+ num_views: ${num_views}
10
+ bg_color: 'white'
11
+ img_wh: [768, 768]
12
+ num_validation_samples: 1000
13
+ crop_size: 740
14
+ margin_size: 50
15
+ smpl_folder: 'smpl_image_pymaf'
16
+
17
+
18
+ save_dir: 'mv_results'
19
+ save_mode: 'rgba' # 'concat', 'rgba', 'rgb'
20
+ seed: 42
21
+ validation_batch_size: 1
22
+ dataloader_num_workers: 1
23
+ local_rank: -1
24
+
25
+ pipe_kwargs:
26
+ num_views: ${num_views}
27
+
28
+ validation_guidance_scales: 3.0
29
+ pipe_validation_kwargs:
30
+ num_inference_steps: 40
31
+ eta: 1.0
32
+
33
+ validation_grid_nrow: ${num_views}
34
+
35
+ unet_from_pretrained_kwargs:
36
+ unclip: true
37
+ sdxl: false
38
+ num_views: ${num_views}
39
+ sample_size: 96
40
+ zero_init_conv_in: false # modify
41
+
42
+ projection_camera_embeddings_input_dim: 2 # 2 for elevation and 6 for focal_length
43
+ zero_init_camera_projection: false
44
+ num_regress_blocks: 3
45
+
46
+ cd_attention_last: false
47
+ cd_attention_mid: false
48
+ multiview_attention: true
49
+ sparse_mv_attention: true
50
+ selfattn_block: self_rowwise
51
+ mvcd_attention: true
52
+
53
+ recon_opt:
54
+ res_path: out
55
+ save_glb: true
56
+ # camera setting
57
+ num_view: 6
58
+ scale: 4
59
+ mode: ortho
60
+ resolution: 1024
61
+ cam_path: 'mvdiffusion/data/six_human_pose'
62
+ # optimization
63
+ iters: 700
64
+ clr_iters: 200
65
+ debug: false
66
+ snapshot_step: 50
67
+ lr_clr: 2e-3
68
+ gpu_id: 0
69
+
70
+ replace_hand: false
71
+
72
+ enable_xformers_memory_efficient_attention: true
configs/remesh.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ res_path: out
2
+ save_glb: False
3
+ imgs_path: examples/debug
4
+ mv_path: ./
5
+ # camera setting
6
+ num_view: 6
7
+ scale: 4
8
+ mode: ortho
9
+ resolution: 1024
10
+ cam_path: 'mvdiffusion/data/six_human_pose'
11
+ # optimization
12
+ iters: 700
13
+ clr_iters: 200
14
+ debug: false
15
+ snapshot_step: 50
16
+ lr_clr: 2e-3
17
+ gpu_id: 0
18
+ replace_hand: false
configs/train-768-6view-onlyscan_face.yaml ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pretrained_model_name_or_path: stabilityai/stable-diffusion-2-1-unclip
2
+ pretrained_unet_path: null
3
+ revision: null
4
+ with_smpl: false
5
+ data_common:
6
+ root_dir: /aifs4su/mmcode/lipeng/human_8view_new/
7
+ predict_relative_views: [0, 1, 2, 4, 6, 7]
8
+ num_validation_samples: 8
9
+ img_wh: [768, 768]
10
+ read_normal: true
11
+ read_color: true
12
+ read_depth: false
13
+ exten: .png
14
+ prompt_embeds_path: mvdiffusion/data/fixed_prompt_embeds_7view
15
+ object_list:
16
+ - data_lists/human_only_scan.json
17
+ invalid_list:
18
+ -
19
+ train_dataset:
20
+ root_dir: ${data_common.root_dir}
21
+ azi_interval: 45.0
22
+ random_views: 3
23
+ predict_relative_views: ${data_common.predict_relative_views}
24
+ bg_color: three_choices
25
+ object_list: ${data_common.object_list}
26
+ invalid_list: ${data_common.invalid_list}
27
+ img_wh: ${data_common.img_wh}
28
+ validation: false
29
+ num_validation_samples: ${data_common.num_validation_samples}
30
+ read_normal: ${data_common.read_normal}
31
+ read_color: ${data_common.read_color}
32
+ read_depth: ${data_common.read_depth}
33
+ load_cache: false
34
+ exten: ${data_common.exten}
35
+ prompt_embeds_path: ${data_common.prompt_embeds_path}
36
+ side_views_rate: 0.3
37
+ elevation_list: null
38
+ validation_dataset:
39
+ prompt_embeds_path: ${data_common.prompt_embeds_path}
40
+ root_dir: examples/debug
41
+ num_views: ${num_views}
42
+ bg_color: white
43
+ img_wh: ${data_common.img_wh}
44
+ num_validation_samples: 1000
45
+ crop_size: 740
46
+ validation_train_dataset:
47
+ root_dir: ${data_common.root_dir}
48
+ azi_interval: 45.0
49
+ random_views: 3
50
+ predict_relative_views: ${data_common.predict_relative_views}
51
+ bg_color: white
52
+ object_list: ${data_common.object_list}
53
+ invalid_list: ${data_common.invalid_list}
54
+ img_wh: ${data_common.img_wh}
55
+ validation: false
56
+ num_validation_samples: ${data_common.num_validation_samples}
57
+ read_normal: ${data_common.read_normal}
58
+ read_color: ${data_common.read_color}
59
+ read_depth: ${data_common.read_depth}
60
+ num_samples: ${data_common.num_validation_samples}
61
+ load_cache: false
62
+ exten: ${data_common.exten}
63
+ prompt_embeds_path: ${data_common.prompt_embeds_path}
64
+ elevation_list: null
65
+ output_dir: output/unit-unclip-768-6view-onlyscan-onlyortho-faceinself-scale0.5
66
+ checkpoint_prefix: ../human_checkpoint_backup/
67
+ seed: 42
68
+ train_batch_size: 2
69
+ validation_batch_size: 1
70
+ validation_train_batch_size: 1
71
+ max_train_steps: 30000
72
+ gradient_accumulation_steps: 2
73
+ gradient_checkpointing: true
74
+ learning_rate: 0.0001
75
+ scale_lr: false
76
+ lr_scheduler: piecewise_constant
77
+ step_rules: 1:2000,0.5
78
+ lr_warmup_steps: 10
79
+ snr_gamma: 5.0
80
+ use_8bit_adam: false
81
+ allow_tf32: true
82
+ use_ema: true
83
+ dataloader_num_workers: 32
84
+ adam_beta1: 0.9
85
+ adam_beta2: 0.999
86
+ adam_weight_decay: 0.01
87
+ adam_epsilon: 1.0e-08
88
+ max_grad_norm: 1.0
89
+ prediction_type: null
90
+ logging_dir: logs
91
+ vis_dir: vis
92
+ mixed_precision: fp16
93
+ report_to: wandb
94
+ local_rank: 0
95
+ checkpointing_steps: 2500
96
+ checkpoints_total_limit: 2
97
+ resume_from_checkpoint: latest
98
+ enable_xformers_memory_efficient_attention: true
99
+ validation_steps: 2500 #
100
+ validation_sanity_check: true
101
+ tracker_project_name: PSHuman
102
+ trainable_modules: null
103
+
104
+
105
+ use_classifier_free_guidance: true
106
+ condition_drop_rate: 0.05
107
+ scale_input_latents: true
108
+ regress_elevation: false
109
+ regress_focal_length: false
110
+ elevation_loss_weight: 1.0
111
+ focal_loss_weight: 0.0
112
+ pipe_kwargs:
113
+ num_views: ${num_views}
114
+ pipe_validation_kwargs:
115
+ eta: 1.0
116
+
117
+ unet_from_pretrained_kwargs:
118
+ unclip: true
119
+ num_views: ${num_views}
120
+ sample_size: 96
121
+ zero_init_conv_in: true
122
+ regress_elevation: ${regress_elevation}
123
+ regress_focal_length: ${regress_focal_length}
124
+ num_regress_blocks: 2
125
+ camera_embedding_type: e_de_da_sincos
126
+ projection_camera_embeddings_input_dim: 2
127
+ zero_init_camera_projection: true # modified
128
+ init_mvattn_with_selfattn: false
129
+ cd_attention_last: false
130
+ cd_attention_mid: false
131
+ multiview_attention: true
132
+ sparse_mv_attention: true
133
+ selfattn_block: self_rowwise
134
+ mvcd_attention: true
135
+ addition_downsample: false
136
+ use_face_adapter: false
137
+
138
+ validation_guidance_scales:
139
+ - 3.0
140
+ validation_grid_nrow: ${num_views}
141
+ camera_embedding_lr_mult: 1.0
142
+ plot_pose_acc: false
143
+ num_views: 7
144
+ pred_type: joint
145
+ drop_type: drop_as_a_whole
configs/train-768-6view-onlyscan_face_smplx.yaml ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pretrained_model_name_or_path: stabilityai/stable-diffusion-2-1-unclip
2
+ pretrained_unet_path: null
3
+ revision: null
4
+ with_smpl: true
5
+ data_common:
6
+ root_dir: /aifs4su/mmcode/lipeng/human_8view_with_smplx/
7
+ predict_relative_views: [0, 1, 2, 4, 6, 7]
8
+ num_validation_samples: 8
9
+ img_wh: [768, 768]
10
+ read_normal: true
11
+ read_color: true
12
+ read_depth: false
13
+ exten: .png
14
+ prompt_embeds_path: mvdiffusion/data/fixed_prompt_embeds_7view
15
+ object_list:
16
+ - data_lists/human_only_scan_with_smplx.json # modified
17
+ invalid_list:
18
+ -
19
+ with_smpl: ${with_smpl}
20
+
21
+ train_dataset:
22
+ root_dir: ${data_common.root_dir}
23
+ azi_interval: 45.0
24
+ random_views: 0
25
+ predict_relative_views: ${data_common.predict_relative_views}
26
+ bg_color: three_choices
27
+ object_list: ${data_common.object_list}
28
+ invalid_list: ${data_common.invalid_list}
29
+ img_wh: ${data_common.img_wh}
30
+ validation: false
31
+ num_validation_samples: ${data_common.num_validation_samples}
32
+ read_normal: ${data_common.read_normal}
33
+ read_color: ${data_common.read_color}
34
+ read_depth: ${data_common.read_depth}
35
+ load_cache: false
36
+ exten: ${data_common.exten}
37
+ prompt_embeds_path: ${data_common.prompt_embeds_path}
38
+ side_views_rate: 0.3
39
+ elevation_list: null
40
+ with_smpl: ${with_smpl}
41
+
42
+ validation_dataset:
43
+ prompt_embeds_path: ${data_common.prompt_embeds_path}
44
+ root_dir: examples/debug
45
+ num_views: ${num_views}
46
+ bg_color: white
47
+ img_wh: ${data_common.img_wh}
48
+ num_validation_samples: 1000
49
+ margin_size: 10
50
+ # crop_size: 720
51
+
52
+ validation_train_dataset:
53
+ root_dir: ${data_common.root_dir}
54
+ azi_interval: 45.0
55
+ random_views: 0
56
+ predict_relative_views: ${data_common.predict_relative_views}
57
+ bg_color: white
58
+ object_list: ${data_common.object_list}
59
+ invalid_list: ${data_common.invalid_list}
60
+ img_wh: ${data_common.img_wh}
61
+ validation: false
62
+ num_validation_samples: ${data_common.num_validation_samples}
63
+ read_normal: ${data_common.read_normal}
64
+ read_color: ${data_common.read_color}
65
+ read_depth: ${data_common.read_depth}
66
+ num_samples: ${data_common.num_validation_samples}
67
+ load_cache: false
68
+ exten: ${data_common.exten}
69
+ prompt_embeds_path: ${data_common.prompt_embeds_path}
70
+ elevation_list: null
71
+ with_smpl: ${with_smpl}
72
+
73
+ output_dir: output/unit-unclip-768-6view-onlyscan-onlyortho-faceinself-scale0.5-smplx
74
+ checkpoint_prefix: ../human_checkpoint_backup/
75
+ seed: 42
76
+ train_batch_size: 2
77
+ validation_batch_size: 1
78
+ validation_train_batch_size: 1
79
+ max_train_steps: 30000
80
+ gradient_accumulation_steps: 2
81
+ gradient_checkpointing: true
82
+ learning_rate: 0.0001
83
+ scale_lr: false
84
+ lr_scheduler: piecewise_constant
85
+ step_rules: 1:2000,0.5
86
+ lr_warmup_steps: 10
87
+ snr_gamma: 5.0
88
+ use_8bit_adam: false
89
+ allow_tf32: true
90
+ use_ema: true
91
+ dataloader_num_workers: 32
92
+ adam_beta1: 0.9
93
+ adam_beta2: 0.999
94
+ adam_weight_decay: 0.01
95
+ adam_epsilon: 1.0e-08
96
+ max_grad_norm: 1.0
97
+ prediction_type: null
98
+ logging_dir: logs
99
+ vis_dir: vis
100
+ mixed_precision: fp16
101
+ report_to: wandb
102
+ local_rank: 0
103
+ checkpointing_steps: 5000
104
+ checkpoints_total_limit: 2
105
+ resume_from_checkpoint: latest
106
+ enable_xformers_memory_efficient_attention: true
107
+ validation_steps: 2500 #
108
+ validation_sanity_check: true
109
+ tracker_project_name: PSHuman
110
+ trainable_modules: null
111
+
112
+ use_classifier_free_guidance: true
113
+ condition_drop_rate: 0.05
114
+ scale_input_latents: true
115
+ regress_elevation: false
116
+ regress_focal_length: false
117
+ elevation_loss_weight: 1.0
118
+ focal_loss_weight: 0.0
119
+ pipe_kwargs:
120
+ num_views: ${num_views}
121
+ pipe_validation_kwargs:
122
+ eta: 1.0
123
+
124
+ unet_from_pretrained_kwargs:
125
+ unclip: true
126
+ num_views: ${num_views}
127
+ sample_size: 96
128
+ zero_init_conv_in: true
129
+ regress_elevation: ${regress_elevation}
130
+ regress_focal_length: ${regress_focal_length}
131
+ num_regress_blocks: 2
132
+ camera_embedding_type: e_de_da_sincos
133
+ projection_camera_embeddings_input_dim: 2
134
+ zero_init_camera_projection: true # modified
135
+ init_mvattn_with_selfattn: false
136
+ cd_attention_last: false
137
+ cd_attention_mid: false
138
+ multiview_attention: true
139
+ sparse_mv_attention: true
140
+ selfattn_block: self_rowwise
141
+ mvcd_attention: true
142
+ addition_downsample: false
143
+ use_face_adapter: false
144
+ in_channels: 12
145
+
146
+
147
+ validation_guidance_scales:
148
+ - 3.0
149
+ validation_grid_nrow: ${num_views}
150
+ camera_embedding_lr_mult: 1.0
151
+ plot_pose_acc: false
152
+ num_views: 7
153
+ pred_type: joint
154
+ drop_type: drop_as_a_whole
core/opt.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from copy import deepcopy
2
+ import time
3
+ import torch
4
+ import torch_scatter
5
+ from core.remesh import calc_edge_length, calc_edges, calc_face_collapses, calc_face_normals, calc_vertex_normals, collapse_edges, flip_edges, pack, prepend_dummies, remove_dummies, split_edges
6
+
7
+ @torch.no_grad()
8
+ def remesh(
9
+ vertices_etc:torch.Tensor, #V,D
10
+ faces:torch.Tensor, #F,3 long
11
+ min_edgelen:torch.Tensor, #V
12
+ max_edgelen:torch.Tensor, #V
13
+ flip:bool,
14
+ max_vertices=1e6
15
+ ):
16
+
17
+ # dummies
18
+ vertices_etc,faces = prepend_dummies(vertices_etc,faces)
19
+ vertices = vertices_etc[:,:3] #V,3
20
+ nan_tensor = torch.tensor([torch.nan],device=min_edgelen.device)
21
+ min_edgelen = torch.concat((nan_tensor,min_edgelen))
22
+ max_edgelen = torch.concat((nan_tensor,max_edgelen))
23
+
24
+ # collapse
25
+ edges,face_to_edge = calc_edges(faces) #E,2 F,3
26
+ edge_length = calc_edge_length(vertices,edges) #E
27
+ face_normals = calc_face_normals(vertices,faces,normalize=False) #F,3
28
+ vertex_normals = calc_vertex_normals(vertices,faces,face_normals) #V,3
29
+ face_collapse = calc_face_collapses(vertices,faces,edges,face_to_edge,edge_length,face_normals,vertex_normals,min_edgelen,area_ratio=0.5)
30
+ shortness = (1 - edge_length / min_edgelen[edges].mean(dim=-1)).clamp_min_(0) #e[0,1] 0...ok, 1...edgelen=0
31
+ priority = face_collapse.float() + shortness
32
+ vertices_etc,faces = collapse_edges(vertices_etc,faces,edges,priority)
33
+
34
+ # split
35
+ if vertices.shape[0]<max_vertices:
36
+ edges,face_to_edge = calc_edges(faces) #E,2 F,3
37
+ vertices = vertices_etc[:,:3] #V,3
38
+ edge_length = calc_edge_length(vertices,edges) #E
39
+ splits = edge_length > max_edgelen[edges].mean(dim=-1)
40
+ vertices_etc,faces = split_edges(vertices_etc,faces,edges,face_to_edge,splits,pack_faces=False)
41
+
42
+ vertices_etc,faces = pack(vertices_etc,faces)
43
+ vertices = vertices_etc[:,:3]
44
+
45
+ if flip:
46
+ edges,_,edge_to_face = calc_edges(faces,with_edge_to_face=True) #E,2 F,3
47
+ flip_edges(vertices,faces,edges,edge_to_face,with_border=False)
48
+
49
+ return remove_dummies(vertices_etc,faces)
50
+
51
+ def lerp_unbiased(a:torch.Tensor,b:torch.Tensor,weight:float,step:int):
52
+ """lerp with adam's bias correction"""
53
+ c_prev = 1-weight**(step-1)
54
+ c = 1-weight**step
55
+ a_weight = weight*c_prev/c
56
+ b_weight = (1-weight)/c
57
+ a.mul_(a_weight).add_(b, alpha=b_weight)
58
+
59
+
60
+ class MeshOptimizer:
61
+ """Use this like a pytorch Optimizer, but after calling opt.step(), do vertices,faces = opt.remesh()."""
62
+
63
+ def __init__(self,
64
+ vertices:torch.Tensor, #V,3
65
+ faces:torch.Tensor, #F,3
66
+ lr=0.3, #learning rate
67
+ betas=(0.8,0.8,0), #betas[0:2] are the same as in Adam, betas[2] may be used to time-smooth the relative velocity nu
68
+ gammas=(0,0,0), #optional spatial smoothing for m1,m2,nu, values between 0 (no smoothing) and 1 (max. smoothing)
69
+ nu_ref=0.3, #reference velocity for edge length controller
70
+ edge_len_lims=(.01,.15), #smallest and largest allowed reference edge length
71
+ edge_len_tol=.5, #edge length tolerance for split and collapse
72
+ gain=.2, #gain value for edge length controller
73
+ laplacian_weight=.02, #for laplacian smoothing/regularization
74
+ ramp=1, #learning rate ramp, actual ramp width is ramp/(1-betas[0])
75
+ grad_lim=10., #gradients are clipped to m1.abs()*grad_lim
76
+ remesh_interval=1, #larger intervals are faster but with worse mesh quality
77
+ local_edgelen=True, #set to False to use a global scalar reference edge length instead
78
+ remesh_milestones= [500], #list of steps at which to remesh
79
+ # total_steps=1000, #total number of steps
80
+ ):
81
+ self._vertices = vertices
82
+ self._faces = faces
83
+ self._lr = lr
84
+ self._betas = betas
85
+ self._gammas = gammas
86
+ self._nu_ref = nu_ref
87
+ self._edge_len_lims = edge_len_lims
88
+ self._edge_len_tol = edge_len_tol
89
+ self._gain = gain
90
+ self._laplacian_weight = laplacian_weight
91
+ self._ramp = ramp
92
+ self._grad_lim = grad_lim
93
+ # self._remesh_interval = remesh_interval
94
+ # self._remseh_milestones = [ for remesh_milestones]
95
+ self._local_edgelen = local_edgelen
96
+ self._step = 0
97
+ self._start = time.time()
98
+
99
+ V = self._vertices.shape[0]
100
+ # prepare continuous tensor for all vertex-based data
101
+ self._vertices_etc = torch.zeros([V,9],device=vertices.device)
102
+ self._split_vertices_etc()
103
+ self.vertices.copy_(vertices) #initialize vertices
104
+ self._vertices.requires_grad_()
105
+ self._ref_len.fill_(edge_len_lims[1])
106
+
107
+ @property
108
+ def vertices(self):
109
+ return self._vertices
110
+
111
+ @property
112
+ def faces(self):
113
+ return self._faces
114
+
115
+ def _split_vertices_etc(self):
116
+ self._vertices = self._vertices_etc[:,:3]
117
+ self._m2 = self._vertices_etc[:,3]
118
+ self._nu = self._vertices_etc[:,4]
119
+ self._m1 = self._vertices_etc[:,5:8]
120
+ self._ref_len = self._vertices_etc[:,8]
121
+
122
+ with_gammas = any(g!=0 for g in self._gammas)
123
+ self._smooth = self._vertices_etc[:,:8] if with_gammas else self._vertices_etc[:,:3]
124
+
125
+ def zero_grad(self):
126
+ self._vertices.grad = None
127
+
128
+ @torch.no_grad()
129
+ def step(self):
130
+
131
+ eps = 1e-8
132
+
133
+ self._step += 1
134
+ # spatial smoothing
135
+ edges,_ = calc_edges(self._faces) #E,2
136
+ E = edges.shape[0]
137
+ edge_smooth = self._smooth[edges] #E,2,S
138
+ neighbor_smooth = torch.zeros_like(self._smooth) #V,S
139
+ torch_scatter.scatter_mean(src=edge_smooth.flip(dims=[1]).reshape(E*2,-1),index=edges.reshape(E*2,1),dim=0,out=neighbor_smooth)
140
+ #apply optional smoothing of m1,m2,nu
141
+ if self._gammas[0]:
142
+ self._m1.lerp_(neighbor_smooth[:,5:8],self._gammas[0])
143
+ if self._gammas[1]:
144
+ self._m2.lerp_(neighbor_smooth[:,3],self._gammas[1])
145
+ if self._gammas[2]:
146
+ self._nu.lerp_(neighbor_smooth[:,4],self._gammas[2])
147
+
148
+ #add laplace smoothing to gradients
149
+ laplace = self._vertices - neighbor_smooth[:,:3]
150
+ grad = torch.addcmul(self._vertices.grad, laplace, self._nu[:,None], value=self._laplacian_weight)
151
+
152
+ #gradient clipping
153
+ if self._step>1:
154
+ grad_lim = self._m1.abs().mul_(self._grad_lim)
155
+ grad.clamp_(min=-grad_lim,max=grad_lim)
156
+
157
+ # moment updates
158
+ lerp_unbiased(self._m1, grad, self._betas[0], self._step)
159
+ lerp_unbiased(self._m2, (grad**2).sum(dim=-1), self._betas[1], self._step)
160
+
161
+ velocity = self._m1 / self._m2[:,None].sqrt().add_(eps) #V,3
162
+ speed = velocity.norm(dim=-1) #V
163
+
164
+ if self._betas[2]:
165
+ lerp_unbiased(self._nu,speed,self._betas[2],self._step) #V
166
+ else:
167
+ self._nu.copy_(speed) #V
168
+ # update vertices
169
+ ramped_lr = self._lr * min(1,self._step * (1-self._betas[0]) / self._ramp)
170
+ self._vertices.add_(velocity * self._ref_len[:,None], alpha=-ramped_lr)
171
+
172
+ # update target edge length
173
+ if self._step < 500:
174
+ self._remesh_interval = 4
175
+ elif self._step < 800:
176
+ self._remesh_interval = 2
177
+ else:
178
+ self._remesh_interval = 1
179
+
180
+ if self._step % self._remesh_interval == 0:
181
+ if self._local_edgelen:
182
+ len_change = (1 + (self._nu - self._nu_ref) * self._gain)
183
+ else:
184
+ len_change = (1 + (self._nu.mean() - self._nu_ref) * self._gain)
185
+ self._ref_len *= len_change
186
+ self._ref_len.clamp_(*self._edge_len_lims)
187
+
188
+ def remesh(self, flip:bool=True)->tuple[torch.Tensor,torch.Tensor]:
189
+ min_edge_len = self._ref_len * (1 - self._edge_len_tol)
190
+ max_edge_len = self._ref_len * (1 + self._edge_len_tol)
191
+
192
+ self._vertices_etc,self._faces = remesh(self._vertices_etc,self._faces,min_edge_len,max_edge_len,flip)
193
+
194
+ self._split_vertices_etc()
195
+ self._vertices.requires_grad_()
196
+
197
+ return self._vertices, self._faces
core/remesh.py ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as tfunc
3
+ import torch_scatter
4
+
5
+ def prepend_dummies(
6
+ vertices:torch.Tensor, #V,D
7
+ faces:torch.Tensor, #F,3 long
8
+ )->tuple[torch.Tensor,torch.Tensor]:
9
+ """prepend dummy elements to vertices and faces to enable "masked" scatter operations"""
10
+ V,D = vertices.shape
11
+ vertices = torch.concat((torch.full((1,D),fill_value=torch.nan,device=vertices.device),vertices),dim=0)
12
+ faces = torch.concat((torch.zeros((1,3),dtype=torch.long,device=faces.device),faces+1),dim=0)
13
+ return vertices,faces
14
+
15
+ def remove_dummies(
16
+ vertices:torch.Tensor, #V,D - first vertex all nan and unreferenced
17
+ faces:torch.Tensor, #F,3 long - first face all zeros
18
+ )->tuple[torch.Tensor,torch.Tensor]:
19
+ """remove dummy elements added with prepend_dummies()"""
20
+ return vertices[1:],faces[1:]-1
21
+
22
+
23
+ def calc_edges(
24
+ faces: torch.Tensor, # F,3 long - first face may be dummy with all zeros
25
+ with_edge_to_face: bool = False
26
+ ) -> tuple[torch.Tensor, ...]:
27
+ """
28
+ returns tuple of
29
+ - edges E,2 long, 0 for unused, lower vertex index first
30
+ - face_to_edge F,3 long
31
+ - (optional) edge_to_face shape=E,[left,right],[face,side]
32
+
33
+ o-<-----e1 e0,e1...edge, e0<e1
34
+ | /A L,R....left and right face
35
+ | L / | both triangles ordered counter clockwise
36
+ | / R | normals pointing out of screen
37
+ V/ |
38
+ e0---->-o
39
+ """
40
+
41
+ F = faces.shape[0]
42
+
43
+ # make full edges, lower vertex index first
44
+ face_edges = torch.stack((faces,faces.roll(-1,1)),dim=-1) #F*3,3,2
45
+ full_edges = face_edges.reshape(F*3,2)
46
+ sorted_edges,_ = full_edges.sort(dim=-1) #F*3,2 TODO min/max faster?
47
+
48
+ # make unique edges
49
+ edges,full_to_unique = torch.unique(input=sorted_edges,sorted=True,return_inverse=True,dim=0) #(E,2),(F*3)
50
+ E = edges.shape[0]
51
+ face_to_edge = full_to_unique.reshape(F,3) #F,3
52
+
53
+ if not with_edge_to_face:
54
+ return edges, face_to_edge
55
+
56
+ is_right = full_edges[:,0]!=sorted_edges[:,0] #F*3
57
+ edge_to_face = torch.zeros((E,2,2),dtype=torch.long,device=faces.device) #E,LR=2,S=2
58
+ scatter_src = torch.cartesian_prod(torch.arange(0,F,device=faces.device),torch.arange(0,3,device=faces.device)) #F*3,2
59
+ edge_to_face.reshape(2*E,2).scatter_(dim=0,index=(2*full_to_unique+is_right)[:,None].expand(F*3,2),src=scatter_src) #E,LR=2,S=2
60
+ edge_to_face[0] = 0
61
+ return edges, face_to_edge, edge_to_face
62
+
63
+ def calc_edge_length(
64
+ vertices:torch.Tensor, #V,3 first may be dummy
65
+ edges:torch.Tensor, #E,2 long, lower vertex index first, (0,0) for unused
66
+ )->torch.Tensor: #E
67
+
68
+ full_vertices = vertices[edges] #E,2,3
69
+ a,b = full_vertices.unbind(dim=1) #E,3
70
+ return torch.norm(a-b,p=2,dim=-1)
71
+
72
+ def calc_face_normals(
73
+ vertices:torch.Tensor, #V,3 first vertex may be unreferenced
74
+ faces:torch.Tensor, #F,3 long, first face may be all zero
75
+ normalize:bool=False,
76
+ )->torch.Tensor: #F,3
77
+ """
78
+ n
79
+ |
80
+ c0 corners ordered counterclockwise when
81
+ / \ looking onto surface (in neg normal direction)
82
+ c1---c2
83
+ """
84
+ full_vertices = vertices[faces] #F,C=3,3
85
+ v0,v1,v2 = full_vertices.unbind(dim=1) #F,3
86
+ face_normals = torch.cross(v1-v0,v2-v0, dim=1) #F,3
87
+ if normalize:
88
+ face_normals = tfunc.normalize(face_normals, eps=1e-6, dim=1) #TODO inplace?
89
+ return face_normals #F,3
90
+
91
+ def calc_vertex_normals(
92
+ vertices:torch.Tensor, #V,3 first vertex may be unreferenced
93
+ faces:torch.Tensor, #F,3 long, first face may be all zero
94
+ face_normals:torch.Tensor=None, #F,3, not normalized
95
+ )->torch.Tensor: #F,3
96
+
97
+ F = faces.shape[0]
98
+
99
+ if face_normals is None:
100
+ face_normals = calc_face_normals(vertices,faces)
101
+
102
+ vertex_normals = torch.zeros((vertices.shape[0],3,3),dtype=vertices.dtype,device=vertices.device) #V,C=3,3
103
+ vertex_normals.scatter_add_(dim=0,index=faces[:,:,None].expand(F,3,3),src=face_normals[:,None,:].expand(F,3,3))
104
+ vertex_normals = vertex_normals.sum(dim=1) #V,3
105
+ return tfunc.normalize(vertex_normals, eps=1e-6, dim=1)
106
+
107
+ def calc_face_ref_normals(
108
+ faces:torch.Tensor, #F,3 long, 0 for unused
109
+ vertex_normals:torch.Tensor, #V,3 first unused
110
+ normalize:bool=False,
111
+ )->torch.Tensor: #F,3
112
+ """calculate reference normals for face flip detection"""
113
+ full_normals = vertex_normals[faces] #F,C=3,3
114
+ ref_normals = full_normals.sum(dim=1) #F,3
115
+ if normalize:
116
+ ref_normals = tfunc.normalize(ref_normals, eps=1e-6, dim=1)
117
+ return ref_normals
118
+
119
+ def pack(
120
+ vertices:torch.Tensor, #V,3 first unused and nan
121
+ faces:torch.Tensor, #F,3 long, 0 for unused
122
+ )->tuple[torch.Tensor,torch.Tensor]: #(vertices,faces), keeps first vertex unused
123
+ """removes unused elements in vertices and faces"""
124
+ V = vertices.shape[0]
125
+
126
+ # remove unused faces
127
+ used_faces = faces[:,0]!=0
128
+ used_faces[0] = True
129
+ faces = faces[used_faces] #sync
130
+
131
+ # remove unused vertices
132
+ used_vertices = torch.zeros(V,3,dtype=torch.bool,device=vertices.device)
133
+ used_vertices.scatter_(dim=0,index=faces,value=True,reduce='add') #TODO int faster?
134
+ used_vertices = used_vertices.any(dim=1)
135
+ used_vertices[0] = True
136
+ vertices = vertices[used_vertices] #sync
137
+
138
+ # update used faces
139
+ ind = torch.zeros(V,dtype=torch.long,device=vertices.device)
140
+ V1 = used_vertices.sum()
141
+ ind[used_vertices] = torch.arange(0,V1,device=vertices.device) #sync
142
+ faces = ind[faces]
143
+
144
+ return vertices,faces
145
+
146
+ def split_edges(
147
+ vertices:torch.Tensor, #V,3 first unused
148
+ faces:torch.Tensor, #F,3 long, 0 for unused
149
+ edges:torch.Tensor, #E,2 long 0 for unused, lower vertex index first
150
+ face_to_edge:torch.Tensor, #F,3 long 0 for unused
151
+ splits, #E bool
152
+ pack_faces:bool=True,
153
+ )->tuple[torch.Tensor,torch.Tensor]: #(vertices,faces)
154
+
155
+ # c2 c2 c...corners = faces
156
+ # . . . . s...side_vert, 0 means no split
157
+ # . . .N2 . S...shrunk_face
158
+ # . . . . Ni...new_faces
159
+ # s2 s1 s2|c2...s1|c1
160
+ # . . . . .
161
+ # . . . S . .
162
+ # . . . . N1 .
163
+ # c0...(s0=0)....c1 s0|c0...........c1
164
+ #
165
+ # pseudo-code:
166
+ # S = [s0|c0,s1|c1,s2|c2] example:[c0,s1,s2]
167
+ # split = side_vert!=0 example:[False,True,True]
168
+ # N0 = split[0]*[c0,s0,s2|c2] example:[0,0,0]
169
+ # N1 = split[1]*[c1,s1,s0|c0] example:[c1,s1,c0]
170
+ # N2 = split[2]*[c2,s2,s1|c1] example:[c2,s2,s1]
171
+
172
+ V = vertices.shape[0]
173
+ F = faces.shape[0]
174
+ S = splits.sum().item() #sync
175
+
176
+ if S==0:
177
+ return vertices,faces
178
+
179
+ edge_vert = torch.zeros_like(splits, dtype=torch.long) #E
180
+ edge_vert[splits] = torch.arange(V,V+S,dtype=torch.long,device=vertices.device) #E 0 for no split, sync
181
+ side_vert = edge_vert[face_to_edge] #F,3 long, 0 for no split
182
+ split_edges = edges[splits] #S sync
183
+
184
+ #vertices
185
+ split_vertices = vertices[split_edges].mean(dim=1) #S,3
186
+ vertices = torch.concat((vertices,split_vertices),dim=0)
187
+
188
+ #faces
189
+ side_split = side_vert!=0 #F,3
190
+ shrunk_faces = torch.where(side_split,side_vert,faces) #F,3 long, 0 for no split
191
+ new_faces = side_split[:,:,None] * torch.stack((faces,side_vert,shrunk_faces.roll(1,dims=-1)),dim=-1) #F,N=3,C=3
192
+ faces = torch.concat((shrunk_faces,new_faces.reshape(F*3,3))) #4F,3
193
+ if pack_faces:
194
+ mask = faces[:,0]!=0
195
+ mask[0] = True
196
+ faces = faces[mask] #F',3 sync
197
+
198
+ return vertices,faces
199
+
200
+ def collapse_edges(
201
+ vertices:torch.Tensor, #V,3 first unused
202
+ faces:torch.Tensor, #F,3 long 0 for unused
203
+ edges:torch.Tensor, #E,2 long 0 for unused, lower vertex index first
204
+ priorities:torch.Tensor, #E float
205
+ stable:bool=False, #only for unit testing
206
+ )->tuple[torch.Tensor,torch.Tensor]: #(vertices,faces)
207
+
208
+ V = vertices.shape[0]
209
+
210
+ # check spacing
211
+ _,order = priorities.sort(stable=stable) #E
212
+ rank = torch.zeros_like(order)
213
+ rank[order] = torch.arange(0,len(rank),device=rank.device)
214
+ vert_rank = torch.zeros(V,dtype=torch.long,device=vertices.device) #V
215
+ edge_rank = rank #E
216
+ for i in range(3):
217
+ torch_scatter.scatter_max(src=edge_rank[:,None].expand(-1,2).reshape(-1),index=edges.reshape(-1),dim=0,out=vert_rank)
218
+ edge_rank,_ = vert_rank[edges].max(dim=-1) #E
219
+ candidates = edges[(edge_rank==rank).logical_and_(priorities>0)] #E',2
220
+
221
+ # check connectivity
222
+ vert_connections = torch.zeros(V,dtype=torch.long,device=vertices.device) #V
223
+ vert_connections[candidates[:,0]] = 1 #start
224
+ edge_connections = vert_connections[edges].sum(dim=-1) #E, edge connected to start
225
+ vert_connections.scatter_add_(dim=0,index=edges.reshape(-1),src=edge_connections[:,None].expand(-1,2).reshape(-1))# one edge from start
226
+ vert_connections[candidates] = 0 #clear start and end
227
+ edge_connections = vert_connections[edges].sum(dim=-1) #E, one or two edges from start
228
+ vert_connections.scatter_add_(dim=0,index=edges.reshape(-1),src=edge_connections[:,None].expand(-1,2).reshape(-1)) #one or two edges from start
229
+ collapses = candidates[vert_connections[candidates[:,1]] <= 2] # E" not more than two connections between start and end
230
+
231
+ # mean vertices
232
+ vertices[collapses[:,0]] = vertices[collapses].mean(dim=1) #TODO dim?
233
+
234
+ # update faces
235
+ dest = torch.arange(0,V,dtype=torch.long,device=vertices.device) #V
236
+ dest[collapses[:,1]] = dest[collapses[:,0]]
237
+ faces = dest[faces] #F,3 TODO optimize?
238
+ c0,c1,c2 = faces.unbind(dim=-1)
239
+ collapsed = (c0==c1).logical_or_(c1==c2).logical_or_(c0==c2)
240
+ faces[collapsed] = 0
241
+
242
+ return vertices,faces
243
+
244
+ def calc_face_collapses(
245
+ vertices:torch.Tensor, #V,3 first unused
246
+ faces:torch.Tensor, #F,3 long, 0 for unused
247
+ edges:torch.Tensor, #E,2 long 0 for unused, lower vertex index first
248
+ face_to_edge:torch.Tensor, #F,3 long 0 for unused
249
+ edge_length:torch.Tensor, #E
250
+ face_normals:torch.Tensor, #F,3
251
+ vertex_normals:torch.Tensor, #V,3 first unused
252
+ min_edge_length:torch.Tensor=None, #V
253
+ area_ratio = 0.5, #collapse if area < min_edge_length**2 * area_ratio
254
+ shortest_probability = 0.8
255
+ )->torch.Tensor: #E edges to collapse
256
+
257
+ E = edges.shape[0]
258
+ F = faces.shape[0]
259
+
260
+ # face flips
261
+ ref_normals = calc_face_ref_normals(faces,vertex_normals,normalize=False) #F,3
262
+ face_collapses = (face_normals*ref_normals).sum(dim=-1)<0 #F
263
+
264
+ # small faces
265
+ if min_edge_length is not None:
266
+ min_face_length = min_edge_length[faces].mean(dim=-1) #F
267
+ min_area = min_face_length**2 * area_ratio #F
268
+ face_collapses.logical_or_(face_normals.norm(dim=-1) < min_area*2) #F
269
+ face_collapses[0] = False
270
+
271
+ # faces to edges
272
+ face_length = edge_length[face_to_edge] #F,3
273
+
274
+ if shortest_probability<1:
275
+ #select shortest edge with shortest_probability chance
276
+ randlim = round(2/(1-shortest_probability))
277
+ rand_ind = torch.randint(0,randlim,size=(F,),device=faces.device).clamp_max_(2) #selected edge local index in face
278
+ sort_ind = torch.argsort(face_length,dim=-1,descending=True) #F,3
279
+ local_ind = sort_ind.gather(dim=-1,index=rand_ind[:,None])
280
+ else:
281
+ local_ind = torch.argmin(face_length,dim=-1)[:,None] #F,1 0...2 shortest edge local index in face
282
+
283
+ edge_ind = face_to_edge.gather(dim=1,index=local_ind)[:,0] #F 0...E selected edge global index
284
+ edge_collapses = torch.zeros(E,dtype=torch.long,device=vertices.device)
285
+ edge_collapses.scatter_add_(dim=0,index=edge_ind,src=face_collapses.long()) #TODO legal for bool?
286
+
287
+ return edge_collapses.bool()
288
+
289
+ def flip_edges(
290
+ vertices:torch.Tensor, #V,3 first unused
291
+ faces:torch.Tensor, #F,3 long, first must be 0, 0 for unused
292
+ edges:torch.Tensor, #E,2 long, first must be 0, 0 for unused, lower vertex index first
293
+ edge_to_face:torch.Tensor, #E,[left,right],[face,side]
294
+ with_border:bool=True, #handle border edges (D=4 instead of D=6)
295
+ with_normal_check:bool=True, #check face normal flips
296
+ stable:bool=False, #only for unit testing
297
+ ):
298
+ V = vertices.shape[0]
299
+ E = edges.shape[0]
300
+ device=vertices.device
301
+ vertex_degree = torch.zeros(V,dtype=torch.long,device=device) #V long
302
+ vertex_degree.scatter_(dim=0,index=edges.reshape(E*2),value=1,reduce='add')
303
+ neighbor_corner = (edge_to_face[:,:,1] + 2) % 3 #go from side to corner
304
+ neighbors = faces[edge_to_face[:,:,0],neighbor_corner] #E,LR=2
305
+ edge_is_inside = neighbors.all(dim=-1) #E
306
+
307
+ if with_border:
308
+ # inside vertices should have D=6, border edges D=4, so we subtract 2 for all inside vertices
309
+ # need to use float for masks in order to use scatter(reduce='multiply')
310
+ vertex_is_inside = torch.ones(V,2,dtype=torch.float32,device=vertices.device) #V,2 float
311
+ src = edge_is_inside.type(torch.float32)[:,None].expand(E,2) #E,2 float
312
+ vertex_is_inside.scatter_(dim=0,index=edges,src=src,reduce='multiply')
313
+ vertex_is_inside = vertex_is_inside.prod(dim=-1,dtype=torch.long) #V long
314
+ vertex_degree -= 2 * vertex_is_inside #V long
315
+
316
+ neighbor_degrees = vertex_degree[neighbors] #E,LR=2
317
+ edge_degrees = vertex_degree[edges] #E,2
318
+ #
319
+ # loss = Sum_over_affected_vertices((new_degree-6)**2)
320
+ # loss_change = Sum_over_neighbor_vertices((degree+1-6)**2-(degree-6)**2)
321
+ # + Sum_over_edge_vertices((degree-1-6)**2-(degree-6)**2)
322
+ # = 2 * (2 + Sum_over_neighbor_vertices(degree) - Sum_over_edge_vertices(degree))
323
+ #
324
+ loss_change = 2 + neighbor_degrees.sum(dim=-1) - edge_degrees.sum(dim=-1) #E
325
+ candidates = torch.logical_and(loss_change<0, edge_is_inside) #E
326
+ loss_change = loss_change[candidates] #E'
327
+ if loss_change.shape[0]==0:
328
+ return
329
+
330
+ edges_neighbors = torch.concat((edges[candidates],neighbors[candidates]),dim=-1) #E',4
331
+ _,order = loss_change.sort(descending=True, stable=stable) #E'
332
+ rank = torch.zeros_like(order)
333
+ rank[order] = torch.arange(0,len(rank),device=rank.device)
334
+ vertex_rank = torch.zeros((V,4),dtype=torch.long,device=device) #V,4
335
+ torch_scatter.scatter_max(src=rank[:,None].expand(-1,4),index=edges_neighbors,dim=0,out=vertex_rank)
336
+ vertex_rank,_ = vertex_rank.max(dim=-1) #V
337
+ neighborhood_rank,_ = vertex_rank[edges_neighbors].max(dim=-1) #E'
338
+ flip = rank==neighborhood_rank #E'
339
+
340
+ if with_normal_check:
341
+ # cl-<-----e1 e0,e1...edge, e0<e1
342
+ # | /A L,R....left and right face
343
+ # | L / | both triangles ordered counter clockwise
344
+ # | / R | normals pointing out of screen
345
+ # V/ |
346
+ # e0---->-cr
347
+ v = vertices[edges_neighbors] #E",4,3
348
+ v = v - v[:,0:1] #make relative to e0
349
+ e1 = v[:,1]
350
+ cl = v[:,2]
351
+ cr = v[:,3]
352
+ n = torch.cross(e1,cl) + torch.cross(cr,e1) #sum of old normal vectors
353
+ flip.logical_and_(torch.sum(n*torch.cross(cr,cl),dim=-1)>0) #first new face
354
+ flip.logical_and_(torch.sum(n*torch.cross(cl-e1,cr-e1),dim=-1)>0) #second new face
355
+
356
+ flip_edges_neighbors = edges_neighbors[flip] #E",4
357
+ flip_edge_to_face = edge_to_face[candidates,:,0][flip] #E",2
358
+ flip_faces = flip_edges_neighbors[:,[[0,3,2],[1,2,3]]] #E",2,3
359
+ faces.scatter_(dim=0,index=flip_edge_to_face.reshape(-1,1).expand(-1,3),src=flip_faces.reshape(-1,3))
econdataset.py ADDED
@@ -0,0 +1,370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
4
+ # holder of all proprietary rights on this computer program.
5
+ # You can only use this computer program if you have closed
6
+ # a license agreement with MPG or you get the right to use the computer
7
+ # program from someone who is authorized to grant you that right.
8
+ # Any use of the computer program without a valid license is prohibited and
9
+ # liable to prosecution.
10
+ #
11
+ # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
12
+ # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
13
+ # for Intelligent Systems. All rights reserved.
14
+ #
15
+ # Contact: [email protected]
16
+
17
+ from lib.hybrik.models.simple3dpose import HybrIKBaseSMPLCam
18
+ from lib.pixielib.utils.config import cfg as pixie_cfg
19
+ from lib.pixielib.pixie import PIXIE
20
+ import lib.smplx as smplx
21
+ # from lib.pare.pare.core.tester import PARETester
22
+ from lib.pymaf.utils.geometry import rot6d_to_rotmat, batch_rodrigues, rotation_matrix_to_angle_axis
23
+ from lib.pymaf.utils.imutils import process_image
24
+ from lib.common.imutils import econ_process_image
25
+ from lib.pymaf.core import path_config
26
+ from lib.pymaf.models import pymaf_net
27
+ from lib.common.config import cfg
28
+ from lib.common.render import Render
29
+ from lib.dataset.body_model import TetraSMPLModel
30
+ from lib.dataset.mesh_util import get_visibility
31
+ from utils.smpl_util import SMPLX
32
+ import os.path as osp
33
+ import os
34
+ import torch
35
+ import numpy as np
36
+ import random
37
+ from termcolor import colored
38
+ from PIL import ImageFile
39
+ from torchvision.models import detection
40
+
41
+
42
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
43
+
44
+
45
+ class SMPLDataset():
46
+
47
+ def __init__(self, cfg, device):
48
+
49
+ random.seed(1993)
50
+
51
+ self.image_dir = cfg['image_dir']
52
+ self.seg_dir = cfg['seg_dir']
53
+ self.hps_type = cfg['hps_type']
54
+ self.smpl_type = 'smpl' if cfg['hps_type'] != 'pixie' else 'smplx'
55
+ self.smpl_gender = 'neutral'
56
+ self.colab = cfg['colab']
57
+
58
+ self.device = device
59
+
60
+ keep_lst = [f"{self.image_dir}/{i}" for i in sorted(os.listdir(self.image_dir))]
61
+ img_fmts = ['jpg', 'png', 'jpeg', "JPG", 'bmp']
62
+ keep_lst = [item for item in keep_lst if item.split(".")[-1] in img_fmts]
63
+
64
+ self.subject_list = [item for item in keep_lst if item.split(".")[-1] in img_fmts]
65
+
66
+ if self.colab:
67
+ self.subject_list = [self.subject_list[0]]
68
+
69
+ # smpl related
70
+ self.smpl_data = SMPLX()
71
+
72
+ # smpl-smplx correspondence
73
+ self.smpl_joint_ids_24 = np.arange(22).tolist() + [68, 73]
74
+ self.smpl_joint_ids_24_pixie = np.arange(22).tolist() + [68 + 61, 72 + 68]
75
+ self.get_smpl_model = lambda smpl_type, smpl_gender: smplx.create(model_path=self.smpl_data.
76
+ model_dir,
77
+ gender=smpl_gender,
78
+ model_type=smpl_type,
79
+ ext='npz')
80
+
81
+ # Load SMPL model
82
+ self.smpl_model = self.get_smpl_model(self.smpl_type, self.smpl_gender).to(self.device)
83
+ self.faces = self.smpl_model.faces
84
+
85
+ if self.hps_type == 'pymaf':
86
+ self.hps = pymaf_net(path_config.SMPL_MEAN_PARAMS, pretrained=True).to(self.device)
87
+ self.hps.load_state_dict(torch.load(path_config.CHECKPOINT_FILE)['model'], strict=True)
88
+ self.hps.eval()
89
+
90
+ elif self.hps_type == 'pare':
91
+ self.hps = PARETester(path_config.CFG, path_config.CKPT).model
92
+ elif self.hps_type == 'pixie':
93
+ self.hps = PIXIE(config=pixie_cfg, device=self.device)
94
+ self.smpl_model = self.hps.smplx
95
+ elif self.hps_type == 'hybrik':
96
+ smpl_path = osp.join(self.smpl_data.model_dir, "smpl/SMPL_NEUTRAL.pkl")
97
+ self.hps = HybrIKBaseSMPLCam(cfg_file=path_config.HYBRIK_CFG,
98
+ smpl_path=smpl_path,
99
+ data_path=path_config.hybrik_data_dir)
100
+ self.hps.load_state_dict(torch.load(path_config.HYBRIK_CKPT, map_location='cpu'),
101
+ strict=False)
102
+ self.hps.to(self.device)
103
+ elif self.hps_type == 'bev':
104
+ try:
105
+ import bev
106
+ except:
107
+ print('Could not find bev, installing via pip install --upgrade simple-romp')
108
+ os.system('pip install simple-romp==1.0.3')
109
+ import bev
110
+ settings = bev.main.default_settings
111
+ # change the argparse settings of bev here if you prefer other settings.
112
+ settings.mode = 'image'
113
+ settings.GPU = int(str(self.device).split(':')[1])
114
+ settings.show_largest = True
115
+ # settings.show = True # uncommit this to show the original BEV predictions
116
+ self.hps = bev.BEV(settings)
117
+
118
+ self.detector=detection.maskrcnn_resnet50_fpn(pretrained=True)
119
+ self.detector.eval()
120
+ print(colored(f"Using {self.hps_type} as HPS Estimator\n", "green"))
121
+
122
+ self.render = Render(size=512, device=device)
123
+
124
+ def __len__(self):
125
+ return len(self.subject_list)
126
+
127
+ def compute_vis_cmap(self, smpl_verts, smpl_faces):
128
+
129
+ (xy, z) = torch.as_tensor(smpl_verts).split([2, 1], dim=1)
130
+ smpl_vis = get_visibility(xy, -z, torch.as_tensor(smpl_faces).long())
131
+ smpl_cmap = self.smpl_data.cmap_smpl_vids(self.smpl_type)
132
+
133
+ return {
134
+ 'smpl_vis': smpl_vis.unsqueeze(0).to(self.device),
135
+ 'smpl_cmap': smpl_cmap.unsqueeze(0).to(self.device),
136
+ 'smpl_verts': smpl_verts.unsqueeze(0)
137
+ }
138
+
139
+ def compute_voxel_verts(self, body_pose, global_orient, betas, trans, scale):
140
+
141
+ smpl_path = osp.join(self.smpl_data.model_dir, "smpl/SMPL_NEUTRAL.pkl")
142
+ tetra_path = osp.join(self.smpl_data.tedra_dir, 'tetra_neutral_adult_smpl.npz')
143
+ smpl_model = TetraSMPLModel(smpl_path, tetra_path, 'adult')
144
+
145
+ pose = torch.cat([global_orient[0], body_pose[0]], dim=0)
146
+ smpl_model.set_params(rotation_matrix_to_angle_axis(rot6d_to_rotmat(pose)), beta=betas[0])
147
+
148
+ verts = np.concatenate([smpl_model.verts, smpl_model.verts_added],
149
+ axis=0) * scale.item() + trans.detach().cpu().numpy()
150
+ faces = np.loadtxt(osp.join(self.smpl_data.tedra_dir, 'tetrahedrons_neutral_adult.txt'),
151
+ dtype=np.int32) - 1
152
+
153
+ pad_v_num = int(8000 - verts.shape[0])
154
+ pad_f_num = int(25100 - faces.shape[0])
155
+
156
+ verts = np.pad(verts,
157
+ ((0, pad_v_num),
158
+ (0, 0)), mode='constant', constant_values=0.0).astype(np.float32) * 0.5
159
+ faces = np.pad(faces, ((0, pad_f_num), (0, 0)), mode='constant',
160
+ constant_values=0.0).astype(np.int32)
161
+
162
+ verts[:, 2] *= -1.0
163
+
164
+ voxel_dict = {
165
+ 'voxel_verts': torch.from_numpy(verts).to(self.device).unsqueeze(0).float(),
166
+ 'voxel_faces': torch.from_numpy(faces).to(self.device).unsqueeze(0).long(),
167
+ 'pad_v_num': torch.tensor(pad_v_num).to(self.device).unsqueeze(0).long(),
168
+ 'pad_f_num': torch.tensor(pad_f_num).to(self.device).unsqueeze(0).long()
169
+ }
170
+
171
+ return voxel_dict
172
+
173
+ def __getitem__(self, index):
174
+
175
+ img_path = self.subject_list[index]
176
+ img_name = img_path.split("/")[-1].rsplit(".", 1)[0]
177
+ print(img_name)
178
+ # smplx_param_path=f'./data/thuman2/smplx/{img_name[:-2]}.pkl'
179
+ # smplx_param = np.load(smplx_param_path, allow_pickle=True)
180
+
181
+ if self.seg_dir is None:
182
+ img_icon, img_hps, img_ori, img_mask, uncrop_param = process_image(
183
+ img_path, self.hps_type, 512, self.device)
184
+
185
+ data_dict = {
186
+ 'name': img_name,
187
+ 'image': img_icon.to(self.device).unsqueeze(0),
188
+ 'ori_image': img_ori,
189
+ 'mask': img_mask,
190
+ 'uncrop_param': uncrop_param
191
+ }
192
+
193
+ else:
194
+ img_icon, img_hps, img_ori, img_mask, uncrop_param, segmentations = process_image(
195
+ img_path,
196
+ self.hps_type,
197
+ 512,
198
+ self.device,
199
+ seg_path=os.path.join(self.seg_dir, f'{img_name}.json'))
200
+ data_dict = {
201
+ 'name': img_name,
202
+ 'image': img_icon.to(self.device).unsqueeze(0),
203
+ 'ori_image': img_ori,
204
+ 'mask': img_mask,
205
+ 'uncrop_param': uncrop_param,
206
+ 'segmentations': segmentations
207
+ }
208
+
209
+ arr_dict=econ_process_image(img_path,self.hps_type,True,512,self.detector)
210
+ data_dict['hands_visibility']=arr_dict['hands_visibility']
211
+
212
+ with torch.no_grad():
213
+ # import ipdb; ipdb.set_trace()
214
+ preds_dict = self.hps.forward(img_hps)
215
+
216
+ data_dict['smpl_faces'] = torch.Tensor(self.faces.astype(np.int64)).long().unsqueeze(0).to(
217
+ self.device)
218
+
219
+ if self.hps_type == 'pymaf':
220
+ output = preds_dict['smpl_out'][-1]
221
+ scale, tranX, tranY = output['theta'][0, :3]
222
+ data_dict['betas'] = output['pred_shape']
223
+ data_dict['body_pose'] = output['rotmat'][:, 1:]
224
+ data_dict['global_orient'] = output['rotmat'][:, 0:1]
225
+ data_dict['smpl_verts'] = output['verts'] # 不确定尺度是否一样
226
+ data_dict["type"] = "smpl"
227
+
228
+ elif self.hps_type == 'pare':
229
+ data_dict['body_pose'] = preds_dict['pred_pose'][:, 1:]
230
+ data_dict['global_orient'] = preds_dict['pred_pose'][:, 0:1]
231
+ data_dict['betas'] = preds_dict['pred_shape']
232
+ data_dict['smpl_verts'] = preds_dict['smpl_vertices']
233
+ scale, tranX, tranY = preds_dict['pred_cam'][0, :3]
234
+ data_dict["type"] = "smpl"
235
+
236
+ elif self.hps_type == 'pixie':
237
+ data_dict.update(preds_dict)
238
+ data_dict['body_pose'] = preds_dict['body_pose']
239
+ data_dict['global_orient'] = preds_dict['global_pose']
240
+ data_dict['betas'] = preds_dict['shape']
241
+ data_dict['smpl_verts'] = preds_dict['vertices']
242
+ scale, tranX, tranY = preds_dict['cam'][0, :3]
243
+ data_dict["type"] = "smplx"
244
+
245
+ elif self.hps_type == 'hybrik':
246
+ data_dict['body_pose'] = preds_dict['pred_theta_mats'][:, 1:]
247
+ data_dict['global_orient'] = preds_dict['pred_theta_mats'][:, [0]]
248
+ data_dict['betas'] = preds_dict['pred_shape']
249
+ data_dict['smpl_verts'] = preds_dict['pred_vertices']
250
+ scale, tranX, tranY = preds_dict['pred_camera'][0, :3]
251
+ scale = scale * 2
252
+ data_dict["type"] = "smpl"
253
+
254
+ elif self.hps_type == 'bev':
255
+ data_dict['betas'] = torch.from_numpy(preds_dict['smpl_betas'])[[0], :10].to(
256
+ self.device).float()
257
+ pred_thetas = batch_rodrigues(
258
+ torch.from_numpy(preds_dict['smpl_thetas'][0]).reshape(-1, 3)).float()
259
+ data_dict['body_pose'] = pred_thetas[1:][None].to(self.device)
260
+ data_dict['global_orient'] = pred_thetas[[0]][None].to(self.device)
261
+ data_dict['smpl_verts'] = torch.from_numpy(preds_dict['verts'][[0]]).to(
262
+ self.device).float()
263
+ tranX = preds_dict['cam_trans'][0, 0]
264
+ tranY = preds_dict['cam'][0, 1] + 0.28
265
+ scale = preds_dict['cam'][0, 0] * 1.1
266
+ data_dict["type"] = "smpl"
267
+
268
+ data_dict['scale'] = scale
269
+ data_dict['trans'] = torch.tensor([tranX, tranY, 0.0]).unsqueeze(0).to(self.device).float()
270
+
271
+ # data_dict info (key-shape):
272
+ # scale, tranX, tranY - tensor.float
273
+ # betas - [1,10] / [1, 200]
274
+ # body_pose - [1, 23, 3, 3] / [1, 21, 3, 3]
275
+ # global_orient - [1, 1, 3, 3]
276
+ # smpl_verts - [1, 6890, 3] / [1, 10475, 3]
277
+
278
+ # from rot_mat to rot_6d for better optimization
279
+ N_body = data_dict["body_pose"].shape[1]
280
+ data_dict["body_pose"] = data_dict["body_pose"][:, :, :, :2].reshape(1, N_body, -1)
281
+ data_dict["global_orient"] = data_dict["global_orient"][:, :, :, :2].reshape(1, 1, -1)
282
+
283
+ return data_dict
284
+
285
+ def render_normal(self, verts, faces):
286
+
287
+ # render optimized mesh (normal, T_normal, image [-1,1])
288
+ self.render.load_meshes(verts, faces)
289
+ return self.render.get_rgb_image()
290
+
291
+ def render_depth(self, verts, faces):
292
+
293
+ # render optimized mesh (normal, T_normal, image [-1,1])
294
+ self.render.load_meshes(verts, faces)
295
+ return self.render.get_depth_map(cam_ids=[0, 2])
296
+
297
+ def visualize_alignment(self, data):
298
+
299
+ import vedo
300
+ import trimesh
301
+
302
+ if self.hps_type != 'pixie':
303
+ smpl_out = self.smpl_model(betas=data['betas'],
304
+ body_pose=data['body_pose'],
305
+ global_orient=data['global_orient'],
306
+ pose2rot=False)
307
+ smpl_verts = ((smpl_out.vertices + data['trans']) *
308
+ data['scale']).detach().cpu().numpy()[0]
309
+ else:
310
+ smpl_verts, _, _ = self.smpl_model(shape_params=data['betas'],
311
+ expression_params=data['exp'],
312
+ body_pose=data['body_pose'],
313
+ global_pose=data['global_orient'],
314
+ jaw_pose=data['jaw_pose'],
315
+ left_hand_pose=data['left_hand_pose'],
316
+ right_hand_pose=data['right_hand_pose'])
317
+
318
+ smpl_verts = ((smpl_verts + data['trans']) * data['scale']).detach().cpu().numpy()[0]
319
+
320
+ smpl_verts *= np.array([1.0, -1.0, -1.0])
321
+ faces = data['smpl_faces'][0].detach().cpu().numpy()
322
+
323
+ image_P = data['image']
324
+ image_F, image_B = self.render_normal(smpl_verts, faces)
325
+
326
+ # create plot
327
+ vp = vedo.Plotter(title="", size=(1500, 1500))
328
+ vis_list = []
329
+
330
+ image_F = (0.5 * (1.0 + image_F[0].permute(1, 2, 0).detach().cpu().numpy()) * 255.0)
331
+ image_B = (0.5 * (1.0 + image_B[0].permute(1, 2, 0).detach().cpu().numpy()) * 255.0)
332
+ image_P = (0.5 * (1.0 + image_P[0].permute(1, 2, 0).detach().cpu().numpy()) * 255.0)
333
+
334
+ vis_list.append(
335
+ vedo.Picture(image_P * 0.5 + image_F * 0.5).scale(2.0 / image_P.shape[0]).pos(
336
+ -1.0, -1.0, 1.0))
337
+ vis_list.append(vedo.Picture(image_F).scale(2.0 / image_F.shape[0]).pos(-1.0, -1.0, -0.5))
338
+ vis_list.append(vedo.Picture(image_B).scale(2.0 / image_B.shape[0]).pos(-1.0, -1.0, -1.0))
339
+
340
+ # create a mesh
341
+ mesh = trimesh.Trimesh(smpl_verts, faces, process=False)
342
+ mesh.visual.vertex_colors = [200, 200, 0]
343
+ vis_list.append(mesh)
344
+
345
+ vp.show(*vis_list, bg="white", axes=1, interactive=True)
346
+
347
+
348
+ if __name__ == '__main__':
349
+
350
+ cfg.merge_from_file("./configs/icon-filter.yaml")
351
+ cfg.merge_from_file('./lib/pymaf/configs/pymaf_config.yaml')
352
+
353
+ cfg_show_list = ['test_gpus', ['0'], 'mcube_res', 512, 'clean_mesh', False]
354
+
355
+ cfg.merge_from_list(cfg_show_list)
356
+ cfg.freeze()
357
+
358
+
359
+ device = torch.device('cuda:0')
360
+
361
+ dataset = SMPLDataset(
362
+ {
363
+ 'image_dir': "./examples",
364
+ 'has_det': True, # w/ or w/o detection
365
+ 'hps_type': 'bev' # pymaf/pare/pixie/hybrik/bev
366
+ },
367
+ device)
368
+
369
+ for i in range(len(dataset)):
370
+ dataset.visualize_alignment(dataset[i])
examples/02986d0998ce01aa0aa67a99fbd1e09a.png ADDED

Git LFS Details

  • SHA256: cc56331155a0a728073fbddd0df4ecc78f7d096af6032a8838df3c8e9ace8d14
  • Pointer size: 131 Bytes
  • Size of remote file: 291 kB
examples/16171.png ADDED

Git LFS Details

  • SHA256: 7cb871d9357ce4a5f24670919bfb4b67d4be977d43dc1d976a8c2fe2bdb11b87
  • Pointer size: 131 Bytes
  • Size of remote file: 527 kB
examples/26d2e846349647ff04c536816e0e8ca1.png ADDED

Git LFS Details

  • SHA256: 0e8be830e21a98ba5c0156ade2b3cd79145ddda7af1b34e5580a54221efb67a5
  • Pointer size: 131 Bytes
  • Size of remote file: 519 kB
examples/30755.png ADDED

Git LFS Details

  • SHA256: b8f53a4c9eb63ae9114a63b1076cef12be8988429706c34762d56f1d895e0ec5
  • Pointer size: 131 Bytes
  • Size of remote file: 473 kB
examples/3930.png ADDED

Git LFS Details

  • SHA256: 7b3f8074aa15d00fad0284ad7e34ad82dbc5f48cfe40d917fe537634b66e5008
  • Pointer size: 131 Bytes
  • Size of remote file: 407 kB
examples/4656716-3016170581.png ADDED

Git LFS Details

  • SHA256: f5a979c7818eba83bfad14695d2f5b80b5ef15f082b530cbcf2bb0b7665d34ca
  • Pointer size: 131 Bytes
  • Size of remote file: 404 kB
examples/663dcd6db19490de0b790da430bd5681.png ADDED

Git LFS Details

  • SHA256: b499922b6df6d6874fea68c571ff3271f68aa6bc40420396f4898e5c58d74dc8
  • Pointer size: 132 Bytes
  • Size of remote file: 1 MB
examples/7332.png ADDED

Git LFS Details

  • SHA256: bb881e50789818c8419e7eac0c1af34eea44a4e9bdad72ebeb45ec03be3337a3
  • Pointer size: 131 Bytes
  • Size of remote file: 460 kB
examples/85891251f52a2399e660a63c2a7fdf40.png ADDED

Git LFS Details

  • SHA256: c0dfcec6c2e12a0b66feb58c11aeb50c3629533f131a84efac4d6f18c325106e
  • Pointer size: 131 Bytes
  • Size of remote file: 235 kB
examples/a689a48d23d6b8d58d67ff5146c6e088.png ADDED

Git LFS Details

  • SHA256: 83474c20b3609ba21f7a75d353e15011cd4b6638970648641e2f343dc1655f71
  • Pointer size: 131 Bytes
  • Size of remote file: 169 kB
examples/b0d178743c7e3e09700aaee8d2b1ec47.png ADDED

Git LFS Details

  • SHA256: 5c365055d601312a490bdbb24246ff2c17b077d7bce440d057cffd63ad98270f
  • Pointer size: 131 Bytes
  • Size of remote file: 516 kB
examples/case5.png ADDED

Git LFS Details

  • SHA256: a214d479935aa2e2842d76d51b221ae009cc0dfa3bd2f5acf5acd010b6760f97
  • Pointer size: 131 Bytes
  • Size of remote file: 943 kB
examples/d40776a1e1582179d97907d36f84d776.png ADDED

Git LFS Details

  • SHA256: 8dc62a7b9ab270eb6d84f1b492968614c41be9472bfd709157f17ee9a72f3c26
  • Pointer size: 131 Bytes
  • Size of remote file: 151 kB
examples/durant.png ADDED

Git LFS Details

  • SHA256: 4f6b517055c7d6f169d3f0017d3f16d4347adc9648131dd098d9ea1af94331a8
  • Pointer size: 131 Bytes
  • Size of remote file: 485 kB
examples/eedb9018-e0eb-45be-33bd-5a0108ca0d8b.png ADDED

Git LFS Details

  • SHA256: ca83675bb83d7413474f32bd612b14ba954a3ababb09f97133e2771a0809dc73
  • Pointer size: 131 Bytes
  • Size of remote file: 308 kB
examples/f14f7d40b72062928461b21c6cc877407e69ee0c_high.png ADDED

Git LFS Details

  • SHA256: 78318b90aa1fbf0723cc2901ff817031642d21a4f5ceb5f0474bda9a989264be
  • Pointer size: 131 Bytes
  • Size of remote file: 604 kB
examples/f6317ac1b0498f4e6ef9d12bd991a9bd1ff4ae04f898-IQTEBw_fw1200.png ADDED

Git LFS Details

  • SHA256: b49f2e141cad2521057a3dc9d9a2f0a4cf61fe1037a3129bc14f1fee13cf51fb
  • Pointer size: 131 Bytes
  • Size of remote file: 294 kB
examples/pexels-barbara-olsen-7869640.png ADDED

Git LFS Details

  • SHA256: 3a85f6ffaf769809c964c50a7bf8ca7e54d146d2c7e4eb6cc06bea2c2259a784
  • Pointer size: 131 Bytes
  • Size of remote file: 500 kB
examples/pexels-julia-m-cameron-4145040.png ADDED

Git LFS Details

  • SHA256: f57b747babe17cbc2aba4ceb19eeecbddef68a7ddaa232636c4008d318e8d632
  • Pointer size: 131 Bytes
  • Size of remote file: 128 kB
examples/pexels-marta-wave-6437749.png ADDED

Git LFS Details

  • SHA256: 25b5fd1cda24ebbc574e2146d24808024cdaea0f2843c04a1733889a4f149516
  • Pointer size: 131 Bytes
  • Size of remote file: 163 kB
examples/pexels-photo-6311555-removebg.png ADDED

Git LFS Details

  • SHA256: a8343230a24d124d3e6c3c4dc75fb2880471e9988b291ef1009a012a024c7e1e
  • Pointer size: 131 Bytes
  • Size of remote file: 136 kB
examples/pexels-zdmit-6780091.png ADDED

Git LFS Details

  • SHA256: c128ea45f41a2e1382e4afde39d843ec37e601de22e1f856c2c29e09e5ed14d0
  • Pointer size: 131 Bytes
  • Size of remote file: 270 kB
inference.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import sys
3
+ import os
4
+
5
+ from typing import Dict, Optional, Tuple, List
6
+ from omegaconf import OmegaConf
7
+ from PIL import Image
8
+ from dataclasses import dataclass
9
+ from collections import defaultdict
10
+ import torch
11
+ import torch.utils.checkpoint
12
+ from torchvision.utils import make_grid, save_image
13
+ from accelerate.utils import set_seed
14
+ from tqdm.auto import tqdm
15
+ import torch.nn.functional as F
16
+ from einops import rearrange
17
+ from rembg import remove, new_session
18
+ import pdb
19
+ from mvdiffusion.pipelines.pipeline_mvdiffusion_unclip import StableUnCLIPImg2ImgPipeline
20
+ from econdataset import SMPLDataset
21
+ from reconstruct import ReMesh
22
+ providers = [
23
+ ('CUDAExecutionProvider', {
24
+ 'device_id': 0,
25
+ 'arena_extend_strategy': 'kSameAsRequested',
26
+ 'gpu_mem_limit': 8 * 1024 * 1024 * 1024,
27
+ 'cudnn_conv_algo_search': 'HEURISTIC',
28
+ })
29
+ ]
30
+ session = new_session(providers=providers)
31
+
32
+ weight_dtype = torch.float16
33
+ def tensor_to_numpy(tensor):
34
+ return tensor.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
35
+
36
+
37
+ @dataclass
38
+ class TestConfig:
39
+ pretrained_model_name_or_path: str
40
+ revision: Optional[str]
41
+ validation_dataset: Dict
42
+ save_dir: str
43
+ seed: Optional[int]
44
+ validation_batch_size: int
45
+ dataloader_num_workers: int
46
+ # save_single_views: bool
47
+ save_mode: str
48
+ local_rank: int
49
+
50
+ pipe_kwargs: Dict
51
+ pipe_validation_kwargs: Dict
52
+ unet_from_pretrained_kwargs: Dict
53
+ validation_guidance_scales: float
54
+ validation_grid_nrow: int
55
+
56
+ num_views: int
57
+ enable_xformers_memory_efficient_attention: bool
58
+ with_smpl: Optional[bool]
59
+
60
+ recon_opt: Dict
61
+
62
+
63
+ def convert_to_numpy(tensor):
64
+ return tensor.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
65
+
66
+ def convert_to_pil(tensor):
67
+ return Image.fromarray(convert_to_numpy(tensor))
68
+
69
+ def save_image(tensor, fp):
70
+ ndarr = convert_to_numpy(tensor)
71
+ # pdb.set_trace()
72
+ save_image_numpy(ndarr, fp)
73
+ return ndarr
74
+
75
+ def save_image_numpy(ndarr, fp):
76
+ im = Image.fromarray(ndarr)
77
+ im.save(fp)
78
+
79
+ def run_inference(dataloader, econdata, pipeline, carving, cfg: TestConfig, save_dir):
80
+ pipeline.set_progress_bar_config(disable=True)
81
+
82
+ if cfg.seed is None:
83
+ generator = None
84
+ else:
85
+ generator = torch.Generator(device=pipeline.unet.device).manual_seed(cfg.seed)
86
+
87
+ images_cond, pred_cat = [], defaultdict(list)
88
+ for case_id, batch in tqdm(enumerate(dataloader)):
89
+ images_cond.append(batch['imgs_in'][:, 0])
90
+
91
+ imgs_in = torch.cat([batch['imgs_in']]*2, dim=0)
92
+ num_views = imgs_in.shape[1]
93
+ imgs_in = rearrange(imgs_in, "B Nv C H W -> (B Nv) C H W")# (B*Nv, 3, H, W)
94
+ if cfg.with_smpl:
95
+ smpl_in = torch.cat([batch['smpl_imgs_in']]*2, dim=0)
96
+ smpl_in = rearrange(smpl_in, "B Nv C H W -> (B Nv) C H W")
97
+ else:
98
+ smpl_in = None
99
+
100
+ normal_prompt_embeddings, clr_prompt_embeddings = batch['normal_prompt_embeddings'], batch['color_prompt_embeddings']
101
+ prompt_embeddings = torch.cat([normal_prompt_embeddings, clr_prompt_embeddings], dim=0)
102
+ prompt_embeddings = rearrange(prompt_embeddings, "B Nv N C -> (B Nv) N C")
103
+
104
+ with torch.autocast("cuda"):
105
+ # B*Nv images
106
+ guidance_scale = cfg.validation_guidance_scales
107
+ unet_out = pipeline(
108
+ imgs_in, None, prompt_embeds=prompt_embeddings,
109
+ dino_feature=None, smpl_in=smpl_in,
110
+ generator=generator, guidance_scale=guidance_scale, output_type='pt', num_images_per_prompt=1,
111
+ **cfg.pipe_validation_kwargs
112
+ )
113
+
114
+ out = unet_out.images
115
+ bsz = out.shape[0] // 2
116
+
117
+ normals_pred = out[:bsz]
118
+ images_pred = out[bsz:]
119
+ if cfg.save_mode == 'concat': ## save concatenated color and normal---------------------
120
+ pred_cat[f"cfg{guidance_scale:.1f}"].append(torch.cat([normals_pred, images_pred], dim=-1)) # b, 3, h, w
121
+ cur_dir = os.path.join(save_dir, f"cropsize-{cfg.validation_dataset.crop_size}-cfg{guidance_scale:.1f}-seed{cfg.seed}-smpl-{cfg.with_smpl}")
122
+ os.makedirs(cur_dir, exist_ok=True)
123
+ for i in range(bsz//num_views):
124
+ scene = batch['filename'][i].split('.')[0]
125
+
126
+ img_in_ = images_cond[-1][i].to(out.device)
127
+ vis_ = [img_in_]
128
+ for j in range(num_views):
129
+ idx = i*num_views + j
130
+ normal = normals_pred[idx]
131
+ color = images_pred[idx]
132
+
133
+ vis_.append(color)
134
+ vis_.append(normal)
135
+
136
+ out_filename = f"{cur_dir}/{scene}.png"
137
+ vis_ = torch.stack(vis_, dim=0)
138
+ vis_ = make_grid(vis_, nrow=len(vis_), padding=0, value_range=(0, 1))
139
+ save_image(vis_, out_filename)
140
+ elif cfg.save_mode == 'rgb':
141
+ for i in range(bsz//num_views):
142
+ scene = batch['filename'][i].split('.')[0]
143
+
144
+ img_in_ = images_cond[-1][i].to(out.device)
145
+ normals, colors = [], []
146
+ for j in range(num_views):
147
+ idx = i*num_views + j
148
+ normal = normals_pred[idx]
149
+ if j == 0:
150
+ color = imgs_in[0].to(out.device)
151
+ else:
152
+ color = images_pred[idx]
153
+ if j in [3, 4]:
154
+ normal = torch.flip(normal, dims=[2])
155
+ color = torch.flip(color, dims=[2])
156
+
157
+ colors.append(color)
158
+ if j == 6:
159
+ normal = F.interpolate(normal.unsqueeze(0), size=(256, 256), mode='bilinear', align_corners=False).squeeze(0)
160
+ normals.append(normal)
161
+
162
+ ## save color and normal---------------------
163
+ # normal_filename = f"normals_{view}_masked.png"
164
+ # rgb_filename = f"color_{view}_masked.png"
165
+ # save_image(normal, os.path.join(scene_dir, normal_filename))
166
+ # save_image(color, os.path.join(scene_dir, rgb_filename))
167
+ normals[0][:, :256, 256:512] = normals[-1]
168
+
169
+ colors = [remove(convert_to_pil(tensor), session=session) for tensor in colors[:6]]
170
+ normals = [remove(convert_to_pil(tensor), session=session) for tensor in normals[:6]]
171
+ pose = econdata.__getitem__(case_id)
172
+ carving.optimize_case(scene, pose, colors, normals)
173
+ torch.cuda.empty_cache()
174
+
175
+
176
+
177
+ def load_pshuman_pipeline(cfg):
178
+ pipeline = StableUnCLIPImg2ImgPipeline.from_pretrained(cfg.pretrained_model_name_or_path, torch_dtype=weight_dtype)
179
+ pipeline.unet.enable_xformers_memory_efficient_attention()
180
+ if torch.cuda.is_available():
181
+ pipeline.to('cuda')
182
+ return pipeline
183
+
184
+ def main(
185
+ cfg: TestConfig
186
+ ):
187
+
188
+ # If passed along, set the training seed now.
189
+ if cfg.seed is not None:
190
+ set_seed(cfg.seed)
191
+ pipeline = load_pshuman_pipeline(cfg)
192
+
193
+
194
+ if cfg.with_smpl:
195
+ from mvdiffusion.data.testdata_with_smpl import SingleImageDataset
196
+ else:
197
+ from mvdiffusion.data.single_image_dataset import SingleImageDataset
198
+
199
+ # Get the dataset
200
+ validation_dataset = SingleImageDataset(
201
+ **cfg.validation_dataset
202
+ )
203
+ validation_dataloader = torch.utils.data.DataLoader(
204
+ validation_dataset, batch_size=cfg.validation_batch_size, shuffle=False, num_workers=cfg.dataloader_num_workers
205
+ )
206
+ dataset_param = {'image_dir': validation_dataset.root_dir, 'seg_dir': None, 'colab': False, 'has_det': True, 'hps_type': 'pixie'}
207
+ econdata = SMPLDataset(dataset_param, device='cuda')
208
+
209
+ carving = ReMesh(cfg.recon_opt, econ_dataset=econdata)
210
+ run_inference(validation_dataloader, econdata, pipeline, carving, cfg, cfg.save_dir)
211
+
212
+
213
+ if __name__ == '__main__':
214
+ parser = argparse.ArgumentParser()
215
+ parser.add_argument('--config', type=str, required=True)
216
+ args, extras = parser.parse_known_args()
217
+ from utils.misc import load_config
218
+
219
+ # parse YAML config to OmegaConf
220
+ cfg = load_config(args.config, cli_args=extras)
221
+ schema = OmegaConf.structured(TestConfig)
222
+ cfg = OmegaConf.merge(schema, cfg)
223
+ main(cfg)
lib/__init__.py ADDED
File without changes
lib/common/__init__.py ADDED
File without changes
lib/common/cloth_extraction.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import json
3
+ import os
4
+ import itertools
5
+ import trimesh
6
+ from matplotlib.path import Path
7
+ from collections import Counter
8
+ from sklearn.neighbors import KNeighborsClassifier
9
+
10
+
11
+ def load_segmentation(path, shape):
12
+ """
13
+ Get a segmentation mask for a given image
14
+ Arguments:
15
+ path: path to the segmentation json file
16
+ shape: shape of the output mask
17
+ Returns:
18
+ Returns a segmentation mask
19
+ """
20
+ with open(path) as json_file:
21
+ dict = json.load(json_file)
22
+ segmentations = []
23
+ for key, val in dict.items():
24
+ if not key.startswith('item'):
25
+ continue
26
+
27
+ # Each item can have multiple polygons. Combine them to one
28
+ # segmentation_coord = list(itertools.chain.from_iterable(val['segmentation']))
29
+ # segmentation_coord = np.round(np.array(segmentation_coord)).astype(int)
30
+
31
+ coordinates = []
32
+ for segmentation_coord in val['segmentation']:
33
+ # The format before is [x1,y1, x2, y2, ....]
34
+ x = segmentation_coord[::2]
35
+ y = segmentation_coord[1::2]
36
+ xy = np.vstack((x, y)).T
37
+ coordinates.append(xy)
38
+
39
+ segmentations.append({
40
+ 'type': val['category_name'],
41
+ 'type_id': val['category_id'],
42
+ 'coordinates': coordinates
43
+ })
44
+
45
+ return segmentations
46
+
47
+
48
+ def smpl_to_recon_labels(recon, smpl, k=1):
49
+ """
50
+ Get the bodypart labels for the recon object by using the labels from the corresponding smpl object
51
+ Arguments:
52
+ recon: trimesh object (fully clothed model)
53
+ shape: trimesh object (smpl model)
54
+ k: number of nearest neighbours to use
55
+ Returns:
56
+ Returns a dictionary containing the bodypart and the corresponding indices
57
+ """
58
+ smpl_vert_segmentation = json.load(
59
+ open(
60
+ os.path.join(os.path.dirname(__file__),
61
+ 'smpl_vert_segmentation.json')))
62
+ n = smpl.vertices.shape[0]
63
+ y = np.array([None] * n)
64
+ for key, val in smpl_vert_segmentation.items():
65
+ y[val] = key
66
+
67
+ classifier = KNeighborsClassifier(n_neighbors=1)
68
+ classifier.fit(smpl.vertices, y)
69
+
70
+ y_pred = classifier.predict(recon.vertices)
71
+
72
+ recon_labels = {}
73
+ for key in smpl_vert_segmentation.keys():
74
+ recon_labels[key] = list(
75
+ np.argwhere(y_pred == key).flatten().astype(int))
76
+
77
+ return recon_labels
78
+
79
+
80
+ def extract_cloth(recon, segmentation, K, R, t, smpl=None):
81
+ """
82
+ Extract a portion of a mesh using 2d segmentation coordinates
83
+ Arguments:
84
+ recon: fully clothed mesh
85
+ seg_coord: segmentation coordinates in 2D (NDC)
86
+ K: intrinsic matrix of the projection
87
+ R: rotation matrix of the projection
88
+ t: translation vector of the projection
89
+ Returns:
90
+ Returns a submesh using the segmentation coordinates
91
+ """
92
+ seg_coord = segmentation['coord_normalized']
93
+ mesh = trimesh.Trimesh(recon.vertices, recon.faces)
94
+ extrinsic = np.zeros((3, 4))
95
+ extrinsic[:3, :3] = R
96
+ extrinsic[:, 3] = t
97
+ P = K[:3, :3] @ extrinsic
98
+
99
+ P_inv = np.linalg.pinv(P)
100
+
101
+ # Each segmentation can contain multiple polygons
102
+ # We need to check them separately
103
+ points_so_far = []
104
+ faces = recon.faces
105
+ for polygon in seg_coord:
106
+ n = len(polygon)
107
+ coords_h = np.hstack((polygon, np.ones((n, 1))))
108
+ # Apply the inverse projection on homogeneus 2D coordinates to get the corresponding 3d Coordinates
109
+ XYZ = P_inv @ coords_h[:, :, None]
110
+ XYZ = XYZ.reshape((XYZ.shape[0], XYZ.shape[1]))
111
+ XYZ = XYZ[:, :3] / XYZ[:, 3, None]
112
+
113
+ p = Path(XYZ[:, :2])
114
+
115
+ grid = p.contains_points(recon.vertices[:, :2])
116
+ indeces = np.argwhere(grid == True)
117
+ points_so_far += list(indeces.flatten())
118
+
119
+ if smpl is not None:
120
+ num_verts = recon.vertices.shape[0]
121
+ recon_labels = smpl_to_recon_labels(recon, smpl)
122
+ body_parts_to_remove = [
123
+ 'rightHand', 'leftToeBase', 'leftFoot', 'rightFoot', 'head',
124
+ 'leftHandIndex1', 'rightHandIndex1', 'rightToeBase', 'leftHand',
125
+ 'rightHand'
126
+ ]
127
+ type = segmentation['type_id']
128
+
129
+ # Remove additional bodyparts that are most likely not part of the segmentation but might intersect (e.g. hand in front of torso)
130
+ # https://github.com/switchablenorms/DeepFashion2
131
+ # Short sleeve clothes
132
+ if type == 1 or type == 3 or type == 10:
133
+ body_parts_to_remove += ['leftForeArm', 'rightForeArm']
134
+ # No sleeves at all or lower body clothes
135
+ elif type == 5 or type == 6 or type == 12 or type == 13 or type == 8 or type == 9:
136
+ body_parts_to_remove += [
137
+ 'leftForeArm', 'rightForeArm', 'leftArm', 'rightArm'
138
+ ]
139
+ # Shorts
140
+ elif type == 7:
141
+ body_parts_to_remove += [
142
+ 'leftLeg', 'rightLeg', 'leftForeArm', 'rightForeArm',
143
+ 'leftArm', 'rightArm'
144
+ ]
145
+
146
+ verts_to_remove = list(
147
+ itertools.chain.from_iterable(
148
+ [recon_labels[part] for part in body_parts_to_remove]))
149
+
150
+ label_mask = np.zeros(num_verts, dtype=bool)
151
+ label_mask[verts_to_remove] = True
152
+
153
+ seg_mask = np.zeros(num_verts, dtype=bool)
154
+ seg_mask[points_so_far] = True
155
+
156
+ # Remove points that belong to other bodyparts
157
+ # If a vertice in pointsSoFar is included in the bodyparts to remove, then these points should be removed
158
+ extra_verts_to_remove = np.array(list(seg_mask) and list(label_mask))
159
+
160
+ combine_mask = np.zeros(num_verts, dtype=bool)
161
+ combine_mask[points_so_far] = True
162
+ combine_mask[extra_verts_to_remove] = False
163
+
164
+ all_indices = np.argwhere(combine_mask == True).flatten()
165
+
166
+ i_x = np.where(np.in1d(faces[:, 0], all_indices))[0]
167
+ i_y = np.where(np.in1d(faces[:, 1], all_indices))[0]
168
+ i_z = np.where(np.in1d(faces[:, 2], all_indices))[0]
169
+
170
+ faces_to_keep = np.array(list(set(i_x).union(i_y).union(i_z)))
171
+ mask = np.zeros(len(recon.faces), dtype=bool)
172
+ if len(faces_to_keep) > 0:
173
+ mask[faces_to_keep] = True
174
+
175
+ mesh.update_faces(mask)
176
+ mesh.remove_unreferenced_vertices()
177
+
178
+ # mesh.rezero()
179
+
180
+ return mesh
181
+
182
+ return None
lib/common/config.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
4
+ # holder of all proprietary rights on this computer program.
5
+ # You can only use this computer program if you have closed
6
+ # a license agreement with MPG or you get the right to use the computer
7
+ # program from someone who is authorized to grant you that right.
8
+ # Any use of the computer program without a valid license is prohibited and
9
+ # liable to prosecution.
10
+ #
11
+ # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
12
+ # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
13
+ # for Intelligent Systems. All rights reserved.
14
+ #
15
+ # Contact: [email protected]
16
+
17
+ from yacs.config import CfgNode as CN
18
+ import os
19
+
20
+ _C = CN(new_allowed=True)
21
+
22
+ # needed by trainer
23
+ _C.name = 'default'
24
+ _C.gpus = [0]
25
+ _C.test_gpus = [1]
26
+ _C.root = "./data/"
27
+ _C.ckpt_dir = './data/ckpt/'
28
+ _C.resume_path = ''
29
+ _C.normal_path = ''
30
+ _C.corr_path = ''
31
+ _C.results_path = './data/results/'
32
+ _C.projection_mode = 'orthogonal'
33
+ _C.num_views = 1
34
+ _C.sdf = False
35
+ _C.sdf_clip = 5.0
36
+
37
+ _C.lr_G = 1e-3
38
+ _C.lr_C = 1e-3
39
+ _C.lr_N = 2e-4
40
+ _C.weight_decay = 0.0
41
+ _C.momentum = 0.0
42
+ _C.optim = 'Adam'
43
+ _C.schedule = [5, 10, 15]
44
+ _C.gamma = 0.1
45
+
46
+ _C.overfit = False
47
+ _C.resume = False
48
+ _C.test_mode = False
49
+ _C.test_uv = False
50
+ _C.draw_geo_thres = 0.60
51
+ _C.num_sanity_val_steps = 2
52
+ _C.fast_dev = 0
53
+ _C.get_fit = False
54
+ _C.agora = False
55
+ _C.optim_cloth = False
56
+ _C.optim_body = False
57
+ _C.mcube_res = 256
58
+ _C.clean_mesh = True
59
+ _C.remesh = False
60
+
61
+ _C.batch_size = 4
62
+ _C.num_threads = 8
63
+
64
+ _C.num_epoch = 10
65
+ _C.freq_plot = 0.01
66
+ _C.freq_show_train = 0.1
67
+ _C.freq_show_val = 0.2
68
+ _C.freq_eval = 0.5
69
+ _C.accu_grad_batch = 4
70
+
71
+ _C.test_items = ['sv', 'mv', 'mv-fusion', 'hybrid', 'dc-pred', 'gt']
72
+
73
+ _C.net = CN()
74
+ _C.net.gtype = 'HGPIFuNet'
75
+ _C.net.ctype = 'resnet18'
76
+ _C.net.classifierIMF = 'MultiSegClassifier'
77
+ _C.net.netIMF = 'resnet18'
78
+ _C.net.norm = 'group'
79
+ _C.net.norm_mlp = 'group'
80
+ _C.net.norm_color = 'group'
81
+ _C.net.hg_down = 'conv128' #'ave_pool'
82
+ _C.net.num_views = 1
83
+
84
+ # kernel_size, stride, dilation, padding
85
+
86
+ _C.net.conv1 = [7, 2, 1, 3]
87
+ _C.net.conv3x3 = [3, 1, 1, 1]
88
+
89
+ _C.net.num_stack = 4
90
+ _C.net.num_hourglass = 2
91
+ _C.net.hourglass_dim = 256
92
+ _C.net.voxel_dim = 32
93
+ _C.net.resnet_dim = 120
94
+ _C.net.mlp_dim = [320, 1024, 512, 256, 128, 1]
95
+ _C.net.mlp_dim_knn = [320, 1024, 512, 256, 128, 3]
96
+ _C.net.mlp_dim_color = [513, 1024, 512, 256, 128, 3]
97
+ _C.net.mlp_dim_multiseg = [1088, 2048, 1024, 500]
98
+ _C.net.res_layers = [2, 3, 4]
99
+ _C.net.filter_dim = 256
100
+ _C.net.smpl_dim = 3
101
+
102
+ _C.net.cly_dim = 3
103
+ _C.net.soft_dim = 64
104
+ _C.net.z_size = 200.0
105
+ _C.net.N_freqs = 10
106
+ _C.net.geo_w = 0.1
107
+ _C.net.norm_w = 0.1
108
+ _C.net.dc_w = 0.1
109
+ _C.net.C_cat_to_G = False
110
+
111
+ _C.net.skip_hourglass = True
112
+ _C.net.use_tanh = False
113
+ _C.net.soft_onehot = True
114
+ _C.net.no_residual = False
115
+ _C.net.use_attention = False
116
+
117
+ _C.net.prior_type = "sdf"
118
+ _C.net.smpl_feats = ['sdf', 'cmap', 'norm', 'vis']
119
+ _C.net.use_filter = True
120
+ _C.net.use_cc = False
121
+ _C.net.use_PE = False
122
+ _C.net.use_IGR = False
123
+ _C.net.in_geo = ()
124
+ _C.net.in_nml = ()
125
+
126
+ _C.dataset = CN()
127
+ _C.dataset.root = ''
128
+ _C.dataset.set_splits = [0.95, 0.04]
129
+ _C.dataset.types = [
130
+ "3dpeople", "axyz", "renderpeople", "renderpeople_p27", "humanalloy"
131
+ ]
132
+ _C.dataset.scales = [1.0, 100.0, 1.0, 1.0, 100.0 / 39.37]
133
+ _C.dataset.rp_type = "pifu900"
134
+ _C.dataset.th_type = 'train'
135
+ _C.dataset.input_size = 512
136
+ _C.dataset.rotation_num = 3
137
+ _C.dataset.num_sample_ray=128 # volume rendering
138
+ _C.dataset.num_precomp = 10 # Number of segmentation classifiers
139
+ _C.dataset.num_multiseg = 500 # Number of categories per classifier
140
+ _C.dataset.num_knn = 10 # for loss/error
141
+ _C.dataset.num_knn_dis = 20 # for accuracy
142
+ _C.dataset.num_verts_max = 20000
143
+ _C.dataset.zray_type = False
144
+ _C.dataset.online_smpl = False
145
+ _C.dataset.noise_type = ['z-trans', 'pose', 'beta']
146
+ _C.dataset.noise_scale = [0.0, 0.0, 0.0]
147
+ _C.dataset.num_sample_geo = 10000
148
+ _C.dataset.num_sample_color = 0
149
+ _C.dataset.num_sample_seg = 0
150
+ _C.dataset.num_sample_knn = 10000
151
+
152
+ _C.dataset.sigma_geo = 5.0
153
+ _C.dataset.sigma_color = 0.10
154
+ _C.dataset.sigma_seg = 0.10
155
+ _C.dataset.thickness_threshold = 20.0
156
+ _C.dataset.ray_sample_num = 2
157
+ _C.dataset.semantic_p = False
158
+ _C.dataset.remove_outlier = False
159
+
160
+ _C.dataset.train_bsize = 1.0
161
+ _C.dataset.val_bsize = 1.0
162
+ _C.dataset.test_bsize = 1.0
163
+
164
+
165
+ def get_cfg_defaults():
166
+ """Get a yacs CfgNode object with default values for my_project."""
167
+ # Return a clone so that the defaults will not be altered
168
+ # This is for the "local variable" use pattern
169
+ return _C.clone()
170
+
171
+
172
+ # Alternatively, provide a way to import the defaults as
173
+ # a global singleton:
174
+ cfg = _C # users can `from config import cfg`
175
+
176
+ # cfg = get_cfg_defaults()
177
+ # cfg.merge_from_file('./configs/example.yaml')
178
+
179
+ # # Now override from a list (opts could come from the command line)
180
+ # opts = ['dataset.root', './data/XXXX', 'learning_rate', '1e-2']
181
+ # cfg.merge_from_list(opts)
182
+
183
+
184
+ def update_cfg(cfg_file):
185
+ # cfg = get_cfg_defaults()
186
+ _C.merge_from_file(cfg_file)
187
+ # return cfg.clone()
188
+ return _C
189
+
190
+
191
+ def parse_args(args):
192
+ cfg_file = args.cfg_file
193
+ if args.cfg_file is not None:
194
+ cfg = update_cfg(args.cfg_file)
195
+ else:
196
+ cfg = get_cfg_defaults()
197
+
198
+ # if args.misc is not None:
199
+ # cfg.merge_from_list(args.misc)
200
+
201
+ return cfg
202
+
203
+
204
+ def parse_args_extend(args):
205
+ if args.resume:
206
+ if not os.path.exists(args.log_dir):
207
+ raise ValueError(
208
+ 'Experiment are set to resume mode, but log directory does not exist.'
209
+ )
210
+
211
+ # load log's cfg
212
+ cfg_file = os.path.join(args.log_dir, 'cfg.yaml')
213
+ cfg = update_cfg(cfg_file)
214
+
215
+ if args.misc is not None:
216
+ cfg.merge_from_list(args.misc)
217
+ else:
218
+ parse_args(args)