jbilcke-hf's picture
jbilcke-hf HF Staff
we are going to hack into finetrainers
9fd1204
import argparse
from enum import Enum
from typing import TYPE_CHECKING, Any, Dict, List, Union
from finetrainers.utils import ArgsConfigMixin
if TYPE_CHECKING:
from finetrainers.args import BaseArgs
class ControlType(str, Enum):
r"""
Enum class for the control types.
"""
CANNY = "canny"
CUSTOM = "custom"
NONE = "none"
class FrameConditioningType(str, Enum):
r"""
Enum class for the frame conditioning types.
"""
INDEX = "index"
PREFIX = "prefix"
RANDOM = "random"
FIRST_AND_LAST = "first_and_last"
FULL = "full"
class ControlLowRankConfig(ArgsConfigMixin):
r"""
Configuration class for SFT channel-concatenated Control low rank training.
Args:
control_type (`str`, defaults to `"canny"`):
Control type for the low rank approximation matrices. Can be "canny", "custom".
rank (int, defaults to `64`):
Rank of the low rank approximation matrix.
lora_alpha (int, defaults to `64`):
The lora_alpha parameter to compute scaling factor (lora_alpha / rank) for low-rank matrices.
target_modules (`str` or `List[str]`, defaults to `"(transformer_blocks|single_transformer_blocks).*(to_q|to_k|to_v|to_out.0|ff.net.0.proj|ff.net.2)"`):
Target modules for the low rank approximation matrices. Can be a regex string or a list of regex strings.
train_qk_norm (`bool`, defaults to `False`):
Whether to train the QK normalization layers.
frame_conditioning_type (`str`, defaults to `"full"`):
Type of frame conditioning. Can be "index", "prefix", "random", "first_and_last", or "full".
frame_conditioning_index (int, defaults to `0`):
Index of the frame conditioning. Only used if `frame_conditioning_type` is "index".
frame_conditioning_concatenate_mask (`bool`, defaults to `False`):
Whether to concatenate the frame mask with the latents across channel dim.
"""
control_type: str = ControlType.CANNY
rank: int = 64
lora_alpha: int = 64
target_modules: Union[str, List[str]] = (
"(transformer_blocks|single_transformer_blocks).*(to_q|to_k|to_v|to_out.0|ff.net.0.proj|ff.net.2)"
)
train_qk_norm: bool = False
# Specific to video models
frame_conditioning_type: str = FrameConditioningType.FULL
frame_conditioning_index: int = 0
frame_conditioning_concatenate_mask: bool = False
def add_args(self, parser: argparse.ArgumentParser):
parser.add_argument(
"--control_type",
type=str,
default=ControlType.CANNY.value,
choices=[x.value for x in ControlType.__members__.values()],
)
parser.add_argument("--rank", type=int, default=64)
parser.add_argument("--lora_alpha", type=int, default=64)
parser.add_argument(
"--target_modules",
type=str,
nargs="+",
default=[
"(transformer_blocks|single_transformer_blocks).*(to_q|to_k|to_v|to_out.0|ff.net.0.proj|ff.net.2)"
],
)
parser.add_argument("--train_qk_norm", action="store_true")
parser.add_argument(
"--frame_conditioning_type",
type=str,
default=FrameConditioningType.INDEX.value,
choices=[x.value for x in FrameConditioningType.__members__.values()],
)
parser.add_argument("--frame_conditioning_index", type=int, default=0)
parser.add_argument("--frame_conditioning_concatenate_mask", action="store_true")
def validate_args(self, args: "BaseArgs"):
assert self.rank > 0, "Rank must be a positive integer."
assert self.lora_alpha > 0, "lora_alpha must be a positive integer."
def map_args(self, argparse_args: argparse.Namespace, mapped_args: "BaseArgs"):
mapped_args.control_type = argparse_args.control_type
mapped_args.rank = argparse_args.rank
mapped_args.lora_alpha = argparse_args.lora_alpha
mapped_args.target_modules = (
argparse_args.target_modules[0] if len(argparse_args.target_modules) == 1 else argparse_args.target_modules
)
mapped_args.train_qk_norm = argparse_args.train_qk_norm
mapped_args.frame_conditioning_type = argparse_args.frame_conditioning_type
mapped_args.frame_conditioning_index = argparse_args.frame_conditioning_index
mapped_args.frame_conditioning_concatenate_mask = argparse_args.frame_conditioning_concatenate_mask
def to_dict(self) -> Dict[str, Any]:
return {
"control_type": self.control_type,
"rank": self.rank,
"lora_alpha": self.lora_alpha,
"target_modules": self.target_modules,
"train_qk_norm": self.train_qk_norm,
"frame_conditioning_type": self.frame_conditioning_type,
"frame_conditioning_index": self.frame_conditioning_index,
"frame_conditioning_concatenate_mask": self.frame_conditioning_concatenate_mask,
}
class ControlFullRankConfig(ArgsConfigMixin):
r"""
Configuration class for SFT channel-concatenated Control full rank training.
Args:
control_type (`str`, defaults to `"canny"`):
Control type for the low rank approximation matrices. Can be "canny", "custom".
train_qk_norm (`bool`, defaults to `False`):
Whether to train the QK normalization layers.
frame_conditioning_type (`str`, defaults to `"index"`):
Type of frame conditioning. Can be "index", "prefix", "random", "first_and_last", or "full".
frame_conditioning_index (int, defaults to `0`):
Index of the frame conditioning. Only used if `frame_conditioning_type` is "index".
frame_conditioning_concatenate_mask (`bool`, defaults to `False`):
Whether to concatenate the frame mask with the latents across channel dim.
"""
control_type: str = ControlType.CANNY
train_qk_norm: bool = False
# Specific to video models
frame_conditioning_type: str = FrameConditioningType.INDEX
frame_conditioning_index: int = 0
frame_conditioning_concatenate_mask: bool = False
def add_args(self, parser: argparse.ArgumentParser):
parser.add_argument(
"--control_type",
type=str,
default=ControlType.CANNY.value,
choices=[x.value for x in ControlType.__members__.values()],
)
parser.add_argument("--train_qk_norm", action="store_true")
parser.add_argument(
"--frame_conditioning_type",
type=str,
default=FrameConditioningType.INDEX.value,
choices=[x.value for x in FrameConditioningType.__members__.values()],
)
parser.add_argument("--frame_conditioning_index", type=int, default=0)
parser.add_argument("--frame_conditioning_concatenate_mask", action="store_true")
def validate_args(self, args: "BaseArgs"):
pass
def map_args(self, argparse_args: argparse.Namespace, mapped_args: "BaseArgs"):
mapped_args.control_type = argparse_args.control_type
mapped_args.train_qk_norm = argparse_args.train_qk_norm
mapped_args.frame_conditioning_type = argparse_args.frame_conditioning_type
mapped_args.frame_conditioning_index = argparse_args.frame_conditioning_index
mapped_args.frame_conditioning_concatenate_mask = argparse_args.frame_conditioning_concatenate_mask
def to_dict(self) -> Dict[str, Any]:
return {
"control_type": self.control_type,
"train_qk_norm": self.train_qk_norm,
"frame_conditioning_type": self.frame_conditioning_type,
"frame_conditioning_index": self.frame_conditioning_index,
"frame_conditioning_concatenate_mask": self.frame_conditioning_concatenate_mask,
}