Spaces:
Runtime error
Runtime error
""" | |
PointGroup for instance segmentation | |
Author: Xiaoyang Wu ([email protected]), Chengyao Wang | |
Please cite our work if the code is helpful to you. | |
""" | |
from functools import partial | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
try: | |
from pointgroup_ops import ballquery_batch_p, bfs_cluster | |
except ImportError: | |
ballquery_batch_p, bfs_cluster = None, None | |
from pointcept.models.utils import offset2batch, batch2offset | |
from pointcept.models.builder import MODELS, build_model | |
class PointGroup(nn.Module): | |
def __init__( | |
self, | |
backbone, | |
backbone_out_channels=64, | |
semantic_num_classes=20, | |
semantic_ignore_index=-1, | |
segment_ignore_index=(-1, 0, 1), | |
instance_ignore_index=-1, | |
cluster_thresh=1.5, | |
cluster_closed_points=300, | |
cluster_propose_points=100, | |
cluster_min_points=50, | |
voxel_size=0.02, | |
): | |
super().__init__() | |
norm_fn = partial(nn.BatchNorm1d, eps=1e-3, momentum=0.01) | |
self.semantic_num_classes = semantic_num_classes | |
self.segment_ignore_index = segment_ignore_index | |
self.semantic_ignore_index = semantic_ignore_index | |
self.instance_ignore_index = instance_ignore_index | |
self.cluster_thresh = cluster_thresh | |
self.cluster_closed_points = cluster_closed_points | |
self.cluster_propose_points = cluster_propose_points | |
self.cluster_min_points = cluster_min_points | |
self.voxel_size = voxel_size | |
self.backbone = build_model(backbone) | |
self.bias_head = nn.Sequential( | |
nn.Linear(backbone_out_channels, backbone_out_channels), | |
norm_fn(backbone_out_channels), | |
nn.ReLU(), | |
nn.Linear(backbone_out_channels, 3), | |
) | |
self.seg_head = nn.Linear(backbone_out_channels, semantic_num_classes) | |
self.ce_criteria = torch.nn.CrossEntropyLoss(ignore_index=semantic_ignore_index) | |
def forward(self, data_dict): | |
coord = data_dict["coord"] | |
segment = data_dict["segment"] | |
instance = data_dict["instance"] | |
instance_centroid = data_dict["instance_centroid"] | |
offset = data_dict["offset"] | |
feat = self.backbone(data_dict) | |
bias_pred = self.bias_head(feat) | |
logit_pred = self.seg_head(feat) | |
# compute loss | |
seg_loss = self.ce_criteria(logit_pred, segment) | |
mask = (instance != self.instance_ignore_index).float() | |
bias_gt = instance_centroid - coord | |
bias_dist = torch.sum(torch.abs(bias_pred - bias_gt), dim=-1) | |
bias_l1_loss = torch.sum(bias_dist * mask) / (torch.sum(mask) + 1e-8) | |
bias_pred_norm = bias_pred / ( | |
torch.norm(bias_pred, p=2, dim=1, keepdim=True) + 1e-8 | |
) | |
bias_gt_norm = bias_gt / (torch.norm(bias_gt, p=2, dim=1, keepdim=True) + 1e-8) | |
cosine_similarity = -(bias_pred_norm * bias_gt_norm).sum(-1) | |
bias_cosine_loss = torch.sum(cosine_similarity * mask) / ( | |
torch.sum(mask) + 1e-8 | |
) | |
loss = seg_loss + bias_l1_loss + bias_cosine_loss | |
return_dict = dict( | |
loss=loss, | |
seg_loss=seg_loss, | |
bias_l1_loss=bias_l1_loss, | |
bias_cosine_loss=bias_cosine_loss, | |
) | |
if not self.training: | |
center_pred = coord + bias_pred | |
center_pred /= self.voxel_size | |
logit_pred = F.softmax(logit_pred, dim=-1) | |
segment_pred = torch.max(logit_pred, 1)[1] # [n] | |
# cluster | |
mask = ( | |
~torch.concat( | |
[ | |
(segment_pred == index).unsqueeze(-1) | |
for index in self.segment_ignore_index | |
], | |
dim=1, | |
) | |
.sum(-1) | |
.bool() | |
) | |
if mask.sum() == 0: | |
proposals_idx = torch.zeros(0).int() | |
proposals_offset = torch.zeros(1).int() | |
else: | |
center_pred_ = center_pred[mask] | |
segment_pred_ = segment_pred[mask] | |
batch_ = offset2batch(offset)[mask] | |
offset_ = nn.ConstantPad1d((1, 0), 0)(batch2offset(batch_)) | |
idx, start_len = ballquery_batch_p( | |
center_pred_, | |
batch_.int(), | |
offset_.int(), | |
self.cluster_thresh, | |
self.cluster_closed_points, | |
) | |
proposals_idx, proposals_offset = bfs_cluster( | |
segment_pred_.int().cpu(), | |
idx.cpu(), | |
start_len.cpu(), | |
self.cluster_min_points, | |
) | |
proposals_idx[:, 1] = ( | |
mask.nonzero().view(-1)[proposals_idx[:, 1].long()].int() | |
) | |
# get proposal | |
proposals_pred = torch.zeros( | |
(proposals_offset.shape[0] - 1, center_pred.shape[0]), dtype=torch.int | |
) | |
proposals_pred[proposals_idx[:, 0].long(), proposals_idx[:, 1].long()] = 1 | |
instance_pred = segment_pred[ | |
proposals_idx[:, 1][proposals_offset[:-1].long()].long() | |
] | |
proposals_point_num = proposals_pred.sum(1) | |
proposals_mask = proposals_point_num > self.cluster_propose_points | |
proposals_pred = proposals_pred[proposals_mask] | |
instance_pred = instance_pred[proposals_mask] | |
pred_scores = [] | |
pred_classes = [] | |
pred_masks = proposals_pred.detach().cpu() | |
for proposal_id in range(len(proposals_pred)): | |
segment_ = proposals_pred[proposal_id] | |
confidence_ = logit_pred[ | |
segment_.bool(), instance_pred[proposal_id] | |
].mean() | |
object_ = instance_pred[proposal_id] | |
pred_scores.append(confidence_) | |
pred_classes.append(object_) | |
if len(pred_scores) > 0: | |
pred_scores = torch.stack(pred_scores).cpu() | |
pred_classes = torch.stack(pred_classes).cpu() | |
else: | |
pred_scores = torch.tensor([]) | |
pred_classes = torch.tensor([]) | |
return_dict["pred_scores"] = pred_scores | |
return_dict["pred_masks"] = pred_masks | |
return_dict["pred_classes"] = pred_classes | |
return return_dict | |