Find3D / Pointcept /pointcept /models /octformer /octformer_v1m1_base.py
ziqima's picture
initial commit
4893ce0
"""
Octree Transformer
Modified from https://github.com/octree-nn/octformer
Author: Xiaoyang Wu ([email protected])
Please cite our work if the code is helpful to you.
"""
from typing import Optional, List, Dict
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
try:
import ocnn
from ocnn.octree import Octree, Points
except ImportError:
from pointcept.utils.misc import DummyClass
ocnn = None
Octree = DummyClass
Points = DummyClass
try:
import dwconv
except ImportError:
dwconv = None
from pointcept.models.builder import MODELS
from pointcept.models.utils import offset2batch
class OctreeT(Octree):
def __init__(
self,
octree: Octree,
patch_size: int = 24,
dilation: int = 4,
nempty: bool = True,
max_depth: Optional[int] = None,
start_depth: Optional[int] = None,
**kwargs
):
super().__init__(octree.depth, octree.full_depth)
self.__dict__.update(octree.__dict__)
self.patch_size = patch_size
self.dilation = dilation
self.nempty = nempty
self.max_depth = max_depth or self.depth
self.start_depth = start_depth or self.full_depth
self.invalid_mask_value = -1e3
assert self.start_depth > 1
self.block_num = patch_size * dilation
self.nnum_t = self.nnum_nempty if nempty else self.nnum
self.nnum_a = ((self.nnum_t / self.block_num).ceil() * self.block_num).int()
num = self.max_depth + 1
self.batch_idx = [None] * num
self.patch_mask = [None] * num
self.dilate_mask = [None] * num
self.rel_pos = [None] * num
self.dilate_pos = [None] * num
self.build_t()
def build_t(self):
for d in range(self.start_depth, self.max_depth + 1):
self.build_batch_idx(d)
self.build_attn_mask(d)
self.build_rel_pos(d)
def build_batch_idx(self, depth: int):
batch = self.batch_id(depth, self.nempty)
self.batch_idx[depth] = self.patch_partition(batch, depth, self.batch_size)
def build_attn_mask(self, depth: int):
batch = self.batch_idx[depth]
mask = batch.view(-1, self.patch_size)
self.patch_mask[depth] = self._calc_attn_mask(mask)
mask = batch.view(-1, self.patch_size, self.dilation)
mask = mask.transpose(1, 2).reshape(-1, self.patch_size)
self.dilate_mask[depth] = self._calc_attn_mask(mask)
def _calc_attn_mask(self, mask: torch.Tensor):
attn_mask = mask.unsqueeze(2) - mask.unsqueeze(1)
attn_mask = attn_mask.masked_fill(attn_mask != 0, self.invalid_mask_value)
return attn_mask
def build_rel_pos(self, depth: int):
key = self.key(depth, self.nempty)
key = self.patch_partition(key, depth)
x, y, z, _ = ocnn.octree.key2xyz(key, depth)
xyz = torch.stack([x, y, z], dim=1)
xyz = xyz.view(-1, self.patch_size, 3)
self.rel_pos[depth] = xyz.unsqueeze(2) - xyz.unsqueeze(1)
xyz = xyz.view(-1, self.patch_size, self.dilation, 3)
xyz = xyz.transpose(1, 2).reshape(-1, self.patch_size, 3)
self.dilate_pos[depth] = xyz.unsqueeze(2) - xyz.unsqueeze(1)
def patch_partition(self, data: torch.Tensor, depth: int, fill_value=0):
num = self.nnum_a[depth] - self.nnum_t[depth]
tail = data.new_full((num,) + data.shape[1:], fill_value)
return torch.cat([data, tail], dim=0)
def patch_reverse(self, data: torch.Tensor, depth: int):
return data[: self.nnum_t[depth]]
class MLP(torch.nn.Module):
def __init__(
self,
in_features: int,
hidden_features: Optional[int] = None,
out_features: Optional[int] = None,
activation=torch.nn.GELU,
drop: float = 0.0,
**kwargs
):
super().__init__()
self.in_features = in_features
self.out_features = out_features or in_features
self.hidden_features = hidden_features or in_features
self.fc1 = torch.nn.Linear(self.in_features, self.hidden_features)
self.act = activation()
self.fc2 = torch.nn.Linear(self.hidden_features, self.out_features)
self.drop = torch.nn.Dropout(drop, inplace=True)
def forward(self, data: torch.Tensor):
data = self.fc1(data)
data = self.act(data)
data = self.drop(data)
data = self.fc2(data)
data = self.drop(data)
return data
class OctreeDWConvBn(torch.nn.Module):
def __init__(
self,
in_channels: int,
kernel_size: List[int] = [3],
stride: int = 1,
nempty: bool = False,
):
super().__init__()
self.conv = dwconv.OctreeDWConv(
in_channels, kernel_size, nempty, use_bias=False
)
self.bn = torch.nn.BatchNorm1d(in_channels)
def forward(self, data: torch.Tensor, octree: Octree, depth: int):
out = self.conv(data, octree, depth)
out = self.bn(out)
return out
class RPE(torch.nn.Module):
def __init__(self, patch_size: int, num_heads: int, dilation: int = 1):
super().__init__()
self.patch_size = patch_size
self.num_heads = num_heads
self.dilation = dilation
self.pos_bnd = self.get_pos_bnd(patch_size)
self.rpe_num = 2 * self.pos_bnd + 1
self.rpe_table = torch.nn.Parameter(torch.zeros(3 * self.rpe_num, num_heads))
torch.nn.init.trunc_normal_(self.rpe_table, std=0.02)
def get_pos_bnd(self, patch_size: int):
return int(0.8 * patch_size * self.dilation**0.5)
def xyz2idx(self, xyz: torch.Tensor):
mul = torch.arange(3, device=xyz.device) * self.rpe_num
xyz = xyz.clamp(-self.pos_bnd, self.pos_bnd)
idx = xyz + (self.pos_bnd + mul)
return idx
def forward(self, xyz):
idx = self.xyz2idx(xyz)
out = self.rpe_table.index_select(0, idx.reshape(-1))
out = out.view(idx.shape + (-1,)).sum(3)
out = out.permute(0, 3, 1, 2) # (N, K, K, H) -> (N, H, K, K)
return out
def extra_repr(self) -> str:
return "num_heads={}, pos_bnd={}, dilation={}".format(
self.num_heads, self.pos_bnd, self.dilation
) # noqa
class OctreeAttention(torch.nn.Module):
def __init__(
self,
dim: int,
patch_size: int,
num_heads: int,
qkv_bias: bool = True,
qk_scale: Optional[float] = None,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
dilation: int = 1,
use_rpe: bool = True,
):
super().__init__()
self.dim = dim
self.patch_size = patch_size
self.num_heads = num_heads
self.dilation = dilation
self.use_rpe = use_rpe
self.scale = qk_scale or (dim // num_heads) ** -0.5
self.qkv = torch.nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = torch.nn.Dropout(attn_drop)
self.proj = torch.nn.Linear(dim, dim)
self.proj_drop = torch.nn.Dropout(proj_drop)
self.softmax = torch.nn.Softmax(dim=-1)
self.rpe = RPE(patch_size, num_heads, dilation) if use_rpe else None
def forward(self, data: torch.Tensor, octree: OctreeT, depth: int):
H = self.num_heads
K = self.patch_size
C = self.dim
D = self.dilation
# patch partition
data = octree.patch_partition(data, depth)
if D > 1: # dilation
rel_pos = octree.dilate_pos[depth]
mask = octree.dilate_mask[depth]
data = data.view(-1, K, D, C).transpose(1, 2).reshape(-1, C)
else:
rel_pos = octree.rel_pos[depth]
mask = octree.patch_mask[depth]
data = data.view(-1, K, C)
# qkv
qkv = self.qkv(data).reshape(-1, K, 3, H, C // H).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # (N, H, K, C')
q = q * self.scale
# attn
attn = q @ k.transpose(-2, -1) # (N, H, K, K)
attn = self.apply_rpe(attn, rel_pos) # (N, H, K, K)
attn = attn + mask.unsqueeze(1)
attn = self.softmax(attn)
attn = self.attn_drop(attn)
data = (attn @ v).transpose(1, 2).reshape(-1, C)
# patch reverse
if D > 1: # dilation
data = data.view(-1, D, K, C).transpose(1, 2).reshape(-1, C)
data = octree.patch_reverse(data, depth)
# ffn
data = self.proj(data)
data = self.proj_drop(data)
return data
def apply_rpe(self, attn, rel_pos):
if self.use_rpe:
attn = attn + self.rpe(rel_pos)
return attn
def extra_repr(self) -> str:
return "dim={}, patch_size={}, num_heads={}, dilation={}".format(
self.dim, self.patch_size, self.num_heads, self.dilation
) # noqa
class OctFormerBlock(torch.nn.Module):
def __init__(
self,
dim: int,
num_heads: int,
patch_size: int = 32,
dilation: int = 0,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
qk_scale: Optional[float] = None,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
drop_path: float = 0.0,
nempty: bool = True,
activation: torch.nn.Module = torch.nn.GELU,
**kwargs
):
super().__init__()
self.norm1 = torch.nn.LayerNorm(dim)
self.attention = OctreeAttention(
dim,
patch_size,
num_heads,
qkv_bias,
qk_scale,
attn_drop,
proj_drop,
dilation,
)
self.norm2 = torch.nn.LayerNorm(dim)
self.mlp = MLP(dim, int(dim * mlp_ratio), dim, activation, proj_drop)
self.drop_path = ocnn.nn.OctreeDropPath(drop_path, nempty)
self.cpe = OctreeDWConvBn(dim, nempty=nempty)
def forward(self, data: torch.Tensor, octree: OctreeT, depth: int):
data = self.cpe(data, octree, depth) + data
attn = self.attention(self.norm1(data), octree, depth)
data = data + self.drop_path(attn, octree, depth)
ffn = self.mlp(self.norm2(data))
data = data + self.drop_path(ffn, octree, depth)
return data
class OctFormerStage(torch.nn.Module):
def __init__(
self,
dim: int,
num_heads: int,
patch_size: int = 32,
dilation: int = 0,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
qk_scale: Optional[float] = None,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
drop_path: float = 0.0,
nempty: bool = True,
activation: torch.nn.Module = torch.nn.GELU,
interval: int = 6,
use_checkpoint: bool = True,
num_blocks: int = 2,
octformer_block=OctFormerBlock,
**kwargs
):
super().__init__()
self.num_blocks = num_blocks
self.use_checkpoint = use_checkpoint
self.interval = interval # normalization interval
self.num_norms = (num_blocks - 1) // self.interval
self.blocks = torch.nn.ModuleList(
[
octformer_block(
dim=dim,
num_heads=num_heads,
patch_size=patch_size,
dilation=1 if (i % 2 == 0) else dilation,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=proj_drop,
drop_path=(
drop_path[i] if isinstance(drop_path, list) else drop_path
),
nempty=nempty,
activation=activation,
)
for i in range(num_blocks)
]
)
# self.norms = torch.nn.ModuleList([
# torch.nn.BatchNorm1d(dim) for _ in range(self.num_norms)])
def forward(self, data: torch.Tensor, octree: OctreeT, depth: int):
for i in range(self.num_blocks):
if self.use_checkpoint and self.training:
data = checkpoint(self.blocks[i], data, octree, depth)
else:
data = self.blocks[i](data, octree, depth)
# if i % self.interval == 0 and i != 0:
# data = self.norms[(i - 1) // self.interval](data)
return data
class OctFormerDecoder(torch.nn.Module):
def __init__(
self, channels: List[int], fpn_channel: int, nempty: bool, head_up: int = 1
):
super().__init__()
self.head_up = head_up
self.num_stages = len(channels)
self.conv1x1 = torch.nn.ModuleList(
[
torch.nn.Linear(channels[i], fpn_channel)
for i in range(self.num_stages - 1, -1, -1)
]
)
self.upsample = ocnn.nn.OctreeUpsample("nearest", nempty)
self.conv3x3 = torch.nn.ModuleList(
[
ocnn.modules.OctreeConvBnRelu(
fpn_channel, fpn_channel, kernel_size=[3], stride=1, nempty=nempty
)
for _ in range(self.num_stages)
]
)
self.up_conv = torch.nn.ModuleList(
[
ocnn.modules.OctreeDeconvBnRelu(
fpn_channel, fpn_channel, kernel_size=[3], stride=2, nempty=nempty
)
for _ in range(self.head_up)
]
)
def forward(self, features: Dict[int, torch.Tensor], octree: Octree):
depth = min(features.keys())
depth_max = max(features.keys())
assert self.num_stages == len(features)
feature = self.conv1x1[0](features[depth])
conv_out = self.conv3x3[0](feature, octree, depth)
out = self.upsample(conv_out, octree, depth, depth_max)
for i in range(1, self.num_stages):
depth_i = depth + i
feature = self.upsample(feature, octree, depth_i - 1)
feature = self.conv1x1[i](features[depth_i]) + feature
conv_out = self.conv3x3[i](feature, octree, depth_i)
out = out + self.upsample(conv_out, octree, depth_i, depth_max)
for i in range(self.head_up):
out = self.up_conv[i](out, octree, depth_max + i)
return out
class PatchEmbed(torch.nn.Module):
def __init__(
self,
in_channels: int = 3,
dim: int = 96,
num_down: int = 2,
nempty: bool = True,
**kwargs
):
super().__init__()
self.num_stages = num_down
self.delta_depth = -num_down
channels = [int(dim * 2**i) for i in range(-self.num_stages, 1)]
self.convs = torch.nn.ModuleList(
[
ocnn.modules.OctreeConvBnRelu(
in_channels if i == 0 else channels[i],
channels[i],
kernel_size=[3],
stride=1,
nempty=nempty,
)
for i in range(self.num_stages)
]
)
self.downsamples = torch.nn.ModuleList(
[
ocnn.modules.OctreeConvBnRelu(
channels[i],
channels[i + 1],
kernel_size=[2],
stride=2,
nempty=nempty,
)
for i in range(self.num_stages)
]
)
self.proj = ocnn.modules.OctreeConvBnRelu(
channels[-1], dim, kernel_size=[3], stride=1, nempty=nempty
)
def forward(self, data: torch.Tensor, octree: Octree, depth: int):
# TODO: reduce to single input
for i in range(self.num_stages):
depth_i = depth - i
data = self.convs[i](data, octree, depth_i)
data = self.downsamples[i](data, octree, depth_i)
data = self.proj(data, octree, depth_i - 1)
return data
class Downsample(torch.nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: List[int] = (2,),
nempty: bool = True,
):
super().__init__()
self.norm = torch.nn.BatchNorm1d(out_channels)
self.conv = ocnn.nn.OctreeConv(
in_channels,
out_channels,
kernel_size,
stride=2,
nempty=nempty,
use_bias=True,
)
def forward(self, data: torch.Tensor, octree: Octree, depth: int):
data = self.conv(data, octree, depth)
data = self.norm(data)
return data
@MODELS.register_module("OctFormer-v1m1")
class OctFormer(torch.nn.Module):
def __init__(
self,
in_channels,
num_classes,
fpn_channels=168,
channels=(96, 192, 384, 384),
num_blocks=(2, 2, 18, 2),
num_heads=(6, 12, 24, 24),
patch_size=26,
stem_down=2,
head_up=2,
dilation=4,
drop_path=0.5,
nempty=True,
octree_scale_factor=10.24,
octree_depth=11,
octree_full_depth=2,
):
super().__init__()
assert ocnn is not None, "Please follow `README.md` to install ocnn.`"
assert dwconv is not None, "Please follow `README.md` to install dwconv.`"
self.patch_size = patch_size
self.dilation = dilation
self.nempty = nempty
self.num_stages = len(num_blocks)
self.stem_down = stem_down
self.octree_scale_factor = octree_scale_factor
self.octree_depth = octree_depth
self.octree_full_depth = octree_full_depth
drop_ratio = torch.linspace(0, drop_path, sum(num_blocks)).tolist()
self.patch_embed = PatchEmbed(in_channels, channels[0], stem_down, nempty)
self.layers = torch.nn.ModuleList(
[
OctFormerStage(
dim=channels[i],
num_heads=num_heads[i],
patch_size=patch_size,
drop_path=drop_ratio[
sum(num_blocks[:i]) : sum(num_blocks[: i + 1])
],
dilation=dilation,
nempty=nempty,
num_blocks=num_blocks[i],
)
for i in range(self.num_stages)
]
)
self.downsamples = torch.nn.ModuleList(
[
Downsample(channels[i], channels[i + 1], kernel_size=[2], nempty=nempty)
for i in range(self.num_stages - 1)
]
)
self.decoder = OctFormerDecoder(
channels=channels, fpn_channel=fpn_channels, nempty=nempty, head_up=head_up
)
self.interp = ocnn.nn.OctreeInterp("nearest", nempty)
self.seg_head = (
nn.Sequential(
nn.Linear(fpn_channels, fpn_channels),
torch.nn.BatchNorm1d(fpn_channels),
nn.ReLU(inplace=True),
nn.Linear(fpn_channels, num_classes),
)
if num_classes > 0
else nn.Identity()
)
def points2octree(self, points):
octree = ocnn.octree.Octree(self.octree_depth, self.octree_full_depth)
octree.build_octree(points)
return octree
def forward(self, data_dict):
coord = data_dict["coord"]
normal = data_dict["normal"]
feat = data_dict["feat"]
offset = data_dict["offset"]
batch = offset2batch(offset)
point = Points(
points=coord / self.octree_scale_factor,
normals=normal,
features=feat,
batch_id=batch.unsqueeze(-1),
batch_size=len(offset),
)
octree = ocnn.octree.Octree(
depth=self.octree_depth,
full_depth=self.octree_full_depth,
batch_size=len(offset),
device=coord.device,
)
octree.build_octree(point)
octree.construct_all_neigh()
feat = self.patch_embed(octree.features[octree.depth], octree, octree.depth)
depth = octree.depth - self.stem_down # current octree depth
octree = OctreeT(
octree,
self.patch_size,
self.dilation,
self.nempty,
max_depth=depth,
start_depth=depth - self.num_stages + 1,
)
features = {}
for i in range(self.num_stages):
depth_i = depth - i
feat = self.layers[i](feat, octree, depth_i)
features[depth_i] = feat
if i < self.num_stages - 1:
feat = self.downsamples[i](feat, octree, depth_i)
out = self.decoder(features, octree)
# interp representation to points before Octreeization
query_pts = torch.cat([point.points, point.batch_id], dim=1).contiguous()
out = self.interp(out, octree, octree.depth, query_pts)
out = self.seg_head(out)
return out