update for recent transformers updates (#636)
Browse files* update for recent transformers updates
* fix checkpoint forward kwargs
* just pass args into torch checkpoint
src/axolotl/monkeypatch/llama_attn_hijack_flash.py
CHANGED
|
@@ -99,6 +99,7 @@ def flashattn_forward(
|
|
| 99 |
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
| 100 |
output_attentions: bool = False,
|
| 101 |
use_cache: bool = False,
|
|
|
|
| 102 |
cu_seqlens: Optional[torch.Tensor] = None,
|
| 103 |
max_seqlen: Optional[torch.Tensor] = None,
|
| 104 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
@@ -476,6 +477,13 @@ def llama_model_forward(
|
|
| 476 |
dtype=torch.bool,
|
| 477 |
device=inputs_embeds.device,
|
| 478 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 479 |
attention_mask = (
|
| 480 |
self._prepare_decoder_attention_mask( # pylint: disable=protected-access
|
| 481 |
attention_mask,
|
|
@@ -510,7 +518,9 @@ def llama_model_forward(
|
|
| 510 |
def create_custom_forward(module):
|
| 511 |
def custom_forward(*inputs):
|
| 512 |
# None for past_key_value
|
| 513 |
-
return module(
|
|
|
|
|
|
|
| 514 |
|
| 515 |
return custom_forward
|
| 516 |
|
|
@@ -519,9 +529,10 @@ def llama_model_forward(
|
|
| 519 |
hidden_states,
|
| 520 |
attention_mask,
|
| 521 |
position_ids,
|
| 522 |
-
|
| 523 |
output_attentions,
|
| 524 |
None,
|
|
|
|
| 525 |
cu_seqlens,
|
| 526 |
max_seqlen,
|
| 527 |
)
|
|
@@ -533,6 +544,7 @@ def llama_model_forward(
|
|
| 533 |
past_key_value=past_key_value,
|
| 534 |
output_attentions=output_attentions,
|
| 535 |
use_cache=use_cache,
|
|
|
|
| 536 |
cu_seqlens=cu_seqlens,
|
| 537 |
max_seqlen=max_seqlen,
|
| 538 |
)
|
|
@@ -579,6 +591,7 @@ class LlamaDecoderLayer(OriginalLlamaDecoderLayer):
|
|
| 579 |
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
| 580 |
output_attentions: Optional[bool] = False,
|
| 581 |
use_cache: Optional[bool] = False,
|
|
|
|
| 582 |
cu_seqlens: Optional[torch.Tensor] = None,
|
| 583 |
max_seqlen: Optional[torch.Tensor] = None,
|
| 584 |
) -> Tuple[
|
|
@@ -611,6 +624,7 @@ class LlamaDecoderLayer(OriginalLlamaDecoderLayer):
|
|
| 611 |
past_key_value=past_key_value,
|
| 612 |
output_attentions=output_attentions,
|
| 613 |
use_cache=use_cache,
|
|
|
|
| 614 |
cu_seqlens=cu_seqlens,
|
| 615 |
max_seqlen=max_seqlen,
|
| 616 |
)
|
|
|
|
| 99 |
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
| 100 |
output_attentions: bool = False,
|
| 101 |
use_cache: bool = False,
|
| 102 |
+
padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument
|
| 103 |
cu_seqlens: Optional[torch.Tensor] = None,
|
| 104 |
max_seqlen: Optional[torch.Tensor] = None,
|
| 105 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
|
|
| 477 |
dtype=torch.bool,
|
| 478 |
device=inputs_embeds.device,
|
| 479 |
)
|
| 480 |
+
padding_mask = None
|
| 481 |
+
else:
|
| 482 |
+
if 0 in attention_mask:
|
| 483 |
+
padding_mask = attention_mask
|
| 484 |
+
else:
|
| 485 |
+
padding_mask = None
|
| 486 |
+
|
| 487 |
attention_mask = (
|
| 488 |
self._prepare_decoder_attention_mask( # pylint: disable=protected-access
|
| 489 |
attention_mask,
|
|
|
|
| 518 |
def create_custom_forward(module):
|
| 519 |
def custom_forward(*inputs):
|
| 520 |
# None for past_key_value
|
| 521 |
+
return module(
|
| 522 |
+
*inputs,
|
| 523 |
+
)
|
| 524 |
|
| 525 |
return custom_forward
|
| 526 |
|
|
|
|
| 529 |
hidden_states,
|
| 530 |
attention_mask,
|
| 531 |
position_ids,
|
| 532 |
+
past_key_value,
|
| 533 |
output_attentions,
|
| 534 |
None,
|
| 535 |
+
padding_mask,
|
| 536 |
cu_seqlens,
|
| 537 |
max_seqlen,
|
| 538 |
)
|
|
|
|
| 544 |
past_key_value=past_key_value,
|
| 545 |
output_attentions=output_attentions,
|
| 546 |
use_cache=use_cache,
|
| 547 |
+
padding_mask=padding_mask,
|
| 548 |
cu_seqlens=cu_seqlens,
|
| 549 |
max_seqlen=max_seqlen,
|
| 550 |
)
|
|
|
|
| 591 |
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
| 592 |
output_attentions: Optional[bool] = False,
|
| 593 |
use_cache: Optional[bool] = False,
|
| 594 |
+
padding_mask: Optional[torch.LongTensor] = None,
|
| 595 |
cu_seqlens: Optional[torch.Tensor] = None,
|
| 596 |
max_seqlen: Optional[torch.Tensor] = None,
|
| 597 |
) -> Tuple[
|
|
|
|
| 624 |
past_key_value=past_key_value,
|
| 625 |
output_attentions=output_attentions,
|
| 626 |
use_cache=use_cache,
|
| 627 |
+
padding_mask=padding_mask,
|
| 628 |
cu_seqlens=cu_seqlens,
|
| 629 |
max_seqlen=max_seqlen,
|
| 630 |
)
|