Spaces:
Runtime error
Runtime error
| #!/usr/bin/python | |
| # -*- encoding: utf-8 -*- | |
| import torch | |
| import logging | |
| logger = logging.getLogger() | |
| class Optimizer(object): | |
| def __init__(self, | |
| model, | |
| lr0, | |
| momentum, | |
| wd, | |
| warmup_steps, | |
| warmup_start_lr, | |
| max_iter, | |
| power, | |
| *args, **kwargs): | |
| self.warmup_steps = warmup_steps | |
| self.warmup_start_lr = warmup_start_lr | |
| self.lr0 = lr0 | |
| self.lr = self.lr0 | |
| self.max_iter = float(max_iter) | |
| self.power = power | |
| self.it = 0 | |
| wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = model.get_params() | |
| param_list = [ | |
| {'params': wd_params}, | |
| {'params': nowd_params, 'weight_decay': 0}, | |
| {'params': lr_mul_wd_params, 'lr_mul': True}, | |
| {'params': lr_mul_nowd_params, 'weight_decay': 0, 'lr_mul': True}] | |
| self.optim = torch.optim.SGD( | |
| param_list, | |
| lr = lr0, | |
| momentum = momentum, | |
| weight_decay = wd) | |
| self.warmup_factor = (self.lr0/self.warmup_start_lr)**(1./self.warmup_steps) | |
| def get_lr(self): | |
| if self.it <= self.warmup_steps: | |
| lr = self.warmup_start_lr*(self.warmup_factor**self.it) | |
| else: | |
| factor = (1-(self.it-self.warmup_steps)/(self.max_iter-self.warmup_steps))**self.power | |
| lr = self.lr0 * factor | |
| return lr | |
| def step(self): | |
| self.lr = self.get_lr() | |
| for pg in self.optim.param_groups: | |
| if pg.get('lr_mul', False): | |
| pg['lr'] = self.lr * 10 | |
| else: | |
| pg['lr'] = self.lr | |
| if self.optim.defaults.get('lr_mul', False): | |
| self.optim.defaults['lr'] = self.lr * 10 | |
| else: | |
| self.optim.defaults['lr'] = self.lr | |
| self.it += 1 | |
| self.optim.step() | |
| if self.it == self.warmup_steps+2: | |
| logger.info('==> warmup done, start to implement poly lr strategy') | |
| def zero_grad(self): | |
| self.optim.zero_grad() | |