ml_test / lycoris /utils /__init__.py
tonyshark's picture
Upload 132 files
cc69848 verified
import re
import hashlib
from io import BytesIO
from typing import Dict, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.linalg as linalg
import safetensors.torch
from tqdm import tqdm
from .general import *
def load_bytes_in_safetensors(tensors):
bytes = safetensors.torch.save(tensors)
b = BytesIO(bytes)
b.seek(0)
header = b.read(8)
n = int.from_bytes(header, "little")
offset = n + 8
b.seek(offset)
return b.read()
def precalculate_safetensors_hashes(state_dict):
# calculate each tensor one by one to reduce memory usage
hash_sha256 = hashlib.sha256()
for tensor in state_dict.values():
single_tensor_sd = {"tensor": tensor}
bytes_for_tensor = load_bytes_in_safetensors(single_tensor_sd)
hash_sha256.update(bytes_for_tensor)
return f"0x{hash_sha256.hexdigest()}"
def str_bool(val):
return str(val).lower() != "false"
def default(val, d):
return val if val is not None else d
def make_sparse(t: torch.Tensor, sparsity=0.95):
abs_t = torch.abs(t)
np_array = abs_t.detach().cpu().numpy()
quan = float(np.quantile(np_array, sparsity))
sparse_t = t.masked_fill(abs_t < quan, 0)
return sparse_t
def extract_conv(
weight: Union[torch.Tensor, nn.Parameter],
mode="fixed",
mode_param=0,
device="cpu",
is_cp=False,
) -> Tuple[nn.Parameter, nn.Parameter]:
weight = weight.to(device)
out_ch, in_ch, kernel_size, _ = weight.shape
U, S, Vh = linalg.svd(weight.reshape(out_ch, -1))
if mode == "full":
return weight, "full"
elif mode == "fixed":
lora_rank = mode_param
elif mode == "threshold":
assert mode_param >= 0
lora_rank = torch.sum(S > mode_param)
elif mode == "ratio":
assert 1 >= mode_param >= 0
min_s = torch.max(S) * mode_param
lora_rank = torch.sum(S > min_s)
elif mode == "quantile" or mode == "percentile":
assert 1 >= mode_param >= 0
s_cum = torch.cumsum(S, dim=0)
min_cum_sum = mode_param * torch.sum(S)
lora_rank = torch.sum(s_cum < min_cum_sum)
else:
raise NotImplementedError(
'Extract mode should be "fixed", "threshold", "ratio" or "quantile"'
)
lora_rank = max(1, lora_rank)
lora_rank = min(out_ch, in_ch, lora_rank)
if lora_rank >= out_ch / 2 and not is_cp:
return weight, "full"
U = U[:, :lora_rank]
S = S[:lora_rank]
U = U @ torch.diag(S).to(device)
Vh = Vh[:lora_rank, :]
diff = (weight - (U @ Vh).reshape(out_ch, in_ch, kernel_size, kernel_size)).detach()
extract_weight_A = Vh.reshape(lora_rank, in_ch, kernel_size, kernel_size).detach()
extract_weight_B = U.reshape(out_ch, lora_rank, 1, 1).detach()
del U, S, Vh, weight
return (extract_weight_A, extract_weight_B, diff), "low rank"
def extract_linear(
weight: Union[torch.Tensor, nn.Parameter],
mode="fixed",
mode_param=0,
device="cpu",
) -> Tuple[nn.Parameter, nn.Parameter]:
weight = weight.to(device)
out_ch, in_ch = weight.shape
U, S, Vh = linalg.svd(weight)
if mode == "full":
return weight, "full"
elif mode == "fixed":
lora_rank = mode_param
elif mode == "threshold":
assert mode_param >= 0
lora_rank = torch.sum(S > mode_param)
elif mode == "ratio":
assert 1 >= mode_param >= 0
min_s = torch.max(S) * mode_param
lora_rank = torch.sum(S > min_s)
elif mode == "quantile" or mode == "percentile":
assert 1 >= mode_param >= 0
s_cum = torch.cumsum(S, dim=0)
min_cum_sum = mode_param * torch.sum(S)
lora_rank = torch.sum(s_cum < min_cum_sum)
else:
raise NotImplementedError(
'Extract mode should be "fixed", "threshold", "ratio" or "quantile"'
)
lora_rank = max(1, lora_rank)
lora_rank = min(out_ch, in_ch, lora_rank)
if lora_rank >= out_ch / 2:
return weight, "full"
U = U[:, :lora_rank]
S = S[:lora_rank]
U = U @ torch.diag(S).to(device)
Vh = Vh[:lora_rank, :]
diff = (weight - U @ Vh).detach()
extract_weight_A = Vh.reshape(lora_rank, in_ch).detach()
extract_weight_B = U.reshape(out_ch, lora_rank).detach()
del U, S, Vh, weight
return (extract_weight_A, extract_weight_B, diff), "low rank"
@torch.no_grad()
def extract_diff(
base_tes,
db_tes,
base_unet,
db_unet,
mode="fixed",
linear_mode_param=0,
conv_mode_param=0,
extract_device="cpu",
use_bias=False,
sparsity=0.98,
small_conv=True,
):
UNET_TARGET_REPLACE_MODULE = [
"Linear",
"Conv2d",
"LayerNorm",
"GroupNorm",
"GroupNorm32",
]
TEXT_ENCODER_TARGET_REPLACE_MODULE = [
"Embedding",
"Linear",
"Conv2d",
"LayerNorm",
"GroupNorm",
"GroupNorm32",
]
LORA_PREFIX_UNET = "lora_unet"
LORA_PREFIX_TEXT_ENCODER = "lora_te"
def make_state_dict(
prefix,
root_module: torch.nn.Module,
target_module: torch.nn.Module,
target_replace_modules,
):
loras = {}
temp = {}
for name, module in root_module.named_modules():
if module.__class__.__name__ in target_replace_modules:
temp[name] = module
for name, module in tqdm(
list((n, m) for n, m in target_module.named_modules() if n in temp)
):
weights = temp[name]
lora_name = prefix + "." + name
lora_name = lora_name.replace(".", "_")
layer = module.__class__.__name__
if layer in {
"Linear",
"Conv2d",
"LayerNorm",
"GroupNorm",
"GroupNorm32",
"Embedding",
}:
root_weight = module.weight
if torch.allclose(root_weight, weights.weight):
continue
else:
continue
module = module.to(extract_device)
weights = weights.to(extract_device)
if mode == "full":
decompose_mode = "full"
elif layer == "Linear":
weight, decompose_mode = extract_linear(
(root_weight - weights.weight),
mode,
linear_mode_param,
device=extract_device,
)
if decompose_mode == "low rank":
extract_a, extract_b, diff = weight
elif layer == "Conv2d":
is_linear = root_weight.shape[2] == 1 and root_weight.shape[3] == 1
weight, decompose_mode = extract_conv(
(root_weight - weights.weight),
mode,
linear_mode_param if is_linear else conv_mode_param,
device=extract_device,
)
if decompose_mode == "low rank":
extract_a, extract_b, diff = weight
if small_conv and not is_linear and decompose_mode == "low rank":
dim = extract_a.size(0)
(extract_c, extract_a, _), _ = extract_conv(
extract_a.transpose(0, 1),
"fixed",
dim,
extract_device,
True,
)
extract_a = extract_a.transpose(0, 1)
extract_c = extract_c.transpose(0, 1)
loras[f"{lora_name}.lora_mid.weight"] = (
extract_c.detach().cpu().contiguous().half()
)
diff = (
(
root_weight
- torch.einsum(
"i j k l, j r, p i -> p r k l",
extract_c,
extract_a.flatten(1, -1),
extract_b.flatten(1, -1),
)
)
.detach()
.cpu()
.contiguous()
)
del extract_c
else:
module = module.to("cpu")
weights = weights.to("cpu")
continue
if decompose_mode == "low rank":
loras[f"{lora_name}.lora_down.weight"] = (
extract_a.detach().cpu().contiguous().half()
)
loras[f"{lora_name}.lora_up.weight"] = (
extract_b.detach().cpu().contiguous().half()
)
loras[f"{lora_name}.alpha"] = torch.Tensor([extract_a.shape[0]]).half()
if use_bias:
diff = diff.detach().cpu().reshape(extract_b.size(0), -1)
sparse_diff = make_sparse(diff, sparsity).to_sparse().coalesce()
indices = sparse_diff.indices().to(torch.int16)
values = sparse_diff.values().half()
loras[f"{lora_name}.bias_indices"] = indices
loras[f"{lora_name}.bias_values"] = values
loras[f"{lora_name}.bias_size"] = torch.tensor(diff.shape).to(
torch.int16
)
del extract_a, extract_b, diff
elif decompose_mode == "full":
if "Norm" in layer:
w_key = "w_norm"
b_key = "b_norm"
else:
w_key = "diff"
b_key = "diff_b"
weight_diff = module.weight - weights.weight
loras[f"{lora_name}.{w_key}"] = (
weight_diff.detach().cpu().contiguous().half()
)
if getattr(weights, "bias", None) is not None:
bias_diff = module.bias - weights.bias
loras[f"{lora_name}.{b_key}"] = (
bias_diff.detach().cpu().contiguous().half()
)
else:
raise NotImplementedError
module = module.to("cpu")
weights = weights.to("cpu")
return loras
all_loras = {}
all_loras |= make_state_dict(
LORA_PREFIX_UNET,
base_unet,
db_unet,
UNET_TARGET_REPLACE_MODULE,
)
del base_unet, db_unet
if torch.cuda.is_available():
torch.cuda.empty_cache()
for idx, (te1, te2) in enumerate(zip(base_tes, db_tes)):
if len(base_tes) > 1:
prefix = f"{LORA_PREFIX_TEXT_ENCODER}{idx+1}"
else:
prefix = LORA_PREFIX_TEXT_ENCODER
all_loras |= make_state_dict(
prefix,
te1,
te2,
TEXT_ENCODER_TARGET_REPLACE_MODULE,
)
del te1, te2
all_lora_name = set()
for k in all_loras:
lora_name, weight = k.rsplit(".", 1)
all_lora_name.add(lora_name)
print(len(all_lora_name))
return all_loras
re_digits = re.compile(r"\d+")
re_compiled = {}
suffix_conversion = {
"attentions": {},
"resnets": {
"conv1": "in_layers_2",
"conv2": "out_layers_3",
"norm1": "in_layers_0",
"norm2": "out_layers_0",
"time_emb_proj": "emb_layers_1",
"conv_shortcut": "skip_connection",
},
}
def convert_diffusers_name_to_compvis(key):
def match(match_list, regex_text):
regex = re_compiled.get(regex_text)
if regex is None:
regex = re.compile(regex_text)
re_compiled[regex_text] = regex
r = re.match(regex, key)
if not r:
return False
match_list.clear()
match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()])
return True
m = []
if match(m, r"lora_unet_conv_in(.*)"):
return f"lora_unet_input_blocks_0_0{m[0]}"
if match(m, r"lora_unet_conv_out(.*)"):
return f"lora_unet_out_2{m[0]}"
if match(m, r"lora_unet_time_embedding_linear_(\d+)(.*)"):
return f"lora_unet_time_embed_{m[0] * 2 - 2}{m[1]}"
if match(m, r"lora_unet_down_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"):
suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3])
return f"lora_unet_input_blocks_{1 + m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}"
if match(m, r"lora_unet_mid_block_(attentions|resnets)_(\d+)_(.+)"):
suffix = suffix_conversion.get(m[0], {}).get(m[2], m[2])
return (
f"lora_unet_middle_block_{1 if m[0] == 'attentions' else m[1] * 2}_{suffix}"
)
if match(m, r"lora_unet_up_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"):
suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3])
return f"lora_unet_output_blocks_{m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}"
if match(m, r"lora_unet_down_blocks_(\d+)_downsamplers_0_conv"):
return f"lora_unet_input_blocks_{3 + m[0] * 3}_0_op"
if match(m, r"lora_unet_up_blocks_(\d+)_upsamplers_0_conv"):
return f"lora_unet_output_blocks_{2 + m[0] * 3}_2_conv"
return key
@torch.no_grad()
def merge(tes, unet, lyco_state_dict, scale: float = 1.0, device="cpu"):
from ..modules import make_module, get_module
LORA_PREFIX_UNET = "lora_unet"
LORA_PREFIX_TEXT_ENCODER = "lora_te"
merged = 0
def merge_state_dict(
prefix,
root_module: torch.nn.Module,
lyco_state_dict: Dict[str, torch.Tensor],
):
nonlocal merged
for child_name, child_module in tqdm(
list(root_module.named_modules()), desc=f"Merging {prefix}"
):
lora_name = prefix + "." + child_name
lora_name = lora_name.replace(".", "_")
lyco_type, params = get_module(lyco_state_dict, lora_name)
if lyco_type is None:
continue
module = make_module(lyco_type, params, lora_name, child_module)
if module is None:
continue
module.to(device)
module.merge_to(scale)
key_dict.pop(convert_diffusers_name_to_compvis(lora_name), None)
key_dict.pop(lora_name, None)
merged += 1
key_dict = {}
for k, v in tqdm(list(lyco_state_dict.items()), desc="Converting Dtype and Device"):
module, weight_key = k.split(".", 1)
convert_key = convert_diffusers_name_to_compvis(module)
if convert_key != module and len(tes) > 1:
# kohya's format for sdxl is as same as SGM, not diffusers
del lyco_state_dict[k]
key_dict[convert_key] = key_dict.get(convert_key, []) + [k]
k = f"{convert_key}.{weight_key}"
else:
key_dict[module] = key_dict.get(module, []) + [k]
lyco_state_dict[k] = v.float().cpu()
for idx, te in enumerate(tes):
if len(tes) > 1:
prefix = LORA_PREFIX_TEXT_ENCODER + str(idx + 1)
else:
prefix = LORA_PREFIX_TEXT_ENCODER
merge_state_dict(
prefix,
te,
lyco_state_dict,
)
torch.cuda.empty_cache()
merge_state_dict(
LORA_PREFIX_UNET,
unet,
lyco_state_dict,
)
torch.cuda.empty_cache()
print(f"Unused state dict key: {key_dict}")
print(f"{merged} Modules been merged")