Spaces:
Sleeping
Sleeping
Create nerf_renderer.py
Browse files- tsr/models/nerf_renderer.py +180 -0
tsr/models/nerf_renderer.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import Dict
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from einops import rearrange, reduce
|
7 |
+
|
8 |
+
from ..utils import (
|
9 |
+
BaseModule,
|
10 |
+
chunk_batch,
|
11 |
+
get_activation,
|
12 |
+
rays_intersect_bbox,
|
13 |
+
scale_tensor,
|
14 |
+
)
|
15 |
+
|
16 |
+
|
17 |
+
class TriplaneNeRFRenderer(BaseModule):
|
18 |
+
@dataclass
|
19 |
+
class Config(BaseModule.Config):
|
20 |
+
radius: float
|
21 |
+
|
22 |
+
feature_reduction: str = "concat"
|
23 |
+
density_activation: str = "trunc_exp"
|
24 |
+
density_bias: float = -1.0
|
25 |
+
color_activation: str = "sigmoid"
|
26 |
+
num_samples_per_ray: int = 128
|
27 |
+
randomized: bool = False
|
28 |
+
|
29 |
+
cfg: Config
|
30 |
+
|
31 |
+
def configure(self) -> None:
|
32 |
+
assert self.cfg.feature_reduction in ["concat", "mean"]
|
33 |
+
self.chunk_size = 0
|
34 |
+
|
35 |
+
def set_chunk_size(self, chunk_size: int):
|
36 |
+
assert (
|
37 |
+
chunk_size >= 0
|
38 |
+
), "chunk_size must be a non-negative integer (0 for no chunking)."
|
39 |
+
self.chunk_size = chunk_size
|
40 |
+
|
41 |
+
def query_triplane(
|
42 |
+
self,
|
43 |
+
decoder: torch.nn.Module,
|
44 |
+
positions: torch.Tensor,
|
45 |
+
triplane: torch.Tensor,
|
46 |
+
) -> Dict[str, torch.Tensor]:
|
47 |
+
input_shape = positions.shape[:-1]
|
48 |
+
positions = positions.view(-1, 3)
|
49 |
+
|
50 |
+
# positions in (-radius, radius)
|
51 |
+
# normalized to (-1, 1) for grid sample
|
52 |
+
positions = scale_tensor(
|
53 |
+
positions, (-self.cfg.radius, self.cfg.radius), (-1, 1)
|
54 |
+
)
|
55 |
+
|
56 |
+
def _query_chunk(x):
|
57 |
+
indices2D: torch.Tensor = torch.stack(
|
58 |
+
(x[..., [0, 1]], x[..., [0, 2]], x[..., [1, 2]]),
|
59 |
+
dim=-3,
|
60 |
+
)
|
61 |
+
out: torch.Tensor = F.grid_sample(
|
62 |
+
rearrange(triplane, "Np Cp Hp Wp -> Np Cp Hp Wp", Np=3),
|
63 |
+
rearrange(indices2D, "Np N Nd -> Np () N Nd", Np=3),
|
64 |
+
align_corners=False,
|
65 |
+
mode="bilinear",
|
66 |
+
)
|
67 |
+
if self.cfg.feature_reduction == "concat":
|
68 |
+
out = rearrange(out, "Np Cp () N -> N (Np Cp)", Np=3)
|
69 |
+
elif self.cfg.feature_reduction == "mean":
|
70 |
+
out = reduce(out, "Np Cp () N -> N Cp", Np=3, reduction="mean")
|
71 |
+
else:
|
72 |
+
raise NotImplementedError
|
73 |
+
|
74 |
+
net_out: Dict[str, torch.Tensor] = decoder(out)
|
75 |
+
return net_out
|
76 |
+
|
77 |
+
if self.chunk_size > 0:
|
78 |
+
net_out = chunk_batch(_query_chunk, self.chunk_size, positions)
|
79 |
+
else:
|
80 |
+
net_out = _query_chunk(positions)
|
81 |
+
|
82 |
+
net_out["density_act"] = get_activation(self.cfg.density_activation)(
|
83 |
+
net_out["density"] + self.cfg.density_bias
|
84 |
+
)
|
85 |
+
net_out["color"] = get_activation(self.cfg.color_activation)(
|
86 |
+
net_out["features"]
|
87 |
+
)
|
88 |
+
|
89 |
+
net_out = {k: v.view(*input_shape, -1) for k, v in net_out.items()}
|
90 |
+
|
91 |
+
return net_out
|
92 |
+
|
93 |
+
def _forward(
|
94 |
+
self,
|
95 |
+
decoder: torch.nn.Module,
|
96 |
+
triplane: torch.Tensor,
|
97 |
+
rays_o: torch.Tensor,
|
98 |
+
rays_d: torch.Tensor,
|
99 |
+
**kwargs,
|
100 |
+
):
|
101 |
+
rays_shape = rays_o.shape[:-1]
|
102 |
+
rays_o = rays_o.view(-1, 3)
|
103 |
+
rays_d = rays_d.view(-1, 3)
|
104 |
+
n_rays = rays_o.shape[0]
|
105 |
+
|
106 |
+
t_near, t_far, rays_valid = rays_intersect_bbox(rays_o, rays_d, self.cfg.radius)
|
107 |
+
t_near, t_far = t_near[rays_valid], t_far[rays_valid]
|
108 |
+
|
109 |
+
t_vals = torch.linspace(
|
110 |
+
0, 1, self.cfg.num_samples_per_ray + 1, device=triplane.device
|
111 |
+
)
|
112 |
+
t_mid = (t_vals[:-1] + t_vals[1:]) / 2.0
|
113 |
+
z_vals = t_near * (1 - t_mid[None]) + t_far * t_mid[None] # (N_rays, N_samples)
|
114 |
+
|
115 |
+
xyz = (
|
116 |
+
rays_o[:, None, :] + z_vals[..., None] * rays_d[..., None, :]
|
117 |
+
) # (N_rays, N_sample, 3)
|
118 |
+
|
119 |
+
mlp_out = self.query_triplane(
|
120 |
+
decoder=decoder,
|
121 |
+
positions=xyz,
|
122 |
+
triplane=triplane,
|
123 |
+
)
|
124 |
+
|
125 |
+
eps = 1e-10
|
126 |
+
# deltas = z_vals[:, 1:] - z_vals[:, :-1] # (N_rays, N_samples)
|
127 |
+
deltas = t_vals[1:] - t_vals[:-1] # (N_rays, N_samples)
|
128 |
+
alpha = 1 - torch.exp(
|
129 |
+
-deltas * mlp_out["density_act"][..., 0]
|
130 |
+
) # (N_rays, N_samples)
|
131 |
+
accum_prod = torch.cat(
|
132 |
+
[
|
133 |
+
torch.ones_like(alpha[:, :1]),
|
134 |
+
torch.cumprod(1 - alpha[:, :-1] + eps, dim=-1),
|
135 |
+
],
|
136 |
+
dim=-1,
|
137 |
+
)
|
138 |
+
weights = alpha * accum_prod # (N_rays, N_samples)
|
139 |
+
comp_rgb_ = (weights[..., None] * mlp_out["color"]).sum(dim=-2) # (N_rays, 3)
|
140 |
+
opacity_ = weights.sum(dim=-1) # (N_rays)
|
141 |
+
|
142 |
+
comp_rgb = torch.zeros(
|
143 |
+
n_rays, 3, dtype=comp_rgb_.dtype, device=comp_rgb_.device
|
144 |
+
)
|
145 |
+
opacity = torch.zeros(n_rays, dtype=opacity_.dtype, device=opacity_.device)
|
146 |
+
comp_rgb[rays_valid] = comp_rgb_
|
147 |
+
opacity[rays_valid] = opacity_
|
148 |
+
|
149 |
+
comp_rgb += 1 - opacity[..., None]
|
150 |
+
comp_rgb = comp_rgb.view(*rays_shape, 3)
|
151 |
+
|
152 |
+
return comp_rgb
|
153 |
+
|
154 |
+
def forward(
|
155 |
+
self,
|
156 |
+
decoder: torch.nn.Module,
|
157 |
+
triplane: torch.Tensor,
|
158 |
+
rays_o: torch.Tensor,
|
159 |
+
rays_d: torch.Tensor,
|
160 |
+
) -> Dict[str, torch.Tensor]:
|
161 |
+
if triplane.ndim == 4:
|
162 |
+
comp_rgb = self._forward(decoder, triplane, rays_o, rays_d)
|
163 |
+
else:
|
164 |
+
comp_rgb = torch.stack(
|
165 |
+
[
|
166 |
+
self._forward(decoder, triplane[i], rays_o[i], rays_d[i])
|
167 |
+
for i in range(triplane.shape[0])
|
168 |
+
],
|
169 |
+
dim=0,
|
170 |
+
)
|
171 |
+
|
172 |
+
return comp_rgb
|
173 |
+
|
174 |
+
def train(self, mode=True):
|
175 |
+
self.randomized = mode and self.cfg.randomized
|
176 |
+
return super().train(mode=mode)
|
177 |
+
|
178 |
+
def eval(self):
|
179 |
+
self.randomized = False
|
180 |
+
return super().eval()
|