Spaces:
Running
Running
File size: 7,927 Bytes
9fd1204 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 |
import argparse
from enum import Enum
from typing import TYPE_CHECKING, Any, Dict, List, Union
from finetrainers.utils import ArgsConfigMixin
if TYPE_CHECKING:
from finetrainers.args import BaseArgs
class ControlType(str, Enum):
r"""
Enum class for the control types.
"""
CANNY = "canny"
CUSTOM = "custom"
NONE = "none"
class FrameConditioningType(str, Enum):
r"""
Enum class for the frame conditioning types.
"""
INDEX = "index"
PREFIX = "prefix"
RANDOM = "random"
FIRST_AND_LAST = "first_and_last"
FULL = "full"
class ControlLowRankConfig(ArgsConfigMixin):
r"""
Configuration class for SFT channel-concatenated Control low rank training.
Args:
control_type (`str`, defaults to `"canny"`):
Control type for the low rank approximation matrices. Can be "canny", "custom".
rank (int, defaults to `64`):
Rank of the low rank approximation matrix.
lora_alpha (int, defaults to `64`):
The lora_alpha parameter to compute scaling factor (lora_alpha / rank) for low-rank matrices.
target_modules (`str` or `List[str]`, defaults to `"(transformer_blocks|single_transformer_blocks).*(to_q|to_k|to_v|to_out.0|ff.net.0.proj|ff.net.2)"`):
Target modules for the low rank approximation matrices. Can be a regex string or a list of regex strings.
train_qk_norm (`bool`, defaults to `False`):
Whether to train the QK normalization layers.
frame_conditioning_type (`str`, defaults to `"full"`):
Type of frame conditioning. Can be "index", "prefix", "random", "first_and_last", or "full".
frame_conditioning_index (int, defaults to `0`):
Index of the frame conditioning. Only used if `frame_conditioning_type` is "index".
frame_conditioning_concatenate_mask (`bool`, defaults to `False`):
Whether to concatenate the frame mask with the latents across channel dim.
"""
control_type: str = ControlType.CANNY
rank: int = 64
lora_alpha: int = 64
target_modules: Union[str, List[str]] = (
"(transformer_blocks|single_transformer_blocks).*(to_q|to_k|to_v|to_out.0|ff.net.0.proj|ff.net.2)"
)
train_qk_norm: bool = False
# Specific to video models
frame_conditioning_type: str = FrameConditioningType.FULL
frame_conditioning_index: int = 0
frame_conditioning_concatenate_mask: bool = False
def add_args(self, parser: argparse.ArgumentParser):
parser.add_argument(
"--control_type",
type=str,
default=ControlType.CANNY.value,
choices=[x.value for x in ControlType.__members__.values()],
)
parser.add_argument("--rank", type=int, default=64)
parser.add_argument("--lora_alpha", type=int, default=64)
parser.add_argument(
"--target_modules",
type=str,
nargs="+",
default=[
"(transformer_blocks|single_transformer_blocks).*(to_q|to_k|to_v|to_out.0|ff.net.0.proj|ff.net.2)"
],
)
parser.add_argument("--train_qk_norm", action="store_true")
parser.add_argument(
"--frame_conditioning_type",
type=str,
default=FrameConditioningType.INDEX.value,
choices=[x.value for x in FrameConditioningType.__members__.values()],
)
parser.add_argument("--frame_conditioning_index", type=int, default=0)
parser.add_argument("--frame_conditioning_concatenate_mask", action="store_true")
def validate_args(self, args: "BaseArgs"):
assert self.rank > 0, "Rank must be a positive integer."
assert self.lora_alpha > 0, "lora_alpha must be a positive integer."
def map_args(self, argparse_args: argparse.Namespace, mapped_args: "BaseArgs"):
mapped_args.control_type = argparse_args.control_type
mapped_args.rank = argparse_args.rank
mapped_args.lora_alpha = argparse_args.lora_alpha
mapped_args.target_modules = (
argparse_args.target_modules[0] if len(argparse_args.target_modules) == 1 else argparse_args.target_modules
)
mapped_args.train_qk_norm = argparse_args.train_qk_norm
mapped_args.frame_conditioning_type = argparse_args.frame_conditioning_type
mapped_args.frame_conditioning_index = argparse_args.frame_conditioning_index
mapped_args.frame_conditioning_concatenate_mask = argparse_args.frame_conditioning_concatenate_mask
def to_dict(self) -> Dict[str, Any]:
return {
"control_type": self.control_type,
"rank": self.rank,
"lora_alpha": self.lora_alpha,
"target_modules": self.target_modules,
"train_qk_norm": self.train_qk_norm,
"frame_conditioning_type": self.frame_conditioning_type,
"frame_conditioning_index": self.frame_conditioning_index,
"frame_conditioning_concatenate_mask": self.frame_conditioning_concatenate_mask,
}
class ControlFullRankConfig(ArgsConfigMixin):
r"""
Configuration class for SFT channel-concatenated Control full rank training.
Args:
control_type (`str`, defaults to `"canny"`):
Control type for the low rank approximation matrices. Can be "canny", "custom".
train_qk_norm (`bool`, defaults to `False`):
Whether to train the QK normalization layers.
frame_conditioning_type (`str`, defaults to `"index"`):
Type of frame conditioning. Can be "index", "prefix", "random", "first_and_last", or "full".
frame_conditioning_index (int, defaults to `0`):
Index of the frame conditioning. Only used if `frame_conditioning_type` is "index".
frame_conditioning_concatenate_mask (`bool`, defaults to `False`):
Whether to concatenate the frame mask with the latents across channel dim.
"""
control_type: str = ControlType.CANNY
train_qk_norm: bool = False
# Specific to video models
frame_conditioning_type: str = FrameConditioningType.INDEX
frame_conditioning_index: int = 0
frame_conditioning_concatenate_mask: bool = False
def add_args(self, parser: argparse.ArgumentParser):
parser.add_argument(
"--control_type",
type=str,
default=ControlType.CANNY.value,
choices=[x.value for x in ControlType.__members__.values()],
)
parser.add_argument("--train_qk_norm", action="store_true")
parser.add_argument(
"--frame_conditioning_type",
type=str,
default=FrameConditioningType.INDEX.value,
choices=[x.value for x in FrameConditioningType.__members__.values()],
)
parser.add_argument("--frame_conditioning_index", type=int, default=0)
parser.add_argument("--frame_conditioning_concatenate_mask", action="store_true")
def validate_args(self, args: "BaseArgs"):
pass
def map_args(self, argparse_args: argparse.Namespace, mapped_args: "BaseArgs"):
mapped_args.control_type = argparse_args.control_type
mapped_args.train_qk_norm = argparse_args.train_qk_norm
mapped_args.frame_conditioning_type = argparse_args.frame_conditioning_type
mapped_args.frame_conditioning_index = argparse_args.frame_conditioning_index
mapped_args.frame_conditioning_concatenate_mask = argparse_args.frame_conditioning_concatenate_mask
def to_dict(self) -> Dict[str, Any]:
return {
"control_type": self.control_type,
"train_qk_norm": self.train_qk_norm,
"frame_conditioning_type": self.frame_conditioning_type,
"frame_conditioning_index": self.frame_conditioning_index,
"frame_conditioning_concatenate_mask": self.frame_conditioning_concatenate_mask,
}
|