from peft import LoraConfig from transformers.trainer_callback import PrinterCallback from trl import SFTConfig, SFTTrainer from autotrain import logger from autotrain.trainers.clm import utils from autotrain.trainers.clm.params import LLMTrainingParams def train(config): logger.info("Starting SFT training...") if isinstance(config, dict): config = LLMTrainingParams(**config) train_data, valid_data = utils.process_input_data(config) tokenizer = utils.get_tokenizer(config) train_data, valid_data = utils.process_data_with_chat_template(config, tokenizer, train_data, valid_data) logging_steps = utils.configure_logging_steps(config, train_data, valid_data) training_args = utils.configure_training_args(config, logging_steps) config = utils.configure_block_size(config, tokenizer) training_args["dataset_text_field"] = config.text_column training_args["max_seq_length"] = config.block_size training_args["packing"] = True args = SFTConfig(**training_args) model = utils.get_model(config, tokenizer) if config.peft: peft_config = LoraConfig( r=config.lora_r, lora_alpha=config.lora_alpha, lora_dropout=config.lora_dropout, bias="none", task_type="CAUSAL_LM", target_modules=utils.get_target_modules(config), ) logger.info("creating trainer") callbacks = utils.get_callbacks(config) trainer_args = dict( args=args, model=model, callbacks=callbacks, ) trainer = SFTTrainer( **trainer_args, train_dataset=train_data, eval_dataset=valid_data if config.valid_split is not None else None, peft_config=peft_config if config.peft else None, processing_class=tokenizer, ) trainer.remove_callback(PrinterCallback) trainer.train() utils.post_training_steps(config, trainer)