swap the data collator for evals if not using sample packing (#1076)
Browse files* swap the data collator for evals if not using sample packing
* drop last from dataloader to help with issues with evals
src/axolotl/core/trainer_builder.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
"""
|
| 2 |
Builder for the training args and trainer
|
| 3 |
"""
|
|
@@ -137,10 +138,19 @@ class AxolotlTrainer(Trainer):
|
|
| 137 |
args = None # type: AxolotlTrainingArguments
|
| 138 |
tag_names = ["axolotl"]
|
| 139 |
|
| 140 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
self.num_epochs = num_epochs
|
| 142 |
self.bench_data_collator = bench_data_collator
|
| 143 |
-
|
|
|
|
|
|
|
| 144 |
|
| 145 |
def create_scheduler(
|
| 146 |
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
|
|
@@ -239,6 +249,16 @@ class AxolotlTrainer(Trainer):
|
|
| 239 |
return super().get_train_dataloader()
|
| 240 |
|
| 241 |
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 242 |
if self.args.sample_packing and self.args.eval_sample_packing is not False:
|
| 243 |
eval_dataset = (
|
| 244 |
eval_dataset if eval_dataset is not None else self.eval_dataset
|
|
@@ -269,6 +289,7 @@ class AxolotlTrainer(Trainer):
|
|
| 269 |
return self.accelerator.prepare_data_loader(
|
| 270 |
DataLoader(eval_dataset, **dataloader_params)
|
| 271 |
)
|
|
|
|
| 272 |
return super().get_eval_dataloader(eval_dataset)
|
| 273 |
|
| 274 |
def _get_bench_sampler(
|
|
@@ -651,6 +672,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|
| 651 |
training_arguments_kwargs[
|
| 652 |
"dataloader_prefetch_factor"
|
| 653 |
] = self.cfg.dataloader_prefetch_factor
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 654 |
|
| 655 |
if self.cfg.val_set_size == 0:
|
| 656 |
# no eval set, so don't eval
|
|
@@ -831,6 +858,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|
| 831 |
eval_dataset=self.eval_dataset,
|
| 832 |
args=training_args,
|
| 833 |
data_collator=self.build_collator(training_args, **data_collator_kwargs),
|
|
|
|
|
|
|
|
|
|
| 834 |
bench_data_collator=transformers.DataCollatorForSeq2Seq(
|
| 835 |
self.tokenizer,
|
| 836 |
return_tensors="pt",
|
|
@@ -851,14 +881,22 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|
| 851 |
|
| 852 |
return trainer
|
| 853 |
|
| 854 |
-
def build_collator(
|
|
|
|
|
|
|
| 855 |
if training_args.pretraining:
|
| 856 |
return None
|
| 857 |
|
| 858 |
if self.cfg.model_config_type == "mamba":
|
| 859 |
return MambaDataCollator(tokenizer=self.tokenizer)
|
| 860 |
|
| 861 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 862 |
return BatchSamplerDataCollatorForSeq2Seq(
|
| 863 |
self.tokenizer,
|
| 864 |
return_tensors="pt",
|
|
|
|
| 1 |
+
# pylint: disable=too-many-lines
|
| 2 |
"""
|
| 3 |
Builder for the training args and trainer
|
| 4 |
"""
|
|
|
|
| 138 |
args = None # type: AxolotlTrainingArguments
|
| 139 |
tag_names = ["axolotl"]
|
| 140 |
|
| 141 |
+
def __init__(
|
| 142 |
+
self,
|
| 143 |
+
*_args,
|
| 144 |
+
num_epochs=1,
|
| 145 |
+
bench_data_collator=None,
|
| 146 |
+
eval_data_collator=None,
|
| 147 |
+
**kwargs
|
| 148 |
+
):
|
| 149 |
self.num_epochs = num_epochs
|
| 150 |
self.bench_data_collator = bench_data_collator
|
| 151 |
+
self.eval_data_collator = eval_data_collator
|
| 152 |
+
super().__init__(*_args, **kwargs)
|
| 153 |
+
self.train_data_collator = self.data_collator
|
| 154 |
|
| 155 |
def create_scheduler(
|
| 156 |
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
|
|
|
|
| 249 |
return super().get_train_dataloader()
|
| 250 |
|
| 251 |
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
|
| 252 |
+
if self.args.sample_packing and self.args.eval_sample_packing is False:
|
| 253 |
+
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
|
| 254 |
+
self.eval_data_collator
|
| 255 |
+
)
|
| 256 |
+
dataloader = super().get_eval_dataloader(eval_dataset)
|
| 257 |
+
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
|
| 258 |
+
self.train_data_collator
|
| 259 |
+
)
|
| 260 |
+
return dataloader
|
| 261 |
+
|
| 262 |
if self.args.sample_packing and self.args.eval_sample_packing is not False:
|
| 263 |
eval_dataset = (
|
| 264 |
eval_dataset if eval_dataset is not None else self.eval_dataset
|
|
|
|
| 289 |
return self.accelerator.prepare_data_loader(
|
| 290 |
DataLoader(eval_dataset, **dataloader_params)
|
| 291 |
)
|
| 292 |
+
|
| 293 |
return super().get_eval_dataloader(eval_dataset)
|
| 294 |
|
| 295 |
def _get_bench_sampler(
|
|
|
|
| 672 |
training_arguments_kwargs[
|
| 673 |
"dataloader_prefetch_factor"
|
| 674 |
] = self.cfg.dataloader_prefetch_factor
|
| 675 |
+
if self.cfg.dataloader_drop_last is not None:
|
| 676 |
+
training_arguments_kwargs[
|
| 677 |
+
"dataloader_drop_last"
|
| 678 |
+
] = self.cfg.dataloader_drop_last
|
| 679 |
+
elif self.cfg.sample_packing and self.cfg.eval_sample_packing is False:
|
| 680 |
+
training_arguments_kwargs["dataloader_drop_last"] = True
|
| 681 |
|
| 682 |
if self.cfg.val_set_size == 0:
|
| 683 |
# no eval set, so don't eval
|
|
|
|
| 858 |
eval_dataset=self.eval_dataset,
|
| 859 |
args=training_args,
|
| 860 |
data_collator=self.build_collator(training_args, **data_collator_kwargs),
|
| 861 |
+
eval_data_collator=self.build_collator(
|
| 862 |
+
training_args, is_eval=True, **data_collator_kwargs
|
| 863 |
+
),
|
| 864 |
bench_data_collator=transformers.DataCollatorForSeq2Seq(
|
| 865 |
self.tokenizer,
|
| 866 |
return_tensors="pt",
|
|
|
|
| 881 |
|
| 882 |
return trainer
|
| 883 |
|
| 884 |
+
def build_collator(
|
| 885 |
+
self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs
|
| 886 |
+
):
|
| 887 |
if training_args.pretraining:
|
| 888 |
return None
|
| 889 |
|
| 890 |
if self.cfg.model_config_type == "mamba":
|
| 891 |
return MambaDataCollator(tokenizer=self.tokenizer)
|
| 892 |
|
| 893 |
+
use_batch_sampler_collator = False
|
| 894 |
+
if is_eval is False and training_args.sample_packing:
|
| 895 |
+
use_batch_sampler_collator = True
|
| 896 |
+
if is_eval and training_args.eval_sample_packing:
|
| 897 |
+
use_batch_sampler_collator = True
|
| 898 |
+
|
| 899 |
+
if use_batch_sampler_collator:
|
| 900 |
return BatchSamplerDataCollatorForSeq2Seq(
|
| 901 |
self.tokenizer,
|
| 902 |
return_tensors="pt",
|