File size: 8,939 Bytes
174ae06
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
# Copyright (c) 2025 NVIDIA CORPORATION.
# Licensed under the MIT license.

# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
# LICENSE is in incl_licenses directory.

# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0

import os
import time
from copy import deepcopy
from dataclasses import dataclass

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd.function import Function

from ..utils import quant_get_local_rank
from ._division_transpose import fp8_division_transpose
from ._quantize_pertensor_transpose import fp8_quantize_pertensor_transpose
from .linear import fp8_linear_backward, fp8_linear_forward


@dataclass
class DefaultArgs:
    fabit: int
    fwbit: int
    bobit: int


class FP8Linear(nn.Linear):
    def __init__(self, in_features, out_features, bias=True, device=None, args=None, layer_idx=0):
        super().__init__(in_features, out_features, bias, device)

        if args is None:  # I do not want to pass a new argument to OLMo so just use this method
            args = DefaultArgs(
                fabit=os.environ["FABIT_FP8Linear"],
                fwbit=os.environ["FWBIT_FP8Linear"],
                bobit=os.environ["BOBIT_FP8Linear"],
            )
        self.args = deepcopy(args)

        if quant_get_local_rank() == 0:
            print(f"[qlinear debug] Apply QLinear, {layer_idx}")

        self.layer_idx = layer_idx
        self.layer_name = None

    def forward(self, Input):
        if self.training:
            # if False:
            output = QuantLinearTE.apply(Input, self.weight, self.bias, self.args, self.layer_name)
        else:
            output = F.linear(Input, self.weight, self.bias)

        return output


# if int(os.environ.get("LOCAL_RANK")) == 0:
#     import IPython
#     IPython.embed()
# else:
#     import time
#     time.sleep(1000)

# class QuantLinearTE(Function):
#     @staticmethod
#     def forward(ctx, input, weight, bias, args, layer_type):
#         ctx.saved = input, weight, bias, args, layer_type
#         return F.linear(input, weight, bias)

#     @staticmethod
#     def backward(ctx, grad_output):
#         input, weight, bias, args, layer_type = ctx.saved

#         C_in = input.shape[-1]
#         C_out = grad_output.shape[-1]

#         grad_output_flatten = grad_output.reshape(-1, C_out)
#         input_flatten = input.reshape(-1, C_in)

#         if grad_output_flatten.dtype == input_flatten.dtype:
#             grad_weight = grad_output_flatten.t().mm(input_flatten)
#         else:
#             grad_weight = grad_output_flatten.float().t().mm(input_flatten)

#         if grad_output_flatten.dtype == weight.dtype:
#             grad_input = grad_output_flatten.mm(weight)
#         else:
#             grad_input = grad_output_flatten.float().mm(weight)

#         if bias is not None:
#             grad_bias = grad_output_flatten.sum(0)
#         else:
#             grad_bias = None

#         grad_input_transform = grad_input.reshape(input.size())

#         return grad_input_transform, grad_weight, grad_bias, None, None


