File size: 5,179 Bytes
cc69848 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
import math
import random
import torch
import torch.nn as nn
from .base import LycorisBaseModule
from ..utils import product
class DyLoraModule(LycorisBaseModule):
support_module = {
"linear",
"conv1d",
"conv2d",
"conv3d",
}
def __init__(
self,
lora_name,
org_module: nn.Module,
multiplier=1.0,
lora_dim=4,
alpha=1,
dropout=0.0,
rank_dropout=0.0,
module_dropout=0.0,
use_tucker=False,
block_size=4,
use_scalar=False,
rank_dropout_scale=False,
weight_decompose=False,
bypass_mode=None,
rs_lora=False,
train_on_input=False,
**kwargs,
):
"""if alpha == 0 or None, alpha is rank (no scaling)."""
super().__init__(
lora_name,
org_module,
multiplier,
dropout,
rank_dropout,
module_dropout,
rank_dropout_scale,
bypass_mode,
)
if self.module_type not in self.support_module:
raise ValueError(f"{self.module_type} is not supported in IA^3 algo.")
assert lora_dim % block_size == 0, "lora_dim must be a multiple of block_size"
self.block_count = lora_dim // block_size
self.block_size = block_size
shape = (
self.shape[0],
product(self.shape[1:]),
)
self.lora_dim = lora_dim
self.up_list = nn.ParameterList(
[torch.empty(shape[0], self.block_size) for i in range(self.block_count)]
)
self.down_list = nn.ParameterList(
[torch.empty(self.block_size, shape[1]) for i in range(self.block_count)]
)
if type(alpha) == torch.Tensor:
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
alpha = lora_dim if alpha is None or alpha == 0 else alpha
self.scale = alpha / self.lora_dim
self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える
# Need more experiences on init method
for v in self.down_list:
torch.nn.init.kaiming_uniform_(v, a=math.sqrt(5))
for v in self.up_list:
torch.nn.init.zeros_(v)
def load_state_dict(self, state_dict, strict: bool = True, assign: bool = False):
return
def custom_state_dict(self):
destination = {}
destination["alpha"] = self.alpha
destination["lora_up.weight"] = nn.Parameter(
torch.concat(list(self.up_list), dim=1)
)
destination["lora_down.weight"] = nn.Parameter(
torch.concat(list(self.down_list)).reshape(
self.lora_dim, -1, *self.shape[2:]
)
)
return destination
def get_weight(self, rank):
b = math.ceil(rank / self.block_size)
down = torch.concat(
list(i.data for i in self.down_list[:b]) + list(self.down_list[b : (b + 1)])
)
up = torch.concat(
list(i.data for i in self.up_list[:b]) + list(self.up_list[b : (b + 1)]),
dim=1,
)
return down, up, self.alpha / (b + 1)
def get_random_rank_weight(self):
b = random.randint(0, self.block_count - 1)
return self.get_weight(b * self.block_size)
def get_diff_weight(self, multiplier=1, shape=None, device=None, rank=None):
if rank is None:
down, up, scale = self.get_random_rank_weight()
else:
down, up, scale = self.get_weight(rank)
w = up @ (down * (scale * multiplier))
if device is not None:
w = w.to(device)
if shape is not None:
w = w.view(shape)
else:
w = w.view(self.shape)
return w, None
def get_merged_weight(self, multiplier=1, shape=None, device=None, rank=None):
diff, _ = self.get_diff_weight(multiplier, shape, device, rank)
return diff + self.org_weight, None
def bypass_forward_diff(self, x, scale=1, rank=None):
if rank is None:
down, up, gamma = self.get_random_rank_weight()
else:
down, up, scale = self.get_weight(rank)
down = down.view(self.lora_dim, -1, *self.shape[2:])
up = up.view(-1, self.lora_dim, *(1 for _ in self.shape[2:]))
scale = scale * gamma
return self.op(self.op(x, down, **self.kw_dict), up)
def bypass_forward(self, x, scale=1, rank=None):
return self.org_forward(x) + self.bypass_forward_diff(x, scale, rank)
def forward(self, x, *args, **kwargs):
if self.module_dropout and self.training:
if torch.rand(1) < self.module_dropout:
return self.org_forward(x)
if self.bypass_mode:
return self.bypass_forward(x, self.multiplier)
else:
weight = self.get_merged_weight(multiplier=self.multiplier)[0]
bias = (
None
if self.org_module[0].bias is None
else self.org_module[0].bias.data
)
return self.op(x, weight, bias, **self.kw_dict)
|