File size: 7,818 Bytes
179036e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 |
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
from torch import autograd
import math
class TopKBinarizer(autograd.Function):
"""
Top-k Binarizer.
Computes a binary mask M from a real value matrix S such that `M_{i,j} = 1` if and only if `S_{i,j}`
is among the k% highest values of S.
Implementation is inspired from:
https://github.com/yaozhewei/MLPruning
"""
@staticmethod
def forward(ctx, inputs: torch.tensor, threshold: float, sigmoid: bool):
"""
Args:
inputs (`torch.FloatTensor`)
The input matrix from which the binarizer computes the binary mask.
threshold (`float`)
The percentage of weights to keep (the rest is pruned).
`threshold` is a float between 0 and 1.
sigmoid (`bool`)
Whether to apply a sigmoid on the threshold
Returns:
mask (`torch.FloatTensor`)
Binary matrix of the same size as `inputs` acting as a mask (1 - the associated weight is
retained, 0 - the associated weight is pruned).
"""
# Get the subnetwork by sorting the inputs and using the top threshold
if sigmoid:
threshold = torch.sigmoid(threshold).item()
ctx.sigmoid = sigmoid
mask = inputs.clone()
_, idx = inputs.flatten().sort(descending=True)
j = math.ceil(threshold * inputs.numel())
# flat_out and mask access the same memory.
flat_out = mask.flatten()
flat_out[idx[j:]] = 0.
flat_out[idx[:j]] = 1.
ctx.save_for_backward(mask)
return mask
@staticmethod
def backward(ctx, gradOutput):
mask, = ctx.saved_tensors
if ctx.sigmoid:
return gradOutput.clone(), ((gradOutput * mask).sum()).view(-1), None
else:
return gradOutput.clone(), None, None
class SymQuantizer(torch.autograd.Function):
"""
Symmetric quantization
"""
@staticmethod
def forward(ctx, input, num_bits, min_value=None, max_value=None, num_groups=1):
"""
Args:
inputs (`torch.FloatTensor`)
The input which needs to be quantized
num_bits (int, >=4)
Number of bits to use for quantization
min_value/max_value (torch.FloatTensor)
Used for static activation quantization
num_groups (int)
How many groups to partition the quantization into
Returns:
quantized_input (`torch.FloatTensor`)
Quantized input
"""
assert (min_value is None and max_value is None) or (min_value is not None and max_value is not None
and num_groups == 1)
q_range = 2**num_bits
input_shape = input.shape
if min_value is None:
input = input.reshape(num_groups, -1)
max_input = torch.amax(torch.abs(input), dim=-1).view(num_groups, -1)
else:
max_input = torch.max(min_value.abs(), max_value).view(-1)
scale = 2 * max_input / q_range
output = (input / scale).round().clamp(-q_range // 2, q_range // 2 - 1) * scale
output = output.reshape(input_shape).contiguous()
return output
@staticmethod
def backward(ctx, grad_output):
grad_input = grad_output.clone()
return grad_input, None, None, None, None
class AsymQuantizer(torch.autograd.Function):
"""
Asymmetric quantization
"""
@staticmethod
def forward(ctx, input, num_bits, min_value=None, max_value=None, num_groups=1):
"""
Args:
inputs (`torch.FloatTensor`)
The input which needs to be quantized
num_bits (int, >=4)
Number of bits to use for quantization
min_value/max_value (torch.FloatTensor)
Used for static activation quantization
num_groups (int)
How many groups to partition the quantization into
Returns:
quantized_input (`torch.FloatTensor`)
Quantized input
"""
assert (min_value is None and max_value is None) or (min_value is not None and max_value is not None
and num_groups == 1)
q_range = 2**num_bits
input_shape = input.shape
if min_value is None:
input = input.reshape(num_groups, -1)
min_value = input.amin(dim=-1, keepdim=True)
max_value = input.amax(dim=-1, keepdim=True)
scale = (max_value - min_value) / q_range
zero_point = (min_value / scale).round() * scale
output = ((input - zero_point) / scale).round().clamp(0, q_range - 1) * scale + zero_point
output = output.reshape(input_shape).contiguous()
return output
@staticmethod
def backward(ctx, grad_output):
grad_input = grad_output.clone()
return grad_input, None, None, None, None
class TernaryQuantizer(torch.autograd.Function):
"""
Ternary quantization
"""
@staticmethod
def forward(ctx, input, num_bits, min_value=None, max_value=None, num_groups=1):
"""
Args:
inputs (`torch.FloatTensor`)
The input which needs to be quantized
num_bits (int)
Dummy variable
min_value/max_value (torch.FloatTensor)
Used for static activation quantization; for now they are dummy variable
num_groups (int)
How many groups to partition the quantization into
Returns:
quantized_input (`torch.FloatTensor`)
Quantized input
"""
assert (min_value is None and max_value is None)
input_flat = input.reshape(num_groups, -1)
n = input_flat.shape[1]
m = input_flat.norm(p=1, dim=1).div(n)
thres = (0.7 * m).view(-1, 1)
pos = (input_flat > thres).type(input.type())
neg = (input_flat < -thres).type(input.type())
mask = (input_flat.abs() > thres).type(input.type())
alpha = ((mask * input_flat).abs().sum(dim=1) / mask.sum(dim=1)).view(-1, 1)
output = alpha * pos - alpha * neg
output = output.reshape(input.shape).contiguous()
return output
@staticmethod
def backward(ctx, grad_output):
grad_input = grad_output.clone()
return grad_input, None, None, None, None
class BinaryQuantizer(torch.autograd.Function):
"""
Binary quantization
"""
@staticmethod
def forward(ctx, input, num_bits, min_value=None, max_value=None, num_groups=1):
"""
Args:
inputs (`torch.FloatTensor`)
The input which needs to be quantized
num_bits (int)
Dummy variable
min_value/max_value (torch.FloatTensor)
Used for static activation quantization; for now they are dummy variable
num_groups (int)
How many groups to partition the quantization into
Returns:
quantized_input (`torch.FloatTensor`)
Quantized input
"""
assert (min_value is None and max_value is None)
input_flat = input.reshape(num_groups, -1)
n = input_flat.shape[1]
m = input_flat.norm(p=1, dim=1, keepdim=True).div(n)
output = input_flat.sign().mul(m)
output = output.reshape(input.shape).contiguous()
return output
@staticmethod
def backward(ctx, grad_output):
grad_input = grad_output.clone()
return grad_input, None, None, None, None
|