|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
|
|
FUNC_LIST = [None, None, F.linear, F.conv1d, F.conv2d, F.conv3d] |
|
|
|
|
|
def rebuild_tucker(t, wa, wb): |
|
rebuild2 = torch.einsum("i j ..., i p, j r -> p r ...", t, wa, wb) |
|
return rebuild2 |
|
|
|
|
|
def factorization(dimension: int, factor: int = -1) -> tuple[int, int]: |
|
""" |
|
return a tuple of two value of input dimension decomposed by the number closest to factor |
|
second value is higher or equal than first value. |
|
|
|
In LoRA with Kroneckor Product, first value is a value for weight scale. |
|
second value is a value for weight. |
|
|
|
Because of non-commutative property, A⊗B ≠ B⊗A. Meaning of two matrices is slightly different. |
|
|
|
examples) |
|
factor |
|
-1 2 4 8 16 ... |
|
127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 |
|
128 -> 8, 16 128 -> 2, 64 128 -> 4, 32 128 -> 8, 16 128 -> 8, 16 |
|
250 -> 10, 25 250 -> 2, 125 250 -> 2, 125 250 -> 5, 50 250 -> 10, 25 |
|
360 -> 8, 45 360 -> 2, 180 360 -> 4, 90 360 -> 8, 45 360 -> 12, 30 |
|
512 -> 16, 32 512 -> 2, 256 512 -> 4, 128 512 -> 8, 64 512 -> 16, 32 |
|
1024 -> 32, 32 1024 -> 2, 512 1024 -> 4, 256 1024 -> 8, 128 1024 -> 16, 64 |
|
""" |
|
|
|
if factor > 0 and (dimension % factor) == 0: |
|
m = factor |
|
n = dimension // factor |
|
if m > n: |
|
n, m = m, n |
|
return m, n |
|
if factor < 0: |
|
factor = dimension |
|
m, n = 1, dimension |
|
length = m + n |
|
while m < n: |
|
new_m = m + 1 |
|
while dimension % new_m != 0: |
|
new_m += 1 |
|
new_n = dimension // new_m |
|
if new_m + new_n > length or new_m > factor: |
|
break |
|
else: |
|
m, n = new_m, new_n |
|
if m > n: |
|
n, m = m, n |
|
return m, n |
|
|
|
|
|
def power2factorization(dimension: int, factor: int = -1) -> tuple[int, int]: |
|
""" |
|
m = 2k |
|
n = 2**p |
|
m*n = dim |
|
""" |
|
if factor == -1: |
|
factor = dimension |
|
|
|
|
|
m = n = 0 |
|
while m <= factor: |
|
m += 2 |
|
while dimension % m != 0 and m < dimension: |
|
m += 2 |
|
if m > factor: |
|
break |
|
if sum(int(i) for i in f"{dimension//m:b}") == 1: |
|
n = dimension // m |
|
|
|
if n == 0: |
|
return None, n |
|
return dimension // n, n |
|
|
|
|
|
def tucker_weight_from_conv(up, down, mid): |
|
up = up.reshape(up.size(0), up.size(1)) |
|
down = down.reshape(down.size(0), down.size(1)) |
|
return torch.einsum("m n ..., i m, n j -> i j ...", mid, up, down) |
|
|
|
|
|
def tucker_weight(wa, wb, t): |
|
temp = torch.einsum("i j ..., j r -> i r ...", t, wb) |
|
return torch.einsum("i j ..., i r -> r j ...", temp, wa) |
|
|
|
|
|
def apply_dora_scale(org_weight, rebuild, dora_scale, scale): |
|
dora_norm_dims = org_weight.dim() - 1 |
|
weight = org_weight + rebuild |
|
weight = weight.to(dora_scale.dtype) |
|
weight_norm = ( |
|
weight.transpose(0, 1) |
|
.reshape(weight.shape[1], -1) |
|
.norm(dim=1, keepdim=True) |
|
.reshape(weight.shape[1], *[1] * dora_norm_dims) |
|
.transpose(0, 1) |
|
) |
|
merged_scale1 = weight / weight_norm * dora_scale |
|
diff_weight = merged_scale1 - org_weight |
|
return org_weight + diff_weight * scale |
|
|