|
import re |
|
import hashlib |
|
from io import BytesIO |
|
from typing import Dict, Tuple, Union |
|
|
|
import numpy as np |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torch.linalg as linalg |
|
|
|
import safetensors.torch |
|
|
|
from tqdm import tqdm |
|
from .general import * |
|
|
|
|
|
def load_bytes_in_safetensors(tensors): |
|
bytes = safetensors.torch.save(tensors) |
|
b = BytesIO(bytes) |
|
|
|
b.seek(0) |
|
header = b.read(8) |
|
n = int.from_bytes(header, "little") |
|
|
|
offset = n + 8 |
|
b.seek(offset) |
|
|
|
return b.read() |
|
|
|
|
|
def precalculate_safetensors_hashes(state_dict): |
|
|
|
hash_sha256 = hashlib.sha256() |
|
for tensor in state_dict.values(): |
|
single_tensor_sd = {"tensor": tensor} |
|
bytes_for_tensor = load_bytes_in_safetensors(single_tensor_sd) |
|
hash_sha256.update(bytes_for_tensor) |
|
|
|
return f"0x{hash_sha256.hexdigest()}" |
|
|
|
|
|
def str_bool(val): |
|
return str(val).lower() != "false" |
|
|
|
|
|
def default(val, d): |
|
return val if val is not None else d |
|
|
|
|
|
def make_sparse(t: torch.Tensor, sparsity=0.95): |
|
abs_t = torch.abs(t) |
|
np_array = abs_t.detach().cpu().numpy() |
|
quan = float(np.quantile(np_array, sparsity)) |
|
sparse_t = t.masked_fill(abs_t < quan, 0) |
|
return sparse_t |
|
|
|
|
|
def extract_conv( |
|
weight: Union[torch.Tensor, nn.Parameter], |
|
mode="fixed", |
|
mode_param=0, |
|
device="cpu", |
|
is_cp=False, |
|
) -> Tuple[nn.Parameter, nn.Parameter]: |
|
weight = weight.to(device) |
|
out_ch, in_ch, kernel_size, _ = weight.shape |
|
|
|
U, S, Vh = linalg.svd(weight.reshape(out_ch, -1)) |
|
|
|
if mode == "full": |
|
return weight, "full" |
|
elif mode == "fixed": |
|
lora_rank = mode_param |
|
elif mode == "threshold": |
|
assert mode_param >= 0 |
|
lora_rank = torch.sum(S > mode_param) |
|
elif mode == "ratio": |
|
assert 1 >= mode_param >= 0 |
|
min_s = torch.max(S) * mode_param |
|
lora_rank = torch.sum(S > min_s) |
|
elif mode == "quantile" or mode == "percentile": |
|
assert 1 >= mode_param >= 0 |
|
s_cum = torch.cumsum(S, dim=0) |
|
min_cum_sum = mode_param * torch.sum(S) |
|
lora_rank = torch.sum(s_cum < min_cum_sum) |
|
else: |
|
raise NotImplementedError( |
|
'Extract mode should be "fixed", "threshold", "ratio" or "quantile"' |
|
) |
|
lora_rank = max(1, lora_rank) |
|
lora_rank = min(out_ch, in_ch, lora_rank) |
|
if lora_rank >= out_ch / 2 and not is_cp: |
|
return weight, "full" |
|
|
|
U = U[:, :lora_rank] |
|
S = S[:lora_rank] |
|
U = U @ torch.diag(S).to(device) |
|
Vh = Vh[:lora_rank, :] |
|
|
|
diff = (weight - (U @ Vh).reshape(out_ch, in_ch, kernel_size, kernel_size)).detach() |
|
extract_weight_A = Vh.reshape(lora_rank, in_ch, kernel_size, kernel_size).detach() |
|
extract_weight_B = U.reshape(out_ch, lora_rank, 1, 1).detach() |
|
del U, S, Vh, weight |
|
return (extract_weight_A, extract_weight_B, diff), "low rank" |
|
|
|
|
|
def extract_linear( |
|
weight: Union[torch.Tensor, nn.Parameter], |
|
mode="fixed", |
|
mode_param=0, |
|
device="cpu", |
|
) -> Tuple[nn.Parameter, nn.Parameter]: |
|
weight = weight.to(device) |
|
out_ch, in_ch = weight.shape |
|
|
|
U, S, Vh = linalg.svd(weight) |
|
|
|
if mode == "full": |
|
return weight, "full" |
|
elif mode == "fixed": |
|
lora_rank = mode_param |
|
elif mode == "threshold": |
|
assert mode_param >= 0 |
|
lora_rank = torch.sum(S > mode_param) |
|
elif mode == "ratio": |
|
assert 1 >= mode_param >= 0 |
|
min_s = torch.max(S) * mode_param |
|
lora_rank = torch.sum(S > min_s) |
|
elif mode == "quantile" or mode == "percentile": |
|
assert 1 >= mode_param >= 0 |
|
s_cum = torch.cumsum(S, dim=0) |
|
min_cum_sum = mode_param * torch.sum(S) |
|
lora_rank = torch.sum(s_cum < min_cum_sum) |
|
else: |
|
raise NotImplementedError( |
|
'Extract mode should be "fixed", "threshold", "ratio" or "quantile"' |
|
) |
|
lora_rank = max(1, lora_rank) |
|
lora_rank = min(out_ch, in_ch, lora_rank) |
|
if lora_rank >= out_ch / 2: |
|
return weight, "full" |
|
|
|
U = U[:, :lora_rank] |
|
S = S[:lora_rank] |
|
U = U @ torch.diag(S).to(device) |
|
Vh = Vh[:lora_rank, :] |
|
|
|
diff = (weight - U @ Vh).detach() |
|
extract_weight_A = Vh.reshape(lora_rank, in_ch).detach() |
|
extract_weight_B = U.reshape(out_ch, lora_rank).detach() |
|
del U, S, Vh, weight |
|
return (extract_weight_A, extract_weight_B, diff), "low rank" |
|
|
|
|
|
@torch.no_grad() |
|
def extract_diff( |
|
base_tes, |
|
db_tes, |
|
base_unet, |
|
db_unet, |
|
mode="fixed", |
|
linear_mode_param=0, |
|
conv_mode_param=0, |
|
extract_device="cpu", |
|
use_bias=False, |
|
sparsity=0.98, |
|
small_conv=True, |
|
): |
|
UNET_TARGET_REPLACE_MODULE = [ |
|
"Linear", |
|
"Conv2d", |
|
"LayerNorm", |
|
"GroupNorm", |
|
"GroupNorm32", |
|
] |
|
TEXT_ENCODER_TARGET_REPLACE_MODULE = [ |
|
"Embedding", |
|
"Linear", |
|
"Conv2d", |
|
"LayerNorm", |
|
"GroupNorm", |
|
"GroupNorm32", |
|
] |
|
LORA_PREFIX_UNET = "lora_unet" |
|
LORA_PREFIX_TEXT_ENCODER = "lora_te" |
|
|
|
def make_state_dict( |
|
prefix, |
|
root_module: torch.nn.Module, |
|
target_module: torch.nn.Module, |
|
target_replace_modules, |
|
): |
|
loras = {} |
|
temp = {} |
|
|
|
for name, module in root_module.named_modules(): |
|
if module.__class__.__name__ in target_replace_modules: |
|
temp[name] = module |
|
|
|
for name, module in tqdm( |
|
list((n, m) for n, m in target_module.named_modules() if n in temp) |
|
): |
|
weights = temp[name] |
|
lora_name = prefix + "." + name |
|
lora_name = lora_name.replace(".", "_") |
|
layer = module.__class__.__name__ |
|
|
|
if layer in { |
|
"Linear", |
|
"Conv2d", |
|
"LayerNorm", |
|
"GroupNorm", |
|
"GroupNorm32", |
|
"Embedding", |
|
}: |
|
root_weight = module.weight |
|
if torch.allclose(root_weight, weights.weight): |
|
continue |
|
else: |
|
continue |
|
module = module.to(extract_device) |
|
weights = weights.to(extract_device) |
|
|
|
if mode == "full": |
|
decompose_mode = "full" |
|
elif layer == "Linear": |
|
weight, decompose_mode = extract_linear( |
|
(root_weight - weights.weight), |
|
mode, |
|
linear_mode_param, |
|
device=extract_device, |
|
) |
|
if decompose_mode == "low rank": |
|
extract_a, extract_b, diff = weight |
|
elif layer == "Conv2d": |
|
is_linear = root_weight.shape[2] == 1 and root_weight.shape[3] == 1 |
|
weight, decompose_mode = extract_conv( |
|
(root_weight - weights.weight), |
|
mode, |
|
linear_mode_param if is_linear else conv_mode_param, |
|
device=extract_device, |
|
) |
|
if decompose_mode == "low rank": |
|
extract_a, extract_b, diff = weight |
|
if small_conv and not is_linear and decompose_mode == "low rank": |
|
dim = extract_a.size(0) |
|
(extract_c, extract_a, _), _ = extract_conv( |
|
extract_a.transpose(0, 1), |
|
"fixed", |
|
dim, |
|
extract_device, |
|
True, |
|
) |
|
extract_a = extract_a.transpose(0, 1) |
|
extract_c = extract_c.transpose(0, 1) |
|
loras[f"{lora_name}.lora_mid.weight"] = ( |
|
extract_c.detach().cpu().contiguous().half() |
|
) |
|
diff = ( |
|
( |
|
root_weight |
|
- torch.einsum( |
|
"i j k l, j r, p i -> p r k l", |
|
extract_c, |
|
extract_a.flatten(1, -1), |
|
extract_b.flatten(1, -1), |
|
) |
|
) |
|
.detach() |
|
.cpu() |
|
.contiguous() |
|
) |
|
del extract_c |
|
else: |
|
module = module.to("cpu") |
|
weights = weights.to("cpu") |
|
continue |
|
|
|
if decompose_mode == "low rank": |
|
loras[f"{lora_name}.lora_down.weight"] = ( |
|
extract_a.detach().cpu().contiguous().half() |
|
) |
|
loras[f"{lora_name}.lora_up.weight"] = ( |
|
extract_b.detach().cpu().contiguous().half() |
|
) |
|
loras[f"{lora_name}.alpha"] = torch.Tensor([extract_a.shape[0]]).half() |
|
if use_bias: |
|
diff = diff.detach().cpu().reshape(extract_b.size(0), -1) |
|
sparse_diff = make_sparse(diff, sparsity).to_sparse().coalesce() |
|
|
|
indices = sparse_diff.indices().to(torch.int16) |
|
values = sparse_diff.values().half() |
|
loras[f"{lora_name}.bias_indices"] = indices |
|
loras[f"{lora_name}.bias_values"] = values |
|
loras[f"{lora_name}.bias_size"] = torch.tensor(diff.shape).to( |
|
torch.int16 |
|
) |
|
del extract_a, extract_b, diff |
|
elif decompose_mode == "full": |
|
if "Norm" in layer: |
|
w_key = "w_norm" |
|
b_key = "b_norm" |
|
else: |
|
w_key = "diff" |
|
b_key = "diff_b" |
|
weight_diff = module.weight - weights.weight |
|
loras[f"{lora_name}.{w_key}"] = ( |
|
weight_diff.detach().cpu().contiguous().half() |
|
) |
|
if getattr(weights, "bias", None) is not None: |
|
bias_diff = module.bias - weights.bias |
|
loras[f"{lora_name}.{b_key}"] = ( |
|
bias_diff.detach().cpu().contiguous().half() |
|
) |
|
else: |
|
raise NotImplementedError |
|
module = module.to("cpu") |
|
weights = weights.to("cpu") |
|
return loras |
|
|
|
all_loras = {} |
|
|
|
all_loras |= make_state_dict( |
|
LORA_PREFIX_UNET, |
|
base_unet, |
|
db_unet, |
|
UNET_TARGET_REPLACE_MODULE, |
|
) |
|
del base_unet, db_unet |
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
|
|
for idx, (te1, te2) in enumerate(zip(base_tes, db_tes)): |
|
if len(base_tes) > 1: |
|
prefix = f"{LORA_PREFIX_TEXT_ENCODER}{idx+1}" |
|
else: |
|
prefix = LORA_PREFIX_TEXT_ENCODER |
|
all_loras |= make_state_dict( |
|
prefix, |
|
te1, |
|
te2, |
|
TEXT_ENCODER_TARGET_REPLACE_MODULE, |
|
) |
|
del te1, te2 |
|
|
|
all_lora_name = set() |
|
for k in all_loras: |
|
lora_name, weight = k.rsplit(".", 1) |
|
all_lora_name.add(lora_name) |
|
print(len(all_lora_name)) |
|
return all_loras |
|
|
|
|
|
re_digits = re.compile(r"\d+") |
|
re_compiled = {} |
|
|
|
suffix_conversion = { |
|
"attentions": {}, |
|
"resnets": { |
|
"conv1": "in_layers_2", |
|
"conv2": "out_layers_3", |
|
"norm1": "in_layers_0", |
|
"norm2": "out_layers_0", |
|
"time_emb_proj": "emb_layers_1", |
|
"conv_shortcut": "skip_connection", |
|
}, |
|
} |
|
|
|
|
|
def convert_diffusers_name_to_compvis(key): |
|
def match(match_list, regex_text): |
|
regex = re_compiled.get(regex_text) |
|
if regex is None: |
|
regex = re.compile(regex_text) |
|
re_compiled[regex_text] = regex |
|
|
|
r = re.match(regex, key) |
|
if not r: |
|
return False |
|
|
|
match_list.clear() |
|
match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()]) |
|
return True |
|
|
|
m = [] |
|
|
|
if match(m, r"lora_unet_conv_in(.*)"): |
|
return f"lora_unet_input_blocks_0_0{m[0]}" |
|
|
|
if match(m, r"lora_unet_conv_out(.*)"): |
|
return f"lora_unet_out_2{m[0]}" |
|
|
|
if match(m, r"lora_unet_time_embedding_linear_(\d+)(.*)"): |
|
return f"lora_unet_time_embed_{m[0] * 2 - 2}{m[1]}" |
|
|
|
if match(m, r"lora_unet_down_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"): |
|
suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3]) |
|
return f"lora_unet_input_blocks_{1 + m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}" |
|
|
|
if match(m, r"lora_unet_mid_block_(attentions|resnets)_(\d+)_(.+)"): |
|
suffix = suffix_conversion.get(m[0], {}).get(m[2], m[2]) |
|
return ( |
|
f"lora_unet_middle_block_{1 if m[0] == 'attentions' else m[1] * 2}_{suffix}" |
|
) |
|
|
|
if match(m, r"lora_unet_up_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"): |
|
suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3]) |
|
return f"lora_unet_output_blocks_{m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}" |
|
|
|
if match(m, r"lora_unet_down_blocks_(\d+)_downsamplers_0_conv"): |
|
return f"lora_unet_input_blocks_{3 + m[0] * 3}_0_op" |
|
|
|
if match(m, r"lora_unet_up_blocks_(\d+)_upsamplers_0_conv"): |
|
return f"lora_unet_output_blocks_{2 + m[0] * 3}_2_conv" |
|
return key |
|
|
|
|
|
@torch.no_grad() |
|
def merge(tes, unet, lyco_state_dict, scale: float = 1.0, device="cpu"): |
|
from ..modules import make_module, get_module |
|
|
|
LORA_PREFIX_UNET = "lora_unet" |
|
LORA_PREFIX_TEXT_ENCODER = "lora_te" |
|
merged = 0 |
|
|
|
def merge_state_dict( |
|
prefix, |
|
root_module: torch.nn.Module, |
|
lyco_state_dict: Dict[str, torch.Tensor], |
|
): |
|
nonlocal merged |
|
for child_name, child_module in tqdm( |
|
list(root_module.named_modules()), desc=f"Merging {prefix}" |
|
): |
|
lora_name = prefix + "." + child_name |
|
lora_name = lora_name.replace(".", "_") |
|
|
|
lyco_type, params = get_module(lyco_state_dict, lora_name) |
|
if lyco_type is None: |
|
continue |
|
module = make_module(lyco_type, params, lora_name, child_module) |
|
if module is None: |
|
continue |
|
module.to(device) |
|
module.merge_to(scale) |
|
key_dict.pop(convert_diffusers_name_to_compvis(lora_name), None) |
|
key_dict.pop(lora_name, None) |
|
merged += 1 |
|
|
|
key_dict = {} |
|
for k, v in tqdm(list(lyco_state_dict.items()), desc="Converting Dtype and Device"): |
|
module, weight_key = k.split(".", 1) |
|
convert_key = convert_diffusers_name_to_compvis(module) |
|
if convert_key != module and len(tes) > 1: |
|
|
|
del lyco_state_dict[k] |
|
key_dict[convert_key] = key_dict.get(convert_key, []) + [k] |
|
k = f"{convert_key}.{weight_key}" |
|
else: |
|
key_dict[module] = key_dict.get(module, []) + [k] |
|
lyco_state_dict[k] = v.float().cpu() |
|
|
|
for idx, te in enumerate(tes): |
|
if len(tes) > 1: |
|
prefix = LORA_PREFIX_TEXT_ENCODER + str(idx + 1) |
|
else: |
|
prefix = LORA_PREFIX_TEXT_ENCODER |
|
merge_state_dict( |
|
prefix, |
|
te, |
|
lyco_state_dict, |
|
) |
|
torch.cuda.empty_cache() |
|
merge_state_dict( |
|
LORA_PREFIX_UNET, |
|
unet, |
|
lyco_state_dict, |
|
) |
|
torch.cuda.empty_cache() |
|
print(f"Unused state dict key: {key_dict}") |
|
print(f"{merged} Modules been merged") |
|
|