hardiktiwari's picture
Upload 244 files
33d4721 verified
import os
import torch
from accelerate import PartialState
from huggingface_hub import HfApi
from peft import LoraConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training
from transformers import AutoConfig, BitsAndBytesConfig, PaliGemmaForConditionalGeneration
from autotrain import logger
from autotrain.trainers.common import (
ALLOW_REMOTE_CODE,
LossLoggingCallback,
TrainStartCallback,
UploadLogs,
pause_space,
remove_autotrain_data,
save_training_params,
)
TARGET_MODULES = {}
SUPPORTED_MODELS = [
"PaliGemmaForConditionalGeneration",
# "Florence2ForConditionalGeneration", support later
]
MODEL_CARD = """
---
tags:
- autotrain
- text-generation-inference
- image-text-to-text
- text-generation{peft}
library_name: transformers{base_model}
license: other{dataset_tag}
---
# Model Trained Using AutoTrain
This model was trained using AutoTrain. For more information, please visit [AutoTrain](https://hf.co/docs/autotrain).
# Usage
```python
# you will need to adjust code if you didnt use peft
from PIL import Image
from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor
import torch
import requests
from peft import PeftModel
base_model_id = BASE_MODEL_ID
peft_model_id = THIS_MODEL_ID
max_new_tokens = 100
text = "Whats on the flower?"
img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/bee.JPG?download=true"
image = Image.open(requests.get(img_url, stream=True).raw)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
base_model = PaliGemmaForConditionalGeneration.from_pretrained(base_model_id)
processor = PaliGemmaProcessor.from_pretrained(base_model_id)
model = PeftModel.from_pretrained(base_model, peft_model_id)
model.merge_and_unload()
model = model.eval().to(device)
inputs = processor(text=text, images=image, return_tensors="pt").to(device)
with torch.inference_mode():
generated_ids = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=False,
)
result = processor.batch_decode(generated_ids, skip_special_tokens=True)
print(result)
```
"""
def get_target_modules(config):
if config.target_modules is None:
return TARGET_MODULES.get(config.model)
if config.target_modules.strip() == "":
return TARGET_MODULES.get(config.model)
if config.target_modules.strip().lower() == "all-linear":
return "all-linear"
return config.target_modules.split(",")
def create_model_card(config):
if config.peft:
peft = "\n- peft"
else:
peft = ""
if config.data_path == f"{config.project_name}/autotrain-data" or os.path.isdir(config.data_path):
dataset_tag = ""
else:
dataset_tag = f"\ndatasets:\n- {config.data_path}"
if os.path.isdir(config.model):
base_model = ""
else:
base_model = f"\nbase_model: {config.model}"
model_card = MODEL_CARD.format(
dataset_tag=dataset_tag,
peft=peft,
base_model=base_model,
)
return model_card.strip()
def check_model_support(config):
api = HfApi(token=config.token)
model_info = api.model_info(config.model)
architectures = model_info.config.get("architectures", [])
for arch in architectures:
if arch in SUPPORTED_MODELS:
return True
return False
def configure_logging_steps(config, train_data, valid_data):
logger.info("configuring logging steps")
if config.logging_steps == -1:
if config.valid_split is not None:
logging_steps = int(0.2 * len(valid_data) / config.batch_size)
else:
logging_steps = int(0.2 * len(train_data) / config.batch_size)
if logging_steps == 0:
logging_steps = 1
if logging_steps > 25:
logging_steps = 25
config.logging_steps = logging_steps
else:
logging_steps = config.logging_steps
logger.info(f"Logging steps: {logging_steps}")
return logging_steps
def configure_training_args(config, logging_steps):
logger.info("configuring training args")
training_args = dict(
output_dir=config.project_name,
per_device_train_batch_size=config.batch_size,
per_device_eval_batch_size=config.batch_size,
learning_rate=config.lr,
num_train_epochs=config.epochs,
eval_strategy=config.eval_strategy if config.valid_split is not None else "no",
logging_steps=logging_steps,
save_total_limit=config.save_total_limit,
save_strategy=config.eval_strategy if config.valid_split is not None else "no",
gradient_accumulation_steps=config.gradient_accumulation,
report_to=config.log,
auto_find_batch_size=config.auto_find_batch_size,
lr_scheduler_type=config.scheduler,
optim=config.optimizer,
warmup_ratio=config.warmup_ratio,
weight_decay=config.weight_decay,
max_grad_norm=config.max_grad_norm,
push_to_hub=False,
load_best_model_at_end=True if config.valid_split is not None else False,
ddp_find_unused_parameters=False,
gradient_checkpointing=not config.disable_gradient_checkpointing,
remove_unused_columns=False,
)
if not config.disable_gradient_checkpointing:
if config.peft and config.quantization in ("int4", "int8"):
training_args["gradient_checkpointing_kwargs"] = {"use_reentrant": True}
else:
training_args["gradient_checkpointing_kwargs"] = {"use_reentrant": False}
if config.mixed_precision == "fp16":
training_args["fp16"] = True
if config.mixed_precision == "bf16":
training_args["bf16"] = True
return training_args
def get_callbacks(config):
callbacks = [UploadLogs(config=config), LossLoggingCallback(), TrainStartCallback()]
return callbacks
def get_model(config):
logger.info("loading model config...")
model_config = AutoConfig.from_pretrained(
config.model,
token=config.token,
trust_remote_code=ALLOW_REMOTE_CODE,
use_cache=config.disable_gradient_checkpointing,
)
logger.info("loading model...")
if config.peft:
if config.quantization == "int4":
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=False,
)
elif config.quantization == "int8":
bnb_config = BitsAndBytesConfig(load_in_8bit=True)
else:
bnb_config = None
model = PaliGemmaForConditionalGeneration.from_pretrained(
config.model,
config=model_config,
token=config.token,
quantization_config=bnb_config,
trust_remote_code=ALLOW_REMOTE_CODE,
)
else:
model = PaliGemmaForConditionalGeneration.from_pretrained(
config.model,
config=model_config,
token=config.token,
trust_remote_code=ALLOW_REMOTE_CODE,
)
logger.info(f"model dtype: {model.dtype}")
if config.peft:
logger.info("preparing peft model...")
if config.quantization is not None:
gradient_checkpointing_kwargs = {}
if not config.disable_gradient_checkpointing:
if config.quantization in ("int4", "int8"):
gradient_checkpointing_kwargs = {"use_reentrant": True}
else:
gradient_checkpointing_kwargs = {"use_reentrant": False}
model = prepare_model_for_kbit_training(
model,
use_gradient_checkpointing=not config.disable_gradient_checkpointing,
gradient_checkpointing_kwargs=gradient_checkpointing_kwargs,
)
else:
model.enable_input_require_grads()
peft_config = LoraConfig(
r=config.lora_r,
lora_alpha=config.lora_alpha,
lora_dropout=config.lora_dropout,
bias="none",
task_type="CAUSAL_LM",
target_modules=get_target_modules(config),
)
model = get_peft_model(model, peft_config)
for param in model.vision_tower.parameters():
param.requires_grad = False
for param in model.multi_modal_projector.parameters():
param.requires_grad = False
return model
def merge_adapter(base_model_path, target_model_path, adapter_path):
logger.info("Loading adapter...")
model = PaliGemmaForConditionalGeneration.from_pretrained(
base_model_path,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
trust_remote_code=ALLOW_REMOTE_CODE,
)
model = PeftModel.from_pretrained(model, adapter_path)
model = model.merge_and_unload()
logger.info("Saving target model...")
model.save_pretrained(target_model_path)
def post_training_steps(config, trainer):
logger.info("Finished training, saving model...")
trainer.model.config.use_cache = True
trainer.save_model(config.project_name)
model_card = create_model_card(config)
# save model card to output directory as README.md
with open(f"{config.project_name}/README.md", "w", encoding="utf-8") as f:
f.write(model_card)
if config.peft and config.merge_adapter:
logger.info("Merging adapter weights...")
try:
del trainer
torch.cuda.empty_cache()
merge_adapter(
base_model_path=config.model,
target_model_path=config.project_name,
adapter_path=config.project_name,
)
# remove adapter weights: adapter_*
for file in os.listdir(config.project_name):
if file.startswith("adapter_"):
os.remove(f"{config.project_name}/{file}")
except Exception as e:
logger.warning(f"Failed to merge adapter weights: {e}")
logger.warning("Skipping adapter merge. Only adapter weights will be saved.")
if config.push_to_hub:
if PartialState().process_index == 0:
# remove data folder
remove_autotrain_data(config)
logger.info("Pushing model to hub...")
save_training_params(config)
api = HfApi(token=config.token)
api.create_repo(
repo_id=f"{config.username}/{config.project_name}", repo_type="model", private=True, exist_ok=True
)
api.upload_folder(
folder_path=config.project_name,
repo_id=f"{config.username}/{config.project_name}",
repo_type="model",
)
if PartialState().process_index == 0:
pause_space(config)