File size: 9,126 Bytes
174ae06
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
# Copyright (c) 2025 NVIDIA CORPORATION.
# Licensed under the MIT license.

# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
# LICENSE is in incl_licenses directory.

# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0

import re

import torch

from .FloatPointQuantizeTorch import *
from .FloatPointQuantizeTriton import *


def block_cut(input, row_block, column_block, pad_block=False):
    # print(input.shape)
    original_shape = input.shape
    # input tensor shape is M * N
    if len(input.shape) > 2:
        input = input.reshape(-1, input.shape[2])
    elif len(input.shape) == 2:
        pass
    else:
        raise ValueError(f"input shape {input.shape} does not match for block cut, {input}")
    M, N = input.shape[0], input.shape[1]

    if row_block == -1:
        row_block = M
    if column_block == -1:
        column_block = N

    if pad_block:
        row_remainder, col_remainder = M % row_block, N % column_block
        if row_remainder:
            row_pad = row_block - row_remainder
        else:
            row_pad = 0
        if col_remainder:
            col_pad = column_block - col_remainder
        else:
            col_pad = 0

        input = torch.nn.functional.pad(
            input, (0, col_pad, 0, row_pad), "constant", 0
        )  # refer to torch's doc to see why
        M, N = input.shape[0], input.shape[1]
        row_num, column_num = M // row_block, N // column_block
    else:
        row_num, column_num = M // row_block, N // column_block

    assert row_num * row_block == M, f"{row_num}, {row_block}, {M}, {original_shape}"
    assert column_num * column_block == N, f"{column_num}, {column_block}, {N}, {original_shape}"
    # print(input.shape)
    input = (
        input.reshape(row_num, row_block, column_num, column_block)
        .permute(0, 2, 1, 3)
        .reshape(row_num * column_num, row_block, column_block)
    )
    # print(input.shape)
    return input


def block_reshape(input, origin_input, row_block, column_block, pad_block=False):
    if len(origin_input.shape) > 2:
        flatten_input = origin_input.reshape(-1, origin_input.shape[2])
    elif len(origin_input.shape) == 2:
        flatten_input = origin_input
    else:
        raise ValueError(f"input shape {input.shape} does not match for block cut")

    M, N = flatten_input.shape[0], flatten_input.shape[1]

    if row_block == -1:
        row_block = M
    if column_block == -1:
        column_block = N

    if pad_block:
        row_remainder, col_remainder = M % row_block, N % column_block
        if row_remainder:
            row_pad = row_block - row_remainder
        else:
            row_pad = 0
        if col_remainder:
            col_pad = column_block - col_remainder
        else:
            col_pad = 0

        pad_origin_input = torch.nn.functional.pad(origin_input, (0, col_pad, 0, row_pad), "constant", 0)
        M, N = pad_origin_input.shape[0], pad_origin_input.shape[1]
        row_num, column_num = M // row_block, N // column_block
    else:
        row_num, column_num = M // row_block, N // column_block

    input = (
        input.reshape(row_num, column_num, row_block, column_block)
        .permute(0, 2, 1, 3)
        .reshape(row_num * row_block, column_num * column_block)
    )

    M, N = flatten_input.shape[0], flatten_input.shape[1]
    input = input[:M, :N]

    if len(origin_input.shape) > 2:
        input = input.reshape(origin_input.shape)
    elif len(origin_input.shape) == 2:
        pass
    else:
        raise ValueError(f"input shape {input.shape} does not match for block reshape")

    return input


def block_verify_int8(input, row_block, column_block, layer_type, necessary=True):
    Binput = block_cut(input, row_block, column_block)
    Binput = Binput.to(torch.float32)

    for n in range(Binput.shape[0]):
        unique_values = len(torch.unique(Binput[n, :, :]))
        if unique_values > 256:
            if necessary:
                raise ValueError(f"{layer_type} contains more than 256 unique values.")
            else:
                return False
    return True


def block_quant(input, symm, bits, stochastic, epsilon, apply_quantize, layer_name):
    Quant_fn = SymmQuantizer
    return Quant_fn.apply(input, symm, bits, stochastic, epsilon, apply_quantize, layer_name)


