from collections import namedtuple import pytorch_lightning as pl import torch import torch.nn as nn from pointnet2_ops.pointnet2_modules import PointnetFPModule, PointnetSAModuleMSG from pointnet2.models.pointnet2_ssg_sem import PointNet2SemSegSSG class PointNet2SemSegMSG(PointNet2SemSegSSG): def _build_model(self): self.SA_modules = nn.ModuleList() c_in = 6 self.SA_modules.append( PointnetSAModuleMSG( npoint=1024, radii=[0.05, 0.1], nsamples=[16, 32], mlps=[[c_in, 16, 16, 32], [c_in, 32, 32, 64]], use_xyz=self.hparams["model.use_xyz"], ) ) c_out_0 = 32 + 64 c_in = c_out_0 self.SA_modules.append( PointnetSAModuleMSG( npoint=256, radii=[0.1, 0.2], nsamples=[16, 32], mlps=[[c_in, 64, 64, 128], [c_in, 64, 96, 128]], use_xyz=self.hparams["model.use_xyz"], ) ) c_out_1 = 128 + 128 c_in = c_out_1 self.SA_modules.append( PointnetSAModuleMSG( npoint=64, radii=[0.2, 0.4], nsamples=[16, 32], mlps=[[c_in, 128, 196, 256], [c_in, 128, 196, 256]], use_xyz=self.hparams["model.use_xyz"], ) ) c_out_2 = 256 + 256 c_in = c_out_2 self.SA_modules.append( PointnetSAModuleMSG( npoint=16, radii=[0.4, 0.8], nsamples=[16, 32], mlps=[[c_in, 256, 256, 512], [c_in, 256, 384, 512]], use_xyz=self.hparams["model.use_xyz"], ) ) c_out_3 = 512 + 512 self.FP_modules = nn.ModuleList() self.FP_modules.append(PointnetFPModule(mlp=[256 + 6, 128, 128])) self.FP_modules.append(PointnetFPModule(mlp=[512 + c_out_0, 256, 256])) self.FP_modules.append(PointnetFPModule(mlp=[512 + c_out_1, 512, 512])) self.FP_modules.append(PointnetFPModule(mlp=[c_out_3 + c_out_2, 512, 512])) self.fc_lyaer = nn.Sequential( nn.Conv1d(128, 128, kernel_size=1, bias=False), nn.BatchNorm1d(128), nn.ReLU(True), nn.Dropout(0.5), nn.Conv1d(128, 13, kernel_size=1), )