# Copyright (c) 2025 NVIDIA CORPORATION. # Licensed under the MIT license. # Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. # LICENSE is in incl_licenses directory. import torch.nn as nn from ..real_quantization import fp8_division_transpose class FP8CacheWeightModule(nn.Module): def __init__(self, config, qargs, layer_id): super().__init__() self.config = config self.qargs = qargs self.layer_id = layer_id def prepare_weight(self, weight, weight_name, is_first_microbatch): if is_first_microbatch: if self.qargs.weight_memory_efficient: # print(f"{weight_name} uses first microbatch") weight_fp8, weight_s, weight_fp8_t = fp8_division_transpose( weight, self.qargs.group_size, self.fwobits["fwbit"] ) setattr(self, f"{weight_name}_fp8_scale", weight_s) return weight_fp8, weight_fp8_t, weight_s else: # print(f"{weight_name} uses first microbatch") weight_fp8, weight_s, weight_fp8_t = fp8_division_transpose( weight, self.qargs.group_size, self.fwobits["fwbit"] ) setattr(self, f"{weight_name}_fp8", weight_fp8) setattr(self, f"{weight_name}_fp8_t", weight_fp8_t) setattr(self, f"{weight_name}_fp8_scale", weight_s) return weight_fp8, weight_fp8_t, weight_s else: if self.qargs.weight_memory_efficient: return getattr(self, f"{weight_name}_fp8_scale") else: return ( getattr(self, f"{weight_name}_fp8"), getattr(self, f"{weight_name}_fp8_t"), getattr(self, f"{weight_name}_fp8_scale"), ) def forward(self, x): pass