Spaces:
Sleeping
Sleeping
import math | |
import torch | |
import itertools as it | |
from torch.optim import Optimizer | |
from collections import defaultdict | |
# https://github.com/lonePatient/lookahead_pytorch/blob/1055128057408fe8533ffa30654551a317f07f0a/optimizer.py | |
class Lookahead(Optimizer): | |
''' | |
PyTorch implementation of the lookahead wrapper. | |
Lookahead Optimizer: https://arxiv.org/abs/1907.08610 | |
''' | |
def __init__(self, optimizer, alpha=0.5, k=6, pullback_momentum="none"): | |
''' | |
:param optimizer:inner optimizer | |
:param k (int): number of lookahead steps | |
:param alpha(float): linear interpolation factor. 1.0 recovers the inner optimizer. | |
:param pullback_momentum (str): change to inner optimizer momentum on interpolation update | |
''' | |
if not 0.0 <= alpha <= 1.0: | |
raise ValueError(f'Invalid slow update rate: {alpha}') | |
if not 1 <= k: | |
raise ValueError(f'Invalid lookahead steps: {k}') | |
self.optimizer = optimizer | |
self.alpha = alpha | |
self.k = k | |
self.step_counter = 0 | |
assert pullback_momentum in ["reset", "pullback", "none"] | |
self.pullback_momentum = pullback_momentum | |
self.defaults = optimizer.defaults | |
self.reset() | |
def reset(self): | |
self.param_groups = self.optimizer.param_groups | |
self.state = defaultdict(dict) | |
# Cache the current optimizer parameters | |
for group in self.optimizer.param_groups: | |
for p in group['params']: | |
param_state = self.state[p] | |
param_state['cached_params'] = torch.zeros_like(p.data) | |
param_state['cached_params'].copy_(p.data) | |
def __getstate__(self): | |
return { | |
'state': self.state, | |
'optimizer': self.optimizer, | |
'alpha': self.alpha, | |
'step_counter': self.step_counter, | |
'k': self.k, | |
'pullback_momentum': self.pullback_momentum | |
} | |
def zero_grad(self): | |
self.optimizer.zero_grad() | |
def state_dict(self): | |
return self.optimizer.state_dict() | |
def load_state_dict(self, state_dict): | |
self.optimizer.load_state_dict(state_dict) | |
self.reset() | |
def _backup_and_load_cache(self): | |
"""Useful for performing evaluation on the slow weights (which typically generalize better) | |
""" | |
for group in self.optimizer.param_groups: | |
for p in group['params']: | |
param_state = self.state[p] | |
param_state['backup_params'] = torch.zeros_like(p.data) | |
param_state['backup_params'].copy_(p.data) | |
p.data.copy_(param_state['cached_params']) | |
def _clear_and_load_backup(self): | |
for group in self.optimizer.param_groups: | |
for p in group['params']: | |
param_state = self.state[p] | |
p.data.copy_(param_state['backup_params']) | |
del param_state['backup_params'] | |
def step(self, closure=None): | |
"""Performs a single Lookahead optimization step. | |
Arguments: | |
closure (callable, optional): A closure that reevaluates the model | |
and returns the loss. | |
""" | |
loss = self.optimizer.step(closure) | |
self.step_counter += 1 | |
if self.step_counter >= self.k: | |
self.step_counter = 0 | |
# Lookahead and cache the current optimizer parameters | |
for group in self.optimizer.param_groups: | |
for p in group['params']: | |
param_state = self.state[p] | |
p.data.mul_(self.alpha).add_(param_state['cached_params'], alpha=1.0 - self.alpha) # crucial line | |
param_state['cached_params'].copy_(p.data) | |
if self.pullback_momentum == "pullback": | |
internal_momentum = self.optimizer.state[p]["momentum_buffer"] | |
self.optimizer.state[p]["momentum_buffer"] = internal_momentum.mul_(self.alpha).add_( | |
param_state["cached_mom"], alpha=1.0 - self.alpha) | |
param_state["cached_mom"] = self.optimizer.state[p]["momentum_buffer"] | |
elif self.pullback_momentum == "reset": | |
self.optimizer.state[p]["momentum_buffer"] = torch.zeros_like(p.data) | |
return loss | |