File size: 10,735 Bytes
cc69848 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 |
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.utils.parametrize as parametrize
from ..utils.quant import QuantLinears, log_bypass, log_suspect
class ModuleCustomSD(nn.Module):
def __init__(self):
super().__init__()
self._register_load_state_dict_pre_hook(self.load_weight_prehook)
self.register_load_state_dict_post_hook(self.load_weight_hook)
def load_weight_prehook(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
pass
def load_weight_hook(self, module, incompatible_keys):
pass
def custom_state_dict(self):
return None
def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
# TODO: Remove `args` and the parsing logic when BC allows.
if len(args) > 0:
if destination is None:
destination = args[0]
if len(args) > 1 and prefix == "":
prefix = args[1]
if len(args) > 2 and keep_vars is False:
keep_vars = args[2]
# DeprecationWarning is ignored by default
if destination is None:
destination = OrderedDict()
destination._metadata = OrderedDict()
local_metadata = dict(version=self._version)
if hasattr(destination, "_metadata"):
destination._metadata[prefix[:-1]] = local_metadata
if (custom_sd := self.custom_state_dict()) is not None:
for k, v in custom_sd.items():
destination[f"{prefix}{k}"] = v
return destination
else:
return super().state_dict(
*args, destination=destination, prefix=prefix, keep_vars=keep_vars
)
class LycorisBaseModule(ModuleCustomSD):
name: str
dtype_tensor: torch.Tensor
support_module = {}
weight_list = []
weight_list_det = []
def __init__(
self,
lora_name,
org_module: nn.Module,
multiplier=1.0,
dropout=0.0,
rank_dropout=0.0,
module_dropout=0.0,
rank_dropout_scale=False,
bypass_mode=None,
**kwargs,
):
"""if alpha == 0 or None, alpha is rank (no scaling)."""
super().__init__()
self.lora_name = lora_name
self.not_supported = False
self.module = type(org_module)
if isinstance(org_module, nn.Linear):
self.module_type = "linear"
self.shape = (org_module.out_features, org_module.in_features)
self.op = F.linear
self.dim = org_module.out_features
self.kw_dict = {}
elif isinstance(org_module, nn.Conv1d):
self.module_type = "conv1d"
self.shape = (
org_module.out_channels,
org_module.in_channels,
*org_module.kernel_size,
)
self.op = F.conv1d
self.dim = org_module.out_channels
self.kw_dict = {
"stride": org_module.stride,
"padding": org_module.padding,
"dilation": org_module.dilation,
"groups": org_module.groups,
}
elif isinstance(org_module, nn.Conv2d):
self.module_type = "conv2d"
self.shape = (
org_module.out_channels,
org_module.in_channels,
*org_module.kernel_size,
)
self.op = F.conv2d
self.dim = org_module.out_channels
self.kw_dict = {
"stride": org_module.stride,
"padding": org_module.padding,
"dilation": org_module.dilation,
"groups": org_module.groups,
}
elif isinstance(org_module, nn.Conv3d):
self.module_type = "conv3d"
self.shape = (
org_module.out_channels,
org_module.in_channels,
*org_module.kernel_size,
)
self.op = F.conv3d
self.dim = org_module.out_channels
self.kw_dict = {
"stride": org_module.stride,
"padding": org_module.padding,
"dilation": org_module.dilation,
"groups": org_module.groups,
}
elif isinstance(org_module, nn.LayerNorm):
self.module_type = "layernorm"
self.shape = tuple(org_module.normalized_shape)
self.op = F.layer_norm
self.dim = org_module.normalized_shape[0]
self.kw_dict = {
"normalized_shape": org_module.normalized_shape,
"eps": org_module.eps,
}
elif isinstance(org_module, nn.GroupNorm):
self.module_type = "groupnorm"
self.shape = (org_module.num_channels,)
self.op = F.group_norm
self.group_num = org_module.num_groups
self.dim = org_module.num_channels
self.kw_dict = {"num_groups": org_module.num_groups, "eps": org_module.eps}
else:
self.not_supported = True
self.module_type = "unknown"
self.register_buffer("dtype_tensor", torch.tensor(0.0), persistent=False)
self.is_quant = False
if isinstance(org_module, QuantLinears):
if not bypass_mode:
log_bypass()
self.is_quant = True
bypass_mode = True
if (
isinstance(org_module, nn.Linear)
and org_module.__class__.__name__ != "Linear"
):
if bypass_mode is None:
log_suspect()
bypass_mode = True
if bypass_mode == True:
self.is_quant = True
self.bypass_mode = bypass_mode
self.dropout = dropout
self.rank_dropout = rank_dropout
self.rank_dropout_scale = rank_dropout_scale
self.module_dropout = module_dropout
## Dropout things
# Since LoKr/LoHa/OFT/BOFT are hard to follow the rank_dropout definition from kohya
# We redefine the dropout procedure here.
# g(x) = WX + drop(Brank_drop(AX)) for LoCon(lora), bypass
# g(x) = WX + drop(ΔWX) for any algo except LoCon(lora), bypass
# g(x) = (W + Brank_drop(A))X for LoCon(lora), rebuid
# g(x) = (W + rank_drop(ΔW))X for any algo except LoCon(lora), rebuild
self.drop = nn.Identity() if dropout == 0 else nn.Dropout(dropout)
self.rank_drop = (
nn.Identity() if rank_dropout == 0 else nn.Dropout(rank_dropout)
)
self.multiplier = multiplier
self.org_forward = org_module.forward
self.org_module = [org_module]
@classmethod
def parametrize(cls, org_module, attr, *args, **kwargs):
from .full import FullModule
if cls is FullModule:
raise RuntimeError("FullModule cannot be used for parametrize.")
target_param = getattr(org_module, attr)
kwargs["bypass_mode"] = False
if target_param.dim() == 2:
proxy_module = nn.Linear(
target_param.shape[0], target_param.shape[1], bias=False
)
proxy_module.weight = target_param
elif target_param.dim() > 2:
module_type = [
None,
None,
None,
nn.Conv1d,
nn.Conv2d,
nn.Conv3d,
None,
None,
][target_param.dim()]
proxy_module = module_type(
target_param.shape[0],
target_param.shape[1],
*target_param.shape[2:],
bias=False,
)
proxy_module.weight = target_param
module_obj = cls("", proxy_module, *args, **kwargs)
module_obj.forward = module_obj.parametrize_forward
module_obj.to(target_param)
parametrize.register_parametrization(org_module, attr, module_obj)
return module_obj
@classmethod
def algo_check(cls, state_dict, lora_name):
return any(f"{lora_name}.{k}" in state_dict for k in cls.weight_list_det)
@classmethod
def extract_state_dict(cls, state_dict, lora_name):
return [state_dict.get(f"{lora_name}.{k}", None) for k in cls.weight_list]
@classmethod
def make_module_from_state_dict(cls, lora_name, orig_module, *weights):
raise NotImplementedError
@property
def dtype(self):
return self.dtype_tensor.dtype
@property
def device(self):
return self.dtype_tensor.device
@property
def org_weight(self):
return self.org_module[0].weight
@org_weight.setter
def org_weight(self, value):
self.org_module[0].weight.data.copy_(value)
def apply_to(self, **kwargs):
if self.not_supported:
return
self.org_forward = self.org_module[0].forward
self.org_module[0].forward = self.forward
def restore(self):
if self.not_supported:
return
self.org_module[0].forward = self.org_forward
def merge_to(self, multiplier=1.0):
if self.not_supported:
return
self_device = next(self.parameters()).device
self_dtype = next(self.parameters()).dtype
self.to(self.org_weight)
weight, bias = self.get_merged_weight(
multiplier, self.org_weight.shape, self.org_weight.device
)
self.org_weight = weight.to(self.org_weight)
if bias is not None:
bias = bias.to(self.org_weight)
if self.org_module[0].bias is not None:
self.org_module[0].bias.data.copy_(bias)
else:
self.org_module[0].bias = nn.Parameter(bias)
self.to(self_device, self_dtype)
def get_diff_weight(self, multiplier=1.0, shape=None, device=None):
raise NotImplementedError
def get_merged_weight(self, multiplier=1.0, shape=None, device=None):
raise NotImplementedError
@torch.no_grad()
def apply_max_norm(self, max_norm, device=None):
return None, None
def bypass_forward_diff(self, x, scale=1):
raise NotImplementedError
def bypass_forward(self, x, scale=1):
raise NotImplementedError
def parametrize_forward(self, x: torch.Tensor, *args, **kwargs):
return self.get_merged_weight(
multiplier=self.multiplier, shape=x.shape, device=x.device
)[0].to(x.dtype)
def forward(self, *args, **kwargs):
raise NotImplementedError
|