add sinkcache class
Browse files- README.md +1 -1
- custom_generate/generate.py +199 -0
README.md
CHANGED
@@ -13,7 +13,7 @@ This implementation should match the `SinkCache` class present in `transformers<
|
|
13 |
|
14 |

|
15 |
|
16 |
-
## Base model
|
17 |
|
18 |
|
19 |
## Model compatibility
|
|
|
13 |
|
14 |

|
15 |
|
16 |
+
## Base model
|
17 |
|
18 |
|
19 |
## Model compatibility
|
custom_generate/generate.py
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from typing import Any, Dict, List, Optional, Tuple
|
3 |
+
from transformers.utils import logging
|
4 |
+
from transformers.cache_utils import Cache
|
5 |
+
|
6 |
+
logger = logging.get_logger(__name__)
|
7 |
+
|
8 |
+
|
9 |
+
class SinkCache(Cache):
|
10 |
+
"""
|
11 |
+
A cache that as described in the [Attention Sinks paper](https://arxiv.org/abs/2309.17453). It allows the model to
|
12 |
+
generate beyond the length of its context window, without losing fluency in the conversation. As it discards past
|
13 |
+
tokens, the model will lose the ability to generate tokens that depend on the context that was discarded.
|
14 |
+
|
15 |
+
It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
|
16 |
+
`[batch_size, num_heads, seq_len, head_dim]`.
|
17 |
+
|
18 |
+
Parameters:
|
19 |
+
window_length (`int`):
|
20 |
+
The length of the context window.
|
21 |
+
num_sink_tokens (`int`):
|
22 |
+
The number of sink tokens. See the original paper for more information.
|
23 |
+
|
24 |
+
Example:
|
25 |
+
|
26 |
+
```python
|
27 |
+
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, SinkCache
|
28 |
+
|
29 |
+
>>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
|
30 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
|
31 |
+
|
32 |
+
>>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt")
|
33 |
+
|
34 |
+
>>> # Prepare a cache class and pass it to model's forward
|
35 |
+
>>> past_key_values = SinkCache(window_length=256, num_sink_tokens=4)
|
36 |
+
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
|
37 |
+
>>> outputs.past_key_values # access cache filled with key/values from generation
|
38 |
+
SinkCache()
|
39 |
+
```
|
40 |
+
"""
|
41 |
+
|
42 |
+
def __init__(self, window_length: int, num_sink_tokens: int) -> None:
|
43 |
+
super().__init__()
|
44 |
+
self.key_cache: List[torch.Tensor] = []
|
45 |
+
self.value_cache: List[torch.Tensor] = []
|
46 |
+
self.window_length = window_length
|
47 |
+
self.num_sink_tokens = num_sink_tokens
|
48 |
+
self.cos_sin_rerotation_cache = {}
|
49 |
+
self._cos_cache = None
|
50 |
+
self._sin_cache = None
|
51 |
+
self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
|
52 |
+
|
53 |
+
@staticmethod
|
54 |
+
def _rotate_half(x):
|
55 |
+
x1 = x[..., : x.shape[-1] // 2]
|
56 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
57 |
+
return torch.cat((-x2, x1), dim=-1)
|
58 |
+
|
59 |
+
def _apply_key_rotary_pos_emb(
|
60 |
+
self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
|
61 |
+
) -> torch.Tensor:
|
62 |
+
rotated_key_states = (key_states * cos) + (self._rotate_half(key_states) * sin)
|
63 |
+
return rotated_key_states
|
64 |
+
|
65 |
+
def _get_rerotation_cos_sin(
|
66 |
+
self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
|
67 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
68 |
+
if key_states.shape[-2] not in self.cos_sin_rerotation_cache:
|
69 |
+
# Upcast to float32 temporarily for better accuracy
|
70 |
+
cos = cos.to(torch.float32)
|
71 |
+
sin = sin.to(torch.float32)
|
72 |
+
|
73 |
+
# Compute the cos and sin required for back- and forward-rotating to one position earlier in the sequence
|
74 |
+
original_cos = cos[self.num_sink_tokens + key_states.shape[-2] :]
|
75 |
+
shifted_cos = cos[self.num_sink_tokens : -key_states.shape[-2]]
|
76 |
+
original_sin = sin[self.num_sink_tokens + key_states.shape[-2] :]
|
77 |
+
shifted_sin = sin[self.num_sink_tokens : -key_states.shape[-2]]
|
78 |
+
rerotation_cos = original_cos * shifted_cos + original_sin * shifted_sin
|
79 |
+
rerotation_sin = -original_sin * shifted_cos + original_cos * shifted_sin
|
80 |
+
|
81 |
+
self.cos_sin_rerotation_cache[key_states.shape[-2]] = (
|
82 |
+
rerotation_cos.to(key_states.dtype).unsqueeze(0),
|
83 |
+
rerotation_sin.to(key_states.dtype).unsqueeze(0),
|
84 |
+
)
|
85 |
+
return self.cos_sin_rerotation_cache[key_states.shape[-2]]
|
86 |
+
|
87 |
+
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
88 |
+
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
|
89 |
+
# TODO: deprecate this function in favor of `cache_position`
|
90 |
+
# Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length
|
91 |
+
if len(self.key_cache) <= layer_idx:
|
92 |
+
return 0
|
93 |
+
return self.key_cache[layer_idx].shape[-2]
|
94 |
+
|
95 |
+
def get_max_cache_shape(self) -> Optional[int]:
|
96 |
+
"""Returns the maximum sequence length of the cache object, in case of SinkCache it is the window length."""
|
97 |
+
return self.window_length
|
98 |
+
|
99 |
+
def update(
|
100 |
+
self,
|
101 |
+
key_states: torch.Tensor,
|
102 |
+
value_states: torch.Tensor,
|
103 |
+
layer_idx: int,
|
104 |
+
cache_kwargs: Optional[Dict[str, Any]] = None,
|
105 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
106 |
+
"""
|
107 |
+
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
|
108 |
+
|
109 |
+
Parameters:
|
110 |
+
key_states (`torch.Tensor`):
|
111 |
+
The new key states to cache.
|
112 |
+
value_states (`torch.Tensor`):
|
113 |
+
The new value states to cache.
|
114 |
+
layer_idx (`int`):
|
115 |
+
The index of the layer to cache the states for.
|
116 |
+
cache_kwargs (`Dict[str, Any]`, `optional`):
|
117 |
+
Additional arguments for the cache subclass. The following arguments can be used in `SinkCache`: `sin`,
|
118 |
+
`cos` and `partial_rotation_size`. These arguments are used with models using RoPE, to recompute the
|
119 |
+
rotation as the tokens are shifted.
|
120 |
+
|
121 |
+
Return:
|
122 |
+
A tuple containing the updated key and value states.
|
123 |
+
"""
|
124 |
+
# Optional kwargs for `SinkCache` -- needed on models using RoPE. `partial_rotation_size` is used on models
|
125 |
+
# with partially rotated position embeddings, like Phi or Persimmon.
|
126 |
+
if cache_kwargs is None:
|
127 |
+
cache_kwargs = {}
|
128 |
+
sin = cache_kwargs.get("sin")
|
129 |
+
cos = cache_kwargs.get("cos")
|
130 |
+
partial_rotation_size = cache_kwargs.get("partial_rotation_size")
|
131 |
+
using_rope = cos is not None and sin is not None
|
132 |
+
|
133 |
+
# Update the number of seen tokens
|
134 |
+
if layer_idx == 0:
|
135 |
+
self._seen_tokens += key_states.shape[-2]
|
136 |
+
|
137 |
+
# Update the sin/cos cache, which holds sin/cos values for all possible positions
|
138 |
+
if using_rope and layer_idx == 0:
|
139 |
+
# BC: some models still pass `sin`/`cos` with 2 dims. In those models, they are the full sin/cos. Remove
|
140 |
+
# after all RoPE models have a llama-like cache utilization.
|
141 |
+
if cos.dim() == 2:
|
142 |
+
self._cos_cache = cos
|
143 |
+
self._sin_cache = sin
|
144 |
+
else:
|
145 |
+
if self._cos_cache is None:
|
146 |
+
self._cos_cache = cos[0, ...]
|
147 |
+
self._sin_cache = sin[0, ...]
|
148 |
+
elif self._cos_cache.shape[0] < self.window_length:
|
149 |
+
self._cos_cache = torch.cat([self._cos_cache, cos[0, ...]], dim=0)
|
150 |
+
self._sin_cache = torch.cat([self._sin_cache, sin[0, ...]], dim=0)
|
151 |
+
|
152 |
+
# [bsz, num_heads, seq_len, head_dim]
|
153 |
+
if len(self.key_cache) <= layer_idx:
|
154 |
+
# Empty cache
|
155 |
+
self.key_cache.append(key_states)
|
156 |
+
self.value_cache.append(value_states)
|
157 |
+
|
158 |
+
elif key_states.shape[-2] + self.get_seq_length(layer_idx) < self.window_length:
|
159 |
+
# Growing cache
|
160 |
+
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
|
161 |
+
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
|
162 |
+
|
163 |
+
else:
|
164 |
+
# Shifting cache
|
165 |
+
keys_to_keep = self.key_cache[layer_idx][
|
166 |
+
:, :, -self.window_length + self.num_sink_tokens + key_states.shape[-2] :
|
167 |
+
]
|
168 |
+
|
169 |
+
# On RoPE models, we need to recompute the Key rotation as the tokens are shifted
|
170 |
+
if using_rope:
|
171 |
+
rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin(
|
172 |
+
key_states, self._cos_cache[: self.window_length], self._sin_cache[: self.window_length]
|
173 |
+
)
|
174 |
+
if partial_rotation_size is not None:
|
175 |
+
keys_to_keep, keys_pass = (
|
176 |
+
keys_to_keep[..., :partial_rotation_size],
|
177 |
+
keys_to_keep[..., partial_rotation_size:],
|
178 |
+
)
|
179 |
+
keys_to_keep = self._apply_key_rotary_pos_emb(keys_to_keep, rerotation_cos, rerotation_sin)
|
180 |
+
if partial_rotation_size is not None:
|
181 |
+
keys_to_keep = torch.cat((keys_to_keep, keys_pass), dim=-1)
|
182 |
+
|
183 |
+
# Concatenate sink tokens, shifted & rotated tokens (if needed), and new tokens
|
184 |
+
sink_keys = self.key_cache[layer_idx][:, :, : self.num_sink_tokens]
|
185 |
+
self.key_cache[layer_idx] = torch.cat([sink_keys, keys_to_keep, key_states], dim=-2)
|
186 |
+
|
187 |
+
sink_values = self.value_cache[layer_idx][:, :, : self.num_sink_tokens]
|
188 |
+
values_to_keep = self.value_cache[layer_idx][
|
189 |
+
:, :, -self.window_length + self.num_sink_tokens + value_states.shape[-2] :
|
190 |
+
]
|
191 |
+
self.value_cache[layer_idx] = torch.cat([sink_values, values_to_keep, value_states], dim=-2)
|
192 |
+
|
193 |
+
return self.key_cache[layer_idx], self.value_cache[layer_idx]
|
194 |
+
|
195 |
+
|
196 |
+
def generate(model, **kwargs):
|
197 |
+
past_key_values = SinkCache(window_length=256, num_sink_tokens=4)
|
198 |
+
generation_outputs = model.generate(**kwargs, past_key_values=past_key_values, use_cache=True)
|
199 |
+
return generation_outputs
|