tonyshark's picture
Upload 132 files
cc69848 verified
from functools import cache
import torch
import torch.nn as nn
from .base import LycorisBaseModule
from ..logging import logger
@cache
def log_bypass_override():
return logger.warning(
"Automatic Bypass-Mode detected in algo=full, "
"override with bypass_mode=False since algo=full not support bypass mode. "
"If you are using quantized model which require bypass mode, please don't use algo=full. "
)
class FullModule(LycorisBaseModule):
name = "full"
support_module = {
"linear",
"conv1d",
"conv2d",
"conv3d",
}
weight_list = ["diff", "diff_b"]
weight_list_det = ["diff"]
def __init__(
self,
lora_name,
org_module: nn.Module,
multiplier=1.0,
lora_dim=4,
alpha=1,
dropout=0.0,
rank_dropout=0.0,
module_dropout=0.0,
use_tucker=False,
use_scalar=False,
rank_dropout_scale=False,
bypass_mode=None,
**kwargs,
):
org_bypass = bypass_mode
super().__init__(
lora_name,
org_module,
multiplier,
dropout,
rank_dropout,
module_dropout,
rank_dropout_scale,
bypass_mode,
)
if bypass_mode and org_bypass is None:
self.bypass_mode = False
log_bypass_override()
if self.module_type not in self.support_module:
raise ValueError(f"{self.module_type} is not supported in Full algo.")
if self.is_quant:
raise ValueError(
"Quant Linear is not supported and meaningless in Full algo."
)
if self.bypass_mode:
raise ValueError("bypass mode is not supported in Full algo.")
self.weight = nn.Parameter(torch.zeros_like(org_module.weight))
if org_module.bias is not None:
self.bias = nn.Parameter(torch.zeros_like(org_module.bias))
else:
self.bias = None
self.is_diff = True
self._org_weight = [self.org_module[0].weight.data.cpu().clone()]
if self.org_module[0].bias is not None:
self.org_bias = [self.org_module[0].bias.data.cpu().clone()]
else:
self.org_bias = None
@classmethod
def make_module_from_state_dict(cls, lora_name, orig_module, diff, diff_b):
module = cls(
lora_name,
orig_module,
1,
)
module.weight.copy_(diff)
if diff_b is not None:
if orig_module.bias is not None:
module.bias.copy_(diff_b)
else:
module.bias = nn.Parameter(diff_b)
module.is_diff = True
return module
@property
def org_weight(self):
return self._org_weight[0]
@org_weight.setter
def org_weight(self, value):
self.org_module[0].weight.data.copy_(value)
def apply_to(self, **kwargs):
self.org_forward = self.org_module[0].forward
self.org_module[0].forward = self.forward
self.weight.data.add_(self.org_module[0].weight.data)
self._org_weight = [self.org_module[0].weight.data.cpu().clone()]
delattr(self.org_module[0], "weight")
if self.org_module[0].bias is not None:
self.bias.data.add_(self.org_module[0].bias.data)
self.org_bias = [self.org_module[0].bias.data.cpu().clone()]
delattr(self.org_module[0], "bias")
else:
self.org_bias = None
self.is_diff = False
def restore(self):
self.org_module[0].forward = self.org_forward
self.org_module[0].weight = nn.Parameter(self._org_weight[0])
if self.org_bias is not None:
self.org_module[0].bias = nn.Parameter(self.org_bias[0])
def custom_state_dict(self):
sd = {"diff": self.weight.data.cpu() - self._org_weight[0]}
if self.bias is not None:
sd["diff_b"] = self.bias.data.cpu() - self.org_bias[0]
return sd
def load_weight_prehook(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
diff_weight = state_dict.pop(f"{prefix}diff")
state_dict[f"{prefix}weight"] = diff_weight + self.weight.data.to(diff_weight)
if f"{prefix}diff_b" in state_dict:
diff_bias = state_dict.pop(f"{prefix}diff_b")
state_dict[f"{prefix}bias"] = diff_bias + self.bias.data.to(diff_bias)
def make_weight(self, scale=1, device=None):
drop = (
torch.rand(self.dim, device=device) > self.rank_dropout
if self.rank_dropout and self.training
else 1
)
if drop != 1 or scale != 1 or self.is_diff:
diff_w, diff_b = self.get_diff_weight(scale, device=device)
weight = self.org_weight + diff_w * drop
if self.org_bias is not None:
bias = self.org_bias + diff_b * drop
else:
bias = None
else:
weight = self.weight
bias = self.bias
return weight, bias
def get_diff_weight(self, multiplier=1, shape=None, device=None):
if self.is_diff:
diff_b = None
if self.bias is not None:
diff_b = self.bias * multiplier
return self.weight * multiplier, diff_b
org_weight = self.org_module[0].weight.to(device, dtype=self.weight.dtype)
diff = self.weight.to(device) - org_weight
diff_b = None
if shape:
diff = diff.view(shape)
if self.bias is not None:
org_bias = self.org_module[0].bias.to(device, dtype=self.bias.dtype)
diff_b = self.bias.to(device) - org_bias
if device is not None:
diff = diff.to(device)
if self.bias is not None:
diff_b = diff_b.to(device)
if multiplier != 1:
diff = diff * multiplier
if diff_b is not None:
diff_b = diff_b * multiplier
return diff * multiplier, diff_b
def get_merged_weight(self, multiplier=1, shape=None, device=None):
weight, bias = self.make_weight(multiplier, device)
if shape is not None:
weight = weight.view(shape)
if bias is not None:
bias = bias.view(shape[0])
return weight, bias
def forward(self, x: torch.Tensor, *args, **kwargs):
if (
self.module_dropout
and self.training
and torch.rand(1) < self.module_dropout
):
original = True
else:
original = False
if original:
return self.org_forward(x)
scale = self.multiplier
weight, bias = self.make_weight(scale, x.device)
kw_dict = self.kw_dict | {"weight": weight, "bias": bias}
return self.op(x, **kw_dict)