Spaces:
Running
on
A100
Running
on
A100
File size: 2,061 Bytes
174ae06 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
# Copyright (c) 2025 NVIDIA CORPORATION.
# Licensed under the MIT license.
# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
# LICENSE is in incl_licenses directory.
import torch
import torch.nn as nn
from torch.autograd.function import Function, InplaceFunction
try:
from .QAct import QAct_FPin, QAct_FPout
from .Qconfig import qconfig
from .QFunction import *
from .utils import *
except:
from Qconfig import qconfig
from utils import *
from QFunction import *
from .QAct import QAct_FPin, QAct_FPout
import os
from copy import deepcopy
import matplotlib.pyplot as plt
class QAdd(nn.Module):
def __init__(self, args=None, layer_type=""):
super().__init__()
self.args = deepcopy(args)
self.layer_type = layer_type
assert layer_type != "", "layer_type is not defined"
assert layer_type in qconfig.qadd_config, f"{layer_type} not in qgelu_config"
self.apply_quantize = list_has_common_element(args.qchoice, qconfig.qadd_config[layer_type])
self.fbit = self.args.fabit if self.args.fabit else self.Ubit
self.bbit = self.args.babit if self.args.babit else self.Ubit
quantize_flag = format_string_with_condition(
layer_type,
{"apply": self.apply_quantize},
self.args.symm,
self.fbit,
self.bbit,
{"row": self.args.row_blocksize, "col": self.args.col_blocksize},
)
print(quantize_flag)
self.Add_in_re = QAct_FPout(args, layer_type=layer_type + "_in_re")
self.Add_in_fx = QAct_FPout(args, layer_type=layer_type + "_in_fx")
def forward(self, Qinput_re, Qinput_fx, Iscale_re, Iscale_fx):
# input shape is (Batch Size, Sequence Length, Hidden Size)
input1 = self.Add_in_re(Qinput_re, Iscale_re)
input2 = self.Add_in_fx(Qinput_fx, Iscale_fx)
output_fp = input1 + input2
return output_fp
if __name__ == "__main__":
Sum = torch.load("tensor/QAct_nan_epoch16.pt")
|