import math import torch import torch.nn as nn import torch.nn.functional as F from .base import LycorisBaseModule from ..functional import tucker_weight_from_conv class GLoRAModule(LycorisBaseModule): name = "glora" support_module = { "linear", "conv1d", "conv2d", "conv3d", } weight_list = [ "a1.weight", "a2.weight", "b1.weight", "b2.weight", "bm.weight", "alpha", ] weight_list_det = ["a1.weight"] def __init__( self, lora_name, org_module: nn.Module, multiplier=1.0, lora_dim=4, alpha=1, dropout=0.0, rank_dropout=0.0, module_dropout=0.0, use_tucker=False, use_scalar=False, rank_dropout_scale=False, weight_decompose=False, bypass_mode=None, rs_lora=False, **kwargs, ): """ f(x) = WX + WAX + BX, where A and B are low-rank matrices bypass_forward(x) = W(X+A(X)) + B(X) bypass_forward_diff(x) = W(A(X)) + B(X) get_merged_weight() = W + WA + B get_diff_weight() = WA + B """ super().__init__( lora_name, org_module, multiplier, dropout, rank_dropout, module_dropout, rank_dropout_scale, bypass_mode, ) if self.module_type not in self.support_module: raise ValueError(f"{self.module_type} is not supported in GLoRA algo.") self.lora_dim = lora_dim self.tucker = False self.rs_lora = rs_lora if self.module_type.startswith("conv"): self.isconv = True # For general LoCon in_dim = org_module.in_channels k_size = org_module.kernel_size stride = org_module.stride padding = org_module.padding out_dim = org_module.out_channels use_tucker = use_tucker and all(i == 1 for i in k_size) self.down_op = self.op self.up_op = self.op # A self.a2 = self.module(in_dim, lora_dim, 1, bias=False) self.a1 = self.module(lora_dim, in_dim, 1, bias=False) # B if use_tucker and any(i != 1 for i in k_size): self.b2 = self.module(in_dim, lora_dim, 1, bias=False) self.bm = self.module( lora_dim, lora_dim, k_size, stride, padding, bias=False ) self.tucker = True else: self.b2 = self.module( in_dim, lora_dim, k_size, stride, padding, bias=False ) self.b1 = self.module(lora_dim, out_dim, 1, bias=False) else: self.isconv = False self.down_op = F.linear self.up_op = F.linear in_dim = org_module.in_features out_dim = org_module.out_features self.a2 = nn.Linear(in_dim, lora_dim, bias=False) self.a1 = nn.Linear(lora_dim, in_dim, bias=False) self.b2 = nn.Linear(in_dim, lora_dim, bias=False) self.b1 = nn.Linear(lora_dim, out_dim, bias=False) if type(alpha) == torch.Tensor: alpha = alpha.detach().float().numpy() # without casting, bf16 causes error alpha = lora_dim if alpha is None or alpha == 0 else alpha r_factor = lora_dim if self.rs_lora: r_factor = math.sqrt(r_factor) self.scale = alpha / r_factor self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える if use_scalar: self.scalar = nn.Parameter(torch.tensor(0.0)) else: self.register_buffer("scalar", torch.tensor(1.0), persistent=False) # same as microsoft's torch.nn.init.kaiming_uniform_(self.a1.weight, a=math.sqrt(5)) torch.nn.init.kaiming_uniform_(self.b1.weight, a=math.sqrt(5)) if use_scalar: torch.nn.init.kaiming_uniform_(self.a2.weight, a=math.sqrt(5)) torch.nn.init.kaiming_uniform_(self.b2.weight, a=math.sqrt(5)) else: torch.nn.init.zeros_(self.a2.weight) torch.nn.init.zeros_(self.b2.weight) @classmethod def make_module_from_state_dict( cls, lora_name, orig_module, a1, a2, b1, b2, bm, alpha ): module = cls( lora_name, orig_module, 1, a2.size(0), float(alpha), use_tucker=bm is not None, ) module.a1.weight.data.copy_(a1) module.a2.weight.data.copy_(a2) module.b1.weight.data.copy_(b1) module.b2.weight.data.copy_(b2) if bm is not None: module.bm.weight.data.copy_(bm) return module def custom_state_dict(self): destination = {} destination["alpha"] = self.alpha destination["a1.weight"] = self.a1.weight destination["a2.weight"] = self.a2.weight * self.scalar destination["b1.weight"] = self.b1.weight destination["b2.weight"] = self.b2.weight * self.scalar if self.tucker: destination["bm.weight"] = self.bm.weight return destination def load_weight_hook(self, module: nn.Module, incompatible_keys): missing_keys = incompatible_keys.missing_keys for key in missing_keys: if "scalar" in key: del missing_keys[missing_keys.index(key)] if isinstance(self.scalar, nn.Parameter): self.scalar.data.copy_(torch.ones_like(self.scalar)) elif getattr(self, "scalar", None) is not None: self.scalar.copy_(torch.ones_like(self.scalar)) else: self.register_buffer( "scalar", torch.ones_like(self.scalar), persistent=False ) def make_weight(self, device=None): wa1 = self.a1.weight.view(self.a1.weight.size(0), -1) wa2 = self.a2.weight.view(self.a2.weight.size(0), -1) orig = self.org_weight if self.tucker: wb = tucker_weight_from_conv(self.b1.weight, self.b2.weight, self.bm.weight) else: wb1 = self.b1.weight.view(self.b1.weight.size(0), -1) wb2 = self.b2.weight.view(self.b2.weight.size(0), -1) wb = wb1 @ wb2 wb = wb.view(*orig.shape) if orig.dim() > 2: w_wa1 = torch.einsum("o i ..., i j -> o j ...", orig, wa1) w_wa2 = torch.einsum("o i ..., i j -> o j ...", w_wa1, wa2) else: w_wa2 = (orig @ wa1) @ wa2 return (wb + w_wa2) * self.scale * self.scalar def get_diff_weight(self, multiplier=1.0, shape=None, device=None): weight = self.make_weight(device) * multiplier if shape is not None: weight = weight.view(shape) return weight, None def get_merged_weight(self, multiplier=1, shape=None, device=None): diff_w, _ = self.get_diff_weight(multiplier, shape, device) return self.org_weight + diff_w, None def _bypass_forward(self, x, scale=1, diff=False): scale = self.scale * scale ax_mid = self.a2(x) * scale bx_mid = self.b2(x) * scale if self.rank_dropout and self.training: drop_a = ( torch.rand(self.lora_dim, device=ax_mid.device) < self.rank_dropout ).to(ax_mid.dtype) drop_b = ( torch.rand(self.lora_dim, device=bx_mid.device) < self.rank_dropout ).to(bx_mid.dtype) if self.rank_dropout_scale: drop_a /= drop_a.mean() drop_b /= drop_b.mean() if (dims := len(x.shape)) == 4: drop_a = drop_a.view(1, -1, 1, 1) drop_b = drop_b.view(1, -1, 1, 1) else: drop_a = drop_a.view(*[1] * (dims - 1), -1) drop_b = drop_b.view(*[1] * (dims - 1), -1) ax_mid = ax_mid * drop_a bx_mid = bx_mid * drop_b return ( self.org_forward( (0 if diff else x) + self.drop(self.a1(ax_mid)) * self.scale ) + self.drop(self.b1(bx_mid)) * self.scale ) def bypass_forward_diff(self, x, scale=1): return self._bypass_forward(x, scale=scale, diff=True) def bypass_forward(self, x, scale=1): return self._bypass_forward(x, scale=scale, diff=False) def forward(self, x, *args, **kwargs): if self.module_dropout and self.training: if torch.rand(1) < self.module_dropout: return self.org_forward(x) if self.bypass_mode: return self.bypass_forward(x, self.multiplier) else: weight = ( self.org_module[0].weight.data.to(self.dtype) + self.get_diff_weight(multiplier=self.multiplier)[0] ) bias = ( None if self.org_module[0].bias is None else self.org_module[0].bias.data ) return self.op(x, weight, bias, **self.kw_dict)