tonyshark's picture
Upload 132 files
cc69848 verified
import math
from functools import cache
import torch
import torch.nn as nn
import torch.nn.functional as F
from .base import LycorisBaseModule
from ..functional import factorization, rebuild_tucker
from ..functional.lokr import make_kron
from ..logging import logger
@cache
def logging_force_full_matrix(lora_dim, dim, factor):
logger.warning(
f"lora_dim {lora_dim} is too large for"
f" dim={dim} and {factor=}"
", using full matrix mode."
)
class LokrModule(LycorisBaseModule):
name = "kron"
support_module = {
"linear",
"conv1d",
"conv2d",
"conv3d",
}
weight_list = [
"lokr_w1",
"lokr_w1_a",
"lokr_w1_b",
"lokr_w2",
"lokr_w2_a",
"lokr_w2_b",
"lokr_t1",
"lokr_t2",
"alpha",
"dora_scale",
]
weight_list_det = ["lokr_w1", "lokr_w1_a"]
def __init__(
self,
lora_name,
org_module: nn.Module,
multiplier=1.0,
lora_dim=4,
alpha=1,
dropout=0.0,
rank_dropout=0.0,
module_dropout=0.0,
use_tucker=False,
use_scalar=False,
decompose_both=False,
factor: int = -1, # factorization factor
rank_dropout_scale=False,
weight_decompose=False,
wd_on_out=False,
full_matrix=False,
bypass_mode=None,
rs_lora=False,
unbalanced_factorization=False,
**kwargs,
):
super().__init__(
lora_name,
org_module,
multiplier,
dropout,
rank_dropout,
module_dropout,
rank_dropout_scale,
bypass_mode,
)
if self.module_type not in self.support_module:
raise ValueError(f"{self.module_type} is not supported in LoKr algo.")
factor = int(factor)
self.lora_dim = lora_dim
self.tucker = False
self.use_w1 = False
self.use_w2 = False
self.full_matrix = full_matrix
self.rs_lora = rs_lora
if self.module_type.startswith("conv"):
in_dim = org_module.in_channels
k_size = org_module.kernel_size
out_dim = org_module.out_channels
self.shape = (out_dim, in_dim, *k_size)
in_m, in_n = factorization(in_dim, factor)
out_l, out_k = factorization(out_dim, factor)
if unbalanced_factorization:
out_l, out_k = out_k, out_l
shape = ((out_l, out_k), (in_m, in_n), *k_size) # ((a, b), (c, d), *k_size)
self.tucker = use_tucker and any(i != 1 for i in k_size)
if (
decompose_both
and lora_dim < max(shape[0][0], shape[1][0]) / 2
and not self.full_matrix
):
self.lokr_w1_a = nn.Parameter(torch.empty(shape[0][0], lora_dim))
self.lokr_w1_b = nn.Parameter(torch.empty(lora_dim, shape[1][0]))
else:
self.use_w1 = True
self.lokr_w1 = nn.Parameter(
torch.empty(shape[0][0], shape[1][0])
) # a*c, 1-mode
if lora_dim >= max(shape[0][1], shape[1][1]) / 2 or self.full_matrix:
if not self.full_matrix:
logging_force_full_matrix(lora_dim, max(in_dim, out_dim), factor)
self.use_w2 = True
self.lokr_w2 = nn.Parameter(
torch.empty(shape[0][1], shape[1][1], *k_size)
)
elif self.tucker:
self.lokr_t2 = nn.Parameter(torch.empty(lora_dim, lora_dim, *shape[2:]))
self.lokr_w2_a = nn.Parameter(
torch.empty(lora_dim, shape[0][1])
) # b, 1-mode
self.lokr_w2_b = nn.Parameter(
torch.empty(lora_dim, shape[1][1])
) # d, 2-mode
else: # Conv2d not tucker
# bigger part. weight and LoRA. [b, dim] x [dim, d*k1*k2]
self.lokr_w2_a = nn.Parameter(torch.empty(shape[0][1], lora_dim))
self.lokr_w2_b = nn.Parameter(
torch.empty(
lora_dim, shape[1][1] * torch.tensor(shape[2:]).prod().item()
)
)
# w1 βŠ— (w2_a x w2_b) = (a, b)βŠ—((c, dim)x(dim, d*k1*k2)) = (a, b)βŠ—(c, d*k1*k2) = (ac, bd*k1*k2)
else: # Linear
in_dim = org_module.in_features
out_dim = org_module.out_features
self.shape = (out_dim, in_dim)
in_m, in_n = factorization(in_dim, factor)
out_l, out_k = factorization(out_dim, factor)
if unbalanced_factorization:
out_l, out_k = out_k, out_l
shape = (
(out_l, out_k),
(in_m, in_n),
) # ((a, b), (c, d)), out_dim = a*c, in_dim = b*d
# smaller part. weight scale
if (
decompose_both
and lora_dim < max(shape[0][0], shape[1][0]) / 2
and not self.full_matrix
):
self.lokr_w1_a = nn.Parameter(torch.empty(shape[0][0], lora_dim))
self.lokr_w1_b = nn.Parameter(torch.empty(lora_dim, shape[1][0]))
else:
self.use_w1 = True
self.lokr_w1 = nn.Parameter(
torch.empty(shape[0][0], shape[1][0])
) # a*c, 1-mode
if lora_dim < max(shape[0][1], shape[1][1]) / 2 and not self.full_matrix:
# bigger part. weight and LoRA. [b, dim] x [dim, d]
self.lokr_w2_a = nn.Parameter(torch.empty(shape[0][1], lora_dim))
self.lokr_w2_b = nn.Parameter(torch.empty(lora_dim, shape[1][1]))
# w1 βŠ— (w2_a x w2_b) = (a, b)βŠ—((c, dim)x(dim, d)) = (a, b)βŠ—(c, d) = (ac, bd)
else:
if not self.full_matrix:
logging_force_full_matrix(lora_dim, max(in_dim, out_dim), factor)
self.use_w2 = True
self.lokr_w2 = nn.Parameter(torch.empty(shape[0][1], shape[1][1]))
self.wd = weight_decompose
self.wd_on_out = wd_on_out
if self.wd:
org_weight = org_module.weight.cpu().clone().float()
self.dora_norm_dims = org_weight.dim() - 1
if self.wd_on_out:
self.dora_scale = nn.Parameter(
torch.norm(
org_weight.reshape(org_weight.shape[0], -1),
dim=1,
keepdim=True,
).reshape(org_weight.shape[0], *[1] * self.dora_norm_dims)
).float()
else:
self.dora_scale = nn.Parameter(
torch.norm(
org_weight.transpose(1, 0).reshape(org_weight.shape[1], -1),
dim=1,
keepdim=True,
)
.reshape(org_weight.shape[1], *[1] * self.dora_norm_dims)
.transpose(1, 0)
).float()
self.dropout = dropout
if dropout:
print("[WARN]LoHa/LoKr haven't implemented normal dropout yet.")
self.rank_dropout = rank_dropout
self.rank_dropout_scale = rank_dropout_scale
self.module_dropout = module_dropout
if isinstance(alpha, torch.Tensor):
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
alpha = lora_dim if alpha is None or alpha == 0 else alpha
if self.use_w2 and self.use_w1:
# use scale = 1
alpha = lora_dim
r_factor = lora_dim
if self.rs_lora:
r_factor = math.sqrt(r_factor)
self.scale = alpha / r_factor
self.register_buffer("alpha", torch.tensor(alpha * (lora_dim / r_factor)))
if use_scalar:
self.scalar = nn.Parameter(torch.tensor(0.0))
else:
self.register_buffer("scalar", torch.tensor(1.0), persistent=False)
if self.use_w2:
if use_scalar:
torch.nn.init.kaiming_uniform_(self.lokr_w2, a=math.sqrt(5))
else:
torch.nn.init.constant_(self.lokr_w2, 0)
else:
if self.tucker:
torch.nn.init.kaiming_uniform_(self.lokr_t2, a=math.sqrt(5))
torch.nn.init.kaiming_uniform_(self.lokr_w2_a, a=math.sqrt(5))
if use_scalar:
torch.nn.init.kaiming_uniform_(self.lokr_w2_b, a=math.sqrt(5))
else:
torch.nn.init.constant_(self.lokr_w2_b, 0)
if self.use_w1:
torch.nn.init.kaiming_uniform_(self.lokr_w1, a=math.sqrt(5))
else:
torch.nn.init.kaiming_uniform_(self.lokr_w1_a, a=math.sqrt(5))
torch.nn.init.kaiming_uniform_(self.lokr_w1_b, a=math.sqrt(5))
@classmethod
def make_module_from_state_dict(
cls,
lora_name,
orig_module,
w1,
w1a,
w1b,
w2,
w2a,
w2b,
_,
t2,
alpha,
dora_scale,
):
full_matrix = False
if w1a is not None:
lora_dim = w1a.size(1)
elif w2a is not None:
lora_dim = w2a.size(1)
else:
full_matrix = True
lora_dim = 1
if w1 is None:
out_dim = w1a.size(0)
in_dim = w1b.size(1)
else:
out_dim, in_dim = w1.shape
shape_s = [out_dim, in_dim]
if w2 is None:
out_dim *= w2a.size(0)
in_dim *= w2b.size(1)
else:
out_dim *= w2.size(0)
in_dim *= w2.size(1)
if (
shape_s[0] == factorization(out_dim, -1)[0]
and shape_s[1] == factorization(in_dim, -1)[0]
):
factor = -1
else:
w1_shape = w1.shape if w1 is not None else (w1a.size(0), w1b.size(1))
w2_shape = w2.shape if w2 is not None else (w2a.size(0), w2b.size(1))
shape_group_1 = (w1_shape[0], w2_shape[0])
shape_group_2 = (w1_shape[1], w2_shape[1])
w_shape = (w1_shape[0] * w2_shape[0], w1_shape[1] * w2_shape[1])
factor1 = max(w1.shape) if w1 is not None else max(w1a.size(0), w1b.size(1))
factor2 = max(w2.shape) if w2 is not None else max(w2a.size(0), w2b.size(1))
if (
w_shape[0] % factor1 == 0
and w_shape[1] % factor1 == 0
and factor1 in shape_group_1
and factor1 in shape_group_2
):
factor = factor1
elif (
w_shape[0] % factor2 == 0
and w_shape[1] % factor2 == 0
and factor2 in shape_group_1
and factor2 in shape_group_2
):
factor = factor2
else:
factor = min(factor1, factor2)
module = cls(
lora_name,
orig_module,
1,
lora_dim,
float(alpha),
use_tucker=t2 is not None,
decompose_both=w1 is None and w2 is None,
factor=factor,
weight_decompose=dora_scale is not None,
full_matrix=full_matrix,
)
if w1 is not None:
module.lokr_w1.copy_(w1)
else:
module.lokr_w1_a.copy_(w1a)
module.lokr_w1_b.copy_(w1b)
if w2 is not None:
module.lokr_w2.copy_(w2)
else:
module.lokr_w2_a.copy_(w2a)
module.lokr_w2_b.copy_(w2b)
if t2 is not None:
module.lokr_t2.copy_(t2)
if dora_scale is not None:
module.dora_scale.copy_(dora_scale)
return module
def load_weight_hook(self, module: nn.Module, incompatible_keys):
missing_keys = incompatible_keys.missing_keys
for key in missing_keys:
if "scalar" in key:
del missing_keys[missing_keys.index(key)]
if isinstance(self.scalar, nn.Parameter):
self.scalar.data.copy_(torch.ones_like(self.scalar))
elif getattr(self, "scalar", None) is not None:
self.scalar.copy_(torch.ones_like(self.scalar))
else:
self.register_buffer(
"scalar", torch.ones_like(self.scalar), persistent=False
)
def get_weight(self, shape):
weight = make_kron(
self.lokr_w1 if self.use_w1 else self.lokr_w1_a @ self.lokr_w1_b,
(
self.lokr_w2
if self.use_w2
else (
rebuild_tucker(self.lokr_t2, self.lokr_w2_a, self.lokr_w2_b)
if self.tucker
else self.lokr_w2_a @ self.lokr_w2_b
)
),
self.scale,
)
dtype = weight.dtype
if shape is not None:
weight = weight.view(shape)
if self.training and self.rank_dropout:
drop = (torch.rand(weight.size(0)) > self.rank_dropout).to(dtype)
drop = drop.view(-1, *[1] * len(weight.shape[1:]))
if self.rank_dropout_scale:
drop /= drop.mean()
weight *= drop
return weight
def get_diff_weight(self, multiplier=1, shape=None, device=None):
scale = self.scale * multiplier
diff = self.get_weight(shape) * scale
if device is not None:
diff = diff.to(device)
return diff, None
def get_merged_weight(self, multiplier=1, shape=None, device=None):
diff = self.get_diff_weight(multiplier=1, shape=shape, device=device)[0]
weight = self.org_weight
if self.wd:
merged = self.apply_weight_decompose(weight + diff, multiplier)
else:
merged = weight + diff * multiplier
return merged, None
def apply_weight_decompose(self, weight, multiplier=1):
weight = weight.to(self.dora_scale.dtype)
if self.wd_on_out:
weight_norm = (
weight.reshape(weight.shape[0], -1)
.norm(dim=1)
.reshape(weight.shape[0], *[1] * self.dora_norm_dims)
) + torch.finfo(weight.dtype).eps
else:
weight_norm = (
weight.transpose(0, 1)
.reshape(weight.shape[1], -1)
.norm(dim=1, keepdim=True)
.reshape(weight.shape[1], *[1] * self.dora_norm_dims)
.transpose(0, 1)
) + torch.finfo(weight.dtype).eps
scale = self.dora_scale.to(weight.device) / weight_norm
if multiplier != 1:
scale = multiplier * (scale - 1) + 1
return weight * scale
def custom_state_dict(self):
destination = {}
destination["alpha"] = self.alpha
if self.wd:
destination["dora_scale"] = self.dora_scale
if self.use_w1:
destination["lokr_w1"] = self.lokr_w1 * self.scalar
else:
destination["lokr_w1_a"] = self.lokr_w1_a * self.scalar
destination["lokr_w1_b"] = self.lokr_w1_b
if self.use_w2:
destination["lokr_w2"] = self.lokr_w2
else:
destination["lokr_w2_a"] = self.lokr_w2_a
destination["lokr_w2_b"] = self.lokr_w2_b
if self.tucker:
destination["lokr_t2"] = self.lokr_t2
return destination
@torch.no_grad()
def apply_max_norm(self, max_norm, device=None):
orig_norm = self.get_weight(self.shape).norm()
norm = torch.clamp(orig_norm, max_norm / 2)
desired = torch.clamp(norm, max=max_norm)
ratio = desired.cpu() / norm.cpu()
scaled = norm != desired
if scaled:
modules = 4 - self.use_w1 - self.use_w2 + (not self.use_w2 and self.tucker)
if self.use_w1:
self.lokr_w1 *= ratio ** (1 / modules)
else:
self.lokr_w1_a *= ratio ** (1 / modules)
self.lokr_w1_b *= ratio ** (1 / modules)
if self.use_w2:
self.lokr_w2 *= ratio ** (1 / modules)
else:
if self.tucker:
self.lokr_t2 *= ratio ** (1 / modules)
self.lokr_w2_a *= ratio ** (1 / modules)
self.lokr_w2_b *= ratio ** (1 / modules)
return scaled, orig_norm * ratio
def bypass_forward_diff(self, h, scale=1):
is_conv = self.module_type.startswith("conv")
if self.use_w2:
ba = self.lokr_w2
else:
a = self.lokr_w2_b
b = self.lokr_w2_a
if self.tucker:
t = self.lokr_t2
a = a.view(*a.shape, *[1] * (len(t.shape) - 2))
b = b.view(*b.shape, *[1] * (len(t.shape) - 2))
elif is_conv:
a = a.view(*a.shape, *self.shape[2:])
b = b.view(*b.shape, *[1] * (len(self.shape) - 2))
if self.use_w1:
c = self.lokr_w1
else:
c = self.lokr_w1_a @ self.lokr_w1_b
uq = c.size(1)
if is_conv:
# (b, uq), vq, ...
b, _, *rest = h.shape
h_in_group = h.reshape(b * uq, -1, *rest)
else:
# b, ..., uq, vq
h_in_group = h.reshape(*h.shape[:-1], uq, -1)
if self.use_w2:
hb = self.op(h_in_group, ba, **self.kw_dict)
else:
if is_conv:
if self.tucker:
ha = self.op(h_in_group, a)
ht = self.op(ha, t, **self.kw_dict)
hb = self.op(ht, b)
else:
ha = self.op(h_in_group, a, **self.kw_dict)
hb = self.op(ha, b)
else:
ha = self.op(h_in_group, a, **self.kw_dict)
hb = self.op(ha, b)
if is_conv:
# (b, uq), vp, ..., f
# -> b, uq, vp, ..., f
# -> b, f, vp, ..., uq
hb = hb.view(b, -1, *hb.shape[1:])
h_cross_group = hb.transpose(1, -1)
else:
# b, ..., uq, vq
# -> b, ..., vq, uq
h_cross_group = hb.transpose(-1, -2)
hc = F.linear(h_cross_group, c)
if is_conv:
# b, f, vp, ..., up
# -> b, up, vp, ... ,f
# -> b, c, ..., f
hc = hc.transpose(1, -1)
h = hc.reshape(b, -1, *hc.shape[3:])
else:
# b, ..., vp, up
# -> b, ..., up, vp
# -> b, ..., c
hc = hc.transpose(-1, -2)
h = hc.reshape(*hc.shape[:-2], -1)
return self.drop(h * scale * self.scalar)
def bypass_forward(self, x, scale=1):
return self.org_forward(x) + self.bypass_forward_diff(x, scale=scale)
def forward(self, x: torch.Tensor, *args, **kwargs):
if self.module_dropout and self.training:
if torch.rand(1) < self.module_dropout:
return self.org_forward(x)
if self.bypass_mode:
return self.bypass_forward(x, self.multiplier)
else:
diff_weight = self.get_weight(self.shape).to(self.dtype) * self.scalar
weight = self.org_module[0].weight.data.to(self.dtype)
if self.wd:
weight = self.apply_weight_decompose(
weight + diff_weight, self.multiplier
)
elif self.multiplier == 1:
weight = weight + diff_weight
else:
weight = weight + diff_weight * self.multiplier
bias = (
None
if self.org_module[0].bias is None
else self.org_module[0].bias.data
)
return self.op(x, weight, bias, **self.kw_dict)
if __name__ == "__main__":
base = nn.Conv2d(128, 128, 3, 1, 1)
net = LokrModule(
"",
base,
multiplier=1,
lora_dim=4,
alpha=1,
weight_decompose=False,
use_tucker=False,
use_scalar=False,
decompose_both=True,
)
net.apply_to()
sd = net.state_dict()
for key in sd:
if key != "alpha":
sd[key] = torch.randn_like(sd[key])
net.load_state_dict(sd)
test_input = torch.randn(1, 128, 16, 16)
test_output = net(test_input)
print(test_output.shape)
net2 = LokrModule(
"",
base,
multiplier=1,
lora_dim=4,
alpha=1,
weight_decompose=False,
use_tucker=False,
use_scalar=False,
bypass_mode=True,
decompose_both=True,
)
net2.apply_to()
net2.load_state_dict(sd)
print(net2)
test_output2 = net(test_input)
print(F.mse_loss(test_output, test_output2))