File size: 9,435 Bytes
74c6a32 ab9857f 74c6a32 ab9857f 74c6a32 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 |
from tensorflow.python.keras.optimizers import Optimizer
from tensorflow.python.keras.optimizer_v2.optimizer_v2 import OptimizerV2
from tensorflow.python import ops, math_ops, state_ops, control_flow_ops
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import backend_config
import tensorflow as tf
class AccumOptimizer(Optimizer):
"""Optimizer
Inheriting Optimizer class, wrapping the original optimizer
to achieve a new corresponding optimizer of gradient accumulation.
# Arguments
optimizer: an instance of keras optimizer (supporting
all keras optimizers currently available);
steps_per_update: the steps of gradient accumulation
# Returns
a new keras optimizer.
"""
def __init__(self, optimizer, steps_per_update=1, **kwargs):
super(AccumOptimizer, self).__init__(name='AccumOptimizer', **kwargs)
self.optimizer = optimizer
with K.name_scope(self.__class__.__name__):
self.steps_per_update = steps_per_update
self.iterations = K.variable(0, dtype='int64', name='iterations')
self.cond = K.equal(self.iterations % self.steps_per_update, 0)
self.lr = self.optimizer.lr
self.optimizer.lr = K.switch(self.cond, self.optimizer.lr, 0.)
for attr in ['momentum', 'rho', 'beta_1', 'beta_2']:
if hasattr(self.optimizer, attr):
value = getattr(self.optimizer, attr)
setattr(self, attr, value)
setattr(self.optimizer, attr, K.switch(self.cond, value, 1 - 1e-7))
for attr in self.optimizer.get_config():
if not hasattr(self, attr):
value = getattr(self.optimizer, attr)
setattr(self, attr, value)
# Cover the original get_gradients method with accumulative gradients.
def get_gradients(loss, params):
return [ag / self.steps_per_update for ag in self.accum_grads]
self.optimizer.get_gradients = get_gradients
def get_updates(self, loss, params):
self.updates = [
K.update_add(self.iterations, 1),
K.update_add(self.optimizer.iterations, K.cast(self.cond, 'int64')),
]
# gradient accumulation
self.accum_grads = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
grads = self.get_gradients(loss, params)
for g, ag in zip(grads, self.accum_grads):
self.updates.append(K.update(ag, K.switch(self.cond, g, ag + g)))
# inheriting updates of original optimizer
self.updates.extend(self.optimizer.get_updates(loss, params)[1:])
self.weights.extend(self.optimizer.weights)
return self.updates
def get_config(self):
iterations = K.eval(self.iterations)
K.set_value(self.iterations, 0)
config = self.optimizer.get_config()
K.set_value(self.iterations, iterations)
return config
__all__ = ['AdamAccumulated']
# SRC: https://github.com/CyberZHG/keras-gradient-accumulation/blob/master/keras_gradient_accumulation/optimizer_v2.py
class AdamAccumulated(OptimizerV2):
"""Optimizer that implements the Adam algorithm with gradient accumulation."""
def __init__(self,
accumulation_steps,
learning_rate=0.001,
beta_1=0.9,
beta_2=0.999,
epsilon=1e-7,
amsgrad=False,
name='Adam',
**kwargs):
r"""Construct a new Adam optimizer.
Args:
accumulation_steps: An integer. Update gradient in every accumulation steps.
learning_rate: A Tensor or a floating point value. The learning rate.
beta_1: A float value or a constant float tensor. The exponential decay
rate for the 1st moment estimates.
beta_2: A float value or a constant float tensor. The exponential decay
rate for the 2nd moment estimates.
epsilon: A small constant for numerical stability. This epsilon is
"epsilon hat" in the Kingma and Ba paper (in the formula just before
Section 2.1), not the epsilon in Algorithm 1 of the paper.
amsgrad: boolean. Whether to apply AMSGrad variant of this algorithm from
the paper "On the Convergence of Adam and beyond".
name: Optional name for the operations created when applying gradients.
Defaults to "Adam". @compatibility(eager) When eager execution is
enabled, `learning_rate`, `beta_1`, `beta_2`, and `epsilon` can each be
a callable that takes no arguments and returns the actual value to use.
This can be useful for changing these values across different
invocations of optimizer functions. @end_compatibility
**kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`, `lr`,
`decay`}. `clipnorm` is clip gradients by norm; `clipvalue` is clip
gradients by value, `decay` is included for backward compatibility to
allow time inverse decay of learning rate. `lr` is included for backward
compatibility, recommended to use `learning_rate` instead.
"""
super(AdamAccumulated, self).__init__(name, **kwargs)
self._set_hyper('accumulation_steps', accumulation_steps)
self._set_hyper('learning_rate', kwargs.get('lr', learning_rate))
self._set_hyper('decay', self._initial_decay)
self._set_hyper('beta_1', beta_1)
self._set_hyper('beta_2', beta_2)
self.epsilon = epsilon or backend_config.epsilon()
self.amsgrad = amsgrad
def _create_slots(self, var_list):
for var in var_list:
self.add_slot(var, 'g')
for var in var_list:
self.add_slot(var, 'm')
for var in var_list:
self.add_slot(var, 'v')
if self.amsgrad:
for var in var_list:
self.add_slot(var, 'vhat')
def set_weights(self, weights):
params = self.weights
num_vars = int((len(params) - 1) / 2)
if len(weights) == 3 * num_vars + 1:
weights = weights[:len(params)]
super(AdamAccumulated, self).set_weights(weights)
def _resource_apply_dense(self, grad, var):
var_dtype = var.dtype.base_dtype
lr_t = self._decayed_lr(var_dtype)
beta_1_t = self._get_hyper('beta_1', var_dtype)
beta_2_t = self._get_hyper('beta_2', var_dtype)
accumulation_steps = self._get_hyper('accumulation_steps', 'int64')
update_cond = tf.equal((self.iterations + 1) % accumulation_steps, 0)
sub_step = self.iterations % accumulation_steps + 1
local_step = math_ops.cast(self.iterations // accumulation_steps + 1, var_dtype)
beta_1_power = math_ops.pow(beta_1_t, local_step)
beta_2_power = math_ops.pow(beta_2_t, local_step)
epsilon_t = ops.convert_to_tensor(self.epsilon, var_dtype)
lr = (lr_t * math_ops.sqrt(1 - beta_2_power) / (1 - beta_1_power))
lr = tf.where(update_cond, lr, 0.0)
g = self.get_slot(var, 'g')
g_a = grad / math_ops.cast(accumulation_steps, var_dtype)
g_t = tf.where(tf.equal(sub_step, 1),
g_a,
g + (g_a - g) / math_ops.cast(sub_step, var_dtype))
g_t = state_ops.assign(g, g_t, use_locking=self._use_locking)
m = self.get_slot(var, 'm')
m_t = tf.where(update_cond, m * beta_1_t + g_t * (1 - beta_1_t), m)
m_t = state_ops.assign(m, m_t, use_locking=self._use_locking)
v = self.get_slot(var, 'v')
v_t = tf.where(update_cond, v * beta_2_t + (g_t * g_t) * (1 - beta_2_t), v)
v_t = state_ops.assign(v, v_t, use_locking=self._use_locking)
if not self.amsgrad:
v_sqrt = math_ops.sqrt(v_t)
var_update = state_ops.assign_sub(
var, lr * m_t / (v_sqrt + epsilon_t), use_locking=self._use_locking)
return control_flow_ops.group(*[var_update, m_t, v_t])
else:
v_hat = self.get_slot(var, 'vhat')
v_hat_t = tf.where(update_cond, math_ops.maximum(v_hat, v_t), v_hat)
with ops.control_dependencies([v_hat_t]):
v_hat_t = state_ops.assign(
v_hat, v_hat_t, use_locking=self._use_locking)
v_hat_sqrt = math_ops.sqrt(v_hat_t)
var_update = state_ops.assign_sub(
var,
lr * m_t / (v_hat_sqrt + epsilon_t),
use_locking=self._use_locking)
return control_flow_ops.group(*[var_update, m_t, v_t, v_hat_t])
def get_config(self):
config = super(AdamAccumulated, self).get_config()
config.update({
'accumulation_steps': self._serialize_hyperparameter('accumulation_steps'),
'learning_rate': self._serialize_hyperparameter('learning_rate'),
'decay': self._serialize_hyperparameter('decay'),
'beta_1': self._serialize_hyperparameter('beta_1'),
'beta_2': self._serialize_hyperparameter('beta_2'),
'epsilon': self.epsilon,
'amsgrad': self.amsgrad,
})
return config
|