Spaces:
Runtime error
Runtime error
Commit
·
ce92feb
1
Parent(s):
fc8ab35
Update lora.py
Browse files
lora.py
CHANGED
|
@@ -5,12 +5,16 @@
|
|
| 5 |
|
| 6 |
import math
|
| 7 |
import os
|
| 8 |
-
from typing import List, Tuple, Union
|
|
|
|
|
|
|
| 9 |
import numpy as np
|
| 10 |
import torch
|
| 11 |
import re
|
| 12 |
|
| 13 |
|
|
|
|
|
|
|
| 14 |
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
|
| 15 |
|
| 16 |
|
|
@@ -400,7 +404,16 @@ def parse_block_lr_kwargs(nw_kwargs):
|
|
| 400 |
return down_lr_weight, mid_lr_weight, up_lr_weight
|
| 401 |
|
| 402 |
|
| 403 |
-
def create_network(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 404 |
if network_dim is None:
|
| 405 |
network_dim = 4 # default
|
| 406 |
if network_alpha is None:
|
|
@@ -719,33 +732,36 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
|
|
| 719 |
class LoRANetwork(torch.nn.Module):
|
| 720 |
NUM_OF_BLOCKS = 12 # フルモデル相当でのup,downの層の数
|
| 721 |
|
| 722 |
-
|
| 723 |
-
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"]
|
| 724 |
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
|
| 725 |
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
| 726 |
LORA_PREFIX_UNET = "lora_unet"
|
| 727 |
LORA_PREFIX_TEXT_ENCODER = "lora_te"
|
| 728 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 729 |
def __init__(
|
| 730 |
self,
|
| 731 |
-
text_encoder,
|
| 732 |
unet,
|
| 733 |
-
multiplier=1.0,
|
| 734 |
-
lora_dim=4,
|
| 735 |
-
alpha=1,
|
| 736 |
-
dropout=None,
|
| 737 |
-
rank_dropout=None,
|
| 738 |
-
module_dropout=None,
|
| 739 |
-
conv_lora_dim=None,
|
| 740 |
-
conv_alpha=None,
|
| 741 |
-
block_dims=None,
|
| 742 |
-
block_alphas=None,
|
| 743 |
-
conv_block_dims=None,
|
| 744 |
-
conv_block_alphas=None,
|
| 745 |
-
modules_dim=None,
|
| 746 |
-
modules_alpha=None,
|
| 747 |
-
module_class=LoRAModule,
|
| 748 |
-
varbose=False,
|
| 749 |
) -> None:
|
| 750 |
"""
|
| 751 |
LoRA network: すごく引数が多いが、パターンは以下の通り
|
|
@@ -783,8 +799,21 @@ class LoRANetwork(torch.nn.Module):
|
|
| 783 |
print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
|
| 784 |
|
| 785 |
# create module instances
|
| 786 |
-
def create_modules(
|
| 787 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 788 |
loras = []
|
| 789 |
skipped = []
|
| 790 |
for name, module in root_module.named_modules():
|
|
@@ -800,11 +829,14 @@ class LoRANetwork(torch.nn.Module):
|
|
| 800 |
|
| 801 |
dim = None
|
| 802 |
alpha = None
|
|
|
|
| 803 |
if modules_dim is not None:
|
|
|
|
| 804 |
if lora_name in modules_dim:
|
| 805 |
dim = modules_dim[lora_name]
|
| 806 |
alpha = modules_alpha[lora_name]
|
| 807 |
elif is_unet and block_dims is not None:
|
|
|
|
| 808 |
block_idx = get_block_index(lora_name)
|
| 809 |
if is_linear or is_conv2d_1x1:
|
| 810 |
dim = block_dims[block_idx]
|
|
@@ -813,6 +845,7 @@ class LoRANetwork(torch.nn.Module):
|
|
| 813 |
dim = conv_block_dims[block_idx]
|
| 814 |
alpha = conv_block_alphas[block_idx]
|
| 815 |
else:
|
|
|
|
| 816 |
if is_linear or is_conv2d_1x1:
|
| 817 |
dim = self.lora_dim
|
| 818 |
alpha = self.alpha
|
|
@@ -821,6 +854,7 @@ class LoRANetwork(torch.nn.Module):
|
|
| 821 |
alpha = self.conv_alpha
|
| 822 |
|
| 823 |
if dim is None or dim == 0:
|
|
|
|
| 824 |
if is_linear or is_conv2d_1x1 or (self.conv_lora_dim is not None or conv_block_dims is not None):
|
| 825 |
skipped.append(lora_name)
|
| 826 |
continue
|
|
@@ -838,7 +872,24 @@ class LoRANetwork(torch.nn.Module):
|
|
| 838 |
loras.append(lora)
|
| 839 |
return loras, skipped
|
| 840 |
|
| 841 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 842 |
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
| 843 |
|
| 844 |
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
|
|
@@ -846,7 +897,7 @@ class LoRANetwork(torch.nn.Module):
|
|
| 846 |
if modules_dim is not None or self.conv_lora_dim is not None or conv_block_dims is not None:
|
| 847 |
target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
| 848 |
|
| 849 |
-
self.unet_loras, skipped_un = create_modules(True, unet, target_modules)
|
| 850 |
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
| 851 |
|
| 852 |
skipped = skipped_te + skipped_un
|
|
@@ -880,7 +931,6 @@ class LoRANetwork(torch.nn.Module):
|
|
| 880 |
weights_sd = load_file(file)
|
| 881 |
else:
|
| 882 |
weights_sd = torch.load(file, map_location="cpu")
|
| 883 |
-
|
| 884 |
info = self.load_state_dict(weights_sd, False)
|
| 885 |
return info
|
| 886 |
|
|
@@ -961,6 +1011,7 @@ class LoRANetwork(torch.nn.Module):
|
|
| 961 |
|
| 962 |
return lr_weight
|
| 963 |
|
|
|
|
| 964 |
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
|
| 965 |
self.requires_grad_(True)
|
| 966 |
all_params = []
|
|
|
|
| 5 |
|
| 6 |
import math
|
| 7 |
import os
|
| 8 |
+
from typing import Dict, List, Optional, Tuple, Type, Union
|
| 9 |
+
from diffusers import AutoencoderKL
|
| 10 |
+
from transformers import CLIPTextModel
|
| 11 |
import numpy as np
|
| 12 |
import torch
|
| 13 |
import re
|
| 14 |
|
| 15 |
|
| 16 |
+
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
|
| 17 |
+
|
| 18 |
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
|
| 19 |
|
| 20 |
|
|
|
|
| 404 |
return down_lr_weight, mid_lr_weight, up_lr_weight
|
| 405 |
|
| 406 |
|
| 407 |
+
def create_network(
|
| 408 |
+
multiplier: float,
|
| 409 |
+
network_dim: Optional[int],
|
| 410 |
+
network_alpha: Optional[float],
|
| 411 |
+
vae: AutoencoderKL,
|
| 412 |
+
text_encoder: Union[CLIPTextModel, List[CLIPTextModel]],
|
| 413 |
+
unet,
|
| 414 |
+
neuron_dropout: Optional[float] = None,
|
| 415 |
+
**kwargs,
|
| 416 |
+
):
|
| 417 |
if network_dim is None:
|
| 418 |
network_dim = 4 # default
|
| 419 |
if network_alpha is None:
|
|
|
|
| 732 |
class LoRANetwork(torch.nn.Module):
|
| 733 |
NUM_OF_BLOCKS = 12 # フルモデル相当でのup,downの層の数
|
| 734 |
|
| 735 |
+
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
|
|
|
|
| 736 |
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
|
| 737 |
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
| 738 |
LORA_PREFIX_UNET = "lora_unet"
|
| 739 |
LORA_PREFIX_TEXT_ENCODER = "lora_te"
|
| 740 |
|
| 741 |
+
# SDXL: must starts with LORA_PREFIX_TEXT_ENCODER
|
| 742 |
+
LORA_PREFIX_TEXT_ENCODER1 = "lora_te1"
|
| 743 |
+
LORA_PREFIX_TEXT_ENCODER2 = "lora_te2"
|
| 744 |
+
|
| 745 |
def __init__(
|
| 746 |
self,
|
| 747 |
+
text_encoder: Union[List[CLIPTextModel], CLIPTextModel],
|
| 748 |
unet,
|
| 749 |
+
multiplier: float = 1.0,
|
| 750 |
+
lora_dim: int = 4,
|
| 751 |
+
alpha: float = 1,
|
| 752 |
+
dropout: Optional[float] = None,
|
| 753 |
+
rank_dropout: Optional[float] = None,
|
| 754 |
+
module_dropout: Optional[float] = None,
|
| 755 |
+
conv_lora_dim: Optional[int] = None,
|
| 756 |
+
conv_alpha: Optional[float] = None,
|
| 757 |
+
block_dims: Optional[List[int]] = None,
|
| 758 |
+
block_alphas: Optional[List[float]] = None,
|
| 759 |
+
conv_block_dims: Optional[List[int]] = None,
|
| 760 |
+
conv_block_alphas: Optional[List[float]] = None,
|
| 761 |
+
modules_dim: Optional[Dict[str, int]] = None,
|
| 762 |
+
modules_alpha: Optional[Dict[str, int]] = None,
|
| 763 |
+
module_class: Type[object] = LoRAModule,
|
| 764 |
+
varbose: Optional[bool] = False,
|
| 765 |
) -> None:
|
| 766 |
"""
|
| 767 |
LoRA network: すごく引数が多いが、パターンは以下の通り
|
|
|
|
| 799 |
print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
|
| 800 |
|
| 801 |
# create module instances
|
| 802 |
+
def create_modules(
|
| 803 |
+
is_unet: bool,
|
| 804 |
+
text_encoder_idx: Optional[int], # None, 1, 2
|
| 805 |
+
root_module: torch.nn.Module,
|
| 806 |
+
target_replace_modules: List[torch.nn.Module],
|
| 807 |
+
) -> List[LoRAModule]:
|
| 808 |
+
prefix = (
|
| 809 |
+
self.LORA_PREFIX_UNET
|
| 810 |
+
if is_unet
|
| 811 |
+
else (
|
| 812 |
+
self.LORA_PREFIX_TEXT_ENCODER
|
| 813 |
+
if text_encoder_idx is None
|
| 814 |
+
else (self.LORA_PREFIX_TEXT_ENCODER1 if text_encoder_idx == 1 else self.LORA_PREFIX_TEXT_ENCODER2)
|
| 815 |
+
)
|
| 816 |
+
)
|
| 817 |
loras = []
|
| 818 |
skipped = []
|
| 819 |
for name, module in root_module.named_modules():
|
|
|
|
| 829 |
|
| 830 |
dim = None
|
| 831 |
alpha = None
|
| 832 |
+
|
| 833 |
if modules_dim is not None:
|
| 834 |
+
# モジュール指定あり
|
| 835 |
if lora_name in modules_dim:
|
| 836 |
dim = modules_dim[lora_name]
|
| 837 |
alpha = modules_alpha[lora_name]
|
| 838 |
elif is_unet and block_dims is not None:
|
| 839 |
+
# U-Netでblock_dims指定あり
|
| 840 |
block_idx = get_block_index(lora_name)
|
| 841 |
if is_linear or is_conv2d_1x1:
|
| 842 |
dim = block_dims[block_idx]
|
|
|
|
| 845 |
dim = conv_block_dims[block_idx]
|
| 846 |
alpha = conv_block_alphas[block_idx]
|
| 847 |
else:
|
| 848 |
+
# 通常、すべて対象とする
|
| 849 |
if is_linear or is_conv2d_1x1:
|
| 850 |
dim = self.lora_dim
|
| 851 |
alpha = self.alpha
|
|
|
|
| 854 |
alpha = self.conv_alpha
|
| 855 |
|
| 856 |
if dim is None or dim == 0:
|
| 857 |
+
# skipした情報を出力
|
| 858 |
if is_linear or is_conv2d_1x1 or (self.conv_lora_dim is not None or conv_block_dims is not None):
|
| 859 |
skipped.append(lora_name)
|
| 860 |
continue
|
|
|
|
| 872 |
loras.append(lora)
|
| 873 |
return loras, skipped
|
| 874 |
|
| 875 |
+
text_encoders = text_encoder if type(text_encoder) == list else [text_encoder]
|
| 876 |
+
print(text_encoders)
|
| 877 |
+
# create LoRA for text encoder
|
| 878 |
+
# 毎回すべてのモジュールを作るのは無駄なので要検討
|
| 879 |
+
self.text_encoder_loras = []
|
| 880 |
+
skipped_te = []
|
| 881 |
+
for i, text_encoder in enumerate(text_encoders):
|
| 882 |
+
if len(text_encoders) > 1:
|
| 883 |
+
index = i + 1
|
| 884 |
+
print(f"create LoRA for Text Encoder {index}:")
|
| 885 |
+
else:
|
| 886 |
+
index = None
|
| 887 |
+
print(f"create LoRA for Text Encoder:")
|
| 888 |
+
|
| 889 |
+
print(text_encoder)
|
| 890 |
+
text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
| 891 |
+
self.text_encoder_loras.extend(text_encoder_loras)
|
| 892 |
+
skipped_te += skipped
|
| 893 |
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
| 894 |
|
| 895 |
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
|
|
|
|
| 897 |
if modules_dim is not None or self.conv_lora_dim is not None or conv_block_dims is not None:
|
| 898 |
target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
| 899 |
|
| 900 |
+
self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
|
| 901 |
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
| 902 |
|
| 903 |
skipped = skipped_te + skipped_un
|
|
|
|
| 931 |
weights_sd = load_file(file)
|
| 932 |
else:
|
| 933 |
weights_sd = torch.load(file, map_location="cpu")
|
|
|
|
| 934 |
info = self.load_state_dict(weights_sd, False)
|
| 935 |
return info
|
| 936 |
|
|
|
|
| 1011 |
|
| 1012 |
return lr_weight
|
| 1013 |
|
| 1014 |
+
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
|
| 1015 |
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
|
| 1016 |
self.requires_grad_(True)
|
| 1017 |
all_params = []
|