qitaoz commited on
Commit
4562a06
·
verified ·
1 Parent(s): 253f3e1

Upload 57 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +22 -0
  2. LICENSE +21 -0
  3. assets/demo.png +3 -0
  4. conf/config.yaml +83 -0
  5. conf/diffusion.yml +110 -0
  6. data/demo/jellycat/001.jpg +3 -0
  7. data/demo/jellycat/002.jpg +3 -0
  8. data/demo/jellycat/003.jpg +3 -0
  9. data/demo/jellycat/004.jpg +3 -0
  10. data/demo/jordan/001.png +3 -0
  11. data/demo/jordan/002.png +3 -0
  12. data/demo/jordan/003.png +3 -0
  13. data/demo/jordan/004.png +3 -0
  14. data/demo/jordan/005.png +3 -0
  15. data/demo/jordan/006.png +3 -0
  16. data/demo/jordan/007.png +3 -0
  17. data/demo/jordan/008.png +3 -0
  18. data/demo/kew_gardens_ruined_arch/001.jpeg +3 -0
  19. data/demo/kew_gardens_ruined_arch/002.jpeg +3 -0
  20. data/demo/kew_gardens_ruined_arch/003.jpeg +3 -0
  21. data/demo/kotor_cathedral/001.jpeg +3 -0
  22. data/demo/kotor_cathedral/002.jpeg +3 -0
  23. data/demo/kotor_cathedral/003.jpeg +3 -0
  24. data/demo/kotor_cathedral/004.jpeg +3 -0
  25. data/demo/kotor_cathedral/005.jpeg +3 -0
  26. data/demo/kotor_cathedral/006.jpeg +3 -0
  27. diffusionsfm/__init__.py +1 -0
  28. diffusionsfm/dataset/__init__.py +0 -0
  29. diffusionsfm/dataset/co3d_v2.py +792 -0
  30. diffusionsfm/dataset/custom.py +105 -0
  31. diffusionsfm/eval/__init__.py +0 -0
  32. diffusionsfm/eval/eval_category.py +292 -0
  33. diffusionsfm/eval/eval_jobs.py +175 -0
  34. diffusionsfm/inference/__init__.py +0 -0
  35. diffusionsfm/inference/ddim.py +145 -0
  36. diffusionsfm/inference/load_model.py +97 -0
  37. diffusionsfm/inference/predict.py +93 -0
  38. diffusionsfm/model/base_model.py +16 -0
  39. diffusionsfm/model/blocks.py +247 -0
  40. diffusionsfm/model/diffuser.py +195 -0
  41. diffusionsfm/model/diffuser_dpt.py +331 -0
  42. diffusionsfm/model/dit.py +428 -0
  43. diffusionsfm/model/feature_extractors.py +176 -0
  44. diffusionsfm/model/memory_efficient_attention.py +51 -0
  45. diffusionsfm/model/scheduler.py +128 -0
  46. diffusionsfm/utils/__init__.py +0 -0
  47. diffusionsfm/utils/configs.py +66 -0
  48. diffusionsfm/utils/distortion.py +144 -0
  49. diffusionsfm/utils/distributed.py +31 -0
  50. diffusionsfm/utils/geometry.py +145 -0
.gitattributes CHANGED
@@ -33,3 +33,25 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/demo.png filter=lfs diff=lfs merge=lfs -text
37
+ data/demo/jellycat/001.jpg filter=lfs diff=lfs merge=lfs -text
38
+ data/demo/jellycat/002.jpg filter=lfs diff=lfs merge=lfs -text
39
+ data/demo/jellycat/003.jpg filter=lfs diff=lfs merge=lfs -text
40
+ data/demo/jellycat/004.jpg filter=lfs diff=lfs merge=lfs -text
41
+ data/demo/jordan/001.png filter=lfs diff=lfs merge=lfs -text
42
+ data/demo/jordan/002.png filter=lfs diff=lfs merge=lfs -text
43
+ data/demo/jordan/003.png filter=lfs diff=lfs merge=lfs -text
44
+ data/demo/jordan/004.png filter=lfs diff=lfs merge=lfs -text
45
+ data/demo/jordan/005.png filter=lfs diff=lfs merge=lfs -text
46
+ data/demo/jordan/006.png filter=lfs diff=lfs merge=lfs -text
47
+ data/demo/jordan/007.png filter=lfs diff=lfs merge=lfs -text
48
+ data/demo/jordan/008.png filter=lfs diff=lfs merge=lfs -text
49
+ data/demo/kew_gardens_ruined_arch/001.jpeg filter=lfs diff=lfs merge=lfs -text
50
+ data/demo/kew_gardens_ruined_arch/002.jpeg filter=lfs diff=lfs merge=lfs -text
51
+ data/demo/kew_gardens_ruined_arch/003.jpeg filter=lfs diff=lfs merge=lfs -text
52
+ data/demo/kotor_cathedral/001.jpeg filter=lfs diff=lfs merge=lfs -text
53
+ data/demo/kotor_cathedral/002.jpeg filter=lfs diff=lfs merge=lfs -text
54
+ data/demo/kotor_cathedral/003.jpeg filter=lfs diff=lfs merge=lfs -text
55
+ data/demo/kotor_cathedral/004.jpeg filter=lfs diff=lfs merge=lfs -text
56
+ data/demo/kotor_cathedral/005.jpeg filter=lfs diff=lfs merge=lfs -text
57
+ data/demo/kotor_cathedral/006.jpeg filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Qitao Zhao
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
assets/demo.png ADDED

Git LFS Details

  • SHA256: f5021efbf6bf2ad1de68447a1e9d313581422b79c4f460ccb94654d5c08bb83c
  • Pointer size: 132 Bytes
  • Size of remote file: 1.25 MB
conf/config.yaml ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ training:
2
+ resume: False # If True, must set hydra.run.dir accordingly
3
+ pretrain_path: ""
4
+ interval_visualize: 1000
5
+ interval_save_checkpoint: 5000
6
+ interval_delete_checkpoint: 10000
7
+ interval_evaluate: 5000
8
+ delete_all_checkpoints_after_training: False
9
+ lr: 1e-4
10
+ mixed_precision: True
11
+ matmul_precision: high
12
+ max_iterations: 100000
13
+ batch_size: 64
14
+ num_workers: 8
15
+ gpu_id: 0
16
+ freeze_encoder: True
17
+ seed: 0
18
+ job_key: "" # Use this for submitit sweeps where timestamps might collide
19
+ translation_scale: 1.0
20
+ regression: False
21
+ prob_unconditional: 0
22
+ load_extra_cameras: False
23
+ calculate_intrinsics: False
24
+ distort: False
25
+ normalize_first_camera: True
26
+ diffuse_origins_and_endpoints: True
27
+ diffuse_depths: False
28
+ depth_resolution: 1
29
+ dpt_head: False
30
+ full_num_patches_x: 16
31
+ full_num_patches_y: 16
32
+ dpt_encoder_features: True
33
+ nearest_neighbor: True
34
+ no_bg_targets: True
35
+ unit_normalize_scene: False
36
+ sd_scale: 2
37
+ bfloat: True
38
+ first_cam_mediod: True
39
+ gradient_clipping: False
40
+ l1_loss: False
41
+ grad_accumulation: False
42
+ reinit: False
43
+
44
+ model:
45
+ pred_x0: True
46
+ model_type: dit
47
+ num_patches_x: 16
48
+ num_patches_y: 16
49
+ depth: 16
50
+ num_images: 1
51
+ random_num_images: True
52
+ feature_extractor: dino
53
+ append_ndc: True
54
+ within_image: False
55
+ use_homogeneous: True
56
+ freeze_transformer: False
57
+ cond_depth_mask: True
58
+
59
+ noise_scheduler:
60
+ type: linear
61
+ max_timesteps: 100
62
+ beta_start: 0.0120
63
+ beta_end: 0.00085
64
+ marigold_ddim: False
65
+
66
+ dataset:
67
+ name: co3d
68
+ shape: all_train
69
+ apply_augmentation: True
70
+ use_global_intrinsics: True
71
+ mask_holes: True
72
+ image_size: 224
73
+
74
+ debug:
75
+ wandb: True
76
+ project_name: diffusionsfm
77
+ run_name:
78
+ anomaly_detection: False
79
+
80
+ hydra:
81
+ run:
82
+ dir: ./output/${now:%m%d_%H%M%S_%f}${training.job_key}
83
+ output_subdir: hydra
conf/diffusion.yml ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: diffusion
2
+ channels:
3
+ - conda-forge
4
+ - iopath
5
+ - nvidia
6
+ - pkgs/main
7
+ - pytorch
8
+ - xformers
9
+ dependencies:
10
+ - _libgcc_mutex=0.1=conda_forge
11
+ - _openmp_mutex=4.5=2_gnu
12
+ - blas=1.0=mkl
13
+ - brotli-python=1.0.9=py39h5a03fae_9
14
+ - bzip2=1.0.8=h7f98852_4
15
+ - ca-certificates=2023.7.22=hbcca054_0
16
+ - certifi=2023.7.22=pyhd8ed1ab_0
17
+ - charset-normalizer=3.2.0=pyhd8ed1ab_0
18
+ - colorama=0.4.6=pyhd8ed1ab_0
19
+ - cuda-cudart=11.7.99=0
20
+ - cuda-cupti=11.7.101=0
21
+ - cuda-libraries=11.7.1=0
22
+ - cuda-nvrtc=11.7.99=0
23
+ - cuda-nvtx=11.7.91=0
24
+ - cuda-runtime=11.7.1=0
25
+ - ffmpeg=4.3=hf484d3e_0
26
+ - filelock=3.12.2=pyhd8ed1ab_0
27
+ - freetype=2.12.1=hca18f0e_1
28
+ - fvcore=0.1.5.post20221221=pyhd8ed1ab_0
29
+ - gmp=6.2.1=h58526e2_0
30
+ - gmpy2=2.1.2=py39h376b7d2_1
31
+ - gnutls=3.6.13=h85f3911_1
32
+ - idna=3.4=pyhd8ed1ab_0
33
+ - intel-openmp=2022.1.0=h9e868ea_3769
34
+ - iopath=0.1.9=py39
35
+ - jinja2=3.1.2=pyhd8ed1ab_1
36
+ - jpeg=9e=h0b41bf4_3
37
+ - lame=3.100=h166bdaf_1003
38
+ - lcms2=2.15=hfd0df8a_0
39
+ - ld_impl_linux-64=2.40=h41732ed_0
40
+ - lerc=4.0.0=h27087fc_0
41
+ - libblas=3.9.0=16_linux64_mkl
42
+ - libcblas=3.9.0=16_linux64_mkl
43
+ - libcublas=11.10.3.66=0
44
+ - libcufft=10.7.2.124=h4fbf590_0
45
+ - libcufile=1.7.1.12=0
46
+ - libcurand=10.3.3.129=0
47
+ - libcusolver=11.4.0.1=0
48
+ - libcusparse=11.7.4.91=0
49
+ - libdeflate=1.17=h0b41bf4_0
50
+ - libffi=3.3=h58526e2_2
51
+ - libgcc-ng=13.1.0=he5830b7_0
52
+ - libgomp=13.1.0=he5830b7_0
53
+ - libiconv=1.17=h166bdaf_0
54
+ - liblapack=3.9.0=16_linux64_mkl
55
+ - libnpp=11.7.4.75=0
56
+ - libnvjpeg=11.8.0.2=0
57
+ - libpng=1.6.39=h753d276_0
58
+ - libsqlite=3.42.0=h2797004_0
59
+ - libstdcxx-ng=13.1.0=hfd8a6a1_0
60
+ - libtiff=4.5.0=h6adf6a1_2
61
+ - libwebp-base=1.3.1=hd590300_0
62
+ - libxcb=1.13=h7f98852_1004
63
+ - libzlib=1.2.13=hd590300_5
64
+ - markupsafe=2.1.3=py39hd1e30aa_0
65
+ - mkl=2022.1.0=hc2b9512_224
66
+ - mpc=1.3.1=hfe3b2da_0
67
+ - mpfr=4.2.0=hb012696_0
68
+ - mpmath=1.3.0=pyhd8ed1ab_0
69
+ - ncurses=6.4=hcb278e6_0
70
+ - nettle=3.6=he412f7d_0
71
+ - networkx=3.1=pyhd8ed1ab_0
72
+ - numpy=1.25.2=py39h6183b62_0
73
+ - openh264=2.1.1=h780b84a_0
74
+ - openjpeg=2.5.0=hfec8fc6_2
75
+ - openssl=1.1.1v=hd590300_0
76
+ - pillow=9.4.0=py39h2320bf1_1
77
+ - pip=23.2.1=pyhd8ed1ab_0
78
+ - portalocker=2.7.0=py39hf3d152e_0
79
+ - pthread-stubs=0.4=h36c2ea0_1001
80
+ - pysocks=1.7.1=pyha2e5f31_6
81
+ - python=3.9.0=hffdb5ce_5_cpython
82
+ - python_abi=3.9=3_cp39
83
+ - pytorch=2.0.1=py3.9_cuda11.7_cudnn8.5.0_0
84
+ - pytorch-cuda=11.7=h778d358_5
85
+ - pytorch-mutex=1.0=cuda
86
+ - pyyaml=6.0=py39hb9d737c_5
87
+ - readline=8.2=h8228510_1
88
+ - requests=2.31.0=pyhd8ed1ab_0
89
+ - setuptools=68.0.0=pyhd8ed1ab_0
90
+ - sqlite=3.42.0=h2c6b66d_0
91
+ - sympy=1.12=pypyh9d50eac_103
92
+ - tabulate=0.9.0=pyhd8ed1ab_1
93
+ - termcolor=2.3.0=pyhd8ed1ab_0
94
+ - tk=8.6.12=h27826a3_0
95
+ - torchaudio=2.0.2=py39_cu117
96
+ - torchtriton=2.0.0=py39
97
+ - torchvision=0.15.2=py39_cu117
98
+ - tqdm=4.66.1=pyhd8ed1ab_0
99
+ - typing_extensions=4.7.1=pyha770c72_0
100
+ - tzdata=2023c=h71feb2d_0
101
+ - urllib3=2.0.4=pyhd8ed1ab_0
102
+ - wheel=0.41.1=pyhd8ed1ab_0
103
+ - xformers=0.0.21=py39_cu11.8.0_pyt2.0.1
104
+ - xorg-libxau=1.0.11=hd590300_0
105
+ - xorg-libxdmcp=1.1.3=h7f98852_0
106
+ - xz=5.2.6=h166bdaf_0
107
+ - yacs=0.1.8=pyhd8ed1ab_0
108
+ - yaml=0.2.5=h7f98852_2
109
+ - zlib=1.2.13=hd590300_5
110
+ - zstd=1.5.2=hfc55251_7
data/demo/jellycat/001.jpg ADDED

Git LFS Details

  • SHA256: bb252fabcd6588b924266098efbf0538c4cc77f6fd623166a3328692ea04b221
  • Pointer size: 132 Bytes
  • Size of remote file: 6.91 MB
data/demo/jellycat/002.jpg ADDED

Git LFS Details

  • SHA256: 4e36092e3ef63d3de0d9645ce001829c248b2c5ae78011c3578276d4f0009ce6
  • Pointer size: 132 Bytes
  • Size of remote file: 6.86 MB
data/demo/jellycat/003.jpg ADDED

Git LFS Details

  • SHA256: 6fcd32e046e04b809529c202f594a49a210d1fcd38a4664ca619704dd317b550
  • Pointer size: 131 Bytes
  • Size of remote file: 169 kB
data/demo/jellycat/004.jpg ADDED

Git LFS Details

  • SHA256: 082007c67949ce96af89d34fbb3dd8a6eeca4d000e4dc39a920215881ee5a4e1
  • Pointer size: 131 Bytes
  • Size of remote file: 119 kB
data/demo/jordan/001.png ADDED

Git LFS Details

  • SHA256: dff6883afa87339f94ac9d2b07a61e01f9107f2e37a7bda326956f209a5b1c61
  • Pointer size: 131 Bytes
  • Size of remote file: 128 kB
data/demo/jordan/002.png ADDED

Git LFS Details

  • SHA256: bee5060ab4b105fd9383398ab47dc1caa2dd329f9e83ae310bb068870d445270
  • Pointer size: 131 Bytes
  • Size of remote file: 126 kB
data/demo/jordan/003.png ADDED

Git LFS Details

  • SHA256: 34353a07643bb0f6dcc8d1a40d1658e393998af73db0cd17e392558581840c03
  • Pointer size: 131 Bytes
  • Size of remote file: 135 kB
data/demo/jordan/004.png ADDED

Git LFS Details

  • SHA256: c671d0fb4ff49d59e6b044e8a673ad6e9293337423f6db1345c3e1c45f0c7427
  • Pointer size: 131 Bytes
  • Size of remote file: 123 kB
data/demo/jordan/005.png ADDED

Git LFS Details

  • SHA256: aefe2f7ca57407dad2a7ce86759b86698d536d5f6c7fd9f97776d3905bb3ce19
  • Pointer size: 131 Bytes
  • Size of remote file: 115 kB
data/demo/jordan/006.png ADDED

Git LFS Details

  • SHA256: f5f23767c8a3830921e1c5299cc8373daacf9af68e96c596df5b8edbdc0f4836
  • Pointer size: 131 Bytes
  • Size of remote file: 145 kB
data/demo/jordan/007.png ADDED

Git LFS Details

  • SHA256: c61521925f93ec2721a02c4c5b4898171985147c763702cb5f9a7efbb341cc2d
  • Pointer size: 131 Bytes
  • Size of remote file: 145 kB
data/demo/jordan/008.png ADDED

Git LFS Details

  • SHA256: bb1e7c3d5fc1ad0067d2d28d9f83e5d5243010353f0ec0fd12f196cf5939f231
  • Pointer size: 131 Bytes
  • Size of remote file: 143 kB
data/demo/kew_gardens_ruined_arch/001.jpeg ADDED

Git LFS Details

  • SHA256: 96dfde51e8d0857120387e3b81fff665b9b94524a6f1a9f35246faed4f3e8986
  • Pointer size: 131 Bytes
  • Size of remote file: 624 kB
data/demo/kew_gardens_ruined_arch/002.jpeg ADDED

Git LFS Details

  • SHA256: 8c2c07e43d51594fbbea708fce0040fbe6b5ecd4e01c8b10898a2f71f3abf186
  • Pointer size: 131 Bytes
  • Size of remote file: 590 kB
data/demo/kew_gardens_ruined_arch/003.jpeg ADDED

Git LFS Details

  • SHA256: bfeea8fcb46fcbb0d77450227927851472a73a714c62de21e75b4d60a3dba317
  • Pointer size: 131 Bytes
  • Size of remote file: 586 kB
data/demo/kotor_cathedral/001.jpeg ADDED

Git LFS Details

  • SHA256: 732a5d344ddcfc2e50a97abc3792bb444cf40b82a93638f2d52d955f6595c90a
  • Pointer size: 131 Bytes
  • Size of remote file: 617 kB
data/demo/kotor_cathedral/002.jpeg ADDED

Git LFS Details

  • SHA256: e2dc18fde3559ae7333351ded6d765b217a5b754fd828087c5f88cbc11c84793
  • Pointer size: 131 Bytes
  • Size of remote file: 760 kB
data/demo/kotor_cathedral/003.jpeg ADDED

Git LFS Details

  • SHA256: b100bdbc2d9943151424b86e504ece203f15ff5d616dd4c515fabc7b3d39d11c
  • Pointer size: 131 Bytes
  • Size of remote file: 697 kB
data/demo/kotor_cathedral/004.jpeg ADDED

Git LFS Details

  • SHA256: 26d6433fa500c03bd982a3abaf2f8028d26a9726e409b880192de5eea17d83b5
  • Pointer size: 131 Bytes
  • Size of remote file: 583 kB
data/demo/kotor_cathedral/005.jpeg ADDED

