Spaces:
Running
Running
File size: 6,323 Bytes
bc75bfa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
import torch
import torch.optim as optim
import numpy as np
import logging
# Configure logging for loss monitoring
logging.basicConfig (level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger (__name__)
class Azure (optim.Optimizer):
def __init__(self, params, lr=0.0007518383921113902, T0=2.2723218904585964, sigma=0.17181058166567398,
betas=(0.9, 0.999), eps=1e-8, sa_steps=5, sa_momentum=0.6612913488540948, clip_grad_norm=1.0):
"""
Azure Sky Optimizer: A hybrid optimizer combining Simulated Annealing (SA) and Adam.
Args:
params (iterable): Iterable of parameters or dicts defining parameter groups.
lr (float): Learning rate for Adam phase (default: 0.0007518383921113902).
T0 (float): Initial temperature for SA (default: 2.2723218904585964).
sigma (float): Perturbation strength for SA (default: 0.17181058166567398).
betas (tuple): Adam's exponential decay rates (default: (0.9, 0.999)).
eps (float): Adam's epsilon for numerical stability (default: 1e-8).
sa_steps (int): Number of steps for SA phase (default: 5).
sa_momentum (float): Momentum for SA updates (default: 0.6612913488540948).
clip_grad_norm (float): Max norm for gradient clipping (default: 1.0).
"""
# Process params to handle various input formats
if isinstance (params, (list, tuple)) and isinstance (params [0], dict):
# Handle parameter groups (e.g., [{'params': ..., 'lr': ...}, ...])
param_groups = []
for group in params:
group_dict = group.copy ()
if 'params' not in group_dict:
raise ValueError ("Each parameter group must contain a 'params' key")
# Convert named_parameters() to a list of parameters if necessary
if isinstance (group_dict ['params'], (list, tuple)) and isinstance (group_dict ['params'] [0], tuple):
group_dict ['params'] = [p for _, p in group_dict ['params']]
param_groups.append (group_dict)
params = param_groups
else:
# Handle direct parameter lists or named_parameters()
if isinstance (params, (list, tuple)) and isinstance (params [0], tuple):
params = [p for _, p in params] # Convert named_parameters() to parameter list
params = [{'params': params}]
# Set defaults for each parameter group
defaults = dict (lr=lr, T0=T0, sigma=sigma, betas=betas, eps=eps, sa_steps=sa_steps,
sa_momentum=sa_momentum, clip_grad_norm=clip_grad_norm)
super ().__init__ (params, defaults)
self.step_count = 0
self.sa_active = True
self.losses = []
self.loss_window = 5
self.loss_spike_threshold = 10.0
def step(self, closure=None):
"""Performs a single optimization step."""
loss = None
if closure is not None:
with torch.enable_grad ():
loss = closure ()
# Loss spike monitoring
if loss is not None:
self._monitor_loss (loss.item ())
for group in self.param_groups:
# Gradient clipping
if group ['clip_grad_norm'] is not None:
torch.nn.utils.clip_grad_norm_ (group ['params'], group ['clip_grad_norm'])
for p in group ['params']:
if p.grad is None:
continue
grad = p.grad.data
# Dynamic Temperature Scaling
T = self._compute_temperature (group)
# Exploration-Exploitation Fusion
alpha = self._compute_alpha (group)
if self.sa_active:
noise = torch.randn_like (p.data) * group ['sigma'] * T
sa_update = noise
else:
sa_update = torch.zeros_like (p.data)
# Adam update
state = self.state [p]
if 'm' not in state:
state ['m'] = torch.zeros_like (p.data)
state ['v'] = torch.zeros_like (p.data)
state ['step'] = 0
m, v = state ['m'], state ['v']
beta1, beta2 = group ['betas']
state ['step'] += 1
m.mul_ (beta1).add_ (grad, alpha=1 - beta1)
v.mul_ (beta2).addcmul_ (grad, grad, value=1 - beta2)
m_hat = m / (1 - beta1 ** state ['step'])
v_hat = v / (1 - beta2 ** state ['step'])
# Use group-specific learning rate if provided
lr = group.get ('lr', self.defaults ['lr'])
adam_update = -lr * m_hat / (v_hat.sqrt () + group ['eps'])
# Combined update
update = alpha * adam_update + (1 - alpha) * sa_update
p.data.add_ (update)
self.step_count += 1
if self.step_count >= self.param_groups [0] ['sa_steps']:
self.sa_active = False
return loss
def _compute_temperature(self, group):
"""Dynamic Temperature Scaling based on step progress."""
epoch_decay = 0.05 # Adjustable decay rate
return group ['T0'] * (1.0 / (1.0 + epoch_decay * self.step_count))
def _compute_alpha(self, group):
"""Exploration-Exploitation Fusion Schedule using sigmoid."""
midpoint = group ['sa_steps'] / 2
return 1 / (1 + np.exp (-(self.step_count - midpoint) / (midpoint / 5)))
def _monitor_loss(self, loss):
"""Monitors for loss spikes and logs warnings."""
self.losses.append (loss)
if len (self.losses) > self.loss_window:
self.losses.pop (0)
avg_loss = sum (self.losses [:-1]) / (len (self.losses) - 1)
current_loss = self.losses [-1]
if current_loss > avg_loss * self.loss_spike_threshold:
logger.warning (
f"Loss spike detected: {current_loss:.4f} > {avg_loss:.4f} * {self.loss_spike_threshold}")
|