import numpy as np import logging import os def count_params(model): param_num = sum(p.numel() for p in model.parameters()) return param_num / 1e6 def color_map(dataset='pascal'): cmap = np.zeros((256, 3), dtype='uint8') if dataset == 'pascal' or dataset == 'coco': def bitget(byteval, idx): return (byteval & (1 << idx)) != 0 for i in range(256): r = g = b = 0 c = i for j in range(8): r = r | (bitget(c, 0) << 7-j) g = g | (bitget(c, 1) << 7-j) b = b | (bitget(c, 2) << 7-j) c = c >> 3 cmap[i] = np.array([r, g, b]) elif dataset == 'cityscapes': cmap[0] = np.array([128, 64, 128]) cmap[1] = np.array([244, 35, 232]) cmap[2] = np.array([70, 70, 70]) cmap[3] = np.array([102, 102, 156]) cmap[4] = np.array([190, 153, 153]) cmap[5] = np.array([153, 153, 153]) cmap[6] = np.array([250, 170, 30]) cmap[7] = np.array([220, 220, 0]) cmap[8] = np.array([107, 142, 35]) cmap[9] = np.array([152, 251, 152]) cmap[10] = np.array([70, 130, 180]) cmap[11] = np.array([220, 20, 60]) cmap[12] = np.array([255, 0, 0]) cmap[13] = np.array([0, 0, 142]) cmap[14] = np.array([0, 0, 70]) cmap[15] = np.array([0, 60, 100]) cmap[16] = np.array([0, 80, 100]) cmap[17] = np.array([0, 0, 230]) cmap[18] = np.array([119, 11, 32]) return cmap class AverageMeter(object): """Computes and stores the average and current value""" def __init__(self, length=0): self.length = length self.reset() def reset(self): if self.length > 0: self.history = [] else: self.count = 0 self.sum = 0.0 self.val = 0.0 self.avg = 0.0 def update(self, val, num=1): if self.length > 0: # currently assert num==1 to avoid bad usage, refine when there are some explict requirements assert num == 1 self.history.append(val) if len(self.history) > self.length: del self.history[0] self.val = self.history[-1] self.avg = np.mean(self.history) else: self.val = val self.sum += val * num self.count += num self.avg = self.sum / self.count logs = set() def init_log(name, level=logging.INFO): if (name, level) in logs: return logs.add((name, level)) logger = logging.getLogger(name) logger.setLevel(level) ch = logging.StreamHandler() ch.setLevel(level) if "SLURM_PROCID" in os.environ: rank = int(os.environ["SLURM_PROCID"]) logger.addFilter(lambda record: rank == 0) else: rank = 0 format_str = "[%(asctime)s][%(levelname)8s] %(message)s" formatter = logging.Formatter(format_str) ch.setFormatter(formatter) logger.addHandler(ch) return logger