File size: 3,326 Bytes
416b744 58a855f 416b744 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
import copy
import functools
import math
import transformers
import torch
import torch.nn as nn
class TWNLinear(nn.Linear):
def __init__(self, in_features, out_features, bias=False):
super().__init__(in_features, out_features, bias=bias)
def forward(self, input: torch.Tensor) -> torch.Tensor:
x = self.weight
x = TwnQuantizer().apply(x)
output = torch.nn.functional.linear(input, x.to(input.dtype))
return output
class TwnQuantizer(torch.autograd.Function):
"""Ternary Weight Networks (TWN)
Ref: https://arxiv.org/abs/1605.04711
"""
@staticmethod
def forward(ctx, input, max_scale=0.7, clip = None, group_size= -1, per_tensor = False, max_scale_dummy=0.7):
"""
:param input: tensor to be ternarized
:return: quantized tensor
"""
ctx.save_for_backward(input)
org_w_shape = input.shape
q_group_size = group_size
if q_group_size > 0:
assert org_w_shape[-1] % q_group_size == 0
input = input.reshape(-1, q_group_size)
else:
input = input.reshape(-1, input.shape[-1])
if per_tensor: assert q_group_size == -1, "Conflict with Per Tensor and Per Group Quant!"
if clip != None:
if per_tensor:
m = input.norm(p=1).div(input.nelement())
# m = input.max()
clip_alpha = m * clip
else:
m = input.norm(p=1,dim=1).div(input[0].nelement())
m = m.expand(input.shape[1], -1).transpose(0,1)
clip_alpha = m * clip
input = torch.where(input <= clip_alpha, input, clip_alpha)
input = torch.where(input >= -1*clip_alpha, input, -1*clip_alpha)
if per_tensor:
# Per Tensor Quantizaiton
m = input.abs().mean()
thres = max_scale * m
pos = (input > thres).float()
neg = (input < -thres).float()
mask = (input.abs() > thres).float()
alpha = (mask * input).abs().sum() / mask.sum()
result = alpha * pos - alpha * neg
else:
# Per Channel/Group Quantization
n = input[0].nelement()
m = input.data.norm(p=1, dim=1).div(n)
thres = (max_scale * m).view(-1, 1).expand_as(input)
pos = (input > thres).float()
neg = (input < -thres).float()
mask = (input.abs() > thres).float()
alpha = ((mask * input).abs().sum(dim=1) / mask.sum(dim=1)).view(-1, 1)
result = alpha * pos - alpha * neg
result = result.reshape(org_w_shape) # for per-group quantization
return result
@staticmethod
def backward(ctx, grad_output):
"""
:param ctx: saved non-clipped full-precision tensor and clip_val
:param grad_output: gradient ert the quantized tensor
:return: estimated gradient wrt the full-precision tensor
"""
# input, clip_val = ctx.saved_tensors # unclipped input
input = ctx.saved_tensors # unclipped input
grad_input = grad_output.clone()
# grad_input[input.ge(clip_val[1])] = 0
# grad_input[input.le(clip_val[0])] = 0
return grad_input, None, None, None, None
|