File size: 3,747 Bytes
cc69848 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from .general import power2factorization, FUNC_LIST
from .diag_oft import get_r
def weight_gen(org_weight, max_block_size, boft_m=-1, rescale=False):
"""### boft_weight_gen
Args:
org_weight (torch.Tensor): the weight tensor
max_block_size (int): max block size
rescale (bool, optional): whether to rescale the weight. Defaults to False.
Returns:
torch.Tensor: oft_blocks[, rescale_weight]
"""
out_dim, *rest = org_weight.shape
block_size, block_num = power2factorization(out_dim, max_block_size)
max_boft_m = sum(int(i) for i in f"{block_num-1:b}") + 1
if boft_m == -1:
boft_m = max_boft_m
boft_m = min(boft_m, max_boft_m)
oft_blocks = torch.zeros(boft_m, block_num, block_size, block_size)
if rescale is not None:
return oft_blocks, torch.ones(out_dim, *[1] * len(rest))
else:
return oft_blocks, None
def diff_weight(org_weight, *weights, constraint=None):
"""### boft_diff_weight
Args:
org_weight (torch.Tensor): the weight tensor of original model
weights (tuple[torch.Tensor]): (oft_blocks[, rescale_weight])
constraint (float, optional): constraint for oft
Returns:
torch.Tensor: ΔW
"""
oft_blocks, rescale = weights
m, num, b, _ = oft_blocks.shape
r_b = b // 2
I = torch.eye(b, device=oft_blocks.device)
r = get_r(oft_blocks, I, constraint)
inp = org = org_weight.to(dtype=r.dtype)
for i in range(m):
bi = r[i] # b_num, b_size, b_size
g = 2
k = 2**i * r_b
inp = (
inp.unflatten(-1, (-1, g, k))
.transpose(-2, -1)
.flatten(-3)
.unflatten(-1, (-1, b))
)
inp = torch.einsum("b i j, b j ... -> b i ...", bi, inp)
inp = inp.flatten(-2).unflatten(-1, (-1, k, g)).transpose(-2, -1).flatten(-3)
if rescale is not None:
inp = inp * rescale
return inp - org
def bypass_forward_diff(org_out, *weights, constraint=None, need_transpose=False):
"""### boft_bypass_forward_diff
Args:
x (torch.Tensor): the input tensor for original model
org_out (torch.Tensor): the output tensor from original model
weights (tuple[torch.Tensor]): (oft_blocks[, rescale_weight])
constraint (float, optional): constraint for oft
need_transpose (bool, optional):
whether to transpose the input and output,
set to `True` if the original model have "dim" not in the last axis.
For example: Convolution layers
Returns:
torch.Tensor: output tensor
"""
oft_blocks, rescale = weights
m, num, b, _ = oft_blocks.shape
r_b = b // 2
I = torch.eye(b, device=oft_blocks.device)
r = get_r(oft_blocks, I, constraint)
inp = org = org_out.to(dtype=r.dtype)
if need_transpose:
inp = org = inp.transpose(1, -1)
for i in range(m):
bi = r[i] # b_num, b_size, b_size
g = 2
k = 2**i * r_b
# ... (c g k) ->... (c k g)
# ... (d b) -> ... d b
inp = (
inp.unflatten(-1, (-1, g, k))
.transpose(-2, -1)
.flatten(-3)
.unflatten(-1, (-1, b))
)
inp = torch.einsum("b i j, ... b j -> ... b i", bi, inp)
# ... d b -> ... (d b)
# ... (c k g) -> ... (c g k)
inp = inp.flatten(-2).unflatten(-1, (-1, k, g)).transpose(-2, -1).flatten(-3)
if rescale is not None:
inp = inp * rescale.transpose(0, -1)
inp = inp - org
if need_transpose:
inp = inp.transpose(1, -1)
return inp
|