kairunwen's picture
Update Code
57746f1
"""
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import MinkowskiEngine as ME
import numpy as np
def assign_feats(sp, x):
return ME.SparseTensor(
features=x.float(),
coordinate_map_key=sp.coordinate_map_key,
coordinate_manager=sp.coordinate_manager,
)
class MinkConvBN(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size=3,
stride=1,
dilation=1,
bias=False,
dimension=3,
):
super().__init__()
self.conv_layers = nn.Sequential(
ME.MinkowskiConvolution(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
bias=bias,
dimension=dimension,
),
ME.MinkowskiBatchNorm(out_channels),
)
def forward(self, x):
x = self.conv_layers(x)
return x
class MinkConvBNRelu(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size=3,
stride=1,
dilation=1,
bias=False,
dimension=3,
):
super().__init__()
self.conv_layers = nn.Sequential(
ME.MinkowskiConvolution(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
bias=bias,
dimension=dimension,
),
ME.MinkowskiBatchNorm(out_channels),
ME.MinkowskiReLU(inplace=True),
)
def forward(self, x):
x = self.conv_layers(x)
if x.F.dtype == torch.float16:
x = assign_feats(x, x.F.float())
return x
class MinkDeConvBNRelu(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride,
dilation=1,
bias=False,
dimension=3,
):
super().__init__()
self.conv_layers = nn.Sequential(
ME.MinkowskiConvolutionTranspose(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
bias=bias,
dimension=dimension,
),
ME.MinkowskiBatchNorm(out_channels),
ME.MinkowskiReLU(),
)
def forward(self, x):
x = self.conv_layers(x)
return x
class MinkResBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride=1, dilation=1):
super(MinkResBlock, self).__init__()
self.conv1 = ME.MinkowskiConvolution(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=stride,
dilation=dilation,
bias=False,
dimension=3,
)
self.norm1 = ME.MinkowskiBatchNorm(out_channels)
self.conv2 = ME.MinkowskiConvolution(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
stride=1,
dilation=dilation,
bias=False,
dimension=3,
)
self.norm2 = ME.MinkowskiBatchNorm(out_channels)
self.relu = ME.MinkowskiReLU(inplace=True)
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.norm1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.norm2(out)
out += residual
out = self.relu(out)
return out
class SparseTensorLinear(nn.Module):
def __init__(self, in_channels, out_channels, bias=False):
super().__init__()
self.linear = nn.Linear(in_channels, out_channels, bias=bias)
def forward(self, sp):
x = self.linear(sp.F)
return assign_feats(sp, x.float())
class SparseTensorLayerNorm(nn.Module):
def __init__(self, dim):
super().__init__()
self.norm = nn.LayerNorm(dim)
def forward(self, sp):
x = self.norm(sp.F)
return assign_feats(sp, x.float())
class MinkResBlock_v2(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
d_2 = out_channels // 4
self.conv1 = torch.nn.Sequential(
SparseTensorLinear(in_channels, d_2, bias=False),
ME.MinkowskiBatchNorm(d_2),
ME.MinkowskiReLU(),
)
self.unary_2 = torch.nn.Sequential(
SparseTensorLinear(d_2, out_channels, bias=False),
ME.MinkowskiBatchNorm(out_channels),
ME.MinkowskiReLU(),
)
self.spconv = ME.MinkowskiConvolution(
in_channels=d_2,
out_channels=d_2,
kernel_size=5,
stride=1,
dilation=1,
bias=False,
dimension=3,
)
if in_channels != out_channels:
self.shortcut_op = torch.nn.Sequential(
SparseTensorLinear(in_channels, out_channels, bias=False),
ME.MinkowskiBatchNorm(out_channels),
)
else:
self.shortcut_op = nn.Identity()
def forward(self, x):
# feats: [N, C]
# xyz: [N, 3]
# batch: [N,]
# neighbor_idx: [N, M]
shortcut = x
x = self.unary_1(x)
x = self.spconv(x)
x = self.unary_2(x)
shortcut = self.shortcut_op(shortcut)
x += shortcut
return x
class MinkResBlock_BottleNeck(nn.Module):
def __init__(self, in_channels, out_channels):
super(MinkResBlock_BottleNeck, self).__init__()
bottle_neck = out_channels // 4
self.conv1x1a = MinkConvBNRelu(
in_channels, bottle_neck, kernel_size=1, stride=1
)
self.conv3x3 = MinkConvBNRelu(bottle_neck, bottle_neck, kernel_size=3, stride=1)
self.conv1x1b = MinkConvBN(bottle_neck, out_channels, kernel_size=1, stride=1)
if in_channels != out_channels:
self.conv1x1c = MinkConvBN(
in_channels, out_channels, kernel_size=1, stride=1
)
else:
self.conv1x1c = None
self.relu = ME.MinkowskiReLU(inplace=True)
def forward(self, x):
residual = x
out = self.conv1x1a(x)
out = self.conv3x3(out)
out = self.conv1x1b(out)
if self.conv1x1c is not None:
residual = self.conv1x1c(residual)
out = self.relu(out + residual)
return out