Spaces:
Running
on
A100
Running
on
A100
File size: 2,181 Bytes
174ae06 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
# Copyright (c) 2025 NVIDIA CORPORATION.
# Licensed under the MIT license.
# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
# LICENSE is in incl_licenses directory.
import torch
import torch.nn as nn
from torch.autograd.function import Function, InplaceFunction
try:
from .QAct import QAct_FPin, QAct_FPout
from .Qconfig import qconfig
from .QFunction import *
from .utils import *
except:
from Qconfig import qconfig
from utils import *
from QFunction import *
from .QAct import QAct_FPin, QAct_FPout
import os
from copy import deepcopy
import matplotlib.pyplot as plt
class QLayerNorm(nn.Module):
def __init__(self, normalized_shape, eps=1e-5, args=None, layer_type=""):
super().__init__()
self.args = deepcopy(args)
self.layer_type = layer_type
assert layer_type != "", "layer_type is not defined"
assert layer_type in qconfig.qlayernorm_config, f"{layer_type} not in qlayernorm_config"
self.apply_quantize = list_has_common_element(args.qchoice, qconfig.qlayernorm_config[layer_type])
self.fbit = self.args.fabit if self.args.fabit else self.Ubit
self.bbit = self.args.babit if self.args.babit else self.Ubit
quantize_flag = format_string_with_condition(
layer_type,
{"apply": self.apply_quantize},
self.args.symm,
self.fbit,
self.bbit,
{"row": self.args.row_blocksize, "col": self.args.col_blocksize},
)
print(quantize_flag)
self.ln_in = QAct_FPout(args, layer_type=layer_type + "_in")
self.layer_norm = nn.LayerNorm(normalized_shape, eps=eps)
self.ln_out = QAct_FPin(args, layer_type=layer_type + "_out")
def forward(self, Qinput, Iscale):
# input shape is (Batch Size, Sequence Length, Hidden Size)
input = self.ln_in(Qinput, Iscale)
output_fp = self.layer_norm(input)
# import IPython
# IPython.embed()
output, scale = self.ln_out(output_fp)
return output, scale
if __name__ == "__main__":
Sum = torch.load("tensor/QAct_nan_epoch16.pt")
|