Spaces:
Runtime error
Runtime error
| # --------------------------------------------------------------- | |
| # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. | |
| # | |
| # This work is licensed under the NVIDIA Source Code License | |
| # for Denoising Diffusion GAN. To view a copy of this license, see the LICENSE file. | |
| # --------------------------------------------------------------- | |
| ''' | |
| Codes adapted from https://github.com/NVlabs/LSGM/blob/main/util/ema.py | |
| ''' | |
| import warnings | |
| import torch | |
| from torch.optim import Optimizer | |
| class EMA(Optimizer): | |
| def __init__(self, opt, ema_decay, memory_efficient=False): | |
| self.ema_decay = ema_decay | |
| self.apply_ema = self.ema_decay > 0. | |
| self.optimizer = opt | |
| self.state = opt.state | |
| self.param_groups = opt.param_groups | |
| self.defaults = {} | |
| self.memory_efficient = memory_efficient | |
| def step(self, *args, **kwargs): | |
| # for group in self.optimizer.param_groups: | |
| # group.setdefault('amsgrad', False) | |
| # group.setdefault('maximize', False) | |
| # group.setdefault('foreach', None) | |
| # group.setdefault('capturable', False) | |
| # group.setdefault('differentiable', False) | |
| # group.setdefault('fused', False) | |
| retval = self.optimizer.step(*args, **kwargs) | |
| # stop here if we are not applying EMA | |
| if not self.apply_ema: | |
| return retval | |
| ema, params = {}, {} | |
| for group in self.optimizer.param_groups: | |
| for i, p in enumerate(group['params']): | |
| if p.grad is None: | |
| continue | |
| state = self.optimizer.state[p] | |
| # State initialization | |
| if 'ema' not in state: | |
| state['ema'] = p.data.clone() | |
| if p.shape not in params: | |
| params[p.shape] = {'idx': 0, 'data': []} | |
| ema[p.shape] = [] | |
| params[p.shape]['data'].append(p.data) | |
| ema[p.shape].append(state['ema']) | |
| # def stack(d, dim=0): | |
| # return torch.stack([di.cpu() for di in d], dim=dim).cuda() | |
| for i in params: | |
| if self.memory_efficient: | |
| for j in range(len(params[i]['data'])): | |
| ema[i][j].mul_(self.ema_decay).add_(params[i]['data'][j], alpha=1. - self.ema_decay) | |
| ema[i] = torch.stack(ema[i], dim=0) | |
| else: | |
| params[i]['data'] = torch.stack(params[i]['data'], dim=0) | |
| ema[i] = torch.stack(ema[i], dim=0) | |
| ema[i].mul_(self.ema_decay).add_(params[i]['data'], alpha=1. - self.ema_decay) | |
| for p in group['params']: | |
| if p.grad is None: | |
| continue | |
| idx = params[p.shape]['idx'] | |
| self.optimizer.state[p]['ema'] = ema[p.shape][idx, :] | |
| params[p.shape]['idx'] += 1 | |
| return retval | |
| def load_state_dict(self, state_dict): | |
| super(EMA, self).load_state_dict(state_dict) | |
| # load_state_dict loads the data to self.state and self.param_groups. We need to pass this data to | |
| # the underlying optimizer too. | |
| self.optimizer.state = self.state | |
| self.optimizer.param_groups = self.param_groups | |
| def swap_parameters_with_ema(self, store_params_in_ema): | |
| """ This function swaps parameters with their ema values. It records original parameters in the ema | |
| parameters, if store_params_in_ema is true.""" | |
| # stop here if we are not applying EMA | |
| if not self.apply_ema: | |
| warnings.warn('swap_parameters_with_ema was called when there is no EMA weights.') | |
| return | |
| for group in self.optimizer.param_groups: | |
| for i, p in enumerate(group['params']): | |
| if not p.requires_grad: | |
| continue | |
| ema = self.optimizer.state[p]['ema'] | |
| if store_params_in_ema: | |
| tmp = p.data.detach() | |
| p.data = ema.detach() | |
| self.optimizer.state[p]['ema'] = tmp | |
| else: | |
| p.data = ema.detach() | |