kairunwen's picture
Update Code
57746f1
r""" Logging during training/testing """
import datetime
import logging
import os
from tensorboardX import SummaryWriter
import torch
class AverageMeter:
r""" Stores loss, evaluation results """
def __init__(self, dataset):
self.benchmark = dataset.benchmark
self.class_ids_interest = dataset.class_ids
self.class_ids_interest = torch.tensor(self.class_ids_interest).cuda()
if self.benchmark == 'pascal':
self.nclass = 20
elif self.benchmark == 'coco':
self.nclass = 80
elif self.benchmark == 'fss':
self.nclass = 1000
self.intersection_buf = torch.zeros([2, self.nclass]).float().cuda()
self.union_buf = torch.zeros([2, self.nclass]).float().cuda()
self.ones = torch.ones_like(self.union_buf)
self.loss_buf = []
def update(self, inter_b, union_b, class_id, loss):
self.intersection_buf.index_add_(1, class_id, inter_b.float())
self.union_buf.index_add_(1, class_id, union_b.float())
if loss is None:
loss = torch.tensor(0.0)
self.loss_buf.append(loss)
def compute_iou(self):
iou = self.intersection_buf.float() / \
torch.max(torch.stack([self.union_buf, self.ones]), dim=0)[0]
iou = iou.index_select(1, self.class_ids_interest)
miou = iou[1].mean() * 100
fb_iou = (self.intersection_buf.index_select(1, self.class_ids_interest).sum(dim=1) /
self.union_buf.index_select(1, self.class_ids_interest).sum(dim=1)).mean() * 100
return miou, fb_iou
def write_result(self, split, epoch):
iou, fb_iou = self.compute_iou()
loss_buf = torch.stack(self.loss_buf)
msg = '\n*** %s ' % split
msg += '[@Epoch %02d] ' % epoch
msg += 'Avg L: %6.5f ' % loss_buf.mean()
msg += 'mIoU: %5.2f ' % iou
msg += 'FB-IoU: %5.2f ' % fb_iou
msg += '***\n'
Logger.info(msg)
def write_process(self, batch_idx, datalen, epoch, write_batch_idx=20):
if batch_idx % write_batch_idx == 0:
msg = '[Epoch: %02d] ' % epoch if epoch != -1 else ''
msg += '[Batch: %04d/%04d] ' % (batch_idx+1, datalen)
iou, fb_iou = self.compute_iou()
if epoch != -1:
loss_buf = torch.stack(self.loss_buf)
msg += 'L: %6.5f ' % loss_buf[-1]
msg += 'Avg L: %6.5f ' % loss_buf.mean()
msg += 'mIoU: %5.2f | ' % iou
msg += 'FB-IoU: %5.2f' % fb_iou
Logger.info(msg)
return iou, fb_iou
class Logger:
r""" Writes evaluation results of training/testing """
@classmethod
def initialize(cls, args, training):
logtime = datetime.datetime.now().__format__('_%m%d_%H%M%S')
logpath = args.logpath if training else '_TEST_' + args.load.split('/')[-2].split('.')[0] + logtime
if logpath == '': logpath = logtime
cls.logpath = os.path.join('logs', logpath + '.log')
cls.benchmark = args.benchmark
if not os.path.exists(cls.logpath):
os.makedirs(cls.logpath)
logging.basicConfig(filemode='w',
filename=os.path.join(cls.logpath, 'log.txt'),
level=logging.INFO,
format='%(message)s',
datefmt='%m-%d %H:%M:%S')
# Console log config
console = logging.StreamHandler()
console.setLevel(logging.INFO)
formatter = logging.Formatter('%(message)s')
console.setFormatter(formatter)
logging.getLogger('').addHandler(console)
# Tensorboard writer
cls.tbd_writer = SummaryWriter(os.path.join(cls.logpath, 'tbd/runs'))
# Log arguments
logging.info('\n:=========== Few-shot Seg. with HSNet ===========')
for arg_key in args.__dict__:
logging.info('| %20s: %-24s' % (arg_key, str(args.__dict__[arg_key])))
logging.info(':================================================\n')
@classmethod
def info(cls, msg):
r""" Writes log message to log.txt """
logging.info(msg)
@classmethod
def save_model_miou(cls, model, epoch, val_miou):
torch.save(model.state_dict(), os.path.join(cls.logpath, 'best_model.pt'))
cls.info('Model saved @%d w/ val. mIoU: %5.2f.\n' % (epoch, val_miou))
@classmethod
def log_params(cls, model):
backbone_param = 0
learner_param = 0
for k in model.state_dict().keys():
n_param = model.state_dict()[k].view(-1).size(0)
if k.split('.')[0] in 'backbone':
if k.split('.')[1] in ['classifier', 'fc']: # as fc layers are not used in HSNet
continue
backbone_param += n_param
else:
learner_param += n_param
Logger.info('Backbone # param.: %d' % backbone_param)
Logger.info('Learnable # param.: %d' % learner_param)
Logger.info('Total # param.: %d' % (backbone_param + learner_param))