add sanity checks
Browse files- README.md +2 -2
- custom_generate/generate.py +51 -31
README.md
CHANGED
@@ -21,8 +21,8 @@ This implementation should match the `SinkCache` class present in `transformers<
|
|
21 |
|
22 |
|
23 |
## Additional Arguments
|
24 |
-
- `window_length` (`int`, defaults to
|
25 |
-
- `num_sink_tokens` (`int`, defaults to
|
26 |
|
27 |
|
28 |
## Output Type changes
|
|
|
21 |
|
22 |
|
23 |
## Additional Arguments
|
24 |
+
- `window_length` (`int`, *optional*, defaults to 256): The length of the context window.
|
25 |
+
- `num_sink_tokens` (`int`, *optional*, defaults to 4): The number of sink tokens. See the original paper for more information.
|
26 |
|
27 |
|
28 |
## Output Type changes
|
custom_generate/generate.py
CHANGED
@@ -1,11 +1,18 @@
|
|
1 |
import torch
|
2 |
from typing import Any, Dict, List, Optional, Tuple
|
3 |
-
from transformers.utils import logging
|
4 |
-
from transformers.cache_utils import Cache
|
5 |
|
6 |
-
|
7 |
|
8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
class SinkCache(Cache):
|
10 |
"""
|
11 |
A cache that as described in the [Attention Sinks paper](https://arxiv.org/abs/2309.17453). It allows the model to
|
@@ -15,28 +22,13 @@ class SinkCache(Cache):
|
|
15 |
It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
|
16 |
`[batch_size, num_heads, seq_len, head_dim]`.
|
17 |
|
|
|
|
|
18 |
Parameters:
|
19 |
window_length (`int`):
|
20 |
The length of the context window.
|
21 |
num_sink_tokens (`int`):
|
22 |
The number of sink tokens. See the original paper for more information.
|
23 |
-
|
24 |
-
Example:
|
25 |
-
|
26 |
-
```python
|
27 |
-
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, SinkCache
|
28 |
-
|
29 |
-
>>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
|
30 |
-
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
|
31 |
-
|
32 |
-
>>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt")
|
33 |
-
|
34 |
-
>>> # Prepare a cache class and pass it to model's forward
|
35 |
-
>>> past_key_values = SinkCache(window_length=256, num_sink_tokens=4)
|
36 |
-
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
|
37 |
-
>>> outputs.past_key_values # access cache filled with key/values from generation
|
38 |
-
SinkCache()
|
39 |
-
```
|
40 |
"""
|
41 |
|
42 |
def __init__(self, window_length: int, num_sink_tokens: int) -> None:
|
@@ -48,7 +40,6 @@ class SinkCache(Cache):
|
|
48 |
self.cos_sin_rerotation_cache = {}
|
49 |
self._cos_cache = None
|
50 |
self._sin_cache = None
|
51 |
-
self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
|
52 |
|
53 |
@staticmethod
|
54 |
def _rotate_half(x):
|
@@ -86,8 +77,6 @@ class SinkCache(Cache):
|
|
86 |
|
87 |
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
88 |
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
|
89 |
-
# TODO: deprecate this function in favor of `cache_position`
|
90 |
-
# Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length
|
91 |
if len(self.key_cache) <= layer_idx:
|
92 |
return 0
|
93 |
return self.key_cache[layer_idx].shape[-2]
|
@@ -130,10 +119,6 @@ class SinkCache(Cache):
|
|
130 |
partial_rotation_size = cache_kwargs.get("partial_rotation_size")
|
131 |
using_rope = cos is not None and sin is not None
|
132 |
|
133 |
-
# Update the number of seen tokens
|
134 |
-
if layer_idx == 0:
|
135 |
-
self._seen_tokens += key_states.shape[-2]
|
136 |
-
|
137 |
# Update the sin/cos cache, which holds sin/cos values for all possible positions
|
138 |
if using_rope and layer_idx == 0:
|
139 |
# BC: some models still pass `sin`/`cos` with 2 dims. In those models, they are the full sin/cos. Remove
|
@@ -194,17 +179,52 @@ class SinkCache(Cache):
|
|
194 |
|
195 |
|
196 |
def generate(model, window_length=256, num_sink_tokens=4, **kwargs):
|
197 |
-
|
198 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
199 |
kwargs.pop("custom_generate", None)
|
200 |
|
201 |
-
#
|
|
|
202 |
past_key_values = kwargs.pop("past_key_values", None)
|
203 |
if past_key_values is None:
|
204 |
past_key_values = SinkCache(window_length=window_length, num_sink_tokens=num_sink_tokens)
|
205 |
elif not isinstance(past_key_values, SinkCache):
|
206 |
raise ValueError(f"`past_key_values` must be a `SinkCache` instance, got a {type(past_key_values)} instance")
|
207 |
|
208 |
-
# generate with the cache
|
209 |
generation_outputs = model.generate(**kwargs, past_key_values=past_key_values, use_cache=True)
|
210 |
return generation_outputs
|
|
|
1 |
import torch
|
2 |
from typing import Any, Dict, List, Optional, Tuple
|
|
|
|
|
3 |
|
4 |
+
from transformers import Cache, GenerationConfig
|
5 |
|
6 |
|
7 |
+
UNSUPPORTED_GENERATION_ARGS = [
|
8 |
+
"cache_implementation", # cache-related arguments, here we always use SinkCache
|
9 |
+
"cache_config",
|
10 |
+
"return_legacy_cache",
|
11 |
+
"num_beams", # beam search (and cousin techniques) are not supported
|
12 |
+
"compile_config", # SinkCache doesn't support torch.compile
|
13 |
+
"assistant_model", # it also doesn't support speculative decoding
|
14 |
+
]
|
15 |
+
|
16 |
class SinkCache(Cache):
|
17 |
"""
|
18 |
A cache that as described in the [Attention Sinks paper](https://arxiv.org/abs/2309.17453). It allows the model to
|
|
|
22 |
It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
|
23 |
`[batch_size, num_heads, seq_len, head_dim]`.
|
24 |
|
25 |
+
This class was copied from transformers 4.52.0, with minor modifications.
|
26 |
+
|
27 |
Parameters:
|
28 |
window_length (`int`):
|
29 |
The length of the context window.
|
30 |
num_sink_tokens (`int`):
|
31 |
The number of sink tokens. See the original paper for more information.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
"""
|
33 |
|
34 |
def __init__(self, window_length: int, num_sink_tokens: int) -> None:
|
|
|
40 |
self.cos_sin_rerotation_cache = {}
|
41 |
self._cos_cache = None
|
42 |
self._sin_cache = None
|
|
|
43 |
|
44 |
@staticmethod
|
45 |
def _rotate_half(x):
|
|
|
77 |
|
78 |
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
79 |
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
|
|
|
|
|
80 |
if len(self.key_cache) <= layer_idx:
|
81 |
return 0
|
82 |
return self.key_cache[layer_idx].shape[-2]
|
|
|
119 |
partial_rotation_size = cache_kwargs.get("partial_rotation_size")
|
120 |
using_rope = cos is not None and sin is not None
|
121 |
|
|
|
|
|
|
|
|
|
122 |
# Update the sin/cos cache, which holds sin/cos values for all possible positions
|
123 |
if using_rope and layer_idx == 0:
|
124 |
# BC: some models still pass `sin`/`cos` with 2 dims. In those models, they are the full sin/cos. Remove
|
|
|
179 |
|
180 |
|
181 |
def generate(model, window_length=256, num_sink_tokens=4, **kwargs):
|
182 |
+
"""Custom generate function for SinkCache.
|
183 |
+
|
184 |
+
Args:
|
185 |
+
model (`PreTrainedModel`):
|
186 |
+
The model to generate from.
|
187 |
+
window_length (`int`, *optional*, defaults to 256):
|
188 |
+
The length of the context window.
|
189 |
+
num_sink_tokens (`int`, *optional*, defaults to 4):
|
190 |
+
The number of sink tokens. See the original paper for more information.
|
191 |
+
"""
|
192 |
+
# 1. General sanity checks
|
193 |
+
# 1.a. A few arguments are not allowed, especially arguments that control caches.
|
194 |
+
generation_config = kwargs.get("generation_config")
|
195 |
+
default_global_generation_config = GenerationConfig()
|
196 |
+
default_model_generation_config = model.generation_config
|
197 |
+
for arg in UNSUPPORTED_GENERATION_ARGS:
|
198 |
+
has_custom_gen_config_arg = (
|
199 |
+
generation_config is not None
|
200 |
+
# = and not (match global default or match model-specific default)
|
201 |
+
and not (
|
202 |
+
getattr(default_model_generation_config, arg) == getattr(generation_config, arg)
|
203 |
+
or getattr(default_global_generation_config, arg) == getattr(generation_config, arg)
|
204 |
+
)
|
205 |
+
)
|
206 |
+
if arg in kwargs or has_custom_gen_config_arg:
|
207 |
+
raise ValueError(
|
208 |
+
f"`{arg}` is set, but it's not supported in this custom generate function. List of "
|
209 |
+
f"unsupported arguments: {UNSUPPORTED_GENERATION_ARGS}"
|
210 |
+
)
|
211 |
+
|
212 |
+
# 1.b. The model must be decoder-only
|
213 |
+
if model.config.is_encoder_decoder:
|
214 |
+
raise ValueError("This custom generate function only works with decoder-only models")
|
215 |
+
|
216 |
+
# 1.c. compatibility with transformers 4.52: we must pop `custom_generate` from kwargs, otherwise it will result
|
217 |
+
# in an infinite loop when we call `model.generate`. This is solved in transformers 4.53.
|
218 |
kwargs.pop("custom_generate", None)
|
219 |
|
220 |
+
# 2. Generate with SinkCache
|
221 |
+
# 2.a. prepare the cache, if it was not passed.
|
222 |
past_key_values = kwargs.pop("past_key_values", None)
|
223 |
if past_key_values is None:
|
224 |
past_key_values = SinkCache(window_length=window_length, num_sink_tokens=num_sink_tokens)
|
225 |
elif not isinstance(past_key_values, SinkCache):
|
226 |
raise ValueError(f"`past_key_values` must be a `SinkCache` instance, got a {type(past_key_values)} instance")
|
227 |
|
228 |
+
# 2.b. generate with the cache
|
229 |
generation_outputs = model.generate(**kwargs, past_key_values=past_key_values, use_cache=True)
|
230 |
return generation_outputs
|