""" # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. """ import numpy as np import torch import torch.nn as nn from timm.models.layers import DropPath, trunc_normal_ import MinkowskiEngine as ME from MinkowskiEngine import SparseTensor from Swin3D.sparse_dl.attn.attn_coff import ( SelfAttnAIOFunction, PosEmb, TableDims, IndexMode, PrecisionMode, ) import Swin3D.sparse_dl.knn from Swin3D.sparse_dl.knn import KNN from .mink_layers import ( assign_feats, SparseTensorLayerNorm, SparseTensorLinear, ) def query_knn_feature( K, src_xyz, query_xyz, src_feat, src_offset, query_offset, return_idx=False ): """ gather feature in the KNN neighborhood """ assert ( src_xyz.is_contiguous() and query_xyz.is_contiguous() and src_feat.is_contiguous() ) if query_xyz is None: query_xyz = src_xyz query_offset = src_offset idx, _ = KNN.apply(K, src_xyz, query_xyz, src_offset, query_offset) n, m, c = src_xyz.shape[0], query_xyz.shape[0], src_feat.shape[1] grouped_feat = src_feat[idx.view(-1).long(), :].view(m, K, c) if return_idx: return grouped_feat, idx else: return grouped_feat def knn_linear_interpolation( src_xyz, query_xyz, src_feat, src_offset, query_offset, K=3 ): """ interpolation feature using distance in KNN neighborhood """ N, C = query_xyz.shape[0], src_feat.shape[1] assert ( src_xyz.is_contiguous() and query_xyz.is_contiguous() and src_feat.is_contiguous() ) # (N, K) idx, dist = KNN.apply(K, src_xyz, query_xyz, src_offset, query_offset) weight = 1.0 / (dist + 1e-8) norm = torch.sum(weight, dim=1, keepdim=True) weight = weight / norm query_feat = torch.zeros((N, C), dtype=src_feat.dtype, device=src_feat.device) for i in range(K): query_feat += src_feat[idx[:, i].long(), :] * weight[:, i].unsqueeze(-1) return query_feat def sparse_self_attention( w_w_id: torch.Tensor, w_sizes: torch.Tensor, protocol: str = "v1" ): """ Args: indices [torch.Tensor]: sparse window index with shape [N, 2], N is the total number of non-empty voxels with indices (window_id, within_window_id). window_id is ordered and starts from 0; within_window_id is a sparse index to indicate the offset of kernel_size ** 3. feats [torch.Tensor]: sprase features of each non-empty voxel with shape [N, C] Outputs: [M, 3]: sparse indices of cofficient matrix (window_id, att_a_id, att_b_id). att_a_id and att_b_id are the within_window_id [M, 1]: the sparse coffient matrix Spaces: W: total number of windows N: total number of input voxels M: total number of output cofficients """ w_sizes_2 = w_sizes**2 # w2n_indices - [W], mapping window index to window global offset in input # space w_cumsum = torch.cumsum(w_sizes, dim=-1) w2n_indices = torch.cat( [torch.zeros(1, dtype=w_cumsum.dtype, device=w_cumsum.device), w_cumsum[:-1]] ) # w2m indices - [W], mapping window index to window global offset in output # space w2_cumsum = torch.cumsum(w_sizes_2, dim=-1) w2m_indices = torch.cat( [torch.zeros(1, dtype=w2_cumsum.dtype, device=w2_cumsum.device), w2_cumsum[:-1]] ) # m2w indices - [M], mapping element global offset to the window index m2w_indices = torch.zeros( [w2_cumsum[-1]], dtype=w_sizes.dtype, device=w_sizes.device ) m2w_offset = torch.zeros( [w2_cumsum[-1]], dtype=w_sizes.dtype, device=w_sizes.device ) m2w_indices[w2m_indices[1:]] = 1 m2w_offset[w2m_indices[1:]] = w_sizes_2[:-1] m2w_indices = torch.cumsum(m2w_indices, dim=-1) m2w_offset = torch.cumsum(m2w_offset, dim=-1) # m_indices = [M], element global offset in output space m_indices = torch.arange( 0, w2_cumsum[-1], dtype=w_sizes.dtype, device=w_sizes.device ) # m2n_indices - [M], mapping element global offset to the window global offset # in input space m2n_indices = w2n_indices[m2w_indices] m_offset = m_indices - m2w_offset m2w_sizes = w_sizes[m2w_indices] # print_log_main("m_offset:", m_offset, m_offset.shape) # print_log_main("m2n_indices:", m2n_indices, m2n_indices.shape) y_offset = m2n_indices + m_offset % m2w_sizes x_offset = m2n_indices + torch.div(m_offset, m2w_sizes, rounding_mode="floor") # print_log_main("=================================") # print_log_main(w_sizes[:5]) # print_log_main(x_offset[:50]) # print_log_main(y_offset[:50]) # coord = torch.stack([m2w_indices, w_w_id[x_offset], w_w_id[y_offset]], axis=-1) if protocol == "v1": return x_offset, y_offset elif protocol == "v2": return x_offset, y_offset, m2w_indices, w_sizes, w2n_indices, w2m_indices class Mlp(nn.Module): def __init__( self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0, ): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class GridCoordsDown(nn.Module): """ downsample the grid coordinates keep the nearest point to the average point of the downsampled grid """ def __init__(self, stride): super().__init__() self.stride = stride self.avg_pool = ME.MinkowskiAvgPooling( kernel_size=self.stride, stride=self.stride, dimension=3 ) self.unpool = ME.MinkowskiPoolingTranspose( kernel_size=stride, stride=stride, dimension=3 ) self.max_pool = ME.MinkowskiMaxPooling( kernel_size=self.stride, stride=self.stride, dimension=3 ) def forward(self, coords_sp, sp, return_map=False): device = sp.C.device # is_pool = True means pooling map # is_pool = False means conv map (query as center) N = sp.shape[0] avg_coords_sp = self.avg_pool(coords_sp) dist_sp = self.unpool(avg_coords_sp) - coords_sp dist = dist_sp.F dist = -torch.sqrt((dist**2).sum(dim=1)).unsqueeze(1) dist_sp = assign_feats(dist_sp, dist) min_dist_sp = self.max_pool(dist_sp) map_pair = sp.coordinate_manager.kernel_map( dist_sp.coordinate_map_key, min_dist_sp.coordinate_map_key, stride=self.stride, kernel_size=self.stride, is_pool=True, )[0] in_map, out_map = map_pair broad_min_dist_sp = self.unpool(min_dist_sp) mask = (broad_min_dist_sp.F == dist_sp.F).squeeze(1) in_map = in_map[mask].long() out_map = out_map[mask].long() downsample_map = torch.zeros(N, dtype=torch.long, device=device) - 1 downsample_map[out_map] = in_map assert (downsample_map >= 0).all() assert (dist_sp.F[downsample_map] == min_dist_sp.F).all() new_coords = coords_sp.F[downsample_map] new_coords_sp = assign_feats(sp, new_coords) if return_map: return new_coords_sp, downsample_map else: return new_coords_sp def get_offset(batch): offset = [] bs = batch.max() + 1 for i in range(bs): offset.append(torch.sum(batch == i)) offset = torch.cuda.IntTensor(offset) offset = offset.cumsum(dim=0).int() return offset class GridDownsample(nn.Module): """ use stride to downsample voxel use grid maxpooling with kernel_size """ def __init__(self, in_channels, out_channels, kernel_size=2, stride=2): super().__init__() self.kernel_size = kernel_size self.stride = stride self.in_channels = in_channels self.out_channels = out_channels self.sp_pool = ME.MinkowskiMaxPooling( kernel_size=kernel_size, stride=stride, dimension=3 ) self.coords_pool = GridCoordsDown(stride=stride) self.norm = SparseTensorLayerNorm(in_channels) self.linear = SparseTensorLinear(in_channels, out_channels) def forward(self, sp, coords_sp): sp_down = self.sp_pool(self.linear(self.norm(sp))) coords_sp_down = self.coords_pool(coords_sp, sp_down) return sp_down, coords_sp_down def extra_repr(self) -> str: return f"kernel_size={self.kernel_size}, stride={self.stride}, in_channels={self.in_channels}, out_channels={self.out_channels}" class GridKNNDownsample(nn.Module): """ use stride to downsample voxel use KNN to do maxpooling """ def __init__(self, in_channels, out_channels, kernel_size=2, stride=2): super().__init__() self.stride = stride self.in_channels = in_channels self.out_channels = out_channels self.k = 16 self.sp_pool = ME.MinkowskiMaxPooling( kernel_size=stride, stride=stride, dimension=3 ) self.coords_pool = GridCoordsDown(stride=stride) self.norm = nn.LayerNorm(in_channels) self.linear = nn.Linear(in_channels, out_channels, bias=False) self.pool = nn.MaxPool1d(self.k) def forward(self, sp, coords_sp): # calculate the voxel sp_down = self.sp_pool(sp) # for downsampled cRSE coords_sp_down = self.coords_pool(coords_sp, sp_down) offset = get_offset(sp.C[:, 0]) n_offset = get_offset(sp_down.C[:, 0]) xyz = coords_sp.F[:, 1:4].detach().contiguous() n_xyz = coords_sp_down.F[:, 1:4].detach().contiguous() feats = query_knn_feature(self.k, xyz, n_xyz, sp.F, offset, n_offset) m, k, c = feats.shape feats = ( self.linear(self.norm(feats.view(m * k, c)).view(m, k, c)) .transpose(1, 2) .contiguous() ) feats = self.pool(feats).squeeze(-1) sp = assign_feats(sp_down, feats.float()) coords_sp = coords_sp_down return sp, coords_sp def extra_repr(self) -> str: return f"kernel_size={self.k}, stride={self.stride}, in_channels={self.in_channels}, out_channels={self.out_channels}" class Upsample(nn.Module): """ upsample using trilinear interpolation follower by attn block according to self.attn """ def __init__( self, in_channels, out_channels, num_heads, window_size, quant_size, attn=True, up_k=3, cRSE="XYZ_RGB", fp16_mode=0, ): super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.linear1 = nn.Sequential( nn.LayerNorm(out_channels), nn.Linear(out_channels, out_channels) ) self.linear2 = nn.Sequential( nn.LayerNorm(in_channels), nn.Linear(in_channels, out_channels) ) self.up_k = up_k self.attn = attn and window_size > 0 if self.attn: self.block = BasicLayer( dim=out_channels, depth=1, num_heads=num_heads, window_size=window_size, quant_size=quant_size, drop_path=0.1, downsample=None, out_channels=None, cRSE=cRSE, fp16_mode=fp16_mode, ) def forward(self, sp, coords_sp, sp_up, coords_sp_up): feats = sp.F support_feats = sp_up.F xyz = coords_sp.F[:, 1:4].detach().contiguous() support_xyz = coords_sp_up.F[:, 1:4].detach().contiguous() offset = get_offset(sp.C[:, 0]) support_offset = get_offset(sp_up.C[:, 0]) feats = self.linear1(support_feats) + knn_linear_interpolation( xyz, support_xyz, self.linear2(feats), offset, support_offset, K=self.up_k ) sp_up = assign_feats(sp_up, feats) if self.attn: sp_up, _, _ = self.block(sp_up, coords_sp_up) return sp_up def extra_repr(self) -> str: return f"up_k={self.up_k}, in_channels={self.in_channels}, out_channels={self.out_channels}, attn={self.attn}" class WindowAttention(nn.Module): """ Window based multi-head self attention (W-MSA) module with cRSE. Designed for sparse structure It supports both of shifted and non-shifted window. Args: dim (int): Number of input channels. window_size (tuple[int]): The height and width of the window. quant_size (int): quant_size for for finer cRSE table num_heads (int): Number of attention heads. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 proj_drop (float, optional): Dropout ratio of output. Default: 0.0 cRSE (str | 'XYZ', 'XYZ_RGB', 'XYZ_RGB_NORM'): cRSE mode. Default: 'XYZ_RGB' fp16_mode (int | 0, 1, 2): fp16 mode for attention module, Default: 0 0: fp32 forward and fp32 backward 1: fp16 forward and fp32 backward 2: fp16 forward and fp16 backward """ def __init__( self, dim, window_size, quant_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0.0, proj_drop=0.0, cRSE="XYZ_RGB", fp16_mode=0, ): super().__init__() self.dim = dim self.window_size = window_size self.num_heads = num_heads head_dim = dim // num_heads self.scale = qk_scale or head_dim**-0.5 # color in [-1, 1], color_windowsize = 2 # normal in [-1, 1], normal_windowsize = 2 self.color_windowsize = 2 self.normal_windowsize = 2 self.fp16_mode = fp16_mode table_offsets = [] self.cRSE = cRSE if "XYZ" in cRSE: self.xyz_quant_size = quant_size quant_grid_length_xyz = window_size * self.xyz_quant_size table_shape_xyz = (3, 2 * quant_grid_length_xyz, num_heads, head_dim) self.query_xyz_table = nn.Parameter(torch.zeros(table_shape_xyz)) trunc_normal_(self.query_xyz_table, std=0.02) self.key_xyz_table = nn.Parameter(torch.zeros(table_shape_xyz)) trunc_normal_(self.key_xyz_table, std=0.02) self.value_xyz_table = nn.Parameter(torch.zeros(table_shape_xyz)) trunc_normal_(self.value_xyz_table, std=0.02) table_offsets += [np.prod(table_shape_xyz[1:])] * 3 if "RGB" in cRSE: self.color_quant_size = quant_size * 2 quant_grid_length_rgb = self.color_windowsize * self.color_quant_size table_shape_rgb = (3, 2 * quant_grid_length_rgb, num_heads, head_dim) self.query_rgb_table = nn.Parameter(torch.zeros(table_shape_rgb)) trunc_normal_(self.query_rgb_table, std=0.02) self.key_rgb_table = nn.Parameter(torch.zeros(table_shape_rgb)) trunc_normal_(self.key_rgb_table, std=0.02) self.value_rgb_table = nn.Parameter(torch.zeros(table_shape_rgb)) trunc_normal_(self.value_rgb_table, std=0.02) table_offsets += [np.prod(table_shape_rgb[1:])] * 3 if "NORM" in cRSE: self.normal_quant_size = quant_size * 2 quant_grid_length_norm = self.normal_windowsize * self.normal_quant_size table_shape_norm = (3, 2 * quant_grid_length_norm, num_heads, head_dim) self.query_norm_table = nn.Parameter(torch.zeros(table_shape_norm)) trunc_normal_(self.query_norm_table, std=0.02) self.key_norm_table = nn.Parameter(torch.zeros(table_shape_norm)) trunc_normal_(self.key_norm_table, std=0.02) self.value_norm_table = nn.Parameter(torch.zeros(table_shape_norm)) trunc_normal_(self.value_norm_table, std=0.02) table_offsets += [np.prod(table_shape_norm[1:])] * 3 self.table_offsets = table_offsets self.quant_size = quant_size self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop, inplace=True) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop, inplace=True) self.softmax = nn.Softmax(dim=-1) def forward(self, feats: torch.Tensor, attn_args): """Forward function. Args: feats: N, C attn_args: arguments for computing attention """ num_v, _ = feats.shape num_sc = self.dim // self.num_heads ( x_offset, y_offset, m2w_indices, w_sizes, w2n_indices, n2n_indices, w2m_indices, n_coords, ) = attn_args # Query, Key, Value qkv = self.qkv(feats) qkv = ( qkv.reshape(num_v, 3, self.num_heads, num_sc) .permute(1, 0, 2, 3) .contiguous() ) query, key, value = qkv[0], qkv[1], qkv[2] # [N, num_heads, C//num_heads] query = query * self.scale table_offsets = torch.IntTensor(self.table_offsets).cuda() query_table, key_table, value_table = [], [], [] n_cRSE = [] if "XYZ" in self.cRSE: n_xyz = n_coords[:, 0:3] n_xyz = n_xyz * self.quant_size n_cRSE.append(n_xyz) query_table.append(self.query_xyz_table.view(-1)) key_table.append(self.key_xyz_table.view(-1)) value_table.append(self.value_xyz_table.view(-1)) if "RGB" in self.cRSE: n_rgb = n_coords[:, 3:6] n_rgb = n_rgb * self.color_quant_size n_cRSE.append(n_rgb) query_table.append(self.query_rgb_table.view(-1)) key_table.append(self.key_rgb_table.view(-1)) value_table.append(self.value_rgb_table.view(-1)) if "NORM" in self.cRSE: n_norm = n_coords[:, 6:9] n_norm = n_norm * self.normal_quant_size n_cRSE.append(n_norm) query_table.append(self.query_norm_table.view(-1)) key_table.append(self.key_norm_table.view(-1)) value_table.append(self.value_norm_table.view(-1)) n_cRSE = torch.cat(n_cRSE, dim=1) indices = [m2w_indices, w_sizes, w2m_indices, w2n_indices, n2n_indices, n_cRSE] query_table = torch.cat(query_table) key_table = torch.cat(key_table) value_table = torch.cat(value_table) if self.fp16_mode == 0: # do not use fp16 # cast q,k,v to fp32 in forward and backward fp16_mode = PrecisionMode.HALF_NONE elif self.fp16_mode == 1: # use fp16 only in forward fp16_mode = PrecisionMode.HALF_FORWARD elif self.fp16_mode == 2: # use fp16 both in forward and backward fp16_mode = PrecisionMode.HALF_ALL updated_values = SelfAttnAIOFunction.apply( query, key, value, query_table, key_table, value_table, table_offsets, indices, PosEmb.SEPARATE, TableDims.D0, IndexMode.INDIRECT, fp16_mode, ) updated_values = updated_values.flatten(1) updated_feats = updated_values.view(num_v, self.dim) updated_feats = self.proj(updated_feats) updated_feats = self.proj_drop(updated_feats) # [N, C] return updated_feats class SwinTransformerBlock(nn.Module): def __init__( self, dim, num_heads, window_size, quant_size, drop_path=0.0, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm, cRSE="XYZ_RGB", fp16_mode=0, ): super().__init__() self.window_size = window_size self.norm1 = norm_layer(dim) self.attn = WindowAttention( dim, window_size=self.window_size, quant_size=quant_size, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, cRSE=cRSE, fp16_mode=fp16_mode, ) self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp( in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer ) def forward(self, feats, attn_args): # feats: [N, c] short_cut = feats feats = self.norm1(feats) feats = self.attn(feats, attn_args) # [N, c] feats = short_cut + self.drop_path(feats) feats = feats + self.drop_path(self.mlp(self.norm2(feats))) return feats class BasicLayer(nn.Module): """A basic Swin3D layer for one stage. Args: dim (int): Number of input channels. depth (int): Number of blocks. num_heads (int): Number of attention heads. window_size (int): Local window size. quant_size (int): quant_size for for finer cRSE table mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None cRSE (str | 'XYZ', 'XYZ_RGB', 'XYZ_RGB_NORM'): cRSE mode. Default: 'XYZ_RGB' fp16_mode (int | 0, 1, 2): fp16 mode for attention module, Default: 0 0: fp32 forward and fp32 backward 1: fp16 forward and fp32 backward 2: fp16 forward and fp16 backward """ def __init__( self, dim, depth, num_heads, window_size, quant_size, out_channels=None, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, drop_path=0.0, norm_layer=nn.LayerNorm, downsample=None, down_stride=2, cRSE="XYZ_RGB", fp16_mode=0, ): super().__init__() self.window_size = window_size self.depth = depth self.dim = dim self.num_heads = num_heads self.quant_size = quant_size self.cRSE = cRSE self.fp16_mode = fp16_mode self.shift_size = window_size // 2 # build blocks self.blocks = nn.ModuleList( [ SwinTransformerBlock( dim, num_heads, window_size, quant_size, drop_path=( drop_path[i] if isinstance(drop_path, list) else drop_path ), mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, norm_layer=norm_layer, cRSE=cRSE, fp16_mode=fp16_mode, ) for i in range(depth) ] ) self.pool = ME.MinkowskiMaxPooling( kernel_size=self.window_size, stride=self.window_size, dimension=3 ) if downsample is not None: if out_channels is None: out_channels = dim * 2 self.downsample = downsample( dim, out_channels, kernel_size=down_stride, stride=down_stride ) else: self.downsample = None def get_map_pair(self, sp): """ use minkowski pool to calculate windows get the mapping from voxel to window """ window_size = [self.window_size] * 3 pool_sp = self.pool(sp) windows = pool_sp.C window_N = windows.shape[0] stride_in = sp.coordinate_map_key.get_tensor_stride() x, y, z = [ torch.arange(window_size[i], device=self.device) * stride_in[i] for i in range(3) ] x, y, z = torch.meshgrid(x, y, z) i = torch.zeros_like(x, device=self.device) local_window = torch.stack([i, x, y, z], dim=-1).flatten(0, -2) all_windows = windows.unsqueeze(1) + local_window.unsqueeze(0) all_windows = all_windows.flatten(0, -2).int() cm = sp.coordinate_manager query_key, (map, inverse_map) = cm.insert_and_map( all_windows, tensor_stride=stride_in ) map_pair = cm.kernel_map(query_key, sp.coordinate_map_key, kernel_size=1)[0] return map_pair, window_N def get_window_mapping(self, sp): """ calculate the relationshape in the window: w_w_id: non-empty idx inside the window(sorted by window) w_w_xyz: xyz inside the window(sorted by window) nempty_num: non-empty voxel number in each window sort_idx: sort voxel according to window_id, to gather the point inside the same window inv_sort_idx: inverse sort index """ map_pair, window_N = self.get_map_pair(sp) window_size = self.window_size nW = window_size**3 in_map, out_map = map_pair in_map, sort_idx = torch.sort(in_map) # assert out_map == arange(out_map.shape[0]) out_map = out_map[sort_idx] sort_idx = out_map.long() inv_sort_idx = torch.zeros_like(sort_idx) inv_sort_idx[sort_idx] = torch.arange( sort_idx.shape[0], dtype=sort_idx.dtype, device=self.device ) N = window_N * nW v2w_mask = torch.zeros(N, dtype=torch.bool, device=self.device) w_id = ( torch.arange(window_N, dtype=torch.long, device=self.device) .unsqueeze(1) .repeat(1, nW) .view(-1) ) w_w_id = ( torch.arange(nW, dtype=torch.long, device=self.device) .unsqueeze(0) .repeat(window_N, 1) .view(-1) ) v2w_mask[in_map.long()] = True nempty_num = v2w_mask.view(-1, nW).sum(dim=-1) w_id = w_id[in_map.long()] w_w_id = w_w_id[in_map.long()] w_w_xyz = torch.stack( [ w_w_id // window_size // window_size, w_w_id // window_size % window_size, w_w_id % window_size, ], dim=-1, ) return w_w_id, w_w_xyz, nempty_num, sort_idx, inv_sort_idx def get_index01(self, sp, local_xyz, colors): """ calculate the arguments for sparse attention """ ( w_w_id, w_w_xyz, nempty_num, n2n_indices, inv_sort_idx, ) = self.get_window_mapping(sp) local_xyz = local_xyz[n2n_indices] colors = colors[n2n_indices] # recover the relative pos in the voxel n_coords = w_w_xyz + local_xyz n_coords = torch.cat([n_coords, colors], dim=1) ( x_offset, y_offset, m2w_indices, w_sizes, w2n_indices, w2m_indices, ) = sparse_self_attention(w_w_id, nempty_num, protocol="v2") return ( x_offset, y_offset, m2w_indices, w_sizes, w2n_indices, n2n_indices, w2m_indices, n_coords, ) def get_shifted_sp(self, sp): """ get the shifted sparse tensor for shift-window """ stride_in = sp.coordinate_map_key.get_tensor_stride() shift_size = self.shift_size * stride_in[0] shifted_C = sp.C.clone() shifted_C[:, 1:] += shift_size shifted_sp = SparseTensor( features=sp.F, coordinates=shifted_C, device=self.device, tensor_stride=stride_in, ) return shifted_sp def get_window_pos(self, sp): stride_in = sp.coordinate_map_key.get_tensor_stride() return (sp.C[:, 1:] / stride_in[0]) % self.window_size def forward(self, sp, coords_sp): """ xyz: position of point inside voxel colors: other signal for cRSE, include colors and normals local_xyz: relative position of point indide voxel(using for finer cRSE table) """ colors = coords_sp.F[:, 4:] xyz = coords_sp.F[:, :4] local_xyz = (xyz - coords_sp.C)[ :, 1: ] / coords_sp.coordinate_map_key.get_tensor_stride()[0] self.device = sp.device sp_shift = self.get_shifted_sp(sp) attn_args = self.get_index01(sp, local_xyz, colors) attn_args_shift = self.get_index01(sp_shift, local_xyz, colors) feats = sp.F for i, blk in enumerate(self.blocks): attn_args_blk = attn_args if i % 2 == 0 else attn_args_shift feats = blk(feats, attn_args_blk) # [N, C] sp = assign_feats(sp, feats) if self.downsample is not None: sp_down, coords_sp = self.downsample(sp, coords_sp) return sp, sp_down, coords_sp else: return sp, sp, coords_sp def extra_repr(self) -> str: return f"window_size={self.window_size}, depth={self.depth}, channel={self.dim}, num_heads={self.num_heads}, quant_size={self.quant_size}, cRSE={self.cRSE}, fp16_mode={self.fp16_mode}"