ziqima's picture
initial commit
4893ce0
"""
SparseUNet Driven by MinkowskiEngine
Modified from chrischoy/SpatioTemporalSegmentation
Author: Xiaoyang Wu ([email protected])
Please cite our work if the code is helpful to you.
"""
import torch
import torch.nn as nn
try:
import MinkowskiEngine as ME
except ImportError:
ME = None
from pointcept.models.builder import MODELS
def offset2batch(offset):
return (
torch.cat(
[
(
torch.tensor([i] * (o - offset[i - 1]))
if i > 0
else torch.tensor([i] * o)
)
for i, o in enumerate(offset)
],
dim=0,
)
.long()
.to(offset.device)
)
class BasicBlock(nn.Module):
expansion = 1
def __init__(
self,
inplanes,
planes,
stride=1,
dilation=1,
downsample=None,
bn_momentum=0.1,
dimension=-1,
):
super(BasicBlock, self).__init__()
assert dimension > 0
self.conv1 = ME.MinkowskiConvolution(
inplanes,
planes,
kernel_size=3,
stride=stride,
dilation=dilation,
dimension=dimension,
)
self.norm1 = ME.MinkowskiBatchNorm(planes, momentum=bn_momentum)
self.conv2 = ME.MinkowskiConvolution(
planes,
planes,
kernel_size=3,
stride=1,
dilation=dilation,
dimension=dimension,
)
self.norm2 = ME.MinkowskiBatchNorm(planes, momentum=bn_momentum)
self.relu = ME.MinkowskiReLU(inplace=True)
self.downsample = downsample
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.norm1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.norm2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(
self,
inplanes,
planes,
stride=1,
dilation=1,
downsample=None,
bn_momentum=0.1,
dimension=-1,
):
super(Bottleneck, self).__init__()
assert dimension > 0
self.conv1 = ME.MinkowskiConvolution(
inplanes, planes, kernel_size=1, dimension=dimension
)
self.norm1 = ME.MinkowskiBatchNorm(planes, momentum=bn_momentum)
self.conv2 = ME.MinkowskiConvolution(
planes,
planes,
kernel_size=3,
stride=stride,
dilation=dilation,
dimension=dimension,
)
self.norm2 = ME.MinkowskiBatchNorm(planes, momentum=bn_momentum)
self.conv3 = ME.MinkowskiConvolution(
planes, planes * self.expansion, kernel_size=1, dimension=dimension
)
self.norm3 = ME.MinkowskiBatchNorm(
planes * self.expansion, momentum=bn_momentum
)
self.relu = ME.MinkowskiReLU(inplace=True)
self.downsample = downsample
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.norm1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.norm2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.norm3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class MinkUNetBase(nn.Module):
BLOCK = None
PLANES = None
DILATIONS = (1, 1, 1, 1, 1, 1, 1, 1)
LAYERS = (2, 2, 2, 2, 2, 2, 2, 2)
PLANES = (32, 64, 128, 256, 256, 128, 96, 96)
INIT_DIM = 32
OUT_TENSOR_STRIDE = 1
def __init__(self, in_channels, out_channels, dimension=3):
super().__init__()
assert ME is not None, "Please follow `README.md` to install MinkowskiEngine.`"
self.D = dimension
assert self.BLOCK is not None
# Output of the first conv concated to conv6
self.inplanes = self.INIT_DIM
self.conv0p1s1 = ME.MinkowskiConvolution(
in_channels, self.inplanes, kernel_size=5, dimension=self.D
)
self.bn0 = ME.MinkowskiBatchNorm(self.inplanes)
self.conv1p1s2 = ME.MinkowskiConvolution(
self.inplanes, self.inplanes, kernel_size=2, stride=2, dimension=self.D
)
self.bn1 = ME.MinkowskiBatchNorm(self.inplanes)
self.block1 = self._make_layer(self.BLOCK, self.PLANES[0], self.LAYERS[0])
self.conv2p2s2 = ME.MinkowskiConvolution(
self.inplanes, self.inplanes, kernel_size=2, stride=2, dimension=self.D
)
self.bn2 = ME.MinkowskiBatchNorm(self.inplanes)
self.block2 = self._make_layer(self.BLOCK, self.PLANES[1], self.LAYERS[1])
self.conv3p4s2 = ME.MinkowskiConvolution(
self.inplanes, self.inplanes, kernel_size=2, stride=2, dimension=self.D
)
self.bn3 = ME.MinkowskiBatchNorm(self.inplanes)
self.block3 = self._make_layer(self.BLOCK, self.PLANES[2], self.LAYERS[2])
self.conv4p8s2 = ME.MinkowskiConvolution(
self.inplanes, self.inplanes, kernel_size=2, stride=2, dimension=self.D
)
self.bn4 = ME.MinkowskiBatchNorm(self.inplanes)
self.block4 = self._make_layer(self.BLOCK, self.PLANES[3], self.LAYERS[3])
self.convtr4p16s2 = ME.MinkowskiConvolutionTranspose(
self.inplanes, self.PLANES[4], kernel_size=2, stride=2, dimension=self.D
)
self.bntr4 = ME.MinkowskiBatchNorm(self.PLANES[4])
self.inplanes = self.PLANES[4] + self.PLANES[2] * self.BLOCK.expansion
self.block5 = self._make_layer(self.BLOCK, self.PLANES[4], self.LAYERS[4])
self.convtr5p8s2 = ME.MinkowskiConvolutionTranspose(
self.inplanes, self.PLANES[5], kernel_size=2, stride=2, dimension=self.D
)
self.bntr5 = ME.MinkowskiBatchNorm(self.PLANES[5])
self.inplanes = self.PLANES[5] + self.PLANES[1] * self.BLOCK.expansion
self.block6 = self._make_layer(self.BLOCK, self.PLANES[5], self.LAYERS[5])
self.convtr6p4s2 = ME.MinkowskiConvolutionTranspose(
self.inplanes, self.PLANES[6], kernel_size=2, stride=2, dimension=self.D
)
self.bntr6 = ME.MinkowskiBatchNorm(self.PLANES[6])
self.inplanes = self.PLANES[6] + self.PLANES[0] * self.BLOCK.expansion
self.block7 = self._make_layer(self.BLOCK, self.PLANES[6], self.LAYERS[6])
self.convtr7p2s2 = ME.MinkowskiConvolutionTranspose(
self.inplanes, self.PLANES[7], kernel_size=2, stride=2, dimension=self.D
)
self.bntr7 = ME.MinkowskiBatchNorm(self.PLANES[7])
self.inplanes = self.PLANES[7] + self.INIT_DIM
self.block8 = self._make_layer(self.BLOCK, self.PLANES[7], self.LAYERS[7])
self.final = ME.MinkowskiConvolution(
self.PLANES[7] * self.BLOCK.expansion,
out_channels,
kernel_size=1,
bias=True,
dimension=self.D,
)
self.relu = ME.MinkowskiReLU(inplace=True)
self.weight_initialization()
def weight_initialization(self):
for m in self.modules():
if isinstance(m, ME.MinkowskiConvolution):
ME.utils.kaiming_normal_(m.kernel, mode="fan_out", nonlinearity="relu")
if isinstance(m, ME.MinkowskiBatchNorm):
nn.init.constant_(m.bn.weight, 1)
nn.init.constant_(m.bn.bias, 0)
def _make_layer(self, block, planes, blocks, stride=1, dilation=1, bn_momentum=0.1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
ME.MinkowskiConvolution(
self.inplanes,
planes * block.expansion,
kernel_size=1,
stride=stride,
dimension=self.D,
),
ME.MinkowskiBatchNorm(planes * block.expansion),
)
layers = []
layers.append(
block(
self.inplanes,
planes,
stride=stride,
dilation=dilation,
downsample=downsample,
dimension=self.D,
)
)
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(
block(
self.inplanes, planes, stride=1, dilation=dilation, dimension=self.D
)
)
return nn.Sequential(*layers)
def forward(self, data_dict):
grid_coord = data_dict["grid_coord"]
feat = data_dict["feat"]
offset = data_dict["offset"]
batch = offset2batch(offset)
in_field = ME.TensorField(
feat,
coordinates=torch.cat([batch.unsqueeze(-1).int(), grid_coord.int()], dim=1),
quantization_mode=ME.SparseTensorQuantizationMode.UNWEIGHTED_AVERAGE,
minkowski_algorithm=ME.MinkowskiAlgorithm.SPEED_OPTIMIZED,
device=feat.device,
)
x = in_field.sparse()
out = self.conv0p1s1(x)
out = self.bn0(out)
out_p1 = self.relu(out)
out = self.conv1p1s2(out_p1)
out = self.bn1(out)
out = self.relu(out)
out_b1p2 = self.block1(out)
out = self.conv2p2s2(out_b1p2)
out = self.bn2(out)
out = self.relu(out)
out_b2p4 = self.block2(out)
out = self.conv3p4s2(out_b2p4)
out = self.bn3(out)
out = self.relu(out)
out_b3p8 = self.block3(out)
# tensor_stride=16
out = self.conv4p8s2(out_b3p8)
out = self.bn4(out)
out = self.relu(out)
out = self.block4(out)
# tensor_stride=8
out = self.convtr4p16s2(out)
out = self.bntr4(out)
out = self.relu(out)
out = ME.cat(out, out_b3p8)
out = self.block5(out)
# tensor_stride=4
out = self.convtr5p8s2(out)
out = self.bntr5(out)
out = self.relu(out)
out = ME.cat(out, out_b2p4)
out = self.block6(out)
# tensor_stride=2
out = self.convtr6p4s2(out)
out = self.bntr6(out)
out = self.relu(out)
out = ME.cat(out, out_b1p2)
out = self.block7(out)
# tensor_stride=1
out = self.convtr7p2s2(out)
out = self.bntr7(out)
out = self.relu(out)
out = ME.cat(out, out_p1)
out = self.block8(out)
return self.final(out).slice(in_field).F
@MODELS.register_module()
class MinkUNet14(MinkUNetBase):
BLOCK = BasicBlock
LAYERS = (1, 1, 1, 1, 1, 1, 1, 1)
@MODELS.register_module()
class MinkUNet18(MinkUNetBase):
BLOCK = BasicBlock
LAYERS = (2, 2, 2, 2, 2, 2, 2, 2)
@MODELS.register_module()
class MinkUNet34(MinkUNetBase):
BLOCK = BasicBlock
LAYERS = (2, 3, 4, 6, 2, 2, 2, 2)
@MODELS.register_module()
class MinkUNet50(MinkUNetBase):
BLOCK = Bottleneck
LAYERS = (2, 3, 4, 6, 2, 2, 2, 2)
@MODELS.register_module()
class MinkUNet101(MinkUNetBase):
BLOCK = Bottleneck
LAYERS = (2, 3, 4, 23, 2, 2, 2, 2)
@MODELS.register_module()
class MinkUNet14A(MinkUNet14):
PLANES = (32, 64, 128, 256, 128, 128, 96, 96)
@MODELS.register_module()
class MinkUNet14B(MinkUNet14):
PLANES = (32, 64, 128, 256, 128, 128, 128, 128)
@MODELS.register_module()
class MinkUNet14C(MinkUNet14):
PLANES = (32, 64, 128, 256, 192, 192, 128, 128)
@MODELS.register_module()
class MinkUNet14D(MinkUNet14):
PLANES = (32, 64, 128, 256, 384, 384, 384, 384)
@MODELS.register_module()
class MinkUNet18A(MinkUNet18):
PLANES = (32, 64, 128, 256, 128, 128, 96, 96)
@MODELS.register_module()
class MinkUNet18B(MinkUNet18):
PLANES = (32, 64, 128, 256, 128, 128, 128, 128)
@MODELS.register_module()
class MinkUNet18D(MinkUNet18):
PLANES = (32, 64, 128, 256, 384, 384, 384, 384)
@MODELS.register_module()
class MinkUNet34A(MinkUNet34):
PLANES = (32, 64, 128, 256, 256, 128, 96, 96)
@MODELS.register_module()
class MinkUNet34B(MinkUNet34):
PLANES = (32, 64, 128, 256, 256, 128, 64, 32)
@MODELS.register_module()
class MinkUNet34C(MinkUNet34):
PLANES = (32, 64, 128, 256, 256, 128, 96, 96)