File size: 5,122 Bytes
57746f1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
r""" Logging during training/testing """
import datetime
import logging
import os
from tensorboardX import SummaryWriter
import torch
class AverageMeter:
r""" Stores loss, evaluation results """
def __init__(self, dataset):
self.benchmark = dataset.benchmark
self.class_ids_interest = dataset.class_ids
self.class_ids_interest = torch.tensor(self.class_ids_interest).cuda()
if self.benchmark == 'pascal':
self.nclass = 20
elif self.benchmark == 'coco':
self.nclass = 80
elif self.benchmark == 'fss':
self.nclass = 1000
self.intersection_buf = torch.zeros([2, self.nclass]).float().cuda()
self.union_buf = torch.zeros([2, self.nclass]).float().cuda()
self.ones = torch.ones_like(self.union_buf)
self.loss_buf = []
def update(self, inter_b, union_b, class_id, loss):
self.intersection_buf.index_add_(1, class_id, inter_b.float())
self.union_buf.index_add_(1, class_id, union_b.float())
if loss is None:
loss = torch.tensor(0.0)
self.loss_buf.append(loss)
def compute_iou(self):
iou = self.intersection_buf.float() / \
torch.max(torch.stack([self.union_buf, self.ones]), dim=0)[0]
iou = iou.index_select(1, self.class_ids_interest)
miou = iou[1].mean() * 100
fb_iou = (self.intersection_buf.index_select(1, self.class_ids_interest).sum(dim=1) /
self.union_buf.index_select(1, self.class_ids_interest).sum(dim=1)).mean() * 100
return miou, fb_iou
def write_result(self, split, epoch):
iou, fb_iou = self.compute_iou()
loss_buf = torch.stack(self.loss_buf)
msg = '\n*** %s ' % split
msg += '[@Epoch %02d] ' % epoch
msg += 'Avg L: %6.5f ' % loss_buf.mean()
msg += 'mIoU: %5.2f ' % iou
msg += 'FB-IoU: %5.2f ' % fb_iou
msg += '***\n'
Logger.info(msg)
def write_process(self, batch_idx, datalen, epoch, write_batch_idx=20):
if batch_idx % write_batch_idx == 0:
msg = '[Epoch: %02d] ' % epoch if epoch != -1 else ''
msg += '[Batch: %04d/%04d] ' % (batch_idx+1, datalen)
iou, fb_iou = self.compute_iou()
if epoch != -1:
loss_buf = torch.stack(self.loss_buf)
msg += 'L: %6.5f ' % loss_buf[-1]
msg += 'Avg L: %6.5f ' % loss_buf.mean()
msg += 'mIoU: %5.2f | ' % iou
msg += 'FB-IoU: %5.2f' % fb_iou
Logger.info(msg)
return iou, fb_iou
class Logger:
r""" Writes evaluation results of training/testing """
@classmethod
def initialize(cls, args, training):
logtime = datetime.datetime.now().__format__('_%m%d_%H%M%S')
logpath = args.logpath if training else '_TEST_' + args.load.split('/')[-2].split('.')[0] + logtime
if logpath == '': logpath = logtime
cls.logpath = os.path.join('logs', logpath + '.log')
cls.benchmark = args.benchmark
if not os.path.exists(cls.logpath):
os.makedirs(cls.logpath)
logging.basicConfig(filemode='w',
filename=os.path.join(cls.logpath, 'log.txt'),
level=logging.INFO,
format='%(message)s',
datefmt='%m-%d %H:%M:%S')
# Console log config
console = logging.StreamHandler()
console.setLevel(logging.INFO)
formatter = logging.Formatter('%(message)s')
console.setFormatter(formatter)
logging.getLogger('').addHandler(console)
# Tensorboard writer
cls.tbd_writer = SummaryWriter(os.path.join(cls.logpath, 'tbd/runs'))
# Log arguments
logging.info('\n:=========== Few-shot Seg. with HSNet ===========')
for arg_key in args.__dict__:
logging.info('| %20s: %-24s' % (arg_key, str(args.__dict__[arg_key])))
logging.info(':================================================\n')
@classmethod
def info(cls, msg):
r""" Writes log message to log.txt """
logging.info(msg)
@classmethod
def save_model_miou(cls, model, epoch, val_miou):
torch.save(model.state_dict(), os.path.join(cls.logpath, 'best_model.pt'))
cls.info('Model saved @%d w/ val. mIoU: %5.2f.\n' % (epoch, val_miou))
@classmethod
def log_params(cls, model):
backbone_param = 0
learner_param = 0
for k in model.state_dict().keys():
n_param = model.state_dict()[k].view(-1).size(0)
if k.split('.')[0] in 'backbone':
if k.split('.')[1] in ['classifier', 'fc']: # as fc layers are not used in HSNet
continue
backbone_param += n_param
else:
learner_param += n_param
Logger.info('Backbone # param.: %d' % backbone_param)
Logger.info('Learnable # param.: %d' % learner_param)
Logger.info('Total # param.: %d' % (backbone_param + learner_param))
|