Nexa / Backend /optimizers /azure_optim.py
Allanatrix's picture
Upload 31 files
bc75bfa verified
import torch
import torch.optim as optim
import numpy as np
import logging
# Configure logging for loss monitoring
logging.basicConfig (level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger (__name__)
class Azure (optim.Optimizer):
def __init__(self, params, lr=0.0007518383921113902, T0=2.2723218904585964, sigma=0.17181058166567398,
betas=(0.9, 0.999), eps=1e-8, sa_steps=5, sa_momentum=0.6612913488540948, clip_grad_norm=1.0):
"""
Azure Sky Optimizer: A hybrid optimizer combining Simulated Annealing (SA) and Adam.
Args:
params (iterable): Iterable of parameters or dicts defining parameter groups.
lr (float): Learning rate for Adam phase (default: 0.0007518383921113902).
T0 (float): Initial temperature for SA (default: 2.2723218904585964).
sigma (float): Perturbation strength for SA (default: 0.17181058166567398).
betas (tuple): Adam's exponential decay rates (default: (0.9, 0.999)).
eps (float): Adam's epsilon for numerical stability (default: 1e-8).
sa_steps (int): Number of steps for SA phase (default: 5).
sa_momentum (float): Momentum for SA updates (default: 0.6612913488540948).
clip_grad_norm (float): Max norm for gradient clipping (default: 1.0).
"""
# Process params to handle various input formats
if isinstance (params, (list, tuple)) and isinstance (params [0], dict):
# Handle parameter groups (e.g., [{'params': ..., 'lr': ...}, ...])
param_groups = []
for group in params:
group_dict = group.copy ()
if 'params' not in group_dict:
raise ValueError ("Each parameter group must contain a 'params' key")
# Convert named_parameters() to a list of parameters if necessary
if isinstance (group_dict ['params'], (list, tuple)) and isinstance (group_dict ['params'] [0], tuple):
group_dict ['params'] = [p for _, p in group_dict ['params']]
param_groups.append (group_dict)
params = param_groups
else:
# Handle direct parameter lists or named_parameters()
if isinstance (params, (list, tuple)) and isinstance (params [0], tuple):
params = [p for _, p in params] # Convert named_parameters() to parameter list
params = [{'params': params}]
# Set defaults for each parameter group
defaults = dict (lr=lr, T0=T0, sigma=sigma, betas=betas, eps=eps, sa_steps=sa_steps,
sa_momentum=sa_momentum, clip_grad_norm=clip_grad_norm)
super ().__init__ (params, defaults)
self.step_count = 0
self.sa_active = True
self.losses = []
self.loss_window = 5
self.loss_spike_threshold = 10.0
def step(self, closure=None):
"""Performs a single optimization step."""
loss = None
if closure is not None:
with torch.enable_grad ():
loss = closure ()
# Loss spike monitoring
if loss is not None:
self._monitor_loss (loss.item ())
for group in self.param_groups:
# Gradient clipping
if group ['clip_grad_norm'] is not None:
torch.nn.utils.clip_grad_norm_ (group ['params'], group ['clip_grad_norm'])
for p in group ['params']:
if p.grad is None:
continue
grad = p.grad.data
# Dynamic Temperature Scaling
T = self._compute_temperature (group)
# Exploration-Exploitation Fusion
alpha = self._compute_alpha (group)
if self.sa_active:
noise = torch.randn_like (p.data) * group ['sigma'] * T
sa_update = noise
else:
sa_update = torch.zeros_like (p.data)
# Adam update
state = self.state [p]
if 'm' not in state:
state ['m'] = torch.zeros_like (p.data)
state ['v'] = torch.zeros_like (p.data)
state ['step'] = 0
m, v = state ['m'], state ['v']
beta1, beta2 = group ['betas']
state ['step'] += 1
m.mul_ (beta1).add_ (grad, alpha=1 - beta1)
v.mul_ (beta2).addcmul_ (grad, grad, value=1 - beta2)
m_hat = m / (1 - beta1 ** state ['step'])
v_hat = v / (1 - beta2 ** state ['step'])
# Use group-specific learning rate if provided
lr = group.get ('lr', self.defaults ['lr'])
adam_update = -lr * m_hat / (v_hat.sqrt () + group ['eps'])
# Combined update
update = alpha * adam_update + (1 - alpha) * sa_update
p.data.add_ (update)
self.step_count += 1
if self.step_count >= self.param_groups [0] ['sa_steps']:
self.sa_active = False
return loss
def _compute_temperature(self, group):
"""Dynamic Temperature Scaling based on step progress."""
epoch_decay = 0.05 # Adjustable decay rate
return group ['T0'] * (1.0 / (1.0 + epoch_decay * self.step_count))
def _compute_alpha(self, group):
"""Exploration-Exploitation Fusion Schedule using sigmoid."""
midpoint = group ['sa_steps'] / 2
return 1 / (1 + np.exp (-(self.step_count - midpoint) / (midpoint / 5)))
def _monitor_loss(self, loss):
"""Monitors for loss spikes and logs warnings."""
self.losses.append (loss)
if len (self.losses) > self.loss_window:
self.losses.pop (0)
avg_loss = sum (self.losses [:-1]) / (len (self.losses) - 1)
current_loss = self.losses [-1]
if current_loss > avg_loss * self.loss_spike_threshold:
logger.warning (
f"Loss spike detected: {current_loss:.4f} > {avg_loss:.4f} * {self.loss_spike_threshold}")