|
r""" Logging during training/testing """ |
|
import datetime |
|
import logging |
|
import os |
|
|
|
from tensorboardX import SummaryWriter |
|
import torch |
|
|
|
|
|
class AverageMeter: |
|
r""" Stores loss, evaluation results """ |
|
def __init__(self, dataset): |
|
self.benchmark = dataset.benchmark |
|
self.class_ids_interest = dataset.class_ids |
|
self.class_ids_interest = torch.tensor(self.class_ids_interest).cuda() |
|
|
|
if self.benchmark == 'pascal': |
|
self.nclass = 20 |
|
elif self.benchmark == 'coco': |
|
self.nclass = 80 |
|
elif self.benchmark == 'fss': |
|
self.nclass = 1000 |
|
|
|
self.intersection_buf = torch.zeros([2, self.nclass]).float().cuda() |
|
self.union_buf = torch.zeros([2, self.nclass]).float().cuda() |
|
self.ones = torch.ones_like(self.union_buf) |
|
self.loss_buf = [] |
|
|
|
def update(self, inter_b, union_b, class_id, loss): |
|
self.intersection_buf.index_add_(1, class_id, inter_b.float()) |
|
self.union_buf.index_add_(1, class_id, union_b.float()) |
|
if loss is None: |
|
loss = torch.tensor(0.0) |
|
self.loss_buf.append(loss) |
|
|
|
def compute_iou(self): |
|
iou = self.intersection_buf.float() / \ |
|
torch.max(torch.stack([self.union_buf, self.ones]), dim=0)[0] |
|
iou = iou.index_select(1, self.class_ids_interest) |
|
miou = iou[1].mean() * 100 |
|
|
|
fb_iou = (self.intersection_buf.index_select(1, self.class_ids_interest).sum(dim=1) / |
|
self.union_buf.index_select(1, self.class_ids_interest).sum(dim=1)).mean() * 100 |
|
|
|
return miou, fb_iou |
|
|
|
def write_result(self, split, epoch): |
|
iou, fb_iou = self.compute_iou() |
|
|
|
loss_buf = torch.stack(self.loss_buf) |
|
msg = '\n*** %s ' % split |
|
msg += '[@Epoch %02d] ' % epoch |
|
msg += 'Avg L: %6.5f ' % loss_buf.mean() |
|
msg += 'mIoU: %5.2f ' % iou |
|
msg += 'FB-IoU: %5.2f ' % fb_iou |
|
|
|
msg += '***\n' |
|
Logger.info(msg) |
|
|
|
def write_process(self, batch_idx, datalen, epoch, write_batch_idx=20): |
|
if batch_idx % write_batch_idx == 0: |
|
msg = '[Epoch: %02d] ' % epoch if epoch != -1 else '' |
|
msg += '[Batch: %04d/%04d] ' % (batch_idx+1, datalen) |
|
iou, fb_iou = self.compute_iou() |
|
if epoch != -1: |
|
loss_buf = torch.stack(self.loss_buf) |
|
msg += 'L: %6.5f ' % loss_buf[-1] |
|
msg += 'Avg L: %6.5f ' % loss_buf.mean() |
|
msg += 'mIoU: %5.2f | ' % iou |
|
msg += 'FB-IoU: %5.2f' % fb_iou |
|
Logger.info(msg) |
|
return iou, fb_iou |
|
|
|
|
|
class Logger: |
|
r""" Writes evaluation results of training/testing """ |
|
@classmethod |
|
def initialize(cls, args, training): |
|
logtime = datetime.datetime.now().__format__('_%m%d_%H%M%S') |
|
logpath = args.logpath if training else '_TEST_' + args.load.split('/')[-2].split('.')[0] + logtime |
|
if logpath == '': logpath = logtime |
|
|
|
cls.logpath = os.path.join('logs', logpath + '.log') |
|
cls.benchmark = args.benchmark |
|
if not os.path.exists(cls.logpath): |
|
os.makedirs(cls.logpath) |
|
|
|
logging.basicConfig(filemode='w', |
|
filename=os.path.join(cls.logpath, 'log.txt'), |
|
level=logging.INFO, |
|
format='%(message)s', |
|
datefmt='%m-%d %H:%M:%S') |
|
|
|
|
|
console = logging.StreamHandler() |
|
console.setLevel(logging.INFO) |
|
formatter = logging.Formatter('%(message)s') |
|
console.setFormatter(formatter) |
|
logging.getLogger('').addHandler(console) |
|
|
|
|
|
cls.tbd_writer = SummaryWriter(os.path.join(cls.logpath, 'tbd/runs')) |
|
|
|
|
|
logging.info('\n:=========== Few-shot Seg. with HSNet ===========') |
|
for arg_key in args.__dict__: |
|
logging.info('| %20s: %-24s' % (arg_key, str(args.__dict__[arg_key]))) |
|
logging.info(':================================================\n') |
|
|
|
@classmethod |
|
def info(cls, msg): |
|
r""" Writes log message to log.txt """ |
|
logging.info(msg) |
|
|
|
@classmethod |
|
def save_model_miou(cls, model, epoch, val_miou): |
|
torch.save(model.state_dict(), os.path.join(cls.logpath, 'best_model.pt')) |
|
cls.info('Model saved @%d w/ val. mIoU: %5.2f.\n' % (epoch, val_miou)) |
|
|
|
@classmethod |
|
def log_params(cls, model): |
|
backbone_param = 0 |
|
learner_param = 0 |
|
for k in model.state_dict().keys(): |
|
n_param = model.state_dict()[k].view(-1).size(0) |
|
if k.split('.')[0] in 'backbone': |
|
if k.split('.')[1] in ['classifier', 'fc']: |
|
continue |
|
backbone_param += n_param |
|
else: |
|
learner_param += n_param |
|
Logger.info('Backbone # param.: %d' % backbone_param) |
|
Logger.info('Learnable # param.: %d' % learner_param) |
|
Logger.info('Total # param.: %d' % (backbone_param + learner_param)) |
|
|
|
|