from tensorflow.keras.optimizers import Optimizer from tensorflow.keras import backend as K class AccumOptimizer(Optimizer): """Optimizer Inheriting Optimizer class, wrapping the original optimizer to achieve a new corresponding optimizer of gradient accumulation. # Arguments optimizer: an instance of keras optimizer (supporting all keras optimizers currently available); steps_per_update: the steps of gradient accumulation # Returns a new keras optimizer. """ def __init__(self, optimizer, steps_per_update=1, **kwargs): super(AccumOptimizer, self).__init__(**kwargs) self.optimizer = optimizer with K.name_scope(self.__class__.__name__): self.steps_per_update = steps_per_update self.iterations = K.variable(0, dtype='int64', name='iterations') self.cond = K.equal(self.iterations % self.steps_per_update, 0) self.lr = self.optimizer.lr self.optimizer.lr = K.switch(self.cond, self.optimizer.lr, 0.) for attr in ['momentum', 'rho', 'beta_1', 'beta_2']: if hasattr(self.optimizer, attr): value = getattr(self.optimizer, attr) setattr(self, attr, value) setattr(self.optimizer, attr, K.switch(self.cond, value, 1 - 1e-7)) for attr in self.optimizer.get_config(): if not hasattr(self, attr): value = getattr(self.optimizer, attr) setattr(self, attr, value) # Cover the original get_gradients method with accumulative gradients. def get_gradients(loss, params): return [ag / self.steps_per_update for ag in self.accum_grads] self.optimizer.get_gradients = get_gradients def get_updates(self, loss, params): self.updates = [ K.update_add(self.iterations, 1), K.update_add(self.optimizer.iterations, K.cast(self.cond, 'int64')), ] # gradient accumulation self.accum_grads = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params] grads = self.get_gradients(loss, params) for g, ag in zip(grads, self.accum_grads): self.updates.append(K.update(ag, K.switch(self.cond, g, ag + g))) # inheriting updates of original optimizer self.updates.extend(self.optimizer.get_updates(loss, params)[1:]) self.weights.extend(self.optimizer.weights) return self.updates def get_config(self): iterations = K.eval(self.iterations) K.set_value(self.iterations, 0) config = self.optimizer.get_config() K.set_value(self.iterations, iterations) return config