Refactor landmark attention patch
Browse files
src/axolotl/monkeypatch/llama_landmark_attn.py
CHANGED
|
@@ -1593,3 +1593,12 @@ def add_mem_tokens(example, mem_freq, mem_id):
|
|
| 1593 |
ret.extend(x[prev_idx:])
|
| 1594 |
# drop attention_mask
|
| 1595 |
return {"input_ids": ret}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1593 |
ret.extend(x[prev_idx:])
|
| 1594 |
# drop attention_mask
|
| 1595 |
return {"input_ids": ret}
|
| 1596 |
+
|
| 1597 |
+
|
| 1598 |
+
def patch_llama_with_landmark_attn():
|
| 1599 |
+
import transformers
|
| 1600 |
+
|
| 1601 |
+
transformers.models.llama.modeling_llama.LlamaForCausalLM = LlamaForCausalLM
|
| 1602 |
+
transformers.models.llama.modeling_llama.LlamaModel = LlamaModel
|
| 1603 |
+
transformers.models.llama.modeling_llama.LlamaAttention = LlamaAttention
|
| 1604 |
+
transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer
|
src/axolotl/utils/models.py
CHANGED
|
@@ -19,15 +19,6 @@ from transformers import ( # noqa: F401
|
|
| 19 |
LlamaConfig,
|
| 20 |
)
|
| 21 |
|
| 22 |
-
try:
|
| 23 |
-
from transformers import ( # pylint: disable=unused-import # noqa: F401
|
| 24 |
-
LlamaForCausalLM,
|
| 25 |
-
)
|
| 26 |
-
except ImportError:
|
| 27 |
-
logging.warning(
|
| 28 |
-
"This version of transformers does not support Llama. Consider upgrading."
|
| 29 |
-
)
|
| 30 |
-
|
| 31 |
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
|
| 32 |
|
| 33 |
if TYPE_CHECKING:
|
|
@@ -118,14 +109,15 @@ def load_model(
|
|
| 118 |
logging.info("patching with sdp attention")
|
| 119 |
hijack_llama_sdp_attention()
|
| 120 |
elif cfg.is_llama_derived_model and cfg.landmark_attention:
|
| 121 |
-
from axolotl.monkeypatch.llama_landmark_attn import (
|
| 122 |
MEM_TOKEN,
|
| 123 |
-
|
| 124 |
)
|
| 125 |
|
| 126 |
logging.info("patching with landmark attention")
|
|
|
|
| 127 |
|
| 128 |
-
#
|
| 129 |
tokenizer.add_special_tokens({"additional_special_tokens": [MEM_TOKEN]})
|
| 130 |
|
| 131 |
if cfg.is_llama_derived_model and cfg.xpos_rope:
|
|
@@ -211,6 +203,13 @@ def load_model(
|
|
| 211 |
)
|
| 212 |
load_in_8bit = False
|
| 213 |
elif cfg.is_llama_derived_model and "LlamaForCausalLM" in globals():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 214 |
config = LlamaConfig.from_pretrained(base_model_config)
|
| 215 |
model = LlamaForCausalLM.from_pretrained(
|
| 216 |
base_model,
|
|
|
|
| 19 |
LlamaConfig,
|
| 20 |
)
|
| 21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
|
| 23 |
|
| 24 |
if TYPE_CHECKING:
|
|
|
|
| 109 |
logging.info("patching with sdp attention")
|
| 110 |
hijack_llama_sdp_attention()
|
| 111 |
elif cfg.is_llama_derived_model and cfg.landmark_attention:
|
| 112 |
+
from axolotl.monkeypatch.llama_landmark_attn import (
|
| 113 |
MEM_TOKEN,
|
| 114 |
+
patch_llama_with_landmark_attn,
|
| 115 |
)
|
| 116 |
|
| 117 |
logging.info("patching with landmark attention")
|
| 118 |
+
patch_llama_with_landmark_attn()
|
| 119 |
|
| 120 |
+
# Note: This might overwrite previous additional_special_tokens
|
| 121 |
tokenizer.add_special_tokens({"additional_special_tokens": [MEM_TOKEN]})
|
| 122 |
|
| 123 |
if cfg.is_llama_derived_model and cfg.xpos_rope:
|
|
|
|
| 203 |
)
|
| 204 |
load_in_8bit = False
|
| 205 |
elif cfg.is_llama_derived_model and "LlamaForCausalLM" in globals():
|
| 206 |
+
try:
|
| 207 |
+
from transformers import LlamaForCausalLM
|
| 208 |
+
except ImportError:
|
| 209 |
+
logging.warning(
|
| 210 |
+
"This version of transformers does not support Llama. Consider upgrading."
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
config = LlamaConfig.from_pretrained(base_model_config)
|
| 214 |
model = LlamaForCausalLM.from_pretrained(
|
| 215 |
base_model,
|