File size: 7,011 Bytes
cc69848 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 |
from functools import cache
import torch
import torch.nn as nn
from .base import LycorisBaseModule
from ..logging import logger
@cache
def log_bypass_override():
return logger.warning(
"Automatic Bypass-Mode detected in algo=full, "
"override with bypass_mode=False since algo=full not support bypass mode. "
"If you are using quantized model which require bypass mode, please don't use algo=full. "
)
class FullModule(LycorisBaseModule):
name = "full"
support_module = {
"linear",
"conv1d",
"conv2d",
"conv3d",
}
weight_list = ["diff", "diff_b"]
weight_list_det = ["diff"]
def __init__(
self,
lora_name,
org_module: nn.Module,
multiplier=1.0,
lora_dim=4,
alpha=1,
dropout=0.0,
rank_dropout=0.0,
module_dropout=0.0,
use_tucker=False,
use_scalar=False,
rank_dropout_scale=False,
bypass_mode=None,
**kwargs,
):
org_bypass = bypass_mode
super().__init__(
lora_name,
org_module,
multiplier,
dropout,
rank_dropout,
module_dropout,
rank_dropout_scale,
bypass_mode,
)
if bypass_mode and org_bypass is None:
self.bypass_mode = False
log_bypass_override()
if self.module_type not in self.support_module:
raise ValueError(f"{self.module_type} is not supported in Full algo.")
if self.is_quant:
raise ValueError(
"Quant Linear is not supported and meaningless in Full algo."
)
if self.bypass_mode:
raise ValueError("bypass mode is not supported in Full algo.")
self.weight = nn.Parameter(torch.zeros_like(org_module.weight))
if org_module.bias is not None:
self.bias = nn.Parameter(torch.zeros_like(org_module.bias))
else:
self.bias = None
self.is_diff = True
self._org_weight = [self.org_module[0].weight.data.cpu().clone()]
if self.org_module[0].bias is not None:
self.org_bias = [self.org_module[0].bias.data.cpu().clone()]
else:
self.org_bias = None
@classmethod
def make_module_from_state_dict(cls, lora_name, orig_module, diff, diff_b):
module = cls(
lora_name,
orig_module,
1,
)
module.weight.copy_(diff)
if diff_b is not None:
if orig_module.bias is not None:
module.bias.copy_(diff_b)
else:
module.bias = nn.Parameter(diff_b)
module.is_diff = True
return module
@property
def org_weight(self):
return self._org_weight[0]
@org_weight.setter
def org_weight(self, value):
self.org_module[0].weight.data.copy_(value)
def apply_to(self, **kwargs):
self.org_forward = self.org_module[0].forward
self.org_module[0].forward = self.forward
self.weight.data.add_(self.org_module[0].weight.data)
self._org_weight = [self.org_module[0].weight.data.cpu().clone()]
delattr(self.org_module[0], "weight")
if self.org_module[0].bias is not None:
self.bias.data.add_(self.org_module[0].bias.data)
self.org_bias = [self.org_module[0].bias.data.cpu().clone()]
delattr(self.org_module[0], "bias")
else:
self.org_bias = None
self.is_diff = False
def restore(self):
self.org_module[0].forward = self.org_forward
self.org_module[0].weight = nn.Parameter(self._org_weight[0])
if self.org_bias is not None:
self.org_module[0].bias = nn.Parameter(self.org_bias[0])
def custom_state_dict(self):
sd = {"diff": self.weight.data.cpu() - self._org_weight[0]}
if self.bias is not None:
sd["diff_b"] = self.bias.data.cpu() - self.org_bias[0]
return sd
def load_weight_prehook(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
diff_weight = state_dict.pop(f"{prefix}diff")
state_dict[f"{prefix}weight"] = diff_weight + self.weight.data.to(diff_weight)
if f"{prefix}diff_b" in state_dict:
diff_bias = state_dict.pop(f"{prefix}diff_b")
state_dict[f"{prefix}bias"] = diff_bias + self.bias.data.to(diff_bias)
def make_weight(self, scale=1, device=None):
drop = (
torch.rand(self.dim, device=device) > self.rank_dropout
if self.rank_dropout and self.training
else 1
)
if drop != 1 or scale != 1 or self.is_diff:
diff_w, diff_b = self.get_diff_weight(scale, device=device)
weight = self.org_weight + diff_w * drop
if self.org_bias is not None:
bias = self.org_bias + diff_b * drop
else:
bias = None
else:
weight = self.weight
bias = self.bias
return weight, bias
def get_diff_weight(self, multiplier=1, shape=None, device=None):
if self.is_diff:
diff_b = None
if self.bias is not None:
diff_b = self.bias * multiplier
return self.weight * multiplier, diff_b
org_weight = self.org_module[0].weight.to(device, dtype=self.weight.dtype)
diff = self.weight.to(device) - org_weight
diff_b = None
if shape:
diff = diff.view(shape)
if self.bias is not None:
org_bias = self.org_module[0].bias.to(device, dtype=self.bias.dtype)
diff_b = self.bias.to(device) - org_bias
if device is not None:
diff = diff.to(device)
if self.bias is not None:
diff_b = diff_b.to(device)
if multiplier != 1:
diff = diff * multiplier
if diff_b is not None:
diff_b = diff_b * multiplier
return diff * multiplier, diff_b
def get_merged_weight(self, multiplier=1, shape=None, device=None):
weight, bias = self.make_weight(multiplier, device)
if shape is not None:
weight = weight.view(shape)
if bias is not None:
bias = bias.view(shape[0])
return weight, bias
def forward(self, x: torch.Tensor, *args, **kwargs):
if (
self.module_dropout
and self.training
and torch.rand(1) < self.module_dropout
):
original = True
else:
original = False
if original:
return self.org_forward(x)
scale = self.multiplier
weight, bias = self.make_weight(scale, x.device)
kw_dict = self.kw_dict | {"weight": weight, "bias": bias}
return self.op(x, **kw_dict)
|