File size: 3,747 Bytes
cc69848
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import math

import torch
import torch.nn as nn
import torch.nn.functional as F

from einops import rearrange

from .general import power2factorization, FUNC_LIST
from .diag_oft import get_r


def weight_gen(org_weight, max_block_size, boft_m=-1, rescale=False):
    """### boft_weight_gen

    Args:
        org_weight (torch.Tensor): the weight tensor
        max_block_size (int): max block size
        rescale (bool, optional): whether to rescale the weight. Defaults to False.

    Returns:
        torch.Tensor: oft_blocks[, rescale_weight]
    """
    out_dim, *rest = org_weight.shape
    block_size, block_num = power2factorization(out_dim, max_block_size)
    max_boft_m = sum(int(i) for i in f"{block_num-1:b}") + 1
    if boft_m == -1:
        boft_m = max_boft_m
    boft_m = min(boft_m, max_boft_m)
    oft_blocks = torch.zeros(boft_m, block_num, block_size, block_size)
    if rescale is not None:
        return oft_blocks, torch.ones(out_dim, *[1] * len(rest))
    else:
        return oft_blocks, None


def diff_weight(org_weight, *weights, constraint=None):
    """### boft_diff_weight

    Args:
        org_weight (torch.Tensor): the weight tensor of original model
        weights (tuple[torch.Tensor]): (oft_blocks[, rescale_weight])
        constraint (float, optional): constraint for oft

    Returns:
        torch.Tensor: ΔW
    """
    oft_blocks, rescale = weights
    m, num, b, _ = oft_blocks.shape
    r_b = b // 2
    I = torch.eye(b, device=oft_blocks.device)
    r = get_r(oft_blocks, I, constraint)
    inp = org = org_weight.to(dtype=r.dtype)

    for i in range(m):
        bi = r[i]  # b_num, b_size, b_size
        g = 2
        k = 2**i * r_b
        inp = (
            inp.unflatten(-1, (-1, g, k))
            .transpose(-2, -1)
            .flatten(-3)
            .unflatten(-1, (-1, b))
        )
        inp = torch.einsum("b i j, b j ... -> b i ...", bi, inp)
        inp = inp.flatten(-2).unflatten(-1, (-1, k, g)).transpose(-2, -1).flatten(-3)

    if rescale is not None:
        inp = inp * rescale

    return inp - org


def bypass_forward_diff(org_out, *weights, constraint=None, need_transpose=False):
    """### boft_bypass_forward_diff

    Args:
        x (torch.Tensor): the input tensor for original model
        org_out (torch.Tensor): the output tensor from original model
        weights (tuple[torch.Tensor]): (oft_blocks[, rescale_weight])
        constraint (float, optional): constraint for oft
        need_transpose (bool, optional):
            whether to transpose the input and output,
            set to `True` if the original model have "dim" not in the last axis.
            For example: Convolution layers

    Returns:
        torch.Tensor: output tensor
    """
    oft_blocks, rescale = weights
    m, num, b, _ = oft_blocks.shape
    r_b = b // 2
    I = torch.eye(b, device=oft_blocks.device)
    r = get_r(oft_blocks, I, constraint)
    inp = org = org_out.to(dtype=r.dtype)
    if need_transpose:
        inp = org = inp.transpose(1, -1)

    for i in range(m):
        bi = r[i]  # b_num, b_size, b_size
        g = 2
        k = 2**i * r_b
        # ... (c g k) ->... (c k g)
        # ... (d b) -> ... d b
        inp = (
            inp.unflatten(-1, (-1, g, k))
            .transpose(-2, -1)
            .flatten(-3)
            .unflatten(-1, (-1, b))
        )
        inp = torch.einsum("b i j, ... b j -> ... b i", bi, inp)
        # ... d b -> ... (d b)
        # ... (c k g) -> ... (c g k)
        inp = inp.flatten(-2).unflatten(-1, (-1, k, g)).transpose(-2, -1).flatten(-3)

    if rescale is not None:
        inp = inp * rescale.transpose(0, -1)

    inp = inp - org
    if need_transpose:
        inp = inp.transpose(1, -1)
    return inp