Spaces:
Sleeping
Sleeping
File size: 1,927 Bytes
33d4721 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
from peft import LoraConfig
from transformers.trainer_callback import PrinterCallback
from trl import SFTConfig, SFTTrainer
from autotrain import logger
from autotrain.trainers.clm import utils
from autotrain.trainers.clm.params import LLMTrainingParams
def train(config):
logger.info("Starting SFT training...")
if isinstance(config, dict):
config = LLMTrainingParams(**config)
train_data, valid_data = utils.process_input_data(config)
tokenizer = utils.get_tokenizer(config)
train_data, valid_data = utils.process_data_with_chat_template(config, tokenizer, train_data, valid_data)
logging_steps = utils.configure_logging_steps(config, train_data, valid_data)
training_args = utils.configure_training_args(config, logging_steps)
config = utils.configure_block_size(config, tokenizer)
training_args["dataset_text_field"] = config.text_column
training_args["max_seq_length"] = config.block_size
training_args["packing"] = True
args = SFTConfig(**training_args)
model = utils.get_model(config, tokenizer)
if config.peft:
peft_config = LoraConfig(
r=config.lora_r,
lora_alpha=config.lora_alpha,
lora_dropout=config.lora_dropout,
bias="none",
task_type="CAUSAL_LM",
target_modules=utils.get_target_modules(config),
)
logger.info("creating trainer")
callbacks = utils.get_callbacks(config)
trainer_args = dict(
args=args,
model=model,
callbacks=callbacks,
)
trainer = SFTTrainer(
**trainer_args,
train_dataset=train_data,
eval_dataset=valid_data if config.valid_split is not None else None,
peft_config=peft_config if config.peft else None,
processing_class=tokenizer,
)
trainer.remove_callback(PrinterCallback)
trainer.train()
utils.post_training_steps(config, trainer)
|