tonyshark's picture
Upload 132 files
cc69848 verified
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from .general import factorization, FUNC_LIST
def get_r(oft_blocks, I=None, constraint=0):
if I is None:
I = torch.eye(oft_blocks.shape[-1], device=oft_blocks.device)
if I.ndim < oft_blocks.ndim:
for _ in range(oft_blocks.ndim - I.ndim):
I = I.unsqueeze(0)
# for Q = -Q^T
q = oft_blocks - oft_blocks.transpose(-1, -2)
normed_q = q
if constraint is not None and constraint > 0:
q_norm = torch.norm(q) + 1e-8
if q_norm > constraint:
normed_q = q * constraint / q_norm
# use float() to prevent unsupported type
r = (I + normed_q) @ (I - normed_q).float().inverse()
return r
def weight_gen(org_weight, max_block_size=-1, rescale=False):
"""### weight_gen
Args:
org_weight (torch.Tensor): the weight tensor
max_block_size (int): max block size
rescale (bool, optional): whether to rescale the weight. Defaults to False.
Returns:
torch.Tensor: oft_blocks[, rescale_weight]
"""
out_dim, *rest = org_weight.shape
block_size, block_num = factorization(out_dim, max_block_size)
oft_blocks = torch.zeros(block_num, block_size, block_size)
if rescale:
return oft_blocks, torch.ones(out_dim, *[1] * len(rest))
else:
return oft_blocks, None
def diff_weight(org_weight, *weights, constraint=None):
"""### diff_weight
Args:
org_weight (torch.Tensor): the weight tensor of original model
weights (tuple[torch.Tensor]): (oft_blocks[, rescale_weight])
constraint (float, optional): constraint for oft
Returns:
torch.Tensor: ΔW
"""
oft_blocks, rescale = weights
I = torch.eye(oft_blocks.shape[1], device=oft_blocks.device)
r = get_r(oft_blocks, I, constraint)
block_num, block_size, _ = oft_blocks.shape
_, *shape = org_weight.shape
org_weight = org_weight.to(dtype=r.dtype)
org_weight = org_weight.view(block_num, block_size, *shape)
# Init R=0, so add I on it to ensure the output of step0 is original model output
weight = torch.einsum(
"k n m, k n ... -> k m ...",
r - I,
org_weight,
).view(-1, *shape)
if rescale is not None:
weight = rescale * weight
weight = weight + (rescale - 1) * org_weight
return weight
def bypass_forward_diff(x, org_out, *weights, constraint=None, need_transpose=False):
"""### bypass_forward_diff
Args:
x (torch.Tensor): the input tensor for original model
org_out (torch.Tensor): the output tensor from original model
weights (tuple[torch.Tensor]): (oft_blocks[, rescale_weight])
constraint (float, optional): constraint for oft
need_transpose (bool, optional):
whether to transpose the input and output,
set to `True` if the original model have "dim" not in the last axis.
For example: Convolution layers
Returns:
torch.Tensor: output tensor
"""
oft_blocks, rescale = weights
block_num, block_size, _ = oft_blocks.shape
I = torch.eye(block_size, device=oft_blocks.device)
r = get_r(oft_blocks, I, constraint)
if need_transpose:
org_out = org_out.transpose(1, -1)
org_out = org_out.to(dtype=r.dtype)
*shape, _ = org_out.shape
oft_out = torch.einsum(
"k n m, ... k n -> ... k m", r - I, org_out.view(*shape, block_num, block_size)
)
out = oft_out.view(*shape, -1)
if rescale is not None:
out = rescale.transpose(-1, 0) * out
out = out + (rescale - 1).transpose(-1, 0) * org_out
if need_transpose:
out = out.transpose(1, -1)
return out