|
""" |
|
SPVCNN |
|
|
|
Author: Xiaoyang Wu ([email protected]) |
|
Please cite our work if the code is helpful to you. |
|
""" |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
try: |
|
import torchsparse |
|
import torchsparse.nn as spnn |
|
import torchsparse.nn.functional as F |
|
from torchsparse.nn.utils import get_kernel_offsets |
|
from torchsparse import PointTensor, SparseTensor |
|
except ImportError: |
|
torchsparse = None |
|
|
|
|
|
from pointcept.models.utils import offset2batch |
|
from pointcept.models.builder import MODELS |
|
|
|
|
|
def initial_voxelize(z): |
|
pc_hash = F.sphash(torch.floor(z.C).int()) |
|
sparse_hash = torch.unique(pc_hash) |
|
idx_query = F.sphashquery(pc_hash, sparse_hash) |
|
counts = F.spcount(idx_query.int(), len(sparse_hash)) |
|
|
|
inserted_coords = F.spvoxelize(torch.floor(z.C), idx_query, counts) |
|
inserted_coords = torch.round(inserted_coords).int() |
|
inserted_feat = F.spvoxelize(z.F, idx_query, counts) |
|
|
|
new_tensor = SparseTensor(inserted_feat, inserted_coords, 1) |
|
new_tensor.cmaps.setdefault(new_tensor.stride, new_tensor.coords) |
|
z.additional_features["idx_query"][1] = idx_query |
|
z.additional_features["counts"][1] = counts |
|
return new_tensor |
|
|
|
|
|
|
|
|
|
def point_to_voxel(x, z): |
|
if ( |
|
z.additional_features is None |
|
or z.additional_features.get("idx_query") is None |
|
or z.additional_features["idx_query"].get(x.s) is None |
|
): |
|
pc_hash = F.sphash( |
|
torch.cat( |
|
[ |
|
torch.floor(z.C[:, :3] / x.s[0]).int() * x.s[0], |
|
z.C[:, -1].int().view(-1, 1), |
|
], |
|
1, |
|
) |
|
) |
|
sparse_hash = F.sphash(x.C) |
|
idx_query = F.sphashquery(pc_hash, sparse_hash) |
|
counts = F.spcount(idx_query.int(), x.C.shape[0]) |
|
z.additional_features["idx_query"][x.s] = idx_query |
|
z.additional_features["counts"][x.s] = counts |
|
else: |
|
idx_query = z.additional_features["idx_query"][x.s] |
|
counts = z.additional_features["counts"][x.s] |
|
|
|
inserted_feat = F.spvoxelize(z.F, idx_query, counts) |
|
new_tensor = SparseTensor(inserted_feat, x.C, x.s) |
|
new_tensor.cmaps = x.cmaps |
|
new_tensor.kmaps = x.kmaps |
|
|
|
return new_tensor |
|
|
|
|
|
|
|
|
|
def voxel_to_point(x, z, nearest=False): |
|
if ( |
|
z.idx_query is None |
|
or z.weights is None |
|
or z.idx_query.get(x.s) is None |
|
or z.weights.get(x.s) is None |
|
): |
|
off = spnn.utils.get_kernel_offsets(2, x.s, 1, device=z.F.device) |
|
old_hash = F.sphash( |
|
torch.cat( |
|
[ |
|
torch.floor(z.C[:, :3] / x.s[0]).int() * x.s[0], |
|
z.C[:, -1].int().view(-1, 1), |
|
], |
|
1, |
|
), |
|
off, |
|
) |
|
pc_hash = F.sphash(x.C.to(z.F.device)) |
|
idx_query = F.sphashquery(old_hash, pc_hash) |
|
weights = ( |
|
F.calc_ti_weights(z.C, idx_query, scale=x.s[0]).transpose(0, 1).contiguous() |
|
) |
|
idx_query = idx_query.transpose(0, 1).contiguous() |
|
if nearest: |
|
weights[:, 1:] = 0.0 |
|
idx_query[:, 1:] = -1 |
|
new_feat = F.spdevoxelize(x.F, idx_query, weights) |
|
new_tensor = PointTensor( |
|
new_feat, z.C, idx_query=z.idx_query, weights=z.weights |
|
) |
|
new_tensor.additional_features = z.additional_features |
|
new_tensor.idx_query[x.s] = idx_query |
|
new_tensor.weights[x.s] = weights |
|
z.idx_query[x.s] = idx_query |
|
z.weights[x.s] = weights |
|
|
|
else: |
|
new_feat = F.spdevoxelize(x.F, z.idx_query.get(x.s), z.weights.get(x.s)) |
|
new_tensor = PointTensor( |
|
new_feat, z.C, idx_query=z.idx_query, weights=z.weights |
|
) |
|
new_tensor.additional_features = z.additional_features |
|
|
|
return new_tensor |
|
|
|
|
|
class BasicConvolutionBlock(nn.Module): |
|
def __init__(self, inc, outc, ks=3, stride=1, dilation=1): |
|
super().__init__() |
|
self.net = nn.Sequential( |
|
spnn.Conv3d(inc, outc, kernel_size=ks, dilation=dilation, stride=stride), |
|
spnn.BatchNorm(outc), |
|
spnn.ReLU(True), |
|
) |
|
|
|
def forward(self, x): |
|
out = self.net(x) |
|
return out |
|
|
|
|
|
class BasicDeconvolutionBlock(nn.Module): |
|
def __init__(self, inc, outc, ks=3, stride=1): |
|
super().__init__() |
|
self.net = nn.Sequential( |
|
spnn.Conv3d(inc, outc, kernel_size=ks, stride=stride, transposed=True), |
|
spnn.BatchNorm(outc), |
|
spnn.ReLU(True), |
|
) |
|
|
|
def forward(self, x): |
|
return self.net(x) |
|
|
|
|
|
class ResidualBlock(nn.Module): |
|
def __init__(self, inc, outc, ks=3, stride=1, dilation=1): |
|
super().__init__() |
|
self.net = nn.Sequential( |
|
spnn.Conv3d(inc, outc, kernel_size=ks, dilation=dilation, stride=stride), |
|
spnn.BatchNorm(outc), |
|
spnn.ReLU(True), |
|
spnn.Conv3d(outc, outc, kernel_size=ks, dilation=dilation, stride=1), |
|
spnn.BatchNorm(outc), |
|
) |
|
|
|
if inc == outc and stride == 1: |
|
self.downsample = nn.Identity() |
|
else: |
|
self.downsample = nn.Sequential( |
|
spnn.Conv3d(inc, outc, kernel_size=1, dilation=1, stride=stride), |
|
spnn.BatchNorm(outc), |
|
) |
|
|
|
self.relu = spnn.ReLU(True) |
|
|
|
def forward(self, x): |
|
out = self.relu(self.net(x) + self.downsample(x)) |
|
return out |
|
|
|
|
|
@MODELS.register_module() |
|
class SPVCNN(nn.Module): |
|
def __init__( |
|
self, |
|
in_channels, |
|
out_channels, |
|
base_channels=32, |
|
channels=(32, 64, 128, 256, 256, 128, 96, 96), |
|
layers=(2, 2, 2, 2, 2, 2, 2, 2), |
|
): |
|
super().__init__() |
|
|
|
assert ( |
|
torchsparse is not None |
|
), "Please follow `README.md` to install torchsparse.`" |
|
assert len(layers) % 2 == 0 |
|
assert len(layers) == len(channels) |
|
self.in_channels = in_channels |
|
self.out_channels = out_channels |
|
self.base_channels = base_channels |
|
self.channels = channels |
|
self.layers = layers |
|
self.num_stages = len(layers) // 2 |
|
|
|
self.stem = nn.Sequential( |
|
spnn.Conv3d(in_channels, base_channels, kernel_size=3, stride=1), |
|
spnn.BatchNorm(base_channels), |
|
spnn.ReLU(True), |
|
spnn.Conv3d(base_channels, base_channels, kernel_size=3, stride=1), |
|
spnn.BatchNorm(base_channels), |
|
spnn.ReLU(True), |
|
) |
|
|
|
self.stage1 = nn.Sequential( |
|
*[ |
|
BasicConvolutionBlock( |
|
base_channels, base_channels, ks=2, stride=2, dilation=1 |
|
), |
|
ResidualBlock(base_channels, channels[0], ks=3, stride=1, dilation=1), |
|
] |
|
+ [ |
|
ResidualBlock(channels[0], channels[0], ks=3, stride=1, dilation=1) |
|
for _ in range(layers[0] - 1) |
|
] |
|
) |
|
|
|
self.stage2 = nn.Sequential( |
|
*[ |
|
BasicConvolutionBlock( |
|
channels[0], channels[0], ks=2, stride=2, dilation=1 |
|
), |
|
ResidualBlock(channels[0], channels[1], ks=3, stride=1, dilation=1), |
|
] |
|
+ [ |
|
ResidualBlock(channels[1], channels[1], ks=3, stride=1, dilation=1) |
|
for _ in range(layers[1] - 1) |
|
] |
|
) |
|
|
|
self.stage3 = nn.Sequential( |
|
*[ |
|
BasicConvolutionBlock( |
|
channels[1], channels[1], ks=2, stride=2, dilation=1 |
|
), |
|
ResidualBlock(channels[1], channels[2], ks=3, stride=1, dilation=1), |
|
] |
|
+ [ |
|
ResidualBlock(channels[2], channels[2], ks=3, stride=1, dilation=1) |
|
for _ in range(layers[2] - 1) |
|
] |
|
) |
|
|
|
self.stage4 = nn.Sequential( |
|
*[ |
|
BasicConvolutionBlock( |
|
channels[2], channels[2], ks=2, stride=2, dilation=1 |
|
), |
|
ResidualBlock(channels[2], channels[3], ks=3, stride=1, dilation=1), |
|
] |
|
+ [ |
|
ResidualBlock(channels[3], channels[3], ks=3, stride=1, dilation=1) |
|
for _ in range(layers[3] - 1) |
|
] |
|
) |
|
|
|
self.up1 = nn.ModuleList( |
|
[ |
|
BasicDeconvolutionBlock(channels[3], channels[4], ks=2, stride=2), |
|
nn.Sequential( |
|
*[ |
|
ResidualBlock( |
|
channels[4] + channels[2], |
|
channels[4], |
|
ks=3, |
|
stride=1, |
|
dilation=1, |
|
) |
|
] |
|
+ [ |
|
ResidualBlock( |
|
channels[4], channels[4], ks=3, stride=1, dilation=1 |
|
) |
|
for _ in range(layers[4] - 1) |
|
] |
|
), |
|
] |
|
) |
|
|
|
self.up2 = nn.ModuleList( |
|
[ |
|
BasicDeconvolutionBlock(channels[4], channels[5], ks=2, stride=2), |
|
nn.Sequential( |
|
*[ |
|
ResidualBlock( |
|
channels[5] + channels[1], |
|
channels[5], |
|
ks=3, |
|
stride=1, |
|
dilation=1, |
|
) |
|
] |
|
+ [ |
|
ResidualBlock( |
|
channels[5], channels[5], ks=3, stride=1, dilation=1 |
|
) |
|
for _ in range(layers[5] - 1) |
|
] |
|
), |
|
] |
|
) |
|
|
|
self.up3 = nn.ModuleList( |
|
[ |
|
BasicDeconvolutionBlock(channels[5], channels[6], ks=2, stride=2), |
|
nn.Sequential( |
|
*[ |
|
ResidualBlock( |
|
channels[6] + channels[0], |
|
channels[6], |
|
ks=3, |
|
stride=1, |
|
dilation=1, |
|
) |
|
] |
|
+ [ |
|
ResidualBlock( |
|
channels[6], channels[6], ks=3, stride=1, dilation=1 |
|
) |
|
for _ in range(layers[6] - 1) |
|
] |
|
), |
|
] |
|
) |
|
|
|
self.up4 = nn.ModuleList( |
|
[ |
|
BasicDeconvolutionBlock(channels[6], channels[7], ks=2, stride=2), |
|
nn.Sequential( |
|
*[ |
|
ResidualBlock( |
|
channels[7] + base_channels, |
|
channels[7], |
|
ks=3, |
|
stride=1, |
|
dilation=1, |
|
) |
|
] |
|
+ [ |
|
ResidualBlock( |
|
channels[7], channels[7], ks=3, stride=1, dilation=1 |
|
) |
|
for _ in range(layers[7] - 1) |
|
] |
|
), |
|
] |
|
) |
|
|
|
self.classifier = nn.Sequential(nn.Linear(channels[7], out_channels)) |
|
|
|
self.point_transforms = nn.ModuleList( |
|
[ |
|
nn.Sequential( |
|
nn.Linear(base_channels, channels[3]), |
|
nn.BatchNorm1d(channels[3]), |
|
nn.ReLU(True), |
|
), |
|
nn.Sequential( |
|
nn.Linear(channels[3], channels[5]), |
|
nn.BatchNorm1d(channels[5]), |
|
nn.ReLU(True), |
|
), |
|
nn.Sequential( |
|
nn.Linear(channels[5], channels[7]), |
|
nn.BatchNorm1d(channels[7]), |
|
nn.ReLU(True), |
|
), |
|
] |
|
) |
|
|
|
self.weight_initialization() |
|
self.dropout = nn.Dropout(0.3, True) |
|
|
|
def weight_initialization(self): |
|
for m in self.modules(): |
|
if isinstance(m, nn.BatchNorm1d): |
|
nn.init.constant_(m.weight, 1) |
|
nn.init.constant_(m.bias, 0) |
|
|
|
def forward(self, data_dict): |
|
grid_coord = data_dict["grid_coord"] |
|
feat = data_dict["feat"] |
|
offset = data_dict["offset"] |
|
batch = offset2batch(offset) |
|
|
|
|
|
z = PointTensor( |
|
feat, |
|
torch.cat( |
|
[grid_coord.float(), batch.unsqueeze(-1).float()], dim=1 |
|
).contiguous(), |
|
) |
|
x0 = initial_voxelize(z) |
|
|
|
x0 = self.stem(x0) |
|
z0 = voxel_to_point(x0, z, nearest=False) |
|
z0.F = z0.F |
|
|
|
x1 = point_to_voxel(x0, z0) |
|
x1 = self.stage1(x1) |
|
x2 = self.stage2(x1) |
|
x3 = self.stage3(x2) |
|
x4 = self.stage4(x3) |
|
z1 = voxel_to_point(x4, z0) |
|
z1.F = z1.F + self.point_transforms[0](z0.F) |
|
|
|
y1 = point_to_voxel(x4, z1) |
|
y1.F = self.dropout(y1.F) |
|
y1 = self.up1[0](y1) |
|
y1 = torchsparse.cat([y1, x3]) |
|
y1 = self.up1[1](y1) |
|
|
|
y2 = self.up2[0](y1) |
|
y2 = torchsparse.cat([y2, x2]) |
|
y2 = self.up2[1](y2) |
|
z2 = voxel_to_point(y2, z1) |
|
z2.F = z2.F + self.point_transforms[1](z1.F) |
|
|
|
y3 = point_to_voxel(y2, z2) |
|
y3.F = self.dropout(y3.F) |
|
y3 = self.up3[0](y3) |
|
y3 = torchsparse.cat([y3, x1]) |
|
y3 = self.up3[1](y3) |
|
|
|
y4 = self.up4[0](y3) |
|
y4 = torchsparse.cat([y4, x0]) |
|
y4 = self.up4[1](y4) |
|
z3 = voxel_to_point(y4, z2) |
|
z3.F = z3.F + self.point_transforms[2](z2.F) |
|
|
|
out = self.classifier(z3.F) |
|
return out |
|
|