File size: 11,590 Bytes
cc69848 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 |
import math
import torch
import torch.nn as nn
from .base import LycorisBaseModule
from ..functional.loha import diff_weight as loha_diff_weight
class LohaModule(LycorisBaseModule):
name = "loha"
support_module = {
"linear",
"conv1d",
"conv2d",
"conv3d",
}
weight_list = [
"hada_w1_a",
"hada_w1_b",
"hada_w2_a",
"hada_w2_b",
"hada_t1",
"hada_t2",
"alpha",
"dora_scale",
]
weight_list_det = ["hada_w1_a"]
def __init__(
self,
lora_name,
org_module: nn.Module,
multiplier=1.0,
lora_dim=4,
alpha=1,
dropout=0.0,
rank_dropout=0.0,
module_dropout=0.0,
use_tucker=False,
use_scalar=False,
rank_dropout_scale=False,
weight_decompose=False,
wd_on_out=False,
bypass_mode=None,
rs_lora=False,
**kwargs,
):
super().__init__(
lora_name,
org_module,
multiplier,
dropout,
rank_dropout,
module_dropout,
rank_dropout_scale,
bypass_mode,
)
if self.module_type not in self.support_module:
raise ValueError(f"{self.module_type} is not supported in LoHa algo.")
self.lora_name = lora_name
self.lora_dim = lora_dim
self.tucker = False
self.rs_lora = rs_lora
w_shape = self.shape
if self.module_type.startswith("conv"):
in_dim = org_module.in_channels
k_size = org_module.kernel_size
out_dim = org_module.out_channels
self.shape = (out_dim, in_dim, *k_size)
self.tucker = use_tucker and any(i != 1 for i in k_size)
if self.tucker:
w_shape = (out_dim, in_dim, *k_size)
else:
w_shape = (out_dim, in_dim * torch.tensor(k_size).prod().item())
if self.tucker:
self.hada_t1 = nn.Parameter(torch.empty(lora_dim, lora_dim, *w_shape[2:]))
self.hada_w1_a = nn.Parameter(
torch.empty(lora_dim, w_shape[0])
) # out_dim, 1-mode
self.hada_w1_b = nn.Parameter(
torch.empty(lora_dim, w_shape[1])
) # in_dim , 2-mode
self.hada_t2 = nn.Parameter(torch.empty(lora_dim, lora_dim, *w_shape[2:]))
self.hada_w2_a = nn.Parameter(
torch.empty(lora_dim, w_shape[0])
) # out_dim, 1-mode
self.hada_w2_b = nn.Parameter(
torch.empty(lora_dim, w_shape[1])
) # in_dim , 2-mode
else:
self.hada_w1_a = nn.Parameter(torch.empty(w_shape[0], lora_dim))
self.hada_w1_b = nn.Parameter(torch.empty(lora_dim, w_shape[1]))
self.hada_w2_a = nn.Parameter(torch.empty(w_shape[0], lora_dim))
self.hada_w2_b = nn.Parameter(torch.empty(lora_dim, w_shape[1]))
self.wd = weight_decompose
self.wd_on_out = wd_on_out
if self.wd:
org_weight = org_module.weight.cpu().clone().float()
self.dora_norm_dims = org_weight.dim() - 1
if self.wd_on_out:
self.dora_scale = nn.Parameter(
torch.norm(
org_weight.reshape(org_weight.shape[0], -1),
dim=1,
keepdim=True,
).reshape(org_weight.shape[0], *[1] * self.dora_norm_dims)
).float()
else:
self.dora_scale = nn.Parameter(
torch.norm(
org_weight.transpose(1, 0).reshape(org_weight.shape[1], -1),
dim=1,
keepdim=True,
)
.reshape(org_weight.shape[1], *[1] * self.dora_norm_dims)
.transpose(1, 0)
).float()
if self.dropout:
print("[WARN]LoHa/LoKr haven't implemented normal dropout yet.")
if type(alpha) == torch.Tensor:
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
alpha = lora_dim if alpha is None or alpha == 0 else alpha
r_factor = lora_dim
if self.rs_lora:
r_factor = math.sqrt(r_factor)
self.scale = alpha / r_factor
self.register_buffer("alpha", torch.tensor(alpha * (lora_dim / r_factor)))
if use_scalar:
self.scalar = nn.Parameter(torch.tensor(0.0))
else:
self.register_buffer("scalar", torch.tensor(1.0), persistent=False)
# Need more experiments on init method
if self.tucker:
torch.nn.init.normal_(self.hada_t1, std=0.1)
torch.nn.init.normal_(self.hada_t2, std=0.1)
torch.nn.init.normal_(self.hada_w1_b, std=1)
torch.nn.init.normal_(self.hada_w1_a, std=0.1)
torch.nn.init.normal_(self.hada_w2_b, std=1)
if use_scalar:
torch.nn.init.normal_(self.hada_w2_a, std=0.1)
else:
torch.nn.init.constant_(self.hada_w2_a, 0)
@classmethod
def make_module_from_state_dict(
cls, lora_name, orig_module, w1a, w1b, w2a, w2b, t1, t2, alpha, dora_scale
):
module = cls(
lora_name,
orig_module,
1,
w1b.size(0),
float(alpha),
use_tucker=t1 is not None,
weight_decompose=dora_scale is not None,
)
module.hada_w1_a.copy_(w1a)
module.hada_w1_b.copy_(w1b)
module.hada_w2_a.copy_(w2a)
module.hada_w2_b.copy_(w2b)
if t1 is not None:
module.hada_t1.copy_(t1)
module.hada_t2.copy_(t2)
if dora_scale is not None:
module.dora_scale.copy_(dora_scale)
return module
def load_weight_hook(self, module: nn.Module, incompatible_keys):
missing_keys = incompatible_keys.missing_keys
for key in missing_keys:
if "scalar" in key:
del missing_keys[missing_keys.index(key)]
if isinstance(self.scalar, nn.Parameter):
self.scalar.data.copy_(torch.ones_like(self.scalar))
elif getattr(self, "scalar", None) is not None:
self.scalar.copy_(torch.ones_like(self.scalar))
else:
self.register_buffer(
"scalar", torch.ones_like(self.scalar), persistent=False
)
def get_weight(self, shape):
scale = torch.tensor(
self.scale, dtype=self.hada_w1_b.dtype, device=self.hada_w1_b.device
)
if self.tucker:
weight = loha_diff_weight(
self.hada_w1_b,
self.hada_w1_a,
self.hada_w2_b,
self.hada_w2_a,
self.hada_t1,
self.hada_t2,
gamma=scale,
)
else:
weight = loha_diff_weight(
self.hada_w1_b,
self.hada_w1_a,
self.hada_w2_b,
self.hada_w2_a,
None,
None,
gamma=scale,
)
if shape is not None:
weight = weight.reshape(shape)
if self.training and self.rank_dropout:
drop = (torch.rand(weight.size(0)) > self.rank_dropout).to(weight.dtype)
drop = drop.view(-1, *[1] * len(weight.shape[1:])).to(weight.device)
if self.rank_dropout_scale:
drop /= drop.mean()
weight *= drop
return weight
def get_diff_weight(self, multiplier=1, shape=None, device=None):
scale = self.scale * multiplier
diff = self.get_weight(shape) * scale
if device is not None:
diff = diff.to(device)
return diff, None
def get_merged_weight(self, multiplier=1, shape=None, device=None):
diff = self.get_diff_weight(multiplier=1, shape=shape, device=device)[0]
weight = self.org_weight
if self.wd:
merged = self.apply_weight_decompose(weight + diff, multiplier)
else:
merged = weight + diff * multiplier
return merged, None
def apply_weight_decompose(self, weight, multiplier=1):
weight = weight.to(self.dora_scale.dtype)
if self.wd_on_out:
weight_norm = (
weight.reshape(weight.shape[0], -1)
.norm(dim=1)
.reshape(weight.shape[0], *[1] * self.dora_norm_dims)
) + torch.finfo(weight.dtype).eps
else:
weight_norm = (
weight.transpose(0, 1)
.reshape(weight.shape[1], -1)
.norm(dim=1, keepdim=True)
.reshape(weight.shape[1], *[1] * self.dora_norm_dims)
.transpose(0, 1)
) + torch.finfo(weight.dtype).eps
scale = self.dora_scale.to(weight.device) / weight_norm
if multiplier != 1:
scale = multiplier * (scale - 1) + 1
return weight * scale
def custom_state_dict(self):
destination = {}
destination["alpha"] = self.alpha
if self.wd:
destination["dora_scale"] = self.dora_scale
destination["hada_w1_a"] = self.hada_w1_a * self.scalar
destination["hada_w1_b"] = self.hada_w1_b
destination["hada_w2_a"] = self.hada_w2_a
destination["hada_w2_b"] = self.hada_w2_b
if self.tucker:
destination["hada_t1"] = self.hada_t1
destination["hada_t2"] = self.hada_t2
return destination
@torch.no_grad()
def apply_max_norm(self, max_norm, device=None):
orig_norm = (self.get_weight(self.shape) * self.scalar).norm()
norm = torch.clamp(orig_norm, max_norm / 2)
desired = torch.clamp(norm, max=max_norm)
ratio = desired.cpu() / norm.cpu()
scaled = norm != desired
if scaled:
self.scalar *= ratio
return scaled, orig_norm * ratio
def bypass_forward_diff(self, x, scale=1):
diff_weight = self.get_weight(self.shape) * self.scalar * scale
return self.drop(self.op(x, diff_weight, **self.kw_dict))
def bypass_forward(self, x, scale=1):
return self.org_forward(x) + self.bypass_forward_diff(x, scale=scale)
def forward(self, x: torch.Tensor, *args, **kwargs):
if self.module_dropout and self.training:
if torch.rand(1) < self.module_dropout:
return self.op(
x,
self.org_module[0].weight.data,
(
None
if self.org_module[0].bias is None
else self.org_module[0].bias.data
),
)
if self.bypass_mode:
return self.bypass_forward(x, scale=self.multiplier)
else:
diff_weight = self.get_weight(self.shape).to(self.dtype) * self.scalar
weight = self.org_module[0].weight.data.to(self.dtype)
if self.wd:
weight = self.apply_weight_decompose(
weight + diff_weight, self.multiplier
)
else:
weight = weight + diff_weight * self.multiplier
bias = (
None
if self.org_module[0].bias is None
else self.org_module[0].bias.data
)
return self.op(x, weight, bias, **self.kw_dict)
|