|
import math |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from .general import factorization, FUNC_LIST |
|
|
|
|
|
def get_r(oft_blocks, I=None, constraint=0): |
|
if I is None: |
|
I = torch.eye(oft_blocks.shape[-1], device=oft_blocks.device) |
|
if I.ndim < oft_blocks.ndim: |
|
for _ in range(oft_blocks.ndim - I.ndim): |
|
I = I.unsqueeze(0) |
|
|
|
q = oft_blocks - oft_blocks.transpose(-1, -2) |
|
normed_q = q |
|
if constraint is not None and constraint > 0: |
|
q_norm = torch.norm(q) + 1e-8 |
|
if q_norm > constraint: |
|
normed_q = q * constraint / q_norm |
|
|
|
r = (I + normed_q) @ (I - normed_q).float().inverse() |
|
return r |
|
|
|
|
|
def weight_gen(org_weight, max_block_size=-1, rescale=False): |
|
"""### weight_gen |
|
|
|
Args: |
|
org_weight (torch.Tensor): the weight tensor |
|
max_block_size (int): max block size |
|
rescale (bool, optional): whether to rescale the weight. Defaults to False. |
|
|
|
Returns: |
|
torch.Tensor: oft_blocks[, rescale_weight] |
|
""" |
|
out_dim, *rest = org_weight.shape |
|
block_size, block_num = factorization(out_dim, max_block_size) |
|
oft_blocks = torch.zeros(block_num, block_size, block_size) |
|
if rescale: |
|
return oft_blocks, torch.ones(out_dim, *[1] * len(rest)) |
|
else: |
|
return oft_blocks, None |
|
|
|
|
|
def diff_weight(org_weight, *weights, constraint=None): |
|
"""### diff_weight |
|
|
|
Args: |
|
org_weight (torch.Tensor): the weight tensor of original model |
|
weights (tuple[torch.Tensor]): (oft_blocks[, rescale_weight]) |
|
constraint (float, optional): constraint for oft |
|
|
|
Returns: |
|
torch.Tensor: ΔW |
|
""" |
|
oft_blocks, rescale = weights |
|
I = torch.eye(oft_blocks.shape[1], device=oft_blocks.device) |
|
r = get_r(oft_blocks, I, constraint) |
|
|
|
block_num, block_size, _ = oft_blocks.shape |
|
_, *shape = org_weight.shape |
|
org_weight = org_weight.to(dtype=r.dtype) |
|
org_weight = org_weight.view(block_num, block_size, *shape) |
|
|
|
weight = torch.einsum( |
|
"k n m, k n ... -> k m ...", |
|
r - I, |
|
org_weight, |
|
).view(-1, *shape) |
|
if rescale is not None: |
|
weight = rescale * weight |
|
weight = weight + (rescale - 1) * org_weight |
|
return weight |
|
|
|
|
|
def bypass_forward_diff(x, org_out, *weights, constraint=None, need_transpose=False): |
|
"""### bypass_forward_diff |
|
|
|
Args: |
|
x (torch.Tensor): the input tensor for original model |
|
org_out (torch.Tensor): the output tensor from original model |
|
weights (tuple[torch.Tensor]): (oft_blocks[, rescale_weight]) |
|
constraint (float, optional): constraint for oft |
|
need_transpose (bool, optional): |
|
whether to transpose the input and output, |
|
set to `True` if the original model have "dim" not in the last axis. |
|
For example: Convolution layers |
|
|
|
Returns: |
|
torch.Tensor: output tensor |
|
""" |
|
oft_blocks, rescale = weights |
|
block_num, block_size, _ = oft_blocks.shape |
|
I = torch.eye(block_size, device=oft_blocks.device) |
|
r = get_r(oft_blocks, I, constraint) |
|
if need_transpose: |
|
org_out = org_out.transpose(1, -1) |
|
org_out = org_out.to(dtype=r.dtype) |
|
*shape, _ = org_out.shape |
|
oft_out = torch.einsum( |
|
"k n m, ... k n -> ... k m", r - I, org_out.view(*shape, block_num, block_size) |
|
) |
|
out = oft_out.view(*shape, -1) |
|
if rescale is not None: |
|
out = rescale.transpose(-1, 0) * out |
|
out = out + (rescale - 1).transpose(-1, 0) * org_out |
|
if need_transpose: |
|
out = out.transpose(1, -1) |
|
return out |
|
|