File size: 4,427 Bytes
cc69848 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import torch
import torch.nn as nn
from .base import LycorisBaseModule
class IA3Module(LycorisBaseModule):
name = "ia3"
support_module = {
"linear",
"conv1d",
"conv2d",
"conv3d",
}
weight_list = ["weight", "on_input"]
weight_list_det = ["on_input"]
def __init__(
self,
lora_name,
org_module: nn.Module,
multiplier=1.0,
lora_dim=4,
alpha=1,
dropout=0.0,
rank_dropout=0.0,
module_dropout=0.0,
use_tucker=False,
use_scalar=False,
rank_dropout_scale=False,
weight_decompose=False,
bypass_mode=None,
rs_lora=False,
train_on_input=False,
**kwargs,
):
"""if alpha == 0 or None, alpha is rank (no scaling)."""
super().__init__(
lora_name,
org_module,
multiplier,
dropout,
rank_dropout,
module_dropout,
rank_dropout_scale,
bypass_mode,
)
if self.module_type not in self.support_module:
raise ValueError(f"{self.module_type} is not supported in IA^3 algo.")
if self.module_type.startswith("conv"):
self.isconv = True
in_dim = org_module.in_channels
out_dim = org_module.out_channels
if train_on_input:
train_dim = in_dim
else:
train_dim = out_dim
self.weight = nn.Parameter(
torch.empty(1, train_dim, *(1 for _ in self.shape[2:]))
)
else:
in_dim = org_module.in_features
out_dim = org_module.out_features
if train_on_input:
train_dim = in_dim
else:
train_dim = out_dim
self.weight = nn.Parameter(torch.empty(train_dim))
# Need more experiences on init method
torch.nn.init.constant_(self.weight, 0)
self.train_input = train_on_input
self.register_buffer("on_input", torch.tensor(int(train_on_input)))
@classmethod
def make_module_from_state_dict(cls, lora_name, orig_module, weight):
module = cls(
lora_name,
orig_module,
1,
)
module.weight.data.copy_(weight)
return module
def apply_to(self):
self.org_forward = self.org_module[0].forward
self.org_module[0].forward = self.forward
def make_weight(self, multiplier=1, shape=None, device=None, diff=False):
weight = self.weight * multiplier + int(not diff)
if self.train_input:
diff = self.org_weight * weight
else:
diff = self.org_weight.transpose(0, 1) * weight
diff = diff.transpose(0, 1)
if shape is not None:
diff = diff.view(shape)
if device is not None:
diff = diff.to(device)
return diff
def get_diff_weight(self, multiplier=1, shape=None, device=None):
diff = self.make_weight(
multiplier=multiplier, shape=shape, device=device, diff=True
)
return diff, None
def get_merged_weight(self, multiplier=1, shape=None, device=None):
diff = self.make_weight(multiplier=multiplier, shape=shape, device=device)
return diff, None
def _bypass_forward(self, x, scale=1, diff=False):
weight = self.weight * scale + int(not diff)
if self.train_input:
x = x * weight
out = self.org_forward(x)
if not self.train_input:
out = out * weight
return out
def bypass_forward_diff(self, x, scale=1):
return self._bypass_forward(x, scale, diff=True)
def bypass_forward(self, x, scale=1):
return self._bypass_forward(x, scale, diff=False)
def forward(self, x, *args, **kwargs):
if self.module_dropout and self.training:
if torch.rand(1) < self.module_dropout:
return self.org_forward(x)
if self.bypass_mode:
return self.bypass_forward(x, self.multiplier)
else:
weight = self.get_merged_weight(multiplier=self.multiplier)[0]
bias = (
None
if self.org_module[0].bias is None
else self.org_module[0].bias.data
)
return self.op(x, weight, bias, **self.kw_dict)
|