File size: 3,326 Bytes
416b744
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58a855f
416b744
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import copy
import functools
import math

import transformers
import torch
import torch.nn as nn

class TWNLinear(nn.Linear):
    def __init__(self, in_features, out_features, bias=False):
        super().__init__(in_features, out_features, bias=bias)
    
    def forward(self, input: torch.Tensor) -> torch.Tensor:
        x = self.weight
        x = TwnQuantizer().apply(x)
        output = torch.nn.functional.linear(input, x.to(input.dtype))
        return output


class TwnQuantizer(torch.autograd.Function):
    """Ternary Weight Networks (TWN)
    Ref: https://arxiv.org/abs/1605.04711
    """

    @staticmethod
    def forward(ctx, input, max_scale=0.7, clip = None, group_size= -1, per_tensor = False, max_scale_dummy=0.7):
        """
        :param input: tensor to be ternarized
        :return: quantized tensor
        """
        ctx.save_for_backward(input)
        
        org_w_shape = input.shape
        q_group_size = group_size

        if q_group_size > 0:
            assert org_w_shape[-1] % q_group_size == 0
            input = input.reshape(-1, q_group_size)
        else:
            input = input.reshape(-1, input.shape[-1])
        
        if per_tensor: assert q_group_size == -1, "Conflict with Per Tensor and Per Group Quant!"
        
        if clip != None:
            if per_tensor:
                m = input.norm(p=1).div(input.nelement())
                # m = input.max()
                clip_alpha = m * clip
            else:
                m = input.norm(p=1,dim=1).div(input[0].nelement())
                m = m.expand(input.shape[1], -1).transpose(0,1)
                clip_alpha = m * clip
            input = torch.where(input <= clip_alpha, input, clip_alpha)
            input = torch.where(input >= -1*clip_alpha, input, -1*clip_alpha)

        if per_tensor:
            # Per Tensor Quantizaiton
            m = input.abs().mean()
            thres = max_scale * m
            pos = (input > thres).float()
            neg = (input < -thres).float()
            mask = (input.abs() > thres).float()
            alpha = (mask * input).abs().sum() / mask.sum()
            result = alpha * pos - alpha * neg
        else:
            # Per Channel/Group Quantization
            n = input[0].nelement()
            m = input.data.norm(p=1, dim=1).div(n)
            thres = (max_scale * m).view(-1, 1).expand_as(input)
            pos = (input > thres).float()
            neg = (input < -thres).float()
            mask = (input.abs() > thres).float()
            alpha = ((mask * input).abs().sum(dim=1) / mask.sum(dim=1)).view(-1, 1)
            result = alpha * pos - alpha * neg

        result = result.reshape(org_w_shape) # for per-group quantization

        return result

    @staticmethod
    def backward(ctx, grad_output):
        """
        :param ctx: saved non-clipped full-precision tensor and clip_val
        :param grad_output: gradient ert the quantized tensor
        :return: estimated gradient wrt the full-precision tensor
        """
        # input, clip_val = ctx.saved_tensors  # unclipped input
        input = ctx.saved_tensors  # unclipped input
        grad_input = grad_output.clone()
        # grad_input[input.ge(clip_val[1])] = 0
        # grad_input[input.le(clip_val[0])] = 0
        return grad_input, None, None, None, None