jbilcke-hf's picture
jbilcke-hf HF Staff
we are going to hack into finetrainers
9fd1204
import argparse
from typing import TYPE_CHECKING, Any, Dict, List, Union
from finetrainers.utils import ArgsConfigMixin
if TYPE_CHECKING:
from finetrainers.args import BaseArgs
class SFTLowRankConfig(ArgsConfigMixin):
r"""
Configuration class for SFT low rank training.
Args:
rank (int):
Rank of the low rank approximation matrix.
lora_alpha (int):
The lora_alpha parameter to compute scaling factor (lora_alpha / rank) for low-rank matrices.
target_modules (`str` or `List[str]`):
Target modules for the low rank approximation matrices. Can be a regex string or a list of regex strings.
"""
rank: int = 64
lora_alpha: int = 64
target_modules: Union[str, List[str]] = "(transformer_blocks|single_transformer_blocks).*(to_q|to_k|to_v|to_out.0)"
def add_args(self, parser: argparse.ArgumentParser):
parser.add_argument("--rank", type=int, default=64)
parser.add_argument("--lora_alpha", type=int, default=64)
parser.add_argument(
"--target_modules",
type=str,
nargs="+",
default=["(transformer_blocks|single_transformer_blocks).*(to_q|to_k|to_v|to_out.0)"],
)
def validate_args(self, args: "BaseArgs"):
assert self.rank > 0, "Rank must be a positive integer."
assert self.lora_alpha > 0, "lora_alpha must be a positive integer."
def map_args(self, argparse_args: argparse.Namespace, mapped_args: "BaseArgs"):
mapped_args.rank = argparse_args.rank
mapped_args.lora_alpha = argparse_args.lora_alpha
mapped_args.target_modules = (
argparse_args.target_modules[0] if len(argparse_args.target_modules) == 1 else argparse_args.target_modules
)
def to_dict(self) -> Dict[str, Any]:
return {"rank": self.rank, "lora_alpha": self.lora_alpha, "target_modules": self.target_modules}
class SFTFullRankConfig(ArgsConfigMixin):
r"""
Configuration class for SFT full rank training.
"""
def add_args(self, parser: argparse.ArgumentParser):
pass
def validate_args(self, args: "BaseArgs"):
pass
def map_args(self, argparse_args: argparse.Namespace, mapped_args: "BaseArgs"):
pass