Spaces:
Sleeping
Sleeping
from peft import LoraConfig | |
from transformers.trainer_callback import PrinterCallback | |
from trl import ORPOConfig, ORPOTrainer | |
from autotrain import logger | |
from autotrain.trainers.clm import utils | |
from autotrain.trainers.clm.params import LLMTrainingParams | |
def train(config): | |
logger.info("Starting ORPO training...") | |
if isinstance(config, dict): | |
config = LLMTrainingParams(**config) | |
train_data, valid_data = utils.process_input_data(config) | |
tokenizer = utils.get_tokenizer(config) | |
train_data, valid_data = utils.process_data_with_chat_template(config, tokenizer, train_data, valid_data) | |
logging_steps = utils.configure_logging_steps(config, train_data, valid_data) | |
training_args = utils.configure_training_args(config, logging_steps) | |
config = utils.configure_block_size(config, tokenizer) | |
training_args["max_length"] = config.block_size | |
training_args["max_prompt_length"] = config.max_prompt_length | |
training_args["max_completion_length"] = config.max_completion_length | |
args = ORPOConfig(**training_args) | |
model = utils.get_model(config, tokenizer) | |
if config.peft: | |
peft_config = LoraConfig( | |
r=config.lora_r, | |
lora_alpha=config.lora_alpha, | |
lora_dropout=config.lora_dropout, | |
bias="none", | |
task_type="CAUSAL_LM", | |
target_modules=utils.get_target_modules(config), | |
) | |
logger.info("creating trainer") | |
callbacks = utils.get_callbacks(config) | |
trainer_args = dict( | |
args=args, | |
model=model, | |
callbacks=callbacks, | |
) | |
trainer = ORPOTrainer( | |
**trainer_args, | |
train_dataset=train_data, | |
eval_dataset=valid_data if config.valid_split is not None else None, | |
processing_class=tokenizer, | |
peft_config=peft_config if config.peft else None, | |
) | |
trainer.remove_callback(PrinterCallback) | |
trainer.train() | |
utils.post_training_steps(config, trainer) | |