hardiktiwari's picture
Upload 244 files
33d4721 verified
import os
import torch
from peft import set_peft_model_state_dict
from transformers import TrainerCallback, TrainerControl, TrainerState, TrainingArguments
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
class SavePeftModelCallback(TrainerCallback):
def on_save(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
checkpoint_folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
kwargs["model"].save_pretrained(checkpoint_folder)
pytorch_model_path = os.path.join(checkpoint_folder, "pytorch_model.bin")
torch.save({}, pytorch_model_path)
return control
class LoadBestPeftModelCallback(TrainerCallback):
def on_train_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
print(f"Loading best peft model from {state.best_model_checkpoint} (score: {state.best_metric}).")
best_model_path = os.path.join(state.best_model_checkpoint, "adapter_model.bin")
adapters_weights = torch.load(best_model_path)
model = kwargs["model"]
set_peft_model_state_dict(model, adapters_weights)
return control
class SaveDeepSpeedPeftModelCallback(TrainerCallback):
def __init__(self, trainer, save_steps=500):
self.trainer = trainer
self.save_steps = save_steps
def on_step_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
if (state.global_step + 1) % self.save_steps == 0:
self.trainer.accelerator.wait_for_everyone()
state_dict = self.trainer.accelerator.get_state_dict(self.trainer.deepspeed)
unwrapped_model = self.trainer.accelerator.unwrap_model(self.trainer.deepspeed)
if self.trainer.accelerator.is_main_process:
unwrapped_model.save_pretrained(args.output_dir, state_dict=state_dict)
self.trainer.accelerator.wait_for_everyone()
return control