File size: 3,749 Bytes
33d4721
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
from functools import partial

import torch
from datasets import Dataset
from peft.tuners.lora import LoraLayer
from transformers import Trainer, TrainingArguments, default_data_collator
from transformers.trainer_callback import PrinterCallback

from autotrain import logger
from autotrain.trainers.clm import utils
from autotrain.trainers.clm.params import LLMTrainingParams


def process_data(data, tokenizer, config):
    data = data.to_pandas()
    data = data.fillna("")

    data = data[[config.text_column]]
    if config.add_eos_token:
        data[config.text_column] = data[config.text_column] + tokenizer.eos_token
    data = Dataset.from_pandas(data)
    return data


def train(config):
    logger.info("Starting default/generic CLM training...")
    if isinstance(config, dict):
        config = LLMTrainingParams(**config)
    train_data, valid_data = utils.process_input_data(config)
    tokenizer = utils.get_tokenizer(config)
    train_data, valid_data = utils.process_data_with_chat_template(config, tokenizer, train_data, valid_data)

    train_data = process_data(
        data=train_data,
        tokenizer=tokenizer,
        config=config,
    )
    if config.valid_split is not None:
        valid_data = process_data(
            data=valid_data,
            tokenizer=tokenizer,
            config=config,
        )

    logging_steps = utils.configure_logging_steps(config, train_data, valid_data)
    training_args = utils.configure_training_args(config, logging_steps)
    config = utils.configure_block_size(config, tokenizer)
    args = TrainingArguments(**training_args)

    model = utils.get_model(config, tokenizer)

    tokenize_fn = partial(utils.tokenize, tokenizer=tokenizer, config=config)
    group_texts_fn = partial(utils.group_texts, config=config)

    train_data = train_data.map(
        tokenize_fn,
        batched=True,
        num_proc=1,
        remove_columns=list(train_data.features),
        desc="Running tokenizer on train dataset",
    )

    if config.valid_split is not None:
        valid_data = valid_data.map(
            tokenize_fn,
            batched=True,
            num_proc=1,
            remove_columns=list(valid_data.features),
            desc="Running tokenizer on validation dataset",
        )

    train_data = train_data.map(
        group_texts_fn,
        batched=True,
        num_proc=4,
        desc=f"Grouping texts in chunks of {config.block_size}",
    )

    if config.valid_split is not None:
        valid_data = valid_data.map(
            group_texts_fn,
            batched=True,
            num_proc=4,
            desc=f"Grouping texts in chunks of {config.block_size}",
        )

    logger.info("creating trainer")
    callbacks = utils.get_callbacks(config)
    trainer_args = dict(
        args=args,
        model=model,
        callbacks=callbacks,
    )
    trainer = Trainer(
        **trainer_args,
        train_dataset=train_data,
        eval_dataset=valid_data if config.valid_split is not None else None,
        tokenizer=tokenizer,
        data_collator=default_data_collator,
    )
    for name, module in trainer.model.named_modules():
        if isinstance(module, LoraLayer):
            if config.mixed_precision == "bf16":
                module = module.to(torch.bfloat16)
        if "norm" in name:
            module = module.to(torch.float32)
        if any(x in name for x in ["lm_head", "embed_tokens", "wte", "wpe"]):
            if hasattr(module, "weight"):
                if config.mixed_precision == "bf16" and module.weight.dtype == torch.float32:
                    module = module.to(torch.bfloat16)

    trainer.remove_callback(PrinterCallback)
    trainer.train()
    utils.post_training_steps(config, trainer)