Spaces:
Sleeping
Sleeping
Upload 57 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +22 -0
- LICENSE +21 -0
- assets/demo.png +3 -0
- conf/config.yaml +83 -0
- conf/diffusion.yml +110 -0
- data/demo/jellycat/001.jpg +3 -0
- data/demo/jellycat/002.jpg +3 -0
- data/demo/jellycat/003.jpg +3 -0
- data/demo/jellycat/004.jpg +3 -0
- data/demo/jordan/001.png +3 -0
- data/demo/jordan/002.png +3 -0
- data/demo/jordan/003.png +3 -0
- data/demo/jordan/004.png +3 -0
- data/demo/jordan/005.png +3 -0
- data/demo/jordan/006.png +3 -0
- data/demo/jordan/007.png +3 -0
- data/demo/jordan/008.png +3 -0
- data/demo/kew_gardens_ruined_arch/001.jpeg +3 -0
- data/demo/kew_gardens_ruined_arch/002.jpeg +3 -0
- data/demo/kew_gardens_ruined_arch/003.jpeg +3 -0
- data/demo/kotor_cathedral/001.jpeg +3 -0
- data/demo/kotor_cathedral/002.jpeg +3 -0
- data/demo/kotor_cathedral/003.jpeg +3 -0
- data/demo/kotor_cathedral/004.jpeg +3 -0
- data/demo/kotor_cathedral/005.jpeg +3 -0
- data/demo/kotor_cathedral/006.jpeg +3 -0
- diffusionsfm/__init__.py +1 -0
- diffusionsfm/dataset/__init__.py +0 -0
- diffusionsfm/dataset/co3d_v2.py +792 -0
- diffusionsfm/dataset/custom.py +105 -0
- diffusionsfm/eval/__init__.py +0 -0
- diffusionsfm/eval/eval_category.py +292 -0
- diffusionsfm/eval/eval_jobs.py +175 -0
- diffusionsfm/inference/__init__.py +0 -0
- diffusionsfm/inference/ddim.py +145 -0
- diffusionsfm/inference/load_model.py +97 -0
- diffusionsfm/inference/predict.py +93 -0
- diffusionsfm/model/base_model.py +16 -0
- diffusionsfm/model/blocks.py +247 -0
- diffusionsfm/model/diffuser.py +195 -0
- diffusionsfm/model/diffuser_dpt.py +331 -0
- diffusionsfm/model/dit.py +428 -0
- diffusionsfm/model/feature_extractors.py +176 -0
- diffusionsfm/model/memory_efficient_attention.py +51 -0
- diffusionsfm/model/scheduler.py +128 -0
- diffusionsfm/utils/__init__.py +0 -0
- diffusionsfm/utils/configs.py +66 -0
- diffusionsfm/utils/distortion.py +144 -0
- diffusionsfm/utils/distributed.py +31 -0
- diffusionsfm/utils/geometry.py +145 -0
.gitattributes
CHANGED
@@ -33,3 +33,25 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
assets/demo.png filter=lfs diff=lfs merge=lfs -text
|
37 |
+
data/demo/jellycat/001.jpg filter=lfs diff=lfs merge=lfs -text
|
38 |
+
data/demo/jellycat/002.jpg filter=lfs diff=lfs merge=lfs -text
|
39 |
+
data/demo/jellycat/003.jpg filter=lfs diff=lfs merge=lfs -text
|
40 |
+
data/demo/jellycat/004.jpg filter=lfs diff=lfs merge=lfs -text
|
41 |
+
data/demo/jordan/001.png filter=lfs diff=lfs merge=lfs -text
|
42 |
+
data/demo/jordan/002.png filter=lfs diff=lfs merge=lfs -text
|
43 |
+
data/demo/jordan/003.png filter=lfs diff=lfs merge=lfs -text
|
44 |
+
data/demo/jordan/004.png filter=lfs diff=lfs merge=lfs -text
|
45 |
+
data/demo/jordan/005.png filter=lfs diff=lfs merge=lfs -text
|
46 |
+
data/demo/jordan/006.png filter=lfs diff=lfs merge=lfs -text
|
47 |
+
data/demo/jordan/007.png filter=lfs diff=lfs merge=lfs -text
|
48 |
+
data/demo/jordan/008.png filter=lfs diff=lfs merge=lfs -text
|
49 |
+
data/demo/kew_gardens_ruined_arch/001.jpeg filter=lfs diff=lfs merge=lfs -text
|
50 |
+
data/demo/kew_gardens_ruined_arch/002.jpeg filter=lfs diff=lfs merge=lfs -text
|
51 |
+
data/demo/kew_gardens_ruined_arch/003.jpeg filter=lfs diff=lfs merge=lfs -text
|
52 |
+
data/demo/kotor_cathedral/001.jpeg filter=lfs diff=lfs merge=lfs -text
|
53 |
+
data/demo/kotor_cathedral/002.jpeg filter=lfs diff=lfs merge=lfs -text
|
54 |
+
data/demo/kotor_cathedral/003.jpeg filter=lfs diff=lfs merge=lfs -text
|
55 |
+
data/demo/kotor_cathedral/004.jpeg filter=lfs diff=lfs merge=lfs -text
|
56 |
+
data/demo/kotor_cathedral/005.jpeg filter=lfs diff=lfs merge=lfs -text
|
57 |
+
data/demo/kotor_cathedral/006.jpeg filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2025 Qitao Zhao
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
assets/demo.png
ADDED
![]() |
Git LFS Details
|
conf/config.yaml
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
training:
|
2 |
+
resume: False # If True, must set hydra.run.dir accordingly
|
3 |
+
pretrain_path: ""
|
4 |
+
interval_visualize: 1000
|
5 |
+
interval_save_checkpoint: 5000
|
6 |
+
interval_delete_checkpoint: 10000
|
7 |
+
interval_evaluate: 5000
|
8 |
+
delete_all_checkpoints_after_training: False
|
9 |
+
lr: 1e-4
|
10 |
+
mixed_precision: True
|
11 |
+
matmul_precision: high
|
12 |
+
max_iterations: 100000
|
13 |
+
batch_size: 64
|
14 |
+
num_workers: 8
|
15 |
+
gpu_id: 0
|
16 |
+
freeze_encoder: True
|
17 |
+
seed: 0
|
18 |
+
job_key: "" # Use this for submitit sweeps where timestamps might collide
|
19 |
+
translation_scale: 1.0
|
20 |
+
regression: False
|
21 |
+
prob_unconditional: 0
|
22 |
+
load_extra_cameras: False
|
23 |
+
calculate_intrinsics: False
|
24 |
+
distort: False
|
25 |
+
normalize_first_camera: True
|
26 |
+
diffuse_origins_and_endpoints: True
|
27 |
+
diffuse_depths: False
|
28 |
+
depth_resolution: 1
|
29 |
+
dpt_head: False
|
30 |
+
full_num_patches_x: 16
|
31 |
+
full_num_patches_y: 16
|
32 |
+
dpt_encoder_features: True
|
33 |
+
nearest_neighbor: True
|
34 |
+
no_bg_targets: True
|
35 |
+
unit_normalize_scene: False
|
36 |
+
sd_scale: 2
|
37 |
+
bfloat: True
|
38 |
+
first_cam_mediod: True
|
39 |
+
gradient_clipping: False
|
40 |
+
l1_loss: False
|
41 |
+
grad_accumulation: False
|
42 |
+
reinit: False
|
43 |
+
|
44 |
+
model:
|
45 |
+
pred_x0: True
|
46 |
+
model_type: dit
|
47 |
+
num_patches_x: 16
|
48 |
+
num_patches_y: 16
|
49 |
+
depth: 16
|
50 |
+
num_images: 1
|
51 |
+
random_num_images: True
|
52 |
+
feature_extractor: dino
|
53 |
+
append_ndc: True
|
54 |
+
within_image: False
|
55 |
+
use_homogeneous: True
|
56 |
+
freeze_transformer: False
|
57 |
+
cond_depth_mask: True
|
58 |
+
|
59 |
+
noise_scheduler:
|
60 |
+
type: linear
|
61 |
+
max_timesteps: 100
|
62 |
+
beta_start: 0.0120
|
63 |
+
beta_end: 0.00085
|
64 |
+
marigold_ddim: False
|
65 |
+
|
66 |
+
dataset:
|
67 |
+
name: co3d
|
68 |
+
shape: all_train
|
69 |
+
apply_augmentation: True
|
70 |
+
use_global_intrinsics: True
|
71 |
+
mask_holes: True
|
72 |
+
image_size: 224
|
73 |
+
|
74 |
+
debug:
|
75 |
+
wandb: True
|
76 |
+
project_name: diffusionsfm
|
77 |
+
run_name:
|
78 |
+
anomaly_detection: False
|
79 |
+
|
80 |
+
hydra:
|
81 |
+
run:
|
82 |
+
dir: ./output/${now:%m%d_%H%M%S_%f}${training.job_key}
|
83 |
+
output_subdir: hydra
|
conf/diffusion.yml
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: diffusion
|
2 |
+
channels:
|
3 |
+
- conda-forge
|
4 |
+
- iopath
|
5 |
+
- nvidia
|
6 |
+
- pkgs/main
|
7 |
+
- pytorch
|
8 |
+
- xformers
|
9 |
+
dependencies:
|
10 |
+
- _libgcc_mutex=0.1=conda_forge
|
11 |
+
- _openmp_mutex=4.5=2_gnu
|
12 |
+
- blas=1.0=mkl
|
13 |
+
- brotli-python=1.0.9=py39h5a03fae_9
|
14 |
+
- bzip2=1.0.8=h7f98852_4
|
15 |
+
- ca-certificates=2023.7.22=hbcca054_0
|
16 |
+
- certifi=2023.7.22=pyhd8ed1ab_0
|
17 |
+
- charset-normalizer=3.2.0=pyhd8ed1ab_0
|
18 |
+
- colorama=0.4.6=pyhd8ed1ab_0
|
19 |
+
- cuda-cudart=11.7.99=0
|
20 |
+
- cuda-cupti=11.7.101=0
|
21 |
+
- cuda-libraries=11.7.1=0
|
22 |
+
- cuda-nvrtc=11.7.99=0
|
23 |
+
- cuda-nvtx=11.7.91=0
|
24 |
+
- cuda-runtime=11.7.1=0
|
25 |
+
- ffmpeg=4.3=hf484d3e_0
|
26 |
+
- filelock=3.12.2=pyhd8ed1ab_0
|
27 |
+
- freetype=2.12.1=hca18f0e_1
|
28 |
+
- fvcore=0.1.5.post20221221=pyhd8ed1ab_0
|
29 |
+
- gmp=6.2.1=h58526e2_0
|
30 |
+
- gmpy2=2.1.2=py39h376b7d2_1
|
31 |
+
- gnutls=3.6.13=h85f3911_1
|
32 |
+
- idna=3.4=pyhd8ed1ab_0
|
33 |
+
- intel-openmp=2022.1.0=h9e868ea_3769
|
34 |
+
- iopath=0.1.9=py39
|
35 |
+
- jinja2=3.1.2=pyhd8ed1ab_1
|
36 |
+
- jpeg=9e=h0b41bf4_3
|
37 |
+
- lame=3.100=h166bdaf_1003
|
38 |
+
- lcms2=2.15=hfd0df8a_0
|
39 |
+
- ld_impl_linux-64=2.40=h41732ed_0
|
40 |
+
- lerc=4.0.0=h27087fc_0
|
41 |
+
- libblas=3.9.0=16_linux64_mkl
|
42 |
+
- libcblas=3.9.0=16_linux64_mkl
|
43 |
+
- libcublas=11.10.3.66=0
|
44 |
+
- libcufft=10.7.2.124=h4fbf590_0
|
45 |
+
- libcufile=1.7.1.12=0
|
46 |
+
- libcurand=10.3.3.129=0
|
47 |
+
- libcusolver=11.4.0.1=0
|
48 |
+
- libcusparse=11.7.4.91=0
|
49 |
+
- libdeflate=1.17=h0b41bf4_0
|
50 |
+
- libffi=3.3=h58526e2_2
|
51 |
+
- libgcc-ng=13.1.0=he5830b7_0
|
52 |
+
- libgomp=13.1.0=he5830b7_0
|
53 |
+
- libiconv=1.17=h166bdaf_0
|
54 |
+
- liblapack=3.9.0=16_linux64_mkl
|
55 |
+
- libnpp=11.7.4.75=0
|
56 |
+
- libnvjpeg=11.8.0.2=0
|
57 |
+
- libpng=1.6.39=h753d276_0
|
58 |
+
- libsqlite=3.42.0=h2797004_0
|
59 |
+
- libstdcxx-ng=13.1.0=hfd8a6a1_0
|
60 |
+
- libtiff=4.5.0=h6adf6a1_2
|
61 |
+
- libwebp-base=1.3.1=hd590300_0
|
62 |
+
- libxcb=1.13=h7f98852_1004
|
63 |
+
- libzlib=1.2.13=hd590300_5
|
64 |
+
- markupsafe=2.1.3=py39hd1e30aa_0
|
65 |
+
- mkl=2022.1.0=hc2b9512_224
|
66 |
+
- mpc=1.3.1=hfe3b2da_0
|
67 |
+
- mpfr=4.2.0=hb012696_0
|
68 |
+
- mpmath=1.3.0=pyhd8ed1ab_0
|
69 |
+
- ncurses=6.4=hcb278e6_0
|
70 |
+
- nettle=3.6=he412f7d_0
|
71 |
+
- networkx=3.1=pyhd8ed1ab_0
|
72 |
+
- numpy=1.25.2=py39h6183b62_0
|
73 |
+
- openh264=2.1.1=h780b84a_0
|
74 |
+
- openjpeg=2.5.0=hfec8fc6_2
|
75 |
+
- openssl=1.1.1v=hd590300_0
|
76 |
+
- pillow=9.4.0=py39h2320bf1_1
|
77 |
+
- pip=23.2.1=pyhd8ed1ab_0
|
78 |
+
- portalocker=2.7.0=py39hf3d152e_0
|
79 |
+
- pthread-stubs=0.4=h36c2ea0_1001
|
80 |
+
- pysocks=1.7.1=pyha2e5f31_6
|
81 |
+
- python=3.9.0=hffdb5ce_5_cpython
|
82 |
+
- python_abi=3.9=3_cp39
|
83 |
+
- pytorch=2.0.1=py3.9_cuda11.7_cudnn8.5.0_0
|
84 |
+
- pytorch-cuda=11.7=h778d358_5
|
85 |
+
- pytorch-mutex=1.0=cuda
|
86 |
+
- pyyaml=6.0=py39hb9d737c_5
|
87 |
+
- readline=8.2=h8228510_1
|
88 |
+
- requests=2.31.0=pyhd8ed1ab_0
|
89 |
+
- setuptools=68.0.0=pyhd8ed1ab_0
|
90 |
+
- sqlite=3.42.0=h2c6b66d_0
|
91 |
+
- sympy=1.12=pypyh9d50eac_103
|
92 |
+
- tabulate=0.9.0=pyhd8ed1ab_1
|
93 |
+
- termcolor=2.3.0=pyhd8ed1ab_0
|
94 |
+
- tk=8.6.12=h27826a3_0
|
95 |
+
- torchaudio=2.0.2=py39_cu117
|
96 |
+
- torchtriton=2.0.0=py39
|
97 |
+
- torchvision=0.15.2=py39_cu117
|
98 |
+
- tqdm=4.66.1=pyhd8ed1ab_0
|
99 |
+
- typing_extensions=4.7.1=pyha770c72_0
|
100 |
+
- tzdata=2023c=h71feb2d_0
|
101 |
+
- urllib3=2.0.4=pyhd8ed1ab_0
|
102 |
+
- wheel=0.41.1=pyhd8ed1ab_0
|
103 |
+
- xformers=0.0.21=py39_cu11.8.0_pyt2.0.1
|
104 |
+
- xorg-libxau=1.0.11=hd590300_0
|
105 |
+
- xorg-libxdmcp=1.1.3=h7f98852_0
|
106 |
+
- xz=5.2.6=h166bdaf_0
|
107 |
+
- yacs=0.1.8=pyhd8ed1ab_0
|
108 |
+
- yaml=0.2.5=h7f98852_2
|
109 |
+
- zlib=1.2.13=hd590300_5
|
110 |
+
- zstd=1.5.2=hfc55251_7
|
data/demo/jellycat/001.jpg
ADDED
![]() |
Git LFS Details
|
data/demo/jellycat/002.jpg
ADDED
![]() |
Git LFS Details
|
data/demo/jellycat/003.jpg
ADDED
![]() |
Git LFS Details
|
data/demo/jellycat/004.jpg
ADDED
![]() |
Git LFS Details
|
data/demo/jordan/001.png
ADDED
![]() |
Git LFS Details
|
data/demo/jordan/002.png
ADDED
![]() |
Git LFS Details
|
data/demo/jordan/003.png
ADDED
![]() |
Git LFS Details
|
data/demo/jordan/004.png
ADDED
![]() |
Git LFS Details
|
data/demo/jordan/005.png
ADDED
![]() |
Git LFS Details
|
data/demo/jordan/006.png
ADDED
![]() |
Git LFS Details
|
data/demo/jordan/007.png
ADDED
![]() |
Git LFS Details
|
data/demo/jordan/008.png
ADDED
![]() |
Git LFS Details
|
data/demo/kew_gardens_ruined_arch/001.jpeg
ADDED
![]() |
Git LFS Details
|
data/demo/kew_gardens_ruined_arch/002.jpeg
ADDED
![]() |
Git LFS Details
|
data/demo/kew_gardens_ruined_arch/003.jpeg
ADDED
![]() |
Git LFS Details
|
data/demo/kotor_cathedral/001.jpeg
ADDED
![]() |
Git LFS Details
|
data/demo/kotor_cathedral/002.jpeg
ADDED
![]() |
Git LFS Details
|
data/demo/kotor_cathedral/003.jpeg
ADDED
![]() |
Git LFS Details
|
data/demo/kotor_cathedral/004.jpeg
ADDED
![]() |
Git LFS Details
|
data/demo/kotor_cathedral/005.jpeg
ADDED
![]() |
Git LFS Details
|
data/demo/kotor_cathedral/006.jpeg
ADDED
![]() |
Git LFS Details
|
diffusionsfm/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .utils.rays import cameras_to_rays, rays_to_cameras, Rays
|
diffusionsfm/dataset/__init__.py
ADDED
File without changes
|
diffusionsfm/dataset/co3d_v2.py
ADDED
@@ -0,0 +1,792 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gzip
|
2 |
+
import json
|
3 |
+
import os.path as osp
|
4 |
+
import random
|
5 |
+
import socket
|
6 |
+
import time
|
7 |
+
import torch
|
8 |
+
import warnings
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
from PIL import Image, ImageFile
|
12 |
+
from tqdm import tqdm
|
13 |
+
from pytorch3d.renderer import PerspectiveCameras
|
14 |
+
from torch.utils.data import Dataset
|
15 |
+
from torchvision import transforms
|
16 |
+
import matplotlib.pyplot as plt
|
17 |
+
from scipy import ndimage as nd
|
18 |
+
|
19 |
+
from diffusionsfm.utils.distortion import distort_image
|
20 |
+
|
21 |
+
|
22 |
+
HOSTNAME = socket.gethostname()
|
23 |
+
|
24 |
+
CO3D_DIR = "../co3d_data" # update this
|
25 |
+
CO3D_ANNOTATION_DIR = osp.join(CO3D_DIR, "co3d_annotations")
|
26 |
+
CO3D_DIR = CO3D_DEPTH_DIR = osp.join(CO3D_DIR, "co3d")
|
27 |
+
order_path = osp.join(
|
28 |
+
CO3D_DIR, "co3d_v2_random_order_{sample_num}/{category}.json"
|
29 |
+
)
|
30 |
+
|
31 |
+
|
32 |
+
TRAINING_CATEGORIES = [
|
33 |
+
"apple",
|
34 |
+
"backpack",
|
35 |
+
"banana",
|
36 |
+
"baseballbat",
|
37 |
+
"baseballglove",
|
38 |
+
"bench",
|
39 |
+
"bicycle",
|
40 |
+
"bottle",
|
41 |
+
"bowl",
|
42 |
+
"broccoli",
|
43 |
+
"cake",
|
44 |
+
"car",
|
45 |
+
"carrot",
|
46 |
+
"cellphone",
|
47 |
+
"chair",
|
48 |
+
"cup",
|
49 |
+
"donut",
|
50 |
+
"hairdryer",
|
51 |
+
"handbag",
|
52 |
+
"hydrant",
|
53 |
+
"keyboard",
|
54 |
+
"laptop",
|
55 |
+
"microwave",
|
56 |
+
"motorcycle",
|
57 |
+
"mouse",
|
58 |
+
"orange",
|
59 |
+
"parkingmeter",
|
60 |
+
"pizza",
|
61 |
+
"plant",
|
62 |
+
"stopsign",
|
63 |
+
"teddybear",
|
64 |
+
"toaster",
|
65 |
+
"toilet",
|
66 |
+
"toybus",
|
67 |
+
"toyplane",
|
68 |
+
"toytrain",
|
69 |
+
"toytruck",
|
70 |
+
"tv",
|
71 |
+
"umbrella",
|
72 |
+
"vase",
|
73 |
+
"wineglass",
|
74 |
+
]
|
75 |
+
|
76 |
+
TEST_CATEGORIES = [
|
77 |
+
"ball",
|
78 |
+
"book",
|
79 |
+
"couch",
|
80 |
+
"frisbee",
|
81 |
+
"hotdog",
|
82 |
+
"kite",
|
83 |
+
"remote",
|
84 |
+
"sandwich",
|
85 |
+
"skateboard",
|
86 |
+
"suitcase",
|
87 |
+
]
|
88 |
+
|
89 |
+
assert len(TRAINING_CATEGORIES) + len(TEST_CATEGORIES) == 51
|
90 |
+
|
91 |
+
Image.MAX_IMAGE_PIXELS = None
|
92 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
93 |
+
|
94 |
+
|
95 |
+
def fill_depths(data, invalid=None):
|
96 |
+
data_list = []
|
97 |
+
for i in range(data.shape[0]):
|
98 |
+
data_item = data[i].numpy()
|
99 |
+
# Invalid must be 1 where stuff is invalid, 0 where valid
|
100 |
+
ind = nd.distance_transform_edt(
|
101 |
+
invalid[i], return_distances=False, return_indices=True
|
102 |
+
)
|
103 |
+
data_list.append(torch.tensor(data_item[tuple(ind)]))
|
104 |
+
return torch.stack(data_list, dim=0)
|
105 |
+
|
106 |
+
|
107 |
+
def full_scene_scale(batch):
|
108 |
+
cameras = PerspectiveCameras(R=batch["R"], T=batch["T"], device="cuda")
|
109 |
+
cc = cameras.get_camera_center()
|
110 |
+
centroid = torch.mean(cc, dim=0)
|
111 |
+
|
112 |
+
diffs = cc - centroid
|
113 |
+
norms = torch.linalg.norm(diffs, dim=1)
|
114 |
+
|
115 |
+
furthest_index = torch.argmax(norms).item()
|
116 |
+
scale = norms[furthest_index].item()
|
117 |
+
return scale
|
118 |
+
|
119 |
+
|
120 |
+
def square_bbox(bbox, padding=0.0, astype=None, tight=False):
|
121 |
+
"""
|
122 |
+
Computes a square bounding box, with optional padding parameters.
|
123 |
+
Args:
|
124 |
+
bbox: Bounding box in xyxy format (4,).
|
125 |
+
Returns:
|
126 |
+
square_bbox in xyxy format (4,).
|
127 |
+
"""
|
128 |
+
if astype is None:
|
129 |
+
astype = type(bbox[0])
|
130 |
+
bbox = np.array(bbox)
|
131 |
+
center = (bbox[:2] + bbox[2:]) / 2
|
132 |
+
extents = (bbox[2:] - bbox[:2]) / 2
|
133 |
+
|
134 |
+
# No black bars if tight
|
135 |
+
if tight:
|
136 |
+
s = min(extents) * (1 + padding)
|
137 |
+
else:
|
138 |
+
s = max(extents) * (1 + padding)
|
139 |
+
|
140 |
+
square_bbox = np.array(
|
141 |
+
[center[0] - s, center[1] - s, center[0] + s, center[1] + s],
|
142 |
+
dtype=astype,
|
143 |
+
)
|
144 |
+
return square_bbox
|
145 |
+
|
146 |
+
|
147 |
+
def unnormalize_image(image, return_numpy=True, return_int=True):
|
148 |
+
if isinstance(image, torch.Tensor):
|
149 |
+
image = image.detach().cpu().numpy()
|
150 |
+
|
151 |
+
if image.ndim == 3:
|
152 |
+
if image.shape[0] == 3:
|
153 |
+
image = image[None, ...]
|
154 |
+
elif image.shape[2] == 3:
|
155 |
+
image = image.transpose(2, 0, 1)[None, ...]
|
156 |
+
else:
|
157 |
+
raise ValueError(f"Unexpected image shape: {image.shape}")
|
158 |
+
elif image.ndim == 4:
|
159 |
+
if image.shape[1] == 3:
|
160 |
+
pass
|
161 |
+
elif image.shape[3] == 3:
|
162 |
+
image = image.transpose(0, 3, 1, 2)
|
163 |
+
else:
|
164 |
+
raise ValueError(f"Unexpected batch image shape: {image.shape}")
|
165 |
+
else:
|
166 |
+
raise ValueError(f"Unsupported input shape: {image.shape}")
|
167 |
+
|
168 |
+
mean = np.array([0.485, 0.456, 0.406])[None, :, None, None]
|
169 |
+
std = np.array([0.229, 0.224, 0.225])[None, :, None, None]
|
170 |
+
image = image * std + mean
|
171 |
+
|
172 |
+
if return_int:
|
173 |
+
image = np.clip(image * 255.0, 0, 255).astype(np.uint8)
|
174 |
+
else:
|
175 |
+
image = np.clip(image, 0.0, 1.0)
|
176 |
+
|
177 |
+
if image.shape[0] == 1:
|
178 |
+
image = image[0]
|
179 |
+
|
180 |
+
if return_numpy:
|
181 |
+
return image
|
182 |
+
else:
|
183 |
+
return torch.from_numpy(image)
|
184 |
+
|
185 |
+
|
186 |
+
def unnormalize_image_for_vis(image):
|
187 |
+
assert len(image.shape) == 5 and image.shape[2] == 3
|
188 |
+
mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 1, 3, 1, 1).to(image.device)
|
189 |
+
std = torch.tensor([0.229, 0.224, 0.225]).view(1, 1, 3, 1, 1).to(image.device)
|
190 |
+
image = image * std + mean
|
191 |
+
image = (image - 0.5) / 0.5
|
192 |
+
return image
|
193 |
+
|
194 |
+
|
195 |
+
def _transform_intrinsic(image, bbox, principal_point, focal_length):
|
196 |
+
# Rescale intrinsics to match bbox
|
197 |
+
half_box = np.array([image.width, image.height]).astype(np.float32) / 2
|
198 |
+
org_scale = min(half_box).astype(np.float32)
|
199 |
+
|
200 |
+
# Pixel coordinates
|
201 |
+
principal_point_px = half_box - (np.array(principal_point) * org_scale)
|
202 |
+
focal_length_px = np.array(focal_length) * org_scale
|
203 |
+
principal_point_px -= bbox[:2]
|
204 |
+
new_bbox = (bbox[2:] - bbox[:2]) / 2
|
205 |
+
new_scale = min(new_bbox)
|
206 |
+
|
207 |
+
# NDC coordinates
|
208 |
+
new_principal_ndc = (new_bbox - principal_point_px) / new_scale
|
209 |
+
new_focal_ndc = focal_length_px / new_scale
|
210 |
+
|
211 |
+
principal_point = torch.tensor(new_principal_ndc.astype(np.float32))
|
212 |
+
focal_length = torch.tensor(new_focal_ndc.astype(np.float32))
|
213 |
+
|
214 |
+
return principal_point, focal_length
|
215 |
+
|
216 |
+
|
217 |
+
def construct_camera_from_batch(batch, device):
|
218 |
+
if isinstance(device, int):
|
219 |
+
device = f"cuda:{device}"
|
220 |
+
|
221 |
+
return PerspectiveCameras(
|
222 |
+
R=batch["R"].reshape(-1, 3, 3),
|
223 |
+
T=batch["T"].reshape(-1, 3),
|
224 |
+
focal_length=batch["focal_lengths"].reshape(-1, 2),
|
225 |
+
principal_point=batch["principal_points"].reshape(-1, 2),
|
226 |
+
image_size=batch["image_sizes"].reshape(-1, 2),
|
227 |
+
device=device,
|
228 |
+
)
|
229 |
+
|
230 |
+
|
231 |
+
def save_batch_images(images, fname):
|
232 |
+
cmap = plt.get_cmap("hsv")
|
233 |
+
num_frames = len(images)
|
234 |
+
num_rows = len(images)
|
235 |
+
num_cols = 4
|
236 |
+
figsize = (num_cols * 2, num_rows * 2)
|
237 |
+
fig, axs = plt.subplots(num_rows, num_cols, figsize=figsize)
|
238 |
+
axs = axs.flatten()
|
239 |
+
for i in range(num_rows):
|
240 |
+
for j in range(4):
|
241 |
+
if i < num_frames:
|
242 |
+
axs[i * 4 + j].imshow(unnormalize_image(images[i][j]))
|
243 |
+
for s in ["bottom", "top", "left", "right"]:
|
244 |
+
axs[i * 4 + j].spines[s].set_color(cmap(i / (num_frames)))
|
245 |
+
axs[i * 4 + j].spines[s].set_linewidth(5)
|
246 |
+
axs[i * 4 + j].set_xticks([])
|
247 |
+
axs[i * 4 + j].set_yticks([])
|
248 |
+
else:
|
249 |
+
axs[i * 4 + j].axis("off")
|
250 |
+
plt.tight_layout()
|
251 |
+
plt.savefig(fname)
|
252 |
+
|
253 |
+
|
254 |
+
def jitter_bbox(
|
255 |
+
square_bbox,
|
256 |
+
jitter_scale=(1.1, 1.2),
|
257 |
+
jitter_trans=(-0.07, 0.07),
|
258 |
+
direction_from_size=None,
|
259 |
+
):
|
260 |
+
|
261 |
+
square_bbox = np.array(square_bbox.astype(float))
|
262 |
+
s = np.random.uniform(jitter_scale[0], jitter_scale[1])
|
263 |
+
|
264 |
+
# Jitter only one dimension if center cropping
|
265 |
+
tx, ty = np.random.uniform(jitter_trans[0], jitter_trans[1], size=2)
|
266 |
+
if direction_from_size is not None:
|
267 |
+
if direction_from_size[0] > direction_from_size[1]:
|
268 |
+
tx = 0
|
269 |
+
else:
|
270 |
+
ty = 0
|
271 |
+
|
272 |
+
side_length = square_bbox[2] - square_bbox[0]
|
273 |
+
center = (square_bbox[:2] + square_bbox[2:]) / 2 + np.array([tx, ty]) * side_length
|
274 |
+
extent = side_length / 2 * s
|
275 |
+
ul = center - extent
|
276 |
+
lr = ul + 2 * extent
|
277 |
+
return np.concatenate((ul, lr))
|
278 |
+
|
279 |
+
|
280 |
+
class Co3dDataset(Dataset):
|
281 |
+
def __init__(
|
282 |
+
self,
|
283 |
+
category=("all_train",),
|
284 |
+
split="train",
|
285 |
+
transform=None,
|
286 |
+
num_images=2,
|
287 |
+
img_size=224,
|
288 |
+
mask_images=False,
|
289 |
+
crop_images=True,
|
290 |
+
co3d_dir=None,
|
291 |
+
co3d_annotation_dir=None,
|
292 |
+
precropped_images=False,
|
293 |
+
apply_augmentation=True,
|
294 |
+
normalize_cameras=True,
|
295 |
+
no_images=False,
|
296 |
+
sample_num=None,
|
297 |
+
seed=0,
|
298 |
+
load_extra_cameras=False,
|
299 |
+
distort_image=False,
|
300 |
+
load_depths=False,
|
301 |
+
center_crop=False,
|
302 |
+
depth_size=256,
|
303 |
+
mask_holes=False,
|
304 |
+
object_mask=True,
|
305 |
+
):
|
306 |
+
"""
|
307 |
+
Args:
|
308 |
+
num_images: Number of images in each batch.
|
309 |
+
perspective_correction (str):
|
310 |
+
"none": No perspective correction.
|
311 |
+
"warp": Warp the image and label.
|
312 |
+
"label_only": Correct the label only.
|
313 |
+
"""
|
314 |
+
start_time = time.time()
|
315 |
+
|
316 |
+
self.category = category
|
317 |
+
self.split = split
|
318 |
+
self.transform = transform
|
319 |
+
self.num_images = num_images
|
320 |
+
self.img_size = img_size
|
321 |
+
self.mask_images = mask_images
|
322 |
+
self.crop_images = crop_images
|
323 |
+
self.precropped_images = precropped_images
|
324 |
+
self.apply_augmentation = apply_augmentation
|
325 |
+
self.normalize_cameras = normalize_cameras
|
326 |
+
self.no_images = no_images
|
327 |
+
self.sample_num = sample_num
|
328 |
+
self.load_extra_cameras = load_extra_cameras
|
329 |
+
self.distort = distort_image
|
330 |
+
self.load_depths = load_depths
|
331 |
+
self.center_crop = center_crop
|
332 |
+
self.depth_size = depth_size
|
333 |
+
self.mask_holes = mask_holes
|
334 |
+
self.object_mask = object_mask
|
335 |
+
|
336 |
+
if self.apply_augmentation:
|
337 |
+
if self.center_crop:
|
338 |
+
self.jitter_scale = (0.8, 1.1)
|
339 |
+
self.jitter_trans = (0.0, 0.0)
|
340 |
+
else:
|
341 |
+
self.jitter_scale = (1.1, 1.2)
|
342 |
+
self.jitter_trans = (-0.07, 0.07)
|
343 |
+
else:
|
344 |
+
# Note if trained with apply_augmentation, we should still use
|
345 |
+
# apply_augmentation at test time.
|
346 |
+
self.jitter_scale = (1, 1)
|
347 |
+
self.jitter_trans = (0.0, 0.0)
|
348 |
+
|
349 |
+
if self.distort:
|
350 |
+
self.k1_max = 1.0
|
351 |
+
self.k2_max = 1.0
|
352 |
+
|
353 |
+
if co3d_dir is not None:
|
354 |
+
self.co3d_dir = co3d_dir
|
355 |
+
self.co3d_annotation_dir = co3d_annotation_dir
|
356 |
+
else:
|
357 |
+
self.co3d_dir = CO3D_DIR
|
358 |
+
self.co3d_annotation_dir = CO3D_ANNOTATION_DIR
|
359 |
+
self.co3d_depth_dir = CO3D_DEPTH_DIR
|
360 |
+
|
361 |
+
if isinstance(self.category, str):
|
362 |
+
self.category = [self.category]
|
363 |
+
|
364 |
+
if "all_train" in self.category:
|
365 |
+
self.category = TRAINING_CATEGORIES
|
366 |
+
if "all_test" in self.category:
|
367 |
+
self.category = TEST_CATEGORIES
|
368 |
+
if "full" in self.category:
|
369 |
+
self.category = TRAINING_CATEGORIES + TEST_CATEGORIES
|
370 |
+
self.category = sorted(self.category)
|
371 |
+
self.is_single_category = len(self.category) == 1
|
372 |
+
|
373 |
+
# Fixing seed
|
374 |
+
torch.manual_seed(seed)
|
375 |
+
random.seed(seed)
|
376 |
+
np.random.seed(seed)
|
377 |
+
|
378 |
+
print(f"Co3d ({split}):")
|
379 |
+
|
380 |
+
self.low_quality_translations = [
|
381 |
+
"411_55952_107659",
|
382 |
+
"427_59915_115716",
|
383 |
+
"435_61970_121848",
|
384 |
+
"112_13265_22828",
|
385 |
+
"110_13069_25642",
|
386 |
+
"165_18080_34378",
|
387 |
+
"368_39891_78502",
|
388 |
+
"391_47029_93665",
|
389 |
+
"20_695_1450",
|
390 |
+
"135_15556_31096",
|
391 |
+
"417_57572_110680",
|
392 |
+
] # Initialized with sequences with poor depth masks
|
393 |
+
self.rotations = {}
|
394 |
+
self.category_map = {}
|
395 |
+
for c in tqdm(self.category):
|
396 |
+
annotation_file = osp.join(
|
397 |
+
self.co3d_annotation_dir, f"{c}_{self.split}.jgz"
|
398 |
+
)
|
399 |
+
with gzip.open(annotation_file, "r") as fin:
|
400 |
+
annotation = json.loads(fin.read())
|
401 |
+
|
402 |
+
counter = 0
|
403 |
+
for seq_name, seq_data in annotation.items():
|
404 |
+
counter += 1
|
405 |
+
if len(seq_data) < self.num_images:
|
406 |
+
continue
|
407 |
+
|
408 |
+
filtered_data = []
|
409 |
+
self.category_map[seq_name] = c
|
410 |
+
bad_seq = False
|
411 |
+
for data in seq_data:
|
412 |
+
# Make sure translations are not ridiculous and rotations are valid
|
413 |
+
det = np.linalg.det(data["R"])
|
414 |
+
if (np.abs(data["T"]) > 1e5).any() or det < 0.99 or det > 1.01:
|
415 |
+
bad_seq = True
|
416 |
+
self.low_quality_translations.append(seq_name)
|
417 |
+
break
|
418 |
+
|
419 |
+
# Ignore all unnecessary information.
|
420 |
+
filtered_data.append(
|
421 |
+
{
|
422 |
+
"filepath": data["filepath"],
|
423 |
+
"bbox": data["bbox"],
|
424 |
+
"R": data["R"],
|
425 |
+
"T": data["T"],
|
426 |
+
"focal_length": data["focal_length"],
|
427 |
+
"principal_point": data["principal_point"],
|
428 |
+
},
|
429 |
+
)
|
430 |
+
|
431 |
+
if not bad_seq:
|
432 |
+
self.rotations[seq_name] = filtered_data
|
433 |
+
|
434 |
+
self.sequence_list = list(self.rotations.keys())
|
435 |
+
|
436 |
+
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
|
437 |
+
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
|
438 |
+
|
439 |
+
if self.transform is None:
|
440 |
+
self.transform = transforms.Compose(
|
441 |
+
[
|
442 |
+
transforms.ToTensor(),
|
443 |
+
transforms.Resize(self.img_size, antialias=True),
|
444 |
+
transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
445 |
+
]
|
446 |
+
)
|
447 |
+
|
448 |
+
self.transform_depth = transforms.Compose(
|
449 |
+
[
|
450 |
+
transforms.Resize(
|
451 |
+
self.depth_size,
|
452 |
+
antialias=False,
|
453 |
+
interpolation=transforms.InterpolationMode.NEAREST_EXACT,
|
454 |
+
),
|
455 |
+
]
|
456 |
+
)
|
457 |
+
|
458 |
+
print(
|
459 |
+
f"Low quality translation sequences, not used: {self.low_quality_translations}"
|
460 |
+
)
|
461 |
+
print(f"Data size: {len(self)}")
|
462 |
+
print(f"Data loading took {(time.time()-start_time)} seconds.")
|
463 |
+
|
464 |
+
def __len__(self):
|
465 |
+
return len(self.sequence_list)
|
466 |
+
|
467 |
+
def __getitem__(self, index):
|
468 |
+
num_to_load = self.num_images if not self.load_extra_cameras else 8
|
469 |
+
|
470 |
+
sequence_name = self.sequence_list[index % len(self.sequence_list)]
|
471 |
+
metadata = self.rotations[sequence_name]
|
472 |
+
|
473 |
+
if self.sample_num is not None:
|
474 |
+
with open(
|
475 |
+
order_path.format(sample_num=self.sample_num, category=self.category[0])
|
476 |
+
) as f:
|
477 |
+
order = json.load(f)
|
478 |
+
ids = order[sequence_name][:num_to_load]
|
479 |
+
else:
|
480 |
+
replace = len(metadata) < 8
|
481 |
+
ids = np.random.choice(len(metadata), num_to_load, replace=replace)
|
482 |
+
|
483 |
+
return self.get_data(index=index, ids=ids, num_valid_frames=num_to_load)
|
484 |
+
|
485 |
+
def _get_scene_scale(self, sequence_name):
|
486 |
+
n = len(self.rotations[sequence_name])
|
487 |
+
|
488 |
+
R = torch.zeros(n, 3, 3)
|
489 |
+
T = torch.zeros(n, 3)
|
490 |
+
|
491 |
+
for i, ann in enumerate(self.rotations[sequence_name]):
|
492 |
+
R[i, ...] = torch.tensor(self.rotations[sequence_name][i]["R"])
|
493 |
+
T[i, ...] = torch.tensor(self.rotations[sequence_name][i]["T"])
|
494 |
+
|
495 |
+
cameras = PerspectiveCameras(R=R, T=T)
|
496 |
+
cc = cameras.get_camera_center()
|
497 |
+
centeroid = torch.mean(cc, dim=0)
|
498 |
+
diff = cc - centeroid
|
499 |
+
|
500 |
+
norm = torch.norm(diff, dim=1)
|
501 |
+
scale = torch.max(norm).item()
|
502 |
+
|
503 |
+
return scale
|
504 |
+
|
505 |
+
def _crop_image(self, image, bbox):
|
506 |
+
image_crop = transforms.functional.crop(
|
507 |
+
image,
|
508 |
+
top=bbox[1],
|
509 |
+
left=bbox[0],
|
510 |
+
height=bbox[3] - bbox[1],
|
511 |
+
width=bbox[2] - bbox[0],
|
512 |
+
)
|
513 |
+
return image_crop
|
514 |
+
|
515 |
+
def _transform_intrinsic(self, image, bbox, principal_point, focal_length):
|
516 |
+
half_box = np.array([image.width, image.height]).astype(np.float32) / 2
|
517 |
+
org_scale = min(half_box).astype(np.float32)
|
518 |
+
|
519 |
+
# Pixel coordinates
|
520 |
+
principal_point_px = half_box - (np.array(principal_point) * org_scale)
|
521 |
+
focal_length_px = np.array(focal_length) * org_scale
|
522 |
+
principal_point_px -= bbox[:2]
|
523 |
+
new_bbox = (bbox[2:] - bbox[:2]) / 2
|
524 |
+
new_scale = min(new_bbox)
|
525 |
+
|
526 |
+
# NDC coordinates
|
527 |
+
new_principal_ndc = (new_bbox - principal_point_px) / new_scale
|
528 |
+
new_focal_ndc = focal_length_px / new_scale
|
529 |
+
|
530 |
+
return new_principal_ndc.astype(np.float32), new_focal_ndc.astype(np.float32)
|
531 |
+
|
532 |
+
def get_data(
|
533 |
+
self,
|
534 |
+
index=None,
|
535 |
+
sequence_name=None,
|
536 |
+
ids=(0, 1),
|
537 |
+
no_images=False,
|
538 |
+
num_valid_frames=None,
|
539 |
+
load_using_order=None,
|
540 |
+
):
|
541 |
+
if load_using_order is not None:
|
542 |
+
with open(
|
543 |
+
order_path.format(sample_num=self.sample_num, category=self.category[0])
|
544 |
+
) as f:
|
545 |
+
order = json.load(f)
|
546 |
+
ids = order[sequence_name][:load_using_order]
|
547 |
+
|
548 |
+
if sequence_name is None:
|
549 |
+
index = index % len(self.sequence_list)
|
550 |
+
sequence_name = self.sequence_list[index]
|
551 |
+
metadata = self.rotations[sequence_name]
|
552 |
+
category = self.category_map[sequence_name]
|
553 |
+
|
554 |
+
# Read image & camera information from annotations
|
555 |
+
annos = [metadata[i] for i in ids]
|
556 |
+
images = []
|
557 |
+
image_sizes = []
|
558 |
+
PP = []
|
559 |
+
FL = []
|
560 |
+
crop_parameters = []
|
561 |
+
filenames = []
|
562 |
+
distortion_parameters = []
|
563 |
+
depths = []
|
564 |
+
depth_masks = []
|
565 |
+
object_masks = []
|
566 |
+
dino_images = []
|
567 |
+
for anno in annos:
|
568 |
+
filepath = anno["filepath"]
|
569 |
+
|
570 |
+
if not no_images:
|
571 |
+
image = Image.open(osp.join(self.co3d_dir, filepath)).convert("RGB")
|
572 |
+
image_size = image.size
|
573 |
+
|
574 |
+
# Optionally mask images with black background
|
575 |
+
if self.mask_images:
|
576 |
+
black_image = Image.new("RGB", image_size, (0, 0, 0))
|
577 |
+
mask_name = osp.basename(filepath.replace(".jpg", ".png"))
|
578 |
+
|
579 |
+
mask_path = osp.join(
|
580 |
+
self.co3d_dir, category, sequence_name, "masks", mask_name
|
581 |
+
)
|
582 |
+
mask = Image.open(mask_path).convert("L")
|
583 |
+
|
584 |
+
if mask.size != image_size:
|
585 |
+
mask = mask.resize(image_size)
|
586 |
+
mask = Image.fromarray(np.array(mask) > 125)
|
587 |
+
image = Image.composite(image, black_image, mask)
|
588 |
+
|
589 |
+
if self.object_mask:
|
590 |
+
mask_name = osp.basename(filepath.replace(".jpg", ".png"))
|
591 |
+
mask_path = osp.join(
|
592 |
+
self.co3d_dir, category, sequence_name, "masks", mask_name
|
593 |
+
)
|
594 |
+
mask = Image.open(mask_path).convert("L")
|
595 |
+
|
596 |
+
if mask.size != image_size:
|
597 |
+
mask = mask.resize(image_size)
|
598 |
+
mask = torch.from_numpy(np.array(mask) > 125)
|
599 |
+
|
600 |
+
# Determine crop, Resnet wants square images
|
601 |
+
bbox = np.array(anno["bbox"])
|
602 |
+
good_bbox = ((bbox[2:] - bbox[:2]) > 30).all()
|
603 |
+
bbox = (
|
604 |
+
anno["bbox"]
|
605 |
+
if not self.center_crop and good_bbox
|
606 |
+
else [0, 0, image.width, image.height]
|
607 |
+
)
|
608 |
+
|
609 |
+
# Distort image and bbox if desired
|
610 |
+
if self.distort:
|
611 |
+
k1 = random.uniform(0, self.k1_max)
|
612 |
+
k2 = random.uniform(0, self.k2_max)
|
613 |
+
|
614 |
+
try:
|
615 |
+
image, bbox = distort_image(
|
616 |
+
image, np.array(bbox), k1, k2, modify_bbox=True
|
617 |
+
)
|
618 |
+
|
619 |
+
except:
|
620 |
+
print("INFO:")
|
621 |
+
print(sequence_name)
|
622 |
+
print(index)
|
623 |
+
print(ids)
|
624 |
+
print(k1)
|
625 |
+
print(k2)
|
626 |
+
|
627 |
+
distortion_parameters.append(torch.FloatTensor([k1, k2]))
|
628 |
+
|
629 |
+
bbox = square_bbox(np.array(bbox), tight=self.center_crop)
|
630 |
+
if self.apply_augmentation:
|
631 |
+
bbox = jitter_bbox(
|
632 |
+
bbox,
|
633 |
+
jitter_scale=self.jitter_scale,
|
634 |
+
jitter_trans=self.jitter_trans,
|
635 |
+
direction_from_size=image.size if self.center_crop else None,
|
636 |
+
)
|
637 |
+
bbox = np.around(bbox).astype(int)
|
638 |
+
|
639 |
+
# Crop parameters
|
640 |
+
crop_center = (bbox[:2] + bbox[2:]) / 2
|
641 |
+
principal_point = torch.tensor(anno["principal_point"])
|
642 |
+
focal_length = torch.tensor(anno["focal_length"])
|
643 |
+
|
644 |
+
# convert crop center to correspond to a "square" image
|
645 |
+
width, height = image.size
|
646 |
+
length = max(width, height)
|
647 |
+
s = length / min(width, height)
|
648 |
+
crop_center = crop_center + (length - np.array([width, height])) / 2
|
649 |
+
|
650 |
+
# convert to NDC
|
651 |
+
cc = s - 2 * s * crop_center / length
|
652 |
+
crop_width = 2 * s * (bbox[2] - bbox[0]) / length
|
653 |
+
crop_params = torch.tensor([-cc[0], -cc[1], crop_width, s])
|
654 |
+
|
655 |
+
# Crop and normalize image
|
656 |
+
if not self.precropped_images:
|
657 |
+
image = self._crop_image(image, bbox)
|
658 |
+
|
659 |
+
try:
|
660 |
+
image = self.transform(image)
|
661 |
+
except:
|
662 |
+
print("INFO:")
|
663 |
+
print(sequence_name)
|
664 |
+
print(index)
|
665 |
+
print(ids)
|
666 |
+
print(k1)
|
667 |
+
print(k2)
|
668 |
+
|
669 |
+
images.append(image[:, : self.img_size, : self.img_size])
|
670 |
+
crop_parameters.append(crop_params)
|
671 |
+
|
672 |
+
if self.load_depths:
|
673 |
+
# Open depth map
|
674 |
+
depth_name = osp.basename(
|
675 |
+
filepath.replace(".jpg", ".jpg.geometric.png")
|
676 |
+
)
|
677 |
+
depth_path = osp.join(
|
678 |
+
self.co3d_depth_dir,
|
679 |
+
category,
|
680 |
+
sequence_name,
|
681 |
+
"depths",
|
682 |
+
depth_name,
|
683 |
+
)
|
684 |
+
depth_pil = Image.open(depth_path)
|
685 |
+
|
686 |
+
# 16 bit float type casting
|
687 |
+
depth = torch.tensor(
|
688 |
+
np.frombuffer(
|
689 |
+
np.array(depth_pil, dtype=np.uint16), dtype=np.float16
|
690 |
+
)
|
691 |
+
.astype(np.float32)
|
692 |
+
.reshape((depth_pil.size[1], depth_pil.size[0]))
|
693 |
+
)
|
694 |
+
|
695 |
+
# Crop and resize as with images
|
696 |
+
if depth_pil.size != image_size:
|
697 |
+
# bbox may have the wrong scale
|
698 |
+
bbox = depth_pil.size[0] * bbox / image_size[0]
|
699 |
+
|
700 |
+
if self.object_mask:
|
701 |
+
assert mask.shape == depth.shape
|
702 |
+
|
703 |
+
bbox = np.around(bbox).astype(int)
|
704 |
+
depth = self._crop_image(depth, bbox)
|
705 |
+
|
706 |
+
# Resize
|
707 |
+
depth = self.transform_depth(depth.unsqueeze(0))[
|
708 |
+
0, : self.depth_size, : self.depth_size
|
709 |
+
]
|
710 |
+
depths.append(depth)
|
711 |
+
|
712 |
+
if self.object_mask:
|
713 |
+
mask = self._crop_image(mask, bbox)
|
714 |
+
mask = self.transform_depth(mask.unsqueeze(0))[
|
715 |
+
0, : self.depth_size, : self.depth_size
|
716 |
+
]
|
717 |
+
object_masks.append(mask)
|
718 |
+
|
719 |
+
PP.append(principal_point)
|
720 |
+
FL.append(focal_length)
|
721 |
+
image_sizes.append(torch.tensor([self.img_size, self.img_size]))
|
722 |
+
filenames.append(filepath)
|
723 |
+
|
724 |
+
if not no_images:
|
725 |
+
if self.load_depths:
|
726 |
+
depths = torch.stack(depths)
|
727 |
+
|
728 |
+
depth_masks = torch.logical_or(depths <= 0, depths.isinf())
|
729 |
+
depth_masks = (~depth_masks).long()
|
730 |
+
|
731 |
+
if self.object_mask:
|
732 |
+
object_masks = torch.stack(object_masks, dim=0)
|
733 |
+
|
734 |
+
if self.mask_holes:
|
735 |
+
depths = fill_depths(depths, depth_masks == 0)
|
736 |
+
|
737 |
+
# Sometimes mask_holes misses stuff
|
738 |
+
new_masks = torch.logical_or(depths <= 0, depths.isinf())
|
739 |
+
new_masks = (~new_masks).long()
|
740 |
+
depths[new_masks == 0] = -1
|
741 |
+
|
742 |
+
assert torch.logical_or(depths > 0, depths == -1).all()
|
743 |
+
assert not (depths.isinf()).any()
|
744 |
+
assert not (depths.isnan()).any()
|
745 |
+
|
746 |
+
if self.load_extra_cameras:
|
747 |
+
# Remove the extra loaded image, for saving space
|
748 |
+
images = images[: self.num_images]
|
749 |
+
|
750 |
+
if self.distort:
|
751 |
+
distortion_parameters = torch.stack(distortion_parameters)
|
752 |
+
|
753 |
+
images = torch.stack(images)
|
754 |
+
crop_parameters = torch.stack(crop_parameters)
|
755 |
+
focal_lengths = torch.stack(FL)
|
756 |
+
principal_points = torch.stack(PP)
|
757 |
+
image_sizes = torch.stack(image_sizes)
|
758 |
+
else:
|
759 |
+
images = None
|
760 |
+
crop_parameters = None
|
761 |
+
distortion_parameters = None
|
762 |
+
focal_lengths = []
|
763 |
+
principal_points = []
|
764 |
+
image_sizes = []
|
765 |
+
|
766 |
+
# Assemble batch info to send back
|
767 |
+
R = torch.stack([torch.tensor(anno["R"]) for anno in annos])
|
768 |
+
T = torch.stack([torch.tensor(anno["T"]) for anno in annos])
|
769 |
+
|
770 |
+
batch = {
|
771 |
+
"model_id": sequence_name,
|
772 |
+
"category": category,
|
773 |
+
"n": len(metadata),
|
774 |
+
"num_valid_frames": num_valid_frames,
|
775 |
+
"ind": torch.tensor(ids),
|
776 |
+
"image": images,
|
777 |
+
"depth": depths,
|
778 |
+
"depth_masks": depth_masks,
|
779 |
+
"object_masks": object_masks,
|
780 |
+
"R": R,
|
781 |
+
"T": T,
|
782 |
+
"focal_length": focal_lengths,
|
783 |
+
"principal_point": principal_points,
|
784 |
+
"image_size": image_sizes,
|
785 |
+
"crop_parameters": crop_parameters,
|
786 |
+
"distortion_parameters": torch.zeros(4),
|
787 |
+
"filename": filenames,
|
788 |
+
"category": category,
|
789 |
+
"dataset": "co3d",
|
790 |
+
}
|
791 |
+
|
792 |
+
return batch
|
diffusionsfm/dataset/custom.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
|
6 |
+
from PIL import Image, ImageOps
|
7 |
+
from torch.utils.data import Dataset
|
8 |
+
from torchvision import transforms
|
9 |
+
|
10 |
+
from diffusionsfm.dataset.co3d_v2 import square_bbox
|
11 |
+
|
12 |
+
|
13 |
+
class CustomDataset(Dataset):
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
image_list,
|
17 |
+
):
|
18 |
+
self.images = []
|
19 |
+
|
20 |
+
for image_path in sorted(image_list):
|
21 |
+
img = Image.open(image_path)
|
22 |
+
img = ImageOps.exif_transpose(img).convert("RGB") # Apply EXIF rotation
|
23 |
+
self.images.append(img)
|
24 |
+
|
25 |
+
self.n = len(self.images)
|
26 |
+
self.jitter_scale = [1, 1]
|
27 |
+
self.jitter_trans = [0, 0]
|
28 |
+
self.transform = transforms.Compose(
|
29 |
+
[
|
30 |
+
transforms.ToTensor(),
|
31 |
+
transforms.Resize(224),
|
32 |
+
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
33 |
+
]
|
34 |
+
)
|
35 |
+
self.transform_for_vis = transforms.Compose(
|
36 |
+
[
|
37 |
+
transforms.Resize(224),
|
38 |
+
]
|
39 |
+
)
|
40 |
+
|
41 |
+
def __len__(self):
|
42 |
+
return 1
|
43 |
+
|
44 |
+
def _crop_image(self, image, bbox, white_bg=False):
|
45 |
+
if white_bg:
|
46 |
+
# Only support PIL Images
|
47 |
+
image_crop = Image.new(
|
48 |
+
"RGB", (bbox[2] - bbox[0], bbox[3] - bbox[1]), (255, 255, 255)
|
49 |
+
)
|
50 |
+
image_crop.paste(image, (-bbox[0], -bbox[1]))
|
51 |
+
else:
|
52 |
+
image_crop = transforms.functional.crop(
|
53 |
+
image,
|
54 |
+
top=bbox[1],
|
55 |
+
left=bbox[0],
|
56 |
+
height=bbox[3] - bbox[1],
|
57 |
+
width=bbox[2] - bbox[0],
|
58 |
+
)
|
59 |
+
return image_crop
|
60 |
+
|
61 |
+
def __getitem__(self):
|
62 |
+
return self.get_data()
|
63 |
+
|
64 |
+
def get_data(self):
|
65 |
+
cmap = plt.get_cmap("hsv")
|
66 |
+
ids = [i for i in range(len(self.images))]
|
67 |
+
images = [self.images[i] for i in ids]
|
68 |
+
images_transformed = []
|
69 |
+
images_for_vis = []
|
70 |
+
crop_parameters = []
|
71 |
+
|
72 |
+
for i, image in enumerate(images):
|
73 |
+
bbox = np.array([0, 0, image.width, image.height])
|
74 |
+
bbox = square_bbox(bbox, tight=True)
|
75 |
+
bbox = np.around(bbox).astype(int)
|
76 |
+
image = self._crop_image(image, bbox)
|
77 |
+
images_transformed.append(self.transform(image))
|
78 |
+
image_for_vis = self.transform_for_vis(image)
|
79 |
+
color_float = cmap(i / len(images))
|
80 |
+
color_rgb = tuple(int(255 * c) for c in color_float[:3])
|
81 |
+
image_for_vis = ImageOps.expand(image_for_vis, border=3, fill=color_rgb)
|
82 |
+
images_for_vis.append(image_for_vis)
|
83 |
+
|
84 |
+
width, height = image.size
|
85 |
+
length = max(width, height)
|
86 |
+
s = length / min(width, height)
|
87 |
+
crop_center = (bbox[:2] + bbox[2:]) / 2
|
88 |
+
crop_center = crop_center + (length - np.array([width, height])) / 2
|
89 |
+
# convert to NDC
|
90 |
+
cc = s - 2 * s * crop_center / length
|
91 |
+
crop_width = 2 * s * (bbox[2] - bbox[0]) / length
|
92 |
+
crop_params = torch.tensor([-cc[0], -cc[1], crop_width, s])
|
93 |
+
|
94 |
+
crop_parameters.append(crop_params)
|
95 |
+
images = images_transformed
|
96 |
+
|
97 |
+
batch = {}
|
98 |
+
batch["image"] = torch.stack(images)
|
99 |
+
batch["image_for_vis"] = images_for_vis
|
100 |
+
batch["n"] = len(images)
|
101 |
+
batch["ind"] = torch.tensor(ids),
|
102 |
+
batch["crop_parameters"] = torch.stack(crop_parameters)
|
103 |
+
batch["distortion_parameters"] = torch.zeros(4)
|
104 |
+
|
105 |
+
return batch
|
diffusionsfm/eval/__init__.py
ADDED
File without changes
|
diffusionsfm/eval/eval_category.py
ADDED
@@ -0,0 +1,292 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import torch
|
4 |
+
import torchvision
|
5 |
+
import numpy as np
|
6 |
+
from tqdm.auto import tqdm
|
7 |
+
|
8 |
+
from diffusionsfm.dataset.co3d_v2 import (
|
9 |
+
Co3dDataset,
|
10 |
+
full_scene_scale,
|
11 |
+
)
|
12 |
+
from pytorch3d.renderer import PerspectiveCameras
|
13 |
+
from diffusionsfm.utils.visualization import filter_and_align_point_clouds
|
14 |
+
from diffusionsfm.inference.load_model import load_model
|
15 |
+
from diffusionsfm.inference.predict import predict_cameras
|
16 |
+
from diffusionsfm.utils.geometry import (
|
17 |
+
compute_angular_error_batch,
|
18 |
+
get_error,
|
19 |
+
n_to_np_rotations,
|
20 |
+
)
|
21 |
+
from diffusionsfm.utils.slurm import init_slurm_signals_if_slurm
|
22 |
+
from diffusionsfm.utils.rays import cameras_to_rays
|
23 |
+
from diffusionsfm.utils.rays import normalize_cameras_batch
|
24 |
+
|
25 |
+
|
26 |
+
@torch.no_grad()
|
27 |
+
def evaluate(
|
28 |
+
cfg,
|
29 |
+
model,
|
30 |
+
dataset,
|
31 |
+
num_images,
|
32 |
+
device,
|
33 |
+
use_pbar=True,
|
34 |
+
calculate_intrinsics=True,
|
35 |
+
additional_timesteps=(),
|
36 |
+
num_evaluate=None,
|
37 |
+
max_num_images=None,
|
38 |
+
mode=None,
|
39 |
+
metrics=True,
|
40 |
+
load_depth=True,
|
41 |
+
):
|
42 |
+
if cfg.training.get("dpt_head", False):
|
43 |
+
H_in = W_in = 224
|
44 |
+
H_out = W_out = cfg.training.full_num_patches_y
|
45 |
+
else:
|
46 |
+
H_in = H_out = cfg.model.num_patches_x
|
47 |
+
W_in = W_out = cfg.model.num_patches_y
|
48 |
+
|
49 |
+
results = {}
|
50 |
+
instances = np.arange(0, len(dataset)) if num_evaluate is None else np.linspace(0, len(dataset) - 1, num_evaluate, endpoint=True, dtype=int)
|
51 |
+
instances = tqdm(instances) if use_pbar else instances
|
52 |
+
|
53 |
+
for counter, idx in enumerate(instances):
|
54 |
+
batch = dataset[idx]
|
55 |
+
instance = batch["model_id"]
|
56 |
+
images = batch["image"].to(device)
|
57 |
+
focal_length = batch["focal_length"].to(device)[:num_images]
|
58 |
+
R = batch["R"].to(device)[:num_images]
|
59 |
+
T = batch["T"].to(device)[:num_images]
|
60 |
+
crop_parameters = batch["crop_parameters"].to(device)[:num_images]
|
61 |
+
|
62 |
+
if load_depth:
|
63 |
+
depths = batch["depth"].to(device)[:num_images]
|
64 |
+
depth_masks = batch["depth_masks"].to(device)[:num_images]
|
65 |
+
try:
|
66 |
+
object_masks = batch["object_masks"].to(device)[:num_images]
|
67 |
+
except KeyError:
|
68 |
+
object_masks = depth_masks.clone()
|
69 |
+
|
70 |
+
# Normalize cameras and scale depths for output resolution
|
71 |
+
cameras_gt = PerspectiveCameras(
|
72 |
+
R=R, T=T, focal_length=focal_length, device=device
|
73 |
+
)
|
74 |
+
cameras_gt, _, _ = normalize_cameras_batch(
|
75 |
+
[cameras_gt],
|
76 |
+
first_cam_mediod=cfg.training.first_cam_mediod,
|
77 |
+
normalize_first_camera=cfg.training.normalize_first_camera,
|
78 |
+
depths=depths.unsqueeze(0),
|
79 |
+
crop_parameters=crop_parameters.unsqueeze(0),
|
80 |
+
num_patches_x=H_in,
|
81 |
+
num_patches_y=W_in,
|
82 |
+
return_scales=True,
|
83 |
+
)
|
84 |
+
cameras_gt = cameras_gt[0]
|
85 |
+
|
86 |
+
gt_rays = cameras_to_rays(
|
87 |
+
cameras=cameras_gt,
|
88 |
+
num_patches_x=H_in,
|
89 |
+
num_patches_y=W_in,
|
90 |
+
crop_parameters=crop_parameters,
|
91 |
+
depths=depths,
|
92 |
+
mode=mode,
|
93 |
+
)
|
94 |
+
gt_points = gt_rays.get_segments().view(num_images, -1, 3)
|
95 |
+
|
96 |
+
resize = torchvision.transforms.Resize(
|
97 |
+
224,
|
98 |
+
antialias=False,
|
99 |
+
interpolation=torchvision.transforms.InterpolationMode.NEAREST_EXACT,
|
100 |
+
)
|
101 |
+
else:
|
102 |
+
cameras_gt = PerspectiveCameras(
|
103 |
+
R=R, T=T, focal_length=focal_length, device=device
|
104 |
+
)
|
105 |
+
|
106 |
+
pred_cameras, additional_cams = predict_cameras(
|
107 |
+
model,
|
108 |
+
images,
|
109 |
+
device,
|
110 |
+
crop_parameters=crop_parameters,
|
111 |
+
num_patches_x=H_out,
|
112 |
+
num_patches_y=W_out,
|
113 |
+
max_num_images=max_num_images,
|
114 |
+
additional_timesteps=additional_timesteps,
|
115 |
+
calculate_intrinsics=calculate_intrinsics,
|
116 |
+
mode=mode,
|
117 |
+
return_rays=True,
|
118 |
+
use_homogeneous=cfg.model.get("use_homogeneous", False),
|
119 |
+
)
|
120 |
+
cameras_to_evaluate = additional_cams + [pred_cameras]
|
121 |
+
|
122 |
+
all_cams_batch = dataset.get_data(
|
123 |
+
sequence_name=instance, ids=np.arange(0, batch["n"]), no_images=True
|
124 |
+
)
|
125 |
+
gt_scene_scale = full_scene_scale(all_cams_batch)
|
126 |
+
R_gt = R
|
127 |
+
T_gt = T
|
128 |
+
|
129 |
+
errors = []
|
130 |
+
for _, (camera, pred_rays) in enumerate(cameras_to_evaluate):
|
131 |
+
R_pred = camera.R
|
132 |
+
T_pred = camera.T
|
133 |
+
f_pred = camera.focal_length
|
134 |
+
|
135 |
+
R_pred_rel = n_to_np_rotations(num_images, R_pred).cpu().numpy()
|
136 |
+
R_gt_rel = n_to_np_rotations(num_images, batch["R"]).cpu().numpy()
|
137 |
+
R_error = compute_angular_error_batch(R_pred_rel, R_gt_rel)
|
138 |
+
|
139 |
+
CC_error, _ = get_error(True, R_pred, T_pred, R_gt, T_gt, gt_scene_scale)
|
140 |
+
|
141 |
+
if load_depth and metrics:
|
142 |
+
# Evaluate outputs at the same resolution as DUSt3R
|
143 |
+
pred_points = pred_rays.get_segments().view(num_images, H_out, H_out, 3)
|
144 |
+
pred_points = pred_points.permute(0, 3, 1, 2)
|
145 |
+
pred_points = resize(pred_points).permute(0, 2, 3, 1).view(num_images, H_in*W_in, 3)
|
146 |
+
|
147 |
+
(
|
148 |
+
_,
|
149 |
+
_,
|
150 |
+
_,
|
151 |
+
_,
|
152 |
+
metric_values,
|
153 |
+
) = filter_and_align_point_clouds(
|
154 |
+
num_images,
|
155 |
+
gt_points,
|
156 |
+
pred_points,
|
157 |
+
depth_masks,
|
158 |
+
depth_masks,
|
159 |
+
images,
|
160 |
+
metrics=metrics,
|
161 |
+
num_patches_x=H_in,
|
162 |
+
)
|
163 |
+
|
164 |
+
(
|
165 |
+
_,
|
166 |
+
_,
|
167 |
+
_,
|
168 |
+
_,
|
169 |
+
object_metric_values,
|
170 |
+
) = filter_and_align_point_clouds(
|
171 |
+
num_images,
|
172 |
+
gt_points,
|
173 |
+
pred_points,
|
174 |
+
depth_masks * object_masks,
|
175 |
+
depth_masks * object_masks,
|
176 |
+
images,
|
177 |
+
metrics=metrics,
|
178 |
+
num_patches_x=H_in,
|
179 |
+
)
|
180 |
+
|
181 |
+
result = {
|
182 |
+
"R_pred": R_pred.detach().cpu().numpy().tolist(),
|
183 |
+
"T_pred": T_pred.detach().cpu().numpy().tolist(),
|
184 |
+
"f_pred": f_pred.detach().cpu().numpy().tolist(),
|
185 |
+
"R_gt": R_gt.detach().cpu().numpy().tolist(),
|
186 |
+
"T_gt": T_gt.detach().cpu().numpy().tolist(),
|
187 |
+
"f_gt": focal_length.detach().cpu().numpy().tolist(),
|
188 |
+
"scene_scale": gt_scene_scale,
|
189 |
+
"R_error": R_error.tolist(),
|
190 |
+
"CC_error": CC_error,
|
191 |
+
}
|
192 |
+
|
193 |
+
if load_depth and metrics:
|
194 |
+
result["CD"] = metric_values[1]
|
195 |
+
result["CD_Object"] = object_metric_values[1]
|
196 |
+
else:
|
197 |
+
result["CD"] = 0
|
198 |
+
result["CD_Object"] = 0
|
199 |
+
|
200 |
+
errors.append(result)
|
201 |
+
|
202 |
+
results[instance] = errors
|
203 |
+
|
204 |
+
if counter == len(dataset) - 1:
|
205 |
+
break
|
206 |
+
return results
|
207 |
+
|
208 |
+
|
209 |
+
def save_results(
|
210 |
+
output_dir,
|
211 |
+
checkpoint=800_000,
|
212 |
+
category="hydrant",
|
213 |
+
num_images=None,
|
214 |
+
calculate_additional_timesteps=True,
|
215 |
+
calculate_intrinsics=True,
|
216 |
+
split="test",
|
217 |
+
force=False,
|
218 |
+
sample_num=1,
|
219 |
+
max_num_images=None,
|
220 |
+
dataset="co3d",
|
221 |
+
):
|
222 |
+
init_slurm_signals_if_slurm()
|
223 |
+
os.umask(000) # Default to 777 permissions
|
224 |
+
eval_path = os.path.join(
|
225 |
+
output_dir,
|
226 |
+
f"eval_{dataset}",
|
227 |
+
f"{category}_{num_images}_{sample_num}_ckpt{checkpoint}.json",
|
228 |
+
)
|
229 |
+
|
230 |
+
if os.path.exists(eval_path) and not force:
|
231 |
+
print(f"File {eval_path} already exists. Skipping.")
|
232 |
+
return
|
233 |
+
|
234 |
+
if num_images is not None and num_images > 8:
|
235 |
+
custom_keys = {"model.num_images": num_images}
|
236 |
+
ignore_keys = ["pos_table"]
|
237 |
+
else:
|
238 |
+
custom_keys = None
|
239 |
+
ignore_keys = []
|
240 |
+
|
241 |
+
device = torch.device("cuda")
|
242 |
+
model, cfg = load_model(
|
243 |
+
output_dir,
|
244 |
+
checkpoint=checkpoint,
|
245 |
+
device=device,
|
246 |
+
custom_keys=custom_keys,
|
247 |
+
ignore_keys=ignore_keys,
|
248 |
+
)
|
249 |
+
if num_images is None:
|
250 |
+
num_images = cfg.dataset.num_images
|
251 |
+
|
252 |
+
if cfg.training.dpt_head:
|
253 |
+
# Evaluate outputs at the same resolution as DUSt3R
|
254 |
+
depth_size = 224
|
255 |
+
else:
|
256 |
+
depth_size = cfg.model.num_patches_x
|
257 |
+
|
258 |
+
dataset = Co3dDataset(
|
259 |
+
category=category,
|
260 |
+
split=split,
|
261 |
+
num_images=num_images,
|
262 |
+
apply_augmentation=False,
|
263 |
+
sample_num=None if split == "train" else sample_num,
|
264 |
+
use_global_intrinsics=cfg.dataset.use_global_intrinsics,
|
265 |
+
load_depths=True,
|
266 |
+
center_crop=True,
|
267 |
+
depth_size=depth_size,
|
268 |
+
mask_holes=not cfg.training.regression,
|
269 |
+
img_size=256 if cfg.model.unet_diffuser else 224,
|
270 |
+
)
|
271 |
+
print(f"Category {category} {len(dataset)}")
|
272 |
+
|
273 |
+
if calculate_additional_timesteps:
|
274 |
+
additional_timesteps = [0, 10, 20, 30, 40, 50, 60, 70, 80, 90]
|
275 |
+
else:
|
276 |
+
additional_timesteps = []
|
277 |
+
|
278 |
+
results = evaluate(
|
279 |
+
cfg=cfg,
|
280 |
+
model=model,
|
281 |
+
dataset=dataset,
|
282 |
+
num_images=num_images,
|
283 |
+
device=device,
|
284 |
+
calculate_intrinsics=calculate_intrinsics,
|
285 |
+
additional_timesteps=additional_timesteps,
|
286 |
+
max_num_images=max_num_images,
|
287 |
+
mode="segment",
|
288 |
+
)
|
289 |
+
|
290 |
+
os.makedirs(os.path.dirname(eval_path), exist_ok=True)
|
291 |
+
with open(eval_path, "w") as f:
|
292 |
+
json.dump(results, f)
|
diffusionsfm/eval/eval_jobs.py
ADDED
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
python -m diffusionsfm.eval.eval_jobs --eval_path output/multi_diffusionsfm_dense --use_submitit
|
3 |
+
"""
|
4 |
+
|
5 |
+
import os
|
6 |
+
import json
|
7 |
+
import submitit
|
8 |
+
import argparse
|
9 |
+
import itertools
|
10 |
+
from glob import glob
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
from tqdm.auto import tqdm
|
14 |
+
|
15 |
+
from diffusionsfm.dataset.co3d_v2 import TEST_CATEGORIES, TRAINING_CATEGORIES
|
16 |
+
from diffusionsfm.eval.eval_category import save_results
|
17 |
+
from diffusionsfm.utils.slurm import submitit_job_watcher
|
18 |
+
|
19 |
+
|
20 |
+
def evaluate_diffusionsfm(eval_path, use_submitit, mode):
|
21 |
+
JOB_PARAMS = {
|
22 |
+
"output_dir": [eval_path],
|
23 |
+
"checkpoint": [800_000],
|
24 |
+
"num_images": [2, 3, 4, 5, 6, 7, 8],
|
25 |
+
"sample_num": [0, 1, 2, 3, 4],
|
26 |
+
"category": TEST_CATEGORIES, # TRAINING_CATEGORIES + TEST_CATEGORIES,
|
27 |
+
"calculate_additional_timesteps": [True],
|
28 |
+
}
|
29 |
+
if mode == "test":
|
30 |
+
JOB_PARAMS["category"] = TEST_CATEGORIES
|
31 |
+
elif mode == "train1":
|
32 |
+
JOB_PARAMS["category"] = TRAINING_CATEGORIES[:len(TRAINING_CATEGORIES) // 2]
|
33 |
+
elif mode == "train2":
|
34 |
+
JOB_PARAMS["category"] = TRAINING_CATEGORIES[len(TRAINING_CATEGORIES) // 2:]
|
35 |
+
keys, values = zip(*JOB_PARAMS.items())
|
36 |
+
job_configs = [dict(zip(keys, p)) for p in itertools.product(*values)]
|
37 |
+
|
38 |
+
if use_submitit:
|
39 |
+
log_output = "./slurm_logs"
|
40 |
+
executor = submitit.AutoExecutor(
|
41 |
+
cluster=None, folder=log_output, slurm_max_num_timeout=10
|
42 |
+
)
|
43 |
+
# Use your own parameters
|
44 |
+
executor.update_parameters(
|
45 |
+
slurm_additional_parameters={
|
46 |
+
"nodes": 1,
|
47 |
+
"cpus-per-task": 5,
|
48 |
+
"gpus": 1,
|
49 |
+
"time": "6:00:00",
|
50 |
+
"partition": "all",
|
51 |
+
"exclude": "grogu-1-9, grogu-1-14,"
|
52 |
+
}
|
53 |
+
)
|
54 |
+
jobs = []
|
55 |
+
with executor.batch():
|
56 |
+
# This context manager submits all jobs at once at the end.
|
57 |
+
for params in job_configs:
|
58 |
+
job = executor.submit(save_results, **params)
|
59 |
+
job_param = f"{params['category']}_N{params['num_images']}_{params['sample_num']}"
|
60 |
+
jobs.append((job_param, job))
|
61 |
+
jobs = {f"{job_param}_{job.job_id}": job for job_param, job in jobs}
|
62 |
+
submitit_job_watcher(jobs)
|
63 |
+
else:
|
64 |
+
for job_config in tqdm(job_configs):
|
65 |
+
# This is much slower.
|
66 |
+
save_results(**job_config)
|
67 |
+
|
68 |
+
|
69 |
+
def process_predictions(eval_path, pred_index, checkpoint=800_000, threshold_R=15, threshold_CC=0.1):
|
70 |
+
"""
|
71 |
+
pred_index should be 1 (corresponding to T=90)
|
72 |
+
"""
|
73 |
+
def aggregate_per_category(categories, metric_key, num_images, sample_num, threshold=None):
|
74 |
+
"""
|
75 |
+
Aggregates one metric over all data points in a prediction file and then across categories.
|
76 |
+
- For R_error and CC_error: use mean to threshold-based accuracy
|
77 |
+
- For CD and CD_Object: use median to reduce the effect of outliers
|
78 |
+
"""
|
79 |
+
per_category_values = []
|
80 |
+
|
81 |
+
for category in tqdm(categories, desc=f"Sample {sample_num}, N={num_images}, {metric_key}"):
|
82 |
+
per_pred_values = []
|
83 |
+
|
84 |
+
data_path = glob(
|
85 |
+
os.path.join(eval_path, "eval", f"{category}_{num_images}_{sample_num}_ckpt{checkpoint}*.json")
|
86 |
+
)[0]
|
87 |
+
|
88 |
+
with open(data_path) as f:
|
89 |
+
eval_data = json.load(f)
|
90 |
+
|
91 |
+
for preds in eval_data.values():
|
92 |
+
if metric_key in ["R_error", "CC_error"]:
|
93 |
+
vals = np.array(preds[pred_index][metric_key])
|
94 |
+
per_pred_values.append(np.mean(vals < threshold))
|
95 |
+
else:
|
96 |
+
per_pred_values.append(preds[pred_index][metric_key])
|
97 |
+
|
98 |
+
# Aggregate over all predictions within this category
|
99 |
+
per_category_values.append(
|
100 |
+
np.mean(per_pred_values) if metric_key in ["R_error", "CC_error"]
|
101 |
+
else np.median(per_pred_values) # CD or CD_Object — use median to filter outliers
|
102 |
+
)
|
103 |
+
|
104 |
+
if metric_key in ["R_error", "CC_error"]:
|
105 |
+
return np.mean(per_category_values)
|
106 |
+
else:
|
107 |
+
return np.median(per_category_values)
|
108 |
+
|
109 |
+
def aggregate_metric(categories, metric_key, num_images, threshold=None):
|
110 |
+
"""Aggregates one metric over 5 random samples per category and returns the final mean"""
|
111 |
+
return np.mean([
|
112 |
+
aggregate_per_category(categories, metric_key, num_images, sample_num, threshold=threshold)
|
113 |
+
for sample_num in range(5)
|
114 |
+
])
|
115 |
+
|
116 |
+
# Output containers
|
117 |
+
all_seen_acc_R, all_seen_acc_CC = [], []
|
118 |
+
all_seen_CD, all_seen_CD_Object = [], []
|
119 |
+
all_unseen_acc_R, all_unseen_acc_CC = [], []
|
120 |
+
all_unseen_CD, all_unseen_CD_Object = [], []
|
121 |
+
|
122 |
+
for num_images in range(2, 9):
|
123 |
+
# Seen categories
|
124 |
+
all_seen_acc_R.append(
|
125 |
+
aggregate_metric(TRAINING_CATEGORIES, "R_error", num_images, threshold=threshold_R)
|
126 |
+
)
|
127 |
+
all_seen_acc_CC.append(
|
128 |
+
aggregate_metric(TRAINING_CATEGORIES, "CC_error", num_images, threshold=threshold_CC)
|
129 |
+
)
|
130 |
+
all_seen_CD.append(
|
131 |
+
aggregate_metric(TRAINING_CATEGORIES, "CD", num_images)
|
132 |
+
)
|
133 |
+
all_seen_CD_Object.append(
|
134 |
+
aggregate_metric(TRAINING_CATEGORIES, "CD_Object", num_images)
|
135 |
+
)
|
136 |
+
|
137 |
+
# Unseen categories
|
138 |
+
all_unseen_acc_R.append(
|
139 |
+
aggregate_metric(TEST_CATEGORIES, "R_error", num_images, threshold=threshold_R)
|
140 |
+
)
|
141 |
+
all_unseen_acc_CC.append(
|
142 |
+
aggregate_metric(TEST_CATEGORIES, "CC_error", num_images, threshold=threshold_CC)
|
143 |
+
)
|
144 |
+
all_unseen_CD.append(
|
145 |
+
aggregate_metric(TEST_CATEGORIES, "CD", num_images)
|
146 |
+
)
|
147 |
+
all_unseen_CD_Object.append(
|
148 |
+
aggregate_metric(TEST_CATEGORIES, "CD_Object", num_images)
|
149 |
+
)
|
150 |
+
|
151 |
+
# Print the results in formatted rows
|
152 |
+
print("N= ", " ".join(f"{i: 5}" for i in range(2, 9)))
|
153 |
+
print("Seen R ", " ".join([f"{x:0.3f}" for x in all_seen_acc_R]))
|
154 |
+
print("Seen CC ", " ".join([f"{x:0.3f}" for x in all_seen_acc_CC]))
|
155 |
+
print("Seen CD ", " ".join([f"{x:0.3f}" for x in all_seen_CD]))
|
156 |
+
print("Seen CD_Obj ", " ".join([f"{x:0.3f}" for x in all_seen_CD_Object]))
|
157 |
+
print("Unseen R ", " ".join([f"{x:0.3f}" for x in all_unseen_acc_R]))
|
158 |
+
print("Unseen CC ", " ".join([f"{x:0.3f}" for x in all_unseen_acc_CC]))
|
159 |
+
print("Unseen CD ", " ".join([f"{x:0.3f}" for x in all_unseen_CD]))
|
160 |
+
print("Unseen CD_Obj", " ".join([f"{x:0.3f}" for x in all_unseen_CD_Object]))
|
161 |
+
|
162 |
+
|
163 |
+
if __name__ == "__main__":
|
164 |
+
parser = argparse.ArgumentParser()
|
165 |
+
parser.add_argument("--eval_path", type=str, default=None)
|
166 |
+
parser.add_argument("--use_submitit", action="store_true")
|
167 |
+
parser.add_argument("--mode", type=str, default="test")
|
168 |
+
args = parser.parse_args()
|
169 |
+
|
170 |
+
eval_path = "output/multi_diffusionsfm_dense" if args.eval_path is None else args.eval_path
|
171 |
+
use_submitit = args.use_submitit
|
172 |
+
mode = args.mode
|
173 |
+
|
174 |
+
evaluate_diffusionsfm(eval_path, use_submitit, mode)
|
175 |
+
process_predictions(eval_path, 1)
|
diffusionsfm/inference/__init__.py
ADDED
File without changes
|
diffusionsfm/inference/ddim.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import random
|
3 |
+
import numpy as np
|
4 |
+
from tqdm.auto import tqdm
|
5 |
+
|
6 |
+
from diffusionsfm.utils.rays import compute_ndc_coordinates
|
7 |
+
|
8 |
+
|
9 |
+
def inference_ddim(
|
10 |
+
model,
|
11 |
+
images,
|
12 |
+
device,
|
13 |
+
crop_parameters=None,
|
14 |
+
eta=0,
|
15 |
+
num_inference_steps=100,
|
16 |
+
pbar=True,
|
17 |
+
num_patches_x=16,
|
18 |
+
num_patches_y=16,
|
19 |
+
visualize=False,
|
20 |
+
seed=0,
|
21 |
+
):
|
22 |
+
"""
|
23 |
+
Implements DDIM-style inference.
|
24 |
+
|
25 |
+
To get multiple samples, batch the images multiple times.
|
26 |
+
|
27 |
+
Args:
|
28 |
+
model: Ray Diffuser.
|
29 |
+
images (torch.Tensor): (B, N, C, H, W).
|
30 |
+
patch_rays_gt (torch.Tensor): If provided, the patch rays which are ground
|
31 |
+
truth (B, N, P, 6).
|
32 |
+
eta (float, optional): Stochasticity coefficient. 0 is completely deterministic,
|
33 |
+
1 is equivalent to DDPM. (Default: 0)
|
34 |
+
num_inference_steps (int, optional): Number of inference steps. (Default: 100)
|
35 |
+
pbar (bool, optional): Whether to show progress bar. (Default: True)
|
36 |
+
"""
|
37 |
+
timesteps = model.noise_scheduler.compute_inference_timesteps(num_inference_steps)
|
38 |
+
batch_size = images.shape[0]
|
39 |
+
num_images = images.shape[1]
|
40 |
+
|
41 |
+
if isinstance(eta, list):
|
42 |
+
eta_0, eta_1 = float(eta[0]), float(eta[1])
|
43 |
+
else:
|
44 |
+
eta_0, eta_1 = 0, 0
|
45 |
+
|
46 |
+
# Fixing seed
|
47 |
+
if seed is not None:
|
48 |
+
torch.manual_seed(seed)
|
49 |
+
random.seed(seed)
|
50 |
+
np.random.seed(seed)
|
51 |
+
|
52 |
+
with torch.no_grad():
|
53 |
+
x_tau = torch.randn(
|
54 |
+
batch_size,
|
55 |
+
num_images,
|
56 |
+
model.ray_out if hasattr(model, "ray_out") else model.ray_dim,
|
57 |
+
num_patches_x,
|
58 |
+
num_patches_y,
|
59 |
+
device=device,
|
60 |
+
)
|
61 |
+
|
62 |
+
if visualize:
|
63 |
+
x_taus = [x_tau]
|
64 |
+
all_pred = []
|
65 |
+
noise_samples = []
|
66 |
+
|
67 |
+
image_features = model.feature_extractor(images, autoresize=True)
|
68 |
+
|
69 |
+
if model.append_ndc:
|
70 |
+
ndc_coordinates = compute_ndc_coordinates(
|
71 |
+
crop_parameters=crop_parameters,
|
72 |
+
no_crop_param_device="cpu",
|
73 |
+
num_patches_x=model.width,
|
74 |
+
num_patches_y=model.width,
|
75 |
+
distortion_coeffs=None,
|
76 |
+
)[..., :2].to(device)
|
77 |
+
ndc_coordinates = ndc_coordinates.permute(0, 1, 4, 2, 3)
|
78 |
+
else:
|
79 |
+
ndc_coordinates = None
|
80 |
+
|
81 |
+
loop = tqdm(range(len(timesteps))) if pbar else range(len(timesteps))
|
82 |
+
for t in loop:
|
83 |
+
tau = timesteps[t]
|
84 |
+
|
85 |
+
if tau > 0 and eta_1 > 0:
|
86 |
+
z = torch.randn(
|
87 |
+
batch_size,
|
88 |
+
num_images,
|
89 |
+
model.ray_out if hasattr(model, "ray_out") else model.ray_dim,
|
90 |
+
num_patches_x,
|
91 |
+
num_patches_y,
|
92 |
+
device=device,
|
93 |
+
)
|
94 |
+
else:
|
95 |
+
z = 0
|
96 |
+
|
97 |
+
alpha = model.noise_scheduler.alphas_cumprod[tau]
|
98 |
+
if tau > 0:
|
99 |
+
tau_prev = timesteps[t + 1]
|
100 |
+
alpha_prev = model.noise_scheduler.alphas_cumprod[tau_prev]
|
101 |
+
else:
|
102 |
+
alpha_prev = torch.tensor(1.0, device=device).float()
|
103 |
+
|
104 |
+
sigma_t = (
|
105 |
+
torch.sqrt((1 - alpha_prev) / (1 - alpha))
|
106 |
+
* torch.sqrt(1 - alpha / alpha_prev)
|
107 |
+
)
|
108 |
+
|
109 |
+
eps_pred, noise_sample = model(
|
110 |
+
features=image_features,
|
111 |
+
rays_noisy=x_tau,
|
112 |
+
t=int(tau),
|
113 |
+
ndc_coordinates=ndc_coordinates,
|
114 |
+
)
|
115 |
+
|
116 |
+
if model.use_homogeneous:
|
117 |
+
p1 = eps_pred[:, :, :4]
|
118 |
+
p2 = eps_pred[:, :, 4:]
|
119 |
+
|
120 |
+
c1 = torch.linalg.norm(p1, dim=2, keepdim=True)
|
121 |
+
c2 = torch.linalg.norm(p2, dim=2, keepdim=True)
|
122 |
+
eps_pred[:, :, :4] = p1 / c1
|
123 |
+
eps_pred[:, :, 4:] = p2 / c2
|
124 |
+
|
125 |
+
if visualize:
|
126 |
+
all_pred.append(eps_pred.clone())
|
127 |
+
noise_samples.append(noise_sample)
|
128 |
+
|
129 |
+
# TODO: Can simplify this a lot
|
130 |
+
x0_pred = eps_pred.clone()
|
131 |
+
eps_pred = (x_tau - torch.sqrt(alpha) * eps_pred) / torch.sqrt(
|
132 |
+
1 - alpha
|
133 |
+
)
|
134 |
+
|
135 |
+
dir_x_tau = torch.sqrt(1 - alpha_prev - eta_0*sigma_t**2) * eps_pred
|
136 |
+
noise = eta_1 * sigma_t * z
|
137 |
+
|
138 |
+
new_x_tau = torch.sqrt(alpha_prev) * x0_pred + dir_x_tau + noise
|
139 |
+
x_tau = new_x_tau
|
140 |
+
|
141 |
+
if visualize:
|
142 |
+
x_taus.append(x_tau.detach().clone())
|
143 |
+
if visualize:
|
144 |
+
return x_tau, x_taus, all_pred, noise_samples
|
145 |
+
return x_tau
|
diffusionsfm/inference/load_model.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os.path as osp
|
2 |
+
from glob import glob
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from omegaconf import OmegaConf
|
6 |
+
|
7 |
+
from diffusionsfm.model.diffuser import RayDiffuser
|
8 |
+
from diffusionsfm.model.diffuser_dpt import RayDiffuserDPT
|
9 |
+
from diffusionsfm.model.scheduler import NoiseScheduler
|
10 |
+
|
11 |
+
|
12 |
+
def load_model(
|
13 |
+
output_dir, checkpoint=None, device="cuda:0", custom_keys=None, ignore_keys=()
|
14 |
+
):
|
15 |
+
"""
|
16 |
+
Loads a model and config from an output directory.
|
17 |
+
|
18 |
+
E.g. to load with different number of images,
|
19 |
+
```
|
20 |
+
custom_keys={"model.num_images": 15}, ignore_keys=["pos_table"]
|
21 |
+
```
|
22 |
+
|
23 |
+
Args:
|
24 |
+
output_dir (str): Path to the output directory.
|
25 |
+
checkpoint (str or int): Path to the checkpoint to load. If None, loads the
|
26 |
+
latest checkpoint.
|
27 |
+
device (str): Device to load the model on.
|
28 |
+
custom_keys (dict): Dictionary of custom keys to override in the config.
|
29 |
+
"""
|
30 |
+
if checkpoint is None:
|
31 |
+
checkpoint_path = sorted(glob(osp.join(output_dir, "checkpoints", "*.pth")))[-1]
|
32 |
+
else:
|
33 |
+
if isinstance(checkpoint, int):
|
34 |
+
checkpoint_name = f"ckpt_{checkpoint:08d}.pth"
|
35 |
+
else:
|
36 |
+
checkpoint_name = checkpoint
|
37 |
+
checkpoint_path = osp.join(output_dir, "checkpoints", checkpoint_name)
|
38 |
+
print("Loading checkpoint", osp.basename(checkpoint_path))
|
39 |
+
|
40 |
+
cfg = OmegaConf.load(osp.join(output_dir, "hydra", "config.yaml"))
|
41 |
+
if custom_keys is not None:
|
42 |
+
for k, v in custom_keys.items():
|
43 |
+
OmegaConf.update(cfg, k, v)
|
44 |
+
noise_scheduler = NoiseScheduler(
|
45 |
+
type=cfg.noise_scheduler.type,
|
46 |
+
max_timesteps=cfg.noise_scheduler.max_timesteps,
|
47 |
+
beta_start=cfg.noise_scheduler.beta_start,
|
48 |
+
beta_end=cfg.noise_scheduler.beta_end,
|
49 |
+
)
|
50 |
+
|
51 |
+
if not cfg.training.get("dpt_head", False):
|
52 |
+
model = RayDiffuser(
|
53 |
+
depth=cfg.model.depth,
|
54 |
+
width=cfg.model.num_patches_x,
|
55 |
+
P=1,
|
56 |
+
max_num_images=cfg.model.num_images,
|
57 |
+
noise_scheduler=noise_scheduler,
|
58 |
+
feature_extractor=cfg.model.feature_extractor,
|
59 |
+
append_ndc=cfg.model.append_ndc,
|
60 |
+
diffuse_depths=cfg.training.get("diffuse_depths", False),
|
61 |
+
depth_resolution=cfg.training.get("depth_resolution", 1),
|
62 |
+
use_homogeneous=cfg.model.get("use_homogeneous", False),
|
63 |
+
cond_depth_mask=cfg.model.get("cond_depth_mask", False),
|
64 |
+
).to(device)
|
65 |
+
else:
|
66 |
+
model = RayDiffuserDPT(
|
67 |
+
depth=cfg.model.depth,
|
68 |
+
width=cfg.model.num_patches_x,
|
69 |
+
P=1,
|
70 |
+
max_num_images=cfg.model.num_images,
|
71 |
+
noise_scheduler=noise_scheduler,
|
72 |
+
feature_extractor=cfg.model.feature_extractor,
|
73 |
+
append_ndc=cfg.model.append_ndc,
|
74 |
+
diffuse_depths=cfg.training.get("diffuse_depths", False),
|
75 |
+
depth_resolution=cfg.training.get("depth_resolution", 1),
|
76 |
+
encoder_features=cfg.training.get("dpt_encoder_features", False),
|
77 |
+
use_homogeneous=cfg.model.get("use_homogeneous", False),
|
78 |
+
cond_depth_mask=cfg.model.get("cond_depth_mask", False),
|
79 |
+
).to(device)
|
80 |
+
|
81 |
+
data = torch.load(checkpoint_path)
|
82 |
+
state_dict = {}
|
83 |
+
for k, v in data["state_dict"].items():
|
84 |
+
include = True
|
85 |
+
for ignore_key in ignore_keys:
|
86 |
+
if ignore_key in k:
|
87 |
+
include = False
|
88 |
+
if include:
|
89 |
+
state_dict[k] = v
|
90 |
+
|
91 |
+
missing, unexpected = model.load_state_dict(state_dict, strict=False)
|
92 |
+
if len(missing) > 0:
|
93 |
+
print("Missing keys:", missing)
|
94 |
+
if len(unexpected) > 0:
|
95 |
+
print("Unexpected keys:", unexpected)
|
96 |
+
model = model.eval()
|
97 |
+
return model, cfg
|
diffusionsfm/inference/predict.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from diffusionsfm.inference.ddim import inference_ddim
|
2 |
+
from diffusionsfm.utils.rays import (
|
3 |
+
Rays,
|
4 |
+
rays_to_cameras,
|
5 |
+
rays_to_cameras_homography,
|
6 |
+
)
|
7 |
+
|
8 |
+
|
9 |
+
def predict_cameras(
|
10 |
+
model,
|
11 |
+
images,
|
12 |
+
device,
|
13 |
+
crop_parameters=None,
|
14 |
+
num_patches_x=16,
|
15 |
+
num_patches_y=16,
|
16 |
+
additional_timesteps=(),
|
17 |
+
calculate_intrinsics=False,
|
18 |
+
max_num_images=None,
|
19 |
+
mode=None,
|
20 |
+
return_rays=False,
|
21 |
+
use_homogeneous=False,
|
22 |
+
seed=0,
|
23 |
+
):
|
24 |
+
"""
|
25 |
+
Args:
|
26 |
+
images (torch.Tensor): (N, C, H, W)
|
27 |
+
crop_parameters (torch.Tensor): (N, 4) or None
|
28 |
+
"""
|
29 |
+
if calculate_intrinsics:
|
30 |
+
ray_to_cam = rays_to_cameras_homography
|
31 |
+
else:
|
32 |
+
ray_to_cam = rays_to_cameras
|
33 |
+
|
34 |
+
get_spatial_rays = Rays.from_spatial
|
35 |
+
|
36 |
+
rays_final, rays_intermediate, pred_intermediate, _ = inference_ddim(
|
37 |
+
model,
|
38 |
+
images.unsqueeze(0),
|
39 |
+
device,
|
40 |
+
visualize=True,
|
41 |
+
crop_parameters=crop_parameters.unsqueeze(0),
|
42 |
+
num_patches_x=num_patches_x,
|
43 |
+
num_patches_y=num_patches_y,
|
44 |
+
pbar=False,
|
45 |
+
eta=[1, 0],
|
46 |
+
num_inference_steps=100,
|
47 |
+
)
|
48 |
+
|
49 |
+
spatial_rays = get_spatial_rays(
|
50 |
+
rays_final[0],
|
51 |
+
mode=mode,
|
52 |
+
num_patches_x=num_patches_x,
|
53 |
+
num_patches_y=num_patches_y,
|
54 |
+
use_homogeneous=use_homogeneous,
|
55 |
+
)
|
56 |
+
|
57 |
+
pred_cam = ray_to_cam(
|
58 |
+
spatial_rays,
|
59 |
+
crop_parameters,
|
60 |
+
num_patches_x=num_patches_x,
|
61 |
+
num_patches_y=num_patches_y,
|
62 |
+
depth_resolution=model.depth_resolution,
|
63 |
+
average_centers=True,
|
64 |
+
directions_from_averaged_center=True,
|
65 |
+
)
|
66 |
+
|
67 |
+
additional_predictions = []
|
68 |
+
for t in additional_timesteps:
|
69 |
+
ray = pred_intermediate[t]
|
70 |
+
|
71 |
+
ray = get_spatial_rays(
|
72 |
+
ray[0],
|
73 |
+
mode=mode,
|
74 |
+
num_patches_x=num_patches_x,
|
75 |
+
num_patches_y=num_patches_y,
|
76 |
+
use_homogeneous=use_homogeneous,
|
77 |
+
)
|
78 |
+
|
79 |
+
cam = ray_to_cam(
|
80 |
+
ray,
|
81 |
+
crop_parameters,
|
82 |
+
num_patches_x=num_patches_x,
|
83 |
+
num_patches_y=num_patches_y,
|
84 |
+
average_centers=True,
|
85 |
+
directions_from_averaged_center=True,
|
86 |
+
)
|
87 |
+
if return_rays:
|
88 |
+
cam = (cam, ray)
|
89 |
+
additional_predictions.append(cam)
|
90 |
+
|
91 |
+
if return_rays:
|
92 |
+
return (pred_cam, spatial_rays), additional_predictions
|
93 |
+
return pred_cam, additional_predictions, spatial_rays
|
diffusionsfm/model/base_model.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
class BaseModel(torch.nn.Module):
|
5 |
+
def load(self, path):
|
6 |
+
"""Load model from file.
|
7 |
+
|
8 |
+
Args:
|
9 |
+
path (str): file path
|
10 |
+
"""
|
11 |
+
parameters = torch.load(path, map_location=torch.device("cpu"))
|
12 |
+
|
13 |
+
if "optimizer" in parameters:
|
14 |
+
parameters = parameters["model"]
|
15 |
+
|
16 |
+
self.load_state_dict(parameters)
|
diffusionsfm/model/blocks.py
ADDED
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from diffusionsfm.model.dit import TimestepEmbedder
|
4 |
+
import ipdb
|
5 |
+
|
6 |
+
|
7 |
+
def modulate(x, shift, scale):
|
8 |
+
return x * (1 + scale.unsqueeze(-1).unsqueeze(-1)) + shift.unsqueeze(-1).unsqueeze(
|
9 |
+
-1
|
10 |
+
)
|
11 |
+
|
12 |
+
|
13 |
+
def _make_fusion_block(features, use_bn, use_ln, dpt_time, resolution):
|
14 |
+
return FeatureFusionBlock_custom(
|
15 |
+
features,
|
16 |
+
nn.ReLU(False),
|
17 |
+
deconv=False,
|
18 |
+
bn=use_bn,
|
19 |
+
expand=False,
|
20 |
+
align_corners=True,
|
21 |
+
dpt_time=dpt_time,
|
22 |
+
ln=use_ln,
|
23 |
+
resolution=resolution
|
24 |
+
)
|
25 |
+
|
26 |
+
|
27 |
+
def _make_scratch(in_shape, out_shape, groups=1, expand=False):
|
28 |
+
scratch = nn.Module()
|
29 |
+
|
30 |
+
out_shape1 = out_shape
|
31 |
+
out_shape2 = out_shape
|
32 |
+
out_shape3 = out_shape
|
33 |
+
out_shape4 = out_shape
|
34 |
+
if expand == True:
|
35 |
+
out_shape1 = out_shape
|
36 |
+
out_shape2 = out_shape * 2
|
37 |
+
out_shape3 = out_shape * 4
|
38 |
+
out_shape4 = out_shape * 8
|
39 |
+
|
40 |
+
scratch.layer1_rn = nn.Conv2d(
|
41 |
+
in_shape[0],
|
42 |
+
out_shape1,
|
43 |
+
kernel_size=3,
|
44 |
+
stride=1,
|
45 |
+
padding=1,
|
46 |
+
bias=False,
|
47 |
+
groups=groups,
|
48 |
+
)
|
49 |
+
scratch.layer2_rn = nn.Conv2d(
|
50 |
+
in_shape[1],
|
51 |
+
out_shape2,
|
52 |
+
kernel_size=3,
|
53 |
+
stride=1,
|
54 |
+
padding=1,
|
55 |
+
bias=False,
|
56 |
+
groups=groups,
|
57 |
+
)
|
58 |
+
scratch.layer3_rn = nn.Conv2d(
|
59 |
+
in_shape[2],
|
60 |
+
out_shape3,
|
61 |
+
kernel_size=3,
|
62 |
+
stride=1,
|
63 |
+
padding=1,
|
64 |
+
bias=False,
|
65 |
+
groups=groups,
|
66 |
+
)
|
67 |
+
scratch.layer4_rn = nn.Conv2d(
|
68 |
+
in_shape[3],
|
69 |
+
out_shape4,
|
70 |
+
kernel_size=3,
|
71 |
+
stride=1,
|
72 |
+
padding=1,
|
73 |
+
bias=False,
|
74 |
+
groups=groups,
|
75 |
+
)
|
76 |
+
|
77 |
+
return scratch
|
78 |
+
|
79 |
+
|
80 |
+
class ResidualConvUnit_custom(nn.Module):
|
81 |
+
"""Residual convolution module."""
|
82 |
+
|
83 |
+
def __init__(self, features, activation, bn, ln, dpt_time=False, resolution=16):
|
84 |
+
"""Init.
|
85 |
+
|
86 |
+
Args:
|
87 |
+
features (int): number of features
|
88 |
+
"""
|
89 |
+
super().__init__()
|
90 |
+
|
91 |
+
self.bn = bn
|
92 |
+
self.ln = ln
|
93 |
+
|
94 |
+
self.groups = 1
|
95 |
+
|
96 |
+
self.conv1 = nn.Conv2d(
|
97 |
+
features,
|
98 |
+
features,
|
99 |
+
kernel_size=3,
|
100 |
+
stride=1,
|
101 |
+
padding=1,
|
102 |
+
bias=not self.bn,
|
103 |
+
groups=self.groups,
|
104 |
+
)
|
105 |
+
|
106 |
+
self.conv2 = nn.Conv2d(
|
107 |
+
features,
|
108 |
+
features,
|
109 |
+
kernel_size=3,
|
110 |
+
stride=1,
|
111 |
+
padding=1,
|
112 |
+
bias=not self.bn,
|
113 |
+
groups=self.groups,
|
114 |
+
)
|
115 |
+
|
116 |
+
nn.init.kaiming_uniform_(self.conv1.weight)
|
117 |
+
nn.init.kaiming_uniform_(self.conv2.weight)
|
118 |
+
|
119 |
+
if self.bn == True:
|
120 |
+
self.bn1 = nn.BatchNorm2d(features)
|
121 |
+
self.bn2 = nn.BatchNorm2d(features)
|
122 |
+
|
123 |
+
if self.ln == True:
|
124 |
+
self.bn1 = nn.LayerNorm((features, resolution, resolution))
|
125 |
+
self.bn2 = nn.LayerNorm((features, resolution, resolution))
|
126 |
+
|
127 |
+
self.activation = activation
|
128 |
+
|
129 |
+
if dpt_time:
|
130 |
+
self.t_embedder = TimestepEmbedder(hidden_size=features)
|
131 |
+
self.adaLN_modulation = nn.Sequential(
|
132 |
+
nn.SiLU(), nn.Linear(features, 3 * features, bias=True)
|
133 |
+
)
|
134 |
+
|
135 |
+
def forward(self, x, t=None):
|
136 |
+
"""Forward pass.
|
137 |
+
|
138 |
+
Args:
|
139 |
+
x (tensor): input
|
140 |
+
|
141 |
+
Returns:
|
142 |
+
tensor: output
|
143 |
+
"""
|
144 |
+
if t is not None:
|
145 |
+
# Embed timestamp & calculate shift parameters
|
146 |
+
t = self.t_embedder(t) # (B*N)
|
147 |
+
shift, scale, gate = self.adaLN_modulation(t).chunk(3, dim=1) # (B * N, T)
|
148 |
+
|
149 |
+
# Shift & scale x
|
150 |
+
x = modulate(x, shift, scale) # (B * N, T, H, W)
|
151 |
+
|
152 |
+
out = self.activation(x)
|
153 |
+
out = self.conv1(out)
|
154 |
+
if self.bn or self.ln:
|
155 |
+
out = self.bn1(out)
|
156 |
+
|
157 |
+
out = self.activation(out)
|
158 |
+
out = self.conv2(out)
|
159 |
+
if self.bn or self.ln:
|
160 |
+
out = self.bn2(out)
|
161 |
+
|
162 |
+
if self.groups > 1:
|
163 |
+
out = self.conv_merge(out)
|
164 |
+
|
165 |
+
if t is not None:
|
166 |
+
out = gate.unsqueeze(-1).unsqueeze(-1) * out
|
167 |
+
|
168 |
+
return out + x
|
169 |
+
|
170 |
+
|
171 |
+
class FeatureFusionBlock_custom(nn.Module):
|
172 |
+
"""Feature fusion block."""
|
173 |
+
|
174 |
+
def __init__(
|
175 |
+
self,
|
176 |
+
features,
|
177 |
+
activation,
|
178 |
+
deconv=False,
|
179 |
+
bn=False,
|
180 |
+
ln=False,
|
181 |
+
expand=False,
|
182 |
+
align_corners=True,
|
183 |
+
dpt_time=False,
|
184 |
+
resolution=16,
|
185 |
+
):
|
186 |
+
"""Init.
|
187 |
+
|
188 |
+
Args:
|
189 |
+
features (int): number of features
|
190 |
+
"""
|
191 |
+
super(FeatureFusionBlock_custom, self).__init__()
|
192 |
+
|
193 |
+
self.deconv = deconv
|
194 |
+
self.align_corners = align_corners
|
195 |
+
|
196 |
+
self.groups = 1
|
197 |
+
|
198 |
+
self.expand = expand
|
199 |
+
out_features = features
|
200 |
+
if self.expand == True:
|
201 |
+
out_features = features // 2
|
202 |
+
|
203 |
+
self.out_conv = nn.Conv2d(
|
204 |
+
features,
|
205 |
+
out_features,
|
206 |
+
kernel_size=1,
|
207 |
+
stride=1,
|
208 |
+
padding=0,
|
209 |
+
bias=True,
|
210 |
+
groups=1,
|
211 |
+
)
|
212 |
+
|
213 |
+
nn.init.kaiming_uniform_(self.out_conv.weight)
|
214 |
+
|
215 |
+
# The second block sees time
|
216 |
+
self.resConfUnit1 = ResidualConvUnit_custom(
|
217 |
+
features, activation, bn=bn, ln=ln, dpt_time=False, resolution=resolution
|
218 |
+
)
|
219 |
+
self.resConfUnit2 = ResidualConvUnit_custom(
|
220 |
+
features, activation, bn=bn, ln=ln, dpt_time=dpt_time, resolution=resolution
|
221 |
+
)
|
222 |
+
|
223 |
+
def forward(self, input, activation=None, t=None):
|
224 |
+
"""Forward pass.
|
225 |
+
|
226 |
+
Returns:
|
227 |
+
tensor: output
|
228 |
+
"""
|
229 |
+
output = input
|
230 |
+
|
231 |
+
if activation is not None:
|
232 |
+
res = self.resConfUnit1(activation)
|
233 |
+
|
234 |
+
output += res
|
235 |
+
|
236 |
+
output = self.resConfUnit2(output, t)
|
237 |
+
|
238 |
+
output = torch.nn.functional.interpolate(
|
239 |
+
output.float(),
|
240 |
+
scale_factor=2,
|
241 |
+
mode="bilinear",
|
242 |
+
align_corners=self.align_corners,
|
243 |
+
)
|
244 |
+
|
245 |
+
output = self.out_conv(output)
|
246 |
+
|
247 |
+
return output
|
diffusionsfm/model/diffuser.py
ADDED
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import ipdb # noqa: F401
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
from diffusionsfm.model.dit import DiT
|
7 |
+
from diffusionsfm.model.feature_extractors import PretrainedVAE, SpatialDino
|
8 |
+
from diffusionsfm.model.scheduler import NoiseScheduler
|
9 |
+
|
10 |
+
|
11 |
+
class RayDiffuser(nn.Module):
|
12 |
+
def __init__(
|
13 |
+
self,
|
14 |
+
model_type="dit",
|
15 |
+
depth=8,
|
16 |
+
width=16,
|
17 |
+
hidden_size=1152,
|
18 |
+
P=1,
|
19 |
+
max_num_images=1,
|
20 |
+
noise_scheduler=None,
|
21 |
+
freeze_encoder=True,
|
22 |
+
feature_extractor="dino",
|
23 |
+
append_ndc=True,
|
24 |
+
use_unconditional=False,
|
25 |
+
diffuse_depths=False,
|
26 |
+
depth_resolution=1,
|
27 |
+
use_homogeneous=False,
|
28 |
+
cond_depth_mask=False,
|
29 |
+
):
|
30 |
+
super().__init__()
|
31 |
+
if noise_scheduler is None:
|
32 |
+
self.noise_scheduler = NoiseScheduler()
|
33 |
+
else:
|
34 |
+
self.noise_scheduler = noise_scheduler
|
35 |
+
|
36 |
+
self.diffuse_depths = diffuse_depths
|
37 |
+
self.depth_resolution = depth_resolution
|
38 |
+
self.use_homogeneous = use_homogeneous
|
39 |
+
|
40 |
+
self.ray_dim = 3
|
41 |
+
if self.use_homogeneous:
|
42 |
+
self.ray_dim += 1
|
43 |
+
|
44 |
+
self.ray_dim += self.ray_dim * self.depth_resolution**2
|
45 |
+
|
46 |
+
if self.diffuse_depths:
|
47 |
+
self.ray_dim += 1
|
48 |
+
|
49 |
+
self.append_ndc = append_ndc
|
50 |
+
self.width = width
|
51 |
+
|
52 |
+
self.max_num_images = max_num_images
|
53 |
+
self.model_type = model_type
|
54 |
+
self.use_unconditional = use_unconditional
|
55 |
+
self.cond_depth_mask = cond_depth_mask
|
56 |
+
|
57 |
+
if feature_extractor == "dino":
|
58 |
+
self.feature_extractor = SpatialDino(
|
59 |
+
freeze_weights=freeze_encoder, num_patches_x=width, num_patches_y=width
|
60 |
+
)
|
61 |
+
self.feature_dim = self.feature_extractor.feature_dim
|
62 |
+
elif feature_extractor == "vae":
|
63 |
+
self.feature_extractor = PretrainedVAE(
|
64 |
+
freeze_weights=freeze_encoder, num_patches_x=width, num_patches_y=width
|
65 |
+
)
|
66 |
+
self.feature_dim = self.feature_extractor.feature_dim
|
67 |
+
else:
|
68 |
+
raise Exception(f"Unknown feature extractor {feature_extractor}")
|
69 |
+
|
70 |
+
if self.use_unconditional:
|
71 |
+
self.register_parameter(
|
72 |
+
"null_token", nn.Parameter(torch.randn(self.feature_dim, 1, 1))
|
73 |
+
)
|
74 |
+
|
75 |
+
self.input_dim = self.feature_dim * 2
|
76 |
+
|
77 |
+
if self.append_ndc:
|
78 |
+
self.input_dim += 2
|
79 |
+
|
80 |
+
if model_type == "dit":
|
81 |
+
self.ray_predictor = DiT(
|
82 |
+
in_channels=self.input_dim,
|
83 |
+
out_channels=self.ray_dim,
|
84 |
+
width=width,
|
85 |
+
depth=depth,
|
86 |
+
hidden_size=hidden_size,
|
87 |
+
max_num_images=max_num_images,
|
88 |
+
P=P,
|
89 |
+
)
|
90 |
+
|
91 |
+
self.scratch = nn.Module()
|
92 |
+
self.scratch.input_conv = nn.Linear(self.ray_dim + int(self.cond_depth_mask), self.feature_dim)
|
93 |
+
|
94 |
+
def forward_noise(
|
95 |
+
self, x, t, epsilon=None, zero_out_mask=None
|
96 |
+
):
|
97 |
+
"""
|
98 |
+
Applies forward diffusion (adds noise) to the input.
|
99 |
+
|
100 |
+
If a mask is provided, the noise is only applied to the masked inputs.
|
101 |
+
"""
|
102 |
+
t = t.reshape(-1, 1, 1, 1, 1)
|
103 |
+
|
104 |
+
if epsilon is None:
|
105 |
+
epsilon = torch.randn_like(x)
|
106 |
+
else:
|
107 |
+
epsilon = epsilon.reshape(x.shape)
|
108 |
+
|
109 |
+
alpha_bar = self.noise_scheduler.alphas_cumprod[t]
|
110 |
+
x_noise = torch.sqrt(alpha_bar) * x + torch.sqrt(1 - alpha_bar) * epsilon
|
111 |
+
|
112 |
+
if zero_out_mask is not None and self.cond_depth_mask:
|
113 |
+
x_noise = x_noise * zero_out_mask
|
114 |
+
|
115 |
+
return x_noise, epsilon
|
116 |
+
|
117 |
+
def forward(
|
118 |
+
self,
|
119 |
+
features=None,
|
120 |
+
images=None,
|
121 |
+
rays=None,
|
122 |
+
rays_noisy=None,
|
123 |
+
t=None,
|
124 |
+
ndc_coordinates=None,
|
125 |
+
unconditional_mask=None,
|
126 |
+
return_dpt_activations=False,
|
127 |
+
depth_mask=None,
|
128 |
+
):
|
129 |
+
"""
|
130 |
+
Args:
|
131 |
+
images: (B, N, 3, H, W).
|
132 |
+
t: (B,).
|
133 |
+
rays: (B, N, 6, H, W).
|
134 |
+
rays_noisy: (B, N, 6, H, W).
|
135 |
+
ndc_coordinates: (B, N, 2, H, W).
|
136 |
+
unconditional_mask: (B, N) or (B,). Should be 1 for unconditional samples
|
137 |
+
and 0 else.
|
138 |
+
"""
|
139 |
+
|
140 |
+
if features is None:
|
141 |
+
# VAE expects 256x256 images while DINO expects 224x224 images.
|
142 |
+
# Both feature extractors support autoresize=True, but ideally we should
|
143 |
+
# set this to be false and handle in the dataloader.
|
144 |
+
features = self.feature_extractor(images, autoresize=True)
|
145 |
+
|
146 |
+
B = features.shape[0]
|
147 |
+
|
148 |
+
if (
|
149 |
+
unconditional_mask is not None
|
150 |
+
and self.use_unconditional
|
151 |
+
):
|
152 |
+
null_token = self.null_token.reshape(1, 1, self.feature_dim, 1, 1)
|
153 |
+
unconditional_mask = unconditional_mask.reshape(B, -1, 1, 1, 1)
|
154 |
+
features = (
|
155 |
+
features * (1 - unconditional_mask) + null_token * unconditional_mask
|
156 |
+
)
|
157 |
+
|
158 |
+
if isinstance(t, int) or isinstance(t, np.int64):
|
159 |
+
t = torch.ones(1, dtype=int).to(features.device) * t
|
160 |
+
else:
|
161 |
+
t = t.reshape(B)
|
162 |
+
|
163 |
+
if rays_noisy is None:
|
164 |
+
if self.cond_depth_mask:
|
165 |
+
rays_noisy, epsilon = self.forward_noise(rays, t, zero_out_mask=depth_mask.unsqueeze(2))
|
166 |
+
else:
|
167 |
+
rays_noisy, epsilon = self.forward_noise(rays, t)
|
168 |
+
else:
|
169 |
+
epsilon = None
|
170 |
+
|
171 |
+
if self.cond_depth_mask:
|
172 |
+
if depth_mask is None:
|
173 |
+
depth_mask = torch.ones_like(rays_noisy[:, :, 0])
|
174 |
+
ray_repr = torch.cat([rays_noisy, depth_mask.unsqueeze(2)], dim=2)
|
175 |
+
else:
|
176 |
+
ray_repr = rays_noisy
|
177 |
+
|
178 |
+
ray_repr = ray_repr.permute(0, 1, 3, 4, 2)
|
179 |
+
ray_repr = self.scratch.input_conv(ray_repr).permute(0, 1, 4, 2, 3).contiguous()
|
180 |
+
|
181 |
+
scene_features = torch.cat([features, ray_repr], dim=2)
|
182 |
+
|
183 |
+
if self.append_ndc:
|
184 |
+
scene_features = torch.cat([scene_features, ndc_coordinates], dim=2)
|
185 |
+
|
186 |
+
epsilon_pred = self.ray_predictor(
|
187 |
+
scene_features,
|
188 |
+
t,
|
189 |
+
return_dpt_activations=return_dpt_activations,
|
190 |
+
)
|
191 |
+
|
192 |
+
if return_dpt_activations:
|
193 |
+
return epsilon_pred, rays_noisy, epsilon
|
194 |
+
|
195 |
+
return epsilon_pred, epsilon
|
diffusionsfm/model/diffuser_dpt.py
ADDED
@@ -0,0 +1,331 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import ipdb # noqa: F401
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
from diffusionsfm.model.dit import DiT
|
7 |
+
from diffusionsfm.model.feature_extractors import PretrainedVAE, SpatialDino
|
8 |
+
from diffusionsfm.model.blocks import _make_fusion_block, _make_scratch
|
9 |
+
from diffusionsfm.model.scheduler import NoiseScheduler
|
10 |
+
|
11 |
+
|
12 |
+
# functional implementation
|
13 |
+
def nearest_neighbor_upsample(x: torch.Tensor, scale_factor: int):
|
14 |
+
"""Upsample {x} (NCHW) by scale factor {scale_factor} using nearest neighbor interpolation."""
|
15 |
+
s = scale_factor
|
16 |
+
return (
|
17 |
+
x.reshape(*x.shape, 1, 1)
|
18 |
+
.expand(*x.shape, s, s)
|
19 |
+
.transpose(-2, -3)
|
20 |
+
.reshape(*x.shape[:2], *(s * hw for hw in x.shape[2:]))
|
21 |
+
)
|
22 |
+
|
23 |
+
|
24 |
+
class ProjectReadout(nn.Module):
|
25 |
+
def __init__(self, in_features, start_index=1):
|
26 |
+
super(ProjectReadout, self).__init__()
|
27 |
+
self.start_index = start_index
|
28 |
+
|
29 |
+
self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
|
30 |
+
|
31 |
+
def forward(self, x):
|
32 |
+
readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :])
|
33 |
+
features = torch.cat((x[:, self.start_index :], readout), -1)
|
34 |
+
|
35 |
+
return self.project(features)
|
36 |
+
|
37 |
+
|
38 |
+
class RayDiffuserDPT(nn.Module):
|
39 |
+
def __init__(
|
40 |
+
self,
|
41 |
+
model_type="dit",
|
42 |
+
depth=8,
|
43 |
+
width=16,
|
44 |
+
hidden_size=1152,
|
45 |
+
P=1,
|
46 |
+
max_num_images=1,
|
47 |
+
noise_scheduler=None,
|
48 |
+
freeze_encoder=True,
|
49 |
+
feature_extractor="dino",
|
50 |
+
append_ndc=True,
|
51 |
+
use_unconditional=False,
|
52 |
+
diffuse_depths=False,
|
53 |
+
depth_resolution=1,
|
54 |
+
encoder_features=False,
|
55 |
+
use_homogeneous=False,
|
56 |
+
freeze_transformer=False,
|
57 |
+
cond_depth_mask=False,
|
58 |
+
):
|
59 |
+
super().__init__()
|
60 |
+
if noise_scheduler is None:
|
61 |
+
self.noise_scheduler = NoiseScheduler()
|
62 |
+
else:
|
63 |
+
self.noise_scheduler = noise_scheduler
|
64 |
+
|
65 |
+
self.diffuse_depths = diffuse_depths
|
66 |
+
self.depth_resolution = depth_resolution
|
67 |
+
self.use_homogeneous = use_homogeneous
|
68 |
+
|
69 |
+
self.ray_dim = 3
|
70 |
+
|
71 |
+
if self.use_homogeneous:
|
72 |
+
self.ray_dim += 1
|
73 |
+
self.ray_dim += self.ray_dim * self.depth_resolution**2
|
74 |
+
|
75 |
+
if self.diffuse_depths:
|
76 |
+
self.ray_dim += 1
|
77 |
+
|
78 |
+
self.append_ndc = append_ndc
|
79 |
+
self.width = width
|
80 |
+
|
81 |
+
self.max_num_images = max_num_images
|
82 |
+
self.model_type = model_type
|
83 |
+
self.use_unconditional = use_unconditional
|
84 |
+
self.cond_depth_mask = cond_depth_mask
|
85 |
+
self.encoder_features = encoder_features
|
86 |
+
|
87 |
+
if feature_extractor == "dino":
|
88 |
+
self.feature_extractor = SpatialDino(
|
89 |
+
freeze_weights=freeze_encoder,
|
90 |
+
num_patches_x=width,
|
91 |
+
num_patches_y=width,
|
92 |
+
activation_hooks=self.encoder_features,
|
93 |
+
)
|
94 |
+
self.feature_dim = self.feature_extractor.feature_dim
|
95 |
+
elif feature_extractor == "vae":
|
96 |
+
self.feature_extractor = PretrainedVAE(
|
97 |
+
freeze_weights=freeze_encoder, num_patches_x=width, num_patches_y=width
|
98 |
+
)
|
99 |
+
self.feature_dim = self.feature_extractor.feature_dim
|
100 |
+
else:
|
101 |
+
raise Exception(f"Unknown feature extractor {feature_extractor}")
|
102 |
+
|
103 |
+
if self.use_unconditional:
|
104 |
+
self.register_parameter(
|
105 |
+
"null_token", nn.Parameter(torch.randn(self.feature_dim, 1, 1))
|
106 |
+
)
|
107 |
+
|
108 |
+
self.input_dim = self.feature_dim * 2
|
109 |
+
|
110 |
+
if self.append_ndc:
|
111 |
+
self.input_dim += 2
|
112 |
+
|
113 |
+
if model_type == "dit":
|
114 |
+
self.ray_predictor = DiT(
|
115 |
+
in_channels=self.input_dim,
|
116 |
+
out_channels=self.ray_dim,
|
117 |
+
width=width,
|
118 |
+
depth=depth,
|
119 |
+
hidden_size=hidden_size,
|
120 |
+
max_num_images=max_num_images,
|
121 |
+
P=P,
|
122 |
+
)
|
123 |
+
|
124 |
+
if freeze_transformer:
|
125 |
+
for param in self.ray_predictor.parameters():
|
126 |
+
param.requires_grad = False
|
127 |
+
|
128 |
+
# Fusion blocks
|
129 |
+
self.f = 256
|
130 |
+
|
131 |
+
if self.encoder_features:
|
132 |
+
feature_lens = [
|
133 |
+
self.feature_extractor.feature_dim,
|
134 |
+
self.feature_extractor.feature_dim,
|
135 |
+
self.ray_predictor.hidden_size,
|
136 |
+
self.ray_predictor.hidden_size,
|
137 |
+
]
|
138 |
+
else:
|
139 |
+
feature_lens = [self.ray_predictor.hidden_size] * 4
|
140 |
+
|
141 |
+
self.scratch = _make_scratch(feature_lens, 256, groups=1, expand=False)
|
142 |
+
self.scratch.refinenet1 = _make_fusion_block(
|
143 |
+
self.f, use_bn=False, use_ln=False, dpt_time=True, resolution=128
|
144 |
+
)
|
145 |
+
self.scratch.refinenet2 = _make_fusion_block(
|
146 |
+
self.f, use_bn=False, use_ln=False, dpt_time=True, resolution=64
|
147 |
+
)
|
148 |
+
self.scratch.refinenet3 = _make_fusion_block(
|
149 |
+
self.f, use_bn=False, use_ln=False, dpt_time=True, resolution=32
|
150 |
+
)
|
151 |
+
self.scratch.refinenet4 = _make_fusion_block(
|
152 |
+
self.f, use_bn=False, use_ln=False, dpt_time=True, resolution=16
|
153 |
+
)
|
154 |
+
|
155 |
+
self.scratch.input_conv = nn.Conv2d(
|
156 |
+
self.ray_dim + int(self.cond_depth_mask),
|
157 |
+
self.feature_dim,
|
158 |
+
kernel_size=16,
|
159 |
+
stride=16,
|
160 |
+
padding=0
|
161 |
+
)
|
162 |
+
|
163 |
+
self.scratch.output_conv = nn.Sequential(
|
164 |
+
nn.Conv2d(self.f, self.f // 2, kernel_size=3, stride=1, padding=1),
|
165 |
+
nn.LeakyReLU(),
|
166 |
+
nn.Conv2d(self.f // 2, 32, kernel_size=3, stride=1, padding=1),
|
167 |
+
nn.LeakyReLU(),
|
168 |
+
nn.Conv2d(32, self.ray_dim, kernel_size=1, stride=1, padding=0),
|
169 |
+
nn.Identity(),
|
170 |
+
)
|
171 |
+
|
172 |
+
if self.encoder_features:
|
173 |
+
self.project_opers = nn.ModuleList([
|
174 |
+
ProjectReadout(in_features=self.feature_extractor.feature_dim),
|
175 |
+
ProjectReadout(in_features=self.feature_extractor.feature_dim),
|
176 |
+
])
|
177 |
+
|
178 |
+
def forward_noise(
|
179 |
+
self, x, t, epsilon=None, zero_out_mask=None
|
180 |
+
):
|
181 |
+
"""
|
182 |
+
Applies forward diffusion (adds noise) to the input.
|
183 |
+
|
184 |
+
If a mask is provided, the noise is only applied to the masked inputs.
|
185 |
+
"""
|
186 |
+
t = t.reshape(-1, 1, 1, 1, 1)
|
187 |
+
if epsilon is None:
|
188 |
+
epsilon = torch.randn_like(x)
|
189 |
+
else:
|
190 |
+
epsilon = epsilon.reshape(x.shape)
|
191 |
+
|
192 |
+
alpha_bar = self.noise_scheduler.alphas_cumprod[t]
|
193 |
+
x_noise = torch.sqrt(alpha_bar) * x + torch.sqrt(1 - alpha_bar) * epsilon
|
194 |
+
|
195 |
+
if zero_out_mask is not None and self.cond_depth_mask:
|
196 |
+
x_noise = zero_out_mask * x_noise
|
197 |
+
|
198 |
+
return x_noise, epsilon
|
199 |
+
|
200 |
+
def forward(
|
201 |
+
self,
|
202 |
+
features=None,
|
203 |
+
images=None,
|
204 |
+
rays=None,
|
205 |
+
rays_noisy=None,
|
206 |
+
t=None,
|
207 |
+
ndc_coordinates=None,
|
208 |
+
unconditional_mask=None,
|
209 |
+
encoder_patches=16,
|
210 |
+
depth_mask=None,
|
211 |
+
multiview_unconditional=False,
|
212 |
+
indices=None,
|
213 |
+
):
|
214 |
+
"""
|
215 |
+
Args:
|
216 |
+
images: (B, N, 3, H, W).
|
217 |
+
t: (B,).
|
218 |
+
rays: (B, N, 6, H, W).
|
219 |
+
rays_noisy: (B, N, 6, H, W).
|
220 |
+
ndc_coordinates: (B, N, 2, H, W).
|
221 |
+
unconditional_mask: (B, N) or (B,). Should be 1 for unconditional samples
|
222 |
+
and 0 else.
|
223 |
+
"""
|
224 |
+
|
225 |
+
if features is None:
|
226 |
+
# VAE expects 256x256 images while DINO expects 224x224 images.
|
227 |
+
# Both feature extractors support autoresize=True, but ideally we should
|
228 |
+
# set this to be false and handle in the dataloader.
|
229 |
+
features = self.feature_extractor(images, autoresize=True)
|
230 |
+
|
231 |
+
B = features.shape[0]
|
232 |
+
|
233 |
+
if unconditional_mask is not None and self.use_unconditional:
|
234 |
+
null_token = self.null_token.reshape(1, 1, self.feature_dim, 1, 1)
|
235 |
+
unconditional_mask = unconditional_mask.reshape(B, -1, 1, 1, 1)
|
236 |
+
features = (
|
237 |
+
features * (1 - unconditional_mask) + null_token * unconditional_mask
|
238 |
+
)
|
239 |
+
|
240 |
+
if isinstance(t, int) or isinstance(t, np.int64):
|
241 |
+
t = torch.ones(1, dtype=int).to(features.device) * t
|
242 |
+
else:
|
243 |
+
t = t.reshape(B)
|
244 |
+
|
245 |
+
if rays_noisy is None:
|
246 |
+
if self.cond_depth_mask:
|
247 |
+
rays_noisy, epsilon = self.forward_noise(
|
248 |
+
rays, t, zero_out_mask=depth_mask.unsqueeze(2)
|
249 |
+
)
|
250 |
+
else:
|
251 |
+
rays_noisy, epsilon = self.forward_noise(
|
252 |
+
rays, t
|
253 |
+
)
|
254 |
+
else:
|
255 |
+
epsilon = None
|
256 |
+
|
257 |
+
# DOWNSAMPLE RAYS
|
258 |
+
B, N, C, H, W = rays_noisy.shape
|
259 |
+
|
260 |
+
if self.cond_depth_mask:
|
261 |
+
if depth_mask is None:
|
262 |
+
depth_mask = torch.ones_like(rays_noisy[:, :, 0])
|
263 |
+
ray_repr = torch.cat([rays_noisy, depth_mask.unsqueeze(2)], dim=2)
|
264 |
+
else:
|
265 |
+
ray_repr = rays_noisy
|
266 |
+
|
267 |
+
ray_repr = self.scratch.input_conv(ray_repr.reshape(B * N, -1, H, W))
|
268 |
+
_, CP, HP, WP = ray_repr.shape
|
269 |
+
ray_repr = ray_repr.reshape(B, N, CP, HP, WP)
|
270 |
+
scene_features = torch.cat([features, ray_repr], dim=2)
|
271 |
+
|
272 |
+
if self.append_ndc:
|
273 |
+
scene_features = torch.cat([scene_features, ndc_coordinates], dim=2)
|
274 |
+
|
275 |
+
# DIT FORWARD PASS
|
276 |
+
activations = self.ray_predictor(
|
277 |
+
scene_features,
|
278 |
+
t,
|
279 |
+
return_dpt_activations=True,
|
280 |
+
multiview_unconditional=multiview_unconditional,
|
281 |
+
)
|
282 |
+
|
283 |
+
# PROJECT ENCODER ACTIVATIONS & RESHAPE
|
284 |
+
if self.encoder_features:
|
285 |
+
for i in range(2):
|
286 |
+
name = f"encoder{i+1}"
|
287 |
+
|
288 |
+
if indices is not None:
|
289 |
+
act = self.feature_extractor.activations[name][indices]
|
290 |
+
else:
|
291 |
+
act = self.feature_extractor.activations[name]
|
292 |
+
|
293 |
+
act = self.project_opers[i](act).permute(0, 2, 1)
|
294 |
+
act = act.reshape(
|
295 |
+
(
|
296 |
+
B * N,
|
297 |
+
self.feature_extractor.feature_dim,
|
298 |
+
encoder_patches,
|
299 |
+
encoder_patches,
|
300 |
+
)
|
301 |
+
)
|
302 |
+
activations[i] = act
|
303 |
+
|
304 |
+
# UPSAMPLE ACTIVATIONS
|
305 |
+
for i, act in enumerate(activations):
|
306 |
+
k = 3 - i
|
307 |
+
activations[i] = nearest_neighbor_upsample(act, 2**k)
|
308 |
+
|
309 |
+
# FUSION BLOCKS
|
310 |
+
layer_1_rn = self.scratch.layer1_rn(activations[0])
|
311 |
+
layer_2_rn = self.scratch.layer2_rn(activations[1])
|
312 |
+
layer_3_rn = self.scratch.layer3_rn(activations[2])
|
313 |
+
layer_4_rn = self.scratch.layer4_rn(activations[3])
|
314 |
+
|
315 |
+
# RESHAPE TIMESTEPS
|
316 |
+
if t.shape[0] == B:
|
317 |
+
t = t.unsqueeze(-1).repeat((1, N)).reshape(B * N)
|
318 |
+
elif t.shape[0] == 1 and B > 1:
|
319 |
+
t = t.repeat((B * N))
|
320 |
+
else:
|
321 |
+
assert False
|
322 |
+
|
323 |
+
path_4 = self.scratch.refinenet4(layer_4_rn, t=t)
|
324 |
+
path_3 = self.scratch.refinenet3(path_4, activation=layer_3_rn, t=t)
|
325 |
+
path_2 = self.scratch.refinenet2(path_3, activation=layer_2_rn, t=t)
|
326 |
+
path_1 = self.scratch.refinenet1(path_2, activation=layer_1_rn, t=t)
|
327 |
+
|
328 |
+
epsilon_pred = self.scratch.output_conv(path_1)
|
329 |
+
epsilon_pred = epsilon_pred.reshape((B, N, C, H, W))
|
330 |
+
|
331 |
+
return epsilon_pred, epsilon
|
diffusionsfm/model/dit.py
ADDED
@@ -0,0 +1,428 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# --------------------------------------------------------
|
7 |
+
# References:
|
8 |
+
# GLIDE: https://github.com/openai/glide-text2im
|
9 |
+
# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
|
10 |
+
# --------------------------------------------------------
|
11 |
+
|
12 |
+
import math
|
13 |
+
|
14 |
+
import ipdb # noqa: F401
|
15 |
+
import numpy as np
|
16 |
+
import torch
|
17 |
+
import torch.nn as nn
|
18 |
+
from timm.models.vision_transformer import Attention, Mlp, PatchEmbed
|
19 |
+
from diffusionsfm.model.memory_efficient_attention import MEAttention
|
20 |
+
|
21 |
+
|
22 |
+
def modulate(x, shift, scale):
|
23 |
+
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
24 |
+
|
25 |
+
|
26 |
+
#################################################################################
|
27 |
+
# Embedding Layers for Timesteps and Class Labels #
|
28 |
+
#################################################################################
|
29 |
+
|
30 |
+
|
31 |
+
class TimestepEmbedder(nn.Module):
|
32 |
+
"""
|
33 |
+
Embeds scalar timesteps into vector representations.
|
34 |
+
"""
|
35 |
+
|
36 |
+
def __init__(self, hidden_size, frequency_embedding_size=256):
|
37 |
+
super().__init__()
|
38 |
+
self.mlp = nn.Sequential(
|
39 |
+
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
40 |
+
nn.SiLU(),
|
41 |
+
nn.Linear(hidden_size, hidden_size, bias=True),
|
42 |
+
)
|
43 |
+
self.frequency_embedding_size = frequency_embedding_size
|
44 |
+
|
45 |
+
@staticmethod
|
46 |
+
def timestep_embedding(t, dim, max_period=10000):
|
47 |
+
"""
|
48 |
+
Create sinusoidal timestep embeddings.
|
49 |
+
:param t: a 1-D Tensor of N indices, one per batch element.
|
50 |
+
These may be fractional.
|
51 |
+
:param dim: the dimension of the output.
|
52 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
53 |
+
:return: an (N, D) Tensor of positional embeddings.
|
54 |
+
"""
|
55 |
+
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
56 |
+
half = dim // 2
|
57 |
+
freqs = torch.exp(
|
58 |
+
-math.log(max_period)
|
59 |
+
* torch.arange(start=0, end=half, dtype=torch.float32)
|
60 |
+
/ half
|
61 |
+
).to(device=t.device)
|
62 |
+
args = t[:, None].float() * freqs[None]
|
63 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
64 |
+
if dim % 2:
|
65 |
+
embedding = torch.cat(
|
66 |
+
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
|
67 |
+
)
|
68 |
+
return embedding
|
69 |
+
|
70 |
+
def forward(self, t):
|
71 |
+
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
72 |
+
t_emb = self.mlp(t_freq)
|
73 |
+
return t_emb
|
74 |
+
|
75 |
+
|
76 |
+
#################################################################################
|
77 |
+
# Core DiT Model #
|
78 |
+
#################################################################################
|
79 |
+
|
80 |
+
|
81 |
+
class DiTBlock(nn.Module):
|
82 |
+
"""
|
83 |
+
A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
|
84 |
+
"""
|
85 |
+
|
86 |
+
def __init__(
|
87 |
+
self,
|
88 |
+
hidden_size,
|
89 |
+
num_heads,
|
90 |
+
mlp_ratio=4.0,
|
91 |
+
use_xformers_attention=False,
|
92 |
+
**block_kwargs
|
93 |
+
):
|
94 |
+
super().__init__()
|
95 |
+
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
96 |
+
attn = MEAttention if use_xformers_attention else Attention
|
97 |
+
self.attn = attn(
|
98 |
+
hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs
|
99 |
+
)
|
100 |
+
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
101 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
102 |
+
|
103 |
+
def approx_gelu():
|
104 |
+
return nn.GELU(approximate="tanh")
|
105 |
+
|
106 |
+
self.mlp = Mlp(
|
107 |
+
in_features=hidden_size,
|
108 |
+
hidden_features=mlp_hidden_dim,
|
109 |
+
act_layer=approx_gelu,
|
110 |
+
drop=0,
|
111 |
+
)
|
112 |
+
self.adaLN_modulation = nn.Sequential(
|
113 |
+
nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True)
|
114 |
+
)
|
115 |
+
|
116 |
+
def forward(self, x, c):
|
117 |
+
(
|
118 |
+
shift_msa,
|
119 |
+
scale_msa,
|
120 |
+
gate_msa,
|
121 |
+
shift_mlp,
|
122 |
+
scale_mlp,
|
123 |
+
gate_mlp,
|
124 |
+
) = self.adaLN_modulation(c).chunk(6, dim=1)
|
125 |
+
x = x + gate_msa.unsqueeze(1) * self.attn(
|
126 |
+
modulate(self.norm1(x), shift_msa, scale_msa)
|
127 |
+
)
|
128 |
+
x = x + gate_mlp.unsqueeze(1) * self.mlp(
|
129 |
+
modulate(self.norm2(x), shift_mlp, scale_mlp)
|
130 |
+
)
|
131 |
+
return x
|
132 |
+
|
133 |
+
|
134 |
+
class FinalLayer(nn.Module):
|
135 |
+
"""
|
136 |
+
The final layer of DiT.
|
137 |
+
"""
|
138 |
+
|
139 |
+
def __init__(self, hidden_size, patch_size, out_channels):
|
140 |
+
super().__init__()
|
141 |
+
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
142 |
+
self.linear = nn.Linear(
|
143 |
+
hidden_size, patch_size * patch_size * out_channels, bias=True
|
144 |
+
)
|
145 |
+
self.adaLN_modulation = nn.Sequential(
|
146 |
+
nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)
|
147 |
+
)
|
148 |
+
|
149 |
+
def forward(self, x, c):
|
150 |
+
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
151 |
+
x = modulate(self.norm_final(x), shift, scale)
|
152 |
+
x = self.linear(x)
|
153 |
+
return x
|
154 |
+
|
155 |
+
|
156 |
+
class DiT(nn.Module):
|
157 |
+
"""
|
158 |
+
Diffusion model with a Transformer backbone.
|
159 |
+
"""
|
160 |
+
|
161 |
+
def __init__(
|
162 |
+
self,
|
163 |
+
in_channels=442,
|
164 |
+
out_channels=6,
|
165 |
+
width=16,
|
166 |
+
hidden_size=1152,
|
167 |
+
depth=8,
|
168 |
+
num_heads=16,
|
169 |
+
mlp_ratio=4.0,
|
170 |
+
max_num_images=8,
|
171 |
+
P=1,
|
172 |
+
within_image=False,
|
173 |
+
):
|
174 |
+
super().__init__()
|
175 |
+
self.num_heads = num_heads
|
176 |
+
self.in_channels = in_channels
|
177 |
+
self.out_channels = out_channels
|
178 |
+
self.width = width
|
179 |
+
self.hidden_size = hidden_size
|
180 |
+
self.max_num_images = max_num_images
|
181 |
+
self.P = P
|
182 |
+
self.within_image = within_image
|
183 |
+
|
184 |
+
# self.x_embedder = nn.Linear(in_channels, hidden_size)
|
185 |
+
# self.x_embedder = PatchEmbed(in_channels, hidden_size, kernel_size=P, hidden_size=P)
|
186 |
+
self.x_embedder = PatchEmbed(
|
187 |
+
img_size=self.width,
|
188 |
+
patch_size=self.P,
|
189 |
+
in_chans=in_channels,
|
190 |
+
embed_dim=hidden_size,
|
191 |
+
bias=True,
|
192 |
+
flatten=False,
|
193 |
+
)
|
194 |
+
self.x_pos_enc = FeaturePositionalEncoding(
|
195 |
+
max_num_images, hidden_size, width**2, P=self.P
|
196 |
+
)
|
197 |
+
self.t_embedder = TimestepEmbedder(hidden_size)
|
198 |
+
|
199 |
+
try:
|
200 |
+
import xformers
|
201 |
+
|
202 |
+
use_xformers_attention = True
|
203 |
+
except ImportError:
|
204 |
+
# xformers not available
|
205 |
+
use_xformers_attention = False
|
206 |
+
|
207 |
+
self.blocks = nn.ModuleList(
|
208 |
+
[
|
209 |
+
DiTBlock(
|
210 |
+
hidden_size,
|
211 |
+
num_heads,
|
212 |
+
mlp_ratio=mlp_ratio,
|
213 |
+
use_xformers_attention=use_xformers_attention,
|
214 |
+
)
|
215 |
+
for _ in range(depth)
|
216 |
+
]
|
217 |
+
)
|
218 |
+
self.final_layer = FinalLayer(hidden_size, P, out_channels)
|
219 |
+
self.initialize_weights()
|
220 |
+
|
221 |
+
def initialize_weights(self):
|
222 |
+
# Initialize transformer layers:
|
223 |
+
def _basic_init(module):
|
224 |
+
if isinstance(module, nn.Linear):
|
225 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
226 |
+
if module.bias is not None:
|
227 |
+
nn.init.constant_(module.bias, 0)
|
228 |
+
|
229 |
+
self.apply(_basic_init)
|
230 |
+
|
231 |
+
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
|
232 |
+
w = self.x_embedder.proj.weight.data
|
233 |
+
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
234 |
+
nn.init.constant_(self.x_embedder.proj.bias, 0)
|
235 |
+
|
236 |
+
# Initialize timestep embedding MLP:
|
237 |
+
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
238 |
+
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
239 |
+
|
240 |
+
# Zero-out adaLN modulation layers in DiT blocks:
|
241 |
+
for block in self.blocks:
|
242 |
+
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
|
243 |
+
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
|
244 |
+
|
245 |
+
# Zero-out output layers:
|
246 |
+
# nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
|
247 |
+
# nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
|
248 |
+
# nn.init.constant_(self.final_layer.linear.weight, 0)
|
249 |
+
# nn.init.constant_(self.final_layer.linear.bias, 0)
|
250 |
+
|
251 |
+
def unpatchify(self, x):
|
252 |
+
"""
|
253 |
+
x: (N, T, patch_size**2 * C)
|
254 |
+
imgs: (N, H, W, C)
|
255 |
+
"""
|
256 |
+
c = self.out_channels
|
257 |
+
p = self.x_embedder.patch_size[0]
|
258 |
+
h = w = int(x.shape[1] ** 0.5)
|
259 |
+
|
260 |
+
# print("unpatchify", c, p, h, w, x.shape)
|
261 |
+
# assert h * w == x.shape[2]
|
262 |
+
|
263 |
+
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
|
264 |
+
x = torch.einsum("nhwpqc->nhpwqc", x)
|
265 |
+
imgs = x.reshape(shape=(x.shape[0], h * p, h * p, c))
|
266 |
+
return imgs
|
267 |
+
|
268 |
+
def forward(
|
269 |
+
self,
|
270 |
+
x,
|
271 |
+
t,
|
272 |
+
return_dpt_activations=False,
|
273 |
+
multiview_unconditional=False,
|
274 |
+
):
|
275 |
+
"""
|
276 |
+
|
277 |
+
Args:
|
278 |
+
x: Image/Ray features (B, N, C, H, W).
|
279 |
+
t: Timesteps (N,).
|
280 |
+
|
281 |
+
Returns:
|
282 |
+
(B, N, D, H, W)
|
283 |
+
"""
|
284 |
+
B, N, c, h, w = x.shape
|
285 |
+
P = self.P
|
286 |
+
|
287 |
+
x = x.reshape((B * N, c, h, w)) # (B * N, C, H, W)
|
288 |
+
x = self.x_embedder(x) # (B * N, C, H / P, W / P)
|
289 |
+
|
290 |
+
x = x.permute(0, 2, 3, 1) # (B * N, H / P, W / P, C)
|
291 |
+
# (B, N, H / P, W / P, C)
|
292 |
+
x = x.reshape((B, N, h // P, w // P, self.hidden_size))
|
293 |
+
x = self.x_pos_enc(x) # (B, N, H * W / P ** 2, C)
|
294 |
+
# TODO: fix positional encoding to work with (N, C, H, W) format.
|
295 |
+
|
296 |
+
# Eval time, we get a scalar t
|
297 |
+
if x.shape[0] != t.shape[0] and t.shape[0] == 1:
|
298 |
+
t = t.repeat_interleave(B)
|
299 |
+
|
300 |
+
if self.within_image or multiview_unconditional:
|
301 |
+
t_within = t.repeat_interleave(N)
|
302 |
+
t_within = self.t_embedder(t_within)
|
303 |
+
|
304 |
+
t = self.t_embedder(t)
|
305 |
+
|
306 |
+
dpt_activations = []
|
307 |
+
for i, block in enumerate(self.blocks):
|
308 |
+
# Within image block
|
309 |
+
if (self.within_image and i % 2 == 0) or multiview_unconditional:
|
310 |
+
x = x.reshape((B * N, h * w // P**2, self.hidden_size))
|
311 |
+
x = block(x, t_within)
|
312 |
+
|
313 |
+
# All patches block
|
314 |
+
# Final layer is an all patches layer
|
315 |
+
else:
|
316 |
+
x = x.reshape((B, N * h * w // P**2, self.hidden_size))
|
317 |
+
x = block(x, t) # (N, T, D)
|
318 |
+
|
319 |
+
if return_dpt_activations and i % 4 == 3:
|
320 |
+
x_prime = x.reshape(B, N, h, w, self.hidden_size)
|
321 |
+
x_prime = x.reshape(B * N, h, w, self.hidden_size)
|
322 |
+
x_prime = x_prime.permute((0, 3, 1, 2))
|
323 |
+
dpt_activations.append(x_prime)
|
324 |
+
|
325 |
+
# Reshape the output back to original shape
|
326 |
+
if multiview_unconditional:
|
327 |
+
x = x.reshape((B, N * h * w // P**2, self.hidden_size))
|
328 |
+
|
329 |
+
# (B, N * H * W / P ** 2, D)
|
330 |
+
x = self.final_layer(
|
331 |
+
x, t
|
332 |
+
) # (B, N * H * W / P ** 2, 6 * P ** 2) or (N, T, patch_size ** 2 * out_channels)
|
333 |
+
|
334 |
+
x = x.reshape((B * N, w * w // P**2, self.out_channels * P**2))
|
335 |
+
x = self.unpatchify(x) # (B * N, H, W, C)
|
336 |
+
x = x.reshape((B, N) + x.shape[1:])
|
337 |
+
x = x.permute(0, 1, 4, 2, 3) # (B, N, C, H, W)
|
338 |
+
|
339 |
+
if return_dpt_activations:
|
340 |
+
return dpt_activations[:4]
|
341 |
+
|
342 |
+
return x
|
343 |
+
|
344 |
+
|
345 |
+
class FeaturePositionalEncoding(nn.Module):
|
346 |
+
def _get_sinusoid_encoding_table(self, n_position, d_hid, base):
|
347 |
+
"""Sinusoid position encoding table"""
|
348 |
+
|
349 |
+
def get_position_angle_vec(position):
|
350 |
+
return [
|
351 |
+
position / np.power(base, 2 * (hid_j // 2) / d_hid)
|
352 |
+
for hid_j in range(d_hid)
|
353 |
+
]
|
354 |
+
|
355 |
+
sinusoid_table = np.array(
|
356 |
+
[get_position_angle_vec(pos_i) for pos_i in range(n_position)]
|
357 |
+
)
|
358 |
+
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
|
359 |
+
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
|
360 |
+
|
361 |
+
return torch.FloatTensor(sinusoid_table).unsqueeze(0)
|
362 |
+
|
363 |
+
def __init__(self, max_num_images=8, feature_dim=1152, num_patches=256, P=1):
|
364 |
+
super().__init__()
|
365 |
+
self.max_num_images = max_num_images
|
366 |
+
self.feature_dim = feature_dim
|
367 |
+
self.P = P
|
368 |
+
self.num_patches = num_patches // self.P**2
|
369 |
+
|
370 |
+
self.register_buffer(
|
371 |
+
"image_pos_table",
|
372 |
+
self._get_sinusoid_encoding_table(
|
373 |
+
self.max_num_images, self.feature_dim, 10000
|
374 |
+
),
|
375 |
+
)
|
376 |
+
|
377 |
+
self.register_buffer(
|
378 |
+
"token_pos_table",
|
379 |
+
self._get_sinusoid_encoding_table(
|
380 |
+
self.num_patches, self.feature_dim, 70007
|
381 |
+
),
|
382 |
+
)
|
383 |
+
|
384 |
+
def forward(self, x):
|
385 |
+
batch_size = x.shape[0]
|
386 |
+
num_images = x.shape[1]
|
387 |
+
|
388 |
+
x = x.reshape(batch_size, num_images, self.num_patches, self.feature_dim)
|
389 |
+
|
390 |
+
# To encode image index
|
391 |
+
pe1 = self.image_pos_table[:, :num_images].clone().detach()
|
392 |
+
pe1 = pe1.reshape((1, num_images, 1, self.feature_dim))
|
393 |
+
pe1 = pe1.repeat((batch_size, 1, self.num_patches, 1))
|
394 |
+
|
395 |
+
# To encode patch index
|
396 |
+
pe2 = self.token_pos_table.clone().detach()
|
397 |
+
pe2 = pe2.reshape((1, 1, self.num_patches, self.feature_dim))
|
398 |
+
pe2 = pe2.repeat((batch_size, num_images, 1, 1))
|
399 |
+
|
400 |
+
x_pe = x + pe1 + pe2
|
401 |
+
x_pe = x_pe.reshape(
|
402 |
+
(batch_size, num_images * self.num_patches, self.feature_dim)
|
403 |
+
)
|
404 |
+
|
405 |
+
return x_pe
|
406 |
+
|
407 |
+
def forward_unet(self, x, B, N):
|
408 |
+
D = int(self.num_patches**0.5)
|
409 |
+
|
410 |
+
# x should be (B, N, T, D, D)
|
411 |
+
x = x.permute((0, 2, 3, 1))
|
412 |
+
x = x.reshape(B, N, self.num_patches, self.feature_dim)
|
413 |
+
|
414 |
+
# To encode image index
|
415 |
+
pe1 = self.image_pos_table[:, :N].clone().detach()
|
416 |
+
pe1 = pe1.reshape((1, N, 1, self.feature_dim))
|
417 |
+
pe1 = pe1.repeat((B, 1, self.num_patches, 1))
|
418 |
+
|
419 |
+
# To encode patch index
|
420 |
+
pe2 = self.token_pos_table.clone().detach()
|
421 |
+
pe2 = pe2.reshape((1, 1, self.num_patches, self.feature_dim))
|
422 |
+
pe2 = pe2.repeat((B, N, 1, 1))
|
423 |
+
|
424 |
+
x_pe = x + pe1 + pe2
|
425 |
+
x_pe = x_pe.reshape((B * N, D, D, self.feature_dim))
|
426 |
+
x_pe = x_pe.permute((0, 3, 1, 2))
|
427 |
+
|
428 |
+
return x_pe
|
diffusionsfm/model/feature_extractors.py
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
import os
|
3 |
+
import socket
|
4 |
+
import sys
|
5 |
+
|
6 |
+
import ipdb # noqa: F401
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
from omegaconf import OmegaConf
|
10 |
+
|
11 |
+
HOSTNAME = socket.gethostname()
|
12 |
+
|
13 |
+
if "trinity" in HOSTNAME:
|
14 |
+
# Might be outdated
|
15 |
+
config_path = "/home/amylin2/latent-diffusion/configs/autoencoder/autoencoder_kl_16x16x16.yaml"
|
16 |
+
weights_path = "/home/amylin2/latent-diffusion/model.ckpt"
|
17 |
+
elif "grogu" in HOSTNAME:
|
18 |
+
# Might be outdated
|
19 |
+
config_path = "/home/jasonzh2/code/latent-diffusion/configs/autoencoder/autoencoder_kl_16x16x16.yaml"
|
20 |
+
weights_path = "/home/jasonzh2/code/latent-diffusion/model.ckpt"
|
21 |
+
elif "ender" in HOSTNAME:
|
22 |
+
config_path = "/home/jason/ray_diffusion/external/latent-diffusion/configs/autoencoder/autoencoder_kl_16x16x16.yaml"
|
23 |
+
weights_path = "/home/jason/ray_diffusion/external/latent-diffusion/model.ckpt"
|
24 |
+
else:
|
25 |
+
config_path = None
|
26 |
+
weights_path = None
|
27 |
+
|
28 |
+
|
29 |
+
if weights_path is not None:
|
30 |
+
LDM_PATH = os.path.dirname(weights_path)
|
31 |
+
if LDM_PATH not in sys.path:
|
32 |
+
sys.path.append(LDM_PATH)
|
33 |
+
|
34 |
+
|
35 |
+
def resize(image, size=None, scale_factor=None):
|
36 |
+
return nn.functional.interpolate(
|
37 |
+
image,
|
38 |
+
size=size,
|
39 |
+
scale_factor=scale_factor,
|
40 |
+
mode="bilinear",
|
41 |
+
align_corners=False,
|
42 |
+
)
|
43 |
+
|
44 |
+
|
45 |
+
def instantiate_from_config(config):
|
46 |
+
if "target" not in config:
|
47 |
+
if config == "__is_first_stage__":
|
48 |
+
return None
|
49 |
+
elif config == "__is_unconditional__":
|
50 |
+
return None
|
51 |
+
raise KeyError("Expected key `target` to instantiate.")
|
52 |
+
return get_obj_from_str(config["target"])(**config.get("params", dict()))
|
53 |
+
|
54 |
+
|
55 |
+
def get_obj_from_str(string, reload=False):
|
56 |
+
module, cls = string.rsplit(".", 1)
|
57 |
+
if reload:
|
58 |
+
module_imp = importlib.import_module(module)
|
59 |
+
importlib.reload(module_imp)
|
60 |
+
return getattr(importlib.import_module(module, package=None), cls)
|
61 |
+
|
62 |
+
|
63 |
+
class PretrainedVAE(nn.Module):
|
64 |
+
def __init__(self, freeze_weights=True, num_patches_x=16, num_patches_y=16):
|
65 |
+
super().__init__()
|
66 |
+
config = OmegaConf.load(config_path)
|
67 |
+
self.model = instantiate_from_config(config.model)
|
68 |
+
self.model.init_from_ckpt(weights_path)
|
69 |
+
self.model.eval()
|
70 |
+
self.feature_dim = 16
|
71 |
+
self.num_patches_x = num_patches_x
|
72 |
+
self.num_patches_y = num_patches_y
|
73 |
+
|
74 |
+
if freeze_weights:
|
75 |
+
for param in self.model.parameters():
|
76 |
+
param.requires_grad = False
|
77 |
+
|
78 |
+
def forward(self, x, autoresize=False):
|
79 |
+
"""
|
80 |
+
Spatial dimensions of output will be H // 16, W // 16. If autoresize is True,
|
81 |
+
then the input will be resized such that the output feature map is the correct
|
82 |
+
dimensions.
|
83 |
+
|
84 |
+
Args:
|
85 |
+
x (torch.Tensor): Images (B, C, H, W). Should be normalized to be [-1, 1].
|
86 |
+
autoresize (bool): Whether to resize the input to match the num_patch
|
87 |
+
dimensions.
|
88 |
+
|
89 |
+
Returns:
|
90 |
+
torch.Tensor: Latent sample (B, 16, h, w)
|
91 |
+
"""
|
92 |
+
|
93 |
+
*B, c, h, w = x.shape
|
94 |
+
x = x.reshape(-1, c, h, w)
|
95 |
+
if autoresize:
|
96 |
+
new_w = self.num_patches_x * 16
|
97 |
+
new_h = self.num_patches_y * 16
|
98 |
+
x = resize(x, size=(new_h, new_w))
|
99 |
+
|
100 |
+
decoded, latent = self.model(x)
|
101 |
+
# A little ambiguous bc it's all 16, but it is (c, h, w)
|
102 |
+
latent_sample = latent.sample().reshape(
|
103 |
+
*B, self.feature_dim, self.num_patches_y, self.num_patches_x
|
104 |
+
)
|
105 |
+
return latent_sample
|
106 |
+
|
107 |
+
|
108 |
+
activations = {}
|
109 |
+
|
110 |
+
|
111 |
+
def get_activation(name):
|
112 |
+
def hook(model, input, output):
|
113 |
+
activations[name] = output
|
114 |
+
|
115 |
+
return hook
|
116 |
+
|
117 |
+
|
118 |
+
class SpatialDino(nn.Module):
|
119 |
+
def __init__(
|
120 |
+
self,
|
121 |
+
freeze_weights=True,
|
122 |
+
model_type="dinov2_vits14",
|
123 |
+
num_patches_x=16,
|
124 |
+
num_patches_y=16,
|
125 |
+
activation_hooks=False,
|
126 |
+
):
|
127 |
+
super().__init__()
|
128 |
+
self.model = torch.hub.load("facebookresearch/dinov2", model_type)
|
129 |
+
self.feature_dim = self.model.embed_dim
|
130 |
+
self.num_patches_x = num_patches_x
|
131 |
+
self.num_patches_y = num_patches_y
|
132 |
+
if freeze_weights:
|
133 |
+
for param in self.model.parameters():
|
134 |
+
param.requires_grad = False
|
135 |
+
|
136 |
+
self.activation_hooks = activation_hooks
|
137 |
+
|
138 |
+
if self.activation_hooks:
|
139 |
+
self.model.blocks[5].register_forward_hook(get_activation("encoder1"))
|
140 |
+
self.model.blocks[11].register_forward_hook(get_activation("encoder2"))
|
141 |
+
self.activations = activations
|
142 |
+
|
143 |
+
def forward(self, x, autoresize=False):
|
144 |
+
"""
|
145 |
+
Spatial dimensions of output will be H // 14, W // 14. If autoresize is True,
|
146 |
+
then the output will be resized to the correct dimensions.
|
147 |
+
|
148 |
+
Args:
|
149 |
+
x (torch.Tensor): Images (B, C, H, W). Should be ImageNet normalized.
|
150 |
+
autoresize (bool): Whether to resize the input to match the num_patch
|
151 |
+
dimensions.
|
152 |
+
|
153 |
+
Returns:
|
154 |
+
feature_map (torch.tensor): (B, C, h, w)
|
155 |
+
"""
|
156 |
+
*B, c, h, w = x.shape
|
157 |
+
|
158 |
+
x = x.reshape(-1, c, h, w)
|
159 |
+
# if autoresize:
|
160 |
+
# new_w = self.num_patches_x * 14
|
161 |
+
# new_h = self.num_patches_y * 14
|
162 |
+
# x = resize(x, size=(new_h, new_w))
|
163 |
+
|
164 |
+
# Output will be (B, H * W, C)
|
165 |
+
features = self.model.forward_features(x)["x_norm_patchtokens"]
|
166 |
+
features = features.permute(0, 2, 1)
|
167 |
+
features = features.reshape( # (B, C, H, W)
|
168 |
+
-1, self.feature_dim, h // 14, w // 14
|
169 |
+
)
|
170 |
+
if autoresize:
|
171 |
+
features = resize(features, size=(self.num_patches_y, self.num_patches_x))
|
172 |
+
|
173 |
+
features = features.reshape(
|
174 |
+
*B, self.feature_dim, self.num_patches_y, self.num_patches_x
|
175 |
+
)
|
176 |
+
return features
|
diffusionsfm/model/memory_efficient_attention.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import ipdb
|
2 |
+
import torch.nn as nn
|
3 |
+
from xformers.ops import memory_efficient_attention
|
4 |
+
|
5 |
+
|
6 |
+
class MEAttention(nn.Module):
|
7 |
+
def __init__(
|
8 |
+
self,
|
9 |
+
dim,
|
10 |
+
num_heads=8,
|
11 |
+
qkv_bias=False,
|
12 |
+
qk_norm=False,
|
13 |
+
attn_drop=0.0,
|
14 |
+
proj_drop=0.0,
|
15 |
+
norm_layer=nn.LayerNorm,
|
16 |
+
):
|
17 |
+
super().__init__()
|
18 |
+
assert dim % num_heads == 0, "dim should be divisible by num_heads"
|
19 |
+
self.num_heads = num_heads
|
20 |
+
self.head_dim = dim // num_heads
|
21 |
+
self.scale = self.head_dim**-0.5
|
22 |
+
|
23 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
24 |
+
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
25 |
+
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
26 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
27 |
+
self.proj = nn.Linear(dim, dim)
|
28 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
29 |
+
|
30 |
+
def forward(self, x):
|
31 |
+
B, N, C = x.shape
|
32 |
+
qkv = (
|
33 |
+
self.qkv(x)
|
34 |
+
.reshape(B, N, 3, self.num_heads, self.head_dim)
|
35 |
+
.permute(2, 0, 3, 1, 4)
|
36 |
+
)
|
37 |
+
q, k, v = qkv.unbind(0)
|
38 |
+
q, k = self.q_norm(q), self.k_norm(k)
|
39 |
+
|
40 |
+
# MEA expects [B, N, H, D], whereas timm uses [B, H, N, D]
|
41 |
+
x = memory_efficient_attention(
|
42 |
+
q.transpose(1, 2),
|
43 |
+
k.transpose(1, 2),
|
44 |
+
v.transpose(1, 2),
|
45 |
+
scale=self.scale,
|
46 |
+
)
|
47 |
+
x = x.reshape(B, N, C)
|
48 |
+
|
49 |
+
x = self.proj(x)
|
50 |
+
x = self.proj_drop(x)
|
51 |
+
return x
|
diffusionsfm/model/scheduler.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import ipdb # noqa: F401
|
2 |
+
import matplotlib.pyplot as plt
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
from diffusionsfm.utils.visualization import plot_to_image
|
8 |
+
|
9 |
+
|
10 |
+
class NoiseScheduler(nn.Module):
|
11 |
+
def __init__(
|
12 |
+
self,
|
13 |
+
max_timesteps=1000,
|
14 |
+
beta_start=0.0001,
|
15 |
+
beta_end=0.02,
|
16 |
+
cos_power=2,
|
17 |
+
num_inference_steps=100,
|
18 |
+
type="linear",
|
19 |
+
):
|
20 |
+
super().__init__()
|
21 |
+
self.max_timesteps = max_timesteps
|
22 |
+
self.num_inference_steps = num_inference_steps
|
23 |
+
self.beta_start = beta_start
|
24 |
+
self.beta_end = beta_end
|
25 |
+
self.cos_power = cos_power
|
26 |
+
self.type = type
|
27 |
+
|
28 |
+
if type == "linear":
|
29 |
+
self.register_linear_schedule()
|
30 |
+
elif type == "cosine":
|
31 |
+
self.register_cosine_schedule(cos_power)
|
32 |
+
elif type == "scaled_linear":
|
33 |
+
self.register_scaled_linear_schedule()
|
34 |
+
|
35 |
+
self.inference_timesteps = self.compute_inference_timesteps()
|
36 |
+
|
37 |
+
def register_linear_schedule(self):
|
38 |
+
# zero terminal SNR (https://arxiv.org/pdf/2305.08891)
|
39 |
+
betas = torch.linspace(
|
40 |
+
self.beta_start,
|
41 |
+
self.beta_end,
|
42 |
+
self.max_timesteps,
|
43 |
+
dtype=torch.float32,
|
44 |
+
)
|
45 |
+
alphas = 1.0 - betas
|
46 |
+
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
47 |
+
alphas_bar_sqrt = alphas_cumprod.sqrt()
|
48 |
+
|
49 |
+
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
50 |
+
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
51 |
+
|
52 |
+
alphas_bar_sqrt -= alphas_bar_sqrt_T
|
53 |
+
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
54 |
+
|
55 |
+
alphas_bar = alphas_bar_sqrt**2
|
56 |
+
alphas = alphas_bar[1:] / alphas_bar[:-1]
|
57 |
+
alphas = torch.cat([alphas_bar[0:1], alphas])
|
58 |
+
betas = 1 - alphas
|
59 |
+
|
60 |
+
self.register_buffer(
|
61 |
+
"betas",
|
62 |
+
betas,
|
63 |
+
)
|
64 |
+
self.register_buffer("alphas", 1.0 - self.betas)
|
65 |
+
self.register_buffer("alphas_cumprod", torch.cumprod(self.alphas, dim=0))
|
66 |
+
|
67 |
+
def register_cosine_schedule(self, cos_power, s=0.008):
|
68 |
+
timesteps = (
|
69 |
+
torch.arange(self.max_timesteps + 1, dtype=torch.float32)
|
70 |
+
/ self.max_timesteps
|
71 |
+
)
|
72 |
+
alpha_bars = (timesteps + s) / (1 + s) * np.pi / 2
|
73 |
+
alpha_bars = torch.cos(alpha_bars).pow(cos_power)
|
74 |
+
alpha_bars = alpha_bars / alpha_bars[0]
|
75 |
+
betas = 1 - alpha_bars[1:] / alpha_bars[:-1]
|
76 |
+
betas = np.clip(betas, a_min=0, a_max=0.999)
|
77 |
+
|
78 |
+
self.register_buffer(
|
79 |
+
"betas",
|
80 |
+
betas,
|
81 |
+
)
|
82 |
+
self.register_buffer("alphas", 1.0 - betas)
|
83 |
+
self.register_buffer("alphas_cumprod", torch.cumprod(self.alphas, dim=0))
|
84 |
+
|
85 |
+
def register_scaled_linear_schedule(self):
|
86 |
+
self.register_buffer(
|
87 |
+
"betas",
|
88 |
+
torch.linspace(
|
89 |
+
self.beta_start**0.5,
|
90 |
+
self.beta_end**0.5,
|
91 |
+
self.max_timesteps,
|
92 |
+
dtype=torch.float32,
|
93 |
+
)
|
94 |
+
** 2,
|
95 |
+
)
|
96 |
+
self.register_buffer("alphas", 1.0 - self.betas)
|
97 |
+
self.register_buffer("alphas_cumprod", torch.cumprod(self.alphas, dim=0))
|
98 |
+
|
99 |
+
def compute_inference_timesteps(
|
100 |
+
self, num_inference_steps=None, num_train_steps=None
|
101 |
+
):
|
102 |
+
# based on diffusers's scheduling code
|
103 |
+
if num_inference_steps is None:
|
104 |
+
num_inference_steps = self.num_inference_steps
|
105 |
+
if num_train_steps is None:
|
106 |
+
num_train_steps = self.max_timesteps
|
107 |
+
step_ratio = num_train_steps // num_inference_steps
|
108 |
+
timesteps = (
|
109 |
+
(np.arange(0, num_inference_steps) * step_ratio).round()[::-1].astype(int)
|
110 |
+
)
|
111 |
+
return timesteps
|
112 |
+
|
113 |
+
def plot_schedule(self, return_image=False):
|
114 |
+
fig = plt.figure(figsize=(6, 4), dpi=100)
|
115 |
+
alpha_bars = self.alphas_cumprod.cpu().numpy()
|
116 |
+
plt.plot(np.sqrt(alpha_bars))
|
117 |
+
plt.grid()
|
118 |
+
if self.type == "linear":
|
119 |
+
plt.title(
|
120 |
+
f"Linear (T={self.max_timesteps}, S={self.beta_start}, E={self.beta_end})"
|
121 |
+
)
|
122 |
+
else:
|
123 |
+
self.type == "cosine"
|
124 |
+
plt.title(f"Cosine (T={self.max_timesteps}, P={self.cos_power})")
|
125 |
+
if return_image:
|
126 |
+
image = plot_to_image(fig)
|
127 |
+
plt.close(fig)
|
128 |
+
return image
|
diffusionsfm/utils/__init__.py
ADDED
File without changes
|
diffusionsfm/utils/configs.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
|
4 |
+
from omegaconf import OmegaConf
|
5 |
+
|
6 |
+
|
7 |
+
def load_cfg(config_path):
|
8 |
+
"""
|
9 |
+
Loads a yaml configuration file.
|
10 |
+
|
11 |
+
Follows the chain of yaml configuration files that have a `_BASE` key, and updates
|
12 |
+
the new keys accordingly. _BASE configurations can be specified using relative
|
13 |
+
paths.
|
14 |
+
"""
|
15 |
+
config_dir = os.path.dirname(config_path)
|
16 |
+
config_path = os.path.basename(config_path)
|
17 |
+
return load_cfg_recursive(config_dir, config_path)
|
18 |
+
|
19 |
+
|
20 |
+
def load_cfg_recursive(config_dir, config_path):
|
21 |
+
"""
|
22 |
+
Recursively loads config files.
|
23 |
+
|
24 |
+
Follows the chain of yaml configuration files that have a `_BASE` key, and updates
|
25 |
+
the new keys accordingly. _BASE configurations can be specified using relative
|
26 |
+
paths.
|
27 |
+
"""
|
28 |
+
cfg = OmegaConf.load(os.path.join(config_dir, config_path))
|
29 |
+
base_path = OmegaConf.select(cfg, "_BASE", default=None)
|
30 |
+
if base_path is not None:
|
31 |
+
base_cfg = load_cfg_recursive(config_dir, base_path)
|
32 |
+
cfg = OmegaConf.merge(base_cfg, cfg)
|
33 |
+
return cfg
|
34 |
+
|
35 |
+
|
36 |
+
def get_cfg():
|
37 |
+
parser = argparse.ArgumentParser()
|
38 |
+
parser.add_argument("--config-path", type=str, required=True)
|
39 |
+
args = parser.parse_args()
|
40 |
+
cfg = load_cfg(args.config_path)
|
41 |
+
print(OmegaConf.to_yaml(cfg))
|
42 |
+
|
43 |
+
exp_dir = os.path.join(cfg.training.runs_dir, cfg.training.exp_tag)
|
44 |
+
os.makedirs(exp_dir, exist_ok=True)
|
45 |
+
to_path = os.path.join(exp_dir, os.path.basename(args.config_path))
|
46 |
+
if not os.path.exists(to_path):
|
47 |
+
OmegaConf.save(config=cfg, f=to_path)
|
48 |
+
return cfg
|
49 |
+
|
50 |
+
|
51 |
+
def get_cfg_from_path(config_path):
|
52 |
+
"""
|
53 |
+
args:
|
54 |
+
config_path - get config from path
|
55 |
+
"""
|
56 |
+
print("getting config from path")
|
57 |
+
|
58 |
+
cfg = load_cfg(config_path)
|
59 |
+
print(OmegaConf.to_yaml(cfg))
|
60 |
+
|
61 |
+
exp_dir = os.path.join(cfg.training.runs_dir, cfg.training.exp_tag)
|
62 |
+
os.makedirs(exp_dir, exist_ok=True)
|
63 |
+
to_path = os.path.join(exp_dir, os.path.basename(config_path))
|
64 |
+
if not os.path.exists(to_path):
|
65 |
+
OmegaConf.save(config=cfg, f=to_path)
|
66 |
+
return cfg
|
diffusionsfm/utils/distortion.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import ipdb
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
import torch
|
6 |
+
|
7 |
+
|
8 |
+
# https://gist.github.com/davegreenwood/820d51ac5ec88a2aeda28d3079e7d9eb
|
9 |
+
def apply_distortion(pts, k1, k2):
|
10 |
+
"""
|
11 |
+
Arguments:
|
12 |
+
pts (N x 2): numpy array in NDC coordinates
|
13 |
+
k1, k2 distortion coefficients
|
14 |
+
Return:
|
15 |
+
pts (N x 2): distorted points in NDC coordinates
|
16 |
+
"""
|
17 |
+
r2 = np.square(pts).sum(-1)
|
18 |
+
f = 1 + k1 * r2 + k2 * r2**2
|
19 |
+
return f[..., None] * pts
|
20 |
+
|
21 |
+
|
22 |
+
# https://gist.github.com/davegreenwood/820d51ac5ec88a2aeda28d3079e7d9eb
|
23 |
+
def apply_distortion_tensor(pts, k1, k2):
|
24 |
+
"""
|
25 |
+
Arguments:
|
26 |
+
pts (N x 2): numpy array in NDC coordinates
|
27 |
+
k1, k2 distortion coefficients
|
28 |
+
Return:
|
29 |
+
pts (N x 2): distorted points in NDC coordinates
|
30 |
+
"""
|
31 |
+
r2 = torch.square(pts).sum(-1)
|
32 |
+
f = 1 + k1 * r2 + k2 * r2**2
|
33 |
+
return f[..., None] * pts
|
34 |
+
|
35 |
+
|
36 |
+
# https://gist.github.com/davegreenwood/820d51ac5ec88a2aeda28d3079e7d9eb
|
37 |
+
def remove_distortion_iter(points, k1, k2):
|
38 |
+
"""
|
39 |
+
Arguments:
|
40 |
+
pts (N x 2): numpy array in NDC coordinates
|
41 |
+
k1, k2 distortion coefficients
|
42 |
+
Return:
|
43 |
+
pts (N x 2): distorted points in NDC coordinates
|
44 |
+
"""
|
45 |
+
pts = ptsd = points
|
46 |
+
for _ in range(5):
|
47 |
+
r2 = np.square(pts).sum(-1)
|
48 |
+
f = 1 + k1 * r2 + k2 * r2**2
|
49 |
+
pts = ptsd / f[..., None]
|
50 |
+
|
51 |
+
return pts
|
52 |
+
|
53 |
+
|
54 |
+
def make_square(im, fill_color=(0, 0, 0)):
|
55 |
+
x, y = im.size
|
56 |
+
size = max(x, y)
|
57 |
+
new_im = Image.new("RGB", (size, size), fill_color)
|
58 |
+
corner = (int((size - x) / 2), int((size - y) / 2))
|
59 |
+
new_im.paste(im, corner)
|
60 |
+
return new_im, corner
|
61 |
+
|
62 |
+
|
63 |
+
def pixel_to_ndc(coords, image_size):
|
64 |
+
"""
|
65 |
+
Converts pixel coordinates to normalized device coordinates (Pytorch3D convention
|
66 |
+
with upper left = (1, 1)) for a square image.
|
67 |
+
|
68 |
+
Args:
|
69 |
+
coords: Pixel coordinates UL=(0, 0), LR=(image_size, image_size).
|
70 |
+
image_size (int): Image size.
|
71 |
+
|
72 |
+
Returns:
|
73 |
+
NDC coordinates UL=(1, 1) LR=(-1, -1).
|
74 |
+
"""
|
75 |
+
coords = np.array(coords)
|
76 |
+
return 1 - coords / image_size * 2
|
77 |
+
|
78 |
+
|
79 |
+
def ndc_to_pixel(coords, image_size):
|
80 |
+
"""
|
81 |
+
Converts normalized device coordinates to pixel coordinates for a square image.
|
82 |
+
"""
|
83 |
+
num_points = coords.shape[0]
|
84 |
+
sizes = np.tile(np.array(image_size, dtype=np.float32)[None, ...], (num_points, 1))
|
85 |
+
|
86 |
+
coords = np.array(coords, dtype=np.float32)
|
87 |
+
return (1 - coords) * sizes / 2
|
88 |
+
|
89 |
+
|
90 |
+
def distort_image(image, bbox, k1, k2, modify_bbox=False):
|
91 |
+
# We want to operate in -1 to 1 space using the padded square of the original image
|
92 |
+
image, corner = make_square(image)
|
93 |
+
bbox[:2] += np.array(corner)
|
94 |
+
bbox[2:] += np.array(corner)
|
95 |
+
|
96 |
+
# Construct grid points
|
97 |
+
x = np.linspace(1, -1, image.width, dtype=np.float32)
|
98 |
+
y = np.linspace(1, -1, image.height, dtype=np.float32)
|
99 |
+
x, y = np.meshgrid(x, y, indexing="xy")
|
100 |
+
xy_grid = np.stack((x, y), axis=-1)
|
101 |
+
points = xy_grid.reshape((image.height * image.width, 2))
|
102 |
+
new_points = ndc_to_pixel(apply_distortion(points, k1, k2), image.size)
|
103 |
+
|
104 |
+
# Distort image by remapping
|
105 |
+
map_x = new_points[:, 0].reshape((image.height, image.width))
|
106 |
+
map_y = new_points[:, 1].reshape((image.height, image.width))
|
107 |
+
distorted = cv2.remap(
|
108 |
+
np.asarray(image),
|
109 |
+
map_x,
|
110 |
+
map_y,
|
111 |
+
cv2.INTER_LINEAR,
|
112 |
+
)
|
113 |
+
distorted = Image.fromarray(distorted)
|
114 |
+
|
115 |
+
# Find distorted crop bounds - inverse process of above
|
116 |
+
if modify_bbox:
|
117 |
+
center = (bbox[:2] + bbox[2:]) / 2
|
118 |
+
top, bottom = (bbox[0], center[1]), (bbox[2], center[1])
|
119 |
+
left, right = (center[0], bbox[1]), (center[0], bbox[3])
|
120 |
+
bbox_points = np.array(
|
121 |
+
[
|
122 |
+
pixel_to_ndc(top, image.size),
|
123 |
+
pixel_to_ndc(left, image.size),
|
124 |
+
pixel_to_ndc(bottom, image.size),
|
125 |
+
pixel_to_ndc(right, image.size),
|
126 |
+
],
|
127 |
+
dtype=np.float32,
|
128 |
+
)
|
129 |
+
else:
|
130 |
+
bbox_points = np.array(
|
131 |
+
[pixel_to_ndc(bbox[:2], image.size), pixel_to_ndc(bbox[2:], image.size)],
|
132 |
+
dtype=np.float32,
|
133 |
+
)
|
134 |
+
|
135 |
+
# Inverse mapping
|
136 |
+
distorted_bbox = remove_distortion_iter(bbox_points, k1, k2)
|
137 |
+
|
138 |
+
if modify_bbox:
|
139 |
+
p = ndc_to_pixel(distorted_bbox, image.size)
|
140 |
+
distorted_bbox = np.array([p[0][0], p[1][1], p[2][0], p[3][1]])
|
141 |
+
else:
|
142 |
+
distorted_bbox = ndc_to_pixel(distorted_bbox, image.size).reshape(4)
|
143 |
+
|
144 |
+
return distorted, distorted_bbox
|
diffusionsfm/utils/distributed.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import socket
|
3 |
+
from contextlib import closing
|
4 |
+
|
5 |
+
import torch.distributed as dist
|
6 |
+
|
7 |
+
|
8 |
+
def get_open_port():
|
9 |
+
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
|
10 |
+
s.bind(("", 0))
|
11 |
+
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
12 |
+
return s.getsockname()[1]
|
13 |
+
|
14 |
+
|
15 |
+
# Distributed process group
|
16 |
+
def ddp_setup(rank, world_size, port="12345"):
|
17 |
+
"""
|
18 |
+
Args:
|
19 |
+
rank: Unique Identifier
|
20 |
+
world_size: number of processes
|
21 |
+
"""
|
22 |
+
os.environ["MASTER_ADDR"] = "localhost"
|
23 |
+
print(f"MasterPort: {str(port)}")
|
24 |
+
os.environ["MASTER_PORT"] = str(port)
|
25 |
+
|
26 |
+
# initialize the process group
|
27 |
+
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
|
28 |
+
|
29 |
+
|
30 |
+
def cleanup():
|
31 |
+
dist.destroy_process_group()
|
diffusionsfm/utils/geometry.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from pytorch3d.renderer import FoVPerspectiveCameras
|
4 |
+
from pytorch3d.transforms import quaternion_to_matrix
|
5 |
+
|
6 |
+
|
7 |
+
def generate_random_rotations(N=1, device="cpu"):
|
8 |
+
q = torch.randn(N, 4, device=device)
|
9 |
+
q = q / q.norm(dim=-1, keepdim=True)
|
10 |
+
return quaternion_to_matrix(q)
|
11 |
+
|
12 |
+
|
13 |
+
def symmetric_orthogonalization(x):
|
14 |
+
"""Maps 9D input vectors onto SO(3) via symmetric orthogonalization.
|
15 |
+
|
16 |
+
x: should have size [batch_size, 9]
|
17 |
+
|
18 |
+
Output has size [batch_size, 3, 3], where each inner 3x3 matrix is in SO(3).
|
19 |
+
"""
|
20 |
+
m = x.view(-1, 3, 3)
|
21 |
+
u, s, v = torch.svd(m)
|
22 |
+
vt = torch.transpose(v, 1, 2)
|
23 |
+
det = torch.det(torch.matmul(u, vt))
|
24 |
+
det = det.view(-1, 1, 1)
|
25 |
+
vt = torch.cat((vt[:, :2, :], vt[:, -1:, :] * det), 1)
|
26 |
+
r = torch.matmul(u, vt)
|
27 |
+
return r
|
28 |
+
|
29 |
+
|
30 |
+
def get_permutations(num_images):
|
31 |
+
permutations = []
|
32 |
+
for i in range(0, num_images):
|
33 |
+
for j in range(0, num_images):
|
34 |
+
if i != j:
|
35 |
+
permutations.append((j, i))
|
36 |
+
|
37 |
+
return permutations
|
38 |
+
|
39 |
+
|
40 |
+
def n_to_np_rotations(num_frames, n_rots):
|
41 |
+
R_pred_rel = []
|
42 |
+
permutations = get_permutations(num_frames)
|
43 |
+
for i, j in permutations:
|
44 |
+
R_pred_rel.append(n_rots[i].T @ n_rots[j])
|
45 |
+
R_pred_rel = torch.stack(R_pred_rel)
|
46 |
+
|
47 |
+
return R_pred_rel
|
48 |
+
|
49 |
+
|
50 |
+
def compute_angular_error_batch(rotation1, rotation2):
|
51 |
+
R_rel = np.einsum("Bij,Bjk ->Bik", rotation2, rotation1.transpose(0, 2, 1))
|
52 |
+
t = (np.trace(R_rel, axis1=1, axis2=2) - 1) / 2
|
53 |
+
theta = np.arccos(np.clip(t, -1, 1))
|
54 |
+
return theta * 180 / np.pi
|
55 |
+
|
56 |
+
|
57 |
+
# A should be GT, B should be predicted
|
58 |
+
def compute_optimal_alignment(A, B):
|
59 |
+
"""
|
60 |
+
Compute the optimal scale s, rotation R, and translation t that minimizes:
|
61 |
+
|| A - (s * B @ R + T) || ^ 2
|
62 |
+
|
63 |
+
Reference: Umeyama (TPAMI 91)
|
64 |
+
|
65 |
+
Args:
|
66 |
+
A (torch.Tensor): (N, 3).
|
67 |
+
B (torch.Tensor): (N, 3).
|
68 |
+
|
69 |
+
Returns:
|
70 |
+
s (float): scale.
|
71 |
+
R (torch.Tensor): rotation matrix (3, 3).
|
72 |
+
t (torch.Tensor): translation (3,).
|
73 |
+
"""
|
74 |
+
A_bar = A.mean(0)
|
75 |
+
B_bar = B.mean(0)
|
76 |
+
# normally with R @ B, this would be A @ B.T
|
77 |
+
H = (B - B_bar).T @ (A - A_bar)
|
78 |
+
U, S, Vh = torch.linalg.svd(H, full_matrices=True)
|
79 |
+
s = torch.linalg.det(U @ Vh)
|
80 |
+
S_prime = torch.diag(torch.tensor([1, 1, torch.sign(s)], device=A.device))
|
81 |
+
variance = torch.sum((B - B_bar) ** 2)
|
82 |
+
scale = 1 / variance * torch.trace(torch.diag(S) @ S_prime)
|
83 |
+
R = U @ S_prime @ Vh
|
84 |
+
t = A_bar - scale * B_bar @ R
|
85 |
+
|
86 |
+
A_hat = scale * B @ R + t
|
87 |
+
return A_hat, scale, R, t
|
88 |
+
|
89 |
+
|
90 |
+
def compute_optimal_translation_alignment(T_A, T_B, R_B):
|
91 |
+
"""
|
92 |
+
Assuming right-multiplied rotation matrices.
|
93 |
+
|
94 |
+
E.g., for world2cam R and T, a world coordinate is transformed to camera coordinate
|
95 |
+
system using X_cam = X_world.T @ R + T = R.T @ X_world + T
|
96 |
+
|
97 |
+
Finds s, t that minimizes || T_A - (s * T_B + R_B.T @ t) ||^2
|
98 |
+
|
99 |
+
Args:
|
100 |
+
T_A (torch.Tensor): Target translation (N, 3).
|
101 |
+
T_B (torch.Tensor): Initial translation (N, 3).
|
102 |
+
R_B (torch.Tensor): Initial rotation (N, 3, 3).
|
103 |
+
|
104 |
+
Returns:
|
105 |
+
T_A_hat (torch.Tensor): s * T_B + t @ R_B (N, 3).
|
106 |
+
scale s (torch.Tensor): (1,).
|
107 |
+
translation t (torch.Tensor): (1, 3).
|
108 |
+
"""
|
109 |
+
n = len(T_A)
|
110 |
+
|
111 |
+
T_A = T_A.unsqueeze(2)
|
112 |
+
T_B = T_B.unsqueeze(2)
|
113 |
+
|
114 |
+
A = torch.sum(T_B * T_A)
|
115 |
+
B = (T_B.transpose(1, 2) @ R_B.transpose(1, 2)).sum(0) @ (R_B @ T_A).sum(0) / n
|
116 |
+
C = torch.sum(T_B * T_B)
|
117 |
+
D = (T_B.transpose(1, 2) @ R_B.transpose(1, 2)).sum(0)
|
118 |
+
E = (D * D).sum() / n
|
119 |
+
|
120 |
+
s = (A - B.sum()) / (C - E.sum())
|
121 |
+
|
122 |
+
t = (R_B @ (T_A - s * T_B)).sum(0) / n
|
123 |
+
|
124 |
+
T_A_hat = s * T_B + R_B.transpose(1, 2) @ t
|
125 |
+
|
126 |
+
return T_A_hat.squeeze(2), s, t.transpose(1, 0)
|
127 |
+
|
128 |
+
|
129 |
+
def get_error(predict_rotations, R_pred, T_pred, R_gt, T_gt, gt_scene_scale):
|
130 |
+
if predict_rotations:
|
131 |
+
cameras_gt = FoVPerspectiveCameras(R=R_gt, T=T_gt)
|
132 |
+
cc_gt = cameras_gt.get_camera_center()
|
133 |
+
cameras_pred = FoVPerspectiveCameras(R=R_pred, T=T_pred)
|
134 |
+
cc_pred = cameras_pred.get_camera_center()
|
135 |
+
|
136 |
+
A_hat, _, _, _ = compute_optimal_alignment(cc_gt, cc_pred)
|
137 |
+
norm = torch.linalg.norm(cc_gt - A_hat, dim=1) / gt_scene_scale
|
138 |
+
|
139 |
+
norms = np.ndarray.tolist(norm.detach().cpu().numpy())
|
140 |
+
return norms, A_hat
|
141 |
+
else:
|
142 |
+
T_A_hat, _, _ = compute_optimal_translation_alignment(T_gt, T_pred, R_pred)
|
143 |
+
norm = torch.linalg.norm(T_gt - T_A_hat, dim=1) / gt_scene_scale
|
144 |
+
norms = np.ndarray.tolist(norm.detach().cpu().numpy())
|
145 |
+
return norms, T_A_hat
|