File size: 5,295 Bytes
cc69848 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from .general import FUNC_LIST
class HadaWeight(torch.autograd.Function):
@staticmethod
def forward(ctx, w1d, w1u, w2d, w2u, scale=torch.tensor(1)):
ctx.save_for_backward(w1d, w1u, w2d, w2u, scale)
diff_weight = ((w1u @ w1d) * (w2u @ w2d)) * scale
return diff_weight
@staticmethod
def backward(ctx, grad_out):
(w1d, w1u, w2d, w2u, scale) = ctx.saved_tensors
grad_out = grad_out * scale
temp = grad_out * (w2u @ w2d)
grad_w1u = temp @ w1d.T
grad_w1d = w1u.T @ temp
temp = grad_out * (w1u @ w1d)
grad_w2u = temp @ w2d.T
grad_w2d = w2u.T @ temp
del temp
return grad_w1d, grad_w1u, grad_w2d, grad_w2u, None
class HadaWeightTucker(torch.autograd.Function):
@staticmethod
def forward(ctx, t1, w1d, w1u, t2, w2d, w2u, scale=torch.tensor(1)):
ctx.save_for_backward(t1, w1d, w1u, t2, w2d, w2u, scale)
rebuild1 = torch.einsum("i j ..., j r, i p -> p r ...", t1, w1d, w1u)
rebuild2 = torch.einsum("i j ..., j r, i p -> p r ...", t2, w2d, w2u)
return rebuild1 * rebuild2 * scale
@staticmethod
def backward(ctx, grad_out):
(t1, w1d, w1u, t2, w2d, w2u, scale) = ctx.saved_tensors
grad_out = grad_out * scale
temp = torch.einsum("i j ..., j r -> i r ...", t2, w2d)
rebuild = torch.einsum("i j ..., i r -> r j ...", temp, w2u)
grad_w = rebuild * grad_out
del rebuild
grad_w1u = torch.einsum("r j ..., i j ... -> r i", temp, grad_w)
grad_temp = torch.einsum("i j ..., i r -> r j ...", grad_w, w1u.T)
del grad_w, temp
grad_w1d = torch.einsum("i r ..., i j ... -> r j", t1, grad_temp)
grad_t1 = torch.einsum("i j ..., j r -> i r ...", grad_temp, w1d.T)
del grad_temp
temp = torch.einsum("i j ..., j r -> i r ...", t1, w1d)
rebuild = torch.einsum("i j ..., i r -> r j ...", temp, w1u)
grad_w = rebuild * grad_out
del rebuild
grad_w2u = torch.einsum("r j ..., i j ... -> r i", temp, grad_w)
grad_temp = torch.einsum("i j ..., i r -> r j ...", grad_w, w2u.T)
del grad_w, temp
grad_w2d = torch.einsum("i r ..., i j ... -> r j", t2, grad_temp)
grad_t2 = torch.einsum("i j ..., j r -> i r ...", grad_temp, w2d.T)
del grad_temp
return grad_t1, grad_w1d, grad_w1u, grad_t2, grad_w2d, grad_w2u, None
def make_weight(w1d, w1u, w2d, w2u, scale):
return HadaWeight.apply(w1d, w1u, w2d, w2u, scale)
def make_weight_tucker(t1, w1d, w1u, t2, w2d, w2u, scale):
return HadaWeightTucker.apply(t1, w1d, w1u, t2, w2d, w2u, scale)
def weight_gen(org_weight, rank, tucker=True):
"""### weight_gen
Args:
org_weight (torch.Tensor): the weight tensor
rank (int): low rank
Returns:
torch.Tensor: w1d, w2d, w1u, w2u[, t1, t2]
"""
out_dim, in_dim, *k = org_weight.shape
if k and tucker:
w1d = torch.empty(rank, in_dim)
w1u = torch.empty(rank, out_dim)
t1 = torch.empty(rank, rank, *k)
w2d = torch.empty(rank, in_dim)
w2u = torch.empty(rank, out_dim)
t2 = torch.empty(rank, rank, *k)
nn.init.normal_(t1, std=0.1)
nn.init.normal_(t2, std=0.1)
else:
w1d = torch.empty(rank, in_dim)
w1u = torch.empty(out_dim, rank)
w2d = torch.empty(rank, in_dim)
w2u = torch.empty(out_dim, rank)
t1 = t2 = None
nn.init.normal_(w1d, std=1)
nn.init.constant_(w1u, 0)
nn.init.normal_(w2d, std=1)
nn.init.normal_(w2u, std=0.1)
return w1d, w1u, w2d, w2u, t1, t2
def diff_weight(*weights, gamma=1.0):
"""### diff_weight
Get ΔW = BA, where BA is low rank decomposition
Args:
wegihts (tuple[torch.Tensor]): (w1d, w2d, w1u, w2u[, t1, t2])
gamma (float, optional): scale factor, normally alpha/rank here
Returns:
torch.Tensor: ΔW
"""
w1d, w1u, w2d, w2u, t1, t2 = weights
if t1 is not None and t2 is not None:
R, I = w1d.shape
R, O = w1u.shape
R, R, *k = t1.shape
result = make_weight_tucker(t1, w1d, w1u, t2, w2d, w2u, gamma)
else:
R, I, *k = w1d.shape
O, R, *_ = w1u.shape
w1d = w1d.reshape(w1d.size(0), -1)
w1u = w1u.reshape(-1, w1u.size(1))
w2d = w2d.reshape(w2d.size(0), -1)
w2u = w2u.reshape(-1, w2u.size(1))
result = make_weight(w1d, w1u, w2d, w2u, gamma)
result = result.reshape(O, I, *k)
return result
def bypass_forward_diff(x, org_out, *weights, gamma=1.0, extra_args={}):
"""### bypass_forward_diff
Args:
x (torch.Tensor): input tensor
weights (tuple[torch.Tensor]): (w1d, w2d, w1u, w2u[, t1, t2])
gamma (float, optional): scale factor, normally alpha/rank here
extra_args (dict, optional): extra args for forward func, \
e.g. padding, stride for Conv1/2/3d
Returns:
torch.Tensor: output tensor
"""
w1d, w1u, w2d, w2u, t1, t2 = weights
diff_w = diff_weight(w1d, w1u, w2d, w2u, t1, t2, gamma)
return FUNC_LIST[w1d.dim() if t1 is None else t1.dim()](x, diff_w, **extra_args)
|