joaogante HF Staff commited on
Commit
ffa8b5a
·
1 Parent(s): f84586a

add sanity checks

Browse files
Files changed (2) hide show
  1. README.md +2 -2
  2. custom_generate/generate.py +51 -31
README.md CHANGED
@@ -21,8 +21,8 @@ This implementation should match the `SinkCache` class present in `transformers<
21
 
22
 
23
  ## Additional Arguments
24
- - `window_length` (`int`, defaults to `256`): The length of the context window.
25
- - `num_sink_tokens` (`int`, defaults to `4`): The number of sink tokens. See the original paper for more information.
26
 
27
 
28
  ## Output Type changes
 
21
 
22
 
23
  ## Additional Arguments
24
+ - `window_length` (`int`, *optional*, defaults to 256): The length of the context window.
25
+ - `num_sink_tokens` (`int`, *optional*, defaults to 4): The number of sink tokens. See the original paper for more information.
26
 
27
 
28
  ## Output Type changes
custom_generate/generate.py CHANGED
@@ -1,11 +1,18 @@
1
  import torch
2
  from typing import Any, Dict, List, Optional, Tuple
3
- from transformers.utils import logging
4
- from transformers.cache_utils import Cache
5
 
6
- logger = logging.get_logger(__name__)
7
 
8
 
 
 
 
 
 
 
 
 
 
9
  class SinkCache(Cache):
10
  """
11
  A cache that as described in the [Attention Sinks paper](https://arxiv.org/abs/2309.17453). It allows the model to
@@ -15,28 +22,13 @@ class SinkCache(Cache):
15
  It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
16
  `[batch_size, num_heads, seq_len, head_dim]`.
17
 
 
 
18
  Parameters:
19
  window_length (`int`):
20
  The length of the context window.
21
  num_sink_tokens (`int`):
22
  The number of sink tokens. See the original paper for more information.
23
-
24
- Example:
25
-
26
- ```python
27
- >>> from transformers import AutoTokenizer, AutoModelForCausalLM, SinkCache
28
-
29
- >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
30
- >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
31
-
32
- >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt")
33
-
34
- >>> # Prepare a cache class and pass it to model's forward
35
- >>> past_key_values = SinkCache(window_length=256, num_sink_tokens=4)
36
- >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
37
- >>> outputs.past_key_values # access cache filled with key/values from generation
38
- SinkCache()
39
- ```
40
  """
41
 
42
  def __init__(self, window_length: int, num_sink_tokens: int) -> None:
@@ -48,7 +40,6 @@ class SinkCache(Cache):
48
  self.cos_sin_rerotation_cache = {}
49
  self._cos_cache = None
50
  self._sin_cache = None
51
- self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
52
 
53
  @staticmethod
54
  def _rotate_half(x):
@@ -86,8 +77,6 @@ class SinkCache(Cache):
86
 
87
  def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
88
  """Returns the sequence length of the cached states. A layer index can be optionally passed."""
89
- # TODO: deprecate this function in favor of `cache_position`
90
- # Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length
91
  if len(self.key_cache) <= layer_idx:
92
  return 0
93
  return self.key_cache[layer_idx].shape[-2]
@@ -130,10 +119,6 @@ class SinkCache(Cache):
130
  partial_rotation_size = cache_kwargs.get("partial_rotation_size")
131
  using_rope = cos is not None and sin is not None
132
 
133
- # Update the number of seen tokens
134
- if layer_idx == 0:
135
- self._seen_tokens += key_states.shape[-2]
136
-
137
  # Update the sin/cos cache, which holds sin/cos values for all possible positions
138
  if using_rope and layer_idx == 0:
139
  # BC: some models still pass `sin`/`cos` with 2 dims. In those models, they are the full sin/cos. Remove
@@ -194,17 +179,52 @@ class SinkCache(Cache):
194
 
195
 
196
  def generate(model, window_length=256, num_sink_tokens=4, **kwargs):
197
- # compatibility with transformers 4.52: we must pop `custom_generate` from kwargs, otherwise it will result in an
198
- # infinite loop. This is solved in transformers 4.53.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
  kwargs.pop("custom_generate", None)
200
 
201
- # prepare the cache, it is was not passed.
 
202
  past_key_values = kwargs.pop("past_key_values", None)
203
  if past_key_values is None:
204
  past_key_values = SinkCache(window_length=window_length, num_sink_tokens=num_sink_tokens)
205
  elif not isinstance(past_key_values, SinkCache):
206
  raise ValueError(f"`past_key_values` must be a `SinkCache` instance, got a {type(past_key_values)} instance")
207
 
208
- # generate with the cache
209
  generation_outputs = model.generate(**kwargs, past_key_values=past_key_values, use_cache=True)
210
  return generation_outputs
 
1
  import torch
2
  from typing import Any, Dict, List, Optional, Tuple
 
 
3
 
4
+ from transformers import Cache, GenerationConfig
5
 
6
 
7
+ UNSUPPORTED_GENERATION_ARGS = [
8
+ "cache_implementation", # cache-related arguments, here we always use SinkCache
9
+ "cache_config",
10
+ "return_legacy_cache",
11
+ "num_beams", # beam search (and cousin techniques) are not supported
12
+ "compile_config", # SinkCache doesn't support torch.compile
13
+ "assistant_model", # it also doesn't support speculative decoding
14
+ ]
15
+
16
  class SinkCache(Cache):
17
  """
18
  A cache that as described in the [Attention Sinks paper](https://arxiv.org/abs/2309.17453). It allows the model to
 
22
  It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
23
  `[batch_size, num_heads, seq_len, head_dim]`.
24
 
25
+ This class was copied from transformers 4.52.0, with minor modifications.
26
+
27
  Parameters:
28
  window_length (`int`):
29
  The length of the context window.
30
  num_sink_tokens (`int`):
31
  The number of sink tokens. See the original paper for more information.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  """
33
 
34
  def __init__(self, window_length: int, num_sink_tokens: int) -> None:
 
40
  self.cos_sin_rerotation_cache = {}
41
  self._cos_cache = None
42
  self._sin_cache = None
 
43
 
44
  @staticmethod
45
  def _rotate_half(x):
 
77
 
78
  def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
79
  """Returns the sequence length of the cached states. A layer index can be optionally passed."""
 
 
80
  if len(self.key_cache) <= layer_idx:
81
  return 0
82
  return self.key_cache[layer_idx].shape[-2]
 
119
  partial_rotation_size = cache_kwargs.get("partial_rotation_size")
120
  using_rope = cos is not None and sin is not None
121
 
 
 
 
 
122
  # Update the sin/cos cache, which holds sin/cos values for all possible positions
123
  if using_rope and layer_idx == 0:
124
  # BC: some models still pass `sin`/`cos` with 2 dims. In those models, they are the full sin/cos. Remove
 
179
 
180
 
181
  def generate(model, window_length=256, num_sink_tokens=4, **kwargs):
182
+ """Custom generate function for SinkCache.
183
+
184
+ Args:
185
+ model (`PreTrainedModel`):
186
+ The model to generate from.
187
+ window_length (`int`, *optional*, defaults to 256):
188
+ The length of the context window.
189
+ num_sink_tokens (`int`, *optional*, defaults to 4):
190
+ The number of sink tokens. See the original paper for more information.
191
+ """
192
+ # 1. General sanity checks
193
+ # 1.a. A few arguments are not allowed, especially arguments that control caches.
194
+ generation_config = kwargs.get("generation_config")
195
+ default_global_generation_config = GenerationConfig()
196
+ default_model_generation_config = model.generation_config
197
+ for arg in UNSUPPORTED_GENERATION_ARGS:
198
+ has_custom_gen_config_arg = (
199
+ generation_config is not None
200
+ # = and not (match global default or match model-specific default)
201
+ and not (
202
+ getattr(default_model_generation_config, arg) == getattr(generation_config, arg)
203
+ or getattr(default_global_generation_config, arg) == getattr(generation_config, arg)
204
+ )
205
+ )
206
+ if arg in kwargs or has_custom_gen_config_arg:
207
+ raise ValueError(
208
+ f"`{arg}` is set, but it's not supported in this custom generate function. List of "
209
+ f"unsupported arguments: {UNSUPPORTED_GENERATION_ARGS}"
210
+ )
211
+
212
+ # 1.b. The model must be decoder-only
213
+ if model.config.is_encoder_decoder:
214
+ raise ValueError("This custom generate function only works with decoder-only models")
215
+
216
+ # 1.c. compatibility with transformers 4.52: we must pop `custom_generate` from kwargs, otherwise it will result
217
+ # in an infinite loop when we call `model.generate`. This is solved in transformers 4.53.
218
  kwargs.pop("custom_generate", None)
219
 
220
+ # 2. Generate with SinkCache
221
+ # 2.a. prepare the cache, if it was not passed.
222
  past_key_values = kwargs.pop("past_key_values", None)
223
  if past_key_values is None:
224
  past_key_values = SinkCache(window_length=window_length, num_sink_tokens=num_sink_tokens)
225
  elif not isinstance(past_key_values, SinkCache):
226
  raise ValueError(f"`past_key_values` must be a `SinkCache` instance, got a {type(past_key_values)} instance")
227
 
228
+ # 2.b. generate with the cache
229
  generation_outputs = model.generate(**kwargs, past_key_values=past_key_values, use_cache=True)
230
  return generation_outputs