import torch import torch.optim as optim import numpy as np import logging # Configure logging for loss monitoring logging.basicConfig (level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger (__name__) class Azure (optim.Optimizer): def __init__(self, params, lr=0.0007518383921113902, T0=2.2723218904585964, sigma=0.17181058166567398, betas=(0.9, 0.999), eps=1e-8, sa_steps=5, sa_momentum=0.6612913488540948, clip_grad_norm=1.0): """ Azure Sky Optimizer: A hybrid optimizer combining Simulated Annealing (SA) and Adam. Args: params (iterable): Iterable of parameters or dicts defining parameter groups. lr (float): Learning rate for Adam phase (default: 0.0007518383921113902). T0 (float): Initial temperature for SA (default: 2.2723218904585964). sigma (float): Perturbation strength for SA (default: 0.17181058166567398). betas (tuple): Adam's exponential decay rates (default: (0.9, 0.999)). eps (float): Adam's epsilon for numerical stability (default: 1e-8). sa_steps (int): Number of steps for SA phase (default: 5). sa_momentum (float): Momentum for SA updates (default: 0.6612913488540948). clip_grad_norm (float): Max norm for gradient clipping (default: 1.0). """ # Process params to handle various input formats if isinstance (params, (list, tuple)) and isinstance (params [0], dict): # Handle parameter groups (e.g., [{'params': ..., 'lr': ...}, ...]) param_groups = [] for group in params: group_dict = group.copy () if 'params' not in group_dict: raise ValueError ("Each parameter group must contain a 'params' key") # Convert named_parameters() to a list of parameters if necessary if isinstance (group_dict ['params'], (list, tuple)) and isinstance (group_dict ['params'] [0], tuple): group_dict ['params'] = [p for _, p in group_dict ['params']] param_groups.append (group_dict) params = param_groups else: # Handle direct parameter lists or named_parameters() if isinstance (params, (list, tuple)) and isinstance (params [0], tuple): params = [p for _, p in params] # Convert named_parameters() to parameter list params = [{'params': params}] # Set defaults for each parameter group defaults = dict (lr=lr, T0=T0, sigma=sigma, betas=betas, eps=eps, sa_steps=sa_steps, sa_momentum=sa_momentum, clip_grad_norm=clip_grad_norm) super ().__init__ (params, defaults) self.step_count = 0 self.sa_active = True self.losses = [] self.loss_window = 5 self.loss_spike_threshold = 10.0 def step(self, closure=None): """Performs a single optimization step.""" loss = None if closure is not None: with torch.enable_grad (): loss = closure () # Loss spike monitoring if loss is not None: self._monitor_loss (loss.item ()) for group in self.param_groups: # Gradient clipping if group ['clip_grad_norm'] is not None: torch.nn.utils.clip_grad_norm_ (group ['params'], group ['clip_grad_norm']) for p in group ['params']: if p.grad is None: continue grad = p.grad.data # Dynamic Temperature Scaling T = self._compute_temperature (group) # Exploration-Exploitation Fusion alpha = self._compute_alpha (group) if self.sa_active: noise = torch.randn_like (p.data) * group ['sigma'] * T sa_update = noise else: sa_update = torch.zeros_like (p.data) # Adam update state = self.state [p] if 'm' not in state: state ['m'] = torch.zeros_like (p.data) state ['v'] = torch.zeros_like (p.data) state ['step'] = 0 m, v = state ['m'], state ['v'] beta1, beta2 = group ['betas'] state ['step'] += 1 m.mul_ (beta1).add_ (grad, alpha=1 - beta1) v.mul_ (beta2).addcmul_ (grad, grad, value=1 - beta2) m_hat = m / (1 - beta1 ** state ['step']) v_hat = v / (1 - beta2 ** state ['step']) # Use group-specific learning rate if provided lr = group.get ('lr', self.defaults ['lr']) adam_update = -lr * m_hat / (v_hat.sqrt () + group ['eps']) # Combined update update = alpha * adam_update + (1 - alpha) * sa_update p.data.add_ (update) self.step_count += 1 if self.step_count >= self.param_groups [0] ['sa_steps']: self.sa_active = False return loss def _compute_temperature(self, group): """Dynamic Temperature Scaling based on step progress.""" epoch_decay = 0.05 # Adjustable decay rate return group ['T0'] * (1.0 / (1.0 + epoch_decay * self.step_count)) def _compute_alpha(self, group): """Exploration-Exploitation Fusion Schedule using sigmoid.""" midpoint = group ['sa_steps'] / 2 return 1 / (1 + np.exp (-(self.step_count - midpoint) / (midpoint / 5))) def _monitor_loss(self, loss): """Monitors for loss spikes and logs warnings.""" self.losses.append (loss) if len (self.losses) > self.loss_window: self.losses.pop (0) avg_loss = sum (self.losses [:-1]) / (len (self.losses) - 1) current_loss = self.losses [-1] if current_loss > avg_loss * self.loss_spike_threshold: logger.warning ( f"Loss spike detected: {current_loss:.4f} > {avg_loss:.4f} * {self.loss_spike_threshold}")