File size: 3,748 Bytes
cc69848
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import math

import torch
import torch.nn as nn
import torch.nn.functional as F

from .general import factorization, FUNC_LIST


def get_r(oft_blocks, I=None, constraint=0):
    if I is None:
        I = torch.eye(oft_blocks.shape[-1], device=oft_blocks.device)
    if I.ndim < oft_blocks.ndim:
        for _ in range(oft_blocks.ndim - I.ndim):
            I = I.unsqueeze(0)
    # for Q = -Q^T
    q = oft_blocks - oft_blocks.transpose(-1, -2)
    normed_q = q
    if constraint is not None and constraint > 0:
        q_norm = torch.norm(q) + 1e-8
        if q_norm > constraint:
            normed_q = q * constraint / q_norm
    # use float() to prevent unsupported type
    r = (I + normed_q) @ (I - normed_q).float().inverse()
    return r


def weight_gen(org_weight, max_block_size=-1, rescale=False):
    """### weight_gen

    Args:
        org_weight (torch.Tensor): the weight tensor
        max_block_size (int): max block size
        rescale (bool, optional): whether to rescale the weight. Defaults to False.

    Returns:
        torch.Tensor: oft_blocks[, rescale_weight]
    """
    out_dim, *rest = org_weight.shape
    block_size, block_num = factorization(out_dim, max_block_size)
    oft_blocks = torch.zeros(block_num, block_size, block_size)
    if rescale:
        return oft_blocks, torch.ones(out_dim, *[1] * len(rest))
    else:
        return oft_blocks, None


def diff_weight(org_weight, *weights, constraint=None):
    """### diff_weight

    Args:
        org_weight (torch.Tensor): the weight tensor of original model
        weights (tuple[torch.Tensor]): (oft_blocks[, rescale_weight])
        constraint (float, optional): constraint for oft

    Returns:
        torch.Tensor: ΔW
    """
    oft_blocks, rescale = weights
    I = torch.eye(oft_blocks.shape[1], device=oft_blocks.device)
    r = get_r(oft_blocks, I, constraint)

    block_num, block_size, _ = oft_blocks.shape
    _, *shape = org_weight.shape
    org_weight = org_weight.to(dtype=r.dtype)
    org_weight = org_weight.view(block_num, block_size, *shape)
    # Init R=0, so add I on it to ensure the output of step0 is original model output
    weight = torch.einsum(
        "k n m, k n ... -> k m ...",
        r - I,
        org_weight,
    ).view(-1, *shape)
    if rescale is not None:
        weight = rescale * weight
        weight = weight + (rescale - 1) * org_weight
    return weight


def bypass_forward_diff(x, org_out, *weights, constraint=None, need_transpose=False):
    """### bypass_forward_diff

    Args:
        x (torch.Tensor): the input tensor for original model
        org_out (torch.Tensor): the output tensor from original model
        weights (tuple[torch.Tensor]): (oft_blocks[, rescale_weight])
        constraint (float, optional): constraint for oft
        need_transpose (bool, optional):
            whether to transpose the input and output,
            set to `True` if the original model have "dim" not in the last axis.
            For example: Convolution layers

    Returns:
        torch.Tensor: output tensor
    """
    oft_blocks, rescale = weights
    block_num, block_size, _ = oft_blocks.shape
    I = torch.eye(block_size, device=oft_blocks.device)
    r = get_r(oft_blocks, I, constraint)
    if need_transpose:
        org_out = org_out.transpose(1, -1)
    org_out = org_out.to(dtype=r.dtype)
    *shape, _ = org_out.shape
    oft_out = torch.einsum(
        "k n m, ... k n -> ... k m", r - I, org_out.view(*shape, block_num, block_size)
    )
    out = oft_out.view(*shape, -1)
    if rescale is not None:
        out = rescale.transpose(-1, 0) * out
        out = out + (rescale - 1).transpose(-1, 0) * org_out
    if need_transpose:
        out = out.transpose(1, -1)
    return out