Spaces:
Runtime error
Runtime error
""" | |
# Copyright (c) Microsoft Corporation. | |
# Licensed under the MIT License. | |
""" | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import MinkowskiEngine as ME | |
import numpy as np | |
def assign_feats(sp, x): | |
return ME.SparseTensor( | |
features=x.float(), | |
coordinate_map_key=sp.coordinate_map_key, | |
coordinate_manager=sp.coordinate_manager, | |
) | |
class MinkConvBN(nn.Module): | |
def __init__( | |
self, | |
in_channels, | |
out_channels, | |
kernel_size=3, | |
stride=1, | |
dilation=1, | |
bias=False, | |
dimension=3, | |
): | |
super().__init__() | |
self.conv_layers = nn.Sequential( | |
ME.MinkowskiConvolution( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=kernel_size, | |
stride=stride, | |
dilation=dilation, | |
bias=bias, | |
dimension=dimension, | |
), | |
ME.MinkowskiBatchNorm(out_channels), | |
) | |
def forward(self, x): | |
x = self.conv_layers(x) | |
return x | |
class MinkConvBNRelu(nn.Module): | |
def __init__( | |
self, | |
in_channels, | |
out_channels, | |
kernel_size=3, | |
stride=1, | |
dilation=1, | |
bias=False, | |
dimension=3, | |
): | |
super().__init__() | |
self.conv_layers = nn.Sequential( | |
ME.MinkowskiConvolution( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=kernel_size, | |
stride=stride, | |
dilation=dilation, | |
bias=bias, | |
dimension=dimension, | |
), | |
ME.MinkowskiBatchNorm(out_channels), | |
ME.MinkowskiReLU(inplace=True), | |
) | |
def forward(self, x): | |
x = self.conv_layers(x) | |
if x.F.dtype == torch.float16: | |
x = assign_feats(x, x.F.float()) | |
return x | |
class MinkDeConvBNRelu(nn.Module): | |
def __init__( | |
self, | |
in_channels, | |
out_channels, | |
kernel_size, | |
stride, | |
dilation=1, | |
bias=False, | |
dimension=3, | |
): | |
super().__init__() | |
self.conv_layers = nn.Sequential( | |
ME.MinkowskiConvolutionTranspose( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=kernel_size, | |
stride=stride, | |
dilation=dilation, | |
bias=bias, | |
dimension=dimension, | |
), | |
ME.MinkowskiBatchNorm(out_channels), | |
ME.MinkowskiReLU(), | |
) | |
def forward(self, x): | |
x = self.conv_layers(x) | |
return x | |
class MinkResBlock(nn.Module): | |
def __init__(self, in_channels, out_channels, stride=1, dilation=1): | |
super(MinkResBlock, self).__init__() | |
self.conv1 = ME.MinkowskiConvolution( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=3, | |
stride=stride, | |
dilation=dilation, | |
bias=False, | |
dimension=3, | |
) | |
self.norm1 = ME.MinkowskiBatchNorm(out_channels) | |
self.conv2 = ME.MinkowskiConvolution( | |
in_channels=out_channels, | |
out_channels=out_channels, | |
kernel_size=3, | |
stride=1, | |
dilation=dilation, | |
bias=False, | |
dimension=3, | |
) | |
self.norm2 = ME.MinkowskiBatchNorm(out_channels) | |
self.relu = ME.MinkowskiReLU(inplace=True) | |
def forward(self, x): | |
residual = x | |
out = self.conv1(x) | |
out = self.norm1(out) | |
out = self.relu(out) | |
out = self.conv2(out) | |
out = self.norm2(out) | |
out += residual | |
out = self.relu(out) | |
return out | |
class SparseTensorLinear(nn.Module): | |
def __init__(self, in_channels, out_channels, bias=False): | |
super().__init__() | |
self.linear = nn.Linear(in_channels, out_channels, bias=bias) | |
def forward(self, sp): | |
x = self.linear(sp.F) | |
return assign_feats(sp, x.float()) | |
class SparseTensorLayerNorm(nn.Module): | |
def __init__(self, dim): | |
super().__init__() | |
self.norm = nn.LayerNorm(dim) | |
def forward(self, sp): | |
x = self.norm(sp.F) | |
return assign_feats(sp, x.float()) | |
class MinkResBlock_v2(nn.Module): | |
def __init__(self, in_channels, out_channels): | |
super().__init__() | |
d_2 = out_channels // 4 | |
self.conv1 = torch.nn.Sequential( | |
SparseTensorLinear(in_channels, d_2, bias=False), | |
ME.MinkowskiBatchNorm(d_2), | |
ME.MinkowskiReLU(), | |
) | |
self.unary_2 = torch.nn.Sequential( | |
SparseTensorLinear(d_2, out_channels, bias=False), | |
ME.MinkowskiBatchNorm(out_channels), | |
ME.MinkowskiReLU(), | |
) | |
self.spconv = ME.MinkowskiConvolution( | |
in_channels=d_2, | |
out_channels=d_2, | |
kernel_size=5, | |
stride=1, | |
dilation=1, | |
bias=False, | |
dimension=3, | |
) | |
if in_channels != out_channels: | |
self.shortcut_op = torch.nn.Sequential( | |
SparseTensorLinear(in_channels, out_channels, bias=False), | |
ME.MinkowskiBatchNorm(out_channels), | |
) | |
else: | |
self.shortcut_op = nn.Identity() | |
def forward(self, x): | |
# feats: [N, C] | |
# xyz: [N, 3] | |
# batch: [N,] | |
# neighbor_idx: [N, M] | |
shortcut = x | |
x = self.unary_1(x) | |
x = self.spconv(x) | |
x = self.unary_2(x) | |
shortcut = self.shortcut_op(shortcut) | |
x += shortcut | |
return x | |
class MinkResBlock_BottleNeck(nn.Module): | |
def __init__(self, in_channels, out_channels): | |
super(MinkResBlock_BottleNeck, self).__init__() | |
bottle_neck = out_channels // 4 | |
self.conv1x1a = MinkConvBNRelu( | |
in_channels, bottle_neck, kernel_size=1, stride=1 | |
) | |
self.conv3x3 = MinkConvBNRelu(bottle_neck, bottle_neck, kernel_size=3, stride=1) | |
self.conv1x1b = MinkConvBN(bottle_neck, out_channels, kernel_size=1, stride=1) | |
if in_channels != out_channels: | |
self.conv1x1c = MinkConvBN( | |
in_channels, out_channels, kernel_size=1, stride=1 | |
) | |
else: | |
self.conv1x1c = None | |
self.relu = ME.MinkowskiReLU(inplace=True) | |
def forward(self, x): | |
residual = x | |
out = self.conv1x1a(x) | |
out = self.conv3x3(out) | |
out = self.conv1x1b(out) | |
if self.conv1x1c is not None: | |
residual = self.conv1x1c(residual) | |
out = self.relu(out + residual) | |
return out | |