Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,427 Bytes
9e15541 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 |
from typing import Callable
from numpy import pi
import torch
import torch.nn as nn
import numpy as np
import torch.autograd.profiler as profiler
# TODO: rethink encoding mode
def encoding_mode(
encoding_mode: str, d_min: float, d_max: float, inv_z: bool, EPS: float
) -> Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]:
def _z(xy: torch.Tensor, z: torch.Tensor, distance: torch.Tensor) -> torch.Tensor:
if inv_z:
z = (1 / z.clamp_min(EPS) - 1 / d_max) / (1 / d_min - 1 / d_max)
else:
z = (z - d_min) / (d_max - d_min)
z = 2 * z - 1
return torch.cat(
(xy, z), dim=-1
) ## concatenates the normalized x, y, and z coordinates
def _distance(xy: torch.Tensor, z: torch.Tensor, distance: torch.Tensor):
if inv_z:
distance = (1 / distance.clamp_min(EPS) - 1 / d_max) / (
1 / d_min - 1 / d_max
)
else:
distance = (distance - d_min) / (d_max - d_min)
distance = 2 * distance - 1
return torch.cat(
(xy, distance), dim=-1
) ## Apply the positional encoder to the concatenated xy and depth/distance coordinates (it enables the model to capture more complex spatial dependencies without a significant increase in model complexity or training data)
match encoding_mode:
case "z":
return _z
case "distance":
return _distance
case _:
return _z
class PositionalEncoding(torch.nn.Module):
"""
Implement NeRF's positional encoding
"""
def __init__(self, num_freqs=6, d_in=3, freq_factor=np.pi, include_input=True):
super().__init__()
self.num_freqs = num_freqs
self.d_in = d_in
self.freqs = freq_factor * 2.0 ** torch.arange(0, num_freqs)
self.d_out = self.num_freqs * 2 * d_in
self.include_input = include_input
if include_input:
self.d_out += d_in
# f1 f1 f2 f2 ... to multiply x by
self.register_buffer(
"_freqs", torch.repeat_interleave(self.freqs, 2).view(1, -1, 1)
)
# 0 pi/2 0 pi/2 ... so that
# (sin(x + _phases[0]), sin(x + _phases[1]) ...) = (sin(x), cos(x)...)
_phases = torch.zeros(2 * self.num_freqs)
_phases[1::2] = np.pi * 0.5
self.register_buffer("_phases", _phases.view(1, -1, 1))
def forward(self, x):
"""
Apply positional encoding (new implementation)
:param x (batch, self.d_in)
:return (batch, self.d_out)
"""
with profiler.record_function("positional_enc"):
embed = x.unsqueeze(1).repeat(1, self.num_freqs * 2, 1)
embed = torch.sin(torch.addcmul(self._phases, embed, self._freqs))
embed = embed.view(x.shape[0], -1)
if self.include_input:
embed = torch.cat((x, embed), dim=-1)
return embed
@classmethod
def from_conf(cls, conf, d_in=3):
# PyHocon construction
return cls(
conf.get("num_freqs", 6),
d_in,
conf.get("freq_factor", np.pi),
conf.get("include_input", True),
)
def token_decoding(filter: nn.Module, pos_offset: float = 0.0):
def _decode(xyz: torch.Tensor, tokens: torch.Tensor):
"""Decode tokens into density for given points
Args:
x (torch.Tensor): points in xyz n_pts, 3
tokens (torch.Tensor): tokens n_pts, n_tokens, d_in + 2
"""
n_pts, n_tokens = tokens.shape
with profiler.record_function("positional_enc"):
z = xyz[..., 3]
scale = tokens[..., 0] # n_pts, n_tokens
token_pos_offset = tokens[..., 1] # n_pts, n_tokens
weights = tokens[..., 2:] # n_pts, n_tokens, d_in
positions = (
2.0
* (z.unsqueeze(1).unsqueeze(2).repeat(1, n_tokens) - token_pos_offset)
/ scale
- 1.0
) # n_pts, n_tokens ((z - t_o) / s) * 2.0 - 1.0 t_o => -1.0 t_o + s => 1.0
individual_densities = filter(positions, weights) # n_pts, n_tokens
densities = individual_densities.sum(-1) # n_pts
return densities
return _decode
class FourierFilter(nn.Module):
# TODO: add filter functions
def __init__(
self,
num_freqs=6,
d_in=3,
freq_factor=np.pi,
include_input=True,
filter_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
):
super().__init__()
self.num_freqs = num_freqs
self.d_in = d_in
self.freqs = freq_factor * 2.0 ** torch.arange(0, num_freqs)
self.register_buffer(
"_freqs", torch.repeat_interleave(self.freqs, 2).view(1, -1, 1)
)
# 0 pi/2 0 pi/2 ... so that
# (sin(x + _phases[0]), sin(x + _phases[1]) ...) = (sin(x), cos(x)...)
_phases = torch.zeros(2 * self.num_freqs)
_phases[1::2] = np.pi * 0.5
self.register_buffer("_phases", _phases.view(1, -1, 1))
self.filter_fn = filter_fn
def forward(self, positions: torch.Tensor, weights: torch.Tensor) -> torch.Tensor:
"""Predict density for given normalized points using Fourier features
Args:
positions (torch.Tensor): normalized positions between -1 and 1, (n_pts, n_tokens)
weights (torch.Tensor): weights for each point (n_pts, n_tokens, num_freqs * 2)
Returns:
torch.Tensor: aggregated density for each point (n_pts)
"""
with profiler.record_function("positional_enc"):
positions = positions.unsqueeze(1).repeat(
1, self.num_freqs * 2, 1
) # n_pts, num_freqs * 2, n_tokens
densities = weights.permute(0, 2, 1) * torch.sin(
torch.addcmul(self._phases, positions, self._freqs)
) # n_pts, num_freqs * 2, n_tokens
if self.filter_fn is not None:
densities = self.filter_fn(densities, positions)
return densities.sum(-2) # n_pts, n_tokens
@classmethod
def from_conf(cls, conf, d_in=3):
# PyHocon construction
return cls(
conf.get("num_freqs", 6),
d_in,
conf.get("freq_factor", np.pi),
)
class LogisticFilter(nn.Module):
def __init__(self, slope: float) -> None:
super().__init__()
self.slope = slope
def forward(self, positions: torch.Tensor, weights: torch.Tensor) -> torch.Tensor:
"""Predict the density as sum of weighted logistic functions
Args:
positions (torch.Tensor): normalized positions between -1 and 1, (n_pts, n_tokens)
weights (torch.Tensor): weights for each point (n_pts, n_tokens, d_in)
Returns:
torch.Tensor: density for each point (n_pts, n_tokens)
"""
with profiler.record_function("positional_enc"):
weights = weights.squeeze(-1) # n_pts, n_tokens
sigmoid_pos = self.slope * positions + 1.0
return (
weights * torch.sigmoid(sigmoid_pos) * torch.sigmoid(-sigmoid_pos)
) # n_pts, n_tokens
@classmethod
def from_conf(cls, conf):
# PyHocon construction
return cls(conf.get("slope", 10.0))
|