Spaces:
Running
Running
File size: 1,175 Bytes
bc75bfa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 |
# RmsProp optimizer implementation
from abc import ABC
from .base import BaseOptimizer
class RMSpropOptimizer(BaseOptimizer, ABC):
"""
RMSprop optimizer implementation.
This optimizer uses a moving average of squared gradients to normalize the gradient.
"""
def __init__(self, params, lr=0.001, alpha=0.99, eps=1e-8):
self.params = params
self.lr = lr
self.alpha = alpha
self.eps = eps
self.state = {p: {'mean_square': 0} for p in params}
def step(self):
for p in self.params:
if p.grad is None:
continue
state = self.state[p]
state['mean_square'] = self.alpha * state['mean_square'] + (1 - self.alpha) * (p.grad ** 2)
p.data -= self.lr * p.grad / (state['mean_square'].sqrt() + self.eps)
def zero_grad(self):
for p in self.params:
p.grad = 0
def __repr__(self):
return f"RMSpropOptimizer(lr={self.lr}, alpha={self.alpha}, eps={self.eps})"
def state_dict(self):
return {p: {'mean_square': state['mean_square']} for p, state in self.state.items()}
|