File size: 1,138 Bytes
cc69848 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
import torch
import torch.nn as nn
from .base import LycorisBaseModule
from .locon import LoConModule
from .loha import LohaModule
from .lokr import LokrModule
from .full import FullModule
from .norms import NormModule
from .diag_oft import DiagOFTModule
from .boft import ButterflyOFTModule
from .glora import GLoRAModule
from .dylora import DyLoraModule
from .ia3 import IA3Module
from ..functional.general import factorization
MODULE_LIST = [
LoConModule,
LohaModule,
IA3Module,
LokrModule,
FullModule,
NormModule,
DiagOFTModule,
ButterflyOFTModule,
GLoRAModule,
DyLoraModule,
]
def get_module(lyco_state_dict, lora_name):
for module in MODULE_LIST:
if module.algo_check(lyco_state_dict, lora_name):
return module, tuple(module.extract_state_dict(lyco_state_dict, lora_name))
return None, None
@torch.no_grad()
def make_module(lyco_type: LycorisBaseModule, params, lora_name, orig_module):
try:
module = lyco_type.make_module_from_state_dict(lora_name, orig_module, *params)
except NotImplementedError:
module = None
return module
|