diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..8d6975514cc07df38ac9e568dfcdfd5e23e034fb 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -33,3 +33,32 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+assets/demo_gradio.gif filter=lfs diff=lfs merge=lfs -text
+assets/pipeline.jpg filter=lfs diff=lfs merge=lfs -text
+examples/video/bungeenerf_colosseum.mp4 filter=lfs diff=lfs merge=lfs -text
+examples/video/dtu_scan_106.mp4 filter=lfs diff=lfs merge=lfs -text
+examples/video/fillerbuster_hand_hand.mp4 filter=lfs diff=lfs merge=lfs -text
+examples/video/fillerbuster_ramen.mp4 filter=lfs diff=lfs merge=lfs -text
+examples/video/fox.mp4 filter=lfs diff=lfs merge=lfs -text
+examples/video/horizongs_hillside_summer.mp4 filter=lfs diff=lfs merge=lfs -text
+examples/video/kitti360.mp4 filter=lfs diff=lfs merge=lfs -text
+examples/video/llff_fortress.mp4 filter=lfs diff=lfs merge=lfs -text
+examples/video/llff_horns.mp4 filter=lfs diff=lfs merge=lfs -text
+examples/video/matrixcity_street.mp4 filter=lfs diff=lfs merge=lfs -text
+examples/video/meganerf_rubble.mp4 filter=lfs diff=lfs merge=lfs -text
+examples/video/vrnerf_apartment.mp4 filter=lfs diff=lfs merge=lfs -text
+examples/video/vrnerf_kitchen.mp4 filter=lfs diff=lfs merge=lfs -text
+examples/video/vrnerf_riverview.mp4 filter=lfs diff=lfs merge=lfs -text
+examples/video/vrnerf_workshop.mp4 filter=lfs diff=lfs merge=lfs -text
+examples/vrnerf/riverview/21_DSC0001.jpg filter=lfs diff=lfs merge=lfs -text
+examples/vrnerf/riverview/21_DSC0010.jpg filter=lfs diff=lfs merge=lfs -text
+examples/vrnerf/riverview/21_DSC0019.jpg filter=lfs diff=lfs merge=lfs -text
+examples/vrnerf/riverview/21_DSC0028.jpg filter=lfs diff=lfs merge=lfs -text
+examples/vrnerf/riverview/21_DSC0037.jpg filter=lfs diff=lfs merge=lfs -text
+examples/vrnerf/riverview/21_DSC0046.jpg filter=lfs diff=lfs merge=lfs -text
+examples/vrnerf/riverview/21_DSC0055.jpg filter=lfs diff=lfs merge=lfs -text
+examples/vrnerf/riverview/21_DSC0064.jpg filter=lfs diff=lfs merge=lfs -text
+examples/vrnerf/riverview/21_DSC0073.jpg filter=lfs diff=lfs merge=lfs -text
+examples/vrnerf/riverview/21_DSC0082.jpg filter=lfs diff=lfs merge=lfs -text
+examples/vrnerf/riverview/21_DSC0091.jpg filter=lfs diff=lfs merge=lfs -text
+examples/vrnerf/riverview/21_DSC0100.jpg filter=lfs diff=lfs merge=lfs -text
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..5f7a63dc185f14f57709daf100520105612b0865
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2025 Lihan Jiang and Yucheng Mao
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/assets/demo_gradio.gif b/assets/demo_gradio.gif
new file mode 100644
index 0000000000000000000000000000000000000000..580880f1825044f969cb5732e6915d470bf3f596
--- /dev/null
+++ b/assets/demo_gradio.gif
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d2de19dc0b15b0d64b408355b016ef8da1ce455913ee37fda935c5b7a43df248
+size 3774652
diff --git a/assets/pipeline.jpg b/assets/pipeline.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..df8e74ed3c4a7aa9f8aa67998cb1512ba7ca36c5
--- /dev/null
+++ b/assets/pipeline.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:eafeeddafbf266caf2a1ea911aec24fb08d4d1177813b7621081b0f92d4a63aa
+size 110525
diff --git a/config/compute_metrics.yaml b/config/compute_metrics.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..079d07e553c531cfdd212f1a510c62e4c6c367c6
--- /dev/null
+++ b/config/compute_metrics.yaml
@@ -0,0 +1,28 @@
+defaults:
+ - model/encoder: noposplat
+ - loss: []
+ - override dataset/view_sampler@dataset.re10k.view_sampler: evaluation
+
+dataset:
+ re10k:
+ view_sampler:
+ index_path: assets/evaluation_index_re10k.json
+
+data_loader:
+ train:
+ num_workers: 0
+ persistent_workers: true
+ batch_size: 1
+ seed: 1234
+ test:
+ num_workers: 4
+ persistent_workers: false
+ batch_size: 1
+ seed: 2345
+ val:
+ num_workers: 0
+ persistent_workers: true
+ batch_size: 1
+ seed: 3456
+
+seed: 111123
diff --git a/config/dataset/base_dataset.yaml b/config/dataset/base_dataset.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..ab32021ac0f178d8f893e22885b169f30d93a8f5
--- /dev/null
+++ b/config/dataset/base_dataset.yaml
@@ -0,0 +1,7 @@
+make_baseline_1: true
+relative_pose: true
+augment: true
+background_color: [1.0, 1.0, 1.0]
+overfit_to_scene: null
+skip_bad_shape: true
+rescale_to_1cube: false
\ No newline at end of file
diff --git a/config/dataset/co3d.yaml b/config/dataset/co3d.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..c3ca06fbbf4107699ca88ac7e4a2c877a3bad6ba
--- /dev/null
+++ b/config/dataset/co3d.yaml
@@ -0,0 +1,15 @@
+defaults:
+ - base_dataset
+ - view_sampler: rank
+
+name: co3d
+roots: [datasets/co3dv2]
+
+input_image_shape: [256, 256]
+original_image_shape: [540, 960]
+cameras_are_circular: false
+
+baseline_min: 1e-3
+baseline_max: 1e2
+max_fov: 110.0
+avg_pose: false
\ No newline at end of file
diff --git a/config/dataset/dl3dv.yaml b/config/dataset/dl3dv.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..0216ee1b59489f721e24da4b0df85a71e5c8ab4e
--- /dev/null
+++ b/config/dataset/dl3dv.yaml
@@ -0,0 +1,19 @@
+defaults:
+ - base_dataset
+ - view_sampler: bounded
+
+name: dl3dv
+roots: [datasets/dl3dv]
+
+input_image_shape: [256, 256]
+original_image_shape: [540, 960]
+cameras_are_circular: false
+
+baseline_min: 1e-3
+baseline_max: 1e2
+max_fov: 100.0
+avg_pose: false
+
+rescale_to_1cube: true
+make_baseline_1: false
+intr_augment: true
\ No newline at end of file
diff --git a/config/dataset/scannetpp.yaml b/config/dataset/scannetpp.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..bcb468bbdb32899917377188ca2926d5ed6cc346
--- /dev/null
+++ b/config/dataset/scannetpp.yaml
@@ -0,0 +1,22 @@
+defaults:
+ - base_dataset
+ - view_sampler: rank
+
+name: scannetpp
+roots: [datasets/scannetpp]
+
+input_image_shape: [256, 256]
+original_image_shape: [690, 1035]
+cameras_are_circular: false
+
+baseline_min: 1e-3
+baseline_max: 1e2
+max_fov: 130.0 # 120.0
+metric_thre: 0.5 # aggressive metric threshold!!
+
+skip_bad_shape: true # if use dlsr and iphone, set to false
+
+rescale_to_1cube: true
+make_baseline_1: false
+intr_augment: true
+normalize_by_pts3d: false
\ No newline at end of file
diff --git a/config/dataset/view_sampler/all.yaml b/config/dataset/view_sampler/all.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..9c49fb4661ff2e30ecd98c7e233e2835c9071014
--- /dev/null
+++ b/config/dataset/view_sampler/all.yaml
@@ -0,0 +1 @@
+name: all
diff --git a/config/dataset/view_sampler/arbitrary.yaml b/config/dataset/view_sampler/arbitrary.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..c947c2ce3ddf89fac2136ee8133c6662536af3d5
--- /dev/null
+++ b/config/dataset/view_sampler/arbitrary.yaml
@@ -0,0 +1,7 @@
+name: arbitrary
+
+num_target_views: 1
+num_context_views: 2
+
+# If you want to hard-code context views, do so here.
+context_views: null
diff --git a/config/dataset/view_sampler/bounded.yaml b/config/dataset/view_sampler/bounded.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..064258a8047965b0f79ae5bf78702c6a6c18518f
--- /dev/null
+++ b/config/dataset/view_sampler/bounded.yaml
@@ -0,0 +1,16 @@
+name: bounded
+
+num_target_views: 1
+num_context_views: 24
+
+min_distance_between_context_views: 2
+max_distance_between_context_views: 6
+min_distance_to_context_views: 0
+
+warm_up_steps: 0
+initial_min_distance_between_context_views: 2
+initial_max_distance_between_context_views: 6
+
+max_img_per_gpu: 24
+min_gap_multiplier: 3
+max_gap_multiplier: 5
\ No newline at end of file
diff --git a/config/dataset/view_sampler/evaluation.yaml b/config/dataset/view_sampler/evaluation.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..054089ad27b2bdb558b57b145f5154b26f97e1ce
--- /dev/null
+++ b/config/dataset/view_sampler/evaluation.yaml
@@ -0,0 +1,4 @@
+name: evaluation
+
+index_path: assets/evaluation_index_re10k.json
+num_context_views: 2
diff --git a/config/dataset/view_sampler/rank.yaml b/config/dataset/view_sampler/rank.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..6e8ff80b628c58950d12c96aa2b08ba53b7956c1
--- /dev/null
+++ b/config/dataset/view_sampler/rank.yaml
@@ -0,0 +1,14 @@
+name: rank
+
+num_target_views: 4
+num_context_views: 24
+
+min_distance_between_context_views: 8
+max_distance_between_context_views: 22
+min_distance_to_context_views: 0
+
+warm_up_steps: 0
+initial_min_distance_between_context_views: 5
+initial_max_distance_between_context_views: 7
+
+max_img_per_gpu: 24
diff --git a/config/experiment/co3d.yaml b/config/experiment/co3d.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..5ee2f35ef40db33199ac7221c801aea20fffb2fb
--- /dev/null
+++ b/config/experiment/co3d.yaml
@@ -0,0 +1,90 @@
+# @package _global_
+
+defaults:
+ - /dataset@_group_.co3d: co3d
+ - override /model/encoder: anysplat
+ - override /model/encoder/backbone: croco
+ - override /loss: [mse, lpips, depth_consis] # ablate: opacity loss
+
+wandb:
+ name: co3d
+ tags: [co3d, 448x448]
+
+model:
+ encoder:
+ gs_params_head_type: dpt_gs
+ pose_free: true
+ intrinsics_embed_loc: encoder
+ intrinsics_embed_type: token
+ pretrained_weights: ''
+ voxel_size: 0.002
+ pred_pose: true
+ anchor_feat_dim: 128
+ gs_prune: false # ablate: opacity loss
+ pred_head_type: depth
+ freeze_backbone: false
+ distill: true
+ render_conf: false
+ conf_threshold: 0.1
+ freeze_module: patch_embed
+ voxelize: true
+ intermediate_layer_idx: [4, 11, 17, 23]
+
+dataset:
+ co3d:
+ input_image_shape: [224, 448]
+ view_sampler:
+ num_context_views: 24
+ num_target_views: 1
+ min_distance_between_context_views: 32
+ max_distance_between_context_views: 256
+ max_img_per_gpu: 24 # keep the same as num_context_views
+ avg_pose: false
+ intr_augment: true
+ normalize_by_pts3d: false
+ rescale_to_1cube: false
+
+optimizer:
+ lr: 2e-4
+ warm_up_steps: 1000
+ backbone_lr_multiplier: 0.1
+
+data_loader:
+ train:
+ batch_size: 1 # not used here
+
+trainer:
+ max_steps: 30000
+ val_check_interval: 500
+ num_nodes: 1
+ accumulate_grad_batches: 1
+ precision: bf16-mixed
+
+checkpointing:
+ load: null
+ every_n_train_steps: 200
+ save_weights_only: false
+ save_top_k: 5
+
+train:
+ pose_loss_alpha: 1.0
+ pose_loss_delta: 1.0
+ cxt_depth_weight: 0.0
+ weight_pose: 10.0
+ weight_depth: 0.0
+ weight_normal: 0.0
+
+hydra:
+ run:
+ dir: output/exp_${wandb.name}/${now:%Y-%m-%d_%H-%M-%S}
+
+loss:
+ mse:
+ conf: false
+ lpips:
+ conf: false
+ depth_consis:
+ weight: 0.1
+ loss_type: MSE
+
+
diff --git a/config/experiment/dl3dv.yaml b/config/experiment/dl3dv.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..197f698214fa193096b69b15490295403584bab5
--- /dev/null
+++ b/config/experiment/dl3dv.yaml
@@ -0,0 +1,92 @@
+# @package _global_
+
+defaults:
+ - /dataset@_group_.dl3dv: dl3dv
+ - override /model/encoder: anysplat
+ - override /model/encoder/backbone: croco
+ - override /loss: [mse, lpips, depth_consis] # ablate: opacity loss
+
+wandb:
+ name: dl3dv
+ tags: [dl3dv, 448x448]
+
+model:
+ encoder:
+ gs_params_head_type: dpt_gs
+ pose_free: true
+ intrinsics_embed_loc: encoder
+ intrinsics_embed_type: token
+ pretrained_weights: ''
+ voxel_size: 0.002
+ pred_pose: true
+ anchor_feat_dim: 128
+ gs_prune: false # ablate: opacity loss
+ pred_head_type: depth
+ freeze_backbone: false
+ distill: true
+ render_conf: false
+ conf_threshold: 0.1
+ freeze_module: patch_embed
+ voxelize: true
+ intermediate_layer_idx: [4, 11, 17, 23]
+
+dataset:
+ dl3dv:
+ input_image_shape: [224, 448]
+ view_sampler:
+
+ num_target_views: 2
+ min_distance_between_context_views: 32
+ max_distance_between_context_views: 256
+
+ min_gap_multiplier: 3
+ max_gap_multiplier: 5
+ avg_pose: false
+ intr_augment: true
+ normalize_by_pts3d: false
+ rescale_to_1cube: false
+
+optimizer:
+ lr: 2e-4
+ warm_up_steps: 1000
+ backbone_lr_multiplier: 0.1
+
+data_loader:
+ train:
+ batch_size: 1 # not used here
+
+trainer:
+ max_steps: 30000
+ val_check_interval: 500
+ num_nodes: 1
+ accumulate_grad_batches: 1
+ precision: bf16-mixed
+
+checkpointing:
+ load: null
+ every_n_train_steps: 200
+ save_weights_only: false
+ save_top_k: 5
+
+train:
+ pose_loss_alpha: 1.0
+ pose_loss_delta: 1.0
+ cxt_depth_weight: 0.0
+ weight_pose: 10.0
+ weight_depth: 1.0
+ weight_normal: 0.0
+
+hydra:
+ run:
+ dir: output/exp_${wandb.name}/${now:%Y-%m-%d_%H-%M-%S}
+
+loss:
+ mse:
+ conf: false
+ lpips:
+ conf: false
+ depth_consis:
+ weight: 0.1
+ loss_type: MSE
+
+
diff --git a/config/experiment/multi-dataset.yaml b/config/experiment/multi-dataset.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..dcfdda33c8cf1c5f3303e25bce2e4054576e750e
--- /dev/null
+++ b/config/experiment/multi-dataset.yaml
@@ -0,0 +1,121 @@
+# @package _global_
+
+defaults:
+ - /dataset@_group_.dl3dv: dl3dv
+ - /dataset@_group_.co3d: co3d
+ - /dataset@_group_.scannetpp: scannetpp
+ - override /model/encoder: anysplat
+ - override /model/encoder/backbone: croco
+ - override /loss: [mse, lpips, depth_consis] # ablate: opacity loss
+
+wandb:
+ name: multidataset-16gpu
+ tags: [multidataset, 448x448]
+
+model:
+ encoder:
+ gs_params_head_type: dpt_gs
+ pose_free: true
+ intrinsics_embed_loc: encoder
+ intrinsics_embed_type: token
+ pretrained_weights: ''
+ voxel_size: 0.002
+ pred_pose: true
+ anchor_feat_dim: 128
+ gs_prune: false # ablate: opacity loss
+ pred_head_type: depth
+ freeze_backbone: false
+ distill: true
+ render_conf: false
+ conf_threshold: 0.1
+ freeze_module: patch_embed
+ voxelize: true
+ intermediate_layer_idx: [4, 11, 17, 23]
+
+dataset:
+ dl3dv:
+ input_image_shape: [224, 448]
+ view_sampler:
+
+ num_target_views: 2
+ min_distance_between_context_views: 32
+ max_distance_between_context_views: 256
+
+ min_gap_multiplier: 3
+ max_gap_multiplier: 5
+ max_img_per_gpu: 24
+ avg_pose: false
+ intr_augment: true
+ normalize_by_pts3d: false
+ rescale_to_1cube: false
+
+ co3d:
+ input_image_shape: [224, 448]
+ view_sampler:
+
+ num_target_views: 1
+ min_distance_between_context_views: 32
+ max_distance_between_context_views: 256
+ max_img_per_gpu: 24
+
+ avg_pose: false
+ intr_augment: true
+ normalize_by_pts3d: false
+ rescale_to_1cube: false
+
+ scannetpp:
+ input_image_shape: [224, 448]
+ view_sampler:
+ num_target_views: 2
+ min_distance_between_context_views: 128
+ max_distance_between_context_views: 512
+ max_img_per_gpu: 24
+ avg_pose: false
+ intr_augment: true
+ normalize_by_pts3d: false
+ rescale_to_1cube: false
+
+optimizer:
+ lr: 2e-4
+ warm_up_steps: 1000
+ backbone_lr_multiplier: 0.1
+
+data_loader:
+ train:
+ batch_size: 1 # not used here
+
+trainer:
+ max_steps: 30000
+ val_check_interval: 500
+ num_nodes: 2
+ accumulate_grad_batches: 1
+ precision: bf16-mixed
+
+checkpointing:
+ load: null
+ every_n_train_steps: 200
+ save_weights_only: false
+ save_top_k: 5
+
+train:
+ pose_loss_alpha: 1.0
+ pose_loss_delta: 1.0
+ cxt_depth_weight: 0.0
+ weight_pose: 10.0
+ weight_depth: 0.0
+ weight_normal: 0.0
+
+hydra:
+ run:
+ dir: output/exp_${wandb.name}/${now:%Y-%m-%d_%H-%M-%S}
+
+loss:
+ mse:
+ conf: false
+ lpips:
+ conf: false
+ depth_consis:
+ weight: 0.1
+ loss_type: MSE
+
+
diff --git a/config/experiment/scannetpp.yaml b/config/experiment/scannetpp.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..dd158557536871f0bba27cf27160f2224691f4e3
--- /dev/null
+++ b/config/experiment/scannetpp.yaml
@@ -0,0 +1,90 @@
+# @package _global_
+
+defaults:
+ - /dataset@_group_.scannetpp: scannetpp
+ - override /model/encoder: anysplat
+ - override /model/encoder/backbone: croco
+ - override /loss: [mse, lpips, depth_consis] # ablate: opacity loss
+
+wandb:
+ name: vggt-mdataset-new-scannetpp-dynamic_batchsampler
+ tags: [multidataset, 448x448]
+
+model:
+ encoder:
+ gs_params_head_type: dpt_gs
+ pose_free: true
+ intrinsics_embed_loc: encoder
+ intrinsics_embed_type: token
+ pretrained_weights: ''
+ voxel_size: 0.002
+ pred_pose: true
+ anchor_feat_dim: 128
+ gs_prune: false # ablate: opacity loss
+ pred_head_type: depth
+ freeze_backbone: false
+ distill: true
+ render_conf: false
+ conf_threshold: 0.1
+ freeze_module: patch_embed
+ voxelize: true
+ intermediate_layer_idx: [4, 11, 17, 23]
+
+dataset:
+ scannetpp:
+ input_image_shape: [224, 448]
+ view_sampler:
+ num_context_views: 24
+ num_target_views: 2
+ min_distance_between_context_views: 128
+ max_distance_between_context_views: 512
+ max_img_per_gpu: 24 # keep the same as num_context_views
+ avg_pose: false
+ intr_augment: true
+ normalize_by_pts3d: false
+ rescale_to_1cube: false
+
+optimizer:
+ lr: 2e-4
+ warm_up_steps: 1000
+ backbone_lr_multiplier: 0.1
+
+data_loader:
+ train:
+ batch_size: 1 # not used here
+
+trainer:
+ max_steps: 30000
+ val_check_interval: 500
+ num_nodes: 1
+ accumulate_grad_batches: 1
+ precision: bf16-mixed
+
+checkpointing:
+ load: null
+ every_n_train_steps: 200
+ save_weights_only: false
+ save_top_k: 5
+
+train:
+ pose_loss_alpha: 1.0
+ pose_loss_delta: 1.0
+ cxt_depth_weight: 0.0
+ weight_pose: 10.0
+ weight_depth: 0.0
+ weight_normal: 0.0
+
+hydra:
+ run:
+ dir: output/exp_${wandb.name}/${now:%Y-%m-%d_%H-%M-%S}
+
+loss:
+ mse:
+ conf: false
+ lpips:
+ conf: false
+ depth_consis:
+ weight: 0.1
+ loss_type: MSE
+
+
diff --git a/config/generate_evaluation_index.yaml b/config/generate_evaluation_index.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..8774a6adfb604fe4f1467dde13dedf59d40503e5
--- /dev/null
+++ b/config/generate_evaluation_index.yaml
@@ -0,0 +1,36 @@
+defaults:
+ - dataset: re10k
+ - optional dataset/view_sampler_dataset_specific_config: ${dataset/view_sampler}_${dataset}
+ - override dataset/view_sampler: all
+
+dataset:
+ overfit_to_scene: null
+
+data_loader:
+ train:
+ num_workers: 0
+ persistent_workers: true
+ batch_size: 1
+ seed: 1234
+ test:
+ num_workers: 8
+ persistent_workers: false
+ batch_size: 1
+ seed: 2345
+ val:
+ num_workers: 0
+ persistent_workers: true
+ batch_size: 1
+ seed: 3456
+
+index_generator:
+ num_target_views: 3
+ min_overlap: 0.6
+ max_overlap: 1.0
+ min_distance: 45
+ max_distance: 135
+ output_path: outputs/evaluation_index_re10k
+ save_previews: false
+ seed: 123
+
+seed: 456
diff --git a/config/loss/chamfer_distance.yaml b/config/loss/chamfer_distance.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..2db2ed7c75bc78e03c2298796e1cc19cef64def9
--- /dev/null
+++ b/config/loss/chamfer_distance.yaml
@@ -0,0 +1,5 @@
+chamfer_distance:
+ weight: 0.01
+ down_sample_ratio: 0.1
+ sigma_image: null
+
diff --git a/config/loss/depth.yaml b/config/loss/depth.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f90e902a296f126598e1e6d243b8e94e99bc7ee6
--- /dev/null
+++ b/config/loss/depth.yaml
@@ -0,0 +1,4 @@
+depth:
+ weight: 0.01
+ sigma_image: null
+ use_second_derivative: false
diff --git a/config/loss/depth_consis.yaml b/config/loss/depth_consis.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..a7c055e8b2dcfe86594e28202669ee958deb963a
--- /dev/null
+++ b/config/loss/depth_consis.yaml
@@ -0,0 +1,4 @@
+depth_consis:
+ weight: 1.0
+ sigma_image: null
+ use_second_derivative: false
diff --git a/config/loss/depthgt.yaml b/config/loss/depthgt.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..51df47d3c98569a3cbd998302cccfe1624a7c15a
--- /dev/null
+++ b/config/loss/depthgt.yaml
@@ -0,0 +1,3 @@
+depthgt:
+ weight: 0.1
+ type: l1+gradient
diff --git a/config/loss/lod.yaml b/config/loss/lod.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..95477e846b2a6549c50bebb6e3259516b00b4689
--- /dev/null
+++ b/config/loss/lod.yaml
@@ -0,0 +1,3 @@
+lod:
+ mse_weight: 1.0
+ lpips_weight: 0.05
diff --git a/config/loss/lpips.yaml b/config/loss/lpips.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..c4c7d403b826e46383515ae4cda01794a487b203
--- /dev/null
+++ b/config/loss/lpips.yaml
@@ -0,0 +1,3 @@
+lpips:
+ weight: 0.05
+ apply_after_step: 0
diff --git a/config/loss/mse.yaml b/config/loss/mse.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..80cc0be6dc7950661998336bd2bb5cb4ff06ba07
--- /dev/null
+++ b/config/loss/mse.yaml
@@ -0,0 +1,2 @@
+mse:
+ weight: 1.0
diff --git a/config/loss/normal_consis.yaml b/config/loss/normal_consis.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..bedb61dec6b92d9115621efc8c6f7f0104ad9d6a
--- /dev/null
+++ b/config/loss/normal_consis.yaml
@@ -0,0 +1,5 @@
+normal_consis:
+ normal_weight: 1.0
+ smooth_weight: 1.0
+ sigma_image: null
+ use_second_derivative: false
diff --git a/config/loss/opacity.yaml b/config/loss/opacity.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..adbd0beb842e27f36fcfb2cf7123cb1c70d061ec
--- /dev/null
+++ b/config/loss/opacity.yaml
@@ -0,0 +1,3 @@
+opacity:
+ weight: 0.1
+ type: exp+mean
\ No newline at end of file
diff --git a/config/main.yaml b/config/main.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..78a092c5429fd4ffc6ab9de6710021a9a2f56296
--- /dev/null
+++ b/config/main.yaml
@@ -0,0 +1,81 @@
+defaults:
+ - model/encoder: anysplat
+ - model/decoder: splatting_cuda
+ - loss: [mse]
+
+wandb:
+ project: anysplat
+ entity: scene-representation-group
+ name: debug
+ mode: online
+mode: train
+
+#dataset:
+# overfit_to_scene: null
+
+data_loader:
+ # Avoid having to spin up new processes to print out visualizations.
+ train:
+ num_workers: 16 # 16
+ persistent_workers: true
+ batch_size: 4
+ seed: 1234
+ test:
+ num_workers: 4
+ persistent_workers: false
+ batch_size: 1
+ seed: 2345
+ val:
+ num_workers: 1
+ persistent_workers: true
+ batch_size: 1
+ seed: 3456
+
+optimizer:
+ lr: 1.5e-4
+ warm_up_steps: 2000
+ backbone_lr_multiplier: 0.1
+
+checkpointing:
+ load: null
+ every_n_train_steps: 5000
+ save_top_k: 1
+ save_weights_only: true
+
+train:
+ output_path: ${hydra.run.dir}
+ depth_mode: null
+ extended_visualization: false
+ print_log_every_n_steps: 10
+ distiller: ''
+ distill_max_steps: 1000000
+ random_context_views: false
+
+test:
+ output_path: outputs/test-nopo
+ align_pose: true
+ pose_align_steps: 100
+ rot_opt_lr: 0.005
+ trans_opt_lr: 0.005
+ compute_scores: true
+ save_image: true
+ save_video: false
+ save_compare: true
+ generate_video: false
+ mode: inference
+ image_folder: examples/bungeenerf
+
+seed: 111123
+
+trainer:
+ max_steps: -1
+ val_check_interval: 250
+ gradient_clip_val: 0.5
+ num_nodes: 1
+ accumulate_grad_batches: 1
+
+hydra:
+ run:
+ dir: output-debug/exp_${wandb.name}/${now:%Y-%m-%d_%H-%M-%S}
+ # run:
+ # dir: outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}_rank${oc.env:LOCAL_RANK,0}
diff --git a/config/model/decoder/splatting_cuda.yaml b/config/model/decoder/splatting_cuda.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..743730cc0149d3ca8b6d3606dd0f65b1fe06a6ba
--- /dev/null
+++ b/config/model/decoder/splatting_cuda.yaml
@@ -0,0 +1,3 @@
+name: splatting_cuda
+background_color: [1.0, 1.0, 1.0]
+make_scale_invariant: false
diff --git a/config/model/encoder/anysplat.yaml b/config/model/encoder/anysplat.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..62fb1859a27a4aaed2ebce86437b4fa26f4dc8b0
--- /dev/null
+++ b/config/model/encoder/anysplat.yaml
@@ -0,0 +1,62 @@
+defaults:
+ - backbone: croco
+
+name: anysplat
+
+opacity_mapping:
+ initial: 0.0
+ final: 0.0
+ warm_up: 1
+
+num_monocular_samples: 32
+num_surfaces: 1
+predict_opacity: false
+
+gaussians_per_pixel: 1
+
+gaussian_adapter:
+ gaussian_scale_min: 0.5
+ gaussian_scale_max: 15.0
+ sh_degree: 4
+
+d_feature: 32
+
+visualizer:
+ num_samples: 8
+ min_resolution: 256
+ export_ply: false
+
+apply_bounds_shim: true
+
+gs_params_head_type: dpt_gs
+pose_free: true
+pretrained_weights: ""
+scale_align: false
+
+voxel_size: 0.001
+n_offsets: 2
+anchor_feat_dim: 83 # 32
+add_view: false
+color_attr: 3D # 3D or RGB
+mlp_type: unified
+scaffold: true
+
+# unet3d:
+# # lifter_params:
+# # img_in_dim: 32
+# # voxel_out_dim: 32
+# img_feature_source: dino
+# in_channels: 83 # 32 keep same as anchor_feat_dim
+# num_blocks: 2 # 512 -> 128
+# f_maps: 83 # 32
+# # f_maps_2d: 32
+# neck_dense_type: "UNCHANGED"
+# neck_bound: 4
+# use_attention: true
+# gs_enhanced: "original"
+# gsplat_upsample: 4
+# occ_upsample: 1
+# max_scaling: 10
+# max_return: 2
+# feature_pooling_2d: "max"
+# gs_free_space: "free-1"
\ No newline at end of file
diff --git a/config/model/encoder/backbone/croco.yaml b/config/model/encoder/backbone/croco.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..b89ab47f7bbfd5b70f25b094addec2f72e6a434a
--- /dev/null
+++ b/config/model/encoder/backbone/croco.yaml
@@ -0,0 +1,9 @@
+name: croco
+
+model: ViTLarge_BaseDecoder
+patch_embed_cls: PatchEmbedDust3R
+asymmetry_decoder: true
+
+intrinsics_embed_loc: 'encoder'
+intrinsics_embed_degree: 4
+intrinsics_embed_type: 'token'
\ No newline at end of file
diff --git a/demo_gradio.py b/demo_gradio.py
new file mode 100644
index 0000000000000000000000000000000000000000..18e477d7f73deb83f3930e2bb9e1f7c27d74a1e6
--- /dev/null
+++ b/demo_gradio.py
@@ -0,0 +1,459 @@
+#!/usr/bin/env python3
+import functools
+import gc
+import os
+import shutil
+import sys
+import tempfile
+import time
+from datetime import datetime
+from pathlib import Path
+
+import cv2
+import gradio as gr
+import torch
+from huggingface_hub import hf_hub_download
+from PIL import Image
+
+sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from src.misc.image_io import save_interpolated_video
+from src.model.model.anysplat import AnySplat
+from src.model.ply_export import export_ply
+from src.utils.image import process_image
+
+
+# 1) Core model inference
+def get_reconstructed_scene(outdir, model, device):
+ # Load Images
+ image_files = sorted(
+ [
+ os.path.join(outdir, "images", f)
+ for f in os.listdir(os.path.join(outdir, "images"))
+ ]
+ )
+ images = [process_image(img_path) for img_path in image_files]
+ images = torch.stack(images, dim=0).unsqueeze(0).to(device) # [1, K, 3, 448, 448]
+ b, v, c, h, w = images.shape
+
+ assert c == 3, "Images must have 3 channels"
+
+ # Run Inference
+ gaussians, pred_context_pose = model.inference((images + 1) * 0.5)
+
+ # Save the results
+ pred_all_extrinsic = pred_context_pose["extrinsic"]
+ pred_all_intrinsic = pred_context_pose["intrinsic"]
+ video, depth_colored = save_interpolated_video(
+ pred_all_extrinsic,
+ pred_all_intrinsic,
+ b,
+ h,
+ w,
+ gaussians,
+ outdir,
+ model.decoder,
+ )
+
+ plyfile = os.path.join(outdir, "gaussians.ply")
+ export_ply(
+ gaussians.means[0],
+ gaussians.scales[0],
+ gaussians.rotations[0],
+ gaussians.harmonics[0],
+ gaussians.opacities[0],
+ Path(plyfile),
+ save_sh_dc_only=True,
+ )
+
+ # Clean up
+ torch.cuda.empty_cache()
+ return plyfile, video, depth_colored
+
+
+# 2) Handle uploaded video/images --> produce target_dir + images
+def handle_uploads(input_video, input_images):
+ """
+ Create a new 'target_dir' + 'images' subfolder, and place user-uploaded
+ images or extracted frames from video into it. Return (target_dir, image_paths).
+ """
+ start_time = time.time()
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ # Create a unique folder name
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
+ target_dir = f"input_images_{timestamp}"
+ target_dir_images = os.path.join(target_dir, "images")
+
+ # Clean up if somehow that folder already exists
+ if os.path.exists(target_dir):
+ shutil.rmtree(target_dir)
+ os.makedirs(target_dir)
+ os.makedirs(target_dir_images)
+
+ image_paths = []
+
+ # --- Handle images ---
+ if input_images is not None:
+ for file_data in input_images:
+ if isinstance(file_data, dict) and "name" in file_data:
+ file_path = file_data["name"]
+ else:
+ file_path = file_data
+ dst_path = os.path.join(target_dir_images, os.path.basename(file_path))
+ shutil.copy(file_path, dst_path)
+ image_paths.append(dst_path)
+
+ # --- Handle video ---
+ if input_video is not None:
+ if isinstance(input_video, dict) and "name" in input_video:
+ video_path = input_video["name"]
+ else:
+ video_path = input_video
+
+ vs = cv2.VideoCapture(video_path)
+ fps = vs.get(cv2.CAP_PROP_FPS)
+ frame_interval = int(fps * 1) # 1 frame/sec
+
+ count = 0
+ video_frame_num = 0
+ while True:
+ gotit, frame = vs.read()
+ if not gotit:
+ break
+ count += 1
+ if count % frame_interval == 0:
+ image_path = os.path.join(
+ target_dir_images, f"{video_frame_num:06}.png"
+ )
+ cv2.imwrite(image_path, frame)
+ image_paths.append(image_path)
+ video_frame_num += 1
+
+ # Sort final images for gallery
+ image_paths = sorted(image_paths)
+
+ end_time = time.time()
+ print(
+ f"Files copied to {target_dir_images}; took {end_time - start_time:.3f} seconds"
+ )
+ return target_dir, image_paths
+
+
+# 3) Update gallery on upload
+def update_gallery_on_upload(input_video, input_images):
+ """
+ Whenever user uploads or changes files, immediately handle them
+ and show in the gallery. Return (target_dir, image_paths).
+ If nothing is uploaded, returns "None" and empty list.
+ """
+ if not input_video and not input_images:
+ return None, None, None
+ target_dir, image_paths = handle_uploads(input_video, input_images)
+ return None, target_dir, image_paths
+
+
+# 4) Reconstruction: uses the target_dir plus any viz parameters
+def gradio_demo(
+ target_dir,
+):
+ """
+ Perform reconstruction using the already-created target_dir/images.
+ """
+ if not os.path.isdir(target_dir) or target_dir == "None":
+ return None, None, None
+
+ start_time = time.time()
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ # Prepare frame_filter dropdown
+ target_dir_images = os.path.join(target_dir, "images")
+ all_files = (
+ sorted(os.listdir(target_dir_images))
+ if os.path.isdir(target_dir_images)
+ else []
+ )
+ all_files = [f"{i}: {filename}" for i, filename in enumerate(all_files)]
+
+ print("Running run_model...")
+ with torch.no_grad():
+ plyfile, video, depth_colored = get_reconstructed_scene(
+ target_dir, model, device
+ )
+
+ end_time = time.time()
+ print(f"Total time: {end_time - start_time:.2f} seconds (including IO)")
+
+ return plyfile, video, depth_colored
+
+
+def clear_fields():
+ """
+ Clears the 3D viewer, the stored target_dir, and empties the gallery.
+ """
+ return None, None, None
+
+
+if __name__ == "__main__":
+ server_name = "127.0.0.1"
+ server_port = None
+ share = True
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+ # Load model
+ model = AnySplat.from_pretrained(
+ "lhjiang/anysplat"
+ )
+ model = model.to(device)
+ model.eval()
+ for param in model.parameters():
+ param.requires_grad = False
+
+ theme = gr.themes.Ocean()
+ theme.set(
+ checkbox_label_background_fill_selected="*button_primary_background_fill",
+ checkbox_label_text_color_selected="*button_primary_text_color",
+ )
+ css = """
+ .custom-log * {
+ font-style: italic;
+ font-size: 22px !important;
+ background-image: linear-gradient(120deg, #0ea5e9 0%, #6ee7b7 60%, #34d399 100%);
+ -webkit-background-clip: text;
+ background-clip: text;
+ font-weight: bold !important;
+ color: transparent !important;
+ text-align: center !important;
+ }
+
+ .example-log * {
+ font-style: italic;
+ font-size: 16px !important;
+ background-image: linear-gradient(120deg, #0ea5e9 0%, #6ee7b7 60%, #34d399 100%);
+ -webkit-background-clip: text;
+ background-clip: text;
+ color: transparent !important;
+ }
+
+ #my_radio .wrap {
+ display: flex;
+ flex-wrap: nowrap;
+ justify-content: center;
+ align-items: center;
+ }
+
+ #my_radio .wrap label {
+ display: flex;
+ width: 50%;
+ justify-content: center;
+ align-items: center;
+ margin: 0;
+ padding: 10px 0;
+ box-sizing: border-box;
+ }
+ """
+ with gr.Blocks(css=css, title="AnySplat Demo", theme=theme) as demo:
+ gr.Markdown(
+ """
+
AnySplat: Feed-forward 3D Gaussian Splatting from Unconstrained Views
+ """
+ )
+
+ with gr.Row():
+ gr.Markdown(
+ """
+
+
+
+
+
+
+
+
+
+
+
+
+ """
+ )
+ with gr.Row():
+ gr.Markdown(
+ """
+ ### Getting Started:
+
+ 1. Upload Your Data: Use the "Upload Video" or "Upload Images" buttons on the left to provide your input. Videos will be automatically split into individual frames (one frame per second).
+
+ 2. Preview: Your uploaded images will appear in the gallery on the left.
+
+ 3. Reconstruct: Click the "Reconstruct" button to start the 3D reconstruction process.
+
+ 4. Visualize: The reconstructed 3D Gaussian Splat will appear in the viewer on the right, along with the rendered RGB and depth videos. The trajectory of the rendered video is obtained by interpolating the estimated input image poses.
+
+ Please note: The generated splats are large in size, so they may not load successfully in the Hugging Face demo. You can download the .ply file and render it using other viewers, such as [SuperSplat](https://playcanvas.com/supersplat/editor).
+ """
+ )
+
+ target_dir_output = gr.Textbox(label="Target Dir", visible=False, value="None")
+ is_example = gr.Textbox(label="is_example", visible=False, value="None")
+ num_images = gr.Textbox(label="num_images", visible=False, value="None")
+ dataset_name = gr.Textbox(label="dataset_name", visible=False, value="None")
+ scene_name = gr.Textbox(label="scene_name", visible=False, value="None")
+ image_type = gr.Textbox(label="image_type", visible=False, value="None")
+
+ with gr.Row():
+ with gr.Column(scale=2):
+ with gr.Tabs():
+ with gr.Tab("Input Data"):
+ input_video = gr.Video(label="Upload Video", interactive=True)
+ input_images = gr.File(
+ file_count="multiple",
+ label="Upload Images",
+ interactive=True,
+ )
+
+ image_gallery = gr.Gallery(
+ label="Preview",
+ columns=4,
+ height="300px",
+ show_download_button=True,
+ object_fit="contain",
+ preview=True,
+ )
+
+ with gr.Column(scale=4):
+ with gr.Tabs():
+ with gr.Tab("AnySplat Output"):
+ with gr.Column():
+ reconstruction_output = gr.Model3D(
+ label="3D Reconstructed Gaussian Splat",
+ height=540,
+ zoom_speed=0.5,
+ pan_speed=0.5,
+ camera_position=[20, 20, 20],
+ )
+
+ with gr.Row():
+ with gr.Row():
+ rgb_video = gr.Video(
+ label="RGB Video", interactive=False, autoplay=True
+ )
+ depth_video = gr.Video(
+ label="Depth Video",
+ interactive=False,
+ autoplay=True,
+ )
+
+ with gr.Row():
+ submit_btn = gr.Button(
+ "Reconstruct", scale=1, variant="primary"
+ )
+ clear_btn = gr.ClearButton(
+ [
+ input_video,
+ input_images,
+ reconstruction_output,
+ target_dir_output,
+ image_gallery,
+ rgb_video,
+ depth_video,
+ ],
+ scale=1,
+ )
+
+ # ---------------------- Examples section ----------------------
+
+ examples = [
+ [None, "examples/video/re10k_1eca36ec55b88fe4.mp4", "re10k", "1eca36ec55b88fe4", "2", "Real", "True",],
+ [None, "examples/video/bungeenerf_colosseum.mp4", "bungeenerf", "colosseum", "8", "Synthetic", "True",],
+ [None, "examples/video/fox.mp4", "InstantNGP", "fox", "14", "Real", "True",],
+ [None, "examples/video/matrixcity_street.mp4", "matrixcity", "street", "32", "Synthetic", "True",],
+ [None, "examples/video/vrnerf_apartment.mp4", "vrnerf", "apartment", "32", "Real", "True",],
+ [None, "examples/video/vrnerf_kitchen.mp4", "vrnerf", "kitchen", "17", "Real", "True",],
+ [None, "examples/video/vrnerf_riverview.mp4", "vrnerf", "riverview", "12", "Real", "True",],
+ [None, "examples/video/vrnerf_workshop.mp4", "vrnerf", "workshop", "32", "Real", "True",],
+ [None, "examples/video/fillerbuster_ramen.mp4", "fillerbuster", "ramen", "32", "Real", "True",],
+ [None, "examples/video/meganerf_rubble.mp4", "meganerf", "rubble", "10", "Real", "True",],
+ [None, "examples/video/llff_horns.mp4", "llff", "horns", "12", "Real", "True",],
+ [None, "examples/video/llff_fortress.mp4", "llff", "fortress", "7", "Real", "True",],
+ [None, "examples/video/dtu_scan_106.mp4", "dtu", "scan_106", "20", "Real", "True",],
+ [None, "examples/video/horizongs_hillside_summer.mp4", "horizongs", "hillside_summer", "55", "Synthetic", "True",],
+ [None, "examples/video/kitti360.mp4", "kitti360", "kitti360", "64", "Real", "True",],
+ ]
+
+ def example_pipeline(
+ input_images,
+ input_video,
+ dataset_name,
+ scene_name,
+ num_images_str,
+ image_type,
+ is_example,
+ ):
+ """
+ 1) Copy example images to new target_dir
+ 2) Reconstruct
+ 3) Return model3D + logs + new_dir + updated dropdown + gallery
+ We do NOT return is_example. It's just an input.
+ """
+ target_dir, image_paths = handle_uploads(input_video, input_images)
+ plyfile, video, depth_colored = gradio_demo(target_dir)
+ return plyfile, video, depth_colored, target_dir, image_paths
+
+ gr.Markdown("Click any row to load an example.", elem_classes=["example-log"])
+
+ gr.Examples(
+ examples=examples,
+ inputs=[
+ input_images,
+ input_video,
+ dataset_name,
+ scene_name,
+ num_images,
+ image_type,
+ is_example,
+ ],
+ outputs=[
+ reconstruction_output,
+ rgb_video,
+ depth_video,
+ target_dir_output,
+ image_gallery,
+ ],
+ fn=example_pipeline,
+ cache_examples=False,
+ examples_per_page=50,
+ )
+
+ gr.Markdown("We thank VGGT for their excellent gradio implementation!
")
+
+ submit_btn.click(
+ fn=clear_fields,
+ inputs=[],
+ outputs=[reconstruction_output, rgb_video, depth_video],
+ ).then(
+ fn=gradio_demo,
+ inputs=[
+ target_dir_output,
+ ],
+ outputs=[reconstruction_output, rgb_video, depth_video],
+ ).then(
+ fn=lambda: "False", inputs=[], outputs=[is_example]
+ )
+
+ input_video.change(
+ fn=update_gallery_on_upload,
+ inputs=[input_video, input_images],
+ outputs=[reconstruction_output, target_dir_output, image_gallery],
+ )
+ input_images.change(
+ fn=update_gallery_on_upload,
+ inputs=[input_video, input_images],
+ outputs=[reconstruction_output, target_dir_output, image_gallery],
+ )
+
+ # demo.launch(share=share, server_name=server_name, server_port=server_port)
+ demo.queue(max_size=20).launch(show_error=True, share=True)
+
+ # We thank VGGT for their excellent gradio implementation
diff --git a/examples/video/bungeenerf_colosseum.mp4 b/examples/video/bungeenerf_colosseum.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..333243a8dcc35dbc1c58fac65694fe77ede1b8bf
--- /dev/null
+++ b/examples/video/bungeenerf_colosseum.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:416b6af945547b5d19476823672de552944c7b5a147d29e9e8243e91a16aee3e
+size 329073
diff --git a/examples/video/dtu_scan_106.mp4 b/examples/video/dtu_scan_106.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..60ab7c593788a8b52156d3b0bf7e296ef260d158
--- /dev/null
+++ b/examples/video/dtu_scan_106.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:16d7a06325cd368b134908e600a6c0741c7d0d188f1db690532b8ac85d65fef5
+size 352188
diff --git a/examples/video/fillerbuster_hand_hand.mp4 b/examples/video/fillerbuster_hand_hand.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..c7978c090a854d1798284f44f14d7aa5da70a3cc
--- /dev/null
+++ b/examples/video/fillerbuster_hand_hand.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7b4ca982672bc92342b3e722c171d9d2e4d67a5a8116cd9f346956fbe01e253f
+size 319404
diff --git a/examples/video/fillerbuster_ramen.mp4 b/examples/video/fillerbuster_ramen.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..31a2941246aa4c708a488979f39be2a0fdfc4962
--- /dev/null
+++ b/examples/video/fillerbuster_ramen.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d60346a64a0a0d6805131d0d57edeeb0dae24f24c3f10560e95df65531221229
+size 660736
diff --git a/examples/video/fox.mp4 b/examples/video/fox.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..a42782fe6c0a29797a2a152c40373bd90d968be1
--- /dev/null
+++ b/examples/video/fox.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d3fa2ccff78e5d8085bb58f3def2d482e8df285ced5ef1b56abfe3766f0d90e0
+size 2361921
diff --git a/examples/video/horizongs_hillside_summer.mp4 b/examples/video/horizongs_hillside_summer.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..7e8b209dba1ae99397e91e37db5b6d0f7745e37e
--- /dev/null
+++ b/examples/video/horizongs_hillside_summer.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e5dff78d9c00b3776bfca3a370061698bddead2ae940fe5a42d082ccf2ca80d1
+size 1606537
diff --git a/examples/video/kitti360.mp4 b/examples/video/kitti360.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..961d93921562cad405c46960d64e9d4e1aa8049d
--- /dev/null
+++ b/examples/video/kitti360.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8c6b13929b2c2aae8b95921d8626f5be06f6afffe05ea4e47940ffeb9906f9fc
+size 1843629
diff --git a/examples/video/llff_fortress.mp4 b/examples/video/llff_fortress.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..c858dffb86c1d5bea72b2220ad1545b641e205c4
--- /dev/null
+++ b/examples/video/llff_fortress.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:90ea046a0ec78651975529ebe6b9c72b60c19561fe61b15b15b9df0e44d9fe9a
+size 196243
diff --git a/examples/video/llff_horns.mp4 b/examples/video/llff_horns.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..5f0717367a7d8750013d0ac4e6d6b3f47978a778
--- /dev/null
+++ b/examples/video/llff_horns.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3bc4c443c2a3f889f0c1283e98bd6a7026c36858fb37808bb2e8699ad1a2c1d8
+size 372570
diff --git a/examples/video/matrixcity_street.mp4 b/examples/video/matrixcity_street.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..7d6b95073dfa14a45df403080002ff18ffbf33d3
--- /dev/null
+++ b/examples/video/matrixcity_street.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:aa415f27177398b4e06f580beb3778701ca55784afade2fd6a058212213febc8
+size 3163684
diff --git a/examples/video/meganerf_rubble.mp4 b/examples/video/meganerf_rubble.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..50797215d0e71fb57268c1811ea51f0d2b2bcd0c
--- /dev/null
+++ b/examples/video/meganerf_rubble.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3410c759eb73ca2403ab8fe35d5ebabdbc25e3a0e67d8670a89fe17686246ed0
+size 450116
diff --git a/examples/video/re10k_1eca36ec55b88fe4.mp4 b/examples/video/re10k_1eca36ec55b88fe4.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..84ca9af870cbfd343dc32398579b029bb5ef2a1d
Binary files /dev/null and b/examples/video/re10k_1eca36ec55b88fe4.mp4 differ
diff --git a/examples/video/vrnerf_apartment.mp4 b/examples/video/vrnerf_apartment.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..bc87ca70b8b9053528cb023f9d117f8d10d94d78
--- /dev/null
+++ b/examples/video/vrnerf_apartment.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4fdd5f165a4293cd95e3dd88d84b1f370decdd86308aa67a9d3832e01f4d6906
+size 2076392
diff --git a/examples/video/vrnerf_kitchen.mp4 b/examples/video/vrnerf_kitchen.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..4ef4be83a532a1357a7760fea728f81c66c9846e
--- /dev/null
+++ b/examples/video/vrnerf_kitchen.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3db5d766ec86a7abdfe1f033b252337e6d934ea15035fafb4d0fc0c0e9e9740a
+size 775715
diff --git a/examples/video/vrnerf_riverview.mp4 b/examples/video/vrnerf_riverview.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..b4a30cf0165251fa9bb990ebd662b1e3805e03d4
--- /dev/null
+++ b/examples/video/vrnerf_riverview.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9b8187936cc49910ef330a37b1bbdab0076096d6c01f33b097c11937184de168
+size 768290
diff --git a/examples/video/vrnerf_workshop.mp4 b/examples/video/vrnerf_workshop.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..afb7f73b9f56b6182f38371a344077b000410bfd
--- /dev/null
+++ b/examples/video/vrnerf_workshop.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c0f1334acc74bd70086a9be94d0c36838ebd7499af27f942c315e1ba282e285b
+size 1718918
diff --git a/examples/vrnerf/riverview/21_DSC0001.jpg b/examples/vrnerf/riverview/21_DSC0001.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..802a075a6334bda6007962080b03af5cdbe7a39a
--- /dev/null
+++ b/examples/vrnerf/riverview/21_DSC0001.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7600a24a0725bf42c2260c748f11f39ef495a065f187dd894a6a6b643d209a79
+size 478234
diff --git a/examples/vrnerf/riverview/21_DSC0010.jpg b/examples/vrnerf/riverview/21_DSC0010.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..ab75e2ceb80f2f95d4637ce5d0b37ecf674104a4
--- /dev/null
+++ b/examples/vrnerf/riverview/21_DSC0010.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1621945dbca88d01326e23372993067c7400acf7076c803a3300a159fdb053cc
+size 467263
diff --git a/examples/vrnerf/riverview/21_DSC0019.jpg b/examples/vrnerf/riverview/21_DSC0019.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..b2f7ba0ee6a8bd8cbf4a52d61d051b2601880359
--- /dev/null
+++ b/examples/vrnerf/riverview/21_DSC0019.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2309a28deb2797f96f516cf9bb758e97f86b84f78b9bc88dea793cd43f67715f
+size 453802
diff --git a/examples/vrnerf/riverview/21_DSC0028.jpg b/examples/vrnerf/riverview/21_DSC0028.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..2a025f1b078ddc914bdd75bbac05220f8762aecd
--- /dev/null
+++ b/examples/vrnerf/riverview/21_DSC0028.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bffe2e849bb4bbb32e6e028f966ed1169a0a7d3c6707c90c37a42a471939a1d6
+size 451209
diff --git a/examples/vrnerf/riverview/21_DSC0037.jpg b/examples/vrnerf/riverview/21_DSC0037.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..2d5cf7c1eee7763895ed7d88fa0115cdd1e424a1
--- /dev/null
+++ b/examples/vrnerf/riverview/21_DSC0037.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2a2f68fc6c9036a62b9b601587ae82334c957d1a93ef47ed92cbff6209648591
+size 450246
diff --git a/examples/vrnerf/riverview/21_DSC0046.jpg b/examples/vrnerf/riverview/21_DSC0046.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..4e32c91d346d2a832e19c94b4eab31d7a74f36a4
--- /dev/null
+++ b/examples/vrnerf/riverview/21_DSC0046.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4799bf9781b4e6c0ea0b3bc6bb60a718b14c7343063f7828cc99a934ca2718bf
+size 452324
diff --git a/examples/vrnerf/riverview/21_DSC0055.jpg b/examples/vrnerf/riverview/21_DSC0055.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..0db079c6359ab986766f91b1edbfd3f9c20d7d4a
--- /dev/null
+++ b/examples/vrnerf/riverview/21_DSC0055.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4f8ae4628e53df1dee87e24eb0e32f475f864c693e7ff6e565478abd327ab746
+size 455865
diff --git a/examples/vrnerf/riverview/21_DSC0064.jpg b/examples/vrnerf/riverview/21_DSC0064.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..d7d72edb989d336e1a02a5a377806293c068a8ae
--- /dev/null
+++ b/examples/vrnerf/riverview/21_DSC0064.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a89effe9cba1bb833cf231b3c43b2d2b148194a416e42d46d67d32e57a4b9a13
+size 448712
diff --git a/examples/vrnerf/riverview/21_DSC0073.jpg b/examples/vrnerf/riverview/21_DSC0073.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..1f677a9d4bc3a67c2efac76f17ff03ef2eee7e3f
--- /dev/null
+++ b/examples/vrnerf/riverview/21_DSC0073.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6ce299db65e14c3ec67568e58eec6e33b1177251153bb0284609838cb95a40d1
+size 440003
diff --git a/examples/vrnerf/riverview/21_DSC0082.jpg b/examples/vrnerf/riverview/21_DSC0082.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..7eafbb06fc3d8a506a80468a88d1759b3363778b
--- /dev/null
+++ b/examples/vrnerf/riverview/21_DSC0082.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6dcf7be5b41a46d2a879163bcbde6cc11dcbab7fde2c437b120c428da97cfe98
+size 432309
diff --git a/examples/vrnerf/riverview/21_DSC0091.jpg b/examples/vrnerf/riverview/21_DSC0091.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..768a4497034b1c5e4422843d0a481ed041dacc78
--- /dev/null
+++ b/examples/vrnerf/riverview/21_DSC0091.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:97229231f32411afd68a25b9bb6ffc67b5b068312a3aea675d2deda4471f6cd9
+size 437419
diff --git a/examples/vrnerf/riverview/21_DSC0100.jpg b/examples/vrnerf/riverview/21_DSC0100.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..b3553ef7e126e0a274c91c372da10a782eeab1a2
--- /dev/null
+++ b/examples/vrnerf/riverview/21_DSC0100.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ea6bbb2a4dbafdd50d8ea49b23349d64e7677f5a70db88663cbec3e8d2211ba1
+size 436117
diff --git a/inference.py b/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..f9a30790fdedc795621cd01c3e8b4e01121d4048
--- /dev/null
+++ b/inference.py
@@ -0,0 +1,39 @@
+
+from pathlib import Path
+import torch
+import os
+import sys
+sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from src.misc.image_io import save_interpolated_video
+from src.model.ply_export import export_ply
+from src.model.model.anysplat import AnySplat
+from src.utils.image import process_image
+
+def main():
+ # Load the model from Hugging Face
+ model = AnySplat.from_pretrained("lhjiang/anysplat")
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ model = model.to(device)
+ model.eval()
+ for param in model.parameters():
+ param.requires_grad = False
+
+ # Load Images
+ image_folder = "examples/vrnerf/riverview"
+ images = sorted([os.path.join(image_folder, f) for f in os.listdir(image_folder) if f.lower().endswith(('.png', '.jpg', '.jpeg'))])
+ images = [process_image(img_path) for img_path in images]
+ images = torch.stack(images, dim=0).unsqueeze(0).to(device) # [1, K, 3, 448, 448]
+ b, v, _, h, w = images.shape
+
+ # Run Inference
+ gaussians, pred_context_pose = model.inference((images+1)*0.5)
+
+ # Save the results
+ pred_all_extrinsic = pred_context_pose['extrinsic']
+ pred_all_intrinsic = pred_context_pose['intrinsic']
+ save_interpolated_video(pred_all_extrinsic, pred_all_intrinsic, b, h, w, gaussians, image_folder, model.decoder)
+ export_ply(gaussians.means[0], gaussians.scales[0], gaussians.rotations[0], gaussians.harmonics[0], gaussians.opacities[0], Path(image_folder) / "gaussians.ply")
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..b1c809772e8cab1e73844177229c2c3814db82c6
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,38 @@
+numpy==1.25.0
+wheel
+tqdm
+lightning
+black
+ruff
+hydra-core
+jaxtyping
+beartype
+wandb
+einops
+colorama
+scikit-image
+colorspacious
+matplotlib
+moviepy
+imageio
+timm
+dacite
+lpips
+e3nn
+plyfile
+tabulate
+svg.py
+scikit-video
+opencv-python
+Pillow
+huggingface_hub
+gradio
+xformers==0.0.24
+torch_scatter==2.1.2
+moviepy==1.0.3
+pydantic
+open3d
+einops
+safetensors
+git+https://github.com/facebookresearch/pytorch3d.git
+https://github.com/nerfstudio-project/gsplat/releases/download/v1.4.0/gsplat-1.4.0%2Bpt22cu121-cp310-cp310-linux_x86_64.whl
diff --git a/src/config.py b/src/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..4646a0e79be3cd6801b9fc6bc2ca363eb09dd359
--- /dev/null
+++ b/src/config.py
@@ -0,0 +1,106 @@
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Literal, Optional, Type, TypeVar
+
+from dacite import Config, from_dict
+from omegaconf import DictConfig, OmegaConf
+
+from .dataset import DatasetCfgWrapper
+from .dataset.data_module import DataLoaderCfg
+from .loss import LossCfgWrapper
+from .model.decoder import DecoderCfg
+from .model.encoder import EncoderCfg
+from .model.model_wrapper import OptimizerCfg, TestCfg, TrainCfg
+
+
+@dataclass
+class CheckpointingCfg:
+ load: Optional[str] # Not a path, since it could be something like wandb://...
+ every_n_train_steps: int
+ save_top_k: int
+ save_weights_only: bool
+
+
+@dataclass
+class ModelCfg:
+ decoder: DecoderCfg
+ encoder: EncoderCfg
+
+
+@dataclass
+class TrainerCfg:
+ max_steps: int
+ val_check_interval: int | float | None
+ gradient_clip_val: int | float | None
+ num_nodes: int = 1
+ accumulate_grad_batches: int = 1
+ precision: Literal["32", "16-mixed", "bf16-mixed"] = "32"
+
+
+@dataclass
+class RootCfg:
+ wandb: dict
+ mode: Literal["train", "test"]
+ dataset: list[DatasetCfgWrapper]
+ data_loader: DataLoaderCfg
+ model: ModelCfg
+ optimizer: OptimizerCfg
+ checkpointing: CheckpointingCfg
+ trainer: TrainerCfg
+ loss: list[LossCfgWrapper]
+ test: TestCfg
+ train: TrainCfg
+ seed: int
+
+
+TYPE_HOOKS = {
+ Path: Path,
+}
+
+
+T = TypeVar("T")
+
+
+def load_typed_config(
+ cfg: DictConfig,
+ data_class: Type[T],
+ extra_type_hooks: dict = {},
+) -> T:
+ return from_dict(
+ data_class,
+ OmegaConf.to_container(cfg),
+ config=Config(type_hooks={**TYPE_HOOKS, **extra_type_hooks}),
+ )
+
+
+def separate_loss_cfg_wrappers(joined: dict) -> list[LossCfgWrapper]:
+ # The dummy allows the union to be converted.
+ @dataclass
+ class Dummy:
+ dummy: LossCfgWrapper
+
+ return [
+ load_typed_config(DictConfig({"dummy": {k: v}}), Dummy).dummy
+ for k, v in joined.items()
+ ]
+
+
+def separate_dataset_cfg_wrappers(joined: dict) -> list[DatasetCfgWrapper]:
+ # The dummy allows the union to be converted.
+ @dataclass
+ class Dummy:
+ dummy: DatasetCfgWrapper
+
+ return [
+ load_typed_config(DictConfig({"dummy": {k: v}}), Dummy).dummy
+ for k, v in joined.items()
+ ]
+
+
+def load_typed_root_config(cfg: DictConfig) -> RootCfg:
+ return load_typed_config(
+ cfg,
+ RootCfg,
+ {list[LossCfgWrapper]: separate_loss_cfg_wrappers,
+ list[DatasetCfgWrapper]: separate_dataset_cfg_wrappers},
+ )
diff --git a/src/dataset/__init__.py b/src/dataset/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1d7936add0b5e69799329be24447e536058d72a
--- /dev/null
+++ b/src/dataset/__init__.py
@@ -0,0 +1,98 @@
+from dataclasses import fields
+from typing import Callable
+from torch.utils.data import Dataset, ConcatDataset
+import bisect
+
+from ..misc.step_tracker import StepTracker
+from .types import Stage
+from .view_sampler import get_view_sampler
+from .dataset_dl3dv import DatasetDL3DV, DatasetDL3DVCfgWrapper
+from .dataset_scannetpp import DatasetScannetpp, DatasetScannetppCfgWrapper
+from .dataset_co3d import DatasetCo3d, DatasetCo3dCfgWrapper
+
+DATASETS: dict[str, Dataset] = {
+ "co3d": DatasetCo3d,
+ "scannetpp": DatasetScannetpp,
+ "dl3dv": DatasetDL3DV,
+}
+
+DatasetCfgWrapper = DatasetDL3DVCfgWrapper | DatasetScannetppCfgWrapper | DatasetCo3dCfgWrapper
+
+class TestDatasetWarpper(Dataset):
+ def __init__(self, dataset: Dataset):
+ self.dataset = dataset
+
+ def __getitem__(self, idx):
+
+ return self.dataset[(idx, self.dataset.view_sampler.num_context_views, self.dataset.cfg.input_image_shape[1] // 14)] # fake parameters here, to fit the input of dataset
+
+ def __len__(self):
+ return len(self.dataset)
+
+
+
+class CustomConcatDataset(ConcatDataset):
+
+ def __getitem__(self, idx_tuple):
+
+ if isinstance(idx_tuple, list):
+ idx_tuple = idx_tuple[0]
+
+ idx = idx_tuple[0]
+ if idx < 0:
+ if -idx > len(self):
+ raise ValueError("absolute value of index should not exceed dataset length")
+ idx = len(self) + idx
+ dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
+ if dataset_idx == 0:
+ sample_idx = idx
+ else:
+ sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
+ return self.datasets[dataset_idx][(sample_idx, idx_tuple[1], idx_tuple[2])]
+
+
+def get_dataset(
+ cfgs: list[DatasetCfgWrapper],
+ stage: Stage,
+ step_tracker: StepTracker | None,
+ dataset_shim: Callable[[Dataset, str], Dataset]
+) -> list[Dataset]:
+ datasets = []
+ if stage != "test":
+ if stage == "val":
+ cfgs = [cfgs[0]]
+ for cfg in cfgs:
+ (field,) = fields(type(cfg))
+ cfg = getattr(cfg, field.name)
+
+ view_sampler = get_view_sampler(
+ cfg.view_sampler,
+ stage,
+ cfg.overfit_to_scene is not None,
+ cfg.cameras_are_circular,
+ step_tracker,
+ )
+ dataset = DATASETS[cfg.name](cfg, stage, view_sampler)
+ dataset = dataset_shim(dataset, stage)
+ datasets.append(dataset)
+
+ return CustomConcatDataset(datasets), datasets
+ elif stage == "test":
+ assert len(cfgs) == 1
+ cfg = cfgs[0]
+ (field,) = fields(type(cfg))
+ cfg = getattr(cfg, field.name)
+
+ view_sampler = get_view_sampler(
+ cfg.view_sampler,
+ stage,
+ cfg.overfit_to_scene is not None,
+ cfg.cameras_are_circular,
+ step_tracker,
+ )
+ dataset = DATASETS[cfg.name](cfg, stage, view_sampler)
+ dataset = dataset_shim(dataset, stage)
+
+ return TestDatasetWarpper(dataset)
+ else:
+ NotImplementedError(f"Stage {stage} is not supported")
diff --git a/src/dataset/data_module.py b/src/dataset/data_module.py
new file mode 100644
index 0000000000000000000000000000000000000000..ed4767e3fa087f8ad50ad9e6538a88516e4c98a7
--- /dev/null
+++ b/src/dataset/data_module.py
@@ -0,0 +1,194 @@
+import random
+from dataclasses import dataclass
+from typing import Callable
+
+import numpy as np
+import torch
+from lightning.pytorch import LightningDataModule
+from torch import Generator, nn
+from torch.utils.data import DataLoader, Dataset, IterableDataset
+
+from src.dataset import *
+from src.global_cfg import get_cfg
+
+
+from ..misc.step_tracker import StepTracker
+from ..misc.utils import get_world_size, get_rank
+from . import DatasetCfgWrapper, get_dataset
+from .types import DataShim, Stage
+from .data_sampler import BatchedRandomSampler, MixedBatchSampler, custom_collate_fn
+from .validation_wrapper import ValidationWrapper
+
+def get_data_shim(encoder: nn.Module) -> DataShim:
+ """Get functions that modify the batch. It's sometimes necessary to modify batches
+ outside the data loader because GPU computations are required to modify the batch or
+ because the modification depends on something outside the data loader.
+ """
+
+ shims: list[DataShim] = []
+ if hasattr(encoder, "get_data_shim"):
+ shims.append(encoder.get_data_shim())
+
+ def combined_shim(batch):
+ for shim in shims:
+ batch = shim(batch)
+ return batch
+
+ return combined_shim
+
+# the training ratio of datasets (example)
+prob_mapping = {DatasetScannetpp: 0.5,
+ DatasetDL3DV: 0.5,
+ DatasetCo3d: 0.5}
+
+@dataclass
+class DataLoaderStageCfg:
+ batch_size: int
+ num_workers: int
+ persistent_workers: bool
+ seed: int | None
+
+
+@dataclass
+class DataLoaderCfg:
+ train: DataLoaderStageCfg
+ test: DataLoaderStageCfg
+ val: DataLoaderStageCfg
+
+
+DatasetShim = Callable[[Dataset, Stage], Dataset]
+
+
+def worker_init_fn(worker_id: int) -> None:
+ random.seed(int(torch.utils.data.get_worker_info().seed) % (2**32 - 1))
+ np.random.seed(int(torch.utils.data.get_worker_info().seed) % (2**32 - 1))
+
+
+class DataModule(LightningDataModule):
+ dataset_cfgs: list[DatasetCfgWrapper]
+ data_loader_cfg: DataLoaderCfg
+ step_tracker: StepTracker | None
+ dataset_shim: DatasetShim
+ global_rank: int
+
+ def __init__(
+ self,
+ dataset_cfgs: list[DatasetCfgWrapper],
+ data_loader_cfg: DataLoaderCfg,
+ step_tracker: StepTracker | None = None,
+ dataset_shim: DatasetShim = lambda dataset, _: dataset,
+ global_rank: int = 0,
+ ) -> None:
+ super().__init__()
+ self.dataset_cfgs = dataset_cfgs
+ self.data_loader_cfg = data_loader_cfg
+ self.step_tracker = step_tracker
+ self.dataset_shim = dataset_shim
+ self.global_rank = global_rank
+
+ def get_persistent(self, loader_cfg: DataLoaderStageCfg) -> bool | None:
+ return None if loader_cfg.num_workers == 0 else loader_cfg.persistent_workers
+
+ def get_generator(self, loader_cfg: DataLoaderStageCfg) -> torch.Generator | None:
+ if loader_cfg.seed is None:
+ return None
+ generator = Generator()
+ generator.manual_seed(loader_cfg.seed + self.global_rank)
+ self.generator = generator
+ return self.generator
+
+ def train_dataloader(self):
+ dataset, datasets_ls = get_dataset(self.dataset_cfgs, "train", self.step_tracker, self.dataset_shim)
+ world_size = get_world_size()
+ rank = get_rank()
+ # breakpoint()
+ prob_ls = [prob_mapping[type(dataset)] for dataset in datasets_ls]
+ # we assume all the dataset share the same num_context_views
+
+ if len(datasets_ls) > 1:
+ prob = prob_ls
+ context_num_views = [dataset.cfg.view_sampler.num_context_views for dataset in datasets_ls]
+ else:
+ prob = None
+ dataset_key = next(iter(get_cfg()["dataset"]))
+ dataset_cfg = get_cfg()["dataset"][dataset_key]
+ context_num_views = dataset_cfg['view_sampler']['num_context_views']
+
+ sampler = MixedBatchSampler(datasets_ls,
+ batch_size=self.data_loader_cfg.train.batch_size, # Not used here!
+ num_context_views=context_num_views,
+ world_size=world_size,
+ rank=rank,
+ prob=prob,
+ generator=self.get_generator(self.data_loader_cfg.train))
+ sampler.set_epoch(0)
+ self.train_loader = DataLoader(
+ dataset,
+ # self.data_loader_cfg.train.batch_size,
+ # shuffle=not isinstance(dataset, IterableDataset),
+ batch_sampler=sampler,
+ num_workers=self.data_loader_cfg.train.num_workers,
+ generator=self.generator,
+ worker_init_fn=worker_init_fn,
+ # collate_fn=custom_collate_fn,
+ persistent_workers=self.get_persistent(self.data_loader_cfg.train),
+ )
+ # breakpoint()
+ # Set epoch for train and validation loaders (if applicable)
+ if hasattr(self.train_loader, "dataset") and hasattr(self.train_loader.dataset, "set_epoch"):
+ print("Training: Set Epoch in DataModule")
+ self.train_loader.dataset.set_epoch(0)
+ if hasattr(self.train_loader, "sampler") and hasattr(self.train_loader.sampler, "set_epoch"):
+ print("Training: Set Epoch in DataModule")
+ self.train_loader.sampler.set_epoch(0)
+
+ return self.train_loader
+
+ def val_dataloader(self):
+ dataset, datasets_ls = get_dataset(self.dataset_cfgs, "val", self.step_tracker, self.dataset_shim)
+ world_size = get_world_size()
+ rank = get_rank()
+ # here, we random select one dataset for val
+ dataset_key = next(iter(get_cfg()["dataset"]))
+ dataset_cfg = get_cfg()["dataset"][dataset_key]
+ if len(datasets_ls) > 1:
+ prob = [0.5] * len(datasets_ls)
+ else:
+ prob = None
+ sampler = MixedBatchSampler(datasets_ls,
+ batch_size=self.data_loader_cfg.train.batch_size,
+ num_context_views=dataset_cfg['view_sampler']['num_context_views'],
+ world_size=world_size,
+ rank=rank,
+ prob=prob,
+ generator=self.get_generator(self.data_loader_cfg.train))
+ sampler.set_epoch(0)
+ self.val_loader = DataLoader(
+ dataset,
+ self.data_loader_cfg.val.batch_size,
+ num_workers=self.data_loader_cfg.val.num_workers,
+ sampler=sampler,
+ generator=self.get_generator(self.data_loader_cfg.val),
+ worker_init_fn=worker_init_fn,
+ persistent_workers=self.get_persistent(self.data_loader_cfg.val),
+ )
+ if hasattr(self.val_loader, "dataset") and hasattr(self.val_loader.dataset, "set_epoch"):
+ print("Validation: Set Epoch in DataModule")
+ self.val_loader.dataset.set_epoch(0)
+ if hasattr(self.val_loader, "sampler") and hasattr(self.val_loader.sampler, "set_epoch"):
+ print("Validation: Set Epoch in DataModule")
+ self.val_loader.sampler.set_epoch(0)
+ return self.val_loader
+
+ def test_dataloader(self):
+ dataset = get_dataset(self.dataset_cfgs, "test", self.step_tracker, self.dataset_shim)
+ data_loader = DataLoader(
+ dataset,
+ self.data_loader_cfg.test.batch_size,
+ num_workers=self.data_loader_cfg.test.num_workers,
+ generator=self.get_generator(self.data_loader_cfg.test),
+ worker_init_fn=worker_init_fn,
+ persistent_workers=self.get_persistent(self.data_loader_cfg.test),
+ )
+
+ return data_loader
\ No newline at end of file
diff --git a/src/dataset/data_sampler.py b/src/dataset/data_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f5dda411ddad56a66c552540957691f93062240
--- /dev/null
+++ b/src/dataset/data_sampler.py
@@ -0,0 +1,374 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# Copyright (C) 2024-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+#
+# --------------------------------------------------------
+# Random sampling under a constraint
+# --------------------------------------------------------
+import numpy as np
+import torch
+from typing import Callable, Iterable, Optional
+from torch.utils.data import DataLoader, Dataset, DistributedSampler, IterableDataset, Sampler, BatchSampler
+import random
+
+def custom_collate_fn(batch):
+ """
+ Custom collate function to handle variable batch sizes
+
+ Args:
+ batch: A list where each element could be either:
+ - A single tuple (idx, num_images, ...)
+ - A list of tuples [(idx1, num_images1, ...), (idx2, num_images2, ...)]
+ """
+ # If batch contains lists (variable batch size case)
+ breakpoint()
+ if isinstance(batch[0], list):
+ # Flatten the batch
+ flattened = []
+ for item in batch:
+ flattened.extend(item)
+ batch = flattened
+
+ # Now batch is a list of tuples, process normally
+ return torch.utils.data.default_collate(batch)
+
+class BatchedRandomSampler:
+ """Random sampling under a constraint: each sample in the batch has the same feature,
+ which is chosen randomly from a known pool of 'features' for each batch.
+
+ For instance, the 'feature' could be the image aspect-ratio.
+
+ The index returned is a tuple (sample_idx, feat_idx).
+ This sampler ensures that each series of `batch_size` indices has the same `feat_idx`.
+ """
+
+ def __init__(
+ self, dataset, batch_size, num_context_views, min_patch_num=20, max_patch_num=32, world_size=1, rank=0, drop_last=True
+ ):
+ self.batch_size = batch_size
+ self.num_context_views = num_context_views
+
+ self.len_dataset = N = len(dataset)
+ self.total_size = round_by(N, batch_size * world_size) if drop_last else N
+ self.min_patch_num = min_patch_num
+ self.max_patch_num = max_patch_num
+ assert (
+ world_size == 1 or drop_last
+ ), "must drop the last batch in distributed mode"
+
+ # distributed sampler
+ self.world_size = world_size
+ self.rank = rank
+ self.epoch = None
+
+ def __len__(self):
+
+
+ return self.total_size // self.world_size
+
+ def set_epoch(self, epoch):
+ self.epoch = epoch
+
+ def __iter__(self):
+ # prepare RNG
+ if self.epoch is None:
+ assert (
+ self.world_size == 1 and self.rank == 0
+ ), "use set_epoch() if distributed mode is used"
+ seed = int(torch.empty((), dtype=torch.int64).random_().item())
+ else:
+ seed = self.epoch + 777
+ rng = np.random.default_rng(seed=seed)
+
+ # random indices (will restart from 0 if not drop_last)
+ sample_idxs = np.arange(self.total_size)
+ rng.shuffle(sample_idxs)
+
+ # random feat_idxs (same across each batch)
+ n_batches = (self.total_size + self.batch_size - 1) // self.batch_size
+ num_imgs = rng.integers(low=2, high=self.num_context_views, size=n_batches)
+ # num_imgs = (np.ones(n_batches) * self.num_context_views).astype(np.int64) # same number of context views for each batch
+ num_imgs = np.broadcast_to(num_imgs[:, None], (n_batches, self.batch_size))
+ num_imgs = num_imgs.ravel()[: self.total_size]
+
+ # put them together
+ idxs = np.c_[sample_idxs, num_imgs] # shape = (total_size, 2)
+
+ # Distributed sampler: we select a subset of batches
+ # make sure the slice for each node is aligned with batch_size
+ size_per_proc = self.batch_size * (
+ (self.total_size + self.world_size * self.batch_size - 1)
+ // (self.world_size * self.batch_size)
+ )
+ idxs = idxs[self.rank * size_per_proc : (self.rank + 1) * size_per_proc]
+
+ yield from (tuple(idx) for idx in idxs)
+
+class DynamicBatchSampler(Sampler):
+ """
+ A custom batch sampler that dynamically adjusts batch size, aspect ratio, and image number
+ for each sample. Batches within a sample share the same aspect ratio and image number.
+ """
+ def __init__(self,
+ sampler,
+ image_num_range,
+ h_range,
+ epoch=0,
+ seed=42,
+ max_img_per_gpu=48):
+ """
+ Initializes the dynamic batch sampler.
+
+ Args:
+ sampler: Instance of DynamicDistributedSampler.
+ aspect_ratio_range: List containing [min_aspect_ratio, max_aspect_ratio].
+ image_num_range: List containing [min_images, max_images] per sample.
+ epoch: Current epoch number.
+ seed: Random seed for reproducibility.
+ max_img_per_gpu: Maximum number of images to fit in GPU memory.
+ """
+ self.sampler = sampler
+ self.image_num_range = image_num_range
+ self.h_range = h_range
+ self.rng = random.Random()
+
+ # Uniformly sample from the range of possible image numbers
+ # For any image number, the weight is 1.0 (uniform sampling). You can set any different weights here.
+ self.image_num_weights = {num_images: float(num_images**2) for num_images in range(image_num_range[0], image_num_range[1]+1)}
+
+ # Possible image numbers, e.g., [2, 3, 4, ..., 24]
+ self.possible_nums = np.array([n for n in self.image_num_weights.keys()
+ if self.image_num_range[0] <= n <= self.image_num_range[1]])
+
+ # Normalize weights for sampling
+ weights = [self.image_num_weights[n] for n in self.possible_nums]
+ self.normalized_weights = np.array(weights) / sum(weights)
+
+ # Maximum image number per GPU
+ self.max_img_per_gpu = max_img_per_gpu
+
+ # Set the epoch for the sampler
+ self.set_epoch(epoch + seed)
+
+ def set_epoch(self, epoch):
+ """
+ Sets the epoch for this sampler, affecting the random sequence.
+
+ Args:
+ epoch: The epoch number.
+ """
+ self.sampler.set_epoch(epoch)
+ self.epoch = epoch
+ self.rng.seed(epoch * 100)
+
+ def __iter__(self):
+ """
+ Yields batches of samples with synchronized dynamic parameters.
+
+ Returns:
+ Iterator yielding batches of indices with associated parameters.
+ """
+ sampler_iterator = iter(self.sampler)
+
+ while True:
+ try:
+ # Sample random image number and aspect ratio
+ random_image_num = int(np.random.choice(self.possible_nums, p=self.normalized_weights))
+ random_ps_h = np.random.randint(low=(self.h_range[0] // 14), high=(self.h_range[1] // 14)+1)
+
+ # Update sampler parameters
+ self.sampler.update_parameters(
+ image_num=random_image_num,
+ ps_h=random_ps_h
+ )
+
+ # Calculate batch size based on max images per GPU and current image number
+ batch_size = self.max_img_per_gpu / random_image_num
+ batch_size = np.floor(batch_size).astype(int)
+ batch_size = max(1, batch_size) # Ensure batch size is at least 1
+
+ # Collect samples for the current batch
+ current_batch = []
+ for _ in range(batch_size):
+ try:
+ item = next(sampler_iterator) # item is (idx, aspect_ratio, image_num)
+ current_batch.append(item)
+ except StopIteration:
+ break # No more samples
+
+ if not current_batch:
+ break # No more data to yield
+
+ yield current_batch
+
+ except StopIteration:
+ break # End of sampler's iterator
+
+ def __len__(self):
+ # Return a large dummy length
+ return 1000000
+
+
+class DynamicDistributedSampler(DistributedSampler):
+ """
+ Extends PyTorch's DistributedSampler to include dynamic aspect_ratio and image_num
+ parameters, which can be passed into the dataset's __getitem__ method.
+ """
+ def __init__(
+ self,
+ dataset,
+ num_replicas: Optional[int] = None,
+ rank: Optional[int] = None,
+ shuffle: bool = False,
+ seed: int = 0,
+ drop_last: bool = False,
+ ):
+ super().__init__(
+ dataset,
+ num_replicas=num_replicas,
+ rank=rank,
+ shuffle=shuffle,
+ seed=seed,
+ drop_last=drop_last
+ )
+ self.image_num = None
+ self.ps_h = None
+
+ def __iter__(self):
+ """
+ Yields a sequence of (index, image_num, aspect_ratio).
+ Relies on the parent class's logic for shuffling/distributing
+ the indices across replicas, then attaches extra parameters.
+ """
+ indices_iter = super().__iter__()
+
+ for idx in indices_iter:
+ yield (idx, self.image_num, self.ps_h, )
+
+ def update_parameters(self, image_num, ps_h):
+ """
+ Updates dynamic parameters for each new epoch or iteration.
+
+ Args:
+ aspect_ratio: The aspect ratio to set.
+ image_num: The number of images to set.
+ """
+ self.image_num = image_num
+ self.ps_h = ps_h
+
+class MixedBatchSampler(BatchSampler):
+ """Sample one batch from a selected dataset with given probability.
+ Compatible with datasets at different resolution
+ """
+
+ def __init__(
+ self, src_dataset_ls, batch_size, num_context_views, world_size=1, rank=0, prob=None, sampler=None, generator=None
+ ):
+ self.base_sampler = None
+ self.batch_size = batch_size
+ self.num_context_views = num_context_views
+ self.world_size = world_size
+ self.rank = rank
+ self.drop_last = True
+ self.generator = generator
+
+ self.src_dataset_ls = src_dataset_ls
+ self.n_dataset = len(self.src_dataset_ls)
+
+ # Dataset length
+ self.dataset_length = [len(ds) for ds in self.src_dataset_ls]
+ self.cum_dataset_length = [
+ sum(self.dataset_length[:i]) for i in range(self.n_dataset)
+ ] # cumulative dataset length
+
+ # BatchSamplers for each source dataset
+ self.src_batch_samplers = []
+ for ds in self.src_dataset_ls:
+ sampler = DynamicDistributedSampler(ds, num_replicas=self.world_size, rank=self.rank, seed=42, shuffle=True)
+ sampler.set_epoch(0)
+
+ if hasattr(ds, "epoch"):
+ ds.epoch = 0
+ if hasattr(ds, "set_epoch"):
+ ds.set_epoch(0)
+ batch_sampler = DynamicBatchSampler(
+ sampler,
+ [2, ds.cfg.view_sampler.num_context_views],
+ ds.cfg.input_image_shape,
+ seed=42,
+ max_img_per_gpu=ds.cfg.view_sampler.max_img_per_gpu
+ )
+ self.src_batch_samplers.append(batch_sampler)
+
+ # self.src_batch_samplers = [
+ # BatchedRandomSampler(
+ # ds,
+ # num_context_views=ds.cfg.view_sampler.num_context_views,
+ # world_size=self.world_size,
+ # rank=self.rank,
+ # batch_size=self.batch_size,
+ # drop_last=self.drop_last,
+ # )
+ # for ds in self.src_dataset_ls
+ # ]
+ # set epoch here
+ print("Setting epoch for all underlying BatchedRandomSamplers")
+ # for sampler in self.src_batch_samplers:
+ # sampler.set_epoch(0)
+ self.raw_batches = [
+ list(bs) for bs in self.src_batch_samplers
+ ] # index in original dataset
+ self.n_batches = [len(b) for b in self.raw_batches]
+ self.n_total_batch = sum(self.n_batches)
+ # print("Total batch num is ", self.n_total_batch)
+ # sampling probability
+ if prob is None:
+ # if not given, decide by dataset length
+ self.prob = torch.tensor(self.n_batches) / self.n_total_batch
+ else:
+ self.prob = torch.as_tensor(prob)
+
+ def __iter__(self):
+ """Yields batches of indices in the format of (sample_idx, feat_idx) tuples,
+ where indices correspond to ConcatDataset of src_dataset_ls
+ """
+ for _ in range(self.n_total_batch):
+ idx_ds = torch.multinomial(
+ self.prob, 1, replacement=True, generator=self.generator
+ ).item()
+
+ if 0 == len(self.raw_batches[idx_ds]):
+ self.raw_batches[idx_ds] = list(self.src_batch_samplers[idx_ds])
+
+ # get a batch from list - this is already in (sample_idx, feat_idx) format
+ batch_raw = self.raw_batches[idx_ds].pop()
+
+ # shift only the sample_idx by cumulative dataset length, keep feat_idx unchanged
+ shift = self.cum_dataset_length[idx_ds]
+ processed_batch = []
+
+ for item in batch_raw:
+ # item[0] is the sample index, item[1] is the number of images
+ processed_item = (item[0] + shift, item[1], item[2])
+ processed_batch.append(processed_item)
+ yield processed_batch
+
+ def set_epoch(self, epoch):
+ """Set epoch for all underlying BatchedRandomSamplers"""
+ for sampler in self.src_batch_samplers:
+ sampler.set_epoch(epoch)
+ # Reset raw_batches after setting new epoch
+ self.raw_batches = [list(bs) for bs in self.src_batch_samplers]
+
+ def __len__(self):
+ return self.n_total_batch
+
+def round_by(total, multiple, up=False):
+ if up:
+ total = total + multiple - 1
+ return (total // multiple) * multiple
\ No newline at end of file
diff --git a/src/dataset/dataset.py b/src/dataset/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..18a29582cd56d920447fdf613b3d0491f15cd0ee
--- /dev/null
+++ b/src/dataset/dataset.py
@@ -0,0 +1,13 @@
+from dataclasses import dataclass
+
+from .view_sampler import ViewSamplerCfg
+
+
+@dataclass
+class DatasetCfgCommon:
+ original_image_shape: list[int]
+ input_image_shape: list[int]
+ background_color: list[float]
+ cameras_are_circular: bool
+ overfit_to_scene: str | None
+ view_sampler: ViewSamplerCfg
diff --git a/src/dataset/dataset_co3d.py b/src/dataset/dataset_co3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..cf96da135b9279e1e2b08ad031624cd9a06cf610
--- /dev/null
+++ b/src/dataset/dataset_co3d.py
@@ -0,0 +1,380 @@
+from concurrent.futures import ThreadPoolExecutor, as_completed
+import json
+from dataclasses import dataclass
+from functools import cached_property
+from pathlib import Path
+import random
+from typing import Literal
+import os
+import numpy as np
+import torch
+import torchvision.transforms as tf
+from einops import rearrange, repeat
+from jaxtyping import Float, UInt8
+from PIL import Image
+from torch import Tensor
+from torch.utils.data import Dataset
+import os.path as osp
+import cv2
+from ..geometry.projection import get_fov
+from .dataset import DatasetCfgCommon
+from .shims.augmentation_shim import apply_augmentation_shim
+from .shims.crop_shim import apply_crop_shim
+from .types import Stage
+from .view_sampler import ViewSampler
+from ..misc.cam_utils import camera_normalization
+
+from .shims.geometry_shim import depthmap_to_absolute_camera_coordinates
+
+CATEGORY = {'train':
+ ["backpack", "ball", "banana", "baseballbat", "baseballglove",
+ "bench", "bicycle", "book", "bottle", "bowl", "broccoli", "cake", "car", "carrot",
+ "cellphone", "chair", "couch", "cup", "donut", "frisbee", "hairdryer", "handbag",
+ "hotdog", "hydrant", "keyboard", "kite", "laptop", "microwave",
+ "motorcycle",
+ "mouse", "orange", "parkingmeter", "pizza", "plant", "remote", "sandwich",
+ "skateboard", "stopsign",
+ "suitcase", "teddybear", "toaster", "toilet", "toybus",
+ "toyplane", "toytrain", "toytruck", "tv",
+ "umbrella", "vase", "wineglass",],
+ 'test': ['teddybear']}
+
+@dataclass
+class DatasetCo3dCfg(DatasetCfgCommon):
+ name: str
+ roots: list[Path]
+ baseline_min: float
+ baseline_max: float
+ max_fov: float
+ make_baseline_1: bool
+ augment: bool
+ relative_pose: bool
+ skip_bad_shape: bool
+ normalize_by_pts3d: bool
+ intr_augment: bool
+ rescale_to_1cube: bool
+ mask_bg: Literal['rand', True, False] = True
+
+@dataclass
+class DatasetCo3dCfgWrapper:
+ co3d: DatasetCo3dCfg
+
+
+class DatasetCo3d(Dataset):
+ cfg: DatasetCo3dCfg
+ stage: Stage
+ view_sampler: ViewSampler
+
+ to_tensor: tf.ToTensor
+ chunks: list[Path]
+ near: float = 0.1
+ far: float = 100.0
+
+ def __init__(
+ self,
+ cfg: DatasetCo3dCfg,
+ stage: Stage,
+ view_sampler: ViewSampler,
+ ) -> None:
+ super().__init__()
+ self.cfg = cfg
+ self.stage = stage
+ self.view_sampler = view_sampler
+ self.to_tensor = tf.ToTensor()
+
+ self.root = cfg.roots[0]
+ self.mask_bg = cfg.mask_bg
+ assert self.mask_bg in ('rand', True, False)
+
+ # load all scenes
+ self.categories = CATEGORY[self.data_stage]
+ self.scene_seq_dict = {}
+ self.scene_ids = []
+ for category in self.categories:
+ with open(osp.join(self.root, f"{category}/valid_seq.json"), "r") as f:
+ scene_seq_dict = json.load(f)
+ for scene, seqs in scene_seq_dict.items():
+ self.scene_seq_dict[f"{category}/{scene}"] = seqs
+ self.scene_ids.append(f"{category}/{scene}")
+
+ print(f"CO3Dv2 {self.stage}: loaded {len(self.scene_seq_dict)} scenes")
+
+ def load_frames(self, scene_id, frame_ids):
+ with ThreadPoolExecutor(max_workers=32) as executor:
+ # Create a list to store futures with their original indices
+ futures_with_idx = []
+ for idx, frame_id in enumerate(frame_ids):
+ file_path = os.path.join(self.root, f"{scene_id}/images/frame{frame_id:06d}.jpg")
+ futures_with_idx.append(
+ (
+ idx,
+ executor.submit(
+ lambda p: self.to_tensor(Image.open(p).convert("RGB")),
+ file_path,
+ ),
+ )
+ )
+
+ # Pre-allocate list with correct size to maintain order
+ torch_images = [None] * len(frame_ids)
+ for idx, future in futures_with_idx:
+ torch_images[idx] = future.result()
+ # Check if all images have the same size
+ sizes = set(img.shape for img in torch_images)
+ if len(sizes) == 1:
+ torch_images = torch.stack(torch_images)
+ # Return as list if images have different sizes
+ return torch_images
+
+ def load_npz(self, scene_id, frame_id):
+ npzpath = os.path.join(self.root, f"{scene_id}/images/frame{frame_id:06d}.npz")
+ imgpath = os.path.join(self.root, f"{scene_id}/images/frame{frame_id:06d}.jpg")
+ img = Image.open(imgpath)
+ # breakpoint()
+ W, H = img.size
+ npzdata = np.load(npzpath)
+ intri = npzdata['camera_intrinsics']
+ extri = npzdata['camera_pose']
+ intri[0, 0] /= float(W)
+ intri[1, 1] /= float(H)
+ intri[0, 2] /= float(W)
+ intri[1, 2] /= float(H)
+ md = npzdata['maximum_depth']
+ return intri, extri, md
+
+ def load_depth(self, scene_id, frame_ids, mds):
+ torch_depths = []
+ for frame_id in frame_ids:
+ depthpath = os.path.join(self.root, f"{scene_id}/depths/frame{frame_id:06d}.jpg.geometric.png")
+ depth = cv2.imread(depthpath, cv2.IMREAD_UNCHANGED)/65535*np.nan_to_num(mds[frame_id])
+ depth = np.nan_to_num(depth)
+ torch_depths.append(torch.from_numpy(depth))
+ return torch_depths
+
+ def load_masks(self, scene_id, frame_ids):
+ masks = []
+ for frame_id in frame_ids:
+ maskpath = os.path.join(self.root, f"{scene_id}/masks/frame{frame_id:06d}.png")
+ maskmap = cv2.imread(maskpath, cv2.IMREAD_UNCHANGED).astype(np.float32)
+ maskmap = (maskmap / 255.0) > 0.1
+ masks.append(torch.from_numpy(maskmap))
+ return masks
+
+ def getitem(self, index: int, num_context_views: int, patchsize: tuple) -> dict:
+ scene_id = self.scene_ids[index]
+ seq = self.scene_seq_dict[scene_id]
+
+ extrinsics = []
+ intrinsics = []
+ frame_ids = []
+ mds = {}
+ for frame_id in seq:
+ intri, extri, md = self.load_npz(scene_id, frame_id)
+ extrinsics.append(extri)
+ intrinsics.append(intri)
+ frame_ids.append(frame_id)
+ mds[frame_id] = md
+
+ extrinsics = np.array(extrinsics)
+ intrinsics = np.array(intrinsics)
+ extrinsics = torch.tensor(extrinsics, dtype=torch.float32)
+ intrinsics = torch.tensor(intrinsics, dtype=torch.float32)
+
+ num_views = extrinsics.shape[0]
+ context_indices = torch.tensor(random.sample(range(num_views), num_context_views))
+ remaining_indices = torch.tensor([i for i in range(num_views) if i not in context_indices])
+ target_indices = torch.tensor(random.sample(remaining_indices.tolist(), self.view_sampler.num_target_views))
+
+ # Skip the example if the field of view is too wide.
+ if (get_fov(intrinsics).rad2deg() > self.cfg.max_fov).any():
+ raise Exception("Field of view too wide")
+
+ input_frames = [frame_ids[i] for i in context_indices]
+ target_frame = [frame_ids[i] for i in target_indices]
+
+ context_images = self.load_frames(scene_id, input_frames)
+ target_images = self.load_frames(scene_id, target_frame)
+ context_depths = self.load_depth(scene_id, input_frames, mds)
+ target_depths = self.load_depth(scene_id, target_frame, mds)
+
+ mask_bg = (self.mask_bg == True) or (self.mask_bg == "rand" and np.random.random() < 0.5)
+ if mask_bg:
+ context_masks = self.load_masks(scene_id, input_frames)
+ target_mask = self.load_masks(scene_id, target_frame)
+
+ # update the depthmap with mask
+ context_depths = [depth * mask for depth, mask in zip(context_depths, context_masks)]
+ target_depths = [depth * mask for depth, mask in zip(target_depths, target_mask)]
+
+
+ # Resize the world to make the baseline 1.
+ context_extrinsics = extrinsics[context_indices]
+ if self.cfg.make_baseline_1:
+ a, b = context_extrinsics[0, :3, 3], context_extrinsics[-1, :3, 3]
+ scale = (a - b).norm()
+ if scale < self.cfg.baseline_min or scale > self.cfg.baseline_max:
+ print(
+ f"Skipped {scene_id} because of baseline out of range: "
+ f"{scale:.6f}"
+ )
+ raise Exception("baseline out of range")
+ extrinsics[:, :3, 3] /= scale
+ else:
+ scale = 1
+
+ if self.cfg.relative_pose:
+ extrinsics = camera_normalization(extrinsics[context_indices][0:1], extrinsics)
+
+ # self.cfg.rescale_to_1cube = True
+ if self.cfg.rescale_to_1cube:
+ scene_scale = torch.max(torch.abs(extrinsics[context_indices][:, :3, 3])) # target pose is not included
+ # all_extrinsics = torch.cat([extrinsics[context_indices], extrinsics[target_indices]], dim=0) # [N, 4, 4]
+ # scene_scale = torch.max(torch.abs(all_extrinsics[:, :3, 3]))
+ rescale_factor = 1 * scene_scale
+ extrinsics[:, :3, 3] /= rescale_factor
+
+ example = {
+ "context": {
+ "extrinsics": extrinsics[context_indices],
+ "intrinsics": intrinsics[context_indices],
+ "image": context_images,
+ "depth": context_depths,
+ "near": self.get_bound("near", len(context_indices)),
+ "far": self.get_bound("far", len(context_indices)),
+ "index": context_indices,
+ # "overlap": overlap,
+ },
+ "target": {
+ "extrinsics": extrinsics[target_indices],
+ "intrinsics": intrinsics[target_indices],
+ "image": target_images,
+ "depth": target_depths,
+ "near": self.get_bound("near", len(target_indices)),
+ "far": self.get_bound("far", len(target_indices)),
+ "index": target_indices,
+ },
+ "scene": f"CO3Dv2 {scene_id}",
+ }
+
+ if self.stage == "train" and self.cfg.intr_augment:
+ intr_aug = True
+ else:
+ intr_aug = False
+
+ example = apply_crop_shim(example, (patchsize[0] * 14, patchsize[1] * 14), intr_aug=intr_aug)
+
+ if self.stage == "train" and self.cfg.augment:
+ example = apply_augmentation_shim(example)
+
+ # example_1 = copy.deepcopy(example)
+ # world pts
+ image_size = example["context"]["image"].shape[2:]
+ context_intrinsics = example["context"]["intrinsics"].clone().detach().numpy()
+ context_intrinsics[:, 0] = context_intrinsics[:, 0] * image_size[1]
+ context_intrinsics[:, 1] = context_intrinsics[:, 1] * image_size[0]
+
+ target_intrinsics = example["target"]["intrinsics"].clone().detach().numpy()
+ target_intrinsics[:, 0] = target_intrinsics[:, 0] * image_size[1]
+ target_intrinsics[:, 1] = target_intrinsics[:, 1] * image_size[0]
+
+ context_pts3d_list, context_valid_mask_list = [], []
+ target_pts3d_list, target_valid_mask_list = [], []
+
+ for i in range(len(example["context"]["depth"])):
+ context_pts3d, context_valid_mask = depthmap_to_absolute_camera_coordinates(example["context"]["depth"][i].numpy(), context_intrinsics[i], example["context"]["extrinsics"][i].numpy())
+ context_pts3d_list.append(torch.from_numpy(context_pts3d).to(torch.float32))
+ context_valid_mask_list.append(torch.from_numpy(context_valid_mask))
+
+ context_pts3d = torch.stack(context_pts3d_list, dim=0)
+ context_valid_mask = torch.stack(context_valid_mask_list, dim=0)
+
+ for i in range(len(example["target"]["depth"])):
+ target_pts3d, target_valid_mask = depthmap_to_absolute_camera_coordinates(example["target"]["depth"][i].numpy(), target_intrinsics[i], example["target"]["extrinsics"][i].numpy())
+ target_pts3d_list.append(torch.from_numpy(target_pts3d).to(torch.float32))
+ target_valid_mask_list.append(torch.from_numpy(target_valid_mask))
+
+ target_pts3d = torch.stack(target_pts3d_list, dim=0)
+ target_valid_mask = torch.stack(target_valid_mask_list, dim=0)
+
+ # normalize by context pts3d
+ if self.cfg.normalize_by_pts3d:
+ transformed_pts3d = context_pts3d[context_valid_mask]
+ scene_factor = transformed_pts3d.norm(dim=-1).mean().clip(min=1e-8)
+
+ context_pts3d /= scene_factor
+ example["context"]["depth"] /= scene_factor
+ example["context"]["extrinsics"][:, :3, 3] /= scene_factor
+
+ target_pts3d /= scene_factor
+ example["target"]["depth"] /= scene_factor
+ example["target"]["extrinsics"][:, :3, 3] /= scene_factor
+
+ example["context"]["pts3d"] = context_pts3d
+ example["target"]["pts3d"] = target_pts3d
+ example["context"]["valid_mask"] = context_valid_mask
+ example["target"]["valid_mask"] = target_valid_mask
+
+ if torch.isnan(example["context"]["depth"]).any() or torch.isinf(example["context"]["depth"]).any() or \
+ torch.isnan(example["context"]["extrinsics"]).any() or torch.isinf(example["context"]["extrinsics"]).any() or \
+ torch.isnan(example["context"]["pts3d"]).any() or torch.isinf(example["context"]["pts3d"]).any() or \
+ torch.isnan(example["context"]["intrinsics"]).any() or torch.isinf(example["context"]["intrinsics"]).any() or \
+ torch.isnan(example["target"]["depth"]).any() or torch.isinf(example["target"]["depth"]).any() or \
+ torch.isnan(example["target"]["extrinsics"]).any() or torch.isinf(example["target"]["extrinsics"]).any() or \
+ torch.isnan(example["target"]["pts3d"]).any() or torch.isinf(example["target"]["pts3d"]).any() or \
+ torch.isnan(example["target"]["intrinsics"]).any() or torch.isinf(example["target"]["intrinsics"]).any():
+ raise Exception("encounter nan or inf in context depth")
+
+ for key in ["context", "target"]:
+ example[key]["valid_mask"] = (torch.ones_like(example[key]["valid_mask"]) * -1).type(torch.int32)
+
+ return example
+
+
+ def __getitem__(self, index_tuple: tuple) -> dict:
+ index, num_context_views, patchsize_h = index_tuple
+ patchsize_w = (self.cfg.input_image_shape[1] // 14)
+ try:
+ return self.getitem(index, num_context_views, (patchsize_h, patchsize_w))
+ except Exception as e:
+ print(f"Error: {e}")
+ index = np.random.randint(len(self))
+ return self.__getitem__((index, num_context_views, patchsize_h))
+
+ def get_bound(
+ self,
+ bound: Literal["near", "far"],
+ num_views: int,
+ ) -> Float[Tensor, " view"]:
+ value = torch.tensor(getattr(self, bound), dtype=torch.float32)
+ return repeat(value, "-> v", v=num_views)
+
+ @property
+ def data_stage(self) -> Stage:
+ if self.cfg.overfit_to_scene is not None:
+ return "test"
+ if self.stage == "val":
+ return "test"
+ return self.stage
+
+ @cached_property
+ def index(self) -> dict[str, Path]:
+ merged_index = {}
+ data_stages = [self.data_stage]
+ if self.cfg.overfit_to_scene is not None:
+ data_stages = ("test", "train")
+ for data_stage in data_stages:
+ for root in self.cfg.roots:
+ # Load the root's index.
+ with (root / data_stage / "index.json").open("r") as f:
+ index = json.load(f)
+ index = {k: Path(root / data_stage / v) for k, v in index.items()}
+
+ # The constituent datasets should have unique keys.
+ assert not (set(merged_index.keys()) & set(index.keys()))
+
+ # Merge the root's index into the main index.
+ merged_index = {**merged_index, **index}
+ return merged_index
+
+ def __len__(self) -> int:
+ return len(self.scene_ids)
\ No newline at end of file
diff --git a/src/dataset/dataset_dl3dv.py b/src/dataset/dataset_dl3dv.py
new file mode 100644
index 0000000000000000000000000000000000000000..8da706db5c0bcde903a11442d144c5318717dd75
--- /dev/null
+++ b/src/dataset/dataset_dl3dv.py
@@ -0,0 +1,421 @@
+from concurrent.futures import ThreadPoolExecutor, as_completed
+import json
+from dataclasses import dataclass
+from functools import cached_property
+from io import BytesIO
+from pathlib import Path
+from typing import Literal
+import os
+import numpy as np
+import torch
+import torchvision.transforms as tf
+from einops import rearrange, repeat
+from jaxtyping import Float, UInt8
+from PIL import Image
+from torch import Tensor
+from torch.utils.data import Dataset
+import torch.nn.functional as F
+
+from ..geometry.projection import get_fov
+from .dataset import DatasetCfgCommon
+from .shims.augmentation_shim import apply_augmentation_shim
+from .shims.crop_shim import apply_crop_shim
+from .types import Stage
+from .view_sampler import ViewSampler
+from ..misc.cam_utils import camera_normalization
+
+
+@dataclass
+class DatasetDl3dvCfg(DatasetCfgCommon):
+ name: str
+ roots: list[Path]
+ baseline_min: float
+ baseline_max: float
+ max_fov: float
+ make_baseline_1: bool
+ augment: bool
+ relative_pose: bool
+ skip_bad_shape: bool
+ avg_pose: bool
+ rescale_to_1cube: bool
+ intr_augment: bool
+ normalize_by_pts3d: bool
+ rescale_to_1cube: bool
+
+
+@dataclass
+class DatasetDL3DVCfgWrapper:
+ dl3dv: DatasetDl3dvCfg
+
+
+
+class DatasetDL3DV(Dataset):
+ cfg: DatasetDl3dvCfg
+ stage: Stage
+ view_sampler: ViewSampler
+
+ to_tensor: tf.ToTensor
+ chunks: list[Path]
+ near: float = 0.1
+ far: float = 100.0
+
+ def __init__(
+ self,
+ cfg: DatasetDl3dvCfg,
+ stage: Stage,
+ view_sampler: ViewSampler,
+ ) -> None:
+ super().__init__()
+ self.cfg = cfg
+ self.stage = stage
+ self.view_sampler = view_sampler
+ self.to_tensor = tf.ToTensor()
+
+ # load data
+ self.data_root = cfg.roots[0]
+ self.data_list = []
+ with open(f"{self.data_root}/{self.data_stage}_index.json", "r") as file:
+ data_index = json.load(file)
+
+ self.data_list = [
+ os.path.join(self.data_root, item) for item in data_index
+ ] # train: 9900 test: 140
+
+ self.scene_ids = {}
+ self.scenes = {}
+ index = 0
+ with ThreadPoolExecutor(max_workers=32) as executor:
+ futures = [executor.submit(self.load_jsons, scene_path) for scene_path in self.data_list]
+ for future in as_completed(futures):
+ scene_frames, scene_id = future.result()
+ self.scenes[scene_id] = scene_frames
+ self.scene_ids[index] = scene_id
+ index += 1
+ print(f"DL3DV: {self.stage}: loaded {len(self.scene_ids)} scenes")
+
+ def convert_intrinsics(self, meta_data):
+ store_h, store_w = meta_data["h"], meta_data["w"]
+ fx, fy, cx, cy = (
+ meta_data["fl_x"],
+ meta_data["fl_y"],
+ meta_data["cx"],
+ meta_data["cy"],
+ )
+ intrinsics = np.eye(3, dtype=np.float32)
+ intrinsics[0, 0] = float(fx) / float(store_w)
+ intrinsics[1, 1] = float(fy) / float(store_h)
+ intrinsics[0, 2] = float(cx) / float(store_w)
+ intrinsics[1, 2] = float(cy) / float(store_h)
+ return intrinsics
+
+ def blender2opencv_c2w(self, pose):
+ blender2opencv = np.array(
+ [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]
+ )
+ opencv_c2w = np.array(pose) @ blender2opencv
+ return opencv_c2w.tolist()
+
+ def load_jsons(self, scene_path):
+ json_path = os.path.join(scene_path, "transforms.json")
+ with open(json_path, "r") as f:
+ data = json.load(f)
+
+ scene_frames = []
+ scene_id = scene_path.split("/")[-1].split(".")[0]
+ for i, frame in enumerate(data["frames"]):
+ frame_tmp = {}
+ frame_tmp["file_path"] = os.path.join(scene_path, frame["file_path"])
+ frame_tmp["intrinsics"] = self.convert_intrinsics(data).tolist()
+ frame_tmp["extrinsics"] = self.blender2opencv_c2w(frame["transform_matrix"])
+ scene_frames.append(frame_tmp)
+ return scene_frames, scene_id
+
+ def load_frames(self, frames):
+ with ThreadPoolExecutor(max_workers=32) as executor:
+ # Create a list to store futures with their original indices
+ futures_with_idx = []
+ for idx, file_path in enumerate(frames):
+ file_path = file_path["file_path"].replace("images", "images_4")
+ futures_with_idx.append(
+ (
+ idx,
+ executor.submit(
+ lambda p: self.to_tensor(Image.open(p).convert("RGB")),
+ file_path,
+ ),
+ )
+ )
+
+ # Pre-allocate list with correct size to maintain order
+ torch_images = [None] * len(frames)
+ for idx, future in futures_with_idx:
+ torch_images[idx] = future.result()
+ # Check if all images have the same size
+ sizes = set(img.shape for img in torch_images)
+ if len(sizes) == 1:
+ torch_images = torch.stack(torch_images)
+ # Return as list if images have different sizes
+ return torch_images
+
+ def load_depth(self, frames):
+ depth_list = []
+ for frame_name in frames:
+ depth_path = frame_name.replace("images", "depth").replace("jpg", "npy")
+ depth = torch.from_numpy(np.load(depth_path))
+ positive_depths = depth[depth > 0]
+ if len(positive_depths) > 1000000: # If more than 1M points, sample randomly
+ indices = torch.randperm(len(positive_depths))[:1000000]
+ positive_depths = positive_depths[indices]
+ percentile_95 = torch.quantile(positive_depths, 0.95)
+ # Set depth values greater than the 95th percentile to 0
+ depth[depth > percentile_95] = 0
+ depth_list.append(depth)
+
+ return torch.stack(depth_list)
+
+ def shuffle(self, lst: list) -> list:
+ indices = torch.randperm(len(lst))
+ return [lst[x] for x in indices]
+
+ def getitem(self, index: int, num_context_views: int, patchsize: tuple) -> dict:
+
+ scene = self.scene_ids[index]
+
+ example = self.scenes[scene]
+ # load poses
+ extrinsics = []
+ intrinsics = []
+ for frame in example:
+ extrinsic = frame["extrinsics"]
+ intrinsic = frame["intrinsics"]
+ extrinsics.append(extrinsic)
+ intrinsics.append(intrinsic)
+
+ extrinsics = np.array(extrinsics)
+ intrinsics = np.array(intrinsics)
+ extrinsics = torch.tensor(extrinsics, dtype=torch.float32)
+ intrinsics = torch.tensor(intrinsics, dtype=torch.float32)
+
+ try:
+ context_indices, target_indices, overlap = self.view_sampler.sample(
+ scene,
+ num_context_views,
+ extrinsics,
+ intrinsics,
+ )
+ except ValueError:
+ # Skip because the example doesn't have enough frames.
+ raise Exception("Not enough frames")
+
+ # Skip the example if the field of view is too wide.
+ if (get_fov(intrinsics).rad2deg() > self.cfg.max_fov).any():
+ raise Exception("Field of view too wide")
+
+ # Load the images.
+ input_frames = [example[i] for i in context_indices]
+ target_frame = [example[i] for i in target_indices]
+
+ context_images = self.load_frames(input_frames)
+ target_images = self.load_frames(target_frame)
+
+ # context_depth = self.load_depth(input_frames)
+ # target_depth = self.load_depth(target_frame)
+ context_depth = torch.ones_like(context_images)[:, 0]
+ target_depth = torch.ones_like(target_images)[:, 0]
+
+ # Skip the example if the images don't have the right shape.
+ context_image_invalid = context_images.shape[1:] != (3, *self.cfg.original_image_shape)
+ target_image_invalid = target_images.shape[1:] != (3, *self.cfg.original_image_shape)
+ if self.cfg.skip_bad_shape and (context_image_invalid or target_image_invalid):
+ print(
+ f"Skipped bad example {example['key']}. Context shape was "
+ f"{context_images.shape} and target shape was "
+ f"{target_images.shape}."
+ )
+ raise Exception("Bad example image shape")
+
+ # Resize the world to make the baseline 1.
+ context_extrinsics = extrinsics[context_indices]
+ if self.cfg.make_baseline_1:
+ a, b = context_extrinsics[0, :3, 3], context_extrinsics[-1, :3, 3]
+ scale = (a - b).norm()
+ if scale < self.cfg.baseline_min or scale > self.cfg.baseline_max:
+ print(
+ f"Skipped {scene} because of baseline out of range: "
+ f"{scale:.6f}"
+ )
+ raise Exception("baseline out of range")
+ extrinsics[:, :3, 3] /= scale
+ else:
+ scale = 1
+
+ if self.cfg.relative_pose:
+ extrinsics = camera_normalization(extrinsics[context_indices][0:1], extrinsics)
+
+ if self.cfg.rescale_to_1cube:
+ scene_scale = torch.max(torch.abs(extrinsics[context_indices][:, :3, 3])) # target pose is not included
+ rescale_factor = 1 * scene_scale
+ extrinsics[:, :3, 3] /= rescale_factor
+
+ if torch.isnan(extrinsics).any() or torch.isinf(extrinsics).any():
+ raise Exception("encounter nan or inf in input poses")
+
+ example = {
+ "context": {
+ "extrinsics": extrinsics[context_indices],
+ "intrinsics": intrinsics[context_indices],
+ "image": context_images,
+ "depth": context_depth,
+ "near": self.get_bound("near", len(context_indices)) / scale,
+ "far": self.get_bound("far", len(context_indices)) / scale,
+ "index": context_indices,
+ # "overlap": overlap,
+ },
+ "target": {
+ "extrinsics": extrinsics[target_indices],
+ "intrinsics": intrinsics[target_indices],
+ "image": target_images,
+ "depth": target_depth,
+ "near": self.get_bound("near", len(target_indices)) / scale,
+ "far": self.get_bound("far", len(target_indices)) / scale,
+ "index": target_indices,
+ },
+ "scene": "dl3dv_"+scene,
+ }
+ if self.stage == "train" and self.cfg.augment:
+ example = apply_augmentation_shim(example)
+
+ if self.stage == "train" and self.cfg.intr_augment:
+ intr_aug = True
+ else:
+ intr_aug = False
+
+ example = apply_crop_shim(example, (patchsize[0] * 14, patchsize[1] * 14), intr_aug=intr_aug)
+
+ image_size = example["context"]["image"].shape[2:]
+ context_intrinsics = example["context"]["intrinsics"].clone().detach().numpy()
+ context_intrinsics[:, 0] = context_intrinsics[:, 0] * image_size[1]
+ context_intrinsics[:, 1] = context_intrinsics[:, 1] * image_size[0]
+
+ target_intrinsics = example["target"]["intrinsics"].clone().detach().numpy()
+ target_intrinsics[:, 0] = target_intrinsics[:, 0] * image_size[1]
+ target_intrinsics[:, 1] = target_intrinsics[:, 1] * image_size[0]
+
+ context_pts3d_list, context_valid_mask_list = [], []
+ target_pts3d_list, target_valid_mask_list = [], []
+
+ # for i in range(len(example["context"]["depth"])):
+ # context_pts3d, context_valid_mask = depthmap_to_absolute_camera_coordinates(example["context"]["depth"][i].numpy(), context_intrinsics[i], example["context"]["extrinsics"][i].numpy())
+ # context_pts3d_list.append(torch.from_numpy(context_pts3d).to(torch.float32))
+ # context_valid_mask_list.append(torch.from_numpy(context_valid_mask))
+
+ # context_pts3d = torch.stack(context_pts3d_list, dim=0)
+ # context_valid_mask = torch.stack(context_valid_mask_list, dim=0)
+
+ context_pts3d = torch.ones_like(example["context"]["image"]).permute(0, 2, 3, 1) # [N, H, W, 3]
+ context_valid_mask = torch.ones_like(example["context"]["image"])[:, 0].bool() # [N, H, W]
+
+ target_pts3d = torch.ones_like(target_images).permute(0, 2, 3, 1) # [N, H, W, 3]
+ target_valid_mask = torch.ones_like(target_images)[:, 0].bool() # [N, H, W]
+
+ # normalize by context pts3d
+ if self.cfg.normalize_by_pts3d:
+ transformed_pts3d = context_pts3d[context_valid_mask]
+ scene_factor = transformed_pts3d.norm(dim=-1).mean().clip(min=1e-8)
+
+ context_pts3d /= scene_factor
+ example["context"]["depth"] /= scene_factor
+ example["context"]["extrinsics"][:, :3, 3] /= scene_factor
+
+ target_pts3d /= scene_factor
+ example["target"]["depth"] /= scene_factor
+ example["target"]["extrinsics"][:, :3, 3] /= scene_factor
+
+ example["context"]["pts3d"] = context_pts3d
+ example["target"]["pts3d"] = target_pts3d
+ example["context"]["valid_mask"] = context_valid_mask * -1
+ example["target"]["valid_mask"] = target_valid_mask * -1
+
+ return example
+
+ def __getitem__(self, index_tuple: tuple) -> dict:
+ index, num_context_views, patchsize_h = index_tuple
+ patchsize_w = (self.cfg.input_image_shape[1] // 14)
+ try:
+ return self.getitem(index, num_context_views, (patchsize_h, patchsize_w))
+ except Exception as e:
+ print(f"Error: {e}")
+ index = np.random.randint(len(self))
+ return self.__getitem__((index, num_context_views, patchsize_h))
+
+ def convert_poses(
+ self,
+ poses: Float[Tensor, "batch 18"],
+ ) -> tuple[
+ Float[Tensor, "batch 4 4"], # extrinsics
+ Float[Tensor, "batch 3 3"], # intrinsics
+ ]:
+ b, _ = poses.shape
+
+ # Convert the intrinsics to a 3x3 normalized K matrix.
+ intrinsics = torch.eye(3, dtype=torch.float32)
+ intrinsics = repeat(intrinsics, "h w -> b h w", b=b).clone()
+ fx, fy, cx, cy = poses[:, :4].T
+ intrinsics[:, 0, 0] = fx
+ intrinsics[:, 1, 1] = fy
+ intrinsics[:, 0, 2] = cx
+ intrinsics[:, 1, 2] = cy
+
+ # Convert the extrinsics to a 4x4 OpenCV-style W2C matrix.
+ w2c = repeat(torch.eye(4, dtype=torch.float32), "h w -> b h w", b=b).clone()
+ w2c[:, :3] = rearrange(poses[:, 6:], "b (h w) -> b h w", h=3, w=4)
+ return w2c.inverse(), intrinsics
+
+ def convert_images(
+ self,
+ images: list[UInt8[Tensor, "..."]],
+ ) -> Float[Tensor, "batch 3 height width"]:
+ torch_images = []
+ for image in images:
+ image = Image.open(BytesIO(image.numpy().tobytes()))
+ torch_images.append(self.to_tensor(image))
+ return torch.stack(torch_images)
+
+ def get_bound(
+ self,
+ bound: Literal["near", "far"],
+ num_views: int,
+ ) -> Float[Tensor, " view"]:
+ value = torch.tensor(getattr(self, bound), dtype=torch.float32)
+ return repeat(value, "-> v", v=num_views)
+
+ @property
+ def data_stage(self) -> Stage:
+ if self.cfg.overfit_to_scene is not None:
+ return "test"
+ if self.stage == "val":
+ return "test"
+ return self.stage
+
+ @cached_property
+ def index(self) -> dict[str, Path]:
+ merged_index = {}
+ data_stages = [self.data_stage]
+ if self.cfg.overfit_to_scene is not None:
+ data_stages = ("test", "train")
+ for data_stage in data_stages:
+ for root in self.cfg.roots:
+ # Load the root's index.
+ with (root / data_stage / "index.json").open("r") as f:
+ index = json.load(f)
+ index = {k: Path(root / data_stage / v) for k, v in index.items()}
+
+ # The constituent datasets should have unique keys.
+ assert not (set(merged_index.keys()) & set(index.keys()))
+
+ # Merge the root's index into the main index.
+ merged_index = {**merged_index, **index}
+ return merged_index
+
+ def __len__(self) -> int:
+ return len(self.scene_ids)
diff --git a/src/dataset/dataset_scannetpp.py b/src/dataset/dataset_scannetpp.py
new file mode 100644
index 0000000000000000000000000000000000000000..050a0943a682ca92481b1d194785fcf5e96b6429
--- /dev/null
+++ b/src/dataset/dataset_scannetpp.py
@@ -0,0 +1,443 @@
+import json
+from dataclasses import dataclass
+from functools import cached_property
+from io import BytesIO
+from pathlib import Path
+import random
+from typing import Literal
+import os
+import cv2
+import numpy as np
+import torch
+import torchvision.transforms as tf
+from einops import rearrange, repeat
+from jaxtyping import Float, UInt8
+from PIL import Image
+import torchvision
+from torch import Tensor
+from torch.utils.data import Dataset
+from concurrent.futures import ThreadPoolExecutor, as_completed
+import copy
+from .shims.geometry_shim import depthmap_to_absolute_camera_coordinates
+
+from .shims.load_shim import imread_cv2
+
+from ..geometry.projection import get_fov
+from .dataset import DatasetCfgCommon
+from .shims.augmentation_shim import apply_augmentation_shim
+from .shims.crop_shim import apply_crop_shim
+from .types import Stage
+from .view_sampler import ViewSampler
+from ..misc.cam_utils import camera_normalization
+
+@dataclass
+class DatasetScannetppCfg(DatasetCfgCommon):
+ name: str
+ roots: list[Path]
+ baseline_min: float
+ baseline_max: float
+ max_fov: float
+ make_baseline_1: bool
+ augment: bool
+ relative_pose: bool
+ skip_bad_shape: bool
+ metric_thre: float
+ intr_augment: bool
+ make_baseline_1: bool
+ rescale_to_1cube: bool
+ normalize_by_pts3d: bool
+
+@dataclass
+class DatasetScannetppCfgWrapper:
+ scannetpp: DatasetScannetppCfg
+
+
+class DatasetScannetpp(Dataset):
+ cfg: DatasetScannetppCfgWrapper
+ stage: Stage
+ view_sampler: ViewSampler
+
+ to_tensor: tf.ToTensor
+ chunks: list[Path]
+ near: float = 0.1
+ far: float = 100.0
+
+ def __init__(
+ self,
+ cfg: DatasetScannetppCfgWrapper,
+ stage: Stage,
+ view_sampler: ViewSampler,
+ ) -> None:
+ super().__init__()
+ self.cfg = cfg
+ self.stage = stage
+ self.view_sampler = view_sampler
+ self.to_tensor = tf.ToTensor()
+
+ # load data
+ self.data_root = cfg.roots[0]
+ self.data_list = [] # we use dslr rather than iphone
+ data_index = os.listdir(f"{self.data_root}") # we train all the scenes
+
+ if self.stage != "train":
+ with open(f"{self.data_root}/valid.json", "r") as file:
+ data_index = json.load(file)[:10]
+ data_index = data_index * 100
+ random.shuffle(data_index)
+ else:
+ with open(f"{self.data_root}/valid.json", "r") as file:
+ data_index = json.load(file)[10:]
+
+ self.data_list = [
+ os.path.join(self.data_root, item) for item in data_index
+ ]
+
+ self.scene_ids = {}
+ self.scenes = {}
+ index = 0
+ with ThreadPoolExecutor(max_workers=32) as executor:
+ futures = [executor.submit(self.load_metadata, scene_path) for scene_path in self.data_list]
+ for future in as_completed(futures):
+ scene_frames, scene_id = future.result()
+ self.scenes[scene_id] = scene_frames
+ self.scene_ids[index] = scene_id
+ index += 1
+
+ # if self.stage != "train":
+ # self.scene_ids = self.scene_ids
+ # random.shuffle(self.scene_ids)
+ print(f"Scannetpp: {self.stage}: loaded {len(self.scene_ids)} scenes")
+
+ def shuffle(self, lst: list) -> list:
+ indices = torch.randperm(len(lst))
+ return [lst[x] for x in indices]
+
+ def load_metadata(self, scene_path):
+ metadata_path = os.path.join(scene_path, "scene_metadata.npz")
+ metadata = np.load(metadata_path, allow_pickle=True)
+ intrinsics = metadata["intrinsics"]
+ trajectories = metadata["trajectories"]
+ images = metadata["images"]
+
+ scene_id = scene_path.split("/")[-1].split(".")[0]
+ scene_frames = [
+ {
+ "file_path": os.path.join(scene_path, "images", images[i].split(".")[0] + ".jpg"),
+ "depth_path": os.path.join(scene_path, "depth", images[i].split(".")[0] + ".png"),
+ "intrinsics": self.convert_intrinsics(intrinsics[i]),
+ "extrinsics": trajectories[i],
+ }
+ for i in range(len(images))
+ ]
+ scene_frames.sort(key=lambda x: x["file_path"]) # sort by file path to ensure correct order
+ return scene_frames, scene_id
+
+ def convert_intrinsics(self, intrinsics):
+ w = intrinsics[0, 2] * 2
+ h = intrinsics[1, 2] * 2
+ intrinsics[0, 0] = intrinsics[0, 0] / w
+ intrinsics[1, 1] = intrinsics[1, 1] / h
+ intrinsics[0, 2] = intrinsics[0, 2] / w
+ intrinsics[1, 2] = intrinsics[1, 2] / h
+ return intrinsics
+
+ def blender2opencv_c2w(self, pose):
+ blender2opencv = np.array(
+ [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]
+ )
+ opencv_c2w = np.array(pose) @ blender2opencv
+ return opencv_c2w.tolist()
+
+ def load_frames(self, frames):
+ with ThreadPoolExecutor(max_workers=1) as executor:
+ # Create a list to store futures with their original indices
+ futures_with_idx = []
+ for idx, file_path in enumerate(frames):
+ file_path = file_path["file_path"]
+ futures_with_idx.append(
+ (
+ idx,
+ executor.submit(
+ lambda p: self.to_tensor(Image.open(p).convert("RGB")),
+ file_path,
+ ),
+ )
+ )
+
+ # Pre-allocate list with correct size to maintain order
+ torch_images = [None] * len(frames)
+ for idx, future in futures_with_idx:
+ torch_images[idx] = future.result()
+ # Check if all images have the same size
+ sizes = set(img.shape for img in torch_images)
+ if len(sizes) == 1:
+ torch_images = torch.stack(torch_images)
+ # Return as list if images have different sizes
+ return torch_images
+
+ def load_depths(self, frames):
+ torch_depths = []
+ for idx, frame in enumerate(frames):
+ depthmap = imread_cv2(frame["depth_path"], cv2.IMREAD_UNCHANGED)
+ depthmap = depthmap.astype(np.float32) / 1000
+ depthmap[~np.isfinite(depthmap)] = 0
+ torch_depths.append(torch.from_numpy(depthmap))
+ return torch.stack(torch_depths) # [N, H, W]
+
+
+ def getitem(self, index: int, num_context_views: int, patchsize: tuple) -> dict:
+ # import time
+ # start_time = time.time()
+
+ scene = self.scene_ids[index]
+ example = self.scenes[scene]
+ # load poses
+ extrinsics = []
+ intrinsics = []
+ for frame in example:
+ extrinsic = frame["extrinsics"]
+ intrinsic = frame["intrinsics"]
+ extrinsics.append(extrinsic)
+ intrinsics.append(intrinsic)
+
+ extrinsics = np.array(extrinsics)
+ intrinsics = np.array(intrinsics)
+ extrinsics = torch.tensor(extrinsics, dtype=torch.float32)
+ intrinsics = torch.tensor(intrinsics, dtype=torch.float32)
+
+ try:
+ context_indices, target_indices, overlap = self.view_sampler.sample(
+ "scannetpp_"+scene,
+ num_context_views,
+ extrinsics,
+ intrinsics,
+ )
+ except ValueError:
+ # Skip because the example doesn't have enough frames.
+ raise Exception("Not enough frames")
+
+ # Skip the example if the field of view is too wide.
+ if (get_fov(intrinsics).rad2deg() > self.cfg.max_fov).any():
+ raise Exception("Field of view too wide")
+
+ # Load the images.
+ input_frames = [example[i] for i in context_indices]
+ target_frame = [example[i] for i in target_indices]
+
+ context_images = self.load_frames(input_frames)
+ target_images = self.load_frames(target_frame)
+
+ context_depths = self.load_depths(input_frames)
+ target_depths = self.load_depths(target_frame)
+
+ # Skip the example if the images don't have the right shape.
+ context_image_invalid = context_images.shape[1:] != (3, *self.cfg.original_image_shape)
+ target_image_invalid = target_images.shape[1:] != (3, *self.cfg.original_image_shape)
+ if self.cfg.skip_bad_shape and (context_image_invalid or target_image_invalid):
+ print(
+ f"Skipped bad example {example['key']}. Context shape was "
+ f"{context_images.shape} and target shape was "
+ f"{target_images.shape}."
+ )
+ raise Exception("Bad example image shape")
+
+ context_extrinsics = extrinsics[context_indices]
+
+ if self.cfg.make_baseline_1:
+ a, b = context_extrinsics[0, :3, 3], context_extrinsics[-1, :3, 3]
+ scale = (a - b).norm()
+ if scale < self.cfg.baseline_min or scale > self.cfg.baseline_max:
+ print(
+ f"Skipped {scene} because of baseline out of range: "
+ f"{scale:.6f}"
+ )
+ raise Exception("baseline out of range")
+ extrinsics[:, :3, 3] /= scale
+ else:
+ scale = 1
+
+ if self.cfg.relative_pose:
+ extrinsics = camera_normalization(extrinsics[context_indices][0:1], extrinsics)
+
+ if self.cfg.rescale_to_1cube:
+ scene_scale = torch.max(torch.abs(extrinsics[context_indices][:, :3, 3])) # target pose is not included
+ rescale_factor = 1 * scene_scale
+ extrinsics[:, :3, 3] /= rescale_factor
+
+ if torch.isnan(extrinsics).any() or torch.isinf(extrinsics).any():
+ raise Exception("encounter nan or inf in input poses")
+
+ example = {
+ "context": {
+ "extrinsics": extrinsics[context_indices],
+ "intrinsics": intrinsics[context_indices],
+ "image": context_images,
+ "depth": context_depths,
+ "near": self.get_bound("near", len(context_indices)) / scale,
+ "far": self.get_bound("far", len(context_indices)) / scale,
+ "index": context_indices,
+ "overlap": overlap,
+ },
+ "target": {
+ "extrinsics": extrinsics[target_indices],
+ "intrinsics": intrinsics[target_indices],
+ "image": target_images,
+ "depth": target_depths,
+ "near": self.get_bound("near", len(target_indices)) / scale,
+ "far": self.get_bound("far", len(target_indices)) / scale,
+ "index": target_indices,
+ },
+ "scene": f"Scannetpp {scene}",
+ }
+ if self.stage == "train" and self.cfg.augment:
+ example = apply_augmentation_shim(example)
+
+ if self.stage == "train" and self.cfg.intr_augment:
+ intr_aug = True
+ else:
+ intr_aug = False
+
+ example = apply_crop_shim(example, (patchsize[0] * 14, patchsize[1] * 14), intr_aug=intr_aug)
+
+ # world pts
+ image_size = example["context"]["image"].shape[2:]
+ context_intrinsics = example["context"]["intrinsics"].clone().detach().numpy()
+ context_intrinsics[:, 0] = context_intrinsics[:, 0] * image_size[1]
+ context_intrinsics[:, 1] = context_intrinsics[:, 1] * image_size[0]
+
+ target_intrinsics = example["target"]["intrinsics"].clone().detach().numpy()
+ target_intrinsics[:, 0] = target_intrinsics[:, 0] * image_size[1]
+ target_intrinsics[:, 1] = target_intrinsics[:, 1] * image_size[0]
+
+ context_pts3d_list, context_valid_mask_list = [], []
+ target_pts3d_list, target_valid_mask_list = [], []
+
+ for i in range(len(example["context"]["depth"])):
+ context_pts3d, context_valid_mask = depthmap_to_absolute_camera_coordinates(example["context"]["depth"][i].numpy(), context_intrinsics[i], example["context"]["extrinsics"][i].numpy())
+ context_pts3d_list.append(torch.from_numpy(context_pts3d).to(torch.float32))
+ context_valid_mask_list.append(torch.from_numpy(context_valid_mask))
+
+ context_pts3d = torch.stack(context_pts3d_list, dim=0)
+ context_valid_mask = torch.stack(context_valid_mask_list, dim=0)
+
+ for i in range(len(example["target"]["depth"])):
+ target_pts3d, target_valid_mask = depthmap_to_absolute_camera_coordinates(example["target"]["depth"][i].numpy(), target_intrinsics[i], example["target"]["extrinsics"][i].numpy())
+ target_pts3d_list.append(torch.from_numpy(target_pts3d).to(torch.float32))
+ target_valid_mask_list.append(torch.from_numpy(target_valid_mask))
+
+ target_pts3d = torch.stack(target_pts3d_list, dim=0)
+ target_valid_mask = torch.stack(target_valid_mask_list, dim=0)
+
+ # normalize by context pts3d
+ if self.cfg.normalize_by_pts3d:
+ transformed_pts3d = context_pts3d[context_valid_mask]
+ scene_factor = transformed_pts3d.norm(dim=-1).mean().clip(min=1e-8)
+ context_pts3d /= scene_factor
+ example["context"]["depth"] /= scene_factor
+ example["context"]["extrinsics"][:, :3, 3] /= scene_factor
+
+ target_pts3d /= scene_factor
+ example["target"]["depth"] /= scene_factor
+ example["target"]["extrinsics"][:, :3, 3] /= scene_factor
+
+ example["context"]["pts3d"] = context_pts3d
+ example["target"]["pts3d"] = target_pts3d
+ example["context"]["valid_mask"] = context_valid_mask
+ example["target"]["valid_mask"] = target_valid_mask
+
+ if torch.isnan(example["context"]["depth"]).any() or torch.isinf(example["context"]["depth"]).any() or \
+ torch.isnan(example["context"]["extrinsics"]).any() or torch.isinf(example["context"]["extrinsics"]).any() or \
+ torch.isnan(example["context"]["intrinsics"]).any() or torch.isinf(example["context"]["intrinsics"]).any() or \
+ torch.isnan(example["target"]["depth"]).any() or torch.isinf(example["target"]["depth"]).any() or \
+ torch.isnan(example["target"]["extrinsics"]).any() or torch.isinf(example["target"]["extrinsics"]).any() or \
+ torch.isnan(example["target"]["intrinsics"]).any() or torch.isinf(example["target"]["intrinsics"]).any():
+ raise Exception("encounter nan or inf in context depth")
+
+ for key in ["context", "target"]:
+ example[key]["valid_mask"] = (torch.ones_like(example[key]["valid_mask"]) * -1).type(torch.int32)
+
+ return example
+
+
+ def __getitem__(self, index_tuple: tuple) -> dict:
+ index, num_context_views, patchsize_h = index_tuple
+ # generate a random patch size
+ patchsize_w = (self.cfg.input_image_shape[1] // 14)
+ try:
+ return self.getitem(index, num_context_views, (patchsize_h, patchsize_w))
+ except Exception as e:
+ print(f"Error: {e}")
+ index = np.random.randint(len(self))
+ return self.__getitem__((index, num_context_views, patchsize_h))
+
+ def convert_poses(
+ self,
+ poses: Float[Tensor, "batch 18"],
+ ) -> tuple[
+ Float[Tensor, "batch 4 4"], # extrinsics
+ Float[Tensor, "batch 3 3"], # intrinsics
+ ]:
+ b, _ = poses.shape
+
+ # Convert the intrinsics to a 3x3 normalized K matrix.
+ intrinsics = torch.eye(3, dtype=torch.float32)
+ intrinsics = repeat(intrinsics, "h w -> b h w", b=b).clone()
+ fx, fy, cx, cy = poses[:, :4].T
+ intrinsics[:, 0, 0] = fx
+ intrinsics[:, 1, 1] = fy
+ intrinsics[:, 0, 2] = cx
+ intrinsics[:, 1, 2] = cy
+
+ # Convert the extrinsics to a 4x4 OpenCV-style W2C matrix.
+ w2c = repeat(torch.eye(4, dtype=torch.float32), "h w -> b h w", b=b).clone()
+ w2c[:, :3] = rearrange(poses[:, 6:], "b (h w) -> b h w", h=3, w=4)
+ return w2c.inverse(), intrinsics
+
+ def convert_images(
+ self,
+ images: list[UInt8[Tensor, "..."]],
+ ) -> Float[Tensor, "batch 3 height width"]:
+ torch_images = []
+ for image in images:
+ image = Image.open(BytesIO(image.numpy().tobytes()))
+ torch_images.append(self.to_tensor(image))
+ return torch.stack(torch_images)
+
+ def get_bound(
+ self,
+ bound: Literal["near", "far"],
+ num_views: int,
+ ) -> Float[Tensor, " view"]:
+ value = torch.tensor(getattr(self, bound), dtype=torch.float32)
+ return repeat(value, "-> v", v=num_views)
+
+ @property
+ def data_stage(self) -> Stage:
+ if self.cfg.overfit_to_scene is not None:
+ return "test"
+ if self.stage == "val":
+ return "test"
+ return self.stage
+
+ @cached_property
+ def index(self) -> dict[str, Path]:
+ merged_index = {}
+ data_stages = [self.data_stage]
+ if self.cfg.overfit_to_scene is not None:
+ data_stages = ("test", "train")
+ for data_stage in data_stages:
+ for root in self.cfg.roots:
+ # Load the root's index.
+ with (root / data_stage / "index.json").open("r") as f:
+ index = json.load(f)
+ index = {k: Path(root / data_stage / v) for k, v in index.items()}
+
+ # The constituent datasets should have unique keys.
+ assert not (set(merged_index.keys()) & set(index.keys()))
+
+ # Merge the root's index into the main index.
+ merged_index = {**merged_index, **index}
+ return merged_index
+
+ def __len__(self) -> int:
+ return len(self.data_list)
\ No newline at end of file
diff --git a/src/dataset/shims/augmentation_shim.py b/src/dataset/shims/augmentation_shim.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4e6e7dd0d4ffe77476a34768c5b0f9d91867a57
--- /dev/null
+++ b/src/dataset/shims/augmentation_shim.py
@@ -0,0 +1,219 @@
+import copy
+import random
+import numpy as np
+import torch
+from jaxtyping import Float
+from torch import Tensor
+
+from ..types import AnyExample, AnyViews
+
+
+def reflect_extrinsics(
+ extrinsics: Float[Tensor, "*batch 4 4"],
+) -> Float[Tensor, "*batch 4 4"]:
+ reflect = torch.eye(4, dtype=torch.float32, device=extrinsics.device)
+ reflect[0, 0] = -1
+ return reflect @ extrinsics @ reflect
+
+
+def reflect_views(views: AnyViews) -> AnyViews:
+ if "depth" in views.keys():
+ return {
+ **views,
+ "image": views["image"].flip(-1),
+ "extrinsics": reflect_extrinsics(views["extrinsics"]),
+ "depth": views["depth"].flip(-1),
+ }
+ else:
+ return {
+ **views,
+ "image": views["image"].flip(-1),
+ "extrinsics": reflect_extrinsics(views["extrinsics"]),
+ }
+
+
+def apply_augmentation_shim(
+ example: AnyExample,
+ generator: torch.Generator | None = None,
+) -> AnyExample:
+ """Randomly augment the training images."""
+ # Do not augment with 50% chance.
+ if torch.rand(tuple(), generator=generator) < 0.5:
+ return example
+
+ return {
+ **example,
+ "context": reflect_views(example["context"]),
+ "target": reflect_views(example["target"]),
+ }
+
+def rotate_90_degrees(
+ image: torch.Tensor, depth_map: torch.Tensor | None, extri_opencv: torch.Tensor, intri_opencv: torch.Tensor, clockwise=True
+):
+ """
+ Rotates the input image, depth map, and camera parameters by 90 degrees.
+
+ Applies one of two 90-degree rotations:
+ - Clockwise
+ - Counterclockwise (if clockwise=False)
+
+ The extrinsic and intrinsic matrices are adjusted accordingly to maintain
+ correct camera geometry.
+
+ Args:
+ image (torch.Tensor):
+ Input image tensor of shape (C, H, W).
+ depth_map (torch.Tensor or None):
+ Depth map tensor of shape (H, W), or None if not available.
+ extri_opencv (torch.Tensor):
+ Extrinsic matrix (3x4) in OpenCV convention.
+ intri_opencv (torch.Tensor):
+ Intrinsic matrix (3x3).
+ clockwise (bool):
+ If True, rotates the image 90 degrees clockwise; else 90 degrees counterclockwise.
+
+ Returns:
+ tuple:
+ (
+ rotated_image,
+ rotated_depth_map,
+ new_extri_opencv,
+ new_intri_opencv
+ )
+
+ Where each is the updated version after the rotation.
+ """
+ image_height, image_width = image.shape[-2:]
+
+ # Rotate the image and depth map
+ rotated_image, rotated_depth_map = rotate_image_and_depth_rot90(image, depth_map, clockwise)
+ # Adjust the intrinsic matrix
+ new_intri_opencv = adjust_intrinsic_matrix_rot90(intri_opencv, image_width, image_height, clockwise)
+ # Adjust the extrinsic matrix
+ new_extri_opencv = adjust_extrinsic_matrix_rot90(extri_opencv, clockwise)
+
+ return (
+ rotated_image,
+ rotated_depth_map,
+ new_extri_opencv,
+ new_intri_opencv,
+ )
+
+
+def rotate_image_and_depth_rot90(image: torch.Tensor, depth_map: torch.Tensor | None, clockwise: bool):
+ """
+ Rotates the given image and depth map by 90 degrees (clockwise or counterclockwise).
+
+ Args:
+ image (torch.Tensor):
+ Input image tensor of shape (C, H, W).
+ depth_map (torch.Tensor or None):
+ Depth map tensor of shape (H, W), or None if not available.
+ clockwise (bool):
+ If True, rotate 90 degrees clockwise; else 90 degrees counterclockwise.
+
+ Returns:
+ tuple:
+ (rotated_image, rotated_depth_map)
+ """
+ rotated_depth_map = None
+ if clockwise:
+ rotated_image = torch.rot90(image, k=-1, dims=[-2, -1])
+ if depth_map is not None:
+ rotated_depth_map = torch.rot90(depth_map, k=-1, dims=[-2, -1])
+ else:
+ rotated_image = torch.rot90(image, k=1, dims=[-2, -1])
+ if depth_map is not None:
+ rotated_depth_map = torch.rot90(depth_map, k=1, dims=[-2, -1])
+ return rotated_image, rotated_depth_map
+
+
+def adjust_extrinsic_matrix_rot90(extri_opencv: torch.Tensor, clockwise: bool):
+ """
+ Adjusts the extrinsic matrix (3x4) for a 90-degree rotation of the image.
+
+ The rotation is in the image plane. This modifies the camera orientation
+ accordingly. The function applies either a clockwise or counterclockwise
+ 90-degree rotation.
+
+ Args:
+ extri_opencv (torch.Tensor):
+ Extrinsic matrix (3x4) in OpenCV convention.
+ clockwise (bool):
+ If True, rotate extrinsic for a 90-degree clockwise image rotation;
+ otherwise, counterclockwise.
+
+ Returns:
+ torch.Tensor:
+ A new 3x4 extrinsic matrix after the rotation.
+ """
+ R = extri_opencv[:3, :3]
+ t = extri_opencv[:3, 3]
+
+ if clockwise:
+ R_rotation = torch.tensor([
+ [0, -1, 0],
+ [1, 0, 0],
+ [0, 0, 1]
+ ], dtype=extri_opencv.dtype, device=extri_opencv.device)
+ else:
+ R_rotation = torch.tensor([
+ [0, 1, 0],
+ [-1, 0, 0],
+ [0, 0, 1]
+ ], dtype=extri_opencv.dtype, device=extri_opencv.device)
+
+ new_R = torch.matmul(R_rotation, R)
+ new_t = torch.matmul(R_rotation, t)
+ new_extri_opencv = torch.cat((new_R, new_t.reshape(-1, 1)), dim=1)
+ new_extri_opencv = torch.cat((new_extri_opencv,
+ torch.tensor([[0, 0, 0, 1]],
+ dtype=extri_opencv.dtype, device=extri_opencv.device)), dim=0)
+ return new_extri_opencv
+
+
+def adjust_intrinsic_matrix_rot90(intri_opencv: torch.Tensor, image_width: int, image_height: int, clockwise: bool):
+ """
+ Adjusts the intrinsic matrix (3x3) for a 90-degree rotation of the image in the image plane.
+
+ Args:
+ intri_opencv (torch.Tensor):
+ Intrinsic matrix (3x3).
+ image_width (int):
+ Original width of the image.
+ image_height (int):
+ Original height of the image.
+ clockwise (bool):
+ If True, rotate 90 degrees clockwise; else 90 degrees counterclockwise.
+
+ Returns:
+ torch.Tensor:
+ A new 3x3 intrinsic matrix after the rotation.
+ """
+ intri_opencv = copy.deepcopy(intri_opencv)
+ intri_opencv[0, :] *= image_width
+ intri_opencv[1, :] *= image_height
+
+ fx, fy, cx, cy = (
+ intri_opencv[0, 0],
+ intri_opencv[1, 1],
+ intri_opencv[0, 2],
+ intri_opencv[1, 2],
+ )
+
+ new_intri_opencv = torch.eye(3, dtype=intri_opencv.dtype, device=intri_opencv.device)
+ if clockwise:
+ new_intri_opencv[0, 0] = fy
+ new_intri_opencv[1, 1] = fx
+ new_intri_opencv[0, 2] = image_height - cy
+ new_intri_opencv[1, 2] = cx
+ else:
+ new_intri_opencv[0, 0] = fy
+ new_intri_opencv[1, 1] = fx
+ new_intri_opencv[0, 2] = cy
+ new_intri_opencv[1, 2] = image_width - cx
+
+ new_intri_opencv[0, :] /= image_height
+ new_intri_opencv[1, :] /= image_width
+
+ return new_intri_opencv
diff --git a/src/dataset/shims/bounds_shim.py b/src/dataset/shims/bounds_shim.py
new file mode 100644
index 0000000000000000000000000000000000000000..f699867554de265bace335d22a90dfb259560650
--- /dev/null
+++ b/src/dataset/shims/bounds_shim.py
@@ -0,0 +1,80 @@
+import torch
+from einops import einsum, reduce, repeat
+from jaxtyping import Float
+from torch import Tensor
+
+from ..types import BatchedExample
+
+
+def compute_depth_for_disparity(
+ extrinsics: Float[Tensor, "batch view 4 4"],
+ intrinsics: Float[Tensor, "batch view 3 3"],
+ image_shape: tuple[int, int],
+ disparity: float,
+ delta_min: float = 1e-6, # This prevents motionless scenes from lacking depth.
+) -> Float[Tensor, " batch"]:
+ """Compute the depth at which moving the maximum distance between cameras
+ corresponds to the specified disparity (in pixels).
+ """
+
+ # Use the furthest distance between cameras as the baseline.
+ origins = extrinsics[:, :, :3, 3]
+ deltas = (origins[:, None, :, :] - origins[:, :, None, :]).norm(dim=-1)
+ deltas = deltas.clip(min=delta_min)
+ baselines = reduce(deltas, "b v ov -> b", "max")
+
+ # Compute a single pixel's size at depth 1.
+ h, w = image_shape
+ pixel_size = 1 / torch.tensor((w, h), dtype=torch.float32, device=extrinsics.device)
+ pixel_size = einsum(
+ intrinsics[..., :2, :2].inverse(), pixel_size, "... i j, j -> ... i"
+ )
+
+ # This wouldn't make sense with non-square pixels, but then again, non-square pixels
+ # don't make much sense anyway.
+ mean_pixel_size = reduce(pixel_size, "b v xy -> b", "mean")
+
+ return baselines / (disparity * mean_pixel_size)
+
+
+def apply_bounds_shim(
+ batch: BatchedExample,
+ near_disparity: float,
+ far_disparity: float,
+) -> BatchedExample:
+ """Compute reasonable near and far planes (lower and upper bounds on depth). This
+ assumes that all of an example's views are of roughly the same thing.
+ """
+
+ context = batch["context"]
+ _, cv, _, h, w = context["image"].shape
+
+ # Compute near and far planes using the context views.
+ near = compute_depth_for_disparity(
+ context["extrinsics"],
+ context["intrinsics"],
+ (h, w),
+ near_disparity,
+ )
+ far = compute_depth_for_disparity(
+ context["extrinsics"],
+ context["intrinsics"],
+ (h, w),
+ far_disparity,
+ )
+
+ target = batch["target"]
+ _, tv, _, _, _ = target["image"].shape
+ return {
+ **batch,
+ "context": {
+ **context,
+ "near": repeat(near, "b -> b v", v=cv),
+ "far": repeat(far, "b -> b v", v=cv),
+ },
+ "target": {
+ **target,
+ "near": repeat(near, "b -> b v", v=tv),
+ "far": repeat(far, "b -> b v", v=tv),
+ },
+ }
diff --git a/src/dataset/shims/crop_shim.py b/src/dataset/shims/crop_shim.py
new file mode 100644
index 0000000000000000000000000000000000000000..25b50f4b7359cb96aca86615dfcb6091f76df3a8
--- /dev/null
+++ b/src/dataset/shims/crop_shim.py
@@ -0,0 +1,196 @@
+import random
+import numpy as np
+import torch
+from einops import rearrange
+from jaxtyping import Float
+from PIL import Image
+from torch import Tensor
+import torchvision.transforms.functional as F
+import cv2
+
+from ..types import AnyExample, AnyViews
+
+
+def rescale(
+ image: Float[Tensor, "3 h_in w_in"],
+ shape: tuple[int, int],
+) -> Float[Tensor, "3 h_out w_out"]:
+ h, w = shape
+ image_new = (image * 255).clip(min=0, max=255).type(torch.uint8)
+ image_new = rearrange(image_new, "c h w -> h w c").detach().cpu().numpy()
+ image_new = Image.fromarray(image_new)
+ image_new = image_new.resize((w, h), Image.LANCZOS)
+ image_new = np.array(image_new) / 255
+ image_new = torch.tensor(image_new, dtype=image.dtype, device=image.device)
+ return rearrange(image_new, "h w c -> c h w")
+
+def rescale_depth(
+ depth: Float[Tensor, "1 h w"],
+ shape: tuple[int, int],
+) -> Float[Tensor, "1 h_out w_out"]:
+ h, w = shape
+ depth_new = depth.detach().cpu().numpy()
+ depth_new = cv2.resize(depth_new, (w,h), interpolation=cv2.INTER_NEAREST)
+ depth_new = torch.from_numpy(depth_new).to(depth.device)
+ return depth_new
+
+def center_crop(
+ images: Float[Tensor, "*#batch c h w"],
+ intrinsics: Float[Tensor, "*#batch 3 3"],
+ shape: tuple[int, int],
+ depths: Float[Tensor, "*#batch 1 h w"] | None = None,
+) -> tuple[
+ Float[Tensor, "*#batch c h_out w_out"], # updated images
+ Float[Tensor, "*#batch 3 3"], # updated intrinsics
+ Float[Tensor, "*#batch 1 h_out w_out"] | None, # updated depths
+]:
+ *_, h_in, w_in = images.shape
+ h_out, w_out = shape
+
+ # Note that odd input dimensions induce half-pixel misalignments.
+ row = (h_in - h_out) // 2
+ col = (w_in - w_out) // 2
+
+ # Center-crop the image.
+ images = images[..., :, row : row + h_out, col : col + w_out]
+
+ if depths is not None:
+ depths = depths[..., row : row + h_out, col : col + w_out]
+
+ # Adjust the intrinsics to account for the cropping.
+ intrinsics = intrinsics.clone()
+ intrinsics[..., 0, 0] *= w_in / w_out # fx
+ intrinsics[..., 1, 1] *= h_in / h_out # fy
+
+
+ if depths is not None:
+ return images, intrinsics, depths
+ else:
+ return images, intrinsics
+
+
+def rescale_and_crop(
+ images: Float[Tensor, "*#batch c h w"],
+ intrinsics: Float[Tensor, "*#batch 3 3"],
+ shape: tuple[int, int],
+ intr_aug: bool = False,
+ scale_range: tuple[float, float] = (0.77, 1.0),
+ depths: Float[Tensor, "*#batch 1 h w"] | None = None,
+) -> tuple[
+ Float[Tensor, "*#batch c h_out w_out"], # updated images
+ Float[Tensor, "*#batch 3 3"], # updated intrinsics
+ Float[Tensor, "*#batch 1 h_out w_out"] | None, # updated depths
+]:
+ if type(images) == list:
+ images_new = []
+ intrinsics_new = []
+ for i in range(len(images)):
+ image = images[i]
+ intrinsic = intrinsics[i]
+
+ *_, h_in, w_in = image.shape
+ h_out, w_out = shape
+
+ scale_factor = max(h_out / h_in, w_out / w_in)
+ h_scaled = round(h_in * scale_factor)
+ w_scaled = round(w_in * scale_factor)
+ image = F.resize(image, (h_scaled, w_scaled))
+ image = F.center_crop(image, (h_out, w_out))
+ images_new.append(image)
+
+ intrinsic_new = intrinsic.clone()
+ intrinsic_new[..., 0, 0] *= w_scaled / w_in # fx
+ intrinsic_new[..., 1, 1] *= h_scaled / h_in # fy
+ intrinsics_new.append(intrinsic_new)
+
+ if depths is not None:
+ depths_new = []
+ for i in range(len(depths)):
+ depth = depths[i]
+ depth = rescale_depth(depth, (h_out, w_out))
+ depth = F.center_crop(depth, (h_out, w_out))
+ depths_new.append(depth)
+ return torch.stack(images_new), torch.stack(intrinsics_new), torch.stack(depths_new)
+ else:
+ return torch.stack(images_new), torch.stack(intrinsics_new)
+
+ else:
+ # we only support intr_aug for clean datasets
+ *_, h_in, w_in = images.shape
+ h_out, w_out = shape
+ # assert h_out <= h_in and w_out <= w_in # to avoid the case that the image is too small, like co3d
+
+ if intr_aug:
+ scale = random.uniform(*scale_range)
+ h_scale = round(h_out * scale)
+ w_scale = round(w_out * scale)
+ else:
+ h_scale = h_out
+ w_scale = w_out
+
+ scale_factor = max(h_scale / h_in, w_scale / w_in)
+ h_scaled = round(h_in * scale_factor)
+ w_scaled = round(w_in * scale_factor)
+ assert h_scaled == h_scale or w_scaled == w_scale
+
+ # Reshape the images to the correct size. Assume we don't have to worry about
+ # changing the intrinsics based on how the images are rounded.
+ *batch, c, h, w = images.shape
+ images = images.reshape(-1, c, h, w)
+ images = torch.stack([rescale(image, (h_scaled, w_scaled)) for image in images])
+ images = images.reshape(*batch, c, h_scaled, w_scaled)
+
+ if depths is not None:
+ if type(depths) == list:
+ depths_new = []
+ for i in range(len(depths)):
+ depth = depths[i]
+ depth = rescale_depth(depth, (h_scaled, w_scaled))
+ depths_new.append(depth)
+ depths = torch.stack(depths_new)
+ else:
+ depths = depths.reshape(-1, h, w)
+ depths = torch.stack([rescale_depth(depth, (h_scaled, w_scaled)) for depth in depths])
+ depths = depths.reshape(*batch, h_scaled, w_scaled)
+
+ images, intrinsics, depths = center_crop(images, intrinsics, (h_scale, w_scale), depths)
+
+ if intr_aug:
+ images = F.resize(images, size=(h_out, w_out), interpolation=F.InterpolationMode.BILINEAR)
+ depths = F.resize(depths, size=(h_out, w_out), interpolation=F.InterpolationMode.NEAREST)
+
+ return images, intrinsics, depths
+ else:
+ images, intrinsics = center_crop(images, intrinsics, (h_scale, w_scale))
+
+ if intr_aug:
+ images = F.resize(images, size=(h_out, w_out))
+
+ return images, intrinsics
+
+
+def apply_crop_shim_to_views(views: AnyViews, shape: tuple[int, int], intr_aug: bool = False) -> AnyViews:
+ if "depth" in views.keys():
+ images, intrinsics, depths = rescale_and_crop(views["image"], views["intrinsics"], shape, depths=views["depth"], intr_aug=intr_aug)
+ return {
+ **views,
+ "image": images,
+ "intrinsics": intrinsics,
+ "depth": depths,
+ }
+ else:
+ images, intrinsics = rescale_and_crop(views["image"], views["intrinsics"], shape, intr_aug)
+ return {
+ **views,
+ "image": images,
+ "intrinsics": intrinsics,
+ }
+
+
+def apply_crop_shim(example: AnyExample, shape: tuple[int, int], intr_aug: bool = False) -> AnyExample:
+ """Crop images in the example."""
+ return {
+ **example,
+ "context": apply_crop_shim_to_views(example["context"], shape, intr_aug),
+ "target": apply_crop_shim_to_views(example["target"], shape, intr_aug),
+ }
diff --git a/src/dataset/shims/geometry_shim.py b/src/dataset/shims/geometry_shim.py
new file mode 100644
index 0000000000000000000000000000000000000000..e5a1bbceaac16e18914678162d7f55b27c4d7561
--- /dev/null
+++ b/src/dataset/shims/geometry_shim.py
@@ -0,0 +1,383 @@
+# Copyright (C) 2024-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+#
+# --------------------------------------------------------
+# geometry utilitary functions
+# --------------------------------------------------------
+import torch
+import numpy as np
+from scipy.spatial import cKDTree as KDTree
+
+def invalid_to_nans(arr, valid_mask, ndim=999):
+ if valid_mask is not None:
+ arr = arr.clone()
+ arr[~valid_mask] = float('nan')
+ if arr.ndim > ndim:
+ arr = arr.flatten(-2 - (arr.ndim - ndim), -2)
+ return arr
+
+
+def invalid_to_zeros(arr, valid_mask, ndim=999):
+ if valid_mask is not None:
+ arr = arr.clone()
+ arr[~valid_mask] = 0
+ nnz = valid_mask.view(len(valid_mask), -1).sum(1)
+ else:
+ nnz = arr.numel() // len(arr) if len(arr) else 0 # number of point per image
+ if arr.ndim > ndim:
+ arr = arr.flatten(-2 - (arr.ndim - ndim), -2)
+ return arr, nnz
+
+
+def xy_grid(W, H, device=None, origin=(0, 0), unsqueeze=None, cat_dim=-1, homogeneous=False, **arange_kw):
+ """ Output a (H,W,2) array of int32
+ with output[j,i,0] = i + origin[0]
+ output[j,i,1] = j + origin[1]
+ """
+ if device is None:
+ # numpy
+ arange, meshgrid, stack, ones = np.arange, np.meshgrid, np.stack, np.ones
+ else:
+ # torch
+ arange = lambda *a, **kw: torch.arange(*a, device=device, **kw)
+ meshgrid, stack = torch.meshgrid, torch.stack
+ ones = lambda *a: torch.ones(*a, device=device)
+
+ tw, th = [arange(o, o + s, **arange_kw) for s, o in zip((W, H), origin)]
+ grid = meshgrid(tw, th, indexing='xy')
+ if homogeneous:
+ grid = grid + (ones((H, W)),)
+ if unsqueeze is not None:
+ grid = (grid[0].unsqueeze(unsqueeze), grid[1].unsqueeze(unsqueeze))
+ if cat_dim is not None:
+ grid = stack(grid, cat_dim)
+ return grid
+
+
+def geotrf(Trf, pts, ncol=None, norm=False):
+ """ Apply a geometric transformation to a list of 3-D points.
+
+ H: 3x3 or 4x4 projection matrix (typically a Homography)
+ p: numpy/torch/tuple of coordinates. Shape must be (...,2) or (...,3)
+
+ ncol: int. number of columns of the result (2 or 3)
+ norm: float. if != 0, the resut is projected on the z=norm plane.
+
+ Returns an array of projected 2d points.
+ """
+ assert Trf.ndim >= 2
+ if isinstance(Trf, np.ndarray):
+ pts = np.asarray(pts)
+ elif isinstance(Trf, torch.Tensor):
+ pts = torch.as_tensor(pts, dtype=Trf.dtype)
+
+ # adapt shape if necessary
+ output_reshape = pts.shape[:-1]
+ ncol = ncol or pts.shape[-1]
+
+ # optimized code
+ if (isinstance(Trf, torch.Tensor) and isinstance(pts, torch.Tensor) and
+ Trf.ndim == 3 and pts.ndim == 4):
+ d = pts.shape[3]
+ if Trf.shape[-1] == d:
+ pts = torch.einsum("bij, bhwj -> bhwi", Trf, pts)
+ elif Trf.shape[-1] == d + 1:
+ pts = torch.einsum("bij, bhwj -> bhwi", Trf[:, :d, :d], pts) + Trf[:, None, None, :d, d]
+ else:
+ raise ValueError(f'bad shape, not ending with 3 or 4, for {pts.shape=}')
+ else:
+ if Trf.ndim >= 3:
+ n = Trf.ndim - 2
+ assert Trf.shape[:n] == pts.shape[:n], 'batch size does not match'
+ Trf = Trf.reshape(-1, Trf.shape[-2], Trf.shape[-1])
+
+ if pts.ndim > Trf.ndim:
+ # Trf == (B,d,d) & pts == (B,H,W,d) --> (B, H*W, d)
+ pts = pts.reshape(Trf.shape[0], -1, pts.shape[-1])
+ elif pts.ndim == 2:
+ # Trf == (B,d,d) & pts == (B,d) --> (B, 1, d)
+ pts = pts[:, None, :]
+
+ if pts.shape[-1] + 1 == Trf.shape[-1]:
+ Trf = Trf.swapaxes(-1, -2) # transpose Trf
+ pts = pts @ Trf[..., :-1, :] + Trf[..., -1:, :]
+ elif pts.shape[-1] == Trf.shape[-1]:
+ Trf = Trf.swapaxes(-1, -2) # transpose Trf
+ pts = pts @ Trf
+ else:
+ pts = Trf @ pts.T
+ if pts.ndim >= 2:
+ pts = pts.swapaxes(-1, -2)
+
+ if norm:
+ pts = pts / pts[..., -1:] # DONT DO /= BECAUSE OF WEIRD PYTORCH BUG
+ if norm != 1:
+ pts *= norm
+
+ res = pts[..., :ncol].reshape(*output_reshape, ncol)
+ return res
+
+
+def inv(mat):
+ """ Invert a torch or numpy matrix
+ """
+ if isinstance(mat, torch.Tensor):
+ return torch.linalg.inv(mat)
+ if isinstance(mat, np.ndarray):
+ return np.linalg.inv(mat)
+ raise ValueError(f'bad matrix type = {type(mat)}')
+
+
+def depthmap_to_pts3d(depth, pseudo_focal, pp=None, **_):
+ """
+ Args:
+ - depthmap (BxHxW array):
+ - pseudo_focal: [B,H,W] ; [B,2,H,W] or [B,1,H,W]
+ Returns:
+ pointmap of absolute coordinates (BxHxWx3 array)
+ """
+
+ if len(depth.shape) == 4:
+ B, H, W, n = depth.shape
+ else:
+ B, H, W = depth.shape
+ n = None
+
+ if len(pseudo_focal.shape) == 3: # [B,H,W]
+ pseudo_focalx = pseudo_focaly = pseudo_focal
+ elif len(pseudo_focal.shape) == 4: # [B,2,H,W] or [B,1,H,W]
+ pseudo_focalx = pseudo_focal[:, 0]
+ if pseudo_focal.shape[1] == 2:
+ pseudo_focaly = pseudo_focal[:, 1]
+ else:
+ pseudo_focaly = pseudo_focalx
+ else:
+ raise NotImplementedError("Error, unknown input focal shape format.")
+
+ assert pseudo_focalx.shape == depth.shape[:3]
+ assert pseudo_focaly.shape == depth.shape[:3]
+ grid_x, grid_y = xy_grid(W, H, cat_dim=0, device=depth.device)[:, None]
+
+ # set principal point
+ if pp is None:
+ grid_x = grid_x - (W - 1) / 2
+ grid_y = grid_y - (H - 1) / 2
+ else:
+ grid_x = grid_x.expand(B, -1, -1) - pp[:, 0, None, None]
+ grid_y = grid_y.expand(B, -1, -1) - pp[:, 1, None, None]
+
+ if n is None:
+ pts3d = torch.empty((B, H, W, 3), device=depth.device)
+ pts3d[..., 0] = depth * grid_x / pseudo_focalx
+ pts3d[..., 1] = depth * grid_y / pseudo_focaly
+ pts3d[..., 2] = depth
+ else:
+ pts3d = torch.empty((B, H, W, 3, n), device=depth.device)
+ pts3d[..., 0, :] = depth * (grid_x / pseudo_focalx)[..., None]
+ pts3d[..., 1, :] = depth * (grid_y / pseudo_focaly)[..., None]
+ pts3d[..., 2, :] = depth
+ return pts3d
+
+
+def depthmap_to_camera_coordinates(depthmap, camera_intrinsics, pseudo_focal=None):
+ """
+ Args:
+ - depthmap (HxW array):
+ - camera_intrinsics: a 3x3 matrix
+ Returns:
+ pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels.
+ """
+ camera_intrinsics = np.float32(camera_intrinsics)
+ H, W = depthmap.shape
+
+ # Compute 3D ray associated with each pixel
+ # Strong assumption: there are no skew terms
+ assert camera_intrinsics[0, 1] == 0.0
+ assert camera_intrinsics[1, 0] == 0.0
+ if pseudo_focal is None:
+ fu = camera_intrinsics[0, 0]
+ fv = camera_intrinsics[1, 1]
+ else:
+ assert pseudo_focal.shape == (H, W)
+ fu = fv = pseudo_focal
+ cu = camera_intrinsics[0, 2]
+ cv = camera_intrinsics[1, 2]
+
+ u, v = np.meshgrid(np.arange(W), np.arange(H))
+ z_cam = depthmap
+ x_cam = (u - cu) * z_cam / fu
+ y_cam = (v - cv) * z_cam / fv
+ X_cam = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32)
+
+ # Mask for valid coordinates
+ valid_mask = (depthmap > 0.0)
+ return X_cam, valid_mask
+
+
+def depthmap_to_absolute_camera_coordinates(depthmap, camera_intrinsics, camera_pose=None, **kw):
+ """
+ Args:
+ - depthmap (HxW array):
+ - camera_intrinsics: a 3x3 matrix
+ - camera_pose: a 4x3 or 4x4 cam2world matrix
+ Returns:
+ pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels."""
+ X_cam, valid_mask = depthmap_to_camera_coordinates(depthmap, camera_intrinsics)
+
+ X_world = X_cam # default
+ if camera_pose is not None:
+ # R_cam2world = np.float32(camera_params["R_cam2world"])
+ # t_cam2world = np.float32(camera_params["t_cam2world"]).squeeze()
+ R_cam2world = camera_pose[:3, :3]
+ t_cam2world = camera_pose[:3, 3]
+
+ # Express in absolute coordinates (invalid depth values)
+ X_world = np.einsum("ik, vuk -> vui", R_cam2world, X_cam) + t_cam2world[None, None, :]
+
+ return X_world, valid_mask
+
+
+def colmap_to_opencv_intrinsics(K):
+ """
+ Modify camera intrinsics to follow a different convention.
+ Coordinates of the center of the top-left pixels are by default:
+ - (0.5, 0.5) in Colmap
+ - (0,0) in OpenCV
+ """
+ K = K.copy()
+ K[0, 2] -= 0.5
+ K[1, 2] -= 0.5
+ return K
+
+
+def opencv_to_colmap_intrinsics(K):
+ """
+ Modify camera intrinsics to follow a different convention.
+ Coordinates of the center of the top-left pixels are by default:
+ - (0.5, 0.5) in Colmap
+ - (0,0) in OpenCV
+ """
+ K = K.copy()
+ K[0, 2] += 0.5
+ K[1, 2] += 0.5
+ return K
+
+
+def normalize_pointcloud(pts1, pts2, norm_mode='avg_dis', valid1=None, valid2=None, ret_factor=False):
+ """ renorm pointmaps pts1, pts2 with norm_mode
+ """
+ assert pts1.ndim >= 3 and pts1.shape[-1] == 3
+ assert pts2 is None or (pts2.ndim >= 3 and pts2.shape[-1] == 3)
+ norm_mode, dis_mode = norm_mode.split('_')
+
+ if norm_mode == 'avg':
+ # gather all points together (joint normalization)
+ nan_pts1, nnz1 = invalid_to_zeros(pts1, valid1, ndim=3)
+ nan_pts2, nnz2 = invalid_to_zeros(pts2, valid2, ndim=3) if pts2 is not None else (None, 0)
+ all_pts = torch.cat((nan_pts1, nan_pts2), dim=1) if pts2 is not None else nan_pts1
+
+ # compute distance to origin
+ all_dis = all_pts.norm(dim=-1)
+ if dis_mode == 'dis':
+ pass # do nothing
+ elif dis_mode == 'log1p':
+ all_dis = torch.log1p(all_dis)
+ elif dis_mode == 'warp-log1p':
+ # actually warp input points before normalizing them
+ log_dis = torch.log1p(all_dis)
+ warp_factor = log_dis / all_dis.clip(min=1e-8)
+ H1, W1 = pts1.shape[1:-1]
+ pts1 = pts1 * warp_factor[:, :W1 * H1].view(-1, H1, W1, 1)
+ if pts2 is not None:
+ H2, W2 = pts2.shape[1:-1]
+ pts2 = pts2 * warp_factor[:, W1 * H1:].view(-1, H2, W2, 1)
+ all_dis = log_dis # this is their true distance afterwards
+ else:
+ raise ValueError(f'bad {dis_mode=}')
+
+ norm_factor = all_dis.sum(dim=1) / (nnz1 + nnz2 + 1e-8)
+ else:
+ # gather all points together (joint normalization)
+ nan_pts1 = invalid_to_nans(pts1, valid1, ndim=3)
+ nan_pts2 = invalid_to_nans(pts2, valid2, ndim=3) if pts2 is not None else None
+ all_pts = torch.cat((nan_pts1, nan_pts2), dim=1) if pts2 is not None else nan_pts1
+
+ # compute distance to origin
+ all_dis = all_pts.norm(dim=-1)
+
+ if norm_mode == 'avg':
+ norm_factor = all_dis.nanmean(dim=1)
+ elif norm_mode == 'median':
+ norm_factor = all_dis.nanmedian(dim=1).values.detach()
+ elif norm_mode == 'sqrt':
+ norm_factor = all_dis.sqrt().nanmean(dim=1)**2
+ else:
+ raise ValueError(f'bad {norm_mode=}')
+
+ norm_factor = norm_factor.clip(min=1e-8)
+ while norm_factor.ndim < pts1.ndim:
+ norm_factor.unsqueeze_(-1)
+
+ res = pts1 / norm_factor
+ if pts2 is not None:
+ res = (res, pts2 / norm_factor)
+ if ret_factor:
+ res = res + (norm_factor,)
+ return res
+
+
+@torch.no_grad()
+def get_joint_pointcloud_depth(z1, z2, valid_mask1, valid_mask2=None, quantile=0.5):
+ # set invalid points to NaN
+ _z1 = invalid_to_nans(z1, valid_mask1).reshape(len(z1), -1)
+ _z2 = invalid_to_nans(z2, valid_mask2).reshape(len(z2), -1) if z2 is not None else None
+ _z = torch.cat((_z1, _z2), dim=-1) if z2 is not None else _z1
+
+ # compute median depth overall (ignoring nans)
+ if quantile == 0.5:
+ shift_z = torch.nanmedian(_z, dim=-1).values
+ else:
+ shift_z = torch.nanquantile(_z, quantile, dim=-1)
+ return shift_z # (B,)
+
+
+@torch.no_grad()
+def get_joint_pointcloud_center_scale(pts1, pts2, valid_mask1=None, valid_mask2=None, z_only=False, center=True):
+ # set invalid points to NaN
+ _pts1 = invalid_to_nans(pts1, valid_mask1).reshape(len(pts1), -1, 3)
+ _pts2 = invalid_to_nans(pts2, valid_mask2).reshape(len(pts2), -1, 3) if pts2 is not None else None
+ _pts = torch.cat((_pts1, _pts2), dim=1) if pts2 is not None else _pts1
+
+ # compute median center
+ _center = torch.nanmedian(_pts, dim=1, keepdim=True).values # (B,1,3)
+ if z_only:
+ _center[..., :2] = 0 # do not center X and Y
+
+ # compute median norm
+ _norm = ((_pts - _center) if center else _pts).norm(dim=-1)
+ scale = torch.nanmedian(_norm, dim=1).values
+ return _center[:, None, :, :], scale[:, None, None, None]
+
+
+def find_reciprocal_matches(P1, P2):
+ """
+ returns 3 values:
+ 1 - reciprocal_in_P2: a boolean array of size P2.shape[0], a "True" value indicates a match
+ 2 - nn2_in_P1: a int array of size P2.shape[0], it contains the indexes of the closest points in P1
+ 3 - reciprocal_in_P2.sum(): the number of matches
+ """
+ tree1 = KDTree(P1)
+ tree2 = KDTree(P2)
+
+ _, nn1_in_P2 = tree2.query(P1, workers=8)
+ _, nn2_in_P1 = tree1.query(P2, workers=8)
+
+ reciprocal_in_P1 = (nn2_in_P1[nn1_in_P2] == np.arange(len(nn1_in_P2)))
+ reciprocal_in_P2 = (nn1_in_P2[nn2_in_P1] == np.arange(len(nn2_in_P1)))
+ assert reciprocal_in_P1.sum() == reciprocal_in_P2.sum()
+ return reciprocal_in_P2, nn2_in_P1, reciprocal_in_P2.sum()
+
+
+def get_med_dist_between_poses(poses):
+ from scipy.spatial.distance import pdist
+ return np.median(pdist([p[:3, 3].detach().cpu().numpy() for p in poses]))
\ No newline at end of file
diff --git a/src/dataset/shims/load_shim.py b/src/dataset/shims/load_shim.py
new file mode 100644
index 0000000000000000000000000000000000000000..d94f272f06adddc5be565e17e0275770bb2023df
--- /dev/null
+++ b/src/dataset/shims/load_shim.py
@@ -0,0 +1,12 @@
+import cv2
+
+def imread_cv2(path, options=cv2.IMREAD_COLOR):
+ """Open an image or a depthmap with opencv-python."""
+ if path.endswith((".exr", "EXR")):
+ options = cv2.IMREAD_ANYDEPTH
+ img = cv2.imread(path, options)
+ if img is None:
+ raise IOError(f"Could not load image={path} with {options=}")
+ if img.ndim == 3:
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ return img
\ No newline at end of file
diff --git a/src/dataset/shims/normalize_shim.py b/src/dataset/shims/normalize_shim.py
new file mode 100644
index 0000000000000000000000000000000000000000..5650a51ccc0cd90b8db0f2a3c45eeb5c22b5c52c
--- /dev/null
+++ b/src/dataset/shims/normalize_shim.py
@@ -0,0 +1,27 @@
+import torch
+from einops import einsum, reduce, repeat
+from jaxtyping import Float
+from torch import Tensor
+
+from ..types import BatchedExample
+
+
+def inverse_normalize_image(tensor, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)):
+ mean = torch.as_tensor(mean, dtype=tensor.dtype, device=tensor.device).view(-1, 1, 1)
+ std = torch.as_tensor(std, dtype=tensor.dtype, device=tensor.device).view(-1, 1, 1)
+ return tensor * std + mean
+
+
+def normalize_image(tensor, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)):
+ mean = torch.as_tensor(mean, dtype=tensor.dtype, device=tensor.device).view(-1, 1, 1)
+ std = torch.as_tensor(std, dtype=tensor.dtype, device=tensor.device).view(-1, 1, 1)
+ return (tensor - mean) / std
+
+
+def apply_normalize_shim(
+ batch: BatchedExample,
+ mean: tuple[float, float, float] = (0.5, 0.5, 0.5),
+ std: tuple[float, float, float] = (0.5, 0.5, 0.5),
+) -> BatchedExample:
+ batch["context"]["image"] = normalize_image(batch["context"]["image"], mean, std)
+ return batch
diff --git a/src/dataset/shims/patch_shim.py b/src/dataset/shims/patch_shim.py
new file mode 100644
index 0000000000000000000000000000000000000000..60dc944995332726b060e7eac9bf914b63c3c840
--- /dev/null
+++ b/src/dataset/shims/patch_shim.py
@@ -0,0 +1,38 @@
+from ..types import BatchedExample, BatchedViews
+
+
+def apply_patch_shim_to_views(views: BatchedViews, patch_size: int) -> BatchedViews:
+ _, _, _, h, w = views["image"].shape
+
+ # Image size must be even so that naive center-cropping does not cause misalignment.
+ assert h % 2 == 0 and w % 2 == 0
+
+ h_new = (h // patch_size) * patch_size
+ row = (h - h_new) // 2
+ w_new = (w // patch_size) * patch_size
+ col = (w - w_new) // 2
+
+ # Center-crop the image.
+ image = views["image"][:, :, :, row : row + h_new, col : col + w_new]
+
+ # Adjust the intrinsics to account for the cropping.
+ intrinsics = views["intrinsics"].clone()
+ intrinsics[:, :, 0, 0] *= w / w_new # fx
+ intrinsics[:, :, 1, 1] *= h / h_new # fy
+
+ return {
+ **views,
+ "image": image,
+ "intrinsics": intrinsics,
+ }
+
+
+def apply_patch_shim(batch: BatchedExample, patch_size: int) -> BatchedExample:
+ """Crop images in the batch so that their dimensions are cleanly divisible by the
+ specified patch size.
+ """
+ return {
+ **batch,
+ "context": apply_patch_shim_to_views(batch["context"], patch_size),
+ "target": apply_patch_shim_to_views(batch["target"], patch_size),
+ }
diff --git a/src/dataset/types.py b/src/dataset/types.py
new file mode 100644
index 0000000000000000000000000000000000000000..76736c58b2daed4ca317a1031b6767cda6eed4f0
--- /dev/null
+++ b/src/dataset/types.py
@@ -0,0 +1,51 @@
+from typing import Callable, Literal, TypedDict
+
+from jaxtyping import Float, Int64
+from torch import Tensor
+
+Stage = Literal["train", "val", "test"]
+
+
+# The following types mainly exist to make type-hinted keys show up in VS Code. Some
+# dimensions are annotated as "_" because either:
+# 1. They're expected to change as part of a function call (e.g., resizing the dataset).
+# 2. They're expected to vary within the same function call (e.g., the number of views,
+# which differs between context and target BatchedViews).
+
+
+class BatchedViews(TypedDict, total=False):
+ extrinsics: Float[Tensor, "batch _ 4 4"] # batch view 4 4
+ intrinsics: Float[Tensor, "batch _ 3 3"] # batch view 3 3
+ image: Float[Tensor, "batch _ _ _ _"] # batch view channel height width
+ near: Float[Tensor, "batch _"] # batch view
+ far: Float[Tensor, "batch _"] # batch view
+ index: Int64[Tensor, "batch _"] # batch view
+ overlap: Float[Tensor, "batch _"] # batch view
+
+
+class BatchedExample(TypedDict, total=False):
+ target: BatchedViews
+ context: BatchedViews
+ scene: list[str]
+
+
+class UnbatchedViews(TypedDict, total=False):
+ extrinsics: Float[Tensor, "_ 4 4"]
+ intrinsics: Float[Tensor, "_ 3 3"]
+ image: Float[Tensor, "_ 3 height width"]
+ near: Float[Tensor, " _"]
+ far: Float[Tensor, " _"]
+ index: Int64[Tensor, " _"]
+
+
+class UnbatchedExample(TypedDict, total=False):
+ target: UnbatchedViews
+ context: UnbatchedViews
+ scene: str
+
+
+# A data shim modifies the example after it's been returned from the data loader.
+DataShim = Callable[[BatchedExample], BatchedExample]
+
+AnyExample = BatchedExample | UnbatchedExample
+AnyViews = BatchedViews | UnbatchedViews
diff --git a/src/dataset/validation_wrapper.py b/src/dataset/validation_wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c3fef9ffdfa52c1fa1e75c52aaecd79bf147e43
--- /dev/null
+++ b/src/dataset/validation_wrapper.py
@@ -0,0 +1,34 @@
+from typing import Iterator, Optional
+
+import torch
+from torch.utils.data import Dataset, IterableDataset
+
+
+class ValidationWrapper(Dataset):
+ """Wraps a dataset so that PyTorch Lightning's validation step can be turned into a
+ visualization step.
+ """
+
+ dataset: Dataset
+ dataset_iterator: Optional[Iterator]
+ length: int
+
+ def __init__(self, dataset: Dataset, length: int) -> None:
+ super().__init__()
+ self.dataset = dataset
+ self.length = length
+ self.dataset_iterator = None
+
+ def __len__(self):
+ return self.length
+
+ def __getitem__(self, index: tuple):
+ if isinstance(self.dataset, IterableDataset):
+ if self.dataset_iterator is None:
+ self.dataset_iterator = iter(self.dataset)
+ return next(self.dataset_iterator)
+
+ random_index = torch.randint(0, len(self.dataset), tuple())
+ random_context_num = torch.randint(2, self.dataset.view_sampler.num_context_views + 1, tuple())
+ # breakpoint()
+ return self.dataset[random_index.item(), random_context_num.item()]
diff --git a/src/dataset/view_sampler/__init__.py b/src/dataset/view_sampler/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c09a072e743d1f06e73770295d12d7ada57702b
--- /dev/null
+++ b/src/dataset/view_sampler/__init__.py
@@ -0,0 +1,40 @@
+from typing import Any
+
+from ...misc.step_tracker import StepTracker
+from ..types import Stage
+from .view_sampler import ViewSampler
+from .view_sampler_all import ViewSamplerAll, ViewSamplerAllCfg
+from .view_sampler_arbitrary import ViewSamplerArbitrary, ViewSamplerArbitraryCfg
+from .view_sampler_bounded import ViewSamplerBounded, ViewSamplerBoundedCfg
+from .view_sampler_evaluation import ViewSamplerEvaluation, ViewSamplerEvaluationCfg
+from .view_sampler_rank import ViewSamplerRank, ViewSamplerRankCfg
+VIEW_SAMPLERS: dict[str, ViewSampler[Any]] = {
+ "all": ViewSamplerAll,
+ "arbitrary": ViewSamplerArbitrary,
+ "bounded": ViewSamplerBounded,
+ "evaluation": ViewSamplerEvaluation,
+ "rank": ViewSamplerRank,
+}
+
+ViewSamplerCfg = (
+ ViewSamplerArbitraryCfg
+ | ViewSamplerBoundedCfg
+ | ViewSamplerEvaluationCfg
+ | ViewSamplerAllCfg
+ | ViewSamplerRankCfg
+)
+
+def get_view_sampler(
+ cfg: ViewSamplerCfg,
+ stage: Stage,
+ overfit: bool,
+ cameras_are_circular: bool,
+ step_tracker: StepTracker | None,
+) -> ViewSampler[Any]:
+ return VIEW_SAMPLERS[cfg.name](
+ cfg,
+ stage,
+ overfit,
+ cameras_are_circular,
+ step_tracker,
+ )
diff --git a/src/dataset/view_sampler/three_view_hack.py b/src/dataset/view_sampler/three_view_hack.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a47cb624a1472651bc6611de2c69b04fdc085ad
--- /dev/null
+++ b/src/dataset/view_sampler/three_view_hack.py
@@ -0,0 +1,10 @@
+import torch
+from jaxtyping import Int
+from torch import Tensor
+
+
+def add_third_context_index(
+ indices: Int[Tensor, "*batch 2"]
+) -> Int[Tensor, "*batch 3"]:
+ left, right = indices.unbind(dim=-1)
+ return torch.stack((left, (left + right) // 2, right), dim=-1)
diff --git a/src/dataset/view_sampler/view_sampler.py b/src/dataset/view_sampler/view_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..45a6276945363ad7a8691295edcb2d5ba4dd0134
--- /dev/null
+++ b/src/dataset/view_sampler/view_sampler.py
@@ -0,0 +1,61 @@
+from abc import ABC, abstractmethod
+from typing import Generic, TypeVar
+
+import torch
+from jaxtyping import Float, Int64
+from torch import Tensor
+
+from ...misc.step_tracker import StepTracker
+from ..types import Stage
+
+T = TypeVar("T")
+
+
+class ViewSampler(ABC, Generic[T]):
+ cfg: T
+ stage: Stage
+ is_overfitting: bool
+ cameras_are_circular: bool
+ step_tracker: StepTracker | None
+
+ def __init__(
+ self,
+ cfg: T,
+ stage: Stage,
+ is_overfitting: bool,
+ cameras_are_circular: bool,
+ step_tracker: StepTracker | None,
+ ) -> None:
+ self.cfg = cfg
+ self.stage = stage
+ self.is_overfitting = is_overfitting
+ self.cameras_are_circular = cameras_are_circular
+ self.step_tracker = step_tracker
+
+ @abstractmethod
+ def sample(
+ self,
+ scene: str,
+ extrinsics: Float[Tensor, "view 4 4"],
+ intrinsics: Float[Tensor, "view 3 3"],
+ device: torch.device = torch.device("cpu"),
+ ) -> tuple[
+ Int64[Tensor, " context_view"], # indices for context views
+ Int64[Tensor, " target_view"], # indices for target views
+ Float[Tensor, " overlap"], # overlap
+ ]:
+ pass
+
+ @property
+ @abstractmethod
+ def num_target_views(self) -> int:
+ pass
+
+ @property
+ @abstractmethod
+ def num_context_views(self) -> int:
+ pass
+
+ @property
+ def global_step(self) -> int:
+ return 0 if self.step_tracker is None else self.step_tracker.get_step()
diff --git a/src/dataset/view_sampler/view_sampler_all.py b/src/dataset/view_sampler/view_sampler_all.py
new file mode 100644
index 0000000000000000000000000000000000000000..91ac3dcc5c9f8cfdec7fc495f3162008cf5942f2
--- /dev/null
+++ b/src/dataset/view_sampler/view_sampler_all.py
@@ -0,0 +1,37 @@
+from dataclasses import dataclass
+from typing import Literal
+
+import torch
+from jaxtyping import Float, Int64
+from torch import Tensor
+
+from .view_sampler import ViewSampler
+
+
+@dataclass
+class ViewSamplerAllCfg:
+ name: Literal["all"]
+
+
+class ViewSamplerAll(ViewSampler[ViewSamplerAllCfg]):
+ def sample(
+ self,
+ scene: str,
+ extrinsics: Float[Tensor, "view 4 4"],
+ intrinsics: Float[Tensor, "view 3 3"],
+ device: torch.device = torch.device("cpu"),
+ ) -> tuple[
+ Int64[Tensor, " context_view"], # indices for context views
+ Int64[Tensor, " target_view"], # indices for target views
+ ]:
+ v, _, _ = extrinsics.shape
+ all_frames = torch.arange(v, device=device)
+ return all_frames, all_frames
+
+ @property
+ def num_context_views(self) -> int:
+ return 0
+
+ @property
+ def num_target_views(self) -> int:
+ return 0
diff --git a/src/dataset/view_sampler/view_sampler_arbitrary.py b/src/dataset/view_sampler/view_sampler_arbitrary.py
new file mode 100644
index 0000000000000000000000000000000000000000..44e68d051bf48629b0b51df187844a9f91f7d493
--- /dev/null
+++ b/src/dataset/view_sampler/view_sampler_arbitrary.py
@@ -0,0 +1,77 @@
+from dataclasses import dataclass
+from typing import Literal
+
+import torch
+from jaxtyping import Float, Int64
+from torch import Tensor
+
+from .three_view_hack import add_third_context_index
+from .view_sampler import ViewSampler
+
+
+@dataclass
+class ViewSamplerArbitraryCfg:
+ name: Literal["arbitrary"]
+ num_context_views: int
+ num_target_views: int
+ context_views: list[int] | None
+ target_views: list[int] | None
+
+
+class ViewSamplerArbitrary(ViewSampler[ViewSamplerArbitraryCfg]):
+ def sample(
+ self,
+ scene: str,
+ extrinsics: Float[Tensor, "view 4 4"],
+ intrinsics: Float[Tensor, "view 3 3"],
+ device: torch.device = torch.device("cpu"),
+ ) -> tuple[
+ Int64[Tensor, " context_view"], # indices for context views
+ Int64[Tensor, " target_view"], # indices for target views
+ Float[Tensor, " overlap"], # overlap
+ ]:
+ """Arbitrarily sample context and target views."""
+ num_views, _, _ = extrinsics.shape
+
+ index_context = torch.randint(
+ 0,
+ num_views,
+ size=(self.cfg.num_context_views,),
+ device=device,
+ )
+
+ # Allow the context views to be fixed.
+ if self.cfg.context_views is not None:
+ index_context = torch.tensor(
+ self.cfg.context_views, dtype=torch.int64, device=device
+ )
+
+ if self.cfg.num_context_views == 3 and len(self.cfg.context_views) == 2:
+ index_context = add_third_context_index(index_context)
+ else:
+ assert len(self.cfg.context_views) == self.cfg.num_context_views
+ index_target = torch.randint(
+ 0,
+ num_views,
+ size=(self.cfg.num_target_views,),
+ device=device,
+ )
+
+ # Allow the target views to be fixed.
+ if self.cfg.target_views is not None:
+ assert len(self.cfg.target_views) == self.cfg.num_target_views
+ index_target = torch.tensor(
+ self.cfg.target_views, dtype=torch.int64, device=device
+ )
+
+ overlap = torch.tensor([0.5], dtype=torch.float32, device=device) # dummy
+
+ return index_context, index_target, overlap
+
+ @property
+ def num_context_views(self) -> int:
+ return self.cfg.num_context_views
+
+ @property
+ def num_target_views(self) -> int:
+ return self.cfg.num_target_views
diff --git a/src/dataset/view_sampler/view_sampler_bounded.py b/src/dataset/view_sampler/view_sampler_bounded.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca4b9cef21dfc9da95ba6b1c3e600a3f51392d80
--- /dev/null
+++ b/src/dataset/view_sampler/view_sampler_bounded.py
@@ -0,0 +1,151 @@
+from dataclasses import dataclass
+from typing import Literal
+
+import torch
+from jaxtyping import Float, Int64
+from torch import Tensor
+
+from .view_sampler import ViewSampler
+
+
+@dataclass
+class ViewSamplerBoundedCfg:
+ name: Literal["bounded"]
+ num_context_views: int
+ num_target_views: int
+ min_distance_between_context_views: int
+ max_distance_between_context_views: int
+ min_distance_to_context_views: int
+ warm_up_steps: int
+ initial_min_distance_between_context_views: int
+ initial_max_distance_between_context_views: int
+ max_img_per_gpu: int
+ min_gap_multiplier: int
+ max_gap_multiplier: int
+
+class ViewSamplerBounded(ViewSampler[ViewSamplerBoundedCfg]):
+ def schedule(self, initial: int, final: int) -> int:
+ fraction = self.global_step / self.cfg.warm_up_steps
+ return min(initial + int((final - initial) * fraction), final)
+
+ def sample(
+ self,
+ scene: str,
+ num_context_views: int,
+ extrinsics: Float[Tensor, "view 4 4"],
+ intrinsics: Float[Tensor, "view 3 3"],
+ device: torch.device = torch.device("cpu"),
+ ) -> tuple[
+ Int64[Tensor, " context_view"], # indices for context views
+ Int64[Tensor, " target_view"], # indices for target views
+ Float[Tensor, " overlap"], # overlap
+ ]:
+ num_views, _, _ = extrinsics.shape
+
+ # Compute the context view spacing based on the current global step.
+ if self.stage == "test":
+ # When testing, always use the full gap.
+ max_gap = self.cfg.max_distance_between_context_views
+ min_gap = self.cfg.max_distance_between_context_views
+ # elif self.cfg.warm_up_steps > 0:
+ # max_gap = self.schedule(
+ # self.cfg.initial_max_distance_between_context_views,
+ # self.cfg.max_distance_between_context_views,
+ # )
+ # min_gap = self.schedule(
+ # self.cfg.initial_min_distance_between_context_views,
+ # self.cfg.min_distance_between_context_views,
+ # )
+ # else:
+ # max_gap = self.cfg.max_distance_between_context_views
+ # min_gap = self.cfg.min_distance_between_context_views
+
+ min_gap, max_gap = self.num_ctxt_gap_mapping[num_context_views]
+ max_gap = min(max_gap, num_views-1)
+ # Pick the gap between the context views.
+ if not self.cameras_are_circular:
+ max_gap = min(num_views - 1, max_gap)
+ min_gap = max(2 * self.cfg.min_distance_to_context_views, min_gap)
+ if max_gap < min_gap:
+ raise ValueError("Example does not have enough frames!")
+ context_gap = torch.randint(
+ min_gap,
+ max_gap + 1,
+ size=tuple(),
+ device=device,
+ ).item()
+
+ # Pick the left and right context indices.
+ index_context_left = torch.randint(
+ num_views if self.cameras_are_circular else num_views - context_gap,
+ size=tuple(),
+ device=device,
+ ).item()
+ if self.stage == "test":
+ index_context_left = index_context_left * 0
+ index_context_right = index_context_left + context_gap
+
+ if self.is_overfitting:
+ index_context_left *= 0
+ index_context_right *= 0
+ index_context_right += max_gap
+
+ # Pick the target view indices.
+ if self.stage == "test":
+ # When testing, pick all.
+ index_target = torch.arange(
+ index_context_left,
+ index_context_right + 1,
+ device=device,
+ )
+ else:
+ # When training or validating (visualizing), pick at random.
+ index_target = torch.randint(
+ index_context_left + self.cfg.min_distance_to_context_views,
+ index_context_right + 1 - self.cfg.min_distance_to_context_views,
+ size=(self.cfg.num_target_views,),
+ device=device,
+ )
+
+ # Apply modulo for circular datasets.
+ if self.cameras_are_circular:
+ index_target %= num_views
+ index_context_right %= num_views
+
+ # If more than two context views are desired, pick extra context views between
+ # the left and right ones.
+ if num_context_views > 2:
+ num_extra_views = num_context_views - 2
+ extra_views = []
+ while len(set(extra_views)) != num_extra_views:
+ extra_views = torch.randint(
+ index_context_left + 1,
+ index_context_right,
+ (num_extra_views,),
+ ).tolist()
+ else:
+ extra_views = []
+
+ overlap = torch.tensor([0.5], dtype=torch.float32, device=device) # dummy
+
+ return (
+ torch.tensor((index_context_left, *extra_views, index_context_right)),
+ index_target,
+ overlap
+ )
+
+ @property
+ def num_context_views(self) -> int:
+ return self.cfg.num_context_views
+
+ @property
+ def num_target_views(self) -> int:
+ return self.cfg.num_target_views
+
+ @property
+ def num_ctxt_gap_mapping(self) -> dict:
+ mapping = dict()
+ for num_ctxt in range(2, self.cfg.num_context_views + 1):
+ mapping[num_ctxt] = [min(num_ctxt * self.cfg.min_gap_multiplier, self.cfg.min_distance_between_context_views),
+ min(max(num_ctxt * self.cfg.max_gap_multiplier, num_ctxt ** 2), self.cfg.max_distance_between_context_views)]
+ return mapping
diff --git a/src/dataset/view_sampler/view_sampler_evaluation.py b/src/dataset/view_sampler/view_sampler_evaluation.py
new file mode 100644
index 0000000000000000000000000000000000000000..ad0b91e5c5527db2432dd11eca33b52a183d6213
--- /dev/null
+++ b/src/dataset/view_sampler/view_sampler_evaluation.py
@@ -0,0 +1,70 @@
+import json
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Literal
+
+import torch
+from dacite import Config, from_dict
+from jaxtyping import Float, Int64
+from torch import Tensor
+
+from ...evaluation.evaluation_index_generator import IndexEntry
+from ...global_cfg import get_cfg
+from ...misc.step_tracker import StepTracker
+from ..types import Stage
+from .three_view_hack import add_third_context_index
+from .view_sampler import ViewSampler
+
+
+@dataclass
+class ViewSamplerEvaluationCfg:
+ name: Literal["evaluation"]
+ index_path: Path
+ num_context_views: int
+
+
+class ViewSamplerEvaluation(ViewSampler[ViewSamplerEvaluationCfg]):
+ index: dict[str, IndexEntry | None]
+
+ def __init__(
+ self,
+ cfg: ViewSamplerEvaluationCfg,
+ stage: Stage,
+ is_overfitting: bool,
+ cameras_are_circular: bool,
+ step_tracker: StepTracker | None,
+ ) -> None:
+ super().__init__(cfg, stage, is_overfitting, cameras_are_circular, step_tracker)
+
+ dacite_config = Config(cast=[tuple])
+ with cfg.index_path.open("r") as f:
+ self.index = {
+ k: None if v is None else from_dict(IndexEntry, v, dacite_config)
+ for k, v in json.load(f).items()
+ }
+
+ def sample(
+ self,
+ scene: str,
+ extrinsics: Float[Tensor, "view 4 4"],
+ intrinsics: Float[Tensor, "view 3 3"],
+ device: torch.device = torch.device("cpu"),
+ **kwargs,
+ ) -> tuple[
+ Int64[Tensor, " context_view"], # indices for context views
+ Int64[Tensor, " target_view"], # indices for target views
+ ]:
+ entry = self.index.get(scene)
+ if entry is None:
+ raise ValueError(f"No indices available for scene {scene}.")
+ context_indices = torch.tensor(entry.context, dtype=torch.int64, device=device)
+ target_indices = torch.tensor(entry.target, dtype=torch.int64, device=device)
+ return context_indices, target_indices, torch.zeros(1)
+
+ @property
+ def num_context_views(self) -> int:
+ return 0
+
+ @property
+ def num_target_views(self) -> int:
+ return 0
diff --git a/src/dataset/view_sampler/view_sampler_rank.py b/src/dataset/view_sampler/view_sampler_rank.py
new file mode 100644
index 0000000000000000000000000000000000000000..6be507f07620c32eba589ff26a779c2456348827
--- /dev/null
+++ b/src/dataset/view_sampler/view_sampler_rank.py
@@ -0,0 +1,263 @@
+
+from dataclasses import dataclass
+from typing import Literal
+
+import torch
+import copy
+from jaxtyping import Float, Int64
+from torch import Tensor
+import random
+from .view_sampler import ViewSampler
+
+
+@dataclass
+class ViewSamplerRankCfg:
+ name: Literal["rank"]
+ num_context_views: int # max number of context views
+ num_target_views: int
+ min_distance_between_context_views: int
+ max_distance_between_context_views: int
+ min_distance_to_context_views: int
+ warm_up_steps: int
+ initial_min_distance_between_context_views: int
+ initial_max_distance_between_context_views: int
+ max_img_per_gpu: int
+
+
+def rotation_angle(R1, R2):
+ # R1 and R2 are 3x3 rotation matrices
+ R = R1.T @ R2
+ # Numerical stability: clamp values into [-1,1]
+ val = (torch.trace(R) - 1) / 2
+ val = torch.clamp(val, -1.0, 1.0)
+ angle_rad = torch.acos(val)
+ angle_deg = angle_rad * 180 / torch.pi # Convert radians to degrees
+ return angle_deg
+
+def extrinsic_distance(extrinsic1, extrinsic2, lambda_t=1.0):
+ R1, t1 = extrinsic1[:3, :3], extrinsic1[:3, 3]
+ R2, t2 = extrinsic2[:3, :3], extrinsic2[:3, 3]
+ rot_diff = rotation_angle(R1, R2) / 180
+
+ center_diff = torch.norm(t1 - t2)
+ return rot_diff + lambda_t * center_diff
+
+def rotation_angle_batch(R1, R2):
+ # R1, R2: shape (N, 3, 3)
+ # We want a matrix of rotation angles for all pairs.
+ # We'll get R1^T R2 for each pair.
+ # Expand dimensions to broadcast:
+ # R1^T: (N,3,3) -> (N,1,3,3)
+ # R2: (N,3,3) -> (1,N,3,3)
+ R1_t = R1.transpose(-2, -1)[:, None, :, :] # shape (N,1,3,3)
+ R2_b = R2[None, :, :, :] # shape (1,N,3,3)
+ R_mult = torch.matmul(R1_t, R2_b) # shape (N,N,3,3)
+ # trace(R) for each pair
+ trace_vals = R_mult[..., 0, 0] + R_mult[..., 1, 1] + R_mult[..., 2, 2] # (N,N)
+ val = (trace_vals - 1) / 2
+ val = torch.clamp(val, -1.0, 1.0)
+ angle_rad = torch.acos(val)
+ angle_deg = angle_rad * 180 / torch.pi
+ return angle_deg / 180.0 # normalized rotation difference
+
+def extrinsic_distance_batch(extrinsics, lambda_t=1.0):
+ # extrinsics: (N,4,4)
+ # Extract rotation and translation
+ R = extrinsics[:, :3, :3] # (N,3,3)
+ t = extrinsics[:, :3, 3] # (N,3)
+ # Compute all pairwise rotation differences
+ rot_diff = rotation_angle_batch(R, R) # (N,N)
+ # Compute all pairwise translation differences
+ # For t, shape (N,3). We want all pair differences: t[i] - t[j].
+ # t_i: (N,1,3), t_j: (1,N,3)
+ t_i = t[:, None, :] # (N,1,3)
+ t_j = t[None, :, :] # (1,N,3)
+ trans_diff = torch.norm(t_i - t_j, dim=2) # (N,N)
+ dists = rot_diff + lambda_t * trans_diff
+ return dists
+
+
+def compute_ranking(extrinsics, lambda_t=1.0, normalize=True, batched=True):
+
+ if normalize:
+ extrinsics = copy.deepcopy(extrinsics)
+ camera_center = copy.deepcopy(extrinsics[:, :3, 3])
+ camera_center_scale = torch.norm(camera_center, dim=1)
+ avg_scale = torch.mean(camera_center_scale)
+ extrinsics[:, :3, 3] = extrinsics[:, :3, 3] / avg_scale
+
+
+ if batched:
+ dists = extrinsic_distance_batch(extrinsics, lambda_t=lambda_t)
+ else:
+ N = extrinsics.shape[0]
+ dists = torch.zeros((N, N), device=extrinsics.device)
+ for i in range(N):
+ for j in range(N):
+ dists[i,j] = extrinsic_distance(extrinsics[i], extrinsics[j], lambda_t=lambda_t)
+ ranking = torch.argsort(dists, dim=1)
+ return ranking, dists
+
+# class ViewSamplerRank(ViewSampler[ViewSamplerRankCfg]):
+# def schedule(self, initial: int, final: int) -> int:
+# fraction = self.global_step / self.cfg.warm_up_steps
+# return min(initial + int((final - initial) * fraction), final)
+
+# def sample(
+# self,
+# scene: str,
+# extrinsics: Float[Tensor, "view 4 4"],
+# intrinsics: Float[Tensor, "view 3 3"],
+# device: torch.device = torch.device("cpu"),
+# ) -> tuple[
+# Int64[Tensor, " context_view"], # indices for context views
+# Int64[Tensor, " target_view"], # indices for target views
+# Float[Tensor, " overlap"], # overlap
+# ]:
+# num_views, _, _ = extrinsics.shape
+# # breakpoint()
+# # Compute the context view spacing based on the current global step.
+# ranking, dists = compute_ranking(extrinsics, lambda_t=1.0, normalize=True, batched=True)
+# reference_view = random.sample(range(num_views), 1)[0]
+
+# refview_ranking = ranking[reference_view]
+# # if self.cfg.warm_up_steps > 0:
+# # max_gap = self.schedule(
+# # self.cfg.initial_max_distance_between_context_views,
+# # self.cfg.max_distance_between_context_views,
+# # )
+# # min_gap = self.schedule(
+# # self.cfg.initial_min_distance_between_context_views,
+# # self.cfg.min_distance_between_context_views,
+# # )
+# # else:
+# max_gap = self.cfg.max_distance_between_context_views
+# min_gap = self.cfg.min_distance_between_context_views
+
+# index_context_left = reference_view
+# rightmost_index = random.sample(range(min_gap, max_gap + 1), 1)[0] + 1
+# index_context_right = refview_ranking[rightmost_index].item()
+
+# middle_indices = refview_ranking[1: rightmost_index].tolist()
+# index_target = random.sample(middle_indices, self.num_target_views)
+
+# remaining_indices = [idx for idx in middle_indices if idx not in index_target]
+
+# # Sample extra context views if needed
+# extra_views = []
+# num_extra_views = self.num_context_views - 2 # subtract left and right context views
+# if num_extra_views > 0 and remaining_indices:
+# extra_views = random.sample(remaining_indices, min(num_extra_views, len(remaining_indices)))
+# else:
+# extra_views = []
+
+# overlap = torch.zeros(1)
+
+# return (
+# torch.tensor((index_context_left, *extra_views, index_context_right)),
+# torch.tensor(index_target),
+# overlap
+# )
+
+
+# @property
+# def num_context_views(self) -> int:
+# return self.cfg.num_context_views
+
+# @property
+# def num_target_views(self) -> int:
+# return self.cfg.num_target_views
+
+
+class ViewSamplerRank(ViewSampler[ViewSamplerRankCfg]):
+
+ def sample(
+ self,
+ scene: str,
+ num_context_views: int,
+ extrinsics: Float[Tensor, "view 4 4"],
+ intrinsics: Float[Tensor, "view 3 3"],
+ device: torch.device = torch.device("cpu"),
+ ) -> tuple[
+ Int64[Tensor, " context_view"], # indices for context views
+ Int64[Tensor, " target_view"], # indices for target views
+ Float[Tensor, " overlap"], # overlap
+ ]:
+ num_views, _, _ = extrinsics.shape
+ # breakpoint()
+ extrinsics = extrinsics.clone()
+ # Compute the context view spacing based on the current global step.
+ ranking, dists = compute_ranking(extrinsics, lambda_t=1.0, normalize=True, batched=True)
+ reference_view = random.sample(range(num_views), 1)[0]
+
+ refview_ranking = ranking[reference_view]
+ # if self.cfg.warm_up_steps > 0:
+ # max_gap = self.schedule(
+ # self.cfg.initial_max_distance_between_context_views,
+ # self.cfg.max_distance_between_context_views,
+ # )
+ # min_gap = self.schedule(
+ # self.cfg.initial_min_distance_between_context_views,
+ # self.cfg.min_distance_between_context_views,
+ # )
+ # else:
+ min_gap, max_gap = self.num_ctxt_gap_mapping[num_context_views]
+
+ # min_gap = self.cfg.min_distance_between_context_views
+ # max_gap = self.cfg.max_distance_between_context_views
+
+ max_gap = min(max_gap, num_views-1)
+ # print(f"num_context_views: {num_context_views}, min_gap: {min_gap}, max_gap: {max_gap}")
+ index_context_left = reference_view
+ rightmost_index = random.sample(range(min_gap, max_gap + 1), 1)[0]
+
+ # #! hard code for visualization
+ # rightmost_index = self.cfg.max_distance_between_context_views
+
+ index_context_right = refview_ranking[rightmost_index].item()
+
+ middle_indices = refview_ranking[1: rightmost_index].tolist()
+ index_target = random.sample(middle_indices, self.num_target_views)
+
+ remaining_indices = [idx for idx in middle_indices if idx not in index_target]
+
+ # Sample extra context views if needed
+ extra_views = []
+ num_extra_views = num_context_views - 2 # subtract left and right context views
+ if num_extra_views > 0 and remaining_indices:
+ extra_views = random.sample(remaining_indices, min(num_extra_views, len(remaining_indices)))
+ else:
+ extra_views = []
+
+ overlap = torch.zeros(1)
+
+ return (
+ torch.tensor((index_context_left, *extra_views, index_context_right)),
+ torch.tensor(index_target),
+ overlap
+ )
+
+
+ @property
+ def num_context_views(self) -> int:
+ return self.cfg.num_context_views
+
+ @property
+ def num_target_views(self) -> int:
+ return self.cfg.num_target_views
+
+ @property
+ def num_ctxt_gap_mapping_target(self) -> dict:
+ mapping = dict()
+ for num_ctxt in range(2, self.cfg.num_context_views + 1):
+ mapping[num_ctxt] = [max(num_ctxt * 2, self.cfg.num_target_views + num_ctxt), max(self.cfg.num_target_views + num_ctxt, min(num_ctxt ** 2, self.cfg.max_distance_between_context_views))]
+ return mapping
+
+ @property
+ def num_ctxt_gap_mapping(self) -> dict:
+ mapping = dict()
+ for num_ctxt in range(2, self.cfg.num_context_views + 1):
+ mapping[num_ctxt] = [min(num_ctxt * 3, self.cfg.min_distance_between_context_views), min(max(num_ctxt * 5, num_ctxt ** 2), self.cfg.max_distance_between_context_views)]
+ return mapping
+
+
diff --git a/src/eval_nvs.py b/src/eval_nvs.py
new file mode 100644
index 0000000000000000000000000000000000000000..06725ab45992ec5c7668c1e1f5caad62de2efce0
--- /dev/null
+++ b/src/eval_nvs.py
@@ -0,0 +1,115 @@
+import os
+from pathlib import Path
+import sys
+import json
+import gzip
+import argparse
+import numpy as np
+from PIL import Image
+
+import torch
+import torch.nn as nn
+import torchvision
+from einops import rearrange
+
+sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from src.evaluation.metrics import compute_lpips, compute_psnr, compute_ssim
+from misc.image_io import save_image, save_interpolated_video
+from src.utils.image import process_image
+
+from src.model.model.anysplat import AnySplat
+from src.model.encoder.vggt.utils.pose_enc import pose_encoding_to_extri_intri
+
+def setup_args():
+ """Set up command-line arguments for the eval NVS script."""
+ parser = argparse.ArgumentParser(description='Test AnySplat on NVS evaluation')
+ parser.add_argument('--data_dir', type=str, required=True, help='Path to NVS dataset')
+ parser.add_argument('--llffhold', type=int, default=8, help='LLFF holdout')
+ parser.add_argument('--output_path', type=str, default="outputs/nvs", help='Path to output directory')
+ return parser.parse_args()
+
+def compute_metrics(pred_image, image):
+ psnr = compute_psnr(pred_image, image)
+ ssim = compute_ssim(pred_image, image)
+ lpips = compute_lpips(pred_image, image)
+ return psnr, ssim, lpips
+
+def evaluate(args: argparse.Namespace):
+ model = AnySplat.from_pretrained("lhjiang/anysplat")
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ model.to(device)
+ model.eval()
+ for param in model.parameters():
+ param.requires_grad = False
+
+ os.makedirs(args.output_path, exist_ok=True)
+
+ # load images
+ image_folder = args.data_dir
+ image_names = sorted([os.path.join(image_folder, f) for f in os.listdir(image_folder) if f.lower().endswith(('.png', '.jpg', '.jpeg'))])
+ images = [process_image(img_path) for img_path in image_names]
+ ctx_indices = [idx for idx, name in enumerate(image_names) if idx % args.llffhold != 0]
+ tgt_indices = [idx for idx, name in enumerate(image_names) if idx % args.llffhold == 0]
+
+ ctx_images = torch.stack([images[i] for i in ctx_indices], dim=0).unsqueeze(0).to(device)
+ tgt_images = torch.stack([images[i] for i in tgt_indices], dim=0).unsqueeze(0).to(device)
+ ctx_images = (ctx_images+1)*0.5
+ tgt_images = (tgt_images+1)*0.5
+ b, v, _, h, w = tgt_images.shape
+
+ # run inference
+ encoder_output = model.encoder(
+ ctx_images,
+ global_step=0,
+ visualization_dump={},
+ )
+ gaussians, pred_context_pose = encoder_output.gaussians, encoder_output.pred_context_pose
+
+ num_context_view = ctx_images.shape[1]
+ vggt_input_image = torch.cat((ctx_images, tgt_images), dim=1).to(torch.bfloat16)
+ with torch.no_grad(), torch.cuda.amp.autocast(enabled=False, dtype=torch.bfloat16):
+ aggregated_tokens_list, patch_start_idx = model.encoder.aggregator(vggt_input_image, intermediate_layer_idx=model.encoder.cfg.intermediate_layer_idx)
+ with torch.cuda.amp.autocast(enabled=False):
+ fp32_tokens = [token.float() for token in aggregated_tokens_list]
+ pred_all_pose_enc = model.encoder.camera_head(fp32_tokens)[-1]
+ pred_all_extrinsic, pred_all_intrinsic = pose_encoding_to_extri_intri(pred_all_pose_enc, vggt_input_image.shape[-2:])
+
+ extrinsic_padding = torch.tensor([0, 0, 0, 1], device=pred_all_extrinsic.device, dtype=pred_all_extrinsic.dtype).view(1, 1, 1, 4).repeat(b, vggt_input_image.shape[1], 1, 1)
+ pred_all_extrinsic = torch.cat([pred_all_extrinsic, extrinsic_padding], dim=2).inverse()
+
+ pred_all_intrinsic[:, :, 0] = pred_all_intrinsic[:, :, 0] / w
+ pred_all_intrinsic[:, :, 1] = pred_all_intrinsic[:, :, 1] / h
+ pred_all_context_extrinsic, pred_all_target_extrinsic = pred_all_extrinsic[:, :num_context_view], pred_all_extrinsic[:, num_context_view:]
+ pred_all_context_intrinsic, pred_all_target_intrinsic = pred_all_intrinsic[:, :num_context_view], pred_all_intrinsic[:, num_context_view:]
+
+ scale_factor = pred_context_pose['extrinsic'][:, :, :3, 3].mean() / pred_all_context_extrinsic[:, :, :3, 3].mean()
+ pred_all_target_extrinsic[..., :3, 3] = pred_all_target_extrinsic[..., :3, 3] * scale_factor
+ pred_all_context_extrinsic[..., :3, 3] = pred_all_context_extrinsic[..., :3, 3] * scale_factor
+ print("scale_factor:", scale_factor)
+
+ output = model.decoder.forward(
+ gaussians,
+ pred_all_target_extrinsic,
+ pred_all_target_intrinsic.float(),
+ torch.ones(1, v, device=device) * 0.01,
+ torch.ones(1, v, device=device) * 100,
+ (h, w)
+ )
+
+ save_interpolated_video(pred_all_context_extrinsic, pred_all_context_intrinsic, b, h, w, gaussians, args.output_path, model.decoder)
+
+ # Save original images
+ save_path = Path(args.output_path)
+ # os.makedirs(save_path, exist_ok=True)
+ for idx, (gt_image, pred_image) in enumerate(zip(tgt_images[0], output.color[0])):
+ save_image(gt_image, save_path / "gt" / f"{idx:0>6}.jpg")
+ save_image(pred_image, save_path / "pred" / f"{idx:0>6}.jpg")
+
+ # compute metrics
+ psnr, ssim, lpips = compute_metrics(output.color[0], tgt_images[0])
+ print(f"PSNR: {psnr.mean():.2f}, SSIM: {ssim.mean():.3f}, LPIPS: {lpips.mean():.3f}")
+
+if __name__ == "__main__":
+ args = setup_args()
+ evaluate(args)
diff --git a/src/eval_pose.py b/src/eval_pose.py
new file mode 100644
index 0000000000000000000000000000000000000000..068626e0f9b3818065d717e07cd3051843fa9f27
--- /dev/null
+++ b/src/eval_pose.py
@@ -0,0 +1,339 @@
+import os
+import sys
+import json
+import gzip
+import argparse
+import numpy as np
+from PIL import Image
+
+import torch
+import torch.nn as nn
+import torchvision
+from einops import rearrange
+from lpips import LPIPS
+
+sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from src.model.model.anysplat import AnySplat
+from src.model.encoder.vggt.utils.pose_enc import pose_encoding_to_extri_intri
+from src.model.encoder.vggt.utils.load_fn import load_and_preprocess_images
+from src.utils.pose import align_to_first_camera, calculate_auc_np, convert_pt3d_RT_to_opencv, se3_to_relative_pose_error
+from src.misc.cam_utils import camera_normalization, pose_auc, rotation_6d_to_matrix, update_pose, get_pnp_pose
+
+def setup_args():
+ """Set up command-line arguments for the CO3D evaluation script."""
+ parser = argparse.ArgumentParser(description='Test AnySplat on CO3D dataset')
+ parser.add_argument('--debug', action='store_true', help='Enable debug mode (only test on specific category)')
+ parser.add_argument('--use_ba', action='store_true', default=False, help='Enable bundle adjustment')
+ parser.add_argument('--fast_eval', action='store_true', default=False, help='Only evaluate 10 sequences per category')
+ parser.add_argument('--min_num_images', type=int, default=50, help='Minimum number of images for a sequence')
+ parser.add_argument('--num_frames', type=int, default=10, help='Number of frames to use for testing')
+ parser.add_argument('--co3d_dir', type=str, required=True, help='Path to CO3D dataset')
+ parser.add_argument('--co3d_anno_dir', type=str, required=True, help='Path to CO3D annotations')
+ parser.add_argument('--seed', type=int, default=0, help='Random seed for reproducibility')
+ return parser.parse_args()
+
+lpips = LPIPS(net="vgg")
+
+def rendering_loss(pred_image, image):
+ lpips_loss = lpips.forward(rearrange(pred_image, "b v c h w -> (b v) c h w"), rearrange(image, "b v c h w -> (b v) c h w"), normalize=True)
+ delta = pred_image - (image + 1) / 2
+ mse_loss = (delta**2).mean()
+ return mse_loss + 0.05 * lpips_loss.mean()
+
+def process_sequence(model, seq_name, seq_data, category, co3d_dir, min_num_images, num_frames, use_ba, device, dtype):
+ """
+ Process a single sequence and compute pose errors.
+
+ Args:
+ model: AnySplat model
+ seq_name: Sequence name
+ seq_data: Sequence data
+ category: Category name
+ co3d_dir: CO3D dataset directory
+ min_num_images: Minimum number of images required
+ num_frames: Number of frames to sample
+ use_ba: Whether to use bundle adjustment
+ device: Device to run on
+ dtype: Data type for model inference
+
+ Returns:
+ rError: Rotation errors
+ tError: Translation errors
+ """
+ if len(seq_data) < min_num_images:
+ return None, None
+
+ metadata = []
+ for data in seq_data:
+ # Make sure translations are not ridiculous
+ if data["T"][0] + data["T"][1] + data["T"][2] > 1e5:
+ return None, None
+
+ extri_opencv = convert_pt3d_RT_to_opencv(data["R"], data["T"])
+ metadata.append({
+ "filepath": data["filepath"],
+ "extri": extri_opencv,
+ })
+
+ ids = np.random.choice(len(metadata), num_frames, replace=False)
+ image_names = [os.path.join(co3d_dir, metadata[i]["filepath"]) for i in ids]
+ gt_extri = [np.array(metadata[i]["extri"]) for i in ids]
+ gt_extri = np.stack(gt_extri, axis=0)
+
+ max_size = max(Image.open(image_names[0]).size)
+ if max_size < 448:
+ return None, None
+ images = load_and_preprocess_images(image_names)[None].to(device)
+
+ batch = {
+ "context": {
+ "image": images*2.0-1,
+ "image_names": image_names,
+ "index": ids,
+ },
+ "scene": "co3d"
+ }
+
+ if use_ba:
+ try:
+ encoder_output = model.encoder(
+ batch,
+ global_step=0,
+ visualization_dump={},
+ )
+ gaussians, pred_context_pose = encoder_output.gaussians, encoder_output.pred_context_pose
+ pred_extrinsic = pred_context_pose['extrinsic']
+ pred_intrinsic = pred_context_pose['intrinsic']
+ # rendering ba
+ b, v, _, h, w = images.shape
+ with torch.set_grad_enabled(True), torch.cuda.amp.autocast(enabled=False, dtype=torch.float32):
+ cam_rot_delta = nn.Parameter(torch.zeros([b, v, 6], requires_grad=True, device=pred_extrinsic.device, dtype=torch.float32))
+ cam_trans_delta = nn.Parameter(torch.zeros([b, v, 3], requires_grad=True, device=pred_extrinsic.device, dtype=torch.float32))
+ opt_params = []
+ model.register_buffer("identity", torch.tensor([1.0, 0.0, 0.0, 0.0, 1.0, 0.0], dtype=torch.float32).to(pred_extrinsic.device))
+ opt_params.append(
+ {
+ "params": [cam_rot_delta],
+ "lr": 0.005,
+ }
+ )
+ opt_params.append(
+ {
+ "params": [cam_trans_delta],
+ "lr": 0.005,
+ }
+ )
+ pose_optimizer = torch.optim.Adam(opt_params)
+ extrinsics = pred_extrinsic.clone().float()
+
+ for i in range(100):
+ pose_optimizer.zero_grad()
+ dx, drot = cam_trans_delta, cam_rot_delta
+ rot = rotation_6d_to_matrix(
+ drot + model.identity.expand(b, v, -1)
+ ) # (..., 3, 3)
+
+ transform = torch.eye(4, device=extrinsics.device).repeat((b, v, 1, 1))
+ transform[..., :3, :3] = rot
+ transform[..., :3, 3] = dx
+
+ new_extrinsics = torch.matmul(extrinsics, transform)
+ # breakpoint()
+ output = model.decoder.forward(
+ gaussians,
+ new_extrinsics,
+ pred_intrinsic.float(),
+ 0.1,
+ 100.0,
+ (h, w),
+ # cam_rot_delta=cam_rot_delta,
+ # cam_trans_delta=cam_trans_delta,
+ )
+ # export_ply(gaussians.means[0], gaussians.scales[0], gaussians.rotations[0], gaussians.harmonics[0], gaussians.opacities[0], Path(f"gaussians_co3d.ply"))
+ rendering_loss = rendering_loss(output.color, images*2.0-1)
+ torchvision.utils.save_image(output.color[0], f"outputs/vis/output_co3d_{i}.png")
+ print(f"Rendering loss: {rendering_loss.item()}")
+ # print(f"Rendering loss: {rendering_loss.item()}")
+
+ rendering_loss.backward()
+ pose_optimizer.step()
+ torchvision.utils.save_image(images[0], f"outputs/vis/gt_co3d.png")
+ pred_extrinsic = new_extrinsics.inverse()[0][:,:-1,:]
+
+ except Exception as e:
+ print(f"BA failed with error: {e}. Falling back to standard VGGT inference.")
+ with torch.no_grad(), torch.cuda.amp.autocast(dtype=dtype):
+ aggregated_tokens_list, patch_start_idx = model.encoder.aggregator(images, intermediate_layer_idx=model.encoder.cfg.intermediate_layer_idx)
+ with torch.cuda.amp.autocast(dtype=torch.float32):
+ fp32_tokens = [token.float() for token in aggregated_tokens_list]
+ pred_all_pose_enc = model.encoder.camera_head(fp32_tokens)[-1]
+ pred_all_extrinsic, pred_all_intrinsic = pose_encoding_to_extri_intri(pred_all_pose_enc, images.shape[-2:])
+ pred_extrinsic = pred_all_extrinsic[0]
+ else:
+ with torch.no_grad(), torch.cuda.amp.autocast(dtype=dtype):
+ aggregated_tokens_list, patch_start_idx = model.encoder.aggregator(images, intermediate_layer_idx=model.encoder.cfg.intermediate_layer_idx)
+ with torch.cuda.amp.autocast(dtype=torch.float32):
+ fp32_tokens = [token.float() for token in aggregated_tokens_list]
+ pred_all_pose_enc = model.encoder.camera_head(fp32_tokens)[-1]
+ pred_all_extrinsic, pred_all_intrinsic = pose_encoding_to_extri_intri(pred_all_pose_enc, images.shape[-2:])
+ pred_extrinsic = pred_all_extrinsic[0]
+
+ with torch.cuda.amp.autocast(dtype=torch.float32):
+ gt_extrinsic = torch.from_numpy(gt_extri).to(device)
+ add_row = torch.tensor([0, 0, 0, 1], device=device).expand(pred_extrinsic.size(0), 1, 4)
+
+ pred_se3 = torch.cat((pred_extrinsic, add_row), dim=1)
+ gt_se3 = torch.cat((gt_extrinsic, add_row), dim=1)
+
+ # Set the coordinate of the first camera as the coordinate of the world
+ # NOTE: DO NOT REMOVE THIS UNLESS YOU KNOW WHAT YOU ARE DOING
+ # pred_se3 = align_to_first_camera(pred_se3)
+ gt_se3 = align_to_first_camera(gt_se3)
+
+ rel_rangle_deg, rel_tangle_deg = se3_to_relative_pose_error(pred_se3, gt_se3, num_frames)
+ print(f"{category} sequence {seq_name} Rot Error: {rel_rangle_deg.mean().item():.4f}")
+ print(f"{category} sequence {seq_name} Trans Error: {rel_tangle_deg.mean().item():.4f}")
+
+ return rel_rangle_deg.cpu().numpy(), rel_tangle_deg.cpu().numpy()
+
+def evaluate(args: argparse.Namespace):
+ model = AnySplat.from_pretrained("lhjiang/anysplat")
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ model.to(device)
+ model.eval()
+ for param in model.parameters():
+ param.requires_grad = False
+
+ # CO3D evaluation
+ SEEN_CATEGORIES = [
+ "apple", "backpack", "banana", "baseballbat", "baseballglove",
+ "bench", "bicycle", "bottle", "bowl", "broccoli",
+ "cake", "car", "carrot", "cellphone", "chair",
+ "cup", "donut", "hairdryer", "handbag", "hydrant",
+ "keyboard", "laptop", "microwave", "motorcycle", "mouse",
+ "orange", "parkingmeter", "pizza", "plant", "stopsign",
+ "teddybear", "toaster", "toilet", "toybus", "toyplane",
+ "toytrain", "toytruck", "tv", "umbrella", "vase", "wineglass",
+ ]
+
+ if args.debug:
+ SEEN_CATEGORIES = ["apple"]
+
+ per_category_results = {}
+
+ for category in SEEN_CATEGORIES:
+ print(f"Loading annotation for {category} test set")
+ annotation_file = os.path.join(args.co3d_anno_dir, f"{category}_test.jgz")
+
+ try:
+ with gzip.open(annotation_file, "r") as fin:
+ annotation = json.loads(fin.read())
+ except FileNotFoundError:
+ print(f"Annotation file not found for {category}, skipping")
+ continue
+
+ rError = []
+ tError = []
+
+ for seq_name, seq_data in annotation.items():
+ print("-" * 50)
+
+ print(f"Processing {seq_name} for {category} test set")
+ if args.debug and not os.path.exists(os.path.join(args.co3d_dir, category, seq_name)):
+ print(f"Skipping {seq_name} (not found)")
+ continue
+
+ seq_rError, seq_tError = process_sequence(
+ model, seq_name, seq_data, category, args.co3d_dir,
+ args.min_num_images, args.num_frames, args.use_ba, device, torch.bfloat16
+ )
+
+ print("-" * 50)
+
+ if seq_rError is not None and seq_tError is not None:
+ rError.extend(seq_rError)
+ tError.extend(seq_tError)
+
+ if not rError:
+ print(f"No valid sequences found for {category}, skipping")
+ continue
+
+ rError = np.array(rError)
+ tError = np.array(tError)
+
+ thresholds = [5, 10, 20, 30]
+ Aucs = {}
+
+ for threshold in thresholds:
+ Auc, _ = calculate_auc_np(rError, tError, max_threshold=threshold)
+ Aucs[threshold] = Auc
+
+ print("="*80)
+ print(f"AUC of {category} test set: {Aucs[30]:.4f}")
+ print("="*80)
+
+ per_category_results[category] = {
+ "rError": rError,
+ "tError": tError,
+ "Auc_5": Aucs[5],
+ "Auc_10": Aucs[10],
+ "Auc_20": Aucs[20],
+ "Auc_30": Aucs[30],
+ }
+
+ # Print summary results
+ print("\nSummary of AUC results:")
+ print("-"*50)
+ for category in sorted(per_category_results.keys()):
+ print(f"{category:<15} AUC_5: {per_category_results[category]['Auc_5']:.4f}")
+ print(f"{category:<15} AUC_30: {per_category_results[category]['Auc_30']:.4f}")
+ print(f"{category:<15} AUC_20: {per_category_results[category]['Auc_20']:.4f}")
+ print(f"{category:<15} AUC_10: {per_category_results[category]['Auc_10']:.4f}")
+
+ if per_category_results:
+ mean_AUC_30 = np.mean([per_category_results[category]["Auc_30"] for category in per_category_results])
+ mean_AUC_20 = np.mean([per_category_results[category]["Auc_20"] for category in per_category_results])
+ mean_AUC_10 = np.mean([per_category_results[category]["Auc_10"] for category in per_category_results])
+ mean_AUC_5 = np.mean([per_category_results[category]["Auc_5"] for category in per_category_results])
+ print("-"*50)
+ print(f"Mean AUC_5: {mean_AUC_5:.4f}")
+ print(f"Mean AUC_30: {mean_AUC_30:.4f}")
+ print(f"Mean AUC_20: {mean_AUC_20:.4f}")
+ print(f"Mean AUC_10: {mean_AUC_10:.4f}")
+
+ # Generate a random index to avoid overwriting previous results
+ # random_index = torch.randint(0, 10000, (1,)).item()
+ # Use timestamp as index instead of random number
+ import datetime
+ timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
+ random_index = timestamp
+ results_file = f"co3d_results_{random_index}.txt"
+
+ with open(results_file, "w") as f:
+ f.write("CO3D Evaluation Results\n")
+ f.write("=" * 50 + "\n\n")
+
+ f.write("Per-category results:\n")
+ f.write("-" * 50 + "\n")
+ for category in sorted(per_category_results.keys()):
+ f.write(f"{category:<15} AUC_30: {per_category_results[category]['Auc_30']:.4f}\n")
+ f.write(f"{category:<15} AUC_20: {per_category_results[category]['Auc_20']:.4f}\n")
+ f.write(f"{category:<15} AUC_10: {per_category_results[category]['Auc_10']:.4f}\n")
+ f.write(f"{category:<15} AUC_5: {per_category_results[category]['Auc_5']:.4f}\n")
+ f.write("\n")
+
+ if per_category_results:
+ f.write("-" * 50 + "\n")
+ f.write(f"Mean AUC_30: {mean_AUC_30:.4f}\n")
+ f.write(f"Mean AUC_20: {mean_AUC_20:.4f}\n")
+ f.write(f"Mean AUC_10: {mean_AUC_10:.4f}\n")
+ f.write(f"Mean AUC_5: {mean_AUC_5:.4f}\n")
+ f.write("\n" + "=" * 50 + "\n")
+
+ print(f"Results saved to {results_file}")
+
+
+if __name__ == "__main__":
+ args = setup_args()
+ evaluate(args)
diff --git a/src/evaluation/evaluation_cfg.py b/src/evaluation/evaluation_cfg.py
new file mode 100644
index 0000000000000000000000000000000000000000..1a0245c787f2941b2c79b960c3691fcabe700842
--- /dev/null
+++ b/src/evaluation/evaluation_cfg.py
@@ -0,0 +1,24 @@
+from dataclasses import dataclass
+from pathlib import Path
+
+
+@dataclass
+class MethodCfg:
+ name: str
+ key: str
+ path: Path
+
+
+@dataclass
+class SceneCfg:
+ scene: str
+ target_index: int
+
+
+@dataclass
+class EvaluationCfg:
+ methods: list[MethodCfg]
+ side_by_side_path: Path | None
+ output_metrics_path: Path
+ animate_side_by_side: bool
+ highlighted: list[SceneCfg]
diff --git a/src/evaluation/evaluation_index_generator.py b/src/evaluation/evaluation_index_generator.py
new file mode 100644
index 0000000000000000000000000000000000000000..a3d037a442a76adcbeb20c0ed16c075f9a3c81e1
--- /dev/null
+++ b/src/evaluation/evaluation_index_generator.py
@@ -0,0 +1,160 @@
+import json
+from dataclasses import asdict, dataclass
+from pathlib import Path
+from typing import Optional
+
+import torch
+from einops import rearrange
+from lightning.pytorch import LightningModule
+from tqdm import tqdm
+
+from ..geometry.epipolar_lines import project_rays
+from ..geometry.projection import get_world_rays, sample_image_grid
+from ..misc.image_io import save_image
+from ..visualization.annotation import add_label
+from ..visualization.layout import add_border, hcat
+
+
+@dataclass
+class EvaluationIndexGeneratorCfg:
+ num_target_views: int
+ min_distance: int
+ max_distance: int
+ min_overlap: float
+ max_overlap: float
+ output_path: Path
+ save_previews: bool
+ seed: int
+
+
+@dataclass
+class IndexEntry:
+ context: tuple[int, ...]
+ target: tuple[int, ...]
+ overlap: Optional[str | float] = None # choose from ["small", "medium", "large"] or a float number indicates the overlap ratio
+
+
+class EvaluationIndexGenerator(LightningModule):
+ generator: torch.Generator
+ cfg: EvaluationIndexGeneratorCfg
+ index: dict[str, IndexEntry | None]
+
+ def __init__(self, cfg: EvaluationIndexGeneratorCfg) -> None:
+ super().__init__()
+ self.cfg = cfg
+ self.generator = torch.Generator()
+ self.generator.manual_seed(cfg.seed)
+ self.index = {}
+
+ def test_step(self, batch, batch_idx):
+ b, v, _, h, w = batch["target"]["image"].shape
+ assert b == 1
+ extrinsics = batch["target"]["extrinsics"][0]
+ intrinsics = batch["target"]["intrinsics"][0]
+ scene = batch["scene"][0]
+
+ context_indices = torch.randperm(v, generator=self.generator)
+ for context_index in tqdm(context_indices, "Finding context pair"):
+ xy, _ = sample_image_grid((h, w), self.device)
+ context_origins, context_directions = get_world_rays(
+ rearrange(xy, "h w xy -> (h w) xy"),
+ extrinsics[context_index],
+ intrinsics[context_index],
+ )
+
+ # Step away from context view until the minimum overlap threshold is met.
+ valid_indices = []
+ for step in (1, -1):
+ min_distance = self.cfg.min_distance
+ max_distance = self.cfg.max_distance
+ current_index = context_index + step * min_distance
+
+ while 0 <= current_index.item() < v:
+ # Compute overlap.
+ current_origins, current_directions = get_world_rays(
+ rearrange(xy, "h w xy -> (h w) xy"),
+ extrinsics[current_index],
+ intrinsics[current_index],
+ )
+ projection_onto_current = project_rays(
+ context_origins,
+ context_directions,
+ extrinsics[current_index],
+ intrinsics[current_index],
+ )
+ projection_onto_context = project_rays(
+ current_origins,
+ current_directions,
+ extrinsics[context_index],
+ intrinsics[context_index],
+ )
+ overlap_a = projection_onto_context["overlaps_image"].float().mean()
+ overlap_b = projection_onto_current["overlaps_image"].float().mean()
+
+ overlap = min(overlap_a, overlap_b)
+ delta = (current_index - context_index).abs()
+
+ min_overlap = self.cfg.min_overlap
+ max_overlap = self.cfg.max_overlap
+ if min_overlap <= overlap <= max_overlap:
+ valid_indices.append(
+ (current_index.item(), overlap_a, overlap_b)
+ )
+
+ # Stop once the camera has panned away too much.
+ if overlap < min_overlap or delta > max_distance:
+ break
+
+ current_index += step
+
+ if valid_indices:
+ # Pick a random valid view. Index the resulting views.
+ num_options = len(valid_indices)
+ chosen = torch.randint(
+ 0, num_options, size=tuple(), generator=self.generator
+ )
+ chosen, overlap_a, overlap_b = valid_indices[chosen]
+
+ context_left = min(chosen, context_index.item())
+ context_right = max(chosen, context_index.item())
+ delta = context_right - context_left
+
+ # Pick non-repeated random target views.
+ while True:
+ target_views = torch.randint(
+ context_left,
+ context_right + 1,
+ (self.cfg.num_target_views,),
+ generator=self.generator,
+ )
+ if (target_views.unique(return_counts=True)[1] == 1).all():
+ break
+
+ target = tuple(sorted(target_views.tolist()))
+ self.index[scene] = IndexEntry(
+ context=(context_left, context_right),
+ target=target,
+ )
+
+ # Optionally, save a preview.
+ if self.cfg.save_previews:
+ preview_path = self.cfg.output_path / "previews"
+ preview_path.mkdir(exist_ok=True, parents=True)
+ a = batch["target"]["image"][0, chosen]
+ a = add_label(a, f"Overlap: {overlap_a * 100:.1f}%")
+ b = batch["target"]["image"][0, context_index]
+ b = add_label(b, f"Overlap: {overlap_b * 100:.1f}%")
+ vis = add_border(add_border(hcat(a, b)), 1, 0)
+ vis = add_label(vis, f"Distance: {delta} frames")
+ save_image(add_border(vis), preview_path / f"{scene}.png")
+ break
+ else:
+ # This happens if no starting frame produces a valid evaluation example.
+ self.index[scene] = None
+
+ def save_index(self) -> None:
+ self.cfg.output_path.mkdir(exist_ok=True, parents=True)
+ with (self.cfg.output_path / "evaluation_index.json").open("w") as f:
+ json.dump(
+ {k: None if v is None else asdict(v) for k, v in self.index.items()}, f
+ )
diff --git a/src/evaluation/metrics.py b/src/evaluation/metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..75894fa795653dfcd4e0594aecfb221a2c079e86
--- /dev/null
+++ b/src/evaluation/metrics.py
@@ -0,0 +1,134 @@
+from functools import cache
+
+import torch
+from einops import reduce
+from jaxtyping import Float
+from lpips import LPIPS
+from skimage.metrics import structural_similarity
+from torch import Tensor
+
+
+@torch.no_grad()
+def compute_psnr(
+ ground_truth: Float[Tensor, "batch channel height width"],
+ predicted: Float[Tensor, "batch channel height width"],
+) -> Float[Tensor, " batch"]:
+ ground_truth = ground_truth.clip(min=0, max=1)
+ predicted = predicted.clip(min=0, max=1)
+ mse = reduce((ground_truth - predicted) ** 2, "b c h w -> b", "mean")
+ return -10 * mse.log10()
+
+
+@cache
+def get_lpips(device: torch.device) -> LPIPS:
+ return LPIPS(net="vgg").to(device)
+
+
+@torch.no_grad()
+def compute_lpips(
+ ground_truth: Float[Tensor, "batch channel height width"],
+ predicted: Float[Tensor, "batch channel height width"],
+) -> Float[Tensor, " batch"]:
+ value = get_lpips(predicted.device).forward(ground_truth, predicted, normalize=True)
+ return value[:, 0, 0, 0]
+
+
+@torch.no_grad()
+def compute_ssim(
+ ground_truth: Float[Tensor, "batch channel height width"],
+ predicted: Float[Tensor, "batch channel height width"],
+) -> Float[Tensor, " batch"]:
+ ssim = [
+ structural_similarity(
+ gt.detach().cpu().numpy(),
+ hat.detach().cpu().numpy(),
+ win_size=11,
+ gaussian_weights=True,
+ channel_axis=0,
+ data_range=1.0,
+ )
+ for gt, hat in zip(ground_truth, predicted)
+ ]
+ return torch.tensor(ssim, dtype=predicted.dtype, device=predicted.device)
+
+
+def compute_geodesic_distance_from_two_matrices(m1, m2):
+ batch = m1.shape[0]
+ m = torch.bmm(m1, m2.transpose(1, 2)) # batch*3*3
+
+ cos = (m[:, 0, 0] + m[:, 1, 1] + m[:, 2, 2] - 1) / 2
+ cos = torch.min(cos, torch.autograd.Variable(torch.ones(batch).to(m1.device)))
+ cos = torch.max(cos, torch.autograd.Variable(torch.ones(batch).to(m1.device)) * -1)
+
+ theta = torch.acos(cos)
+
+ # theta = torch.min(theta, 2*np.pi - theta)
+
+ return theta
+
+
+def angle_error_mat(R1, R2):
+ cos = (torch.trace(torch.mm(R1.T, R2)) - 1) / 2
+ cos = torch.clamp(cos, -1.0, 1.0) # numerical errors can make it out of bounds
+ return torch.rad2deg(torch.abs(torch.acos(cos)))
+
+
+def angle_error_vec(v1, v2):
+ n = torch.norm(v1) * torch.norm(v2)
+ cos_theta = torch.dot(v1, v2) / n
+ cos_theta = torch.clamp(cos_theta, -1.0, 1.0) # numerical errors can make it out of bounds
+ return torch.rad2deg(torch.acos(cos_theta))
+
+
+def compute_translation_error(t1, t2):
+ return torch.norm(t1 - t2)
+
+
+@torch.no_grad()
+def compute_pose_error(pose_gt, pose_pred):
+ R_gt = pose_gt[:3, :3]
+ t_gt = pose_gt[:3, 3]
+
+ R = pose_pred[:3, :3]
+ t = pose_pred[:3, 3]
+
+ error_t = angle_error_vec(t, t_gt)
+ error_t = torch.minimum(error_t, 180 - error_t) # ambiguity of E estimation
+ error_t_scale = compute_translation_error(t, t_gt)
+ error_R = angle_error_mat(R, R_gt)
+ return error_t, error_t_scale, error_R
+
+@torch.no_grad()
+def abs_relative_difference(output, target, valid_mask=None):
+ actual_output = output
+ actual_target = target
+ abs_relative_diff = torch.abs(actual_output - actual_target) / actual_target
+ if valid_mask is not None:
+ abs_relative_diff[~valid_mask] = 0
+ n = valid_mask.sum((-1, -2))
+ else:
+ n = output.shape[-1] * output.shape[-2]
+ abs_relative_diff = torch.sum(abs_relative_diff, (-1, -2)) / n
+ return abs_relative_diff.mean()
+
+# adapt from: https://github.com/imran3180/depth-map-prediction/blob/master/main.py
+@torch.no_grad()
+def threshold_percentage(output, target, threshold_val, valid_mask=None):
+ d1 = output / target
+ d2 = target / output
+ max_d1_d2 = torch.max(d1, d2)
+ zero = torch.zeros_like(output)
+ one = torch.ones_like(output)
+ bit_mat = torch.where(max_d1_d2 < threshold_val, one, zero)
+ if valid_mask is not None:
+ bit_mat[~valid_mask] = 0
+ n = valid_mask.sum((-1, -2))
+ else:
+ n = output.shape[-1] * output.shape[-2]
+ count_mat = torch.sum(bit_mat, (-1, -2))
+ threshold_mat = count_mat / n
+ return threshold_mat.mean()
+
+@torch.no_grad()
+def delta1_acc(pred, gt, valid_mask):
+ return threshold_percentage(pred, gt, 1.25, valid_mask)
\ No newline at end of file
diff --git a/src/geometry/camera_emb.py b/src/geometry/camera_emb.py
new file mode 100644
index 0000000000000000000000000000000000000000..39beeae7fc9b78743d70db702185d3cb07562184
--- /dev/null
+++ b/src/geometry/camera_emb.py
@@ -0,0 +1,29 @@
+from einops import rearrange
+
+from .projection import sample_image_grid, get_local_rays
+from ..misc.sht import rsh_cart_2, rsh_cart_4, rsh_cart_6, rsh_cart_8
+
+
+def get_intrinsic_embedding(context, degree=0, downsample=1, merge_hw=False):
+ assert degree in [0, 2, 4, 8]
+
+ b, v, _, h, w = context["image"].shape
+ device = context["image"].device
+ tgt_h, tgt_w = h // downsample, w // downsample
+ xy_ray, _ = sample_image_grid((tgt_h, tgt_w), device)
+ xy_ray = xy_ray[None, None, ...].expand(b, v, -1, -1, -1) # [b, v, h, w, 2]
+ directions = get_local_rays(xy_ray, rearrange(context["intrinsics"], "b v i j -> b v () () i j"),)
+
+ if degree == 2:
+ directions = rsh_cart_2(directions)
+ elif degree == 4:
+ directions = rsh_cart_4(directions)
+ elif degree == 8:
+ directions = rsh_cart_8(directions)
+
+ if merge_hw:
+ directions = rearrange(directions, "b v h w d -> b v (h w) d")
+ else:
+ directions = rearrange(directions, "b v h w d -> b v d h w")
+
+ return directions
diff --git a/src/geometry/epipolar_lines.py b/src/geometry/epipolar_lines.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4ea46bfbdfefb77851e747dd16d1eade1970d2c
--- /dev/null
+++ b/src/geometry/epipolar_lines.py
@@ -0,0 +1,292 @@
+import itertools
+from typing import Iterable, Literal, Optional, TypedDict
+
+import torch
+from einops import einsum, repeat
+from jaxtyping import Bool, Float
+from torch import Tensor
+from torch.utils.data.dataloader import default_collate
+
+from .projection import (
+ get_world_rays,
+ homogenize_points,
+ homogenize_vectors,
+ intersect_rays,
+ project_camera_space,
+)
+
+
+def _is_in_bounds(
+ xy: Float[Tensor, "*batch 2"],
+ epsilon: float = 1e-6,
+) -> Bool[Tensor, " *batch"]:
+ """Check whether the specified XY coordinates are within the normalized image plane,
+ which has a range from 0 to 1 in each direction.
+ """
+ return (xy >= -epsilon).all(dim=-1) & (xy <= 1 + epsilon).all(dim=-1)
+
+
+def _is_in_front_of_camera(
+ xyz: Float[Tensor, "*batch 3"],
+ epsilon: float = 1e-6,
+) -> Bool[Tensor, " *batch"]:
+ """Check whether the specified points in camera space are in front of the camera."""
+ return xyz[..., -1] > -epsilon
+
+
+def _is_positive_t(
+ t: Float[Tensor, " *batch"],
+ epsilon: float = 1e-6,
+) -> Bool[Tensor, " *batch"]:
+ """Check whether the specified t value is positive."""
+ return t > -epsilon
+
+
+class PointProjection(TypedDict):
+ t: Float[Tensor, " *batch"] # ray parameter, as in xyz = origin + t * direction
+ xy: Float[Tensor, "*batch 2"] # image-space xy (normalized to 0 to 1)
+
+ # A "valid" projection satisfies two conditions:
+ # 1. It is in front of the camera (i.e., its 3D Z coordinate is positive).
+ # 2. It is within the image frame (i.e., its 2D coordinates are between 0 and 1).
+ valid: Bool[Tensor, " *batch"]
+
+
+def _intersect_image_coordinate(
+ intrinsics: Float[Tensor, "*#batch 3 3"],
+ origins: Float[Tensor, "*#batch 3"],
+ directions: Float[Tensor, "*#batch 3"],
+ dimension: Literal["x", "y"],
+ coordinate_value: float,
+) -> PointProjection:
+ """Compute the intersection of the projection of a camera-space ray with a line
+ that's parallel to the image frame, either horizontally or vertically.
+ """
+
+ # Define shorthands.
+ dim = "xy".index(dimension)
+ other_dim = 1 - dim
+ fs = intrinsics[..., dim, dim] # focal length, same coordinate
+ fo = intrinsics[..., other_dim, other_dim] # focal length, other coordinate
+ cs = intrinsics[..., dim, 2] # principal point, same coordinate
+ co = intrinsics[..., other_dim, 2] # principal point, other coordinate
+ os = origins[..., dim] # ray origin, same coordinate
+ oo = origins[..., other_dim] # ray origin, other coordinate
+ ds = directions[..., dim] # ray direction, same coordinate
+ do = directions[..., other_dim] # ray direction, other coordinate
+ oz = origins[..., 2] # ray origin, z coordinate
+ dz = directions[..., 2] # ray direction, z coordinate
+ c = (coordinate_value - cs) / fs # coefficient (computed once and factored out)
+
+ # Compute the value of t at the intersection.
+ # Note: Infinite values of t are fine. No need to handle division by zero.
+ t_numerator = c * oz - os
+ t_denominator = ds - c * dz
+ t = t_numerator / t_denominator
+
+ # Compute the value of the other coordinate at the intersection.
+ # Note: Infinite coordinate values are fine. No need to handle division by zero.
+ coordinate_numerator = fo * (oo * (c * dz - ds) + do * (os - c * oz))
+ coordinate_denominator = dz * os - ds * oz
+ coordinate_other = co + coordinate_numerator / coordinate_denominator
+ coordinate_same = torch.ones_like(coordinate_other) * coordinate_value
+ xy = [coordinate_same]
+ xy.insert(other_dim, coordinate_other)
+ xy = torch.stack(xy, dim=-1)
+ xyz = origins + t[..., None] * directions
+
+ # These will all have exactly the same batch shape (no broadcasting necessary). In
+ # terms of jaxtyping annotations, they all match *batch, not just *#batch.
+ return {
+ "t": t,
+ "xy": xy,
+ "valid": _is_in_bounds(xy) & _is_in_front_of_camera(xyz) & _is_positive_t(t),
+ }
+
+
+def _compare_projections(
+ intersections: Iterable[PointProjection],
+ reduction: Literal["min", "max"],
+) -> PointProjection:
+ intersections = {k: v.clone() for k, v in default_collate(intersections).items()}
+ t = intersections["t"]
+ xy = intersections["xy"]
+ valid = intersections["valid"]
+
+ # Make sure out-of-bounds values are not chosen.
+ lowest_priority = {
+ "min": torch.inf,
+ "max": -torch.inf,
+ }[reduction]
+ t[~valid] = lowest_priority
+
+ # Run the reduction (either t.min() or t.max()).
+ reduced, selector = getattr(t, reduction)(dim=0)
+
+ # Index the results.
+ return {
+ "t": reduced,
+ "xy": xy.gather(0, repeat(selector, "... -> () ... xy", xy=2))[0],
+ "valid": valid.gather(0, selector[None])[0],
+ }
+
+
+def _compute_point_projection(
+ xyz: Float[Tensor, "*#batch 3"],
+ t: Float[Tensor, "*#batch"],
+ intrinsics: Float[Tensor, "*#batch 3 3"],
+) -> PointProjection:
+ xy = project_camera_space(xyz, intrinsics)
+ return {
+ "t": t,
+ "xy": xy,
+ "valid": _is_in_bounds(xy) & _is_in_front_of_camera(xyz) & _is_positive_t(t),
+ }
+
+
+class RaySegmentProjection(TypedDict):
+ t_min: Float[Tensor, " *batch"] # ray parameter
+ t_max: Float[Tensor, " *batch"] # ray parameter
+ xy_min: Float[Tensor, "*batch 2"] # image-space xy (normalized to 0 to 1)
+ xy_max: Float[Tensor, "*batch 2"] # image-space xy (normalized to 0 to 1)
+
+ # Whether the segment overlaps the image. If not, the above values are meaningless.
+ overlaps_image: Bool[Tensor, " *batch"]
+
+
+def project_rays(
+ origins: Float[Tensor, "*#batch 3"],
+ directions: Float[Tensor, "*#batch 3"],
+ extrinsics: Float[Tensor, "*#batch 4 4"],
+ intrinsics: Float[Tensor, "*#batch 3 3"],
+ near: Optional[Float[Tensor, "*#batch"]] = None,
+ far: Optional[Float[Tensor, "*#batch"]] = None,
+ epsilon: float = 1e-6,
+) -> RaySegmentProjection:
+ # Transform the rays into camera space.
+ world_to_cam = torch.linalg.inv(extrinsics)
+ origins = homogenize_points(origins)
+ origins = einsum(world_to_cam, origins, "... i j, ... j -> ... i")
+ directions = homogenize_vectors(directions)
+ directions = einsum(world_to_cam, directions, "... i j, ... j -> ... i")
+ origins = origins[..., :3]
+ directions = directions[..., :3]
+
+ # Compute intersections with the image's frame.
+ frame_intersections = (
+ _intersect_image_coordinate(intrinsics, origins, directions, "x", 0.0),
+ _intersect_image_coordinate(intrinsics, origins, directions, "x", 1.0),
+ _intersect_image_coordinate(intrinsics, origins, directions, "y", 0.0),
+ _intersect_image_coordinate(intrinsics, origins, directions, "y", 1.0),
+ )
+ frame_intersection_min = _compare_projections(frame_intersections, "min")
+ frame_intersection_max = _compare_projections(frame_intersections, "max")
+
+ if near is None:
+ # Compute the ray's projection at zero depth. If an origin's depth (z value) is
+ # within epsilon of zero, this can mean one of two things:
+ # 1. The origin is at the camera's position. In this case, use the direction
+ # instead (the ray is probably coming from the camera).
+ # 2. The origin isn't at the camera's position, and randomly happens to be on
+ # the plane at zero depth. In this case, its projection is outside the image
+ # plane, and is thus marked as invalid.
+ origins_for_projection = origins.clone()
+ mask_depth_zero = origins_for_projection[..., -1] < epsilon
+ mask_at_camera = origins_for_projection.norm(dim=-1) < epsilon
+ origins_for_projection[mask_at_camera] = directions[mask_at_camera]
+ projection_at_zero = _compute_point_projection(
+ origins_for_projection,
+ torch.zeros_like(frame_intersection_min["t"]),
+ intrinsics,
+ )
+ projection_at_zero["valid"][mask_depth_zero & ~mask_at_camera] = False
+ else:
+ # If a near plane is specified, use it instead.
+ t_near = near.broadcast_to(frame_intersection_min["t"].shape)
+ projection_at_zero = _compute_point_projection(
+ origins + near[..., None] * directions,
+ t_near,
+ intrinsics,
+ )
+
+ if far is None:
+ # Compute the ray's projection at infinite depth. Using the projection function
+ # with directions (vectors) instead of points may seem wonky, but is equivalent
+ # to projecting the point at (origins + infinity * directions).
+ projection_at_infinity = _compute_point_projection(
+ directions,
+ torch.ones_like(frame_intersection_min["t"]) * torch.inf,
+ intrinsics,
+ )
+ else:
+ # If a far plane is specified, use it instead.
+ t_far = far.broadcast_to(frame_intersection_min["t"].shape)
+ projection_at_infinity = _compute_point_projection(
+ origins + far[..., None] * directions,
+ t_far,
+ intrinsics,
+ )
+
+ # Build the result by handling cases for ray intersection.
+ result = {
+ "t_min": torch.empty_like(projection_at_zero["t"]),
+ "t_max": torch.empty_like(projection_at_infinity["t"]),
+ "xy_min": torch.empty_like(projection_at_zero["xy"]),
+ "xy_max": torch.empty_like(projection_at_infinity["xy"]),
+ "overlaps_image": torch.empty_like(projection_at_zero["valid"]),
+ }
+
+ for min_valid, max_valid in itertools.product([True, False], [True, False]):
+ min_mask = projection_at_zero["valid"] ^ (not min_valid)
+ max_mask = projection_at_infinity["valid"] ^ (not max_valid)
+ mask = min_mask & max_mask
+ min_value = projection_at_zero if min_valid else frame_intersection_min
+ max_value = projection_at_infinity if max_valid else frame_intersection_max
+ result["t_min"][mask] = min_value["t"][mask]
+ result["t_max"][mask] = max_value["t"][mask]
+ result["xy_min"][mask] = min_value["xy"][mask]
+ result["xy_max"][mask] = max_value["xy"][mask]
+ result["overlaps_image"][mask] = (min_value["valid"] & max_value["valid"])[mask]
+
+ return result
+
+
+class RaySegmentProjection(TypedDict):
+ t_min: Float[Tensor, " *batch"] # ray parameter
+ t_max: Float[Tensor, " *batch"] # ray parameter
+ xy_min: Float[Tensor, "*batch 2"] # image-space xy (normalized to 0 to 1)
+ xy_max: Float[Tensor, "*batch 2"] # image-space xy (normalized to 0 to 1)
+
+ # Whether the segment overlaps the image. If not, the above values are meaningless.
+ overlaps_image: Bool[Tensor, " *batch"]
+
+
+def lift_to_3d(
+ origins: Float[Tensor, "*#batch 3"],
+ directions: Float[Tensor, "*#batch 3"],
+ xy: Float[Tensor, "*#batch 2"],
+ extrinsics: Float[Tensor, "*#batch 4 4"],
+ intrinsics: Float[Tensor, "*#batch 3 3"],
+) -> Float[Tensor, "*batch 3"]:
+ """Calculate the 3D positions that correspond to the specified 2D points on the
+ epipolar lines defined by the origins and directions. The extrinsics and intrinsics
+ are for the images the 2D points lie on.
+ """
+
+ xy_origins, xy_directions = get_world_rays(xy, extrinsics, intrinsics)
+ return intersect_rays(origins, directions, xy_origins, xy_directions)
+
+
+def get_depth(
+ origins: Float[Tensor, "*#batch 3"],
+ directions: Float[Tensor, "*#batch 3"],
+ xy: Float[Tensor, "*#batch 2"],
+ extrinsics: Float[Tensor, "*#batch 4 4"],
+ intrinsics: Float[Tensor, "*#batch 3 3"],
+) -> Float[Tensor, " *batch"]:
+ """Calculate the depths that correspond to the specified 2D points on the epipolar
+ lines defined by the origins and directions. The extrinsics and intrinsics are for
+ the images the 2D points lie on.
+ """
+ xyz = lift_to_3d(origins, directions, xy, extrinsics, intrinsics)
+ return (xyz - origins).norm(dim=-1)
diff --git a/src/geometry/projection.py b/src/geometry/projection.py
new file mode 100644
index 0000000000000000000000000000000000000000..c7fcf6f4cc439f4110f8f1ca8fb1c653231cd024
--- /dev/null
+++ b/src/geometry/projection.py
@@ -0,0 +1,261 @@
+from math import prod
+
+import torch
+from einops import einsum, rearrange, reduce, repeat
+from jaxtyping import Bool, Float, Int64
+from torch import Tensor
+
+
+def homogenize_points(
+ points: Float[Tensor, "*batch dim"],
+) -> Float[Tensor, "*batch dim+1"]:
+ """Convert batched points (xyz) to (xyz1)."""
+ return torch.cat([points, torch.ones_like(points[..., :1])], dim=-1)
+
+
+def homogenize_vectors(
+ vectors: Float[Tensor, "*batch dim"],
+) -> Float[Tensor, "*batch dim+1"]:
+ """Convert batched vectors (xyz) to (xyz0)."""
+ return torch.cat([vectors, torch.zeros_like(vectors[..., :1])], dim=-1)
+
+
+def transform_rigid(
+ homogeneous_coordinates: Float[Tensor, "*#batch dim"],
+ transformation: Float[Tensor, "*#batch dim dim"],
+) -> Float[Tensor, "*batch dim"]:
+ """Apply a rigid-body transformation to points or vectors."""
+ return einsum(transformation, homogeneous_coordinates, "... i j, ... j -> ... i")
+
+
+def transform_cam2world(
+ homogeneous_coordinates: Float[Tensor, "*#batch dim"],
+ extrinsics: Float[Tensor, "*#batch dim dim"],
+) -> Float[Tensor, "*batch dim"]:
+ """Transform points from 3D camera coordinates to 3D world coordinates."""
+ return transform_rigid(homogeneous_coordinates, extrinsics)
+
+
+def transform_world2cam(
+ homogeneous_coordinates: Float[Tensor, "*#batch dim"],
+ extrinsics: Float[Tensor, "*#batch dim dim"],
+) -> Float[Tensor, "*batch dim"]:
+ """Transform points from 3D world coordinates to 3D camera coordinates."""
+ return transform_rigid(homogeneous_coordinates, extrinsics.inverse())
+
+
+def project_camera_space(
+ points: Float[Tensor, "*#batch dim"],
+ intrinsics: Float[Tensor, "*#batch dim dim"],
+ epsilon: float = torch.finfo(torch.float32).eps,
+ infinity: float = 1e8,
+) -> Float[Tensor, "*batch dim-1"]:
+ points = points / (points[..., -1:] + epsilon)
+ points = points.nan_to_num(posinf=infinity, neginf=-infinity)
+ points = einsum(intrinsics, points, "... i j, ... j -> ... i")
+ return points[..., :-1]
+
+
+def project(
+ points: Float[Tensor, "*#batch dim"],
+ extrinsics: Float[Tensor, "*#batch dim+1 dim+1"],
+ intrinsics: Float[Tensor, "*#batch dim dim"],
+ epsilon: float = torch.finfo(torch.float32).eps,
+) -> tuple[
+ Float[Tensor, "*batch dim-1"], # xy coordinates
+ Bool[Tensor, " *batch"], # whether points are in front of the camera
+]:
+ points = homogenize_points(points)
+ points = transform_world2cam(points, extrinsics)[..., :-1]
+ in_front_of_camera = points[..., -1] >= 0
+ return project_camera_space(points, intrinsics, epsilon=epsilon), in_front_of_camera
+
+
+def unproject(
+ coordinates: Float[Tensor, "*#batch dim"],
+ z: Float[Tensor, "*#batch"],
+ intrinsics: Float[Tensor, "*#batch dim+1 dim+1"],
+) -> Float[Tensor, "*batch dim+1"]:
+ """Unproject 2D camera coordinates with the given Z values."""
+
+ # Apply the inverse intrinsics to the coordinates.
+ coordinates = homogenize_points(coordinates)
+ ray_directions = einsum(
+ intrinsics.inverse(), coordinates, "... i j, ... j -> ... i"
+ )
+
+ # Apply the supplied depth values.
+ return ray_directions * z[..., None]
+
+
+def get_world_rays(
+ coordinates: Float[Tensor, "*#batch dim"],
+ extrinsics: Float[Tensor, "*#batch dim+2 dim+2"],
+ intrinsics: Float[Tensor, "*#batch dim+1 dim+1"],
+) -> tuple[
+ Float[Tensor, "*batch dim+1"], # origins
+ Float[Tensor, "*batch dim+1"], # directions
+]:
+ # Get camera-space ray directions.
+ directions = unproject(
+ coordinates,
+ torch.ones_like(coordinates[..., 0]),
+ intrinsics,
+ )
+ directions = directions / directions.norm(dim=-1, keepdim=True)
+
+ # Transform ray directions to world coordinates.
+ directions = homogenize_vectors(directions)
+ directions = transform_cam2world(directions, extrinsics)[..., :-1]
+
+ # Tile the ray origins to have the same shape as the ray directions.
+ origins = extrinsics[..., :-1, -1].broadcast_to(directions.shape)
+
+ return origins, directions
+
+
+def get_local_rays(
+ coordinates: Float[Tensor, "*#batch dim"],
+ intrinsics: Float[Tensor, "*#batch dim+1 dim+1"],
+) -> Float[Tensor, "*batch dim+1"]:
+ # Get camera-space ray directions.
+ directions = unproject(
+ coordinates,
+ torch.ones_like(coordinates[..., 0]),
+ intrinsics,
+ )
+ directions = directions / directions.norm(dim=-1, keepdim=True)
+ return directions
+
+
+def sample_image_grid(
+ shape: tuple[int, ...],
+ device: torch.device = torch.device("cpu"),
+) -> tuple[
+ Float[Tensor, "*shape dim"], # float coordinates (xy indexing)
+ Int64[Tensor, "*shape dim"], # integer indices (ij indexing)
+]:
+ """Get normalized (range 0 to 1) coordinates and integer indices for an image."""
+
+ # Each entry is a pixel-wise integer coordinate. In the 2D case, each entry is a
+ # (row, col) coordinate.
+ indices = [torch.arange(length, device=device) for length in shape]
+ stacked_indices = torch.stack(torch.meshgrid(*indices, indexing="ij"), dim=-1)
+
+ # Each entry is a floating-point coordinate in the range (0, 1). In the 2D case,
+ # each entry is an (x, y) coordinate.
+ coordinates = [(idx + 0.5) / length for idx, length in zip(indices, shape)]
+ coordinates = reversed(coordinates)
+ coordinates = torch.stack(torch.meshgrid(*coordinates, indexing="xy"), dim=-1)
+
+ return coordinates, stacked_indices
+
+
+def sample_training_rays(
+ image: Float[Tensor, "batch view channel ..."],
+ intrinsics: Float[Tensor, "batch view dim dim"],
+ extrinsics: Float[Tensor, "batch view dim+1 dim+1"],
+ num_rays: int,
+) -> tuple[
+ Float[Tensor, "batch ray dim"], # origins
+ Float[Tensor, "batch ray dim"], # directions
+ Float[Tensor, "batch ray 3"], # sampled color
+]:
+ device = extrinsics.device
+ b, v, _, *grid_shape = image.shape
+
+ # Generate all possible target rays.
+ xy, _ = sample_image_grid(tuple(grid_shape), device)
+ origins, directions = get_world_rays(
+ rearrange(xy, "... d -> ... () () d"),
+ extrinsics,
+ intrinsics,
+ )
+ origins = rearrange(origins, "... b v xy -> b (v ...) xy", b=b, v=v)
+ directions = rearrange(directions, "... b v xy -> b (v ...) xy", b=b, v=v)
+ pixels = rearrange(image, "b v c ... -> b (v ...) c")
+
+ # Sample random rays.
+ num_possible_rays = v * prod(grid_shape)
+ ray_indices = torch.randint(num_possible_rays, (b, num_rays), device=device)
+ batch_indices = repeat(torch.arange(b, device=device), "b -> b n", n=num_rays)
+
+ return (
+ origins[batch_indices, ray_indices],
+ directions[batch_indices, ray_indices],
+ pixels[batch_indices, ray_indices],
+ )
+
+
+def intersect_rays(
+ origins_x: Float[Tensor, "*#batch 3"],
+ directions_x: Float[Tensor, "*#batch 3"],
+ origins_y: Float[Tensor, "*#batch 3"],
+ directions_y: Float[Tensor, "*#batch 3"],
+ eps: float = 1e-5,
+ inf: float = 1e10,
+) -> Float[Tensor, "*batch 3"]:
+ """Compute the least-squares intersection of rays. Uses the math from here:
+ https://math.stackexchange.com/a/1762491/286022
+ """
+
+ # Broadcast the rays so their shapes match.
+ shape = torch.broadcast_shapes(
+ origins_x.shape,
+ directions_x.shape,
+ origins_y.shape,
+ directions_y.shape,
+ )
+ origins_x = origins_x.broadcast_to(shape)
+ directions_x = directions_x.broadcast_to(shape)
+ origins_y = origins_y.broadcast_to(shape)
+ directions_y = directions_y.broadcast_to(shape)
+
+ # Detect and remove batch elements where the directions are parallel.
+ parallel = einsum(directions_x, directions_y, "... xyz, ... xyz -> ...") > 1 - eps
+ origins_x = origins_x[~parallel]
+ directions_x = directions_x[~parallel]
+ origins_y = origins_y[~parallel]
+ directions_y = directions_y[~parallel]
+
+ # Stack the rays into (2, *shape).
+ origins = torch.stack([origins_x, origins_y], dim=0)
+ directions = torch.stack([directions_x, directions_y], dim=0)
+ dtype = origins.dtype
+ device = origins.device
+
+ # Compute n_i * n_i^T - eye(3) from the equation.
+ n = einsum(directions, directions, "r b i, r b j -> r b i j")
+ n = n - torch.eye(3, dtype=dtype, device=device).broadcast_to((2, 1, 3, 3))
+
+ # Compute the left-hand side of the equation.
+ lhs = reduce(n, "r b i j -> b i j", "sum")
+
+ # Compute the right-hand side of the equation.
+ rhs = einsum(n, origins, "r b i j, r b j -> r b i")
+ rhs = reduce(rhs, "r b i -> b i", "sum")
+
+ # Left-matrix-multiply both sides by the pseudo-inverse of lhs to find p.
+ result = torch.linalg.lstsq(lhs, rhs).solution
+
+ # Handle the case of parallel lines by setting depth to infinity.
+ result_all = torch.ones(shape, dtype=dtype, device=device) * inf
+ result_all[~parallel] = result
+ return result_all
+
+
+def get_fov(intrinsics: Float[Tensor, "batch 3 3"]) -> Float[Tensor, "batch 2"]:
+ intrinsics_inv = intrinsics.inverse()
+
+ def process_vector(vector):
+ vector = torch.tensor(vector, dtype=torch.float32, device=intrinsics.device)
+ vector = einsum(intrinsics_inv, vector, "b i j, j -> b i")
+ return vector / vector.norm(dim=-1, keepdim=True)
+
+ left = process_vector([0, 0.5, 1])
+ right = process_vector([1, 0.5, 1])
+ top = process_vector([0.5, 0, 1])
+ bottom = process_vector([0.5, 1, 1])
+ fov_x = (left * right).sum(dim=-1).acos()
+ fov_y = (top * bottom).sum(dim=-1).acos()
+ return torch.stack((fov_x, fov_y), dim=-1)
diff --git a/src/geometry/ptc_geometry.py b/src/geometry/ptc_geometry.py
new file mode 100644
index 0000000000000000000000000000000000000000..4cef303162fce19b7a9d2e7bb2ad520ad7c2cd96
--- /dev/null
+++ b/src/geometry/ptc_geometry.py
@@ -0,0 +1,385 @@
+# Copyright (C) 2024-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+#
+# --------------------------------------------------------
+# geometry utilitary functions
+# --------------------------------------------------------
+import torch
+import numpy as np
+from scipy.spatial import cKDTree as KDTree
+
+from ..model.encoder.backbone.croco.misc import invalid_to_zeros, invalid_to_nans
+
+# from dust3r.utils.device import to_numpy
+
+
+def xy_grid(W, H, device=None, origin=(0, 0), unsqueeze=None, cat_dim=-1, homogeneous=False, **arange_kw):
+ """ Output a (H,W,2) array of int32
+ with output[j,i,0] = i + origin[0]
+ output[j,i,1] = j + origin[1]
+ """
+ if device is None:
+ # numpy
+ arange, meshgrid, stack, ones = np.arange, np.meshgrid, np.stack, np.ones
+ else:
+ # torch
+ arange = lambda *a, **kw: torch.arange(*a, device=device, **kw)
+ meshgrid, stack = torch.meshgrid, torch.stack
+ ones = lambda *a: torch.ones(*a, device=device)
+
+ tw, th = [arange(o, o+s, **arange_kw) for s, o in zip((W, H), origin)]
+ grid = meshgrid(tw, th, indexing='xy')
+ if homogeneous:
+ grid = grid + (ones((H, W)),)
+ if unsqueeze is not None:
+ grid = (grid[0].unsqueeze(unsqueeze), grid[1].unsqueeze(unsqueeze))
+ if cat_dim is not None:
+ grid = stack(grid, cat_dim)
+ return grid
+
+
+def geotrf(Trf, pts, ncol=None, norm=False):
+ """ Apply a geometric transformation to a list of 3-D points.
+
+ H: 3x3 or 4x4 projection matrix (typically a Homography)
+ p: numpy/torch/tuple of coordinates. Shape must be (...,2) or (...,3)
+
+ ncol: int. number of columns of the result (2 or 3)
+ norm: float. if != 0, the resut is projected on the z=norm plane.
+
+ Returns an array of projected 2d points.
+ """
+ assert Trf.ndim >= 2
+ if isinstance(Trf, np.ndarray):
+ pts = np.asarray(pts)
+ elif isinstance(Trf, torch.Tensor):
+ pts = torch.as_tensor(pts, dtype=Trf.dtype)
+
+ # adapt shape if necessary
+ output_reshape = pts.shape[:-1]
+ ncol = ncol or pts.shape[-1]
+
+ # optimized code
+ if (isinstance(Trf, torch.Tensor) and isinstance(pts, torch.Tensor) and
+ Trf.ndim == 3 and pts.ndim == 4):
+ d = pts.shape[3]
+ if Trf.shape[-1] == d:
+ pts = torch.einsum("bij, bhwj -> bhwi", Trf, pts)
+ elif Trf.shape[-1] == d+1:
+ pts = torch.einsum("bij, bhwj -> bhwi", Trf[:, :d, :d], pts) + Trf[:, None, None, :d, d]
+ else:
+ raise ValueError(f'bad shape, not ending with 3 or 4, for {pts.shape=}')
+ else:
+ if Trf.ndim >= 3:
+ n = Trf.ndim-2
+ assert Trf.shape[:n] == pts.shape[:n], 'batch size does not match'
+ Trf = Trf.reshape(-1, Trf.shape[-2], Trf.shape[-1])
+
+ if pts.ndim > Trf.ndim:
+ # Trf == (B,d,d) & pts == (B,H,W,d) --> (B, H*W, d)
+ pts = pts.reshape(Trf.shape[0], -1, pts.shape[-1])
+ elif pts.ndim == 2:
+ # Trf == (B,d,d) & pts == (B,d) --> (B, 1, d)
+ pts = pts[:, None, :]
+
+ if pts.shape[-1]+1 == Trf.shape[-1]:
+ Trf = Trf.swapaxes(-1, -2) # transpose Trf
+ pts = pts @ Trf[..., :-1, :] + Trf[..., -1:, :]
+ elif pts.shape[-1] == Trf.shape[-1]:
+ Trf = Trf.swapaxes(-1, -2) # transpose Trf
+ pts = pts @ Trf
+ else:
+ pts = Trf @ pts.T
+ if pts.ndim >= 2:
+ pts = pts.swapaxes(-1, -2)
+
+ if norm:
+ pts = pts / pts[..., -1:] # DONT DO /= BECAUSE OF WEIRD PYTORCH BUG
+ if norm != 1:
+ pts *= norm
+
+ res = pts[..., :ncol].reshape(*output_reshape, ncol)
+ return res
+
+
+def inv(mat):
+ """ Invert a torch or numpy matrix
+ """
+ if isinstance(mat, torch.Tensor):
+ return torch.linalg.inv(mat)
+ if isinstance(mat, np.ndarray):
+ return np.linalg.inv(mat)
+ raise ValueError(f'bad matrix type = {type(mat)}')
+
+
+def depthmap_to_pts3d(depth, pseudo_focal, pp=None, **_):
+ """
+ Args:
+ - depthmap (BxHxW array):
+ - pseudo_focal: [B,H,W] ; [B,2,H,W] or [B,1,H,W]
+ Returns:
+ pointmap of absolute coordinates (BxHxWx3 array)
+ """
+
+ if len(depth.shape) == 4:
+ B, H, W, n = depth.shape
+ else:
+ B, H, W = depth.shape
+ n = None
+
+ if len(pseudo_focal.shape) == 3: # [B,H,W]
+ pseudo_focalx = pseudo_focaly = pseudo_focal
+ elif len(pseudo_focal.shape) == 4: # [B,2,H,W] or [B,1,H,W]
+ pseudo_focalx = pseudo_focal[:, 0]
+ if pseudo_focal.shape[1] == 2:
+ pseudo_focaly = pseudo_focal[:, 1]
+ else:
+ pseudo_focaly = pseudo_focalx
+ else:
+ raise NotImplementedError("Error, unknown input focal shape format.")
+
+ assert pseudo_focalx.shape == depth.shape[:3]
+ assert pseudo_focaly.shape == depth.shape[:3]
+ grid_x, grid_y = xy_grid(W, H, cat_dim=0, device=depth.device)[:, None]
+
+ # set principal point
+ if pp is None:
+ grid_x = grid_x - (W-1)/2
+ grid_y = grid_y - (H-1)/2
+ else:
+ grid_x = grid_x.expand(B, -1, -1) - pp[:, 0, None, None]
+ grid_y = grid_y.expand(B, -1, -1) - pp[:, 1, None, None]
+
+ if n is None:
+ pts3d = torch.empty((B, H, W, 3), device=depth.device)
+ pts3d[..., 0] = depth * grid_x / pseudo_focalx
+ pts3d[..., 1] = depth * grid_y / pseudo_focaly
+ pts3d[..., 2] = depth
+ else:
+ pts3d = torch.empty((B, H, W, 3, n), device=depth.device)
+ pts3d[..., 0, :] = depth * (grid_x / pseudo_focalx)[..., None]
+ pts3d[..., 1, :] = depth * (grid_y / pseudo_focaly)[..., None]
+ pts3d[..., 2, :] = depth
+ return pts3d
+
+
+def depthmap_to_camera_coordinates(depthmap, camera_intrinsics, pseudo_focal=None):
+ """
+ Args:
+ - depthmap (HxW array):
+ - camera_intrinsics: a 3x3 matrix
+ Returns:
+ pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels.
+ """
+ camera_intrinsics = np.float32(camera_intrinsics)
+ H, W = depthmap.shape
+
+ # Compute 3D ray associated with each pixel
+ # Strong assumption: there are no skew terms
+ assert camera_intrinsics[0, 1] == 0.0
+ assert camera_intrinsics[1, 0] == 0.0
+ if pseudo_focal is None:
+ fu = camera_intrinsics[0, 0]
+ fv = camera_intrinsics[1, 1]
+ else:
+ assert pseudo_focal.shape == (H, W)
+ fu = fv = pseudo_focal
+ cu = camera_intrinsics[0, 2]
+ cv = camera_intrinsics[1, 2]
+
+ u, v = np.meshgrid(np.arange(W), np.arange(H))
+ z_cam = depthmap
+ x_cam = (u - cu) * z_cam / fu
+ y_cam = (v - cv) * z_cam / fv
+ X_cam = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32)
+
+ # Mask for valid coordinates
+ valid_mask = (depthmap > 0.0)
+ return X_cam, valid_mask
+
+
+def depthmap_to_absolute_camera_coordinates(depthmap, camera_intrinsics, camera_pose, **kw):
+ """
+ Args:
+ - depthmap (HxW array):
+ - camera_intrinsics: a 3x3 matrix
+ - camera_pose: a 4x3 or 4x4 cam2world matrix
+ Returns:
+ pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels."""
+ X_cam, valid_mask = depthmap_to_camera_coordinates(depthmap, camera_intrinsics)
+
+ # R_cam2world = np.float32(camera_params["R_cam2world"])
+ # t_cam2world = np.float32(camera_params["t_cam2world"]).squeeze()
+ R_cam2world = camera_pose[:3, :3]
+ t_cam2world = camera_pose[:3, 3]
+
+ # Express in absolute coordinates (invalid depth values)
+ X_world = np.einsum("ik, vuk -> vui", R_cam2world, X_cam) + t_cam2world[None, None, :]
+ return X_world, valid_mask
+
+
+def colmap_to_opencv_intrinsics(K):
+ """
+ Modify camera intrinsics to follow a different convention.
+ Coordinates of the center of the top-left pixels are by default:
+ - (0.5, 0.5) in Colmap
+ - (0,0) in OpenCV
+ """
+ K = K.copy()
+ K[0, 2] -= 0.5
+ K[1, 2] -= 0.5
+ return K
+
+
+def opencv_to_colmap_intrinsics(K):
+ """
+ Modify camera intrinsics to follow a different convention.
+ Coordinates of the center of the top-left pixels are by default:
+ - (0.5, 0.5) in Colmap
+ - (0,0) in OpenCV
+ """
+ K = K.copy()
+ K[0, 2] += 0.5
+ K[1, 2] += 0.5
+ return K
+
+
+def obtain_pointcloud_center(pts, valid_mask=None):
+ """
+ Args:
+ - pts (BxNx3 torch.array): pointmap of absolute coordinates
+ - valid_mask (BxN torch.array): mask specifying valid pixels.
+ Returns:
+ center of the point cloud (3 torch.array)
+ """
+ depth = pts[..., 2]
+ # only choose the 0.02-0.98 quantile as the valid depth range
+ valid_depth_mask = ((depth > depth.quantile(0.02, dim=1, keepdim=True)) &
+ (depth < depth.quantile(0.98, dim=1, keepdim=True)))
+ if valid_mask is not None:
+ valid_mask = valid_depth_mask & valid_mask
+ else:
+ valid_mask = valid_depth_mask
+
+ # pts: (B, N, 3), valid_mask: (B, N)
+ all_pts, nnz = invalid_to_zeros(pts, valid_mask, ndim=3)
+ all_dis = all_pts.norm(dim=-1)
+ center_dis = all_dis.sum(dim=1) / (nnz + 1e-8)
+ return center_dis # (B,)
+
+def normalize_pointcloud(pts1, pts2, norm_mode='avg_dis', valid1=None, valid2=None):
+ """ renorm pointmaps pts1, pts2 with norm_mode
+ """
+ assert pts1.ndim >= 3 and pts1.shape[-1] == 3
+ assert pts2 is None or (pts2.ndim >= 3 and pts2.shape[-1] == 3)
+ norm_mode, dis_mode = norm_mode.split('_')
+
+ if norm_mode == 'avg':
+ # gather all points together (joint normalization)
+ nan_pts1, nnz1 = invalid_to_zeros(pts1, valid1, ndim=3)
+ nan_pts2, nnz2 = invalid_to_zeros(pts2, valid2, ndim=3) if pts2 is not None else (None, 0)
+ all_pts = torch.cat((nan_pts1, nan_pts2), dim=1) if pts2 is not None else nan_pts1
+
+ # compute distance to origin
+ all_dis = all_pts.norm(dim=-1)
+ if dis_mode == 'dis':
+ pass # do nothing
+ elif dis_mode == 'log1p':
+ all_dis = torch.log1p(all_dis)
+ elif dis_mode == 'warp-log1p':
+ # actually warp input points before normalizing them
+ log_dis = torch.log1p(all_dis)
+ warp_factor = log_dis / all_dis.clip(min=1e-8)
+ H1, W1 = pts1.shape[1:-1]
+ pts1 = pts1 * warp_factor[:, :W1*H1].view(-1, H1, W1, 1)
+ if pts2 is not None:
+ H2, W2 = pts2.shape[1:-1]
+ pts2 = pts2 * warp_factor[:, W1*H1:].view(-1, H2, W2, 1)
+ all_dis = log_dis # this is their true distance afterwards
+ else:
+ raise ValueError(f'bad {dis_mode=}')
+
+ norm_factor = all_dis.sum(dim=1) / (nnz1 + nnz2 + 1e-8)
+ else:
+ # gather all points together (joint normalization)
+ nan_pts1 = invalid_to_nans(pts1, valid1, ndim=3)
+ nan_pts2 = invalid_to_nans(pts2, valid2, ndim=3) if pts2 is not None else None
+ all_pts = torch.cat((nan_pts1, nan_pts2), dim=1) if pts2 is not None else nan_pts1
+
+ # compute distance to origin
+ all_dis = all_pts.norm(dim=-1)
+
+ if norm_mode == 'avg':
+ norm_factor = all_dis.nanmean(dim=1)
+ elif norm_mode == 'median':
+ norm_factor = all_dis.nanmedian(dim=1).values.detach()
+ elif norm_mode == 'sqrt':
+ norm_factor = all_dis.sqrt().nanmean(dim=1)**2
+ else:
+ raise ValueError(f'bad {norm_mode=}')
+
+ norm_factor = norm_factor.clip(min=1e-8)
+ while norm_factor.ndim < pts1.ndim:
+ norm_factor.unsqueeze_(-1)
+
+ res = pts1 / norm_factor
+ if pts2 is not None:
+ res = (res, pts2 / norm_factor)
+ return res
+
+
+@torch.no_grad()
+def get_joint_pointcloud_depth(z1, z2, valid_mask1, valid_mask2=None, quantile=0.5):
+ # set invalid points to NaN
+ _z1 = invalid_to_nans(z1, valid_mask1).reshape(len(z1), -1)
+ _z2 = invalid_to_nans(z2, valid_mask2).reshape(len(z2), -1) if z2 is not None else None
+ _z = torch.cat((_z1, _z2), dim=-1) if z2 is not None else _z1
+
+ # compute median depth overall (ignoring nans)
+ if quantile == 0.5:
+ shift_z = torch.nanmedian(_z, dim=-1).values
+ else:
+ shift_z = torch.nanquantile(_z, quantile, dim=-1)
+ return shift_z # (B,)
+
+
+@torch.no_grad()
+def get_joint_pointcloud_center_scale(pts1, pts2, valid_mask1=None, valid_mask2=None, z_only=False, center=True):
+ # set invalid points to NaN
+ _pts1 = invalid_to_nans(pts1, valid_mask1).reshape(len(pts1), -1, 3)
+ _pts2 = invalid_to_nans(pts2, valid_mask2).reshape(len(pts2), -1, 3) if pts2 is not None else None
+ _pts = torch.cat((_pts1, _pts2), dim=1) if pts2 is not None else _pts1
+
+ # compute median center
+ _center = torch.nanmedian(_pts, dim=1, keepdim=True).values # (B,1,3)
+ if z_only:
+ _center[..., :2] = 0 # do not center X and Y
+
+ # compute median norm
+ _norm = ((_pts - _center) if center else _pts).norm(dim=-1)
+ scale = torch.nanmedian(_norm, dim=1).values
+ return _center[:, None, :, :], scale[:, None, None, None]
+
+
+def find_reciprocal_matches(P1, P2):
+ """
+ returns 3 values:
+ 1 - reciprocal_in_P2: a boolean array of size P2.shape[0], a "True" value indicates a match
+ 2 - nn2_in_P1: a int array of size P2.shape[0], it contains the indexes of the closest points in P1
+ 3 - reciprocal_in_P2.sum(): the number of matches
+ """
+ tree1 = KDTree(P1)
+ tree2 = KDTree(P2)
+
+ _, nn1_in_P2 = tree2.query(P1, workers=8)
+ _, nn2_in_P1 = tree1.query(P2, workers=8)
+
+ reciprocal_in_P1 = (nn2_in_P1[nn1_in_P2] == np.arange(len(nn1_in_P2)))
+ reciprocal_in_P2 = (nn1_in_P2[nn2_in_P1] == np.arange(len(nn2_in_P1)))
+ assert reciprocal_in_P1.sum() == reciprocal_in_P2.sum()
+ return reciprocal_in_P2, nn2_in_P1, reciprocal_in_P2.sum()
+
+
+def get_med_dist_between_poses(poses):
+ from scipy.spatial.distance import pdist
+ return np.median(pdist([to_numpy(p[:3, 3]) for p in poses]))
diff --git a/src/global_cfg.py b/src/global_cfg.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc8571b0adc3dbc3c41ddb0b364698782ff94ced
--- /dev/null
+++ b/src/global_cfg.py
@@ -0,0 +1,19 @@
+from typing import Optional
+
+from omegaconf import DictConfig
+
+cfg: Optional[DictConfig] = None
+
+
+def get_cfg() -> DictConfig:
+ global cfg
+ return cfg
+
+
+def set_cfg(new_cfg: DictConfig) -> None:
+ global cfg
+ cfg = new_cfg
+
+
+def get_seed() -> int:
+ return cfg.seed
diff --git a/src/loss/__init__.py b/src/loss/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d7f511ba3e4aa6d6f381a4bcf5738e9166e2bb5c
--- /dev/null
+++ b/src/loss/__init__.py
@@ -0,0 +1,26 @@
+from .loss import Loss
+from .loss_depth import LossDepth, LossDepthCfgWrapper
+from .loss_lpips import LossLpips, LossLpipsCfgWrapper
+from .loss_mse import LossMse, LossMseCfgWrapper
+from .loss_opacity import LossOpacity, LossOpacityCfgWrapper
+from .loss_depth_gt import LossDepthGT, LossDepthGTCfgWrapper
+from .loss_lod import LossLOD, LossLODCfgWrapper
+from .loss_depth_consis import LossDepthConsis, LossDepthConsisCfgWrapper
+from .loss_normal_consis import LossNormalConsis, LossNormalConsisCfgWrapper
+from .loss_chamfer_distance import LossChamferDistance, LossChamferDistanceCfgWrapper
+LOSSES = {
+ LossDepthCfgWrapper: LossDepth,
+ LossLpipsCfgWrapper: LossLpips,
+ LossMseCfgWrapper: LossMse,
+ LossOpacityCfgWrapper: LossOpacity,
+ LossDepthGTCfgWrapper: LossDepthGT,
+ LossLODCfgWrapper: LossLOD,
+ LossDepthConsisCfgWrapper: LossDepthConsis,
+ LossNormalConsisCfgWrapper: LossNormalConsis,
+ LossChamferDistanceCfgWrapper: LossChamferDistance,
+}
+
+LossCfgWrapper = LossDepthCfgWrapper | LossLpipsCfgWrapper | LossMseCfgWrapper | LossOpacityCfgWrapper | LossDepthGTCfgWrapper | LossLODCfgWrapper | LossDepthConsisCfgWrapper | LossNormalConsisCfgWrapper | LossChamferDistanceCfgWrapper
+
+def get_losses(cfgs: list[LossCfgWrapper]) -> list[Loss]:
+ return [LOSSES[type(cfg)](cfg) for cfg in cfgs]
diff --git a/src/loss/loss.py b/src/loss/loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..42690697b6a6134930771f382caa9424eb6c13c8
--- /dev/null
+++ b/src/loss/loss.py
@@ -0,0 +1,37 @@
+from abc import ABC, abstractmethod
+from dataclasses import fields
+from typing import Generic, TypeVar
+
+from jaxtyping import Float
+from torch import Tensor, nn
+
+from src.dataset.types import BatchedExample
+from src.model.decoder.decoder import DecoderOutput
+from src.model.types import Gaussians
+
+T_cfg = TypeVar("T_cfg")
+T_wrapper = TypeVar("T_wrapper")
+
+
+class Loss(nn.Module, ABC, Generic[T_cfg, T_wrapper]):
+ cfg: T_cfg
+ name: str
+
+ def __init__(self, cfg: T_wrapper) -> None:
+ super().__init__()
+
+ # Extract the configuration from the wrapper.
+ (field,) = fields(type(cfg))
+ self.cfg = getattr(cfg, field.name)
+ self.name = field.name
+
+ @abstractmethod
+ def forward(
+ self,
+ prediction: DecoderOutput,
+ batch: BatchedExample,
+ gaussians: Gaussians,
+ depth_dict: dict,
+ global_step: int,
+ ) -> Float[Tensor, ""]:
+ pass
diff --git a/src/loss/loss_chamfer_distance.py b/src/loss/loss_chamfer_distance.py
new file mode 100644
index 0000000000000000000000000000000000000000..17f0bc2e07824ac50e5308e342b3fc626c7d8222
--- /dev/null
+++ b/src/loss/loss_chamfer_distance.py
@@ -0,0 +1,75 @@
+from dataclasses import dataclass
+
+import torch
+from einops import reduce
+from jaxtyping import Float
+from torch import Tensor
+
+from src.dataset.types import BatchedExample
+from src.model.decoder.decoder import DecoderOutput
+from src.model.types import Gaussians
+from .loss import Loss
+from typing import Generic, TypeVar
+from dataclasses import fields
+import torch.nn.functional as F
+import sys
+from pytorch3d.loss import chamfer_distance
+import os
+sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+# from src.loss.depth_anything.dpt import DepthAnything
+from src.misc.utils import vis_depth_map
+
+T_cfg = TypeVar("T_cfg")
+T_wrapper = TypeVar("T_wrapper")
+
+
+@dataclass
+class LossChamferDistanceCfg:
+ weight: float
+ down_sample_ratio: float
+ sigma_image: float | None
+
+
+@dataclass
+class LossChamferDistanceCfgWrapper:
+ chamfer_distance: LossChamferDistanceCfg
+
+class LossChamferDistance(Loss[LossChamferDistanceCfg, LossChamferDistanceCfgWrapper]):
+ def __init__(self, cfg: T_wrapper) -> None:
+ super().__init__(cfg)
+
+ # Extract the configuration from the wrapper.
+ (field,) = fields(type(cfg))
+ self.cfg = getattr(cfg, field.name)
+ self.name = field.name
+
+ def forward(
+ self,
+ prediction: DecoderOutput,
+ batch: BatchedExample,
+ gaussians: Gaussians,
+ depth_dict: dict,
+ global_step: int,
+ ) -> Float[Tensor, ""]:
+ # Scale the depth between the near and far planes.
+ b, v, h, w, _ = depth_dict['distill_infos']['pts_all'].shape
+ pred_pts = depth_dict['distill_infos']['pts_all'].flatten(0, 1)
+
+ conf_mask = depth_dict['distill_infos']['conf_mask']
+ gaussian_meas = gaussians.means
+
+ pred_pts = pred_pts.view(b, v, h, w, -1)
+ conf_mask = conf_mask.view(b, v, h, w)
+
+ pts_mask = torch.abs(gaussian_meas[..., -1]) < 1e2 #
+ # conf_mask = conf_mask & pts_mask
+
+ cd_losses = 0.0
+ for b_idx in range(b):
+ batch_pts, batch_conf, batch_gaussian = pred_pts[b_idx], conf_mask[b_idx], gaussian_meas[b_idx][pts_mask[b_idx]]
+ batch_pts = batch_pts[batch_conf]
+ batch_pts = batch_pts[torch.randperm(batch_pts.shape[0])[:int(batch_pts.shape[0] * self.cfg.down_sample_ratio)]]
+ batch_gaussian = batch_gaussian[torch.randperm(batch_gaussian.shape[0])[:int(batch_gaussian.shape[0] * self.cfg.down_sample_ratio)]]
+ cd_loss = chamfer_distance(batch_pts.unsqueeze(0), batch_gaussian.unsqueeze(0))[0]
+ cd_losses = cd_losses + cd_loss
+ return self.cfg.weight * torch.nan_to_num(cd_losses / b, nan=0.0)
\ No newline at end of file
diff --git a/src/loss/loss_depth.py b/src/loss/loss_depth.py
new file mode 100644
index 0000000000000000000000000000000000000000..81929387682d0e65386b3f913ee10d1363c3cf8f
--- /dev/null
+++ b/src/loss/loss_depth.py
@@ -0,0 +1,118 @@
+from dataclasses import dataclass
+
+import torch
+from einops import reduce
+from jaxtyping import Float
+from torch import Tensor
+
+from src.dataset.types import BatchedExample
+from src.model.decoder.decoder import DecoderOutput
+from src.model.types import Gaussians
+from .loss import Loss
+from typing import Generic, TypeVar
+from dataclasses import fields
+import torch.nn.functional as F
+import sys
+import os
+sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+# from src.loss.depth_anything.dpt import DepthAnything
+from src.misc.utils import vis_depth_map
+
+T_cfg = TypeVar("T_cfg")
+T_wrapper = TypeVar("T_wrapper")
+
+
+@dataclass
+class LossDepthCfg:
+ weight: float
+ sigma_image: float | None
+ use_second_derivative: bool
+
+
+@dataclass
+class LossDepthCfgWrapper:
+ depth: LossDepthCfg
+
+
+class LossDepth(Loss[LossDepthCfg, LossDepthCfgWrapper]):
+ def __init__(self, cfg: T_wrapper) -> None:
+ super().__init__(cfg)
+
+ # Extract the configuration from the wrapper.
+ (field,) = fields(type(cfg))
+ self.cfg = getattr(cfg, field.name)
+ self.name = field.name
+
+ model_configs = {
+ 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
+ 'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
+ 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]}
+ }
+ encoder = 'vits' # or 'vitb', 'vits'
+ depth_anything = DepthAnything(model_configs[encoder])
+ depth_anything.load_state_dict(torch.load(f'src/loss/depth_anything/depth_anything_{encoder}14.pth'))
+
+ self.depth_anything = depth_anything
+ for param in self.depth_anything.parameters():
+ param.requires_grad = False
+
+ def disp_rescale(self, disp: Float[Tensor, "B H W"]):
+ disp = disp.flatten(1, 2)
+ disp_median = torch.median(disp, dim=-1, keepdim=True)[0] # (B, V, 1)
+ disp_var = (disp - disp_median).abs().mean(dim=-1, keepdim=True) # (B, V, 1)
+ disp = (disp - disp_median) / (disp_var + 1e-6)
+ return disp
+
+ def smooth_l1_loss(self, pred, target, beta=1.0, reduction='none'):
+ diff = pred - target
+ abs_diff = torch.abs(diff)
+
+ loss = torch.where(abs_diff < beta, 0.5 * diff ** 2 / beta, abs_diff - 0.5 * beta)
+
+ if reduction == 'mean':
+ return loss.mean()
+ elif reduction == 'sum':
+ return loss.sum()
+ elif reduction == 'none':
+ return loss
+ else:
+ raise ValueError("Invalid reduction type. Choose from 'mean', 'sum', or 'none'.")
+
+ def ctx_depth_loss(self,
+ depth_map: Float[Tensor, "B V H W C"],
+ depth_conf: Float[Tensor, "B V H W"],
+ batch: BatchedExample,
+ cxt_depth_weight: float = 0.01,
+ alpha: float = 0.2):
+ B, V, _, H, W = batch["context"]["image"].shape
+ ctx_imgs = batch["context"]["image"].view(B * V, 3, H, W).float()
+ da_output = self.depth_anything(ctx_imgs)
+ da_output = self.disp_rescale(da_output)
+
+ disp_context = 1.0 / depth_map.flatten(0, 1).squeeze(-1).clamp(1e-3) # (B * V, H, W)
+ context_output = self.disp_rescale(disp_context)
+
+ depth_conf = depth_conf.flatten(0, 1).flatten(1, 2) # (B * V)
+
+ return cxt_depth_weight * (self.smooth_l1_loss(context_output*depth_conf, da_output*depth_conf, reduction='none') - alpha * torch.log(depth_conf)).mean()
+
+
+ def forward(
+ self,
+ prediction: DecoderOutput,
+ batch: BatchedExample,
+ gaussians: Gaussians,
+ global_step: int,
+ ) -> Float[Tensor, ""]:
+ # Scale the depth between the near and far planes.
+ target_imgs = batch["target"]["image"]
+ B, V, _, H, W = target_imgs.shape
+ target_imgs = target_imgs.view(B * V, 3, H, W)
+ da_output = self.depth_anything(target_imgs.float())
+ da_output = self.disp_rescale(da_output)
+
+ disp_gs = 1.0 / prediction.depth.flatten(0, 1).clamp(1e-3).float()
+ gs_output = self.disp_rescale(disp_gs)
+
+
+ return self.cfg.weight * torch.nan_to_num(F.smooth_l1_loss(da_output, gs_output), nan=0.0)
\ No newline at end of file
diff --git a/src/loss/loss_depth_consis.py b/src/loss/loss_depth_consis.py
new file mode 100644
index 0000000000000000000000000000000000000000..cbd9eada7dcb547994a44d1e9be2aad6e9f25414
--- /dev/null
+++ b/src/loss/loss_depth_consis.py
@@ -0,0 +1,140 @@
+from dataclasses import dataclass
+
+import torch
+from einops import reduce
+from jaxtyping import Float
+from torch import Tensor
+
+from src.dataset.types import BatchedExample
+from src.model.decoder.decoder import DecoderOutput
+from src.model.types import Gaussians
+from .loss import Loss
+from typing import Generic, Literal, Optional, TypeVar
+from dataclasses import fields
+import torch.nn.functional as F
+import sys
+import os
+sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+# from src.loss.depth_anything.dpt import DepthAnything
+from src.misc.utils import vis_depth_map
+
+T_cfg = TypeVar("T_cfg")
+T_wrapper = TypeVar("T_wrapper")
+
+
+@dataclass
+class LossDepthConsisCfg:
+ weight: float
+ sigma_image: float | None
+ use_second_derivative: bool
+ loss_type: Literal['MSE', 'EdgeAwareLogL1', 'PearsonDepth'] = 'MSE'
+ detach: bool = False
+ conf: bool = False
+ not_use_valid_mask: bool = False
+ apply_after_step: int = 0
+
+@dataclass
+class LossDepthConsisCfgWrapper:
+ depth_consis: LossDepthConsisCfg
+
+
+class LogL1(torch.nn.Module):
+ """Log-L1 loss"""
+
+ def __init__(
+ self, implementation: Literal["scalar", "per-pixel"] = "scalar", **kwargs
+ ):
+ super().__init__()
+ self.implementation = implementation
+
+ def forward(self, pred, gt):
+ if self.implementation == "scalar":
+ return torch.log(1 + torch.abs(pred - gt)).mean()
+ else:
+ return torch.log(1 + torch.abs(pred - gt))
+
+class EdgeAwareLogL1(torch.nn.Module):
+ """Gradient aware Log-L1 loss"""
+
+ def __init__(
+ self, implementation: Literal["scalar", "per-pixel"] = "scalar", **kwargs
+ ):
+ super().__init__()
+ self.implementation = implementation
+ self.logl1 = LogL1(implementation="per-pixel")
+
+ def forward(self, pred: Tensor, gt: Tensor, rgb: Tensor, mask: Optional[Tensor]):
+ logl1 = self.logl1(pred, gt)
+
+ grad_img_x = torch.mean(
+ torch.abs(rgb[..., :, :-1, :] - rgb[..., :, 1:, :]), -1, keepdim=True
+ )
+ grad_img_y = torch.mean(
+ torch.abs(rgb[..., :-1, :, :] - rgb[..., 1:, :, :]), -1, keepdim=True
+ )
+ lambda_x = torch.exp(-grad_img_x)
+ lambda_y = torch.exp(-grad_img_y)
+
+ loss_x = lambda_x * logl1[..., :, :-1, :]
+ loss_y = lambda_y * logl1[..., :-1, :, :]
+
+ if self.implementation == "per-pixel":
+ if mask is not None:
+ loss_x[~mask[..., :, :-1, :]] = 0
+ loss_y[~mask[..., :-1, :, :]] = 0
+ return loss_x[..., :-1, :, :] + loss_y[..., :, :-1, :]
+
+ if mask is not None:
+ assert mask.shape[:2] == pred.shape[:2]
+ loss_x = loss_x[mask[..., :, :-1, :]]
+ loss_y = loss_y[mask[..., :-1, :, :]]
+
+ if self.implementation == "scalar":
+ return loss_x.mean() + loss_y.mean()
+
+class LossDepthConsis(Loss[LossDepthConsisCfg, LossDepthConsisCfgWrapper]):
+ def __init__(self, cfg: T_wrapper) -> None:
+ super().__init__(cfg)
+
+ # Extract the configuration from the wrapper.
+ (field,) = fields(type(cfg))
+ self.cfg = getattr(cfg, field.name)
+ self.name = field.name
+
+ def forward(
+ self,
+ prediction: DecoderOutput,
+ batch: BatchedExample,
+ gaussians: Gaussians,
+ depth_dict: dict,
+ global_step: int,
+ ) -> Float[Tensor, ""]:
+
+ # Before the specified step, don't apply the loss.
+ if global_step < self.cfg.apply_after_step:
+ return torch.tensor(0.0, dtype=torch.float32, device=prediction.depth.device)
+
+ # Scale the depth between the near and far planes.
+ # conf_valid_mask = depth_dict['conf_valid_mask']
+ rendered_depth = prediction.depth
+ gt_rgb = (batch["context"]["image"] + 1) / 2
+ valid_mask = depth_dict["distill_infos"]['conf_mask']
+
+ if batch['context']['valid_mask'].sum() > 0:
+ valid_mask = batch['context']['valid_mask']
+ # if self.cfg.conf:
+ # valid_mask = valid_mask & conf_valid_mask
+ if self.cfg.not_use_valid_mask:
+ valid_mask = torch.ones_like(valid_mask, device=valid_mask.device)
+ pred_depth = depth_dict['depth'].squeeze(-1)
+ if self.cfg.detach:
+ pred_depth = pred_depth.detach()
+ if self.cfg.loss_type == 'MSE':
+ depth_loss = F.mse_loss(rendered_depth, pred_depth, reduction='none')[valid_mask].mean()
+ elif self.cfg.loss_type == 'EdgeAwareLogL1':
+ rendered_depth = rendered_depth.flatten(0, 1).unsqueeze(-1)
+ pred_depth = pred_depth.flatten(0, 1).unsqueeze(-1)
+ gt_rgb = gt_rgb.flatten(0, 1).permute(0, 2, 3, 1)
+ valid_mask = valid_mask.flatten(0, 1).unsqueeze(-1)
+ depth_loss = EdgeAwareLogL1()(rendered_depth, pred_depth, gt_rgb, valid_mask)
+ return self.cfg.weight * torch.nan_to_num(depth_loss, nan=0.0)
\ No newline at end of file
diff --git a/src/loss/loss_depth_gt.py b/src/loss/loss_depth_gt.py
new file mode 100644
index 0000000000000000000000000000000000000000..5a4caeae8b1dfe19da9e1f4729096806ca46bc2d
--- /dev/null
+++ b/src/loss/loss_depth_gt.py
@@ -0,0 +1,93 @@
+from dataclasses import dataclass
+
+import torch
+from einops import reduce
+from jaxtyping import Float
+from torch import Tensor
+
+from src.dataset.types import BatchedExample
+from src.model.decoder.decoder import DecoderOutput
+from src.model.types import Gaussians
+from .loss import Loss
+from typing import Generic, Literal, TypeVar
+from dataclasses import fields
+import torch.nn.functional as F
+import sys
+import os
+sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+# from src.loss.depth_anything.dpt import DepthAnything
+from src.misc.utils import vis_depth_map
+
+T_cfg = TypeVar("T_cfg")
+T_wrapper = TypeVar("T_wrapper")
+
+
+@dataclass
+class LossDepthGTCfg:
+ weight: float
+ type: Literal["l1", "mse", "silog", "gradient", "l1+gradient"] | None
+
+@dataclass
+class LossDepthGTCfgWrapper:
+ depthgt: LossDepthGTCfg
+
+
+class LossDepthGT(Loss[LossDepthGTCfg, LossDepthGTCfgWrapper]):
+ def gradient_loss(self, gs_depth, target_depth, target_valid_mask):
+ diff = gs_depth - target_depth
+
+ grad_x_diff = diff[:, :, :, 1:] - diff[:, :, :, :-1]
+ grad_y_diff = diff[:, :, 1:, :] - diff[:, :, :-1, :]
+
+ mask_x = target_valid_mask[:, :, :, 1:] * target_valid_mask[:, :, :, :-1]
+ mask_y = target_valid_mask[:, :, 1:, :] * target_valid_mask[:, :, :-1, :]
+
+ grad_x_diff = grad_x_diff * mask_x
+ grad_y_diff = grad_y_diff * mask_y
+
+ grad_x_diff = grad_x_diff.clamp(min=-100, max=100)
+ grad_y_diff = grad_y_diff.clamp(min=-100, max=100)
+
+ loss_x = grad_x_diff.abs().sum()
+ loss_y = grad_y_diff.abs().sum()
+ num_valid = mask_x.sum() + mask_y.sum()
+
+ if num_valid == 0:
+ gradient_loss = 0
+ else:
+ gradient_loss = (loss_x + loss_y) / (num_valid + 1e-6)
+
+ return gradient_loss
+
+ def forward(
+ self,
+ prediction: DecoderOutput,
+ batch: BatchedExample,
+ gaussians: Gaussians,
+ global_step: int,
+ ) -> Float[Tensor, ""]:
+ # Scale the depth between the near and far planes.
+
+ # prediction: B, H, W, C
+ # target: B, H, W, C
+ # mask: B, H, W
+
+ target_depth = batch["target"]["depth"]
+ target_valid_mask = batch["target"]["valid_mask"]
+ gs_depth = prediction.depth.clamp(1e-3)
+
+ if self.cfg.type == "l1":
+ depth_loss = torch.abs(target_depth[target_valid_mask] - gs_depth[target_valid_mask]).mean()
+ elif self.cfg.type == "mse":
+ depth_loss = F.mse_loss(target_depth[target_valid_mask], gs_depth[target_valid_mask])
+ elif self.cfg.type == "silog":
+ depth_loss = torch.log(gs_depth[target_valid_mask]) ** 2 + (gs_depth[target_valid_mask] - target_depth[target_valid_mask]) ** 2 - 0.5
+ depth_loss = depth_loss.mean()
+ elif self.cfg.type == "gradient":
+ depth_loss = self.gradient_loss(gs_depth, target_depth, target_valid_mask)
+ elif self.cfg.type == "l1+gradient":
+ depth_loss_l1 = torch.abs(target_depth[target_valid_mask] - gs_depth[target_valid_mask]).mean()
+ depth_loss_gradient = self.gradient_loss(gs_depth, target_depth, target_valid_mask)
+ depth_loss = depth_loss_l1 + depth_loss_gradient
+
+ return self.cfg.weight * torch.nan_to_num(depth_loss, nan=0.0, posinf=0.0, neginf=0.0)
\ No newline at end of file
diff --git a/src/loss/loss_distill.py b/src/loss/loss_distill.py
new file mode 100644
index 0000000000000000000000000000000000000000..3c32da2af139f5dad0758cec0d73bad64e9831d1
--- /dev/null
+++ b/src/loss/loss_distill.py
@@ -0,0 +1,151 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from copy import copy, deepcopy
+
+from src.model.encoder.vggt.utils.pose_enc import pose_encoding_to_extri_intri
+from src.model.encoder.vggt.utils.rotation import mat_to_quat
+from src.utils.point import get_normal_map
+
+def extri_intri_to_pose_encoding(
+ extrinsics,
+ intrinsics,
+ image_size_hw=None, # e.g., (256, 512)
+ pose_encoding_type="absT_quaR_FoV",
+):
+ """Convert camera extrinsics and intrinsics to a compact pose encoding.
+
+ This function transforms camera parameters into a unified pose encoding format,
+ which can be used for various downstream tasks like pose prediction or representation.
+
+ Args:
+ extrinsics (torch.Tensor): Camera extrinsic parameters with shape BxSx3x4,
+ where B is batch size and S is sequence length.
+ In OpenCV coordinate system (x-right, y-down, z-forward), representing camera from world transformation.
+ The format is [R|t] where R is a 3x3 rotation matrix and t is a 3x1 translation vector.
+ intrinsics (torch.Tensor): Camera intrinsic parameters with shape BxSx3x3.
+ Defined in pixels, with format:
+ [[fx, 0, cx],
+ [0, fy, cy],
+ [0, 0, 1]]
+ where fx, fy are focal lengths and (cx, cy) is the principal point
+ image_size_hw (tuple): Tuple of (height, width) of the image in pixels.
+ Required for computing field of view values. For example: (256, 512).
+ pose_encoding_type (str): Type of pose encoding to use. Currently only
+ supports "absT_quaR_FoV" (absolute translation, quaternion rotation, field of view).
+
+ Returns:
+ torch.Tensor: Encoded camera pose parameters with shape BxSx9.
+ For "absT_quaR_FoV" type, the 9 dimensions are:
+ - [:3] = absolute translation vector T (3D)
+ - [3:7] = rotation as quaternion quat (4D)
+ - [7:] = field of view (2D)
+ """
+
+ # extrinsics: BxSx3x4
+ # intrinsics: BxSx3x3
+
+ if pose_encoding_type == "absT_quaR_FoV":
+ R = extrinsics[:, :, :3, :3] # BxSx3x3
+ T = extrinsics[:, :, :3, 3] # BxSx3
+
+ quat = mat_to_quat(R)
+ # Note the order of h and w here
+ # H, W = image_size_hw
+ # fov_h = 2 * torch.atan((H / 2) / intrinsics[..., 1, 1])
+ # fov_w = 2 * torch.atan((W / 2) / intrinsics[..., 0, 0])
+ fov_h = 2 * torch.atan(0.5 / intrinsics[..., 1, 1])
+ fov_w = 2 * torch.atan(0.5 / intrinsics[..., 0, 0])
+ pose_encoding = torch.cat([T, quat, fov_h[..., None], fov_w[..., None]], dim=-1).float()
+ else:
+ raise NotImplementedError
+
+ return pose_encoding
+
+def huber_loss(x, y, delta=1.0):
+ """Calculate element-wise Huber loss between x and y"""
+ diff = x - y
+ abs_diff = diff.abs()
+ flag = (abs_diff <= delta).to(diff.dtype)
+ return flag * 0.5 * diff**2 + (1 - flag) * delta * (abs_diff - 0.5 * delta)
+
+class DistillLoss(nn.Module):
+ def __init__(self, delta=1.0, gamma=0.6, weight_pose=1.0, weight_depth=1.0, weight_normal=1.0):
+ super().__init__()
+ self.delta = delta
+ self.gamma = gamma
+ self.weight_pose = weight_pose
+ self.weight_depth = weight_depth
+ self.weight_normal = weight_normal
+
+ def camera_loss_single(self, cur_pred_pose_enc, gt_pose_encoding, loss_type="l1"):
+ if loss_type == "l1":
+ loss_T = (cur_pred_pose_enc[..., :3] - gt_pose_encoding[..., :3]).abs()
+ loss_R = (cur_pred_pose_enc[..., 3:7] - gt_pose_encoding[..., 3:7]).abs()
+ loss_fl = (cur_pred_pose_enc[..., 7:] - gt_pose_encoding[..., 7:]).abs()
+ elif loss_type == "l2":
+ loss_T = (cur_pred_pose_enc[..., :3] - gt_pose_encoding[..., :3]).norm(dim=-1, keepdim=True)
+ loss_R = (cur_pred_pose_enc[..., 3:7] - gt_pose_encoding[..., 3:7]).norm(dim=-1)
+ loss_fl = (cur_pred_pose_enc[..., 7:] - gt_pose_encoding[..., 7:]).norm(dim=-1)
+ elif loss_type == "huber":
+ loss_T = huber_loss(cur_pred_pose_enc[..., :3], gt_pose_encoding[..., :3])
+ loss_R = huber_loss(cur_pred_pose_enc[..., 3:7], gt_pose_encoding[..., 3:7])
+ loss_fl = huber_loss(cur_pred_pose_enc[..., 7:], gt_pose_encoding[..., 7:])
+ else:
+ raise ValueError(f"Unknown loss type: {loss_type}")
+
+ loss_T = torch.nan_to_num(loss_T, nan=0.0, posinf=0.0, neginf=0.0)
+ loss_R = torch.nan_to_num(loss_R, nan=0.0, posinf=0.0, neginf=0.0)
+ loss_fl = torch.nan_to_num(loss_fl, nan=0.0, posinf=0.0, neginf=0.0)
+
+ loss_T = torch.clamp(loss_T, min=-100, max=100)
+ loss_R = torch.clamp(loss_R, min=-100, max=100)
+ loss_fl = torch.clamp(loss_fl, min=-100, max=100)
+
+ loss_T = loss_T.mean()
+ loss_R = loss_R.mean()
+ loss_fl = loss_fl.mean()
+
+ return loss_T, loss_R, loss_fl
+
+ def forward(self, distill_infos, pred_pose_enc_list, prediction, batch):
+ loss_pose = 0.0
+
+ if pred_pose_enc_list is not None:
+ num_predictions = len(pred_pose_enc_list)
+ pesudo_gt_pose_enc = distill_infos['pred_pose_enc_list']
+ for i in range(num_predictions):
+ i_weight = self.gamma ** (num_predictions - i - 1)
+ cur_pred_pose_enc = pred_pose_enc_list[i]
+ cur_pesudo_gt_pose_enc = pesudo_gt_pose_enc[i]
+ loss_pose += i_weight * huber_loss(cur_pred_pose_enc, cur_pesudo_gt_pose_enc).mean()
+ loss_pose = loss_pose / num_predictions
+ loss_pose = torch.nan_to_num(loss_pose, nan=0.0, posinf=0.0, neginf=0.0)
+
+ pred_depth = prediction.depth.flatten(0, 1)
+ pesudo_gt_depth = distill_infos['depth_map'].flatten(0, 1).squeeze(-1)
+ conf_mask = distill_infos['conf_mask'].flatten(0, 1)
+
+ if batch['context']['valid_mask'].sum() > 0:
+ conf_mask = batch['context']['valid_mask'].flatten(0, 1)
+
+ loss_depth = F.mse_loss(pred_depth[conf_mask], pesudo_gt_depth[conf_mask], reduction='none').mean()
+
+ render_normal = get_normal_map(pred_depth, batch["context"]["intrinsics"].flatten(0, 1))
+ pred_normal = get_normal_map(pesudo_gt_depth, batch["context"]["intrinsics"].flatten(0, 1))
+
+ alpha1_loss = (1 - (render_normal[conf_mask] * pred_normal[conf_mask]).sum(-1)).mean()
+ alpha2_loss = F.l1_loss(render_normal[conf_mask], pred_normal[conf_mask], reduction='mean')
+ loss_normal = (alpha1_loss + alpha2_loss) / 2
+
+ loss_distill = loss_pose * self.weight_pose + loss_depth * self.weight_depth + loss_normal * self.weight_normal
+ loss_distill = torch.nan_to_num(loss_distill, nan=0.0, posinf=0.0, neginf=0.0)
+
+ loss_dict = {
+ "loss_distill": loss_distill,
+ "loss_pose": loss_pose * self.weight_pose,
+ "loss_depth": loss_depth * self.weight_depth,
+ "loss_normal": loss_normal * self.weight_normal
+ }
+
+ return loss_dict
diff --git a/src/loss/loss_huber.py b/src/loss/loss_huber.py
new file mode 100644
index 0000000000000000000000000000000000000000..3bb69c3b10e7b89b2edcdd562a17a7a4124de589
--- /dev/null
+++ b/src/loss/loss_huber.py
@@ -0,0 +1,171 @@
+import torch
+import torch.nn as nn
+from copy import copy, deepcopy
+
+from src.model.encoder.vggt.utils.pose_enc import pose_encoding_to_extri_intri
+from src.model.encoder.vggt.utils.rotation import mat_to_quat
+
+def extri_intri_to_pose_encoding(
+ extrinsics,
+ intrinsics,
+ image_size_hw=None, # e.g., (256, 512)
+ pose_encoding_type="absT_quaR_FoV",
+):
+ """Convert camera extrinsics and intrinsics to a compact pose encoding.
+
+ This function transforms camera parameters into a unified pose encoding format,
+ which can be used for various downstream tasks like pose prediction or representation.
+
+ Args:
+ extrinsics (torch.Tensor): Camera extrinsic parameters with shape BxSx3x4,
+ where B is batch size and S is sequence length.
+ In OpenCV coordinate system (x-right, y-down, z-forward), representing camera from world transformation.
+ The format is [R|t] where R is a 3x3 rotation matrix and t is a 3x1 translation vector.
+ intrinsics (torch.Tensor): Camera intrinsic parameters with shape BxSx3x3.
+ Defined in pixels, with format:
+ [[fx, 0, cx],
+ [0, fy, cy],
+ [0, 0, 1]]
+ where fx, fy are focal lengths and (cx, cy) is the principal point
+ image_size_hw (tuple): Tuple of (height, width) of the image in pixels.
+ Required for computing field of view values. For example: (256, 512).
+ pose_encoding_type (str): Type of pose encoding to use. Currently only
+ supports "absT_quaR_FoV" (absolute translation, quaternion rotation, field of view).
+
+ Returns:
+ torch.Tensor: Encoded camera pose parameters with shape BxSx9.
+ For "absT_quaR_FoV" type, the 9 dimensions are:
+ - [:3] = absolute translation vector T (3D)
+ - [3:7] = rotation as quaternion quat (4D)
+ - [7:] = field of view (2D)
+ """
+
+ # extrinsics: BxSx3x4
+ # intrinsics: BxSx3x3
+
+ if pose_encoding_type == "absT_quaR_FoV":
+ R = extrinsics[:, :, :3, :3] # BxSx3x3
+ T = extrinsics[:, :, :3, 3] # BxSx3
+
+ quat = mat_to_quat(R)
+ # Note the order of h and w here
+ # H, W = image_size_hw
+ # fov_h = 2 * torch.atan((H / 2) / intrinsics[..., 1, 1])
+ # fov_w = 2 * torch.atan((W / 2) / intrinsics[..., 0, 0])
+ fov_h = 2 * torch.atan(0.5 / intrinsics[..., 1, 1])
+ fov_w = 2 * torch.atan(0.5 / intrinsics[..., 0, 0])
+ pose_encoding = torch.cat([T, quat, fov_h[..., None], fov_w[..., None]], dim=-1).float()
+ else:
+ raise NotImplementedError
+
+ return pose_encoding
+
+def huber_loss(x, y, delta=1.0):
+ """Calculate element-wise Huber loss between x and y"""
+ diff = x - y
+ abs_diff = diff.abs()
+ flag = (abs_diff <= delta).to(diff.dtype)
+ return flag * 0.5 * diff**2 + (1 - flag) * delta * (abs_diff - 0.5 * delta)
+
+class HuberLoss(nn.Module):
+ def __init__(self, alpha=1.0, delta=1.0, gamma=0.6, weight_T=1.0, weight_R=1.0, weight_fl=0.5):
+ super().__init__()
+ self.alpha = alpha
+ self.delta = delta
+ self.gamma = gamma
+ self.weight_T = weight_T
+ self.weight_R = weight_R
+ self.weight_fl = weight_fl
+
+ def camera_loss_single(self, cur_pred_pose_enc, gt_pose_encoding, loss_type="l1"):
+ if loss_type == "l1":
+ loss_T = (cur_pred_pose_enc[..., :3] - gt_pose_encoding[..., :3]).abs()
+ loss_R = (cur_pred_pose_enc[..., 3:7] - gt_pose_encoding[..., 3:7]).abs()
+ loss_fl = (cur_pred_pose_enc[..., 7:] - gt_pose_encoding[..., 7:]).abs()
+ elif loss_type == "l2":
+ loss_T = (cur_pred_pose_enc[..., :3] - gt_pose_encoding[..., :3]).norm(dim=-1, keepdim=True)
+ loss_R = (cur_pred_pose_enc[..., 3:7] - gt_pose_encoding[..., 3:7]).norm(dim=-1)
+ loss_fl = (cur_pred_pose_enc[..., 7:] - gt_pose_encoding[..., 7:]).norm(dim=-1)
+ elif loss_type == "huber":
+ loss_T = huber_loss(cur_pred_pose_enc[..., :3], gt_pose_encoding[..., :3])
+ loss_R = huber_loss(cur_pred_pose_enc[..., 3:7], gt_pose_encoding[..., 3:7])
+ loss_fl = huber_loss(cur_pred_pose_enc[..., 7:], gt_pose_encoding[..., 7:])
+ else:
+ raise ValueError(f"Unknown loss type: {loss_type}")
+
+ loss_T = torch.nan_to_num(loss_T, nan=0.0, posinf=0.0, neginf=0.0)
+ loss_R = torch.nan_to_num(loss_R, nan=0.0, posinf=0.0, neginf=0.0)
+ loss_fl = torch.nan_to_num(loss_fl, nan=0.0, posinf=0.0, neginf=0.0)
+
+ loss_T = torch.clamp(loss_T, min=-100, max=100)
+ loss_R = torch.clamp(loss_R, min=-100, max=100)
+ loss_fl = torch.clamp(loss_fl, min=-100, max=100)
+
+ loss_T = loss_T.mean()
+ loss_R = loss_R.mean()
+ loss_fl = loss_fl.mean()
+
+ return loss_T, loss_R, loss_fl
+
+ def forward(self, pred_pose_enc_list, batch):
+ context_extrinsics = batch["context"]["extrinsics"]
+ context_intrinsics = batch["context"]["intrinsics"]
+ image_size_hw = batch["context"]["image"].shape[-2:]
+
+ # transform extrinsics and intrinsics to pose_enc
+ GT_pose_enc = extri_intri_to_pose_encoding(context_extrinsics, context_intrinsics, image_size_hw)
+ num_predictions = len(pred_pose_enc_list)
+ loss_T = loss_R = loss_fl = 0
+
+ for i in range(num_predictions):
+ i_weight = self.gamma ** (num_predictions - i - 1)
+
+ cur_pred_pose_enc = pred_pose_enc_list[i]
+
+ loss_T_i, loss_R_i, loss_fl_i = self.camera_loss_single(cur_pred_pose_enc.clone(), GT_pose_enc.clone(), loss_type="huber")
+ loss_T += i_weight * loss_T_i
+ loss_R += i_weight * loss_R_i
+ loss_fl += i_weight * loss_fl_i
+
+ loss_T = loss_T / num_predictions
+ loss_R = loss_R / num_predictions
+ loss_fl = loss_fl / num_predictions
+ loss_camera = loss_T * self.weight_T + loss_R * self.weight_R + loss_fl * self.weight_fl
+
+ loss_dict = {
+ "loss_camera": loss_camera,
+ "loss_T": loss_T,
+ "loss_R": loss_R,
+ "loss_fl": loss_fl
+ }
+
+ # with torch.no_grad():
+ # # compute auc
+ # last_pred_pose_enc = pred_pose_enc_list[-1]
+
+ # last_pred_extrinsic, _ = pose_encoding_to_extri_intri(last_pred_pose_enc.detach(), image_size_hw, pose_encoding_type='absT_quaR_FoV', build_intrinsics=False)
+
+ # rel_rangle_deg, rel_tangle_deg = camera_to_rel_deg(last_pred_extrinsic.float(), context_extrinsics.float(), context_extrinsics.device)
+
+ # if rel_rangle_deg.numel() == 0 and rel_tangle_deg.numel() == 0:
+ # rel_rangle_deg = torch.FloatTensor([0]).to(context_extrinsics.device).to(context_extrinsics.dtype)
+ # rel_tangle_deg = torch.FloatTensor([0]).to(context_extrinsics.device).to(context_extrinsics.dtype)
+
+ # thresholds = [5, 15]
+ # for threshold in thresholds:
+ # loss_dict[f"Rac_{threshold}"] = (rel_rangle_deg < threshold).float().mean()
+ # loss_dict[f"Tac_{threshold}"] = (rel_tangle_deg < threshold).float().mean()
+
+ # _, normalized_histogram = calculate_auc(
+ # rel_rangle_deg, rel_tangle_deg, max_threshold=30, return_list=True
+ # )
+
+ # auc_thresholds = [30, 10, 5, 3]
+ # for auc_threshold in auc_thresholds:
+ # cur_auc = torch.cumsum(
+ # normalized_histogram[:auc_threshold], dim=0
+ # ).mean()
+ # loss_dict[f"Auc_{auc_threshold}"] = cur_auc
+
+ return loss_dict
+
diff --git a/src/loss/loss_lod.py b/src/loss/loss_lod.py
new file mode 100644
index 0000000000000000000000000000000000000000..2af7af4cd768c9e2b828c2ba206d51c9d50e305e
--- /dev/null
+++ b/src/loss/loss_lod.py
@@ -0,0 +1,64 @@
+from dataclasses import dataclass
+
+import torch
+import torch.nn.functional as F
+from einops import rearrange
+from jaxtyping import Float
+from lpips import LPIPS
+from torch import Tensor
+
+from src.dataset.types import BatchedExample
+from src.misc.nn_module_tools import convert_to_buffer
+from src.model.decoder.decoder import DecoderOutput
+from src.model.types import Gaussians
+from .loss import Loss
+
+
+@dataclass
+class LossLODCfg:
+ mse_weight: float
+ lpips_weight: float
+
+@dataclass
+class LossLODCfgWrapper:
+ lod: LossLODCfg
+
+WEIGHT_LEVEL_MAPPING = {0: 0.1, 1: 0.1, 2: 0.2, 3: 0.6}
+
+class LossLOD(Loss[LossLODCfg, LossLODCfgWrapper]):
+ lpips: LPIPS
+
+ def __init__(self, cfg: LossLODCfgWrapper) -> None:
+ super().__init__(cfg)
+
+ self.lpips = LPIPS(net="vgg")
+ convert_to_buffer(self.lpips, persistent=False)
+
+ def forward(
+ self,
+ prediction: DecoderOutput,
+ batch: BatchedExample,
+ gaussians: Gaussians,
+ global_step: int,
+ ) -> Float[Tensor, ""]:
+ image = batch["target"]["image"]
+ # breakpoint()
+ def mse_loss(x, y):
+ delta = x - y
+ return torch.nan_to_num((delta**2).mean().mean(), nan=0.0, posinf=0.0, neginf=0.0)
+ # Before the specified step, don't apply the loss.
+ lod_rendering = prediction.lod_rendering
+ loss = 0.0
+ for level in lod_rendering.keys():
+ # level_weight
+ # breakpoint()
+ # if level != 3:
+ # continue
+ rendered_imgs = lod_rendering[level]['rendered_imgs'].flatten(0, 1)
+ _h, _w = rendered_imgs.shape[2:]
+ resized_image = F.interpolate(image.clone().flatten(0, 1), size=(_h, _w), mode='bilinear', align_corners=False)
+ level_mse_loss = mse_loss(rendered_imgs, resized_image)
+ level_lpips_loss = self.lpips.forward(rendered_imgs, resized_image, normalize=True).mean()
+
+ loss += (torch.nan_to_num(level_mse_loss, nan=0.0, posinf=0.0, neginf=0.0) * self.cfg.mse_weight + torch.nan_to_num(level_lpips_loss, nan=0.0, posinf=0.0, neginf=0.0) * self.cfg.lpips_weight)
+ return loss / len(lod_rendering.keys())
diff --git a/src/loss/loss_lpips.py b/src/loss/loss_lpips.py
new file mode 100644
index 0000000000000000000000000000000000000000..13411fa94c3b40d9875c1f48e97d69929d070766
--- /dev/null
+++ b/src/loss/loss_lpips.py
@@ -0,0 +1,76 @@
+from dataclasses import dataclass
+
+import torch
+from einops import rearrange
+from jaxtyping import Float
+from lpips import LPIPS
+from torch import Tensor
+
+from src.dataset.types import BatchedExample
+from src.misc.nn_module_tools import convert_to_buffer
+from src.model.decoder.decoder import DecoderOutput
+from src.model.types import Gaussians
+from .loss import Loss
+
+
+@dataclass
+class LossLpipsCfg:
+ weight: float
+ apply_after_step: int
+ conf: bool = False
+ alpha: bool = False
+ mask: bool = False
+
+
+@dataclass
+class LossLpipsCfgWrapper:
+ lpips: LossLpipsCfg
+
+
+class LossLpips(Loss[LossLpipsCfg, LossLpipsCfgWrapper]):
+ lpips: LPIPS
+
+ def __init__(self, cfg: LossLpipsCfgWrapper) -> None:
+ super().__init__(cfg)
+
+ self.lpips = LPIPS(net="vgg")
+ convert_to_buffer(self.lpips, persistent=False)
+
+ def forward(
+ self,
+ prediction: DecoderOutput,
+ batch: BatchedExample,
+ gaussians: Gaussians,
+ depth_dict: dict | None,
+ global_step: int,
+ ) -> Float[Tensor, ""]:
+ image = (batch["context"]["image"] + 1) / 2
+
+ # Before the specified step, don't apply the loss.
+ if global_step < self.cfg.apply_after_step:
+ return torch.tensor(0, dtype=torch.float32, device=image.device)
+
+ if self.cfg.mask or self.cfg.alpha or self.cfg.conf:
+ if self.cfg.mask:
+ mask = batch["context"]["valid_mask"]
+ elif self.cfg.alpha:
+ mask = prediction.alpha
+ elif self.cfg.conf:
+ mask = depth_dict['conf_valid_mask']
+ b, v, c, h, w = prediction.color.shape
+ expanded_mask = mask.unsqueeze(2).expand(-1, -1, c, -1, -1)
+ masked_pred = prediction.color * expanded_mask
+ masked_img = image * expanded_mask
+
+ loss = self.lpips.forward(
+ rearrange(masked_pred, "b v c h w -> (b v) c h w"),
+ rearrange(masked_img, "b v c h w -> (b v) c h w"),
+ normalize=True,
+ )
+ else:
+ loss = self.lpips.forward(
+ rearrange(prediction.color, "b v c h w -> (b v) c h w"),
+ rearrange(image, "b v c h w -> (b v) c h w"),
+ normalize=True,
+ )
+ return self.cfg.weight * torch.nan_to_num(loss.mean(), nan=0.0, posinf=0.0, neginf=0.0)
diff --git a/src/loss/loss_mse.py b/src/loss/loss_mse.py
new file mode 100644
index 0000000000000000000000000000000000000000..c72cb8735b0f88ff560085066d3eb7f86152ca5f
--- /dev/null
+++ b/src/loss/loss_mse.py
@@ -0,0 +1,59 @@
+from dataclasses import dataclass
+
+from jaxtyping import Float
+from torch import Tensor
+import torch
+from src.dataset.types import BatchedExample
+from src.model.decoder.decoder import DecoderOutput
+from src.model.types import Gaussians
+from .loss import Loss
+
+
+@dataclass
+class LossMseCfg:
+ weight: float
+ conf: bool = False
+ mask: bool = False
+ alpha: bool = False
+
+
+@dataclass
+class LossMseCfgWrapper:
+ mse: LossMseCfg
+
+
+class LossMse(Loss[LossMseCfg, LossMseCfgWrapper]):
+ def forward(
+ self,
+ prediction: DecoderOutput,
+ batch: BatchedExample,
+ gaussians: Gaussians,
+ depth_dict: dict | None,
+ global_step: int,
+ ) -> Float[Tensor, ""]:
+ # Get alpha and valid mask from inputs
+ alpha = prediction.alpha
+ # valid_mask = torch.ones_like(alpha, device=alpha.device).bool()
+ valid_mask = batch['context']['valid_mask']
+
+ # # only for objaverse
+ # if batch['context']['valid_mask'].sum() > 0:
+ # valid_mask = batch['context']['valid_mask']
+
+ # Determine which mask to use based on config
+ if self.cfg.mask:
+ mask = valid_mask
+ elif self.cfg.alpha:
+ mask = alpha
+ elif self.cfg.conf:
+ mask = depth_dict['conf_valid_mask']
+ else:
+ mask = torch.ones_like(alpha, device=alpha.device).bool()
+
+ # Rearrange and mask predicted and ground truth images
+ pred_img = prediction.color.permute(0, 1, 3, 4, 2)[mask]
+ gt_img = ((batch["context"]["image"][:, batch["using_index"]] + 1) / 2).permute(0, 1, 3, 4, 2)[mask]
+
+ delta = pred_img - gt_img
+
+ return self.cfg.weight * torch.nan_to_num((delta**2).mean(), nan=0.0, posinf=0.0, neginf=0.0)
diff --git a/src/loss/loss_normal_consis.py b/src/loss/loss_normal_consis.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b8933bbbc1291fc7fdcc53256078805038f8b34
--- /dev/null
+++ b/src/loss/loss_normal_consis.py
@@ -0,0 +1,134 @@
+from dataclasses import dataclass
+
+import torch
+from einops import reduce
+from jaxtyping import Float
+from torch import Tensor
+
+from src.dataset.types import BatchedExample
+from src.model.decoder.decoder import DecoderOutput
+from src.model.types import Gaussians
+from .loss import Loss
+from typing import Generic, TypeVar
+from dataclasses import fields
+import torch.nn.functional as F
+import sys
+import os
+sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+# from src.loss.depth_anything.dpt import DepthAnything
+from src.misc.utils import vis_depth_map
+import open3d as o3d
+T_cfg = TypeVar("T_cfg")
+T_wrapper = TypeVar("T_wrapper")
+
+@dataclass
+class LossNormalConsisCfg:
+ normal_weight: float
+ smooth_weight: float
+ sigma_image: float | None
+ use_second_derivative: bool
+ detach: bool = False
+ conf: bool = False
+ not_use_valid_mask: bool = False
+
+@dataclass
+class LossNormalConsisCfgWrapper:
+ normal_consis: LossNormalConsisCfg
+
+class TVLoss(torch.nn.Module):
+ """TV loss"""
+
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, pred):
+ """
+ Args:
+ pred: [batch, H, W, 3]
+
+ Returns:
+ tv_loss: [batch]
+ """
+ h_diff = pred[..., :, :-1, :] - pred[..., :, 1:, :]
+ w_diff = pred[..., :-1, :, :] - pred[..., 1:, :, :]
+ return torch.mean(torch.abs(h_diff)) + torch.mean(torch.abs(w_diff))
+
+
+class LossNormalConsis(Loss[LossNormalConsisCfg, LossNormalConsisCfgWrapper]):
+ def __init__(self, cfg: T_wrapper) -> None:
+ super().__init__(cfg)
+
+ # Extract the configuration from the wrapper.
+ (field,) = fields(type(cfg))
+ self.cfg = getattr(cfg, field.name)
+ self.name = field.name
+
+ def forward(
+ self,
+ prediction: DecoderOutput,
+ batch: BatchedExample,
+ gaussians: Gaussians,
+ depth_dict: dict,
+ global_step: int,
+ ) -> Float[Tensor, ""]:
+ # Scale the depth between the near and far planes.
+ conf_valid_mask = depth_dict['conf_valid_mask'].flatten(0, 1)
+ valid_mask = batch["context"]["valid_mask"][:, batch["using_index"]].flatten(0, 1)
+ if self.cfg.conf:
+ valid_mask = valid_mask & conf_valid_mask
+ if self.cfg.not_use_valid_mask:
+ valid_mask = torch.ones_like(valid_mask, device=valid_mask.device)
+ render_normal = self.get_normal_map(prediction.depth.flatten(0, 1), batch["context"]["intrinsics"].flatten(0, 1))
+ pred_normal = self.get_normal_map(depth_dict['depth'].flatten(0, 1).squeeze(-1), batch["context"]["intrinsics"].flatten(0, 1))
+ if self.cfg.detach:
+ pred_normal = pred_normal.detach()
+ alpha1_loss = (1 - (render_normal * pred_normal).sum(-1)).mean()
+ alpha2_loss = F.l1_loss(render_normal, pred_normal, reduction='mean')
+ normal_smooth_loss = TVLoss()(render_normal)
+ normal_loss = (alpha1_loss + alpha2_loss) / 2
+ return self.cfg.normal_weight * torch.nan_to_num(normal_loss, nan=0.0) + self.cfg.smooth_weight * torch.nan_to_num(normal_smooth_loss, nan=0.0)
+
+ def get_normal_map(self, depth_map: torch.Tensor, intrinsic: torch.Tensor) -> torch.Tensor:
+ """
+ Convert a depth map to camera coordinates.
+
+ Args:
+ depth_map (torch.Tensor): Depth map of shape (H, W).
+ intrinsic (torch.Tensor): Camera intrinsic matrix of shape (3, 3).
+
+ Returns:
+ tuple[torch.Tensor, torch.Tensor]: Camera coordinates (H, W, 3)
+ """
+ B, H, W = depth_map.shape
+ assert intrinsic.shape == (B, 3, 3), "Intrinsic matrix must be Bx3x3"
+ assert (intrinsic[:, 0, 1] == 0).all() and (intrinsic[:, 1, 0] == 0).all(), "Intrinsic matrix must have zero skew"
+
+ # Intrinsic parameters
+ fu = intrinsic[:, 0, 0] * W # (B,)
+ fv = intrinsic[:, 1, 1] * H # (B,)
+ cu = intrinsic[:, 0, 2] * W # (B,)
+ cv = intrinsic[:, 1, 2] * H # (B,)
+
+ # Generate grid of pixel coordinates
+ u = torch.arange(W, device=depth_map.device)[None, None, :].expand(B, H, W)
+ v = torch.arange(H, device=depth_map.device)[None, :, None].expand(B, H, W)
+
+ # Unproject to camera coordinates (B, H, W)
+ x_cam = (u - cu[:, None, None]) * depth_map / fu[:, None, None]
+ y_cam = (v - cv[:, None, None]) * depth_map / fv[:, None, None]
+ z_cam = depth_map
+
+ # Stack to form camera coordinates (B, H, W, 3)
+ cam_coords = torch.stack((x_cam, y_cam, z_cam), dim=-1).to(dtype=torch.float32)
+
+ output = torch.zeros_like(cam_coords)
+ # Calculate dx using batch dimension (B, H-2, W-2, 3)
+ dx = cam_coords[:, 2:, 1:-1] - cam_coords[:, :-2, 1:-1]
+ # Calculate dy using batch dimension (B, H-2, W-2, 3)
+ dy = cam_coords[:, 1:-1, 2:] - cam_coords[:, 1:-1, :-2]
+ # Cross product and normalization (B, H-2, W-2, 3)
+ normal_map = torch.nn.functional.normalize(torch.cross(dx, dy, dim=-1), dim=-1)
+ # Assign the computed normal map to the output tensor
+ output[:, 1:-1, 1:-1, :] = normal_map
+
+ return output
\ No newline at end of file
diff --git a/src/loss/loss_opacity.py b/src/loss/loss_opacity.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e3aa78a5091b03784d803e1ca2fea986721e5a5
--- /dev/null
+++ b/src/loss/loss_opacity.py
@@ -0,0 +1,43 @@
+from dataclasses import dataclass
+from typing import Literal
+
+from jaxtyping import Float
+from torch import Tensor
+import torch
+import torch.nn.functional as F
+from src.dataset.types import BatchedExample
+from src.model.decoder.decoder import DecoderOutput
+from src.model.types import Gaussians
+from .loss import Loss
+
+
+@dataclass
+class LossOpacityCfg:
+ weight: float
+ type: Literal["exp", "mean", "exp+mean"] = "exp+mean"
+
+
+@dataclass
+class LossOpacityCfgWrapper:
+ opacity: LossOpacityCfg
+
+
+class LossOpacity(Loss[LossOpacityCfg, LossOpacityCfgWrapper]):
+ def forward(
+ self,
+ prediction: DecoderOutput,
+ batch: BatchedExample,
+ gaussians: Gaussians,
+ depth_dict: dict | None,
+ global_step: int,
+ ) -> Float[Tensor, ""]:
+ alpha = prediction.alpha
+ valid_mask = batch['context']['valid_mask'].float()
+ opacity_loss = F.mse_loss(alpha, valid_mask, reduction='none').mean()
+ # if self.cfg.type == "exp":
+ # opacity_loss = torch.exp(-(gaussians.opacities - 0.5) ** 2 / 0.05).mean()
+ # elif self.cfg.type == "mean":
+ # opacity_loss = gaussians.opacities.mean()
+ # elif self.cfg.type == "exp+mean":
+ # opacity_loss = 0.5 * torch.exp(-(gaussians.opacities - 0.5) ** 2 / 0.05).mean() + gaussians.opacities.mean()
+ return self.cfg.weight * torch.nan_to_num(opacity_loss, nan=0.0, posinf=0.0, neginf=0.0)
diff --git a/src/loss/loss_point.py b/src/loss/loss_point.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0c6792c86913381c7197e39733299ebd9205078
--- /dev/null
+++ b/src/loss/loss_point.py
@@ -0,0 +1,358 @@
+import torch
+import torch.nn as nn
+from copy import copy, deepcopy
+
+from src.geometry.ptc_geometry import geotrf, inv, normalize_pointcloud, depthmap_to_pts3d
+# from torchmetrics.functional.regression import pearson_corrcoef
+# from pytorch3d.loss import chamfer_distance
+
+
+def get_pred_pts3d(gt, pred, use_pose=False):
+ if 'depth' in pred and 'pseudo_focal' in pred:
+ try:
+ pp = gt['camera_intrinsics'][..., :2, 2]
+ except KeyError:
+ pp = None
+ pts3d = depthmap_to_pts3d(**pred, pp=pp)
+
+ elif 'pts3d' in pred:
+ # pts3d from my camera
+ pts3d = pred['pts3d']
+
+ elif 'pts3d_in_other_view' in pred:
+ # pts3d from the other camera, already transformed
+ assert use_pose is True
+ return pred['pts3d_in_other_view'] # return!
+
+ if use_pose:
+ camera_pose = pred.get('camera_pose')
+ assert camera_pose is not None
+ pts3d = geotrf(camera_pose, pts3d)
+
+ return pts3d
+
+
+class LLoss (nn.Module):
+ """ L-norm loss
+ """
+
+ def __init__(self, reduction='mean'):
+ super().__init__()
+ self.reduction = reduction
+
+ def forward(self, a, b):
+ assert a.shape == b.shape and a.ndim >= 2 and 1 <= a.shape[-1] <= 3, f'Bad shape = {a.shape}'
+ dist = self.distance(a, b)
+ assert dist.ndim == a.ndim-1 # one dimension less
+ if self.reduction == 'none':
+ return dist
+ if self.reduction == 'sum':
+ return dist.sum()
+ if self.reduction == 'mean':
+ return dist.mean() if dist.numel() > 0 else dist.new_zeros(())
+ raise ValueError(f'bad {self.reduction=} mode')
+
+ def distance(self, a, b):
+ raise NotImplementedError()
+
+
+class L21Loss (LLoss):
+ """ Euclidean distance between 3d points """
+
+ def distance(self, a, b):
+ return torch.norm(a - b, dim=-1) # normalized L2 distance
+
+
+class MultiLoss (nn.Module):
+ """ Easily combinable losses (also keep track of individual loss values):
+ loss = MyLoss1() + 0.1*MyLoss2()
+ Usage:
+ Inherit from this class and override get_name() and compute_loss()
+ """
+
+ def __init__(self):
+ super().__init__()
+ self._alpha = 1
+ self._loss2 = None
+
+ def compute_loss(self, *args, **kwargs):
+ raise NotImplementedError()
+
+ def get_name(self):
+ raise NotImplementedError()
+
+ def __mul__(self, alpha):
+ assert isinstance(alpha, (int, float))
+ res = copy(self)
+ res._alpha = alpha
+ return res
+ __rmul__ = __mul__ # same
+
+ def __add__(self, loss2):
+ assert isinstance(loss2, MultiLoss)
+ res = cur = copy(self)
+ # find the end of the chain
+ while cur._loss2 is not None:
+ cur = cur._loss2
+ cur._loss2 = loss2
+ return res
+
+ def __repr__(self):
+ name = self.get_name()
+ if self._alpha != 1:
+ name = f'{self._alpha:g}*{name}'
+ if self._loss2:
+ name = f'{name} + {self._loss2}'
+ return name
+
+ def forward(self, *args, **kwargs):
+ loss = self.compute_loss(*args, **kwargs)
+ if isinstance(loss, tuple):
+ loss, details = loss
+ elif loss.ndim == 0:
+ details = {self.get_name(): float(loss)}
+ else:
+ details = {}
+ loss = loss * self._alpha
+
+ if self._loss2:
+ loss2, details2 = self._loss2(*args, **kwargs)
+ loss = loss + loss2
+ details |= details2
+
+ return loss, details
+
+
+class Criterion (nn.Module):
+ def __init__(self, criterion=None):
+ super().__init__()
+ assert isinstance(criterion, LLoss), f'{criterion} is not a proper criterion!'+bb()
+ self.criterion = copy(criterion)
+
+ def get_name(self):
+ return f'{type(self).__name__}({self.criterion})'
+
+ def with_reduction(self, mode):
+ res = loss = deepcopy(self)
+ while loss is not None:
+ assert isinstance(loss, Criterion)
+ loss.criterion.reduction = 'none' # make it return the loss for each sample
+ loss = loss._loss2 # we assume loss is a Multiloss
+ return res
+
+
+class ConfLoss (MultiLoss):
+ """ Weighted regression by learned confidence.
+ Assuming the input pixel_loss is a pixel-level regression loss.
+
+ Principle:
+ high-confidence means high conf = 0.1 ==> conf_loss = x / 10 + alpha*log(10)
+ low confidence means low conf = 10 ==> conf_loss = x * 10 - alpha*log(10)
+
+ alpha: hyperparameter
+ """
+
+ def __init__(self, pixel_loss, alpha=1):
+ super().__init__()
+ assert alpha > 0
+ self.alpha = alpha
+ self.pixel_loss = pixel_loss.with_reduction('none')
+
+ def get_name(self):
+ return f'ConfLoss({self.pixel_loss})'
+
+ def get_conf_log(self, x):
+ return x, torch.log(x)
+
+ def compute_loss(self, gt1, gt2, pred1, pred2, **kw):
+ # compute per-pixel loss
+ ((loss1, msk1), (loss2, msk2)), details = self.pixel_loss(gt1, gt2, pred1, pred2, **kw)
+ if loss1.numel() == 0:
+ print('NO VALID POINTS in img1', force=True)
+ if loss2.numel() == 0:
+ print('NO VALID POINTS in img2', force=True)
+
+ # weight by confidence
+ conf1, log_conf1 = self.get_conf_log(pred1['conf'][msk1])
+ conf2, log_conf2 = self.get_conf_log(pred2['conf'][msk2])
+ conf_loss1 = loss1 * conf1 - self.alpha * log_conf1
+ conf_loss2 = loss2 * conf2 - self.alpha * log_conf2
+
+ # average + nan protection (in case of no valid pixels at all)
+ conf_loss1 = conf_loss1.mean() if conf_loss1.numel() > 0 else 0
+ conf_loss2 = conf_loss2.mean() if conf_loss2.numel() > 0 else 0
+
+ return conf_loss1 + conf_loss2, dict(conf_loss_1=float(conf_loss1), conf_loss2=float(conf_loss2), **details)
+
+
+class Regr3D(nn.Module):
+ """ Ensure that all 3D points are correct.
+ Asymmetric loss: view1 is supposed to be the anchor.
+
+ P1 = RT1 @ D1
+ P2 = RT2 @ D2
+ loss1 = (I @ pred_D1) - (RT1^-1 @ RT1 @ D1)
+ loss2 = (RT21 @ pred_D2) - (RT1^-1 @ P2)
+ = (RT21 @ pred_D2) - (RT1^-1 @ RT2 @ D2)
+ """
+
+ def __init__(self, norm_mode='avg_dis', alpha=0.2, gt_scale=False):
+ super().__init__()
+ self.norm_mode = norm_mode
+ self.alpha = alpha
+ self.gt_scale = gt_scale
+
+ def get_conf_log(self, x):
+ return x, torch.log(x)
+
+ def forward(self, gt_pts1, gt_pts2, pr_pts1, pr_pts2, conf1=None, conf2=None, dist_clip=None, disable_view1=False):
+ valid1 = valid2 = torch.ones_like(conf1, dtype=torch.bool)
+ if dist_clip is not None:
+ # points that are too far-away == invalid
+ dis1 = gt_pts1.norm(dim=-1) # (B, H, W)
+ dis2 = gt_pts2.norm(dim=-1) # (B, H, W)
+ valid1 = (dis1 <= dist_clip)
+ valid2 = (dis2 <= dist_clip)
+ else:
+ dis1 = gt_pts1.norm(dim=-1) # (B, H, W)
+ dis2 = gt_pts2.norm(dim=-1) # (B, H, W)
+
+ # only keep the points norm whithin the range of 1% to 99% of each batch
+ # Flatten along the H and W dimensions
+ dis1_flat = dis1.view(dis1.shape[0], -1)
+ dis2_flat = dis2.view(dis2.shape[0], -1)
+
+ # Compute the 0.1% and 99.9% quantiles for each batch
+ # quantiles_1 = torch.quantile(dis1_flat, torch.tensor([0.01, 0.99]).to(dis1_flat.device), dim=1)
+ # quantiles_2 = torch.quantile(dis2_flat, torch.tensor([0.01, 0.99]).to(dis2_flat.device), dim=1)
+ quantiles_1 = torch.quantile(dis1_flat, torch.tensor([0.002, 0.998]).to(dis1_flat.device), dim=1)
+ quantiles_2 = torch.quantile(dis2_flat, torch.tensor([0.002, 0.998]).to(dis2_flat.device), dim=1)
+
+ # Create masks based on the quantiles
+ valid1 = (dis1 >= quantiles_1[0].view(-1, 1, 1)) & (dis1 <= quantiles_1[1].view(-1, 1, 1))
+ valid2 = (dis2 >= quantiles_2[0].view(-1, 1, 1)) & (dis2 <= quantiles_2[1].view(-1, 1, 1))
+
+ # set min confidence to 3
+ valid1 = valid1 & (conf1 >= 3)
+ valid2 = valid2 & (conf2 >= 3)
+
+ # normalize 3d points
+ if self.norm_mode:
+ pr_pts1, pr_pts2 = normalize_pointcloud(pr_pts1, pr_pts2, self.norm_mode, valid1, valid2)
+ if self.norm_mode and not self.gt_scale:
+ gt_pts1, gt_pts2 = normalize_pointcloud(gt_pts1, gt_pts2, self.norm_mode, valid1, valid2)
+
+ loss1 = torch.norm(pr_pts1 - gt_pts1, dim=-1)
+ loss2 = torch.norm(pr_pts2 - gt_pts2, dim=-1)
+ # loss1 = (pr_pts1[..., -1] - gt_pts1[..., -1]).abs()
+ # loss2 = (pr_pts2[..., -1] - gt_pts2[..., -1]).abs()
+
+ loss1, loss2 = loss1[valid1], loss2[valid2]
+
+ if disable_view1:
+ return loss2.mean()
+ return loss1.mean() + loss2.mean()
+
+ # conf1, conf2 = conf1[valid1], conf2[valid2]
+ # conf1, conf2 = conf1.softmax(dim=-1), conf2.softmax(dim=-1)
+ # loss1 = (loss1 * conf1).sum()
+ # loss2 = (loss2 * conf2).sum()
+ # return loss1 + loss2
+ #
+ # # weight by confidence
+ # conf1, log_conf1 = self.get_conf_log(conf1[valid1])
+ # conf2, log_conf2 = self.get_conf_log(conf2[valid2])
+ # conf_loss1 = loss1 * conf1 - self.alpha * log_conf1
+ # conf_loss2 = loss2 * conf2 - self.alpha * log_conf2
+ #
+ # # average + nan protection (in case of no valid pixels at all)
+ # conf_loss1 = conf_loss1.mean() if conf_loss1.numel() > 0 else 0
+ # conf_loss2 = conf_loss2.mean() if conf_loss2.numel() > 0 else 0
+ #
+ # return conf_loss1 + conf_loss2
+
+ # def forward(self, gt_pts1, gt_pts2, pr_pts1, pr_pts2, conf1=None, conf2=None, dist_clip=None, disable_view1=False):
+ # # valid1 = valid2 = torch.ones_like(conf1, dtype=torch.bool)
+ # if dist_clip is not None:
+ # # points that are too far-away == invalid
+ # dis1 = gt_pts1.norm(dim=-1) # (B, H, W)
+ # dis2 = gt_pts2.norm(dim=-1) # (B, H, W)
+ # valid1 = (dis1 <= dist_clip)
+ # valid2 = (dis2 <= dist_clip)
+ # else:
+ # dis1 = gt_pts1.norm(dim=-1) # (B, H, W)
+ # dis2 = gt_pts2.norm(dim=-1) # (B, H, W)
+ #
+ # # only keep the points norm whithin the range of 1% to 99% of each batch
+ # # Flatten along the H and W dimensions
+ # dis1_flat = dis1.view(dis1.shape[0], -1)
+ # dis2_flat = dis2.view(dis2.shape[0], -1)
+ #
+ # # Compute the 0.1% and 99.9% quantiles for each batch
+ # quantiles_1 = torch.quantile(dis1_flat, torch.tensor([0.1, 0.9]).to(dis1_flat.device), dim=1)
+ # quantiles_2 = torch.quantile(dis2_flat, torch.tensor([0.1, 0.9]).to(dis2_flat.device), dim=1)
+ # # quantiles_1 = torch.quantile(dis1_flat, torch.tensor([0.002, 0.998]).to(dis1_flat.device), dim=1)
+ # # quantiles_2 = torch.quantile(dis2_flat, torch.tensor([0.002, 0.998]).to(dis2_flat.device), dim=1)
+ #
+ # # Create masks based on the quantiles
+ # valid1 = (dis1 >= quantiles_1[0].view(-1, 1, 1)) & (dis1 <= quantiles_1[1].view(-1, 1, 1))
+ # valid2 = (dis2 >= quantiles_2[0].view(-1, 1, 1)) & (dis2 <= quantiles_2[1].view(-1, 1, 1))
+ #
+ # # set min opacity to 3
+ # valid1 = valid1 & (conf1 >= 0.2)
+ # valid2 = valid2 & (conf2 >= 0.2)
+ #
+ # # # normalize 3d points
+ # # if self.norm_mode:
+ # # pr_pts1, pr_pts2 = normalize_pointcloud(pr_pts1, pr_pts2, self.norm_mode, valid1, valid2)
+ # # if self.norm_mode and not self.gt_scale:
+ # # gt_pts1, gt_pts2 = normalize_pointcloud(gt_pts1, gt_pts2, self.norm_mode, valid1, valid2)
+ #
+ # # L1 loss
+ # # loss1 = (pr_pts1[..., -1] - gt_pts1[..., -1]).abs()
+ # # loss2 = (pr_pts2[..., -1] - gt_pts2[..., -1]).abs()
+ #
+ # # L2 loss
+ # loss1 = torch.norm(pr_pts1 - gt_pts1, dim=-1)
+ # loss2 = torch.norm(pr_pts2 - gt_pts2, dim=-1)
+ # loss1, loss2 = loss1[valid1], loss2[valid2]
+ #
+ # # Pearson correlation coefficient loss
+ # # pr_pts1, pr_pts2 = pr_pts1[valid1], pr_pts2[valid2]
+ # # gt_pts1, gt_pts2 = gt_pts1[valid1], gt_pts2[valid2]
+ # # loss1 = 1 - pearson_corrcoef(pr_pts1.view(-1, 3), gt_pts1.view(-1, 3))
+ # # loss2 = 1 - pearson_corrcoef(pr_pts2.view(-1, 3), gt_pts2.view(-1, 3))
+ #
+ # # # Chamfer distance loss
+ # # pr_pts = torch.cat([pr_pts1.flatten(1, 2), pr_pts2.flatten(1, 2)], dim=1)
+ # # gt_pts = torch.cat([gt_pts1.flatten(1, 2), gt_pts2.flatten(1, 2)], dim=1)
+ # # valid_mask = torch.cat([valid1.flatten(1, 2), valid2.flatten(1, 2)], dim=1)
+ # # nan_pts_pr, nnz = invalid_to_zeros(pr_pts, valid_mask, ndim=3)
+ # # nan_pts_gt, nnz = invalid_to_zeros(gt_pts, valid_mask, ndim=3)
+ # #
+ # # loss, _ = chamfer_distance(nan_pts_pr, nan_pts_gt, batch_reduction=None, point_reduction=None)
+ # # loss1, loss2 = loss[0], loss[1]
+ # # return loss1.sum() / valid_mask.sum()
+ #
+ # if disable_view1:
+ # return loss2.mean()
+ # return loss1.mean() + loss2.mean()
+ #
+ # # conf1, conf2 = conf1[valid1], conf2[valid2]
+ # # conf1, conf2 = conf1.softmax(dim=-1), conf2.softmax(dim=-1)
+ # # loss1 = (loss1 * conf1).sum()
+ # # loss2 = (loss2 * conf2).sum()
+ # # return loss1 + loss2
+ # #
+ # # # weight by confidence
+ # # conf1, log_conf1 = self.get_conf_log(conf1[valid1])
+ # # conf2, log_conf2 = self.get_conf_log(conf2[valid2])
+ # # conf_loss1 = loss1 * conf1 - self.alpha * log_conf1
+ # # conf_loss2 = loss2 * conf2 - self.alpha * log_conf2
+ # #
+ # # # average + nan protection (in case of no valid pixels at all)
+ # # conf_loss1 = conf_loss1.mean() if conf_loss1.numel() > 0 else 0
+ # # conf_loss2 = conf_loss2.mean() if conf_loss2.numel() > 0 else 0
+ # #
+ # # return conf_loss1 + conf_loss2
+
diff --git a/src/loss/loss_ssim.py b/src/loss/loss_ssim.py
new file mode 100644
index 0000000000000000000000000000000000000000..7f9e475b346a5d69bc9b2aeca56ae4a783c8747f
--- /dev/null
+++ b/src/loss/loss_ssim.py
@@ -0,0 +1,357 @@
+# Copyright 2020 by Gongfan Fang, Zhejiang University.
+# All rights reserved.
+# Modified by Botao Ye from https://github.com/VainF/pytorch-msssim/blob/master/pytorch_msssim/ssim.py.
+import warnings
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.nn.functional as F
+from torch import Tensor
+
+
+def _fspecial_gauss_1d(size: int, sigma: float) -> Tensor:
+ r"""Create 1-D gauss kernel
+ Args:
+ size (int): the size of gauss kernel
+ sigma (float): sigma of normal distribution
+ Returns:
+ torch.Tensor: 1D kernel (1 x 1 x size)
+ """
+ coords = torch.arange(size, dtype=torch.float)
+ coords -= size // 2
+
+ g = torch.exp(-(coords ** 2) / (2 * sigma ** 2))
+ g /= g.sum()
+
+ return g.unsqueeze(0).unsqueeze(0)
+
+
+def gaussian_filter(input: Tensor, win: Tensor) -> Tensor:
+ r""" Blur input with 1-D kernel
+ Args:
+ input (torch.Tensor): a batch of tensors to be blurred
+ window (torch.Tensor): 1-D gauss kernel
+ Returns:
+ torch.Tensor: blurred tensors
+ """
+ assert all([ws == 1 for ws in win.shape[1:-1]]), win.shape
+ if len(input.shape) == 4:
+ conv = F.conv2d
+ elif len(input.shape) == 5:
+ conv = F.conv3d
+ else:
+ raise NotImplementedError(input.shape)
+
+ C = input.shape[1]
+ out = input
+ for i, s in enumerate(input.shape[2:]):
+ if s >= win.shape[-1]:
+ out = conv(out, weight=win.transpose(2 + i, -1), stride=1, padding=0, groups=C)
+ else:
+ warnings.warn(
+ f"Skipping Gaussian Smoothing at dimension 2+{i} for input: {input.shape} and win size: {win.shape[-1]}"
+ )
+
+ return out
+
+
+def _ssim(
+ X: Tensor,
+ Y: Tensor,
+ data_range: float,
+ win: Tensor,
+ size_average: bool = True,
+ K: Union[Tuple[float, float], List[float]] = (0.01, 0.03),
+ retrun_seprate: bool = False,
+) -> Tuple[Tensor, Tensor, Tensor | None, Tensor | None, Tensor | None]:
+ r""" Calculate ssim index for X and Y
+
+ Args:
+ X (torch.Tensor): images
+ Y (torch.Tensor): images
+ data_range (float or int): value range of input images. (usually 1.0 or 255)
+ win (torch.Tensor): 1-D gauss kernel
+ size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar
+ retrun_seprate (bool, optional): if True, return brightness, contrast, and structure similarity maps as well
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor]: ssim results.
+ """
+ K1, K2 = K
+ # batch, channel, [depth,] height, width = X.shape
+ compensation = 1.0
+
+ C1 = (K1 * data_range) ** 2
+ C2 = (K2 * data_range) ** 2
+
+ win = win.to(X.device, dtype=X.dtype)
+
+ mu1 = gaussian_filter(X, win)
+ mu2 = gaussian_filter(Y, win)
+
+ mu1_sq = mu1.pow(2)
+ mu2_sq = mu2.pow(2)
+ mu1_mu2 = mu1 * mu2
+
+ sigma1_sq = compensation * (gaussian_filter(X * X, win) - mu1_sq)
+ sigma2_sq = compensation * (gaussian_filter(Y * Y, win) - mu2_sq)
+ sigma12 = compensation * (gaussian_filter(X * Y, win) - mu1_mu2)
+
+ cs_map = (2 * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2) # set alpha=beta=gamma=1
+ ssim_map = ((2 * mu1_mu2 + C1) / (mu1_sq + mu2_sq + C1)) * cs_map
+ ssim_per_channel = torch.flatten(ssim_map, 2).mean(-1)
+ cs = torch.flatten(cs_map, 2).mean(-1)
+
+ brightness = contrast = structure = torch.zeros_like(ssim_per_channel)
+ if retrun_seprate:
+ epsilon = torch.finfo(torch.float32).eps**2
+ sigma1_sq = sigma1_sq.clamp(min=epsilon)
+ sigma2_sq = sigma2_sq.clamp(min=epsilon)
+ sigma12 = torch.sign(sigma12) * torch.minimum(
+ torch.sqrt(sigma1_sq * sigma2_sq), torch.abs(sigma12))
+
+ C3 = C2 / 2
+ sigma1_sigma2 = torch.sqrt(sigma1_sq) * torch.sqrt(sigma2_sq)
+ brightness_map = (2 * mu1_mu2 + C1) / (mu1_sq + mu2_sq + C1)
+ contrast_map = (2 * sigma1_sigma2 + C2) / (sigma1_sq + sigma2_sq + C2)
+ structure_map = (sigma12 + C3) / (sigma1_sigma2 + C3)
+
+ contrast_map = contrast_map.clamp(max=0.98)
+ structure_map = structure_map.clamp(max=0.98)
+
+ brightness = brightness_map.flatten(2).mean(-1)
+ contrast = contrast_map.flatten(2).mean(-1)
+ structure = structure_map.flatten(2).mean(-1)
+
+ return ssim_per_channel, cs, brightness, contrast, structure
+
+
+def ssim(
+ X: Tensor,
+ Y: Tensor,
+ data_range: float = 255,
+ size_average: bool = True,
+ win_size: int = 11,
+ win_sigma: float = 1.5,
+ win: Optional[Tensor] = None,
+ K: Union[Tuple[float, float], List[float]] = (0.01, 0.03),
+ nonnegative_ssim: bool = False,
+ retrun_seprate: bool = False,
+) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
+ r""" interface of ssim
+ Args:
+ X (torch.Tensor): a batch of images, (N,C,H,W)
+ Y (torch.Tensor): a batch of images, (N,C,H,W)
+ data_range (float or int, optional): value range of input images. (usually 1.0 or 255)
+ size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar
+ win_size: (int, optional): the size of gauss kernel
+ win_sigma: (float, optional): sigma of normal distribution
+ win (torch.Tensor, optional): 1-D gauss kernel. if None, a new kernel will be created according to win_size and win_sigma
+ K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results.
+ nonnegative_ssim (bool, optional): force the ssim response to be nonnegative with relu
+ retrun_seprate (bool, optional): if True, return brightness, contrast, and structure similarity maps as well
+
+ Returns:
+ torch.Tensor: ssim results
+ """
+ if not X.shape == Y.shape:
+ raise ValueError(f"Input images should have the same dimensions, but got {X.shape} and {Y.shape}.")
+
+ for d in range(len(X.shape) - 1, 1, -1):
+ X = X.squeeze(dim=d)
+ Y = Y.squeeze(dim=d)
+
+ if len(X.shape) not in (4, 5):
+ raise ValueError(f"Input images should be 4-d or 5-d tensors, but got {X.shape}")
+
+ #if not X.type() == Y.type():
+ # raise ValueError(f"Input images should have the same dtype, but got {X.type()} and {Y.type()}.")
+
+ if win is not None: # set win_size
+ win_size = win.shape[-1]
+
+ if not (win_size % 2 == 1):
+ raise ValueError("Window size should be odd.")
+
+ if win is None:
+ win = _fspecial_gauss_1d(win_size, win_sigma)
+ win = win.repeat([X.shape[1]] + [1] * (len(X.shape) - 1))
+
+ ssim_per_channel, cs, brightness, contrast, structure \
+ = _ssim(X, Y, data_range=data_range, win=win, size_average=False, K=K, retrun_seprate=retrun_seprate)
+
+ if nonnegative_ssim:
+ ssim_per_channel = torch.relu(ssim_per_channel)
+
+ if size_average:
+ return ssim_per_channel.mean(), brightness.mean(), contrast.mean(), structure.mean()
+ else:
+ return ssim_per_channel.mean(1), brightness.mean(1), contrast.mean(1), structure.mean(1)
+
+
+def ms_ssim(
+ X: Tensor,
+ Y: Tensor,
+ data_range: float = 255,
+ size_average: bool = True,
+ win_size: int = 11,
+ win_sigma: float = 1.5,
+ win: Optional[Tensor] = None,
+ weights: Optional[List[float]] = None,
+ K: Union[Tuple[float, float], List[float]] = (0.01, 0.03)
+) -> Tensor:
+ r""" interface of ms-ssim
+ Args:
+ X (torch.Tensor): a batch of images, (N,C,[T,]H,W)
+ Y (torch.Tensor): a batch of images, (N,C,[T,]H,W)
+ data_range (float or int, optional): value range of input images. (usually 1.0 or 255)
+ size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar
+ win_size: (int, optional): the size of gauss kernel
+ win_sigma: (float, optional): sigma of normal distribution
+ win (torch.Tensor, optional): 1-D gauss kernel. if None, a new kernel will be created according to win_size and win_sigma
+ weights (list, optional): weights for different levels
+ K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results.
+ Returns:
+ torch.Tensor: ms-ssim results
+ """
+ if not X.shape == Y.shape:
+ raise ValueError(f"Input images should have the same dimensions, but got {X.shape} and {Y.shape}.")
+
+ for d in range(len(X.shape) - 1, 1, -1):
+ X = X.squeeze(dim=d)
+ Y = Y.squeeze(dim=d)
+
+ #if not X.type() == Y.type():
+ # raise ValueError(f"Input images should have the same dtype, but got {X.type()} and {Y.type()}.")
+
+ if len(X.shape) == 4:
+ avg_pool = F.avg_pool2d
+ elif len(X.shape) == 5:
+ avg_pool = F.avg_pool3d
+ else:
+ raise ValueError(f"Input images should be 4-d or 5-d tensors, but got {X.shape}")
+
+ if win is not None: # set win_size
+ win_size = win.shape[-1]
+
+ if not (win_size % 2 == 1):
+ raise ValueError("Window size should be odd.")
+
+ smaller_side = min(X.shape[-2:])
+ assert smaller_side > (win_size - 1) * (
+ 2 ** 4
+ ), "Image size should be larger than %d due to the 4 downsamplings in ms-ssim" % ((win_size - 1) * (2 ** 4))
+
+ if weights is None:
+ weights = [0.0448, 0.2856, 0.3001, 0.2363, 0.1333]
+ weights_tensor = X.new_tensor(weights)
+
+ if win is None:
+ win = _fspecial_gauss_1d(win_size, win_sigma)
+ win = win.repeat([X.shape[1]] + [1] * (len(X.shape) - 1))
+
+ levels = weights_tensor.shape[0]
+ mcs = []
+ for i in range(levels):
+ ssim_per_channel, cs = _ssim(X, Y, win=win, data_range=data_range, size_average=False, K=K)
+
+ if i < levels - 1:
+ mcs.append(torch.relu(cs))
+ padding = [s % 2 for s in X.shape[2:]]
+ X = avg_pool(X, kernel_size=2, padding=padding)
+ Y = avg_pool(Y, kernel_size=2, padding=padding)
+
+ ssim_per_channel = torch.relu(ssim_per_channel) # type: ignore # (batch, channel)
+ mcs_and_ssim = torch.stack(mcs + [ssim_per_channel], dim=0) # (level, batch, channel)
+ ms_ssim_val = torch.prod(mcs_and_ssim ** weights_tensor.view(-1, 1, 1), dim=0)
+
+ if size_average:
+ return ms_ssim_val.mean()
+ else:
+ return ms_ssim_val.mean(1)
+
+
+class SSIM(torch.nn.Module):
+ def __init__(
+ self,
+ data_range: float = 255,
+ size_average: bool = True,
+ win_size: int = 11,
+ win_sigma: float = 1.5,
+ channel: int = 3,
+ spatial_dims: int = 2,
+ K: Union[Tuple[float, float], List[float]] = (0.01, 0.03),
+ nonnegative_ssim: bool = False,
+ ) -> None:
+ r""" class for ssim
+ Args:
+ data_range (float or int, optional): value range of input images. (usually 1.0 or 255)
+ size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar
+ win_size: (int, optional): the size of gauss kernel
+ win_sigma: (float, optional): sigma of normal distribution
+ channel (int, optional): input channels (default: 3)
+ K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results.
+ nonnegative_ssim (bool, optional): force the ssim response to be nonnegative with relu.
+ """
+
+ super(SSIM, self).__init__()
+ self.win_size = win_size
+ self.win = _fspecial_gauss_1d(win_size, win_sigma).repeat([channel, 1] + [1] * spatial_dims)
+ self.size_average = size_average
+ self.data_range = data_range
+ self.K = K
+ self.nonnegative_ssim = nonnegative_ssim
+
+ def forward(self, X: Tensor, Y: Tensor) -> Tensor:
+ return ssim(
+ X,
+ Y,
+ data_range=self.data_range,
+ size_average=self.size_average,
+ win=self.win,
+ K=self.K,
+ nonnegative_ssim=self.nonnegative_ssim,
+ )
+
+
+class MS_SSIM(torch.nn.Module):
+ def __init__(
+ self,
+ data_range: float = 255,
+ size_average: bool = True,
+ win_size: int = 11,
+ win_sigma: float = 1.5,
+ channel: int = 3,
+ spatial_dims: int = 2,
+ weights: Optional[List[float]] = None,
+ K: Union[Tuple[float, float], List[float]] = (0.01, 0.03),
+ ) -> None:
+ r""" class for ms-ssim
+ Args:
+ data_range (float or int, optional): value range of input images. (usually 1.0 or 255)
+ size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar
+ win_size: (int, optional): the size of gauss kernel
+ win_sigma: (float, optional): sigma of normal distribution
+ channel (int, optional): input channels (default: 3)
+ weights (list, optional): weights for different levels
+ K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results.
+ """
+
+ super(MS_SSIM, self).__init__()
+ self.win_size = win_size
+ self.win = _fspecial_gauss_1d(win_size, win_sigma).repeat([channel, 1] + [1] * spatial_dims)
+ self.size_average = size_average
+ self.data_range = data_range
+ self.weights = weights
+ self.K = K
+
+ def forward(self, X: Tensor, Y: Tensor) -> Tensor:
+ return ms_ssim(
+ X,
+ Y,
+ data_range=self.data_range,
+ size_average=self.size_average,
+ win=self.win,
+ weights=self.weights,
+ K=self.K,
+ )
diff --git a/src/main.py b/src/main.py
new file mode 100644
index 0000000000000000000000000000000000000000..975fe992020ca5dd379a18c34f3f016bdc802625
--- /dev/null
+++ b/src/main.py
@@ -0,0 +1,158 @@
+import os
+from pathlib import Path
+
+import hydra
+import torch
+import wandb
+import random
+from colorama import Fore
+from jaxtyping import install_import_hook
+from lightning.pytorch import Trainer
+from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint
+from lightning.pytorch.loggers.wandb import WandbLogger
+from lightning.pytorch.plugins.environments import SLURMEnvironment
+from lightning.pytorch.strategies import DeepSpeedStrategy
+from omegaconf import DictConfig, OmegaConf
+from hydra.core.hydra_config import HydraConfig
+
+import sys
+sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+from src.model.model import get_model
+from src.misc.weight_modify import checkpoint_filter_fn
+
+import warnings
+warnings.filterwarnings("ignore")
+
+# Configure beartype and jaxtyping.
+with install_import_hook(
+ ("src",),
+ ("beartype", "beartype"),
+):
+ from src.config import load_typed_root_config
+ from src.dataset.data_module import DataModule
+ from src.global_cfg import set_cfg
+ from src.loss import get_losses
+ from src.misc.LocalLogger import LocalLogger
+ from src.misc.step_tracker import StepTracker
+ from src.misc.wandb_tools import update_checkpoint_path
+ from src.model.decoder import get_decoder
+ from src.model.encoder import get_encoder
+ from src.model.model_wrapper import ModelWrapper
+
+
+def cyan(text: str) -> str:
+ return f"{Fore.CYAN}{text}{Fore.RESET}"
+
+
+@hydra.main(
+ version_base=None,
+ config_path="../config",
+ config_name="main",
+)
+def train(cfg_dict: DictConfig):
+ cfg = load_typed_root_config(cfg_dict)
+ set_cfg(cfg_dict)
+
+ # Set up the output directory.
+ output_dir = Path(
+ hydra.core.hydra_config.HydraConfig.get()["runtime"]["output_dir"]
+ )
+ output_dir.mkdir(parents=True, exist_ok=True)
+ print(cyan(f"Saving outputs to {output_dir}."))
+
+ cfg.train.output_path = output_dir
+
+ # Set up logging with wandb.
+ callbacks = []
+ if cfg_dict.wandb.mode != "disabled":
+ logger = WandbLogger(
+ project=cfg_dict.wandb.project,
+ mode=cfg_dict.wandb.mode,
+ name=f"{cfg_dict.wandb.name} ({output_dir.parent.name}/{output_dir.name})",
+ tags=cfg_dict.wandb.get("tags", None),
+ log_model=False,
+ save_dir=output_dir,
+ config=OmegaConf.to_container(cfg_dict),
+ )
+ callbacks.append(LearningRateMonitor("step", True))
+
+ # On rank != 0, wandb.run is None.
+ if wandb.run is not None:
+ wandb.run.log_code("src")
+ else:
+ logger = LocalLogger()
+
+ # Set up checkpointing.
+ callbacks.append(
+ ModelCheckpoint(
+ output_dir / "checkpoints",
+ every_n_train_steps=cfg.checkpointing.every_n_train_steps,
+ save_top_k=cfg.checkpointing.save_top_k,
+ save_weights_only=cfg.checkpointing.save_weights_only,
+ monitor="info/global_step",
+ mode="max",
+ )
+ )
+ callbacks[-1].CHECKPOINT_EQUALS_CHAR = '_'
+
+ # Prepare the checkpoint for loading.
+ checkpoint_path = update_checkpoint_path(cfg.checkpointing.load, cfg.wandb)
+
+ # This allows the current step to be shared with the data loader processes.
+ step_tracker = StepTracker()
+
+ trainer = Trainer(
+ max_epochs=-1,
+ num_nodes=cfg.trainer.num_nodes,
+ # num_sanity_val_steps=0,
+ accelerator="gpu",
+ logger=logger,
+ devices="auto",
+ strategy=(
+ "ddp_find_unused_parameters_true"
+ if torch.cuda.device_count() > 1
+ else "auto"
+ ),
+ # strategy="deepspeed_stage_1",
+ callbacks=callbacks,
+ val_check_interval=cfg.trainer.val_check_interval,
+ check_val_every_n_epoch=None,
+ enable_progress_bar=False,
+ gradient_clip_val=cfg.trainer.gradient_clip_val,
+ max_steps=cfg.trainer.max_steps,
+ precision=cfg.trainer.precision,
+ accumulate_grad_batches=cfg.trainer.accumulate_grad_batches,
+ # plugins=[SLURMEnvironment(requeue_signal=signal.SIGUSR1)], # Uncomment for SLURM auto resubmission.
+ inference_mode=False if (cfg.mode == "test" and cfg.test.align_pose) else True,
+ )
+ torch.manual_seed(cfg_dict.seed + trainer.global_rank)
+
+ model = get_model(cfg.model.encoder, cfg.model.decoder)
+
+ model_wrapper = ModelWrapper(
+ cfg.optimizer,
+ cfg.test,
+ cfg.train,
+ model,
+ get_losses(cfg.loss),
+ step_tracker
+ )
+ data_module = DataModule(
+ cfg.dataset,
+ cfg.data_loader,
+ step_tracker,
+ global_rank=trainer.global_rank,
+ )
+
+ if cfg.mode == "train":
+ trainer.fit(model_wrapper, datamodule=data_module, ckpt_path=checkpoint_path)
+ else:
+ trainer.test(
+ model_wrapper,
+ datamodule=data_module,
+ ckpt_path=checkpoint_path,
+ )
+
+
+if __name__ == "__main__":
+ train()
diff --git a/src/misc/LocalLogger.py b/src/misc/LocalLogger.py
new file mode 100644
index 0000000000000000000000000000000000000000..8761a7aa04c24f10a9ef40e66e7d5cbb0a2591b8
--- /dev/null
+++ b/src/misc/LocalLogger.py
@@ -0,0 +1,48 @@
+import os
+from pathlib import Path
+from typing import Any, Optional
+
+from lightning.pytorch.loggers.logger import Logger
+from lightning.pytorch.utilities import rank_zero_only
+from PIL import Image
+
+LOG_PATH = Path("outputs/local")
+
+
+class LocalLogger(Logger):
+ def __init__(self) -> None:
+ super().__init__()
+ self.experiment = None
+ os.system(f"rm -r {LOG_PATH}")
+
+ @property
+ def name(self):
+ return "LocalLogger"
+
+ @property
+ def version(self):
+ return 0
+
+ @rank_zero_only
+ def log_hyperparams(self, params):
+ pass
+
+ @rank_zero_only
+ def log_metrics(self, metrics, step):
+ pass
+
+ @rank_zero_only
+ def log_image(
+ self,
+ key: str,
+ images: list[Any],
+ step: Optional[int] = None,
+ **kwargs,
+ ):
+ # The function signature is the same as the wandb logger's, but the step is
+ # actually required.
+ assert step is not None
+ for index, image in enumerate(images):
+ path = LOG_PATH / f"{key}/{index:0>2}_{step:0>6}.jpg"
+ path.parent.mkdir(exist_ok=True, parents=True)
+ Image.fromarray(image).save(path)
diff --git a/src/misc/benchmarker.py b/src/misc/benchmarker.py
new file mode 100644
index 0000000000000000000000000000000000000000..3cad5bcd413d08de482b58046187e4f887f3a970
--- /dev/null
+++ b/src/misc/benchmarker.py
@@ -0,0 +1,37 @@
+import json
+from collections import defaultdict
+from contextlib import contextmanager
+from pathlib import Path
+from time import time
+
+import numpy as np
+import torch
+
+
+class Benchmarker:
+ def __init__(self):
+ self.execution_times = defaultdict(list)
+
+ @contextmanager
+ def time(self, tag: str, num_calls: int = 1):
+ try:
+ start_time = time()
+ yield
+ finally:
+ end_time = time()
+ for _ in range(num_calls):
+ self.execution_times[tag].append((end_time - start_time) / num_calls)
+
+ def dump(self, path: Path) -> None:
+ path.parent.mkdir(exist_ok=True, parents=True)
+ with path.open("w") as f:
+ json.dump(dict(self.execution_times), f)
+
+ def dump_memory(self, path: Path) -> None:
+ path.parent.mkdir(exist_ok=True, parents=True)
+ with path.open("w") as f:
+ json.dump(torch.cuda.memory_stats()["allocated_bytes.all.peak"], f)
+
+ def summarize(self) -> None:
+ for tag, times in self.execution_times.items():
+ print(f"{tag}: {len(times)} calls, avg. {np.mean(times)} seconds per call")
diff --git a/src/misc/cam_utils.py b/src/misc/cam_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..e194236f38e381a3b813643105b38c5f494cff5d
--- /dev/null
+++ b/src/misc/cam_utils.py
@@ -0,0 +1,218 @@
+import cv2
+import numpy as np
+import torch
+from jaxtyping import Float
+from torch import Tensor
+import torch.nn.functional as F
+
+
+def decompose_extrinsic_RT(E: torch.Tensor):
+ """
+ Decompose the standard extrinsic matrix into RT.
+ Batched I/O.
+ """
+ return E[:, :3, :]
+
+
+def compose_extrinsic_RT(RT: torch.Tensor):
+ """
+ Compose the standard form extrinsic matrix from RT.
+ Batched I/O.
+ """
+ return torch.cat([
+ RT,
+ torch.tensor([[[0, 0, 0, 1]]], dtype=RT.dtype, device=RT.device).repeat(RT.shape[0], 1, 1)
+ ], dim=1)
+
+
+def camera_normalization(pivotal_pose: torch.Tensor, poses: torch.Tensor):
+ # [1, 4, 4], [N, 4, 4]
+
+ canonical_camera_extrinsics = torch.tensor([[
+ [1, 0, 0, 0],
+ [0, 1, 0, 0],
+ [0, 0, 1, 0],
+ [0, 0, 0, 1],
+ ]], dtype=torch.float32, device=pivotal_pose.device)
+ pivotal_pose_inv = torch.inverse(pivotal_pose)
+ camera_norm_matrix = torch.bmm(canonical_camera_extrinsics, pivotal_pose_inv)
+
+ # normalize all views
+ poses = torch.bmm(camera_norm_matrix.repeat(poses.shape[0], 1, 1), poses)
+
+ return poses
+
+
+####### Pose update from delta
+
+def rt2mat(R, T):
+ mat = np.eye(4)
+ mat[0:3, 0:3] = R
+ mat[0:3, 3] = T
+ return mat
+
+
+def skew_sym_mat(x):
+ device = x.device
+ dtype = x.dtype
+ ssm = torch.zeros(3, 3, device=device, dtype=dtype)
+ ssm[0, 1] = -x[2]
+ ssm[0, 2] = x[1]
+ ssm[1, 0] = x[2]
+ ssm[1, 2] = -x[0]
+ ssm[2, 0] = -x[1]
+ ssm[2, 1] = x[0]
+ return ssm
+
+
+def SO3_exp(theta):
+ device = theta.device
+ dtype = theta.dtype
+
+ W = skew_sym_mat(theta)
+ W2 = W @ W
+ angle = torch.norm(theta)
+ I = torch.eye(3, device=device, dtype=dtype)
+ if angle < 1e-5:
+ return I + W + 0.5 * W2
+ else:
+ return (
+ I
+ + (torch.sin(angle) / angle) * W
+ + ((1 - torch.cos(angle)) / (angle**2)) * W2
+ )
+
+
+def V(theta):
+ dtype = theta.dtype
+ device = theta.device
+ I = torch.eye(3, device=device, dtype=dtype)
+ W = skew_sym_mat(theta)
+ W2 = W @ W
+ angle = torch.norm(theta)
+ if angle < 1e-5:
+ V = I + 0.5 * W + (1.0 / 6.0) * W2
+ else:
+ V = (
+ I
+ + W * ((1.0 - torch.cos(angle)) / (angle**2))
+ + W2 * ((angle - torch.sin(angle)) / (angle**3))
+ )
+ return V
+
+
+def SE3_exp(tau):
+ dtype = tau.dtype
+ device = tau.device
+
+ rho = tau[:3]
+ theta = tau[3:]
+ R = SO3_exp(theta)
+ t = V(theta) @ rho
+
+ T = torch.eye(4, device=device, dtype=dtype)
+ T[:3, :3] = R
+ T[:3, 3] = t
+ return T
+
+
+def update_pose(cam_trans_delta: Float[Tensor, "batch 3"],
+ cam_rot_delta: Float[Tensor, "batch 3"],
+ extrinsics: Float[Tensor, "batch 4 4"],
+ # original_rot: Float[Tensor, "batch 3 3"],
+ # original_trans: Float[Tensor, "batch 3"],
+ # converged_threshold: float = 1e-4
+ ):
+ # extrinsics is c2w, here we need w2c as input, so we need to invert it
+ bs = cam_trans_delta.shape[0]
+
+ tau = torch.cat([cam_trans_delta, cam_rot_delta], dim=-1)
+ T_w2c = extrinsics.inverse()
+
+ new_w2c_list = []
+ for i in range(bs):
+ new_w2c = SE3_exp(tau[i]) @ T_w2c[i]
+ new_w2c_list.append(new_w2c)
+
+ new_w2c = torch.stack(new_w2c_list, dim=0)
+ return new_w2c.inverse()
+
+ # converged = tau.norm() < converged_threshold
+ # camera.update_RT(new_R, new_T)
+ #
+ # camera.cam_rot_delta.data.fill_(0)
+ # camera.cam_trans_delta.data.fill_(0)
+ # return converged
+
+
+####### Pose estimation
+def inv(mat):
+ """ Invert a torch or numpy matrix
+ """
+ if isinstance(mat, torch.Tensor):
+ return torch.linalg.inv(mat)
+ if isinstance(mat, np.ndarray):
+ return np.linalg.inv(mat)
+ raise ValueError(f'bad matrix type = {type(mat)}')
+
+
+def get_pnp_pose(pts3d, opacity, K, H, W, opacity_threshold=0.3):
+ pixels = np.mgrid[:W, :H].T.astype(np.float32)
+ pts3d = pts3d.cpu().numpy()
+ opacity = opacity.cpu().numpy()
+ K = K.cpu().numpy()
+
+ K[0, :] = K[0, :] * W
+ K[1, :] = K[1, :] * H
+
+ mask = opacity > opacity_threshold
+
+ res = cv2.solvePnPRansac(pts3d[mask], pixels[mask], K, None,
+ iterationsCount=100, reprojectionError=5, flags=cv2.SOLVEPNP_SQPNP)
+ success, R, T, inliers = res
+
+ assert success
+
+ R = cv2.Rodrigues(R)[0] # world to cam
+ pose = inv(np.r_[np.c_[R, T], [(0, 0, 0, 1)]]) # cam to world
+
+ return torch.from_numpy(pose.astype(np.float32))
+
+
+def pose_auc(errors, thresholds):
+ sort_idx = np.argsort(errors)
+ errors = np.array(errors.copy())[sort_idx]
+ recall = (np.arange(len(errors)) + 1) / len(errors)
+ errors = np.r_[0.0, errors]
+ recall = np.r_[0.0, recall]
+ aucs = []
+ for t in thresholds:
+ last_index = np.searchsorted(errors, t)
+ r = np.r_[recall[:last_index], recall[last_index - 1]]
+ e = np.r_[errors[:last_index], t]
+ aucs.append(np.trapz(r, x=e) / t)
+ return aucs
+
+
+def rotation_6d_to_matrix(d6):
+ """
+ Converts 6D rotation representation by Zhou et al. [1] to rotation matrix
+ using Gram--Schmidt orthogonalization per Section B of [1]. Adapted from pytorch3d.
+ Args:
+ d6: 6D rotation representation, of size (*, 6)
+
+ Returns:
+ batch of rotation matrices of size (*, 3, 3)
+
+ [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
+ On the Continuity of Rotation Representations in Neural Networks.
+ IEEE Conference on Computer Vision and Pattern Recognition, 2019.
+ Retrieved from http://arxiv.org/abs/1812.07035
+ """
+
+ a1, a2 = d6[..., :3], d6[..., 3:]
+ b1 = F.normalize(a1, dim=-1)
+ b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1
+ b2 = F.normalize(b2, dim=-1)
+ b3 = torch.cross(b1, b2, dim=-1)
+ return torch.stack((b1, b2, b3), dim=-2)
\ No newline at end of file
diff --git a/src/misc/collation.py b/src/misc/collation.py
new file mode 100644
index 0000000000000000000000000000000000000000..f65cc4f978c049e249117a09995018f4fb102b1d
--- /dev/null
+++ b/src/misc/collation.py
@@ -0,0 +1,15 @@
+from typing import Callable, Dict, Union
+
+from torch import Tensor
+
+Tree = Union[Dict[str, "Tree"], Tensor]
+
+
+def collate(trees: list[Tree], merge_fn: Callable[[list[Tensor]], Tensor]) -> Tree:
+ """Merge nested dictionaries of tensors."""
+ if isinstance(trees[0], Tensor):
+ return merge_fn(trees)
+ else:
+ return {
+ key: collate([tree[key] for tree in trees], merge_fn) for key in trees[0]
+ }
diff --git a/src/misc/colmap_utils.py b/src/misc/colmap_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..c9fcb6f16511934664fa23d2919456c35323d5ca
--- /dev/null
+++ b/src/misc/colmap_utils.py
@@ -0,0 +1,347 @@
+#
+# Copyright (C) 2023, Inria
+# GRAPHDECO research group, https://team.inria.fr/graphdeco
+# All rights reserved.
+#
+# This software is free for non-commercial, research and evaluation use
+# under the terms of the LICENSE.md file.
+#
+# For inquiries contact george.drettakis@inria.fr
+#
+
+import collections
+import struct
+
+import numpy as np
+
+CameraModel = collections.namedtuple(
+ "CameraModel", ["model_id", "model_name", "num_params"]
+)
+Camera = collections.namedtuple("Camera", ["id", "model", "width", "height", "params"])
+BaseImage = collections.namedtuple(
+ "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"]
+)
+Point3D = collections.namedtuple(
+ "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"]
+)
+CAMERA_MODELS = {
+ CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3),
+ CameraModel(model_id=1, model_name="PINHOLE", num_params=4),
+ CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4),
+ CameraModel(model_id=3, model_name="RADIAL", num_params=5),
+ CameraModel(model_id=4, model_name="OPENCV", num_params=8),
+ CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8),
+ CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12),
+ CameraModel(model_id=7, model_name="FOV", num_params=5),
+ CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4),
+ CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5),
+ CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12),
+}
+CAMERA_MODEL_IDS = dict(
+ [(camera_model.model_id, camera_model) for camera_model in CAMERA_MODELS]
+)
+CAMERA_MODEL_NAMES = dict(
+ [(camera_model.model_name, camera_model) for camera_model in CAMERA_MODELS]
+)
+
+
+def qvec2rotmat(qvec):
+ return np.array(
+ [
+ [
+ 1 - 2 * qvec[2] ** 2 - 2 * qvec[3] ** 2,
+ 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],
+ 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2],
+ ],
+ [
+ 2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],
+ 1 - 2 * qvec[1] ** 2 - 2 * qvec[3] ** 2,
+ 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1],
+ ],
+ [
+ 2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],
+ 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],
+ 1 - 2 * qvec[1] ** 2 - 2 * qvec[2] ** 2,
+ ],
+ ]
+ )
+
+
+def rotmat2qvec(R):
+ Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat
+ K = (
+ np.array(
+ [
+ [Rxx - Ryy - Rzz, 0, 0, 0],
+ [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0],
+ [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0],
+ [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz],
+ ]
+ )
+ / 3.0
+ )
+ eigvals, eigvecs = np.linalg.eigh(K)
+ qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)]
+ if qvec[0] < 0:
+ qvec *= -1
+ return qvec
+
+
+class Image(BaseImage):
+ def qvec2rotmat(self):
+ return qvec2rotmat(self.qvec)
+
+
+def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"):
+ """Read and unpack the next bytes from a binary file.
+ :param fid:
+ :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc.
+ :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}.
+ :param endian_character: Any of {@, =, <, >, !}
+ :return: Tuple of read and unpacked values.
+ """
+ data = fid.read(num_bytes)
+ return struct.unpack(endian_character + format_char_sequence, data)
+
+
+def read_points3D_text(path):
+ """
+ see: src/base/reconstruction.cc
+ void Reconstruction::ReadPoints3DText(const std::string& path)
+ void Reconstruction::WritePoints3DText(const std::string& path)
+ """
+ xyzs = None
+ rgbs = None
+ errors = None
+ num_points = 0
+ with open(path, "r") as fid:
+ while True:
+ line = fid.readline()
+ if not line:
+ break
+ line = line.strip()
+ if len(line) > 0 and line[0] != "#":
+ num_points += 1
+
+ xyzs = np.empty((num_points, 3))
+ rgbs = np.empty((num_points, 3))
+ errors = np.empty((num_points, 1))
+ count = 0
+ with open(path, "r") as fid:
+ while True:
+ line = fid.readline()
+ if not line:
+ break
+ line = line.strip()
+ if len(line) > 0 and line[0] != "#":
+ elems = line.split()
+ xyz = np.array(tuple(map(float, elems[1:4])))
+ rgb = np.array(tuple(map(int, elems[4:7])))
+ error = np.array(float(elems[7]))
+ xyzs[count] = xyz
+ rgbs[count] = rgb
+ errors[count] = error
+ count += 1
+
+ return xyzs, rgbs, errors
+
+
+def read_points3D_binary(path_to_model_file):
+ """
+ see: src/base/reconstruction.cc
+ void Reconstruction::ReadPoints3DBinary(const std::string& path)
+ void Reconstruction::WritePoints3DBinary(const std::string& path)
+ """
+
+ with open(path_to_model_file, "rb") as fid:
+ num_points = read_next_bytes(fid, 8, "Q")[0]
+
+ xyzs = np.empty((num_points, 3))
+ rgbs = np.empty((num_points, 3))
+ errors = np.empty((num_points, 1))
+
+ for p_id in range(num_points):
+ binary_point_line_properties = read_next_bytes(
+ fid, num_bytes=43, format_char_sequence="QdddBBBd"
+ )
+ xyz = np.array(binary_point_line_properties[1:4])
+ rgb = np.array(binary_point_line_properties[4:7])
+ error = np.array(binary_point_line_properties[7])
+ track_length = read_next_bytes(fid, num_bytes=8, format_char_sequence="Q")[
+ 0
+ ]
+ track_elems = read_next_bytes(
+ fid,
+ num_bytes=8 * track_length,
+ format_char_sequence="ii" * track_length,
+ )
+ xyzs[p_id] = xyz
+ rgbs[p_id] = rgb
+ errors[p_id] = error
+ return xyzs, rgbs, errors
+
+
+def read_intrinsics_text(path):
+ """
+ Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py
+ """
+ cameras = {}
+ with open(path, "r") as fid:
+ while True:
+ line = fid.readline()
+ if not line:
+ break
+ line = line.strip()
+ if len(line) > 0 and line[0] != "#":
+ elems = line.split()
+ camera_id = int(elems[0])
+ model = elems[1]
+ assert (
+ model == "PINHOLE"
+ ), "While the loader support other types, the rest of the code assumes PINHOLE"
+ width = int(elems[2])
+ height = int(elems[3])
+ params = np.array(tuple(map(float, elems[4:])))
+ cameras[camera_id] = Camera(
+ id=camera_id, model=model, width=width, height=height, params=params
+ )
+ return cameras
+
+
+def read_extrinsics_binary(path_to_model_file):
+ """
+ see: src/base/reconstruction.cc
+ void Reconstruction::ReadImagesBinary(const std::string& path)
+ void Reconstruction::WriteImagesBinary(const std::string& path)
+ """
+ images = {}
+ with open(path_to_model_file, "rb") as fid:
+ num_reg_images = read_next_bytes(fid, 8, "Q")[0]
+ for _ in range(num_reg_images):
+ binary_image_properties = read_next_bytes(
+ fid, num_bytes=64, format_char_sequence="idddddddi"
+ )
+ image_id = binary_image_properties[0]
+ qvec = np.array(binary_image_properties[1:5])
+ tvec = np.array(binary_image_properties[5:8])
+ camera_id = binary_image_properties[8]
+ image_name = ""
+ current_char = read_next_bytes(fid, 1, "c")[0]
+ while current_char != b"\x00": # look for the ASCII 0 entry
+ image_name += current_char.decode("utf-8")
+ current_char = read_next_bytes(fid, 1, "c")[0]
+ num_points2D = read_next_bytes(fid, num_bytes=8, format_char_sequence="Q")[
+ 0
+ ]
+ x_y_id_s = read_next_bytes(
+ fid,
+ num_bytes=24 * num_points2D,
+ format_char_sequence="ddq" * num_points2D,
+ )
+ xys = np.column_stack(
+ [tuple(map(float, x_y_id_s[0::3])), tuple(map(float, x_y_id_s[1::3]))]
+ )
+ point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3])))
+ images[image_id] = Image(
+ id=image_id,
+ qvec=qvec,
+ tvec=tvec,
+ camera_id=camera_id,
+ name=image_name,
+ xys=xys,
+ point3D_ids=point3D_ids,
+ )
+ return images
+
+
+def read_intrinsics_binary(path_to_model_file):
+ """
+ see: src/base/reconstruction.cc
+ void Reconstruction::WriteCamerasBinary(const std::string& path)
+ void Reconstruction::ReadCamerasBinary(const std::string& path)
+ """
+ cameras = {}
+ with open(path_to_model_file, "rb") as fid:
+ num_cameras = read_next_bytes(fid, 8, "Q")[0]
+ for _ in range(num_cameras):
+ camera_properties = read_next_bytes(
+ fid, num_bytes=24, format_char_sequence="iiQQ"
+ )
+ camera_id = camera_properties[0]
+ model_id = camera_properties[1]
+ model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name
+ width = camera_properties[2]
+ height = camera_properties[3]
+ num_params = CAMERA_MODEL_IDS[model_id].num_params
+ params = read_next_bytes(
+ fid, num_bytes=8 * num_params, format_char_sequence="d" * num_params
+ )
+ cameras[camera_id] = Camera(
+ id=camera_id,
+ model=model_name,
+ width=width,
+ height=height,
+ params=np.array(params),
+ )
+ assert len(cameras) == num_cameras
+ return cameras
+
+
+def read_extrinsics_text(path):
+ """
+ Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py
+ """
+ images = {}
+ with open(path, "r") as fid:
+ while True:
+ line = fid.readline()
+ if not line:
+ break
+ line = line.strip()
+ if len(line) > 0 and line[0] != "#":
+ elems = line.split()
+ image_id = int(elems[0])
+ qvec = np.array(tuple(map(float, elems[1:5])))
+ tvec = np.array(tuple(map(float, elems[5:8])))
+ camera_id = int(elems[8])
+ image_name = elems[9]
+ elems = fid.readline().split()
+ xys = np.column_stack(
+ [tuple(map(float, elems[0::3])), tuple(map(float, elems[1::3]))]
+ )
+ point3D_ids = np.array(tuple(map(int, elems[2::3])))
+ images[image_id] = Image(
+ id=image_id,
+ qvec=qvec,
+ tvec=tvec,
+ camera_id=camera_id,
+ name=image_name,
+ xys=xys,
+ point3D_ids=point3D_ids,
+ )
+ return images
+
+
+def read_colmap_bin_array(path):
+ """
+ Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_dense.py
+
+ :param path: path to the colmap binary file.
+ :return: nd array with the floating point values in the value
+ """
+ with open(path, "rb") as fid:
+ width, height, channels = np.genfromtxt(
+ fid, delimiter="&", max_rows=1, usecols=(0, 1, 2), dtype=int
+ )
+ fid.seek(0)
+ num_delimiter = 0
+ byte = fid.read(1)
+ while True:
+ if byte == b"&":
+ num_delimiter += 1
+ if num_delimiter >= 3:
+ break
+ byte = fid.read(1)
+ array = np.fromfile(fid, np.float32)
+ array = array.reshape((width, height, channels), order="F")
+ return np.transpose(array, (1, 0, 2)).squeeze()
diff --git a/src/misc/discrete_probability_distribution.py b/src/misc/discrete_probability_distribution.py
new file mode 100644
index 0000000000000000000000000000000000000000..1b9fcb228098da9a12ce7f7b59fe43c3562083f5
--- /dev/null
+++ b/src/misc/discrete_probability_distribution.py
@@ -0,0 +1,33 @@
+import torch
+from einops import reduce
+from jaxtyping import Float, Int64
+from torch import Tensor
+
+
+def sample_discrete_distribution(
+ pdf: Float[Tensor, "*batch bucket"],
+ num_samples: int,
+ eps: float = torch.finfo(torch.float32).eps,
+) -> tuple[
+ Int64[Tensor, "*batch sample"], # index
+ Float[Tensor, "*batch sample"], # probability density
+]:
+ *batch, bucket = pdf.shape
+ normalized_pdf = pdf / (eps + reduce(pdf, "... bucket -> ... ()", "sum"))
+ cdf = normalized_pdf.cumsum(dim=-1)
+ samples = torch.rand((*batch, num_samples), device=pdf.device)
+ index = torch.searchsorted(cdf, samples, right=True).clip(max=bucket - 1)
+ return index, normalized_pdf.gather(dim=-1, index=index)
+
+
+def gather_discrete_topk(
+ pdf: Float[Tensor, "*batch bucket"],
+ num_samples: int,
+ eps: float = torch.finfo(torch.float32).eps,
+) -> tuple[
+ Int64[Tensor, "*batch sample"], # index
+ Float[Tensor, "*batch sample"], # probability density
+]:
+ normalized_pdf = pdf / (eps + reduce(pdf, "... bucket -> ... ()", "sum"))
+ index = pdf.topk(k=num_samples, dim=-1).indices
+ return index, normalized_pdf.gather(dim=-1, index=index)
diff --git a/src/misc/heterogeneous_pairings.py b/src/misc/heterogeneous_pairings.py
new file mode 100644
index 0000000000000000000000000000000000000000..76d9b7fd1ef8676ad47135110ecd0795bf86a363
--- /dev/null
+++ b/src/misc/heterogeneous_pairings.py
@@ -0,0 +1,43 @@
+import torch
+from einops import repeat
+from jaxtyping import Int
+from torch import Tensor
+
+Index = Int[Tensor, "n n-1"]
+
+
+def generate_heterogeneous_index(
+ n: int,
+ device: torch.device = torch.device("cpu"),
+) -> tuple[Index, Index]:
+ """Generate indices for all pairs except self-pairs."""
+ arange = torch.arange(n, device=device)
+
+ # Generate an index that represents the item itself.
+ index_self = repeat(arange, "h -> h w", w=n - 1)
+
+ # Generate an index that represents the other items.
+ index_other = repeat(arange, "w -> h w", h=n).clone()
+ index_other += torch.ones((n, n), device=device, dtype=torch.int64).triu()
+ index_other = index_other[:, :-1]
+
+ return index_self, index_other
+
+
+def generate_heterogeneous_index_transpose(
+ n: int,
+ device: torch.device = torch.device("cpu"),
+) -> tuple[Index, Index]:
+ """Generate an index that can be used to "transpose" the heterogeneous index.
+ Applying the index a second time inverts the "transpose."
+ """
+ arange = torch.arange(n, device=device)
+ ones = torch.ones((n, n), device=device, dtype=torch.int64)
+
+ index_self = repeat(arange, "w -> h w", h=n).clone()
+ index_self = index_self + ones.triu()
+
+ index_other = repeat(arange, "h -> h w", w=n)
+ index_other = index_other - (1 - ones.triu())
+
+ return index_self[:, :-1], index_other[:, :-1]
diff --git a/src/misc/image_io.py b/src/misc/image_io.py
new file mode 100644
index 0000000000000000000000000000000000000000..2a6994df1e3f06c3842b6d13364afe3a75e3df3a
--- /dev/null
+++ b/src/misc/image_io.py
@@ -0,0 +1,211 @@
+import io
+import os
+from pathlib import Path
+from typing import Union
+
+import cv2
+import numpy as np
+import skvideo
+import torch
+import torchvision.transforms as tf
+from einops import rearrange, repeat
+from jaxtyping import Float, UInt8
+
+from matplotlib import pyplot as plt
+from matplotlib.figure import Figure
+from PIL import Image
+from torch import Tensor
+
+FloatImage = Union[
+ Float[Tensor, "height width"],
+ Float[Tensor, "channel height width"],
+ Float[Tensor, "batch channel height width"],
+]
+
+
+def fig_to_image(
+ fig: Figure,
+ dpi: int = 100,
+ device: torch.device = torch.device("cpu"),
+) -> Float[Tensor, "3 height width"]:
+ buffer = io.BytesIO()
+ fig.savefig(buffer, format="raw", dpi=dpi)
+ buffer.seek(0)
+ data = np.frombuffer(buffer.getvalue(), dtype=np.uint8)
+ h = int(fig.bbox.bounds[3])
+ w = int(fig.bbox.bounds[2])
+ data = rearrange(data, "(h w c) -> c h w", h=h, w=w, c=4)
+ buffer.close()
+ return (torch.tensor(data, device=device, dtype=torch.float32) / 255)[:3]
+
+
+def prep_image(image: FloatImage) -> UInt8[np.ndarray, "height width channel"]:
+ # Handle batched images.
+ if image.ndim == 4:
+ image = rearrange(image, "b c h w -> c h (b w)")
+
+ # Handle single-channel images.
+ if image.ndim == 2:
+ image = rearrange(image, "h w -> () h w")
+
+ # Ensure that there are 3 or 4 channels.
+ channel, _, _ = image.shape
+ if channel == 1:
+ image = repeat(image, "() h w -> c h w", c=3)
+ assert image.shape[0] in (3, 4)
+
+ image = (image.detach().clip(min=0, max=1) * 255).type(torch.uint8)
+ return rearrange(image, "c h w -> h w c").cpu().numpy()
+
+
+def save_image(
+ image: FloatImage,
+ path: Union[Path, str],
+) -> None:
+ """Save an image. Assumed to be in range 0-1."""
+
+ # Create the parent directory if it doesn't already exist.
+ path = Path(path)
+ path.parent.mkdir(exist_ok=True, parents=True)
+
+ # Save the image.
+ Image.fromarray(prep_image(image)).save(path)
+
+
+def load_image(
+ path: Union[Path, str],
+) -> Float[Tensor, "3 height width"]:
+ return tf.ToTensor()(Image.open(path))[:3]
+
+
+def save_video(tensor, save_path, fps=10):
+ """
+ Save a tensor of shape (N, C, H, W) as a video file using imageio.
+ Args:
+ tensor: Tensor of shape (N, C, H, W) in range [0, 1]
+ save_path: Path to save the video file
+ fps: Frames per second for the video
+ """
+ # Convert tensor to numpy array and adjust dimensions
+ video = tensor.cpu().detach().numpy() # (N, C, H, W)
+ video = np.transpose(video, (0, 2, 3, 1)) # (N, H, W, C)
+
+ # Scale to [0, 255] and convert to uint8
+ video = (video * 255).astype(np.uint8)
+
+ # Ensure the directory exists
+ import os
+
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
+
+ # Use imageio to write video (handles codec compatibility automatically)
+ import imageio
+
+ writer = imageio.get_writer(save_path, fps=fps)
+
+ for frame in video:
+ writer.append_data(frame)
+
+ writer.close()
+
+
+def save_interpolated_video(
+ pred_extrinsics, pred_intrinsics, b, h, w, gaussians, save_path, decoder_func, t=10
+):
+ # Interpolate between neighboring frames
+ # t: Number of extra views to interpolate between each pair
+ interpolated_extrinsics = []
+ interpolated_intrinsics = []
+
+ # For each pair of neighboring frame
+ for i in range(pred_extrinsics.shape[1] - 1):
+ # Add the current frame
+ interpolated_extrinsics.append(pred_extrinsics[:, i : i + 1])
+ interpolated_intrinsics.append(pred_intrinsics[:, i : i + 1])
+
+ # Interpolate between current and next frame
+ for j in range(1, t + 1):
+ alpha = j / (t + 1)
+
+ # Interpolate extrinsics
+ start_extrinsic = pred_extrinsics[:, i]
+ end_extrinsic = pred_extrinsics[:, i + 1]
+
+ # Separate rotation and translation
+ start_rot = start_extrinsic[:, :3, :3]
+ end_rot = end_extrinsic[:, :3, :3]
+ start_trans = start_extrinsic[:, :3, 3]
+ end_trans = end_extrinsic[:, :3, 3]
+
+ # Interpolate translation (linear)
+ interp_trans = (1 - alpha) * start_trans + alpha * end_trans
+
+ # Interpolate rotation (spherical)
+ start_rot_flat = start_rot.reshape(b, 9)
+ end_rot_flat = end_rot.reshape(b, 9)
+ interp_rot_flat = (1 - alpha) * start_rot_flat + alpha * end_rot_flat
+ interp_rot = interp_rot_flat.reshape(b, 3, 3)
+
+ # Normalize rotation matrix to ensure it's orthogonal
+ u, _, v = torch.svd(interp_rot)
+ interp_rot = torch.bmm(u, v.transpose(1, 2))
+
+ # Combine interpolated rotation and translation
+ interp_extrinsic = (
+ torch.eye(4, device=pred_extrinsics.device).unsqueeze(0).repeat(b, 1, 1)
+ )
+ interp_extrinsic[:, :3, :3] = interp_rot
+ interp_extrinsic[:, :3, 3] = interp_trans
+
+ # Interpolate intrinsics (linear)
+ start_intrinsic = pred_intrinsics[:, i]
+ end_intrinsic = pred_intrinsics[:, i + 1]
+ interp_intrinsic = (1 - alpha) * start_intrinsic + alpha * end_intrinsic
+
+ # Add interpolated frame
+ interpolated_extrinsics.append(interp_extrinsic.unsqueeze(1))
+ interpolated_intrinsics.append(interp_intrinsic.unsqueeze(1))
+
+ # Concatenate all frames
+ pred_all_extrinsic = torch.cat(interpolated_extrinsics, dim=1)
+ pred_all_intrinsic = torch.cat(interpolated_intrinsics, dim=1)
+
+ # Add the last frame
+ interpolated_extrinsics.append(pred_all_extrinsic[:, -1:])
+ interpolated_intrinsics.append(pred_all_intrinsic[:, -1:])
+
+ # Update K to reflect the new number of frames
+ num_frames = pred_all_extrinsic.shape[1]
+
+ # Render interpolated views
+ interpolated_output = decoder_func.forward(
+ gaussians,
+ pred_all_extrinsic,
+ pred_all_intrinsic.float(),
+ torch.ones(1, num_frames, device=pred_all_extrinsic.device) * 0.1,
+ torch.ones(1, num_frames, device=pred_all_extrinsic.device) * 100,
+ (h, w),
+ )
+
+ # Convert to video format
+ video = interpolated_output.color[0].clip(min=0, max=1)
+ depth = interpolated_output.depth[0]
+
+ # Normalize depth for visualization
+ # to avoid `quantile() input tensor is too large`
+ num_views = pred_extrinsics.shape[1]
+ depth_norm = (depth - depth[::num_views].quantile(0.01)) / (
+ depth[::num_views].quantile(0.99) - depth[::num_views].quantile(0.01)
+ )
+ depth_norm = plt.cm.turbo(depth_norm.cpu().numpy())
+ depth_colored = (
+ torch.from_numpy(depth_norm[..., :3]).permute(0, 3, 1, 2).to(depth.device)
+ )
+ depth_colored = depth_colored.clip(min=0, max=1)
+
+ # Save depth video
+ save_video(depth_colored, os.path.join(save_path, f"depth.mp4"))
+ # Save video
+ save_video(video, os.path.join(save_path, f"rgb.mp4"))
+
+ return os.path.join(save_path, f"rgb.mp4"), os.path.join(save_path, f"depth.mp4")
diff --git a/src/misc/nn_module_tools.py b/src/misc/nn_module_tools.py
new file mode 100644
index 0000000000000000000000000000000000000000..21570e2880349a8c6083118d76183e90a501e617
--- /dev/null
+++ b/src/misc/nn_module_tools.py
@@ -0,0 +1,16 @@
+from torch import nn
+
+
+def convert_to_buffer(module: nn.Module, persistent: bool = True):
+ # Recurse over child modules.
+ for name, child in list(module.named_children()):
+ convert_to_buffer(child, persistent)
+
+ # Also re-save buffers to change persistence.
+ for name, parameter_or_buffer in (
+ *module.named_parameters(recurse=False),
+ *module.named_buffers(recurse=False),
+ ):
+ value = parameter_or_buffer.detach().clone()
+ delattr(module, name)
+ module.register_buffer(name, value, persistent=persistent)
diff --git a/src/misc/sh_rotation.py b/src/misc/sh_rotation.py
new file mode 100644
index 0000000000000000000000000000000000000000..cdca2dafbe845d1bc7e2a8df2f07bdf933e64b98
--- /dev/null
+++ b/src/misc/sh_rotation.py
@@ -0,0 +1,111 @@
+from math import isqrt
+
+import torch
+from e3nn.o3 import matrix_to_angles, wigner_D
+from einops import einsum
+from jaxtyping import Float
+from torch import Tensor
+
+
+def rotate_sh(
+ sh_coefficients: Float[Tensor, "*#batch n"],
+ rotations: Float[Tensor, "*#batch 3 3"],
+) -> Float[Tensor, "*batch n"]:
+ device = sh_coefficients.device
+ dtype = sh_coefficients.dtype
+
+ # change the basis from YZX -> XYZ to fit the convention of e3nn
+ P = torch.tensor([[0, 0, 1], [1, 0, 0], [0, 1, 0]],
+ dtype=sh_coefficients.dtype, device=sh_coefficients.device)
+ inversed_P = torch.tensor([[0, 1, 0], [0, 0, 1], [1, 0, 0], ],
+ dtype=sh_coefficients.dtype, device=sh_coefficients.device)
+ permuted_rotation_matrix = inversed_P @ rotations @ P
+
+ *_, n = sh_coefficients.shape
+ alpha, beta, gamma = matrix_to_angles(permuted_rotation_matrix)
+ result = []
+ for degree in range(isqrt(n)):
+ with torch.device(device):
+ sh_rotations = wigner_D(degree, alpha, -beta, gamma).type(dtype)
+ sh_rotated = einsum(
+ sh_rotations,
+ sh_coefficients[..., degree**2 : (degree + 1) ** 2],
+ "... i j, ... j -> ... i",
+ )
+ result.append(sh_rotated)
+
+ return torch.cat(result, dim=-1)
+
+
+# def rotate_sh(
+# sh_coefficients: Float[Tensor, "*#batch n"],
+# rotations: Float[Tensor, "*#batch 3 3"],
+# ) -> Float[Tensor, "*batch n"]:
+# device = sh_coefficients.device
+# dtype = sh_coefficients.dtype
+#
+# *_, n = sh_coefficients.shape
+# alpha, beta, gamma = matrix_to_angles(rotations)
+# result = []
+# for degree in range(isqrt(n)):
+# with torch.device(device):
+# sh_rotations = wigner_D(degree, alpha, beta, gamma).type(dtype)
+# sh_rotated = einsum(
+# sh_rotations,
+# sh_coefficients[..., degree**2 : (degree + 1) ** 2],
+# "... i j, ... j -> ... i",
+# )
+# result.append(sh_rotated)
+#
+# return torch.cat(result, dim=-1)
+
+
+if __name__ == "__main__":
+ from pathlib import Path
+
+ import matplotlib.pyplot as plt
+ from e3nn.o3 import spherical_harmonics
+ from matplotlib import cm
+ from scipy.spatial.transform.rotation import Rotation as R
+
+ device = torch.device("cuda")
+
+ # Generate random spherical harmonics coefficients.
+ degree = 4
+ coefficients = torch.rand((degree + 1) ** 2, dtype=torch.float32, device=device)
+
+ def plot_sh(sh_coefficients, path: Path) -> None:
+ phi = torch.linspace(0, torch.pi, 100, device=device)
+ theta = torch.linspace(0, 2 * torch.pi, 100, device=device)
+ phi, theta = torch.meshgrid(phi, theta, indexing="xy")
+ x = torch.sin(phi) * torch.cos(theta)
+ y = torch.sin(phi) * torch.sin(theta)
+ z = torch.cos(phi)
+ xyz = torch.stack([x, y, z], dim=-1)
+ sh = spherical_harmonics(list(range(degree + 1)), xyz, True)
+ result = einsum(sh, sh_coefficients, "... n, n -> ...")
+ result = (result - result.min()) / (result.max() - result.min())
+
+ # Set the aspect ratio to 1 so our sphere looks spherical
+ fig = plt.figure(figsize=plt.figaspect(1.0))
+ ax = fig.add_subplot(111, projection="3d")
+ ax.plot_surface(
+ x.cpu().numpy(),
+ y.cpu().numpy(),
+ z.cpu().numpy(),
+ rstride=1,
+ cstride=1,
+ facecolors=cm.seismic(result.cpu().numpy()),
+ )
+ # Turn off the axis planes
+ ax.set_axis_off()
+ path.parent.mkdir(exist_ok=True, parents=True)
+ plt.savefig(path)
+
+ for i, angle in enumerate(torch.linspace(0, 2 * torch.pi, 30)):
+ rotation = torch.tensor(
+ R.from_euler("x", angle.item()).as_matrix(), device=device
+ )
+ plot_sh(rotate_sh(coefficients, rotation), Path(f"sh_rotation/{i:0>3}.png"))
+
+ print("Done!")
diff --git a/src/misc/sht.py b/src/misc/sht.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b89273a8f20b4da5ba296b175c856c974df0984
--- /dev/null
+++ b/src/misc/sht.py
@@ -0,0 +1,1637 @@
+"""Real spherical harmonics in Cartesian form for PyTorch.
+
+This is an autogenerated file. See
+https://github.com/cheind/torch-spherical-harmonics
+for more information.
+"""
+
+import torch
+
+
+def rsh_cart_0(xyz: torch.Tensor):
+ """Computes all real spherical harmonics up to degree 0.
+
+ This is an autogenerated method. See
+ https://github.com/cheind/torch-spherical-harmonics
+ for more information.
+
+ Params:
+ xyz: (N,...,3) tensor of points on the unit sphere
+
+ Returns:
+ rsh: (N,...,1) real spherical harmonics
+ projections of input. Ynm is found at index
+ `n*(n+1) + m`, with `0 <= n <= degree` and
+ `-n <= m <= n`.
+ """
+
+ return torch.stack(
+ [
+ xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]),
+ ],
+ -1,
+ )
+
+
+def rsh_cart_1(xyz: torch.Tensor):
+ """Computes all real spherical harmonics up to degree 1.
+
+ This is an autogenerated method. See
+ https://github.com/cheind/torch-spherical-harmonics
+ for more information.
+
+ Params:
+ xyz: (N,...,3) tensor of points on the unit sphere
+
+ Returns:
+ rsh: (N,...,4) real spherical harmonics
+ projections of input. Ynm is found at index
+ `n*(n+1) + m`, with `0 <= n <= degree` and
+ `-n <= m <= n`.
+ """
+ x = xyz[..., 0]
+ y = xyz[..., 1]
+ z = xyz[..., 2]
+
+ return torch.stack(
+ [
+ xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]),
+ -0.48860251190292 * y,
+ 0.48860251190292 * z,
+ -0.48860251190292 * x,
+ ],
+ -1,
+ )
+
+
+def rsh_cart_2(xyz: torch.Tensor):
+ """Computes all real spherical harmonics up to degree 2.
+
+ This is an autogenerated method. See
+ https://github.com/cheind/torch-spherical-harmonics
+ for more information.
+
+ Params:
+ xyz: (N,...,3) tensor of points on the unit sphere
+
+ Returns:
+ rsh: (N,...,9) real spherical harmonics
+ projections of input. Ynm is found at index
+ `n*(n+1) + m`, with `0 <= n <= degree` and
+ `-n <= m <= n`.
+ """
+ x = xyz[..., 0]
+ y = xyz[..., 1]
+ z = xyz[..., 2]
+
+ x2 = x**2
+ y2 = y**2
+ z2 = z**2
+ xy = x * y
+ xz = x * z
+ yz = y * z
+
+ return torch.stack(
+ [
+ xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]),
+ -0.48860251190292 * y,
+ 0.48860251190292 * z,
+ -0.48860251190292 * x,
+ 1.09254843059208 * xy,
+ -1.09254843059208 * yz,
+ 0.94617469575756 * z2 - 0.31539156525252,
+ -1.09254843059208 * xz,
+ 0.54627421529604 * x2 - 0.54627421529604 * y2,
+ ],
+ -1,
+ )
+
+
+def rsh_cart_3(xyz: torch.Tensor):
+ """Computes all real spherical harmonics up to degree 3.
+
+ This is an autogenerated method. See
+ https://github.com/cheind/torch-spherical-harmonics
+ for more information.
+
+ Params:
+ xyz: (N,...,3) tensor of points on the unit sphere
+
+ Returns:
+ rsh: (N,...,16) real spherical harmonics
+ projections of input. Ynm is found at index
+ `n*(n+1) + m`, with `0 <= n <= degree` and
+ `-n <= m <= n`.
+ """
+ x = xyz[..., 0]
+ y = xyz[..., 1]
+ z = xyz[..., 2]
+
+ x2 = x**2
+ y2 = y**2
+ z2 = z**2
+ xy = x * y
+ xz = x * z
+ yz = y * z
+
+ return torch.stack(
+ [
+ xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]),
+ -0.48860251190292 * y,
+ 0.48860251190292 * z,
+ -0.48860251190292 * x,
+ 1.09254843059208 * xy,
+ -1.09254843059208 * yz,
+ 0.94617469575756 * z2 - 0.31539156525252,
+ -1.09254843059208 * xz,
+ 0.54627421529604 * x2 - 0.54627421529604 * y2,
+ -0.590043589926644 * y * (3.0 * x2 - y2),
+ 2.89061144264055 * xy * z,
+ 0.304697199642977 * y * (1.5 - 7.5 * z2),
+ 1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z,
+ 0.304697199642977 * x * (1.5 - 7.5 * z2),
+ 1.44530572132028 * z * (x2 - y2),
+ -0.590043589926644 * x * (x2 - 3.0 * y2),
+ ],
+ -1,
+ )
+
+
+def rsh_cart_4(xyz: torch.Tensor):
+ """Computes all real spherical harmonics up to degree 4.
+
+ This is an autogenerated method. See
+ https://github.com/cheind/torch-spherical-harmonics
+ for more information.
+
+ Params:
+ xyz: (N,...,3) tensor of points on the unit sphere
+
+ Returns:
+ rsh: (N,...,25) real spherical harmonics
+ projections of input. Ynm is found at index
+ `n*(n+1) + m`, with `0 <= n <= degree` and
+ `-n <= m <= n`.
+ """
+ x = xyz[..., 0]
+ y = xyz[..., 1]
+ z = xyz[..., 2]
+
+ x2 = x**2
+ y2 = y**2
+ z2 = z**2
+ xy = x * y
+ xz = x * z
+ yz = y * z
+ x4 = x2**2
+ y4 = y2**2
+ z4 = z2**2
+
+ return torch.stack(
+ [
+ xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]),
+ -0.48860251190292 * y,
+ 0.48860251190292 * z,
+ -0.48860251190292 * x,
+ 1.09254843059208 * xy,
+ -1.09254843059208 * yz,
+ 0.94617469575756 * z2 - 0.31539156525252,
+ -1.09254843059208 * xz,
+ 0.54627421529604 * x2 - 0.54627421529604 * y2,
+ -0.590043589926644 * y * (3.0 * x2 - y2),
+ 2.89061144264055 * xy * z,
+ 0.304697199642977 * y * (1.5 - 7.5 * z2),
+ 1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z,
+ 0.304697199642977 * x * (1.5 - 7.5 * z2),
+ 1.44530572132028 * z * (x2 - y2),
+ -0.590043589926644 * x * (x2 - 3.0 * y2),
+ 2.5033429417967 * xy * (x2 - y2),
+ -1.77013076977993 * yz * (3.0 * x2 - y2),
+ 0.126156626101008 * xy * (52.5 * z2 - 7.5),
+ 0.267618617422916 * y * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
+ 1.48099765681286
+ * z
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
+ - 0.952069922236839 * z2
+ + 0.317356640745613,
+ 0.267618617422916 * x * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
+ 0.063078313050504 * (x2 - y2) * (52.5 * z2 - 7.5),
+ -1.77013076977993 * xz * (x2 - 3.0 * y2),
+ -3.75501441269506 * x2 * y2
+ + 0.625835735449176 * x4
+ + 0.625835735449176 * y4,
+ ],
+ -1,
+ )
+
+
+def rsh_cart_5(xyz: torch.Tensor):
+ """Computes all real spherical harmonics up to degree 5.
+
+ This is an autogenerated method. See
+ https://github.com/cheind/torch-spherical-harmonics
+ for more information.
+
+ Params:
+ xyz: (N,...,3) tensor of points on the unit sphere
+
+ Returns:
+ rsh: (N,...,36) real spherical harmonics
+ projections of input. Ynm is found at index
+ `n*(n+1) + m`, with `0 <= n <= degree` and
+ `-n <= m <= n`.
+ """
+ x = xyz[..., 0]
+ y = xyz[..., 1]
+ z = xyz[..., 2]
+
+ x2 = x**2
+ y2 = y**2
+ z2 = z**2
+ xy = x * y
+ xz = x * z
+ yz = y * z
+ x4 = x2**2
+ y4 = y2**2
+ z4 = z2**2
+
+ return torch.stack(
+ [
+ xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]),
+ -0.48860251190292 * y,
+ 0.48860251190292 * z,
+ -0.48860251190292 * x,
+ 1.09254843059208 * xy,
+ -1.09254843059208 * yz,
+ 0.94617469575756 * z2 - 0.31539156525252,
+ -1.09254843059208 * xz,
+ 0.54627421529604 * x2 - 0.54627421529604 * y2,
+ -0.590043589926644 * y * (3.0 * x2 - y2),
+ 2.89061144264055 * xy * z,
+ 0.304697199642977 * y * (1.5 - 7.5 * z2),
+ 1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z,
+ 0.304697199642977 * x * (1.5 - 7.5 * z2),
+ 1.44530572132028 * z * (x2 - y2),
+ -0.590043589926644 * x * (x2 - 3.0 * y2),
+ 2.5033429417967 * xy * (x2 - y2),
+ -1.77013076977993 * yz * (3.0 * x2 - y2),
+ 0.126156626101008 * xy * (52.5 * z2 - 7.5),
+ 0.267618617422916 * y * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
+ 1.48099765681286
+ * z
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
+ - 0.952069922236839 * z2
+ + 0.317356640745613,
+ 0.267618617422916 * x * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
+ 0.063078313050504 * (x2 - y2) * (52.5 * z2 - 7.5),
+ -1.77013076977993 * xz * (x2 - 3.0 * y2),
+ -3.75501441269506 * x2 * y2
+ + 0.625835735449176 * x4
+ + 0.625835735449176 * y4,
+ -0.65638205684017 * y * (-10.0 * x2 * y2 + 5.0 * x4 + y4),
+ 8.30264925952416 * xy * z * (x2 - y2),
+ 0.00931882475114763 * y * (52.5 - 472.5 * z2) * (3.0 * x2 - y2),
+ 0.0913054625709205 * xy * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z),
+ 0.241571547304372
+ * y
+ * (
+ 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ + 9.375 * z2
+ - 1.875
+ ),
+ -1.24747010616985 * z * (1.5 * z2 - 0.5)
+ + 1.6840846433293
+ * z
+ * (
+ 1.75
+ * z
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
+ - 1.125 * z2
+ + 0.375
+ )
+ + 0.498988042467941 * z,
+ 0.241571547304372
+ * x
+ * (
+ 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ + 9.375 * z2
+ - 1.875
+ ),
+ 0.0456527312854602 * (x2 - y2) * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z),
+ 0.00931882475114763 * x * (52.5 - 472.5 * z2) * (x2 - 3.0 * y2),
+ 2.07566231488104 * z * (-6.0 * x2 * y2 + x4 + y4),
+ -0.65638205684017 * x * (-10.0 * x2 * y2 + x4 + 5.0 * y4),
+ ],
+ -1,
+ )
+
+
+def rsh_cart_6(xyz: torch.Tensor):
+ """Computes all real spherical harmonics up to degree 6.
+
+ This is an autogenerated method. See
+ https://github.com/cheind/torch-spherical-harmonics
+ for more information.
+
+ Params:
+ xyz: (N,...,3) tensor of points on the unit sphere
+
+ Returns:
+ rsh: (N,...,49) real spherical harmonics
+ projections of input. Ynm is found at index
+ `n*(n+1) + m`, with `0 <= n <= degree` and
+ `-n <= m <= n`.
+ """
+ x = xyz[..., 0]
+ y = xyz[..., 1]
+ z = xyz[..., 2]
+
+ x2 = x**2
+ y2 = y**2
+ z2 = z**2
+ xy = x * y
+ xz = x * z
+ yz = y * z
+ x4 = x2**2
+ y4 = y2**2
+ z4 = z2**2
+
+ return torch.stack(
+ [
+ xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]),
+ -0.48860251190292 * y,
+ 0.48860251190292 * z,
+ -0.48860251190292 * x,
+ 1.09254843059208 * xy,
+ -1.09254843059208 * yz,
+ 0.94617469575756 * z2 - 0.31539156525252,
+ -1.09254843059208 * xz,
+ 0.54627421529604 * x2 - 0.54627421529604 * y2,
+ -0.590043589926644 * y * (3.0 * x2 - y2),
+ 2.89061144264055 * xy * z,
+ 0.304697199642977 * y * (1.5 - 7.5 * z2),
+ 1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z,
+ 0.304697199642977 * x * (1.5 - 7.5 * z2),
+ 1.44530572132028 * z * (x2 - y2),
+ -0.590043589926644 * x * (x2 - 3.0 * y2),
+ 2.5033429417967 * xy * (x2 - y2),
+ -1.77013076977993 * yz * (3.0 * x2 - y2),
+ 0.126156626101008 * xy * (52.5 * z2 - 7.5),
+ 0.267618617422916 * y * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
+ 1.48099765681286
+ * z
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
+ - 0.952069922236839 * z2
+ + 0.317356640745613,
+ 0.267618617422916 * x * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
+ 0.063078313050504 * (x2 - y2) * (52.5 * z2 - 7.5),
+ -1.77013076977993 * xz * (x2 - 3.0 * y2),
+ -3.75501441269506 * x2 * y2
+ + 0.625835735449176 * x4
+ + 0.625835735449176 * y4,
+ -0.65638205684017 * y * (-10.0 * x2 * y2 + 5.0 * x4 + y4),
+ 8.30264925952416 * xy * z * (x2 - y2),
+ 0.00931882475114763 * y * (52.5 - 472.5 * z2) * (3.0 * x2 - y2),
+ 0.0913054625709205 * xy * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z),
+ 0.241571547304372
+ * y
+ * (
+ 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ + 9.375 * z2
+ - 1.875
+ ),
+ -1.24747010616985 * z * (1.5 * z2 - 0.5)
+ + 1.6840846433293
+ * z
+ * (
+ 1.75
+ * z
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
+ - 1.125 * z2
+ + 0.375
+ )
+ + 0.498988042467941 * z,
+ 0.241571547304372
+ * x
+ * (
+ 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ + 9.375 * z2
+ - 1.875
+ ),
+ 0.0456527312854602 * (x2 - y2) * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z),
+ 0.00931882475114763 * x * (52.5 - 472.5 * z2) * (x2 - 3.0 * y2),
+ 2.07566231488104 * z * (-6.0 * x2 * y2 + x4 + y4),
+ -0.65638205684017 * x * (-10.0 * x2 * y2 + x4 + 5.0 * y4),
+ 4.09910463115149 * x**4 * xy
+ - 13.6636821038383 * xy**3
+ + 4.09910463115149 * xy * y**4,
+ -2.36661916223175 * yz * (-10.0 * x2 * y2 + 5.0 * x4 + y4),
+ 0.00427144889505798 * xy * (x2 - y2) * (5197.5 * z2 - 472.5),
+ 0.00584892228263444
+ * y
+ * (3.0 * x2 - y2)
+ * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z),
+ 0.0701870673916132
+ * xy
+ * (
+ 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
+ - 91.875 * z2
+ + 13.125
+ ),
+ 0.221950995245231
+ * y
+ * (
+ -2.8 * z * (1.5 - 7.5 * z2)
+ + 2.2
+ * z
+ * (
+ 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ + 9.375 * z2
+ - 1.875
+ )
+ - 4.8 * z
+ ),
+ -1.48328138624466
+ * z
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
+ + 1.86469659985043
+ * z
+ * (
+ -1.33333333333333 * z * (1.5 * z2 - 0.5)
+ + 1.8
+ * z
+ * (
+ 1.75
+ * z
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
+ - 1.125 * z2
+ + 0.375
+ )
+ + 0.533333333333333 * z
+ )
+ + 0.953538034014426 * z2
+ - 0.317846011338142,
+ 0.221950995245231
+ * x
+ * (
+ -2.8 * z * (1.5 - 7.5 * z2)
+ + 2.2
+ * z
+ * (
+ 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ + 9.375 * z2
+ - 1.875
+ )
+ - 4.8 * z
+ ),
+ 0.0350935336958066
+ * (x2 - y2)
+ * (
+ 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
+ - 91.875 * z2
+ + 13.125
+ ),
+ 0.00584892228263444
+ * x
+ * (x2 - 3.0 * y2)
+ * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z),
+ 0.0010678622237645 * (5197.5 * z2 - 472.5) * (-6.0 * x2 * y2 + x4 + y4),
+ -2.36661916223175 * xz * (-10.0 * x2 * y2 + x4 + 5.0 * y4),
+ 0.683184105191914 * x2**3
+ + 10.2477615778787 * x2 * y4
+ - 10.2477615778787 * x4 * y2
+ - 0.683184105191914 * y2**3,
+ ],
+ -1,
+ )
+
+
+def rsh_cart_7(xyz: torch.Tensor):
+ """Computes all real spherical harmonics up to degree 7.
+
+ This is an autogenerated method. See
+ https://github.com/cheind/torch-spherical-harmonics
+ for more information.
+
+ Params:
+ xyz: (N,...,3) tensor of points on the unit sphere
+
+ Returns:
+ rsh: (N,...,64) real spherical harmonics
+ projections of input. Ynm is found at index
+ `n*(n+1) + m`, with `0 <= n <= degree` and
+ `-n <= m <= n`.
+ """
+ x = xyz[..., 0]
+ y = xyz[..., 1]
+ z = xyz[..., 2]
+
+ x2 = x**2
+ y2 = y**2
+ z2 = z**2
+ xy = x * y
+ xz = x * z
+ yz = y * z
+ x4 = x2**2
+ y4 = y2**2
+ z4 = z2**2
+
+ return torch.stack(
+ [
+ xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]),
+ -0.48860251190292 * y,
+ 0.48860251190292 * z,
+ -0.48860251190292 * x,
+ 1.09254843059208 * xy,
+ -1.09254843059208 * yz,
+ 0.94617469575756 * z2 - 0.31539156525252,
+ -1.09254843059208 * xz,
+ 0.54627421529604 * x2 - 0.54627421529604 * y2,
+ -0.590043589926644 * y * (3.0 * x2 - y2),
+ 2.89061144264055 * xy * z,
+ 0.304697199642977 * y * (1.5 - 7.5 * z2),
+ 1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z,
+ 0.304697199642977 * x * (1.5 - 7.5 * z2),
+ 1.44530572132028 * z * (x2 - y2),
+ -0.590043589926644 * x * (x2 - 3.0 * y2),
+ 2.5033429417967 * xy * (x2 - y2),
+ -1.77013076977993 * yz * (3.0 * x2 - y2),
+ 0.126156626101008 * xy * (52.5 * z2 - 7.5),
+ 0.267618617422916 * y * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
+ 1.48099765681286
+ * z
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
+ - 0.952069922236839 * z2
+ + 0.317356640745613,
+ 0.267618617422916 * x * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
+ 0.063078313050504 * (x2 - y2) * (52.5 * z2 - 7.5),
+ -1.77013076977993 * xz * (x2 - 3.0 * y2),
+ -3.75501441269506 * x2 * y2
+ + 0.625835735449176 * x4
+ + 0.625835735449176 * y4,
+ -0.65638205684017 * y * (-10.0 * x2 * y2 + 5.0 * x4 + y4),
+ 8.30264925952416 * xy * z * (x2 - y2),
+ 0.00931882475114763 * y * (52.5 - 472.5 * z2) * (3.0 * x2 - y2),
+ 0.0913054625709205 * xy * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z),
+ 0.241571547304372
+ * y
+ * (
+ 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ + 9.375 * z2
+ - 1.875
+ ),
+ -1.24747010616985 * z * (1.5 * z2 - 0.5)
+ + 1.6840846433293
+ * z
+ * (
+ 1.75
+ * z
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
+ - 1.125 * z2
+ + 0.375
+ )
+ + 0.498988042467941 * z,
+ 0.241571547304372
+ * x
+ * (
+ 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ + 9.375 * z2
+ - 1.875
+ ),
+ 0.0456527312854602 * (x2 - y2) * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z),
+ 0.00931882475114763 * x * (52.5 - 472.5 * z2) * (x2 - 3.0 * y2),
+ 2.07566231488104 * z * (-6.0 * x2 * y2 + x4 + y4),
+ -0.65638205684017 * x * (-10.0 * x2 * y2 + x4 + 5.0 * y4),
+ 4.09910463115149 * x**4 * xy
+ - 13.6636821038383 * xy**3
+ + 4.09910463115149 * xy * y**4,
+ -2.36661916223175 * yz * (-10.0 * x2 * y2 + 5.0 * x4 + y4),
+ 0.00427144889505798 * xy * (x2 - y2) * (5197.5 * z2 - 472.5),
+ 0.00584892228263444
+ * y
+ * (3.0 * x2 - y2)
+ * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z),
+ 0.0701870673916132
+ * xy
+ * (
+ 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
+ - 91.875 * z2
+ + 13.125
+ ),
+ 0.221950995245231
+ * y
+ * (
+ -2.8 * z * (1.5 - 7.5 * z2)
+ + 2.2
+ * z
+ * (
+ 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ + 9.375 * z2
+ - 1.875
+ )
+ - 4.8 * z
+ ),
+ -1.48328138624466
+ * z
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
+ + 1.86469659985043
+ * z
+ * (
+ -1.33333333333333 * z * (1.5 * z2 - 0.5)
+ + 1.8
+ * z
+ * (
+ 1.75
+ * z
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
+ - 1.125 * z2
+ + 0.375
+ )
+ + 0.533333333333333 * z
+ )
+ + 0.953538034014426 * z2
+ - 0.317846011338142,
+ 0.221950995245231
+ * x
+ * (
+ -2.8 * z * (1.5 - 7.5 * z2)
+ + 2.2
+ * z
+ * (
+ 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ + 9.375 * z2
+ - 1.875
+ )
+ - 4.8 * z
+ ),
+ 0.0350935336958066
+ * (x2 - y2)
+ * (
+ 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
+ - 91.875 * z2
+ + 13.125
+ ),
+ 0.00584892228263444
+ * x
+ * (x2 - 3.0 * y2)
+ * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z),
+ 0.0010678622237645 * (5197.5 * z2 - 472.5) * (-6.0 * x2 * y2 + x4 + y4),
+ -2.36661916223175 * xz * (-10.0 * x2 * y2 + x4 + 5.0 * y4),
+ 0.683184105191914 * x2**3
+ + 10.2477615778787 * x2 * y4
+ - 10.2477615778787 * x4 * y2
+ - 0.683184105191914 * y2**3,
+ -0.707162732524596
+ * y
+ * (7.0 * x2**3 + 21.0 * x2 * y4 - 35.0 * x4 * y2 - y2**3),
+ 2.6459606618019 * z * (6.0 * x**4 * xy - 20.0 * xy**3 + 6.0 * xy * y**4),
+ 9.98394571852353e-5
+ * y
+ * (5197.5 - 67567.5 * z2)
+ * (-10.0 * x2 * y2 + 5.0 * x4 + y4),
+ 0.00239614697244565
+ * xy
+ * (x2 - y2)
+ * (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z),
+ 0.00397356022507413
+ * y
+ * (3.0 * x2 - y2)
+ * (
+ 3.25 * z * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z)
+ + 1063.125 * z2
+ - 118.125
+ ),
+ 0.0561946276120613
+ * xy
+ * (
+ -4.8 * z * (52.5 * z2 - 7.5)
+ + 2.6
+ * z
+ * (
+ 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
+ - 91.875 * z2
+ + 13.125
+ )
+ + 48.0 * z
+ ),
+ 0.206472245902897
+ * y
+ * (
+ -2.625 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ + 2.16666666666667
+ * z
+ * (
+ -2.8 * z * (1.5 - 7.5 * z2)
+ + 2.2
+ * z
+ * (
+ 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ + 9.375 * z2
+ - 1.875
+ )
+ - 4.8 * z
+ )
+ - 10.9375 * z2
+ + 2.1875
+ ),
+ 1.24862677781952 * z * (1.5 * z2 - 0.5)
+ - 1.68564615005635
+ * z
+ * (
+ 1.75
+ * z
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
+ - 1.125 * z2
+ + 0.375
+ )
+ + 2.02901851395672
+ * z
+ * (
+ -1.45833333333333
+ * z
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
+ + 1.83333333333333
+ * z
+ * (
+ -1.33333333333333 * z * (1.5 * z2 - 0.5)
+ + 1.8
+ * z
+ * (
+ 1.75
+ * z
+ * (
+ 1.66666666666667 * z * (1.5 * z2 - 0.5)
+ - 0.666666666666667 * z
+ )
+ - 1.125 * z2
+ + 0.375
+ )
+ + 0.533333333333333 * z
+ )
+ + 0.9375 * z2
+ - 0.3125
+ )
+ - 0.499450711127808 * z,
+ 0.206472245902897
+ * x
+ * (
+ -2.625 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ + 2.16666666666667
+ * z
+ * (
+ -2.8 * z * (1.5 - 7.5 * z2)
+ + 2.2
+ * z
+ * (
+ 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ + 9.375 * z2
+ - 1.875
+ )
+ - 4.8 * z
+ )
+ - 10.9375 * z2
+ + 2.1875
+ ),
+ 0.0280973138060306
+ * (x2 - y2)
+ * (
+ -4.8 * z * (52.5 * z2 - 7.5)
+ + 2.6
+ * z
+ * (
+ 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
+ - 91.875 * z2
+ + 13.125
+ )
+ + 48.0 * z
+ ),
+ 0.00397356022507413
+ * x
+ * (x2 - 3.0 * y2)
+ * (
+ 3.25 * z * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z)
+ + 1063.125 * z2
+ - 118.125
+ ),
+ 0.000599036743111412
+ * (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z)
+ * (-6.0 * x2 * y2 + x4 + y4),
+ 9.98394571852353e-5
+ * x
+ * (5197.5 - 67567.5 * z2)
+ * (-10.0 * x2 * y2 + x4 + 5.0 * y4),
+ 2.6459606618019 * z * (x2**3 + 15.0 * x2 * y4 - 15.0 * x4 * y2 - y2**3),
+ -0.707162732524596
+ * x
+ * (x2**3 + 35.0 * x2 * y4 - 21.0 * x4 * y2 - 7.0 * y2**3),
+ ],
+ -1,
+ )
+
+
+# @torch.jit.script
+def rsh_cart_8(xyz: torch.Tensor):
+ """Computes all real spherical harmonics up to degree 8.
+
+ This is an autogenerated method. See
+ https://github.com/cheind/torch-spherical-harmonics
+ for more information.
+
+ Params:
+ xyz: (N,...,3) tensor of points on the unit sphere
+
+ Returns:
+ rsh: (N,...,81) real spherical harmonics
+ projections of input. Ynm is found at index
+ `n*(n+1) + m`, with `0 <= n <= degree` and
+ `-n <= m <= n`.
+ """
+ x = xyz[..., 0]
+ y = xyz[..., 1]
+ z = xyz[..., 2]
+
+ x2 = x**2
+ y2 = y**2
+ z2 = z**2
+ xy = x * y
+ xz = x * z
+ yz = y * z
+ x4 = x2**2
+ y4 = y2**2
+ # z4 = z2**2
+ return torch.stack(
+ [
+ 0.282094791773878 * torch.ones(1, device=xyz.device).expand(xyz.shape[:-1]),
+ -0.48860251190292 * y,
+ 0.48860251190292 * z,
+ -0.48860251190292 * x,
+ 1.09254843059208 * xy,
+ -1.09254843059208 * yz,
+ 0.94617469575756 * z2 - 0.31539156525252,
+ -1.09254843059208 * xz,
+ 0.54627421529604 * x2 - 0.54627421529604 * y2,
+ -0.590043589926644 * y * (3.0 * x2 - y2),
+ 2.89061144264055 * xy * z,
+ 0.304697199642977 * y * (1.5 - 7.5 * z2),
+ 1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z,
+ 0.304697199642977 * x * (1.5 - 7.5 * z2),
+ 1.44530572132028 * z * (x2 - y2),
+ -0.590043589926644 * x * (x2 - 3.0 * y2),
+ 2.5033429417967 * xy * (x2 - y2),
+ -1.77013076977993 * yz * (3.0 * x2 - y2),
+ 0.126156626101008 * xy * (52.5 * z2 - 7.5),
+ 0.267618617422916 * y * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
+ 1.48099765681286
+ * z
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
+ - 0.952069922236839 * z2
+ + 0.317356640745613,
+ 0.267618617422916 * x * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
+ 0.063078313050504 * (x2 - y2) * (52.5 * z2 - 7.5),
+ -1.77013076977993 * xz * (x2 - 3.0 * y2),
+ -3.75501441269506 * x2 * y2
+ + 0.625835735449176 * x4
+ + 0.625835735449176 * y4,
+ -0.65638205684017 * y * (-10.0 * x2 * y2 + 5.0 * x4 + y4),
+ 8.30264925952416 * xy * z * (x2 - y2),
+ 0.00931882475114763 * y * (52.5 - 472.5 * z2) * (3.0 * x2 - y2),
+ 0.0913054625709205 * xy * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z),
+ 0.241571547304372
+ * y
+ * (
+ 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ + 9.375 * z2
+ - 1.875
+ ),
+ -1.24747010616985 * z * (1.5 * z2 - 0.5)
+ + 1.6840846433293
+ * z
+ * (
+ 1.75
+ * z
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
+ - 1.125 * z2
+ + 0.375
+ )
+ + 0.498988042467941 * z,
+ 0.241571547304372
+ * x
+ * (
+ 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ + 9.375 * z2
+ - 1.875
+ ),
+ 0.0456527312854602 * (x2 - y2) * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z),
+ 0.00931882475114763 * x * (52.5 - 472.5 * z2) * (x2 - 3.0 * y2),
+ 2.07566231488104 * z * (-6.0 * x2 * y2 + x4 + y4),
+ -0.65638205684017 * x * (-10.0 * x2 * y2 + x4 + 5.0 * y4),
+ 4.09910463115149 * x**4 * xy
+ - 13.6636821038383 * xy**3
+ + 4.09910463115149 * xy * y**4,
+ -2.36661916223175 * yz * (-10.0 * x2 * y2 + 5.0 * x4 + y4),
+ 0.00427144889505798 * xy * (x2 - y2) * (5197.5 * z2 - 472.5),
+ 0.00584892228263444
+ * y
+ * (3.0 * x2 - y2)
+ * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z),
+ 0.0701870673916132
+ * xy
+ * (
+ 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
+ - 91.875 * z2
+ + 13.125
+ ),
+ 0.221950995245231
+ * y
+ * (
+ -2.8 * z * (1.5 - 7.5 * z2)
+ + 2.2
+ * z
+ * (
+ 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ + 9.375 * z2
+ - 1.875
+ )
+ - 4.8 * z
+ ),
+ -1.48328138624466
+ * z
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
+ + 1.86469659985043
+ * z
+ * (
+ -1.33333333333333 * z * (1.5 * z2 - 0.5)
+ + 1.8
+ * z
+ * (
+ 1.75
+ * z
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
+ - 1.125 * z2
+ + 0.375
+ )
+ + 0.533333333333333 * z
+ )
+ + 0.953538034014426 * z2
+ - 0.317846011338142,
+ 0.221950995245231
+ * x
+ * (
+ -2.8 * z * (1.5 - 7.5 * z2)
+ + 2.2
+ * z
+ * (
+ 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ + 9.375 * z2
+ - 1.875
+ )
+ - 4.8 * z
+ ),
+ 0.0350935336958066
+ * (x2 - y2)
+ * (
+ 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
+ - 91.875 * z2
+ + 13.125
+ ),
+ 0.00584892228263444
+ * x
+ * (x2 - 3.0 * y2)
+ * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z),
+ 0.0010678622237645 * (5197.5 * z2 - 472.5) * (-6.0 * x2 * y2 + x4 + y4),
+ -2.36661916223175 * xz * (-10.0 * x2 * y2 + x4 + 5.0 * y4),
+ 0.683184105191914 * x2**3
+ + 10.2477615778787 * x2 * y4
+ - 10.2477615778787 * x4 * y2
+ - 0.683184105191914 * y2**3,
+ -0.707162732524596
+ * y
+ * (7.0 * x2**3 + 21.0 * x2 * y4 - 35.0 * x4 * y2 - y2**3),
+ 2.6459606618019 * z * (6.0 * x**4 * xy - 20.0 * xy**3 + 6.0 * xy * y**4),
+ 9.98394571852353e-5
+ * y
+ * (5197.5 - 67567.5 * z2)
+ * (-10.0 * x2 * y2 + 5.0 * x4 + y4),
+ 0.00239614697244565
+ * xy
+ * (x2 - y2)
+ * (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z),
+ 0.00397356022507413
+ * y
+ * (3.0 * x2 - y2)
+ * (
+ 3.25 * z * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z)
+ + 1063.125 * z2
+ - 118.125
+ ),
+ 0.0561946276120613
+ * xy
+ * (
+ -4.8 * z * (52.5 * z2 - 7.5)
+ + 2.6
+ * z
+ * (
+ 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
+ - 91.875 * z2
+ + 13.125
+ )
+ + 48.0 * z
+ ),
+ 0.206472245902897
+ * y
+ * (
+ -2.625 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ + 2.16666666666667
+ * z
+ * (
+ -2.8 * z * (1.5 - 7.5 * z2)
+ + 2.2
+ * z
+ * (
+ 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ + 9.375 * z2
+ - 1.875
+ )
+ - 4.8 * z
+ )
+ - 10.9375 * z2
+ + 2.1875
+ ),
+ 1.24862677781952 * z * (1.5 * z2 - 0.5)
+ - 1.68564615005635
+ * z
+ * (
+ 1.75
+ * z
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
+ - 1.125 * z2
+ + 0.375
+ )
+ + 2.02901851395672
+ * z
+ * (
+ -1.45833333333333
+ * z
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
+ + 1.83333333333333
+ * z
+ * (
+ -1.33333333333333 * z * (1.5 * z2 - 0.5)
+ + 1.8
+ * z
+ * (
+ 1.75
+ * z
+ * (
+ 1.66666666666667 * z * (1.5 * z2 - 0.5)
+ - 0.666666666666667 * z
+ )
+ - 1.125 * z2
+ + 0.375
+ )
+ + 0.533333333333333 * z
+ )
+ + 0.9375 * z2
+ - 0.3125
+ )
+ - 0.499450711127808 * z,
+ 0.206472245902897
+ * x
+ * (
+ -2.625 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ + 2.16666666666667
+ * z
+ * (
+ -2.8 * z * (1.5 - 7.5 * z2)
+ + 2.2
+ * z
+ * (
+ 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ + 9.375 * z2
+ - 1.875
+ )
+ - 4.8 * z
+ )
+ - 10.9375 * z2
+ + 2.1875
+ ),
+ 0.0280973138060306
+ * (x2 - y2)
+ * (
+ -4.8 * z * (52.5 * z2 - 7.5)
+ + 2.6
+ * z
+ * (
+ 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
+ - 91.875 * z2
+ + 13.125
+ )
+ + 48.0 * z
+ ),
+ 0.00397356022507413
+ * x
+ * (x2 - 3.0 * y2)
+ * (
+ 3.25 * z * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z)
+ + 1063.125 * z2
+ - 118.125
+ ),
+ 0.000599036743111412
+ * (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z)
+ * (-6.0 * x2 * y2 + x4 + y4),
+ 9.98394571852353e-5
+ * x
+ * (5197.5 - 67567.5 * z2)
+ * (-10.0 * x2 * y2 + x4 + 5.0 * y4),
+ 2.6459606618019 * z * (x2**3 + 15.0 * x2 * y4 - 15.0 * x4 * y2 - y2**3),
+ -0.707162732524596
+ * x
+ * (x2**3 + 35.0 * x2 * y4 - 21.0 * x4 * y2 - 7.0 * y2**3),
+ 5.83141328139864 * xy * (x2**3 + 7.0 * x2 * y4 - 7.0 * x4 * y2 - y2**3),
+ -2.91570664069932
+ * yz
+ * (7.0 * x2**3 + 21.0 * x2 * y4 - 35.0 * x4 * y2 - y2**3),
+ 7.87853281621404e-6
+ * (1013512.5 * z2 - 67567.5)
+ * (6.0 * x**4 * xy - 20.0 * xy**3 + 6.0 * xy * y**4),
+ 5.10587282657803e-5
+ * y
+ * (5.0 * z * (5197.5 - 67567.5 * z2) + 41580.0 * z)
+ * (-10.0 * x2 * y2 + 5.0 * x4 + y4),
+ 0.00147275890257803
+ * xy
+ * (x2 - y2)
+ * (
+ 3.75 * z * (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z)
+ - 14293.125 * z2
+ + 1299.375
+ ),
+ 0.0028519853513317
+ * y
+ * (3.0 * x2 - y2)
+ * (
+ -7.33333333333333 * z * (52.5 - 472.5 * z2)
+ + 3.0
+ * z
+ * (
+ 3.25 * z * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z)
+ + 1063.125 * z2
+ - 118.125
+ )
+ - 560.0 * z
+ ),
+ 0.0463392770473559
+ * xy
+ * (
+ -4.125 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
+ + 2.5
+ * z
+ * (
+ -4.8 * z * (52.5 * z2 - 7.5)
+ + 2.6
+ * z
+ * (
+ 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
+ - 91.875 * z2
+ + 13.125
+ )
+ + 48.0 * z
+ )
+ + 137.8125 * z2
+ - 19.6875
+ ),
+ 0.193851103820053
+ * y
+ * (
+ 3.2 * z * (1.5 - 7.5 * z2)
+ - 2.51428571428571
+ * z
+ * (
+ 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ + 9.375 * z2
+ - 1.875
+ )
+ + 2.14285714285714
+ * z
+ * (
+ -2.625 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ + 2.16666666666667
+ * z
+ * (
+ -2.8 * z * (1.5 - 7.5 * z2)
+ + 2.2
+ * z
+ * (
+ 2.25
+ * z
+ * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ + 9.375 * z2
+ - 1.875
+ )
+ - 4.8 * z
+ )
+ - 10.9375 * z2
+ + 2.1875
+ )
+ + 5.48571428571429 * z
+ ),
+ 1.48417251362228
+ * z
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
+ - 1.86581687426801
+ * z
+ * (
+ -1.33333333333333 * z * (1.5 * z2 - 0.5)
+ + 1.8
+ * z
+ * (
+ 1.75
+ * z
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
+ - 1.125 * z2
+ + 0.375
+ )
+ + 0.533333333333333 * z
+ )
+ + 2.1808249179756
+ * z
+ * (
+ 1.14285714285714 * z * (1.5 * z2 - 0.5)
+ - 1.54285714285714
+ * z
+ * (
+ 1.75
+ * z
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
+ - 1.125 * z2
+ + 0.375
+ )
+ + 1.85714285714286
+ * z
+ * (
+ -1.45833333333333
+ * z
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
+ + 1.83333333333333
+ * z
+ * (
+ -1.33333333333333 * z * (1.5 * z2 - 0.5)
+ + 1.8
+ * z
+ * (
+ 1.75
+ * z
+ * (
+ 1.66666666666667 * z * (1.5 * z2 - 0.5)
+ - 0.666666666666667 * z
+ )
+ - 1.125 * z2
+ + 0.375
+ )
+ + 0.533333333333333 * z
+ )
+ + 0.9375 * z2
+ - 0.3125
+ )
+ - 0.457142857142857 * z
+ )
+ - 0.954110901614325 * z2
+ + 0.318036967204775,
+ 0.193851103820053
+ * x
+ * (
+ 3.2 * z * (1.5 - 7.5 * z2)
+ - 2.51428571428571
+ * z
+ * (
+ 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ + 9.375 * z2
+ - 1.875
+ )
+ + 2.14285714285714
+ * z
+ * (
+ -2.625 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ + 2.16666666666667
+ * z
+ * (
+ -2.8 * z * (1.5 - 7.5 * z2)
+ + 2.2
+ * z
+ * (
+ 2.25
+ * z
+ * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ + 9.375 * z2
+ - 1.875
+ )
+ - 4.8 * z
+ )
+ - 10.9375 * z2
+ + 2.1875
+ )
+ + 5.48571428571429 * z
+ ),
+ 0.0231696385236779
+ * (x2 - y2)
+ * (
+ -4.125 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
+ + 2.5
+ * z
+ * (
+ -4.8 * z * (52.5 * z2 - 7.5)
+ + 2.6
+ * z
+ * (
+ 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
+ - 91.875 * z2
+ + 13.125
+ )
+ + 48.0 * z
+ )
+ + 137.8125 * z2
+ - 19.6875
+ ),
+ 0.0028519853513317
+ * x
+ * (x2 - 3.0 * y2)
+ * (
+ -7.33333333333333 * z * (52.5 - 472.5 * z2)
+ + 3.0
+ * z
+ * (
+ 3.25 * z * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z)
+ + 1063.125 * z2
+ - 118.125
+ )
+ - 560.0 * z
+ ),
+ 0.000368189725644507
+ * (-6.0 * x2 * y2 + x4 + y4)
+ * (
+ 3.75 * z * (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z)
+ - 14293.125 * z2
+ + 1299.375
+ ),
+ 5.10587282657803e-5
+ * x
+ * (5.0 * z * (5197.5 - 67567.5 * z2) + 41580.0 * z)
+ * (-10.0 * x2 * y2 + x4 + 5.0 * y4),
+ 7.87853281621404e-6
+ * (1013512.5 * z2 - 67567.5)
+ * (x2**3 + 15.0 * x2 * y4 - 15.0 * x4 * y2 - y2**3),
+ -2.91570664069932
+ * xz
+ * (x2**3 + 35.0 * x2 * y4 - 21.0 * x4 * y2 - 7.0 * y2**3),
+ -20.4099464848952 * x2**3 * y2
+ - 20.4099464848952 * x2 * y2**3
+ + 0.72892666017483 * x4**2
+ + 51.0248662122381 * x4 * y4
+ + 0.72892666017483 * y4**2,
+ ],
+ -1,
+ )
+
+
+__all__ = [
+ "rsh_cart_0",
+ "rsh_cart_1",
+ "rsh_cart_2",
+ "rsh_cart_3",
+ "rsh_cart_4",
+ "rsh_cart_5",
+ "rsh_cart_6",
+ "rsh_cart_7",
+ "rsh_cart_8",
+]
+
+
+from typing import Optional
+import torch
+
+
+class SphHarm(torch.nn.Module):
+ def __init__(self, m, n, dtype=torch.float32) -> None:
+ super().__init__()
+ self.dtype = dtype
+ m = torch.tensor(list(range(-m + 1, m)))
+ n = torch.tensor(list(range(n)))
+ self.is_normalized = False
+ vals = torch.cartesian_prod(m, n).T
+ vals = vals[:, vals[0] <= vals[1]]
+ m, n = vals.unbind(0)
+
+ self.register_buffer("m", tensor=m)
+ self.register_buffer("n", tensor=n)
+ self.register_buffer("l_max", tensor=torch.max(self.n))
+
+ f_a, f_b, initial_value, d0_mask_3d, d1_mask_3d = self._init_legendre()
+ self.register_buffer("f_a", tensor=f_a)
+ self.register_buffer("f_b", tensor=f_b)
+ self.register_buffer("d0_mask_3d", tensor=d0_mask_3d)
+ self.register_buffer("d1_mask_3d", tensor=d1_mask_3d)
+ self.register_buffer("initial_value", tensor=initial_value)
+
+ @property
+ def device(self):
+ return next(self.buffers()).device
+
+ def forward(self, points: torch.Tensor) -> torch.Tensor:
+ """Computes the spherical harmonics."""
+ # Y_l^m = (-1) ^ m c_l^m P_l^m(cos(theta)) exp(i m phi)
+ B, N, D = points.shape
+ dtype = points.dtype
+ theta, phi = points.view(-1, D).to(self.dtype).unbind(-1)
+ cos_colatitude = torch.cos(phi)
+ legendre = self._gen_associated_legendre(cos_colatitude)
+ vals = torch.stack([self.m.abs(), self.n], dim=0)
+ vals = torch.cat(
+ [
+ vals.repeat(1, theta.shape[0]),
+ torch.arange(theta.shape[0], device=theta.device)
+ .unsqueeze(0)
+ .repeat_interleave(vals.shape[1], dim=1),
+ ],
+ dim=0,
+ )
+ legendre_vals = legendre[vals[0], vals[1], vals[2]]
+ legendre_vals = legendre_vals.reshape(-1, theta.shape[0])
+ angle = torch.outer(self.m.abs(), theta)
+ vandermonde = torch.complex(torch.cos(angle), torch.sin(angle))
+ harmonics = torch.complex(
+ legendre_vals * torch.real(vandermonde),
+ legendre_vals * torch.imag(vandermonde),
+ )
+
+ # Negative order.
+ m = self.m.unsqueeze(-1)
+ harmonics = torch.where(
+ m < 0, (-1.0) ** m.abs() * torch.conj(harmonics), harmonics
+ )
+ harmonics = harmonics.permute(1, 0).reshape(B, N, -1).to(dtype)
+ return harmonics
+
+ def _gen_recurrence_mask(self) -> tuple[torch.Tensor, torch.Tensor]:
+ """Generates mask for recurrence relation on the remaining entries.
+
+ The remaining entries are with respect to the diagonal and offdiagonal
+ entries.
+
+ Args:
+ l_max: see `gen_normalized_legendre`.
+ Returns:
+ torch.Tensors representing the mask used by the recurrence relations.
+ """
+
+ # Computes all coefficients.
+ m_mat, l_mat = torch.meshgrid(
+ torch.arange(0, self.l_max + 1, device=self.device, dtype=self.dtype),
+ torch.arange(0, self.l_max + 1, device=self.device, dtype=self.dtype),
+ indexing="ij",
+ )
+ if self.is_normalized:
+ c0 = l_mat * l_mat
+ c1 = m_mat * m_mat
+ c2 = 2.0 * l_mat
+ c3 = (l_mat - 1.0) * (l_mat - 1.0)
+ d0 = torch.sqrt((4.0 * c0 - 1.0) / (c0 - c1))
+ d1 = torch.sqrt(((c2 + 1.0) * (c3 - c1)) / ((c2 - 3.0) * (c0 - c1)))
+ else:
+ d0 = (2.0 * l_mat - 1.0) / (l_mat - m_mat)
+ d1 = (l_mat + m_mat - 1.0) / (l_mat - m_mat)
+
+ d0_mask_indices = torch.triu_indices(self.l_max + 1, 1)
+ d1_mask_indices = torch.triu_indices(self.l_max + 1, 2)
+
+ d_zeros = torch.zeros(
+ (self.l_max + 1, self.l_max + 1), dtype=self.dtype, device=self.device
+ )
+ d_zeros[d0_mask_indices] = d0[d0_mask_indices]
+ d0_mask = d_zeros
+
+ d_zeros = torch.zeros(
+ (self.l_max + 1, self.l_max + 1), dtype=self.dtype, device=self.device
+ )
+ d_zeros[d1_mask_indices] = d1[d1_mask_indices]
+ d1_mask = d_zeros
+
+ # Creates a 3D mask that contains 1s on the diagonal plane and 0s elsewhere.
+ i = torch.arange(self.l_max + 1, device=self.device)[:, None, None]
+ j = torch.arange(self.l_max + 1, device=self.device)[None, :, None]
+ k = torch.arange(self.l_max + 1, device=self.device)[None, None, :]
+ mask = (i + j - k == 0).to(self.dtype)
+ d0_mask_3d = torch.einsum("jk,ijk->ijk", d0_mask, mask)
+ d1_mask_3d = torch.einsum("jk,ijk->ijk", d1_mask, mask)
+ return (d0_mask_3d, d1_mask_3d)
+
+ def _recursive(self, i: int, p_val: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
+ coeff_0 = self.d0_mask_3d[i]
+ coeff_1 = self.d1_mask_3d[i]
+ h = torch.einsum(
+ "ij,ijk->ijk",
+ coeff_0,
+ torch.einsum("ijk,k->ijk", torch.roll(p_val, shifts=1, dims=1), x),
+ ) - torch.einsum("ij,ijk->ijk", coeff_1, torch.roll(p_val, shifts=2, dims=1))
+ p_val = p_val + h
+ return p_val
+
+ def _init_legendre(self):
+ a_idx = torch.arange(1, self.l_max + 1, dtype=self.dtype, device=self.device)
+ b_idx = torch.arange(self.l_max, dtype=self.dtype, device=self.device)
+ if self.is_normalized:
+ # The initial value p(0,0).
+ initial_value: torch.Tensor = torch.tensor(
+ 0.5 / (torch.pi**0.5), device=self.device
+ )
+ f_a = torch.cumprod(-1 * torch.sqrt(1.0 + 0.5 / a_idx), dim=0)
+ f_b = torch.sqrt(2.0 * b_idx + 3.0)
+ else:
+ # The initial value p(0,0).
+ initial_value = torch.tensor(1.0, device=self.device)
+ f_a = torch.cumprod(1.0 - 2.0 * a_idx, dim=0)
+ f_b = 2.0 * b_idx + 1.0
+
+ d0_mask_3d, d1_mask_3d = self._gen_recurrence_mask()
+ return f_a, f_b, initial_value, d0_mask_3d, d1_mask_3d
+
+ def _gen_associated_legendre(self, x: torch.Tensor) -> torch.Tensor:
+ r"""Computes associated Legendre functions (ALFs) of the first kind.
+
+ The ALFs of the first kind are used in spherical harmonics. The spherical
+ harmonic of degree `l` and order `m` can be written as
+ `Y_l^m(θ, φ) = N_l^m * P_l^m(cos(θ)) * exp(i m φ)`, where `N_l^m` is the
+ normalization factor and θ and φ are the colatitude and longitude,
+ repectively. `N_l^m` is chosen in the way that the spherical harmonics form
+ a set of orthonormal basis function of L^2(S^2). For the computational
+ efficiency of spherical harmonics transform, the normalization factor is
+ used in the computation of the ALFs. In addition, normalizing `P_l^m`
+ avoids overflow/underflow and achieves better numerical stability. Three
+ recurrence relations are used in the computation.
+
+ Args:
+ l_max: The maximum degree of the associated Legendre function. Both the
+ degrees and orders are `[0, 1, 2, ..., l_max]`.
+ x: A vector of type `float32`, `float64` containing the sampled points in
+ spherical coordinates, at which the ALFs are computed; `x` is essentially
+ `cos(θ)`. For the numerical integration used by the spherical harmonics
+ transforms, `x` contains the quadrature points in the interval of
+ `[-1, 1]`. There are several approaches to provide the quadrature points:
+ Gauss-Legendre method (`scipy.special.roots_legendre`), Gauss-Chebyshev
+ method (`scipy.special.roots_chebyu`), and Driscoll & Healy
+ method (Driscoll, James R., and Dennis M. Healy. "Computing Fourier
+ transforms and convolutions on the 2-sphere." Advances in applied
+ mathematics 15, no. 2 (1994): 202-250.). The Gauss-Legendre quadrature
+ points are nearly equal-spaced along θ and provide exact discrete
+ orthogonality, (P^m)^T W P_m = I, where `T` represents the transpose
+ operation, `W` is a diagonal matrix containing the quadrature weights,
+ and `I` is the identity matrix. The Gauss-Chebyshev points are equally
+ spaced, which only provide approximate discrete orthogonality. The
+ Driscoll & Healy qudarture points are equally spaced and provide the
+ exact discrete orthogonality. The number of sampling points is required to
+ be twice as the number of frequency points (modes) in the Driscoll & Healy
+ approach, which enables FFT and achieves a fast spherical harmonics
+ transform.
+ is_normalized: True if the associated Legendre functions are normalized.
+ With normalization, `N_l^m` is applied such that the spherical harmonics
+ form a set of orthonormal basis functions of L^2(S^2).
+
+ Returns:
+ The 3D array of shape `(l_max + 1, l_max + 1, len(x))` containing the values
+ of the ALFs at `x`; the dimensions in the sequence of order, degree, and
+ evalution points.
+ """
+ p = torch.zeros(
+ (self.l_max + 1, self.l_max + 1, x.shape[0]), dtype=x.dtype, device=x.device
+ )
+ p[0, 0] = self.initial_value
+
+ # Compute the diagonal entries p(l,l) with recurrence.
+ y = torch.cumprod(
+ torch.broadcast_to(torch.sqrt(1.0 - x * x), (self.l_max, x.shape[0])), dim=0
+ )
+ p_diag = self.initial_value * torch.einsum("i,ij->ij", self.f_a, y)
+ # torch.diag_indices(l_max + 1)
+ diag_indices = torch.stack(
+ [torch.arange(0, self.l_max + 1, device=x.device)] * 2, dim=0
+ )
+ p[(diag_indices[0][1:], diag_indices[1][1:])] = p_diag
+
+ diag_indices = torch.stack(
+ [torch.arange(0, self.l_max, device=x.device)] * 2, dim=0
+ )
+
+ # Compute the off-diagonal entries with recurrence.
+ p_offdiag = torch.einsum(
+ "ij,ij->ij",
+ torch.einsum("i,j->ij", self.f_b, x),
+ p[(diag_indices[0], diag_indices[1])],
+ ) # p[torch.diag_indices(l_max)])
+ p[(diag_indices[0][: self.l_max], diag_indices[1][: self.l_max] + 1)] = (
+ p_offdiag
+ )
+
+ # Compute the remaining entries with recurrence.
+ if self.l_max > 1:
+ for i in range(2, self.l_max + 1):
+ p = self._recursive(i, p, x)
+ return p
diff --git a/src/misc/step_tracker.py b/src/misc/step_tracker.py
new file mode 100644
index 0000000000000000000000000000000000000000..7298ffcf5cb028799d67c4dedf7ef1a4bf6fe802
--- /dev/null
+++ b/src/misc/step_tracker.py
@@ -0,0 +1,23 @@
+from multiprocessing import RLock
+
+import torch
+from jaxtyping import Int64
+from torch import Tensor
+from torch.multiprocessing import Manager
+
+
+class StepTracker:
+ lock: RLock
+ step: Int64[Tensor, ""]
+
+ def __init__(self):
+ self.lock = Manager().RLock()
+ self.step = torch.tensor(0, dtype=torch.int64).share_memory_()
+
+ def set_step(self, step: int) -> None:
+ with self.lock:
+ self.step.fill_(step)
+
+ def get_step(self) -> int:
+ with self.lock:
+ return self.step.item()
diff --git a/src/misc/utils.py b/src/misc/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..4e29b8f7e6fa06937b8196988a2a5ef0beb7280b
--- /dev/null
+++ b/src/misc/utils.py
@@ -0,0 +1,73 @@
+import torch
+
+from src.visualization.color_map import apply_color_map_to_image
+import torch.distributed as dist
+
+def inverse_normalize(tensor, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)):
+ mean = torch.as_tensor(mean, dtype=tensor.dtype, device=tensor.device).view(-1, 1, 1)
+ std = torch.as_tensor(std, dtype=tensor.dtype, device=tensor.device).view(-1, 1, 1)
+ return tensor.mul(std).add(mean)
+
+
+# Color-map the result.
+def vis_depth_map(result, near=None, far=None):
+ if near is None and far is None:
+ far = result.view(-1)[:16_000_000].quantile(0.99).log()
+ try:
+ near = result[result > 0][:16_000_000].quantile(0.01).log()
+ except:
+ print("No valid depth values found.")
+ near = torch.zeros_like(far)
+ else:
+ near = near.log()
+ far = far.log()
+
+ result = result.log()
+ result = 1 - (result - near) / (far - near)
+ return apply_color_map_to_image(result, "turbo")
+
+
+def confidence_map(result):
+ # far = result.view(-1)[:16_000_000].quantile(0.99).log()
+ # try:
+ # near = result[result > 0][:16_000_000].quantile(0.01).log()
+ # except:
+ # print("No valid depth values found.")
+ # near = torch.zeros_like(far)
+ # result = result.log()
+ # result = 1 - (result - near) / (far - near)
+ result = result / result.view(-1).max()
+ return apply_color_map_to_image(result, "magma")
+
+
+def get_overlap_tag(overlap):
+ if 0.05 <= overlap <= 0.3:
+ overlap_tag = "small"
+ elif overlap <= 0.55:
+ overlap_tag = "medium"
+ elif overlap <= 0.8:
+ overlap_tag = "large"
+ else:
+ overlap_tag = "ignore"
+
+ return overlap_tag
+
+
+def is_dist_avail_and_initialized():
+ if not dist.is_available():
+ return False
+ if not dist.is_initialized():
+ return False
+ return True
+
+
+def get_world_size():
+ if not is_dist_avail_and_initialized():
+ return 1
+ return dist.get_world_size()
+
+
+def get_rank():
+ if not is_dist_avail_and_initialized():
+ return 0
+ return dist.get_rank()
\ No newline at end of file
diff --git a/src/misc/wandb_tools.py b/src/misc/wandb_tools.py
new file mode 100644
index 0000000000000000000000000000000000000000..a5dc362ad13eef89fcde805695095dbff8ef375c
--- /dev/null
+++ b/src/misc/wandb_tools.py
@@ -0,0 +1,62 @@
+from pathlib import Path
+
+import wandb
+
+
+def version_to_int(artifact) -> int:
+ """Convert versions of the form vX to X. For example, v12 to 12."""
+ return int(artifact.version[1:])
+
+
+def download_checkpoint(
+ run_id: str,
+ download_dir: Path,
+ version: str | None,
+) -> Path:
+ api = wandb.Api()
+ run = api.run(run_id)
+
+ # Find the latest saved model checkpoint.
+ chosen = None
+ for artifact in run.logged_artifacts():
+ if artifact.type != "model" or artifact.state != "COMMITTED":
+ continue
+
+ # If no version is specified, use the latest.
+ if version is None:
+ if chosen is None or version_to_int(artifact) > version_to_int(chosen):
+ chosen = artifact
+
+ # If a specific verison is specified, look for it.
+ elif version == artifact.version:
+ chosen = artifact
+ break
+
+ # Download the checkpoint.
+ download_dir.mkdir(exist_ok=True, parents=True)
+ root = download_dir / run_id
+ chosen.download(root=root)
+ return root / "model.ckpt"
+
+
+def update_checkpoint_path(path: str | None, wandb_cfg: dict) -> Path | None:
+ if path is None:
+ return None
+
+ if not str(path).startswith("wandb://"):
+ return Path(path)
+
+ run_id, *version = path[len("wandb://") :].split(":")
+ if len(version) == 0:
+ version = None
+ elif len(version) == 1:
+ version = version[0]
+ else:
+ raise ValueError("Invalid version specifier!")
+
+ project = wandb_cfg["project"]
+ return download_checkpoint(
+ f"{project}/{run_id}",
+ Path("checkpoints"),
+ version,
+ )
diff --git a/src/misc/weight_modify.py b/src/misc/weight_modify.py
new file mode 100644
index 0000000000000000000000000000000000000000..b6ed4771c239933615764e91800e513567f42feb
--- /dev/null
+++ b/src/misc/weight_modify.py
@@ -0,0 +1,197 @@
+import logging
+from typing import List, Dict
+
+import math
+import torch
+from torch import nn as nn
+import torch.nn.functional as F
+
+
+_logger = logging.getLogger(__name__)
+
+
+def resample_patch_embed(
+ patch_embed,
+ new_size: List[int],
+ interpolation: str = 'bicubic',
+ antialias: bool = True,
+ verbose: bool = False,
+):
+ """Resample the weights of the patch embedding kernel to target resolution.
+ We resample the patch embedding kernel by approximately inverting the effect
+ of patch resizing.
+
+ Code based on:
+ https://github.com/google-research/big_vision/blob/b00544b81f8694488d5f36295aeb7972f3755ffe/big_vision/models/proj/flexi/vit.py
+
+ With this resizing, we can for example load a B/8 filter into a B/16 model
+ and, on 2x larger input image, the result will match.
+
+ Args:
+ patch_embed: original parameter to be resized.
+ new_size (tuple(int, int): target shape (height, width)-only.
+ interpolation (str): interpolation for resize
+ antialias (bool): use anti-aliasing filter in resize
+ verbose (bool): log operation
+ Returns:
+ Resized patch embedding kernel.
+ """
+ import numpy as np
+ try:
+ import functorch
+ vmap = functorch.vmap
+ except ImportError:
+ if hasattr(torch, 'vmap'):
+ vmap = torch.vmap
+ else:
+ assert False, "functorch or a version of torch with vmap is required for FlexiViT resizing."
+
+ assert len(patch_embed.shape) == 4, "Four dimensions expected"
+ assert len(new_size) == 2, "New shape should only be hw"
+ old_size = patch_embed.shape[-2:]
+ if tuple(old_size) == tuple(new_size):
+ return patch_embed
+
+ if verbose:
+ _logger.info(f"Resize patch embedding {patch_embed.shape} to {new_size}, w/ {interpolation} interpolation.")
+
+ def resize(x_np, _new_size):
+ x_tf = torch.Tensor(x_np)[None, None, ...]
+ x_upsampled = F.interpolate(
+ x_tf, size=_new_size, mode=interpolation, antialias=antialias)[0, 0, ...].numpy()
+ return x_upsampled
+
+ def get_resize_mat(_old_size, _new_size):
+ mat = []
+ for i in range(np.prod(_old_size)):
+ basis_vec = np.zeros(_old_size)
+ basis_vec[np.unravel_index(i, _old_size)] = 1.
+ mat.append(resize(basis_vec, _new_size).reshape(-1))
+ return np.stack(mat).T
+
+ resize_mat = get_resize_mat(old_size, new_size)
+ resize_mat_pinv = torch.tensor(np.linalg.pinv(resize_mat.T), device=patch_embed.device)
+
+ def resample_kernel(kernel):
+ resampled_kernel = resize_mat_pinv @ kernel.reshape(-1)
+ return resampled_kernel.reshape(new_size)
+
+ v_resample_kernel = vmap(vmap(resample_kernel, 0, 0), 1, 1)
+ orig_dtype = patch_embed.dtype
+ patch_embed = patch_embed.float()
+ patch_embed = v_resample_kernel(patch_embed)
+ patch_embed = patch_embed.to(orig_dtype)
+ return patch_embed
+
+
+def adapt_input_conv(in_chans, conv_weight):
+ conv_type = conv_weight.dtype
+ conv_weight = conv_weight.float() # Some weights are in torch.half, ensure it's float for sum on CPU
+ O, I, J, K = conv_weight.shape
+ if in_chans == 1:
+ if I > 3:
+ assert conv_weight.shape[1] % 3 == 0
+ # For models with space2depth stems
+ conv_weight = conv_weight.reshape(O, I // 3, 3, J, K)
+ conv_weight = conv_weight.sum(dim=2, keepdim=False)
+ else:
+ conv_weight = conv_weight.sum(dim=1, keepdim=True)
+ elif in_chans != 3:
+ if I != 3:
+ raise NotImplementedError('Weight format not supported by conversion.')
+ else:
+ # NOTE this strategy should be better than random init, but there could be other combinations of
+ # the original RGB input layer weights that'd work better for specific cases.
+ repeat = int(math.ceil(in_chans / 3))
+ conv_weight = conv_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :]
+ conv_weight *= (3 / float(in_chans))
+
+ # instead of assigning the same weight to all channels, we can assign higher weight for original RGB channels
+ # conv_weight[:, :3, :, :] = conv_weight[:, :3, :, :] * 0.5
+ # conv_weight[:, 3:, :, :] = conv_weight[:, 3:, :, :] * 0.5 * (3 / float(in_chans - 3))
+
+ conv_weight = conv_weight.to(conv_type)
+ return conv_weight
+
+
+def adapt_head_conv(conv_weight):
+ conv_type = conv_weight.dtype
+ conv_weight = conv_weight.float() # Some weights are in torch.half, ensure it's float for sum on CPU
+ O, I, J, K = conv_weight.shape
+
+ conv_weight_new = torch.chunk(conv_weight, 6, dim=1)
+ conv_weight_new = [conv_weight_new.mean(dim=1, keepdim=True) for conv_weight_new in conv_weight_new]
+ conv_weight_new = torch.cat(conv_weight_new, dim=1) * 0.5
+ conv_weight = torch.cat([conv_weight, conv_weight_new], dim=1)
+ conv_weight = conv_weight.to(conv_type)
+ return conv_weight
+
+
+def adapt_linear(conv_weight):
+ conv_type = conv_weight.dtype
+ conv_weight = conv_weight.float() # Some weights are in torch.half, ensure it's float for sum on CPU
+ O, I = conv_weight.shape
+
+ conv_weight_new = torch.tensor_split(conv_weight, 81, dim=1)
+ conv_weight_new = [conv_weight_new.mean(dim=1, keepdim=True) for conv_weight_new in conv_weight_new]
+ conv_weight_new = torch.cat(conv_weight_new, dim=1)
+ # conv_weight = torch.cat([conv_weight, conv_weight_new], dim=1)
+ conv_weight = torch.cat([conv_weight * 0.5, conv_weight_new * 0.5], dim=1)
+ conv_weight = conv_weight.to(conv_type)
+ return conv_weight
+
+
+def checkpoint_filter_fn(
+ state_dict: Dict[str, torch.Tensor],
+ model: nn.Module,
+ interpolation: str = 'bicubic',
+ antialias: bool = True,
+) -> Dict[str, torch.Tensor]:
+ """ convert patch embedding weight from manual patchify + linear proj to conv"""
+ out_dict = {}
+ # state_dict = state_dict.get('model', state_dict)
+ # state_dict = state_dict.get('state_dict', state_dict)
+ prefix = ''
+
+ if prefix:
+ # filter on & remove prefix string from keys
+ state_dict = {k[len(prefix):]: v for k, v in state_dict.items() if k.startswith(prefix)}
+
+ for k, v in state_dict.items():
+ if 'patch_embed.proj.weight' in k:
+ O, I, H, W = model.backbone.patch_embed.proj.weight.shape
+ if len(v.shape) < 4:
+ # For old models that I trained prior to conv based patchification
+ O, I, H, W = model.backbone.patch_embed.proj.weight.shape
+ v = v.reshape(O, -1, H, W)
+ if v.shape[-1] != W or v.shape[-2] != H:
+ v = resample_patch_embed(
+ v,
+ (H, W),
+ interpolation=interpolation,
+ antialias=antialias,
+ verbose=True,
+ )
+ if v.shape[1] != I:
+ v = adapt_input_conv(I, v)
+ # elif 'downstream_head1.dpt.head.0.weight' in k or 'downstream_head2.dpt.head.0.weight' in k:
+ # v = adapt_head_conv(v)
+
+ elif 'decoder_embed.weight' in k:
+ O, I = model.backbone.decoder_embed.weight.shape
+ if v.shape[1] != I:
+ v = adapt_linear(v)
+
+ out_dict[k] = v
+
+ # add prefix to make our model happy
+ prefix = 'backbone.'
+ out_dict = {prefix + k if 'downstream_head' not in k else k: v for k, v in out_dict.items()}
+
+ # # remove the conf head weights
+ out_dict['downstream_head1.dpt.head.4.weight'] = out_dict['downstream_head1.dpt.head.4.weight'][0:3]
+ out_dict['downstream_head1.dpt.head.4.bias'] = out_dict['downstream_head1.dpt.head.4.bias'][0:3]
+ out_dict['downstream_head2.dpt.head.4.weight'] = out_dict['downstream_head2.dpt.head.4.weight'][0:3]
+ out_dict['downstream_head2.dpt.head.4.bias'] = out_dict['downstream_head2.dpt.head.4.bias'][0:3]
+
+ return out_dict
diff --git a/src/model/decoder/__init__.py b/src/model/decoder/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d757213bdfaeb3994ab67cbf47852f1a2b5c1d3
--- /dev/null
+++ b/src/model/decoder/__init__.py
@@ -0,0 +1,12 @@
+from .decoder import Decoder
+from .decoder_splatting_cuda import DecoderSplattingCUDA, DecoderSplattingCUDACfg
+
+DECODERS = {
+ "splatting_cuda": DecoderSplattingCUDA,
+}
+
+DecoderCfg = DecoderSplattingCUDACfg
+
+
+def get_decoder(decoder_cfg: DecoderCfg) -> Decoder:
+ return DECODERS[decoder_cfg.name](decoder_cfg)
diff --git a/src/model/decoder/cuda_splatting.py b/src/model/decoder/cuda_splatting.py
new file mode 100644
index 0000000000000000000000000000000000000000..367c8cc13d998c97b1cdc96c6d076f04fda3be4c
--- /dev/null
+++ b/src/model/decoder/cuda_splatting.py
@@ -0,0 +1,244 @@
+from math import isqrt
+from typing import Literal
+
+import torch
+from diff_gaussian_rasterization import (
+ GaussianRasterizationSettings,
+ GaussianRasterizer,
+)
+from einops import einsum, rearrange, repeat
+from jaxtyping import Float, Bool
+from torch import Tensor
+
+from ...geometry.projection import get_fov, homogenize_points
+
+
+def get_projection_matrix(
+ near: Float[Tensor, " batch"],
+ far: Float[Tensor, " batch"],
+ fov_x: Float[Tensor, " batch"],
+ fov_y: Float[Tensor, " batch"],
+) -> Float[Tensor, "batch 4 4"]:
+ """Maps points in the viewing frustum to (-1, 1) on the X/Y axes and (0, 1) on the Z
+ axis. Differs from the OpenGL version in that Z doesn't have range (-1, 1) after
+ transformation and that Z is flipped.
+ """
+ tan_fov_x = (0.5 * fov_x).tan()
+ tan_fov_y = (0.5 * fov_y).tan()
+
+ top = tan_fov_y * near
+ bottom = -top
+ right = tan_fov_x * near
+ left = -right
+
+ (b,) = near.shape
+ result = torch.zeros((b, 4, 4), dtype=torch.float32, device=near.device)
+ result[:, 0, 0] = 2 * near / (right - left)
+ result[:, 1, 1] = 2 * near / (top - bottom)
+ result[:, 0, 2] = (right + left) / (right - left)
+ result[:, 1, 2] = (top + bottom) / (top - bottom)
+ result[:, 3, 2] = 1
+ result[:, 2, 2] = far / (far - near)
+ result[:, 2, 3] = -(far * near) / (far - near)
+ return result
+
+
+def render_cuda(
+ extrinsics: Float[Tensor, "batch 4 4"],
+ intrinsics: Float[Tensor, "batch 3 3"],
+ near: Float[Tensor, " batch"],
+ far: Float[Tensor, " batch"],
+ image_shape: tuple[int, int],
+ background_color: Float[Tensor, "batch 3"],
+ gaussian_means: Float[Tensor, "batch gaussian 3"],
+ gaussian_covariances: Float[Tensor, "batch gaussian 3 3"],
+ gaussian_sh_coefficients: Float[Tensor, "batch gaussian 3 d_sh"],
+ gaussian_opacities: Float[Tensor, "batch gaussian"],
+ scale_invariant: bool = True,
+ use_sh: bool = True,
+ cam_rot_delta: Float[Tensor, "batch 3"] | None = None,
+ cam_trans_delta: Float[Tensor, "batch 3"] | None = None,
+ voxel_masks: Bool[Tensor, "batch gaussian"] | None = None,
+) -> tuple[Float[Tensor, "batch 3 height width"], Float[Tensor, "batch height width"]]:
+ assert use_sh or gaussian_sh_coefficients.shape[-1] == 1
+
+ # Make sure everything is in a range where numerical issues don't appear.
+ if scale_invariant:
+ scale = 1 / near
+ extrinsics = extrinsics.clone()
+ extrinsics[..., :3, 3] = extrinsics[..., :3, 3] * scale[:, None]
+ gaussian_covariances = gaussian_covariances * (scale[:, None, None, None] ** 2)
+ gaussian_means = gaussian_means * scale[:, None, None]
+ near = near * scale
+ far = far * scale
+
+ _, _, _, n = gaussian_sh_coefficients.shape
+ degree = isqrt(n) - 1
+ shs = rearrange(gaussian_sh_coefficients, "b g xyz n -> b g n xyz").contiguous()
+
+ b, _, _ = extrinsics.shape
+ h, w = image_shape
+
+ fov_x, fov_y = get_fov(intrinsics).unbind(dim=-1)
+ tan_fov_x = (0.5 * fov_x).tan()
+ tan_fov_y = (0.5 * fov_y).tan()
+
+ projection_matrix = get_projection_matrix(near, far, fov_x, fov_y)
+ projection_matrix = rearrange(projection_matrix, "b i j -> b j i")
+ view_matrix = rearrange(extrinsics.inverse(), "b i j -> b j i")
+ full_projection = view_matrix @ projection_matrix
+
+ all_images = []
+ all_radii = []
+ all_depths = []
+ for i in range(b):
+ # Set up a tensor for the gradients of the screen-space means.
+ mean_gradients = torch.zeros_like(gaussian_means[i], requires_grad=True)
+ try:
+ mean_gradients.retain_grad()
+ except Exception:
+ pass
+
+ settings = GaussianRasterizationSettings(
+ image_height=h,
+ image_width=w,
+ tanfovx=tan_fov_x[i].item(),
+ tanfovy=tan_fov_y[i].item(),
+ bg=background_color[i],
+ scale_modifier=1.0,
+ viewmatrix=view_matrix[i],
+ projmatrix=full_projection[i],
+ projmatrix_raw=projection_matrix[i],
+ sh_degree=degree,
+ campos=extrinsics[i, :3, 3],
+ prefiltered=False, # This matches the original usage.
+ debug=False,
+ )
+ rasterizer = GaussianRasterizer(settings)
+
+ row, col = torch.triu_indices(3, 3)
+
+ if voxel_masks is not None:
+ voxel_mask = voxel_masks[i]
+ image, radii, depth, opacity, n_touched = rasterizer(
+ means3D=gaussian_means[i][voxel_mask],
+ means2D=mean_gradients[voxel_mask],
+ shs=shs[i][voxel_mask] if use_sh else None,
+ colors_precomp=None if use_sh else shs[i, :, 0, :][voxel_mask],
+ opacities=gaussian_opacities[i][voxel_mask, ..., None],
+ cov3D_precomp=gaussian_covariances[i, :, row, col][voxel_mask],
+ theta=cam_rot_delta[i] if cam_rot_delta is not None else None,
+ rho=cam_trans_delta[i] if cam_trans_delta is not None else None,
+ )
+ else:
+ image, radii, depth, opacity, n_touched = rasterizer(
+ means3D=gaussian_means[i],
+ means2D=mean_gradients,
+ shs=shs[i] if use_sh else None,
+ colors_precomp=None if use_sh else shs[i, :, 0, :],
+ opacities=gaussian_opacities[i, ..., None],
+ cov3D_precomp=gaussian_covariances[i, :, row, col],
+ theta=cam_rot_delta[i] if cam_rot_delta is not None else None,
+ rho=cam_trans_delta[i] if cam_trans_delta is not None else None,
+ )
+ all_images.append(image)
+ all_radii.append(radii)
+ all_depths.append(depth.squeeze(0))
+ return torch.stack(all_images), torch.stack(all_depths)
+
+
+def render_cuda_orthographic(
+ extrinsics: Float[Tensor, "batch 4 4"],
+ width: Float[Tensor, " batch"],
+ height: Float[Tensor, " batch"],
+ near: Float[Tensor, " batch"],
+ far: Float[Tensor, " batch"],
+ image_shape: tuple[int, int],
+ background_color: Float[Tensor, "batch 3"],
+ gaussian_means: Float[Tensor, "batch gaussian 3"],
+ gaussian_covariances: Float[Tensor, "batch gaussian 3 3"],
+ gaussian_sh_coefficients: Float[Tensor, "batch gaussian 3 d_sh"],
+ gaussian_opacities: Float[Tensor, "batch gaussian"],
+ fov_degrees: float = 0.1,
+ use_sh: bool = True,
+ dump: dict | None = None,
+) -> Float[Tensor, "batch 3 height width"]:
+ b, _, _ = extrinsics.shape
+ h, w = image_shape
+ assert use_sh or gaussian_sh_coefficients.shape[-1] == 1
+
+ _, _, _, n = gaussian_sh_coefficients.shape
+ degree = isqrt(n) - 1
+ shs = rearrange(gaussian_sh_coefficients, "b g xyz n -> b g n xyz").contiguous()
+
+ # Create fake "orthographic" projection by moving the camera back and picking a
+ # small field of view.
+ fov_x = torch.tensor(fov_degrees, device=extrinsics.device).deg2rad()
+ tan_fov_x = (0.5 * fov_x).tan()
+ distance_to_near = (0.5 * width) / tan_fov_x
+ tan_fov_y = 0.5 * height / distance_to_near
+ fov_y = (2 * tan_fov_y).atan()
+ near = near + distance_to_near
+ far = far + distance_to_near
+ move_back = torch.eye(4, dtype=torch.float32, device=extrinsics.device)
+ move_back[2, 3] = -distance_to_near
+ extrinsics = extrinsics @ move_back
+
+ # Escape hatch for visualization/figures.
+ if dump is not None:
+ dump["extrinsics"] = extrinsics
+ dump["fov_x"] = fov_x
+ dump["fov_y"] = fov_y
+ dump["near"] = near
+ dump["far"] = far
+
+ projection_matrix = get_projection_matrix(
+ near, far, repeat(fov_x, "-> b", b=b), fov_y
+ )
+ projection_matrix = rearrange(projection_matrix, "b i j -> b j i")
+ view_matrix = rearrange(extrinsics.inverse(), "b i j -> b j i")
+ full_projection = view_matrix @ projection_matrix
+
+ all_images = []
+ all_radii = []
+ for i in range(b):
+ # Set up a tensor for the gradients of the screen-space means.
+ mean_gradients = torch.zeros_like(gaussian_means[i], requires_grad=True)
+ try:
+ mean_gradients.retain_grad()
+ except Exception:
+ pass
+
+ settings = GaussianRasterizationSettings(
+ image_height=h,
+ image_width=w,
+ tanfovx=tan_fov_x,
+ tanfovy=tan_fov_y,
+ bg=background_color[i],
+ scale_modifier=1.0,
+ viewmatrix=view_matrix[i],
+ projmatrix=full_projection[i],
+ projmatrix_raw=projection_matrix[i],
+ sh_degree=degree,
+ campos=extrinsics[i, :3, 3],
+ prefiltered=False, # This matches the original usage.
+ debug=False,
+ )
+ rasterizer = GaussianRasterizer(settings)
+
+ row, col = torch.triu_indices(3, 3)
+
+ image, radii, depth, opacity, n_touched = rasterizer(
+ means3D=gaussian_means[i],
+ means2D=mean_gradients,
+ shs=shs[i] if use_sh else None,
+ colors_precomp=None if use_sh else shs[i, :, 0, :],
+ opacities=gaussian_opacities[i, ..., None],
+ cov3D_precomp=gaussian_covariances[i, :, row, col],
+ )
+ all_images.append(image)
+ all_radii.append(radii)
+ return torch.stack(all_images)
+
+
+DepthRenderingMode = Literal["depth", "disparity", "relative_disparity", "log"]
diff --git a/src/model/decoder/decoder.py b/src/model/decoder/decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..296d3301b7b11c09da1d674a62050ed9f672fcb6
--- /dev/null
+++ b/src/model/decoder/decoder.py
@@ -0,0 +1,46 @@
+from abc import ABC, abstractmethod
+from dataclasses import dataclass
+from typing import Generic, Literal, TypeVar
+
+from jaxtyping import Float
+from torch import Tensor, nn
+
+from ..types import Gaussians
+
+DepthRenderingMode = Literal[
+ "depth",
+ "log",
+ "disparity",
+ "relative_disparity",
+]
+
+
+@dataclass
+class DecoderOutput:
+ color: Float[Tensor, "batch view 3 height width"]
+ depth: Float[Tensor, "batch view height width"] | None
+ alpha: Float[Tensor, "batch view height width"] | None
+ lod_rendering: dict | None
+
+T = TypeVar("T")
+
+
+class Decoder(nn.Module, ABC, Generic[T]):
+ cfg: T
+
+ def __init__(self, cfg: T) -> None:
+ super().__init__()
+ self.cfg = cfg
+
+ @abstractmethod
+ def forward(
+ self,
+ gaussians: Gaussians,
+ extrinsics: Float[Tensor, "batch view 4 4"],
+ intrinsics: Float[Tensor, "batch view 3 3"],
+ near: Float[Tensor, "batch view"],
+ far: Float[Tensor, "batch view"],
+ image_shape: tuple[int, int],
+ depth_mode: DepthRenderingMode | None = None,
+ ) -> DecoderOutput:
+ pass
diff --git a/src/model/decoder/decoder_splatting_cuda.py b/src/model/decoder/decoder_splatting_cuda.py
new file mode 100644
index 0000000000000000000000000000000000000000..59681e6bb95b0a129437f913edd9c60331d49f88
--- /dev/null
+++ b/src/model/decoder/decoder_splatting_cuda.py
@@ -0,0 +1,111 @@
+from dataclasses import dataclass
+from typing import Literal
+
+import torch
+from einops import rearrange, repeat
+from jaxtyping import Float
+from torch import Tensor
+import torchvision
+
+from ..types import Gaussians
+# from .cuda_splatting import DepthRenderingMode, render_cuda
+from .decoder import Decoder, DecoderOutput
+from math import sqrt
+from gsplat import rasterization
+
+from ...misc.utils import vis_depth_map
+
+DepthRenderingMode = Literal["depth", "disparity", "relative_disparity", "log"]
+
+@dataclass
+class DecoderSplattingCUDACfg:
+ name: Literal["splatting_cuda"]
+ background_color: list[float]
+ make_scale_invariant: bool
+
+
+class DecoderSplattingCUDA(Decoder[DecoderSplattingCUDACfg]):
+ background_color: Float[Tensor, "3"]
+
+ def __init__(
+ self,
+ cfg: DecoderSplattingCUDACfg,
+ ) -> None:
+ super().__init__(cfg)
+ self.make_scale_invariant = cfg.make_scale_invariant
+ self.register_buffer(
+ "background_color",
+ torch.tensor(cfg.background_color, dtype=torch.float32),
+ persistent=False,
+ )
+
+ def rendering_fn(
+ self,
+ gaussians: Gaussians,
+ extrinsics: Float[Tensor, "batch view 4 4"],
+ intrinsics: Float[Tensor, "batch view 3 3"],
+ near: Float[Tensor, "batch view"],
+ far: Float[Tensor, "batch view"],
+ image_shape: tuple[int, int],
+ depth_mode: DepthRenderingMode | None = None,
+ cam_rot_delta: Float[Tensor, "batch view 3"] | None = None,
+ cam_trans_delta: Float[Tensor, "batch view 3"] | None = None,
+ ) -> DecoderOutput:
+ B, V, _, _ = intrinsics.shape
+ H, W = image_shape
+ rendered_imgs, rendered_depths, rendered_alphas = [], [], []
+ xyzs, opacitys, rotations, scales, features = gaussians.means, gaussians.opacities, gaussians.rotations, gaussians.scales, gaussians.harmonics.permute(0, 1, 3, 2).contiguous()
+ covariances = gaussians.covariances
+ for i in range(B):
+ xyz_i = xyzs[i].float()
+ feature_i = features[i].float()
+ covar_i = covariances[i].float()
+ scale_i = scales[i].float()
+ rotation_i = rotations[i].float()
+ opacity_i = opacitys[i].squeeze().float()
+ test_w2c_i = extrinsics[i].float().inverse() # (V, 4, 4)
+ test_intr_i_normalized = intrinsics[i].float()
+ # Denormalize the intrinsics into standred format
+ test_intr_i = test_intr_i_normalized.clone()
+ test_intr_i[:, 0] = test_intr_i_normalized[:, 0] * W
+ test_intr_i[:, 1] = test_intr_i_normalized[:, 1] * H
+ sh_degree = (int(sqrt(feature_i.shape[-2])) - 1)
+
+ rendering_list = []
+ rendering_depth_list = []
+ rendering_alpha_list = []
+ for j in range(V):
+ rendering, alpha, _ = rasterization(xyz_i, rotation_i, scale_i, opacity_i, feature_i,
+ test_w2c_i[j:j+1], test_intr_i[j:j+1], W, H, sh_degree=sh_degree,
+ # near_plane=near[i].mean(), far_plane=far[i].mean(),
+ render_mode="RGB+D", packed=False,
+ near_plane=1e-10,
+ backgrounds=self.background_color.unsqueeze(0).repeat(1, 1),
+ radius_clip=0.1,
+ covars=covar_i,
+ rasterize_mode='classic') # (V, H, W, 3)
+ rendering_img, rendering_depth = torch.split(rendering, [3, 1], dim=-1)
+ rendering_img = rendering_img.clamp(0.0, 1.0)
+ rendering_list.append(rendering_img.permute(0, 3, 1, 2))
+ rendering_depth_list.append(rendering_depth)
+ rendering_alpha_list.append(alpha)
+ rendered_depths.append(torch.cat(rendering_depth_list, dim=0).squeeze())
+ rendered_imgs.append(torch.cat(rendering_list, dim=0))
+ rendered_alphas.append(torch.cat(rendering_alpha_list, dim=0).squeeze())
+ return DecoderOutput(torch.stack(rendered_imgs), torch.stack(rendered_depths), torch.stack(rendered_alphas), lod_rendering=None)
+
+ def forward(
+ self,
+ gaussians: Gaussians,
+ extrinsics: Float[Tensor, "batch view 4 4"],
+ intrinsics: Float[Tensor, "batch view 3 3"],
+ near: Float[Tensor, "batch view"],
+ far: Float[Tensor, "batch view"],
+ image_shape: tuple[int, int],
+ depth_mode: DepthRenderingMode | None = None,
+ cam_rot_delta: Float[Tensor, "batch view 3"] | None = None,
+ cam_trans_delta: Float[Tensor, "batch view 3"] | None = None,
+ ) -> DecoderOutput:
+
+ return self.rendering_fn(gaussians, extrinsics, intrinsics, near, far, image_shape, depth_mode, cam_rot_delta, cam_trans_delta)
+
diff --git a/src/model/encoder/__init__.py b/src/model/encoder/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4da77b7a2bd380ead96672170135b88ca2ff66c9
--- /dev/null
+++ b/src/model/encoder/__init__.py
@@ -0,0 +1,19 @@
+from typing import Optional, Union
+
+from .encoder import Encoder
+from .visualization.encoder_visualizer import EncoderVisualizer
+from .anysplat import EncoderAnySplat, EncoderAnySplatCfg
+
+ENCODERS = {
+ "anysplat": (EncoderAnySplat, None),
+}
+
+EncoderCfg = Union[EncoderAnySplatCfg]
+
+
+def get_encoder(cfg: EncoderCfg) -> tuple[Encoder, Optional[EncoderVisualizer]]:
+ encoder, visualizer = ENCODERS[cfg.name]
+ encoder = encoder(cfg)
+ if visualizer is not None:
+ visualizer = visualizer(cfg.visualizer, encoder)
+ return encoder, visualizer
diff --git a/src/model/encoder/anysplat.py b/src/model/encoder/anysplat.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e6c73497b696ac4689a9120fc2f3c7bbd20c9d9
--- /dev/null
+++ b/src/model/encoder/anysplat.py
@@ -0,0 +1,593 @@
+import copy
+
+# VGGT parts
+import os
+import sys
+from copy import deepcopy
+from dataclasses import dataclass
+from typing import List, Literal, Optional
+
+import torch
+import torch.nn.functional as F
+import torchvision
+from einops import rearrange
+from huggingface_hub import PyTorchModelHubMixin
+from jaxtyping import Float
+from src.dataset.shims.bounds_shim import apply_bounds_shim
+from src.dataset.shims.normalize_shim import apply_normalize_shim
+from src.dataset.shims.patch_shim import apply_patch_shim
+from src.dataset.types import BatchedExample, DataShim
+from src.geometry.projection import sample_image_grid
+
+from src.model.encoder.heads.vggt_dpt_gs_head import VGGT_DPT_GS_Head
+from src.model.encoder.vggt.utils.geometry import (
+ batchify_unproject_depth_map_to_point_map,
+ unproject_depth_map_to_point_map,
+)
+from src.model.encoder.vggt.utils.pose_enc import pose_encoding_to_extri_intri
+from src.utils.geometry import get_rel_pos # used for model hub
+from torch import nn, Tensor
+from torch_scatter import scatter_add, scatter_max
+
+from ..types import Gaussians
+from .backbone import Backbone, BackboneCfg, get_backbone
+
+from .backbone.croco.misc import transpose_to_landscape
+from .common.gaussian_adapter import (
+ GaussianAdapter,
+ GaussianAdapterCfg,
+ UnifiedGaussianAdapter,
+)
+from .encoder import Encoder, EncoderOutput
+from .heads import head_factory
+from .visualization.encoder_visualizer_epipolar_cfg import EncoderVisualizerEpipolarCfg
+
+root_path = os.path.abspath(".")
+sys.path.append(root_path)
+from src.model.encoder.heads.head_modules import TransformerBlockSelfAttn
+from src.model.encoder.vggt.heads.dpt_head import DPTHead
+from src.model.encoder.vggt.layers.mlp import Mlp
+from src.model.encoder.vggt.models.vggt import VGGT
+
+inf = float("inf")
+
+
+@dataclass
+class OpacityMappingCfg:
+ initial: float
+ final: float
+ warm_up: int
+
+
+@dataclass
+class GSHeadParams:
+ dec_depth: int = 23
+ patch_size: tuple[int, int] = (14, 14)
+ enc_embed_dim: int = 2048
+ dec_embed_dim: int = 2048
+ feature_dim: int = 256
+ depth_mode = ("exp", -inf, inf)
+ conf_mode = True
+
+
+@dataclass
+class EncoderAnySplatCfg:
+ name: Literal["anysplat"]
+ anchor_feat_dim: int
+ voxel_size: float
+ n_offsets: int
+ d_feature: int
+ add_view: bool
+ num_monocular_samples: int
+ backbone: BackboneCfg
+ visualizer: EncoderVisualizerEpipolarCfg
+ gaussian_adapter: GaussianAdapterCfg
+ apply_bounds_shim: bool
+ opacity_mapping: OpacityMappingCfg
+ gaussians_per_pixel: int
+ num_surfaces: int
+ gs_params_head_type: str
+ input_mean: tuple[float, float, float] = (0.5, 0.5, 0.5)
+ input_std: tuple[float, float, float] = (0.5, 0.5, 0.5)
+ pretrained_weights: str = ""
+ pose_free: bool = True
+ pred_pose: bool = True
+ gt_pose_to_pts: bool = False
+ gs_prune: bool = False
+ opacity_threshold: float = 0.001
+ gs_keep_ratio: float = 1.0
+ pred_head_type: Literal["depth", "point"] = "point"
+ freeze_backbone: bool = False
+ freeze_module: Literal[
+ "all",
+ "global",
+ "frame",
+ "patch_embed",
+ "patch_embed+frame",
+ "patch_embed+global",
+ "global+frame",
+ "None",
+ ] = "None"
+ distill: bool = False
+ render_conf: bool = False
+ opacity_conf: bool = False
+ conf_threshold: float = 0.1
+ intermediate_layer_idx: Optional[List[int]] = None
+ voxelize: bool = False
+
+
+def rearrange_head(feat, patch_size, H, W):
+ B = feat.shape[0]
+ feat = feat.transpose(-1, -2).view(B, -1, H // patch_size, W // patch_size)
+ feat = F.pixel_shuffle(feat, patch_size) # B,D,H,W
+ feat = rearrange(feat, "b d h w -> b (h w) d")
+ return feat
+
+
+class EncoderAnySplat(Encoder[EncoderAnySplatCfg]):
+ backbone: nn.Module
+ gaussian_adapter: GaussianAdapter
+
+ def __init__(self, cfg: EncoderAnySplatCfg) -> None:
+ super().__init__(cfg)
+ model_full = VGGT.from_pretrained("facebook/VGGT-1B")
+ # model_full = VGGT()
+ self.aggregator = model_full.aggregator.to(torch.bfloat16)
+ self.freeze_backbone = cfg.freeze_backbone
+ self.distill = cfg.distill
+ self.pred_pose = cfg.pred_pose
+
+ self.camera_head = model_full.camera_head
+ if self.cfg.pred_head_type == "depth":
+ self.depth_head = model_full.depth_head
+ else:
+ self.point_head = model_full.point_head
+
+ if self.distill:
+ self.distill_aggregator = copy.deepcopy(self.aggregator)
+ self.distill_camera_head = copy.deepcopy(self.camera_head)
+ self.distill_depth_head = copy.deepcopy(self.depth_head)
+ for module in [
+ self.distill_aggregator,
+ self.distill_camera_head,
+ self.distill_depth_head,
+ ]:
+ for param in module.parameters():
+ param.requires_grad = False
+ param.data = param.data.cpu()
+
+ if self.freeze_backbone:
+ # Freeze backbone components
+ if self.cfg.pred_head_type == "depth":
+ for module in [self.aggregator, self.camera_head, self.depth_head]:
+ for param in module.parameters():
+ param.requires_grad = False
+ else:
+ for module in [self.aggregator, self.camera_head, self.point_head]:
+ for param in module.parameters():
+ param.requires_grad = False
+ else:
+ # aggregator freeze
+ freeze_module = self.cfg.freeze_module
+ if freeze_module == "None":
+ pass
+
+ elif freeze_module == "all":
+ for param in self.aggregator.parameters():
+ param.requires_grad = False
+
+ else:
+ module_pairs = {
+ "patch_embed+frame": ["patch_embed", "frame"],
+ "patch_embed+global": ["patch_embed", "global"],
+ "global+frame": ["global", "frame"],
+ }
+
+ if freeze_module in module_pairs:
+ for name, param in self.aggregator.named_parameters():
+ if any(m in name for m in module_pairs[freeze_module]):
+ param.requires_grad = False
+ else:
+ for name, param in self.named_parameters():
+ param.requires_grad = (
+ freeze_module not in name and "distill" not in name
+ )
+
+ self.pose_free = cfg.pose_free
+ if self.pose_free:
+ self.gaussian_adapter = UnifiedGaussianAdapter(cfg.gaussian_adapter)
+ else:
+ self.gaussian_adapter = GaussianAdapter(cfg.gaussian_adapter)
+
+ self.raw_gs_dim = 1 + self.gaussian_adapter.d_in # 1 for opacity
+ self.voxel_size = cfg.voxel_size
+ self.gs_params_head_type = cfg.gs_params_head_type
+ # fake backbone for head parameters
+ head_params = GSHeadParams()
+ self.gaussian_param_head = VGGT_DPT_GS_Head(
+ dim_in=2048,
+ patch_size=head_params.patch_size,
+ output_dim=self.raw_gs_dim + 1,
+ activation="norm_exp",
+ conf_activation="expp1",
+ features=head_params.feature_dim,
+ )
+
+ def map_pdf_to_opacity(
+ self,
+ pdf: Float[Tensor, " *batch"],
+ global_step: int,
+ ) -> Float[Tensor, " *batch"]:
+ # https://www.desmos.com/calculator/opvwti3ba9
+
+ # Figure out the exponent.
+ cfg = self.cfg.opacity_mapping
+ x = cfg.initial + min(global_step / cfg.warm_up, 1) * (cfg.final - cfg.initial)
+ exponent = 2**x
+
+ # Map the probability density to an opacity.
+ return 0.5 * (1 - (1 - pdf) ** exponent + pdf ** (1 / exponent))
+
+ def normalize_pts3d(self, pts3ds, valid_masks, original_extrinsics=None):
+ # normalize pts_all
+ B = pts3ds.shape[0]
+ pts3d_norms = []
+ scale_factors = []
+ for bs in range(B):
+ pts3d, valid_mask = pts3ds[bs], valid_masks[bs]
+ if original_extrinsics is not None:
+ camera_c2w = original_extrinsics[bs]
+ first_camera_w2c = (
+ camera_c2w[0].inverse().unsqueeze(0).repeat(pts3d.shape[0], 1, 1)
+ )
+
+ pts3d_homo = torch.cat(
+ [pts3d, torch.ones_like(pts3d[:, :, :, :1])], dim=-1
+ )
+ transformed_pts3d = torch.bmm(
+ first_camera_w2c, pts3d_homo.flatten(1, 2).transpose(1, 2)
+ ).transpose(1, 2)[..., :3]
+ scene_scale = torch.norm(
+ transformed_pts3d.flatten(0, 1)[valid_mask.flatten(0, 2).bool()],
+ dim=-1,
+ ).mean()
+ else:
+ transformed_pts3d = pts3d[valid_mask]
+ dis = transformed_pts3d.norm(dim=-1)
+ scene_scale = dis.mean().clip(min=1e-8)
+ # pts3d_norm[bs] = pts3d[bs] / scene_scale
+ pts3d_norms.append(pts3d / scene_scale)
+ scale_factors.append(scene_scale)
+ return torch.stack(pts3d_norms, dim=0), torch.stack(scale_factors, dim=0)
+
+ def align_pts_all_with_pts3d(
+ self, pts_all, pts3d, valid_mask, original_extrinsics=None
+ ):
+ # align pts_all with pts3d
+ B = pts_all.shape[0]
+
+ # follow vggt's normalization implementation
+ pts3d_norm, scale_factor = self.normalize_pts3d(
+ pts3d, valid_mask, original_extrinsics
+ ) # check if this is correct
+ pts_all = pts_all * scale_factor.view(B, 1, 1, 1, 1)
+
+ return pts_all
+
+ def pad_tensor_list(self, tensor_list, pad_shape, value=0.0):
+ padded = []
+ for t in tensor_list:
+ pad_len = pad_shape[0] - t.shape[0]
+ if pad_len > 0:
+ padding = torch.full(
+ (pad_len, *t.shape[1:]), value, device=t.device, dtype=t.dtype
+ )
+ t = torch.cat([t, padding], dim=0)
+ padded.append(t)
+ return torch.stack(padded)
+
+ def voxelizaton_with_fusion(self, img_feat, pts3d, voxel_size, conf=None):
+ # img_feat: B*V, C, H, W
+ # pts3d: B*V, 3, H, W
+ V, C, H, W = img_feat.shape
+ pts3d_flatten = pts3d.permute(0, 2, 3, 1).flatten(0, 2)
+
+ voxel_indices = (pts3d_flatten / voxel_size).round().int() # [B*V*N, 3]
+ unique_voxels, inverse_indices, counts = torch.unique(
+ voxel_indices, dim=0, return_inverse=True, return_counts=True
+ )
+
+ # Flatten confidence scores and features
+ conf_flat = conf.flatten() # [B*V*N]
+ anchor_feats_flat = img_feat.permute(0, 2, 3, 1).flatten(0, 2) # [B*V*N, ...]
+
+ # Compute softmax weights per voxel
+ conf_voxel_max, _ = scatter_max(conf_flat, inverse_indices, dim=0)
+ conf_exp = torch.exp(conf_flat - conf_voxel_max[inverse_indices])
+ voxel_weights = scatter_add(
+ conf_exp, inverse_indices, dim=0
+ ) # [num_unique_voxels]
+ weights = (conf_exp / (voxel_weights[inverse_indices] + 1e-6)).unsqueeze(
+ -1
+ ) # [B*V*N, 1]
+
+ # Compute weighted average of positions and features
+ weighted_pts = pts3d_flatten * weights
+ weighted_feats = anchor_feats_flat.squeeze(1) * weights
+
+ # Aggregate per voxel
+ voxel_pts = scatter_add(
+ weighted_pts, inverse_indices, dim=0
+ ) # [num_unique_voxels, 3]
+ voxel_feats = scatter_add(
+ weighted_feats, inverse_indices, dim=0
+ ) # [num_unique_voxels, feat_dim]
+
+ return voxel_pts, voxel_feats
+
+ def forward(
+ self,
+ image: torch.Tensor,
+ global_step: int = 0,
+ visualization_dump: Optional[dict] = None,
+ ) -> Gaussians:
+ device = image.device
+ b, v, _, h, w = image.shape
+ distill_infos = {}
+ if self.distill:
+ distill_image = image.clone().detach()
+ for module in [
+ self.distill_aggregator,
+ self.distill_camera_head,
+ self.distill_depth_head,
+ ]:
+ for param in module.parameters():
+ param.data = param.data.to(device, non_blocking=True)
+
+ with torch.no_grad():
+ # Process with bfloat16 precision
+ with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16):
+ distill_aggregated_tokens_list, distill_patch_start_idx = (
+ self.distill_aggregator(
+ distill_image.to(torch.bfloat16),
+ intermediate_layer_idx=self.cfg.intermediate_layer_idx,
+ )
+ )
+
+ # Process with default precision
+ with torch.amp.autocast("cuda", enabled=False):
+ # Get camera pose information
+ distill_pred_pose_enc_list = self.distill_camera_head(
+ distill_aggregated_tokens_list
+ )
+ last_distill_pred_pose_enc = distill_pred_pose_enc_list[-1]
+ distill_extrinsic, distill_intrinsic = pose_encoding_to_extri_intri(
+ last_distill_pred_pose_enc, image.shape[-2:]
+ )
+
+ # Get depth information
+ distill_depth_map, distill_depth_conf = self.distill_depth_head(
+ distill_aggregated_tokens_list,
+ images=distill_image,
+ patch_start_idx=distill_patch_start_idx,
+ )
+
+ # Convert depth to 3D points
+ distill_pts_all = batchify_unproject_depth_map_to_point_map(
+ distill_depth_map, distill_extrinsic, distill_intrinsic
+ )
+ # Store results
+ distill_infos["pred_pose_enc_list"] = distill_pred_pose_enc_list
+ distill_infos["pts_all"] = distill_pts_all
+ distill_infos["depth_map"] = distill_depth_map
+
+ conf_threshold = torch.quantile(
+ distill_depth_conf.flatten(2, 3), 0.3, dim=-1, keepdim=True
+ ) # Get threshold for each view
+ conf_mask = distill_depth_conf > conf_threshold.unsqueeze(-1)
+ distill_infos["conf_mask"] = conf_mask
+
+ for module in [
+ self.distill_aggregator,
+ self.distill_camera_head,
+ self.distill_depth_head,
+ ]:
+ for param in module.parameters():
+ param.data = param.data.cpu()
+ # Clean up to save memory
+ del distill_aggregated_tokens_list, distill_patch_start_idx
+ del distill_pred_pose_enc_list, last_distill_pred_pose_enc
+ del distill_extrinsic, distill_intrinsic
+ del distill_depth_map, distill_depth_conf
+ torch.cuda.empty_cache()
+
+ with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16):
+ aggregated_tokens_list, patch_start_idx = self.aggregator(
+ image.to(torch.bfloat16),
+ intermediate_layer_idx=self.cfg.intermediate_layer_idx,
+ )
+
+ with torch.amp.autocast("cuda", enabled=False):
+ pred_pose_enc_list = self.camera_head(aggregated_tokens_list)
+ last_pred_pose_enc = pred_pose_enc_list[-1]
+ extrinsic, intrinsic = pose_encoding_to_extri_intri(
+ last_pred_pose_enc, image.shape[-2:]
+ ) # only for debug
+
+ if self.cfg.pred_head_type == "point":
+ pts_all, pts_conf = self.point_head(
+ aggregated_tokens_list,
+ images=image,
+ patch_start_idx=patch_start_idx,
+ )
+ elif self.cfg.pred_head_type == "depth":
+ depth_map, depth_conf = self.depth_head(
+ aggregated_tokens_list,
+ images=image,
+ patch_start_idx=patch_start_idx,
+ )
+ pts_all = batchify_unproject_depth_map_to_point_map(
+ depth_map, extrinsic, intrinsic
+ )
+ else:
+ raise ValueError(f"Invalid pred_head_type: {self.cfg.pred_head_type}")
+
+ if self.cfg.render_conf:
+ conf_valid = torch.quantile(
+ depth_conf.flatten(0, 1), self.cfg.conf_threshold
+ )
+ conf_valid_mask = depth_conf > conf_valid
+ else:
+ conf_valid_mask = torch.ones_like(depth_conf, dtype=torch.bool)
+
+ # dpt style gs_head input format
+ out = self.gaussian_param_head(
+ aggregated_tokens_list,
+ pts_all.flatten(0, 1).permute(0, 3, 1, 2),
+ image,
+ patch_start_idx=patch_start_idx,
+ image_size=(h, w),
+ )
+
+ del aggregated_tokens_list, patch_start_idx
+ torch.cuda.empty_cache()
+
+ pts_flat = pts_all.flatten(2, 3)
+ scene_scale = pts_flat.norm(dim=-1).mean().clip(min=1e-8)
+
+ anchor_feats, conf = out[:, :, : self.raw_gs_dim], out[:, :, self.raw_gs_dim]
+
+ neural_feats_list, neural_pts_list = [], []
+ if self.cfg.voxelize:
+ for b_i in range(b):
+ neural_pts, neural_feats = self.voxelizaton_with_fusion(
+ anchor_feats[b_i],
+ pts_all[b_i].permute(0, 3, 1, 2).contiguous(),
+ self.voxel_size,
+ conf=conf[b_i],
+ )
+ neural_feats_list.append(neural_feats)
+ neural_pts_list.append(neural_pts)
+ else:
+ for b_i in range(b):
+ neural_feats_list.append(
+ anchor_feats[b_i].permute(0, 2, 3, 1)[conf_valid_mask[b_i]]
+ )
+ neural_pts_list.append(pts_all[b_i][conf_valid_mask[b_i]])
+
+ max_voxels = max(f.shape[0] for f in neural_feats_list)
+ neural_feats = self.pad_tensor_list(
+ neural_feats_list, (max_voxels,), value=-1e10
+ )
+
+ neural_pts = self.pad_tensor_list(
+ neural_pts_list, (max_voxels,), -1e4
+ ) # -1 == invalid voxel
+
+ depths = neural_pts[..., -1].unsqueeze(-1)
+ densities = neural_feats[..., 0].sigmoid()
+
+ assert len(densities.shape) == 2, "the shape of densities should be (B, N)"
+ assert neural_pts.shape[1] > 1, "the number of voxels should be greater than 1"
+
+ opacity = self.map_pdf_to_opacity(densities, global_step).squeeze(-1)
+ if self.cfg.opacity_conf:
+ shift = torch.quantile(depth_conf, self.cfg.conf_threshold)
+ opacity = opacity * torch.sigmoid(depth_conf - shift)[
+ conf_valid_mask
+ ].unsqueeze(
+ 0
+ ) # little bit hacky
+
+ # GS Prune, but only works when bs = 1
+ # if want to support bs > 1, need to random prune gaussians based on the rank of opacity like LongLRM
+ # Note: we not prune gaussians here, but we will try it in the future
+ if self.cfg.gs_prune and b == 1:
+ opacity_threshold = self.cfg.opacity_threshold
+ gaussian_usage = opacity > opacity_threshold # (B, N)
+
+ print(
+ f"based on opacity threshold {opacity_threshold}, pruned {gaussian_usage.shape[1] - neural_pts.shape[1]} gaussians out of {gaussian_usage.shape[1]}"
+ )
+
+ if (gaussian_usage.sum() / gaussian_usage.numel()) > self.cfg.gs_keep_ratio:
+ # rank by opacity
+ num_keep = int(gaussian_usage.shape[1] * self.cfg.gs_keep_ratio)
+ idx_sort = opacity.argsort(dim=1, descending=True)
+ keep_idx = idx_sort[:, :num_keep]
+ gaussian_usage = torch.zeros_like(gaussian_usage, dtype=torch.bool)
+ gaussian_usage.scatter_(1, keep_idx, True)
+
+ neural_pts = neural_pts[gaussian_usage].view(b, -1, 3).contiguous()
+ depths = depths[gaussian_usage].view(b, -1, 1).contiguous()
+ neural_feats = (
+ neural_feats[gaussian_usage].view(b, -1, self.raw_gs_dim).contiguous()
+ )
+ opacity = opacity[gaussian_usage].view(b, -1).contiguous()
+
+ print(
+ f"finally pruned {gaussian_usage.shape[1] - neural_pts.shape[1]} gaussians out of {gaussian_usage.shape[1]}"
+ )
+
+ gaussians = self.gaussian_adapter.forward(
+ neural_pts,
+ depths,
+ opacity,
+ neural_feats[..., 1:].squeeze(2),
+ )
+
+ if visualization_dump is not None:
+ visualization_dump["depth"] = rearrange(
+ pts_all[..., -1].flatten(2, 3).unsqueeze(-1).unsqueeze(-1),
+ "b v (h w) srf s -> b v h w srf s",
+ h=h,
+ w=w,
+ )
+
+ infos = {}
+ infos["scene_scale"] = scene_scale
+ infos["voxelize_ratio"] = densities.shape[1] / (h * w * v)
+
+ print(
+ f"scene scale: {scene_scale:.3f}, pixel-wise num: {h*w*v}, after voxelize: {neural_pts.shape[1]}, voxelize ratio: {infos['voxelize_ratio']:.3f}"
+ )
+ print(
+ f"Gaussians attributes: \n"
+ f"opacities: mean: {gaussians.opacities.mean()}, min: {gaussians.opacities.min()}, max: {gaussians.opacities.max()} \n"
+ f"scales: mean: {gaussians.scales.mean()}, min: {gaussians.scales.min()}, max: {gaussians.scales.max()}"
+ )
+
+ print("B:", b, "V:", v, "H:", h, "W:", w)
+ extrinsic_padding = (
+ torch.tensor([0, 0, 0, 1], device=device, dtype=extrinsic.dtype)
+ .view(1, 1, 1, 4)
+ .repeat(b, v, 1, 1)
+ )
+ intrinsic = intrinsic.clone() # Create a new tensor
+ intrinsic = torch.stack(
+ [intrinsic[:, :, 0] / w, intrinsic[:, :, 1] / h, intrinsic[:, :, 2]], dim=2
+ )
+
+ return EncoderOutput(
+ gaussians=gaussians,
+ pred_pose_enc_list=pred_pose_enc_list,
+ pred_context_pose=dict(
+ extrinsic=torch.cat([extrinsic, extrinsic_padding], dim=2).inverse(),
+ intrinsic=intrinsic,
+ ),
+ depth_dict=dict(depth=depth_map, conf_valid_mask=conf_valid_mask),
+ infos=infos,
+ distill_infos=distill_infos,
+ )
+
+ def get_data_shim(self) -> DataShim:
+ def data_shim(batch: BatchedExample) -> BatchedExample:
+ batch = apply_normalize_shim(
+ batch,
+ self.cfg.input_mean,
+ self.cfg.input_std,
+ )
+
+ return batch
+
+ return data_shim
diff --git a/src/model/encoder/backbone/__init__.py b/src/model/encoder/backbone/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f6a0d24132a5fa7ea3d42ed06141d55b93d21036
--- /dev/null
+++ b/src/model/encoder/backbone/__init__.py
@@ -0,0 +1,21 @@
+from typing import Any
+import torch.nn as nn
+
+from .backbone import Backbone
+from .backbone_croco_multiview import AsymmetricCroCoMulti
+from .backbone_dino import BackboneDino, BackboneDinoCfg
+from .backbone_resnet import BackboneResnet, BackboneResnetCfg
+from .backbone_croco import AsymmetricCroCo, BackboneCrocoCfg
+
+BACKBONES: dict[str, Backbone[Any]] = {
+ "resnet": BackboneResnet,
+ "dino": BackboneDino,
+ "croco": AsymmetricCroCo,
+ "croco_multi": AsymmetricCroCoMulti,
+}
+
+BackboneCfg = BackboneResnetCfg | BackboneDinoCfg | BackboneCrocoCfg
+
+
+def get_backbone(cfg: BackboneCfg, d_in: int = 3) -> nn.Module:
+ return BACKBONES[cfg.name](cfg, d_in)
diff --git a/src/model/encoder/backbone/backbone.py b/src/model/encoder/backbone/backbone.py
new file mode 100644
index 0000000000000000000000000000000000000000..2644e44eb38279165314fcb47bd3a7757abb76cc
--- /dev/null
+++ b/src/model/encoder/backbone/backbone.py
@@ -0,0 +1,29 @@
+from abc import ABC, abstractmethod
+from typing import Generic, TypeVar
+
+from jaxtyping import Float
+from torch import Tensor, nn
+
+from src.dataset.types import BatchedViews
+
+T = TypeVar("T")
+
+
+class Backbone(nn.Module, ABC, Generic[T]):
+ cfg: T
+
+ def __init__(self, cfg: T) -> None:
+ super().__init__()
+ self.cfg = cfg
+
+ @abstractmethod
+ def forward(
+ self,
+ context: BatchedViews,
+ ) -> Float[Tensor, "batch view d_out height width"]:
+ pass
+
+ @property
+ @abstractmethod
+ def d_out(self) -> int:
+ pass
diff --git a/src/model/encoder/backbone/backbone_croco.py b/src/model/encoder/backbone/backbone_croco.py
new file mode 100644
index 0000000000000000000000000000000000000000..8eb26b5cfff2ab4fa5961d18d7166d37cde54805
--- /dev/null
+++ b/src/model/encoder/backbone/backbone_croco.py
@@ -0,0 +1,281 @@
+from copy import deepcopy
+from dataclasses import dataclass
+from typing import Literal
+
+import torch
+from torch import nn
+
+from .croco.blocks import DecoderBlock
+from .croco.croco import CroCoNet
+from .croco.misc import fill_default_args, freeze_all_params, transpose_to_landscape, is_symmetrized, interleave, \
+ make_batch_symmetric
+from .croco.patch_embed import get_patch_embed
+from .backbone import Backbone
+from src.geometry.camera_emb import get_intrinsic_embedding
+
+
+inf = float('inf')
+
+
+croco_params = {
+ 'ViTLarge_BaseDecoder': {
+ 'enc_depth': 24,
+ 'dec_depth': 12,
+ 'enc_embed_dim': 1024,
+ 'dec_embed_dim': 768,
+ 'enc_num_heads': 16,
+ 'dec_num_heads': 12,
+ 'pos_embed': 'RoPE100',
+ 'img_size': (512, 512),
+ },
+}
+
+default_dust3r_params = {
+ 'enc_depth': 24,
+ 'dec_depth': 12,
+ 'enc_embed_dim': 1024,
+ 'dec_embed_dim': 768,
+ 'enc_num_heads': 16,
+ 'dec_num_heads': 12,
+ 'pos_embed': 'RoPE100',
+ 'patch_embed_cls': 'PatchEmbedDust3R',
+ 'img_size': (512, 512),
+ 'head_type': 'dpt',
+ 'output_mode': 'pts3d',
+ 'depth_mode': ('exp', -inf, inf),
+ 'conf_mode': ('exp', 1, inf)
+}
+
+
+@dataclass
+class BackboneCrocoCfg:
+ name: Literal["croco", "croco_multi"]
+ model: Literal["ViTLarge_BaseDecoder", "ViTBase_SmallDecoder", "ViTBase_BaseDecoder"] # keep interface for the last two models, but they are not supported
+ patch_embed_cls: str = 'PatchEmbedDust3R' # PatchEmbedDust3R or ManyAR_PatchEmbed
+ asymmetry_decoder: bool = True
+ intrinsics_embed_loc: Literal["encoder", "decoder", "none"] = 'none'
+ intrinsics_embed_degree: int = 0
+ intrinsics_embed_type: Literal["pixelwise", "linear", "token"] = 'token' # linear or dpt
+
+
+class AsymmetricCroCo(CroCoNet):
+ """ Two siamese encoders, followed by two decoders.
+ The goal is to output 3d points directly, both images in view1's frame
+ (hence the asymmetry).
+ """
+
+ def __init__(self, cfg: BackboneCrocoCfg, d_in: int) -> None:
+
+ self.intrinsics_embed_loc = cfg.intrinsics_embed_loc
+ self.intrinsics_embed_degree = cfg.intrinsics_embed_degree
+ self.intrinsics_embed_type = cfg.intrinsics_embed_type
+ self.intrinsics_embed_encoder_dim = 0
+ self.intrinsics_embed_decoder_dim = 0
+ if self.intrinsics_embed_loc == 'encoder' and self.intrinsics_embed_type == 'pixelwise':
+ self.intrinsics_embed_encoder_dim = (self.intrinsics_embed_degree + 1) ** 2 if self.intrinsics_embed_degree > 0 else 3
+ elif self.intrinsics_embed_loc == 'decoder' and self.intrinsics_embed_type == 'pixelwise':
+ self.intrinsics_embed_decoder_dim = (self.intrinsics_embed_degree + 1) ** 2 if self.intrinsics_embed_degree > 0 else 3
+
+ self.patch_embed_cls = cfg.patch_embed_cls
+ self.croco_args = fill_default_args(croco_params[cfg.model], CroCoNet.__init__)
+
+ super().__init__(**croco_params[cfg.model])
+
+ if cfg.asymmetry_decoder:
+ self.dec_blocks2 = deepcopy(self.dec_blocks) # This is used in DUSt3R and MASt3R
+
+ if self.intrinsics_embed_type == 'linear' or self.intrinsics_embed_type == 'token':
+ self.intrinsic_encoder = nn.Linear(9, 1024)
+
+ # self.set_freeze(freeze)
+
+ def _set_patch_embed(self, img_size=224, patch_size=16, enc_embed_dim=768, in_chans=3):
+ in_chans = in_chans + self.intrinsics_embed_encoder_dim
+ self.patch_embed = get_patch_embed(self.patch_embed_cls, img_size, patch_size, enc_embed_dim, in_chans)
+
+ def _set_decoder(self, enc_embed_dim, dec_embed_dim, dec_num_heads, dec_depth, mlp_ratio, norm_layer, norm_im2_in_dec):
+ self.dec_depth = dec_depth
+ self.dec_embed_dim = dec_embed_dim
+ # transfer from encoder to decoder
+ enc_embed_dim = enc_embed_dim + self.intrinsics_embed_decoder_dim
+ self.decoder_embed = nn.Linear(enc_embed_dim, dec_embed_dim, bias=True)
+ # transformer for the decoder
+ self.dec_blocks = nn.ModuleList([
+ DecoderBlock(dec_embed_dim, dec_num_heads, mlp_ratio=mlp_ratio, qkv_bias=True, norm_layer=norm_layer, norm_mem=norm_im2_in_dec, rope=self.rope)
+ for i in range(dec_depth)])
+ # final norm layer
+ self.dec_norm = norm_layer(dec_embed_dim)
+
+ def load_state_dict(self, ckpt, **kw):
+ # duplicate all weights for the second decoder if not present
+ new_ckpt = dict(ckpt)
+ if not any(k.startswith('dec_blocks2') for k in ckpt):
+ for key, value in ckpt.items():
+ if key.startswith('dec_blocks'):
+ new_ckpt[key.replace('dec_blocks', 'dec_blocks2')] = value
+ return super().load_state_dict(new_ckpt, **kw)
+
+ def set_freeze(self, freeze): # this is for use by downstream models
+ assert freeze in ['none', 'mask', 'encoder'], f"unexpected freeze={freeze}"
+ to_be_frozen = {
+ 'none': [],
+ 'mask': [self.mask_token],
+ 'encoder': [self.mask_token, self.patch_embed, self.enc_blocks],
+ 'encoder_decoder': [self.mask_token, self.patch_embed, self.enc_blocks, self.enc_norm, self.decoder_embed, self.dec_blocks, self.dec_blocks2, self.dec_norm],
+ }
+ freeze_all_params(to_be_frozen[freeze])
+
+ def _set_prediction_head(self, *args, **kwargs):
+ """ No prediction head """
+ return
+
+ def _encode_image(self, image, true_shape, intrinsics_embed=None):
+ # embed the image into patches (x has size B x Npatches x C)
+ x, pos = self.patch_embed(image, true_shape=true_shape)
+
+ if intrinsics_embed is not None:
+
+ if self.intrinsics_embed_type == 'linear':
+ x = x + intrinsics_embed
+ elif self.intrinsics_embed_type == 'token':
+ x = torch.cat((x, intrinsics_embed), dim=1)
+ add_pose = pos[:, 0:1, :].clone()
+ add_pose[:, :, 0] += (pos[:, -1, 0].unsqueeze(-1) + 1)
+ pos = torch.cat((pos, add_pose), dim=1)
+
+ # add positional embedding without cls token
+ assert self.enc_pos_embed is None
+
+ # now apply the transformer encoder and normalization
+ for blk in self.enc_blocks:
+ x = blk(x, pos)
+
+ x = self.enc_norm(x)
+ return x, pos, None
+
+ def _encode_image_pairs(self, img1, img2, true_shape1, true_shape2, intrinsics_embed1=None, intrinsics_embed2=None):
+ if img1.shape[-2:] == img2.shape[-2:]:
+ out, pos, _ = self._encode_image(torch.cat((img1, img2), dim=0),
+ torch.cat((true_shape1, true_shape2), dim=0),
+ torch.cat((intrinsics_embed1, intrinsics_embed2), dim=0) if intrinsics_embed1 is not None else None)
+ out, out2 = out.chunk(2, dim=0)
+ pos, pos2 = pos.chunk(2, dim=0)
+ else:
+ out, pos, _ = self._encode_image(img1, true_shape1, intrinsics_embed1)
+ out2, pos2, _ = self._encode_image(img2, true_shape2, intrinsics_embed2)
+ return out, out2, pos, pos2
+
+ def _encode_symmetrized(self, view1, view2, force_asym=False):
+ img1 = view1['img']
+ img2 = view2['img']
+ B = img1.shape[0]
+ # Recover true_shape when available, otherwise assume that the img shape is the true one
+ shape1 = view1.get('true_shape', torch.tensor(img1.shape[-2:])[None].repeat(B, 1))
+ shape2 = view2.get('true_shape', torch.tensor(img2.shape[-2:])[None].repeat(B, 1))
+ # warning! maybe the images have different portrait/landscape orientations
+
+ intrinsics_embed1 = view1.get('intrinsics_embed', None)
+ intrinsics_embed2 = view2.get('intrinsics_embed', None)
+
+ if force_asym or not is_symmetrized(view1, view2):
+ feat1, feat2, pos1, pos2 = self._encode_image_pairs(img1, img2, shape1, shape2, intrinsics_embed1, intrinsics_embed2)
+ else:
+ # computing half of forward pass!'
+ feat1, feat2, pos1, pos2 = self._encode_image_pairs(img1[::2], img2[::2], shape1[::2], shape2[::2])
+ feat1, feat2 = interleave(feat1, feat2)
+ pos1, pos2 = interleave(pos1, pos2)
+
+ return (shape1, shape2), (feat1, feat2), (pos1, pos2)
+
+ def _decoder(self, f1, pos1, f2, pos2, extra_embed1=None, extra_embed2=None):
+ final_output = [(f1, f2)] # before projection
+
+ if extra_embed1 is not None:
+ f1 = torch.cat((f1, extra_embed1), dim=-1)
+ if extra_embed2 is not None:
+ f2 = torch.cat((f2, extra_embed2), dim=-1)
+
+ # project to decoder dim
+ f1 = self.decoder_embed(f1)
+ f2 = self.decoder_embed(f2)
+
+ final_output.append((f1, f2))
+ for blk1, blk2 in zip(self.dec_blocks, self.dec_blocks2):
+ # img1 side
+ f1, _ = blk1(*final_output[-1][::+1], pos1, pos2)
+ # img2 side
+ f2, _ = blk2(*final_output[-1][::-1], pos2, pos1)
+ # store the result
+ final_output.append((f1, f2))
+
+ # normalize last output
+ del final_output[1] # duplicate with final_output[0]
+ final_output[-1] = tuple(map(self.dec_norm, final_output[-1]))
+ return zip(*final_output)
+
+ def _downstream_head(self, head_num, decout, img_shape):
+ B, S, D = decout[-1].shape
+ # img_shape = tuple(map(int, img_shape))
+ head = getattr(self, f'head{head_num}')
+ return head(decout, img_shape)
+
+ def forward(self,
+ context: dict,
+ symmetrize_batch=False,
+ return_views=False,
+ ):
+ b, v, _, h, w = context["image"].shape
+ device = context["image"].device
+
+ view1, view2 = ({'img': context["image"][:, 0]},
+ {'img': context["image"][:, 1]})
+
+ # camera embedding in the encoder
+ if self.intrinsics_embed_loc == 'encoder' and self.intrinsics_embed_type == 'pixelwise':
+ intrinsic_emb = get_intrinsic_embedding(context, degree=self.intrinsics_embed_degree)
+ view1['img'] = torch.cat((view1['img'], intrinsic_emb[:, 0]), dim=1)
+ view2['img'] = torch.cat((view2['img'], intrinsic_emb[:, 1]), dim=1)
+
+ if self.intrinsics_embed_loc == 'encoder' and (self.intrinsics_embed_type == 'token' or self.intrinsics_embed_type == 'linear'):
+ intrinsic_embedding = self.intrinsic_encoder(context["intrinsics"].flatten(2))
+ view1['intrinsics_embed'] = intrinsic_embedding[:, 0].unsqueeze(1)
+ view2['intrinsics_embed'] = intrinsic_embedding[:, 1].unsqueeze(1)
+
+ if symmetrize_batch:
+ instance_list_view1, instance_list_view2 = [0 for _ in range(b)], [1 for _ in range(b)]
+ view1['instance'] = instance_list_view1
+ view2['instance'] = instance_list_view2
+ view1['idx'] = instance_list_view1
+ view2['idx'] = instance_list_view2
+ view1, view2 = make_batch_symmetric(view1, view2)
+
+ # encode the two images --> B,S,D
+ (shape1, shape2), (feat1, feat2), (pos1, pos2) = self._encode_symmetrized(view1, view2, force_asym=False)
+ else:
+ # encode the two images --> B,S,D
+ (shape1, shape2), (feat1, feat2), (pos1, pos2) = self._encode_symmetrized(view1, view2, force_asym=True)
+
+ if self.intrinsics_embed_loc == 'decoder':
+ # FIXME: downsample is hardcoded to 16
+ intrinsic_emb = get_intrinsic_embedding(context, degree=self.intrinsics_embed_degree, downsample=16, merge_hw=True)
+ dec1, dec2 = self._decoder(feat1, pos1, feat2, pos2, intrinsic_emb[:, 0], intrinsic_emb[:, 1])
+ else:
+ dec1, dec2 = self._decoder(feat1, pos1, feat2, pos2)
+
+ if self.intrinsics_embed_loc == 'encoder' and self.intrinsics_embed_type == 'token':
+ dec1, dec2 = list(dec1), list(dec2)
+ for i in range(len(dec1)):
+ dec1[i] = dec1[i][:, :-1]
+ dec2[i] = dec2[i][:, :-1]
+
+ if return_views:
+ return dec1, dec2, shape1, shape2, view1, view2
+ return dec1, dec2, shape1, shape2
+
+ @property
+ def patch_size(self) -> int:
+ return 16
+
+ @property
+ def d_out(self) -> int:
+ return 1024
diff --git a/src/model/encoder/backbone/backbone_croco_multiview.py b/src/model/encoder/backbone/backbone_croco_multiview.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d5689e0899c3b77dd49a9bed97bfa6760ecea6e
--- /dev/null
+++ b/src/model/encoder/backbone/backbone_croco_multiview.py
@@ -0,0 +1,245 @@
+from copy import deepcopy
+from dataclasses import dataclass
+from typing import Literal
+
+import torch
+from einops import rearrange
+from torch import nn
+
+from .croco.blocks import DecoderBlock
+from .croco.croco import CroCoNet
+from .croco.misc import fill_default_args, freeze_all_params, transpose_to_landscape, is_symmetrized, interleave, \
+ make_batch_symmetric
+from .croco.patch_embed import get_patch_embed
+from .backbone import Backbone
+from src.geometry.camera_emb import get_intrinsic_embedding
+
+inf = float('inf')
+
+
+croco_params = {
+ 'ViTLarge_BaseDecoder': {
+ 'enc_depth': 24,
+ 'dec_depth': 12,
+ 'enc_embed_dim': 1024,
+ 'dec_embed_dim': 768,
+ 'enc_num_heads': 16,
+ 'dec_num_heads': 12,
+ 'pos_embed': 'RoPE100',
+ 'img_size': (512, 512),
+ },
+}
+
+default_dust3r_params = {
+ 'enc_depth': 24,
+ 'dec_depth': 12,
+ 'enc_embed_dim': 1024,
+ 'dec_embed_dim': 768,
+ 'enc_num_heads': 16,
+ 'dec_num_heads': 12,
+ 'pos_embed': 'RoPE100',
+ 'patch_embed_cls': 'PatchEmbedDust3R',
+ 'img_size': (512, 512),
+ 'head_type': 'dpt',
+ 'output_mode': 'pts3d',
+ 'depth_mode': ('exp', -inf, inf),
+ 'conf_mode': ('exp', 1, inf)
+}
+
+
+@dataclass
+class BackboneCrocoCfg:
+ name: Literal["croco"]
+ model: Literal["ViTLarge_BaseDecoder", "ViTBase_SmallDecoder", "ViTBase_BaseDecoder"] # keep interface for the last two models, but they are not supported
+ patch_embed_cls: str = 'PatchEmbedDust3R' # PatchEmbedDust3R or ManyAR_PatchEmbed
+ asymmetry_decoder: bool = True
+ intrinsics_embed_loc: Literal["encoder", "decoder", "none"] = 'none'
+ intrinsics_embed_degree: int = 0
+ intrinsics_embed_type: Literal["pixelwise", "linear", "token"] = 'token' # linear or dpt
+
+
+class AsymmetricCroCoMulti(CroCoNet):
+ """ Two siamese encoders, followed by two decoders.
+ The goal is to output 3d points directly, both images in view1's frame
+ (hence the asymmetry).
+ """
+
+ def __init__(self, cfg: BackboneCrocoCfg, d_in: int) -> None:
+
+ self.intrinsics_embed_loc = cfg.intrinsics_embed_loc
+ self.intrinsics_embed_degree = cfg.intrinsics_embed_degree
+ self.intrinsics_embed_type = cfg.intrinsics_embed_type
+ self.intrinsics_embed_encoder_dim = 0
+ self.intrinsics_embed_decoder_dim = 0
+ if self.intrinsics_embed_loc == 'encoder' and self.intrinsics_embed_type == 'pixelwise':
+ self.intrinsics_embed_encoder_dim = (self.intrinsics_embed_degree + 1) ** 2 if self.intrinsics_embed_degree > 0 else 3
+ elif self.intrinsics_embed_loc == 'decoder' and self.intrinsics_embed_type == 'pixelwise':
+ self.intrinsics_embed_decoder_dim = (self.intrinsics_embed_degree + 1) ** 2 if self.intrinsics_embed_degree > 0 else 3
+
+ self.patch_embed_cls = cfg.patch_embed_cls
+ self.croco_args = fill_default_args(croco_params[cfg.model], CroCoNet.__init__)
+
+ super().__init__(**croco_params[cfg.model])
+
+ if cfg.asymmetry_decoder:
+ self.dec_blocks2 = deepcopy(self.dec_blocks) # This is used in DUSt3R and MASt3R
+
+ if self.intrinsics_embed_type == 'linear' or self.intrinsics_embed_type == 'token':
+ self.intrinsic_encoder = nn.Linear(9, 1024)
+
+ # self.set_downstream_head(output_mode, head_type, landscape_only, depth_mode, conf_mode, **croco_kwargs)
+ # self.set_freeze(freeze)
+
+ def _set_patch_embed(self, img_size=224, patch_size=16, enc_embed_dim=768, in_chans=3):
+ in_chans = in_chans + self.intrinsics_embed_encoder_dim
+ self.patch_embed = get_patch_embed(self.patch_embed_cls, img_size, patch_size, enc_embed_dim, in_chans)
+
+ def _set_decoder(self, enc_embed_dim, dec_embed_dim, dec_num_heads, dec_depth, mlp_ratio, norm_layer, norm_im2_in_dec):
+ self.dec_depth = dec_depth
+ self.dec_embed_dim = dec_embed_dim
+ # transfer from encoder to decoder
+ enc_embed_dim = enc_embed_dim + self.intrinsics_embed_decoder_dim
+ self.decoder_embed = nn.Linear(enc_embed_dim, dec_embed_dim, bias=True)
+ # transformer for the decoder
+ self.dec_blocks = nn.ModuleList([
+ DecoderBlock(dec_embed_dim, dec_num_heads, mlp_ratio=mlp_ratio, qkv_bias=True, norm_layer=norm_layer, norm_mem=norm_im2_in_dec, rope=self.rope)
+ for i in range(dec_depth)])
+ # final norm layer
+ self.dec_norm = norm_layer(dec_embed_dim)
+
+ def load_state_dict(self, ckpt, **kw):
+ # duplicate all weights for the second decoder if not present
+ new_ckpt = dict(ckpt)
+ if not any(k.startswith('dec_blocks2') for k in ckpt):
+ for key, value in ckpt.items():
+ if key.startswith('dec_blocks'):
+ new_ckpt[key.replace('dec_blocks', 'dec_blocks2')] = value
+ return super().load_state_dict(new_ckpt, **kw)
+
+ def set_freeze(self, freeze): # this is for use by downstream models
+ assert freeze in ['none', 'mask', 'encoder'], f"unexpected freeze={freeze}"
+ to_be_frozen = {
+ 'none': [],
+ 'mask': [self.mask_token],
+ 'encoder': [self.mask_token, self.patch_embed, self.enc_blocks],
+ 'encoder_decoder': [self.mask_token, self.patch_embed, self.enc_blocks, self.enc_norm, self.decoder_embed, self.dec_blocks, self.dec_blocks2, self.dec_norm],
+ }
+ freeze_all_params(to_be_frozen[freeze])
+
+ def _set_prediction_head(self, *args, **kwargs):
+ """ No prediction head """
+ return
+
+ def _encode_image(self, image, true_shape, intrinsics_embed=None):
+ # embed the image into patches (x has size B x Npatches x C)
+ x, pos = self.patch_embed(image, true_shape=true_shape)
+
+ if intrinsics_embed is not None:
+
+ if self.intrinsics_embed_type == 'linear':
+ x = x + intrinsics_embed
+ elif self.intrinsics_embed_type == 'token':
+ x = torch.cat((x, intrinsics_embed), dim=1)
+ add_pose = pos[:, 0:1, :].clone()
+ add_pose[:, :, 0] += (pos[:, -1, 0].unsqueeze(-1) + 1)
+ pos = torch.cat((pos, add_pose), dim=1)
+
+ # add positional embedding without cls token
+ assert self.enc_pos_embed is None
+
+ # now apply the transformer encoder and normalization
+ for blk in self.enc_blocks:
+ x = blk(x, pos)
+
+ x = self.enc_norm(x)
+ return x, pos, None
+
+ def _decoder(self, feat, pose, extra_embed=None):
+ b, v, l, c = feat.shape
+ final_output = [feat] # before projection
+ if extra_embed is not None:
+ feat = torch.cat((feat, extra_embed), dim=-1)
+
+ # project to decoder dim
+ f = rearrange(feat, "b v l c -> (b v) l c")
+ f = self.decoder_embed(f)
+ f = rearrange(f, "(b v) l c -> b v l c", b=b, v=v)
+ final_output.append(f)
+
+ def generate_ctx_views(x):
+ b, v, l, c = x.shape
+ ctx_views = x.unsqueeze(1).expand(b, v, v, l, c)
+ mask = torch.arange(v).unsqueeze(0) != torch.arange(v).unsqueeze(1)
+ ctx_views = ctx_views[:, mask].reshape(b, v, v - 1, l, c) # B, V, V-1, L, C
+ ctx_views = ctx_views.flatten(2, 3) # B, V, (V-1)*L, C
+ return ctx_views.contiguous()
+
+ pos_ctx = generate_ctx_views(pose)
+ for blk1, blk2 in zip(self.dec_blocks, self.dec_blocks2):
+ feat_current = final_output[-1]
+ feat_current_ctx = generate_ctx_views(feat_current)
+ # img1 side
+ f1, _ = blk1(feat_current[:, 0].contiguous(), feat_current_ctx[:, 0].contiguous(), pose[:, 0].contiguous(), pos_ctx[:, 0].contiguous())
+ f1 = f1.unsqueeze(1)
+ # img2 side
+ f2, _ = blk2(rearrange(feat_current[:, 1:], "b v l c -> (b v) l c"),
+ rearrange(feat_current_ctx[:, 1:], "b v l c -> (b v) l c"),
+ rearrange(pose[:, 1:], "b v l c -> (b v) l c"),
+ rearrange(pos_ctx[:, 1:], "b v l c -> (b v) l c"))
+ f2 = rearrange(f2, "(b v) l c -> b v l c", b=b, v=v-1)
+ # store the result
+ final_output.append(torch.cat((f1, f2), dim=1))
+
+ # normalize last output
+ del final_output[1] # duplicate with final_output[0]
+ last_feat = rearrange(final_output[-1], "b v l c -> (b v) l c")
+ last_feat = self.dec_norm(last_feat)
+ final_output[-1] = rearrange(last_feat, "(b v) l c -> b v l c", b=b, v=v)
+ return final_output
+
+ def forward(self,
+ context: dict,
+ symmetrize_batch=False,
+ return_views=False,
+ ):
+ b, v, _, h, w = context["image"].shape
+ images_all = context["image"]
+
+ # camera embedding in the encoder
+ if self.intrinsics_embed_loc == 'encoder' and self.intrinsics_embed_type == 'pixelwise':
+ intrinsic_embedding = get_intrinsic_embedding(context, degree=self.intrinsics_embed_degree)
+ images_all = torch.cat((images_all, intrinsic_embedding), dim=2)
+
+ intrinsic_embedding_all = None
+ if self.intrinsics_embed_loc == 'encoder' and (self.intrinsics_embed_type == 'token' or self.intrinsics_embed_type == 'linear'):
+ intrinsic_embedding = self.intrinsic_encoder(context["intrinsics"].flatten(2))
+ intrinsic_embedding_all = rearrange(intrinsic_embedding, "b v c -> (b v) c").unsqueeze(1)
+
+ # step 1: encoder input images
+ images_all = rearrange(images_all, "b v c h w -> (b v) c h w")
+ shape_all = torch.tensor(images_all.shape[-2:])[None].repeat(b*v, 1)
+
+ feat, pose, _ = self._encode_image(images_all, shape_all, intrinsic_embedding_all)
+
+ feat = rearrange(feat, "(b v) l c -> b v l c", b=b, v=v)
+ pose = rearrange(pose, "(b v) l c -> b v l c", b=b, v=v)
+
+ # step 2: decoder
+ dec_feat = self._decoder(feat, pose)
+ shape = rearrange(shape_all, "(b v) c -> b v c", b=b, v=v)
+ images = rearrange(images_all, "(b v) c h w -> b v c h w", b=b, v=v)
+
+ if self.intrinsics_embed_loc == 'encoder' and self.intrinsics_embed_type == 'token':
+ dec_feat = list(dec_feat)
+ for i in range(len(dec_feat)):
+ dec_feat[i] = dec_feat[i][:, :, :-1]
+
+ return dec_feat, shape, images
+
+ @property
+ def patch_size(self) -> int:
+ return 16
+
+ @property
+ def d_out(self) -> int:
+ return 1024
diff --git a/src/model/encoder/backbone/backbone_dino.py b/src/model/encoder/backbone/backbone_dino.py
new file mode 100644
index 0000000000000000000000000000000000000000..a69fce1c633e25b7e93cba45f0a5895587fd293b
--- /dev/null
+++ b/src/model/encoder/backbone/backbone_dino.py
@@ -0,0 +1,79 @@
+from dataclasses import dataclass
+from typing import Literal
+
+import torch
+from einops import rearrange, repeat
+from jaxtyping import Float
+from torch import Tensor, nn
+
+from src.dataset.types import BatchedViews
+from .backbone import Backbone
+from .backbone_resnet import BackboneResnet, BackboneResnetCfg
+
+
+@dataclass
+class BackboneDinoCfg:
+ name: Literal["dino"]
+ model: Literal["dino_vits16", "dino_vits8", "dino_vitb16", "dino_vitb8"]
+ d_out: int
+
+
+class BackboneDino(Backbone[BackboneDinoCfg]):
+ def __init__(self, cfg: BackboneDinoCfg, d_in: int) -> None:
+ super().__init__(cfg)
+ assert d_in == 3
+ self.dino = torch.hub.load("facebookresearch/dino:main", cfg.model)
+ self.resnet_backbone = BackboneResnet(
+ BackboneResnetCfg("resnet", "dino_resnet50", 4, False, cfg.d_out),
+ d_in,
+ )
+ self.global_token_mlp = nn.Sequential(
+ nn.Linear(768, 768),
+ nn.ReLU(),
+ nn.Linear(768, cfg.d_out),
+ )
+ self.local_token_mlp = nn.Sequential(
+ nn.Linear(768, 768),
+ nn.ReLU(),
+ nn.Linear(768, cfg.d_out),
+ )
+
+ def forward(
+ self,
+ context: BatchedViews,
+ ) -> Float[Tensor, "batch view d_out height width"]:
+ # Compute features from the DINO-pretrained resnet50.
+ resnet_features = self.resnet_backbone(context)
+
+ # Compute features from the DINO-pretrained ViT.
+ b, v, _, h, w = context["image"].shape
+ assert h % self.patch_size == 0 and w % self.patch_size == 0
+ tokens = rearrange(context["image"], "b v c h w -> (b v) c h w")
+ tokens = self.dino.get_intermediate_layers(tokens)[0]
+ global_token = self.global_token_mlp(tokens[:, 0])
+ local_tokens = self.local_token_mlp(tokens[:, 1:])
+
+ # Repeat the global token to match the image shape.
+ global_token = repeat(global_token, "(b v) c -> b v c h w", b=b, v=v, h=h, w=w)
+
+ # Repeat the local tokens to match the image shape.
+ local_tokens = repeat(
+ local_tokens,
+ "(b v) (h w) c -> b v c (h hps) (w wps)",
+ b=b,
+ v=v,
+ h=h // self.patch_size,
+ hps=self.patch_size,
+ w=w // self.patch_size,
+ wps=self.patch_size,
+ )
+
+ return resnet_features + local_tokens + global_token
+
+ @property
+ def patch_size(self) -> int:
+ return int("".join(filter(str.isdigit, self.cfg.model)))
+
+ @property
+ def d_out(self) -> int:
+ return self.cfg.d_out
diff --git a/src/model/encoder/backbone/backbone_resnet.py b/src/model/encoder/backbone/backbone_resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..97c0c320ad500acae43fb3483965876007f16e7b
--- /dev/null
+++ b/src/model/encoder/backbone/backbone_resnet.py
@@ -0,0 +1,100 @@
+import functools
+from dataclasses import dataclass
+from typing import Literal
+
+import torch
+import torch.nn.functional as F
+import torchvision
+from einops import rearrange
+from jaxtyping import Float
+from torch import Tensor, nn
+from torchvision.models import ResNet
+
+from src.dataset.types import BatchedViews
+from .backbone import Backbone
+
+
+@dataclass
+class BackboneResnetCfg:
+ name: Literal["resnet"]
+ model: Literal[
+ "resnet18", "resnet34", "resnet50", "resnet101", "resnet152", "dino_resnet50"
+ ]
+ num_layers: int
+ use_first_pool: bool
+ d_out: int
+
+
+class BackboneResnet(Backbone[BackboneResnetCfg]):
+ model: ResNet
+
+ def __init__(self, cfg: BackboneResnetCfg, d_in: int) -> None:
+ super().__init__(cfg)
+
+ assert d_in == 3
+
+ norm_layer = functools.partial(
+ nn.InstanceNorm2d,
+ affine=False,
+ track_running_stats=False,
+ )
+
+ if cfg.model == "dino_resnet50":
+ self.model = torch.hub.load("facebookresearch/dino:main", "dino_resnet50")
+ else:
+ self.model = getattr(torchvision.models, cfg.model)(norm_layer=norm_layer)
+
+ # Set up projections
+ self.projections = nn.ModuleDict({})
+ for index in range(1, cfg.num_layers):
+ key = f"layer{index}"
+ block = getattr(self.model, key)
+ conv_index = 1
+ try:
+ while True:
+ d_layer_out = getattr(block[-1], f"conv{conv_index}").out_channels
+ conv_index += 1
+ except AttributeError:
+ pass
+ self.projections[key] = nn.Conv2d(d_layer_out, cfg.d_out, 1)
+
+ # Add a projection for the first layer.
+ self.projections["layer0"] = nn.Conv2d(
+ self.model.conv1.out_channels, cfg.d_out, 1
+ )
+
+ def forward(
+ self,
+ context: BatchedViews,
+ ) -> Float[Tensor, "batch view d_out height width"]:
+ # Merge the batch dimensions.
+ b, v, _, h, w = context["image"].shape
+ x = rearrange(context["image"], "b v c h w -> (b v) c h w")
+
+ # Run the images through the resnet.
+ x = self.model.conv1(x)
+ x = self.model.bn1(x)
+ x = self.model.relu(x)
+ features = [self.projections["layer0"](x)]
+
+ # Propagate the input through the resnet's layers.
+ for index in range(1, self.cfg.num_layers):
+ key = f"layer{index}"
+ if index == 0 and self.cfg.use_first_pool:
+ x = self.model.maxpool(x)
+ x = getattr(self.model, key)(x)
+ features.append(self.projections[key](x))
+
+ # Upscale the features.
+ features = [
+ F.interpolate(f, (h, w), mode="bilinear", align_corners=True)
+ for f in features
+ ]
+ features = torch.stack(features).sum(dim=0)
+
+ # Separate batch dimensions.
+ return rearrange(features, "(b v) c h w -> b v c h w", b=b, v=v)
+
+ @property
+ def d_out(self) -> int:
+ return self.cfg.d_out
diff --git a/src/model/encoder/backbone/croco/README.md b/src/model/encoder/backbone/croco/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..af8a9f0dfce215af0e9a4a17fbdabe3a0c18f0ed
--- /dev/null
+++ b/src/model/encoder/backbone/croco/README.md
@@ -0,0 +1,7 @@
+Most of the code under src/model/encoder/backbone/croco/ is from the original CROCO implementation.
+The code is not modified in any way except the relative module path.
+The original code can be found at [croco Github Repo](https://github.com/naver/croco/tree/743ee71a2a9bf57cea6832a9064a70a0597fcfcb/models).
+
+
+Except:
+- 'misc.py', 'patch_embed.py' is from DUSt3R.
\ No newline at end of file
diff --git a/src/model/encoder/backbone/croco/__init__.py b/src/model/encoder/backbone/croco/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/model/encoder/backbone/croco/blocks.py b/src/model/encoder/backbone/croco/blocks.py
new file mode 100644
index 0000000000000000000000000000000000000000..18133524f0ae265b0bd8d062d7c9eeaa63858a9b
--- /dev/null
+++ b/src/model/encoder/backbone/croco/blocks.py
@@ -0,0 +1,241 @@
+# Copyright (C) 2022-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+
+
+# --------------------------------------------------------
+# Main encoder/decoder blocks
+# --------------------------------------------------------
+# References:
+# timm
+# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/helpers.py
+# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
+# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/mlp.py
+# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/patch_embed.py
+
+
+import torch
+import torch.nn as nn
+
+from itertools import repeat
+import collections.abc
+
+
+def _ntuple(n):
+ def parse(x):
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
+ return x
+ return tuple(repeat(x, n))
+ return parse
+to_2tuple = _ntuple(2)
+
+def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+ """
+ if drop_prob == 0. or not training:
+ return x
+ keep_prob = 1 - drop_prob
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
+ if keep_prob > 0.0 and scale_by_keep:
+ random_tensor.div_(keep_prob)
+ return x * random_tensor
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+ """
+ def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+ self.scale_by_keep = scale_by_keep
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
+
+ def extra_repr(self):
+ return f'drop_prob={round(self.drop_prob,3):0.3f}'
+
+class Mlp(nn.Module):
+ """ MLP as used in Vision Transformer, MLP-Mixer and related networks"""
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, bias=True, drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ bias = to_2tuple(bias)
+ drop_probs = to_2tuple(drop)
+
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])
+ self.act = act_layer()
+ self.drop1 = nn.Dropout(drop_probs[0])
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])
+ self.drop2 = nn.Dropout(drop_probs[1])
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop1(x)
+ x = self.fc2(x)
+ x = self.drop2(x)
+ return x
+
+class Attention(nn.Module):
+
+ def __init__(self, dim, rope=None, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = head_dim ** -0.5
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+ self.rope = rope
+
+ def forward(self, x, xpos):
+ B, N, C = x.shape
+
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).transpose(1,3)
+ q, k, v = [qkv[:,:,i] for i in range(3)]
+ # q,k,v = qkv.unbind(2) # make torchscript happy (cannot use tensor as tuple)
+
+ if self.rope is not None:
+ q = self.rope(q, xpos)
+ k = self.rope(k, xpos)
+
+ attn = (q @ k.transpose(-2, -1)) * self.scale
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+class Block(nn.Module):
+
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, rope=None):
+ super().__init__()
+ self.norm1 = norm_layer(dim)
+ self.attn = Attention(dim, rope=rope, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+ def forward(self, x, xpos):
+ x = x + self.drop_path(self.attn(self.norm1(x), xpos))
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+ return x
+
+class CrossAttention(nn.Module):
+
+ def __init__(self, dim, rope=None, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = head_dim ** -0.5
+
+ self.projq = nn.Linear(dim, dim, bias=qkv_bias)
+ self.projk = nn.Linear(dim, dim, bias=qkv_bias)
+ self.projv = nn.Linear(dim, dim, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ self.rope = rope
+
+ def forward(self, query, key, value, qpos, kpos):
+ B, Nq, C = query.shape
+ Nk = key.shape[1]
+ Nv = value.shape[1]
+
+ q = self.projq(query).reshape(B,Nq,self.num_heads, C// self.num_heads).permute(0, 2, 1, 3)
+ k = self.projk(key).reshape(B,Nk,self.num_heads, C// self.num_heads).permute(0, 2, 1, 3)
+ v = self.projv(value).reshape(B,Nv,self.num_heads, C// self.num_heads).permute(0, 2, 1, 3)
+
+ if self.rope is not None:
+ q = self.rope(q, qpos)
+ k = self.rope(k, kpos)
+
+ attn = (q @ k.transpose(-2, -1)) * self.scale
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, Nq, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+class DecoderBlock(nn.Module):
+
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, norm_mem=True, rope=None):
+ super().__init__()
+ self.norm1 = norm_layer(dim)
+ self.attn = Attention(dim, rope=rope, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
+ self.cross_attn = CrossAttention(dim, rope=rope, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ self.norm3 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+ self.norm_y = norm_layer(dim) if norm_mem else nn.Identity()
+
+ def forward(self, x, y, xpos, ypos):
+ x = x + self.drop_path(self.attn(self.norm1(x), xpos))
+ y_ = self.norm_y(y)
+ x = x + self.drop_path(self.cross_attn(self.norm2(x), y_, y_, xpos, ypos))
+ x = x + self.drop_path(self.mlp(self.norm3(x)))
+ return x, y
+
+
+# patch embedding
+class PositionGetter(object):
+ """ return positions of patches """
+
+ def __init__(self):
+ self.cache_positions = {}
+
+ def __call__(self, b, h, w, device):
+ if not (h,w) in self.cache_positions:
+ x = torch.arange(w, device=device)
+ y = torch.arange(h, device=device)
+ self.cache_positions[h,w] = torch.cartesian_prod(y, x) # (h, w, 2)
+ pos = self.cache_positions[h,w].view(1, h*w, 2).expand(b, -1, 2).clone()
+ return pos
+
+class PatchEmbed(nn.Module):
+ """ just adding _init_weights + position getter compared to timm.models.layers.patch_embed.PatchEmbed"""
+
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True):
+ super().__init__()
+ img_size = to_2tuple(img_size)
+ patch_size = to_2tuple(patch_size)
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
+ self.num_patches = self.grid_size[0] * self.grid_size[1]
+ self.flatten = flatten
+
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
+
+ self.position_getter = PositionGetter()
+
+ def forward(self, x):
+ B, C, H, W = x.shape
+ torch._assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).")
+ torch._assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).")
+ x = self.proj(x)
+ pos = self.position_getter(B, x.size(2), x.size(3), x.device)
+ if self.flatten:
+ x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
+ x = self.norm(x)
+ return x, pos
+
+ def _init_weights(self):
+ w = self.proj.weight.data
+ torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
+
diff --git a/src/model/encoder/backbone/croco/croco.py b/src/model/encoder/backbone/croco/croco.py
new file mode 100644
index 0000000000000000000000000000000000000000..74bcb8521a1ef0f198f911ab3e7dac166b281f5c
--- /dev/null
+++ b/src/model/encoder/backbone/croco/croco.py
@@ -0,0 +1,249 @@
+# Copyright (C) 2022-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+
+
+# --------------------------------------------------------
+# CroCo model during pretraining
+# --------------------------------------------------------
+
+
+
+import torch
+import torch.nn as nn
+torch.backends.cuda.matmul.allow_tf32 = True # for gpu >= Ampere and pytorch >= 1.12
+from functools import partial
+
+from .blocks import Block, DecoderBlock, PatchEmbed
+from .pos_embed import get_2d_sincos_pos_embed, RoPE2D
+from .masking import RandomMask
+
+
+class CroCoNet(nn.Module):
+
+ def __init__(self,
+ img_size=224, # input image size
+ patch_size=16, # patch_size
+ mask_ratio=0.9, # ratios of masked tokens
+ enc_embed_dim=768, # encoder feature dimension
+ enc_depth=12, # encoder depth
+ enc_num_heads=12, # encoder number of heads in the transformer block
+ dec_embed_dim=512, # decoder feature dimension
+ dec_depth=8, # decoder depth
+ dec_num_heads=16, # decoder number of heads in the transformer block
+ mlp_ratio=4,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
+ norm_im2_in_dec=True, # whether to apply normalization of the 'memory' = (second image) in the decoder
+ pos_embed='cosine', # positional embedding (either cosine or RoPE100)
+ ):
+
+ super(CroCoNet, self).__init__()
+
+ # patch embeddings (with initialization done as in MAE)
+ self._set_patch_embed(img_size, patch_size, enc_embed_dim)
+
+ # mask generations
+ self._set_mask_generator(self.patch_embed.num_patches, mask_ratio)
+
+ self.pos_embed = pos_embed
+ if pos_embed=='cosine':
+ # positional embedding of the encoder
+ enc_pos_embed = get_2d_sincos_pos_embed(enc_embed_dim, int(self.patch_embed.num_patches**.5), n_cls_token=0)
+ self.register_buffer('enc_pos_embed', torch.from_numpy(enc_pos_embed).float())
+ # positional embedding of the decoder
+ dec_pos_embed = get_2d_sincos_pos_embed(dec_embed_dim, int(self.patch_embed.num_patches**.5), n_cls_token=0)
+ self.register_buffer('dec_pos_embed', torch.from_numpy(dec_pos_embed).float())
+ # pos embedding in each block
+ self.rope = None # nothing for cosine
+ elif pos_embed.startswith('RoPE'): # eg RoPE100
+ self.enc_pos_embed = None # nothing to add in the encoder with RoPE
+ self.dec_pos_embed = None # nothing to add in the decoder with RoPE
+ if RoPE2D is None: raise ImportError("Cannot find cuRoPE2D, please install it following the README instructions")
+ freq = float(pos_embed[len('RoPE'):])
+ self.rope = RoPE2D(freq=freq)
+ else:
+ raise NotImplementedError('Unknown pos_embed '+pos_embed)
+
+ # transformer for the encoder
+ self.enc_depth = enc_depth
+ self.enc_embed_dim = enc_embed_dim
+ self.enc_blocks = nn.ModuleList([
+ Block(enc_embed_dim, enc_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer, rope=self.rope)
+ for i in range(enc_depth)])
+ self.enc_norm = norm_layer(enc_embed_dim)
+
+ # masked tokens
+ self._set_mask_token(dec_embed_dim)
+
+ # decoder
+ self._set_decoder(enc_embed_dim, dec_embed_dim, dec_num_heads, dec_depth, mlp_ratio, norm_layer, norm_im2_in_dec)
+
+ # prediction head
+ self._set_prediction_head(dec_embed_dim, patch_size)
+
+ # initializer weights
+ self.initialize_weights()
+
+ def _set_patch_embed(self, img_size=224, patch_size=16, enc_embed_dim=768):
+ self.patch_embed = PatchEmbed(img_size, patch_size, 3, enc_embed_dim)
+
+ def _set_mask_generator(self, num_patches, mask_ratio):
+ self.mask_generator = RandomMask(num_patches, mask_ratio)
+
+ def _set_mask_token(self, dec_embed_dim):
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, dec_embed_dim))
+
+ def _set_decoder(self, enc_embed_dim, dec_embed_dim, dec_num_heads, dec_depth, mlp_ratio, norm_layer, norm_im2_in_dec):
+ self.dec_depth = dec_depth
+ self.dec_embed_dim = dec_embed_dim
+ # transfer from encoder to decoder
+ self.decoder_embed = nn.Linear(enc_embed_dim+0, dec_embed_dim, bias=True)
+ # transformer for the decoder
+ self.dec_blocks = nn.ModuleList([
+ DecoderBlock(dec_embed_dim, dec_num_heads, mlp_ratio=mlp_ratio, qkv_bias=True, norm_layer=norm_layer, norm_mem=norm_im2_in_dec, rope=self.rope)
+ for i in range(dec_depth)])
+ # final norm layer
+ self.dec_norm = norm_layer(dec_embed_dim)
+
+ def _set_prediction_head(self, dec_embed_dim, patch_size):
+ self.prediction_head = nn.Linear(dec_embed_dim, patch_size**2 * 3, bias=True)
+
+
+ def initialize_weights(self):
+ # patch embed
+ self.patch_embed._init_weights()
+ # mask tokens
+ if self.mask_token is not None: torch.nn.init.normal_(self.mask_token, std=.02)
+ # linears and layer norms
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ # we use xavier_uniform following official JAX ViT:
+ torch.nn.init.xavier_uniform_(m.weight)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ def _encode_image(self, image, do_mask=False, return_all_blocks=False):
+ """
+ image has B x 3 x img_size x img_size
+ do_mask: whether to perform masking or not
+ return_all_blocks: if True, return the features at the end of every block
+ instead of just the features from the last block (eg for some prediction heads)
+ """
+ # embed the image into patches (x has size B x Npatches x C)
+ # and get position if each return patch (pos has size B x Npatches x 2)
+ x, pos = self.patch_embed(image)
+ # add positional embedding without cls token
+ if self.enc_pos_embed is not None:
+ x = x + self.enc_pos_embed[None,...]
+ # apply masking
+ B,N,C = x.size()
+ if do_mask:
+ masks = self.mask_generator(x)
+ x = x[~masks].view(B, -1, C)
+ posvis = pos[~masks].view(B, -1, 2)
+ else:
+ B,N,C = x.size()
+ masks = torch.zeros((B,N), dtype=bool)
+ posvis = pos
+ # now apply the transformer encoder and normalization
+ if return_all_blocks:
+ out = []
+ for blk in self.enc_blocks:
+ x = blk(x, posvis)
+ out.append(x)
+ out[-1] = self.enc_norm(out[-1])
+ return out, pos, masks
+ else:
+ for blk in self.enc_blocks:
+ x = blk(x, posvis)
+ x = self.enc_norm(x)
+ return x, pos, masks
+
+ def _decoder(self, feat1, pos1, masks1, feat2, pos2, return_all_blocks=False):
+ """
+ return_all_blocks: if True, return the features at the end of every block
+ instead of just the features from the last block (eg for some prediction heads)
+
+ masks1 can be None => assume image1 fully visible
+ """
+ # encoder to decoder layer
+ visf1 = self.decoder_embed(feat1)
+ f2 = self.decoder_embed(feat2)
+ # append masked tokens to the sequence
+ B,Nenc,C = visf1.size()
+ if masks1 is None: # downstreams
+ f1_ = visf1
+ else: # pretraining
+ Ntotal = masks1.size(1)
+ f1_ = self.mask_token.repeat(B, Ntotal, 1).to(dtype=visf1.dtype)
+ f1_[~masks1] = visf1.view(B * Nenc, C)
+ # add positional embedding
+ if self.dec_pos_embed is not None:
+ f1_ = f1_ + self.dec_pos_embed
+ f2 = f2 + self.dec_pos_embed
+ # apply Transformer blocks
+ out = f1_
+ out2 = f2
+ if return_all_blocks:
+ _out, out = out, []
+ for blk in self.dec_blocks:
+ _out, out2 = blk(_out, out2, pos1, pos2)
+ out.append(_out)
+ out[-1] = self.dec_norm(out[-1])
+ else:
+ for blk in self.dec_blocks:
+ out, out2 = blk(out, out2, pos1, pos2)
+ out = self.dec_norm(out)
+ return out
+
+ def patchify(self, imgs):
+ """
+ imgs: (B, 3, H, W)
+ x: (B, L, patch_size**2 *3)
+ """
+ p = self.patch_embed.patch_size[0]
+ assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
+
+ h = w = imgs.shape[2] // p
+ x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
+ x = torch.einsum('nchpwq->nhwpqc', x)
+ x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
+
+ return x
+
+ def unpatchify(self, x, channels=3):
+ """
+ x: (N, L, patch_size**2 *channels)
+ imgs: (N, 3, H, W)
+ """
+ patch_size = self.patch_embed.patch_size[0]
+ h = w = int(x.shape[1]**.5)
+ assert h * w == x.shape[1]
+ x = x.reshape(shape=(x.shape[0], h, w, patch_size, patch_size, channels))
+ x = torch.einsum('nhwpqc->nchpwq', x)
+ imgs = x.reshape(shape=(x.shape[0], channels, h * patch_size, h * patch_size))
+ return imgs
+
+ def forward(self, img1, img2):
+ """
+ img1: tensor of size B x 3 x img_size x img_size
+ img2: tensor of size B x 3 x img_size x img_size
+
+ out will be B x N x (3*patch_size*patch_size)
+ masks are also returned as B x N just in case
+ """
+ # encoder of the masked first image
+ feat1, pos1, mask1 = self._encode_image(img1, do_mask=True)
+ # encoder of the second image
+ feat2, pos2, _ = self._encode_image(img2, do_mask=False)
+ # decoder
+ decfeat = self._decoder(feat1, pos1, mask1, feat2, pos2)
+ # prediction head
+ out = self.prediction_head(decfeat)
+ # get target
+ target = self.patchify(img1)
+ return out, mask1, target
diff --git a/src/model/encoder/backbone/croco/curope/__init__.py b/src/model/encoder/backbone/croco/curope/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..25e3d48a162760260826080f6366838e83e26878
--- /dev/null
+++ b/src/model/encoder/backbone/croco/curope/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (C) 2022-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+
+from .curope2d import cuRoPE2D
diff --git a/src/model/encoder/backbone/croco/curope/curope.cpp b/src/model/encoder/backbone/croco/curope/curope.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..8fe9058e05aa1bf3f37b0d970edc7312bc68455b
--- /dev/null
+++ b/src/model/encoder/backbone/croco/curope/curope.cpp
@@ -0,0 +1,69 @@
+/*
+ Copyright (C) 2022-present Naver Corporation. All rights reserved.
+ Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+*/
+
+#include
+
+// forward declaration
+void rope_2d_cuda( torch::Tensor tokens, const torch::Tensor pos, const float base, const float fwd );
+
+void rope_2d_cpu( torch::Tensor tokens, const torch::Tensor positions, const float base, const float fwd )
+{
+ const int B = tokens.size(0);
+ const int N = tokens.size(1);
+ const int H = tokens.size(2);
+ const int D = tokens.size(3) / 4;
+
+ auto tok = tokens.accessor();
+ auto pos = positions.accessor();
+
+ for (int b = 0; b < B; b++) {
+ for (int x = 0; x < 2; x++) { // y and then x (2d)
+ for (int n = 0; n < N; n++) {
+
+ // grab the token position
+ const int p = pos[b][n][x];
+
+ for (int h = 0; h < H; h++) {
+ for (int d = 0; d < D; d++) {
+ // grab the two values
+ float u = tok[b][n][h][d+0+x*2*D];
+ float v = tok[b][n][h][d+D+x*2*D];
+
+ // grab the cos,sin
+ const float inv_freq = fwd * p / powf(base, d/float(D));
+ float c = cosf(inv_freq);
+ float s = sinf(inv_freq);
+
+ // write the result
+ tok[b][n][h][d+0+x*2*D] = u*c - v*s;
+ tok[b][n][h][d+D+x*2*D] = v*c + u*s;
+ }
+ }
+ }
+ }
+ }
+}
+
+void rope_2d( torch::Tensor tokens, // B,N,H,D
+ const torch::Tensor positions, // B,N,2
+ const float base,
+ const float fwd )
+{
+ TORCH_CHECK(tokens.dim() == 4, "tokens must have 4 dimensions");
+ TORCH_CHECK(positions.dim() == 3, "positions must have 3 dimensions");
+ TORCH_CHECK(tokens.size(0) == positions.size(0), "batch size differs between tokens & positions");
+ TORCH_CHECK(tokens.size(1) == positions.size(1), "seq_length differs between tokens & positions");
+ TORCH_CHECK(positions.size(2) == 2, "positions.shape[2] must be equal to 2");
+ TORCH_CHECK(tokens.is_cuda() == positions.is_cuda(), "tokens and positions are not on the same device" );
+
+ if (tokens.is_cuda())
+ rope_2d_cuda( tokens, positions, base, fwd );
+ else
+ rope_2d_cpu( tokens, positions, base, fwd );
+}
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("rope_2d", &rope_2d, "RoPE 2d forward/backward");
+}
diff --git a/src/model/encoder/backbone/croco/curope/curope2d.py b/src/model/encoder/backbone/croco/curope/curope2d.py
new file mode 100644
index 0000000000000000000000000000000000000000..a49c12f8c529e9a889b5ac20c5767158f238e17d
--- /dev/null
+++ b/src/model/encoder/backbone/croco/curope/curope2d.py
@@ -0,0 +1,40 @@
+# Copyright (C) 2022-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+
+import torch
+
+try:
+ import curope as _kernels # run `python setup.py install`
+except ModuleNotFoundError:
+ from . import curope as _kernels # run `python setup.py build_ext --inplace`
+
+
+class cuRoPE2D_func (torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, tokens, positions, base, F0=1):
+ ctx.save_for_backward(positions)
+ ctx.saved_base = base
+ ctx.saved_F0 = F0
+ # tokens = tokens.clone() # uncomment this if inplace doesn't work
+ _kernels.rope_2d( tokens, positions, base, F0 )
+ ctx.mark_dirty(tokens)
+ return tokens
+
+ @staticmethod
+ def backward(ctx, grad_res):
+ positions, base, F0 = ctx.saved_tensors[0], ctx.saved_base, ctx.saved_F0
+ _kernels.rope_2d( grad_res, positions, base, -F0 )
+ ctx.mark_dirty(grad_res)
+ return grad_res, None, None, None
+
+
+class cuRoPE2D(torch.nn.Module):
+ def __init__(self, freq=100.0, F0=1.0):
+ super().__init__()
+ self.base = freq
+ self.F0 = F0
+
+ def forward(self, tokens, positions):
+ cuRoPE2D_func.apply( tokens.transpose(1,2), positions, self.base, self.F0 )
+ return tokens
\ No newline at end of file
diff --git a/src/model/encoder/backbone/croco/curope/kernels.cu b/src/model/encoder/backbone/croco/curope/kernels.cu
new file mode 100644
index 0000000000000000000000000000000000000000..7156cd1bb935cb1f0be45e58add53f9c21505c20
--- /dev/null
+++ b/src/model/encoder/backbone/croco/curope/kernels.cu
@@ -0,0 +1,108 @@
+/*
+ Copyright (C) 2022-present Naver Corporation. All rights reserved.
+ Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+*/
+
+#include
+#include
+#include
+#include
+
+#define CHECK_CUDA(tensor) {\
+ TORCH_CHECK((tensor).is_cuda(), #tensor " is not in cuda memory"); \
+ TORCH_CHECK((tensor).is_contiguous(), #tensor " is not contiguous"); }
+void CHECK_KERNEL() {auto error = cudaGetLastError(); TORCH_CHECK( error == cudaSuccess, cudaGetErrorString(error));}
+
+
+template < typename scalar_t >
+__global__ void rope_2d_cuda_kernel(
+ //scalar_t* __restrict__ tokens,
+ torch::PackedTensorAccessor32 tokens,
+ const int64_t* __restrict__ pos,
+ const float base,
+ const float fwd )
+ // const int N, const int H, const int D )
+{
+ // tokens shape = (B, N, H, D)
+ const int N = tokens.size(1);
+ const int H = tokens.size(2);
+ const int D = tokens.size(3);
+
+ // each block update a single token, for all heads
+ // each thread takes care of a single output
+ extern __shared__ float shared[];
+ float* shared_inv_freq = shared + D;
+
+ const int b = blockIdx.x / N;
+ const int n = blockIdx.x % N;
+
+ const int Q = D / 4;
+ // one token = [0..Q : Q..2Q : 2Q..3Q : 3Q..D]
+ // u_Y v_Y u_X v_X
+
+ // shared memory: first, compute inv_freq
+ if (threadIdx.x < Q)
+ shared_inv_freq[threadIdx.x] = fwd / powf(base, threadIdx.x/float(Q));
+ __syncthreads();
+
+ // start of X or Y part
+ const int X = threadIdx.x < D/2 ? 0 : 1;
+ const int m = (X*D/2) + (threadIdx.x % Q); // index of u_Y or u_X
+
+ // grab the cos,sin appropriate for me
+ const float freq = pos[blockIdx.x*2+X] * shared_inv_freq[threadIdx.x % Q];
+ const float cos = cosf(freq);
+ const float sin = sinf(freq);
+ /*
+ float* shared_cos_sin = shared + D + D/4;
+ if ((threadIdx.x % (D/2)) < Q)
+ shared_cos_sin[m+0] = cosf(freq);
+ else
+ shared_cos_sin[m+Q] = sinf(freq);
+ __syncthreads();
+ const float cos = shared_cos_sin[m+0];
+ const float sin = shared_cos_sin[m+Q];
+ */
+
+ for (int h = 0; h < H; h++)
+ {
+ // then, load all the token for this head in shared memory
+ shared[threadIdx.x] = tokens[b][n][h][threadIdx.x];
+ __syncthreads();
+
+ const float u = shared[m];
+ const float v = shared[m+Q];
+
+ // write output
+ if ((threadIdx.x % (D/2)) < Q)
+ tokens[b][n][h][threadIdx.x] = u*cos - v*sin;
+ else
+ tokens[b][n][h][threadIdx.x] = v*cos + u*sin;
+ }
+}
+
+void rope_2d_cuda( torch::Tensor tokens, const torch::Tensor pos, const float base, const float fwd )
+{
+ const int B = tokens.size(0); // batch size
+ const int N = tokens.size(1); // sequence length
+ const int H = tokens.size(2); // number of heads
+ const int D = tokens.size(3); // dimension per head
+
+ TORCH_CHECK(tokens.stride(3) == 1 && tokens.stride(2) == D, "tokens are not contiguous");
+ TORCH_CHECK(pos.is_contiguous(), "positions are not contiguous");
+ TORCH_CHECK(pos.size(0) == B && pos.size(1) == N && pos.size(2) == 2, "bad pos.shape");
+ TORCH_CHECK(D % 4 == 0, "token dim must be multiple of 4");
+
+ // one block for each layer, one thread per local-max
+ const int THREADS_PER_BLOCK = D;
+ const int N_BLOCKS = B * N; // each block takes care of H*D values
+ const int SHARED_MEM = sizeof(float) * (D + D/4);
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(tokens.type(), "rope_2d_cuda", ([&] {
+ rope_2d_cuda_kernel <<>> (
+ //tokens.data_ptr(),
+ tokens.packed_accessor32(),
+ pos.data_ptr(),
+ base, fwd); //, N, H, D );
+ }));
+}
diff --git a/src/model/encoder/backbone/croco/curope/setup.py b/src/model/encoder/backbone/croco/curope/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..230632ed05e309200e8f93a3a852072333975009
--- /dev/null
+++ b/src/model/encoder/backbone/croco/curope/setup.py
@@ -0,0 +1,34 @@
+# Copyright (C) 2022-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+
+from setuptools import setup
+from torch import cuda
+from torch.utils.cpp_extension import BuildExtension, CUDAExtension
+
+# compile for all possible CUDA architectures
+all_cuda_archs = cuda.get_gencode_flags().replace('compute=','arch=').split()
+# alternatively, you can list cuda archs that you want, eg:
+# all_cuda_archs = [
+ # '-gencode', 'arch=compute_70,code=sm_70',
+ # '-gencode', 'arch=compute_75,code=sm_75',
+ # '-gencode', 'arch=compute_80,code=sm_80',
+ # '-gencode', 'arch=compute_86,code=sm_86'
+# ]
+
+setup(
+ name = 'curope',
+ ext_modules = [
+ CUDAExtension(
+ name='curope',
+ sources=[
+ "curope.cpp",
+ "kernels.cu",
+ ],
+ extra_compile_args = dict(
+ nvcc=['-O3','--ptxas-options=-v',"--use_fast_math"]+all_cuda_archs,
+ cxx=['-O3'])
+ )
+ ],
+ cmdclass = {
+ 'build_ext': BuildExtension
+ })
diff --git a/src/model/encoder/backbone/croco/masking.py b/src/model/encoder/backbone/croco/masking.py
new file mode 100644
index 0000000000000000000000000000000000000000..fb0d36f53efb4d42f3270db515235dceea8a44c2
--- /dev/null
+++ b/src/model/encoder/backbone/croco/masking.py
@@ -0,0 +1,25 @@
+# Copyright (C) 2022-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+
+
+# --------------------------------------------------------
+# Masking utils
+# --------------------------------------------------------
+
+import torch
+import torch.nn as nn
+
+class RandomMask(nn.Module):
+ """
+ random masking
+ """
+
+ def __init__(self, num_patches, mask_ratio):
+ super().__init__()
+ self.num_patches = num_patches
+ self.num_mask = int(mask_ratio * self.num_patches)
+
+ def __call__(self, x):
+ noise = torch.rand(x.size(0), self.num_patches, device=x.device)
+ argsort = torch.argsort(noise, dim=1)
+ return argsort < self.num_mask
diff --git a/src/model/encoder/backbone/croco/misc.py b/src/model/encoder/backbone/croco/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..d7e882197d44d2c6d778f0b79bd90675f2a124ff
--- /dev/null
+++ b/src/model/encoder/backbone/croco/misc.py
@@ -0,0 +1,138 @@
+# Copyright (C) 2024-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+#
+# --------------------------------------------------------
+# utilitary functions for DUSt3R
+# --------------------------------------------------------
+import torch
+
+
+def fill_default_args(kwargs, func):
+ import inspect # a bit hacky but it works reliably
+ signature = inspect.signature(func)
+
+ for k, v in signature.parameters.items():
+ if v.default is inspect.Parameter.empty:
+ continue
+ kwargs.setdefault(k, v.default)
+
+ return kwargs
+
+
+def freeze_all_params(modules):
+ for module in modules:
+ try:
+ for n, param in module.named_parameters():
+ param.requires_grad = False
+ except AttributeError:
+ # module is directly a parameter
+ module.requires_grad = False
+
+
+def is_symmetrized(gt1, gt2):
+ x = gt1['instance']
+ y = gt2['instance']
+ if len(x) == len(y) and len(x) == 1:
+ return False # special case of batchsize 1
+ ok = True
+ for i in range(0, len(x), 2):
+ ok = ok and (x[i] == y[i+1]) and (x[i+1] == y[i])
+ return ok
+
+
+def flip(tensor):
+ """ flip so that tensor[0::2] <=> tensor[1::2] """
+ return torch.stack((tensor[1::2], tensor[0::2]), dim=1).flatten(0, 1)
+
+
+def interleave(tensor1, tensor2):
+ res1 = torch.stack((tensor1, tensor2), dim=1).flatten(0, 1)
+ res2 = torch.stack((tensor2, tensor1), dim=1).flatten(0, 1)
+ return res1, res2
+
+
+def _interleave_imgs(img1, img2):
+ res = {}
+ for key, value1 in img1.items():
+ value2 = img2[key]
+ if isinstance(value1, torch.Tensor):
+ value = torch.stack((value1, value2), dim=1).flatten(0, 1)
+ else:
+ value = [x for pair in zip(value1, value2) for x in pair]
+ res[key] = value
+ return res
+
+
+def make_batch_symmetric(view1, view2):
+ view1, view2 = (_interleave_imgs(view1, view2), _interleave_imgs(view2, view1))
+ return view1, view2
+
+
+def transpose_to_landscape(head, activate=True):
+ """ Predict in the correct aspect-ratio,
+ then transpose the result in landscape
+ and stack everything back together.
+ """
+ def wrapper_no(decout, true_shape, ray_embedding=None):
+ B = len(true_shape)
+ assert true_shape[0:1].allclose(true_shape), 'true_shape must be all identical'
+ H, W = true_shape[0].cpu().tolist()
+ res = head(decout, (H, W), ray_embedding=ray_embedding)
+ return res
+
+ def wrapper_yes(decout, true_shape, ray_embedding=None):
+ B = len(true_shape)
+ # by definition, the batch is in landscape mode so W >= H
+ H, W = int(true_shape.min()), int(true_shape.max())
+
+ height, width = true_shape.T
+ is_landscape = (width >= height)
+ is_portrait = ~is_landscape
+
+ # true_shape = true_shape.cpu()
+ if is_landscape.all():
+ return head(decout, (H, W), ray_embedding=ray_embedding)
+ if is_portrait.all():
+ return transposed(head(decout, (W, H), ray_embedding=ray_embedding))
+
+ # batch is a mix of both portraint & landscape
+ def selout(ar): return [d[ar] for d in decout]
+ l_result = head(selout(is_landscape), (H, W), ray_embedding=ray_embedding)
+ p_result = transposed(head(selout(is_portrait), (W, H), ray_embedding=ray_embedding))
+
+ # allocate full result
+ result = {}
+ for k in l_result | p_result:
+ x = l_result[k].new(B, *l_result[k].shape[1:])
+ x[is_landscape] = l_result[k]
+ x[is_portrait] = p_result[k]
+ result[k] = x
+
+ return result
+
+ return wrapper_yes if activate else wrapper_no
+
+
+def transposed(dic):
+ return {k: v.swapaxes(1, 2) for k, v in dic.items()}
+
+
+def invalid_to_nans(arr, valid_mask, ndim=999):
+ if valid_mask is not None:
+ arr = arr.clone()
+ arr[~valid_mask] = float('nan')
+ if arr.ndim > ndim:
+ arr = arr.flatten(-2 - (arr.ndim - ndim), -2)
+ return arr
+
+
+def invalid_to_zeros(arr, valid_mask, ndim=999):
+ if valid_mask is not None:
+ arr = arr.clone()
+ arr[~valid_mask] = 0
+ nnz = valid_mask.view(len(valid_mask), -1).sum(1)
+ else:
+ nnz = arr.numel() // len(arr) if len(arr) else 0 # number of point per image
+ if arr.ndim > ndim:
+ arr = arr.flatten(-2 - (arr.ndim - ndim), -2)
+ return arr, nnz
diff --git a/src/model/encoder/backbone/croco/patch_embed.py b/src/model/encoder/backbone/croco/patch_embed.py
new file mode 100644
index 0000000000000000000000000000000000000000..02c88f4685fedcbaa9a6b686536cdef47c27c6a7
--- /dev/null
+++ b/src/model/encoder/backbone/croco/patch_embed.py
@@ -0,0 +1,70 @@
+# Copyright (C) 2024-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+#
+# --------------------------------------------------------
+# PatchEmbed implementation for DUST3R,
+# in particular ManyAR_PatchEmbed that Handle images with non-square aspect ratio
+# --------------------------------------------------------
+import torch
+
+from .blocks import PatchEmbed
+
+
+def get_patch_embed(patch_embed_cls, img_size, patch_size, enc_embed_dim, in_chans=3):
+ assert patch_embed_cls in ['PatchEmbedDust3R', 'ManyAR_PatchEmbed']
+ patch_embed = eval(patch_embed_cls)(img_size, patch_size, in_chans, enc_embed_dim)
+ return patch_embed
+
+
+class PatchEmbedDust3R(PatchEmbed):
+ def forward(self, x, **kw):
+ B, C, H, W = x.shape
+ assert H % self.patch_size[0] == 0, f"Input image height ({H}) is not a multiple of patch size ({self.patch_size[0]})."
+ assert W % self.patch_size[1] == 0, f"Input image width ({W}) is not a multiple of patch size ({self.patch_size[1]})."
+ x = self.proj(x)
+ pos = self.position_getter(B, x.size(2), x.size(3), x.device)
+ if self.flatten:
+ x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
+ x = self.norm(x)
+ return x, pos
+
+
+class ManyAR_PatchEmbed (PatchEmbed):
+ """ Handle images with non-square aspect ratio.
+ All images in the same batch have the same aspect ratio.
+ true_shape = [(height, width) ...] indicates the actual shape of each image.
+ """
+
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True):
+ self.embed_dim = embed_dim
+ super().__init__(img_size, patch_size, in_chans, embed_dim, norm_layer, flatten)
+
+ def forward(self, img, true_shape):
+ B, C, H, W = img.shape
+ assert W >= H, f'img should be in landscape mode, but got {W=} {H=}'
+ assert H % self.patch_size[0] == 0, f"Input image height ({H}) is not a multiple of patch size ({self.patch_size[0]})."
+ assert W % self.patch_size[1] == 0, f"Input image width ({W}) is not a multiple of patch size ({self.patch_size[1]})."
+ assert true_shape.shape == (B, 2), f"true_shape has the wrong shape={true_shape.shape}"
+
+ # size expressed in tokens
+ W //= self.patch_size[0]
+ H //= self.patch_size[1]
+ n_tokens = H * W
+
+ height, width = true_shape.T
+ is_landscape = (width >= height)
+ is_portrait = ~is_landscape
+
+ # allocate result
+ x = img.new_zeros((B, n_tokens, self.embed_dim))
+ pos = img.new_zeros((B, n_tokens, 2), dtype=torch.int64)
+
+ # linear projection, transposed if necessary
+ x[is_landscape] = self.proj(img[is_landscape]).permute(0, 2, 3, 1).flatten(1, 2).float()
+ x[is_portrait] = self.proj(img[is_portrait].swapaxes(-1, -2)).permute(0, 2, 3, 1).flatten(1, 2).float()
+
+ pos[is_landscape] = self.position_getter(1, H, W, pos.device)
+ pos[is_portrait] = self.position_getter(1, W, H, pos.device)
+
+ x = self.norm(x)
+ return x, pos
diff --git a/src/model/encoder/backbone/croco/pos_embed.py b/src/model/encoder/backbone/croco/pos_embed.py
new file mode 100644
index 0000000000000000000000000000000000000000..eacb29dc11ce773a8d2381d58dc72693e687ca8f
--- /dev/null
+++ b/src/model/encoder/backbone/croco/pos_embed.py
@@ -0,0 +1,159 @@
+# Copyright (C) 2022-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+
+
+# --------------------------------------------------------
+# Position embedding utils
+# --------------------------------------------------------
+
+
+
+import numpy as np
+
+import torch
+
+# --------------------------------------------------------
+# 2D sine-cosine position embedding
+# References:
+# MAE: https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
+# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
+# MoCo v3: https://github.com/facebookresearch/moco-v3
+# --------------------------------------------------------
+def get_2d_sincos_pos_embed(embed_dim, grid_size, n_cls_token=0):
+ """
+ grid_size: int of the grid height and width
+ return:
+ pos_embed: [grid_size*grid_size, embed_dim] or [n_cls_token+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
+ """
+ grid_h = np.arange(grid_size, dtype=np.float32)
+ grid_w = np.arange(grid_size, dtype=np.float32)
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
+ grid = np.stack(grid, axis=0)
+
+ grid = grid.reshape([2, 1, grid_size, grid_size])
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
+ if n_cls_token>0:
+ pos_embed = np.concatenate([np.zeros([n_cls_token, embed_dim]), pos_embed], axis=0)
+ return pos_embed
+
+
+def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
+ assert embed_dim % 2 == 0
+
+ # use half of dimensions to encode grid_h
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
+
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
+ return emb
+
+
+def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
+ """
+ embed_dim: output dimension for each position
+ pos: a list of positions to be encoded: size (M,)
+ out: (M, D)
+ """
+ assert embed_dim % 2 == 0
+ omega = np.arange(embed_dim // 2, dtype=float)
+ omega /= embed_dim / 2.
+ omega = 1. / 10000**omega # (D/2,)
+
+ pos = pos.reshape(-1) # (M,)
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
+
+ emb_sin = np.sin(out) # (M, D/2)
+ emb_cos = np.cos(out) # (M, D/2)
+
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
+ return emb
+
+
+# --------------------------------------------------------
+# Interpolate position embeddings for high-resolution
+# References:
+# MAE: https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
+# DeiT: https://github.com/facebookresearch/deit
+# --------------------------------------------------------
+def interpolate_pos_embed(model, checkpoint_model):
+ if 'pos_embed' in checkpoint_model:
+ pos_embed_checkpoint = checkpoint_model['pos_embed']
+ embedding_size = pos_embed_checkpoint.shape[-1]
+ num_patches = model.patch_embed.num_patches
+ num_extra_tokens = model.pos_embed.shape[-2] - num_patches
+ # height (== width) for the checkpoint position embedding
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
+ # height (== width) for the new position embedding
+ new_size = int(num_patches ** 0.5)
+ # class_token and dist_token are kept unchanged
+ if orig_size != new_size:
+ print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
+ # only the position tokens are interpolated
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
+ pos_tokens = torch.nn.functional.interpolate(
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
+ checkpoint_model['pos_embed'] = new_pos_embed
+
+
+#----------------------------------------------------------
+# RoPE2D: RoPE implementation in 2D
+#----------------------------------------------------------
+
+try:
+ from .curope import cuRoPE2D
+ RoPE2D = cuRoPE2D
+except ImportError:
+ print('Warning, cannot find cuda-compiled version of RoPE2D, using a slow pytorch version instead')
+
+ class RoPE2D(torch.nn.Module):
+
+ def __init__(self, freq=100.0, F0=1.0):
+ super().__init__()
+ self.base = freq
+ self.F0 = F0
+ self.cache = {}
+
+ def get_cos_sin(self, D, seq_len, device, dtype):
+ if (D,seq_len,device,dtype) not in self.cache:
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, D, 2).float().to(device) / D))
+ t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
+ freqs = torch.einsum("i,j->ij", t, inv_freq).to(dtype)
+ freqs = torch.cat((freqs, freqs), dim=-1)
+ cos = freqs.cos() # (Seq, Dim)
+ sin = freqs.sin()
+ self.cache[D,seq_len,device,dtype] = (cos,sin)
+ return self.cache[D,seq_len,device,dtype]
+
+ @staticmethod
+ def rotate_half(x):
+ x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+ def apply_rope1d(self, tokens, pos1d, cos, sin):
+ assert pos1d.ndim==2
+ cos = torch.nn.functional.embedding(pos1d, cos)[:, None, :, :]
+ sin = torch.nn.functional.embedding(pos1d, sin)[:, None, :, :]
+ return (tokens * cos) + (self.rotate_half(tokens) * sin)
+
+ def forward(self, tokens, positions):
+ """
+ input:
+ * tokens: batch_size x nheads x ntokens x dim
+ * positions: batch_size x ntokens x 2 (y and x position of each token)
+ output:
+ * tokens after appplying RoPE2D (batch_size x nheads x ntokens x dim)
+ """
+ assert tokens.size(3)%2==0, "number of dimensions should be a multiple of two"
+ D = tokens.size(3) // 2
+ assert positions.ndim==3 and positions.shape[-1] == 2 # Batch, Seq, 2
+ cos, sin = self.get_cos_sin(D, int(positions.max())+1, tokens.device, tokens.dtype)
+ # split features into two along the feature dimension, and apply rope1d on each half
+ y, x = tokens.chunk(2, dim=-1)
+ y = self.apply_rope1d(y, positions[:,:,0], cos, sin)
+ x = self.apply_rope1d(x, positions[:,:,1], cos, sin)
+ tokens = torch.cat((y, x), dim=-1)
+ return tokens
\ No newline at end of file
diff --git a/src/model/encoder/common/gaussian_adapter.py b/src/model/encoder/common/gaussian_adapter.py
new file mode 100644
index 0000000000000000000000000000000000000000..17315195b810903fe8aaf8cc0bbf857b39905890
--- /dev/null
+++ b/src/model/encoder/common/gaussian_adapter.py
@@ -0,0 +1,180 @@
+from dataclasses import dataclass
+from typing import Optional
+
+import torch
+import torch.nn.functional as F
+from einops import einsum, rearrange
+from jaxtyping import Float
+from torch import Tensor, nn
+
+from src.geometry.projection import get_world_rays
+from src.misc.sh_rotation import rotate_sh
+from .gaussians import build_covariance
+
+from ...types import Gaussians
+
+@dataclass
+class GaussianAdapterCfg:
+ gaussian_scale_min: float
+ gaussian_scale_max: float
+ sh_degree: int
+
+
+class GaussianAdapter(nn.Module):
+ cfg: GaussianAdapterCfg
+
+ def __init__(self, cfg: GaussianAdapterCfg):
+ super().__init__()
+ self.cfg = cfg
+
+ # Create a mask for the spherical harmonics coefficients. This ensures that at
+ # initialization, the coefficients are biased towards having a large DC
+ # component and small view-dependent components.
+ self.register_buffer(
+ "sh_mask",
+ torch.ones((self.d_sh,), dtype=torch.float32),
+ persistent=False,
+ )
+ for degree in range(1, self.cfg.sh_degree + 1):
+ self.sh_mask[degree**2 : (degree + 1) ** 2] = 0.1 * 0.25**degree
+
+ def forward(
+ self,
+ extrinsics: Float[Tensor, "*#batch 4 4"],
+ intrinsics: Float[Tensor, "*#batch 3 3"],
+ coordinates: Float[Tensor, "*#batch 2"],
+ depths: Float[Tensor, "*#batch"],
+ opacities: Float[Tensor, "*#batch"],
+ raw_gaussians: Float[Tensor, "*#batch _"],
+ image_shape: tuple[int, int],
+ eps: float = 1e-8,
+ ) -> Gaussians:
+ device = extrinsics.device
+ scales, rotations, sh = raw_gaussians.split((3, 4, 3 * self.d_sh), dim=-1)
+
+ # Map scale features to valid scale range.
+ scale_min = self.cfg.gaussian_scale_min
+ scale_max = self.cfg.gaussian_scale_max
+ scales = scale_min + (scale_max - scale_min) * scales.sigmoid()
+ h, w = image_shape
+ pixel_size = 1 / torch.tensor((w, h), dtype=torch.float32, device=device)
+ multiplier = self.get_scale_multiplier(intrinsics, pixel_size)
+ scales = scales * depths[..., None] * multiplier[..., None]
+
+ # Normalize the quaternion features to yield a valid quaternion.
+ rotations = rotations / (rotations.norm(dim=-1, keepdim=True) + eps)
+
+ sh = rearrange(sh, "... (xyz d_sh) -> ... xyz d_sh", xyz=3)
+ sh = sh.broadcast_to((*opacities.shape, 3, self.d_sh)) * self.sh_mask
+
+ # Create world-space covariance matrices.
+ covariances = build_covariance(scales, rotations)
+ c2w_rotations = extrinsics[..., :3, :3]
+ covariances = c2w_rotations @ covariances @ c2w_rotations.transpose(-1, -2)
+
+ # Compute Gaussian means.
+ origins, directions = get_world_rays(coordinates, extrinsics, intrinsics)
+ means = origins + directions * depths[..., None]
+
+ return Gaussians(
+ means=means,
+ covariances=covariances,
+ # harmonics=rotate_sh(sh, c2w_rotations[..., None, :, :]),
+ harmonics=sh,
+ opacities=opacities,
+ # Note: These aren't yet rotated into world space, but they're only used for
+ # exporting Gaussians to ply files. This needs to be fixed...
+ scales=scales,
+ rotations=rotations.broadcast_to((*scales.shape[:-1], 4)),
+ )
+
+ def get_scale_multiplier(
+ self,
+ intrinsics: Float[Tensor, "*#batch 3 3"],
+ pixel_size: Float[Tensor, "*#batch 2"],
+ multiplier: float = 0.1,
+ ) -> Float[Tensor, " *batch"]:
+ xy_multipliers = multiplier * einsum(
+ intrinsics[..., :2, :2].inverse(),
+ pixel_size,
+ "... i j, j -> ... i",
+ )
+ return xy_multipliers.sum(dim=-1)
+
+ @property
+ def d_sh(self) -> int:
+ return (self.cfg.sh_degree + 1) ** 2
+
+ @property
+ def d_in(self) -> int:
+ return 7 + 3 * self.d_sh
+
+
+class UnifiedGaussianAdapter(GaussianAdapter):
+ def forward(
+ self,
+ means: Float[Tensor, "*#batch 3"],
+ # levels: Float[Tensor, "*#batch"],
+ depths: Float[Tensor, "*#batch"],
+ opacities: Float[Tensor, "*#batch"],
+ raw_gaussians: Float[Tensor, "*#batch _"],
+ eps: float = 1e-8,
+ intrinsics: Optional[Float[Tensor, "*#batch 3 3"]] = None,
+ coordinates: Optional[Float[Tensor, "*#batch 2"]] = None,
+ ) -> Gaussians:
+ scales, rotations, sh = raw_gaussians.split((3, 4, 3 * self.d_sh), dim=-1)
+
+ scales = 0.001 * F.softplus(scales)
+ scales = scales.clamp_max(0.3)
+
+ # Normalize the quaternion features to yield a valid quaternion.
+ rotations = rotations / (rotations.norm(dim=-1, keepdim=True) + eps)
+
+ sh = rearrange(sh, "... (xyz d_sh) -> ... xyz d_sh", xyz=3)
+ sh = sh.broadcast_to((*opacities.shape, 3, self.d_sh)) * self.sh_mask
+ # print(scales.max())
+ covariances = build_covariance(scales, rotations)
+
+ return Gaussians(
+ means=means.float(),
+ # levels=levels.int(),
+ covariances=covariances.float(),
+ harmonics=sh.float(),
+ opacities=opacities.float(),
+ scales=scales.float(),
+ rotations=rotations.float(),
+ )
+
+class Unet3dGaussianAdapter(GaussianAdapter):
+ def forward(
+ self,
+ means: Float[Tensor, "*#batch 3"],
+ depths: Float[Tensor, "*#batch"],
+ opacities: Float[Tensor, "*#batch"],
+ raw_gaussians: Float[Tensor, "*#batch _"],
+ eps: float = 1e-8,
+ intrinsics: Optional[Float[Tensor, "*#batch 3 3"]] = None,
+ coordinates: Optional[Float[Tensor, "*#batch 2"]] = None,
+ ) -> Gaussians:
+ scales, rotations, sh = raw_gaussians.split((3, 4, 3 * self.d_sh), dim=-1)
+
+ scales = 0.001 * F.softplus(scales)
+ scales = scales.clamp_max(0.3)
+
+ # Normalize the quaternion features to yield a valid quaternion.
+ rotations = rotations / (rotations.norm(dim=-1, keepdim=True) + eps)
+
+ sh = rearrange(sh, "... (xyz d_sh) -> ... xyz d_sh", xyz=3)
+ sh = sh.broadcast_to((*opacities.shape, 3, self.d_sh)) * self.sh_mask
+
+ covariances = build_covariance(scales, rotations)
+
+ return Gaussians(
+ means=means,
+ covariances=covariances,
+ harmonics=sh,
+ opacities=opacities,
+ scales=scales,
+ rotations=rotations,
+ )
+
diff --git a/src/model/encoder/common/gaussians.py b/src/model/encoder/common/gaussians.py
new file mode 100644
index 0000000000000000000000000000000000000000..1bff0519fa2d8c7e93d66ede9da9e5166e554a5e
--- /dev/null
+++ b/src/model/encoder/common/gaussians.py
@@ -0,0 +1,44 @@
+import torch
+from einops import rearrange
+from jaxtyping import Float
+from torch import Tensor
+
+
+# https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/rotation_conversions.py
+def quaternion_to_matrix(
+ quaternions: Float[Tensor, "*batch 4"],
+ eps: float = 1e-8,
+) -> Float[Tensor, "*batch 3 3"]:
+ # Order changed to match scipy format!
+ i, j, k, r = torch.unbind(quaternions, dim=-1)
+ two_s = 2 / ((quaternions * quaternions).sum(dim=-1) + eps)
+
+ o = torch.stack(
+ (
+ 1 - two_s * (j * j + k * k),
+ two_s * (i * j - k * r),
+ two_s * (i * k + j * r),
+ two_s * (i * j + k * r),
+ 1 - two_s * (i * i + k * k),
+ two_s * (j * k - i * r),
+ two_s * (i * k - j * r),
+ two_s * (j * k + i * r),
+ 1 - two_s * (i * i + j * j),
+ ),
+ -1,
+ )
+ return rearrange(o, "... (i j) -> ... i j", i=3, j=3)
+
+
+def build_covariance(
+ scale: Float[Tensor, "*#batch 3"],
+ rotation_xyzw: Float[Tensor, "*#batch 4"],
+) -> Float[Tensor, "*batch 3 3"]:
+ scale = scale.diag_embed()
+ rotation = quaternion_to_matrix(rotation_xyzw)
+ return (
+ rotation
+ @ scale
+ @ rearrange(scale, "... i j -> ... j i")
+ @ rearrange(rotation, "... i j -> ... j i")
+ )
diff --git a/src/model/encoder/encoder.py b/src/model/encoder/encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..b1a26d3f949264ee362bc0730c837151ad1d118d
--- /dev/null
+++ b/src/model/encoder/encoder.py
@@ -0,0 +1,38 @@
+from abc import ABC, abstractmethod
+from typing import Generic, TypeVar
+
+from torch import nn
+from dataclasses import dataclass
+from src.dataset.types import BatchedViews, DataShim
+from ..types import Gaussians
+from jaxtyping import Float
+from torch import Tensor, nn
+
+T = TypeVar("T")
+
+@dataclass
+class EncoderOutput:
+ gaussians: Gaussians
+ pred_pose_enc_list: list[Float[Tensor, "batch view 6"]] | None
+ pred_context_pose: dict | None
+ depth_dict: dict | None
+ infos: dict | None
+ distill_infos: dict | None
+
+class Encoder(nn.Module, ABC, Generic[T]):
+ cfg: T
+
+ def __init__(self, cfg: T) -> None:
+ super().__init__()
+ self.cfg = cfg
+
+ @abstractmethod
+ def forward(
+ self,
+ context: BatchedViews,
+ ) -> Gaussians:
+ pass
+
+ def get_data_shim(self) -> DataShim:
+ """The default shim doesn't modify the batch."""
+ return lambda x: x
diff --git a/src/model/encoder/heads/__init__.py b/src/model/encoder/heads/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a4f70d377289d23edf9e2f69e8c5134691281a9
--- /dev/null
+++ b/src/model/encoder/heads/__init__.py
@@ -0,0 +1,27 @@
+# Copyright (C) 2024-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+#
+# --------------------------------------------------------
+# head factory
+# --------------------------------------------------------
+from .dpt_gs_head import create_gs_dpt_head
+from .linear_head import LinearPts3d
+from .linear_gs_head import create_gs_linear_head
+from .dpt_head import create_dpt_head
+
+
+def head_factory(head_type, output_mode, net, has_conf=False, out_nchan=3):
+ """" build a prediction head for the decoder
+ """
+ if head_type == 'linear' and output_mode == 'pts3d':
+ return LinearPts3d(net, has_conf)
+ elif head_type == 'dpt' and output_mode == 'pts3d':
+ return create_dpt_head(net, has_conf=has_conf)
+ elif head_type == 'dpt' and output_mode == 'gs_params':
+ return create_dpt_head(net, has_conf=False, out_nchan=out_nchan, postprocess_func=None)
+ elif head_type == 'dpt_gs' and output_mode == 'gs_params':
+ return create_gs_dpt_head(net, has_conf=False, out_nchan=out_nchan, postprocess_func=None)
+ elif head_type == 'linear' and output_mode == 'ga_params':
+ return create_gs_linear_head(net, has_conf=False, out_nchan=out_nchan, postprocess_func=None)
+ else:
+ raise NotImplementedError(f"unexpected {head_type=} and {output_mode=}")
diff --git a/src/model/encoder/heads/dpt_block.py b/src/model/encoder/heads/dpt_block.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d86c491b0111c94fbd2bb2d6595feca561ed408
--- /dev/null
+++ b/src/model/encoder/heads/dpt_block.py
@@ -0,0 +1,459 @@
+# Copyright (C) 2022-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+
+# --------------------------------------------------------
+# DPT head for ViTs
+# --------------------------------------------------------
+# References:
+# https://github.com/isl-org/DPT
+# https://github.com/EPFL-VILAB/MultiMAE/blob/main/multimae/output_adapters.py
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange, repeat
+from typing import Union, Tuple, Iterable, List, Optional, Dict
+
+def pair(t):
+ return t if isinstance(t, tuple) else (t, t)
+
+def make_scratch(in_shape, out_shape, groups=1, expand=False):
+ scratch = nn.Module()
+
+ out_shape1 = out_shape
+ out_shape2 = out_shape
+ out_shape3 = out_shape
+ out_shape4 = out_shape
+ if expand == True:
+ out_shape1 = out_shape
+ out_shape2 = out_shape * 2
+ out_shape3 = out_shape * 4
+ out_shape4 = out_shape * 8
+
+ scratch.layer1_rn = nn.Conv2d(
+ in_shape[0],
+ out_shape1,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False,
+ groups=groups,
+ )
+ scratch.layer2_rn = nn.Conv2d(
+ in_shape[1],
+ out_shape2,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False,
+ groups=groups,
+ )
+ scratch.layer3_rn = nn.Conv2d(
+ in_shape[2],
+ out_shape3,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False,
+ groups=groups,
+ )
+ scratch.layer4_rn = nn.Conv2d(
+ in_shape[3],
+ out_shape4,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False,
+ groups=groups,
+ )
+
+ scratch.layer_rn = nn.ModuleList([
+ scratch.layer1_rn,
+ scratch.layer2_rn,
+ scratch.layer3_rn,
+ scratch.layer4_rn,
+ ])
+
+ return scratch
+
+class ResidualConvUnit_custom(nn.Module):
+ """Residual convolution module."""
+
+ def __init__(self, features, activation, bn):
+ """Init.
+ Args:
+ features (int): number of features
+ """
+ super().__init__()
+
+ self.bn = bn
+
+ self.groups = 1
+
+ self.conv1 = nn.Conv2d(
+ features,
+ features,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=not self.bn,
+ groups=self.groups,
+ )
+
+ self.conv2 = nn.Conv2d(
+ features,
+ features,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=not self.bn,
+ groups=self.groups,
+ )
+
+ if self.bn == True:
+ self.bn1 = nn.BatchNorm2d(features)
+ self.bn2 = nn.BatchNorm2d(features)
+
+ self.activation = activation
+
+ self.skip_add = nn.quantized.FloatFunctional()
+
+ def forward(self, x):
+ """Forward pass.
+ Args:
+ x (tensor): input
+ Returns:
+ tensor: output
+ """
+
+ out = self.activation(x)
+ out = self.conv1(out)
+ if self.bn == True:
+ out = self.bn1(out)
+
+ out = self.activation(out)
+ out = self.conv2(out)
+ if self.bn == True:
+ out = self.bn2(out)
+
+ if self.groups > 1:
+ out = self.conv_merge(out)
+
+ return self.skip_add.add(out, x)
+
+class FeatureFusionBlock_custom(nn.Module):
+ """Feature fusion block."""
+
+ def __init__(
+ self,
+ features,
+ activation,
+ deconv=False,
+ bn=False,
+ expand=False,
+ align_corners=True,
+ width_ratio=1,
+ ):
+ """Init.
+ Args:
+ features (int): number of features
+ """
+ super(FeatureFusionBlock_custom, self).__init__()
+ self.width_ratio = width_ratio
+
+ self.deconv = deconv
+ self.align_corners = align_corners
+
+ self.groups = 1
+
+ self.expand = expand
+ out_features = features
+ if self.expand == True:
+ out_features = features // 2
+
+ self.out_conv = nn.Conv2d(
+ features,
+ out_features,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=True,
+ groups=1,
+ )
+
+ self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
+ self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
+
+ self.skip_add = nn.quantized.FloatFunctional()
+
+ def forward(self, *xs):
+ """Forward pass.
+ Returns:
+ tensor: output
+ """
+ output = xs[0]
+
+ if len(xs) == 2:
+ res = self.resConfUnit1(xs[1])
+ if self.width_ratio != 1:
+ res = F.interpolate(res, size=(output.shape[2], output.shape[3]), mode='bilinear')
+
+ output = self.skip_add.add(output, res)
+ # output += res
+
+ output = self.resConfUnit2(output)
+
+ if self.width_ratio != 1:
+ # and output.shape[3] < self.width_ratio * output.shape[2]
+ #size=(image.shape[])
+ if (output.shape[3] / output.shape[2]) < (2 / 3) * self.width_ratio:
+ shape = 3 * output.shape[3]
+ else:
+ shape = int(self.width_ratio * 2 * output.shape[2])
+ output = F.interpolate(output, size=(2* output.shape[2], shape), mode='bilinear')
+ else:
+ output = nn.functional.interpolate(output, scale_factor=2,
+ mode="bilinear", align_corners=self.align_corners)
+ output = self.out_conv(output)
+ return output
+
+def make_fusion_block(features, use_bn, width_ratio=1, expand=False):
+ return FeatureFusionBlock_custom(
+ features,
+ nn.ReLU(False),
+ deconv=False,
+ bn=use_bn,
+ expand=expand,
+ align_corners=True,
+ width_ratio=width_ratio,
+ )
+
+class Interpolate(nn.Module):
+ """Interpolation module."""
+
+ def __init__(self, scale_factor, mode, align_corners=False):
+ """Init.
+ Args:
+ scale_factor (float): scaling
+ mode (str): interpolation mode
+ """
+ super(Interpolate, self).__init__()
+
+ self.interp = nn.functional.interpolate
+ self.scale_factor = scale_factor
+ self.mode = mode
+ self.align_corners = align_corners
+
+ def forward(self, x):
+ """Forward pass.
+ Args:
+ x (tensor): input
+ Returns:
+ tensor: interpolated data
+ """
+
+ x = self.interp(
+ x,
+ scale_factor=self.scale_factor,
+ mode=self.mode,
+ align_corners=self.align_corners,
+ )
+
+ return x
+
+class DPTOutputAdapter(nn.Module):
+ """DPT output adapter.
+
+ :param num_cahnnels: Number of output channels
+ :param stride_level: tride level compared to the full-sized image.
+ E.g. 4 for 1/4th the size of the image.
+ :param patch_size_full: Int or tuple of the patch size over the full image size.
+ Patch size for smaller inputs will be computed accordingly.
+ :param hooks: Index of intermediate layers
+ :param layer_dims: Dimension of intermediate layers
+ :param feature_dim: Feature dimension
+ :param last_dim: out_channels/in_channels for the last two Conv2d when head_type == regression
+ :param use_bn: If set to True, activates batch norm
+ :param dim_tokens_enc: Dimension of tokens coming from encoder
+ """
+
+ def __init__(self,
+ num_channels: int = 1,
+ stride_level: int = 1,
+ patch_size: Union[int, Tuple[int, int]] = 16,
+ main_tasks: Iterable[str] = ('rgb',),
+ hooks: List[int] = [2, 5, 8, 11],
+ layer_dims: List[int] = [96, 192, 384, 768],
+ feature_dim: int = 256,
+ last_dim: int = 32,
+ use_bn: bool = False,
+ dim_tokens_enc: Optional[int] = None,
+ head_type: str = 'regression',
+ output_width_ratio=1,
+ **kwargs):
+ super().__init__()
+ self.num_channels = num_channels
+ self.stride_level = stride_level
+ self.patch_size = pair(patch_size)
+ self.main_tasks = main_tasks
+ self.hooks = hooks
+ self.layer_dims = layer_dims
+ self.feature_dim = feature_dim
+ self.dim_tokens_enc = dim_tokens_enc * len(self.main_tasks) if dim_tokens_enc is not None else None
+ self.head_type = head_type
+
+ # Actual patch height and width, taking into account stride of input
+ self.P_H = max(1, self.patch_size[0] // stride_level)
+ self.P_W = max(1, self.patch_size[1] // stride_level)
+
+ self.scratch = make_scratch(layer_dims, feature_dim, groups=1, expand=False)
+
+ self.scratch.refinenet1 = make_fusion_block(feature_dim, use_bn, output_width_ratio)
+ self.scratch.refinenet2 = make_fusion_block(feature_dim, use_bn, output_width_ratio)
+ self.scratch.refinenet3 = make_fusion_block(feature_dim, use_bn, output_width_ratio)
+ self.scratch.refinenet4 = make_fusion_block(feature_dim, use_bn, output_width_ratio)
+
+ if self.head_type == 'regression':
+ # The "DPTDepthModel" head
+ self.head = nn.Sequential(
+ nn.Conv2d(feature_dim, feature_dim // 2, kernel_size=3, stride=1, padding=1),
+ Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
+ nn.Conv2d(feature_dim // 2, last_dim, kernel_size=3, stride=1, padding=1),
+ nn.ReLU(True),
+ nn.Conv2d(last_dim, self.num_channels, kernel_size=1, stride=1, padding=0)
+ )
+ elif self.head_type == 'semseg':
+ # The "DPTSegmentationModel" head
+ self.head = nn.Sequential(
+ nn.Conv2d(feature_dim, feature_dim, kernel_size=3, padding=1, bias=False),
+ nn.BatchNorm2d(feature_dim) if use_bn else nn.Identity(),
+ nn.ReLU(True),
+ nn.Dropout(0.1, False),
+ nn.Conv2d(feature_dim, self.num_channels, kernel_size=1),
+ Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
+ )
+ elif self.head_type == 'gs_params':
+ # The "DPTSegmentationModel" head
+ self.head = nn.Sequential(
+ nn.Conv2d(feature_dim, feature_dim, kernel_size=3, padding=1, bias=False),
+ nn.BatchNorm2d(feature_dim) if use_bn else nn.Identity(),
+ nn.ReLU(True),
+ nn.Dropout(0.1, False),
+ nn.Conv2d(feature_dim, self.num_channels, kernel_size=1),
+ )
+ else:
+ raise ValueError('DPT head_type must be "regression" or "semseg".')
+
+ if self.dim_tokens_enc is not None:
+ self.init(dim_tokens_enc=dim_tokens_enc)
+
+ def init(self, dim_tokens_enc=768):
+ """
+ Initialize parts of decoder that are dependent on dimension of encoder tokens.
+ Should be called when setting up MultiMAE.
+
+ :param dim_tokens_enc: Dimension of tokens coming from encoder
+ """
+ #print(dim_tokens_enc)
+
+ # Set up activation postprocessing layers
+ if isinstance(dim_tokens_enc, int):
+ dim_tokens_enc = 4 * [dim_tokens_enc]
+
+ self.dim_tokens_enc = [dt * len(self.main_tasks) for dt in dim_tokens_enc]
+
+ self.act_1_postprocess = nn.Sequential(
+ nn.Conv2d(
+ in_channels=self.dim_tokens_enc[0],
+ out_channels=self.layer_dims[0],
+ kernel_size=1, stride=1, padding=0,
+ ),
+ nn.ConvTranspose2d(
+ in_channels=self.layer_dims[0],
+ out_channels=self.layer_dims[0],
+ kernel_size=4, stride=4, padding=0,
+ bias=True, dilation=1, groups=1,
+ )
+ )
+
+ self.act_2_postprocess = nn.Sequential(
+ nn.Conv2d(
+ in_channels=self.dim_tokens_enc[1],
+ out_channels=self.layer_dims[1],
+ kernel_size=1, stride=1, padding=0,
+ ),
+ nn.ConvTranspose2d(
+ in_channels=self.layer_dims[1],
+ out_channels=self.layer_dims[1],
+ kernel_size=2, stride=2, padding=0,
+ bias=True, dilation=1, groups=1,
+ )
+ )
+
+ self.act_3_postprocess = nn.Sequential(
+ nn.Conv2d(
+ in_channels=self.dim_tokens_enc[2],
+ out_channels=self.layer_dims[2],
+ kernel_size=1, stride=1, padding=0,
+ )
+ )
+
+ self.act_4_postprocess = nn.Sequential(
+ nn.Conv2d(
+ in_channels=self.dim_tokens_enc[3],
+ out_channels=self.layer_dims[3],
+ kernel_size=1, stride=1, padding=0,
+ ),
+ nn.Conv2d(
+ in_channels=self.layer_dims[3],
+ out_channels=self.layer_dims[3],
+ kernel_size=3, stride=2, padding=1,
+ )
+ )
+
+ self.act_postprocess = nn.ModuleList([
+ self.act_1_postprocess,
+ self.act_2_postprocess,
+ self.act_3_postprocess,
+ self.act_4_postprocess
+ ])
+
+ def adapt_tokens(self, encoder_tokens):
+ # Adapt tokens
+ x = []
+ x.append(encoder_tokens[:, :])
+ x = torch.cat(x, dim=-1)
+ return x
+
+ def forward(self, encoder_tokens: List[torch.Tensor], image_size):
+ #input_info: Dict):
+ assert self.dim_tokens_enc is not None, 'Need to call init(dim_tokens_enc) function first'
+ H, W = image_size
+
+ # Number of patches in height and width
+ N_H = H // (self.stride_level * self.P_H)
+ N_W = W // (self.stride_level * self.P_W)
+
+ # Hook decoder onto 4 layers from specified ViT layers
+ layers = [encoder_tokens[hook] for hook in self.hooks]
+
+ # Extract only task-relevant tokens and ignore global tokens.
+ layers = [self.adapt_tokens(l) for l in layers]
+
+ # Reshape tokens to spatial representation
+ layers = [rearrange(l, 'b (nh nw) c -> b c nh nw', nh=N_H, nw=N_W) for l in layers]
+
+ layers = [self.act_postprocess[idx](l) for idx, l in enumerate(layers)]
+ # Project layers to chosen feature dim
+ layers = [self.scratch.layer_rn[idx](l) for idx, l in enumerate(layers)]
+
+ # Fuse layers using refinement stages
+ path_4 = self.scratch.refinenet4(layers[3])
+ path_3 = self.scratch.refinenet3(path_4, layers[2])
+ path_2 = self.scratch.refinenet2(path_3, layers[1])
+ path_1 = self.scratch.refinenet1(path_2, layers[0])
+
+ # Output head
+ out = self.head(path_1)
+
+ return out
diff --git a/src/model/encoder/heads/dpt_gs_head.py b/src/model/encoder/heads/dpt_gs_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..aeabe742b132e6be13983101413a3fc0f069c2bb
--- /dev/null
+++ b/src/model/encoder/heads/dpt_gs_head.py
@@ -0,0 +1,284 @@
+# Copyright (C) 2024-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+#
+# --------------------------------------------------------
+# dpt head implementation for DUST3R
+# Downstream heads assume inputs of size B x N x C (where N is the number of tokens) ;
+# or if it takes as input the output at every layer, the attribute return_all_layers should be set to True
+# the forward function also takes as input a dictionnary img_info with key "height" and "width"
+# for PixelwiseTask, the output will be of dimension B x num_channels x H x W
+# --------------------------------------------------------
+from einops import rearrange
+from typing import List
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+# import dust3r.utils.path_to_croco
+from .dpt_block import DPTOutputAdapter, Interpolate, make_fusion_block
+from .head_modules import UnetExtractor, AppearanceTransformer, _init_weights
+from .postprocess import postprocess
+
+# class DPTOutputAdapter_fix(DPTOutputAdapter):
+# """
+# Adapt croco's DPTOutputAdapter implementation for dust3r:
+# remove duplicated weigths, and fix forward for dust3r
+# """
+#
+# def init(self, dim_tokens_enc=768):
+# super().init(dim_tokens_enc)
+# # these are duplicated weights
+# del self.act_1_postprocess
+# del self.act_2_postprocess
+# del self.act_3_postprocess
+# del self.act_4_postprocess
+#
+# self.scratch.refinenet1 = make_fusion_block(256 * 2, False, 1, expand=True)
+# self.scratch.refinenet2 = make_fusion_block(256 * 2, False, 1, expand=True)
+# self.scratch.refinenet3 = make_fusion_block(256 * 2, False, 1, expand=True)
+# # self.scratch.refinenet4 = make_fusion_block(256 * 2, False, 1)
+#
+# self.depth_encoder = UnetExtractor(in_channel=3)
+# self.feat_up = Interpolate(scale_factor=2, mode="bilinear", align_corners=True)
+# self.out_conv = nn.Conv2d(256+3+4, 256, kernel_size=3, padding=1)
+# self.out_relu = nn.ReLU(inplace=True)
+#
+# self.input_merger = nn.Sequential(
+# # nn.Conv2d(256+3+3+1, 256, kernel_size=3, padding=1),
+# nn.Conv2d(256+3+3, 256, kernel_size=3, padding=1),
+# nn.ReLU(),
+# )
+#
+# def forward(self, encoder_tokens: List[torch.Tensor], depths, imgs, image_size=None, conf=None):
+# assert self.dim_tokens_enc is not None, 'Need to call init(dim_tokens_enc) function first'
+# # H, W = input_info['image_size']
+# image_size = self.image_size if image_size is None else image_size
+# H, W = image_size
+# # Number of patches in height and width
+# N_H = H // (self.stride_level * self.P_H)
+# N_W = W // (self.stride_level * self.P_W)
+#
+# # Hook decoder onto 4 layers from specified ViT layers
+# layers = [encoder_tokens[hook] for hook in self.hooks]
+#
+# # Extract only task-relevant tokens and ignore global tokens.
+# layers = [self.adapt_tokens(l) for l in layers]
+#
+# # Reshape tokens to spatial representation
+# layers = [rearrange(l, 'b (nh nw) c -> b c nh nw', nh=N_H, nw=N_W) for l in layers]
+#
+# layers = [self.act_postprocess[idx](l) for idx, l in enumerate(layers)]
+# # Project layers to chosen feature dim
+# layers = [self.scratch.layer_rn[idx](l) for idx, l in enumerate(layers)]
+#
+# # get depth features
+# depth_features = self.depth_encoder(depths)
+# depth_feature1, depth_feature2, depth_feature3 = depth_features
+#
+# # Fuse layers using refinement stages
+# path_4 = self.scratch.refinenet4(layers[3])[:, :, :layers[2].shape[2], :layers[2].shape[3]]
+# path_3 = self.scratch.refinenet3(torch.cat([path_4, depth_feature3], dim=1), torch.cat([layers[2], depth_feature3], dim=1))
+# path_2 = self.scratch.refinenet2(torch.cat([path_3, depth_feature2], dim=1), torch.cat([layers[1], depth_feature2], dim=1))
+# path_1 = self.scratch.refinenet1(torch.cat([path_2, depth_feature1], dim=1), torch.cat([layers[0], depth_feature1], dim=1))
+# # path_3 = self.scratch.refinenet3(path_4, layers[2], depth_feature3)
+# # path_2 = self.scratch.refinenet2(path_3, layers[1], depth_feature2)
+# # path_1 = self.scratch.refinenet1(path_2, layers[0], depth_feature1)
+#
+# path_1 = self.feat_up(path_1)
+# path_1 = torch.cat([path_1, imgs, depths], dim=1)
+# if conf is not None:
+# path_1 = torch.cat([path_1, conf], dim=1)
+# path_1 = self.input_merger(path_1)
+#
+# # Output head
+# out = self.head(path_1)
+#
+# return out
+
+
+class DPTOutputAdapter_fix(DPTOutputAdapter):
+ """
+ Adapt croco's DPTOutputAdapter implementation for dust3r:
+ remove duplicated weigths, and fix forward for dust3r
+ """
+
+ def init(self, dim_tokens_enc=768):
+ super().init(dim_tokens_enc)
+ # these are duplicated weights
+ del self.act_1_postprocess
+ del self.act_2_postprocess
+ del self.act_3_postprocess
+ del self.act_4_postprocess
+
+ self.feat_up = Interpolate(scale_factor=2, mode="bilinear", align_corners=True)
+ self.input_merger = nn.Sequential(
+ # nn.Conv2d(256+3+3+1, 256, kernel_size=3, padding=1),
+ # nn.Conv2d(3+6, 256, 7, 1, 3),
+ nn.Conv2d(3, 256, 7, 1, 3),
+ nn.ReLU(),
+ )
+
+ def forward(self, encoder_tokens: List[torch.Tensor], depths, imgs, image_size=None, conf=None):
+ assert self.dim_tokens_enc is not None, 'Need to call init(dim_tokens_enc) function first'
+ # H, W = input_info['image_size']
+ image_size = self.image_size if image_size is None else image_size
+ H, W = image_size
+ # Number of patches in height and width
+ N_H = H // (self.stride_level * self.P_H)
+ N_W = W // (self.stride_level * self.P_W)
+
+ # Hook decoder onto 4 layers from specified ViT layers
+ layers = [encoder_tokens[hook] for hook in self.hooks]
+
+ # Extract only task-relevant tokens and ignore global tokens.
+ layers = [self.adapt_tokens(l) for l in layers]
+
+ # Reshape tokens to spatial representation
+ layers = [rearrange(l, 'b (nh nw) c -> b c nh nw', nh=N_H, nw=N_W) for l in layers]
+
+ layers = [self.act_postprocess[idx](l) for idx, l in enumerate(layers)]
+ # Project layers to chosen feature dim
+ layers = [self.scratch.layer_rn[idx](l) for idx, l in enumerate(layers)]
+
+ # Fuse layers using refinement stages
+ path_4 = self.scratch.refinenet4(layers[3])[:, :, :layers[2].shape[2], :layers[2].shape[3]]
+ path_3 = self.scratch.refinenet3(path_4, layers[2])
+ path_2 = self.scratch.refinenet2(path_3, layers[1])
+ path_1 = self.scratch.refinenet1(path_2, layers[0])
+
+ direct_img_feat = self.input_merger(imgs)
+ # imgs = imgs.permute(0, 2, 3, 1).flatten(1, 2).contiguous()
+ # # Pachify
+ # patch_size = self.patch_size
+ # hh = H // patch_size[0]
+ # ww = W // patch_size[1]
+ # direct_img_feat = rearrange(imgs, "b (hh ph ww pw) d -> b (hh ww) (ph pw d)", hh=hh, ww=ww, ph=patch_size[0], pw=patch_size[1])
+
+ # actually, we just do interpolate here
+ # path_1 = self.feat_up(path_1)
+ path_1 = F.interpolate(path_1, size=(H, W), mode='bilinear', align_corners=True)
+ path_1 = path_1 + direct_img_feat
+
+ # path_1 = torch.cat([path_1, imgs], dim=1)
+
+ # Output head
+ out = self.head(path_1)
+
+ return out, [path_4, path_3, path_2]
+
+
+class PixelwiseTaskWithDPT(nn.Module):
+ """ DPT module for dust3r, can return 3D points + confidence for all pixels"""
+
+ def __init__(self, *, n_cls_token=0, hooks_idx=None, dim_tokens=None,
+ output_width_ratio=1, num_channels=1, postprocess=None, depth_mode=None, conf_mode=None, **kwargs):
+ super(PixelwiseTaskWithDPT, self).__init__()
+ self.return_all_layers = True # backbone needs to return all layers
+ self.postprocess = postprocess
+ self.depth_mode = depth_mode
+ self.conf_mode = conf_mode
+
+ assert n_cls_token == 0, "Not implemented"
+ dpt_args = dict(output_width_ratio=output_width_ratio,
+ num_channels=num_channels,
+ **kwargs)
+ if hooks_idx is not None:
+ dpt_args.update(hooks=hooks_idx)
+ self.dpt = DPTOutputAdapter_fix(**dpt_args)
+ dpt_init_args = {} if dim_tokens is None else {'dim_tokens_enc': dim_tokens}
+ self.dpt.init(**dpt_init_args)
+
+ def forward(self, x, depths, imgs, img_info, conf=None):
+ out, interm_feats = self.dpt(x, depths, imgs, image_size=(img_info[0], img_info[1]), conf=conf)
+ if self.postprocess:
+ out = self.postprocess(out, self.depth_mode, self.conf_mode)
+ return out, interm_feats
+
+class AttnBasedAppearanceHead(nn.Module):
+ """
+ Attention head Appearence Reconstruction
+ """
+
+ def __init__(self, num_channels, patch_size, feature_dim, last_dim, hooks_idx, dim_tokens, postprocess, depth_mode, conf_mode, head_type='gs_params'):
+ super().__init__()
+
+ self.num_channels = num_channels
+ self.patch_size = patch_size
+
+ self.hooks = hooks_idx
+
+ assert len(set(dim_tokens)) == 1
+
+ self.tokenizer = nn.Linear(3 * self.patch_size[0] ** 2, dim_tokens[0], bias=False)
+
+ self.attn_processor = AppearanceTransformer(num_layers=4, attn_dim=dim_tokens[0] * 2, head_dim=feature_dim)
+
+ self.token_decoder = nn.Sequential(
+ nn.LayerNorm(dim_tokens[0] * 2, bias=False),
+ nn.Linear(
+ dim_tokens[0] * 2, self.num_channels * (self.patch_size[0] ** 2),
+ bias=False,
+ )
+ )
+ self.token_decoder.apply(_init_weights)
+
+
+ def img_pts_tokenizer(self, imgs, pts3d):
+ B, V, _, H, W = imgs.shape
+ pts3d = pts3d.flatten(2, 3).contiguous()
+ imgs = imgs.permute(0, 1, 3, 4, 2).flatten(2, 3).contiguous()
+ mean = pts3d.mean(dim=-2, keepdim=True) # (B, V, 1, 3)
+ z_median = torch.median(torch.norm(pts3d, dim=-1, keepdim=True), dim=2, keepdim=True)[0] # (B, V, 1, 1)
+ pts3d_normed = (pts3d - mean) / (z_median + 1e-8) # (B, V, N, 3)
+
+ input = imgs #torch.cat([pts3d_normed, imgs], dim=-1) # (B, V, H*W, 9)
+ # Pachify
+ patch_size = self.patch_size
+ hh = H // patch_size[0]
+ ww = W // patch_size[1]
+ input = rearrange(input, "b v (hh ph ww pw) d -> (b v) (hh ww) (ph pw d)", hh=hh, ww=ww, ph=patch_size[0], pw=patch_size[1])
+ # Tokenize the input images
+ input_tokens = self.tokenizer(input)
+ return input_tokens
+
+ def forward(self, x, depths, imgs, img_info, conf=None):
+ B, V, H, W = img_info
+ input_tokens = rearrange(self.img_pts_tokenizer(imgs, depths), "(b v) l d -> b (v l) d", b=B, v=V)
+
+ # Hook decoder onto 4 layers from specified ViT layers
+ layer_tokens = [rearrange(x[hook].detach(), "(b v) l d -> b (v l) d", b=B, v=V) for hook in self.hooks]
+
+ tokens = self.attn_processor(torch.cat([input_tokens, layer_tokens[-1]], dim=-1))
+
+ gaussian_params = self.token_decoder(tokens)
+
+ patch_size = self.patch_size
+ hh = H // patch_size[0]
+ ww = W // patch_size[1]
+ gaussians = rearrange(gaussian_params, "b (v hh ww) (ph pw d) -> b (v hh ph ww pw) d", v=V, hh=hh, ww=ww, ph=patch_size[0], pw=patch_size[1])
+ return gaussians.view(B, V, H*W, -1)
+
+def create_gs_dpt_head(net, has_conf=False, out_nchan=3, postprocess_func=postprocess):
+ """
+ return PixelwiseTaskWithDPT for given net params
+ """
+ assert net.dec_depth > 9
+ l2 = net.dec_depth
+ feature_dim = net.feature_dim
+ last_dim = feature_dim//2
+ ed = net.enc_embed_dim
+ dd = net.dec_embed_dim
+ try:
+ patch_size = net.patch_size
+ except:
+ patch_size = (16, 16)
+
+ return PixelwiseTaskWithDPT(num_channels=out_nchan + has_conf,
+ patch_size=patch_size,
+ feature_dim=feature_dim,
+ last_dim=last_dim,
+ hooks_idx=[0, l2*2//4, l2*3//4, l2],
+ dim_tokens=[ed, dd, dd, dd],
+ postprocess=postprocess_func,
+ depth_mode=net.depth_mode,
+ conf_mode=net.conf_mode,
+ head_type='gs_params')
\ No newline at end of file
diff --git a/src/model/encoder/heads/dpt_head.py b/src/model/encoder/heads/dpt_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..b9be9e6e78383dd404909cd8914a64eade4e12bc
--- /dev/null
+++ b/src/model/encoder/heads/dpt_head.py
@@ -0,0 +1,119 @@
+# Copyright (C) 2024-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+#
+# --------------------------------------------------------
+# dpt head implementation for DUST3R
+# Downstream heads assume inputs of size B x N x C (where N is the number of tokens) ;
+# or if it takes as input the output at every layer, the attribute return_all_layers should be set to True
+# the forward function also takes as input a dictionnary img_info with key "height" and "width"
+# for PixelwiseTask, the output will be of dimension B x num_channels x H x W
+# --------------------------------------------------------
+from einops import rearrange
+from typing import List
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+# import dust3r.utils.path_to_croco
+from .dpt_block import DPTOutputAdapter
+from .postprocess import postprocess
+
+
+class DPTOutputAdapter_fix(DPTOutputAdapter):
+ """
+ Adapt croco's DPTOutputAdapter implementation for dust3r:
+ remove duplicated weigths, and fix forward for dust3r
+ """
+
+ def init(self, dim_tokens_enc=768):
+ super().init(dim_tokens_enc)
+ # these are duplicated weights
+ del self.act_1_postprocess
+ del self.act_2_postprocess
+ del self.act_3_postprocess
+ del self.act_4_postprocess
+
+ def forward(self, encoder_tokens: List[torch.Tensor], image_size=None, ray_embedding=None):
+ assert self.dim_tokens_enc is not None, 'Need to call init(dim_tokens_enc) function first'
+ # H, W = input_info['image_size']
+ image_size = self.image_size if image_size is None else image_size
+ H, W = image_size
+ # Number of patches in height and width
+ N_H = H // (self.stride_level * self.P_H)
+ N_W = W // (self.stride_level * self.P_W)
+
+ # Hook decoder onto 4 layers from specified ViT layers
+ layers = [encoder_tokens[hook] for hook in self.hooks]
+
+ # Extract only task-relevant tokens and ignore global tokens.
+ layers = [self.adapt_tokens(l) for l in layers]
+
+ # Reshape tokens to spatial representation
+ layers = [rearrange(l, 'b (nh nw) c -> b c nh nw', nh=N_H, nw=N_W) for l in layers]
+
+ layers = [self.act_postprocess[idx](l) for idx, l in enumerate(layers)]
+ # Project layers to chosen feature dim
+ layers = [self.scratch.layer_rn[idx](l) for idx, l in enumerate(layers)]
+
+ # Fuse layers using refinement stages
+ path_4 = self.scratch.refinenet4(layers[3])[:, :, :layers[2].shape[2], :layers[2].shape[3]]
+ path_3 = self.scratch.refinenet3(path_4, layers[2])
+ path_2 = self.scratch.refinenet2(path_3, layers[1])
+ path_1 = self.scratch.refinenet1(path_2, layers[0])
+
+ # if ray_embedding is not None:
+ # ray_embedding = F.interpolate(ray_embedding, size=(path_1.shape[2], path_1.shape[3]), mode='bilinear')
+ # path_1 = torch.cat([path_1, ray_embedding], dim=1)
+
+ # Output head
+ out = self.head(path_1)
+
+ return out
+
+
+class PixelwiseTaskWithDPT(nn.Module):
+ """ DPT module for dust3r, can return 3D points + confidence for all pixels"""
+
+ def __init__(self, *, n_cls_token=0, hooks_idx=None, dim_tokens=None,
+ output_width_ratio=1, num_channels=1, postprocess=None, depth_mode=None, conf_mode=None, **kwargs):
+ super(PixelwiseTaskWithDPT, self).__init__()
+ self.return_all_layers = True # backbone needs to return all layers
+ self.postprocess = postprocess
+ self.depth_mode = depth_mode
+ self.conf_mode = conf_mode
+
+ assert n_cls_token == 0, "Not implemented"
+ dpt_args = dict(output_width_ratio=output_width_ratio,
+ num_channels=num_channels,
+ **kwargs)
+ if hooks_idx is not None:
+ dpt_args.update(hooks=hooks_idx)
+ self.dpt = DPTOutputAdapter_fix(**dpt_args)
+ dpt_init_args = {} if dim_tokens is None else {'dim_tokens_enc': dim_tokens}
+ self.dpt.init(**dpt_init_args)
+
+ def forward(self, x, img_info, ray_embedding=None):
+ out = self.dpt(x, image_size=(img_info[0], img_info[1]), ray_embedding=ray_embedding)
+ if self.postprocess:
+ out = self.postprocess(out, self.depth_mode, self.conf_mode)
+ return out
+
+
+def create_dpt_head(net, has_conf=False, out_nchan=3, postprocess_func=postprocess):
+ """
+ return PixelwiseTaskWithDPT for given net params
+ """
+ assert net.dec_depth > 9
+ l2 = net.dec_depth
+ feature_dim = 256
+ last_dim = feature_dim//2
+ ed = net.enc_embed_dim
+ dd = net.dec_embed_dim
+ return PixelwiseTaskWithDPT(num_channels=out_nchan + has_conf,
+ feature_dim=feature_dim,
+ last_dim=last_dim,
+ hooks_idx=[0, l2*2//4, l2*3//4, l2],
+ dim_tokens=[ed, dd, dd, dd],
+ postprocess=postprocess_func,
+ depth_mode=net.depth_mode,
+ conf_mode=net.conf_mode,
+ head_type='regression')
diff --git a/src/model/encoder/heads/head_modules.py b/src/model/encoder/heads/head_modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..a13d97e088697ecee01f9721058e6394578cc9e1
--- /dev/null
+++ b/src/model/encoder/heads/head_modules.py
@@ -0,0 +1,327 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+import torch
+import torch.nn as nn
+import xformers.ops as xops
+from einops import rearrange
+from torch.nn import functional as F
+import numbers
+
+class RMSNorm(nn.Module):
+ def __init__(self, dim, eps=1e-6):
+ super(RMSNorm, self).__init__()
+ self.eps = eps
+ self.scale = nn.Parameter(torch.ones(dim))
+
+ def forward(self, x):
+ rms = torch.sqrt(torch.mean(x**2, dim=-1, keepdim=True) + self.eps)
+ return self.scale * x / rms
+
+
+class ResidualBlock(nn.Module):
+ def __init__(self, in_planes, planes, norm_fn='group', stride=1):
+ super(ResidualBlock, self).__init__()
+
+ self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride)
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
+ self.relu = nn.ReLU(inplace=True)
+
+ num_groups = planes // 8
+
+ if norm_fn == 'group':
+ self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
+ self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
+ if not (stride == 1 and in_planes == planes):
+ self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
+
+ elif norm_fn == 'batch':
+ self.norm1 = nn.BatchNorm2d(planes)
+ self.norm2 = nn.BatchNorm2d(planes)
+ if not (stride == 1 and in_planes == planes):
+ self.norm3 = nn.BatchNorm2d(planes)
+
+ elif norm_fn == 'instance':
+ self.norm1 = nn.InstanceNorm2d(planes)
+ self.norm2 = nn.InstanceNorm2d(planes)
+ if not (stride == 1 and in_planes == planes):
+ self.norm3 = nn.InstanceNorm2d(planes)
+
+ elif norm_fn == 'none':
+ self.norm1 = nn.Sequential()
+ self.norm2 = nn.Sequential()
+ if not (stride == 1 and in_planes == planes):
+ self.norm3 = nn.Sequential()
+
+ if stride == 1 and in_planes == planes:
+ self.downsample = None
+
+ else:
+ self.downsample = nn.Sequential(
+ nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
+
+ def forward(self, x):
+ y = x
+ y = self.conv1(y)
+ y = self.norm1(y)
+ y = self.relu(y)
+ y = self.conv2(y)
+ y = self.norm2(y)
+ y = self.relu(y)
+
+ if self.downsample is not None:
+ x = self.downsample(x)
+
+ return self.relu(x + y)
+
+
+class UnetExtractor(nn.Module):
+ def __init__(self, in_channel=3, encoder_dim=[256, 256, 256], norm_fn='group'):
+ super().__init__()
+ self.in_ds = nn.Sequential(
+ nn.Conv2d(in_channel, 64, kernel_size=7, stride=2, padding=3),
+ nn.GroupNorm(num_groups=8, num_channels=64),
+ nn.ReLU(inplace=True)
+ )
+
+ self.res1 = nn.Sequential(
+ ResidualBlock(64, encoder_dim[0], stride=2, norm_fn=norm_fn),
+ ResidualBlock(encoder_dim[0], encoder_dim[0], norm_fn=norm_fn)
+ )
+ self.res2 = nn.Sequential(
+ ResidualBlock(encoder_dim[0], encoder_dim[1], stride=2, norm_fn=norm_fn),
+ ResidualBlock(encoder_dim[1], encoder_dim[1], norm_fn=norm_fn)
+ )
+ self.res3 = nn.Sequential(
+ ResidualBlock(encoder_dim[1], encoder_dim[2], stride=2, norm_fn=norm_fn),
+ ResidualBlock(encoder_dim[2], encoder_dim[2], norm_fn=norm_fn),
+ )
+
+ def forward(self, x):
+ x = self.in_ds(x)
+ x1 = self.res1(x)
+ x2 = self.res2(x1)
+ x3 = self.res3(x2)
+
+ return x1, x2, x3
+
+
+class MultiBasicEncoder(nn.Module):
+ def __init__(self, output_dim=[128], encoder_dim=[64, 96, 128]):
+ super(MultiBasicEncoder, self).__init__()
+
+ # output convolution for feature
+ self.conv2 = nn.Sequential(
+ ResidualBlock(encoder_dim[2], encoder_dim[2], stride=1),
+ nn.Conv2d(encoder_dim[2], encoder_dim[2] * 2, 3, padding=1))
+
+ # output convolution for context
+ output_list = []
+ for dim in output_dim:
+ conv_out = nn.Sequential(
+ ResidualBlock(encoder_dim[2], encoder_dim[2], stride=1),
+ nn.Conv2d(encoder_dim[2], dim[2], 3, padding=1))
+ output_list.append(conv_out)
+
+ self.outputs08 = nn.ModuleList(output_list)
+
+ def forward(self, x):
+ feat1, feat2 = self.conv2(x).split(dim=0, split_size=x.shape[0] // 2)
+
+ outputs08 = [f(x) for f in self.outputs08]
+ return outputs08, feat1, feat2
+
+
+
+# attention processor for appreaance head
+
+def _init_weights(m):
+ if isinstance(m, nn.Linear):
+ nn.init.normal_(m.weight, std=.02)
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+
+class Mlp(nn.Module):
+ def __init__(self, in_features, mlp_ratio=4., mlp_bias=False, out_features=None, act_layer=nn.GELU, drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = int(in_features * mlp_ratio)
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=mlp_bias)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=mlp_bias)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ """
+ x: (B, L, D)
+ Returns: same shape as input
+ """
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+class SelfAttention(nn.Module):
+ def __init__(self, dim, head_dim=64, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., use_flashatt_v2=True):
+ super().__init__()
+ assert dim % head_dim == 0, 'dim must be divisible by head_dim'
+ self.num_heads = dim // head_dim
+ self.scale = qk_scale or head_dim ** -0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop_p = attn_drop
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim, bias=False)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ self.norm_q = RMSNorm(head_dim, eps=1e-5)
+ self.norm_k = RMSNorm(head_dim, eps=1e-5)
+
+ self.use_flashatt_v2 = use_flashatt_v2
+
+ def forward(self, x):
+ """
+ x: (B, L, D)
+ Returns: same shape as input
+ """
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
+
+ if self.use_flashatt_v2:
+ qkv = qkv.permute(2, 0, 1, 3, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2] # (B, N, H, C)
+ q, k = self.norm_q(q).to(v.dtype), self.norm_k(k).to(v.dtype)
+ x = xops.memory_efficient_attention(q, k, v, op=(xops.fmha.flash.FwOp, xops.fmha.flash.BwOp), p=self.attn_drop_p)
+ x = rearrange(x, 'b n h d -> b n (h d)')
+
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+class CrossAttention(nn.Module):
+ def __init__(self, dim, head_dim=64, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., use_flashatt_v2=True):
+ super().__init__()
+ assert dim % head_dim == 0, 'dim must be divisible by head_dim'
+ self.num_heads = dim // head_dim
+ self.scale = qk_scale or head_dim ** -0.5
+
+ self.q = nn.Linear(dim, dim, bias=qkv_bias)
+ self.k = nn.Linear(dim, dim, bias=qkv_bias)
+ self.v = nn.Linear(dim, dim, bias=qkv_bias)
+
+ self.attn_drop_p = attn_drop
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim, bias=False)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ self.norm_q = RMSNorm(head_dim, eps=1e-5)
+ self.norm_k = RMSNorm(head_dim, eps=1e-5)
+
+ self.use_flashatt_v2 = use_flashatt_v2
+
+ def forward(self, x_q, x_kv):
+ """
+ x_q: query input (B, L_q, D)
+ x_kv: key-value input (B, L_kv, D)
+ Returns: same shape as query input (B, L_q, D)
+ """
+ B, N_q, C = x_q.shape
+ _, N_kv, _ = x_kv.shape
+
+ q = self.q(x_q).reshape(B, N_q, self.num_heads, C // self.num_heads)
+ k = self.k(x_kv).reshape(B, N_kv, self.num_heads, C // self.num_heads)
+ v = self.v(x_kv).reshape(B, N_kv, self.num_heads, C // self.num_heads)
+
+ if self.use_flashatt_v2:
+ q, k = self.norm_q(q).to(v.dtype), self.norm_k(k).to(v.dtype)
+ x = xops.memory_efficient_attention(
+ q, k, v,
+ op=(xops.fmha.flash.FwOp, xops.fmha.flash.BwOp),
+ p=self.attn_drop_p
+ )
+ x = rearrange(x, 'b n h d -> b n (h d)')
+
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class TransformerBlockSelfAttn(nn.Module):
+ def __init__(self, dim, head_dim, mlp_ratio=4., mlp_bias=False, qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_flashatt_v2=True):
+ super().__init__()
+ self.norm1 = norm_layer(dim, bias=False)
+ self.attn = SelfAttention(
+ dim, head_dim=head_dim, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, use_flashatt_v2=use_flashatt_v2)
+ self.norm2 = norm_layer(dim, bias=False)
+ self.mlp = Mlp(in_features=dim, mlp_ratio=mlp_ratio, mlp_bias=mlp_bias, act_layer=act_layer, drop=drop)
+
+ def forward(self, x):
+ """
+ x: (B, L, D)
+ Returns: same shape as input
+ """
+ y = self.attn(self.norm1(x))
+ x = x + y
+ x = x + self.mlp(self.norm2(x))
+ return x
+
+class TransformerBlockCrossAttn(nn.Module):
+ def __init__(self, dim, head_dim, mlp_ratio=4., mlp_bias=False, qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_flashatt_v2=True):
+ super().__init__()
+ self.norm1 = norm_layer(dim, bias=False)
+ self.attn = CrossAttention(
+ dim, head_dim=head_dim, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, use_flashatt_v2=use_flashatt_v2)
+ self.norm2 = norm_layer(dim, bias=False)
+ self.mlp = Mlp(in_features=dim, mlp_ratio=mlp_ratio, mlp_bias=mlp_bias, act_layer=act_layer, drop=drop)
+
+ def forward(self, x_list):
+ """
+ x_q: (B, L_q, D)
+ x_kv: (B, L_kv, D)
+ Returns: same shape as input
+ """
+ x_q, x_kv = x_list
+ y = self.attn(self.norm1(x_q), self.norm1(x_kv))
+ x = x_q + y
+ x = x + self.mlp(self.norm2(x))
+ return x
+
+class AppearanceTransformer(nn.Module):
+ def __init__(self, num_layers, attn_dim, head_dim, ca_incides=[1, 3, 5, 7]):
+ super().__init__()
+ self.attn_dim = attn_dim
+ self.num_layers = num_layers
+ self.blocks = nn.ModuleList()
+ self.ca_incides = ca_incides
+
+ for attn_index in range(num_layers):
+ self.blocks.append(TransformerBlockSelfAttn(self.attn_dim, head_dim))
+ self.blocks[-1].apply(_init_weights)
+
+ def forward(self, x, use_checkpoint=True):
+ """
+ input_tokens: (B, L, D)
+ aggregated_tokens: List of (B, L, D)
+ Returns: B and D remain the same, L might change if there are merge layers
+ """
+ for block in self.blocks:
+ if use_checkpoint:
+ x = torch.utils.checkpoint.checkpoint(block, x, use_reentrant=False)
+ else:
+ x = block(x)
+
+ return x
+
+
+if __name__ == '__main__':
+ data = torch.ones((1, 3, 1024, 1024))
+
+ model = UnetExtractor(in_channel=3, encoder_dim=[64, 96, 128])
+
+ x1, x2, x3 = model(data)
+ print(x1.shape, x2.shape, x3.shape)
diff --git a/src/model/encoder/heads/linear_gs_head.py b/src/model/encoder/heads/linear_gs_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..b5e8e3a4347b7d54f5ca1368d0de1d9cd2e240f0
--- /dev/null
+++ b/src/model/encoder/heads/linear_gs_head.py
@@ -0,0 +1,383 @@
+# Copyright (C) 2024-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+#
+# --------------------------------------------------------
+# dpt head implementation for DUST3R
+# Downstream heads assume inputs of size B x N x C (where N is the number of tokens) ;
+# or if it takes as input the output at every layer, the attribute return_all_layers should be set to True
+# the forward function also takes as input a dictionnary img_info with key "height" and "width"
+# for PixelwiseTask, the output will be of dimension B x num_channels x H x W
+# --------------------------------------------------------
+from einops import rearrange
+from typing import List, Tuple
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+# import dust3r.utils.path_to_croco
+from .dpt_block import DPTOutputAdapter, Interpolate, make_fusion_block
+from .head_modules import UnetExtractor, AppearanceTransformer, _init_weights
+from .postprocess import postprocess
+import torchvision
+
+def custom_interpolate(
+ x: torch.Tensor,
+ size: Tuple[int, int] = None,
+ scale_factor: float = None,
+ mode: str = "bilinear",
+ align_corners: bool = True,
+) -> torch.Tensor:
+ """
+ Custom interpolate to avoid INT_MAX issues in nn.functional.interpolate.
+ """
+ if size is None:
+ size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor))
+
+ INT_MAX = 1610612736
+
+ input_elements = size[0] * size[1] * x.shape[0] * x.shape[1]
+
+ if input_elements > INT_MAX:
+ chunks = torch.chunk(x, chunks=(input_elements // INT_MAX) + 1, dim=0)
+ interpolated_chunks = [
+ nn.functional.interpolate(chunk, size=size, mode=mode, align_corners=align_corners) for chunk in chunks
+ ]
+ x = torch.cat(interpolated_chunks, dim=0)
+ return x.contiguous()
+ else:
+ return nn.functional.interpolate(x, size=size, mode=mode, align_corners=align_corners)
+
+# class DPTOutputAdapter_fix(DPTOutputAdapter):
+# """
+# Adapt croco's DPTOutputAdapter implementation for dust3r:
+# remove duplicated weigths, and fix forward for dust3r
+# """
+#
+# def init(self, dim_tokens_enc=768):
+# super().init(dim_tokens_enc)
+# # these are duplicated weights
+# del self.act_1_postprocess
+# del self.act_2_postprocess
+# del self.act_3_postprocess
+# del self.act_4_postprocess
+#
+# self.scratch.refinenet1 = make_fusion_block(256 * 2, False, 1, expand=True)
+# self.scratch.refinenet2 = make_fusion_block(256 * 2, False, 1, expand=True)
+# self.scratch.refinenet3 = make_fusion_block(256 * 2, False, 1, expand=True)
+# # self.scratch.refinenet4 = make_fusion_block(256 * 2, False, 1)
+#
+# self.depth_encoder = UnetExtractor(in_channel=3)
+# self.feat_up = Interpolate(scale_factor=2, mode="bilinear", align_corners=True)
+# self.out_conv = nn.Conv2d(256+3+4, 256, kernel_size=3, padding=1)
+# self.out_relu = nn.ReLU(inplace=True)
+#
+# self.input_merger = nn.Sequential(
+# # nn.Conv2d(256+3+3+1, 256, kernel_size=3, padding=1),
+# nn.Conv2d(256+3+3, 256, kernel_size=3, padding=1),
+# nn.ReLU(),
+# )
+#
+# def forward(self, encoder_tokens: List[torch.Tensor], depths, imgs, image_size=None, conf=None):
+# assert self.dim_tokens_enc is not None, 'Need to call init(dim_tokens_enc) function first'
+# # H, W = input_info['image_size']
+# image_size = self.image_size if image_size is None else image_size
+# H, W = image_size
+# # Number of patches in height and width
+# N_H = H // (self.stride_level * self.P_H)
+# N_W = W // (self.stride_level * self.P_W)
+#
+# # Hook decoder onto 4 layers from specified ViT layers
+# layers = [encoder_tokens[hook] for hook in self.hooks]
+#
+# # Extract only task-relevant tokens and ignore global tokens.
+# layers = [self.adapt_tokens(l) for l in layers]
+#
+# # Reshape tokens to spatial representation
+# layers = [rearrange(l, 'b (nh nw) c -> b c nh nw', nh=N_H, nw=N_W) for l in layers]
+#
+# layers = [self.act_postprocess[idx](l) for idx, l in enumerate(layers)]
+# # Project layers to chosen feature dim
+# layers = [self.scratch.layer_rn[idx](l) for idx, l in enumerate(layers)]
+#
+# # get depth features
+# depth_features = self.depth_encoder(depths)
+# depth_feature1, depth_feature2, depth_feature3 = depth_features
+#
+# # Fuse layers using refinement stages
+# path_4 = self.scratch.refinenet4(layers[3])[:, :, :layers[2].shape[2], :layers[2].shape[3]]
+# path_3 = self.scratch.refinenet3(torch.cat([path_4, depth_feature3], dim=1), torch.cat([layers[2], depth_feature3], dim=1))
+# path_2 = self.scratch.refinenet2(torch.cat([path_3, depth_feature2], dim=1), torch.cat([layers[1], depth_feature2], dim=1))
+# path_1 = self.scratch.refinenet1(torch.cat([path_2, depth_feature1], dim=1), torch.cat([layers[0], depth_feature1], dim=1))
+# # path_3 = self.scratch.refinenet3(path_4, layers[2], depth_feature3)
+# # path_2 = self.scratch.refinenet2(path_3, layers[1], depth_feature2)
+# # path_1 = self.scratch.refinenet1(path_2, layers[0], depth_feature1)
+#
+# path_1 = self.feat_up(path_1)
+# path_1 = torch.cat([path_1, imgs, depths], dim=1)
+# if conf is not None:
+# path_1 = torch.cat([path_1, conf], dim=1)
+# path_1 = self.input_merger(path_1)
+#
+# # Output head
+# out = self.head(path_1)
+#
+# return out
+
+
+class DPTOutputAdapter_fix(DPTOutputAdapter):
+ """
+ Adapt croco's DPTOutputAdapter implementation for dust3r:
+ remove duplicated weigths, and fix forward for dust3r
+ """
+
+ def init(self, dim_tokens_enc=768):
+ super().init(dim_tokens_enc)
+ # these are duplicated weights
+ del self.act_1_postprocess
+ del self.act_2_postprocess
+ del self.act_3_postprocess
+ del self.act_4_postprocess
+
+ self.feat_up = Interpolate(scale_factor=2, mode="bilinear", align_corners=True)
+ # self.input_merger = nn.Sequential(
+ # # nn.Conv2d(256+3+3+1, 256, kernel_size=3, padding=1),
+ # # nn.Conv2d(3+6, 256, 7, 1, 3),
+ # nn.Conv2d(3, 256, 7, 1, 3),
+ # nn.ReLU(),
+ # )
+
+ def forward(self, encoder_tokens: List[torch.Tensor], depths, imgs, image_size=None, conf=None):
+ assert self.dim_tokens_enc is not None, 'Need to call init(dim_tokens_enc) function first'
+ # H, W = input_info['image_size']
+ image_size = self.image_size if image_size is None else image_size
+ H, W = image_size
+ # Number of patches in height and width
+ N_H = H // (self.stride_level * self.P_H)
+ N_W = W // (self.stride_level * self.P_W)
+
+ # Hook decoder onto 4 layers from specified ViT layers
+ layers = [encoder_tokens[hook] for hook in self.hooks]
+
+ # Extract only task-relevant tokens and ignore global tokens.
+ layers = [self.adapt_tokens(l) for l in layers]
+
+ # Reshape tokens to spatial representation
+ layers = [rearrange(l, 'b (nh nw) c -> b c nh nw', nh=N_H, nw=N_W) for l in layers]
+
+ layers = [self.act_postprocess[idx](l) for idx, l in enumerate(layers)]
+ # Project layers to chosen feature dim
+ layers = [self.scratch.layer_rn[idx](l) for idx, l in enumerate(layers)]
+
+ # Fuse layers using refinement stages
+ path_4 = self.scratch.refinenet4(layers[3])[:, :, :layers[2].shape[2], :layers[2].shape[3]]
+ path_3 = self.scratch.refinenet3(path_4, layers[2])
+ path_2 = self.scratch.refinenet2(path_3, layers[1])
+ path_1 = self.scratch.refinenet1(path_2, layers[0])
+
+ # direct_img_feat = self.input_merger(imgs)
+ # actually, we just do interpolate here
+ # path_1 = self.feat_up(path_1)
+ path_1 = custom_interpolate(path_1, size=(H, W), mode='bilinear', align_corners=True)
+ # path_1 = F.interpolate(path_1, size=(H, W), mode='bilinear', align_corners=True)
+ # path_1 = path_1 + direct_img_feat
+
+ # path_1 = torch.cat([path_1, imgs], dim=1)
+
+ # Output head
+ # out = self.head(path_1)
+ out = path_1
+ return out, [path_4, path_3, path_2]
+
+
+class PixelwiseTaskWithDPT(nn.Module):
+ """ DPT module for dust3r, can return 3D points + confidence for all pixels"""
+
+ def __init__(self, *, n_cls_token=0, hooks_idx=None, dim_tokens=None,
+ output_width_ratio=1, num_channels=1, postprocess=None, depth_mode=None, conf_mode=None, **kwargs):
+ super(PixelwiseTaskWithDPT, self).__init__()
+ self.return_all_layers = True # backbone needs to return all layers
+ self.postprocess = postprocess
+ self.depth_mode = depth_mode
+ self.conf_mode = conf_mode
+
+ assert n_cls_token == 0, "Not implemented"
+ dpt_args = dict(output_width_ratio=output_width_ratio,
+ num_channels=num_channels,
+ **kwargs)
+ if hooks_idx is not None:
+ dpt_args.update(hooks=hooks_idx)
+ self.dpt = DPTOutputAdapter_fix(**dpt_args)
+ dpt_init_args = {} if dim_tokens is None else {'dim_tokens_enc': dim_tokens}
+ self.dpt.init(**dpt_init_args)
+
+ def forward(self, x, depths, imgs, img_info, conf=None):
+ out, interm_feats = self.dpt(x, depths, imgs, image_size=(img_info[0], img_info[1]), conf=conf)
+ if self.postprocess:
+ out = self.postprocess(out, self.depth_mode, self.conf_mode)
+ return out, interm_feats
+
+class AttnBasedAppearanceHead(nn.Module):
+ """
+ Attention head Appearence Reconstruction
+ """
+
+ def __init__(self, num_channels, patch_size, feature_dim, last_dim, hooks_idx, dim_tokens, postprocess, depth_mode, conf_mode, head_type='gs_params'):
+ super().__init__()
+
+ self.num_channels = num_channels
+ self.patch_size = patch_size
+
+ self.hooks = hooks_idx
+
+ assert len(set(dim_tokens)) == 1
+
+ self.tokenizer = nn.Linear(3 * self.patch_size[0] ** 2 + 512, dim_tokens[0], bias=False)
+ self.C_feat = 128
+
+ self.vgg_feature_extractor = torchvision.models.vgg16(pretrained=True).features
+ # Freeze the VGG parameters
+ for param in self.vgg_feature_extractor.parameters():
+ param.requires_grad = False
+
+ self.token_decoder = nn.Sequential(
+ nn.Linear(dim_tokens[0] * (len(self.hooks) + 1), self.C_feat * (self.patch_size[0] ** 2)),
+ nn.SiLU(),
+ nn.Linear(self.C_feat * (self.patch_size[0] ** 2), self.C_feat * (self.patch_size[0] ** 2)),
+ )
+
+ self.pixel_linear = nn.Linear(self.C_feat, self.num_channels)
+
+
+ def img_pts_tokenizer(self, imgs):
+ _, _, H, W = imgs.shape
+ # Process images through VGG to extract features
+ # imgs = imgs.permute(0, 2, 3, 1).contiguous()
+ with torch.no_grad():
+ vgg_features = self.vgg_feature_extractor(imgs)
+
+ # 1. concat original images with vgg features and then patchify
+ vgg_features = F.interpolate(vgg_features, size=(H, W), mode='bilinear', align_corners=False)
+ combined = torch.cat([imgs, vgg_features], dim=1) # [B, C+512, H, W]
+ combined = combined.permute(0, 2, 3, 1).contiguous()
+
+ patch_size = self.patch_size
+ hh = H // patch_size[0]
+ ww = W // patch_size[1]
+ input_patches = rearrange(combined, "b (hh ph) (ww pw) c -> b (hh ww) (ph pw c)",
+ hh=hh, ww=ww, ph=patch_size[0], pw=patch_size[1])
+
+ input_tokens = self.tokenizer(input_patches)
+
+ # 2. only use vgg features, use a shallow conv to get the token
+
+ # # Combine original images with VGG features
+ # imgs = torch.cat([imgs, vgg_features], dim=1)
+ # imgs = imgs.permute(0, 2, 3, 1).flatten(1, 2).contiguous()
+
+ # # Pachify
+ # patch_size = self.patch_size
+ # hh = H // patch_size[0]
+ # ww = W // patch_size[1]
+ # input = rearrange(imgs, "b (hh ph ww pw) d -> b (hh ww) (ph pw d)", hh=hh, ww=ww, ph=patch_size[0], pw=patch_size[1])
+ # Tokenize the input images
+ input_tokens = self.tokenizer(input)
+ return input_tokens
+
+ def forward(self, x, depths, imgs, img_info, conf=None):
+ B, V, H, W = img_info
+ input_tokens = self.img_pts_tokenizer(imgs)
+ # Hook decoder onto 4 layers from specified ViT layers
+ layer_tokens = [x[hook] for hook in self.hooks] # [B, S, D]
+ # layer_tokens.append(input_tokens)
+
+ x = self.token_decoder(torch.cat(layer_tokens, dim=-1))
+ x = x.view(B*V, (H // self.patch_size[0]) * (W // self.patch_size[1]), self.patch_size[0]**2, self.C_feat).flatten(1, 2).contiguous()
+ out_flat = self.pixel_linear(x)
+
+ return out_flat.view(B*V, H, W, -1).permute(0, 3, 1, 2)
+
+# class Pixellevel_Linear_Pts3d(nn.Module):
+# """
+# Pixel-level linear head for DUST3R
+# Each pixel outputs: 3D point (+ confidence)
+# """
+
+# def __init__(self, dec_embed_dim, patch_size, depth_mode, conf_mode, has_conf=False, index_hook=[-1]):
+# super().__init__()
+# self.patch_size = patch_size
+# self.depth_mode = depth_mode
+# self.conf_mode = conf_mode
+# self.has_conf = has_conf
+# self.dec_embed_dim = dec_embed_dim
+# self.index_hook = index_hook
+
+# # Total embedding dimension per token (possibly concatenated)
+# D = self.dec_embed_dim * len(self.index_hook)
+# # Ensure divisible into pixel-level features
+# assert D % (self.patch_size**2) == 0, \
+# f"Embedding dim {D} not divisible by patch_size^2 ({self.patch_size**2})"
+# # Feature dimension for each pixel
+# self.C_feat = D // (self.patch_size**2) * 4
+# # Output channels: x,y,z (+ confidence)
+# self.out_dim = 3 + int(self.has_conf)
+
+# self.feat_expand = nn.Sequential(nn.Linear(D, 4*D),
+# nn.SiLU(),
+# nn.Linear(4*D, 4*D)
+# )
+# # Per-pixel linear head
+# self.pixel_linear = nn.Linear(self.C_feat, self.out_dim)
+
+# def setup(self, croconet):
+# pass
+
+# def forward(self, decout, img_shape):
+# H, W = img_shape
+# # Combine specified decoder tokens: B x num_patches x D
+# tokens = [decout[i] for i in self.index_hook]
+# x = torch.cat(tokens, dim=-1) # B, S, D
+# x = self.feat_expand(x)
+# B, S, D = x.shape
+
+# # Validate pixel count
+# assert S * (self.patch_size**2) == H * W, \
+# f"Mismatch: S*ps^2 ({S*self.patch_size**2}) != H*W ({H*W})"
+
+# # 1. Reshape embedding into pixel features
+# # x -> B, S, (ps^2), C_feat -> flatten to B, (S*ps^2), C_feat
+# x = x.view(B, S, self.patch_size**2, self.C_feat)
+# x = x.reshape(B, S * self.patch_size**2, self.C_feat)
+
+# # 2. Per-pixel linear output
+# out_flat = self.pixel_linear(x) # B, S*ps^2, out_dim
+
+# # 3. Reshape to image map: B x out_dim x H x W
+# out = out_flat.permute(0, 2, 1).view(B, self.out_dim, H, W)
+
+# # 4. Postprocess depth/conf
+# return out
+
+def create_gs_linear_head(net, has_conf=False, out_nchan=3, postprocess_func=postprocess):
+ """
+ return PixelwiseTaskWithDPT for given net params
+ """
+ assert net.dec_depth > 9
+ l2 = net.dec_depth
+ feature_dim = net.feature_dim
+ last_dim = feature_dim//2
+ ed = net.enc_embed_dim
+ dd = net.dec_embed_dim
+ try:
+ patch_size = net.patch_size
+ except:
+ patch_size = (16, 16)
+
+ return AttnBasedAppearanceHead(num_channels=out_nchan + has_conf,
+ patch_size=patch_size,
+ feature_dim=feature_dim,
+ last_dim=last_dim,
+ hooks_idx=[0, l2*2//4, l2*3//4, l2],
+ dim_tokens=[ed, dd, dd, dd],
+ postprocess=postprocess_func,
+ depth_mode=net.depth_mode,
+ conf_mode=net.conf_mode,
+ head_type='gs_params')
diff --git a/src/model/encoder/heads/linear_head.py b/src/model/encoder/heads/linear_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..ffcd680d9a165be4b9f742a09d1e8741f82a4a02
--- /dev/null
+++ b/src/model/encoder/heads/linear_head.py
@@ -0,0 +1,73 @@
+# Copyright (C) 2024-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+#
+# --------------------------------------------------------
+# linear head implementation for DUST3R
+# --------------------------------------------------------
+import torch.nn as nn
+import torch.nn.functional as F
+from .postprocess import postprocess
+
+
+class LinearPts3d (nn.Module):
+ """
+ Linear head for dust3r
+ Each token outputs: - 16x16 3D points (+ confidence)
+ """
+
+ def __init__(self, net, has_conf=False):
+ super().__init__()
+ self.patch_size = net.patch_embed.patch_size[0]
+ self.depth_mode = net.depth_mode
+ self.conf_mode = net.conf_mode
+ self.has_conf = has_conf
+
+ self.proj = nn.Linear(net.dec_embed_dim, (3 + has_conf)*self.patch_size**2)
+
+ def setup(self, croconet):
+ pass
+
+ def forward(self, decout, img_shape):
+ H, W = img_shape
+ tokens = decout[-1]
+ B, S, D = tokens.shape
+
+ # extract 3D points
+ feat = self.proj(tokens) # B,S,D
+ feat = feat.transpose(-1, -2).view(B, -1, H//self.patch_size, W//self.patch_size)
+ feat = F.pixel_shuffle(feat, self.patch_size) # B,3,H,W
+
+ # permute + norm depth
+ return postprocess(feat, self.depth_mode, self.conf_mode)
+
+
+class LinearGS(nn.Module):
+ """
+ Linear head for GS parameter prediction
+ Each token outputs: - 16x16 3D points (+ confidence)
+ """
+
+ def __init__(self, net, has_conf=False):
+ super().__init__()
+ self.patch_size = net.patch_embed.patch_size[0]
+ self.depth_mode = net.depth_mode
+ self.conf_mode = net.conf_mode
+ self.has_conf = has_conf
+
+ self.proj = nn.Linear(net.dec_embed_dim, (2 + 1 + net.gaussian_adapter.d_in)*self.patch_size**2) # 2 for xy offset, 1 for opacity
+
+ def setup(self, croconet):
+ pass
+
+ def forward(self, decout, img_shape):
+ H, W = img_shape
+ tokens = decout[-1]
+ B, S, D = tokens.shape
+
+ # extract 3D points
+ feat = self.proj(tokens) # B,S,D
+ feat = feat.transpose(-1, -2).view(B, -1, H//self.patch_size, W//self.patch_size)
+ feat = F.pixel_shuffle(feat, self.patch_size) # B,3,H,W
+
+ # permute + norm depth
+ return postprocess(feat, self.depth_mode, self.conf_mode)
diff --git a/src/model/encoder/heads/postprocess.py b/src/model/encoder/heads/postprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..c8ef2f34a9a00deee119d4813c3dc9dd51c4f333
--- /dev/null
+++ b/src/model/encoder/heads/postprocess.py
@@ -0,0 +1,77 @@
+# Copyright (C) 2024-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+#
+# --------------------------------------------------------
+# post process function for all heads: extract 3D points/confidence from output
+# --------------------------------------------------------
+import torch
+
+
+def postprocess(out, depth_mode, conf_mode):
+ """
+ extract 3D points/confidence from prediction head output
+ """
+ fmap = out.permute(0, 2, 3, 1) # B,H,W,3
+ res = dict(pts3d=reg_dense_depth(fmap[:, :, :, 0:3], mode=depth_mode))
+
+ if conf_mode is not None:
+ res['conf'] = reg_dense_conf(fmap[:, :, :, 3], mode=conf_mode)
+ return res
+
+
+def reg_dense_depth(xyz, mode):
+ """
+ extract 3D points from prediction head output
+ """
+ mode, vmin, vmax = mode
+
+ no_bounds = (vmin == -float('inf')) and (vmax == float('inf'))
+ # assert no_bounds
+
+ if mode == 'range':
+ xyz = xyz.sigmoid()
+ xyz = (1 - xyz) * vmin + xyz * vmax
+ return xyz
+
+ if mode == 'linear':
+ if no_bounds:
+ return xyz # [-inf, +inf]
+ return xyz.clip(min=vmin, max=vmax)
+
+ if mode == 'exp_direct':
+ xyz = xyz.expm1()
+ return xyz.clip(min=vmin, max=vmax)
+
+ # distance to origin
+ d = xyz.norm(dim=-1, keepdim=True)
+ xyz = xyz / d.clip(min=1e-8)
+
+ if mode == 'square':
+ return xyz * d.square()
+
+ if mode == 'exp':
+ exp_d = d.expm1()
+ if not no_bounds:
+ exp_d = exp_d.clip(min=vmin, max=vmax)
+ xyz = xyz * exp_d
+ # if not no_bounds:
+ # # xyz = xyz.clip(min=vmin, max=vmax)
+ # depth = xyz.clone()[..., 2].clip(min=vmin, max=vmax)
+ # xyz = torch.cat([xyz[..., :2], depth.unsqueeze(-1)], dim=-1)
+ return xyz
+
+ raise ValueError(f'bad {mode=}')
+
+
+def reg_dense_conf(x, mode):
+ """
+ extract confidence from prediction head output
+ """
+ mode, vmin, vmax = mode
+ if mode == 'opacity':
+ return x.sigmoid()
+ if mode == 'exp':
+ return vmin + x.exp().clip(max=vmax-vmin)
+ if mode == 'sigmoid':
+ return (vmax - vmin) * torch.sigmoid(x) + vmin
+ raise ValueError(f'bad {mode=}')
diff --git a/src/model/encoder/heads/vggt_dpt_gs_head.py b/src/model/encoder/heads/vggt_dpt_gs_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..b197adca95dc41c2e91194cc95474d01fd228a21
--- /dev/null
+++ b/src/model/encoder/heads/vggt_dpt_gs_head.py
@@ -0,0 +1,195 @@
+# Copyright (C) 2024-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+#
+# --------------------------------------------------------
+# dpt head implementation for DUST3R
+# Downstream heads assume inputs of size B x N x C (where N is the number of tokens) ;
+# or if it takes as input the output at every layer, the attribute return_all_layers should be set to True
+# the forward function also takes as input a dictionnary img_info with key "height" and "width"
+# for PixelwiseTask, the output will be of dimension B x num_channels x H x W
+# --------------------------------------------------------
+from einops import rearrange
+from typing import List
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+# import dust3r.utils.path_to_croco
+from .dpt_block import DPTOutputAdapter, Interpolate, make_fusion_block
+from src.model.encoder.vggt.heads.dpt_head import DPTHead
+from .head_modules import UnetExtractor, AppearanceTransformer, _init_weights
+from .postprocess import postprocess
+
+
+ # def __init__(self,
+ # num_channels: int = 1,
+ # stride_level: int = 1,
+ # patch_size: Union[int, Tuple[int, int]] = 16,
+ # main_tasks: Iterable[str] = ('rgb',),
+ # hooks: List[int] = [2, 5, 8, 11],
+ # layer_dims: List[int] = [96, 192, 384, 768],
+ # feature_dim: int = 256,
+ # last_dim: int = 32,
+ # use_bn: bool = False,
+ # dim_tokens_enc: Optional[int] = None,
+ # head_type: str = 'regression',
+ # output_width_ratio=1,
+
+class VGGT_DPT_GS_Head(DPTHead):
+ def __init__(self,
+ dim_in: int,
+ patch_size: int = 14,
+ output_dim: int = 83,
+ activation: str = "inv_log",
+ conf_activation: str = "expp1",
+ features: int = 256,
+ out_channels: List[int] = [256, 512, 1024, 1024],
+ intermediate_layer_idx: List[int] = [4, 11, 17, 23],
+ pos_embed: bool = True,
+ feature_only: bool = False,
+ down_ratio: int = 1,
+ ):
+ super().__init__(dim_in, patch_size, output_dim, activation, conf_activation, features, out_channels, intermediate_layer_idx, pos_embed, feature_only, down_ratio)
+
+ head_features_1 = 128
+ head_features_2 = 128 if output_dim > 50 else 32 # sh=0, head_features_2 = 32; sh=4, head_features_2 = 128
+ self.input_merger = nn.Sequential(
+ nn.Conv2d(3, head_features_2, 7, 1, 3),
+ nn.ReLU(),
+ )
+
+ self.scratch.output_conv2 = nn.Sequential(
+ nn.Conv2d(head_features_1, head_features_2, kernel_size=3, stride=1, padding=1),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0),
+ )
+
+ def forward(self, encoder_tokens: List[torch.Tensor], depths, imgs, patch_start_idx: int = 5, image_size=None, conf=None, frames_chunk_size: int = 8):
+ # H, W = input_info['image_size']
+ B, S, _, H, W = imgs.shape
+ image_size = self.image_size if image_size is None else image_size
+
+ # If frames_chunk_size is not specified or greater than S, process all frames at once
+ if frames_chunk_size is None or frames_chunk_size >= S:
+ return self._forward_impl(encoder_tokens, imgs, patch_start_idx)
+
+ # Otherwise, process frames in chunks to manage memory usage
+ assert frames_chunk_size > 0
+
+ # Process frames in batches
+ all_preds = []
+
+ for frames_start_idx in range(0, S, frames_chunk_size):
+ frames_end_idx = min(frames_start_idx + frames_chunk_size, S)
+
+ # Process batch of frames
+ chunk_output = self._forward_impl(
+ encoder_tokens, imgs, patch_start_idx, frames_start_idx, frames_end_idx
+ )
+ all_preds.append(chunk_output)
+
+ # Concatenate results along the sequence dimension
+ return torch.cat(all_preds, dim=1)
+
+ def _forward_impl(self, encoder_tokens: List[torch.Tensor], imgs, patch_start_idx: int = 5, frames_start_idx: int = None, frames_end_idx: int = None):
+
+ if frames_start_idx is not None and frames_end_idx is not None:
+ imgs = imgs[:, frames_start_idx:frames_end_idx]
+
+ B, S, _, H, W = imgs.shape
+
+ patch_h, patch_w = H // self.patch_size[0], W // self.patch_size[1]
+
+ out = []
+ dpt_idx = 0
+ for layer_idx in self.intermediate_layer_idx:
+ # x = encoder_tokens[layer_idx][:, :, patch_start_idx:]
+ if len(encoder_tokens) > 10:
+ x = encoder_tokens[layer_idx][:, :, patch_start_idx:]
+ else:
+ list_idx = self.intermediate_layer_idx.index(layer_idx)
+ x = encoder_tokens[list_idx][:, :, patch_start_idx:]
+
+ # Select frames if processing a chunk
+ if frames_start_idx is not None and frames_end_idx is not None:
+ x = x[:, frames_start_idx:frames_end_idx].contiguous()
+
+ x = x.view(B * S, -1, x.shape[-1])
+
+ x = self.norm(x)
+
+ x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
+
+ x = self.projects[dpt_idx](x)
+ if self.pos_embed:
+ x = self._apply_pos_embed(x, W, H)
+ x = self.resize_layers[dpt_idx](x)
+
+ out.append(x)
+ dpt_idx += 1
+
+ # Fuse features from multiple layers.
+ out = self.scratch_forward(out)
+ direct_img_feat = self.input_merger(imgs.flatten(0,1))
+ out = F.interpolate(out, size=(H, W), mode='bilinear', align_corners=True)
+ out = out + direct_img_feat
+
+ if self.pos_embed:
+ out = self._apply_pos_embed(out, W, H)
+
+ out = self.scratch.output_conv2(out)
+ out = out.view(B, S, *out.shape[1:])
+ return out
+
+
+class PixelwiseTaskWithDPT(nn.Module):
+ """ DPT module for dust3r, can return 3D points + confidence for all pixels"""
+
+ def __init__(self, *, n_cls_token=0, hooks_idx=None, dim_tokens=None,
+ output_width_ratio=1, num_channels=1, postprocess=None, depth_mode=None, conf_mode=None, **kwargs):
+ super(PixelwiseTaskWithDPT, self).__init__()
+ self.return_all_layers = True # backbone needs to return all layers
+ self.postprocess = postprocess
+ self.depth_mode = depth_mode
+ self.conf_mode = conf_mode
+
+ assert n_cls_token == 0, "Not implemented"
+ dpt_args = dict(output_width_ratio=output_width_ratio,
+ num_channels=num_channels,
+ **kwargs)
+ if hooks_idx is not None:
+ dpt_args.update(hooks=hooks_idx)
+ self.dpt = DPTOutputAdapter_fix(**dpt_args)
+ dpt_init_args = {} if dim_tokens is None else {'dim_tokens_enc': dim_tokens}
+ self.dpt.init(**dpt_init_args)
+
+ def forward(self, x, depths, imgs, img_info, conf=None):
+ out, interm_feats = self.dpt(x, depths, imgs, image_size=(img_info[0], img_info[1]), conf=conf)
+ if self.postprocess:
+ out = self.postprocess(out, self.depth_mode, self.conf_mode)
+ return out, interm_feats
+
+def create_gs_dpt_head(net, has_conf=False, out_nchan=3, postprocess_func=postprocess):
+ """
+ return PixelwiseTaskWithDPT for given net params
+ """
+ assert net.dec_depth > 9
+ l2 = net.dec_depth
+ feature_dim = net.feature_dim
+ last_dim = feature_dim//2
+ ed = net.enc_embed_dim
+ dd = net.dec_embed_dim
+ try:
+ patch_size = net.patch_size
+ except:
+ patch_size = (16, 16)
+
+ return PixelwiseTaskWithDPT(num_channels=out_nchan + has_conf,
+ patch_size=patch_size,
+ feature_dim=feature_dim,
+ last_dim=last_dim,
+ hooks_idx=[0, l2*2//4, l2*3//4, l2],
+ dim_tokens=[ed, dd, dd, dd],
+ postprocess=postprocess_func,
+ depth_mode=net.depth_mode,
+ conf_mode=net.conf_mode,
+ head_type='gs_params')
\ No newline at end of file
diff --git a/src/model/encoder/vggt/heads/camera_head.py b/src/model/encoder/vggt/heads/camera_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..50147de678f0e5638d766dd6900ef3e0525fc13c
--- /dev/null
+++ b/src/model/encoder/vggt/heads/camera_head.py
@@ -0,0 +1,168 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+import numpy as np
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from src.model.encoder.vggt.layers import Mlp
+from src.model.encoder.vggt.layers.block import Block
+from src.model.encoder.vggt.heads.head_act import activate_pose
+
+
+class CameraHead(nn.Module):
+ """
+ CameraHead predicts camera parameters from token representations using iterative refinement.
+
+ It applies a series of transformer blocks (the "trunk") to dedicated camera tokens.
+ """
+
+ def __init__(
+ self,
+ dim_in: int = 2048,
+ trunk_depth: int = 4,
+ pose_encoding_type: str = "absT_quaR_FoV",
+ num_heads: int = 16,
+ mlp_ratio: int = 4,
+ init_values: float = 0.01,
+ trans_act: str = "linear",
+ quat_act: str = "linear",
+ fl_act: str = "relu", # Field of view activations: ensures FOV values are positive.
+ ):
+ super().__init__()
+
+ if pose_encoding_type == "absT_quaR_FoV":
+ self.target_dim = 9
+ else:
+ raise ValueError(f"Unsupported camera encoding type: {pose_encoding_type}")
+
+ self.trans_act = trans_act
+ self.quat_act = quat_act
+ self.fl_act = fl_act
+ self.trunk_depth = trunk_depth
+
+ # Build the trunk using a sequence of transformer blocks.
+ self.trunk = nn.Sequential(
+ *[
+ Block(
+ dim=dim_in,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ init_values=init_values,
+ )
+ for _ in range(trunk_depth)
+ ]
+ )
+
+ # Normalizations for camera token and trunk output.
+ self.token_norm = nn.LayerNorm(dim_in)
+ self.trunk_norm = nn.LayerNorm(dim_in)
+
+ # Learnable empty camera pose token.
+ self.empty_pose_tokens = nn.Parameter(torch.zeros(1, 1, self.target_dim))
+ self.embed_pose = nn.Linear(self.target_dim, dim_in)
+
+ # Module for producing modulation parameters: shift, scale, and a gate.
+ self.poseLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim_in, 3 * dim_in, bias=True))
+
+ # Adaptive layer normalization without affine parameters.
+ self.adaln_norm = nn.LayerNorm(dim_in, elementwise_affine=False, eps=1e-6)
+ self.pose_branch = Mlp(
+ in_features=dim_in,
+ hidden_features=dim_in // 2,
+ out_features=self.target_dim,
+ drop=0,
+ )
+
+ def forward(self, aggregated_tokens_list: list, num_iterations: int = 4) -> list:
+ """
+ Forward pass to predict camera parameters.
+
+ Args:
+ aggregated_tokens_list (list): List of token tensors from the network;
+ the last tensor is used for prediction.
+ num_iterations (int, optional): Number of iterative refinement steps. Defaults to 4.
+
+ Returns:
+ list: A list of predicted camera encodings (post-activation) from each iteration.
+ """
+ # Use tokens from the last block for camera prediction.
+ tokens = aggregated_tokens_list[-1]
+
+ # Extract the camera tokens
+ pose_tokens = tokens[:, :, 0]
+
+ pose_tokens = self.token_norm(pose_tokens)
+
+ pred_pose_enc_list = self.trunk_fn(pose_tokens, num_iterations)
+ return pred_pose_enc_list
+
+ def trunk_fn(self, pose_tokens: torch.Tensor, num_iterations: int) -> list:
+ """
+ Iteratively refine camera pose predictions.
+
+ Args:
+ pose_tokens (torch.Tensor): Normalized camera tokens with shape [B, 1, C].
+ num_iterations (int): Number of refinement iterations.
+
+ Returns:
+ list: List of activated camera encodings from each iteration.
+ """
+ B, S, C = pose_tokens.shape # S is expected to be 1.
+ pred_pose_enc = None
+ pred_pose_enc_list = []
+
+ for _ in range(num_iterations):
+ # Use a learned empty pose for the first iteration.
+ if pred_pose_enc is None:
+ module_input = self.embed_pose(self.empty_pose_tokens.expand(B, S, -1))
+ else:
+ # Detach the previous prediction to avoid backprop through time.
+ pred_pose_enc = pred_pose_enc.detach()
+ module_input = self.embed_pose(pred_pose_enc)
+
+ # Generate modulation parameters and split them into shift, scale, and gate components.
+ shift_msa, scale_msa, gate_msa = self.poseLN_modulation(module_input).chunk(3, dim=-1)
+
+ # Adaptive layer normalization and modulation.
+ pose_tokens_modulated = gate_msa * modulate(self.adaln_norm(pose_tokens), shift_msa, scale_msa)
+ pose_tokens_modulated = pose_tokens_modulated + pose_tokens
+
+ pose_tokens_modulated = torch.utils.checkpoint.checkpoint(
+ self.trunk,
+ pose_tokens_modulated,
+ use_reentrant=False,
+ )
+
+ # Compute the delta update for the pose encoding.
+ pred_pose_enc_delta = self.pose_branch(self.trunk_norm(pose_tokens_modulated))
+
+ if pred_pose_enc is None:
+ pred_pose_enc = pred_pose_enc_delta
+ else:
+ pred_pose_enc = pred_pose_enc + pred_pose_enc_delta
+
+ # Apply final activation functions for translation, quaternion, and field-of-view.
+ activated_pose = activate_pose(
+ pred_pose_enc,
+ trans_act=self.trans_act,
+ quat_act=self.quat_act,
+ fl_act=self.fl_act,
+ )
+ pred_pose_enc_list.append(activated_pose)
+
+ return pred_pose_enc_list
+
+
+def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
+ """
+ Modulate the input tensor using scaling and shifting parameters.
+ """
+ # modified from https://github.com/facebookresearch/DiT/blob/796c29e532f47bba17c5b9c5eb39b9354b8b7c64/models.py#L19
+ return x * (1 + scale) + shift
diff --git a/src/model/encoder/vggt/heads/dpt_head.py b/src/model/encoder/vggt/heads/dpt_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..d6df2b7ec28654eef6d1a5e127f4ad6be45dcffa
--- /dev/null
+++ b/src/model/encoder/vggt/heads/dpt_head.py
@@ -0,0 +1,502 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+# Inspired by https://github.com/DepthAnything/Depth-Anything-V2
+
+
+import os
+from typing import List, Dict, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from .head_act import activate_head
+from .utils import create_uv_grid, position_grid_to_embed
+
+
+class DPTHead(nn.Module):
+ """
+ DPT Head for dense prediction tasks.
+
+ This implementation follows the architecture described in "Vision Transformers for Dense Prediction"
+ (https://arxiv.org/abs/2103.13413). The DPT head processes features from a vision transformer
+ backbone and produces dense predictions by fusing multi-scale features.
+
+ Args:
+ dim_in (int): Input dimension (channels).
+ patch_size (int, optional): Patch size. Default is 14.
+ output_dim (int, optional): Number of output channels. Default is 4.
+ activation (str, optional): Activation type. Default is "inv_log".
+ conf_activation (str, optional): Confidence activation type. Default is "expp1".
+ features (int, optional): Feature channels for intermediate representations. Default is 256.
+ out_channels (List[int], optional): Output channels for each intermediate layer.
+ intermediate_layer_idx (List[int], optional): Indices of layers from aggregated tokens used for DPT.
+ pos_embed (bool, optional): Whether to use positional embedding. Default is True.
+ feature_only (bool, optional): If True, return features only without the last several layers and activation head. Default is False.
+ down_ratio (int, optional): Downscaling factor for the output resolution. Default is 1.
+ """
+
+ def __init__(
+ self,
+ dim_in: int,
+ patch_size: int = 14,
+ output_dim: int = 4,
+ activation: str = "inv_log",
+ conf_activation: str = "expp1",
+ features: int = 256,
+ out_channels: List[int] = [256, 512, 1024, 1024],
+ intermediate_layer_idx: List[int] = [4, 11, 17, 23],
+ pos_embed: bool = True,
+ feature_only: bool = False,
+ down_ratio: int = 1,
+ ) -> None:
+ super(DPTHead, self).__init__()
+ self.patch_size = patch_size
+ self.activation = activation
+ self.conf_activation = conf_activation
+ self.pos_embed = pos_embed
+ self.feature_only = feature_only
+ self.down_ratio = down_ratio
+ self.intermediate_layer_idx = intermediate_layer_idx
+
+ self.norm = nn.LayerNorm(dim_in)
+
+ # Projection layers for each output channel from tokens.
+ self.projects = nn.ModuleList(
+ [
+ nn.Conv2d(
+ in_channels=dim_in,
+ out_channels=oc,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ )
+ for oc in out_channels
+ ]
+ )
+
+ # Resize layers for upsampling feature maps.
+ self.resize_layers = nn.ModuleList(
+ [
+ nn.ConvTranspose2d(
+ in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0
+ ),
+ nn.ConvTranspose2d(
+ in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0
+ ),
+ nn.Identity(),
+ nn.Conv2d(
+ in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1
+ ),
+ ]
+ )
+
+ self.scratch = _make_scratch(
+ out_channels,
+ features,
+ expand=False,
+ )
+
+ # Attach additional modules to scratch.
+ self.scratch.stem_transpose = None
+ self.scratch.refinenet1 = _make_fusion_block(features)
+ self.scratch.refinenet2 = _make_fusion_block(features)
+ self.scratch.refinenet3 = _make_fusion_block(features)
+ self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False)
+
+ head_features_1 = features
+ head_features_2 = 32
+
+ if feature_only:
+ self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1, kernel_size=3, stride=1, padding=1)
+ else:
+ self.scratch.output_conv1 = nn.Conv2d(
+ head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1
+ )
+ conv2_in_channels = head_features_1 // 2
+
+ self.scratch.output_conv2 = nn.Sequential(
+ nn.Conv2d(conv2_in_channels, head_features_2, kernel_size=3, stride=1, padding=1),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0),
+ )
+
+ def forward(
+ self,
+ aggregated_tokens_list: List[torch.Tensor],
+ images: torch.Tensor,
+ patch_start_idx: int,
+ frames_chunk_size: int = 8,
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
+ """
+ Forward pass through the DPT head, supports processing by chunking frames.
+ Args:
+ aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers.
+ images (Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1].
+ patch_start_idx (int): Starting index for patch tokens in the token sequence.
+ Used to separate patch tokens from other tokens (e.g., camera or register tokens).
+ frames_chunk_size (int, optional): Number of frames to process in each chunk.
+ If None or larger than S, all frames are processed at once. Default: 8.
+
+ Returns:
+ Tensor or Tuple[Tensor, Tensor]:
+ - If feature_only=True: Feature maps with shape [B, S, C, H, W]
+ - Otherwise: Tuple of (predictions, confidence) both with shape [B, S, 1, H, W]
+ """
+ B, S, _, H, W = images.shape
+
+ # If frames_chunk_size is not specified or greater than S, process all frames at once
+ if frames_chunk_size is None or frames_chunk_size >= S:
+ return self._forward_impl(aggregated_tokens_list, images, patch_start_idx)
+
+ # Otherwise, process frames in chunks to manage memory usage
+ assert frames_chunk_size > 0
+
+ # Process frames in batches
+ all_preds = []
+ all_conf = []
+
+ for frames_start_idx in range(0, S, frames_chunk_size):
+ frames_end_idx = min(frames_start_idx + frames_chunk_size, S)
+
+ # Process batch of frames
+ if self.feature_only:
+ chunk_output = self._forward_impl(
+ aggregated_tokens_list, images, patch_start_idx, frames_start_idx, frames_end_idx
+ )
+ all_preds.append(chunk_output)
+ else:
+ chunk_preds, chunk_conf = self._forward_impl(
+ aggregated_tokens_list, images, patch_start_idx, frames_start_idx, frames_end_idx
+ )
+ all_preds.append(chunk_preds)
+ all_conf.append(chunk_conf)
+
+ # Concatenate results along the sequence dimension
+ if self.feature_only:
+ return torch.cat(all_preds, dim=1)
+ else:
+ return torch.cat(all_preds, dim=1), torch.cat(all_conf, dim=1)
+
+ def _forward_impl(
+ self,
+ aggregated_tokens_list: List[torch.Tensor],
+ images: torch.Tensor,
+ patch_start_idx: int,
+ frames_start_idx: int = None,
+ frames_end_idx: int = None,
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
+ """
+ Implementation of the forward pass through the DPT head.
+
+ This method processes a specific chunk of frames from the sequence.
+
+ Args:
+ aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers.
+ images (Tensor): Input images with shape [B, S, 3, H, W].
+ patch_start_idx (int): Starting index for patch tokens.
+ frames_start_idx (int, optional): Starting index for frames to process.
+ frames_end_idx (int, optional): Ending index for frames to process.
+
+ Returns:
+ Tensor or Tuple[Tensor, Tensor]: Feature maps or (predictions, confidence).
+ """
+ if frames_start_idx is not None and frames_end_idx is not None:
+ images = images[:, frames_start_idx:frames_end_idx]
+
+ B, S, _, H, W = images.shape
+
+ patch_h, patch_w = H // self.patch_size, W // self.patch_size
+
+ out = []
+ dpt_idx = 0
+
+ for layer_idx in self.intermediate_layer_idx:
+ # x = aggregated_tokens_list[layer_idx][:, :, patch_start_idx:]
+ if len(aggregated_tokens_list) > 10:
+ x = aggregated_tokens_list[layer_idx][:, :, patch_start_idx:]
+ else:
+ list_idx = self.intermediate_layer_idx.index(layer_idx)
+ x = aggregated_tokens_list[list_idx][:, :, patch_start_idx:]
+
+ # Select frames if processing a chunk
+ if frames_start_idx is not None and frames_end_idx is not None:
+ x = x[:, frames_start_idx:frames_end_idx].contiguous()
+
+ x = x.view(B * S, -1, x.shape[-1])
+ x = self.norm(x)
+
+ x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
+
+ x = self.projects[dpt_idx](x)
+ if self.pos_embed:
+ x = self._apply_pos_embed(x, W, H)
+
+ x = self.resize_layers[dpt_idx](x)
+
+ out.append(x)
+ dpt_idx += 1
+
+ # Fuse features from multiple layers.
+ out = self.scratch_forward(out)
+ # Interpolate fused output to match target image resolution.
+ out = custom_interpolate(
+ out,
+ (int(patch_h * self.patch_size / self.down_ratio), int(patch_w * self.patch_size / self.down_ratio)),
+ mode="bilinear",
+ align_corners=True,
+ )
+
+ if self.pos_embed:
+ out = self._apply_pos_embed(out, W, H)
+
+ if self.feature_only:
+ return out.view(B, S, *out.shape[1:])
+
+ out = self.scratch.output_conv2(out)
+ preds, conf = activate_head(out, activation=self.activation, conf_activation=self.conf_activation)
+
+ preds = preds.view(B, S, *preds.shape[1:])
+ conf = conf.view(B, S, *conf.shape[1:])
+ return preds, conf
+
+ def _apply_pos_embed(self, x: torch.Tensor, W: int, H: int, ratio: float = 0.1) -> torch.Tensor:
+ """
+ Apply positional embedding to tensor x.
+ """
+ patch_w = x.shape[-1]
+ patch_h = x.shape[-2]
+ pos_embed = create_uv_grid(patch_w, patch_h, aspect_ratio=W / H, dtype=x.dtype, device=x.device)
+ pos_embed = position_grid_to_embed(pos_embed, x.shape[1])
+ pos_embed = pos_embed * ratio
+ pos_embed = pos_embed.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1)
+ return x + pos_embed
+
+ def scratch_forward(self, features: List[torch.Tensor]) -> torch.Tensor:
+ """
+ Forward pass through the fusion blocks.
+
+ Args:
+ features (List[Tensor]): List of feature maps from different layers.
+
+ Returns:
+ Tensor: Fused feature map.
+ """
+ layer_1, layer_2, layer_3, layer_4 = features
+
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
+
+ out = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
+ del layer_4_rn, layer_4
+
+ out = self.scratch.refinenet3(out, layer_3_rn, size=layer_2_rn.shape[2:])
+ del layer_3_rn, layer_3
+
+ out = self.scratch.refinenet2(out, layer_2_rn, size=layer_1_rn.shape[2:])
+ del layer_2_rn, layer_2
+
+ out = self.scratch.refinenet1(out, layer_1_rn)
+ del layer_1_rn, layer_1
+
+ out = self.scratch.output_conv1(out)
+ return out
+
+
+################################################################################
+# Modules
+################################################################################
+
+
+def _make_fusion_block(features: int, size: int = None, has_residual: bool = True, groups: int = 1) -> nn.Module:
+ return FeatureFusionBlock(
+ features,
+ nn.ReLU(inplace=True),
+ deconv=False,
+ bn=False,
+ expand=False,
+ align_corners=True,
+ size=size,
+ has_residual=has_residual,
+ groups=groups,
+ )
+
+
+def _make_scratch(in_shape: List[int], out_shape: int, groups: int = 1, expand: bool = False) -> nn.Module:
+ scratch = nn.Module()
+ out_shape1 = out_shape
+ out_shape2 = out_shape
+ out_shape3 = out_shape
+ if len(in_shape) >= 4:
+ out_shape4 = out_shape
+
+ if expand:
+ out_shape1 = out_shape
+ out_shape2 = out_shape * 2
+ out_shape3 = out_shape * 4
+ if len(in_shape) >= 4:
+ out_shape4 = out_shape * 8
+
+ scratch.layer1_rn = nn.Conv2d(
+ in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+ )
+ scratch.layer2_rn = nn.Conv2d(
+ in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+ )
+ scratch.layer3_rn = nn.Conv2d(
+ in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+ )
+ if len(in_shape) >= 4:
+ scratch.layer4_rn = nn.Conv2d(
+ in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+ )
+ return scratch
+
+
+class ResidualConvUnit(nn.Module):
+ """Residual convolution module."""
+
+ def __init__(self, features, activation, bn, groups=1):
+ """Init.
+
+ Args:
+ features (int): number of features
+ """
+ super().__init__()
+
+ self.bn = bn
+ self.groups = groups
+ self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
+ self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
+
+ self.norm1 = None
+ self.norm2 = None
+
+ self.activation = activation
+ self.skip_add = nn.quantized.FloatFunctional()
+
+ def forward(self, x):
+ """Forward pass.
+
+ Args:
+ x (tensor): input
+
+ Returns:
+ tensor: output
+ """
+
+ out = self.activation(x)
+ out = self.conv1(out)
+ if self.norm1 is not None:
+ out = self.norm1(out)
+
+ out = self.activation(out)
+ out = self.conv2(out)
+ if self.norm2 is not None:
+ out = self.norm2(out)
+
+ return self.skip_add.add(out, x)
+
+
+class FeatureFusionBlock(nn.Module):
+ """Feature fusion block."""
+
+ def __init__(
+ self,
+ features,
+ activation,
+ deconv=False,
+ bn=False,
+ expand=False,
+ align_corners=True,
+ size=None,
+ has_residual=True,
+ groups=1,
+ ):
+ """Init.
+
+ Args:
+ features (int): number of features
+ """
+ super(FeatureFusionBlock, self).__init__()
+
+ self.deconv = deconv
+ self.align_corners = align_corners
+ self.groups = groups
+ self.expand = expand
+ out_features = features
+ if self.expand == True:
+ out_features = features // 2
+
+ self.out_conv = nn.Conv2d(
+ features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=self.groups
+ )
+
+ if has_residual:
+ self.resConfUnit1 = ResidualConvUnit(features, activation, bn, groups=self.groups)
+
+ self.has_residual = has_residual
+ self.resConfUnit2 = ResidualConvUnit(features, activation, bn, groups=self.groups)
+
+ self.skip_add = nn.quantized.FloatFunctional()
+ self.size = size
+
+ def forward(self, *xs, size=None):
+ """Forward pass.
+
+ Returns:
+ tensor: output
+ """
+ output = xs[0]
+
+ if self.has_residual:
+ res = self.resConfUnit1(xs[1])
+ output = self.skip_add.add(output, res)
+
+ output = self.resConfUnit2(output)
+
+ if (size is None) and (self.size is None):
+ modifier = {"scale_factor": 2}
+ elif size is None:
+ modifier = {"size": self.size}
+ else:
+ modifier = {"size": size}
+
+ output = custom_interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners)
+ output = self.out_conv(output)
+
+ return output
+
+
+def custom_interpolate(
+ x: torch.Tensor,
+ size: Tuple[int, int] = None,
+ scale_factor: float = None,
+ mode: str = "bilinear",
+ align_corners: bool = True,
+) -> torch.Tensor:
+ """
+ Custom interpolate to avoid INT_MAX issues in nn.functional.interpolate.
+ """
+ if size is None:
+ size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor))
+
+ INT_MAX = 1610612736
+
+ input_elements = size[0] * size[1] * x.shape[0] * x.shape[1]
+
+ if input_elements > INT_MAX:
+ chunks = torch.chunk(x, chunks=(input_elements // INT_MAX) + 1, dim=0)
+ interpolated_chunks = [
+ nn.functional.interpolate(chunk, size=size, mode=mode, align_corners=align_corners) for chunk in chunks
+ ]
+ x = torch.cat(interpolated_chunks, dim=0)
+ return x.contiguous()
+ else:
+ return nn.functional.interpolate(x, size=size, mode=mode, align_corners=align_corners)
diff --git a/src/model/encoder/vggt/heads/head_act.py b/src/model/encoder/vggt/heads/head_act.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0b332ec90cef68d98e8f9cd4c8a3d61008d46f1
--- /dev/null
+++ b/src/model/encoder/vggt/heads/head_act.py
@@ -0,0 +1,125 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+import torch
+import torch.nn.functional as F
+
+
+def activate_pose(pred_pose_enc, trans_act="linear", quat_act="linear", fl_act="linear"):
+ """
+ Activate pose parameters with specified activation functions.
+
+ Args:
+ pred_pose_enc: Tensor containing encoded pose parameters [translation, quaternion, focal length]
+ trans_act: Activation type for translation component
+ quat_act: Activation type for quaternion component
+ fl_act: Activation type for focal length component
+
+ Returns:
+ Activated pose parameters tensor
+ """
+ T = pred_pose_enc[..., :3]
+ quat = pred_pose_enc[..., 3:7]
+ fl = pred_pose_enc[..., 7:] # or fov
+
+ T = base_pose_act(T, trans_act)
+ quat = base_pose_act(quat, quat_act)
+ fl = base_pose_act(fl, fl_act) # or fov
+
+ pred_pose_enc = torch.cat([T, quat, fl], dim=-1)
+
+ return pred_pose_enc
+
+
+def base_pose_act(pose_enc, act_type="linear"):
+ """
+ Apply basic activation function to pose parameters.
+
+ Args:
+ pose_enc: Tensor containing encoded pose parameters
+ act_type: Activation type ("linear", "inv_log", "exp", "relu")
+
+ Returns:
+ Activated pose parameters
+ """
+ if act_type == "linear":
+ return pose_enc
+ elif act_type == "inv_log":
+ return inverse_log_transform(pose_enc)
+ elif act_type == "exp":
+ return torch.exp(pose_enc)
+ elif act_type == "relu":
+ return F.relu(pose_enc)
+ else:
+ raise ValueError(f"Unknown act_type: {act_type}")
+
+
+def activate_head(out, activation="norm_exp", conf_activation="expp1"):
+ """
+ Process network output to extract 3D points and confidence values.
+
+ Args:
+ out: Network output tensor (B, C, H, W)
+ activation: Activation type for 3D points
+ conf_activation: Activation type for confidence values
+
+ Returns:
+ Tuple of (3D points tensor, confidence tensor)
+ """
+ # Move channels from last dim to the 4th dimension => (B, H, W, C)
+ fmap = out.permute(0, 2, 3, 1) # B,H,W,C expected
+
+ # Split into xyz (first C-1 channels) and confidence (last channel)
+ xyz = fmap[:, :, :, :-1]
+ conf = fmap[:, :, :, -1]
+
+ if activation == "norm_exp":
+ d = xyz.norm(dim=-1, keepdim=True).clamp(min=1e-8)
+ xyz_normed = xyz / d
+ pts3d = xyz_normed * torch.expm1(d)
+ elif activation == "norm":
+ pts3d = xyz / xyz.norm(dim=-1, keepdim=True)
+ elif activation == "exp":
+ pts3d = torch.exp(xyz)
+ elif activation == "relu":
+ pts3d = F.relu(xyz)
+ elif activation == "inv_log":
+ pts3d = inverse_log_transform(xyz)
+ elif activation == "xy_inv_log":
+ xy, z = xyz.split([2, 1], dim=-1)
+ z = inverse_log_transform(z)
+ pts3d = torch.cat([xy * z, z], dim=-1)
+ elif activation == "sigmoid":
+ pts3d = torch.sigmoid(xyz)
+ elif activation == "linear":
+ pts3d = xyz
+ else:
+ raise ValueError(f"Unknown activation: {activation}")
+
+ if conf_activation == "expp1":
+ conf_out = 1 + conf.exp()
+ elif conf_activation == "expp0":
+ conf_out = conf.exp()
+ elif conf_activation == "sigmoid":
+ conf_out = torch.sigmoid(conf)
+ else:
+ raise ValueError(f"Unknown conf_activation: {conf_activation}")
+
+ return pts3d, conf_out
+
+
+def inverse_log_transform(y):
+ """
+ Apply inverse log transform: sign(y) * (exp(|y|) - 1)
+
+ Args:
+ y: Input tensor
+
+ Returns:
+ Transformed tensor
+ """
+ return torch.sign(y) * (torch.expm1(torch.abs(y)))
diff --git a/src/model/encoder/vggt/heads/track_head.py b/src/model/encoder/vggt/heads/track_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ec7199bd185060989c236997f93b93f4fc77825
--- /dev/null
+++ b/src/model/encoder/vggt/heads/track_head.py
@@ -0,0 +1,108 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch.nn as nn
+from .dpt_head import DPTHead
+from .track_modules.base_track_predictor import BaseTrackerPredictor
+
+
+class TrackHead(nn.Module):
+ """
+ Track head that uses DPT head to process tokens and BaseTrackerPredictor for tracking.
+ The tracking is performed iteratively, refining predictions over multiple iterations.
+ """
+
+ def __init__(
+ self,
+ dim_in,
+ patch_size=14,
+ features=128,
+ iters=4,
+ predict_conf=True,
+ stride=2,
+ corr_levels=7,
+ corr_radius=4,
+ hidden_size=384,
+ ):
+ """
+ Initialize the TrackHead module.
+
+ Args:
+ dim_in (int): Input dimension of tokens from the backbone.
+ patch_size (int): Size of image patches used in the vision transformer.
+ features (int): Number of feature channels in the feature extractor output.
+ iters (int): Number of refinement iterations for tracking predictions.
+ predict_conf (bool): Whether to predict confidence scores for tracked points.
+ stride (int): Stride value for the tracker predictor.
+ corr_levels (int): Number of correlation pyramid levels
+ corr_radius (int): Radius for correlation computation, controlling the search area.
+ hidden_size (int): Size of hidden layers in the tracker network.
+ """
+ super().__init__()
+
+ self.patch_size = patch_size
+
+ # Feature extractor based on DPT architecture
+ # Processes tokens into feature maps for tracking
+ self.feature_extractor = DPTHead(
+ dim_in=dim_in,
+ patch_size=patch_size,
+ features=features,
+ feature_only=True, # Only output features, no activation
+ down_ratio=2, # Reduces spatial dimensions by factor of 2
+ pos_embed=False,
+ )
+
+ # Tracker module that predicts point trajectories
+ # Takes feature maps and predicts coordinates and visibility
+ self.tracker = BaseTrackerPredictor(
+ latent_dim=features, # Match the output_dim of feature extractor
+ predict_conf=predict_conf,
+ stride=stride,
+ corr_levels=corr_levels,
+ corr_radius=corr_radius,
+ hidden_size=hidden_size,
+ )
+
+ self.iters = iters
+
+ def forward(self, aggregated_tokens_list, images, patch_start_idx, query_points=None, iters=None):
+ """
+ Forward pass of the TrackHead.
+
+ Args:
+ aggregated_tokens_list (list): List of aggregated tokens from the backbone.
+ images (torch.Tensor): Input images of shape (B, S, C, H, W) where:
+ B = batch size, S = sequence length.
+ patch_start_idx (int): Starting index for patch tokens.
+ query_points (torch.Tensor, optional): Initial query points to track.
+ If None, points are initialized by the tracker.
+ iters (int, optional): Number of refinement iterations. If None, uses self.iters.
+
+ Returns:
+ tuple:
+ - coord_preds (torch.Tensor): Predicted coordinates for tracked points.
+ - vis_scores (torch.Tensor): Visibility scores for tracked points.
+ - conf_scores (torch.Tensor): Confidence scores for tracked points (if predict_conf=True).
+ """
+ B, S, _, H, W = images.shape
+
+ # Extract features from tokens
+ # feature_maps has shape (B, S, C, H//2, W//2) due to down_ratio=2
+ feature_maps = self.feature_extractor(aggregated_tokens_list, images, patch_start_idx)
+
+ # Use default iterations if not specified
+ if iters is None:
+ iters = self.iters
+
+ # Perform tracking using the extracted features
+ coord_preds, vis_scores, conf_scores = self.tracker(
+ query_points=query_points,
+ fmaps=feature_maps,
+ iters=iters,
+ )
+
+ return coord_preds, vis_scores, conf_scores
diff --git a/src/model/encoder/vggt/heads/track_modules/__init__.py b/src/model/encoder/vggt/heads/track_modules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0952fcc3f57e34b3747962e9ebd6fc57aeea63fa
--- /dev/null
+++ b/src/model/encoder/vggt/heads/track_modules/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
diff --git a/src/model/encoder/vggt/heads/track_modules/base_track_predictor.py b/src/model/encoder/vggt/heads/track_modules/base_track_predictor.py
new file mode 100644
index 0000000000000000000000000000000000000000..3ce8ec4b66fff236e015d1bcaf85c8237a52be7a
--- /dev/null
+++ b/src/model/encoder/vggt/heads/track_modules/base_track_predictor.py
@@ -0,0 +1,209 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+from einops import rearrange, repeat
+
+
+from .blocks import EfficientUpdateFormer, CorrBlock
+from .utils import sample_features4d, get_2d_embedding, get_2d_sincos_pos_embed
+from .modules import Mlp
+
+
+class BaseTrackerPredictor(nn.Module):
+ def __init__(
+ self,
+ stride=1,
+ corr_levels=5,
+ corr_radius=4,
+ latent_dim=128,
+ hidden_size=384,
+ use_spaceatt=True,
+ depth=6,
+ max_scale=518,
+ predict_conf=True,
+ ):
+ super(BaseTrackerPredictor, self).__init__()
+ """
+ The base template to create a track predictor
+
+ Modified from https://github.com/facebookresearch/co-tracker/
+ and https://github.com/facebookresearch/vggsfm
+ """
+
+ self.stride = stride
+ self.latent_dim = latent_dim
+ self.corr_levels = corr_levels
+ self.corr_radius = corr_radius
+ self.hidden_size = hidden_size
+ self.max_scale = max_scale
+ self.predict_conf = predict_conf
+
+ self.flows_emb_dim = latent_dim // 2
+
+ self.corr_mlp = Mlp(
+ in_features=self.corr_levels * (self.corr_radius * 2 + 1) ** 2,
+ hidden_features=self.hidden_size,
+ out_features=self.latent_dim,
+ )
+
+ self.transformer_dim = self.latent_dim + self.latent_dim + self.latent_dim + 4
+
+ self.query_ref_token = nn.Parameter(torch.randn(1, 2, self.transformer_dim))
+
+ space_depth = depth if use_spaceatt else 0
+ time_depth = depth
+
+ self.updateformer = EfficientUpdateFormer(
+ space_depth=space_depth,
+ time_depth=time_depth,
+ input_dim=self.transformer_dim,
+ hidden_size=self.hidden_size,
+ output_dim=self.latent_dim + 2,
+ mlp_ratio=4.0,
+ add_space_attn=use_spaceatt,
+ )
+
+ self.fmap_norm = nn.LayerNorm(self.latent_dim)
+ self.ffeat_norm = nn.GroupNorm(1, self.latent_dim)
+
+ # A linear layer to update track feats at each iteration
+ self.ffeat_updater = nn.Sequential(nn.Linear(self.latent_dim, self.latent_dim), nn.GELU())
+
+ self.vis_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1))
+
+ if predict_conf:
+ self.conf_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1))
+
+ def forward(self, query_points, fmaps=None, iters=6, return_feat=False, down_ratio=1, apply_sigmoid=True):
+ """
+ query_points: B x N x 2, the number of batches, tracks, and xy
+ fmaps: B x S x C x HH x WW, the number of batches, frames, and feature dimension.
+ note HH and WW is the size of feature maps instead of original images
+ """
+ B, N, D = query_points.shape
+ B, S, C, HH, WW = fmaps.shape
+
+ assert D == 2, "Input points must be 2D coordinates"
+
+ # apply a layernorm to fmaps here
+ fmaps = self.fmap_norm(fmaps.permute(0, 1, 3, 4, 2))
+ fmaps = fmaps.permute(0, 1, 4, 2, 3)
+
+ # Scale the input query_points because we may downsample the images
+ # by down_ratio or self.stride
+ # e.g., if a 3x1024x1024 image is processed to a 128x256x256 feature map
+ # its query_points should be query_points/4
+ if down_ratio > 1:
+ query_points = query_points / float(down_ratio)
+
+ query_points = query_points / float(self.stride)
+
+ # Init with coords as the query points
+ # It means the search will start from the position of query points at the reference frames
+ coords = query_points.clone().reshape(B, 1, N, 2).repeat(1, S, 1, 1)
+
+ # Sample/extract the features of the query points in the query frame
+ query_track_feat = sample_features4d(fmaps[:, 0], coords[:, 0])
+
+ # init track feats by query feats
+ track_feats = query_track_feat.unsqueeze(1).repeat(1, S, 1, 1) # B, S, N, C
+ # back up the init coords
+ coords_backup = coords.clone()
+
+ fcorr_fn = CorrBlock(fmaps, num_levels=self.corr_levels, radius=self.corr_radius)
+
+ coord_preds = []
+
+ # Iterative Refinement
+ for _ in range(iters):
+ # Detach the gradients from the last iteration
+ # (in my experience, not very important for performance)
+ coords = coords.detach()
+
+ fcorrs = fcorr_fn.corr_sample(track_feats, coords)
+
+ corr_dim = fcorrs.shape[3]
+ fcorrs_ = fcorrs.permute(0, 2, 1, 3).reshape(B * N, S, corr_dim)
+ fcorrs_ = self.corr_mlp(fcorrs_)
+
+ # Movement of current coords relative to query points
+ flows = (coords - coords[:, 0:1]).permute(0, 2, 1, 3).reshape(B * N, S, 2)
+
+ flows_emb = get_2d_embedding(flows, self.flows_emb_dim, cat_coords=False)
+
+ # (In my trials, it is also okay to just add the flows_emb instead of concat)
+ flows_emb = torch.cat([flows_emb, flows / self.max_scale, flows / self.max_scale], dim=-1)
+
+ track_feats_ = track_feats.permute(0, 2, 1, 3).reshape(B * N, S, self.latent_dim)
+
+ # Concatenate them as the input for the transformers
+ transformer_input = torch.cat([flows_emb, fcorrs_, track_feats_], dim=2)
+
+ # 2D positional embed
+ # TODO: this can be much simplified
+ pos_embed = get_2d_sincos_pos_embed(self.transformer_dim, grid_size=(HH, WW)).to(query_points.device)
+ sampled_pos_emb = sample_features4d(pos_embed.expand(B, -1, -1, -1), coords[:, 0])
+
+ sampled_pos_emb = rearrange(sampled_pos_emb, "b n c -> (b n) c").unsqueeze(1)
+
+ x = transformer_input + sampled_pos_emb
+
+ # Add the query ref token to the track feats
+ query_ref_token = torch.cat(
+ [self.query_ref_token[:, 0:1], self.query_ref_token[:, 1:2].expand(-1, S - 1, -1)], dim=1
+ )
+ x = x + query_ref_token.to(x.device).to(x.dtype)
+
+ # B, N, S, C
+ x = rearrange(x, "(b n) s d -> b n s d", b=B)
+
+ # Compute the delta coordinates and delta track features
+ delta, _ = self.updateformer(x)
+
+ # BN, S, C
+ delta = rearrange(delta, " b n s d -> (b n) s d", b=B)
+ delta_coords_ = delta[:, :, :2]
+ delta_feats_ = delta[:, :, 2:]
+
+ track_feats_ = track_feats_.reshape(B * N * S, self.latent_dim)
+ delta_feats_ = delta_feats_.reshape(B * N * S, self.latent_dim)
+
+ # Update the track features
+ track_feats_ = self.ffeat_updater(self.ffeat_norm(delta_feats_)) + track_feats_
+
+ track_feats = track_feats_.reshape(B, N, S, self.latent_dim).permute(0, 2, 1, 3) # BxSxNxC
+
+ # B x S x N x 2
+ coords = coords + delta_coords_.reshape(B, N, S, 2).permute(0, 2, 1, 3)
+
+ # Force coord0 as query
+ # because we assume the query points should not be changed
+ coords[:, 0] = coords_backup[:, 0]
+
+ # The predicted tracks are in the original image scale
+ if down_ratio > 1:
+ coord_preds.append(coords * self.stride * down_ratio)
+ else:
+ coord_preds.append(coords * self.stride)
+
+ # B, S, N
+ vis_e = self.vis_predictor(track_feats.reshape(B * S * N, self.latent_dim)).reshape(B, S, N)
+ if apply_sigmoid:
+ vis_e = torch.sigmoid(vis_e)
+
+ if self.predict_conf:
+ conf_e = self.conf_predictor(track_feats.reshape(B * S * N, self.latent_dim)).reshape(B, S, N)
+ if apply_sigmoid:
+ conf_e = torch.sigmoid(conf_e)
+ else:
+ conf_e = None
+
+ if return_feat:
+ return coord_preds, vis_e, track_feats, query_track_feat, conf_e
+ else:
+ return coord_preds, vis_e, conf_e
diff --git a/src/model/encoder/vggt/heads/track_modules/blocks.py b/src/model/encoder/vggt/heads/track_modules/blocks.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e7763f4fd8f515662421db192594380dbb574e5
--- /dev/null
+++ b/src/model/encoder/vggt/heads/track_modules/blocks.py
@@ -0,0 +1,246 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+# Modified from https://github.com/facebookresearch/co-tracker/
+
+import math
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .utils import bilinear_sampler
+from .modules import Mlp, AttnBlock, CrossAttnBlock, ResidualBlock
+
+
+class EfficientUpdateFormer(nn.Module):
+ """
+ Transformer model that updates track estimates.
+ """
+
+ def __init__(
+ self,
+ space_depth=6,
+ time_depth=6,
+ input_dim=320,
+ hidden_size=384,
+ num_heads=8,
+ output_dim=130,
+ mlp_ratio=4.0,
+ add_space_attn=True,
+ num_virtual_tracks=64,
+ ):
+ super().__init__()
+
+ self.out_channels = 2
+ self.num_heads = num_heads
+ self.hidden_size = hidden_size
+ self.add_space_attn = add_space_attn
+
+ # Add input LayerNorm before linear projection
+ self.input_norm = nn.LayerNorm(input_dim)
+ self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True)
+
+ # Add output LayerNorm before final projection
+ self.output_norm = nn.LayerNorm(hidden_size)
+ self.flow_head = torch.nn.Linear(hidden_size, output_dim, bias=True)
+ self.num_virtual_tracks = num_virtual_tracks
+
+ if self.add_space_attn:
+ self.virual_tracks = nn.Parameter(torch.randn(1, num_virtual_tracks, 1, hidden_size))
+ else:
+ self.virual_tracks = None
+
+ self.time_blocks = nn.ModuleList(
+ [
+ AttnBlock(
+ hidden_size,
+ num_heads,
+ mlp_ratio=mlp_ratio,
+ attn_class=nn.MultiheadAttention,
+ )
+ for _ in range(time_depth)
+ ]
+ )
+
+ if add_space_attn:
+ self.space_virtual_blocks = nn.ModuleList(
+ [
+ AttnBlock(
+ hidden_size,
+ num_heads,
+ mlp_ratio=mlp_ratio,
+ attn_class=nn.MultiheadAttention,
+ )
+ for _ in range(space_depth)
+ ]
+ )
+ self.space_point2virtual_blocks = nn.ModuleList(
+ [CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(space_depth)]
+ )
+ self.space_virtual2point_blocks = nn.ModuleList(
+ [CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(space_depth)]
+ )
+ assert len(self.time_blocks) >= len(self.space_virtual2point_blocks)
+ self.initialize_weights()
+
+ def initialize_weights(self):
+ def _basic_init(module):
+ if isinstance(module, nn.Linear):
+ torch.nn.init.xavier_uniform_(module.weight)
+ if module.bias is not None:
+ nn.init.constant_(module.bias, 0)
+ torch.nn.init.trunc_normal_(self.flow_head.weight, std=0.001)
+
+ self.apply(_basic_init)
+
+ def forward(self, input_tensor, mask=None):
+ # Apply input LayerNorm
+ input_tensor = self.input_norm(input_tensor)
+ tokens = self.input_transform(input_tensor)
+
+ init_tokens = tokens
+
+ B, _, T, _ = tokens.shape
+
+ if self.add_space_attn:
+ virtual_tokens = self.virual_tracks.repeat(B, 1, T, 1)
+ tokens = torch.cat([tokens, virtual_tokens], dim=1)
+
+ _, N, _, _ = tokens.shape
+
+ j = 0
+ for i in range(len(self.time_blocks)):
+ time_tokens = tokens.contiguous().view(B * N, T, -1) # B N T C -> (B N) T C
+
+ time_tokens = self.time_blocks[i](time_tokens)
+
+ tokens = time_tokens.view(B, N, T, -1) # (B N) T C -> B N T C
+ if self.add_space_attn and (i % (len(self.time_blocks) // len(self.space_virtual_blocks)) == 0):
+ space_tokens = tokens.permute(0, 2, 1, 3).contiguous().view(B * T, N, -1) # B N T C -> (B T) N C
+ point_tokens = space_tokens[:, : N - self.num_virtual_tracks]
+ virtual_tokens = space_tokens[:, N - self.num_virtual_tracks :]
+
+ virtual_tokens = self.space_virtual2point_blocks[j](virtual_tokens, point_tokens, mask=mask)
+ virtual_tokens = self.space_virtual_blocks[j](virtual_tokens)
+ point_tokens = self.space_point2virtual_blocks[j](point_tokens, virtual_tokens, mask=mask)
+
+ space_tokens = torch.cat([point_tokens, virtual_tokens], dim=1)
+ tokens = space_tokens.view(B, T, N, -1).permute(0, 2, 1, 3) # (B T) N C -> B N T C
+ j += 1
+
+ if self.add_space_attn:
+ tokens = tokens[:, : N - self.num_virtual_tracks]
+
+ tokens = tokens + init_tokens
+
+ # Apply output LayerNorm before final projection
+ tokens = self.output_norm(tokens)
+ flow = self.flow_head(tokens)
+
+ return flow, None
+
+
+class CorrBlock:
+ def __init__(self, fmaps, num_levels=4, radius=4, multiple_track_feats=False, padding_mode="zeros"):
+ """
+ Build a pyramid of feature maps from the input.
+
+ fmaps: Tensor (B, S, C, H, W)
+ num_levels: number of pyramid levels (each downsampled by factor 2)
+ radius: search radius for sampling correlation
+ multiple_track_feats: if True, split the target features per pyramid level
+ padding_mode: passed to grid_sample / bilinear_sampler
+ """
+ B, S, C, H, W = fmaps.shape
+ self.S, self.C, self.H, self.W = S, C, H, W
+ self.num_levels = num_levels
+ self.radius = radius
+ self.padding_mode = padding_mode
+ self.multiple_track_feats = multiple_track_feats
+
+ # Build pyramid: each level is half the spatial resolution of the previous
+ self.fmaps_pyramid = [fmaps] # level 0 is full resolution
+ current_fmaps = fmaps
+ for i in range(num_levels - 1):
+ B, S, C, H, W = current_fmaps.shape
+ # Merge batch & sequence dimensions
+ current_fmaps = current_fmaps.reshape(B * S, C, H, W)
+ # Avg pool down by factor 2
+ current_fmaps = F.avg_pool2d(current_fmaps, kernel_size=2, stride=2)
+ _, _, H_new, W_new = current_fmaps.shape
+ current_fmaps = current_fmaps.reshape(B, S, C, H_new, W_new)
+ self.fmaps_pyramid.append(current_fmaps)
+
+ # Precompute a delta grid (of shape (2r+1, 2r+1, 2)) for sampling.
+ # This grid is added to the (scaled) coordinate centroids.
+ r = self.radius
+ dx = torch.linspace(-r, r, 2 * r + 1, device=fmaps.device, dtype=fmaps.dtype)
+ dy = torch.linspace(-r, r, 2 * r + 1, device=fmaps.device, dtype=fmaps.dtype)
+ # delta: for every (dy,dx) displacement (i.e. Δx, Δy)
+ self.delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), dim=-1) # shape: (2r+1, 2r+1, 2)
+
+ def corr_sample(self, targets, coords):
+ """
+ Instead of storing the entire correlation pyramid, we compute each level's correlation
+ volume, sample it immediately, then discard it. This saves GPU memory.
+
+ Args:
+ targets: Tensor (B, S, N, C) — features for the current targets.
+ coords: Tensor (B, S, N, 2) — coordinates at full resolution.
+
+ Returns:
+ Tensor (B, S, N, L) where L = num_levels * (2*radius+1)**2 (concatenated sampled correlations)
+ """
+ B, S, N, C = targets.shape
+
+ # If you have multiple track features, split them per level.
+ if self.multiple_track_feats:
+ targets_split = torch.split(targets, C // self.num_levels, dim=-1)
+
+ out_pyramid = []
+ for i, fmaps in enumerate(self.fmaps_pyramid):
+ # Get current spatial resolution H, W for this pyramid level.
+ B, S, C, H, W = fmaps.shape
+ # Reshape feature maps for correlation computation:
+ # fmap2s: (B, S, C, H*W)
+ fmap2s = fmaps.view(B, S, C, H * W)
+ # Choose appropriate target features.
+ fmap1 = targets_split[i] if self.multiple_track_feats else targets # shape: (B, S, N, C)
+
+ # Compute correlation directly
+ corrs = compute_corr_level(fmap1, fmap2s, C)
+ corrs = corrs.view(B, S, N, H, W)
+
+ # Prepare sampling grid:
+ # Scale down the coordinates for the current level.
+ centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / (2**i)
+ # Make sure our precomputed delta grid is on the same device/dtype.
+ delta_lvl = self.delta.to(coords.device).to(coords.dtype)
+ # Now the grid for grid_sample is:
+ # coords_lvl = centroid_lvl + delta_lvl (broadcasted over grid)
+ coords_lvl = centroid_lvl + delta_lvl.view(1, 2 * self.radius + 1, 2 * self.radius + 1, 2)
+
+ # Sample from the correlation volume using bilinear interpolation.
+ # We reshape corrs to (B * S * N, 1, H, W) so grid_sample acts over each target.
+ corrs_sampled = bilinear_sampler(
+ corrs.reshape(B * S * N, 1, H, W), coords_lvl, padding_mode=self.padding_mode
+ )
+ # The sampled output is (B * S * N, 1, 2r+1, 2r+1). Flatten the last two dims.
+ corrs_sampled = corrs_sampled.view(B, S, N, -1) # Now shape: (B, S, N, (2r+1)^2)
+ out_pyramid.append(corrs_sampled)
+
+ # Concatenate all levels along the last dimension.
+ out = torch.cat(out_pyramid, dim=-1).contiguous()
+ return out
+
+
+def compute_corr_level(fmap1, fmap2s, C):
+ # fmap1: (B, S, N, C)
+ # fmap2s: (B, S, C, H*W)
+ corrs = torch.matmul(fmap1, fmap2s) # (B, S, N, H*W)
+ corrs = corrs.view(fmap1.shape[0], fmap1.shape[1], fmap1.shape[2], -1) # (B, S, N, H*W)
+ return corrs / math.sqrt(C)
diff --git a/src/model/encoder/vggt/heads/track_modules/modules.py b/src/model/encoder/vggt/heads/track_modules/modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b090ddc4a9db01c8dd3564f9053e1ca9cdde93a
--- /dev/null
+++ b/src/model/encoder/vggt/heads/track_modules/modules.py
@@ -0,0 +1,218 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from functools import partial
+from typing import Callable
+import collections
+from torch import Tensor
+from itertools import repeat
+
+
+# From PyTorch internals
+def _ntuple(n):
+ def parse(x):
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
+ return tuple(x)
+ return tuple(repeat(x, n))
+
+ return parse
+
+
+def exists(val):
+ return val is not None
+
+
+def default(val, d):
+ return val if exists(val) else d
+
+
+to_2tuple = _ntuple(2)
+
+
+class ResidualBlock(nn.Module):
+ """
+ ResidualBlock: construct a block of two conv layers with residual connections
+ """
+
+ def __init__(self, in_planes, planes, norm_fn="group", stride=1, kernel_size=3):
+ super(ResidualBlock, self).__init__()
+
+ self.conv1 = nn.Conv2d(
+ in_planes,
+ planes,
+ kernel_size=kernel_size,
+ padding=1,
+ stride=stride,
+ padding_mode="zeros",
+ )
+ self.conv2 = nn.Conv2d(
+ planes,
+ planes,
+ kernel_size=kernel_size,
+ padding=1,
+ padding_mode="zeros",
+ )
+ self.relu = nn.ReLU(inplace=True)
+
+ num_groups = planes // 8
+
+ if norm_fn == "group":
+ self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
+ self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
+ if not stride == 1:
+ self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
+
+ elif norm_fn == "batch":
+ self.norm1 = nn.BatchNorm2d(planes)
+ self.norm2 = nn.BatchNorm2d(planes)
+ if not stride == 1:
+ self.norm3 = nn.BatchNorm2d(planes)
+
+ elif norm_fn == "instance":
+ self.norm1 = nn.InstanceNorm2d(planes)
+ self.norm2 = nn.InstanceNorm2d(planes)
+ if not stride == 1:
+ self.norm3 = nn.InstanceNorm2d(planes)
+
+ elif norm_fn == "none":
+ self.norm1 = nn.Sequential()
+ self.norm2 = nn.Sequential()
+ if not stride == 1:
+ self.norm3 = nn.Sequential()
+ else:
+ raise NotImplementedError
+
+ if stride == 1:
+ self.downsample = None
+ else:
+ self.downsample = nn.Sequential(
+ nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride),
+ self.norm3,
+ )
+
+ def forward(self, x):
+ y = x
+ y = self.relu(self.norm1(self.conv1(y)))
+ y = self.relu(self.norm2(self.conv2(y)))
+
+ if self.downsample is not None:
+ x = self.downsample(x)
+
+ return self.relu(x + y)
+
+
+class Mlp(nn.Module):
+ """MLP as used in Vision Transformer, MLP-Mixer and related networks"""
+
+ def __init__(
+ self,
+ in_features,
+ hidden_features=None,
+ out_features=None,
+ act_layer=nn.GELU,
+ norm_layer=None,
+ bias=True,
+ drop=0.0,
+ use_conv=False,
+ ):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ bias = to_2tuple(bias)
+ drop_probs = to_2tuple(drop)
+ linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
+
+ self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
+ self.act = act_layer()
+ self.drop1 = nn.Dropout(drop_probs[0])
+ self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
+ self.drop2 = nn.Dropout(drop_probs[1])
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop1(x)
+ x = self.fc2(x)
+ x = self.drop2(x)
+ return x
+
+
+class AttnBlock(nn.Module):
+ def __init__(
+ self,
+ hidden_size,
+ num_heads,
+ attn_class: Callable[..., nn.Module] = nn.MultiheadAttention,
+ mlp_ratio=4.0,
+ **block_kwargs
+ ):
+ """
+ Self attention block
+ """
+ super().__init__()
+
+ self.norm1 = nn.LayerNorm(hidden_size)
+ self.norm2 = nn.LayerNorm(hidden_size)
+
+ self.attn = attn_class(embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs)
+
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
+
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0)
+
+ def forward(self, x, mask=None):
+ # Prepare the mask for PyTorch's attention (it expects a different format)
+ # attn_mask = mask if mask is not None else None
+ # Normalize before attention
+ x = self.norm1(x)
+
+ # PyTorch's MultiheadAttention returns attn_output, attn_output_weights
+ # attn_output, _ = self.attn(x, x, x, attn_mask=attn_mask)
+
+ attn_output, _ = self.attn(x, x, x)
+
+ # Add & Norm
+ x = x + attn_output
+ x = x + self.mlp(self.norm2(x))
+ return x
+
+
+class CrossAttnBlock(nn.Module):
+ def __init__(self, hidden_size, context_dim, num_heads=1, mlp_ratio=4.0, **block_kwargs):
+ """
+ Cross attention block
+ """
+ super().__init__()
+
+ self.norm1 = nn.LayerNorm(hidden_size)
+ self.norm_context = nn.LayerNorm(hidden_size)
+ self.norm2 = nn.LayerNorm(hidden_size)
+
+ self.cross_attn = nn.MultiheadAttention(
+ embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs
+ )
+
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
+
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0)
+
+ def forward(self, x, context, mask=None):
+ # Normalize inputs
+ x = self.norm1(x)
+ context = self.norm_context(context)
+
+ # Apply cross attention
+ # Note: nn.MultiheadAttention returns attn_output, attn_output_weights
+ attn_output, _ = self.cross_attn(x, context, context, attn_mask=mask)
+
+ # Add & Norm
+ x = x + attn_output
+ x = x + self.mlp(self.norm2(x))
+ return x
diff --git a/src/model/encoder/vggt/heads/track_modules/utils.py b/src/model/encoder/vggt/heads/track_modules/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..51d01d39cdc10388a04dab5db7cf409b31bde766
--- /dev/null
+++ b/src/model/encoder/vggt/heads/track_modules/utils.py
@@ -0,0 +1,226 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# Modified from https://github.com/facebookresearch/vggsfm
+# and https://github.com/facebookresearch/co-tracker/tree/main
+
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from typing import Optional, Tuple, Union
+
+
+def get_2d_sincos_pos_embed(embed_dim: int, grid_size: Union[int, Tuple[int, int]], return_grid=False) -> torch.Tensor:
+ """
+ This function initializes a grid and generates a 2D positional embedding using sine and cosine functions.
+ It is a wrapper of get_2d_sincos_pos_embed_from_grid.
+ Args:
+ - embed_dim: The embedding dimension.
+ - grid_size: The grid size.
+ Returns:
+ - pos_embed: The generated 2D positional embedding.
+ """
+ if isinstance(grid_size, tuple):
+ grid_size_h, grid_size_w = grid_size
+ else:
+ grid_size_h = grid_size_w = grid_size
+ grid_h = torch.arange(grid_size_h, dtype=torch.float)
+ grid_w = torch.arange(grid_size_w, dtype=torch.float)
+ grid = torch.meshgrid(grid_w, grid_h, indexing="xy")
+ grid = torch.stack(grid, dim=0)
+ grid = grid.reshape([2, 1, grid_size_h, grid_size_w])
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
+ if return_grid:
+ return (
+ pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2),
+ grid,
+ )
+ return pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2)
+
+
+def get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid: torch.Tensor) -> torch.Tensor:
+ """
+ This function generates a 2D positional embedding from a given grid using sine and cosine functions.
+
+ Args:
+ - embed_dim: The embedding dimension.
+ - grid: The grid to generate the embedding from.
+
+ Returns:
+ - emb: The generated 2D positional embedding.
+ """
+ assert embed_dim % 2 == 0
+
+ # use half of dimensions to encode grid_h
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
+
+ emb = torch.cat([emb_h, emb_w], dim=2) # (H*W, D)
+ return emb
+
+
+def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: torch.Tensor) -> torch.Tensor:
+ """
+ This function generates a 1D positional embedding from a given grid using sine and cosine functions.
+
+ Args:
+ - embed_dim: The embedding dimension.
+ - pos: The position to generate the embedding from.
+
+ Returns:
+ - emb: The generated 1D positional embedding.
+ """
+ assert embed_dim % 2 == 0
+ omega = torch.arange(embed_dim // 2, dtype=torch.double)
+ omega /= embed_dim / 2.0
+ omega = 1.0 / 10000**omega # (D/2,)
+
+ pos = pos.reshape(-1) # (M,)
+ out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
+
+ emb_sin = torch.sin(out) # (M, D/2)
+ emb_cos = torch.cos(out) # (M, D/2)
+
+ emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
+ return emb[None].float()
+
+
+def get_2d_embedding(xy: torch.Tensor, C: int, cat_coords: bool = True) -> torch.Tensor:
+ """
+ This function generates a 2D positional embedding from given coordinates using sine and cosine functions.
+
+ Args:
+ - xy: The coordinates to generate the embedding from.
+ - C: The size of the embedding.
+ - cat_coords: A flag to indicate whether to concatenate the original coordinates to the embedding.
+
+ Returns:
+ - pe: The generated 2D positional embedding.
+ """
+ B, N, D = xy.shape
+ assert D == 2
+
+ x = xy[:, :, 0:1]
+ y = xy[:, :, 1:2]
+ div_term = (torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C)).reshape(1, 1, int(C / 2))
+
+ pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
+ pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
+
+ pe_x[:, :, 0::2] = torch.sin(x * div_term)
+ pe_x[:, :, 1::2] = torch.cos(x * div_term)
+
+ pe_y[:, :, 0::2] = torch.sin(y * div_term)
+ pe_y[:, :, 1::2] = torch.cos(y * div_term)
+
+ pe = torch.cat([pe_x, pe_y], dim=2) # (B, N, C*3)
+ if cat_coords:
+ pe = torch.cat([xy, pe], dim=2) # (B, N, C*3+3)
+ return pe
+
+
+def bilinear_sampler(input, coords, align_corners=True, padding_mode="border"):
+ r"""Sample a tensor using bilinear interpolation
+
+ `bilinear_sampler(input, coords)` samples a tensor :attr:`input` at
+ coordinates :attr:`coords` using bilinear interpolation. It is the same
+ as `torch.nn.functional.grid_sample()` but with a different coordinate
+ convention.
+
+ The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where
+ :math:`B` is the batch size, :math:`C` is the number of channels,
+ :math:`H` is the height of the image, and :math:`W` is the width of the
+ image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is
+ interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`.
+
+ Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`,
+ in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note
+ that in this case the order of the components is slightly different
+ from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`.
+
+ If `align_corners` is `True`, the coordinate :math:`x` is assumed to be
+ in the range :math:`[0,W-1]`, with 0 corresponding to the center of the
+ left-most image pixel :math:`W-1` to the center of the right-most
+ pixel.
+
+ If `align_corners` is `False`, the coordinate :math:`x` is assumed to
+ be in the range :math:`[0,W]`, with 0 corresponding to the left edge of
+ the left-most pixel :math:`W` to the right edge of the right-most
+ pixel.
+
+ Similar conventions apply to the :math:`y` for the range
+ :math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range
+ :math:`[0,T-1]` and :math:`[0,T]`.
+
+ Args:
+ input (Tensor): batch of input images.
+ coords (Tensor): batch of coordinates.
+ align_corners (bool, optional): Coordinate convention. Defaults to `True`.
+ padding_mode (str, optional): Padding mode. Defaults to `"border"`.
+
+ Returns:
+ Tensor: sampled points.
+ """
+ coords = coords.detach().clone()
+ ############################################################
+ # IMPORTANT:
+ coords = coords.to(input.device).to(input.dtype)
+ ############################################################
+
+ sizes = input.shape[2:]
+
+ assert len(sizes) in [2, 3]
+
+ if len(sizes) == 3:
+ # t x y -> x y t to match dimensions T H W in grid_sample
+ coords = coords[..., [1, 2, 0]]
+
+ if align_corners:
+ scale = torch.tensor(
+ [2 / max(size - 1, 1) for size in reversed(sizes)], device=coords.device, dtype=coords.dtype
+ )
+ else:
+ scale = torch.tensor([2 / size for size in reversed(sizes)], device=coords.device, dtype=coords.dtype)
+
+ coords.mul_(scale) # coords = coords * scale
+ coords.sub_(1) # coords = coords - 1
+
+ return F.grid_sample(input, coords, align_corners=align_corners, padding_mode=padding_mode)
+
+
+def sample_features4d(input, coords):
+ r"""Sample spatial features
+
+ `sample_features4d(input, coords)` samples the spatial features
+ :attr:`input` represented by a 4D tensor :math:`(B, C, H, W)`.
+
+ The field is sampled at coordinates :attr:`coords` using bilinear
+ interpolation. :attr:`coords` is assumed to be of shape :math:`(B, R,
+ 2)`, where each sample has the format :math:`(x_i, y_i)`. This uses the
+ same convention as :func:`bilinear_sampler` with `align_corners=True`.
+
+ The output tensor has one feature per point, and has shape :math:`(B,
+ R, C)`.
+
+ Args:
+ input (Tensor): spatial features.
+ coords (Tensor): points.
+
+ Returns:
+ Tensor: sampled features.
+ """
+
+ B, _, _, _ = input.shape
+
+ # B R 2 -> B R 1 2
+ coords = coords.unsqueeze(2)
+
+ # B C R 1
+ feats = bilinear_sampler(input, coords)
+
+ return feats.permute(0, 2, 1, 3).view(B, -1, feats.shape[1] * feats.shape[3]) # B C R 1 -> B R C
diff --git a/src/model/encoder/vggt/heads/utils.py b/src/model/encoder/vggt/heads/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..d7af1f68fa0ce0a48d11a708d53aa20aa8f78ba2
--- /dev/null
+++ b/src/model/encoder/vggt/heads/utils.py
@@ -0,0 +1,108 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+
+
+def position_grid_to_embed(pos_grid: torch.Tensor, embed_dim: int, omega_0: float = 100) -> torch.Tensor:
+ """
+ Convert 2D position grid (HxWx2) to sinusoidal embeddings (HxWxC)
+
+ Args:
+ pos_grid: Tensor of shape (H, W, 2) containing 2D coordinates
+ embed_dim: Output channel dimension for embeddings
+
+ Returns:
+ Tensor of shape (H, W, embed_dim) with positional embeddings
+ """
+ H, W, grid_dim = pos_grid.shape
+ assert grid_dim == 2
+ pos_flat = pos_grid.reshape(-1, grid_dim) # Flatten to (H*W, 2)
+
+ # Process x and y coordinates separately
+ emb_x = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 0], omega_0=omega_0) # [1, H*W, D/2]
+ emb_y = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 1], omega_0=omega_0) # [1, H*W, D/2]
+
+ # Combine and reshape
+ emb = torch.cat([emb_x, emb_y], dim=-1) # [1, H*W, D]
+
+ return emb.view(H, W, embed_dim) # [H, W, D]
+
+
+def make_sincos_pos_embed(embed_dim: int, pos: torch.Tensor, omega_0: float = 100) -> torch.Tensor:
+ """
+ This function generates a 1D positional embedding from a given grid using sine and cosine functions.
+
+ Args:
+ - embed_dim: The embedding dimension.
+ - pos: The position to generate the embedding from.
+
+ Returns:
+ - emb: The generated 1D positional embedding.
+ """
+ assert embed_dim % 2 == 0
+ omega = torch.arange(embed_dim // 2, dtype=torch.double, device=pos.device)
+ omega /= embed_dim / 2.0
+ omega = 1.0 / omega_0**omega # (D/2,)
+
+ pos = pos.reshape(-1) # (M,)
+ out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
+
+ emb_sin = torch.sin(out) # (M, D/2)
+ emb_cos = torch.cos(out) # (M, D/2)
+
+ emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
+ return emb.float()
+
+
+# Inspired by https://github.com/microsoft/moge
+
+
+def create_uv_grid(
+ width: int, height: int, aspect_ratio: float = None, dtype: torch.dtype = None, device: torch.device = None
+) -> torch.Tensor:
+ """
+ Create a normalized UV grid of shape (width, height, 2).
+
+ The grid spans horizontally and vertically according to an aspect ratio,
+ ensuring the top-left corner is at (-x_span, -y_span) and the bottom-right
+ corner is at (x_span, y_span), normalized by the diagonal of the plane.
+
+ Args:
+ width (int): Number of points horizontally.
+ height (int): Number of points vertically.
+ aspect_ratio (float, optional): Width-to-height ratio. Defaults to width/height.
+ dtype (torch.dtype, optional): Data type of the resulting tensor.
+ device (torch.device, optional): Device on which the tensor is created.
+
+ Returns:
+ torch.Tensor: A (width, height, 2) tensor of UV coordinates.
+ """
+ # Derive aspect ratio if not explicitly provided
+ if aspect_ratio is None:
+ aspect_ratio = float(width) / float(height)
+
+ # Compute normalized spans for X and Y
+ diag_factor = (aspect_ratio**2 + 1.0) ** 0.5
+ span_x = aspect_ratio / diag_factor
+ span_y = 1.0 / diag_factor
+
+ # Establish the linspace boundaries
+ left_x = -span_x * (width - 1) / width
+ right_x = span_x * (width - 1) / width
+ top_y = -span_y * (height - 1) / height
+ bottom_y = span_y * (height - 1) / height
+
+ # Generate 1D coordinates
+ x_coords = torch.linspace(left_x, right_x, steps=width, dtype=dtype, device=device)
+ y_coords = torch.linspace(top_y, bottom_y, steps=height, dtype=dtype, device=device)
+
+ # Create 2D meshgrid (width x height) and stack into UV
+ uu, vv = torch.meshgrid(x_coords, y_coords, indexing="xy")
+ uv_grid = torch.stack((uu, vv), dim=-1)
+
+ return uv_grid
diff --git a/src/model/encoder/vggt/layers/__init__.py b/src/model/encoder/vggt/layers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8120f4bc83066cb3f825ce32daa3b437f88486f1
--- /dev/null
+++ b/src/model/encoder/vggt/layers/__init__.py
@@ -0,0 +1,11 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from .mlp import Mlp
+from .patch_embed import PatchEmbed
+from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
+from .block import NestedTensorBlock
+from .attention import MemEffAttention
diff --git a/src/model/encoder/vggt/layers/attention.py b/src/model/encoder/vggt/layers/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..805d8150c4dcbcf6ae2b88b997cfeeac4c7e691b
--- /dev/null
+++ b/src/model/encoder/vggt/layers/attention.py
@@ -0,0 +1,98 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
+
+import logging
+import os
+import warnings
+
+from torch import Tensor
+from torch import nn
+import torch.nn.functional as F
+
+XFORMERS_AVAILABLE = False
+
+
+class Attention(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = True,
+ proj_bias: bool = True,
+ attn_drop: float = 0.0,
+ proj_drop: float = 0.0,
+ norm_layer: nn.Module = nn.LayerNorm,
+ qk_norm: bool = False,
+ fused_attn: bool = True, # use F.scaled_dot_product_attention or not
+ rope=None,
+ ) -> None:
+ super().__init__()
+ assert dim % num_heads == 0, "dim should be divisible by num_heads"
+ self.num_heads = num_heads
+ self.head_dim = dim // num_heads
+ self.scale = self.head_dim**-0.5
+ self.fused_attn = fused_attn
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
+ self.proj_drop = nn.Dropout(proj_drop)
+ self.rope = rope
+
+ def forward(self, x: Tensor, pos=None) -> Tensor:
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv.unbind(0)
+ q, k = self.q_norm(q), self.k_norm(k)
+
+ if self.rope is not None:
+ q = self.rope(q, pos)
+ k = self.rope(k, pos)
+
+ if self.fused_attn:
+ x = F.scaled_dot_product_attention(
+ q,
+ k,
+ v,
+ dropout_p=self.attn_drop.p if self.training else 0.0,
+ )
+ else:
+ q = q * self.scale
+ attn = q @ k.transpose(-2, -1)
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+ x = attn @ v
+
+ x = x.transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class MemEffAttention(Attention):
+ def forward(self, x: Tensor, attn_bias=None, pos=None) -> Tensor:
+ assert pos is None
+ if not XFORMERS_AVAILABLE:
+ if attn_bias is not None:
+ raise AssertionError("xFormers is required for using nested tensors")
+ return super().forward(x)
+
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
+
+ q, k, v = unbind(qkv, 2)
+
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
+ x = x.reshape([B, N, C])
+
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
diff --git a/src/model/encoder/vggt/layers/block.py b/src/model/encoder/vggt/layers/block.py
new file mode 100644
index 0000000000000000000000000000000000000000..567881a5ca6fd8dcce2b965946ce21e70fb6be02
--- /dev/null
+++ b/src/model/encoder/vggt/layers/block.py
@@ -0,0 +1,259 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
+
+import logging
+import os
+from typing import Callable, List, Any, Tuple, Dict
+import warnings
+
+import torch
+from torch import nn, Tensor
+
+from .attention import Attention
+from .drop_path import DropPath
+from .layer_scale import LayerScale
+from .mlp import Mlp
+
+
+XFORMERS_AVAILABLE = False
+
+
+class Block(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int,
+ mlp_ratio: float = 4.0,
+ qkv_bias: bool = True,
+ proj_bias: bool = True,
+ ffn_bias: bool = True,
+ drop: float = 0.0,
+ attn_drop: float = 0.0,
+ init_values=None,
+ drop_path: float = 0.0,
+ act_layer: Callable[..., nn.Module] = nn.GELU,
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
+ attn_class: Callable[..., nn.Module] = Attention,
+ ffn_layer: Callable[..., nn.Module] = Mlp,
+ qk_norm: bool = False,
+ fused_attn: bool = True, # use F.scaled_dot_product_attention or not
+ rope=None,
+ ) -> None:
+ super().__init__()
+
+ self.norm1 = norm_layer(dim)
+
+ self.attn = attn_class(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ attn_drop=attn_drop,
+ proj_drop=drop,
+ qk_norm=qk_norm,
+ fused_attn=fused_attn,
+ rope=rope,
+ )
+
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = ffn_layer(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=drop,
+ bias=ffn_bias,
+ )
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ self.sample_drop_ratio = drop_path
+
+ def forward(self, x: Tensor, pos=None) -> Tensor:
+ def attn_residual_func(x: Tensor, pos=None) -> Tensor:
+ return self.ls1(self.attn(self.norm1(x), pos=pos))
+
+ def ffn_residual_func(x: Tensor) -> Tensor:
+ return self.ls2(self.mlp(self.norm2(x)))
+
+ if self.training and self.sample_drop_ratio > 0.1:
+ # the overhead is compensated only for a drop path rate larger than 0.1
+ x = drop_add_residual_stochastic_depth(
+ x,
+ pos=pos,
+ residual_func=attn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ )
+ x = drop_add_residual_stochastic_depth(
+ x,
+ residual_func=ffn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ )
+ elif self.training and self.sample_drop_ratio > 0.0:
+ x = x + self.drop_path1(attn_residual_func(x, pos=pos))
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
+ else:
+ x = x + attn_residual_func(x, pos=pos)
+ x = x + ffn_residual_func(x)
+ return x
+
+
+def drop_add_residual_stochastic_depth(
+ x: Tensor,
+ residual_func: Callable[[Tensor], Tensor],
+ sample_drop_ratio: float = 0.0,
+ pos=None,
+) -> Tensor:
+ # 1) extract subset using permutation
+ b, n, d = x.shape
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
+ x_subset = x[brange]
+
+ # 2) apply residual_func to get residual
+ if pos is not None:
+ # if necessary, apply rope to the subset
+ pos = pos[brange]
+ residual = residual_func(x_subset, pos=pos)
+ else:
+ residual = residual_func(x_subset)
+
+ x_flat = x.flatten(1)
+ residual = residual.flatten(1)
+
+ residual_scale_factor = b / sample_subset_size
+
+ # 3) add the residual
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
+ return x_plus_residual.view_as(x)
+
+
+def get_branges_scales(x, sample_drop_ratio=0.0):
+ b, n, d = x.shape
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
+ residual_scale_factor = b / sample_subset_size
+ return brange, residual_scale_factor
+
+
+def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
+ if scaling_vector is None:
+ x_flat = x.flatten(1)
+ residual = residual.flatten(1)
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
+ else:
+ x_plus_residual = scaled_index_add(
+ x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
+ )
+ return x_plus_residual
+
+
+attn_bias_cache: Dict[Tuple, Any] = {}
+
+
+def get_attn_bias_and_cat(x_list, branges=None):
+ """
+ this will perform the index select, cat the tensors, and provide the attn_bias from cache
+ """
+ batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
+ all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
+ if all_shapes not in attn_bias_cache.keys():
+ seqlens = []
+ for b, x in zip(batch_sizes, x_list):
+ for _ in range(b):
+ seqlens.append(x.shape[1])
+ attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
+ attn_bias._batch_sizes = batch_sizes
+ attn_bias_cache[all_shapes] = attn_bias
+
+ if branges is not None:
+ cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
+ else:
+ tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
+ cat_tensors = torch.cat(tensors_bs1, dim=1)
+
+ return attn_bias_cache[all_shapes], cat_tensors
+
+
+def drop_add_residual_stochastic_depth_list(
+ x_list: List[Tensor],
+ residual_func: Callable[[Tensor, Any], Tensor],
+ sample_drop_ratio: float = 0.0,
+ scaling_vector=None,
+) -> Tensor:
+ # 1) generate random set of indices for dropping samples in the batch
+ branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
+ branges = [s[0] for s in branges_scales]
+ residual_scale_factors = [s[1] for s in branges_scales]
+
+ # 2) get attention bias and index+concat the tensors
+ attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
+
+ # 3) apply residual_func to get residual, and split the result
+ residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
+
+ outputs = []
+ for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
+ outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
+ return outputs
+
+
+class NestedTensorBlock(Block):
+ def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
+ """
+ x_list contains a list of tensors to nest together and run
+ """
+ assert isinstance(self.attn, MemEffAttention)
+
+ if self.training and self.sample_drop_ratio > 0.0:
+
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.attn(self.norm1(x), attn_bias=attn_bias)
+
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.mlp(self.norm2(x))
+
+ x_list = drop_add_residual_stochastic_depth_list(
+ x_list,
+ residual_func=attn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
+ )
+ x_list = drop_add_residual_stochastic_depth_list(
+ x_list,
+ residual_func=ffn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
+ )
+ return x_list
+ else:
+
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
+
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.ls2(self.mlp(self.norm2(x)))
+
+ attn_bias, x = get_attn_bias_and_cat(x_list)
+ x = x + attn_residual_func(x, attn_bias=attn_bias)
+ x = x + ffn_residual_func(x)
+ return attn_bias.split(x)
+
+ def forward(self, x_or_x_list):
+ if isinstance(x_or_x_list, Tensor):
+ return super().forward(x_or_x_list)
+ elif isinstance(x_or_x_list, list):
+ if not XFORMERS_AVAILABLE:
+ raise AssertionError("xFormers is required for using nested tensors")
+ return self.forward_nested(x_or_x_list)
+ else:
+ raise AssertionError
diff --git a/src/model/encoder/vggt/layers/drop_path.py b/src/model/encoder/vggt/layers/drop_path.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d640e0b969b8dcba96260243473700b4e5b24b5
--- /dev/null
+++ b/src/model/encoder/vggt/layers/drop_path.py
@@ -0,0 +1,34 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
+
+
+from torch import nn
+
+
+def drop_path(x, drop_prob: float = 0.0, training: bool = False):
+ if drop_prob == 0.0 or not training:
+ return x
+ keep_prob = 1 - drop_prob
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
+ if keep_prob > 0.0:
+ random_tensor.div_(keep_prob)
+ output = x * random_tensor
+ return output
+
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+ def __init__(self, drop_prob=None):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training)
diff --git a/src/model/encoder/vggt/layers/layer_scale.py b/src/model/encoder/vggt/layers/layer_scale.py
new file mode 100644
index 0000000000000000000000000000000000000000..51df0d7ce61f2b41fa9e6369f52391dd7fe7d386
--- /dev/null
+++ b/src/model/encoder/vggt/layers/layer_scale.py
@@ -0,0 +1,27 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
+
+from typing import Union
+
+import torch
+from torch import Tensor
+from torch import nn
+
+
+class LayerScale(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ init_values: Union[float, Tensor] = 1e-5,
+ inplace: bool = False,
+ ) -> None:
+ super().__init__()
+ self.inplace = inplace
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
+
+ def forward(self, x: Tensor) -> Tensor:
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
diff --git a/src/model/encoder/vggt/layers/mlp.py b/src/model/encoder/vggt/layers/mlp.py
new file mode 100644
index 0000000000000000000000000000000000000000..bbf9432aae9258612caeae910a7bde17999e328e
--- /dev/null
+++ b/src/model/encoder/vggt/layers/mlp.py
@@ -0,0 +1,40 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
+
+
+from typing import Callable, Optional
+
+from torch import Tensor, nn
+
+
+class Mlp(nn.Module):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = nn.GELU,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x: Tensor) -> Tensor:
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
diff --git a/src/model/encoder/vggt/layers/patch_embed.py b/src/model/encoder/vggt/layers/patch_embed.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b7c0804784a42cf80c0297d110dcc68cc85b339
--- /dev/null
+++ b/src/model/encoder/vggt/layers/patch_embed.py
@@ -0,0 +1,88 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
+
+from typing import Callable, Optional, Tuple, Union
+
+from torch import Tensor
+import torch.nn as nn
+
+
+def make_2tuple(x):
+ if isinstance(x, tuple):
+ assert len(x) == 2
+ return x
+
+ assert isinstance(x, int)
+ return (x, x)
+
+
+class PatchEmbed(nn.Module):
+ """
+ 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
+
+ Args:
+ img_size: Image size.
+ patch_size: Patch token size.
+ in_chans: Number of input image channels.
+ embed_dim: Number of linear projection output channels.
+ norm_layer: Normalization layer.
+ """
+
+ def __init__(
+ self,
+ img_size: Union[int, Tuple[int, int]] = 224,
+ patch_size: Union[int, Tuple[int, int]] = 16,
+ in_chans: int = 3,
+ embed_dim: int = 768,
+ norm_layer: Optional[Callable] = None,
+ flatten_embedding: bool = True,
+ ) -> None:
+ super().__init__()
+
+ image_HW = make_2tuple(img_size)
+ patch_HW = make_2tuple(patch_size)
+ patch_grid_size = (
+ image_HW[0] // patch_HW[0],
+ image_HW[1] // patch_HW[1],
+ )
+
+ self.img_size = image_HW
+ self.patch_size = patch_HW
+ self.patches_resolution = patch_grid_size
+ self.num_patches = patch_grid_size[0] * patch_grid_size[1]
+
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+
+ self.flatten_embedding = flatten_embedding
+
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
+
+ def forward(self, x: Tensor) -> Tensor:
+ _, _, H, W = x.shape
+ patch_H, patch_W = self.patch_size
+
+ assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
+ assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
+
+ x = self.proj(x) # B C H W
+ H, W = x.size(2), x.size(3)
+ x = x.flatten(2).transpose(1, 2) # B HW C
+ x = self.norm(x)
+ if not self.flatten_embedding:
+ x = x.reshape(-1, H, W, self.embed_dim) # B H W C
+ return x
+
+ def flops(self) -> float:
+ Ho, Wo = self.patches_resolution
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
+ if self.norm is not None:
+ flops += Ho * Wo * self.embed_dim
+ return flops
diff --git a/src/model/encoder/vggt/layers/rope.py b/src/model/encoder/vggt/layers/rope.py
new file mode 100644
index 0000000000000000000000000000000000000000..4d5d33304e55dbd05687bd86752a47a80e5f82df
--- /dev/null
+++ b/src/model/encoder/vggt/layers/rope.py
@@ -0,0 +1,188 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+
+# Implementation of 2D Rotary Position Embeddings (RoPE).
+
+# This module provides a clean implementation of 2D Rotary Position Embeddings,
+# which extends the original RoPE concept to handle 2D spatial positions.
+
+# Inspired by:
+# https://github.com/meta-llama/codellama/blob/main/llama/model.py
+# https://github.com/naver-ai/rope-vit
+
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from typing import Dict, Tuple
+
+
+class PositionGetter:
+ """Generates and caches 2D spatial positions for patches in a grid.
+
+ This class efficiently manages the generation of spatial coordinates for patches
+ in a 2D grid, caching results to avoid redundant computations.
+
+ Attributes:
+ position_cache: Dictionary storing precomputed position tensors for different
+ grid dimensions.
+ """
+
+ def __init__(self):
+ """Initializes the position generator with an empty cache."""
+ self.position_cache: Dict[Tuple[int, int], torch.Tensor] = {}
+
+ def __call__(self, batch_size: int, height: int, width: int, device: torch.device) -> torch.Tensor:
+ """Generates spatial positions for a batch of patches.
+
+ Args:
+ batch_size: Number of samples in the batch.
+ height: Height of the grid in patches.
+ width: Width of the grid in patches.
+ device: Target device for the position tensor.
+
+ Returns:
+ Tensor of shape (batch_size, height*width, 2) containing y,x coordinates
+ for each position in the grid, repeated for each batch item.
+ """
+ if (height, width) not in self.position_cache:
+ y_coords = torch.arange(height, device=device)
+ x_coords = torch.arange(width, device=device)
+ positions = torch.cartesian_prod(y_coords, x_coords)
+ self.position_cache[height, width] = positions
+
+ cached_positions = self.position_cache[height, width]
+ return cached_positions.view(1, height * width, 2).expand(batch_size, -1, -1).clone()
+
+
+class RotaryPositionEmbedding2D(nn.Module):
+ """2D Rotary Position Embedding implementation.
+
+ This module applies rotary position embeddings to input tokens based on their
+ 2D spatial positions. It handles the position-dependent rotation of features
+ separately for vertical and horizontal dimensions.
+
+ Args:
+ frequency: Base frequency for the position embeddings. Default: 100.0
+ scaling_factor: Scaling factor for frequency computation. Default: 1.0
+
+ Attributes:
+ base_frequency: Base frequency for computing position embeddings.
+ scaling_factor: Factor to scale the computed frequencies.
+ frequency_cache: Cache for storing precomputed frequency components.
+ """
+
+ def __init__(self, frequency: float = 100.0, scaling_factor: float = 1.0):
+ """Initializes the 2D RoPE module."""
+ super().__init__()
+ self.base_frequency = frequency
+ self.scaling_factor = scaling_factor
+ self.frequency_cache: Dict[Tuple, Tuple[torch.Tensor, torch.Tensor]] = {}
+
+ def _compute_frequency_components(
+ self, dim: int, seq_len: int, device: torch.device, dtype: torch.dtype
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Computes frequency components for rotary embeddings.
+
+ Args:
+ dim: Feature dimension (must be even).
+ seq_len: Maximum sequence length.
+ device: Target device for computations.
+ dtype: Data type for the computed tensors.
+
+ Returns:
+ Tuple of (cosine, sine) tensors for frequency components.
+ """
+ cache_key = (dim, seq_len, device, dtype)
+ if cache_key not in self.frequency_cache:
+ # Compute frequency bands
+ exponents = torch.arange(0, dim, 2, device=device).float() / dim
+ inv_freq = 1.0 / (self.base_frequency**exponents)
+
+ # Generate position-dependent frequencies
+ positions = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
+ angles = torch.einsum("i,j->ij", positions, inv_freq)
+
+ # Compute and cache frequency components
+ angles = angles.to(dtype)
+ angles = torch.cat((angles, angles), dim=-1)
+ cos_components = angles.cos().to(dtype)
+ sin_components = angles.sin().to(dtype)
+ self.frequency_cache[cache_key] = (cos_components, sin_components)
+
+ return self.frequency_cache[cache_key]
+
+ @staticmethod
+ def _rotate_features(x: torch.Tensor) -> torch.Tensor:
+ """Performs feature rotation by splitting and recombining feature dimensions.
+
+ Args:
+ x: Input tensor to rotate.
+
+ Returns:
+ Rotated feature tensor.
+ """
+ feature_dim = x.shape[-1]
+ x1, x2 = x[..., : feature_dim // 2], x[..., feature_dim // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+ def _apply_1d_rope(
+ self, tokens: torch.Tensor, positions: torch.Tensor, cos_comp: torch.Tensor, sin_comp: torch.Tensor
+ ) -> torch.Tensor:
+ """Applies 1D rotary position embeddings along one dimension.
+
+ Args:
+ tokens: Input token features.
+ positions: Position indices.
+ cos_comp: Cosine components for rotation.
+ sin_comp: Sine components for rotation.
+
+ Returns:
+ Tokens with applied rotary position embeddings.
+ """
+ # Embed positions with frequency components
+ cos = F.embedding(positions, cos_comp)[:, None, :, :]
+ sin = F.embedding(positions, sin_comp)[:, None, :, :]
+
+ # Apply rotation
+ return (tokens * cos) + (self._rotate_features(tokens) * sin)
+
+ def forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> torch.Tensor:
+ """Applies 2D rotary position embeddings to input tokens.
+
+ Args:
+ tokens: Input tensor of shape (batch_size, n_heads, n_tokens, dim).
+ The feature dimension (dim) must be divisible by 4.
+ positions: Position tensor of shape (batch_size, n_tokens, 2) containing
+ the y and x coordinates for each token.
+
+ Returns:
+ Tensor of same shape as input with applied 2D rotary position embeddings.
+
+ Raises:
+ AssertionError: If input dimensions are invalid or positions are malformed.
+ """
+ # Validate inputs
+ assert tokens.size(-1) % 2 == 0, "Feature dimension must be even"
+ assert positions.ndim == 3 and positions.shape[-1] == 2, "Positions must have shape (batch_size, n_tokens, 2)"
+
+ # Compute feature dimension for each spatial direction
+ feature_dim = tokens.size(-1) // 2
+
+ # Get frequency components
+ max_position = int(positions.max()) + 1
+ cos_comp, sin_comp = self._compute_frequency_components(feature_dim, max_position, tokens.device, tokens.dtype)
+
+ # Split features for vertical and horizontal processing
+ vertical_features, horizontal_features = tokens.chunk(2, dim=-1)
+
+ # Apply RoPE separately for each dimension
+ vertical_features = self._apply_1d_rope(vertical_features, positions[..., 0], cos_comp, sin_comp)
+ horizontal_features = self._apply_1d_rope(horizontal_features, positions[..., 1], cos_comp, sin_comp)
+
+ # Combine processed features
+ return torch.cat((vertical_features, horizontal_features), dim=-1)
diff --git a/src/model/encoder/vggt/layers/swiglu_ffn.py b/src/model/encoder/vggt/layers/swiglu_ffn.py
new file mode 100644
index 0000000000000000000000000000000000000000..54fe8e90b7bedf6fbdbf09c6215844e3cc63f857
--- /dev/null
+++ b/src/model/encoder/vggt/layers/swiglu_ffn.py
@@ -0,0 +1,72 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import os
+from typing import Callable, Optional
+import warnings
+
+from torch import Tensor, nn
+import torch.nn.functional as F
+
+
+class SwiGLUFFN(nn.Module):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = None,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
+ self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
+
+ def forward(self, x: Tensor) -> Tensor:
+ x12 = self.w12(x)
+ x1, x2 = x12.chunk(2, dim=-1)
+ hidden = F.silu(x1) * x2
+ return self.w3(hidden)
+
+
+XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
+# try:
+# if XFORMERS_ENABLED:
+# from xformers.ops import SwiGLU
+
+# XFORMERS_AVAILABLE = True
+# warnings.warn("xFormers is available (SwiGLU)")
+# else:
+# warnings.warn("xFormers is disabled (SwiGLU)")
+# raise ImportError
+# except ImportError:
+SwiGLU = SwiGLUFFN
+XFORMERS_AVAILABLE = False
+
+# warnings.warn("xFormers is not available (SwiGLU)")
+
+
+class SwiGLUFFNFused(SwiGLU):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = None,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
+ super().__init__(
+ in_features=in_features,
+ hidden_features=hidden_features,
+ out_features=out_features,
+ bias=bias,
+ )
diff --git a/src/model/encoder/vggt/layers/vision_transformer.py b/src/model/encoder/vggt/layers/vision_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e12f5b4158c4bbd2a0438adb263f310f9c6554b
--- /dev/null
+++ b/src/model/encoder/vggt/layers/vision_transformer.py
@@ -0,0 +1,408 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
+
+from functools import partial
+import math
+import logging
+from typing import Sequence, Tuple, Union, Callable
+
+import torch
+import torch.nn as nn
+from torch.utils.checkpoint import checkpoint
+from torch.nn.init import trunc_normal_
+from . import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
+
+logger = logging.getLogger("dinov2")
+
+
+def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
+ if not depth_first and include_root:
+ fn(module=module, name=name)
+ for child_name, child_module in module.named_children():
+ child_name = ".".join((name, child_name)) if name else child_name
+ named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
+ if depth_first and include_root:
+ fn(module=module, name=name)
+ return module
+
+
+class BlockChunk(nn.ModuleList):
+ def forward(self, x):
+ for b in self:
+ x = b(x)
+ return x
+
+
+class DinoVisionTransformer(nn.Module):
+ def __init__(
+ self,
+ img_size=224,
+ patch_size=16,
+ in_chans=3,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ ffn_bias=True,
+ proj_bias=True,
+ drop_path_rate=0.0,
+ drop_path_uniform=False,
+ init_values=None, # for layerscale: None or 0 => no layerscale
+ embed_layer=PatchEmbed,
+ act_layer=nn.GELU,
+ block_fn=Block,
+ ffn_layer="mlp",
+ block_chunks=1,
+ num_register_tokens=0,
+ interpolate_antialias=False,
+ interpolate_offset=0.1,
+ qk_norm=False,
+ ):
+ """
+ Args:
+ img_size (int, tuple): input image size
+ patch_size (int, tuple): patch size
+ in_chans (int): number of input channels
+ embed_dim (int): embedding dimension
+ depth (int): depth of transformer
+ num_heads (int): number of attention heads
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
+ qkv_bias (bool): enable bias for qkv if True
+ proj_bias (bool): enable bias for proj in attn if True
+ ffn_bias (bool): enable bias for ffn if True
+ drop_path_rate (float): stochastic depth rate
+ drop_path_uniform (bool): apply uniform drop rate across blocks
+ weight_init (str): weight init scheme
+ init_values (float): layer-scale init values
+ embed_layer (nn.Module): patch embedding layer
+ act_layer (nn.Module): MLP activation layer
+ block_fn (nn.Module): transformer block class
+ ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
+ block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
+ num_register_tokens: (int) number of extra cls tokens (so-called "registers")
+ interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
+ interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
+ """
+ super().__init__()
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
+
+ # tricky but makes it work
+ self.use_checkpoint = True
+ self.use_reentrant = False
+ #
+
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
+ self.num_tokens = 1
+ self.n_blocks = depth
+ self.num_heads = num_heads
+ self.patch_size = patch_size
+ self.num_register_tokens = num_register_tokens
+ self.interpolate_antialias = interpolate_antialias
+ self.interpolate_offset = interpolate_offset
+
+ self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
+ num_patches = self.patch_embed.num_patches
+
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
+ assert num_register_tokens >= 0
+ self.register_tokens = (
+ nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
+ )
+
+ if drop_path_uniform is True:
+ dpr = [drop_path_rate] * depth
+ else:
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
+
+ if ffn_layer == "mlp":
+ logger.info("using MLP layer as FFN")
+ ffn_layer = Mlp
+ elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
+ logger.info("using SwiGLU layer as FFN")
+ ffn_layer = SwiGLUFFNFused
+ elif ffn_layer == "identity":
+ logger.info("using Identity layer as FFN")
+
+ def f(*args, **kwargs):
+ return nn.Identity()
+
+ ffn_layer = f
+ else:
+ raise NotImplementedError
+
+ blocks_list = [
+ block_fn(
+ dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ ffn_bias=ffn_bias,
+ drop_path=dpr[i],
+ norm_layer=norm_layer,
+ act_layer=act_layer,
+ ffn_layer=ffn_layer,
+ init_values=init_values,
+ qk_norm=qk_norm,
+ )
+ for i in range(depth)
+ ]
+ if block_chunks > 0:
+ self.chunked_blocks = True
+ chunked_blocks = []
+ chunksize = depth // block_chunks
+ for i in range(0, depth, chunksize):
+ # this is to keep the block index consistent if we chunk the block list
+ chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
+ self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
+ else:
+ self.chunked_blocks = False
+ self.blocks = nn.ModuleList(blocks_list)
+
+ self.norm = norm_layer(embed_dim)
+ self.head = nn.Identity()
+
+ self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
+
+ self.init_weights()
+
+ def init_weights(self):
+ trunc_normal_(self.pos_embed, std=0.02)
+ nn.init.normal_(self.cls_token, std=1e-6)
+ if self.register_tokens is not None:
+ nn.init.normal_(self.register_tokens, std=1e-6)
+ named_apply(init_weights_vit_timm, self)
+
+ def interpolate_pos_encoding(self, x, w, h):
+ previous_dtype = x.dtype
+ npatch = x.shape[1] - 1
+ N = self.pos_embed.shape[1] - 1
+ if npatch == N and w == h:
+ return self.pos_embed
+ pos_embed = self.pos_embed.float()
+ class_pos_embed = pos_embed[:, 0]
+ patch_pos_embed = pos_embed[:, 1:]
+ dim = x.shape[-1]
+ w0 = w // self.patch_size
+ h0 = h // self.patch_size
+ M = int(math.sqrt(N)) # Recover the number of patches in each dimension
+ assert N == M * M
+ kwargs = {}
+ if self.interpolate_offset:
+ # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8
+ # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors
+ sx = float(w0 + self.interpolate_offset) / M
+ sy = float(h0 + self.interpolate_offset) / M
+ kwargs["scale_factor"] = (sx, sy)
+ else:
+ # Simply specify an output size instead of a scale factor
+ kwargs["size"] = (w0, h0)
+ patch_pos_embed = nn.functional.interpolate(
+ patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
+ mode="bicubic",
+ antialias=self.interpolate_antialias,
+ **kwargs,
+ )
+ assert (w0, h0) == patch_pos_embed.shape[-2:]
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
+
+ def prepare_tokens_with_masks(self, x, masks=None):
+ B, nc, w, h = x.shape
+ x = self.patch_embed(x)
+ if masks is not None:
+ x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
+
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
+ x = x + self.interpolate_pos_encoding(x, w, h)
+
+ if self.register_tokens is not None:
+ x = torch.cat(
+ (
+ x[:, :1],
+ self.register_tokens.expand(x.shape[0], -1, -1),
+ x[:, 1:],
+ ),
+ dim=1,
+ )
+
+ return x
+
+ def forward_features_list(self, x_list, masks_list):
+ x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
+
+ for blk in self.blocks:
+ if self.use_checkpoint:
+ x = checkpoint(blk, x, use_reentrant=self.use_reentrant)
+ else:
+ x = blk(x)
+
+ all_x = x
+ output = []
+ for x, masks in zip(all_x, masks_list):
+ x_norm = self.norm(x)
+ output.append(
+ {
+ "x_norm_clstoken": x_norm[:, 0],
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
+ "x_prenorm": x,
+ "masks": masks,
+ }
+ )
+ return output
+
+ def forward_features(self, x, masks=None):
+ if isinstance(x, list):
+ return self.forward_features_list(x, masks)
+
+ x = self.prepare_tokens_with_masks(x, masks)
+
+ for blk in self.blocks:
+ if self.use_checkpoint:
+ x = checkpoint(blk, x, use_reentrant=self.use_reentrant)
+ else:
+ x = blk(x)
+
+ x_norm = self.norm(x)
+ return {
+ "x_norm_clstoken": x_norm[:, 0],
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
+ "x_prenorm": x,
+ "masks": masks,
+ }
+
+ def _get_intermediate_layers_not_chunked(self, x, n=1):
+ x = self.prepare_tokens_with_masks(x)
+ # If n is an int, take the n last blocks. If it's a list, take them
+ output, total_block_len = [], len(self.blocks)
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
+ for i, blk in enumerate(self.blocks):
+ x = blk(x)
+ if i in blocks_to_take:
+ output.append(x)
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
+ return output
+
+ def _get_intermediate_layers_chunked(self, x, n=1):
+ x = self.prepare_tokens_with_masks(x)
+ output, i, total_block_len = [], 0, len(self.blocks[-1])
+ # If n is an int, take the n last blocks. If it's a list, take them
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
+ for block_chunk in self.blocks:
+ for blk in block_chunk[i:]: # Passing the nn.Identity()
+ x = blk(x)
+ if i in blocks_to_take:
+ output.append(x)
+ i += 1
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
+ return output
+
+ def get_intermediate_layers(
+ self,
+ x: torch.Tensor,
+ n: Union[int, Sequence] = 1, # Layers or n last layers to take
+ reshape: bool = False,
+ return_class_token: bool = False,
+ norm=True,
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
+ if self.chunked_blocks:
+ outputs = self._get_intermediate_layers_chunked(x, n)
+ else:
+ outputs = self._get_intermediate_layers_not_chunked(x, n)
+ if norm:
+ outputs = [self.norm(out) for out in outputs]
+ class_tokens = [out[:, 0] for out in outputs]
+ outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs]
+ if reshape:
+ B, _, w, h = x.shape
+ outputs = [
+ out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
+ for out in outputs
+ ]
+ if return_class_token:
+ return tuple(zip(outputs, class_tokens))
+ return tuple(outputs)
+
+ def forward(self, *args, is_training=True, **kwargs):
+ ret = self.forward_features(*args, **kwargs)
+ if is_training:
+ return ret
+ else:
+ return self.head(ret["x_norm_clstoken"])
+
+
+def init_weights_vit_timm(module: nn.Module, name: str = ""):
+ """ViT weight initialization, original timm impl (for reproducibility)"""
+ if isinstance(module, nn.Linear):
+ trunc_normal_(module.weight, std=0.02)
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+
+
+def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=384,
+ depth=12,
+ num_heads=6,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+ return model
+
+
+def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+ return model
+
+
+def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=1024,
+ depth=24,
+ num_heads=16,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+ return model
+
+
+def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
+ """
+ Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
+ """
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=1536,
+ depth=40,
+ num_heads=24,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+ return model
diff --git a/src/model/encoder/vggt/models/aggregator.py b/src/model/encoder/vggt/models/aggregator.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae6387768e6ace9acce173086c539c1e752bf371
--- /dev/null
+++ b/src/model/encoder/vggt/models/aggregator.py
@@ -0,0 +1,365 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import logging
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from typing import Optional, Tuple, Union, List, Dict, Any
+
+from src.model.encoder.vggt.layers import PatchEmbed
+from src.model.encoder.vggt.layers.block import Block
+from src.model.encoder.vggt.layers.rope import RotaryPositionEmbedding2D, PositionGetter
+from src.model.encoder.vggt.layers.vision_transformer import vit_small, vit_base, vit_large, vit_giant2
+
+logger = logging.getLogger(__name__)
+
+_RESNET_MEAN = [0.485, 0.456, 0.406]
+_RESNET_STD = [0.229, 0.224, 0.225]
+
+
+class Aggregator(nn.Module):
+ """
+ The Aggregator applies alternating-attention over input frames,
+ as described in VGGT: Visual Geometry Grounded Transformer.
+
+
+ Args:
+ img_size (int): Image size in pixels.
+ patch_size (int): Size of each patch for PatchEmbed.
+ embed_dim (int): Dimension of the token embeddings.
+ depth (int): Number of blocks.
+ num_heads (int): Number of attention heads.
+ mlp_ratio (float): Ratio of MLP hidden dim to embedding dim.
+ num_register_tokens (int): Number of register tokens.
+ block_fn (nn.Module): The block type used for attention (Block by default).
+ qkv_bias (bool): Whether to include bias in QKV projections.
+ proj_bias (bool): Whether to include bias in the output projection.
+ ffn_bias (bool): Whether to include bias in MLP layers.
+ patch_embed (str): Type of patch embed. e.g., "conv" or "dinov2_vitl14_reg".
+ aa_order (list[str]): The order of alternating attention, e.g. ["frame", "global"].
+ aa_block_size (int): How many blocks to group under each attention type before switching. If not necessary, set to 1.
+ qk_norm (bool): Whether to apply QK normalization.
+ rope_freq (int): Base frequency for rotary embedding. -1 to disable.
+ init_values (float): Init scale for layer scale.
+ """
+
+ def __init__(
+ self,
+ img_size=518,
+ patch_size=14,
+ embed_dim=1024,
+ depth=24,
+ num_heads=16,
+ mlp_ratio=4.0,
+ num_register_tokens=4,
+ block_fn=Block,
+ qkv_bias=True,
+ proj_bias=True,
+ ffn_bias=True,
+ patch_embed="dinov2_vitl14_reg",
+ aa_order=["frame", "global"],
+ aa_block_size=1,
+ qk_norm=True,
+ rope_freq=100,
+ init_values=0.01,
+ ):
+ super().__init__()
+ self.use_checkpoint = True
+ self.__build_patch_embed__(patch_embed, img_size, patch_size, num_register_tokens, embed_dim=embed_dim)
+
+ # Initialize rotary position embedding if frequency > 0
+ self.rope = RotaryPositionEmbedding2D(frequency=rope_freq) if rope_freq > 0 else None
+ self.position_getter = PositionGetter() if self.rope is not None else None
+
+ self.frame_blocks = nn.ModuleList(
+ [
+ block_fn(
+ dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ ffn_bias=ffn_bias,
+ init_values=init_values,
+ qk_norm=qk_norm,
+ rope=self.rope,
+ )
+ for _ in range(depth)
+ ]
+ )
+
+ self.global_blocks = nn.ModuleList(
+ [
+ block_fn(
+ dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ ffn_bias=ffn_bias,
+ init_values=init_values,
+ qk_norm=qk_norm,
+ rope=self.rope,
+ )
+ for _ in range(depth)
+ ]
+ )
+
+ self.depth = depth
+ self.aa_order = aa_order
+ self.patch_size = patch_size
+ self.aa_block_size = aa_block_size
+
+ # Validate that depth is divisible by aa_block_size
+ if self.depth % self.aa_block_size != 0:
+ raise ValueError(f"depth ({depth}) must be divisible by aa_block_size ({aa_block_size})")
+
+ self.aa_block_num = self.depth // self.aa_block_size
+
+ # Note: We have two camera tokens, one for the first frame and one for the rest
+ # The same applies for register tokens
+ self.camera_token = nn.Parameter(torch.randn(1, 2, 1, embed_dim))
+ self.register_token = nn.Parameter(torch.randn(1, 2, num_register_tokens, embed_dim))
+
+ # The patch tokens start after the camera and register tokens
+ self.patch_start_idx = 1 + num_register_tokens
+
+ # Initialize parameters with small values
+ nn.init.normal_(self.camera_token, std=1e-6)
+ nn.init.normal_(self.register_token, std=1e-6)
+
+ # Register normalization constants as buffers
+ for name, value in (
+ ("_resnet_mean", _RESNET_MEAN),
+ ("_resnet_std", _RESNET_STD),
+ ):
+ self.register_buffer(
+ name,
+ torch.FloatTensor(value).view(1, 1, 3, 1, 1),
+ persistent=False,
+ )
+
+ def __build_patch_embed__(
+ self,
+ patch_embed,
+ img_size,
+ patch_size,
+ num_register_tokens,
+ interpolate_antialias=True,
+ interpolate_offset=0.0,
+ block_chunks=0,
+ init_values=1.0,
+ embed_dim=1024,
+ ):
+ """
+ Build the patch embed layer. If 'conv', we use a
+ simple PatchEmbed conv layer. Otherwise, we use a vision transformer.
+ """
+
+ if "conv" in patch_embed:
+ self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=3, embed_dim=embed_dim)
+ else:
+ vit_models = {
+ "dinov2_vitl14_reg": vit_large,
+ "dinov2_vitb14_reg": vit_base,
+ "dinov2_vits14_reg": vit_small,
+ "dinov2_vitg2_reg": vit_giant2,
+ }
+
+ self.patch_embed = vit_models[patch_embed](
+ img_size=img_size,
+ patch_size=patch_size,
+ num_register_tokens=num_register_tokens,
+ interpolate_antialias=interpolate_antialias,
+ interpolate_offset=interpolate_offset,
+ block_chunks=block_chunks,
+ init_values=init_values,
+ )
+
+ # Disable gradient updates for mask token
+ if hasattr(self.patch_embed, "mask_token"):
+ self.patch_embed.mask_token.requires_grad_(False)
+
+ def forward(
+ self,
+ images: torch.Tensor,
+ intermediate_layer_idx: Optional[List[int]] = None
+ ) -> Tuple[List[torch.Tensor], int]:
+ """
+ Args:
+ images (torch.Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1].
+ B: batch size, S: sequence length, 3: RGB channels, H: height, W: width
+
+ Returns:
+ (list[torch.Tensor], int):
+ The list of outputs from the attention blocks,
+ and the patch_start_idx indicating where patch tokens begin.
+ """
+ B, S, C_in, H, W = images.shape
+
+ if C_in != 3:
+ raise ValueError(f"Expected 3 input channels, got {C_in}")
+
+ # Normalize images and reshape for patch embed
+ images = (images - self._resnet_mean) / self._resnet_std
+
+ # Reshape to [B*S, C, H, W] for patch embedding
+ images = images.view(B * S, C_in, H, W)
+ patch_tokens = self.patch_embed(images)
+
+ if isinstance(patch_tokens, dict):
+ patch_tokens = patch_tokens["x_norm_patchtokens"]
+
+ _, P, C = patch_tokens.shape
+
+ # Expand camera and register tokens to match batch size and sequence length
+ camera_token = slice_expand_and_flatten(self.camera_token, B, S)
+ register_token = slice_expand_and_flatten(self.register_token, B, S)
+
+ # Concatenate special tokens with patch tokens
+ tokens = torch.cat([camera_token, register_token, patch_tokens], dim=1)
+
+ pos = None
+ if self.rope is not None:
+ pos = self.position_getter(B * S, H // self.patch_size, W // self.patch_size, device=images.device)
+
+ if self.patch_start_idx > 0:
+ # do not use position embedding for special tokens (camera and register tokens)
+ # so set pos to 0 for the special tokens
+ pos = pos + 1
+ pos_special = torch.zeros(B * S, self.patch_start_idx, 2).to(images.device).to(pos.dtype)
+ pos = torch.cat([pos_special, pos], dim=1)
+
+ # update P because we added special tokens
+ _, P, C = tokens.shape
+
+ frame_idx = 0
+ global_idx = 0
+ output_list = []
+ layer_idx = 0
+
+ # Convert intermediate_layer_idx to a set for O(1) lookup
+ if intermediate_layer_idx is not None:
+ required_layers = set(intermediate_layer_idx)
+ # Always include the last layer for camera_head
+ required_layers.add(self.depth - 1)
+
+ for _ in range(self.aa_block_num):
+ for attn_type in self.aa_order:
+ if attn_type == "frame":
+ tokens, frame_idx, frame_intermediates = self._process_frame_attention(
+ tokens, B, S, P, C, frame_idx, pos=pos
+ )
+ elif attn_type == "global":
+ tokens, global_idx, global_intermediates = self._process_global_attention(
+ tokens, B, S, P, C, global_idx, pos=pos
+ )
+ else:
+ raise ValueError(f"Unknown attention type: {attn_type}")
+
+ if intermediate_layer_idx is not None:
+ for i in range(len(frame_intermediates)):
+ current_layer = layer_idx + i
+ if current_layer in required_layers:
+ # concat frame and global intermediates, [B x S x P x 2C]
+ concat_inter = torch.cat([frame_intermediates[i], global_intermediates[i]], dim=-1)
+ output_list.append(concat_inter)
+ layer_idx += self.aa_block_size
+
+ else:
+ for i in range(len(frame_intermediates)):
+ # concat frame and global intermediates, [B x S x P x 2C]
+ concat_inter = torch.cat([frame_intermediates[i], global_intermediates[i]], dim=-1)
+ output_list.append(concat_inter)
+
+ del concat_inter
+ del frame_intermediates
+ del global_intermediates
+ return output_list, self.patch_start_idx
+
+ def _process_frame_attention(self, tokens, B, S, P, C, frame_idx, pos=None):
+ """
+ Process frame attention blocks. We keep tokens in shape (B*S, P, C).
+ """
+ # If needed, reshape tokens or positions:
+ if tokens.shape != (B * S, P, C):
+ tokens = tokens.view(B, S, P, C).view(B * S, P, C)
+
+ if pos is not None and pos.shape != (B * S, P, 2):
+ pos = pos.view(B, S, P, 2).view(B * S, P, 2)
+
+ intermediates = []
+
+ # by default, self.aa_block_size=1, which processes one block at a time
+ for _ in range(self.aa_block_size):
+ if self.use_checkpoint:
+ tokens = torch.utils.checkpoint.checkpoint(
+ self.frame_blocks[frame_idx],
+ tokens,
+ pos,
+ use_reentrant=False,
+ )
+ else:
+ tokens = self.frame_blocks[frame_idx](tokens, pos=pos)
+ frame_idx += 1
+ intermediates.append(tokens.view(B, S, P, C))
+
+ return tokens, frame_idx, intermediates
+
+ def _process_global_attention(self, tokens, B, S, P, C, global_idx, pos=None):
+ """
+ Process global attention blocks. We keep tokens in shape (B, S*P, C).
+ """
+ if tokens.shape != (B, S * P, C):
+ tokens = tokens.view(B, S, P, C).view(B, S * P, C)
+
+ if pos is not None and pos.shape != (B, S * P, 2):
+ pos = pos.view(B, S, P, 2).view(B, S * P, 2)
+
+ intermediates = []
+
+ # by default, self.aa_block_size=1, which processes one block at a time
+ for _ in range(self.aa_block_size):
+ if self.use_checkpoint:
+ tokens = torch.utils.checkpoint.checkpoint(
+ self.global_blocks[global_idx],
+ tokens,
+ pos,
+ use_reentrant=False,
+ )
+ else:
+ tokens = self.global_blocks[global_idx](tokens, pos=pos)
+ global_idx += 1
+ intermediates.append(tokens.view(B, S, P, C))
+
+ return tokens, global_idx, intermediates
+
+
+def slice_expand_and_flatten(token_tensor, B, S):
+ """
+ Processes specialized tokens with shape (1, 2, X, C) for multi-frame processing:
+ 1) Uses the first position (index=0) for the first frame only
+ 2) Uses the second position (index=1) for all remaining frames (S-1 frames)
+ 3) Expands both to match batch size B
+ 4) Concatenates to form (B, S, X, C) where each sequence has 1 first-position token
+ followed by (S-1) second-position tokens
+ 5) Flattens to (B*S, X, C) for processing
+
+ Returns:
+ torch.Tensor: Processed tokens with shape (B*S, X, C)
+ """
+
+ # Slice out the "query" tokens => shape (1, 1, ...)
+ query = token_tensor[:, 0:1, ...].expand(B, 1, *token_tensor.shape[2:])
+ # Slice out the "other" tokens => shape (1, S-1, ...)
+ others = token_tensor[:, 1:, ...].expand(B, S - 1, *token_tensor.shape[2:])
+ # Concatenate => shape (B, S, ...)
+ combined = torch.cat([query, others], dim=1)
+
+ # Finally flatten => shape (B*S, ...)
+ combined = combined.view(B * S, *combined.shape[2:])
+ return combined
diff --git a/src/model/encoder/vggt/models/vggt.py b/src/model/encoder/vggt/models/vggt.py
new file mode 100644
index 0000000000000000000000000000000000000000..973e6d465d4e0dad9611058a1bb7966a6c6c48c0
--- /dev/null
+++ b/src/model/encoder/vggt/models/vggt.py
@@ -0,0 +1,96 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+from huggingface_hub import PyTorchModelHubMixin # used for model hub
+
+from src.model.encoder.vggt.models.aggregator import Aggregator
+from src.model.encoder.vggt.heads.camera_head import CameraHead
+from src.model.encoder.vggt.heads.dpt_head import DPTHead
+from src.model.encoder.vggt.heads.track_head import TrackHead
+
+
+class VGGT(nn.Module, PyTorchModelHubMixin):
+ def __init__(self, img_size=518, patch_size=14, embed_dim=1024):
+ super().__init__()
+
+ self.aggregator = Aggregator(img_size=img_size, patch_size=patch_size, embed_dim=embed_dim)
+ self.camera_head = CameraHead(dim_in=2 * embed_dim)
+ self.point_head = DPTHead(dim_in=2 * embed_dim, output_dim=4, activation="inv_log", conf_activation="expp1")
+ self.depth_head = DPTHead(dim_in=2 * embed_dim, output_dim=2, activation="exp", conf_activation="expp1")
+ self.track_head = TrackHead(dim_in=2 * embed_dim, patch_size=patch_size)
+
+ def forward(
+ self,
+ images: torch.Tensor,
+ query_points: torch.Tensor = None,
+ ):
+ """
+ Forward pass of the VGGT model.
+
+ Args:
+ images (torch.Tensor): Input images with shape [S, 3, H, W] or [B, S, 3, H, W], in range [0, 1].
+ B: batch size, S: sequence length, 3: RGB channels, H: height, W: width
+ query_points (torch.Tensor, optional): Query points for tracking, in pixel coordinates.
+ Shape: [N, 2] or [B, N, 2], where N is the number of query points.
+ Default: None
+
+ Returns:
+ dict: A dictionary containing the following predictions:
+ - pose_enc (torch.Tensor): Camera pose encoding with shape [B, S, 9] (from the last iteration)
+ - depth (torch.Tensor): Predicted depth maps with shape [B, S, H, W, 1]
+ - depth_conf (torch.Tensor): Confidence scores for depth predictions with shape [B, S, H, W]
+ - world_points (torch.Tensor): 3D world coordinates for each pixel with shape [B, S, H, W, 3]
+ - world_points_conf (torch.Tensor): Confidence scores for world points with shape [B, S, H, W]
+ - images (torch.Tensor): Original input images, preserved for visualization
+
+ If query_points is provided, also includes:
+ - track (torch.Tensor): Point tracks with shape [B, S, N, 2] (from the last iteration), in pixel coordinates
+ - vis (torch.Tensor): Visibility scores for tracked points with shape [B, S, N]
+ - conf (torch.Tensor): Confidence scores for tracked points with shape [B, S, N]
+ """
+
+ # If without batch dimension, add it
+ if len(images.shape) == 4:
+ images = images.unsqueeze(0)
+ if query_points is not None and len(query_points.shape) == 2:
+ query_points = query_points.unsqueeze(0)
+
+ aggregated_tokens_list, patch_start_idx = self.aggregator(images)
+
+ predictions = {}
+
+ with torch.cuda.amp.autocast(enabled=False):
+ if self.camera_head is not None:
+ pose_enc_list = self.camera_head(aggregated_tokens_list)
+ predictions["pose_enc"] = pose_enc_list[-1] # pose encoding of the last iteration
+
+ if self.depth_head is not None:
+ depth, depth_conf = self.depth_head(
+ aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx
+ )
+ predictions["depth"] = depth
+ predictions["depth_conf"] = depth_conf
+
+ if self.point_head is not None:
+ pts3d, pts3d_conf = self.point_head(
+ aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx
+ )
+ predictions["world_points"] = pts3d
+ predictions["world_points_conf"] = pts3d_conf
+
+ if self.track_head is not None and query_points is not None:
+ track_list, vis, conf = self.track_head(
+ aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx, query_points=query_points
+ )
+ predictions["track"] = track_list[-1] # track of the last iteration
+ predictions["vis"] = vis
+ predictions["conf"] = conf
+
+ predictions["images"] = images
+
+ return predictions
diff --git a/src/model/encoder/vggt/utils/geometry.py b/src/model/encoder/vggt/utils/geometry.py
new file mode 100644
index 0000000000000000000000000000000000000000..722a3423e4c44211ea42d29bb50834f44a27b485
--- /dev/null
+++ b/src/model/encoder/vggt/utils/geometry.py
@@ -0,0 +1,198 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+import torch
+
+def batchify_unproject_depth_map_to_point_map(
+ depth_map: torch.Tensor, extrinsics_cam: torch.Tensor, intrinsics_cam: torch.Tensor
+) -> torch.Tensor:
+ """
+ Unproject a batch of depth maps to 3D world coordinates.
+
+ Args:
+ depth_map (torch.Tensor): Batch of depth maps of shape (B, V, H, W, 1) or (B, V, H, W)
+ extrinsics_cam (torch.Tensor): Batch of camera extrinsic matrices of shape (B, V, 3, 4)
+ intrinsics_cam (torch.Tensor): Batch of camera intrinsic matrices of shape (B, V, 3, 3)
+
+ Returns:
+ torch.Tensor: Batch of 3D world coordinates of shape (S, H, W, 3)
+ """
+
+ # Handle both (S, H, W, 1) and (S, H, W) cases
+ if depth_map.dim() == 5:
+ depth_map = depth_map.squeeze(-1) # (S, H, W)
+
+ # Generate batched camera coordinates
+ H, W = depth_map.shape[2:]
+ batch_size, num_views = depth_map.shape[0], depth_map.shape[1]
+
+ # Intrinsic parameters (S, 3, 3)
+ intrinsics_cam, extrinsics_cam, depth_map = intrinsics_cam.flatten(0, 1), extrinsics_cam.flatten(0, 1), depth_map.flatten(0, 1)
+ fu = intrinsics_cam[:, 0, 0] # (S,)
+ fv = intrinsics_cam[:, 1, 1] # (S,)
+ cu = intrinsics_cam[:, 0, 2] # (S,)
+ cv = intrinsics_cam[:, 1, 2] # (S,)
+
+ # Generate grid of pixel coordinates
+ u = torch.arange(W, device=depth_map.device)[None, None, :].expand(batch_size * num_views, H, W) # (S, H, W)
+ v = torch.arange(H, device=depth_map.device)[None, :, None].expand(batch_size * num_views, H, W) # (S, H, W)
+
+ # Unproject to camera coordinates (S, H, W, 3)
+ x_cam = (u - cu[:, None, None]) * depth_map / fu[:, None, None]
+ y_cam = (v - cv[:, None, None]) * depth_map / fv[:, None, None]
+ z_cam = depth_map
+
+ cam_coords = torch.stack((x_cam, y_cam, z_cam), dim=-1) # (S, H, W, 3)
+
+ # Transform to world coordinates
+ cam_to_world = closed_form_inverse_se3(extrinsics_cam) # (S, 4, 4)
+
+ # homo transformation
+ homo_pts = torch.cat((cam_coords, torch.ones_like(cam_coords[..., :1])), dim=-1).flatten(1, 2)
+ world_coords = torch.bmm(cam_to_world, homo_pts.transpose(1, 2)).transpose(1, 2)[:, :, :3].view(batch_size*num_views, H, W, 3)
+
+ return world_coords.view(batch_size, num_views, H, W, 3)
+
+def unproject_depth_map_to_point_map(
+ depth_map: torch.Tensor, extrinsics_cam: torch.Tensor, intrinsics_cam: torch.Tensor
+) -> torch.Tensor:
+ """
+ Unproject a batch of depth maps to 3D world coordinates.
+
+ Args:
+ depth_map (torch.Tensor): Batch of depth maps of shape (S, H, W, 1) or (S, H, W)
+ extrinsics_cam (torch.Tensor): Batch of camera extrinsic matrices of shape (S, 3, 4)
+ intrinsics_cam (torch.Tensor): Batch of camera intrinsic matrices of shape (S, 3, 3)
+
+ Returns:
+ torch.Tensor: Batch of 3D world coordinates of shape (S, H, W, 3)
+ """
+ world_points_list = []
+ for frame_idx in range(depth_map.shape[0]):
+ cur_world_points, _, _ = depth_to_world_coords_points(
+ depth_map[frame_idx].squeeze(-1), extrinsics_cam[frame_idx], intrinsics_cam[frame_idx]
+ )
+ world_points_list.append(cur_world_points)
+ world_points_array = torch.stack(world_points_list, dim=0)
+
+ return world_points_array
+
+
+def depth_to_world_coords_points(
+ depth_map: torch.Tensor,
+ extrinsic: torch.Tensor,
+ intrinsic: torch.Tensor,
+ eps=1e-8,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """
+ Convert a depth map to world coordinates.
+
+ Args:
+ depth_map (torch.Tensor): Depth map of shape (H, W).
+ intrinsic (torch.Tensor): Camera intrinsic matrix of shape (3, 3).
+ extrinsic (torch.Tensor): Camera extrinsic matrix of shape (3, 4). OpenCV camera coordinate convention, cam from world.
+
+ Returns:
+ tuple[torch.Tensor, torch.Tensor]: World coordinates (H, W, 3) and valid depth mask (H, W).
+ """
+ if depth_map is None:
+ return None, None, None
+
+ # Valid depth mask
+ point_mask = depth_map > eps
+
+ # Convert depth map to camera coordinates
+ cam_coords_points = depth_to_cam_coords_points(depth_map, intrinsic)
+
+ # Multiply with the inverse of extrinsic matrix to transform to world coordinates
+ # extrinsic_inv is 4x4 (note closed_form_inverse_OpenCV is batched, the output is (N, 4, 4))
+ cam_to_world_extrinsic = closed_form_inverse_se3(extrinsic[None])[0]
+
+ R_cam_to_world = cam_to_world_extrinsic[:3, :3]
+ t_cam_to_world = cam_to_world_extrinsic[:3, 3]
+
+ # Apply the rotation and translation to the camera coordinates
+ world_coords_points = torch.matmul(cam_coords_points, R_cam_to_world.T) + t_cam_to_world # HxWx3, 3x3 -> HxWx3
+
+ return world_coords_points, cam_coords_points, point_mask
+
+
+def depth_to_cam_coords_points(depth_map: torch.Tensor, intrinsic: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Convert a depth map to camera coordinates.
+
+ Args:
+ depth_map (torch.Tensor): Depth map of shape (H, W).
+ intrinsic (torch.Tensor): Camera intrinsic matrix of shape (3, 3).
+
+ Returns:
+ tuple[torch.Tensor, torch.Tensor]: Camera coordinates (H, W, 3)
+ """
+ H, W = depth_map.shape
+ assert intrinsic.shape == (3, 3), "Intrinsic matrix must be 3x3"
+ assert intrinsic[0, 1] == 0 and intrinsic[1, 0] == 0, "Intrinsic matrix must have zero skew"
+
+ # Intrinsic parameters
+ fu, fv = intrinsic[0, 0], intrinsic[1, 1]
+ cu, cv = intrinsic[0, 2], intrinsic[1, 2]
+
+ # Generate grid of pixel coordinates
+ u, v = torch.meshgrid(torch.arange(W, device=depth_map.device),
+ torch.arange(H, device=depth_map.device),
+ indexing='xy')
+
+ # Unproject to camera coordinates
+ x_cam = (u - cu) * depth_map / fu
+ y_cam = (v - cv) * depth_map / fv
+ z_cam = depth_map
+
+ # Stack to form camera coordinates
+ cam_coords = torch.stack((x_cam, y_cam, z_cam), dim=-1).to(dtype=torch.float32)
+
+ return cam_coords
+
+
+def closed_form_inverse_se3(se3, R=None, T=None):
+ """
+ Compute the inverse of each 4x4 (or 3x4) SE3 matrix in a batch.
+
+ If `R` and `T` are provided, they must correspond to the rotation and translation
+ components of `se3`. Otherwise, they will be extracted from `se3`.
+
+ Args:
+ se3: Nx4x4 or Nx3x4 array or tensor of SE3 matrices.
+ R (optional): Nx3x3 array or tensor of rotation matrices.
+ T (optional): Nx3x1 array or tensor of translation vectors.
+
+ Returns:
+ Inverted SE3 matrices with the same type and device as `se3`.
+
+ Shapes:
+ se3: (N, 4, 4)
+ R: (N, 3, 3)
+ T: (N, 3, 1)
+ """
+ # Validate shapes
+ if se3.shape[-2:] != (4, 4) and se3.shape[-2:] != (3, 4):
+ raise ValueError(f"se3 must be of shape (N,4,4), got {se3.shape}.")
+
+ # Extract R and T if not provided
+ if R is None:
+ R = se3[:, :3, :3] # (N,3,3)
+ if T is None:
+ T = se3[:, :3, 3:] # (N,3,1)
+
+ # Transpose R
+ R_transposed = R.transpose(1, 2) # (N,3,3)
+ top_right = -torch.bmm(R_transposed, T) # (N,3,1)
+ inverted_matrix = torch.eye(4, 4, device=R.device)[None].repeat(len(R), 1, 1)
+ inverted_matrix = inverted_matrix.to(R.dtype)
+
+ inverted_matrix[:, :3, :3] = R_transposed
+ inverted_matrix[:, :3, 3:] = top_right
+
+ return inverted_matrix
diff --git a/src/model/encoder/vggt/utils/load_fn.py b/src/model/encoder/vggt/utils/load_fn.py
new file mode 100644
index 0000000000000000000000000000000000000000..ee8a307a4f339a07424f79e4dfe5fee1252a9c8c
--- /dev/null
+++ b/src/model/encoder/vggt/utils/load_fn.py
@@ -0,0 +1,145 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+from PIL import Image
+from torchvision import transforms as TF
+
+def load_and_preprocess_images(image_path_list, mode="crop"):
+ """
+ A quick start function to load and preprocess images for model input.
+ This assumes the images should have the same shape for easier batching, but our model can also work well with different shapes.
+
+ Args:
+ image_path_list (list): List of paths to image files
+ mode (str, optional): Preprocessing mode, either "crop" or "pad".
+ - "crop" (default): Sets width to 518px and center crops height if needed.
+ - "pad": Preserves all pixels by making the largest dimension 518px
+ and padding the smaller dimension to reach a square shape.
+
+ Returns:
+ torch.Tensor: Batched tensor of preprocessed images with shape (N, 3, H, W)
+
+ Raises:
+ ValueError: If the input list is empty or if mode is invalid
+
+ Notes:
+ - Images with different dimensions will be padded with white (value=1.0)
+ - A warning is printed when images have different shapes
+ - When mode="crop": The function ensures width=518px while maintaining aspect ratio
+ and height is center-cropped if larger than 518px
+ - When mode="pad": The function ensures the largest dimension is 518px while maintaining aspect ratio
+ and the smaller dimension is padded to reach a square shape (518x518)
+ - Dimensions are adjusted to be divisible by 14 for compatibility with model requirements
+ """
+ # Check for empty list
+ if len(image_path_list) == 0:
+ raise ValueError("At least 1 image is required")
+
+ # Validate mode
+ if mode not in ["crop", "pad"]:
+ raise ValueError("Mode must be either 'crop' or 'pad'")
+
+ images = []
+ shapes = set()
+ to_tensor = TF.ToTensor()
+ target_size = 448
+
+ # First process all images and collect their shapes
+ for image_path in image_path_list:
+
+ # Open image
+ img = Image.open(image_path)
+
+ # If there's an alpha channel, blend onto white background:
+ if img.mode == "RGBA":
+ # Create white background
+ background = Image.new("RGBA", img.size, (255, 255, 255, 255))
+ # Alpha composite onto the white background
+ img = Image.alpha_composite(background, img)
+
+ # Now convert to "RGB" (this step assigns white for transparent areas)
+ img = img.convert("RGB")
+
+ width, height = img.size
+
+ if mode == "pad":
+ # Make the largest dimension 518px while maintaining aspect ratio
+ if width >= height:
+ new_width = target_size
+ new_height = round(height * (new_width / width) / 14) * 14 # Make divisible by 14
+ else:
+ new_height = target_size
+ new_width = round(width * (new_height / height) / 14) * 14 # Make divisible by 14
+ else: # mode == "crop"
+ # Original behavior: set width to 518px
+ new_width = target_size
+ # Calculate height maintaining aspect ratio, divisible by 14
+ new_height = round(height * (new_width / width) / 14) * 14
+
+ # Resize with new dimensions (width, height)
+ img = img.resize((new_width, new_height), Image.Resampling.BICUBIC)
+ img = to_tensor(img) # Convert to tensor (0, 1)
+
+ # Center crop height if it's larger than 518 (only in crop mode)
+ if mode == "crop" and new_height > target_size:
+ start_y = (new_height - target_size) // 2
+ img = img[:, start_y : start_y + target_size, :]
+
+ # For pad mode, pad to make a square of target_size x target_size
+ if mode == "pad":
+ h_padding = target_size - img.shape[1]
+ w_padding = target_size - img.shape[2]
+
+ if h_padding > 0 or w_padding > 0:
+ pad_top = h_padding // 2
+ pad_bottom = h_padding - pad_top
+ pad_left = w_padding // 2
+ pad_right = w_padding - pad_left
+
+ # Pad with white (value=1.0)
+ img = torch.nn.functional.pad(
+ img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0
+ )
+
+ shapes.add((img.shape[1], img.shape[2]))
+ images.append(img)
+
+ # Check if we have different shapes
+ # In theory our model can also work well with different shapes
+ if len(shapes) > 1:
+ print(f"Warning: Found images with different shapes: {shapes}")
+ # Find maximum dimensions
+ max_height = max(shape[0] for shape in shapes)
+ max_width = max(shape[1] for shape in shapes)
+
+ # Pad images if necessary
+ padded_images = []
+ for img in images:
+ h_padding = max_height - img.shape[1]
+ w_padding = max_width - img.shape[2]
+
+ if h_padding > 0 or w_padding > 0:
+ pad_top = h_padding // 2
+ pad_bottom = h_padding - pad_top
+ pad_left = w_padding // 2
+ pad_right = w_padding - pad_left
+
+ img = torch.nn.functional.pad(
+ img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0
+ )
+ padded_images.append(img)
+ images = padded_images
+
+ images = torch.stack(images) # concatenate images
+
+ # Ensure correct shape when single image
+ if len(image_path_list) == 1:
+ # Verify shape is (1, C, H, W)
+ if images.dim() == 3:
+ images = images.unsqueeze(0)
+
+ return images
diff --git a/src/model/encoder/vggt/utils/pose_enc.py b/src/model/encoder/vggt/utils/pose_enc.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d91eb086f3c216693292aa2daaf123629b12475
--- /dev/null
+++ b/src/model/encoder/vggt/utils/pose_enc.py
@@ -0,0 +1,130 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+from .rotation import quat_to_mat, mat_to_quat
+
+
+def extri_intri_to_pose_encoding(
+ extrinsics,
+ intrinsics,
+ image_size_hw=None, # e.g., (256, 512)
+ pose_encoding_type="absT_quaR_FoV",
+):
+ """Convert camera extrinsics and intrinsics to a compact pose encoding.
+
+ This function transforms camera parameters into a unified pose encoding format,
+ which can be used for various downstream tasks like pose prediction or representation.
+
+ Args:
+ extrinsics (torch.Tensor): Camera extrinsic parameters with shape BxSx3x4,
+ where B is batch size and S is sequence length.
+ In OpenCV coordinate system (x-right, y-down, z-forward), representing camera from world transformation.
+ The format is [R|t] where R is a 3x3 rotation matrix and t is a 3x1 translation vector.
+ intrinsics (torch.Tensor): Camera intrinsic parameters with shape BxSx3x3.
+ Defined in pixels, with format:
+ [[fx, 0, cx],
+ [0, fy, cy],
+ [0, 0, 1]]
+ where fx, fy are focal lengths and (cx, cy) is the principal point
+ image_size_hw (tuple): Tuple of (height, width) of the image in pixels.
+ Required for computing field of view values. For example: (256, 512).
+ pose_encoding_type (str): Type of pose encoding to use. Currently only
+ supports "absT_quaR_FoV" (absolute translation, quaternion rotation, field of view).
+
+ Returns:
+ torch.Tensor: Encoded camera pose parameters with shape BxSx9.
+ For "absT_quaR_FoV" type, the 9 dimensions are:
+ - [:3] = absolute translation vector T (3D)
+ - [3:7] = rotation as quaternion quat (4D)
+ - [7:] = field of view (2D)
+ """
+
+ # extrinsics: BxSx3x4
+ # intrinsics: BxSx3x3
+
+ if pose_encoding_type == "absT_quaR_FoV":
+ R = extrinsics[:, :, :3, :3] # BxSx3x3
+ T = extrinsics[:, :, :3, 3] # BxSx3
+
+ quat = mat_to_quat(R)
+ # Note the order of h and w here
+ H, W = image_size_hw
+ fov_h = 2 * torch.atan((H / 2) / intrinsics[..., 1, 1])
+ fov_w = 2 * torch.atan((W / 2) / intrinsics[..., 0, 0])
+ pose_encoding = torch.cat([T, quat, fov_h[..., None], fov_w[..., None]], dim=-1).float()
+ else:
+ raise NotImplementedError
+
+ return pose_encoding
+
+
+def pose_encoding_to_extri_intri(
+ pose_encoding,
+ image_size_hw=None, # e.g., (256, 512)
+ pose_encoding_type="absT_quaR_FoV",
+ build_intrinsics=True,
+):
+ """Convert a pose encoding back to camera extrinsics and intrinsics.
+
+ This function performs the inverse operation of extri_intri_to_pose_encoding,
+ reconstructing the full camera parameters from the compact encoding.
+
+ Args:
+ pose_encoding (torch.Tensor): Encoded camera pose parameters with shape BxSx9,
+ where B is batch size and S is sequence length.
+ For "absT_quaR_FoV" type, the 9 dimensions are:
+ - [:3] = absolute translation vector T (3D)
+ - [3:7] = rotation as quaternion quat (4D)
+ - [7:] = field of view (2D)
+ image_size_hw (tuple): Tuple of (height, width) of the image in pixels.
+ Required for reconstructing intrinsics from field of view values.
+ For example: (256, 512).
+ pose_encoding_type (str): Type of pose encoding used. Currently only
+ supports "absT_quaR_FoV" (absolute translation, quaternion rotation, field of view).
+ build_intrinsics (bool): Whether to reconstruct the intrinsics matrix.
+ If False, only extrinsics are returned and intrinsics will be None.
+
+ Returns:
+ tuple: (extrinsics, intrinsics)
+ - extrinsics (torch.Tensor): Camera extrinsic parameters with shape BxSx3x4.
+ In OpenCV coordinate system (x-right, y-down, z-forward), representing camera from world
+ transformation. The format is [R|t] where R is a 3x3 rotation matrix and t is
+ a 3x1 translation vector.
+ - intrinsics (torch.Tensor or None): Camera intrinsic parameters with shape BxSx3x3,
+ or None if build_intrinsics is False. Defined in pixels, with format:
+ [[fx, 0, cx],
+ [0, fy, cy],
+ [0, 0, 1]]
+ where fx, fy are focal lengths and (cx, cy) is the principal point,
+ assumed to be at the center of the image (W/2, H/2).
+ """
+
+ intrinsics = None
+
+ if pose_encoding_type == "absT_quaR_FoV":
+ T = pose_encoding[..., :3]
+ quat = pose_encoding[..., 3:7]
+ fov_h = pose_encoding[..., 7]
+ fov_w = pose_encoding[..., 8]
+
+ R = quat_to_mat(quat)
+ extrinsics = torch.cat([R, T[..., None]], dim=-1)
+
+ if build_intrinsics:
+ H, W = image_size_hw
+ fy = (H / 2.0) / (torch.tan(fov_h / 2.0) + 1e-3)
+ fx = (W / 2.0) / (torch.tan(fov_w / 2.0) + 1e-3)
+ intrinsics = torch.zeros(pose_encoding.shape[:2] + (3, 3), device=pose_encoding.device, dtype=pose_encoding.dtype)
+ intrinsics[..., 0, 0] = fx
+ intrinsics[..., 1, 1] = fy
+ intrinsics[..., 0, 2] = W / 2
+ intrinsics[..., 1, 2] = H / 2
+ intrinsics[..., 2, 2] = 1.0 # Set the homogeneous coordinate to 1
+ else:
+ raise NotImplementedError
+
+ return extrinsics, intrinsics
diff --git a/src/model/encoder/vggt/utils/rotation.py b/src/model/encoder/vggt/utils/rotation.py
new file mode 100644
index 0000000000000000000000000000000000000000..657583e6915437c824c192d51939990b589a14fa
--- /dev/null
+++ b/src/model/encoder/vggt/utils/rotation.py
@@ -0,0 +1,138 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# Modified from PyTorch3D, https://github.com/facebookresearch/pytorch3d
+
+import torch
+import numpy as np
+import torch.nn.functional as F
+
+
+def quat_to_mat(quaternions: torch.Tensor) -> torch.Tensor:
+ """
+ Quaternion Order: XYZW or say ijkr, scalar-last
+
+ Convert rotations given as quaternions to rotation matrices.
+ Args:
+ quaternions: quaternions with real part last,
+ as tensor of shape (..., 4).
+
+ Returns:
+ Rotation matrices as tensor of shape (..., 3, 3).
+ """
+ i, j, k, r = torch.unbind(quaternions, -1)
+ # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
+ two_s = 2.0 / (quaternions * quaternions).sum(-1)
+
+ o = torch.stack(
+ (
+ 1 - two_s * (j * j + k * k),
+ two_s * (i * j - k * r),
+ two_s * (i * k + j * r),
+ two_s * (i * j + k * r),
+ 1 - two_s * (i * i + k * k),
+ two_s * (j * k - i * r),
+ two_s * (i * k - j * r),
+ two_s * (j * k + i * r),
+ 1 - two_s * (i * i + j * j),
+ ),
+ -1,
+ )
+ return o.reshape(quaternions.shape[:-1] + (3, 3))
+
+
+def mat_to_quat(matrix: torch.Tensor) -> torch.Tensor:
+ """
+ Convert rotations given as rotation matrices to quaternions.
+
+ Args:
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
+
+ Returns:
+ quaternions with real part last, as tensor of shape (..., 4).
+ Quaternion Order: XYZW or say ijkr, scalar-last
+ """
+ if matrix.size(-1) != 3 or matrix.size(-2) != 3:
+ raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
+
+ batch_dim = matrix.shape[:-2]
+ m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(matrix.reshape(batch_dim + (9,)), dim=-1)
+
+ q_abs = _sqrt_positive_part(
+ torch.stack(
+ [
+ 1.0 + m00 + m11 + m22,
+ 1.0 + m00 - m11 - m22,
+ 1.0 - m00 + m11 - m22,
+ 1.0 - m00 - m11 + m22,
+ ],
+ dim=-1,
+ )
+ )
+
+ # we produce the desired quaternion multiplied by each of r, i, j, k
+ quat_by_rijk = torch.stack(
+ [
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
+ # `int`.
+ torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
+ # `int`.
+ torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
+ # `int`.
+ torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
+ # `int`.
+ torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
+ ],
+ dim=-2,
+ )
+
+ # We floor here at 0.1 but the exact level is not important; if q_abs is small,
+ # the candidate won't be picked.
+ flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
+ quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
+
+ # if not for numerical problems, quat_candidates[i] should be same (up to a sign),
+ # forall i; we pick the best-conditioned one (with the largest denominator)
+ out = quat_candidates[F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :].reshape(batch_dim + (4,))
+
+ # Convert from rijk to ijkr
+ out = out[..., [1, 2, 3, 0]]
+
+ out = standardize_quaternion(out)
+
+ return out
+
+
+def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
+ """
+ Returns torch.sqrt(torch.max(0, x))
+ but with a zero subgradient where x is 0.
+ """
+ ret = torch.zeros_like(x)
+ positive_mask = x > 0
+ if torch.is_grad_enabled():
+ ret[positive_mask] = torch.sqrt(x[positive_mask])
+ else:
+ ret = torch.where(positive_mask, torch.sqrt(x), ret)
+ return ret
+
+
+def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
+ """
+ Convert a unit quaternion to a standard form: one in which the real
+ part is non negative.
+
+ Args:
+ quaternions: Quaternions with real part last,
+ as tensor of shape (..., 4).
+
+ Returns:
+ Standardized quaternions as tensor of shape (..., 4).
+ """
+ return torch.where(quaternions[..., 3:4] < 0, -quaternions, quaternions)
diff --git a/src/model/encoder/vggt/utils/visual_track.py b/src/model/encoder/vggt/utils/visual_track.py
new file mode 100644
index 0000000000000000000000000000000000000000..796c114ccba00b5f7850e04b9444a6cd5c44b154
--- /dev/null
+++ b/src/model/encoder/vggt/utils/visual_track.py
@@ -0,0 +1,239 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import cv2
+import torch
+import numpy as np
+import os
+
+
+def color_from_xy(x, y, W, H, cmap_name="hsv"):
+ """
+ Map (x, y) -> color in (R, G, B).
+ 1) Normalize x,y to [0,1].
+ 2) Combine them into a single scalar c in [0,1].
+ 3) Use matplotlib's colormap to convert c -> (R,G,B).
+
+ You can customize step 2, e.g., c = (x + y)/2, or some function of (x, y).
+ """
+ import matplotlib.cm
+ import matplotlib.colors
+
+ x_norm = x / max(W - 1, 1)
+ y_norm = y / max(H - 1, 1)
+ # Simple combination:
+ c = (x_norm + y_norm) / 2.0
+
+ cmap = matplotlib.cm.get_cmap(cmap_name)
+ # cmap(c) -> (r,g,b,a) in [0,1]
+ rgba = cmap(c)
+ r, g, b = rgba[0], rgba[1], rgba[2]
+ return (r, g, b) # in [0,1], RGB order
+
+
+def get_track_colors_by_position(tracks_b, vis_mask_b=None, image_width=None, image_height=None, cmap_name="hsv"):
+ """
+ Given all tracks in one sample (b), compute a (N,3) array of RGB color values
+ in [0,255]. The color is determined by the (x,y) position in the first
+ visible frame for each track.
+
+ Args:
+ tracks_b: Tensor of shape (S, N, 2). (x,y) for each track in each frame.
+ vis_mask_b: (S, N) boolean mask; if None, assume all are visible.
+ image_width, image_height: used for normalizing (x, y).
+ cmap_name: for matplotlib (e.g., 'hsv', 'rainbow', 'jet').
+
+ Returns:
+ track_colors: np.ndarray of shape (N, 3), each row is (R,G,B) in [0,255].
+ """
+ S, N, _ = tracks_b.shape
+ track_colors = np.zeros((N, 3), dtype=np.uint8)
+
+ if vis_mask_b is None:
+ # treat all as visible
+ vis_mask_b = torch.ones(S, N, dtype=torch.bool, device=tracks_b.device)
+
+ for i in range(N):
+ # Find first visible frame for track i
+ visible_frames = torch.where(vis_mask_b[:, i])[0]
+ if len(visible_frames) == 0:
+ # track is never visible; just assign black or something
+ track_colors[i] = (0, 0, 0)
+ continue
+
+ first_s = int(visible_frames[0].item())
+ # use that frame's (x,y)
+ x, y = tracks_b[first_s, i].tolist()
+
+ # map (x,y) -> (R,G,B) in [0,1]
+ r, g, b = color_from_xy(x, y, W=image_width, H=image_height, cmap_name=cmap_name)
+ # scale to [0,255]
+ r, g, b = int(r * 255), int(g * 255), int(b * 255)
+ track_colors[i] = (r, g, b)
+
+ return track_colors
+
+
+def visualize_tracks_on_images(
+ images,
+ tracks,
+ track_vis_mask=None,
+ out_dir="track_visuals_concat_by_xy",
+ image_format="CHW", # "CHW" or "HWC"
+ normalize_mode="[0,1]",
+ cmap_name="hsv", # e.g. "hsv", "rainbow", "jet"
+ frames_per_row=4, # New parameter for grid layout
+ save_grid=True, # Flag to control whether to save the grid image
+):
+ """
+ Visualizes frames in a grid layout with specified frames per row.
+ Each track's color is determined by its (x,y) position
+ in the first visible frame (or frame 0 if always visible).
+ Finally convert the BGR result to RGB before saving.
+ Also saves each individual frame as a separate PNG file.
+
+ Args:
+ images: torch.Tensor (S, 3, H, W) if CHW or (S, H, W, 3) if HWC.
+ tracks: torch.Tensor (S, N, 2), last dim = (x, y).
+ track_vis_mask: torch.Tensor (S, N) or None.
+ out_dir: folder to save visualizations.
+ image_format: "CHW" or "HWC".
+ normalize_mode: "[0,1]", "[-1,1]", or None for direct raw -> 0..255
+ cmap_name: a matplotlib colormap name for color_from_xy.
+ frames_per_row: number of frames to display in each row of the grid.
+ save_grid: whether to save all frames in one grid image.
+
+ Returns:
+ None (saves images in out_dir).
+ """
+
+ if len(tracks.shape) == 4:
+ tracks = tracks.squeeze(0)
+ images = images.squeeze(0)
+ if track_vis_mask is not None:
+ track_vis_mask = track_vis_mask.squeeze(0)
+
+ import matplotlib
+
+ matplotlib.use("Agg") # for non-interactive (optional)
+
+ os.makedirs(out_dir, exist_ok=True)
+
+ S = images.shape[0]
+ _, N, _ = tracks.shape # (S, N, 2)
+
+ # Move to CPU
+ images = images.cpu().clone()
+ tracks = tracks.cpu().clone()
+ if track_vis_mask is not None:
+ track_vis_mask = track_vis_mask.cpu().clone()
+
+ # Infer H, W from images shape
+ if image_format == "CHW":
+ # e.g. images[s].shape = (3, H, W)
+ H, W = images.shape[2], images.shape[3]
+ else:
+ # e.g. images[s].shape = (H, W, 3)
+ H, W = images.shape[1], images.shape[2]
+
+ # Pre-compute the color for each track i based on first visible position
+ track_colors_rgb = get_track_colors_by_position(
+ tracks, # shape (S, N, 2)
+ vis_mask_b=track_vis_mask if track_vis_mask is not None else None,
+ image_width=W,
+ image_height=H,
+ cmap_name=cmap_name,
+ )
+
+ # We'll accumulate each frame's drawn image in a list
+ frame_images = []
+
+ for s in range(S):
+ # shape => either (3, H, W) or (H, W, 3)
+ img = images[s]
+
+ # Convert to (H, W, 3)
+ if image_format == "CHW":
+ img = img.permute(1, 2, 0) # (H, W, 3)
+ # else "HWC", do nothing
+
+ img = img.numpy().astype(np.float32)
+
+ # Scale to [0,255] if needed
+ if normalize_mode == "[0,1]":
+ img = np.clip(img, 0, 1) * 255.0
+ elif normalize_mode == "[-1,1]":
+ img = (img + 1.0) * 0.5 * 255.0
+ img = np.clip(img, 0, 255.0)
+ # else no normalization
+
+ # Convert to uint8
+ img = img.astype(np.uint8)
+
+ # For drawing in OpenCV, convert to BGR
+ img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
+
+ # Draw each visible track
+ cur_tracks = tracks[s] # shape (N, 2)
+ if track_vis_mask is not None:
+ valid_indices = torch.where(track_vis_mask[s])[0]
+ else:
+ valid_indices = range(N)
+
+ cur_tracks_np = cur_tracks.numpy()
+ for i in valid_indices:
+ x, y = cur_tracks_np[i]
+ pt = (int(round(x)), int(round(y)))
+
+ # track_colors_rgb[i] is (R,G,B). For OpenCV circle, we need BGR
+ R, G, B = track_colors_rgb[i]
+ color_bgr = (int(B), int(G), int(R))
+ cv2.circle(img_bgr, pt, radius=3, color=color_bgr, thickness=-1)
+
+ # Convert back to RGB for consistent final saving:
+ img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
+
+ # Save individual frame
+ frame_path = os.path.join(out_dir, f"frame_{s:04d}.png")
+ # Convert to BGR for OpenCV imwrite
+ frame_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)
+ cv2.imwrite(frame_path, frame_bgr)
+
+ frame_images.append(img_rgb)
+
+ # Only create and save the grid image if save_grid is True
+ if save_grid:
+ # Calculate grid dimensions
+ num_rows = (S + frames_per_row - 1) // frames_per_row # Ceiling division
+
+ # Create a grid of images
+ grid_img = None
+ for row in range(num_rows):
+ start_idx = row * frames_per_row
+ end_idx = min(start_idx + frames_per_row, S)
+
+ # Concatenate this row horizontally
+ row_img = np.concatenate(frame_images[start_idx:end_idx], axis=1)
+
+ # If this row has fewer than frames_per_row images, pad with black
+ if end_idx - start_idx < frames_per_row:
+ padding_width = (frames_per_row - (end_idx - start_idx)) * W
+ padding = np.zeros((H, padding_width, 3), dtype=np.uint8)
+ row_img = np.concatenate([row_img, padding], axis=1)
+
+ # Add this row to the grid
+ if grid_img is None:
+ grid_img = row_img
+ else:
+ grid_img = np.concatenate([grid_img, row_img], axis=0)
+
+ out_path = os.path.join(out_dir, "tracks_grid.png")
+ # Convert back to BGR for OpenCV imwrite
+ grid_img_bgr = cv2.cvtColor(grid_img, cv2.COLOR_RGB2BGR)
+ cv2.imwrite(out_path, grid_img_bgr)
+ print(f"[INFO] Saved color-by-XY track visualization grid -> {out_path}")
+
+ print(f"[INFO] Saved {S} individual frames to {out_dir}/frame_*.png")
diff --git a/src/model/encoder/visualization/encoder_visualizer.py b/src/model/encoder/visualization/encoder_visualizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..5af485b01faed25fbe3557f650d0c3809b216b6d
--- /dev/null
+++ b/src/model/encoder/visualization/encoder_visualizer.py
@@ -0,0 +1,25 @@
+from abc import ABC, abstractmethod
+from typing import Generic, TypeVar
+
+from jaxtyping import Float
+from torch import Tensor
+
+T_cfg = TypeVar("T_cfg")
+T_encoder = TypeVar("T_encoder")
+
+
+class EncoderVisualizer(ABC, Generic[T_cfg, T_encoder]):
+ cfg: T_cfg
+ encoder: T_encoder
+
+ def __init__(self, cfg: T_cfg, encoder: T_encoder) -> None:
+ self.cfg = cfg
+ self.encoder = encoder
+
+ @abstractmethod
+ def visualize(
+ self,
+ context: dict,
+ global_step: int,
+ ) -> dict[str, Float[Tensor, "3 _ _"]]:
+ pass
diff --git a/src/model/encoder/visualization/encoder_visualizer_epipolar.py b/src/model/encoder/visualization/encoder_visualizer_epipolar.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c924b951898d2ff86d3aa2ee2747c5f3380a8d8
--- /dev/null
+++ b/src/model/encoder/visualization/encoder_visualizer_epipolar.py
@@ -0,0 +1,526 @@
+from pathlib import Path
+from random import randrange
+from typing import Optional
+
+import numpy as np
+import torch
+import wandb
+from einops import rearrange, reduce, repeat
+from jaxtyping import Bool, Float
+from torch import Tensor
+
+from ....dataset.types import BatchedViews
+from ....misc.heterogeneous_pairings import generate_heterogeneous_index
+from ....visualization.annotation import add_label
+from ....visualization.color_map import apply_color_map, apply_color_map_to_image
+from ....visualization.colors import get_distinct_color
+from ....visualization.drawing.lines import draw_lines
+from ....visualization.drawing.points import draw_points
+from ....visualization.layout import add_border, hcat, vcat
+from ...ply_export import export_ply
+from .encoder_visualizer import EncoderVisualizer
+from .encoder_visualizer_epipolar_cfg import EncoderVisualizerEpipolarCfg
+
+
+def box(
+ image: Float[Tensor, "3 height width"],
+) -> Float[Tensor, "3 new_height new_width"]:
+ return add_border(add_border(image), 1, 0)
+
+
+class EncoderVisualizerEpipolar(
+ EncoderVisualizer[EncoderVisualizerEpipolarCfg, EncoderEpipolar]
+):
+ def visualize(
+ self,
+ context: BatchedViews,
+ global_step: int,
+ ) -> dict[str, Float[Tensor, "3 _ _"]]:
+ # Short-circuit execution when ablating the epipolar transformer.
+ if self.encoder.epipolar_transformer is None:
+ return {}
+
+ visualization_dump = {}
+
+ softmax_weights = []
+
+ def hook(module, input, output):
+ softmax_weights.append(output)
+
+ # Register hooks to grab attention.
+ handles = [
+ layer[0].fn.attend.register_forward_hook(hook)
+ for layer in self.encoder.epipolar_transformer.transformer.layers
+ ]
+
+ result = self.encoder.forward(
+ context,
+ global_step,
+ visualization_dump=visualization_dump,
+ deterministic=True,
+ )
+
+ # De-register hooks.
+ for handle in handles:
+ handle.remove()
+
+ softmax_weights = torch.stack(softmax_weights)
+
+ # Generate high-resolution context images that can be drawn on.
+ context_images = context["image"]
+ _, _, _, h, w = context_images.shape
+ length = min(h, w)
+ min_resolution = self.cfg.min_resolution
+ scale_multiplier = (min_resolution + length - 1) // length
+ if scale_multiplier > 1:
+ context_images = repeat(
+ context_images,
+ "b v c h w -> b v c (h rh) (w rw)",
+ rh=scale_multiplier,
+ rw=scale_multiplier,
+ )
+
+ # This is kind of hacky for now, since we're using it for short experiments.
+ if self.cfg.export_ply and wandb.run is not None:
+ name = wandb.run._name.split(" ")[0]
+ ply_path = Path(f"outputs/gaussians/{name}/{global_step:0>6}.ply")
+ export_ply(
+ context["extrinsics"][0, 0],
+ result.means[0],
+ visualization_dump["scales"][0],
+ visualization_dump["rotations"][0],
+ result.harmonics[0],
+ result.opacities[0],
+ ply_path,
+ )
+
+ return {
+ # "attention": self.visualize_attention(
+ # context_images,
+ # visualization_dump["sampling"],
+ # softmax_weights,
+ # ),
+ "epipolar_samples": self.visualize_epipolar_samples(
+ context_images,
+ visualization_dump["sampling"],
+ ),
+ "epipolar_color_samples": self.visualize_epipolar_color_samples(
+ context_images,
+ context,
+ ),
+ "gaussians": self.visualize_gaussians(
+ context["image"],
+ result.opacities,
+ result.covariances,
+ result.harmonics[..., 0], # Just visualize DC component.
+ ),
+ "overlaps": self.visualize_overlaps(
+ context["image"],
+ visualization_dump["sampling"],
+ visualization_dump.get("is_monocular", None),
+ ),
+ "depth": self.visualize_depth(
+ context,
+ visualization_dump["depth"],
+ ),
+ }
+
+ def visualize_attention(
+ self,
+ context_images: Float[Tensor, "batch view 3 height width"],
+ sampling: EpipolarSampling,
+ attention: Float[Tensor, "layer bvr head 1 sample"],
+ ) -> Float[Tensor, "3 vis_height vis_width"]:
+ device = context_images.device
+
+ # Pick a random batch element, view, and other view.
+ b, v, ov, r, s, _ = sampling.xy_sample.shape
+ rb = randrange(b)
+ rv = randrange(v)
+ rov = randrange(ov)
+ num_samples = self.cfg.num_samples
+ rr = np.random.choice(r, num_samples, replace=False)
+ rr = torch.tensor(rr, dtype=torch.int64, device=device)
+
+ # Visualize the rays in the ray view.
+ ray_view = draw_points(
+ context_images[rb, rv],
+ sampling.xy_ray[rb, rv, rr],
+ 0,
+ radius=4,
+ x_range=(0, 1),
+ y_range=(0, 1),
+ )
+ ray_view = draw_points(
+ ray_view,
+ sampling.xy_ray[rb, rv, rr],
+ [get_distinct_color(i) for i, _ in enumerate(rr)],
+ radius=3,
+ x_range=(0, 1),
+ y_range=(0, 1),
+ )
+
+ # Visualize attention in the sample view.
+ attention = rearrange(
+ attention, "l (b v r) hd () s -> l b v r hd s", b=b, v=v, r=r
+ )
+ attention = attention[:, rb, rv, rr, :, :]
+ num_layers, _, hd, _ = attention.shape
+
+ vis = []
+ for il in range(num_layers):
+ vis_layer = []
+ for ihd in range(hd):
+ # Create colors according to attention.
+ color = [get_distinct_color(i) for i, _ in enumerate(rr)]
+ color = torch.tensor(color, device=attention.device)
+ color = rearrange(color, "r c -> r () c")
+ attn = rearrange(attention[il, :, ihd], "r s -> r s ()")
+ color = rearrange(attn * color, "r s c -> (r s ) c")
+
+ # Draw the alternating bucket lines.
+ vis_layer_head = draw_lines(
+ context_images[rb, self.encoder.sampler.index_v[rv, rov]],
+ rearrange(
+ sampling.xy_sample_near[rb, rv, rov, rr], "r s xy -> (r s) xy"
+ ),
+ rearrange(
+ sampling.xy_sample_far[rb, rv, rov, rr], "r s xy -> (r s) xy"
+ ),
+ color,
+ 3,
+ cap="butt",
+ x_range=(0, 1),
+ y_range=(0, 1),
+ )
+ vis_layer.append(vis_layer_head)
+ vis.append(add_label(vcat(*vis_layer), f"Layer {il}"))
+ vis = add_label(add_border(add_border(hcat(*vis)), 1, 0), "Keys & Values")
+ vis = add_border(hcat(add_label(ray_view), vis, align="top"))
+ return vis
+
+ def visualize_depth(
+ self,
+ context: BatchedViews,
+ multi_depth: Float[Tensor, "batch view height width surface spp"],
+ ) -> Float[Tensor, "3 vis_width vis_height"]:
+ multi_vis = []
+ *_, srf, _ = multi_depth.shape
+ for i in range(srf):
+ depth = multi_depth[..., i, :]
+ depth = depth.mean(dim=-1)
+
+ # Compute relative depth and disparity.
+ near = rearrange(context["near"], "b v -> b v () ()")
+ far = rearrange(context["far"], "b v -> b v () ()")
+ relative_depth = (depth - near) / (far - near)
+ relative_disparity = 1 - (1 / depth - 1 / far) / (1 / near - 1 / far)
+
+ relative_depth = apply_color_map_to_image(relative_depth, "turbo")
+ relative_depth = vcat(*[hcat(*x) for x in relative_depth])
+ relative_depth = add_label(relative_depth, "Depth")
+ relative_disparity = apply_color_map_to_image(relative_disparity, "turbo")
+ relative_disparity = vcat(*[hcat(*x) for x in relative_disparity])
+ relative_disparity = add_label(relative_disparity, "Disparity")
+ multi_vis.append(add_border(hcat(relative_depth, relative_disparity)))
+
+ return add_border(vcat(*multi_vis))
+
+ def visualize_overlaps(
+ self,
+ context_images: Float[Tensor, "batch view 3 height width"],
+ sampling: EpipolarSampling,
+ is_monocular: Optional[Bool[Tensor, "batch view height width"]] = None,
+ ) -> Float[Tensor, "3 vis_width vis_height"]:
+ device = context_images.device
+ b, v, _, h, w = context_images.shape
+ green = torch.tensor([0.235, 0.706, 0.294], device=device)[..., None, None]
+ rb = randrange(b)
+ valid = sampling.valid[rb].float()
+ ds = self.encoder.cfg.epipolar_transformer.downscale
+ valid = repeat(
+ valid,
+ "v ov (h w) -> v ov c (h rh) (w rw)",
+ c=3,
+ h=h // ds,
+ w=w // ds,
+ rh=ds,
+ rw=ds,
+ )
+
+ if is_monocular is not None:
+ is_monocular = is_monocular[rb].float()
+ is_monocular = repeat(is_monocular, "v h w -> v c h w", c=3, h=h, w=w)
+
+ # Select context images in grid.
+ context_images = context_images[rb]
+ index, _ = generate_heterogeneous_index(v)
+ valid = valid * (green + context_images[index]) / 2
+
+ vis = vcat(*(hcat(im, hcat(*v)) for im, v in zip(context_images, valid)))
+ vis = add_label(vis, "Context Overlaps")
+
+ if is_monocular is not None:
+ vis = hcat(vis, add_label(vcat(*is_monocular), "Monocular?"))
+
+ return add_border(vis)
+
+ def visualize_gaussians(
+ self,
+ context_images: Float[Tensor, "batch view 3 height width"],
+ opacities: Float[Tensor, "batch vrspp"],
+ covariances: Float[Tensor, "batch vrspp 3 3"],
+ colors: Float[Tensor, "batch vrspp 3"],
+ ) -> Float[Tensor, "3 vis_height vis_width"]:
+ b, v, _, h, w = context_images.shape
+ rb = randrange(b)
+ context_images = context_images[rb]
+ opacities = repeat(
+ opacities[rb], "(v h w spp) -> spp v c h w", v=v, c=3, h=h, w=w
+ )
+ colors = rearrange(colors[rb], "(v h w spp) c -> spp v c h w", v=v, h=h, w=w)
+
+ # Color-map Gaussian covariawnces.
+ det = covariances[rb].det()
+ det = apply_color_map(det / det.max(), "inferno")
+ det = rearrange(det, "(v h w spp) c -> spp v c h w", v=v, h=h, w=w)
+
+ return add_border(
+ hcat(
+ add_label(box(hcat(*context_images)), "Context"),
+ add_label(box(vcat(*[hcat(*x) for x in opacities])), "Opacities"),
+ add_label(
+ box(vcat(*[hcat(*x) for x in (colors * opacities)])), "Colors"
+ ),
+ add_label(box(vcat(*[hcat(*x) for x in colors])), "Colors (Raw)"),
+ add_label(box(vcat(*[hcat(*x) for x in det])), "Determinant"),
+ )
+ )
+
+ def visualize_probabilities(
+ self,
+ context_images: Float[Tensor, "batch view 3 height width"],
+ sampling: EpipolarSampling,
+ pdf: Float[Tensor, "batch view ray sample"],
+ ) -> Float[Tensor, "3 vis_height vis_width"]:
+ device = context_images.device
+
+ # Pick a random batch element, view, and other view.
+ b, v, ov, r, _, _ = sampling.xy_sample.shape
+ rb = randrange(b)
+ rv = randrange(v)
+ rov = randrange(ov)
+ num_samples = self.cfg.num_samples
+ rr = np.random.choice(r, num_samples, replace=False)
+ rr = torch.tensor(rr, dtype=torch.int64, device=device)
+ colors = [get_distinct_color(i) for i, _ in enumerate(rr)]
+ colors = torch.tensor(colors, dtype=torch.float32, device=device)
+
+ # Visualize the rays in the ray view.
+ ray_view = draw_points(
+ context_images[rb, rv],
+ sampling.xy_ray[rb, rv, rr],
+ 0,
+ radius=4,
+ x_range=(0, 1),
+ y_range=(0, 1),
+ )
+ ray_view = draw_points(
+ ray_view,
+ sampling.xy_ray[rb, rv, rr],
+ colors,
+ radius=3,
+ x_range=(0, 1),
+ y_range=(0, 1),
+ )
+
+ # Visualize probabilities in the sample view.
+ pdf = pdf[rb, rv, rr]
+ pdf = rearrange(pdf, "r s -> r s ()")
+ colors = rearrange(colors, "r c -> r () c")
+ sample_view = draw_lines(
+ context_images[rb, self.encoder.sampler.index_v[rv, rov]],
+ rearrange(sampling.xy_sample_near[rb, rv, rov, rr], "r s xy -> (r s) xy"),
+ rearrange(sampling.xy_sample_far[rb, rv, rov, rr], "r s xy -> (r s) xy"),
+ rearrange(pdf * colors, "r s c -> (r s) c"),
+ 6,
+ cap="butt",
+ x_range=(0, 1),
+ y_range=(0, 1),
+ )
+
+ # Visualize rescaled probabilities in the sample view.
+ pdf_magnified = pdf / reduce(pdf, "r s () -> r () ()", "max")
+ sample_view_magnified = draw_lines(
+ context_images[rb, self.encoder.sampler.index_v[rv, rov]],
+ rearrange(sampling.xy_sample_near[rb, rv, rov, rr], "r s xy -> (r s) xy"),
+ rearrange(sampling.xy_sample_far[rb, rv, rov, rr], "r s xy -> (r s) xy"),
+ rearrange(pdf_magnified * colors, "r s c -> (r s) c"),
+ 6,
+ cap="butt",
+ x_range=(0, 1),
+ y_range=(0, 1),
+ )
+
+ return add_border(
+ hcat(
+ add_label(ray_view, "Rays"),
+ add_label(sample_view, "Samples"),
+ add_label(sample_view_magnified, "Samples (Magnified PDF)"),
+ )
+ )
+
+ def visualize_epipolar_samples(
+ self,
+ context_images: Float[Tensor, "batch view 3 height width"],
+ sampling: EpipolarSampling,
+ ) -> Float[Tensor, "3 vis_height vis_width"]:
+ device = context_images.device
+
+ # Pick a random batch element, view, and other view.
+ b, v, ov, r, s, _ = sampling.xy_sample.shape
+ rb = randrange(b)
+ rv = randrange(v)
+ rov = randrange(ov)
+ num_samples = self.cfg.num_samples
+ rr = np.random.choice(r, num_samples, replace=False)
+ rr = torch.tensor(rr, dtype=torch.int64, device=device)
+
+ # Visualize the rays in the ray view.
+ ray_view = draw_points(
+ context_images[rb, rv],
+ sampling.xy_ray[rb, rv, rr],
+ 0,
+ radius=4,
+ x_range=(0, 1),
+ y_range=(0, 1),
+ )
+ ray_view = draw_points(
+ ray_view,
+ sampling.xy_ray[rb, rv, rr],
+ [get_distinct_color(i) for i, _ in enumerate(rr)],
+ radius=3,
+ x_range=(0, 1),
+ y_range=(0, 1),
+ )
+
+ # Visualize the samples and epipolar lines in the sample view.
+ # First, draw the epipolar line in black.
+ sample_view = draw_lines(
+ context_images[rb, self.encoder.sampler.index_v[rv, rov]],
+ sampling.xy_sample_near[rb, rv, rov, rr, 0],
+ sampling.xy_sample_far[rb, rv, rov, rr, -1],
+ 0,
+ 5,
+ cap="butt",
+ x_range=(0, 1),
+ y_range=(0, 1),
+ )
+
+ # Create an alternating line color for the buckets.
+ color = repeat(
+ torch.tensor([0, 1], device=device),
+ "ab -> r (s ab) c",
+ r=len(rr),
+ s=(s + 1) // 2,
+ c=3,
+ )
+ color = rearrange(color[:, :s], "r s c -> (r s) c")
+
+ # Draw the alternating bucket lines.
+ sample_view = draw_lines(
+ sample_view,
+ rearrange(sampling.xy_sample_near[rb, rv, rov, rr], "r s xy -> (r s) xy"),
+ rearrange(sampling.xy_sample_far[rb, rv, rov, rr], "r s xy -> (r s) xy"),
+ color,
+ 3,
+ cap="butt",
+ x_range=(0, 1),
+ y_range=(0, 1),
+ )
+
+ # Draw the sample points.
+ sample_view = draw_points(
+ sample_view,
+ rearrange(sampling.xy_sample[rb, rv, rov, rr], "r s xy -> (r s) xy"),
+ 0,
+ radius=4,
+ x_range=(0, 1),
+ y_range=(0, 1),
+ )
+ sample_view = draw_points(
+ sample_view,
+ rearrange(sampling.xy_sample[rb, rv, rov, rr], "r s xy -> (r s) xy"),
+ [get_distinct_color(i // s) for i in range(s * len(rr))],
+ radius=3,
+ x_range=(0, 1),
+ y_range=(0, 1),
+ )
+
+ return add_border(
+ hcat(add_label(ray_view, "Ray View"), add_label(sample_view, "Sample View"))
+ )
+
+ def visualize_epipolar_color_samples(
+ self,
+ context_images: Float[Tensor, "batch view 3 height width"],
+ context: BatchedViews,
+ ) -> Float[Tensor, "3 vis_height vis_width"]:
+ device = context_images.device
+
+ sampling = self.encoder.sampler(
+ context["image"],
+ context["extrinsics"],
+ context["intrinsics"],
+ context["near"],
+ context["far"],
+ )
+
+ # Pick a random batch element, view, and other view.
+ b, v, ov, r, s, _ = sampling.xy_sample.shape
+ rb = randrange(b)
+ rv = randrange(v)
+ rov = randrange(ov)
+ num_samples = self.cfg.num_samples
+ rr = np.random.choice(r, num_samples, replace=False)
+ rr = torch.tensor(rr, dtype=torch.int64, device=device)
+
+ # Visualize the rays in the ray view.
+ ray_view = draw_points(
+ context_images[rb, rv],
+ sampling.xy_ray[rb, rv, rr],
+ 0,
+ radius=4,
+ x_range=(0, 1),
+ y_range=(0, 1),
+ )
+ ray_view = draw_points(
+ ray_view,
+ sampling.xy_ray[rb, rv, rr],
+ [get_distinct_color(i) for i, _ in enumerate(rr)],
+ radius=3,
+ x_range=(0, 1),
+ y_range=(0, 1),
+ )
+
+ # Visualize the samples and in the sample view.
+ sample_view = draw_points(
+ context_images[rb, self.encoder.sampler.index_v[rv, rov]],
+ rearrange(sampling.xy_sample[rb, rv, rov, rr], "r s xy -> (r s) xy"),
+ [get_distinct_color(i // s) for i in range(s * len(rr))],
+ radius=4,
+ x_range=(0, 1),
+ y_range=(0, 1),
+ )
+ sample_view = draw_points(
+ sample_view,
+ rearrange(sampling.xy_sample[rb, rv, rov, rr], "r s xy -> (r s) xy"),
+ rearrange(sampling.features[rb, rv, rov, rr], "r s c -> (r s) c"),
+ radius=3,
+ x_range=(0, 1),
+ y_range=(0, 1),
+ )
+
+ return add_border(
+ hcat(add_label(ray_view, "Ray View"), add_label(sample_view, "Sample View"))
+ )
diff --git a/src/model/encoder/visualization/encoder_visualizer_epipolar_cfg.py b/src/model/encoder/visualization/encoder_visualizer_epipolar_cfg.py
new file mode 100644
index 0000000000000000000000000000000000000000..e68d57f80c7a875401bf25758bc9e8e4a23d6dab
--- /dev/null
+++ b/src/model/encoder/visualization/encoder_visualizer_epipolar_cfg.py
@@ -0,0 +1,10 @@
+from dataclasses import dataclass
+
+# This is in a separate file to avoid circular imports.
+
+
+@dataclass
+class EncoderVisualizerEpipolarCfg:
+ num_samples: int
+ min_resolution: int
+ export_ply: bool
diff --git a/src/model/encodings/positional_encoding.py b/src/model/encodings/positional_encoding.py
new file mode 100644
index 0000000000000000000000000000000000000000..e36b4bfc39aef2a4a2883e501e7f19d27ad1f663
--- /dev/null
+++ b/src/model/encodings/positional_encoding.py
@@ -0,0 +1,36 @@
+import torch
+import torch.nn as nn
+from einops import einsum, rearrange, repeat
+from jaxtyping import Float
+from torch import Tensor
+
+
+class PositionalEncoding(nn.Module):
+ """For the sake of simplicity, this encodes values in the range [0, 1]."""
+
+ frequencies: Float[Tensor, "frequency phase"]
+ phases: Float[Tensor, "frequency phase"]
+
+ def __init__(self, num_octaves: int):
+ super().__init__()
+ octaves = torch.arange(num_octaves).float()
+
+ # The lowest frequency has a period of 1.
+ frequencies = 2 * torch.pi * 2**octaves
+ frequencies = repeat(frequencies, "f -> f p", p=2)
+ self.register_buffer("frequencies", frequencies, persistent=False)
+
+ # Choose the phases to match sine and cosine.
+ phases = torch.tensor([0, 0.5 * torch.pi], dtype=torch.float32)
+ phases = repeat(phases, "p -> f p", f=num_octaves)
+ self.register_buffer("phases", phases, persistent=False)
+
+ def forward(
+ self,
+ samples: Float[Tensor, "*batch dim"],
+ ) -> Float[Tensor, "*batch embedded_dim"]:
+ samples = einsum(samples, self.frequencies, "... d, f p -> ... d f p")
+ return rearrange(torch.sin(samples + self.phases), "... d f p -> ... (d f p)")
+
+ def d_out(self, dimensionality: int):
+ return self.frequencies.numel() * dimensionality
diff --git a/src/model/model/__init__.py b/src/model/model/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b8cb280359ea4397b93e6cd045397298bf272569
--- /dev/null
+++ b/src/model/model/__init__.py
@@ -0,0 +1,21 @@
+from typing import Optional, Union
+
+from ..encoder import Encoder
+from ..encoder.visualization.encoder_visualizer import EncoderVisualizer
+from ..encoder.anysplat import EncoderAnySplat, EncoderAnySplatCfg
+from ..decoder.decoder_splatting_cuda import DecoderSplattingCUDACfg
+from torch import nn
+from .anysplat import AnySplat
+
+MODELS = {
+ "anysplat": AnySplat,
+}
+
+EncoderCfg = Union[EncoderAnySplatCfg]
+DecoderCfg = DecoderSplattingCUDACfg
+
+
+# hard code for now
+def get_model(encoder_cfg: EncoderCfg, decoder_cfg: DecoderCfg) -> nn.Module:
+ model = MODELS['anysplat'](encoder_cfg, decoder_cfg)
+ return model
diff --git a/src/model/model/anysplat.py b/src/model/model/anysplat.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e7b211db7bfc308194be31080d875df3d524c8e
--- /dev/null
+++ b/src/model/model/anysplat.py
@@ -0,0 +1,126 @@
+import os
+from copy import deepcopy
+import time
+from typing import Optional
+from einops import rearrange
+import huggingface_hub
+from omegaconf import DictConfig, OmegaConf
+import torch
+import torch.distributed
+import torch.nn as nn
+import numpy as np
+import torch.nn.functional as F
+from dataclasses import dataclass
+
+from src.model.encoder.common.gaussian_adapter import GaussianAdapterCfg
+from src.model.decoder.decoder_splatting_cuda import DecoderSplattingCUDA, DecoderSplattingCUDACfg
+from src.model.encoder.anysplat import EncoderAnySplat, EncoderAnySplatCfg, OpacityMappingCfg
+
+class AnySplat(nn.Module, huggingface_hub.PyTorchModelHubMixin):
+ def __init__(
+ self,
+ encoder_cfg: EncoderAnySplatCfg,
+ decoder_cfg: DecoderSplattingCUDACfg,
+ ):
+ super(AnySplat, self).__init__()
+ self.encoder_cfg = encoder_cfg
+ self.decoder_cfg = decoder_cfg
+ self.build_encoder(encoder_cfg)
+ self.build_decoder(decoder_cfg)
+
+ def convert_nested_config(self, cfg_dict: dict, target_class: type):
+ """Convert nested dictionary config to dataclass instance
+
+ Args:
+ cfg_dict: Configuration dictionary or already converted object
+ target_class: Target dataclass type to convert to
+
+ Returns:
+ Instance of target_class
+ """
+ if isinstance(cfg_dict, dict):
+ # Convert dict to dataclass
+ return target_class(**cfg_dict)
+ elif isinstance(cfg_dict, target_class):
+ # Already converted, return as is
+ return cfg_dict
+ elif cfg_dict is None:
+ # Handle None case
+ return None
+ else:
+ raise ValueError(f"Cannot convert {type(cfg_dict)} to {target_class}")
+
+ def convert_config_recursively(self, cfg_obj, conversion_map: dict):
+ """Convert nested configurations recursively using a conversion map
+
+ Args:
+ cfg_obj: Configuration object to convert
+ conversion_map: Dict mapping field names to their target classes
+ e.g., {'gaussian_adapter': GaussianAdapterCfg}
+
+ Returns:
+ Converted configuration object
+ """
+ if not hasattr(cfg_obj, '__dict__'):
+ return cfg_obj
+
+ cfg_dict = cfg_obj.__dict__.copy()
+
+ for field_name, target_class in conversion_map.items():
+ if field_name in cfg_dict:
+ cfg_dict[field_name] = self.convert_nested_config(
+ cfg_dict[field_name],
+ target_class
+ )
+
+ # Return new instance of the same type
+ return type(cfg_obj)(**cfg_dict)
+
+ def convert_encoder_config(self, encoder_cfg: EncoderAnySplatCfg) -> EncoderAnySplatCfg:
+ """Convert all nested configurations in encoder_cfg"""
+ conversion_map = {
+ 'gaussian_adapter': GaussianAdapterCfg,
+ 'opacity_mapping': OpacityMappingCfg,
+ }
+
+ return self.convert_config_recursively(encoder_cfg, conversion_map)
+
+ def build_encoder(self, encoder_cfg: EncoderAnySplatCfg):
+ # Convert nested configurations using the helper method
+ encoder_cfg = self.convert_encoder_config(encoder_cfg)
+ self.encoder = EncoderAnySplat(encoder_cfg)
+
+ def build_decoder(self, decoder_cfg: DecoderSplattingCUDACfg):
+ self.decoder = DecoderSplattingCUDA(decoder_cfg)
+
+ @torch.no_grad()
+ def inference(self,
+ context_image: torch.Tensor,
+ ):
+ self.encoder.distill = False
+ encoder_output = self.encoder(context_image, global_step=0, visualization_dump=None)
+ gaussians, pred_context_pose = encoder_output.gaussians, encoder_output.pred_context_pose
+ return gaussians, pred_context_pose
+
+ def forward(self,
+ context_image: torch.Tensor,
+ global_step: int = 0,
+ visualization_dump: Optional[dict] = None,
+ near: float = 0.01,
+ far: float = 100.0,
+ ):
+ b, v, c, h, w = context_image.shape
+ device = context_image.device
+ encoder_output = self.encoder(context_image, global_step, visualization_dump=visualization_dump)
+ gaussians, pred_context_pose = encoder_output.gaussians, encoder_output.pred_context_pose
+ output = self.decoder.forward(
+ gaussians,
+ pred_context_pose['extrinsic'],
+ pred_context_pose["intrinsic"],
+ torch.ones(1, v, device=device) * near,
+ torch.ones(1, v, device=device) * far,
+ (h, w),
+ "depth",
+ )
+ return encoder_output, output
+
diff --git a/src/model/model_wrapper.py b/src/model/model_wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..e969e38f6df3bc225b7016e02d073cf415076842
--- /dev/null
+++ b/src/model/model_wrapper.py
@@ -0,0 +1,847 @@
+from dataclasses import dataclass
+from pathlib import Path
+import gc
+import random
+from typing import Literal, Optional, Protocol, runtime_checkable, Any
+
+import moviepy.editor as mpy
+import torch
+import torchvision
+import wandb
+from einops import pack, rearrange, repeat
+from jaxtyping import Float
+from lightning.pytorch import LightningModule
+from lightning.pytorch.loggers.wandb import WandbLogger
+from lightning.pytorch.utilities import rank_zero_only
+from tabulate import tabulate
+from torch import Tensor, nn, optim
+import torch.nn.functional as F
+
+from loss.loss_lpips import LossLpips
+from loss.loss_mse import LossMse
+from model.encoder.vggt.utils.pose_enc import pose_encoding_to_extri_intri
+
+from ..loss.loss_distill import DistillLoss
+from src.utils.render import generate_path
+from src.utils.point import get_normal_map
+
+from ..loss.loss_huber import HuberLoss, extri_intri_to_pose_encoding
+
+# from model.types import Gaussians
+
+from ..dataset.data_module import get_data_shim
+from ..dataset.types import BatchedExample
+from ..evaluation.metrics import compute_lpips, compute_psnr, compute_ssim, abs_relative_difference, delta1_acc
+from ..global_cfg import get_cfg
+from ..loss import Loss
+from ..loss.loss_point import Regr3D
+from ..loss.loss_ssim import ssim
+from ..misc.benchmarker import Benchmarker
+from ..misc.cam_utils import update_pose, get_pnp_pose, rotation_6d_to_matrix
+from ..misc.image_io import prep_image, save_image, save_video
+from ..misc.LocalLogger import LOG_PATH, LocalLogger
+from ..misc.nn_module_tools import convert_to_buffer
+from ..misc.step_tracker import StepTracker
+from ..misc.utils import inverse_normalize, vis_depth_map, confidence_map, get_overlap_tag
+from ..visualization.annotation import add_label
+from ..visualization.camera_trajectory.interpolation import (
+ interpolate_extrinsics,
+ interpolate_intrinsics,
+)
+from ..visualization.camera_trajectory.wobble import (
+ generate_wobble,
+ generate_wobble_transformation,
+)
+from ..visualization.color_map import apply_color_map_to_image
+from ..visualization.layout import add_border, hcat, vcat
+# from ..visualization.validation_in_3d import render_cameras, render_projections
+from .decoder.decoder import Decoder, DepthRenderingMode
+from .encoder import Encoder
+from .encoder.visualization.encoder_visualizer import EncoderVisualizer
+from .ply_export import export_ply
+
+@dataclass
+class OptimizerCfg:
+ lr: float
+ warm_up_steps: int
+ backbone_lr_multiplier: float
+
+
+@dataclass
+class TestCfg:
+ output_path: Path
+ align_pose: bool
+ pose_align_steps: int
+ rot_opt_lr: float
+ trans_opt_lr: float
+ compute_scores: bool
+ save_image: bool
+ save_video: bool
+ save_compare: bool
+ generate_video: bool
+ mode: Literal["inference", "evaluation"]
+ image_folder: str
+
+
+@dataclass
+class TrainCfg:
+ output_path: Path
+ depth_mode: DepthRenderingMode | None
+ extended_visualization: bool
+ print_log_every_n_steps: int
+ distiller: str
+ distill_max_steps: int
+ pose_loss_alpha: float = 1.0
+ pose_loss_delta: float = 1.0
+ cxt_depth_weight: float = 0.01
+ weight_pose: float = 1.0
+ weight_depth: float = 1.0
+ weight_normal: float = 1.0
+ render_ba: bool = False
+ render_ba_after_step: int = 0
+
+
+@runtime_checkable
+class TrajectoryFn(Protocol):
+ def __call__(
+ self,
+ t: Float[Tensor, " t"],
+ ) -> tuple[
+ Float[Tensor, "batch view 4 4"], # extrinsics
+ Float[Tensor, "batch view 3 3"], # intrinsics
+ ]:
+ pass
+
+
+class ModelWrapper(LightningModule):
+ logger: Optional[WandbLogger]
+ model: nn.Module
+ losses: nn.ModuleList
+ optimizer_cfg: OptimizerCfg
+ test_cfg: TestCfg
+ train_cfg: TrainCfg
+ step_tracker: StepTracker | None
+
+ def __init__(
+ self,
+ optimizer_cfg: OptimizerCfg,
+ test_cfg: TestCfg,
+ train_cfg: TrainCfg,
+ model: nn.Module,
+ losses: list[Loss],
+ step_tracker: StepTracker | None
+ ) -> None:
+ super().__init__()
+ self.optimizer_cfg = optimizer_cfg
+ self.test_cfg = test_cfg
+ self.train_cfg = train_cfg
+ self.step_tracker = step_tracker
+
+ # Set up the model.
+ self.encoder_visualizer = None
+ self.model = model
+ self.data_shim = get_data_shim(self.model.encoder)
+ self.losses = nn.ModuleList(losses)
+
+ if self.model.encoder.pred_pose:
+ self.loss_pose = HuberLoss(alpha=self.train_cfg.pose_loss_alpha, delta=self.train_cfg.pose_loss_delta)
+
+ if self.model.encoder.distill:
+ self.loss_distill = DistillLoss(
+ delta=self.train_cfg.pose_loss_delta,
+ weight_pose=self.train_cfg.weight_pose,
+ weight_depth=self.train_cfg.weight_depth,
+ weight_normal=self.train_cfg.weight_normal
+ )
+
+ # This is used for testing.
+ self.benchmarker = Benchmarker()
+
+ def on_train_epoch_start(self) -> None:
+ # our custom dataset and sampler has to have epoch set by calling set_epoch
+ if hasattr(self.trainer.datamodule.train_loader.dataset, "set_epoch"):
+ self.trainer.datamodule.train_loader.dataset.set_epoch(self.current_epoch)
+ if hasattr(self.trainer.datamodule.train_loader.sampler, "set_epoch"):
+ self.trainer.datamodule.train_loader.sampler.set_epoch(self.current_epoch)
+
+ def on_validation_epoch_start(self) -> None:
+ print(f"Validation epoch start on rank {self.trainer.global_rank}")
+ # our custom dataset and sampler has to have epoch set by calling set_epoch
+ if hasattr(self.trainer.datamodule.val_loader.dataset, "set_epoch"):
+ self.trainer.datamodule.val_loader.dataset.set_epoch(self.current_epoch)
+ if hasattr(self.trainer.datamodule.val_loader.sampler, "set_epoch"):
+ self.trainer.datamodule.val_loader.sampler.set_epoch(self.current_epoch)
+
+ def training_step(self, batch, batch_idx):
+ # combine batch from different dataloaders
+ # torch.cuda.empty_cache()
+ if isinstance(batch, list):
+ batch_combined = None
+ for batch_per_dl in batch:
+ if batch_combined is None:
+ batch_combined = batch_per_dl
+ else:
+ for k in batch_combined.keys():
+ if isinstance(batch_combined[k], list):
+ batch_combined[k] += batch_per_dl[k]
+ elif isinstance(batch_combined[k], dict):
+ for kk in batch_combined[k].keys():
+ batch_combined[k][kk] = torch.cat([batch_combined[k][kk], batch_per_dl[k][kk]], dim=0)
+ else:
+ raise NotImplementedError
+ batch = batch_combined
+
+ batch: BatchedExample = self.data_shim(batch)
+ b, v, c, h, w = batch["context"]["image"].shape
+ context_image = (batch["context"]["image"] + 1) / 2
+
+ # Run the model.
+ visualization_dump = None
+
+ encoder_output, output = self.model(context_image, self.global_step, visualization_dump=visualization_dump)
+ gaussians, pred_pose_enc_list, depth_dict = encoder_output.gaussians, encoder_output.pred_pose_enc_list, encoder_output.depth_dict
+ pred_context_pose = encoder_output.pred_context_pose
+ infos = encoder_output.infos
+ distill_infos = encoder_output.distill_infos
+
+ num_context_views = pred_context_pose['extrinsic'].shape[1]
+
+ using_index = torch.arange(num_context_views, device=gaussians.means.device)
+ batch["using_index"] = using_index
+
+ target_gt = (batch["context"]["image"] + 1) / 2
+ scene_scale = infos["scene_scale"]
+ self.log("train/scene_scale", infos["scene_scale"])
+ self.log("train/voxelize_ratio", infos["voxelize_ratio"])
+
+ # Compute metrics.
+ psnr_probabilistic = compute_psnr(
+ rearrange(target_gt, "b v c h w -> (b v) c h w"),
+ rearrange(output.color, "b v c h w -> (b v) c h w"),
+ )
+ self.log("train/psnr_probabilistic", psnr_probabilistic.mean())
+
+ consis_absrel = abs_relative_difference(
+ rearrange(output.depth, "b v h w -> (b v) h w"),
+ rearrange(depth_dict['depth'].squeeze(-1), "b v h w -> (b v) h w"),
+ rearrange(distill_infos['conf_mask'], "b v h w -> (b v) h w"),
+ )
+ self.log("train/consis_absrel", consis_absrel.mean())
+
+ consis_delta1 = delta1_acc(
+ rearrange(output.depth, "b v h w -> (b v) h w"),
+ rearrange(depth_dict['depth'].squeeze(-1), "b v h w -> (b v) h w"),
+ rearrange(distill_infos['conf_mask'], "b v h w -> (b v) h w"),
+ )
+ self.log("train/consis_delta1", consis_delta1.mean())
+
+ # Compute and log loss.
+ total_loss = 0
+
+ depth_dict['distill_infos'] = distill_infos
+ with torch.amp.autocast('cuda', enabled=False):
+ for loss_fn in self.losses:
+ loss = loss_fn.forward(output, batch, gaussians, depth_dict, self.global_step)
+ self.log(f"loss/{loss_fn.name}", loss)
+ total_loss = total_loss + loss
+
+ if depth_dict is not None and "depth" in get_cfg()["loss"].keys() and self.train_cfg.cxt_depth_weight > 0:
+ depth_loss_idx = list(get_cfg()["loss"].keys()).index("depth")
+ depth_loss_fn = self.losses[depth_loss_idx].ctx_depth_loss
+ loss_depth = depth_loss_fn(depth_dict["depth_map"], depth_dict["depth_conf"], batch, cxt_depth_weight=self.train_cfg.cxt_depth_weight)
+ self.log("loss/ctx_depth", loss_depth)
+ total_loss = total_loss + loss_depth
+
+ if distill_infos is not None:
+ # distill ctx pred_pose & depth & normal
+ loss_distill_list = self.loss_distill(distill_infos, pred_pose_enc_list, output, batch)
+ self.log("loss/distill", loss_distill_list['loss_distill'])
+ self.log("loss/distill_pose", loss_distill_list['loss_pose'])
+ self.log("loss/distill_depth", loss_distill_list['loss_depth'])
+ self.log("loss/distill_normal", loss_distill_list['loss_normal'])
+ total_loss = total_loss + loss_distill_list['loss_distill']
+
+ self.log("loss/total", total_loss)
+ print(f"total_loss: {total_loss}")
+
+ # Skip batch if loss is too high after certain step
+ SKIP_AFTER_STEP = 1000
+ LOSS_THRESHOLD = 0.2
+ if self.global_step > SKIP_AFTER_STEP and total_loss > LOSS_THRESHOLD:
+ print(f"Skipping batch with high loss ({total_loss:.6f}) at step {self.global_step} on Rank {self.global_rank}")
+ # set to a really small number
+ return total_loss * 1e-10
+
+ if (
+ self.global_rank == 0
+ and self.global_step % self.train_cfg.print_log_every_n_steps == 0
+ ):
+ print(
+ f"train step {self.global_step}; "
+ f"scene = {[x[:20] for x in batch['scene']]}; "
+ f"context = {batch['context']['index'].tolist()}; "
+ f"loss = {total_loss:.6f}; "
+ )
+
+ self.log("info/global_step", self.global_step) # hack for ckpt monitor
+
+ # Tell the data loader processes about the current step.
+ if self.step_tracker is not None:
+ self.step_tracker.set_step(self.global_step)
+
+ del batch
+ if self.global_step % 50 == 0:
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ return total_loss
+
+ def on_after_backward(self):
+ total_norm = 0.0
+ counter = 0
+ for p in self.parameters():
+ if p.grad is not None:
+ param_norm = p.grad.detach().data.norm(2)
+ total_norm += param_norm.item() ** 2
+ counter += 1
+ total_norm = (total_norm / counter) ** 0.5
+ self.log("loss/grad_norm", total_norm)
+
+ def test_step(self, batch, batch_idx):
+ batch: BatchedExample = self.data_shim(batch)
+ b, v, _, h, w = batch["target"]["image"].shape
+ assert b == 1
+ if batch_idx % 100 == 0:
+ print(f"Test step {batch_idx:0>6}.")
+
+ # Render Gaussians.
+ with self.benchmarker.time("encoder"):
+ gaussians = self.model.encoder(
+ (batch["context"]["image"]+1)/2,
+ self.global_step,
+ )[0]
+ # export_ply(gaussians.means[0], gaussians.scales[0], gaussians.rotations[0], gaussians.harmonics[0], gaussians.opacities[0], Path("gaussians.ply"))
+ # align the target pose
+ if self.test_cfg.align_pose:
+ output = self.test_step_align(batch, gaussians)
+ else:
+ with self.benchmarker.time("decoder", num_calls=v):
+ output = self.model.decoder.forward(
+ gaussians,
+ batch["target"]["extrinsics"],
+ batch["target"]["intrinsics"],
+ batch["target"]["near"],
+ batch["target"]["far"],
+ (h, w),
+ )
+
+ # compute scores
+ if self.test_cfg.compute_scores:
+ overlap = batch["context"]["overlap"][0]
+ overlap_tag = get_overlap_tag(overlap)
+
+ rgb_pred = output.color[0]
+ rgb_gt = batch["target"]["image"][0]
+ all_metrics = {
+ f"lpips_ours": compute_lpips(rgb_gt, rgb_pred).mean(),
+ f"ssim_ours": compute_ssim(rgb_gt, rgb_pred).mean(),
+ f"psnr_ours": compute_psnr(rgb_gt, rgb_pred).mean(),
+ }
+ methods = ['ours']
+
+ self.log_dict(all_metrics)
+ self.print_preview_metrics(all_metrics, methods, overlap_tag=overlap_tag)
+
+ # Save images.
+ (scene,) = batch["scene"]
+ name = get_cfg()["wandb"]["name"]
+ path = self.test_cfg.output_path / name
+ if self.test_cfg.save_image:
+ for index, color in zip(batch["target"]["index"][0], output.color[0]):
+ save_image(color, path / scene / f"color/{index:0>6}.png")
+
+ if self.test_cfg.save_video:
+ frame_str = "_".join([str(x.item()) for x in batch["context"]["index"][0]])
+ save_video(
+ [a for a in output.color[0]],
+ path / "video" / f"{scene}_frame_{frame_str}.mp4",
+ )
+
+ if self.test_cfg.save_compare:
+ # Construct comparison image.
+ context_img = inverse_normalize(batch["context"]["image"][0])
+ comparison = hcat(
+ add_label(vcat(*context_img), "Context"),
+ add_label(vcat(*rgb_gt), "Target (Ground Truth)"),
+ add_label(vcat(*rgb_pred), "Target (Prediction)"),
+ )
+ save_image(comparison, path / f"{scene}.png")
+
+ def test_step_align(self, batch, gaussians):
+ self.model.encoder.eval()
+ # freeze all parameters
+ for param in self.model.encoder.parameters():
+ param.requires_grad = False
+
+ b, v, _, h, w = batch["target"]["image"].shape
+ output_c2ws = batch["target"]["extrinsics"]
+ with torch.set_grad_enabled(True):
+ cam_rot_delta = nn.Parameter(torch.zeros([b, v, 6], requires_grad=True, device=output_c2ws.device))
+ cam_trans_delta = nn.Parameter(torch.zeros([b, v, 3], requires_grad=True, device=output_c2ws.device))
+ opt_params = []
+ self.register_buffer("identity", torch.tensor([1.0, 0.0, 0.0, 0.0, 1.0, 0.0]).to(output_c2ws))
+ opt_params.append(
+ {
+ "params": [cam_rot_delta],
+ "lr": 0.005,
+ }
+ )
+ opt_params.append(
+ {
+ "params": [cam_trans_delta],
+ "lr": 0.005,
+ }
+ )
+ pose_optimizer = torch.optim.Adam(opt_params)
+ extrinsics = output_c2ws.clone()
+ with self.benchmarker.time("optimize"):
+ for i in range(self.test_cfg.pose_align_steps):
+ pose_optimizer.zero_grad()
+ dx, drot = cam_trans_delta, cam_rot_delta
+ rot = rotation_6d_to_matrix(
+ drot + self.identity.expand(b, v, -1)
+ ) # (..., 3, 3)
+
+ transform = torch.eye(4, device=extrinsics.device).repeat((b, v, 1, 1))
+ transform[..., :3, :3] = rot
+ transform[..., :3, 3] = dx
+
+ new_extrinsics = torch.matmul(extrinsics, transform)
+ output = self.model.decoder.forward(
+ gaussians,
+ new_extrinsics,
+ batch["target"]["intrinsics"],
+ batch["target"]["near"],
+ batch["target"]["far"],
+ (h, w),
+ # cam_rot_delta=cam_rot_delta,
+ # cam_trans_delta=cam_trans_delta,
+ )
+
+ # Compute and log loss.
+ total_loss = 0
+ for loss_fn in self.losses:
+ loss = loss_fn.forward(output, batch, gaussians, self.global_step)
+ total_loss = total_loss + loss
+
+ total_loss.backward()
+ pose_optimizer.step()
+
+ # Render Gaussians.
+ output = self.model.decoder.forward(
+ gaussians,
+ new_extrinsics,
+ batch["target"]["intrinsics"],
+ batch["target"]["near"],
+ batch["target"]["far"],
+ (h, w),
+ )
+
+ return output
+
+ def on_test_end(self) -> None:
+ name = get_cfg()["wandb"]["name"]
+ self.benchmarker.dump(self.test_cfg.output_path / name / "benchmark.json")
+ self.benchmarker.dump_memory(
+ self.test_cfg.output_path / name / "peak_memory.json"
+ )
+ self.benchmarker.summarize()
+
+ @rank_zero_only
+ def validation_step(self, batch, batch_idx, dataloader_idx=0):
+ batch: BatchedExample = self.data_shim(batch)
+
+ if self.global_rank == 0:
+ print(
+ f"validation step {self.global_step}; "
+ f"scene = {batch['scene']}; "
+ f"context = {batch['context']['index'].tolist()}"
+ )
+
+ # Render Gaussians.
+ b, v, _, h, w = batch["context"]["image"].shape
+ assert b == 1
+ visualization_dump = {}
+
+ encoder_output, output = self.model(batch["context"]["image"], self.global_step, visualization_dump=visualization_dump)
+ gaussians, pred_pose_enc_list, depth_dict = encoder_output.gaussians, encoder_output.pred_pose_enc_list, encoder_output.depth_dict
+ pred_context_pose, distill_infos = encoder_output.pred_context_pose, encoder_output.distill_infos
+ infos = encoder_output.infos
+
+ GS_num = infos['voxelize_ratio'] * (h*w*v)
+ self.log("val/GS_num", GS_num)
+
+ num_context_views = pred_context_pose['extrinsic'].shape[1]
+ num_target_views = batch["target"]["extrinsics"].shape[1]
+ rgb_pred = output.color[0].float()
+ depth_pred = vis_depth_map(output.depth[0])
+
+ # direct depth from gaussian means (used for visualization only)
+ gaussian_means = visualization_dump["depth"][0].squeeze()
+ if gaussian_means.shape[-1] == 3:
+ gaussian_means = gaussian_means.mean(dim=-1)
+
+ # Compute validation metrics.
+ rgb_gt = (batch["context"]["image"][0].float() + 1) / 2
+ psnr = compute_psnr(rgb_gt, rgb_pred).mean()
+ self.log(f"val/psnr", psnr)
+ lpips = compute_lpips(rgb_gt, rgb_pred).mean()
+ self.log(f"val/lpips", lpips)
+ ssim = compute_ssim(rgb_gt, rgb_pred).mean()
+ self.log(f"val/ssim", ssim)
+
+ # depth metrics
+ consis_absrel = abs_relative_difference(
+ rearrange(output.depth, "b v h w -> (b v) h w"),
+ rearrange(depth_dict['depth'].squeeze(-1), "b v h w -> (b v) h w"),
+ )
+ self.log("val/consis_absrel", consis_absrel.mean())
+
+ consis_delta1 = delta1_acc(
+ rearrange(output.depth, "b v h w -> (b v) h w"),
+ rearrange(depth_dict['depth'].squeeze(-1), "b v h w -> (b v) h w"),
+ valid_mask=rearrange(torch.ones_like(output.depth, device=output.depth.device, dtype=torch.bool), "b v h w -> (b v) h w"),
+ )
+ self.log("val/consis_delta1", consis_delta1.mean())
+
+ diff_map = torch.abs(output.depth - depth_dict['depth'].squeeze(-1))
+ self.log("val/consis_mse", diff_map[distill_infos['conf_mask']].mean())
+
+ # Construct comparison image.
+ context_img = inverse_normalize(batch["context"]["image"][0])
+ # context_img_depth = vis_depth_map(gaussian_means)
+ context = []
+ for i in range(context_img.shape[0]):
+ context.append(context_img[i])
+ # context.append(context_img_depth[i])
+
+ colored_diff_map = vis_depth_map(diff_map[0], near=torch.tensor(1e-4, device=diff_map.device), far=torch.tensor(1.0, device=diff_map.device))
+ model_depth_pred = depth_dict["depth"].squeeze(-1)[0]
+ model_depth_pred = vis_depth_map(model_depth_pred)
+
+ render_normal = (get_normal_map(output.depth.flatten(0, 1), batch["context"]["intrinsics"].flatten(0, 1)).permute(0, 3, 1, 2) + 1) / 2.
+ pred_normal = (get_normal_map(depth_dict['depth'].flatten(0, 1).squeeze(-1), batch["context"]["intrinsics"].flatten(0, 1)).permute(0, 3, 1, 2) + 1) / 2.
+
+ comparison = hcat(
+ add_label(vcat(*context), "Context"),
+ add_label(vcat(*rgb_gt), "Target (Ground Truth)"),
+ add_label(vcat(*rgb_pred), "Target (Prediction)"),
+ add_label(vcat(*depth_pred), "Depth (Prediction)"),
+ add_label(vcat(*model_depth_pred), "Depth (VGGT Prediction)"),
+ add_label(vcat(*render_normal), "Normal (Prediction)"),
+ add_label(vcat(*pred_normal), "Normal (VGGT Prediction)"),
+ add_label(vcat(*colored_diff_map), "Diff Map"),
+ )
+
+ comparison = torch.nn.functional.interpolate(
+ comparison.unsqueeze(0),
+ scale_factor=0.5,
+ mode='bicubic',
+ align_corners=False
+ ).squeeze(0)
+
+ self.logger.log_image(
+ "comparison",
+ [prep_image(add_border(comparison))],
+ step=self.global_step,
+ caption=batch["scene"],
+ )
+
+ # self.logger.log_image(
+ # key="comparison",
+ # images=[wandb.Image(prep_image(add_border(comparison)), caption=batch["scene"], file_type="jpg")],
+ # step=self.global_step
+ # )
+
+ # Render projections and construct projection image.
+ # These are disabled for now, since RE10k scenes are effectively unbounded.
+
+ # if isinstance(gaussians, Gaussians):
+ # projections = hcat(
+ # *render_projections(
+ # gaussians,
+ # 256,
+ # extra_label="",
+ # )[0]
+ # )
+ # self.logger.log_image(
+ # "projection",
+ # [prep_image(add_border(projections))],
+ # step=self.global_step,
+ # )
+
+ # Draw cameras.
+ # cameras = hcat(*render_cameras(batch, 256))
+ # self.logger.log_image(
+ # "cameras", [prep_image(add_border(cameras))], step=self.global_step
+ # )
+
+ if self.encoder_visualizer is not None:
+ for k, image in self.encoder_visualizer.visualize(
+ batch["context"], self.global_step
+ ).items():
+ self.logger.log_image(k, [prep_image(image)], step=self.global_step)
+
+ # Run video validation step.
+ self.render_video_interpolation(batch)
+ self.render_video_wobble(batch)
+ if self.train_cfg.extended_visualization:
+ self.render_video_interpolation_exaggerated(batch)
+
+ @rank_zero_only
+ def render_video_wobble(self, batch: BatchedExample) -> None:
+ # Two views are needed to get the wobble radius.
+ _, v, _, _ = batch["context"]["extrinsics"].shape
+ if v != 2:
+ return
+
+ def trajectory_fn(t):
+ origin_a = batch["context"]["extrinsics"][:, 0, :3, 3]
+ origin_b = batch["context"]["extrinsics"][:, 1, :3, 3]
+ delta = (origin_a - origin_b).norm(dim=-1)
+ extrinsics = generate_wobble(
+ batch["context"]["extrinsics"][:, 0],
+ delta * 0.25,
+ t,
+ )
+ intrinsics = repeat(
+ batch["context"]["intrinsics"][:, 0],
+ "b i j -> b v i j",
+ v=t.shape[0],
+ )
+ return extrinsics, intrinsics
+
+ return self.render_video_generic(batch, trajectory_fn, "wobble", num_frames=60)
+
+ @rank_zero_only
+ def render_video_interpolation(self, batch: BatchedExample) -> None:
+ _, v, _, _ = batch["context"]["extrinsics"].shape
+
+ def trajectory_fn(t):
+ extrinsics = interpolate_extrinsics(
+ batch["context"]["extrinsics"][0, 0],
+ (
+ batch["context"]["extrinsics"][0, 1]
+ if v == 2
+ else batch["target"]["extrinsics"][0, 0]
+ ),
+ t,
+ )
+ intrinsics = interpolate_intrinsics(
+ batch["context"]["intrinsics"][0, 0],
+ (
+ batch["context"]["intrinsics"][0, 1]
+ if v == 2
+ else batch["target"]["intrinsics"][0, 0]
+ ),
+ t,
+ )
+ return extrinsics[None], intrinsics[None]
+
+ return self.render_video_generic(batch, trajectory_fn, "rgb")
+
+ @rank_zero_only
+ def render_video_interpolation_exaggerated(self, batch: BatchedExample) -> None:
+ # Two views are needed to get the wobble radius.
+ _, v, _, _ = batch["context"]["extrinsics"].shape
+ if v != 2:
+ return
+
+ def trajectory_fn(t):
+ origin_a = batch["context"]["extrinsics"][:, 0, :3, 3]
+ origin_b = batch["context"]["extrinsics"][:, 1, :3, 3]
+ delta = (origin_a - origin_b).norm(dim=-1)
+ tf = generate_wobble_transformation(
+ delta * 0.5,
+ t,
+ 5,
+ scale_radius_with_t=False,
+ )
+ extrinsics = interpolate_extrinsics(
+ batch["context"]["extrinsics"][0, 0],
+ (
+ batch["context"]["extrinsics"][0, 1]
+ if v == 2
+ else batch["target"]["extrinsics"][0, 0]
+ ),
+ t * 5 - 2,
+ )
+ intrinsics = interpolate_intrinsics(
+ batch["context"]["intrinsics"][0, 0],
+ (
+ batch["context"]["intrinsics"][0, 1]
+ if v == 2
+ else batch["target"]["intrinsics"][0, 0]
+ ),
+ t * 5 - 2,
+ )
+ return extrinsics @ tf, intrinsics[None]
+
+ return self.render_video_generic(
+ batch,
+ trajectory_fn,
+ "interpolation_exagerrated",
+ num_frames=300,
+ smooth=False,
+ loop_reverse=False,
+ )
+
+ @rank_zero_only
+ def render_video_generic(
+ self,
+ batch: BatchedExample,
+ trajectory_fn: TrajectoryFn,
+ name: str,
+ num_frames: int = 30,
+ smooth: bool = True,
+ loop_reverse: bool = True,
+ ) -> None:
+ # Render probabilistic estimate of scene.
+ encoder_output = self.model.encoder((batch["context"]["image"]+1)/2, self.global_step)
+ gaussians, pred_pose_enc_list = encoder_output.gaussians, encoder_output.pred_pose_enc_list
+
+ t = torch.linspace(0, 1, num_frames, dtype=torch.float32, device=self.device)
+ if smooth:
+ t = (torch.cos(torch.pi * (t + 1)) + 1) / 2
+
+ extrinsics, intrinsics = trajectory_fn(t)
+
+ _, _, _, h, w = batch["context"]["image"].shape
+
+ # TODO: Interpolate near and far planes?
+ near = repeat(batch["context"]["near"][:, 0], "b -> b v", v=num_frames)
+ far = repeat(batch["context"]["far"][:, 0], "b -> b v", v=num_frames)
+ output = self.model.decoder.forward(
+ gaussians, extrinsics, intrinsics, near, far, (h, w), "depth"
+ )
+ images = [
+ vcat(rgb, depth)
+ for rgb, depth in zip(output.color[0], vis_depth_map(output.depth[0]))
+ ]
+
+ video = torch.stack(images)
+ video = (video.clip(min=0, max=1) * 255).type(torch.uint8).cpu().numpy()
+ if loop_reverse:
+ video = pack([video, video[::-1][1:-1]], "* c h w")[0]
+ visualizations = {
+ f"video/{name}": wandb.Video(video[None], fps=30, format="mp4")
+ }
+
+ # Since the PyTorch Lightning doesn't support video logging, log to wandb directly.
+ try:
+ wandb.log(visualizations)
+ except Exception:
+ assert isinstance(self.logger, LocalLogger)
+ for key, value in visualizations.items():
+ tensor = value._prepare_video(value.data)
+ clip = mpy.ImageSequenceClip(list(tensor), fps=30)
+ dir = LOG_PATH / key
+ dir.mkdir(exist_ok=True, parents=True)
+ clip.write_videofile(
+ str(dir / f"{self.global_step:0>6}.mp4"), logger=None
+ )
+
+ def print_preview_metrics(self, metrics: dict[str, float | Tensor], methods: list[str] | None = None, overlap_tag: str | None = None) -> None:
+ if getattr(self, "running_metrics", None) is None:
+ self.running_metrics = metrics
+ self.running_metric_steps = 1
+ else:
+ s = self.running_metric_steps
+ self.running_metrics = {
+ k: ((s * v) + metrics[k]) / (s + 1)
+ for k, v in self.running_metrics.items()
+ }
+ self.running_metric_steps += 1
+
+ if overlap_tag is not None:
+ if getattr(self, "running_metrics_sub", None) is None:
+ self.running_metrics_sub = {overlap_tag: metrics}
+ self.running_metric_steps_sub = {overlap_tag: 1}
+ elif overlap_tag not in self.running_metrics_sub:
+ self.running_metrics_sub[overlap_tag] = metrics
+ self.running_metric_steps_sub[overlap_tag] = 1
+ else:
+ s = self.running_metric_steps_sub[overlap_tag]
+ self.running_metrics_sub[overlap_tag] = {k: ((s * v) + metrics[k]) / (s + 1)
+ for k, v in self.running_metrics_sub[overlap_tag].items()}
+ self.running_metric_steps_sub[overlap_tag] += 1
+
+ metric_list = ["psnr", "lpips", "ssim"]
+
+ def print_metrics(runing_metric, methods=None):
+ table = []
+ if methods is None:
+ methods = ['ours']
+
+ for method in methods:
+ row = [
+ f"{runing_metric[f'{metric}_{method}']:.3f}"
+ for metric in metric_list
+ ]
+ table.append((method, *row))
+
+ headers = ["Method"] + metric_list
+ table = tabulate(table, headers)
+ print(table)
+
+ print("All Pairs:")
+ print_metrics(self.running_metrics, methods)
+ if overlap_tag is not None:
+ for k, v in self.running_metrics_sub.items():
+ print(f"Overlap: {k}")
+ print_metrics(v, methods)
+
+ def configure_optimizers(self):
+ new_params, new_param_names = [], []
+ pretrained_params, pretrained_param_names = [], []
+ for name, param in self.named_parameters():
+ if not param.requires_grad:
+ continue
+
+ if "gaussian_param_head" in name or "interm" in name:
+ new_params.append(param)
+ new_param_names.append(name)
+ else:
+ pretrained_params.append(param)
+ pretrained_param_names.append(name)
+
+ param_dicts = [
+ {
+ "params": new_params,
+ "lr": self.optimizer_cfg.lr,
+ },
+ {
+ "params": pretrained_params,
+ "lr": self.optimizer_cfg.lr * self.optimizer_cfg.backbone_lr_multiplier,
+ },
+ ]
+ optimizer = torch.optim.AdamW(param_dicts, lr=self.optimizer_cfg.lr, weight_decay=0.05, betas=(0.9, 0.95))
+ warm_up_steps = self.optimizer_cfg.warm_up_steps
+ warm_up = torch.optim.lr_scheduler.LinearLR(
+ optimizer,
+ 1 / warm_up_steps,
+ 1,
+ total_iters=warm_up_steps,
+ )
+
+ lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=get_cfg()["trainer"]["max_steps"], eta_min=self.optimizer_cfg.lr * 0.1)
+ lr_scheduler = torch.optim.lr_scheduler.SequentialLR(optimizer, schedulers=[warm_up, lr_scheduler], milestones=[warm_up_steps])
+
+ return {
+ "optimizer": optimizer,
+ "lr_scheduler": {
+ "scheduler": lr_scheduler,
+ "interval": "step",
+ "frequency": 1,
+ },
+ }
diff --git a/src/model/ply_export.py b/src/model/ply_export.py
new file mode 100644
index 0000000000000000000000000000000000000000..da22bdafcbfdff9a140c5b8318d5a8a0c68d9ebd
--- /dev/null
+++ b/src/model/ply_export.py
@@ -0,0 +1,74 @@
+from pathlib import Path
+
+import numpy as np
+import torch
+from einops import einsum, rearrange
+from jaxtyping import Float
+from plyfile import PlyData, PlyElement
+from scipy.spatial.transform import Rotation as R
+from torch import Tensor
+
+
+def construct_list_of_attributes(num_rest: int) -> list[str]:
+ attributes = ["x", "y", "z", "nx", "ny", "nz"]
+ for i in range(3):
+ attributes.append(f"f_dc_{i}")
+ for i in range(num_rest):
+ attributes.append(f"f_rest_{i}")
+ attributes.append("opacity")
+ for i in range(3):
+ attributes.append(f"scale_{i}")
+ for i in range(4):
+ attributes.append(f"rot_{i}")
+ return attributes
+
+
+def export_ply(
+ means: Float[Tensor, "gaussian 3"],
+ scales: Float[Tensor, "gaussian 3"],
+ rotations: Float[Tensor, "gaussian 4"],
+ harmonics: Float[Tensor, "gaussian 3 d_sh"],
+ opacities: Float[Tensor, " gaussian"],
+ path: Path,
+ shift_and_scale: bool = False,
+ save_sh_dc_only: bool = True,
+):
+ if shift_and_scale:
+ # Shift the scene so that the median Gaussian is at the origin.
+ means = means - means.median(dim=0).values
+
+ # Rescale the scene so that most Gaussians are within range [-1, 1].
+ scale_factor = means.abs().quantile(0.95, dim=0).max()
+ means = means / scale_factor
+ scales = scales / scale_factor
+
+ # Apply the rotation to the Gaussian rotations.
+ rotations = R.from_quat(rotations.detach().cpu().numpy()).as_matrix()
+ rotations = R.from_matrix(rotations).as_quat()
+ x, y, z, w = rearrange(rotations, "g xyzw -> xyzw g")
+ rotations = np.stack((w, x, y, z), axis=-1)
+
+ # Since current model use SH_degree = 4,
+ # which require large memory to store, we can only save the DC band to save memory.
+ f_dc = harmonics[..., 0]
+ f_rest = harmonics[..., 1:].flatten(start_dim=1)
+
+ dtype_full = [(attribute, "f4") for attribute in construct_list_of_attributes(0 if save_sh_dc_only else f_rest.shape[1])]
+ elements = np.empty(means.shape[0], dtype=dtype_full)
+ attributes = [
+ means.detach().cpu().numpy(),
+ torch.zeros_like(means).detach().cpu().numpy(),
+ f_dc.detach().cpu().contiguous().numpy(),
+ f_rest.detach().cpu().contiguous().numpy(),
+ opacities[..., None].detach().cpu().numpy(),
+ scales.log().detach().cpu().numpy(),
+ rotations,
+ ]
+ if save_sh_dc_only:
+ # remove f_rest from attributes
+ attributes.pop(3)
+
+ attributes = np.concatenate(attributes, axis=1)
+ elements[:] = list(map(tuple, attributes))
+ path.parent.mkdir(exist_ok=True, parents=True)
+ PlyData([PlyElement.describe(elements, "vertex")]).write(path)
diff --git a/src/model/transformer/attention.py b/src/model/transformer/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f013081ebb2353ae0b3c1f5425cceec291f2721
--- /dev/null
+++ b/src/model/transformer/attention.py
@@ -0,0 +1,70 @@
+# MIT License
+
+# Copyright (c) 2022 Karl Stelzner
+
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+# This file comes from https://github.com/stelzner/srt
+
+import torch
+from einops import rearrange
+from torch import nn
+
+
+class Attention(nn.Module):
+ def __init__(
+ self, dim, heads=8, dim_head=64, dropout=0.0, selfatt=True, kv_dim=None
+ ):
+ super().__init__()
+ inner_dim = dim_head * heads
+ project_out = not (heads == 1 and dim_head == dim)
+
+ self.heads = heads
+ self.scale = dim_head**-0.5
+
+ self.attend = nn.Softmax(dim=-1)
+ if selfatt:
+ self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
+ else:
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
+ self.to_kv = nn.Linear(kv_dim, inner_dim * 2, bias=False)
+
+ self.to_out = (
+ nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))
+ if project_out
+ else nn.Identity()
+ )
+
+ def forward(self, x, z=None):
+ if z is None:
+ qkv = self.to_qkv(x).chunk(3, dim=-1)
+ else:
+ q = self.to_q(x)
+ k, v = self.to_kv(z).chunk(2, dim=-1)
+ qkv = (q, k, v)
+
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), qkv)
+
+ dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
+
+ attn = self.attend(dots)
+
+ out = torch.matmul(attn, v)
+ out = rearrange(out, "b h n d -> b n (h d)")
+ return self.to_out(out)
diff --git a/src/model/transformer/feed_forward.py b/src/model/transformer/feed_forward.py
new file mode 100644
index 0000000000000000000000000000000000000000..1cec1dfc4ac2a389e8f8abdcd5885d46617f73e4
--- /dev/null
+++ b/src/model/transformer/feed_forward.py
@@ -0,0 +1,40 @@
+# MIT License
+
+# Copyright (c) 2022 Karl Stelzner
+
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+# This file comes from https://github.com/stelzner/srt
+
+from torch import nn
+
+
+class FeedForward(nn.Module):
+ def __init__(self, dim, hidden_dim, dropout=0.0):
+ super().__init__()
+ self.net = nn.Sequential(
+ nn.Linear(dim, hidden_dim),
+ nn.GELU(),
+ nn.Dropout(dropout),
+ nn.Linear(hidden_dim, dim),
+ nn.Dropout(dropout),
+ )
+
+ def forward(self, x):
+ return self.net(x)
diff --git a/src/model/transformer/pre_norm.py b/src/model/transformer/pre_norm.py
new file mode 100644
index 0000000000000000000000000000000000000000..1b1560add080382c86192eb989017bf903a1032f
--- /dev/null
+++ b/src/model/transformer/pre_norm.py
@@ -0,0 +1,35 @@
+# MIT License
+
+# Copyright (c) 2022 Karl Stelzner
+
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+# This file comes from https://github.com/stelzner/srt
+
+from torch import nn
+
+
+class PreNorm(nn.Module):
+ def __init__(self, dim, fn):
+ super().__init__()
+ self.norm = nn.LayerNorm(dim)
+ self.fn = fn
+
+ def forward(self, x, **kwargs):
+ return self.fn(self.norm(x), **kwargs)
diff --git a/src/model/transformer/transformer.py b/src/model/transformer/transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..7226e9b7872f087d7e4c9461566b7e1273887b7c
--- /dev/null
+++ b/src/model/transformer/transformer.py
@@ -0,0 +1,71 @@
+# MIT License
+
+# Copyright (c) 2022 Karl Stelzner
+
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+# This file comes from https://github.com/stelzner/srt
+
+from torch import nn
+
+from .attention import Attention
+from .feed_forward import FeedForward
+from .pre_norm import PreNorm
+
+
+class Transformer(nn.Module):
+ def __init__(
+ self,
+ dim,
+ depth,
+ heads,
+ dim_head,
+ mlp_dim,
+ dropout=0.0,
+ selfatt=True,
+ kv_dim=None,
+ feed_forward_layer=FeedForward,
+ ):
+ super().__init__()
+ self.layers = nn.ModuleList([])
+ for _ in range(depth):
+ self.layers.append(
+ nn.ModuleList(
+ [
+ PreNorm(
+ dim,
+ Attention(
+ dim,
+ heads=heads,
+ dim_head=dim_head,
+ dropout=dropout,
+ selfatt=selfatt,
+ kv_dim=kv_dim,
+ ),
+ ),
+ PreNorm(dim, feed_forward_layer(dim, mlp_dim, dropout=dropout)),
+ ]
+ )
+ )
+
+ def forward(self, x, z=None, **kwargs):
+ for attn, ff in self.layers:
+ x = attn(x, z=z) + x
+ x = ff(x, **kwargs) + x
+ return x
diff --git a/src/model/types.py b/src/model/types.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4f6929af47b85a0fba1db65d706bf04d91aa50e
--- /dev/null
+++ b/src/model/types.py
@@ -0,0 +1,15 @@
+from dataclasses import dataclass
+
+from jaxtyping import Float
+from torch import Tensor
+
+
+@dataclass
+class Gaussians:
+ means: Float[Tensor, "batch gaussian dim"]
+ covariances: Float[Tensor, "batch gaussian dim dim"]
+ harmonics: Float[Tensor, "batch gaussian 3 d_sh"]
+ opacities: Float[Tensor, "batch gaussian"]
+ scales: Float[Tensor, "batch gaussian 3"]
+ rotations: Float[Tensor, "batch gaussian 4"]
+ # levels: Float[Tensor, "batch gaussian"]
diff --git a/src/post_opt/datasets/colmap.py b/src/post_opt/datasets/colmap.py
new file mode 100644
index 0000000000000000000000000000000000000000..d93b9a08e8b523371f0433ac182158be18c82e74
--- /dev/null
+++ b/src/post_opt/datasets/colmap.py
@@ -0,0 +1,490 @@
+import json
+import os
+from typing import Any, Dict, List, Optional
+
+import cv2
+import imageio.v2 as imageio
+import numpy as np
+import torch
+from PIL import Image
+from pycolmap import SceneManager
+from tqdm import tqdm
+from typing_extensions import assert_never
+
+import sys
+sys.path.append("/cpfs01/user/jianglihan/projects/gsplat/examples/datasets")
+sys.path.append("/cpfs01/user/jianglihan/projects/gsplat/examples")
+sys.path.append("/cpfs01/user/jianglihan/projects/gsplat")
+
+from normalize import (
+ align_principal_axes,
+ similarity_from_cameras,
+ transform_cameras,
+ transform_points,
+)
+
+
+def _get_rel_paths(path_dir: str) -> List[str]:
+ """Recursively get relative paths of files in a directory."""
+ paths = []
+ for dp, dn, fn in os.walk(path_dir):
+ for f in fn:
+ paths.append(os.path.relpath(os.path.join(dp, f), path_dir))
+ return paths
+
+
+def _resize_image_folder(image_dir: str, resized_dir: str, factor: int) -> str:
+ """Resize image folder."""
+ print(f"Downscaling images by {factor}x from {image_dir} to {resized_dir}.")
+ os.makedirs(resized_dir, exist_ok=True)
+
+ image_files = _get_rel_paths(image_dir)
+ for image_file in tqdm(image_files):
+ image_path = os.path.join(image_dir, image_file)
+ resized_path = os.path.join(
+ resized_dir, os.path.splitext(image_file)[0] + ".png"
+ )
+ if os.path.isfile(resized_path):
+ continue
+ image = imageio.imread(image_path)[..., :3]
+ resized_size = (
+ int(round(image.shape[1] / factor)),
+ int(round(image.shape[0] / factor)),
+ )
+ resized_image = np.array(
+ Image.fromarray(image).resize(resized_size, Image.BICUBIC)
+ )
+ imageio.imwrite(resized_path, resized_image)
+ return resized_dir
+
+
+class Parser:
+ """COLMAP parser."""
+
+ def __init__(
+ self,
+ data_dir: str,
+ factor: int = 1,
+ normalize: bool = False,
+ test_every: int = 8,
+ ):
+ self.data_dir = data_dir
+ self.factor = factor
+ self.normalize = normalize
+ self.test_every = test_every
+
+
+
+ colmap_dir = os.path.join(data_dir, "sparse/0/")
+ if not os.path.exists(colmap_dir):
+ colmap_dir = os.path.join(data_dir, "sparse")
+ assert os.path.exists(
+ colmap_dir
+ ), f"COLMAP directory {colmap_dir} does not exist."
+
+ manager = SceneManager(colmap_dir)
+ manager.load_cameras()
+ manager.load_images()
+ manager.load_points3D()
+
+ # Extract extrinsic matrices in world-to-camera format.
+ imdata = manager.images
+ w2c_mats = []
+ camera_ids = []
+ Ks_dict = dict()
+ params_dict = dict()
+ imsize_dict = dict() # width, height
+ mask_dict = dict()
+ bottom = np.array([0, 0, 0, 1]).reshape(1, 4)
+ for k in imdata:
+ im = imdata[k]
+ rot = im.R()
+ trans = im.tvec.reshape(3, 1)
+ w2c = np.concatenate([np.concatenate([rot, trans], 1), bottom], axis=0)
+ w2c_mats.append(w2c)
+
+ # support different camera intrinsics
+ camera_id = im.camera_id
+ camera_ids.append(camera_id)
+
+ # camera intrinsics
+ cam = manager.cameras[camera_id]
+ fx, fy, cx, cy = cam.fx, cam.fy, cam.cx, cam.cy
+ K = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]])
+ K[:2, :] /= factor
+ Ks_dict[camera_id] = K
+
+ # Get distortion parameters.
+ type_ = cam.camera_type
+ if type_ == 0 or type_ == "SIMPLE_PINHOLE":
+ params = np.empty(0, dtype=np.float32)
+ camtype = "perspective"
+ elif type_ == 1 or type_ == "PINHOLE":
+ params = np.empty(0, dtype=np.float32)
+ camtype = "perspective"
+ if type_ == 2 or type_ == "SIMPLE_RADIAL":
+ params = np.array([cam.k1, 0.0, 0.0, 0.0], dtype=np.float32)
+ camtype = "perspective"
+ elif type_ == 3 or type_ == "RADIAL":
+ params = np.array([cam.k1, cam.k2, 0.0, 0.0], dtype=np.float32)
+ camtype = "perspective"
+ elif type_ == 4 or type_ == "OPENCV":
+ params = np.array([cam.k1, cam.k2, cam.p1, cam.p2], dtype=np.float32)
+ camtype = "perspective"
+ elif type_ == 5 or type_ == "OPENCV_FISHEYE":
+ params = np.array([cam.k1, cam.k2, cam.k3, cam.k4], dtype=np.float32)
+ camtype = "fisheye"
+ assert (
+ camtype == "perspective" or camtype == "fisheye"
+ ), f"Only perspective and fisheye cameras are supported, got {type_}"
+
+ params_dict[camera_id] = params
+ imsize_dict[camera_id] = (cam.width // factor, cam.height // factor)
+ mask_dict[camera_id] = None
+ print(
+ f"[Parser] {len(imdata)} images, taken by {len(set(camera_ids))} cameras."
+ )
+
+ if len(imdata) == 0:
+ raise ValueError("No images found in COLMAP.")
+ if not (type_ == 0 or type_ == 1):
+ print("Warning: COLMAP Camera is not PINHOLE. Images have distortion.")
+
+ w2c_mats = np.stack(w2c_mats, axis=0)
+
+ # Convert extrinsics to camera-to-world.
+ camtoworlds = np.linalg.inv(w2c_mats)
+
+ # Image names from COLMAP. No need for permuting the poses according to
+ # image names anymore.
+ image_names = [imdata[k].name for k in imdata]
+
+ # Previous Nerf results were generated with images sorted by filename,
+ # ensure metrics are reported on the same test set.
+ inds = np.argsort(image_names)
+ image_names = [image_names[i] for i in inds]
+ camtoworlds = camtoworlds[inds]
+ camera_ids = [camera_ids[i] for i in inds]
+
+ # Load extended metadata. Used by Bilarf dataset.
+ self.extconf = {
+ "spiral_radius_scale": 1.0,
+ "no_factor_suffix": False,
+ }
+ extconf_file = os.path.join(data_dir, "ext_metadata.json")
+ if os.path.exists(extconf_file):
+ with open(extconf_file) as f:
+ self.extconf.update(json.load(f))
+
+ # Load bounds if possible (only used in forward facing scenes).
+ self.bounds = np.array([0.01, 1.0])
+ posefile = os.path.join(data_dir, "poses_bounds.npy")
+ if os.path.exists(posefile):
+ self.bounds = np.load(posefile)[:, -2:]
+
+ # Load images.
+ if factor > 1 and not self.extconf["no_factor_suffix"]:
+ image_dir_suffix = f"_{factor}"
+ else:
+ image_dir_suffix = ""
+ colmap_image_dir = os.path.join(data_dir, "images")
+ image_dir = os.path.join(data_dir, "images" + image_dir_suffix)
+ for d in [image_dir, colmap_image_dir]:
+ if not os.path.exists(d):
+ raise ValueError(f"Image folder {d} does not exist.")
+
+ # Downsampled images may have different names vs images used for COLMAP,
+ # so we need to map between the two sorted lists of files.
+ colmap_files = sorted(_get_rel_paths(colmap_image_dir))
+ image_files = sorted(_get_rel_paths(image_dir))
+ if factor > 1 and os.path.splitext(image_files[0])[1].lower() == ".jpg":
+ image_dir = _resize_image_folder(
+ colmap_image_dir, image_dir + "_png", factor=factor
+ )
+ image_files = sorted(_get_rel_paths(image_dir))
+ colmap_to_image = dict(zip(colmap_files, image_files))
+ image_paths = [os.path.join(image_dir, colmap_to_image[f]) for f in image_names]
+
+ # 3D points and {image_name -> [point_idx]}
+ points = manager.points3D.astype(np.float32)
+ points_err = manager.point3D_errors.astype(np.float32)
+ points_rgb = manager.point3D_colors.astype(np.uint8)
+ point_indices = dict()
+
+ image_id_to_name = {v: k for k, v in manager.name_to_image_id.items()}
+ for point_id, data in manager.point3D_id_to_images.items():
+ for image_id, _ in data:
+ image_name = image_id_to_name[image_id]
+ point_idx = manager.point3D_id_to_point3D_idx[point_id]
+ point_indices.setdefault(image_name, []).append(point_idx)
+ point_indices = {
+ k: np.array(v).astype(np.int32) for k, v in point_indices.items()
+ }
+
+ # Normalize the world space.
+ if normalize:
+ T1 = similarity_from_cameras(camtoworlds)
+ camtoworlds = transform_cameras(T1, camtoworlds)
+ points = transform_points(T1, points)
+
+ T2 = align_principal_axes(points)
+ camtoworlds = transform_cameras(T2, camtoworlds)
+ points = transform_points(T2, points)
+
+ transform = T2 @ T1
+
+ # Fix for up side down. We assume more points towards
+ # the bottom of the scene which is true when ground floor is
+ # present in the images.
+ if np.median(points[:, 2]) > np.mean(points[:, 2]):
+ # rotate 180 degrees around x axis such that z is flipped
+ T3 = np.array(
+ [
+ [1.0, 0.0, 0.0, 0.0],
+ [0.0, -1.0, 0.0, 0.0],
+ [0.0, 0.0, -1.0, 0.0],
+ [0.0, 0.0, 0.0, 1.0],
+ ]
+ )
+ camtoworlds = transform_cameras(T3, camtoworlds)
+ points = transform_points(T3, points)
+ transform = T3 @ transform
+ else:
+ transform = np.eye(4)
+
+ self.image_names = image_names # List[str], (num_images,)
+ self.image_paths = image_paths # List[str], (num_images,)
+ self.camtoworlds = camtoworlds # np.ndarray, (num_images, 4, 4)
+ self.camera_ids = camera_ids # List[int], (num_images,)
+ self.Ks_dict = Ks_dict # Dict of camera_id -> K
+ self.params_dict = params_dict # Dict of camera_id -> params
+ self.imsize_dict = imsize_dict # Dict of camera_id -> (width, height)
+ self.mask_dict = mask_dict # Dict of camera_id -> mask
+ self.points = points # np.ndarray, (num_points, 3)
+ self.points_err = points_err # np.ndarray, (num_points,)
+ self.points_rgb = points_rgb # np.ndarray, (num_points, 3)
+ self.point_indices = point_indices # Dict[str, np.ndarray], image_name -> [M,]
+ self.transform = transform # np.ndarray, (4, 4)
+
+ # load one image to check the size. In the case of tanksandtemples dataset, the
+ # intrinsics stored in COLMAP corresponds to 2x upsampled images.
+ actual_image = imageio.imread(self.image_paths[0])[..., :3]
+ actual_height, actual_width = actual_image.shape[:2]
+ colmap_width, colmap_height = self.imsize_dict[self.camera_ids[0]]
+ s_height, s_width = actual_height / colmap_height, actual_width / colmap_width
+ for camera_id, K in self.Ks_dict.items():
+ K[0, :] *= s_width
+ K[1, :] *= s_height
+ self.Ks_dict[camera_id] = K
+ width, height = self.imsize_dict[camera_id]
+ self.imsize_dict[camera_id] = (int(width * s_width), int(height * s_height))
+
+ # undistortion
+ self.mapx_dict = dict()
+ self.mapy_dict = dict()
+ self.roi_undist_dict = dict()
+ for camera_id in self.params_dict.keys():
+ params = self.params_dict[camera_id]
+ if len(params) == 0:
+ continue # no distortion
+ assert camera_id in self.Ks_dict, f"Missing K for camera {camera_id}"
+ assert (
+ camera_id in self.params_dict
+ ), f"Missing params for camera {camera_id}"
+ K = self.Ks_dict[camera_id]
+ width, height = self.imsize_dict[camera_id]
+
+ if camtype == "perspective":
+ K_undist, roi_undist = cv2.getOptimalNewCameraMatrix(
+ K, params, (width, height), 0
+ )
+ mapx, mapy = cv2.initUndistortRectifyMap(
+ K, params, None, K_undist, (width, height), cv2.CV_32FC1
+ )
+ mask = None
+ elif camtype == "fisheye":
+ fx = K[0, 0]
+ fy = K[1, 1]
+ cx = K[0, 2]
+ cy = K[1, 2]
+ grid_x, grid_y = np.meshgrid(
+ np.arange(width, dtype=np.float32),
+ np.arange(height, dtype=np.float32),
+ indexing="xy",
+ )
+ x1 = (grid_x - cx) / fx
+ y1 = (grid_y - cy) / fy
+ theta = np.sqrt(x1**2 + y1**2)
+ r = (
+ 1.0
+ + params[0] * theta**2
+ + params[1] * theta**4
+ + params[2] * theta**6
+ + params[3] * theta**8
+ )
+ mapx = (fx * x1 * r + width // 2).astype(np.float32)
+ mapy = (fy * y1 * r + height // 2).astype(np.float32)
+
+ # Use mask to define ROI
+ mask = np.logical_and(
+ np.logical_and(mapx > 0, mapy > 0),
+ np.logical_and(mapx < width - 1, mapy < height - 1),
+ )
+ y_indices, x_indices = np.nonzero(mask)
+ y_min, y_max = y_indices.min(), y_indices.max() + 1
+ x_min, x_max = x_indices.min(), x_indices.max() + 1
+ mask = mask[y_min:y_max, x_min:x_max]
+ K_undist = K.copy()
+ K_undist[0, 2] -= x_min
+ K_undist[1, 2] -= y_min
+ roi_undist = [x_min, y_min, x_max - x_min, y_max - y_min]
+ else:
+ assert_never(camtype)
+
+ self.mapx_dict[camera_id] = mapx
+ self.mapy_dict[camera_id] = mapy
+ self.Ks_dict[camera_id] = K_undist
+ self.roi_undist_dict[camera_id] = roi_undist
+ self.imsize_dict[camera_id] = (roi_undist[2], roi_undist[3])
+ self.mask_dict[camera_id] = mask
+
+ # size of the scene measured by cameras
+ camera_locations = camtoworlds[:, :3, 3]
+ scene_center = np.mean(camera_locations, axis=0)
+ dists = np.linalg.norm(camera_locations - scene_center, axis=1)
+ self.scene_scale = np.max(dists)
+
+
+class Dataset:
+ """A simple dataset class."""
+
+ def __init__(
+ self,
+ # parser: Parser,
+ images: np.ndarray,
+ camtoworlds: np.ndarray,
+ Ks: np.ndarray,
+ split: str = "train",
+ patch_size: Optional[int] = None,
+ load_depths: bool = False,
+ ):
+ # self.parser = parser
+ self.split = split
+ self.patch_size = patch_size
+ self.load_depths = load_depths
+ self.images = images
+ self.camtoworlds = camtoworlds
+ self.Ks = Ks
+ H, W = self.images.shape[-2:]
+ self.Ks[:, 0, :] *= W
+ self.Ks[:, 1, :] *= H
+ self.indices = np.arange(len(images))
+ # indices = np.arange(len(self.parser.image_names))
+ # if split == "train":
+ # self.indices = indices[indices % self.parser.test_every != 0]
+ # else:
+ # self.indices = indices[indices % self.parser.test_every == 0]
+
+ # if split == "train":
+ # self.images = np.load(os.path.join(self.parser.true_data_dir, "context_image.npy"))
+ # self.camtoworlds = np.load(os.path.join(self.parser.true_data_dir, "context_extrinsic.npy"))
+ # self.Ks = np.load(os.path.join(self.parser.true_data_dir, "context_intrinsic.npy"))
+ # H, W = self.images.shape[-2:]
+ # self.Ks[:, 0, :] *= W
+ # self.Ks[:, 1, :] *= H
+ # self.indices = np.arange(len(self.images))
+ # else:
+ # self.images = np.load(os.path.join(self.parser.true_data_dir, "target_image.npy"))
+ # self.camtoworlds = np.load(os.path.join(self.parser.true_data_dir, "target_extrinsic.npy"))
+ # self.Ks = np.load(os.path.join(self.parser.true_data_dir, "target_intrinsic.npy"))
+ # H, W = self.images.shape[-2:]
+ # self.Ks[:, 0, :] *= W
+ # self.Ks[:, 1, :] *= H
+ # self.indices = np.arange(len(self.images))
+
+ def __len__(self):
+ return len(self.indices)
+
+ def __getitem__(self, item: int) -> Dict[str, Any]:
+ index = self.indices[item]
+ image = (self.images[index]*255.0).transpose(1, 2, 0).astype(np.uint8) # (H, W, 3)
+ K = self.Ks[index].copy() # undistorted K
+ params = None
+ camtoworlds = self.camtoworlds[index]
+ mask = None
+
+ if self.patch_size is not None:
+ # Random crop.
+ h, w = image.shape[:2]
+ x = np.random.randint(0, max(w - self.patch_size, 1))
+ y = np.random.randint(0, max(h - self.patch_size, 1))
+ image = image[y : y + self.patch_size, x : x + self.patch_size]
+ K[0, 2] -= x
+ K[1, 2] -= y
+
+ data = {
+ "K": torch.from_numpy(K).float(),
+ "camtoworld": torch.from_numpy(camtoworlds).float(),
+ "image": torch.from_numpy(image).float(),
+ "image_id": item, # the index of the image in the dataset
+ }
+ if mask is not None:
+ data["mask"] = torch.from_numpy(mask).bool()
+
+ if self.load_depths and False:
+ # projected points to image plane to get depths
+ worldtocams = np.linalg.inv(camtoworlds)
+ image_name = self.parser.image_names[index]
+ point_indices = self.parser.point_indices[image_name]
+ points_world = self.parser.points[point_indices]
+ points_cam = (worldtocams[:3, :3] @ points_world.T + worldtocams[:3, 3:4]).T
+ points_proj = (K @ points_cam.T).T
+ points = points_proj[:, :2] / points_proj[:, 2:3] # (M, 2)
+ depths = points_cam[:, 2] # (M,)
+ # filter out points outside the image
+ selector = (
+ (points[:, 0] >= 0)
+ & (points[:, 0] < image.shape[1])
+ & (points[:, 1] >= 0)
+ & (points[:, 1] < image.shape[0])
+ & (depths > 0)
+ )
+ points = points[selector]
+ depths = depths[selector]
+ data["points"] = torch.from_numpy(points).float()
+ data["depths"] = torch.from_numpy(depths).float()
+
+ return data
+
+
+if __name__ == "__main__":
+ import argparse
+
+ import imageio.v2 as imageio
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--data_dir", type=str, default="data/mipnerf360/garden")
+ parser.add_argument("--true_data_dir", type=str, default="/cpfs01/user/jianglihan/projects/anysplat_baselines/demo_data/infer_output/3F_100view/room5")
+ parser.add_argument("--factor", type=int, default=4)
+ args = parser.parse_args()
+
+ # Parse COLMAP data.
+ parser = Parser(
+ data_dir=args.data_dir,
+ true_data_dir=args.true_data_dir,
+ factor=args.factor,
+ normalize=True,
+ test_every=8
+ )
+ dataset = Dataset(parser, split="train", load_depths=True)
+ print(f"Dataset: {len(dataset)} images.")
+
+ writer = imageio.get_writer("results/points.mp4", fps=30)
+ for data in tqdm(dataset, desc="Plotting points"):
+ image = data["image"].numpy().astype(np.uint8)
+ points = data["points"].numpy()
+ depths = data["depths"].numpy()
+ for x, y in points:
+ cv2.circle(image, (int(x), int(y)), 2, (255, 0, 0), -1)
+ writer.append_data(image)
+ writer.close()
diff --git a/src/post_opt/datasets/normalize.py b/src/post_opt/datasets/normalize.py
new file mode 100644
index 0000000000000000000000000000000000000000..681623b311625065744813f565666e9a8154a33c
--- /dev/null
+++ b/src/post_opt/datasets/normalize.py
@@ -0,0 +1,143 @@
+import numpy as np
+
+
+def similarity_from_cameras(c2w, strict_scaling=False, center_method="focus"):
+ """
+ reference: nerf-factory
+ Get a similarity transform to normalize dataset
+ from c2w (OpenCV convention) cameras
+ :param c2w: (N, 4)
+ :return T (4,4) , scale (float)
+ """
+ t = c2w[:, :3, 3]
+ R = c2w[:, :3, :3]
+
+ # (1) Rotate the world so that z+ is the up axis
+ # we estimate the up axis by averaging the camera up axes
+ ups = np.sum(R * np.array([0, -1.0, 0]), axis=-1)
+ world_up = np.mean(ups, axis=0)
+ world_up /= np.linalg.norm(world_up)
+
+ up_camspace = np.array([0.0, -1.0, 0.0])
+ c = (up_camspace * world_up).sum()
+ cross = np.cross(world_up, up_camspace)
+ skew = np.array(
+ [
+ [0.0, -cross[2], cross[1]],
+ [cross[2], 0.0, -cross[0]],
+ [-cross[1], cross[0], 0.0],
+ ]
+ )
+ if c > -1:
+ R_align = np.eye(3) + skew + (skew @ skew) * 1 / (1 + c)
+ else:
+ # In the unlikely case the original data has y+ up axis,
+ # rotate 180-deg about x axis
+ R_align = np.array([[-1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]])
+
+ # R_align = np.eye(3) # DEBUG
+ R = R_align @ R
+ fwds = np.sum(R * np.array([0, 0.0, 1.0]), axis=-1)
+ t = (R_align @ t[..., None])[..., 0]
+
+ # (2) Recenter the scene.
+ if center_method == "focus":
+ # find the closest point to the origin for each camera's center ray
+ nearest = t + (fwds * -t).sum(-1)[:, None] * fwds
+ translate = -np.median(nearest, axis=0)
+ elif center_method == "poses":
+ # use center of the camera positions
+ translate = -np.median(t, axis=0)
+ else:
+ raise ValueError(f"Unknown center_method {center_method}")
+
+ transform = np.eye(4)
+ transform[:3, 3] = translate
+ transform[:3, :3] = R_align
+
+ # (3) Rescale the scene using camera distances
+ scale_fn = np.max if strict_scaling else np.median
+ scale = 1.0 / scale_fn(np.linalg.norm(t + translate, axis=-1))
+ transform[:3, :] *= scale
+
+ return transform
+
+
+def align_principal_axes(point_cloud):
+ # Compute centroid
+ centroid = np.median(point_cloud, axis=0)
+
+ # Translate point cloud to centroid
+ translated_point_cloud = point_cloud - centroid
+
+ # Compute covariance matrix
+ covariance_matrix = np.cov(translated_point_cloud, rowvar=False)
+
+ # Compute eigenvectors and eigenvalues
+ eigenvalues, eigenvectors = np.linalg.eigh(covariance_matrix)
+
+ # Sort eigenvectors by eigenvalues (descending order) so that the z-axis
+ # is the principal axis with the smallest eigenvalue.
+ sort_indices = eigenvalues.argsort()[::-1]
+ eigenvectors = eigenvectors[:, sort_indices]
+
+ # Check orientation of eigenvectors. If the determinant of the eigenvectors is
+ # negative, then we need to flip the sign of one of the eigenvectors.
+ if np.linalg.det(eigenvectors) < 0:
+ eigenvectors[:, 0] *= -1
+
+ # Create rotation matrix
+ rotation_matrix = eigenvectors.T
+
+ # Create SE(3) matrix (4x4 transformation matrix)
+ transform = np.eye(4)
+ transform[:3, :3] = rotation_matrix
+ transform[:3, 3] = -rotation_matrix @ centroid
+
+ return transform
+
+
+def transform_points(matrix, points):
+ """Transform points using an SE(3) matrix.
+
+ Args:
+ matrix: 4x4 SE(3) matrix
+ points: Nx3 array of points
+
+ Returns:
+ Nx3 array of transformed points
+ """
+ assert matrix.shape == (4, 4)
+ assert len(points.shape) == 2 and points.shape[1] == 3
+ return points @ matrix[:3, :3].T + matrix[:3, 3]
+
+
+def transform_cameras(matrix, camtoworlds):
+ """Transform cameras using an SE(3) matrix.
+
+ Args:
+ matrix: 4x4 SE(3) matrix
+ camtoworlds: Nx4x4 array of camera-to-world matrices
+
+ Returns:
+ Nx4x4 array of transformed camera-to-world matrices
+ """
+ assert matrix.shape == (4, 4)
+ assert len(camtoworlds.shape) == 3 and camtoworlds.shape[1:] == (4, 4)
+ camtoworlds = np.einsum("nij, ki -> nkj", camtoworlds, matrix)
+ scaling = np.linalg.norm(camtoworlds[:, 0, :3], axis=1)
+ camtoworlds[:, :3, :3] = camtoworlds[:, :3, :3] / scaling[:, None, None]
+ return camtoworlds
+
+
+def normalize(camtoworlds, points=None):
+ T1 = similarity_from_cameras(camtoworlds)
+ camtoworlds = transform_cameras(T1, camtoworlds)
+ if points is not None:
+ points = transform_points(T1, points)
+ T2 = align_principal_axes(points)
+ camtoworlds = transform_cameras(T2, camtoworlds)
+ points = transform_points(T2, points)
+ return camtoworlds, points, T2 @ T1
+ else:
+ return camtoworlds, T1
diff --git a/src/post_opt/datasets/traj.py b/src/post_opt/datasets/traj.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e8a2d69dc969e78d5ea019465fe39c10618cf5c
--- /dev/null
+++ b/src/post_opt/datasets/traj.py
@@ -0,0 +1,254 @@
+"""
+Code borrowed from
+
+https://github.com/google-research/multinerf/blob/5b4d4f64608ec8077222c52fdf814d40acc10bc1/internal/camera_utils.py
+"""
+
+import numpy as np
+import scipy
+
+
+def normalize(x: np.ndarray) -> np.ndarray:
+ """Normalization helper function."""
+ return x / np.linalg.norm(x)
+
+
+def viewmatrix(lookdir: np.ndarray, up: np.ndarray, position: np.ndarray) -> np.ndarray:
+ """Construct lookat view matrix."""
+ vec2 = normalize(lookdir)
+ vec0 = normalize(np.cross(up, vec2))
+ vec1 = normalize(np.cross(vec2, vec0))
+ m = np.stack([vec0, vec1, vec2, position], axis=1)
+ return m
+
+
+def focus_point_fn(poses: np.ndarray) -> np.ndarray:
+ """Calculate nearest point to all focal axes in poses."""
+ directions, origins = poses[:, :3, 2:3], poses[:, :3, 3:4]
+ m = np.eye(3) - directions * np.transpose(directions, [0, 2, 1])
+ mt_m = np.transpose(m, [0, 2, 1]) @ m
+ focus_pt = np.linalg.inv(mt_m.mean(0)) @ (mt_m @ origins).mean(0)[:, 0]
+ return focus_pt
+
+
+def average_pose(poses: np.ndarray) -> np.ndarray:
+ """New pose using average position, z-axis, and up vector of input poses."""
+ position = poses[:, :3, 3].mean(0)
+ z_axis = poses[:, :3, 2].mean(0)
+ up = poses[:, :3, 1].mean(0)
+ cam2world = viewmatrix(z_axis, up, position)
+ return cam2world
+
+
+def generate_spiral_path(
+ poses,
+ bounds,
+ n_frames=120,
+ n_rots=2,
+ zrate=0.5,
+ spiral_scale_f=1.0,
+ spiral_scale_r=1.0,
+ focus_distance=0.75,
+):
+ """Calculates a forward facing spiral path for rendering."""
+ # Find a reasonable 'focus depth' for this dataset as a weighted average
+ # of conservative near and far bounds in disparity space.
+ near_bound = bounds.min()
+ far_bound = bounds.max()
+ # All cameras will point towards the world space point (0, 0, -focal).
+ focal = 1 / (((1 - focus_distance) / near_bound + focus_distance / far_bound))
+ focal = focal * spiral_scale_f
+
+ # Get radii for spiral path using 90th percentile of camera positions.
+ positions = poses[:, :3, 3]
+ radii = np.percentile(np.abs(positions), 90, 0)
+ radii = radii * spiral_scale_r
+ radii = np.concatenate([radii, [1.0]])
+
+ # Generate poses for spiral path.
+ render_poses = []
+ cam2world = average_pose(poses)
+ up = poses[:, :3, 1].mean(0)
+ for theta in np.linspace(0.0, 2.0 * np.pi * n_rots, n_frames, endpoint=False):
+ t = radii * [np.cos(theta), -np.sin(theta), -np.sin(theta * zrate), 1.0]
+ position = cam2world @ t
+ lookat = cam2world @ [0, 0, -focal, 1.0]
+ z_axis = position - lookat
+ render_poses.append(viewmatrix(z_axis, up, position))
+ render_poses = np.stack(render_poses, axis=0)
+ return render_poses
+
+
+def generate_ellipse_path_z(
+ poses: np.ndarray,
+ n_frames: int = 120,
+ # const_speed: bool = True,
+ variation: float = 0.0,
+ phase: float = 0.0,
+ height: float = 0.0,
+) -> np.ndarray:
+ """Generate an elliptical render path based on the given poses."""
+ # Calculate the focal point for the path (cameras point toward this).
+ center = focus_point_fn(poses)
+ # Path height sits at z=height (in middle of zero-mean capture pattern).
+ offset = np.array([center[0], center[1], height])
+
+ # Calculate scaling for ellipse axes based on input camera positions.
+ sc = np.percentile(np.abs(poses[:, :3, 3] - offset), 90, axis=0)
+ # Use ellipse that is symmetric about the focal point in xy.
+ low = -sc + offset
+ high = sc + offset
+ # Optional height variation need not be symmetric
+ z_low = np.percentile((poses[:, :3, 3]), 10, axis=0)
+ z_high = np.percentile((poses[:, :3, 3]), 90, axis=0)
+
+ def get_positions(theta):
+ # Interpolate between bounds with trig functions to get ellipse in x-y.
+ # Optionally also interpolate in z to change camera height along path.
+ return np.stack(
+ [
+ low[0] + (high - low)[0] * (np.cos(theta) * 0.5 + 0.5),
+ low[1] + (high - low)[1] * (np.sin(theta) * 0.5 + 0.5),
+ variation
+ * (
+ z_low[2]
+ + (z_high - z_low)[2]
+ * (np.cos(theta + 2 * np.pi * phase) * 0.5 + 0.5)
+ )
+ + height,
+ ],
+ -1,
+ )
+
+ theta = np.linspace(0, 2.0 * np.pi, n_frames + 1, endpoint=True)
+ positions = get_positions(theta)
+
+ # if const_speed:
+ # # Resample theta angles so that the velocity is closer to constant.
+ # lengths = np.linalg.norm(positions[1:] - positions[:-1], axis=-1)
+ # theta = stepfun.sample(None, theta, np.log(lengths), n_frames + 1)
+ # positions = get_positions(theta)
+
+ # Throw away duplicated last position.
+ positions = positions[:-1]
+
+ # Set path's up vector to axis closest to average of input pose up vectors.
+ avg_up = poses[:, :3, 1].mean(0)
+ avg_up = avg_up / np.linalg.norm(avg_up)
+ ind_up = np.argmax(np.abs(avg_up))
+ up = np.eye(3)[ind_up] * np.sign(avg_up[ind_up])
+
+ return np.stack([viewmatrix(center - p, up, p) for p in positions])
+
+
+def generate_ellipse_path_y(
+ poses: np.ndarray,
+ n_frames: int = 120,
+ # const_speed: bool = True,
+ variation: float = 0.0,
+ phase: float = 0.0,
+ height: float = 0.0,
+) -> np.ndarray:
+ """Generate an elliptical render path based on the given poses."""
+ # Calculate the focal point for the path (cameras point toward this).
+ center = focus_point_fn(poses)
+ # Path height sits at y=height (in middle of zero-mean capture pattern).
+ offset = np.array([center[0], height, center[2]])
+
+ # Calculate scaling for ellipse axes based on input camera positions.
+ sc = np.percentile(np.abs(poses[:, :3, 3] - offset), 90, axis=0)
+ # Use ellipse that is symmetric about the focal point in xy.
+ low = -sc + offset
+ high = sc + offset
+ # Optional height variation need not be symmetric
+ y_low = np.percentile((poses[:, :3, 3]), 10, axis=0)
+ y_high = np.percentile((poses[:, :3, 3]), 90, axis=0)
+
+ def get_positions(theta):
+ # Interpolate between bounds with trig functions to get ellipse in x-z.
+ # Optionally also interpolate in y to change camera height along path.
+ return np.stack(
+ [
+ low[0] + (high - low)[0] * (np.cos(theta) * 0.5 + 0.5),
+ variation
+ * (
+ y_low[1]
+ + (y_high - y_low)[1]
+ * (np.cos(theta + 2 * np.pi * phase) * 0.5 + 0.5)
+ )
+ + height,
+ low[2] + (high - low)[2] * (np.sin(theta) * 0.5 + 0.5),
+ ],
+ -1,
+ )
+
+ theta = np.linspace(0, 2.0 * np.pi, n_frames + 1, endpoint=True)
+ positions = get_positions(theta)
+
+ # if const_speed:
+ # # Resample theta angles so that the velocity is closer to constant.
+ # lengths = np.linalg.norm(positions[1:] - positions[:-1], axis=-1)
+ # theta = stepfun.sample(None, theta, np.log(lengths), n_frames + 1)
+ # positions = get_positions(theta)
+
+ # Throw away duplicated last position.
+ positions = positions[:-1]
+
+ # Set path's up vector to axis closest to average of input pose up vectors.
+ avg_up = poses[:, :3, 1].mean(0)
+ avg_up = avg_up / np.linalg.norm(avg_up)
+ ind_up = np.argmax(np.abs(avg_up))
+ up = np.eye(3)[ind_up] * np.sign(avg_up[ind_up])
+
+ return np.stack([viewmatrix(p - center, up, p) for p in positions])
+
+
+def generate_interpolated_path(
+ poses: np.ndarray,
+ n_interp: int,
+ spline_degree: int = 5,
+ smoothness: float = 0.03,
+ rot_weight: float = 0.1,
+):
+ """Creates a smooth spline path between input keyframe camera poses.
+
+ Spline is calculated with poses in format (position, lookat-point, up-point).
+
+ Args:
+ poses: (n, 3, 4) array of input pose keyframes.
+ n_interp: returned path will have n_interp * (n - 1) total poses.
+ spline_degree: polynomial degree of B-spline.
+ smoothness: parameter for spline smoothing, 0 forces exact interpolation.
+ rot_weight: relative weighting of rotation/translation in spline solve.
+
+ Returns:
+ Array of new camera poses with shape (n_interp * (n - 1), 3, 4).
+ """
+
+ def poses_to_points(poses, dist):
+ """Converts from pose matrices to (position, lookat, up) format."""
+ pos = poses[:, :3, -1]
+ lookat = poses[:, :3, -1] - dist * poses[:, :3, 2]
+ up = poses[:, :3, -1] + dist * poses[:, :3, 1]
+ return np.stack([pos, lookat, up], 1)
+
+ def points_to_poses(points):
+ """Converts from (position, lookat, up) format to pose matrices."""
+ return np.array([viewmatrix(p - l, u - p, p) for p, l, u in points])
+
+ def interp(points, n, k, s):
+ """Runs multidimensional B-spline interpolation on the input points."""
+ sh = points.shape
+ pts = np.reshape(points, (sh[0], -1))
+ k = min(k, sh[0] - 1)
+ tck, _ = scipy.interpolate.splprep(pts.T, k=k, s=s)
+ u = np.linspace(0, 1, n, endpoint=False)
+ new_points = np.array(scipy.interpolate.splev(u, tck))
+ new_points = np.reshape(new_points.T, (n, sh[1], sh[2]))
+ return new_points
+
+ points = poses_to_points(poses, dist=rot_weight)
+ new_points = interp(
+ points, n_interp * (points.shape[0] - 1), k=spline_degree, s=smoothness
+ )
+ return points_to_poses(new_points)
diff --git a/src/post_opt/exporter.py b/src/post_opt/exporter.py
new file mode 100644
index 0000000000000000000000000000000000000000..b9e44f96c299a9f5391a681fa41a8a697891fdf4
--- /dev/null
+++ b/src/post_opt/exporter.py
@@ -0,0 +1,553 @@
+import math
+import struct
+from io import BytesIO
+from typing import Literal, Optional
+
+import numpy as np
+import torch
+
+
+def sh2rgb(sh: torch.Tensor) -> torch.Tensor:
+ """Convert Sphere Harmonics to RGB
+
+ Args:
+ sh (torch.Tensor): SH tensor
+
+ Returns:
+ torch.Tensor: RGB tensor
+ """
+ C0 = 0.28209479177387814
+ return sh * C0 + 0.5
+
+
+def part1by2_vec(x: torch.Tensor) -> torch.Tensor:
+ """Interleave bits of x with 0s
+
+ Args:
+ x (torch.Tensor): Input tensor. Shape (N,)
+
+ Returns:
+ torch.Tensor: Output tensor. Shape (N,)
+ """
+
+ x = x & 0x000003FF
+ x = (x ^ (x << 16)) & 0xFF0000FF
+ x = (x ^ (x << 8)) & 0x0300F00F
+ x = (x ^ (x << 4)) & 0x030C30C3
+ x = (x ^ (x << 2)) & 0x09249249
+ return x
+
+
+def encode_morton3_vec(
+ x: torch.Tensor, y: torch.Tensor, z: torch.Tensor
+) -> torch.Tensor:
+ """Compute Morton codes for 3D coordinates
+
+ Args:
+ x (torch.Tensor): X coordinates. Shape (N,)
+ y (torch.Tensor): Y coordinates. Shape (N,)
+ z (torch.Tensor): Z coordinates. Shape (N,)
+ Returns:
+ torch.Tensor: Morton codes. Shape (N,)
+ """
+ return (part1by2_vec(z) << 2) + (part1by2_vec(y) << 1) + part1by2_vec(x)
+
+
+def sort_centers(centers: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
+ """Sort centers based on Morton codes
+
+ Args:
+ centers (torch.Tensor): Centers. Shape (N, 3)
+ indices (torch.Tensor): Indices. Shape (N,)
+ Returns:
+ torch.Tensor: Sorted indices. Shape (N,)
+ """
+ # Compute min and max values in a single operation
+ min_vals, _ = torch.min(centers, dim=0)
+ max_vals, _ = torch.max(centers, dim=0)
+
+ # Compute the scaling factors
+ lengths = max_vals - min_vals
+ lengths[lengths == 0] = 1 # Prevent division by zero
+
+ # Normalize and scale to 10-bit integer range (0-1024)
+ scaled_centers = ((centers - min_vals) / lengths * 1024).floor().to(torch.int32)
+
+ # Extract x, y, z coordinates
+ x, y, z = scaled_centers[:, 0], scaled_centers[:, 1], scaled_centers[:, 2]
+
+ # Compute Morton codes using vectorized operations
+ morton = encode_morton3_vec(x, y, z)
+
+ # Sort indices based on Morton codes
+ sorted_indices = indices[torch.argsort(morton).to(indices.device)]
+
+ return sorted_indices
+
+
+def pack_unorm(value: torch.Tensor, bits: int) -> torch.Tensor:
+ """Pack a floating point value into an unsigned integer with a given number of bits.
+
+ Args:
+ value (torch.Tensor): Floating point value to pack. Shape (N,)
+ bits (int): Number of bits to pack into.
+
+ Returns:
+ torch.Tensor: Packed value. Shape (N,)
+ """
+
+ t = (1 << bits) - 1
+ packed = torch.clamp((value * t + 0.5).floor(), min=0, max=t)
+ # Convert to integer type
+ return packed.to(torch.int64)
+
+
+def pack_111011(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
+ """Pack three floating point values into a 32-bit integer with 11, 10, and 11 bits.
+
+ Args:
+ x (torch.Tensor): X component. Shape (N,)
+ y (torch.Tensor): Y component. Shape (N,)
+ z (torch.Tensor): Z component. Shape (N,)
+ Returns:
+ torch.Tensor: Packed values. Shape (N,)
+ """
+ # Pack each component using pack_unorm
+ packed_x = pack_unorm(x, 11) << 21
+ packed_y = pack_unorm(y, 10) << 11
+ packed_z = pack_unorm(z, 11)
+
+ # Combine the packed values using bitwise OR
+ return packed_x | packed_y | packed_z
+
+
+def pack_8888(
+ x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, w: torch.Tensor
+) -> torch.Tensor:
+ """Pack four floating point values into a 32-bit integer with 8 bits each.
+
+ Args:
+ x (torch.Tensor): X component. Shape (N,)
+ y (torch.Tensor): Y component. Shape (N,)
+ z (torch.Tensor): Z component. Shape (N,)
+ w (torch.Tensor): W component. Shape (N,)
+ Returns:
+ torch.Tensor: Packed values. Shape (N,)
+ """
+ # Pack each component using pack_unorm
+ packed_x = pack_unorm(x, 8) << 24
+ packed_y = pack_unorm(y, 8) << 16
+ packed_z = pack_unorm(z, 8) << 8
+ packed_w = pack_unorm(w, 8)
+
+ # Combine the packed values using bitwise OR
+ return packed_x | packed_y | packed_z | packed_w
+
+
+def pack_rotation(q: torch.Tensor) -> torch.Tensor:
+ """Pack a quaternion into a 32-bit integer.
+
+ Args:
+ q (torch.Tensor): Quaternions. Shape (N, 4)
+
+ Returns:
+ torch.Tensor: Packed values. Shape (N,)
+ """
+
+ # Normalize each quaternion
+ norms = torch.linalg.norm(q, dim=-1, keepdim=True)
+ q = q / norms
+
+ # Find the largest component index for each quaternion
+ largest_components = torch.argmax(torch.abs(q), dim=-1)
+
+ # Flip quaternions where the largest component is negative
+ batch_indices = torch.arange(q.size(0), device=q.device)
+ largest_values = q[batch_indices, largest_components]
+ flip_mask = largest_values < 0
+ q[flip_mask] *= -1
+
+ # Precomputed indices for the components to pack (excluding largest)
+ precomputed_indices = torch.tensor(
+ [[1, 2, 3], [0, 2, 3], [0, 1, 3], [0, 1, 2]], dtype=torch.long, device=q.device
+ )
+
+ # Gather components to pack for each quaternion
+ pack_indices = precomputed_indices[largest_components]
+ components_to_pack = q[batch_indices[:, None], pack_indices]
+
+ # Scale and pack each component into 10-bit integers
+ norm = math.sqrt(2) * 0.5
+ scaled = components_to_pack * norm + 0.5
+ packed = pack_unorm(scaled, 10) # Assuming pack_unorm is vectorized
+
+ # Combine into the final 32-bit integer
+ largest_packed = largest_components.to(torch.int64) << 30
+ c0_packed = packed[:, 0] << 20
+ c1_packed = packed[:, 1] << 10
+ c2_packed = packed[:, 2]
+
+ result = largest_packed | c0_packed | c1_packed | c2_packed
+ return result
+
+
+def splat2ply_bytes_compressed(
+ means: torch.Tensor,
+ scales: torch.Tensor,
+ quats: torch.Tensor,
+ opacities: torch.Tensor,
+ sh0: torch.Tensor,
+ shN: torch.Tensor,
+ chunk_max_size: int = 256,
+ opacity_threshold: float = 1 / 255,
+) -> bytes:
+ """Return the binary compressed Ply file. Used by Supersplat viewer.
+
+ Args:
+ means (torch.Tensor): Splat means. Shape (N, 3)
+ scales (torch.Tensor): Splat scales. Shape (N, 3)
+ quats (torch.Tensor): Splat quaternions. Shape (N, 4)
+ opacities (torch.Tensor): Splat opacities. Shape (N,)
+ sh0 (torch.Tensor): Spherical harmonics. Shape (N, 3)
+ shN (torch.Tensor): Spherical harmonics. Shape (N, K*3)
+ chunk_max_size (int): Maximum number of splats per chunk. Default: 256
+ opacity_threshold (float): Opacity threshold. Default: 1 / 255
+
+ Returns:
+ bytes: Binary compressed Ply file representing the model.
+ """
+
+ # Filter the splats with too low opacity
+ mask = torch.sigmoid(opacities) > opacity_threshold
+ means = means[mask]
+ scales = scales[mask]
+ sh0_colors = sh2rgb(sh0)
+ sh0_colors = sh0_colors[mask]
+ shN = shN[mask]
+ quats = quats[mask]
+ opacities = opacities[mask]
+
+ num_splats = means.shape[0]
+ n_chunks = num_splats // chunk_max_size + (num_splats % chunk_max_size != 0)
+ indices = torch.arange(num_splats)
+ indices = sort_centers(means, indices)
+
+ float_properties = [
+ "min_x",
+ "min_y",
+ "min_z",
+ "max_x",
+ "max_y",
+ "max_z",
+ "min_scale_x",
+ "min_scale_y",
+ "min_scale_z",
+ "max_scale_x",
+ "max_scale_y",
+ "max_scale_z",
+ "min_r",
+ "min_g",
+ "min_b",
+ "max_r",
+ "max_g",
+ "max_b",
+ ]
+ uint_properties = [
+ "packed_position",
+ "packed_rotation",
+ "packed_scale",
+ "packed_color",
+ ]
+ buffer = BytesIO()
+
+ # Write PLY header
+ buffer.write(b"ply\n")
+ buffer.write(b"format binary_little_endian 1.0\n")
+ buffer.write(f"element chunk {n_chunks}\n".encode())
+ for prop in float_properties:
+ buffer.write(f"property float {prop}\n".encode())
+ buffer.write(f"element vertex {num_splats}\n".encode())
+ for prop in uint_properties:
+ buffer.write(f"property uint {prop}\n".encode())
+ buffer.write(f"element sh {num_splats}\n".encode())
+ for j in range(shN.shape[1]):
+ buffer.write(f"property uchar f_rest_{j}\n".encode())
+ buffer.write(b"end_header\n")
+
+ chunk_data = []
+ splat_data = []
+ sh_data = []
+ for chunk_idx in range(n_chunks):
+ chunk_end_idx = min((chunk_idx + 1) * chunk_max_size, num_splats)
+ chunk_start_idx = chunk_idx * chunk_max_size
+ splat_idxs = indices[chunk_start_idx:chunk_end_idx]
+
+ # Bounds
+ # Means
+ chunk_means = means[splat_idxs]
+ min_means = torch.min(chunk_means, dim=0).values
+ max_means = torch.max(chunk_means, dim=0).values
+ mean_bounds = torch.cat([min_means, max_means])
+ # Scales
+ chunk_scales = scales[splat_idxs]
+ min_scales = torch.min(chunk_scales, dim=0).values
+ max_scales = torch.max(chunk_scales, dim=0).values
+ min_scales = torch.clamp(min_scales, -20, 20)
+ max_scales = torch.clamp(max_scales, -20, 20)
+ scale_bounds = torch.cat([min_scales, max_scales])
+ # Colors
+ chunk_colors = sh0_colors[splat_idxs]
+ min_colors = torch.min(chunk_colors, dim=0).values
+ max_colors = torch.max(chunk_colors, dim=0).values
+ color_bounds = torch.cat([min_colors, max_colors])
+ chunk_data.extend([mean_bounds, scale_bounds, color_bounds])
+
+ # Quantized properties:
+ # Means
+ normalized_means = (chunk_means - min_means) / (max_means - min_means)
+ means_i = pack_111011(
+ normalized_means[:, 0],
+ normalized_means[:, 1],
+ normalized_means[:, 2],
+ )
+ # Quaternions
+ chunk_quats = quats[splat_idxs]
+ quat_i = pack_rotation(chunk_quats)
+ # Scales
+ normalized_scales = (chunk_scales - min_scales) / (max_scales - min_scales)
+ scales_i = pack_111011(
+ normalized_scales[:, 0],
+ normalized_scales[:, 1],
+ normalized_scales[:, 2],
+ )
+ # Colors
+ normalized_colors = (chunk_colors - min_colors) / (max_colors - min_colors)
+ chunk_opacities = opacities[splat_idxs]
+ chunk_opacities = 1 / (1 + torch.exp(-chunk_opacities))
+ chunk_opacities = chunk_opacities.unsqueeze(-1)
+ normalized_colors_i = torch.cat([normalized_colors, chunk_opacities], dim=-1)
+ color_i = pack_8888(
+ normalized_colors_i[:, 0],
+ normalized_colors_i[:, 1],
+ normalized_colors_i[:, 2],
+ normalized_colors_i[:, 3],
+ )
+ splat_data_chunk = torch.stack([means_i, quat_i, scales_i, color_i], dim=1)
+ splat_data_chunk = splat_data_chunk.ravel().to(torch.int64)
+ splat_data.extend([splat_data_chunk])
+
+ # Quantized spherical harmonics
+ shN_chunk = shN[splat_idxs]
+ shN_chunk_quantized = (shN_chunk / 8 + 0.5) * 256
+ shN_chunk_quantized = torch.clamp(torch.trunc(shN_chunk_quantized), 0, 255)
+ shN_chunk_quantized = shN_chunk_quantized.to(torch.uint8)
+ sh_data.extend([shN_chunk_quantized.ravel()])
+
+ float_dtype = np.dtype(np.float32).newbyteorder("<")
+ uint32_dtype = np.dtype(np.uint32).newbyteorder("<")
+ uint8_dtype = np.dtype(np.uint8)
+
+ buffer.write(
+ torch.cat(chunk_data).detach().cpu().numpy().astype(float_dtype).tobytes()
+ )
+ buffer.write(
+ torch.cat(splat_data).detach().cpu().numpy().astype(uint32_dtype).tobytes()
+ )
+ buffer.write(
+ torch.cat(sh_data).detach().cpu().numpy().astype(uint8_dtype).tobytes()
+ )
+
+ return buffer.getvalue()
+
+
+def splat2ply_bytes(
+ means: torch.Tensor,
+ scales: torch.Tensor,
+ quats: torch.Tensor,
+ opacities: torch.Tensor,
+ sh0: torch.Tensor,
+ shN: torch.Tensor,
+) -> bytes:
+ """Return the binary Ply file. Supported by almost all viewers.
+
+ Args:
+ means (torch.Tensor): Splat means. Shape (N, 3)
+ scales (torch.Tensor): Splat scales. Shape (N, 3)
+ quats (torch.Tensor): Splat quaternions. Shape (N, 4)
+ opacities (torch.Tensor): Splat opacities. Shape (N,)
+ sh0 (torch.Tensor): Spherical harmonics. Shape (N, 3)
+ shN (torch.Tensor): Spherical harmonics. Shape (N, K*3)
+
+ Returns:
+ bytes: Binary Ply file representing the model.
+ """
+
+ num_splats = means.shape[0]
+ buffer = BytesIO()
+
+ # Write PLY header
+ buffer.write(b"ply\n")
+ buffer.write(b"format binary_little_endian 1.0\n")
+ buffer.write(f"element vertex {num_splats}\n".encode())
+ buffer.write(b"property float x\n")
+ buffer.write(b"property float y\n")
+ buffer.write(b"property float z\n")
+ for i, data in enumerate([sh0, shN]):
+ prefix = "f_dc" if i == 0 else "f_rest"
+ for j in range(data.shape[1]):
+ buffer.write(f"property float {prefix}_{j}\n".encode())
+ buffer.write(b"property float opacity\n")
+ for i in range(scales.shape[1]):
+ buffer.write(f"property float scale_{i}\n".encode())
+ for i in range(quats.shape[1]):
+ buffer.write(f"property float rot_{i}\n".encode())
+ buffer.write(b"end_header\n")
+
+ # Concatenate all tensors in the correct order
+ splat_data = torch.cat(
+ [means, sh0, shN, opacities.unsqueeze(1), scales, quats], dim=1
+ )
+ # Ensure correct dtype
+ splat_data = splat_data.to(torch.float32)
+
+ # Write binary data
+ float_dtype = np.dtype(np.float32).newbyteorder("<")
+ buffer.write(splat_data.detach().cpu().numpy().astype(float_dtype).tobytes())
+
+ return buffer.getvalue()
+
+
+def splat2splat_bytes(
+ means: torch.Tensor,
+ scales: torch.Tensor,
+ quats: torch.Tensor,
+ opacities: torch.Tensor,
+ sh0: torch.Tensor,
+) -> bytes:
+ """Return the binary Splat file. Supported by antimatter15 viewer.
+
+ Args:
+ means (torch.Tensor): Splat means. Shape (N, 3)
+ scales (torch.Tensor): Splat scales. Shape (N, 3)
+ quats (torch.Tensor): Splat quaternions. Shape (N, 4)
+ opacities (torch.Tensor): Splat opacities. Shape (N,)
+ sh0 (torch.Tensor): Spherical harmonics. Shape (N, 3)
+
+ Returns:
+ bytes: Binary Splat file representing the model.
+ """
+
+ # Preprocess
+ scales = torch.exp(scales)
+ sh0_color = sh2rgb(sh0)
+ colors = torch.cat([sh0_color, torch.sigmoid(opacities).unsqueeze(-1)], dim=1)
+ colors = (colors * 255).clamp(0, 255).to(torch.uint8)
+
+ rots = (quats / torch.linalg.norm(quats, dim=1, keepdim=True)) * 128 + 128
+ rots = rots.clamp(0, 255).to(torch.uint8)
+
+ # Sort splats
+ num_splats = means.shape[0]
+ indices = sort_centers(means, torch.arange(num_splats))
+
+ # Reorder everything
+ means = means[indices]
+ scales = scales[indices]
+ colors = colors[indices]
+ rots = rots[indices]
+
+ float_dtype = np.dtype(np.float32).newbyteorder("<")
+ means_np = means.detach().cpu().numpy().astype(float_dtype)
+ scales_np = scales.detach().cpu().numpy().astype(float_dtype)
+ colors_np = colors.detach().cpu().numpy().astype(np.uint8)
+ rots_np = rots.detach().cpu().numpy().astype(np.uint8)
+
+ buffer = BytesIO()
+ for i in range(num_splats):
+ buffer.write(means_np[i].tobytes())
+ buffer.write(scales_np[i].tobytes())
+ buffer.write(colors_np[i].tobytes())
+ buffer.write(rots_np[i].tobytes())
+
+ return buffer.getvalue()
+
+
+def export_splats(
+ means: torch.Tensor,
+ scales: torch.Tensor,
+ quats: torch.Tensor,
+ opacities: torch.Tensor,
+ sh0: torch.Tensor,
+ shN: torch.Tensor,
+ format: Literal["ply", "splat", "ply_compressed"] = "ply",
+ save_to: Optional[str] = None,
+) -> bytes:
+ """Export a Gaussian Splats model to bytes.
+ The three supported formats are:
+ - ply: A standard PLY file format. Supported by most viewers.
+ - splat: A custom Splat file format. Supported by antimatter15 viewer.
+ - ply_compressed: A compressed PLY file format. Used by Supersplat viewer.
+
+ Args:
+ means (torch.Tensor): Splat means. Shape (N, 3)
+ scales (torch.Tensor): Splat scales. Shape (N, 3)
+ quats (torch.Tensor): Splat quaternions. Shape (N, 4)
+ opacities (torch.Tensor): Splat opacities. Shape (N,)
+ sh0 (torch.Tensor): Spherical harmonics. Shape (N, 1, 3)
+ shN (torch.Tensor): Spherical harmonics. Shape (N, K, 3)
+ format (str): Export format. Options: "ply", "splat", "ply_compressed". Default: "ply"
+ save_to (str): Output file path. If provided, the bytes will be written to file.
+ """
+ total_splats = means.shape[0]
+ assert means.shape == (total_splats, 3), "Means must be of shape (N, 3)"
+ assert scales.shape == (total_splats, 3), "Scales must be of shape (N, 3)"
+ assert quats.shape == (total_splats, 4), "Quaternions must be of shape (N, 4)"
+ assert opacities.shape == (total_splats,), "Opacities must be of shape (N,)"
+ assert sh0.shape == (total_splats, 1, 3), "sh0 must be of shape (N, 1, 3)"
+ assert (
+ shN.ndim == 3 and shN.shape[0] == total_splats and shN.shape[2] == 3
+ ), f"shN must be of shape (N, K, 3), got {shN.shape}"
+
+ # Reshape spherical harmonics
+ sh0 = sh0.squeeze(1) # Shape (N, 3)
+ shN = shN.permute(0, 2, 1).reshape(means.shape[0], -1) # Shape (N, K * 3)
+
+ # Check for NaN or Inf values
+ invalid_mask = (
+ torch.isnan(means).any(dim=1)
+ | torch.isinf(means).any(dim=1)
+ | torch.isnan(scales).any(dim=1)
+ | torch.isinf(scales).any(dim=1)
+ | torch.isnan(quats).any(dim=1)
+ | torch.isinf(quats).any(dim=1)
+ | torch.isnan(opacities).any(dim=0)
+ | torch.isinf(opacities).any(dim=0)
+ | torch.isnan(sh0).any(dim=1)
+ | torch.isinf(sh0).any(dim=1)
+ | torch.isnan(shN).any(dim=1)
+ | torch.isinf(shN).any(dim=1)
+ )
+
+ # Filter out invalid entries
+ valid_mask = ~invalid_mask
+ means = means[valid_mask]
+ scales = scales[valid_mask]
+ quats = quats[valid_mask]
+ opacities = opacities[valid_mask]
+ sh0 = sh0[valid_mask]
+ shN = shN[valid_mask]
+
+ if format == "ply":
+ data = splat2ply_bytes(means, scales, quats, opacities, sh0, shN)
+ elif format == "splat":
+ data = splat2splat_bytes(means, scales, quats, opacities, sh0)
+ elif format == "ply_compressed":
+ data = splat2ply_bytes_compressed(means, scales, quats, opacities, sh0, shN)
+ else:
+ raise ValueError(f"Unsupported format: {format}")
+
+ if save_to:
+ with open(save_to, "wb") as binary_file:
+ binary_file.write(data)
+
+ return data
diff --git a/src/post_opt/gsplat_viewer.py b/src/post_opt/gsplat_viewer.py
new file mode 100644
index 0000000000000000000000000000000000000000..841397c325acc4cfdebeffc9e24ea808d205d1f8
--- /dev/null
+++ b/src/post_opt/gsplat_viewer.py
@@ -0,0 +1,249 @@
+import viser
+from pathlib import Path
+from typing import Literal
+from typing import Tuple, Callable
+from nerfview import Viewer, RenderTabState
+
+
+class GsplatRenderTabState(RenderTabState):
+ # non-controlable parameters
+ total_gs_count: int = 0
+ rendered_gs_count: int = 0
+
+ # controlable parameters
+ max_sh_degree: int = 5
+ near_plane: float = 1e-2
+ far_plane: float = 1e2
+ radius_clip: float = 0.0
+ eps2d: float = 0.3
+ backgrounds: Tuple[float, float, float] = (0.0, 0.0, 0.0)
+ render_mode: Literal[
+ "rgb", "depth(accumulated)", "depth(expected)", "alpha"
+ ] = "rgb"
+ normalize_nearfar: bool = False
+ inverse: bool = True
+ colormap: Literal[
+ "turbo", "viridis", "magma", "inferno", "cividis", "gray"
+ ] = "turbo"
+ rasterize_mode: Literal["classic", "antialiased"] = "classic"
+ camera_model: Literal["pinhole", "ortho", "fisheye"] = "pinhole"
+
+
+class GsplatViewer(Viewer):
+ """
+ Viewer for gsplat.
+ """
+
+ def __init__(
+ self,
+ server: viser.ViserServer,
+ render_fn: Callable,
+ output_dir: Path,
+ mode: Literal["rendering", "training"] = "rendering",
+ ):
+ super().__init__(server, render_fn, output_dir, mode)
+ server.gui.set_panel_label("gsplat viewer")
+
+ def _init_rendering_tab(self):
+ self.render_tab_state = GsplatRenderTabState()
+ self._rendering_tab_handles = {}
+ self._rendering_folder = self.server.gui.add_folder("Rendering")
+
+ def _populate_rendering_tab(self):
+ server = self.server
+ with self._rendering_folder:
+ with server.gui.add_folder("Gsplat"):
+ total_gs_count_number = server.gui.add_number(
+ "Total",
+ initial_value=self.render_tab_state.total_gs_count,
+ disabled=True,
+ hint="Total number of splats in the scene.",
+ )
+ rendered_gs_count_number = server.gui.add_number(
+ "Rendered",
+ initial_value=self.render_tab_state.rendered_gs_count,
+ disabled=True,
+ hint="Number of splats rendered.",
+ )
+
+ max_sh_degree_number = server.gui.add_number(
+ "Max SH",
+ initial_value=self.render_tab_state.max_sh_degree,
+ min=0,
+ max=5,
+ step=1,
+ hint="Maximum SH degree used",
+ )
+
+ @max_sh_degree_number.on_update
+ def _(_) -> None:
+ self.render_tab_state.max_sh_degree = int(
+ max_sh_degree_number.value
+ )
+ self.rerender(_)
+
+ near_far_plane_vec2 = server.gui.add_vector2(
+ "Near/Far",
+ initial_value=(
+ self.render_tab_state.near_plane,
+ self.render_tab_state.far_plane,
+ ),
+ min=(1e-3, 1e1),
+ max=(1e1, 1e2),
+ step=1e-3,
+ hint="Near and far plane for rendering.",
+ )
+
+ @near_far_plane_vec2.on_update
+ def _(_) -> None:
+ self.render_tab_state.near_plane = near_far_plane_vec2.value[0]
+ self.render_tab_state.far_plane = near_far_plane_vec2.value[1]
+ self.rerender(_)
+
+ radius_clip_slider = server.gui.add_number(
+ "Radius Clip",
+ initial_value=self.render_tab_state.radius_clip,
+ min=0.0,
+ max=100.0,
+ step=1.0,
+ hint="2D radius clip for rendering.",
+ )
+
+ @radius_clip_slider.on_update
+ def _(_) -> None:
+ self.render_tab_state.radius_clip = radius_clip_slider.value
+ self.rerender(_)
+
+ eps2d_slider = server.gui.add_number(
+ "2D Epsilon",
+ initial_value=self.render_tab_state.eps2d,
+ min=0.0,
+ max=1.0,
+ step=0.01,
+ hint="Epsilon added to the egienvalues of projected 2D covariance matrices.",
+ )
+
+ @eps2d_slider.on_update
+ def _(_) -> None:
+ self.render_tab_state.eps2d = eps2d_slider.value
+ self.rerender(_)
+
+ backgrounds_slider = server.gui.add_rgb(
+ "Background",
+ initial_value=self.render_tab_state.backgrounds,
+ hint="Background color for rendering.",
+ )
+
+ @backgrounds_slider.on_update
+ def _(_) -> None:
+ self.render_tab_state.backgrounds = backgrounds_slider.value
+ self.rerender(_)
+
+ render_mode_dropdown = server.gui.add_dropdown(
+ "Render Mode",
+ ("rgb", "depth(accumulated)", "depth(expected)", "alpha"),
+ initial_value=self.render_tab_state.render_mode,
+ hint="Render mode to use.",
+ )
+
+ @render_mode_dropdown.on_update
+ def _(_) -> None:
+ if "depth" in render_mode_dropdown.value:
+ normalize_nearfar_checkbox.disabled = False
+ else:
+ normalize_nearfar_checkbox.disabled = True
+ if render_mode_dropdown.value == "rgb":
+ inverse_checkbox.disabled = True
+ else:
+ inverse_checkbox.disabled = False
+ self.render_tab_state.render_mode = render_mode_dropdown.value
+ self.rerender(_)
+
+ normalize_nearfar_checkbox = server.gui.add_checkbox(
+ "Normalize Near/Far",
+ initial_value=self.render_tab_state.normalize_nearfar,
+ disabled=True,
+ hint="Normalize depth with near/far plane.",
+ )
+
+ @normalize_nearfar_checkbox.on_update
+ def _(_) -> None:
+ self.render_tab_state.normalize_nearfar = (
+ normalize_nearfar_checkbox.value
+ )
+ self.rerender(_)
+
+ inverse_checkbox = server.gui.add_checkbox(
+ "Inverse",
+ initial_value=self.render_tab_state.inverse,
+ disabled=True,
+ hint="Inverse the depth.",
+ )
+
+ @inverse_checkbox.on_update
+ def _(_) -> None:
+ self.render_tab_state.inverse = inverse_checkbox.value
+ self.rerender(_)
+
+ colormap_dropdown = server.gui.add_dropdown(
+ "Colormap",
+ ("turbo", "viridis", "magma", "inferno", "cividis", "gray"),
+ initial_value=self.render_tab_state.colormap,
+ hint="Colormap used for rendering depth/alpha.",
+ )
+
+ @colormap_dropdown.on_update
+ def _(_) -> None:
+ self.render_tab_state.colormap = colormap_dropdown.value
+ self.rerender(_)
+
+ rasterize_mode_dropdown = server.gui.add_dropdown(
+ "Anti-Aliasing",
+ ("classic", "antialiased"),
+ initial_value=self.render_tab_state.rasterize_mode,
+ hint="Whether to use classic or antialiased rasterization.",
+ )
+
+ @rasterize_mode_dropdown.on_update
+ def _(_) -> None:
+ self.render_tab_state.rasterize_mode = rasterize_mode_dropdown.value
+ self.rerender(_)
+
+ camera_model_dropdown = server.gui.add_dropdown(
+ "Camera",
+ ("pinhole", "ortho", "fisheye"),
+ initial_value=self.render_tab_state.camera_model,
+ hint="Camera model used for rendering.",
+ )
+
+ @camera_model_dropdown.on_update
+ def _(_) -> None:
+ self.render_tab_state.camera_model = camera_model_dropdown.value
+ self.rerender(_)
+
+ self._rendering_tab_handles.update(
+ {
+ "total_gs_count_number": total_gs_count_number,
+ "rendered_gs_count_number": rendered_gs_count_number,
+ "near_far_plane_vec2": near_far_plane_vec2,
+ "radius_clip_slider": radius_clip_slider,
+ "eps2d_slider": eps2d_slider,
+ "backgrounds_slider": backgrounds_slider,
+ "render_mode_dropdown": render_mode_dropdown,
+ "normalize_nearfar_checkbox": normalize_nearfar_checkbox,
+ "inverse_checkbox": inverse_checkbox,
+ "colormap_dropdown": colormap_dropdown,
+ "rasterize_mode_dropdown": rasterize_mode_dropdown,
+ "camera_model_dropdown": camera_model_dropdown,
+ }
+ )
+ super()._populate_rendering_tab()
+
+ def _after_render(self):
+ # Update the GUI elements with current values
+ self._rendering_tab_handles[
+ "total_gs_count_number"
+ ].value = self.render_tab_state.total_gs_count
+ self._rendering_tab_handles[
+ "rendered_gs_count_number"
+ ].value = self.render_tab_state.rendered_gs_count
diff --git a/src/post_opt/image_fitting.py b/src/post_opt/image_fitting.py
new file mode 100644
index 0000000000000000000000000000000000000000..434b7869b87d034de2149e49eb64172e4cffd2c9
--- /dev/null
+++ b/src/post_opt/image_fitting.py
@@ -0,0 +1,189 @@
+import math
+import os
+import time
+from pathlib import Path
+from typing import Literal, Optional
+
+import numpy as np
+import torch
+import tyro
+from PIL import Image
+from torch import Tensor, optim
+
+from gsplat import rasterization, rasterization_2dgs
+
+
+class SimpleTrainer:
+ """Trains random gaussians to fit an image."""
+
+ def __init__(
+ self,
+ gt_image: Tensor,
+ num_points: int = 2000,
+ ):
+ self.device = torch.device("cuda:0")
+ self.gt_image = gt_image.to(device=self.device)
+ self.num_points = num_points
+
+ fov_x = math.pi / 2.0
+ self.H, self.W = gt_image.shape[0], gt_image.shape[1]
+ self.focal = 0.5 * float(self.W) / math.tan(0.5 * fov_x)
+ self.img_size = torch.tensor([self.W, self.H, 1], device=self.device)
+
+ self._init_gaussians()
+
+ def _init_gaussians(self):
+ """Random gaussians"""
+ bd = 2
+
+ self.means = bd * (torch.rand(self.num_points, 3, device=self.device) - 0.5)
+ self.scales = torch.rand(self.num_points, 3, device=self.device)
+ d = 3
+ self.rgbs = torch.rand(self.num_points, d, device=self.device)
+
+ u = torch.rand(self.num_points, 1, device=self.device)
+ v = torch.rand(self.num_points, 1, device=self.device)
+ w = torch.rand(self.num_points, 1, device=self.device)
+
+ self.quats = torch.cat(
+ [
+ torch.sqrt(1.0 - u) * torch.sin(2.0 * math.pi * v),
+ torch.sqrt(1.0 - u) * torch.cos(2.0 * math.pi * v),
+ torch.sqrt(u) * torch.sin(2.0 * math.pi * w),
+ torch.sqrt(u) * torch.cos(2.0 * math.pi * w),
+ ],
+ -1,
+ )
+ self.opacities = torch.ones((self.num_points), device=self.device)
+
+ self.viewmat = torch.tensor(
+ [
+ [1.0, 0.0, 0.0, 0.0],
+ [0.0, 1.0, 0.0, 0.0],
+ [0.0, 0.0, 1.0, 8.0],
+ [0.0, 0.0, 0.0, 1.0],
+ ],
+ device=self.device,
+ )
+ self.background = torch.zeros(d, device=self.device)
+
+ self.means.requires_grad = True
+ self.scales.requires_grad = True
+ self.quats.requires_grad = True
+ self.rgbs.requires_grad = True
+ self.opacities.requires_grad = True
+ self.viewmat.requires_grad = False
+
+ def train(
+ self,
+ iterations: int = 1000,
+ lr: float = 0.01,
+ save_imgs: bool = False,
+ model_type: Literal["3dgs", "2dgs"] = "3dgs",
+ ):
+ optimizer = optim.Adam(
+ [self.rgbs, self.means, self.scales, self.opacities, self.quats], lr
+ )
+ mse_loss = torch.nn.MSELoss()
+ frames = []
+ times = [0] * 2 # rasterization, backward
+ K = torch.tensor(
+ [
+ [self.focal, 0, self.W / 2],
+ [0, self.focal, self.H / 2],
+ [0, 0, 1],
+ ],
+ device=self.device,
+ )
+
+ if model_type == "3dgs":
+ rasterize_fnc = rasterization
+ elif model_type == "2dgs":
+ rasterize_fnc = rasterization_2dgs
+
+ for iter in range(iterations):
+ start = time.time()
+
+ renders = rasterize_fnc(
+ self.means,
+ self.quats / self.quats.norm(dim=-1, keepdim=True),
+ self.scales,
+ torch.sigmoid(self.opacities),
+ torch.sigmoid(self.rgbs),
+ self.viewmat[None],
+ K[None],
+ self.W,
+ self.H,
+ packed=False,
+ )[0]
+ out_img = renders[0]
+ torch.cuda.synchronize()
+ times[0] += time.time() - start
+ loss = mse_loss(out_img, self.gt_image)
+ optimizer.zero_grad()
+ start = time.time()
+ loss.backward()
+ torch.cuda.synchronize()
+ times[1] += time.time() - start
+ optimizer.step()
+ print(f"Iteration {iter + 1}/{iterations}, Loss: {loss.item()}")
+
+ if save_imgs and iter % 5 == 0:
+ frames.append((out_img.detach().cpu().numpy() * 255).astype(np.uint8))
+ if save_imgs:
+ # save them as a gif with PIL
+ frames = [Image.fromarray(frame) for frame in frames]
+ out_dir = os.path.join(os.getcwd(), "results")
+ os.makedirs(out_dir, exist_ok=True)
+ frames[0].save(
+ f"{out_dir}/training.gif",
+ save_all=True,
+ append_images=frames[1:],
+ optimize=False,
+ duration=5,
+ loop=0,
+ )
+ print(f"Total(s):\nRasterization: {times[0]:.3f}, Backward: {times[1]:.3f}")
+ print(
+ f"Per step(s):\nRasterization: {times[0]/iterations:.5f}, Backward: {times[1]/iterations:.5f}"
+ )
+
+
+def image_path_to_tensor(image_path: Path):
+ import torchvision.transforms as transforms
+
+ img = Image.open(image_path)
+ transform = transforms.ToTensor()
+ img_tensor = transform(img).permute(1, 2, 0)[..., :3]
+ return img_tensor
+
+
+def main(
+ height: int = 256,
+ width: int = 256,
+ num_points: int = 100000,
+ save_imgs: bool = True,
+ img_path: Optional[Path] = None,
+ iterations: int = 1000,
+ lr: float = 0.01,
+ model_type: Literal["3dgs", "2dgs"] = "3dgs",
+) -> None:
+ if img_path:
+ gt_image = image_path_to_tensor(img_path)
+ else:
+ gt_image = torch.ones((height, width, 3)) * 1.0
+ # make top left and bottom right red, blue
+ gt_image[: height // 2, : width // 2, :] = torch.tensor([1.0, 0.0, 0.0])
+ gt_image[height // 2 :, width // 2 :, :] = torch.tensor([0.0, 0.0, 1.0])
+
+ trainer = SimpleTrainer(gt_image=gt_image, num_points=num_points)
+ trainer.train(
+ iterations=iterations,
+ lr=lr,
+ save_imgs=save_imgs,
+ model_type=model_type,
+ )
+
+
+if __name__ == "__main__":
+ tyro.cli(main)
diff --git a/src/post_opt/lib_bilagrid.py b/src/post_opt/lib_bilagrid.py
new file mode 100644
index 0000000000000000000000000000000000000000..1b10a204569a6480ac345401a8efcfaf154bc321
--- /dev/null
+++ b/src/post_opt/lib_bilagrid.py
@@ -0,0 +1,573 @@
+# # Copyright 2024 Yuehao Wang (https://github.com/yuehaowang). This part of code is borrowed form ["Bilateral Guided Radiance Field Processing"](https://bilarfpro.github.io/).
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This is a standalone PyTorch implementation of 3D bilateral grid and CP-decomposed 4D bilateral grid.
+To use this module, you can download the "lib_bilagrid.py" file and simply put it in your project directory.
+
+For the details, please check our research project: ["Bilateral Guided Radiance Field Processing"](https://bilarfpro.github.io/).
+
+#### Dependencies
+
+In addition to PyTorch and Numpy, please install [tensorly](https://github.com/tensorly/tensorly).
+We have tested this module on Python 3.9.18, PyTorch 2.0.1 (CUDA 11), tensorly 0.8.1, and Numpy 1.25.2.
+
+#### Overview
+
+- For bilateral guided training, you need to construct a `BilateralGrid` instance, which can hold multiple bilateral grids
+ for input views. Then, use `slice` function to obtain transformed RGB output and the corresponding affine transformations.
+
+- For bilateral guided finishing, you need to instantiate a `BilateralGridCP4D` object and use `slice4d`.
+
+#### Examples
+
+- Bilateral grid for approximating ISP:
+
+
+
+- Low-rank 4D bilateral grid for MR enhancement:
+
+
+
+
+Below is the API reference.
+
+"""
+
+import tensorly as tl
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+tl.set_backend("pytorch")
+
+
+def color_correct(
+ img: torch.Tensor, ref: torch.Tensor, num_iters: int = 5, eps: float = 0.5 / 255
+) -> torch.Tensor:
+ """
+ Warp `img` to match the colors in `ref_img` using iterative color matching.
+
+ This function performs color correction by warping the colors of the input image
+ to match those of a reference image. It uses a least squares method to find a
+ transformation that maps the input image's colors to the reference image's colors.
+
+ The algorithm iteratively solves a system of linear equations, updating the set of
+ unsaturated pixels in each iteration. This approach helps handle non-linear color
+ transformations and reduces the impact of clipping.
+
+ Args:
+ img (torch.Tensor): Input image to be color corrected. Shape: [..., num_channels]
+ ref (torch.Tensor): Reference image to match colors. Shape: [..., num_channels]
+ num_iters (int, optional): Number of iterations for the color matching process.
+ Default is 5.
+ eps (float, optional): Small value to determine the range of unclipped pixels.
+ Default is 0.5 / 255.
+
+ Returns:
+ torch.Tensor: Color corrected image with the same shape as the input image.
+
+ Note:
+ - Both input and reference images should be in the range [0, 1].
+ - The function works with any number of channels, but typically used with 3 (RGB).
+ """
+ if img.shape[-1] != ref.shape[-1]:
+ raise ValueError(
+ f"img's {img.shape[-1]} and ref's {ref.shape[-1]} channels must match"
+ )
+ num_channels = img.shape[-1]
+ img_mat = img.reshape([-1, num_channels])
+ ref_mat = ref.reshape([-1, num_channels])
+
+ def is_unclipped(z):
+ return (z >= eps) & (z <= 1 - eps) # z \in [eps, 1-eps].
+
+ mask0 = is_unclipped(img_mat)
+ # Because the set of saturated pixels may change after solving for a
+ # transformation, we repeatedly solve a system `num_iters` times and update
+ # our estimate of which pixels are saturated.
+ for _ in range(num_iters):
+ # Construct the left hand side of a linear system that contains a quadratic
+ # expansion of each pixel of `img`.
+ a_mat = []
+ for c in range(num_channels):
+ a_mat.append(img_mat[:, c : (c + 1)] * img_mat[:, c:]) # Quadratic term.
+ a_mat.append(img_mat) # Linear term.
+ a_mat.append(torch.ones_like(img_mat[:, :1])) # Bias term.
+ a_mat = torch.cat(a_mat, dim=-1)
+ warp = []
+ for c in range(num_channels):
+ # Construct the right hand side of a linear system containing each color
+ # of `ref`.
+ b = ref_mat[:, c]
+ # Ignore rows of the linear system that were saturated in the input or are
+ # saturated in the current corrected color estimate.
+ mask = mask0[:, c] & is_unclipped(img_mat[:, c]) & is_unclipped(b)
+ ma_mat = torch.where(mask[:, None], a_mat, torch.zeros_like(a_mat))
+ mb = torch.where(mask, b, torch.zeros_like(b))
+ w = torch.linalg.lstsq(ma_mat, mb, rcond=-1)[0]
+ assert torch.all(torch.isfinite(w))
+ warp.append(w)
+ warp = torch.stack(warp, dim=-1)
+ # Apply the warp to update img_mat.
+ img_mat = torch.clip(torch.matmul(a_mat, warp), 0, 1)
+ corrected_img = torch.reshape(img_mat, img.shape)
+ return corrected_img
+
+
+def bilateral_grid_tv_loss(model, config):
+ """Computes total variations of bilateral grids."""
+ total_loss = 0.0
+
+ for bil_grids in model.bil_grids:
+ total_loss += config.bilgrid_tv_loss_mult * total_variation_loss(
+ bil_grids.grids
+ )
+
+ return total_loss
+
+
+def color_affine_transform(affine_mats, rgb):
+ """Applies color affine transformations.
+
+ Args:
+ affine_mats (torch.Tensor): Affine transformation matrices. Supported shape: $(..., 3, 4)$.
+ rgb (torch.Tensor): Input RGB values. Supported shape: $(..., 3)$.
+
+ Returns:
+ Output transformed colors of shape $(..., 3)$.
+ """
+ return (
+ torch.matmul(affine_mats[..., :3], rgb.unsqueeze(-1)).squeeze(-1)
+ + affine_mats[..., 3]
+ )
+
+
+def _num_tensor_elems(t):
+ return max(torch.prod(torch.tensor(t.size()[1:]).float()).item(), 1.0)
+
+
+def total_variation_loss(x): # noqa: F811
+ """Returns total variation on multi-dimensional tensors.
+
+ Args:
+ x (torch.Tensor): The input tensor with shape $(B, C, ...)$, where $B$ is the batch size and $C$ is the channel size.
+ """
+ batch_size = x.shape[0]
+ tv = 0
+ for i in range(2, len(x.shape)):
+ n_res = x.shape[i]
+ idx1 = torch.arange(1, n_res, device=x.device)
+ idx2 = torch.arange(0, n_res - 1, device=x.device)
+ x1 = x.index_select(i, idx1)
+ x2 = x.index_select(i, idx2)
+ count = _num_tensor_elems(x1)
+ tv += torch.pow((x1 - x2), 2).sum() / count
+ return tv / batch_size
+
+
+def slice(bil_grids, xy, rgb, grid_idx):
+ """Slices a batch of 3D bilateral grids by pixel coordinates `xy` and gray-scale guidances of pixel colors `rgb`.
+
+ Supports 2-D, 3-D, and 4-D input shapes. The first dimension of the input is the batch size
+ and the last dimension is 2 for `xy`, 3 for `rgb`, and 1 for `grid_idx`.
+
+ The return value is a dictionary containing the affine transformations `affine_mats` sliced from bilateral grids and
+ the output color `rgb_out` after applying the afffine transformations.
+
+ In the 2-D input case, `xy` is a $(N, 2)$ tensor, `rgb` is a $(N, 3)$ tensor, and `grid_idx` is a $(N, 1)$ tensor.
+ Then `affine_mats[i]` can be obtained via slicing the bilateral grid indexed at `grid_idx[i]` by `xy[i, :]` and `rgb2gray(rgb[i, :])`.
+ For 3-D and 4-D input cases, the behavior of indexing bilateral grids and coordinates is the same with the 2-D case.
+
+ .. note::
+ This function can be regarded as a wrapper of `color_affine_transform` and `BilateralGrid` with a slight performance improvement.
+ When `grid_idx` contains a unique index, only a single bilateral grid will used during the slicing. In this case, this function will not
+ perform tensor indexing to avoid data copy and extra memory
+ (see [this](https://discuss.pytorch.org/t/does-indexing-a-tensor-return-a-copy-of-it/164905)).
+
+ Args:
+ bil_grids (`BilateralGrid`): An instance of $N$ bilateral grids.
+ xy (torch.Tensor): The x-y coordinates of shape $(..., 2)$ in the range of $[0,1]$.
+ rgb (torch.Tensor): The RGB values of shape $(..., 3)$ for computing the guidance coordinates, ranging in $[0,1]$.
+ grid_idx (torch.Tensor): The indices of bilateral grids for each slicing. Shape: $(..., 1)$.
+
+ Returns:
+ A dictionary with keys and values as follows:
+ ```
+ {
+ "rgb": Transformed RGB colors. Shape: (..., 3),
+ "rgb_affine_mats": The sliced affine transformation matrices from bilateral grids. Shape: (..., 3, 4)
+ }
+ ```
+ """
+
+ sh_ = rgb.shape
+
+ grid_idx_unique = torch.unique(grid_idx)
+ if len(grid_idx_unique) == 1:
+ # All pixels are from a single view.
+ grid_idx = grid_idx_unique # (1,)
+ xy = xy.unsqueeze(0) # (1, ..., 2)
+ rgb = rgb.unsqueeze(0) # (1, ..., 3)
+ else:
+ # Pixels are randomly sampled from different views.
+ if len(grid_idx.shape) == 4:
+ grid_idx = grid_idx[:, 0, 0, 0] # (chunk_size,)
+ elif len(grid_idx.shape) == 3:
+ grid_idx = grid_idx[:, 0, 0] # (chunk_size,)
+ elif len(grid_idx.shape) == 2:
+ grid_idx = grid_idx[:, 0] # (chunk_size,)
+ else:
+ raise ValueError(
+ "The input to bilateral grid slicing is not supported yet."
+ )
+
+ affine_mats = bil_grids(xy, rgb, grid_idx)
+ rgb = color_affine_transform(affine_mats, rgb)
+
+ return {
+ "rgb": rgb.reshape(*sh_),
+ "rgb_affine_mats": affine_mats.reshape(
+ *sh_[:-1], affine_mats.shape[-2], affine_mats.shape[-1]
+ ),
+ }
+
+
+class BilateralGrid(nn.Module):
+ """Class for 3D bilateral grids.
+
+ Holds one or more than one bilateral grids.
+ """
+
+ def __init__(self, num, grid_X=16, grid_Y=16, grid_W=8):
+ """
+ Args:
+ num (int): The number of bilateral grids (i.e., the number of views).
+ grid_X (int): Defines grid width $W$.
+ grid_Y (int): Defines grid height $H$.
+ grid_W (int): Defines grid guidance dimension $L$.
+ """
+ super(BilateralGrid, self).__init__()
+
+ self.grid_width = grid_X
+ """Grid width. Type: int."""
+ self.grid_height = grid_Y
+ """Grid height. Type: int."""
+ self.grid_guidance = grid_W
+ """Grid guidance dimension. Type: int."""
+
+ # Initialize grids.
+ grid = self._init_identity_grid()
+ self.grids = nn.Parameter(grid.tile(num, 1, 1, 1, 1)) # (N, 12, L, H, W)
+ """ A 5-D tensor of shape $(N, 12, L, H, W)$."""
+
+ # Weights of BT601 RGB-to-gray.
+ self.register_buffer("rgb2gray_weight", torch.Tensor([[0.299, 0.587, 0.114]]))
+ self.rgb2gray = lambda rgb: (rgb @ self.rgb2gray_weight.T) * 2.0 - 1.0
+ """ A function that converts RGB to gray-scale guidance in $[-1, 1]$."""
+
+ def _init_identity_grid(self):
+ grid = torch.tensor(
+ [
+ 1.0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 1.0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 1.0,
+ 0,
+ ]
+ ).float()
+ grid = grid.repeat(
+ [self.grid_guidance * self.grid_height * self.grid_width, 1]
+ ) # (L * H * W, 12)
+ grid = grid.reshape(
+ 1, self.grid_guidance, self.grid_height, self.grid_width, -1
+ ) # (1, L, H, W, 12)
+ grid = grid.permute(0, 4, 1, 2, 3) # (1, 12, L, H, W)
+ return grid
+
+ def tv_loss(self):
+ """Computes and returns total variation loss on the bilateral grids."""
+ return total_variation_loss(self.grids)
+
+ def forward(self, grid_xy, rgb, idx=None):
+ """Bilateral grid slicing. Supports 2-D, 3-D, 4-D, and 5-D input.
+ For the 2-D, 3-D, and 4-D cases, please refer to `slice`.
+ For the 5-D cases, `idx` will be unused and the first dimension of `xy` should be
+ equal to the number of bilateral grids. Then this function becomes PyTorch's
+ [`F.grid_sample`](https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html).
+
+ Args:
+ grid_xy (torch.Tensor): The x-y coordinates in the range of $[0,1]$.
+ rgb (torch.Tensor): The RGB values in the range of $[0,1]$.
+ idx (torch.Tensor): The bilateral grid indices.
+
+ Returns:
+ Sliced affine matrices of shape $(..., 3, 4)$.
+ """
+
+ grids = self.grids
+ input_ndims = len(grid_xy.shape)
+ assert len(rgb.shape) == input_ndims
+
+ if input_ndims > 1 and input_ndims < 5:
+ # Convert input into 5D
+ for i in range(5 - input_ndims):
+ grid_xy = grid_xy.unsqueeze(1)
+ rgb = rgb.unsqueeze(1)
+ assert idx is not None
+ elif input_ndims != 5:
+ raise ValueError(
+ "Bilateral grid slicing only takes either 2D, 3D, 4D and 5D inputs"
+ )
+
+ grids = self.grids
+ if idx is not None:
+ grids = grids[idx]
+ assert grids.shape[0] == grid_xy.shape[0]
+
+ # Generate slicing coordinates.
+ grid_xy = (grid_xy - 0.5) * 2 # Rescale to [-1, 1].
+ grid_z = self.rgb2gray(rgb)
+
+ # print(grid_xy.shape, grid_z.shape)
+ # exit()
+ grid_xyz = torch.cat([grid_xy, grid_z], dim=-1) # (N, m, h, w, 3)
+
+ affine_mats = F.grid_sample(
+ grids, grid_xyz, mode="bilinear", align_corners=True, padding_mode="border"
+ ) # (N, 12, m, h, w)
+ affine_mats = affine_mats.permute(0, 2, 3, 4, 1) # (N, m, h, w, 12)
+ affine_mats = affine_mats.reshape(
+ *affine_mats.shape[:-1], 3, 4
+ ) # (N, m, h, w, 3, 4)
+
+ for _ in range(5 - input_ndims):
+ affine_mats = affine_mats.squeeze(1)
+
+ return affine_mats
+
+
+def slice4d(bil_grid4d, xyz, rgb):
+ """Slices a 4D bilateral grid by point coordinates `xyz` and gray-scale guidances of radiance colors `rgb`.
+
+ Args:
+ bil_grid4d (`BilateralGridCP4D`): The input 4D bilateral grid.
+ xyz (torch.Tensor): The xyz coordinates with shape $(..., 3)$.
+ rgb (torch.Tensor): The RGB values with shape $(..., 3)$.
+
+ Returns:
+ A dictionary with keys and values as follows:
+ ```
+ {
+ "rgb": Transformed radiance RGB colors. Shape: (..., 3),
+ "rgb_affine_mats": The sliced affine transformation matrices from the 4D bilateral grid. Shape: (..., 3, 4)
+ }
+ ```
+ """
+
+ affine_mats = bil_grid4d(xyz, rgb)
+ rgb = color_affine_transform(affine_mats, rgb)
+
+ return {"rgb": rgb, "rgb_affine_mats": affine_mats}
+
+
+class _ScaledTanh(nn.Module):
+ def __init__(self, s=2.0):
+ super().__init__()
+ self.scaler = s
+
+ def forward(self, x):
+ return torch.tanh(self.scaler * x)
+
+
+class BilateralGridCP4D(nn.Module):
+ """Class for low-rank 4D bilateral grids."""
+
+ def __init__(
+ self,
+ grid_X=16,
+ grid_Y=16,
+ grid_Z=16,
+ grid_W=8,
+ rank=5,
+ learn_gray=True,
+ gray_mlp_width=8,
+ gray_mlp_depth=2,
+ init_noise_scale=1e-6,
+ bound=2.0,
+ ):
+ """
+ Args:
+ grid_X (int): Defines grid width.
+ grid_Y (int): Defines grid height.
+ grid_Z (int): Defines grid depth.
+ grid_W (int): Defines grid guidance dimension.
+ rank (int): Rank of the 4D bilateral grid.
+ learn_gray (bool): If True, an MLP will be learned to convert RGB colors to gray-scale guidances.
+ gray_mlp_width (int): The MLP width for learnable guidance.
+ gray_mlp_depth (int): The number of MLP layers for learnable guidance.
+ init_noise_scale (float): The noise scale of the initialized factors.
+ bound (float): The bound of the xyz coordinates.
+ """
+ super(BilateralGridCP4D, self).__init__()
+
+ self.grid_X = grid_X
+ """Grid width. Type: int."""
+ self.grid_Y = grid_Y
+ """Grid height. Type: int."""
+ self.grid_Z = grid_Z
+ """Grid depth. Type: int."""
+ self.grid_W = grid_W
+ """Grid guidance dimension. Type: int."""
+ self.rank = rank
+ """Rank of the 4D bilateral grid. Type: int."""
+ self.learn_gray = learn_gray
+ """Flags of learnable guidance is used. Type: bool."""
+ self.gray_mlp_width = gray_mlp_width
+ """The MLP width for learnable guidance. Type: int."""
+ self.gray_mlp_depth = gray_mlp_depth
+ """The MLP depth for learnable guidance. Type: int."""
+ self.init_noise_scale = init_noise_scale
+ """The noise scale of the initialized factors. Type: float."""
+ self.bound = bound
+ """The bound of the xyz coordinates. Type: float."""
+
+ self._init_cp_factors_parafac()
+
+ self.rgb2gray = None
+ """ A function that converts RGB to gray-scale guidances in $[-1, 1]$.
+ If `learn_gray` is True, this will be an MLP network."""
+
+ if self.learn_gray:
+
+ def rgb2gray_mlp_linear(layer):
+ return nn.Linear(
+ self.gray_mlp_width,
+ self.gray_mlp_width if layer < self.gray_mlp_depth - 1 else 1,
+ )
+
+ def rgb2gray_mlp_actfn(_):
+ return nn.ReLU(inplace=True)
+
+ self.rgb2gray = nn.Sequential(
+ *(
+ [nn.Linear(3, self.gray_mlp_width)]
+ + [
+ nn_module(layer)
+ for layer in range(1, self.gray_mlp_depth)
+ for nn_module in [rgb2gray_mlp_actfn, rgb2gray_mlp_linear]
+ ]
+ + [_ScaledTanh(2.0)]
+ )
+ )
+ else:
+ # Weights of BT601/BT470 RGB-to-gray.
+ self.register_buffer(
+ "rgb2gray_weight", torch.Tensor([[0.299, 0.587, 0.114]])
+ )
+ self.rgb2gray = lambda rgb: (rgb @ self.rgb2gray_weight.T) * 2.0 - 1.0
+
+ def _init_identity_grid(self):
+ grid = torch.tensor(
+ [
+ 1.0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 1.0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 1.0,
+ 0,
+ ]
+ ).float()
+ grid = grid.repeat([self.grid_W * self.grid_Z * self.grid_Y * self.grid_X, 1])
+ grid = grid.reshape(self.grid_W, self.grid_Z, self.grid_Y, self.grid_X, -1)
+ grid = grid.permute(4, 0, 1, 2, 3) # (12, grid_W, grid_Z, grid_Y, grid_X)
+ return grid
+
+ def _init_cp_factors_parafac(self):
+ # Initialize identity grids.
+ init_grids = self._init_identity_grid()
+ # Random noises are added to avoid singularity.
+ init_grids = torch.randn_like(init_grids) * self.init_noise_scale + init_grids
+ from tensorly.decomposition import parafac
+
+ # Initialize grid CP factors
+ _, facs = parafac(init_grids.clone().detach(), rank=self.rank)
+
+ self.num_facs = len(facs)
+
+ self.fac_0 = nn.Linear(facs[0].shape[0], facs[0].shape[1], bias=False)
+ self.fac_0.weight = nn.Parameter(facs[0]) # (12, rank)
+
+ for i in range(1, self.num_facs):
+ fac = facs[i].T # (rank, grid_size)
+ fac = fac.view(1, fac.shape[0], fac.shape[1], 1) # (1, rank, grid_size, 1)
+ self.register_buffer(f"fac_{i}_init", fac)
+
+ fac_resid = torch.zeros_like(fac)
+ self.register_parameter(f"fac_{i}", nn.Parameter(fac_resid))
+
+ def tv_loss(self):
+ """Computes and returns total variation loss on the factors of the low-rank 4D bilateral grids."""
+
+ total_loss = 0
+ for i in range(1, self.num_facs):
+ fac = self.get_parameter(f"fac_{i}")
+ total_loss += total_variation_loss(fac)
+
+ return total_loss
+
+ def forward(self, xyz, rgb):
+ """Low-rank 4D bilateral grid slicing.
+
+ Args:
+ xyz (torch.Tensor): The xyz coordinates with shape $(..., 3)$.
+ rgb (torch.Tensor): The corresponding RGB values with shape $(..., 3)$.
+
+ Returns:
+ Sliced affine matrices with shape $(..., 3, 4)$.
+ """
+ sh_ = xyz.shape
+ xyz = xyz.reshape(-1, 3) # flatten (N, 3)
+ rgb = rgb.reshape(-1, 3) # flatten (N, 3)
+
+ xyz = xyz / self.bound
+ assert self.rgb2gray is not None
+ gray = self.rgb2gray(rgb)
+ xyzw = torch.cat([xyz, gray], dim=-1) # (N, 4)
+ xyzw = xyzw.transpose(0, 1) # (4, N)
+ coords = torch.stack([torch.zeros_like(xyzw), xyzw], dim=-1) # (4, N, 2)
+ coords = coords.unsqueeze(1) # (4, 1, N, 2)
+
+ coef = 1.0
+ for i in range(1, self.num_facs):
+ fac = self.get_parameter(f"fac_{i}") + self.get_buffer(f"fac_{i}_init")
+ coef = coef * F.grid_sample(
+ fac, coords[[i - 1]], align_corners=True, padding_mode="border"
+ ) # [1, rank, 1, N]
+ coef = coef.squeeze([0, 2]).transpose(0, 1) # (N, rank) #type: ignore
+ mat = self.fac_0(coef)
+ return mat.reshape(*sh_[:-1], 3, 4)
diff --git a/src/post_opt/simple_trainer.py b/src/post_opt/simple_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..207f2d63266246140d161a2dda449ac0a2f281dc
--- /dev/null
+++ b/src/post_opt/simple_trainer.py
@@ -0,0 +1,1433 @@
+import json
+import math
+import os
+import time
+from collections import defaultdict
+from dataclasses import dataclass, field
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple, Union
+
+import imageio
+import matplotlib
+import torchvision
+import numpy as np
+import torch
+import torch.nn.functional as F
+import tqdm
+import tyro
+import viser
+import yaml
+import torchvision
+import sys
+from plyfile import PlyData, PlyElement
+
+sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
+
+from src.model.encoder.vggt.utils.pose_enc import pose_encoding_to_extri_intri
+from src.model.types import Gaussians
+from src.post_opt.datasets.colmap import Dataset, Parser
+from src.post_opt.datasets.traj import (
+ generate_ellipse_path_z,
+ generate_interpolated_path,
+ generate_spiral_path,
+)
+from fused_ssim import fused_ssim
+
+from src.utils.image import process_image
+from src.post_opt.exporter import export_splats
+from src.post_opt.lib_bilagrid import BilateralGrid, color_correct, slice, total_variation_loss
+from torch import Tensor
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.tensorboard import SummaryWriter
+from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure
+from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
+from typing_extensions import Literal, assert_never
+from src.post_opt.utils import AppearanceOptModule, CameraOptModule, knn, rgb_to_sh, set_random_seed
+
+# from gsplat import export_splats
+from gsplat.compression import PngCompression
+from gsplat.distributed import cli
+# from gsplat.optimizers import SelectiveAdam
+# from gsplat.rendering import rasterization
+from gsplat import rasterization
+from gsplat.strategy import DefaultStrategy, MCMCStrategy
+from src.post_opt.gsplat_viewer import GsplatViewer, GsplatRenderTabState
+from nerfview import CameraState, RenderTabState, apply_float_colormap
+
+import torch
+from einops import rearrange
+from jaxtyping import Float
+from torch import Tensor
+from scipy.spatial.transform import Rotation as R
+
+from src.model.model.anysplat import AnySplat
+
+
+# pytorch3d/pytorch3d/transforms/rotation_conversions.py at main · facebookresearch/pytorch3d
+def quaternion_to_matrix(
+ quaternions: Float[Tensor, "*batch 4"],
+ eps: float = 1e-8,
+) -> Float[Tensor, "*batch 3 3"]:
+ # Order changed to match scipy format!
+ i, j, k, r = torch.unbind(quaternions, dim=-1)
+ two_s = 2 / ((quaternions * quaternions).sum(dim=-1) + eps)
+
+ o = torch.stack(
+ (
+ 1 - two_s * (j * j + k * k),
+ two_s * (i * j - k * r),
+ two_s * (i * k + j * r),
+ two_s * (i * j + k * r),
+ 1 - two_s * (i * i + k * k),
+ two_s * (j * k - i * r),
+ two_s * (i * k - j * r),
+ two_s * (j * k + i * r),
+ 1 - two_s * (i * i + j * j),
+ ),
+ -1,
+ )
+ return rearrange(o, "... (i j) -> ... i j", i=3, j=3)
+
+def construct_list_of_attributes(num_rest: int) -> list[str]:
+ attributes = ["x", "y", "z", "nx", "ny", "nz"]
+ for i in range(3):
+ attributes.append(f"f_dc_{i}")
+ for i in range(num_rest):
+ attributes.append(f"f_rest_{i}")
+ attributes.append("opacity")
+ for i in range(3):
+ attributes.append(f"scale_{i}")
+ for i in range(4):
+ attributes.append(f"rot_{i}")
+ return attributes
+
+def export_ply(
+ means: Float[Tensor, "gaussian 3"],
+ scales: Float[Tensor, "gaussian 3"],
+ rotations: Float[Tensor, "gaussian 4"],
+ harmonics: Float[Tensor, "gaussian 3 d_sh"],
+ opacities: Float[Tensor, " gaussian"],
+ path: Path,
+ shift_and_scale: bool = False,
+ save_sh_dc_only: bool = True,
+):
+ if shift_and_scale:
+ # Shift the scene so that the median Gaussian is at the origin.
+ means = means - means.median(dim=0).values
+
+ # Rescale the scene so that most Gaussians are within range [-1, 1].
+ scale_factor = means.abs().quantile(0.95, dim=0).max()
+ means = means / scale_factor
+ scales = scales / scale_factor
+
+ # Apply the rotation to the Gaussian rotations.
+ rotations = R.from_quat(rotations.detach().cpu().numpy()).as_matrix()
+ rotations = R.from_matrix(rotations).as_quat()
+ x, y, z, w = rearrange(rotations, "g xyzw -> xyzw g")
+ rotations = np.stack((w, x, y, z), axis=-1)
+
+ # Since current model use SH_degree = 4,
+ # which require large memory to store, we can only save the DC band to save memory.
+ f_dc = harmonics[..., 0]
+ f_rest = harmonics[..., 1:].flatten(start_dim=1)
+
+ dtype_full = [(attribute, "f4") for attribute in construct_list_of_attributes(0 if save_sh_dc_only else f_rest.shape[1])]
+ elements = np.empty(means.shape[0], dtype=dtype_full)
+ attributes = [
+ means.detach().cpu().numpy(),
+ torch.zeros_like(means).detach().cpu().numpy(),
+ f_dc.detach().cpu().contiguous().numpy(),
+ f_rest.detach().cpu().contiguous().numpy(),
+ opacities[..., None].detach().cpu().numpy(),
+ scales.detach().cpu().numpy(),
+ rotations,
+ ]
+ if save_sh_dc_only:
+ # remove f_rest from attributes
+ attributes.pop(3)
+
+ attributes = np.concatenate(attributes, axis=1)
+ elements[:] = list(map(tuple, attributes))
+ path.parent.mkdir(exist_ok=True, parents=True)
+ PlyData([PlyElement.describe(elements, "vertex")]).write(path)
+
+def colorize_depth_maps(depth_map, min_depth=0.0, max_depth=1.0, cmap="Spectral", valid_mask=None):
+ """
+ Colorize depth maps.
+ """
+ assert len(depth_map.shape) >= 2, "Invalid dimension"
+
+ if isinstance(depth_map, torch.Tensor):
+ depth = depth_map.detach().clone().squeeze().numpy()
+ elif isinstance(depth_map, np.ndarray):
+ depth = depth_map.copy().squeeze()
+ # reshape to [ (B,) H, W ]
+ if depth.ndim < 3:
+ depth = depth[np.newaxis, :, :]
+
+ # colorize
+ cm = matplotlib.colormaps[cmap]
+ # depth = ((depth - min_depth) / (max_depth - min_depth)).clip(0, 1)
+ depth = ((depth - depth.min()) / (depth.max() - depth.min())).clip(0, 1)
+ img_colored_np = cm(depth, bytes=False)[:, :, :, 0:3] # value from 0 to 1
+ img_colored_np = np.rollaxis(img_colored_np, 3, 1)
+
+ if valid_mask is not None:
+ if isinstance(depth_map, torch.Tensor):
+ valid_mask = valid_mask.detach().numpy()
+ valid_mask = valid_mask.squeeze() # [H, W] or [B, H, W]
+ if valid_mask.ndim < 3:
+ valid_mask = valid_mask[np.newaxis, np.newaxis, :, :]
+ else:
+ valid_mask = valid_mask[:, np.newaxis, :, :]
+ valid_mask = np.repeat(valid_mask, 3, axis=1)
+ img_colored_np[~valid_mask] = 0
+
+ if isinstance(depth_map, torch.Tensor):
+ img_colored = torch.from_numpy(img_colored_np).float()
+ elif isinstance(depth_map, np.ndarray):
+ img_colored = img_colored_np
+
+ return img_colored
+
+def build_covariance(
+ scale: Float[Tensor, "*#batch 3"],
+ rotation_xyzw: Float[Tensor, "*#batch 4"],
+) -> Float[Tensor, "*batch 3 3"]:
+ scale = scale.diag_embed()
+ rotation = quaternion_to_matrix(rotation_xyzw)
+ return (
+ rotation
+ @ scale
+ @ rearrange(scale, "... i j -> ... j i")
+ @ rearrange(rotation, "... i j -> ... j i")
+ )
+
+
+@dataclass
+class Config:
+ # Disable viewer
+ disable_viewer: bool = True
+ # Path to the .pt files. If provide, it will skip training and run evaluation only.
+ ckpt: Optional[List[str]] = None
+ # Name of compression strategy to use
+ compression: Optional[Literal["png"]] = None
+ # Render trajectory path
+ render_traj_path: str = "interp"
+
+ data_dir: str = "data/360_v2/garden"
+ # Downsample factor for the dataset
+ data_factor: int = 4
+ # Directory to save results
+ result_dir: str = "results/garden"
+ # Every N images there is a test image
+ test_every: int = 8
+ # Random crop size for training (experimental)
+ patch_size: Optional[int] = None
+ # A global scaler that applies to the scene size related parameters
+ global_scale: float = 1.0
+ # Normalize the world space
+ normalize_world_space: bool = True
+ # Camera model
+ camera_model: Literal["pinhole", "ortho", "fisheye"] = "pinhole"
+
+ # Port for the viewer server
+ port: int = 8080
+
+ # Batch size for training. Learning rates are scaled automatically
+ batch_size: int = 1
+ # A global factor to scale the number of training steps
+ steps_scaler: float = 1.0
+
+ # Number of training steps
+ max_steps: int = 3_000
+ # Steps to evaluate the model
+ eval_steps: List[int] = field(default_factory=lambda: [1, 1_000, 3_000, 7_000, 10_000])
+ # Steps to save the model
+ save_steps: List[int] = field(default_factory=lambda: [1, 1_000, 3_000, 7_000, 10_000])
+ # Whether to save ply file (storage size can be large)
+ save_ply: bool = False
+ # Steps to save the model as ply
+ ply_steps: List[int] = field(default_factory=lambda: [1, 1_000, 3_000, 7_000, 10_000])
+ # Whether to disable video generation during training and evaluation
+ disable_video: bool = False
+
+ # Initialization strategy
+ init_type: str = "sfm"
+ # Initial number of GSs. Ignored if using sfm
+ init_num_pts: int = 100_000
+ # Initial extent of GSs as a multiple of the camera extent. Ignored if using sfm
+ init_extent: float = 3.0
+ # Degree of spherical harmonics
+ sh_degree: int = 4
+ # Turn on another SH degree every this steps
+ sh_degree_interval: int = 1000
+ # Initial opacity of GS
+ init_opa: float = 0.1
+ # Initial scale of GS
+ init_scale: float = 1.0
+ # Weight for SSIM loss
+ ssim_lambda: float = 0.2
+
+ # Near plane clipping distance
+ near_plane: float = 1e-10
+ # Far plane clipping distance
+ far_plane: float = 1e10
+
+ # Strategy for GS densification
+ strategy: Union[DefaultStrategy, MCMCStrategy] = field(
+ default_factory=DefaultStrategy
+ )
+ # Use packed mode for rasterization, this leads to less memory usage but slightly slower.
+ packed: bool = False
+ # Use sparse gradients for optimization. (experimental)
+ sparse_grad: bool = False
+ # Use visible adam from Taming 3DGS. (experimental)
+ visible_adam: bool = False
+ # Anti-aliasing in rasterization. Might slightly hurt quantitative metrics.
+ antialiased: bool = False
+
+ # Use random background for training to discourage transparency
+ random_bkgd: bool = False
+
+ # Opacity regularization
+ opacity_reg: float = 0.0
+ # Scale regularization
+ scale_reg: float = 0.0
+
+ # Enable camera optimization.
+ pose_opt: bool = True
+ # Learning rate for camera optimization
+ pose_opt_lr: float = 1e-5
+ # Regularization for camera optimization as weight decay
+ pose_opt_reg: float = 1e-6
+ # Add noise to camera extrinsics. This is only to test the camera pose optimization.
+ pose_noise: float = 0.0
+
+ # Enable appearance optimization. (experimental)
+ app_opt: bool = False
+ # Appearance embedding dimension
+ app_embed_dim: int = 16
+ # Learning rate for appearance optimization
+ app_opt_lr: float = 1e-3
+ # Regularization for appearance optimization as weight decay
+ app_opt_reg: float = 1e-6
+
+ # Enable bilateral grid. (experimental)
+ use_bilateral_grid: bool = False
+ # Shape of the bilateral grid (X, Y, W)
+ bilateral_grid_shape: Tuple[int, int, int] = (16, 16, 8)
+
+ # Enable depth loss. (experimental)
+ depth_loss: bool = False
+ # Weight for depth loss
+ depth_lambda: float = 1e-2
+
+ # Dump information to tensorboard every this steps
+ tb_every: int = 100
+ # Save training images to tensorboard
+ tb_save_image: bool = False
+
+ lpips_net: Literal["vgg", "alex"] = "vgg"
+
+ lr_means: float = 1.6e-4
+ lr_scales: float = 5e-3
+ lr_quats: float = 1e-3
+ lr_opacities: float = 5e-2
+ lr_sh: float = 2.5e-3
+
+ def adjust_steps(self, factor: float):
+ self.eval_steps = [int(i * factor) for i in self.eval_steps]
+ self.save_steps = [int(i * factor) for i in self.save_steps]
+ self.ply_steps = [int(i * factor) for i in self.ply_steps]
+ self.max_steps = int(self.max_steps * factor)
+ self.sh_degree_interval = int(self.sh_degree_interval * factor)
+
+ strategy = self.strategy
+ if isinstance(strategy, DefaultStrategy):
+ # strategy.refine_start_iter = int(strategy.refine_start_iter * factor)
+ # strategy.refine_stop_iter = int(strategy.refine_stop_iter * factor)
+ # strategy.reset_every = int(strategy.reset_every * factor)
+ # strategy.refine_every = int(strategy.refine_every * factor)
+
+ strategy.refine_start_iter = 30000
+ strategy.refine_stop_iter = 0
+ strategy.reset_every = 30000
+ strategy.refine_every = 30000
+
+ elif isinstance(strategy, MCMCStrategy):
+ strategy.refine_start_iter = int(strategy.refine_start_iter * factor)
+ strategy.refine_stop_iter = int(strategy.refine_stop_iter * factor)
+ strategy.refine_every = int(strategy.refine_every * factor)
+ else:
+ assert_never(strategy)
+
+
+def create_splats_with_optimizers(
+ gaussians: Gaussians,
+ init_num_pts: int = 100_000,
+ init_extent: float = 3.0,
+ init_opacity: float = 0.1,
+ init_scale: float = 1.0,
+ sh_degree: int = 3,
+ sparse_grad: bool = False,
+ visible_adam: bool = False,
+ batch_size: int = 1,
+ feature_dim: Optional[int] = None,
+ device: str = "cuda",
+ world_rank: int = 0,
+ world_size: int = 1,
+ cfg: Config = None,
+) -> Tuple[torch.nn.ParameterDict, Dict[str, torch.optim.Optimizer]]:
+
+ points = gaussians.means[0].detach().float()
+ scales = torch.log(gaussians.scales[0].detach().float())
+ quats = gaussians.rotations[0].detach().float()
+ opacities = torch.logit(gaussians.opacities[0].detach().float())
+ harmonics = gaussians.harmonics[0].detach().float().permute(0, 2, 1).contiguous()
+
+ N = points.shape[0]
+
+ scene_scale = 1.0
+ masks = opacities.sigmoid() > 0.01
+ harmonics = harmonics[masks]
+ params = [
+ # name, value, lr
+ ("means", torch.nn.Parameter(points[masks]), cfg.lr_means * scene_scale),
+ ("scales", torch.nn.Parameter(scales[masks]), cfg.lr_scales),
+ ("quats", torch.nn.Parameter(quats[masks]), cfg.lr_quats),
+ ("opacities", torch.nn.Parameter(opacities[masks]), cfg.lr_opacities),
+ ]
+
+ params.append(("sh0", torch.nn.Parameter(harmonics[:, :1, :]), cfg.lr_sh))
+ params.append(("shN", torch.nn.Parameter(harmonics[:, 1:, :]), cfg.lr_sh/20))
+
+ splats = torch.nn.ParameterDict({n: v for n, v, _ in params}).to(device)
+ # Scale learning rate based on batch size, reference:
+ # https://www.cs.princeton.edu/~smalladi/blog/2024/01/22/SDEs-ScalingRules/
+ # Note that this would not make the training exactly equivalent, see
+ # https://arxiv.org/pdf/2402.18824v1
+ BS = batch_size * world_size
+ optimizer_class = None
+ if sparse_grad:
+ optimizer_class = torch.optim.SparseAdam
+ elif visible_adam:
+ optimizer_class = SelectiveAdam
+ else:
+ optimizer_class = torch.optim.Adam
+ optimizers = {
+ name: optimizer_class(
+ [{"params": splats[name], "lr": lr * math.sqrt(BS), "name": name}],
+ eps=1e-15 / math.sqrt(BS),
+ # TODO: check betas logic when BS is larger than 10 betas[0] will be zero.
+ betas=(1 - BS * (1 - 0.9), 1 - BS * (1 - 0.999)),
+ )
+ for name, _, lr in params
+ }
+ return splats, optimizers
+
+
+class Runner:
+ """Engine for training and testing."""
+
+ def __init__(
+ self, local_rank: int, world_rank, world_size: int, cfg: Config
+ ) -> None:
+ set_random_seed(42 + local_rank)
+
+ self.cfg = cfg
+ self.world_rank = world_rank
+ self.local_rank = local_rank
+ self.world_size = world_size
+ self.device = f"cuda:{local_rank}"
+
+ # Where to dump results.
+ os.makedirs(cfg.result_dir, exist_ok=True)
+
+ # Setup output directories.
+ self.ckpt_dir = f"{cfg.result_dir}/ckpts"
+ os.makedirs(self.ckpt_dir, exist_ok=True)
+ self.stats_dir = f"{cfg.result_dir}/stats"
+ os.makedirs(self.stats_dir, exist_ok=True)
+ self.render_dir = f"{cfg.result_dir}/renders"
+ os.makedirs(self.render_dir, exist_ok=True)
+ self.ply_dir = f"{cfg.result_dir}/ply"
+ os.makedirs(self.ply_dir, exist_ok=True)
+
+ # Tensorboard
+ self.writer = SummaryWriter(log_dir=f"{cfg.result_dir}/tb")
+
+ # first get the initial 3DGS and camera poses
+ model = AnySplat.from_pretrained("lhjiang/anysplat")
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ model.to(device)
+ model.eval()
+ for param in model.parameters():
+ param.requires_grad = False
+
+ image_folder = cfg.data_dir
+ image_names = sorted([os.path.join(image_folder, f) for f in os.listdir(image_folder) if f.lower().endswith(('.png', '.jpg', '.jpeg'))])
+ images = [process_image(img_path) for img_path in image_names]
+ ctx_indices = [idx for idx, name in enumerate(image_names) if idx % cfg.test_every != 0]
+ tgt_indices = [idx for idx, name in enumerate(image_names) if idx % cfg.test_every == 0]
+
+ ctx_images = torch.stack([images[i] for i in ctx_indices], dim=0).unsqueeze(0).to(device)
+ tgt_images = torch.stack([images[i] for i in tgt_indices], dim=0).unsqueeze(0).to(device)
+ ctx_images = (ctx_images+1)*0.5
+ tgt_images = (tgt_images+1)*0.5
+ b, v, _, h, w = tgt_images.shape
+
+ # run inference
+ encoder_output = model.encoder(
+ ctx_images,
+ global_step=0,
+ visualization_dump={},
+ )
+ gaussians, pred_context_pose = encoder_output.gaussians, encoder_output.pred_context_pose
+
+ num_context_view = ctx_images.shape[1]
+ vggt_input_image = torch.cat((ctx_images, tgt_images), dim=1).to(torch.bfloat16)
+ with torch.no_grad(), torch.cuda.amp.autocast(enabled=False, dtype=torch.bfloat16):
+ aggregated_tokens_list, patch_start_idx = model.encoder.aggregator(vggt_input_image, intermediate_layer_idx=model.encoder.cfg.intermediate_layer_idx)
+ with torch.cuda.amp.autocast(enabled=False):
+ fp32_tokens = [token.float() for token in aggregated_tokens_list]
+ pred_all_pose_enc = model.encoder.camera_head(fp32_tokens)[-1]
+ pred_all_extrinsic, pred_all_intrinsic = pose_encoding_to_extri_intri(pred_all_pose_enc, vggt_input_image.shape[-2:])
+
+ extrinsic_padding = torch.tensor([0, 0, 0, 1], device=pred_all_extrinsic.device, dtype=pred_all_extrinsic.dtype).view(1, 1, 1, 4).repeat(b, vggt_input_image.shape[1], 1, 1)
+ pred_all_extrinsic = torch.cat([pred_all_extrinsic, extrinsic_padding], dim=2).inverse()
+
+ pred_all_intrinsic[:, :, 0] = pred_all_intrinsic[:, :, 0] / w
+ pred_all_intrinsic[:, :, 1] = pred_all_intrinsic[:, :, 1] / h
+ pred_all_context_extrinsic, pred_all_target_extrinsic = pred_all_extrinsic[:, :num_context_view], pred_all_extrinsic[:, num_context_view:]
+ pred_all_context_intrinsic, pred_all_target_intrinsic = pred_all_intrinsic[:, :num_context_view], pred_all_intrinsic[:, num_context_view:]
+
+ scale_factor = pred_context_pose['extrinsic'][:, :, :3, 3].mean() / pred_all_context_extrinsic[:, :, :3, 3].mean()
+ pred_all_target_extrinsic[..., :3, 3] = pred_all_target_extrinsic[..., :3, 3] * scale_factor
+ pred_all_context_extrinsic[..., :3, 3] = pred_all_context_extrinsic[..., :3, 3] * scale_factor
+ print("scale_factor:", scale_factor)
+
+ # Load data: Training data should contain initial points and colors.
+ # self.parser = Parser(
+ # data_dir=cfg.data_dir,
+ # factor=cfg.data_factor,
+ # normalize=cfg.normalize_world_space,
+ # test_every=cfg.test_every,
+ # )
+ self.trainset = Dataset(
+ split="train",
+ images=ctx_images[0].detach().cpu().numpy(),
+ camtoworlds=pred_all_context_extrinsic[0].detach().cpu().numpy(),
+ Ks=pred_all_context_intrinsic[0].detach().cpu().numpy(),
+ patch_size=cfg.patch_size,
+ load_depths=cfg.depth_loss,
+ )
+ self.valset = Dataset(
+ images=tgt_images[0].detach().cpu().numpy(),
+ camtoworlds=pred_all_target_extrinsic[0].detach().cpu().numpy(),
+ Ks=pred_all_target_intrinsic[0].detach().cpu().numpy(),
+ split="val"
+ )
+
+ # Model
+ feature_dim = 32 if cfg.app_opt else None
+ self.splats, self.optimizers = create_splats_with_optimizers(
+ gaussians=gaussians,
+ init_num_pts=cfg.init_num_pts,
+ init_extent=cfg.init_extent,
+ init_opacity=cfg.init_opa,
+ init_scale=cfg.init_scale,
+ sh_degree=cfg.sh_degree,
+ sparse_grad=cfg.sparse_grad,
+ visible_adam=cfg.visible_adam,
+ batch_size=cfg.batch_size,
+ feature_dim=feature_dim,
+ device=self.device,
+ world_rank=world_rank,
+ world_size=world_size,
+ cfg=cfg,
+ )
+ print("Model initialized. Number of GS:", len(self.splats["means"]))
+
+ # Densification Strategy
+ self.cfg.strategy.check_sanity(self.splats, self.optimizers)
+
+ if isinstance(self.cfg.strategy, DefaultStrategy):
+ self.strategy_state = self.cfg.strategy.initialize_state(
+ scene_scale=1.0
+ )
+ elif isinstance(self.cfg.strategy, MCMCStrategy):
+ self.strategy_state = self.cfg.strategy.initialize_state()
+ else:
+ assert_never(self.cfg.strategy)
+
+ # Compression Strategy
+ self.compression_method = None
+ if cfg.compression is not None:
+ if cfg.compression == "png":
+ self.compression_method = PngCompression()
+ else:
+ raise ValueError(f"Unknown compression strategy: {cfg.compression}")
+
+ self.pose_optimizers = []
+ if cfg.pose_opt:
+ self.pose_adjust = CameraOptModule(len(self.trainset)).to(self.device)
+ self.pose_adjust.zero_init()
+ self.pose_optimizers = [
+ torch.optim.Adam(
+ self.pose_adjust.parameters(),
+ lr=cfg.pose_opt_lr * math.sqrt(cfg.batch_size),
+ weight_decay=cfg.pose_opt_reg,
+ )
+ ]
+ if world_size > 1:
+ self.pose_adjust = DDP(self.pose_adjust)
+
+ if cfg.pose_noise > 0.0:
+ self.pose_perturb = CameraOptModule(len(self.trainset)).to(self.device)
+ self.pose_perturb.random_init(cfg.pose_noise)
+ if world_size > 1:
+ self.pose_perturb = DDP(self.pose_perturb)
+
+ self.app_optimizers = []
+ if cfg.app_opt:
+ assert feature_dim is not None
+ self.app_module = AppearanceOptModule(
+ len(self.trainset), feature_dim, cfg.app_embed_dim, cfg.sh_degree
+ ).to(self.device)
+ # initialize the last layer to be zero so that the initial output is zero.
+ torch.nn.init.zeros_(self.app_module.color_head[-1].weight)
+ torch.nn.init.zeros_(self.app_module.color_head[-1].bias)
+ self.app_optimizers = [
+ torch.optim.Adam(
+ self.app_module.embeds.parameters(),
+ lr=cfg.app_opt_lr * math.sqrt(cfg.batch_size) * 10.0,
+ weight_decay=cfg.app_opt_reg,
+ ),
+ torch.optim.Adam(
+ self.app_module.color_head.parameters(),
+ lr=cfg.app_opt_lr * math.sqrt(cfg.batch_size),
+ ),
+ ]
+ if world_size > 1:
+ self.app_module = DDP(self.app_module)
+
+ self.bil_grid_optimizers = []
+ if cfg.use_bilateral_grid:
+ self.bil_grids = BilateralGrid(
+ len(self.trainset),
+ grid_X=cfg.bilateral_grid_shape[0],
+ grid_Y=cfg.bilateral_grid_shape[1],
+ grid_W=cfg.bilateral_grid_shape[2],
+ ).to(self.device)
+ self.bil_grid_optimizers = [
+ torch.optim.Adam(
+ self.bil_grids.parameters(),
+ lr=2e-3 * math.sqrt(cfg.batch_size),
+ eps=1e-15,
+ ),
+ ]
+
+ # Losses & Metrics.
+ self.ssim = StructuralSimilarityIndexMeasure(data_range=1.0).to(self.device)
+ self.psnr = PeakSignalNoiseRatio(data_range=1.0).to(self.device)
+
+ if cfg.lpips_net == "alex":
+ self.lpips = LearnedPerceptualImagePatchSimilarity(
+ net_type="alex", normalize=True
+ ).to(self.device)
+ elif cfg.lpips_net == "vgg":
+ # The 3DGS official repo uses lpips vgg, which is equivalent with the following:
+ self.lpips = LearnedPerceptualImagePatchSimilarity(
+ net_type="vgg", normalize=False
+ ).to(self.device)
+ else:
+ raise ValueError(f"Unknown LPIPS network: {cfg.lpips_net}")
+
+ # Viewer
+ if not self.cfg.disable_viewer:
+ self.server = viser.ViserServer(port=cfg.port, verbose=False)
+ self.viewer = GsplatViewer(
+ server=self.server,
+ render_fn=self._viewer_render_fn,
+ output_dir=Path(cfg.result_dir),
+ mode="training",
+ )
+
+ def rasterize_splats(
+ self,
+ camtoworlds: Tensor,
+ Ks: Tensor,
+ width: int,
+ height: int,
+ masks: Optional[Tensor] = None,
+ rasterize_mode: Optional[Literal["classic", "antialiased"]] = None,
+ camera_model: Optional[Literal["pinhole", "ortho", "fisheye"]] = None,
+ **kwargs,
+ ) -> Tuple[Tensor, Tensor, Dict]:
+ means = self.splats["means"] # [N, 3]
+ # quats = F.normalize(self.splats["quats"], dim=-1) # [N, 4]
+ # rasterization does normalization internally
+ quats = self.splats["quats"] # [N, 4]
+ scales = torch.exp(self.splats["scales"]) # [N, 3]
+ opacities = torch.sigmoid(self.splats["opacities"]) # [N,]
+
+ image_ids = kwargs.pop("image_ids", None)
+ if self.cfg.app_opt:
+ colors = self.app_module(
+ features=self.splats["features"],
+ embed_ids=image_ids,
+ dirs=means[None, :, :] - camtoworlds[:, None, :3, 3],
+ sh_degree=kwargs.pop("sh_degree", self.cfg.sh_degree),
+ )
+ colors = colors + self.splats["colors"]
+ colors = torch.sigmoid(colors)
+ else:
+ colors = torch.cat([self.splats["sh0"], self.splats["shN"]], 1) # [N, K, 3]
+
+ if rasterize_mode is None:
+ rasterize_mode = "antialiased" if self.cfg.antialiased else "classic"
+ if camera_model is None:
+ camera_model = self.cfg.camera_model
+
+ # covariance = build_covariance(scales[None], quats[None]).squeeze(0)
+ render_colors, render_alphas, info = rasterization(
+ means=means,
+ quats=quats,
+ scales=scales,
+ opacities=opacities,
+ colors=colors,
+ # covars=covariance,
+ viewmats=torch.linalg.inv(camtoworlds), # [C, 4, 4]
+ Ks=Ks, # [C, 3, 3]
+ width=width,
+ height=height,
+ packed=self.cfg.packed,
+ absgrad=(
+ self.cfg.strategy.absgrad
+ if isinstance(self.cfg.strategy, DefaultStrategy)
+ else False
+ ),
+ sparse_grad=self.cfg.sparse_grad,
+ rasterize_mode=rasterize_mode,
+ distributed=self.world_size > 1,
+ camera_model=self.cfg.camera_model,
+ radius_clip=0.1,
+ backgrounds=torch.tensor([0.0, 0.0, 0.0]).cuda().unsqueeze(0).repeat(1, 1),
+ **kwargs,
+ )
+ if masks is not None:
+ render_colors[~masks] = 0
+ return render_colors, render_alphas, info
+
+ def train(self):
+ cfg = self.cfg
+ device = self.device
+ world_rank = self.world_rank
+ world_size = self.world_size
+
+ # Dump cfg.
+ if world_rank == 0:
+ with open(f"{cfg.result_dir}/cfg.yml", "w") as f:
+ yaml.dump(vars(cfg), f)
+
+ max_steps = cfg.max_steps
+ init_step = 0
+
+ schedulers = [
+ # means has a learning rate schedule, that end at 0.01 of the initial value
+ torch.optim.lr_scheduler.ExponentialLR(
+ self.optimizers["means"], gamma=0.01 ** (1.0 / max_steps)
+ ),
+ ]
+ if cfg.pose_opt:
+ # pose optimization has a learning rate schedule
+ schedulers.append(
+ torch.optim.lr_scheduler.ExponentialLR(
+ self.pose_optimizers[0], gamma=0.01 ** (1.0 / max_steps)
+ )
+ )
+ if cfg.use_bilateral_grid:
+ # bilateral grid has a learning rate schedule. Linear warmup for 1000 steps.
+ schedulers.append(
+ torch.optim.lr_scheduler.ChainedScheduler(
+ [
+ torch.optim.lr_scheduler.LinearLR(
+ self.bil_grid_optimizers[0],
+ start_factor=0.01,
+ total_iters=1000,
+ ),
+ torch.optim.lr_scheduler.ExponentialLR(
+ self.bil_grid_optimizers[0], gamma=0.01 ** (1.0 / max_steps)
+ ),
+ ]
+ )
+ )
+
+ trainloader = torch.utils.data.DataLoader(
+ self.trainset,
+ batch_size=cfg.batch_size,
+ shuffle=True,
+ num_workers=4,
+ persistent_workers=True,
+ pin_memory=True,
+ )
+ trainloader_iter = iter(trainloader)
+
+ # Training loop.
+ global_tic = time.time()
+ pbar = tqdm.tqdm(range(init_step, max_steps))
+ for step in pbar:
+ if not cfg.disable_viewer:
+ while self.viewer.state == "paused":
+ time.sleep(0.01)
+ self.viewer.lock.acquire()
+ tic = time.time()
+
+ try:
+ data = next(trainloader_iter)
+ except StopIteration:
+ trainloader_iter = iter(trainloader)
+ data = next(trainloader_iter)
+
+ camtoworlds = camtoworlds_gt = data["camtoworld"].to(device) # [1, 4, 4]
+ Ks = data["K"].to(device) # [1, 3, 3]
+ pixels = data["image"].to(device) / 255.0 # [1, H, W, 3]
+ num_train_rays_per_step = (
+ pixels.shape[0] * pixels.shape[1] * pixels.shape[2]
+ )
+ image_ids = data["image_id"].to(device)
+ masks = data["mask"].to(device) if "mask" in data else None # [1, H, W]
+ if cfg.depth_loss:
+ points = data["points"].to(device) # [1, M, 2]
+ depths_gt = data["depths"].to(device) # [1, M]
+
+ height, width = pixels.shape[1:3]
+
+ if cfg.pose_noise:
+ camtoworlds = self.pose_perturb(camtoworlds, image_ids)
+
+ if cfg.pose_opt:
+ camtoworlds = self.pose_adjust(camtoworlds, image_ids)
+
+ # sh schedule
+ # sh_degree_to_use = min(step // cfg.sh_degree_interval, cfg.sh_degree)
+ sh_degree_to_use = cfg.sh_degree
+
+ # forward
+ renders, alphas, info = self.rasterize_splats(
+ camtoworlds=camtoworlds,
+ Ks=Ks,
+ width=width,
+ height=height,
+ sh_degree=sh_degree_to_use,
+ near_plane=cfg.near_plane,
+ far_plane=cfg.far_plane,
+ image_ids=image_ids,
+ render_mode="RGB+ED" if cfg.depth_loss else "RGB",
+ masks=masks,
+ )
+ if renders.shape[-1] == 4:
+ colors, depths = renders[..., 0:3], renders[..., 3:4]
+ else:
+ colors, depths = renders, None
+
+ if cfg.use_bilateral_grid:
+ grid_y, grid_x = torch.meshgrid(
+ (torch.arange(height, device=self.device) + 0.5) / height,
+ (torch.arange(width, device=self.device) + 0.5) / width,
+ indexing="ij",
+ )
+ grid_xy = torch.stack([grid_x, grid_y], dim=-1).unsqueeze(0)
+ colors = slice(self.bil_grids, grid_xy, colors, image_ids)["rgb"]
+
+ if cfg.random_bkgd:
+ bkgd = torch.rand(1, 3, device=device)
+ colors = colors + bkgd * (1.0 - alphas)
+
+ self.cfg.strategy.step_pre_backward(
+ params=self.splats,
+ optimizers=self.optimizers,
+ state=self.strategy_state,
+ step=step,
+ info=info,
+ )
+
+ # loss
+ l1loss = F.l1_loss(colors, pixels)
+ ssimloss = 1.0 - fused_ssim(
+ colors.permute(0, 3, 1, 2), pixels.permute(0, 3, 1, 2), padding="valid"
+ )
+ loss = l1loss * (1.0 - cfg.ssim_lambda) + ssimloss * cfg.ssim_lambda
+ if cfg.depth_loss:
+ # query depths from depth map
+ points = torch.stack(
+ [
+ points[:, :, 0] / (width - 1) * 2 - 1,
+ points[:, :, 1] / (height - 1) * 2 - 1,
+ ],
+ dim=-1,
+ ) # normalize to [-1, 1]
+ grid = points.unsqueeze(2) # [1, M, 1, 2]
+ depths = F.grid_sample(
+ depths.permute(0, 3, 1, 2), grid, align_corners=True
+ ) # [1, 1, M, 1]
+ depths = depths.squeeze(3).squeeze(1) # [1, M]
+ # calculate loss in disparity space
+ disp = torch.where(depths > 0.0, 1.0 / depths, torch.zeros_like(depths))
+ disp_gt = 1.0 / depths_gt # [1, M]
+ depthloss = F.l1_loss(disp, disp_gt) * self.scene_scale
+ loss += depthloss * cfg.depth_lambda
+ if cfg.use_bilateral_grid:
+ tvloss = 10 * total_variation_loss(self.bil_grids.grids)
+ loss += tvloss
+
+ # regularizations
+ if cfg.opacity_reg > 0.0:
+ loss = (
+ loss
+ + cfg.opacity_reg
+ * torch.abs(torch.sigmoid(self.splats["opacities"])).mean()
+ )
+ if cfg.scale_reg > 0.0:
+ loss = (
+ loss
+ + cfg.scale_reg * torch.abs(torch.exp(self.splats["scales"])).mean()
+ )
+
+ loss.backward()
+
+ desc = f"loss={loss.item():.3f}| " f"sh degree={sh_degree_to_use}| "
+ if cfg.depth_loss:
+ desc += f"depth loss={depthloss.item():.6f}| "
+ if cfg.pose_opt and cfg.pose_noise:
+ # monitor the pose error if we inject noise
+ pose_err = F.l1_loss(camtoworlds_gt, camtoworlds)
+ desc += f"pose err={pose_err.item():.6f}| "
+ pbar.set_description(desc)
+
+ # write images (gt and render)
+ # if world_rank == 0 and step % 800 == 0:
+ # canvas = torch.cat([pixels, colors], dim=2).detach().cpu().numpy()
+ # canvas = canvas.reshape(-1, *canvas.shape[2:])
+ # imageio.imwrite(
+ # f"{self.render_dir}/train_rank{self.world_rank}.png",
+ # (canvas * 255).astype(np.uint8),
+ # )
+
+ if world_rank == 0 and cfg.tb_every > 0 and step % cfg.tb_every == 0:
+ mem = torch.cuda.max_memory_allocated() / 1024**3
+ self.writer.add_scalar("train/loss", loss.item(), step)
+ self.writer.add_scalar("train/l1loss", l1loss.item(), step)
+ self.writer.add_scalar("train/ssimloss", ssimloss.item(), step)
+ self.writer.add_scalar("train/num_GS", len(self.splats["means"]), step)
+ self.writer.add_scalar("train/mem", mem, step)
+ if cfg.depth_loss:
+ self.writer.add_scalar("train/depthloss", depthloss.item(), step)
+ if cfg.use_bilateral_grid:
+ self.writer.add_scalar("train/tvloss", tvloss.item(), step)
+ if cfg.tb_save_image:
+ canvas = torch.cat([pixels, colors], dim=2).detach().cpu().numpy()
+ canvas = canvas.reshape(-1, *canvas.shape[2:])
+ self.writer.add_image("train/render", canvas, step)
+ self.writer.flush()
+
+ # save checkpoint before updating the model
+ if step in [i - 1 for i in cfg.save_steps] or step == max_steps - 1:
+ mem = torch.cuda.max_memory_allocated() / 1024**3
+ stats = {
+ "mem": mem,
+ "ellipse_time": time.time() - global_tic,
+ "num_GS": len(self.splats["means"]),
+ }
+ print("Step: ", step, stats)
+ with open(
+ f"{self.stats_dir}/train_step{step:04d}_rank{self.world_rank}.json",
+ "w",
+ ) as f:
+ json.dump(stats, f)
+ data = {"step": step, "splats": self.splats.state_dict()}
+ if cfg.pose_opt:
+ if world_size > 1:
+ data["pose_adjust"] = self.pose_adjust.module.state_dict()
+ else:
+ data["pose_adjust"] = self.pose_adjust.state_dict()
+ if cfg.app_opt:
+ if world_size > 1:
+ data["app_module"] = self.app_module.module.state_dict()
+ else:
+ data["app_module"] = self.app_module.state_dict()
+ torch.save(
+ data, f"{self.ckpt_dir}/ckpt_{step}_rank{self.world_rank}.pt"
+ )
+ if (
+ step in [i - 1 for i in cfg.ply_steps] or step == max_steps - 1
+ ) and cfg.save_ply:
+
+ if self.cfg.app_opt:
+ # eval at origin to bake the appeareance into the colors
+ rgb = self.app_module(
+ features=self.splats["features"],
+ embed_ids=None,
+ dirs=torch.zeros_like(self.splats["means"][None, :, :]),
+ sh_degree=sh_degree_to_use,
+ )
+ rgb = rgb + self.splats["colors"]
+ rgb = torch.sigmoid(rgb).squeeze(0).unsqueeze(1)
+ sh0 = rgb_to_sh(rgb)
+ shN = torch.empty([sh0.shape[0], 0, 3], device=sh0.device)
+ else:
+ sh0 = self.splats["sh0"]
+ shN = self.splats["shN"]
+ # shN = torch.empty([sh0.shape[0], 0, 3], device=sh0.device)
+
+ means = self.splats["means"]
+ scales = self.splats["scales"]
+ quats = self.splats["quats"]
+ opacities = self.splats["opacities"]
+
+ # export_splats(
+ # means=means,
+ # scales=scales,
+ # quats=quats,
+ # opacities=opacities,
+ # sh0=sh0,
+ # shN=shN,
+ # format="ply",
+ # save_to=f"{self.ply_dir}/point_cloud_{step}.ply",
+ # )
+ export_ply(
+ means=means,
+ scales=scales,
+ rotations=quats,
+ harmonics=torch.cat([sh0, shN], dim=1).permute(0, 2, 1),
+ opacities=opacities.sigmoid(),
+ path=Path(f"{self.ply_dir}/point_cloud_{step}.ply"),
+ )
+
+ # Turn Gradients into Sparse Tensor before running optimizer
+ if cfg.sparse_grad:
+ assert cfg.packed, "Sparse gradients only work with packed mode."
+ gaussian_ids = info["gaussian_ids"]
+ for k in self.splats.keys():
+ grad = self.splats[k].grad
+ if grad is None or grad.is_sparse:
+ continue
+ self.splats[k].grad = torch.sparse_coo_tensor(
+ indices=gaussian_ids[None], # [1, nnz]
+ values=grad[gaussian_ids], # [nnz, ...]
+ size=self.splats[k].size(), # [N, ...]
+ is_coalesced=len(Ks) == 1,
+ )
+
+ if cfg.visible_adam:
+ gaussian_cnt = self.splats.means.shape[0]
+ if cfg.packed:
+ visibility_mask = torch.zeros_like(
+ self.splats["opacities"], dtype=bool
+ )
+ visibility_mask.scatter_(0, info["gaussian_ids"], 1)
+ else:
+ visibility_mask = (info["radii"] > 0).all(-1).any(0)
+
+ # optimize
+ for optimizer in self.optimizers.values():
+ if cfg.visible_adam:
+ optimizer.step(visibility_mask)
+ else:
+ optimizer.step()
+ optimizer.zero_grad(set_to_none=True)
+ for optimizer in self.pose_optimizers:
+ optimizer.step()
+ optimizer.zero_grad(set_to_none=True)
+ for optimizer in self.app_optimizers:
+ optimizer.step()
+ optimizer.zero_grad(set_to_none=True)
+ for optimizer in self.bil_grid_optimizers:
+ optimizer.step()
+ optimizer.zero_grad(set_to_none=True)
+ for scheduler in schedulers:
+ scheduler.step()
+
+ # Run post-backward steps after backward and optimizer
+ if isinstance(self.cfg.strategy, DefaultStrategy):
+ self.cfg.strategy.step_post_backward(
+ params=self.splats,
+ optimizers=self.optimizers,
+ state=self.strategy_state,
+ step=step,
+ info=info,
+ packed=cfg.packed,
+ )
+ elif isinstance(self.cfg.strategy, MCMCStrategy):
+ self.cfg.strategy.step_post_backward(
+ params=self.splats,
+ optimizers=self.optimizers,
+ state=self.strategy_state,
+ step=step,
+ info=info,
+ lr=schedulers[0].get_last_lr()[0],
+ )
+ else:
+ assert_never(self.cfg.strategy)
+
+ # eval the full set
+ if step in [i - 1 for i in cfg.eval_steps]:
+ self.eval(step)
+ # self.render_traj(step)
+
+ # run compression
+ if cfg.compression is not None and step in [i - 1 for i in cfg.eval_steps]:
+ self.run_compression(step=step)
+
+ if not cfg.disable_viewer:
+ self.viewer.lock.release()
+ num_train_steps_per_sec = 1.0 / (time.time() - tic)
+ num_train_rays_per_sec = (
+ num_train_rays_per_step * num_train_steps_per_sec
+ )
+ # Update the viewer state.
+ self.viewer.render_tab_state.num_train_rays_per_sec = (
+ num_train_rays_per_sec
+ )
+ # Update the scene.
+ self.viewer.update(step, num_train_rays_per_step)
+
+ @torch.no_grad()
+ def eval(self, step: int, stage: str = "val"):
+ """Entry for evaluation."""
+ print("Running evaluation...")
+ cfg = self.cfg
+ device = self.device
+ world_rank = self.world_rank
+ world_size = self.world_size
+
+ valloader = torch.utils.data.DataLoader(
+ self.valset, batch_size=1, shuffle=False, num_workers=1
+ )
+ ellipse_time = 0
+ metrics = defaultdict(list)
+ for i, data in enumerate(valloader):
+ camtoworlds = data["camtoworld"].to(device)
+ Ks = data["K"].to(device)
+ pixels = data["image"].to(device) / 255.0
+ masks = data["mask"].to(device) if "mask" in data else None
+ height, width = pixels.shape[1:3]
+
+ torch.cuda.synchronize()
+ tic = time.time()
+ render_colors, _, _ = self.rasterize_splats(
+ camtoworlds=camtoworlds,
+ Ks=Ks,
+ width=width,
+ height=height,
+ sh_degree=cfg.sh_degree,
+ near_plane=cfg.near_plane,
+ far_plane=cfg.far_plane,
+ # radius_clip=0.1,
+ render_mode="RGB+ED",
+ masks=masks,
+ ) # [1, H, W, 3]
+ torch.cuda.synchronize()
+ ellipse_time += time.time() - tic
+
+ colors = render_colors[..., :3]
+ depths = render_colors[..., 3]
+
+ colors = torch.clamp(colors, 0.0, 1.0)
+ canvas_list = [pixels, colors]
+
+ if world_rank == 0:
+ # write images
+ canvas = torch.cat(canvas_list, dim=2).squeeze(0).cpu().numpy()
+ canvas = (canvas * 255).astype(np.uint8)
+ imageio.imwrite(
+ f"{self.render_dir}/{stage}_step{step}_{i:04d}.png",
+ canvas,
+ )
+ torchvision.utils.save_image(pixels.permute(0, 3, 1, 2), f"{self.render_dir}/gt_rgb_{stage}_step{step}_{i:04d}.png")
+ torchvision.utils.save_image(colors.permute(0, 3, 1, 2), f"{self.render_dir}/render_rgb_{stage}_step{step}_{i:04d}.png")
+ # save depth & normal map
+
+
+ pixels_p = pixels.permute(0, 3, 1, 2) # [1, 3, H, W]
+ colors_p = colors.permute(0, 3, 1, 2) # [1, 3, H, W]
+
+ metrics["psnr"].append(self.psnr(colors_p, pixels_p))
+ metrics["ssim"].append(self.ssim(colors_p, pixels_p))
+ metrics["lpips"].append(self.lpips(colors_p, pixels_p))
+ if cfg.use_bilateral_grid:
+ cc_colors = color_correct(colors, pixels)
+ cc_colors_p = cc_colors.permute(0, 3, 1, 2) # [1, 3, H, W]
+ metrics["cc_psnr"].append(self.psnr(cc_colors_p, pixels_p))
+
+ if world_rank == 0:
+ ellipse_time /= len(valloader)
+
+ stats = {k: torch.stack(v).mean().item() for k, v in metrics.items()}
+ stats.update(
+ {
+ "ellipse_time": ellipse_time,
+ "num_GS": len(self.splats["means"]),
+ }
+ )
+ print(
+ f"PSNR: {stats['psnr']:.3f}, SSIM: {stats['ssim']:.4f}, LPIPS: {stats['lpips']:.3f} "
+ f"Time: {stats['ellipse_time']:.3f}s/image "
+ f"Number of GS: {stats['num_GS']}"
+ )
+ # save stats as json
+ with open(f"{self.stats_dir}/{stage}_step{step:04d}.json", "w") as f:
+ json.dump(stats, f)
+ # save stats to tensorboard
+ for k, v in stats.items():
+ self.writer.add_scalar(f"{stage}/{k}", v, step)
+ self.writer.flush()
+
+ @torch.no_grad()
+ def render_traj(self, step: int):
+ """Entry for trajectory rendering."""
+ if self.cfg.disable_video:
+ return
+ print("Running trajectory rendering...")
+ cfg = self.cfg
+ device = self.device
+
+ camtoworlds_all = self.parser.camtoworlds[5:-5]
+ if cfg.render_traj_path == "interp":
+ camtoworlds_all = generate_interpolated_path(
+ camtoworlds_all, 1
+ ) # [N, 3, 4]
+ elif cfg.render_traj_path == "ellipse":
+ height = camtoworlds_all[:, 2, 3].mean()
+ camtoworlds_all = generate_ellipse_path_z(
+ camtoworlds_all, height=height
+ ) # [N, 3, 4]
+ elif cfg.render_traj_path == "spiral":
+ camtoworlds_all = generate_spiral_path(
+ camtoworlds_all,
+ bounds=self.parser.bounds * self.scene_scale,
+ spiral_scale_r=self.parser.extconf["spiral_radius_scale"],
+ )
+ else:
+ raise ValueError(
+ f"Render trajectory type not supported: {cfg.render_traj_path}"
+ )
+
+ camtoworlds_all = np.concatenate(
+ [
+ camtoworlds_all,
+ np.repeat(
+ np.array([[[0.0, 0.0, 0.0, 1.0]]]), len(camtoworlds_all), axis=0
+ ),
+ ],
+ axis=1,
+ ) # [N, 4, 4]
+
+ camtoworlds_all = torch.from_numpy(camtoworlds_all).float().to(device)
+ K = torch.from_numpy(list(self.parser.Ks_dict.values())[0]).float().to(device)
+ width, height = list(self.parser.imsize_dict.values())[0]
+
+ # save to video
+ video_dir = f"{cfg.result_dir}/videos"
+ os.makedirs(video_dir, exist_ok=True)
+ writer = imageio.get_writer(f"{video_dir}/traj_{step}.mp4", fps=30)
+ for i in tqdm.trange(len(camtoworlds_all), desc="Rendering trajectory"):
+ camtoworlds = camtoworlds_all[i : i + 1]
+ Ks = K[None]
+
+ renders, _, _ = self.rasterize_splats(
+ camtoworlds=camtoworlds,
+ Ks=Ks,
+ width=width,
+ height=height,
+ sh_degree=cfg.sh_degree,
+ near_plane=cfg.near_plane,
+ far_plane=cfg.far_plane,
+ render_mode="RGB+ED",
+ ) # [1, H, W, 4]
+ colors = torch.clamp(renders[..., 0:3], 0.0, 1.0) # [1, H, W, 3]
+ depths = renders[..., 3:4] # [1, H, W, 1]
+ depths = (depths - depths.min()) / (depths.max() - depths.min())
+ canvas_list = [colors, depths.repeat(1, 1, 1, 3)]
+
+ # write images
+ canvas = torch.cat(canvas_list, dim=2).squeeze(0).cpu().numpy()
+ canvas = (canvas * 255).astype(np.uint8)
+ writer.append_data(canvas)
+ writer.close()
+ print(f"Video saved to {video_dir}/traj_{step}.mp4")
+
+ @torch.no_grad()
+ def run_compression(self, step: int):
+ """Entry for running compression."""
+ print("Running compression...")
+ world_rank = self.world_rank
+
+ compress_dir = f"{cfg.result_dir}/compression/rank{world_rank}"
+ os.makedirs(compress_dir, exist_ok=True)
+
+ self.compression_method.compress(compress_dir, self.splats)
+
+ # evaluate compression
+ splats_c = self.compression_method.decompress(compress_dir)
+ for k in splats_c.keys():
+ self.splats[k].data = splats_c[k].to(self.device)
+ self.eval(step=step, stage="compress")
+
+ @torch.no_grad()
+ def _viewer_render_fn(
+ self, camera_state: CameraState, render_tab_state: RenderTabState
+ ):
+ assert isinstance(render_tab_state, GsplatRenderTabState)
+ if render_tab_state.preview_render:
+ width = render_tab_state.render_width
+ height = render_tab_state.render_height
+ else:
+ width = render_tab_state.viewer_width
+ height = render_tab_state.viewer_height
+ c2w = camera_state.c2w
+ K = camera_state.get_K((width, height))
+ c2w = torch.from_numpy(c2w).float().to(self.device)
+ K = torch.from_numpy(K).float().to(self.device)
+
+ RENDER_MODE_MAP = {
+ "rgb": "RGB",
+ "depth(accumulated)": "D",
+ "depth(expected)": "ED",
+ "alpha": "RGB",
+ }
+
+ render_colors, render_alphas, info = self.rasterize_splats(
+ camtoworlds=c2w[None],
+ Ks=K[None],
+ width=width,
+ height=height,
+ sh_degree=min(render_tab_state.max_sh_degree, self.cfg.sh_degree),
+ near_plane=render_tab_state.near_plane,
+ far_plane=render_tab_state.far_plane,
+ radius_clip=render_tab_state.radius_clip,
+ # radius_clip=0.1,
+ eps2d=render_tab_state.eps2d,
+ backgrounds=torch.tensor([render_tab_state.backgrounds], device=self.device)
+ / 255.0,
+ render_mode=RENDER_MODE_MAP[render_tab_state.render_mode],
+ rasterize_mode=render_tab_state.rasterize_mode,
+ camera_model=render_tab_state.camera_model,
+ ) # [1, H, W, 3]
+ render_tab_state.total_gs_count = len(self.splats["means"])
+ render_tab_state.rendered_gs_count = (info["radii"] > 0).all(-1).sum().item()
+
+ if render_tab_state.render_mode == "rgb":
+ # colors represented with sh are not guranteed to be in [0, 1]
+ render_colors = render_colors[0, ..., 0:3].clamp(0, 1)
+ renders = render_colors.cpu().numpy()
+ elif render_tab_state.render_mode in ["depth(accumulated)", "depth(expected)"]:
+ # normalize depth to [0, 1]
+ depth = render_colors[0, ..., 0:1]
+ if render_tab_state.normalize_nearfar:
+ near_plane = render_tab_state.near_plane
+ far_plane = render_tab_state.far_plane
+ else:
+ near_plane = depth.min()
+ far_plane = depth.max()
+ depth_norm = (depth - near_plane) / (far_plane - near_plane + 1e-10)
+ depth_norm = torch.clip(depth_norm, 0, 1)
+ if render_tab_state.inverse:
+ depth_norm = 1 - depth_norm
+ renders = (
+ apply_float_colormap(depth_norm, render_tab_state.colormap)
+ .cpu()
+ .numpy()
+ )
+ elif render_tab_state.render_mode == "alpha":
+ alpha = render_alphas[0, ..., 0:1]
+ if render_tab_state.inverse:
+ alpha = 1 - alpha
+ renders = (
+ apply_float_colormap(alpha, render_tab_state.colormap).cpu().numpy()
+ )
+ return renders
+
+
+def main(local_rank: int, world_rank, world_size: int, cfg: Config):
+ if world_size > 1 and not cfg.disable_viewer:
+ cfg.disable_viewer = True
+ if world_rank == 0:
+ print("Viewer is disabled in distributed training.")
+
+ runner = Runner(local_rank, world_rank, world_size, cfg)
+
+ if cfg.ckpt is not None:
+ # run eval only
+ ckpts = [
+ torch.load(file, map_location=runner.device, weights_only=True)
+ for file in cfg.ckpt
+ ]
+ for k in runner.splats.keys():
+ runner.splats[k].data = torch.cat([ckpt["splats"][k] for ckpt in ckpts])
+ step = ckpts[0]["step"]
+ runner.eval(step=step)
+ # runner.render_traj(step=step)
+ if cfg.compression is not None:
+ runner.run_compression(step=step)
+ else:
+ runner.train()
+ runner.eval(step=runner.cfg.max_steps)
+ # runner.render_traj(step=runner.cfg.max_steps)
+ print("Training complete.")
+ # runner.viewer.complete()
+ # if not cfg.disable_viewer:
+ # print("Viewer running... Ctrl+C to exit.")
+ # time.sleep(1000000)
+
+
+if __name__ == "__main__":
+ """
+ Usage:
+
+ ```bash
+ # Single GPU training
+ CUDA_VISIBLE_DEVICES=9 python -m examples.simple_trainer default
+
+ # Distributed training on 4 GPUs: Effectively 4x batch size so run 4x less steps.
+ CUDA_VISIBLE_DEVICES=0,1,2,3 python simple_trainer.py default --steps_scaler 0.25
+
+ """
+
+ # Config objects we can choose between.
+ # Each is a tuple of (CLI description, config object).
+ configs = {
+ "default": (
+ "Gaussian splatting training using densification heuristics from the original paper.",
+ Config(
+ strategy=DefaultStrategy(verbose=True),
+ ),
+ ),
+ "mcmc": (
+ "Gaussian splatting training using densification from the paper '3D Gaussian Splatting as Markov Chain Monte Carlo'.",
+ Config(
+ init_opa=0.5,
+ init_scale=0.1,
+ opacity_reg=0.01,
+ scale_reg=0.01,
+ strategy=MCMCStrategy(verbose=True),
+ ),
+ ),
+ }
+ cfg = tyro.extras.overridable_config_cli(configs)
+ cfg.adjust_steps(cfg.steps_scaler)
+
+ # try import extra dependencies
+ if cfg.compression == "png":
+ try:
+ import plas
+ import torchpq
+ except:
+ raise ImportError(
+ "To use PNG compression, you need to install "
+ "torchpq (instruction at https://github.com/DeMoriarty/TorchPQ?tab=readme-ov-file#install) "
+ "and plas (via 'pip install git+https://github.com/fraunhoferhhi/PLAS.git') "
+ )
+
+ cli(main, cfg, verbose=True)
diff --git a/src/post_opt/simple_viewer.py b/src/post_opt/simple_viewer.py
new file mode 100644
index 0000000000000000000000000000000000000000..fd1191e82d4660194ba6d107b75a9e0c18dcfc06
--- /dev/null
+++ b/src/post_opt/simple_viewer.py
@@ -0,0 +1,275 @@
+import argparse
+import math
+import os
+import time
+
+import imageio
+import numpy as np
+import torch
+import torch.nn.functional as F
+import tqdm
+import viser
+from pathlib import Path
+from gsplat._helper import load_test_data
+from gsplat.distributed import cli
+from gsplat.rendering import rasterization
+
+from nerfview import CameraState, RenderTabState, apply_float_colormap
+from examples.gsplat_viewer import GsplatViewer, GsplatRenderTabState
+
+
+def main(local_rank: int, world_rank, world_size: int, args):
+ torch.manual_seed(42)
+ device = torch.device("cuda", local_rank)
+
+ if args.ckpt is None:
+ (
+ means,
+ quats,
+ scales,
+ opacities,
+ colors,
+ viewmats,
+ Ks,
+ width,
+ height,
+ ) = load_test_data(device=device, scene_grid=args.scene_grid)
+
+ assert world_size <= 2
+ means = means[world_rank::world_size].contiguous()
+ means.requires_grad = True
+ quats = quats[world_rank::world_size].contiguous()
+ quats.requires_grad = True
+ scales = scales[world_rank::world_size].contiguous()
+ scales.requires_grad = True
+ opacities = opacities[world_rank::world_size].contiguous()
+ opacities.requires_grad = True
+ colors = colors[world_rank::world_size].contiguous()
+ colors.requires_grad = True
+
+ viewmats = viewmats[world_rank::world_size][:1].contiguous()
+ Ks = Ks[world_rank::world_size][:1].contiguous()
+
+ sh_degree = None
+ C = len(viewmats)
+ N = len(means)
+ print("rank", world_rank, "Number of Gaussians:", N, "Number of Cameras:", C)
+
+ # batched render
+ for _ in tqdm.trange(1):
+ render_colors, render_alphas, meta = rasterization(
+ means, # [N, 3]
+ quats, # [N, 4]
+ scales, # [N, 3]
+ opacities, # [N]
+ colors, # [N, S, 3]
+ viewmats, # [C, 4, 4]
+ Ks, # [C, 3, 3]
+ width,
+ height,
+ render_mode="RGB+D",
+ packed=False,
+ distributed=world_size > 1,
+ )
+ C = render_colors.shape[0]
+ assert render_colors.shape == (C, height, width, 4)
+ assert render_alphas.shape == (C, height, width, 1)
+ render_colors.sum().backward()
+
+ render_rgbs = render_colors[..., 0:3]
+ render_depths = render_colors[..., 3:4]
+ render_depths = render_depths / render_depths.max()
+
+ # dump batch images
+ os.makedirs(args.output_dir, exist_ok=True)
+ canvas = (
+ torch.cat(
+ [
+ render_rgbs.reshape(C * height, width, 3),
+ render_depths.reshape(C * height, width, 1).expand(-1, -1, 3),
+ render_alphas.reshape(C * height, width, 1).expand(-1, -1, 3),
+ ],
+ dim=1,
+ )
+ .detach()
+ .cpu()
+ .numpy()
+ )
+ imageio.imsave(
+ f"{args.output_dir}/render_rank{world_rank}.png",
+ (canvas * 255).astype(np.uint8),
+ )
+ else:
+ means, quats, scales, opacities, sh0, shN = [], [], [], [], [], []
+ for ckpt_path in args.ckpt:
+ ckpt = torch.load(ckpt_path, map_location=device)["splats"]
+ means.append(ckpt["means"])
+ quats.append(F.normalize(ckpt["quats"], p=2, dim=-1))
+ scales.append(torch.exp(ckpt["scales"]))
+ opacities.append(torch.sigmoid(ckpt["opacities"]))
+ sh0.append(ckpt["sh0"])
+ shN.append(ckpt["shN"])
+ means = torch.cat(means, dim=0)
+ quats = torch.cat(quats, dim=0)
+ scales = torch.cat(scales, dim=0)
+ opacities = torch.cat(opacities, dim=0)
+ sh0 = torch.cat(sh0, dim=0)
+ shN = torch.cat(shN, dim=0)
+ colors = torch.cat([sh0, shN], dim=-2)
+ sh_degree = int(math.sqrt(colors.shape[-2]) - 1)
+
+ # # crop
+ # aabb = torch.tensor((-1.0, -1.0, -1.0, 1.0, 1.0, 0.7), device=device)
+ # edges = aabb[3:] - aabb[:3]
+ # sel = ((means >= aabb[:3]) & (means <= aabb[3:])).all(dim=-1)
+ # sel = torch.where(sel)[0]
+ # means, quats, scales, colors, opacities = (
+ # means[sel],
+ # quats[sel],
+ # scales[sel],
+ # colors[sel],
+ # opacities[sel],
+ # )
+
+ # # repeat the scene into a grid (to mimic a large-scale setting)
+ # repeats = args.scene_grid
+ # gridx, gridy = torch.meshgrid(
+ # [
+ # torch.arange(-(repeats // 2), repeats // 2 + 1, device=device),
+ # torch.arange(-(repeats // 2), repeats // 2 + 1, device=device),
+ # ],
+ # indexing="ij",
+ # )
+ # grid = torch.stack([gridx, gridy, torch.zeros_like(gridx)], dim=-1).reshape(
+ # -1, 3
+ # )
+ # means = means[None, :, :] + grid[:, None, :] * edges[None, None, :]
+ # means = means.reshape(-1, 3)
+ # quats = quats.repeat(repeats**2, 1)
+ # scales = scales.repeat(repeats**2, 1)
+ # colors = colors.repeat(repeats**2, 1, 1)
+ # opacities = opacities.repeat(repeats**2)
+ print("Number of Gaussians:", len(means))
+
+ # register and open viewer
+ @torch.no_grad()
+ def viewer_render_fn(camera_state: CameraState, render_tab_state: RenderTabState):
+ assert isinstance(render_tab_state, GsplatRenderTabState)
+ if render_tab_state.preview_render:
+ width = render_tab_state.render_width
+ height = render_tab_state.render_height
+ else:
+ width = render_tab_state.viewer_width
+ height = render_tab_state.viewer_height
+ c2w = camera_state.c2w
+ K = camera_state.get_K((width, height))
+ c2w = torch.from_numpy(c2w).float().to(device)
+ K = torch.from_numpy(K).float().to(device)
+ viewmat = c2w.inverse()
+
+ RENDER_MODE_MAP = {
+ "rgb": "RGB",
+ "depth(accumulated)": "D",
+ "depth(expected)": "ED",
+ "alpha": "RGB",
+ }
+
+ render_colors, render_alphas, info = rasterization(
+ means, # [N, 3]
+ quats, # [N, 4]
+ scales, # [N, 3]
+ opacities, # [N]
+ colors, # [N, S, 3]
+ viewmat[None], # [1, 4, 4]
+ K[None], # [1, 3, 3]
+ width,
+ height,
+ sh_degree=(
+ min(render_tab_state.max_sh_degree, sh_degree)
+ if sh_degree is not None
+ else None
+ ),
+ near_plane=render_tab_state.near_plane,
+ far_plane=render_tab_state.far_plane,
+ radius_clip=render_tab_state.radius_clip,
+ eps2d=render_tab_state.eps2d,
+ backgrounds=torch.tensor([render_tab_state.backgrounds], device=device)
+ / 255.0,
+ render_mode=RENDER_MODE_MAP[render_tab_state.render_mode],
+ rasterize_mode=render_tab_state.rasterize_mode,
+ camera_model=render_tab_state.camera_model,
+ )
+ render_tab_state.total_gs_count = len(means)
+ render_tab_state.rendered_gs_count = (info["radii"] > 0).all(-1).sum().item()
+
+ if render_tab_state.render_mode == "rgb":
+ # colors represented with sh are not guranteed to be in [0, 1]
+ render_colors = render_colors[0, ..., 0:3].clamp(0, 1)
+ renders = render_colors.cpu().numpy()
+ elif render_tab_state.render_mode in ["depth(accumulated)", "depth(expected)"]:
+ # normalize depth to [0, 1]
+ depth = render_colors[0, ..., 0:1]
+ if render_tab_state.normalize_nearfar:
+ near_plane = render_tab_state.near_plane
+ far_plane = render_tab_state.far_plane
+ else:
+ near_plane = depth.min()
+ far_plane = depth.max()
+ depth_norm = (depth - near_plane) / (far_plane - near_plane + 1e-10)
+ depth_norm = torch.clip(depth_norm, 0, 1)
+ if render_tab_state.inverse:
+ depth_norm = 1 - depth_norm
+ renders = (
+ apply_float_colormap(depth_norm, render_tab_state.colormap)
+ .cpu()
+ .numpy()
+ )
+ elif render_tab_state.render_mode == "alpha":
+ alpha = render_alphas[0, ..., 0:1]
+ if render_tab_state.inverse:
+ alpha = 1 - alpha
+ renders = (
+ apply_float_colormap(alpha, render_tab_state.colormap).cpu().numpy()
+ )
+ return renders
+
+ server = viser.ViserServer(port=args.port, verbose=False)
+ _ = GsplatViewer(
+ server=server,
+ render_fn=viewer_render_fn,
+ output_dir=Path(args.output_dir),
+ mode="rendering",
+ )
+ print("Viewer running... Ctrl+C to exit.")
+ time.sleep(100000)
+
+
+if __name__ == "__main__":
+ """
+ # Use single GPU to view the scene
+ CUDA_VISIBLE_DEVICES=9 python -m simple_viewer \
+ --ckpt results/garden/ckpts/ckpt_6999_rank0.pt \
+ --output_dir results/garden/ \
+ --port 8082
+
+ CUDA_VISIBLE_DEVICES=9 python -m simple_viewer \
+ --output_dir results/garden/ \
+ --port 8082
+ """
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--output_dir", type=str, default="results/", help="where to dump outputs"
+ )
+ parser.add_argument(
+ "--scene_grid", type=int, default=1, help="repeat the scene into a grid of NxN"
+ )
+ parser.add_argument(
+ "--ckpt", type=str, nargs="+", default=None, help="path to the .pt file"
+ )
+ parser.add_argument(
+ "--port", type=int, default=8080, help="port for the viewer server"
+ )
+ args = parser.parse_args()
+ assert args.scene_grid % 2 == 1, "scene_grid must be odd"
+
+ cli(main, args, verbose=True)
diff --git a/src/post_opt/utils.py b/src/post_opt/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..b79f4244bd6bfcc1c8aff85fd7bfe59608410e41
--- /dev/null
+++ b/src/post_opt/utils.py
@@ -0,0 +1,224 @@
+import random
+
+import numpy as np
+import torch
+from sklearn.neighbors import NearestNeighbors
+from torch import Tensor
+import torch.nn.functional as F
+import matplotlib.pyplot as plt
+from matplotlib import colormaps
+
+
+class CameraOptModule(torch.nn.Module):
+ """Camera pose optimization module."""
+
+ def __init__(self, n: int):
+ super().__init__()
+ # Delta positions (3D) + Delta rotations (6D)
+ self.embeds = torch.nn.Embedding(n, 9)
+ # Identity rotation in 6D representation
+ self.register_buffer("identity", torch.tensor([1.0, 0.0, 0.0, 0.0, 1.0, 0.0]))
+
+ def zero_init(self):
+ torch.nn.init.zeros_(self.embeds.weight)
+
+ def random_init(self, std: float):
+ torch.nn.init.normal_(self.embeds.weight, std=std)
+
+ def forward(self, camtoworlds: Tensor, embed_ids: Tensor) -> Tensor:
+ """Adjust camera pose based on deltas.
+
+ Args:
+ camtoworlds: (..., 4, 4)
+ embed_ids: (...,)
+
+ Returns:
+ updated camtoworlds: (..., 4, 4)
+ """
+ assert camtoworlds.shape[:-2] == embed_ids.shape
+ batch_shape = camtoworlds.shape[:-2]
+ pose_deltas = self.embeds(embed_ids) # (..., 9)
+ dx, drot = pose_deltas[..., :3], pose_deltas[..., 3:]
+ rot = rotation_6d_to_matrix(
+ drot + self.identity.expand(*batch_shape, -1)
+ ) # (..., 3, 3)
+ transform = torch.eye(4, device=pose_deltas.device).repeat((*batch_shape, 1, 1))
+ transform[..., :3, :3] = rot
+ transform[..., :3, 3] = dx
+ return torch.matmul(camtoworlds, transform)
+
+
+class AppearanceOptModule(torch.nn.Module):
+ """Appearance optimization module."""
+
+ def __init__(
+ self,
+ n: int,
+ feature_dim: int,
+ embed_dim: int = 16,
+ sh_degree: int = 3,
+ mlp_width: int = 64,
+ mlp_depth: int = 2,
+ ):
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.sh_degree = sh_degree
+ self.embeds = torch.nn.Embedding(n, embed_dim)
+ layers = []
+ layers.append(
+ torch.nn.Linear(embed_dim + feature_dim + (sh_degree + 1) ** 2, mlp_width)
+ )
+ layers.append(torch.nn.ReLU(inplace=True))
+ for _ in range(mlp_depth - 1):
+ layers.append(torch.nn.Linear(mlp_width, mlp_width))
+ layers.append(torch.nn.ReLU(inplace=True))
+ layers.append(torch.nn.Linear(mlp_width, 3))
+ self.color_head = torch.nn.Sequential(*layers)
+
+ def forward(
+ self, features: Tensor, embed_ids: Tensor, dirs: Tensor, sh_degree: int
+ ) -> Tensor:
+ """Adjust appearance based on embeddings.
+
+ Args:
+ features: (N, feature_dim)
+ embed_ids: (C,)
+ dirs: (C, N, 3)
+
+ Returns:
+ colors: (C, N, 3)
+ """
+ from gsplat.cuda._torch_impl import _eval_sh_bases_fast
+
+ C, N = dirs.shape[:2]
+ # Camera embeddings
+ if embed_ids is None:
+ embeds = torch.zeros(C, self.embed_dim, device=features.device)
+ else:
+ embeds = self.embeds(embed_ids) # [C, D2]
+ embeds = embeds[:, None, :].expand(-1, N, -1) # [C, N, D2]
+ # GS features
+ features = features[None, :, :].expand(C, -1, -1) # [C, N, D1]
+ # View directions
+ dirs = F.normalize(dirs, dim=-1) # [C, N, 3]
+ num_bases_to_use = (sh_degree + 1) ** 2
+ num_bases = (self.sh_degree + 1) ** 2
+ sh_bases = torch.zeros(C, N, num_bases, device=features.device) # [C, N, K]
+ sh_bases[:, :, :num_bases_to_use] = _eval_sh_bases_fast(num_bases_to_use, dirs)
+ # Get colors
+ if self.embed_dim > 0:
+ h = torch.cat([embeds, features, sh_bases], dim=-1) # [C, N, D1 + D2 + K]
+ else:
+ h = torch.cat([features, sh_bases], dim=-1)
+ colors = self.color_head(h)
+ return colors
+
+
+def rotation_6d_to_matrix(d6: Tensor) -> Tensor:
+ """
+ Converts 6D rotation representation by Zhou et al. [1] to rotation matrix
+ using Gram--Schmidt orthogonalization per Section B of [1]. Adapted from pytorch3d.
+ Args:
+ d6: 6D rotation representation, of size (*, 6)
+
+ Returns:
+ batch of rotation matrices of size (*, 3, 3)
+
+ [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
+ On the Continuity of Rotation Representations in Neural Networks.
+ IEEE Conference on Computer Vision and Pattern Recognition, 2019.
+ Retrieved from http://arxiv.org/abs/1812.07035
+ """
+
+ a1, a2 = d6[..., :3], d6[..., 3:]
+ b1 = F.normalize(a1, dim=-1)
+ b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1
+ b2 = F.normalize(b2, dim=-1)
+ b3 = torch.cross(b1, b2, dim=-1)
+ return torch.stack((b1, b2, b3), dim=-2)
+
+
+def knn(x: Tensor, K: int = 4) -> Tensor:
+ x_np = x.cpu().numpy()
+ model = NearestNeighbors(n_neighbors=K, metric="euclidean").fit(x_np)
+ distances, _ = model.kneighbors(x_np)
+ return torch.from_numpy(distances).to(x)
+
+
+def rgb_to_sh(rgb: Tensor) -> Tensor:
+ C0 = 0.28209479177387814
+ return (rgb - 0.5) / C0
+
+
+def set_random_seed(seed: int):
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+
+
+# ref: https://github.com/hbb1/2d-gaussian-splatting/blob/main/utils/general_utils.py#L163
+def colormap(img, cmap="jet"):
+ W, H = img.shape[:2]
+ dpi = 300
+ fig, ax = plt.subplots(1, figsize=(H / dpi, W / dpi), dpi=dpi)
+ im = ax.imshow(img, cmap=cmap)
+ ax.set_axis_off()
+ fig.colorbar(im, ax=ax)
+ fig.tight_layout()
+ fig.canvas.draw()
+ data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
+ img = torch.from_numpy(data).float().permute(2, 0, 1)
+ plt.close()
+ return img
+
+
+def apply_float_colormap(img: torch.Tensor, colormap: str = "turbo") -> torch.Tensor:
+ """Convert single channel to a color img.
+
+ Args:
+ img (torch.Tensor): (..., 1) float32 single channel image.
+ colormap (str): Colormap for img.
+
+ Returns:
+ (..., 3) colored img with colors in [0, 1].
+ """
+ img = torch.nan_to_num(img, 0)
+ if colormap == "gray":
+ return img.repeat(1, 1, 3)
+ img_long = (img * 255).long()
+ img_long_min = torch.min(img_long)
+ img_long_max = torch.max(img_long)
+ assert img_long_min >= 0, f"the min value is {img_long_min}"
+ assert img_long_max <= 255, f"the max value is {img_long_max}"
+ return torch.tensor(
+ colormaps[colormap].colors, # type: ignore
+ device=img.device,
+ )[img_long[..., 0]]
+
+
+def apply_depth_colormap(
+ depth: torch.Tensor,
+ acc: torch.Tensor = None,
+ near_plane: float = None,
+ far_plane: float = None,
+) -> torch.Tensor:
+ """Converts a depth image to color for easier analysis.
+
+ Args:
+ depth (torch.Tensor): (..., 1) float32 depth.
+ acc (torch.Tensor | None): (..., 1) optional accumulation mask.
+ near_plane: Closest depth to consider. If None, use min image value.
+ far_plane: Furthest depth to consider. If None, use max image value.
+
+ Returns:
+ (..., 3) colored depth image with colors in [0, 1].
+ """
+ near_plane = near_plane or float(torch.min(depth))
+ far_plane = far_plane or float(torch.max(depth))
+ depth = (depth - near_plane) / (far_plane - near_plane + 1e-10)
+ depth = torch.clip(depth, 0.0, 1.0)
+ img = apply_float_colormap(depth, colormap="turbo")
+ if acc is not None:
+ img = img * acc + (1.0 - acc)
+ return img
diff --git a/src/utils/__init.py b/src/utils/__init.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/utils/ba.py b/src/utils/ba.py
new file mode 100644
index 0000000000000000000000000000000000000000..6844f8fcc332c3ad484ba0816428d5471379413d
--- /dev/null
+++ b/src/utils/ba.py
@@ -0,0 +1,407 @@
+import pycolmap
+import numpy as np
+
+import torch
+import torch.nn.functional as F
+from src.model.encoder.vggt.utils.pose_enc import pose_encoding_to_extri_intri
+from src.model.encoder.vggt.utils.geometry import closed_form_inverse_se3, unproject_depth_map_to_point_map
+from lightglue import ALIKED, SuperPoint, SIFT
+from src.utils.tensor_to_pycolmap import batch_matrix_to_pycolmap, pycolmap_to_batch_matrix
+
+
+_RESNET_MEAN = [0.485, 0.456, 0.406]
+_RESNET_STD = [0.229, 0.224, 0.225]
+
+
+def generate_rank_by_dino(
+ images, query_frame_num, image_size=518, model_name="dinov2_vitb14_reg", device="cuda", spatial_similarity=True
+):
+ """
+ Generate a ranking of frames using DINO ViT features.
+
+ Args:
+ images: Tensor of shape (S, 3, H, W) with values in range [0, 1]
+ query_frame_num: Number of frames to select
+ image_size: Size to resize images to before processing
+ model_name: Name of the DINO model to use
+ device: Device to run the model on
+ spatial_similarity: Whether to use spatial token similarity or CLS token similarity
+
+ Returns:
+ List of frame indices ranked by their representativeness
+ """
+ dino_v2_model = torch.hub.load('facebookresearch/dinov2', model_name)
+ dino_v2_model.eval()
+ dino_v2_model = dino_v2_model.to(device)
+
+ resnet_mean = torch.tensor(_RESNET_MEAN, device=device).view(1, 3, 1, 1)
+ resnet_std = torch.tensor(_RESNET_STD, device=device).view(1, 3, 1, 1)
+ images_resnet_norm = (images - resnet_mean) / resnet_std
+
+ with torch.no_grad():
+ frame_feat = dino_v2_model(images_resnet_norm, is_training=True)
+
+ if spatial_similarity:
+ frame_feat = frame_feat["x_norm_patchtokens"]
+ frame_feat_norm = F.normalize(frame_feat, p=2, dim=1)
+
+ # Compute the similarity matrix
+ frame_feat_norm = frame_feat_norm.permute(1, 0, 2)
+ similarity_matrix = torch.bmm(
+ frame_feat_norm, frame_feat_norm.transpose(-1, -2)
+ )
+ similarity_matrix = similarity_matrix.mean(dim=0)
+ else:
+ frame_feat = frame_feat["x_norm_clstoken"]
+ frame_feat_norm = F.normalize(frame_feat, p=2, dim=1)
+ similarity_matrix = torch.mm(
+ frame_feat_norm, frame_feat_norm.transpose(-1, -2)
+ )
+
+ distance_matrix = 100 - similarity_matrix.clone()
+
+ # Ignore self-pairing
+ similarity_matrix.fill_diagonal_(-100)
+ similarity_sum = similarity_matrix.sum(dim=1)
+
+ # Find the most common frame
+ most_common_frame_index = torch.argmax(similarity_sum).item()
+
+ # Conduct FPS sampling starting from the most common frame
+ fps_idx = farthest_point_sampling(
+ distance_matrix, query_frame_num, most_common_frame_index
+ )
+
+ return fps_idx
+
+
+def farthest_point_sampling(
+ distance_matrix, num_samples, most_common_frame_index=0
+):
+ """
+ Farthest point sampling algorithm to select diverse frames.
+
+ Args:
+ distance_matrix: Matrix of distances between frames
+ num_samples: Number of frames to select
+ most_common_frame_index: Index of the first frame to select
+
+ Returns:
+ List of selected frame indices
+ """
+ distance_matrix = distance_matrix.clamp(min=0)
+ N = distance_matrix.size(0)
+
+ # Initialize with the most common frame
+ selected_indices = [most_common_frame_index]
+ check_distances = distance_matrix[selected_indices]
+
+ while len(selected_indices) < num_samples:
+ # Find the farthest point from the current set of selected points
+ farthest_point = torch.argmax(check_distances)
+ selected_indices.append(farthest_point.item())
+
+ check_distances = distance_matrix[farthest_point]
+ # Mark already selected points to avoid selecting them again
+ check_distances[selected_indices] = 0
+
+ # Break if all points have been selected
+ if len(selected_indices) == N:
+ break
+
+ return selected_indices
+
+
+def calculate_index_mappings(query_index, S, device=None):
+ """
+ Construct an order that switches [query_index] and [0]
+ so that the content of query_index would be placed at [0].
+
+ Args:
+ query_index: Index to swap with 0
+ S: Total number of elements
+ device: Device to place the tensor on
+
+ Returns:
+ Tensor of indices with the swapped order
+ """
+ new_order = torch.arange(S)
+ new_order[0] = query_index
+ new_order[query_index] = 0
+ if device is not None:
+ new_order = new_order.to(device)
+ return new_order
+
+
+def switch_tensor_order(tensors, order, dim=1):
+ """
+ Reorder tensors along a specific dimension according to the given order.
+
+ Args:
+ tensors: List of tensors to reorder
+ order: Tensor of indices specifying the new order
+ dim: Dimension along which to reorder
+
+ Returns:
+ List of reordered tensors
+ """
+ return [
+ torch.index_select(tensor, dim, order) if tensor is not None else None
+ for tensor in tensors
+ ]
+
+
+def predict_track(model, images, query_points, dtype=torch.bfloat16, use_tf32_for_track=True, iters=4):
+ """
+ Predict tracks for query points across frames.
+
+ Args:
+ model: VGGT model
+ images: Tensor of images of shape (S, 3, H, W)
+ query_points: Query points to track
+ dtype: Data type for computation
+ use_tf32_for_track: Whether to use TF32 precision for tracking
+ iters: Number of iterations for tracking
+
+ Returns:
+ Predicted tracks, visibility scores, and confidence scores
+ """
+ with torch.no_grad():
+ with torch.cuda.amp.autocast(dtype=dtype):
+ images = images[None] # add batch dimension
+ aggregated_tokens_list, ps_idx = model.aggregator(images)
+
+ if not use_tf32_for_track:
+ track_list, vis_score, conf_score = model.track_head(
+ aggregated_tokens_list, images, ps_idx, query_points=query_points, iters=iters
+ )
+
+ if use_tf32_for_track:
+ with torch.cuda.amp.autocast(enabled=False):
+ track_list, vis_score, conf_score = model.track_head(
+ aggregated_tokens_list, images, ps_idx, query_points=query_points, iters=iters
+ )
+
+ pred_track = track_list[-1]
+
+ return pred_track.squeeze(0), vis_score.squeeze(0), conf_score.squeeze(0)
+
+
+def initialize_feature_extractors(max_query_num, det_thres, extractor_method="aliked", device="cuda"):
+ """
+ Initialize feature extractors that can be reused based on a method string.
+
+ Args:
+ max_query_num: Maximum number of keypoints to extract
+ det_thres: Detection threshold for keypoint extraction
+ extractor_method: String specifying which extractors to use (e.g., "aliked", "sp+sift", "aliked+sp+sift")
+ device: Device to run extraction on
+
+ Returns:
+ Dictionary of initialized extractors
+ """
+ extractors = {}
+ methods = extractor_method.lower().split('+')
+ active_extractors = len(methods)
+
+ for method in methods:
+ method = method.strip()
+ if method == "aliked":
+ aliked_max_points = max_query_num // active_extractors
+ aliked_extractor = ALIKED(max_num_keypoints=aliked_max_points, detection_threshold=det_thres)
+ extractors['aliked'] = aliked_extractor.to(device).eval()
+ elif method == "sp":
+ sp_max_points = max_query_num // active_extractors
+ sp_extractor = SuperPoint(max_num_keypoints=sp_max_points, detection_threshold=det_thres)
+ extractors['sp'] = sp_extractor.to(device).eval()
+ elif method == "sift":
+ sift_max_points = max_query_num // active_extractors
+ sift_extractor = SIFT(max_num_keypoints=sift_max_points)
+ extractors['sift'] = sift_extractor.to(device).eval()
+ else:
+ print(f"Warning: Unknown feature extractor '{method}', ignoring.")
+
+ if not extractors:
+ print(f"Warning: No valid extractors found in '{extractor_method}'. Using ALIKED by default.")
+ aliked_extractor = ALIKED(max_num_keypoints=max_query_num, detection_threshold=det_thres)
+ extractors['aliked'] = aliked_extractor.to(device).eval()
+
+ return extractors
+
+
+def extract_keypoints(query_image, extractors):
+ """
+ Extract keypoints using pre-initialized feature extractors.
+
+ Args:
+ query_image: Input image tensor (3xHxW, range [0, 1])
+ extractors: Dictionary of initialized extractors
+
+ Returns:
+ Tensor of keypoint coordinates (1xNx2)
+ """
+ query_points_round = None
+
+ with torch.no_grad():
+ for extractor_name, extractor in extractors.items():
+ query_points_data = extractor.extract(query_image)
+ extractor_points = query_points_data["keypoints"].round()
+
+ if query_points_round is not None:
+ query_points_round = torch.cat([query_points_round, extractor_points], dim=1)
+ else:
+ query_points_round = extractor_points
+
+ return query_points_round
+
+
+def run_vggt_with_ba(model, images, image_names=None, dtype=torch.bfloat16,
+ max_query_num=2048, det_thres=0.005, query_frame_num=3,
+ extractor_method="aliked+sp+sift",
+ max_reproj_error=4,
+ shared_camera=True,
+ ):
+ """
+ Run VGGT with bundle adjustment for pose estimation.
+
+ Args:
+ model: VGGT model
+ images: Tensor of images of shape (S, 3, H, W)
+ image_names: Optional list of image names
+ dtype: Data type for computation
+
+ Returns:
+ Predicted extrinsic camera parameters
+
+ TODO:
+ - [ ] Use VGGT's vit instead of dinov2 for rank generation
+ """
+ device = images.device
+ frame_num = images.shape[0]
+
+ # TODO: use vggt's vit instead of dinov2
+ # Select representative frames for feature extraction
+ query_frame_indexes = generate_rank_by_dino(
+ images, query_frame_num, image_size=518,
+ model_name="dinov2_vitb14_reg", device=device,
+ spatial_similarity=False
+ )
+
+ # Add the first image to the front if not already present
+ if 0 in query_frame_indexes:
+ query_frame_indexes.remove(0)
+ query_frame_indexes = [0, *query_frame_indexes]
+
+ # Get initial pose and depth predictions
+
+ with torch.no_grad(), torch.cuda.amp.autocast(enabled=False, dtype=torch.bfloat16):
+ aggregated_tokens_list, patch_start_idx = model.aggregator(images, intermediate_layer_idx=model.cfg.intermediate_layer_idx)
+ with torch.cuda.amp.autocast(enabled=False):
+ fp32_tokens = [token.float() for token in aggregated_tokens_list]
+ pred_all_pose_enc = model.camera_head(fp32_tokens)[-1]
+ pred_all_extrinsic, pred_all_intrinsic = pose_encoding_to_extri_intri(pred_all_pose_enc, images.shape[-2:])
+ pred_extrinsic = pred_all_extrinsic[0]
+ pred_intrinsic = pred_all_intrinsic[0]
+ depth_map, depth_conf = model.depth_head(aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx)
+
+ world_points = unproject_depth_map_to_point_map(depth_map, pred_extrinsic, pred_intrinsic)
+ world_points = torch.from_numpy(world_points).to(device)
+ world_points_conf = depth_conf.to(device)
+
+ torch.cuda.empty_cache()
+
+ # Lists to store predictions
+ pred_tracks = []
+ pred_vis_scores = []
+ pred_conf_scores = []
+ pred_world_points = []
+ pred_world_points_conf = []
+
+ # Initialize feature extractors
+ extractors = initialize_feature_extractors(max_query_num, det_thres, extractor_method, device)
+
+ # Process each query frame
+ for query_index in query_frame_indexes:
+ query_image = images[query_index]
+ query_points_round = extract_keypoints(query_image, extractors)
+
+ # Reorder images to put query image first
+ reorder_index = calculate_index_mappings(query_index, frame_num, device=device)
+ reorder_images = switch_tensor_order([images], reorder_index, dim=0)[0]
+
+ # Track points across frames
+ reorder_tracks, reorder_vis_score, reorder_conf_score = predict_track(
+ model, reorder_images, query_points_round, dtype=dtype, use_tf32_for_track=True, iters=4
+ )
+
+ # Restore original order
+ pred_track, pred_vis, pred_score = switch_tensor_order(
+ [reorder_tracks, reorder_vis_score, reorder_conf_score], reorder_index, dim=0
+ )
+
+ pred_tracks.append(pred_track)
+ pred_vis_scores.append(pred_vis)
+ pred_conf_scores.append(pred_score)
+
+ # Get corresponding 3D points
+ query_points_round_long = query_points_round.squeeze(0).long()
+ query_world_points = world_points[query_index][
+ query_points_round_long[:, 1], query_points_round_long[:, 0]
+ ]
+ query_world_points_conf = world_points_conf[query_index][
+ query_points_round_long[:, 1], query_points_round_long[:, 0]
+ ]
+
+ pred_world_points.append(query_world_points)
+ pred_world_points_conf.append(query_world_points_conf)
+
+ # Concatenate prediction lists
+ pred_tracks = torch.cat(pred_tracks, dim=1)
+ pred_vis_scores = torch.cat(pred_vis_scores, dim=1)
+ pred_conf_scores = torch.cat(pred_conf_scores, dim=1)
+ pred_world_points = torch.cat(pred_world_points, dim=0)
+ pred_world_points_conf = torch.cat(pred_world_points_conf, dim=0)
+
+ # Filter points by confidence
+ filtered_flag = pred_world_points_conf > 1.5
+
+ if filtered_flag.sum() > 1024:
+ # well if the number of points is too small, we will not filter
+ pred_world_points = pred_world_points[filtered_flag]
+ pred_world_points_conf = pred_world_points_conf[filtered_flag]
+
+ pred_tracks = pred_tracks[:, filtered_flag]
+ pred_vis_scores = pred_vis_scores[:, filtered_flag]
+ pred_conf_scores = pred_conf_scores[:, filtered_flag]
+
+ torch.cuda.empty_cache()
+
+ # Bundle adjustment parameters
+ S, _, H, W = images.shape
+ image_size = torch.tensor([W, H], dtype=pred_tracks.dtype, device=device)
+
+ # Run bundle adjustment
+ reconstruction = batch_matrix_to_pycolmap(
+ pred_world_points,
+ pred_extrinsic,
+ pred_intrinsic,
+ pred_tracks,
+ image_size,
+ max_reproj_error=max_reproj_error,
+ shared_camera=shared_camera
+ )
+
+ ba_options = pycolmap.BundleAdjustmentOptions()
+ pycolmap.bundle_adjustment(reconstruction, ba_options)
+ _, updated_extrinsic, _, _ = pycolmap_to_batch_matrix(
+ reconstruction, device=device, camera_type="SIMPLE_PINHOLE"
+ )
+
+ return updated_extrinsic
+
+
+
+
+
+
+
diff --git a/src/utils/cropping.py b/src/utils/cropping.py
new file mode 100644
index 0000000000000000000000000000000000000000..1e635174fe869d9da83c7428fd5c7e81020cff84
--- /dev/null
+++ b/src/utils/cropping.py
@@ -0,0 +1,163 @@
+# Copyright (C) 2024-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+#
+# --------------------------------------------------------
+# croppping utilities
+# --------------------------------------------------------
+import PIL.Image
+import os
+os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
+import cv2 # noqa
+import numpy as np # noqa
+from src.utils.geometry import colmap_to_opencv_intrinsics, opencv_to_colmap_intrinsics # noqa
+try:
+ lanczos = PIL.Image.Resampling.LANCZOS
+ bicubic = PIL.Image.Resampling.BICUBIC
+except AttributeError:
+ lanczos = PIL.Image.LANCZOS
+ bicubic = PIL.Image.BICUBIC
+
+
+class ImageList:
+ """ Convenience class to aply the same operation to a whole set of images.
+ """
+
+ def __init__(self, images):
+ if not isinstance(images, (tuple, list, set)):
+ images = [images]
+ self.images = []
+ for image in images:
+ if not isinstance(image, PIL.Image.Image):
+ image = PIL.Image.fromarray(image)
+ self.images.append(image)
+
+ def __len__(self):
+ return len(self.images)
+
+ def to_pil(self):
+ return tuple(self.images) if len(self.images) > 1 else self.images[0]
+
+ @property
+ def size(self):
+ sizes = [im.size for im in self.images]
+ assert all(sizes[0] == s for s in sizes)
+ return sizes[0]
+
+ def resize(self, *args, **kwargs):
+ return ImageList(self._dispatch('resize', *args, **kwargs))
+
+ def crop(self, *args, **kwargs):
+ return ImageList(self._dispatch('crop', *args, **kwargs))
+
+ def _dispatch(self, func, *args, **kwargs):
+ return [getattr(im, func)(*args, **kwargs) for im in self.images]
+
+
+def rescale_image(image, camera_intrinsics, output_resolution, force=True):
+ """ Jointly rescale a (image, depthmap)
+ so that (out_width, out_height) >= output_res
+ """
+ image = ImageList(image)
+ input_resolution = np.array(image.size) # (W,H)
+ output_resolution = np.array(output_resolution)
+ # define output resolution
+ assert output_resolution.shape == (2,)
+ scale_final = max(output_resolution / image.size) + 1e-8
+ if scale_final >= 1 and not force: # image is already smaller than what is asked
+ return (image.to_pil(), camera_intrinsics)
+ output_resolution = np.floor(input_resolution * scale_final).astype(int)
+
+ # first rescale the image so that it contains the crop
+ image = image.resize(tuple(output_resolution), resample=lanczos if scale_final < 1 else bicubic)
+
+ # no offset here; simple rescaling
+ camera_intrinsics = camera_matrix_of_crop(
+ camera_intrinsics, input_resolution, output_resolution, scaling=scale_final)
+
+ return image.to_pil(), camera_intrinsics
+
+
+def camera_matrix_of_crop(input_camera_matrix, input_resolution, output_resolution, scaling=1, offset_factor=0.5, offset=None):
+ # Margins to offset the origin
+ margins = np.asarray(input_resolution) * scaling - output_resolution
+ assert np.all(margins >= 0.0)
+ if offset is None:
+ offset = offset_factor * margins
+
+ # Generate new camera parameters
+ output_camera_matrix_colmap = opencv_to_colmap_intrinsics(input_camera_matrix)
+ output_camera_matrix_colmap[:2, :] *= scaling
+ output_camera_matrix_colmap[:2, 2] -= offset
+ output_camera_matrix = colmap_to_opencv_intrinsics(output_camera_matrix_colmap)
+
+ return output_camera_matrix
+
+
+def crop_image(image, camera_intrinsics, crop_bbox):
+ """
+ Return a crop of the input view.
+ """
+ image = ImageList(image)
+ l, t, r, b = crop_bbox
+
+ image = image.crop((l, t, r, b))
+
+ camera_intrinsics = camera_intrinsics.copy()
+ camera_intrinsics[0, 2] -= l
+ camera_intrinsics[1, 2] -= t
+
+ return image.to_pil(), camera_intrinsics
+
+
+def bbox_from_intrinsics_in_out(input_camera_matrix, output_camera_matrix, output_resolution):
+ out_width, out_height = output_resolution
+ l, t = np.int32(np.round(input_camera_matrix[:2, 2] - output_camera_matrix[:2, 2]))
+ crop_bbox = (l, t, l + out_width, t + out_height)
+ return crop_bbox
+
+def rescale_image_depthmap(image, depthmap, camera_intrinsics, output_resolution, force=True):
+ """ Jointly rescale a (image, depthmap)
+ so that (out_width, out_height) >= output_res
+ """
+ image = ImageList(image)
+ input_resolution = np.array(image.size) # (W,H)
+ output_resolution = np.array(output_resolution)
+ if depthmap is not None:
+ # can also use this with masks instead of depthmaps
+ assert tuple(depthmap.shape[:2]) == image.size[::-1]
+
+ # define output resolution
+ assert output_resolution.shape == (2,)
+ scale_final = max(output_resolution / image.size) + 1e-8
+ if scale_final >= 1 and not force: # image is already smaller than what is asked
+ return (image.to_pil(), depthmap, camera_intrinsics)
+ output_resolution = np.floor(input_resolution * scale_final).astype(int)
+
+ # first rescale the image so that it contains the crop
+ # breakpoint()
+ image = image.resize(tuple(output_resolution), resample=lanczos if scale_final < 1 else bicubic)
+ if depthmap is not None:
+ depthmap = cv2.resize(depthmap, tuple(output_resolution), fx=scale_final,
+ fy=scale_final, interpolation=cv2.INTER_NEAREST)
+
+ # no offset here; simple rescaling
+ camera_intrinsics = camera_matrix_of_crop(
+ camera_intrinsics, input_resolution, tuple(output_resolution), scaling=scale_final)
+
+ return image.to_pil(), depthmap, camera_intrinsics
+
+def crop_image_depthmap(image, depthmap, camera_intrinsics, crop_bbox):
+ """
+ Return a crop of the input view.
+ """
+ image = ImageList(image)
+ l, t, r, b = crop_bbox
+
+ image = image.crop((l, t, r, b))
+ depthmap = depthmap[t:b, l:r]
+
+ camera_intrinsics = camera_intrinsics.copy()
+ camera_intrinsics[0, 2] -= l
+ camera_intrinsics[1, 2] -= t
+
+ return image.to_pil(), depthmap, camera_intrinsics
\ No newline at end of file
diff --git a/src/utils/device.py b/src/utils/device.py
new file mode 100644
index 0000000000000000000000000000000000000000..bae8a503be3d75b60258135a5134dfdd7154ad59
--- /dev/null
+++ b/src/utils/device.py
@@ -0,0 +1,95 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# Copyright (C) 2024-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+#
+# --------------------------------------------------------
+# utilitary functions for DUSt3R
+# --------------------------------------------------------
+import numpy as np
+import torch
+
+
+def todevice(batch, device, callback=None, non_blocking=False):
+ """Transfer some variables to another device (i.e. GPU, CPU:torch, CPU:numpy).
+
+ batch: list, tuple, dict of tensors or other things
+ device: pytorch device or 'numpy'
+ callback: function that would be called on every sub-elements.
+ """
+ if callback:
+ batch = callback(batch)
+
+ if isinstance(batch, dict):
+ return {k: todevice(v, device) for k, v in batch.items()}
+
+ if isinstance(batch, (tuple, list)):
+ return type(batch)(todevice(x, device) for x in batch)
+
+ x = batch
+ if device == "numpy":
+ if isinstance(x, torch.Tensor):
+ x = x.detach().cpu().numpy()
+ elif x is not None:
+ if isinstance(x, np.ndarray):
+ x = torch.from_numpy(x)
+ if torch.is_tensor(x):
+ x = x.to(device, non_blocking=non_blocking)
+ return x
+
+
+to_device = todevice # alias
+
+
+def to_numpy(x):
+ return todevice(x, "numpy")
+
+
+def to_cpu(x):
+ return todevice(x, "cpu")
+
+
+def to_cuda(x):
+ return todevice(x, "cuda")
+
+
+def collate_with_cat(whatever, lists=False):
+ if isinstance(whatever, dict):
+ return {k: collate_with_cat(vals, lists=lists) for k, vals in whatever.items()}
+
+ elif isinstance(whatever, (tuple, list)):
+ if len(whatever) == 0:
+ return whatever
+ elem = whatever[0]
+ T = type(whatever)
+
+ if elem is None:
+ return None
+ if isinstance(elem, (bool, float, int, str)):
+ return whatever
+ if isinstance(elem, tuple):
+ return T(collate_with_cat(x, lists=lists) for x in zip(*whatever))
+ if isinstance(elem, dict):
+ return {
+ k: collate_with_cat([e[k] for e in whatever], lists=lists) for k in elem
+ }
+
+ if isinstance(elem, torch.Tensor):
+ return listify(whatever) if lists else torch.cat(whatever)
+ if isinstance(elem, np.ndarray):
+ return (
+ listify(whatever)
+ if lists
+ else torch.cat([torch.from_numpy(x) for x in whatever])
+ )
+
+ # otherwise, we just chain lists
+ return sum(whatever, T())
+
+
+def listify(elems):
+ return [x for e in elems for x in e]
diff --git a/src/utils/geometry.py b/src/utils/geometry.py
new file mode 100644
index 0000000000000000000000000000000000000000..2124b269b641716489ee20767efc77534005a2fd
--- /dev/null
+++ b/src/utils/geometry.py
@@ -0,0 +1,428 @@
+# Copyright (C) 2024-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+#
+# --------------------------------------------------------
+# geometry utilitary functions
+# --------------------------------------------------------
+import torch
+import numpy as np
+from scipy.spatial import cKDTree as KDTree
+
+def invalid_to_nans(arr, valid_mask, ndim=999):
+ if valid_mask is not None:
+ arr = arr.clone()
+ arr[~valid_mask] = float('nan')
+ if arr.ndim > ndim:
+ arr = arr.flatten(-2 - (arr.ndim - ndim), -2)
+ return arr
+
+
+def invalid_to_zeros(arr, valid_mask, ndim=999):
+ if valid_mask is not None:
+ arr = arr.clone()
+ arr[~valid_mask] = 0
+ nnz = valid_mask.view(len(valid_mask), -1).sum(1)
+ else:
+ nnz = arr.numel() // len(arr) if len(arr) else 0 # number of point per image
+ if arr.ndim > ndim:
+ arr = arr.flatten(-2 - (arr.ndim - ndim), -2)
+ return arr, nnz
+
+
+def xy_grid(W, H, device=None, origin=(0, 0), unsqueeze=None, cat_dim=-1, homogeneous=False, **arange_kw):
+ """ Output a (H,W,2) array of int32
+ with output[j,i,0] = i + origin[0]
+ output[j,i,1] = j + origin[1]
+ """
+ if device is None:
+ # numpy
+ arange, meshgrid, stack, ones = np.arange, np.meshgrid, np.stack, np.ones
+ else:
+ # torch
+ arange = lambda *a, **kw: torch.arange(*a, device=device, **kw)
+ meshgrid, stack = torch.meshgrid, torch.stack
+ ones = lambda *a: torch.ones(*a, device=device)
+
+ tw, th = [arange(o, o + s, **arange_kw) for s, o in zip((W, H), origin)]
+ grid = meshgrid(tw, th, indexing='xy')
+ if homogeneous:
+ grid = grid + (ones((H, W)),)
+ if unsqueeze is not None:
+ grid = (grid[0].unsqueeze(unsqueeze), grid[1].unsqueeze(unsqueeze))
+ if cat_dim is not None:
+ grid = stack(grid, cat_dim)
+ return grid
+
+
+def geotrf(Trf, pts, ncol=None, norm=False):
+ """ Apply a geometric transformation to a list of 3-D points.
+
+ H: 3x3 or 4x4 projection matrix (typically a Homography)
+ p: numpy/torch/tuple of coordinates. Shape must be (...,2) or (...,3)
+
+ ncol: int. number of columns of the result (2 or 3)
+ norm: float. if != 0, the resut is projected on the z=norm plane.
+
+ Returns an array of projected 2d points.
+ """
+ assert Trf.ndim >= 2
+ if isinstance(Trf, np.ndarray):
+ pts = np.asarray(pts)
+ elif isinstance(Trf, torch.Tensor):
+ pts = torch.as_tensor(pts, dtype=Trf.dtype)
+
+ # adapt shape if necessary
+ output_reshape = pts.shape[:-1]
+ ncol = ncol or pts.shape[-1]
+
+ # optimized code
+ if (isinstance(Trf, torch.Tensor) and isinstance(pts, torch.Tensor) and
+ Trf.ndim == 3 and pts.ndim == 4):
+ d = pts.shape[3]
+ if Trf.shape[-1] == d:
+ pts = torch.einsum("bij, bhwj -> bhwi", Trf, pts)
+ elif Trf.shape[-1] == d + 1:
+ pts = torch.einsum("bij, bhwj -> bhwi", Trf[:, :d, :d], pts) + Trf[:, None, None, :d, d]
+ else:
+ raise ValueError(f'bad shape, not ending with 3 or 4, for {pts.shape=}')
+ else:
+ if Trf.ndim >= 3:
+ n = Trf.ndim - 2
+ assert Trf.shape[:n] == pts.shape[:n], 'batch size does not match'
+ Trf = Trf.reshape(-1, Trf.shape[-2], Trf.shape[-1])
+
+ if pts.ndim > Trf.ndim:
+ # Trf == (B,d,d) & pts == (B,H,W,d) --> (B, H*W, d)
+ pts = pts.reshape(Trf.shape[0], -1, pts.shape[-1])
+ elif pts.ndim == 2:
+ # Trf == (B,d,d) & pts == (B,d) --> (B, 1, d)
+ pts = pts[:, None, :]
+
+ if pts.shape[-1] + 1 == Trf.shape[-1]:
+ Trf = Trf.swapaxes(-1, -2) # transpose Trf
+ pts = pts @ Trf[..., :-1, :] + Trf[..., -1:, :]
+ elif pts.shape[-1] == Trf.shape[-1]:
+ Trf = Trf.swapaxes(-1, -2) # transpose Trf
+ pts = pts @ Trf
+ else:
+ pts = Trf @ pts.T
+ if pts.ndim >= 2:
+ pts = pts.swapaxes(-1, -2)
+
+ if norm:
+ pts = pts / pts[..., -1:] # DONT DO /= BECAUSE OF WEIRD PYTORCH BUG
+ if norm != 1:
+ pts *= norm
+
+ res = pts[..., :ncol].reshape(*output_reshape, ncol)
+ return res
+
+
+def inv(mat):
+ """ Invert a torch or numpy matrix
+ """
+ if isinstance(mat, torch.Tensor):
+ return torch.linalg.inv(mat)
+ if isinstance(mat, np.ndarray):
+ return np.linalg.inv(mat)
+ raise ValueError(f'bad matrix type = {type(mat)}')
+
+
+def depthmap_to_pts3d(depth, pseudo_focal, pp=None, **_):
+ """
+ Args:
+ - depthmap (BxHxW array):
+ - pseudo_focal: [B,H,W] ; [B,2,H,W] or [B,1,H,W]
+ Returns:
+ pointmap of absolute coordinates (BxHxWx3 array)
+ """
+
+ if len(depth.shape) == 4:
+ B, H, W, n = depth.shape
+ else:
+ B, H, W = depth.shape
+ n = None
+
+ if len(pseudo_focal.shape) == 3: # [B,H,W]
+ pseudo_focalx = pseudo_focaly = pseudo_focal
+ elif len(pseudo_focal.shape) == 4: # [B,2,H,W] or [B,1,H,W]
+ pseudo_focalx = pseudo_focal[:, 0]
+ if pseudo_focal.shape[1] == 2:
+ pseudo_focaly = pseudo_focal[:, 1]
+ else:
+ pseudo_focaly = pseudo_focalx
+ else:
+ raise NotImplementedError("Error, unknown input focal shape format.")
+
+ assert pseudo_focalx.shape == depth.shape[:3]
+ assert pseudo_focaly.shape == depth.shape[:3]
+ grid_x, grid_y = xy_grid(W, H, cat_dim=0, device=depth.device)[:, None]
+
+ # set principal point
+ if pp is None:
+ grid_x = grid_x - (W - 1) / 2
+ grid_y = grid_y - (H - 1) / 2
+ else:
+ grid_x = grid_x.expand(B, -1, -1) - pp[:, 0, None, None]
+ grid_y = grid_y.expand(B, -1, -1) - pp[:, 1, None, None]
+
+ if n is None:
+ pts3d = torch.empty((B, H, W, 3), device=depth.device)
+ pts3d[..., 0] = depth * grid_x / pseudo_focalx
+ pts3d[..., 1] = depth * grid_y / pseudo_focaly
+ pts3d[..., 2] = depth
+ else:
+ pts3d = torch.empty((B, H, W, 3, n), device=depth.device)
+ pts3d[..., 0, :] = depth * (grid_x / pseudo_focalx)[..., None]
+ pts3d[..., 1, :] = depth * (grid_y / pseudo_focaly)[..., None]
+ pts3d[..., 2, :] = depth
+ return pts3d
+
+
+def depthmap_to_camera_coordinates(depthmap, camera_intrinsics, pseudo_focal=None):
+ """
+ Args:
+ - depthmap (HxW array):
+ - camera_intrinsics: a 3x3 matrix
+ Returns:
+ pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels.
+ """
+ camera_intrinsics = np.float32(camera_intrinsics)
+ H, W = depthmap.shape
+
+ # Compute 3D ray associated with each pixel
+ # Strong assumption: there are no skew terms
+ assert camera_intrinsics[0, 1] == 0.0
+ assert camera_intrinsics[1, 0] == 0.0
+ if pseudo_focal is None:
+ fu = camera_intrinsics[0, 0]
+ fv = camera_intrinsics[1, 1]
+ else:
+ assert pseudo_focal.shape == (H, W)
+ fu = fv = pseudo_focal
+ cu = camera_intrinsics[0, 2]
+ cv = camera_intrinsics[1, 2]
+
+ u, v = np.meshgrid(np.arange(W), np.arange(H))
+ z_cam = depthmap
+ x_cam = (u - cu) * z_cam / fu
+ y_cam = (v - cv) * z_cam / fv
+ X_cam = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32)
+
+ # Mask for valid coordinates
+ valid_mask = (depthmap > 0.0)
+ return X_cam, valid_mask
+
+
+def depthmap_to_absolute_camera_coordinates(depthmap, camera_intrinsics, camera_pose=None, **kw):
+ """
+ Args:
+ - depthmap (HxW array):
+ - camera_intrinsics: a 3x3 matrix
+ - camera_pose: a 4x3 or 4x4 cam2world matrix
+ Returns:
+ pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels."""
+ X_cam, valid_mask = depthmap_to_camera_coordinates(depthmap, camera_intrinsics)
+
+ X_world = X_cam # default
+ if camera_pose is not None:
+ # R_cam2world = np.float32(camera_params["R_cam2world"])
+ # t_cam2world = np.float32(camera_params["t_cam2world"]).squeeze()
+ R_cam2world = camera_pose[:3, :3]
+ t_cam2world = camera_pose[:3, 3]
+
+ # Express in absolute coordinates (invalid depth values)
+ X_world = np.einsum("ik, vuk -> vui", R_cam2world, X_cam) + t_cam2world[None, None, :]
+
+ return X_world, valid_mask
+
+
+def colmap_to_opencv_intrinsics(K):
+ """
+ Modify camera intrinsics to follow a different convention.
+ Coordinates of the center of the top-left pixels are by default:
+ - (0.5, 0.5) in Colmap
+ - (0,0) in OpenCV
+ """
+ K = K.copy()
+ K[0, 2] -= 0.5
+ K[1, 2] -= 0.5
+ return K
+
+
+def opencv_to_colmap_intrinsics(K):
+ """
+ Modify camera intrinsics to follow a different convention.
+ Coordinates of the center of the top-left pixels are by default:
+ - (0.5, 0.5) in Colmap
+ - (0,0) in OpenCV
+ """
+ K = K.copy()
+ K[0, 2] += 0.5
+ K[1, 2] += 0.5
+ return K
+
+
+def normalize_pointcloud(pts1, pts2, norm_mode='avg_dis', valid1=None, valid2=None, ret_factor=False):
+ """ renorm pointmaps pts1, pts2 with norm_mode
+ """
+ assert pts1.ndim >= 3 and pts1.shape[-1] == 3
+ assert pts2 is None or (pts2.ndim >= 3 and pts2.shape[-1] == 3)
+ norm_mode, dis_mode = norm_mode.split('_')
+
+ if norm_mode == 'avg':
+ # gather all points together (joint normalization)
+ nan_pts1, nnz1 = invalid_to_zeros(pts1, valid1, ndim=3)
+ nan_pts2, nnz2 = invalid_to_zeros(pts2, valid2, ndim=3) if pts2 is not None else (None, 0)
+ all_pts = torch.cat((nan_pts1, nan_pts2), dim=1) if pts2 is not None else nan_pts1
+
+ # compute distance to origin
+ all_dis = all_pts.norm(dim=-1)
+ if dis_mode == 'dis':
+ pass # do nothing
+ elif dis_mode == 'log1p':
+ all_dis = torch.log1p(all_dis)
+ elif dis_mode == 'warp-log1p':
+ # actually warp input points before normalizing them
+ log_dis = torch.log1p(all_dis)
+ warp_factor = log_dis / all_dis.clip(min=1e-8)
+ H1, W1 = pts1.shape[1:-1]
+ pts1 = pts1 * warp_factor[:, :W1 * H1].view(-1, H1, W1, 1)
+ if pts2 is not None:
+ H2, W2 = pts2.shape[1:-1]
+ pts2 = pts2 * warp_factor[:, W1 * H1:].view(-1, H2, W2, 1)
+ all_dis = log_dis # this is their true distance afterwards
+ else:
+ raise ValueError(f'bad {dis_mode=}')
+
+ norm_factor = all_dis.sum(dim=1) / (nnz1 + nnz2 + 1e-8)
+ else:
+ # gather all points together (joint normalization)
+ nan_pts1 = invalid_to_nans(pts1, valid1, ndim=3)
+ nan_pts2 = invalid_to_nans(pts2, valid2, ndim=3) if pts2 is not None else None
+ all_pts = torch.cat((nan_pts1, nan_pts2), dim=1) if pts2 is not None else nan_pts1
+
+ # compute distance to origin
+ all_dis = all_pts.norm(dim=-1)
+
+ if norm_mode == 'avg':
+ norm_factor = all_dis.nanmean(dim=1)
+ elif norm_mode == 'median':
+ norm_factor = all_dis.nanmedian(dim=1).values.detach()
+ elif norm_mode == 'sqrt':
+ norm_factor = all_dis.sqrt().nanmean(dim=1)**2
+ else:
+ raise ValueError(f'bad {norm_mode=}')
+
+ norm_factor = norm_factor.clip(min=1e-8)
+ while norm_factor.ndim < pts1.ndim:
+ norm_factor.unsqueeze_(-1)
+
+ res = pts1 / norm_factor
+ if pts2 is not None:
+ res = (res, pts2 / norm_factor)
+ if ret_factor:
+ res = res + (norm_factor,)
+ return res
+
+
+@torch.no_grad()
+def get_joint_pointcloud_depth(z1, z2, valid_mask1, valid_mask2=None, quantile=0.5):
+ # set invalid points to NaN
+ _z1 = invalid_to_nans(z1, valid_mask1).reshape(len(z1), -1)
+ _z2 = invalid_to_nans(z2, valid_mask2).reshape(len(z2), -1) if z2 is not None else None
+ _z = torch.cat((_z1, _z2), dim=-1) if z2 is not None else _z1
+
+ # compute median depth overall (ignoring nans)
+ if quantile == 0.5:
+ shift_z = torch.nanmedian(_z, dim=-1).values
+ else:
+ shift_z = torch.nanquantile(_z, quantile, dim=-1)
+ return shift_z # (B,)
+
+
+@torch.no_grad()
+def get_joint_pointcloud_center_scale(pts1, pts2, valid_mask1=None, valid_mask2=None, z_only=False, center=True):
+ # set invalid points to NaN
+ _pts1 = invalid_to_nans(pts1, valid_mask1).reshape(len(pts1), -1, 3)
+ _pts2 = invalid_to_nans(pts2, valid_mask2).reshape(len(pts2), -1, 3) if pts2 is not None else None
+ _pts = torch.cat((_pts1, _pts2), dim=1) if pts2 is not None else _pts1
+
+ # compute median center
+ _center = torch.nanmedian(_pts, dim=1, keepdim=True).values # (B,1,3)
+ if z_only:
+ _center[..., :2] = 0 # do not center X and Y
+
+ # compute median norm
+ _norm = ((_pts - _center) if center else _pts).norm(dim=-1)
+ scale = torch.nanmedian(_norm, dim=1).values
+ return _center[:, None, :, :], scale[:, None, None, None]
+
+
+def find_reciprocal_matches(P1, P2):
+ """
+ returns 3 values:
+ 1 - reciprocal_in_P2: a boolean array of size P2.shape[0], a "True" value indicates a match
+ 2 - nn2_in_P1: a int array of size P2.shape[0], it contains the indexes of the closest points in P1
+ 3 - reciprocal_in_P2.sum(): the number of matches
+ """
+ tree1 = KDTree(P1)
+ tree2 = KDTree(P2)
+
+ _, nn1_in_P2 = tree2.query(P1, workers=8)
+ _, nn2_in_P1 = tree1.query(P2, workers=8)
+
+ reciprocal_in_P1 = (nn2_in_P1[nn1_in_P2] == np.arange(len(nn1_in_P2)))
+ reciprocal_in_P2 = (nn1_in_P2[nn2_in_P1] == np.arange(len(nn2_in_P1)))
+ assert reciprocal_in_P1.sum() == reciprocal_in_P2.sum()
+ return reciprocal_in_P2, nn2_in_P1, reciprocal_in_P2.sum()
+
+
+def get_med_dist_between_poses(poses):
+ from scipy.spatial.distance import pdist
+ return np.median(pdist([p[:3, 3].detach().cpu().numpy() for p in poses]))
+
+def get_rel_pos(_rel_xyz, free_space, voxel_size):
+ """
+ Args:
+ _rel_xyz: torch.tensor
+ shape [..., 3]
+ free_space: str
+ something like 'soft'
+ grid: GridBatch
+ fvdb grid
+
+ Returns:
+ rel_pos: torch.tensor
+ shape [..., 3], the position within a voxel (compared with the voxel center)
+ """
+ if free_space == "hard":
+ rel_pos = torch.sigmoid(_rel_xyz) * voxel_size
+ elif free_space == "soft":
+ # free space [-1, 2]
+ rel_pos = (torch.sigmoid(_rel_xyz) * 3 - 1) * voxel_size
+ elif free_space == "soft-2":
+ # free space [-2, 3]
+ rel_pos = (torch.sigmoid(_rel_xyz) * 5 - 2) * voxel_size
+ elif free_space == "soft-3":
+ # free space [-3, 4]
+ rel_pos = (torch.sigmoid(_rel_xyz) * 7 - 3) * voxel_size
+ elif free_space == "soft-4":
+ # free space [-4, 5]
+ rel_pos = (torch.sigmoid(_rel_xyz) * 9 - 4) * voxel_size
+ elif free_space == "soft-5":
+ # free space [-5, 6]
+ rel_pos = (torch.sigmoid(_rel_xyz) * 11 - 5) * voxel_size
+ elif free_space == "tanh-3":
+ # free space [-2.5, 3.5]
+ rel_pos = (torch.tanh(_rel_xyz) * 3 + 0.5) * voxel_size
+ elif free_space == "free-1":
+ rel_pos = _rel_xyz * voxel_size
+ elif free_space == "free-2":
+ rel_pos = _rel_xyz
+ elif free_space == "center":
+ rel_pos = (torch.zeros_like(_rel_xyz) + 0.5) * voxel_size
+ else:
+ raise NotImplementedError
+
+ return rel_pos
diff --git a/src/utils/image.py b/src/utils/image.py
new file mode 100644
index 0000000000000000000000000000000000000000..494502fd2b9ad5dcd0f7baf7929156d002c70c93
--- /dev/null
+++ b/src/utils/image.py
@@ -0,0 +1,185 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# Copyright (C) 2024-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+#
+# --------------------------------------------------------
+# utilitary functions about images (loading/converting...)
+# --------------------------------------------------------
+import os
+
+import numpy as np
+import PIL.Image
+import torch
+import torchvision.transforms as tvf
+from PIL.ImageOps import exif_transpose
+from PIL import Image
+import torchvision
+
+os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
+import cv2
+
+try:
+ from pillow_heif import register_heif_opener
+
+ register_heif_opener()
+ heif_support_enabled = True
+except ImportError:
+ heif_support_enabled = False
+
+ImgNorm = tvf.Compose([tvf.ToTensor(), tvf.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
+
+
+def imread_cv2(path, options=cv2.IMREAD_COLOR):
+ """Open an image or a depthmap with opencv-python."""
+ if path.endswith((".exr", "EXR")):
+ options = cv2.IMREAD_ANYDEPTH
+ img = cv2.imread(path, options)
+ if img is None:
+ raise IOError(f"Could not load image={path} with {options=}")
+ if img.ndim == 3:
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ return img
+
+
+def rgb(ftensor, true_shape=None):
+ if isinstance(ftensor, list):
+ return [rgb(x, true_shape=true_shape) for x in ftensor]
+ if isinstance(ftensor, torch.Tensor):
+ ftensor = ftensor.detach().cpu().numpy() # H,W,3
+ if ftensor.ndim == 3 and ftensor.shape[0] == 3:
+ ftensor = ftensor.transpose(1, 2, 0)
+ elif ftensor.ndim == 4 and ftensor.shape[1] == 3:
+ ftensor = ftensor.transpose(0, 2, 3, 1)
+ if true_shape is not None:
+ H, W = true_shape
+ ftensor = ftensor[:H, :W]
+ if ftensor.dtype == np.uint8:
+ img = np.float32(ftensor) / 255
+ else:
+ img = (ftensor * 0.5) + 0.5
+ return img.clip(min=0, max=1)
+
+
+def _resize_pil_image(img, long_edge_size):
+ S = max(img.size)
+ if S > long_edge_size:
+ interp = PIL.Image.LANCZOS
+ elif S <= long_edge_size:
+ interp = PIL.Image.BICUBIC
+ new_size = tuple(int(round(x * long_edge_size / S)) for x in img.size)
+ return img.resize(new_size, interp)
+
+
+def load_images(folder_or_list, size, square_ok=False, verbose=True, rotate_clockwise_90=False, crop_to_landscape=False):
+ """open and convert all images in a list or folder to proper input format for DUSt3R"""
+ if isinstance(folder_or_list, str):
+ if verbose:
+ print(f">> Loading images from {folder_or_list}")
+ root, folder_content = folder_or_list, sorted(os.listdir(folder_or_list))
+
+ elif isinstance(folder_or_list, list):
+ if verbose:
+ print(f">> Loading a list of {len(folder_or_list)} images")
+ root, folder_content = "", folder_or_list
+
+ else:
+ raise ValueError(f"bad {folder_or_list=} ({type(folder_or_list)})")
+
+ supported_images_extensions = [".jpg", ".jpeg", ".png"]
+ if heif_support_enabled:
+ supported_images_extensions += [".heic", ".heif"]
+ supported_images_extensions = tuple(supported_images_extensions)
+
+ imgs = []
+ for path in folder_content:
+ if not path.lower().endswith(supported_images_extensions):
+ continue
+ img = exif_transpose(PIL.Image.open(os.path.join(root, path))).convert("RGB")
+ if rotate_clockwise_90:
+ img = img.rotate(-90, expand=True)
+ if crop_to_landscape:
+ # Crop to a landscape aspect ratio (e.g., 16:9)
+ desired_aspect_ratio = 4 / 3
+ width, height = img.size
+ current_aspect_ratio = width / height
+
+ if current_aspect_ratio > desired_aspect_ratio:
+ # Wider than landscape: crop width
+ new_width = int(height * desired_aspect_ratio)
+ left = (width - new_width) // 2
+ right = left + new_width
+ top = 0
+ bottom = height
+ else:
+ # Taller than landscape: crop height
+ new_height = int(width / desired_aspect_ratio)
+ top = (height - new_height) // 2
+ bottom = top + new_height
+ left = 0
+ right = width
+
+ img = img.crop((left, top, right, bottom))
+
+ W1, H1 = img.size
+ if size == 224:
+ # resize short side to 224 (then crop)
+ img = _resize_pil_image(img, round(size * max(W1 / H1, H1 / W1)))
+ else:
+ # resize long side to 512
+ img = _resize_pil_image(img, size)
+ W, H = img.size
+ cx, cy = W // 2, H // 2
+ if size == 224:
+ half = min(cx, cy)
+ img = img.crop((cx - half, cy - half, cx + half, cy + half))
+ else:
+ halfw, halfh = ((2 * cx) // 16) * 8, ((2 * cy) // 16) * 8
+ if not (square_ok) and W == H:
+ halfh = 3 * halfw / 4
+ img = img.crop((cx - halfw, cy - halfh, cx + halfw, cy + halfh))
+
+ W2, H2 = img.size
+ if verbose:
+ print(f" - adding {path} with resolution {W1}x{H1} --> {W2}x{H2}")
+ imgs.append(
+ dict(
+ img=ImgNorm(img)[None],
+ true_shape=np.int32([img.size[::-1]]),
+ idx=len(imgs),
+ instance=str(len(imgs)),
+ )
+ )
+
+ assert imgs, "no images foud at " + root
+ if verbose:
+ print(f" (Found {len(imgs)} images)")
+ return imgs
+
+def process_image(img_path):
+ img = Image.open(img_path)
+ if img.mode == 'RGBA':
+ # Convert RGBA to RGB by removing alpha channel
+ img = img.convert('RGB')
+ # Resize to maintain aspect ratio and then center crop to 448x448
+ width, height = img.size
+ if width > height:
+ new_height = 448
+ new_width = int(width * (new_height / height))
+ else:
+ new_width = 448
+ new_height = int(height * (new_width / width))
+ img = img.resize((new_width, new_height))
+
+ # Center crop
+ left = (new_width - 448) // 2
+ top = (new_height - 448) // 2
+ right = left + 448
+ bottom = top + 448
+ img = img.crop((left, top, right, bottom))
+ img_tensor = torchvision.transforms.ToTensor()(img) * 2.0 - 1.0 # [-1, 1]
+ return img_tensor
\ No newline at end of file
diff --git a/src/utils/misc.py b/src/utils/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..8c34192786d91aeb2d67e8b5a8cbf7412a1c682d
--- /dev/null
+++ b/src/utils/misc.py
@@ -0,0 +1,531 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# Copyright (C) 2022-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+#
+# --------------------------------------------------------
+# utilitary functions for CroCo
+# --------------------------------------------------------
+# References:
+# MAE: https://github.com/facebookresearch/mae
+# DeiT: https://github.com/facebookresearch/deit
+# BEiT: https://github.com/microsoft/unilm/tree/master/beit
+# --------------------------------------------------------
+
+import builtins
+import datetime
+import json
+import math
+import os
+import time
+from collections import defaultdict, deque
+from pathlib import Path
+
+import numpy as np
+import torch
+import torch.distributed as dist
+from torch import inf
+
+
+class SmoothedValue(object):
+ """Track a series of values and provide access to smoothed values over a
+ window or the global series average.
+ """
+
+ def __init__(self, window_size=20, fmt=None):
+ if fmt is None:
+ fmt = "{median:.4f} ({global_avg:.4f})"
+ self.deque = deque(maxlen=window_size)
+ self.total = 0.0
+ self.count = 0
+ self.fmt = fmt
+
+ def update(self, value, n=1):
+ self.deque.append(value)
+ self.count += n
+ self.total += value * n
+
+ def synchronize_between_processes(self):
+ """
+ Warning: does not synchronize the deque!
+ """
+ if not is_dist_avail_and_initialized():
+ return
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
+ dist.barrier()
+ dist.all_reduce(t)
+ t = t.tolist()
+ self.count = int(t[0])
+ self.total = t[1]
+
+ @property
+ def median(self):
+ d = torch.tensor(list(self.deque))
+ return d.median().item()
+
+ @property
+ def avg(self):
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
+ return d.mean().item()
+
+ @property
+ def global_avg(self):
+ return self.total / self.count
+
+ @property
+ def max(self):
+ return max(self.deque)
+
+ @property
+ def value(self):
+ return self.deque[-1]
+
+ def __str__(self):
+ return self.fmt.format(
+ median=self.median,
+ avg=self.avg,
+ global_avg=self.global_avg,
+ max=self.max,
+ value=self.value,
+ )
+
+
+class MetricLogger(object):
+ def __init__(self, delimiter="\t"):
+ self.meters = defaultdict(SmoothedValue)
+ self.delimiter = delimiter
+
+ def update(self, **kwargs):
+ for k, v in kwargs.items():
+ if v is None:
+ continue
+ if isinstance(v, torch.Tensor):
+ v = v.item()
+ assert isinstance(v, (float, int))
+ self.meters[k].update(v)
+
+ def __getattr__(self, attr):
+ if attr in self.meters:
+ return self.meters[attr]
+ if attr in self.__dict__:
+ return self.__dict__[attr]
+ raise AttributeError(
+ "'{}' object has no attribute '{}'".format(type(self).__name__, attr)
+ )
+
+ def __str__(self):
+ loss_str = []
+ for name, meter in self.meters.items():
+ loss_str.append("{}: {}".format(name, str(meter)))
+ return self.delimiter.join(loss_str)
+
+ def synchronize_between_processes(self):
+ for meter in self.meters.values():
+ meter.synchronize_between_processes()
+
+ def add_meter(self, name, meter):
+ self.meters[name] = meter
+
+ def log_every(self, iterable, print_freq, header=None, max_iter=None):
+ i = 0
+ if not header:
+ header = ""
+ start_time = time.time()
+ end = time.time()
+ iter_time = SmoothedValue(fmt="{avg:.4f}")
+ data_time = SmoothedValue(fmt="{avg:.4f}")
+ len_iterable = min(len(iterable), max_iter) if max_iter else len(iterable)
+ space_fmt = ":" + str(len(str(len_iterable))) + "d"
+ log_msg = [
+ header,
+ "[{0" + space_fmt + "}/{1}]",
+ "eta: {eta}",
+ "{meters}",
+ "time: {time}",
+ "data: {data}",
+ ]
+ if torch.cuda.is_available():
+ log_msg.append("max mem: {memory:.0f}")
+ log_msg = self.delimiter.join(log_msg)
+ MB = 1024.0 * 1024.0
+ for it, obj in enumerate(iterable):
+ data_time.update(time.time() - end)
+ yield obj
+ iter_time.update(time.time() - end)
+ if i % print_freq == 0 or i == len_iterable - 1:
+ eta_seconds = iter_time.global_avg * (len_iterable - i)
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
+ if torch.cuda.is_available():
+ print(
+ log_msg.format(
+ i,
+ len_iterable,
+ eta=eta_string,
+ meters=str(self),
+ time=str(iter_time),
+ data=str(data_time),
+ memory=torch.cuda.max_memory_allocated() / MB,
+ )
+ )
+ else:
+ print(
+ log_msg.format(
+ i,
+ len_iterable,
+ eta=eta_string,
+ meters=str(self),
+ time=str(iter_time),
+ data=str(data_time),
+ )
+ )
+ i += 1
+ end = time.time()
+ if max_iter and it >= max_iter:
+ break
+ total_time = time.time() - start_time
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+ print(
+ "{} Total time: {} ({:.4f} s / it)".format(
+ header, total_time_str, total_time / len_iterable
+ )
+ )
+
+
+def setup_for_distributed(is_master):
+ """
+ This function disables printing when not in master process
+ """
+ builtin_print = builtins.print
+
+ def print(*args, **kwargs):
+ force = kwargs.pop("force", False)
+ force = force or (get_world_size() > 8)
+ if is_master or force:
+ now = datetime.datetime.now().time()
+ builtin_print("[{}] ".format(now), end="") # print with time stamp
+ builtin_print(*args, **kwargs)
+
+ builtins.print = print
+
+
+def is_dist_avail_and_initialized():
+ if not dist.is_available():
+ return False
+ if not dist.is_initialized():
+ return False
+ return True
+
+
+def get_world_size():
+ if not is_dist_avail_and_initialized():
+ return 1
+ return dist.get_world_size()
+
+
+def get_rank():
+ if not is_dist_avail_and_initialized():
+ return 0
+ return dist.get_rank()
+
+
+def is_main_process():
+ return get_rank() == 0
+
+
+def save_on_master(*args, **kwargs):
+ if is_main_process():
+ torch.save(*args, **kwargs)
+
+
+def init_distributed_mode(args):
+ nodist = args.nodist if hasattr(args, "nodist") else False
+ if "RANK" in os.environ and "WORLD_SIZE" in os.environ and not nodist:
+ args.rank = int(os.environ["RANK"])
+ args.world_size = int(os.environ["WORLD_SIZE"])
+ args.gpu = int(os.environ["LOCAL_RANK"])
+ else:
+ print("Not using distributed mode")
+ setup_for_distributed(is_master=True) # hack
+ args.distributed = False
+ return
+
+ args.distributed = True
+
+ torch.cuda.set_device(args.gpu)
+ args.dist_backend = "nccl"
+ print(
+ "| distributed init (rank {}): {}, gpu {}".format(
+ args.rank, args.dist_url, args.gpu
+ ),
+ flush=True,
+ )
+ torch.distributed.init_process_group(
+ backend=args.dist_backend,
+ init_method=args.dist_url,
+ world_size=args.world_size,
+ rank=args.rank,
+ )
+ torch.distributed.barrier()
+ setup_for_distributed(args.rank == 0)
+
+
+class NativeScalerWithGradNormCount:
+ state_dict_key = "amp_scaler"
+
+ def __init__(self, enabled=True):
+ self._scaler = torch.cuda.amp.GradScaler(enabled=enabled)
+
+ def __call__(
+ self,
+ loss,
+ optimizer,
+ clip_grad=None,
+ parameters=None,
+ create_graph=False,
+ update_grad=True,
+ ):
+ self._scaler.scale(loss).backward(create_graph=create_graph)
+ if update_grad:
+ if clip_grad is not None:
+ assert parameters is not None
+ self._scaler.unscale_(
+ optimizer
+ ) # unscale the gradients of optimizer's assigned params in-place
+ norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
+ else:
+ self._scaler.unscale_(optimizer)
+ # TODO: FIXME: get_grad_norm_ is a poorly-implemented func that is very slow, and its return not even used
+ # norm = get_grad_norm_(parameters)
+ norm = None
+ self._scaler.step(optimizer)
+ self._scaler.update()
+ else:
+ norm = None
+ return norm
+
+ def state_dict(self):
+ return self._scaler.state_dict()
+
+ def load_state_dict(self, state_dict):
+ self._scaler.load_state_dict(state_dict)
+
+
+def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
+ if isinstance(parameters, torch.Tensor):
+ parameters = [parameters]
+ parameters = [p for p in parameters if p.grad is not None]
+ norm_type = float(norm_type)
+ if len(parameters) == 0:
+ return torch.tensor(0.0)
+ device = parameters[0].grad.device
+ if norm_type == inf:
+ total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
+ else:
+ total_norm = torch.norm(
+ torch.stack(
+ [torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]
+ ),
+ norm_type,
+ )
+ return total_norm
+
+
+def save_model(
+ args, epoch, model_without_ddp, optimizer, loss_scaler, fname=None, best_so_far=None
+):
+ output_dir = Path(args.output_dir)
+ if fname is None:
+ fname = str(epoch)
+ checkpoint_path = output_dir / ("checkpoint-%s.pth" % fname)
+ to_save = {
+ "model": model_without_ddp.state_dict(),
+ "optimizer": optimizer.state_dict(),
+ "scaler": loss_scaler.state_dict(),
+ "args": args,
+ "epoch": epoch,
+ }
+ if best_so_far is not None:
+ to_save["best_so_far"] = best_so_far
+ print(f">> Saving model to {checkpoint_path} ...")
+ save_on_master(to_save, checkpoint_path)
+
+
+def load_model(args, model_without_ddp, optimizer, loss_scaler):
+ args.start_epoch = 0
+ best_so_far = None
+ if args.resume is not None:
+ if args.resume.startswith("https"):
+ checkpoint = torch.hub.load_state_dict_from_url(
+ args.resume, map_location="cpu", check_hash=True
+ )
+ else:
+ checkpoint = torch.load(args.resume, map_location="cpu")
+ print("Resume checkpoint %s" % args.resume)
+ model_without_ddp.load_state_dict(checkpoint["model"], strict=False)
+ args.start_epoch = checkpoint["epoch"] + 1
+ optimizer.load_state_dict(checkpoint["optimizer"])
+ if "scaler" in checkpoint:
+ loss_scaler.load_state_dict(checkpoint["scaler"])
+ if "best_so_far" in checkpoint:
+ best_so_far = checkpoint["best_so_far"]
+ print(" & best_so_far={:g}".format(best_so_far))
+ else:
+ print("")
+ print("With optim & sched! start_epoch={:d}".format(args.start_epoch), end="")
+ return best_so_far
+
+
+def all_reduce_mean(x):
+ world_size = get_world_size()
+ if world_size > 1:
+ x_reduce = torch.tensor(x).cuda()
+ dist.all_reduce(x_reduce)
+ x_reduce /= world_size
+ return x_reduce.item()
+ else:
+ return x
+
+
+def _replace(text, src, tgt, rm=""):
+ """Advanced string replacement.
+ Given a text:
+ - replace all elements in src by the corresponding element in tgt
+ - remove all elements in rm
+ """
+ if len(tgt) == 1:
+ tgt = tgt * len(src)
+ assert len(src) == len(tgt), f"'{src}' and '{tgt}' should have the same len"
+ for s, t in zip(src, tgt):
+ text = text.replace(s, t)
+ for c in rm:
+ text = text.replace(c, "")
+ return text
+
+
+def filename(obj):
+ """transform a python obj or cmd into a proper filename.
+ - \1 gets replaced by slash '/'
+ - \2 gets replaced by comma ','
+ """
+ if not isinstance(obj, str):
+ obj = repr(obj)
+ obj = str(obj).replace("()", "")
+ obj = _replace(obj, "_,(*/\1\2", "-__x%/,", rm=" )'\"")
+ assert all(len(s) < 256 for s in obj.split(os.sep)), (
+ "filename too long (>256 characters):\n" + obj
+ )
+ return obj
+
+
+def _get_num_layer_for_vit(var_name, enc_depth, dec_depth):
+ if var_name in ("cls_token", "mask_token", "pos_embed", "global_tokens"):
+ return 0
+ elif var_name.startswith("patch_embed"):
+ return 0
+ elif var_name.startswith("enc_blocks"):
+ layer_id = int(var_name.split(".")[1])
+ return layer_id + 1
+ elif var_name.startswith("decoder_embed") or var_name.startswith(
+ "enc_norm"
+ ): # part of the last black
+ return enc_depth
+ elif var_name.startswith("dec_blocks"):
+ layer_id = int(var_name.split(".")[1])
+ return enc_depth + layer_id + 1
+ elif var_name.startswith("dec_norm"): # part of the last block
+ return enc_depth + dec_depth
+ elif any(var_name.startswith(k) for k in ["head", "prediction_head"]):
+ return enc_depth + dec_depth + 1
+ else:
+ raise NotImplementedError(var_name)
+
+
+def get_parameter_groups(
+ model, weight_decay, layer_decay=1.0, skip_list=(), no_lr_scale_list=[]
+):
+ parameter_group_names = {}
+ parameter_group_vars = {}
+ enc_depth, dec_depth = None, None
+ # prepare layer decay values
+ assert layer_decay == 1.0 or 0.0 < layer_decay < 1.0
+ if layer_decay < 1.0:
+ enc_depth = model.enc_depth
+ dec_depth = model.dec_depth if hasattr(model, "dec_blocks") else 0
+ num_layers = enc_depth + dec_depth
+ layer_decay_values = list(
+ layer_decay ** (num_layers + 1 - i) for i in range(num_layers + 2)
+ )
+
+ for name, param in model.named_parameters():
+ if not param.requires_grad:
+ continue # frozen weights
+
+ # Assign weight decay values
+ if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list:
+ group_name = "no_decay"
+ this_weight_decay = 0.0
+ else:
+ group_name = "decay"
+ this_weight_decay = weight_decay
+
+ # Assign layer ID for LR scaling
+ if layer_decay < 1.0:
+ skip_scale = False
+ layer_id = _get_num_layer_for_vit(name, enc_depth, dec_depth)
+ group_name = "layer_%d_%s" % (layer_id, group_name)
+ if name in no_lr_scale_list:
+ skip_scale = True
+ group_name = f"{group_name}_no_lr_scale"
+ else:
+ layer_id = 0
+ skip_scale = True
+
+ if group_name not in parameter_group_names:
+ if not skip_scale:
+ scale = layer_decay_values[layer_id]
+ else:
+ scale = 1.0
+
+ parameter_group_names[group_name] = {
+ "weight_decay": this_weight_decay,
+ "params": [],
+ "lr_scale": scale,
+ }
+ parameter_group_vars[group_name] = {
+ "weight_decay": this_weight_decay,
+ "params": [],
+ "lr_scale": scale,
+ }
+
+ parameter_group_vars[group_name]["params"].append(param)
+ parameter_group_names[group_name]["params"].append(name)
+ print("Param groups = %s" % json.dumps(parameter_group_names, indent=2))
+ return list(parameter_group_vars.values())
+
+
+def adjust_learning_rate(optimizer, epoch, args):
+ """Decay the learning rate with half-cycle cosine after warmup"""
+
+ if epoch < args.warmup_epochs:
+ lr = args.lr * epoch / args.warmup_epochs
+ else:
+ lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * (
+ 1.0
+ + math.cos(
+ math.pi
+ * (epoch - args.warmup_epochs)
+ / (args.epochs - args.warmup_epochs)
+ )
+ )
+
+ for param_group in optimizer.param_groups:
+ if "lr_scale" in param_group:
+ param_group["lr"] = lr * param_group["lr_scale"]
+ else:
+ param_group["lr"] = lr
+
+ return lr
diff --git a/src/utils/point.py b/src/utils/point.py
new file mode 100644
index 0000000000000000000000000000000000000000..5eb126006427139c38ca607fcc650f2427567e6d
--- /dev/null
+++ b/src/utils/point.py
@@ -0,0 +1,47 @@
+import torch
+from torch import Tensor
+
+def get_normal_map(depth_map: torch.Tensor, intrinsic: torch.Tensor) -> torch.Tensor:
+ """
+ Convert a depth map to camera coordinates.
+
+ Args:
+ depth_map (torch.Tensor): Depth map of shape (H, W).
+ intrinsic (torch.Tensor): Camera intrinsic matrix of shape (3, 3).
+
+ Returns:
+ tuple[torch.Tensor, torch.Tensor]: Camera coordinates (H, W, 3)
+ """
+ B, H, W = depth_map.shape
+ assert intrinsic.shape == (B, 3, 3), "Intrinsic matrix must be Bx3x3"
+ assert (intrinsic[:, 0, 1] == 0).all() and (intrinsic[:, 1, 0] == 0).all(), "Intrinsic matrix must have zero skew"
+
+ # Intrinsic parameters
+ fu = intrinsic[:, 0, 0] * W # (B,)
+ fv = intrinsic[:, 1, 1] * H # (B,)
+ cu = intrinsic[:, 0, 2] * W # (B,)
+ cv = intrinsic[:, 1, 2] * H # (B,)
+
+ # Generate grid of pixel coordinates
+ u = torch.arange(W, device=depth_map.device)[None, None, :].expand(B, H, W)
+ v = torch.arange(H, device=depth_map.device)[None, :, None].expand(B, H, W)
+
+ # Unproject to camera coordinates (B, H, W)
+ x_cam = (u - cu[:, None, None]) * depth_map / fu[:, None, None]
+ y_cam = (v - cv[:, None, None]) * depth_map / fv[:, None, None]
+ z_cam = depth_map
+
+ # Stack to form camera coordinates (B, H, W, 3)
+ cam_coords = torch.stack((x_cam, y_cam, z_cam), dim=-1).to(dtype=torch.float32)
+
+ output = torch.zeros_like(cam_coords)
+ # Calculate dx using batch dimension (B, H-2, W-2, 3)
+ dx = cam_coords[:, 2:, 1:-1] - cam_coords[:, :-2, 1:-1]
+ # Calculate dy using batch dimension (B, H-2, W-2, 3)
+ dy = cam_coords[:, 1:-1, 2:] - cam_coords[:, 1:-1, :-2]
+ # Cross product and normalization (B, H-2, W-2, 3)
+ normal_map = torch.nn.functional.normalize(torch.cross(dx, dy, dim=-1), dim=-1)
+ # Assign the computed normal map to the output tensor
+ output[:, 1:-1, 1:-1, :] = normal_map
+
+ return output
\ No newline at end of file
diff --git a/src/utils/pose.py b/src/utils/pose.py
new file mode 100644
index 0000000000000000000000000000000000000000..2bdba20257c1922a4cd3241c152c23a2382bdb9d
--- /dev/null
+++ b/src/utils/pose.py
@@ -0,0 +1,218 @@
+import torch
+import numpy as np
+from src.model.encoder.vggt.utils.rotation import mat_to_quat
+from src.model.encoder.vggt.utils.geometry import closed_form_inverse_se3, unproject_depth_map_to_point_map
+
+
+def convert_pt3d_RT_to_opencv(Rot, Trans):
+ """
+ Convert Point3D extrinsic matrices to OpenCV convention.
+
+ Args:
+ Rot: 3D rotation matrix in Point3D format
+ Trans: 3D translation vector in Point3D format
+
+ Returns:
+ extri_opencv: 3x4 extrinsic matrix in OpenCV format
+ """
+ rot_pt3d = np.array(Rot)
+ trans_pt3d = np.array(Trans)
+
+ trans_pt3d[:2] *= -1
+ rot_pt3d[:, :2] *= -1
+ rot_pt3d = rot_pt3d.transpose(1, 0)
+ extri_opencv = np.hstack((rot_pt3d, trans_pt3d[:, None]))
+ return extri_opencv
+
+
+def build_pair_index(N, B=1):
+ """
+ Build indices for all possible pairs of frames.
+
+ Args:
+ N: Number of frames
+ B: Batch size
+
+ Returns:
+ i1, i2: Indices for all possible pairs
+ """
+ i1_, i2_ = torch.combinations(torch.arange(N), 2, with_replacement=False).unbind(-1)
+ i1, i2 = [(i[None] + torch.arange(B)[:, None] * N).reshape(-1) for i in [i1_, i2_]]
+ return i1, i2
+
+
+def rotation_angle(rot_gt, rot_pred, batch_size=None, eps=1e-15):
+ """
+ Calculate rotation angle error between ground truth and predicted rotations.
+
+ Args:
+ rot_gt: Ground truth rotation matrices
+ rot_pred: Predicted rotation matrices
+ batch_size: Batch size for reshaping the result
+ eps: Small value to avoid numerical issues
+
+ Returns:
+ Rotation angle error in degrees
+ """
+ q_pred = mat_to_quat(rot_pred)
+ q_gt = mat_to_quat(rot_gt)
+
+ loss_q = (1 - (q_pred * q_gt).sum(dim=1) ** 2).clamp(min=eps)
+ err_q = torch.arccos(1 - 2 * loss_q)
+
+ rel_rangle_deg = err_q * 180 / np.pi
+
+ if batch_size is not None:
+ rel_rangle_deg = rel_rangle_deg.reshape(batch_size, -1)
+
+ return rel_rangle_deg
+
+
+def translation_angle(tvec_gt, tvec_pred, batch_size=None, ambiguity=True):
+ """
+ Calculate translation angle error between ground truth and predicted translations.
+
+ Args:
+ tvec_gt: Ground truth translation vectors
+ tvec_pred: Predicted translation vectors
+ batch_size: Batch size for reshaping the result
+ ambiguity: Whether to handle direction ambiguity
+
+ Returns:
+ Translation angle error in degrees
+ """
+ rel_tangle_deg = compare_translation_by_angle(tvec_gt, tvec_pred)
+ rel_tangle_deg = rel_tangle_deg * 180.0 / np.pi
+
+ if ambiguity:
+ rel_tangle_deg = torch.min(rel_tangle_deg, (180 - rel_tangle_deg).abs())
+
+ if batch_size is not None:
+ rel_tangle_deg = rel_tangle_deg.reshape(batch_size, -1)
+
+ return rel_tangle_deg
+
+
+def compare_translation_by_angle(t_gt, t, eps=1e-15, default_err=1e6):
+ """
+ Normalize the translation vectors and compute the angle between them.
+
+ Args:
+ t_gt: Ground truth translation vectors
+ t: Predicted translation vectors
+ eps: Small value to avoid division by zero
+ default_err: Default error value for invalid cases
+
+ Returns:
+ Angular error between translation vectors in radians
+ """
+ t_norm = torch.norm(t, dim=1, keepdim=True)
+ t = t / (t_norm + eps)
+
+ t_gt_norm = torch.norm(t_gt, dim=1, keepdim=True)
+ t_gt = t_gt / (t_gt_norm + eps)
+
+ loss_t = torch.clamp_min(1.0 - torch.sum(t * t_gt, dim=1) ** 2, eps)
+ err_t = torch.acos(torch.sqrt(1 - loss_t))
+
+ err_t[torch.isnan(err_t) | torch.isinf(err_t)] = default_err
+ return err_t
+
+
+def calculate_auc(r_error, t_error, max_threshold=30, return_list=False):
+ """
+ Calculate the Area Under the Curve (AUC) for the given error arrays using PyTorch.
+
+ Args:
+ r_error: torch.Tensor representing R error values (Degree)
+ t_error: torch.Tensor representing T error values (Degree)
+ max_threshold: Maximum threshold value for binning the histogram
+ return_list: Whether to return the normalized histogram as well
+
+ Returns:
+ AUC value, and optionally the normalized histogram
+ """
+ error_matrix = torch.stack((r_error, t_error), dim=1)
+ max_errors, _ = torch.max(error_matrix, dim=1)
+ histogram = torch.histc(
+ max_errors, bins=max_threshold + 1, min=0, max=max_threshold
+ )
+ num_pairs = float(max_errors.size(0))
+ normalized_histogram = histogram / num_pairs
+
+ if return_list:
+ return (
+ torch.cumsum(normalized_histogram, dim=0).mean(),
+ normalized_histogram,
+ )
+ return torch.cumsum(normalized_histogram, dim=0).mean()
+
+
+def calculate_auc_np(r_error, t_error, max_threshold=30):
+ """
+ Calculate the Area Under the Curve (AUC) for the given error arrays using NumPy.
+
+ Args:
+ r_error: numpy array representing R error values (Degree)
+ t_error: numpy array representing T error values (Degree)
+ max_threshold: Maximum threshold value for binning the histogram
+
+ Returns:
+ AUC value and the normalized histogram
+ """
+ error_matrix = np.concatenate((r_error[:, None], t_error[:, None]), axis=1)
+ max_errors = np.max(error_matrix, axis=1)
+ bins = np.arange(max_threshold + 1)
+ histogram, _ = np.histogram(max_errors, bins=bins)
+ num_pairs = float(len(max_errors))
+ normalized_histogram = histogram.astype(float) / num_pairs
+ return np.mean(np.cumsum(normalized_histogram)), normalized_histogram
+
+
+def se3_to_relative_pose_error(pred_se3, gt_se3, num_frames):
+ """
+ Compute rotation and translation errors between predicted and ground truth poses.
+
+ Args:
+ pred_se3: Predicted SE(3) transformations
+ gt_se3: Ground truth SE(3) transformations
+ num_frames: Number of frames
+
+ Returns:
+ Rotation and translation angle errors in degrees
+ """
+ pair_idx_i1, pair_idx_i2 = build_pair_index(num_frames)
+
+ # Compute relative camera poses between pairs
+ # We use closed_form_inverse to avoid potential numerical loss by torch.inverse()
+ relative_pose_gt = closed_form_inverse_se3(gt_se3[pair_idx_i1]).bmm(
+ gt_se3[pair_idx_i2]
+ )
+ relative_pose_pred = closed_form_inverse_se3(pred_se3[pair_idx_i1]).bmm(
+ pred_se3[pair_idx_i2]
+ )
+
+ # Compute the difference in rotation and translation
+ rel_rangle_deg = rotation_angle(
+ relative_pose_gt[:, :3, :3], relative_pose_pred[:, :3, :3]
+ )
+ rel_tangle_deg = translation_angle(
+ relative_pose_gt[:, :3, 3], relative_pose_pred[:, :3, 3]
+ )
+
+ return rel_rangle_deg, rel_tangle_deg
+
+
+def align_to_first_camera(camera_poses):
+ """
+ Align all camera poses to the first camera's coordinate frame.
+
+ Args:
+ camera_poses: Tensor of shape (N, 4, 4) containing camera poses as SE3 transformations
+
+ Returns:
+ Tensor of shape (N, 4, 4) containing aligned camera poses
+ """
+ first_cam_extrinsic_inv = closed_form_inverse_se3(camera_poses[0][None])
+ aligned_poses = torch.matmul(camera_poses, first_cam_extrinsic_inv)
+ return aligned_poses
\ No newline at end of file
diff --git a/src/utils/render.py b/src/utils/render.py
new file mode 100644
index 0000000000000000000000000000000000000000..63e576230d33ab76b5dba10916c768b098ca83fb
--- /dev/null
+++ b/src/utils/render.py
@@ -0,0 +1,284 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+import os
+import enum
+import types
+from typing import List, Mapping, Optional, Text, Tuple, Union
+import copy
+from PIL import Image
+# import mediapy as media
+from matplotlib import cm
+from tqdm import tqdm
+
+import torch
+
+def normalize(x: np.ndarray) -> np.ndarray:
+ """Normalization helper function."""
+ return x / np.linalg.norm(x)
+
+def pad_poses(p: np.ndarray) -> np.ndarray:
+ """Pad [..., 3, 4] pose matrices with a homogeneous bottom row [0,0,0,1]."""
+ bottom = np.broadcast_to([0, 0, 0, 1.], p[..., :1, :4].shape)
+ return np.concatenate([p[..., :3, :4], bottom], axis=-2)
+
+
+def unpad_poses(p: np.ndarray) -> np.ndarray:
+ """Remove the homogeneous bottom row from [..., 4, 4] pose matrices."""
+ return p[..., :3, :4]
+
+
+def recenter_poses(poses: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
+ """Recenter poses around the origin."""
+ cam2world = average_pose(poses)
+ transform = np.linalg.inv(pad_poses(cam2world))
+ poses = transform @ pad_poses(poses)
+ return unpad_poses(poses), transform
+
+
+def average_pose(poses: np.ndarray) -> np.ndarray:
+ """New pose using average position, z-axis, and up vector of input poses."""
+ position = poses[:, :3, 3].mean(0)
+ z_axis = poses[:, :3, 2].mean(0)
+ up = poses[:, :3, 1].mean(0)
+ cam2world = viewmatrix(z_axis, up, position)
+ return cam2world
+
+def viewmatrix(lookdir: np.ndarray, up: np.ndarray,
+ position: np.ndarray) -> np.ndarray:
+ """Construct lookat view matrix."""
+ vec2 = normalize(lookdir)
+ vec0 = normalize(np.cross(up, vec2))
+ vec1 = normalize(np.cross(vec2, vec0))
+ m = np.stack([vec0, vec1, vec2, position], axis=1)
+ return m
+
+def focus_point_fn(poses: np.ndarray) -> np.ndarray:
+ """Calculate nearest point to all focal axes in poses."""
+ directions, origins = poses[:, :3, 2:3], poses[:, :3, 3:4]
+ m = np.eye(3) - directions * np.transpose(directions, [0, 2, 1])
+ mt_m = np.transpose(m, [0, 2, 1]) @ m
+ focus_pt = np.linalg.inv(mt_m.mean(0)) @ (mt_m @ origins).mean(0)[:, 0]
+ return focus_pt
+
+def transform_poses_pca(poses: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
+ """Transforms poses so principal components lie on XYZ axes.
+
+ Args:
+ poses: a (N, 3, 4) array containing the cameras' camera to world transforms.
+
+ Returns:
+ A tuple (poses, transform), with the transformed poses and the applied
+ camera_to_world transforms.
+ """
+ t = poses[:, :3, 3]
+ t_mean = t.mean(axis=0)
+ t = t - t_mean
+
+ eigval, eigvec = np.linalg.eig(t.T @ t)
+ # Sort eigenvectors in order of largest to smallest eigenvalue.
+ inds = np.argsort(eigval)[::-1]
+ eigvec = eigvec[:, inds]
+ rot = eigvec.T
+ if np.linalg.det(rot) < 0:
+ rot = np.diag(np.array([1, 1, -1])) @ rot
+
+ transform = np.concatenate([rot, rot @ -t_mean[:, None]], -1)
+ poses_recentered = unpad_poses(transform @ pad_poses(poses))
+ transform = np.concatenate([transform, np.eye(4)[3:]], axis=0)
+
+ # Flip coordinate system if z component of y-axis is negative
+ if poses_recentered.mean(axis=0)[2, 1] < 0:
+ poses_recentered = np.diag(np.array([1, -1, -1])) @ poses_recentered
+ transform = np.diag(np.array([1, -1, -1, 1])) @ transform
+
+ return poses_recentered, transform
+ # points = np.random.rand(3,100)
+ # points_h = np.concatenate((points,np.ones_like(points[:1])), axis=0)
+ # (poses_recentered @ points_h)[0]
+ # (transform @ pad_poses(poses) @ points_h)[0,:3]
+ # import pdb; pdb.set_trace()
+
+ # # Just make sure it's it in the [-1, 1]^3 cube
+ # scale_factor = 1. / np.max(np.abs(poses_recentered[:, :3, 3]))
+ # poses_recentered[:, :3, 3] *= scale_factor
+ # transform = np.diag(np.array([scale_factor] * 3 + [1])) @ transform
+
+ # return poses_recentered, transform
+
+def generate_ellipse_path(poses: np.ndarray,
+ n_frames: int = 120,
+ const_speed: bool = True,
+ z_variation: float = 0.,
+ z_phase: float = 0.) -> np.ndarray:
+ """Generate an elliptical render path based on the given poses."""
+ # Calculate the focal point for the path (cameras point toward this).
+ center = focus_point_fn(poses)
+ # Path height sits at z=0 (in middle of zero-mean capture pattern).
+ offset = np.array([center[0], center[1], 0])
+
+ # Calculate scaling for ellipse axes based on input camera positions.
+ sc = np.percentile(np.abs(poses[:, :3, 3] - offset), 90, axis=0)
+ # Use ellipse that is symmetric about the focal point in xy.
+ low = -sc + offset
+ high = sc + offset
+ # Optional height variation need not be symmetric
+ z_low = np.percentile((poses[:, :3, 3]), 10, axis=0)
+ z_high = np.percentile((poses[:, :3, 3]), 90, axis=0)
+
+ def get_positions(theta):
+ # Interpolate between bounds with trig functions to get ellipse in x-y.
+ # Optionally also interpolate in z to change camera height along path.
+ return np.stack([
+ low[0] + (high - low)[0] * (np.cos(theta) * .5 + .5),
+ low[1] + (high - low)[1] * (np.sin(theta) * .5 + .5),
+ z_variation * (z_low[2] + (z_high - z_low)[2] *
+ (np.cos(theta + 2 * np.pi * z_phase) * .5 + .5)),
+ ], -1)
+
+ theta = np.linspace(0, 2. * np.pi, n_frames + 1, endpoint=True)
+ positions = get_positions(theta)
+
+ #if const_speed:
+
+ # # Resample theta angles so that the velocity is closer to constant.
+ # lengths = np.linalg.norm(positions[1:] - positions[:-1], axis=-1)
+ # theta = stepfun.sample(None, theta, np.log(lengths), n_frames + 1)
+ # positions = get_positions(theta)
+
+ # Throw away duplicated last position.
+ positions = positions[:-1]
+
+ # Set path's up vector to axis closest to average of input pose up vectors.
+ avg_up = poses[:, :3, 1].mean(0)
+ avg_up = avg_up / np.linalg.norm(avg_up)
+ ind_up = np.argmax(np.abs(avg_up))
+ up = np.eye(3)[ind_up] * np.sign(avg_up[ind_up])
+
+ return np.stack([viewmatrix(p - center, up, p) for p in positions])
+
+
+def generate_path(viewpoint_cameras, n_frames=480):
+ # c2ws = np.array([np.linalg.inv(np.asarray((cam.world_view_transform.T).cpu().numpy())) for cam in viewpoint_cameras])
+ c2ws = viewpoint_cameras.cpu().numpy()
+ pose = c2ws[:,:3,:] @ np.diag([1, -1, -1, 1])
+ pose_recenter, colmap_to_world_transform = transform_poses_pca(pose)
+
+ # generate new poses
+ new_poses = generate_ellipse_path(poses=pose_recenter, n_frames=n_frames)
+ # warp back to orignal scale
+ new_poses = np.linalg.inv(colmap_to_world_transform) @ pad_poses(new_poses)
+
+ return new_poses
+
+ # traj = []
+ # for c2w in new_poses:
+ # c2w = c2w @ np.diag([1, -1, -1, 1])
+ # cam = copy.deepcopy(viewpoint_cameras[0])
+ # cam.image_height = int(cam.image_height / 2) * 2
+ # cam.image_width = int(cam.image_width / 2) * 2
+ # cam.world_view_transform = torch.from_numpy(np.linalg.inv(c2w).T).float().cuda()
+ # cam.full_proj_transform = (cam.world_view_transform.unsqueeze(0).bmm(cam.projection_matrix.unsqueeze(0))).squeeze(0)
+ # cam.camera_center = cam.world_view_transform.inverse()[3, :3]
+ # traj.append(cam)
+
+ # return traj
+
+def load_img(pth: str) -> np.ndarray:
+ """Load an image and cast to float32."""
+ with open(pth, 'rb') as f:
+ image = np.array(Image.open(f), dtype=np.float32)
+ return image
+
+
+def create_videos(base_dir, input_dir, out_name, num_frames=480):
+ """Creates videos out of the images saved to disk."""
+ # Last two parts of checkpoint path are experiment name and scene name.
+ video_prefix = f'{out_name}'
+
+ zpad = max(5, len(str(num_frames - 1)))
+ idx_to_str = lambda idx: str(idx).zfill(zpad)
+
+ os.makedirs(base_dir, exist_ok=True)
+ render_dist_curve_fn = np.log
+
+ # Load one example frame to get image shape and depth range.
+ depth_file = os.path.join(input_dir, 'vis', f'depth_{idx_to_str(0)}.tiff')
+ depth_frame = load_img(depth_file)
+ shape = depth_frame.shape
+ p = 3
+ distance_limits = np.percentile(depth_frame.flatten(), [p, 100 - p])
+ lo, hi = [render_dist_curve_fn(x) for x in distance_limits]
+ print(f'Video shape is {shape[:2]}')
+
+ video_kwargs = {
+ 'shape': shape[:2],
+ 'codec': 'h264',
+ 'fps': 60,
+ 'crf': 18,
+ }
+
+ for k in ['depth', 'normal', 'color']:
+ video_file = os.path.join(base_dir, f'{video_prefix}_{k}.mp4')
+ input_format = 'gray' if k == 'alpha' else 'rgb'
+
+
+ file_ext = 'png' if k in ['color', 'normal'] else 'tiff'
+ idx = 0
+
+ if k == 'color':
+ file0 = os.path.join(input_dir, 'renders', f'{idx_to_str(0)}.{file_ext}')
+ else:
+ file0 = os.path.join(input_dir, 'vis', f'{k}_{idx_to_str(0)}.{file_ext}')
+
+ if not os.path.exists(file0):
+ print(f'Images missing for tag {k}')
+ continue
+ print(f'Making video {video_file}...')
+ with media.VideoWriter(
+ video_file, **video_kwargs, input_format=input_format) as writer:
+ for idx in tqdm(range(num_frames)):
+ # img_file = os.path.join(input_dir, f'{k}_{idx_to_str(idx)}.{file_ext}')
+ if k == 'color':
+ img_file = os.path.join(input_dir, 'renders', f'{idx_to_str(idx)}.{file_ext}')
+ else:
+ img_file = os.path.join(input_dir, 'vis', f'{k}_{idx_to_str(idx)}.{file_ext}')
+
+ if not os.path.exists(img_file):
+ ValueError(f'Image file {img_file} does not exist.')
+ img = load_img(img_file)
+ if k in ['color', 'normal']:
+ img = img / 255.
+ elif k.startswith('depth'):
+ img = render_dist_curve_fn(img)
+ img = np.clip((img - np.minimum(lo, hi)) / np.abs(hi - lo), 0, 1)
+ img = cm.get_cmap('turbo')(img)[..., :3]
+
+ frame = (np.clip(np.nan_to_num(img), 0., 1.) * 255.).astype(np.uint8)
+ writer.add_image(frame)
+ idx += 1
+
+def save_img_u8(img, pth):
+ """Save an image (probably RGB) in [0, 1] to disk as a uint8 PNG."""
+ with open(pth, 'wb') as f:
+ Image.fromarray(
+ (np.clip(np.nan_to_num(img), 0., 1.) * 255.).astype(np.uint8)).save(
+ f, 'PNG')
+
+
+def save_img_f32(depthmap, pth):
+ """Save an image (probably a depthmap) to disk as a float32 TIFF."""
+ with open(pth, 'wb') as f:
+ Image.fromarray(np.nan_to_num(depthmap).astype(np.float32)).save(f, 'TIFF')
\ No newline at end of file
diff --git a/src/utils/tensor_to_pycolmap.py b/src/utils/tensor_to_pycolmap.py
new file mode 100644
index 0000000000000000000000000000000000000000..1332f6617b933134557555c99c1b040eb1569514
--- /dev/null
+++ b/src/utils/tensor_to_pycolmap.py
@@ -0,0 +1,320 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+import numpy as np
+import pycolmap
+
+# TODO: frame_idx should start from 1 instead of 0 in colmap
+def batch_matrix_to_pycolmap(
+ points3d,
+ extrinsics,
+ intrinsics,
+ tracks,
+ image_size,
+ masks=None,
+ max_reproj_error=None,
+ max_points3D_val=3000,
+ shared_camera=False,
+ camera_type="SIMPLE_PINHOLE",
+ extra_params=None,
+):
+ """
+ Convert Batched Pytorch Tensors to PyCOLMAP
+
+ Check https://github.com/colmap/pycolmap for more details about its format
+ """
+
+ # points3d: Px3
+ # extrinsics: Nx3x4
+ # intrinsics: Nx3x3
+ # tracks: NxPx2
+ # masks: NxP
+ # image_size: 2, assume all the frames have been padded to the same size
+ # where N is the number of frames and P is the number of tracks
+
+ N, P, _ = tracks.shape
+ assert len(extrinsics) == N
+ assert len(intrinsics) == N
+ assert len(points3d) == P
+ assert image_size.shape[0] == 2
+
+ projected_points_2d, projected_points_cam = project_3D_points(points3d, extrinsics, intrinsics, return_points_cam=True)
+ projected_diff = (projected_points_2d - tracks).norm(dim=-1)
+ projected_points_2d[projected_points_cam[:, -1] <= 0] = 1e6
+ reproj_mask = projected_diff < max_reproj_error
+
+ if masks is not None:
+ masks = torch.logical_and(masks, reproj_mask)
+ else:
+ masks = reproj_mask
+
+ extrinsics = extrinsics.cpu().numpy()
+ intrinsics = intrinsics.cpu().numpy()
+
+ if extra_params is not None:
+ extra_params = extra_params.cpu().numpy()
+
+
+ tracks = tracks.cpu().numpy()
+ points3d = points3d.cpu().numpy()
+ image_size = image_size.cpu().numpy()
+
+ # Reconstruction object, following the format of PyCOLMAP/COLMAP
+ reconstruction = pycolmap.Reconstruction()
+
+ masks = masks.cpu().numpy()
+
+ inlier_num = masks.sum(0)
+ valid_mask = inlier_num >= 2 # a track is invalid if without two inliers
+ valid_idx = np.nonzero(valid_mask)[0]
+
+ # Only add 3D points that have sufficient 2D points
+ for vidx in valid_idx:
+ reconstruction.add_point3D(
+ points3d[vidx], pycolmap.Track(), np.zeros(3)
+ )
+
+ num_points3D = len(valid_idx)
+
+ camera = None
+ # frame idx
+ for fidx in range(N):
+ # set camera
+ if camera is None or (not shared_camera):
+ if camera_type == "SIMPLE_RADIAL":
+ focal = (intrinsics[fidx][0, 0] + intrinsics[fidx][1, 1]) / 2
+ pycolmap_intri = np.array(
+ [
+ focal,
+ intrinsics[fidx][0, 2],
+ intrinsics[fidx][1, 2],
+ extra_params[fidx][0],
+ ]
+ )
+ elif camera_type == "SIMPLE_PINHOLE":
+ focal = (intrinsics[fidx][0, 0] + intrinsics[fidx][1, 1]) / 2
+ pycolmap_intri = np.array(
+ [
+ focal,
+ intrinsics[fidx][0, 2],
+ intrinsics[fidx][1, 2],
+ ]
+ )
+ else:
+ raise ValueError(
+ f"Camera type {camera_type} is not supported yet"
+ )
+
+ camera = pycolmap.Camera(
+ model=camera_type,
+ width=image_size[0],
+ height=image_size[1],
+ params=pycolmap_intri,
+ camera_id=fidx,
+ )
+
+ # add camera
+ reconstruction.add_camera(camera)
+
+ # set image
+ cam_from_world = pycolmap.Rigid3d(
+ pycolmap.Rotation3d(extrinsics[fidx][:3, :3]),
+ extrinsics[fidx][:3, 3],
+ ) # Rot and Trans
+ image = pycolmap.Image(
+ id=fidx,
+ name=f"image_{fidx}",
+ camera_id=camera.camera_id,
+ cam_from_world=cam_from_world,
+ )
+
+ points2D_list = []
+
+ point2D_idx = 0
+ # NOTE point3D_id start by 1
+ for point3D_id in range(1, num_points3D + 1):
+ original_track_idx = valid_idx[point3D_id - 1]
+
+ if (
+ reconstruction.points3D[point3D_id].xyz < max_points3D_val
+ ).all():
+ if masks[fidx][original_track_idx]:
+ # It seems we don't need +0.5 for BA
+ point2D_xy = tracks[fidx][original_track_idx]
+ # Please note when adding the Point2D object
+ # It not only requires the 2D xy location, but also the id to 3D point
+ points2D_list.append(
+ pycolmap.Point2D(point2D_xy, point3D_id)
+ )
+
+ # add element
+ track = reconstruction.points3D[point3D_id].track
+ track.add_element(fidx, point2D_idx)
+ point2D_idx += 1
+
+ assert point2D_idx == len(points2D_list)
+
+ try:
+ image.points2D = pycolmap.ListPoint2D(points2D_list)
+ image.registered = True
+ except:
+ print(f"frame {fidx} is out of BA")
+ image.registered = False
+
+ # add image
+ reconstruction.add_image(image)
+
+ return reconstruction
+
+
+def pycolmap_to_batch_matrix(
+ reconstruction, device="cuda", camera_type="SIMPLE_PINHOLE"
+):
+ """
+ Convert a PyCOLMAP Reconstruction Object to batched PyTorch tensors.
+
+ Args:
+ reconstruction (pycolmap.Reconstruction): The reconstruction object from PyCOLMAP.
+ device (str): The device to place the tensors on (default: "cuda").
+ camera_type (str): The type of camera model used (default: "SIMPLE_PINHOLE").
+
+ Returns:
+ tuple: A tuple containing points3D, extrinsics, intrinsics, and optionally extra_params.
+ """
+
+ num_images = len(reconstruction.images)
+ max_points3D_id = max(reconstruction.point3D_ids())
+ points3D = np.zeros((max_points3D_id, 3))
+
+ for point3D_id in reconstruction.points3D:
+ points3D[point3D_id - 1] = reconstruction.points3D[point3D_id].xyz
+ points3D = torch.from_numpy(points3D).to(device)
+
+ extrinsics = []
+ intrinsics = []
+
+ extra_params = [] if camera_type == "SIMPLE_RADIAL" else None
+
+ for i in range(num_images):
+ # Extract and append extrinsics
+ pyimg = reconstruction.images[i]
+ pycam = reconstruction.cameras[pyimg.camera_id]
+ matrix = pyimg.cam_from_world.matrix()
+ extrinsics.append(matrix)
+
+ # Extract and append intrinsics
+ calibration_matrix = pycam.calibration_matrix()
+ intrinsics.append(calibration_matrix)
+
+ if camera_type == "SIMPLE_RADIAL":
+ extra_params.append(pycam.params[-1])
+
+ # Convert lists to torch tensors
+ extrinsics = torch.from_numpy(np.stack(extrinsics)).to(device)
+
+ intrinsics = torch.from_numpy(np.stack(intrinsics)).to(device)
+
+ if camera_type == "SIMPLE_RADIAL":
+ extra_params = torch.from_numpy(np.stack(extra_params)).to(device)
+ extra_params = extra_params[:, None]
+
+ return points3D, extrinsics, intrinsics, extra_params
+
+
+
+
+
+def project_3D_points(
+ points3D,
+ extrinsics,
+ intrinsics=None,
+ extra_params=None,
+ return_points_cam=False,
+ default=0,
+ only_points_cam=False,
+):
+ """
+ Transforms 3D points to 2D using extrinsic and intrinsic parameters.
+ Args:
+ points3D (torch.Tensor): 3D points of shape Px3.
+ extrinsics (torch.Tensor): Extrinsic parameters of shape Bx3x4.
+ intrinsics (torch.Tensor): Intrinsic parameters of shape Bx3x3.
+ extra_params (torch.Tensor): Extra parameters of shape BxN, which is used for radial distortion.
+ Returns:
+ torch.Tensor: Transformed 2D points of shape BxNx2.
+ """
+ with torch.cuda.amp.autocast(dtype=torch.double):
+ N = points3D.shape[0] # Number of points
+ B = extrinsics.shape[0] # Batch size, i.e., number of cameras
+ points3D_homogeneous = torch.cat(
+ [points3D, torch.ones_like(points3D[..., 0:1])], dim=1
+ ) # Nx4
+ # Reshape for batch processing
+ points3D_homogeneous = points3D_homogeneous.unsqueeze(0).expand(
+ B, -1, -1
+ ) # BxNx4
+
+ # Step 1: Apply extrinsic parameters
+ # Transform 3D points to camera coordinate system for all cameras
+ points_cam = torch.bmm(
+ extrinsics, points3D_homogeneous.transpose(-1, -2)
+ )
+
+ if only_points_cam:
+ return points_cam
+
+ # Step 2: Apply intrinsic parameters and (optional) distortion
+ points2D = img_from_cam(intrinsics, points_cam, extra_params)
+
+ if return_points_cam:
+ return points2D, points_cam
+ return points2D
+
+
+def img_from_cam(intrinsics, points_cam, extra_params=None, default=0.0):
+ """
+ Applies intrinsic parameters and optional distortion to the given 3D points.
+
+ Args:
+ intrinsics (torch.Tensor): Intrinsic camera parameters of shape Bx3x3.
+ points_cam (torch.Tensor): 3D points in camera coordinates of shape Bx3xN.
+ extra_params (torch.Tensor, optional): Distortion parameters of shape BxN, where N can be 1, 2, or 4.
+ default (float, optional): Default value to replace NaNs in the output.
+
+ Returns:
+ points2D (torch.Tensor): 2D points in pixel coordinates of shape BxNx2.
+ """
+
+ # Normalize by the third coordinate (homogeneous division)
+ points_cam = points_cam / points_cam[:, 2:3, :]
+ # Extract uv
+ uv = points_cam[:, :2, :]
+
+ # Apply distortion if extra_params are provided
+ if extra_params is not None:
+ uu, vv = apply_distortion(extra_params, uv[:, 0], uv[:, 1])
+ uv = torch.stack([uu, vv], dim=1)
+
+ # Prepare points_cam for batch matrix multiplication
+ points_cam_homo = torch.cat(
+ (uv, torch.ones_like(uv[:, :1, :])), dim=1
+ ) # Bx3xN
+ # Apply intrinsic parameters using batch matrix multiplication
+ points2D_homo = torch.bmm(intrinsics, points_cam_homo) # Bx3xN
+
+ # Extract x and y coordinates
+ points2D = points2D_homo[:, :2, :] # Bx2xN
+
+ # Replace NaNs with default value
+ points2D = torch.nan_to_num(points2D, nan=default)
+
+ return points2D.transpose(1, 2) # BxNx2
+
diff --git a/src/utils/transforms.py b/src/utils/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..809802e3c3c2f290b00fb1cd80d724535b554ec8
--- /dev/null
+++ b/src/utils/transforms.py
@@ -0,0 +1,18 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# Copyright (C) 2024-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+#
+# --------------------------------------------------------
+# DUST3R default transforms
+# --------------------------------------------------------
+import torchvision.transforms as tvf
+
+from model.dataset.utils.image import ImgNorm
+
+# define the standard image transforms
+ColorJitter = tvf.Compose([tvf.ColorJitter(0.5, 0.5, 0.5, 0.1), ImgNorm])
diff --git a/src/utils/viz.py b/src/utils/viz.py
new file mode 100644
index 0000000000000000000000000000000000000000..c757e8c835626dabed6ad497338e96465e7dd2e1
--- /dev/null
+++ b/src/utils/viz.py
@@ -0,0 +1,377 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# Copyright (C) 2024-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+#
+# --------------------------------------------------------
+# Visualization utilities using trimesh
+# --------------------------------------------------------
+import numpy as np
+import PIL.Image
+import torch
+from scipy.spatial.transform import Rotation
+
+from model.dataset.utils.device import to_numpy
+from model.dataset.utils.geometry import geotrf, get_med_dist_between_poses
+from model.dataset.utils.image import rgb
+
+try:
+ import trimesh
+except ImportError:
+ print("/!\\ module trimesh is not installed, cannot visualize results /!\\")
+
+
+def cat_3d(vecs):
+ if isinstance(vecs, (np.ndarray, torch.Tensor)):
+ vecs = [vecs]
+ return np.concatenate([p.reshape(-1, 3) for p in to_numpy(vecs)])
+
+
+def show_raw_pointcloud(pts3d, colors, point_size=2):
+ scene = trimesh.Scene()
+
+ pct = trimesh.PointCloud(cat_3d(pts3d), colors=cat_3d(colors))
+ scene.add_geometry(pct)
+
+ scene.show(line_settings={"point_size": point_size})
+
+
+def pts3d_to_trimesh(img, pts3d, valid=None):
+ H, W, THREE = img.shape
+ assert THREE == 3
+ assert img.shape == pts3d.shape
+
+ vertices = pts3d.reshape(-1, 3)
+
+ # make squares: each pixel == 2 triangles
+ idx = np.arange(len(vertices)).reshape(H, W)
+ idx1 = idx[:-1, :-1].ravel() # top-left corner
+ idx2 = idx[:-1, +1:].ravel() # right-left corner
+ idx3 = idx[+1:, :-1].ravel() # bottom-left corner
+ idx4 = idx[+1:, +1:].ravel() # bottom-right corner
+ faces = np.concatenate(
+ (
+ np.c_[idx1, idx2, idx3],
+ np.c_[
+ idx3, idx2, idx1
+ ], # same triangle, but backward (cheap solution to cancel face culling)
+ np.c_[idx2, idx3, idx4],
+ np.c_[
+ idx4, idx3, idx2
+ ], # same triangle, but backward (cheap solution to cancel face culling)
+ ),
+ axis=0,
+ )
+
+ # prepare triangle colors
+ face_colors = np.concatenate(
+ (
+ img[:-1, :-1].reshape(-1, 3),
+ img[:-1, :-1].reshape(-1, 3),
+ img[+1:, +1:].reshape(-1, 3),
+ img[+1:, +1:].reshape(-1, 3),
+ ),
+ axis=0,
+ )
+
+ # remove invalid faces
+ if valid is not None:
+ assert valid.shape == (H, W)
+ valid_idxs = valid.ravel()
+ valid_faces = valid_idxs[faces].all(axis=-1)
+ faces = faces[valid_faces]
+ face_colors = face_colors[valid_faces]
+
+ assert len(faces) == len(face_colors)
+ return dict(vertices=vertices, face_colors=face_colors, faces=faces)
+
+
+def cat_meshes(meshes):
+ vertices, faces, colors = zip(
+ *[(m["vertices"], m["faces"], m["face_colors"]) for m in meshes]
+ )
+ n_vertices = np.cumsum([0] + [len(v) for v in vertices])
+ for i in range(len(faces)):
+ faces[i][:] += n_vertices[i]
+
+ vertices = np.concatenate(vertices)
+ colors = np.concatenate(colors)
+ faces = np.concatenate(faces)
+ return dict(vertices=vertices, face_colors=colors, faces=faces)
+
+
+def show_duster_pairs(view1, view2, pred1, pred2):
+ import matplotlib.pyplot as pl
+
+ pl.ion()
+
+ for e in range(len(view1["instance"])):
+ i = view1["idx"][e]
+ j = view2["idx"][e]
+ img1 = rgb(view1["img"][e])
+ img2 = rgb(view2["img"][e])
+ conf1 = pred1["conf"][e].squeeze()
+ conf2 = pred2["conf"][e].squeeze()
+ score = conf1.mean() * conf2.mean()
+ print(f">> Showing pair #{e} {i}-{j} {score=:g}")
+ pl.clf()
+ pl.subplot(221).imshow(img1)
+ pl.subplot(223).imshow(img2)
+ pl.subplot(222).imshow(conf1, vmin=1, vmax=30)
+ pl.subplot(224).imshow(conf2, vmin=1, vmax=30)
+ pts1 = pred1["pts3d"][e]
+ pts2 = pred2["pts3d_in_other_view"][e]
+ pl.subplots_adjust(0, 0, 1, 1, 0, 0)
+ if input("show pointcloud? (y/n) ") == "y":
+ show_raw_pointcloud(cat(pts1, pts2), cat(img1, img2), point_size=5)
+
+
+def auto_cam_size(im_poses):
+ return 0.1 * get_med_dist_between_poses(im_poses)
+
+
+class SceneViz:
+ def __init__(self):
+ self.scene = trimesh.Scene()
+
+ def add_pointcloud(self, pts3d, color, mask=None):
+ pts3d = to_numpy(pts3d)
+ mask = to_numpy(mask)
+ if mask is None:
+ mask = [slice(None)] * len(pts3d)
+ pts = np.concatenate([p[m] for p, m in zip(pts3d, mask)])
+ pct = trimesh.PointCloud(pts.reshape(-1, 3))
+
+ if isinstance(color, (list, np.ndarray, torch.Tensor)):
+ color = to_numpy(color)
+ col = np.concatenate([p[m] for p, m in zip(color, mask)])
+ assert col.shape == pts.shape
+ pct.visual.vertex_colors = uint8(col.reshape(-1, 3))
+ else:
+ assert len(color) == 3
+ pct.visual.vertex_colors = np.broadcast_to(uint8(color), pts.shape)
+
+ self.scene.add_geometry(pct)
+ return self
+
+ def add_camera(
+ self,
+ pose_c2w,
+ focal=None,
+ color=(0, 0, 0),
+ image=None,
+ imsize=None,
+ cam_size=0.03,
+ ):
+ pose_c2w, focal, color, image = to_numpy((pose_c2w, focal, color, image))
+ add_scene_cam(self.scene, pose_c2w, color, image, focal, screen_width=cam_size)
+ return self
+
+ def add_cameras(
+ self, poses, focals=None, images=None, imsizes=None, colors=None, **kw
+ ):
+ def get(arr, idx):
+ return None if arr is None else arr[idx]
+
+ for i, pose_c2w in enumerate(poses):
+ self.add_camera(
+ pose_c2w,
+ get(focals, i),
+ image=get(images, i),
+ color=get(colors, i),
+ imsize=get(imsizes, i),
+ **kw,
+ )
+ return self
+
+ def show(self, point_size=2, viewer=None):
+ return self.scene.show(viewer=viewer, line_settings={"point_size": point_size})
+
+
+def show_raw_pointcloud_with_cams(
+ imgs, pts3d, mask, focals, cams2world, point_size=2, cam_size=0.05, cam_color=None
+):
+ """Visualization of a pointcloud with cameras
+ imgs = (N, H, W, 3) or N-size list of [(H,W,3), ...]
+ pts3d = (N, H, W, 3) or N-size list of [(H,W,3), ...]
+ focals = (N,) or N-size list of [focal, ...]
+ cams2world = (N,4,4) or N-size list of [(4,4), ...]
+ """
+ assert len(pts3d) == len(mask) <= len(imgs) <= len(cams2world) == len(focals)
+ pts3d = to_numpy(pts3d)
+ imgs = to_numpy(imgs)
+ focals = to_numpy(focals)
+ cams2world = to_numpy(cams2world)
+
+ scene = trimesh.Scene()
+
+ # full pointcloud
+ pts = np.concatenate([p[m] for p, m in zip(pts3d, mask)])
+ col = np.concatenate([p[m] for p, m in zip(imgs, mask)])
+ pct = trimesh.PointCloud(pts.reshape(-1, 3), colors=col.reshape(-1, 3))
+ scene.add_geometry(pct)
+
+ # add each camera
+ for i, pose_c2w in enumerate(cams2world):
+ if isinstance(cam_color, list):
+ camera_edge_color = cam_color[i]
+ else:
+ camera_edge_color = cam_color or CAM_COLORS[i % len(CAM_COLORS)]
+ add_scene_cam(
+ scene,
+ pose_c2w,
+ camera_edge_color,
+ imgs[i] if i < len(imgs) else None,
+ focals[i],
+ screen_width=cam_size,
+ )
+
+ scene.show(line_settings={"point_size": point_size})
+
+
+def add_scene_cam(
+ scene, pose_c2w, edge_color, image=None, focal=None, imsize=None, screen_width=0.03
+):
+ if image is not None:
+ H, W, THREE = image.shape
+ assert THREE == 3
+ if image.dtype != np.uint8:
+ image = np.uint8(255 * image)
+ elif imsize is not None:
+ W, H = imsize
+ elif focal is not None:
+ H = W = focal / 1.1
+ else:
+ H = W = 1
+
+ if focal is None:
+ focal = min(H, W) * 1.1 # default value
+ elif isinstance(focal, np.ndarray):
+ focal = focal[0]
+
+ # create fake camera
+ height = focal * screen_width / H
+ width = screen_width * 0.5**0.5
+ rot45 = np.eye(4)
+ rot45[:3, :3] = Rotation.from_euler("z", np.deg2rad(45)).as_matrix()
+ rot45[2, 3] = -height # set the tip of the cone = optical center
+ aspect_ratio = np.eye(4)
+ aspect_ratio[0, 0] = W / H
+ transform = pose_c2w @ OPENGL @ aspect_ratio @ rot45
+ cam = trimesh.creation.cone(width, height, sections=4) # , transform=transform)
+
+ # this is the image
+ if image is not None:
+ vertices = geotrf(transform, cam.vertices[[4, 5, 1, 3]])
+ faces = np.array([[0, 1, 2], [0, 2, 3], [2, 1, 0], [3, 2, 0]])
+ img = trimesh.Trimesh(vertices=vertices, faces=faces)
+ uv_coords = np.float32([[0, 0], [1, 0], [1, 1], [0, 1]])
+ img.visual = trimesh.visual.TextureVisuals(
+ uv_coords, image=PIL.Image.fromarray(image)
+ )
+ scene.add_geometry(img)
+
+ # this is the camera mesh
+ rot2 = np.eye(4)
+ rot2[:3, :3] = Rotation.from_euler("z", np.deg2rad(2)).as_matrix()
+ vertices = np.r_[cam.vertices, 0.95 * cam.vertices, geotrf(rot2, cam.vertices)]
+ vertices = geotrf(transform, vertices)
+ faces = []
+ for face in cam.faces:
+ if 0 in face:
+ continue
+ a, b, c = face
+ a2, b2, c2 = face + len(cam.vertices)
+ a3, b3, c3 = face + 2 * len(cam.vertices)
+
+ # add 3 pseudo-edges
+ faces.append((a, b, b2))
+ faces.append((a, a2, c))
+ faces.append((c2, b, c))
+
+ faces.append((a, b, b3))
+ faces.append((a, a3, c))
+ faces.append((c3, b, c))
+
+ # no culling
+ faces += [(c, b, a) for a, b, c in faces]
+
+ cam = trimesh.Trimesh(vertices=vertices, faces=faces)
+ cam.visual.face_colors[:, :3] = edge_color
+ scene.add_geometry(cam)
+
+
+def cat(a, b):
+ return np.concatenate((a.reshape(-1, 3), b.reshape(-1, 3)))
+
+
+OPENGL = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]])
+
+
+CAM_COLORS = [
+ (255, 0, 0),
+ (0, 0, 255),
+ (0, 255, 0),
+ (255, 0, 255),
+ (255, 204, 0),
+ (0, 204, 204),
+ (128, 255, 255),
+ (255, 128, 255),
+ (255, 255, 128),
+ (0, 0, 0),
+ (128, 128, 128),
+]
+
+
+def uint8(colors):
+ if not isinstance(colors, np.ndarray):
+ colors = np.array(colors)
+ if np.issubdtype(colors.dtype, np.floating):
+ colors *= 255
+ assert 0 <= colors.min() and colors.max() < 256
+ return np.uint8(colors)
+
+
+def segment_sky(image):
+ import cv2
+ from scipy import ndimage
+
+ # Convert to HSV
+ image = to_numpy(image)
+ if np.issubdtype(image.dtype, np.floating):
+ image = np.uint8(255 * image.clip(min=0, max=1))
+ hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
+
+ # Define range for blue color and create mask
+ lower_blue = np.array([0, 0, 100])
+ upper_blue = np.array([30, 255, 255])
+ mask = cv2.inRange(hsv, lower_blue, upper_blue).view(bool)
+
+ # add luminous gray
+ mask |= (hsv[:, :, 1] < 10) & (hsv[:, :, 2] > 150)
+ mask |= (hsv[:, :, 1] < 30) & (hsv[:, :, 2] > 180)
+ mask |= (hsv[:, :, 1] < 50) & (hsv[:, :, 2] > 220)
+
+ # Morphological operations
+ kernel = np.ones((5, 5), np.uint8)
+ mask2 = ndimage.binary_opening(mask, structure=kernel)
+
+ # keep only largest CC
+ _, labels, stats, _ = cv2.connectedComponentsWithStats(
+ mask2.view(np.uint8), connectivity=8
+ )
+ cc_sizes = stats[1:, cv2.CC_STAT_AREA]
+ order = cc_sizes.argsort()[::-1] # bigger first
+ i = 0
+ selection = []
+ while i < len(order) and cc_sizes[order[i]] > cc_sizes[order[0]] / 2:
+ selection.append(1 + order[i])
+ i += 1
+ mask3 = np.in1d(labels, selection).reshape(labels.shape)
+
+ # Apply mask
+ return torch.from_numpy(mask3)
diff --git a/src/visualization/annotation.py b/src/visualization/annotation.py
new file mode 100644
index 0000000000000000000000000000000000000000..ebcbdbb668e0a32cfc05ee9c986b029a7ff50a34
--- /dev/null
+++ b/src/visualization/annotation.py
@@ -0,0 +1,49 @@
+from pathlib import Path
+from string import ascii_letters, digits, punctuation
+
+import numpy as np
+import torch
+from einops import rearrange
+from jaxtyping import Float
+from PIL import Image, ImageDraw, ImageFont
+from torch import Tensor
+
+from .layout import vcat
+
+EXPECTED_CHARACTERS = digits + punctuation + ascii_letters
+
+
+def draw_label(
+ text: str,
+ font: Path,
+ font_size: int,
+ device: torch.device = torch.device("cpu"),
+) -> Float[Tensor, "3 height width"]:
+ """Draw a black label on a white background with no border."""
+ try:
+ font = ImageFont.truetype(str(font), font_size)
+ except OSError:
+ font = ImageFont.load_default()
+ left, _, right, _ = font.getbbox(text)
+ width = right - left
+ _, top, _, bottom = font.getbbox(EXPECTED_CHARACTERS)
+ height = bottom - top
+ image = Image.new("RGB", (width, height), color="white")
+ draw = ImageDraw.Draw(image)
+ draw.text((0, 0), text, font=font, fill="black")
+ image = torch.tensor(np.array(image) / 255, dtype=torch.float32, device=device)
+ return rearrange(image, "h w c -> c h w")
+
+
+def add_label(
+ image: Float[Tensor, "3 width height"],
+ label: str,
+ font: Path = Path("assets/Inter-Regular.otf"),
+ font_size: int = 24,
+) -> Float[Tensor, "3 width_with_label height_with_label"]:
+ return vcat(
+ draw_label(label, font, font_size, image.device),
+ image,
+ align="left",
+ gap=4,
+ )
diff --git a/src/visualization/camera_trajectory/interpolation.py b/src/visualization/camera_trajectory/interpolation.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1c151441480e889a6e16bb8e705adbfb94ee60d
--- /dev/null
+++ b/src/visualization/camera_trajectory/interpolation.py
@@ -0,0 +1,255 @@
+import torch
+from einops import einsum, rearrange, reduce
+from jaxtyping import Float
+from scipy.spatial.transform import Rotation as R
+from torch import Tensor
+
+
+def interpolate_intrinsics(
+ initial: Float[Tensor, "*#batch 3 3"],
+ final: Float[Tensor, "*#batch 3 3"],
+ t: Float[Tensor, " time_step"],
+) -> Float[Tensor, "*batch time_step 3 3"]:
+ initial = rearrange(initial, "... i j -> ... () i j")
+ final = rearrange(final, "... i j -> ... () i j")
+ t = rearrange(t, "t -> t () ()")
+ return initial + (final - initial) * t
+
+
+def intersect_rays(
+ a_origins: Float[Tensor, "*#batch dim"],
+ a_directions: Float[Tensor, "*#batch dim"],
+ b_origins: Float[Tensor, "*#batch dim"],
+ b_directions: Float[Tensor, "*#batch dim"],
+) -> Float[Tensor, "*batch dim"]:
+ """Compute the least-squares intersection of rays. Uses the math from here:
+ https://math.stackexchange.com/a/1762491/286022
+ """
+
+ # Broadcast and stack the tensors.
+ a_origins, a_directions, b_origins, b_directions = torch.broadcast_tensors(
+ a_origins, a_directions, b_origins, b_directions
+ )
+ origins = torch.stack((a_origins, b_origins), dim=-2)
+ directions = torch.stack((a_directions, b_directions), dim=-2)
+
+ # Compute n_i * n_i^T - eye(3) from the equation.
+ n = einsum(directions, directions, "... n i, ... n j -> ... n i j")
+ n = n - torch.eye(3, dtype=origins.dtype, device=origins.device)
+
+ # Compute the left-hand side of the equation.
+ lhs = reduce(n, "... n i j -> ... i j", "sum")
+
+ # Compute the right-hand side of the equation.
+ rhs = einsum(n, origins, "... n i j, ... n j -> ... n i")
+ rhs = reduce(rhs, "... n i -> ... i", "sum")
+
+ # Left-matrix-multiply both sides by the inverse of lhs to find p.
+ return torch.linalg.lstsq(lhs, rhs).solution
+
+
+def normalize(a: Float[Tensor, "*#batch dim"]) -> Float[Tensor, "*#batch dim"]:
+ return a / a.norm(dim=-1, keepdim=True)
+
+
+def generate_coordinate_frame(
+ y: Float[Tensor, "*#batch 3"],
+ z: Float[Tensor, "*#batch 3"],
+) -> Float[Tensor, "*batch 3 3"]:
+ """Generate a coordinate frame given perpendicular, unit-length Y and Z vectors."""
+ y, z = torch.broadcast_tensors(y, z)
+ return torch.stack([y.cross(z), y, z], dim=-1)
+
+
+def generate_rotation_coordinate_frame(
+ a: Float[Tensor, "*#batch 3"],
+ b: Float[Tensor, "*#batch 3"],
+ eps: float = 1e-4,
+) -> Float[Tensor, "*batch 3 3"]:
+ """Generate a coordinate frame where the Y direction is normal to the plane defined
+ by unit vectors a and b. The other axes are arbitrary."""
+ device = a.device
+
+ # Replace every entry in b that's parallel to the corresponding entry in a with an
+ # arbitrary vector.
+ b = b.detach().clone()
+ parallel = (einsum(a, b, "... i, ... i -> ...").abs() - 1).abs() < eps
+ b[parallel] = torch.tensor([0, 0, 1], dtype=b.dtype, device=device)
+ parallel = (einsum(a, b, "... i, ... i -> ...").abs() - 1).abs() < eps
+ b[parallel] = torch.tensor([0, 1, 0], dtype=b.dtype, device=device)
+
+ # Generate the coordinate frame. The initial cross product defines the plane.
+ return generate_coordinate_frame(normalize(a.cross(b)), a)
+
+
+def matrix_to_euler(
+ rotations: Float[Tensor, "*batch 3 3"],
+ pattern: str,
+) -> Float[Tensor, "*batch 3"]:
+ *batch, _, _ = rotations.shape
+ rotations = rotations.reshape(-1, 3, 3)
+ angles_np = R.from_matrix(rotations.detach().cpu().numpy()).as_euler(pattern)
+ rotations = torch.tensor(angles_np, dtype=rotations.dtype, device=rotations.device)
+ return rotations.reshape(*batch, 3)
+
+
+def euler_to_matrix(
+ rotations: Float[Tensor, "*batch 3"],
+ pattern: str,
+) -> Float[Tensor, "*batch 3 3"]:
+ *batch, _ = rotations.shape
+ rotations = rotations.reshape(-1, 3)
+ matrix_np = R.from_euler(pattern, rotations.detach().cpu().numpy()).as_matrix()
+ rotations = torch.tensor(matrix_np, dtype=rotations.dtype, device=rotations.device)
+ return rotations.reshape(*batch, 3, 3)
+
+
+def extrinsics_to_pivot_parameters(
+ extrinsics: Float[Tensor, "*#batch 4 4"],
+ pivot_coordinate_frame: Float[Tensor, "*#batch 3 3"],
+ pivot_point: Float[Tensor, "*#batch 3"],
+) -> Float[Tensor, "*batch 5"]:
+ """Convert the extrinsics to a representation with 5 degrees of freedom:
+ 1. Distance from pivot point in the "X" (look cross pivot axis) direction.
+ 2. Distance from pivot point in the "Y" (pivot axis) direction.
+ 3. Distance from pivot point in the Z (look) direction
+ 4. Angle in plane
+ 5. Twist (rotation not in plane)
+ """
+
+ # The pivot coordinate frame's Z axis is normal to the plane.
+ pivot_axis = pivot_coordinate_frame[..., :, 1]
+
+ # Compute the translation elements of the pivot parametrization.
+ translation_frame = generate_coordinate_frame(pivot_axis, extrinsics[..., :3, 2])
+ origin = extrinsics[..., :3, 3]
+ delta = pivot_point - origin
+ translation = einsum(translation_frame, delta, "... i j, ... i -> ... j")
+
+ # Add the rotation elements of the pivot parametrization.
+ inverted = pivot_coordinate_frame.inverse() @ extrinsics[..., :3, :3]
+ y, _, z = matrix_to_euler(inverted, "YXZ").unbind(dim=-1)
+
+ return torch.cat([translation, y[..., None], z[..., None]], dim=-1)
+
+
+def pivot_parameters_to_extrinsics(
+ parameters: Float[Tensor, "*#batch 5"],
+ pivot_coordinate_frame: Float[Tensor, "*#batch 3 3"],
+ pivot_point: Float[Tensor, "*#batch 3"],
+) -> Float[Tensor, "*batch 4 4"]:
+ translation, y, z = parameters.split((3, 1, 1), dim=-1)
+
+ euler = torch.cat((y, torch.zeros_like(y), z), dim=-1)
+ rotation = pivot_coordinate_frame @ euler_to_matrix(euler, "YXZ")
+
+ # The pivot coordinate frame's Z axis is normal to the plane.
+ pivot_axis = pivot_coordinate_frame[..., :, 1]
+
+ translation_frame = generate_coordinate_frame(pivot_axis, rotation[..., :3, 2])
+ delta = einsum(translation_frame, translation, "... i j, ... j -> ... i")
+ origin = pivot_point - delta
+
+ *batch, _ = origin.shape
+ extrinsics = torch.eye(4, dtype=parameters.dtype, device=parameters.device)
+ extrinsics = extrinsics.broadcast_to((*batch, 4, 4)).clone()
+ extrinsics[..., 3, 3] = 1
+ extrinsics[..., :3, :3] = rotation
+ extrinsics[..., :3, 3] = origin
+ return extrinsics
+
+
+def interpolate_circular(
+ a: Float[Tensor, "*#batch"],
+ b: Float[Tensor, "*#batch"],
+ t: Float[Tensor, "*#batch"],
+) -> Float[Tensor, " *batch"]:
+ a, b, t = torch.broadcast_tensors(a, b, t)
+
+ tau = 2 * torch.pi
+ a = a % tau
+ b = b % tau
+
+ # Consider piecewise edge cases.
+ d = (b - a).abs()
+ a_left = a - tau
+ d_left = (b - a_left).abs()
+ a_right = a + tau
+ d_right = (b - a_right).abs()
+ use_d = (d < d_left) & (d < d_right)
+ use_d_left = (d_left < d_right) & (~use_d)
+ use_d_right = (~use_d) & (~use_d_left)
+
+ result = a + (b - a) * t
+ result[use_d_left] = (a_left + (b - a_left) * t)[use_d_left]
+ result[use_d_right] = (a_right + (b - a_right) * t)[use_d_right]
+
+ return result
+
+
+def interpolate_pivot_parameters(
+ initial: Float[Tensor, "*#batch 5"],
+ final: Float[Tensor, "*#batch 5"],
+ t: Float[Tensor, " time_step"],
+) -> Float[Tensor, "*batch time_step 5"]:
+ initial = rearrange(initial, "... d -> ... () d")
+ final = rearrange(final, "... d -> ... () d")
+ t = rearrange(t, "t -> t ()")
+ ti, ri = initial.split((3, 2), dim=-1)
+ tf, rf = final.split((3, 2), dim=-1)
+
+ t_lerp = ti + (tf - ti) * t
+ r_lerp = interpolate_circular(ri, rf, t)
+
+ return torch.cat((t_lerp, r_lerp), dim=-1)
+
+
+@torch.no_grad()
+def interpolate_extrinsics(
+ initial: Float[Tensor, "*#batch 4 4"],
+ final: Float[Tensor, "*#batch 4 4"],
+ t: Float[Tensor, " time_step"],
+ eps: float = 1e-4,
+) -> Float[Tensor, "*batch time_step 4 4"]:
+ """Interpolate extrinsics by rotating around their "focus point," which is the
+ least-squares intersection between the look vectors of the initial and final
+ extrinsics.
+ """
+
+ initial = initial.type(torch.float64)
+ final = final.type(torch.float64)
+ t = t.type(torch.float64)
+
+ # Based on the dot product between the look vectors, pick from one of two cases:
+ # 1. Look vectors are parallel: interpolate about their origins' midpoint.
+ # 3. Look vectors aren't parallel: interpolate about their focus point.
+ initial_look = initial[..., :3, 2]
+ final_look = final[..., :3, 2]
+ dot_products = einsum(initial_look, final_look, "... i, ... i -> ...")
+ parallel_mask = (dot_products.abs() - 1).abs() < eps
+
+ # Pick focus points.
+ initial_origin = initial[..., :3, 3]
+ final_origin = final[..., :3, 3]
+ pivot_point = 0.5 * (initial_origin + final_origin)
+ pivot_point[~parallel_mask] = intersect_rays(
+ initial_origin[~parallel_mask],
+ initial_look[~parallel_mask],
+ final_origin[~parallel_mask],
+ final_look[~parallel_mask],
+ )
+
+ # Convert to pivot parameters.
+ pivot_frame = generate_rotation_coordinate_frame(initial_look, final_look, eps=eps)
+ initial_params = extrinsics_to_pivot_parameters(initial, pivot_frame, pivot_point)
+ final_params = extrinsics_to_pivot_parameters(final, pivot_frame, pivot_point)
+
+ # Interpolate the pivot parameters.
+ interpolated_params = interpolate_pivot_parameters(initial_params, final_params, t)
+
+ # Convert back.
+ return pivot_parameters_to_extrinsics(
+ interpolated_params.type(torch.float32),
+ rearrange(pivot_frame, "... i j -> ... () i j").type(torch.float32),
+ rearrange(pivot_point, "... xyz -> ... () xyz").type(torch.float32),
+ )
diff --git a/src/visualization/camera_trajectory/spin.py b/src/visualization/camera_trajectory/spin.py
new file mode 100644
index 0000000000000000000000000000000000000000..eadddcbc9b33075e902d610a3fc40ef0ba7bde2e
--- /dev/null
+++ b/src/visualization/camera_trajectory/spin.py
@@ -0,0 +1,37 @@
+import numpy as np
+import torch
+from einops import repeat
+from jaxtyping import Float
+from scipy.spatial.transform import Rotation as R
+from torch import Tensor
+
+
+def generate_spin(
+ num_frames: int,
+ device: torch.device,
+ elevation: float,
+ radius: float,
+) -> Float[Tensor, "frame 4 4"]:
+ # Translate back along the camera's look vector.
+ tf_translation = torch.eye(4, dtype=torch.float32, device=device)
+ tf_translation[:2] *= -1
+ tf_translation[2, 3] = -radius
+
+ # Generate the transformation for the azimuth.
+ phi = 2 * np.pi * (np.arange(num_frames) / num_frames)
+ rotation_vectors = np.stack([np.zeros_like(phi), phi, np.zeros_like(phi)], axis=-1)
+
+ azimuth = R.from_rotvec(rotation_vectors).as_matrix()
+ azimuth = torch.tensor(azimuth, dtype=torch.float32, device=device)
+ tf_azimuth = torch.eye(4, dtype=torch.float32, device=device)
+ tf_azimuth = repeat(tf_azimuth, "i j -> b i j", b=num_frames).clone()
+ tf_azimuth[:, :3, :3] = azimuth
+
+ # Generate the transformation for the elevation.
+ deg_elevation = np.deg2rad(elevation)
+ elevation = R.from_rotvec(np.array([deg_elevation, 0, 0], dtype=np.float32))
+ elevation = torch.tensor(elevation.as_matrix())
+ tf_elevation = torch.eye(4, dtype=torch.float32, device=device)
+ tf_elevation[:3, :3] = elevation
+
+ return tf_azimuth @ tf_elevation @ tf_translation
diff --git a/src/visualization/camera_trajectory/wobble.py b/src/visualization/camera_trajectory/wobble.py
new file mode 100644
index 0000000000000000000000000000000000000000..5fd71c8f8e3d808561894f993f6cd4988e469b39
--- /dev/null
+++ b/src/visualization/camera_trajectory/wobble.py
@@ -0,0 +1,32 @@
+import torch
+from einops import rearrange
+from jaxtyping import Float
+from torch import Tensor
+
+
+@torch.no_grad()
+def generate_wobble_transformation(
+ radius: Float[Tensor, "*#batch"],
+ t: Float[Tensor, " time_step"],
+ num_rotations: int = 1,
+ scale_radius_with_t: bool = True,
+) -> Float[Tensor, "*batch time_step 4 4"]:
+ # Generate a translation in the image plane.
+ tf = torch.eye(4, dtype=torch.float32, device=t.device)
+ tf = tf.broadcast_to((*radius.shape, t.shape[0], 4, 4)).clone()
+ radius = radius[..., None]
+ if scale_radius_with_t:
+ radius = radius * t
+ tf[..., 0, 3] = torch.sin(2 * torch.pi * num_rotations * t) * radius
+ tf[..., 1, 3] = -torch.cos(2 * torch.pi * num_rotations * t) * radius
+ return tf
+
+
+@torch.no_grad()
+def generate_wobble(
+ extrinsics: Float[Tensor, "*#batch 4 4"],
+ radius: Float[Tensor, "*#batch"],
+ t: Float[Tensor, " time_step"],
+) -> Float[Tensor, "*batch time_step 4 4"]:
+ tf = generate_wobble_transformation(radius, t)
+ return rearrange(extrinsics, "... i j -> ... () i j") @ tf
diff --git a/src/visualization/color_map.py b/src/visualization/color_map.py
new file mode 100644
index 0000000000000000000000000000000000000000..2926fffd896610ebbc1b5786170dcba4cbdb1a73
--- /dev/null
+++ b/src/visualization/color_map.py
@@ -0,0 +1,48 @@
+import torch
+from colorspacious import cspace_convert
+from einops import rearrange
+from jaxtyping import Float
+from matplotlib import cm
+from torch import Tensor
+
+
+def apply_color_map(
+ x: Float[Tensor, " *batch"],
+ color_map: str = "inferno",
+) -> Float[Tensor, "*batch 3"]:
+ cmap = cm.get_cmap(color_map)
+
+ # Convert to NumPy so that Matplotlib color maps can be used.
+ mapped = cmap(x.detach().clip(min=0, max=1).cpu().numpy())[..., :3]
+
+ # Convert back to the original format.
+ return torch.tensor(mapped, device=x.device, dtype=torch.float32)
+
+
+def apply_color_map_to_image(
+ image: Float[Tensor, "*batch height width"],
+ color_map: str = "inferno",
+) -> Float[Tensor, "*batch 3 height with"]:
+ image = apply_color_map(image, color_map)
+ return rearrange(image, "... h w c -> ... c h w")
+
+
+def apply_color_map_2d(
+ x: Float[Tensor, "*#batch"],
+ y: Float[Tensor, "*#batch"],
+) -> Float[Tensor, "*batch 3"]:
+ red = cspace_convert((189, 0, 0), "sRGB255", "CIELab")
+ blue = cspace_convert((0, 45, 255), "sRGB255", "CIELab")
+ white = cspace_convert((255, 255, 255), "sRGB255", "CIELab")
+ x_np = x.detach().clip(min=0, max=1).cpu().numpy()[..., None]
+ y_np = y.detach().clip(min=0, max=1).cpu().numpy()[..., None]
+
+ # Interpolate between red and blue on the x axis.
+ interpolated = x_np * red + (1 - x_np) * blue
+
+ # Interpolate between color and white on the y axis.
+ interpolated = y_np * interpolated + (1 - y_np) * white
+
+ # Convert to RGB.
+ rgb = cspace_convert(interpolated, "CIELab", "sRGB1")
+ return torch.tensor(rgb, device=x.device, dtype=torch.float32).clip(min=0, max=1)
diff --git a/src/visualization/colors.py b/src/visualization/colors.py
new file mode 100644
index 0000000000000000000000000000000000000000..ef82964e0bf71f147d41d95a6e389b1f22ba57b5
--- /dev/null
+++ b/src/visualization/colors.py
@@ -0,0 +1,32 @@
+from PIL import ImageColor
+
+# https://sashamaps.net/docs/resources/20-colors/
+DISTINCT_COLORS = [
+ "#e6194b",
+ "#3cb44b",
+ "#ffe119",
+ "#4363d8",
+ "#f58231",
+ "#911eb4",
+ "#46f0f0",
+ "#f032e6",
+ "#bcf60c",
+ "#fabebe",
+ "#008080",
+ "#e6beff",
+ "#9a6324",
+ "#fffac8",
+ "#800000",
+ "#aaffc3",
+ "#808000",
+ "#ffd8b1",
+ "#000075",
+ "#808080",
+ "#ffffff",
+ "#000000",
+]
+
+
+def get_distinct_color(index: int) -> tuple[float, float, float]:
+ hex = DISTINCT_COLORS[index % len(DISTINCT_COLORS)]
+ return tuple(x / 255 for x in ImageColor.getcolor(hex, "RGB"))
diff --git a/src/visualization/drawing/cameras.py b/src/visualization/drawing/cameras.py
new file mode 100644
index 0000000000000000000000000000000000000000..6a60beb53a1a24931cd4b822e5ebe65b6425a6d6
--- /dev/null
+++ b/src/visualization/drawing/cameras.py
@@ -0,0 +1,195 @@
+from typing import Optional
+
+import torch
+from einops import einsum, rearrange, repeat
+from jaxtyping import Float
+from torch import Tensor
+
+from ...geometry.projection import unproject
+from ..annotation import add_label
+from .lines import draw_lines
+from .types import Scalar, sanitize_scalar
+
+
+def draw_cameras(
+ resolution: int,
+ extrinsics: Float[Tensor, "batch 4 4"],
+ intrinsics: Float[Tensor, "batch 3 3"],
+ color: Float[Tensor, "batch 3"],
+ near: Optional[Scalar] = None,
+ far: Optional[Scalar] = None,
+ margin: float = 0.1, # relative to AABB
+ frustum_scale: float = 0.05, # relative to image resolution
+) -> Float[Tensor, "3 3 height width"]:
+ device = extrinsics.device
+
+ # Compute scene bounds.
+ minima, maxima = compute_aabb(extrinsics, intrinsics, near, far)
+ scene_minima, scene_maxima = compute_equal_aabb_with_margin(
+ minima, maxima, margin=margin
+ )
+ span = (scene_maxima - scene_minima).max()
+
+ # Compute frustum locations.
+ corner_depth = (span * frustum_scale)[None]
+ frustum_corners = unproject_frustum_corners(extrinsics, intrinsics, corner_depth)
+ if near is not None:
+ near_corners = unproject_frustum_corners(extrinsics, intrinsics, near)
+ if far is not None:
+ far_corners = unproject_frustum_corners(extrinsics, intrinsics, far)
+
+ # Project the cameras onto each axis-aligned plane.
+ projections = []
+ for projected_axis in range(3):
+ image = torch.zeros(
+ (3, resolution, resolution),
+ dtype=torch.float32,
+ device=device,
+ )
+ image_x_axis = (projected_axis + 1) % 3
+ image_y_axis = (projected_axis + 2) % 3
+
+ def project(points: Float[Tensor, "*batch 3"]) -> Float[Tensor, "*batch 2"]:
+ x = points[..., image_x_axis]
+ y = points[..., image_y_axis]
+ return torch.stack([x, y], dim=-1)
+
+ x_range, y_range = torch.stack(
+ (project(scene_minima), project(scene_maxima)), dim=-1
+ )
+
+ # Draw near and far planes.
+ if near is not None:
+ projected_near_corners = project(near_corners)
+ image = draw_lines(
+ image,
+ rearrange(projected_near_corners, "b p xy -> (b p) xy"),
+ rearrange(projected_near_corners.roll(1, 1), "b p xy -> (b p) xy"),
+ color=0.25,
+ width=2,
+ x_range=x_range,
+ y_range=y_range,
+ )
+ if far is not None:
+ projected_far_corners = project(far_corners)
+ image = draw_lines(
+ image,
+ rearrange(projected_far_corners, "b p xy -> (b p) xy"),
+ rearrange(projected_far_corners.roll(1, 1), "b p xy -> (b p) xy"),
+ color=0.25,
+ width=2,
+ x_range=x_range,
+ y_range=y_range,
+ )
+ if near is not None and far is not None:
+ image = draw_lines(
+ image,
+ rearrange(projected_near_corners, "b p xy -> (b p) xy"),
+ rearrange(projected_far_corners, "b p xy -> (b p) xy"),
+ color=0.25,
+ width=2,
+ x_range=x_range,
+ y_range=y_range,
+ )
+
+ # Draw the camera frustums themselves.
+ projected_origins = project(extrinsics[:, :3, 3])
+ projected_frustum_corners = project(frustum_corners)
+ start = [
+ repeat(projected_origins, "b xy -> (b p) xy", p=4),
+ rearrange(projected_frustum_corners.roll(1, 1), "b p xy -> (b p) xy"),
+ ]
+ start = rearrange(torch.cat(start, dim=0), "(r b p) xy -> (b r p) xy", r=2, p=4)
+ image = draw_lines(
+ image,
+ start,
+ repeat(projected_frustum_corners, "b p xy -> (b r p) xy", r=2),
+ color=repeat(color, "b c -> (b r p) c", r=2, p=4),
+ width=2,
+ x_range=x_range,
+ y_range=y_range,
+ )
+
+ x_name = "XYZ"[image_x_axis]
+ y_name = "XYZ"[image_y_axis]
+ image = add_label(image, f"{x_name}{y_name} Projection")
+
+ # TODO: Draw axis indicators.
+ projections.append(image)
+
+ return torch.stack(projections)
+
+
+def compute_aabb(
+ extrinsics: Float[Tensor, "batch 4 4"],
+ intrinsics: Float[Tensor, "batch 3 3"],
+ near: Optional[Scalar] = None,
+ far: Optional[Scalar] = None,
+) -> tuple[
+ Float[Tensor, "3"], # minima of the scene
+ Float[Tensor, "3"], # maxima of the scene
+]:
+ """Compute an axis-aligned bounding box for the camera frustums."""
+
+ device = extrinsics.device
+
+ # These points are included in the AABB.
+ points = [extrinsics[:, :3, 3]]
+
+ if near is not None:
+ near = sanitize_scalar(near, device)
+ corners = unproject_frustum_corners(extrinsics, intrinsics, near)
+ points.append(rearrange(corners, "b p xyz -> (b p) xyz"))
+
+ if far is not None:
+ far = sanitize_scalar(far, device)
+ corners = unproject_frustum_corners(extrinsics, intrinsics, far)
+ points.append(rearrange(corners, "b p xyz -> (b p) xyz"))
+
+ points = torch.cat(points, dim=0)
+ return points.min(dim=0).values, points.max(dim=0).values
+
+
+def compute_equal_aabb_with_margin(
+ minima: Float[Tensor, "*#batch 3"],
+ maxima: Float[Tensor, "*#batch 3"],
+ margin: float = 0.1,
+) -> tuple[
+ Float[Tensor, "*batch 3"], # minima of the scene
+ Float[Tensor, "*batch 3"], # maxima of the scene
+]:
+ midpoint = (maxima + minima) * 0.5
+ span = (maxima - minima).max() * (1 + margin)
+ scene_minima = midpoint - 0.5 * span
+ scene_maxima = midpoint + 0.5 * span
+ return scene_minima, scene_maxima
+
+
+def unproject_frustum_corners(
+ extrinsics: Float[Tensor, "batch 4 4"],
+ intrinsics: Float[Tensor, "batch 3 3"],
+ depth: Float[Tensor, "#batch"],
+) -> Float[Tensor, "batch 4 3"]:
+ device = extrinsics.device
+
+ # Get coordinates for the corners. Following them in a circle makes a rectangle.
+ xy = torch.linspace(0, 1, 2, device=device)
+ xy = torch.stack(torch.meshgrid(xy, xy, indexing="xy"), dim=-1)
+ xy = rearrange(xy, "i j xy -> (i j) xy")
+ xy = xy[torch.tensor([0, 1, 3, 2], device=device)]
+
+ # Get ray directions in camera space.
+ directions = unproject(
+ xy,
+ torch.ones(1, dtype=torch.float32, device=device),
+ rearrange(intrinsics, "b i j -> b () i j"),
+ )
+
+ # Divide by the z coordinate so that multiplying by depth will produce orthographic
+ # depth (z depth) as opposed to Euclidean depth (distance from the camera).
+ directions = directions / directions[..., -1:]
+ directions = einsum(extrinsics[..., :3, :3], directions, "b i j, b r j -> b r i")
+
+ origins = rearrange(extrinsics[:, :3, 3], "b xyz -> b () xyz")
+ depth = rearrange(depth, "b -> b () ()")
+ return origins + depth * directions
diff --git a/src/visualization/drawing/coordinate_conversion.py b/src/visualization/drawing/coordinate_conversion.py
new file mode 100644
index 0000000000000000000000000000000000000000..42ac5aca7cf903b61e6e9f9018c8aa1811300fa3
--- /dev/null
+++ b/src/visualization/drawing/coordinate_conversion.py
@@ -0,0 +1,44 @@
+from typing import Optional, Protocol, runtime_checkable
+
+import torch
+from jaxtyping import Float
+from torch import Tensor
+
+from .types import Pair, sanitize_pair
+
+
+@runtime_checkable
+class ConversionFunction(Protocol):
+ def __call__(
+ self,
+ xy: Float[Tensor, "*batch 2"],
+ ) -> Float[Tensor, "*batch 2"]:
+ pass
+
+
+def generate_conversions(
+ shape: tuple[int, int],
+ device: torch.device,
+ x_range: Optional[Pair] = None,
+ y_range: Optional[Pair] = None,
+) -> tuple[
+ ConversionFunction, # conversion from world coordinates to pixel coordinates
+ ConversionFunction, # conversion from pixel coordinates to world coordinates
+]:
+ h, w = shape
+ x_range = sanitize_pair((0, w) if x_range is None else x_range, device)
+ y_range = sanitize_pair((0, h) if y_range is None else y_range, device)
+ minima, maxima = torch.stack((x_range, y_range), dim=-1)
+ wh = torch.tensor((w, h), dtype=torch.float32, device=device)
+
+ def convert_world_to_pixel(
+ xy: Float[Tensor, "*batch 2"],
+ ) -> Float[Tensor, "*batch 2"]:
+ return (xy - minima) / (maxima - minima) * wh
+
+ def convert_pixel_to_world(
+ xy: Float[Tensor, "*batch 2"],
+ ) -> Float[Tensor, "*batch 2"]:
+ return xy / wh * (maxima - minima) + minima
+
+ return convert_world_to_pixel, convert_pixel_to_world
diff --git a/src/visualization/drawing/lines.py b/src/visualization/drawing/lines.py
new file mode 100644
index 0000000000000000000000000000000000000000..85ce39825f456f5511049b2b1237836ded600416
--- /dev/null
+++ b/src/visualization/drawing/lines.py
@@ -0,0 +1,83 @@
+from typing import Literal, Optional
+
+import torch
+from einops import einsum, repeat
+from jaxtyping import Float
+from torch import Tensor
+
+from .coordinate_conversion import generate_conversions
+from .rendering import render_over_image
+from .types import Pair, Scalar, Vector, sanitize_scalar, sanitize_vector
+
+
+def draw_lines(
+ image: Float[Tensor, "3 height width"],
+ start: Vector,
+ end: Vector,
+ color: Vector,
+ width: Scalar,
+ cap: Literal["butt", "round", "square"] = "round",
+ num_msaa_passes: int = 1,
+ x_range: Optional[Pair] = None,
+ y_range: Optional[Pair] = None,
+) -> Float[Tensor, "3 height width"]:
+ device = image.device
+ start = sanitize_vector(start, 2, device)
+ end = sanitize_vector(end, 2, device)
+ color = sanitize_vector(color, 3, device)
+ width = sanitize_scalar(width, device)
+ (num_lines,) = torch.broadcast_shapes(
+ start.shape[0],
+ end.shape[0],
+ color.shape[0],
+ width.shape,
+ )
+
+ # Convert world-space points to pixel space.
+ _, h, w = image.shape
+ world_to_pixel, _ = generate_conversions((h, w), device, x_range, y_range)
+ start = world_to_pixel(start)
+ end = world_to_pixel(end)
+
+ def color_function(
+ xy: Float[Tensor, "point 2"],
+ ) -> Float[Tensor, "point 4"]:
+ # Define a vector between the start and end points.
+ delta = end - start
+ delta_norm = delta.norm(dim=-1, keepdim=True)
+ u_delta = delta / delta_norm
+
+ # Define a vector between each sample and the start point.
+ indicator = xy - start[:, None]
+
+ # Determine whether each sample is inside the line in the parallel direction.
+ extra = 0.5 * width[:, None] if cap == "square" else 0
+ parallel = einsum(u_delta, indicator, "l xy, l s xy -> l s")
+ parallel_inside_line = (parallel <= delta_norm + extra) & (parallel > -extra)
+
+ # Determine whether each sample is inside the line perpendicularly.
+ perpendicular = indicator - parallel[..., None] * u_delta[:, None]
+ perpendicular_inside_line = perpendicular.norm(dim=-1) < 0.5 * width[:, None]
+
+ inside_line = parallel_inside_line & perpendicular_inside_line
+
+ # Compute round caps.
+ if cap == "round":
+ near_start = indicator.norm(dim=-1) < 0.5 * width[:, None]
+ inside_line |= near_start
+ end_indicator = indicator = xy - end[:, None]
+ near_end = end_indicator.norm(dim=-1) < 0.5 * width[:, None]
+ inside_line |= near_end
+
+ # Determine the sample's color.
+ selectable_color = color.broadcast_to((num_lines, 3))
+ arrangement = inside_line * torch.arange(num_lines, device=device)[:, None]
+ top_color = selectable_color.gather(
+ dim=0,
+ index=repeat(arrangement.argmax(dim=0), "s -> s c", c=3),
+ )
+ rgba = torch.cat((top_color, inside_line.any(dim=0).float()[:, None]), dim=-1)
+
+ return rgba
+
+ return render_over_image(image, color_function, device, num_passes=num_msaa_passes)
diff --git a/src/visualization/drawing/points.py b/src/visualization/drawing/points.py
new file mode 100644
index 0000000000000000000000000000000000000000..671db100d34cd9121cb2778dcdb7252ec915bb2d
--- /dev/null
+++ b/src/visualization/drawing/points.py
@@ -0,0 +1,59 @@
+from typing import Optional
+
+import torch
+from einops import repeat
+from jaxtyping import Float
+from torch import Tensor
+
+from .coordinate_conversion import generate_conversions
+from .rendering import render_over_image
+from .types import Pair, Scalar, Vector, sanitize_scalar, sanitize_vector
+
+
+def draw_points(
+ image: Float[Tensor, "3 height width"],
+ points: Vector,
+ color: Vector = [1, 1, 1],
+ radius: Scalar = 1,
+ inner_radius: Scalar = 0,
+ num_msaa_passes: int = 1,
+ x_range: Optional[Pair] = None,
+ y_range: Optional[Pair] = None,
+) -> Float[Tensor, "3 height width"]:
+ device = image.device
+ points = sanitize_vector(points, 2, device)
+ color = sanitize_vector(color, 3, device)
+ radius = sanitize_scalar(radius, device)
+ inner_radius = sanitize_scalar(inner_radius, device)
+ (num_points,) = torch.broadcast_shapes(
+ points.shape[0],
+ color.shape[0],
+ radius.shape,
+ inner_radius.shape,
+ )
+
+ # Convert world-space points to pixel space.
+ _, h, w = image.shape
+ world_to_pixel, _ = generate_conversions((h, w), device, x_range, y_range)
+ points = world_to_pixel(points)
+
+ def color_function(
+ xy: Float[Tensor, "point 2"],
+ ) -> Float[Tensor, "point 4"]:
+ # Define a vector between the start and end points.
+ delta = xy[:, None] - points[None]
+ delta_norm = delta.norm(dim=-1)
+ mask = (delta_norm >= inner_radius[None]) & (delta_norm <= radius[None])
+
+ # Determine the sample's color.
+ selectable_color = color.broadcast_to((num_points, 3))
+ arrangement = mask * torch.arange(num_points, device=device)
+ top_color = selectable_color.gather(
+ dim=0,
+ index=repeat(arrangement.argmax(dim=1), "s -> s c", c=3),
+ )
+ rgba = torch.cat((top_color, mask.any(dim=1).float()[:, None]), dim=-1)
+
+ return rgba
+
+ return render_over_image(image, color_function, device, num_passes=num_msaa_passes)
diff --git a/src/visualization/drawing/rendering.py b/src/visualization/drawing/rendering.py
new file mode 100644
index 0000000000000000000000000000000000000000..65842a8ec5b9c6b109d19ef9509bc04ca0437ea7
--- /dev/null
+++ b/src/visualization/drawing/rendering.py
@@ -0,0 +1,152 @@
+from typing import Protocol, runtime_checkable
+
+import torch
+from einops import rearrange, reduce
+from jaxtyping import Bool, Float
+from torch import Tensor
+
+
+@runtime_checkable
+class ColorFunction(Protocol):
+ def __call__(
+ self,
+ xy: Float[Tensor, "point 2"],
+ ) -> Float[Tensor, "point 4"]: # RGBA color
+ pass
+
+
+def generate_sample_grid(
+ shape: tuple[int, int],
+ device: torch.device,
+) -> Float[Tensor, "height width 2"]:
+ h, w = shape
+ x = torch.arange(w, device=device) + 0.5
+ y = torch.arange(h, device=device) + 0.5
+ x, y = torch.meshgrid(x, y, indexing="xy")
+ return torch.stack([x, y], dim=-1)
+
+
+def detect_msaa_pixels(
+ image: Float[Tensor, "batch 4 height width"],
+) -> Bool[Tensor, "batch height width"]:
+ b, _, h, w = image.shape
+
+ mask = torch.zeros((b, h, w), dtype=torch.bool, device=image.device)
+
+ # Detect horizontal differences.
+ horizontal = (image[:, :, :, 1:] != image[:, :, :, :-1]).any(dim=1)
+ mask[:, :, 1:] |= horizontal
+ mask[:, :, :-1] |= horizontal
+
+ # Detect vertical differences.
+ vertical = (image[:, :, 1:, :] != image[:, :, :-1, :]).any(dim=1)
+ mask[:, 1:, :] |= vertical
+ mask[:, :-1, :] |= vertical
+
+ # Detect diagonal (top left to bottom right) differences.
+ tlbr = (image[:, :, 1:, 1:] != image[:, :, :-1, :-1]).any(dim=1)
+ mask[:, 1:, 1:] |= tlbr
+ mask[:, :-1, :-1] |= tlbr
+
+ # Detect diagonal (top right to bottom left) differences.
+ trbl = (image[:, :, :-1, 1:] != image[:, :, 1:, :-1]).any(dim=1)
+ mask[:, :-1, 1:] |= trbl
+ mask[:, 1:, :-1] |= trbl
+
+ return mask
+
+
+def reduce_straight_alpha(
+ rgba: Float[Tensor, "batch 4 height width"],
+) -> Float[Tensor, "batch 4"]:
+ color, alpha = rgba.split((3, 1), dim=1)
+
+ # Color becomes a weighted average of color (weighted by alpha).
+ weighted_color = reduce(color * alpha, "b c h w -> b c", "sum")
+ alpha_sum = reduce(alpha, "b c h w -> b c", "sum")
+ color = weighted_color / (alpha_sum + 1e-10)
+
+ # Alpha becomes mean alpha.
+ alpha = reduce(alpha, "b c h w -> b c", "mean")
+
+ return torch.cat((color, alpha), dim=-1)
+
+
+@torch.no_grad()
+def run_msaa_pass(
+ xy: Float[Tensor, "batch height width 2"],
+ color_function: ColorFunction,
+ scale: float,
+ subdivision: int,
+ remaining_passes: int,
+ device: torch.device,
+ batch_size: int = int(2**16),
+) -> Float[Tensor, "batch 4 height width"]: # color (RGBA with straight alpha)
+ # Sample the color function.
+ b, h, w, _ = xy.shape
+ color = [
+ color_function(batch)
+ for batch in rearrange(xy, "b h w xy -> (b h w) xy").split(batch_size)
+ ]
+ color = torch.cat(color, dim=0)
+ color = rearrange(color, "(b h w) c -> b c h w", b=b, h=h, w=w)
+
+ # If any MSAA passes remain, subdivide.
+ if remaining_passes > 0:
+ mask = detect_msaa_pixels(color)
+ batch_index, row_index, col_index = torch.where(mask)
+ xy = xy[batch_index, row_index, col_index]
+
+ offsets = generate_sample_grid((subdivision, subdivision), device)
+ offsets = (offsets / subdivision - 0.5) * scale
+
+ color_fine = run_msaa_pass(
+ xy[:, None, None] + offsets,
+ color_function,
+ scale / subdivision,
+ subdivision,
+ remaining_passes - 1,
+ device,
+ batch_size=batch_size,
+ )
+ color[batch_index, :, row_index, col_index] = reduce_straight_alpha(color_fine)
+
+ return color
+
+
+@torch.no_grad()
+def render(
+ shape: tuple[int, int],
+ color_function: ColorFunction,
+ device: torch.device,
+ subdivision: int = 8,
+ num_passes: int = 2,
+) -> Float[Tensor, "4 height width"]: # color (RGBA with straight alpha)
+ xy = generate_sample_grid(shape, device)
+ return run_msaa_pass(
+ xy[None],
+ color_function,
+ 1.0,
+ subdivision,
+ num_passes,
+ device,
+ )[0]
+
+
+def render_over_image(
+ image: Float[Tensor, "3 height width"],
+ color_function: ColorFunction,
+ device: torch.device,
+ subdivision: int = 8,
+ num_passes: int = 1,
+) -> Float[Tensor, "3 height width"]:
+ _, h, w = image.shape
+ overlay = render(
+ (h, w),
+ color_function,
+ device,
+ subdivision=subdivision,
+ num_passes=num_passes,
+ )
+ color, alpha = overlay.split((3, 1), dim=0)
+ return image * (1 - alpha) + color * alpha
diff --git a/src/visualization/drawing/types.py b/src/visualization/drawing/types.py
new file mode 100644
index 0000000000000000000000000000000000000000..4bbbe3a068052ac389a8332304ddbcea78762b86
--- /dev/null
+++ b/src/visualization/drawing/types.py
@@ -0,0 +1,67 @@
+from typing import Iterable, Union
+
+import torch
+from einops import repeat
+from jaxtyping import Float, Shaped
+from torch import Tensor
+
+Real = Union[float, int]
+
+Vector = Union[
+ Real,
+ Iterable[Real],
+ Shaped[Tensor, "dim"],
+ Shaped[Tensor, "batch dim"],
+]
+
+
+def sanitize_vector(
+ vector: Vector,
+ dim: int,
+ device: torch.device,
+) -> Float[Tensor, "*#batch dim"]:
+ if isinstance(vector, Tensor):
+ vector = vector.type(torch.float32).to(device)
+ else:
+ vector = torch.tensor(vector, dtype=torch.float32, device=device)
+ while vector.ndim < 2:
+ vector = vector[None]
+ if vector.shape[-1] == 1:
+ vector = repeat(vector, "... () -> ... c", c=dim)
+ assert vector.shape[-1] == dim
+ assert vector.ndim == 2
+ return vector
+
+
+Scalar = Union[
+ Real,
+ Iterable[Real],
+ Shaped[Tensor, ""],
+ Shaped[Tensor, " batch"],
+]
+
+
+def sanitize_scalar(scalar: Scalar, device: torch.device) -> Float[Tensor, "*#batch"]:
+ if isinstance(scalar, Tensor):
+ scalar = scalar.type(torch.float32).to(device)
+ else:
+ scalar = torch.tensor(scalar, dtype=torch.float32, device=device)
+ while scalar.ndim < 1:
+ scalar = scalar[None]
+ assert scalar.ndim == 1
+ return scalar
+
+
+Pair = Union[
+ Iterable[Real],
+ Shaped[Tensor, "2"],
+]
+
+
+def sanitize_pair(pair: Pair, device: torch.device) -> Float[Tensor, "2"]:
+ if isinstance(pair, Tensor):
+ pair = pair.type(torch.float32).to(device)
+ else:
+ pair = torch.tensor(pair, dtype=torch.float32, device=device)
+ assert pair.shape == (2,)
+ return pair
diff --git a/src/visualization/layout.py b/src/visualization/layout.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca44b8a0814f8e601d82aaea9dd737d8881c972c
--- /dev/null
+++ b/src/visualization/layout.py
@@ -0,0 +1,228 @@
+"""This file contains useful layout utilities for images. They are:
+
+- add_border: Add a border to an image.
+- cat/hcat/vcat: Join images by arranging them in a line. If the images have different
+ sizes, they are aligned as specified (start, end, center). Allows you to specify a gap
+ between images.
+
+Images are assumed to be float32 tensors with shape (channel, height, width).
+"""
+
+from typing import Any, Generator, Iterable, Literal, Optional, Union
+
+import torch
+import torch.nn.functional as F
+from jaxtyping import Float
+from torch import Tensor
+
+Alignment = Literal["start", "center", "end"]
+Axis = Literal["horizontal", "vertical"]
+Color = Union[
+ int,
+ float,
+ Iterable[int],
+ Iterable[float],
+ Float[Tensor, "#channel"],
+ Float[Tensor, ""],
+]
+
+
+def _sanitize_color(color: Color) -> Float[Tensor, "#channel"]:
+ # Convert tensor to list (or individual item).
+ if isinstance(color, torch.Tensor):
+ color = color.tolist()
+
+ # Turn iterators and individual items into lists.
+ if isinstance(color, Iterable):
+ color = list(color)
+ else:
+ color = [color]
+
+ return torch.tensor(color, dtype=torch.float32)
+
+
+def _intersperse(iterable: Iterable, delimiter: Any) -> Generator[Any, None, None]:
+ it = iter(iterable)
+ yield next(it)
+ for item in it:
+ yield delimiter
+ yield item
+
+
+def _get_main_dim(main_axis: Axis) -> int:
+ return {
+ "horizontal": 2,
+ "vertical": 1,
+ }[main_axis]
+
+
+def _get_cross_dim(main_axis: Axis) -> int:
+ return {
+ "horizontal": 1,
+ "vertical": 2,
+ }[main_axis]
+
+
+def _compute_offset(base: int, overlay: int, align: Alignment) -> slice:
+ assert base >= overlay
+ offset = {
+ "start": 0,
+ "center": (base - overlay) // 2,
+ "end": base - overlay,
+ }[align]
+ return slice(offset, offset + overlay)
+
+
+def overlay(
+ base: Float[Tensor, "channel base_height base_width"],
+ overlay: Float[Tensor, "channel overlay_height overlay_width"],
+ main_axis: Axis,
+ main_axis_alignment: Alignment,
+ cross_axis_alignment: Alignment,
+) -> Float[Tensor, "channel base_height base_width"]:
+ # The overlay must be smaller than the base.
+ _, base_height, base_width = base.shape
+ _, overlay_height, overlay_width = overlay.shape
+ assert base_height >= overlay_height and base_width >= overlay_width
+
+ # Compute spacing on the main dimension.
+ main_dim = _get_main_dim(main_axis)
+ main_slice = _compute_offset(
+ base.shape[main_dim], overlay.shape[main_dim], main_axis_alignment
+ )
+
+ # Compute spacing on the cross dimension.
+ cross_dim = _get_cross_dim(main_axis)
+ cross_slice = _compute_offset(
+ base.shape[cross_dim], overlay.shape[cross_dim], cross_axis_alignment
+ )
+
+ # Combine the slices and paste the overlay onto the base accordingly.
+ selector = [..., None, None]
+ selector[main_dim] = main_slice
+ selector[cross_dim] = cross_slice
+ result = base.clone()
+ result[selector] = overlay
+ return result
+
+
+def cat(
+ main_axis: Axis,
+ *images: Iterable[Float[Tensor, "channel _ _"]],
+ align: Alignment = "center",
+ gap: int = 8,
+ gap_color: Color = 1,
+) -> Float[Tensor, "channel height width"]:
+ """Arrange images in a line. The interface resembles a CSS div with flexbox."""
+ device = images[0].device
+ gap_color = _sanitize_color(gap_color).to(device)
+
+ # Find the maximum image side length in the cross axis dimension.
+ cross_dim = _get_cross_dim(main_axis)
+ cross_axis_length = max(image.shape[cross_dim] for image in images)
+
+ # Pad the images.
+ padded_images = []
+ for image in images:
+ # Create an empty image with the correct size.
+ padded_shape = list(image.shape)
+ padded_shape[cross_dim] = cross_axis_length
+ base = torch.ones(padded_shape, dtype=torch.float32, device=device)
+ base = base * gap_color[:, None, None]
+ padded_images.append(overlay(base, image, main_axis, "start", align))
+
+ # Intersperse separators if necessary.
+ if gap > 0:
+ # Generate a separator.
+ c, _, _ = images[0].shape
+ separator_size = [gap, gap]
+ separator_size[cross_dim - 1] = cross_axis_length
+ separator = torch.ones((c, *separator_size), dtype=torch.float32, device=device)
+ separator = separator * gap_color[:, None, None]
+
+ # Intersperse the separator between the images.
+ padded_images = list(_intersperse(padded_images, separator))
+
+ return torch.cat(padded_images, dim=_get_main_dim(main_axis))
+
+
+def hcat(
+ *images: Iterable[Float[Tensor, "channel _ _"]],
+ align: Literal["start", "center", "end", "top", "bottom"] = "start",
+ gap: int = 8,
+ gap_color: Color = 1,
+):
+ """Shorthand for a horizontal linear concatenation."""
+ return cat(
+ "horizontal",
+ *images,
+ align={
+ "start": "start",
+ "center": "center",
+ "end": "end",
+ "top": "start",
+ "bottom": "end",
+ }[align],
+ gap=gap,
+ gap_color=gap_color,
+ )
+
+
+def vcat(
+ *images: Iterable[Float[Tensor, "channel _ _"]],
+ align: Literal["start", "center", "end", "left", "right"] = "start",
+ gap: int = 8,
+ gap_color: Color = 1,
+):
+ """Shorthand for a horizontal linear concatenation."""
+ return cat(
+ "vertical",
+ *images,
+ align={
+ "start": "start",
+ "center": "center",
+ "end": "end",
+ "left": "start",
+ "right": "end",
+ }[align],
+ gap=gap,
+ gap_color=gap_color,
+ )
+
+
+def add_border(
+ image: Float[Tensor, "channel height width"],
+ border: int = 8,
+ color: Color = 1,
+) -> Float[Tensor, "channel new_height new_width"]:
+ color = _sanitize_color(color).to(image)
+ c, h, w = image.shape
+ result = torch.empty(
+ (c, h + 2 * border, w + 2 * border), dtype=torch.float32, device=image.device
+ )
+ result[:] = color[:, None, None]
+ result[:, border : h + border, border : w + border] = image
+ return result
+
+
+def resize(
+ image: Float[Tensor, "channel height width"],
+ shape: Optional[tuple[int, int]] = None,
+ width: Optional[int] = None,
+ height: Optional[int] = None,
+) -> Float[Tensor, "channel new_height new_width"]:
+ assert (shape is not None) + (width is not None) + (height is not None) == 1
+ _, h, w = image.shape
+
+ if width is not None:
+ shape = (int(h * width / w), width)
+ elif height is not None:
+ shape = (height, int(w * height / h))
+
+ return F.interpolate(
+ image[None],
+ shape,
+ mode="bilinear",
+ align_corners=False,
+ antialias="bilinear",
+ )[0]
diff --git a/src/visualization/validation_in_3d.py b/src/visualization/validation_in_3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..0fceafe9493cb33260c0446d7fd07fd2a6a27bb1
--- /dev/null
+++ b/src/visualization/validation_in_3d.py
@@ -0,0 +1,115 @@
+import torch
+from jaxtyping import Float, Shaped
+from torch import Tensor
+
+from ..model.decoder.cuda_splatting import render_cuda_orthographic
+from ..model.types import Gaussians
+from ..visualization.annotation import add_label
+from ..visualization.drawing.cameras import draw_cameras
+from .drawing.cameras import compute_equal_aabb_with_margin
+
+
+def pad(images: list[Shaped[Tensor, "..."]]) -> list[Shaped[Tensor, "..."]]:
+ shapes = torch.stack([torch.tensor(x.shape) for x in images])
+ padded_shape = shapes.max(dim=0)[0]
+ results = [
+ torch.ones(padded_shape.tolist(), dtype=x.dtype, device=x.device)
+ for x in images
+ ]
+ for image, result in zip(images, results):
+ slices = [slice(0, x) for x in image.shape]
+ result[slices] = image[slices]
+ return results
+
+
+def render_projections(
+ gaussians: Gaussians,
+ resolution: int,
+ margin: float = 0.1,
+ draw_label: bool = True,
+ extra_label: str = "",
+) -> Float[Tensor, "batch 3 3 height width"]:
+ device = gaussians.means.device
+ b, _, _ = gaussians.means.shape
+
+ # Compute the minima and maxima of the scene.
+ minima = gaussians.means.min(dim=1).values
+ maxima = gaussians.means.max(dim=1).values
+ scene_minima, scene_maxima = compute_equal_aabb_with_margin(
+ minima, maxima, margin=margin
+ )
+
+ projections = []
+ for look_axis in range(3):
+ right_axis = (look_axis + 1) % 3
+ down_axis = (look_axis + 2) % 3
+
+ # Define the extrinsics for rendering.
+ extrinsics = torch.zeros((b, 4, 4), dtype=torch.float32, device=device)
+ extrinsics[:, right_axis, 0] = 1
+ extrinsics[:, down_axis, 1] = 1
+ extrinsics[:, look_axis, 2] = 1
+ extrinsics[:, right_axis, 3] = 0.5 * (
+ scene_minima[:, right_axis] + scene_maxima[:, right_axis]
+ )
+ extrinsics[:, down_axis, 3] = 0.5 * (
+ scene_minima[:, down_axis] + scene_maxima[:, down_axis]
+ )
+ extrinsics[:, look_axis, 3] = scene_minima[:, look_axis]
+ extrinsics[:, 3, 3] = 1
+
+ # Define the intrinsics for rendering.
+ extents = scene_maxima - scene_minima
+ far = extents[:, look_axis]
+ near = torch.zeros_like(far)
+ width = extents[:, right_axis]
+ height = extents[:, down_axis]
+
+ projection = render_cuda_orthographic(
+ extrinsics,
+ width,
+ height,
+ near,
+ far,
+ (resolution, resolution),
+ torch.zeros((b, 3), dtype=torch.float32, device=device),
+ gaussians.means,
+ gaussians.covariances,
+ gaussians.harmonics,
+ gaussians.opacities,
+ fov_degrees=10.0,
+ )
+ if draw_label:
+ right_axis_name = "XYZ"[right_axis]
+ down_axis_name = "XYZ"[down_axis]
+ label = f"{right_axis_name}{down_axis_name} Projection {extra_label}"
+ projection = torch.stack([add_label(x, label) for x in projection])
+
+ projections.append(projection)
+
+ return torch.stack(pad(projections), dim=1)
+
+
+def render_cameras(batch: dict, resolution: int) -> Float[Tensor, "3 3 height width"]:
+ # Define colors for context and target views.
+ num_context_views = batch["context"]["extrinsics"].shape[1]
+ num_target_views = batch["target"]["extrinsics"].shape[1]
+ color = torch.ones(
+ (num_target_views + num_context_views, 3),
+ dtype=torch.float32,
+ device=batch["target"]["extrinsics"].device,
+ )
+ color[num_context_views:, 1:] = 0
+
+ return draw_cameras(
+ resolution,
+ torch.cat(
+ (batch["context"]["extrinsics"][0], batch["target"]["extrinsics"][0])
+ ),
+ torch.cat(
+ (batch["context"]["intrinsics"][0], batch["target"]["intrinsics"][0])
+ ),
+ color,
+ torch.cat((batch["context"]["near"][0], batch["target"]["near"][0])),
+ torch.cat((batch["context"]["far"][0], batch["target"]["far"][0])),
+ )
diff --git a/src/visualization/video_render.py b/src/visualization/video_render.py
new file mode 100644
index 0000000000000000000000000000000000000000..b7a6153e8d5d98b947f533fae19087122e56bf6f
--- /dev/null
+++ b/src/visualization/video_render.py
@@ -0,0 +1,162 @@
+from typing import Protocol, runtime_checkable
+
+import cv2
+import torch
+from einops import repeat, pack
+from jaxtyping import Float
+from torch import Tensor
+
+from .camera_trajectory.interpolation import interpolate_extrinsics, interpolate_intrinsics
+from .camera_trajectory.wobble import generate_wobble, generate_wobble_transformation
+from .layout import vcat
+from ..dataset.types import BatchedExample
+from ..misc.image_io import save_video
+from ..misc.utils import vis_depth_map
+from ..model.decoder import Decoder
+from ..model.types import Gaussians
+
+
+@runtime_checkable
+class TrajectoryFn(Protocol):
+ def __call__(
+ self,
+ t: Float[Tensor, " t"],
+ ) -> tuple[
+ Float[Tensor, "batch view 4 4"], # extrinsics
+ Float[Tensor, "batch view 3 3"], # intrinsics
+ ]:
+ pass
+
+
+def render_video_wobble(
+ gaussians: Gaussians,
+ decoder: Decoder,
+ batch: BatchedExample,
+ num_frames: int = 60,
+ smooth: bool = True,
+ loop_reverse: bool = True,
+ add_depth: bool = False,
+) -> Tensor:
+ # Two views are needed to get the wobble radius,use the first and the last view
+ _, v, _, _ = batch["context"]["extrinsics"].shape
+
+ def trajectory_fn(t):
+ origin_a = batch["context"]["extrinsics"][:, 0, :3, 3]
+ origin_b = batch["context"]["extrinsics"][:, -1, :3, 3]
+ delta = (origin_a - origin_b).norm(dim=-1)
+ extrinsics = generate_wobble(
+ batch["context"]["extrinsics"][:, 0],
+ delta * 0.25,
+ t,
+ )
+ intrinsics = repeat(
+ batch["context"]["intrinsics"][:, 0],
+ "b i j -> b v i j",
+ v=t.shape[0],
+ )
+ return extrinsics, intrinsics
+
+ return render_video_generic(gaussians, decoder, batch, trajectory_fn, num_frames, smooth, loop_reverse, add_depth)
+
+
+def render_video_interpolation(
+ gaussians: Gaussians,
+ decoder: Decoder,
+ batch: BatchedExample,
+ num_frames: int = 60,
+ smooth: bool = True,
+ loop_reverse: bool = True,
+ add_depth: bool = False,
+) -> Tensor:
+ _, v, _, _ = batch["context"]["extrinsics"].shape
+
+ def trajectory_fn(t):
+ extrinsics = interpolate_extrinsics(
+ batch["context"]["extrinsics"][0, 0],
+ batch["context"]["extrinsics"][0, -1],
+ t,
+ )
+ intrinsics = interpolate_intrinsics(
+ batch["context"]["intrinsics"][0, 0],
+ batch["context"]["intrinsics"][0, -1],
+ t,
+ )
+ return extrinsics[None], intrinsics[None]
+
+ return render_video_generic(gaussians, decoder, batch, trajectory_fn, num_frames, smooth, loop_reverse, add_depth)
+
+
+def render_video_interpolation_exaggerated(
+ gaussians: Gaussians,
+ decoder: Decoder,
+ batch: BatchedExample,
+ num_frames: int = 300,
+ smooth: bool = False,
+ loop_reverse: bool = False,
+ add_depth: bool = False,
+) -> Tensor:
+ # Two views are needed to get the wobble radius.
+ _, v, _, _ = batch["context"]["extrinsics"].shape
+
+ def trajectory_fn(t):
+ origin_a = batch["context"]["extrinsics"][:, 0, :3, 3]
+ origin_b = batch["context"]["extrinsics"][:, -1, :3, 3]
+ delta = (origin_a - origin_b).norm(dim=-1)
+ tf = generate_wobble_transformation(
+ delta * 0.5,
+ t,
+ 5,
+ scale_radius_with_t=False,
+ )
+ extrinsics = interpolate_extrinsics(
+ batch["context"]["extrinsics"][0, 0],
+ batch["context"]["extrinsics"][0, -1],
+ t * 5 - 2,
+ )
+ intrinsics = interpolate_intrinsics(
+ batch["context"]["intrinsics"][0, 0],
+ batch["context"]["extrinsics"][0, -1],
+ t * 5 - 2,
+ )
+ return extrinsics @ tf, intrinsics[None]
+
+ return render_video_generic(gaussians, decoder, batch, trajectory_fn, num_frames, smooth, loop_reverse, add_depth)
+
+
+def render_video_generic(
+ gaussians: Gaussians,
+ decoder: Decoder,
+ batch: BatchedExample,
+ trajectory_fn: TrajectoryFn,
+ num_frames: int = 30,
+ smooth: bool = True,
+ loop_reverse: bool = True,
+ add_depth: bool = False,
+) -> Tensor:
+ device = gaussians.means.device
+
+ t = torch.linspace(0, 1, num_frames, dtype=torch.float32, device=device)
+ if smooth:
+ t = (torch.cos(torch.pi * (t + 1)) + 1) / 2
+
+ extrinsics, intrinsics = trajectory_fn(t)
+
+ _, _, _, h, w = batch["context"]["image"].shape
+
+ near = repeat(batch["context"]["near"][:, 0], "b -> b v", v=num_frames)
+ far = repeat(batch["context"]["far"][:, 0], "b -> b v", v=num_frames)
+ output = decoder.forward(
+ gaussians, extrinsics, intrinsics, near, far, (h, w), "depth"
+ )
+ images = [
+ vcat(rgb, depth) if add_depth else rgb
+ for rgb, depth in zip(output.color[0], vis_depth_map(output.depth[0]))
+ ]
+
+ video = torch.stack(images)
+ # video = (video.clip(min=0, max=1) * 255).type(torch.uint8).cpu().numpy()
+ if loop_reverse:
+ # video = pack([video, video[::-1][1:-1]], "* c h w")[0]
+ video = pack([video, video.flip(dims=(0,))[1:-1]], "* c h w")[0]
+
+ return video