Spaces:
Running
on
A100
Running
on
A100
# Copyright (c) 2025 NVIDIA CORPORATION. | |
# Licensed under the MIT license. | |
# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. | |
# LICENSE is in incl_licenses directory. | |
import torch | |
import torch.nn as nn | |
from torch.autograd.function import Function, InplaceFunction | |
try: | |
from .Qconfig import qconfig | |
from .QFunction import * | |
from .utils import * | |
except: | |
from Qconfig import qconfig | |
from utils import * | |
from QFunction import * | |
import os | |
from copy import deepcopy | |
import matplotlib.pyplot as plt | |
class QAct_FPout(nn.Identity): | |
def __init__(self, args, normalize_before=False, layer_type=""): | |
super().__init__() | |
self.args = deepcopy(args) | |
self.normalize_before = normalize_before | |
self.layer_type = layer_type | |
assert layer_type != "", "layer_type is not defined" | |
assert layer_type in qconfig.qact_config, f"{layer_type} not in qact_config" | |
self.apply_quantize = list_has_common_element(args.qchoice, qconfig.qact_config[layer_type]) | |
self.apply_quantize_f, self.apply_quantize_b = self.apply_quantize, self.apply_quantize | |
self.refine_rowcol_blocksize() | |
self.fbit = self.args.fabit if self.args.fabit else self.Ubit | |
self.bbit = self.args.babit if self.args.babit else self.Ubit | |
quantize_flag = format_string_with_condition( | |
layer_type, | |
{"apply-f": self.apply_quantize_f, "apply-b": self.apply_quantize_b}, | |
self.args.symm, | |
self.fbit, | |
self.bbit, | |
{ | |
"row-f": self.args.row_blocksize_f, | |
"col-f": self.args.col_blocksize_f, | |
"row-b": self.args.row_blocksize_b, | |
"col-b": self.args.col_blocksize_b, | |
}, | |
) | |
if quant_get_local_rank() == 0: | |
print(quantize_flag) | |
def refine_rowcol_blocksize(self): | |
self.args.row_blocksize_f, self.args.col_blocksize_f = self.args.row_blocksize, self.args.col_blocksize | |
self.args.row_blocksize_b, self.args.col_blocksize_b = self.args.row_blocksize, self.args.col_blocksize | |
if self.args.refine_residual_fp: | |
if self.layer_type in ["add_attn_in_re", "add_mlp_in_re"]: | |
self.apply_quantize_f, self.apply_quantize_b = False, False | |
if self.args.refine_ln_blocksize: | |
if self.layer_type in ["ln_attn_in"]: | |
if self.args.refine_ln_pertoken: | |
self.args.row_blocksize_f, self.args.col_blocksize_f = ( | |
1, | |
self.args.refine_row_blocksize * self.args.refine_col_blocksize, | |
) | |
self.args.row_blocksize_b, self.args.col_blocksize_b = ( | |
1, | |
self.args.refine_row_blocksize * self.args.refine_col_blocksize, | |
) | |
else: | |
self.args.row_blocksize_f, self.args.col_blocksize_f = ( | |
self.args.refine_row_blocksize, | |
self.args.refine_col_blocksize, | |
) | |
self.args.row_blocksize_b, self.args.col_blocksize_b = ( | |
self.args.refine_row_blocksize, | |
self.args.refine_col_blocksize, | |
) | |
assert not ( | |
self.args.refine_ln_blocksize_but_only_forward and self.args.refine_ln_blocksize_but_only_backward | |
) # This will not happen at the same time | |
if self.args.refine_ln_blocksize_but_only_forward: | |
self.apply_quantize_f, self.apply_quantize_b = True, False | |
if self.args.refine_ln_blocksize_but_only_backward: | |
self.apply_quantize_f, self.apply_quantize_b = False, True | |
if self.layer_type in [ | |
"ln_mlp_in", | |
]: | |
if self.args.refine_ln_pertoken: | |
self.args.row_blocksize_f, self.args.col_blocksize_f = ( | |
1, | |
self.args.refine_row_blocksize * self.args.refine_col_blocksize, | |
) | |
self.args.row_blocksize_b, self.args.col_blocksize_b = ( | |
1, | |
self.args.refine_row_blocksize * self.args.refine_col_blocksize, | |
) | |
else: | |
self.args.row_blocksize_f, self.args.col_blocksize_f = ( | |
self.args.refine_row_blocksize, | |
self.args.refine_col_blocksize, | |
) | |
self.args.row_blocksize_b, self.args.col_blocksize_b = ( | |
self.args.refine_row_blocksize, | |
self.args.refine_col_blocksize, | |
) | |
assert not ( | |
self.args.refine_ln_blocksize_but_only_forward and self.args.refine_ln_blocksize_but_only_backward | |
) # This will not happen at the same time | |
if self.args.refine_ln_blocksize_but_only_forward: | |
self.apply_quantize_f, self.apply_quantize_b = True, False | |
if self.args.refine_ln_blocksize_but_only_backward: | |
self.apply_quantize_f, self.apply_quantize_b = False, True | |
if self.args.refine_attn_blocksize: | |
if self.layer_type in ["ln_attn_in"]: | |
if self.args.refine_ln_pertoken: | |
self.args.row_blocksize_f, self.args.col_blocksize_f = ( | |
1, | |
self.args.refine_row_blocksize * self.args.refine_col_blocksize, | |
) | |
self.args.row_blocksize_b, self.args.col_blocksize_b = ( | |
1, | |
self.args.refine_row_blocksize * self.args.refine_col_blocksize, | |
) | |
else: | |
self.args.row_blocksize_f, self.args.col_blocksize_f = ( | |
self.args.refine_row_blocksize, | |
self.args.refine_col_blocksize, | |
) | |
self.args.row_blocksize_b, self.args.col_blocksize_b = ( | |
self.args.refine_row_blocksize, | |
self.args.refine_col_blocksize, | |
) | |
if self.layer_type in ["attn_qkv_sum"]: | |
self.args.row_blocksize_b, self.args.col_blocksize_b = ( | |
self.args.refine_row_blocksize, | |
self.args.refine_col_blocksize, | |
) | |
if self.layer_type in ["add_attn_in_fx"]: | |
self.args.row_blocksize_f, self.args.col_blocksize_f = ( | |
self.args.refine_row_blocksize, | |
self.args.refine_col_blocksize, | |
) | |
if self.args.refine_mlp_blocksize: | |
if self.layer_type in [ | |
"ln_mlp_in", | |
]: | |
if self.args.refine_ln_pertoken: | |
self.args.row_blocksize_f, self.args.col_blocksize_f = ( | |
1, | |
self.args.refine_row_blocksize * self.args.refine_col_blocksize, | |
) | |
self.args.row_blocksize_b, self.args.col_blocksize_b = ( | |
1, | |
self.args.refine_row_blocksize * self.args.refine_col_blocksize, | |
) | |
else: | |
self.args.row_blocksize_f, self.args.col_blocksize_f = ( | |
self.args.refine_row_blocksize, | |
self.args.refine_col_blocksize, | |
) | |
self.args.row_blocksize_b, self.args.col_blocksize_b = ( | |
self.args.refine_row_blocksize, | |
self.args.refine_col_blocksize, | |
) | |
if self.layer_type in ["mlp_act_sum"]: | |
self.args.row_blocksize_b, self.args.col_blocksize_b = ( | |
self.args.refine_row_blocksize, | |
self.args.refine_col_blocksize, | |
) | |
if self.layer_type in ["mlp_act_in"]: | |
self.args.row_blocksize_f, self.args.col_blocksize_f = ( | |
self.args.refine_row_blocksize, | |
self.args.refine_col_blocksize, | |
) | |
if self.layer_type in [ | |
"mul_act_in1", | |
]: | |
self.args.row_blocksize_f, self.args.col_blocksize_f = ( | |
self.args.refine_row_blocksize, | |
self.args.refine_col_blocksize, | |
) | |
self.args.row_blocksize_b, self.args.col_blocksize_b = ( | |
self.args.refine_row_blocksize, | |
self.args.refine_col_blocksize, | |
) | |
if self.layer_type in [ | |
"mul_act_in2", | |
]: | |
self.args.row_blocksize_f, self.args.col_blocksize_f = ( | |
self.args.refine_row_blocksize, | |
self.args.refine_col_blocksize, | |
) | |
if self.layer_type in ["add_mlp_in_fx"]: | |
self.args.row_blocksize_f, self.args.col_blocksize_f = ( | |
self.args.refine_row_blocksize, | |
self.args.refine_col_blocksize, | |
) | |
def forward(self, Qinput, Iscale): | |
# input shape is (Batch Size, Sequence Length, Hidden Size) | |
if self.training: | |
return QuantAct_FPout.apply( | |
Qinput, Iscale, self.args, self.layer_name, self.apply_quantize_f, self.apply_quantize_b | |
) | |
else: | |
return Qinput | |
class QuantAct_FPout(Function): | |
def forward(ctx, Qinput, Iscale, args, layer_name, apply_quantize_f=True, apply_quantize_b=True): | |
ctx.saved = args, layer_name, apply_quantize_f, apply_quantize_b | |
# shrink Iscale to let the size of gradient the same as forward | |
ideal_scale_num = Qinput.numel() / (args.min_blockunit_row * args.min_blockunit_col) | |
actual_scale_num = calculate_scale_num(Qinput, args.row_blocksize_f, args.col_blocksize_f) | |
# actual_scale_num = Qinput.numel() / (args.row_blocksize_f * args.col_blocksize_f) | |
assert Iscale.shape[0] == ideal_scale_num | |
Iscale = Iscale[: int(actual_scale_num), :, :] | |
Binput = block_cut(Qinput, args.row_blocksize_f, args.col_blocksize_f) | |
input = Binput * Iscale | |
input = block_reshape(input, Qinput, args.row_blocksize_f, args.col_blocksize_f) | |
if args.draw_distribution_forward: | |
save_tensor(input, None, None, fb="forward", aw="Activation", layer_name=layer_name) | |
return input | |
def backward(ctx, grad_output): | |
args, layer_name, apply_quantize_f, apply_quantize_b = ctx.saved | |
Bgrad_output = block_cut(grad_output, args.row_blocksize_b, args.col_blocksize_b) | |
RQgrad_output, Qgrad_output, Gscale = block_quant( | |
Bgrad_output, | |
args.symm, | |
args.babit, | |
stochastic=True, | |
epsilon=args.epsilon, | |
apply_quantize=apply_quantize_b, | |
layer_name=layer_name, | |
) | |
Qgrad_output = block_reshape(Qgrad_output, grad_output, args.row_blocksize_b, args.col_blocksize_b) | |
if args.draw_distribution_backward: | |
save_tensor(grad_output, RQgrad_output, Qgrad_output, fb="backward", aw="Activation", layer_name=layer_name) | |
# enlarge grad_output to let the size of gradient the same as forward | |
ideal_scale_num = grad_output.numel() / (args.min_blockunit_row * args.min_blockunit_col) | |
actual_scale_num = calculate_scale_num(grad_output, args.row_blocksize_b, args.col_blocksize_b) | |
# actual_scale_num = grad_output.numel() / (args.row_blocksize_b * args.col_blocksize_b) | |
assert Gscale.shape[0] == actual_scale_num | |
Gscale = torch.nn.functional.pad(Gscale, (0, 0, 0, 0, 0, int(ideal_scale_num - actual_scale_num))) | |
return Qgrad_output, Gscale, None, None, None, None | |
class QAct_FPin(nn.Identity): | |
def __init__(self, args, normalize_before=False, layer_type=""): | |
super().__init__() | |
self.args = deepcopy(args) | |
self.normalize_before = normalize_before | |
self.layer_type = layer_type | |
assert layer_type != "", "layer_type is not defined" | |
assert layer_type in qconfig.qact_config, f"{layer_type} not in qact_config" | |
self.apply_quantize = list_has_common_element(args.qchoice, qconfig.qact_config[layer_type]) | |
self.apply_quantize_f, self.apply_quantize_b = self.apply_quantize, self.apply_quantize | |
self.refine_rowcol_blocksize() | |
self.fbit = self.args.fabit if self.args.fabit else self.Ubit | |
self.bbit = self.args.babit if self.args.babit else self.Ubit | |
quantize_flag = format_string_with_condition( | |
layer_type, | |
{"apply-f": self.apply_quantize_f, "apply-b": self.apply_quantize_b}, | |
self.args.symm, | |
self.fbit, | |
self.bbit, | |
{ | |
"row-f": self.args.row_blocksize_f, | |
"col-f": self.args.col_blocksize_f, | |
"row-b": self.args.row_blocksize_b, | |
"col-b": self.args.col_blocksize_b, | |
}, | |
) | |
if quant_get_local_rank() == 0: | |
print(quantize_flag) | |
def refine_rowcol_blocksize(self): | |
self.args.row_blocksize_f, self.args.col_blocksize_f = self.args.row_blocksize, self.args.col_blocksize | |
self.args.row_blocksize_b, self.args.col_blocksize_b = self.args.row_blocksize, self.args.col_blocksize | |
if self.args.refine_residual_fp: | |
if self.layer_type in ["re_attn_out_re", "re_mlp_out_re"]: | |
self.apply_quantize_f, self.apply_quantize_b = False, False | |
if self.args.refine_ln_blocksize: | |
if self.layer_type in ["re_attn_out_fx"]: | |
if self.args.refine_ln_pertoken: | |
self.args.row_blocksize_f, self.args.col_blocksize_f = ( | |
1, | |
self.args.refine_row_blocksize * self.args.refine_col_blocksize, | |
) | |
self.args.row_blocksize_b, self.args.col_blocksize_b = ( | |
1, | |
self.args.refine_row_blocksize * self.args.refine_col_blocksize, | |
) | |
else: | |
self.args.row_blocksize_f, self.args.col_blocksize_f = ( | |
self.args.refine_row_blocksize, | |
self.args.refine_col_blocksize, | |
) | |
self.args.row_blocksize_b, self.args.col_blocksize_b = ( | |
self.args.refine_row_blocksize, | |
self.args.refine_col_blocksize, | |
) | |
assert not ( | |
self.args.refine_ln_blocksize_but_only_forward and self.args.refine_ln_blocksize_but_only_backward | |
) # This will not happen at the same time | |
if self.args.refine_ln_blocksize_but_only_forward: | |
self.apply_quantize_f, self.apply_quantize_b = True, False | |
if self.args.refine_ln_blocksize_but_only_backward: | |
self.apply_quantize_f, self.apply_quantize_b = False, True | |
if self.layer_type in ["re_mlp_out_fx"]: | |
if self.args.refine_ln_pertoken: | |
self.args.row_blocksize_f, self.args.col_blocksize_f = ( | |
1, | |
self.args.refine_row_blocksize * self.args.refine_col_blocksize, | |
) | |
self.args.row_blocksize_b, self.args.col_blocksize_b = ( | |
1, | |
self.args.refine_row_blocksize * self.args.refine_col_blocksize, | |
) | |
else: | |
self.args.row_blocksize_f, self.args.col_blocksize_f = ( | |
self.args.refine_row_blocksize, | |
self.args.refine_col_blocksize, | |
) | |
self.args.row_blocksize_b, self.args.col_blocksize_b = ( | |
self.args.refine_row_blocksize, | |
self.args.refine_col_blocksize, | |
) | |
assert not ( | |
self.args.refine_ln_blocksize_but_only_forward and self.args.refine_ln_blocksize_but_only_backward | |
) # This will not happen at the same time | |
if self.args.refine_ln_blocksize_but_only_forward: | |
self.apply_quantize_f, self.apply_quantize_b = True, False | |
if self.args.refine_ln_blocksize_but_only_backward: | |
self.apply_quantize_f, self.apply_quantize_b = False, True | |
if self.args.refine_attn_blocksize: | |
if self.layer_type in ["re_attn_out_fx"]: | |
if self.args.refine_ln_pertoken: | |
self.args.row_blocksize_f, self.args.col_blocksize_f = ( | |
1, | |
self.args.refine_row_blocksize * self.args.refine_col_blocksize, | |
) | |
self.args.row_blocksize_b, self.args.col_blocksize_b = ( | |
1, | |
self.args.refine_row_blocksize * self.args.refine_col_blocksize, | |
) | |
else: | |
self.args.row_blocksize_f, self.args.col_blocksize_f = ( | |
self.args.refine_row_blocksize, | |
self.args.refine_col_blocksize, | |
) | |
self.args.row_blocksize_b, self.args.col_blocksize_b = ( | |
self.args.refine_row_blocksize, | |
self.args.refine_col_blocksize, | |
) | |
if self.layer_type in ["ln_attn_out"]: | |
self.args.row_blocksize_b, self.args.col_blocksize_b = ( | |
self.args.refine_row_blocksize, | |
self.args.refine_col_blocksize, | |
) | |
if self.layer_type in ["attn_q_in", "attn_k_in", "attn_v_in"]: | |
self.args.row_blocksize_b, self.args.col_blocksize_b = ( | |
self.args.refine_row_blocksize, | |
self.args.refine_col_blocksize, | |
) | |
if self.args.refine_mlp_blocksize: | |
if self.layer_type in ["re_mlp_out_fx"]: | |
if self.args.refine_ln_pertoken: | |
self.args.row_blocksize_f, self.args.col_blocksize_f = ( | |
1, | |
self.args.refine_row_blocksize * self.args.refine_col_blocksize, | |
) | |
self.args.row_blocksize_b, self.args.col_blocksize_b = ( | |
1, | |
self.args.refine_row_blocksize * self.args.refine_col_blocksize, | |
) | |
else: | |
self.args.row_blocksize_f, self.args.col_blocksize_f = ( | |
self.args.refine_row_blocksize, | |
self.args.refine_col_blocksize, | |
) | |
self.args.row_blocksize_b, self.args.col_blocksize_b = ( | |
self.args.refine_row_blocksize, | |
self.args.refine_col_blocksize, | |
) | |
if self.layer_type in ["ln_mlp_out"]: | |
self.args.row_blocksize_b, self.args.col_blocksize_b = ( | |
self.args.refine_row_blocksize, | |
self.args.refine_col_blocksize, | |
) | |
if self.layer_type in ["mlp_act_gate", "mlp_act_up", "mul_act_out"]: | |
self.args.row_blocksize_b, self.args.col_blocksize_b = ( | |
self.args.refine_row_blocksize, | |
self.args.refine_col_blocksize, | |
) | |
if self.layer_type in ["mlp_act_out"]: | |
self.args.row_blocksize_f, self.args.col_blocksize_f = ( | |
self.args.refine_row_blocksize, | |
self.args.refine_col_blocksize, | |
) | |
self.args.row_blocksize_b, self.args.col_blocksize_b = ( | |
self.args.refine_row_blocksize, | |
self.args.refine_col_blocksize, | |
) | |
def forward(self, input): | |
# input shape is (Batch Size, Sequence Length, Hidden Size) | |
if self.training: | |
return QuantAct_FPin.apply(input, self.args, self.layer_name, self.apply_quantize_f, self.apply_quantize_b) | |
else: | |
return input, None | |
class QuantAct_FPin(Function): | |
def forward(ctx, input, args, layer_name, apply_quantize_f=True, apply_quantize_b=True): | |
ctx.saved = args, layer_name, apply_quantize_f, apply_quantize_b | |
Binput = block_cut(input, args.row_blocksize_f, args.col_blocksize_f) | |
RQinput, Qinput, Iscale = block_quant( | |
Binput, | |
args.symm, | |
args.fabit, | |
stochastic=False, | |
epsilon=args.epsilon, | |
apply_quantize=apply_quantize_f, | |
layer_name=layer_name, | |
) | |
Qinput = block_reshape(Qinput, input, args.row_blocksize_f, args.col_blocksize_f) | |
RQinput = block_reshape(RQinput, input, args.row_blocksize_f, args.col_blocksize_f) | |
if args.draw_distribution_forward: | |
save_tensor(input, RQinput, Qinput, fb="forward", aw="Activation", layer_name=layer_name) | |
# enlarge Iscale to let the size of gradient the same as forward | |
ideal_scale_num = input.numel() / (args.min_blockunit_row * args.min_blockunit_col) | |
actual_scale_num = calculate_scale_num(input, args.row_blocksize_f, args.col_blocksize_f) | |
# actual_scale_num = input.numel() / (args.row_blocksize_f * args.col_blocksize_f) | |
assert Iscale.shape[0] == actual_scale_num | |
Iscale = torch.nn.functional.pad(Iscale, (0, 0, 0, 0, 0, int(ideal_scale_num - actual_scale_num))) | |
return Qinput, Iscale | |
def backward(ctx, Qgrad_output, Gscale): | |
args, layer_name, apply_quantize_f, apply_quantize_b = ctx.saved | |
# shrink Gscale to let the size of gradient the same as forward | |
ideal_scale_num = Qgrad_output.numel() / (args.min_blockunit_row * args.min_blockunit_col) | |
actual_scale_num = calculate_scale_num(Qgrad_output, args.row_blocksize_b, args.col_blocksize_b) | |
# actual_scale_num = Qgrad_output.numel() / (args.row_blocksize_b * args.col_blocksize_b) | |
assert Gscale.shape[0] == ideal_scale_num | |
Gscale = Gscale[: int(actual_scale_num), :, :] | |
Bgrad_output = block_cut(Qgrad_output, args.row_blocksize_b, args.col_blocksize_b) | |
grad_output = Bgrad_output * Gscale | |
grad_output = block_reshape(grad_output, Qgrad_output, args.row_blocksize_b, args.col_blocksize_b) | |
if args.draw_distribution_backward: | |
save_tensor(grad_output, None, None, fb="backward", aw="Activation", layer_name=layer_name) | |
return grad_output, None, None, None, None | |
if __name__ == "__main__": | |
Sum = torch.load("tensor/QAct_nan_epoch16.pt") | |
Qinput, Binput, input, args, layer_type, name = ( | |
Sum["Qinput"], | |
Sum["Binput"], | |
Sum["input"], | |
Sum["args"], | |
Sum["layer_type"], | |
Sum["name"], | |
) | |
if_nan, if_inf = check_nan_inf(input, True, False) | |
print(if_nan) | |
Q = block_quant(Binput, True, 8, stochastic=False, epsilon=1e-8) | |