def extract_bit(string):
    match = re.match(r"INT(\d+)", string)  # INT8
    if match:
        return "integer", int(match.group(1)), None
    match = re.match(r"E(\d+)M(\d+)", string)  # E4M3 / E5M2
    if match:
        Ebit, Mbit = int(match.group(1)), int(match.group(2))
        if Ebit == 1:
            return "integer", Mbit + 1, None
        if Mbit == 0:
            return "floatExM0", int(match.group(1)), 0
        return "floatExMy", int(match.group(1)), int(match.group(2))
    match = re.match(r"DE(\d+)", string)
    if match:
        return "Dynamic", int(match.group(1)), None
    match = re.match(r"ZeroD(\d+)", string)
    if match:
        return "ZeroDynamic", int(match.group(1)), None
    raise ValueError(f"{string} data format is not supported")


class SymmQuantizer(torch.autograd.function.InplaceFunction):
    @staticmethod
    def forward(ctx, input, symm, bits, stochastic, epsilon, apply_quantize=True, layer_name=None):
        with torch.no_grad():
            absmax_per_block = input.abs().amax(dim=(1, 2)).unsqueeze(1).unsqueeze(2) + epsilon

            if bits == "100" or not apply_quantize:
                return input, input, torch.ones_like(absmax_per_block)
            elif bits == "FP32":
                return input.to(torch.float32), input.to(torch.float32), torch.ones_like(absmax_per_block)
            elif bits == "FP16":
                return input.to(torch.float16), input.to(torch.float16), torch.ones_like(absmax_per_block)
            elif bits == "BF16":
                return input.to(torch.bfloat16), input.to(torch.bfloat16), torch.ones_like(absmax_per_block)
            else:
                QuantType, bit1, bit2 = extract_bit(bits)
                if not symm:
                    bit1 = bit1 + 1  # pretend to be asymmtric

                if QuantType == "integer":
                    Qn, Qp = -(2 ** (bit1 - 1) - 1), 2 ** (bit1 - 1) - 1
                elif QuantType == "floatExMy":
                    Qn, Qp = -(2 - 2 ** (-bit2)) * (2 ** (2 ** (bit1 - 1))), (2 - 2 ** (-bit2)) * (
                        2 ** (2 ** (bit1 - 1))
                    )
                    if bit1 == 4 and bit2 == 3:  # E4M3
                        Qn, Qp = -448, 448
                    if bit1 == 5 and bit2 == 2:  # E5M2
                        Qn, Qp = -57344, 57344
                elif QuantType == "floatExM0":
                    Qn, Qp = -(2 ** (2 ** (bit1 - 1))) + 1, 2 ** (2 ** (bit1 - 1))
                elif QuantType == "Dynamic":
                    Qn, Qp = -1, 1
                elif QuantType == "ZeroDynamic":
                    Qn, Qp = -1, 1
                else:
                    raise NotImplementedError(f"{bits} is not supported by quantization")
                scale_per_block = (2 * absmax_per_block) / (Qp - Qn)
                scale_per_block = scale_per_block.to(input)

                Qinput = input / scale_per_block

                if QuantType == "integer":
                    if stochastic:
                        noise = Qinput.new(Qinput.shape).uniform_(-0.5, 0.5)
                        Qinput.add_(noise)
                    Qinput.clamp_(Qn, Qp).round_()
                elif QuantType == "floatExMy":
                    # Qinput = floatExMy_quantize_torch(Qinput, bit1, bit2, stochastic)
                    Qinput = floatExMy_quantize_triton(Qinput, bit1, bit2, stochastic)
                elif QuantType == "floatExM0":
                    Qinput = floatExM0_quantize_torch(Qinput, bit1, stochastic)
                else:
                    raise NotImplementedError(f"{bits} is not supported by quantization")

                RQinput = Qinput * scale_per_block

                if input.dtype != Qinput.dtype:
                    print(
                        f"Input type is {input.dtype}, Qinput type is {Qinput.dtype}, scale_per_block type is {scale_per_block.dtype}",
                        file=open("debug.txt", "a"),
                    )
                    import IPython

                    IPython.embed()
                return RQinput, Qinput, scale_per_block

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output, None, None, None, None, None