Spaces:
Running
on
A100
Running
on
A100
# Copyright (c) 2025 NVIDIA CORPORATION. | |
# Licensed under the MIT license. | |
# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. | |
# LICENSE is in incl_licenses directory. | |
import math | |
import torch | |
def floatExMy_quantize_torch(x, e_bit, m_bit, stochastic): | |
sign, x_abs = x.sign(), x.abs() | |
Elow, Ehigh, Mhigh = -(2 ** (e_bit - 1)) + 2, 2 ** (e_bit - 1), 2**m_bit | |
expo = torch.floor(torch.log2(x_abs)) | |
expo = torch.clamp(expo, min=Elow, max=Ehigh) | |
mant = x_abs / torch.exp2(expo) | |
mant_int = torch.floor(mant) | |
mant_frac = mant - mant_int | |
mant_frac = mant_frac * Mhigh | |
if stochastic: | |
noise = mant_frac.new(mant_frac.shape).uniform_(-0.5, 0.5) | |
mant_frac.add_(noise) | |
mant_frac = torch.round(mant_frac) | |
mant_q = mant_int + mant_frac / Mhigh | |
y = sign * (2**expo) * mant_q | |
y = y.to(x) | |
return y | |
def floatExM0_quantize_torch(x, e_bit, stochastic): | |
sign, x_abs = x.sign(), x.abs() | |
Elow, Ehigh = -(2 ** (e_bit - 1)) + 1, 2 ** (e_bit - 1) | |
expo = torch.log2(x_abs) | |
if stochastic: | |
noise = expo.new(expo.shape).uniform_(-0.5, 0.5) | |
expo.add(noise) | |
log_bias = math.log2(4 / 3) - 1 / 2 | |
expo.add(torch.ones_like(expo) * log_bias) | |
expo = torch.clamp(expo, min=Elow - 1, max=Ehigh) | |
expo = torch.round(expo) | |
y = sign * (2**expo) * (expo > Elow) # When underflow, set the value to 0 | |
y = y.to(x) | |
return y | |
def Dynamic_quantize_torch(x, bit, stochastic): | |
if stochastic: | |
raise NotImplementedError("Dynamic Tree quantization does not support stochastic") | |
sign, x_abs = x.sign(), x.abs() | |
expo = torch.ceil(torch.log10(x_abs)) | |
expo = torch.clamp(expo, min=2 - bit) | |
mant = (10 * x_abs / torch.pow(10, expo) - 1) / 9 # Range from 0 - 1 | |
mant_frac = mant * 2 ** (bit - 2 - expo.abs()) | |
mant_frac = torch.round(mant_frac) | |
mant_frac = mant_frac / (2 ** (bit - 2 - expo.abs())) * 9 + 1 | |
y = sign * (10**expo) * mant_frac / 10 | |
zero_mask = y.abs() > 1.01 * 10 ** (1 - bit) | |
y = y * zero_mask | |
y = y.to(x) | |
return y | |
def ZeroDynamic_quantize_torch(x, bit, stochastic): | |
if stochastic: | |
raise NotImplementedError("Dynamic Tree quantization does not support stochastic") | |
sign, x_abs = x.sign(), x.abs() | |
expo = torch.ceil(torch.log10(x_abs)) | |
expo = torch.clamp(expo, min=2 - bit) | |
mant = (10 * x_abs / torch.pow(10, expo) - 1) / 9 # Range from 0 - 1 | |
mant_frac = mant * 2 ** (bit - 2 - expo.abs()) | |
mant_frac = torch.round(mant_frac) | |
mant_frac = mant_frac / (2 ** (bit - 2 - expo.abs())) * 9 + 1 | |
y = sign * (10**expo) * mant_frac / 10 | |
y = y.to(x) | |
return y | |