fix evals (#447)
Browse files
src/axolotl/monkeypatch/llama_attn_hijack_flash.py
CHANGED
|
@@ -169,7 +169,7 @@ def flashattn_forward(
|
|
| 169 |
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
| 170 |
|
| 171 |
output = flash_attn_varlen_qkvpacked_func(
|
| 172 |
-
qkv, cu_seqlens, max_seqlen, 0.0, softmax_scale=None, causal=
|
| 173 |
)
|
| 174 |
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
| 175 |
elif query_states.shape == key_states.shape:
|
|
|
|
| 169 |
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
| 170 |
|
| 171 |
output = flash_attn_varlen_qkvpacked_func(
|
| 172 |
+
qkv, cu_seqlens, max_seqlen, 0.0, softmax_scale=None, causal=True
|
| 173 |
)
|
| 174 |
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
| 175 |
elif query_states.shape == key_states.shape:
|
src/axolotl/utils/models.py
CHANGED
|
@@ -438,7 +438,7 @@ def load_llama_adapter(model, cfg):
|
|
| 438 |
)
|
| 439 |
|
| 440 |
if cfg.lora_model_dir:
|
| 441 |
-
LOG.
|
| 442 |
model = PeftModel.from_pretrained(
|
| 443 |
model,
|
| 444 |
cfg.lora_model_dir,
|
|
@@ -500,6 +500,7 @@ def load_lora(model, cfg):
|
|
| 500 |
)
|
| 501 |
|
| 502 |
if cfg.lora_model_dir:
|
|
|
|
| 503 |
model = PeftModel.from_pretrained(
|
| 504 |
model,
|
| 505 |
cfg.lora_model_dir,
|
|
|
|
| 438 |
)
|
| 439 |
|
| 440 |
if cfg.lora_model_dir:
|
| 441 |
+
LOG.debug("Loading pretained PEFT - llama_adapter")
|
| 442 |
model = PeftModel.from_pretrained(
|
| 443 |
model,
|
| 444 |
cfg.lora_model_dir,
|
|
|
|
| 500 |
)
|
| 501 |
|
| 502 |
if cfg.lora_model_dir:
|
| 503 |
+
LOG.debug("Loading pretained PEFT - LoRA")
|
| 504 |
model = PeftModel.from_pretrained(
|
| 505 |
model,
|
| 506 |
cfg.lora_model_dir,
|
src/axolotl/utils/trainer.py
CHANGED
|
@@ -14,12 +14,15 @@ import bitsandbytes as bnb
|
|
| 14 |
import numpy as np
|
| 15 |
import torch.cuda
|
| 16 |
import transformers
|
| 17 |
-
from datasets import set_caching_enabled
|
| 18 |
from torch import nn
|
| 19 |
from torch.optim.lr_scheduler import OneCycleLR
|
| 20 |
from torch.utils.data import DataLoader, DistributedSampler, RandomSampler
|
| 21 |
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
|
| 22 |
-
from transformers.trainer_pt_utils import
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
from axolotl.utils.callbacks import (
|
| 25 |
GPUStatsCallback,
|
|
@@ -171,6 +174,18 @@ class AxolotlTrainer(Trainer):
|
|
| 171 |
)
|
| 172 |
return super()._get_train_sampler()
|
| 173 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
def get_train_dataloader(self) -> Union[DataLoader, MultipackDistributedDataloader]:
|
| 175 |
if self.args.sample_packing:
|
| 176 |
train_sampler = self._get_train_sampler()
|
|
@@ -188,27 +203,28 @@ class AxolotlTrainer(Trainer):
|
|
| 188 |
)
|
| 189 |
return super().get_train_dataloader()
|
| 190 |
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
|
|
|
| 212 |
|
| 213 |
def compute_loss(self, model, inputs, return_outputs=False):
|
| 214 |
# use one's weighted cross entropy loss calc
|
|
|
|
| 14 |
import numpy as np
|
| 15 |
import torch.cuda
|
| 16 |
import transformers
|
| 17 |
+
from datasets import Dataset, set_caching_enabled
|
| 18 |
from torch import nn
|
| 19 |
from torch.optim.lr_scheduler import OneCycleLR
|
| 20 |
from torch.utils.data import DataLoader, DistributedSampler, RandomSampler
|
| 21 |
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
|
| 22 |
+
from transformers.trainer_pt_utils import (
|
| 23 |
+
SequentialDistributedSampler,
|
| 24 |
+
get_parameter_names,
|
| 25 |
+
)
|
| 26 |
|
| 27 |
from axolotl.utils.callbacks import (
|
| 28 |
GPUStatsCallback,
|
|
|
|
| 174 |
)
|
| 175 |
return super()._get_train_sampler()
|
| 176 |
|
| 177 |
+
def _get_eval_sampler(
|
| 178 |
+
self, eval_dataset: Dataset
|
| 179 |
+
) -> Optional[torch.utils.data.Sampler]:
|
| 180 |
+
if self.args.world_size > 1 and self.args.sample_packing:
|
| 181 |
+
return SequentialDistributedSampler(
|
| 182 |
+
eval_dataset,
|
| 183 |
+
num_replicas=self.args.world_size,
|
| 184 |
+
rank=self.args.process_index,
|
| 185 |
+
batch_size=self.args.per_device_eval_batch_size,
|
| 186 |
+
)
|
| 187 |
+
return super()._get_eval_sampler()
|
| 188 |
+
|
| 189 |
def get_train_dataloader(self) -> Union[DataLoader, MultipackDistributedDataloader]:
|
| 190 |
if self.args.sample_packing:
|
| 191 |
train_sampler = self._get_train_sampler()
|
|
|
|
| 203 |
)
|
| 204 |
return super().get_train_dataloader()
|
| 205 |
|
| 206 |
+
def get_eval_dataloader(
|
| 207 |
+
self, eval_dataset: Optional[Dataset] = None
|
| 208 |
+
) -> Union[DataLoader, MultipackDistributedDataloader]:
|
| 209 |
+
if self.args.sample_packing:
|
| 210 |
+
eval_dataset = (
|
| 211 |
+
eval_dataset if eval_dataset is not None else self.eval_dataset
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
eval_sampler = self._get_eval_sampler(eval_dataset)
|
| 215 |
+
return self.accelerator.prepare(
|
| 216 |
+
MultipackDistributedDataloader(
|
| 217 |
+
eval_dataset,
|
| 218 |
+
batch_size=self.args.eval_batch_size,
|
| 219 |
+
seq_max_length=self.args.max_seq_length,
|
| 220 |
+
collate_fn=self.data_collator,
|
| 221 |
+
sampler=eval_sampler,
|
| 222 |
+
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
| 223 |
+
sample_packing_seq_len_multiplier=self.args.eval_batch_size,
|
| 224 |
+
device_count=int(os.environ.get("WORLD_SIZE", 1)),
|
| 225 |
+
)
|
| 226 |
+
)
|
| 227 |
+
return super().get_eval_dataloader(eval_dataset)
|
| 228 |
|
| 229 |
def compute_loss(self, model, inputs, return_outputs=False):
|
| 230 |
# use one's weighted cross entropy loss calc
|