Spaces:
Running
on
Zero
Running
on
Zero
from functools import partial | |
from addict import Dict | |
import math | |
import torch | |
import torch.nn as nn | |
import spconv.pytorch as spconv | |
import torch_scatter | |
from timm.models.layers import DropPath | |
from typing import Union | |
from einops import rearrange | |
try: | |
import flash_attn | |
except ImportError: | |
flash_attn = None | |
from .utils.misc import offset2bincount | |
from .utils.structure import Point | |
from .modules import PointModule, PointSequential | |
class RPE(torch.nn.Module): | |
def __init__(self, patch_size, num_heads): | |
super().__init__() | |
self.patch_size = patch_size | |
self.num_heads = num_heads | |
self.pos_bnd = int((4 * patch_size) ** (1 / 3) * 2) | |
self.rpe_num = 2 * self.pos_bnd + 1 | |
self.rpe_table = torch.nn.Parameter(torch.zeros(3 * self.rpe_num, num_heads)) | |
torch.nn.init.trunc_normal_(self.rpe_table, std=0.02) | |
def forward(self, coord): | |
idx = ( | |
coord.clamp(-self.pos_bnd, self.pos_bnd) # clamp into bnd | |
+ self.pos_bnd # relative position to positive index | |
+ torch.arange(3, device=coord.device) * self.rpe_num # x, y, z stride | |
) | |
out = self.rpe_table.index_select(0, idx.reshape(-1)) | |
out = out.view(idx.shape + (-1,)).sum(3) | |
out = out.permute(0, 3, 1, 2) # (N, K, K, H) -> (N, H, K, K) | |
return out | |
class QueryKeyNorm(nn.Module): | |
def __init__(self, channels, num_heads): | |
super(QueryKeyNorm, self).__init__() | |
self.num_heads = num_heads | |
self.norm = nn.LayerNorm(channels // num_heads, elementwise_affine=False) | |
def forward(self, qkv): | |
H = self.num_heads | |
#qkv = qkv.reshape(-1, 3, H, qkv.shape[1] // H).permute(1, 0, 2, 3) | |
qkv = rearrange(qkv, 'N (S H Ch) -> S N H Ch', H=H, S=3) | |
q, k, v = qkv.unbind(dim=0) | |
# q, k, v: [N, H, C // H] | |
q_norm = self.norm(q) | |
k_norm = self.norm(k) | |
# qkv_norm: [3, N, H, C // H] | |
qkv_norm = torch.stack([q_norm, k_norm, v]) | |
qkv_norm = rearrange(qkv_norm, 'S N H Ch -> N (S H Ch)') | |
return qkv_norm | |
class SerializedAttention(PointModule): | |
def __init__( | |
self, | |
channels, | |
num_heads, | |
patch_size, | |
qkv_bias=True, | |
qk_scale=None, | |
attn_drop=0.0, | |
proj_drop=0.0, | |
order_index=0, | |
enable_rpe=False, | |
enable_flash=True, | |
upcast_attention=True, | |
upcast_softmax=True, | |
enable_qknorm=False, | |
): | |
super().__init__() | |
assert channels % num_heads == 0, f"channels {channels} must be divisible by num_heads {num_heads}" | |
self.channels = channels | |
self.num_heads = num_heads | |
self.scale = qk_scale or (channels // num_heads) ** -0.5 | |
self.order_index = order_index | |
self.upcast_attention = upcast_attention | |
self.upcast_softmax = upcast_softmax | |
self.enable_rpe = enable_rpe | |
self.enable_flash = enable_flash | |
self.enable_qknorm = enable_qknorm | |
if enable_qknorm: | |
self.qknorm = QueryKeyNorm(channels, num_heads) | |
else: | |
print("WARNING: enable_qknorm is False in PTv3Object and training may be fragile") | |
if enable_flash: | |
assert ( | |
enable_rpe is False | |
), "Set enable_rpe to False when enable Flash Attention" | |
assert ( | |
upcast_attention is False | |
), "Set upcast_attention to False when enable Flash Attention" | |
assert ( | |
upcast_softmax is False | |
), "Set upcast_softmax to False when enable Flash Attention" | |
assert flash_attn is not None, "Make sure flash_attn is installed." | |
self.patch_size = patch_size | |
self.attn_drop = attn_drop | |
else: | |
# when disable flash attention, we still don't want to use mask | |
# consequently, patch size will auto set to the | |
# min number of patch_size_max and number of points | |
self.patch_size_max = patch_size | |
self.patch_size = 0 | |
self.attn_drop = torch.nn.Dropout(attn_drop) | |
self.qkv = torch.nn.Linear(channels, channels * 3, bias=qkv_bias) | |
self.proj = torch.nn.Linear(channels, channels) | |
self.proj_drop = torch.nn.Dropout(proj_drop) | |
self.softmax = torch.nn.Softmax(dim=-1) | |
self.rpe = RPE(patch_size, num_heads) if self.enable_rpe else None | |
def get_rel_pos(self, point, order): | |
K = self.patch_size | |
rel_pos_key = f"rel_pos_{self.order_index}" | |
if rel_pos_key not in point.keys(): | |
grid_coord = point.grid_coord[order] | |
grid_coord = grid_coord.reshape(-1, K, 3) | |
point[rel_pos_key] = grid_coord.unsqueeze(2) - grid_coord.unsqueeze(1) | |
return point[rel_pos_key] | |
def get_padding_and_inverse(self, point): | |
pad_key = "pad" | |
unpad_key = "unpad" | |
cu_seqlens_key = "cu_seqlens_key" | |
if ( | |
pad_key not in point.keys() | |
or unpad_key not in point.keys() | |
or cu_seqlens_key not in point.keys() | |
): | |
offset = point.offset | |
bincount = offset2bincount(offset) | |
bincount_pad = ( | |
torch.div( | |
bincount + self.patch_size - 1, | |
self.patch_size, | |
rounding_mode="trunc", | |
) | |
* self.patch_size | |
) | |
# only pad point when num of points larger than patch_size | |
mask_pad = bincount > self.patch_size | |
bincount_pad = ~mask_pad * bincount + mask_pad * bincount_pad | |
_offset = nn.functional.pad(offset, (1, 0)) | |
_offset_pad = nn.functional.pad(torch.cumsum(bincount_pad, dim=0), (1, 0)) | |
pad = torch.arange(_offset_pad[-1], device=offset.device) | |
unpad = torch.arange(_offset[-1], device=offset.device) | |
cu_seqlens = [] | |
for i in range(len(offset)): | |
unpad[_offset[i] : _offset[i + 1]] += _offset_pad[i] - _offset[i] | |
if bincount[i] != bincount_pad[i]: | |
pad[ | |
_offset_pad[i + 1] | |
- self.patch_size | |
+ (bincount[i] % self.patch_size) : _offset_pad[i + 1] | |
] = pad[ | |
_offset_pad[i + 1] | |
- 2 * self.patch_size | |
+ (bincount[i] % self.patch_size) : _offset_pad[i + 1] | |
- self.patch_size | |
] | |
pad[_offset_pad[i] : _offset_pad[i + 1]] -= _offset_pad[i] - _offset[i] | |
cu_seqlens.append( | |
torch.arange( | |
_offset_pad[i], | |
_offset_pad[i + 1], | |
step=self.patch_size, | |
dtype=torch.int32, | |
device=offset.device, | |
) | |
) | |
point[pad_key] = pad | |
point[unpad_key] = unpad | |
point[cu_seqlens_key] = nn.functional.pad( | |
torch.concat(cu_seqlens), (0, 1), value=_offset_pad[-1] | |
) | |
return point[pad_key], point[unpad_key], point[cu_seqlens_key] | |
def forward(self, point): | |
if not self.enable_flash: | |
self.patch_size = min( | |
offset2bincount(point.offset).min().tolist(), self.patch_size_max | |
) | |
H = self.num_heads | |
K = self.patch_size | |
C = self.channels | |
pad, unpad, cu_seqlens = self.get_padding_and_inverse(point) | |
order = point.serialized_order[self.order_index][pad] | |
inverse = unpad[point.serialized_inverse[self.order_index]] | |
# padding and reshape feat and batch for serialized point patch | |
qkv = self.qkv(point.feat)[order] | |
if self.enable_qknorm: | |
qkv = self.qknorm(qkv) | |
if not self.enable_flash: | |
# encode and reshape qkv: (N', K, 3, H, C') => (3, N', H, K, C') | |
q, k, v = ( | |
qkv.reshape(-1, K, 3, H, C // H).permute(2, 0, 3, 1, 4).unbind(dim=0) | |
) | |
# attn | |
if self.upcast_attention: | |
q = q.float() | |
k = k.float() | |
attn = (q * self.scale) @ k.transpose(-2, -1) # (N', H, K, K) | |
if self.enable_rpe: | |
attn = attn + self.rpe(self.get_rel_pos(point, order)) | |
if self.upcast_softmax: | |
attn = attn.float() | |
attn = self.softmax(attn) | |
attn = self.attn_drop(attn).to(qkv.dtype) | |
feat = (attn @ v).transpose(1, 2).reshape(-1, C) | |
else: | |
feat = flash_attn.flash_attn_varlen_qkvpacked_func( | |
qkv.half().reshape(-1, 3, H, C // H), | |
cu_seqlens, | |
max_seqlen=self.patch_size, | |
dropout_p=self.attn_drop if self.training else 0, | |
softmax_scale=self.scale, | |
).reshape(-1, C) | |
feat = feat.to(qkv.dtype) | |
feat = feat[inverse] | |
# ffn | |
feat = self.proj(feat) | |
feat = self.proj_drop(feat) | |
point.feat = feat | |
return point | |
class MLP(nn.Module): | |
def __init__( | |
self, | |
in_channels, | |
hidden_channels=None, | |
out_channels=None, | |
act_layer=nn.GELU, | |
drop=0.0, | |
): | |
super().__init__() | |
out_channels = out_channels or in_channels | |
hidden_channels = hidden_channels or in_channels | |
self.fc1 = nn.Linear(in_channels, hidden_channels) | |
self.act = act_layer() | |
self.fc2 = nn.Linear(hidden_channels, out_channels) | |
self.drop = nn.Dropout(drop) | |
def forward(self, x): | |
x = self.fc1(x) | |
x = self.act(x) | |
x = self.drop(x) | |
x = self.fc2(x) | |
x = self.drop(x) | |
return x | |
class Block(PointModule): | |
def __init__( | |
self, | |
channels, | |
num_heads, | |
patch_size=48, | |
mlp_ratio=4.0, | |
qkv_bias=True, | |
qk_scale=None, | |
attn_drop=0.0, | |
proj_drop=0.0, | |
drop_path=0.0, | |
norm_layer=nn.LayerNorm, | |
act_layer=nn.GELU, | |
pre_norm=True, | |
order_index=0, | |
cpe_indice_key=None, | |
enable_rpe=False, | |
enable_flash=True, | |
upcast_attention=True, | |
upcast_softmax=True, | |
enable_qknorm=False, | |
): | |
super().__init__() | |
self.channels = channels | |
self.pre_norm = pre_norm | |
self.cpe = PointSequential( | |
spconv.SubMConv3d( | |
channels, | |
channels, | |
kernel_size=3, | |
bias=True, | |
indice_key=cpe_indice_key, | |
), | |
nn.Linear(channels, channels), | |
norm_layer(channels), | |
) | |
self.norm1 = PointSequential(norm_layer(channels)) | |
self.attn = SerializedAttention( | |
channels=channels, | |
patch_size=patch_size, | |
num_heads=num_heads, | |
qkv_bias=qkv_bias, | |
qk_scale=qk_scale, | |
attn_drop=attn_drop, | |
proj_drop=proj_drop, | |
order_index=order_index, | |
enable_rpe=enable_rpe, | |
enable_flash=enable_flash, | |
upcast_attention=upcast_attention, | |
upcast_softmax=upcast_softmax, | |
enable_qknorm=enable_qknorm, | |
) | |
self.norm2 = PointSequential(norm_layer(channels)) | |
self.mlp = PointSequential( | |
MLP( | |
in_channels=channels, | |
hidden_channels=int(channels * mlp_ratio), | |
out_channels=channels, | |
act_layer=act_layer, | |
drop=proj_drop, | |
) | |
) | |
self.drop_path = PointSequential( | |
DropPath(drop_path) if drop_path > 0.0 else nn.Identity() | |
) | |
def forward(self, point: Point): | |
shortcut = point.feat | |
point = self.cpe(point) | |
point.feat = shortcut + point.feat | |
shortcut = point.feat | |
if self.pre_norm: | |
point = self.norm1(point) | |
point = self.drop_path(self.attn(point)) | |
point.feat = shortcut + point.feat | |
if not self.pre_norm: | |
point = self.norm1(point) | |
shortcut = point.feat | |
if self.pre_norm: | |
point = self.norm2(point) | |
point = self.drop_path(self.mlp(point)) | |
point.feat = shortcut + point.feat | |
if not self.pre_norm: | |
point = self.norm2(point) | |
# point.sparse_conv_feat.replace_feature(point.feat) | |
point.sparse_conv_feat = point.sparse_conv_feat.replace_feature(point.feat) | |
return point | |
class SerializedPooling(PointModule): | |
def __init__( | |
self, | |
in_channels, | |
out_channels, | |
stride=2, | |
norm_layer=None, | |
act_layer=None, | |
reduce="max", | |
shuffle_orders=True, | |
traceable=True, # record parent and cluster | |
): | |
super().__init__() | |
self.in_channels = in_channels | |
self.out_channels = out_channels | |
assert stride == 2 ** (math.ceil(stride) - 1).bit_length() # 2, 4, 8 | |
# TODO: add support to grid pool (any stride) | |
self.stride = stride | |
assert reduce in ["sum", "mean", "min", "max"] | |
self.reduce = reduce | |
self.shuffle_orders = shuffle_orders | |
self.traceable = traceable | |
self.proj = nn.Linear(in_channels, out_channels) | |
if norm_layer is not None: | |
self.norm = PointSequential(norm_layer(out_channels)) | |
if act_layer is not None: | |
self.act = PointSequential(act_layer()) | |
def forward(self, point: Point): | |
pooling_depth = (math.ceil(self.stride) - 1).bit_length() | |
if pooling_depth > point.serialized_depth: | |
pooling_depth = 0 | |
assert { | |
"serialized_code", | |
"serialized_order", | |
"serialized_inverse", | |
"serialized_depth", | |
}.issubset( | |
point.keys() | |
), "Run point.serialization() point cloud before SerializedPooling" | |
code = point.serialized_code >> pooling_depth * 3 | |
code_, cluster, counts = torch.unique( | |
code[0], | |
sorted=True, | |
return_inverse=True, | |
return_counts=True, | |
) | |
# indices of point sorted by cluster, for torch_scatter.segment_csr | |
_, indices = torch.sort(cluster) | |
# index pointer for sorted point, for torch_scatter.segment_csr | |
idx_ptr = torch.cat([counts.new_zeros(1), torch.cumsum(counts, dim=0)]) | |
# head_indices of each cluster, for reduce attr e.g. code, batch | |
head_indices = indices[idx_ptr[:-1]] | |
# generate down code, order, inverse | |
code = code[:, head_indices] | |
order = torch.argsort(code) | |
inverse = torch.zeros_like(order).scatter_( | |
dim=1, | |
index=order, | |
src=torch.arange(0, code.shape[1], device=order.device).repeat( | |
code.shape[0], 1 | |
), | |
) | |
if self.shuffle_orders: | |
perm = torch.randperm(code.shape[0]) | |
code = code[perm] | |
order = order[perm] | |
inverse = inverse[perm] | |
# collect information | |
point_dict = Dict( | |
feat=torch_scatter.segment_csr( | |
self.proj(point.feat)[indices], idx_ptr, reduce=self.reduce | |
), | |
coord=torch_scatter.segment_csr( | |
point.coord[indices], idx_ptr, reduce="mean" | |
), | |
grid_coord=point.grid_coord[head_indices] >> pooling_depth, | |
serialized_code=code, | |
serialized_order=order, | |
serialized_inverse=inverse, | |
serialized_depth=point.serialized_depth - pooling_depth, | |
batch=point.batch[head_indices], | |
) | |
if "condition" in point.keys(): | |
point_dict["condition"] = point.condition | |
if "context" in point.keys(): | |
point_dict["context"] = point.context | |
if self.traceable: | |
point_dict["pooling_inverse"] = cluster | |
point_dict["pooling_parent"] = point | |
point = Point(point_dict) | |
if self.norm is not None: | |
point = self.norm(point) | |
if self.act is not None: | |
point = self.act(point) | |
point.sparsify() | |
return point | |
class SerializedUnpooling(PointModule): | |
def __init__( | |
self, | |
in_channels, | |
skip_channels, | |
out_channels, | |
norm_layer=None, | |
act_layer=None, | |
traceable=False, # record parent and cluster | |
): | |
super().__init__() | |
self.proj = PointSequential(nn.Linear(in_channels, out_channels)) | |
self.proj_skip = PointSequential(nn.Linear(skip_channels, out_channels)) | |
if norm_layer is not None: | |
self.proj.add(norm_layer(out_channels)) | |
self.proj_skip.add(norm_layer(out_channels)) | |
if act_layer is not None: | |
self.proj.add(act_layer()) | |
self.proj_skip.add(act_layer()) | |
self.traceable = traceable | |
def forward(self, point): | |
assert "pooling_parent" in point.keys() | |
assert "pooling_inverse" in point.keys() | |
parent = point.pop("pooling_parent") | |
inverse = point.pop("pooling_inverse") | |
point = self.proj(point) | |
parent = self.proj_skip(parent) | |
parent.feat = parent.feat + point.feat[inverse] | |
if self.traceable: | |
parent["unpooling_parent"] = point | |
return parent | |
class Embedding(PointModule): | |
def __init__( | |
self, | |
in_channels, | |
embed_channels, | |
norm_layer=None, | |
act_layer=None, | |
res_linear=False, | |
): | |
super().__init__() | |
self.in_channels = in_channels | |
self.embed_channels = embed_channels | |
# TODO: check remove spconv | |
self.stem = PointSequential( | |
conv=spconv.SubMConv3d( | |
in_channels, | |
embed_channels, | |
kernel_size=5, | |
padding=1, | |
bias=False, | |
indice_key="stem", | |
) | |
) | |
if norm_layer is not None: | |
self.stem.add(norm_layer(embed_channels), name="norm") | |
if act_layer is not None: | |
self.stem.add(act_layer(), name="act") | |
if res_linear: | |
self.res_linear = nn.Linear(in_channels, embed_channels) | |
else: | |
self.res_linear = None | |
def forward(self, point: Point): | |
if self.res_linear: | |
res_feature = self.res_linear(point.feat) | |
point = self.stem(point) | |
if self.res_linear: | |
point.feat = point.feat + res_feature | |
point.sparse_conv_feat = point.sparse_conv_feat.replace_feature(point.feat) | |
return point | |
class PointTransformerV3Object(PointModule): | |
def __init__( | |
self, | |
in_channels=9, | |
order=("z", "z-trans", "hilbert", "hilbert-trans"), | |
stride=(), | |
enc_depths=(3, 3, 3, 6, 16), | |
enc_channels=(32, 64, 128, 256, 384), | |
enc_num_head=(2, 4, 8, 16, 24), | |
enc_patch_size=(1024, 1024, 1024, 1024, 1024), | |
mlp_ratio=4, | |
qkv_bias=True, | |
qk_scale=None, | |
attn_drop=0.0, | |
proj_drop=0.0, | |
drop_path=0.0, | |
pre_norm=True, | |
shuffle_orders=True, | |
enable_rpe=False, | |
enable_flash=True, | |
upcast_attention=False, | |
upcast_softmax=False, | |
cls_mode=False, | |
enable_qknorm=False, | |
layer_norm=False, | |
res_linear=True, | |
): | |
super().__init__() | |
self.num_stages = len(enc_depths) | |
self.order = [order] if isinstance(order, str) else order | |
self.cls_mode = cls_mode | |
self.shuffle_orders = shuffle_orders | |
# norm layers | |
if layer_norm: | |
bn_layer = partial(nn.LayerNorm) | |
else: | |
print("WARNING: use BatchNorm in ptv3obj !!!") | |
bn_layer = partial(nn.BatchNorm1d, eps=1e-3, momentum=0.01) | |
ln_layer = nn.LayerNorm | |
# activation layers | |
act_layer = nn.GELU | |
self.embedding = Embedding( | |
in_channels=in_channels, | |
embed_channels=enc_channels[0], | |
norm_layer=bn_layer, | |
act_layer=act_layer, | |
res_linear=res_linear, | |
) | |
# encoder | |
enc_drop_path = [ | |
x.item() for x in torch.linspace(0, drop_path, sum(enc_depths)) | |
] | |
self.enc = PointSequential() | |
for s in range(self.num_stages): | |
enc_drop_path_ = enc_drop_path[ | |
sum(enc_depths[:s]) : sum(enc_depths[: s + 1]) | |
] | |
enc = PointSequential() | |
if s > 0: | |
enc.add(nn.Linear(enc_channels[s - 1], enc_channels[s])) | |
for i in range(enc_depths[s]): | |
enc.add( | |
Block( | |
channels=enc_channels[s], | |
num_heads=enc_num_head[s], | |
patch_size=enc_patch_size[s], | |
mlp_ratio=mlp_ratio, | |
qkv_bias=qkv_bias, | |
qk_scale=qk_scale, | |
attn_drop=attn_drop, | |
proj_drop=proj_drop, | |
drop_path=enc_drop_path_[i], | |
norm_layer=ln_layer, | |
act_layer=act_layer, | |
pre_norm=pre_norm, | |
order_index=i % len(self.order), | |
cpe_indice_key=f"stage{s}", | |
enable_rpe=enable_rpe, | |
enable_flash=enable_flash, | |
upcast_attention=upcast_attention, | |
upcast_softmax=upcast_softmax, | |
enable_qknorm=enable_qknorm, | |
), | |
name=f"block{i}", | |
) | |
if len(enc) != 0: | |
self.enc.add(module=enc, name=f"enc{s}") | |
def forward(self, data_dict, min_coord=None): | |
point = Point(data_dict) | |
point.serialization(order=self.order, shuffle_orders=self.shuffle_orders, min_coord=min_coord) | |
point.sparsify() | |
point = self.embedding(point) | |
point = self.enc(point) | |
return point | |
def get_encoder(pretrained_path: Union[str, None]=None, freeze_encoder: bool=False, **kwargs) -> PointTransformerV3Object: | |
point_encoder = PointTransformerV3Object(**kwargs) | |
if pretrained_path is not None: | |
checkpoint = torch.load(pretrained_path) | |
state_dict = checkpoint["state_dict"] | |
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()} | |
point_encoder.load_state_dict(state_dict, strict=False) | |
if freeze_encoder is True: | |
for name, param in point_encoder.named_parameters(): | |
if 'res_linear' not in name and 'qknorm' not in name: | |
param.requires_grad = False | |
return point_encoder |