Spaces:
Sleeping
Sleeping
from functools import partial | |
import torch | |
from peft import LoraConfig | |
from transformers import AutoConfig, AutoModelForSequenceClassification, BitsAndBytesConfig | |
from transformers.trainer_callback import PrinterCallback | |
from trl import RewardConfig, RewardTrainer | |
from autotrain import logger | |
from autotrain.trainers.clm import utils | |
from autotrain.trainers.clm.params import LLMTrainingParams | |
from autotrain.trainers.common import ALLOW_REMOTE_CODE | |
def train(config): | |
logger.info("Starting Reward training...") | |
if isinstance(config, dict): | |
config = LLMTrainingParams(**config) | |
train_data, valid_data = utils.process_input_data(config) | |
tokenizer = utils.get_tokenizer(config) | |
train_data, valid_data = utils.process_data_with_chat_template(config, tokenizer, train_data, valid_data) | |
logging_steps = utils.configure_logging_steps(config, train_data, valid_data) | |
training_args = utils.configure_training_args(config, logging_steps) | |
config = utils.configure_block_size(config, tokenizer) | |
training_args["max_length"] = config.block_size | |
args = RewardConfig(**training_args) | |
logger.info("loading model config...") | |
model_config = AutoConfig.from_pretrained( | |
config.model, | |
token=config.token, | |
trust_remote_code=ALLOW_REMOTE_CODE, | |
use_cache=config.disable_gradient_checkpointing, | |
) | |
model_config.num_labels = 1 | |
model_config.pad_token_id = tokenizer.pad_token_id | |
model_config.pad_token = tokenizer.pad_token | |
logger.info("loading model...") | |
if config.peft: | |
if config.quantization == "int4": | |
bnb_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_compute_dtype=torch.float16, | |
bnb_4bit_use_double_quant=False, | |
) | |
elif config.quantization == "int8": | |
bnb_config = BitsAndBytesConfig(load_in_8bit=True) | |
else: | |
bnb_config = None | |
model = AutoModelForSequenceClassification.from_pretrained( | |
config.model, | |
config=model_config, | |
token=config.token, | |
quantization_config=bnb_config, | |
trust_remote_code=ALLOW_REMOTE_CODE, | |
use_flash_attention_2=config.use_flash_attention_2, | |
) | |
else: | |
model = AutoModelForSequenceClassification.from_pretrained( | |
config.model, | |
config=model_config, | |
token=config.token, | |
trust_remote_code=ALLOW_REMOTE_CODE, | |
use_flash_attention_2=config.use_flash_attention_2, | |
) | |
logger.info(f"model dtype: {model.dtype}") | |
model.resize_token_embeddings(len(tokenizer)) | |
if config.peft: | |
peft_config = LoraConfig( | |
r=config.lora_r, | |
lora_alpha=config.lora_alpha, | |
lora_dropout=config.lora_dropout, | |
bias="none", | |
task_type="SEQ_CLS", | |
target_modules=utils.get_target_modules(config), | |
) | |
reward_proc = partial(utils.preprocess_reward, tokenizer=tokenizer) | |
train_data = train_data.map( | |
reward_proc, | |
batched=True, | |
num_proc=4, | |
desc="Running tokenizer on train dataset", | |
) | |
train_data = train_data.filter( | |
lambda x: len(x["input_ids_chosen"]) <= config.block_size and len(x["input_ids_rejected"]) <= config.block_size | |
) | |
if config.valid_split is not None: | |
valid_data = valid_data.map( | |
reward_proc, | |
batched=True, | |
num_proc=4, | |
desc="Running tokenizer on validation dataset", | |
) | |
valid_data = valid_data.filter( | |
lambda x: len(x["input_ids_chosen"]) <= config.block_size | |
and len(x["input_ids_rejected"]) <= config.block_size | |
) | |
logger.info("creating trainer") | |
callbacks = utils.get_callbacks(config) | |
trainer_args = dict( | |
args=args, | |
model=model, | |
callbacks=callbacks, | |
) | |
trainer = RewardTrainer( | |
**trainer_args, | |
train_dataset=train_data, | |
eval_dataset=valid_data if config.valid_split is not None else None, | |
peft_config=peft_config if config.peft else None, | |
processing_class=tokenizer, | |
) | |
trainer.remove_callback(PrinterCallback) | |
trainer.train() | |
utils.post_training_steps(config, trainer) | |