from functools import partial from addict import Dict import math import torch import torch.nn as nn import spconv.pytorch as spconv import torch_scatter from timm.models.layers import DropPath from typing import Union from einops import rearrange try: import flash_attn except ImportError: flash_attn = None from .utils.misc import offset2bincount from .utils.structure import Point from .modules import PointModule, PointSequential class RPE(torch.nn.Module): def __init__(self, patch_size, num_heads): super().__init__() self.patch_size = patch_size self.num_heads = num_heads self.pos_bnd = int((4 * patch_size) ** (1 / 3) * 2) self.rpe_num = 2 * self.pos_bnd + 1 self.rpe_table = torch.nn.Parameter(torch.zeros(3 * self.rpe_num, num_heads)) torch.nn.init.trunc_normal_(self.rpe_table, std=0.02) def forward(self, coord): idx = ( coord.clamp(-self.pos_bnd, self.pos_bnd) # clamp into bnd + self.pos_bnd # relative position to positive index + torch.arange(3, device=coord.device) * self.rpe_num # x, y, z stride ) out = self.rpe_table.index_select(0, idx.reshape(-1)) out = out.view(idx.shape + (-1,)).sum(3) out = out.permute(0, 3, 1, 2) # (N, K, K, H) -> (N, H, K, K) return out class QueryKeyNorm(nn.Module): def __init__(self, channels, num_heads): super(QueryKeyNorm, self).__init__() self.num_heads = num_heads self.norm = nn.LayerNorm(channels // num_heads, elementwise_affine=False) def forward(self, qkv): H = self.num_heads #qkv = qkv.reshape(-1, 3, H, qkv.shape[1] // H).permute(1, 0, 2, 3) qkv = rearrange(qkv, 'N (S H Ch) -> S N H Ch', H=H, S=3) q, k, v = qkv.unbind(dim=0) # q, k, v: [N, H, C // H] q_norm = self.norm(q) k_norm = self.norm(k) # qkv_norm: [3, N, H, C // H] qkv_norm = torch.stack([q_norm, k_norm, v]) qkv_norm = rearrange(qkv_norm, 'S N H Ch -> N (S H Ch)') return qkv_norm class SerializedAttention(PointModule): def __init__( self, channels, num_heads, patch_size, qkv_bias=True, qk_scale=None, attn_drop=0.0, proj_drop=0.0, order_index=0, enable_rpe=False, enable_flash=True, upcast_attention=True, upcast_softmax=True, enable_qknorm=False, ): super().__init__() assert channels % num_heads == 0, f"channels {channels} must be divisible by num_heads {num_heads}" self.channels = channels self.num_heads = num_heads self.scale = qk_scale or (channels // num_heads) ** -0.5 self.order_index = order_index self.upcast_attention = upcast_attention self.upcast_softmax = upcast_softmax self.enable_rpe = enable_rpe self.enable_flash = enable_flash self.enable_qknorm = enable_qknorm if enable_qknorm: self.qknorm = QueryKeyNorm(channels, num_heads) else: print("WARNING: enable_qknorm is False in PTv3Object and training may be fragile") if enable_flash: assert ( enable_rpe is False ), "Set enable_rpe to False when enable Flash Attention" assert ( upcast_attention is False ), "Set upcast_attention to False when enable Flash Attention" assert ( upcast_softmax is False ), "Set upcast_softmax to False when enable Flash Attention" assert flash_attn is not None, "Make sure flash_attn is installed." self.patch_size = patch_size self.attn_drop = attn_drop else: # when disable flash attention, we still don't want to use mask # consequently, patch size will auto set to the # min number of patch_size_max and number of points self.patch_size_max = patch_size self.patch_size = 0 self.attn_drop = torch.nn.Dropout(attn_drop) self.qkv = torch.nn.Linear(channels, channels * 3, bias=qkv_bias) self.proj = torch.nn.Linear(channels, channels) self.proj_drop = torch.nn.Dropout(proj_drop) self.softmax = torch.nn.Softmax(dim=-1) self.rpe = RPE(patch_size, num_heads) if self.enable_rpe else None @torch.no_grad() def get_rel_pos(self, point, order): K = self.patch_size rel_pos_key = f"rel_pos_{self.order_index}" if rel_pos_key not in point.keys(): grid_coord = point.grid_coord[order] grid_coord = grid_coord.reshape(-1, K, 3) point[rel_pos_key] = grid_coord.unsqueeze(2) - grid_coord.unsqueeze(1) return point[rel_pos_key] @torch.no_grad() def get_padding_and_inverse(self, point): pad_key = "pad" unpad_key = "unpad" cu_seqlens_key = "cu_seqlens_key" if ( pad_key not in point.keys() or unpad_key not in point.keys() or cu_seqlens_key not in point.keys() ): offset = point.offset bincount = offset2bincount(offset) bincount_pad = ( torch.div( bincount + self.patch_size - 1, self.patch_size, rounding_mode="trunc", ) * self.patch_size ) # only pad point when num of points larger than patch_size mask_pad = bincount > self.patch_size bincount_pad = ~mask_pad * bincount + mask_pad * bincount_pad _offset = nn.functional.pad(offset, (1, 0)) _offset_pad = nn.functional.pad(torch.cumsum(bincount_pad, dim=0), (1, 0)) pad = torch.arange(_offset_pad[-1], device=offset.device) unpad = torch.arange(_offset[-1], device=offset.device) cu_seqlens = [] for i in range(len(offset)): unpad[_offset[i] : _offset[i + 1]] += _offset_pad[i] - _offset[i] if bincount[i] != bincount_pad[i]: pad[ _offset_pad[i + 1] - self.patch_size + (bincount[i] % self.patch_size) : _offset_pad[i + 1] ] = pad[ _offset_pad[i + 1] - 2 * self.patch_size + (bincount[i] % self.patch_size) : _offset_pad[i + 1] - self.patch_size ] pad[_offset_pad[i] : _offset_pad[i + 1]] -= _offset_pad[i] - _offset[i] cu_seqlens.append( torch.arange( _offset_pad[i], _offset_pad[i + 1], step=self.patch_size, dtype=torch.int32, device=offset.device, ) ) point[pad_key] = pad point[unpad_key] = unpad point[cu_seqlens_key] = nn.functional.pad( torch.concat(cu_seqlens), (0, 1), value=_offset_pad[-1] ) return point[pad_key], point[unpad_key], point[cu_seqlens_key] def forward(self, point): if not self.enable_flash: self.patch_size = min( offset2bincount(point.offset).min().tolist(), self.patch_size_max ) H = self.num_heads K = self.patch_size C = self.channels pad, unpad, cu_seqlens = self.get_padding_and_inverse(point) order = point.serialized_order[self.order_index][pad] inverse = unpad[point.serialized_inverse[self.order_index]] # padding and reshape feat and batch for serialized point patch qkv = self.qkv(point.feat)[order] if self.enable_qknorm: qkv = self.qknorm(qkv) if not self.enable_flash: # encode and reshape qkv: (N', K, 3, H, C') => (3, N', H, K, C') q, k, v = ( qkv.reshape(-1, K, 3, H, C // H).permute(2, 0, 3, 1, 4).unbind(dim=0) ) # attn if self.upcast_attention: q = q.float() k = k.float() attn = (q * self.scale) @ k.transpose(-2, -1) # (N', H, K, K) if self.enable_rpe: attn = attn + self.rpe(self.get_rel_pos(point, order)) if self.upcast_softmax: attn = attn.float() attn = self.softmax(attn) attn = self.attn_drop(attn).to(qkv.dtype) feat = (attn @ v).transpose(1, 2).reshape(-1, C) else: feat = flash_attn.flash_attn_varlen_qkvpacked_func( qkv.half().reshape(-1, 3, H, C // H), cu_seqlens, max_seqlen=self.patch_size, dropout_p=self.attn_drop if self.training else 0, softmax_scale=self.scale, ).reshape(-1, C) feat = feat.to(qkv.dtype) feat = feat[inverse] # ffn feat = self.proj(feat) feat = self.proj_drop(feat) point.feat = feat return point class MLP(nn.Module): def __init__( self, in_channels, hidden_channels=None, out_channels=None, act_layer=nn.GELU, drop=0.0, ): super().__init__() out_channels = out_channels or in_channels hidden_channels = hidden_channels or in_channels self.fc1 = nn.Linear(in_channels, hidden_channels) self.act = act_layer() self.fc2 = nn.Linear(hidden_channels, out_channels) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class Block(PointModule): def __init__( self, channels, num_heads, patch_size=48, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, attn_drop=0.0, proj_drop=0.0, drop_path=0.0, norm_layer=nn.LayerNorm, act_layer=nn.GELU, pre_norm=True, order_index=0, cpe_indice_key=None, enable_rpe=False, enable_flash=True, upcast_attention=True, upcast_softmax=True, enable_qknorm=False, ): super().__init__() self.channels = channels self.pre_norm = pre_norm self.cpe = PointSequential( spconv.SubMConv3d( channels, channels, kernel_size=3, bias=True, indice_key=cpe_indice_key, ), nn.Linear(channels, channels), norm_layer(channels), ) self.norm1 = PointSequential(norm_layer(channels)) self.attn = SerializedAttention( channels=channels, patch_size=patch_size, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=proj_drop, order_index=order_index, enable_rpe=enable_rpe, enable_flash=enable_flash, upcast_attention=upcast_attention, upcast_softmax=upcast_softmax, enable_qknorm=enable_qknorm, ) self.norm2 = PointSequential(norm_layer(channels)) self.mlp = PointSequential( MLP( in_channels=channels, hidden_channels=int(channels * mlp_ratio), out_channels=channels, act_layer=act_layer, drop=proj_drop, ) ) self.drop_path = PointSequential( DropPath(drop_path) if drop_path > 0.0 else nn.Identity() ) def forward(self, point: Point): shortcut = point.feat point = self.cpe(point) point.feat = shortcut + point.feat shortcut = point.feat if self.pre_norm: point = self.norm1(point) point = self.drop_path(self.attn(point)) point.feat = shortcut + point.feat if not self.pre_norm: point = self.norm1(point) shortcut = point.feat if self.pre_norm: point = self.norm2(point) point = self.drop_path(self.mlp(point)) point.feat = shortcut + point.feat if not self.pre_norm: point = self.norm2(point) # point.sparse_conv_feat.replace_feature(point.feat) point.sparse_conv_feat = point.sparse_conv_feat.replace_feature(point.feat) return point class SerializedPooling(PointModule): def __init__( self, in_channels, out_channels, stride=2, norm_layer=None, act_layer=None, reduce="max", shuffle_orders=True, traceable=True, # record parent and cluster ): super().__init__() self.in_channels = in_channels self.out_channels = out_channels assert stride == 2 ** (math.ceil(stride) - 1).bit_length() # 2, 4, 8 # TODO: add support to grid pool (any stride) self.stride = stride assert reduce in ["sum", "mean", "min", "max"] self.reduce = reduce self.shuffle_orders = shuffle_orders self.traceable = traceable self.proj = nn.Linear(in_channels, out_channels) if norm_layer is not None: self.norm = PointSequential(norm_layer(out_channels)) if act_layer is not None: self.act = PointSequential(act_layer()) def forward(self, point: Point): pooling_depth = (math.ceil(self.stride) - 1).bit_length() if pooling_depth > point.serialized_depth: pooling_depth = 0 assert { "serialized_code", "serialized_order", "serialized_inverse", "serialized_depth", }.issubset( point.keys() ), "Run point.serialization() point cloud before SerializedPooling" code = point.serialized_code >> pooling_depth * 3 code_, cluster, counts = torch.unique( code[0], sorted=True, return_inverse=True, return_counts=True, ) # indices of point sorted by cluster, for torch_scatter.segment_csr _, indices = torch.sort(cluster) # index pointer for sorted point, for torch_scatter.segment_csr idx_ptr = torch.cat([counts.new_zeros(1), torch.cumsum(counts, dim=0)]) # head_indices of each cluster, for reduce attr e.g. code, batch head_indices = indices[idx_ptr[:-1]] # generate down code, order, inverse code = code[:, head_indices] order = torch.argsort(code) inverse = torch.zeros_like(order).scatter_( dim=1, index=order, src=torch.arange(0, code.shape[1], device=order.device).repeat( code.shape[0], 1 ), ) if self.shuffle_orders: perm = torch.randperm(code.shape[0]) code = code[perm] order = order[perm] inverse = inverse[perm] # collect information point_dict = Dict( feat=torch_scatter.segment_csr( self.proj(point.feat)[indices], idx_ptr, reduce=self.reduce ), coord=torch_scatter.segment_csr( point.coord[indices], idx_ptr, reduce="mean" ), grid_coord=point.grid_coord[head_indices] >> pooling_depth, serialized_code=code, serialized_order=order, serialized_inverse=inverse, serialized_depth=point.serialized_depth - pooling_depth, batch=point.batch[head_indices], ) if "condition" in point.keys(): point_dict["condition"] = point.condition if "context" in point.keys(): point_dict["context"] = point.context if self.traceable: point_dict["pooling_inverse"] = cluster point_dict["pooling_parent"] = point point = Point(point_dict) if self.norm is not None: point = self.norm(point) if self.act is not None: point = self.act(point) point.sparsify() return point class SerializedUnpooling(PointModule): def __init__( self, in_channels, skip_channels, out_channels, norm_layer=None, act_layer=None, traceable=False, # record parent and cluster ): super().__init__() self.proj = PointSequential(nn.Linear(in_channels, out_channels)) self.proj_skip = PointSequential(nn.Linear(skip_channels, out_channels)) if norm_layer is not None: self.proj.add(norm_layer(out_channels)) self.proj_skip.add(norm_layer(out_channels)) if act_layer is not None: self.proj.add(act_layer()) self.proj_skip.add(act_layer()) self.traceable = traceable def forward(self, point): assert "pooling_parent" in point.keys() assert "pooling_inverse" in point.keys() parent = point.pop("pooling_parent") inverse = point.pop("pooling_inverse") point = self.proj(point) parent = self.proj_skip(parent) parent.feat = parent.feat + point.feat[inverse] if self.traceable: parent["unpooling_parent"] = point return parent class Embedding(PointModule): def __init__( self, in_channels, embed_channels, norm_layer=None, act_layer=None, res_linear=False, ): super().__init__() self.in_channels = in_channels self.embed_channels = embed_channels # TODO: check remove spconv self.stem = PointSequential( conv=spconv.SubMConv3d( in_channels, embed_channels, kernel_size=5, padding=1, bias=False, indice_key="stem", ) ) if norm_layer is not None: self.stem.add(norm_layer(embed_channels), name="norm") if act_layer is not None: self.stem.add(act_layer(), name="act") if res_linear: self.res_linear = nn.Linear(in_channels, embed_channels) else: self.res_linear = None def forward(self, point: Point): if self.res_linear: res_feature = self.res_linear(point.feat) point = self.stem(point) if self.res_linear: point.feat = point.feat + res_feature point.sparse_conv_feat = point.sparse_conv_feat.replace_feature(point.feat) return point class PointTransformerV3Object(PointModule): def __init__( self, in_channels=9, order=("z", "z-trans", "hilbert", "hilbert-trans"), stride=(), enc_depths=(3, 3, 3, 6, 16), enc_channels=(32, 64, 128, 256, 384), enc_num_head=(2, 4, 8, 16, 24), enc_patch_size=(1024, 1024, 1024, 1024, 1024), mlp_ratio=4, qkv_bias=True, qk_scale=None, attn_drop=0.0, proj_drop=0.0, drop_path=0.0, pre_norm=True, shuffle_orders=True, enable_rpe=False, enable_flash=True, upcast_attention=False, upcast_softmax=False, cls_mode=False, enable_qknorm=False, layer_norm=False, res_linear=True, ): super().__init__() self.num_stages = len(enc_depths) self.order = [order] if isinstance(order, str) else order self.cls_mode = cls_mode self.shuffle_orders = shuffle_orders # norm layers if layer_norm: bn_layer = partial(nn.LayerNorm) else: print("WARNING: use BatchNorm in ptv3obj !!!") bn_layer = partial(nn.BatchNorm1d, eps=1e-3, momentum=0.01) ln_layer = nn.LayerNorm # activation layers act_layer = nn.GELU self.embedding = Embedding( in_channels=in_channels, embed_channels=enc_channels[0], norm_layer=bn_layer, act_layer=act_layer, res_linear=res_linear, ) # encoder enc_drop_path = [ x.item() for x in torch.linspace(0, drop_path, sum(enc_depths)) ] self.enc = PointSequential() for s in range(self.num_stages): enc_drop_path_ = enc_drop_path[ sum(enc_depths[:s]) : sum(enc_depths[: s + 1]) ] enc = PointSequential() if s > 0: enc.add(nn.Linear(enc_channels[s - 1], enc_channels[s])) for i in range(enc_depths[s]): enc.add( Block( channels=enc_channels[s], num_heads=enc_num_head[s], patch_size=enc_patch_size[s], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=proj_drop, drop_path=enc_drop_path_[i], norm_layer=ln_layer, act_layer=act_layer, pre_norm=pre_norm, order_index=i % len(self.order), cpe_indice_key=f"stage{s}", enable_rpe=enable_rpe, enable_flash=enable_flash, upcast_attention=upcast_attention, upcast_softmax=upcast_softmax, enable_qknorm=enable_qknorm, ), name=f"block{i}", ) if len(enc) != 0: self.enc.add(module=enc, name=f"enc{s}") def forward(self, data_dict, min_coord=None): point = Point(data_dict) point.serialization(order=self.order, shuffle_orders=self.shuffle_orders, min_coord=min_coord) point.sparsify() point = self.embedding(point) point = self.enc(point) return point def get_encoder(pretrained_path: Union[str, None]=None, freeze_encoder: bool=False, **kwargs) -> PointTransformerV3Object: point_encoder = PointTransformerV3Object(**kwargs) if pretrained_path is not None: checkpoint = torch.load(pretrained_path) state_dict = checkpoint["state_dict"] state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()} point_encoder.load_state_dict(state_dict, strict=False) if freeze_encoder is True: for name, param in point_encoder.named_parameters(): if 'res_linear' not in name and 'qknorm' not in name: param.requires_grad = False return point_encoder