SceneDINO / scenedino /common /positional_encoding.py
jev-aleks's picture
scenedino init
9e15541
from typing import Callable
from numpy import pi
import torch
import torch.nn as nn
import numpy as np
import torch.autograd.profiler as profiler
# TODO: rethink encoding mode
def encoding_mode(
encoding_mode: str, d_min: float, d_max: float, inv_z: bool, EPS: float
) -> Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]:
def _z(xy: torch.Tensor, z: torch.Tensor, distance: torch.Tensor) -> torch.Tensor:
if inv_z:
z = (1 / z.clamp_min(EPS) - 1 / d_max) / (1 / d_min - 1 / d_max)
else:
z = (z - d_min) / (d_max - d_min)
z = 2 * z - 1
return torch.cat(
(xy, z), dim=-1
) ## concatenates the normalized x, y, and z coordinates
def _distance(xy: torch.Tensor, z: torch.Tensor, distance: torch.Tensor):
if inv_z:
distance = (1 / distance.clamp_min(EPS) - 1 / d_max) / (
1 / d_min - 1 / d_max
)
else:
distance = (distance - d_min) / (d_max - d_min)
distance = 2 * distance - 1
return torch.cat(
(xy, distance), dim=-1
) ## Apply the positional encoder to the concatenated xy and depth/distance coordinates (it enables the model to capture more complex spatial dependencies without a significant increase in model complexity or training data)
match encoding_mode:
case "z":
return _z
case "distance":
return _distance
case _:
return _z
class PositionalEncoding(torch.nn.Module):
"""
Implement NeRF's positional encoding
"""
def __init__(self, num_freqs=6, d_in=3, freq_factor=np.pi, include_input=True):
super().__init__()
self.num_freqs = num_freqs
self.d_in = d_in
self.freqs = freq_factor * 2.0 ** torch.arange(0, num_freqs)
self.d_out = self.num_freqs * 2 * d_in
self.include_input = include_input
if include_input:
self.d_out += d_in
# f1 f1 f2 f2 ... to multiply x by
self.register_buffer(
"_freqs", torch.repeat_interleave(self.freqs, 2).view(1, -1, 1)
)
# 0 pi/2 0 pi/2 ... so that
# (sin(x + _phases[0]), sin(x + _phases[1]) ...) = (sin(x), cos(x)...)
_phases = torch.zeros(2 * self.num_freqs)
_phases[1::2] = np.pi * 0.5
self.register_buffer("_phases", _phases.view(1, -1, 1))
def forward(self, x):
"""
Apply positional encoding (new implementation)
:param x (batch, self.d_in)
:return (batch, self.d_out)
"""
with profiler.record_function("positional_enc"):
embed = x.unsqueeze(1).repeat(1, self.num_freqs * 2, 1)
embed = torch.sin(torch.addcmul(self._phases, embed, self._freqs))
embed = embed.view(x.shape[0], -1)
if self.include_input:
embed = torch.cat((x, embed), dim=-1)
return embed
@classmethod
def from_conf(cls, conf, d_in=3):
# PyHocon construction
return cls(
conf.get("num_freqs", 6),
d_in,
conf.get("freq_factor", np.pi),
conf.get("include_input", True),
)
def token_decoding(filter: nn.Module, pos_offset: float = 0.0):
def _decode(xyz: torch.Tensor, tokens: torch.Tensor):
"""Decode tokens into density for given points
Args:
x (torch.Tensor): points in xyz n_pts, 3
tokens (torch.Tensor): tokens n_pts, n_tokens, d_in + 2
"""
n_pts, n_tokens = tokens.shape
with profiler.record_function("positional_enc"):
z = xyz[..., 3]
scale = tokens[..., 0] # n_pts, n_tokens
token_pos_offset = tokens[..., 1] # n_pts, n_tokens
weights = tokens[..., 2:] # n_pts, n_tokens, d_in
positions = (
2.0
* (z.unsqueeze(1).unsqueeze(2).repeat(1, n_tokens) - token_pos_offset)
/ scale
- 1.0
) # n_pts, n_tokens ((z - t_o) / s) * 2.0 - 1.0 t_o => -1.0 t_o + s => 1.0
individual_densities = filter(positions, weights) # n_pts, n_tokens
densities = individual_densities.sum(-1) # n_pts
return densities
return _decode
class FourierFilter(nn.Module):
# TODO: add filter functions
def __init__(
self,
num_freqs=6,
d_in=3,
freq_factor=np.pi,
include_input=True,
filter_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
):
super().__init__()
self.num_freqs = num_freqs
self.d_in = d_in
self.freqs = freq_factor * 2.0 ** torch.arange(0, num_freqs)
self.register_buffer(
"_freqs", torch.repeat_interleave(self.freqs, 2).view(1, -1, 1)
)
# 0 pi/2 0 pi/2 ... so that
# (sin(x + _phases[0]), sin(x + _phases[1]) ...) = (sin(x), cos(x)...)
_phases = torch.zeros(2 * self.num_freqs)
_phases[1::2] = np.pi * 0.5
self.register_buffer("_phases", _phases.view(1, -1, 1))
self.filter_fn = filter_fn
def forward(self, positions: torch.Tensor, weights: torch.Tensor) -> torch.Tensor:
"""Predict density for given normalized points using Fourier features
Args:
positions (torch.Tensor): normalized positions between -1 and 1, (n_pts, n_tokens)
weights (torch.Tensor): weights for each point (n_pts, n_tokens, num_freqs * 2)
Returns:
torch.Tensor: aggregated density for each point (n_pts)
"""
with profiler.record_function("positional_enc"):
positions = positions.unsqueeze(1).repeat(
1, self.num_freqs * 2, 1
) # n_pts, num_freqs * 2, n_tokens
densities = weights.permute(0, 2, 1) * torch.sin(
torch.addcmul(self._phases, positions, self._freqs)
) # n_pts, num_freqs * 2, n_tokens
if self.filter_fn is not None:
densities = self.filter_fn(densities, positions)
return densities.sum(-2) # n_pts, n_tokens
@classmethod
def from_conf(cls, conf, d_in=3):
# PyHocon construction
return cls(
conf.get("num_freqs", 6),
d_in,
conf.get("freq_factor", np.pi),
)
class LogisticFilter(nn.Module):
def __init__(self, slope: float) -> None:
super().__init__()
self.slope = slope
def forward(self, positions: torch.Tensor, weights: torch.Tensor) -> torch.Tensor:
"""Predict the density as sum of weighted logistic functions
Args:
positions (torch.Tensor): normalized positions between -1 and 1, (n_pts, n_tokens)
weights (torch.Tensor): weights for each point (n_pts, n_tokens, d_in)
Returns:
torch.Tensor: density for each point (n_pts, n_tokens)
"""
with profiler.record_function("positional_enc"):
weights = weights.squeeze(-1) # n_pts, n_tokens
sigmoid_pos = self.slope * positions + 1.0
return (
weights * torch.sigmoid(sigmoid_pos) * torch.sigmoid(-sigmoid_pos)
) # n_pts, n_tokens
@classmethod
def from_conf(cls, conf):
# PyHocon construction
return cls(conf.get("slope", 10.0))