ghost233lism's picture
upload models
7f0f123 verified
import numpy as np
import logging
import os
def count_params(model):
param_num = sum(p.numel() for p in model.parameters())
return param_num / 1e6
def color_map(dataset='pascal'):
cmap = np.zeros((256, 3), dtype='uint8')
if dataset == 'pascal' or dataset == 'coco':
def bitget(byteval, idx):
return (byteval & (1 << idx)) != 0
for i in range(256):
r = g = b = 0
c = i
for j in range(8):
r = r | (bitget(c, 0) << 7-j)
g = g | (bitget(c, 1) << 7-j)
b = b | (bitget(c, 2) << 7-j)
c = c >> 3
cmap[i] = np.array([r, g, b])
elif dataset == 'cityscapes':
cmap[0] = np.array([128, 64, 128])
cmap[1] = np.array([244, 35, 232])
cmap[2] = np.array([70, 70, 70])
cmap[3] = np.array([102, 102, 156])
cmap[4] = np.array([190, 153, 153])
cmap[5] = np.array([153, 153, 153])
cmap[6] = np.array([250, 170, 30])
cmap[7] = np.array([220, 220, 0])
cmap[8] = np.array([107, 142, 35])
cmap[9] = np.array([152, 251, 152])
cmap[10] = np.array([70, 130, 180])
cmap[11] = np.array([220, 20, 60])
cmap[12] = np.array([255, 0, 0])
cmap[13] = np.array([0, 0, 142])
cmap[14] = np.array([0, 0, 70])
cmap[15] = np.array([0, 60, 100])
cmap[16] = np.array([0, 80, 100])
cmap[17] = np.array([0, 0, 230])
cmap[18] = np.array([119, 11, 32])
return cmap
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self, length=0):
self.length = length
self.reset()
def reset(self):
if self.length > 0:
self.history = []
else:
self.count = 0
self.sum = 0.0
self.val = 0.0
self.avg = 0.0
def update(self, val, num=1):
if self.length > 0:
# currently assert num==1 to avoid bad usage, refine when there are some explict requirements
assert num == 1
self.history.append(val)
if len(self.history) > self.length:
del self.history[0]
self.val = self.history[-1]
self.avg = np.mean(self.history)
else:
self.val = val
self.sum += val * num
self.count += num
self.avg = self.sum / self.count
logs = set()
def init_log(name, level=logging.INFO):
if (name, level) in logs:
return
logs.add((name, level))
logger = logging.getLogger(name)
logger.setLevel(level)
ch = logging.StreamHandler()
ch.setLevel(level)
if "SLURM_PROCID" in os.environ:
rank = int(os.environ["SLURM_PROCID"])
logger.addFilter(lambda record: rank == 0)
else:
rank = 0
format_str = "[%(asctime)s][%(levelname)8s] %(message)s"
formatter = logging.Formatter(format_str)
ch.setFormatter(formatter)
logger.addHandler(ch)
return logger