|
""" |
|
# Copyright (c) Microsoft Corporation. |
|
# Licensed under the MIT License. |
|
""" |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
from timm.models.layers import DropPath, trunc_normal_ |
|
import MinkowskiEngine as ME |
|
from MinkowskiEngine import SparseTensor |
|
from Swin3D.sparse_dl.attn.attn_coff import ( |
|
SelfAttnAIOFunction, |
|
PosEmb, |
|
TableDims, |
|
IndexMode, |
|
PrecisionMode, |
|
) |
|
import Swin3D.sparse_dl.knn |
|
from Swin3D.sparse_dl.knn import KNN |
|
|
|
from .mink_layers import ( |
|
assign_feats, |
|
SparseTensorLayerNorm, |
|
SparseTensorLinear, |
|
) |
|
|
|
|
|
def query_knn_feature( |
|
K, src_xyz, query_xyz, src_feat, src_offset, query_offset, return_idx=False |
|
): |
|
""" |
|
gather feature in the KNN neighborhood |
|
""" |
|
assert ( |
|
src_xyz.is_contiguous() |
|
and query_xyz.is_contiguous() |
|
and src_feat.is_contiguous() |
|
) |
|
if query_xyz is None: |
|
query_xyz = src_xyz |
|
query_offset = src_offset |
|
|
|
idx, _ = KNN.apply(K, src_xyz, query_xyz, src_offset, query_offset) |
|
|
|
n, m, c = src_xyz.shape[0], query_xyz.shape[0], src_feat.shape[1] |
|
grouped_feat = src_feat[idx.view(-1).long(), :].view(m, K, c) |
|
|
|
if return_idx: |
|
return grouped_feat, idx |
|
else: |
|
return grouped_feat |
|
|
|
|
|
def knn_linear_interpolation( |
|
src_xyz, query_xyz, src_feat, src_offset, query_offset, K=3 |
|
): |
|
""" |
|
interpolation feature using distance in KNN neighborhood |
|
""" |
|
N, C = query_xyz.shape[0], src_feat.shape[1] |
|
assert ( |
|
src_xyz.is_contiguous() |
|
and query_xyz.is_contiguous() |
|
and src_feat.is_contiguous() |
|
) |
|
|
|
idx, dist = KNN.apply(K, src_xyz, query_xyz, src_offset, query_offset) |
|
weight = 1.0 / (dist + 1e-8) |
|
norm = torch.sum(weight, dim=1, keepdim=True) |
|
weight = weight / norm |
|
query_feat = torch.zeros((N, C), dtype=src_feat.dtype, device=src_feat.device) |
|
for i in range(K): |
|
query_feat += src_feat[idx[:, i].long(), :] * weight[:, i].unsqueeze(-1) |
|
return query_feat |
|
|
|
|
|
def sparse_self_attention( |
|
w_w_id: torch.Tensor, w_sizes: torch.Tensor, protocol: str = "v1" |
|
): |
|
""" |
|
Args: |
|
indices [torch.Tensor]: sparse window index with shape [N, 2], N is the total |
|
number of non-empty voxels with indices (window_id, within_window_id). window_id |
|
is ordered and starts from 0; within_window_id is a sparse index to indicate the |
|
offset of kernel_size ** 3. |
|
feats [torch.Tensor]: sprase features of each non-empty voxel with shape [N, C] |
|
Outputs: |
|
[M, 3]: sparse indices of cofficient matrix (window_id, att_a_id, att_b_id). att_a_id |
|
and att_b_id are the within_window_id |
|
[M, 1]: the sparse coffient matrix |
|
|
|
Spaces: |
|
W: total number of windows |
|
N: total number of input voxels |
|
M: total number of output cofficients |
|
""" |
|
w_sizes_2 = w_sizes**2 |
|
|
|
|
|
|
|
w_cumsum = torch.cumsum(w_sizes, dim=-1) |
|
w2n_indices = torch.cat( |
|
[torch.zeros(1, dtype=w_cumsum.dtype, device=w_cumsum.device), w_cumsum[:-1]] |
|
) |
|
|
|
|
|
|
|
w2_cumsum = torch.cumsum(w_sizes_2, dim=-1) |
|
w2m_indices = torch.cat( |
|
[torch.zeros(1, dtype=w2_cumsum.dtype, device=w2_cumsum.device), w2_cumsum[:-1]] |
|
) |
|
|
|
|
|
m2w_indices = torch.zeros( |
|
[w2_cumsum[-1]], dtype=w_sizes.dtype, device=w_sizes.device |
|
) |
|
m2w_offset = torch.zeros( |
|
[w2_cumsum[-1]], dtype=w_sizes.dtype, device=w_sizes.device |
|
) |
|
m2w_indices[w2m_indices[1:]] = 1 |
|
m2w_offset[w2m_indices[1:]] = w_sizes_2[:-1] |
|
m2w_indices = torch.cumsum(m2w_indices, dim=-1) |
|
m2w_offset = torch.cumsum(m2w_offset, dim=-1) |
|
|
|
|
|
m_indices = torch.arange( |
|
0, w2_cumsum[-1], dtype=w_sizes.dtype, device=w_sizes.device |
|
) |
|
|
|
|
|
|
|
m2n_indices = w2n_indices[m2w_indices] |
|
|
|
m_offset = m_indices - m2w_offset |
|
m2w_sizes = w_sizes[m2w_indices] |
|
|
|
|
|
|
|
|
|
y_offset = m2n_indices + m_offset % m2w_sizes |
|
x_offset = m2n_indices + torch.div(m_offset, m2w_sizes, rounding_mode="floor") |
|
|
|
|
|
|
|
|
|
|
|
|
|
if protocol == "v1": |
|
return x_offset, y_offset |
|
elif protocol == "v2": |
|
return x_offset, y_offset, m2w_indices, w_sizes, w2n_indices, w2m_indices |
|
|
|
|
|
class Mlp(nn.Module): |
|
def __init__( |
|
self, |
|
in_features, |
|
hidden_features=None, |
|
out_features=None, |
|
act_layer=nn.GELU, |
|
drop=0.0, |
|
): |
|
super().__init__() |
|
out_features = out_features or in_features |
|
hidden_features = hidden_features or in_features |
|
self.fc1 = nn.Linear(in_features, hidden_features) |
|
self.act = act_layer() |
|
self.fc2 = nn.Linear(hidden_features, out_features) |
|
self.drop = nn.Dropout(drop) |
|
|
|
def forward(self, x): |
|
x = self.fc1(x) |
|
x = self.act(x) |
|
x = self.drop(x) |
|
x = self.fc2(x) |
|
x = self.drop(x) |
|
return x |
|
|
|
|
|
class GridCoordsDown(nn.Module): |
|
""" |
|
downsample the grid coordinates |
|
keep the nearest point to the average point of the downsampled grid |
|
""" |
|
|
|
def __init__(self, stride): |
|
super().__init__() |
|
self.stride = stride |
|
self.avg_pool = ME.MinkowskiAvgPooling( |
|
kernel_size=self.stride, stride=self.stride, dimension=3 |
|
) |
|
self.unpool = ME.MinkowskiPoolingTranspose( |
|
kernel_size=stride, stride=stride, dimension=3 |
|
) |
|
self.max_pool = ME.MinkowskiMaxPooling( |
|
kernel_size=self.stride, stride=self.stride, dimension=3 |
|
) |
|
|
|
def forward(self, coords_sp, sp, return_map=False): |
|
device = sp.C.device |
|
|
|
|
|
|
|
N = sp.shape[0] |
|
avg_coords_sp = self.avg_pool(coords_sp) |
|
dist_sp = self.unpool(avg_coords_sp) - coords_sp |
|
dist = dist_sp.F |
|
dist = -torch.sqrt((dist**2).sum(dim=1)).unsqueeze(1) |
|
dist_sp = assign_feats(dist_sp, dist) |
|
min_dist_sp = self.max_pool(dist_sp) |
|
map_pair = sp.coordinate_manager.kernel_map( |
|
dist_sp.coordinate_map_key, |
|
min_dist_sp.coordinate_map_key, |
|
stride=self.stride, |
|
kernel_size=self.stride, |
|
is_pool=True, |
|
)[0] |
|
in_map, out_map = map_pair |
|
broad_min_dist_sp = self.unpool(min_dist_sp) |
|
mask = (broad_min_dist_sp.F == dist_sp.F).squeeze(1) |
|
in_map = in_map[mask].long() |
|
out_map = out_map[mask].long() |
|
downsample_map = torch.zeros(N, dtype=torch.long, device=device) - 1 |
|
downsample_map[out_map] = in_map |
|
assert (downsample_map >= 0).all() |
|
assert (dist_sp.F[downsample_map] == min_dist_sp.F).all() |
|
new_coords = coords_sp.F[downsample_map] |
|
new_coords_sp = assign_feats(sp, new_coords) |
|
if return_map: |
|
return new_coords_sp, downsample_map |
|
else: |
|
return new_coords_sp |
|
|
|
|
|
def get_offset(batch): |
|
offset = [] |
|
bs = batch.max() + 1 |
|
for i in range(bs): |
|
offset.append(torch.sum(batch == i)) |
|
offset = torch.cuda.IntTensor(offset) |
|
offset = offset.cumsum(dim=0).int() |
|
return offset |
|
|
|
|
|
class GridDownsample(nn.Module): |
|
""" |
|
use stride to downsample voxel |
|
use grid maxpooling with kernel_size |
|
""" |
|
|
|
def __init__(self, in_channels, out_channels, kernel_size=2, stride=2): |
|
super().__init__() |
|
self.kernel_size = kernel_size |
|
self.stride = stride |
|
self.in_channels = in_channels |
|
self.out_channels = out_channels |
|
self.sp_pool = ME.MinkowskiMaxPooling( |
|
kernel_size=kernel_size, stride=stride, dimension=3 |
|
) |
|
self.coords_pool = GridCoordsDown(stride=stride) |
|
self.norm = SparseTensorLayerNorm(in_channels) |
|
self.linear = SparseTensorLinear(in_channels, out_channels) |
|
|
|
def forward(self, sp, coords_sp): |
|
sp_down = self.sp_pool(self.linear(self.norm(sp))) |
|
coords_sp_down = self.coords_pool(coords_sp, sp_down) |
|
return sp_down, coords_sp_down |
|
|
|
def extra_repr(self) -> str: |
|
return f"kernel_size={self.kernel_size}, stride={self.stride}, in_channels={self.in_channels}, out_channels={self.out_channels}" |
|
|
|
|
|
class GridKNNDownsample(nn.Module): |
|
""" |
|
use stride to downsample voxel |
|
use KNN to do maxpooling |
|
""" |
|
|
|
def __init__(self, in_channels, out_channels, kernel_size=2, stride=2): |
|
super().__init__() |
|
self.stride = stride |
|
self.in_channels = in_channels |
|
self.out_channels = out_channels |
|
self.k = 16 |
|
self.sp_pool = ME.MinkowskiMaxPooling( |
|
kernel_size=stride, stride=stride, dimension=3 |
|
) |
|
self.coords_pool = GridCoordsDown(stride=stride) |
|
self.norm = nn.LayerNorm(in_channels) |
|
self.linear = nn.Linear(in_channels, out_channels, bias=False) |
|
self.pool = nn.MaxPool1d(self.k) |
|
|
|
def forward(self, sp, coords_sp): |
|
|
|
sp_down = self.sp_pool(sp) |
|
|
|
coords_sp_down = self.coords_pool(coords_sp, sp_down) |
|
offset = get_offset(sp.C[:, 0]) |
|
n_offset = get_offset(sp_down.C[:, 0]) |
|
|
|
xyz = coords_sp.F[:, 1:4].detach().contiguous() |
|
n_xyz = coords_sp_down.F[:, 1:4].detach().contiguous() |
|
feats = query_knn_feature(self.k, xyz, n_xyz, sp.F, offset, n_offset) |
|
m, k, c = feats.shape |
|
feats = ( |
|
self.linear(self.norm(feats.view(m * k, c)).view(m, k, c)) |
|
.transpose(1, 2) |
|
.contiguous() |
|
) |
|
feats = self.pool(feats).squeeze(-1) |
|
sp = assign_feats(sp_down, feats.float()) |
|
coords_sp = coords_sp_down |
|
return sp, coords_sp |
|
|
|
def extra_repr(self) -> str: |
|
return f"kernel_size={self.k}, stride={self.stride}, in_channels={self.in_channels}, out_channels={self.out_channels}" |
|
|
|
|
|
class Upsample(nn.Module): |
|
""" |
|
upsample using trilinear interpolation |
|
follower by attn block according to self.attn |
|
""" |
|
|
|
def __init__( |
|
self, |
|
in_channels, |
|
out_channels, |
|
num_heads, |
|
window_size, |
|
quant_size, |
|
attn=True, |
|
up_k=3, |
|
cRSE="XYZ_RGB", |
|
fp16_mode=0, |
|
): |
|
super().__init__() |
|
self.in_channels = in_channels |
|
self.out_channels = out_channels |
|
|
|
self.linear1 = nn.Sequential( |
|
nn.LayerNorm(out_channels), nn.Linear(out_channels, out_channels) |
|
) |
|
self.linear2 = nn.Sequential( |
|
nn.LayerNorm(in_channels), nn.Linear(in_channels, out_channels) |
|
) |
|
self.up_k = up_k |
|
self.attn = attn and window_size > 0 |
|
if self.attn: |
|
self.block = BasicLayer( |
|
dim=out_channels, |
|
depth=1, |
|
num_heads=num_heads, |
|
window_size=window_size, |
|
quant_size=quant_size, |
|
drop_path=0.1, |
|
downsample=None, |
|
out_channels=None, |
|
cRSE=cRSE, |
|
fp16_mode=fp16_mode, |
|
) |
|
|
|
def forward(self, sp, coords_sp, sp_up, coords_sp_up): |
|
feats = sp.F |
|
support_feats = sp_up.F |
|
xyz = coords_sp.F[:, 1:4].detach().contiguous() |
|
support_xyz = coords_sp_up.F[:, 1:4].detach().contiguous() |
|
offset = get_offset(sp.C[:, 0]) |
|
support_offset = get_offset(sp_up.C[:, 0]) |
|
|
|
feats = self.linear1(support_feats) + knn_linear_interpolation( |
|
xyz, support_xyz, self.linear2(feats), offset, support_offset, K=self.up_k |
|
) |
|
sp_up = assign_feats(sp_up, feats) |
|
if self.attn: |
|
sp_up, _, _ = self.block(sp_up, coords_sp_up) |
|
return sp_up |
|
|
|
def extra_repr(self) -> str: |
|
return f"up_k={self.up_k}, in_channels={self.in_channels}, out_channels={self.out_channels}, attn={self.attn}" |
|
|
|
|
|
class WindowAttention(nn.Module): |
|
""" |
|
Window based multi-head self attention (W-MSA) module with cRSE. |
|
Designed for sparse structure |
|
It supports both of shifted and non-shifted window. |
|
|
|
Args: |
|
dim (int): Number of input channels. |
|
window_size (tuple[int]): The height and width of the window. |
|
quant_size (int): quant_size for for finer cRSE table |
|
num_heads (int): Number of attention heads. |
|
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True |
|
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set |
|
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 |
|
proj_drop (float, optional): Dropout ratio of output. Default: 0.0 |
|
cRSE (str | 'XYZ', 'XYZ_RGB', 'XYZ_RGB_NORM'): cRSE mode. Default: 'XYZ_RGB' |
|
fp16_mode (int | 0, 1, 2): fp16 mode for attention module, Default: 0 |
|
0: fp32 forward and fp32 backward |
|
1: fp16 forward and fp32 backward |
|
2: fp16 forward and fp16 backward |
|
""" |
|
|
|
def __init__( |
|
self, |
|
dim, |
|
window_size, |
|
quant_size, |
|
num_heads, |
|
qkv_bias=True, |
|
qk_scale=None, |
|
attn_drop=0.0, |
|
proj_drop=0.0, |
|
cRSE="XYZ_RGB", |
|
fp16_mode=0, |
|
): |
|
super().__init__() |
|
self.dim = dim |
|
self.window_size = window_size |
|
self.num_heads = num_heads |
|
head_dim = dim // num_heads |
|
self.scale = qk_scale or head_dim**-0.5 |
|
|
|
|
|
|
|
self.color_windowsize = 2 |
|
self.normal_windowsize = 2 |
|
|
|
self.fp16_mode = fp16_mode |
|
|
|
table_offsets = [] |
|
self.cRSE = cRSE |
|
if "XYZ" in cRSE: |
|
self.xyz_quant_size = quant_size |
|
quant_grid_length_xyz = window_size * self.xyz_quant_size |
|
table_shape_xyz = (3, 2 * quant_grid_length_xyz, num_heads, head_dim) |
|
self.query_xyz_table = nn.Parameter(torch.zeros(table_shape_xyz)) |
|
trunc_normal_(self.query_xyz_table, std=0.02) |
|
self.key_xyz_table = nn.Parameter(torch.zeros(table_shape_xyz)) |
|
trunc_normal_(self.key_xyz_table, std=0.02) |
|
self.value_xyz_table = nn.Parameter(torch.zeros(table_shape_xyz)) |
|
trunc_normal_(self.value_xyz_table, std=0.02) |
|
table_offsets += [np.prod(table_shape_xyz[1:])] * 3 |
|
|
|
if "RGB" in cRSE: |
|
self.color_quant_size = quant_size * 2 |
|
quant_grid_length_rgb = self.color_windowsize * self.color_quant_size |
|
table_shape_rgb = (3, 2 * quant_grid_length_rgb, num_heads, head_dim) |
|
self.query_rgb_table = nn.Parameter(torch.zeros(table_shape_rgb)) |
|
trunc_normal_(self.query_rgb_table, std=0.02) |
|
self.key_rgb_table = nn.Parameter(torch.zeros(table_shape_rgb)) |
|
trunc_normal_(self.key_rgb_table, std=0.02) |
|
self.value_rgb_table = nn.Parameter(torch.zeros(table_shape_rgb)) |
|
trunc_normal_(self.value_rgb_table, std=0.02) |
|
table_offsets += [np.prod(table_shape_rgb[1:])] * 3 |
|
|
|
if "NORM" in cRSE: |
|
self.normal_quant_size = quant_size * 2 |
|
quant_grid_length_norm = self.normal_windowsize * self.normal_quant_size |
|
table_shape_norm = (3, 2 * quant_grid_length_norm, num_heads, head_dim) |
|
self.query_norm_table = nn.Parameter(torch.zeros(table_shape_norm)) |
|
trunc_normal_(self.query_norm_table, std=0.02) |
|
self.key_norm_table = nn.Parameter(torch.zeros(table_shape_norm)) |
|
trunc_normal_(self.key_norm_table, std=0.02) |
|
self.value_norm_table = nn.Parameter(torch.zeros(table_shape_norm)) |
|
trunc_normal_(self.value_norm_table, std=0.02) |
|
table_offsets += [np.prod(table_shape_norm[1:])] * 3 |
|
|
|
self.table_offsets = table_offsets |
|
|
|
self.quant_size = quant_size |
|
|
|
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) |
|
self.attn_drop = nn.Dropout(attn_drop, inplace=True) |
|
self.proj = nn.Linear(dim, dim) |
|
self.proj_drop = nn.Dropout(proj_drop, inplace=True) |
|
|
|
self.softmax = nn.Softmax(dim=-1) |
|
|
|
def forward(self, feats: torch.Tensor, attn_args): |
|
"""Forward function. |
|
|
|
Args: |
|
feats: N, C |
|
attn_args: arguments for computing attention |
|
""" |
|
num_v, _ = feats.shape |
|
num_sc = self.dim // self.num_heads |
|
|
|
( |
|
x_offset, |
|
y_offset, |
|
m2w_indices, |
|
w_sizes, |
|
w2n_indices, |
|
n2n_indices, |
|
w2m_indices, |
|
n_coords, |
|
) = attn_args |
|
|
|
|
|
qkv = self.qkv(feats) |
|
qkv = ( |
|
qkv.reshape(num_v, 3, self.num_heads, num_sc) |
|
.permute(1, 0, 2, 3) |
|
.contiguous() |
|
) |
|
query, key, value = qkv[0], qkv[1], qkv[2] |
|
query = query * self.scale |
|
|
|
table_offsets = torch.IntTensor(self.table_offsets).cuda() |
|
query_table, key_table, value_table = [], [], [] |
|
n_cRSE = [] |
|
if "XYZ" in self.cRSE: |
|
n_xyz = n_coords[:, 0:3] |
|
n_xyz = n_xyz * self.quant_size |
|
n_cRSE.append(n_xyz) |
|
query_table.append(self.query_xyz_table.view(-1)) |
|
key_table.append(self.key_xyz_table.view(-1)) |
|
value_table.append(self.value_xyz_table.view(-1)) |
|
if "RGB" in self.cRSE: |
|
n_rgb = n_coords[:, 3:6] |
|
n_rgb = n_rgb * self.color_quant_size |
|
n_cRSE.append(n_rgb) |
|
query_table.append(self.query_rgb_table.view(-1)) |
|
key_table.append(self.key_rgb_table.view(-1)) |
|
value_table.append(self.value_rgb_table.view(-1)) |
|
if "NORM" in self.cRSE: |
|
n_norm = n_coords[:, 6:9] |
|
n_norm = n_norm * self.normal_quant_size |
|
n_cRSE.append(n_norm) |
|
query_table.append(self.query_norm_table.view(-1)) |
|
key_table.append(self.key_norm_table.view(-1)) |
|
value_table.append(self.value_norm_table.view(-1)) |
|
|
|
n_cRSE = torch.cat(n_cRSE, dim=1) |
|
|
|
indices = [m2w_indices, w_sizes, w2m_indices, w2n_indices, n2n_indices, n_cRSE] |
|
query_table = torch.cat(query_table) |
|
key_table = torch.cat(key_table) |
|
value_table = torch.cat(value_table) |
|
|
|
if self.fp16_mode == 0: |
|
|
|
|
|
fp16_mode = PrecisionMode.HALF_NONE |
|
elif self.fp16_mode == 1: |
|
|
|
fp16_mode = PrecisionMode.HALF_FORWARD |
|
elif self.fp16_mode == 2: |
|
|
|
fp16_mode = PrecisionMode.HALF_ALL |
|
|
|
updated_values = SelfAttnAIOFunction.apply( |
|
query, |
|
key, |
|
value, |
|
query_table, |
|
key_table, |
|
value_table, |
|
table_offsets, |
|
indices, |
|
PosEmb.SEPARATE, |
|
TableDims.D0, |
|
IndexMode.INDIRECT, |
|
fp16_mode, |
|
) |
|
|
|
updated_values = updated_values.flatten(1) |
|
updated_feats = updated_values.view(num_v, self.dim) |
|
|
|
updated_feats = self.proj(updated_feats) |
|
updated_feats = self.proj_drop(updated_feats) |
|
|
|
return updated_feats |
|
|
|
|
|
class SwinTransformerBlock(nn.Module): |
|
def __init__( |
|
self, |
|
dim, |
|
num_heads, |
|
window_size, |
|
quant_size, |
|
drop_path=0.0, |
|
mlp_ratio=4.0, |
|
qkv_bias=True, |
|
qk_scale=None, |
|
act_layer=nn.GELU, |
|
norm_layer=nn.LayerNorm, |
|
cRSE="XYZ_RGB", |
|
fp16_mode=0, |
|
): |
|
super().__init__() |
|
self.window_size = window_size |
|
|
|
self.norm1 = norm_layer(dim) |
|
self.attn = WindowAttention( |
|
dim, |
|
window_size=self.window_size, |
|
quant_size=quant_size, |
|
num_heads=num_heads, |
|
qkv_bias=qkv_bias, |
|
qk_scale=qk_scale, |
|
cRSE=cRSE, |
|
fp16_mode=fp16_mode, |
|
) |
|
|
|
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() |
|
self.norm2 = norm_layer(dim) |
|
mlp_hidden_dim = int(dim * mlp_ratio) |
|
self.mlp = Mlp( |
|
in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer |
|
) |
|
|
|
def forward(self, feats, attn_args): |
|
|
|
short_cut = feats |
|
feats = self.norm1(feats) |
|
feats = self.attn(feats, attn_args) |
|
|
|
feats = short_cut + self.drop_path(feats) |
|
feats = feats + self.drop_path(self.mlp(self.norm2(feats))) |
|
|
|
return feats |
|
|
|
|
|
class BasicLayer(nn.Module): |
|
"""A basic Swin3D layer for one stage. |
|
|
|
Args: |
|
dim (int): Number of input channels. |
|
depth (int): Number of blocks. |
|
num_heads (int): Number of attention heads. |
|
window_size (int): Local window size. |
|
quant_size (int): quant_size for for finer cRSE table |
|
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. |
|
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True |
|
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. |
|
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 |
|
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm |
|
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None |
|
cRSE (str | 'XYZ', 'XYZ_RGB', 'XYZ_RGB_NORM'): cRSE mode. Default: 'XYZ_RGB' |
|
fp16_mode (int | 0, 1, 2): fp16 mode for attention module, Default: 0 |
|
0: fp32 forward and fp32 backward |
|
1: fp16 forward and fp32 backward |
|
2: fp16 forward and fp16 backward |
|
""" |
|
|
|
def __init__( |
|
self, |
|
dim, |
|
depth, |
|
num_heads, |
|
window_size, |
|
quant_size, |
|
out_channels=None, |
|
mlp_ratio=4.0, |
|
qkv_bias=True, |
|
qk_scale=None, |
|
drop_path=0.0, |
|
norm_layer=nn.LayerNorm, |
|
downsample=None, |
|
down_stride=2, |
|
cRSE="XYZ_RGB", |
|
fp16_mode=0, |
|
): |
|
super().__init__() |
|
self.window_size = window_size |
|
self.depth = depth |
|
self.dim = dim |
|
self.num_heads = num_heads |
|
self.quant_size = quant_size |
|
self.cRSE = cRSE |
|
self.fp16_mode = fp16_mode |
|
|
|
self.shift_size = window_size // 2 |
|
|
|
self.blocks = nn.ModuleList( |
|
[ |
|
SwinTransformerBlock( |
|
dim, |
|
num_heads, |
|
window_size, |
|
quant_size, |
|
drop_path=( |
|
drop_path[i] if isinstance(drop_path, list) else drop_path |
|
), |
|
mlp_ratio=mlp_ratio, |
|
qkv_bias=qkv_bias, |
|
qk_scale=qk_scale, |
|
norm_layer=norm_layer, |
|
cRSE=cRSE, |
|
fp16_mode=fp16_mode, |
|
) |
|
for i in range(depth) |
|
] |
|
) |
|
|
|
self.pool = ME.MinkowskiMaxPooling( |
|
kernel_size=self.window_size, stride=self.window_size, dimension=3 |
|
) |
|
|
|
if downsample is not None: |
|
if out_channels is None: |
|
out_channels = dim * 2 |
|
self.downsample = downsample( |
|
dim, out_channels, kernel_size=down_stride, stride=down_stride |
|
) |
|
else: |
|
self.downsample = None |
|
|
|
def get_map_pair(self, sp): |
|
""" |
|
use minkowski pool to calculate windows |
|
get the mapping from voxel to window |
|
""" |
|
window_size = [self.window_size] * 3 |
|
pool_sp = self.pool(sp) |
|
windows = pool_sp.C |
|
window_N = windows.shape[0] |
|
|
|
stride_in = sp.coordinate_map_key.get_tensor_stride() |
|
x, y, z = [ |
|
torch.arange(window_size[i], device=self.device) * stride_in[i] |
|
for i in range(3) |
|
] |
|
x, y, z = torch.meshgrid(x, y, z) |
|
i = torch.zeros_like(x, device=self.device) |
|
local_window = torch.stack([i, x, y, z], dim=-1).flatten(0, -2) |
|
all_windows = windows.unsqueeze(1) + local_window.unsqueeze(0) |
|
all_windows = all_windows.flatten(0, -2).int() |
|
cm = sp.coordinate_manager |
|
query_key, (map, inverse_map) = cm.insert_and_map( |
|
all_windows, tensor_stride=stride_in |
|
) |
|
map_pair = cm.kernel_map(query_key, sp.coordinate_map_key, kernel_size=1)[0] |
|
return map_pair, window_N |
|
|
|
def get_window_mapping(self, sp): |
|
""" |
|
calculate the relationshape in the window: |
|
w_w_id: non-empty idx inside the window(sorted by window) |
|
w_w_xyz: xyz inside the window(sorted by window) |
|
nempty_num: non-empty voxel number in each window |
|
sort_idx: sort voxel according to window_id, to gather the point inside the same window |
|
inv_sort_idx: inverse sort index |
|
""" |
|
map_pair, window_N = self.get_map_pair(sp) |
|
window_size = self.window_size |
|
nW = window_size**3 |
|
in_map, out_map = map_pair |
|
in_map, sort_idx = torch.sort(in_map) |
|
|
|
out_map = out_map[sort_idx] |
|
sort_idx = out_map.long() |
|
inv_sort_idx = torch.zeros_like(sort_idx) |
|
inv_sort_idx[sort_idx] = torch.arange( |
|
sort_idx.shape[0], dtype=sort_idx.dtype, device=self.device |
|
) |
|
N = window_N * nW |
|
v2w_mask = torch.zeros(N, dtype=torch.bool, device=self.device) |
|
w_id = ( |
|
torch.arange(window_N, dtype=torch.long, device=self.device) |
|
.unsqueeze(1) |
|
.repeat(1, nW) |
|
.view(-1) |
|
) |
|
w_w_id = ( |
|
torch.arange(nW, dtype=torch.long, device=self.device) |
|
.unsqueeze(0) |
|
.repeat(window_N, 1) |
|
.view(-1) |
|
) |
|
v2w_mask[in_map.long()] = True |
|
nempty_num = v2w_mask.view(-1, nW).sum(dim=-1) |
|
w_id = w_id[in_map.long()] |
|
w_w_id = w_w_id[in_map.long()] |
|
w_w_xyz = torch.stack( |
|
[ |
|
w_w_id // window_size // window_size, |
|
w_w_id // window_size % window_size, |
|
w_w_id % window_size, |
|
], |
|
dim=-1, |
|
) |
|
return w_w_id, w_w_xyz, nempty_num, sort_idx, inv_sort_idx |
|
|
|
def get_index01(self, sp, local_xyz, colors): |
|
""" |
|
calculate the arguments for sparse attention |
|
""" |
|
( |
|
w_w_id, |
|
w_w_xyz, |
|
nempty_num, |
|
n2n_indices, |
|
inv_sort_idx, |
|
) = self.get_window_mapping(sp) |
|
local_xyz = local_xyz[n2n_indices] |
|
colors = colors[n2n_indices] |
|
|
|
n_coords = w_w_xyz + local_xyz |
|
n_coords = torch.cat([n_coords, colors], dim=1) |
|
( |
|
x_offset, |
|
y_offset, |
|
m2w_indices, |
|
w_sizes, |
|
w2n_indices, |
|
w2m_indices, |
|
) = sparse_self_attention(w_w_id, nempty_num, protocol="v2") |
|
return ( |
|
x_offset, |
|
y_offset, |
|
m2w_indices, |
|
w_sizes, |
|
w2n_indices, |
|
n2n_indices, |
|
w2m_indices, |
|
n_coords, |
|
) |
|
|
|
def get_shifted_sp(self, sp): |
|
""" |
|
get the shifted sparse tensor for shift-window |
|
""" |
|
stride_in = sp.coordinate_map_key.get_tensor_stride() |
|
shift_size = self.shift_size * stride_in[0] |
|
shifted_C = sp.C.clone() |
|
shifted_C[:, 1:] += shift_size |
|
shifted_sp = SparseTensor( |
|
features=sp.F, |
|
coordinates=shifted_C, |
|
device=self.device, |
|
tensor_stride=stride_in, |
|
) |
|
return shifted_sp |
|
|
|
def get_window_pos(self, sp): |
|
stride_in = sp.coordinate_map_key.get_tensor_stride() |
|
return (sp.C[:, 1:] / stride_in[0]) % self.window_size |
|
|
|
def forward(self, sp, coords_sp): |
|
""" |
|
xyz: position of point inside voxel |
|
colors: other signal for cRSE, include colors and normals |
|
local_xyz: relative position of point indide voxel(using for finer cRSE table) |
|
""" |
|
colors = coords_sp.F[:, 4:] |
|
xyz = coords_sp.F[:, :4] |
|
local_xyz = (xyz - coords_sp.C)[ |
|
:, 1: |
|
] / coords_sp.coordinate_map_key.get_tensor_stride()[0] |
|
self.device = sp.device |
|
sp_shift = self.get_shifted_sp(sp) |
|
|
|
attn_args = self.get_index01(sp, local_xyz, colors) |
|
attn_args_shift = self.get_index01(sp_shift, local_xyz, colors) |
|
|
|
feats = sp.F |
|
for i, blk in enumerate(self.blocks): |
|
attn_args_blk = attn_args if i % 2 == 0 else attn_args_shift |
|
feats = blk(feats, attn_args_blk) |
|
|
|
sp = assign_feats(sp, feats) |
|
if self.downsample is not None: |
|
sp_down, coords_sp = self.downsample(sp, coords_sp) |
|
return sp, sp_down, coords_sp |
|
else: |
|
return sp, sp, coords_sp |
|
|
|
def extra_repr(self) -> str: |
|
return f"window_size={self.window_size}, depth={self.depth}, channel={self.dim}, num_heads={self.num_heads}, quant_size={self.quant_size}, cRSE={self.cRSE}, fp16_mode={self.fp16_mode}" |
|
|