File size: 6,323 Bytes
bc75bfa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import torch
import torch.optim as optim
import numpy as np
import logging

# Configure logging for loss monitoring
logging.basicConfig (level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger (__name__)


class Azure (optim.Optimizer):
    def __init__(self, params, lr=0.0007518383921113902, T0=2.2723218904585964, sigma=0.17181058166567398,

                 betas=(0.9, 0.999), eps=1e-8, sa_steps=5, sa_momentum=0.6612913488540948, clip_grad_norm=1.0):
        """

        Azure Sky Optimizer: A hybrid optimizer combining Simulated Annealing (SA) and Adam.



        Args:

            params (iterable): Iterable of parameters or dicts defining parameter groups.

            lr (float): Learning rate for Adam phase (default: 0.0007518383921113902).

            T0 (float): Initial temperature for SA (default: 2.2723218904585964).

            sigma (float): Perturbation strength for SA (default: 0.17181058166567398).

            betas (tuple): Adam's exponential decay rates (default: (0.9, 0.999)).

            eps (float): Adam's epsilon for numerical stability (default: 1e-8).

            sa_steps (int): Number of steps for SA phase (default: 5).

            sa_momentum (float): Momentum for SA updates (default: 0.6612913488540948).

            clip_grad_norm (float): Max norm for gradient clipping (default: 1.0).

        """
        # Process params to handle various input formats
        if isinstance (params, (list, tuple)) and isinstance (params [0], dict):
            # Handle parameter groups (e.g., [{'params': ..., 'lr': ...}, ...])
            param_groups = []
            for group in params:
                group_dict = group.copy ()
                if 'params' not in group_dict:
                    raise ValueError ("Each parameter group must contain a 'params' key")
                # Convert named_parameters() to a list of parameters if necessary
                if isinstance (group_dict ['params'], (list, tuple)) and isinstance (group_dict ['params'] [0], tuple):
                    group_dict ['params'] = [p for _, p in group_dict ['params']]
                param_groups.append (group_dict)
            params = param_groups
        else:
            # Handle direct parameter lists or named_parameters()
            if isinstance (params, (list, tuple)) and isinstance (params [0], tuple):
                params = [p for _, p in params]  # Convert named_parameters() to parameter list
            params = [{'params': params}]

        # Set defaults for each parameter group
        defaults = dict (lr=lr, T0=T0, sigma=sigma, betas=betas, eps=eps, sa_steps=sa_steps,
                         sa_momentum=sa_momentum, clip_grad_norm=clip_grad_norm)
        super ().__init__ (params, defaults)
        self.step_count = 0
        self.sa_active = True
        self.losses = []
        self.loss_window = 5
        self.loss_spike_threshold = 10.0

    def step(self, closure=None):
        """Performs a single optimization step."""
        loss = None
        if closure is not None:
            with torch.enable_grad ():
                loss = closure ()

        # Loss spike monitoring
        if loss is not None:
            self._monitor_loss (loss.item ())

        for group in self.param_groups:
            # Gradient clipping
            if group ['clip_grad_norm'] is not None:
                torch.nn.utils.clip_grad_norm_ (group ['params'], group ['clip_grad_norm'])

            for p in group ['params']:
                if p.grad is None:
                    continue
                grad = p.grad.data

                # Dynamic Temperature Scaling
                T = self._compute_temperature (group)
                # Exploration-Exploitation Fusion
                alpha = self._compute_alpha (group)

                if self.sa_active:
                    noise = torch.randn_like (p.data) * group ['sigma'] * T
                    sa_update = noise
                else:
                    sa_update = torch.zeros_like (p.data)

                # Adam update
                state = self.state [p]
                if 'm' not in state:
                    state ['m'] = torch.zeros_like (p.data)
                    state ['v'] = torch.zeros_like (p.data)
                    state ['step'] = 0
                m, v = state ['m'], state ['v']
                beta1, beta2 = group ['betas']
                state ['step'] += 1
                m.mul_ (beta1).add_ (grad, alpha=1 - beta1)
                v.mul_ (beta2).addcmul_ (grad, grad, value=1 - beta2)
                m_hat = m / (1 - beta1 ** state ['step'])
                v_hat = v / (1 - beta2 ** state ['step'])
                # Use group-specific learning rate if provided
                lr = group.get ('lr', self.defaults ['lr'])
                adam_update = -lr * m_hat / (v_hat.sqrt () + group ['eps'])

                # Combined update
                update = alpha * adam_update + (1 - alpha) * sa_update
                p.data.add_ (update)

        self.step_count += 1
        if self.step_count >= self.param_groups [0] ['sa_steps']:
            self.sa_active = False
        return loss

    def _compute_temperature(self, group):
        """Dynamic Temperature Scaling based on step progress."""
        epoch_decay = 0.05  # Adjustable decay rate
        return group ['T0'] * (1.0 / (1.0 + epoch_decay * self.step_count))

    def _compute_alpha(self, group):
        """Exploration-Exploitation Fusion Schedule using sigmoid."""
        midpoint = group ['sa_steps'] / 2
        return 1 / (1 + np.exp (-(self.step_count - midpoint) / (midpoint / 5)))

    def _monitor_loss(self, loss):
        """Monitors for loss spikes and logs warnings."""
        self.losses.append (loss)
        if len (self.losses) > self.loss_window:
            self.losses.pop (0)
            avg_loss = sum (self.losses [:-1]) / (len (self.losses) - 1)
            current_loss = self.losses [-1]
            if current_loss > avg_loss * self.loss_spike_threshold:
                logger.warning (
                    f"Loss spike detected: {current_loss:.4f} > {avg_loss:.4f} * {self.loss_spike_threshold}")