DiffusionSfM / diffusionsfm /model /diffuser_dpt.py
qitaoz's picture
Upload 57 files
4562a06 verified
import ipdb # noqa: F401
import numpy as np
import torch
import torch.nn as nn
from diffusionsfm.model.dit import DiT
from diffusionsfm.model.feature_extractors import PretrainedVAE, SpatialDino
from diffusionsfm.model.blocks import _make_fusion_block, _make_scratch
from diffusionsfm.model.scheduler import NoiseScheduler
# functional implementation
def nearest_neighbor_upsample(x: torch.Tensor, scale_factor: int):
"""Upsample {x} (NCHW) by scale factor {scale_factor} using nearest neighbor interpolation."""
s = scale_factor
return (
x.reshape(*x.shape, 1, 1)
.expand(*x.shape, s, s)
.transpose(-2, -3)
.reshape(*x.shape[:2], *(s * hw for hw in x.shape[2:]))
)
class ProjectReadout(nn.Module):
def __init__(self, in_features, start_index=1):
super(ProjectReadout, self).__init__()
self.start_index = start_index
self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
def forward(self, x):
readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :])
features = torch.cat((x[:, self.start_index :], readout), -1)
return self.project(features)
class RayDiffuserDPT(nn.Module):
def __init__(
self,
model_type="dit",
depth=8,
width=16,
hidden_size=1152,
P=1,
max_num_images=1,
noise_scheduler=None,
freeze_encoder=True,
feature_extractor="dino",
append_ndc=True,
use_unconditional=False,
diffuse_depths=False,
depth_resolution=1,
encoder_features=False,
use_homogeneous=False,
freeze_transformer=False,
cond_depth_mask=False,
):
super().__init__()
if noise_scheduler is None:
self.noise_scheduler = NoiseScheduler()
else:
self.noise_scheduler = noise_scheduler
self.diffuse_depths = diffuse_depths
self.depth_resolution = depth_resolution
self.use_homogeneous = use_homogeneous
self.ray_dim = 3
if self.use_homogeneous:
self.ray_dim += 1
self.ray_dim += self.ray_dim * self.depth_resolution**2
if self.diffuse_depths:
self.ray_dim += 1
self.append_ndc = append_ndc
self.width = width
self.max_num_images = max_num_images
self.model_type = model_type
self.use_unconditional = use_unconditional
self.cond_depth_mask = cond_depth_mask
self.encoder_features = encoder_features
if feature_extractor == "dino":
self.feature_extractor = SpatialDino(
freeze_weights=freeze_encoder,
num_patches_x=width,
num_patches_y=width,
activation_hooks=self.encoder_features,
)
self.feature_dim = self.feature_extractor.feature_dim
elif feature_extractor == "vae":
self.feature_extractor = PretrainedVAE(
freeze_weights=freeze_encoder, num_patches_x=width, num_patches_y=width
)
self.feature_dim = self.feature_extractor.feature_dim
else:
raise Exception(f"Unknown feature extractor {feature_extractor}")
if self.use_unconditional:
self.register_parameter(
"null_token", nn.Parameter(torch.randn(self.feature_dim, 1, 1))
)
self.input_dim = self.feature_dim * 2
if self.append_ndc:
self.input_dim += 2
if model_type == "dit":
self.ray_predictor = DiT(
in_channels=self.input_dim,
out_channels=self.ray_dim,
width=width,
depth=depth,
hidden_size=hidden_size,
max_num_images=max_num_images,
P=P,
)
if freeze_transformer:
for param in self.ray_predictor.parameters():
param.requires_grad = False
# Fusion blocks
self.f = 256
if self.encoder_features:
feature_lens = [
self.feature_extractor.feature_dim,
self.feature_extractor.feature_dim,
self.ray_predictor.hidden_size,
self.ray_predictor.hidden_size,
]
else:
feature_lens = [self.ray_predictor.hidden_size] * 4
self.scratch = _make_scratch(feature_lens, 256, groups=1, expand=False)
self.scratch.refinenet1 = _make_fusion_block(
self.f, use_bn=False, use_ln=False, dpt_time=True, resolution=128
)
self.scratch.refinenet2 = _make_fusion_block(
self.f, use_bn=False, use_ln=False, dpt_time=True, resolution=64
)
self.scratch.refinenet3 = _make_fusion_block(
self.f, use_bn=False, use_ln=False, dpt_time=True, resolution=32
)
self.scratch.refinenet4 = _make_fusion_block(
self.f, use_bn=False, use_ln=False, dpt_time=True, resolution=16
)
self.scratch.input_conv = nn.Conv2d(
self.ray_dim + int(self.cond_depth_mask),
self.feature_dim,
kernel_size=16,
stride=16,
padding=0
)
self.scratch.output_conv = nn.Sequential(
nn.Conv2d(self.f, self.f // 2, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(),
nn.Conv2d(self.f // 2, 32, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(),
nn.Conv2d(32, self.ray_dim, kernel_size=1, stride=1, padding=0),
nn.Identity(),
)
if self.encoder_features:
self.project_opers = nn.ModuleList([
ProjectReadout(in_features=self.feature_extractor.feature_dim),
ProjectReadout(in_features=self.feature_extractor.feature_dim),
])
def forward_noise(
self, x, t, epsilon=None, zero_out_mask=None
):
"""
Applies forward diffusion (adds noise) to the input.
If a mask is provided, the noise is only applied to the masked inputs.
"""
t = t.reshape(-1, 1, 1, 1, 1)
if epsilon is None:
epsilon = torch.randn_like(x)
else:
epsilon = epsilon.reshape(x.shape)
alpha_bar = self.noise_scheduler.alphas_cumprod[t]
x_noise = torch.sqrt(alpha_bar) * x + torch.sqrt(1 - alpha_bar) * epsilon
if zero_out_mask is not None and self.cond_depth_mask:
x_noise = zero_out_mask * x_noise
return x_noise, epsilon
def forward(
self,
features=None,
images=None,
rays=None,
rays_noisy=None,
t=None,
ndc_coordinates=None,
unconditional_mask=None,
encoder_patches=16,
depth_mask=None,
multiview_unconditional=False,
indices=None,
):
"""
Args:
images: (B, N, 3, H, W).
t: (B,).
rays: (B, N, 6, H, W).
rays_noisy: (B, N, 6, H, W).
ndc_coordinates: (B, N, 2, H, W).
unconditional_mask: (B, N) or (B,). Should be 1 for unconditional samples
and 0 else.
"""
if features is None:
# VAE expects 256x256 images while DINO expects 224x224 images.
# Both feature extractors support autoresize=True, but ideally we should
# set this to be false and handle in the dataloader.
features = self.feature_extractor(images, autoresize=True)
B = features.shape[0]
if unconditional_mask is not None and self.use_unconditional:
null_token = self.null_token.reshape(1, 1, self.feature_dim, 1, 1)
unconditional_mask = unconditional_mask.reshape(B, -1, 1, 1, 1)
features = (
features * (1 - unconditional_mask) + null_token * unconditional_mask
)
if isinstance(t, int) or isinstance(t, np.int64):
t = torch.ones(1, dtype=int).to(features.device) * t
else:
t = t.reshape(B)
if rays_noisy is None:
if self.cond_depth_mask:
rays_noisy, epsilon = self.forward_noise(
rays, t, zero_out_mask=depth_mask.unsqueeze(2)
)
else:
rays_noisy, epsilon = self.forward_noise(
rays, t
)
else:
epsilon = None
# DOWNSAMPLE RAYS
B, N, C, H, W = rays_noisy.shape
if self.cond_depth_mask:
if depth_mask is None:
depth_mask = torch.ones_like(rays_noisy[:, :, 0])
ray_repr = torch.cat([rays_noisy, depth_mask.unsqueeze(2)], dim=2)
else:
ray_repr = rays_noisy
ray_repr = self.scratch.input_conv(ray_repr.reshape(B * N, -1, H, W))
_, CP, HP, WP = ray_repr.shape
ray_repr = ray_repr.reshape(B, N, CP, HP, WP)
scene_features = torch.cat([features, ray_repr], dim=2)
if self.append_ndc:
scene_features = torch.cat([scene_features, ndc_coordinates], dim=2)
# DIT FORWARD PASS
activations = self.ray_predictor(
scene_features,
t,
return_dpt_activations=True,
multiview_unconditional=multiview_unconditional,
)
# PROJECT ENCODER ACTIVATIONS & RESHAPE
if self.encoder_features:
for i in range(2):
name = f"encoder{i+1}"
if indices is not None:
act = self.feature_extractor.activations[name][indices]
else:
act = self.feature_extractor.activations[name]
act = self.project_opers[i](act).permute(0, 2, 1)
act = act.reshape(
(
B * N,
self.feature_extractor.feature_dim,
encoder_patches,
encoder_patches,
)
)
activations[i] = act
# UPSAMPLE ACTIVATIONS
for i, act in enumerate(activations):
k = 3 - i
activations[i] = nearest_neighbor_upsample(act, 2**k)
# FUSION BLOCKS
layer_1_rn = self.scratch.layer1_rn(activations[0])
layer_2_rn = self.scratch.layer2_rn(activations[1])
layer_3_rn = self.scratch.layer3_rn(activations[2])
layer_4_rn = self.scratch.layer4_rn(activations[3])
# RESHAPE TIMESTEPS
if t.shape[0] == B:
t = t.unsqueeze(-1).repeat((1, N)).reshape(B * N)
elif t.shape[0] == 1 and B > 1:
t = t.repeat((B * N))
else:
assert False
path_4 = self.scratch.refinenet4(layer_4_rn, t=t)
path_3 = self.scratch.refinenet3(path_4, activation=layer_3_rn, t=t)
path_2 = self.scratch.refinenet2(path_3, activation=layer_2_rn, t=t)
path_1 = self.scratch.refinenet1(path_2, activation=layer_1_rn, t=t)
epsilon_pred = self.scratch.output_conv(path_1)
epsilon_pred = epsilon_pred.reshape((B, N, C, H, W))
return epsilon_pred, epsilon