multipack for gemma (#1313)
Browse files* multipack for gemma
* chore: lint
* handle cache_position kwarg in updated llama modeling
* add position_ids to rotary embed call for updated llama modeling
examples/gemma/qlora.yml
CHANGED
|
@@ -1,49 +1,49 @@
|
|
| 1 |
# use google/gemma-7b if you have access
|
| 2 |
-
base_model: mhenrichsen/gemma-7b
|
| 3 |
model_type: AutoModelForCausalLM
|
| 4 |
tokenizer_type: AutoTokenizer
|
| 5 |
-
|
| 6 |
load_in_8bit: false
|
| 7 |
load_in_4bit: true
|
| 8 |
strict: false
|
| 9 |
-
|
| 10 |
# huggingface repo
|
| 11 |
datasets:
|
| 12 |
- path: mhenrichsen/alpaca_2k_test
|
| 13 |
type: alpaca
|
| 14 |
val_set_size: 0.1
|
| 15 |
output_dir: ./out
|
| 16 |
-
|
| 17 |
adapter: qlora
|
| 18 |
lora_r: 32
|
| 19 |
lora_alpha: 16
|
| 20 |
lora_dropout: 0.05
|
| 21 |
lora_target_linear: true
|
| 22 |
-
|
| 23 |
sequence_len: 4096
|
| 24 |
sample_packing: false
|
| 25 |
pad_to_sequence_len: false
|
| 26 |
-
|
| 27 |
wandb_project:
|
| 28 |
wandb_entity:
|
| 29 |
wandb_watch:
|
| 30 |
wandb_name:
|
| 31 |
wandb_log_model:
|
| 32 |
-
|
| 33 |
-
|
| 34 |
gradient_accumulation_steps: 3
|
| 35 |
micro_batch_size: 2
|
| 36 |
num_epochs: 4
|
| 37 |
optimizer: adamw_bnb_8bit
|
| 38 |
lr_scheduler: cosine
|
| 39 |
learning_rate: 0.0002
|
| 40 |
-
|
| 41 |
train_on_inputs: false
|
| 42 |
group_by_length: false
|
| 43 |
bf16: auto
|
| 44 |
fp16:
|
| 45 |
tf32: false
|
| 46 |
-
|
| 47 |
gradient_checkpointing: true
|
| 48 |
early_stopping_patience:
|
| 49 |
resume_from_checkpoint:
|
|
@@ -51,7 +51,7 @@ local_rank:
|
|
| 51 |
logging_steps: 1
|
| 52 |
xformers_attention:
|
| 53 |
flash_attention: true
|
| 54 |
-
|
| 55 |
warmup_ratio: 0.1
|
| 56 |
evals_per_epoch: 4
|
| 57 |
eval_table_size:
|
|
|
|
| 1 |
# use google/gemma-7b if you have access
|
| 2 |
+
base_model: mhenrichsen/gemma-7b
|
| 3 |
model_type: AutoModelForCausalLM
|
| 4 |
tokenizer_type: AutoTokenizer
|
| 5 |
+
|
| 6 |
load_in_8bit: false
|
| 7 |
load_in_4bit: true
|
| 8 |
strict: false
|
| 9 |
+
|
| 10 |
# huggingface repo
|
| 11 |
datasets:
|
| 12 |
- path: mhenrichsen/alpaca_2k_test
|
| 13 |
type: alpaca
|
| 14 |
val_set_size: 0.1
|
| 15 |
output_dir: ./out
|
| 16 |
+
|
| 17 |
adapter: qlora
|
| 18 |
lora_r: 32
|
| 19 |
lora_alpha: 16
|
| 20 |
lora_dropout: 0.05
|
| 21 |
lora_target_linear: true
|
| 22 |
+
|
| 23 |
sequence_len: 4096
|
| 24 |
sample_packing: false
|
| 25 |
pad_to_sequence_len: false
|
| 26 |
+
|
| 27 |
wandb_project:
|
| 28 |
wandb_entity:
|
| 29 |
wandb_watch:
|
| 30 |
wandb_name:
|
| 31 |
wandb_log_model:
|
| 32 |
+
|
| 33 |
+
|
| 34 |
gradient_accumulation_steps: 3
|
| 35 |
micro_batch_size: 2
|
| 36 |
num_epochs: 4
|
| 37 |
optimizer: adamw_bnb_8bit
|
| 38 |
lr_scheduler: cosine
|
| 39 |
learning_rate: 0.0002
|
| 40 |
+
|
| 41 |
train_on_inputs: false
|
| 42 |
group_by_length: false
|
| 43 |
bf16: auto
|
| 44 |
fp16:
|
| 45 |
tf32: false
|
| 46 |
+
|
| 47 |
gradient_checkpointing: true
|
| 48 |
early_stopping_patience:
|
| 49 |
resume_from_checkpoint:
|
|
|
|
| 51 |
logging_steps: 1
|
| 52 |
xformers_attention:
|
| 53 |
flash_attention: true
|
| 54 |
+
|
| 55 |
warmup_ratio: 0.1
|
| 56 |
evals_per_epoch: 4
|
| 57 |
eval_table_size:
|
requirements.txt
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
| 2 |
packaging==23.2
|
| 3 |
peft @ git+https://github.com/huggingface/peft.git
|
| 4 |
-
transformers @ git+https://github.com/huggingface/transformers.git@
|
| 5 |
tokenizers==0.15.0
|
| 6 |
bitsandbytes>=0.41.1
|
| 7 |
accelerate==0.26.1
|
|
|
|
| 1 |
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
| 2 |
packaging==23.2
|
| 3 |
peft @ git+https://github.com/huggingface/peft.git
|
| 4 |
+
transformers @ git+https://github.com/huggingface/transformers.git@ae49b218c3d718df90d8e4a109016450fb8f0632
|
| 5 |
tokenizers==0.15.0
|
| 6 |
bitsandbytes>=0.41.1
|
| 7 |
accelerate==0.26.1
|
src/axolotl/monkeypatch/llama_attn_hijack_flash.py
CHANGED
|
@@ -275,7 +275,9 @@ def flashattn_forward_with_s2attn(
|
|
| 275 |
kv_seq_len = key_states.shape[-2]
|
| 276 |
if past_key_value is not None:
|
| 277 |
kv_seq_len += past_key_value[0].shape[-2]
|
| 278 |
-
cos, sin = self.rotary_emb(
|
|
|
|
|
|
|
| 279 |
query_states, key_states = apply_rotary_pos_emb(
|
| 280 |
query_states, key_states, cos, sin, position_ids
|
| 281 |
)
|
|
@@ -425,7 +427,9 @@ def flashattn_forward(
|
|
| 425 |
if past_key_value is not None:
|
| 426 |
kv_seq_len += past_key_value[0].shape[-2]
|
| 427 |
|
| 428 |
-
cos, sin = self.rotary_emb(
|
|
|
|
|
|
|
| 429 |
query_states, key_states = apply_rotary_pos_emb(
|
| 430 |
query_states, key_states, cos, sin, position_ids
|
| 431 |
)
|
|
@@ -688,6 +692,9 @@ def llama_model_forward(
|
|
| 688 |
output_attentions: Optional[bool] = None,
|
| 689 |
output_hidden_states: Optional[bool] = None,
|
| 690 |
return_dict: Optional[bool] = None,
|
|
|
|
|
|
|
|
|
|
| 691 |
) -> Union[Tuple, BaseModelOutputWithPast]:
|
| 692 |
output_attentions = (
|
| 693 |
output_attentions
|
|
|
|
| 275 |
kv_seq_len = key_states.shape[-2]
|
| 276 |
if past_key_value is not None:
|
| 277 |
kv_seq_len += past_key_value[0].shape[-2]
|
| 278 |
+
cos, sin = self.rotary_emb(
|
| 279 |
+
value_states, seq_len=kv_seq_len, position_ids=position_ids
|
| 280 |
+
)
|
| 281 |
query_states, key_states = apply_rotary_pos_emb(
|
| 282 |
query_states, key_states, cos, sin, position_ids
|
| 283 |
)
|
|
|
|
| 427 |
if past_key_value is not None:
|
| 428 |
kv_seq_len += past_key_value[0].shape[-2]
|
| 429 |
|
| 430 |
+
cos, sin = self.rotary_emb(
|
| 431 |
+
value_states, seq_len=kv_seq_len, position_ids=position_ids
|
| 432 |
+
)
|
| 433 |
query_states, key_states = apply_rotary_pos_emb(
|
| 434 |
query_states, key_states, cos, sin, position_ids
|
| 435 |
)
|
|
|
|
| 692 |
output_attentions: Optional[bool] = None,
|
| 693 |
output_hidden_states: Optional[bool] = None,
|
| 694 |
return_dict: Optional[bool] = None,
|
| 695 |
+
cache_position: Optional[ # pylint: disable=unused-argument
|
| 696 |
+
torch.LongTensor
|
| 697 |
+
] = None,
|
| 698 |
) -> Union[Tuple, BaseModelOutputWithPast]:
|
| 699 |
output_attentions = (
|
| 700 |
output_attentions
|
src/axolotl/monkeypatch/multipack.py
CHANGED
|
@@ -6,7 +6,7 @@ from transformers.integrations import is_deepspeed_zero3_enabled
|
|
| 6 |
from axolotl.monkeypatch.mixtral import patch_mixtral_moe_forward_zero3
|
| 7 |
from axolotl.monkeypatch.utils import get_unpad_data
|
| 8 |
|
| 9 |
-
SUPPORTED_MULTIPACK_MODEL_TYPES = ["mixtral", "qwen2", "falcon", "phi"]
|
| 10 |
|
| 11 |
|
| 12 |
def patch_for_multipack(model_type):
|
|
@@ -28,3 +28,7 @@ def patch_for_multipack(model_type):
|
|
| 28 |
transformers.models.phi.modeling_phi._get_unpad_data = ( # pylint: disable=protected-access
|
| 29 |
get_unpad_data
|
| 30 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
from axolotl.monkeypatch.mixtral import patch_mixtral_moe_forward_zero3
|
| 7 |
from axolotl.monkeypatch.utils import get_unpad_data
|
| 8 |
|
| 9 |
+
SUPPORTED_MULTIPACK_MODEL_TYPES = ["mixtral", "qwen2", "falcon", "phi", "gemma"]
|
| 10 |
|
| 11 |
|
| 12 |
def patch_for_multipack(model_type):
|
|
|
|
| 28 |
transformers.models.phi.modeling_phi._get_unpad_data = ( # pylint: disable=protected-access
|
| 29 |
get_unpad_data
|
| 30 |
)
|
| 31 |
+
elif model_type == "gemma":
|
| 32 |
+
transformers.models.gemma.modeling_gemma._get_unpad_data = ( # pylint: disable=protected-access
|
| 33 |
+
get_unpad_data
|
| 34 |
+
)
|