Spaces:
Sleeping
Sleeping
File size: 4,370 Bytes
33d4721 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
from functools import partial
import torch
from peft import LoraConfig
from transformers import AutoConfig, AutoModelForSequenceClassification, BitsAndBytesConfig
from transformers.trainer_callback import PrinterCallback
from trl import RewardConfig, RewardTrainer
from autotrain import logger
from autotrain.trainers.clm import utils
from autotrain.trainers.clm.params import LLMTrainingParams
from autotrain.trainers.common import ALLOW_REMOTE_CODE
def train(config):
logger.info("Starting Reward training...")
if isinstance(config, dict):
config = LLMTrainingParams(**config)
train_data, valid_data = utils.process_input_data(config)
tokenizer = utils.get_tokenizer(config)
train_data, valid_data = utils.process_data_with_chat_template(config, tokenizer, train_data, valid_data)
logging_steps = utils.configure_logging_steps(config, train_data, valid_data)
training_args = utils.configure_training_args(config, logging_steps)
config = utils.configure_block_size(config, tokenizer)
training_args["max_length"] = config.block_size
args = RewardConfig(**training_args)
logger.info("loading model config...")
model_config = AutoConfig.from_pretrained(
config.model,
token=config.token,
trust_remote_code=ALLOW_REMOTE_CODE,
use_cache=config.disable_gradient_checkpointing,
)
model_config.num_labels = 1
model_config.pad_token_id = tokenizer.pad_token_id
model_config.pad_token = tokenizer.pad_token
logger.info("loading model...")
if config.peft:
if config.quantization == "int4":
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=False,
)
elif config.quantization == "int8":
bnb_config = BitsAndBytesConfig(load_in_8bit=True)
else:
bnb_config = None
model = AutoModelForSequenceClassification.from_pretrained(
config.model,
config=model_config,
token=config.token,
quantization_config=bnb_config,
trust_remote_code=ALLOW_REMOTE_CODE,
use_flash_attention_2=config.use_flash_attention_2,
)
else:
model = AutoModelForSequenceClassification.from_pretrained(
config.model,
config=model_config,
token=config.token,
trust_remote_code=ALLOW_REMOTE_CODE,
use_flash_attention_2=config.use_flash_attention_2,
)
logger.info(f"model dtype: {model.dtype}")
model.resize_token_embeddings(len(tokenizer))
if config.peft:
peft_config = LoraConfig(
r=config.lora_r,
lora_alpha=config.lora_alpha,
lora_dropout=config.lora_dropout,
bias="none",
task_type="SEQ_CLS",
target_modules=utils.get_target_modules(config),
)
reward_proc = partial(utils.preprocess_reward, tokenizer=tokenizer)
train_data = train_data.map(
reward_proc,
batched=True,
num_proc=4,
desc="Running tokenizer on train dataset",
)
train_data = train_data.filter(
lambda x: len(x["input_ids_chosen"]) <= config.block_size and len(x["input_ids_rejected"]) <= config.block_size
)
if config.valid_split is not None:
valid_data = valid_data.map(
reward_proc,
batched=True,
num_proc=4,
desc="Running tokenizer on validation dataset",
)
valid_data = valid_data.filter(
lambda x: len(x["input_ids_chosen"]) <= config.block_size
and len(x["input_ids_rejected"]) <= config.block_size
)
logger.info("creating trainer")
callbacks = utils.get_callbacks(config)
trainer_args = dict(
args=args,
model=model,
callbacks=callbacks,
)
trainer = RewardTrainer(
**trainer_args,
train_dataset=train_data,
eval_dataset=valid_data if config.valid_split is not None else None,
peft_config=peft_config if config.peft else None,
processing_class=tokenizer,
)
trainer.remove_callback(PrinterCallback)
trainer.train()
utils.post_training_steps(config, trainer)
|