Spaces:
Running
Running
import torch | |
import torch.optim as optim | |
import numpy as np | |
import logging | |
# Configure logging for loss monitoring | |
logging.basicConfig (level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger (__name__) | |
class Azure (optim.Optimizer): | |
def __init__(self, params, lr=0.0007518383921113902, T0=2.2723218904585964, sigma=0.17181058166567398, | |
betas=(0.9, 0.999), eps=1e-8, sa_steps=5, sa_momentum=0.6612913488540948, clip_grad_norm=1.0): | |
""" | |
Azure Sky Optimizer: A hybrid optimizer combining Simulated Annealing (SA) and Adam. | |
Args: | |
params (iterable): Iterable of parameters or dicts defining parameter groups. | |
lr (float): Learning rate for Adam phase (default: 0.0007518383921113902). | |
T0 (float): Initial temperature for SA (default: 2.2723218904585964). | |
sigma (float): Perturbation strength for SA (default: 0.17181058166567398). | |
betas (tuple): Adam's exponential decay rates (default: (0.9, 0.999)). | |
eps (float): Adam's epsilon for numerical stability (default: 1e-8). | |
sa_steps (int): Number of steps for SA phase (default: 5). | |
sa_momentum (float): Momentum for SA updates (default: 0.6612913488540948). | |
clip_grad_norm (float): Max norm for gradient clipping (default: 1.0). | |
""" | |
# Process params to handle various input formats | |
if isinstance (params, (list, tuple)) and isinstance (params [0], dict): | |
# Handle parameter groups (e.g., [{'params': ..., 'lr': ...}, ...]) | |
param_groups = [] | |
for group in params: | |
group_dict = group.copy () | |
if 'params' not in group_dict: | |
raise ValueError ("Each parameter group must contain a 'params' key") | |
# Convert named_parameters() to a list of parameters if necessary | |
if isinstance (group_dict ['params'], (list, tuple)) and isinstance (group_dict ['params'] [0], tuple): | |
group_dict ['params'] = [p for _, p in group_dict ['params']] | |
param_groups.append (group_dict) | |
params = param_groups | |
else: | |
# Handle direct parameter lists or named_parameters() | |
if isinstance (params, (list, tuple)) and isinstance (params [0], tuple): | |
params = [p for _, p in params] # Convert named_parameters() to parameter list | |
params = [{'params': params}] | |
# Set defaults for each parameter group | |
defaults = dict (lr=lr, T0=T0, sigma=sigma, betas=betas, eps=eps, sa_steps=sa_steps, | |
sa_momentum=sa_momentum, clip_grad_norm=clip_grad_norm) | |
super ().__init__ (params, defaults) | |
self.step_count = 0 | |
self.sa_active = True | |
self.losses = [] | |
self.loss_window = 5 | |
self.loss_spike_threshold = 10.0 | |
def step(self, closure=None): | |
"""Performs a single optimization step.""" | |
loss = None | |
if closure is not None: | |
with torch.enable_grad (): | |
loss = closure () | |
# Loss spike monitoring | |
if loss is not None: | |
self._monitor_loss (loss.item ()) | |
for group in self.param_groups: | |
# Gradient clipping | |
if group ['clip_grad_norm'] is not None: | |
torch.nn.utils.clip_grad_norm_ (group ['params'], group ['clip_grad_norm']) | |
for p in group ['params']: | |
if p.grad is None: | |
continue | |
grad = p.grad.data | |
# Dynamic Temperature Scaling | |
T = self._compute_temperature (group) | |
# Exploration-Exploitation Fusion | |
alpha = self._compute_alpha (group) | |
if self.sa_active: | |
noise = torch.randn_like (p.data) * group ['sigma'] * T | |
sa_update = noise | |
else: | |
sa_update = torch.zeros_like (p.data) | |
# Adam update | |
state = self.state [p] | |
if 'm' not in state: | |
state ['m'] = torch.zeros_like (p.data) | |
state ['v'] = torch.zeros_like (p.data) | |
state ['step'] = 0 | |
m, v = state ['m'], state ['v'] | |
beta1, beta2 = group ['betas'] | |
state ['step'] += 1 | |
m.mul_ (beta1).add_ (grad, alpha=1 - beta1) | |
v.mul_ (beta2).addcmul_ (grad, grad, value=1 - beta2) | |
m_hat = m / (1 - beta1 ** state ['step']) | |
v_hat = v / (1 - beta2 ** state ['step']) | |
# Use group-specific learning rate if provided | |
lr = group.get ('lr', self.defaults ['lr']) | |
adam_update = -lr * m_hat / (v_hat.sqrt () + group ['eps']) | |
# Combined update | |
update = alpha * adam_update + (1 - alpha) * sa_update | |
p.data.add_ (update) | |
self.step_count += 1 | |
if self.step_count >= self.param_groups [0] ['sa_steps']: | |
self.sa_active = False | |
return loss | |
def _compute_temperature(self, group): | |
"""Dynamic Temperature Scaling based on step progress.""" | |
epoch_decay = 0.05 # Adjustable decay rate | |
return group ['T0'] * (1.0 / (1.0 + epoch_decay * self.step_count)) | |
def _compute_alpha(self, group): | |
"""Exploration-Exploitation Fusion Schedule using sigmoid.""" | |
midpoint = group ['sa_steps'] / 2 | |
return 1 / (1 + np.exp (-(self.step_count - midpoint) / (midpoint / 5))) | |
def _monitor_loss(self, loss): | |
"""Monitors for loss spikes and logs warnings.""" | |
self.losses.append (loss) | |
if len (self.losses) > self.loss_window: | |
self.losses.pop (0) | |
avg_loss = sum (self.losses [:-1]) / (len (self.losses) - 1) | |
current_loss = self.losses [-1] | |
if current_loss > avg_loss * self.loss_spike_threshold: | |
logger.warning ( | |
f"Loss spike detected: {current_loss:.4f} > {avg_loss:.4f} * {self.loss_spike_threshold}") | |