Spaces:
Sleeping
Sleeping
from typing import Optional | |
from pydantic import Field | |
from autotrain.trainers.common import AutoTrainParams | |
class VLMTrainingParams(AutoTrainParams): | |
""" | |
VLMTrainingParams | |
Attributes: | |
model (str): Model name. Default is "google/paligemma-3b-pt-224". | |
project_name (str): Output directory. Default is "project-name". | |
data_path (str): Data path. Default is "data". | |
train_split (str): Train data config. Default is "train". | |
valid_split (Optional[str]): Validation data config. Default is None. | |
trainer (str): Trainer type (captioning, vqa, segmentation, detection). Default is "vqa". | |
log (str): Logging using experiment tracking. Default is "none". | |
disable_gradient_checkpointing (bool): Gradient checkpointing. Default is False. | |
logging_steps (int): Logging steps. Default is -1. | |
eval_strategy (str): Evaluation strategy. Default is "epoch". | |
save_total_limit (int): Save total limit. Default is 1. | |
auto_find_batch_size (bool): Auto find batch size. Default is False. | |
mixed_precision (Optional[str]): Mixed precision (fp16, bf16, or None). Default is None. | |
lr (float): Learning rate. Default is 3e-5. | |
epochs (int): Number of training epochs. Default is 1. | |
batch_size (int): Training batch size. Default is 2. | |
warmup_ratio (float): Warmup proportion. Default is 0.1. | |
gradient_accumulation (int): Gradient accumulation steps. Default is 4. | |
optimizer (str): Optimizer. Default is "adamw_torch". | |
scheduler (str): Scheduler. Default is "linear". | |
weight_decay (float): Weight decay. Default is 0.0. | |
max_grad_norm (float): Max gradient norm. Default is 1.0. | |
seed (int): Seed. Default is 42. | |
quantization (Optional[str]): Quantization (int4, int8, or None). Default is "int4". | |
target_modules (Optional[str]): Target modules. Default is "all-linear". | |
merge_adapter (bool): Merge adapter. Default is False. | |
peft (bool): Use PEFT. Default is False. | |
lora_r (int): Lora r. Default is 16. | |
lora_alpha (int): Lora alpha. Default is 32. | |
lora_dropout (float): Lora dropout. Default is 0.05. | |
image_column (Optional[str]): Image column. Default is "image". | |
text_column (str): Text (answer) column. Default is "text". | |
prompt_text_column (Optional[str]): Prompt (prefix) column. Default is "prompt". | |
push_to_hub (bool): Push to hub. Default is False. | |
username (Optional[str]): Hugging Face Username. Default is None. | |
token (Optional[str]): Huggingface token. Default is None. | |
""" | |
model: str = Field("google/paligemma-3b-pt-224", title="Model name") | |
project_name: str = Field("project-name", title="Output directory") | |
# data params | |
data_path: str = Field("data", title="Data path") | |
train_split: str = Field("train", title="Train data config") | |
valid_split: Optional[str] = Field(None, title="Validation data config") | |
# trainer params | |
trainer: str = Field("vqa", title="Trainer type") # captioning, vqa, segmentation, detection | |
log: str = Field("none", title="Logging using experiment tracking") | |
disable_gradient_checkpointing: bool = Field(False, title="Gradient checkpointing") | |
logging_steps: int = Field(-1, title="Logging steps") | |
eval_strategy: str = Field("epoch", title="Evaluation strategy") | |
save_total_limit: int = Field(1, title="Save total limit") | |
auto_find_batch_size: bool = Field(False, title="Auto find batch size") | |
mixed_precision: Optional[str] = Field(None, title="fp16, bf16, or None") | |
lr: float = Field(3e-5, title="Learning rate") | |
epochs: int = Field(1, title="Number of training epochs") | |
batch_size: int = Field(2, title="Training batch size") | |
warmup_ratio: float = Field(0.1, title="Warmup proportion") | |
gradient_accumulation: int = Field(4, title="Gradient accumulation steps") | |
optimizer: str = Field("adamw_torch", title="Optimizer") | |
scheduler: str = Field("linear", title="Scheduler") | |
weight_decay: float = Field(0.0, title="Weight decay") | |
max_grad_norm: float = Field(1.0, title="Max gradient norm") | |
seed: int = Field(42, title="Seed") | |
# peft | |
quantization: Optional[str] = Field("int4", title="int4, int8, or None") | |
target_modules: Optional[str] = Field("all-linear", title="Target modules") | |
merge_adapter: bool = Field(False, title="Merge adapter") | |
peft: bool = Field(False, title="Use PEFT") | |
lora_r: int = Field(16, title="Lora r") | |
lora_alpha: int = Field(32, title="Lora alpha") | |
lora_dropout: float = Field(0.05, title="Lora dropout") | |
# column mappings | |
image_column: Optional[str] = Field("image", title="Image column") | |
text_column: str = Field("text", title="Text (answer) column") | |
prompt_text_column: Optional[str] = Field("prompt", title="Prompt (prefix) column") | |
# push to hub | |
push_to_hub: bool = Field(False, title="Push to hub") | |
username: Optional[str] = Field(None, title="Hugging Face Username") | |
token: Optional[str] = Field(None, title="Huggingface token") | |