from collections import OrderedDict import torch import torch.nn as nn import torch.nn.functional as F import torch.nn.utils.parametrize as parametrize from ..utils.quant import QuantLinears, log_bypass, log_suspect class ModuleCustomSD(nn.Module): def __init__(self): super().__init__() self._register_load_state_dict_pre_hook(self.load_weight_prehook) self.register_load_state_dict_post_hook(self.load_weight_hook) def load_weight_prehook( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs, ): pass def load_weight_hook(self, module, incompatible_keys): pass def custom_state_dict(self): return None def state_dict(self, *args, destination=None, prefix="", keep_vars=False): # TODO: Remove `args` and the parsing logic when BC allows. if len(args) > 0: if destination is None: destination = args[0] if len(args) > 1 and prefix == "": prefix = args[1] if len(args) > 2 and keep_vars is False: keep_vars = args[2] # DeprecationWarning is ignored by default if destination is None: destination = OrderedDict() destination._metadata = OrderedDict() local_metadata = dict(version=self._version) if hasattr(destination, "_metadata"): destination._metadata[prefix[:-1]] = local_metadata if (custom_sd := self.custom_state_dict()) is not None: for k, v in custom_sd.items(): destination[f"{prefix}{k}"] = v return destination else: return super().state_dict( *args, destination=destination, prefix=prefix, keep_vars=keep_vars ) class LycorisBaseModule(ModuleCustomSD): name: str dtype_tensor: torch.Tensor support_module = {} weight_list = [] weight_list_det = [] def __init__( self, lora_name, org_module: nn.Module, multiplier=1.0, dropout=0.0, rank_dropout=0.0, module_dropout=0.0, rank_dropout_scale=False, bypass_mode=None, **kwargs, ): """if alpha == 0 or None, alpha is rank (no scaling).""" super().__init__() self.lora_name = lora_name self.not_supported = False self.module = type(org_module) if isinstance(org_module, nn.Linear): self.module_type = "linear" self.shape = (org_module.out_features, org_module.in_features) self.op = F.linear self.dim = org_module.out_features self.kw_dict = {} elif isinstance(org_module, nn.Conv1d): self.module_type = "conv1d" self.shape = ( org_module.out_channels, org_module.in_channels, *org_module.kernel_size, ) self.op = F.conv1d self.dim = org_module.out_channels self.kw_dict = { "stride": org_module.stride, "padding": org_module.padding, "dilation": org_module.dilation, "groups": org_module.groups, } elif isinstance(org_module, nn.Conv2d): self.module_type = "conv2d" self.shape = ( org_module.out_channels, org_module.in_channels, *org_module.kernel_size, ) self.op = F.conv2d self.dim = org_module.out_channels self.kw_dict = { "stride": org_module.stride, "padding": org_module.padding, "dilation": org_module.dilation, "groups": org_module.groups, } elif isinstance(org_module, nn.Conv3d): self.module_type = "conv3d" self.shape = ( org_module.out_channels, org_module.in_channels, *org_module.kernel_size, ) self.op = F.conv3d self.dim = org_module.out_channels self.kw_dict = { "stride": org_module.stride, "padding": org_module.padding, "dilation": org_module.dilation, "groups": org_module.groups, } elif isinstance(org_module, nn.LayerNorm): self.module_type = "layernorm" self.shape = tuple(org_module.normalized_shape) self.op = F.layer_norm self.dim = org_module.normalized_shape[0] self.kw_dict = { "normalized_shape": org_module.normalized_shape, "eps": org_module.eps, } elif isinstance(org_module, nn.GroupNorm): self.module_type = "groupnorm" self.shape = (org_module.num_channels,) self.op = F.group_norm self.group_num = org_module.num_groups self.dim = org_module.num_channels self.kw_dict = {"num_groups": org_module.num_groups, "eps": org_module.eps} else: self.not_supported = True self.module_type = "unknown" self.register_buffer("dtype_tensor", torch.tensor(0.0), persistent=False) self.is_quant = False if isinstance(org_module, QuantLinears): if not bypass_mode: log_bypass() self.is_quant = True bypass_mode = True if ( isinstance(org_module, nn.Linear) and org_module.__class__.__name__ != "Linear" ): if bypass_mode is None: log_suspect() bypass_mode = True if bypass_mode == True: self.is_quant = True self.bypass_mode = bypass_mode self.dropout = dropout self.rank_dropout = rank_dropout self.rank_dropout_scale = rank_dropout_scale self.module_dropout = module_dropout ## Dropout things # Since LoKr/LoHa/OFT/BOFT are hard to follow the rank_dropout definition from kohya # We redefine the dropout procedure here. # g(x) = WX + drop(Brank_drop(AX)) for LoCon(lora), bypass # g(x) = WX + drop(ΔWX) for any algo except LoCon(lora), bypass # g(x) = (W + Brank_drop(A))X for LoCon(lora), rebuid # g(x) = (W + rank_drop(ΔW))X for any algo except LoCon(lora), rebuild self.drop = nn.Identity() if dropout == 0 else nn.Dropout(dropout) self.rank_drop = ( nn.Identity() if rank_dropout == 0 else nn.Dropout(rank_dropout) ) self.multiplier = multiplier self.org_forward = org_module.forward self.org_module = [org_module] @classmethod def parametrize(cls, org_module, attr, *args, **kwargs): from .full import FullModule if cls is FullModule: raise RuntimeError("FullModule cannot be used for parametrize.") target_param = getattr(org_module, attr) kwargs["bypass_mode"] = False if target_param.dim() == 2: proxy_module = nn.Linear( target_param.shape[0], target_param.shape[1], bias=False ) proxy_module.weight = target_param elif target_param.dim() > 2: module_type = [ None, None, None, nn.Conv1d, nn.Conv2d, nn.Conv3d, None, None, ][target_param.dim()] proxy_module = module_type( target_param.shape[0], target_param.shape[1], *target_param.shape[2:], bias=False, ) proxy_module.weight = target_param module_obj = cls("", proxy_module, *args, **kwargs) module_obj.forward = module_obj.parametrize_forward module_obj.to(target_param) parametrize.register_parametrization(org_module, attr, module_obj) return module_obj @classmethod def algo_check(cls, state_dict, lora_name): return any(f"{lora_name}.{k}" in state_dict for k in cls.weight_list_det) @classmethod def extract_state_dict(cls, state_dict, lora_name): return [state_dict.get(f"{lora_name}.{k}", None) for k in cls.weight_list] @classmethod def make_module_from_state_dict(cls, lora_name, orig_module, *weights): raise NotImplementedError @property def dtype(self): return self.dtype_tensor.dtype @property def device(self): return self.dtype_tensor.device @property def org_weight(self): return self.org_module[0].weight @org_weight.setter def org_weight(self, value): self.org_module[0].weight.data.copy_(value) def apply_to(self, **kwargs): if self.not_supported: return self.org_forward = self.org_module[0].forward self.org_module[0].forward = self.forward def restore(self): if self.not_supported: return self.org_module[0].forward = self.org_forward def merge_to(self, multiplier=1.0): if self.not_supported: return self_device = next(self.parameters()).device self_dtype = next(self.parameters()).dtype self.to(self.org_weight) weight, bias = self.get_merged_weight( multiplier, self.org_weight.shape, self.org_weight.device ) self.org_weight = weight.to(self.org_weight) if bias is not None: bias = bias.to(self.org_weight) if self.org_module[0].bias is not None: self.org_module[0].bias.data.copy_(bias) else: self.org_module[0].bias = nn.Parameter(bias) self.to(self_device, self_dtype) def get_diff_weight(self, multiplier=1.0, shape=None, device=None): raise NotImplementedError def get_merged_weight(self, multiplier=1.0, shape=None, device=None): raise NotImplementedError @torch.no_grad() def apply_max_norm(self, max_norm, device=None): return None, None def bypass_forward_diff(self, x, scale=1): raise NotImplementedError def bypass_forward(self, x, scale=1): raise NotImplementedError def parametrize_forward(self, x: torch.Tensor, *args, **kwargs): return self.get_merged_weight( multiplier=self.multiplier, shape=x.shape, device=x.device )[0].to(x.dtype) def forward(self, *args, **kwargs): raise NotImplementedError