Spaces:
Running
on
Zero
Running
on
Zero
from typing import Optional | |
from torch import nn | |
import torch | |
# import torch_scatter | |
import torch.autograd.profiler as profiler | |
from scenedino.common import util | |
# Resnet Blocks | |
class ResnetBlockFC(nn.Module): | |
""" | |
Fully connected ResNet Block class. | |
Taken from DVR code. | |
:param size_in (int): input dimension | |
:param size_out (int): output dimension | |
:param size_h (int): hidden dimension | |
""" | |
def __init__(self, size_in, size_out=None, size_h=None, beta=0.0): | |
super().__init__() | |
# Attributes | |
if size_out is None: | |
size_out = size_in | |
if size_h is None: | |
size_h = min(size_in, size_out) | |
self.size_in = size_in | |
self.size_h = size_h | |
self.size_out = size_out | |
# Submodules | |
self.fc_0 = nn.Linear(size_in, size_h) | |
self.fc_1 = nn.Linear(size_h, size_out) | |
# Init | |
nn.init.constant_(self.fc_0.bias, 0.0) | |
nn.init.kaiming_normal_(self.fc_0.weight, a=0, mode="fan_in") | |
nn.init.constant_(self.fc_1.bias, 0.0) | |
nn.init.zeros_(self.fc_1.weight) | |
if beta > 0: | |
self.activation = nn.Softplus(beta=beta) | |
else: | |
self.activation = nn.ReLU() | |
if size_in == size_out: | |
self.shortcut = None | |
else: | |
self.shortcut = nn.Linear(size_in, size_out, bias=False) | |
nn.init.constant_(self.shortcut.bias, 0.0) | |
nn.init.kaiming_normal_(self.shortcut.weight, a=0, mode="fan_in") | |
def forward(self, x): | |
with profiler.record_function("resblock"): | |
net = self.fc_0(self.activation(x)) | |
dx = self.fc_1(self.activation(net)) | |
if self.shortcut is not None: | |
x_s = self.shortcut(x) | |
else: | |
x_s = x | |
return x_s + dx | |
class ResnetFC(nn.Module): | |
def __init__( | |
self, | |
d_in, | |
view_number: Optional[int] = None, | |
d_out=4, | |
n_blocks=5, | |
d_latent=0, | |
d_hidden=128, | |
beta=0.0, | |
combine_layer=1000, | |
combine_type="average", | |
use_spade=False, | |
): | |
""" | |
:param d_in input size | |
:param d_out output size | |
:param n_blocks number of Resnet blocks | |
:param d_latent latent size, added in each resnet block (0 = disable) | |
:param d_hidden hiddent dimension throughout network | |
:param beta softplus beta, 100 is reasonable; if <=0 uses ReLU activations instead | |
""" | |
super().__init__() | |
if d_in > 0: | |
self.lin_in = nn.Linear(d_in, d_hidden) | |
nn.init.constant_(self.lin_in.bias, 0.0) | |
nn.init.kaiming_normal_(self.lin_in.weight, a=0, mode="fan_in") | |
self.lin_out = nn.Linear(d_hidden, d_out) | |
nn.init.constant_(self.lin_out.bias, 0.0) | |
nn.init.kaiming_normal_(self.lin_out.weight, a=0, mode="fan_in") | |
self.n_blocks = n_blocks | |
self.d_latent = d_latent | |
self.d_in = d_in | |
self.view_number = view_number | |
self.d_out = d_out | |
self.d_hidden = d_hidden | |
self.combine_layer = combine_layer | |
self.combine_type = combine_type | |
self.use_spade = use_spade | |
self.blocks = nn.ModuleList( | |
[ResnetBlockFC(d_hidden, beta=beta) for i in range(n_blocks)] | |
) | |
if d_latent != 0: | |
n_lin_z = min(combine_layer, n_blocks) | |
self.lin_z = nn.ModuleList( | |
[nn.Linear(d_latent, d_hidden) for i in range(n_lin_z)] | |
) | |
for i in range(n_lin_z): | |
nn.init.constant_(self.lin_z[i].bias, 0.0) | |
nn.init.kaiming_normal_(self.lin_z[i].weight, a=0, mode="fan_in") | |
if self.use_spade: | |
self.scale_z = nn.ModuleList( | |
[nn.Linear(d_latent, d_hidden) for _ in range(n_lin_z)] | |
) | |
for i in range(n_lin_z): | |
nn.init.constant_(self.scale_z[i].bias, 0.0) | |
nn.init.kaiming_normal_(self.scale_z[i].weight, a=0, mode="fan_in") | |
if beta > 0: | |
self.activation = nn.Softplus(beta=beta) | |
else: | |
self.activation = nn.ReLU() | |
def forward( | |
self, | |
sampled_features, | |
combine_inner_dims=(1,), | |
combine_index=None, | |
dim_size=None, | |
**kwargs | |
): | |
""" | |
:param zx (..., d_latent + d_in) | |
:param combine_inner_dims Combining dimensions for use with multiview inputs. | |
Tensor will be reshaped to (-1, combine_inner_dims, ...) and reduced using combine_type | |
on dim 1, at combine_layer | |
""" | |
with profiler.record_function("resnetfc_infer"): | |
if self.view_number is not None: | |
zx = sampled_features[..., self.view_number, :] | |
else: | |
zx = sampled_features | |
assert zx.size(-1) == self.d_latent + self.d_in | |
if self.d_latent > 0: | |
z = zx[..., : self.d_latent] | |
x = zx[..., self.d_latent :] | |
else: | |
x = zx | |
if self.d_in > 0: | |
x = self.lin_in(x) | |
else: | |
x = torch.zeros(self.d_hidden, device=zx.device) | |
for blkid in range(self.n_blocks): | |
if blkid == self.combine_layer: | |
# The following implements camera frustum culling, requires torch_scatter | |
# if combine_index is not None: | |
# combine_type = ( | |
# "mean" | |
# if self.combine_type == "average" | |
# else self.combine_type | |
# ) | |
# if dim_size is not None: | |
# assert isinstance(dim_size, int) | |
# x = torch_scatter.scatter( | |
# x, | |
# combine_index, | |
# dim=0, | |
# dim_size=dim_size, | |
# reduce=combine_type, | |
# ) | |
# else: | |
x = util.combine_interleaved( | |
x, combine_inner_dims, self.combine_type | |
) | |
if self.d_latent > 0 and blkid < self.combine_layer: | |
tz = self.lin_z[blkid](z) | |
if self.use_spade: | |
sz = self.scale_z[blkid](z) | |
x = sz * x + tz | |
else: | |
x = x + tz | |
x = self.blocks[blkid](x) | |
out = self.lin_out(self.activation(x)) | |
# if kwargs["head_name"] == "singleviewhead": ## To recognize resnerfc.py that it only creates singleview feature map for pgt loss | |
# return out[:,0,:] ## Take 1st feature map as viz frame as evluation purpose mono camera | |
return out | |
## For foward_hook arguments matching: For multi view BTS model | |
# def from_conf(cls, conf, d_in, d_out): ## default | |
def from_conf(cls, conf, d_in, d_out, d_latent=0): | |
return cls(d_in=d_in, d_out=d_out, **conf) | |
## default for original resnetfc.py | |
def from_conf2(cls, conf, d_in, **kwargs): | |
# PyHocon construction | |
return cls( | |
d_in, | |
n_blocks=conf.get("n_blocks", 5), | |
d_hidden=conf.get("d_hidden", 128), | |
beta=conf.get("beta", 0.0), | |
combine_layer=conf.get("combine_layer", 1000), | |
combine_type=conf.get("combine_type", "average"), # average | max | |
use_spade=conf.get("use_spade", False), | |
**kwargs | |
) | |
# @classmethod ## For both multi and single view BTS model (integrated from both classmethod) | |
# def from_conf(cls, conf, d_in, d_out, **kwargs): | |
# # PyHocon construction | |
# return cls( | |
# d_out = d_out, | |
# d_in = d_in, | |
# n_blocks = conf.get("n_blocks", 5), | |
# d_hidden = conf.get("d_hidden", 128), | |
# beta = conf.get("beta", 0.0), | |
# combine_layer = conf.get("combine_layer", 1000), | |
# combine_type = conf.get("combine_type", "average"), # average | max | |
# use_spade = conf.get("use_spade", False), | |
# **kwargs | |
# ) | |