Spaces:
Sleeping
Sleeping
import os | |
import torch | |
from accelerate import PartialState | |
from huggingface_hub import HfApi | |
from peft import LoraConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training | |
from transformers import AutoConfig, BitsAndBytesConfig, PaliGemmaForConditionalGeneration | |
from autotrain import logger | |
from autotrain.trainers.common import ( | |
ALLOW_REMOTE_CODE, | |
LossLoggingCallback, | |
TrainStartCallback, | |
UploadLogs, | |
pause_space, | |
remove_autotrain_data, | |
save_training_params, | |
) | |
TARGET_MODULES = {} | |
SUPPORTED_MODELS = [ | |
"PaliGemmaForConditionalGeneration", | |
# "Florence2ForConditionalGeneration", support later | |
] | |
MODEL_CARD = """ | |
--- | |
tags: | |
- autotrain | |
- text-generation-inference | |
- image-text-to-text | |
- text-generation{peft} | |
library_name: transformers{base_model} | |
license: other{dataset_tag} | |
--- | |
# Model Trained Using AutoTrain | |
This model was trained using AutoTrain. For more information, please visit [AutoTrain](https://hf.co/docs/autotrain). | |
# Usage | |
```python | |
# you will need to adjust code if you didnt use peft | |
from PIL import Image | |
from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor | |
import torch | |
import requests | |
from peft import PeftModel | |
base_model_id = BASE_MODEL_ID | |
peft_model_id = THIS_MODEL_ID | |
max_new_tokens = 100 | |
text = "Whats on the flower?" | |
img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/bee.JPG?download=true" | |
image = Image.open(requests.get(img_url, stream=True).raw) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
base_model = PaliGemmaForConditionalGeneration.from_pretrained(base_model_id) | |
processor = PaliGemmaProcessor.from_pretrained(base_model_id) | |
model = PeftModel.from_pretrained(base_model, peft_model_id) | |
model.merge_and_unload() | |
model = model.eval().to(device) | |
inputs = processor(text=text, images=image, return_tensors="pt").to(device) | |
with torch.inference_mode(): | |
generated_ids = model.generate( | |
**inputs, | |
max_new_tokens=max_new_tokens, | |
do_sample=False, | |
) | |
result = processor.batch_decode(generated_ids, skip_special_tokens=True) | |
print(result) | |
``` | |
""" | |
def get_target_modules(config): | |
if config.target_modules is None: | |
return TARGET_MODULES.get(config.model) | |
if config.target_modules.strip() == "": | |
return TARGET_MODULES.get(config.model) | |
if config.target_modules.strip().lower() == "all-linear": | |
return "all-linear" | |
return config.target_modules.split(",") | |
def create_model_card(config): | |
if config.peft: | |
peft = "\n- peft" | |
else: | |
peft = "" | |
if config.data_path == f"{config.project_name}/autotrain-data" or os.path.isdir(config.data_path): | |
dataset_tag = "" | |
else: | |
dataset_tag = f"\ndatasets:\n- {config.data_path}" | |
if os.path.isdir(config.model): | |
base_model = "" | |
else: | |
base_model = f"\nbase_model: {config.model}" | |
model_card = MODEL_CARD.format( | |
dataset_tag=dataset_tag, | |
peft=peft, | |
base_model=base_model, | |
) | |
return model_card.strip() | |
def check_model_support(config): | |
api = HfApi(token=config.token) | |
model_info = api.model_info(config.model) | |
architectures = model_info.config.get("architectures", []) | |
for arch in architectures: | |
if arch in SUPPORTED_MODELS: | |
return True | |
return False | |
def configure_logging_steps(config, train_data, valid_data): | |
logger.info("configuring logging steps") | |
if config.logging_steps == -1: | |
if config.valid_split is not None: | |
logging_steps = int(0.2 * len(valid_data) / config.batch_size) | |
else: | |
logging_steps = int(0.2 * len(train_data) / config.batch_size) | |
if logging_steps == 0: | |
logging_steps = 1 | |
if logging_steps > 25: | |
logging_steps = 25 | |
config.logging_steps = logging_steps | |
else: | |
logging_steps = config.logging_steps | |
logger.info(f"Logging steps: {logging_steps}") | |
return logging_steps | |
def configure_training_args(config, logging_steps): | |
logger.info("configuring training args") | |
training_args = dict( | |
output_dir=config.project_name, | |
per_device_train_batch_size=config.batch_size, | |
per_device_eval_batch_size=config.batch_size, | |
learning_rate=config.lr, | |
num_train_epochs=config.epochs, | |
eval_strategy=config.eval_strategy if config.valid_split is not None else "no", | |
logging_steps=logging_steps, | |
save_total_limit=config.save_total_limit, | |
save_strategy=config.eval_strategy if config.valid_split is not None else "no", | |
gradient_accumulation_steps=config.gradient_accumulation, | |
report_to=config.log, | |
auto_find_batch_size=config.auto_find_batch_size, | |
lr_scheduler_type=config.scheduler, | |
optim=config.optimizer, | |
warmup_ratio=config.warmup_ratio, | |
weight_decay=config.weight_decay, | |
max_grad_norm=config.max_grad_norm, | |
push_to_hub=False, | |
load_best_model_at_end=True if config.valid_split is not None else False, | |
ddp_find_unused_parameters=False, | |
gradient_checkpointing=not config.disable_gradient_checkpointing, | |
remove_unused_columns=False, | |
) | |
if not config.disable_gradient_checkpointing: | |
if config.peft and config.quantization in ("int4", "int8"): | |
training_args["gradient_checkpointing_kwargs"] = {"use_reentrant": True} | |
else: | |
training_args["gradient_checkpointing_kwargs"] = {"use_reentrant": False} | |
if config.mixed_precision == "fp16": | |
training_args["fp16"] = True | |
if config.mixed_precision == "bf16": | |
training_args["bf16"] = True | |
return training_args | |
def get_callbacks(config): | |
callbacks = [UploadLogs(config=config), LossLoggingCallback(), TrainStartCallback()] | |
return callbacks | |
def get_model(config): | |
logger.info("loading model config...") | |
model_config = AutoConfig.from_pretrained( | |
config.model, | |
token=config.token, | |
trust_remote_code=ALLOW_REMOTE_CODE, | |
use_cache=config.disable_gradient_checkpointing, | |
) | |
logger.info("loading model...") | |
if config.peft: | |
if config.quantization == "int4": | |
bnb_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_compute_dtype=torch.float16, | |
bnb_4bit_use_double_quant=False, | |
) | |
elif config.quantization == "int8": | |
bnb_config = BitsAndBytesConfig(load_in_8bit=True) | |
else: | |
bnb_config = None | |
model = PaliGemmaForConditionalGeneration.from_pretrained( | |
config.model, | |
config=model_config, | |
token=config.token, | |
quantization_config=bnb_config, | |
trust_remote_code=ALLOW_REMOTE_CODE, | |
) | |
else: | |
model = PaliGemmaForConditionalGeneration.from_pretrained( | |
config.model, | |
config=model_config, | |
token=config.token, | |
trust_remote_code=ALLOW_REMOTE_CODE, | |
) | |
logger.info(f"model dtype: {model.dtype}") | |
if config.peft: | |
logger.info("preparing peft model...") | |
if config.quantization is not None: | |
gradient_checkpointing_kwargs = {} | |
if not config.disable_gradient_checkpointing: | |
if config.quantization in ("int4", "int8"): | |
gradient_checkpointing_kwargs = {"use_reentrant": True} | |
else: | |
gradient_checkpointing_kwargs = {"use_reentrant": False} | |
model = prepare_model_for_kbit_training( | |
model, | |
use_gradient_checkpointing=not config.disable_gradient_checkpointing, | |
gradient_checkpointing_kwargs=gradient_checkpointing_kwargs, | |
) | |
else: | |
model.enable_input_require_grads() | |
peft_config = LoraConfig( | |
r=config.lora_r, | |
lora_alpha=config.lora_alpha, | |
lora_dropout=config.lora_dropout, | |
bias="none", | |
task_type="CAUSAL_LM", | |
target_modules=get_target_modules(config), | |
) | |
model = get_peft_model(model, peft_config) | |
for param in model.vision_tower.parameters(): | |
param.requires_grad = False | |
for param in model.multi_modal_projector.parameters(): | |
param.requires_grad = False | |
return model | |
def merge_adapter(base_model_path, target_model_path, adapter_path): | |
logger.info("Loading adapter...") | |
model = PaliGemmaForConditionalGeneration.from_pretrained( | |
base_model_path, | |
torch_dtype=torch.float16, | |
low_cpu_mem_usage=True, | |
trust_remote_code=ALLOW_REMOTE_CODE, | |
) | |
model = PeftModel.from_pretrained(model, adapter_path) | |
model = model.merge_and_unload() | |
logger.info("Saving target model...") | |
model.save_pretrained(target_model_path) | |
def post_training_steps(config, trainer): | |
logger.info("Finished training, saving model...") | |
trainer.model.config.use_cache = True | |
trainer.save_model(config.project_name) | |
model_card = create_model_card(config) | |
# save model card to output directory as README.md | |
with open(f"{config.project_name}/README.md", "w", encoding="utf-8") as f: | |
f.write(model_card) | |
if config.peft and config.merge_adapter: | |
logger.info("Merging adapter weights...") | |
try: | |
del trainer | |
torch.cuda.empty_cache() | |
merge_adapter( | |
base_model_path=config.model, | |
target_model_path=config.project_name, | |
adapter_path=config.project_name, | |
) | |
# remove adapter weights: adapter_* | |
for file in os.listdir(config.project_name): | |
if file.startswith("adapter_"): | |
os.remove(f"{config.project_name}/{file}") | |
except Exception as e: | |
logger.warning(f"Failed to merge adapter weights: {e}") | |
logger.warning("Skipping adapter merge. Only adapter weights will be saved.") | |
if config.push_to_hub: | |
if PartialState().process_index == 0: | |
# remove data folder | |
remove_autotrain_data(config) | |
logger.info("Pushing model to hub...") | |
save_training_params(config) | |
api = HfApi(token=config.token) | |
api.create_repo( | |
repo_id=f"{config.username}/{config.project_name}", repo_type="model", private=True, exist_ok=True | |
) | |
api.upload_folder( | |
folder_path=config.project_name, | |
repo_id=f"{config.username}/{config.project_name}", | |
repo_type="model", | |
) | |
if PartialState().process_index == 0: | |
pause_space(config) | |