Spaces:
Runtime error
Runtime error
""" | |
SparseUNet Driven by MinkowskiEngine | |
Modified from chrischoy/SpatioTemporalSegmentation | |
Author: Xiaoyang Wu ([email protected]) | |
Please cite our work if the code is helpful to you. | |
""" | |
import torch | |
import torch.nn as nn | |
try: | |
import MinkowskiEngine as ME | |
except ImportError: | |
ME = None | |
from pointcept.models.builder import MODELS | |
def offset2batch(offset): | |
return ( | |
torch.cat( | |
[ | |
( | |
torch.tensor([i] * (o - offset[i - 1])) | |
if i > 0 | |
else torch.tensor([i] * o) | |
) | |
for i, o in enumerate(offset) | |
], | |
dim=0, | |
) | |
.long() | |
.to(offset.device) | |
) | |
class BasicBlock(nn.Module): | |
expansion = 1 | |
def __init__( | |
self, | |
inplanes, | |
planes, | |
stride=1, | |
dilation=1, | |
downsample=None, | |
bn_momentum=0.1, | |
dimension=-1, | |
): | |
super(BasicBlock, self).__init__() | |
assert dimension > 0 | |
self.conv1 = ME.MinkowskiConvolution( | |
inplanes, | |
planes, | |
kernel_size=3, | |
stride=stride, | |
dilation=dilation, | |
dimension=dimension, | |
) | |
self.norm1 = ME.MinkowskiBatchNorm(planes, momentum=bn_momentum) | |
self.conv2 = ME.MinkowskiConvolution( | |
planes, | |
planes, | |
kernel_size=3, | |
stride=1, | |
dilation=dilation, | |
dimension=dimension, | |
) | |
self.norm2 = ME.MinkowskiBatchNorm(planes, momentum=bn_momentum) | |
self.relu = ME.MinkowskiReLU(inplace=True) | |
self.downsample = downsample | |
def forward(self, x): | |
residual = x | |
out = self.conv1(x) | |
out = self.norm1(out) | |
out = self.relu(out) | |
out = self.conv2(out) | |
out = self.norm2(out) | |
if self.downsample is not None: | |
residual = self.downsample(x) | |
out += residual | |
out = self.relu(out) | |
return out | |
class Bottleneck(nn.Module): | |
expansion = 4 | |
def __init__( | |
self, | |
inplanes, | |
planes, | |
stride=1, | |
dilation=1, | |
downsample=None, | |
bn_momentum=0.1, | |
dimension=-1, | |
): | |
super(Bottleneck, self).__init__() | |
assert dimension > 0 | |
self.conv1 = ME.MinkowskiConvolution( | |
inplanes, planes, kernel_size=1, dimension=dimension | |
) | |
self.norm1 = ME.MinkowskiBatchNorm(planes, momentum=bn_momentum) | |
self.conv2 = ME.MinkowskiConvolution( | |
planes, | |
planes, | |
kernel_size=3, | |
stride=stride, | |
dilation=dilation, | |
dimension=dimension, | |
) | |
self.norm2 = ME.MinkowskiBatchNorm(planes, momentum=bn_momentum) | |
self.conv3 = ME.MinkowskiConvolution( | |
planes, planes * self.expansion, kernel_size=1, dimension=dimension | |
) | |
self.norm3 = ME.MinkowskiBatchNorm( | |
planes * self.expansion, momentum=bn_momentum | |
) | |
self.relu = ME.MinkowskiReLU(inplace=True) | |
self.downsample = downsample | |
def forward(self, x): | |
residual = x | |
out = self.conv1(x) | |
out = self.norm1(out) | |
out = self.relu(out) | |
out = self.conv2(out) | |
out = self.norm2(out) | |
out = self.relu(out) | |
out = self.conv3(out) | |
out = self.norm3(out) | |
if self.downsample is not None: | |
residual = self.downsample(x) | |
out += residual | |
out = self.relu(out) | |
return out | |
class MinkUNetBase(nn.Module): | |
BLOCK = None | |
PLANES = None | |
DILATIONS = (1, 1, 1, 1, 1, 1, 1, 1) | |
LAYERS = (2, 2, 2, 2, 2, 2, 2, 2) | |
PLANES = (32, 64, 128, 256, 256, 128, 96, 96) | |
INIT_DIM = 32 | |
OUT_TENSOR_STRIDE = 1 | |
def __init__(self, in_channels, out_channels, dimension=3): | |
super().__init__() | |
assert ME is not None, "Please follow `README.md` to install MinkowskiEngine.`" | |
self.D = dimension | |
assert self.BLOCK is not None | |
# Output of the first conv concated to conv6 | |
self.inplanes = self.INIT_DIM | |
self.conv0p1s1 = ME.MinkowskiConvolution( | |
in_channels, self.inplanes, kernel_size=5, dimension=self.D | |
) | |
self.bn0 = ME.MinkowskiBatchNorm(self.inplanes) | |
self.conv1p1s2 = ME.MinkowskiConvolution( | |
self.inplanes, self.inplanes, kernel_size=2, stride=2, dimension=self.D | |
) | |
self.bn1 = ME.MinkowskiBatchNorm(self.inplanes) | |
self.block1 = self._make_layer(self.BLOCK, self.PLANES[0], self.LAYERS[0]) | |
self.conv2p2s2 = ME.MinkowskiConvolution( | |
self.inplanes, self.inplanes, kernel_size=2, stride=2, dimension=self.D | |
) | |
self.bn2 = ME.MinkowskiBatchNorm(self.inplanes) | |
self.block2 = self._make_layer(self.BLOCK, self.PLANES[1], self.LAYERS[1]) | |
self.conv3p4s2 = ME.MinkowskiConvolution( | |
self.inplanes, self.inplanes, kernel_size=2, stride=2, dimension=self.D | |
) | |
self.bn3 = ME.MinkowskiBatchNorm(self.inplanes) | |
self.block3 = self._make_layer(self.BLOCK, self.PLANES[2], self.LAYERS[2]) | |
self.conv4p8s2 = ME.MinkowskiConvolution( | |
self.inplanes, self.inplanes, kernel_size=2, stride=2, dimension=self.D | |
) | |
self.bn4 = ME.MinkowskiBatchNorm(self.inplanes) | |
self.block4 = self._make_layer(self.BLOCK, self.PLANES[3], self.LAYERS[3]) | |
self.convtr4p16s2 = ME.MinkowskiConvolutionTranspose( | |
self.inplanes, self.PLANES[4], kernel_size=2, stride=2, dimension=self.D | |
) | |
self.bntr4 = ME.MinkowskiBatchNorm(self.PLANES[4]) | |
self.inplanes = self.PLANES[4] + self.PLANES[2] * self.BLOCK.expansion | |
self.block5 = self._make_layer(self.BLOCK, self.PLANES[4], self.LAYERS[4]) | |
self.convtr5p8s2 = ME.MinkowskiConvolutionTranspose( | |
self.inplanes, self.PLANES[5], kernel_size=2, stride=2, dimension=self.D | |
) | |
self.bntr5 = ME.MinkowskiBatchNorm(self.PLANES[5]) | |
self.inplanes = self.PLANES[5] + self.PLANES[1] * self.BLOCK.expansion | |
self.block6 = self._make_layer(self.BLOCK, self.PLANES[5], self.LAYERS[5]) | |
self.convtr6p4s2 = ME.MinkowskiConvolutionTranspose( | |
self.inplanes, self.PLANES[6], kernel_size=2, stride=2, dimension=self.D | |
) | |
self.bntr6 = ME.MinkowskiBatchNorm(self.PLANES[6]) | |
self.inplanes = self.PLANES[6] + self.PLANES[0] * self.BLOCK.expansion | |
self.block7 = self._make_layer(self.BLOCK, self.PLANES[6], self.LAYERS[6]) | |
self.convtr7p2s2 = ME.MinkowskiConvolutionTranspose( | |
self.inplanes, self.PLANES[7], kernel_size=2, stride=2, dimension=self.D | |
) | |
self.bntr7 = ME.MinkowskiBatchNorm(self.PLANES[7]) | |
self.inplanes = self.PLANES[7] + self.INIT_DIM | |
self.block8 = self._make_layer(self.BLOCK, self.PLANES[7], self.LAYERS[7]) | |
self.final = ME.MinkowskiConvolution( | |
self.PLANES[7] * self.BLOCK.expansion, | |
out_channels, | |
kernel_size=1, | |
bias=True, | |
dimension=self.D, | |
) | |
self.relu = ME.MinkowskiReLU(inplace=True) | |
self.weight_initialization() | |
def weight_initialization(self): | |
for m in self.modules(): | |
if isinstance(m, ME.MinkowskiConvolution): | |
ME.utils.kaiming_normal_(m.kernel, mode="fan_out", nonlinearity="relu") | |
if isinstance(m, ME.MinkowskiBatchNorm): | |
nn.init.constant_(m.bn.weight, 1) | |
nn.init.constant_(m.bn.bias, 0) | |
def _make_layer(self, block, planes, blocks, stride=1, dilation=1, bn_momentum=0.1): | |
downsample = None | |
if stride != 1 or self.inplanes != planes * block.expansion: | |
downsample = nn.Sequential( | |
ME.MinkowskiConvolution( | |
self.inplanes, | |
planes * block.expansion, | |
kernel_size=1, | |
stride=stride, | |
dimension=self.D, | |
), | |
ME.MinkowskiBatchNorm(planes * block.expansion), | |
) | |
layers = [] | |
layers.append( | |
block( | |
self.inplanes, | |
planes, | |
stride=stride, | |
dilation=dilation, | |
downsample=downsample, | |
dimension=self.D, | |
) | |
) | |
self.inplanes = planes * block.expansion | |
for i in range(1, blocks): | |
layers.append( | |
block( | |
self.inplanes, planes, stride=1, dilation=dilation, dimension=self.D | |
) | |
) | |
return nn.Sequential(*layers) | |
def forward(self, data_dict): | |
grid_coord = data_dict["grid_coord"] | |
feat = data_dict["feat"] | |
offset = data_dict["offset"] | |
batch = offset2batch(offset) | |
in_field = ME.TensorField( | |
feat, | |
coordinates=torch.cat([batch.unsqueeze(-1).int(), grid_coord.int()], dim=1), | |
quantization_mode=ME.SparseTensorQuantizationMode.UNWEIGHTED_AVERAGE, | |
minkowski_algorithm=ME.MinkowskiAlgorithm.SPEED_OPTIMIZED, | |
device=feat.device, | |
) | |
x = in_field.sparse() | |
out = self.conv0p1s1(x) | |
out = self.bn0(out) | |
out_p1 = self.relu(out) | |
out = self.conv1p1s2(out_p1) | |
out = self.bn1(out) | |
out = self.relu(out) | |
out_b1p2 = self.block1(out) | |
out = self.conv2p2s2(out_b1p2) | |
out = self.bn2(out) | |
out = self.relu(out) | |
out_b2p4 = self.block2(out) | |
out = self.conv3p4s2(out_b2p4) | |
out = self.bn3(out) | |
out = self.relu(out) | |
out_b3p8 = self.block3(out) | |
# tensor_stride=16 | |
out = self.conv4p8s2(out_b3p8) | |
out = self.bn4(out) | |
out = self.relu(out) | |
out = self.block4(out) | |
# tensor_stride=8 | |
out = self.convtr4p16s2(out) | |
out = self.bntr4(out) | |
out = self.relu(out) | |
out = ME.cat(out, out_b3p8) | |
out = self.block5(out) | |
# tensor_stride=4 | |
out = self.convtr5p8s2(out) | |
out = self.bntr5(out) | |
out = self.relu(out) | |
out = ME.cat(out, out_b2p4) | |
out = self.block6(out) | |
# tensor_stride=2 | |
out = self.convtr6p4s2(out) | |
out = self.bntr6(out) | |
out = self.relu(out) | |
out = ME.cat(out, out_b1p2) | |
out = self.block7(out) | |
# tensor_stride=1 | |
out = self.convtr7p2s2(out) | |
out = self.bntr7(out) | |
out = self.relu(out) | |
out = ME.cat(out, out_p1) | |
out = self.block8(out) | |
return self.final(out).slice(in_field).F | |
class MinkUNet14(MinkUNetBase): | |
BLOCK = BasicBlock | |
LAYERS = (1, 1, 1, 1, 1, 1, 1, 1) | |
class MinkUNet18(MinkUNetBase): | |
BLOCK = BasicBlock | |
LAYERS = (2, 2, 2, 2, 2, 2, 2, 2) | |
class MinkUNet34(MinkUNetBase): | |
BLOCK = BasicBlock | |
LAYERS = (2, 3, 4, 6, 2, 2, 2, 2) | |
class MinkUNet50(MinkUNetBase): | |
BLOCK = Bottleneck | |
LAYERS = (2, 3, 4, 6, 2, 2, 2, 2) | |
class MinkUNet101(MinkUNetBase): | |
BLOCK = Bottleneck | |
LAYERS = (2, 3, 4, 23, 2, 2, 2, 2) | |
class MinkUNet14A(MinkUNet14): | |
PLANES = (32, 64, 128, 256, 128, 128, 96, 96) | |
class MinkUNet14B(MinkUNet14): | |
PLANES = (32, 64, 128, 256, 128, 128, 128, 128) | |
class MinkUNet14C(MinkUNet14): | |
PLANES = (32, 64, 128, 256, 192, 192, 128, 128) | |
class MinkUNet14D(MinkUNet14): | |
PLANES = (32, 64, 128, 256, 384, 384, 384, 384) | |
class MinkUNet18A(MinkUNet18): | |
PLANES = (32, 64, 128, 256, 128, 128, 96, 96) | |
class MinkUNet18B(MinkUNet18): | |
PLANES = (32, 64, 128, 256, 128, 128, 128, 128) | |
class MinkUNet18D(MinkUNet18): | |
PLANES = (32, 64, 128, 256, 384, 384, 384, 384) | |
class MinkUNet34A(MinkUNet34): | |
PLANES = (32, 64, 128, 256, 256, 128, 96, 96) | |
class MinkUNet34B(MinkUNet34): | |
PLANES = (32, 64, 128, 256, 256, 128, 64, 32) | |
class MinkUNet34C(MinkUNet34): | |
PLANES = (32, 64, 128, 256, 256, 128, 96, 96) | |