Spaces:
Running
Running
VideoModelStudio
/
docs
/finetrainers-src-codebase
/finetrainers
/trainer
/control_trainer
/config.py
import argparse | |
from enum import Enum | |
from typing import TYPE_CHECKING, Any, Dict, List, Union | |
from finetrainers.utils import ArgsConfigMixin | |
if TYPE_CHECKING: | |
from finetrainers.args import BaseArgs | |
class ControlType(str, Enum): | |
r""" | |
Enum class for the control types. | |
""" | |
CANNY = "canny" | |
CUSTOM = "custom" | |
NONE = "none" | |
class FrameConditioningType(str, Enum): | |
r""" | |
Enum class for the frame conditioning types. | |
""" | |
INDEX = "index" | |
PREFIX = "prefix" | |
RANDOM = "random" | |
FIRST_AND_LAST = "first_and_last" | |
FULL = "full" | |
class ControlLowRankConfig(ArgsConfigMixin): | |
r""" | |
Configuration class for SFT channel-concatenated Control low rank training. | |
Args: | |
control_type (`str`, defaults to `"canny"`): | |
Control type for the low rank approximation matrices. Can be "canny", "custom". | |
rank (int, defaults to `64`): | |
Rank of the low rank approximation matrix. | |
lora_alpha (int, defaults to `64`): | |
The lora_alpha parameter to compute scaling factor (lora_alpha / rank) for low-rank matrices. | |
target_modules (`str` or `List[str]`, defaults to `"(transformer_blocks|single_transformer_blocks).*(to_q|to_k|to_v|to_out.0|ff.net.0.proj|ff.net.2)"`): | |
Target modules for the low rank approximation matrices. Can be a regex string or a list of regex strings. | |
train_qk_norm (`bool`, defaults to `False`): | |
Whether to train the QK normalization layers. | |
frame_conditioning_type (`str`, defaults to `"full"`): | |
Type of frame conditioning. Can be "index", "prefix", "random", "first_and_last", or "full". | |
frame_conditioning_index (int, defaults to `0`): | |
Index of the frame conditioning. Only used if `frame_conditioning_type` is "index". | |
frame_conditioning_concatenate_mask (`bool`, defaults to `False`): | |
Whether to concatenate the frame mask with the latents across channel dim. | |
""" | |
control_type: str = ControlType.CANNY | |
rank: int = 64 | |
lora_alpha: int = 64 | |
target_modules: Union[str, List[str]] = ( | |
"(transformer_blocks|single_transformer_blocks).*(to_q|to_k|to_v|to_out.0|ff.net.0.proj|ff.net.2)" | |
) | |
train_qk_norm: bool = False | |
# Specific to video models | |
frame_conditioning_type: str = FrameConditioningType.FULL | |
frame_conditioning_index: int = 0 | |
frame_conditioning_concatenate_mask: bool = False | |
def add_args(self, parser: argparse.ArgumentParser): | |
parser.add_argument( | |
"--control_type", | |
type=str, | |
default=ControlType.CANNY.value, | |
choices=[x.value for x in ControlType.__members__.values()], | |
) | |
parser.add_argument("--rank", type=int, default=64) | |
parser.add_argument("--lora_alpha", type=int, default=64) | |
parser.add_argument( | |
"--target_modules", | |
type=str, | |
nargs="+", | |
default=[ | |
"(transformer_blocks|single_transformer_blocks).*(to_q|to_k|to_v|to_out.0|ff.net.0.proj|ff.net.2)" | |
], | |
) | |
parser.add_argument("--train_qk_norm", action="store_true") | |
parser.add_argument( | |
"--frame_conditioning_type", | |
type=str, | |
default=FrameConditioningType.INDEX.value, | |
choices=[x.value for x in FrameConditioningType.__members__.values()], | |
) | |
parser.add_argument("--frame_conditioning_index", type=int, default=0) | |
parser.add_argument("--frame_conditioning_concatenate_mask", action="store_true") | |
def validate_args(self, args: "BaseArgs"): | |
assert self.rank > 0, "Rank must be a positive integer." | |
assert self.lora_alpha > 0, "lora_alpha must be a positive integer." | |
def map_args(self, argparse_args: argparse.Namespace, mapped_args: "BaseArgs"): | |
mapped_args.control_type = argparse_args.control_type | |
mapped_args.rank = argparse_args.rank | |
mapped_args.lora_alpha = argparse_args.lora_alpha | |
mapped_args.target_modules = ( | |
argparse_args.target_modules[0] if len(argparse_args.target_modules) == 1 else argparse_args.target_modules | |
) | |
mapped_args.train_qk_norm = argparse_args.train_qk_norm | |
mapped_args.frame_conditioning_type = argparse_args.frame_conditioning_type | |
mapped_args.frame_conditioning_index = argparse_args.frame_conditioning_index | |
mapped_args.frame_conditioning_concatenate_mask = argparse_args.frame_conditioning_concatenate_mask | |
def to_dict(self) -> Dict[str, Any]: | |
return { | |
"control_type": self.control_type, | |
"rank": self.rank, | |
"lora_alpha": self.lora_alpha, | |
"target_modules": self.target_modules, | |
"train_qk_norm": self.train_qk_norm, | |
"frame_conditioning_type": self.frame_conditioning_type, | |
"frame_conditioning_index": self.frame_conditioning_index, | |
"frame_conditioning_concatenate_mask": self.frame_conditioning_concatenate_mask, | |
} | |
class ControlFullRankConfig(ArgsConfigMixin): | |
r""" | |
Configuration class for SFT channel-concatenated Control full rank training. | |
Args: | |
control_type (`str`, defaults to `"canny"`): | |
Control type for the low rank approximation matrices. Can be "canny", "custom". | |
train_qk_norm (`bool`, defaults to `False`): | |
Whether to train the QK normalization layers. | |
frame_conditioning_type (`str`, defaults to `"index"`): | |
Type of frame conditioning. Can be "index", "prefix", "random", "first_and_last", or "full". | |
frame_conditioning_index (int, defaults to `0`): | |
Index of the frame conditioning. Only used if `frame_conditioning_type` is "index". | |
frame_conditioning_concatenate_mask (`bool`, defaults to `False`): | |
Whether to concatenate the frame mask with the latents across channel dim. | |
""" | |
control_type: str = ControlType.CANNY | |
train_qk_norm: bool = False | |
# Specific to video models | |
frame_conditioning_type: str = FrameConditioningType.INDEX | |
frame_conditioning_index: int = 0 | |
frame_conditioning_concatenate_mask: bool = False | |
def add_args(self, parser: argparse.ArgumentParser): | |
parser.add_argument( | |
"--control_type", | |
type=str, | |
default=ControlType.CANNY.value, | |
choices=[x.value for x in ControlType.__members__.values()], | |
) | |
parser.add_argument("--train_qk_norm", action="store_true") | |
parser.add_argument( | |
"--frame_conditioning_type", | |
type=str, | |
default=FrameConditioningType.INDEX.value, | |
choices=[x.value for x in FrameConditioningType.__members__.values()], | |
) | |
parser.add_argument("--frame_conditioning_index", type=int, default=0) | |
parser.add_argument("--frame_conditioning_concatenate_mask", action="store_true") | |
def validate_args(self, args: "BaseArgs"): | |
pass | |
def map_args(self, argparse_args: argparse.Namespace, mapped_args: "BaseArgs"): | |
mapped_args.control_type = argparse_args.control_type | |
mapped_args.train_qk_norm = argparse_args.train_qk_norm | |
mapped_args.frame_conditioning_type = argparse_args.frame_conditioning_type | |
mapped_args.frame_conditioning_index = argparse_args.frame_conditioning_index | |
mapped_args.frame_conditioning_concatenate_mask = argparse_args.frame_conditioning_concatenate_mask | |
def to_dict(self) -> Dict[str, Any]: | |
return { | |
"control_type": self.control_type, | |
"train_qk_norm": self.train_qk_norm, | |
"frame_conditioning_type": self.frame_conditioning_type, | |
"frame_conditioning_index": self.frame_conditioning_index, | |
"frame_conditioning_concatenate_mask": self.frame_conditioning_concatenate_mask, | |
} | |