jev-aleks's picture
scenedino init
9e15541
from typing import Optional
from torch import nn
import torch
# import torch_scatter
import torch.autograd.profiler as profiler
from scenedino.common import util
# Resnet Blocks
class ResnetBlockFC(nn.Module):
"""
Fully connected ResNet Block class.
Taken from DVR code.
:param size_in (int): input dimension
:param size_out (int): output dimension
:param size_h (int): hidden dimension
"""
def __init__(self, size_in, size_out=None, size_h=None, beta=0.0):
super().__init__()
# Attributes
if size_out is None:
size_out = size_in
if size_h is None:
size_h = min(size_in, size_out)
self.size_in = size_in
self.size_h = size_h
self.size_out = size_out
# Submodules
self.fc_0 = nn.Linear(size_in, size_h)
self.fc_1 = nn.Linear(size_h, size_out)
# Init
nn.init.constant_(self.fc_0.bias, 0.0)
nn.init.kaiming_normal_(self.fc_0.weight, a=0, mode="fan_in")
nn.init.constant_(self.fc_1.bias, 0.0)
nn.init.zeros_(self.fc_1.weight)
if beta > 0:
self.activation = nn.Softplus(beta=beta)
else:
self.activation = nn.ReLU()
if size_in == size_out:
self.shortcut = None
else:
self.shortcut = nn.Linear(size_in, size_out, bias=False)
nn.init.constant_(self.shortcut.bias, 0.0)
nn.init.kaiming_normal_(self.shortcut.weight, a=0, mode="fan_in")
def forward(self, x):
with profiler.record_function("resblock"):
net = self.fc_0(self.activation(x))
dx = self.fc_1(self.activation(net))
if self.shortcut is not None:
x_s = self.shortcut(x)
else:
x_s = x
return x_s + dx
class ResnetFC(nn.Module):
def __init__(
self,
d_in,
view_number: Optional[int] = None,
d_out=4,
n_blocks=5,
d_latent=0,
d_hidden=128,
beta=0.0,
combine_layer=1000,
combine_type="average",
use_spade=False,
):
"""
:param d_in input size
:param d_out output size
:param n_blocks number of Resnet blocks
:param d_latent latent size, added in each resnet block (0 = disable)
:param d_hidden hiddent dimension throughout network
:param beta softplus beta, 100 is reasonable; if <=0 uses ReLU activations instead
"""
super().__init__()
if d_in > 0:
self.lin_in = nn.Linear(d_in, d_hidden)
nn.init.constant_(self.lin_in.bias, 0.0)
nn.init.kaiming_normal_(self.lin_in.weight, a=0, mode="fan_in")
self.lin_out = nn.Linear(d_hidden, d_out)
nn.init.constant_(self.lin_out.bias, 0.0)
nn.init.kaiming_normal_(self.lin_out.weight, a=0, mode="fan_in")
self.n_blocks = n_blocks
self.d_latent = d_latent
self.d_in = d_in
self.view_number = view_number
self.d_out = d_out
self.d_hidden = d_hidden
self.combine_layer = combine_layer
self.combine_type = combine_type
self.use_spade = use_spade
self.blocks = nn.ModuleList(
[ResnetBlockFC(d_hidden, beta=beta) for i in range(n_blocks)]
)
if d_latent != 0:
n_lin_z = min(combine_layer, n_blocks)
self.lin_z = nn.ModuleList(
[nn.Linear(d_latent, d_hidden) for i in range(n_lin_z)]
)
for i in range(n_lin_z):
nn.init.constant_(self.lin_z[i].bias, 0.0)
nn.init.kaiming_normal_(self.lin_z[i].weight, a=0, mode="fan_in")
if self.use_spade:
self.scale_z = nn.ModuleList(
[nn.Linear(d_latent, d_hidden) for _ in range(n_lin_z)]
)
for i in range(n_lin_z):
nn.init.constant_(self.scale_z[i].bias, 0.0)
nn.init.kaiming_normal_(self.scale_z[i].weight, a=0, mode="fan_in")
if beta > 0:
self.activation = nn.Softplus(beta=beta)
else:
self.activation = nn.ReLU()
def forward(
self,
sampled_features,
combine_inner_dims=(1,),
combine_index=None,
dim_size=None,
**kwargs
):
"""
:param zx (..., d_latent + d_in)
:param combine_inner_dims Combining dimensions for use with multiview inputs.
Tensor will be reshaped to (-1, combine_inner_dims, ...) and reduced using combine_type
on dim 1, at combine_layer
"""
with profiler.record_function("resnetfc_infer"):
if self.view_number is not None:
zx = sampled_features[..., self.view_number, :]
else:
zx = sampled_features
assert zx.size(-1) == self.d_latent + self.d_in
if self.d_latent > 0:
z = zx[..., : self.d_latent]
x = zx[..., self.d_latent :]
else:
x = zx
if self.d_in > 0:
x = self.lin_in(x)
else:
x = torch.zeros(self.d_hidden, device=zx.device)
for blkid in range(self.n_blocks):
if blkid == self.combine_layer:
# The following implements camera frustum culling, requires torch_scatter
# if combine_index is not None:
# combine_type = (
# "mean"
# if self.combine_type == "average"
# else self.combine_type
# )
# if dim_size is not None:
# assert isinstance(dim_size, int)
# x = torch_scatter.scatter(
# x,
# combine_index,
# dim=0,
# dim_size=dim_size,
# reduce=combine_type,
# )
# else:
x = util.combine_interleaved(
x, combine_inner_dims, self.combine_type
)
if self.d_latent > 0 and blkid < self.combine_layer:
tz = self.lin_z[blkid](z)
if self.use_spade:
sz = self.scale_z[blkid](z)
x = sz * x + tz
else:
x = x + tz
x = self.blocks[blkid](x)
out = self.lin_out(self.activation(x))
# if kwargs["head_name"] == "singleviewhead": ## To recognize resnerfc.py that it only creates singleview feature map for pgt loss
# return out[:,0,:] ## Take 1st feature map as viz frame as evluation purpose mono camera
return out
@classmethod ## For foward_hook arguments matching: For multi view BTS model
# def from_conf(cls, conf, d_in, d_out): ## default
def from_conf(cls, conf, d_in, d_out, d_latent=0):
return cls(d_in=d_in, d_out=d_out, **conf)
@classmethod ## default for original resnetfc.py
def from_conf2(cls, conf, d_in, **kwargs):
# PyHocon construction
return cls(
d_in,
n_blocks=conf.get("n_blocks", 5),
d_hidden=conf.get("d_hidden", 128),
beta=conf.get("beta", 0.0),
combine_layer=conf.get("combine_layer", 1000),
combine_type=conf.get("combine_type", "average"), # average | max
use_spade=conf.get("use_spade", False),
**kwargs
)
# @classmethod ## For both multi and single view BTS model (integrated from both classmethod)
# def from_conf(cls, conf, d_in, d_out, **kwargs):
# # PyHocon construction
# return cls(
# d_out = d_out,
# d_in = d_in,
# n_blocks = conf.get("n_blocks", 5),
# d_hidden = conf.get("d_hidden", 128),
# beta = conf.get("beta", 0.0),
# combine_layer = conf.get("combine_layer", 1000),
# combine_type = conf.get("combine_type", "average"), # average | max
# use_spade = conf.get("use_spade", False),
# **kwargs
# )