joaogante HF Staff commited on
Commit
7f36015
·
1 Parent(s): 7bbd1e8

add sinkcache class

Browse files
Files changed (2) hide show
  1. README.md +1 -1
  2. custom_generate/generate.py +199 -0
README.md CHANGED
@@ -13,7 +13,7 @@ This implementation should match the `SinkCache` class present in `transformers<
13
 
14
  ![Sink Cache diagram from the original paper](https://arxiv.org/html/2309.17453v4/x1.png)
15
 
16
- ## Base model:
17
 
18
 
19
  ## Model compatibility
 
13
 
14
  ![Sink Cache diagram from the original paper](https://arxiv.org/html/2309.17453v4/x1.png)
15
 
16
+ ## Base model
17
 
18
 
19
  ## Model compatibility
custom_generate/generate.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Any, Dict, List, Optional, Tuple
3
+ from transformers.utils import logging
4
+ from transformers.cache_utils import Cache
5
+
6
+ logger = logging.get_logger(__name__)
7
+
8
+
9
+ class SinkCache(Cache):
10
+ """
11
+ A cache that as described in the [Attention Sinks paper](https://arxiv.org/abs/2309.17453). It allows the model to
12
+ generate beyond the length of its context window, without losing fluency in the conversation. As it discards past
13
+ tokens, the model will lose the ability to generate tokens that depend on the context that was discarded.
14
+
15
+ It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
16
+ `[batch_size, num_heads, seq_len, head_dim]`.
17
+
18
+ Parameters:
19
+ window_length (`int`):
20
+ The length of the context window.
21
+ num_sink_tokens (`int`):
22
+ The number of sink tokens. See the original paper for more information.
23
+
24
+ Example:
25
+
26
+ ```python
27
+ >>> from transformers import AutoTokenizer, AutoModelForCausalLM, SinkCache
28
+
29
+ >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
30
+ >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
31
+
32
+ >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt")
33
+
34
+ >>> # Prepare a cache class and pass it to model's forward
35
+ >>> past_key_values = SinkCache(window_length=256, num_sink_tokens=4)
36
+ >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
37
+ >>> outputs.past_key_values # access cache filled with key/values from generation
38
+ SinkCache()
39
+ ```
40
+ """
41
+
42
+ def __init__(self, window_length: int, num_sink_tokens: int) -> None:
43
+ super().__init__()
44
+ self.key_cache: List[torch.Tensor] = []
45
+ self.value_cache: List[torch.Tensor] = []
46
+ self.window_length = window_length
47
+ self.num_sink_tokens = num_sink_tokens
48
+ self.cos_sin_rerotation_cache = {}
49
+ self._cos_cache = None
50
+ self._sin_cache = None
51
+ self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
52
+
53
+ @staticmethod
54
+ def _rotate_half(x):
55
+ x1 = x[..., : x.shape[-1] // 2]
56
+ x2 = x[..., x.shape[-1] // 2 :]
57
+ return torch.cat((-x2, x1), dim=-1)
58
+
59
+ def _apply_key_rotary_pos_emb(
60
+ self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
61
+ ) -> torch.Tensor:
62
+ rotated_key_states = (key_states * cos) + (self._rotate_half(key_states) * sin)
63
+ return rotated_key_states
64
+
65
+ def _get_rerotation_cos_sin(
66
+ self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
67
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
68
+ if key_states.shape[-2] not in self.cos_sin_rerotation_cache:
69
+ # Upcast to float32 temporarily for better accuracy
70
+ cos = cos.to(torch.float32)
71
+ sin = sin.to(torch.float32)
72
+
73
+ # Compute the cos and sin required for back- and forward-rotating to one position earlier in the sequence
74
+ original_cos = cos[self.num_sink_tokens + key_states.shape[-2] :]
75
+ shifted_cos = cos[self.num_sink_tokens : -key_states.shape[-2]]
76
+ original_sin = sin[self.num_sink_tokens + key_states.shape[-2] :]
77
+ shifted_sin = sin[self.num_sink_tokens : -key_states.shape[-2]]
78
+ rerotation_cos = original_cos * shifted_cos + original_sin * shifted_sin
79
+ rerotation_sin = -original_sin * shifted_cos + original_cos * shifted_sin
80
+
81
+ self.cos_sin_rerotation_cache[key_states.shape[-2]] = (
82
+ rerotation_cos.to(key_states.dtype).unsqueeze(0),
83
+ rerotation_sin.to(key_states.dtype).unsqueeze(0),
84
+ )
85
+ return self.cos_sin_rerotation_cache[key_states.shape[-2]]
86
+
87
+ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
88
+ """Returns the sequence length of the cached states. A layer index can be optionally passed."""
89
+ # TODO: deprecate this function in favor of `cache_position`
90
+ # Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length
91
+ if len(self.key_cache) <= layer_idx:
92
+ return 0
93
+ return self.key_cache[layer_idx].shape[-2]
94
+
95
+ def get_max_cache_shape(self) -> Optional[int]:
96
+ """Returns the maximum sequence length of the cache object, in case of SinkCache it is the window length."""
97
+ return self.window_length
98
+
99
+ def update(
100
+ self,
101
+ key_states: torch.Tensor,
102
+ value_states: torch.Tensor,
103
+ layer_idx: int,
104
+ cache_kwargs: Optional[Dict[str, Any]] = None,
105
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
106
+ """
107
+ Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
108
+
109
+ Parameters:
110
+ key_states (`torch.Tensor`):
111
+ The new key states to cache.
112
+ value_states (`torch.Tensor`):
113
+ The new value states to cache.
114
+ layer_idx (`int`):
115
+ The index of the layer to cache the states for.
116
+ cache_kwargs (`Dict[str, Any]`, `optional`):
117
+ Additional arguments for the cache subclass. The following arguments can be used in `SinkCache`: `sin`,
118
+ `cos` and `partial_rotation_size`. These arguments are used with models using RoPE, to recompute the
119
+ rotation as the tokens are shifted.
120
+
121
+ Return:
122
+ A tuple containing the updated key and value states.
123
+ """
124
+ # Optional kwargs for `SinkCache` -- needed on models using RoPE. `partial_rotation_size` is used on models
125
+ # with partially rotated position embeddings, like Phi or Persimmon.
126
+ if cache_kwargs is None:
127
+ cache_kwargs = {}
128
+ sin = cache_kwargs.get("sin")
129
+ cos = cache_kwargs.get("cos")
130
+ partial_rotation_size = cache_kwargs.get("partial_rotation_size")
131
+ using_rope = cos is not None and sin is not None
132
+
133
+ # Update the number of seen tokens
134
+ if layer_idx == 0:
135
+ self._seen_tokens += key_states.shape[-2]
136
+
137
+ # Update the sin/cos cache, which holds sin/cos values for all possible positions
138
+ if using_rope and layer_idx == 0:
139
+ # BC: some models still pass `sin`/`cos` with 2 dims. In those models, they are the full sin/cos. Remove
140
+ # after all RoPE models have a llama-like cache utilization.
141
+ if cos.dim() == 2:
142
+ self._cos_cache = cos
143
+ self._sin_cache = sin
144
+ else:
145
+ if self._cos_cache is None:
146
+ self._cos_cache = cos[0, ...]
147
+ self._sin_cache = sin[0, ...]
148
+ elif self._cos_cache.shape[0] < self.window_length:
149
+ self._cos_cache = torch.cat([self._cos_cache, cos[0, ...]], dim=0)
150
+ self._sin_cache = torch.cat([self._sin_cache, sin[0, ...]], dim=0)
151
+
152
+ # [bsz, num_heads, seq_len, head_dim]
153
+ if len(self.key_cache) <= layer_idx:
154
+ # Empty cache
155
+ self.key_cache.append(key_states)
156
+ self.value_cache.append(value_states)
157
+
158
+ elif key_states.shape[-2] + self.get_seq_length(layer_idx) < self.window_length:
159
+ # Growing cache
160
+ self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
161
+ self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
162
+
163
+ else:
164
+ # Shifting cache
165
+ keys_to_keep = self.key_cache[layer_idx][
166
+ :, :, -self.window_length + self.num_sink_tokens + key_states.shape[-2] :
167
+ ]
168
+
169
+ # On RoPE models, we need to recompute the Key rotation as the tokens are shifted
170
+ if using_rope:
171
+ rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin(
172
+ key_states, self._cos_cache[: self.window_length], self._sin_cache[: self.window_length]
173
+ )
174
+ if partial_rotation_size is not None:
175
+ keys_to_keep, keys_pass = (
176
+ keys_to_keep[..., :partial_rotation_size],
177
+ keys_to_keep[..., partial_rotation_size:],
178
+ )
179
+ keys_to_keep = self._apply_key_rotary_pos_emb(keys_to_keep, rerotation_cos, rerotation_sin)
180
+ if partial_rotation_size is not None:
181
+ keys_to_keep = torch.cat((keys_to_keep, keys_pass), dim=-1)
182
+
183
+ # Concatenate sink tokens, shifted & rotated tokens (if needed), and new tokens
184
+ sink_keys = self.key_cache[layer_idx][:, :, : self.num_sink_tokens]
185
+ self.key_cache[layer_idx] = torch.cat([sink_keys, keys_to_keep, key_states], dim=-2)
186
+
187
+ sink_values = self.value_cache[layer_idx][:, :, : self.num_sink_tokens]
188
+ values_to_keep = self.value_cache[layer_idx][
189
+ :, :, -self.window_length + self.num_sink_tokens + value_states.shape[-2] :
190
+ ]
191
+ self.value_cache[layer_idx] = torch.cat([sink_values, values_to_keep, value_states], dim=-2)
192
+
193
+ return self.key_cache[layer_idx], self.value_cache[layer_idx]
194
+
195
+
196
+ def generate(model, **kwargs):
197
+ past_key_values = SinkCache(window_length=256, num_sink_tokens=4)
198
+ generation_outputs = model.generate(**kwargs, past_key_values=past_key_values, use_cache=True)
199
+ return generation_outputs