File size: 15,625 Bytes
cc69848 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 |
import re
import hashlib
from io import BytesIO
from typing import Dict, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.linalg as linalg
import safetensors.torch
from tqdm import tqdm
from .general import *
def load_bytes_in_safetensors(tensors):
bytes = safetensors.torch.save(tensors)
b = BytesIO(bytes)
b.seek(0)
header = b.read(8)
n = int.from_bytes(header, "little")
offset = n + 8
b.seek(offset)
return b.read()
def precalculate_safetensors_hashes(state_dict):
# calculate each tensor one by one to reduce memory usage
hash_sha256 = hashlib.sha256()
for tensor in state_dict.values():
single_tensor_sd = {"tensor": tensor}
bytes_for_tensor = load_bytes_in_safetensors(single_tensor_sd)
hash_sha256.update(bytes_for_tensor)
return f"0x{hash_sha256.hexdigest()}"
def str_bool(val):
return str(val).lower() != "false"
def default(val, d):
return val if val is not None else d
def make_sparse(t: torch.Tensor, sparsity=0.95):
abs_t = torch.abs(t)
np_array = abs_t.detach().cpu().numpy()
quan = float(np.quantile(np_array, sparsity))
sparse_t = t.masked_fill(abs_t < quan, 0)
return sparse_t
def extract_conv(
weight: Union[torch.Tensor, nn.Parameter],
mode="fixed",
mode_param=0,
device="cpu",
is_cp=False,
) -> Tuple[nn.Parameter, nn.Parameter]:
weight = weight.to(device)
out_ch, in_ch, kernel_size, _ = weight.shape
U, S, Vh = linalg.svd(weight.reshape(out_ch, -1))
if mode == "full":
return weight, "full"
elif mode == "fixed":
lora_rank = mode_param
elif mode == "threshold":
assert mode_param >= 0
lora_rank = torch.sum(S > mode_param)
elif mode == "ratio":
assert 1 >= mode_param >= 0
min_s = torch.max(S) * mode_param
lora_rank = torch.sum(S > min_s)
elif mode == "quantile" or mode == "percentile":
assert 1 >= mode_param >= 0
s_cum = torch.cumsum(S, dim=0)
min_cum_sum = mode_param * torch.sum(S)
lora_rank = torch.sum(s_cum < min_cum_sum)
else:
raise NotImplementedError(
'Extract mode should be "fixed", "threshold", "ratio" or "quantile"'
)
lora_rank = max(1, lora_rank)
lora_rank = min(out_ch, in_ch, lora_rank)
if lora_rank >= out_ch / 2 and not is_cp:
return weight, "full"
U = U[:, :lora_rank]
S = S[:lora_rank]
U = U @ torch.diag(S).to(device)
Vh = Vh[:lora_rank, :]
diff = (weight - (U @ Vh).reshape(out_ch, in_ch, kernel_size, kernel_size)).detach()
extract_weight_A = Vh.reshape(lora_rank, in_ch, kernel_size, kernel_size).detach()
extract_weight_B = U.reshape(out_ch, lora_rank, 1, 1).detach()
del U, S, Vh, weight
return (extract_weight_A, extract_weight_B, diff), "low rank"
def extract_linear(
weight: Union[torch.Tensor, nn.Parameter],
mode="fixed",
mode_param=0,
device="cpu",
) -> Tuple[nn.Parameter, nn.Parameter]:
weight = weight.to(device)
out_ch, in_ch = weight.shape
U, S, Vh = linalg.svd(weight)
if mode == "full":
return weight, "full"
elif mode == "fixed":
lora_rank = mode_param
elif mode == "threshold":
assert mode_param >= 0
lora_rank = torch.sum(S > mode_param)
elif mode == "ratio":
assert 1 >= mode_param >= 0
min_s = torch.max(S) * mode_param
lora_rank = torch.sum(S > min_s)
elif mode == "quantile" or mode == "percentile":
assert 1 >= mode_param >= 0
s_cum = torch.cumsum(S, dim=0)
min_cum_sum = mode_param * torch.sum(S)
lora_rank = torch.sum(s_cum < min_cum_sum)
else:
raise NotImplementedError(
'Extract mode should be "fixed", "threshold", "ratio" or "quantile"'
)
lora_rank = max(1, lora_rank)
lora_rank = min(out_ch, in_ch, lora_rank)
if lora_rank >= out_ch / 2:
return weight, "full"
U = U[:, :lora_rank]
S = S[:lora_rank]
U = U @ torch.diag(S).to(device)
Vh = Vh[:lora_rank, :]
diff = (weight - U @ Vh).detach()
extract_weight_A = Vh.reshape(lora_rank, in_ch).detach()
extract_weight_B = U.reshape(out_ch, lora_rank).detach()
del U, S, Vh, weight
return (extract_weight_A, extract_weight_B, diff), "low rank"
@torch.no_grad()
def extract_diff(
base_tes,
db_tes,
base_unet,
db_unet,
mode="fixed",
linear_mode_param=0,
conv_mode_param=0,
extract_device="cpu",
use_bias=False,
sparsity=0.98,
small_conv=True,
):
UNET_TARGET_REPLACE_MODULE = [
"Linear",
"Conv2d",
"LayerNorm",
"GroupNorm",
"GroupNorm32",
]
TEXT_ENCODER_TARGET_REPLACE_MODULE = [
"Embedding",
"Linear",
"Conv2d",
"LayerNorm",
"GroupNorm",
"GroupNorm32",
]
LORA_PREFIX_UNET = "lora_unet"
LORA_PREFIX_TEXT_ENCODER = "lora_te"
def make_state_dict(
prefix,
root_module: torch.nn.Module,
target_module: torch.nn.Module,
target_replace_modules,
):
loras = {}
temp = {}
for name, module in root_module.named_modules():
if module.__class__.__name__ in target_replace_modules:
temp[name] = module
for name, module in tqdm(
list((n, m) for n, m in target_module.named_modules() if n in temp)
):
weights = temp[name]
lora_name = prefix + "." + name
lora_name = lora_name.replace(".", "_")
layer = module.__class__.__name__
if layer in {
"Linear",
"Conv2d",
"LayerNorm",
"GroupNorm",
"GroupNorm32",
"Embedding",
}:
root_weight = module.weight
if torch.allclose(root_weight, weights.weight):
continue
else:
continue
module = module.to(extract_device)
weights = weights.to(extract_device)
if mode == "full":
decompose_mode = "full"
elif layer == "Linear":
weight, decompose_mode = extract_linear(
(root_weight - weights.weight),
mode,
linear_mode_param,
device=extract_device,
)
if decompose_mode == "low rank":
extract_a, extract_b, diff = weight
elif layer == "Conv2d":
is_linear = root_weight.shape[2] == 1 and root_weight.shape[3] == 1
weight, decompose_mode = extract_conv(
(root_weight - weights.weight),
mode,
linear_mode_param if is_linear else conv_mode_param,
device=extract_device,
)
if decompose_mode == "low rank":
extract_a, extract_b, diff = weight
if small_conv and not is_linear and decompose_mode == "low rank":
dim = extract_a.size(0)
(extract_c, extract_a, _), _ = extract_conv(
extract_a.transpose(0, 1),
"fixed",
dim,
extract_device,
True,
)
extract_a = extract_a.transpose(0, 1)
extract_c = extract_c.transpose(0, 1)
loras[f"{lora_name}.lora_mid.weight"] = (
extract_c.detach().cpu().contiguous().half()
)
diff = (
(
root_weight
- torch.einsum(
"i j k l, j r, p i -> p r k l",
extract_c,
extract_a.flatten(1, -1),
extract_b.flatten(1, -1),
)
)
.detach()
.cpu()
.contiguous()
)
del extract_c
else:
module = module.to("cpu")
weights = weights.to("cpu")
continue
if decompose_mode == "low rank":
loras[f"{lora_name}.lora_down.weight"] = (
extract_a.detach().cpu().contiguous().half()
)
loras[f"{lora_name}.lora_up.weight"] = (
extract_b.detach().cpu().contiguous().half()
)
loras[f"{lora_name}.alpha"] = torch.Tensor([extract_a.shape[0]]).half()
if use_bias:
diff = diff.detach().cpu().reshape(extract_b.size(0), -1)
sparse_diff = make_sparse(diff, sparsity).to_sparse().coalesce()
indices = sparse_diff.indices().to(torch.int16)
values = sparse_diff.values().half()
loras[f"{lora_name}.bias_indices"] = indices
loras[f"{lora_name}.bias_values"] = values
loras[f"{lora_name}.bias_size"] = torch.tensor(diff.shape).to(
torch.int16
)
del extract_a, extract_b, diff
elif decompose_mode == "full":
if "Norm" in layer:
w_key = "w_norm"
b_key = "b_norm"
else:
w_key = "diff"
b_key = "diff_b"
weight_diff = module.weight - weights.weight
loras[f"{lora_name}.{w_key}"] = (
weight_diff.detach().cpu().contiguous().half()
)
if getattr(weights, "bias", None) is not None:
bias_diff = module.bias - weights.bias
loras[f"{lora_name}.{b_key}"] = (
bias_diff.detach().cpu().contiguous().half()
)
else:
raise NotImplementedError
module = module.to("cpu")
weights = weights.to("cpu")
return loras
all_loras = {}
all_loras |= make_state_dict(
LORA_PREFIX_UNET,
base_unet,
db_unet,
UNET_TARGET_REPLACE_MODULE,
)
del base_unet, db_unet
if torch.cuda.is_available():
torch.cuda.empty_cache()
for idx, (te1, te2) in enumerate(zip(base_tes, db_tes)):
if len(base_tes) > 1:
prefix = f"{LORA_PREFIX_TEXT_ENCODER}{idx+1}"
else:
prefix = LORA_PREFIX_TEXT_ENCODER
all_loras |= make_state_dict(
prefix,
te1,
te2,
TEXT_ENCODER_TARGET_REPLACE_MODULE,
)
del te1, te2
all_lora_name = set()
for k in all_loras:
lora_name, weight = k.rsplit(".", 1)
all_lora_name.add(lora_name)
print(len(all_lora_name))
return all_loras
re_digits = re.compile(r"\d+")
re_compiled = {}
suffix_conversion = {
"attentions": {},
"resnets": {
"conv1": "in_layers_2",
"conv2": "out_layers_3",
"norm1": "in_layers_0",
"norm2": "out_layers_0",
"time_emb_proj": "emb_layers_1",
"conv_shortcut": "skip_connection",
},
}
def convert_diffusers_name_to_compvis(key):
def match(match_list, regex_text):
regex = re_compiled.get(regex_text)
if regex is None:
regex = re.compile(regex_text)
re_compiled[regex_text] = regex
r = re.match(regex, key)
if not r:
return False
match_list.clear()
match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()])
return True
m = []
if match(m, r"lora_unet_conv_in(.*)"):
return f"lora_unet_input_blocks_0_0{m[0]}"
if match(m, r"lora_unet_conv_out(.*)"):
return f"lora_unet_out_2{m[0]}"
if match(m, r"lora_unet_time_embedding_linear_(\d+)(.*)"):
return f"lora_unet_time_embed_{m[0] * 2 - 2}{m[1]}"
if match(m, r"lora_unet_down_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"):
suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3])
return f"lora_unet_input_blocks_{1 + m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}"
if match(m, r"lora_unet_mid_block_(attentions|resnets)_(\d+)_(.+)"):
suffix = suffix_conversion.get(m[0], {}).get(m[2], m[2])
return (
f"lora_unet_middle_block_{1 if m[0] == 'attentions' else m[1] * 2}_{suffix}"
)
if match(m, r"lora_unet_up_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"):
suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3])
return f"lora_unet_output_blocks_{m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}"
if match(m, r"lora_unet_down_blocks_(\d+)_downsamplers_0_conv"):
return f"lora_unet_input_blocks_{3 + m[0] * 3}_0_op"
if match(m, r"lora_unet_up_blocks_(\d+)_upsamplers_0_conv"):
return f"lora_unet_output_blocks_{2 + m[0] * 3}_2_conv"
return key
@torch.no_grad()
def merge(tes, unet, lyco_state_dict, scale: float = 1.0, device="cpu"):
from ..modules import make_module, get_module
LORA_PREFIX_UNET = "lora_unet"
LORA_PREFIX_TEXT_ENCODER = "lora_te"
merged = 0
def merge_state_dict(
prefix,
root_module: torch.nn.Module,
lyco_state_dict: Dict[str, torch.Tensor],
):
nonlocal merged
for child_name, child_module in tqdm(
list(root_module.named_modules()), desc=f"Merging {prefix}"
):
lora_name = prefix + "." + child_name
lora_name = lora_name.replace(".", "_")
lyco_type, params = get_module(lyco_state_dict, lora_name)
if lyco_type is None:
continue
module = make_module(lyco_type, params, lora_name, child_module)
if module is None:
continue
module.to(device)
module.merge_to(scale)
key_dict.pop(convert_diffusers_name_to_compvis(lora_name), None)
key_dict.pop(lora_name, None)
merged += 1
key_dict = {}
for k, v in tqdm(list(lyco_state_dict.items()), desc="Converting Dtype and Device"):
module, weight_key = k.split(".", 1)
convert_key = convert_diffusers_name_to_compvis(module)
if convert_key != module and len(tes) > 1:
# kohya's format for sdxl is as same as SGM, not diffusers
del lyco_state_dict[k]
key_dict[convert_key] = key_dict.get(convert_key, []) + [k]
k = f"{convert_key}.{weight_key}"
else:
key_dict[module] = key_dict.get(module, []) + [k]
lyco_state_dict[k] = v.float().cpu()
for idx, te in enumerate(tes):
if len(tes) > 1:
prefix = LORA_PREFIX_TEXT_ENCODER + str(idx + 1)
else:
prefix = LORA_PREFIX_TEXT_ENCODER
merge_state_dict(
prefix,
te,
lyco_state_dict,
)
torch.cuda.empty_cache()
merge_state_dict(
LORA_PREFIX_UNET,
unet,
lyco_state_dict,
)
torch.cuda.empty_cache()
print(f"Unused state dict key: {key_dict}")
print(f"{merged} Modules been merged")
|