File size: 2,423 Bytes
cc69848 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from .general import rebuild_tucker, FUNC_LIST
def weight_gen(org_weight, rank, tucker=True):
"""### weight_gen
Args:
org_weight (torch.Tensor): the weight tensor
rank (int): low rank
Returns:
torch.Tensor: down, up[, mid]
"""
out_dim, in_dim, *k = org_weight.shape
if k and tucker:
down = torch.empty(rank, in_dim, *(1 for _ in k))
up = torch.empty(out_dim, rank, *(1 for _ in k))
mid = torch.empty(rank, rank, *k)
nn.init.kaiming_uniform_(down, a=math.sqrt(5))
nn.init.constant_(up, 0)
nn.init.kaiming_uniform_(mid, a=math.sqrt(5))
return down, up, mid
else:
down = torch.empty(rank, in_dim)
up = torch.empty(out_dim, rank)
nn.init.kaiming_uniform_(down, a=math.sqrt(5))
nn.init.constant_(up, 0)
return down, up, None
def diff_weight(*weights: tuple[torch.Tensor], gamma=1.0):
"""### diff_weight
Get ΔW = BA, where BA is low rank decomposition
Args:
weights (tuple[torch.Tensor]): (down, up[, mid])
gamma (float, optional): scale factor, normally alpha/rank here
Returns:
torch.Tensor: ΔW
"""
d, u, m = weights
R, I, *k = d.shape
O, R, *_ = u.shape
u = u * gamma
if m is None:
result = u.reshape(-1, u.size(1)) @ d.reshape(d.size(0), -1)
else:
R, R, *k = m.shape
u = u.reshape(u.size(0), -1).transpose(0, 1)
d = d.reshape(d.size(0), -1)
result = rebuild_tucker(m, u, d)
return result.reshape(O, I, *k)
def bypass_forward_diff(x, org_out, *weights, gamma=1.0, extra_args={}):
"""### bypass_forward_diff
Args:
x (torch.Tensor): input tensor
weights (tuple[torch.Tensor]): (down, up[, mid])
gamma (float, optional): scale factor, normally alpha/rank here
extra_args (dict, optional): extra args for forward func, \
e.g. padding, stride for Conv1/2/3d
Returns:
torch.Tensor: output tensor
"""
d, u, m = weights
if m is not None:
down = FUNC_LIST[d.dim()](x, d)
mid = FUNC_LIST[d.dim()](down, m, **extra_args)
up = FUNC_LIST[d.dim()](mid, u)
else:
down = FUNC_LIST[d.dim()](x, d, **extra_args)
up = FUNC_LIST[d.dim()](down, u)
return up * gamma
|