|
import logging |
|
from typing import Optional |
|
|
|
import torch |
|
import comfy.model_management |
|
from .base import ( |
|
WeightAdapterBase, |
|
WeightAdapterTrainBase, |
|
weight_decompose, |
|
pad_tensor_to_shape, |
|
tucker_weight_from_conv, |
|
) |
|
|
|
|
|
class LoraDiff(WeightAdapterTrainBase): |
|
def __init__(self, weights): |
|
super().__init__() |
|
mat1, mat2, alpha, mid, dora_scale, reshape = weights |
|
out_dim, rank = mat1.shape[0], mat1.shape[1] |
|
rank, in_dim = mat2.shape[0], mat2.shape[1] |
|
if mid is not None: |
|
convdim = mid.ndim - 2 |
|
layer = ( |
|
torch.nn.Conv1d, |
|
torch.nn.Conv2d, |
|
torch.nn.Conv3d |
|
)[convdim] |
|
else: |
|
layer = torch.nn.Linear |
|
self.lora_up = layer(rank, out_dim, bias=False) |
|
self.lora_down = layer(in_dim, rank, bias=False) |
|
self.lora_up.weight.data.copy_(mat1) |
|
self.lora_down.weight.data.copy_(mat2) |
|
if mid is not None: |
|
self.lora_mid = layer(mid, rank, bias=False) |
|
self.lora_mid.weight.data.copy_(mid) |
|
else: |
|
self.lora_mid = None |
|
self.rank = rank |
|
self.alpha = torch.nn.Parameter(torch.tensor(alpha), requires_grad=False) |
|
|
|
def __call__(self, w): |
|
org_dtype = w.dtype |
|
if self.lora_mid is None: |
|
diff = self.lora_up.weight @ self.lora_down.weight |
|
else: |
|
diff = tucker_weight_from_conv( |
|
self.lora_up.weight, self.lora_down.weight, self.lora_mid.weight |
|
) |
|
scale = self.alpha / self.rank |
|
weight = w + scale * diff.reshape(w.shape) |
|
return weight.to(org_dtype) |
|
|
|
def passive_memory_usage(self): |
|
return sum(param.numel() * param.element_size() for param in self.parameters()) |
|
|
|
|
|
class LoRAAdapter(WeightAdapterBase): |
|
name = "lora" |
|
|
|
def __init__(self, loaded_keys, weights): |
|
self.loaded_keys = loaded_keys |
|
self.weights = weights |
|
|
|
@classmethod |
|
def create_train(cls, weight, rank=1, alpha=1.0): |
|
out_dim = weight.shape[0] |
|
in_dim = weight.shape[1:].numel() |
|
mat1 = torch.empty(out_dim, rank, device=weight.device, dtype=weight.dtype) |
|
mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=weight.dtype) |
|
torch.nn.init.kaiming_uniform_(mat1, a=5**0.5) |
|
torch.nn.init.constant_(mat2, 0.0) |
|
return LoraDiff( |
|
(mat1, mat2, alpha, None, None, None) |
|
) |
|
|
|
def to_train(self): |
|
return LoraDiff(self.weights) |
|
|
|
@classmethod |
|
def load( |
|
cls, |
|
x: str, |
|
lora: dict[str, torch.Tensor], |
|
alpha: float, |
|
dora_scale: torch.Tensor, |
|
loaded_keys: set[str] = None, |
|
) -> Optional["LoRAAdapter"]: |
|
if loaded_keys is None: |
|
loaded_keys = set() |
|
|
|
reshape_name = "{}.reshape_weight".format(x) |
|
regular_lora = "{}.lora_up.weight".format(x) |
|
diffusers_lora = "{}_lora.up.weight".format(x) |
|
diffusers2_lora = "{}.lora_B.weight".format(x) |
|
diffusers3_lora = "{}.lora.up.weight".format(x) |
|
mochi_lora = "{}.lora_B".format(x) |
|
transformers_lora = "{}.lora_linear_layer.up.weight".format(x) |
|
qwen_default_lora = "{}.lora_B.default.weight".format(x) |
|
A_name = None |
|
|
|
if regular_lora in lora.keys(): |
|
A_name = regular_lora |
|
B_name = "{}.lora_down.weight".format(x) |
|
mid_name = "{}.lora_mid.weight".format(x) |
|
elif diffusers_lora in lora.keys(): |
|
A_name = diffusers_lora |
|
B_name = "{}_lora.down.weight".format(x) |
|
mid_name = None |
|
elif diffusers2_lora in lora.keys(): |
|
A_name = diffusers2_lora |
|
B_name = "{}.lora_A.weight".format(x) |
|
mid_name = None |
|
elif diffusers3_lora in lora.keys(): |
|
A_name = diffusers3_lora |
|
B_name = "{}.lora.down.weight".format(x) |
|
mid_name = None |
|
elif mochi_lora in lora.keys(): |
|
A_name = mochi_lora |
|
B_name = "{}.lora_A".format(x) |
|
mid_name = None |
|
elif transformers_lora in lora.keys(): |
|
A_name = transformers_lora |
|
B_name = "{}.lora_linear_layer.down.weight".format(x) |
|
mid_name = None |
|
elif qwen_default_lora in lora.keys(): |
|
A_name = qwen_default_lora |
|
B_name = "{}.lora_A.default.weight".format(x) |
|
mid_name = None |
|
|
|
if A_name is not None: |
|
mid = None |
|
if mid_name is not None and mid_name in lora.keys(): |
|
mid = lora[mid_name] |
|
loaded_keys.add(mid_name) |
|
reshape = None |
|
if reshape_name in lora.keys(): |
|
try: |
|
reshape = lora[reshape_name].tolist() |
|
loaded_keys.add(reshape_name) |
|
except: |
|
pass |
|
weights = (lora[A_name], lora[B_name], alpha, mid, dora_scale, reshape) |
|
loaded_keys.add(A_name) |
|
loaded_keys.add(B_name) |
|
return cls(loaded_keys, weights) |
|
else: |
|
return None |
|
|
|
def calculate_weight( |
|
self, |
|
weight, |
|
key, |
|
strength, |
|
strength_model, |
|
offset, |
|
function, |
|
intermediate_dtype=torch.float32, |
|
original_weight=None, |
|
): |
|
v = self.weights |
|
mat1 = comfy.model_management.cast_to_device( |
|
v[0], weight.device, intermediate_dtype |
|
) |
|
mat2 = comfy.model_management.cast_to_device( |
|
v[1], weight.device, intermediate_dtype |
|
) |
|
dora_scale = v[4] |
|
reshape = v[5] |
|
|
|
if reshape is not None: |
|
weight = pad_tensor_to_shape(weight, reshape) |
|
|
|
if v[2] is not None: |
|
alpha = v[2] / mat2.shape[0] |
|
else: |
|
alpha = 1.0 |
|
|
|
if v[3] is not None: |
|
|
|
mat3 = comfy.model_management.cast_to_device( |
|
v[3], weight.device, intermediate_dtype |
|
) |
|
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]] |
|
mat2 = ( |
|
torch.mm( |
|
mat2.transpose(0, 1).flatten(start_dim=1), |
|
mat3.transpose(0, 1).flatten(start_dim=1), |
|
) |
|
.reshape(final_shape) |
|
.transpose(0, 1) |
|
) |
|
try: |
|
lora_diff = torch.mm( |
|
mat1.flatten(start_dim=1), mat2.flatten(start_dim=1) |
|
).reshape(weight.shape) |
|
if dora_scale is not None: |
|
weight = weight_decompose( |
|
dora_scale, |
|
weight, |
|
lora_diff, |
|
alpha, |
|
strength, |
|
intermediate_dtype, |
|
function, |
|
) |
|
else: |
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) |
|
except Exception as e: |
|
logging.error("ERROR {} {} {}".format(self.name, key, e)) |
|
return weight |
|
|