tonyshark's picture
Upload 132 files
cc69848 verified
import math
import random
import torch
import torch.nn as nn
from .base import LycorisBaseModule
from ..utils import product
class DyLoraModule(LycorisBaseModule):
support_module = {
"linear",
"conv1d",
"conv2d",
"conv3d",
}
def __init__(
self,
lora_name,
org_module: nn.Module,
multiplier=1.0,
lora_dim=4,
alpha=1,
dropout=0.0,
rank_dropout=0.0,
module_dropout=0.0,
use_tucker=False,
block_size=4,
use_scalar=False,
rank_dropout_scale=False,
weight_decompose=False,
bypass_mode=None,
rs_lora=False,
train_on_input=False,
**kwargs,
):
"""if alpha == 0 or None, alpha is rank (no scaling)."""
super().__init__(
lora_name,
org_module,
multiplier,
dropout,
rank_dropout,
module_dropout,
rank_dropout_scale,
bypass_mode,
)
if self.module_type not in self.support_module:
raise ValueError(f"{self.module_type} is not supported in IA^3 algo.")
assert lora_dim % block_size == 0, "lora_dim must be a multiple of block_size"
self.block_count = lora_dim // block_size
self.block_size = block_size
shape = (
self.shape[0],
product(self.shape[1:]),
)
self.lora_dim = lora_dim
self.up_list = nn.ParameterList(
[torch.empty(shape[0], self.block_size) for i in range(self.block_count)]
)
self.down_list = nn.ParameterList(
[torch.empty(self.block_size, shape[1]) for i in range(self.block_count)]
)
if type(alpha) == torch.Tensor:
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
alpha = lora_dim if alpha is None or alpha == 0 else alpha
self.scale = alpha / self.lora_dim
self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える
# Need more experiences on init method
for v in self.down_list:
torch.nn.init.kaiming_uniform_(v, a=math.sqrt(5))
for v in self.up_list:
torch.nn.init.zeros_(v)
def load_state_dict(self, state_dict, strict: bool = True, assign: bool = False):
return
def custom_state_dict(self):
destination = {}
destination["alpha"] = self.alpha
destination["lora_up.weight"] = nn.Parameter(
torch.concat(list(self.up_list), dim=1)
)
destination["lora_down.weight"] = nn.Parameter(
torch.concat(list(self.down_list)).reshape(
self.lora_dim, -1, *self.shape[2:]
)
)
return destination
def get_weight(self, rank):
b = math.ceil(rank / self.block_size)
down = torch.concat(
list(i.data for i in self.down_list[:b]) + list(self.down_list[b : (b + 1)])
)
up = torch.concat(
list(i.data for i in self.up_list[:b]) + list(self.up_list[b : (b + 1)]),
dim=1,
)
return down, up, self.alpha / (b + 1)
def get_random_rank_weight(self):
b = random.randint(0, self.block_count - 1)
return self.get_weight(b * self.block_size)
def get_diff_weight(self, multiplier=1, shape=None, device=None, rank=None):
if rank is None:
down, up, scale = self.get_random_rank_weight()
else:
down, up, scale = self.get_weight(rank)
w = up @ (down * (scale * multiplier))
if device is not None:
w = w.to(device)
if shape is not None:
w = w.view(shape)
else:
w = w.view(self.shape)
return w, None
def get_merged_weight(self, multiplier=1, shape=None, device=None, rank=None):
diff, _ = self.get_diff_weight(multiplier, shape, device, rank)
return diff + self.org_weight, None
def bypass_forward_diff(self, x, scale=1, rank=None):
if rank is None:
down, up, gamma = self.get_random_rank_weight()
else:
down, up, scale = self.get_weight(rank)
down = down.view(self.lora_dim, -1, *self.shape[2:])
up = up.view(-1, self.lora_dim, *(1 for _ in self.shape[2:]))
scale = scale * gamma
return self.op(self.op(x, down, **self.kw_dict), up)
def bypass_forward(self, x, scale=1, rank=None):
return self.org_forward(x) + self.bypass_forward_diff(x, scale, rank)
def forward(self, x, *args, **kwargs):
if self.module_dropout and self.training:
if torch.rand(1) < self.module_dropout:
return self.org_forward(x)
if self.bypass_mode:
return self.bypass_forward(x, self.multiplier)
else:
weight = self.get_merged_weight(multiplier=self.multiplier)[0]
bias = (
None
if self.org_module[0].bias is None
else self.org_module[0].bias.data
)
return self.op(x, weight, bias, **self.kw_dict)