import math import random import torch import torch.nn as nn from .base import LycorisBaseModule from ..utils import product class DyLoraModule(LycorisBaseModule): support_module = { "linear", "conv1d", "conv2d", "conv3d", } def __init__( self, lora_name, org_module: nn.Module, multiplier=1.0, lora_dim=4, alpha=1, dropout=0.0, rank_dropout=0.0, module_dropout=0.0, use_tucker=False, block_size=4, use_scalar=False, rank_dropout_scale=False, weight_decompose=False, bypass_mode=None, rs_lora=False, train_on_input=False, **kwargs, ): """if alpha == 0 or None, alpha is rank (no scaling).""" super().__init__( lora_name, org_module, multiplier, dropout, rank_dropout, module_dropout, rank_dropout_scale, bypass_mode, ) if self.module_type not in self.support_module: raise ValueError(f"{self.module_type} is not supported in IA^3 algo.") assert lora_dim % block_size == 0, "lora_dim must be a multiple of block_size" self.block_count = lora_dim // block_size self.block_size = block_size shape = ( self.shape[0], product(self.shape[1:]), ) self.lora_dim = lora_dim self.up_list = nn.ParameterList( [torch.empty(shape[0], self.block_size) for i in range(self.block_count)] ) self.down_list = nn.ParameterList( [torch.empty(self.block_size, shape[1]) for i in range(self.block_count)] ) if type(alpha) == torch.Tensor: alpha = alpha.detach().float().numpy() # without casting, bf16 causes error alpha = lora_dim if alpha is None or alpha == 0 else alpha self.scale = alpha / self.lora_dim self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える # Need more experiences on init method for v in self.down_list: torch.nn.init.kaiming_uniform_(v, a=math.sqrt(5)) for v in self.up_list: torch.nn.init.zeros_(v) def load_state_dict(self, state_dict, strict: bool = True, assign: bool = False): return def custom_state_dict(self): destination = {} destination["alpha"] = self.alpha destination["lora_up.weight"] = nn.Parameter( torch.concat(list(self.up_list), dim=1) ) destination["lora_down.weight"] = nn.Parameter( torch.concat(list(self.down_list)).reshape( self.lora_dim, -1, *self.shape[2:] ) ) return destination def get_weight(self, rank): b = math.ceil(rank / self.block_size) down = torch.concat( list(i.data for i in self.down_list[:b]) + list(self.down_list[b : (b + 1)]) ) up = torch.concat( list(i.data for i in self.up_list[:b]) + list(self.up_list[b : (b + 1)]), dim=1, ) return down, up, self.alpha / (b + 1) def get_random_rank_weight(self): b = random.randint(0, self.block_count - 1) return self.get_weight(b * self.block_size) def get_diff_weight(self, multiplier=1, shape=None, device=None, rank=None): if rank is None: down, up, scale = self.get_random_rank_weight() else: down, up, scale = self.get_weight(rank) w = up @ (down * (scale * multiplier)) if device is not None: w = w.to(device) if shape is not None: w = w.view(shape) else: w = w.view(self.shape) return w, None def get_merged_weight(self, multiplier=1, shape=None, device=None, rank=None): diff, _ = self.get_diff_weight(multiplier, shape, device, rank) return diff + self.org_weight, None def bypass_forward_diff(self, x, scale=1, rank=None): if rank is None: down, up, gamma = self.get_random_rank_weight() else: down, up, scale = self.get_weight(rank) down = down.view(self.lora_dim, -1, *self.shape[2:]) up = up.view(-1, self.lora_dim, *(1 for _ in self.shape[2:])) scale = scale * gamma return self.op(self.op(x, down, **self.kw_dict), up) def bypass_forward(self, x, scale=1, rank=None): return self.org_forward(x) + self.bypass_forward_diff(x, scale, rank) def forward(self, x, *args, **kwargs): if self.module_dropout and self.training: if torch.rand(1) < self.module_dropout: return self.org_forward(x) if self.bypass_mode: return self.bypass_forward(x, self.multiplier) else: weight = self.get_merged_weight(multiplier=self.multiplier)[0] bias = ( None if self.org_module[0].bias is None else self.org_module[0].bias.data ) return self.op(x, weight, bias, **self.kw_dict)