Update README.md
Browse files
README.md
CHANGED
@@ -1,347 +1,349 @@
|
|
1 |
-
# SepCache - Native Sparse Attention Cache
|
2 |
-
|
3 |
-
## Table of Contents
|
4 |
-
|
5 |
-
- [1. Abstract](#1-abstract)
|
6 |
-
- [2. Usage](#2-usage)
|
7 |
-
- [2.1 Sample Base Model](#21-sample-base-model)
|
8 |
-
- [2.2 Quick Start](#22-quick-start)
|
9 |
-
- [2.2.1 Environment Setup](#221-environment-setup)
|
10 |
-
- [2.2.2 A Simple Example](#222-a-simple-example)
|
11 |
-
- [2.2.3 Frequently-Used Parameters](#223-frequently-used-parameters)
|
12 |
-
- [2.2.4 Update Function](#224-update-function)
|
13 |
-
- [2.2.5 Monkey Patch Demo](#225-monkey-patch-demo)
|
14 |
-
- [2.2.6 Downstream Task Evaluation](#226-downstream-task-evaluation)
|
15 |
-
- [2.2.7 The Detailed Signature of `generate` Function](#227-the-detailed-signature-of-generate-function)
|
16 |
-
- [3. Adaptation for Other Models](#3-adaptation-for-other-models)
|
17 |
-
- [3.1 Method 1 - Monkey Patching](#31-method-1---monkey-patching)
|
18 |
-
- [3.2 Method 2 - Direct Code Modification](#32-method-2---direct-code-modification)
|
19 |
-
- [3.3 Important Note](#33-important-note)
|
20 |
-
- [4. Other Advanced Usage](#4-other-advanced-usage)
|
21 |
-
|
22 |
-
---
|
23 |
-
|
24 |
-
## 1. Abstract
|
25 |
-
`SepCache` is a simple yet effective, native sparse attention `Cache` class proposed in the [`SepLLM paper - ICML 2025`](https://icml.cc/virtual/2025/poster/45536), which most closely aligns with the semantic distribution of natural language. In the training phase, `SepLLM` condenses the segment information into the KV of the separator that divides the segment. In the inference phase, the corresponding `SepCache` only needs to store the KVs of initial tokens, separator tokens, and recent tokens for generation.
|
26 |
-
|
27 |
-
Notably, `SepCache` also delivers strong performance across many tasks in training-free scenarios. Moreover, `SepLLM` (or simply `SepCache`) is the **most suitable baseline method for sparse attention mechanisms and KV compression/management**, as it is the natively sparse attention mechanism that best aligns with the natural semantic distribution of language.
|
28 |
-
|
29 |
-
See more details and advanced usage in https://github.com/HKUDS/SepLLM
|
30 |
-
|
31 |
-

|
32 |
-
|
33 |
-
## 2. Usage
|
34 |
-
|
35 |
-
### 2.1 Sample Base Model
|
36 |
-
|
37 |
-
We recommend using models from the **Llama 3 series**. Our example model is based on `meta-llama/Meta-Llama-3-8B-Instruct`, for which we have already prepared a targeted `monkey patch`.
|
38 |
-
|
39 |
-
For other models, using `SepCache` requires minor modifications to the corresponding `modeling_xxx.py` file or writing a **custom monkey patch**. These changes are **very simple** -- you only need to pass arguments like `input_ids` to the `update` function of `SepCache` when calling it.
|
40 |
-
|
41 |
-
We will provide a detailed guide later on how to modify your `modeling_xxx.py` file or `monkey patch` file to adapt `SepCache` to any model.
|
42 |
-
|
43 |
-
### 2.2 Quick Start
|
44 |
-
|
45 |
-
#### 2.2.1 Environment Setup
|
46 |
-
You need to install `transformers>=4.53`, and we recommend using `lm_eval>=0.4.9` for running evaluations. We suggest managing your Python environment with `conda` for better dependency control.
|
47 |
-
|
48 |
-
```bash
|
49 |
-
conda create -n sepcache python=3.10
|
50 |
-
conda activate sepcache
|
51 |
-
pip install transformers==4.53
|
52 |
-
pip install lm_eval==0.4.9
|
53 |
-
```
|
54 |
-
#### 2.2.2 A Simple Example
|
55 |
-
You can use `SepCache` by specifying `custom_generate="transformers-community/sep_cache"` or `custom_generate="Gausson/sep_cache"` when calling the `generate` function. In our demo, we have already prepared sample monkey patching for the `Llama 3 series` models and provided some common parameters for initializing `SepCache`.
|
56 |
-
|
57 |
-
```python
|
58 |
-
# requires `transformers>=4.53.0`
|
59 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
60 |
-
|
61 |
-
# Preparing model, tokenizer, and model inputs
|
62 |
-
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
|
63 |
-
model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct", device_map="auto")
|
64 |
-
|
65 |
-
|
66 |
-
messages = [{"role": "user", "content": "Tell me a story about a cat."}]
|
67 |
-
text = tokenizer.apply_chat_template(
|
68 |
-
messages,
|
69 |
-
tokenize=False,
|
70 |
-
add_generation_prompt=True,
|
71 |
-
enable_thinking=False
|
72 |
-
)
|
73 |
-
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
|
74 |
-
|
75 |
-
|
76 |
-
# Using SepCache for generation
|
77 |
-
gen_out = model.generate(
|
78 |
-
# usual `generate` arguments
|
79 |
-
**model_inputs,
|
80 |
-
do_sample=False,
|
81 |
-
max_new_tokens=100,
|
82 |
-
return_dict_in_generate=True,
|
83 |
-
monkey_patch_verbose = True, # To see which functions are actually being monkey patched for `SepCache`.
|
84 |
-
|
85 |
-
# Using SepCache
|
86 |
-
custom_generate="transformers-community/sep_cache", ## Alternatively, you can use `Gausson/sep_cache`
|
87 |
-
trust_remote_code=True,
|
88 |
-
|
89 |
-
# SepCache arguments
|
90 |
-
init_cache_size = 4,
|
91 |
-
sep_cache_size = 128,
|
92 |
-
local_size = 256,
|
93 |
-
cache_size = 512,
|
94 |
-
USE_MAX_SEP_CACHE = True,
|
95 |
-
model_type = 'llama'
|
96 |
-
)
|
97 |
-
|
98 |
-
print(tokenizer.batch_decode(gen_out.sequences, skip_special_tokens=True))
|
99 |
-
assert "sepcache" in str(type(gen_out.past_key_values)).lower()
|
100 |
-
```
|
101 |
-
|
102 |
-
It is worth noting that you must specify the `separator_token_ids: List[int]` and `PADDING_ID: int` parameters for initializing `SepCache`. In the example above, we did not do this because, for convenience, in the demo above, we specified `model_type = "llama"`, in which case `separator_token_ids` and `PADDING_ID` will be automatically filled.
|
103 |
-
|
104 |
-
However, when you use a tokenizer for a non-Llama 3 series model, you need to specify the specific values of `separator_token_ids` and `PADDING_ID` based on the tokenizer you are using. For example, the following example is based on the values obtained from a Llama 3 series tokenizer.
|
105 |
-
```python
|
106 |
-
# Using SepCache for generation
|
107 |
-
gen_out = model.generate(
|
108 |
-
# usual `generate` arguments
|
109 |
-
**model_inputs,
|
110 |
-
do_sample=False,
|
111 |
-
max_new_tokens=100,
|
112 |
-
return_dict_in_generate=True,
|
113 |
-
monkey_patch_verbose = True, # To see which functions are actually being monkey patched for `SepCache`.
|
114 |
-
|
115 |
-
# Using SepCache
|
116 |
-
custom_generate="transformers-community/sep_cache", ## Alternatively, you can use `Gausson/sep_cache`
|
117 |
-
trust_remote_code=True,
|
118 |
-
|
119 |
-
# SepCache arguments
|
120 |
-
init_cache_size = 4,
|
121 |
-
sep_cache_size = 128,
|
122 |
-
local_size = 256,
|
123 |
-
cache_size = 512,
|
124 |
-
USE_MAX_SEP_CACHE = True,
|
125 |
-
separator_token_ids = [128000, 13, 11, 30, 0, 26, 25, 198, 220, 662, 1174, 949, 758, 2652, 551, 720, 256,262],
|
126 |
-
PADDING_ID = 128009
|
127 |
-
)
|
128 |
-
```
|
129 |
-
|
130 |
-
|
131 |
-
#### 2.2.3 Frequently-Used Parameters
|
132 |
-
|
133 |
-
Below, we provide explanations and examples for the most commonly used parameters when initializing `SepCache`. These parameters can be passed through the `generate` function.
|
134 |
-
|
135 |
-
```
|
136 |
-
`SepCache` stores the Key and Value states as lists of tensors, two lists for each layer. The expected shape for each tensor is
|
137 |
-
`[batch_size, num_heads, seq_len, head_dim]`.
|
138 |
-
|
139 |
-
Frequently-Used Parameters:
|
140 |
-
|
141 |
-
`init_cache_size: Union[int, List]`:
|
142 |
-
The maximum number of KVs to be stored for initial tokens.
|
143 |
-
In the paper, the hyperparameter `a` is an abbreviated alias for `init_cache_size`.
|
144 |
-
|
145 |
-
`sep_cache_size: Union[int, List]`:
|
146 |
-
The maximum number of KVs to be stored for separator tokens.
|
147 |
-
In the paper, the hyperparameter `s` is an abbreviated alias for `sep_cache_size`.
|
148 |
-
|
149 |
-
`local_size: Union[int, List]`:
|
150 |
-
The maximum number of KVs to be stored for local tokens (i.e., sliding window).
|
151 |
-
In the paper, the hyperparameter `w` is an abbreviated alias for `local_size`.
|
152 |
-
|
153 |
-
`cache_size: Union[int, List]`:
|
154 |
-
The maximum number of KVs to be stored for all the tokens, i.e., the size for the whole KV cache.
|
155 |
-
In the paper, the hyperparameter `c` is an abbreviated alias for `cache_size`.
|
156 |
-
|
157 |
-
Concerning these four parameters above:
|
158 |
-
When a list is passed (its length must be `layer_num`), it represents different values for each layer.
|
159 |
-
When an integer is passed, it means the setting is the same for all layers.
|
160 |
-
|
161 |
-
|
162 |
-
`USE_MAX_SEP_CACHE: bool`:
|
163 |
-
If True, it means we only keep at most `sep_cache_size` separators' KVs.
|
164 |
-
If the number exceeds this limit, older separators' KVs will be discarded, keeping only the most recent `sep_cache_size` KVs.
|
165 |
-
In the paper, the hyperparameter `s` is an abbreviated alias for `sep_cache_size`.
|
166 |
-
|
167 |
-
`separator_token_ids: List[int]`:
|
168 |
-
The token ids of the separator tokens for the current model's tokenizer.
|
169 |
-
We have some examples, such as the Llama-3 series models, where setting `model_type='llama'` allows you
|
170 |
-
to skip setting `separator_token_ids` and `PADDING_ID` (SepCache will auto-fill them).
|
171 |
-
|
172 |
-
`PADDING_ID: int`:
|
173 |
-
The token id of the padding token. You can just set `PADDING_ID` to the id of "<|endoftext|>" token of the tokenizer for the pretrained model.
|
174 |
-
```
|
175 |
-
Important Note:
|
176 |
-
- When `cache_size` and `local_size` are set to infinity (i.e., sufficiently large positive integers), and `USE_MAX_SEP_CACHE` is `False`, `SepCache` degenerates into a regular Cache.
|
177 |
-
- You must always ensure that `init_cache_size` + `sep_cache_size` + `local_size` + `left_padding_offset` < `cache_size`. Here, `left_padding_offset` denotes the number of padding tokens in the record with the largest left paddings within a runtime batch. `left_padding_offset` can only be determined at runtime.
|
178 |
-
- To guarantee the above inequality always holds during runtime, when setting, you can intentionally create a sufficient margin between both sides of the following inequality:
|
179 |
-
`init_cache_size` + `sep_cache_size` + `local_size` < `cache_size`, i.e., `a`+`s`+`w`<`c` in the [SepLLM paper - ICML 2025](https://arxiv.org/abs/2412.12094) to leave room for `left_padding_offset`.
|
180 |
-
|
181 |
-
**More Important Note: In practice, no need to do positional encoding (PE) shifting like [StreamingLLM](https://github.com/mit-han-lab/streaming-llm/) if the actual length does not exceed the pretrained max PE length (which applies to most downstream tasks.) . So, for most basic usages, just set `APPLY_PE_SHIFT=False` (`False` is also the default setting) and `APPLY_PES_INSIDE=False` for initialization.**
|
182 |
-
|
183 |
-
|
184 |
-
#### 2.2.4 Update Function
|
185 |
-
After initialization, another key point to note is that when using the `update` function of `SepCache` to update the **keys/values** and the **past token IDs** (which is necessary in SepCache), the current `input_ids` must also be provided.
|
186 |
-
```python
|
187 |
-
key_states, value_states = past_key_values.update(
|
188 |
-
key_states = key_states,
|
189 |
-
value_states = value_states,
|
190 |
-
input_ids = input_ids, ## required
|
191 |
-
layer_idx = layer_idx,
|
192 |
-
PREFILLING_FLAG = q_len > 1, ## `q_len` is the sequence length of the current `query_states`
|
193 |
-
)
|
194 |
-
```
|
195 |
-
|
196 |
-
|
197 |
-
#### 2.2.5 Monkey Patch Demo
|
198 |
-
To adapt the `update` function of `SepCache` mentioned in [`2.2.4 Update Function`](#224-update-function), i.e., passing the current `input_ids` as a parameter to the `update` function. It is worth noting that during the prefilling stage, the shape of the input_ids tensor is `[batch_size, seq_len]`, while during the decoding stage of auto-regressive models, the shape of the `input_ids` tensor should be `[batch_size, 1]`.
|
199 |
-
|
200 |
-
|
201 |
-
In our `custom_generate/generate.py` file, we provide the `monkey_patching` function, which works by replacing the `forward` function in all the related instances of the `XXXAttention` class (for example, in the Llama 3 series model, it would be `LlamaAttention`) with our customized forward function (specified by the `model_atten_forward` parameter of the `monkey_patching` function).
|
202 |
-
```python
|
203 |
-
def monkey_patching(model_obj,
|
204 |
-
model_atten_forward , ## The `forward` function used to patch.
|
205 |
-
possible_inner_model_names: List[str] = ["model", "transformer", "gpt_neox"] , # In `XXXForCausalLM` class, the possible name of internal attribute for model. e.g., "model", "transformer", "gpt_neox", etc.
|
206 |
-
possible_layers_names: List[str] = ["layers", "h" ], # In `XXXModel` class, the possible name of internal attribute for decoder layers, e.g., "layers", "h", etc.
|
207 |
-
atten_attr_name_pattern_list: List[str] = ["attention", "self_attn"], # In `XXXDecoderLayer` class, the possible name of internal attribute for self-attention, e.g., "attention", "self_attn", etc.
|
208 |
-
atten_attr_name_pattern_exclude: List[str] = ["norm", "layer"], # In `XXXDecoderLayer` class, the impossible name patterns (i.e., the patterns to be excluded) of internal attribute for self-attention module class, e.g., "norm" , etc. Sometimes, there will be some attributes like "post_attention_norm" and we do not want modify the `forward` function of it - we want to modify the `forward` function of `XXXAttention`. So, we need to exclude attribute name patterns like "norm" to accurately find the correct "forward" function to replace.
|
209 |
-
verbose = True):
|
210 |
-
|
211 |
-
"""
|
212 |
-
This `monkey_patching` function is to
|
213 |
-
- find the `forward` function of the `XXXAttention` class.
|
214 |
-
- replace all the related `forward` functions of the instances of `XXXAttention` class with `model_atten_forward`.
|
215 |
-
"""
|
216 |
-
|
217 |
-
## To avoid the argument check failure, i.e., let "sepllm_kwargs" pass the check.
|
218 |
-
transformers.generation.GenerationMixin._validate_model_kwargs = _validate_model_kwargs
|
219 |
-
|
220 |
-
## Get inner model obj
|
221 |
-
inner_model_type = PreTrainedModel
|
222 |
-
inner_model = find_inner_attribute(model_obj, possible_inner_model_names, inner_model_type)
|
223 |
-
|
224 |
-
## Get the decoder layers (`nn.ModuleList`) obj
|
225 |
-
layers_type = nn.ModuleList
|
226 |
-
model_layers = find_inner_attribute(inner_model, possible_layers_names, layers_type)
|
227 |
-
|
228 |
-
## Replace all the related `forward` functions of XXXAttention class's instances.
|
229 |
-
for i, decoder_layer in enumerate(model_layers):
|
230 |
-
self_attn_module = find_attribute_name(decoder_layer, atten_attr_name_pattern_list, atten_attr_name_pattern_exclude, nn.Module)
|
231 |
-
result = monkey_patch_by_class_path(self_attn_module, model_atten_forward)
|
232 |
-
if verbose:
|
233 |
-
decoder_class_name = get_importable_class_path(decoder_layer)
|
234 |
-
print(f"For Layer {i}'s `{decoder_class_name}`: {result}")
|
235 |
-
|
236 |
-
return model_layers
|
237 |
-
```
|
238 |
-
|
239 |
-
The `monkey_patching` function primarily does three things:
|
240 |
-
- Precisely locate the `forward` function of all instances of the `XXXAttention` class.
|
241 |
-
- Replace the `forward` function with the `model_atten_forward` function you provide.
|
242 |
-
- Return the corresponding properties of the decoder layers found during the process, typically of type `nn.ModuleList`. This return value (`model_layers`) is only used to determine the number of layers in the current model later on (obtained by `len(model_layers)`).
|
243 |
-
|
244 |
-
In addition, the `monkey_patching` function replaces `transformers.generation.GenerationMixin._validate_model_kwargs` with our `_validate_model_kwargs` to bypass some parameter checks, as we will provide an additional `sepllm_kwargs` parameter to wrap the `input_ids` for eventual transmission to the `SepCache` `update` function.
|
245 |
-
|
246 |
-
|
247 |
-
**Please ensure that the `monkey_patching` function accurately locates and replaces the `forward` function of the `XXXAttention` class. The current `monkey_patching` is designed for the `Llama 3 series` models. For other models, you need to appropriately modify `monkey_patching` to ensure its correctness of targeting and replacement !** You can monitor the monkey patching process by setting `verbose=True` in the `monkey_patching` function (or, `monkey_patch_verbose = True` for the `generate` function.)
|
248 |
-
|
249 |
-
|
250 |
-
```python
|
251 |
-
def truncate_input_ids_4_autoregression(input_ids, key_states):
|
252 |
-
if input_ids.shape[-1] != key_states.shape[-2]:
|
253 |
-
assert input_ids.shape[-1] >= key_states.shape[-2]
|
254 |
-
truncated_input_ids = input_ids[..., -key_states.shape[-2]: ]
|
255 |
-
return truncated_input_ids
|
256 |
-
else:
|
257 |
-
return input_ids
|
258 |
-
```
|
259 |
-
The `truncate_input_ids_4_autoregression` function in the `custom_generate/generate.py` file is used to shape the `input_ids` tensor to `[batch_size, 1]` during decoding.
|
260 |
-
|
261 |
-
#### 2.2.6 Downstream Task Evaluation
|
262 |
-
We recommend using `lm_eval==0.4.9` for downstream task evaluation. You can pass model-related parameters via `--model_args` and generation-related parameters (including those required for initializing `SepCache`) via `--gen_kwargs`. Notably, you typically need to pass a `list` to `separator_token_ids` using a string format like `"id1;id2;id3"` (as shown in the example below).
|
263 |
-
```bash
|
264 |
-
lm_eval --model hf \
|
265 |
-
--model_args pretrained=meta-llama/Meta-Llama-3-8B-Instruct,attn_implementation=flash_attention_2 \
|
266 |
-
--tasks gsm8k_cot \
|
267 |
-
--gen_kwargs custom_generate=transformers-community/sep_cache,trust_remote_code=True,monkey_patch_verbose=True,separator_token_ids="128000;13;11;30;0;26;25;198;220;662;1174;949;758;2652;551;720;256;262",PADDING_ID=128009\
|
268 |
-
--device cuda:0\
|
269 |
-
--batch_size 80 2>&1 | tee log.txt
|
270 |
-
```
|
271 |
-
Note: `SepCache` is typically used in combination with `Flash Attention` to maximize generation efficiency.
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
##
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
##
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
|
|
|
|
|
1 |
+
# SepCache - Native Sparse Attention Cache
|
2 |
+
|
3 |
+
## Table of Contents
|
4 |
+
|
5 |
+
- [1. Abstract](#1-abstract)
|
6 |
+
- [2. Usage](#2-usage)
|
7 |
+
- [2.1 Sample Base Model](#21-sample-base-model)
|
8 |
+
- [2.2 Quick Start](#22-quick-start)
|
9 |
+
- [2.2.1 Environment Setup](#221-environment-setup)
|
10 |
+
- [2.2.2 A Simple Example](#222-a-simple-example)
|
11 |
+
- [2.2.3 Frequently-Used Parameters](#223-frequently-used-parameters)
|
12 |
+
- [2.2.4 Update Function](#224-update-function)
|
13 |
+
- [2.2.5 Monkey Patch Demo](#225-monkey-patch-demo)
|
14 |
+
- [2.2.6 Downstream Task Evaluation](#226-downstream-task-evaluation)
|
15 |
+
- [2.2.7 The Detailed Signature of `generate` Function](#227-the-detailed-signature-of-generate-function)
|
16 |
+
- [3. Adaptation for Other Models](#3-adaptation-for-other-models)
|
17 |
+
- [3.1 Method 1 - Monkey Patching](#31-method-1---monkey-patching)
|
18 |
+
- [3.2 Method 2 - Direct Code Modification](#32-method-2---direct-code-modification)
|
19 |
+
- [3.3 Important Note](#33-important-note)
|
20 |
+
- [4. Other Advanced Usage](#4-other-advanced-usage)
|
21 |
+
|
22 |
+
---
|
23 |
+
|
24 |
+
## 1. Abstract
|
25 |
+
`SepCache` is a simple yet effective, native sparse attention `Cache` class proposed in the [`SepLLM paper - ICML 2025`](https://icml.cc/virtual/2025/poster/45536), which most closely aligns with the semantic distribution of natural language. In the training phase, `SepLLM` condenses the segment information into the KV of the separator that divides the segment. In the inference phase, the corresponding `SepCache` only needs to store the KVs of initial tokens, separator tokens, and recent tokens for generation.
|
26 |
+
|
27 |
+
Notably, `SepCache` also delivers strong performance across many tasks in training-free scenarios. Moreover, `SepLLM` (or simply `SepCache`) is the **most suitable baseline method for sparse attention mechanisms and KV compression/management**, as it is the natively sparse attention mechanism that best aligns with the natural semantic distribution of language.
|
28 |
+
|
29 |
+
See more details and advanced usage in https://github.com/HKUDS/SepLLM
|
30 |
+
|
31 |
+

|
32 |
+
|
33 |
+
## 2. Usage
|
34 |
+
|
35 |
+
### 2.1 Sample Base Model
|
36 |
+
|
37 |
+
We recommend using models from the **Llama 3 series**. Our example model is based on `meta-llama/Meta-Llama-3-8B-Instruct`, for which we have already prepared a targeted `monkey patch`.
|
38 |
+
|
39 |
+
For other models, using `SepCache` requires minor modifications to the corresponding `modeling_xxx.py` file or writing a **custom monkey patch**. These changes are **very simple** -- you only need to pass arguments like `input_ids` to the `update` function of `SepCache` when calling it.
|
40 |
+
|
41 |
+
We will provide a detailed guide later on how to modify your `modeling_xxx.py` file or `monkey patch` file to adapt `SepCache` to any model.
|
42 |
+
|
43 |
+
### 2.2 Quick Start
|
44 |
+
|
45 |
+
#### 2.2.1 Environment Setup
|
46 |
+
You need to install `transformers>=4.53`, and we recommend using `lm_eval>=0.4.9` for running evaluations. We suggest managing your Python environment with `conda` for better dependency control.
|
47 |
+
|
48 |
+
```bash
|
49 |
+
conda create -n sepcache python=3.10
|
50 |
+
conda activate sepcache
|
51 |
+
pip install transformers==4.53
|
52 |
+
pip install lm_eval==0.4.9
|
53 |
+
```
|
54 |
+
#### 2.2.2 A Simple Example
|
55 |
+
You can use `SepCache` by specifying `custom_generate="transformers-community/sep_cache"` or `custom_generate="Gausson/sep_cache"` when calling the `generate` function. In our demo, we have already prepared sample monkey patching for the `Llama 3 series` models and provided some common parameters for initializing `SepCache`.
|
56 |
+
|
57 |
+
```python
|
58 |
+
# requires `transformers>=4.53.0`
|
59 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
60 |
+
|
61 |
+
# Preparing model, tokenizer, and model inputs
|
62 |
+
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
|
63 |
+
model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct", device_map="auto")
|
64 |
+
|
65 |
+
|
66 |
+
messages = [{"role": "user", "content": "Tell me a story about a cat."}]
|
67 |
+
text = tokenizer.apply_chat_template(
|
68 |
+
messages,
|
69 |
+
tokenize=False,
|
70 |
+
add_generation_prompt=True,
|
71 |
+
enable_thinking=False
|
72 |
+
)
|
73 |
+
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
|
74 |
+
|
75 |
+
|
76 |
+
# Using SepCache for generation
|
77 |
+
gen_out = model.generate(
|
78 |
+
# usual `generate` arguments
|
79 |
+
**model_inputs,
|
80 |
+
do_sample=False,
|
81 |
+
max_new_tokens=100,
|
82 |
+
return_dict_in_generate=True,
|
83 |
+
monkey_patch_verbose = True, # To see which functions are actually being monkey patched for `SepCache`.
|
84 |
+
|
85 |
+
# Using SepCache
|
86 |
+
custom_generate="transformers-community/sep_cache", ## Alternatively, you can use `Gausson/sep_cache`
|
87 |
+
trust_remote_code=True,
|
88 |
+
|
89 |
+
# SepCache arguments
|
90 |
+
init_cache_size = 4,
|
91 |
+
sep_cache_size = 128,
|
92 |
+
local_size = 256,
|
93 |
+
cache_size = 512,
|
94 |
+
USE_MAX_SEP_CACHE = True,
|
95 |
+
model_type = 'llama'
|
96 |
+
)
|
97 |
+
|
98 |
+
print(tokenizer.batch_decode(gen_out.sequences, skip_special_tokens=True))
|
99 |
+
assert "sepcache" in str(type(gen_out.past_key_values)).lower()
|
100 |
+
```
|
101 |
+
|
102 |
+
It is worth noting that you must specify the `separator_token_ids: List[int]` and `PADDING_ID: int` parameters for initializing `SepCache`. In the example above, we did not do this because, for convenience, in the demo above, we specified `model_type = "llama"`, in which case `separator_token_ids` and `PADDING_ID` will be automatically filled.
|
103 |
+
|
104 |
+
However, when you use a tokenizer for a non-Llama 3 series model, you need to specify the specific values of `separator_token_ids` and `PADDING_ID` based on the tokenizer you are using. For example, the following example is based on the values obtained from a Llama 3 series tokenizer.
|
105 |
+
```python
|
106 |
+
# Using SepCache for generation
|
107 |
+
gen_out = model.generate(
|
108 |
+
# usual `generate` arguments
|
109 |
+
**model_inputs,
|
110 |
+
do_sample=False,
|
111 |
+
max_new_tokens=100,
|
112 |
+
return_dict_in_generate=True,
|
113 |
+
monkey_patch_verbose = True, # To see which functions are actually being monkey patched for `SepCache`.
|
114 |
+
|
115 |
+
# Using SepCache
|
116 |
+
custom_generate="transformers-community/sep_cache", ## Alternatively, you can use `Gausson/sep_cache`
|
117 |
+
trust_remote_code=True,
|
118 |
+
|
119 |
+
# SepCache arguments
|
120 |
+
init_cache_size = 4,
|
121 |
+
sep_cache_size = 128,
|
122 |
+
local_size = 256,
|
123 |
+
cache_size = 512,
|
124 |
+
USE_MAX_SEP_CACHE = True,
|
125 |
+
separator_token_ids = [128000, 13, 11, 30, 0, 26, 25, 198, 220, 662, 1174, 949, 758, 2652, 551, 720, 256,262],
|
126 |
+
PADDING_ID = 128009
|
127 |
+
)
|
128 |
+
```
|
129 |
+
|
130 |
+
|
131 |
+
#### 2.2.3 Frequently-Used Parameters
|
132 |
+
|
133 |
+
Below, we provide explanations and examples for the most commonly used parameters when initializing `SepCache`. These parameters can be passed through the `generate` function.
|
134 |
+
|
135 |
+
```
|
136 |
+
`SepCache` stores the Key and Value states as lists of tensors, two lists for each layer. The expected shape for each tensor is
|
137 |
+
`[batch_size, num_heads, seq_len, head_dim]`.
|
138 |
+
|
139 |
+
Frequently-Used Parameters:
|
140 |
+
|
141 |
+
`init_cache_size: Union[int, List]`:
|
142 |
+
The maximum number of KVs to be stored for initial tokens.
|
143 |
+
In the paper, the hyperparameter `a` is an abbreviated alias for `init_cache_size`.
|
144 |
+
|
145 |
+
`sep_cache_size: Union[int, List]`:
|
146 |
+
The maximum number of KVs to be stored for separator tokens.
|
147 |
+
In the paper, the hyperparameter `s` is an abbreviated alias for `sep_cache_size`.
|
148 |
+
|
149 |
+
`local_size: Union[int, List]`:
|
150 |
+
The maximum number of KVs to be stored for local tokens (i.e., sliding window).
|
151 |
+
In the paper, the hyperparameter `w` is an abbreviated alias for `local_size`.
|
152 |
+
|
153 |
+
`cache_size: Union[int, List]`:
|
154 |
+
The maximum number of KVs to be stored for all the tokens, i.e., the size for the whole KV cache.
|
155 |
+
In the paper, the hyperparameter `c` is an abbreviated alias for `cache_size`.
|
156 |
+
|
157 |
+
Concerning these four parameters above:
|
158 |
+
When a list is passed (its length must be `layer_num`), it represents different values for each layer.
|
159 |
+
When an integer is passed, it means the setting is the same for all layers.
|
160 |
+
|
161 |
+
|
162 |
+
`USE_MAX_SEP_CACHE: bool`:
|
163 |
+
If True, it means we only keep at most `sep_cache_size` separators' KVs.
|
164 |
+
If the number exceeds this limit, older separators' KVs will be discarded, keeping only the most recent `sep_cache_size` KVs.
|
165 |
+
In the paper, the hyperparameter `s` is an abbreviated alias for `sep_cache_size`.
|
166 |
+
|
167 |
+
`separator_token_ids: List[int]`:
|
168 |
+
The token ids of the separator tokens for the current model's tokenizer.
|
169 |
+
We have some examples, such as the Llama-3 series models, where setting `model_type='llama'` allows you
|
170 |
+
to skip setting `separator_token_ids` and `PADDING_ID` (SepCache will auto-fill them).
|
171 |
+
|
172 |
+
`PADDING_ID: int`:
|
173 |
+
The token id of the padding token. You can just set `PADDING_ID` to the id of "<|endoftext|>" token of the tokenizer for the pretrained model.
|
174 |
+
```
|
175 |
+
Important Note:
|
176 |
+
- When `cache_size` and `local_size` are set to infinity (i.e., sufficiently large positive integers), and `USE_MAX_SEP_CACHE` is `False`, `SepCache` degenerates into a regular Cache.
|
177 |
+
- You must always ensure that `init_cache_size` + `sep_cache_size` + `local_size` + `left_padding_offset` < `cache_size`. Here, `left_padding_offset` denotes the number of padding tokens in the record with the largest left paddings within a runtime batch. `left_padding_offset` can only be determined at runtime.
|
178 |
+
- To guarantee the above inequality always holds during runtime, when setting, you can intentionally create a sufficient margin between both sides of the following inequality:
|
179 |
+
`init_cache_size` + `sep_cache_size` + `local_size` < `cache_size`, i.e., `a`+`s`+`w`<`c` in the [SepLLM paper - ICML 2025](https://arxiv.org/abs/2412.12094) to leave room for `left_padding_offset`.
|
180 |
+
|
181 |
+
**More Important Note: In practice, no need to do positional encoding (PE) shifting like [StreamingLLM](https://github.com/mit-han-lab/streaming-llm/) if the actual length does not exceed the pretrained max PE length (which applies to most downstream tasks.) . So, for most basic usages, just set `APPLY_PE_SHIFT=False` (`False` is also the default setting) and `APPLY_PES_INSIDE=False` for initialization.**
|
182 |
+
|
183 |
+
|
184 |
+
#### 2.2.4 Update Function
|
185 |
+
After initialization, another key point to note is that when using the `update` function of `SepCache` to update the **keys/values** and the **past token IDs** (which is necessary in SepCache), the current `input_ids` must also be provided.
|
186 |
+
```python
|
187 |
+
key_states, value_states = past_key_values.update(
|
188 |
+
key_states = key_states,
|
189 |
+
value_states = value_states,
|
190 |
+
input_ids = input_ids, ## required
|
191 |
+
layer_idx = layer_idx,
|
192 |
+
PREFILLING_FLAG = q_len > 1, ## `q_len` is the sequence length of the current `query_states`
|
193 |
+
)
|
194 |
+
```
|
195 |
+
|
196 |
+
|
197 |
+
#### 2.2.5 Monkey Patch Demo
|
198 |
+
To adapt the `update` function of `SepCache` mentioned in [`2.2.4 Update Function`](#224-update-function), i.e., passing the current `input_ids` as a parameter to the `update` function. It is worth noting that during the prefilling stage, the shape of the input_ids tensor is `[batch_size, seq_len]`, while during the decoding stage of auto-regressive models, the shape of the `input_ids` tensor should be `[batch_size, 1]`.
|
199 |
+
|
200 |
+
|
201 |
+
In our `custom_generate/generate.py` file, we provide the `monkey_patching` function, which works by replacing the `forward` function in all the related instances of the `XXXAttention` class (for example, in the Llama 3 series model, it would be `LlamaAttention`) with our customized forward function (specified by the `model_atten_forward` parameter of the `monkey_patching` function).
|
202 |
+
```python
|
203 |
+
def monkey_patching(model_obj,
|
204 |
+
model_atten_forward , ## The `forward` function used to patch.
|
205 |
+
possible_inner_model_names: List[str] = ["model", "transformer", "gpt_neox"] , # In `XXXForCausalLM` class, the possible name of internal attribute for model. e.g., "model", "transformer", "gpt_neox", etc.
|
206 |
+
possible_layers_names: List[str] = ["layers", "h" ], # In `XXXModel` class, the possible name of internal attribute for decoder layers, e.g., "layers", "h", etc.
|
207 |
+
atten_attr_name_pattern_list: List[str] = ["attention", "self_attn"], # In `XXXDecoderLayer` class, the possible name of internal attribute for self-attention, e.g., "attention", "self_attn", etc.
|
208 |
+
atten_attr_name_pattern_exclude: List[str] = ["norm", "layer"], # In `XXXDecoderLayer` class, the impossible name patterns (i.e., the patterns to be excluded) of internal attribute for self-attention module class, e.g., "norm" , etc. Sometimes, there will be some attributes like "post_attention_norm" and we do not want modify the `forward` function of it - we want to modify the `forward` function of `XXXAttention`. So, we need to exclude attribute name patterns like "norm" to accurately find the correct "forward" function to replace.
|
209 |
+
verbose = True):
|
210 |
+
|
211 |
+
"""
|
212 |
+
This `monkey_patching` function is to
|
213 |
+
- find the `forward` function of the `XXXAttention` class.
|
214 |
+
- replace all the related `forward` functions of the instances of `XXXAttention` class with `model_atten_forward`.
|
215 |
+
"""
|
216 |
+
|
217 |
+
## To avoid the argument check failure, i.e., let "sepllm_kwargs" pass the check.
|
218 |
+
transformers.generation.GenerationMixin._validate_model_kwargs = _validate_model_kwargs
|
219 |
+
|
220 |
+
## Get inner model obj
|
221 |
+
inner_model_type = PreTrainedModel
|
222 |
+
inner_model = find_inner_attribute(model_obj, possible_inner_model_names, inner_model_type)
|
223 |
+
|
224 |
+
## Get the decoder layers (`nn.ModuleList`) obj
|
225 |
+
layers_type = nn.ModuleList
|
226 |
+
model_layers = find_inner_attribute(inner_model, possible_layers_names, layers_type)
|
227 |
+
|
228 |
+
## Replace all the related `forward` functions of XXXAttention class's instances.
|
229 |
+
for i, decoder_layer in enumerate(model_layers):
|
230 |
+
self_attn_module = find_attribute_name(decoder_layer, atten_attr_name_pattern_list, atten_attr_name_pattern_exclude, nn.Module)
|
231 |
+
result = monkey_patch_by_class_path(self_attn_module, model_atten_forward)
|
232 |
+
if verbose:
|
233 |
+
decoder_class_name = get_importable_class_path(decoder_layer)
|
234 |
+
print(f"For Layer {i}'s `{decoder_class_name}`: {result}")
|
235 |
+
|
236 |
+
return model_layers
|
237 |
+
```
|
238 |
+
|
239 |
+
The `monkey_patching` function primarily does three things:
|
240 |
+
- Precisely locate the `forward` function of all instances of the `XXXAttention` class.
|
241 |
+
- Replace the `forward` function with the `model_atten_forward` function you provide.
|
242 |
+
- Return the corresponding properties of the decoder layers found during the process, typically of type `nn.ModuleList`. This return value (`model_layers`) is only used to determine the number of layers in the current model later on (obtained by `len(model_layers)`).
|
243 |
+
|
244 |
+
In addition, the `monkey_patching` function replaces `transformers.generation.GenerationMixin._validate_model_kwargs` with our `_validate_model_kwargs` to bypass some parameter checks, as we will provide an additional `sepllm_kwargs` parameter to wrap the `input_ids` for eventual transmission to the `SepCache` `update` function.
|
245 |
+
|
246 |
+
|
247 |
+
**Please ensure that the `monkey_patching` function accurately locates and replaces the `forward` function of the `XXXAttention` class. The current `monkey_patching` is designed for the `Llama 3 series` models. For other models, you need to appropriately modify `monkey_patching` to ensure its correctness of targeting and replacement !** You can monitor the monkey patching process by setting `verbose=True` in the `monkey_patching` function (or, `monkey_patch_verbose = True` for the `generate` function.)
|
248 |
+
|
249 |
+
|
250 |
+
```python
|
251 |
+
def truncate_input_ids_4_autoregression(input_ids, key_states):
|
252 |
+
if input_ids.shape[-1] != key_states.shape[-2]:
|
253 |
+
assert input_ids.shape[-1] >= key_states.shape[-2]
|
254 |
+
truncated_input_ids = input_ids[..., -key_states.shape[-2]: ]
|
255 |
+
return truncated_input_ids
|
256 |
+
else:
|
257 |
+
return input_ids
|
258 |
+
```
|
259 |
+
The `truncate_input_ids_4_autoregression` function in the `custom_generate/generate.py` file is used to shape the `input_ids` tensor to `[batch_size, 1]` during decoding.
|
260 |
+
|
261 |
+
#### 2.2.6 Downstream Task Evaluation
|
262 |
+
We recommend using `lm_eval==0.4.9` for downstream task evaluation. You can pass model-related parameters via `--model_args` and generation-related parameters (including those required for initializing `SepCache`) via `--gen_kwargs`. Notably, you typically need to pass a `list` to `separator_token_ids` using a string format like `"id1;id2;id3"` (as shown in the example below).
|
263 |
+
```bash
|
264 |
+
lm_eval --model hf \
|
265 |
+
--model_args pretrained=meta-llama/Meta-Llama-3-8B-Instruct,attn_implementation=flash_attention_2 \
|
266 |
+
--tasks gsm8k_cot \
|
267 |
+
--gen_kwargs custom_generate=transformers-community/sep_cache,trust_remote_code=True,monkey_patch_verbose=True,separator_token_ids="128000;13;11;30;0;26;25;198;220;662;1174;949;758;2652;551;720;256;262",PADDING_ID=128009\
|
268 |
+
--device cuda:0\
|
269 |
+
--batch_size 80 2>&1 | tee log.txt
|
270 |
+
```
|
271 |
+
Note: `SepCache` is typically used in combination with `Flash Attention` to maximize generation efficiency.
|
272 |
+
|
273 |
+
<img width="1022" height="248" alt="1752618213617" src="https://github.com/user-attachments/assets/87e2e745-9677-4101-895e-dd6fc7b6039d" />
|
274 |
+
|
275 |
+
#### 2.2.7 The Detailed Signature of `generate` Function
|
276 |
+
Here is the detailed signature of our customized `generate` function for `SepCache` in `custom_generate/generate.py` file:
|
277 |
+
|
278 |
+
```python
|
279 |
+
def generate(model,
|
280 |
+
## For SepCache
|
281 |
+
init_cache_size: Union[int, List] = 4,
|
282 |
+
sep_cache_size: Union[int, List] = 128,
|
283 |
+
local_size: Union[int, List]=256,
|
284 |
+
cache_size: Union[int, List]=512,
|
285 |
+
SEP_ACCUMULATION: bool = True,
|
286 |
+
USE_MAX_SEP_CACHE: bool = False,
|
287 |
+
SEP_PADDING_IN_BATCH: bool = False,
|
288 |
+
separator_token_ids: List[int] = None, ## required for initialization if `model_type` is not provided.
|
289 |
+
PADDING_ID: int = None, ## required for initialization if `model_type` is not provided.
|
290 |
+
|
291 |
+
## For inheritance & initialization states
|
292 |
+
past_tok_ids: List[torch.Tensor] = None, ## It saves all the token ids corresponding to the saved KVs for all layers in SepCache.
|
293 |
+
key_cache: List[torch.Tensor] = None,
|
294 |
+
value_cache: List[torch.Tensor] = None,
|
295 |
+
|
296 |
+
## For debugging
|
297 |
+
PRINT_KV_RATIO_INSIDE: bool = False,
|
298 |
+
print_KV_inside_per_steps: int = 1000,
|
299 |
+
_seen_tokens: int = 0,
|
300 |
+
_kept_kv_ratio: List[Tuple[int]] = None,
|
301 |
+
|
302 |
+
### For positional encoding shifting
|
303 |
+
APPLY_PE_SHIFT: bool = False,
|
304 |
+
APPLY_PES_INSIDE: bool = False,
|
305 |
+
_shifted_position_ids: List[torch.Tensor] = None,
|
306 |
+
_rope_unsqueeze_dim: int = 1, ## The unsqueeze_dim when applying RoPE.
|
307 |
+
_rope_seq_dim: int=1, ## The seq_len dimension for the `cos` or `sin` tensors.
|
308 |
+
pe_scaling_factor:float = 1.0,
|
309 |
+
pe_dim:int=128, ## The number of dims for positional encoding. Typically, just set the `head_dim` to this.
|
310 |
+
max_position_embeddings: int = 8192,
|
311 |
+
base: int=10000, ## The base for RoPE.
|
312 |
+
|
313 |
+
## For basic transformer architecture
|
314 |
+
k_seq_dim: int=2, ## The dimension for seq_len in key tensors
|
315 |
+
v_seq_dim: int=2, ## The dimension for seq_len in value tensors
|
316 |
+
layer_num: int = None, ## required for initialization
|
317 |
+
|
318 |
+
model_type: str = 'llama', ## The model type for running the example. choose from ['llama', 'pythia','falcon'].
|
319 |
+
device = None,
|
320 |
+
|
321 |
+
## For verbosity of monkey patching
|
322 |
+
monkey_patch_verbose: bool = False,
|
323 |
+
|
324 |
+
**kwargs
|
325 |
+
):
|
326 |
+
...
|
327 |
+
```
|
328 |
+
|
329 |
+
## 3. Adaptation for Other Models
|
330 |
+
|
331 |
+
Adapting `SepCache` to various models is simple - two approaches:
|
332 |
+
|
333 |
+
|
334 |
+
### 3.1 Method 1 - Monkey Patching
|
335 |
+
- Modify the `monkey_patching` function to correctly locate and target the `forward` function of your model's `XXXAttention` class (e.g., `LlamaAttention` for Llama 3).
|
336 |
+
- Write your custom `model_atten_forward` function and use `monkey_patching` to replace the `forward` function of all `XXXAttention` class instances. The key modification is passing `input_ids` to `SepCache`'s `update` function.
|
337 |
+
|
338 |
+
### 3.2 Method 2 - Direct Code Modification (Recommended for Simplicity)
|
339 |
+
Simply edit your `modeling_xxx.py` file to implement:
|
340 |
+
|
341 |
+
- Initialize `past_key_values` as a `SepCache` instance at the appropriate location (e.g., in `XXXForCausalLM` or `XXXModel` class' `forward` function).
|
342 |
+
- Modify the `forward` function of the `XXXAttention` class to pass `input_ids` to `SepCache`'s `update` function.
|
343 |
+
|
344 |
+
### 3.3 Important Note
|
345 |
+
The shape of `input_ids` is `[batch_size, seq_len]` during prefilling, and `[batch_size, 1]` during generation.
|
346 |
+
|
347 |
+
## 4. Other Advanced Usage
|
348 |
+
|
349 |
+
Please refer to https://github.com/HKUDS/SepLLM, in which there are detailed explanations and examples.
|