Git LFS Details

  • SHA256: 15b4e342917ae6df3a43a82aba0fb199098fd5ff9946303109b5e21a613a6d30
  • Pointer size: 131 Bytes
  • Size of remote file: 902 kB
data/demo/kotor_cathedral/006.jpeg ADDED

Git LFS Details

  • SHA256: da1098cc0360bbc34eb1a61a224ccdaa43677c4e6b26687c8f9bb95fbf7a2f42
  • Pointer size: 131 Bytes
  • Size of remote file: 411 kB
diffusionsfm/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .utils.rays import cameras_to_rays, rays_to_cameras, Rays
diffusionsfm/dataset/__init__.py ADDED
File without changes
diffusionsfm/dataset/co3d_v2.py ADDED
@@ -0,0 +1,792 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gzip
2
+ import json
3
+ import os.path as osp
4
+ import random
5
+ import socket
6
+ import time
7
+ import torch
8
+ import warnings
9
+
10
+ import numpy as np
11
+ from PIL import Image, ImageFile
12
+ from tqdm import tqdm
13
+ from pytorch3d.renderer import PerspectiveCameras
14
+ from torch.utils.data import Dataset
15
+ from torchvision import transforms
16
+ import matplotlib.pyplot as plt
17
+ from scipy import ndimage as nd
18
+
19
+ from diffusionsfm.utils.distortion import distort_image
20
+
21
+
22
+ HOSTNAME = socket.gethostname()
23
+
24
+ CO3D_DIR = "../co3d_data" # update this
25
+ CO3D_ANNOTATION_DIR = osp.join(CO3D_DIR, "co3d_annotations")
26
+ CO3D_DIR = CO3D_DEPTH_DIR = osp.join(CO3D_DIR, "co3d")
27
+ order_path = osp.join(
28
+ CO3D_DIR, "co3d_v2_random_order_{sample_num}/{category}.json"
29
+ )
30
+
31
+
32
+ TRAINING_CATEGORIES = [
33
+ "apple",
34
+ "backpack",
35
+ "banana",
36
+ "baseballbat",
37
+ "baseballglove",
38
+ "bench",
39
+ "bicycle",
40
+ "bottle",
41
+ "bowl",
42
+ "broccoli",
43
+ "cake",
44
+ "car",
45
+ "carrot",
46
+ "cellphone",
47
+ "chair",
48
+ "cup",
49
+ "donut",
50
+ "hairdryer",
51
+ "handbag",
52
+ "hydrant",
53
+ "keyboard",
54
+ "laptop",
55
+ "microwave",
56
+ "motorcycle",
57
+ "mouse",
58
+ "orange",
59
+ "parkingmeter",
60
+ "pizza",
61
+ "plant",
62
+ "stopsign",
63
+ "teddybear",
64
+ "toaster",
65
+ "toilet",
66
+ "toybus",
67
+ "toyplane",
68
+ "toytrain",
69
+ "toytruck",
70
+ "tv",
71
+ "umbrella",
72
+ "vase",
73
+ "wineglass",
74
+ ]
75
+
76
+ TEST_CATEGORIES = [
77
+ "ball",
78
+ "book",
79
+ "couch",
80
+ "frisbee",
81
+ "hotdog",
82
+ "kite",
83
+ "remote",
84
+ "sandwich",
85
+ "skateboard",
86
+ "suitcase",
87
+ ]
88
+
89
+ assert len(TRAINING_CATEGORIES) + len(TEST_CATEGORIES) == 51
90
+
91
+ Image.MAX_IMAGE_PIXELS = None
92
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
93
+
94
+
95
+ def fill_depths(data, invalid=None):
96
+ data_list = []
97
+ for i in range(data.shape[0]):
98
+ data_item = data[i].numpy()
99
+ # Invalid must be 1 where stuff is invalid, 0 where valid
100
+ ind = nd.distance_transform_edt(
101
+ invalid[i], return_distances=False, return_indices=True
102
+ )
103
+ data_list.append(torch.tensor(data_item[tuple(ind)]))
104
+ return torch.stack(data_list, dim=0)
105
+
106
+
107
+ def full_scene_scale(batch):
108
+ cameras = PerspectiveCameras(R=batch["R"], T=batch["T"], device="cuda")
109
+ cc = cameras.get_camera_center()
110
+ centroid = torch.mean(cc, dim=0)
111
+
112
+ diffs = cc - centroid
113
+ norms = torch.linalg.norm(diffs, dim=1)
114
+
115
+ furthest_index = torch.argmax(norms).item()
116
+ scale = norms[furthest_index].item()
117
+ return scale
118
+
119
+
120
+ def square_bbox(bbox, padding=0.0, astype=None, tight=False):
121
+ """
122
+ Computes a square bounding box, with optional padding parameters.
123
+ Args:
124
+ bbox: Bounding box in xyxy format (4,).
125
+ Returns:
126
+ square_bbox in xyxy format (4,).
127
+ """
128
+ if astype is None:
129
+ astype = type(bbox[0])
130
+ bbox = np.array(bbox)
131
+ center = (bbox[:2] + bbox[2:]) / 2
132
+ extents = (bbox[2:] - bbox[:2]) / 2
133
+
134
+ # No black bars if tight
135
+ if tight:
136
+ s = min(extents) * (1 + padding)
137
+ else:
138
+ s = max(extents) * (1 + padding)
139
+
140
+ square_bbox = np.array(
141
+ [center[0] - s, center[1] - s, center[0] + s, center[1] + s],
142
+ dtype=astype,
143
+ )
144
+ return square_bbox
145
+
146
+
147
+ def unnormalize_image(image, return_numpy=True, return_int=True):
148
+ if isinstance(image, torch.Tensor):
149
+ image = image.detach().cpu().numpy()
150
+
151
+ if image.ndim == 3:
152
+ if image.shape[0] == 3:
153
+ image = image[None, ...]
154
+ elif image.shape[2] == 3:
155
+ image = image.transpose(2, 0, 1)[None, ...]
156
+ else:
157
+ raise ValueError(f"Unexpected image shape: {image.shape}")
158
+ elif image.ndim == 4:
159
+ if image.shape[1] == 3:
160
+ pass
161
+ elif image.shape[3] == 3:
162
+ image = image.transpose(0, 3, 1, 2)
163
+ else:
164
+ raise ValueError(f"Unexpected batch image shape: {image.shape}")
165
+ else:
166
+ raise ValueError(f"Unsupported input shape: {image.shape}")
167
+
168
+ mean = np.array([0.485, 0.456, 0.406])[None, :, None, None]
169
+ std = np.array([0.229, 0.224, 0.225])[None, :, None, None]
170
+ image = image * std + mean
171
+
172
+ if return_int:
173
+ image = np.clip(image * 255.0, 0, 255).astype(np.uint8)
174
+ else:
175
+ image = np.clip(image, 0.0, 1.0)
176
+
177
+ if image.shape[0] == 1:
178
+ image = image[0]
179
+
180
+ if return_numpy:
181
+ return image
182
+ else:
183
+ return torch.from_numpy(image)
184
+
185
+
186
+ def unnormalize_image_for_vis(image):
187
+ assert len(image.shape) == 5 and image.shape[2] == 3
188
+ mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 1, 3, 1, 1).to(image.device)
189
+ std = torch.tensor([0.229, 0.224, 0.225]).view(1, 1, 3, 1, 1).to(image.device)
190
+ image = image * std + mean
191
+ image = (image - 0.5) / 0.5
192
+ return image
193
+
194
+
195
+ def _transform_intrinsic(image, bbox, principal_point, focal_length):
196
+ # Rescale intrinsics to match bbox
197
+ half_box = np.array([image.width, image.height]).astype(np.float32) / 2
198
+ org_scale = min(half_box).astype(np.float32)
199
+
200
+ # Pixel coordinates
201
+ principal_point_px = half_box - (np.array(principal_point) * org_scale)
202
+ focal_length_px = np.array(focal_length) * org_scale
203
+ principal_point_px -= bbox[:2]
204
+ new_bbox = (bbox[2:] - bbox[:2]) / 2
205
+ new_scale = min(new_bbox)
206
+
207
+ # NDC coordinates
208
+ new_principal_ndc = (new_bbox - principal_point_px) / new_scale
209
+ new_focal_ndc = focal_length_px / new_scale
210
+
211
+ principal_point = torch.tensor(new_principal_ndc.astype(np.float32))
212
+ focal_length = torch.tensor(new_focal_ndc.astype(np.float32))
213
+
214
+ return principal_point, focal_length
215
+
216
+
217
+ def construct_camera_from_batch(batch, device):
218
+ if isinstance(device, int):
219
+ device = f"cuda:{device}"
220
+
221
+ return PerspectiveCameras(
222
+ R=batch["R"].reshape(-1, 3, 3),
223
+ T=batch["T"].reshape(-1, 3),
224
+ focal_length=batch["focal_lengths"].reshape(-1, 2),
225
+ principal_point=batch["principal_points"].reshape(-1, 2),
226
+ image_size=batch["image_sizes"].reshape(-1, 2),
227
+ device=device,
228
+ )
229
+
230
+
231
+ def save_batch_images(images, fname):
232
+ cmap = plt.get_cmap("hsv")
233
+ num_frames = len(images)
234
+ num_rows = len(images)
235
+ num_cols = 4
236
+ figsize = (num_cols * 2, num_rows * 2)
237
+ fig, axs = plt.subplots(num_rows, num_cols, figsize=figsize)
238
+ axs = axs.flatten()
239
+ for i in range(num_rows):
240
+ for j in range(4):
241
+ if i < num_frames:
242
+ axs[i * 4 + j].imshow(unnormalize_image(images[i][j]))
243
+ for s in ["bottom", "top", "left", "right"]:
244
+ axs[i * 4 + j].spines[s].set_color(cmap(i / (num_frames)))
245
+ axs[i * 4 + j].spines[s].set_linewidth(5)
246
+ axs[i * 4 + j].set_xticks([])
247
+ axs[i * 4 + j].set_yticks([])
248
+ else:
249
+ axs[i * 4 + j].axis("off")
250
+ plt.tight_layout()
251
+ plt.savefig(fname)
252
+
253
+
254
+ def jitter_bbox(
255
+ square_bbox,
256
+ jitter_scale=(1.1, 1.2),
257
+ jitter_trans=(-0.07, 0.07),
258
+ direction_from_size=None,
259
+ ):
260
+
261
+ square_bbox = np.array(square_bbox.astype(float))
262
+ s = np.random.uniform(jitter_scale[0], jitter_scale[1])
263
+
264
+ # Jitter only one dimension if center cropping
265
+ tx, ty = np.random.uniform(jitter_trans[0], jitter_trans[1], size=2)
266
+ if direction_from_size is not None:
267
+ if direction_from_size[0] > direction_from_size[1]:
268
+ tx = 0
269
+ else:
270
+ ty = 0
271
+
272
+ side_length = square_bbox[2] - square_bbox[0]
273
+ center = (square_bbox[:2] + square_bbox[2:]) / 2 + np.array([tx, ty]) * side_length
274
+ extent = side_length / 2 * s
275
+ ul = center - extent
276
+ lr = ul + 2 * extent
277
+ return np.concatenate((ul, lr))
278
+
279
+
280
+ class Co3dDataset(Dataset):
281
+ def __init__(
282
+ self,
283
+ category=("all_train",),
284
+ split="train",
285
+ transform=None,
286
+ num_images=2,
287
+ img_size=224,
288
+ mask_images=False,
289
+ crop_images=True,
290
+ co3d_dir=None,
291
+ co3d_annotation_dir=None,
292
+ precropped_images=False,
293
+ apply_augmentation=True,
294
+ normalize_cameras=True,
295
+ no_images=False,
296
+ sample_num=None,
297
+ seed=0,
298
+ load_extra_cameras=False,
299
+ distort_image=False,
300
+ load_depths=False,
301
+ center_crop=False,
302
+ depth_size=256,
303
+ mask_holes=False,
304
+ object_mask=True,
305
+ ):
306
+ """
307
+ Args:
308
+ num_images: Number of images in each batch.
309
+ perspective_correction (str):
310
+ "none": No perspective correction.
311
+ "warp": Warp the image and label.
312
+ "label_only": Correct the label only.
313
+ """
314
+ start_time = time.time()
315
+
316
+ self.category = category
317
+ self.split = split
318
+ self.transform = transform
319
+ self.num_images = num_images
320
+ self.img_size = img_size
321
+ self.mask_images = mask_images
322
+ self.crop_images = crop_images
323
+ self.precropped_images = precropped_images
324
+ self.apply_augmentation = apply_augmentation
325
+ self.normalize_cameras = normalize_cameras
326
+ self.no_images = no_images
327
+ self.sample_num = sample_num
328
+ self.load_extra_cameras = load_extra_cameras
329
+ self.distort = distort_image
330
+ self.load_depths = load_depths
331
+ self.center_crop = center_crop
332
+ self.depth_size = depth_size
333
+ self.mask_holes = mask_holes
334
+ self.object_mask = object_mask
335
+
336
+ if self.apply_augmentation:
337
+ if self.center_crop:
338
+ self.jitter_scale = (0.8, 1.1)
339
+ self.jitter_trans = (0.0, 0.0)
340
+ else:
341
+ self.jitter_scale = (1.1, 1.2)
342
+ self.jitter_trans = (-0.07, 0.07)
343
+ else:
344
+ # Note if trained with apply_augmentation, we should still use
345
+ # apply_augmentation at test time.
346
+ self.jitter_scale = (1, 1)
347
+ self.jitter_trans = (0.0, 0.0)
348
+
349
+ if self.distort:
350
+ self.k1_max = 1.0
351
+ self.k2_max = 1.0
352
+
353
+ if co3d_dir is not None:
354
+ self.co3d_dir = co3d_dir
355
+ self.co3d_annotation_dir = co3d_annotation_dir
356
+ else:
357
+ self.co3d_dir = CO3D_DIR
358
+ self.co3d_annotation_dir = CO3D_ANNOTATION_DIR
359
+ self.co3d_depth_dir = CO3D_DEPTH_DIR
360
+
361
+ if isinstance(self.category, str):
362
+ self.category = [self.category]
363
+
364
+ if "all_train" in self.category:
365
+ self.category = TRAINING_CATEGORIES
366
+ if "all_test" in self.category:
367
+ self.category = TEST_CATEGORIES
368
+ if "full" in self.category:
369
+ self.category = TRAINING_CATEGORIES + TEST_CATEGORIES
370
+ self.category = sorted(self.category)
371
+ self.is_single_category = len(self.category) == 1
372
+
373
+ # Fixing seed
374
+ torch.manual_seed(seed)
375
+ random.seed(seed)
376
+ np.random.seed(seed)
377
+
378
+ print(f"Co3d ({split}):")
379
+
380
+ self.low_quality_translations = [
381
+ "411_55952_107659",
382
+ "427_59915_115716",
383
+ "435_61970_121848",
384
+ "112_13265_22828",
385
+ "110_13069_25642",
386
+ "165_18080_34378",
387
+ "368_39891_78502",
388
+ "391_47029_93665",
389
+ "20_695_1450",
390
+ "135_15556_31096",
391
+ "417_57572_110680",
392
+ ] # Initialized with sequences with poor depth masks
393
+ self.rotations = {}
394
+ self.category_map = {}
395
+ for c in tqdm(self.category):
396
+ annotation_file = osp.join(
397
+ self.co3d_annotation_dir, f"{c}_{self.split}.jgz"
398
+ )
399
+ with gzip.open(annotation_file, "r") as fin:
400
+ annotation = json.loads(fin.read())
401
+
402
+ counter = 0
403
+ for seq_name, seq_data in annotation.items():
404
+ counter += 1
405
+ if len(seq_data) < self.num_images:
406
+ continue
407
+
408
+ filtered_data = []
409
+ self.category_map[seq_name] = c
410
+ bad_seq = False
411
+ for data in seq_data:
412
+ # Make sure translations are not ridiculous and rotations are valid
413
+ det = np.linalg.det(data["R"])
414
+ if (np.abs(data["T"]) > 1e5).any() or det < 0.99 or det > 1.01:
415
+ bad_seq = True
416
+ self.low_quality_translations.append(seq_name)
417
+ break
418
+
419
+ # Ignore all unnecessary information.
420
+ filtered_data.append(
421
+ {
422
+ "filepath": data["filepath"],
423
+ "bbox": data["bbox"],
424
+ "R": data["R"],
425
+ "T": data["T"],
426
+ "focal_length": data["focal_length"],
427
+ "principal_point": data["principal_point"],
428
+ },
429
+ )
430
+
431
+ if not bad_seq:
432
+ self.rotations[seq_name] = filtered_data
433
+
434
+ self.sequence_list = list(self.rotations.keys())
435
+
436
+ IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
437
+ IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
438
+
439
+ if self.transform is None:
440
+ self.transform = transforms.Compose(
441
+ [
442
+ transforms.ToTensor(),
443
+ transforms.Resize(self.img_size, antialias=True),
444
+ transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
445
+ ]
446
+ )
447
+
448
+ self.transform_depth = transforms.Compose(
449
+ [
450
+ transforms.Resize(
451
+ self.depth_size,
452
+ antialias=False,
453
+ interpolation=transforms.InterpolationMode.NEAREST_EXACT,
454
+ ),
455
+ ]
456
+ )
457
+
458
+ print(
459
+ f"Low quality translation sequences, not used: {self.low_quality_translations}"
460
+ )
461
+ print(f"Data size: {len(self)}")
462
+ print(f"Data loading took {(time.time()-start_time)} seconds.")
463
+
464
+ def __len__(self):
465
+ return len(self.sequence_list)
466
+
467
+ def __getitem__(self, index):
468
+ num_to_load = self.num_images if not self.load_extra_cameras else 8
469
+
470
+ sequence_name = self.sequence_list[index % len(self.sequence_list)]
471
+ metadata = self.rotations[sequence_name]
472
+
473
+ if self.sample_num is not None:
474
+ with open(
475
+ order_path.format(sample_num=self.sample_num, category=self.category[0])
476
+ ) as f:
477
+ order = json.load(f)
478
+ ids = order[sequence_name][:num_to_load]
479
+ else:
480
+ replace = len(metadata) < 8
481
+ ids = np.random.choice(len(metadata), num_to_load, replace=replace)
482
+
483
+ return self.get_data(index=index, ids=ids, num_valid_frames=num_to_load)
484
+
485
+ def _get_scene_scale(self, sequence_name):
486
+ n = len(self.rotations[sequence_name])
487
+
488
+ R = torch.zeros(n, 3, 3)
489
+ T = torch.zeros(n, 3)
490
+
491
+ for i, ann in enumerate(self.rotations[sequence_name]):
492
+ R[i, ...] = torch.tensor(self.rotations[sequence_name][i]["R"])
493
+ T[i, ...] = torch.tensor(self.rotations[sequence_name][i]["T"])
494
+
495
+ cameras = PerspectiveCameras(R=R, T=T)
496
+ cc = cameras.get_camera_center()
497
+ centeroid = torch.mean(cc, dim=0)
498
+ diff = cc - centeroid
499
+
500
+ norm = torch.norm(diff, dim=1)
501
+ scale = torch.max(norm).item()
502
+
503
+ return scale
504
+
505
+ def _crop_image(self, image, bbox):
506
+ image_crop = transforms.functional.crop(
507
+ image,
508
+ top=bbox[1],
509
+ left=bbox[0],
510
+ height=bbox[3] - bbox[1],
511
+ width=bbox[2] - bbox[0],
512
+ )
513
+ return image_crop
514
+
515
+ def _transform_intrinsic(self, image, bbox, principal_point, focal_length):
516
+ half_box = np.array([image.width, image.height]).astype(np.float32) / 2
517
+ org_scale = min(half_box).astype(np.float32)
518
+
519
+ # Pixel coordinates
520
+ principal_point_px = half_box - (np.array(principal_point) * org_scale)
521
+ focal_length_px = np.array(focal_length) * org_scale
522
+ principal_point_px -= bbox[:2]
523
+ new_bbox = (bbox[2:] - bbox[:2]) / 2
524
+ new_scale = min(new_bbox)
525
+
526
+ # NDC coordinates
527
+ new_principal_ndc = (new_bbox - principal_point_px) / new_scale
528
+ new_focal_ndc = focal_length_px / new_scale
529
+
530
+ return new_principal_ndc.astype(np.float32), new_focal_ndc.astype(np.float32)
531
+
532
+ def get_data(
533
+ self,
534
+ index=None,
535
+ sequence_name=None,
536
+ ids=(0, 1),
537
+ no_images=False,
538
+ num_valid_frames=None,
539
+ load_using_order=None,
540
+ ):
541
+ if load_using_order is not None:
542
+ with open(
543
+ order_path.format(sample_num=self.sample_num, category=self.category[0])
544
+ ) as f:
545
+ order = json.load(f)
546
+ ids = order[sequence_name][:load_using_order]
547
+
548
+ if sequence_name is None:
549
+ index = index % len(self.sequence_list)
550
+ sequence_name = self.sequence_list[index]
551
+ metadata = self.rotations[sequence_name]
552
+ category = self.category_map[sequence_name]
553
+
554
+ # Read image & camera information from annotations
555
+ annos = [metadata[i] for i in ids]
556
+ images = []
557
+ image_sizes = []
558
+ PP = []
559
+ FL = []
560
+ crop_parameters = []
561
+ filenames = []
562
+ distortion_parameters = []
563
+ depths = []
564
+ depth_masks = []
565
+ object_masks = []
566
+ dino_images = []
567
+ for anno in annos:
568
+ filepath = anno["filepath"]
569
+
570
+ if not no_images:
571
+ image = Image.open(osp.join(self.co3d_dir, filepath)).convert("RGB")
572
+ image_size = image.size
573
+
574
+ # Optionally mask images with black background
575
+ if self.mask_images:
576
+ black_image = Image.new("RGB", image_size, (0, 0, 0))
577
+ mask_name = osp.basename(filepath.replace(".jpg", ".png"))
578
+
579
+ mask_path = osp.join(
580
+ self.co3d_dir, category, sequence_name, "masks", mask_name
581
+ )
582
+ mask = Image.open(mask_path).convert("L")
583
+
584
+ if mask.size != image_size:
585
+ mask = mask.resize(image_size)
586
+ mask = Image.fromarray(np.array(mask) > 125)
587
+ image = Image.composite(image, black_image, mask)
588
+
589
+ if self.object_mask:
590
+ mask_name = osp.basename(filepath.replace(".jpg", ".png"))
591
+ mask_path = osp.join(
592
+ self.co3d_dir, category, sequence_name, "masks", mask_name
593
+ )
594
+ mask = Image.open(mask_path).convert("L")
595
+
596
+ if mask.size != image_size:
597
+ mask = mask.resize(image_size)
598
+ mask = torch.from_numpy(np.array(mask) > 125)
599
+
600
+ # Determine crop, Resnet wants square images
601
+ bbox = np.array(anno["bbox"])
602
+ good_bbox = ((bbox[2:] - bbox[:2]) > 30).all()
603
+ bbox = (
604
+ anno["bbox"]
605
+ if not self.center_crop and good_bbox
606
+ else [0, 0, image.width, image.height]
607
+ )
608
+
609
+ # Distort image and bbox if desired
610
+ if self.distort:
611
+ k1 = random.uniform(0, self.k1_max)
612
+ k2 = random.uniform(0, self.k2_max)
613
+
614
+ try:
615
+ image, bbox = distort_image(
616
+ image, np.array(bbox), k1, k2, modify_bbox=True
617
+ )
618
+
619
+ except:
620
+ print("INFO:")
621
+ print(sequence_name)
622
+ print(index)
623
+ print(ids)
624
+ print(k1)
625
+ print(k2)
626
+
627
+ distortion_parameters.append(torch.FloatTensor([k1, k2]))
628
+
629
+ bbox = square_bbox(np.array(bbox), tight=self.center_crop)
630
+ if self.apply_augmentation:
631
+ bbox = jitter_bbox(
632
+ bbox,
633
+ jitter_scale=self.jitter_scale,
634
+ jitter_trans=self.jitter_trans,
635
+ direction_from_size=image.size if self.center_crop else None,
636
+ )
637
+ bbox = np.around(bbox).astype(int)
638
+
639
+ # Crop parameters
640
+ crop_center = (bbox[:2] + bbox[2:]) / 2
641
+ principal_point = torch.tensor(anno["principal_point"])
642
+ focal_length = torch.tensor(anno["focal_length"])
643
+
644
+ # convert crop center to correspond to a "square" image
645
+ width, height = image.size
646
+ length = max(width, height)
647
+ s = length / min(width, height)
648
+ crop_center = crop_center + (length - np.array([width, height])) / 2
649
+
650
+ # convert to NDC
651
+ cc = s - 2 * s * crop_center / length
652
+ crop_width = 2 * s * (bbox[2] - bbox[0]) / length
653
+ crop_params = torch.tensor([-cc[0], -cc[1], crop_width, s])
654
+
655
+ # Crop and normalize image
656
+ if not self.precropped_images:
657
+ image = self._crop_image(image, bbox)
658
+
659
+ try:
660
+ image = self.transform(image)
661
+ except:
662
+ print("INFO:")
663
+ print(sequence_name)
664
+ print(index)
665
+ print(ids)
666
+ print(k1)
667
+ print(k2)
668
+
669
+ images.append(image[:, : self.img_size, : self.img_size])
670
+ crop_parameters.append(crop_params)
671
+
672
+ if self.load_depths:
673
+ # Open depth map
674
+ depth_name = osp.basename(
675
+ filepath.replace(".jpg", ".jpg.geometric.png")
676
+ )
677
+ depth_path = osp.join(
678
+ self.co3d_depth_dir,
679
+ category,
680
+ sequence_name,
681
+ "depths",
682
+ depth_name,
683
+ )
684
+ depth_pil = Image.open(depth_path)
685
+
686
+ # 16 bit float type casting
687
+ depth = torch.tensor(
688
+ np.frombuffer(
689
+ np.array(depth_pil, dtype=np.uint16), dtype=np.float16
690
+ )
691
+ .astype(np.float32)
692
+ .reshape((depth_pil.size[1], depth_pil.size[0]))
693
+ )
694
+
695
+ # Crop and resize as with images
696
+ if depth_pil.size != image_size:
697
+ # bbox may have the wrong scale
698
+ bbox = depth_pil.size[0] * bbox / image_size[0]
699
+
700
+ if self.object_mask:
701
+ assert mask.shape == depth.shape
702
+
703
+ bbox = np.around(bbox).astype(int)
704
+ depth = self._crop_image(depth, bbox)
705
+
706
+ # Resize
707
+ depth = self.transform_depth(depth.unsqueeze(0))[
708
+ 0, : self.depth_size, : self.depth_size
709
+ ]
710
+ depths.append(depth)
711
+
712
+ if self.object_mask:
713
+ mask = self._crop_image(mask, bbox)
714
+ mask = self.transform_depth(mask.unsqueeze(0))[
715
+ 0, : self.depth_size, : self.depth_size
716
+ ]
717
+ object_masks.append(mask)
718
+
719
+ PP.append(principal_point)
720
+ FL.append(focal_length)
721
+ image_sizes.append(torch.tensor([self.img_size, self.img_size]))
722
+ filenames.append(filepath)
723
+
724
+ if not no_images:
725
+ if self.load_depths:
726
+ depths = torch.stack(depths)
727
+
728
+ depth_masks = torch.logical_or(depths <= 0, depths.isinf())
729
+ depth_masks = (~depth_masks).long()
730
+
731
+ if self.object_mask:
732
+ object_masks = torch.stack(object_masks, dim=0)
733
+
734
+ if self.mask_holes:
735
+ depths = fill_depths(depths, depth_masks == 0)
736
+
737
+ # Sometimes mask_holes misses stuff
738
+ new_masks = torch.logical_or(depths <= 0, depths.isinf())
739
+ new_masks = (~new_masks).long()
740
+ depths[new_masks == 0] = -1
741
+
742
+ assert torch.logical_or(depths > 0, depths == -1).all()
743
+ assert not (depths.isinf()).any()
744
+ assert not (depths.isnan()).any()
745
+
746
+ if self.load_extra_cameras:
747
+ # Remove the extra loaded image, for saving space
748
+ images = images[: self.num_images]
749
+
750
+ if self.distort:
751
+ distortion_parameters = torch.stack(distortion_parameters)
752
+
753
+ images = torch.stack(images)
754
+ crop_parameters = torch.stack(crop_parameters)
755
+ focal_lengths = torch.stack(FL)
756
+ principal_points = torch.stack(PP)
757
+ image_sizes = torch.stack(image_sizes)
758
+ else:
759
+ images = None
760
+ crop_parameters = None
761
+ distortion_parameters = None
762
+ focal_lengths = []
763
+ principal_points = []
764
+ image_sizes = []
765
+
766
+ # Assemble batch info to send back
767
+ R = torch.stack([torch.tensor(anno["R"]) for anno in annos])
768
+ T = torch.stack([torch.tensor(anno["T"]) for anno in annos])
769
+
770
+ batch = {
771
+ "model_id": sequence_name,
772
+ "category": category,
773
+ "n": len(metadata),
774
+ "num_valid_frames": num_valid_frames,
775
+ "ind": torch.tensor(ids),
776
+ "image": images,
777
+ "depth": depths,
778
+ "depth_masks": depth_masks,
779
+ "object_masks": object_masks,
780
+ "R": R,
781
+ "T": T,
782
+ "focal_length": focal_lengths,
783
+ "principal_point": principal_points,
784
+ "image_size": image_sizes,
785
+ "crop_parameters": crop_parameters,
786
+ "distortion_parameters": torch.zeros(4),
787
+ "filename": filenames,
788
+ "category": category,
789
+ "dataset": "co3d",
790
+ }
791
+
792
+ return batch
diffusionsfm/dataset/custom.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+
6
+ from PIL import Image, ImageOps
7
+ from torch.utils.data import Dataset
8
+ from torchvision import transforms
9
+
10
+ from diffusionsfm.dataset.co3d_v2 import square_bbox
11
+
12
+
13
+ class CustomDataset(Dataset):
14
+ def __init__(
15
+ self,
16
+ image_list,
17
+ ):
18
+ self.images = []
19
+
20
+ for image_path in sorted(image_list):
21
+ img = Image.open(image_path)
22
+ img = ImageOps.exif_transpose(img).convert("RGB") # Apply EXIF rotation
23
+ self.images.append(img)
24
+
25
+ self.n = len(self.images)
26
+ self.jitter_scale = [1, 1]
27
+ self.jitter_trans = [0, 0]
28
+ self.transform = transforms.Compose(
29
+ [
30
+ transforms.ToTensor(),
31
+ transforms.Resize(224),
32
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
33
+ ]
34
+ )
35
+ self.transform_for_vis = transforms.Compose(
36
+ [
37
+ transforms.Resize(224),
38
+ ]
39
+ )
40
+
41
+ def __len__(self):
42
+ return 1
43
+
44
+ def _crop_image(self, image, bbox, white_bg=False):
45
+ if white_bg:
46
+ # Only support PIL Images
47
+ image_crop = Image.new(
48
+ "RGB", (bbox[2] - bbox[0], bbox[3] - bbox[1]), (255, 255, 255)
49
+ )
50
+ image_crop.paste(image, (-bbox[0], -bbox[1]))
51
+ else:
52
+ image_crop = transforms.functional.crop(
53
+ image,
54
+ top=bbox[1],
55
+ left=bbox[0],
56
+ height=bbox[3] - bbox[1],
57
+ width=bbox[2] - bbox[0],
58
+ )
59
+ return image_crop
60
+
61
+ def __getitem__(self):
62
+ return self.get_data()
63
+
64
+ def get_data(self):
65
+ cmap = plt.get_cmap("hsv")
66
+ ids = [i for i in range(len(self.images))]
67
+ images = [self.images[i] for i in ids]
68
+ images_transformed = []
69
+ images_for_vis = []
70
+ crop_parameters = []
71
+
72
+ for i, image in enumerate(images):
73
+ bbox = np.array([0, 0, image.width, image.height])
74
+ bbox = square_bbox(bbox, tight=True)
75
+ bbox = np.around(bbox).astype(int)
76
+ image = self._crop_image(image, bbox)
77
+ images_transformed.append(self.transform(image))
78
+ image_for_vis = self.transform_for_vis(image)
79
+ color_float = cmap(i / len(images))
80
+ color_rgb = tuple(int(255 * c) for c in color_float[:3])
81
+ image_for_vis = ImageOps.expand(image_for_vis, border=3, fill=color_rgb)
82
+ images_for_vis.append(image_for_vis)
83
+
84
+ width, height = image.size
85
+ length = max(width, height)
86
+ s = length / min(width, height)
87
+ crop_center = (bbox[:2] + bbox[2:]) / 2
88
+ crop_center = crop_center + (length - np.array([width, height])) / 2
89
+ # convert to NDC
90
+ cc = s - 2 * s * crop_center / length
91
+ crop_width = 2 * s * (bbox[2] - bbox[0]) / length
92
+ crop_params = torch.tensor([-cc[0], -cc[1], crop_width, s])
93
+
94
+ crop_parameters.append(crop_params)
95
+ images = images_transformed
96
+
97
+ batch = {}
98
+ batch["image"] = torch.stack(images)
99
+ batch["image_for_vis"] = images_for_vis
100
+ batch["n"] = len(images)
101
+ batch["ind"] = torch.tensor(ids),
102
+ batch["crop_parameters"] = torch.stack(crop_parameters)
103
+ batch["distortion_parameters"] = torch.zeros(4)
104
+
105
+ return batch
diffusionsfm/eval/__init__.py ADDED
File without changes
diffusionsfm/eval/eval_category.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ import torchvision
5
+ import numpy as np
6
+ from tqdm.auto import tqdm
7
+
8
+ from diffusionsfm.dataset.co3d_v2 import (
9
+ Co3dDataset,
10
+ full_scene_scale,
11
+ )
12
+ from pytorch3d.renderer import PerspectiveCameras
13
+ from diffusionsfm.utils.visualization import filter_and_align_point_clouds
14
+ from diffusionsfm.inference.load_model import load_model
15
+ from diffusionsfm.inference.predict import predict_cameras
16
+ from diffusionsfm.utils.geometry import (
17
+ compute_angular_error_batch,
18
+ get_error,
19
+ n_to_np_rotations,
20
+ )
21
+ from diffusionsfm.utils.slurm import init_slurm_signals_if_slurm
22
+ from diffusionsfm.utils.rays import cameras_to_rays
23
+ from diffusionsfm.utils.rays import normalize_cameras_batch
24
+
25
+
26
+ @torch.no_grad()
27
+ def evaluate(
28
+ cfg,
29
+ model,
30
+ dataset,
31
+ num_images,
32
+ device,
33
+ use_pbar=True,
34
+ calculate_intrinsics=True,
35
+ additional_timesteps=(),
36
+ num_evaluate=None,
37
+ max_num_images=None,
38
+ mode=None,
39
+ metrics=True,
40
+ load_depth=True,
41
+ ):
42
+ if cfg.training.get("dpt_head", False):
43
+ H_in = W_in = 224
44
+ H_out = W_out = cfg.training.full_num_patches_y
45
+ else:
46
+ H_in = H_out = cfg.model.num_patches_x
47
+ W_in = W_out = cfg.model.num_patches_y
48
+
49
+ results = {}
50
+ instances = np.arange(0, len(dataset)) if num_evaluate is None else np.linspace(0, len(dataset) - 1, num_evaluate, endpoint=True, dtype=int)
51
+ instances = tqdm(instances) if use_pbar else instances
52
+
53
+ for counter, idx in enumerate(instances):
54
+ batch = dataset[idx]
55
+ instance = batch["model_id"]
56
+ images = batch["image"].to(device)
57
+ focal_length = batch["focal_length"].to(device)[:num_images]
58
+ R = batch["R"].to(device)[:num_images]
59
+ T = batch["T"].to(device)[:num_images]
60
+ crop_parameters = batch["crop_parameters"].to(device)[:num_images]
61
+
62
+ if load_depth:
63
+ depths = batch["depth"].to(device)[:num_images]
64
+ depth_masks = batch["depth_masks"].to(device)[:num_images]
65
+ try:
66
+ object_masks = batch["object_masks"].to(device)[:num_images]
67
+ except KeyError:
68
+ object_masks = depth_masks.clone()
69
+
70
+ # Normalize cameras and scale depths for output resolution
71
+ cameras_gt = PerspectiveCameras(
72
+ R=R, T=T, focal_length=focal_length, device=device
73
+ )
74
+ cameras_gt, _, _ = normalize_cameras_batch(
75
+ [cameras_gt],
76
+ first_cam_mediod=cfg.training.first_cam_mediod,
77
+ normalize_first_camera=cfg.training.normalize_first_camera,
78
+ depths=depths.unsqueeze(0),
79
+ crop_parameters=crop_parameters.unsqueeze(0),
80
+ num_patches_x=H_in,
81
+ num_patches_y=W_in,
82
+ return_scales=True,
83
+ )
84
+ cameras_gt = cameras_gt[0]
85
+
86
+ gt_rays = cameras_to_rays(
87
+ cameras=cameras_gt,
88
+ num_patches_x=H_in,
89
+ num_patches_y=W_in,
90
+ crop_parameters=crop_parameters,
91
+ depths=depths,
92
+ mode=mode,
93
+ )
94
+ gt_points = gt_rays.get_segments().view(num_images, -1, 3)
95
+
96
+ resize = torchvision.transforms.Resize(
97
+ 224,
98
+ antialias=False,
99
+ interpolation=torchvision.transforms.InterpolationMode.NEAREST_EXACT,
100
+ )
101
+ else:
102
+ cameras_gt = PerspectiveCameras(
103
+ R=R, T=T, focal_length=focal_length, device=device
104
+ )
105
+
106
+ pred_cameras, additional_cams = predict_cameras(
107
+ model,
108
+ images,
109
+ device,
110
+ crop_parameters=crop_parameters,
111
+ num_patches_x=H_out,
112
+ num_patches_y=W_out,
113
+ max_num_images=max_num_images,
114
+ additional_timesteps=additional_timesteps,
115
+ calculate_intrinsics=calculate_intrinsics,
116
+ mode=mode,
117
+ return_rays=True,
118
+ use_homogeneous=cfg.model.get("use_homogeneous", False),
119
+ )
120
+ cameras_to_evaluate = additional_cams + [pred_cameras]
121
+
122
+ all_cams_batch = dataset.get_data(
123
+ sequence_name=instance, ids=np.arange(0, batch["n"]), no_images=True
124
+ )
125
+ gt_scene_scale = full_scene_scale(all_cams_batch)
126
+ R_gt = R
127
+ T_gt = T
128
+
129
+ errors = []
130
+ for _, (camera, pred_rays) in enumerate(cameras_to_evaluate):
131
+ R_pred = camera.R
132
+ T_pred = camera.T
133
+ f_pred = camera.focal_length
134
+
135
+ R_pred_rel = n_to_np_rotations(num_images, R_pred).cpu().numpy()
136
+ R_gt_rel = n_to_np_rotations(num_images, batch["R"]).cpu().numpy()
137
+ R_error = compute_angular_error_batch(R_pred_rel, R_gt_rel)
138
+
139
+ CC_error, _ = get_error(True, R_pred, T_pred, R_gt, T_gt, gt_scene_scale)
140
+
141
+ if load_depth and metrics:
142
+ # Evaluate outputs at the same resolution as DUSt3R
143
+ pred_points = pred_rays.get_segments().view(num_images, H_out, H_out, 3)
144
+ pred_points = pred_points.permute(0, 3, 1, 2)
145
+ pred_points = resize(pred_points).permute(0, 2, 3, 1).view(num_images, H_in*W_in, 3)
146
+
147
+ (
148
+ _,
149
+ _,
150
+ _,
151
+ _,
152
+ metric_values,
153
+ ) = filter_and_align_point_clouds(
154
+ num_images,
155
+ gt_points,
156
+ pred_points,
157
+ depth_masks,
158
+ depth_masks,
159
+ images,
160
+ metrics=metrics,
161
+ num_patches_x=H_in,
162
+ )
163
+
164
+ (
165
+ _,
166
+ _,
167
+ _,
168
+ _,
169
+ object_metric_values,
170
+ ) = filter_and_align_point_clouds(
171
+ num_images,
172
+ gt_points,
173
+ pred_points,
174
+ depth_masks * object_masks,
175
+ depth_masks * object_masks,
176
+ images,
177
+ metrics=metrics,
178
+ num_patches_x=H_in,
179
+ )
180
+
181
+ result = {
182
+ "R_pred": R_pred.detach().cpu().numpy().tolist(),
183
+ "T_pred": T_pred.detach().cpu().numpy().tolist(),
184
+ "f_pred": f_pred.detach().cpu().numpy().tolist(),
185
+ "R_gt": R_gt.detach().cpu().numpy().tolist(),
186
+ "T_gt": T_gt.detach().cpu().numpy().tolist(),
187
+ "f_gt": focal_length.detach().cpu().numpy().tolist(),
188
+ "scene_scale": gt_scene_scale,
189
+ "R_error": R_error.tolist(),
190
+ "CC_error": CC_error,
191
+ }
192
+
193
+ if load_depth and metrics:
194
+ result["CD"] = metric_values[1]
195
+ result["CD_Object"] = object_metric_values[1]
196
+ else:
197
+ result["CD"] = 0
198
+ result["CD_Object"] = 0
199
+
200
+ errors.append(result)
201
+
202
+ results[instance] = errors
203
+
204
+ if counter == len(dataset) - 1:
205
+ break
206
+ return results
207
+
208
+
209
+ def save_results(
210
+ output_dir,
211
+ checkpoint=800_000,
212
+ category="hydrant",
213
+ num_images=None,
214
+ calculate_additional_timesteps=True,
215
+ calculate_intrinsics=True,
216
+ split="test",
217
+ force=False,
218
+ sample_num=1,
219
+ max_num_images=None,
220
+ dataset="co3d",
221
+ ):
222
+ init_slurm_signals_if_slurm()
223
+ os.umask(000) # Default to 777 permissions
224
+ eval_path = os.path.join(
225
+ output_dir,
226
+ f"eval_{dataset}",
227
+ f"{category}_{num_images}_{sample_num}_ckpt{checkpoint}.json",
228
+ )
229
+
230
+ if os.path.exists(eval_path) and not force:
231
+ print(f"File {eval_path} already exists. Skipping.")
232
+ return
233
+
234
+ if num_images is not None and num_images > 8:
235
+ custom_keys = {"model.num_images": num_images}
236
+ ignore_keys = ["pos_table"]
237
+ else:
238
+ custom_keys = None
239
+ ignore_keys = []
240
+
241
+ device = torch.device("cuda")
242
+ model, cfg = load_model(
243
+ output_dir,
244
+ checkpoint=checkpoint,
245
+ device=device,
246
+ custom_keys=custom_keys,
247
+ ignore_keys=ignore_keys,
248
+ )
249
+ if num_images is None:
250
+ num_images = cfg.dataset.num_images
251
+
252
+ if cfg.training.dpt_head:
253
+ # Evaluate outputs at the same resolution as DUSt3R
254
+ depth_size = 224
255
+ else:
256
+ depth_size = cfg.model.num_patches_x
257
+
258
+ dataset = Co3dDataset(
259
+ category=category,
260
+ split=split,
261
+ num_images=num_images,
262
+ apply_augmentation=False,
263
+ sample_num=None if split == "train" else sample_num,
264
+ use_global_intrinsics=cfg.dataset.use_global_intrinsics,
265
+ load_depths=True,
266
+ center_crop=True,
267
+ depth_size=depth_size,
268
+ mask_holes=not cfg.training.regression,
269
+ img_size=256 if cfg.model.unet_diffuser else 224,
270
+ )
271
+ print(f"Category {category} {len(dataset)}")
272
+
273
+ if calculate_additional_timesteps:
274
+ additional_timesteps = [0, 10, 20, 30, 40, 50, 60, 70, 80, 90]
275
+ else:
276
+ additional_timesteps = []
277
+
278
+ results = evaluate(
279
+ cfg=cfg,
280
+ model=model,
281
+ dataset=dataset,
282
+ num_images=num_images,
283
+ device=device,
284
+ calculate_intrinsics=calculate_intrinsics,
285
+ additional_timesteps=additional_timesteps,
286
+ max_num_images=max_num_images,
287
+ mode="segment",
288
+ )
289
+
290
+ os.makedirs(os.path.dirname(eval_path), exist_ok=True)
291
+ with open(eval_path, "w") as f:
292
+ json.dump(results, f)
diffusionsfm/eval/eval_jobs.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ python -m diffusionsfm.eval.eval_jobs --eval_path output/multi_diffusionsfm_dense --use_submitit
3
+ """
4
+
5
+ import os
6
+ import json
7
+ import submitit
8
+ import argparse
9
+ import itertools
10
+ from glob import glob
11
+
12
+ import numpy as np
13
+ from tqdm.auto import tqdm
14
+
15
+ from diffusionsfm.dataset.co3d_v2 import TEST_CATEGORIES, TRAINING_CATEGORIES
16
+ from diffusionsfm.eval.eval_category import save_results
17
+ from diffusionsfm.utils.slurm import submitit_job_watcher
18
+
19
+
20
+ def evaluate_diffusionsfm(eval_path, use_submitit, mode):
21
+ JOB_PARAMS = {
22
+ "output_dir": [eval_path],
23
+ "checkpoint": [800_000],
24
+ "num_images": [2, 3, 4, 5, 6, 7, 8],
25
+ "sample_num": [0, 1, 2, 3, 4],
26
+ "category": TEST_CATEGORIES, # TRAINING_CATEGORIES + TEST_CATEGORIES,
27
+ "calculate_additional_timesteps": [True],
28
+ }
29
+ if mode == "test":
30
+ JOB_PARAMS["category"] = TEST_CATEGORIES
31
+ elif mode == "train1":
32
+ JOB_PARAMS["category"] = TRAINING_CATEGORIES[:len(TRAINING_CATEGORIES) // 2]
33
+ elif mode == "train2":
34
+ JOB_PARAMS["category"] = TRAINING_CATEGORIES[len(TRAINING_CATEGORIES) // 2:]
35
+ keys, values = zip(*JOB_PARAMS.items())
36
+ job_configs = [dict(zip(keys, p)) for p in itertools.product(*values)]
37
+
38
+ if use_submitit:
39
+ log_output = "./slurm_logs"
40
+ executor = submitit.AutoExecutor(
41
+ cluster=None, folder=log_output, slurm_max_num_timeout=10
42
+ )
43
+ # Use your own parameters
44
+ executor.update_parameters(
45
+ slurm_additional_parameters={
46
+ "nodes": 1,
47
+ "cpus-per-task": 5,
48
+ "gpus": 1,
49
+ "time": "6:00:00",
50
+ "partition": "all",
51
+ "exclude": "grogu-1-9, grogu-1-14,"
52
+ }
53
+ )
54
+ jobs = []
55
+ with executor.batch():
56
+ # This context manager submits all jobs at once at the end.
57
+ for params in job_configs:
58
+ job = executor.submit(save_results, **params)
59
+ job_param = f"{params['category']}_N{params['num_images']}_{params['sample_num']}"
60
+ jobs.append((job_param, job))
61
+ jobs = {f"{job_param}_{job.job_id}": job for job_param, job in jobs}
62
+ submitit_job_watcher(jobs)
63
+ else:
64
+ for job_config in tqdm(job_configs):
65
+ # This is much slower.
66
+ save_results(**job_config)
67
+
68
+
69
+ def process_predictions(eval_path, pred_index, checkpoint=800_000, threshold_R=15, threshold_CC=0.1):
70
+ """
71
+ pred_index should be 1 (corresponding to T=90)
72
+ """
73
+ def aggregate_per_category(categories, metric_key, num_images, sample_num, threshold=None):
74
+ """
75
+ Aggregates one metric over all data points in a prediction file and then across categories.
76
+ - For R_error and CC_error: use mean to threshold-based accuracy
77
+ - For CD and CD_Object: use median to reduce the effect of outliers
78
+ """
79
+ per_category_values = []
80
+
81
+ for category in tqdm(categories, desc=f"Sample {sample_num}, N={num_images}, {metric_key}"):
82
+ per_pred_values = []
83
+
84
+ data_path = glob(
85
+ os.path.join(eval_path, "eval", f"{category}_{num_images}_{sample_num}_ckpt{checkpoint}*.json")
86
+ )[0]
87
+
88
+ with open(data_path) as f:
89
+ eval_data = json.load(f)
90
+
91
+ for preds in eval_data.values():
92
+ if metric_key in ["R_error", "CC_error"]:
93
+ vals = np.array(preds[pred_index][metric_key])
94
+ per_pred_values.append(np.mean(vals < threshold))
95
+ else:
96
+ per_pred_values.append(preds[pred_index][metric_key])
97
+
98
+ # Aggregate over all predictions within this category
99
+ per_category_values.append(
100
+ np.mean(per_pred_values) if metric_key in ["R_error", "CC_error"]
101
+ else np.median(per_pred_values) # CD or CD_Object — use median to filter outliers
102
+ )
103
+
104
+ if metric_key in ["R_error", "CC_error"]:
105
+ return np.mean(per_category_values)
106
+ else:
107
+ return np.median(per_category_values)
108
+
109
+ def aggregate_metric(categories, metric_key, num_images, threshold=None):
110
+ """Aggregates one metric over 5 random samples per category and returns the final mean"""
111
+ return np.mean([
112
+ aggregate_per_category(categories, metric_key, num_images, sample_num, threshold=threshold)
113
+ for sample_num in range(5)
114
+ ])
115
+
116
+ # Output containers
117
+ all_seen_acc_R, all_seen_acc_CC = [], []
118
+ all_seen_CD, all_seen_CD_Object = [], []
119
+ all_unseen_acc_R, all_unseen_acc_CC = [], []
120
+ all_unseen_CD, all_unseen_CD_Object = [], []
121
+
122
+ for num_images in range(2, 9):
123
+ # Seen categories
124
+ all_seen_acc_R.append(
125
+ aggregate_metric(TRAINING_CATEGORIES, "R_error", num_images, threshold=threshold_R)
126
+ )
127
+ all_seen_acc_CC.append(
128
+ aggregate_metric(TRAINING_CATEGORIES, "CC_error", num_images, threshold=threshold_CC)
129
+ )
130
+ all_seen_CD.append(
131
+ aggregate_metric(TRAINING_CATEGORIES, "CD", num_images)
132
+ )
133
+ all_seen_CD_Object.append(
134
+ aggregate_metric(TRAINING_CATEGORIES, "CD_Object", num_images)
135
+ )
136
+
137
+ # Unseen categories
138
+ all_unseen_acc_R.append(
139
+ aggregate_metric(TEST_CATEGORIES, "R_error", num_images, threshold=threshold_R)
140
+ )
141
+ all_unseen_acc_CC.append(
142
+ aggregate_metric(TEST_CATEGORIES, "CC_error", num_images, threshold=threshold_CC)
143
+ )
144
+ all_unseen_CD.append(
145
+ aggregate_metric(TEST_CATEGORIES, "CD", num_images)
146
+ )
147
+ all_unseen_CD_Object.append(
148
+ aggregate_metric(TEST_CATEGORIES, "CD_Object", num_images)
149
+ )
150
+
151
+ # Print the results in formatted rows
152
+ print("N= ", " ".join(f"{i: 5}" for i in range(2, 9)))
153
+ print("Seen R ", " ".join([f"{x:0.3f}" for x in all_seen_acc_R]))
154
+ print("Seen CC ", " ".join([f"{x:0.3f}" for x in all_seen_acc_CC]))
155
+ print("Seen CD ", " ".join([f"{x:0.3f}" for x in all_seen_CD]))
156
+ print("Seen CD_Obj ", " ".join([f"{x:0.3f}" for x in all_seen_CD_Object]))
157
+ print("Unseen R ", " ".join([f"{x:0.3f}" for x in all_unseen_acc_R]))
158
+ print("Unseen CC ", " ".join([f"{x:0.3f}" for x in all_unseen_acc_CC]))
159
+ print("Unseen CD ", " ".join([f"{x:0.3f}" for x in all_unseen_CD]))
160
+ print("Unseen CD_Obj", " ".join([f"{x:0.3f}" for x in all_unseen_CD_Object]))
161
+
162
+
163
+ if __name__ == "__main__":
164
+ parser = argparse.ArgumentParser()
165
+ parser.add_argument("--eval_path", type=str, default=None)
166
+ parser.add_argument("--use_submitit", action="store_true")
167
+ parser.add_argument("--mode", type=str, default="test")
168
+ args = parser.parse_args()
169
+
170
+ eval_path = "output/multi_diffusionsfm_dense" if args.eval_path is None else args.eval_path
171
+ use_submitit = args.use_submitit
172
+ mode = args.mode
173
+
174
+ evaluate_diffusionsfm(eval_path, use_submitit, mode)
175
+ process_predictions(eval_path, 1)
diffusionsfm/inference/__init__.py ADDED
File without changes
diffusionsfm/inference/ddim.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import random
3
+ import numpy as np
4
+ from tqdm.auto import tqdm
5
+
6
+ from diffusionsfm.utils.rays import compute_ndc_coordinates
7
+
8
+
9
+ def inference_ddim(
10
+ model,
11
+ images,
12
+ device,
13
+ crop_parameters=None,
14
+ eta=0,
15
+ num_inference_steps=100,
16
+ pbar=True,
17
+ num_patches_x=16,
18
+ num_patches_y=16,
19
+ visualize=False,
20
+ seed=0,
21
+ ):
22
+ """
23
+ Implements DDIM-style inference.
24
+
25
+ To get multiple samples, batch the images multiple times.
26
+
27
+ Args:
28
+ model: Ray Diffuser.
29
+ images (torch.Tensor): (B, N, C, H, W).
30
+ patch_rays_gt (torch.Tensor): If provided, the patch rays which are ground
31
+ truth (B, N, P, 6).
32
+ eta (float, optional): Stochasticity coefficient. 0 is completely deterministic,
33
+ 1 is equivalent to DDPM. (Default: 0)
34
+ num_inference_steps (int, optional): Number of inference steps. (Default: 100)
35
+ pbar (bool, optional): Whether to show progress bar. (Default: True)
36
+ """
37
+ timesteps = model.noise_scheduler.compute_inference_timesteps(num_inference_steps)
38
+ batch_size = images.shape[0]
39
+ num_images = images.shape[1]
40
+
41
+ if isinstance(eta, list):
42
+ eta_0, eta_1 = float(eta[0]), float(eta[1])
43
+ else:
44
+ eta_0, eta_1 = 0, 0
45
+
46
+ # Fixing seed
47
+ if seed is not None:
48
+ torch.manual_seed(seed)
49
+ random.seed(seed)
50
+ np.random.seed(seed)
51
+
52
+ with torch.no_grad():
53
+ x_tau = torch.randn(
54
+ batch_size,
55
+ num_images,
56
+ model.ray_out if hasattr(model, "ray_out") else model.ray_dim,
57
+ num_patches_x,
58
+ num_patches_y,
59
+ device=device,
60
+ )
61
+
62
+ if visualize:
63
+ x_taus = [x_tau]
64
+ all_pred = []
65
+ noise_samples = []
66
+
67
+ image_features = model.feature_extractor(images, autoresize=True)
68
+
69
+ if model.append_ndc:
70
+ ndc_coordinates = compute_ndc_coordinates(
71
+ crop_parameters=crop_parameters,
72
+ no_crop_param_device="cpu",
73
+ num_patches_x=model.width,
74
+ num_patches_y=model.width,
75
+ distortion_coeffs=None,
76
+ )[..., :2].to(device)
77
+ ndc_coordinates = ndc_coordinates.permute(0, 1, 4, 2, 3)
78
+ else:
79
+ ndc_coordinates = None
80
+
81
+ loop = tqdm(range(len(timesteps))) if pbar else range(len(timesteps))
82
+ for t in loop:
83
+ tau = timesteps[t]
84
+
85
+ if tau > 0 and eta_1 > 0:
86
+ z = torch.randn(
87
+ batch_size,
88
+ num_images,
89
+ model.ray_out if hasattr(model, "ray_out") else model.ray_dim,
90
+ num_patches_x,
91
+ num_patches_y,
92
+ device=device,
93
+ )
94
+ else:
95
+ z = 0
96
+
97
+ alpha = model.noise_scheduler.alphas_cumprod[tau]
98
+ if tau > 0:
99
+ tau_prev = timesteps[t + 1]
100
+ alpha_prev = model.noise_scheduler.alphas_cumprod[tau_prev]
101
+ else:
102
+ alpha_prev = torch.tensor(1.0, device=device).float()
103
+
104
+ sigma_t = (
105
+ torch.sqrt((1 - alpha_prev) / (1 - alpha))
106
+ * torch.sqrt(1 - alpha / alpha_prev)
107
+ )
108
+
109
+ eps_pred, noise_sample = model(
110
+ features=image_features,
111
+ rays_noisy=x_tau,
112
+ t=int(tau),
113
+ ndc_coordinates=ndc_coordinates,
114
+ )
115
+
116
+ if model.use_homogeneous:
117
+ p1 = eps_pred[:, :, :4]
118
+ p2 = eps_pred[:, :, 4:]
119
+
120
+ c1 = torch.linalg.norm(p1, dim=2, keepdim=True)
121
+ c2 = torch.linalg.norm(p2, dim=2, keepdim=True)
122
+ eps_pred[:, :, :4] = p1 / c1
123
+ eps_pred[:, :, 4:] = p2 / c2
124
+
125
+ if visualize:
126
+ all_pred.append(eps_pred.clone())
127
+ noise_samples.append(noise_sample)
128
+
129
+ # TODO: Can simplify this a lot
130
+ x0_pred = eps_pred.clone()
131
+ eps_pred = (x_tau - torch.sqrt(alpha) * eps_pred) / torch.sqrt(
132
+ 1 - alpha
133
+ )
134
+
135
+ dir_x_tau = torch.sqrt(1 - alpha_prev - eta_0*sigma_t**2) * eps_pred
136
+ noise = eta_1 * sigma_t * z
137
+
138
+ new_x_tau = torch.sqrt(alpha_prev) * x0_pred + dir_x_tau + noise
139
+ x_tau = new_x_tau
140
+
141
+ if visualize:
142
+ x_taus.append(x_tau.detach().clone())
143
+ if visualize:
144
+ return x_tau, x_taus, all_pred, noise_samples
145
+ return x_tau
diffusionsfm/inference/load_model.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path as osp
2
+ from glob import glob
3
+
4
+ import torch
5
+ from omegaconf import OmegaConf
6
+
7
+ from diffusionsfm.model.diffuser import RayDiffuser
8
+ from diffusionsfm.model.diffuser_dpt import RayDiffuserDPT
9
+ from diffusionsfm.model.scheduler import NoiseScheduler
10
+
11
+
12
+ def load_model(
13
+ output_dir, checkpoint=None, device="cuda:0", custom_keys=None, ignore_keys=()
14
+ ):
15
+ """
16
+ Loads a model and config from an output directory.
17
+
18
+ E.g. to load with different number of images,
19
+ ```
20
+ custom_keys={"model.num_images": 15}, ignore_keys=["pos_table"]
21
+ ```
22
+
23
+ Args:
24
+ output_dir (str): Path to the output directory.
25
+ checkpoint (str or int): Path to the checkpoint to load. If None, loads the
26
+ latest checkpoint.
27
+ device (str): Device to load the model on.
28
+ custom_keys (dict): Dictionary of custom keys to override in the config.
29
+ """
30
+ if checkpoint is None:
31
+ checkpoint_path = sorted(glob(osp.join(output_dir, "checkpoints", "*.pth")))[-1]
32
+ else:
33
+ if isinstance(checkpoint, int):
34
+ checkpoint_name = f"ckpt_{checkpoint:08d}.pth"
35
+ else:
36
+ checkpoint_name = checkpoint
37
+ checkpoint_path = osp.join(output_dir, "checkpoints", checkpoint_name)
38
+ print("Loading checkpoint", osp.basename(checkpoint_path))
39
+
40
+ cfg = OmegaConf.load(osp.join(output_dir, "hydra", "config.yaml"))
41
+ if custom_keys is not None:
42
+ for k, v in custom_keys.items():
43
+ OmegaConf.update(cfg, k, v)
44
+ noise_scheduler = NoiseScheduler(
45
+ type=cfg.noise_scheduler.type,
46
+ max_timesteps=cfg.noise_scheduler.max_timesteps,
47
+ beta_start=cfg.noise_scheduler.beta_start,
48
+ beta_end=cfg.noise_scheduler.beta_end,
49
+ )
50
+
51
+ if not cfg.training.get("dpt_head", False):
52
+ model = RayDiffuser(
53
+ depth=cfg.model.depth,
54
+ width=cfg.model.num_patches_x,
55
+ P=1,
56
+ max_num_images=cfg.model.num_images,
57
+ noise_scheduler=noise_scheduler,
58
+ feature_extractor=cfg.model.feature_extractor,
59
+ append_ndc=cfg.model.append_ndc,
60
+ diffuse_depths=cfg.training.get("diffuse_depths", False),
61
+ depth_resolution=cfg.training.get("depth_resolution", 1),
62
+ use_homogeneous=cfg.model.get("use_homogeneous", False),
63
+ cond_depth_mask=cfg.model.get("cond_depth_mask", False),
64
+ ).to(device)
65
+ else:
66
+ model = RayDiffuserDPT(
67
+ depth=cfg.model.depth,
68
+ width=cfg.model.num_patches_x,
69
+ P=1,
70
+ max_num_images=cfg.model.num_images,
71
+ noise_scheduler=noise_scheduler,
72
+ feature_extractor=cfg.model.feature_extractor,
73
+ append_ndc=cfg.model.append_ndc,
74
+ diffuse_depths=cfg.training.get("diffuse_depths", False),
75
+ depth_resolution=cfg.training.get("depth_resolution", 1),
76
+ encoder_features=cfg.training.get("dpt_encoder_features", False),
77
+ use_homogeneous=cfg.model.get("use_homogeneous", False),
78
+ cond_depth_mask=cfg.model.get("cond_depth_mask", False),
79
+ ).to(device)
80
+
81
+ data = torch.load(checkpoint_path)
82
+ state_dict = {}
83
+ for k, v in data["state_dict"].items():
84
+ include = True
85
+ for ignore_key in ignore_keys:
86
+ if ignore_key in k:
87
+ include = False
88
+ if include:
89
+ state_dict[k] = v
90
+
91
+ missing, unexpected = model.load_state_dict(state_dict, strict=False)
92
+ if len(missing) > 0:
93
+ print("Missing keys:", missing)
94
+ if len(unexpected) > 0:
95
+ print("Unexpected keys:", unexpected)
96
+ model = model.eval()
97
+ return model, cfg
diffusionsfm/inference/predict.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusionsfm.inference.ddim import inference_ddim
2
+ from diffusionsfm.utils.rays import (
3
+ Rays,
4
+ rays_to_cameras,
5
+ rays_to_cameras_homography,
6
+ )
7
+
8
+
9
+ def predict_cameras(
10
+ model,
11
+ images,
12
+ device,
13
+ crop_parameters=None,
14
+ num_patches_x=16,
15
+ num_patches_y=16,
16
+ additional_timesteps=(),
17
+ calculate_intrinsics=False,
18
+ max_num_images=None,
19
+ mode=None,
20
+ return_rays=False,
21
+ use_homogeneous=False,
22
+ seed=0,
23
+ ):
24
+ """
25
+ Args:
26
+ images (torch.Tensor): (N, C, H, W)
27
+ crop_parameters (torch.Tensor): (N, 4) or None
28
+ """
29
+ if calculate_intrinsics:
30
+ ray_to_cam = rays_to_cameras_homography
31
+ else:
32
+ ray_to_cam = rays_to_cameras
33
+
34
+ get_spatial_rays = Rays.from_spatial
35
+
36
+ rays_final, rays_intermediate, pred_intermediate, _ = inference_ddim(
37
+ model,
38
+ images.unsqueeze(0),
39
+ device,
40
+ visualize=True,
41
+ crop_parameters=crop_parameters.unsqueeze(0),
42
+ num_patches_x=num_patches_x,
43
+ num_patches_y=num_patches_y,
44
+ pbar=False,
45
+ eta=[1, 0],
46
+ num_inference_steps=100,
47
+ )
48
+
49
+ spatial_rays = get_spatial_rays(
50
+ rays_final[0],
51
+ mode=mode,
52
+ num_patches_x=num_patches_x,
53
+ num_patches_y=num_patches_y,
54
+ use_homogeneous=use_homogeneous,
55
+ )
56
+
57
+ pred_cam = ray_to_cam(
58
+ spatial_rays,
59
+ crop_parameters,
60
+ num_patches_x=num_patches_x,
61
+ num_patches_y=num_patches_y,
62
+ depth_resolution=model.depth_resolution,
63
+ average_centers=True,
64
+ directions_from_averaged_center=True,
65
+ )
66
+
67
+ additional_predictions = []
68
+ for t in additional_timesteps:
69
+ ray = pred_intermediate[t]
70
+
71
+ ray = get_spatial_rays(
72
+ ray[0],
73
+ mode=mode,
74
+ num_patches_x=num_patches_x,
75
+ num_patches_y=num_patches_y,
76
+ use_homogeneous=use_homogeneous,
77
+ )
78
+
79
+ cam = ray_to_cam(
80
+ ray,
81
+ crop_parameters,
82
+ num_patches_x=num_patches_x,
83
+ num_patches_y=num_patches_y,
84
+ average_centers=True,
85
+ directions_from_averaged_center=True,
86
+ )
87
+ if return_rays:
88
+ cam = (cam, ray)
89
+ additional_predictions.append(cam)
90
+
91
+ if return_rays:
92
+ return (pred_cam, spatial_rays), additional_predictions
93
+ return pred_cam, additional_predictions, spatial_rays
diffusionsfm/model/base_model.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class BaseModel(torch.nn.Module):
5
+ def load(self, path):
6
+ """Load model from file.
7
+
8
+ Args:
9
+ path (str): file path
10
+ """
11
+ parameters = torch.load(path, map_location=torch.device("cpu"))
12
+
13
+ if "optimizer" in parameters:
14
+ parameters = parameters["model"]
15
+
16
+ self.load_state_dict(parameters)
diffusionsfm/model/blocks.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from diffusionsfm.model.dit import TimestepEmbedder
4
+ import ipdb
5
+
6
+
7
+ def modulate(x, shift, scale):
8
+ return x * (1 + scale.unsqueeze(-1).unsqueeze(-1)) + shift.unsqueeze(-1).unsqueeze(
9
+ -1
10
+ )
11
+
12
+
13
+ def _make_fusion_block(features, use_bn, use_ln, dpt_time, resolution):
14
+ return FeatureFusionBlock_custom(
15
+ features,
16
+ nn.ReLU(False),
17
+ deconv=False,
18
+ bn=use_bn,
19
+ expand=False,
20
+ align_corners=True,
21
+ dpt_time=dpt_time,
22
+ ln=use_ln,
23
+ resolution=resolution
24
+ )
25
+
26
+
27
+ def _make_scratch(in_shape, out_shape, groups=1, expand=False):
28
+ scratch = nn.Module()
29
+
30
+ out_shape1 = out_shape
31
+ out_shape2 = out_shape
32
+ out_shape3 = out_shape
33
+ out_shape4 = out_shape
34
+ if expand == True:
35
+ out_shape1 = out_shape
36
+ out_shape2 = out_shape * 2
37
+ out_shape3 = out_shape * 4
38
+ out_shape4 = out_shape * 8
39
+
40
+ scratch.layer1_rn = nn.Conv2d(
41
+ in_shape[0],
42
+ out_shape1,
43
+ kernel_size=3,
44
+ stride=1,
45
+ padding=1,
46
+ bias=False,
47
+ groups=groups,
48
+ )
49
+ scratch.layer2_rn = nn.Conv2d(
50
+ in_shape[1],
51
+ out_shape2,
52
+ kernel_size=3,
53
+ stride=1,
54
+ padding=1,
55
+ bias=False,
56
+ groups=groups,
57
+ )
58
+ scratch.layer3_rn = nn.Conv2d(
59
+ in_shape[2],
60
+ out_shape3,
61
+ kernel_size=3,
62
+ stride=1,
63
+ padding=1,
64
+ bias=False,
65
+ groups=groups,
66
+ )
67
+ scratch.layer4_rn = nn.Conv2d(
68
+ in_shape[3],
69
+ out_shape4,
70
+ kernel_size=3,
71
+ stride=1,
72
+ padding=1,
73
+ bias=False,
74
+ groups=groups,
75
+ )
76
+
77
+ return scratch
78
+
79
+
80
+ class ResidualConvUnit_custom(nn.Module):
81
+ """Residual convolution module."""
82
+
83
+ def __init__(self, features, activation, bn, ln, dpt_time=False, resolution=16):
84
+ """Init.
85
+
86
+ Args:
87
+ features (int): number of features
88
+ """
89
+ super().__init__()
90
+
91
+ self.bn = bn
92
+ self.ln = ln
93
+
94
+ self.groups = 1
95
+
96
+ self.conv1 = nn.Conv2d(
97
+ features,
98
+ features,
99
+ kernel_size=3,
100
+ stride=1,
101
+ padding=1,
102
+ bias=not self.bn,
103
+ groups=self.groups,
104
+ )
105
+
106
+ self.conv2 = nn.Conv2d(
107
+ features,
108
+ features,
109
+ kernel_size=3,
110
+ stride=1,
111
+ padding=1,
112
+ bias=not self.bn,
113
+ groups=self.groups,
114
+ )
115
+
116
+ nn.init.kaiming_uniform_(self.conv1.weight)
117
+ nn.init.kaiming_uniform_(self.conv2.weight)
118
+
119
+ if self.bn == True:
120
+ self.bn1 = nn.BatchNorm2d(features)
121
+ self.bn2 = nn.BatchNorm2d(features)
122
+
123
+ if self.ln == True:
124
+ self.bn1 = nn.LayerNorm((features, resolution, resolution))
125
+ self.bn2 = nn.LayerNorm((features, resolution, resolution))
126
+
127
+ self.activation = activation
128
+
129
+ if dpt_time:
130
+ self.t_embedder = TimestepEmbedder(hidden_size=features)
131
+ self.adaLN_modulation = nn.Sequential(
132
+ nn.SiLU(), nn.Linear(features, 3 * features, bias=True)
133
+ )
134
+
135
+ def forward(self, x, t=None):
136
+ """Forward pass.
137
+
138
+ Args:
139
+ x (tensor): input
140
+
141
+ Returns:
142
+ tensor: output
143
+ """
144
+ if t is not None:
145
+ # Embed timestamp & calculate shift parameters
146
+ t = self.t_embedder(t) # (B*N)
147
+ shift, scale, gate = self.adaLN_modulation(t).chunk(3, dim=1) # (B * N, T)
148
+
149
+ # Shift & scale x
150
+ x = modulate(x, shift, scale) # (B * N, T, H, W)
151
+
152
+ out = self.activation(x)
153
+ out = self.conv1(out)
154
+ if self.bn or self.ln:
155
+ out = self.bn1(out)
156
+
157
+ out = self.activation(out)
158
+ out = self.conv2(out)
159
+ if self.bn or self.ln:
160
+ out = self.bn2(out)
161
+
162
+ if self.groups > 1:
163
+ out = self.conv_merge(out)
164
+
165
+ if t is not None:
166
+ out = gate.unsqueeze(-1).unsqueeze(-1) * out
167
+
168
+ return out + x
169
+
170
+
171
+ class FeatureFusionBlock_custom(nn.Module):
172
+ """Feature fusion block."""
173
+
174
+ def __init__(
175
+ self,
176
+ features,
177
+ activation,
178
+ deconv=False,
179
+ bn=False,
180
+ ln=False,
181
+ expand=False,
182
+ align_corners=True,
183
+ dpt_time=False,
184
+ resolution=16,
185
+ ):
186
+ """Init.
187
+
188
+ Args:
189
+ features (int): number of features
190
+ """
191
+ super(FeatureFusionBlock_custom, self).__init__()
192
+
193
+ self.deconv = deconv
194
+ self.align_corners = align_corners
195
+
196
+ self.groups = 1
197
+
198
+ self.expand = expand
199
+ out_features = features
200
+ if self.expand == True:
201
+ out_features = features // 2
202
+
203
+ self.out_conv = nn.Conv2d(
204
+ features,
205
+ out_features,
206
+ kernel_size=1,
207
+ stride=1,
208
+ padding=0,
209
+ bias=True,
210
+ groups=1,
211
+ )
212
+
213
+ nn.init.kaiming_uniform_(self.out_conv.weight)
214
+
215
+ # The second block sees time
216
+ self.resConfUnit1 = ResidualConvUnit_custom(
217
+ features, activation, bn=bn, ln=ln, dpt_time=False, resolution=resolution
218
+ )
219
+ self.resConfUnit2 = ResidualConvUnit_custom(
220
+ features, activation, bn=bn, ln=ln, dpt_time=dpt_time, resolution=resolution
221
+ )
222
+
223
+ def forward(self, input, activation=None, t=None):
224
+ """Forward pass.
225
+
226
+ Returns:
227
+ tensor: output
228
+ """
229
+ output = input
230
+
231
+ if activation is not None:
232
+ res = self.resConfUnit1(activation)
233
+
234
+ output += res
235
+
236
+ output = self.resConfUnit2(output, t)
237
+
238
+ output = torch.nn.functional.interpolate(
239
+ output.float(),
240
+ scale_factor=2,
241
+ mode="bilinear",
242
+ align_corners=self.align_corners,
243
+ )
244
+
245
+ output = self.out_conv(output)
246
+
247
+ return output
diffusionsfm/model/diffuser.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ipdb # noqa: F401
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from diffusionsfm.model.dit import DiT
7
+ from diffusionsfm.model.feature_extractors import PretrainedVAE, SpatialDino
8
+ from diffusionsfm.model.scheduler import NoiseScheduler
9
+
10
+
11
+ class RayDiffuser(nn.Module):
12
+ def __init__(
13
+ self,
14
+ model_type="dit",
15
+ depth=8,
16
+ width=16,
17
+ hidden_size=1152,
18
+ P=1,
19
+ max_num_images=1,
20
+ noise_scheduler=None,
21
+ freeze_encoder=True,
22
+ feature_extractor="dino",
23
+ append_ndc=True,
24
+ use_unconditional=False,
25
+ diffuse_depths=False,
26
+ depth_resolution=1,
27
+ use_homogeneous=False,
28
+ cond_depth_mask=False,
29
+ ):
30
+ super().__init__()
31
+ if noise_scheduler is None:
32
+ self.noise_scheduler = NoiseScheduler()
33
+ else:
34
+ self.noise_scheduler = noise_scheduler
35
+
36
+ self.diffuse_depths = diffuse_depths
37
+ self.depth_resolution = depth_resolution
38
+ self.use_homogeneous = use_homogeneous
39
+
40
+ self.ray_dim = 3
41
+ if self.use_homogeneous:
42
+ self.ray_dim += 1
43
+
44
+ self.ray_dim += self.ray_dim * self.depth_resolution**2
45
+
46
+ if self.diffuse_depths:
47
+ self.ray_dim += 1
48
+
49
+ self.append_ndc = append_ndc
50
+ self.width = width
51
+
52
+ self.max_num_images = max_num_images
53
+ self.model_type = model_type
54
+ self.use_unconditional = use_unconditional
55
+ self.cond_depth_mask = cond_depth_mask
56
+
57
+ if feature_extractor == "dino":
58
+ self.feature_extractor = SpatialDino(
59
+ freeze_weights=freeze_encoder, num_patches_x=width, num_patches_y=width
60
+ )
61
+ self.feature_dim = self.feature_extractor.feature_dim
62
+ elif feature_extractor == "vae":
63
+ self.feature_extractor = PretrainedVAE(
64
+ freeze_weights=freeze_encoder, num_patches_x=width, num_patches_y=width
65
+ )
66
+ self.feature_dim = self.feature_extractor.feature_dim
67
+ else:
68
+ raise Exception(f"Unknown feature extractor {feature_extractor}")
69
+
70
+ if self.use_unconditional:
71
+ self.register_parameter(
72
+ "null_token", nn.Parameter(torch.randn(self.feature_dim, 1, 1))
73
+ )
74
+
75
+ self.input_dim = self.feature_dim * 2
76
+
77
+ if self.append_ndc:
78
+ self.input_dim += 2
79
+
80
+ if model_type == "dit":
81
+ self.ray_predictor = DiT(
82
+ in_channels=self.input_dim,
83
+ out_channels=self.ray_dim,
84
+ width=width,
85
+ depth=depth,
86
+ hidden_size=hidden_size,
87
+ max_num_images=max_num_images,
88
+ P=P,
89
+ )
90
+
91
+ self.scratch = nn.Module()
92
+ self.scratch.input_conv = nn.Linear(self.ray_dim + int(self.cond_depth_mask), self.feature_dim)
93
+
94
+ def forward_noise(
95
+ self, x, t, epsilon=None, zero_out_mask=None
96
+ ):
97
+ """
98
+ Applies forward diffusion (adds noise) to the input.
99
+
100
+ If a mask is provided, the noise is only applied to the masked inputs.
101
+ """
102
+ t = t.reshape(-1, 1, 1, 1, 1)
103
+
104
+ if epsilon is None:
105
+ epsilon = torch.randn_like(x)
106
+ else:
107
+ epsilon = epsilon.reshape(x.shape)
108
+
109
+ alpha_bar = self.noise_scheduler.alphas_cumprod[t]
110
+ x_noise = torch.sqrt(alpha_bar) * x + torch.sqrt(1 - alpha_bar) * epsilon
111
+
112
+ if zero_out_mask is not None and self.cond_depth_mask:
113
+ x_noise = x_noise * zero_out_mask
114
+
115
+ return x_noise, epsilon
116
+
117
+ def forward(
118
+ self,
119
+ features=None,
120
+ images=None,
121
+ rays=None,
122
+ rays_noisy=None,
123
+ t=None,
124
+ ndc_coordinates=None,
125
+ unconditional_mask=None,
126
+ return_dpt_activations=False,
127
+ depth_mask=None,
128
+ ):
129
+ """
130
+ Args:
131
+ images: (B, N, 3, H, W).
132
+ t: (B,).
133
+ rays: (B, N, 6, H, W).
134
+ rays_noisy: (B, N, 6, H, W).
135
+ ndc_coordinates: (B, N, 2, H, W).
136
+ unconditional_mask: (B, N) or (B,). Should be 1 for unconditional samples
137
+ and 0 else.
138
+ """
139
+
140
+ if features is None:
141
+ # VAE expects 256x256 images while DINO expects 224x224 images.
142
+ # Both feature extractors support autoresize=True, but ideally we should
143
+ # set this to be false and handle in the dataloader.
144
+ features = self.feature_extractor(images, autoresize=True)
145
+
146
+ B = features.shape[0]
147
+
148
+ if (
149
+ unconditional_mask is not None
150
+ and self.use_unconditional
151
+ ):
152
+ null_token = self.null_token.reshape(1, 1, self.feature_dim, 1, 1)
153
+ unconditional_mask = unconditional_mask.reshape(B, -1, 1, 1, 1)
154
+ features = (
155
+ features * (1 - unconditional_mask) + null_token * unconditional_mask
156
+ )
157
+
158
+ if isinstance(t, int) or isinstance(t, np.int64):
159
+ t = torch.ones(1, dtype=int).to(features.device) * t
160
+ else:
161
+ t = t.reshape(B)
162
+
163
+ if rays_noisy is None:
164
+ if self.cond_depth_mask:
165
+ rays_noisy, epsilon = self.forward_noise(rays, t, zero_out_mask=depth_mask.unsqueeze(2))
166
+ else:
167
+ rays_noisy, epsilon = self.forward_noise(rays, t)
168
+ else:
169
+ epsilon = None
170
+
171
+ if self.cond_depth_mask:
172
+ if depth_mask is None:
173
+ depth_mask = torch.ones_like(rays_noisy[:, :, 0])
174
+ ray_repr = torch.cat([rays_noisy, depth_mask.unsqueeze(2)], dim=2)
175
+ else:
176
+ ray_repr = rays_noisy
177
+
178
+ ray_repr = ray_repr.permute(0, 1, 3, 4, 2)
179
+ ray_repr = self.scratch.input_conv(ray_repr).permute(0, 1, 4, 2, 3).contiguous()
180
+
181
+ scene_features = torch.cat([features, ray_repr], dim=2)
182
+
183
+ if self.append_ndc:
184
+ scene_features = torch.cat([scene_features, ndc_coordinates], dim=2)
185
+
186
+ epsilon_pred = self.ray_predictor(
187
+ scene_features,
188
+ t,
189
+ return_dpt_activations=return_dpt_activations,
190
+ )
191
+
192
+ if return_dpt_activations:
193
+ return epsilon_pred, rays_noisy, epsilon
194
+
195
+ return epsilon_pred, epsilon
diffusionsfm/model/diffuser_dpt.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ipdb # noqa: F401
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from diffusionsfm.model.dit import DiT
7
+ from diffusionsfm.model.feature_extractors import PretrainedVAE, SpatialDino
8
+ from diffusionsfm.model.blocks import _make_fusion_block, _make_scratch
9
+ from diffusionsfm.model.scheduler import NoiseScheduler
10
+
11
+
12
+ # functional implementation
13
+ def nearest_neighbor_upsample(x: torch.Tensor, scale_factor: int):
14
+ """Upsample {x} (NCHW) by scale factor {scale_factor} using nearest neighbor interpolation."""
15
+ s = scale_factor
16
+ return (
17
+ x.reshape(*x.shape, 1, 1)
18
+ .expand(*x.shape, s, s)
19
+ .transpose(-2, -3)
20
+ .reshape(*x.shape[:2], *(s * hw for hw in x.shape[2:]))
21
+ )
22
+
23
+
24
+ class ProjectReadout(nn.Module):
25
+ def __init__(self, in_features, start_index=1):
26
+ super(ProjectReadout, self).__init__()
27
+ self.start_index = start_index
28
+
29
+ self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
30
+
31
+ def forward(self, x):
32
+ readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :])
33
+ features = torch.cat((x[:, self.start_index :], readout), -1)
34
+
35
+ return self.project(features)
36
+
37
+
38
+ class RayDiffuserDPT(nn.Module):
39
+ def __init__(
40
+ self,
41
+ model_type="dit",
42
+ depth=8,
43
+ width=16,
44
+ hidden_size=1152,
45
+ P=1,
46
+ max_num_images=1,
47
+ noise_scheduler=None,
48
+ freeze_encoder=True,
49
+ feature_extractor="dino",
50
+ append_ndc=True,
51
+ use_unconditional=False,
52
+ diffuse_depths=False,
53
+ depth_resolution=1,
54
+ encoder_features=False,
55
+ use_homogeneous=False,
56
+ freeze_transformer=False,
57
+ cond_depth_mask=False,
58
+ ):
59
+ super().__init__()
60
+ if noise_scheduler is None:
61
+ self.noise_scheduler = NoiseScheduler()
62
+ else:
63
+ self.noise_scheduler = noise_scheduler
64
+
65
+ self.diffuse_depths = diffuse_depths
66
+ self.depth_resolution = depth_resolution
67
+ self.use_homogeneous = use_homogeneous
68
+
69
+ self.ray_dim = 3
70
+
71
+ if self.use_homogeneous:
72
+ self.ray_dim += 1
73
+ self.ray_dim += self.ray_dim * self.depth_resolution**2
74
+
75
+ if self.diffuse_depths:
76
+ self.ray_dim += 1
77
+
78
+ self.append_ndc = append_ndc
79
+ self.width = width
80
+
81
+ self.max_num_images = max_num_images
82
+ self.model_type = model_type
83
+ self.use_unconditional = use_unconditional
84
+ self.cond_depth_mask = cond_depth_mask
85
+ self.encoder_features = encoder_features
86
+
87
+ if feature_extractor == "dino":
88
+ self.feature_extractor = SpatialDino(
89
+ freeze_weights=freeze_encoder,
90
+ num_patches_x=width,
91
+ num_patches_y=width,
92
+ activation_hooks=self.encoder_features,
93
+ )
94
+ self.feature_dim = self.feature_extractor.feature_dim
95
+ elif feature_extractor == "vae":
96
+ self.feature_extractor = PretrainedVAE(
97
+ freeze_weights=freeze_encoder, num_patches_x=width, num_patches_y=width
98
+ )
99
+ self.feature_dim = self.feature_extractor.feature_dim
100
+ else:
101
+ raise Exception(f"Unknown feature extractor {feature_extractor}")
102
+
103
+ if self.use_unconditional:
104
+ self.register_parameter(
105
+ "null_token", nn.Parameter(torch.randn(self.feature_dim, 1, 1))
106
+ )
107
+
108
+ self.input_dim = self.feature_dim * 2
109
+
110
+ if self.append_ndc:
111
+ self.input_dim += 2
112
+
113
+ if model_type == "dit":
114
+ self.ray_predictor = DiT(
115
+ in_channels=self.input_dim,
116
+ out_channels=self.ray_dim,
117
+ width=width,
118
+ depth=depth,
119
+ hidden_size=hidden_size,
120
+ max_num_images=max_num_images,
121
+ P=P,
122
+ )
123
+
124
+ if freeze_transformer:
125
+ for param in self.ray_predictor.parameters():
126
+ param.requires_grad = False
127
+
128
+ # Fusion blocks
129
+ self.f = 256
130
+
131
+ if self.encoder_features:
132
+ feature_lens = [
133
+ self.feature_extractor.feature_dim,
134
+ self.feature_extractor.feature_dim,
135
+ self.ray_predictor.hidden_size,
136
+ self.ray_predictor.hidden_size,
137
+ ]
138
+ else:
139
+ feature_lens = [self.ray_predictor.hidden_size] * 4
140
+
141
+ self.scratch = _make_scratch(feature_lens, 256, groups=1, expand=False)
142
+ self.scratch.refinenet1 = _make_fusion_block(
143
+ self.f, use_bn=False, use_ln=False, dpt_time=True, resolution=128
144
+ )
145
+ self.scratch.refinenet2 = _make_fusion_block(
146
+ self.f, use_bn=False, use_ln=False, dpt_time=True, resolution=64
147
+ )
148
+ self.scratch.refinenet3 = _make_fusion_block(
149
+ self.f, use_bn=False, use_ln=False, dpt_time=True, resolution=32
150
+ )
151
+ self.scratch.refinenet4 = _make_fusion_block(
152
+ self.f, use_bn=False, use_ln=False, dpt_time=True, resolution=16
153
+ )
154
+
155
+ self.scratch.input_conv = nn.Conv2d(
156
+ self.ray_dim + int(self.cond_depth_mask),
157
+ self.feature_dim,
158
+ kernel_size=16,
159
+ stride=16,
160
+ padding=0
161
+ )
162
+
163
+ self.scratch.output_conv = nn.Sequential(
164
+ nn.Conv2d(self.f, self.f // 2, kernel_size=3, stride=1, padding=1),
165
+ nn.LeakyReLU(),
166
+ nn.Conv2d(self.f // 2, 32, kernel_size=3, stride=1, padding=1),
167
+ nn.LeakyReLU(),
168
+ nn.Conv2d(32, self.ray_dim, kernel_size=1, stride=1, padding=0),
169
+ nn.Identity(),
170
+ )
171
+
172
+ if self.encoder_features:
173
+ self.project_opers = nn.ModuleList([
174
+ ProjectReadout(in_features=self.feature_extractor.feature_dim),
175
+ ProjectReadout(in_features=self.feature_extractor.feature_dim),
176
+ ])
177
+
178
+ def forward_noise(
179
+ self, x, t, epsilon=None, zero_out_mask=None
180
+ ):
181
+ """
182
+ Applies forward diffusion (adds noise) to the input.
183
+
184
+ If a mask is provided, the noise is only applied to the masked inputs.
185
+ """
186
+ t = t.reshape(-1, 1, 1, 1, 1)
187
+ if epsilon is None:
188
+ epsilon = torch.randn_like(x)
189
+ else:
190
+ epsilon = epsilon.reshape(x.shape)
191
+
192
+ alpha_bar = self.noise_scheduler.alphas_cumprod[t]
193
+ x_noise = torch.sqrt(alpha_bar) * x + torch.sqrt(1 - alpha_bar) * epsilon
194
+
195
+ if zero_out_mask is not None and self.cond_depth_mask:
196
+ x_noise = zero_out_mask * x_noise
197
+
198
+ return x_noise, epsilon
199
+
200
+ def forward(
201
+ self,
202
+ features=None,
203
+ images=None,
204
+ rays=None,
205
+ rays_noisy=None,
206
+ t=None,
207
+ ndc_coordinates=None,
208
+ unconditional_mask=None,
209
+ encoder_patches=16,
210
+ depth_mask=None,
211
+ multiview_unconditional=False,
212
+ indices=None,
213
+ ):
214
+ """
215
+ Args:
216
+ images: (B, N, 3, H, W).
217
+ t: (B,).
218
+ rays: (B, N, 6, H, W).
219
+ rays_noisy: (B, N, 6, H, W).
220
+ ndc_coordinates: (B, N, 2, H, W).
221
+ unconditional_mask: (B, N) or (B,). Should be 1 for unconditional samples
222
+ and 0 else.
223
+ """
224
+
225
+ if features is None:
226
+ # VAE expects 256x256 images while DINO expects 224x224 images.
227
+ # Both feature extractors support autoresize=True, but ideally we should
228
+ # set this to be false and handle in the dataloader.
229
+ features = self.feature_extractor(images, autoresize=True)
230
+
231
+ B = features.shape[0]
232
+
233
+ if unconditional_mask is not None and self.use_unconditional:
234
+ null_token = self.null_token.reshape(1, 1, self.feature_dim, 1, 1)
235
+ unconditional_mask = unconditional_mask.reshape(B, -1, 1, 1, 1)
236
+ features = (
237
+ features * (1 - unconditional_mask) + null_token * unconditional_mask
238
+ )
239
+
240
+ if isinstance(t, int) or isinstance(t, np.int64):
241
+ t = torch.ones(1, dtype=int).to(features.device) * t
242
+ else:
243
+ t = t.reshape(B)
244
+
245
+ if rays_noisy is None:
246
+ if self.cond_depth_mask:
247
+ rays_noisy, epsilon = self.forward_noise(
248
+ rays, t, zero_out_mask=depth_mask.unsqueeze(2)
249
+ )
250
+ else:
251
+ rays_noisy, epsilon = self.forward_noise(
252
+ rays, t
253
+ )
254
+ else:
255
+ epsilon = None
256
+
257
+ # DOWNSAMPLE RAYS
258
+ B, N, C, H, W = rays_noisy.shape
259
+
260
+ if self.cond_depth_mask:
261
+ if depth_mask is None:
262
+ depth_mask = torch.ones_like(rays_noisy[:, :, 0])
263
+ ray_repr = torch.cat([rays_noisy, depth_mask.unsqueeze(2)], dim=2)
264
+ else:
265
+ ray_repr = rays_noisy
266
+
267
+ ray_repr = self.scratch.input_conv(ray_repr.reshape(B * N, -1, H, W))
268
+ _, CP, HP, WP = ray_repr.shape
269
+ ray_repr = ray_repr.reshape(B, N, CP, HP, WP)
270
+ scene_features = torch.cat([features, ray_repr], dim=2)
271
+
272
+ if self.append_ndc:
273
+ scene_features = torch.cat([scene_features, ndc_coordinates], dim=2)
274
+
275
+ # DIT FORWARD PASS
276
+ activations = self.ray_predictor(
277
+ scene_features,
278
+ t,
279
+ return_dpt_activations=True,
280
+ multiview_unconditional=multiview_unconditional,
281
+ )
282
+
283
+ # PROJECT ENCODER ACTIVATIONS & RESHAPE
284
+ if self.encoder_features:
285
+ for i in range(2):
286
+ name = f"encoder{i+1}"
287
+
288
+ if indices is not None:
289
+ act = self.feature_extractor.activations[name][indices]
290
+ else:
291
+ act = self.feature_extractor.activations[name]
292
+
293
+ act = self.project_opers[i](act).permute(0, 2, 1)
294
+ act = act.reshape(
295
+ (
296
+ B * N,
297
+ self.feature_extractor.feature_dim,
298
+ encoder_patches,
299
+ encoder_patches,
300
+ )
301
+ )
302
+ activations[i] = act
303
+
304
+ # UPSAMPLE ACTIVATIONS
305
+ for i, act in enumerate(activations):
306
+ k = 3 - i
307
+ activations[i] = nearest_neighbor_upsample(act, 2**k)
308
+
309
+ # FUSION BLOCKS
310
+ layer_1_rn = self.scratch.layer1_rn(activations[0])
311
+ layer_2_rn = self.scratch.layer2_rn(activations[1])
312
+ layer_3_rn = self.scratch.layer3_rn(activations[2])
313
+ layer_4_rn = self.scratch.layer4_rn(activations[3])
314
+
315
+ # RESHAPE TIMESTEPS
316
+ if t.shape[0] == B:
317
+ t = t.unsqueeze(-1).repeat((1, N)).reshape(B * N)
318
+ elif t.shape[0] == 1 and B > 1:
319
+ t = t.repeat((B * N))
320
+ else:
321
+ assert False
322
+
323
+ path_4 = self.scratch.refinenet4(layer_4_rn, t=t)
324
+ path_3 = self.scratch.refinenet3(path_4, activation=layer_3_rn, t=t)
325
+ path_2 = self.scratch.refinenet2(path_3, activation=layer_2_rn, t=t)
326
+ path_1 = self.scratch.refinenet1(path_2, activation=layer_1_rn, t=t)
327
+
328
+ epsilon_pred = self.scratch.output_conv(path_1)
329
+ epsilon_pred = epsilon_pred.reshape((B, N, C, H, W))
330
+
331
+ return epsilon_pred, epsilon
diffusionsfm/model/dit.py ADDED
@@ -0,0 +1,428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # GLIDE: https://github.com/openai/glide-text2im
9
+ # MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
10
+ # --------------------------------------------------------
11
+
12
+ import math
13
+
14
+ import ipdb # noqa: F401
15
+ import numpy as np
16
+ import torch
17
+ import torch.nn as nn
18
+ from timm.models.vision_transformer import Attention, Mlp, PatchEmbed
19
+ from diffusionsfm.model.memory_efficient_attention import MEAttention
20
+
21
+
22
+ def modulate(x, shift, scale):
23
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
24
+
25
+
26
+ #################################################################################
27
+ # Embedding Layers for Timesteps and Class Labels #
28
+ #################################################################################
29
+
30
+
31
+ class TimestepEmbedder(nn.Module):
32
+ """
33
+ Embeds scalar timesteps into vector representations.
34
+ """
35
+
36
+ def __init__(self, hidden_size, frequency_embedding_size=256):
37
+ super().__init__()
38
+ self.mlp = nn.Sequential(
39
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
40
+ nn.SiLU(),
41
+ nn.Linear(hidden_size, hidden_size, bias=True),
42
+ )
43
+ self.frequency_embedding_size = frequency_embedding_size
44
+
45
+ @staticmethod
46
+ def timestep_embedding(t, dim, max_period=10000):
47
+ """
48
+ Create sinusoidal timestep embeddings.
49
+ :param t: a 1-D Tensor of N indices, one per batch element.
50
+ These may be fractional.
51
+ :param dim: the dimension of the output.
52
+ :param max_period: controls the minimum frequency of the embeddings.
53
+ :return: an (N, D) Tensor of positional embeddings.
54
+ """
55
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
56
+ half = dim // 2
57
+ freqs = torch.exp(
58
+ -math.log(max_period)
59
+ * torch.arange(start=0, end=half, dtype=torch.float32)
60
+ / half
61
+ ).to(device=t.device)
62
+ args = t[:, None].float() * freqs[None]
63
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
64
+ if dim % 2:
65
+ embedding = torch.cat(
66
+ [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
67
+ )
68
+ return embedding
69
+
70
+ def forward(self, t):
71
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
72
+ t_emb = self.mlp(t_freq)
73
+ return t_emb
74
+
75
+
76
+ #################################################################################
77
+ # Core DiT Model #
78
+ #################################################################################
79
+
80
+
81
+ class DiTBlock(nn.Module):
82
+ """
83
+ A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
84
+ """
85
+
86
+ def __init__(
87
+ self,
88
+ hidden_size,
89
+ num_heads,
90
+ mlp_ratio=4.0,
91
+ use_xformers_attention=False,
92
+ **block_kwargs
93
+ ):
94
+ super().__init__()
95
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
96
+ attn = MEAttention if use_xformers_attention else Attention
97
+ self.attn = attn(
98
+ hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs
99
+ )
100
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
101
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
102
+
103
+ def approx_gelu():
104
+ return nn.GELU(approximate="tanh")
105
+
106
+ self.mlp = Mlp(
107
+ in_features=hidden_size,
108
+ hidden_features=mlp_hidden_dim,
109
+ act_layer=approx_gelu,
110
+ drop=0,
111
+ )
112
+ self.adaLN_modulation = nn.Sequential(
113
+ nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True)
114
+ )
115
+
116
+ def forward(self, x, c):
117
+ (
118
+ shift_msa,
119
+ scale_msa,
120
+ gate_msa,
121
+ shift_mlp,
122
+ scale_mlp,
123
+ gate_mlp,
124
+ ) = self.adaLN_modulation(c).chunk(6, dim=1)
125
+ x = x + gate_msa.unsqueeze(1) * self.attn(
126
+ modulate(self.norm1(x), shift_msa, scale_msa)
127
+ )
128
+ x = x + gate_mlp.unsqueeze(1) * self.mlp(
129
+ modulate(self.norm2(x), shift_mlp, scale_mlp)
130
+ )
131
+ return x
132
+
133
+
134
+ class FinalLayer(nn.Module):
135
+ """
136
+ The final layer of DiT.
137
+ """
138
+
139
+ def __init__(self, hidden_size, patch_size, out_channels):
140
+ super().__init__()
141
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
142
+ self.linear = nn.Linear(
143
+ hidden_size, patch_size * patch_size * out_channels, bias=True
144
+ )
145
+ self.adaLN_modulation = nn.Sequential(
146
+ nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)
147
+ )
148
+
149
+ def forward(self, x, c):
150
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
151
+ x = modulate(self.norm_final(x), shift, scale)
152
+ x = self.linear(x)
153
+ return x
154
+
155
+
156
+ class DiT(nn.Module):
157
+ """
158
+ Diffusion model with a Transformer backbone.
159
+ """
160
+
161
+ def __init__(
162
+ self,
163
+ in_channels=442,
164
+ out_channels=6,
165
+ width=16,
166
+ hidden_size=1152,
167
+ depth=8,
168
+ num_heads=16,
169
+ mlp_ratio=4.0,
170
+ max_num_images=8,
171
+ P=1,
172
+ within_image=False,
173
+ ):
174
+ super().__init__()
175
+ self.num_heads = num_heads
176
+ self.in_channels = in_channels
177
+ self.out_channels = out_channels
178
+ self.width = width
179
+ self.hidden_size = hidden_size
180
+ self.max_num_images = max_num_images
181
+ self.P = P
182
+ self.within_image = within_image
183
+
184
+ # self.x_embedder = nn.Linear(in_channels, hidden_size)
185
+ # self.x_embedder = PatchEmbed(in_channels, hidden_size, kernel_size=P, hidden_size=P)
186
+ self.x_embedder = PatchEmbed(
187
+ img_size=self.width,
188
+ patch_size=self.P,
189
+ in_chans=in_channels,
190
+ embed_dim=hidden_size,
191
+ bias=True,
192
+ flatten=False,
193
+ )
194
+ self.x_pos_enc = FeaturePositionalEncoding(
195
+ max_num_images, hidden_size, width**2, P=self.P
196
+ )
197
+ self.t_embedder = TimestepEmbedder(hidden_size)
198
+
199
+ try:
200
+ import xformers
201
+
202
+ use_xformers_attention = True
203
+ except ImportError:
204
+ # xformers not available
205
+ use_xformers_attention = False
206
+
207
+ self.blocks = nn.ModuleList(
208
+ [
209
+ DiTBlock(
210
+ hidden_size,
211
+ num_heads,
212
+ mlp_ratio=mlp_ratio,
213
+ use_xformers_attention=use_xformers_attention,
214
+ )
215
+ for _ in range(depth)
216
+ ]
217
+ )
218
+ self.final_layer = FinalLayer(hidden_size, P, out_channels)
219
+ self.initialize_weights()
220
+
221
+ def initialize_weights(self):
222
+ # Initialize transformer layers:
223
+ def _basic_init(module):
224
+ if isinstance(module, nn.Linear):
225
+ torch.nn.init.xavier_uniform_(module.weight)
226
+ if module.bias is not None:
227
+ nn.init.constant_(module.bias, 0)
228
+
229
+ self.apply(_basic_init)
230
+
231
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
232
+ w = self.x_embedder.proj.weight.data
233
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
234
+ nn.init.constant_(self.x_embedder.proj.bias, 0)
235
+
236
+ # Initialize timestep embedding MLP:
237
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
238
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
239
+
240
+ # Zero-out adaLN modulation layers in DiT blocks:
241
+ for block in self.blocks:
242
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
243
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
244
+
245
+ # Zero-out output layers:
246
+ # nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
247
+ # nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
248
+ # nn.init.constant_(self.final_layer.linear.weight, 0)
249
+ # nn.init.constant_(self.final_layer.linear.bias, 0)
250
+
251
+ def unpatchify(self, x):
252
+ """
253
+ x: (N, T, patch_size**2 * C)
254
+ imgs: (N, H, W, C)
255
+ """
256
+ c = self.out_channels
257
+ p = self.x_embedder.patch_size[0]
258
+ h = w = int(x.shape[1] ** 0.5)
259
+
260
+ # print("unpatchify", c, p, h, w, x.shape)
261
+ # assert h * w == x.shape[2]
262
+
263
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
264
+ x = torch.einsum("nhwpqc->nhpwqc", x)
265
+ imgs = x.reshape(shape=(x.shape[0], h * p, h * p, c))
266
+ return imgs
267
+
268
+ def forward(
269
+ self,
270
+ x,
271
+ t,
272
+ return_dpt_activations=False,
273
+ multiview_unconditional=False,
274
+ ):
275
+ """
276
+
277
+ Args:
278
+ x: Image/Ray features (B, N, C, H, W).
279
+ t: Timesteps (N,).
280
+
281
+ Returns:
282
+ (B, N, D, H, W)
283
+ """
284
+ B, N, c, h, w = x.shape
285
+ P = self.P
286
+
287
+ x = x.reshape((B * N, c, h, w)) # (B * N, C, H, W)
288
+ x = self.x_embedder(x) # (B * N, C, H / P, W / P)
289
+
290
+ x = x.permute(0, 2, 3, 1) # (B * N, H / P, W / P, C)
291
+ # (B, N, H / P, W / P, C)
292
+ x = x.reshape((B, N, h // P, w // P, self.hidden_size))
293
+ x = self.x_pos_enc(x) # (B, N, H * W / P ** 2, C)
294
+ # TODO: fix positional encoding to work with (N, C, H, W) format.
295
+
296
+ # Eval time, we get a scalar t
297
+ if x.shape[0] != t.shape[0] and t.shape[0] == 1:
298
+ t = t.repeat_interleave(B)
299
+
300
+ if self.within_image or multiview_unconditional:
301
+ t_within = t.repeat_interleave(N)
302
+ t_within = self.t_embedder(t_within)
303
+
304
+ t = self.t_embedder(t)
305
+
306
+ dpt_activations = []
307
+ for i, block in enumerate(self.blocks):
308
+ # Within image block
309
+ if (self.within_image and i % 2 == 0) or multiview_unconditional:
310
+ x = x.reshape((B * N, h * w // P**2, self.hidden_size))
311
+ x = block(x, t_within)
312
+
313
+ # All patches block
314
+ # Final layer is an all patches layer
315
+ else:
316
+ x = x.reshape((B, N * h * w // P**2, self.hidden_size))
317
+ x = block(x, t) # (N, T, D)
318
+
319
+ if return_dpt_activations and i % 4 == 3:
320
+ x_prime = x.reshape(B, N, h, w, self.hidden_size)
321
+ x_prime = x.reshape(B * N, h, w, self.hidden_size)
322
+ x_prime = x_prime.permute((0, 3, 1, 2))
323
+ dpt_activations.append(x_prime)
324
+
325
+ # Reshape the output back to original shape
326
+ if multiview_unconditional:
327
+ x = x.reshape((B, N * h * w // P**2, self.hidden_size))
328
+
329
+ # (B, N * H * W / P ** 2, D)
330
+ x = self.final_layer(
331
+ x, t
332
+ ) # (B, N * H * W / P ** 2, 6 * P ** 2) or (N, T, patch_size ** 2 * out_channels)
333
+
334
+ x = x.reshape((B * N, w * w // P**2, self.out_channels * P**2))
335
+ x = self.unpatchify(x) # (B * N, H, W, C)
336
+ x = x.reshape((B, N) + x.shape[1:])
337
+ x = x.permute(0, 1, 4, 2, 3) # (B, N, C, H, W)
338
+
339
+ if return_dpt_activations:
340
+ return dpt_activations[:4]
341
+
342
+ return x
343
+
344
+
345
+ class FeaturePositionalEncoding(nn.Module):
346
+ def _get_sinusoid_encoding_table(self, n_position, d_hid, base):
347
+ """Sinusoid position encoding table"""
348
+
349
+ def get_position_angle_vec(position):
350
+ return [
351
+ position / np.power(base, 2 * (hid_j // 2) / d_hid)
352
+ for hid_j in range(d_hid)
353
+ ]
354
+
355
+ sinusoid_table = np.array(
356
+ [get_position_angle_vec(pos_i) for pos_i in range(n_position)]
357
+ )
358
+ sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
359
+ sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
360
+
361
+ return torch.FloatTensor(sinusoid_table).unsqueeze(0)
362
+
363
+ def __init__(self, max_num_images=8, feature_dim=1152, num_patches=256, P=1):
364
+ super().__init__()
365
+ self.max_num_images = max_num_images
366
+ self.feature_dim = feature_dim
367
+ self.P = P
368
+ self.num_patches = num_patches // self.P**2
369
+
370
+ self.register_buffer(
371
+ "image_pos_table",
372
+ self._get_sinusoid_encoding_table(
373
+ self.max_num_images, self.feature_dim, 10000
374
+ ),
375
+ )
376
+
377
+ self.register_buffer(
378
+ "token_pos_table",
379
+ self._get_sinusoid_encoding_table(
380
+ self.num_patches, self.feature_dim, 70007
381
+ ),
382
+ )
383
+
384
+ def forward(self, x):
385
+ batch_size = x.shape[0]
386
+ num_images = x.shape[1]
387
+
388
+ x = x.reshape(batch_size, num_images, self.num_patches, self.feature_dim)
389
+
390
+ # To encode image index
391
+ pe1 = self.image_pos_table[:, :num_images].clone().detach()
392
+ pe1 = pe1.reshape((1, num_images, 1, self.feature_dim))
393
+ pe1 = pe1.repeat((batch_size, 1, self.num_patches, 1))
394
+
395
+ # To encode patch index
396
+ pe2 = self.token_pos_table.clone().detach()
397
+ pe2 = pe2.reshape((1, 1, self.num_patches, self.feature_dim))
398
+ pe2 = pe2.repeat((batch_size, num_images, 1, 1))
399
+
400
+ x_pe = x + pe1 + pe2
401
+ x_pe = x_pe.reshape(
402
+ (batch_size, num_images * self.num_patches, self.feature_dim)
403
+ )
404
+
405
+ return x_pe
406
+
407
+ def forward_unet(self, x, B, N):
408
+ D = int(self.num_patches**0.5)
409
+
410
+ # x should be (B, N, T, D, D)
411
+ x = x.permute((0, 2, 3, 1))
412
+ x = x.reshape(B, N, self.num_patches, self.feature_dim)
413
+
414
+ # To encode image index
415
+ pe1 = self.image_pos_table[:, :N].clone().detach()
416
+ pe1 = pe1.reshape((1, N, 1, self.feature_dim))
417
+ pe1 = pe1.repeat((B, 1, self.num_patches, 1))
418
+
419
+ # To encode patch index
420
+ pe2 = self.token_pos_table.clone().detach()
421
+ pe2 = pe2.reshape((1, 1, self.num_patches, self.feature_dim))
422
+ pe2 = pe2.repeat((B, N, 1, 1))
423
+
424
+ x_pe = x + pe1 + pe2
425
+ x_pe = x_pe.reshape((B * N, D, D, self.feature_dim))
426
+ x_pe = x_pe.permute((0, 3, 1, 2))
427
+
428
+ return x_pe
diffusionsfm/model/feature_extractors.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import os
3
+ import socket
4
+ import sys
5
+
6
+ import ipdb # noqa: F401
7
+ import torch
8
+ import torch.nn as nn
9
+ from omegaconf import OmegaConf
10
+
11
+ HOSTNAME = socket.gethostname()
12
+
13
+ if "trinity" in HOSTNAME:
14
+ # Might be outdated
15
+ config_path = "/home/amylin2/latent-diffusion/configs/autoencoder/autoencoder_kl_16x16x16.yaml"
16
+ weights_path = "/home/amylin2/latent-diffusion/model.ckpt"
17
+ elif "grogu" in HOSTNAME:
18
+ # Might be outdated
19
+ config_path = "/home/jasonzh2/code/latent-diffusion/configs/autoencoder/autoencoder_kl_16x16x16.yaml"
20
+ weights_path = "/home/jasonzh2/code/latent-diffusion/model.ckpt"
21
+ elif "ender" in HOSTNAME:
22
+ config_path = "/home/jason/ray_diffusion/external/latent-diffusion/configs/autoencoder/autoencoder_kl_16x16x16.yaml"
23
+ weights_path = "/home/jason/ray_diffusion/external/latent-diffusion/model.ckpt"
24
+ else:
25
+ config_path = None
26
+ weights_path = None
27
+
28
+
29
+ if weights_path is not None:
30
+ LDM_PATH = os.path.dirname(weights_path)
31
+ if LDM_PATH not in sys.path:
32
+ sys.path.append(LDM_PATH)
33
+
34
+
35
+ def resize(image, size=None, scale_factor=None):
36
+ return nn.functional.interpolate(
37
+ image,
38
+ size=size,
39
+ scale_factor=scale_factor,
40
+ mode="bilinear",
41
+ align_corners=False,
42
+ )
43
+
44
+
45
+ def instantiate_from_config(config):
46
+ if "target" not in config:
47
+ if config == "__is_first_stage__":
48
+ return None
49
+ elif config == "__is_unconditional__":
50
+ return None
51
+ raise KeyError("Expected key `target` to instantiate.")
52
+ return get_obj_from_str(config["target"])(**config.get("params", dict()))
53
+
54
+
55
+ def get_obj_from_str(string, reload=False):
56
+ module, cls = string.rsplit(".", 1)
57
+ if reload:
58
+ module_imp = importlib.import_module(module)
59
+ importlib.reload(module_imp)
60
+ return getattr(importlib.import_module(module, package=None), cls)
61
+
62
+
63
+ class PretrainedVAE(nn.Module):
64
+ def __init__(self, freeze_weights=True, num_patches_x=16, num_patches_y=16):
65
+ super().__init__()
66
+ config = OmegaConf.load(config_path)
67
+ self.model = instantiate_from_config(config.model)
68
+ self.model.init_from_ckpt(weights_path)
69
+ self.model.eval()
70
+ self.feature_dim = 16
71
+ self.num_patches_x = num_patches_x
72
+ self.num_patches_y = num_patches_y
73
+
74
+ if freeze_weights:
75
+ for param in self.model.parameters():
76
+ param.requires_grad = False
77
+
78
+ def forward(self, x, autoresize=False):
79
+ """
80
+ Spatial dimensions of output will be H // 16, W // 16. If autoresize is True,
81
+ then the input will be resized such that the output feature map is the correct
82
+ dimensions.
83
+
84
+ Args:
85
+ x (torch.Tensor): Images (B, C, H, W). Should be normalized to be [-1, 1].
86
+ autoresize (bool): Whether to resize the input to match the num_patch
87
+ dimensions.
88
+
89
+ Returns:
90
+ torch.Tensor: Latent sample (B, 16, h, w)
91
+ """
92
+
93
+ *B, c, h, w = x.shape
94
+ x = x.reshape(-1, c, h, w)
95
+ if autoresize:
96
+ new_w = self.num_patches_x * 16
97
+ new_h = self.num_patches_y * 16
98
+ x = resize(x, size=(new_h, new_w))
99
+
100
+ decoded, latent = self.model(x)
101
+ # A little ambiguous bc it's all 16, but it is (c, h, w)
102
+ latent_sample = latent.sample().reshape(
103
+ *B, self.feature_dim, self.num_patches_y, self.num_patches_x
104
+ )
105
+ return latent_sample
106
+
107
+
108
+ activations = {}
109
+
110
+
111
+ def get_activation(name):
112
+ def hook(model, input, output):
113
+ activations[name] = output
114
+
115
+ return hook
116
+
117
+
118
+ class SpatialDino(nn.Module):
119
+ def __init__(
120
+ self,
121
+ freeze_weights=True,
122
+ model_type="dinov2_vits14",
123
+ num_patches_x=16,
124
+ num_patches_y=16,
125
+ activation_hooks=False,
126
+ ):
127
+ super().__init__()
128
+ self.model = torch.hub.load("facebookresearch/dinov2", model_type)
129
+ self.feature_dim = self.model.embed_dim
130
+ self.num_patches_x = num_patches_x
131
+ self.num_patches_y = num_patches_y
132
+ if freeze_weights:
133
+ for param in self.model.parameters():
134
+ param.requires_grad = False
135
+
136
+ self.activation_hooks = activation_hooks
137
+
138
+ if self.activation_hooks:
139
+ self.model.blocks[5].register_forward_hook(get_activation("encoder1"))
140
+ self.model.blocks[11].register_forward_hook(get_activation("encoder2"))
141
+ self.activations = activations
142
+
143
+ def forward(self, x, autoresize=False):
144
+ """
145
+ Spatial dimensions of output will be H // 14, W // 14. If autoresize is True,
146
+ then the output will be resized to the correct dimensions.
147
+
148
+ Args:
149
+ x (torch.Tensor): Images (B, C, H, W). Should be ImageNet normalized.
150
+ autoresize (bool): Whether to resize the input to match the num_patch
151
+ dimensions.
152
+
153
+ Returns:
154
+ feature_map (torch.tensor): (B, C, h, w)
155
+ """
156
+ *B, c, h, w = x.shape
157
+
158
+ x = x.reshape(-1, c, h, w)
159
+ # if autoresize:
160
+ # new_w = self.num_patches_x * 14
161
+ # new_h = self.num_patches_y * 14
162
+ # x = resize(x, size=(new_h, new_w))
163
+
164
+ # Output will be (B, H * W, C)
165
+ features = self.model.forward_features(x)["x_norm_patchtokens"]
166
+ features = features.permute(0, 2, 1)
167
+ features = features.reshape( # (B, C, H, W)
168
+ -1, self.feature_dim, h // 14, w // 14
169
+ )
170
+ if autoresize:
171
+ features = resize(features, size=(self.num_patches_y, self.num_patches_x))
172
+
173
+ features = features.reshape(
174
+ *B, self.feature_dim, self.num_patches_y, self.num_patches_x
175
+ )
176
+ return features
diffusionsfm/model/memory_efficient_attention.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ipdb
2
+ import torch.nn as nn
3
+ from xformers.ops import memory_efficient_attention
4
+
5
+
6
+ class MEAttention(nn.Module):
7
+ def __init__(
8
+ self,
9
+ dim,
10
+ num_heads=8,
11
+ qkv_bias=False,
12
+ qk_norm=False,
13
+ attn_drop=0.0,
14
+ proj_drop=0.0,
15
+ norm_layer=nn.LayerNorm,
16
+ ):
17
+ super().__init__()
18
+ assert dim % num_heads == 0, "dim should be divisible by num_heads"
19
+ self.num_heads = num_heads
20
+ self.head_dim = dim // num_heads
21
+ self.scale = self.head_dim**-0.5
22
+
23
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
24
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
25
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
26
+ self.attn_drop = nn.Dropout(attn_drop)
27
+ self.proj = nn.Linear(dim, dim)
28
+ self.proj_drop = nn.Dropout(proj_drop)
29
+
30
+ def forward(self, x):
31
+ B, N, C = x.shape
32
+ qkv = (
33
+ self.qkv(x)
34
+ .reshape(B, N, 3, self.num_heads, self.head_dim)
35
+ .permute(2, 0, 3, 1, 4)
36
+ )
37
+ q, k, v = qkv.unbind(0)
38
+ q, k = self.q_norm(q), self.k_norm(k)
39
+
40
+ # MEA expects [B, N, H, D], whereas timm uses [B, H, N, D]
41
+ x = memory_efficient_attention(
42
+ q.transpose(1, 2),
43
+ k.transpose(1, 2),
44
+ v.transpose(1, 2),
45
+ scale=self.scale,
46
+ )
47
+ x = x.reshape(B, N, C)
48
+
49
+ x = self.proj(x)
50
+ x = self.proj_drop(x)
51
+ return x
diffusionsfm/model/scheduler.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ipdb # noqa: F401
2
+ import matplotlib.pyplot as plt
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from diffusionsfm.utils.visualization import plot_to_image
8
+
9
+
10
+ class NoiseScheduler(nn.Module):
11
+ def __init__(
12
+ self,
13
+ max_timesteps=1000,
14
+ beta_start=0.0001,
15
+ beta_end=0.02,
16
+ cos_power=2,
17
+ num_inference_steps=100,
18
+ type="linear",
19
+ ):
20
+ super().__init__()
21
+ self.max_timesteps = max_timesteps
22
+ self.num_inference_steps = num_inference_steps
23
+ self.beta_start = beta_start
24
+ self.beta_end = beta_end
25
+ self.cos_power = cos_power
26
+ self.type = type
27
+
28
+ if type == "linear":
29
+ self.register_linear_schedule()
30
+ elif type == "cosine":
31
+ self.register_cosine_schedule(cos_power)
32
+ elif type == "scaled_linear":
33
+ self.register_scaled_linear_schedule()
34
+
35
+ self.inference_timesteps = self.compute_inference_timesteps()
36
+
37
+ def register_linear_schedule(self):
38
+ # zero terminal SNR (https://arxiv.org/pdf/2305.08891)
39
+ betas = torch.linspace(
40
+ self.beta_start,
41
+ self.beta_end,
42
+ self.max_timesteps,
43
+ dtype=torch.float32,
44
+ )
45
+ alphas = 1.0 - betas
46
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
47
+ alphas_bar_sqrt = alphas_cumprod.sqrt()
48
+
49
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
50
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
51
+
52
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
53
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
54
+
55
+ alphas_bar = alphas_bar_sqrt**2
56
+ alphas = alphas_bar[1:] / alphas_bar[:-1]
57
+ alphas = torch.cat([alphas_bar[0:1], alphas])
58
+ betas = 1 - alphas
59
+
60
+ self.register_buffer(
61
+ "betas",
62
+ betas,
63
+ )
64
+ self.register_buffer("alphas", 1.0 - self.betas)
65
+ self.register_buffer("alphas_cumprod", torch.cumprod(self.alphas, dim=0))
66
+
67
+ def register_cosine_schedule(self, cos_power, s=0.008):
68
+ timesteps = (
69
+ torch.arange(self.max_timesteps + 1, dtype=torch.float32)
70
+ / self.max_timesteps
71
+ )
72
+ alpha_bars = (timesteps + s) / (1 + s) * np.pi / 2
73
+ alpha_bars = torch.cos(alpha_bars).pow(cos_power)
74
+ alpha_bars = alpha_bars / alpha_bars[0]
75
+ betas = 1 - alpha_bars[1:] / alpha_bars[:-1]
76
+ betas = np.clip(betas, a_min=0, a_max=0.999)
77
+
78
+ self.register_buffer(
79
+ "betas",
80
+ betas,
81
+ )
82
+ self.register_buffer("alphas", 1.0 - betas)
83
+ self.register_buffer("alphas_cumprod", torch.cumprod(self.alphas, dim=0))
84
+
85
+ def register_scaled_linear_schedule(self):
86
+ self.register_buffer(
87
+ "betas",
88
+ torch.linspace(
89
+ self.beta_start**0.5,
90
+ self.beta_end**0.5,
91
+ self.max_timesteps,
92
+ dtype=torch.float32,
93
+ )
94
+ ** 2,
95
+ )
96
+ self.register_buffer("alphas", 1.0 - self.betas)
97
+ self.register_buffer("alphas_cumprod", torch.cumprod(self.alphas, dim=0))
98
+
99
+ def compute_inference_timesteps(
100
+ self, num_inference_steps=None, num_train_steps=None
101
+ ):
102
+ # based on diffusers's scheduling code
103
+ if num_inference_steps is None:
104
+ num_inference_steps = self.num_inference_steps
105
+ if num_train_steps is None:
106
+ num_train_steps = self.max_timesteps
107
+ step_ratio = num_train_steps // num_inference_steps
108
+ timesteps = (
109
+ (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].astype(int)
110
+ )
111
+ return timesteps
112
+
113
+ def plot_schedule(self, return_image=False):
114
+ fig = plt.figure(figsize=(6, 4), dpi=100)
115
+ alpha_bars = self.alphas_cumprod.cpu().numpy()
116
+ plt.plot(np.sqrt(alpha_bars))
117
+ plt.grid()
118
+ if self.type == "linear":
119
+ plt.title(
120
+ f"Linear (T={self.max_timesteps}, S={self.beta_start}, E={self.beta_end})"
121
+ )
122
+ else:
123
+ self.type == "cosine"
124
+ plt.title(f"Cosine (T={self.max_timesteps}, P={self.cos_power})")
125
+ if return_image:
126
+ image = plot_to_image(fig)
127
+ plt.close(fig)
128
+ return image
diffusionsfm/utils/__init__.py ADDED
File without changes
diffusionsfm/utils/configs.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+
4
+ from omegaconf import OmegaConf
5
+
6
+
7
+ def load_cfg(config_path):
8
+ """
9
+ Loads a yaml configuration file.
10
+
11
+ Follows the chain of yaml configuration files that have a `_BASE` key, and updates
12
+ the new keys accordingly. _BASE configurations can be specified using relative
13
+ paths.
14
+ """
15
+ config_dir = os.path.dirname(config_path)
16
+ config_path = os.path.basename(config_path)
17
+ return load_cfg_recursive(config_dir, config_path)
18
+
19
+
20
+ def load_cfg_recursive(config_dir, config_path):
21
+ """
22
+ Recursively loads config files.
23
+
24
+ Follows the chain of yaml configuration files that have a `_BASE` key, and updates
25
+ the new keys accordingly. _BASE configurations can be specified using relative
26
+ paths.
27
+ """
28
+ cfg = OmegaConf.load(os.path.join(config_dir, config_path))
29
+ base_path = OmegaConf.select(cfg, "_BASE", default=None)
30
+ if base_path is not None:
31
+ base_cfg = load_cfg_recursive(config_dir, base_path)
32
+ cfg = OmegaConf.merge(base_cfg, cfg)
33
+ return cfg
34
+
35
+
36
+ def get_cfg():
37
+ parser = argparse.ArgumentParser()
38
+ parser.add_argument("--config-path", type=str, required=True)
39
+ args = parser.parse_args()
40
+ cfg = load_cfg(args.config_path)
41
+ print(OmegaConf.to_yaml(cfg))
42
+
43
+ exp_dir = os.path.join(cfg.training.runs_dir, cfg.training.exp_tag)
44
+ os.makedirs(exp_dir, exist_ok=True)
45
+ to_path = os.path.join(exp_dir, os.path.basename(args.config_path))
46
+ if not os.path.exists(to_path):
47
+ OmegaConf.save(config=cfg, f=to_path)
48
+ return cfg
49
+
50
+
51
+ def get_cfg_from_path(config_path):
52
+ """
53
+ args:
54
+ config_path - get config from path
55
+ """
56
+ print("getting config from path")
57
+
58
+ cfg = load_cfg(config_path)
59
+ print(OmegaConf.to_yaml(cfg))
60
+
61
+ exp_dir = os.path.join(cfg.training.runs_dir, cfg.training.exp_tag)
62
+ os.makedirs(exp_dir, exist_ok=True)
63
+ to_path = os.path.join(exp_dir, os.path.basename(config_path))
64
+ if not os.path.exists(to_path):
65
+ OmegaConf.save(config=cfg, f=to_path)
66
+ return cfg
diffusionsfm/utils/distortion.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import ipdb
3
+ import numpy as np
4
+ from PIL import Image
5
+ import torch
6
+
7
+
8
+ # https://gist.github.com/davegreenwood/820d51ac5ec88a2aeda28d3079e7d9eb
9
+ def apply_distortion(pts, k1, k2):
10
+ """
11
+ Arguments:
12
+ pts (N x 2): numpy array in NDC coordinates
13
+ k1, k2 distortion coefficients
14
+ Return:
15
+ pts (N x 2): distorted points in NDC coordinates
16
+ """
17
+ r2 = np.square(pts).sum(-1)
18
+ f = 1 + k1 * r2 + k2 * r2**2
19
+ return f[..., None] * pts
20
+
21
+
22
+ # https://gist.github.com/davegreenwood/820d51ac5ec88a2aeda28d3079e7d9eb
23
+ def apply_distortion_tensor(pts, k1, k2):
24
+ """
25
+ Arguments:
26
+ pts (N x 2): numpy array in NDC coordinates
27
+ k1, k2 distortion coefficients
28
+ Return:
29
+ pts (N x 2): distorted points in NDC coordinates
30
+ """
31
+ r2 = torch.square(pts).sum(-1)
32
+ f = 1 + k1 * r2 + k2 * r2**2
33
+ return f[..., None] * pts
34
+
35
+
36
+ # https://gist.github.com/davegreenwood/820d51ac5ec88a2aeda28d3079e7d9eb
37
+ def remove_distortion_iter(points, k1, k2):
38
+ """
39
+ Arguments:
40
+ pts (N x 2): numpy array in NDC coordinates
41
+ k1, k2 distortion coefficients
42
+ Return:
43
+ pts (N x 2): distorted points in NDC coordinates
44
+ """
45
+ pts = ptsd = points
46
+ for _ in range(5):
47
+ r2 = np.square(pts).sum(-1)
48
+ f = 1 + k1 * r2 + k2 * r2**2
49
+ pts = ptsd / f[..., None]
50
+
51
+ return pts
52
+
53
+
54
+ def make_square(im, fill_color=(0, 0, 0)):
55
+ x, y = im.size
56
+ size = max(x, y)
57
+ new_im = Image.new("RGB", (size, size), fill_color)
58
+ corner = (int((size - x) / 2), int((size - y) / 2))
59
+ new_im.paste(im, corner)
60
+ return new_im, corner
61
+
62
+
63
+ def pixel_to_ndc(coords, image_size):
64
+ """
65
+ Converts pixel coordinates to normalized device coordinates (Pytorch3D convention
66
+ with upper left = (1, 1)) for a square image.
67
+
68
+ Args:
69
+ coords: Pixel coordinates UL=(0, 0), LR=(image_size, image_size).
70
+ image_size (int): Image size.
71
+
72
+ Returns:
73
+ NDC coordinates UL=(1, 1) LR=(-1, -1).
74
+ """
75
+ coords = np.array(coords)
76
+ return 1 - coords / image_size * 2
77
+
78
+
79
+ def ndc_to_pixel(coords, image_size):
80
+ """
81
+ Converts normalized device coordinates to pixel coordinates for a square image.
82
+ """
83
+ num_points = coords.shape[0]
84
+ sizes = np.tile(np.array(image_size, dtype=np.float32)[None, ...], (num_points, 1))
85
+
86
+ coords = np.array(coords, dtype=np.float32)
87
+ return (1 - coords) * sizes / 2
88
+
89
+
90
+ def distort_image(image, bbox, k1, k2, modify_bbox=False):
91
+ # We want to operate in -1 to 1 space using the padded square of the original image
92
+ image, corner = make_square(image)
93
+ bbox[:2] += np.array(corner)
94
+ bbox[2:] += np.array(corner)
95
+
96
+ # Construct grid points
97
+ x = np.linspace(1, -1, image.width, dtype=np.float32)
98
+ y = np.linspace(1, -1, image.height, dtype=np.float32)
99
+ x, y = np.meshgrid(x, y, indexing="xy")
100
+ xy_grid = np.stack((x, y), axis=-1)
101
+ points = xy_grid.reshape((image.height * image.width, 2))
102
+ new_points = ndc_to_pixel(apply_distortion(points, k1, k2), image.size)
103
+
104
+ # Distort image by remapping
105
+ map_x = new_points[:, 0].reshape((image.height, image.width))
106
+ map_y = new_points[:, 1].reshape((image.height, image.width))
107
+ distorted = cv2.remap(
108
+ np.asarray(image),
109
+ map_x,
110
+ map_y,
111
+ cv2.INTER_LINEAR,
112
+ )
113
+ distorted = Image.fromarray(distorted)
114
+
115
+ # Find distorted crop bounds - inverse process of above
116
+ if modify_bbox:
117
+ center = (bbox[:2] + bbox[2:]) / 2
118
+ top, bottom = (bbox[0], center[1]), (bbox[2], center[1])
119
+ left, right = (center[0], bbox[1]), (center[0], bbox[3])
120
+ bbox_points = np.array(
121
+ [
122
+ pixel_to_ndc(top, image.size),
123
+ pixel_to_ndc(left, image.size),
124
+ pixel_to_ndc(bottom, image.size),
125
+ pixel_to_ndc(right, image.size),
126
+ ],
127
+ dtype=np.float32,
128
+ )
129
+ else:
130
+ bbox_points = np.array(
131
+ [pixel_to_ndc(bbox[:2], image.size), pixel_to_ndc(bbox[2:], image.size)],
132
+ dtype=np.float32,
133
+ )
134
+
135
+ # Inverse mapping
136
+ distorted_bbox = remove_distortion_iter(bbox_points, k1, k2)
137
+
138
+ if modify_bbox:
139
+ p = ndc_to_pixel(distorted_bbox, image.size)
140
+ distorted_bbox = np.array([p[0][0], p[1][1], p[2][0], p[3][1]])
141
+ else:
142
+ distorted_bbox = ndc_to_pixel(distorted_bbox, image.size).reshape(4)
143
+
144
+ return distorted, distorted_bbox
diffusionsfm/utils/distributed.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import socket
3
+ from contextlib import closing
4
+
5
+ import torch.distributed as dist
6
+
7
+
8
+ def get_open_port():
9
+ with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
10
+ s.bind(("", 0))
11
+ s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
12
+ return s.getsockname()[1]
13
+
14
+
15
+ # Distributed process group
16
+ def ddp_setup(rank, world_size, port="12345"):
17
+ """
18
+ Args:
19
+ rank: Unique Identifier
20
+ world_size: number of processes
21
+ """
22
+ os.environ["MASTER_ADDR"] = "localhost"
23
+ print(f"MasterPort: {str(port)}")
24
+ os.environ["MASTER_PORT"] = str(port)
25
+
26
+ # initialize the process group
27
+ dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
28
+
29
+
30
+ def cleanup():
31
+ dist.destroy_process_group()
diffusionsfm/utils/geometry.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from pytorch3d.renderer import FoVPerspectiveCameras
4
+ from pytorch3d.transforms import quaternion_to_matrix
5
+
6
+
7
+ def generate_random_rotations(N=1, device="cpu"):
8
+ q = torch.randn(N, 4, device=device)
9
+ q = q / q.norm(dim=-1, keepdim=True)
10
+ return quaternion_to_matrix(q)
11
+
12
+
13
+ def symmetric_orthogonalization(x):
14
+ """Maps 9D input vectors onto SO(3) via symmetric orthogonalization.
15
+
16
+ x: should have size [batch_size, 9]
17
+
18
+ Output has size [batch_size, 3, 3], where each inner 3x3 matrix is in SO(3).
19
+ """
20
+ m = x.view(-1, 3, 3)
21
+ u, s, v = torch.svd(m)
22
+ vt = torch.transpose(v, 1, 2)
23
+ det = torch.det(torch.matmul(u, vt))
24
+ det = det.view(-1, 1, 1)
25
+ vt = torch.cat((vt[:, :2, :], vt[:, -1:, :] * det), 1)
26
+ r = torch.matmul(u, vt)
27
+ return r
28
+
29
+
30
+ def get_permutations(num_images):
31
+ permutations = []
32
+ for i in range(0, num_images):
33
+ for j in range(0, num_images):
34
+ if i != j:
35
+ permutations.append((j, i))
36
+
37
+ return permutations
38
+
39
+
40
+ def n_to_np_rotations(num_frames, n_rots):
41
+ R_pred_rel = []
42
+ permutations = get_permutations(num_frames)
43
+ for i, j in permutations:
44
+ R_pred_rel.append(n_rots[i].T @ n_rots[j])
45
+ R_pred_rel = torch.stack(R_pred_rel)
46
+
47
+ return R_pred_rel
48
+
49
+
50
+ def compute_angular_error_batch(rotation1, rotation2):
51
+ R_rel = np.einsum("Bij,Bjk ->Bik", rotation2, rotation1.transpose(0, 2, 1))
52
+ t = (np.trace(R_rel, axis1=1, axis2=2) - 1) / 2
53
+ theta = np.arccos(np.clip(t, -1, 1))
54
+ return theta * 180 / np.pi
55
+
56
+
57
+ # A should be GT, B should be predicted
58
+ def compute_optimal_alignment(A, B):
59
+ """
60
+ Compute the optimal scale s, rotation R, and translation t that minimizes:
61
+ || A - (s * B @ R + T) || ^ 2
62
+
63
+ Reference: Umeyama (TPAMI 91)
64
+
65
+ Args:
66
+ A (torch.Tensor): (N, 3).
67
+ B (torch.Tensor): (N, 3).
68
+
69
+ Returns:
70
+ s (float): scale.
71
+ R (torch.Tensor): rotation matrix (3, 3).
72
+ t (torch.Tensor): translation (3,).
73
+ """
74
+ A_bar = A.mean(0)
75
+ B_bar = B.mean(0)
76
+ # normally with R @ B, this would be A @ B.T
77
+ H = (B - B_bar).T @ (A - A_bar)
78
+ U, S, Vh = torch.linalg.svd(H, full_matrices=True)
79
+ s = torch.linalg.det(U @ Vh)
80
+ S_prime = torch.diag(torch.tensor([1, 1, torch.sign(s)], device=A.device))
81
+ variance = torch.sum((B - B_bar) ** 2)
82
+ scale = 1 / variance * torch.trace(torch.diag(S) @ S_prime)
83
+ R = U @ S_prime @ Vh
84
+ t = A_bar - scale * B_bar @ R
85
+
86
+ A_hat = scale * B @ R + t
87
+ return A_hat, scale, R, t
88
+
89
+
90
+ def compute_optimal_translation_alignment(T_A, T_B, R_B):
91
+ """
92
+ Assuming right-multiplied rotation matrices.
93
+
94
+ E.g., for world2cam R and T, a world coordinate is transformed to camera coordinate
95
+ system using X_cam = X_world.T @ R + T = R.T @ X_world + T
96
+
97
+ Finds s, t that minimizes || T_A - (s * T_B + R_B.T @ t) ||^2
98
+
99
+ Args:
100
+ T_A (torch.Tensor): Target translation (N, 3).
101
+ T_B (torch.Tensor): Initial translation (N, 3).
102
+ R_B (torch.Tensor): Initial rotation (N, 3, 3).
103
+
104
+ Returns:
105
+ T_A_hat (torch.Tensor): s * T_B + t @ R_B (N, 3).
106
+ scale s (torch.Tensor): (1,).
107
+ translation t (torch.Tensor): (1, 3).
108
+ """
109
+ n = len(T_A)
110
+
111
+ T_A = T_A.unsqueeze(2)
112
+ T_B = T_B.unsqueeze(2)
113
+
114
+ A = torch.sum(T_B * T_A)
115
+ B = (T_B.transpose(1, 2) @ R_B.transpose(1, 2)).sum(0) @ (R_B @ T_A).sum(0) / n
116
+ C = torch.sum(T_B * T_B)
117
+ D = (T_B.transpose(1, 2) @ R_B.transpose(1, 2)).sum(0)
118
+ E = (D * D).sum() / n
119
+
120
+ s = (A - B.sum()) / (C - E.sum())
121
+
122
+ t = (R_B @ (T_A - s * T_B)).sum(0) / n
123
+
124
+ T_A_hat = s * T_B + R_B.transpose(1, 2) @ t
125
+
126
+ return T_A_hat.squeeze(2), s, t.transpose(1, 0)
127
+
128
+
129
+ def get_error(predict_rotations, R_pred, T_pred, R_gt, T_gt, gt_scene_scale):
130
+ if predict_rotations:
131
+ cameras_gt = FoVPerspectiveCameras(R=R_gt, T=T_gt)
132
+ cc_gt = cameras_gt.get_camera_center()
133
+ cameras_pred = FoVPerspectiveCameras(R=R_pred, T=T_pred)
134
+ cc_pred = cameras_pred.get_camera_center()
135
+
136
+ A_hat, _, _, _ = compute_optimal_alignment(cc_gt, cc_pred)
137
+ norm = torch.linalg.norm(cc_gt - A_hat, dim=1) / gt_scene_scale
138
+
139
+ norms = np.ndarray.tolist(norm.detach().cpu().numpy())
140
+ return norms, A_hat
141
+ else:
142
+ T_A_hat, _, _ = compute_optimal_translation_alignment(T_gt, T_pred, R_pred)
143
+ norm = torch.linalg.norm(T_gt - T_A_hat, dim=1) / gt_scene_scale
144
+ norms = np.ndarray.tolist(norm.detach().cpu().numpy())
145
+ return norms, T_A_hat