""" Tester Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com) Please cite our work if the code is helpful to you. """ import os import time import numpy as np from collections import OrderedDict import torch import torch.distributed as dist import torch.nn.functional as F import torch.utils.data from .defaults import create_ddp_model import pointcept.utils.comm as comm from pointcept.datasets import build_dataset, collate_fn from pointcept.models import build_model from pointcept.utils.logger import get_root_logger from pointcept.utils.registry import Registry from pointcept.utils.misc import ( AverageMeter, intersection_and_union, intersection_and_union_gpu, make_dirs, ) TESTERS = Registry("testers") class TesterBase: def __init__(self, cfg, model=None, test_loader=None, verbose=False) -> None: torch.multiprocessing.set_sharing_strategy("file_system") self.logger = get_root_logger( log_file=os.path.join(cfg.save_path, "test.log"), file_mode="a" if cfg.resume else "w", ) self.logger.info("=> Loading config ...") self.cfg = cfg self.verbose = verbose if self.verbose: self.logger.info(f"Save path: {cfg.save_path}") self.logger.info(f"Config:\n{cfg.pretty_text}") if model is None: self.logger.info("=> Building model ...") self.model = self.build_model() else: self.model = model if test_loader is None: self.logger.info("=> Building test dataset & dataloader ...") self.test_loader = self.build_test_loader() else: self.test_loader = test_loader def build_model(self): model = build_model(self.cfg.model) n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) self.logger.info(f"Num params: {n_parameters}") model = create_ddp_model( model.cuda(), broadcast_buffers=False, find_unused_parameters=self.cfg.find_unused_parameters, ) if os.path.isfile(self.cfg.weight): self.logger.info(f"Loading weight at: {self.cfg.weight}") checkpoint = torch.load(self.cfg.weight) weight = OrderedDict() for key, value in checkpoint["state_dict"].items(): if key.startswith("module."): if comm.get_world_size() == 1: key = key[7:] # module.xxx.xxx -> xxx.xxx else: if comm.get_world_size() > 1: key = "module." + key # xxx.xxx -> module.xxx.xxx weight[key] = value model.load_state_dict(weight, strict=True) self.logger.info( "=> Loaded weight '{}' (epoch {})".format( self.cfg.weight, checkpoint["epoch"] ) ) else: raise RuntimeError("=> No checkpoint found at '{}'".format(self.cfg.weight)) return model def build_test_loader(self): test_dataset = build_dataset(self.cfg.data.test) if comm.get_world_size() > 1: test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset) else: test_sampler = None test_loader = torch.utils.data.DataLoader( test_dataset, batch_size=self.cfg.batch_size_test_per_gpu, shuffle=False, num_workers=self.cfg.batch_size_test_per_gpu, pin_memory=True, sampler=test_sampler, collate_fn=self.__class__.collate_fn, ) return test_loader def test(self): raise NotImplementedError @staticmethod def collate_fn(batch): raise collate_fn(batch) @TESTERS.register_module() class SemSegTester(TesterBase): def test(self): assert self.test_loader.batch_size == 1 logger = get_root_logger() logger.info(">>>>>>>>>>>>>>>> Start Evaluation >>>>>>>>>>>>>>>>") batch_time = AverageMeter() intersection_meter = AverageMeter() union_meter = AverageMeter() target_meter = AverageMeter() self.model.eval() save_path = os.path.join(self.cfg.save_path, "result") make_dirs(save_path) # create submit folder only on main process if ( self.cfg.data.test.type == "ScanNetDataset" or self.cfg.data.test.type == "ScanNet200Dataset" or self.cfg.data.test.type == "ScanNetPPDataset" ) and comm.is_main_process(): make_dirs(os.path.join(save_path, "submit")) elif ( self.cfg.data.test.type == "SemanticKITTIDataset" and comm.is_main_process() ): make_dirs(os.path.join(save_path, "submit")) elif self.cfg.data.test.type == "NuScenesDataset" and comm.is_main_process(): import json make_dirs(os.path.join(save_path, "submit", "lidarseg", "test")) make_dirs(os.path.join(save_path, "submit", "test")) submission = dict( meta=dict( use_camera=False, use_lidar=True, use_radar=False, use_map=False, use_external=False, ) ) with open( os.path.join(save_path, "submit", "test", "submission.json"), "w" ) as f: json.dump(submission, f, indent=4) comm.synchronize() record = {} # fragment inference for idx, data_dict in enumerate(self.test_loader): end = time.time() data_dict = data_dict[0] # current assume batch size is 1 fragment_list = data_dict.pop("fragment_list") segment = data_dict.pop("segment") data_name = data_dict.pop("name") pred_save_path = os.path.join(save_path, "{}_pred.npy".format(data_name)) if os.path.isfile(pred_save_path): logger.info( "{}/{}: {}, loaded pred and label.".format( idx + 1, len(self.test_loader), data_name ) ) pred = np.load(pred_save_path) if "origin_segment" in data_dict.keys(): segment = data_dict["origin_segment"] else: pred = torch.zeros((segment.size, self.cfg.data.num_classes)).cuda() for i in range(len(fragment_list)): fragment_batch_size = 1 s_i, e_i = i * fragment_batch_size, min( (i + 1) * fragment_batch_size, len(fragment_list) ) input_dict = collate_fn(fragment_list[s_i:e_i]) for key in input_dict.keys(): if isinstance(input_dict[key], torch.Tensor): input_dict[key] = input_dict[key].cuda(non_blocking=True) idx_part = input_dict["index"] with torch.no_grad(): pred_part = self.model(input_dict)["seg_logits"] # (n, k) pred_part = F.softmax(pred_part, -1) if self.cfg.empty_cache: torch.cuda.empty_cache() bs = 0 for be in input_dict["offset"]: pred[idx_part[bs:be], :] += pred_part[bs:be] bs = be logger.info( "Test: {}/{}-{data_name}, Batch: {batch_idx}/{batch_num}".format( idx + 1, len(self.test_loader), data_name=data_name, batch_idx=i, batch_num=len(fragment_list), ) ) if self.cfg.data.test.type == "ScanNetPPDataset": pred = pred.topk(3, dim=1)[1].data.cpu().numpy() else: pred = pred.max(1)[1].data.cpu().numpy() if "origin_segment" in data_dict.keys(): assert "inverse" in data_dict.keys() pred = pred[data_dict["inverse"]] segment = data_dict["origin_segment"] np.save(pred_save_path, pred) if ( self.cfg.data.test.type == "ScanNetDataset" or self.cfg.data.test.type == "ScanNet200Dataset" ): np.savetxt( os.path.join(save_path, "submit", "{}.txt".format(data_name)), self.test_loader.dataset.class2id[pred].reshape([-1, 1]), fmt="%d", ) elif self.cfg.data.test.type == "ScanNetPPDataset": np.savetxt( os.path.join(save_path, "submit", "{}.txt".format(data_name)), pred.astype(np.int32), delimiter=",", fmt="%d", ) pred = pred[:, 0] # for mIoU, TODO: support top3 mIoU elif self.cfg.data.test.type == "SemanticKITTIDataset": # 00_000000 -> 00, 000000 sequence_name, frame_name = data_name.split("_") os.makedirs( os.path.join( save_path, "submit", "sequences", sequence_name, "predictions" ), exist_ok=True, ) submit = pred.astype(np.uint32) submit = np.vectorize( self.test_loader.dataset.learning_map_inv.__getitem__ )(submit).astype(np.uint32) submit.tofile( os.path.join( save_path, "submit", "sequences", sequence_name, "predictions", f"{frame_name}.label", ) ) elif self.cfg.data.test.type == "NuScenesDataset": np.array(pred + 1).astype(np.uint8).tofile( os.path.join( save_path, "submit", "lidarseg", "test", "{}_lidarseg.bin".format(data_name), ) ) intersection, union, target = intersection_and_union( pred, segment, self.cfg.data.num_classes, self.cfg.data.ignore_index ) intersection_meter.update(intersection) union_meter.update(union) target_meter.update(target) record[data_name] = dict( intersection=intersection, union=union, target=target ) mask = union != 0 iou_class = intersection / (union + 1e-10) iou = np.mean(iou_class[mask]) acc = sum(intersection) / (sum(target) + 1e-10) m_iou = np.mean(intersection_meter.sum / (union_meter.sum + 1e-10)) m_acc = np.mean(intersection_meter.sum / (target_meter.sum + 1e-10)) batch_time.update(time.time() - end) logger.info( "Test: {} [{}/{}]-{} " "Batch {batch_time.val:.3f} ({batch_time.avg:.3f}) " "Accuracy {acc:.4f} ({m_acc:.4f}) " "mIoU {iou:.4f} ({m_iou:.4f})".format( data_name, idx + 1, len(self.test_loader), segment.size, batch_time=batch_time, acc=acc, m_acc=m_acc, iou=iou, m_iou=m_iou, ) ) logger.info("Syncing ...") comm.synchronize() record_sync = comm.gather(record, dst=0) if comm.is_main_process(): record = {} for _ in range(len(record_sync)): r = record_sync.pop() record.update(r) del r intersection = np.sum( [meters["intersection"] for _, meters in record.items()], axis=0 ) union = np.sum([meters["union"] for _, meters in record.items()], axis=0) target = np.sum([meters["target"] for _, meters in record.items()], axis=0) if self.cfg.data.test.type == "S3DISDataset": torch.save( dict(intersection=intersection, union=union, target=target), os.path.join(save_path, f"{self.test_loader.dataset.split}.pth"), ) iou_class = intersection / (union + 1e-10) accuracy_class = intersection / (target + 1e-10) mIoU = np.mean(iou_class) mAcc = np.mean(accuracy_class) allAcc = sum(intersection) / (sum(target) + 1e-10) logger.info( "Val result: mIoU/mAcc/allAcc {:.4f}/{:.4f}/{:.4f}".format( mIoU, mAcc, allAcc ) ) for i in range(self.cfg.data.num_classes): logger.info( "Class_{idx} - {name} Result: iou/accuracy {iou:.4f}/{accuracy:.4f}".format( idx=i, name=self.cfg.data.names[i], iou=iou_class[i], accuracy=accuracy_class[i], ) ) logger.info("<<<<<<<<<<<<<<<<< End Evaluation <<<<<<<<<<<<<<<<<") @staticmethod def collate_fn(batch): return batch @TESTERS.register_module() class ClsTester(TesterBase): def test(self): logger = get_root_logger() logger.info(">>>>>>>>>>>>>>>> Start Evaluation >>>>>>>>>>>>>>>>") batch_time = AverageMeter() intersection_meter = AverageMeter() union_meter = AverageMeter() target_meter = AverageMeter() self.model.eval() for i, input_dict in enumerate(self.test_loader): for key in input_dict.keys(): if isinstance(input_dict[key], torch.Tensor): input_dict[key] = input_dict[key].cuda(non_blocking=True) end = time.time() with torch.no_grad(): output_dict = self.model(input_dict) output = output_dict["cls_logits"] pred = output.max(1)[1] label = input_dict["category"] intersection, union, target = intersection_and_union_gpu( pred, label, self.cfg.data.num_classes, self.cfg.data.ignore_index ) if comm.get_world_size() > 1: dist.all_reduce(intersection), dist.all_reduce(union), dist.all_reduce( target ) intersection, union, target = ( intersection.cpu().numpy(), union.cpu().numpy(), target.cpu().numpy(), ) intersection_meter.update(intersection), union_meter.update( union ), target_meter.update(target) accuracy = sum(intersection_meter.val) / (sum(target_meter.val) + 1e-10) batch_time.update(time.time() - end) logger.info( "Test: [{}/{}] " "Batch {batch_time.val:.3f} ({batch_time.avg:.3f}) " "Accuracy {accuracy:.4f} ".format( i + 1, len(self.test_loader), batch_time=batch_time, accuracy=accuracy, ) ) iou_class = intersection_meter.sum / (union_meter.sum + 1e-10) accuracy_class = intersection_meter.sum / (target_meter.sum + 1e-10) mIoU = np.mean(iou_class) mAcc = np.mean(accuracy_class) allAcc = sum(intersection_meter.sum) / (sum(target_meter.sum) + 1e-10) logger.info( "Val result: mIoU/mAcc/allAcc {:.4f}/{:.4f}/{:.4f}.".format( mIoU, mAcc, allAcc ) ) for i in range(self.cfg.data.num_classes): logger.info( "Class_{idx} - {name} Result: iou/accuracy {iou:.4f}/{accuracy:.4f}".format( idx=i, name=self.cfg.data.names[i], iou=iou_class[i], accuracy=accuracy_class[i], ) ) logger.info("<<<<<<<<<<<<<<<<< End Evaluation <<<<<<<<<<<<<<<<<") @staticmethod def collate_fn(batch): return collate_fn(batch) @TESTERS.register_module() class ClsVotingTester(TesterBase): def __init__( self, num_repeat=100, metric="allAcc", **kwargs, ): super().__init__(**kwargs) self.num_repeat = num_repeat self.metric = metric self.best_idx = 0 self.best_record = None self.best_metric = 0 def test(self): for i in range(self.num_repeat): logger = get_root_logger() logger.info(f">>>>>>>>>>>>>>>> Start Evaluation {i + 1} >>>>>>>>>>>>>>>>") record = self.test_once() if comm.is_main_process(): if record[self.metric] > self.best_metric: self.best_record = record self.best_idx = i self.best_metric = record[self.metric] info = f"Current best record is Evaluation {i + 1}: " for m in self.best_record.keys(): info += f"{m}: {self.best_record[m]:.4f} " logger.info(info) def test_once(self): logger = get_root_logger() batch_time = AverageMeter() intersection_meter = AverageMeter() target_meter = AverageMeter() record = {} self.model.eval() for idx, data_dict in enumerate(self.test_loader): end = time.time() data_dict = data_dict[0] # current assume batch size is 1 voting_list = data_dict.pop("voting_list") category = data_dict.pop("category") data_name = data_dict.pop("name") # pred = torch.zeros([1, self.cfg.data.num_classes]).cuda() # for i in range(len(voting_list)): # input_dict = voting_list[i] # for key in input_dict.keys(): # if isinstance(input_dict[key], torch.Tensor): # input_dict[key] = input_dict[key].cuda(non_blocking=True) # with torch.no_grad(): # pred += F.softmax(self.model(input_dict)["cls_logits"], -1) input_dict = collate_fn(voting_list) for key in input_dict.keys(): if isinstance(input_dict[key], torch.Tensor): input_dict[key] = input_dict[key].cuda(non_blocking=True) with torch.no_grad(): pred = F.softmax(self.model(input_dict)["cls_logits"], -1).sum( 0, keepdim=True ) pred = pred.max(1)[1].cpu().numpy() intersection, union, target = intersection_and_union( pred, category, self.cfg.data.num_classes, self.cfg.data.ignore_index ) intersection_meter.update(intersection) target_meter.update(target) record[data_name] = dict(intersection=intersection, target=target) acc = sum(intersection) / (sum(target) + 1e-10) m_acc = np.mean(intersection_meter.sum / (target_meter.sum + 1e-10)) batch_time.update(time.time() - end) logger.info( "Test: {} [{}/{}] " "Batch {batch_time.val:.3f} ({batch_time.avg:.3f}) " "Accuracy {acc:.4f} ({m_acc:.4f}) ".format( data_name, idx + 1, len(self.test_loader), batch_time=batch_time, acc=acc, m_acc=m_acc, ) ) logger.info("Syncing ...") comm.synchronize() record_sync = comm.gather(record, dst=0) if comm.is_main_process(): record = {} for _ in range(len(record_sync)): r = record_sync.pop() record.update(r) del r intersection = np.sum( [meters["intersection"] for _, meters in record.items()], axis=0 ) target = np.sum([meters["target"] for _, meters in record.items()], axis=0) accuracy_class = intersection / (target + 1e-10) mAcc = np.mean(accuracy_class) allAcc = sum(intersection) / (sum(target) + 1e-10) logger.info("Val result: mAcc/allAcc {:.4f}/{:.4f}".format(mAcc, allAcc)) for i in range(self.cfg.data.num_classes): logger.info( "Class_{idx} - {name} Result: iou/accuracy {accuracy:.4f}".format( idx=i, name=self.cfg.data.names[i], accuracy=accuracy_class[i], ) ) return dict(mAcc=mAcc, allAcc=allAcc) @staticmethod def collate_fn(batch): return batch @TESTERS.register_module() class PartSegTester(TesterBase): def test(self): test_dataset = self.test_loader.dataset logger = get_root_logger() logger.info(">>>>>>>>>>>>>>>> Start Evaluation >>>>>>>>>>>>>>>>") batch_time = AverageMeter() num_categories = len(self.test_loader.dataset.categories) iou_category, iou_count = np.zeros(num_categories), np.zeros(num_categories) self.model.eval() save_path = os.path.join( self.cfg.save_path, "result", "test_epoch{}".format(self.cfg.test_epoch) ) make_dirs(save_path) for idx in range(len(test_dataset)): end = time.time() data_name = test_dataset.get_data_name(idx) data_dict_list, label = test_dataset[idx] pred = torch.zeros((label.size, self.cfg.data.num_classes)).cuda() batch_num = int(np.ceil(len(data_dict_list) / self.cfg.batch_size_test)) for i in range(batch_num): s_i, e_i = i * self.cfg.batch_size_test, min( (i + 1) * self.cfg.batch_size_test, len(data_dict_list) ) input_dict = collate_fn(data_dict_list[s_i:e_i]) for key in input_dict.keys(): if isinstance(input_dict[key], torch.Tensor): input_dict[key] = input_dict[key].cuda(non_blocking=True) with torch.no_grad(): pred_part = self.model(input_dict)["cls_logits"] pred_part = F.softmax(pred_part, -1) if self.cfg.empty_cache: torch.cuda.empty_cache() pred_part = pred_part.reshape(-1, label.size, self.cfg.data.num_classes) pred = pred + pred_part.total(dim=0) logger.info( "Test: {} {}/{}, Batch: {batch_idx}/{batch_num}".format( data_name, idx + 1, len(test_dataset), batch_idx=i, batch_num=batch_num, ) ) pred = pred.max(1)[1].data.cpu().numpy() category_index = data_dict_list[0]["cls_token"] category = self.test_loader.dataset.categories[category_index] parts_idx = self.test_loader.dataset.category2part[category] parts_iou = np.zeros(len(parts_idx)) for j, part in enumerate(parts_idx): if (np.sum(label == part) == 0) and (np.sum(pred == part) == 0): parts_iou[j] = 1.0 else: i = (label == part) & (pred == part) u = (label == part) | (pred == part) parts_iou[j] = np.sum(i) / (np.sum(u) + 1e-10) iou_category[category_index] += parts_iou.mean() iou_count[category_index] += 1 batch_time.update(time.time() - end) logger.info( "Test: {} [{}/{}] " "Batch {batch_time.val:.3f} " "({batch_time.avg:.3f}) ".format( data_name, idx + 1, len(self.test_loader), batch_time=batch_time ) ) ins_mIoU = iou_category.sum() / (iou_count.sum() + 1e-10) cat_mIoU = (iou_category / (iou_count + 1e-10)).mean() logger.info( "Val result: ins.mIoU/cat.mIoU {:.4f}/{:.4f}.".format(ins_mIoU, cat_mIoU) ) for i in range(num_categories): logger.info( "Class_{idx}-{name} Result: iou_cat/num_sample {iou_cat:.4f}/{iou_count:.4f}".format( idx=i, name=self.test_loader.dataset.categories[i], iou_cat=iou_category[i] / (iou_count[i] + 1e-10), iou_count=int(iou_count[i]), ) ) logger.info("<<<<<<<<<<<<<<<<< End Evaluation <<<<<<<<<<<<<<<<<") @staticmethod def collate_fn(batch): return collate_fn(batch)