File size: 2,423 Bytes
cc69848
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import math

import torch
import torch.nn as nn
import torch.nn.functional as F

from .general import rebuild_tucker, FUNC_LIST


def weight_gen(org_weight, rank, tucker=True):
    """### weight_gen

    Args:
        org_weight (torch.Tensor): the weight tensor
        rank (int): low rank

    Returns:
        torch.Tensor: down, up[, mid]
    """
    out_dim, in_dim, *k = org_weight.shape
    if k and tucker:
        down = torch.empty(rank, in_dim, *(1 for _ in k))
        up = torch.empty(out_dim, rank, *(1 for _ in k))
        mid = torch.empty(rank, rank, *k)
        nn.init.kaiming_uniform_(down, a=math.sqrt(5))
        nn.init.constant_(up, 0)
        nn.init.kaiming_uniform_(mid, a=math.sqrt(5))
        return down, up, mid
    else:
        down = torch.empty(rank, in_dim)
        up = torch.empty(out_dim, rank)
        nn.init.kaiming_uniform_(down, a=math.sqrt(5))
        nn.init.constant_(up, 0)
        return down, up, None


def diff_weight(*weights: tuple[torch.Tensor], gamma=1.0):
    """### diff_weight

    Get ΔW = BA, where BA is low rank decomposition

    Args:
        weights (tuple[torch.Tensor]): (down, up[, mid])
        gamma (float, optional): scale factor, normally alpha/rank here

    Returns:
        torch.Tensor: ΔW
    """
    d, u, m = weights
    R, I, *k = d.shape
    O, R, *_ = u.shape
    u = u * gamma

    if m is None:
        result = u.reshape(-1, u.size(1)) @ d.reshape(d.size(0), -1)
    else:
        R, R, *k = m.shape
        u = u.reshape(u.size(0), -1).transpose(0, 1)
        d = d.reshape(d.size(0), -1)
        result = rebuild_tucker(m, u, d)
    return result.reshape(O, I, *k)


def bypass_forward_diff(x, org_out, *weights, gamma=1.0, extra_args={}):
    """### bypass_forward_diff

    Args:
        x (torch.Tensor): input tensor
        weights (tuple[torch.Tensor]): (down, up[, mid])
        gamma (float, optional): scale factor, normally alpha/rank here
        extra_args (dict, optional): extra args for forward func, \
            e.g. padding, stride for Conv1/2/3d

    Returns:
        torch.Tensor: output tensor
    """
    d, u, m = weights
    if m is not None:
        down = FUNC_LIST[d.dim()](x, d)
        mid = FUNC_LIST[d.dim()](down, m, **extra_args)
        up = FUNC_LIST[d.dim()](mid, u)
    else:
        down = FUNC_LIST[d.dim()](x, d, **extra_args)
        up = FUNC_LIST[d.dim()](down, u)
    return up * gamma