Spaces:
Configuration error
Configuration error
heheyas
commited on
Commit
·
cfb7702
1
Parent(s):
f5c8d4d
init
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +50 -0
- configs/ae/video.yaml +35 -0
- configs/embedder/clip_image.yaml +8 -0
- configs/example_training/autoencoder/kl-f4/imagenet-attnfree-logvar.yaml +104 -0
- configs/example_training/autoencoder/kl-f4/imagenet-kl_f8_8chn.yaml +105 -0
- configs/example_training/imagenet-f8_cond.yaml +185 -0
- configs/example_training/toy/cifar10_cond.yaml +98 -0
- configs/example_training/toy/mnist.yaml +79 -0
- configs/example_training/toy/mnist_cond.yaml +98 -0
- configs/example_training/toy/mnist_cond_discrete_eps.yaml +103 -0
- configs/example_training/toy/mnist_cond_l1_loss.yaml +99 -0
- configs/example_training/toy/mnist_cond_with_ema.yaml +100 -0
- configs/example_training/txt2img-clipl-legacy-ucg-training.yaml +182 -0
- configs/example_training/txt2img-clipl.yaml +184 -0
- configs/inference/sd_2_1.yaml +60 -0
- configs/inference/sd_2_1_768.yaml +60 -0
- configs/inference/sd_xl_base.yaml +93 -0
- configs/inference/sd_xl_refiner.yaml +86 -0
- configs/inference/svd.yaml +131 -0
- configs/inference/svd_image_decoder.yaml +114 -0
- configs/inference/svd_mv.yaml +202 -0
- mesh_recon/configs/neuralangelo-ortho-wmask.yaml +145 -0
- mesh_recon/configs/v3d.yaml +144 -0
- mesh_recon/configs/videonvs.yaml +144 -0
- mesh_recon/datasets/__init__.py +17 -0
- mesh_recon/datasets/blender.py +143 -0
- mesh_recon/datasets/colmap.py +332 -0
- mesh_recon/datasets/colmap_utils.py +295 -0
- mesh_recon/datasets/dtu.py +201 -0
- mesh_recon/datasets/fixed_poses/000_back_RT.txt +3 -0
- mesh_recon/datasets/fixed_poses/000_back_left_RT.txt +3 -0
- mesh_recon/datasets/fixed_poses/000_back_right_RT.txt +3 -0
- mesh_recon/datasets/fixed_poses/000_front_RT.txt +3 -0
- mesh_recon/datasets/fixed_poses/000_front_left_RT.txt +3 -0
- mesh_recon/datasets/fixed_poses/000_front_right_RT.txt +3 -0
- mesh_recon/datasets/fixed_poses/000_left_RT.txt +3 -0
- mesh_recon/datasets/fixed_poses/000_right_RT.txt +3 -0
- mesh_recon/datasets/fixed_poses/000_top_RT.txt +3 -0
- mesh_recon/datasets/ortho.py +287 -0
- mesh_recon/datasets/utils.py +0 -0
- mesh_recon/datasets/v3d.py +284 -0
- mesh_recon/datasets/videonvs.py +256 -0
- mesh_recon/datasets/videonvs_co3d.py +252 -0
- mesh_recon/launch.py +144 -0
- mesh_recon/mesh.py +845 -0
- mesh_recon/models/__init__.py +16 -0
- mesh_recon/models/base.py +32 -0
- mesh_recon/models/geometry.py +238 -0
- mesh_recon/models/nerf.py +161 -0
- mesh_recon/models/network_utils.py +215 -0
.gitignore
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# extensions
|
| 2 |
+
*.egg-info
|
| 3 |
+
*.py[cod]
|
| 4 |
+
|
| 5 |
+
# envs
|
| 6 |
+
.pt13
|
| 7 |
+
.pt2
|
| 8 |
+
|
| 9 |
+
# directories
|
| 10 |
+
/checkpoints
|
| 11 |
+
/dist
|
| 12 |
+
/outputs
|
| 13 |
+
/build
|
| 14 |
+
/src
|
| 15 |
+
logs/
|
| 16 |
+
ckpts/
|
| 17 |
+
tmp/
|
| 18 |
+
lightning_logs/
|
| 19 |
+
images/
|
| 20 |
+
images*/
|
| 21 |
+
kb_configs/
|
| 22 |
+
debug_lvis.log
|
| 23 |
+
*.log
|
| 24 |
+
.cache/
|
| 25 |
+
redirects/
|
| 26 |
+
submits/
|
| 27 |
+
extern/
|
| 28 |
+
assets/images
|
| 29 |
+
output/
|
| 30 |
+
assets/scene
|
| 31 |
+
assets/GSO
|
| 32 |
+
assets/SD
|
| 33 |
+
spirals
|
| 34 |
+
*.zip
|
| 35 |
+
paper/
|
| 36 |
+
spirals_co3d/
|
| 37 |
+
scene_spirals/
|
| 38 |
+
blenders/
|
| 39 |
+
colmap_results/
|
| 40 |
+
depth_spirals/
|
| 41 |
+
recon/SIBR_viewers/
|
| 42 |
+
recon/assets/
|
| 43 |
+
mesh_recon/exp
|
| 44 |
+
mesh_recon/runs
|
| 45 |
+
mesh_recon/renders
|
| 46 |
+
mesh_recon/refined
|
| 47 |
+
*.png
|
| 48 |
+
*.pdf
|
| 49 |
+
*.npz
|
| 50 |
+
*.npy
|
configs/ae/video.yaml
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
target: sgm.models.autoencoder.AutoencodingEngine
|
| 2 |
+
params:
|
| 3 |
+
loss_config:
|
| 4 |
+
target: torch.nn.Identity
|
| 5 |
+
regularizer_config:
|
| 6 |
+
target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
|
| 7 |
+
encoder_config:
|
| 8 |
+
target: sgm.modules.diffusionmodules.model.Encoder
|
| 9 |
+
params:
|
| 10 |
+
attn_type: vanilla
|
| 11 |
+
double_z: True
|
| 12 |
+
z_channels: 4
|
| 13 |
+
resolution: 256
|
| 14 |
+
in_channels: 3
|
| 15 |
+
out_ch: 3
|
| 16 |
+
ch: 128
|
| 17 |
+
ch_mult: [1, 2, 4, 4]
|
| 18 |
+
num_res_blocks: 2
|
| 19 |
+
attn_resolutions: []
|
| 20 |
+
dropout: 0.0
|
| 21 |
+
decoder_config:
|
| 22 |
+
target: sgm.modules.autoencoding.temporal_ae.VideoDecoder
|
| 23 |
+
params:
|
| 24 |
+
attn_type: vanilla
|
| 25 |
+
double_z: True
|
| 26 |
+
z_channels: 4
|
| 27 |
+
resolution: 256
|
| 28 |
+
in_channels: 3
|
| 29 |
+
out_ch: 3
|
| 30 |
+
ch: 128
|
| 31 |
+
ch_mult: [1, 2, 4, 4]
|
| 32 |
+
num_res_blocks: 2
|
| 33 |
+
attn_resolutions: []
|
| 34 |
+
dropout: 0.0
|
| 35 |
+
video_kernel_size: [3, 1, 1]
|
configs/embedder/clip_image.yaml
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
|
| 2 |
+
params:
|
| 3 |
+
n_cond_frames: 1
|
| 4 |
+
n_copies: 1
|
| 5 |
+
open_clip_embedding_config:
|
| 6 |
+
target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
|
| 7 |
+
params:
|
| 8 |
+
freeze: True
|
configs/example_training/autoencoder/kl-f4/imagenet-attnfree-logvar.yaml
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model:
|
| 2 |
+
base_learning_rate: 4.5e-6
|
| 3 |
+
target: sgm.models.autoencoder.AutoencodingEngine
|
| 4 |
+
params:
|
| 5 |
+
input_key: jpg
|
| 6 |
+
monitor: val/rec_loss
|
| 7 |
+
|
| 8 |
+
loss_config:
|
| 9 |
+
target: sgm.modules.autoencoding.losses.GeneralLPIPSWithDiscriminator
|
| 10 |
+
params:
|
| 11 |
+
perceptual_weight: 0.25
|
| 12 |
+
disc_start: 20001
|
| 13 |
+
disc_weight: 0.5
|
| 14 |
+
learn_logvar: True
|
| 15 |
+
|
| 16 |
+
regularization_weights:
|
| 17 |
+
kl_loss: 1.0
|
| 18 |
+
|
| 19 |
+
regularizer_config:
|
| 20 |
+
target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
|
| 21 |
+
|
| 22 |
+
encoder_config:
|
| 23 |
+
target: sgm.modules.diffusionmodules.model.Encoder
|
| 24 |
+
params:
|
| 25 |
+
attn_type: none
|
| 26 |
+
double_z: True
|
| 27 |
+
z_channels: 4
|
| 28 |
+
resolution: 256
|
| 29 |
+
in_channels: 3
|
| 30 |
+
out_ch: 3
|
| 31 |
+
ch: 128
|
| 32 |
+
ch_mult: [1, 2, 4]
|
| 33 |
+
num_res_blocks: 4
|
| 34 |
+
attn_resolutions: []
|
| 35 |
+
dropout: 0.0
|
| 36 |
+
|
| 37 |
+
decoder_config:
|
| 38 |
+
target: sgm.modules.diffusionmodules.model.Decoder
|
| 39 |
+
params: ${model.params.encoder_config.params}
|
| 40 |
+
|
| 41 |
+
data:
|
| 42 |
+
target: sgm.data.dataset.StableDataModuleFromConfig
|
| 43 |
+
params:
|
| 44 |
+
train:
|
| 45 |
+
datapipeline:
|
| 46 |
+
urls:
|
| 47 |
+
- DATA-PATH
|
| 48 |
+
pipeline_config:
|
| 49 |
+
shardshuffle: 10000
|
| 50 |
+
sample_shuffle: 10000
|
| 51 |
+
|
| 52 |
+
decoders:
|
| 53 |
+
- pil
|
| 54 |
+
|
| 55 |
+
postprocessors:
|
| 56 |
+
- target: sdata.mappers.TorchVisionImageTransforms
|
| 57 |
+
params:
|
| 58 |
+
key: jpg
|
| 59 |
+
transforms:
|
| 60 |
+
- target: torchvision.transforms.Resize
|
| 61 |
+
params:
|
| 62 |
+
size: 256
|
| 63 |
+
interpolation: 3
|
| 64 |
+
- target: torchvision.transforms.ToTensor
|
| 65 |
+
- target: sdata.mappers.Rescaler
|
| 66 |
+
- target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare
|
| 67 |
+
params:
|
| 68 |
+
h_key: height
|
| 69 |
+
w_key: width
|
| 70 |
+
|
| 71 |
+
loader:
|
| 72 |
+
batch_size: 8
|
| 73 |
+
num_workers: 4
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
lightning:
|
| 77 |
+
strategy:
|
| 78 |
+
target: pytorch_lightning.strategies.DDPStrategy
|
| 79 |
+
params:
|
| 80 |
+
find_unused_parameters: True
|
| 81 |
+
|
| 82 |
+
modelcheckpoint:
|
| 83 |
+
params:
|
| 84 |
+
every_n_train_steps: 5000
|
| 85 |
+
|
| 86 |
+
callbacks:
|
| 87 |
+
metrics_over_trainsteps_checkpoint:
|
| 88 |
+
params:
|
| 89 |
+
every_n_train_steps: 50000
|
| 90 |
+
|
| 91 |
+
image_logger:
|
| 92 |
+
target: main.ImageLogger
|
| 93 |
+
params:
|
| 94 |
+
enable_autocast: False
|
| 95 |
+
batch_frequency: 1000
|
| 96 |
+
max_images: 8
|
| 97 |
+
increase_log_steps: True
|
| 98 |
+
|
| 99 |
+
trainer:
|
| 100 |
+
devices: 0,
|
| 101 |
+
limit_val_batches: 50
|
| 102 |
+
benchmark: True
|
| 103 |
+
accumulate_grad_batches: 1
|
| 104 |
+
val_check_interval: 10000
|
configs/example_training/autoencoder/kl-f4/imagenet-kl_f8_8chn.yaml
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model:
|
| 2 |
+
base_learning_rate: 4.5e-6
|
| 3 |
+
target: sgm.models.autoencoder.AutoencodingEngine
|
| 4 |
+
params:
|
| 5 |
+
input_key: jpg
|
| 6 |
+
monitor: val/loss/rec
|
| 7 |
+
disc_start_iter: 0
|
| 8 |
+
|
| 9 |
+
encoder_config:
|
| 10 |
+
target: sgm.modules.diffusionmodules.model.Encoder
|
| 11 |
+
params:
|
| 12 |
+
attn_type: vanilla-xformers
|
| 13 |
+
double_z: true
|
| 14 |
+
z_channels: 8
|
| 15 |
+
resolution: 256
|
| 16 |
+
in_channels: 3
|
| 17 |
+
out_ch: 3
|
| 18 |
+
ch: 128
|
| 19 |
+
ch_mult: [1, 2, 4, 4]
|
| 20 |
+
num_res_blocks: 2
|
| 21 |
+
attn_resolutions: []
|
| 22 |
+
dropout: 0.0
|
| 23 |
+
|
| 24 |
+
decoder_config:
|
| 25 |
+
target: sgm.modules.diffusionmodules.model.Decoder
|
| 26 |
+
params: ${model.params.encoder_config.params}
|
| 27 |
+
|
| 28 |
+
regularizer_config:
|
| 29 |
+
target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
|
| 30 |
+
|
| 31 |
+
loss_config:
|
| 32 |
+
target: sgm.modules.autoencoding.losses.GeneralLPIPSWithDiscriminator
|
| 33 |
+
params:
|
| 34 |
+
perceptual_weight: 0.25
|
| 35 |
+
disc_start: 20001
|
| 36 |
+
disc_weight: 0.5
|
| 37 |
+
learn_logvar: True
|
| 38 |
+
|
| 39 |
+
regularization_weights:
|
| 40 |
+
kl_loss: 1.0
|
| 41 |
+
|
| 42 |
+
data:
|
| 43 |
+
target: sgm.data.dataset.StableDataModuleFromConfig
|
| 44 |
+
params:
|
| 45 |
+
train:
|
| 46 |
+
datapipeline:
|
| 47 |
+
urls:
|
| 48 |
+
- DATA-PATH
|
| 49 |
+
pipeline_config:
|
| 50 |
+
shardshuffle: 10000
|
| 51 |
+
sample_shuffle: 10000
|
| 52 |
+
|
| 53 |
+
decoders:
|
| 54 |
+
- pil
|
| 55 |
+
|
| 56 |
+
postprocessors:
|
| 57 |
+
- target: sdata.mappers.TorchVisionImageTransforms
|
| 58 |
+
params:
|
| 59 |
+
key: jpg
|
| 60 |
+
transforms:
|
| 61 |
+
- target: torchvision.transforms.Resize
|
| 62 |
+
params:
|
| 63 |
+
size: 256
|
| 64 |
+
interpolation: 3
|
| 65 |
+
- target: torchvision.transforms.ToTensor
|
| 66 |
+
- target: sdata.mappers.Rescaler
|
| 67 |
+
- target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare
|
| 68 |
+
params:
|
| 69 |
+
h_key: height
|
| 70 |
+
w_key: width
|
| 71 |
+
|
| 72 |
+
loader:
|
| 73 |
+
batch_size: 8
|
| 74 |
+
num_workers: 4
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
lightning:
|
| 78 |
+
strategy:
|
| 79 |
+
target: pytorch_lightning.strategies.DDPStrategy
|
| 80 |
+
params:
|
| 81 |
+
find_unused_parameters: True
|
| 82 |
+
|
| 83 |
+
modelcheckpoint:
|
| 84 |
+
params:
|
| 85 |
+
every_n_train_steps: 5000
|
| 86 |
+
|
| 87 |
+
callbacks:
|
| 88 |
+
metrics_over_trainsteps_checkpoint:
|
| 89 |
+
params:
|
| 90 |
+
every_n_train_steps: 50000
|
| 91 |
+
|
| 92 |
+
image_logger:
|
| 93 |
+
target: main.ImageLogger
|
| 94 |
+
params:
|
| 95 |
+
enable_autocast: False
|
| 96 |
+
batch_frequency: 1000
|
| 97 |
+
max_images: 8
|
| 98 |
+
increase_log_steps: True
|
| 99 |
+
|
| 100 |
+
trainer:
|
| 101 |
+
devices: 0,
|
| 102 |
+
limit_val_batches: 50
|
| 103 |
+
benchmark: True
|
| 104 |
+
accumulate_grad_batches: 1
|
| 105 |
+
val_check_interval: 10000
|
configs/example_training/imagenet-f8_cond.yaml
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model:
|
| 2 |
+
base_learning_rate: 1.0e-4
|
| 3 |
+
target: sgm.models.diffusion.DiffusionEngine
|
| 4 |
+
params:
|
| 5 |
+
scale_factor: 0.13025
|
| 6 |
+
disable_first_stage_autocast: True
|
| 7 |
+
log_keys:
|
| 8 |
+
- cls
|
| 9 |
+
|
| 10 |
+
scheduler_config:
|
| 11 |
+
target: sgm.lr_scheduler.LambdaLinearScheduler
|
| 12 |
+
params:
|
| 13 |
+
warm_up_steps: [10000]
|
| 14 |
+
cycle_lengths: [10000000000000]
|
| 15 |
+
f_start: [1.e-6]
|
| 16 |
+
f_max: [1.]
|
| 17 |
+
f_min: [1.]
|
| 18 |
+
|
| 19 |
+
denoiser_config:
|
| 20 |
+
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
|
| 21 |
+
params:
|
| 22 |
+
num_idx: 1000
|
| 23 |
+
|
| 24 |
+
scaling_config:
|
| 25 |
+
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
|
| 26 |
+
discretization_config:
|
| 27 |
+
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
|
| 28 |
+
|
| 29 |
+
network_config:
|
| 30 |
+
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
|
| 31 |
+
params:
|
| 32 |
+
use_checkpoint: True
|
| 33 |
+
in_channels: 4
|
| 34 |
+
out_channels: 4
|
| 35 |
+
model_channels: 256
|
| 36 |
+
attention_resolutions: [1, 2, 4]
|
| 37 |
+
num_res_blocks: 2
|
| 38 |
+
channel_mult: [1, 2, 4]
|
| 39 |
+
num_head_channels: 64
|
| 40 |
+
num_classes: sequential
|
| 41 |
+
adm_in_channels: 1024
|
| 42 |
+
transformer_depth: 1
|
| 43 |
+
context_dim: 1024
|
| 44 |
+
spatial_transformer_attn_type: softmax-xformers
|
| 45 |
+
|
| 46 |
+
conditioner_config:
|
| 47 |
+
target: sgm.modules.GeneralConditioner
|
| 48 |
+
params:
|
| 49 |
+
emb_models:
|
| 50 |
+
- is_trainable: True
|
| 51 |
+
input_key: cls
|
| 52 |
+
ucg_rate: 0.2
|
| 53 |
+
target: sgm.modules.encoders.modules.ClassEmbedder
|
| 54 |
+
params:
|
| 55 |
+
add_sequence_dim: True
|
| 56 |
+
embed_dim: 1024
|
| 57 |
+
n_classes: 1000
|
| 58 |
+
|
| 59 |
+
- is_trainable: False
|
| 60 |
+
ucg_rate: 0.2
|
| 61 |
+
input_key: original_size_as_tuple
|
| 62 |
+
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
| 63 |
+
params:
|
| 64 |
+
outdim: 256
|
| 65 |
+
|
| 66 |
+
- is_trainable: False
|
| 67 |
+
input_key: crop_coords_top_left
|
| 68 |
+
ucg_rate: 0.2
|
| 69 |
+
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
| 70 |
+
params:
|
| 71 |
+
outdim: 256
|
| 72 |
+
|
| 73 |
+
first_stage_config:
|
| 74 |
+
target: sgm.models.autoencoder.AutoencoderKL
|
| 75 |
+
params:
|
| 76 |
+
ckpt_path: CKPT_PATH
|
| 77 |
+
embed_dim: 4
|
| 78 |
+
monitor: val/rec_loss
|
| 79 |
+
ddconfig:
|
| 80 |
+
attn_type: vanilla-xformers
|
| 81 |
+
double_z: true
|
| 82 |
+
z_channels: 4
|
| 83 |
+
resolution: 256
|
| 84 |
+
in_channels: 3
|
| 85 |
+
out_ch: 3
|
| 86 |
+
ch: 128
|
| 87 |
+
ch_mult: [1, 2, 4, 4]
|
| 88 |
+
num_res_blocks: 2
|
| 89 |
+
attn_resolutions: []
|
| 90 |
+
dropout: 0.0
|
| 91 |
+
lossconfig:
|
| 92 |
+
target: torch.nn.Identity
|
| 93 |
+
|
| 94 |
+
loss_fn_config:
|
| 95 |
+
target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
|
| 96 |
+
params:
|
| 97 |
+
loss_weighting_config:
|
| 98 |
+
target: sgm.modules.diffusionmodules.loss_weighting.EpsWeighting
|
| 99 |
+
sigma_sampler_config:
|
| 100 |
+
target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
|
| 101 |
+
params:
|
| 102 |
+
num_idx: 1000
|
| 103 |
+
|
| 104 |
+
discretization_config:
|
| 105 |
+
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
|
| 106 |
+
|
| 107 |
+
sampler_config:
|
| 108 |
+
target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
|
| 109 |
+
params:
|
| 110 |
+
num_steps: 50
|
| 111 |
+
|
| 112 |
+
discretization_config:
|
| 113 |
+
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
|
| 114 |
+
|
| 115 |
+
guider_config:
|
| 116 |
+
target: sgm.modules.diffusionmodules.guiders.VanillaCFG
|
| 117 |
+
params:
|
| 118 |
+
scale: 5.0
|
| 119 |
+
|
| 120 |
+
data:
|
| 121 |
+
target: sgm.data.dataset.StableDataModuleFromConfig
|
| 122 |
+
params:
|
| 123 |
+
train:
|
| 124 |
+
datapipeline:
|
| 125 |
+
urls:
|
| 126 |
+
# USER: adapt this path the root of your custom dataset
|
| 127 |
+
- DATA_PATH
|
| 128 |
+
pipeline_config:
|
| 129 |
+
shardshuffle: 10000
|
| 130 |
+
sample_shuffle: 10000 # USER: you might wanna adapt depending on your available RAM
|
| 131 |
+
|
| 132 |
+
decoders:
|
| 133 |
+
- pil
|
| 134 |
+
|
| 135 |
+
postprocessors:
|
| 136 |
+
- target: sdata.mappers.TorchVisionImageTransforms
|
| 137 |
+
params:
|
| 138 |
+
key: jpg # USER: you might wanna adapt this for your custom dataset
|
| 139 |
+
transforms:
|
| 140 |
+
- target: torchvision.transforms.Resize
|
| 141 |
+
params:
|
| 142 |
+
size: 256
|
| 143 |
+
interpolation: 3
|
| 144 |
+
- target: torchvision.transforms.ToTensor
|
| 145 |
+
- target: sdata.mappers.Rescaler
|
| 146 |
+
|
| 147 |
+
- target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare
|
| 148 |
+
params:
|
| 149 |
+
h_key: height # USER: you might wanna adapt this for your custom dataset
|
| 150 |
+
w_key: width # USER: you might wanna adapt this for your custom dataset
|
| 151 |
+
|
| 152 |
+
loader:
|
| 153 |
+
batch_size: 64
|
| 154 |
+
num_workers: 6
|
| 155 |
+
|
| 156 |
+
lightning:
|
| 157 |
+
modelcheckpoint:
|
| 158 |
+
params:
|
| 159 |
+
every_n_train_steps: 5000
|
| 160 |
+
|
| 161 |
+
callbacks:
|
| 162 |
+
metrics_over_trainsteps_checkpoint:
|
| 163 |
+
params:
|
| 164 |
+
every_n_train_steps: 25000
|
| 165 |
+
|
| 166 |
+
image_logger:
|
| 167 |
+
target: main.ImageLogger
|
| 168 |
+
params:
|
| 169 |
+
disabled: False
|
| 170 |
+
enable_autocast: False
|
| 171 |
+
batch_frequency: 1000
|
| 172 |
+
max_images: 8
|
| 173 |
+
increase_log_steps: True
|
| 174 |
+
log_first_step: False
|
| 175 |
+
log_images_kwargs:
|
| 176 |
+
use_ema_scope: False
|
| 177 |
+
N: 8
|
| 178 |
+
n_rows: 2
|
| 179 |
+
|
| 180 |
+
trainer:
|
| 181 |
+
devices: 0,
|
| 182 |
+
benchmark: True
|
| 183 |
+
num_sanity_val_steps: 0
|
| 184 |
+
accumulate_grad_batches: 1
|
| 185 |
+
max_epochs: 1000
|
configs/example_training/toy/cifar10_cond.yaml
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model:
|
| 2 |
+
base_learning_rate: 1.0e-4
|
| 3 |
+
target: sgm.models.diffusion.DiffusionEngine
|
| 4 |
+
params:
|
| 5 |
+
denoiser_config:
|
| 6 |
+
target: sgm.modules.diffusionmodules.denoiser.Denoiser
|
| 7 |
+
params:
|
| 8 |
+
scaling_config:
|
| 9 |
+
target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
|
| 10 |
+
params:
|
| 11 |
+
sigma_data: 1.0
|
| 12 |
+
|
| 13 |
+
network_config:
|
| 14 |
+
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
|
| 15 |
+
params:
|
| 16 |
+
in_channels: 3
|
| 17 |
+
out_channels: 3
|
| 18 |
+
model_channels: 32
|
| 19 |
+
attention_resolutions: []
|
| 20 |
+
num_res_blocks: 4
|
| 21 |
+
channel_mult: [1, 2, 2]
|
| 22 |
+
num_head_channels: 32
|
| 23 |
+
num_classes: sequential
|
| 24 |
+
adm_in_channels: 128
|
| 25 |
+
|
| 26 |
+
conditioner_config:
|
| 27 |
+
target: sgm.modules.GeneralConditioner
|
| 28 |
+
params:
|
| 29 |
+
emb_models:
|
| 30 |
+
- is_trainable: True
|
| 31 |
+
input_key: cls
|
| 32 |
+
ucg_rate: 0.2
|
| 33 |
+
target: sgm.modules.encoders.modules.ClassEmbedder
|
| 34 |
+
params:
|
| 35 |
+
embed_dim: 128
|
| 36 |
+
n_classes: 10
|
| 37 |
+
|
| 38 |
+
first_stage_config:
|
| 39 |
+
target: sgm.models.autoencoder.IdentityFirstStage
|
| 40 |
+
|
| 41 |
+
loss_fn_config:
|
| 42 |
+
target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
|
| 43 |
+
params:
|
| 44 |
+
loss_weighting_config:
|
| 45 |
+
target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting
|
| 46 |
+
params:
|
| 47 |
+
sigma_data: 1.0
|
| 48 |
+
sigma_sampler_config:
|
| 49 |
+
target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling
|
| 50 |
+
|
| 51 |
+
sampler_config:
|
| 52 |
+
target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
|
| 53 |
+
params:
|
| 54 |
+
num_steps: 50
|
| 55 |
+
|
| 56 |
+
discretization_config:
|
| 57 |
+
target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
|
| 58 |
+
|
| 59 |
+
guider_config:
|
| 60 |
+
target: sgm.modules.diffusionmodules.guiders.VanillaCFG
|
| 61 |
+
params:
|
| 62 |
+
scale: 3.0
|
| 63 |
+
|
| 64 |
+
data:
|
| 65 |
+
target: sgm.data.cifar10.CIFAR10Loader
|
| 66 |
+
params:
|
| 67 |
+
batch_size: 512
|
| 68 |
+
num_workers: 1
|
| 69 |
+
|
| 70 |
+
lightning:
|
| 71 |
+
modelcheckpoint:
|
| 72 |
+
params:
|
| 73 |
+
every_n_train_steps: 5000
|
| 74 |
+
|
| 75 |
+
callbacks:
|
| 76 |
+
metrics_over_trainsteps_checkpoint:
|
| 77 |
+
params:
|
| 78 |
+
every_n_train_steps: 25000
|
| 79 |
+
|
| 80 |
+
image_logger:
|
| 81 |
+
target: main.ImageLogger
|
| 82 |
+
params:
|
| 83 |
+
disabled: False
|
| 84 |
+
batch_frequency: 1000
|
| 85 |
+
max_images: 64
|
| 86 |
+
increase_log_steps: True
|
| 87 |
+
log_first_step: False
|
| 88 |
+
log_images_kwargs:
|
| 89 |
+
use_ema_scope: False
|
| 90 |
+
N: 64
|
| 91 |
+
n_rows: 8
|
| 92 |
+
|
| 93 |
+
trainer:
|
| 94 |
+
devices: 0,
|
| 95 |
+
benchmark: True
|
| 96 |
+
num_sanity_val_steps: 0
|
| 97 |
+
accumulate_grad_batches: 1
|
| 98 |
+
max_epochs: 20
|
configs/example_training/toy/mnist.yaml
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model:
|
| 2 |
+
base_learning_rate: 1.0e-4
|
| 3 |
+
target: sgm.models.diffusion.DiffusionEngine
|
| 4 |
+
params:
|
| 5 |
+
denoiser_config:
|
| 6 |
+
target: sgm.modules.diffusionmodules.denoiser.Denoiser
|
| 7 |
+
params:
|
| 8 |
+
scaling_config:
|
| 9 |
+
target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
|
| 10 |
+
params:
|
| 11 |
+
sigma_data: 1.0
|
| 12 |
+
|
| 13 |
+
network_config:
|
| 14 |
+
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
|
| 15 |
+
params:
|
| 16 |
+
in_channels: 1
|
| 17 |
+
out_channels: 1
|
| 18 |
+
model_channels: 32
|
| 19 |
+
attention_resolutions: []
|
| 20 |
+
num_res_blocks: 4
|
| 21 |
+
channel_mult: [1, 2, 2]
|
| 22 |
+
num_head_channels: 32
|
| 23 |
+
|
| 24 |
+
first_stage_config:
|
| 25 |
+
target: sgm.models.autoencoder.IdentityFirstStage
|
| 26 |
+
|
| 27 |
+
loss_fn_config:
|
| 28 |
+
target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
|
| 29 |
+
params:
|
| 30 |
+
loss_weighting_config:
|
| 31 |
+
target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting
|
| 32 |
+
params:
|
| 33 |
+
sigma_data: 1.0
|
| 34 |
+
sigma_sampler_config:
|
| 35 |
+
target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling
|
| 36 |
+
|
| 37 |
+
sampler_config:
|
| 38 |
+
target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
|
| 39 |
+
params:
|
| 40 |
+
num_steps: 50
|
| 41 |
+
|
| 42 |
+
discretization_config:
|
| 43 |
+
target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
|
| 44 |
+
|
| 45 |
+
data:
|
| 46 |
+
target: sgm.data.mnist.MNISTLoader
|
| 47 |
+
params:
|
| 48 |
+
batch_size: 512
|
| 49 |
+
num_workers: 1
|
| 50 |
+
|
| 51 |
+
lightning:
|
| 52 |
+
modelcheckpoint:
|
| 53 |
+
params:
|
| 54 |
+
every_n_train_steps: 5000
|
| 55 |
+
|
| 56 |
+
callbacks:
|
| 57 |
+
metrics_over_trainsteps_checkpoint:
|
| 58 |
+
params:
|
| 59 |
+
every_n_train_steps: 25000
|
| 60 |
+
|
| 61 |
+
image_logger:
|
| 62 |
+
target: main.ImageLogger
|
| 63 |
+
params:
|
| 64 |
+
disabled: False
|
| 65 |
+
batch_frequency: 1000
|
| 66 |
+
max_images: 64
|
| 67 |
+
increase_log_steps: False
|
| 68 |
+
log_first_step: False
|
| 69 |
+
log_images_kwargs:
|
| 70 |
+
use_ema_scope: False
|
| 71 |
+
N: 64
|
| 72 |
+
n_rows: 8
|
| 73 |
+
|
| 74 |
+
trainer:
|
| 75 |
+
devices: 0,
|
| 76 |
+
benchmark: True
|
| 77 |
+
num_sanity_val_steps: 0
|
| 78 |
+
accumulate_grad_batches: 1
|
| 79 |
+
max_epochs: 10
|
configs/example_training/toy/mnist_cond.yaml
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model:
|
| 2 |
+
base_learning_rate: 1.0e-4
|
| 3 |
+
target: sgm.models.diffusion.DiffusionEngine
|
| 4 |
+
params:
|
| 5 |
+
denoiser_config:
|
| 6 |
+
target: sgm.modules.diffusionmodules.denoiser.Denoiser
|
| 7 |
+
params:
|
| 8 |
+
scaling_config:
|
| 9 |
+
target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
|
| 10 |
+
params:
|
| 11 |
+
sigma_data: 1.0
|
| 12 |
+
|
| 13 |
+
network_config:
|
| 14 |
+
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
|
| 15 |
+
params:
|
| 16 |
+
in_channels: 1
|
| 17 |
+
out_channels: 1
|
| 18 |
+
model_channels: 32
|
| 19 |
+
attention_resolutions: []
|
| 20 |
+
num_res_blocks: 4
|
| 21 |
+
channel_mult: [1, 2, 2]
|
| 22 |
+
num_head_channels: 32
|
| 23 |
+
num_classes: sequential
|
| 24 |
+
adm_in_channels: 128
|
| 25 |
+
|
| 26 |
+
conditioner_config:
|
| 27 |
+
target: sgm.modules.GeneralConditioner
|
| 28 |
+
params:
|
| 29 |
+
emb_models:
|
| 30 |
+
- is_trainable: True
|
| 31 |
+
input_key: cls
|
| 32 |
+
ucg_rate: 0.2
|
| 33 |
+
target: sgm.modules.encoders.modules.ClassEmbedder
|
| 34 |
+
params:
|
| 35 |
+
embed_dim: 128
|
| 36 |
+
n_classes: 10
|
| 37 |
+
|
| 38 |
+
first_stage_config:
|
| 39 |
+
target: sgm.models.autoencoder.IdentityFirstStage
|
| 40 |
+
|
| 41 |
+
loss_fn_config:
|
| 42 |
+
target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
|
| 43 |
+
params:
|
| 44 |
+
loss_weighting_config:
|
| 45 |
+
target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting
|
| 46 |
+
params:
|
| 47 |
+
sigma_data: 1.0
|
| 48 |
+
sigma_sampler_config:
|
| 49 |
+
target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling
|
| 50 |
+
|
| 51 |
+
sampler_config:
|
| 52 |
+
target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
|
| 53 |
+
params:
|
| 54 |
+
num_steps: 50
|
| 55 |
+
|
| 56 |
+
discretization_config:
|
| 57 |
+
target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
|
| 58 |
+
|
| 59 |
+
guider_config:
|
| 60 |
+
target: sgm.modules.diffusionmodules.guiders.VanillaCFG
|
| 61 |
+
params:
|
| 62 |
+
scale: 3.0
|
| 63 |
+
|
| 64 |
+
data:
|
| 65 |
+
target: sgm.data.mnist.MNISTLoader
|
| 66 |
+
params:
|
| 67 |
+
batch_size: 512
|
| 68 |
+
num_workers: 1
|
| 69 |
+
|
| 70 |
+
lightning:
|
| 71 |
+
modelcheckpoint:
|
| 72 |
+
params:
|
| 73 |
+
every_n_train_steps: 5000
|
| 74 |
+
|
| 75 |
+
callbacks:
|
| 76 |
+
metrics_over_trainsteps_checkpoint:
|
| 77 |
+
params:
|
| 78 |
+
every_n_train_steps: 25000
|
| 79 |
+
|
| 80 |
+
image_logger:
|
| 81 |
+
target: main.ImageLogger
|
| 82 |
+
params:
|
| 83 |
+
disabled: False
|
| 84 |
+
batch_frequency: 1000
|
| 85 |
+
max_images: 16
|
| 86 |
+
increase_log_steps: True
|
| 87 |
+
log_first_step: False
|
| 88 |
+
log_images_kwargs:
|
| 89 |
+
use_ema_scope: False
|
| 90 |
+
N: 16
|
| 91 |
+
n_rows: 4
|
| 92 |
+
|
| 93 |
+
trainer:
|
| 94 |
+
devices: 0,
|
| 95 |
+
benchmark: True
|
| 96 |
+
num_sanity_val_steps: 0
|
| 97 |
+
accumulate_grad_batches: 1
|
| 98 |
+
max_epochs: 20
|
configs/example_training/toy/mnist_cond_discrete_eps.yaml
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model:
|
| 2 |
+
base_learning_rate: 1.0e-4
|
| 3 |
+
target: sgm.models.diffusion.DiffusionEngine
|
| 4 |
+
params:
|
| 5 |
+
denoiser_config:
|
| 6 |
+
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
|
| 7 |
+
params:
|
| 8 |
+
num_idx: 1000
|
| 9 |
+
|
| 10 |
+
scaling_config:
|
| 11 |
+
target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
|
| 12 |
+
discretization_config:
|
| 13 |
+
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
|
| 14 |
+
|
| 15 |
+
network_config:
|
| 16 |
+
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
|
| 17 |
+
params:
|
| 18 |
+
in_channels: 1
|
| 19 |
+
out_channels: 1
|
| 20 |
+
model_channels: 32
|
| 21 |
+
attention_resolutions: []
|
| 22 |
+
num_res_blocks: 4
|
| 23 |
+
channel_mult: [1, 2, 2]
|
| 24 |
+
num_head_channels: 32
|
| 25 |
+
num_classes: sequential
|
| 26 |
+
adm_in_channels: 128
|
| 27 |
+
|
| 28 |
+
conditioner_config:
|
| 29 |
+
target: sgm.modules.GeneralConditioner
|
| 30 |
+
params:
|
| 31 |
+
emb_models:
|
| 32 |
+
- is_trainable: True
|
| 33 |
+
input_key: cls
|
| 34 |
+
ucg_rate: 0.2
|
| 35 |
+
target: sgm.modules.encoders.modules.ClassEmbedder
|
| 36 |
+
params:
|
| 37 |
+
embed_dim: 128
|
| 38 |
+
n_classes: 10
|
| 39 |
+
|
| 40 |
+
first_stage_config:
|
| 41 |
+
target: sgm.models.autoencoder.IdentityFirstStage
|
| 42 |
+
|
| 43 |
+
loss_fn_config:
|
| 44 |
+
target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
|
| 45 |
+
params:
|
| 46 |
+
loss_weighting_config:
|
| 47 |
+
target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting
|
| 48 |
+
sigma_sampler_config:
|
| 49 |
+
target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
|
| 50 |
+
params:
|
| 51 |
+
num_idx: 1000
|
| 52 |
+
|
| 53 |
+
discretization_config:
|
| 54 |
+
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
|
| 55 |
+
|
| 56 |
+
sampler_config:
|
| 57 |
+
target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
|
| 58 |
+
params:
|
| 59 |
+
num_steps: 50
|
| 60 |
+
|
| 61 |
+
discretization_config:
|
| 62 |
+
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
|
| 63 |
+
|
| 64 |
+
guider_config:
|
| 65 |
+
target: sgm.modules.diffusionmodules.guiders.VanillaCFG
|
| 66 |
+
params:
|
| 67 |
+
scale: 5.0
|
| 68 |
+
|
| 69 |
+
data:
|
| 70 |
+
target: sgm.data.mnist.MNISTLoader
|
| 71 |
+
params:
|
| 72 |
+
batch_size: 512
|
| 73 |
+
num_workers: 1
|
| 74 |
+
|
| 75 |
+
lightning:
|
| 76 |
+
modelcheckpoint:
|
| 77 |
+
params:
|
| 78 |
+
every_n_train_steps: 5000
|
| 79 |
+
|
| 80 |
+
callbacks:
|
| 81 |
+
metrics_over_trainsteps_checkpoint:
|
| 82 |
+
params:
|
| 83 |
+
every_n_train_steps: 25000
|
| 84 |
+
|
| 85 |
+
image_logger:
|
| 86 |
+
target: main.ImageLogger
|
| 87 |
+
params:
|
| 88 |
+
disabled: False
|
| 89 |
+
batch_frequency: 1000
|
| 90 |
+
max_images: 16
|
| 91 |
+
increase_log_steps: True
|
| 92 |
+
log_first_step: False
|
| 93 |
+
log_images_kwargs:
|
| 94 |
+
use_ema_scope: False
|
| 95 |
+
N: 16
|
| 96 |
+
n_rows: 4
|
| 97 |
+
|
| 98 |
+
trainer:
|
| 99 |
+
devices: 0,
|
| 100 |
+
benchmark: True
|
| 101 |
+
num_sanity_val_steps: 0
|
| 102 |
+
accumulate_grad_batches: 1
|
| 103 |
+
max_epochs: 20
|
configs/example_training/toy/mnist_cond_l1_loss.yaml
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model:
|
| 2 |
+
base_learning_rate: 1.0e-4
|
| 3 |
+
target: sgm.models.diffusion.DiffusionEngine
|
| 4 |
+
params:
|
| 5 |
+
denoiser_config:
|
| 6 |
+
target: sgm.modules.diffusionmodules.denoiser.Denoiser
|
| 7 |
+
params:
|
| 8 |
+
scaling_config:
|
| 9 |
+
target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
|
| 10 |
+
params:
|
| 11 |
+
sigma_data: 1.0
|
| 12 |
+
|
| 13 |
+
network_config:
|
| 14 |
+
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
|
| 15 |
+
params:
|
| 16 |
+
in_channels: 1
|
| 17 |
+
out_channels: 1
|
| 18 |
+
model_channels: 32
|
| 19 |
+
attention_resolutions: []
|
| 20 |
+
num_res_blocks: 4
|
| 21 |
+
channel_mult: [1, 2, 2]
|
| 22 |
+
num_head_channels: 32
|
| 23 |
+
num_classes: sequential
|
| 24 |
+
adm_in_channels: 128
|
| 25 |
+
|
| 26 |
+
conditioner_config:
|
| 27 |
+
target: sgm.modules.GeneralConditioner
|
| 28 |
+
params:
|
| 29 |
+
emb_models:
|
| 30 |
+
- is_trainable: True
|
| 31 |
+
input_key: cls
|
| 32 |
+
ucg_rate: 0.2
|
| 33 |
+
target: sgm.modules.encoders.modules.ClassEmbedder
|
| 34 |
+
params:
|
| 35 |
+
embed_dim: 128
|
| 36 |
+
n_classes: 10
|
| 37 |
+
|
| 38 |
+
first_stage_config:
|
| 39 |
+
target: sgm.models.autoencoder.IdentityFirstStage
|
| 40 |
+
|
| 41 |
+
loss_fn_config:
|
| 42 |
+
target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
|
| 43 |
+
params:
|
| 44 |
+
loss_type: l1
|
| 45 |
+
loss_weighting_config:
|
| 46 |
+
target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting
|
| 47 |
+
params:
|
| 48 |
+
sigma_data: 1.0
|
| 49 |
+
sigma_sampler_config:
|
| 50 |
+
target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling
|
| 51 |
+
|
| 52 |
+
sampler_config:
|
| 53 |
+
target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
|
| 54 |
+
params:
|
| 55 |
+
num_steps: 50
|
| 56 |
+
|
| 57 |
+
discretization_config:
|
| 58 |
+
target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
|
| 59 |
+
|
| 60 |
+
guider_config:
|
| 61 |
+
target: sgm.modules.diffusionmodules.guiders.VanillaCFG
|
| 62 |
+
params:
|
| 63 |
+
scale: 3.0
|
| 64 |
+
|
| 65 |
+
data:
|
| 66 |
+
target: sgm.data.mnist.MNISTLoader
|
| 67 |
+
params:
|
| 68 |
+
batch_size: 512
|
| 69 |
+
num_workers: 1
|
| 70 |
+
|
| 71 |
+
lightning:
|
| 72 |
+
modelcheckpoint:
|
| 73 |
+
params:
|
| 74 |
+
every_n_train_steps: 5000
|
| 75 |
+
|
| 76 |
+
callbacks:
|
| 77 |
+
metrics_over_trainsteps_checkpoint:
|
| 78 |
+
params:
|
| 79 |
+
every_n_train_steps: 25000
|
| 80 |
+
|
| 81 |
+
image_logger:
|
| 82 |
+
target: main.ImageLogger
|
| 83 |
+
params:
|
| 84 |
+
disabled: False
|
| 85 |
+
batch_frequency: 1000
|
| 86 |
+
max_images: 64
|
| 87 |
+
increase_log_steps: True
|
| 88 |
+
log_first_step: False
|
| 89 |
+
log_images_kwargs:
|
| 90 |
+
use_ema_scope: False
|
| 91 |
+
N: 64
|
| 92 |
+
n_rows: 8
|
| 93 |
+
|
| 94 |
+
trainer:
|
| 95 |
+
devices: 0,
|
| 96 |
+
benchmark: True
|
| 97 |
+
num_sanity_val_steps: 0
|
| 98 |
+
accumulate_grad_batches: 1
|
| 99 |
+
max_epochs: 20
|
configs/example_training/toy/mnist_cond_with_ema.yaml
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model:
|
| 2 |
+
base_learning_rate: 1.0e-4
|
| 3 |
+
target: sgm.models.diffusion.DiffusionEngine
|
| 4 |
+
params:
|
| 5 |
+
use_ema: True
|
| 6 |
+
|
| 7 |
+
denoiser_config:
|
| 8 |
+
target: sgm.modules.diffusionmodules.denoiser.Denoiser
|
| 9 |
+
params:
|
| 10 |
+
scaling_config:
|
| 11 |
+
target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
|
| 12 |
+
params:
|
| 13 |
+
sigma_data: 1.0
|
| 14 |
+
|
| 15 |
+
network_config:
|
| 16 |
+
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
|
| 17 |
+
params:
|
| 18 |
+
in_channels: 1
|
| 19 |
+
out_channels: 1
|
| 20 |
+
model_channels: 32
|
| 21 |
+
attention_resolutions: []
|
| 22 |
+
num_res_blocks: 4
|
| 23 |
+
channel_mult: [1, 2, 2]
|
| 24 |
+
num_head_channels: 32
|
| 25 |
+
num_classes: sequential
|
| 26 |
+
adm_in_channels: 128
|
| 27 |
+
|
| 28 |
+
conditioner_config:
|
| 29 |
+
target: sgm.modules.GeneralConditioner
|
| 30 |
+
params:
|
| 31 |
+
emb_models:
|
| 32 |
+
- is_trainable: True
|
| 33 |
+
input_key: cls
|
| 34 |
+
ucg_rate: 0.2
|
| 35 |
+
target: sgm.modules.encoders.modules.ClassEmbedder
|
| 36 |
+
params:
|
| 37 |
+
embed_dim: 128
|
| 38 |
+
n_classes: 10
|
| 39 |
+
|
| 40 |
+
first_stage_config:
|
| 41 |
+
target: sgm.models.autoencoder.IdentityFirstStage
|
| 42 |
+
|
| 43 |
+
loss_fn_config:
|
| 44 |
+
target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
|
| 45 |
+
params:
|
| 46 |
+
loss_weighting_config:
|
| 47 |
+
target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting
|
| 48 |
+
params:
|
| 49 |
+
sigma_data: 1.0
|
| 50 |
+
sigma_sampler_config:
|
| 51 |
+
target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling
|
| 52 |
+
|
| 53 |
+
sampler_config:
|
| 54 |
+
target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
|
| 55 |
+
params:
|
| 56 |
+
num_steps: 50
|
| 57 |
+
|
| 58 |
+
discretization_config:
|
| 59 |
+
target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
|
| 60 |
+
|
| 61 |
+
guider_config:
|
| 62 |
+
target: sgm.modules.diffusionmodules.guiders.VanillaCFG
|
| 63 |
+
params:
|
| 64 |
+
scale: 3.0
|
| 65 |
+
|
| 66 |
+
data:
|
| 67 |
+
target: sgm.data.mnist.MNISTLoader
|
| 68 |
+
params:
|
| 69 |
+
batch_size: 512
|
| 70 |
+
num_workers: 1
|
| 71 |
+
|
| 72 |
+
lightning:
|
| 73 |
+
modelcheckpoint:
|
| 74 |
+
params:
|
| 75 |
+
every_n_train_steps: 5000
|
| 76 |
+
|
| 77 |
+
callbacks:
|
| 78 |
+
metrics_over_trainsteps_checkpoint:
|
| 79 |
+
params:
|
| 80 |
+
every_n_train_steps: 25000
|
| 81 |
+
|
| 82 |
+
image_logger:
|
| 83 |
+
target: main.ImageLogger
|
| 84 |
+
params:
|
| 85 |
+
disabled: False
|
| 86 |
+
batch_frequency: 1000
|
| 87 |
+
max_images: 64
|
| 88 |
+
increase_log_steps: True
|
| 89 |
+
log_first_step: False
|
| 90 |
+
log_images_kwargs:
|
| 91 |
+
use_ema_scope: False
|
| 92 |
+
N: 64
|
| 93 |
+
n_rows: 8
|
| 94 |
+
|
| 95 |
+
trainer:
|
| 96 |
+
devices: 0,
|
| 97 |
+
benchmark: True
|
| 98 |
+
num_sanity_val_steps: 0
|
| 99 |
+
accumulate_grad_batches: 1
|
| 100 |
+
max_epochs: 20
|
configs/example_training/txt2img-clipl-legacy-ucg-training.yaml
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model:
|
| 2 |
+
base_learning_rate: 1.0e-4
|
| 3 |
+
target: sgm.models.diffusion.DiffusionEngine
|
| 4 |
+
params:
|
| 5 |
+
scale_factor: 0.13025
|
| 6 |
+
disable_first_stage_autocast: True
|
| 7 |
+
log_keys:
|
| 8 |
+
- txt
|
| 9 |
+
|
| 10 |
+
scheduler_config:
|
| 11 |
+
target: sgm.lr_scheduler.LambdaLinearScheduler
|
| 12 |
+
params:
|
| 13 |
+
warm_up_steps: [10000]
|
| 14 |
+
cycle_lengths: [10000000000000]
|
| 15 |
+
f_start: [1.e-6]
|
| 16 |
+
f_max: [1.]
|
| 17 |
+
f_min: [1.]
|
| 18 |
+
|
| 19 |
+
denoiser_config:
|
| 20 |
+
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
|
| 21 |
+
params:
|
| 22 |
+
num_idx: 1000
|
| 23 |
+
|
| 24 |
+
scaling_config:
|
| 25 |
+
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
|
| 26 |
+
discretization_config:
|
| 27 |
+
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
|
| 28 |
+
|
| 29 |
+
network_config:
|
| 30 |
+
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
|
| 31 |
+
params:
|
| 32 |
+
use_checkpoint: True
|
| 33 |
+
in_channels: 4
|
| 34 |
+
out_channels: 4
|
| 35 |
+
model_channels: 320
|
| 36 |
+
attention_resolutions: [1, 2, 4]
|
| 37 |
+
num_res_blocks: 2
|
| 38 |
+
channel_mult: [1, 2, 4, 4]
|
| 39 |
+
num_head_channels: 64
|
| 40 |
+
num_classes: sequential
|
| 41 |
+
adm_in_channels: 1792
|
| 42 |
+
num_heads: 1
|
| 43 |
+
transformer_depth: 1
|
| 44 |
+
context_dim: 768
|
| 45 |
+
spatial_transformer_attn_type: softmax-xformers
|
| 46 |
+
|
| 47 |
+
conditioner_config:
|
| 48 |
+
target: sgm.modules.GeneralConditioner
|
| 49 |
+
params:
|
| 50 |
+
emb_models:
|
| 51 |
+
- is_trainable: True
|
| 52 |
+
input_key: txt
|
| 53 |
+
ucg_rate: 0.1
|
| 54 |
+
legacy_ucg_value: ""
|
| 55 |
+
target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
|
| 56 |
+
params:
|
| 57 |
+
always_return_pooled: True
|
| 58 |
+
|
| 59 |
+
- is_trainable: False
|
| 60 |
+
ucg_rate: 0.1
|
| 61 |
+
input_key: original_size_as_tuple
|
| 62 |
+
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
| 63 |
+
params:
|
| 64 |
+
outdim: 256
|
| 65 |
+
|
| 66 |
+
- is_trainable: False
|
| 67 |
+
input_key: crop_coords_top_left
|
| 68 |
+
ucg_rate: 0.1
|
| 69 |
+
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
| 70 |
+
params:
|
| 71 |
+
outdim: 256
|
| 72 |
+
|
| 73 |
+
first_stage_config:
|
| 74 |
+
target: sgm.models.autoencoder.AutoencoderKL
|
| 75 |
+
params:
|
| 76 |
+
ckpt_path: CKPT_PATH
|
| 77 |
+
embed_dim: 4
|
| 78 |
+
monitor: val/rec_loss
|
| 79 |
+
ddconfig:
|
| 80 |
+
attn_type: vanilla-xformers
|
| 81 |
+
double_z: true
|
| 82 |
+
z_channels: 4
|
| 83 |
+
resolution: 256
|
| 84 |
+
in_channels: 3
|
| 85 |
+
out_ch: 3
|
| 86 |
+
ch: 128
|
| 87 |
+
ch_mult: [ 1, 2, 4, 4 ]
|
| 88 |
+
num_res_blocks: 2
|
| 89 |
+
attn_resolutions: [ ]
|
| 90 |
+
dropout: 0.0
|
| 91 |
+
lossconfig:
|
| 92 |
+
target: torch.nn.Identity
|
| 93 |
+
|
| 94 |
+
loss_fn_config:
|
| 95 |
+
target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
|
| 96 |
+
params:
|
| 97 |
+
loss_weighting_config:
|
| 98 |
+
target: sgm.modules.diffusionmodules.loss_weighting.EpsWeighting
|
| 99 |
+
sigma_sampler_config:
|
| 100 |
+
target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
|
| 101 |
+
params:
|
| 102 |
+
num_idx: 1000
|
| 103 |
+
|
| 104 |
+
discretization_config:
|
| 105 |
+
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
|
| 106 |
+
|
| 107 |
+
sampler_config:
|
| 108 |
+
target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
|
| 109 |
+
params:
|
| 110 |
+
num_steps: 50
|
| 111 |
+
|
| 112 |
+
discretization_config:
|
| 113 |
+
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
|
| 114 |
+
|
| 115 |
+
guider_config:
|
| 116 |
+
target: sgm.modules.diffusionmodules.guiders.VanillaCFG
|
| 117 |
+
params:
|
| 118 |
+
scale: 7.5
|
| 119 |
+
|
| 120 |
+
data:
|
| 121 |
+
target: sgm.data.dataset.StableDataModuleFromConfig
|
| 122 |
+
params:
|
| 123 |
+
train:
|
| 124 |
+
datapipeline:
|
| 125 |
+
urls:
|
| 126 |
+
# USER: adapt this path the root of your custom dataset
|
| 127 |
+
- DATA_PATH
|
| 128 |
+
pipeline_config:
|
| 129 |
+
shardshuffle: 10000
|
| 130 |
+
sample_shuffle: 10000 # USER: you might wanna adapt depending on your available RAM
|
| 131 |
+
|
| 132 |
+
decoders:
|
| 133 |
+
- pil
|
| 134 |
+
|
| 135 |
+
postprocessors:
|
| 136 |
+
- target: sdata.mappers.TorchVisionImageTransforms
|
| 137 |
+
params:
|
| 138 |
+
key: jpg # USER: you might wanna adapt this for your custom dataset
|
| 139 |
+
transforms:
|
| 140 |
+
- target: torchvision.transforms.Resize
|
| 141 |
+
params:
|
| 142 |
+
size: 256
|
| 143 |
+
interpolation: 3
|
| 144 |
+
- target: torchvision.transforms.ToTensor
|
| 145 |
+
- target: sdata.mappers.Rescaler
|
| 146 |
+
- target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare
|
| 147 |
+
# USER: you might wanna use non-default parameters due to your custom dataset
|
| 148 |
+
|
| 149 |
+
loader:
|
| 150 |
+
batch_size: 64
|
| 151 |
+
num_workers: 6
|
| 152 |
+
|
| 153 |
+
lightning:
|
| 154 |
+
modelcheckpoint:
|
| 155 |
+
params:
|
| 156 |
+
every_n_train_steps: 5000
|
| 157 |
+
|
| 158 |
+
callbacks:
|
| 159 |
+
metrics_over_trainsteps_checkpoint:
|
| 160 |
+
params:
|
| 161 |
+
every_n_train_steps: 25000
|
| 162 |
+
|
| 163 |
+
image_logger:
|
| 164 |
+
target: main.ImageLogger
|
| 165 |
+
params:
|
| 166 |
+
disabled: False
|
| 167 |
+
enable_autocast: False
|
| 168 |
+
batch_frequency: 1000
|
| 169 |
+
max_images: 8
|
| 170 |
+
increase_log_steps: True
|
| 171 |
+
log_first_step: False
|
| 172 |
+
log_images_kwargs:
|
| 173 |
+
use_ema_scope: False
|
| 174 |
+
N: 8
|
| 175 |
+
n_rows: 2
|
| 176 |
+
|
| 177 |
+
trainer:
|
| 178 |
+
devices: 0,
|
| 179 |
+
benchmark: True
|
| 180 |
+
num_sanity_val_steps: 0
|
| 181 |
+
accumulate_grad_batches: 1
|
| 182 |
+
max_epochs: 1000
|
configs/example_training/txt2img-clipl.yaml
ADDED
|
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model:
|
| 2 |
+
base_learning_rate: 1.0e-4
|
| 3 |
+
target: sgm.models.diffusion.DiffusionEngine
|
| 4 |
+
params:
|
| 5 |
+
scale_factor: 0.13025
|
| 6 |
+
disable_first_stage_autocast: True
|
| 7 |
+
log_keys:
|
| 8 |
+
- txt
|
| 9 |
+
|
| 10 |
+
scheduler_config:
|
| 11 |
+
target: sgm.lr_scheduler.LambdaLinearScheduler
|
| 12 |
+
params:
|
| 13 |
+
warm_up_steps: [10000]
|
| 14 |
+
cycle_lengths: [10000000000000]
|
| 15 |
+
f_start: [1.e-6]
|
| 16 |
+
f_max: [1.]
|
| 17 |
+
f_min: [1.]
|
| 18 |
+
|
| 19 |
+
denoiser_config:
|
| 20 |
+
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
|
| 21 |
+
params:
|
| 22 |
+
num_idx: 1000
|
| 23 |
+
|
| 24 |
+
scaling_config:
|
| 25 |
+
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
|
| 26 |
+
discretization_config:
|
| 27 |
+
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
|
| 28 |
+
|
| 29 |
+
network_config:
|
| 30 |
+
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
|
| 31 |
+
params:
|
| 32 |
+
use_checkpoint: True
|
| 33 |
+
in_channels: 4
|
| 34 |
+
out_channels: 4
|
| 35 |
+
model_channels: 320
|
| 36 |
+
attention_resolutions: [1, 2, 4]
|
| 37 |
+
num_res_blocks: 2
|
| 38 |
+
channel_mult: [1, 2, 4, 4]
|
| 39 |
+
num_head_channels: 64
|
| 40 |
+
num_classes: sequential
|
| 41 |
+
adm_in_channels: 1792
|
| 42 |
+
num_heads: 1
|
| 43 |
+
transformer_depth: 1
|
| 44 |
+
context_dim: 768
|
| 45 |
+
spatial_transformer_attn_type: softmax-xformers
|
| 46 |
+
|
| 47 |
+
conditioner_config:
|
| 48 |
+
target: sgm.modules.GeneralConditioner
|
| 49 |
+
params:
|
| 50 |
+
emb_models:
|
| 51 |
+
- is_trainable: True
|
| 52 |
+
input_key: txt
|
| 53 |
+
ucg_rate: 0.1
|
| 54 |
+
legacy_ucg_value: ""
|
| 55 |
+
target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
|
| 56 |
+
params:
|
| 57 |
+
always_return_pooled: True
|
| 58 |
+
|
| 59 |
+
- is_trainable: False
|
| 60 |
+
ucg_rate: 0.1
|
| 61 |
+
input_key: original_size_as_tuple
|
| 62 |
+
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
| 63 |
+
params:
|
| 64 |
+
outdim: 256
|
| 65 |
+
|
| 66 |
+
- is_trainable: False
|
| 67 |
+
input_key: crop_coords_top_left
|
| 68 |
+
ucg_rate: 0.1
|
| 69 |
+
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
| 70 |
+
params:
|
| 71 |
+
outdim: 256
|
| 72 |
+
|
| 73 |
+
first_stage_config:
|
| 74 |
+
target: sgm.models.autoencoder.AutoencoderKL
|
| 75 |
+
params:
|
| 76 |
+
ckpt_path: CKPT_PATH
|
| 77 |
+
embed_dim: 4
|
| 78 |
+
monitor: val/rec_loss
|
| 79 |
+
ddconfig:
|
| 80 |
+
attn_type: vanilla-xformers
|
| 81 |
+
double_z: true
|
| 82 |
+
z_channels: 4
|
| 83 |
+
resolution: 256
|
| 84 |
+
in_channels: 3
|
| 85 |
+
out_ch: 3
|
| 86 |
+
ch: 128
|
| 87 |
+
ch_mult: [1, 2, 4, 4]
|
| 88 |
+
num_res_blocks: 2
|
| 89 |
+
attn_resolutions: []
|
| 90 |
+
dropout: 0.0
|
| 91 |
+
lossconfig:
|
| 92 |
+
target: torch.nn.Identity
|
| 93 |
+
|
| 94 |
+
loss_fn_config:
|
| 95 |
+
target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
|
| 96 |
+
params:
|
| 97 |
+
loss_weighting_config:
|
| 98 |
+
target: sgm.modules.diffusionmodules.loss_weighting.EpsWeighting
|
| 99 |
+
sigma_sampler_config:
|
| 100 |
+
target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
|
| 101 |
+
params:
|
| 102 |
+
num_idx: 1000
|
| 103 |
+
|
| 104 |
+
discretization_config:
|
| 105 |
+
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
|
| 106 |
+
|
| 107 |
+
sampler_config:
|
| 108 |
+
target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
|
| 109 |
+
params:
|
| 110 |
+
num_steps: 50
|
| 111 |
+
|
| 112 |
+
discretization_config:
|
| 113 |
+
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
|
| 114 |
+
|
| 115 |
+
guider_config:
|
| 116 |
+
target: sgm.modules.diffusionmodules.guiders.VanillaCFG
|
| 117 |
+
params:
|
| 118 |
+
scale: 7.5
|
| 119 |
+
|
| 120 |
+
data:
|
| 121 |
+
target: sgm.data.dataset.StableDataModuleFromConfig
|
| 122 |
+
params:
|
| 123 |
+
train:
|
| 124 |
+
datapipeline:
|
| 125 |
+
urls:
|
| 126 |
+
# USER: adapt this path the root of your custom dataset
|
| 127 |
+
- DATA_PATH
|
| 128 |
+
pipeline_config:
|
| 129 |
+
shardshuffle: 10000
|
| 130 |
+
sample_shuffle: 10000
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
decoders:
|
| 134 |
+
- pil
|
| 135 |
+
|
| 136 |
+
postprocessors:
|
| 137 |
+
- target: sdata.mappers.TorchVisionImageTransforms
|
| 138 |
+
params:
|
| 139 |
+
key: jpg # USER: you might wanna adapt this for your custom dataset
|
| 140 |
+
transforms:
|
| 141 |
+
- target: torchvision.transforms.Resize
|
| 142 |
+
params:
|
| 143 |
+
size: 256
|
| 144 |
+
interpolation: 3
|
| 145 |
+
- target: torchvision.transforms.ToTensor
|
| 146 |
+
- target: sdata.mappers.Rescaler
|
| 147 |
+
# USER: you might wanna use non-default parameters due to your custom dataset
|
| 148 |
+
- target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare
|
| 149 |
+
# USER: you might wanna use non-default parameters due to your custom dataset
|
| 150 |
+
|
| 151 |
+
loader:
|
| 152 |
+
batch_size: 64
|
| 153 |
+
num_workers: 6
|
| 154 |
+
|
| 155 |
+
lightning:
|
| 156 |
+
modelcheckpoint:
|
| 157 |
+
params:
|
| 158 |
+
every_n_train_steps: 5000
|
| 159 |
+
|
| 160 |
+
callbacks:
|
| 161 |
+
metrics_over_trainsteps_checkpoint:
|
| 162 |
+
params:
|
| 163 |
+
every_n_train_steps: 25000
|
| 164 |
+
|
| 165 |
+
image_logger:
|
| 166 |
+
target: main.ImageLogger
|
| 167 |
+
params:
|
| 168 |
+
disabled: False
|
| 169 |
+
enable_autocast: False
|
| 170 |
+
batch_frequency: 1000
|
| 171 |
+
max_images: 8
|
| 172 |
+
increase_log_steps: True
|
| 173 |
+
log_first_step: False
|
| 174 |
+
log_images_kwargs:
|
| 175 |
+
use_ema_scope: False
|
| 176 |
+
N: 8
|
| 177 |
+
n_rows: 2
|
| 178 |
+
|
| 179 |
+
trainer:
|
| 180 |
+
devices: 0,
|
| 181 |
+
benchmark: True
|
| 182 |
+
num_sanity_val_steps: 0
|
| 183 |
+
accumulate_grad_batches: 1
|
| 184 |
+
max_epochs: 1000
|
configs/inference/sd_2_1.yaml
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model:
|
| 2 |
+
target: sgm.models.diffusion.DiffusionEngine
|
| 3 |
+
params:
|
| 4 |
+
scale_factor: 0.18215
|
| 5 |
+
disable_first_stage_autocast: True
|
| 6 |
+
|
| 7 |
+
denoiser_config:
|
| 8 |
+
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
|
| 9 |
+
params:
|
| 10 |
+
num_idx: 1000
|
| 11 |
+
|
| 12 |
+
scaling_config:
|
| 13 |
+
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
|
| 14 |
+
discretization_config:
|
| 15 |
+
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
|
| 16 |
+
|
| 17 |
+
network_config:
|
| 18 |
+
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
|
| 19 |
+
params:
|
| 20 |
+
use_checkpoint: True
|
| 21 |
+
in_channels: 4
|
| 22 |
+
out_channels: 4
|
| 23 |
+
model_channels: 320
|
| 24 |
+
attention_resolutions: [4, 2, 1]
|
| 25 |
+
num_res_blocks: 2
|
| 26 |
+
channel_mult: [1, 2, 4, 4]
|
| 27 |
+
num_head_channels: 64
|
| 28 |
+
use_linear_in_transformer: True
|
| 29 |
+
transformer_depth: 1
|
| 30 |
+
context_dim: 1024
|
| 31 |
+
|
| 32 |
+
conditioner_config:
|
| 33 |
+
target: sgm.modules.GeneralConditioner
|
| 34 |
+
params:
|
| 35 |
+
emb_models:
|
| 36 |
+
- is_trainable: False
|
| 37 |
+
input_key: txt
|
| 38 |
+
target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder
|
| 39 |
+
params:
|
| 40 |
+
freeze: true
|
| 41 |
+
layer: penultimate
|
| 42 |
+
|
| 43 |
+
first_stage_config:
|
| 44 |
+
target: sgm.models.autoencoder.AutoencoderKL
|
| 45 |
+
params:
|
| 46 |
+
embed_dim: 4
|
| 47 |
+
monitor: val/rec_loss
|
| 48 |
+
ddconfig:
|
| 49 |
+
double_z: true
|
| 50 |
+
z_channels: 4
|
| 51 |
+
resolution: 256
|
| 52 |
+
in_channels: 3
|
| 53 |
+
out_ch: 3
|
| 54 |
+
ch: 128
|
| 55 |
+
ch_mult: [1, 2, 4, 4]
|
| 56 |
+
num_res_blocks: 2
|
| 57 |
+
attn_resolutions: []
|
| 58 |
+
dropout: 0.0
|
| 59 |
+
lossconfig:
|
| 60 |
+
target: torch.nn.Identity
|
configs/inference/sd_2_1_768.yaml
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model:
|
| 2 |
+
target: sgm.models.diffusion.DiffusionEngine
|
| 3 |
+
params:
|
| 4 |
+
scale_factor: 0.18215
|
| 5 |
+
disable_first_stage_autocast: True
|
| 6 |
+
|
| 7 |
+
denoiser_config:
|
| 8 |
+
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
|
| 9 |
+
params:
|
| 10 |
+
num_idx: 1000
|
| 11 |
+
|
| 12 |
+
scaling_config:
|
| 13 |
+
target: sgm.modules.diffusionmodules.denoiser_scaling.VScaling
|
| 14 |
+
discretization_config:
|
| 15 |
+
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
|
| 16 |
+
|
| 17 |
+
network_config:
|
| 18 |
+
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
|
| 19 |
+
params:
|
| 20 |
+
use_checkpoint: True
|
| 21 |
+
in_channels: 4
|
| 22 |
+
out_channels: 4
|
| 23 |
+
model_channels: 320
|
| 24 |
+
attention_resolutions: [4, 2, 1]
|
| 25 |
+
num_res_blocks: 2
|
| 26 |
+
channel_mult: [1, 2, 4, 4]
|
| 27 |
+
num_head_channels: 64
|
| 28 |
+
use_linear_in_transformer: True
|
| 29 |
+
transformer_depth: 1
|
| 30 |
+
context_dim: 1024
|
| 31 |
+
|
| 32 |
+
conditioner_config:
|
| 33 |
+
target: sgm.modules.GeneralConditioner
|
| 34 |
+
params:
|
| 35 |
+
emb_models:
|
| 36 |
+
- is_trainable: False
|
| 37 |
+
input_key: txt
|
| 38 |
+
target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder
|
| 39 |
+
params:
|
| 40 |
+
freeze: true
|
| 41 |
+
layer: penultimate
|
| 42 |
+
|
| 43 |
+
first_stage_config:
|
| 44 |
+
target: sgm.models.autoencoder.AutoencoderKL
|
| 45 |
+
params:
|
| 46 |
+
embed_dim: 4
|
| 47 |
+
monitor: val/rec_loss
|
| 48 |
+
ddconfig:
|
| 49 |
+
double_z: true
|
| 50 |
+
z_channels: 4
|
| 51 |
+
resolution: 256
|
| 52 |
+
in_channels: 3
|
| 53 |
+
out_ch: 3
|
| 54 |
+
ch: 128
|
| 55 |
+
ch_mult: [1, 2, 4, 4]
|
| 56 |
+
num_res_blocks: 2
|
| 57 |
+
attn_resolutions: []
|
| 58 |
+
dropout: 0.0
|
| 59 |
+
lossconfig:
|
| 60 |
+
target: torch.nn.Identity
|
configs/inference/sd_xl_base.yaml
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model:
|
| 2 |
+
target: sgm.models.diffusion.DiffusionEngine
|
| 3 |
+
params:
|
| 4 |
+
scale_factor: 0.13025
|
| 5 |
+
disable_first_stage_autocast: True
|
| 6 |
+
|
| 7 |
+
denoiser_config:
|
| 8 |
+
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
|
| 9 |
+
params:
|
| 10 |
+
num_idx: 1000
|
| 11 |
+
|
| 12 |
+
scaling_config:
|
| 13 |
+
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
|
| 14 |
+
discretization_config:
|
| 15 |
+
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
|
| 16 |
+
|
| 17 |
+
network_config:
|
| 18 |
+
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
|
| 19 |
+
params:
|
| 20 |
+
adm_in_channels: 2816
|
| 21 |
+
num_classes: sequential
|
| 22 |
+
use_checkpoint: True
|
| 23 |
+
in_channels: 4
|
| 24 |
+
out_channels: 4
|
| 25 |
+
model_channels: 320
|
| 26 |
+
attention_resolutions: [4, 2]
|
| 27 |
+
num_res_blocks: 2
|
| 28 |
+
channel_mult: [1, 2, 4]
|
| 29 |
+
num_head_channels: 64
|
| 30 |
+
use_linear_in_transformer: True
|
| 31 |
+
transformer_depth: [1, 2, 10]
|
| 32 |
+
context_dim: 2048
|
| 33 |
+
spatial_transformer_attn_type: softmax-xformers
|
| 34 |
+
|
| 35 |
+
conditioner_config:
|
| 36 |
+
target: sgm.modules.GeneralConditioner
|
| 37 |
+
params:
|
| 38 |
+
emb_models:
|
| 39 |
+
- is_trainable: False
|
| 40 |
+
input_key: txt
|
| 41 |
+
target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
|
| 42 |
+
params:
|
| 43 |
+
layer: hidden
|
| 44 |
+
layer_idx: 11
|
| 45 |
+
|
| 46 |
+
- is_trainable: False
|
| 47 |
+
input_key: txt
|
| 48 |
+
target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2
|
| 49 |
+
params:
|
| 50 |
+
arch: ViT-bigG-14
|
| 51 |
+
version: laion2b_s39b_b160k
|
| 52 |
+
freeze: True
|
| 53 |
+
layer: penultimate
|
| 54 |
+
always_return_pooled: True
|
| 55 |
+
legacy: False
|
| 56 |
+
|
| 57 |
+
- is_trainable: False
|
| 58 |
+
input_key: original_size_as_tuple
|
| 59 |
+
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
| 60 |
+
params:
|
| 61 |
+
outdim: 256
|
| 62 |
+
|
| 63 |
+
- is_trainable: False
|
| 64 |
+
input_key: crop_coords_top_left
|
| 65 |
+
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
| 66 |
+
params:
|
| 67 |
+
outdim: 256
|
| 68 |
+
|
| 69 |
+
- is_trainable: False
|
| 70 |
+
input_key: target_size_as_tuple
|
| 71 |
+
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
| 72 |
+
params:
|
| 73 |
+
outdim: 256
|
| 74 |
+
|
| 75 |
+
first_stage_config:
|
| 76 |
+
target: sgm.models.autoencoder.AutoencoderKL
|
| 77 |
+
params:
|
| 78 |
+
embed_dim: 4
|
| 79 |
+
monitor: val/rec_loss
|
| 80 |
+
ddconfig:
|
| 81 |
+
attn_type: vanilla-xformers
|
| 82 |
+
double_z: true
|
| 83 |
+
z_channels: 4
|
| 84 |
+
resolution: 256
|
| 85 |
+
in_channels: 3
|
| 86 |
+
out_ch: 3
|
| 87 |
+
ch: 128
|
| 88 |
+
ch_mult: [1, 2, 4, 4]
|
| 89 |
+
num_res_blocks: 2
|
| 90 |
+
attn_resolutions: []
|
| 91 |
+
dropout: 0.0
|
| 92 |
+
lossconfig:
|
| 93 |
+
target: torch.nn.Identity
|
configs/inference/sd_xl_refiner.yaml
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model:
|
| 2 |
+
target: sgm.models.diffusion.DiffusionEngine
|
| 3 |
+
params:
|
| 4 |
+
scale_factor: 0.13025
|
| 5 |
+
disable_first_stage_autocast: True
|
| 6 |
+
|
| 7 |
+
denoiser_config:
|
| 8 |
+
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
|
| 9 |
+
params:
|
| 10 |
+
num_idx: 1000
|
| 11 |
+
|
| 12 |
+
scaling_config:
|
| 13 |
+
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
|
| 14 |
+
discretization_config:
|
| 15 |
+
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
|
| 16 |
+
|
| 17 |
+
network_config:
|
| 18 |
+
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
|
| 19 |
+
params:
|
| 20 |
+
adm_in_channels: 2560
|
| 21 |
+
num_classes: sequential
|
| 22 |
+
use_checkpoint: True
|
| 23 |
+
in_channels: 4
|
| 24 |
+
out_channels: 4
|
| 25 |
+
model_channels: 384
|
| 26 |
+
attention_resolutions: [4, 2]
|
| 27 |
+
num_res_blocks: 2
|
| 28 |
+
channel_mult: [1, 2, 4, 4]
|
| 29 |
+
num_head_channels: 64
|
| 30 |
+
use_linear_in_transformer: True
|
| 31 |
+
transformer_depth: 4
|
| 32 |
+
context_dim: [1280, 1280, 1280, 1280]
|
| 33 |
+
spatial_transformer_attn_type: softmax-xformers
|
| 34 |
+
|
| 35 |
+
conditioner_config:
|
| 36 |
+
target: sgm.modules.GeneralConditioner
|
| 37 |
+
params:
|
| 38 |
+
emb_models:
|
| 39 |
+
- is_trainable: False
|
| 40 |
+
input_key: txt
|
| 41 |
+
target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2
|
| 42 |
+
params:
|
| 43 |
+
arch: ViT-bigG-14
|
| 44 |
+
version: laion2b_s39b_b160k
|
| 45 |
+
legacy: False
|
| 46 |
+
freeze: True
|
| 47 |
+
layer: penultimate
|
| 48 |
+
always_return_pooled: True
|
| 49 |
+
|
| 50 |
+
- is_trainable: False
|
| 51 |
+
input_key: original_size_as_tuple
|
| 52 |
+
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
| 53 |
+
params:
|
| 54 |
+
outdim: 256
|
| 55 |
+
|
| 56 |
+
- is_trainable: False
|
| 57 |
+
input_key: crop_coords_top_left
|
| 58 |
+
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
| 59 |
+
params:
|
| 60 |
+
outdim: 256
|
| 61 |
+
|
| 62 |
+
- is_trainable: False
|
| 63 |
+
input_key: aesthetic_score
|
| 64 |
+
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
| 65 |
+
params:
|
| 66 |
+
outdim: 256
|
| 67 |
+
|
| 68 |
+
first_stage_config:
|
| 69 |
+
target: sgm.models.autoencoder.AutoencoderKL
|
| 70 |
+
params:
|
| 71 |
+
embed_dim: 4
|
| 72 |
+
monitor: val/rec_loss
|
| 73 |
+
ddconfig:
|
| 74 |
+
attn_type: vanilla-xformers
|
| 75 |
+
double_z: true
|
| 76 |
+
z_channels: 4
|
| 77 |
+
resolution: 256
|
| 78 |
+
in_channels: 3
|
| 79 |
+
out_ch: 3
|
| 80 |
+
ch: 128
|
| 81 |
+
ch_mult: [1, 2, 4, 4]
|
| 82 |
+
num_res_blocks: 2
|
| 83 |
+
attn_resolutions: []
|
| 84 |
+
dropout: 0.0
|
| 85 |
+
lossconfig:
|
| 86 |
+
target: torch.nn.Identity
|
configs/inference/svd.yaml
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model:
|
| 2 |
+
target: sgm.models.diffusion.DiffusionEngine
|
| 3 |
+
params:
|
| 4 |
+
scale_factor: 0.18215
|
| 5 |
+
disable_first_stage_autocast: True
|
| 6 |
+
|
| 7 |
+
denoiser_config:
|
| 8 |
+
target: sgm.modules.diffusionmodules.denoiser.Denoiser
|
| 9 |
+
params:
|
| 10 |
+
scaling_config:
|
| 11 |
+
target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
|
| 12 |
+
|
| 13 |
+
network_config:
|
| 14 |
+
target: sgm.modules.diffusionmodules.video_model.VideoUNet
|
| 15 |
+
params:
|
| 16 |
+
adm_in_channels: 768
|
| 17 |
+
num_classes: sequential
|
| 18 |
+
use_checkpoint: True
|
| 19 |
+
in_channels: 8
|
| 20 |
+
out_channels: 4
|
| 21 |
+
model_channels: 320
|
| 22 |
+
attention_resolutions: [4, 2, 1]
|
| 23 |
+
num_res_blocks: 2
|
| 24 |
+
channel_mult: [1, 2, 4, 4]
|
| 25 |
+
num_head_channels: 64
|
| 26 |
+
use_linear_in_transformer: True
|
| 27 |
+
transformer_depth: 1
|
| 28 |
+
context_dim: 1024
|
| 29 |
+
spatial_transformer_attn_type: softmax-xformers
|
| 30 |
+
extra_ff_mix_layer: True
|
| 31 |
+
use_spatial_context: True
|
| 32 |
+
merge_strategy: learned_with_images
|
| 33 |
+
video_kernel_size: [3, 1, 1]
|
| 34 |
+
|
| 35 |
+
conditioner_config:
|
| 36 |
+
target: sgm.modules.GeneralConditioner
|
| 37 |
+
params:
|
| 38 |
+
emb_models:
|
| 39 |
+
- is_trainable: False
|
| 40 |
+
input_key: cond_frames_without_noise
|
| 41 |
+
target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
|
| 42 |
+
params:
|
| 43 |
+
n_cond_frames: 1
|
| 44 |
+
n_copies: 1
|
| 45 |
+
open_clip_embedding_config:
|
| 46 |
+
target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
|
| 47 |
+
params:
|
| 48 |
+
freeze: True
|
| 49 |
+
|
| 50 |
+
- input_key: fps_id
|
| 51 |
+
is_trainable: False
|
| 52 |
+
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
| 53 |
+
params:
|
| 54 |
+
outdim: 256
|
| 55 |
+
|
| 56 |
+
- input_key: motion_bucket_id
|
| 57 |
+
is_trainable: False
|
| 58 |
+
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
| 59 |
+
params:
|
| 60 |
+
outdim: 256
|
| 61 |
+
|
| 62 |
+
- input_key: cond_frames
|
| 63 |
+
is_trainable: False
|
| 64 |
+
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
|
| 65 |
+
params:
|
| 66 |
+
disable_encoder_autocast: True
|
| 67 |
+
n_cond_frames: 1
|
| 68 |
+
n_copies: 1
|
| 69 |
+
is_ae: True
|
| 70 |
+
encoder_config:
|
| 71 |
+
target: sgm.models.autoencoder.AutoencoderKLModeOnly
|
| 72 |
+
params:
|
| 73 |
+
embed_dim: 4
|
| 74 |
+
monitor: val/rec_loss
|
| 75 |
+
ddconfig:
|
| 76 |
+
attn_type: vanilla-xformers
|
| 77 |
+
double_z: True
|
| 78 |
+
z_channels: 4
|
| 79 |
+
resolution: 256
|
| 80 |
+
in_channels: 3
|
| 81 |
+
out_ch: 3
|
| 82 |
+
ch: 128
|
| 83 |
+
ch_mult: [1, 2, 4, 4]
|
| 84 |
+
num_res_blocks: 2
|
| 85 |
+
attn_resolutions: []
|
| 86 |
+
dropout: 0.0
|
| 87 |
+
lossconfig:
|
| 88 |
+
target: torch.nn.Identity
|
| 89 |
+
|
| 90 |
+
- input_key: cond_aug
|
| 91 |
+
is_trainable: False
|
| 92 |
+
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
| 93 |
+
params:
|
| 94 |
+
outdim: 256
|
| 95 |
+
|
| 96 |
+
first_stage_config:
|
| 97 |
+
target: sgm.models.autoencoder.AutoencodingEngine
|
| 98 |
+
params:
|
| 99 |
+
loss_config:
|
| 100 |
+
target: torch.nn.Identity
|
| 101 |
+
regularizer_config:
|
| 102 |
+
target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
|
| 103 |
+
encoder_config:
|
| 104 |
+
target: sgm.modules.diffusionmodules.model.Encoder
|
| 105 |
+
params:
|
| 106 |
+
attn_type: vanilla
|
| 107 |
+
double_z: True
|
| 108 |
+
z_channels: 4
|
| 109 |
+
resolution: 256
|
| 110 |
+
in_channels: 3
|
| 111 |
+
out_ch: 3
|
| 112 |
+
ch: 128
|
| 113 |
+
ch_mult: [1, 2, 4, 4]
|
| 114 |
+
num_res_blocks: 2
|
| 115 |
+
attn_resolutions: []
|
| 116 |
+
dropout: 0.0
|
| 117 |
+
decoder_config:
|
| 118 |
+
target: sgm.modules.autoencoding.temporal_ae.VideoDecoder
|
| 119 |
+
params:
|
| 120 |
+
attn_type: vanilla
|
| 121 |
+
double_z: True
|
| 122 |
+
z_channels: 4
|
| 123 |
+
resolution: 256
|
| 124 |
+
in_channels: 3
|
| 125 |
+
out_ch: 3
|
| 126 |
+
ch: 128
|
| 127 |
+
ch_mult: [1, 2, 4, 4]
|
| 128 |
+
num_res_blocks: 2
|
| 129 |
+
attn_resolutions: []
|
| 130 |
+
dropout: 0.0
|
| 131 |
+
video_kernel_size: [3, 1, 1]
|
configs/inference/svd_image_decoder.yaml
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model:
|
| 2 |
+
target: sgm.models.diffusion.DiffusionEngine
|
| 3 |
+
params:
|
| 4 |
+
scale_factor: 0.18215
|
| 5 |
+
disable_first_stage_autocast: True
|
| 6 |
+
|
| 7 |
+
denoiser_config:
|
| 8 |
+
target: sgm.modules.diffusionmodules.denoiser.Denoiser
|
| 9 |
+
params:
|
| 10 |
+
scaling_config:
|
| 11 |
+
target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
|
| 12 |
+
|
| 13 |
+
network_config:
|
| 14 |
+
target: sgm.modules.diffusionmodules.video_model.VideoUNet
|
| 15 |
+
params:
|
| 16 |
+
adm_in_channels: 768
|
| 17 |
+
num_classes: sequential
|
| 18 |
+
use_checkpoint: True
|
| 19 |
+
in_channels: 8
|
| 20 |
+
out_channels: 4
|
| 21 |
+
model_channels: 320
|
| 22 |
+
attention_resolutions: [4, 2, 1]
|
| 23 |
+
num_res_blocks: 2
|
| 24 |
+
channel_mult: [1, 2, 4, 4]
|
| 25 |
+
num_head_channels: 64
|
| 26 |
+
use_linear_in_transformer: True
|
| 27 |
+
transformer_depth: 1
|
| 28 |
+
context_dim: 1024
|
| 29 |
+
spatial_transformer_attn_type: softmax-xformers
|
| 30 |
+
extra_ff_mix_layer: True
|
| 31 |
+
use_spatial_context: True
|
| 32 |
+
merge_strategy: learned_with_images
|
| 33 |
+
video_kernel_size: [3, 1, 1]
|
| 34 |
+
|
| 35 |
+
conditioner_config:
|
| 36 |
+
target: sgm.modules.GeneralConditioner
|
| 37 |
+
params:
|
| 38 |
+
emb_models:
|
| 39 |
+
- is_trainable: False
|
| 40 |
+
input_key: cond_frames_without_noise
|
| 41 |
+
target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
|
| 42 |
+
params:
|
| 43 |
+
n_cond_frames: 1
|
| 44 |
+
n_copies: 1
|
| 45 |
+
open_clip_embedding_config:
|
| 46 |
+
target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
|
| 47 |
+
params:
|
| 48 |
+
freeze: True
|
| 49 |
+
|
| 50 |
+
- input_key: fps_id
|
| 51 |
+
is_trainable: False
|
| 52 |
+
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
| 53 |
+
params:
|
| 54 |
+
outdim: 256
|
| 55 |
+
|
| 56 |
+
- input_key: motion_bucket_id
|
| 57 |
+
is_trainable: False
|
| 58 |
+
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
| 59 |
+
params:
|
| 60 |
+
outdim: 256
|
| 61 |
+
|
| 62 |
+
- input_key: cond_frames
|
| 63 |
+
is_trainable: False
|
| 64 |
+
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
|
| 65 |
+
params:
|
| 66 |
+
disable_encoder_autocast: True
|
| 67 |
+
n_cond_frames: 1
|
| 68 |
+
n_copies: 1
|
| 69 |
+
is_ae: True
|
| 70 |
+
encoder_config:
|
| 71 |
+
target: sgm.models.autoencoder.AutoencoderKLModeOnly
|
| 72 |
+
params:
|
| 73 |
+
embed_dim: 4
|
| 74 |
+
monitor: val/rec_loss
|
| 75 |
+
ddconfig:
|
| 76 |
+
attn_type: vanilla-xformers
|
| 77 |
+
double_z: True
|
| 78 |
+
z_channels: 4
|
| 79 |
+
resolution: 256
|
| 80 |
+
in_channels: 3
|
| 81 |
+
out_ch: 3
|
| 82 |
+
ch: 128
|
| 83 |
+
ch_mult: [1, 2, 4, 4]
|
| 84 |
+
num_res_blocks: 2
|
| 85 |
+
attn_resolutions: []
|
| 86 |
+
dropout: 0.0
|
| 87 |
+
lossconfig:
|
| 88 |
+
target: torch.nn.Identity
|
| 89 |
+
|
| 90 |
+
- input_key: cond_aug
|
| 91 |
+
is_trainable: False
|
| 92 |
+
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
| 93 |
+
params:
|
| 94 |
+
outdim: 256
|
| 95 |
+
|
| 96 |
+
first_stage_config:
|
| 97 |
+
target: sgm.models.autoencoder.AutoencoderKL
|
| 98 |
+
params:
|
| 99 |
+
embed_dim: 4
|
| 100 |
+
monitor: val/rec_loss
|
| 101 |
+
ddconfig:
|
| 102 |
+
attn_type: vanilla-xformers
|
| 103 |
+
double_z: True
|
| 104 |
+
z_channels: 4
|
| 105 |
+
resolution: 256
|
| 106 |
+
in_channels: 3
|
| 107 |
+
out_ch: 3
|
| 108 |
+
ch: 128
|
| 109 |
+
ch_mult: [1, 2, 4, 4]
|
| 110 |
+
num_res_blocks: 2
|
| 111 |
+
attn_resolutions: []
|
| 112 |
+
dropout: 0.0
|
| 113 |
+
lossconfig:
|
| 114 |
+
target: torch.nn.Identity
|
configs/inference/svd_mv.yaml
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model:
|
| 2 |
+
base_learning_rate: 1.0e-05
|
| 3 |
+
target: sgm.models.video_diffusion.DiffusionEngine
|
| 4 |
+
params:
|
| 5 |
+
ckpt_path: ckpts/svd_xt.safetensors
|
| 6 |
+
scale_factor: 0.18215
|
| 7 |
+
disable_first_stage_autocast: true
|
| 8 |
+
scheduler_config:
|
| 9 |
+
target: sgm.lr_scheduler.LambdaLinearScheduler
|
| 10 |
+
params:
|
| 11 |
+
warm_up_steps:
|
| 12 |
+
- 1
|
| 13 |
+
cycle_lengths:
|
| 14 |
+
- 10000000000000
|
| 15 |
+
f_start:
|
| 16 |
+
- 1.0e-06
|
| 17 |
+
f_max:
|
| 18 |
+
- 1.0
|
| 19 |
+
f_min:
|
| 20 |
+
- 1.0
|
| 21 |
+
denoiser_config:
|
| 22 |
+
target: sgm.modules.diffusionmodules.denoiser.Denoiser
|
| 23 |
+
params:
|
| 24 |
+
scaling_config:
|
| 25 |
+
target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
|
| 26 |
+
network_config:
|
| 27 |
+
target: sgm.modules.diffusionmodules.video_model.VideoUNet
|
| 28 |
+
params:
|
| 29 |
+
adm_in_channels: 768
|
| 30 |
+
num_classes: sequential
|
| 31 |
+
use_checkpoint: true
|
| 32 |
+
in_channels: 8
|
| 33 |
+
out_channels: 4
|
| 34 |
+
model_channels: 320
|
| 35 |
+
attention_resolutions:
|
| 36 |
+
- 4
|
| 37 |
+
- 2
|
| 38 |
+
- 1
|
| 39 |
+
num_res_blocks: 2
|
| 40 |
+
channel_mult:
|
| 41 |
+
- 1
|
| 42 |
+
- 2
|
| 43 |
+
- 4
|
| 44 |
+
- 4
|
| 45 |
+
num_head_channels: 64
|
| 46 |
+
use_linear_in_transformer: true
|
| 47 |
+
transformer_depth: 1
|
| 48 |
+
context_dim: 1024
|
| 49 |
+
spatial_transformer_attn_type: softmax-xformers
|
| 50 |
+
extra_ff_mix_layer: true
|
| 51 |
+
use_spatial_context: true
|
| 52 |
+
merge_strategy: learned_with_images
|
| 53 |
+
video_kernel_size:
|
| 54 |
+
- 3
|
| 55 |
+
- 1
|
| 56 |
+
- 1
|
| 57 |
+
conditioner_config:
|
| 58 |
+
target: sgm.modules.GeneralConditioner
|
| 59 |
+
params:
|
| 60 |
+
emb_models:
|
| 61 |
+
- is_trainable: false
|
| 62 |
+
ucg_rate: 0.2
|
| 63 |
+
input_key: cond_frames_without_noise
|
| 64 |
+
target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
|
| 65 |
+
params:
|
| 66 |
+
n_cond_frames: 1
|
| 67 |
+
n_copies: 1
|
| 68 |
+
open_clip_embedding_config:
|
| 69 |
+
target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
|
| 70 |
+
params:
|
| 71 |
+
freeze: true
|
| 72 |
+
- input_key: fps_id
|
| 73 |
+
is_trainable: true
|
| 74 |
+
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
| 75 |
+
params:
|
| 76 |
+
outdim: 256
|
| 77 |
+
- input_key: motion_bucket_id
|
| 78 |
+
is_trainable: true
|
| 79 |
+
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
| 80 |
+
params:
|
| 81 |
+
outdim: 256
|
| 82 |
+
- input_key: cond_frames
|
| 83 |
+
is_trainable: false
|
| 84 |
+
ucg_rate: 0.2
|
| 85 |
+
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
|
| 86 |
+
params:
|
| 87 |
+
disable_encoder_autocast: true
|
| 88 |
+
n_cond_frames: 1
|
| 89 |
+
n_copies: 1
|
| 90 |
+
is_ae: true
|
| 91 |
+
encoder_config:
|
| 92 |
+
target: sgm.models.autoencoder.AutoencoderKLModeOnly
|
| 93 |
+
params:
|
| 94 |
+
embed_dim: 4
|
| 95 |
+
monitor: val/rec_loss
|
| 96 |
+
ddconfig:
|
| 97 |
+
attn_type: vanilla-xformers
|
| 98 |
+
double_z: true
|
| 99 |
+
z_channels: 4
|
| 100 |
+
resolution: 256
|
| 101 |
+
in_channels: 3
|
| 102 |
+
out_ch: 3
|
| 103 |
+
ch: 128
|
| 104 |
+
ch_mult:
|
| 105 |
+
- 1
|
| 106 |
+
- 2
|
| 107 |
+
- 4
|
| 108 |
+
- 4
|
| 109 |
+
num_res_blocks: 2
|
| 110 |
+
attn_resolutions: []
|
| 111 |
+
dropout: 0.0
|
| 112 |
+
lossconfig:
|
| 113 |
+
target: torch.nn.Identity
|
| 114 |
+
- input_key: cond_aug
|
| 115 |
+
is_trainable: true
|
| 116 |
+
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
| 117 |
+
params:
|
| 118 |
+
outdim: 256
|
| 119 |
+
first_stage_config:
|
| 120 |
+
target: sgm.models.autoencoder.AutoencodingEngine
|
| 121 |
+
params:
|
| 122 |
+
loss_config:
|
| 123 |
+
target: torch.nn.Identity
|
| 124 |
+
regularizer_config:
|
| 125 |
+
target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
|
| 126 |
+
encoder_config:
|
| 127 |
+
target: sgm.modules.diffusionmodules.model.Encoder
|
| 128 |
+
params:
|
| 129 |
+
attn_type: vanilla
|
| 130 |
+
double_z: true
|
| 131 |
+
z_channels: 4
|
| 132 |
+
resolution: 256
|
| 133 |
+
in_channels: 3
|
| 134 |
+
out_ch: 3
|
| 135 |
+
ch: 128
|
| 136 |
+
ch_mult:
|
| 137 |
+
- 1
|
| 138 |
+
- 2
|
| 139 |
+
- 4
|
| 140 |
+
- 4
|
| 141 |
+
num_res_blocks: 2
|
| 142 |
+
attn_resolutions: []
|
| 143 |
+
dropout: 0.0
|
| 144 |
+
decoder_config:
|
| 145 |
+
target: sgm.modules.autoencoding.temporal_ae.VideoDecoder
|
| 146 |
+
params:
|
| 147 |
+
attn_type: vanilla
|
| 148 |
+
double_z: true
|
| 149 |
+
z_channels: 4
|
| 150 |
+
resolution: 256
|
| 151 |
+
in_channels: 3
|
| 152 |
+
out_ch: 3
|
| 153 |
+
ch: 128
|
| 154 |
+
ch_mult:
|
| 155 |
+
- 1
|
| 156 |
+
- 2
|
| 157 |
+
- 4
|
| 158 |
+
- 4
|
| 159 |
+
num_res_blocks: 2
|
| 160 |
+
attn_resolutions: []
|
| 161 |
+
dropout: 0.0
|
| 162 |
+
video_kernel_size:
|
| 163 |
+
- 3
|
| 164 |
+
- 1
|
| 165 |
+
- 1
|
| 166 |
+
sampler_config:
|
| 167 |
+
target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
|
| 168 |
+
params:
|
| 169 |
+
num_steps: 30
|
| 170 |
+
discretization_config:
|
| 171 |
+
target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
|
| 172 |
+
params:
|
| 173 |
+
sigma_max: 700.0
|
| 174 |
+
guider_config:
|
| 175 |
+
target: sgm.modules.diffusionmodules.guiders.LinearPredictionGuider
|
| 176 |
+
params:
|
| 177 |
+
max_scale: 2.5
|
| 178 |
+
min_scale: 1.0
|
| 179 |
+
num_frames: 24
|
| 180 |
+
loss_fn_config:
|
| 181 |
+
target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
|
| 182 |
+
params:
|
| 183 |
+
batch2model_keys:
|
| 184 |
+
- num_video_frames
|
| 185 |
+
- image_only_indicator
|
| 186 |
+
loss_weighting_config:
|
| 187 |
+
target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting
|
| 188 |
+
params:
|
| 189 |
+
sigma_data: 1.0
|
| 190 |
+
sigma_sampler_config:
|
| 191 |
+
target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling
|
| 192 |
+
params:
|
| 193 |
+
p_mean: 0.3
|
| 194 |
+
p_std: 1.2
|
| 195 |
+
data:
|
| 196 |
+
target: sgm.data.objaverse.ObjaverseSpiralDataset
|
| 197 |
+
params:
|
| 198 |
+
root_dir: /mnt/mfs/zilong.chen/Downloads/objaverse-ndd-samples
|
| 199 |
+
random_front: true
|
| 200 |
+
batch_size: 2
|
| 201 |
+
num_workers: 16
|
| 202 |
+
cond_aug_mean: -0.0
|
mesh_recon/configs/neuralangelo-ortho-wmask.yaml
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: ${basename:${dataset.scene}}
|
| 2 |
+
tag: ""
|
| 3 |
+
seed: 42
|
| 4 |
+
|
| 5 |
+
dataset:
|
| 6 |
+
name: ortho
|
| 7 |
+
root_dir: /home/xiaoxiao/Workplace/wonder3Dplus/outputs/joint-twice/aigc/cropsize-224-cfg1.0
|
| 8 |
+
cam_pose_dir: null
|
| 9 |
+
scene: scene_name
|
| 10 |
+
imSize: [1024, 1024] # should use larger res, otherwise the exported mesh has wrong colors
|
| 11 |
+
camera_type: ortho
|
| 12 |
+
apply_mask: true
|
| 13 |
+
camera_params: null
|
| 14 |
+
view_weights: [1.0, 0.8, 0.2, 1.0, 0.4, 0.7] #['front', 'front_right', 'right', 'back', 'left', 'front_left']
|
| 15 |
+
# view_weights: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
|
| 16 |
+
|
| 17 |
+
model:
|
| 18 |
+
name: neus
|
| 19 |
+
radius: 1.0
|
| 20 |
+
num_samples_per_ray: 1024
|
| 21 |
+
train_num_rays: 256
|
| 22 |
+
max_train_num_rays: 8192
|
| 23 |
+
grid_prune: true
|
| 24 |
+
grid_prune_occ_thre: 0.001
|
| 25 |
+
dynamic_ray_sampling: true
|
| 26 |
+
batch_image_sampling: true
|
| 27 |
+
randomized: true
|
| 28 |
+
ray_chunk: 2048
|
| 29 |
+
cos_anneal_end: 20000
|
| 30 |
+
learned_background: false
|
| 31 |
+
background_color: black
|
| 32 |
+
variance:
|
| 33 |
+
init_val: 0.3
|
| 34 |
+
modulate: false
|
| 35 |
+
geometry:
|
| 36 |
+
name: volume-sdf
|
| 37 |
+
radius: ${model.radius}
|
| 38 |
+
feature_dim: 13
|
| 39 |
+
grad_type: finite_difference
|
| 40 |
+
finite_difference_eps: progressive
|
| 41 |
+
isosurface:
|
| 42 |
+
method: mc
|
| 43 |
+
resolution: 192
|
| 44 |
+
chunk: 2097152
|
| 45 |
+
threshold: 0.
|
| 46 |
+
xyz_encoding_config:
|
| 47 |
+
otype: ProgressiveBandHashGrid
|
| 48 |
+
n_levels: 10 # 12 modify
|
| 49 |
+
n_features_per_level: 2
|
| 50 |
+
log2_hashmap_size: 19
|
| 51 |
+
base_resolution: 32
|
| 52 |
+
per_level_scale: 1.3195079107728942
|
| 53 |
+
include_xyz: true
|
| 54 |
+
start_level: 4
|
| 55 |
+
start_step: 0
|
| 56 |
+
update_steps: 1000
|
| 57 |
+
mlp_network_config:
|
| 58 |
+
otype: VanillaMLP
|
| 59 |
+
activation: ReLU
|
| 60 |
+
output_activation: none
|
| 61 |
+
n_neurons: 64
|
| 62 |
+
n_hidden_layers: 1
|
| 63 |
+
sphere_init: true
|
| 64 |
+
sphere_init_radius: 0.5
|
| 65 |
+
weight_norm: true
|
| 66 |
+
texture:
|
| 67 |
+
name: volume-radiance
|
| 68 |
+
input_feature_dim: ${add:${model.geometry.feature_dim},3} # surface normal as additional input
|
| 69 |
+
dir_encoding_config:
|
| 70 |
+
otype: SphericalHarmonics
|
| 71 |
+
degree: 4
|
| 72 |
+
mlp_network_config:
|
| 73 |
+
otype: VanillaMLP
|
| 74 |
+
activation: ReLU
|
| 75 |
+
output_activation: none
|
| 76 |
+
n_neurons: 64
|
| 77 |
+
n_hidden_layers: 2
|
| 78 |
+
color_activation: sigmoid
|
| 79 |
+
|
| 80 |
+
system:
|
| 81 |
+
name: ortho-neus-system
|
| 82 |
+
loss:
|
| 83 |
+
lambda_rgb_mse: 0.5
|
| 84 |
+
lambda_rgb_l1: 0.
|
| 85 |
+
lambda_mask: 1.0
|
| 86 |
+
lambda_eikonal: 0.2 # cannot be too large, will cause holes to thin objects
|
| 87 |
+
lambda_normal: 1.0 # cannot be too large
|
| 88 |
+
lambda_3d_normal_smooth: 1.0
|
| 89 |
+
# lambda_curvature: [0, 0.0, 1.e-4, 1000] # topology warmup
|
| 90 |
+
lambda_curvature: 0.
|
| 91 |
+
lambda_sparsity: 0.5
|
| 92 |
+
lambda_distortion: 0.0
|
| 93 |
+
lambda_distortion_bg: 0.0
|
| 94 |
+
lambda_opaque: 0.0
|
| 95 |
+
sparsity_scale: 100.0
|
| 96 |
+
geo_aware: true
|
| 97 |
+
rgb_p_ratio: 0.8
|
| 98 |
+
normal_p_ratio: 0.8
|
| 99 |
+
mask_p_ratio: 0.9
|
| 100 |
+
optimizer:
|
| 101 |
+
name: AdamW
|
| 102 |
+
args:
|
| 103 |
+
lr: 0.01
|
| 104 |
+
betas: [0.9, 0.99]
|
| 105 |
+
eps: 1.e-15
|
| 106 |
+
params:
|
| 107 |
+
geometry:
|
| 108 |
+
lr: 0.001
|
| 109 |
+
texture:
|
| 110 |
+
lr: 0.01
|
| 111 |
+
variance:
|
| 112 |
+
lr: 0.001
|
| 113 |
+
constant_steps: 500
|
| 114 |
+
scheduler:
|
| 115 |
+
name: SequentialLR
|
| 116 |
+
interval: step
|
| 117 |
+
milestones:
|
| 118 |
+
- ${system.constant_steps}
|
| 119 |
+
schedulers:
|
| 120 |
+
- name: ConstantLR
|
| 121 |
+
args:
|
| 122 |
+
factor: 1.0
|
| 123 |
+
total_iters: ${system.constant_steps}
|
| 124 |
+
- name: ExponentialLR
|
| 125 |
+
args:
|
| 126 |
+
gamma: ${calc_exp_lr_decay_rate:0.1,${sub:${trainer.max_steps},${system.constant_steps}}}
|
| 127 |
+
|
| 128 |
+
checkpoint:
|
| 129 |
+
save_top_k: -1
|
| 130 |
+
every_n_train_steps: ${trainer.max_steps}
|
| 131 |
+
|
| 132 |
+
export:
|
| 133 |
+
chunk_size: 2097152
|
| 134 |
+
export_vertex_color: True
|
| 135 |
+
ortho_scale: 1.35 #modify
|
| 136 |
+
|
| 137 |
+
trainer:
|
| 138 |
+
max_steps: 3000
|
| 139 |
+
log_every_n_steps: 100
|
| 140 |
+
num_sanity_val_steps: 0
|
| 141 |
+
val_check_interval: 4000
|
| 142 |
+
limit_train_batches: 1.0
|
| 143 |
+
limit_val_batches: 2
|
| 144 |
+
enable_progress_bar: true
|
| 145 |
+
precision: 16
|
mesh_recon/configs/v3d.yaml
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: ${basename:${dataset.scene}}
|
| 2 |
+
tag: ""
|
| 3 |
+
seed: 42
|
| 4 |
+
|
| 5 |
+
dataset:
|
| 6 |
+
name: v3d
|
| 7 |
+
root_dir: ./spirals
|
| 8 |
+
cam_pose_dir: null
|
| 9 |
+
scene: pizza_man
|
| 10 |
+
apply_mask: true
|
| 11 |
+
train_split: train
|
| 12 |
+
test_split: train
|
| 13 |
+
val_split: train
|
| 14 |
+
img_wh: [1024, 1024]
|
| 15 |
+
|
| 16 |
+
model:
|
| 17 |
+
name: neus
|
| 18 |
+
radius: 1.0 ## check this
|
| 19 |
+
num_samples_per_ray: 1024
|
| 20 |
+
train_num_rays: 256
|
| 21 |
+
max_train_num_rays: 8192
|
| 22 |
+
grid_prune: true
|
| 23 |
+
grid_prune_occ_thre: 0.001
|
| 24 |
+
dynamic_ray_sampling: true
|
| 25 |
+
batch_image_sampling: true
|
| 26 |
+
randomized: true
|
| 27 |
+
ray_chunk: 2048
|
| 28 |
+
cos_anneal_end: 20000
|
| 29 |
+
learned_background: false
|
| 30 |
+
background_color: black
|
| 31 |
+
variance:
|
| 32 |
+
init_val: 0.3
|
| 33 |
+
modulate: false
|
| 34 |
+
geometry:
|
| 35 |
+
name: volume-sdf
|
| 36 |
+
radius: ${model.radius}
|
| 37 |
+
feature_dim: 13
|
| 38 |
+
grad_type: finite_difference
|
| 39 |
+
finite_difference_eps: progressive
|
| 40 |
+
isosurface:
|
| 41 |
+
method: mc
|
| 42 |
+
resolution: 384
|
| 43 |
+
chunk: 2097152
|
| 44 |
+
threshold: 0.
|
| 45 |
+
xyz_encoding_config:
|
| 46 |
+
otype: ProgressiveBandHashGrid
|
| 47 |
+
n_levels: 10 # 12 modify
|
| 48 |
+
n_features_per_level: 2
|
| 49 |
+
log2_hashmap_size: 19
|
| 50 |
+
base_resolution: 32
|
| 51 |
+
per_level_scale: 1.3195079107728942
|
| 52 |
+
include_xyz: true
|
| 53 |
+
start_level: 4
|
| 54 |
+
start_step: 0
|
| 55 |
+
update_steps: 1000
|
| 56 |
+
mlp_network_config:
|
| 57 |
+
otype: VanillaMLP
|
| 58 |
+
activation: ReLU
|
| 59 |
+
output_activation: none
|
| 60 |
+
n_neurons: 64
|
| 61 |
+
n_hidden_layers: 1
|
| 62 |
+
sphere_init: true
|
| 63 |
+
sphere_init_radius: 0.5
|
| 64 |
+
weight_norm: true
|
| 65 |
+
texture:
|
| 66 |
+
name: volume-radiance
|
| 67 |
+
input_feature_dim: ${add:${model.geometry.feature_dim},3} # surface normal as additional input
|
| 68 |
+
dir_encoding_config:
|
| 69 |
+
otype: SphericalHarmonics
|
| 70 |
+
degree: 4
|
| 71 |
+
mlp_network_config:
|
| 72 |
+
otype: VanillaMLP
|
| 73 |
+
activation: ReLU
|
| 74 |
+
output_activation: none
|
| 75 |
+
n_neurons: 64
|
| 76 |
+
n_hidden_layers: 2
|
| 77 |
+
color_activation: sigmoid
|
| 78 |
+
|
| 79 |
+
system:
|
| 80 |
+
name: videonvs-neus-system
|
| 81 |
+
loss:
|
| 82 |
+
lambda_rgb_mse: 0.5
|
| 83 |
+
lambda_rgb_l1: 0.
|
| 84 |
+
lambda_mask: 1.0
|
| 85 |
+
lambda_eikonal: 0.2 # cannot be too large, will cause holes to thin objects
|
| 86 |
+
lambda_normal: 0.0 # cannot be too large
|
| 87 |
+
lambda_3d_normal_smooth: 1.0
|
| 88 |
+
# lambda_curvature: [0, 0.0, 1.e-4, 1000] # topology warmup
|
| 89 |
+
lambda_curvature: 0.
|
| 90 |
+
lambda_sparsity: 0.5
|
| 91 |
+
lambda_distortion: 0.0
|
| 92 |
+
lambda_distortion_bg: 0.0
|
| 93 |
+
lambda_opaque: 0.0
|
| 94 |
+
sparsity_scale: 100.0
|
| 95 |
+
geo_aware: true
|
| 96 |
+
rgb_p_ratio: 0.8
|
| 97 |
+
normal_p_ratio: 0.8
|
| 98 |
+
mask_p_ratio: 0.9
|
| 99 |
+
optimizer:
|
| 100 |
+
name: AdamW
|
| 101 |
+
args:
|
| 102 |
+
lr: 0.01
|
| 103 |
+
betas: [0.9, 0.99]
|
| 104 |
+
eps: 1.e-15
|
| 105 |
+
params:
|
| 106 |
+
geometry:
|
| 107 |
+
lr: 0.001
|
| 108 |
+
texture:
|
| 109 |
+
lr: 0.01
|
| 110 |
+
variance:
|
| 111 |
+
lr: 0.001
|
| 112 |
+
constant_steps: 500
|
| 113 |
+
scheduler:
|
| 114 |
+
name: SequentialLR
|
| 115 |
+
interval: step
|
| 116 |
+
milestones:
|
| 117 |
+
- ${system.constant_steps}
|
| 118 |
+
schedulers:
|
| 119 |
+
- name: ConstantLR
|
| 120 |
+
args:
|
| 121 |
+
factor: 1.0
|
| 122 |
+
total_iters: ${system.constant_steps}
|
| 123 |
+
- name: ExponentialLR
|
| 124 |
+
args:
|
| 125 |
+
gamma: ${calc_exp_lr_decay_rate:0.1,${sub:${trainer.max_steps},${system.constant_steps}}}
|
| 126 |
+
|
| 127 |
+
checkpoint:
|
| 128 |
+
save_top_k: -1
|
| 129 |
+
every_n_train_steps: ${trainer.max_steps}
|
| 130 |
+
|
| 131 |
+
export:
|
| 132 |
+
chunk_size: 2097152
|
| 133 |
+
export_vertex_color: True
|
| 134 |
+
ortho_scale: null #modify
|
| 135 |
+
|
| 136 |
+
trainer:
|
| 137 |
+
max_steps: 3000
|
| 138 |
+
log_every_n_steps: 100
|
| 139 |
+
num_sanity_val_steps: 0
|
| 140 |
+
val_check_interval: 3000
|
| 141 |
+
limit_train_batches: 1.0
|
| 142 |
+
limit_val_batches: 2
|
| 143 |
+
enable_progress_bar: true
|
| 144 |
+
precision: 16
|
mesh_recon/configs/videonvs.yaml
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: ${basename:${dataset.scene}}
|
| 2 |
+
tag: ""
|
| 3 |
+
seed: 42
|
| 4 |
+
|
| 5 |
+
dataset:
|
| 6 |
+
name: videonvs
|
| 7 |
+
root_dir: ./spirals
|
| 8 |
+
cam_pose_dir: null
|
| 9 |
+
scene: pizza_man
|
| 10 |
+
apply_mask: true
|
| 11 |
+
train_split: train
|
| 12 |
+
test_split: train
|
| 13 |
+
val_split: train
|
| 14 |
+
img_wh: [1024, 1024]
|
| 15 |
+
|
| 16 |
+
model:
|
| 17 |
+
name: neus
|
| 18 |
+
radius: 1.0 ## check this
|
| 19 |
+
num_samples_per_ray: 1024
|
| 20 |
+
train_num_rays: 256
|
| 21 |
+
max_train_num_rays: 8192
|
| 22 |
+
grid_prune: true
|
| 23 |
+
grid_prune_occ_thre: 0.001
|
| 24 |
+
dynamic_ray_sampling: true
|
| 25 |
+
batch_image_sampling: true
|
| 26 |
+
randomized: true
|
| 27 |
+
ray_chunk: 2048
|
| 28 |
+
cos_anneal_end: 20000
|
| 29 |
+
learned_background: false
|
| 30 |
+
background_color: black
|
| 31 |
+
variance:
|
| 32 |
+
init_val: 0.3
|
| 33 |
+
modulate: false
|
| 34 |
+
geometry:
|
| 35 |
+
name: volume-sdf
|
| 36 |
+
radius: ${model.radius}
|
| 37 |
+
feature_dim: 13
|
| 38 |
+
grad_type: finite_difference
|
| 39 |
+
finite_difference_eps: progressive
|
| 40 |
+
isosurface:
|
| 41 |
+
method: mc
|
| 42 |
+
resolution: 384
|
| 43 |
+
chunk: 2097152
|
| 44 |
+
threshold: 0.
|
| 45 |
+
xyz_encoding_config:
|
| 46 |
+
otype: ProgressiveBandHashGrid
|
| 47 |
+
n_levels: 10 # 12 modify
|
| 48 |
+
n_features_per_level: 2
|
| 49 |
+
log2_hashmap_size: 19
|
| 50 |
+
base_resolution: 32
|
| 51 |
+
per_level_scale: 1.3195079107728942
|
| 52 |
+
include_xyz: true
|
| 53 |
+
start_level: 4
|
| 54 |
+
start_step: 0
|
| 55 |
+
update_steps: 1000
|
| 56 |
+
mlp_network_config:
|
| 57 |
+
otype: VanillaMLP
|
| 58 |
+
activation: ReLU
|
| 59 |
+
output_activation: none
|
| 60 |
+
n_neurons: 64
|
| 61 |
+
n_hidden_layers: 1
|
| 62 |
+
sphere_init: true
|
| 63 |
+
sphere_init_radius: 0.5
|
| 64 |
+
weight_norm: true
|
| 65 |
+
texture:
|
| 66 |
+
name: volume-radiance
|
| 67 |
+
input_feature_dim: ${add:${model.geometry.feature_dim},3} # surface normal as additional input
|
| 68 |
+
dir_encoding_config:
|
| 69 |
+
otype: SphericalHarmonics
|
| 70 |
+
degree: 4
|
| 71 |
+
mlp_network_config:
|
| 72 |
+
otype: VanillaMLP
|
| 73 |
+
activation: ReLU
|
| 74 |
+
output_activation: none
|
| 75 |
+
n_neurons: 64
|
| 76 |
+
n_hidden_layers: 2
|
| 77 |
+
color_activation: sigmoid
|
| 78 |
+
|
| 79 |
+
system:
|
| 80 |
+
name: videonvs-neus-system
|
| 81 |
+
loss:
|
| 82 |
+
lambda_rgb_mse: 0.5
|
| 83 |
+
lambda_rgb_l1: 0.
|
| 84 |
+
lambda_mask: 1.0
|
| 85 |
+
lambda_eikonal: 0.2 # cannot be too large, will cause holes to thin objects
|
| 86 |
+
lambda_normal: 1.0 # cannot be too large
|
| 87 |
+
lambda_3d_normal_smooth: 1.0
|
| 88 |
+
# lambda_curvature: [0, 0.0, 1.e-4, 1000] # topology warmup
|
| 89 |
+
lambda_curvature: 0.
|
| 90 |
+
lambda_sparsity: 0.5
|
| 91 |
+
lambda_distortion: 0.0
|
| 92 |
+
lambda_distortion_bg: 0.0
|
| 93 |
+
lambda_opaque: 0.0
|
| 94 |
+
sparsity_scale: 100.0
|
| 95 |
+
geo_aware: true
|
| 96 |
+
rgb_p_ratio: 0.8
|
| 97 |
+
normal_p_ratio: 0.8
|
| 98 |
+
mask_p_ratio: 0.9
|
| 99 |
+
optimizer:
|
| 100 |
+
name: AdamW
|
| 101 |
+
args:
|
| 102 |
+
lr: 0.01
|
| 103 |
+
betas: [0.9, 0.99]
|
| 104 |
+
eps: 1.e-15
|
| 105 |
+
params:
|
| 106 |
+
geometry:
|
| 107 |
+
lr: 0.001
|
| 108 |
+
texture:
|
| 109 |
+
lr: 0.01
|
| 110 |
+
variance:
|
| 111 |
+
lr: 0.001
|
| 112 |
+
constant_steps: 500
|
| 113 |
+
scheduler:
|
| 114 |
+
name: SequentialLR
|
| 115 |
+
interval: step
|
| 116 |
+
milestones:
|
| 117 |
+
- ${system.constant_steps}
|
| 118 |
+
schedulers:
|
| 119 |
+
- name: ConstantLR
|
| 120 |
+
args:
|
| 121 |
+
factor: 1.0
|
| 122 |
+
total_iters: ${system.constant_steps}
|
| 123 |
+
- name: ExponentialLR
|
| 124 |
+
args:
|
| 125 |
+
gamma: ${calc_exp_lr_decay_rate:0.1,${sub:${trainer.max_steps},${system.constant_steps}}}
|
| 126 |
+
|
| 127 |
+
checkpoint:
|
| 128 |
+
save_top_k: -1
|
| 129 |
+
every_n_train_steps: ${trainer.max_steps}
|
| 130 |
+
|
| 131 |
+
export:
|
| 132 |
+
chunk_size: 2097152
|
| 133 |
+
export_vertex_color: True
|
| 134 |
+
ortho_scale: null #modify
|
| 135 |
+
|
| 136 |
+
trainer:
|
| 137 |
+
max_steps: 3000
|
| 138 |
+
log_every_n_steps: 100
|
| 139 |
+
num_sanity_val_steps: 0
|
| 140 |
+
val_check_interval: 3000
|
| 141 |
+
limit_train_batches: 1.0
|
| 142 |
+
limit_val_batches: 2
|
| 143 |
+
enable_progress_bar: true
|
| 144 |
+
precision: 16
|
mesh_recon/datasets/__init__.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
datasets = {}
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def register(name):
|
| 5 |
+
def decorator(cls):
|
| 6 |
+
datasets[name] = cls
|
| 7 |
+
return cls
|
| 8 |
+
|
| 9 |
+
return decorator
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def make(name, config):
|
| 13 |
+
dataset = datasets[name](config)
|
| 14 |
+
return dataset
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
from . import blender, colmap, dtu, ortho, videonvs, videonvs_co3d, v3d
|
mesh_recon/datasets/blender.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import math
|
| 4 |
+
import numpy as np
|
| 5 |
+
from PIL import Image
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch.utils.data import Dataset, DataLoader, IterableDataset
|
| 9 |
+
import torchvision.transforms.functional as TF
|
| 10 |
+
|
| 11 |
+
import pytorch_lightning as pl
|
| 12 |
+
|
| 13 |
+
import datasets
|
| 14 |
+
from models.ray_utils import get_ray_directions
|
| 15 |
+
from utils.misc import get_rank
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class BlenderDatasetBase:
|
| 19 |
+
def setup(self, config, split):
|
| 20 |
+
self.config = config
|
| 21 |
+
self.split = split
|
| 22 |
+
self.rank = get_rank()
|
| 23 |
+
|
| 24 |
+
self.has_mask = True
|
| 25 |
+
self.apply_mask = True
|
| 26 |
+
|
| 27 |
+
with open(
|
| 28 |
+
os.path.join(self.config.root_dir, f"transforms_{self.split}.json"), "r"
|
| 29 |
+
) as f:
|
| 30 |
+
meta = json.load(f)
|
| 31 |
+
|
| 32 |
+
if "w" in meta and "h" in meta:
|
| 33 |
+
W, H = int(meta["w"]), int(meta["h"])
|
| 34 |
+
else:
|
| 35 |
+
W, H = 800, 800
|
| 36 |
+
|
| 37 |
+
if "img_wh" in self.config:
|
| 38 |
+
w, h = self.config.img_wh
|
| 39 |
+
assert round(W / w * h) == H
|
| 40 |
+
elif "img_downscale" in self.config:
|
| 41 |
+
w, h = W // self.config.img_downscale, H // self.config.img_downscale
|
| 42 |
+
else:
|
| 43 |
+
raise KeyError("Either img_wh or img_downscale should be specified.")
|
| 44 |
+
|
| 45 |
+
self.w, self.h = w, h
|
| 46 |
+
self.img_wh = (self.w, self.h)
|
| 47 |
+
|
| 48 |
+
self.near, self.far = self.config.near_plane, self.config.far_plane
|
| 49 |
+
|
| 50 |
+
self.focal = (
|
| 51 |
+
0.5 * w / math.tan(0.5 * meta["camera_angle_x"])
|
| 52 |
+
) # scaled focal length
|
| 53 |
+
|
| 54 |
+
# ray directions for all pixels, same for all images (same H, W, focal)
|
| 55 |
+
self.directions = get_ray_directions(
|
| 56 |
+
self.w, self.h, self.focal, self.focal, self.w // 2, self.h // 2
|
| 57 |
+
).to(
|
| 58 |
+
self.rank
|
| 59 |
+
) # (h, w, 3)
|
| 60 |
+
|
| 61 |
+
self.all_c2w, self.all_images, self.all_fg_masks = [], [], []
|
| 62 |
+
|
| 63 |
+
for i, frame in enumerate(meta["frames"]):
|
| 64 |
+
c2w = torch.from_numpy(np.array(frame["transform_matrix"])[:3, :4])
|
| 65 |
+
self.all_c2w.append(c2w)
|
| 66 |
+
|
| 67 |
+
img_path = os.path.join(self.config.root_dir, f"{frame['file_path']}.png")
|
| 68 |
+
img = Image.open(img_path)
|
| 69 |
+
img = img.resize(self.img_wh, Image.BICUBIC)
|
| 70 |
+
img = TF.to_tensor(img).permute(1, 2, 0) # (4, h, w) => (h, w, 4)
|
| 71 |
+
|
| 72 |
+
self.all_fg_masks.append(img[..., -1]) # (h, w)
|
| 73 |
+
self.all_images.append(img[..., :3])
|
| 74 |
+
|
| 75 |
+
self.all_c2w, self.all_images, self.all_fg_masks = (
|
| 76 |
+
torch.stack(self.all_c2w, dim=0).float().to(self.rank),
|
| 77 |
+
torch.stack(self.all_images, dim=0).float().to(self.rank),
|
| 78 |
+
torch.stack(self.all_fg_masks, dim=0).float().to(self.rank),
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class BlenderDataset(Dataset, BlenderDatasetBase):
|
| 83 |
+
def __init__(self, config, split):
|
| 84 |
+
self.setup(config, split)
|
| 85 |
+
|
| 86 |
+
def __len__(self):
|
| 87 |
+
return len(self.all_images)
|
| 88 |
+
|
| 89 |
+
def __getitem__(self, index):
|
| 90 |
+
return {"index": index}
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class BlenderIterableDataset(IterableDataset, BlenderDatasetBase):
|
| 94 |
+
def __init__(self, config, split):
|
| 95 |
+
self.setup(config, split)
|
| 96 |
+
|
| 97 |
+
def __iter__(self):
|
| 98 |
+
while True:
|
| 99 |
+
yield {}
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
@datasets.register("blender")
|
| 103 |
+
class VideoNVSDataModule(pl.LightningDataModule):
|
| 104 |
+
def __init__(self, config):
|
| 105 |
+
super().__init__()
|
| 106 |
+
self.config = config
|
| 107 |
+
|
| 108 |
+
def setup(self, stage=None):
|
| 109 |
+
if stage in [None, "fit"]:
|
| 110 |
+
self.train_dataset = BlenderIterableDataset(
|
| 111 |
+
self.config, self.config.train_split
|
| 112 |
+
)
|
| 113 |
+
if stage in [None, "fit", "validate"]:
|
| 114 |
+
self.val_dataset = BlenderDataset(self.config, self.config.val_split)
|
| 115 |
+
if stage in [None, "test"]:
|
| 116 |
+
self.test_dataset = BlenderDataset(self.config, self.config.test_split)
|
| 117 |
+
if stage in [None, "predict"]:
|
| 118 |
+
self.predict_dataset = BlenderDataset(self.config, self.config.train_split)
|
| 119 |
+
|
| 120 |
+
def prepare_data(self):
|
| 121 |
+
pass
|
| 122 |
+
|
| 123 |
+
def general_loader(self, dataset, batch_size):
|
| 124 |
+
sampler = None
|
| 125 |
+
return DataLoader(
|
| 126 |
+
dataset,
|
| 127 |
+
num_workers=os.cpu_count(),
|
| 128 |
+
batch_size=batch_size,
|
| 129 |
+
pin_memory=True,
|
| 130 |
+
sampler=sampler,
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
def train_dataloader(self):
|
| 134 |
+
return self.general_loader(self.train_dataset, batch_size=1)
|
| 135 |
+
|
| 136 |
+
def val_dataloader(self):
|
| 137 |
+
return self.general_loader(self.val_dataset, batch_size=1)
|
| 138 |
+
|
| 139 |
+
def test_dataloader(self):
|
| 140 |
+
return self.general_loader(self.test_dataset, batch_size=1)
|
| 141 |
+
|
| 142 |
+
def predict_dataloader(self):
|
| 143 |
+
return self.general_loader(self.predict_dataset, batch_size=1)
|
mesh_recon/datasets/colmap.py
ADDED
|
@@ -0,0 +1,332 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import math
|
| 3 |
+
import numpy as np
|
| 4 |
+
from PIL import Image
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from torch.utils.data import Dataset, DataLoader, IterableDataset
|
| 9 |
+
import torchvision.transforms.functional as TF
|
| 10 |
+
|
| 11 |
+
import pytorch_lightning as pl
|
| 12 |
+
|
| 13 |
+
import datasets
|
| 14 |
+
from datasets.colmap_utils import \
|
| 15 |
+
read_cameras_binary, read_images_binary, read_points3d_binary
|
| 16 |
+
from models.ray_utils import get_ray_directions
|
| 17 |
+
from utils.misc import get_rank
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def get_center(pts):
|
| 21 |
+
center = pts.mean(0)
|
| 22 |
+
dis = (pts - center[None,:]).norm(p=2, dim=-1)
|
| 23 |
+
mean, std = dis.mean(), dis.std()
|
| 24 |
+
q25, q75 = torch.quantile(dis, 0.25), torch.quantile(dis, 0.75)
|
| 25 |
+
valid = (dis > mean - 1.5 * std) & (dis < mean + 1.5 * std) & (dis > mean - (q75 - q25) * 1.5) & (dis < mean + (q75 - q25) * 1.5)
|
| 26 |
+
center = pts[valid].mean(0)
|
| 27 |
+
return center
|
| 28 |
+
|
| 29 |
+
def normalize_poses(poses, pts, up_est_method, center_est_method):
|
| 30 |
+
if center_est_method == 'camera':
|
| 31 |
+
# estimation scene center as the average of all camera positions
|
| 32 |
+
center = poses[...,3].mean(0)
|
| 33 |
+
elif center_est_method == 'lookat':
|
| 34 |
+
# estimation scene center as the average of the intersection of selected pairs of camera rays
|
| 35 |
+
cams_ori = poses[...,3]
|
| 36 |
+
cams_dir = poses[:,:3,:3] @ torch.as_tensor([0.,0.,-1.])
|
| 37 |
+
cams_dir = F.normalize(cams_dir, dim=-1)
|
| 38 |
+
A = torch.stack([cams_dir, -cams_dir.roll(1,0)], dim=-1)
|
| 39 |
+
b = -cams_ori + cams_ori.roll(1,0)
|
| 40 |
+
t = torch.linalg.lstsq(A, b).solution
|
| 41 |
+
center = (torch.stack([cams_dir, cams_dir.roll(1,0)], dim=-1) * t[:,None,:] + torch.stack([cams_ori, cams_ori.roll(1,0)], dim=-1)).mean((0,2))
|
| 42 |
+
elif center_est_method == 'point':
|
| 43 |
+
# first estimation scene center as the average of all camera positions
|
| 44 |
+
# later we'll use the center of all points bounded by the cameras as the final scene center
|
| 45 |
+
center = poses[...,3].mean(0)
|
| 46 |
+
else:
|
| 47 |
+
raise NotImplementedError(f'Unknown center estimation method: {center_est_method}')
|
| 48 |
+
|
| 49 |
+
if up_est_method == 'ground':
|
| 50 |
+
# estimate up direction as the normal of the estimated ground plane
|
| 51 |
+
# use RANSAC to estimate the ground plane in the point cloud
|
| 52 |
+
import pyransac3d as pyrsc
|
| 53 |
+
ground = pyrsc.Plane()
|
| 54 |
+
plane_eq, inliers = ground.fit(pts.numpy(), thresh=0.01) # TODO: determine thresh based on scene scale
|
| 55 |
+
plane_eq = torch.as_tensor(plane_eq) # A, B, C, D in Ax + By + Cz + D = 0
|
| 56 |
+
z = F.normalize(plane_eq[:3], dim=-1) # plane normal as up direction
|
| 57 |
+
signed_distance = (torch.cat([pts, torch.ones_like(pts[...,0:1])], dim=-1) * plane_eq).sum(-1)
|
| 58 |
+
if signed_distance.mean() < 0:
|
| 59 |
+
z = -z # flip the direction if points lie under the plane
|
| 60 |
+
elif up_est_method == 'camera':
|
| 61 |
+
# estimate up direction as the average of all camera up directions
|
| 62 |
+
z = F.normalize((poses[...,3] - center).mean(0), dim=0)
|
| 63 |
+
else:
|
| 64 |
+
raise NotImplementedError(f'Unknown up estimation method: {up_est_method}')
|
| 65 |
+
|
| 66 |
+
# new axis
|
| 67 |
+
y_ = torch.as_tensor([z[1], -z[0], 0.])
|
| 68 |
+
x = F.normalize(y_.cross(z), dim=0)
|
| 69 |
+
y = z.cross(x)
|
| 70 |
+
|
| 71 |
+
if center_est_method == 'point':
|
| 72 |
+
# rotation
|
| 73 |
+
Rc = torch.stack([x, y, z], dim=1)
|
| 74 |
+
R = Rc.T
|
| 75 |
+
poses_homo = torch.cat([poses, torch.as_tensor([[[0.,0.,0.,1.]]]).expand(poses.shape[0], -1, -1)], dim=1)
|
| 76 |
+
inv_trans = torch.cat([torch.cat([R, torch.as_tensor([[0.,0.,0.]]).T], dim=1), torch.as_tensor([[0.,0.,0.,1.]])], dim=0)
|
| 77 |
+
poses_norm = (inv_trans @ poses_homo)[:,:3]
|
| 78 |
+
pts = (inv_trans @ torch.cat([pts, torch.ones_like(pts[:,0:1])], dim=-1)[...,None])[:,:3,0]
|
| 79 |
+
|
| 80 |
+
# translation and scaling
|
| 81 |
+
poses_min, poses_max = poses_norm[...,3].min(0)[0], poses_norm[...,3].max(0)[0]
|
| 82 |
+
pts_fg = pts[(poses_min[0] < pts[:,0]) & (pts[:,0] < poses_max[0]) & (poses_min[1] < pts[:,1]) & (pts[:,1] < poses_max[1])]
|
| 83 |
+
center = get_center(pts_fg)
|
| 84 |
+
tc = center.reshape(3, 1)
|
| 85 |
+
t = -tc
|
| 86 |
+
poses_homo = torch.cat([poses_norm, torch.as_tensor([[[0.,0.,0.,1.]]]).expand(poses_norm.shape[0], -1, -1)], dim=1)
|
| 87 |
+
inv_trans = torch.cat([torch.cat([torch.eye(3), t], dim=1), torch.as_tensor([[0.,0.,0.,1.]])], dim=0)
|
| 88 |
+
poses_norm = (inv_trans @ poses_homo)[:,:3]
|
| 89 |
+
scale = poses_norm[...,3].norm(p=2, dim=-1).min()
|
| 90 |
+
poses_norm[...,3] /= scale
|
| 91 |
+
pts = (inv_trans @ torch.cat([pts, torch.ones_like(pts[:,0:1])], dim=-1)[...,None])[:,:3,0]
|
| 92 |
+
pts = pts / scale
|
| 93 |
+
else:
|
| 94 |
+
# rotation and translation
|
| 95 |
+
Rc = torch.stack([x, y, z], dim=1)
|
| 96 |
+
tc = center.reshape(3, 1)
|
| 97 |
+
R, t = Rc.T, -Rc.T @ tc
|
| 98 |
+
poses_homo = torch.cat([poses, torch.as_tensor([[[0.,0.,0.,1.]]]).expand(poses.shape[0], -1, -1)], dim=1)
|
| 99 |
+
inv_trans = torch.cat([torch.cat([R, t], dim=1), torch.as_tensor([[0.,0.,0.,1.]])], dim=0)
|
| 100 |
+
poses_norm = (inv_trans @ poses_homo)[:,:3] # (N_images, 4, 4)
|
| 101 |
+
|
| 102 |
+
# scaling
|
| 103 |
+
scale = poses_norm[...,3].norm(p=2, dim=-1).min()
|
| 104 |
+
poses_norm[...,3] /= scale
|
| 105 |
+
|
| 106 |
+
# apply the transformation to the point cloud
|
| 107 |
+
pts = (inv_trans @ torch.cat([pts, torch.ones_like(pts[:,0:1])], dim=-1)[...,None])[:,:3,0]
|
| 108 |
+
pts = pts / scale
|
| 109 |
+
|
| 110 |
+
return poses_norm, pts
|
| 111 |
+
|
| 112 |
+
def create_spheric_poses(cameras, n_steps=120):
|
| 113 |
+
center = torch.as_tensor([0.,0.,0.], dtype=cameras.dtype, device=cameras.device)
|
| 114 |
+
mean_d = (cameras - center[None,:]).norm(p=2, dim=-1).mean()
|
| 115 |
+
mean_h = cameras[:,2].mean()
|
| 116 |
+
r = (mean_d**2 - mean_h**2).sqrt()
|
| 117 |
+
up = torch.as_tensor([0., 0., 1.], dtype=center.dtype, device=center.device)
|
| 118 |
+
|
| 119 |
+
all_c2w = []
|
| 120 |
+
for theta in torch.linspace(0, 2 * math.pi, n_steps):
|
| 121 |
+
cam_pos = torch.stack([r * theta.cos(), r * theta.sin(), mean_h])
|
| 122 |
+
l = F.normalize(center - cam_pos, p=2, dim=0)
|
| 123 |
+
s = F.normalize(l.cross(up), p=2, dim=0)
|
| 124 |
+
u = F.normalize(s.cross(l), p=2, dim=0)
|
| 125 |
+
c2w = torch.cat([torch.stack([s, u, -l], dim=1), cam_pos[:,None]], axis=1)
|
| 126 |
+
all_c2w.append(c2w)
|
| 127 |
+
|
| 128 |
+
all_c2w = torch.stack(all_c2w, dim=0)
|
| 129 |
+
|
| 130 |
+
return all_c2w
|
| 131 |
+
|
| 132 |
+
class ColmapDatasetBase():
|
| 133 |
+
# the data only has to be processed once
|
| 134 |
+
initialized = False
|
| 135 |
+
properties = {}
|
| 136 |
+
|
| 137 |
+
def setup(self, config, split):
|
| 138 |
+
self.config = config
|
| 139 |
+
self.split = split
|
| 140 |
+
self.rank = get_rank()
|
| 141 |
+
|
| 142 |
+
if not ColmapDatasetBase.initialized:
|
| 143 |
+
camdata = read_cameras_binary(os.path.join(self.config.root_dir, 'sparse/0/cameras.bin'))
|
| 144 |
+
|
| 145 |
+
H = int(camdata[1].height)
|
| 146 |
+
W = int(camdata[1].width)
|
| 147 |
+
|
| 148 |
+
if 'img_wh' in self.config:
|
| 149 |
+
w, h = self.config.img_wh
|
| 150 |
+
assert round(W / w * h) == H
|
| 151 |
+
elif 'img_downscale' in self.config:
|
| 152 |
+
w, h = int(W / self.config.img_downscale + 0.5), int(H / self.config.img_downscale + 0.5)
|
| 153 |
+
else:
|
| 154 |
+
raise KeyError("Either img_wh or img_downscale should be specified.")
|
| 155 |
+
|
| 156 |
+
img_wh = (w, h)
|
| 157 |
+
factor = w / W
|
| 158 |
+
|
| 159 |
+
if camdata[1].model == 'SIMPLE_RADIAL':
|
| 160 |
+
fx = fy = camdata[1].params[0] * factor
|
| 161 |
+
cx = camdata[1].params[1] * factor
|
| 162 |
+
cy = camdata[1].params[2] * factor
|
| 163 |
+
elif camdata[1].model in ['PINHOLE', 'OPENCV']:
|
| 164 |
+
fx = camdata[1].params[0] * factor
|
| 165 |
+
fy = camdata[1].params[1] * factor
|
| 166 |
+
cx = camdata[1].params[2] * factor
|
| 167 |
+
cy = camdata[1].params[3] * factor
|
| 168 |
+
else:
|
| 169 |
+
raise ValueError(f"Please parse the intrinsics for camera model {camdata[1].model}!")
|
| 170 |
+
|
| 171 |
+
directions = get_ray_directions(w, h, fx, fy, cx, cy).to(self.rank)
|
| 172 |
+
|
| 173 |
+
imdata = read_images_binary(os.path.join(self.config.root_dir, 'sparse/0/images.bin'))
|
| 174 |
+
|
| 175 |
+
mask_dir = os.path.join(self.config.root_dir, 'masks')
|
| 176 |
+
has_mask = os.path.exists(mask_dir) # TODO: support partial masks
|
| 177 |
+
apply_mask = has_mask and self.config.apply_mask
|
| 178 |
+
|
| 179 |
+
all_c2w, all_images, all_fg_masks = [], [], []
|
| 180 |
+
|
| 181 |
+
for i, d in enumerate(imdata.values()):
|
| 182 |
+
R = d.qvec2rotmat()
|
| 183 |
+
t = d.tvec.reshape(3, 1)
|
| 184 |
+
c2w = torch.from_numpy(np.concatenate([R.T, -R.T@t], axis=1)).float()
|
| 185 |
+
c2w[:,1:3] *= -1. # COLMAP => OpenGL
|
| 186 |
+
all_c2w.append(c2w)
|
| 187 |
+
if self.split in ['train', 'val']:
|
| 188 |
+
img_path = os.path.join(self.config.root_dir, 'images', d.name)
|
| 189 |
+
img = Image.open(img_path)
|
| 190 |
+
img = img.resize(img_wh, Image.BICUBIC)
|
| 191 |
+
img = TF.to_tensor(img).permute(1, 2, 0)[...,:3]
|
| 192 |
+
img = img.to(self.rank) if self.config.load_data_on_gpu else img.cpu()
|
| 193 |
+
if has_mask:
|
| 194 |
+
mask_paths = [os.path.join(mask_dir, d.name), os.path.join(mask_dir, d.name[3:])]
|
| 195 |
+
mask_paths = list(filter(os.path.exists, mask_paths))
|
| 196 |
+
assert len(mask_paths) == 1
|
| 197 |
+
mask = Image.open(mask_paths[0]).convert('L') # (H, W, 1)
|
| 198 |
+
mask = mask.resize(img_wh, Image.BICUBIC)
|
| 199 |
+
mask = TF.to_tensor(mask)[0]
|
| 200 |
+
else:
|
| 201 |
+
mask = torch.ones_like(img[...,0], device=img.device)
|
| 202 |
+
all_fg_masks.append(mask) # (h, w)
|
| 203 |
+
all_images.append(img)
|
| 204 |
+
|
| 205 |
+
all_c2w = torch.stack(all_c2w, dim=0)
|
| 206 |
+
|
| 207 |
+
pts3d = read_points3d_binary(os.path.join(self.config.root_dir, 'sparse/0/points3D.bin'))
|
| 208 |
+
pts3d = torch.from_numpy(np.array([pts3d[k].xyz for k in pts3d])).float()
|
| 209 |
+
all_c2w, pts3d = normalize_poses(all_c2w, pts3d, up_est_method=self.config.up_est_method, center_est_method=self.config.center_est_method)
|
| 210 |
+
|
| 211 |
+
ColmapDatasetBase.properties = {
|
| 212 |
+
'w': w,
|
| 213 |
+
'h': h,
|
| 214 |
+
'img_wh': img_wh,
|
| 215 |
+
'factor': factor,
|
| 216 |
+
'has_mask': has_mask,
|
| 217 |
+
'apply_mask': apply_mask,
|
| 218 |
+
'directions': directions,
|
| 219 |
+
'pts3d': pts3d,
|
| 220 |
+
'all_c2w': all_c2w,
|
| 221 |
+
'all_images': all_images,
|
| 222 |
+
'all_fg_masks': all_fg_masks
|
| 223 |
+
}
|
| 224 |
+
|
| 225 |
+
ColmapDatasetBase.initialized = True
|
| 226 |
+
|
| 227 |
+
for k, v in ColmapDatasetBase.properties.items():
|
| 228 |
+
setattr(self, k, v)
|
| 229 |
+
|
| 230 |
+
if self.split == 'test':
|
| 231 |
+
self.all_c2w = create_spheric_poses(self.all_c2w[:,:,3], n_steps=self.config.n_test_traj_steps)
|
| 232 |
+
self.all_images = torch.zeros((self.config.n_test_traj_steps, self.h, self.w, 3), dtype=torch.float32)
|
| 233 |
+
self.all_fg_masks = torch.zeros((self.config.n_test_traj_steps, self.h, self.w), dtype=torch.float32)
|
| 234 |
+
else:
|
| 235 |
+
self.all_images, self.all_fg_masks = torch.stack(self.all_images, dim=0).float(), torch.stack(self.all_fg_masks, dim=0).float()
|
| 236 |
+
|
| 237 |
+
"""
|
| 238 |
+
# for debug use
|
| 239 |
+
from models.ray_utils import get_rays
|
| 240 |
+
rays_o, rays_d = get_rays(self.directions.cpu(), self.all_c2w, keepdim=True)
|
| 241 |
+
pts_out = []
|
| 242 |
+
pts_out.append('\n'.join([' '.join([str(p) for p in l]) + ' 1.0 0.0 0.0' for l in rays_o[:,0,0].reshape(-1, 3).tolist()]))
|
| 243 |
+
|
| 244 |
+
t_vals = torch.linspace(0, 1, 8)
|
| 245 |
+
z_vals = 0.05 * (1 - t_vals) + 0.5 * t_vals
|
| 246 |
+
|
| 247 |
+
ray_pts = (rays_o[:,0,0][..., None, :] + z_vals[..., None] * rays_d[:,0,0][..., None, :])
|
| 248 |
+
pts_out.append('\n'.join([' '.join([str(p) for p in l]) + ' 0.0 1.0 0.0' for l in ray_pts.view(-1, 3).tolist()]))
|
| 249 |
+
|
| 250 |
+
ray_pts = (rays_o[:,0,0][..., None, :] + z_vals[..., None] * rays_d[:,self.h-1,0][..., None, :])
|
| 251 |
+
pts_out.append('\n'.join([' '.join([str(p) for p in l]) + ' 0.0 0.0 1.0' for l in ray_pts.view(-1, 3).tolist()]))
|
| 252 |
+
|
| 253 |
+
ray_pts = (rays_o[:,0,0][..., None, :] + z_vals[..., None] * rays_d[:,0,self.w-1][..., None, :])
|
| 254 |
+
pts_out.append('\n'.join([' '.join([str(p) for p in l]) + ' 0.0 1.0 1.0' for l in ray_pts.view(-1, 3).tolist()]))
|
| 255 |
+
|
| 256 |
+
ray_pts = (rays_o[:,0,0][..., None, :] + z_vals[..., None] * rays_d[:,self.h-1,self.w-1][..., None, :])
|
| 257 |
+
pts_out.append('\n'.join([' '.join([str(p) for p in l]) + ' 1.0 1.0 1.0' for l in ray_pts.view(-1, 3).tolist()]))
|
| 258 |
+
|
| 259 |
+
open('cameras.txt', 'w').write('\n'.join(pts_out))
|
| 260 |
+
open('scene.txt', 'w').write('\n'.join([' '.join([str(p) for p in l]) + ' 0.0 0.0 0.0' for l in self.pts3d.view(-1, 3).tolist()]))
|
| 261 |
+
|
| 262 |
+
exit(1)
|
| 263 |
+
"""
|
| 264 |
+
|
| 265 |
+
self.all_c2w = self.all_c2w.float().to(self.rank)
|
| 266 |
+
if self.config.load_data_on_gpu:
|
| 267 |
+
self.all_images = self.all_images.to(self.rank)
|
| 268 |
+
self.all_fg_masks = self.all_fg_masks.to(self.rank)
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
class ColmapDataset(Dataset, ColmapDatasetBase):
|
| 272 |
+
def __init__(self, config, split):
|
| 273 |
+
self.setup(config, split)
|
| 274 |
+
|
| 275 |
+
def __len__(self):
|
| 276 |
+
return len(self.all_images)
|
| 277 |
+
|
| 278 |
+
def __getitem__(self, index):
|
| 279 |
+
return {
|
| 280 |
+
'index': index
|
| 281 |
+
}
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
class ColmapIterableDataset(IterableDataset, ColmapDatasetBase):
|
| 285 |
+
def __init__(self, config, split):
|
| 286 |
+
self.setup(config, split)
|
| 287 |
+
|
| 288 |
+
def __iter__(self):
|
| 289 |
+
while True:
|
| 290 |
+
yield {}
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
@datasets.register('colmap')
|
| 294 |
+
class ColmapDataModule(pl.LightningDataModule):
|
| 295 |
+
def __init__(self, config):
|
| 296 |
+
super().__init__()
|
| 297 |
+
self.config = config
|
| 298 |
+
|
| 299 |
+
def setup(self, stage=None):
|
| 300 |
+
if stage in [None, 'fit']:
|
| 301 |
+
self.train_dataset = ColmapIterableDataset(self.config, 'train')
|
| 302 |
+
if stage in [None, 'fit', 'validate']:
|
| 303 |
+
self.val_dataset = ColmapDataset(self.config, self.config.get('val_split', 'train'))
|
| 304 |
+
if stage in [None, 'test']:
|
| 305 |
+
self.test_dataset = ColmapDataset(self.config, self.config.get('test_split', 'test'))
|
| 306 |
+
if stage in [None, 'predict']:
|
| 307 |
+
self.predict_dataset = ColmapDataset(self.config, 'train')
|
| 308 |
+
|
| 309 |
+
def prepare_data(self):
|
| 310 |
+
pass
|
| 311 |
+
|
| 312 |
+
def general_loader(self, dataset, batch_size):
|
| 313 |
+
sampler = None
|
| 314 |
+
return DataLoader(
|
| 315 |
+
dataset,
|
| 316 |
+
num_workers=os.cpu_count(),
|
| 317 |
+
batch_size=batch_size,
|
| 318 |
+
pin_memory=True,
|
| 319 |
+
sampler=sampler
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
def train_dataloader(self):
|
| 323 |
+
return self.general_loader(self.train_dataset, batch_size=1)
|
| 324 |
+
|
| 325 |
+
def val_dataloader(self):
|
| 326 |
+
return self.general_loader(self.val_dataset, batch_size=1)
|
| 327 |
+
|
| 328 |
+
def test_dataloader(self):
|
| 329 |
+
return self.general_loader(self.test_dataset, batch_size=1)
|
| 330 |
+
|
| 331 |
+
def predict_dataloader(self):
|
| 332 |
+
return self.general_loader(self.predict_dataset, batch_size=1)
|
mesh_recon/datasets/colmap_utils.py
ADDED
|
@@ -0,0 +1,295 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2018, ETH Zurich and UNC Chapel Hill.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Redistribution and use in source and binary forms, with or without
|
| 5 |
+
# modification, are permitted provided that the following conditions are met:
|
| 6 |
+
#
|
| 7 |
+
# * Redistributions of source code must retain the above copyright
|
| 8 |
+
# notice, this list of conditions and the following disclaimer.
|
| 9 |
+
#
|
| 10 |
+
# * Redistributions in binary form must reproduce the above copyright
|
| 11 |
+
# notice, this list of conditions and the following disclaimer in the
|
| 12 |
+
# documentation and/or other materials provided with the distribution.
|
| 13 |
+
#
|
| 14 |
+
# * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of
|
| 15 |
+
# its contributors may be used to endorse or promote products derived
|
| 16 |
+
# from this software without specific prior written permission.
|
| 17 |
+
#
|
| 18 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 19 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 20 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
| 21 |
+
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE
|
| 22 |
+
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
| 23 |
+
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
| 24 |
+
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
| 25 |
+
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
| 26 |
+
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
| 27 |
+
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
| 28 |
+
# POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
#
|
| 30 |
+
# Author: Johannes L. Schoenberger (jsch at inf.ethz.ch)
|
| 31 |
+
|
| 32 |
+
import os
|
| 33 |
+
import collections
|
| 34 |
+
import numpy as np
|
| 35 |
+
import struct
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
CameraModel = collections.namedtuple(
|
| 39 |
+
"CameraModel", ["model_id", "model_name", "num_params"])
|
| 40 |
+
Camera = collections.namedtuple(
|
| 41 |
+
"Camera", ["id", "model", "width", "height", "params"])
|
| 42 |
+
BaseImage = collections.namedtuple(
|
| 43 |
+
"Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"])
|
| 44 |
+
Point3D = collections.namedtuple(
|
| 45 |
+
"Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"])
|
| 46 |
+
|
| 47 |
+
class Image(BaseImage):
|
| 48 |
+
def qvec2rotmat(self):
|
| 49 |
+
return qvec2rotmat(self.qvec)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
CAMERA_MODELS = {
|
| 53 |
+
CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3),
|
| 54 |
+
CameraModel(model_id=1, model_name="PINHOLE", num_params=4),
|
| 55 |
+
CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4),
|
| 56 |
+
CameraModel(model_id=3, model_name="RADIAL", num_params=5),
|
| 57 |
+
CameraModel(model_id=4, model_name="OPENCV", num_params=8),
|
| 58 |
+
CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8),
|
| 59 |
+
CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12),
|
| 60 |
+
CameraModel(model_id=7, model_name="FOV", num_params=5),
|
| 61 |
+
CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4),
|
| 62 |
+
CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5),
|
| 63 |
+
CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12)
|
| 64 |
+
}
|
| 65 |
+
CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model) \
|
| 66 |
+
for camera_model in CAMERA_MODELS])
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"):
|
| 70 |
+
"""Read and unpack the next bytes from a binary file.
|
| 71 |
+
:param fid:
|
| 72 |
+
:param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc.
|
| 73 |
+
:param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}.
|
| 74 |
+
:param endian_character: Any of {@, =, <, >, !}
|
| 75 |
+
:return: Tuple of read and unpacked values.
|
| 76 |
+
"""
|
| 77 |
+
data = fid.read(num_bytes)
|
| 78 |
+
return struct.unpack(endian_character + format_char_sequence, data)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def read_cameras_text(path):
|
| 82 |
+
"""
|
| 83 |
+
see: src/base/reconstruction.cc
|
| 84 |
+
void Reconstruction::WriteCamerasText(const std::string& path)
|
| 85 |
+
void Reconstruction::ReadCamerasText(const std::string& path)
|
| 86 |
+
"""
|
| 87 |
+
cameras = {}
|
| 88 |
+
with open(path, "r") as fid:
|
| 89 |
+
while True:
|
| 90 |
+
line = fid.readline()
|
| 91 |
+
if not line:
|
| 92 |
+
break
|
| 93 |
+
line = line.strip()
|
| 94 |
+
if len(line) > 0 and line[0] != "#":
|
| 95 |
+
elems = line.split()
|
| 96 |
+
camera_id = int(elems[0])
|
| 97 |
+
model = elems[1]
|
| 98 |
+
width = int(elems[2])
|
| 99 |
+
height = int(elems[3])
|
| 100 |
+
params = np.array(tuple(map(float, elems[4:])))
|
| 101 |
+
cameras[camera_id] = Camera(id=camera_id, model=model,
|
| 102 |
+
width=width, height=height,
|
| 103 |
+
params=params)
|
| 104 |
+
return cameras
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def read_cameras_binary(path_to_model_file):
|
| 108 |
+
"""
|
| 109 |
+
see: src/base/reconstruction.cc
|
| 110 |
+
void Reconstruction::WriteCamerasBinary(const std::string& path)
|
| 111 |
+
void Reconstruction::ReadCamerasBinary(const std::string& path)
|
| 112 |
+
"""
|
| 113 |
+
cameras = {}
|
| 114 |
+
with open(path_to_model_file, "rb") as fid:
|
| 115 |
+
num_cameras = read_next_bytes(fid, 8, "Q")[0]
|
| 116 |
+
for camera_line_index in range(num_cameras):
|
| 117 |
+
camera_properties = read_next_bytes(
|
| 118 |
+
fid, num_bytes=24, format_char_sequence="iiQQ")
|
| 119 |
+
camera_id = camera_properties[0]
|
| 120 |
+
model_id = camera_properties[1]
|
| 121 |
+
model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name
|
| 122 |
+
width = camera_properties[2]
|
| 123 |
+
height = camera_properties[3]
|
| 124 |
+
num_params = CAMERA_MODEL_IDS[model_id].num_params
|
| 125 |
+
params = read_next_bytes(fid, num_bytes=8*num_params,
|
| 126 |
+
format_char_sequence="d"*num_params)
|
| 127 |
+
cameras[camera_id] = Camera(id=camera_id,
|
| 128 |
+
model=model_name,
|
| 129 |
+
width=width,
|
| 130 |
+
height=height,
|
| 131 |
+
params=np.array(params))
|
| 132 |
+
assert len(cameras) == num_cameras
|
| 133 |
+
return cameras
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def read_images_text(path):
|
| 137 |
+
"""
|
| 138 |
+
see: src/base/reconstruction.cc
|
| 139 |
+
void Reconstruction::ReadImagesText(const std::string& path)
|
| 140 |
+
void Reconstruction::WriteImagesText(const std::string& path)
|
| 141 |
+
"""
|
| 142 |
+
images = {}
|
| 143 |
+
with open(path, "r") as fid:
|
| 144 |
+
while True:
|
| 145 |
+
line = fid.readline()
|
| 146 |
+
if not line:
|
| 147 |
+
break
|
| 148 |
+
line = line.strip()
|
| 149 |
+
if len(line) > 0 and line[0] != "#":
|
| 150 |
+
elems = line.split()
|
| 151 |
+
image_id = int(elems[0])
|
| 152 |
+
qvec = np.array(tuple(map(float, elems[1:5])))
|
| 153 |
+
tvec = np.array(tuple(map(float, elems[5:8])))
|
| 154 |
+
camera_id = int(elems[8])
|
| 155 |
+
image_name = elems[9]
|
| 156 |
+
elems = fid.readline().split()
|
| 157 |
+
xys = np.column_stack([tuple(map(float, elems[0::3])),
|
| 158 |
+
tuple(map(float, elems[1::3]))])
|
| 159 |
+
point3D_ids = np.array(tuple(map(int, elems[2::3])))
|
| 160 |
+
images[image_id] = Image(
|
| 161 |
+
id=image_id, qvec=qvec, tvec=tvec,
|
| 162 |
+
camera_id=camera_id, name=image_name,
|
| 163 |
+
xys=xys, point3D_ids=point3D_ids)
|
| 164 |
+
return images
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def read_images_binary(path_to_model_file):
|
| 168 |
+
"""
|
| 169 |
+
see: src/base/reconstruction.cc
|
| 170 |
+
void Reconstruction::ReadImagesBinary(const std::string& path)
|
| 171 |
+
void Reconstruction::WriteImagesBinary(const std::string& path)
|
| 172 |
+
"""
|
| 173 |
+
images = {}
|
| 174 |
+
with open(path_to_model_file, "rb") as fid:
|
| 175 |
+
num_reg_images = read_next_bytes(fid, 8, "Q")[0]
|
| 176 |
+
for image_index in range(num_reg_images):
|
| 177 |
+
binary_image_properties = read_next_bytes(
|
| 178 |
+
fid, num_bytes=64, format_char_sequence="idddddddi")
|
| 179 |
+
image_id = binary_image_properties[0]
|
| 180 |
+
qvec = np.array(binary_image_properties[1:5])
|
| 181 |
+
tvec = np.array(binary_image_properties[5:8])
|
| 182 |
+
camera_id = binary_image_properties[8]
|
| 183 |
+
image_name = ""
|
| 184 |
+
current_char = read_next_bytes(fid, 1, "c")[0]
|
| 185 |
+
while current_char != b"\x00": # look for the ASCII 0 entry
|
| 186 |
+
image_name += current_char.decode("utf-8")
|
| 187 |
+
current_char = read_next_bytes(fid, 1, "c")[0]
|
| 188 |
+
num_points2D = read_next_bytes(fid, num_bytes=8,
|
| 189 |
+
format_char_sequence="Q")[0]
|
| 190 |
+
x_y_id_s = read_next_bytes(fid, num_bytes=24*num_points2D,
|
| 191 |
+
format_char_sequence="ddq"*num_points2D)
|
| 192 |
+
xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])),
|
| 193 |
+
tuple(map(float, x_y_id_s[1::3]))])
|
| 194 |
+
point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3])))
|
| 195 |
+
images[image_id] = Image(
|
| 196 |
+
id=image_id, qvec=qvec, tvec=tvec,
|
| 197 |
+
camera_id=camera_id, name=image_name,
|
| 198 |
+
xys=xys, point3D_ids=point3D_ids)
|
| 199 |
+
return images
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def read_points3D_text(path):
|
| 203 |
+
"""
|
| 204 |
+
see: src/base/reconstruction.cc
|
| 205 |
+
void Reconstruction::ReadPoints3DText(const std::string& path)
|
| 206 |
+
void Reconstruction::WritePoints3DText(const std::string& path)
|
| 207 |
+
"""
|
| 208 |
+
points3D = {}
|
| 209 |
+
with open(path, "r") as fid:
|
| 210 |
+
while True:
|
| 211 |
+
line = fid.readline()
|
| 212 |
+
if not line:
|
| 213 |
+
break
|
| 214 |
+
line = line.strip()
|
| 215 |
+
if len(line) > 0 and line[0] != "#":
|
| 216 |
+
elems = line.split()
|
| 217 |
+
point3D_id = int(elems[0])
|
| 218 |
+
xyz = np.array(tuple(map(float, elems[1:4])))
|
| 219 |
+
rgb = np.array(tuple(map(int, elems[4:7])))
|
| 220 |
+
error = float(elems[7])
|
| 221 |
+
image_ids = np.array(tuple(map(int, elems[8::2])))
|
| 222 |
+
point2D_idxs = np.array(tuple(map(int, elems[9::2])))
|
| 223 |
+
points3D[point3D_id] = Point3D(id=point3D_id, xyz=xyz, rgb=rgb,
|
| 224 |
+
error=error, image_ids=image_ids,
|
| 225 |
+
point2D_idxs=point2D_idxs)
|
| 226 |
+
return points3D
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def read_points3d_binary(path_to_model_file):
|
| 230 |
+
"""
|
| 231 |
+
see: src/base/reconstruction.cc
|
| 232 |
+
void Reconstruction::ReadPoints3DBinary(const std::string& path)
|
| 233 |
+
void Reconstruction::WritePoints3DBinary(const std::string& path)
|
| 234 |
+
"""
|
| 235 |
+
points3D = {}
|
| 236 |
+
with open(path_to_model_file, "rb") as fid:
|
| 237 |
+
num_points = read_next_bytes(fid, 8, "Q")[0]
|
| 238 |
+
for point_line_index in range(num_points):
|
| 239 |
+
binary_point_line_properties = read_next_bytes(
|
| 240 |
+
fid, num_bytes=43, format_char_sequence="QdddBBBd")
|
| 241 |
+
point3D_id = binary_point_line_properties[0]
|
| 242 |
+
xyz = np.array(binary_point_line_properties[1:4])
|
| 243 |
+
rgb = np.array(binary_point_line_properties[4:7])
|
| 244 |
+
error = np.array(binary_point_line_properties[7])
|
| 245 |
+
track_length = read_next_bytes(
|
| 246 |
+
fid, num_bytes=8, format_char_sequence="Q")[0]
|
| 247 |
+
track_elems = read_next_bytes(
|
| 248 |
+
fid, num_bytes=8*track_length,
|
| 249 |
+
format_char_sequence="ii"*track_length)
|
| 250 |
+
image_ids = np.array(tuple(map(int, track_elems[0::2])))
|
| 251 |
+
point2D_idxs = np.array(tuple(map(int, track_elems[1::2])))
|
| 252 |
+
points3D[point3D_id] = Point3D(
|
| 253 |
+
id=point3D_id, xyz=xyz, rgb=rgb,
|
| 254 |
+
error=error, image_ids=image_ids,
|
| 255 |
+
point2D_idxs=point2D_idxs)
|
| 256 |
+
return points3D
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def read_model(path, ext):
|
| 260 |
+
if ext == ".txt":
|
| 261 |
+
cameras = read_cameras_text(os.path.join(path, "cameras" + ext))
|
| 262 |
+
images = read_images_text(os.path.join(path, "images" + ext))
|
| 263 |
+
points3D = read_points3D_text(os.path.join(path, "points3D") + ext)
|
| 264 |
+
else:
|
| 265 |
+
cameras = read_cameras_binary(os.path.join(path, "cameras" + ext))
|
| 266 |
+
images = read_images_binary(os.path.join(path, "images" + ext))
|
| 267 |
+
points3D = read_points3d_binary(os.path.join(path, "points3D") + ext)
|
| 268 |
+
return cameras, images, points3D
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
def qvec2rotmat(qvec):
|
| 272 |
+
return np.array([
|
| 273 |
+
[1 - 2 * qvec[2]**2 - 2 * qvec[3]**2,
|
| 274 |
+
2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],
|
| 275 |
+
2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]],
|
| 276 |
+
[2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],
|
| 277 |
+
1 - 2 * qvec[1]**2 - 2 * qvec[3]**2,
|
| 278 |
+
2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]],
|
| 279 |
+
[2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],
|
| 280 |
+
2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],
|
| 281 |
+
1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]])
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
def rotmat2qvec(R):
|
| 285 |
+
Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat
|
| 286 |
+
K = np.array([
|
| 287 |
+
[Rxx - Ryy - Rzz, 0, 0, 0],
|
| 288 |
+
[Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0],
|
| 289 |
+
[Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0],
|
| 290 |
+
[Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0
|
| 291 |
+
eigvals, eigvecs = np.linalg.eigh(K)
|
| 292 |
+
qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)]
|
| 293 |
+
if qvec[0] < 0:
|
| 294 |
+
qvec *= -1
|
| 295 |
+
return qvec
|
mesh_recon/datasets/dtu.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import math
|
| 4 |
+
import numpy as np
|
| 5 |
+
from PIL import Image
|
| 6 |
+
import cv2
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from torch.utils.data import Dataset, DataLoader, IterableDataset
|
| 11 |
+
import torchvision.transforms.functional as TF
|
| 12 |
+
|
| 13 |
+
import pytorch_lightning as pl
|
| 14 |
+
|
| 15 |
+
import datasets
|
| 16 |
+
from models.ray_utils import get_ray_directions
|
| 17 |
+
from utils.misc import get_rank
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def load_K_Rt_from_P(P=None):
|
| 21 |
+
out = cv2.decomposeProjectionMatrix(P)
|
| 22 |
+
K = out[0]
|
| 23 |
+
R = out[1]
|
| 24 |
+
t = out[2]
|
| 25 |
+
|
| 26 |
+
K = K / K[2, 2]
|
| 27 |
+
intrinsics = np.eye(4)
|
| 28 |
+
intrinsics[:3, :3] = K
|
| 29 |
+
|
| 30 |
+
pose = np.eye(4, dtype=np.float32)
|
| 31 |
+
pose[:3, :3] = R.transpose()
|
| 32 |
+
pose[:3, 3] = (t[:3] / t[3])[:, 0]
|
| 33 |
+
|
| 34 |
+
return intrinsics, pose
|
| 35 |
+
|
| 36 |
+
def create_spheric_poses(cameras, n_steps=120):
|
| 37 |
+
center = torch.as_tensor([0.,0.,0.], dtype=cameras.dtype, device=cameras.device)
|
| 38 |
+
cam_center = F.normalize(cameras.mean(0), p=2, dim=-1) * cameras.mean(0).norm(2)
|
| 39 |
+
eigvecs = torch.linalg.eig(cameras.T @ cameras).eigenvectors
|
| 40 |
+
rot_axis = F.normalize(eigvecs[:,1].real.float(), p=2, dim=-1)
|
| 41 |
+
up = rot_axis
|
| 42 |
+
rot_dir = torch.cross(rot_axis, cam_center)
|
| 43 |
+
max_angle = (F.normalize(cameras, p=2, dim=-1) * F.normalize(cam_center, p=2, dim=-1)).sum(-1).acos().max()
|
| 44 |
+
|
| 45 |
+
all_c2w = []
|
| 46 |
+
for theta in torch.linspace(-max_angle, max_angle, n_steps):
|
| 47 |
+
cam_pos = cam_center * math.cos(theta) + rot_dir * math.sin(theta)
|
| 48 |
+
l = F.normalize(center - cam_pos, p=2, dim=0)
|
| 49 |
+
s = F.normalize(l.cross(up), p=2, dim=0)
|
| 50 |
+
u = F.normalize(s.cross(l), p=2, dim=0)
|
| 51 |
+
c2w = torch.cat([torch.stack([s, u, -l], dim=1), cam_pos[:,None]], axis=1)
|
| 52 |
+
all_c2w.append(c2w)
|
| 53 |
+
|
| 54 |
+
all_c2w = torch.stack(all_c2w, dim=0)
|
| 55 |
+
|
| 56 |
+
return all_c2w
|
| 57 |
+
|
| 58 |
+
class DTUDatasetBase():
|
| 59 |
+
def setup(self, config, split):
|
| 60 |
+
self.config = config
|
| 61 |
+
self.split = split
|
| 62 |
+
self.rank = get_rank()
|
| 63 |
+
|
| 64 |
+
cams = np.load(os.path.join(self.config.root_dir, self.config.cameras_file))
|
| 65 |
+
|
| 66 |
+
img_sample = cv2.imread(os.path.join(self.config.root_dir, 'image', '000000.png'))
|
| 67 |
+
H, W = img_sample.shape[0], img_sample.shape[1]
|
| 68 |
+
|
| 69 |
+
if 'img_wh' in self.config:
|
| 70 |
+
w, h = self.config.img_wh
|
| 71 |
+
assert round(W / w * h) == H
|
| 72 |
+
elif 'img_downscale' in self.config:
|
| 73 |
+
w, h = int(W / self.config.img_downscale + 0.5), int(H / self.config.img_downscale + 0.5)
|
| 74 |
+
else:
|
| 75 |
+
raise KeyError("Either img_wh or img_downscale should be specified.")
|
| 76 |
+
|
| 77 |
+
self.w, self.h = w, h
|
| 78 |
+
self.img_wh = (w, h)
|
| 79 |
+
self.factor = w / W
|
| 80 |
+
|
| 81 |
+
mask_dir = os.path.join(self.config.root_dir, 'mask')
|
| 82 |
+
self.has_mask = True
|
| 83 |
+
self.apply_mask = self.config.apply_mask
|
| 84 |
+
|
| 85 |
+
self.directions = []
|
| 86 |
+
self.all_c2w, self.all_images, self.all_fg_masks = [], [], []
|
| 87 |
+
|
| 88 |
+
n_images = max([int(k.split('_')[-1]) for k in cams.keys()]) + 1
|
| 89 |
+
|
| 90 |
+
for i in range(n_images):
|
| 91 |
+
world_mat, scale_mat = cams[f'world_mat_{i}'], cams[f'scale_mat_{i}']
|
| 92 |
+
P = (world_mat @ scale_mat)[:3,:4]
|
| 93 |
+
K, c2w = load_K_Rt_from_P(P)
|
| 94 |
+
fx, fy, cx, cy = K[0,0] * self.factor, K[1,1] * self.factor, K[0,2] * self.factor, K[1,2] * self.factor
|
| 95 |
+
directions = get_ray_directions(w, h, fx, fy, cx, cy)
|
| 96 |
+
self.directions.append(directions)
|
| 97 |
+
|
| 98 |
+
c2w = torch.from_numpy(c2w).float()
|
| 99 |
+
|
| 100 |
+
# blender follows opengl camera coordinates (right up back)
|
| 101 |
+
# NeuS DTU data coordinate system (right down front) is different from blender
|
| 102 |
+
# https://github.com/Totoro97/NeuS/issues/9
|
| 103 |
+
# for c2w, flip the sign of input camera coordinate yz
|
| 104 |
+
c2w_ = c2w.clone()
|
| 105 |
+
c2w_[:3,1:3] *= -1. # flip input sign
|
| 106 |
+
self.all_c2w.append(c2w_[:3,:4])
|
| 107 |
+
|
| 108 |
+
if self.split in ['train', 'val']:
|
| 109 |
+
img_path = os.path.join(self.config.root_dir, 'image', f'{i:06d}.png')
|
| 110 |
+
img = Image.open(img_path)
|
| 111 |
+
img = img.resize(self.img_wh, Image.BICUBIC)
|
| 112 |
+
img = TF.to_tensor(img).permute(1, 2, 0)[...,:3]
|
| 113 |
+
|
| 114 |
+
mask_path = os.path.join(mask_dir, f'{i:03d}.png')
|
| 115 |
+
mask = Image.open(mask_path).convert('L') # (H, W, 1)
|
| 116 |
+
mask = mask.resize(self.img_wh, Image.BICUBIC)
|
| 117 |
+
mask = TF.to_tensor(mask)[0]
|
| 118 |
+
|
| 119 |
+
self.all_fg_masks.append(mask) # (h, w)
|
| 120 |
+
self.all_images.append(img)
|
| 121 |
+
|
| 122 |
+
self.all_c2w = torch.stack(self.all_c2w, dim=0)
|
| 123 |
+
|
| 124 |
+
if self.split == 'test':
|
| 125 |
+
self.all_c2w = create_spheric_poses(self.all_c2w[:,:,3], n_steps=self.config.n_test_traj_steps)
|
| 126 |
+
self.all_images = torch.zeros((self.config.n_test_traj_steps, self.h, self.w, 3), dtype=torch.float32)
|
| 127 |
+
self.all_fg_masks = torch.zeros((self.config.n_test_traj_steps, self.h, self.w), dtype=torch.float32)
|
| 128 |
+
self.directions = self.directions[0]
|
| 129 |
+
else:
|
| 130 |
+
self.all_images, self.all_fg_masks = torch.stack(self.all_images, dim=0), torch.stack(self.all_fg_masks, dim=0)
|
| 131 |
+
self.directions = torch.stack(self.directions, dim=0)
|
| 132 |
+
|
| 133 |
+
self.directions = self.directions.float().to(self.rank)
|
| 134 |
+
self.all_c2w, self.all_images, self.all_fg_masks = \
|
| 135 |
+
self.all_c2w.float().to(self.rank), \
|
| 136 |
+
self.all_images.float().to(self.rank), \
|
| 137 |
+
self.all_fg_masks.float().to(self.rank)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
class DTUDataset(Dataset, DTUDatasetBase):
|
| 141 |
+
def __init__(self, config, split):
|
| 142 |
+
self.setup(config, split)
|
| 143 |
+
|
| 144 |
+
def __len__(self):
|
| 145 |
+
return len(self.all_images)
|
| 146 |
+
|
| 147 |
+
def __getitem__(self, index):
|
| 148 |
+
return {
|
| 149 |
+
'index': index
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
class DTUIterableDataset(IterableDataset, DTUDatasetBase):
|
| 154 |
+
def __init__(self, config, split):
|
| 155 |
+
self.setup(config, split)
|
| 156 |
+
|
| 157 |
+
def __iter__(self):
|
| 158 |
+
while True:
|
| 159 |
+
yield {}
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
@datasets.register('dtu')
|
| 163 |
+
class DTUDataModule(pl.LightningDataModule):
|
| 164 |
+
def __init__(self, config):
|
| 165 |
+
super().__init__()
|
| 166 |
+
self.config = config
|
| 167 |
+
|
| 168 |
+
def setup(self, stage=None):
|
| 169 |
+
if stage in [None, 'fit']:
|
| 170 |
+
self.train_dataset = DTUIterableDataset(self.config, 'train')
|
| 171 |
+
if stage in [None, 'fit', 'validate']:
|
| 172 |
+
self.val_dataset = DTUDataset(self.config, self.config.get('val_split', 'train'))
|
| 173 |
+
if stage in [None, 'test']:
|
| 174 |
+
self.test_dataset = DTUDataset(self.config, self.config.get('test_split', 'test'))
|
| 175 |
+
if stage in [None, 'predict']:
|
| 176 |
+
self.predict_dataset = DTUDataset(self.config, 'train')
|
| 177 |
+
|
| 178 |
+
def prepare_data(self):
|
| 179 |
+
pass
|
| 180 |
+
|
| 181 |
+
def general_loader(self, dataset, batch_size):
|
| 182 |
+
sampler = None
|
| 183 |
+
return DataLoader(
|
| 184 |
+
dataset,
|
| 185 |
+
num_workers=os.cpu_count(),
|
| 186 |
+
batch_size=batch_size,
|
| 187 |
+
pin_memory=True,
|
| 188 |
+
sampler=sampler
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
def train_dataloader(self):
|
| 192 |
+
return self.general_loader(self.train_dataset, batch_size=1)
|
| 193 |
+
|
| 194 |
+
def val_dataloader(self):
|
| 195 |
+
return self.general_loader(self.val_dataset, batch_size=1)
|
| 196 |
+
|
| 197 |
+
def test_dataloader(self):
|
| 198 |
+
return self.general_loader(self.test_dataset, batch_size=1)
|
| 199 |
+
|
| 200 |
+
def predict_dataloader(self):
|
| 201 |
+
return self.general_loader(self.predict_dataset, batch_size=1)
|
mesh_recon/datasets/fixed_poses/000_back_RT.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
-1.000000238418579102e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00
|
| 2 |
+
0.000000000000000000e+00 -1.343588564850506373e-07 1.000000119209289551e+00 1.746665105883948854e-07
|
| 3 |
+
0.000000000000000000e+00 1.000000119209289551e+00 -1.343588564850506373e-07 -1.300000071525573730e+00
|
mesh_recon/datasets/fixed_poses/000_back_left_RT.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
-7.071069478988647461e-01 -7.071068286895751953e-01 0.000000000000000000e+00 -1.192092895507812500e-07
|
| 2 |
+
0.000000000000000000e+00 -7.587616579485256807e-08 1.000000119209289551e+00 9.863901340168013121e-08
|
| 3 |
+
-7.071068286895751953e-01 7.071068286895751953e-01 -7.587616579485256807e-08 -1.838477730751037598e+00
|
mesh_recon/datasets/fixed_poses/000_back_right_RT.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
-7.071069478988647461e-01 7.071068286895751953e-01 0.000000000000000000e+00 1.192092895507812500e-07
|
| 2 |
+
0.000000000000000000e+00 -7.587616579485256807e-08 1.000000119209289551e+00 9.863901340168013121e-08
|
| 3 |
+
7.071068286895751953e-01 7.071068286895751953e-01 -7.587616579485256807e-08 -1.838477730751037598e+00
|
mesh_recon/datasets/fixed_poses/000_front_RT.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00
|
| 2 |
+
0.000000000000000000e+00 -1.343588564850506373e-07 1.000000119209289551e+00 -1.746665105883948854e-07
|
| 3 |
+
0.000000000000000000e+00 -1.000000119209289551e+00 -1.343588564850506373e-07 -1.300000071525573730e+00
|
mesh_recon/datasets/fixed_poses/000_front_left_RT.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
7.071067690849304199e-01 -7.071068286895751953e-01 0.000000000000000000e+00 -1.192092895507812500e-07
|
| 2 |
+
0.000000000000000000e+00 -7.587616579485256807e-08 1.000000119209289551e+00 -9.863901340168013121e-08
|
| 3 |
+
-7.071068286895751953e-01 -7.071068286895751953e-01 -7.587616579485256807e-08 -1.838477730751037598e+00
|
mesh_recon/datasets/fixed_poses/000_front_right_RT.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
7.071067690849304199e-01 7.071068286895751953e-01 0.000000000000000000e+00 1.192092895507812500e-07
|
| 2 |
+
0.000000000000000000e+00 -7.587616579485256807e-08 1.000000119209289551e+00 -9.863901340168013121e-08
|
| 3 |
+
7.071068286895751953e-01 -7.071068286895751953e-01 -7.587616579485256807e-08 -1.838477730751037598e+00
|
mesh_recon/datasets/fixed_poses/000_left_RT.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
-2.220446049250313081e-16 -1.000000000000000000e+00 0.000000000000000000e+00 -2.886579758146288598e-16
|
| 2 |
+
0.000000000000000000e+00 -2.220446049250313081e-16 1.000000000000000000e+00 0.000000000000000000e+00
|
| 3 |
+
-1.000000000000000000e+00 0.000000000000000000e+00 -2.220446049250313081e-16 -1.299999952316284180e+00
|
mesh_recon/datasets/fixed_poses/000_right_RT.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
-2.220446049250313081e-16 1.000000000000000000e+00 0.000000000000000000e+00 2.886579758146288598e-16
|
| 2 |
+
0.000000000000000000e+00 -2.220446049250313081e-16 1.000000000000000000e+00 0.000000000000000000e+00
|
| 3 |
+
1.000000000000000000e+00 0.000000000000000000e+00 -2.220446049250313081e-16 -1.299999952316284180e+00
|
mesh_recon/datasets/fixed_poses/000_top_RT.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00
|
| 2 |
+
0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00
|
| 3 |
+
0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 -1.299999952316284180e+00
|
mesh_recon/datasets/ortho.py
ADDED
|
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import math
|
| 4 |
+
import numpy as np
|
| 5 |
+
from PIL import Image
|
| 6 |
+
import cv2
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from torch.utils.data import Dataset, DataLoader, IterableDataset
|
| 11 |
+
import torchvision.transforms.functional as TF
|
| 12 |
+
|
| 13 |
+
import pytorch_lightning as pl
|
| 14 |
+
|
| 15 |
+
import datasets
|
| 16 |
+
from models.ray_utils import get_ortho_ray_directions_origins, get_ortho_rays, get_ray_directions
|
| 17 |
+
from utils.misc import get_rank
|
| 18 |
+
|
| 19 |
+
from glob import glob
|
| 20 |
+
import PIL.Image
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def camNormal2worldNormal(rot_c2w, camNormal):
|
| 24 |
+
H,W,_ = camNormal.shape
|
| 25 |
+
normal_img = np.matmul(rot_c2w[None, :, :], camNormal.reshape(-1,3)[:, :, None]).reshape([H, W, 3])
|
| 26 |
+
|
| 27 |
+
return normal_img
|
| 28 |
+
|
| 29 |
+
def worldNormal2camNormal(rot_w2c, worldNormal):
|
| 30 |
+
H,W,_ = worldNormal.shape
|
| 31 |
+
normal_img = np.matmul(rot_w2c[None, :, :], worldNormal.reshape(-1,3)[:, :, None]).reshape([H, W, 3])
|
| 32 |
+
|
| 33 |
+
return normal_img
|
| 34 |
+
|
| 35 |
+
def trans_normal(normal, RT_w2c, RT_w2c_target):
|
| 36 |
+
|
| 37 |
+
normal_world = camNormal2worldNormal(np.linalg.inv(RT_w2c[:3,:3]), normal)
|
| 38 |
+
normal_target_cam = worldNormal2camNormal(RT_w2c_target[:3,:3], normal_world)
|
| 39 |
+
|
| 40 |
+
return normal_target_cam
|
| 41 |
+
|
| 42 |
+
def img2normal(img):
|
| 43 |
+
return (img/255.)*2-1
|
| 44 |
+
|
| 45 |
+
def normal2img(normal):
|
| 46 |
+
return np.uint8((normal*0.5+0.5)*255)
|
| 47 |
+
|
| 48 |
+
def norm_normalize(normal, dim=-1):
|
| 49 |
+
|
| 50 |
+
normal = normal/(np.linalg.norm(normal, axis=dim, keepdims=True)+1e-6)
|
| 51 |
+
|
| 52 |
+
return normal
|
| 53 |
+
|
| 54 |
+
def RT_opengl2opencv(RT):
|
| 55 |
+
# Build the coordinate transform matrix from world to computer vision camera
|
| 56 |
+
# R_world2cv = R_bcam2cv@R_world2bcam
|
| 57 |
+
# T_world2cv = R_bcam2cv@T_world2bcam
|
| 58 |
+
|
| 59 |
+
R = RT[:3, :3]
|
| 60 |
+
t = RT[:3, 3]
|
| 61 |
+
|
| 62 |
+
R_bcam2cv = np.asarray([[1, 0, 0], [0, -1, 0], [0, 0, -1]], np.float32)
|
| 63 |
+
|
| 64 |
+
R_world2cv = R_bcam2cv @ R
|
| 65 |
+
t_world2cv = R_bcam2cv @ t
|
| 66 |
+
|
| 67 |
+
RT = np.concatenate([R_world2cv,t_world2cv[:,None]],1)
|
| 68 |
+
|
| 69 |
+
return RT
|
| 70 |
+
|
| 71 |
+
def normal_opengl2opencv(normal):
|
| 72 |
+
H,W,C = np.shape(normal)
|
| 73 |
+
# normal_img = np.reshape(normal, (H*W,C))
|
| 74 |
+
R_bcam2cv = np.array([1, -1, -1], np.float32)
|
| 75 |
+
normal_cv = normal * R_bcam2cv[None, None, :]
|
| 76 |
+
|
| 77 |
+
print(np.shape(normal_cv))
|
| 78 |
+
|
| 79 |
+
return normal_cv
|
| 80 |
+
|
| 81 |
+
def inv_RT(RT):
|
| 82 |
+
RT_h = np.concatenate([RT, np.array([[0,0,0,1]])], axis=0)
|
| 83 |
+
RT_inv = np.linalg.inv(RT_h)
|
| 84 |
+
|
| 85 |
+
return RT_inv[:3, :]
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def load_a_prediction(root_dir, test_object, imSize, view_types, load_color=False, cam_pose_dir=None,
|
| 89 |
+
normal_system='front', erode_mask=True, camera_type='ortho', cam_params=None):
|
| 90 |
+
|
| 91 |
+
all_images = []
|
| 92 |
+
all_normals = []
|
| 93 |
+
all_normals_world = []
|
| 94 |
+
all_masks = []
|
| 95 |
+
all_color_masks = []
|
| 96 |
+
all_poses = []
|
| 97 |
+
all_w2cs = []
|
| 98 |
+
directions = []
|
| 99 |
+
ray_origins = []
|
| 100 |
+
|
| 101 |
+
RT_front = np.loadtxt(glob(os.path.join(cam_pose_dir, '*_%s_RT.txt'%( 'front')))[0]) # world2cam matrix
|
| 102 |
+
RT_front_cv = RT_opengl2opencv(RT_front) # convert normal from opengl to opencv
|
| 103 |
+
for idx, view in enumerate(view_types):
|
| 104 |
+
print(os.path.join(root_dir,test_object))
|
| 105 |
+
normal_filepath = os.path.join(root_dir, test_object, 'normals_000_%s.png'%( view))
|
| 106 |
+
# Load key frame
|
| 107 |
+
if load_color: # use bgr
|
| 108 |
+
image =np.array(PIL.Image.open(normal_filepath.replace("normals", "rgb")).resize(imSize))[:, :, :3]
|
| 109 |
+
|
| 110 |
+
normal = np.array(PIL.Image.open(normal_filepath).resize(imSize))
|
| 111 |
+
mask = normal[:, :, 3]
|
| 112 |
+
normal = normal[:, :, :3]
|
| 113 |
+
|
| 114 |
+
color_mask = np.array(PIL.Image.open(os.path.join(root_dir,test_object, 'masked_colors/rgb_000_%s.png'%( view))).resize(imSize))[:, :, 3]
|
| 115 |
+
invalid_color_mask = color_mask < 255*0.5
|
| 116 |
+
threshold = np.ones_like(image[:, :, 0]) * 250
|
| 117 |
+
invalid_white_mask = (image[:, :, 0] > threshold) & (image[:, :, 1] > threshold) & (image[:, :, 2] > threshold)
|
| 118 |
+
invalid_color_mask_final = invalid_color_mask & invalid_white_mask
|
| 119 |
+
color_mask = (1 - invalid_color_mask_final) > 0
|
| 120 |
+
|
| 121 |
+
# if erode_mask:
|
| 122 |
+
# kernel = np.ones((3, 3), np.uint8)
|
| 123 |
+
# mask = cv2.erode(mask, kernel, iterations=1)
|
| 124 |
+
|
| 125 |
+
RT = np.loadtxt(os.path.join(cam_pose_dir, '000_%s_RT.txt'%( view))) # world2cam matrix
|
| 126 |
+
|
| 127 |
+
normal = img2normal(normal)
|
| 128 |
+
|
| 129 |
+
normal[mask==0] = [0,0,0]
|
| 130 |
+
mask = mask> (0.5*255)
|
| 131 |
+
if load_color:
|
| 132 |
+
all_images.append(image)
|
| 133 |
+
|
| 134 |
+
all_masks.append(mask)
|
| 135 |
+
all_color_masks.append(color_mask)
|
| 136 |
+
RT_cv = RT_opengl2opencv(RT) # convert normal from opengl to opencv
|
| 137 |
+
all_poses.append(inv_RT(RT_cv)) # cam2world
|
| 138 |
+
all_w2cs.append(RT_cv)
|
| 139 |
+
|
| 140 |
+
# whether to
|
| 141 |
+
normal_cam_cv = normal_opengl2opencv(normal)
|
| 142 |
+
|
| 143 |
+
if normal_system == 'front':
|
| 144 |
+
print("the loaded normals are defined in the system of front view")
|
| 145 |
+
normal_world = camNormal2worldNormal(inv_RT(RT_front_cv)[:3, :3], normal_cam_cv)
|
| 146 |
+
elif normal_system == 'self':
|
| 147 |
+
print("the loaded normals are in their independent camera systems")
|
| 148 |
+
normal_world = camNormal2worldNormal(inv_RT(RT_cv)[:3, :3], normal_cam_cv)
|
| 149 |
+
all_normals.append(normal_cam_cv)
|
| 150 |
+
all_normals_world.append(normal_world)
|
| 151 |
+
|
| 152 |
+
if camera_type == 'ortho':
|
| 153 |
+
origins, dirs = get_ortho_ray_directions_origins(W=imSize[0], H=imSize[1])
|
| 154 |
+
elif camera_type == 'pinhole':
|
| 155 |
+
dirs = get_ray_directions(W=imSize[0], H=imSize[1],
|
| 156 |
+
fx=cam_params[0], fy=cam_params[1], cx=cam_params[2], cy=cam_params[3])
|
| 157 |
+
origins = dirs # occupy a position
|
| 158 |
+
else:
|
| 159 |
+
raise Exception("not support camera type")
|
| 160 |
+
ray_origins.append(origins)
|
| 161 |
+
directions.append(dirs)
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
if not load_color:
|
| 165 |
+
all_images = [normal2img(x) for x in all_normals_world]
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
return np.stack(all_images), np.stack(all_masks), np.stack(all_normals), \
|
| 169 |
+
np.stack(all_normals_world), np.stack(all_poses), np.stack(all_w2cs), np.stack(ray_origins), np.stack(directions), np.stack(all_color_masks)
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
class OrthoDatasetBase():
|
| 173 |
+
def setup(self, config, split):
|
| 174 |
+
self.config = config
|
| 175 |
+
self.split = split
|
| 176 |
+
self.rank = get_rank()
|
| 177 |
+
|
| 178 |
+
self.data_dir = self.config.root_dir
|
| 179 |
+
self.object_name = self.config.scene
|
| 180 |
+
self.scene = self.config.scene
|
| 181 |
+
self.imSize = self.config.imSize
|
| 182 |
+
self.load_color = True
|
| 183 |
+
self.img_wh = [self.imSize[0], self.imSize[1]]
|
| 184 |
+
self.w = self.img_wh[0]
|
| 185 |
+
self.h = self.img_wh[1]
|
| 186 |
+
self.camera_type = self.config.camera_type
|
| 187 |
+
self.camera_params = self.config.camera_params # [fx, fy, cx, cy]
|
| 188 |
+
|
| 189 |
+
self.view_types = ['front', 'front_right', 'right', 'back', 'left', 'front_left']
|
| 190 |
+
|
| 191 |
+
self.view_weights = torch.from_numpy(np.array(self.config.view_weights)).float().to(self.rank).view(-1)
|
| 192 |
+
self.view_weights = self.view_weights.view(-1,1,1).repeat(1, self.h, self.w)
|
| 193 |
+
|
| 194 |
+
if self.config.cam_pose_dir is None:
|
| 195 |
+
self.cam_pose_dir = "./datasets/fixed_poses"
|
| 196 |
+
else:
|
| 197 |
+
self.cam_pose_dir = self.config.cam_pose_dir
|
| 198 |
+
|
| 199 |
+
self.images_np, self.masks_np, self.normals_cam_np, self.normals_world_np, \
|
| 200 |
+
self.pose_all_np, self.w2c_all_np, self.origins_np, self.directions_np, self.rgb_masks_np = load_a_prediction(
|
| 201 |
+
self.data_dir, self.object_name, self.imSize, self.view_types,
|
| 202 |
+
self.load_color, self.cam_pose_dir, normal_system='front',
|
| 203 |
+
camera_type=self.camera_type, cam_params=self.camera_params)
|
| 204 |
+
|
| 205 |
+
self.has_mask = True
|
| 206 |
+
self.apply_mask = self.config.apply_mask
|
| 207 |
+
|
| 208 |
+
self.all_c2w = torch.from_numpy(self.pose_all_np)
|
| 209 |
+
self.all_images = torch.from_numpy(self.images_np) / 255.
|
| 210 |
+
self.all_fg_masks = torch.from_numpy(self.masks_np)
|
| 211 |
+
self.all_rgb_masks = torch.from_numpy(self.rgb_masks_np)
|
| 212 |
+
self.all_normals_world = torch.from_numpy(self.normals_world_np)
|
| 213 |
+
self.origins = torch.from_numpy(self.origins_np)
|
| 214 |
+
self.directions = torch.from_numpy(self.directions_np)
|
| 215 |
+
|
| 216 |
+
self.directions = self.directions.float().to(self.rank)
|
| 217 |
+
self.origins = self.origins.float().to(self.rank)
|
| 218 |
+
self.all_rgb_masks = self.all_rgb_masks.float().to(self.rank)
|
| 219 |
+
self.all_c2w, self.all_images, self.all_fg_masks, self.all_normals_world = \
|
| 220 |
+
self.all_c2w.float().to(self.rank), \
|
| 221 |
+
self.all_images.float().to(self.rank), \
|
| 222 |
+
self.all_fg_masks.float().to(self.rank), \
|
| 223 |
+
self.all_normals_world.float().to(self.rank)
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
class OrthoDataset(Dataset, OrthoDatasetBase):
|
| 227 |
+
def __init__(self, config, split):
|
| 228 |
+
self.setup(config, split)
|
| 229 |
+
|
| 230 |
+
def __len__(self):
|
| 231 |
+
return len(self.all_images)
|
| 232 |
+
|
| 233 |
+
def __getitem__(self, index):
|
| 234 |
+
return {
|
| 235 |
+
'index': index
|
| 236 |
+
}
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
class OrthoIterableDataset(IterableDataset, OrthoDatasetBase):
|
| 240 |
+
def __init__(self, config, split):
|
| 241 |
+
self.setup(config, split)
|
| 242 |
+
|
| 243 |
+
def __iter__(self):
|
| 244 |
+
while True:
|
| 245 |
+
yield {}
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
@datasets.register('ortho')
|
| 249 |
+
class OrthoDataModule(pl.LightningDataModule):
|
| 250 |
+
def __init__(self, config):
|
| 251 |
+
super().__init__()
|
| 252 |
+
self.config = config
|
| 253 |
+
|
| 254 |
+
def setup(self, stage=None):
|
| 255 |
+
if stage in [None, 'fit']:
|
| 256 |
+
self.train_dataset = OrthoIterableDataset(self.config, 'train')
|
| 257 |
+
if stage in [None, 'fit', 'validate']:
|
| 258 |
+
self.val_dataset = OrthoDataset(self.config, self.config.get('val_split', 'train'))
|
| 259 |
+
if stage in [None, 'test']:
|
| 260 |
+
self.test_dataset = OrthoDataset(self.config, self.config.get('test_split', 'test'))
|
| 261 |
+
if stage in [None, 'predict']:
|
| 262 |
+
self.predict_dataset = OrthoDataset(self.config, 'train')
|
| 263 |
+
|
| 264 |
+
def prepare_data(self):
|
| 265 |
+
pass
|
| 266 |
+
|
| 267 |
+
def general_loader(self, dataset, batch_size):
|
| 268 |
+
sampler = None
|
| 269 |
+
return DataLoader(
|
| 270 |
+
dataset,
|
| 271 |
+
num_workers=os.cpu_count(),
|
| 272 |
+
batch_size=batch_size,
|
| 273 |
+
pin_memory=True,
|
| 274 |
+
sampler=sampler
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
def train_dataloader(self):
|
| 278 |
+
return self.general_loader(self.train_dataset, batch_size=1)
|
| 279 |
+
|
| 280 |
+
def val_dataloader(self):
|
| 281 |
+
return self.general_loader(self.val_dataset, batch_size=1)
|
| 282 |
+
|
| 283 |
+
def test_dataloader(self):
|
| 284 |
+
return self.general_loader(self.test_dataset, batch_size=1)
|
| 285 |
+
|
| 286 |
+
def predict_dataloader(self):
|
| 287 |
+
return self.general_loader(self.predict_dataset, batch_size=1)
|
mesh_recon/datasets/utils.py
ADDED
|
File without changes
|
mesh_recon/datasets/v3d.py
ADDED
|
@@ -0,0 +1,284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import math
|
| 4 |
+
import numpy as np
|
| 5 |
+
from PIL import Image
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch.utils.data import Dataset, DataLoader, IterableDataset
|
| 9 |
+
import torchvision.transforms.functional as TF
|
| 10 |
+
from torchvision.utils import make_grid, save_image
|
| 11 |
+
from einops import rearrange
|
| 12 |
+
from mediapy import read_video
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
from rembg import remove, new_session
|
| 15 |
+
|
| 16 |
+
import pytorch_lightning as pl
|
| 17 |
+
|
| 18 |
+
import datasets
|
| 19 |
+
from models.ray_utils import get_ray_directions
|
| 20 |
+
from utils.misc import get_rank
|
| 21 |
+
from datasets.ortho import (
|
| 22 |
+
inv_RT,
|
| 23 |
+
camNormal2worldNormal,
|
| 24 |
+
RT_opengl2opencv,
|
| 25 |
+
normal_opengl2opencv,
|
| 26 |
+
)
|
| 27 |
+
from utils.dpt import DPT
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def get_c2w_from_up_and_look_at(
|
| 31 |
+
up,
|
| 32 |
+
look_at,
|
| 33 |
+
pos,
|
| 34 |
+
opengl=False,
|
| 35 |
+
):
|
| 36 |
+
up = up / np.linalg.norm(up)
|
| 37 |
+
z = look_at - pos
|
| 38 |
+
z = z / np.linalg.norm(z)
|
| 39 |
+
y = -up
|
| 40 |
+
x = np.cross(y, z)
|
| 41 |
+
x /= np.linalg.norm(x)
|
| 42 |
+
y = np.cross(z, x)
|
| 43 |
+
|
| 44 |
+
c2w = np.zeros([4, 4], dtype=np.float32)
|
| 45 |
+
c2w[:3, 0] = x
|
| 46 |
+
c2w[:3, 1] = y
|
| 47 |
+
c2w[:3, 2] = z
|
| 48 |
+
c2w[:3, 3] = pos
|
| 49 |
+
c2w[3, 3] = 1.0
|
| 50 |
+
|
| 51 |
+
# opencv to opengl
|
| 52 |
+
if opengl:
|
| 53 |
+
c2w[..., 1:3] *= -1
|
| 54 |
+
|
| 55 |
+
return c2w
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def get_uniform_poses(num_frames, radius, elevation, opengl=False):
|
| 59 |
+
T = num_frames
|
| 60 |
+
azimuths = np.deg2rad(np.linspace(0, 360, T + 1)[:T])
|
| 61 |
+
elevations = np.full_like(azimuths, np.deg2rad(elevation))
|
| 62 |
+
cam_dists = np.full_like(azimuths, radius)
|
| 63 |
+
|
| 64 |
+
campos = np.stack(
|
| 65 |
+
[
|
| 66 |
+
cam_dists * np.cos(elevations) * np.cos(azimuths),
|
| 67 |
+
cam_dists * np.cos(elevations) * np.sin(azimuths),
|
| 68 |
+
cam_dists * np.sin(elevations),
|
| 69 |
+
],
|
| 70 |
+
axis=-1,
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
center = np.array([0, 0, 0], dtype=np.float32)
|
| 74 |
+
up = np.array([0, 0, 1], dtype=np.float32)
|
| 75 |
+
poses = []
|
| 76 |
+
for t in range(T):
|
| 77 |
+
poses.append(get_c2w_from_up_and_look_at(up, center, campos[t], opengl=opengl))
|
| 78 |
+
|
| 79 |
+
return np.stack(poses, axis=0)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def blender2midas(img):
|
| 83 |
+
"""Blender: rub
|
| 84 |
+
midas: lub
|
| 85 |
+
"""
|
| 86 |
+
img[..., 0] = -img[..., 0]
|
| 87 |
+
img[..., 1] = -img[..., 1]
|
| 88 |
+
img[..., -1] = -img[..., -1]
|
| 89 |
+
return img
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def midas2blender(img):
|
| 93 |
+
"""Blender: rub
|
| 94 |
+
midas: lub
|
| 95 |
+
"""
|
| 96 |
+
img[..., 0] = -img[..., 0]
|
| 97 |
+
img[..., 1] = -img[..., 1]
|
| 98 |
+
img[..., -1] = -img[..., -1]
|
| 99 |
+
return img
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class BlenderDatasetBase:
|
| 103 |
+
def setup(self, config, split):
|
| 104 |
+
self.config = config
|
| 105 |
+
self.rank = get_rank()
|
| 106 |
+
|
| 107 |
+
self.has_mask = True
|
| 108 |
+
self.apply_mask = True
|
| 109 |
+
|
| 110 |
+
dpt = DPT(device=self.rank, mode="normal")
|
| 111 |
+
|
| 112 |
+
# with open(
|
| 113 |
+
# os.path.join(
|
| 114 |
+
# self.config.root_dir, self.config.scene, f"transforms_train.json"
|
| 115 |
+
# ),
|
| 116 |
+
# "r",
|
| 117 |
+
# ) as f:
|
| 118 |
+
# meta = json.load(f)
|
| 119 |
+
|
| 120 |
+
# if "w" in meta and "h" in meta:
|
| 121 |
+
# W, H = int(meta["w"]), int(meta["h"])
|
| 122 |
+
# else:
|
| 123 |
+
# W, H = 800, 800
|
| 124 |
+
frames = read_video(Path(self.config.root_dir) / f"{self.config.scene}")
|
| 125 |
+
rembg_session = new_session()
|
| 126 |
+
num_frames, H, W = frames.shape[:3]
|
| 127 |
+
|
| 128 |
+
if "img_wh" in self.config:
|
| 129 |
+
w, h = self.config.img_wh
|
| 130 |
+
assert round(W / w * h) == H
|
| 131 |
+
elif "img_downscale" in self.config:
|
| 132 |
+
w, h = W // self.config.img_downscale, H // self.config.img_downscale
|
| 133 |
+
else:
|
| 134 |
+
raise KeyError("Either img_wh or img_downscale should be specified.")
|
| 135 |
+
|
| 136 |
+
self.w, self.h = w, h
|
| 137 |
+
self.img_wh = (self.w, self.h)
|
| 138 |
+
|
| 139 |
+
# self.near, self.far = self.config.near_plane, self.config.far_plane
|
| 140 |
+
|
| 141 |
+
self.focal = 0.5 * w / math.tan(0.5 * np.deg2rad(60)) # scaled focal length
|
| 142 |
+
|
| 143 |
+
# ray directions for all pixels, same for all images (same H, W, focal)
|
| 144 |
+
self.directions = get_ray_directions(
|
| 145 |
+
self.w, self.h, self.focal, self.focal, self.w // 2, self.h // 2
|
| 146 |
+
).to(
|
| 147 |
+
self.rank
|
| 148 |
+
) # (h, w, 3)
|
| 149 |
+
|
| 150 |
+
self.all_c2w, self.all_images, self.all_fg_masks = [], [], []
|
| 151 |
+
|
| 152 |
+
radius = 2.0
|
| 153 |
+
elevation = 0.0
|
| 154 |
+
poses = get_uniform_poses(num_frames, radius, elevation, opengl=True)
|
| 155 |
+
for i, (c2w, frame) in enumerate(zip(poses, frames)):
|
| 156 |
+
c2w = torch.from_numpy(np.array(c2w)[:3, :4])
|
| 157 |
+
self.all_c2w.append(c2w)
|
| 158 |
+
|
| 159 |
+
img = Image.fromarray(frame)
|
| 160 |
+
img = remove(img, session=rembg_session)
|
| 161 |
+
img = img.resize(self.img_wh, Image.BICUBIC)
|
| 162 |
+
img = TF.to_tensor(img).permute(1, 2, 0) # (4, h, w) => (h, w, 4)
|
| 163 |
+
|
| 164 |
+
self.all_fg_masks.append(img[..., -1]) # (h, w)
|
| 165 |
+
self.all_images.append(img[..., :3])
|
| 166 |
+
|
| 167 |
+
self.all_c2w, self.all_images, self.all_fg_masks = (
|
| 168 |
+
torch.stack(self.all_c2w, dim=0).float().to(self.rank),
|
| 169 |
+
torch.stack(self.all_images, dim=0).float().to(self.rank),
|
| 170 |
+
torch.stack(self.all_fg_masks, dim=0).float().to(self.rank),
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
self.normals = dpt(self.all_images)
|
| 174 |
+
|
| 175 |
+
self.all_masks = self.all_fg_masks.cpu().numpy() > 0.1
|
| 176 |
+
|
| 177 |
+
self.normals = self.normals * 2.0 - 1.0
|
| 178 |
+
self.normals = midas2blender(self.normals).cpu().numpy()
|
| 179 |
+
# self.normals = self.normals.cpu().numpy()
|
| 180 |
+
self.normals[..., 0] *= -1
|
| 181 |
+
self.normals[~self.all_masks] = [0, 0, 0]
|
| 182 |
+
normals = rearrange(self.normals, "b h w c -> b c h w")
|
| 183 |
+
normals = normals * 0.5 + 0.5
|
| 184 |
+
normals = torch.from_numpy(normals)
|
| 185 |
+
# save_image(make_grid(normals, nrow=4), "tmp/normals.png")
|
| 186 |
+
# exit(0)
|
| 187 |
+
|
| 188 |
+
(
|
| 189 |
+
self.all_poses,
|
| 190 |
+
self.all_normals,
|
| 191 |
+
self.all_normals_world,
|
| 192 |
+
self.all_w2cs,
|
| 193 |
+
self.all_color_masks,
|
| 194 |
+
) = ([], [], [], [], [])
|
| 195 |
+
|
| 196 |
+
for c2w_opengl, normal in zip(self.all_c2w.cpu().numpy(), self.normals):
|
| 197 |
+
RT_opengl = inv_RT(c2w_opengl)
|
| 198 |
+
RT_opencv = RT_opengl2opencv(RT_opengl)
|
| 199 |
+
c2w_opencv = inv_RT(RT_opencv)
|
| 200 |
+
self.all_poses.append(c2w_opencv)
|
| 201 |
+
self.all_w2cs.append(RT_opencv)
|
| 202 |
+
normal = normal_opengl2opencv(normal)
|
| 203 |
+
normal_world = camNormal2worldNormal(inv_RT(RT_opencv)[:3, :3], normal)
|
| 204 |
+
self.all_normals.append(normal)
|
| 205 |
+
self.all_normals_world.append(normal_world)
|
| 206 |
+
|
| 207 |
+
self.directions = torch.stack([self.directions] * len(self.all_images))
|
| 208 |
+
self.origins = self.directions
|
| 209 |
+
self.all_poses = np.stack(self.all_poses)
|
| 210 |
+
self.all_normals = np.stack(self.all_normals)
|
| 211 |
+
self.all_normals_world = np.stack(self.all_normals_world)
|
| 212 |
+
self.all_w2cs = np.stack(self.all_w2cs)
|
| 213 |
+
|
| 214 |
+
self.all_c2w = torch.from_numpy(self.all_poses).float().to(self.rank)
|
| 215 |
+
self.all_images = self.all_images.to(self.rank)
|
| 216 |
+
self.all_fg_masks = self.all_fg_masks.to(self.rank)
|
| 217 |
+
self.all_rgb_masks = self.all_fg_masks.to(self.rank)
|
| 218 |
+
self.all_normals_world = (
|
| 219 |
+
torch.from_numpy(self.all_normals_world).float().to(self.rank)
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
class BlenderDataset(Dataset, BlenderDatasetBase):
|
| 224 |
+
def __init__(self, config, split):
|
| 225 |
+
self.setup(config, split)
|
| 226 |
+
|
| 227 |
+
def __len__(self):
|
| 228 |
+
return len(self.all_images)
|
| 229 |
+
|
| 230 |
+
def __getitem__(self, index):
|
| 231 |
+
return {"index": index}
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
class BlenderIterableDataset(IterableDataset, BlenderDatasetBase):
|
| 235 |
+
def __init__(self, config, split):
|
| 236 |
+
self.setup(config, split)
|
| 237 |
+
|
| 238 |
+
def __iter__(self):
|
| 239 |
+
while True:
|
| 240 |
+
yield {}
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
@datasets.register("v3d")
|
| 244 |
+
class BlenderDataModule(pl.LightningDataModule):
|
| 245 |
+
def __init__(self, config):
|
| 246 |
+
super().__init__()
|
| 247 |
+
self.config = config
|
| 248 |
+
|
| 249 |
+
def setup(self, stage=None):
|
| 250 |
+
if stage in [None, "fit"]:
|
| 251 |
+
self.train_dataset = BlenderIterableDataset(
|
| 252 |
+
self.config, self.config.train_split
|
| 253 |
+
)
|
| 254 |
+
if stage in [None, "fit", "validate"]:
|
| 255 |
+
self.val_dataset = BlenderDataset(self.config, self.config.val_split)
|
| 256 |
+
if stage in [None, "test"]:
|
| 257 |
+
self.test_dataset = BlenderDataset(self.config, self.config.test_split)
|
| 258 |
+
if stage in [None, "predict"]:
|
| 259 |
+
self.predict_dataset = BlenderDataset(self.config, self.config.train_split)
|
| 260 |
+
|
| 261 |
+
def prepare_data(self):
|
| 262 |
+
pass
|
| 263 |
+
|
| 264 |
+
def general_loader(self, dataset, batch_size):
|
| 265 |
+
sampler = None
|
| 266 |
+
return DataLoader(
|
| 267 |
+
dataset,
|
| 268 |
+
num_workers=os.cpu_count(),
|
| 269 |
+
batch_size=batch_size,
|
| 270 |
+
pin_memory=True,
|
| 271 |
+
sampler=sampler,
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
def train_dataloader(self):
|
| 275 |
+
return self.general_loader(self.train_dataset, batch_size=1)
|
| 276 |
+
|
| 277 |
+
def val_dataloader(self):
|
| 278 |
+
return self.general_loader(self.val_dataset, batch_size=1)
|
| 279 |
+
|
| 280 |
+
def test_dataloader(self):
|
| 281 |
+
return self.general_loader(self.test_dataset, batch_size=1)
|
| 282 |
+
|
| 283 |
+
def predict_dataloader(self):
|
| 284 |
+
return self.general_loader(self.predict_dataset, batch_size=1)
|
mesh_recon/datasets/videonvs.py
ADDED
|
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import math
|
| 4 |
+
import numpy as np
|
| 5 |
+
from PIL import Image
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch.utils.data import Dataset, DataLoader, IterableDataset
|
| 9 |
+
import torchvision.transforms.functional as TF
|
| 10 |
+
from torchvision.utils import make_grid, save_image
|
| 11 |
+
from einops import rearrange
|
| 12 |
+
|
| 13 |
+
import pytorch_lightning as pl
|
| 14 |
+
|
| 15 |
+
import datasets
|
| 16 |
+
from models.ray_utils import get_ray_directions
|
| 17 |
+
from utils.misc import get_rank
|
| 18 |
+
from datasets.ortho import (
|
| 19 |
+
inv_RT,
|
| 20 |
+
camNormal2worldNormal,
|
| 21 |
+
RT_opengl2opencv,
|
| 22 |
+
normal_opengl2opencv,
|
| 23 |
+
)
|
| 24 |
+
from utils.dpt import DPT
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def blender2midas(img):
|
| 28 |
+
"""Blender: rub
|
| 29 |
+
midas: lub
|
| 30 |
+
"""
|
| 31 |
+
img[..., 0] = -img[..., 0]
|
| 32 |
+
img[..., 1] = -img[..., 1]
|
| 33 |
+
img[..., -1] = -img[..., -1]
|
| 34 |
+
return img
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def midas2blender(img):
|
| 38 |
+
"""Blender: rub
|
| 39 |
+
midas: lub
|
| 40 |
+
"""
|
| 41 |
+
img[..., 0] = -img[..., 0]
|
| 42 |
+
img[..., 1] = -img[..., 1]
|
| 43 |
+
img[..., -1] = -img[..., -1]
|
| 44 |
+
return img
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class BlenderDatasetBase:
|
| 48 |
+
def setup(self, config, split):
|
| 49 |
+
self.config = config
|
| 50 |
+
self.rank = get_rank()
|
| 51 |
+
|
| 52 |
+
self.has_mask = True
|
| 53 |
+
self.apply_mask = True
|
| 54 |
+
|
| 55 |
+
dpt = DPT(device=self.rank, mode="normal")
|
| 56 |
+
|
| 57 |
+
with open(
|
| 58 |
+
os.path.join(
|
| 59 |
+
self.config.root_dir, self.config.scene, f"transforms_train.json"
|
| 60 |
+
),
|
| 61 |
+
"r",
|
| 62 |
+
) as f:
|
| 63 |
+
meta = json.load(f)
|
| 64 |
+
|
| 65 |
+
if "w" in meta and "h" in meta:
|
| 66 |
+
W, H = int(meta["w"]), int(meta["h"])
|
| 67 |
+
else:
|
| 68 |
+
W, H = 800, 800
|
| 69 |
+
|
| 70 |
+
if "img_wh" in self.config:
|
| 71 |
+
w, h = self.config.img_wh
|
| 72 |
+
assert round(W / w * h) == H
|
| 73 |
+
elif "img_downscale" in self.config:
|
| 74 |
+
w, h = W // self.config.img_downscale, H // self.config.img_downscale
|
| 75 |
+
else:
|
| 76 |
+
raise KeyError("Either img_wh or img_downscale should be specified.")
|
| 77 |
+
|
| 78 |
+
self.w, self.h = w, h
|
| 79 |
+
self.img_wh = (self.w, self.h)
|
| 80 |
+
|
| 81 |
+
# self.near, self.far = self.config.near_plane, self.config.far_plane
|
| 82 |
+
|
| 83 |
+
self.focal = (
|
| 84 |
+
0.5 * w / math.tan(0.5 * meta["camera_angle_x"])
|
| 85 |
+
) # scaled focal length
|
| 86 |
+
|
| 87 |
+
# ray directions for all pixels, same for all images (same H, W, focal)
|
| 88 |
+
self.directions = get_ray_directions(
|
| 89 |
+
self.w, self.h, self.focal, self.focal, self.w // 2, self.h // 2
|
| 90 |
+
).to(
|
| 91 |
+
self.rank
|
| 92 |
+
) # (h, w, 3)
|
| 93 |
+
|
| 94 |
+
self.all_c2w, self.all_images, self.all_fg_masks = [], [], []
|
| 95 |
+
|
| 96 |
+
for i, frame in enumerate(meta["frames"]):
|
| 97 |
+
c2w = torch.from_numpy(np.array(frame["transform_matrix"])[:3, :4])
|
| 98 |
+
self.all_c2w.append(c2w)
|
| 99 |
+
|
| 100 |
+
img_path = os.path.join(
|
| 101 |
+
self.config.root_dir,
|
| 102 |
+
self.config.scene,
|
| 103 |
+
f"{frame['file_path']}.png",
|
| 104 |
+
)
|
| 105 |
+
img = Image.open(img_path)
|
| 106 |
+
img = img.resize(self.img_wh, Image.BICUBIC)
|
| 107 |
+
img = TF.to_tensor(img).permute(1, 2, 0) # (4, h, w) => (h, w, 4)
|
| 108 |
+
|
| 109 |
+
self.all_fg_masks.append(img[..., -1]) # (h, w)
|
| 110 |
+
self.all_images.append(img[..., :3])
|
| 111 |
+
|
| 112 |
+
self.all_c2w, self.all_images, self.all_fg_masks = (
|
| 113 |
+
torch.stack(self.all_c2w, dim=0).float().to(self.rank),
|
| 114 |
+
torch.stack(self.all_images, dim=0).float().to(self.rank),
|
| 115 |
+
torch.stack(self.all_fg_masks, dim=0).float().to(self.rank),
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
self.normals = dpt(self.all_images)
|
| 119 |
+
|
| 120 |
+
self.all_masks = self.all_fg_masks.cpu().numpy() > 0.1
|
| 121 |
+
|
| 122 |
+
self.normals = self.normals * 2.0 - 1.0
|
| 123 |
+
self.normals = midas2blender(self.normals).cpu().numpy()
|
| 124 |
+
# self.normals = self.normals.cpu().numpy()
|
| 125 |
+
self.normals[..., 0] *= -1
|
| 126 |
+
self.normals[~self.all_masks] = [0, 0, 0]
|
| 127 |
+
normals = rearrange(self.normals, "b h w c -> b c h w")
|
| 128 |
+
normals = normals * 0.5 + 0.5
|
| 129 |
+
normals = torch.from_numpy(normals)
|
| 130 |
+
save_image(make_grid(normals, nrow=4), "tmp/normals.png")
|
| 131 |
+
# exit(0)
|
| 132 |
+
|
| 133 |
+
(
|
| 134 |
+
self.all_poses,
|
| 135 |
+
self.all_normals,
|
| 136 |
+
self.all_normals_world,
|
| 137 |
+
self.all_w2cs,
|
| 138 |
+
self.all_color_masks,
|
| 139 |
+
) = ([], [], [], [], [])
|
| 140 |
+
|
| 141 |
+
for c2w_opengl, normal in zip(self.all_c2w.cpu().numpy(), self.normals):
|
| 142 |
+
RT_opengl = inv_RT(c2w_opengl)
|
| 143 |
+
RT_opencv = RT_opengl2opencv(RT_opengl)
|
| 144 |
+
c2w_opencv = inv_RT(RT_opencv)
|
| 145 |
+
self.all_poses.append(c2w_opencv)
|
| 146 |
+
self.all_w2cs.append(RT_opencv)
|
| 147 |
+
normal = normal_opengl2opencv(normal)
|
| 148 |
+
normal_world = camNormal2worldNormal(inv_RT(RT_opencv)[:3, :3], normal)
|
| 149 |
+
self.all_normals.append(normal)
|
| 150 |
+
self.all_normals_world.append(normal_world)
|
| 151 |
+
|
| 152 |
+
self.directions = torch.stack([self.directions] * len(self.all_images))
|
| 153 |
+
self.origins = self.directions
|
| 154 |
+
self.all_poses = np.stack(self.all_poses)
|
| 155 |
+
self.all_normals = np.stack(self.all_normals)
|
| 156 |
+
self.all_normals_world = np.stack(self.all_normals_world)
|
| 157 |
+
self.all_w2cs = np.stack(self.all_w2cs)
|
| 158 |
+
|
| 159 |
+
self.all_c2w = torch.from_numpy(self.all_poses).float().to(self.rank)
|
| 160 |
+
self.all_images = self.all_images.to(self.rank)
|
| 161 |
+
self.all_fg_masks = self.all_fg_masks.to(self.rank)
|
| 162 |
+
self.all_rgb_masks = self.all_fg_masks.to(self.rank)
|
| 163 |
+
self.all_normals_world = (
|
| 164 |
+
torch.from_numpy(self.all_normals_world).float().to(self.rank)
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
# normals = rearrange(self.all_normals_world, "b h w c -> b c h w")
|
| 168 |
+
# normals = normals * 0.5 + 0.5
|
| 169 |
+
# # normals = torch.from_numpy(normals)
|
| 170 |
+
# save_image(make_grid(normals, nrow=4), "tmp/normals_world.png")
|
| 171 |
+
# # exit(0)
|
| 172 |
+
|
| 173 |
+
# # normals = (normals + 1) / 2.0
|
| 174 |
+
# # for debug
|
| 175 |
+
# index = [0, 9]
|
| 176 |
+
# self.all_poses = self.all_poses[index]
|
| 177 |
+
# self.all_c2w = self.all_c2w[index]
|
| 178 |
+
# self.all_normals_world = self.all_normals_world[index]
|
| 179 |
+
# self.all_w2cs = self.all_w2cs[index]
|
| 180 |
+
# self.rgb_masks = self.all_rgb_masks[index]
|
| 181 |
+
# self.fg_masks = self.all_fg_masks[index]
|
| 182 |
+
# self.all_images = self.all_images[index]
|
| 183 |
+
# self.directions = self.directions[index]
|
| 184 |
+
# self.origins = self.origins[index]
|
| 185 |
+
|
| 186 |
+
# images = rearrange(self.all_images, "b h w c -> b c h w")
|
| 187 |
+
# normals = rearrange(normals, "b h w c -> b c h w")
|
| 188 |
+
# save_image(make_grid(images, nrow=4), "tmp/images.png")
|
| 189 |
+
# save_image(make_grid(normals, nrow=4), "tmp/normals.png")
|
| 190 |
+
# breakpoint()
|
| 191 |
+
|
| 192 |
+
# self.normals = self.normals * 2.0 - 1.0
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
class BlenderDataset(Dataset, BlenderDatasetBase):
|
| 196 |
+
def __init__(self, config, split):
|
| 197 |
+
self.setup(config, split)
|
| 198 |
+
|
| 199 |
+
def __len__(self):
|
| 200 |
+
return len(self.all_images)
|
| 201 |
+
|
| 202 |
+
def __getitem__(self, index):
|
| 203 |
+
return {"index": index}
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
class BlenderIterableDataset(IterableDataset, BlenderDatasetBase):
|
| 207 |
+
def __init__(self, config, split):
|
| 208 |
+
self.setup(config, split)
|
| 209 |
+
|
| 210 |
+
def __iter__(self):
|
| 211 |
+
while True:
|
| 212 |
+
yield {}
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
@datasets.register("videonvs")
|
| 216 |
+
class BlenderDataModule(pl.LightningDataModule):
|
| 217 |
+
def __init__(self, config):
|
| 218 |
+
super().__init__()
|
| 219 |
+
self.config = config
|
| 220 |
+
|
| 221 |
+
def setup(self, stage=None):
|
| 222 |
+
if stage in [None, "fit"]:
|
| 223 |
+
self.train_dataset = BlenderIterableDataset(
|
| 224 |
+
self.config, self.config.train_split
|
| 225 |
+
)
|
| 226 |
+
if stage in [None, "fit", "validate"]:
|
| 227 |
+
self.val_dataset = BlenderDataset(self.config, self.config.val_split)
|
| 228 |
+
if stage in [None, "test"]:
|
| 229 |
+
self.test_dataset = BlenderDataset(self.config, self.config.test_split)
|
| 230 |
+
if stage in [None, "predict"]:
|
| 231 |
+
self.predict_dataset = BlenderDataset(self.config, self.config.train_split)
|
| 232 |
+
|
| 233 |
+
def prepare_data(self):
|
| 234 |
+
pass
|
| 235 |
+
|
| 236 |
+
def general_loader(self, dataset, batch_size):
|
| 237 |
+
sampler = None
|
| 238 |
+
return DataLoader(
|
| 239 |
+
dataset,
|
| 240 |
+
num_workers=os.cpu_count(),
|
| 241 |
+
batch_size=batch_size,
|
| 242 |
+
pin_memory=True,
|
| 243 |
+
sampler=sampler,
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
def train_dataloader(self):
|
| 247 |
+
return self.general_loader(self.train_dataset, batch_size=1)
|
| 248 |
+
|
| 249 |
+
def val_dataloader(self):
|
| 250 |
+
return self.general_loader(self.val_dataset, batch_size=1)
|
| 251 |
+
|
| 252 |
+
def test_dataloader(self):
|
| 253 |
+
return self.general_loader(self.test_dataset, batch_size=1)
|
| 254 |
+
|
| 255 |
+
def predict_dataloader(self):
|
| 256 |
+
return self.general_loader(self.predict_dataset, batch_size=1)
|
mesh_recon/datasets/videonvs_co3d.py
ADDED
|
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import math
|
| 4 |
+
import numpy as np
|
| 5 |
+
from PIL import Image
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch.utils.data import Dataset, DataLoader, IterableDataset
|
| 9 |
+
import torchvision.transforms.functional as TF
|
| 10 |
+
from torchvision.utils import make_grid, save_image
|
| 11 |
+
from einops import rearrange
|
| 12 |
+
from rembg import remove, new_session
|
| 13 |
+
|
| 14 |
+
import pytorch_lightning as pl
|
| 15 |
+
|
| 16 |
+
import datasets
|
| 17 |
+
from models.ray_utils import get_ray_directions
|
| 18 |
+
from utils.misc import get_rank
|
| 19 |
+
from datasets.ortho import (
|
| 20 |
+
inv_RT,
|
| 21 |
+
camNormal2worldNormal,
|
| 22 |
+
RT_opengl2opencv,
|
| 23 |
+
normal_opengl2opencv,
|
| 24 |
+
)
|
| 25 |
+
from utils.dpt import DPT
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def blender2midas(img):
|
| 29 |
+
"""Blender: rub
|
| 30 |
+
midas: lub
|
| 31 |
+
"""
|
| 32 |
+
img[..., 0] = -img[..., 0]
|
| 33 |
+
img[..., 1] = -img[..., 1]
|
| 34 |
+
img[..., -1] = -img[..., -1]
|
| 35 |
+
return img
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def midas2blender(img):
|
| 39 |
+
"""Blender: rub
|
| 40 |
+
midas: lub
|
| 41 |
+
"""
|
| 42 |
+
img[..., 0] = -img[..., 0]
|
| 43 |
+
img[..., 1] = -img[..., 1]
|
| 44 |
+
img[..., -1] = -img[..., -1]
|
| 45 |
+
return img
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class BlenderDatasetBase:
|
| 49 |
+
def setup(self, config, split):
|
| 50 |
+
self.config = config
|
| 51 |
+
self.rank = get_rank()
|
| 52 |
+
|
| 53 |
+
self.has_mask = True
|
| 54 |
+
self.apply_mask = True
|
| 55 |
+
|
| 56 |
+
dpt = DPT(device=self.rank, mode="normal")
|
| 57 |
+
|
| 58 |
+
self.directions = []
|
| 59 |
+
with open(
|
| 60 |
+
os.path.join(self.config.root_dir, self.config.scene, f"transforms.json"),
|
| 61 |
+
"r",
|
| 62 |
+
) as f:
|
| 63 |
+
meta = json.load(f)
|
| 64 |
+
|
| 65 |
+
if "w" in meta and "h" in meta:
|
| 66 |
+
W, H = int(meta["w"]), int(meta["h"])
|
| 67 |
+
else:
|
| 68 |
+
W, H = 800, 800
|
| 69 |
+
|
| 70 |
+
if "img_wh" in self.config:
|
| 71 |
+
w, h = self.config.img_wh
|
| 72 |
+
assert round(W / w * h) == H
|
| 73 |
+
elif "img_downscale" in self.config:
|
| 74 |
+
w, h = W // self.config.img_downscale, H // self.config.img_downscale
|
| 75 |
+
else:
|
| 76 |
+
raise KeyError("Either img_wh or img_downscale should be specified.")
|
| 77 |
+
|
| 78 |
+
self.w, self.h = w, h
|
| 79 |
+
self.img_wh = (self.w, self.h)
|
| 80 |
+
|
| 81 |
+
# self.near, self.far = self.config.near_plane, self.config.far_plane
|
| 82 |
+
_session = new_session()
|
| 83 |
+
self.all_c2w, self.all_images, self.all_fg_masks = [], [], []
|
| 84 |
+
|
| 85 |
+
for i, frame in enumerate(meta["frames"]):
|
| 86 |
+
c2w = torch.from_numpy(np.array(frame["transform_matrix"])[:3, :4])
|
| 87 |
+
self.all_c2w.append(c2w)
|
| 88 |
+
|
| 89 |
+
img_path = os.path.join(
|
| 90 |
+
self.config.root_dir,
|
| 91 |
+
self.config.scene,
|
| 92 |
+
f"{frame['file_path']}",
|
| 93 |
+
)
|
| 94 |
+
img = Image.open(img_path)
|
| 95 |
+
img = remove(img, session=_session)
|
| 96 |
+
img = img.resize(self.img_wh, Image.BICUBIC)
|
| 97 |
+
img = TF.to_tensor(img).permute(1, 2, 0) # (4, h, w) => (h, w, 4)
|
| 98 |
+
fx = frame["fl_x"]
|
| 99 |
+
fy = frame["fl_y"]
|
| 100 |
+
cx = frame["cx"]
|
| 101 |
+
cy = frame["cy"]
|
| 102 |
+
|
| 103 |
+
self.all_fg_masks.append(img[..., -1]) # (h, w)
|
| 104 |
+
self.all_images.append(img[..., :3])
|
| 105 |
+
|
| 106 |
+
self.directions.append(get_ray_directions(self.w, self.h, fx, fy, cx, cy))
|
| 107 |
+
|
| 108 |
+
self.all_c2w, self.all_images, self.all_fg_masks = (
|
| 109 |
+
torch.stack(self.all_c2w, dim=0).float().to(self.rank),
|
| 110 |
+
torch.stack(self.all_images, dim=0).float().to(self.rank),
|
| 111 |
+
torch.stack(self.all_fg_masks, dim=0).float().to(self.rank),
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
self.normals = dpt(self.all_images)
|
| 115 |
+
|
| 116 |
+
self.all_masks = self.all_fg_masks.cpu().numpy() > 0.1
|
| 117 |
+
|
| 118 |
+
self.normals = self.normals * 2.0 - 1.0
|
| 119 |
+
self.normals = midas2blender(self.normals).cpu().numpy()
|
| 120 |
+
# self.normals = self.normals.cpu().numpy()
|
| 121 |
+
self.normals[..., 0] *= -1
|
| 122 |
+
self.normals[~self.all_masks] = [0, 0, 0]
|
| 123 |
+
normals = rearrange(self.normals, "b h w c -> b c h w")
|
| 124 |
+
normals = normals * 0.5 + 0.5
|
| 125 |
+
normals = torch.from_numpy(normals)
|
| 126 |
+
save_image(make_grid(normals, nrow=4), "tmp/normals.png")
|
| 127 |
+
# exit(0)
|
| 128 |
+
|
| 129 |
+
(
|
| 130 |
+
self.all_poses,
|
| 131 |
+
self.all_normals,
|
| 132 |
+
self.all_normals_world,
|
| 133 |
+
self.all_w2cs,
|
| 134 |
+
self.all_color_masks,
|
| 135 |
+
) = ([], [], [], [], [])
|
| 136 |
+
|
| 137 |
+
for c2w_opengl, normal in zip(self.all_c2w.cpu().numpy(), self.normals):
|
| 138 |
+
RT_opengl = inv_RT(c2w_opengl)
|
| 139 |
+
RT_opencv = RT_opengl2opencv(RT_opengl)
|
| 140 |
+
c2w_opencv = inv_RT(RT_opencv)
|
| 141 |
+
self.all_poses.append(c2w_opencv)
|
| 142 |
+
self.all_w2cs.append(RT_opencv)
|
| 143 |
+
normal = normal_opengl2opencv(normal)
|
| 144 |
+
normal_world = camNormal2worldNormal(inv_RT(RT_opencv)[:3, :3], normal)
|
| 145 |
+
self.all_normals.append(normal)
|
| 146 |
+
self.all_normals_world.append(normal_world)
|
| 147 |
+
|
| 148 |
+
self.directions = torch.stack(self.directions).to(self.rank)
|
| 149 |
+
self.origins = self.directions
|
| 150 |
+
self.all_poses = np.stack(self.all_poses)
|
| 151 |
+
self.all_normals = np.stack(self.all_normals)
|
| 152 |
+
self.all_normals_world = np.stack(self.all_normals_world)
|
| 153 |
+
self.all_w2cs = np.stack(self.all_w2cs)
|
| 154 |
+
|
| 155 |
+
self.all_c2w = torch.from_numpy(self.all_poses).float().to(self.rank)
|
| 156 |
+
self.all_images = self.all_images.to(self.rank)
|
| 157 |
+
self.all_fg_masks = self.all_fg_masks.to(self.rank)
|
| 158 |
+
self.all_rgb_masks = self.all_fg_masks.to(self.rank)
|
| 159 |
+
self.all_normals_world = (
|
| 160 |
+
torch.from_numpy(self.all_normals_world).float().to(self.rank)
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
# normals = rearrange(self.all_normals_world, "b h w c -> b c h w")
|
| 164 |
+
# normals = normals * 0.5 + 0.5
|
| 165 |
+
# # normals = torch.from_numpy(normals)
|
| 166 |
+
# save_image(make_grid(normals, nrow=4), "tmp/normals_world.png")
|
| 167 |
+
# # exit(0)
|
| 168 |
+
|
| 169 |
+
# # normals = (normals + 1) / 2.0
|
| 170 |
+
# # for debug
|
| 171 |
+
# index = [0, 9]
|
| 172 |
+
# self.all_poses = self.all_poses[index]
|
| 173 |
+
# self.all_c2w = self.all_c2w[index]
|
| 174 |
+
# self.all_normals_world = self.all_normals_world[index]
|
| 175 |
+
# self.all_w2cs = self.all_w2cs[index]
|
| 176 |
+
# self.rgb_masks = self.all_rgb_masks[index]
|
| 177 |
+
# self.fg_masks = self.all_fg_masks[index]
|
| 178 |
+
# self.all_images = self.all_images[index]
|
| 179 |
+
# self.directions = self.directions[index]
|
| 180 |
+
# self.origins = self.origins[index]
|
| 181 |
+
|
| 182 |
+
# images = rearrange(self.all_images, "b h w c -> b c h w")
|
| 183 |
+
# normals = rearrange(normals, "b h w c -> b c h w")
|
| 184 |
+
# save_image(make_grid(images, nrow=4), "tmp/images.png")
|
| 185 |
+
# save_image(make_grid(normals, nrow=4), "tmp/normals.png")
|
| 186 |
+
# breakpoint()
|
| 187 |
+
|
| 188 |
+
# self.normals = self.normals * 2.0 - 1.0
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
class BlenderDataset(Dataset, BlenderDatasetBase):
|
| 192 |
+
def __init__(self, config, split):
|
| 193 |
+
self.setup(config, split)
|
| 194 |
+
|
| 195 |
+
def __len__(self):
|
| 196 |
+
return len(self.all_images)
|
| 197 |
+
|
| 198 |
+
def __getitem__(self, index):
|
| 199 |
+
return {"index": index}
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
class BlenderIterableDataset(IterableDataset, BlenderDatasetBase):
|
| 203 |
+
def __init__(self, config, split):
|
| 204 |
+
self.setup(config, split)
|
| 205 |
+
|
| 206 |
+
def __iter__(self):
|
| 207 |
+
while True:
|
| 208 |
+
yield {}
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
@datasets.register("videonvs-scene")
|
| 212 |
+
class VideoNVSScene(pl.LightningDataModule):
|
| 213 |
+
def __init__(self, config):
|
| 214 |
+
super().__init__()
|
| 215 |
+
self.config = config
|
| 216 |
+
|
| 217 |
+
def setup(self, stage=None):
|
| 218 |
+
if stage in [None, "fit"]:
|
| 219 |
+
self.train_dataset = BlenderIterableDataset(
|
| 220 |
+
self.config, self.config.train_split
|
| 221 |
+
)
|
| 222 |
+
if stage in [None, "fit", "validate"]:
|
| 223 |
+
self.val_dataset = BlenderDataset(self.config, self.config.val_split)
|
| 224 |
+
if stage in [None, "test"]:
|
| 225 |
+
self.test_dataset = BlenderDataset(self.config, self.config.test_split)
|
| 226 |
+
if stage in [None, "predict"]:
|
| 227 |
+
self.predict_dataset = BlenderDataset(self.config, self.config.train_split)
|
| 228 |
+
|
| 229 |
+
def prepare_data(self):
|
| 230 |
+
pass
|
| 231 |
+
|
| 232 |
+
def general_loader(self, dataset, batch_size):
|
| 233 |
+
sampler = None
|
| 234 |
+
return DataLoader(
|
| 235 |
+
dataset,
|
| 236 |
+
num_workers=os.cpu_count(),
|
| 237 |
+
batch_size=batch_size,
|
| 238 |
+
pin_memory=True,
|
| 239 |
+
sampler=sampler,
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
def train_dataloader(self):
|
| 243 |
+
return self.general_loader(self.train_dataset, batch_size=1)
|
| 244 |
+
|
| 245 |
+
def val_dataloader(self):
|
| 246 |
+
return self.general_loader(self.val_dataset, batch_size=1)
|
| 247 |
+
|
| 248 |
+
def test_dataloader(self):
|
| 249 |
+
return self.general_loader(self.test_dataset, batch_size=1)
|
| 250 |
+
|
| 251 |
+
def predict_dataloader(self):
|
| 252 |
+
return self.general_loader(self.predict_dataset, batch_size=1)
|
mesh_recon/launch.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import argparse
|
| 3 |
+
import os
|
| 4 |
+
import time
|
| 5 |
+
import logging
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def main():
|
| 10 |
+
parser = argparse.ArgumentParser()
|
| 11 |
+
parser.add_argument("--config", required=True, help="path to config file")
|
| 12 |
+
parser.add_argument("--gpu", default="0", help="GPU(s) to be used")
|
| 13 |
+
parser.add_argument(
|
| 14 |
+
"--resume", default=None, help="path to the weights to be resumed"
|
| 15 |
+
)
|
| 16 |
+
parser.add_argument(
|
| 17 |
+
"--resume_weights_only",
|
| 18 |
+
action="store_true",
|
| 19 |
+
help="specify this argument to restore only the weights (w/o training states), e.g. --resume path/to/resume --resume_weights_only",
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
group = parser.add_mutually_exclusive_group(required=True)
|
| 23 |
+
group.add_argument("--train", action="store_true")
|
| 24 |
+
group.add_argument("--validate", action="store_true")
|
| 25 |
+
group.add_argument("--test", action="store_true")
|
| 26 |
+
group.add_argument("--predict", action="store_true")
|
| 27 |
+
# group.add_argument('--export', action='store_true') # TODO: a separate export action
|
| 28 |
+
|
| 29 |
+
parser.add_argument("--exp_dir", default="./exp")
|
| 30 |
+
parser.add_argument("--runs_dir", default="./runs")
|
| 31 |
+
parser.add_argument(
|
| 32 |
+
"--verbose", action="store_true", help="if true, set logging level to DEBUG"
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
args, extras = parser.parse_known_args()
|
| 36 |
+
|
| 37 |
+
# set CUDA_VISIBLE_DEVICES then import pytorch-lightning
|
| 38 |
+
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
| 39 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
|
| 40 |
+
n_gpus = len(args.gpu.split(","))
|
| 41 |
+
|
| 42 |
+
import datasets
|
| 43 |
+
import systems
|
| 44 |
+
import pytorch_lightning as pl
|
| 45 |
+
from pytorch_lightning import Trainer
|
| 46 |
+
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
|
| 47 |
+
from pytorch_lightning.loggers import TensorBoardLogger, CSVLogger
|
| 48 |
+
from utils.callbacks import (
|
| 49 |
+
CodeSnapshotCallback,
|
| 50 |
+
ConfigSnapshotCallback,
|
| 51 |
+
CustomProgressBar,
|
| 52 |
+
)
|
| 53 |
+
from utils.misc import load_config
|
| 54 |
+
|
| 55 |
+
# parse YAML config to OmegaConf
|
| 56 |
+
config = load_config(args.config, cli_args=extras)
|
| 57 |
+
config.cmd_args = vars(args)
|
| 58 |
+
|
| 59 |
+
config.trial_name = config.get("trial_name") or (
|
| 60 |
+
config.tag + datetime.now().strftime("@%Y%m%d-%H%M%S")
|
| 61 |
+
)
|
| 62 |
+
config.exp_dir = config.get("exp_dir") or os.path.join(args.exp_dir, config.name)
|
| 63 |
+
config.save_dir = config.get("save_dir") or os.path.join(
|
| 64 |
+
config.exp_dir, config.trial_name, "save"
|
| 65 |
+
)
|
| 66 |
+
config.ckpt_dir = config.get("ckpt_dir") or os.path.join(
|
| 67 |
+
config.exp_dir, config.trial_name, "ckpt"
|
| 68 |
+
)
|
| 69 |
+
config.code_dir = config.get("code_dir") or os.path.join(
|
| 70 |
+
config.exp_dir, config.trial_name, "code"
|
| 71 |
+
)
|
| 72 |
+
config.config_dir = config.get("config_dir") or os.path.join(
|
| 73 |
+
config.exp_dir, config.trial_name, "config"
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
logger = logging.getLogger("pytorch_lightning")
|
| 77 |
+
if args.verbose:
|
| 78 |
+
logger.setLevel(logging.DEBUG)
|
| 79 |
+
|
| 80 |
+
if "seed" not in config:
|
| 81 |
+
config.seed = int(time.time() * 1000) % 1000
|
| 82 |
+
pl.seed_everything(config.seed)
|
| 83 |
+
|
| 84 |
+
dm = datasets.make(config.dataset.name, config.dataset)
|
| 85 |
+
system = systems.make(
|
| 86 |
+
config.system.name,
|
| 87 |
+
config,
|
| 88 |
+
load_from_checkpoint=None if not args.resume_weights_only else args.resume,
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
callbacks = []
|
| 92 |
+
if args.train:
|
| 93 |
+
callbacks += [
|
| 94 |
+
ModelCheckpoint(dirpath=config.ckpt_dir, **config.checkpoint),
|
| 95 |
+
LearningRateMonitor(logging_interval="step"),
|
| 96 |
+
# CodeSnapshotCallback(
|
| 97 |
+
# config.code_dir, use_version=False
|
| 98 |
+
# ),
|
| 99 |
+
ConfigSnapshotCallback(config, config.config_dir, use_version=False),
|
| 100 |
+
CustomProgressBar(refresh_rate=1),
|
| 101 |
+
]
|
| 102 |
+
|
| 103 |
+
loggers = []
|
| 104 |
+
if args.train:
|
| 105 |
+
loggers += [
|
| 106 |
+
TensorBoardLogger(
|
| 107 |
+
args.runs_dir, name=config.name, version=config.trial_name
|
| 108 |
+
),
|
| 109 |
+
CSVLogger(config.exp_dir, name=config.trial_name, version="csv_logs"),
|
| 110 |
+
]
|
| 111 |
+
|
| 112 |
+
if sys.platform == "win32":
|
| 113 |
+
# does not support multi-gpu on windows
|
| 114 |
+
strategy = "dp"
|
| 115 |
+
assert n_gpus == 1
|
| 116 |
+
else:
|
| 117 |
+
strategy = "ddp_find_unused_parameters_false"
|
| 118 |
+
|
| 119 |
+
trainer = Trainer(
|
| 120 |
+
devices=n_gpus,
|
| 121 |
+
accelerator="gpu",
|
| 122 |
+
callbacks=callbacks,
|
| 123 |
+
logger=loggers,
|
| 124 |
+
strategy=strategy,
|
| 125 |
+
**config.trainer
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
if args.train:
|
| 129 |
+
if args.resume and not args.resume_weights_only:
|
| 130 |
+
# FIXME: different behavior in pytorch-lighting>1.9 ?
|
| 131 |
+
trainer.fit(system, datamodule=dm, ckpt_path=args.resume)
|
| 132 |
+
else:
|
| 133 |
+
trainer.fit(system, datamodule=dm)
|
| 134 |
+
trainer.test(system, datamodule=dm)
|
| 135 |
+
elif args.validate:
|
| 136 |
+
trainer.validate(system, datamodule=dm, ckpt_path=args.resume)
|
| 137 |
+
elif args.test:
|
| 138 |
+
trainer.test(system, datamodule=dm, ckpt_path=args.resume)
|
| 139 |
+
elif args.predict:
|
| 140 |
+
trainer.predict(system, datamodule=dm, ckpt_path=args.resume)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
if __name__ == "__main__":
|
| 144 |
+
main()
|
mesh_recon/mesh.py
ADDED
|
@@ -0,0 +1,845 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import cv2
|
| 3 |
+
import torch
|
| 4 |
+
import trimesh
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
from kiui.op import safe_normalize, dot
|
| 8 |
+
from kiui.typing import *
|
| 9 |
+
|
| 10 |
+
class Mesh:
|
| 11 |
+
"""
|
| 12 |
+
A torch-native trimesh class, with support for ``ply/obj/glb`` formats.
|
| 13 |
+
|
| 14 |
+
Note:
|
| 15 |
+
This class only supports one mesh with a single texture image (an albedo texture and a metallic-roughness texture).
|
| 16 |
+
"""
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
v: Optional[Tensor] = None,
|
| 20 |
+
f: Optional[Tensor] = None,
|
| 21 |
+
vn: Optional[Tensor] = None,
|
| 22 |
+
fn: Optional[Tensor] = None,
|
| 23 |
+
vt: Optional[Tensor] = None,
|
| 24 |
+
ft: Optional[Tensor] = None,
|
| 25 |
+
vc: Optional[Tensor] = None, # vertex color
|
| 26 |
+
albedo: Optional[Tensor] = None,
|
| 27 |
+
metallicRoughness: Optional[Tensor] = None,
|
| 28 |
+
device: Optional[torch.device] = None,
|
| 29 |
+
):
|
| 30 |
+
"""Init a mesh directly using all attributes.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
v (Optional[Tensor]): vertices, float [N, 3]. Defaults to None.
|
| 34 |
+
f (Optional[Tensor]): faces, int [M, 3]. Defaults to None.
|
| 35 |
+
vn (Optional[Tensor]): vertex normals, float [N, 3]. Defaults to None.
|
| 36 |
+
fn (Optional[Tensor]): faces for normals, int [M, 3]. Defaults to None.
|
| 37 |
+
vt (Optional[Tensor]): vertex uv coordinates, float [N, 2]. Defaults to None.
|
| 38 |
+
ft (Optional[Tensor]): faces for uvs, int [M, 3]. Defaults to None.
|
| 39 |
+
vc (Optional[Tensor]): vertex colors, float [N, 3]. Defaults to None.
|
| 40 |
+
albedo (Optional[Tensor]): albedo texture, float [H, W, 3], RGB format. Defaults to None.
|
| 41 |
+
metallicRoughness (Optional[Tensor]): metallic-roughness texture, float [H, W, 3], metallic(Blue) = metallicRoughness[..., 2], roughness(Green) = metallicRoughness[..., 1]. Defaults to None.
|
| 42 |
+
device (Optional[torch.device]): torch device. Defaults to None.
|
| 43 |
+
"""
|
| 44 |
+
self.device = device
|
| 45 |
+
self.v = v
|
| 46 |
+
self.vn = vn
|
| 47 |
+
self.vt = vt
|
| 48 |
+
self.f = f
|
| 49 |
+
self.fn = fn
|
| 50 |
+
self.ft = ft
|
| 51 |
+
# will first see if there is vertex color to use
|
| 52 |
+
self.vc = vc
|
| 53 |
+
# only support a single albedo image
|
| 54 |
+
self.albedo = albedo
|
| 55 |
+
# pbr extension, metallic(Blue) = metallicRoughness[..., 2], roughness(Green) = metallicRoughness[..., 1]
|
| 56 |
+
# ref: https://registry.khronos.org/glTF/specs/2.0/glTF-2.0.html
|
| 57 |
+
self.metallicRoughness = metallicRoughness
|
| 58 |
+
|
| 59 |
+
self.ori_center = 0
|
| 60 |
+
self.ori_scale = 1
|
| 61 |
+
|
| 62 |
+
@classmethod
|
| 63 |
+
def load(cls, path, resize=True, clean=False, renormal=True, retex=False, bound=0.9, front_dir='+z', **kwargs):
|
| 64 |
+
"""load mesh from path.
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
path (str): path to mesh file, supports ply, obj, glb.
|
| 68 |
+
clean (bool, optional): perform mesh cleaning at load (e.g., merge close vertices). Defaults to False.
|
| 69 |
+
resize (bool, optional): auto resize the mesh using ``bound`` into [-bound, bound]^3. Defaults to True.
|
| 70 |
+
renormal (bool, optional): re-calc the vertex normals. Defaults to True.
|
| 71 |
+
retex (bool, optional): re-calc the uv coordinates, will overwrite the existing uv coordinates. Defaults to False.
|
| 72 |
+
bound (float, optional): bound to resize. Defaults to 0.9.
|
| 73 |
+
front_dir (str, optional): front-view direction of the mesh, should be [+-][xyz][ 123]. Defaults to '+z'.
|
| 74 |
+
device (torch.device, optional): torch device. Defaults to None.
|
| 75 |
+
|
| 76 |
+
Note:
|
| 77 |
+
a ``device`` keyword argument can be provided to specify the torch device.
|
| 78 |
+
If it's not provided, we will try to use ``'cuda'`` as the device if it's available.
|
| 79 |
+
|
| 80 |
+
Returns:
|
| 81 |
+
Mesh: the loaded Mesh object.
|
| 82 |
+
"""
|
| 83 |
+
# obj supports face uv
|
| 84 |
+
if path.endswith(".obj"):
|
| 85 |
+
mesh = cls.load_obj(path, **kwargs)
|
| 86 |
+
# trimesh only supports vertex uv, but can load more formats
|
| 87 |
+
else:
|
| 88 |
+
mesh = cls.load_trimesh(path, **kwargs)
|
| 89 |
+
|
| 90 |
+
# clean
|
| 91 |
+
if clean:
|
| 92 |
+
from kiui.mesh_utils import clean_mesh
|
| 93 |
+
vertices = mesh.v.detach().cpu().numpy()
|
| 94 |
+
triangles = mesh.f.detach().cpu().numpy()
|
| 95 |
+
vertices, triangles = clean_mesh(vertices, triangles, remesh=False)
|
| 96 |
+
mesh.v = torch.from_numpy(vertices).contiguous().float().to(mesh.device)
|
| 97 |
+
mesh.f = torch.from_numpy(triangles).contiguous().int().to(mesh.device)
|
| 98 |
+
|
| 99 |
+
print(f"[Mesh loading] v: {mesh.v.shape}, f: {mesh.f.shape}")
|
| 100 |
+
# auto-normalize
|
| 101 |
+
if resize:
|
| 102 |
+
mesh.auto_size(bound=bound)
|
| 103 |
+
# auto-fix normal
|
| 104 |
+
if renormal or mesh.vn is None:
|
| 105 |
+
mesh.auto_normal()
|
| 106 |
+
print(f"[Mesh loading] vn: {mesh.vn.shape}, fn: {mesh.fn.shape}")
|
| 107 |
+
# auto-fix texcoords
|
| 108 |
+
if retex or (mesh.albedo is not None and mesh.vt is None):
|
| 109 |
+
mesh.auto_uv(cache_path=path)
|
| 110 |
+
print(f"[Mesh loading] vt: {mesh.vt.shape}, ft: {mesh.ft.shape}")
|
| 111 |
+
|
| 112 |
+
# rotate front dir to +z
|
| 113 |
+
if front_dir != "+z":
|
| 114 |
+
# axis switch
|
| 115 |
+
if "-z" in front_dir:
|
| 116 |
+
T = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, -1]], device=mesh.device, dtype=torch.float32)
|
| 117 |
+
elif "+x" in front_dir:
|
| 118 |
+
T = torch.tensor([[0, 0, 1], [0, 1, 0], [1, 0, 0]], device=mesh.device, dtype=torch.float32)
|
| 119 |
+
elif "-x" in front_dir:
|
| 120 |
+
T = torch.tensor([[0, 0, -1], [0, 1, 0], [1, 0, 0]], device=mesh.device, dtype=torch.float32)
|
| 121 |
+
elif "+y" in front_dir:
|
| 122 |
+
T = torch.tensor([[1, 0, 0], [0, 0, 1], [0, 1, 0]], device=mesh.device, dtype=torch.float32)
|
| 123 |
+
elif "-y" in front_dir:
|
| 124 |
+
T = torch.tensor([[1, 0, 0], [0, 0, -1], [0, 1, 0]], device=mesh.device, dtype=torch.float32)
|
| 125 |
+
else:
|
| 126 |
+
T = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]], device=mesh.device, dtype=torch.float32)
|
| 127 |
+
# rotation (how many 90 degrees)
|
| 128 |
+
if '1' in front_dir:
|
| 129 |
+
T @= torch.tensor([[0, -1, 0], [1, 0, 0], [0, 0, 1]], device=mesh.device, dtype=torch.float32)
|
| 130 |
+
elif '2' in front_dir:
|
| 131 |
+
T @= torch.tensor([[1, 0, 0], [0, -1, 0], [0, 0, 1]], device=mesh.device, dtype=torch.float32)
|
| 132 |
+
elif '3' in front_dir:
|
| 133 |
+
T @= torch.tensor([[0, 1, 0], [-1, 0, 0], [0, 0, 1]], device=mesh.device, dtype=torch.float32)
|
| 134 |
+
mesh.v @= T
|
| 135 |
+
mesh.vn @= T
|
| 136 |
+
|
| 137 |
+
return mesh
|
| 138 |
+
|
| 139 |
+
# load from obj file
|
| 140 |
+
@classmethod
|
| 141 |
+
def load_obj(cls, path, albedo_path=None, device=None):
|
| 142 |
+
"""load an ``obj`` mesh.
|
| 143 |
+
|
| 144 |
+
Args:
|
| 145 |
+
path (str): path to mesh.
|
| 146 |
+
albedo_path (str, optional): path to the albedo texture image, will overwrite the existing texture path if specified in mtl. Defaults to None.
|
| 147 |
+
device (torch.device, optional): torch device. Defaults to None.
|
| 148 |
+
|
| 149 |
+
Note:
|
| 150 |
+
We will try to read `mtl` path from `obj`, else we assume the file name is the same as `obj` but with `mtl` extension.
|
| 151 |
+
The `usemtl` statement is ignored, and we only use the last material path in `mtl` file.
|
| 152 |
+
|
| 153 |
+
Returns:
|
| 154 |
+
Mesh: the loaded Mesh object.
|
| 155 |
+
"""
|
| 156 |
+
assert os.path.splitext(path)[-1] == ".obj"
|
| 157 |
+
|
| 158 |
+
mesh = cls()
|
| 159 |
+
|
| 160 |
+
# device
|
| 161 |
+
if device is None:
|
| 162 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 163 |
+
|
| 164 |
+
mesh.device = device
|
| 165 |
+
|
| 166 |
+
# load obj
|
| 167 |
+
with open(path, "r") as f:
|
| 168 |
+
lines = f.readlines()
|
| 169 |
+
|
| 170 |
+
def parse_f_v(fv):
|
| 171 |
+
# pass in a vertex term of a face, return {v, vt, vn} (-1 if not provided)
|
| 172 |
+
# supported forms:
|
| 173 |
+
# f v1 v2 v3
|
| 174 |
+
# f v1/vt1 v2/vt2 v3/vt3
|
| 175 |
+
# f v1/vt1/vn1 v2/vt2/vn2 v3/vt3/vn3
|
| 176 |
+
# f v1//vn1 v2//vn2 v3//vn3
|
| 177 |
+
xs = [int(x) - 1 if x != "" else -1 for x in fv.split("/")]
|
| 178 |
+
xs.extend([-1] * (3 - len(xs)))
|
| 179 |
+
return xs[0], xs[1], xs[2]
|
| 180 |
+
|
| 181 |
+
vertices, texcoords, normals = [], [], []
|
| 182 |
+
faces, tfaces, nfaces = [], [], []
|
| 183 |
+
mtl_path = None
|
| 184 |
+
|
| 185 |
+
for line in lines:
|
| 186 |
+
split_line = line.split()
|
| 187 |
+
# empty line
|
| 188 |
+
if len(split_line) == 0:
|
| 189 |
+
continue
|
| 190 |
+
prefix = split_line[0].lower()
|
| 191 |
+
# mtllib
|
| 192 |
+
if prefix == "mtllib":
|
| 193 |
+
mtl_path = split_line[1]
|
| 194 |
+
# usemtl
|
| 195 |
+
elif prefix == "usemtl":
|
| 196 |
+
pass # ignored
|
| 197 |
+
# v/vn/vt
|
| 198 |
+
elif prefix == "v":
|
| 199 |
+
vertices.append([float(v) for v in split_line[1:]])
|
| 200 |
+
elif prefix == "vn":
|
| 201 |
+
normals.append([float(v) for v in split_line[1:]])
|
| 202 |
+
elif prefix == "vt":
|
| 203 |
+
val = [float(v) for v in split_line[1:]]
|
| 204 |
+
texcoords.append([val[0], 1.0 - val[1]])
|
| 205 |
+
elif prefix == "f":
|
| 206 |
+
vs = split_line[1:]
|
| 207 |
+
nv = len(vs)
|
| 208 |
+
v0, t0, n0 = parse_f_v(vs[0])
|
| 209 |
+
for i in range(nv - 2): # triangulate (assume vertices are ordered)
|
| 210 |
+
v1, t1, n1 = parse_f_v(vs[i + 1])
|
| 211 |
+
v2, t2, n2 = parse_f_v(vs[i + 2])
|
| 212 |
+
faces.append([v0, v1, v2])
|
| 213 |
+
tfaces.append([t0, t1, t2])
|
| 214 |
+
nfaces.append([n0, n1, n2])
|
| 215 |
+
|
| 216 |
+
mesh.v = torch.tensor(vertices, dtype=torch.float32, device=device)
|
| 217 |
+
mesh.vt = (
|
| 218 |
+
torch.tensor(texcoords, dtype=torch.float32, device=device)
|
| 219 |
+
if len(texcoords) > 0
|
| 220 |
+
else None
|
| 221 |
+
)
|
| 222 |
+
mesh.vn = (
|
| 223 |
+
torch.tensor(normals, dtype=torch.float32, device=device)
|
| 224 |
+
if len(normals) > 0
|
| 225 |
+
else None
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
mesh.f = torch.tensor(faces, dtype=torch.int32, device=device)
|
| 229 |
+
mesh.ft = (
|
| 230 |
+
torch.tensor(tfaces, dtype=torch.int32, device=device)
|
| 231 |
+
if len(texcoords) > 0
|
| 232 |
+
else None
|
| 233 |
+
)
|
| 234 |
+
mesh.fn = (
|
| 235 |
+
torch.tensor(nfaces, dtype=torch.int32, device=device)
|
| 236 |
+
if len(normals) > 0
|
| 237 |
+
else None
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
# see if there is vertex color
|
| 241 |
+
use_vertex_color = False
|
| 242 |
+
if mesh.v.shape[1] == 6:
|
| 243 |
+
use_vertex_color = True
|
| 244 |
+
mesh.vc = mesh.v[:, 3:]
|
| 245 |
+
mesh.v = mesh.v[:, :3]
|
| 246 |
+
print(f"[load_obj] use vertex color: {mesh.vc.shape}")
|
| 247 |
+
|
| 248 |
+
# try to load texture image
|
| 249 |
+
if not use_vertex_color:
|
| 250 |
+
# try to retrieve mtl file
|
| 251 |
+
mtl_path_candidates = []
|
| 252 |
+
if mtl_path is not None:
|
| 253 |
+
mtl_path_candidates.append(mtl_path)
|
| 254 |
+
mtl_path_candidates.append(os.path.join(os.path.dirname(path), mtl_path))
|
| 255 |
+
mtl_path_candidates.append(path.replace(".obj", ".mtl"))
|
| 256 |
+
|
| 257 |
+
mtl_path = None
|
| 258 |
+
for candidate in mtl_path_candidates:
|
| 259 |
+
if os.path.exists(candidate):
|
| 260 |
+
mtl_path = candidate
|
| 261 |
+
break
|
| 262 |
+
|
| 263 |
+
# if albedo_path is not provided, try retrieve it from mtl
|
| 264 |
+
metallic_path = None
|
| 265 |
+
roughness_path = None
|
| 266 |
+
if mtl_path is not None and albedo_path is None:
|
| 267 |
+
with open(mtl_path, "r") as f:
|
| 268 |
+
lines = f.readlines()
|
| 269 |
+
|
| 270 |
+
for line in lines:
|
| 271 |
+
split_line = line.split()
|
| 272 |
+
# empty line
|
| 273 |
+
if len(split_line) == 0:
|
| 274 |
+
continue
|
| 275 |
+
prefix = split_line[0]
|
| 276 |
+
|
| 277 |
+
if "map_Kd" in prefix:
|
| 278 |
+
# assume relative path!
|
| 279 |
+
albedo_path = os.path.join(os.path.dirname(path), split_line[1])
|
| 280 |
+
print(f"[load_obj] use texture from: {albedo_path}")
|
| 281 |
+
elif "map_Pm" in prefix:
|
| 282 |
+
metallic_path = os.path.join(os.path.dirname(path), split_line[1])
|
| 283 |
+
elif "map_Pr" in prefix:
|
| 284 |
+
roughness_path = os.path.join(os.path.dirname(path), split_line[1])
|
| 285 |
+
|
| 286 |
+
# still not found albedo_path, or the path doesn't exist
|
| 287 |
+
if albedo_path is None or not os.path.exists(albedo_path):
|
| 288 |
+
# init an empty texture
|
| 289 |
+
print(f"[load_obj] init empty albedo!")
|
| 290 |
+
# albedo = np.random.rand(1024, 1024, 3).astype(np.float32)
|
| 291 |
+
albedo = np.ones((1024, 1024, 3), dtype=np.float32) * np.array([0.5, 0.5, 0.5]) # default color
|
| 292 |
+
else:
|
| 293 |
+
albedo = cv2.imread(albedo_path, cv2.IMREAD_UNCHANGED)
|
| 294 |
+
albedo = cv2.cvtColor(albedo, cv2.COLOR_BGR2RGB)
|
| 295 |
+
albedo = albedo.astype(np.float32) / 255
|
| 296 |
+
print(f"[load_obj] load texture: {albedo.shape}")
|
| 297 |
+
|
| 298 |
+
mesh.albedo = torch.tensor(albedo, dtype=torch.float32, device=device)
|
| 299 |
+
|
| 300 |
+
# try to load metallic and roughness
|
| 301 |
+
if metallic_path is not None and roughness_path is not None:
|
| 302 |
+
print(f"[load_obj] load metallicRoughness from: {metallic_path}, {roughness_path}")
|
| 303 |
+
metallic = cv2.imread(metallic_path, cv2.IMREAD_UNCHANGED)
|
| 304 |
+
metallic = metallic.astype(np.float32) / 255
|
| 305 |
+
roughness = cv2.imread(roughness_path, cv2.IMREAD_UNCHANGED)
|
| 306 |
+
roughness = roughness.astype(np.float32) / 255
|
| 307 |
+
metallicRoughness = np.stack([np.zeros_like(metallic), roughness, metallic], axis=-1)
|
| 308 |
+
|
| 309 |
+
mesh.metallicRoughness = torch.tensor(metallicRoughness, dtype=torch.float32, device=device).contiguous()
|
| 310 |
+
|
| 311 |
+
return mesh
|
| 312 |
+
|
| 313 |
+
@classmethod
|
| 314 |
+
def load_trimesh(cls, path, device=None):
|
| 315 |
+
"""load a mesh using ``trimesh.load()``.
|
| 316 |
+
|
| 317 |
+
Can load various formats like ``glb`` and serves as a fallback.
|
| 318 |
+
|
| 319 |
+
Note:
|
| 320 |
+
We will try to merge all meshes if the glb contains more than one,
|
| 321 |
+
but **this may cause the texture to lose**, since we only support one texture image!
|
| 322 |
+
|
| 323 |
+
Args:
|
| 324 |
+
path (str): path to the mesh file.
|
| 325 |
+
device (torch.device, optional): torch device. Defaults to None.
|
| 326 |
+
|
| 327 |
+
Returns:
|
| 328 |
+
Mesh: the loaded Mesh object.
|
| 329 |
+
"""
|
| 330 |
+
mesh = cls()
|
| 331 |
+
|
| 332 |
+
# device
|
| 333 |
+
if device is None:
|
| 334 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 335 |
+
|
| 336 |
+
mesh.device = device
|
| 337 |
+
|
| 338 |
+
# use trimesh to load ply/glb
|
| 339 |
+
_data = trimesh.load(path)
|
| 340 |
+
if isinstance(_data, trimesh.Scene):
|
| 341 |
+
if len(_data.geometry) == 1:
|
| 342 |
+
_mesh = list(_data.geometry.values())[0]
|
| 343 |
+
else:
|
| 344 |
+
print(f"[load_trimesh] concatenating {len(_data.geometry)} meshes.")
|
| 345 |
+
_concat = []
|
| 346 |
+
# loop the scene graph and apply transform to each mesh
|
| 347 |
+
scene_graph = _data.graph.to_flattened() # dict {name: {transform: 4x4 mat, geometry: str}}
|
| 348 |
+
for k, v in scene_graph.items():
|
| 349 |
+
name = v['geometry']
|
| 350 |
+
if name in _data.geometry and isinstance(_data.geometry[name], trimesh.Trimesh):
|
| 351 |
+
transform = v['transform']
|
| 352 |
+
_concat.append(_data.geometry[name].apply_transform(transform))
|
| 353 |
+
_mesh = trimesh.util.concatenate(_concat)
|
| 354 |
+
else:
|
| 355 |
+
_mesh = _data
|
| 356 |
+
|
| 357 |
+
if _mesh.visual.kind == 'vertex':
|
| 358 |
+
vertex_colors = _mesh.visual.vertex_colors
|
| 359 |
+
vertex_colors = np.array(vertex_colors[..., :3]).astype(np.float32) / 255
|
| 360 |
+
mesh.vc = torch.tensor(vertex_colors, dtype=torch.float32, device=device)
|
| 361 |
+
print(f"[load_trimesh] use vertex color: {mesh.vc.shape}")
|
| 362 |
+
elif _mesh.visual.kind == 'texture':
|
| 363 |
+
_material = _mesh.visual.material
|
| 364 |
+
if isinstance(_material, trimesh.visual.material.PBRMaterial):
|
| 365 |
+
texture = np.array(_material.baseColorTexture).astype(np.float32) / 255
|
| 366 |
+
# load metallicRoughness if present
|
| 367 |
+
if _material.metallicRoughnessTexture is not None:
|
| 368 |
+
metallicRoughness = np.array(_material.metallicRoughnessTexture).astype(np.float32) / 255
|
| 369 |
+
mesh.metallicRoughness = torch.tensor(metallicRoughness, dtype=torch.float32, device=device).contiguous()
|
| 370 |
+
elif isinstance(_material, trimesh.visual.material.SimpleMaterial):
|
| 371 |
+
texture = np.array(_material.to_pbr().baseColorTexture).astype(np.float32) / 255
|
| 372 |
+
else:
|
| 373 |
+
raise NotImplementedError(f"material type {type(_material)} not supported!")
|
| 374 |
+
mesh.albedo = torch.tensor(texture[..., :3], dtype=torch.float32, device=device).contiguous()
|
| 375 |
+
print(f"[load_trimesh] load texture: {texture.shape}")
|
| 376 |
+
else:
|
| 377 |
+
texture = np.ones((1024, 1024, 3), dtype=np.float32) * np.array([0.5, 0.5, 0.5])
|
| 378 |
+
mesh.albedo = torch.tensor(texture, dtype=torch.float32, device=device)
|
| 379 |
+
print(f"[load_trimesh] failed to load texture.")
|
| 380 |
+
|
| 381 |
+
vertices = _mesh.vertices
|
| 382 |
+
|
| 383 |
+
try:
|
| 384 |
+
texcoords = _mesh.visual.uv
|
| 385 |
+
texcoords[:, 1] = 1 - texcoords[:, 1]
|
| 386 |
+
except Exception as e:
|
| 387 |
+
texcoords = None
|
| 388 |
+
|
| 389 |
+
try:
|
| 390 |
+
normals = _mesh.vertex_normals
|
| 391 |
+
except Exception as e:
|
| 392 |
+
normals = None
|
| 393 |
+
|
| 394 |
+
# trimesh only support vertex uv...
|
| 395 |
+
faces = tfaces = nfaces = _mesh.faces
|
| 396 |
+
|
| 397 |
+
mesh.v = torch.tensor(vertices, dtype=torch.float32, device=device)
|
| 398 |
+
mesh.vt = (
|
| 399 |
+
torch.tensor(texcoords, dtype=torch.float32, device=device)
|
| 400 |
+
if texcoords is not None
|
| 401 |
+
else None
|
| 402 |
+
)
|
| 403 |
+
mesh.vn = (
|
| 404 |
+
torch.tensor(normals, dtype=torch.float32, device=device)
|
| 405 |
+
if normals is not None
|
| 406 |
+
else None
|
| 407 |
+
)
|
| 408 |
+
|
| 409 |
+
mesh.f = torch.tensor(faces, dtype=torch.int32, device=device)
|
| 410 |
+
mesh.ft = (
|
| 411 |
+
torch.tensor(tfaces, dtype=torch.int32, device=device)
|
| 412 |
+
if texcoords is not None
|
| 413 |
+
else None
|
| 414 |
+
)
|
| 415 |
+
mesh.fn = (
|
| 416 |
+
torch.tensor(nfaces, dtype=torch.int32, device=device)
|
| 417 |
+
if normals is not None
|
| 418 |
+
else None
|
| 419 |
+
)
|
| 420 |
+
|
| 421 |
+
return mesh
|
| 422 |
+
|
| 423 |
+
# sample surface (using trimesh)
|
| 424 |
+
def sample_surface(self, count: int):
|
| 425 |
+
"""sample points on the surface of the mesh.
|
| 426 |
+
|
| 427 |
+
Args:
|
| 428 |
+
count (int): number of points to sample.
|
| 429 |
+
|
| 430 |
+
Returns:
|
| 431 |
+
torch.Tensor: the sampled points, float [count, 3].
|
| 432 |
+
"""
|
| 433 |
+
_mesh = trimesh.Trimesh(vertices=self.v.detach().cpu().numpy(), faces=self.f.detach().cpu().numpy())
|
| 434 |
+
points, face_idx = trimesh.sample.sample_surface(_mesh, count)
|
| 435 |
+
points = torch.from_numpy(points).float().to(self.device)
|
| 436 |
+
return points
|
| 437 |
+
|
| 438 |
+
# aabb
|
| 439 |
+
def aabb(self):
|
| 440 |
+
"""get the axis-aligned bounding box of the mesh.
|
| 441 |
+
|
| 442 |
+
Returns:
|
| 443 |
+
Tuple[torch.Tensor]: the min xyz and max xyz of the mesh.
|
| 444 |
+
"""
|
| 445 |
+
return torch.min(self.v, dim=0).values, torch.max(self.v, dim=0).values
|
| 446 |
+
|
| 447 |
+
# unit size
|
| 448 |
+
@torch.no_grad()
|
| 449 |
+
def auto_size(self, bound=0.9):
|
| 450 |
+
"""auto resize the mesh.
|
| 451 |
+
|
| 452 |
+
Args:
|
| 453 |
+
bound (float, optional): resizing into ``[-bound, bound]^3``. Defaults to 0.9.
|
| 454 |
+
"""
|
| 455 |
+
vmin, vmax = self.aabb()
|
| 456 |
+
self.ori_center = (vmax + vmin) / 2
|
| 457 |
+
self.ori_scale = 2 * bound / torch.max(vmax - vmin).item()
|
| 458 |
+
self.v = (self.v - self.ori_center) * self.ori_scale
|
| 459 |
+
|
| 460 |
+
def auto_normal(self):
|
| 461 |
+
"""auto calculate the vertex normals.
|
| 462 |
+
"""
|
| 463 |
+
i0, i1, i2 = self.f[:, 0].long(), self.f[:, 1].long(), self.f[:, 2].long()
|
| 464 |
+
v0, v1, v2 = self.v[i0, :], self.v[i1, :], self.v[i2, :]
|
| 465 |
+
|
| 466 |
+
face_normals = torch.cross(v1 - v0, v2 - v0)
|
| 467 |
+
|
| 468 |
+
# Splat face normals to vertices
|
| 469 |
+
vn = torch.zeros_like(self.v)
|
| 470 |
+
vn.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals)
|
| 471 |
+
vn.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals)
|
| 472 |
+
vn.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals)
|
| 473 |
+
|
| 474 |
+
# Normalize, replace zero (degenerated) normals with some default value
|
| 475 |
+
vn = torch.where(
|
| 476 |
+
dot(vn, vn) > 1e-20,
|
| 477 |
+
vn,
|
| 478 |
+
torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32, device=vn.device),
|
| 479 |
+
)
|
| 480 |
+
vn = safe_normalize(vn)
|
| 481 |
+
|
| 482 |
+
self.vn = vn
|
| 483 |
+
self.fn = self.f
|
| 484 |
+
|
| 485 |
+
def auto_uv(self, cache_path=None, vmap=True):
|
| 486 |
+
"""auto calculate the uv coordinates.
|
| 487 |
+
|
| 488 |
+
Args:
|
| 489 |
+
cache_path (str, optional): path to save/load the uv cache as a npz file, this can avoid calculating uv every time when loading the same mesh, which is time-consuming. Defaults to None.
|
| 490 |
+
vmap (bool, optional): remap vertices based on uv coordinates, so each v correspond to a unique vt (necessary for formats like gltf).
|
| 491 |
+
Usually this will duplicate the vertices on the edge of uv atlas. Defaults to True.
|
| 492 |
+
"""
|
| 493 |
+
# try to load cache
|
| 494 |
+
if cache_path is not None:
|
| 495 |
+
cache_path = os.path.splitext(cache_path)[0] + "_uv.npz"
|
| 496 |
+
if cache_path is not None and os.path.exists(cache_path):
|
| 497 |
+
data = np.load(cache_path)
|
| 498 |
+
vt_np, ft_np, vmapping = data["vt"], data["ft"], data["vmapping"]
|
| 499 |
+
else:
|
| 500 |
+
import xatlas
|
| 501 |
+
|
| 502 |
+
v_np = self.v.detach().cpu().numpy()
|
| 503 |
+
f_np = self.f.detach().int().cpu().numpy()
|
| 504 |
+
atlas = xatlas.Atlas()
|
| 505 |
+
atlas.add_mesh(v_np, f_np)
|
| 506 |
+
chart_options = xatlas.ChartOptions()
|
| 507 |
+
# chart_options.max_iterations = 4
|
| 508 |
+
atlas.generate(chart_options=chart_options)
|
| 509 |
+
vmapping, ft_np, vt_np = atlas[0] # [N], [M, 3], [N, 2]
|
| 510 |
+
|
| 511 |
+
# save to cache
|
| 512 |
+
if cache_path is not None:
|
| 513 |
+
np.savez(cache_path, vt=vt_np, ft=ft_np, vmapping=vmapping)
|
| 514 |
+
|
| 515 |
+
vt = torch.from_numpy(vt_np.astype(np.float32)).to(self.device)
|
| 516 |
+
ft = torch.from_numpy(ft_np.astype(np.int32)).to(self.device)
|
| 517 |
+
self.vt = vt
|
| 518 |
+
self.ft = ft
|
| 519 |
+
|
| 520 |
+
if vmap:
|
| 521 |
+
vmapping = torch.from_numpy(vmapping.astype(np.int64)).long().to(self.device)
|
| 522 |
+
self.align_v_to_vt(vmapping)
|
| 523 |
+
|
| 524 |
+
def align_v_to_vt(self, vmapping=None):
|
| 525 |
+
""" remap v/f and vn/fn to vt/ft.
|
| 526 |
+
|
| 527 |
+
Args:
|
| 528 |
+
vmapping (np.ndarray, optional): the mapping relationship from f to ft. Defaults to None.
|
| 529 |
+
"""
|
| 530 |
+
if vmapping is None:
|
| 531 |
+
ft = self.ft.view(-1).long()
|
| 532 |
+
f = self.f.view(-1).long()
|
| 533 |
+
vmapping = torch.zeros(self.vt.shape[0], dtype=torch.long, device=self.device)
|
| 534 |
+
vmapping[ft] = f # scatter, randomly choose one if index is not unique
|
| 535 |
+
|
| 536 |
+
self.v = self.v[vmapping]
|
| 537 |
+
self.f = self.ft
|
| 538 |
+
|
| 539 |
+
if self.vn is not None:
|
| 540 |
+
self.vn = self.vn[vmapping]
|
| 541 |
+
self.fn = self.ft
|
| 542 |
+
|
| 543 |
+
def to(self, device):
|
| 544 |
+
"""move all tensor attributes to device.
|
| 545 |
+
|
| 546 |
+
Args:
|
| 547 |
+
device (torch.device): target device.
|
| 548 |
+
|
| 549 |
+
Returns:
|
| 550 |
+
Mesh: self.
|
| 551 |
+
"""
|
| 552 |
+
self.device = device
|
| 553 |
+
for name in ["v", "f", "vn", "fn", "vt", "ft", "albedo", "vc", "metallicRoughness"]:
|
| 554 |
+
tensor = getattr(self, name)
|
| 555 |
+
if tensor is not None:
|
| 556 |
+
setattr(self, name, tensor.to(device))
|
| 557 |
+
return self
|
| 558 |
+
|
| 559 |
+
def write(self, path):
|
| 560 |
+
"""write the mesh to a path.
|
| 561 |
+
|
| 562 |
+
Args:
|
| 563 |
+
path (str): path to write, supports ply, obj and glb.
|
| 564 |
+
"""
|
| 565 |
+
if path.endswith(".ply"):
|
| 566 |
+
self.write_ply(path)
|
| 567 |
+
elif path.endswith(".obj"):
|
| 568 |
+
self.write_obj(path)
|
| 569 |
+
elif path.endswith(".glb") or path.endswith(".gltf"):
|
| 570 |
+
self.write_glb(path)
|
| 571 |
+
else:
|
| 572 |
+
raise NotImplementedError(f"format {path} not supported!")
|
| 573 |
+
|
| 574 |
+
def write_ply(self, path):
|
| 575 |
+
"""write the mesh in ply format. Only for geometry!
|
| 576 |
+
|
| 577 |
+
Args:
|
| 578 |
+
path (str): path to write.
|
| 579 |
+
"""
|
| 580 |
+
|
| 581 |
+
if self.albedo is not None:
|
| 582 |
+
print(f'[WARN] ply format does not support exporting texture, will ignore!')
|
| 583 |
+
|
| 584 |
+
v_np = self.v.detach().cpu().numpy()
|
| 585 |
+
f_np = self.f.detach().cpu().numpy()
|
| 586 |
+
|
| 587 |
+
_mesh = trimesh.Trimesh(vertices=v_np, faces=f_np)
|
| 588 |
+
_mesh.export(path)
|
| 589 |
+
|
| 590 |
+
|
| 591 |
+
def write_glb(self, path):
|
| 592 |
+
"""write the mesh in glb/gltf format.
|
| 593 |
+
This will create a scene with a single mesh.
|
| 594 |
+
|
| 595 |
+
Args:
|
| 596 |
+
path (str): path to write.
|
| 597 |
+
"""
|
| 598 |
+
|
| 599 |
+
# assert self.v.shape[0] == self.vn.shape[0] and self.v.shape[0] == self.vt.shape[0]
|
| 600 |
+
if self.vt is not None and self.v.shape[0] != self.vt.shape[0]:
|
| 601 |
+
self.align_v_to_vt()
|
| 602 |
+
|
| 603 |
+
import pygltflib
|
| 604 |
+
|
| 605 |
+
f_np = self.f.detach().cpu().numpy().astype(np.uint32)
|
| 606 |
+
f_np_blob = f_np.flatten().tobytes()
|
| 607 |
+
|
| 608 |
+
v_np = self.v.detach().cpu().numpy().astype(np.float32)
|
| 609 |
+
v_np_blob = v_np.tobytes()
|
| 610 |
+
|
| 611 |
+
blob = f_np_blob + v_np_blob
|
| 612 |
+
byteOffset = len(blob)
|
| 613 |
+
|
| 614 |
+
# base mesh
|
| 615 |
+
gltf = pygltflib.GLTF2(
|
| 616 |
+
scene=0,
|
| 617 |
+
scenes=[pygltflib.Scene(nodes=[0])],
|
| 618 |
+
nodes=[pygltflib.Node(mesh=0)],
|
| 619 |
+
meshes=[pygltflib.Mesh(primitives=[pygltflib.Primitive(
|
| 620 |
+
# indices to accessors (0 is triangles)
|
| 621 |
+
attributes=pygltflib.Attributes(
|
| 622 |
+
POSITION=1,
|
| 623 |
+
),
|
| 624 |
+
indices=0,
|
| 625 |
+
)])],
|
| 626 |
+
buffers=[
|
| 627 |
+
pygltflib.Buffer(byteLength=len(f_np_blob) + len(v_np_blob))
|
| 628 |
+
],
|
| 629 |
+
# buffer view (based on dtype)
|
| 630 |
+
bufferViews=[
|
| 631 |
+
# triangles; as flatten (element) array
|
| 632 |
+
pygltflib.BufferView(
|
| 633 |
+
buffer=0,
|
| 634 |
+
byteLength=len(f_np_blob),
|
| 635 |
+
target=pygltflib.ELEMENT_ARRAY_BUFFER, # GL_ELEMENT_ARRAY_BUFFER (34963)
|
| 636 |
+
),
|
| 637 |
+
# positions; as vec3 array
|
| 638 |
+
pygltflib.BufferView(
|
| 639 |
+
buffer=0,
|
| 640 |
+
byteOffset=len(f_np_blob),
|
| 641 |
+
byteLength=len(v_np_blob),
|
| 642 |
+
byteStride=12, # vec3
|
| 643 |
+
target=pygltflib.ARRAY_BUFFER, # GL_ARRAY_BUFFER (34962)
|
| 644 |
+
),
|
| 645 |
+
],
|
| 646 |
+
accessors=[
|
| 647 |
+
# 0 = triangles
|
| 648 |
+
pygltflib.Accessor(
|
| 649 |
+
bufferView=0,
|
| 650 |
+
componentType=pygltflib.UNSIGNED_INT, # GL_UNSIGNED_INT (5125)
|
| 651 |
+
count=f_np.size,
|
| 652 |
+
type=pygltflib.SCALAR,
|
| 653 |
+
max=[int(f_np.max())],
|
| 654 |
+
min=[int(f_np.min())],
|
| 655 |
+
),
|
| 656 |
+
# 1 = positions
|
| 657 |
+
pygltflib.Accessor(
|
| 658 |
+
bufferView=1,
|
| 659 |
+
componentType=pygltflib.FLOAT, # GL_FLOAT (5126)
|
| 660 |
+
count=len(v_np),
|
| 661 |
+
type=pygltflib.VEC3,
|
| 662 |
+
max=v_np.max(axis=0).tolist(),
|
| 663 |
+
min=v_np.min(axis=0).tolist(),
|
| 664 |
+
),
|
| 665 |
+
],
|
| 666 |
+
)
|
| 667 |
+
|
| 668 |
+
# append texture info
|
| 669 |
+
if self.vt is not None:
|
| 670 |
+
|
| 671 |
+
vt_np = self.vt.detach().cpu().numpy().astype(np.float32)
|
| 672 |
+
vt_np_blob = vt_np.tobytes()
|
| 673 |
+
|
| 674 |
+
albedo = self.albedo.detach().cpu().numpy()
|
| 675 |
+
albedo = (albedo * 255).astype(np.uint8)
|
| 676 |
+
albedo = cv2.cvtColor(albedo, cv2.COLOR_RGB2BGR)
|
| 677 |
+
albedo_blob = cv2.imencode('.png', albedo)[1].tobytes()
|
| 678 |
+
|
| 679 |
+
# update primitive
|
| 680 |
+
gltf.meshes[0].primitives[0].attributes.TEXCOORD_0 = 2
|
| 681 |
+
gltf.meshes[0].primitives[0].material = 0
|
| 682 |
+
|
| 683 |
+
# update materials
|
| 684 |
+
gltf.materials.append(pygltflib.Material(
|
| 685 |
+
pbrMetallicRoughness=pygltflib.PbrMetallicRoughness(
|
| 686 |
+
baseColorTexture=pygltflib.TextureInfo(index=0, texCoord=0),
|
| 687 |
+
metallicFactor=0.0,
|
| 688 |
+
roughnessFactor=1.0,
|
| 689 |
+
),
|
| 690 |
+
alphaMode=pygltflib.OPAQUE,
|
| 691 |
+
alphaCutoff=None,
|
| 692 |
+
doubleSided=True,
|
| 693 |
+
))
|
| 694 |
+
|
| 695 |
+
gltf.textures.append(pygltflib.Texture(sampler=0, source=0))
|
| 696 |
+
gltf.samplers.append(pygltflib.Sampler(magFilter=pygltflib.LINEAR, minFilter=pygltflib.LINEAR_MIPMAP_LINEAR, wrapS=pygltflib.REPEAT, wrapT=pygltflib.REPEAT))
|
| 697 |
+
gltf.images.append(pygltflib.Image(bufferView=3, mimeType="image/png"))
|
| 698 |
+
|
| 699 |
+
# update buffers
|
| 700 |
+
gltf.bufferViews.append(
|
| 701 |
+
# index = 2, texcoords; as vec2 array
|
| 702 |
+
pygltflib.BufferView(
|
| 703 |
+
buffer=0,
|
| 704 |
+
byteOffset=byteOffset,
|
| 705 |
+
byteLength=len(vt_np_blob),
|
| 706 |
+
byteStride=8, # vec2
|
| 707 |
+
target=pygltflib.ARRAY_BUFFER,
|
| 708 |
+
)
|
| 709 |
+
)
|
| 710 |
+
|
| 711 |
+
gltf.accessors.append(
|
| 712 |
+
# 2 = texcoords
|
| 713 |
+
pygltflib.Accessor(
|
| 714 |
+
bufferView=2,
|
| 715 |
+
componentType=pygltflib.FLOAT,
|
| 716 |
+
count=len(vt_np),
|
| 717 |
+
type=pygltflib.VEC2,
|
| 718 |
+
max=vt_np.max(axis=0).tolist(),
|
| 719 |
+
min=vt_np.min(axis=0).tolist(),
|
| 720 |
+
)
|
| 721 |
+
)
|
| 722 |
+
|
| 723 |
+
blob += vt_np_blob
|
| 724 |
+
byteOffset += len(vt_np_blob)
|
| 725 |
+
|
| 726 |
+
gltf.bufferViews.append(
|
| 727 |
+
# index = 3, albedo texture; as none target
|
| 728 |
+
pygltflib.BufferView(
|
| 729 |
+
buffer=0,
|
| 730 |
+
byteOffset=byteOffset,
|
| 731 |
+
byteLength=len(albedo_blob),
|
| 732 |
+
)
|
| 733 |
+
)
|
| 734 |
+
|
| 735 |
+
blob += albedo_blob
|
| 736 |
+
byteOffset += len(albedo_blob)
|
| 737 |
+
|
| 738 |
+
gltf.buffers[0].byteLength = byteOffset
|
| 739 |
+
|
| 740 |
+
# append metllic roughness
|
| 741 |
+
if self.metallicRoughness is not None:
|
| 742 |
+
metallicRoughness = self.metallicRoughness.detach().cpu().numpy()
|
| 743 |
+
metallicRoughness = (metallicRoughness * 255).astype(np.uint8)
|
| 744 |
+
metallicRoughness = cv2.cvtColor(metallicRoughness, cv2.COLOR_RGB2BGR)
|
| 745 |
+
metallicRoughness_blob = cv2.imencode('.png', metallicRoughness)[1].tobytes()
|
| 746 |
+
|
| 747 |
+
# update texture definition
|
| 748 |
+
gltf.materials[0].pbrMetallicRoughness.metallicFactor = 1.0
|
| 749 |
+
gltf.materials[0].pbrMetallicRoughness.roughnessFactor = 1.0
|
| 750 |
+
gltf.materials[0].pbrMetallicRoughness.metallicRoughnessTexture = pygltflib.TextureInfo(index=1, texCoord=0)
|
| 751 |
+
|
| 752 |
+
gltf.textures.append(pygltflib.Texture(sampler=1, source=1))
|
| 753 |
+
gltf.samplers.append(pygltflib.Sampler(magFilter=pygltflib.LINEAR, minFilter=pygltflib.LINEAR_MIPMAP_LINEAR, wrapS=pygltflib.REPEAT, wrapT=pygltflib.REPEAT))
|
| 754 |
+
gltf.images.append(pygltflib.Image(bufferView=4, mimeType="image/png"))
|
| 755 |
+
|
| 756 |
+
# update buffers
|
| 757 |
+
gltf.bufferViews.append(
|
| 758 |
+
# index = 4, metallicRoughness texture; as none target
|
| 759 |
+
pygltflib.BufferView(
|
| 760 |
+
buffer=0,
|
| 761 |
+
byteOffset=byteOffset,
|
| 762 |
+
byteLength=len(metallicRoughness_blob),
|
| 763 |
+
)
|
| 764 |
+
)
|
| 765 |
+
|
| 766 |
+
blob += metallicRoughness_blob
|
| 767 |
+
byteOffset += len(metallicRoughness_blob)
|
| 768 |
+
|
| 769 |
+
gltf.buffers[0].byteLength = byteOffset
|
| 770 |
+
|
| 771 |
+
|
| 772 |
+
# set actual data
|
| 773 |
+
gltf.set_binary_blob(blob)
|
| 774 |
+
|
| 775 |
+
# glb = b"".join(gltf.save_to_bytes())
|
| 776 |
+
gltf.save(path)
|
| 777 |
+
|
| 778 |
+
|
| 779 |
+
def write_obj(self, path):
|
| 780 |
+
"""write the mesh in obj format. Will also write the texture and mtl files.
|
| 781 |
+
|
| 782 |
+
Args:
|
| 783 |
+
path (str): path to write.
|
| 784 |
+
"""
|
| 785 |
+
|
| 786 |
+
mtl_path = path.replace(".obj", ".mtl")
|
| 787 |
+
albedo_path = path.replace(".obj", "_albedo.png")
|
| 788 |
+
metallic_path = path.replace(".obj", "_metallic.png")
|
| 789 |
+
roughness_path = path.replace(".obj", "_roughness.png")
|
| 790 |
+
|
| 791 |
+
v_np = self.v.detach().cpu().numpy()
|
| 792 |
+
vt_np = self.vt.detach().cpu().numpy() if self.vt is not None else None
|
| 793 |
+
vn_np = self.vn.detach().cpu().numpy() if self.vn is not None else None
|
| 794 |
+
f_np = self.f.detach().cpu().numpy()
|
| 795 |
+
ft_np = self.ft.detach().cpu().numpy() if self.ft is not None else None
|
| 796 |
+
fn_np = self.fn.detach().cpu().numpy() if self.fn is not None else None
|
| 797 |
+
|
| 798 |
+
with open(path, "w") as fp:
|
| 799 |
+
fp.write(f"mtllib {os.path.basename(mtl_path)} \n")
|
| 800 |
+
|
| 801 |
+
for v in v_np:
|
| 802 |
+
fp.write(f"v {v[0]} {v[1]} {v[2]} \n")
|
| 803 |
+
|
| 804 |
+
if vt_np is not None:
|
| 805 |
+
for v in vt_np:
|
| 806 |
+
fp.write(f"vt {v[0]} {1 - v[1]} \n")
|
| 807 |
+
|
| 808 |
+
if vn_np is not None:
|
| 809 |
+
for v in vn_np:
|
| 810 |
+
fp.write(f"vn {v[0]} {v[1]} {v[2]} \n")
|
| 811 |
+
|
| 812 |
+
fp.write(f"usemtl defaultMat \n")
|
| 813 |
+
for i in range(len(f_np)):
|
| 814 |
+
fp.write(
|
| 815 |
+
f'f {f_np[i, 0] + 1}/{ft_np[i, 0] + 1 if ft_np is not None else ""}/{fn_np[i, 0] + 1 if fn_np is not None else ""} \
|
| 816 |
+
{f_np[i, 1] + 1}/{ft_np[i, 1] + 1 if ft_np is not None else ""}/{fn_np[i, 1] + 1 if fn_np is not None else ""} \
|
| 817 |
+
{f_np[i, 2] + 1}/{ft_np[i, 2] + 1 if ft_np is not None else ""}/{fn_np[i, 2] + 1 if fn_np is not None else ""} \n'
|
| 818 |
+
)
|
| 819 |
+
|
| 820 |
+
with open(mtl_path, "w") as fp:
|
| 821 |
+
fp.write(f"newmtl defaultMat \n")
|
| 822 |
+
fp.write(f"Ka 1 1 1 \n")
|
| 823 |
+
fp.write(f"Kd 1 1 1 \n")
|
| 824 |
+
fp.write(f"Ks 0 0 0 \n")
|
| 825 |
+
fp.write(f"Tr 1 \n")
|
| 826 |
+
fp.write(f"illum 1 \n")
|
| 827 |
+
fp.write(f"Ns 0 \n")
|
| 828 |
+
if self.albedo is not None:
|
| 829 |
+
fp.write(f"map_Kd {os.path.basename(albedo_path)} \n")
|
| 830 |
+
if self.metallicRoughness is not None:
|
| 831 |
+
# ref: https://en.wikipedia.org/wiki/Wavefront_.obj_file#Physically-based_Rendering
|
| 832 |
+
fp.write(f"map_Pm {os.path.basename(metallic_path)} \n")
|
| 833 |
+
fp.write(f"map_Pr {os.path.basename(roughness_path)} \n")
|
| 834 |
+
|
| 835 |
+
if self.albedo is not None:
|
| 836 |
+
albedo = self.albedo.detach().cpu().numpy()
|
| 837 |
+
albedo = (albedo * 255).astype(np.uint8)
|
| 838 |
+
cv2.imwrite(albedo_path, cv2.cvtColor(albedo, cv2.COLOR_RGB2BGR))
|
| 839 |
+
|
| 840 |
+
if self.metallicRoughness is not None:
|
| 841 |
+
metallicRoughness = self.metallicRoughness.detach().cpu().numpy()
|
| 842 |
+
metallicRoughness = (metallicRoughness * 255).astype(np.uint8)
|
| 843 |
+
cv2.imwrite(metallic_path, metallicRoughness[..., 2])
|
| 844 |
+
cv2.imwrite(roughness_path, metallicRoughness[..., 1])
|
| 845 |
+
|
mesh_recon/models/__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
models = {}
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def register(name):
|
| 5 |
+
def decorator(cls):
|
| 6 |
+
models[name] = cls
|
| 7 |
+
return cls
|
| 8 |
+
return decorator
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def make(name, config):
|
| 12 |
+
model = models[name](config)
|
| 13 |
+
return model
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
from . import nerf, neus, geometry, texture
|
mesh_recon/models/base.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
from utils.misc import get_rank
|
| 5 |
+
|
| 6 |
+
class BaseModel(nn.Module):
|
| 7 |
+
def __init__(self, config):
|
| 8 |
+
super().__init__()
|
| 9 |
+
self.config = config
|
| 10 |
+
self.rank = get_rank()
|
| 11 |
+
self.setup()
|
| 12 |
+
if self.config.get('weights', None):
|
| 13 |
+
self.load_state_dict(torch.load(self.config.weights))
|
| 14 |
+
|
| 15 |
+
def setup(self):
|
| 16 |
+
raise NotImplementedError
|
| 17 |
+
|
| 18 |
+
def update_step(self, epoch, global_step):
|
| 19 |
+
pass
|
| 20 |
+
|
| 21 |
+
def train(self, mode=True):
|
| 22 |
+
return super().train(mode=mode)
|
| 23 |
+
|
| 24 |
+
def eval(self):
|
| 25 |
+
return super().eval()
|
| 26 |
+
|
| 27 |
+
def regularizations(self, out):
|
| 28 |
+
return {}
|
| 29 |
+
|
| 30 |
+
@torch.no_grad()
|
| 31 |
+
def export(self, export_config):
|
| 32 |
+
return {}
|
mesh_recon/models/geometry.py
ADDED
|
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
from pytorch_lightning.utilities.rank_zero import rank_zero_info
|
| 7 |
+
|
| 8 |
+
import models
|
| 9 |
+
from models.base import BaseModel
|
| 10 |
+
from models.utils import scale_anything, get_activation, cleanup, chunk_batch
|
| 11 |
+
from models.network_utils import get_encoding, get_mlp, get_encoding_with_network
|
| 12 |
+
from utils.misc import get_rank
|
| 13 |
+
from systems.utils import update_module_step
|
| 14 |
+
from nerfacc import ContractionType
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def contract_to_unisphere(x, radius, contraction_type):
|
| 18 |
+
if contraction_type == ContractionType.AABB:
|
| 19 |
+
x = scale_anything(x, (-radius, radius), (0, 1))
|
| 20 |
+
elif contraction_type == ContractionType.UN_BOUNDED_SPHERE:
|
| 21 |
+
x = scale_anything(x, (-radius, radius), (0, 1))
|
| 22 |
+
x = x * 2 - 1 # aabb is at [-1, 1]
|
| 23 |
+
mag = x.norm(dim=-1, keepdim=True)
|
| 24 |
+
mask = mag.squeeze(-1) > 1
|
| 25 |
+
x[mask] = (2 - 1 / mag[mask]) * (x[mask] / mag[mask])
|
| 26 |
+
x = x / 4 + 0.5 # [-inf, inf] is at [0, 1]
|
| 27 |
+
else:
|
| 28 |
+
raise NotImplementedError
|
| 29 |
+
return x
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class MarchingCubeHelper(nn.Module):
|
| 33 |
+
def __init__(self, resolution, use_torch=True):
|
| 34 |
+
super().__init__()
|
| 35 |
+
self.resolution = resolution
|
| 36 |
+
self.use_torch = use_torch
|
| 37 |
+
self.points_range = (0, 1)
|
| 38 |
+
if self.use_torch:
|
| 39 |
+
import torchmcubes
|
| 40 |
+
self.mc_func = torchmcubes.marching_cubes
|
| 41 |
+
else:
|
| 42 |
+
import mcubes
|
| 43 |
+
self.mc_func = mcubes.marching_cubes
|
| 44 |
+
self.verts = None
|
| 45 |
+
|
| 46 |
+
def grid_vertices(self):
|
| 47 |
+
if self.verts is None:
|
| 48 |
+
x, y, z = torch.linspace(*self.points_range, self.resolution), torch.linspace(*self.points_range, self.resolution), torch.linspace(*self.points_range, self.resolution)
|
| 49 |
+
x, y, z = torch.meshgrid(x, y, z, indexing='ij')
|
| 50 |
+
verts = torch.cat([x.reshape(-1, 1), y.reshape(-1, 1), z.reshape(-1, 1)], dim=-1).reshape(-1, 3)
|
| 51 |
+
self.verts = verts
|
| 52 |
+
return self.verts
|
| 53 |
+
|
| 54 |
+
def forward(self, level, threshold=0.):
|
| 55 |
+
level = level.float().view(self.resolution, self.resolution, self.resolution)
|
| 56 |
+
if self.use_torch:
|
| 57 |
+
verts, faces = self.mc_func(level.to(get_rank()), threshold)
|
| 58 |
+
verts, faces = verts.cpu(), faces.cpu().long()
|
| 59 |
+
else:
|
| 60 |
+
verts, faces = self.mc_func(-level.numpy(), threshold) # transform to numpy
|
| 61 |
+
verts, faces = torch.from_numpy(verts.astype(np.float32)), torch.from_numpy(faces.astype(np.int64)) # transform back to pytorch
|
| 62 |
+
verts = verts / (self.resolution - 1.)
|
| 63 |
+
return {
|
| 64 |
+
'v_pos': verts,
|
| 65 |
+
't_pos_idx': faces
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class BaseImplicitGeometry(BaseModel):
|
| 70 |
+
def __init__(self, config):
|
| 71 |
+
super().__init__(config)
|
| 72 |
+
if self.config.isosurface is not None:
|
| 73 |
+
assert self.config.isosurface.method in ['mc', 'mc-torch']
|
| 74 |
+
if self.config.isosurface.method == 'mc-torch':
|
| 75 |
+
raise NotImplementedError("Please do not use mc-torch. It currently has some scaling issues I haven't fixed yet.")
|
| 76 |
+
self.helper = MarchingCubeHelper(self.config.isosurface.resolution, use_torch=self.config.isosurface.method=='mc-torch')
|
| 77 |
+
self.radius = self.config.radius
|
| 78 |
+
self.contraction_type = None # assigned in system
|
| 79 |
+
|
| 80 |
+
def forward_level(self, points):
|
| 81 |
+
raise NotImplementedError
|
| 82 |
+
|
| 83 |
+
def isosurface_(self, vmin, vmax):
|
| 84 |
+
def batch_func(x):
|
| 85 |
+
x = torch.stack([
|
| 86 |
+
scale_anything(x[...,0], (0, 1), (vmin[0], vmax[0])),
|
| 87 |
+
scale_anything(x[...,1], (0, 1), (vmin[1], vmax[1])),
|
| 88 |
+
scale_anything(x[...,2], (0, 1), (vmin[2], vmax[2])),
|
| 89 |
+
], dim=-1).to(self.rank)
|
| 90 |
+
rv = self.forward_level(x).cpu()
|
| 91 |
+
cleanup()
|
| 92 |
+
return rv
|
| 93 |
+
|
| 94 |
+
level = chunk_batch(batch_func, self.config.isosurface.chunk, True, self.helper.grid_vertices())
|
| 95 |
+
mesh = self.helper(level, threshold=self.config.isosurface.threshold)
|
| 96 |
+
mesh['v_pos'] = torch.stack([
|
| 97 |
+
scale_anything(mesh['v_pos'][...,0], (0, 1), (vmin[0], vmax[0])),
|
| 98 |
+
scale_anything(mesh['v_pos'][...,1], (0, 1), (vmin[1], vmax[1])),
|
| 99 |
+
scale_anything(mesh['v_pos'][...,2], (0, 1), (vmin[2], vmax[2]))
|
| 100 |
+
], dim=-1)
|
| 101 |
+
return mesh
|
| 102 |
+
|
| 103 |
+
@torch.no_grad()
|
| 104 |
+
def isosurface(self):
|
| 105 |
+
if self.config.isosurface is None:
|
| 106 |
+
raise NotImplementedError
|
| 107 |
+
mesh_coarse = self.isosurface_((-self.radius, -self.radius, -self.radius), (self.radius, self.radius, self.radius))
|
| 108 |
+
vmin, vmax = mesh_coarse['v_pos'].amin(dim=0), mesh_coarse['v_pos'].amax(dim=0)
|
| 109 |
+
vmin_ = (vmin - (vmax - vmin) * 0.1).clamp(-self.radius, self.radius)
|
| 110 |
+
vmax_ = (vmax + (vmax - vmin) * 0.1).clamp(-self.radius, self.radius)
|
| 111 |
+
mesh_fine = self.isosurface_(vmin_, vmax_)
|
| 112 |
+
return mesh_fine
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
@models.register('volume-density')
|
| 116 |
+
class VolumeDensity(BaseImplicitGeometry):
|
| 117 |
+
def setup(self):
|
| 118 |
+
self.n_input_dims = self.config.get('n_input_dims', 3)
|
| 119 |
+
self.n_output_dims = self.config.feature_dim
|
| 120 |
+
self.encoding_with_network = get_encoding_with_network(self.n_input_dims, self.n_output_dims, self.config.xyz_encoding_config, self.config.mlp_network_config)
|
| 121 |
+
|
| 122 |
+
def forward(self, points):
|
| 123 |
+
points = contract_to_unisphere(points, self.radius, self.contraction_type)
|
| 124 |
+
out = self.encoding_with_network(points.view(-1, self.n_input_dims)).view(*points.shape[:-1], self.n_output_dims).float()
|
| 125 |
+
density, feature = out[...,0], out
|
| 126 |
+
if 'density_activation' in self.config:
|
| 127 |
+
density = get_activation(self.config.density_activation)(density + float(self.config.density_bias))
|
| 128 |
+
if 'feature_activation' in self.config:
|
| 129 |
+
feature = get_activation(self.config.feature_activation)(feature)
|
| 130 |
+
return density, feature
|
| 131 |
+
|
| 132 |
+
def forward_level(self, points):
|
| 133 |
+
points = contract_to_unisphere(points, self.radius, self.contraction_type)
|
| 134 |
+
density = self.encoding_with_network(points.reshape(-1, self.n_input_dims)).reshape(*points.shape[:-1], self.n_output_dims)[...,0]
|
| 135 |
+
if 'density_activation' in self.config:
|
| 136 |
+
density = get_activation(self.config.density_activation)(density + float(self.config.density_bias))
|
| 137 |
+
return -density
|
| 138 |
+
|
| 139 |
+
def update_step(self, epoch, global_step):
|
| 140 |
+
update_module_step(self.encoding_with_network, epoch, global_step)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
@models.register('volume-sdf')
|
| 144 |
+
class VolumeSDF(BaseImplicitGeometry):
|
| 145 |
+
def setup(self):
|
| 146 |
+
self.n_output_dims = self.config.feature_dim
|
| 147 |
+
encoding = get_encoding(3, self.config.xyz_encoding_config)
|
| 148 |
+
network = get_mlp(encoding.n_output_dims, self.n_output_dims, self.config.mlp_network_config)
|
| 149 |
+
self.encoding, self.network = encoding, network
|
| 150 |
+
self.grad_type = self.config.grad_type
|
| 151 |
+
self.finite_difference_eps = self.config.get('finite_difference_eps', 1e-3)
|
| 152 |
+
# the actual value used in training
|
| 153 |
+
# will update at certain steps if finite_difference_eps="progressive"
|
| 154 |
+
self._finite_difference_eps = None
|
| 155 |
+
if self.grad_type == 'finite_difference':
|
| 156 |
+
rank_zero_info(f"Using finite difference to compute gradients with eps={self.finite_difference_eps}")
|
| 157 |
+
|
| 158 |
+
def forward(self, points, with_grad=True, with_feature=True, with_laplace=False):
|
| 159 |
+
with torch.inference_mode(torch.is_inference_mode_enabled() and not (with_grad and self.grad_type == 'analytic')):
|
| 160 |
+
with torch.set_grad_enabled(self.training or (with_grad and self.grad_type == 'analytic')):
|
| 161 |
+
if with_grad and self.grad_type == 'analytic':
|
| 162 |
+
if not self.training:
|
| 163 |
+
points = points.clone() # points may be in inference mode, get a copy to enable grad
|
| 164 |
+
points.requires_grad_(True)
|
| 165 |
+
|
| 166 |
+
points_ = points # points in the original scale
|
| 167 |
+
points = contract_to_unisphere(points, self.radius, self.contraction_type) # points normalized to (0, 1)
|
| 168 |
+
|
| 169 |
+
out = self.network(self.encoding(points.view(-1, 3))).view(*points.shape[:-1], self.n_output_dims).float()
|
| 170 |
+
sdf, feature = out[...,0], out
|
| 171 |
+
if 'sdf_activation' in self.config:
|
| 172 |
+
sdf = get_activation(self.config.sdf_activation)(sdf + float(self.config.sdf_bias))
|
| 173 |
+
if 'feature_activation' in self.config:
|
| 174 |
+
feature = get_activation(self.config.feature_activation)(feature)
|
| 175 |
+
if with_grad:
|
| 176 |
+
if self.grad_type == 'analytic':
|
| 177 |
+
grad = torch.autograd.grad(
|
| 178 |
+
sdf, points_, grad_outputs=torch.ones_like(sdf),
|
| 179 |
+
create_graph=True, retain_graph=True, only_inputs=True
|
| 180 |
+
)[0]
|
| 181 |
+
elif self.grad_type == 'finite_difference':
|
| 182 |
+
eps = self._finite_difference_eps
|
| 183 |
+
offsets = torch.as_tensor(
|
| 184 |
+
[
|
| 185 |
+
[eps, 0.0, 0.0],
|
| 186 |
+
[-eps, 0.0, 0.0],
|
| 187 |
+
[0.0, eps, 0.0],
|
| 188 |
+
[0.0, -eps, 0.0],
|
| 189 |
+
[0.0, 0.0, eps],
|
| 190 |
+
[0.0, 0.0, -eps],
|
| 191 |
+
]
|
| 192 |
+
).to(points_)
|
| 193 |
+
points_d_ = (points_[...,None,:] + offsets).clamp(-self.radius, self.radius)
|
| 194 |
+
points_d = scale_anything(points_d_, (-self.radius, self.radius), (0, 1))
|
| 195 |
+
points_d_sdf = self.network(self.encoding(points_d.view(-1, 3)))[...,0].view(*points.shape[:-1], 6).float()
|
| 196 |
+
grad = 0.5 * (points_d_sdf[..., 0::2] - points_d_sdf[..., 1::2]) / eps
|
| 197 |
+
|
| 198 |
+
if with_laplace:
|
| 199 |
+
laplace = (points_d_sdf[..., 0::2] + points_d_sdf[..., 1::2] - 2 * sdf[..., None]).sum(-1) / (eps ** 2)
|
| 200 |
+
|
| 201 |
+
rv = [sdf]
|
| 202 |
+
if with_grad:
|
| 203 |
+
rv.append(grad)
|
| 204 |
+
if with_feature:
|
| 205 |
+
rv.append(feature)
|
| 206 |
+
if with_laplace:
|
| 207 |
+
assert self.config.grad_type == 'finite_difference', "Laplace computation is only supported with grad_type='finite_difference'"
|
| 208 |
+
rv.append(laplace)
|
| 209 |
+
rv = [v if self.training else v.detach() for v in rv]
|
| 210 |
+
return rv[0] if len(rv) == 1 else rv
|
| 211 |
+
|
| 212 |
+
def forward_level(self, points):
|
| 213 |
+
points = contract_to_unisphere(points, self.radius, self.contraction_type) # points normalized to (0, 1)
|
| 214 |
+
sdf = self.network(self.encoding(points.view(-1, 3))).view(*points.shape[:-1], self.n_output_dims)[...,0]
|
| 215 |
+
if 'sdf_activation' in self.config:
|
| 216 |
+
sdf = get_activation(self.config.sdf_activation)(sdf + float(self.config.sdf_bias))
|
| 217 |
+
return sdf
|
| 218 |
+
|
| 219 |
+
def update_step(self, epoch, global_step):
|
| 220 |
+
update_module_step(self.encoding, epoch, global_step)
|
| 221 |
+
update_module_step(self.network, epoch, global_step)
|
| 222 |
+
if self.grad_type == 'finite_difference':
|
| 223 |
+
if isinstance(self.finite_difference_eps, float):
|
| 224 |
+
self._finite_difference_eps = self.finite_difference_eps
|
| 225 |
+
elif self.finite_difference_eps == 'progressive':
|
| 226 |
+
hg_conf = self.config.xyz_encoding_config
|
| 227 |
+
assert hg_conf.otype == "ProgressiveBandHashGrid", "finite_difference_eps='progressive' only works with ProgressiveBandHashGrid"
|
| 228 |
+
current_level = min(
|
| 229 |
+
hg_conf.start_level + max(global_step - hg_conf.start_step, 0) // hg_conf.update_steps,
|
| 230 |
+
hg_conf.n_levels
|
| 231 |
+
)
|
| 232 |
+
grid_res = hg_conf.base_resolution * hg_conf.per_level_scale**(current_level - 1)
|
| 233 |
+
grid_size = 2 * self.config.radius / grid_res
|
| 234 |
+
if grid_size != self._finite_difference_eps:
|
| 235 |
+
rank_zero_info(f"Update finite_difference_eps to {grid_size}")
|
| 236 |
+
self._finite_difference_eps = grid_size
|
| 237 |
+
else:
|
| 238 |
+
raise ValueError(f"Unknown finite_difference_eps={self.finite_difference_eps}")
|
mesh_recon/models/nerf.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
import models
|
| 8 |
+
from models.base import BaseModel
|
| 9 |
+
from models.utils import chunk_batch
|
| 10 |
+
from systems.utils import update_module_step
|
| 11 |
+
from nerfacc import ContractionType, OccupancyGrid, ray_marching, render_weight_from_density, accumulate_along_rays
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@models.register('nerf')
|
| 15 |
+
class NeRFModel(BaseModel):
|
| 16 |
+
def setup(self):
|
| 17 |
+
self.geometry = models.make(self.config.geometry.name, self.config.geometry)
|
| 18 |
+
self.texture = models.make(self.config.texture.name, self.config.texture)
|
| 19 |
+
self.register_buffer('scene_aabb', torch.as_tensor([-self.config.radius, -self.config.radius, -self.config.radius, self.config.radius, self.config.radius, self.config.radius], dtype=torch.float32))
|
| 20 |
+
|
| 21 |
+
if self.config.learned_background:
|
| 22 |
+
self.occupancy_grid_res = 256
|
| 23 |
+
self.near_plane, self.far_plane = 0.2, 1e4
|
| 24 |
+
self.cone_angle = 10**(math.log10(self.far_plane) / self.config.num_samples_per_ray) - 1. # approximate
|
| 25 |
+
self.render_step_size = 0.01 # render_step_size = max(distance_to_camera * self.cone_angle, self.render_step_size)
|
| 26 |
+
self.contraction_type = ContractionType.UN_BOUNDED_SPHERE
|
| 27 |
+
else:
|
| 28 |
+
self.occupancy_grid_res = 128
|
| 29 |
+
self.near_plane, self.far_plane = None, None
|
| 30 |
+
self.cone_angle = 0.0
|
| 31 |
+
self.render_step_size = 1.732 * 2 * self.config.radius / self.config.num_samples_per_ray
|
| 32 |
+
self.contraction_type = ContractionType.AABB
|
| 33 |
+
|
| 34 |
+
self.geometry.contraction_type = self.contraction_type
|
| 35 |
+
|
| 36 |
+
if self.config.grid_prune:
|
| 37 |
+
self.occupancy_grid = OccupancyGrid(
|
| 38 |
+
roi_aabb=self.scene_aabb,
|
| 39 |
+
resolution=self.occupancy_grid_res,
|
| 40 |
+
contraction_type=self.contraction_type
|
| 41 |
+
)
|
| 42 |
+
self.randomized = self.config.randomized
|
| 43 |
+
self.background_color = None
|
| 44 |
+
|
| 45 |
+
def update_step(self, epoch, global_step):
|
| 46 |
+
update_module_step(self.geometry, epoch, global_step)
|
| 47 |
+
update_module_step(self.texture, epoch, global_step)
|
| 48 |
+
|
| 49 |
+
def occ_eval_fn(x):
|
| 50 |
+
density, _ = self.geometry(x)
|
| 51 |
+
# approximate for 1 - torch.exp(-density[...,None] * self.render_step_size) based on taylor series
|
| 52 |
+
return density[...,None] * self.render_step_size
|
| 53 |
+
|
| 54 |
+
if self.training and self.config.grid_prune:
|
| 55 |
+
self.occupancy_grid.every_n_step(step=global_step, occ_eval_fn=occ_eval_fn)
|
| 56 |
+
|
| 57 |
+
def isosurface(self):
|
| 58 |
+
mesh = self.geometry.isosurface()
|
| 59 |
+
return mesh
|
| 60 |
+
|
| 61 |
+
def forward_(self, rays):
|
| 62 |
+
n_rays = rays.shape[0]
|
| 63 |
+
rays_o, rays_d = rays[:, 0:3], rays[:, 3:6] # both (N_rays, 3)
|
| 64 |
+
|
| 65 |
+
def sigma_fn(t_starts, t_ends, ray_indices):
|
| 66 |
+
ray_indices = ray_indices.long()
|
| 67 |
+
t_origins = rays_o[ray_indices]
|
| 68 |
+
t_dirs = rays_d[ray_indices]
|
| 69 |
+
positions = t_origins + t_dirs * (t_starts + t_ends) / 2.
|
| 70 |
+
density, _ = self.geometry(positions)
|
| 71 |
+
return density[...,None]
|
| 72 |
+
|
| 73 |
+
def rgb_sigma_fn(t_starts, t_ends, ray_indices):
|
| 74 |
+
ray_indices = ray_indices.long()
|
| 75 |
+
t_origins = rays_o[ray_indices]
|
| 76 |
+
t_dirs = rays_d[ray_indices]
|
| 77 |
+
positions = t_origins + t_dirs * (t_starts + t_ends) / 2.
|
| 78 |
+
density, feature = self.geometry(positions)
|
| 79 |
+
rgb = self.texture(feature, t_dirs)
|
| 80 |
+
return rgb, density[...,None]
|
| 81 |
+
|
| 82 |
+
with torch.no_grad():
|
| 83 |
+
ray_indices, t_starts, t_ends = ray_marching(
|
| 84 |
+
rays_o, rays_d,
|
| 85 |
+
scene_aabb=None if self.config.learned_background else self.scene_aabb,
|
| 86 |
+
grid=self.occupancy_grid if self.config.grid_prune else None,
|
| 87 |
+
sigma_fn=sigma_fn,
|
| 88 |
+
near_plane=self.near_plane, far_plane=self.far_plane,
|
| 89 |
+
render_step_size=self.render_step_size,
|
| 90 |
+
stratified=self.randomized,
|
| 91 |
+
cone_angle=self.cone_angle,
|
| 92 |
+
alpha_thre=0.0
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
ray_indices = ray_indices.long()
|
| 96 |
+
t_origins = rays_o[ray_indices]
|
| 97 |
+
t_dirs = rays_d[ray_indices]
|
| 98 |
+
midpoints = (t_starts + t_ends) / 2.
|
| 99 |
+
positions = t_origins + t_dirs * midpoints
|
| 100 |
+
intervals = t_ends - t_starts
|
| 101 |
+
|
| 102 |
+
density, feature = self.geometry(positions)
|
| 103 |
+
rgb = self.texture(feature, t_dirs)
|
| 104 |
+
|
| 105 |
+
weights = render_weight_from_density(t_starts, t_ends, density[...,None], ray_indices=ray_indices, n_rays=n_rays)
|
| 106 |
+
opacity = accumulate_along_rays(weights, ray_indices, values=None, n_rays=n_rays)
|
| 107 |
+
depth = accumulate_along_rays(weights, ray_indices, values=midpoints, n_rays=n_rays)
|
| 108 |
+
comp_rgb = accumulate_along_rays(weights, ray_indices, values=rgb, n_rays=n_rays)
|
| 109 |
+
comp_rgb = comp_rgb + self.background_color * (1.0 - opacity)
|
| 110 |
+
|
| 111 |
+
out = {
|
| 112 |
+
'comp_rgb': comp_rgb,
|
| 113 |
+
'opacity': opacity,
|
| 114 |
+
'depth': depth,
|
| 115 |
+
'rays_valid': opacity > 0,
|
| 116 |
+
'num_samples': torch.as_tensor([len(t_starts)], dtype=torch.int32, device=rays.device)
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
if self.training:
|
| 120 |
+
out.update({
|
| 121 |
+
'weights': weights.view(-1),
|
| 122 |
+
'points': midpoints.view(-1),
|
| 123 |
+
'intervals': intervals.view(-1),
|
| 124 |
+
'ray_indices': ray_indices.view(-1)
|
| 125 |
+
})
|
| 126 |
+
|
| 127 |
+
return out
|
| 128 |
+
|
| 129 |
+
def forward(self, rays):
|
| 130 |
+
if self.training:
|
| 131 |
+
out = self.forward_(rays)
|
| 132 |
+
else:
|
| 133 |
+
out = chunk_batch(self.forward_, self.config.ray_chunk, True, rays)
|
| 134 |
+
return {
|
| 135 |
+
**out,
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
def train(self, mode=True):
|
| 139 |
+
self.randomized = mode and self.config.randomized
|
| 140 |
+
return super().train(mode=mode)
|
| 141 |
+
|
| 142 |
+
def eval(self):
|
| 143 |
+
self.randomized = False
|
| 144 |
+
return super().eval()
|
| 145 |
+
|
| 146 |
+
def regularizations(self, out):
|
| 147 |
+
losses = {}
|
| 148 |
+
losses.update(self.geometry.regularizations(out))
|
| 149 |
+
losses.update(self.texture.regularizations(out))
|
| 150 |
+
return losses
|
| 151 |
+
|
| 152 |
+
@torch.no_grad()
|
| 153 |
+
def export(self, export_config):
|
| 154 |
+
mesh = self.isosurface()
|
| 155 |
+
if export_config.export_vertex_color:
|
| 156 |
+
_, feature = chunk_batch(self.geometry, export_config.chunk_size, False, mesh['v_pos'].to(self.rank))
|
| 157 |
+
viewdirs = torch.zeros(feature.shape[0], 3).to(feature)
|
| 158 |
+
viewdirs[...,2] = -1. # set the viewing directions to be -z (looking down)
|
| 159 |
+
rgb = self.texture(feature, viewdirs).clamp(0,1)
|
| 160 |
+
mesh['v_rgb'] = rgb.cpu()
|
| 161 |
+
return mesh
|
mesh_recon/models/network_utils.py
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import tinycudann as tcnn
|
| 7 |
+
|
| 8 |
+
from pytorch_lightning.utilities.rank_zero import rank_zero_debug, rank_zero_info
|
| 9 |
+
|
| 10 |
+
from utils.misc import config_to_primitive, get_rank
|
| 11 |
+
from models.utils import get_activation
|
| 12 |
+
from systems.utils import update_module_step
|
| 13 |
+
|
| 14 |
+
class VanillaFrequency(nn.Module):
|
| 15 |
+
def __init__(self, in_channels, config):
|
| 16 |
+
super().__init__()
|
| 17 |
+
self.N_freqs = config['n_frequencies']
|
| 18 |
+
self.in_channels, self.n_input_dims = in_channels, in_channels
|
| 19 |
+
self.funcs = [torch.sin, torch.cos]
|
| 20 |
+
self.freq_bands = 2**torch.linspace(0, self.N_freqs-1, self.N_freqs)
|
| 21 |
+
self.n_output_dims = self.in_channels * (len(self.funcs) * self.N_freqs)
|
| 22 |
+
self.n_masking_step = config.get('n_masking_step', 0)
|
| 23 |
+
self.update_step(None, None) # mask should be updated at the beginning each step
|
| 24 |
+
|
| 25 |
+
def forward(self, x):
|
| 26 |
+
out = []
|
| 27 |
+
for freq, mask in zip(self.freq_bands, self.mask):
|
| 28 |
+
for func in self.funcs:
|
| 29 |
+
out += [func(freq*x) * mask]
|
| 30 |
+
return torch.cat(out, -1)
|
| 31 |
+
|
| 32 |
+
def update_step(self, epoch, global_step):
|
| 33 |
+
if self.n_masking_step <= 0 or global_step is None:
|
| 34 |
+
self.mask = torch.ones(self.N_freqs, dtype=torch.float32)
|
| 35 |
+
else:
|
| 36 |
+
self.mask = (1. - torch.cos(math.pi * (global_step / self.n_masking_step * self.N_freqs - torch.arange(0, self.N_freqs)).clamp(0, 1))) / 2.
|
| 37 |
+
rank_zero_debug(f'Update mask: {global_step}/{self.n_masking_step} {self.mask}')
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class ProgressiveBandHashGrid(nn.Module):
|
| 41 |
+
def __init__(self, in_channels, config):
|
| 42 |
+
super().__init__()
|
| 43 |
+
self.n_input_dims = in_channels
|
| 44 |
+
encoding_config = config.copy()
|
| 45 |
+
encoding_config['otype'] = 'HashGrid'
|
| 46 |
+
with torch.cuda.device(get_rank()):
|
| 47 |
+
self.encoding = tcnn.Encoding(in_channels, encoding_config)
|
| 48 |
+
self.n_output_dims = self.encoding.n_output_dims
|
| 49 |
+
self.n_level = config['n_levels']
|
| 50 |
+
self.n_features_per_level = config['n_features_per_level']
|
| 51 |
+
self.start_level, self.start_step, self.update_steps = config['start_level'], config['start_step'], config['update_steps']
|
| 52 |
+
self.current_level = self.start_level
|
| 53 |
+
self.mask = torch.zeros(self.n_level * self.n_features_per_level, dtype=torch.float32, device=get_rank())
|
| 54 |
+
|
| 55 |
+
def forward(self, x):
|
| 56 |
+
enc = self.encoding(x)
|
| 57 |
+
enc = enc * self.mask
|
| 58 |
+
return enc
|
| 59 |
+
|
| 60 |
+
def update_step(self, epoch, global_step):
|
| 61 |
+
current_level = min(self.start_level + max(global_step - self.start_step, 0) // self.update_steps, self.n_level)
|
| 62 |
+
if current_level > self.current_level:
|
| 63 |
+
rank_zero_info(f'Update grid level to {current_level}')
|
| 64 |
+
self.current_level = current_level
|
| 65 |
+
self.mask[:self.current_level * self.n_features_per_level] = 1.
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class CompositeEncoding(nn.Module):
|
| 69 |
+
def __init__(self, encoding, include_xyz=False, xyz_scale=1., xyz_offset=0.):
|
| 70 |
+
super(CompositeEncoding, self).__init__()
|
| 71 |
+
self.encoding = encoding
|
| 72 |
+
self.include_xyz, self.xyz_scale, self.xyz_offset = include_xyz, xyz_scale, xyz_offset
|
| 73 |
+
self.n_output_dims = int(self.include_xyz) * self.encoding.n_input_dims + self.encoding.n_output_dims
|
| 74 |
+
|
| 75 |
+
def forward(self, x, *args):
|
| 76 |
+
return self.encoding(x, *args) if not self.include_xyz else torch.cat([x * self.xyz_scale + self.xyz_offset, self.encoding(x, *args)], dim=-1)
|
| 77 |
+
|
| 78 |
+
def update_step(self, epoch, global_step):
|
| 79 |
+
update_module_step(self.encoding, epoch, global_step)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def get_encoding(n_input_dims, config):
|
| 83 |
+
# input suppose to be range [0, 1]
|
| 84 |
+
if config.otype == 'VanillaFrequency':
|
| 85 |
+
encoding = VanillaFrequency(n_input_dims, config_to_primitive(config))
|
| 86 |
+
elif config.otype == 'ProgressiveBandHashGrid':
|
| 87 |
+
encoding = ProgressiveBandHashGrid(n_input_dims, config_to_primitive(config))
|
| 88 |
+
else:
|
| 89 |
+
with torch.cuda.device(get_rank()):
|
| 90 |
+
encoding = tcnn.Encoding(n_input_dims, config_to_primitive(config))
|
| 91 |
+
encoding = CompositeEncoding(encoding, include_xyz=config.get('include_xyz', False), xyz_scale=2., xyz_offset=-1.)
|
| 92 |
+
return encoding
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class VanillaMLP(nn.Module):
|
| 96 |
+
def __init__(self, dim_in, dim_out, config):
|
| 97 |
+
super().__init__()
|
| 98 |
+
self.n_neurons, self.n_hidden_layers = config['n_neurons'], config['n_hidden_layers']
|
| 99 |
+
self.sphere_init, self.weight_norm = config.get('sphere_init', False), config.get('weight_norm', False)
|
| 100 |
+
self.sphere_init_radius = config.get('sphere_init_radius', 0.5)
|
| 101 |
+
self.layers = [self.make_linear(dim_in, self.n_neurons, is_first=True, is_last=False), self.make_activation()]
|
| 102 |
+
for i in range(self.n_hidden_layers - 1):
|
| 103 |
+
self.layers += [self.make_linear(self.n_neurons, self.n_neurons, is_first=False, is_last=False), self.make_activation()]
|
| 104 |
+
self.layers += [self.make_linear(self.n_neurons, dim_out, is_first=False, is_last=True)]
|
| 105 |
+
self.layers = nn.Sequential(*self.layers)
|
| 106 |
+
self.output_activation = get_activation(config['output_activation'])
|
| 107 |
+
|
| 108 |
+
@torch.cuda.amp.autocast(False)
|
| 109 |
+
def forward(self, x):
|
| 110 |
+
x = self.layers(x.float())
|
| 111 |
+
x = self.output_activation(x)
|
| 112 |
+
return x
|
| 113 |
+
|
| 114 |
+
def make_linear(self, dim_in, dim_out, is_first, is_last):
|
| 115 |
+
layer = nn.Linear(dim_in, dim_out, bias=True) # network without bias will degrade quality
|
| 116 |
+
if self.sphere_init:
|
| 117 |
+
if is_last:
|
| 118 |
+
torch.nn.init.constant_(layer.bias, -self.sphere_init_radius)
|
| 119 |
+
torch.nn.init.normal_(layer.weight, mean=math.sqrt(math.pi) / math.sqrt(dim_in), std=0.0001)
|
| 120 |
+
elif is_first:
|
| 121 |
+
torch.nn.init.constant_(layer.bias, 0.0)
|
| 122 |
+
torch.nn.init.constant_(layer.weight[:, 3:], 0.0)
|
| 123 |
+
torch.nn.init.normal_(layer.weight[:, :3], 0.0, math.sqrt(2) / math.sqrt(dim_out))
|
| 124 |
+
else:
|
| 125 |
+
torch.nn.init.constant_(layer.bias, 0.0)
|
| 126 |
+
torch.nn.init.normal_(layer.weight, 0.0, math.sqrt(2) / math.sqrt(dim_out))
|
| 127 |
+
else:
|
| 128 |
+
torch.nn.init.constant_(layer.bias, 0.0)
|
| 129 |
+
torch.nn.init.kaiming_uniform_(layer.weight, nonlinearity='relu')
|
| 130 |
+
|
| 131 |
+
if self.weight_norm:
|
| 132 |
+
layer = nn.utils.weight_norm(layer)
|
| 133 |
+
return layer
|
| 134 |
+
|
| 135 |
+
def make_activation(self):
|
| 136 |
+
if self.sphere_init:
|
| 137 |
+
return nn.Softplus(beta=100)
|
| 138 |
+
else:
|
| 139 |
+
return nn.ReLU(inplace=True)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def sphere_init_tcnn_network(n_input_dims, n_output_dims, config, network):
|
| 143 |
+
rank_zero_debug('Initialize tcnn MLP to approximately represent a sphere.')
|
| 144 |
+
"""
|
| 145 |
+
from https://github.com/NVlabs/tiny-cuda-nn/issues/96
|
| 146 |
+
It's the weight matrices of each layer laid out in row-major order and then concatenated.
|
| 147 |
+
Notably: inputs and output dimensions are padded to multiples of 8 (CutlassMLP) or 16 (FullyFusedMLP).
|
| 148 |
+
The padded input dimensions get a constant value of 1.0,
|
| 149 |
+
whereas the padded output dimensions are simply ignored,
|
| 150 |
+
so the weights pertaining to those can have any value.
|
| 151 |
+
"""
|
| 152 |
+
padto = 16 if config.otype == 'FullyFusedMLP' else 8
|
| 153 |
+
n_input_dims = n_input_dims + (padto - n_input_dims % padto) % padto
|
| 154 |
+
n_output_dims = n_output_dims + (padto - n_output_dims % padto) % padto
|
| 155 |
+
data = list(network.parameters())[0].data
|
| 156 |
+
assert data.shape[0] == (n_input_dims + n_output_dims) * config.n_neurons + (config.n_hidden_layers - 1) * config.n_neurons**2
|
| 157 |
+
new_data = []
|
| 158 |
+
# first layer
|
| 159 |
+
weight = torch.zeros((config.n_neurons, n_input_dims)).to(data)
|
| 160 |
+
torch.nn.init.constant_(weight[:, 3:], 0.0)
|
| 161 |
+
torch.nn.init.normal_(weight[:, :3], 0.0, math.sqrt(2) / math.sqrt(config.n_neurons))
|
| 162 |
+
new_data.append(weight.flatten())
|
| 163 |
+
# hidden layers
|
| 164 |
+
for i in range(config.n_hidden_layers - 1):
|
| 165 |
+
weight = torch.zeros((config.n_neurons, config.n_neurons)).to(data)
|
| 166 |
+
torch.nn.init.normal_(weight, 0.0, math.sqrt(2) / math.sqrt(config.n_neurons))
|
| 167 |
+
new_data.append(weight.flatten())
|
| 168 |
+
# last layer
|
| 169 |
+
weight = torch.zeros((n_output_dims, config.n_neurons)).to(data)
|
| 170 |
+
torch.nn.init.normal_(weight, mean=math.sqrt(math.pi) / math.sqrt(config.n_neurons), std=0.0001)
|
| 171 |
+
new_data.append(weight.flatten())
|
| 172 |
+
new_data = torch.cat(new_data)
|
| 173 |
+
data.copy_(new_data)
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def get_mlp(n_input_dims, n_output_dims, config):
|
| 177 |
+
if config.otype == 'VanillaMLP':
|
| 178 |
+
network = VanillaMLP(n_input_dims, n_output_dims, config_to_primitive(config))
|
| 179 |
+
else:
|
| 180 |
+
with torch.cuda.device(get_rank()):
|
| 181 |
+
network = tcnn.Network(n_input_dims, n_output_dims, config_to_primitive(config))
|
| 182 |
+
if config.get('sphere_init', False):
|
| 183 |
+
sphere_init_tcnn_network(n_input_dims, n_output_dims, config, network)
|
| 184 |
+
return network
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
class EncodingWithNetwork(nn.Module):
|
| 188 |
+
def __init__(self, encoding, network):
|
| 189 |
+
super().__init__()
|
| 190 |
+
self.encoding, self.network = encoding, network
|
| 191 |
+
|
| 192 |
+
def forward(self, x):
|
| 193 |
+
return self.network(self.encoding(x))
|
| 194 |
+
|
| 195 |
+
def update_step(self, epoch, global_step):
|
| 196 |
+
update_module_step(self.encoding, epoch, global_step)
|
| 197 |
+
update_module_step(self.network, epoch, global_step)
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def get_encoding_with_network(n_input_dims, n_output_dims, encoding_config, network_config):
|
| 201 |
+
# input suppose to be range [0, 1]
|
| 202 |
+
if encoding_config.otype in ['VanillaFrequency', 'ProgressiveBandHashGrid'] \
|
| 203 |
+
or network_config.otype in ['VanillaMLP']:
|
| 204 |
+
encoding = get_encoding(n_input_dims, encoding_config)
|
| 205 |
+
network = get_mlp(encoding.n_output_dims, n_output_dims, network_config)
|
| 206 |
+
encoding_with_network = EncodingWithNetwork(encoding, network)
|
| 207 |
+
else:
|
| 208 |
+
with torch.cuda.device(get_rank()):
|
| 209 |
+
encoding_with_network = tcnn.NetworkWithInputEncoding(
|
| 210 |
+
n_input_dims=n_input_dims,
|
| 211 |
+
n_output_dims=n_output_dims,
|
| 212 |
+
encoding_config=config_to_primitive(encoding_config),
|
| 213 |
+
network_config=config_to_primitive(network_config)
|
| 214 |
+
)
|
| 215 |
+
return encoding_with_network
|