Spaces:
Running
on
A100
Running
on
A100
# Copyright (c) 2025 NVIDIA CORPORATION. | |
# Licensed under the MIT license. | |
# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. | |
# LICENSE is in incl_licenses directory. | |
import torch.nn as nn | |
from ..real_quantization import fp8_division_transpose | |
class FP8CacheWeightModule(nn.Module): | |
def __init__(self, config, qargs, layer_id): | |
super().__init__() | |
self.config = config | |
self.qargs = qargs | |
self.layer_id = layer_id | |
def prepare_weight(self, weight, weight_name, is_first_microbatch): | |
if is_first_microbatch: | |
if self.qargs.weight_memory_efficient: | |
# print(f"{weight_name} uses first microbatch") | |
weight_fp8, weight_s, weight_fp8_t = fp8_division_transpose( | |
weight, self.qargs.group_size, self.fwobits["fwbit"] | |
) | |
setattr(self, f"{weight_name}_fp8_scale", weight_s) | |
return weight_fp8, weight_fp8_t, weight_s | |
else: | |
# print(f"{weight_name} uses first microbatch") | |
weight_fp8, weight_s, weight_fp8_t = fp8_division_transpose( | |
weight, self.qargs.group_size, self.fwobits["fwbit"] | |
) | |
setattr(self, f"{weight_name}_fp8", weight_fp8) | |
setattr(self, f"{weight_name}_fp8_t", weight_fp8_t) | |
setattr(self, f"{weight_name}_fp8_scale", weight_s) | |
return weight_fp8, weight_fp8_t, weight_s | |
else: | |
if self.qargs.weight_memory_efficient: | |
return getattr(self, f"{weight_name}_fp8_scale") | |
else: | |
return ( | |
getattr(self, f"{weight_name}_fp8"), | |
getattr(self, f"{weight_name}_fp8_t"), | |
getattr(self, f"{weight_name}_fp8_scale"), | |
) | |
def forward(self, x): | |
pass | |