import math import torch import torch.nn as nn from .base import LycorisBaseModule from ..functional.loha import diff_weight as loha_diff_weight class LohaModule(LycorisBaseModule): name = "loha" support_module = { "linear", "conv1d", "conv2d", "conv3d", } weight_list = [ "hada_w1_a", "hada_w1_b", "hada_w2_a", "hada_w2_b", "hada_t1", "hada_t2", "alpha", "dora_scale", ] weight_list_det = ["hada_w1_a"] def __init__( self, lora_name, org_module: nn.Module, multiplier=1.0, lora_dim=4, alpha=1, dropout=0.0, rank_dropout=0.0, module_dropout=0.0, use_tucker=False, use_scalar=False, rank_dropout_scale=False, weight_decompose=False, wd_on_out=False, bypass_mode=None, rs_lora=False, **kwargs, ): super().__init__( lora_name, org_module, multiplier, dropout, rank_dropout, module_dropout, rank_dropout_scale, bypass_mode, ) if self.module_type not in self.support_module: raise ValueError(f"{self.module_type} is not supported in LoHa algo.") self.lora_name = lora_name self.lora_dim = lora_dim self.tucker = False self.rs_lora = rs_lora w_shape = self.shape if self.module_type.startswith("conv"): in_dim = org_module.in_channels k_size = org_module.kernel_size out_dim = org_module.out_channels self.shape = (out_dim, in_dim, *k_size) self.tucker = use_tucker and any(i != 1 for i in k_size) if self.tucker: w_shape = (out_dim, in_dim, *k_size) else: w_shape = (out_dim, in_dim * torch.tensor(k_size).prod().item()) if self.tucker: self.hada_t1 = nn.Parameter(torch.empty(lora_dim, lora_dim, *w_shape[2:])) self.hada_w1_a = nn.Parameter( torch.empty(lora_dim, w_shape[0]) ) # out_dim, 1-mode self.hada_w1_b = nn.Parameter( torch.empty(lora_dim, w_shape[1]) ) # in_dim , 2-mode self.hada_t2 = nn.Parameter(torch.empty(lora_dim, lora_dim, *w_shape[2:])) self.hada_w2_a = nn.Parameter( torch.empty(lora_dim, w_shape[0]) ) # out_dim, 1-mode self.hada_w2_b = nn.Parameter( torch.empty(lora_dim, w_shape[1]) ) # in_dim , 2-mode else: self.hada_w1_a = nn.Parameter(torch.empty(w_shape[0], lora_dim)) self.hada_w1_b = nn.Parameter(torch.empty(lora_dim, w_shape[1])) self.hada_w2_a = nn.Parameter(torch.empty(w_shape[0], lora_dim)) self.hada_w2_b = nn.Parameter(torch.empty(lora_dim, w_shape[1])) self.wd = weight_decompose self.wd_on_out = wd_on_out if self.wd: org_weight = org_module.weight.cpu().clone().float() self.dora_norm_dims = org_weight.dim() - 1 if self.wd_on_out: self.dora_scale = nn.Parameter( torch.norm( org_weight.reshape(org_weight.shape[0], -1), dim=1, keepdim=True, ).reshape(org_weight.shape[0], *[1] * self.dora_norm_dims) ).float() else: self.dora_scale = nn.Parameter( torch.norm( org_weight.transpose(1, 0).reshape(org_weight.shape[1], -1), dim=1, keepdim=True, ) .reshape(org_weight.shape[1], *[1] * self.dora_norm_dims) .transpose(1, 0) ).float() if self.dropout: print("[WARN]LoHa/LoKr haven't implemented normal dropout yet.") if type(alpha) == torch.Tensor: alpha = alpha.detach().float().numpy() # without casting, bf16 causes error alpha = lora_dim if alpha is None or alpha == 0 else alpha r_factor = lora_dim if self.rs_lora: r_factor = math.sqrt(r_factor) self.scale = alpha / r_factor self.register_buffer("alpha", torch.tensor(alpha * (lora_dim / r_factor))) if use_scalar: self.scalar = nn.Parameter(torch.tensor(0.0)) else: self.register_buffer("scalar", torch.tensor(1.0), persistent=False) # Need more experiments on init method if self.tucker: torch.nn.init.normal_(self.hada_t1, std=0.1) torch.nn.init.normal_(self.hada_t2, std=0.1) torch.nn.init.normal_(self.hada_w1_b, std=1) torch.nn.init.normal_(self.hada_w1_a, std=0.1) torch.nn.init.normal_(self.hada_w2_b, std=1) if use_scalar: torch.nn.init.normal_(self.hada_w2_a, std=0.1) else: torch.nn.init.constant_(self.hada_w2_a, 0) @classmethod def make_module_from_state_dict( cls, lora_name, orig_module, w1a, w1b, w2a, w2b, t1, t2, alpha, dora_scale ): module = cls( lora_name, orig_module, 1, w1b.size(0), float(alpha), use_tucker=t1 is not None, weight_decompose=dora_scale is not None, ) module.hada_w1_a.copy_(w1a) module.hada_w1_b.copy_(w1b) module.hada_w2_a.copy_(w2a) module.hada_w2_b.copy_(w2b) if t1 is not None: module.hada_t1.copy_(t1) module.hada_t2.copy_(t2) if dora_scale is not None: module.dora_scale.copy_(dora_scale) return module def load_weight_hook(self, module: nn.Module, incompatible_keys): missing_keys = incompatible_keys.missing_keys for key in missing_keys: if "scalar" in key: del missing_keys[missing_keys.index(key)] if isinstance(self.scalar, nn.Parameter): self.scalar.data.copy_(torch.ones_like(self.scalar)) elif getattr(self, "scalar", None) is not None: self.scalar.copy_(torch.ones_like(self.scalar)) else: self.register_buffer( "scalar", torch.ones_like(self.scalar), persistent=False ) def get_weight(self, shape): scale = torch.tensor( self.scale, dtype=self.hada_w1_b.dtype, device=self.hada_w1_b.device ) if self.tucker: weight = loha_diff_weight( self.hada_w1_b, self.hada_w1_a, self.hada_w2_b, self.hada_w2_a, self.hada_t1, self.hada_t2, gamma=scale, ) else: weight = loha_diff_weight( self.hada_w1_b, self.hada_w1_a, self.hada_w2_b, self.hada_w2_a, None, None, gamma=scale, ) if shape is not None: weight = weight.reshape(shape) if self.training and self.rank_dropout: drop = (torch.rand(weight.size(0)) > self.rank_dropout).to(weight.dtype) drop = drop.view(-1, *[1] * len(weight.shape[1:])).to(weight.device) if self.rank_dropout_scale: drop /= drop.mean() weight *= drop return weight def get_diff_weight(self, multiplier=1, shape=None, device=None): scale = self.scale * multiplier diff = self.get_weight(shape) * scale if device is not None: diff = diff.to(device) return diff, None def get_merged_weight(self, multiplier=1, shape=None, device=None): diff = self.get_diff_weight(multiplier=1, shape=shape, device=device)[0] weight = self.org_weight if self.wd: merged = self.apply_weight_decompose(weight + diff, multiplier) else: merged = weight + diff * multiplier return merged, None def apply_weight_decompose(self, weight, multiplier=1): weight = weight.to(self.dora_scale.dtype) if self.wd_on_out: weight_norm = ( weight.reshape(weight.shape[0], -1) .norm(dim=1) .reshape(weight.shape[0], *[1] * self.dora_norm_dims) ) + torch.finfo(weight.dtype).eps else: weight_norm = ( weight.transpose(0, 1) .reshape(weight.shape[1], -1) .norm(dim=1, keepdim=True) .reshape(weight.shape[1], *[1] * self.dora_norm_dims) .transpose(0, 1) ) + torch.finfo(weight.dtype).eps scale = self.dora_scale.to(weight.device) / weight_norm if multiplier != 1: scale = multiplier * (scale - 1) + 1 return weight * scale def custom_state_dict(self): destination = {} destination["alpha"] = self.alpha if self.wd: destination["dora_scale"] = self.dora_scale destination["hada_w1_a"] = self.hada_w1_a * self.scalar destination["hada_w1_b"] = self.hada_w1_b destination["hada_w2_a"] = self.hada_w2_a destination["hada_w2_b"] = self.hada_w2_b if self.tucker: destination["hada_t1"] = self.hada_t1 destination["hada_t2"] = self.hada_t2 return destination @torch.no_grad() def apply_max_norm(self, max_norm, device=None): orig_norm = (self.get_weight(self.shape) * self.scalar).norm() norm = torch.clamp(orig_norm, max_norm / 2) desired = torch.clamp(norm, max=max_norm) ratio = desired.cpu() / norm.cpu() scaled = norm != desired if scaled: self.scalar *= ratio return scaled, orig_norm * ratio def bypass_forward_diff(self, x, scale=1): diff_weight = self.get_weight(self.shape) * self.scalar * scale return self.drop(self.op(x, diff_weight, **self.kw_dict)) def bypass_forward(self, x, scale=1): return self.org_forward(x) + self.bypass_forward_diff(x, scale=scale) def forward(self, x: torch.Tensor, *args, **kwargs): if self.module_dropout and self.training: if torch.rand(1) < self.module_dropout: return self.op( x, self.org_module[0].weight.data, ( None if self.org_module[0].bias is None else self.org_module[0].bias.data ), ) if self.bypass_mode: return self.bypass_forward(x, scale=self.multiplier) else: diff_weight = self.get_weight(self.shape).to(self.dtype) * self.scalar weight = self.org_module[0].weight.data.to(self.dtype) if self.wd: weight = self.apply_weight_decompose( weight + diff_weight, self.multiplier ) else: weight = weight + diff_weight * self.multiplier bias = ( None if self.org_module[0].bias is None else self.org_module[0].bias.data ) return self.op(x, weight, bias, **self.kw_dict)