YuanTang96's picture
1
b30c1d8
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim.lr_scheduler as lr_sched
from pointnet2_ops.pointnet2_modules import PointnetFPModule, PointnetSAModule
from torch.utils.data import DataLoader, DistributedSampler
from torchvision import transforms
import pointnet2.data.data_utils as d_utils
from pointnet2.data.ModelNet40Loader import ModelNet40Cls
def set_bn_momentum_default(bn_momentum):
def fn(m):
if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
m.momentum = bn_momentum
return fn
class BNMomentumScheduler(lr_sched.LambdaLR):
def __init__(self, model, bn_lambda, last_epoch=-1, setter=set_bn_momentum_default):
if not isinstance(model, nn.Module):
raise RuntimeError(
"Class '{}' is not a PyTorch nn Module".format(type(model)._name_)
)
self.model = model
self.setter = setter
self.lmbd = bn_lambda
self.step(last_epoch + 1)
self.last_epoch = last_epoch
def step(self, epoch=None):
if epoch is None:
epoch = self.last_epoch + 1
self.last_epoch = epoch
self.model.apply(self.setter(self.lmbd(epoch)))
def state_dict(self):
return dict(last_epoch=self.last_epoch)
def load_state_dict(self, state):
self.last_epoch = state["last_epoch"]
self.step(self.last_epoch)
lr_clip = 1e-5
bnm_clip = 1e-2
class PointNet2ClassificationSSG(pl.LightningModule):
def __init__(self, hparams):
super().__init__()
self.hparams = hparams
self._build_model()
def _build_model(self):
self.SA_modules = nn.ModuleList()
self.SA_modules.append(
PointnetSAModule(
npoint=512,
radius=0.2,
nsample=64,
mlp=[3, 64, 64, 128],
use_xyz=self.hparams["model.use_xyz"],
)
)
self.SA_modules.append(
PointnetSAModule(
npoint=128,
radius=0.4,
nsample=64,
mlp=[128, 128, 128, 256],
use_xyz=self.hparams["model.use_xyz"],
)
)
self.SA_modules.append(
PointnetSAModule(
mlp=[256, 256, 512, 1024], use_xyz=self.hparams["model.use_xyz"]
)
)
self.fc_layer = nn.Sequential(
nn.Linear(1024, 512, bias=False),
nn.BatchNorm1d(512),
nn.ReLU(True),
nn.Linear(512, 256, bias=False),
nn.BatchNorm1d(256),
nn.ReLU(True),
nn.Dropout(0.5),
nn.Linear(256, 40),
)
def _break_up_pc(self, pc):
xyz = pc[..., 0:3].contiguous()
features = pc[..., 3:].transpose(1, 2).contiguous() if pc.size(-1) > 3 else None
return xyz, features
def forward(self, pointcloud):
r"""
Forward pass of the network
Parameters
----------
pointcloud: Variable(torch.cuda.FloatTensor)
(B, N, 3 + input_channels) tensor
Point cloud to run predicts on
Each point in the point-cloud MUST
be formated as (x, y, z, features...)
"""
xyz, features = self._break_up_pc(pointcloud)
for module in self.SA_modules:
xyz, features = module(xyz, features)
return self.fc_layer(features.squeeze(-1))
def training_step(self, batch, batch_idx):
pc, labels = batch
logits = self.forward(pc)
loss = F.cross_entropy(logits, labels)
with torch.no_grad():
acc = (torch.argmax(logits, dim=1) == labels).float().mean()
log = dict(train_loss=loss, train_acc=acc)
return dict(loss=loss, log=log, progress_bar=dict(train_acc=acc))
def validation_step(self, batch, batch_idx):
pc, labels = batch
logits = self.forward(pc)
loss = F.cross_entropy(logits, labels)
acc = (torch.argmax(logits, dim=1) == labels).float().mean()
return dict(val_loss=loss, val_acc=acc)
def validation_end(self, outputs):
reduced_outputs = {}
for k in outputs[0]:
for o in outputs:
reduced_outputs[k] = reduced_outputs.get(k, []) + [o[k]]
for k in reduced_outputs:
reduced_outputs[k] = torch.stack(reduced_outputs[k]).mean()
reduced_outputs.update(
dict(log=reduced_outputs.copy(), progress_bar=reduced_outputs.copy())
)
return reduced_outputs
def configure_optimizers(self):
lr_lbmd = lambda _: max(
self.hparams["optimizer.lr_decay"]
** (
int(
self.global_step
* self.hparams["batch_size"]
/ self.hparams["optimizer.decay_step"]
)
),
lr_clip / self.hparams["optimizer.lr"],
)
bn_lbmd = lambda _: max(
self.hparams["optimizer.bn_momentum"]
* self.hparams["optimizer.bnm_decay"]
** (
int(
self.global_step
* self.hparams["batch_size"]
/ self.hparams["optimizer.decay_step"]
)
),
bnm_clip,
)
optimizer = torch.optim.Adam(
self.parameters(),
lr=self.hparams["optimizer.lr"],
weight_decay=self.hparams["optimizer.weight_decay"],
)
lr_scheduler = lr_sched.LambdaLR(optimizer, lr_lambda=lr_lbmd)
bnm_scheduler = BNMomentumScheduler(self, bn_lambda=bn_lbmd)
return [optimizer], [lr_scheduler, bnm_scheduler]
def prepare_data(self):
train_transforms = transforms.Compose(
[
d_utils.PointcloudToTensor(),
d_utils.PointcloudScale(),
d_utils.PointcloudRotate(),
d_utils.PointcloudRotatePerturbation(),
d_utils.PointcloudTranslate(),
d_utils.PointcloudJitter(),
d_utils.PointcloudRandomInputDropout(),
]
)
self.train_dset = ModelNet40Cls(
self.hparams["num_points"], transforms=train_transforms, train=True
)
self.val_dset = ModelNet40Cls(
self.hparams["num_points"], transforms=None, train=False
)
def _build_dataloader(self, dset, mode):
return DataLoader(
dset,
batch_size=self.hparams["batch_size"],
shuffle=mode == "train",
num_workers=4,
pin_memory=True,
drop_last=mode == "train",
)
def train_dataloader(self):
return self._build_dataloader(self.train_dset, mode="train")
def val_dataloader(self):
return self._build_dataloader(self.val_dset, mode="val")