|
from tensorflow.keras.optimizers import Optimizer |
|
from tensorflow.keras import backend as K |
|
|
|
|
|
class AccumOptimizer(Optimizer): |
|
"""Optimizer |
|
Inheriting Optimizer class, wrapping the original optimizer |
|
to achieve a new corresponding optimizer of gradient accumulation. |
|
# Arguments |
|
optimizer: an instance of keras optimizer (supporting |
|
all keras optimizers currently available); |
|
steps_per_update: the steps of gradient accumulation |
|
# Returns |
|
a new keras optimizer. |
|
""" |
|
def __init__(self, optimizer, steps_per_update=1, **kwargs): |
|
super(AccumOptimizer, self).__init__(**kwargs) |
|
self.optimizer = optimizer |
|
with K.name_scope(self.__class__.__name__): |
|
self.steps_per_update = steps_per_update |
|
self.iterations = K.variable(0, dtype='int64', name='iterations') |
|
self.cond = K.equal(self.iterations % self.steps_per_update, 0) |
|
self.lr = self.optimizer.lr |
|
self.optimizer.lr = K.switch(self.cond, self.optimizer.lr, 0.) |
|
for attr in ['momentum', 'rho', 'beta_1', 'beta_2']: |
|
if hasattr(self.optimizer, attr): |
|
value = getattr(self.optimizer, attr) |
|
setattr(self, attr, value) |
|
setattr(self.optimizer, attr, K.switch(self.cond, value, 1 - 1e-7)) |
|
for attr in self.optimizer.get_config(): |
|
if not hasattr(self, attr): |
|
value = getattr(self.optimizer, attr) |
|
setattr(self, attr, value) |
|
|
|
def get_gradients(loss, params): |
|
return [ag / self.steps_per_update for ag in self.accum_grads] |
|
self.optimizer.get_gradients = get_gradients |
|
def get_updates(self, loss, params): |
|
self.updates = [ |
|
K.update_add(self.iterations, 1), |
|
K.update_add(self.optimizer.iterations, K.cast(self.cond, 'int64')), |
|
] |
|
|
|
self.accum_grads = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params] |
|
grads = self.get_gradients(loss, params) |
|
for g, ag in zip(grads, self.accum_grads): |
|
self.updates.append(K.update(ag, K.switch(self.cond, g, ag + g))) |
|
|
|
self.updates.extend(self.optimizer.get_updates(loss, params)[1:]) |
|
self.weights.extend(self.optimizer.weights) |
|
return self.updates |
|
def get_config(self): |
|
iterations = K.eval(self.iterations) |
|
K.set_value(self.iterations, 0) |
|
config = self.optimizer.get_config() |
|
K.set_value(self.iterations, iterations) |
|
return config |
|
|