File size: 18,650 Bytes
174ae06
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
# Copyright (c) 2025 NVIDIA CORPORATION.
# Licensed under the MIT license.

# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
# LICENSE is in incl_licenses directory.

# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0

from collections import OrderedDict, defaultdict
from copy import deepcopy
from itertools import chain
from typing import Any, DefaultDict, Dict, Hashable, Iterable, List, Optional, Tuple, Union

import qoptim_cuda
import torch
from torch import Tensor
from torch.optim.optimizer import Optimizer
from typing_extensions import ParamSpec, Self, TypeAlias

StateDict: TypeAlias = Dict[str, Any]

convert_str_to_fp8 = {"E4M3": torch.float8_e4m3fn, "E5M2": torch.float8_e5m2}


class CoatAdamW(Optimizer):
    def __init__(
        self,
        qargs,
        params,
        lr: float = 1e-3,
        betas: Tuple[float, float] = (0.9, 0.999),
        eps: float = 1e-8,
        weight_decay: float = 1e-2,
        amsgrad: bool = False,
        *,
        fused: Optional[bool] = None,
    ):
        self.qargs = qargs
        assert self.qargs.first_order_expansion == self.qargs.second_order_expansion
        if not 0.0 <= lr:
            raise ValueError(f"Invalid learning rate: {lr}")
        if not 0.0 <= eps:
            raise ValueError(f"Invalid epsilon value: {eps}")
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
        if not 0.0 <= weight_decay:
            raise ValueError(f"Invalid weight_decay value: {weight_decay}")
        defaults = dict(
            lr=lr,
            betas=betas,
            eps=eps,
            weight_decay=weight_decay,
            amsgrad=amsgrad,
            fused=fused,
        )
        super().__init__(params, defaults)

    def __setstate__(self, state):
        super().__setstate__(state)
        for group in self.param_groups:
            group.setdefault("amsgrad", False)
            fused = group.setdefault("fused", None)
            for p in group["params"]:
                p_state = self.state.get(p, [])
                if len(p_state) != 0 and not torch.is_tensor(p_state["step"]):
                    step_val = float(p_state["step"])
                    p_state["step"] = torch.tensor(step_val, dtype=torch.float32)

    def _init_group(
        self,
        group,
        params_with_grad,
        grads,
        amsgrad,
        use_expansion,
        exp_avgs,
        scale_exp_avgs,
        expand_exp_avgs,
        sqrt_minmax_exp_avgs,
        exp_avg_sqs,
        scale_exp_avg_sqs,
        expand_exp_avg_sqs,
        sqrt_minmax_exp_avg_sqs,
        max_exp_avg_sqs,
        state_steps,
    ):
        for p in group["params"]:
            if p.grad is None:
                continue
            params_with_grad.append(p)
            if p.grad.is_sparse:
                raise RuntimeError("AdamW does not support sparse gradients")
            grads.append(p.grad)

            state = self.state[p]

            # print(f'Param shape: {p.shape}', file=open('debug.txt', 'a'))
            # print(f'Param shape: {p.shape}, {p.device}')

            # State initialization
            if len(state) == 0:
                # This is because kernel launches are costly on CUDA and XLA.
                state["step"] = torch.tensor(0.0)

                # Should be torch.float8_e4m3fn
                first_order_dtype = convert_str_to_fp8[self.qargs.first_order_bit]
                second_order_dtype = convert_str_to_fp8[self.qargs.second_order_bit]
                scale_shape = (p.numel() + self.qargs.qgroup_size - 1) // self.qargs.qgroup_size

                # Exponential moving average of gradient values
                state["exp_avg"] = torch.zeros_like(p, dtype=first_order_dtype, memory_format=torch.preserve_format)
                state["scale_exp_avg"] = torch.zeros(scale_shape, device=p.device, dtype=p.dtype)
                if use_expansion:
                    state["expand_exp_avg"] = torch.ones(scale_shape, device=p.device, dtype=p.dtype)
                    state["sqrt_minmax_exp_avg"] = torch.ones(scale_shape, device=p.device, dtype=p.dtype)
                # Exponential moving average of squared gradient values
                state["exp_avg_sq"] = torch.zeros_like(p, dtype=second_order_dtype, memory_format=torch.preserve_format)
                state["scale_exp_avg_sq"] = torch.zeros(scale_shape, device=p.device, dtype=p.dtype)
                if use_expansion:
                    state["expand_exp_avg_sq"] = torch.ones(scale_shape, device=p.device, dtype=p.dtype)
                    state["sqrt_minmax_exp_avg_sq"] = torch.ones(scale_shape, device=p.device, dtype=p.dtype)
                if amsgrad:
                    # Maintains max of all exp. moving avg. of sq. grad. values
                    state["max_exp_avg_sq"] = torch.zeros(p, memory_format=torch.preserve_format)

            exp_avgs.append(state["exp_avg"])
            scale_exp_avgs.append(state["scale_exp_avg"])
            if use_expansion:
                expand_exp_avgs.append(state["expand_exp_avg"])
                sqrt_minmax_exp_avgs.append(state["sqrt_minmax_exp_avg"])
            exp_avg_sqs.append(state["exp_avg_sq"])
            scale_exp_avg_sqs.append(state["scale_exp_avg_sq"])
            if use_expansion:
                expand_exp_avg_sqs.append(state["expand_exp_avg_sq"])
                sqrt_minmax_exp_avg_sqs.append(state["sqrt_minmax_exp_avg_sq"])

            if group["amsgrad"]:
                max_exp_avg_sqs.append(state["max_exp_avg_sq"])

            state_steps.append(state["step"])

    @torch._disable_dynamo
    def load_state_dict(self, state_dict: StateDict) -> None:
        r"""Loads the optimizer state.

        Args:
            state_dict (dict): optimizer state. Should be an object returned
                from a call to :meth:`state_dict`.
        """
        # shallow copy, to be consistent with module API
        state_dict = state_dict.copy()

        for pre_hook in self._optimizer_load_state_dict_pre_hooks.values():
            hook_result = pre_hook(self, state_dict)
            if hook_result is not None:
                state_dict = hook_result

        # Validate the state_dict
        groups = self.param_groups

        # Deepcopy as we write into saved_groups later to update state
        saved_groups = deepcopy(state_dict["param_groups"])

        if len(groups) != len(saved_groups):
            raise ValueError("loaded state dict has a different number of " "parameter groups")
        param_lens = (len(g["params"]) for g in groups)
        saved_lens = (len(g["params"]) for g in saved_groups)
        if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
            raise ValueError(
                "loaded state dict contains a parameter group " "that doesn't match the size of optimizer's group"
            )

        # Update the state
        id_map = dict(
            zip(
                chain.from_iterable(g["params"] for g in saved_groups), chain.from_iterable(g["params"] for g in groups)
            )
        )

        def _cast(param, value, param_id=None, param_groups=None, key=None):
            r"""Make a deep copy of value, casting all tensors to device of param."""
            if isinstance(value, torch.Tensor):
                return CoatAdamW._process_value_according_to_param_policy(param, value, param_id, param_groups, key)
            elif isinstance(value, dict):
                return {
                    k: _cast(param, v, param_id=param_id, param_groups=param_groups, key=k) for k, v in value.items()
                }
            elif isinstance(value, Iterable):
                return type(value)(_cast(param, v, param_id=param_id, param_groups=param_groups) for v in value)  # type: ignore[call-arg]
            else:
                return value

        # Copy state assigned to params (and cast tensors to appropriate types).
        # State that is not assigned to params is copied as is (needed for
        # backward compatibility).
        state: DefaultDict[torch.Tensor, Dict[Any, Any]] = defaultdict(dict)
        for k, v in state_dict["state"].items():
            if k in id_map:
                param = id_map[k]
                state[param] = _cast(param, v, param_id=k, param_groups=state_dict["param_groups"])
            else:
                state[k] = v

        # Update parameter groups, setting their 'params' value
        def update_group(group: Dict[str, Any], new_group: Dict[str, Any]) -> Dict[str, Any]:
            new_group["params"] = group["params"]
            return new_group

        param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)]
        self.__setstate__({"state": state, "param_groups": param_groups})

        for post_hook in self._optimizer_load_state_dict_post_hooks.values():
            post_hook(self)

    @staticmethod
    def _process_value_according_to_param_policy(
        param: torch.Tensor,
        value: torch.Tensor,
        param_id: int,
        param_groups: List[Dict[Any, Any]],
        key: Hashable = None,
    ) -> torch.Tensor:
        # Floating-point types are a bit special here. They are the only ones
        # that are assumed to always match the type of params.
        # Make sure state['step'] is not casted https://github.com/pytorch/pytorch/issues/74424
        # UNLESS fused or capturable, see note [special device hosting for step]
        fused = False
        capturable = False
        assert param_groups is not None
        for pg in param_groups:
            if param_id in pg["params"]:
                fused = pg["fused"] if "fused" in pg else False
                capturable = pg["capturable"] if "capturable" in pg else False
                break
        if key == "step":
            if capturable or fused:
                return value.to(dtype=torch.float32, device=param.device)
            else:
                return value
        else:
            assert value.dtype in [torch.float8_e4m3fn, torch.float8_e5m2, torch.float32]
            return value.to(device=param.device)  # do not cast optimizer states
            # if param.is_floating_point():
            #     return value.to(dtype=param.dtype, device=param.device)
            # else:
            #     return value.to(device=param.device)

    @torch.no_grad()
    def step(self, closure=None):
        """Perform a single optimization step.

        Args:
            closure (Callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        self._cuda_graph_capture_health_check()

        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            params_with_grad = []
            grads = []
            exp_avgs = []
            scale_exp_avgs = []
            expand_exp_avgs = []
            sqrt_minmax_exp_avgs = []
            exp_avg_sqs = []
            scale_exp_avg_sqs = []
            expand_exp_avg_sqs = []
            sqrt_minmax_exp_avg_sqs = []
            max_exp_avg_sqs = []
            state_steps = []
            amsgrad = group["amsgrad"]
            use_expansion = self.qargs.first_order_expansion in ["expansion", "true"]
            beta1, beta2 = group["betas"]

            self._init_group(
                group,
                params_with_grad,
                grads,
                amsgrad,
                use_expansion,
                exp_avgs,
                scale_exp_avgs,
                expand_exp_avgs,
                sqrt_minmax_exp_avgs,
                exp_avg_sqs,
                scale_exp_avg_sqs,
                expand_exp_avg_sqs,
                sqrt_minmax_exp_avg_sqs,
                max_exp_avg_sqs,
                state_steps,
            )

            Coatadamw(
                self.qargs,
                params_with_grad,
                grads,
                exp_avgs,
                scale_exp_avgs,
                expand_exp_avgs,
                sqrt_minmax_exp_avgs,
                exp_avg_sqs,
                scale_exp_avg_sqs,
                expand_exp_avg_sqs,
                sqrt_minmax_exp_avg_sqs,
                max_exp_avg_sqs,
                state_steps,
                amsgrad=amsgrad,
                use_expansion=use_expansion,
                beta1=beta1,
                beta2=beta2,
                lr=group["lr"],
                weight_decay=group["weight_decay"],
                eps=group["eps"],
                qgroup_size=self.qargs.qgroup_size,
                expand_min=self.qargs.expand_min,
                fused=group["fused"],
                grad_scale=getattr(self, "grad_scale", None),
                found_inf=getattr(self, "found_inf", None),
            )

        return loss


def Coatadamw(
    qargs,
    params: List[Tensor],
    grads: List[Tensor],
    exp_avgs: List[Tensor],
    scale_exp_avgs: List[Tensor],
    expand_exp_avgs: List[Tensor],
    sqrt_minmax_exp_avgs: List[Tensor],
    exp_avg_sqs: List[Tensor],
    scale_exp_avg_sqs: List[Tensor],
    expand_exp_avg_sqs: List[Tensor],
    sqrt_minmax_exp_avg_sqs: List[Tensor],
    max_exp_avg_sqs: List[Tensor],
    state_steps: List[Tensor],
    # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
    # setting this as kwarg for now as functional API is compiled by torch/distributed/optim
    fused: Optional[bool] = None,
    grad_scale: Optional[Tensor] = None,
    found_inf: Optional[Tensor] = None,
    *,
    amsgrad: bool,
    use_expansion: bool,
    beta1: float,
    beta2: float,
    lr: Union[float, Tensor],
    weight_decay: float,
    eps: float,
    qgroup_size: int,
    expand_min: int,
):
    r"""Functional API that performs AdamW algorithm computation.

    See :class:`~torch.optim.AdamW` for details.
    """
    if not torch._utils.is_compiling() and not all(isinstance(t, torch.Tensor) for t in state_steps):
        raise RuntimeError("API has changed, `state_steps` argument must contain a list of singleton tensors")

    func = _single_tensor_Coatadamw

    func(
        qargs,
        params,
        grads,
        exp_avgs,
        scale_exp_avgs,
        expand_exp_avgs,
        sqrt_minmax_exp_avgs,
        exp_avg_sqs,
        scale_exp_avg_sqs,
        expand_exp_avg_sqs,
        sqrt_minmax_exp_avg_sqs,
        max_exp_avg_sqs,
        state_steps,
        amsgrad=amsgrad,
        use_expansion=use_expansion,
        beta1=beta1,
        beta2=beta2,
        lr=lr,
        weight_decay=weight_decay,
        eps=eps,
        qgroup_size=qgroup_size,
        expand_min=expand_min,
        grad_scale=grad_scale,
        found_inf=found_inf,
    )


def _dispatch_sqrt(x: float):  # float annotation is needed because of torchscript type inference
    if not torch.jit.is_scripting() and isinstance(x, torch.Tensor):
        return x.sqrt()
    else:
        return sqrt(x)


def _single_tensor_Coatadamw(
    qargs,
    params: List[Tensor],
    grads: List[Tensor],
    exp_avgs: List[Tensor],
    scale_exp_avgs: List[Tensor],
    expand_exp_avgs: List[Tensor],
    sqrt_minmax_exp_avgs: List[Tensor],
    exp_avg_sqs: List[Tensor],
    scale_exp_avg_sqs: List[Tensor],
    expand_exp_avg_sqs: List[Tensor],
    sqrt_minmax_exp_avg_sqs: List[Tensor],
    max_exp_avg_sqs: List[Tensor],
    state_steps: List[Tensor],
    grad_scale: Optional[Tensor],
    found_inf: Optional[Tensor],
    *,
    amsgrad: bool,
    use_expansion: bool,
    beta1: float,
    beta2: float,
    lr: Union[Tensor, float],
    weight_decay: float,
    eps: float,
    qgroup_size: int,
    expand_min: int,
):

    assert grad_scale is None and found_inf is None

    if torch.jit.is_scripting():
        # this assert is due to JIT being dumb and not realizing that the ops below
        # have overloads to handle both float and Tensor lrs, so we just assert it's
        # a float since most people using JIT are using floats
        assert isinstance(lr, float)

    for i, param in enumerate(params):
        grad = grads[i]
        # First order
        exp_avg = exp_avgs[i]
        scale_exp_avg = scale_exp_avgs[i]
        # Second order
        exp_avg_sq = exp_avg_sqs[i]
        scale_exp_avg_sq = scale_exp_avg_sqs[i]
        step_t = state_steps[i]

        # print(len(exp_avg.unique()), len(exp_avg_sq.unique()))
        # print(f"{param.shape}, {grad.shape}, {exp_avg.shape}, {exp_avg_sq.shape}", file=open('debug.txt', 'a'))

        # update step
        step_t += 1
        step = int(step_t.item())

        # Perform Optimizer Step
        if use_expansion:
            expand_exp_avg = expand_exp_avgs[i]
            sqrt_minmax_exp_avg = sqrt_minmax_exp_avgs[i]
            expand_exp_avg_sq = expand_exp_avg_sqs[i]
            sqrt_minmax_exp_avg_sq = sqrt_minmax_exp_avg_sqs[i]

            qoptim_cuda.fp8_adamw_expand_step(
                param,
                grad,
                exp_avg,
                scale_exp_avg,
                expand_exp_avg,
                sqrt_minmax_exp_avg,
                exp_avg_sq,
                scale_exp_avg_sq,
                expand_exp_avg_sq,
                sqrt_minmax_exp_avg_sq,
                beta1,
                beta2,
                lr,
                weight_decay,
                eps,
                step,
                qgroup_size,
                expand_min,
            )

        else:
            qoptim_cuda.fp8_adamw_step(
                param,
                grad,
                exp_avg,
                scale_exp_avg,
                exp_avg_sq,
                scale_exp_avg_sq,
                beta1,
                beta2,
                lr,
                weight_decay,
                eps,
                step,
                qgroup_size,
            )