Spaces:
Running
on
A100
Running
on
A100
File size: 1,920 Bytes
174ae06 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
# Copyright (c) 2025 NVIDIA CORPORATION.
# Licensed under the MIT license.
# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
# LICENSE is in incl_licenses directory.
import torch.nn as nn
from ..real_quantization import fp8_division_transpose
class FP8CacheWeightModule(nn.Module):
def __init__(self, config, qargs, layer_id):
super().__init__()
self.config = config
self.qargs = qargs
self.layer_id = layer_id
def prepare_weight(self, weight, weight_name, is_first_microbatch):
if is_first_microbatch:
if self.qargs.weight_memory_efficient:
# print(f"{weight_name} uses first microbatch")
weight_fp8, weight_s, weight_fp8_t = fp8_division_transpose(
weight, self.qargs.group_size, self.fwobits["fwbit"]
)
setattr(self, f"{weight_name}_fp8_scale", weight_s)
return weight_fp8, weight_fp8_t, weight_s
else:
# print(f"{weight_name} uses first microbatch")
weight_fp8, weight_s, weight_fp8_t = fp8_division_transpose(
weight, self.qargs.group_size, self.fwobits["fwbit"]
)
setattr(self, f"{weight_name}_fp8", weight_fp8)
setattr(self, f"{weight_name}_fp8_t", weight_fp8_t)
setattr(self, f"{weight_name}_fp8_scale", weight_s)
return weight_fp8, weight_fp8_t, weight_s
else:
if self.qargs.weight_memory_efficient:
return getattr(self, f"{weight_name}_fp8_scale")
else:
return (
getattr(self, f"{weight_name}_fp8"),
getattr(self, f"{weight_name}_fp8_t"),
getattr(self, f"{weight_name}_fp8_scale"),
)
def forward(self, x):
pass
|