Spaces:
Runtime error
Runtime error
""" | |
Tester | |
Author: Xiaoyang Wu ([email protected]) | |
Please cite our work if the code is helpful to you. | |
""" | |
import os | |
import time | |
import numpy as np | |
from collections import OrderedDict | |
import torch | |
import torch.distributed as dist | |
import torch.nn.functional as F | |
import torch.utils.data | |
from .defaults import create_ddp_model | |
import pointcept.utils.comm as comm | |
from pointcept.datasets import build_dataset, collate_fn | |
from pointcept.models import build_model | |
from pointcept.utils.logger import get_root_logger | |
from pointcept.utils.registry import Registry | |
from pointcept.utils.misc import ( | |
AverageMeter, | |
intersection_and_union, | |
intersection_and_union_gpu, | |
make_dirs, | |
) | |
TESTERS = Registry("testers") | |
class TesterBase: | |
def __init__(self, cfg, model=None, test_loader=None, verbose=False) -> None: | |
torch.multiprocessing.set_sharing_strategy("file_system") | |
self.logger = get_root_logger( | |
log_file=os.path.join(cfg.save_path, "test.log"), | |
file_mode="a" if cfg.resume else "w", | |
) | |
self.logger.info("=> Loading config ...") | |
self.cfg = cfg | |
self.verbose = verbose | |
if self.verbose: | |
self.logger.info(f"Save path: {cfg.save_path}") | |
self.logger.info(f"Config:\n{cfg.pretty_text}") | |
if model is None: | |
self.logger.info("=> Building model ...") | |
self.model = self.build_model() | |
else: | |
self.model = model | |
if test_loader is None: | |
self.logger.info("=> Building test dataset & dataloader ...") | |
self.test_loader = self.build_test_loader() | |
else: | |
self.test_loader = test_loader | |
def build_model(self): | |
model = build_model(self.cfg.model) | |
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
self.logger.info(f"Num params: {n_parameters}") | |
model = create_ddp_model( | |
model.cuda(), | |
broadcast_buffers=False, | |
find_unused_parameters=self.cfg.find_unused_parameters, | |
) | |
if os.path.isfile(self.cfg.weight): | |
self.logger.info(f"Loading weight at: {self.cfg.weight}") | |
checkpoint = torch.load(self.cfg.weight) | |
weight = OrderedDict() | |
for key, value in checkpoint["state_dict"].items(): | |
if key.startswith("module."): | |
if comm.get_world_size() == 1: | |
key = key[7:] # module.xxx.xxx -> xxx.xxx | |
else: | |
if comm.get_world_size() > 1: | |
key = "module." + key # xxx.xxx -> module.xxx.xxx | |
weight[key] = value | |
model.load_state_dict(weight, strict=True) | |
self.logger.info( | |
"=> Loaded weight '{}' (epoch {})".format( | |
self.cfg.weight, checkpoint["epoch"] | |
) | |
) | |
else: | |
raise RuntimeError("=> No checkpoint found at '{}'".format(self.cfg.weight)) | |
return model | |
def build_test_loader(self): | |
test_dataset = build_dataset(self.cfg.data.test) | |
if comm.get_world_size() > 1: | |
test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset) | |
else: | |
test_sampler = None | |
test_loader = torch.utils.data.DataLoader( | |
test_dataset, | |
batch_size=self.cfg.batch_size_test_per_gpu, | |
shuffle=False, | |
num_workers=self.cfg.batch_size_test_per_gpu, | |
pin_memory=True, | |
sampler=test_sampler, | |
collate_fn=self.__class__.collate_fn, | |
) | |
return test_loader | |
def test(self): | |
raise NotImplementedError | |
def collate_fn(batch): | |
raise collate_fn(batch) | |
class SemSegTester(TesterBase): | |
def test(self): | |
assert self.test_loader.batch_size == 1 | |
logger = get_root_logger() | |
logger.info(">>>>>>>>>>>>>>>> Start Evaluation >>>>>>>>>>>>>>>>") | |
batch_time = AverageMeter() | |
intersection_meter = AverageMeter() | |
union_meter = AverageMeter() | |
target_meter = AverageMeter() | |
self.model.eval() | |
save_path = os.path.join(self.cfg.save_path, "result") | |
make_dirs(save_path) | |
# create submit folder only on main process | |
if ( | |
self.cfg.data.test.type == "ScanNetDataset" | |
or self.cfg.data.test.type == "ScanNet200Dataset" | |
or self.cfg.data.test.type == "ScanNetPPDataset" | |
) and comm.is_main_process(): | |
make_dirs(os.path.join(save_path, "submit")) | |
elif ( | |
self.cfg.data.test.type == "SemanticKITTIDataset" and comm.is_main_process() | |
): | |
make_dirs(os.path.join(save_path, "submit")) | |
elif self.cfg.data.test.type == "NuScenesDataset" and comm.is_main_process(): | |
import json | |
make_dirs(os.path.join(save_path, "submit", "lidarseg", "test")) | |
make_dirs(os.path.join(save_path, "submit", "test")) | |
submission = dict( | |
meta=dict( | |
use_camera=False, | |
use_lidar=True, | |
use_radar=False, | |
use_map=False, | |
use_external=False, | |
) | |
) | |
with open( | |
os.path.join(save_path, "submit", "test", "submission.json"), "w" | |
) as f: | |
json.dump(submission, f, indent=4) | |
comm.synchronize() | |
record = {} | |
# fragment inference | |
for idx, data_dict in enumerate(self.test_loader): | |
end = time.time() | |
data_dict = data_dict[0] # current assume batch size is 1 | |
fragment_list = data_dict.pop("fragment_list") | |
segment = data_dict.pop("segment") | |
data_name = data_dict.pop("name") | |
pred_save_path = os.path.join(save_path, "{}_pred.npy".format(data_name)) | |
if os.path.isfile(pred_save_path): | |
logger.info( | |
"{}/{}: {}, loaded pred and label.".format( | |
idx + 1, len(self.test_loader), data_name | |
) | |
) | |
pred = np.load(pred_save_path) | |
if "origin_segment" in data_dict.keys(): | |
segment = data_dict["origin_segment"] | |
else: | |
pred = torch.zeros((segment.size, self.cfg.data.num_classes)).cuda() | |
for i in range(len(fragment_list)): | |
fragment_batch_size = 1 | |
s_i, e_i = i * fragment_batch_size, min( | |
(i + 1) * fragment_batch_size, len(fragment_list) | |
) | |
input_dict = collate_fn(fragment_list[s_i:e_i]) | |
for key in input_dict.keys(): | |
if isinstance(input_dict[key], torch.Tensor): | |
input_dict[key] = input_dict[key].cuda(non_blocking=True) | |
idx_part = input_dict["index"] | |
with torch.no_grad(): | |
pred_part = self.model(input_dict)["seg_logits"] # (n, k) | |
pred_part = F.softmax(pred_part, -1) | |
if self.cfg.empty_cache: | |
torch.cuda.empty_cache() | |
bs = 0 | |
for be in input_dict["offset"]: | |
pred[idx_part[bs:be], :] += pred_part[bs:be] | |
bs = be | |
logger.info( | |
"Test: {}/{}-{data_name}, Batch: {batch_idx}/{batch_num}".format( | |
idx + 1, | |
len(self.test_loader), | |
data_name=data_name, | |
batch_idx=i, | |
batch_num=len(fragment_list), | |
) | |
) | |
if self.cfg.data.test.type == "ScanNetPPDataset": | |
pred = pred.topk(3, dim=1)[1].data.cpu().numpy() | |
else: | |
pred = pred.max(1)[1].data.cpu().numpy() | |
if "origin_segment" in data_dict.keys(): | |
assert "inverse" in data_dict.keys() | |
pred = pred[data_dict["inverse"]] | |
segment = data_dict["origin_segment"] | |
np.save(pred_save_path, pred) | |
if ( | |
self.cfg.data.test.type == "ScanNetDataset" | |
or self.cfg.data.test.type == "ScanNet200Dataset" | |
): | |
np.savetxt( | |
os.path.join(save_path, "submit", "{}.txt".format(data_name)), | |
self.test_loader.dataset.class2id[pred].reshape([-1, 1]), | |
fmt="%d", | |
) | |
elif self.cfg.data.test.type == "ScanNetPPDataset": | |
np.savetxt( | |
os.path.join(save_path, "submit", "{}.txt".format(data_name)), | |
pred.astype(np.int32), | |
delimiter=",", | |
fmt="%d", | |
) | |
pred = pred[:, 0] # for mIoU, TODO: support top3 mIoU | |
elif self.cfg.data.test.type == "SemanticKITTIDataset": | |
# 00_000000 -> 00, 000000 | |
sequence_name, frame_name = data_name.split("_") | |
os.makedirs( | |
os.path.join( | |
save_path, "submit", "sequences", sequence_name, "predictions" | |
), | |
exist_ok=True, | |
) | |
submit = pred.astype(np.uint32) | |
submit = np.vectorize( | |
self.test_loader.dataset.learning_map_inv.__getitem__ | |
)(submit).astype(np.uint32) | |
submit.tofile( | |
os.path.join( | |
save_path, | |
"submit", | |
"sequences", | |
sequence_name, | |
"predictions", | |
f"{frame_name}.label", | |
) | |
) | |
elif self.cfg.data.test.type == "NuScenesDataset": | |
np.array(pred + 1).astype(np.uint8).tofile( | |
os.path.join( | |
save_path, | |
"submit", | |
"lidarseg", | |
"test", | |
"{}_lidarseg.bin".format(data_name), | |
) | |
) | |
intersection, union, target = intersection_and_union( | |
pred, segment, self.cfg.data.num_classes, self.cfg.data.ignore_index | |
) | |
intersection_meter.update(intersection) | |
union_meter.update(union) | |
target_meter.update(target) | |
record[data_name] = dict( | |
intersection=intersection, union=union, target=target | |
) | |
mask = union != 0 | |
iou_class = intersection / (union + 1e-10) | |
iou = np.mean(iou_class[mask]) | |
acc = sum(intersection) / (sum(target) + 1e-10) | |
m_iou = np.mean(intersection_meter.sum / (union_meter.sum + 1e-10)) | |
m_acc = np.mean(intersection_meter.sum / (target_meter.sum + 1e-10)) | |
batch_time.update(time.time() - end) | |
logger.info( | |
"Test: {} [{}/{}]-{} " | |
"Batch {batch_time.val:.3f} ({batch_time.avg:.3f}) " | |
"Accuracy {acc:.4f} ({m_acc:.4f}) " | |
"mIoU {iou:.4f} ({m_iou:.4f})".format( | |
data_name, | |
idx + 1, | |
len(self.test_loader), | |
segment.size, | |
batch_time=batch_time, | |
acc=acc, | |
m_acc=m_acc, | |
iou=iou, | |
m_iou=m_iou, | |
) | |
) | |
logger.info("Syncing ...") | |
comm.synchronize() | |
record_sync = comm.gather(record, dst=0) | |
if comm.is_main_process(): | |
record = {} | |
for _ in range(len(record_sync)): | |
r = record_sync.pop() | |
record.update(r) | |
del r | |
intersection = np.sum( | |
[meters["intersection"] for _, meters in record.items()], axis=0 | |
) | |
union = np.sum([meters["union"] for _, meters in record.items()], axis=0) | |
target = np.sum([meters["target"] for _, meters in record.items()], axis=0) | |
if self.cfg.data.test.type == "S3DISDataset": | |
torch.save( | |
dict(intersection=intersection, union=union, target=target), | |
os.path.join(save_path, f"{self.test_loader.dataset.split}.pth"), | |
) | |
iou_class = intersection / (union + 1e-10) | |
accuracy_class = intersection / (target + 1e-10) | |
mIoU = np.mean(iou_class) | |
mAcc = np.mean(accuracy_class) | |
allAcc = sum(intersection) / (sum(target) + 1e-10) | |
logger.info( | |
"Val result: mIoU/mAcc/allAcc {:.4f}/{:.4f}/{:.4f}".format( | |
mIoU, mAcc, allAcc | |
) | |
) | |
for i in range(self.cfg.data.num_classes): | |
logger.info( | |
"Class_{idx} - {name} Result: iou/accuracy {iou:.4f}/{accuracy:.4f}".format( | |
idx=i, | |
name=self.cfg.data.names[i], | |
iou=iou_class[i], | |
accuracy=accuracy_class[i], | |
) | |
) | |
logger.info("<<<<<<<<<<<<<<<<< End Evaluation <<<<<<<<<<<<<<<<<") | |
def collate_fn(batch): | |
return batch | |
class ClsTester(TesterBase): | |
def test(self): | |
logger = get_root_logger() | |
logger.info(">>>>>>>>>>>>>>>> Start Evaluation >>>>>>>>>>>>>>>>") | |
batch_time = AverageMeter() | |
intersection_meter = AverageMeter() | |
union_meter = AverageMeter() | |
target_meter = AverageMeter() | |
self.model.eval() | |
for i, input_dict in enumerate(self.test_loader): | |
for key in input_dict.keys(): | |
if isinstance(input_dict[key], torch.Tensor): | |
input_dict[key] = input_dict[key].cuda(non_blocking=True) | |
end = time.time() | |
with torch.no_grad(): | |
output_dict = self.model(input_dict) | |
output = output_dict["cls_logits"] | |
pred = output.max(1)[1] | |
label = input_dict["category"] | |
intersection, union, target = intersection_and_union_gpu( | |
pred, label, self.cfg.data.num_classes, self.cfg.data.ignore_index | |
) | |
if comm.get_world_size() > 1: | |
dist.all_reduce(intersection), dist.all_reduce(union), dist.all_reduce( | |
target | |
) | |
intersection, union, target = ( | |
intersection.cpu().numpy(), | |
union.cpu().numpy(), | |
target.cpu().numpy(), | |
) | |
intersection_meter.update(intersection), union_meter.update( | |
union | |
), target_meter.update(target) | |
accuracy = sum(intersection_meter.val) / (sum(target_meter.val) + 1e-10) | |
batch_time.update(time.time() - end) | |
logger.info( | |
"Test: [{}/{}] " | |
"Batch {batch_time.val:.3f} ({batch_time.avg:.3f}) " | |
"Accuracy {accuracy:.4f} ".format( | |
i + 1, | |
len(self.test_loader), | |
batch_time=batch_time, | |
accuracy=accuracy, | |
) | |
) | |
iou_class = intersection_meter.sum / (union_meter.sum + 1e-10) | |
accuracy_class = intersection_meter.sum / (target_meter.sum + 1e-10) | |
mIoU = np.mean(iou_class) | |
mAcc = np.mean(accuracy_class) | |
allAcc = sum(intersection_meter.sum) / (sum(target_meter.sum) + 1e-10) | |
logger.info( | |
"Val result: mIoU/mAcc/allAcc {:.4f}/{:.4f}/{:.4f}.".format( | |
mIoU, mAcc, allAcc | |
) | |
) | |
for i in range(self.cfg.data.num_classes): | |
logger.info( | |
"Class_{idx} - {name} Result: iou/accuracy {iou:.4f}/{accuracy:.4f}".format( | |
idx=i, | |
name=self.cfg.data.names[i], | |
iou=iou_class[i], | |
accuracy=accuracy_class[i], | |
) | |
) | |
logger.info("<<<<<<<<<<<<<<<<< End Evaluation <<<<<<<<<<<<<<<<<") | |
def collate_fn(batch): | |
return collate_fn(batch) | |
class ClsVotingTester(TesterBase): | |
def __init__( | |
self, | |
num_repeat=100, | |
metric="allAcc", | |
**kwargs, | |
): | |
super().__init__(**kwargs) | |
self.num_repeat = num_repeat | |
self.metric = metric | |
self.best_idx = 0 | |
self.best_record = None | |
self.best_metric = 0 | |
def test(self): | |
for i in range(self.num_repeat): | |
logger = get_root_logger() | |
logger.info(f">>>>>>>>>>>>>>>> Start Evaluation {i + 1} >>>>>>>>>>>>>>>>") | |
record = self.test_once() | |
if comm.is_main_process(): | |
if record[self.metric] > self.best_metric: | |
self.best_record = record | |
self.best_idx = i | |
self.best_metric = record[self.metric] | |
info = f"Current best record is Evaluation {i + 1}: " | |
for m in self.best_record.keys(): | |
info += f"{m}: {self.best_record[m]:.4f} " | |
logger.info(info) | |
def test_once(self): | |
logger = get_root_logger() | |
batch_time = AverageMeter() | |
intersection_meter = AverageMeter() | |
target_meter = AverageMeter() | |
record = {} | |
self.model.eval() | |
for idx, data_dict in enumerate(self.test_loader): | |
end = time.time() | |
data_dict = data_dict[0] # current assume batch size is 1 | |
voting_list = data_dict.pop("voting_list") | |
category = data_dict.pop("category") | |
data_name = data_dict.pop("name") | |
# pred = torch.zeros([1, self.cfg.data.num_classes]).cuda() | |
# for i in range(len(voting_list)): | |
# input_dict = voting_list[i] | |
# for key in input_dict.keys(): | |
# if isinstance(input_dict[key], torch.Tensor): | |
# input_dict[key] = input_dict[key].cuda(non_blocking=True) | |
# with torch.no_grad(): | |
# pred += F.softmax(self.model(input_dict)["cls_logits"], -1) | |
input_dict = collate_fn(voting_list) | |
for key in input_dict.keys(): | |
if isinstance(input_dict[key], torch.Tensor): | |
input_dict[key] = input_dict[key].cuda(non_blocking=True) | |
with torch.no_grad(): | |
pred = F.softmax(self.model(input_dict)["cls_logits"], -1).sum( | |
0, keepdim=True | |
) | |
pred = pred.max(1)[1].cpu().numpy() | |
intersection, union, target = intersection_and_union( | |
pred, category, self.cfg.data.num_classes, self.cfg.data.ignore_index | |
) | |
intersection_meter.update(intersection) | |
target_meter.update(target) | |
record[data_name] = dict(intersection=intersection, target=target) | |
acc = sum(intersection) / (sum(target) + 1e-10) | |
m_acc = np.mean(intersection_meter.sum / (target_meter.sum + 1e-10)) | |
batch_time.update(time.time() - end) | |
logger.info( | |
"Test: {} [{}/{}] " | |
"Batch {batch_time.val:.3f} ({batch_time.avg:.3f}) " | |
"Accuracy {acc:.4f} ({m_acc:.4f}) ".format( | |
data_name, | |
idx + 1, | |
len(self.test_loader), | |
batch_time=batch_time, | |
acc=acc, | |
m_acc=m_acc, | |
) | |
) | |
logger.info("Syncing ...") | |
comm.synchronize() | |
record_sync = comm.gather(record, dst=0) | |
if comm.is_main_process(): | |
record = {} | |
for _ in range(len(record_sync)): | |
r = record_sync.pop() | |
record.update(r) | |
del r | |
intersection = np.sum( | |
[meters["intersection"] for _, meters in record.items()], axis=0 | |
) | |
target = np.sum([meters["target"] for _, meters in record.items()], axis=0) | |
accuracy_class = intersection / (target + 1e-10) | |
mAcc = np.mean(accuracy_class) | |
allAcc = sum(intersection) / (sum(target) + 1e-10) | |
logger.info("Val result: mAcc/allAcc {:.4f}/{:.4f}".format(mAcc, allAcc)) | |
for i in range(self.cfg.data.num_classes): | |
logger.info( | |
"Class_{idx} - {name} Result: iou/accuracy {accuracy:.4f}".format( | |
idx=i, | |
name=self.cfg.data.names[i], | |
accuracy=accuracy_class[i], | |
) | |
) | |
return dict(mAcc=mAcc, allAcc=allAcc) | |
def collate_fn(batch): | |
return batch | |
class PartSegTester(TesterBase): | |
def test(self): | |
test_dataset = self.test_loader.dataset | |
logger = get_root_logger() | |
logger.info(">>>>>>>>>>>>>>>> Start Evaluation >>>>>>>>>>>>>>>>") | |
batch_time = AverageMeter() | |
num_categories = len(self.test_loader.dataset.categories) | |
iou_category, iou_count = np.zeros(num_categories), np.zeros(num_categories) | |
self.model.eval() | |
save_path = os.path.join( | |
self.cfg.save_path, "result", "test_epoch{}".format(self.cfg.test_epoch) | |
) | |
make_dirs(save_path) | |
for idx in range(len(test_dataset)): | |
end = time.time() | |
data_name = test_dataset.get_data_name(idx) | |
data_dict_list, label = test_dataset[idx] | |
pred = torch.zeros((label.size, self.cfg.data.num_classes)).cuda() | |
batch_num = int(np.ceil(len(data_dict_list) / self.cfg.batch_size_test)) | |
for i in range(batch_num): | |
s_i, e_i = i * self.cfg.batch_size_test, min( | |
(i + 1) * self.cfg.batch_size_test, len(data_dict_list) | |
) | |
input_dict = collate_fn(data_dict_list[s_i:e_i]) | |
for key in input_dict.keys(): | |
if isinstance(input_dict[key], torch.Tensor): | |
input_dict[key] = input_dict[key].cuda(non_blocking=True) | |
with torch.no_grad(): | |
pred_part = self.model(input_dict)["cls_logits"] | |
pred_part = F.softmax(pred_part, -1) | |
if self.cfg.empty_cache: | |
torch.cuda.empty_cache() | |
pred_part = pred_part.reshape(-1, label.size, self.cfg.data.num_classes) | |
pred = pred + pred_part.total(dim=0) | |
logger.info( | |
"Test: {} {}/{}, Batch: {batch_idx}/{batch_num}".format( | |
data_name, | |
idx + 1, | |
len(test_dataset), | |
batch_idx=i, | |
batch_num=batch_num, | |
) | |
) | |
pred = pred.max(1)[1].data.cpu().numpy() | |
category_index = data_dict_list[0]["cls_token"] | |
category = self.test_loader.dataset.categories[category_index] | |
parts_idx = self.test_loader.dataset.category2part[category] | |
parts_iou = np.zeros(len(parts_idx)) | |
for j, part in enumerate(parts_idx): | |
if (np.sum(label == part) == 0) and (np.sum(pred == part) == 0): | |
parts_iou[j] = 1.0 | |
else: | |
i = (label == part) & (pred == part) | |
u = (label == part) | (pred == part) | |
parts_iou[j] = np.sum(i) / (np.sum(u) + 1e-10) | |
iou_category[category_index] += parts_iou.mean() | |
iou_count[category_index] += 1 | |
batch_time.update(time.time() - end) | |
logger.info( | |
"Test: {} [{}/{}] " | |
"Batch {batch_time.val:.3f} " | |
"({batch_time.avg:.3f}) ".format( | |
data_name, idx + 1, len(self.test_loader), batch_time=batch_time | |
) | |
) | |
ins_mIoU = iou_category.sum() / (iou_count.sum() + 1e-10) | |
cat_mIoU = (iou_category / (iou_count + 1e-10)).mean() | |
logger.info( | |
"Val result: ins.mIoU/cat.mIoU {:.4f}/{:.4f}.".format(ins_mIoU, cat_mIoU) | |
) | |
for i in range(num_categories): | |
logger.info( | |
"Class_{idx}-{name} Result: iou_cat/num_sample {iou_cat:.4f}/{iou_count:.4f}".format( | |
idx=i, | |
name=self.test_loader.dataset.categories[i], | |
iou_cat=iou_category[i] / (iou_count[i] + 1e-10), | |
iou_count=int(iou_count[i]), | |
) | |
) | |
logger.info("<<<<<<<<<<<<<<<<< End Evaluation <<<<<<<<<<<<<<<<<") | |
def collate_fn(batch): | |
return collate_fn(batch) | |