mac9087 commited on
Commit
95289e5
·
verified ·
1 Parent(s): 06fbbd8

Create nerf_renderer.py

Browse files
Files changed (1) hide show
  1. tsr/models/nerf_renderer.py +180 -0
tsr/models/nerf_renderer.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Dict
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from einops import rearrange, reduce
7
+
8
+ from ..utils import (
9
+ BaseModule,
10
+ chunk_batch,
11
+ get_activation,
12
+ rays_intersect_bbox,
13
+ scale_tensor,
14
+ )
15
+
16
+
17
+ class TriplaneNeRFRenderer(BaseModule):
18
+ @dataclass
19
+ class Config(BaseModule.Config):
20
+ radius: float
21
+
22
+ feature_reduction: str = "concat"
23
+ density_activation: str = "trunc_exp"
24
+ density_bias: float = -1.0
25
+ color_activation: str = "sigmoid"
26
+ num_samples_per_ray: int = 128
27
+ randomized: bool = False
28
+
29
+ cfg: Config
30
+
31
+ def configure(self) -> None:
32
+ assert self.cfg.feature_reduction in ["concat", "mean"]
33
+ self.chunk_size = 0
34
+
35
+ def set_chunk_size(self, chunk_size: int):
36
+ assert (
37
+ chunk_size >= 0
38
+ ), "chunk_size must be a non-negative integer (0 for no chunking)."
39
+ self.chunk_size = chunk_size
40
+
41
+ def query_triplane(
42
+ self,
43
+ decoder: torch.nn.Module,
44
+ positions: torch.Tensor,
45
+ triplane: torch.Tensor,
46
+ ) -> Dict[str, torch.Tensor]:
47
+ input_shape = positions.shape[:-1]
48
+ positions = positions.view(-1, 3)
49
+
50
+ # positions in (-radius, radius)
51
+ # normalized to (-1, 1) for grid sample
52
+ positions = scale_tensor(
53
+ positions, (-self.cfg.radius, self.cfg.radius), (-1, 1)
54
+ )
55
+
56
+ def _query_chunk(x):
57
+ indices2D: torch.Tensor = torch.stack(
58
+ (x[..., [0, 1]], x[..., [0, 2]], x[..., [1, 2]]),
59
+ dim=-3,
60
+ )
61
+ out: torch.Tensor = F.grid_sample(
62
+ rearrange(triplane, "Np Cp Hp Wp -> Np Cp Hp Wp", Np=3),
63
+ rearrange(indices2D, "Np N Nd -> Np () N Nd", Np=3),
64
+ align_corners=False,
65
+ mode="bilinear",
66
+ )
67
+ if self.cfg.feature_reduction == "concat":
68
+ out = rearrange(out, "Np Cp () N -> N (Np Cp)", Np=3)
69
+ elif self.cfg.feature_reduction == "mean":
70
+ out = reduce(out, "Np Cp () N -> N Cp", Np=3, reduction="mean")
71
+ else:
72
+ raise NotImplementedError
73
+
74
+ net_out: Dict[str, torch.Tensor] = decoder(out)
75
+ return net_out
76
+
77
+ if self.chunk_size > 0:
78
+ net_out = chunk_batch(_query_chunk, self.chunk_size, positions)
79
+ else:
80
+ net_out = _query_chunk(positions)
81
+
82
+ net_out["density_act"] = get_activation(self.cfg.density_activation)(
83
+ net_out["density"] + self.cfg.density_bias
84
+ )
85
+ net_out["color"] = get_activation(self.cfg.color_activation)(
86
+ net_out["features"]
87
+ )
88
+
89
+ net_out = {k: v.view(*input_shape, -1) for k, v in net_out.items()}
90
+
91
+ return net_out
92
+
93
+ def _forward(
94
+ self,
95
+ decoder: torch.nn.Module,
96
+ triplane: torch.Tensor,
97
+ rays_o: torch.Tensor,
98
+ rays_d: torch.Tensor,
99
+ **kwargs,
100
+ ):
101
+ rays_shape = rays_o.shape[:-1]
102
+ rays_o = rays_o.view(-1, 3)
103
+ rays_d = rays_d.view(-1, 3)
104
+ n_rays = rays_o.shape[0]
105
+
106
+ t_near, t_far, rays_valid = rays_intersect_bbox(rays_o, rays_d, self.cfg.radius)
107
+ t_near, t_far = t_near[rays_valid], t_far[rays_valid]
108
+
109
+ t_vals = torch.linspace(
110
+ 0, 1, self.cfg.num_samples_per_ray + 1, device=triplane.device
111
+ )
112
+ t_mid = (t_vals[:-1] + t_vals[1:]) / 2.0
113
+ z_vals = t_near * (1 - t_mid[None]) + t_far * t_mid[None] # (N_rays, N_samples)
114
+
115
+ xyz = (
116
+ rays_o[:, None, :] + z_vals[..., None] * rays_d[..., None, :]
117
+ ) # (N_rays, N_sample, 3)
118
+
119
+ mlp_out = self.query_triplane(
120
+ decoder=decoder,
121
+ positions=xyz,
122
+ triplane=triplane,
123
+ )
124
+
125
+ eps = 1e-10
126
+ # deltas = z_vals[:, 1:] - z_vals[:, :-1] # (N_rays, N_samples)
127
+ deltas = t_vals[1:] - t_vals[:-1] # (N_rays, N_samples)
128
+ alpha = 1 - torch.exp(
129
+ -deltas * mlp_out["density_act"][..., 0]
130
+ ) # (N_rays, N_samples)
131
+ accum_prod = torch.cat(
132
+ [
133
+ torch.ones_like(alpha[:, :1]),
134
+ torch.cumprod(1 - alpha[:, :-1] + eps, dim=-1),
135
+ ],
136
+ dim=-1,
137
+ )
138
+ weights = alpha * accum_prod # (N_rays, N_samples)
139
+ comp_rgb_ = (weights[..., None] * mlp_out["color"]).sum(dim=-2) # (N_rays, 3)
140
+ opacity_ = weights.sum(dim=-1) # (N_rays)
141
+
142
+ comp_rgb = torch.zeros(
143
+ n_rays, 3, dtype=comp_rgb_.dtype, device=comp_rgb_.device
144
+ )
145
+ opacity = torch.zeros(n_rays, dtype=opacity_.dtype, device=opacity_.device)
146
+ comp_rgb[rays_valid] = comp_rgb_
147
+ opacity[rays_valid] = opacity_
148
+
149
+ comp_rgb += 1 - opacity[..., None]
150
+ comp_rgb = comp_rgb.view(*rays_shape, 3)
151
+
152
+ return comp_rgb
153
+
154
+ def forward(
155
+ self,
156
+ decoder: torch.nn.Module,
157
+ triplane: torch.Tensor,
158
+ rays_o: torch.Tensor,
159
+ rays_d: torch.Tensor,
160
+ ) -> Dict[str, torch.Tensor]:
161
+ if triplane.ndim == 4:
162
+ comp_rgb = self._forward(decoder, triplane, rays_o, rays_d)
163
+ else:
164
+ comp_rgb = torch.stack(
165
+ [
166
+ self._forward(decoder, triplane[i], rays_o[i], rays_d[i])
167
+ for i in range(triplane.shape[0])
168
+ ],
169
+ dim=0,
170
+ )
171
+
172
+ return comp_rgb
173
+
174
+ def train(self, mode=True):
175
+ self.randomized = mode and self.cfg.randomized
176
+ return super().train(mode=mode)
177
+
178
+ def eval(self):
179
+ self.randomized = False
180
+ return super().eval()