diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..350e9f60924dcd28d569334bae41b5e6964d93cb 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,25 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +assets/demo.png filter=lfs diff=lfs merge=lfs -text +data/demo/jellycat/001.jpg filter=lfs diff=lfs merge=lfs -text +data/demo/jellycat/002.jpg filter=lfs diff=lfs merge=lfs -text +data/demo/jellycat/003.jpg filter=lfs diff=lfs merge=lfs -text +data/demo/jellycat/004.jpg filter=lfs diff=lfs merge=lfs -text +data/demo/jordan/001.png filter=lfs diff=lfs merge=lfs -text +data/demo/jordan/002.png filter=lfs diff=lfs merge=lfs -text +data/demo/jordan/003.png filter=lfs diff=lfs merge=lfs -text +data/demo/jordan/004.png filter=lfs diff=lfs merge=lfs -text +data/demo/jordan/005.png filter=lfs diff=lfs merge=lfs -text +data/demo/jordan/006.png filter=lfs diff=lfs merge=lfs -text +data/demo/jordan/007.png filter=lfs diff=lfs merge=lfs -text +data/demo/jordan/008.png filter=lfs diff=lfs merge=lfs -text +data/demo/kew_gardens_ruined_arch/001.jpeg filter=lfs diff=lfs merge=lfs -text +data/demo/kew_gardens_ruined_arch/002.jpeg filter=lfs diff=lfs merge=lfs -text +data/demo/kew_gardens_ruined_arch/003.jpeg filter=lfs diff=lfs merge=lfs -text +data/demo/kotor_cathedral/001.jpeg filter=lfs diff=lfs merge=lfs -text +data/demo/kotor_cathedral/002.jpeg filter=lfs diff=lfs merge=lfs -text +data/demo/kotor_cathedral/003.jpeg filter=lfs diff=lfs merge=lfs -text +data/demo/kotor_cathedral/004.jpeg filter=lfs diff=lfs merge=lfs -text +data/demo/kotor_cathedral/005.jpeg filter=lfs diff=lfs merge=lfs -text +data/demo/kotor_cathedral/006.jpeg filter=lfs diff=lfs merge=lfs -text diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..c08c032512a4c09f39b35f0c2dc17c550999bc51 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2025 Qitao Zhao + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/assets/demo.png b/assets/demo.png new file mode 100644 index 0000000000000000000000000000000000000000..3f400242b56a250d0acd53dadae4027fa55df63e --- /dev/null +++ b/assets/demo.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f5021efbf6bf2ad1de68447a1e9d313581422b79c4f460ccb94654d5c08bb83c +size 1253593 diff --git a/conf/config.yaml b/conf/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1972c2ff8ee688c4d04b3b3563f057fffde72570 --- /dev/null +++ b/conf/config.yaml @@ -0,0 +1,83 @@ +training: + resume: False # If True, must set hydra.run.dir accordingly + pretrain_path: "" + interval_visualize: 1000 + interval_save_checkpoint: 5000 + interval_delete_checkpoint: 10000 + interval_evaluate: 5000 + delete_all_checkpoints_after_training: False + lr: 1e-4 + mixed_precision: True + matmul_precision: high + max_iterations: 100000 + batch_size: 64 + num_workers: 8 + gpu_id: 0 + freeze_encoder: True + seed: 0 + job_key: "" # Use this for submitit sweeps where timestamps might collide + translation_scale: 1.0 + regression: False + prob_unconditional: 0 + load_extra_cameras: False + calculate_intrinsics: False + distort: False + normalize_first_camera: True + diffuse_origins_and_endpoints: True + diffuse_depths: False + depth_resolution: 1 + dpt_head: False + full_num_patches_x: 16 + full_num_patches_y: 16 + dpt_encoder_features: True + nearest_neighbor: True + no_bg_targets: True + unit_normalize_scene: False + sd_scale: 2 + bfloat: True + first_cam_mediod: True + gradient_clipping: False + l1_loss: False + grad_accumulation: False + reinit: False + +model: + pred_x0: True + model_type: dit + num_patches_x: 16 + num_patches_y: 16 + depth: 16 + num_images: 1 + random_num_images: True + feature_extractor: dino + append_ndc: True + within_image: False + use_homogeneous: True + freeze_transformer: False + cond_depth_mask: True + +noise_scheduler: + type: linear + max_timesteps: 100 + beta_start: 0.0120 + beta_end: 0.00085 + marigold_ddim: False + +dataset: + name: co3d + shape: all_train + apply_augmentation: True + use_global_intrinsics: True + mask_holes: True + image_size: 224 + +debug: + wandb: True + project_name: diffusionsfm + run_name: + anomaly_detection: False + +hydra: + run: + dir: ./output/${now:%m%d_%H%M%S_%f}${training.job_key} + output_subdir: hydra diff --git a/conf/diffusion.yml b/conf/diffusion.yml new file mode 100644 index 0000000000000000000000000000000000000000..8609191263b129d977f11730611c3b23d88f3d9f --- /dev/null +++ b/conf/diffusion.yml @@ -0,0 +1,110 @@ +name: diffusion +channels: +- conda-forge +- iopath +- nvidia +- pkgs/main +- pytorch +- xformers +dependencies: +- _libgcc_mutex=0.1=conda_forge +- _openmp_mutex=4.5=2_gnu +- blas=1.0=mkl +- brotli-python=1.0.9=py39h5a03fae_9 +- bzip2=1.0.8=h7f98852_4 +- ca-certificates=2023.7.22=hbcca054_0 +- certifi=2023.7.22=pyhd8ed1ab_0 +- charset-normalizer=3.2.0=pyhd8ed1ab_0 +- colorama=0.4.6=pyhd8ed1ab_0 +- cuda-cudart=11.7.99=0 +- cuda-cupti=11.7.101=0 +- cuda-libraries=11.7.1=0 +- cuda-nvrtc=11.7.99=0 +- cuda-nvtx=11.7.91=0 +- cuda-runtime=11.7.1=0 +- ffmpeg=4.3=hf484d3e_0 +- filelock=3.12.2=pyhd8ed1ab_0 +- freetype=2.12.1=hca18f0e_1 +- fvcore=0.1.5.post20221221=pyhd8ed1ab_0 +- gmp=6.2.1=h58526e2_0 +- gmpy2=2.1.2=py39h376b7d2_1 +- gnutls=3.6.13=h85f3911_1 +- idna=3.4=pyhd8ed1ab_0 +- intel-openmp=2022.1.0=h9e868ea_3769 +- iopath=0.1.9=py39 +- jinja2=3.1.2=pyhd8ed1ab_1 +- jpeg=9e=h0b41bf4_3 +- lame=3.100=h166bdaf_1003 +- lcms2=2.15=hfd0df8a_0 +- ld_impl_linux-64=2.40=h41732ed_0 +- lerc=4.0.0=h27087fc_0 +- libblas=3.9.0=16_linux64_mkl +- libcblas=3.9.0=16_linux64_mkl +- libcublas=11.10.3.66=0 +- libcufft=10.7.2.124=h4fbf590_0 +- libcufile=1.7.1.12=0 +- libcurand=10.3.3.129=0 +- libcusolver=11.4.0.1=0 +- libcusparse=11.7.4.91=0 +- libdeflate=1.17=h0b41bf4_0 +- libffi=3.3=h58526e2_2 +- libgcc-ng=13.1.0=he5830b7_0 +- libgomp=13.1.0=he5830b7_0 +- libiconv=1.17=h166bdaf_0 +- liblapack=3.9.0=16_linux64_mkl +- libnpp=11.7.4.75=0 +- libnvjpeg=11.8.0.2=0 +- libpng=1.6.39=h753d276_0 +- libsqlite=3.42.0=h2797004_0 +- libstdcxx-ng=13.1.0=hfd8a6a1_0 +- libtiff=4.5.0=h6adf6a1_2 +- libwebp-base=1.3.1=hd590300_0 +- libxcb=1.13=h7f98852_1004 +- libzlib=1.2.13=hd590300_5 +- markupsafe=2.1.3=py39hd1e30aa_0 +- mkl=2022.1.0=hc2b9512_224 +- mpc=1.3.1=hfe3b2da_0 +- mpfr=4.2.0=hb012696_0 +- mpmath=1.3.0=pyhd8ed1ab_0 +- ncurses=6.4=hcb278e6_0 +- nettle=3.6=he412f7d_0 +- networkx=3.1=pyhd8ed1ab_0 +- numpy=1.25.2=py39h6183b62_0 +- openh264=2.1.1=h780b84a_0 +- openjpeg=2.5.0=hfec8fc6_2 +- openssl=1.1.1v=hd590300_0 +- pillow=9.4.0=py39h2320bf1_1 +- pip=23.2.1=pyhd8ed1ab_0 +- portalocker=2.7.0=py39hf3d152e_0 +- pthread-stubs=0.4=h36c2ea0_1001 +- pysocks=1.7.1=pyha2e5f31_6 +- python=3.9.0=hffdb5ce_5_cpython +- python_abi=3.9=3_cp39 +- pytorch=2.0.1=py3.9_cuda11.7_cudnn8.5.0_0 +- pytorch-cuda=11.7=h778d358_5 +- pytorch-mutex=1.0=cuda +- pyyaml=6.0=py39hb9d737c_5 +- readline=8.2=h8228510_1 +- requests=2.31.0=pyhd8ed1ab_0 +- setuptools=68.0.0=pyhd8ed1ab_0 +- sqlite=3.42.0=h2c6b66d_0 +- sympy=1.12=pypyh9d50eac_103 +- tabulate=0.9.0=pyhd8ed1ab_1 +- termcolor=2.3.0=pyhd8ed1ab_0 +- tk=8.6.12=h27826a3_0 +- torchaudio=2.0.2=py39_cu117 +- torchtriton=2.0.0=py39 +- torchvision=0.15.2=py39_cu117 +- tqdm=4.66.1=pyhd8ed1ab_0 +- typing_extensions=4.7.1=pyha770c72_0 +- tzdata=2023c=h71feb2d_0 +- urllib3=2.0.4=pyhd8ed1ab_0 +- wheel=0.41.1=pyhd8ed1ab_0 +- xformers=0.0.21=py39_cu11.8.0_pyt2.0.1 +- xorg-libxau=1.0.11=hd590300_0 +- xorg-libxdmcp=1.1.3=h7f98852_0 +- xz=5.2.6=h166bdaf_0 +- yacs=0.1.8=pyhd8ed1ab_0 +- yaml=0.2.5=h7f98852_2 +- zlib=1.2.13=hd590300_5 +- zstd=1.5.2=hfc55251_7 diff --git a/data/demo/jellycat/001.jpg b/data/demo/jellycat/001.jpg new file mode 100644 index 0000000000000000000000000000000000000000..4a250f2aac6ba3bd0d6decdd5a78d5f029d80a5a --- /dev/null +++ b/data/demo/jellycat/001.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bb252fabcd6588b924266098efbf0538c4cc77f6fd623166a3328692ea04b221 +size 6913058 diff --git a/data/demo/jellycat/002.jpg b/data/demo/jellycat/002.jpg new file mode 100644 index 0000000000000000000000000000000000000000..74320cb106c27c85e701204be3be4734401a130d --- /dev/null +++ b/data/demo/jellycat/002.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4e36092e3ef63d3de0d9645ce001829c248b2c5ae78011c3578276d4f0009ce6 +size 6857518 diff --git a/data/demo/jellycat/003.jpg b/data/demo/jellycat/003.jpg new file mode 100644 index 0000000000000000000000000000000000000000..70f1037e75eb12aaf96045f37395ae25fde76056 --- /dev/null +++ b/data/demo/jellycat/003.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6fcd32e046e04b809529c202f594a49a210d1fcd38a4664ca619704dd317b550 +size 169160 diff --git a/data/demo/jellycat/004.jpg b/data/demo/jellycat/004.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c9dc3d42fa64beed041b090a211d778a89fef737 --- /dev/null +++ b/data/demo/jellycat/004.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:082007c67949ce96af89d34fbb3dd8a6eeca4d000e4dc39a920215881ee5a4e1 +size 118663 diff --git a/data/demo/jordan/001.png b/data/demo/jordan/001.png new file mode 100644 index 0000000000000000000000000000000000000000..b2303d2eeca0b5808feae6fc6b31ab22b120d0cc --- /dev/null +++ b/data/demo/jordan/001.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dff6883afa87339f94ac9d2b07a61e01f9107f2e37a7bda326956f209a5b1c61 +size 128454 diff --git a/data/demo/jordan/002.png b/data/demo/jordan/002.png new file mode 100644 index 0000000000000000000000000000000000000000..3cd39ca383ad17ea342be2455c212504dd1d5087 --- /dev/null +++ b/data/demo/jordan/002.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bee5060ab4b105fd9383398ab47dc1caa2dd329f9e83ae310bb068870d445270 +size 125808 diff --git a/data/demo/jordan/003.png b/data/demo/jordan/003.png new file mode 100644 index 0000000000000000000000000000000000000000..c59163b39e6cd6208e71faa04a0bf99a960b8c4d --- /dev/null +++ b/data/demo/jordan/003.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:34353a07643bb0f6dcc8d1a40d1658e393998af73db0cd17e392558581840c03 +size 135412 diff --git a/data/demo/jordan/004.png b/data/demo/jordan/004.png new file mode 100644 index 0000000000000000000000000000000000000000..8e5947503cfbaacb3664588ff8082058b28dc862 --- /dev/null +++ b/data/demo/jordan/004.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c671d0fb4ff49d59e6b044e8a673ad6e9293337423f6db1345c3e1c45f0c7427 +size 123310 diff --git a/data/demo/jordan/005.png b/data/demo/jordan/005.png new file mode 100644 index 0000000000000000000000000000000000000000..16667b9a510bb018d8a9f063c04c1deb0fdc9037 --- /dev/null +++ b/data/demo/jordan/005.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:aefe2f7ca57407dad2a7ce86759b86698d536d5f6c7fd9f97776d3905bb3ce19 +size 114628 diff --git a/data/demo/jordan/006.png b/data/demo/jordan/006.png new file mode 100644 index 0000000000000000000000000000000000000000..650c4dbda8e727aea445195b95d5a450a47ffdb5 --- /dev/null +++ b/data/demo/jordan/006.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f5f23767c8a3830921e1c5299cc8373daacf9af68e96c596df5b8edbdc0f4836 +size 145214 diff --git a/data/demo/jordan/007.png b/data/demo/jordan/007.png new file mode 100644 index 0000000000000000000000000000000000000000..47124e5f8df22526d9eb6345aadc49c85bae9766 --- /dev/null +++ b/data/demo/jordan/007.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c61521925f93ec2721a02c4c5b4898171985147c763702cb5f9a7efbb341cc2d +size 145108 diff --git a/data/demo/jordan/008.png b/data/demo/jordan/008.png new file mode 100644 index 0000000000000000000000000000000000000000..cc6e034320a78c886e31fbe78111843aa839002e --- /dev/null +++ b/data/demo/jordan/008.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bb1e7c3d5fc1ad0067d2d28d9f83e5d5243010353f0ec0fd12f196cf5939f231 +size 142700 diff --git a/data/demo/kew_gardens_ruined_arch/001.jpeg b/data/demo/kew_gardens_ruined_arch/001.jpeg new file mode 100644 index 0000000000000000000000000000000000000000..6d70e4c236c6f1abb20a4e667e10638cf02fdecf --- /dev/null +++ b/data/demo/kew_gardens_ruined_arch/001.jpeg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:96dfde51e8d0857120387e3b81fff665b9b94524a6f1a9f35246faed4f3e8986 +size 624410 diff --git a/data/demo/kew_gardens_ruined_arch/002.jpeg b/data/demo/kew_gardens_ruined_arch/002.jpeg new file mode 100644 index 0000000000000000000000000000000000000000..1f0c249b66cc8f6c62f926182b9b47cf86900823 --- /dev/null +++ b/data/demo/kew_gardens_ruined_arch/002.jpeg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8c2c07e43d51594fbbea708fce0040fbe6b5ecd4e01c8b10898a2f71f3abf186 +size 589599 diff --git a/data/demo/kew_gardens_ruined_arch/003.jpeg b/data/demo/kew_gardens_ruined_arch/003.jpeg new file mode 100644 index 0000000000000000000000000000000000000000..2812562af772fdf499aa353cda8ede5aa67ce440 --- /dev/null +++ b/data/demo/kew_gardens_ruined_arch/003.jpeg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bfeea8fcb46fcbb0d77450227927851472a73a714c62de21e75b4d60a3dba317 +size 586250 diff --git a/data/demo/kotor_cathedral/001.jpeg b/data/demo/kotor_cathedral/001.jpeg new file mode 100644 index 0000000000000000000000000000000000000000..b62a46fe151cafbcb08140d14629ba3d95f17567 --- /dev/null +++ b/data/demo/kotor_cathedral/001.jpeg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:732a5d344ddcfc2e50a97abc3792bb444cf40b82a93638f2d52d955f6595c90a +size 617086 diff --git a/data/demo/kotor_cathedral/002.jpeg b/data/demo/kotor_cathedral/002.jpeg new file mode 100644 index 0000000000000000000000000000000000000000..65e449a78fd2c287a5fbb3944dc695f7214e1f4a --- /dev/null +++ b/data/demo/kotor_cathedral/002.jpeg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e2dc18fde3559ae7333351ded6d765b217a5b754fd828087c5f88cbc11c84793 +size 759782 diff --git a/data/demo/kotor_cathedral/003.jpeg b/data/demo/kotor_cathedral/003.jpeg new file mode 100644 index 0000000000000000000000000000000000000000..3b4683862a3d2347f7908b3a236c09133e4a7d82 --- /dev/null +++ b/data/demo/kotor_cathedral/003.jpeg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b100bdbc2d9943151424b86e504ece203f15ff5d616dd4c515fabc7b3d39d11c +size 697022 diff --git a/data/demo/kotor_cathedral/004.jpeg b/data/demo/kotor_cathedral/004.jpeg new file mode 100644 index 0000000000000000000000000000000000000000..4ed6b74203ef2587bc5b7a9cf372e66631623a3d --- /dev/null +++ b/data/demo/kotor_cathedral/004.jpeg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:26d6433fa500c03bd982a3abaf2f8028d26a9726e409b880192de5eea17d83b5 +size 582854 diff --git a/data/demo/kotor_cathedral/005.jpeg b/data/demo/kotor_cathedral/005.jpeg new file mode 100644 index 0000000000000000000000000000000000000000..3ac5b846a306839c7a39b1782ae0fbf4ce4f37dd --- /dev/null +++ b/data/demo/kotor_cathedral/005.jpeg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:15b4e342917ae6df3a43a82aba0fb199098fd5ff9946303109b5e21a613a6d30 +size 901648 diff --git a/data/demo/kotor_cathedral/006.jpeg b/data/demo/kotor_cathedral/006.jpeg new file mode 100644 index 0000000000000000000000000000000000000000..e13a4e2ff1ddfefe42fc2ee66da105e67dc9f73f --- /dev/null +++ b/data/demo/kotor_cathedral/006.jpeg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:da1098cc0360bbc34eb1a61a224ccdaa43677c4e6b26687c8f9bb95fbf7a2f42 +size 411364 diff --git a/diffusionsfm/__init__.py b/diffusionsfm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bb5d822a1517fb8d1f195741fd743b1d2c917cf8 --- /dev/null +++ b/diffusionsfm/__init__.py @@ -0,0 +1 @@ +from .utils.rays import cameras_to_rays, rays_to_cameras, Rays diff --git a/diffusionsfm/dataset/__init__.py b/diffusionsfm/dataset/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/diffusionsfm/dataset/co3d_v2.py b/diffusionsfm/dataset/co3d_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..0385059207991d68ba355f5ad256eff437f31438 --- /dev/null +++ b/diffusionsfm/dataset/co3d_v2.py @@ -0,0 +1,792 @@ +import gzip +import json +import os.path as osp +import random +import socket +import time +import torch +import warnings + +import numpy as np +from PIL import Image, ImageFile +from tqdm import tqdm +from pytorch3d.renderer import PerspectiveCameras +from torch.utils.data import Dataset +from torchvision import transforms +import matplotlib.pyplot as plt +from scipy import ndimage as nd + +from diffusionsfm.utils.distortion import distort_image + + +HOSTNAME = socket.gethostname() + +CO3D_DIR = "../co3d_data" # update this +CO3D_ANNOTATION_DIR = osp.join(CO3D_DIR, "co3d_annotations") +CO3D_DIR = CO3D_DEPTH_DIR = osp.join(CO3D_DIR, "co3d") +order_path = osp.join( + CO3D_DIR, "co3d_v2_random_order_{sample_num}/{category}.json" +) + + +TRAINING_CATEGORIES = [ + "apple", + "backpack", + "banana", + "baseballbat", + "baseballglove", + "bench", + "bicycle", + "bottle", + "bowl", + "broccoli", + "cake", + "car", + "carrot", + "cellphone", + "chair", + "cup", + "donut", + "hairdryer", + "handbag", + "hydrant", + "keyboard", + "laptop", + "microwave", + "motorcycle", + "mouse", + "orange", + "parkingmeter", + "pizza", + "plant", + "stopsign", + "teddybear", + "toaster", + "toilet", + "toybus", + "toyplane", + "toytrain", + "toytruck", + "tv", + "umbrella", + "vase", + "wineglass", +] + +TEST_CATEGORIES = [ + "ball", + "book", + "couch", + "frisbee", + "hotdog", + "kite", + "remote", + "sandwich", + "skateboard", + "suitcase", +] + +assert len(TRAINING_CATEGORIES) + len(TEST_CATEGORIES) == 51 + +Image.MAX_IMAGE_PIXELS = None +ImageFile.LOAD_TRUNCATED_IMAGES = True + + +def fill_depths(data, invalid=None): + data_list = [] + for i in range(data.shape[0]): + data_item = data[i].numpy() + # Invalid must be 1 where stuff is invalid, 0 where valid + ind = nd.distance_transform_edt( + invalid[i], return_distances=False, return_indices=True + ) + data_list.append(torch.tensor(data_item[tuple(ind)])) + return torch.stack(data_list, dim=0) + + +def full_scene_scale(batch): + cameras = PerspectiveCameras(R=batch["R"], T=batch["T"], device="cuda") + cc = cameras.get_camera_center() + centroid = torch.mean(cc, dim=0) + + diffs = cc - centroid + norms = torch.linalg.norm(diffs, dim=1) + + furthest_index = torch.argmax(norms).item() + scale = norms[furthest_index].item() + return scale + + +def square_bbox(bbox, padding=0.0, astype=None, tight=False): + """ + Computes a square bounding box, with optional padding parameters. + Args: + bbox: Bounding box in xyxy format (4,). + Returns: + square_bbox in xyxy format (4,). + """ + if astype is None: + astype = type(bbox[0]) + bbox = np.array(bbox) + center = (bbox[:2] + bbox[2:]) / 2 + extents = (bbox[2:] - bbox[:2]) / 2 + + # No black bars if tight + if tight: + s = min(extents) * (1 + padding) + else: + s = max(extents) * (1 + padding) + + square_bbox = np.array( + [center[0] - s, center[1] - s, center[0] + s, center[1] + s], + dtype=astype, + ) + return square_bbox + + +def unnormalize_image(image, return_numpy=True, return_int=True): + if isinstance(image, torch.Tensor): + image = image.detach().cpu().numpy() + + if image.ndim == 3: + if image.shape[0] == 3: + image = image[None, ...] + elif image.shape[2] == 3: + image = image.transpose(2, 0, 1)[None, ...] + else: + raise ValueError(f"Unexpected image shape: {image.shape}") + elif image.ndim == 4: + if image.shape[1] == 3: + pass + elif image.shape[3] == 3: + image = image.transpose(0, 3, 1, 2) + else: + raise ValueError(f"Unexpected batch image shape: {image.shape}") + else: + raise ValueError(f"Unsupported input shape: {image.shape}") + + mean = np.array([0.485, 0.456, 0.406])[None, :, None, None] + std = np.array([0.229, 0.224, 0.225])[None, :, None, None] + image = image * std + mean + + if return_int: + image = np.clip(image * 255.0, 0, 255).astype(np.uint8) + else: + image = np.clip(image, 0.0, 1.0) + + if image.shape[0] == 1: + image = image[0] + + if return_numpy: + return image + else: + return torch.from_numpy(image) + + +def unnormalize_image_for_vis(image): + assert len(image.shape) == 5 and image.shape[2] == 3 + mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 1, 3, 1, 1).to(image.device) + std = torch.tensor([0.229, 0.224, 0.225]).view(1, 1, 3, 1, 1).to(image.device) + image = image * std + mean + image = (image - 0.5) / 0.5 + return image + + +def _transform_intrinsic(image, bbox, principal_point, focal_length): + # Rescale intrinsics to match bbox + half_box = np.array([image.width, image.height]).astype(np.float32) / 2 + org_scale = min(half_box).astype(np.float32) + + # Pixel coordinates + principal_point_px = half_box - (np.array(principal_point) * org_scale) + focal_length_px = np.array(focal_length) * org_scale + principal_point_px -= bbox[:2] + new_bbox = (bbox[2:] - bbox[:2]) / 2 + new_scale = min(new_bbox) + + # NDC coordinates + new_principal_ndc = (new_bbox - principal_point_px) / new_scale + new_focal_ndc = focal_length_px / new_scale + + principal_point = torch.tensor(new_principal_ndc.astype(np.float32)) + focal_length = torch.tensor(new_focal_ndc.astype(np.float32)) + + return principal_point, focal_length + + +def construct_camera_from_batch(batch, device): + if isinstance(device, int): + device = f"cuda:{device}" + + return PerspectiveCameras( + R=batch["R"].reshape(-1, 3, 3), + T=batch["T"].reshape(-1, 3), + focal_length=batch["focal_lengths"].reshape(-1, 2), + principal_point=batch["principal_points"].reshape(-1, 2), + image_size=batch["image_sizes"].reshape(-1, 2), + device=device, + ) + + +def save_batch_images(images, fname): + cmap = plt.get_cmap("hsv") + num_frames = len(images) + num_rows = len(images) + num_cols = 4 + figsize = (num_cols * 2, num_rows * 2) + fig, axs = plt.subplots(num_rows, num_cols, figsize=figsize) + axs = axs.flatten() + for i in range(num_rows): + for j in range(4): + if i < num_frames: + axs[i * 4 + j].imshow(unnormalize_image(images[i][j])) + for s in ["bottom", "top", "left", "right"]: + axs[i * 4 + j].spines[s].set_color(cmap(i / (num_frames))) + axs[i * 4 + j].spines[s].set_linewidth(5) + axs[i * 4 + j].set_xticks([]) + axs[i * 4 + j].set_yticks([]) + else: + axs[i * 4 + j].axis("off") + plt.tight_layout() + plt.savefig(fname) + + +def jitter_bbox( + square_bbox, + jitter_scale=(1.1, 1.2), + jitter_trans=(-0.07, 0.07), + direction_from_size=None, +): + + square_bbox = np.array(square_bbox.astype(float)) + s = np.random.uniform(jitter_scale[0], jitter_scale[1]) + + # Jitter only one dimension if center cropping + tx, ty = np.random.uniform(jitter_trans[0], jitter_trans[1], size=2) + if direction_from_size is not None: + if direction_from_size[0] > direction_from_size[1]: + tx = 0 + else: + ty = 0 + + side_length = square_bbox[2] - square_bbox[0] + center = (square_bbox[:2] + square_bbox[2:]) / 2 + np.array([tx, ty]) * side_length + extent = side_length / 2 * s + ul = center - extent + lr = ul + 2 * extent + return np.concatenate((ul, lr)) + + +class Co3dDataset(Dataset): + def __init__( + self, + category=("all_train",), + split="train", + transform=None, + num_images=2, + img_size=224, + mask_images=False, + crop_images=True, + co3d_dir=None, + co3d_annotation_dir=None, + precropped_images=False, + apply_augmentation=True, + normalize_cameras=True, + no_images=False, + sample_num=None, + seed=0, + load_extra_cameras=False, + distort_image=False, + load_depths=False, + center_crop=False, + depth_size=256, + mask_holes=False, + object_mask=True, + ): + """ + Args: + num_images: Number of images in each batch. + perspective_correction (str): + "none": No perspective correction. + "warp": Warp the image and label. + "label_only": Correct the label only. + """ + start_time = time.time() + + self.category = category + self.split = split + self.transform = transform + self.num_images = num_images + self.img_size = img_size + self.mask_images = mask_images + self.crop_images = crop_images + self.precropped_images = precropped_images + self.apply_augmentation = apply_augmentation + self.normalize_cameras = normalize_cameras + self.no_images = no_images + self.sample_num = sample_num + self.load_extra_cameras = load_extra_cameras + self.distort = distort_image + self.load_depths = load_depths + self.center_crop = center_crop + self.depth_size = depth_size + self.mask_holes = mask_holes + self.object_mask = object_mask + + if self.apply_augmentation: + if self.center_crop: + self.jitter_scale = (0.8, 1.1) + self.jitter_trans = (0.0, 0.0) + else: + self.jitter_scale = (1.1, 1.2) + self.jitter_trans = (-0.07, 0.07) + else: + # Note if trained with apply_augmentation, we should still use + # apply_augmentation at test time. + self.jitter_scale = (1, 1) + self.jitter_trans = (0.0, 0.0) + + if self.distort: + self.k1_max = 1.0 + self.k2_max = 1.0 + + if co3d_dir is not None: + self.co3d_dir = co3d_dir + self.co3d_annotation_dir = co3d_annotation_dir + else: + self.co3d_dir = CO3D_DIR + self.co3d_annotation_dir = CO3D_ANNOTATION_DIR + self.co3d_depth_dir = CO3D_DEPTH_DIR + + if isinstance(self.category, str): + self.category = [self.category] + + if "all_train" in self.category: + self.category = TRAINING_CATEGORIES + if "all_test" in self.category: + self.category = TEST_CATEGORIES + if "full" in self.category: + self.category = TRAINING_CATEGORIES + TEST_CATEGORIES + self.category = sorted(self.category) + self.is_single_category = len(self.category) == 1 + + # Fixing seed + torch.manual_seed(seed) + random.seed(seed) + np.random.seed(seed) + + print(f"Co3d ({split}):") + + self.low_quality_translations = [ + "411_55952_107659", + "427_59915_115716", + "435_61970_121848", + "112_13265_22828", + "110_13069_25642", + "165_18080_34378", + "368_39891_78502", + "391_47029_93665", + "20_695_1450", + "135_15556_31096", + "417_57572_110680", + ] # Initialized with sequences with poor depth masks + self.rotations = {} + self.category_map = {} + for c in tqdm(self.category): + annotation_file = osp.join( + self.co3d_annotation_dir, f"{c}_{self.split}.jgz" + ) + with gzip.open(annotation_file, "r") as fin: + annotation = json.loads(fin.read()) + + counter = 0 + for seq_name, seq_data in annotation.items(): + counter += 1 + if len(seq_data) < self.num_images: + continue + + filtered_data = [] + self.category_map[seq_name] = c + bad_seq = False + for data in seq_data: + # Make sure translations are not ridiculous and rotations are valid + det = np.linalg.det(data["R"]) + if (np.abs(data["T"]) > 1e5).any() or det < 0.99 or det > 1.01: + bad_seq = True + self.low_quality_translations.append(seq_name) + break + + # Ignore all unnecessary information. + filtered_data.append( + { + "filepath": data["filepath"], + "bbox": data["bbox"], + "R": data["R"], + "T": data["T"], + "focal_length": data["focal_length"], + "principal_point": data["principal_point"], + }, + ) + + if not bad_seq: + self.rotations[seq_name] = filtered_data + + self.sequence_list = list(self.rotations.keys()) + + IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) + IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) + + if self.transform is None: + self.transform = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Resize(self.img_size, antialias=True), + transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), + ] + ) + + self.transform_depth = transforms.Compose( + [ + transforms.Resize( + self.depth_size, + antialias=False, + interpolation=transforms.InterpolationMode.NEAREST_EXACT, + ), + ] + ) + + print( + f"Low quality translation sequences, not used: {self.low_quality_translations}" + ) + print(f"Data size: {len(self)}") + print(f"Data loading took {(time.time()-start_time)} seconds.") + + def __len__(self): + return len(self.sequence_list) + + def __getitem__(self, index): + num_to_load = self.num_images if not self.load_extra_cameras else 8 + + sequence_name = self.sequence_list[index % len(self.sequence_list)] + metadata = self.rotations[sequence_name] + + if self.sample_num is not None: + with open( + order_path.format(sample_num=self.sample_num, category=self.category[0]) + ) as f: + order = json.load(f) + ids = order[sequence_name][:num_to_load] + else: + replace = len(metadata) < 8 + ids = np.random.choice(len(metadata), num_to_load, replace=replace) + + return self.get_data(index=index, ids=ids, num_valid_frames=num_to_load) + + def _get_scene_scale(self, sequence_name): + n = len(self.rotations[sequence_name]) + + R = torch.zeros(n, 3, 3) + T = torch.zeros(n, 3) + + for i, ann in enumerate(self.rotations[sequence_name]): + R[i, ...] = torch.tensor(self.rotations[sequence_name][i]["R"]) + T[i, ...] = torch.tensor(self.rotations[sequence_name][i]["T"]) + + cameras = PerspectiveCameras(R=R, T=T) + cc = cameras.get_camera_center() + centeroid = torch.mean(cc, dim=0) + diff = cc - centeroid + + norm = torch.norm(diff, dim=1) + scale = torch.max(norm).item() + + return scale + + def _crop_image(self, image, bbox): + image_crop = transforms.functional.crop( + image, + top=bbox[1], + left=bbox[0], + height=bbox[3] - bbox[1], + width=bbox[2] - bbox[0], + ) + return image_crop + + def _transform_intrinsic(self, image, bbox, principal_point, focal_length): + half_box = np.array([image.width, image.height]).astype(np.float32) / 2 + org_scale = min(half_box).astype(np.float32) + + # Pixel coordinates + principal_point_px = half_box - (np.array(principal_point) * org_scale) + focal_length_px = np.array(focal_length) * org_scale + principal_point_px -= bbox[:2] + new_bbox = (bbox[2:] - bbox[:2]) / 2 + new_scale = min(new_bbox) + + # NDC coordinates + new_principal_ndc = (new_bbox - principal_point_px) / new_scale + new_focal_ndc = focal_length_px / new_scale + + return new_principal_ndc.astype(np.float32), new_focal_ndc.astype(np.float32) + + def get_data( + self, + index=None, + sequence_name=None, + ids=(0, 1), + no_images=False, + num_valid_frames=None, + load_using_order=None, + ): + if load_using_order is not None: + with open( + order_path.format(sample_num=self.sample_num, category=self.category[0]) + ) as f: + order = json.load(f) + ids = order[sequence_name][:load_using_order] + + if sequence_name is None: + index = index % len(self.sequence_list) + sequence_name = self.sequence_list[index] + metadata = self.rotations[sequence_name] + category = self.category_map[sequence_name] + + # Read image & camera information from annotations + annos = [metadata[i] for i in ids] + images = [] + image_sizes = [] + PP = [] + FL = [] + crop_parameters = [] + filenames = [] + distortion_parameters = [] + depths = [] + depth_masks = [] + object_masks = [] + dino_images = [] + for anno in annos: + filepath = anno["filepath"] + + if not no_images: + image = Image.open(osp.join(self.co3d_dir, filepath)).convert("RGB") + image_size = image.size + + # Optionally mask images with black background + if self.mask_images: + black_image = Image.new("RGB", image_size, (0, 0, 0)) + mask_name = osp.basename(filepath.replace(".jpg", ".png")) + + mask_path = osp.join( + self.co3d_dir, category, sequence_name, "masks", mask_name + ) + mask = Image.open(mask_path).convert("L") + + if mask.size != image_size: + mask = mask.resize(image_size) + mask = Image.fromarray(np.array(mask) > 125) + image = Image.composite(image, black_image, mask) + + if self.object_mask: + mask_name = osp.basename(filepath.replace(".jpg", ".png")) + mask_path = osp.join( + self.co3d_dir, category, sequence_name, "masks", mask_name + ) + mask = Image.open(mask_path).convert("L") + + if mask.size != image_size: + mask = mask.resize(image_size) + mask = torch.from_numpy(np.array(mask) > 125) + + # Determine crop, Resnet wants square images + bbox = np.array(anno["bbox"]) + good_bbox = ((bbox[2:] - bbox[:2]) > 30).all() + bbox = ( + anno["bbox"] + if not self.center_crop and good_bbox + else [0, 0, image.width, image.height] + ) + + # Distort image and bbox if desired + if self.distort: + k1 = random.uniform(0, self.k1_max) + k2 = random.uniform(0, self.k2_max) + + try: + image, bbox = distort_image( + image, np.array(bbox), k1, k2, modify_bbox=True + ) + + except: + print("INFO:") + print(sequence_name) + print(index) + print(ids) + print(k1) + print(k2) + + distortion_parameters.append(torch.FloatTensor([k1, k2])) + + bbox = square_bbox(np.array(bbox), tight=self.center_crop) + if self.apply_augmentation: + bbox = jitter_bbox( + bbox, + jitter_scale=self.jitter_scale, + jitter_trans=self.jitter_trans, + direction_from_size=image.size if self.center_crop else None, + ) + bbox = np.around(bbox).astype(int) + + # Crop parameters + crop_center = (bbox[:2] + bbox[2:]) / 2 + principal_point = torch.tensor(anno["principal_point"]) + focal_length = torch.tensor(anno["focal_length"]) + + # convert crop center to correspond to a "square" image + width, height = image.size + length = max(width, height) + s = length / min(width, height) + crop_center = crop_center + (length - np.array([width, height])) / 2 + + # convert to NDC + cc = s - 2 * s * crop_center / length + crop_width = 2 * s * (bbox[2] - bbox[0]) / length + crop_params = torch.tensor([-cc[0], -cc[1], crop_width, s]) + + # Crop and normalize image + if not self.precropped_images: + image = self._crop_image(image, bbox) + + try: + image = self.transform(image) + except: + print("INFO:") + print(sequence_name) + print(index) + print(ids) + print(k1) + print(k2) + + images.append(image[:, : self.img_size, : self.img_size]) + crop_parameters.append(crop_params) + + if self.load_depths: + # Open depth map + depth_name = osp.basename( + filepath.replace(".jpg", ".jpg.geometric.png") + ) + depth_path = osp.join( + self.co3d_depth_dir, + category, + sequence_name, + "depths", + depth_name, + ) + depth_pil = Image.open(depth_path) + + # 16 bit float type casting + depth = torch.tensor( + np.frombuffer( + np.array(depth_pil, dtype=np.uint16), dtype=np.float16 + ) + .astype(np.float32) + .reshape((depth_pil.size[1], depth_pil.size[0])) + ) + + # Crop and resize as with images + if depth_pil.size != image_size: + # bbox may have the wrong scale + bbox = depth_pil.size[0] * bbox / image_size[0] + + if self.object_mask: + assert mask.shape == depth.shape + + bbox = np.around(bbox).astype(int) + depth = self._crop_image(depth, bbox) + + # Resize + depth = self.transform_depth(depth.unsqueeze(0))[ + 0, : self.depth_size, : self.depth_size + ] + depths.append(depth) + + if self.object_mask: + mask = self._crop_image(mask, bbox) + mask = self.transform_depth(mask.unsqueeze(0))[ + 0, : self.depth_size, : self.depth_size + ] + object_masks.append(mask) + + PP.append(principal_point) + FL.append(focal_length) + image_sizes.append(torch.tensor([self.img_size, self.img_size])) + filenames.append(filepath) + + if not no_images: + if self.load_depths: + depths = torch.stack(depths) + + depth_masks = torch.logical_or(depths <= 0, depths.isinf()) + depth_masks = (~depth_masks).long() + + if self.object_mask: + object_masks = torch.stack(object_masks, dim=0) + + if self.mask_holes: + depths = fill_depths(depths, depth_masks == 0) + + # Sometimes mask_holes misses stuff + new_masks = torch.logical_or(depths <= 0, depths.isinf()) + new_masks = (~new_masks).long() + depths[new_masks == 0] = -1 + + assert torch.logical_or(depths > 0, depths == -1).all() + assert not (depths.isinf()).any() + assert not (depths.isnan()).any() + + if self.load_extra_cameras: + # Remove the extra loaded image, for saving space + images = images[: self.num_images] + + if self.distort: + distortion_parameters = torch.stack(distortion_parameters) + + images = torch.stack(images) + crop_parameters = torch.stack(crop_parameters) + focal_lengths = torch.stack(FL) + principal_points = torch.stack(PP) + image_sizes = torch.stack(image_sizes) + else: + images = None + crop_parameters = None + distortion_parameters = None + focal_lengths = [] + principal_points = [] + image_sizes = [] + + # Assemble batch info to send back + R = torch.stack([torch.tensor(anno["R"]) for anno in annos]) + T = torch.stack([torch.tensor(anno["T"]) for anno in annos]) + + batch = { + "model_id": sequence_name, + "category": category, + "n": len(metadata), + "num_valid_frames": num_valid_frames, + "ind": torch.tensor(ids), + "image": images, + "depth": depths, + "depth_masks": depth_masks, + "object_masks": object_masks, + "R": R, + "T": T, + "focal_length": focal_lengths, + "principal_point": principal_points, + "image_size": image_sizes, + "crop_parameters": crop_parameters, + "distortion_parameters": torch.zeros(4), + "filename": filenames, + "category": category, + "dataset": "co3d", + } + + return batch diff --git a/diffusionsfm/dataset/custom.py b/diffusionsfm/dataset/custom.py new file mode 100644 index 0000000000000000000000000000000000000000..533d3a34062f9cb1bc075838f1d39859a15d64ff --- /dev/null +++ b/diffusionsfm/dataset/custom.py @@ -0,0 +1,105 @@ + +import torch +import numpy as np +import matplotlib.pyplot as plt + +from PIL import Image, ImageOps +from torch.utils.data import Dataset +from torchvision import transforms + +from diffusionsfm.dataset.co3d_v2 import square_bbox + + +class CustomDataset(Dataset): + def __init__( + self, + image_list, + ): + self.images = [] + + for image_path in sorted(image_list): + img = Image.open(image_path) + img = ImageOps.exif_transpose(img).convert("RGB") # Apply EXIF rotation + self.images.append(img) + + self.n = len(self.images) + self.jitter_scale = [1, 1] + self.jitter_trans = [0, 0] + self.transform = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Resize(224), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), + ] + ) + self.transform_for_vis = transforms.Compose( + [ + transforms.Resize(224), + ] + ) + + def __len__(self): + return 1 + + def _crop_image(self, image, bbox, white_bg=False): + if white_bg: + # Only support PIL Images + image_crop = Image.new( + "RGB", (bbox[2] - bbox[0], bbox[3] - bbox[1]), (255, 255, 255) + ) + image_crop.paste(image, (-bbox[0], -bbox[1])) + else: + image_crop = transforms.functional.crop( + image, + top=bbox[1], + left=bbox[0], + height=bbox[3] - bbox[1], + width=bbox[2] - bbox[0], + ) + return image_crop + + def __getitem__(self): + return self.get_data() + + def get_data(self): + cmap = plt.get_cmap("hsv") + ids = [i for i in range(len(self.images))] + images = [self.images[i] for i in ids] + images_transformed = [] + images_for_vis = [] + crop_parameters = [] + + for i, image in enumerate(images): + bbox = np.array([0, 0, image.width, image.height]) + bbox = square_bbox(bbox, tight=True) + bbox = np.around(bbox).astype(int) + image = self._crop_image(image, bbox) + images_transformed.append(self.transform(image)) + image_for_vis = self.transform_for_vis(image) + color_float = cmap(i / len(images)) + color_rgb = tuple(int(255 * c) for c in color_float[:3]) + image_for_vis = ImageOps.expand(image_for_vis, border=3, fill=color_rgb) + images_for_vis.append(image_for_vis) + + width, height = image.size + length = max(width, height) + s = length / min(width, height) + crop_center = (bbox[:2] + bbox[2:]) / 2 + crop_center = crop_center + (length - np.array([width, height])) / 2 + # convert to NDC + cc = s - 2 * s * crop_center / length + crop_width = 2 * s * (bbox[2] - bbox[0]) / length + crop_params = torch.tensor([-cc[0], -cc[1], crop_width, s]) + + crop_parameters.append(crop_params) + images = images_transformed + + batch = {} + batch["image"] = torch.stack(images) + batch["image_for_vis"] = images_for_vis + batch["n"] = len(images) + batch["ind"] = torch.tensor(ids), + batch["crop_parameters"] = torch.stack(crop_parameters) + batch["distortion_parameters"] = torch.zeros(4) + + return batch diff --git a/diffusionsfm/eval/__init__.py b/diffusionsfm/eval/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/diffusionsfm/eval/eval_category.py b/diffusionsfm/eval/eval_category.py new file mode 100644 index 0000000000000000000000000000000000000000..250b7c0f5f625d222ff4dce6b479837b41c1215c --- /dev/null +++ b/diffusionsfm/eval/eval_category.py @@ -0,0 +1,292 @@ +import os +import json +import torch +import torchvision +import numpy as np +from tqdm.auto import tqdm + +from diffusionsfm.dataset.co3d_v2 import ( + Co3dDataset, + full_scene_scale, +) +from pytorch3d.renderer import PerspectiveCameras +from diffusionsfm.utils.visualization import filter_and_align_point_clouds +from diffusionsfm.inference.load_model import load_model +from diffusionsfm.inference.predict import predict_cameras +from diffusionsfm.utils.geometry import ( + compute_angular_error_batch, + get_error, + n_to_np_rotations, +) +from diffusionsfm.utils.slurm import init_slurm_signals_if_slurm +from diffusionsfm.utils.rays import cameras_to_rays +from diffusionsfm.utils.rays import normalize_cameras_batch + + +@torch.no_grad() +def evaluate( + cfg, + model, + dataset, + num_images, + device, + use_pbar=True, + calculate_intrinsics=True, + additional_timesteps=(), + num_evaluate=None, + max_num_images=None, + mode=None, + metrics=True, + load_depth=True, +): + if cfg.training.get("dpt_head", False): + H_in = W_in = 224 + H_out = W_out = cfg.training.full_num_patches_y + else: + H_in = H_out = cfg.model.num_patches_x + W_in = W_out = cfg.model.num_patches_y + + results = {} + instances = np.arange(0, len(dataset)) if num_evaluate is None else np.linspace(0, len(dataset) - 1, num_evaluate, endpoint=True, dtype=int) + instances = tqdm(instances) if use_pbar else instances + + for counter, idx in enumerate(instances): + batch = dataset[idx] + instance = batch["model_id"] + images = batch["image"].to(device) + focal_length = batch["focal_length"].to(device)[:num_images] + R = batch["R"].to(device)[:num_images] + T = batch["T"].to(device)[:num_images] + crop_parameters = batch["crop_parameters"].to(device)[:num_images] + + if load_depth: + depths = batch["depth"].to(device)[:num_images] + depth_masks = batch["depth_masks"].to(device)[:num_images] + try: + object_masks = batch["object_masks"].to(device)[:num_images] + except KeyError: + object_masks = depth_masks.clone() + + # Normalize cameras and scale depths for output resolution + cameras_gt = PerspectiveCameras( + R=R, T=T, focal_length=focal_length, device=device + ) + cameras_gt, _, _ = normalize_cameras_batch( + [cameras_gt], + first_cam_mediod=cfg.training.first_cam_mediod, + normalize_first_camera=cfg.training.normalize_first_camera, + depths=depths.unsqueeze(0), + crop_parameters=crop_parameters.unsqueeze(0), + num_patches_x=H_in, + num_patches_y=W_in, + return_scales=True, + ) + cameras_gt = cameras_gt[0] + + gt_rays = cameras_to_rays( + cameras=cameras_gt, + num_patches_x=H_in, + num_patches_y=W_in, + crop_parameters=crop_parameters, + depths=depths, + mode=mode, + ) + gt_points = gt_rays.get_segments().view(num_images, -1, 3) + + resize = torchvision.transforms.Resize( + 224, + antialias=False, + interpolation=torchvision.transforms.InterpolationMode.NEAREST_EXACT, + ) + else: + cameras_gt = PerspectiveCameras( + R=R, T=T, focal_length=focal_length, device=device + ) + + pred_cameras, additional_cams = predict_cameras( + model, + images, + device, + crop_parameters=crop_parameters, + num_patches_x=H_out, + num_patches_y=W_out, + max_num_images=max_num_images, + additional_timesteps=additional_timesteps, + calculate_intrinsics=calculate_intrinsics, + mode=mode, + return_rays=True, + use_homogeneous=cfg.model.get("use_homogeneous", False), + ) + cameras_to_evaluate = additional_cams + [pred_cameras] + + all_cams_batch = dataset.get_data( + sequence_name=instance, ids=np.arange(0, batch["n"]), no_images=True + ) + gt_scene_scale = full_scene_scale(all_cams_batch) + R_gt = R + T_gt = T + + errors = [] + for _, (camera, pred_rays) in enumerate(cameras_to_evaluate): + R_pred = camera.R + T_pred = camera.T + f_pred = camera.focal_length + + R_pred_rel = n_to_np_rotations(num_images, R_pred).cpu().numpy() + R_gt_rel = n_to_np_rotations(num_images, batch["R"]).cpu().numpy() + R_error = compute_angular_error_batch(R_pred_rel, R_gt_rel) + + CC_error, _ = get_error(True, R_pred, T_pred, R_gt, T_gt, gt_scene_scale) + + if load_depth and metrics: + # Evaluate outputs at the same resolution as DUSt3R + pred_points = pred_rays.get_segments().view(num_images, H_out, H_out, 3) + pred_points = pred_points.permute(0, 3, 1, 2) + pred_points = resize(pred_points).permute(0, 2, 3, 1).view(num_images, H_in*W_in, 3) + + ( + _, + _, + _, + _, + metric_values, + ) = filter_and_align_point_clouds( + num_images, + gt_points, + pred_points, + depth_masks, + depth_masks, + images, + metrics=metrics, + num_patches_x=H_in, + ) + + ( + _, + _, + _, + _, + object_metric_values, + ) = filter_and_align_point_clouds( + num_images, + gt_points, + pred_points, + depth_masks * object_masks, + depth_masks * object_masks, + images, + metrics=metrics, + num_patches_x=H_in, + ) + + result = { + "R_pred": R_pred.detach().cpu().numpy().tolist(), + "T_pred": T_pred.detach().cpu().numpy().tolist(), + "f_pred": f_pred.detach().cpu().numpy().tolist(), + "R_gt": R_gt.detach().cpu().numpy().tolist(), + "T_gt": T_gt.detach().cpu().numpy().tolist(), + "f_gt": focal_length.detach().cpu().numpy().tolist(), + "scene_scale": gt_scene_scale, + "R_error": R_error.tolist(), + "CC_error": CC_error, + } + + if load_depth and metrics: + result["CD"] = metric_values[1] + result["CD_Object"] = object_metric_values[1] + else: + result["CD"] = 0 + result["CD_Object"] = 0 + + errors.append(result) + + results[instance] = errors + + if counter == len(dataset) - 1: + break + return results + + +def save_results( + output_dir, + checkpoint=800_000, + category="hydrant", + num_images=None, + calculate_additional_timesteps=True, + calculate_intrinsics=True, + split="test", + force=False, + sample_num=1, + max_num_images=None, + dataset="co3d", +): + init_slurm_signals_if_slurm() + os.umask(000) # Default to 777 permissions + eval_path = os.path.join( + output_dir, + f"eval_{dataset}", + f"{category}_{num_images}_{sample_num}_ckpt{checkpoint}.json", + ) + + if os.path.exists(eval_path) and not force: + print(f"File {eval_path} already exists. Skipping.") + return + + if num_images is not None and num_images > 8: + custom_keys = {"model.num_images": num_images} + ignore_keys = ["pos_table"] + else: + custom_keys = None + ignore_keys = [] + + device = torch.device("cuda") + model, cfg = load_model( + output_dir, + checkpoint=checkpoint, + device=device, + custom_keys=custom_keys, + ignore_keys=ignore_keys, + ) + if num_images is None: + num_images = cfg.dataset.num_images + + if cfg.training.dpt_head: + # Evaluate outputs at the same resolution as DUSt3R + depth_size = 224 + else: + depth_size = cfg.model.num_patches_x + + dataset = Co3dDataset( + category=category, + split=split, + num_images=num_images, + apply_augmentation=False, + sample_num=None if split == "train" else sample_num, + use_global_intrinsics=cfg.dataset.use_global_intrinsics, + load_depths=True, + center_crop=True, + depth_size=depth_size, + mask_holes=not cfg.training.regression, + img_size=256 if cfg.model.unet_diffuser else 224, + ) + print(f"Category {category} {len(dataset)}") + + if calculate_additional_timesteps: + additional_timesteps = [0, 10, 20, 30, 40, 50, 60, 70, 80, 90] + else: + additional_timesteps = [] + + results = evaluate( + cfg=cfg, + model=model, + dataset=dataset, + num_images=num_images, + device=device, + calculate_intrinsics=calculate_intrinsics, + additional_timesteps=additional_timesteps, + max_num_images=max_num_images, + mode="segment", + ) + + os.makedirs(os.path.dirname(eval_path), exist_ok=True) + with open(eval_path, "w") as f: + json.dump(results, f) \ No newline at end of file diff --git a/diffusionsfm/eval/eval_jobs.py b/diffusionsfm/eval/eval_jobs.py new file mode 100644 index 0000000000000000000000000000000000000000..93e13629553ec5e9bd2d61038bd90207470b8396 --- /dev/null +++ b/diffusionsfm/eval/eval_jobs.py @@ -0,0 +1,175 @@ +""" +python -m diffusionsfm.eval.eval_jobs --eval_path output/multi_diffusionsfm_dense --use_submitit +""" + +import os +import json +import submitit +import argparse +import itertools +from glob import glob + +import numpy as np +from tqdm.auto import tqdm + +from diffusionsfm.dataset.co3d_v2 import TEST_CATEGORIES, TRAINING_CATEGORIES +from diffusionsfm.eval.eval_category import save_results +from diffusionsfm.utils.slurm import submitit_job_watcher + + +def evaluate_diffusionsfm(eval_path, use_submitit, mode): + JOB_PARAMS = { + "output_dir": [eval_path], + "checkpoint": [800_000], + "num_images": [2, 3, 4, 5, 6, 7, 8], + "sample_num": [0, 1, 2, 3, 4], + "category": TEST_CATEGORIES, # TRAINING_CATEGORIES + TEST_CATEGORIES, + "calculate_additional_timesteps": [True], + } + if mode == "test": + JOB_PARAMS["category"] = TEST_CATEGORIES + elif mode == "train1": + JOB_PARAMS["category"] = TRAINING_CATEGORIES[:len(TRAINING_CATEGORIES) // 2] + elif mode == "train2": + JOB_PARAMS["category"] = TRAINING_CATEGORIES[len(TRAINING_CATEGORIES) // 2:] + keys, values = zip(*JOB_PARAMS.items()) + job_configs = [dict(zip(keys, p)) for p in itertools.product(*values)] + + if use_submitit: + log_output = "./slurm_logs" + executor = submitit.AutoExecutor( + cluster=None, folder=log_output, slurm_max_num_timeout=10 + ) + # Use your own parameters + executor.update_parameters( + slurm_additional_parameters={ + "nodes": 1, + "cpus-per-task": 5, + "gpus": 1, + "time": "6:00:00", + "partition": "all", + "exclude": "grogu-1-9, grogu-1-14," + } + ) + jobs = [] + with executor.batch(): + # This context manager submits all jobs at once at the end. + for params in job_configs: + job = executor.submit(save_results, **params) + job_param = f"{params['category']}_N{params['num_images']}_{params['sample_num']}" + jobs.append((job_param, job)) + jobs = {f"{job_param}_{job.job_id}": job for job_param, job in jobs} + submitit_job_watcher(jobs) + else: + for job_config in tqdm(job_configs): + # This is much slower. + save_results(**job_config) + + +def process_predictions(eval_path, pred_index, checkpoint=800_000, threshold_R=15, threshold_CC=0.1): + """ + pred_index should be 1 (corresponding to T=90) + """ + def aggregate_per_category(categories, metric_key, num_images, sample_num, threshold=None): + """ + Aggregates one metric over all data points in a prediction file and then across categories. + - For R_error and CC_error: use mean to threshold-based accuracy + - For CD and CD_Object: use median to reduce the effect of outliers + """ + per_category_values = [] + + for category in tqdm(categories, desc=f"Sample {sample_num}, N={num_images}, {metric_key}"): + per_pred_values = [] + + data_path = glob( + os.path.join(eval_path, "eval", f"{category}_{num_images}_{sample_num}_ckpt{checkpoint}*.json") + )[0] + + with open(data_path) as f: + eval_data = json.load(f) + + for preds in eval_data.values(): + if metric_key in ["R_error", "CC_error"]: + vals = np.array(preds[pred_index][metric_key]) + per_pred_values.append(np.mean(vals < threshold)) + else: + per_pred_values.append(preds[pred_index][metric_key]) + + # Aggregate over all predictions within this category + per_category_values.append( + np.mean(per_pred_values) if metric_key in ["R_error", "CC_error"] + else np.median(per_pred_values) # CD or CD_Object — use median to filter outliers + ) + + if metric_key in ["R_error", "CC_error"]: + return np.mean(per_category_values) + else: + return np.median(per_category_values) + + def aggregate_metric(categories, metric_key, num_images, threshold=None): + """Aggregates one metric over 5 random samples per category and returns the final mean""" + return np.mean([ + aggregate_per_category(categories, metric_key, num_images, sample_num, threshold=threshold) + for sample_num in range(5) + ]) + + # Output containers + all_seen_acc_R, all_seen_acc_CC = [], [] + all_seen_CD, all_seen_CD_Object = [], [] + all_unseen_acc_R, all_unseen_acc_CC = [], [] + all_unseen_CD, all_unseen_CD_Object = [], [] + + for num_images in range(2, 9): + # Seen categories + all_seen_acc_R.append( + aggregate_metric(TRAINING_CATEGORIES, "R_error", num_images, threshold=threshold_R) + ) + all_seen_acc_CC.append( + aggregate_metric(TRAINING_CATEGORIES, "CC_error", num_images, threshold=threshold_CC) + ) + all_seen_CD.append( + aggregate_metric(TRAINING_CATEGORIES, "CD", num_images) + ) + all_seen_CD_Object.append( + aggregate_metric(TRAINING_CATEGORIES, "CD_Object", num_images) + ) + + # Unseen categories + all_unseen_acc_R.append( + aggregate_metric(TEST_CATEGORIES, "R_error", num_images, threshold=threshold_R) + ) + all_unseen_acc_CC.append( + aggregate_metric(TEST_CATEGORIES, "CC_error", num_images, threshold=threshold_CC) + ) + all_unseen_CD.append( + aggregate_metric(TEST_CATEGORIES, "CD", num_images) + ) + all_unseen_CD_Object.append( + aggregate_metric(TEST_CATEGORIES, "CD_Object", num_images) + ) + + # Print the results in formatted rows + print("N= ", " ".join(f"{i: 5}" for i in range(2, 9))) + print("Seen R ", " ".join([f"{x:0.3f}" for x in all_seen_acc_R])) + print("Seen CC ", " ".join([f"{x:0.3f}" for x in all_seen_acc_CC])) + print("Seen CD ", " ".join([f"{x:0.3f}" for x in all_seen_CD])) + print("Seen CD_Obj ", " ".join([f"{x:0.3f}" for x in all_seen_CD_Object])) + print("Unseen R ", " ".join([f"{x:0.3f}" for x in all_unseen_acc_R])) + print("Unseen CC ", " ".join([f"{x:0.3f}" for x in all_unseen_acc_CC])) + print("Unseen CD ", " ".join([f"{x:0.3f}" for x in all_unseen_CD])) + print("Unseen CD_Obj", " ".join([f"{x:0.3f}" for x in all_unseen_CD_Object])) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--eval_path", type=str, default=None) + parser.add_argument("--use_submitit", action="store_true") + parser.add_argument("--mode", type=str, default="test") + args = parser.parse_args() + + eval_path = "output/multi_diffusionsfm_dense" if args.eval_path is None else args.eval_path + use_submitit = args.use_submitit + mode = args.mode + + evaluate_diffusionsfm(eval_path, use_submitit, mode) + process_predictions(eval_path, 1) \ No newline at end of file diff --git a/diffusionsfm/inference/__init__.py b/diffusionsfm/inference/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/diffusionsfm/inference/ddim.py b/diffusionsfm/inference/ddim.py new file mode 100644 index 0000000000000000000000000000000000000000..8f710d3ec4504b91a3870afb5158c2a95460bc7b --- /dev/null +++ b/diffusionsfm/inference/ddim.py @@ -0,0 +1,145 @@ +import torch +import random +import numpy as np +from tqdm.auto import tqdm + +from diffusionsfm.utils.rays import compute_ndc_coordinates + + +def inference_ddim( + model, + images, + device, + crop_parameters=None, + eta=0, + num_inference_steps=100, + pbar=True, + num_patches_x=16, + num_patches_y=16, + visualize=False, + seed=0, +): + """ + Implements DDIM-style inference. + + To get multiple samples, batch the images multiple times. + + Args: + model: Ray Diffuser. + images (torch.Tensor): (B, N, C, H, W). + patch_rays_gt (torch.Tensor): If provided, the patch rays which are ground + truth (B, N, P, 6). + eta (float, optional): Stochasticity coefficient. 0 is completely deterministic, + 1 is equivalent to DDPM. (Default: 0) + num_inference_steps (int, optional): Number of inference steps. (Default: 100) + pbar (bool, optional): Whether to show progress bar. (Default: True) + """ + timesteps = model.noise_scheduler.compute_inference_timesteps(num_inference_steps) + batch_size = images.shape[0] + num_images = images.shape[1] + + if isinstance(eta, list): + eta_0, eta_1 = float(eta[0]), float(eta[1]) + else: + eta_0, eta_1 = 0, 0 + + # Fixing seed + if seed is not None: + torch.manual_seed(seed) + random.seed(seed) + np.random.seed(seed) + + with torch.no_grad(): + x_tau = torch.randn( + batch_size, + num_images, + model.ray_out if hasattr(model, "ray_out") else model.ray_dim, + num_patches_x, + num_patches_y, + device=device, + ) + + if visualize: + x_taus = [x_tau] + all_pred = [] + noise_samples = [] + + image_features = model.feature_extractor(images, autoresize=True) + + if model.append_ndc: + ndc_coordinates = compute_ndc_coordinates( + crop_parameters=crop_parameters, + no_crop_param_device="cpu", + num_patches_x=model.width, + num_patches_y=model.width, + distortion_coeffs=None, + )[..., :2].to(device) + ndc_coordinates = ndc_coordinates.permute(0, 1, 4, 2, 3) + else: + ndc_coordinates = None + + loop = tqdm(range(len(timesteps))) if pbar else range(len(timesteps)) + for t in loop: + tau = timesteps[t] + + if tau > 0 and eta_1 > 0: + z = torch.randn( + batch_size, + num_images, + model.ray_out if hasattr(model, "ray_out") else model.ray_dim, + num_patches_x, + num_patches_y, + device=device, + ) + else: + z = 0 + + alpha = model.noise_scheduler.alphas_cumprod[tau] + if tau > 0: + tau_prev = timesteps[t + 1] + alpha_prev = model.noise_scheduler.alphas_cumprod[tau_prev] + else: + alpha_prev = torch.tensor(1.0, device=device).float() + + sigma_t = ( + torch.sqrt((1 - alpha_prev) / (1 - alpha)) + * torch.sqrt(1 - alpha / alpha_prev) + ) + + eps_pred, noise_sample = model( + features=image_features, + rays_noisy=x_tau, + t=int(tau), + ndc_coordinates=ndc_coordinates, + ) + + if model.use_homogeneous: + p1 = eps_pred[:, :, :4] + p2 = eps_pred[:, :, 4:] + + c1 = torch.linalg.norm(p1, dim=2, keepdim=True) + c2 = torch.linalg.norm(p2, dim=2, keepdim=True) + eps_pred[:, :, :4] = p1 / c1 + eps_pred[:, :, 4:] = p2 / c2 + + if visualize: + all_pred.append(eps_pred.clone()) + noise_samples.append(noise_sample) + + # TODO: Can simplify this a lot + x0_pred = eps_pred.clone() + eps_pred = (x_tau - torch.sqrt(alpha) * eps_pred) / torch.sqrt( + 1 - alpha + ) + + dir_x_tau = torch.sqrt(1 - alpha_prev - eta_0*sigma_t**2) * eps_pred + noise = eta_1 * sigma_t * z + + new_x_tau = torch.sqrt(alpha_prev) * x0_pred + dir_x_tau + noise + x_tau = new_x_tau + + if visualize: + x_taus.append(x_tau.detach().clone()) + if visualize: + return x_tau, x_taus, all_pred, noise_samples + return x_tau diff --git a/diffusionsfm/inference/load_model.py b/diffusionsfm/inference/load_model.py new file mode 100644 index 0000000000000000000000000000000000000000..fd8d6032271a89befdefe2bb6579842add2d9216 --- /dev/null +++ b/diffusionsfm/inference/load_model.py @@ -0,0 +1,97 @@ +import os.path as osp +from glob import glob + +import torch +from omegaconf import OmegaConf + +from diffusionsfm.model.diffuser import RayDiffuser +from diffusionsfm.model.diffuser_dpt import RayDiffuserDPT +from diffusionsfm.model.scheduler import NoiseScheduler + + +def load_model( + output_dir, checkpoint=None, device="cuda:0", custom_keys=None, ignore_keys=() +): + """ + Loads a model and config from an output directory. + + E.g. to load with different number of images, + ``` + custom_keys={"model.num_images": 15}, ignore_keys=["pos_table"] + ``` + + Args: + output_dir (str): Path to the output directory. + checkpoint (str or int): Path to the checkpoint to load. If None, loads the + latest checkpoint. + device (str): Device to load the model on. + custom_keys (dict): Dictionary of custom keys to override in the config. + """ + if checkpoint is None: + checkpoint_path = sorted(glob(osp.join(output_dir, "checkpoints", "*.pth")))[-1] + else: + if isinstance(checkpoint, int): + checkpoint_name = f"ckpt_{checkpoint:08d}.pth" + else: + checkpoint_name = checkpoint + checkpoint_path = osp.join(output_dir, "checkpoints", checkpoint_name) + print("Loading checkpoint", osp.basename(checkpoint_path)) + + cfg = OmegaConf.load(osp.join(output_dir, "hydra", "config.yaml")) + if custom_keys is not None: + for k, v in custom_keys.items(): + OmegaConf.update(cfg, k, v) + noise_scheduler = NoiseScheduler( + type=cfg.noise_scheduler.type, + max_timesteps=cfg.noise_scheduler.max_timesteps, + beta_start=cfg.noise_scheduler.beta_start, + beta_end=cfg.noise_scheduler.beta_end, + ) + + if not cfg.training.get("dpt_head", False): + model = RayDiffuser( + depth=cfg.model.depth, + width=cfg.model.num_patches_x, + P=1, + max_num_images=cfg.model.num_images, + noise_scheduler=noise_scheduler, + feature_extractor=cfg.model.feature_extractor, + append_ndc=cfg.model.append_ndc, + diffuse_depths=cfg.training.get("diffuse_depths", False), + depth_resolution=cfg.training.get("depth_resolution", 1), + use_homogeneous=cfg.model.get("use_homogeneous", False), + cond_depth_mask=cfg.model.get("cond_depth_mask", False), + ).to(device) + else: + model = RayDiffuserDPT( + depth=cfg.model.depth, + width=cfg.model.num_patches_x, + P=1, + max_num_images=cfg.model.num_images, + noise_scheduler=noise_scheduler, + feature_extractor=cfg.model.feature_extractor, + append_ndc=cfg.model.append_ndc, + diffuse_depths=cfg.training.get("diffuse_depths", False), + depth_resolution=cfg.training.get("depth_resolution", 1), + encoder_features=cfg.training.get("dpt_encoder_features", False), + use_homogeneous=cfg.model.get("use_homogeneous", False), + cond_depth_mask=cfg.model.get("cond_depth_mask", False), + ).to(device) + + data = torch.load(checkpoint_path) + state_dict = {} + for k, v in data["state_dict"].items(): + include = True + for ignore_key in ignore_keys: + if ignore_key in k: + include = False + if include: + state_dict[k] = v + + missing, unexpected = model.load_state_dict(state_dict, strict=False) + if len(missing) > 0: + print("Missing keys:", missing) + if len(unexpected) > 0: + print("Unexpected keys:", unexpected) + model = model.eval() + return model, cfg diff --git a/diffusionsfm/inference/predict.py b/diffusionsfm/inference/predict.py new file mode 100644 index 0000000000000000000000000000000000000000..fc38fb376f64d4340a87ff854dbdbcde8351a177 --- /dev/null +++ b/diffusionsfm/inference/predict.py @@ -0,0 +1,93 @@ +from diffusionsfm.inference.ddim import inference_ddim +from diffusionsfm.utils.rays import ( + Rays, + rays_to_cameras, + rays_to_cameras_homography, +) + + +def predict_cameras( + model, + images, + device, + crop_parameters=None, + num_patches_x=16, + num_patches_y=16, + additional_timesteps=(), + calculate_intrinsics=False, + max_num_images=None, + mode=None, + return_rays=False, + use_homogeneous=False, + seed=0, +): + """ + Args: + images (torch.Tensor): (N, C, H, W) + crop_parameters (torch.Tensor): (N, 4) or None + """ + if calculate_intrinsics: + ray_to_cam = rays_to_cameras_homography + else: + ray_to_cam = rays_to_cameras + + get_spatial_rays = Rays.from_spatial + + rays_final, rays_intermediate, pred_intermediate, _ = inference_ddim( + model, + images.unsqueeze(0), + device, + visualize=True, + crop_parameters=crop_parameters.unsqueeze(0), + num_patches_x=num_patches_x, + num_patches_y=num_patches_y, + pbar=False, + eta=[1, 0], + num_inference_steps=100, + ) + + spatial_rays = get_spatial_rays( + rays_final[0], + mode=mode, + num_patches_x=num_patches_x, + num_patches_y=num_patches_y, + use_homogeneous=use_homogeneous, + ) + + pred_cam = ray_to_cam( + spatial_rays, + crop_parameters, + num_patches_x=num_patches_x, + num_patches_y=num_patches_y, + depth_resolution=model.depth_resolution, + average_centers=True, + directions_from_averaged_center=True, + ) + + additional_predictions = [] + for t in additional_timesteps: + ray = pred_intermediate[t] + + ray = get_spatial_rays( + ray[0], + mode=mode, + num_patches_x=num_patches_x, + num_patches_y=num_patches_y, + use_homogeneous=use_homogeneous, + ) + + cam = ray_to_cam( + ray, + crop_parameters, + num_patches_x=num_patches_x, + num_patches_y=num_patches_y, + average_centers=True, + directions_from_averaged_center=True, + ) + if return_rays: + cam = (cam, ray) + additional_predictions.append(cam) + + if return_rays: + return (pred_cam, spatial_rays), additional_predictions + return pred_cam, additional_predictions, spatial_rays \ No newline at end of file diff --git a/diffusionsfm/model/base_model.py b/diffusionsfm/model/base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..5c2e0e93b0495f48a3405546b6fe1969be3480a2 --- /dev/null +++ b/diffusionsfm/model/base_model.py @@ -0,0 +1,16 @@ +import torch + + +class BaseModel(torch.nn.Module): + def load(self, path): + """Load model from file. + + Args: + path (str): file path + """ + parameters = torch.load(path, map_location=torch.device("cpu")) + + if "optimizer" in parameters: + parameters = parameters["model"] + + self.load_state_dict(parameters) diff --git a/diffusionsfm/model/blocks.py b/diffusionsfm/model/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..829769f7ee646d40b6ed1464966d440a1acd5363 --- /dev/null +++ b/diffusionsfm/model/blocks.py @@ -0,0 +1,247 @@ +import torch +import torch.nn as nn +from diffusionsfm.model.dit import TimestepEmbedder +import ipdb + + +def modulate(x, shift, scale): + return x * (1 + scale.unsqueeze(-1).unsqueeze(-1)) + shift.unsqueeze(-1).unsqueeze( + -1 + ) + + +def _make_fusion_block(features, use_bn, use_ln, dpt_time, resolution): + return FeatureFusionBlock_custom( + features, + nn.ReLU(False), + deconv=False, + bn=use_bn, + expand=False, + align_corners=True, + dpt_time=dpt_time, + ln=use_ln, + resolution=resolution + ) + + +def _make_scratch(in_shape, out_shape, groups=1, expand=False): + scratch = nn.Module() + + out_shape1 = out_shape + out_shape2 = out_shape + out_shape3 = out_shape + out_shape4 = out_shape + if expand == True: + out_shape1 = out_shape + out_shape2 = out_shape * 2 + out_shape3 = out_shape * 4 + out_shape4 = out_shape * 8 + + scratch.layer1_rn = nn.Conv2d( + in_shape[0], + out_shape1, + kernel_size=3, + stride=1, + padding=1, + bias=False, + groups=groups, + ) + scratch.layer2_rn = nn.Conv2d( + in_shape[1], + out_shape2, + kernel_size=3, + stride=1, + padding=1, + bias=False, + groups=groups, + ) + scratch.layer3_rn = nn.Conv2d( + in_shape[2], + out_shape3, + kernel_size=3, + stride=1, + padding=1, + bias=False, + groups=groups, + ) + scratch.layer4_rn = nn.Conv2d( + in_shape[3], + out_shape4, + kernel_size=3, + stride=1, + padding=1, + bias=False, + groups=groups, + ) + + return scratch + + +class ResidualConvUnit_custom(nn.Module): + """Residual convolution module.""" + + def __init__(self, features, activation, bn, ln, dpt_time=False, resolution=16): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.bn = bn + self.ln = ln + + self.groups = 1 + + self.conv1 = nn.Conv2d( + features, + features, + kernel_size=3, + stride=1, + padding=1, + bias=not self.bn, + groups=self.groups, + ) + + self.conv2 = nn.Conv2d( + features, + features, + kernel_size=3, + stride=1, + padding=1, + bias=not self.bn, + groups=self.groups, + ) + + nn.init.kaiming_uniform_(self.conv1.weight) + nn.init.kaiming_uniform_(self.conv2.weight) + + if self.bn == True: + self.bn1 = nn.BatchNorm2d(features) + self.bn2 = nn.BatchNorm2d(features) + + if self.ln == True: + self.bn1 = nn.LayerNorm((features, resolution, resolution)) + self.bn2 = nn.LayerNorm((features, resolution, resolution)) + + self.activation = activation + + if dpt_time: + self.t_embedder = TimestepEmbedder(hidden_size=features) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), nn.Linear(features, 3 * features, bias=True) + ) + + def forward(self, x, t=None): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + if t is not None: + # Embed timestamp & calculate shift parameters + t = self.t_embedder(t) # (B*N) + shift, scale, gate = self.adaLN_modulation(t).chunk(3, dim=1) # (B * N, T) + + # Shift & scale x + x = modulate(x, shift, scale) # (B * N, T, H, W) + + out = self.activation(x) + out = self.conv1(out) + if self.bn or self.ln: + out = self.bn1(out) + + out = self.activation(out) + out = self.conv2(out) + if self.bn or self.ln: + out = self.bn2(out) + + if self.groups > 1: + out = self.conv_merge(out) + + if t is not None: + out = gate.unsqueeze(-1).unsqueeze(-1) * out + + return out + x + + +class FeatureFusionBlock_custom(nn.Module): + """Feature fusion block.""" + + def __init__( + self, + features, + activation, + deconv=False, + bn=False, + ln=False, + expand=False, + align_corners=True, + dpt_time=False, + resolution=16, + ): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock_custom, self).__init__() + + self.deconv = deconv + self.align_corners = align_corners + + self.groups = 1 + + self.expand = expand + out_features = features + if self.expand == True: + out_features = features // 2 + + self.out_conv = nn.Conv2d( + features, + out_features, + kernel_size=1, + stride=1, + padding=0, + bias=True, + groups=1, + ) + + nn.init.kaiming_uniform_(self.out_conv.weight) + + # The second block sees time + self.resConfUnit1 = ResidualConvUnit_custom( + features, activation, bn=bn, ln=ln, dpt_time=False, resolution=resolution + ) + self.resConfUnit2 = ResidualConvUnit_custom( + features, activation, bn=bn, ln=ln, dpt_time=dpt_time, resolution=resolution + ) + + def forward(self, input, activation=None, t=None): + """Forward pass. + + Returns: + tensor: output + """ + output = input + + if activation is not None: + res = self.resConfUnit1(activation) + + output += res + + output = self.resConfUnit2(output, t) + + output = torch.nn.functional.interpolate( + output.float(), + scale_factor=2, + mode="bilinear", + align_corners=self.align_corners, + ) + + output = self.out_conv(output) + + return output diff --git a/diffusionsfm/model/diffuser.py b/diffusionsfm/model/diffuser.py new file mode 100644 index 0000000000000000000000000000000000000000..95d69eee2751192219cf1a5926dc29ea562a6abc --- /dev/null +++ b/diffusionsfm/model/diffuser.py @@ -0,0 +1,195 @@ +import ipdb # noqa: F401 +import numpy as np +import torch +import torch.nn as nn + +from diffusionsfm.model.dit import DiT +from diffusionsfm.model.feature_extractors import PretrainedVAE, SpatialDino +from diffusionsfm.model.scheduler import NoiseScheduler + + +class RayDiffuser(nn.Module): + def __init__( + self, + model_type="dit", + depth=8, + width=16, + hidden_size=1152, + P=1, + max_num_images=1, + noise_scheduler=None, + freeze_encoder=True, + feature_extractor="dino", + append_ndc=True, + use_unconditional=False, + diffuse_depths=False, + depth_resolution=1, + use_homogeneous=False, + cond_depth_mask=False, + ): + super().__init__() + if noise_scheduler is None: + self.noise_scheduler = NoiseScheduler() + else: + self.noise_scheduler = noise_scheduler + + self.diffuse_depths = diffuse_depths + self.depth_resolution = depth_resolution + self.use_homogeneous = use_homogeneous + + self.ray_dim = 3 + if self.use_homogeneous: + self.ray_dim += 1 + + self.ray_dim += self.ray_dim * self.depth_resolution**2 + + if self.diffuse_depths: + self.ray_dim += 1 + + self.append_ndc = append_ndc + self.width = width + + self.max_num_images = max_num_images + self.model_type = model_type + self.use_unconditional = use_unconditional + self.cond_depth_mask = cond_depth_mask + + if feature_extractor == "dino": + self.feature_extractor = SpatialDino( + freeze_weights=freeze_encoder, num_patches_x=width, num_patches_y=width + ) + self.feature_dim = self.feature_extractor.feature_dim + elif feature_extractor == "vae": + self.feature_extractor = PretrainedVAE( + freeze_weights=freeze_encoder, num_patches_x=width, num_patches_y=width + ) + self.feature_dim = self.feature_extractor.feature_dim + else: + raise Exception(f"Unknown feature extractor {feature_extractor}") + + if self.use_unconditional: + self.register_parameter( + "null_token", nn.Parameter(torch.randn(self.feature_dim, 1, 1)) + ) + + self.input_dim = self.feature_dim * 2 + + if self.append_ndc: + self.input_dim += 2 + + if model_type == "dit": + self.ray_predictor = DiT( + in_channels=self.input_dim, + out_channels=self.ray_dim, + width=width, + depth=depth, + hidden_size=hidden_size, + max_num_images=max_num_images, + P=P, + ) + + self.scratch = nn.Module() + self.scratch.input_conv = nn.Linear(self.ray_dim + int(self.cond_depth_mask), self.feature_dim) + + def forward_noise( + self, x, t, epsilon=None, zero_out_mask=None + ): + """ + Applies forward diffusion (adds noise) to the input. + + If a mask is provided, the noise is only applied to the masked inputs. + """ + t = t.reshape(-1, 1, 1, 1, 1) + + if epsilon is None: + epsilon = torch.randn_like(x) + else: + epsilon = epsilon.reshape(x.shape) + + alpha_bar = self.noise_scheduler.alphas_cumprod[t] + x_noise = torch.sqrt(alpha_bar) * x + torch.sqrt(1 - alpha_bar) * epsilon + + if zero_out_mask is not None and self.cond_depth_mask: + x_noise = x_noise * zero_out_mask + + return x_noise, epsilon + + def forward( + self, + features=None, + images=None, + rays=None, + rays_noisy=None, + t=None, + ndc_coordinates=None, + unconditional_mask=None, + return_dpt_activations=False, + depth_mask=None, + ): + """ + Args: + images: (B, N, 3, H, W). + t: (B,). + rays: (B, N, 6, H, W). + rays_noisy: (B, N, 6, H, W). + ndc_coordinates: (B, N, 2, H, W). + unconditional_mask: (B, N) or (B,). Should be 1 for unconditional samples + and 0 else. + """ + + if features is None: + # VAE expects 256x256 images while DINO expects 224x224 images. + # Both feature extractors support autoresize=True, but ideally we should + # set this to be false and handle in the dataloader. + features = self.feature_extractor(images, autoresize=True) + + B = features.shape[0] + + if ( + unconditional_mask is not None + and self.use_unconditional + ): + null_token = self.null_token.reshape(1, 1, self.feature_dim, 1, 1) + unconditional_mask = unconditional_mask.reshape(B, -1, 1, 1, 1) + features = ( + features * (1 - unconditional_mask) + null_token * unconditional_mask + ) + + if isinstance(t, int) or isinstance(t, np.int64): + t = torch.ones(1, dtype=int).to(features.device) * t + else: + t = t.reshape(B) + + if rays_noisy is None: + if self.cond_depth_mask: + rays_noisy, epsilon = self.forward_noise(rays, t, zero_out_mask=depth_mask.unsqueeze(2)) + else: + rays_noisy, epsilon = self.forward_noise(rays, t) + else: + epsilon = None + + if self.cond_depth_mask: + if depth_mask is None: + depth_mask = torch.ones_like(rays_noisy[:, :, 0]) + ray_repr = torch.cat([rays_noisy, depth_mask.unsqueeze(2)], dim=2) + else: + ray_repr = rays_noisy + + ray_repr = ray_repr.permute(0, 1, 3, 4, 2) + ray_repr = self.scratch.input_conv(ray_repr).permute(0, 1, 4, 2, 3).contiguous() + + scene_features = torch.cat([features, ray_repr], dim=2) + + if self.append_ndc: + scene_features = torch.cat([scene_features, ndc_coordinates], dim=2) + + epsilon_pred = self.ray_predictor( + scene_features, + t, + return_dpt_activations=return_dpt_activations, + ) + + if return_dpt_activations: + return epsilon_pred, rays_noisy, epsilon + + return epsilon_pred, epsilon diff --git a/diffusionsfm/model/diffuser_dpt.py b/diffusionsfm/model/diffuser_dpt.py new file mode 100644 index 0000000000000000000000000000000000000000..14abd38f3c27cc58f750d505bde0a0c61c11aeb7 --- /dev/null +++ b/diffusionsfm/model/diffuser_dpt.py @@ -0,0 +1,331 @@ +import ipdb # noqa: F401 +import numpy as np +import torch +import torch.nn as nn + +from diffusionsfm.model.dit import DiT +from diffusionsfm.model.feature_extractors import PretrainedVAE, SpatialDino +from diffusionsfm.model.blocks import _make_fusion_block, _make_scratch +from diffusionsfm.model.scheduler import NoiseScheduler + + +# functional implementation +def nearest_neighbor_upsample(x: torch.Tensor, scale_factor: int): + """Upsample {x} (NCHW) by scale factor {scale_factor} using nearest neighbor interpolation.""" + s = scale_factor + return ( + x.reshape(*x.shape, 1, 1) + .expand(*x.shape, s, s) + .transpose(-2, -3) + .reshape(*x.shape[:2], *(s * hw for hw in x.shape[2:])) + ) + + +class ProjectReadout(nn.Module): + def __init__(self, in_features, start_index=1): + super(ProjectReadout, self).__init__() + self.start_index = start_index + + self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU()) + + def forward(self, x): + readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :]) + features = torch.cat((x[:, self.start_index :], readout), -1) + + return self.project(features) + + +class RayDiffuserDPT(nn.Module): + def __init__( + self, + model_type="dit", + depth=8, + width=16, + hidden_size=1152, + P=1, + max_num_images=1, + noise_scheduler=None, + freeze_encoder=True, + feature_extractor="dino", + append_ndc=True, + use_unconditional=False, + diffuse_depths=False, + depth_resolution=1, + encoder_features=False, + use_homogeneous=False, + freeze_transformer=False, + cond_depth_mask=False, + ): + super().__init__() + if noise_scheduler is None: + self.noise_scheduler = NoiseScheduler() + else: + self.noise_scheduler = noise_scheduler + + self.diffuse_depths = diffuse_depths + self.depth_resolution = depth_resolution + self.use_homogeneous = use_homogeneous + + self.ray_dim = 3 + + if self.use_homogeneous: + self.ray_dim += 1 + self.ray_dim += self.ray_dim * self.depth_resolution**2 + + if self.diffuse_depths: + self.ray_dim += 1 + + self.append_ndc = append_ndc + self.width = width + + self.max_num_images = max_num_images + self.model_type = model_type + self.use_unconditional = use_unconditional + self.cond_depth_mask = cond_depth_mask + self.encoder_features = encoder_features + + if feature_extractor == "dino": + self.feature_extractor = SpatialDino( + freeze_weights=freeze_encoder, + num_patches_x=width, + num_patches_y=width, + activation_hooks=self.encoder_features, + ) + self.feature_dim = self.feature_extractor.feature_dim + elif feature_extractor == "vae": + self.feature_extractor = PretrainedVAE( + freeze_weights=freeze_encoder, num_patches_x=width, num_patches_y=width + ) + self.feature_dim = self.feature_extractor.feature_dim + else: + raise Exception(f"Unknown feature extractor {feature_extractor}") + + if self.use_unconditional: + self.register_parameter( + "null_token", nn.Parameter(torch.randn(self.feature_dim, 1, 1)) + ) + + self.input_dim = self.feature_dim * 2 + + if self.append_ndc: + self.input_dim += 2 + + if model_type == "dit": + self.ray_predictor = DiT( + in_channels=self.input_dim, + out_channels=self.ray_dim, + width=width, + depth=depth, + hidden_size=hidden_size, + max_num_images=max_num_images, + P=P, + ) + + if freeze_transformer: + for param in self.ray_predictor.parameters(): + param.requires_grad = False + + # Fusion blocks + self.f = 256 + + if self.encoder_features: + feature_lens = [ + self.feature_extractor.feature_dim, + self.feature_extractor.feature_dim, + self.ray_predictor.hidden_size, + self.ray_predictor.hidden_size, + ] + else: + feature_lens = [self.ray_predictor.hidden_size] * 4 + + self.scratch = _make_scratch(feature_lens, 256, groups=1, expand=False) + self.scratch.refinenet1 = _make_fusion_block( + self.f, use_bn=False, use_ln=False, dpt_time=True, resolution=128 + ) + self.scratch.refinenet2 = _make_fusion_block( + self.f, use_bn=False, use_ln=False, dpt_time=True, resolution=64 + ) + self.scratch.refinenet3 = _make_fusion_block( + self.f, use_bn=False, use_ln=False, dpt_time=True, resolution=32 + ) + self.scratch.refinenet4 = _make_fusion_block( + self.f, use_bn=False, use_ln=False, dpt_time=True, resolution=16 + ) + + self.scratch.input_conv = nn.Conv2d( + self.ray_dim + int(self.cond_depth_mask), + self.feature_dim, + kernel_size=16, + stride=16, + padding=0 + ) + + self.scratch.output_conv = nn.Sequential( + nn.Conv2d(self.f, self.f // 2, kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(), + nn.Conv2d(self.f // 2, 32, kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(), + nn.Conv2d(32, self.ray_dim, kernel_size=1, stride=1, padding=0), + nn.Identity(), + ) + + if self.encoder_features: + self.project_opers = nn.ModuleList([ + ProjectReadout(in_features=self.feature_extractor.feature_dim), + ProjectReadout(in_features=self.feature_extractor.feature_dim), + ]) + + def forward_noise( + self, x, t, epsilon=None, zero_out_mask=None + ): + """ + Applies forward diffusion (adds noise) to the input. + + If a mask is provided, the noise is only applied to the masked inputs. + """ + t = t.reshape(-1, 1, 1, 1, 1) + if epsilon is None: + epsilon = torch.randn_like(x) + else: + epsilon = epsilon.reshape(x.shape) + + alpha_bar = self.noise_scheduler.alphas_cumprod[t] + x_noise = torch.sqrt(alpha_bar) * x + torch.sqrt(1 - alpha_bar) * epsilon + + if zero_out_mask is not None and self.cond_depth_mask: + x_noise = zero_out_mask * x_noise + + return x_noise, epsilon + + def forward( + self, + features=None, + images=None, + rays=None, + rays_noisy=None, + t=None, + ndc_coordinates=None, + unconditional_mask=None, + encoder_patches=16, + depth_mask=None, + multiview_unconditional=False, + indices=None, + ): + """ + Args: + images: (B, N, 3, H, W). + t: (B,). + rays: (B, N, 6, H, W). + rays_noisy: (B, N, 6, H, W). + ndc_coordinates: (B, N, 2, H, W). + unconditional_mask: (B, N) or (B,). Should be 1 for unconditional samples + and 0 else. + """ + + if features is None: + # VAE expects 256x256 images while DINO expects 224x224 images. + # Both feature extractors support autoresize=True, but ideally we should + # set this to be false and handle in the dataloader. + features = self.feature_extractor(images, autoresize=True) + + B = features.shape[0] + + if unconditional_mask is not None and self.use_unconditional: + null_token = self.null_token.reshape(1, 1, self.feature_dim, 1, 1) + unconditional_mask = unconditional_mask.reshape(B, -1, 1, 1, 1) + features = ( + features * (1 - unconditional_mask) + null_token * unconditional_mask + ) + + if isinstance(t, int) or isinstance(t, np.int64): + t = torch.ones(1, dtype=int).to(features.device) * t + else: + t = t.reshape(B) + + if rays_noisy is None: + if self.cond_depth_mask: + rays_noisy, epsilon = self.forward_noise( + rays, t, zero_out_mask=depth_mask.unsqueeze(2) + ) + else: + rays_noisy, epsilon = self.forward_noise( + rays, t + ) + else: + epsilon = None + + # DOWNSAMPLE RAYS + B, N, C, H, W = rays_noisy.shape + + if self.cond_depth_mask: + if depth_mask is None: + depth_mask = torch.ones_like(rays_noisy[:, :, 0]) + ray_repr = torch.cat([rays_noisy, depth_mask.unsqueeze(2)], dim=2) + else: + ray_repr = rays_noisy + + ray_repr = self.scratch.input_conv(ray_repr.reshape(B * N, -1, H, W)) + _, CP, HP, WP = ray_repr.shape + ray_repr = ray_repr.reshape(B, N, CP, HP, WP) + scene_features = torch.cat([features, ray_repr], dim=2) + + if self.append_ndc: + scene_features = torch.cat([scene_features, ndc_coordinates], dim=2) + + # DIT FORWARD PASS + activations = self.ray_predictor( + scene_features, + t, + return_dpt_activations=True, + multiview_unconditional=multiview_unconditional, + ) + + # PROJECT ENCODER ACTIVATIONS & RESHAPE + if self.encoder_features: + for i in range(2): + name = f"encoder{i+1}" + + if indices is not None: + act = self.feature_extractor.activations[name][indices] + else: + act = self.feature_extractor.activations[name] + + act = self.project_opers[i](act).permute(0, 2, 1) + act = act.reshape( + ( + B * N, + self.feature_extractor.feature_dim, + encoder_patches, + encoder_patches, + ) + ) + activations[i] = act + + # UPSAMPLE ACTIVATIONS + for i, act in enumerate(activations): + k = 3 - i + activations[i] = nearest_neighbor_upsample(act, 2**k) + + # FUSION BLOCKS + layer_1_rn = self.scratch.layer1_rn(activations[0]) + layer_2_rn = self.scratch.layer2_rn(activations[1]) + layer_3_rn = self.scratch.layer3_rn(activations[2]) + layer_4_rn = self.scratch.layer4_rn(activations[3]) + + # RESHAPE TIMESTEPS + if t.shape[0] == B: + t = t.unsqueeze(-1).repeat((1, N)).reshape(B * N) + elif t.shape[0] == 1 and B > 1: + t = t.repeat((B * N)) + else: + assert False + + path_4 = self.scratch.refinenet4(layer_4_rn, t=t) + path_3 = self.scratch.refinenet3(path_4, activation=layer_3_rn, t=t) + path_2 = self.scratch.refinenet2(path_3, activation=layer_2_rn, t=t) + path_1 = self.scratch.refinenet1(path_2, activation=layer_1_rn, t=t) + + epsilon_pred = self.scratch.output_conv(path_1) + epsilon_pred = epsilon_pred.reshape((B, N, C, H, W)) + + return epsilon_pred, epsilon diff --git a/diffusionsfm/model/dit.py b/diffusionsfm/model/dit.py new file mode 100644 index 0000000000000000000000000000000000000000..19d453b33fd548a4420707e81e58e8d9f082aba8 --- /dev/null +++ b/diffusionsfm/model/dit.py @@ -0,0 +1,428 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# GLIDE: https://github.com/openai/glide-text2im +# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py +# -------------------------------------------------------- + +import math + +import ipdb # noqa: F401 +import numpy as np +import torch +import torch.nn as nn +from timm.models.vision_transformer import Attention, Mlp, PatchEmbed +from diffusionsfm.model.memory_efficient_attention import MEAttention + + +def modulate(x, shift, scale): + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +################################################################################# +# Embedding Layers for Timesteps and Class Labels # +################################################################################# + + +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) + * torch.arange(start=0, end=half, dtype=torch.float32) + / half + ).to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat( + [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 + ) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq) + return t_emb + + +################################################################################# +# Core DiT Model # +################################################################################# + + +class DiTBlock(nn.Module): + """ + A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning. + """ + + def __init__( + self, + hidden_size, + num_heads, + mlp_ratio=4.0, + use_xformers_attention=False, + **block_kwargs + ): + super().__init__() + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + attn = MEAttention if use_xformers_attention else Attention + self.attn = attn( + hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs + ) + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + + def approx_gelu(): + return nn.GELU(approximate="tanh") + + self.mlp = Mlp( + in_features=hidden_size, + hidden_features=mlp_hidden_dim, + act_layer=approx_gelu, + drop=0, + ) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True) + ) + + def forward(self, x, c): + ( + shift_msa, + scale_msa, + gate_msa, + shift_mlp, + scale_mlp, + gate_mlp, + ) = self.adaLN_modulation(c).chunk(6, dim=1) + x = x + gate_msa.unsqueeze(1) * self.attn( + modulate(self.norm1(x), shift_msa, scale_msa) + ) + x = x + gate_mlp.unsqueeze(1) * self.mlp( + modulate(self.norm2(x), shift_mlp, scale_mlp) + ) + return x + + +class FinalLayer(nn.Module): + """ + The final layer of DiT. + """ + + def __init__(self, hidden_size, patch_size, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear( + hidden_size, patch_size * patch_size * out_channels, bias=True + ) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True) + ) + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + + +class DiT(nn.Module): + """ + Diffusion model with a Transformer backbone. + """ + + def __init__( + self, + in_channels=442, + out_channels=6, + width=16, + hidden_size=1152, + depth=8, + num_heads=16, + mlp_ratio=4.0, + max_num_images=8, + P=1, + within_image=False, + ): + super().__init__() + self.num_heads = num_heads + self.in_channels = in_channels + self.out_channels = out_channels + self.width = width + self.hidden_size = hidden_size + self.max_num_images = max_num_images + self.P = P + self.within_image = within_image + + # self.x_embedder = nn.Linear(in_channels, hidden_size) + # self.x_embedder = PatchEmbed(in_channels, hidden_size, kernel_size=P, hidden_size=P) + self.x_embedder = PatchEmbed( + img_size=self.width, + patch_size=self.P, + in_chans=in_channels, + embed_dim=hidden_size, + bias=True, + flatten=False, + ) + self.x_pos_enc = FeaturePositionalEncoding( + max_num_images, hidden_size, width**2, P=self.P + ) + self.t_embedder = TimestepEmbedder(hidden_size) + + try: + import xformers + + use_xformers_attention = True + except ImportError: + # xformers not available + use_xformers_attention = False + + self.blocks = nn.ModuleList( + [ + DiTBlock( + hidden_size, + num_heads, + mlp_ratio=mlp_ratio, + use_xformers_attention=use_xformers_attention, + ) + for _ in range(depth) + ] + ) + self.final_layer = FinalLayer(hidden_size, P, out_channels) + self.initialize_weights() + + def initialize_weights(self): + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.x_embedder.proj.bias, 0) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # Zero-out adaLN modulation layers in DiT blocks: + for block in self.blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + # Zero-out output layers: + # nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + # nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + # nn.init.constant_(self.final_layer.linear.weight, 0) + # nn.init.constant_(self.final_layer.linear.bias, 0) + + def unpatchify(self, x): + """ + x: (N, T, patch_size**2 * C) + imgs: (N, H, W, C) + """ + c = self.out_channels + p = self.x_embedder.patch_size[0] + h = w = int(x.shape[1] ** 0.5) + + # print("unpatchify", c, p, h, w, x.shape) + # assert h * w == x.shape[2] + + x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) + x = torch.einsum("nhwpqc->nhpwqc", x) + imgs = x.reshape(shape=(x.shape[0], h * p, h * p, c)) + return imgs + + def forward( + self, + x, + t, + return_dpt_activations=False, + multiview_unconditional=False, + ): + """ + + Args: + x: Image/Ray features (B, N, C, H, W). + t: Timesteps (N,). + + Returns: + (B, N, D, H, W) + """ + B, N, c, h, w = x.shape + P = self.P + + x = x.reshape((B * N, c, h, w)) # (B * N, C, H, W) + x = self.x_embedder(x) # (B * N, C, H / P, W / P) + + x = x.permute(0, 2, 3, 1) # (B * N, H / P, W / P, C) + # (B, N, H / P, W / P, C) + x = x.reshape((B, N, h // P, w // P, self.hidden_size)) + x = self.x_pos_enc(x) # (B, N, H * W / P ** 2, C) + # TODO: fix positional encoding to work with (N, C, H, W) format. + + # Eval time, we get a scalar t + if x.shape[0] != t.shape[0] and t.shape[0] == 1: + t = t.repeat_interleave(B) + + if self.within_image or multiview_unconditional: + t_within = t.repeat_interleave(N) + t_within = self.t_embedder(t_within) + + t = self.t_embedder(t) + + dpt_activations = [] + for i, block in enumerate(self.blocks): + # Within image block + if (self.within_image and i % 2 == 0) or multiview_unconditional: + x = x.reshape((B * N, h * w // P**2, self.hidden_size)) + x = block(x, t_within) + + # All patches block + # Final layer is an all patches layer + else: + x = x.reshape((B, N * h * w // P**2, self.hidden_size)) + x = block(x, t) # (N, T, D) + + if return_dpt_activations and i % 4 == 3: + x_prime = x.reshape(B, N, h, w, self.hidden_size) + x_prime = x.reshape(B * N, h, w, self.hidden_size) + x_prime = x_prime.permute((0, 3, 1, 2)) + dpt_activations.append(x_prime) + + # Reshape the output back to original shape + if multiview_unconditional: + x = x.reshape((B, N * h * w // P**2, self.hidden_size)) + + # (B, N * H * W / P ** 2, D) + x = self.final_layer( + x, t + ) # (B, N * H * W / P ** 2, 6 * P ** 2) or (N, T, patch_size ** 2 * out_channels) + + x = x.reshape((B * N, w * w // P**2, self.out_channels * P**2)) + x = self.unpatchify(x) # (B * N, H, W, C) + x = x.reshape((B, N) + x.shape[1:]) + x = x.permute(0, 1, 4, 2, 3) # (B, N, C, H, W) + + if return_dpt_activations: + return dpt_activations[:4] + + return x + + +class FeaturePositionalEncoding(nn.Module): + def _get_sinusoid_encoding_table(self, n_position, d_hid, base): + """Sinusoid position encoding table""" + + def get_position_angle_vec(position): + return [ + position / np.power(base, 2 * (hid_j // 2) / d_hid) + for hid_j in range(d_hid) + ] + + sinusoid_table = np.array( + [get_position_angle_vec(pos_i) for pos_i in range(n_position)] + ) + sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i + sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 + + return torch.FloatTensor(sinusoid_table).unsqueeze(0) + + def __init__(self, max_num_images=8, feature_dim=1152, num_patches=256, P=1): + super().__init__() + self.max_num_images = max_num_images + self.feature_dim = feature_dim + self.P = P + self.num_patches = num_patches // self.P**2 + + self.register_buffer( + "image_pos_table", + self._get_sinusoid_encoding_table( + self.max_num_images, self.feature_dim, 10000 + ), + ) + + self.register_buffer( + "token_pos_table", + self._get_sinusoid_encoding_table( + self.num_patches, self.feature_dim, 70007 + ), + ) + + def forward(self, x): + batch_size = x.shape[0] + num_images = x.shape[1] + + x = x.reshape(batch_size, num_images, self.num_patches, self.feature_dim) + + # To encode image index + pe1 = self.image_pos_table[:, :num_images].clone().detach() + pe1 = pe1.reshape((1, num_images, 1, self.feature_dim)) + pe1 = pe1.repeat((batch_size, 1, self.num_patches, 1)) + + # To encode patch index + pe2 = self.token_pos_table.clone().detach() + pe2 = pe2.reshape((1, 1, self.num_patches, self.feature_dim)) + pe2 = pe2.repeat((batch_size, num_images, 1, 1)) + + x_pe = x + pe1 + pe2 + x_pe = x_pe.reshape( + (batch_size, num_images * self.num_patches, self.feature_dim) + ) + + return x_pe + + def forward_unet(self, x, B, N): + D = int(self.num_patches**0.5) + + # x should be (B, N, T, D, D) + x = x.permute((0, 2, 3, 1)) + x = x.reshape(B, N, self.num_patches, self.feature_dim) + + # To encode image index + pe1 = self.image_pos_table[:, :N].clone().detach() + pe1 = pe1.reshape((1, N, 1, self.feature_dim)) + pe1 = pe1.repeat((B, 1, self.num_patches, 1)) + + # To encode patch index + pe2 = self.token_pos_table.clone().detach() + pe2 = pe2.reshape((1, 1, self.num_patches, self.feature_dim)) + pe2 = pe2.repeat((B, N, 1, 1)) + + x_pe = x + pe1 + pe2 + x_pe = x_pe.reshape((B * N, D, D, self.feature_dim)) + x_pe = x_pe.permute((0, 3, 1, 2)) + + return x_pe diff --git a/diffusionsfm/model/feature_extractors.py b/diffusionsfm/model/feature_extractors.py new file mode 100644 index 0000000000000000000000000000000000000000..7db6ca95674242ce52a8a7547d105d62c4e01667 --- /dev/null +++ b/diffusionsfm/model/feature_extractors.py @@ -0,0 +1,176 @@ +import importlib +import os +import socket +import sys + +import ipdb # noqa: F401 +import torch +import torch.nn as nn +from omegaconf import OmegaConf + +HOSTNAME = socket.gethostname() + +if "trinity" in HOSTNAME: + # Might be outdated + config_path = "/home/amylin2/latent-diffusion/configs/autoencoder/autoencoder_kl_16x16x16.yaml" + weights_path = "/home/amylin2/latent-diffusion/model.ckpt" +elif "grogu" in HOSTNAME: + # Might be outdated + config_path = "/home/jasonzh2/code/latent-diffusion/configs/autoencoder/autoencoder_kl_16x16x16.yaml" + weights_path = "/home/jasonzh2/code/latent-diffusion/model.ckpt" +elif "ender" in HOSTNAME: + config_path = "/home/jason/ray_diffusion/external/latent-diffusion/configs/autoencoder/autoencoder_kl_16x16x16.yaml" + weights_path = "/home/jason/ray_diffusion/external/latent-diffusion/model.ckpt" +else: + config_path = None + weights_path = None + + +if weights_path is not None: + LDM_PATH = os.path.dirname(weights_path) + if LDM_PATH not in sys.path: + sys.path.append(LDM_PATH) + + +def resize(image, size=None, scale_factor=None): + return nn.functional.interpolate( + image, + size=size, + scale_factor=scale_factor, + mode="bilinear", + align_corners=False, + ) + + +def instantiate_from_config(config): + if "target" not in config: + if config == "__is_first_stage__": + return None + elif config == "__is_unconditional__": + return None + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict())) + + +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + + +class PretrainedVAE(nn.Module): + def __init__(self, freeze_weights=True, num_patches_x=16, num_patches_y=16): + super().__init__() + config = OmegaConf.load(config_path) + self.model = instantiate_from_config(config.model) + self.model.init_from_ckpt(weights_path) + self.model.eval() + self.feature_dim = 16 + self.num_patches_x = num_patches_x + self.num_patches_y = num_patches_y + + if freeze_weights: + for param in self.model.parameters(): + param.requires_grad = False + + def forward(self, x, autoresize=False): + """ + Spatial dimensions of output will be H // 16, W // 16. If autoresize is True, + then the input will be resized such that the output feature map is the correct + dimensions. + + Args: + x (torch.Tensor): Images (B, C, H, W). Should be normalized to be [-1, 1]. + autoresize (bool): Whether to resize the input to match the num_patch + dimensions. + + Returns: + torch.Tensor: Latent sample (B, 16, h, w) + """ + + *B, c, h, w = x.shape + x = x.reshape(-1, c, h, w) + if autoresize: + new_w = self.num_patches_x * 16 + new_h = self.num_patches_y * 16 + x = resize(x, size=(new_h, new_w)) + + decoded, latent = self.model(x) + # A little ambiguous bc it's all 16, but it is (c, h, w) + latent_sample = latent.sample().reshape( + *B, self.feature_dim, self.num_patches_y, self.num_patches_x + ) + return latent_sample + + +activations = {} + + +def get_activation(name): + def hook(model, input, output): + activations[name] = output + + return hook + + +class SpatialDino(nn.Module): + def __init__( + self, + freeze_weights=True, + model_type="dinov2_vits14", + num_patches_x=16, + num_patches_y=16, + activation_hooks=False, + ): + super().__init__() + self.model = torch.hub.load("facebookresearch/dinov2", model_type) + self.feature_dim = self.model.embed_dim + self.num_patches_x = num_patches_x + self.num_patches_y = num_patches_y + if freeze_weights: + for param in self.model.parameters(): + param.requires_grad = False + + self.activation_hooks = activation_hooks + + if self.activation_hooks: + self.model.blocks[5].register_forward_hook(get_activation("encoder1")) + self.model.blocks[11].register_forward_hook(get_activation("encoder2")) + self.activations = activations + + def forward(self, x, autoresize=False): + """ + Spatial dimensions of output will be H // 14, W // 14. If autoresize is True, + then the output will be resized to the correct dimensions. + + Args: + x (torch.Tensor): Images (B, C, H, W). Should be ImageNet normalized. + autoresize (bool): Whether to resize the input to match the num_patch + dimensions. + + Returns: + feature_map (torch.tensor): (B, C, h, w) + """ + *B, c, h, w = x.shape + + x = x.reshape(-1, c, h, w) + # if autoresize: + # new_w = self.num_patches_x * 14 + # new_h = self.num_patches_y * 14 + # x = resize(x, size=(new_h, new_w)) + + # Output will be (B, H * W, C) + features = self.model.forward_features(x)["x_norm_patchtokens"] + features = features.permute(0, 2, 1) + features = features.reshape( # (B, C, H, W) + -1, self.feature_dim, h // 14, w // 14 + ) + if autoresize: + features = resize(features, size=(self.num_patches_y, self.num_patches_x)) + + features = features.reshape( + *B, self.feature_dim, self.num_patches_y, self.num_patches_x + ) + return features diff --git a/diffusionsfm/model/memory_efficient_attention.py b/diffusionsfm/model/memory_efficient_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..1fa86856fdb07b064d424a8bf4546ecbd1cba1a8 --- /dev/null +++ b/diffusionsfm/model/memory_efficient_attention.py @@ -0,0 +1,51 @@ +import ipdb +import torch.nn as nn +from xformers.ops import memory_efficient_attention + + +class MEAttention(nn.Module): + def __init__( + self, + dim, + num_heads=8, + qkv_bias=False, + qk_norm=False, + attn_drop=0.0, + proj_drop=0.0, + norm_layer=nn.LayerNorm, + ): + super().__init__() + assert dim % num_heads == 0, "dim should be divisible by num_heads" + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + qkv = ( + self.qkv(x) + .reshape(B, N, 3, self.num_heads, self.head_dim) + .permute(2, 0, 3, 1, 4) + ) + q, k, v = qkv.unbind(0) + q, k = self.q_norm(q), self.k_norm(k) + + # MEA expects [B, N, H, D], whereas timm uses [B, H, N, D] + x = memory_efficient_attention( + q.transpose(1, 2), + k.transpose(1, 2), + v.transpose(1, 2), + scale=self.scale, + ) + x = x.reshape(B, N, C) + + x = self.proj(x) + x = self.proj_drop(x) + return x diff --git a/diffusionsfm/model/scheduler.py b/diffusionsfm/model/scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..bc5ab7252b95c79e8dbc2aa229779dd499646f98 --- /dev/null +++ b/diffusionsfm/model/scheduler.py @@ -0,0 +1,128 @@ +import ipdb # noqa: F401 +import matplotlib.pyplot as plt +import numpy as np +import torch +import torch.nn as nn + +from diffusionsfm.utils.visualization import plot_to_image + + +class NoiseScheduler(nn.Module): + def __init__( + self, + max_timesteps=1000, + beta_start=0.0001, + beta_end=0.02, + cos_power=2, + num_inference_steps=100, + type="linear", + ): + super().__init__() + self.max_timesteps = max_timesteps + self.num_inference_steps = num_inference_steps + self.beta_start = beta_start + self.beta_end = beta_end + self.cos_power = cos_power + self.type = type + + if type == "linear": + self.register_linear_schedule() + elif type == "cosine": + self.register_cosine_schedule(cos_power) + elif type == "scaled_linear": + self.register_scaled_linear_schedule() + + self.inference_timesteps = self.compute_inference_timesteps() + + def register_linear_schedule(self): + # zero terminal SNR (https://arxiv.org/pdf/2305.08891) + betas = torch.linspace( + self.beta_start, + self.beta_end, + self.max_timesteps, + dtype=torch.float32, + ) + alphas = 1.0 - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + alphas_bar_sqrt = alphas_cumprod.sqrt() + + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + alphas_bar_sqrt -= alphas_bar_sqrt_T + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + alphas_bar = alphas_bar_sqrt**2 + alphas = alphas_bar[1:] / alphas_bar[:-1] + alphas = torch.cat([alphas_bar[0:1], alphas]) + betas = 1 - alphas + + self.register_buffer( + "betas", + betas, + ) + self.register_buffer("alphas", 1.0 - self.betas) + self.register_buffer("alphas_cumprod", torch.cumprod(self.alphas, dim=0)) + + def register_cosine_schedule(self, cos_power, s=0.008): + timesteps = ( + torch.arange(self.max_timesteps + 1, dtype=torch.float32) + / self.max_timesteps + ) + alpha_bars = (timesteps + s) / (1 + s) * np.pi / 2 + alpha_bars = torch.cos(alpha_bars).pow(cos_power) + alpha_bars = alpha_bars / alpha_bars[0] + betas = 1 - alpha_bars[1:] / alpha_bars[:-1] + betas = np.clip(betas, a_min=0, a_max=0.999) + + self.register_buffer( + "betas", + betas, + ) + self.register_buffer("alphas", 1.0 - betas) + self.register_buffer("alphas_cumprod", torch.cumprod(self.alphas, dim=0)) + + def register_scaled_linear_schedule(self): + self.register_buffer( + "betas", + torch.linspace( + self.beta_start**0.5, + self.beta_end**0.5, + self.max_timesteps, + dtype=torch.float32, + ) + ** 2, + ) + self.register_buffer("alphas", 1.0 - self.betas) + self.register_buffer("alphas_cumprod", torch.cumprod(self.alphas, dim=0)) + + def compute_inference_timesteps( + self, num_inference_steps=None, num_train_steps=None + ): + # based on diffusers's scheduling code + if num_inference_steps is None: + num_inference_steps = self.num_inference_steps + if num_train_steps is None: + num_train_steps = self.max_timesteps + step_ratio = num_train_steps // num_inference_steps + timesteps = ( + (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].astype(int) + ) + return timesteps + + def plot_schedule(self, return_image=False): + fig = plt.figure(figsize=(6, 4), dpi=100) + alpha_bars = self.alphas_cumprod.cpu().numpy() + plt.plot(np.sqrt(alpha_bars)) + plt.grid() + if self.type == "linear": + plt.title( + f"Linear (T={self.max_timesteps}, S={self.beta_start}, E={self.beta_end})" + ) + else: + self.type == "cosine" + plt.title(f"Cosine (T={self.max_timesteps}, P={self.cos_power})") + if return_image: + image = plot_to_image(fig) + plt.close(fig) + return image diff --git a/diffusionsfm/utils/__init__.py b/diffusionsfm/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/diffusionsfm/utils/configs.py b/diffusionsfm/utils/configs.py new file mode 100644 index 0000000000000000000000000000000000000000..9f61b1e11e3b0a9b7db254e05d18a22d829583e1 --- /dev/null +++ b/diffusionsfm/utils/configs.py @@ -0,0 +1,66 @@ +import argparse +import os + +from omegaconf import OmegaConf + + +def load_cfg(config_path): + """ + Loads a yaml configuration file. + + Follows the chain of yaml configuration files that have a `_BASE` key, and updates + the new keys accordingly. _BASE configurations can be specified using relative + paths. + """ + config_dir = os.path.dirname(config_path) + config_path = os.path.basename(config_path) + return load_cfg_recursive(config_dir, config_path) + + +def load_cfg_recursive(config_dir, config_path): + """ + Recursively loads config files. + + Follows the chain of yaml configuration files that have a `_BASE` key, and updates + the new keys accordingly. _BASE configurations can be specified using relative + paths. + """ + cfg = OmegaConf.load(os.path.join(config_dir, config_path)) + base_path = OmegaConf.select(cfg, "_BASE", default=None) + if base_path is not None: + base_cfg = load_cfg_recursive(config_dir, base_path) + cfg = OmegaConf.merge(base_cfg, cfg) + return cfg + + +def get_cfg(): + parser = argparse.ArgumentParser() + parser.add_argument("--config-path", type=str, required=True) + args = parser.parse_args() + cfg = load_cfg(args.config_path) + print(OmegaConf.to_yaml(cfg)) + + exp_dir = os.path.join(cfg.training.runs_dir, cfg.training.exp_tag) + os.makedirs(exp_dir, exist_ok=True) + to_path = os.path.join(exp_dir, os.path.basename(args.config_path)) + if not os.path.exists(to_path): + OmegaConf.save(config=cfg, f=to_path) + return cfg + + +def get_cfg_from_path(config_path): + """ + args: + config_path - get config from path + """ + print("getting config from path") + + cfg = load_cfg(config_path) + print(OmegaConf.to_yaml(cfg)) + + exp_dir = os.path.join(cfg.training.runs_dir, cfg.training.exp_tag) + os.makedirs(exp_dir, exist_ok=True) + to_path = os.path.join(exp_dir, os.path.basename(config_path)) + if not os.path.exists(to_path): + OmegaConf.save(config=cfg, f=to_path) + return cfg diff --git a/diffusionsfm/utils/distortion.py b/diffusionsfm/utils/distortion.py new file mode 100644 index 0000000000000000000000000000000000000000..3c638db8a4f2bd89e294f86ce84eb0cfc215d6ea --- /dev/null +++ b/diffusionsfm/utils/distortion.py @@ -0,0 +1,144 @@ +import cv2 +import ipdb +import numpy as np +from PIL import Image +import torch + + +# https://gist.github.com/davegreenwood/820d51ac5ec88a2aeda28d3079e7d9eb +def apply_distortion(pts, k1, k2): + """ + Arguments: + pts (N x 2): numpy array in NDC coordinates + k1, k2 distortion coefficients + Return: + pts (N x 2): distorted points in NDC coordinates + """ + r2 = np.square(pts).sum(-1) + f = 1 + k1 * r2 + k2 * r2**2 + return f[..., None] * pts + + +# https://gist.github.com/davegreenwood/820d51ac5ec88a2aeda28d3079e7d9eb +def apply_distortion_tensor(pts, k1, k2): + """ + Arguments: + pts (N x 2): numpy array in NDC coordinates + k1, k2 distortion coefficients + Return: + pts (N x 2): distorted points in NDC coordinates + """ + r2 = torch.square(pts).sum(-1) + f = 1 + k1 * r2 + k2 * r2**2 + return f[..., None] * pts + + +# https://gist.github.com/davegreenwood/820d51ac5ec88a2aeda28d3079e7d9eb +def remove_distortion_iter(points, k1, k2): + """ + Arguments: + pts (N x 2): numpy array in NDC coordinates + k1, k2 distortion coefficients + Return: + pts (N x 2): distorted points in NDC coordinates + """ + pts = ptsd = points + for _ in range(5): + r2 = np.square(pts).sum(-1) + f = 1 + k1 * r2 + k2 * r2**2 + pts = ptsd / f[..., None] + + return pts + + +def make_square(im, fill_color=(0, 0, 0)): + x, y = im.size + size = max(x, y) + new_im = Image.new("RGB", (size, size), fill_color) + corner = (int((size - x) / 2), int((size - y) / 2)) + new_im.paste(im, corner) + return new_im, corner + + +def pixel_to_ndc(coords, image_size): + """ + Converts pixel coordinates to normalized device coordinates (Pytorch3D convention + with upper left = (1, 1)) for a square image. + + Args: + coords: Pixel coordinates UL=(0, 0), LR=(image_size, image_size). + image_size (int): Image size. + + Returns: + NDC coordinates UL=(1, 1) LR=(-1, -1). + """ + coords = np.array(coords) + return 1 - coords / image_size * 2 + + +def ndc_to_pixel(coords, image_size): + """ + Converts normalized device coordinates to pixel coordinates for a square image. + """ + num_points = coords.shape[0] + sizes = np.tile(np.array(image_size, dtype=np.float32)[None, ...], (num_points, 1)) + + coords = np.array(coords, dtype=np.float32) + return (1 - coords) * sizes / 2 + + +def distort_image(image, bbox, k1, k2, modify_bbox=False): + # We want to operate in -1 to 1 space using the padded square of the original image + image, corner = make_square(image) + bbox[:2] += np.array(corner) + bbox[2:] += np.array(corner) + + # Construct grid points + x = np.linspace(1, -1, image.width, dtype=np.float32) + y = np.linspace(1, -1, image.height, dtype=np.float32) + x, y = np.meshgrid(x, y, indexing="xy") + xy_grid = np.stack((x, y), axis=-1) + points = xy_grid.reshape((image.height * image.width, 2)) + new_points = ndc_to_pixel(apply_distortion(points, k1, k2), image.size) + + # Distort image by remapping + map_x = new_points[:, 0].reshape((image.height, image.width)) + map_y = new_points[:, 1].reshape((image.height, image.width)) + distorted = cv2.remap( + np.asarray(image), + map_x, + map_y, + cv2.INTER_LINEAR, + ) + distorted = Image.fromarray(distorted) + + # Find distorted crop bounds - inverse process of above + if modify_bbox: + center = (bbox[:2] + bbox[2:]) / 2 + top, bottom = (bbox[0], center[1]), (bbox[2], center[1]) + left, right = (center[0], bbox[1]), (center[0], bbox[3]) + bbox_points = np.array( + [ + pixel_to_ndc(top, image.size), + pixel_to_ndc(left, image.size), + pixel_to_ndc(bottom, image.size), + pixel_to_ndc(right, image.size), + ], + dtype=np.float32, + ) + else: + bbox_points = np.array( + [pixel_to_ndc(bbox[:2], image.size), pixel_to_ndc(bbox[2:], image.size)], + dtype=np.float32, + ) + + # Inverse mapping + distorted_bbox = remove_distortion_iter(bbox_points, k1, k2) + + if modify_bbox: + p = ndc_to_pixel(distorted_bbox, image.size) + distorted_bbox = np.array([p[0][0], p[1][1], p[2][0], p[3][1]]) + else: + distorted_bbox = ndc_to_pixel(distorted_bbox, image.size).reshape(4) + + return distorted, distorted_bbox diff --git a/diffusionsfm/utils/distributed.py b/diffusionsfm/utils/distributed.py new file mode 100644 index 0000000000000000000000000000000000000000..5f523fe6769748f0b9e5007c77790a953111c5ab --- /dev/null +++ b/diffusionsfm/utils/distributed.py @@ -0,0 +1,31 @@ +import os +import socket +from contextlib import closing + +import torch.distributed as dist + + +def get_open_port(): + with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: + s.bind(("", 0)) + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + return s.getsockname()[1] + + +# Distributed process group +def ddp_setup(rank, world_size, port="12345"): + """ + Args: + rank: Unique Identifier + world_size: number of processes + """ + os.environ["MASTER_ADDR"] = "localhost" + print(f"MasterPort: {str(port)}") + os.environ["MASTER_PORT"] = str(port) + + # initialize the process group + dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) + + +def cleanup(): + dist.destroy_process_group() diff --git a/diffusionsfm/utils/geometry.py b/diffusionsfm/utils/geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..c92f2dd7505c028bdca0c58c8fd7fea09c8f38b7 --- /dev/null +++ b/diffusionsfm/utils/geometry.py @@ -0,0 +1,145 @@ +import numpy as np +import torch +from pytorch3d.renderer import FoVPerspectiveCameras +from pytorch3d.transforms import quaternion_to_matrix + + +def generate_random_rotations(N=1, device="cpu"): + q = torch.randn(N, 4, device=device) + q = q / q.norm(dim=-1, keepdim=True) + return quaternion_to_matrix(q) + + +def symmetric_orthogonalization(x): + """Maps 9D input vectors onto SO(3) via symmetric orthogonalization. + + x: should have size [batch_size, 9] + + Output has size [batch_size, 3, 3], where each inner 3x3 matrix is in SO(3). + """ + m = x.view(-1, 3, 3) + u, s, v = torch.svd(m) + vt = torch.transpose(v, 1, 2) + det = torch.det(torch.matmul(u, vt)) + det = det.view(-1, 1, 1) + vt = torch.cat((vt[:, :2, :], vt[:, -1:, :] * det), 1) + r = torch.matmul(u, vt) + return r + + +def get_permutations(num_images): + permutations = [] + for i in range(0, num_images): + for j in range(0, num_images): + if i != j: + permutations.append((j, i)) + + return permutations + + +def n_to_np_rotations(num_frames, n_rots): + R_pred_rel = [] + permutations = get_permutations(num_frames) + for i, j in permutations: + R_pred_rel.append(n_rots[i].T @ n_rots[j]) + R_pred_rel = torch.stack(R_pred_rel) + + return R_pred_rel + + +def compute_angular_error_batch(rotation1, rotation2): + R_rel = np.einsum("Bij,Bjk ->Bik", rotation2, rotation1.transpose(0, 2, 1)) + t = (np.trace(R_rel, axis1=1, axis2=2) - 1) / 2 + theta = np.arccos(np.clip(t, -1, 1)) + return theta * 180 / np.pi + + +# A should be GT, B should be predicted +def compute_optimal_alignment(A, B): + """ + Compute the optimal scale s, rotation R, and translation t that minimizes: + || A - (s * B @ R + T) || ^ 2 + + Reference: Umeyama (TPAMI 91) + + Args: + A (torch.Tensor): (N, 3). + B (torch.Tensor): (N, 3). + + Returns: + s (float): scale. + R (torch.Tensor): rotation matrix (3, 3). + t (torch.Tensor): translation (3,). + """ + A_bar = A.mean(0) + B_bar = B.mean(0) + # normally with R @ B, this would be A @ B.T + H = (B - B_bar).T @ (A - A_bar) + U, S, Vh = torch.linalg.svd(H, full_matrices=True) + s = torch.linalg.det(U @ Vh) + S_prime = torch.diag(torch.tensor([1, 1, torch.sign(s)], device=A.device)) + variance = torch.sum((B - B_bar) ** 2) + scale = 1 / variance * torch.trace(torch.diag(S) @ S_prime) + R = U @ S_prime @ Vh + t = A_bar - scale * B_bar @ R + + A_hat = scale * B @ R + t + return A_hat, scale, R, t + + +def compute_optimal_translation_alignment(T_A, T_B, R_B): + """ + Assuming right-multiplied rotation matrices. + + E.g., for world2cam R and T, a world coordinate is transformed to camera coordinate + system using X_cam = X_world.T @ R + T = R.T @ X_world + T + + Finds s, t that minimizes || T_A - (s * T_B + R_B.T @ t) ||^2 + + Args: + T_A (torch.Tensor): Target translation (N, 3). + T_B (torch.Tensor): Initial translation (N, 3). + R_B (torch.Tensor): Initial rotation (N, 3, 3). + + Returns: + T_A_hat (torch.Tensor): s * T_B + t @ R_B (N, 3). + scale s (torch.Tensor): (1,). + translation t (torch.Tensor): (1, 3). + """ + n = len(T_A) + + T_A = T_A.unsqueeze(2) + T_B = T_B.unsqueeze(2) + + A = torch.sum(T_B * T_A) + B = (T_B.transpose(1, 2) @ R_B.transpose(1, 2)).sum(0) @ (R_B @ T_A).sum(0) / n + C = torch.sum(T_B * T_B) + D = (T_B.transpose(1, 2) @ R_B.transpose(1, 2)).sum(0) + E = (D * D).sum() / n + + s = (A - B.sum()) / (C - E.sum()) + + t = (R_B @ (T_A - s * T_B)).sum(0) / n + + T_A_hat = s * T_B + R_B.transpose(1, 2) @ t + + return T_A_hat.squeeze(2), s, t.transpose(1, 0) + + +def get_error(predict_rotations, R_pred, T_pred, R_gt, T_gt, gt_scene_scale): + if predict_rotations: + cameras_gt = FoVPerspectiveCameras(R=R_gt, T=T_gt) + cc_gt = cameras_gt.get_camera_center() + cameras_pred = FoVPerspectiveCameras(R=R_pred, T=T_pred) + cc_pred = cameras_pred.get_camera_center() + + A_hat, _, _, _ = compute_optimal_alignment(cc_gt, cc_pred) + norm = torch.linalg.norm(cc_gt - A_hat, dim=1) / gt_scene_scale + + norms = np.ndarray.tolist(norm.detach().cpu().numpy()) + return norms, A_hat + else: + T_A_hat, _, _ = compute_optimal_translation_alignment(T_gt, T_pred, R_pred) + norm = torch.linalg.norm(T_gt - T_A_hat, dim=1) / gt_scene_scale + norms = np.ndarray.tolist(norm.detach().cpu().numpy()) + return norms, T_A_hat diff --git a/diffusionsfm/utils/normalize.py b/diffusionsfm/utils/normalize.py new file mode 100644 index 0000000000000000000000000000000000000000..287be15f32edcaf2f14e7d8e133d71aa5d6768ca --- /dev/null +++ b/diffusionsfm/utils/normalize.py @@ -0,0 +1,92 @@ +import ipdb # noqa: F401 +import torch +from pytorch3d.transforms import Rotate, Translate + + +def intersect_skew_line_groups(p, r, mask=None): + # p, r both of shape (B, N, n_intersected_lines, 3) + # mask of shape (B, N, n_intersected_lines) + p_intersect, r = intersect_skew_lines_high_dim(p, r, mask=mask) + if p_intersect is None: + return None, None, None, None + _, p_line_intersect = point_line_distance( + p, r, p_intersect[..., None, :].expand_as(p) + ) + intersect_dist_squared = ((p_line_intersect - p_intersect[..., None, :]) ** 2).sum( + dim=-1 + ) + return p_intersect, p_line_intersect, intersect_dist_squared, r + + +def intersect_skew_lines_high_dim(p, r, mask=None): + # Implements https://en.wikipedia.org/wiki/Skew_lines In more than two dimensions + dim = p.shape[-1] + # make sure the heading vectors are l2-normed + if mask is None: + mask = torch.ones_like(p[..., 0]) + r = torch.nn.functional.normalize(r, dim=-1) + + eye = torch.eye(dim, device=p.device, dtype=p.dtype)[None, None] + I_min_cov = (eye - (r[..., None] * r[..., None, :])) * mask[..., None, None] + sum_proj = I_min_cov.matmul(p[..., None]).sum(dim=-3) + + p_intersect = torch.linalg.lstsq(I_min_cov.sum(dim=-3), sum_proj).solution[..., 0] + + if torch.any(torch.isnan(p_intersect)): + print(p_intersect) + return None, None + ipdb.set_trace() + assert False + return p_intersect, r + + +def point_line_distance(p1, r1, p2): + df = p2 - p1 + proj_vector = df - ((df * r1).sum(dim=-1, keepdim=True) * r1) + line_pt_nearest = p2 - proj_vector + d = (proj_vector).norm(dim=-1) + return d, line_pt_nearest + + +def compute_optical_axis_intersection(cameras): + centers = cameras.get_camera_center() + principal_points = cameras.principal_point + + one_vec = torch.ones((len(cameras), 1), device=centers.device) + optical_axis = torch.cat((principal_points, one_vec), -1) + + pp = cameras.unproject_points(optical_axis, from_ndc=True, world_coordinates=True) + pp2 = torch.diagonal(pp, dim1=0, dim2=1).T + + directions = pp2 - centers + centers = centers.unsqueeze(0).unsqueeze(0) + directions = directions.unsqueeze(0).unsqueeze(0) + + p_intersect, p_line_intersect, _, r = intersect_skew_line_groups( + p=centers, r=directions, mask=None + ) + + if p_intersect is None: + dist = None + else: + p_intersect = p_intersect.squeeze().unsqueeze(0) + dist = (p_intersect - centers).norm(dim=-1) + + return p_intersect, dist, p_line_intersect, pp2, r + + +def first_camera_transform(cameras, rotation_only=True): + new_cameras = cameras.clone() + new_transform = new_cameras.get_world_to_view_transform() + tR = Rotate(new_cameras.R[0].unsqueeze(0)) + if rotation_only: + t = tR.inverse() + else: + tT = Translate(new_cameras.T[0].unsqueeze(0)) + t = tR.compose(tT).inverse() + + new_transform = t.compose(new_transform) + new_cameras.R = new_transform.get_matrix()[:, :3, :3] + new_cameras.T = new_transform.get_matrix()[:, 3, :3] + + return new_cameras diff --git a/diffusionsfm/utils/rays.py b/diffusionsfm/utils/rays.py new file mode 100644 index 0000000000000000000000000000000000000000..5cb899acaf7d392488ec2545db03822e65eb7602 --- /dev/null +++ b/diffusionsfm/utils/rays.py @@ -0,0 +1,1390 @@ +from tkinter import FALSE +import cv2 +import ipdb # noqa: F401 +import numpy as np +import torch +from pytorch3d.renderer import PerspectiveCameras, RayBundle +from pytorch3d.transforms import Rotate, Translate + + +from diffusionsfm.utils.normalize import ( + compute_optical_axis_intersection, + intersect_skew_line_groups, + first_camera_transform, + intersect_skew_lines_high_dim, +) +from diffusionsfm.utils.distortion import apply_distortion_tensor + + +class Rays(object): + def __init__( + self, + rays=None, + origins=None, + directions=None, + moments=None, + segments=None, + depths=None, + moments_rescale=1.0, + ndc_coordinates=None, + crop_parameters=None, + num_patches_x=16, + num_patches_y=16, + distortion_coeffs=None, + camera_coordinate_rays=None, + mode=None, + unprojected=None, + depth_resolution=1, + row_form=False, + ): + """ + Ray class to keep track of current ray representation. + + Args: + rays: (..., 6). + origins: (..., 3). + directions: (..., 3). + moments: (..., 3). + mode: One of "ray", "plucker" or "segment". + moments_rescale: Rescale the moment component of the rays by a scalar. + ndc_coordinates: (..., 2): NDC coordinates of each ray. + """ + self.depth_resolution = depth_resolution + self.num_patches_x = num_patches_x + self.num_patches_y = num_patches_y + + if rays is not None: + self.rays = rays + assert mode is not None + self._mode = mode + elif segments is not None: + if not row_form: + segments = Rays.patches_to_rows( + segments, + num_patches_x=num_patches_x, + num_patches_y=num_patches_y, + depth_resolution=depth_resolution, + ) + self.rays = torch.cat((origins, segments), dim=-1) + self._mode = "segment" + elif origins is not None and directions is not None: + self.rays = torch.cat((origins, directions), dim=-1) + self._mode = "ray" + elif directions is not None and moments is not None: + self.rays = torch.cat((directions, moments), dim=-1) + self._mode = "plucker" + else: + raise Exception("Invalid combination of arguments") + + if depths is not None: + self._mode = mode + self.depths = depths + else: + self.depths = None + assert mode is not None + + if unprojected is not None: + self.unprojected = unprojected + else: + self.unprojected = None + + if moments_rescale != 1.0: + self.rescale_moments(moments_rescale) + + if ndc_coordinates is not None: + self.ndc_coordinates = ndc_coordinates + elif crop_parameters is not None: + # (..., H, W, 2) + xy_grid = compute_ndc_coordinates( + crop_parameters, + num_patches_x=num_patches_x, + num_patches_y=num_patches_y, + distortion_coeffs=distortion_coeffs, + )[..., :2] + xy_grid = xy_grid.reshape(*xy_grid.shape[:-3], -1, 2) + self.ndc_coordinates = xy_grid + else: + self.ndc_coordinates = None + + if camera_coordinate_rays is not None: + self.camera_ray_coordinates = True + self.camera_coordinate_ray_directions = camera_coordinate_rays + else: + self.camera_ray_coordinates = False + + def __getitem__(self, index): + cam_coord_rays = None + if self.camera_ray_coordinates: + cam_coord_rays = self.camera_coordinate_ray_directions[index] + + return Rays( + rays=self.rays[index], + mode=self._mode, + camera_coordinate_rays=cam_coord_rays, + ndc_coordinates=( + self.ndc_coordinates[index] + if self.ndc_coordinates is not None + else None + ), + num_patches_x=self.num_patches_x, + num_patches_y=self.num_patches_y, + depths=( + self.depths[index] + if self.ndc_coordinates is not None and self.depths is not None + else None + ), + unprojected=( + self.unprojected[index] if self.ndc_coordinates is not None else None + ), + depth_resolution=self.depth_resolution, + ) + + def __len__(self): + return self.rays.shape[0] + + def to_spatial( + self, include_ndc_coordinates=False, include_depths=False, use_homogeneous=False + ): + """ + Converts rays to spatial representation: (..., H * W, 6) --> (..., 6, H, W) + + If use_homogeneous is True, then each 3D component will be 4D and normalized. + + Returns: + torch.Tensor: (..., 6, H, W) + """ + if self._mode == "ray": + rays = self.to_plucker().rays + else: + rays = self.rays + + *batch_dims, P, D = rays.shape + H = W = int(np.sqrt(P)) + assert H * W == P + + if use_homogeneous: + rays_reshaped = rays.reshape(*batch_dims, P, D // 3, 3) + ones = torch.ones_like(rays_reshaped[..., :1]) + rays_reshaped = torch.cat((rays_reshaped, ones), dim=-1) + rays = torch.nn.functional.normalize(rays_reshaped, dim=-1) + D = (4 * D) // 3 + rays = rays.reshape(*batch_dims, P, D) + + rays = torch.transpose(rays, -1, -2) # (..., 6, H * W) + rays = rays.reshape(*batch_dims, D, H, W) + + if include_depths: + depths = self.depths.unsqueeze(1) + rays = torch.cat((rays, depths), dim=-3) + + if include_ndc_coordinates: + ndc_coords = self.ndc_coordinates.transpose(-1, -2) # (..., 2, H * W) + ndc_coords = ndc_coords.reshape(*batch_dims, 2, H, W) + rays = torch.cat((rays, ndc_coords), dim=-3) + + return rays + + def to_spatial_with_camera_coordinate_rays( + self, + I_camera, + crop_params, + moments_rescale=1.0, + include_ndc_coordinates=False, + use_homogeneous=False, + ): + """ + Converts rays to spatial representation: (..., H * W, 6) --> (..., 6, H, W) + + Returns: + torch.Tensor: (..., 6, H, W) + """ + + rays = self.to_spatial( + include_ndc_coordinates=include_ndc_coordinates, + use_homogeneous=use_homogeneous, + ) + N, _, H, W = rays.shape + + camera_coord_rays = ( + cameras_to_rays( + cameras=I_camera, + num_patches_x=H, + num_patches_y=W, + crop_parameters=crop_params, + ) + .rescale_moments(1 / moments_rescale) + .get_directions() + ) + + self.camera_coordinate_ray_directions = camera_coord_rays + + # camera_coord_rays = torch.stack(camera_coord_rays) + camera_coord_rays = torch.transpose(camera_coord_rays, -1, -2) + camera_coord_rays = camera_coord_rays.reshape(N, 3, H, W) + + rays = torch.cat((camera_coord_rays, rays), dim=-3) + + return rays + + def rescale_moments(self, scale): + """ + Rescale the moment component of the rays by a scalar. Might be desirable since + moments may come from a very narrow distribution. + + Note that this modifies in place! + """ + assert False, "Deprecated" + if self._mode == "plucker": + self.rays[..., 3:] *= scale + return self + else: + return self.to_plucker().rescale_moments(scale) + + def to_spatial_with_camera_coordinate_rays_object( + self, + I_camera, + crop_params, + moments_rescale=1.0, + include_ndc_coordinates=False, + use_homogeneous=False, + ): + """ + Converts rays to spatial representation: (..., H * W, 6) --> (..., 6, H, W) + + Returns: + torch.Tensor: (..., 6, H, W) + """ + + rays = self.to_spatial(include_ndc_coordinates, use_homogeneous=use_homogeneous) + N, _, H, W = rays.shape + + camera_coord_rays = ( + cameras_to_rays( + cameras=I_camera, + num_patches_x=H, + num_patches_y=W, + crop_parameters=crop_params, + ) + .rescale_moments(1 / moments_rescale) + .get_directions() + ) + + self.camera_coordinate_ray_directions = camera_coord_rays + + camera_coord_rays = torch.transpose(camera_coord_rays, -1, -2) + camera_coord_rays = camera_coord_rays.reshape(N, 3, H, W) + + @classmethod + def patches_to_rows(cls, x, num_patches_x=16, num_patches_y=16, depth_resolution=1): + B, P, C = x.shape + assert P == (depth_resolution**2 * num_patches_x * num_patches_y) + + x = x.reshape( + B, + depth_resolution * num_patches_x, + depth_resolution * num_patches_y, + C, + ) + + new = x.unfold(1, depth_resolution, depth_resolution).unfold( + 2, depth_resolution, depth_resolution + ) + new = new.permute((0, 1, 2, 4, 5, 3)) + new = new.reshape( + (B, num_patches_x * num_patches_y, depth_resolution * depth_resolution * C) + ) + return new + + @classmethod + def rows_to_patches(cls, x, num_patches_x=16, num_patches_y=16, depth_resolution=1): + B, P, CP = x.shape + assert P == (num_patches_x * num_patches_y) + C = CP // (depth_resolution**2) + HP, WP = num_patches_x * depth_resolution, num_patches_y * depth_resolution + + x = x.reshape( + B, num_patches_x, num_patches_y, depth_resolution, depth_resolution, C + ) + x = x.permute(0, 1, 3, 2, 4, 5) + x = x.reshape(B, HP * WP, C) + return x + + @classmethod + def upsample_origins( + cls, x, num_patches_x=16, num_patches_y=16, depth_resolution=1 + ): + B, P, C = x.shape + origins = x.permute((0, 2, 1)) + origins = origins.reshape((B, C, num_patches_x, num_patches_y)) + origins = torch.nn.functional.interpolate( + origins, scale_factor=(depth_resolution, depth_resolution) + ) + origins = origins.permute((0, 2, 3, 1)).reshape( + (B, P * depth_resolution * depth_resolution, C) + ) + return origins + + @classmethod + def from_spatial_with_camera_coordinate_rays( + cls, rays, mode, moments_rescale=1.0, use_homogeneous=False + ): + """ + Converts rays from spatial representation: (..., 6, H, W) --> (..., H * W, 6) + + Args: + rays: (..., 6, H, W) + + Returns: + Rays: (..., H * W, 6) + """ + *batch_dims, D, H, W = rays.shape + rays = rays.reshape(*batch_dims, D, H * W) + rays = torch.transpose(rays, -1, -2) + + camera_coordinate_ray_directions = rays[..., :3] + rays = rays[..., 3:] + + return cls( + rays=rays, + mode=mode, + moments_rescale=moments_rescale, + camera_coordinate_rays=camera_coordinate_ray_directions, + ) + + @classmethod + def from_spatial( + cls, + rays, + mode, + moments_rescale=1.0, + ndc_coordinates=None, + num_patches_x=16, + num_patches_y=16, + use_homogeneous=False, + ): + """ + Converts rays from spatial representation: (..., 6, H, W) --> (..., H * W, 6) + + Args: + rays: (..., 6, H, W) + + Returns: + Rays: (..., H * W, 6) + """ + *batch_dims, D, H, W = rays.shape + rays = rays.reshape(*batch_dims, D, H * W) + rays = torch.transpose(rays, -1, -2) + + if use_homogeneous: + D -= 2 + + if D == 7: + if use_homogeneous: + r1 = rays[..., :3] / (rays[..., 3:4] + 1e-6) + r2 = rays[..., 4:7] / (rays[..., 7:8] + 1e-6) + rays = torch.cat((r1, r2), dim=-1) + depths = rays[8] + else: + old_rays = rays + rays = rays[..., :6] + depths = old_rays[..., 6] + return cls( + rays=rays, + mode=mode, + moments_rescale=moments_rescale, + ndc_coordinates=ndc_coordinates, + depths=depths.reshape(*batch_dims, H, W), + num_patches_x=num_patches_x, + num_patches_y=num_patches_y, + ) + elif D > 7: + + D += 2 + if use_homogeneous: + rays_reshaped = rays.reshape((*batch_dims, H * W, D // 4, 4)) + rays_not_homo = rays_reshaped / rays_reshaped[..., :, 3].unsqueeze(-1) + rays = rays_not_homo[..., :, :3].reshape( + (*batch_dims, H * W, (D // 4) * 3) + ) + D = (D // 4) * 3 + + ray = cls( + origins=rays[:, :, :3], + segments=rays[:, :, 3:], + mode="segment", + moments_rescale=moments_rescale, + ndc_coordinates=ndc_coordinates, + # depths=rays[:, :, -1].reshape(*batch_dims, H, W), + row_form=True, + num_patches_x=num_patches_x, + num_patches_y=num_patches_y, + depth_resolution=int(((D - 3) // 3) ** 0.5), + ) + + if mode == "ray": + return ray.to_point_direction() + elif mode == "plucker": + return ray.to_plucker() + elif mode == "segment": + return ray + else: + assert False + else: + if use_homogeneous: + r1 = rays[..., :3] / (rays[..., 3:4] + 1e-6) + r2 = rays[..., 4:7] / (rays[..., 7:8] + 1e-6) + rays = torch.cat((r1, r2), dim=-1) + return cls( + rays=rays, + mode=mode, + moments_rescale=moments_rescale, + ndc_coordinates=ndc_coordinates, + num_patches_x=num_patches_x, + num_patches_y=num_patches_y, + ) + + def to_point_direction(self, normalize_moment=True): + """ + Convert to point direction representation . + + Returns: + rays: (..., 6). + """ + if self._mode == "plucker": + direction = torch.nn.functional.normalize(self.rays[..., :3], dim=-1) + moment = self.rays[..., 3:] + if normalize_moment: + c = torch.linalg.norm(direction, dim=-1, keepdim=True) + moment = moment / c + points = torch.cross(direction, moment, dim=-1) + return Rays( + rays=torch.cat((points, direction), dim=-1), + mode="ray", + ndc_coordinates=self.ndc_coordinates, + num_patches_x=self.num_patches_x, + num_patches_y=self.num_patches_y, + depths=self.depths, + unprojected=self.unprojected, + depth_resolution=self.depth_resolution, + ) + elif self._mode == "segment": + origins = self.get_origins(high_res=True) + + direction = self.get_segments() - origins + direction = torch.nn.functional.normalize(direction, dim=-1) + + return Rays( + rays=torch.cat((origins, direction), dim=-1), + mode="ray", + ndc_coordinates=self.ndc_coordinates, + num_patches_x=self.num_patches_x, + num_patches_y=self.num_patches_y, + depths=self.depths, + unprojected=self.unprojected, + depth_resolution=self.depth_resolution, + ) + else: + return self + + def to_plucker(self): + """ + Convert to plucker representation . + """ + if self._mode == "plucker": + return self + elif self._mode == "ray": + ray = self.rays.clone() + ray_origins = ray[..., :3] + ray_directions = ray[..., 3:] + + # Normalize ray directions to unit vectors + ray_directions = ray_directions / torch.linalg.vector_norm( + ray_directions, dim=-1, keepdim=True + ) + plucker_normal = torch.cross(ray_origins, ray_directions, dim=-1) + new_ray = torch.cat([ray_directions, plucker_normal], dim=-1) + return Rays( + rays=new_ray, + mode="plucker", + ndc_coordinates=self.ndc_coordinates, + num_patches_x=self.num_patches_x, + num_patches_y=self.num_patches_y, + depths=self.depths, + unprojected=self.unprojected, + depth_resolution=self.depth_resolution, + ) + elif self._mode == "segment": + return self.to_point_direction().to_plucker() + + def get_directions(self, normalize=True): + if self._mode == "plucker": + directions = self.rays[..., :3] + elif self._mode == "segment": + directions = self.to_point_direction().get_directions() + else: + directions = self.rays[..., 3:] + if normalize: + directions = torch.nn.functional.normalize(directions, dim=-1) + return directions + + def get_camera_coordinate_rays(self, normalize=True): + directions = self.camera_coordinate_ray_directions + if normalize: + directions = torch.nn.functional.normalize(directions, dim=-1) + return directions + + def get_origins(self, high_res=False): + if self._mode == "plucker": + origins = self.to_point_direction().get_origins(high_res=high_res) + elif self._mode == "ray": + origins = self.rays[..., :3] + elif self._mode == "segment": + origins = Rays.upsample_origins( + self.rays[..., :3], + num_patches_x=self.num_patches_x, + num_patches_y=self.num_patches_y, + depth_resolution=self.depth_resolution, + ) + else: + assert False + + return origins + + def get_moments(self): + if self._mode == "plucker": + moments = self.rays[..., 3:] + elif self._mode in ["ray", "segment"]: + moments = self.to_plucker().get_moments() + + return moments + + def get_segments(self): + assert self._mode == "segment" + + if self.unprojected is not None: + return self.unprojected + else: + return Rays.rows_to_patches( + self.rays[..., 3:], + num_patches_x=self.num_patches_x, + num_patches_y=self.num_patches_y, + depth_resolution=self.depth_resolution, + ) + + def get_ndc_coordinates(self): + return self.ndc_coordinates + + @property + def mode(self): + return self._mode + + @mode.setter + def mode(self, mode): + self._mode = mode + + @property + def device(self): + return self.rays.device + + def __repr__(self, *args, **kwargs): + ray_str = self.rays.__repr__(*args, **kwargs)[6:] # remove "tensor" + + if self._mode == "plucker": + return "PluRay" + ray_str + elif self._mode == "ray": + return "DirRay" + ray_str + else: + return "SegRay" + ray_str + + def to(self, device): + self.rays = self.rays.to(device) + + def clone(self): + return Rays(rays=self.rays.clone(), mode=self._mode) + + @property + def shape(self): + return self.rays.shape + + def visualize(self): + directions = torch.nn.functional.normalize(self.get_directions(), dim=-1).cpu() + moments = torch.nn.functional.normalize(self.get_moments(), dim=-1).cpu() + return (directions + 1) / 2, (moments + 1) / 2 + + def to_ray_bundle(self, length=0.3, recompute_origin=False, true_length=False): + """ + Args: + length (float): Length of the rays for visualization. + recompute_origin (bool): If True, origin is set to the intersection point of + all rays. If False, origins are the point along the ray closest + """ + origins = self.get_origins(high_res=self.depth_resolution > 1) + lengths = torch.ones_like(origins[..., :2]) * length + lengths[..., 0] = 0 + p_intersect, p_closest, _, _ = intersect_skew_line_groups( + origins.float(), self.get_directions().float() + ) + if recompute_origin: + centers = p_intersect + centers = centers.unsqueeze(1).repeat(1, lengths.shape[1], 1) + else: + centers = p_closest + + if true_length: + length = torch.norm(self.get_segments() - centers, dim=-1).unsqueeze(-1) + lengths = torch.ones_like(origins[..., :2]) * length + lengths[..., 0] = 0 + + return RayBundle( + origins=centers, + directions=self.get_directions(), + lengths=lengths, + xys=self.get_directions(), + ) + + +def cameras_to_rays( + cameras, + crop_parameters, + use_half_pix=True, + use_plucker=True, + num_patches_x=16, + num_patches_y=16, + no_crop_param_device="cpu", + distortion_coeffs=None, + depths=None, + visualize=False, + mode=None, + depth_resolution=1, + nearest_neighbor=True, + distortion_coefficients=None, +): + """ + Unprojects rays from camera center to grid on image plane. + + To match Moneish's code, set use_half_pix=False, use_plucker=True. Also, the + arguments to meshgrid should be swapped (x first, then y). I'm following Pytorch3d + convention to have y first. + + distortion_coeffs refers to Amy's distortion experiments + distortion_coefficients refers to the fisheye parameters from colmap + + Args: + cameras: Pytorch3D cameras to unproject. Can be batched. + crop_parameters: Crop parameters in NDC (cc_x, cc_y, crop_width, scale). + Shape is (B, 4). + use_half_pix: If True, use half pixel offset (Default: True). + use_plucker: If True, return rays in plucker coordinates (Default: False). + num_patches_x: Number of patches in x direction (Default: 16). + num_patches_y: Number of patches in y direction (Default: 16). + """ + + unprojected = [] + unprojected_ones = [] + crop_parameters_list = ( + crop_parameters if crop_parameters is not None else [None for _ in cameras] + ) + depths_list = depths if depths is not None else [None for _ in cameras] + if distortion_coeffs is None: + zs = [] + for i, (camera, crop_param, depth) in enumerate( + zip(cameras, crop_parameters_list, depths_list) + ): + xyd_grid = compute_ndc_coordinates( + crop_parameters=crop_param, + use_half_pix=use_half_pix, + num_patches_x=num_patches_x, + num_patches_y=num_patches_y, + no_crop_param_device=no_crop_param_device, + depths=depth, + return_zs=True, + depth_resolution=depth_resolution, + nearest_neighbor=nearest_neighbor, + ) + + xyd_grid, z, ones_grid = xyd_grid + zs.append(z) + + if ( + distortion_coefficients is not None + and (distortion_coefficients[i] != 0).any() + ): + xyd_grid = undistort_ndc_coordinates( + ndc_coordinates=xyd_grid, + principal_point=camera.principal_point[0], + focal_length=camera.focal_length[0], + distortion_coefficients=distortion_coefficients[i], + ) + + unprojected.append( + camera.unproject_points( + xyd_grid.reshape(-1, 3), world_coordinates=True, from_ndc=True + ) + ) + + if depths is not None and mode == "plucker": + unprojected_ones.append( + camera.unproject_points( + ones_grid.reshape(-1, 3), world_coordinates=True, from_ndc=True + ) + ) + + else: + for camera, crop_param, distort_coeff in zip( + cameras, crop_parameters_list, distortion_coeffs + ): + xyd_grid = compute_ndc_coordinates( + crop_parameters=crop_param, + use_half_pix=use_half_pix, + num_patches_x=num_patches_x, + num_patches_y=num_patches_y, + no_crop_param_device=no_crop_param_device, + distortion_coeffs=distort_coeff, + depths=depths, + nearest_neighbor=nearest_neighbor, + ) + + unprojected.append( + camera.unproject_points( + xyd_grid.reshape(-1, 3), world_coordinates=True, from_ndc=True + ) + ) + + unprojected = torch.stack(unprojected, dim=0) # (N, P, 3) + origins = cameras.get_camera_center().unsqueeze(1) # (N, 1, 3) + origins = origins.repeat(1, num_patches_x * num_patches_y, 1) # (N, P, 3) + + if depths is None: + directions = unprojected - origins + rays = Rays( + origins=origins, + directions=directions, + crop_parameters=crop_parameters, + num_patches_x=num_patches_x, + num_patches_y=num_patches_y, + distortion_coeffs=distortion_coeffs, + mode="ray", + unprojected=unprojected, + ) + if use_plucker: + return rays.to_plucker() + elif mode == "segment": + rays = Rays( + origins=origins, + segments=unprojected, + crop_parameters=crop_parameters, + num_patches_x=num_patches_x, + num_patches_y=num_patches_y, + distortion_coeffs=distortion_coeffs, + depths=torch.stack(zs, dim=0), + mode=mode, + unprojected=unprojected, + depth_resolution=depth_resolution, + ) + elif mode == "plucker" or mode == "ray": + unprojected_ones = torch.stack(unprojected_ones) + directions = unprojected_ones - origins + + rays = Rays( + origins=origins, + directions=directions, + crop_parameters=crop_parameters, + num_patches_x=num_patches_x, + num_patches_y=num_patches_y, + distortion_coeffs=distortion_coeffs, + depths=torch.stack(zs, dim=0), + mode="ray", + unprojected=unprojected, + ) + + if mode == "plucker": + rays = rays.to_plucker() + else: + assert False + + if visualize: + return rays, unprojected, torch.stack(zs, dim=0) + + return rays + + +def rays_to_cameras( + rays, + crop_parameters, + num_patches_x=16, + num_patches_y=16, + use_half_pix=True, + no_crop_param_device="cpu", + sampled_ray_idx=None, + cameras=None, + focal_length=(3.453,), + distortion_coeffs=None, + calculate_distortion=False, + depth_resolution=1, + average_centers=False, +): + """ + If cameras are provided, will use those intrinsics. Otherwise will use the provided + focal_length(s). Dataset default is 3.32. + + Args: + rays (Rays): (N, P, 6) + crop_parameters (torch.Tensor): (N, 4) + """ + + device = rays.device + origins = rays.get_origins(high_res=True) + directions = rays.get_directions() + + if average_centers: + camera_centers = torch.mean(origins, dim=1) + else: + camera_centers, _ = intersect_skew_lines_high_dim(origins, directions) + + # Retrieve target rays + if cameras is None: + if len(focal_length) == 1: + focal_length = focal_length * rays.shape[0] + I_camera = PerspectiveCameras(focal_length=focal_length, device=device) + else: + # Use same intrinsics but reset to identity extrinsics. + I_camera = cameras.clone() + I_camera.R[:] = torch.eye(3, device=device) + I_camera.T[:] = torch.zeros(3, device=device) + + if distortion_coeffs is not None and not calculate_distortion: + coeff = distortion_coeffs + else: + coeff = None + + I_patch_rays = cameras_to_rays( + cameras=I_camera, + num_patches_x=num_patches_x * depth_resolution, + num_patches_y=num_patches_y * depth_resolution, + use_half_pix=use_half_pix, + crop_parameters=crop_parameters, + no_crop_param_device=no_crop_param_device, + distortion_coeffs=coeff, + mode="plucker", + depth_resolution=depth_resolution, + ).get_directions() + + if sampled_ray_idx is not None: + I_patch_rays = I_patch_rays[:, sampled_ray_idx] + + # Compute optimal rotation to align rays + R = torch.zeros_like(I_camera.R) + for i in range(len(I_camera)): + R[i] = compute_optimal_rotation_alignment( + I_patch_rays[i], + directions[i], + ) + + # Construct and return rotated camera + cam = I_camera.clone() + cam.R = R + cam.T = -torch.matmul(R.transpose(1, 2), camera_centers.unsqueeze(2)).squeeze(2) + return cam + + +# https://www.reddit.com/r/learnmath/comments/v1crd7/linear_algebra_qr_to_ql_decomposition/ +def ql_decomposition(A): + P = torch.tensor([[0, 0, 1], [0, 1, 0], [1, 0, 0]], device=A.device).float() + A_tilde = torch.matmul(A, P) + Q_tilde, R_tilde = torch.linalg.qr(A_tilde) + Q = torch.matmul(Q_tilde, P) + L = torch.matmul(torch.matmul(P, R_tilde), P) + d = torch.diag(L) + Q[:, 0] *= torch.sign(d[0]) + Q[:, 1] *= torch.sign(d[1]) + Q[:, 2] *= torch.sign(d[2]) + L[0] *= torch.sign(d[0]) + L[1] *= torch.sign(d[1]) + L[2] *= torch.sign(d[2]) + return Q, L + + +def rays_to_cameras_homography( + rays, + crop_parameters, + num_patches_x=16, + num_patches_y=16, + use_half_pix=True, + sampled_ray_idx=None, + reproj_threshold=0.2, + camera_coordinate_rays=False, + average_centers=False, + depth_resolution=1, + directions_from_averaged_center=False, +): + """ + Args: + rays (Rays): (N, P, 6) + crop_parameters (torch.Tensor): (N, 4) + """ + device = rays.device + origins = rays.get_origins(high_res=True) + directions = rays.get_directions() + + if average_centers: + camera_centers = torch.mean(origins, dim=1) + else: + camera_centers, _ = intersect_skew_lines_high_dim(origins, directions) + + if directions_from_averaged_center: + assert rays.mode == "segment" + directions = rays.get_segments() - camera_centers.unsqueeze(1).repeat( + (1, num_patches_x * num_patches_y, 1) + ) + + # Retrieve target rays + I_camera = PerspectiveCameras(focal_length=[1] * rays.shape[0], device=device) + I_patch_rays = cameras_to_rays( + cameras=I_camera, + num_patches_x=num_patches_x * depth_resolution, + num_patches_y=num_patches_y * depth_resolution, + use_half_pix=use_half_pix, + crop_parameters=crop_parameters, + no_crop_param_device=device, + mode="plucker", + ).get_directions() + + if sampled_ray_idx is not None: + I_patch_rays = I_patch_rays[:, sampled_ray_idx] + + # Compute optimal rotation to align rays + if camera_coordinate_rays: + directions_used = rays.get_camera_coordinate_rays() + else: + directions_used = directions + + Rs = [] + focal_lengths = [] + principal_points = [] + for i in range(rays.shape[-3]): + R, f, pp = compute_optimal_rotation_intrinsics( + I_patch_rays[i], + directions_used[i], + reproj_threshold=reproj_threshold, + ) + Rs.append(R) + focal_lengths.append(f) + principal_points.append(pp) + + R = torch.stack(Rs) + focal_lengths = torch.stack(focal_lengths) + principal_points = torch.stack(principal_points) + T = -torch.matmul(R.transpose(1, 2), camera_centers.unsqueeze(2)).squeeze(2) + return PerspectiveCameras( + R=R, + T=T, + focal_length=focal_lengths, + principal_point=principal_points, + device=device, + ) + + +def compute_optimal_rotation_alignment(A, B): + """ + Compute optimal R that minimizes: || A - B @ R ||_F + + Args: + A (torch.Tensor): (N, 3) + B (torch.Tensor): (N, 3) + + Returns: + R (torch.tensor): (3, 3) + """ + # normally with R @ B, this would be A @ B.T + H = B.T @ A + U, _, Vh = torch.linalg.svd(H, full_matrices=True) + s = torch.linalg.det(U @ Vh) + S_prime = torch.diag(torch.tensor([1, 1, torch.sign(s)], device=A.device)) + return U @ S_prime @ Vh + + +def compute_optimal_rotation_intrinsics( + rays_origin, rays_target, z_threshold=1e-4, reproj_threshold=0.2 +): + """ + Note: for some reason, f seems to be 1/f. + + Args: + rays_origin (torch.Tensor): (N, 3) + rays_target (torch.Tensor): (N, 3) + z_threshold (float): Threshold for z value to be considered valid. + + Returns: + R (torch.tensor): (3, 3) + focal_length (torch.tensor): (2,) + principal_point (torch.tensor): (2,) + """ + device = rays_origin.device + z_mask = torch.logical_and( + torch.abs(rays_target) > z_threshold, torch.abs(rays_origin) > z_threshold + )[:, 2] + + rays_target = rays_target[z_mask] + rays_origin = rays_origin[z_mask] + rays_origin = rays_origin[:, :2] / rays_origin[:, -1:] + rays_target = rays_target[:, :2] / rays_target[:, -1:] + + try: + A, _ = cv2.findHomography( + rays_origin.cpu().numpy(), + rays_target.cpu().numpy(), + cv2.RANSAC, + reproj_threshold, + ) + except: + A, _ = cv2.findHomography( + rays_origin.cpu().numpy(), + rays_target.cpu().numpy(), + cv2.RANSAC, + reproj_threshold, + ) + A = torch.from_numpy(A).float().to(device) + + if torch.linalg.det(A) < 0: + # TODO: Find a better fix for this. This gives the correct R but incorrect + # intrinsics. + A = -A + + R, L = ql_decomposition(A) + L = L / L[2][2] + + f = torch.stack((L[0][0], L[1][1])) + # f = torch.stack(((L[0][0] + L[1][1]) / 2, (L[0][0] + L[1][1]) / 2)) + pp = torch.stack((L[2][0], L[2][1])) + return R, f, pp + + +def compute_ndc_coordinates( + crop_parameters=None, + use_half_pix=True, + num_patches_x=16, + num_patches_y=16, + no_crop_param_device="cpu", + distortion_coeffs=None, + depths=None, + return_zs=False, + depth_resolution=1, + nearest_neighbor=True, +): + """ + Computes NDC Grid using crop_parameters. If crop_parameters is not provided, + then it assumes that the crop is the entire image (corresponding to an NDC grid + where top left corner is (1, 1) and bottom right corner is (-1, -1)). + """ + + if crop_parameters is None: + cc_x, cc_y, width = 0, 0, 2 + device = no_crop_param_device + else: + if len(crop_parameters.shape) > 1: + if distortion_coeffs is None: + return torch.stack( + [ + compute_ndc_coordinates( + crop_parameters=crop_param, + use_half_pix=use_half_pix, + num_patches_x=num_patches_x, + num_patches_y=num_patches_y, + nearest_neighbor=nearest_neighbor, + depths=depths[i] if depths is not None else None, + ) + for i, crop_param in enumerate(crop_parameters) + ], + dim=0, + ) + else: + patch_params = zip(crop_parameters, distortion_coeffs) + return torch.stack( + [ + compute_ndc_coordinates( + crop_parameters=crop_param, + use_half_pix=use_half_pix, + num_patches_x=num_patches_x, + num_patches_y=num_patches_y, + distortion_coeffs=distortion_coeff, + nearest_neighbor=nearest_neighbor, + ) + for crop_param, distortion_coeff in patch_params + ], + dim=0, + ) + device = crop_parameters.device + cc_x, cc_y, width, _ = crop_parameters + + dx = 1 / num_patches_x + dy = 1 / num_patches_y + if use_half_pix: + min_y = 1 - dy + max_y = -min_y + min_x = 1 - dx + max_x = -min_x + else: + min_y = min_x = 1 + max_y = -1 + 2 * dy + max_x = -1 + 2 * dx + + y, x = torch.meshgrid( + torch.linspace(min_y, max_y, num_patches_y, dtype=torch.float32, device=device), + torch.linspace(min_x, max_x, num_patches_x, dtype=torch.float32, device=device), + indexing="ij", + ) + + x_prime = x * width / 2 - cc_x + y_prime = y * width / 2 - cc_y + + if distortion_coeffs is not None: + points = torch.cat( + (x_prime.flatten().unsqueeze(-1), y_prime.flatten().unsqueeze(-1)), + dim=-1, + ) + new_points = apply_distortion_tensor( + points, distortion_coeffs[0], distortion_coeffs[1] + ) + x_prime = new_points[:, 0].reshape((num_patches_x, num_patches_y)) + y_prime = new_points[:, 1].reshape((num_patches_x, num_patches_y)) + + if depths is not None: + if depth_resolution > 1: + high_res_grid = compute_ndc_coordinates( + crop_parameters=crop_parameters, + use_half_pix=use_half_pix, + num_patches_x=num_patches_x * depth_resolution, + num_patches_y=num_patches_y * depth_resolution, + no_crop_param_device=no_crop_param_device, + ) + x_prime = high_res_grid[..., 0] + y_prime = high_res_grid[..., 1] + + z = depths + xyd_grid = torch.stack([x_prime, y_prime, z], dim=-1) + else: + z = torch.ones_like(x) + + xyd_grid = torch.stack([x_prime, y_prime, z], dim=-1) + xyd_grid_ones = torch.stack([x_prime, y_prime, torch.ones_like(x_prime)], dim=-1) + + if return_zs: + return xyd_grid, z, xyd_grid_ones + + return xyd_grid + + +def undistort_ndc_coordinates( + ndc_coordinates, principal_point, focal_length, distortion_coefficients +): + """ + Given NDC coordinates from a fisheye camera, computes where the coordinates would + have been for a pinhole camera. + + Args: + ndc_coordinates (torch.Tensor): (H, W, 3) + principal_point (torch.Tensor): (2,) + focal_length (torch.Tensor): (2,) + distortion_coefficients (torch.Tensor): (4,) + + Returns: + torch.Tensor: (H, W, 3) + """ + device = ndc_coordinates.device + x = ndc_coordinates[..., 0] + y = ndc_coordinates[..., 1] + d = ndc_coordinates[..., 2] + # Compute normalized coordinates (using opencv convention where negative is top-left + x = -(x - principal_point[0]) / focal_length[0] + y = -(y - principal_point[1]) / focal_length[1] + distorted = torch.stack((x.flatten(), y.flatten()), 1).unsqueeze(1).cpu().numpy() + undistorted = cv2.fisheye.undistortPoints( + distorted, np.eye(3), distortion_coefficients.cpu().numpy(), np.eye(3) + ) + u = torch.tensor(undistorted[:, 0, 0], device=device) + v = torch.tensor(undistorted[:, 0, 1], device=device) + new_x = -u * focal_length[0] + principal_point[0] + new_y = -v * focal_length[1] + principal_point[1] + return torch.stack((new_x.reshape(x.shape), new_y.reshape(y.shape), d), -1) + + +def get_identity_cameras_with_intrinsics(cameras): + D = len(cameras) + device = cameras.R.device + + new_cameras = cameras.clone() + new_cameras.R = torch.eye(3, device=device).unsqueeze(0).repeat((D, 1, 1)) + new_cameras.T = torch.zeros((D, 3), device=device) + + return new_cameras + + +def normalize_cameras_batch( + cameras, + scale=1.0, + normalize_first_camera=False, + depths=None, + crop_parameters=None, + num_patches_x=16, + num_patches_y=16, + distortion_coeffs=[None], + first_cam_mediod=False, + return_scales=False, +): + new_cameras = [] + undo_transforms = [] + scales = [] + for i, cam in enumerate(cameras): + if normalize_first_camera: + # Normalize cameras such that first camera is identity and origin is at + # first camera center. + + s = 1 + if first_cam_mediod: + s = scale_first_cam_mediod( + cam[0], + depths=depths[i][0].unsqueeze(0) if depths is not None else None, + crop_parameters=crop_parameters[i][0].unsqueeze(0), + num_patches_x=num_patches_x, + num_patches_y=num_patches_y, + distortion_coeffs=( + distortion_coeffs[i][0].unsqueeze(0) + if distortion_coeffs[i] is not None + else None + ), + ) + scales.append(s) + + normalized_cameras = first_camera_transform(cam, s, rotation_only=False) + undo_transform = None + else: + out = normalize_cameras(cam, scale=scale, return_scale=depths is not None) + normalized_cameras, undo_transform, s = out + + if depths is not None: + depths[i] *= s + + if depths.isnan().any(): + assert False + + new_cameras.append(normalized_cameras) + undo_transforms.append(undo_transform) + + if return_scales: + return new_cameras, undo_transforms, scales + + return new_cameras, undo_transforms + + +def scale_first_cam_mediod( + cameras, + scale=1.0, + return_scale=False, + depths=None, + crop_parameters=None, + num_patches_x=16, + num_patches_y=16, + distortion_coeffs=None, +): + xy_grid = ( + compute_ndc_coordinates( + depths=depths, + crop_parameters=crop_parameters, + num_patches_x=num_patches_x, + num_patches_y=num_patches_y, + distortion_coeffs=distortion_coeffs, + ) + .reshape((-1, 3)) + .to(depths.device) + ) + verts = cameras.unproject_points(xy_grid, from_ndc=True, world_coordinates=True) + p_intersect = torch.median( + verts.reshape((-1, 3))[: num_patches_x * num_patches_y].float(), dim=0 + ).values.unsqueeze(0) + d = torch.norm(p_intersect - cameras.get_camera_center()) + + if d < 0.001: + return 1 + + return 1 / d + + +def normalize_cameras(cameras, scale=1.0, return_scale=False): + """ + Normalizes cameras such that the optical axes point to the origin, the rotation is + identity, and the norm of the translation of the first camera is 1. + + Args: + cameras (pytorch3d.renderer.cameras.CamerasBase). + scale (float): Norm of the translation of the first camera. + + Returns: + new_cameras (pytorch3d.renderer.cameras.CamerasBase): Normalized cameras. + undo_transform (function): Function that undoes the normalization. + """ + # Let distance from first camera to origin be unit + new_cameras = cameras.clone() + new_transform = ( + new_cameras.get_world_to_view_transform() + ) # potential R is not valid matrix + + p_intersect, dist, _, _, _ = compute_optical_axis_intersection(cameras) + + if p_intersect is None: + print("Warning: optical axes code has a nan. Returning identity cameras.") + new_cameras.R[:] = torch.eye(3, device=cameras.R.device, dtype=cameras.R.dtype) + new_cameras.T[:] = torch.tensor( + [0, 0, 1], device=cameras.T.device, dtype=cameras.T.dtype + ) + return new_cameras, lambda x: x, 1 / scale + + d = dist.squeeze(dim=1).squeeze(dim=0)[0] + # Degenerate case + if d == 0: + print(cameras.T) + print(new_transform.get_matrix()[:, 3, :3]) + assert False + assert d != 0 + + # Can't figure out how to make scale part of the transform too without messing up R. + # Ideally, we would just wrap it all in a single Pytorch3D transform so that it + # would work with any structure (eg PointClouds, Meshes). + tR = Rotate(new_cameras.R[0].unsqueeze(0)).inverse() + tT = Translate(p_intersect) + t = tR.compose(tT) + + new_transform = t.compose(new_transform) + new_cameras.R = new_transform.get_matrix()[:, :3, :3] + new_cameras.T = new_transform.get_matrix()[:, 3, :3] / d * scale + + def undo_transform(cameras): + cameras_copy = cameras.clone() + cameras_copy.T *= d / scale + new_t = ( + t.inverse().compose(cameras_copy.get_world_to_view_transform()).get_matrix() + ) + cameras_copy.R = new_t[:, :3, :3] + cameras_copy.T = new_t[:, 3, :3] + return cameras_copy + + if return_scale: + return new_cameras, undo_transform, scale / d + + return new_cameras, undo_transform + + +def first_camera_transform(cameras, s, rotation_only=True): + new_cameras = cameras.clone() + new_transform = new_cameras.get_world_to_view_transform() + tR = Rotate(new_cameras.R[0].unsqueeze(0)) + if rotation_only: + t = tR.inverse() + else: + tT = Translate(new_cameras.T[0].unsqueeze(0)) + t = tR.compose(tT).inverse() + + new_transform = t.compose(new_transform) + new_cameras.R = new_transform.get_matrix()[:, :3, :3] + new_cameras.T = new_transform.get_matrix()[:, 3, :3] * s + + return new_cameras diff --git a/diffusionsfm/utils/slurm.py b/diffusionsfm/utils/slurm.py new file mode 100644 index 0000000000000000000000000000000000000000..d9bccdae97203324f4d6537ead772077ca704f40 --- /dev/null +++ b/diffusionsfm/utils/slurm.py @@ -0,0 +1,87 @@ +import os +import signal +import subprocess +import sys +import time + + +def submitit_job_watcher(jobs, check_period: int = 15): + job_out = {} + + try: + while True: + job_states = [job.state for job in jobs.values()] + state_counts = { + state: len([j for j in job_states if j == state]) + for state in set(job_states) + } + + n_done = sum(job.done() for job in jobs.values()) + + for job_name, job in jobs.items(): + if job_name not in job_out and job.done(): + job_out[job_name] = { + "stderr": job.stderr(), + "stdout": job.stdout(), + } + + exc = job.exception() + if exc is not None: + print(f"{job_name} crashed!!!") + if job_out[job_name]["stderr"] is not None: + print("===== STDERR =====") + print(job_out[job_name]["stderr"]) + else: + print(f"{job_name} done!") + + print("Job states:") + for state, count in state_counts.items(): + print(f" {state:15s} {count:6d} ({100.*count/len(jobs):.1f}%)") + + if n_done == len(jobs): + print("All done!") + return + + time.sleep(check_period) + + except KeyboardInterrupt: + for job_name, job in jobs.items(): + if not job.done(): + print(f"Killing {job_name}") + job.cancel(check=False) + + +def get_jid(): + if "SLURM_ARRAY_TASK_ID" in os.environ: + return f"{os.environ['SLURM_ARRAY_JOB_ID']}_{os.environ['SLURM_ARRAY_TASK_ID']}" + return os.environ["SLURM_JOB_ID"] + + +def signal_helper(signum, frame): + print(f"Caught signal {signal.Signals(signum).name} on for the this job") + jid = get_jid() + cmd = ["scontrol", "requeue", jid] + try: + print("calling", cmd) + rtn = subprocess.check_call(cmd) + print("subprocc", rtn) + except: + print("subproc call failed") + return sys.exit(10) + + +def bypass(signum, frame): + print(f"Ignoring signal {signal.Signals(signum).name} on for the this job") + + +def init_slurm_signals(): + signal.signal(signal.SIGCONT, bypass) + signal.signal(signal.SIGCHLD, bypass) + signal.signal(signal.SIGTERM, bypass) + signal.signal(signal.SIGUSR2, signal_helper) + print("SLURM signal installed", flush=True) + + +def init_slurm_signals_if_slurm(): + if "SLURM_JOB_ID" in os.environ: + init_slurm_signals() diff --git a/diffusionsfm/utils/visualization.py b/diffusionsfm/utils/visualization.py new file mode 100644 index 0000000000000000000000000000000000000000..0f235193872c660211673b73ad1cfefdebb5aa73 --- /dev/null +++ b/diffusionsfm/utils/visualization.py @@ -0,0 +1,550 @@ +from http.client import MOVED_PERMANENTLY +import io + +import ipdb # noqa: F401 +import matplotlib.pyplot as plt +import numpy as np +import trimesh +import torch +import torchvision +from pytorch3d.loss import chamfer_distance +from scipy.spatial.transform import Rotation + +from diffusionsfm.inference.ddim import inference_ddim +from diffusionsfm.utils.rays import ( + Rays, + cameras_to_rays, + rays_to_cameras, + rays_to_cameras_homography, +) +from diffusionsfm.utils.geometry import ( + compute_optimal_alignment, +) + +cmap = plt.get_cmap("hsv") + + +def create_training_visualizations( + model, + images, + device, + cameras_gt, + num_images, + crop_parameters, + pred_x0=False, + no_crop_param_device="cpu", + visualize_pred=False, + return_first=False, + calculate_intrinsics=False, + mode=None, + depths=None, + scale_min=-1, + scale_max=1, + diffuse_depths=False, + vis_mode=None, + average_centers=True, + full_num_patches_x=16, + full_num_patches_y=16, + use_homogeneous=False, + distortion_coefficients=None, +): + + if model.depth_resolution == 1: + W_in = W_out = full_num_patches_x + H_in = H_out = full_num_patches_y + else: + W_in = H_in = model.width + W_out = model.width * model.depth_resolution + H_out = model.width * model.depth_resolution + + rays_final, rays_intermediate, pred_intermediate, _ = inference_ddim( + model, + images, + device, + crop_parameters=crop_parameters, + eta=[1, 0], + num_patches_x=W_in, + num_patches_y=H_in, + visualize=True, + ) + + if vis_mode is None: + vis_mode = mode + + T = model.noise_scheduler.max_timesteps + if T == 1000: + ts = [0, 100, 200, 300, 400, 500, 600, 700, 800, 900, 999] + else: + ts = [0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 99] + + # Get predicted cameras from rays + pred_cameras_batched = [] + vis_images = [] + pred_rays = [] + for index in range(len(images)): + pred_cameras = [] + per_sample_images = [] + for ii in range(num_images): + rays_gt = cameras_to_rays( + cameras_gt[index], + crop_parameters[index], + no_crop_param_device=no_crop_param_device, + num_patches_x=W_in, + num_patches_y=H_in, + depths=None if depths is None else depths[index], + mode=mode, + depth_resolution=model.depth_resolution, + distortion_coefficients=( + None + if distortion_coefficients is None + else distortion_coefficients[index] + ), + ) + image_vis = (images[index, ii].cpu().permute(1, 2, 0).numpy() + 1) / 2 + + if diffuse_depths: + fig, axs = plt.subplots(3, 13, figsize=(15, 4.5), dpi=100) + else: + fig, axs = plt.subplots(3, 9, figsize=(12, 4.5), dpi=100) + + for i, t in enumerate(ts): + r, c = i // 4, i % 4 + if visualize_pred: + curr = pred_intermediate[t][index] + else: + curr = rays_intermediate[t][index] + rays = Rays.from_spatial( + curr, + mode=mode, + num_patches_x=H_in, + num_patches_y=W_in, + use_homogeneous=use_homogeneous, + ) + + if vis_mode == "segment": + vis = ( + torch.clip( + rays.get_segments()[ii], min=scale_min, max=scale_max + ) + - scale_min + ) / (scale_max - scale_min) + + else: + vis = ( + torch.nn.functional.normalize(rays.get_moments()[ii], dim=-1) + + 1 + ) / 2 + + axs[r, c].imshow(vis.reshape(W_out, H_out, 3).cpu()) + axs[r, c].set_title(f"T={T - t}") + + i += 1 + r, c = i // 4, i % 4 + + if vis_mode == "segment": + vis = ( + torch.clip(rays_gt.get_segments()[ii], min=scale_min, max=scale_max) + - scale_min + ) / (scale_max - scale_min) + else: + vis = ( + torch.nn.functional.normalize(rays_gt.get_moments()[ii], dim=-1) + 1 + ) / 2 + + axs[r, c].imshow(vis.reshape(W_out, H_out, 3).cpu()) + + type_str = "Endpoints" if vis_mode == "segment" else "Moments" + axs[r, c].set_title(f"GT {type_str}") + + for i, t in enumerate(ts): + r, c = i // 4, i % 4 + 4 + if visualize_pred: + curr = pred_intermediate[t][index] + else: + curr = rays_intermediate[t][index] + rays = Rays.from_spatial( + curr, + mode, + num_patches_x=H_in, + num_patches_y=W_in, + use_homogeneous=use_homogeneous, + ) + + if vis_mode == "segment": + vis = ( + torch.clip( + rays.get_origins(high_res=True)[ii], + min=scale_min, + max=scale_max, + ) + - scale_min + ) / (scale_max - scale_min) + else: + vis = ( + torch.nn.functional.normalize(rays.get_directions()[ii], dim=-1) + + 1 + ) / 2 + + axs[r, c].imshow(vis.reshape(W_out, H_out, 3).cpu()) + axs[r, c].set_title(f"T={T - t}") + + i += 1 + r, c = i // 4, i % 4 + 4 + + if vis_mode == "segment": + vis = ( + torch.clip( + rays_gt.get_origins(high_res=True)[ii], + min=scale_min, + max=scale_max, + ) + - scale_min + ) / (scale_max - scale_min) + else: + vis = ( + torch.nn.functional.normalize(rays_gt.get_directions()[ii], dim=-1) + + 1 + ) / 2 + axs[r, c].imshow(vis.reshape(W_out, H_out, 3).cpu()) + type_str = "Origins" if vis_mode == "segment" else "Directions" + axs[r, c].set_title(f"GT {type_str}") + + if diffuse_depths: + for i, t in enumerate(ts): + r, c = i // 4, i % 4 + 8 + if visualize_pred: + curr = pred_intermediate[t][index] + else: + curr = rays_intermediate[t][index] + rays = Rays.from_spatial( + curr, + mode, + num_patches_x=H_in, + num_patches_y=W_in, + use_homogeneous=use_homogeneous, + ) + + vis = rays.depths[ii] + if len(rays.depths[ii].shape) < 2: + vis = rays.depths[ii].reshape(H_out, W_out) + + axs[r, c].imshow(vis.cpu()) + axs[r, c].set_title(f"T={T - t}") + + i += 1 + r, c = i // 4, i % 4 + 8 + + vis = depths[index][ii] + if len(rays.depths[ii].shape) < 2: + vis = depths[index][ii].reshape(256, 256) + + axs[r, c].imshow(vis.cpu()) + axs[r, c].set_title(f"GT Depths") + + axs[2, -1].imshow(image_vis) + axs[2, -1].set_title("Input Image") + for s in ["bottom", "top", "left", "right"]: + axs[2, -1].spines[s].set_color(cmap(ii / (num_images))) + axs[2, -1].spines[s].set_linewidth(5) + + for ax in axs.flatten(): + ax.set_xticks([]) + ax.set_yticks([]) + plt.tight_layout() + img = plot_to_image(fig) + plt.close() + per_sample_images.append(img) + + if return_first: + rays_camera = pred_intermediate[0][index] + elif pred_x0: + rays_camera = pred_intermediate[-1][index] + else: + rays_camera = rays_final[index] + rays = Rays.from_spatial( + rays_camera, + mode=mode, + num_patches_x=H_in, + num_patches_y=W_in, + use_homogeneous=use_homogeneous, + ) + if calculate_intrinsics: + pred_camera = rays_to_cameras_homography( + rays=rays[ii, None], + crop_parameters=crop_parameters[index], + num_patches_x=W_in, + num_patches_y=H_in, + average_centers=average_centers, + depth_resolution=model.depth_resolution, + ) + else: + pred_camera = rays_to_cameras( + rays=rays[ii, None], + crop_parameters=crop_parameters[index], + no_crop_param_device=no_crop_param_device, + num_patches_x=W_in, + num_patches_y=H_in, + depth_resolution=model.depth_resolution, + average_centers=average_centers, + ) + pred_cameras.append(pred_camera[0]) + pred_rays.append(rays) + + pred_cameras_batched.append(pred_cameras) + vis_images.append(np.vstack(per_sample_images)) + + return vis_images, pred_cameras_batched, pred_rays + + +def plot_to_image(figure, dpi=100): + """Converts matplotlib fig to a png for logging with tf.summary.image.""" + buffer = io.BytesIO() + figure.savefig(buffer, format="raw", dpi=dpi) + plt.close(figure) + buffer.seek(0) + image = np.reshape( + np.frombuffer(buffer.getvalue(), dtype=np.uint8), + newshape=(int(figure.bbox.bounds[3]), int(figure.bbox.bounds[2]), -1), + ) + return image[..., :3] + + +def view_color_coded_images_from_tensor(images, depth=False): + num_frames = images.shape[0] + cmap = plt.get_cmap("hsv") + num_rows = 3 + num_cols = 3 + figsize = (num_cols * 2, num_rows * 2) + fig, axs = plt.subplots(num_rows, num_cols, figsize=figsize) + axs = axs.flatten() + for i in range(num_rows * num_cols): + if i < num_frames: + if images[i].shape[0] == 3: + image = images[i].permute(1, 2, 0) + else: + image = images[i].unsqueeze(-1) + + if not depth: + image = image * 0.5 + 0.5 + else: + image = image.repeat((1, 1, 3)) / torch.max(image) + + axs[i].imshow(image) + for s in ["bottom", "top", "left", "right"]: + axs[i].spines[s].set_color(cmap(i / (num_frames))) + axs[i].spines[s].set_linewidth(5) + axs[i].set_xticks([]) + axs[i].set_yticks([]) + else: + axs[i].axis("off") + plt.tight_layout() + return fig + + +def color_and_filter_points(points, images, mask, num_show, resolution): + # Resize images + resize = torchvision.transforms.Resize(resolution) + images = resize(images) * 0.5 + 0.5 + + # Reshape points and calculate mask + points = points.reshape(num_show * resolution * resolution, 3) + mask = mask.reshape(num_show * resolution * resolution) + depth_mask = torch.argwhere(mask > 0.5)[:, 0] + points = points[depth_mask] + + # Mask and reshape colors + colors = images.permute(0, 2, 3, 1).reshape(num_show * resolution * resolution, 3) + colors = colors[depth_mask] + + return points, colors + + +def filter_and_align_point_clouds( + num_frames, + gt_points, + pred_points, + gt_masks, + pred_masks, + images, + metrics=False, + num_patches_x=16, +): + + # Filter and color points + gt_points, gt_colors = color_and_filter_points( + gt_points, images, gt_masks, num_show=num_frames, resolution=num_patches_x + ) + pred_points, pred_colors = color_and_filter_points( + pred_points, images, pred_masks, num_show=num_frames, resolution=num_patches_x + ) + + pred_points, _, _, _ = compute_optimal_alignment( + gt_points.float(), pred_points.float() + ) + + # Scale PCL so that furthest point from centroid is distance 1 + centroid = torch.mean(gt_points, dim=0) + dists = torch.norm(gt_points - centroid.unsqueeze(0), dim=-1) + scale = torch.mean(dists) + gt_points_scaled = (gt_points - centroid) / scale + pred_points_scaled = (pred_points - centroid) / scale + + if metrics: + + cd, _ = chamfer_distance( + pred_points_scaled.unsqueeze(0), gt_points_scaled.unsqueeze(0) + ) + cd = cd.item() + mse = torch.mean( + torch.norm(pred_points_scaled - gt_points_scaled, dim=-1), dim=-1 + ).item() + else: + mse, cd = None, None + + return ( + gt_points, + pred_points, + gt_colors, + pred_colors, + [mse, cd, None], + ) + + +def add_scene_cam(scene, c2w, edge_color, image=None, focal=None, imsize=None, screen_width=0.03): + OPENGL = np.array([ + [1, 0, 0, 0], + [0, -1, 0, 0], + [0, 0, -1, 0], + [0, 0, 0, 1] + ]) + + if image is not None: + H, W, THREE = image.shape + assert THREE == 3 + if image.dtype != np.uint8: + image = np.uint8(255*image) + elif imsize is not None: + W, H = imsize + elif focal is not None: + H = W = focal / 1.1 + else: + H = W = 1 + + if focal is None: + focal = min(H, W) * 1.1 # default value + elif isinstance(focal, np.ndarray): + focal = focal[0] + + # create fake camera + height = focal * screen_width / H + width = screen_width * 0.5**0.5 + rot45 = np.eye(4) + rot45[:3, :3] = Rotation.from_euler('z', np.deg2rad(45)).as_matrix() + rot45[2, 3] = -height # set the tip of the cone = optical center + aspect_ratio = np.eye(4) + aspect_ratio[0, 0] = W/H + transform = c2w @ OPENGL @ aspect_ratio @ rot45 + cam = trimesh.creation.cone(width, height, sections=4) + + # this is the camera mesh + rot2 = np.eye(4) + rot2[:3, :3] = Rotation.from_euler('z', np.deg2rad(4)).as_matrix() + vertices = cam.vertices + vertices_offset = 0.9 * cam.vertices + vertices = np.r_[vertices, vertices_offset, geotrf(rot2, cam.vertices)] + vertices = geotrf(transform, vertices) + faces = [] + for face in cam.faces: + if 0 in face: + continue + a, b, c = face + a2, b2, c2 = face + len(cam.vertices) + + # add 3 pseudo-edges + faces.append((a, b, b2)) + faces.append((a, a2, c)) + faces.append((c2, b, c)) + + faces.append((a, b2, a2)) + faces.append((a2, c, c2)) + faces.append((c2, b2, b)) + + # no culling + faces += [(c, b, a) for a, b, c in faces] + + for i,face in enumerate(cam.faces): + if 0 in face: + continue + + if i == 1 or i == 5: + a, b, c = face + faces.append((a, b, c)) + + cam = trimesh.Trimesh(vertices=vertices, faces=faces) + cam.visual.face_colors[:, :3] = edge_color + + scene.add_geometry(cam) + + +def geotrf(Trf, pts, ncol=None, norm=False): + """ Apply a geometric transformation to a list of 3-D points. + + H: 3x3 or 4x4 projection matrix (typically a Homography) + p: numpy/torch/tuple of coordinates. Shape must be (...,2) or (...,3) + + ncol: int. number of columns of the result (2 or 3) + norm: float. if != 0, the resut is projected on the z=norm plane. + + Returns an array of projected 2d points. + """ + assert Trf.ndim >= 2 + if isinstance(Trf, np.ndarray): + pts = np.asarray(pts) + elif isinstance(Trf, torch.Tensor): + pts = torch.as_tensor(pts, dtype=Trf.dtype) + + # adapt shape if necessary + output_reshape = pts.shape[:-1] + ncol = ncol or pts.shape[-1] + + # optimized code + if (isinstance(Trf, torch.Tensor) and isinstance(pts, torch.Tensor) and + Trf.ndim == 3 and pts.ndim == 4): + d = pts.shape[3] + if Trf.shape[-1] == d: + pts = torch.einsum("bij, bhwj -> bhwi", Trf, pts) + elif Trf.shape[-1] == d+1: + pts = torch.einsum("bij, bhwj -> bhwi", Trf[:, :d, :d], pts) + Trf[:, None, None, :d, d] + else: + raise ValueError(f'bad shape, not ending with 3 or 4, for {pts.shape=}') + else: + if Trf.ndim >= 3: + n = Trf.ndim-2 + assert Trf.shape[:n] == pts.shape[:n], 'batch size does not match' + Trf = Trf.reshape(-1, Trf.shape[-2], Trf.shape[-1]) + + if pts.ndim > Trf.ndim: + # Trf == (B,d,d) & pts == (B,H,W,d) --> (B, H*W, d) + pts = pts.reshape(Trf.shape[0], -1, pts.shape[-1]) + elif pts.ndim == 2: + # Trf == (B,d,d) & pts == (B,d) --> (B, 1, d) + pts = pts[:, None, :] + + if pts.shape[-1]+1 == Trf.shape[-1]: + Trf = Trf.swapaxes(-1, -2) # transpose Trf + pts = pts @ Trf[..., :-1, :] + Trf[..., -1:, :] + elif pts.shape[-1] == Trf.shape[-1]: + Trf = Trf.swapaxes(-1, -2) # transpose Trf + pts = pts @ Trf + else: + pts = Trf @ pts.T + if pts.ndim >= 2: + pts = pts.swapaxes(-1, -2) + + if norm: + pts = pts / pts[..., -1:] # DONT DO /= BECAUSE OF WEIRD PYTORCH BUG + if norm != 1: + pts *= norm + + res = pts[..., :ncol].reshape(*output_reshape, ncol) + return res \ No newline at end of file diff --git a/docs/eval.md b/docs/eval.md new file mode 100644 index 0000000000000000000000000000000000000000..702899c286fb536755a66a9397f20fd6ab05e17f --- /dev/null +++ b/docs/eval.md @@ -0,0 +1,27 @@ +## Evaluation Directions + +Use the scripts from `diffusionsfm/eval` to evaluate the performance of the dense model on the CO3D dataset: + +``` +python -m diffusionsfm.eval.eval_jobs --eval_path output/multi_diffusionsfm_dense --use_submitit +``` + +**Note:** The `use_submitit` flag is optional. If you have a SLURM system available, enabling it will dispatch jobs in parallel across available GPUs, significantly accelerating the evaluation process. + +The expected output at the end of evaluating the dense model is: + +``` +N= 2 3 4 5 6 7 8 +Seen R 0.926 0.941 0.946 0.950 0.953 0.955 0.955 +Seen CC 1.000 0.956 0.934 0.924 0.917 0.911 0.907 +Seen CD 0.023 0.023 0.026 0.026 0.028 0.031 0.030 +Seen CD_Obj 0.040 0.037 0.033 0.032 0.032 0.032 0.033 +Unseen R 0.913 0.928 0.938 0.945 0.950 0.951 0.953 +Unseen CC 1.000 0.926 0.884 0.870 0.864 0.851 0.847 +Unseen CD 0.024 0.024 0.025 0.024 0.025 0.026 0.027 +Unseen CD_Obj 0.028 0.023 0.022 0.022 0.023 0.021 0.020 +``` + +This reports rotation and camera center accuracy, as well as Chamfer Distance on both all points (CD) and foreground points (CD_Obj), evaluated on held-out sequences from both seen and unseen object categories using varying numbers of input images. Performance is averaged over five runs to reduce variance. + +Note that minor variations in the reported numbers may occur due to randomness in the evaluation and inference processes. \ No newline at end of file diff --git a/docs/train.md b/docs/train.md new file mode 100644 index 0000000000000000000000000000000000000000..a90432489c8eb6778f4878fbfbff175f76703c59 --- /dev/null +++ b/docs/train.md @@ -0,0 +1,58 @@ +## Training Directions + +### Prepare CO3D Dataset + +Please refer to the instructions from [RayDiffusion](https://github.com/jasonyzhang/RayDiffusion/blob/main/docs/train.md#training-directions) to set up the CO3D dataset. + +### Setting up `accelerate` + +Use `accelerate config` to set up `accelerate`. We recommend using multiple GPUs without any mixed precision (we handle AMP ourselves). + +### Training models + +Our model is trained in two stages. In the first stage, we train a *sparse model* that predicts ray origins and endpoints at a low resolution (16×16). In the second stage, we initialize the dense model using the DiT weights from the sparse model and append a DPT decoder to produce high-resolution outputs (256×256 ray origins and endpoints). + +To train the sparse model, run: + +``` +accelerate launch --multi_gpu --gpu_ids 0,1,2,3,4,5,6,7 --num_processes 8 train.py \ + training.batch_size=8 \ + training.max_iterations=400000 \ + model.num_images=8 \ + dataset.name=co3d \ + debug.project_name=diffusionsfm_co3d \ + debug.run_name=co3d_diffusionsfm_sparse +``` + +To train the dense model (initialized from the sparse model weights), run: + +``` +accelerate launch --multi_gpu --gpu_ids 0,1,2,3,4,5,6,7 --num_processes 8 train.py \ + training.batch_size=4 \ + training.max_iterations=800000 \ + model.num_images=8 \ + dataset.name=co3d \ + debug.project_name=diffusionsfm_co3d \ + debug.run_name=co3d_diffusionsfm_dense \ + training.dpt_head=True \ + training.full_num_patches_x=256 \ + training.full_num_patches_y=256 \ + training.gradient_clipping=True \ + training.reinit=True \ + training.freeze_encoder=True \ + model.freeze_transformer=True \ + training.pretrain_path=.pth +``` + +Some notes: + +- `batch_size` refers to the batch size per GPU. The total batch size will be `batch_size * num_gpu`. +- Depending on your setup, you can adjust the number of GPUs and batch size. You may also need to adjust the number of training iterations accordingly. +- You can resume training from a checkpoint by specifying `train.resume=True hydra.run.dir=/path/to/your/output_dir` +- If you are getting NaNs, try turning off mixed precision. This will increase the amount of memory used. + +For debugging, we recommend using a single-GPU job with a single category: + +``` +accelerate launch train.py training.batch_size=4 dataset.category=apple debug.wandb=False hydra.run.dir=output_debug +``` \ No newline at end of file diff --git a/gradio_app.py b/gradio_app.py new file mode 100644 index 0000000000000000000000000000000000000000..23bbee161468f3f597c831d77d5ee3f040bf4e31 --- /dev/null +++ b/gradio_app.py @@ -0,0 +1,244 @@ +import os +import time +import shutil +import argparse +import functools +import torch +import torchvision +from PIL import Image +import gradio as gr +import numpy as np +import matplotlib.pyplot as plt +import trimesh + +from diffusionsfm.dataset.custom import CustomDataset +from diffusionsfm.dataset.co3d_v2 import unnormalize_image +from diffusionsfm.inference.load_model import load_model +from diffusionsfm.inference.predict import predict_cameras +from diffusionsfm.utils.visualization import add_scene_cam + + +def info_fn(): + gr.Info("Data preprocessing completed!") + + +def get_select_index(evt: gr.SelectData): + selected = evt.index + return examples_full[selected][0], selected + + +def check_img_input(control_image): + if control_image is None: + raise gr.Error("Please select or upload an input image.") + + +def preprocess(args, image_block, selected): + cate_name = time.strftime("%m%d_%H%M%S") if selected is None else examples_list[selected] + + demo_dir = os.path.join(args.output_dir, f'demo/{cate_name}') + shutil.rmtree(demo_dir, ignore_errors=True) + + os.makedirs(os.path.join(demo_dir, 'source'), exist_ok=True) + os.makedirs(os.path.join(demo_dir, 'processed'), exist_ok=True) + + dataset = CustomDataset(image_block) + batch = dataset.get_data() + batch['cate_name'] = cate_name + + processed_image_block = [] + for i, file_path in enumerate(image_block): + file_name = os.path.basename(file_path) + raw_img = Image.open(file_path) + try: + raw_img.save(os.path.join(demo_dir, 'source', file_name)) + except OSError: + raw_img.convert('RGB').save(os.path.join(demo_dir, 'source', file_name)) + + batch['image_for_vis'][i].save(os.path.join(demo_dir, 'processed', file_name)) + processed_image_block.append(os.path.join(demo_dir, 'processed', file_name)) + + return processed_image_block, batch + + +def transform_cameras(pred_cameras): + num_cameras = pred_cameras.R.shape[0] + Rs = pred_cameras.R.transpose(1, 2).detach() + ts = pred_cameras.T.unsqueeze(-1).detach() + c2ws = torch.zeros(num_cameras, 4, 4) + c2ws[:, :3, :3] = Rs + c2ws[:, :3, -1:] = ts + c2ws[:, 3, 3] = 1 + c2ws[:, :2] *= -1 # PyTorch3D to OpenCV + c2ws = torch.linalg.inv(c2ws).numpy() + + return c2ws + + +def run_inference(args, cfg, model, batch): + device = args.device + images = batch["image"].to(device) + crop_parameters = batch["crop_parameters"].to(device) + + _, additional_cams = predict_cameras( + model=model, + images=images, + device=device, + crop_parameters=crop_parameters, + num_patches_x=cfg.training.full_num_patches_x, + num_patches_y=cfg.training.full_num_patches_y, + additional_timesteps=list(range(11)), + calculate_intrinsics=True, + max_num_images=8, + mode="segment", + return_rays=True, + use_homogeneous=True, + seed=0, + ) + pred_cameras, pred_rays = additional_cams[10] + + # Unnormalize and resize input images + images = unnormalize_image(images, return_numpy=False, return_int=False) + images = torchvision.transforms.Resize(256)(images) + rgbs = images.permute(0, 2, 3, 1).contiguous().view(-1, 3) + xyzs = pred_rays.get_segments().view(-1, 3).cpu() + + # Create point cloud and scene + scene = trimesh.Scene() + point_cloud = trimesh.points.PointCloud(xyzs, colors=rgbs) + scene.add_geometry(point_cloud) + + # Add predicted cameras to the scene + num_images = images.shape[0] + c2ws = transform_cameras(pred_cameras) + cmap = plt.get_cmap("hsv") + + for i, c2w in enumerate(c2ws): + color_rgb = (np.array(cmap(i / num_images))[:3] * 255).astype(int) + add_scene_cam( + scene=scene, + c2w=c2w, + edge_color=color_rgb, + image=None, + focal=None, + imsize=(256, 256), + screen_width=0.1 + ) + + # Export GLB + cate_name = batch['cate_name'] + output_path = os.path.join(args.output_dir, f'demo/{cate_name}/{cate_name}.glb') + scene.export(output_path) + + return output_path + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--output_dir', default='output/multi_diffusionsfm_dense', type=str, help='Output directory') + parser.add_argument('--device', default='cuda', type=str, help='Device to run inference on') + args = parser.parse_args() + + _TITLE = "DiffusionSfM: Predicting Structure and Motion via Ray Origin and Endpoint Diffusion" + _DESCRIPTION = """ +
+ + +
+ DiffusionSfM learns to predict scene geometry and camera poses as pixel-wise ray origins and endpoints using a denoising diffusion model. + """ + + # Load demo examples + examples_list = ["kew_gardens_ruined_arch", "jellycat", "kotor_cathedral", "jordan"] + examples_full = [] + for example in examples_list: + folder = os.path.join(os.path.dirname(__file__), "data/demo", example) + examples = sorted(os.path.join(folder, x) for x in os.listdir(folder)) + examples_full.append([examples]) + + model, cfg = load_model(args.output_dir, device=args.device) + print("Loaded DiffusionSfM model!") + + preprocess = functools.partial(preprocess, args) + run_inference = functools.partial(run_inference, args, cfg, model) + + with gr.Blocks(title=_TITLE, theme=gr.themes.Soft()) as demo: + gr.Markdown(f"# {_TITLE}") + gr.Markdown(_DESCRIPTION) + + with gr.Row(variant='panel'): + with gr.Column(scale=2): + image_block = gr.Files(file_count="multiple", label="Upload Images") + + gr.Markdown( + "You can run our model by either: (1) **Uploading images** above " + "or (2) selecting a **pre-collected example** below." + ) + + gallery = gr.Gallery( + value=[example[0][0] for example in examples_full], + label="Examples", + show_label=True, + columns=[4], + rows=[1], + object_fit="contain", + height="256", + ) + + selected = gr.State() + batch = gr.State() + + preprocessed_data = gr.Gallery( + label="Preprocessed Images", + show_label=True, + columns=[4], + rows=[1], + object_fit="contain", + height="256", + ) + + with gr.Row(variant='panel'): + run_inference_btn = gr.Button("Run Inference") + + with gr.Column(scale=4): + output_3D = gr.Model3D( + clear_color=[0.0, 0.0, 0.0, 0.0], + height=520, + zoom_speed=0.5, + pan_speed=0.5, + label="3D Point Cloud and Cameras" + ) + + # Link image gallery selection + gallery.select( + fn=get_select_index, + inputs=None, + outputs=[image_block, selected] + ).success( + fn=preprocess, + inputs=[image_block, selected], + outputs=[preprocessed_data, batch], + queue=False, + show_progress="full" + ) + + # Handle user uploads + image_block.upload( + preprocess, + inputs=[image_block], + outputs=[preprocessed_data, batch], + queue=False, + show_progress="full" + ).success(info_fn, None, None) + + # Run 3D reconstruction + run_inference_btn.click( + check_img_input, + inputs=[image_block], + queue=False + ).success( + run_inference, + inputs=[batch], + outputs=[output_3D] + ) + + demo.queue().launch(share=True) \ No newline at end of file diff --git a/train.py b/train.py new file mode 100644 index 0000000000000000000000000000000000000000..28d668576ec0f86206d2b2e5e57bcc07ed78765d --- /dev/null +++ b/train.py @@ -0,0 +1,916 @@ +""" +Configurations can be overwritten by adding: key=value +Use debug.wandb=False to disable logging to wandb. +""" + +import datetime +from datetime import timedelta +import os +import random +import socket +import time +from glob import glob + +import hydra +import ipdb # noqa: F401 +import numpy as np +import omegaconf +import torch +import wandb +from accelerate import Accelerator, DistributedDataParallelKwargs, InitProcessGroupKwargs +from pytorch3d.renderer import PerspectiveCameras + +from diffusionsfm.dataset.co3d_v2 import Co3dDataset, unnormalize_image_for_vis +# from diffusionsfm.dataset.multiloader import get_multiloader, MultiDataset +from diffusionsfm.eval.eval_category import evaluate +from diffusionsfm.model.diffuser import RayDiffuser +from diffusionsfm.model.diffuser_dpt import RayDiffuserDPT +from diffusionsfm.model.scheduler import NoiseScheduler +from diffusionsfm.utils.rays import cameras_to_rays, normalize_cameras_batch, compute_ndc_coordinates +from diffusionsfm.utils.visualization import ( + create_training_visualizations, + view_color_coded_images_from_tensor, +) + +os.umask(000) # Default to 777 permissions + + +class Trainer(object): + def __init__(self, cfg): + seed = cfg.training.seed + torch.manual_seed(seed) + np.random.seed(seed) + random.seed(seed) + + self.cfg = cfg + self.debug = cfg.debug + self.resume = cfg.training.resume + self.pretrain_path = cfg.training.pretrain_path + + self.batch_size = cfg.training.batch_size + self.max_iterations = cfg.training.max_iterations + self.mixed_precision = cfg.training.mixed_precision + self.interval_visualize = cfg.training.interval_visualize + self.interval_save_checkpoint = cfg.training.interval_save_checkpoint + self.interval_delete_checkpoint = cfg.training.interval_delete_checkpoint + self.interval_evaluate = cfg.training.interval_evaluate + self.delete_all = cfg.training.delete_all_checkpoints_after_training + self.freeze_encoder = cfg.training.freeze_encoder + self.translation_scale = cfg.training.translation_scale + self.regression = cfg.training.regression + self.prob_unconditional = cfg.training.prob_unconditional + self.load_extra_cameras = cfg.training.load_extra_cameras + self.calculate_intrinsics = cfg.training.calculate_intrinsics + self.distort = cfg.training.distort + self.diffuse_origins_and_endpoints = cfg.training.diffuse_origins_and_endpoints + self.diffuse_depths = cfg.training.diffuse_depths + self.depth_resolution = cfg.training.depth_resolution + self.dpt_head = cfg.training.dpt_head + self.full_num_patches_x = cfg.training.full_num_patches_x + self.full_num_patches_y = cfg.training.full_num_patches_y + self.dpt_encoder_features = cfg.training.dpt_encoder_features + self.nearest_neighbor = cfg.training.nearest_neighbor + self.no_bg_targets = cfg.training.no_bg_targets + self.unit_normalize_scene = cfg.training.unit_normalize_scene + self.sd_scale = cfg.training.sd_scale + self.bfloat = cfg.training.bfloat + self.first_cam_mediod = cfg.training.first_cam_mediod + self.normalize_first_camera = cfg.training.normalize_first_camera + self.gradient_clipping = cfg.training.gradient_clipping + self.l1_loss = cfg.training.l1_loss + self.reinit = cfg.training.reinit + + if self.first_cam_mediod: + assert self.normalize_first_camera + + self.pred_x0 = cfg.model.pred_x0 + self.num_patches_x = cfg.model.num_patches_x + self.num_patches_y = cfg.model.num_patches_y + self.depth = cfg.model.depth + self.num_images = cfg.model.num_images + self.num_visualize = min(self.batch_size, 2) + self.random_num_images = cfg.model.random_num_images + self.feature_extractor = cfg.model.feature_extractor + self.append_ndc = cfg.model.append_ndc + self.use_homogeneous = cfg.model.use_homogeneous + self.freeze_transformer = cfg.model.freeze_transformer + self.cond_depth_mask = cfg.model.cond_depth_mask + + self.dataset_name = cfg.dataset.name + self.shape = cfg.dataset.shape + self.apply_augmentation = cfg.dataset.apply_augmentation + self.mask_holes = cfg.dataset.mask_holes + self.image_size = cfg.dataset.image_size + + if not self.regression and (self.diffuse_origins_and_endpoints or self.diffuse_depths): + assert self.mask_holes or self.cond_depth_mask + + if self.regression: + assert self.pred_x0 + + self.start_time = None + self.iteration = 0 + self.epoch = 0 + self.wandb_id = None + + self.hostname = socket.gethostname() + + if self.dpt_head: + find_unused_parameters = True + else: + find_unused_parameters = False + + ddp_scaler = DistributedDataParallelKwargs( + find_unused_parameters=find_unused_parameters + ) + init_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=5400)) + self.accelerator = Accelerator( + even_batches=False, + device_placement=False, + kwargs_handlers=[ddp_scaler, init_kwargs], + ) + self.device = self.accelerator.device + + scheduler = NoiseScheduler( + type=cfg.noise_scheduler.type, + max_timesteps=cfg.noise_scheduler.max_timesteps, + beta_start=cfg.noise_scheduler.beta_start, + beta_end=cfg.noise_scheduler.beta_end, + ) + if self.dpt_head: + self.model = RayDiffuserDPT( + depth=self.depth, + width=self.num_patches_x, + P=1, + max_num_images=self.num_images, + noise_scheduler=scheduler, + freeze_encoder=self.freeze_encoder, + feature_extractor=self.feature_extractor, + append_ndc=self.append_ndc, + use_unconditional=self.prob_unconditional > 0, + diffuse_depths=self.diffuse_depths, + depth_resolution=self.depth_resolution, + encoder_features=self.dpt_encoder_features, + use_homogeneous=self.use_homogeneous, + freeze_transformer=self.freeze_transformer, + cond_depth_mask=self.cond_depth_mask, + ).to(self.device) + else: + self.model = RayDiffuser( + depth=self.depth, + width=self.num_patches_x, + P=1, + max_num_images=self.num_images, + noise_scheduler=scheduler, + freeze_encoder=self.freeze_encoder, + feature_extractor=self.feature_extractor, + append_ndc=self.append_ndc, + use_unconditional=self.prob_unconditional > 0, + diffuse_depths=self.diffuse_depths, + depth_resolution=self.depth_resolution, + use_homogeneous=self.use_homogeneous, + cond_depth_mask=self.cond_depth_mask, + ).to(self.device) + + if self.dpt_head: + depth_size = self.full_num_patches_x + elif self.depth_resolution > 1: + depth_size = self.num_patches_x * self.depth_resolution + else: + depth_size = self.num_patches_x + self.depth_size = depth_size + + if self.dataset_name == "multi": + self.dataset, self.train_dataloader, self.test_dataset = get_multiloader( + num_images=self.num_images, + apply_augmentation=self.apply_augmentation, + load_extra_cameras=self.load_extra_cameras, + distort_image=self.distort, + center_crop=self.diffuse_origins_and_endpoints or self.diffuse_depths, + crop_images=not (self.diffuse_origins_and_endpoints or self.diffuse_depths), + load_depths=self.diffuse_origins_and_endpoints or self.diffuse_depths, + depth_size=depth_size, + mask_holes=self.mask_holes, + img_size=self.image_size, + batch_size=self.batch_size, + num_workers=cfg.training.num_workers, + dust3r_pairs=True, + ) + elif self.dataset_name == "co3d": + self.dataset = Co3dDataset( + category=self.shape, + split="train", + num_images=self.num_images, + apply_augmentation=self.apply_augmentation, + load_extra_cameras=self.load_extra_cameras, + distort_image=self.distort, + center_crop=self.diffuse_origins_and_endpoints or self.diffuse_depths, + crop_images=not (self.diffuse_origins_and_endpoints or self.diffuse_depths), + load_depths=self.diffuse_origins_and_endpoints or self.diffuse_depths, + depth_size=depth_size, + mask_holes=self.mask_holes, + img_size=self.image_size, + ) + self.train_dataloader = torch.utils.data.DataLoader( + self.dataset, + batch_size=self.batch_size, + shuffle=True, + num_workers=cfg.training.num_workers, + pin_memory=True, + drop_last=True, + ) + self.test_dataset = Co3dDataset( + category=self.shape, + split="test", + num_images=self.num_images, + apply_augmentation=False, + load_extra_cameras=self.load_extra_cameras, + distort_image=self.distort, + center_crop=self.diffuse_origins_and_endpoints or self.diffuse_depths, + crop_images=not (self.diffuse_origins_and_endpoints or self.diffuse_depths), + load_depths=self.diffuse_origins_and_endpoints or self.diffuse_depths, + depth_size=depth_size, + mask_holes=self.mask_holes, + img_size=self.image_size, + ) + else: + raise NotImplementedError(f"Dataset '{self.dataset_name}' is not supported.") + self.lr = 1e-4 + + self.output_dir = hydra.core.hydra_config.HydraConfig.get().runtime.output_dir + self.checkpoint_dir = os.path.join(self.output_dir, "checkpoints") + if self.accelerator.is_main_process: + name = os.path.basename(self.output_dir) + name += f"_{self.debug.run_name}" + + print("Output dir:", self.output_dir) + with open(os.path.join(self.output_dir, name), "w"): + # Create empty tag with name + pass + self.name = name + + conf_dict = omegaconf.OmegaConf.to_container( + cfg, resolve=True, throw_on_missing=True + ) + conf_dict["output_dir"] = self.output_dir + conf_dict["hostname"] = self.hostname + + if self.dpt_head: + self.init_optimizer_with_separate_lrs() + else: + self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr) + self.gradscaler = torch.cuda.amp.GradScaler(growth_interval=100000, enabled=self.mixed_precision) + + self.model, self.optimizer, self.train_dataloader = self.accelerator.prepare( + self.model, self.optimizer, self.train_dataloader + ) + + if self.resume: + checkpoint_files = sorted(glob(os.path.join(self.checkpoint_dir, "*.pth"))) + last_checkpoint = checkpoint_files[-1] + print("Resuming from checkpoint:", last_checkpoint) + self.load_model(last_checkpoint, load_metadata=True) + elif self.pretrain_path != "": + print("Loading pretrained model:", self.pretrain_path) + self.load_model(self.pretrain_path, load_metadata=False) + + if self.accelerator.is_main_process: + mode = "online" if cfg.debug.wandb else "disabled" + if self.wandb_id is None: + self.wandb_id = wandb.util.generate_id() + self.wandb_run = wandb.init( + mode=mode, + name=name, + project=cfg.debug.project_name, + config=conf_dict, + resume=self.resume, + id=self.wandb_id, + ) + wandb.define_metric("iteration") + noise_schedule = self.get_module().noise_scheduler.plot_schedule( + return_image=True + ) + wandb.log( + {"Schedule": wandb.Image(noise_schedule, caption="Noise Schedule")} + ) + + def get_module(self): + if isinstance(self.model, torch.nn.parallel.DistributedDataParallel): + model = self.model.module + else: + model = self.model + + return model + + def init_optimizer_with_separate_lrs(self): + print("Use different LRs for the DINOv2 encoder and DiT!") + + feature_extractor_params = [ + p for n, p in self.model.feature_extractor.named_parameters() + ] + feature_extractor_param_names = [ + "feature_extractor." + n for n, _ in self.model.feature_extractor.named_parameters() + ] + ray_predictor_params = [ + p for n, p in self.model.ray_predictor.named_parameters() + ] + ray_predictor_param_names = [ + "ray_predictor." + n for n, p in self.model.ray_predictor.named_parameters() + ] + other_params = [ + p for n, p in self.model.named_parameters() + if n not in feature_extractor_param_names + ray_predictor_param_names + ] + + self.optimizer = torch.optim.Adam([ + {'params': feature_extractor_params, 'lr': self.lr * 0.1}, # Lower LR for feature extractor + {'params': ray_predictor_params, 'lr': self.lr * 0.1}, # Lower LR for DIT (ray_predictor) + {'params': other_params, 'lr': self.lr} # Normal LR for other parts of the model + ]) + + def train(self): + while self.iteration < self.max_iterations: + for batch in self.train_dataloader: + t0 = time.time() + self.optimizer.zero_grad() + + float_type = torch.bfloat16 if self.bfloat else torch.float16 + with torch.cuda.amp.autocast( + enabled=self.mixed_precision, dtype=float_type + ): + images = batch["image"].to(self.device) + focal_lengths = batch["focal_length"].to(self.device) + crop_params = batch["crop_parameters"].to(self.device) + principal_points = batch["principal_point"].to(self.device) + R = batch["R"].to(self.device) + T = batch["T"].to(self.device) + if "distortion_coefficients" in batch: + distortion_coefficients = batch["distortion_coefficients"] + else: + distortion_coefficients = [None for _ in range(R.shape[0])] + + depths = batch["depth"].to(self.device) + if self.no_bg_targets: + masks = batch["depth_masks"].to(self.device).bool() + + cameras_og = [ + PerspectiveCameras( + focal_length=focal_lengths[b], + principal_point=principal_points[b], + R=R[b], + T=T[b], + device=self.device, + ) + for b in range(self.batch_size) + ] + + cameras, _ = normalize_cameras_batch( + cameras=cameras_og, + scale=self.translation_scale, + normalize_first_camera=self.normalize_first_camera, + depths=( + None + if not (self.diffuse_origins_and_endpoints or self.diffuse_depths) + else depths + ), + first_cam_mediod=self.first_cam_mediod, + crop_parameters=crop_params, + num_patches_x=self.depth_size, + num_patches_y=self.depth_size, + distortion_coeffs=distortion_coefficients, + ) + + # Now that cameras are normalized, fix shapes of camera parameters + if self.load_extra_cameras or self.random_num_images: + if self.random_num_images: + num_images = torch.randint(2, self.num_images + 1, (1,)) + else: + num_images = self.num_images + + # The correct number of images is already loaded. + # Only need to modify these camera parameters shapes. + focal_lengths = focal_lengths[:, :num_images] + crop_params = crop_params[:, :num_images] + R = R[:, :num_images] + T = T[:, :num_images] + images = images[:, :num_images] + depths = depths[:, :num_images] + masks = masks[:, :num_images] + + cameras = [ + PerspectiveCameras( + focal_length=cameras[b].focal_length[:num_images], + principal_point=cameras[b].principal_point[:num_images], + R=cameras[b].R[:num_images], + T=cameras[b].T[:num_images], + device=self.device, + ) + for b in range(self.batch_size) + ] + + if self.regression: + low = self.get_module().noise_scheduler.max_timesteps - 1 + else: + low = 0 + + t = torch.randint( + low=low, + high=self.get_module().noise_scheduler.max_timesteps, + size=(self.batch_size,), + device=self.device, + ) + + if self.prob_unconditional > 0: + unconditional_mask = ( + (torch.rand(self.batch_size) < self.prob_unconditional) + .float() + .to(self.device) + ) + else: + unconditional_mask = None + + if self.distort: + raise NotImplementedError() + else: + gt_rays = [] + rays_dirs = [] + rays = [] + for i, (camera, crop_param, depth) in enumerate( + zip(cameras, crop_params, depths) + ): + if self.diffuse_origins_and_endpoints: + mode = "segment" + else: + mode = "plucker" + + r = cameras_to_rays( + cameras=camera, + num_patches_x=self.full_num_patches_x, + num_patches_y=self.full_num_patches_y, + crop_parameters=crop_param, + depths=depth, + mode=mode, + depth_resolution=self.depth_resolution, + nearest_neighbor=self.nearest_neighbor, + distortion_coefficients=distortion_coefficients[i], + ) + rays_dirs.append(r.get_directions()) + gt_rays.append(r) + + if self.diffuse_origins_and_endpoints: + assert r.mode == "segment" + elif self.diffuse_depths: + assert r.mode == "plucker" + + if self.unit_normalize_scene: + if self.diffuse_origins_and_endpoints: + assert r.mode == "segment" + # Let's say SD should be 0.5 + scale = r.get_segments().std() * self.sd_scale + + if scale.isnan().any(): + assert False + + camera.T /= scale + r.rays /= scale + depths[i] /= scale + else: + assert r.mode == "plucker" + scale = r.depths.std() * self.sd_scale + + if scale.isnan().any(): + assert False + + camera.T /= scale + r.depths /= scale + depths[i] /= scale + + rays.append( + r.to_spatial( + include_ndc_coordinates=self.append_ndc, + include_depths=self.diffuse_depths, + use_homogeneous=self.use_homogeneous, + ) + ) + + rays_tensor = torch.stack(rays, dim=0) + + if self.append_ndc: + ndc_coordinates = rays_tensor[..., -2:, :, :] + rays_tensor = rays_tensor[..., :-2, :, :] + + if self.dpt_head: + xy_grid = compute_ndc_coordinates( + crop_params, + num_patches_x=self.depth_size // 16, + num_patches_y=self.depth_size // 16, + distortion_coeffs=distortion_coefficients, + )[..., :2] + ndc_coordinates = xy_grid.permute(0, 1, 4, 2, 3).contiguous() + + else: + ndc_coordinates = None + + if self.cond_depth_mask: + condition_mask = masks + else: + condition_mask = None + + if rays_tensor.isnan().any(): + import pickle + + with open("bad.json", "wb") as f: + pickle.dump(batch, f) + ipdb.set_trace() + + eps_pred, eps = self.model( + images=images, + rays=rays_tensor, + t=t, + ndc_coordinates=ndc_coordinates, + unconditional_mask=unconditional_mask, + depth_mask=condition_mask, + ) + if self.pred_x0: + target = rays_tensor + else: + target = eps + + if self.no_bg_targets: + C = eps_pred.shape[2] + loss_masks = masks.unsqueeze(2).repeat(1, 1, C, 1, 1) + eps_pred = loss_masks * eps_pred + target = loss_masks * target + + loss = 0 + + if self.l1_loss: + loss_reconstruction = torch.mean(torch.abs(eps_pred - target)) + else: + loss_reconstruction = torch.mean((eps_pred - target) ** 2) + + loss += loss_reconstruction + + if self.mixed_precision: + self.gradscaler.scale(loss).backward() + + scaled_norm = 0 + for p in self.model.parameters(): + if p.requires_grad and p.grad is not None: + param_norm = p.grad.data.norm(2) + scaled_norm += param_norm.item() ** 2 + scaled_norm = scaled_norm ** 0.5 + + if self.gradient_clipping and self.accelerator.sync_gradients: + self.accelerator.clip_grad_norm_( + self.get_module().parameters(), 1 + ) + + clipped_norm = 0 + for p in self.model.parameters(): + if p.requires_grad and p.grad is not None: + param_norm = p.grad.data.norm(2) + clipped_norm += param_norm.item() ** 2 + clipped_norm = clipped_norm ** 0.5 + + self.gradscaler.unscale_(self.optimizer) + unscaled_norm = 0 + for p in self.model.parameters(): + if p.requires_grad and p.grad is not None: + param_norm = p.grad.data.norm(2) + unscaled_norm += param_norm.item() ** 2 + unscaled_norm = unscaled_norm ** 0.5 + + self.gradscaler.step(self.optimizer) + self.gradscaler.update() + else: + self.accelerator.backward(loss) + + if self.gradient_clipping and self.accelerator.sync_gradients: + self.accelerator.clip_grad_norm_( + self.get_module().parameters(), 10 + ) + self.optimizer.step() + + if self.accelerator.is_main_process: + if self.iteration % 10 == 0: + self.log_info( + loss_reconstruction, + t0, + self.lr, + scaled_norm, + unscaled_norm, + clipped_norm, + ) + + if self.iteration % self.interval_visualize == 0: + self.visualize( + images=unnormalize_image_for_vis(images.clone()), + cameras_gt=cameras, + depths=depths, + crop_parameters=crop_params, + distortion_coefficients=distortion_coefficients, + depth_mask=masks, + ) + + if self.iteration % self.interval_save_checkpoint == 0 and self.iteration != 0: + self.save_model() + + if self.iteration % self.interval_delete_checkpoint == 0: + self.clear_old_checkpoints(self.checkpoint_dir) + + if ( + self.iteration % self.interval_evaluate == 0 + and self.iteration > 0 + ): + self.evaluate_train_acc() + + if self.iteration >= self.max_iterations + 1: + if self.delete_all: + self.clear_old_checkpoints( + self.checkpoint_dir, clear_all_old=True + ) + return + + self.iteration += 1 + + if self.reinit and self.iteration >= 50000: + state_dict = self.get_module().state_dict() + self.model = RayDiffuserDPT( + depth=self.depth, + width=self.num_patches_x, + P=1, + max_num_images=self.num_images, + noise_scheduler=self.get_module().noise_scheduler, + freeze_encoder=False, + feature_extractor=self.feature_extractor, + append_ndc=self.append_ndc, + use_unconditional=self.prob_unconditional > 0, + diffuse_depths=self.diffuse_depths, + depth_resolution=self.depth_resolution, + encoder_features=self.dpt_encoder_features, + use_homogeneous=self.use_homogeneous, + freeze_transformer=False, + cond_depth_mask=self.cond_depth_mask, + ).to(self.device) + + self.init_optimizer_with_separate_lrs() + self.gradscaler = torch.cuda.amp.GradScaler(growth_interval=100000, enabled=self.mixed_precision) + + self.model, self.optimizer = self.accelerator.prepare( + self.model, self.optimizer + ) + + msg = self.get_module().load_state_dict( + state_dict, + strict=True, + ) + print(msg) + + self.reinit = False + + self.epoch += 1 + + def load_model(self, path, load_metadata=True): + save_dict = torch.load(path, map_location=self.device) + del save_dict["state_dict"]["ray_predictor.x_pos_enc.image_pos_table"] + + if not self.resume: + if len(save_dict["state_dict"]["scratch.input_conv.weight"].shape) == 2 and self.dpt_head: + print("Initialize conv layer weights from the linear layer!") + C = save_dict["state_dict"]["scratch.input_conv.weight"].shape[1] + input_conv_weight = save_dict["state_dict"]["scratch.input_conv.weight"].view(384, C, 1, 1).repeat(1, 1, 16, 16) / 256. + input_conv_bias = save_dict["state_dict"]["scratch.input_conv.bias"] + + self.get_module().scratch.input_conv.weight.data = input_conv_weight + self.get_module().scratch.input_conv.bias.data = input_conv_bias + + del save_dict["state_dict"]["scratch.input_conv.weight"] + del save_dict["state_dict"]["scratch.input_conv.bias"] + + missing, unexpected = self.get_module().load_state_dict( + save_dict["state_dict"], + strict=False, + ) + print(f"Missing keys: {missing}") + print(f"Unexpected keys: {unexpected}") + if load_metadata: + self.iteration = save_dict["iteration"] + self.epoch = save_dict["epoch"] + time_elapsed = save_dict["elapsed"] + self.start_time = time.time() - time_elapsed + if "wandb_id" in save_dict: + self.wandb_id = save_dict["wandb_id"] + self.optimizer.load_state_dict(save_dict["optimizer"]) + self.gradscaler.load_state_dict(save_dict["gradscaler"]) + + def save_model(self): + path = os.path.join(self.checkpoint_dir, f"ckpt_{self.iteration:08d}.pth") + os.makedirs(os.path.dirname(path), exist_ok=True) + elapsed = time.time() - self.start_time if self.start_time is not None else 0 + save_dict = { + "epoch": self.epoch, + "elapsed": elapsed, + "gradscaler": self.gradscaler.state_dict(), + "iteration": self.iteration, + "state_dict": self.get_module().state_dict(), + "optimizer": self.optimizer.state_dict(), + "wandb_id": self.wandb_id, + } + torch.save(save_dict, path) + + def clear_old_checkpoints(self, checkpoint_dir, clear_all_old=False): + print("Clearing old checkpoints") + checkpoint_files = sorted(glob(os.path.join(checkpoint_dir, "ckpt_*.pth"))) + if clear_all_old: + for checkpoint_file in checkpoint_files[:-1]: + os.remove(checkpoint_file) + else: + for checkpoint_file in checkpoint_files: + checkpoint = os.path.basename(checkpoint_file) + checkpoint_iteration = int("".join(filter(str.isdigit, checkpoint))) + if checkpoint_iteration % self.interval_delete_checkpoint != 0: + os.remove(checkpoint_file) + + def log_info( + self, + loss, + t0, + lr, + scaled_norm, + unscaled_norm, + clipped_norm, + ): + if self.start_time is None: + self.start_time = time.time() + time_elapsed = round(time.time() - self.start_time) + time_remaining = round( + (time.time() - self.start_time) + / (self.iteration + 1) + * (self.max_iterations - self.iteration) + ) + disp = [ + f"Iter: {self.iteration}/{self.max_iterations}", + f"Epoch: {self.epoch}", + f"Loss: {loss.item():.4f}", + f"LR: {lr:.7f}", + f"Grad Norm: {scaled_norm:.4f}/{unscaled_norm:.4f}/{clipped_norm:.4f}", + f"Elap: {str(datetime.timedelta(seconds=time_elapsed))}", + f"Rem: {str(datetime.timedelta(seconds=time_remaining))}", + self.hostname, + self.name, + ] + print(", ".join(disp), flush=True) + wandb_log = { + "loss": loss.item(), + "iter_time": time.time() - t0, + "lr": lr, + "iteration": self.iteration, + "hours_remaining": time_remaining / 3600, + "gradient norm": scaled_norm, + "unscaled norm": unscaled_norm, + "clipped norm": clipped_norm, + } + wandb.log(wandb_log) + + def visualize( + self, + images, + cameras_gt, + crop_parameters=None, + depths=None, + distortion_coefficients=None, + depth_mask=None, + high_loss=False, + ): + self.get_module().eval() + + for camera in cameras_gt: + # AMP may not cast back to float + camera.R = camera.R.float() + camera.T = camera.T.float() + + loss_tag = "" if not high_loss else " HIGH LOSS" + + for i in range(self.num_visualize): + imgs = view_color_coded_images_from_tensor(images[i].cpu(), depth=False) + im = wandb.Image(imgs, caption=f"iteration {self.iteration} example {i}") + wandb.log({f"Vis images {i}{loss_tag}": im}) + + if self.cond_depth_mask: + imgs = view_color_coded_images_from_tensor( + depth_mask[i].cpu(), depth=True + ) + im = wandb.Image( + imgs, caption=f"iteration {self.iteration} example {i}" + ) + wandb.log({f"Vis masks {i}{loss_tag}": im}) + + vis_depths, _, _ = create_training_visualizations( + model=self.get_module(), + images=images[: self.num_visualize], + device=self.device, + cameras_gt=cameras_gt, + pred_x0=self.pred_x0, + num_images=images.shape[1], + crop_parameters=crop_parameters[: self.num_visualize], + visualize_pred=self.regression, + return_first=self.regression, + calculate_intrinsics=self.calculate_intrinsics, + mode="segment" if self.diffuse_origins_and_endpoints else "plucker", + depths=depths[: self.num_visualize], + diffuse_depths=self.diffuse_depths, + full_num_patches_x=self.full_num_patches_x, + full_num_patches_y=self.full_num_patches_y, + use_homogeneous=self.use_homogeneous, + distortion_coefficients=distortion_coefficients, + ) + + for i, vis_image in enumerate(vis_depths): + im = wandb.Image( + vis_image, caption=f"iteration {self.iteration} example {i}" + ) + + for i, vis_image in enumerate(vis_depths): + im = wandb.Image( + vis_image, caption=f"iteration {self.iteration} example {i}" + ) + wandb.log({f"Vis origins and endpoints {i}{loss_tag}": im}) + + self.get_module().train() + + def evaluate_train_acc(self, num_evaluate=10): + print("Evaluating train accuracy") + model = self.get_module() + model.eval() + additional_timesteps = [0, 10, 20, 30, 40, 50, 60, 70, 80, 90] + num_images = self.num_images + + for split in ["train", "test"]: + if split == "train": + if self.dataset_name != "co3d": + to_evaluate = self.dataset.datasets + names = self.dataset.names + else: + to_evaluate = [self.dataset] + names = ["co3d"] + elif split == "test": + if self.dataset_name != "co3d": + to_evaluate = self.test_dataset.datasets + names = self.test_dataset.names + else: + to_evaluate = [self.test_dataset] + names = ["co3d"] + + for name, dataset in zip(names, to_evaluate): + results = evaluate( + cfg=self.cfg, + model=model, + dataset=dataset, + num_images=num_images, + device=self.device, + additional_timesteps=additional_timesteps, + num_evaluate=num_evaluate, + use_pbar=True, + mode="segment" if self.diffuse_origins_and_endpoints else "plucker", + metrics=False, + ) + + R_err = [] + CC_err = [] + for key in results.keys(): + R_err.append([v["R_error"] for v in results[key]]) + CC_err.append([v["CC_error"] for v in results[key]]) + + R_err = np.array(R_err) + CC_err = np.array(CC_err) + + R_acc_15 = np.mean(R_err < 15, (0, 2)).max() + CC_acc = np.mean(CC_err < 0.1, (0, 2)).max() + + wandb.log( + { + f"R_acc_15_{name}_{split}": R_acc_15, + "iteration": self.iteration, + } + ) + wandb.log( + { + f"CC_acc_0.1_{name}_{split}": CC_acc, + "iteration": self.iteration, + } + ) + model.train() + + +@hydra.main(config_path="./conf", config_name="config", version_base="1.3") +def main(cfg): + print(cfg) + torch.autograd.set_detect_anomaly(cfg.debug.anomaly_detection) + torch.set_float32_matmul_precision(cfg.training.matmul_precision) + trainer = Trainer(cfg=cfg) + trainer.train() + + +if __name__ == "__main__": + main()