more dpo fixes for dataset loading and docs (#1185) [skip ci]
Browse files* more dpo fixes for dataset loading and docs
* preprocess dpo datasets
- docs/rlhf.md +10 -0
- src/axolotl/cli/preprocess.py +6 -1
- src/axolotl/core/trainer_builder.py +6 -1
- src/axolotl/utils/data.py +51 -2
    	
        docs/rlhf.md
    CHANGED
    
    | @@ -34,6 +34,16 @@ datasets: | |
| 34 | 
             
            rl: ipo
         | 
| 35 | 
             
            ```
         | 
| 36 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 37 | 
             
            #### Trl autounwrap for peft
         | 
| 38 |  | 
| 39 | 
             
            Trl supports autounwrapping peft models, so that a ref model does not need to be additionally loaded, leading to less VRAM needed. This is on by default. To turn it off, pass the following config.
         | 
|  | |
| 34 | 
             
            rl: ipo
         | 
| 35 | 
             
            ```
         | 
| 36 |  | 
| 37 | 
            +
            #### Using local dataset files
         | 
| 38 | 
            +
            ```yaml
         | 
| 39 | 
            +
            datasets:
         | 
| 40 | 
            +
              - ds_type: json
         | 
| 41 | 
            +
                data_files:
         | 
| 42 | 
            +
                  - orca_rlhf.jsonl
         | 
| 43 | 
            +
                split: train
         | 
| 44 | 
            +
                type: chatml.intel
         | 
| 45 | 
            +
            ```
         | 
| 46 | 
            +
             | 
| 47 | 
             
            #### Trl autounwrap for peft
         | 
| 48 |  | 
| 49 | 
             
            Trl supports autounwrapping peft models, so that a ref model does not need to be additionally loaded, leading to less VRAM needed. This is on by default. To turn it off, pass the following config.
         | 
    	
        src/axolotl/cli/preprocess.py
    CHANGED
    
    | @@ -13,6 +13,7 @@ from axolotl.cli import ( | |
| 13 | 
             
                check_user_token,
         | 
| 14 | 
             
                load_cfg,
         | 
| 15 | 
             
                load_datasets,
         | 
|  | |
| 16 | 
             
                print_axolotl_text_art,
         | 
| 17 | 
             
            )
         | 
| 18 | 
             
            from axolotl.common.cli import PreprocessCliArgs
         | 
| @@ -43,7 +44,11 @@ def do_cli(config: Path = Path("examples/"), **kwargs): | |
| 43 | 
             
                    LOG.warning(msg)
         | 
| 44 | 
             
                    parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
         | 
| 45 |  | 
| 46 | 
            -
                 | 
|  | |
|  | |
|  | |
|  | |
| 47 | 
             
                LOG.info(
         | 
| 48 | 
             
                    Fore.GREEN
         | 
| 49 | 
             
                    + f"Success! Preprocessed data path: `dataset_prepared_path: {parsed_cfg.dataset_prepared_path}`"
         | 
|  | |
| 13 | 
             
                check_user_token,
         | 
| 14 | 
             
                load_cfg,
         | 
| 15 | 
             
                load_datasets,
         | 
| 16 | 
            +
                load_rl_datasets,
         | 
| 17 | 
             
                print_axolotl_text_art,
         | 
| 18 | 
             
            )
         | 
| 19 | 
             
            from axolotl.common.cli import PreprocessCliArgs
         | 
|  | |
| 44 | 
             
                    LOG.warning(msg)
         | 
| 45 | 
             
                    parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
         | 
| 46 |  | 
| 47 | 
            +
                if parsed_cfg.rl:
         | 
| 48 | 
            +
                    load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
         | 
| 49 | 
            +
                else:
         | 
| 50 | 
            +
                    load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
         | 
| 51 | 
            +
             | 
| 52 | 
             
                LOG.info(
         | 
| 53 | 
             
                    Fore.GREEN
         | 
| 54 | 
             
                    + f"Success! Preprocessed data path: `dataset_prepared_path: {parsed_cfg.dataset_prepared_path}`"
         | 
    	
        src/axolotl/core/trainer_builder.py
    CHANGED
    
    | @@ -996,6 +996,12 @@ class HFDPOTrainerBuilder(TrainerBuilderBase): | |
| 996 | 
             
                    training_args_kwargs["lr_scheduler_kwargs"] = (
         | 
| 997 | 
             
                        self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
         | 
| 998 | 
             
                    )
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 999 |  | 
| 1000 | 
             
                    if self.cfg.dataloader_pin_memory is not None:
         | 
| 1001 | 
             
                        training_args_kwargs[
         | 
| @@ -1013,7 +1019,6 @@ class HFDPOTrainerBuilder(TrainerBuilderBase): | |
| 1013 | 
             
                    training_args = TrainingArguments(
         | 
| 1014 | 
             
                        per_device_train_batch_size=self.cfg.micro_batch_size,
         | 
| 1015 | 
             
                        max_steps=self.cfg.max_steps or total_num_steps,
         | 
| 1016 | 
            -
                        remove_unused_columns=False,
         | 
| 1017 | 
             
                        gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
         | 
| 1018 | 
             
                        learning_rate=self.cfg.learning_rate,
         | 
| 1019 | 
             
                        save_strategy="steps",
         | 
|  | |
| 996 | 
             
                    training_args_kwargs["lr_scheduler_kwargs"] = (
         | 
| 997 | 
             
                        self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
         | 
| 998 | 
             
                    )
         | 
| 999 | 
            +
                    if self.cfg.remove_unused_columns is not None:
         | 
| 1000 | 
            +
                        training_args_kwargs[
         | 
| 1001 | 
            +
                            "remove_unused_columns"
         | 
| 1002 | 
            +
                        ] = self.cfg.remove_unused_columns
         | 
| 1003 | 
            +
                    else:
         | 
| 1004 | 
            +
                        training_args_kwargs["remove_unused_columns"] = False
         | 
| 1005 |  | 
| 1006 | 
             
                    if self.cfg.dataloader_pin_memory is not None:
         | 
| 1007 | 
             
                        training_args_kwargs[
         | 
|  | |
| 1019 | 
             
                    training_args = TrainingArguments(
         | 
| 1020 | 
             
                        per_device_train_batch_size=self.cfg.micro_batch_size,
         | 
| 1021 | 
             
                        max_steps=self.cfg.max_steps or total_num_steps,
         | 
|  | |
| 1022 | 
             
                        gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
         | 
| 1023 | 
             
                        learning_rate=self.cfg.learning_rate,
         | 
| 1024 | 
             
                        save_strategy="steps",
         | 
    	
        src/axolotl/utils/data.py
    CHANGED
    
    | @@ -7,6 +7,7 @@ from pathlib import Path | |
| 7 | 
             
            from typing import Any, Dict, List, Optional, Tuple, Union
         | 
| 8 |  | 
| 9 | 
             
            import torch
         | 
|  | |
| 10 | 
             
            from datasets import (
         | 
| 11 | 
             
                Dataset,
         | 
| 12 | 
             
                DatasetDict,
         | 
| @@ -853,6 +854,41 @@ def encode_packed_pretraining( | |
| 853 | 
             
                return chunked_data
         | 
| 854 |  | 
| 855 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 856 | 
             
            def load_prepare_dpo_datasets(cfg):
         | 
| 857 | 
             
                def load_split(dataset_cfgs, _cfg):
         | 
| 858 | 
             
                    split_datasets: List[Any] = []
         | 
| @@ -889,12 +925,25 @@ def load_prepare_dpo_datasets(cfg): | |
| 889 | 
             
                    return concatenate_datasets(split_datasets)
         | 
| 890 |  | 
| 891 | 
             
                with zero_first(is_main_process()):
         | 
| 892 | 
            -
                     | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 893 |  | 
| 894 | 
             
                    eval_dataset = None
         | 
| 895 | 
             
                    if cfg.test_datasets:
         | 
| 896 | 
            -
                        eval_dataset  | 
|  | |
|  | |
|  | |
| 897 | 
             
                    if not eval_dataset:
         | 
| 898 | 
             
                        eval_dataset = None
         | 
| 899 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 900 | 
             
                return train_dataset, eval_dataset
         | 
|  | |
| 7 | 
             
            from typing import Any, Dict, List, Optional, Tuple, Union
         | 
| 8 |  | 
| 9 | 
             
            import torch
         | 
| 10 | 
            +
            import yaml
         | 
| 11 | 
             
            from datasets import (
         | 
| 12 | 
             
                Dataset,
         | 
| 13 | 
             
                DatasetDict,
         | 
|  | |
| 854 | 
             
                return chunked_data
         | 
| 855 |  | 
| 856 |  | 
| 857 | 
            +
            def _get_path(ds_hash, cfg):
         | 
| 858 | 
            +
                prepared_ds_path = (
         | 
| 859 | 
            +
                    Path(cfg.dataset_prepared_path) / ds_hash
         | 
| 860 | 
            +
                    if cfg.dataset_prepared_path
         | 
| 861 | 
            +
                    else Path(DEFAULT_DATASET_PREPARED_PATH) / ds_hash
         | 
| 862 | 
            +
                )
         | 
| 863 | 
            +
             | 
| 864 | 
            +
                return prepared_ds_path
         | 
| 865 | 
            +
             | 
| 866 | 
            +
             | 
| 867 | 
            +
            def _load_preprocessed_ds(cfg, sub_cfg):
         | 
| 868 | 
            +
                ds_hash = md5(yaml.dump(sub_cfg, Dumper=yaml.Dumper))
         | 
| 869 | 
            +
                prepared_ds_path = _get_path(ds_hash, cfg)
         | 
| 870 | 
            +
                dataset = None
         | 
| 871 | 
            +
             | 
| 872 | 
            +
                if (
         | 
| 873 | 
            +
                    cfg.dataset_prepared_path
         | 
| 874 | 
            +
                    and any(prepared_ds_path.glob("*"))
         | 
| 875 | 
            +
                    and not cfg.is_preprocess
         | 
| 876 | 
            +
                ):
         | 
| 877 | 
            +
                    LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
         | 
| 878 | 
            +
                    dataset = load_from_disk(str(prepared_ds_path))
         | 
| 879 | 
            +
             | 
| 880 | 
            +
                return dataset
         | 
| 881 | 
            +
             | 
| 882 | 
            +
             | 
| 883 | 
            +
            def _save_preprocessed_ds(cfg, sub_cfg, dataset):
         | 
| 884 | 
            +
                ds_hash = md5(yaml.dump(sub_cfg, Dumper=yaml.Dumper))
         | 
| 885 | 
            +
                prepared_ds_path = _get_path(ds_hash, cfg)
         | 
| 886 | 
            +
             | 
| 887 | 
            +
                if cfg.is_preprocess and is_main_process():
         | 
| 888 | 
            +
                    LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
         | 
| 889 | 
            +
                    dataset.save_to_disk(str(prepared_ds_path))
         | 
| 890 | 
            +
             | 
| 891 | 
            +
             | 
| 892 | 
             
            def load_prepare_dpo_datasets(cfg):
         | 
| 893 | 
             
                def load_split(dataset_cfgs, _cfg):
         | 
| 894 | 
             
                    split_datasets: List[Any] = []
         | 
|  | |
| 925 | 
             
                    return concatenate_datasets(split_datasets)
         | 
| 926 |  | 
| 927 | 
             
                with zero_first(is_main_process()):
         | 
| 928 | 
            +
                    train_is_preprocessed = False
         | 
| 929 | 
            +
                    eval_is_preprocessed = False
         | 
| 930 | 
            +
                    if train_dataset := _load_preprocessed_ds(cfg, cfg.datasets):
         | 
| 931 | 
            +
                        train_is_preprocessed = True
         | 
| 932 | 
            +
                    else:
         | 
| 933 | 
            +
                        train_dataset = load_split(cfg.datasets, cfg)
         | 
| 934 |  | 
| 935 | 
             
                    eval_dataset = None
         | 
| 936 | 
             
                    if cfg.test_datasets:
         | 
| 937 | 
            +
                        if eval_dataset := _load_preprocessed_ds(cfg, cfg.test_datasets):
         | 
| 938 | 
            +
                            eval_is_preprocessed = True
         | 
| 939 | 
            +
                        else:
         | 
| 940 | 
            +
                            eval_dataset = load_split(cfg.test_datasets, cfg)
         | 
| 941 | 
             
                    if not eval_dataset:
         | 
| 942 | 
             
                        eval_dataset = None
         | 
| 943 |  | 
| 944 | 
            +
                    if not train_is_preprocessed:
         | 
| 945 | 
            +
                        _save_preprocessed_ds(cfg, cfg.datasets, train_dataset)
         | 
| 946 | 
            +
                    if eval_dataset and not eval_is_preprocessed:
         | 
| 947 | 
            +
                        _save_preprocessed_ds(cfg, cfg.test_datasets, eval_dataset)
         | 
| 948 | 
            +
             | 
| 949 | 
             
                return train_dataset, eval_dataset
         | 