class QuantLinearTE(Function):
    @staticmethod
    @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.bfloat16)
    def forward(ctx, input, weight, bias, args, layer_name):

        time_bench = os.getenv("TIME_BENCH")

        if time_bench:
            start_1 = torch.cuda.Event(enable_timing=True)
            start_1.record()

        # Qinput, Iscale, Qinput_t = fp8_division_transpose(input, 16, args.fabit)
        Qinput, Iscale, Qinput_t = fp8_quantize_pertensor_transpose(input, 16, args.fabit, transpose_output_2d=True)

        if time_bench:
            end_1 = torch.cuda.Event(enable_timing=True)
            end_1.record()
            start_2 = torch.cuda.Event(enable_timing=True)
            start_2.record()

        # Qweight, Wscale, Qweight_t = fp8_division_transpose(weight, 16, args.fwbit)
        Qweight, Wscale, Qweight_t = fp8_quantize_pertensor_transpose(weight, 16, args.fwbit, transpose_output_2d=True)

        if time_bench:
            end_2 = torch.cuda.Event(enable_timing=True)
            end_2.record()
            start_3 = torch.cuda.Event(enable_timing=True)
            start_3.record()

        ctx.saved = Qinput_t, Iscale, Qweight_t, Wscale, bias, args, layer_name
        fc_output = fp8_linear_forward(Qinput, Iscale, Qweight, Wscale, False, 0, bias)

        if time_bench:
            end_3 = torch.cuda.Event(enable_timing=True)
            end_3.record()
            start_4 = torch.cuda.Event(enable_timing=True)
            start_4.record()

            output = F.linear(input, weight, bias)

            end_4 = torch.cuda.Event(enable_timing=True)
            end_4.record()

            torch.cuda.synchronize()
            if quant_get_local_rank() == 0:
                print(
                    f"[Forward] Part 1: {start_1.elapsed_time(end_1):.6f} ms | Part 2: {start_2.elapsed_time(end_2):.6f} ms | Part 3: {start_3.elapsed_time(end_3):.6f} ms | "
                    f"FP8: {start_1.elapsed_time(end_3):.6f} | BF16: {start_4.elapsed_time(end_4):.6f} | Input shape: {input.shape} | Weight shape: {weight.shape}"
                )

        return fc_output

    @staticmethod
    @torch.amp.custom_bwd(device_type="cuda")
    def backward(ctx, grad_output):
        Qinput_t, Iscale, Qweight_t, Wscale, bias, args, layer_name = ctx.saved

        time_bench = os.getenv("TIME_BENCH")
        if time_bench:
            start_1 = torch.cuda.Event(enable_timing=True)
            start_1.record()

        # Qgrad_output, Gscale, Qgrad_output_t = fp8_division_transpose(grad_output, 16, args.bobit, stochastic=False)
        Qgrad_output, Gscale, Qgrad_output_t = fp8_quantize_pertensor_transpose(
            grad_output, 16, args.bobit, stochastic=False, transpose_output_2d=True
        )

        if time_bench:
            end_1 = torch.cuda.Event(enable_timing=True)
            end_1.record()
            start_2 = torch.cuda.Event(enable_timing=True)
            start_2.record()

        grad_input, grad_weight = fp8_linear_backward(
            Qinput_t,
            Iscale,
            Qgrad_output,
            Gscale,
            Qgrad_output_t,
            Qweight_t,
            Wscale,
            16,
            bias,
            stochastic=False,
            dgrad_quantize=False,
        )

        if time_bench:
            end_2 = torch.cuda.Event(enable_timing=True)
            end_2.record()
            start_3 = torch.cuda.Event(enable_timing=True)
            start_3.record()

        if bias is not None:
            grad_bias = grad_output.reshape(-1, grad_output.shape[-1]).sum(0)
        else:
            grad_bias = None

        if time_bench:
            end_3 = torch.cuda.Event(enable_timing=True)
            end_3.record()

            # ========== BF16 ==========
            C_in = Qinput_t.shape[0]
            C_out = grad_output.shape[-1]
            grad_output_flatten = grad_output.reshape(-1, C_out)
            input_flatten = Qinput_t.t().reshape(-1, C_in).to(torch.bfloat16)
            weight = Qweight_t.t().to(torch.bfloat16)

            start_4 = torch.cuda.Event(enable_timing=True)
            start_4.record()

            if grad_output_flatten.dtype == input_flatten.dtype:
                _grad_weight = grad_output_flatten.t().mm(input_flatten)
            else:
                _grad_weight = grad_output_flatten.float().t().mm(input_flatten)

            if grad_output_flatten.dtype == weight.dtype:
                _grad_input = grad_output_flatten.mm(weight)
            else:
                _grad_input = grad_output_flatten.float().mm(weight)

            end_4 = torch.cuda.Event(enable_timing=True)
            end_4.record()

            torch.cuda.synchronize()
            if quant_get_local_rank() == 0:
                print(
                    f"[Backward] Part 1: {start_1.elapsed_time(end_1):.6f} ms | Part 2: {start_2.elapsed_time(end_2):.6f} ms | Part 3: {start_3.elapsed_time(end_3):.6f} ms | "
                    f"FP8: {start_1.elapsed_time(end_3):.6f} | BF16: {start_4.elapsed_time(end_4):.6f} | Input shape: {Qinput_t.shape} | Weight shape: {weight.shape}"
                )

        return grad_input, grad_weight, grad_bias, None, None