File size: 1,997 Bytes
174ae06
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
# Copyright (c) 2025 NVIDIA CORPORATION.
# Licensed under the MIT license.

# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
# LICENSE is in incl_licenses directory.

from dataclasses import dataclass

from transformers import PretrainedConfig


@dataclass
class QuantizationConfig:
    quantize_model: str = "false"
    symm: bool = True
    epsilon: float = 1e-10
    fabit: str = "E4M3"
    fwbit: str = "E4M3"
    bobit: str = "E5M2"
    row_blocksize: int = -1
    col_blocksize: int = -1
    qchoice: str = "none"
    pad_to_multiple_of: int = 0

    def __init__(
        self,
        quantize_model,
        symm,
        epsilon,
        fabit,
        fwbit,
        bobit,
        row_blocksize,
        col_blocksize,
        qchoice,
        pad_to_multiple_of,
        **kwargs,
    ):
        super().__init__()
        self.quantize_model = quantize_model
        self.symm = symm
        self.epsilon = epsilon
        self.fabit = fabit
        self.fwbit = fwbit
        self.bobit = bobit
        self.row_blocksize = row_blocksize
        self.col_blocksize = col_blocksize
        self.qchoice = qchoice
        self.pad_to_multiple_of = pad_to_multiple_of


# class QuantizationConfig(PretrainedConfig):
#     def __init__(
#         self,
#         quantize_model="false",
#         symm=True,
#         epsilon=1e-10,
#         fabit="E4M3",
#         fwbit="E4M3",
#         bobit="E5M2",
#         row_blocksize=-1,
#         col_blocksize=-1,
#         qchoice="none",
#         pad_to_multiple_of=0,
#         **kwargs,
#     ):
#         super().__init__()
#         self.quantize_model = quantize_model
#         self.symm = symm
#         self.epsilon = epsilon
#         self.fabit = fabit
#         self.fwbit = fwbit
#         self.bobit = bobit
#         self.row_blocksize = row_blocksize
#         self.col_blocksize = col_blocksize
#         self.qchoice = qchoice
#         self.pad_to_multiple_of = pad_to_multiple_of