Add KTO support (#1640)
Browse files* add kto support
* test cleanup
* fix outdated comment
* fix llama3 ultra
* chore: lint
* update to use rl_beta instead of dpo_beta
---------
Co-authored-by: Wing Lian <[email protected]>
- src/axolotl/core/trainer_builder.py +29 -2
- src/axolotl/prompt_strategies/kto/__init__.py +9 -0
- src/axolotl/prompt_strategies/kto/chatml.py +105 -0
- src/axolotl/prompt_strategies/kto/llama3.py +105 -0
- src/axolotl/prompt_strategies/kto/user_defined.py +39 -0
- src/axolotl/utils/config/models/input/v0_4_1/__init__.py +42 -2
- src/axolotl/utils/data/rl.py +28 -12
- src/axolotl/utils/models.py +5 -1
- src/axolotl/utils/trainer.py +1 -1
- tests/e2e/test_dpo.py +63 -0
- tests/test_validation.py +9 -0
src/axolotl/core/trainer_builder.py
CHANGED
|
@@ -30,7 +30,7 @@ from transformers import (
|
|
| 30 |
)
|
| 31 |
from transformers.trainer_utils import seed_worker
|
| 32 |
from transformers.utils import is_sagemaker_mp_enabled
|
| 33 |
-
from trl import DPOTrainer, ORPOConfig, ORPOTrainer
|
| 34 |
from trl.trainer.utils import pad_to_length
|
| 35 |
|
| 36 |
from axolotl.loraplus import create_loraplus_optimizer
|
|
@@ -826,6 +826,14 @@ class AxolotlORPOTrainer(ORPOTrainer):
|
|
| 826 |
tag_names = ["axolotl", "orpo"]
|
| 827 |
|
| 828 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 829 |
class TrainerBuilderBase(abc.ABC):
|
| 830 |
"""
|
| 831 |
Base class for trainer builder
|
|
@@ -1532,6 +1540,22 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|
| 1532 |
if self.cfg.max_prompt_len:
|
| 1533 |
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
| 1534 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1535 |
training_args = training_args_cls(
|
| 1536 |
per_device_train_batch_size=self.cfg.micro_batch_size,
|
| 1537 |
max_steps=self.cfg.max_steps or total_num_steps,
|
|
@@ -1567,7 +1591,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|
| 1567 |
] = self.cfg.precompute_ref_log_probs
|
| 1568 |
if self.cfg.rl in ["dpo", "ipo", "kto_pair"]:
|
| 1569 |
trainer_cls = AxolotlDPOTrainer
|
| 1570 |
-
dpo_trainer_kwargs["beta"] = self.cfg.
|
| 1571 |
trainer_cls_args = [self.model, self.model_ref]
|
| 1572 |
|
| 1573 |
# these aren't used for the ORPO trainer
|
|
@@ -1580,6 +1604,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|
| 1580 |
elif self.cfg.rl == "orpo":
|
| 1581 |
trainer_cls = AxolotlORPOTrainer
|
| 1582 |
trainer_cls_args = [self.model]
|
|
|
|
|
|
|
|
|
|
| 1583 |
else:
|
| 1584 |
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
|
| 1585 |
dpo_trainer = trainer_cls(
|
|
|
|
| 30 |
)
|
| 31 |
from transformers.trainer_utils import seed_worker
|
| 32 |
from transformers.utils import is_sagemaker_mp_enabled
|
| 33 |
+
from trl import DPOTrainer, KTOConfig, KTOTrainer, ORPOConfig, ORPOTrainer
|
| 34 |
from trl.trainer.utils import pad_to_length
|
| 35 |
|
| 36 |
from axolotl.loraplus import create_loraplus_optimizer
|
|
|
|
| 826 |
tag_names = ["axolotl", "orpo"]
|
| 827 |
|
| 828 |
|
| 829 |
+
class AxolotlKTOTrainer(KTOTrainer):
|
| 830 |
+
"""
|
| 831 |
+
Extend the base KTOTrainer for axolotl helpers
|
| 832 |
+
"""
|
| 833 |
+
|
| 834 |
+
tag_names = ["axolotl", "kto"]
|
| 835 |
+
|
| 836 |
+
|
| 837 |
class TrainerBuilderBase(abc.ABC):
|
| 838 |
"""
|
| 839 |
Base class for trainer builder
|
|
|
|
| 1540 |
if self.cfg.max_prompt_len:
|
| 1541 |
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
| 1542 |
|
| 1543 |
+
if self.cfg.rl == "kto":
|
| 1544 |
+
training_args_cls = KTOConfig
|
| 1545 |
+
|
| 1546 |
+
training_args_kwargs["beta"] = self.cfg.rl_beta or 0.1
|
| 1547 |
+
training_args_kwargs["desirable_weight"] = (
|
| 1548 |
+
self.cfg.kto_desirable_weight or 1.0
|
| 1549 |
+
)
|
| 1550 |
+
training_args_kwargs["undesirable_weight"] = (
|
| 1551 |
+
self.cfg.kto_undesirable_weight or 1.0
|
| 1552 |
+
)
|
| 1553 |
+
|
| 1554 |
+
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
| 1555 |
+
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
| 1556 |
+
if self.cfg.max_prompt_len:
|
| 1557 |
+
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
| 1558 |
+
|
| 1559 |
training_args = training_args_cls(
|
| 1560 |
per_device_train_batch_size=self.cfg.micro_batch_size,
|
| 1561 |
max_steps=self.cfg.max_steps or total_num_steps,
|
|
|
|
| 1591 |
] = self.cfg.precompute_ref_log_probs
|
| 1592 |
if self.cfg.rl in ["dpo", "ipo", "kto_pair"]:
|
| 1593 |
trainer_cls = AxolotlDPOTrainer
|
| 1594 |
+
dpo_trainer_kwargs["beta"] = self.cfg.rl_beta or 0.1
|
| 1595 |
trainer_cls_args = [self.model, self.model_ref]
|
| 1596 |
|
| 1597 |
# these aren't used for the ORPO trainer
|
|
|
|
| 1604 |
elif self.cfg.rl == "orpo":
|
| 1605 |
trainer_cls = AxolotlORPOTrainer
|
| 1606 |
trainer_cls_args = [self.model]
|
| 1607 |
+
elif self.cfg.rl == "kto":
|
| 1608 |
+
trainer_cls = AxolotlKTOTrainer
|
| 1609 |
+
trainer_cls_args = [self.model]
|
| 1610 |
else:
|
| 1611 |
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
|
| 1612 |
dpo_trainer = trainer_cls(
|
src/axolotl/prompt_strategies/kto/__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
module for KTO style dataset transform strategies
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from functools import partial
|
| 6 |
+
|
| 7 |
+
from ..base import load as load_base
|
| 8 |
+
|
| 9 |
+
load = partial(load_base, module_base="axolotl.prompt_strategies.kto")
|
src/axolotl/prompt_strategies/kto/chatml.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
KTO strategies for chatml
|
| 3 |
+
"""
|
| 4 |
+
# pylint: disable=duplicate-code
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def argilla(
|
| 8 |
+
cfg,
|
| 9 |
+
**kwargs,
|
| 10 |
+
): # pylint: disable=possibly-unused-variable,unused-argument
|
| 11 |
+
def transform_fn(sample):
|
| 12 |
+
if "system" in sample and sample["system"]:
|
| 13 |
+
sample["prompt"] = (
|
| 14 |
+
f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
|
| 15 |
+
f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n"
|
| 16 |
+
)
|
| 17 |
+
else:
|
| 18 |
+
sample[
|
| 19 |
+
"prompt"
|
| 20 |
+
] = f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n"
|
| 21 |
+
sample["completion"] = f"{sample['completion']}<|im_end|>"
|
| 22 |
+
return sample
|
| 23 |
+
|
| 24 |
+
return transform_fn
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def argilla_chat(
|
| 28 |
+
cfg,
|
| 29 |
+
**kwargs,
|
| 30 |
+
): # pylint: disable=possibly-unused-variable,unused-argument
|
| 31 |
+
"""
|
| 32 |
+
for argilla/kto-mix-15k conversations
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
def transform_fn(sample):
|
| 36 |
+
sample[
|
| 37 |
+
"prompt"
|
| 38 |
+
] = f"<|im_start|>user\n{sample['chosen'][0]['content']}<|im_end|>\n<|im_start|>assistant\n"
|
| 39 |
+
sample["completion"] = f"{sample['completion'][1]['content']}<|im_end|>"
|
| 40 |
+
return sample
|
| 41 |
+
|
| 42 |
+
return transform_fn
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
|
| 46 |
+
"""
|
| 47 |
+
For Intel Orca KTO
|
| 48 |
+
ex: argilla/distilabel-intel-orca-kto
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
def transform_fn(sample):
|
| 52 |
+
if "system" in sample and sample["system"]:
|
| 53 |
+
sample["prompt"] = (
|
| 54 |
+
f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
|
| 55 |
+
f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n"
|
| 56 |
+
)
|
| 57 |
+
else:
|
| 58 |
+
sample[
|
| 59 |
+
"prompt"
|
| 60 |
+
] = f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n"
|
| 61 |
+
sample["completion"] = f"{sample['completion']}<|im_end|>"
|
| 62 |
+
return sample
|
| 63 |
+
|
| 64 |
+
return transform_fn
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def prompt_pairs(
|
| 68 |
+
cfg, **kwargs
|
| 69 |
+
): # pylint: disable=possibly-unused-variable,unused-argument
|
| 70 |
+
def transform_fn(sample):
|
| 71 |
+
if "system" in sample and sample["system"]:
|
| 72 |
+
sample["prompt"] = (
|
| 73 |
+
f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
|
| 74 |
+
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
|
| 75 |
+
)
|
| 76 |
+
else:
|
| 77 |
+
sample[
|
| 78 |
+
"prompt"
|
| 79 |
+
] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
|
| 80 |
+
sample["completion"] = f"{sample['completion']}<|im_end|>"
|
| 81 |
+
return sample
|
| 82 |
+
|
| 83 |
+
return transform_fn
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def ultra(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
|
| 87 |
+
"""
|
| 88 |
+
for ultrafeedback binarized conversations
|
| 89 |
+
ex: argilla/ultrafeedback-binarized-preferences-cleaned-kto
|
| 90 |
+
"""
|
| 91 |
+
|
| 92 |
+
def transform_fn(sample):
|
| 93 |
+
if "system" in sample and sample["system"]:
|
| 94 |
+
sample["prompt"] = (
|
| 95 |
+
f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
|
| 96 |
+
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
|
| 97 |
+
)
|
| 98 |
+
else:
|
| 99 |
+
sample[
|
| 100 |
+
"prompt"
|
| 101 |
+
] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
|
| 102 |
+
sample["completion"] = f"{sample['completion']}<|im_end|>"
|
| 103 |
+
return sample
|
| 104 |
+
|
| 105 |
+
return transform_fn
|
src/axolotl/prompt_strategies/kto/llama3.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
KTO strategies for llama-3 chat template
|
| 3 |
+
"""
|
| 4 |
+
# pylint: disable=duplicate-code
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def argilla(
|
| 8 |
+
cfg,
|
| 9 |
+
**kwargs,
|
| 10 |
+
): # pylint: disable=possibly-unused-variable,unused-argument
|
| 11 |
+
def transform_fn(sample):
|
| 12 |
+
if "system" in sample and sample["system"]:
|
| 13 |
+
sample["prompt"] = (
|
| 14 |
+
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
|
| 15 |
+
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
| 16 |
+
)
|
| 17 |
+
else:
|
| 18 |
+
sample[
|
| 19 |
+
"prompt"
|
| 20 |
+
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
| 21 |
+
sample["completion"] = f"{sample['completion']}<|eot_id|>"
|
| 22 |
+
return sample
|
| 23 |
+
|
| 24 |
+
return transform_fn
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def argilla_chat(
|
| 28 |
+
cfg,
|
| 29 |
+
**kwargs,
|
| 30 |
+
): # pylint: disable=possibly-unused-variable,unused-argument
|
| 31 |
+
"""
|
| 32 |
+
for argilla/kto-mix-15k conversations
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
def transform_fn(sample):
|
| 36 |
+
sample[
|
| 37 |
+
"prompt"
|
| 38 |
+
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['completion'][0]['content']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
| 39 |
+
sample["completion"] = f"{sample['completion'][1]['content']}<|eot_id|>"
|
| 40 |
+
return sample
|
| 41 |
+
|
| 42 |
+
return transform_fn
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
|
| 46 |
+
"""
|
| 47 |
+
For Intel Orca KTO
|
| 48 |
+
ex: argilla/distilabel-intel-orca-kto
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
def transform_fn(sample):
|
| 52 |
+
if "system" in sample and sample["system"]:
|
| 53 |
+
sample["prompt"] = (
|
| 54 |
+
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
|
| 55 |
+
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
| 56 |
+
)
|
| 57 |
+
else:
|
| 58 |
+
sample[
|
| 59 |
+
"prompt"
|
| 60 |
+
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
| 61 |
+
sample["completion"] = f"{sample['completion']}<|eot_id|>"
|
| 62 |
+
return sample
|
| 63 |
+
|
| 64 |
+
return transform_fn
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def prompt_pairs(
|
| 68 |
+
cfg, **kwargs
|
| 69 |
+
): # pylint: disable=possibly-unused-variable,unused-argument
|
| 70 |
+
def transform_fn(sample):
|
| 71 |
+
if "system" in sample and sample["system"]:
|
| 72 |
+
sample["prompt"] = (
|
| 73 |
+
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
|
| 74 |
+
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
| 75 |
+
)
|
| 76 |
+
else:
|
| 77 |
+
sample[
|
| 78 |
+
"prompt"
|
| 79 |
+
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
| 80 |
+
sample["completion"] = f"{sample['completion']}<|eot_id|>"
|
| 81 |
+
return sample
|
| 82 |
+
|
| 83 |
+
return transform_fn
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def ultra(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
|
| 87 |
+
"""
|
| 88 |
+
for ultrafeedback binarized conversations
|
| 89 |
+
ex: argilla/ultrafeedback-binarized-preferences-cleaned-kto
|
| 90 |
+
"""
|
| 91 |
+
|
| 92 |
+
def transform_fn(sample):
|
| 93 |
+
if "system" in sample and sample["system"]:
|
| 94 |
+
sample["prompt"] = (
|
| 95 |
+
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
|
| 96 |
+
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
| 97 |
+
)
|
| 98 |
+
else:
|
| 99 |
+
sample[
|
| 100 |
+
"prompt"
|
| 101 |
+
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
| 102 |
+
sample["completion"] = f"{sample['completion']}<|eot_id|>"
|
| 103 |
+
return sample
|
| 104 |
+
|
| 105 |
+
return transform_fn
|
src/axolotl/prompt_strategies/kto/user_defined.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
User-defined KTO strategies
|
| 3 |
+
"""
|
| 4 |
+
# pylint: disable=duplicate-code
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def default(cfg, dataset_idx=0, **kwargs): # pylint: disable=unused-argument
|
| 8 |
+
ds_cfg = cfg["datasets"][dataset_idx]["type"]
|
| 9 |
+
if not isinstance(ds_cfg, dict):
|
| 10 |
+
raise ValueError(
|
| 11 |
+
f"User-defined dataset type must be a dictionary. Got: {ds_cfg}"
|
| 12 |
+
)
|
| 13 |
+
field_prompt = ds_cfg.get("field_prompt", "prompt")
|
| 14 |
+
field_system = ds_cfg.get("field_system", "system")
|
| 15 |
+
field_completion = ds_cfg.get("field_completion", "completion")
|
| 16 |
+
field_label = ds_cfg.get("field_label", "label")
|
| 17 |
+
prompt_format = ds_cfg.get("prompt_format")
|
| 18 |
+
if not prompt_format:
|
| 19 |
+
prompt_format = "{" + field_prompt + "}"
|
| 20 |
+
completion_format = ds_cfg.get("completion_format")
|
| 21 |
+
if not completion_format:
|
| 22 |
+
chosen_format = "{" + field_completion + "}"
|
| 23 |
+
|
| 24 |
+
def transform_fn(sample):
|
| 25 |
+
if (
|
| 26 |
+
"{" + field_system + "}" in prompt_format
|
| 27 |
+
and field_system in sample
|
| 28 |
+
and sample[field_system]
|
| 29 |
+
):
|
| 30 |
+
sample["prompt"] = prompt_format.format(
|
| 31 |
+
system=sample[field_system], prompt=sample[field_prompt]
|
| 32 |
+
)
|
| 33 |
+
else:
|
| 34 |
+
sample["prompt"] = prompt_format.format(prompt=sample["prompt"])
|
| 35 |
+
sample["completion"] = chosen_format.format(chosen=sample[field_completion])
|
| 36 |
+
sample["label"] = sample[field_label]
|
| 37 |
+
return sample
|
| 38 |
+
|
| 39 |
+
return transform_fn
|
src/axolotl/utils/config/models/input/v0_4_1/__init__.py
CHANGED
|
@@ -24,6 +24,7 @@ class DeprecatedParameters(BaseModel):
|
|
| 24 |
max_packed_sequence_len: Optional[int] = None
|
| 25 |
rope_scaling: Optional[Any] = None
|
| 26 |
noisy_embedding_alpha: Optional[float] = None
|
|
|
|
| 27 |
|
| 28 |
@field_validator("max_packed_sequence_len")
|
| 29 |
@classmethod
|
|
@@ -48,6 +49,13 @@ class DeprecatedParameters(BaseModel):
|
|
| 48 |
LOG.warning("noisy_embedding_alpha is deprecated, use neftune_noise_alpha")
|
| 49 |
return noisy_embedding_alpha
|
| 50 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
|
| 52 |
class RemappedParameters(BaseModel):
|
| 53 |
"""parameters that have been remapped to other names"""
|
|
@@ -126,6 +134,26 @@ class DPODataset(BaseModel):
|
|
| 126 |
data_files: Optional[List[str]] = None
|
| 127 |
|
| 128 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
class RLType(str, Enum):
|
| 130 |
"""RL trainer type configuration subset"""
|
| 131 |
|
|
@@ -133,6 +161,7 @@ class RLType(str, Enum):
|
|
| 133 |
ipo = "ipo" # pylint: disable=invalid-name
|
| 134 |
kto_pair = "kto_pair" # pylint: disable=invalid-name
|
| 135 |
orpo = "orpo" # pylint: disable=invalid-name
|
|
|
|
| 136 |
|
| 137 |
|
| 138 |
class ChatTemplate(str, Enum):
|
|
@@ -450,8 +479,8 @@ class AxolotlInputConfig(
|
|
| 450 |
|
| 451 |
rl: Optional[RLType] = None
|
| 452 |
|
| 453 |
-
datasets: Optional[conlist(Union[SFTDataset, DPODataset], min_length=1)] = None # type: ignore
|
| 454 |
-
test_datasets: Optional[conlist(Union[SFTDataset, DPODataset], min_length=1)] = None # type: ignore
|
| 455 |
shuffle_merged_datasets: Optional[bool] = True
|
| 456 |
dataset_prepared_path: Optional[str] = None
|
| 457 |
dataset_shard_num: Optional[int] = None
|
|
@@ -585,6 +614,10 @@ class AxolotlInputConfig(
|
|
| 585 |
|
| 586 |
orpo_alpha: Optional[float] = None
|
| 587 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 588 |
max_memory: Optional[
|
| 589 |
Dict[Union[int, Literal["cpu", "disk"]], Union[int, str]]
|
| 590 |
] = None
|
|
@@ -884,6 +917,13 @@ class AxolotlInputConfig(
|
|
| 884 |
raise ValueError("neftune_noise_alpha must be > 0.0")
|
| 885 |
return neftune_noise_alpha
|
| 886 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 887 |
@model_validator(mode="before")
|
| 888 |
@classmethod
|
| 889 |
def check_frozen(cls, data):
|
|
|
|
| 24 |
max_packed_sequence_len: Optional[int] = None
|
| 25 |
rope_scaling: Optional[Any] = None
|
| 26 |
noisy_embedding_alpha: Optional[float] = None
|
| 27 |
+
dpo_beta: Optional[float] = None
|
| 28 |
|
| 29 |
@field_validator("max_packed_sequence_len")
|
| 30 |
@classmethod
|
|
|
|
| 49 |
LOG.warning("noisy_embedding_alpha is deprecated, use neftune_noise_alpha")
|
| 50 |
return noisy_embedding_alpha
|
| 51 |
|
| 52 |
+
@field_validator("dpo_beta")
|
| 53 |
+
@classmethod
|
| 54 |
+
def validate_dpo_beta(cls, dpo_beta):
|
| 55 |
+
if dpo_beta is not None:
|
| 56 |
+
LOG.warning("dpo_beta is deprecated, use rl_beta instead")
|
| 57 |
+
return dpo_beta
|
| 58 |
+
|
| 59 |
|
| 60 |
class RemappedParameters(BaseModel):
|
| 61 |
"""parameters that have been remapped to other names"""
|
|
|
|
| 134 |
data_files: Optional[List[str]] = None
|
| 135 |
|
| 136 |
|
| 137 |
+
class UserDefinedKTOType(BaseModel):
|
| 138 |
+
"""User defined typing for KTO"""
|
| 139 |
+
|
| 140 |
+
field_system: Optional[str] = None
|
| 141 |
+
field_prompt: Optional[str] = None
|
| 142 |
+
field_completion: Optional[str] = None
|
| 143 |
+
field_label: Optional[bool] = None
|
| 144 |
+
prompt_format: Optional[str] = None
|
| 145 |
+
completion_format: Optional[str] = None
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
class KTODataset(BaseModel):
|
| 149 |
+
"""KTO configuration subset"""
|
| 150 |
+
|
| 151 |
+
path: Optional[str] = None
|
| 152 |
+
split: Optional[str] = None
|
| 153 |
+
type: Optional[Union[UserDefinedKTOType, str]] = None
|
| 154 |
+
data_files: Optional[List[str]] = None
|
| 155 |
+
|
| 156 |
+
|
| 157 |
class RLType(str, Enum):
|
| 158 |
"""RL trainer type configuration subset"""
|
| 159 |
|
|
|
|
| 161 |
ipo = "ipo" # pylint: disable=invalid-name
|
| 162 |
kto_pair = "kto_pair" # pylint: disable=invalid-name
|
| 163 |
orpo = "orpo" # pylint: disable=invalid-name
|
| 164 |
+
kto = "kto" # pylint: disable=invalid-name
|
| 165 |
|
| 166 |
|
| 167 |
class ChatTemplate(str, Enum):
|
|
|
|
| 479 |
|
| 480 |
rl: Optional[RLType] = None
|
| 481 |
|
| 482 |
+
datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore
|
| 483 |
+
test_datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore
|
| 484 |
shuffle_merged_datasets: Optional[bool] = True
|
| 485 |
dataset_prepared_path: Optional[str] = None
|
| 486 |
dataset_shard_num: Optional[int] = None
|
|
|
|
| 614 |
|
| 615 |
orpo_alpha: Optional[float] = None
|
| 616 |
|
| 617 |
+
kto_desirable_weight: Optional[float] = None
|
| 618 |
+
kto_undesirable_weight: Optional[float] = None
|
| 619 |
+
rl_beta: Optional[float] = None
|
| 620 |
+
|
| 621 |
max_memory: Optional[
|
| 622 |
Dict[Union[int, Literal["cpu", "disk"]], Union[int, str]]
|
| 623 |
] = None
|
|
|
|
| 917 |
raise ValueError("neftune_noise_alpha must be > 0.0")
|
| 918 |
return neftune_noise_alpha
|
| 919 |
|
| 920 |
+
@model_validator(mode="after")
|
| 921 |
+
def check(self):
|
| 922 |
+
if self.dpo_beta and not self.rl_beta:
|
| 923 |
+
self.rl_beta = self.dpo_beta
|
| 924 |
+
del self.dpo_beta
|
| 925 |
+
return self
|
| 926 |
+
|
| 927 |
@model_validator(mode="before")
|
| 928 |
@classmethod
|
| 929 |
def check_frozen(cls, data):
|
src/axolotl/utils/data/rl.py
CHANGED
|
@@ -10,6 +10,7 @@ from datasets import DatasetDict, concatenate_datasets, load_dataset, load_from_
|
|
| 10 |
|
| 11 |
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
| 12 |
from axolotl.prompt_strategies.dpo import load as load_dpo
|
|
|
|
| 13 |
from axolotl.prompt_strategies.orpo import load as load_orpo
|
| 14 |
from axolotl.utils.data.utils import md5
|
| 15 |
from axolotl.utils.dict import DictDefault
|
|
@@ -55,6 +56,22 @@ def _save_preprocessed_ds(cfg, sub_cfg, dataset):
|
|
| 55 |
dataset.save_to_disk(str(prepared_ds_path))
|
| 56 |
|
| 57 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
def load_prepare_dpo_datasets(cfg):
|
| 59 |
def load_split(dataset_cfgs, _cfg):
|
| 60 |
split_datasets: List[Any] = []
|
|
@@ -76,6 +93,7 @@ def load_prepare_dpo_datasets(cfg):
|
|
| 76 |
split_datasets.insert(i, ds)
|
| 77 |
|
| 78 |
tokenizer = None
|
|
|
|
| 79 |
for i, data_set in enumerate(split_datasets):
|
| 80 |
_type = dataset_cfgs[i]["type"]
|
| 81 |
if _type:
|
|
@@ -83,21 +101,19 @@ def load_prepare_dpo_datasets(cfg):
|
|
| 83 |
_type = "user_defined.default"
|
| 84 |
if _cfg.rl == "orpo":
|
| 85 |
ds_transform_fn = load_orpo(_type, _cfg, dataset_idx=i)
|
|
|
|
|
|
|
| 86 |
else:
|
| 87 |
ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i)
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
ds_transform_fn,
|
| 96 |
-
desc="Mapping RL Dataset",
|
| 97 |
)
|
| 98 |
-
if isinstance(data_set, DatasetDict):
|
| 99 |
-
data_set = data_set["train"]
|
| 100 |
-
split_datasets[i] = data_set
|
| 101 |
else:
|
| 102 |
# If no `type` is provided, assume the dataset is already in the expected format with
|
| 103 |
# "prompt", "chosen" and "rejected" already preprocessed
|
|
|
|
| 10 |
|
| 11 |
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
| 12 |
from axolotl.prompt_strategies.dpo import load as load_dpo
|
| 13 |
+
from axolotl.prompt_strategies.kto import load as load_kto
|
| 14 |
from axolotl.prompt_strategies.orpo import load as load_orpo
|
| 15 |
from axolotl.utils.data.utils import md5
|
| 16 |
from axolotl.utils.dict import DictDefault
|
|
|
|
| 56 |
dataset.save_to_disk(str(prepared_ds_path))
|
| 57 |
|
| 58 |
|
| 59 |
+
def map_dataset(cfg, data_set, ds_transform_fn, tokenizer):
|
| 60 |
+
sig = inspect.signature(ds_transform_fn)
|
| 61 |
+
if "tokenizer" in sig.parameters:
|
| 62 |
+
if not tokenizer:
|
| 63 |
+
tokenizer = load_tokenizer(cfg)
|
| 64 |
+
ds_transform_fn = partial(ds_transform_fn, tokenizer=tokenizer)
|
| 65 |
+
|
| 66 |
+
data_set = data_set.map(
|
| 67 |
+
ds_transform_fn,
|
| 68 |
+
desc="Mapping RL Dataset",
|
| 69 |
+
)
|
| 70 |
+
if isinstance(data_set, DatasetDict):
|
| 71 |
+
data_set = data_set["train"]
|
| 72 |
+
return data_set
|
| 73 |
+
|
| 74 |
+
|
| 75 |
def load_prepare_dpo_datasets(cfg):
|
| 76 |
def load_split(dataset_cfgs, _cfg):
|
| 77 |
split_datasets: List[Any] = []
|
|
|
|
| 93 |
split_datasets.insert(i, ds)
|
| 94 |
|
| 95 |
tokenizer = None
|
| 96 |
+
|
| 97 |
for i, data_set in enumerate(split_datasets):
|
| 98 |
_type = dataset_cfgs[i]["type"]
|
| 99 |
if _type:
|
|
|
|
| 101 |
_type = "user_defined.default"
|
| 102 |
if _cfg.rl == "orpo":
|
| 103 |
ds_transform_fn = load_orpo(_type, _cfg, dataset_idx=i)
|
| 104 |
+
elif _cfg.rl == "kto":
|
| 105 |
+
ds_transform_fn = load_kto(_type, _cfg, dataset_idx=i)
|
| 106 |
else:
|
| 107 |
ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i)
|
| 108 |
+
|
| 109 |
+
split_datasets[i] = map_dataset(
|
| 110 |
+
cfg, data_set, ds_transform_fn, tokenizer
|
| 111 |
+
)
|
| 112 |
+
elif _cfg.rl == "kto":
|
| 113 |
+
ds_transform_fn = load_kto(_type, _cfg, dataset_idx=i)
|
| 114 |
+
split_datasets[i] = map_dataset(
|
| 115 |
+
cfg, data_set, ds_transform_fn, tokenizer
|
|
|
|
| 116 |
)
|
|
|
|
|
|
|
|
|
|
| 117 |
else:
|
| 118 |
# If no `type` is provided, assume the dataset is already in the expected format with
|
| 119 |
# "prompt", "chosen" and "rejected" already preprocessed
|
src/axolotl/utils/models.py
CHANGED
|
@@ -803,7 +803,11 @@ def load_model(
|
|
| 803 |
if not reference_model or cfg.lora_model_dir:
|
| 804 |
# if we're not loading the reference model, then we're loading the model for training
|
| 805 |
# then the dpo trainer doesn't want the peft model loaded over it, it just wants the lora/peft config
|
| 806 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
| 807 |
_, lora_config = load_lora(model, cfg, inference=False, config_only=True)
|
| 808 |
else:
|
| 809 |
model, lora_config = load_adapter(model, cfg, cfg.adapter)
|
|
|
|
| 803 |
if not reference_model or cfg.lora_model_dir:
|
| 804 |
# if we're not loading the reference model, then we're loading the model for training
|
| 805 |
# then the dpo trainer doesn't want the peft model loaded over it, it just wants the lora/peft config
|
| 806 |
+
if (
|
| 807 |
+
cfg.adapter
|
| 808 |
+
and cfg.rl in ["dpo", "ipo", "kto_pair", "kto"]
|
| 809 |
+
and not cfg.merge_lora
|
| 810 |
+
):
|
| 811 |
_, lora_config = load_lora(model, cfg, inference=False, config_only=True)
|
| 812 |
else:
|
| 813 |
model, lora_config = load_adapter(model, cfg, cfg.adapter)
|
src/axolotl/utils/trainer.py
CHANGED
|
@@ -428,7 +428,7 @@ def prepare_optim_env(cfg):
|
|
| 428 |
|
| 429 |
|
| 430 |
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
|
| 431 |
-
if cfg.rl in ["dpo", "ipo", "kto_pair", "orpo"]:
|
| 432 |
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer)
|
| 433 |
trainer_builder.model_ref = model[1]
|
| 434 |
trainer_builder.peft_config = model[2]
|
|
|
|
| 428 |
|
| 429 |
|
| 430 |
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
|
| 431 |
+
if cfg.rl in ["dpo", "ipo", "kto_pair", "orpo", "kto"]:
|
| 432 |
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer)
|
| 433 |
trainer_builder.model_ref = model[1]
|
| 434 |
trainer_builder.peft_config = model[2]
|
tests/e2e/test_dpo.py
CHANGED
|
@@ -205,3 +205,66 @@ class TestDPOLlamaLora(unittest.TestCase):
|
|
| 205 |
|
| 206 |
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
| 207 |
assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
|
| 206 |
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
| 207 |
assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
|
| 208 |
+
|
| 209 |
+
@with_temp_dir
|
| 210 |
+
def test_kto_lora(self, temp_dir):
|
| 211 |
+
# pylint: disable=duplicate-code
|
| 212 |
+
cfg = DictDefault(
|
| 213 |
+
{
|
| 214 |
+
"base_model": "JackFram/llama-68m",
|
| 215 |
+
"tokenizer_type": "LlamaTokenizer",
|
| 216 |
+
"sequence_len": 1024,
|
| 217 |
+
"load_in_8bit": True,
|
| 218 |
+
"adapter": "lora",
|
| 219 |
+
"lora_r": 64,
|
| 220 |
+
"lora_alpha": 32,
|
| 221 |
+
"lora_dropout": 0.1,
|
| 222 |
+
"lora_target_linear": True,
|
| 223 |
+
"special_tokens": {},
|
| 224 |
+
"rl": "kto",
|
| 225 |
+
"rl_beta": 0.5,
|
| 226 |
+
"kto_desirable_weight": 1.0,
|
| 227 |
+
"kto_undesirable_weight": 1.0,
|
| 228 |
+
"remove_unused_columns": False,
|
| 229 |
+
"datasets": [
|
| 230 |
+
# {
|
| 231 |
+
# "path": "argilla/kto-mix-15k",
|
| 232 |
+
# "type": "chatml.argilla_chat",
|
| 233 |
+
# "split": "train",
|
| 234 |
+
# },
|
| 235 |
+
{
|
| 236 |
+
"path": "argilla/ultrafeedback-binarized-preferences-cleaned-kto",
|
| 237 |
+
"type": "chatml.ultra",
|
| 238 |
+
"split": "train",
|
| 239 |
+
},
|
| 240 |
+
# {
|
| 241 |
+
# "path": "argilla/kto-mix-15k",
|
| 242 |
+
# "type": "llama3.argilla_chat",
|
| 243 |
+
# "split": "train",
|
| 244 |
+
# },
|
| 245 |
+
{
|
| 246 |
+
"path": "argilla/ultrafeedback-binarized-preferences-cleaned-kto",
|
| 247 |
+
"type": "llama3.ultra",
|
| 248 |
+
"split": "train",
|
| 249 |
+
},
|
| 250 |
+
],
|
| 251 |
+
"num_epochs": 1,
|
| 252 |
+
"micro_batch_size": 4,
|
| 253 |
+
"gradient_accumulation_steps": 1,
|
| 254 |
+
"output_dir": temp_dir,
|
| 255 |
+
"learning_rate": 0.00001,
|
| 256 |
+
"optimizer": "paged_adamw_8bit",
|
| 257 |
+
"lr_scheduler": "cosine",
|
| 258 |
+
"max_steps": 20,
|
| 259 |
+
"save_steps": 10,
|
| 260 |
+
"warmup_steps": 5,
|
| 261 |
+
"gradient_checkpointing": True,
|
| 262 |
+
"gradient_checkpointing_kwargs": {"use_reentrant": True},
|
| 263 |
+
}
|
| 264 |
+
)
|
| 265 |
+
normalize_config(cfg)
|
| 266 |
+
cli_args = TrainerCliArgs()
|
| 267 |
+
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
|
| 268 |
+
|
| 269 |
+
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
| 270 |
+
assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
|
tests/test_validation.py
CHANGED
|
@@ -1117,6 +1117,15 @@ class TestValidation(BaseValidation):
|
|
| 1117 |
validate_config(cfg)
|
| 1118 |
assert len(self._caplog.records) == 0
|
| 1119 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1120 |
|
| 1121 |
class TestValidationCheckModelConfig(BaseValidation):
|
| 1122 |
"""
|
|
|
|
| 1117 |
validate_config(cfg)
|
| 1118 |
assert len(self._caplog.records) == 0
|
| 1119 |
|
| 1120 |
+
def test_dpo_beta_deprecation(self, minimal_cfg):
|
| 1121 |
+
cfg = DictDefault({"dpo_beta": 0.2}) | minimal_cfg
|
| 1122 |
+
|
| 1123 |
+
with self._caplog.at_level(logging.WARNING):
|
| 1124 |
+
new_cfg = validate_config(cfg)
|
| 1125 |
+
assert new_cfg["rl_beta"] == 0.2
|
| 1126 |
+
assert new_cfg["dpo_beta"] is None
|
| 1127 |
+
assert len(self._caplog.records) == 1
|
| 1128 |
+
|
| 1129 |
|
| 1130 |
class TestValidationCheckModelConfig(BaseValidation):
|
| 1131 |
"""
|