more dpo fixes for dataset loading and docs (#1185) [skip ci]
Browse files* more dpo fixes for dataset loading and docs
* preprocess dpo datasets
- docs/rlhf.md +10 -0
- src/axolotl/cli/preprocess.py +6 -1
- src/axolotl/core/trainer_builder.py +6 -1
- src/axolotl/utils/data.py +51 -2
docs/rlhf.md
CHANGED
|
@@ -34,6 +34,16 @@ datasets:
|
|
| 34 |
rl: ipo
|
| 35 |
```
|
| 36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
#### Trl autounwrap for peft
|
| 38 |
|
| 39 |
Trl supports autounwrapping peft models, so that a ref model does not need to be additionally loaded, leading to less VRAM needed. This is on by default. To turn it off, pass the following config.
|
|
|
|
| 34 |
rl: ipo
|
| 35 |
```
|
| 36 |
|
| 37 |
+
#### Using local dataset files
|
| 38 |
+
```yaml
|
| 39 |
+
datasets:
|
| 40 |
+
- ds_type: json
|
| 41 |
+
data_files:
|
| 42 |
+
- orca_rlhf.jsonl
|
| 43 |
+
split: train
|
| 44 |
+
type: chatml.intel
|
| 45 |
+
```
|
| 46 |
+
|
| 47 |
#### Trl autounwrap for peft
|
| 48 |
|
| 49 |
Trl supports autounwrapping peft models, so that a ref model does not need to be additionally loaded, leading to less VRAM needed. This is on by default. To turn it off, pass the following config.
|
src/axolotl/cli/preprocess.py
CHANGED
|
@@ -13,6 +13,7 @@ from axolotl.cli import (
|
|
| 13 |
check_user_token,
|
| 14 |
load_cfg,
|
| 15 |
load_datasets,
|
|
|
|
| 16 |
print_axolotl_text_art,
|
| 17 |
)
|
| 18 |
from axolotl.common.cli import PreprocessCliArgs
|
|
@@ -43,7 +44,11 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
|
|
| 43 |
LOG.warning(msg)
|
| 44 |
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
|
| 45 |
|
| 46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
LOG.info(
|
| 48 |
Fore.GREEN
|
| 49 |
+ f"Success! Preprocessed data path: `dataset_prepared_path: {parsed_cfg.dataset_prepared_path}`"
|
|
|
|
| 13 |
check_user_token,
|
| 14 |
load_cfg,
|
| 15 |
load_datasets,
|
| 16 |
+
load_rl_datasets,
|
| 17 |
print_axolotl_text_art,
|
| 18 |
)
|
| 19 |
from axolotl.common.cli import PreprocessCliArgs
|
|
|
|
| 44 |
LOG.warning(msg)
|
| 45 |
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
|
| 46 |
|
| 47 |
+
if parsed_cfg.rl:
|
| 48 |
+
load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
| 49 |
+
else:
|
| 50 |
+
load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
| 51 |
+
|
| 52 |
LOG.info(
|
| 53 |
Fore.GREEN
|
| 54 |
+ f"Success! Preprocessed data path: `dataset_prepared_path: {parsed_cfg.dataset_prepared_path}`"
|
src/axolotl/core/trainer_builder.py
CHANGED
|
@@ -996,6 +996,12 @@ class HFDPOTrainerBuilder(TrainerBuilderBase):
|
|
| 996 |
training_args_kwargs["lr_scheduler_kwargs"] = (
|
| 997 |
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
|
| 998 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 999 |
|
| 1000 |
if self.cfg.dataloader_pin_memory is not None:
|
| 1001 |
training_args_kwargs[
|
|
@@ -1013,7 +1019,6 @@ class HFDPOTrainerBuilder(TrainerBuilderBase):
|
|
| 1013 |
training_args = TrainingArguments(
|
| 1014 |
per_device_train_batch_size=self.cfg.micro_batch_size,
|
| 1015 |
max_steps=self.cfg.max_steps or total_num_steps,
|
| 1016 |
-
remove_unused_columns=False,
|
| 1017 |
gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
|
| 1018 |
learning_rate=self.cfg.learning_rate,
|
| 1019 |
save_strategy="steps",
|
|
|
|
| 996 |
training_args_kwargs["lr_scheduler_kwargs"] = (
|
| 997 |
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
|
| 998 |
)
|
| 999 |
+
if self.cfg.remove_unused_columns is not None:
|
| 1000 |
+
training_args_kwargs[
|
| 1001 |
+
"remove_unused_columns"
|
| 1002 |
+
] = self.cfg.remove_unused_columns
|
| 1003 |
+
else:
|
| 1004 |
+
training_args_kwargs["remove_unused_columns"] = False
|
| 1005 |
|
| 1006 |
if self.cfg.dataloader_pin_memory is not None:
|
| 1007 |
training_args_kwargs[
|
|
|
|
| 1019 |
training_args = TrainingArguments(
|
| 1020 |
per_device_train_batch_size=self.cfg.micro_batch_size,
|
| 1021 |
max_steps=self.cfg.max_steps or total_num_steps,
|
|
|
|
| 1022 |
gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
|
| 1023 |
learning_rate=self.cfg.learning_rate,
|
| 1024 |
save_strategy="steps",
|
src/axolotl/utils/data.py
CHANGED
|
@@ -7,6 +7,7 @@ from pathlib import Path
|
|
| 7 |
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 8 |
|
| 9 |
import torch
|
|
|
|
| 10 |
from datasets import (
|
| 11 |
Dataset,
|
| 12 |
DatasetDict,
|
|
@@ -853,6 +854,41 @@ def encode_packed_pretraining(
|
|
| 853 |
return chunked_data
|
| 854 |
|
| 855 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 856 |
def load_prepare_dpo_datasets(cfg):
|
| 857 |
def load_split(dataset_cfgs, _cfg):
|
| 858 |
split_datasets: List[Any] = []
|
|
@@ -889,12 +925,25 @@ def load_prepare_dpo_datasets(cfg):
|
|
| 889 |
return concatenate_datasets(split_datasets)
|
| 890 |
|
| 891 |
with zero_first(is_main_process()):
|
| 892 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 893 |
|
| 894 |
eval_dataset = None
|
| 895 |
if cfg.test_datasets:
|
| 896 |
-
eval_dataset
|
|
|
|
|
|
|
|
|
|
| 897 |
if not eval_dataset:
|
| 898 |
eval_dataset = None
|
| 899 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 900 |
return train_dataset, eval_dataset
|
|
|
|
| 7 |
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 8 |
|
| 9 |
import torch
|
| 10 |
+
import yaml
|
| 11 |
from datasets import (
|
| 12 |
Dataset,
|
| 13 |
DatasetDict,
|
|
|
|
| 854 |
return chunked_data
|
| 855 |
|
| 856 |
|
| 857 |
+
def _get_path(ds_hash, cfg):
|
| 858 |
+
prepared_ds_path = (
|
| 859 |
+
Path(cfg.dataset_prepared_path) / ds_hash
|
| 860 |
+
if cfg.dataset_prepared_path
|
| 861 |
+
else Path(DEFAULT_DATASET_PREPARED_PATH) / ds_hash
|
| 862 |
+
)
|
| 863 |
+
|
| 864 |
+
return prepared_ds_path
|
| 865 |
+
|
| 866 |
+
|
| 867 |
+
def _load_preprocessed_ds(cfg, sub_cfg):
|
| 868 |
+
ds_hash = md5(yaml.dump(sub_cfg, Dumper=yaml.Dumper))
|
| 869 |
+
prepared_ds_path = _get_path(ds_hash, cfg)
|
| 870 |
+
dataset = None
|
| 871 |
+
|
| 872 |
+
if (
|
| 873 |
+
cfg.dataset_prepared_path
|
| 874 |
+
and any(prepared_ds_path.glob("*"))
|
| 875 |
+
and not cfg.is_preprocess
|
| 876 |
+
):
|
| 877 |
+
LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
|
| 878 |
+
dataset = load_from_disk(str(prepared_ds_path))
|
| 879 |
+
|
| 880 |
+
return dataset
|
| 881 |
+
|
| 882 |
+
|
| 883 |
+
def _save_preprocessed_ds(cfg, sub_cfg, dataset):
|
| 884 |
+
ds_hash = md5(yaml.dump(sub_cfg, Dumper=yaml.Dumper))
|
| 885 |
+
prepared_ds_path = _get_path(ds_hash, cfg)
|
| 886 |
+
|
| 887 |
+
if cfg.is_preprocess and is_main_process():
|
| 888 |
+
LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
|
| 889 |
+
dataset.save_to_disk(str(prepared_ds_path))
|
| 890 |
+
|
| 891 |
+
|
| 892 |
def load_prepare_dpo_datasets(cfg):
|
| 893 |
def load_split(dataset_cfgs, _cfg):
|
| 894 |
split_datasets: List[Any] = []
|
|
|
|
| 925 |
return concatenate_datasets(split_datasets)
|
| 926 |
|
| 927 |
with zero_first(is_main_process()):
|
| 928 |
+
train_is_preprocessed = False
|
| 929 |
+
eval_is_preprocessed = False
|
| 930 |
+
if train_dataset := _load_preprocessed_ds(cfg, cfg.datasets):
|
| 931 |
+
train_is_preprocessed = True
|
| 932 |
+
else:
|
| 933 |
+
train_dataset = load_split(cfg.datasets, cfg)
|
| 934 |
|
| 935 |
eval_dataset = None
|
| 936 |
if cfg.test_datasets:
|
| 937 |
+
if eval_dataset := _load_preprocessed_ds(cfg, cfg.test_datasets):
|
| 938 |
+
eval_is_preprocessed = True
|
| 939 |
+
else:
|
| 940 |
+
eval_dataset = load_split(cfg.test_datasets, cfg)
|
| 941 |
if not eval_dataset:
|
| 942 |
eval_dataset = None
|
| 943 |
|
| 944 |
+
if not train_is_preprocessed:
|
| 945 |
+
_save_preprocessed_ds(cfg, cfg.datasets, train_dataset)
|
| 946 |
+
if eval_dataset and not eval_is_preprocessed:
|
| 947 |
+
_save_preprocessed_ds(cfg, cfg.test_datasets, eval_dataset)
|
| 948 |
+
|
| 949 |
return train_dataset, eval_dataset
|