Spaces:
Running
on
Zero
Running
on
Zero
from typing import Callable | |
from numpy import pi | |
import torch | |
import torch.nn as nn | |
import numpy as np | |
import torch.autograd.profiler as profiler | |
# TODO: rethink encoding mode | |
def encoding_mode( | |
encoding_mode: str, d_min: float, d_max: float, inv_z: bool, EPS: float | |
) -> Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]: | |
def _z(xy: torch.Tensor, z: torch.Tensor, distance: torch.Tensor) -> torch.Tensor: | |
if inv_z: | |
z = (1 / z.clamp_min(EPS) - 1 / d_max) / (1 / d_min - 1 / d_max) | |
else: | |
z = (z - d_min) / (d_max - d_min) | |
z = 2 * z - 1 | |
return torch.cat( | |
(xy, z), dim=-1 | |
) ## concatenates the normalized x, y, and z coordinates | |
def _distance(xy: torch.Tensor, z: torch.Tensor, distance: torch.Tensor): | |
if inv_z: | |
distance = (1 / distance.clamp_min(EPS) - 1 / d_max) / ( | |
1 / d_min - 1 / d_max | |
) | |
else: | |
distance = (distance - d_min) / (d_max - d_min) | |
distance = 2 * distance - 1 | |
return torch.cat( | |
(xy, distance), dim=-1 | |
) ## Apply the positional encoder to the concatenated xy and depth/distance coordinates (it enables the model to capture more complex spatial dependencies without a significant increase in model complexity or training data) | |
match encoding_mode: | |
case "z": | |
return _z | |
case "distance": | |
return _distance | |
case _: | |
return _z | |
class PositionalEncoding(torch.nn.Module): | |
""" | |
Implement NeRF's positional encoding | |
""" | |
def __init__(self, num_freqs=6, d_in=3, freq_factor=np.pi, include_input=True): | |
super().__init__() | |
self.num_freqs = num_freqs | |
self.d_in = d_in | |
self.freqs = freq_factor * 2.0 ** torch.arange(0, num_freqs) | |
self.d_out = self.num_freqs * 2 * d_in | |
self.include_input = include_input | |
if include_input: | |
self.d_out += d_in | |
# f1 f1 f2 f2 ... to multiply x by | |
self.register_buffer( | |
"_freqs", torch.repeat_interleave(self.freqs, 2).view(1, -1, 1) | |
) | |
# 0 pi/2 0 pi/2 ... so that | |
# (sin(x + _phases[0]), sin(x + _phases[1]) ...) = (sin(x), cos(x)...) | |
_phases = torch.zeros(2 * self.num_freqs) | |
_phases[1::2] = np.pi * 0.5 | |
self.register_buffer("_phases", _phases.view(1, -1, 1)) | |
def forward(self, x): | |
""" | |
Apply positional encoding (new implementation) | |
:param x (batch, self.d_in) | |
:return (batch, self.d_out) | |
""" | |
with profiler.record_function("positional_enc"): | |
embed = x.unsqueeze(1).repeat(1, self.num_freqs * 2, 1) | |
embed = torch.sin(torch.addcmul(self._phases, embed, self._freqs)) | |
embed = embed.view(x.shape[0], -1) | |
if self.include_input: | |
embed = torch.cat((x, embed), dim=-1) | |
return embed | |
def from_conf(cls, conf, d_in=3): | |
# PyHocon construction | |
return cls( | |
conf.get("num_freqs", 6), | |
d_in, | |
conf.get("freq_factor", np.pi), | |
conf.get("include_input", True), | |
) | |
def token_decoding(filter: nn.Module, pos_offset: float = 0.0): | |
def _decode(xyz: torch.Tensor, tokens: torch.Tensor): | |
"""Decode tokens into density for given points | |
Args: | |
x (torch.Tensor): points in xyz n_pts, 3 | |
tokens (torch.Tensor): tokens n_pts, n_tokens, d_in + 2 | |
""" | |
n_pts, n_tokens = tokens.shape | |
with profiler.record_function("positional_enc"): | |
z = xyz[..., 3] | |
scale = tokens[..., 0] # n_pts, n_tokens | |
token_pos_offset = tokens[..., 1] # n_pts, n_tokens | |
weights = tokens[..., 2:] # n_pts, n_tokens, d_in | |
positions = ( | |
2.0 | |
* (z.unsqueeze(1).unsqueeze(2).repeat(1, n_tokens) - token_pos_offset) | |
/ scale | |
- 1.0 | |
) # n_pts, n_tokens ((z - t_o) / s) * 2.0 - 1.0 t_o => -1.0 t_o + s => 1.0 | |
individual_densities = filter(positions, weights) # n_pts, n_tokens | |
densities = individual_densities.sum(-1) # n_pts | |
return densities | |
return _decode | |
class FourierFilter(nn.Module): | |
# TODO: add filter functions | |
def __init__( | |
self, | |
num_freqs=6, | |
d_in=3, | |
freq_factor=np.pi, | |
include_input=True, | |
filter_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None, | |
): | |
super().__init__() | |
self.num_freqs = num_freqs | |
self.d_in = d_in | |
self.freqs = freq_factor * 2.0 ** torch.arange(0, num_freqs) | |
self.register_buffer( | |
"_freqs", torch.repeat_interleave(self.freqs, 2).view(1, -1, 1) | |
) | |
# 0 pi/2 0 pi/2 ... so that | |
# (sin(x + _phases[0]), sin(x + _phases[1]) ...) = (sin(x), cos(x)...) | |
_phases = torch.zeros(2 * self.num_freqs) | |
_phases[1::2] = np.pi * 0.5 | |
self.register_buffer("_phases", _phases.view(1, -1, 1)) | |
self.filter_fn = filter_fn | |
def forward(self, positions: torch.Tensor, weights: torch.Tensor) -> torch.Tensor: | |
"""Predict density for given normalized points using Fourier features | |
Args: | |
positions (torch.Tensor): normalized positions between -1 and 1, (n_pts, n_tokens) | |
weights (torch.Tensor): weights for each point (n_pts, n_tokens, num_freqs * 2) | |
Returns: | |
torch.Tensor: aggregated density for each point (n_pts) | |
""" | |
with profiler.record_function("positional_enc"): | |
positions = positions.unsqueeze(1).repeat( | |
1, self.num_freqs * 2, 1 | |
) # n_pts, num_freqs * 2, n_tokens | |
densities = weights.permute(0, 2, 1) * torch.sin( | |
torch.addcmul(self._phases, positions, self._freqs) | |
) # n_pts, num_freqs * 2, n_tokens | |
if self.filter_fn is not None: | |
densities = self.filter_fn(densities, positions) | |
return densities.sum(-2) # n_pts, n_tokens | |
def from_conf(cls, conf, d_in=3): | |
# PyHocon construction | |
return cls( | |
conf.get("num_freqs", 6), | |
d_in, | |
conf.get("freq_factor", np.pi), | |
) | |
class LogisticFilter(nn.Module): | |
def __init__(self, slope: float) -> None: | |
super().__init__() | |
self.slope = slope | |
def forward(self, positions: torch.Tensor, weights: torch.Tensor) -> torch.Tensor: | |
"""Predict the density as sum of weighted logistic functions | |
Args: | |
positions (torch.Tensor): normalized positions between -1 and 1, (n_pts, n_tokens) | |
weights (torch.Tensor): weights for each point (n_pts, n_tokens, d_in) | |
Returns: | |
torch.Tensor: density for each point (n_pts, n_tokens) | |
""" | |
with profiler.record_function("positional_enc"): | |
weights = weights.squeeze(-1) # n_pts, n_tokens | |
sigmoid_pos = self.slope * positions + 1.0 | |
return ( | |
weights * torch.sigmoid(sigmoid_pos) * torch.sigmoid(-sigmoid_pos) | |
) # n_pts, n_tokens | |
def from_conf(cls, conf): | |
# PyHocon construction | |
return cls(conf.get("slope", 10.0)) | |