|
from functools import cache |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
from .base import LycorisBaseModule |
|
from ..logging import logger |
|
|
|
|
|
@cache |
|
def log_bypass_override(): |
|
return logger.warning( |
|
"Automatic Bypass-Mode detected in algo=full, " |
|
"override with bypass_mode=False since algo=full not support bypass mode. " |
|
"If you are using quantized model which require bypass mode, please don't use algo=full. " |
|
) |
|
|
|
|
|
class FullModule(LycorisBaseModule): |
|
name = "full" |
|
support_module = { |
|
"linear", |
|
"conv1d", |
|
"conv2d", |
|
"conv3d", |
|
} |
|
weight_list = ["diff", "diff_b"] |
|
weight_list_det = ["diff"] |
|
|
|
def __init__( |
|
self, |
|
lora_name, |
|
org_module: nn.Module, |
|
multiplier=1.0, |
|
lora_dim=4, |
|
alpha=1, |
|
dropout=0.0, |
|
rank_dropout=0.0, |
|
module_dropout=0.0, |
|
use_tucker=False, |
|
use_scalar=False, |
|
rank_dropout_scale=False, |
|
bypass_mode=None, |
|
**kwargs, |
|
): |
|
org_bypass = bypass_mode |
|
super().__init__( |
|
lora_name, |
|
org_module, |
|
multiplier, |
|
dropout, |
|
rank_dropout, |
|
module_dropout, |
|
rank_dropout_scale, |
|
bypass_mode, |
|
) |
|
if bypass_mode and org_bypass is None: |
|
self.bypass_mode = False |
|
log_bypass_override() |
|
|
|
if self.module_type not in self.support_module: |
|
raise ValueError(f"{self.module_type} is not supported in Full algo.") |
|
|
|
if self.is_quant: |
|
raise ValueError( |
|
"Quant Linear is not supported and meaningless in Full algo." |
|
) |
|
|
|
if self.bypass_mode: |
|
raise ValueError("bypass mode is not supported in Full algo.") |
|
|
|
self.weight = nn.Parameter(torch.zeros_like(org_module.weight)) |
|
if org_module.bias is not None: |
|
self.bias = nn.Parameter(torch.zeros_like(org_module.bias)) |
|
else: |
|
self.bias = None |
|
self.is_diff = True |
|
self._org_weight = [self.org_module[0].weight.data.cpu().clone()] |
|
if self.org_module[0].bias is not None: |
|
self.org_bias = [self.org_module[0].bias.data.cpu().clone()] |
|
else: |
|
self.org_bias = None |
|
|
|
@classmethod |
|
def make_module_from_state_dict(cls, lora_name, orig_module, diff, diff_b): |
|
module = cls( |
|
lora_name, |
|
orig_module, |
|
1, |
|
) |
|
module.weight.copy_(diff) |
|
if diff_b is not None: |
|
if orig_module.bias is not None: |
|
module.bias.copy_(diff_b) |
|
else: |
|
module.bias = nn.Parameter(diff_b) |
|
module.is_diff = True |
|
return module |
|
|
|
@property |
|
def org_weight(self): |
|
return self._org_weight[0] |
|
|
|
@org_weight.setter |
|
def org_weight(self, value): |
|
self.org_module[0].weight.data.copy_(value) |
|
|
|
def apply_to(self, **kwargs): |
|
self.org_forward = self.org_module[0].forward |
|
self.org_module[0].forward = self.forward |
|
self.weight.data.add_(self.org_module[0].weight.data) |
|
self._org_weight = [self.org_module[0].weight.data.cpu().clone()] |
|
delattr(self.org_module[0], "weight") |
|
if self.org_module[0].bias is not None: |
|
self.bias.data.add_(self.org_module[0].bias.data) |
|
self.org_bias = [self.org_module[0].bias.data.cpu().clone()] |
|
delattr(self.org_module[0], "bias") |
|
else: |
|
self.org_bias = None |
|
self.is_diff = False |
|
|
|
def restore(self): |
|
self.org_module[0].forward = self.org_forward |
|
self.org_module[0].weight = nn.Parameter(self._org_weight[0]) |
|
if self.org_bias is not None: |
|
self.org_module[0].bias = nn.Parameter(self.org_bias[0]) |
|
|
|
def custom_state_dict(self): |
|
sd = {"diff": self.weight.data.cpu() - self._org_weight[0]} |
|
if self.bias is not None: |
|
sd["diff_b"] = self.bias.data.cpu() - self.org_bias[0] |
|
return sd |
|
|
|
def load_weight_prehook( |
|
self, |
|
state_dict, |
|
prefix, |
|
local_metadata, |
|
strict, |
|
missing_keys, |
|
unexpected_keys, |
|
error_msgs, |
|
): |
|
diff_weight = state_dict.pop(f"{prefix}diff") |
|
state_dict[f"{prefix}weight"] = diff_weight + self.weight.data.to(diff_weight) |
|
if f"{prefix}diff_b" in state_dict: |
|
diff_bias = state_dict.pop(f"{prefix}diff_b") |
|
state_dict[f"{prefix}bias"] = diff_bias + self.bias.data.to(diff_bias) |
|
|
|
def make_weight(self, scale=1, device=None): |
|
drop = ( |
|
torch.rand(self.dim, device=device) > self.rank_dropout |
|
if self.rank_dropout and self.training |
|
else 1 |
|
) |
|
if drop != 1 or scale != 1 or self.is_diff: |
|
diff_w, diff_b = self.get_diff_weight(scale, device=device) |
|
weight = self.org_weight + diff_w * drop |
|
if self.org_bias is not None: |
|
bias = self.org_bias + diff_b * drop |
|
else: |
|
bias = None |
|
else: |
|
weight = self.weight |
|
bias = self.bias |
|
return weight, bias |
|
|
|
def get_diff_weight(self, multiplier=1, shape=None, device=None): |
|
if self.is_diff: |
|
diff_b = None |
|
if self.bias is not None: |
|
diff_b = self.bias * multiplier |
|
return self.weight * multiplier, diff_b |
|
org_weight = self.org_module[0].weight.to(device, dtype=self.weight.dtype) |
|
diff = self.weight.to(device) - org_weight |
|
diff_b = None |
|
if shape: |
|
diff = diff.view(shape) |
|
if self.bias is not None: |
|
org_bias = self.org_module[0].bias.to(device, dtype=self.bias.dtype) |
|
diff_b = self.bias.to(device) - org_bias |
|
if device is not None: |
|
diff = diff.to(device) |
|
if self.bias is not None: |
|
diff_b = diff_b.to(device) |
|
if multiplier != 1: |
|
diff = diff * multiplier |
|
if diff_b is not None: |
|
diff_b = diff_b * multiplier |
|
return diff * multiplier, diff_b |
|
|
|
def get_merged_weight(self, multiplier=1, shape=None, device=None): |
|
weight, bias = self.make_weight(multiplier, device) |
|
if shape is not None: |
|
weight = weight.view(shape) |
|
if bias is not None: |
|
bias = bias.view(shape[0]) |
|
return weight, bias |
|
|
|
def forward(self, x: torch.Tensor, *args, **kwargs): |
|
if ( |
|
self.module_dropout |
|
and self.training |
|
and torch.rand(1) < self.module_dropout |
|
): |
|
original = True |
|
else: |
|
original = False |
|
if original: |
|
return self.org_forward(x) |
|
scale = self.multiplier |
|
weight, bias = self.make_weight(scale, x.device) |
|
kw_dict = self.kw_dict | {"weight": weight, "bias": bias} |
|
return self.op(x, **kw_dict) |
|
|