Gausson commited on
Commit
610c95b
·
verified ·
1 Parent(s): c5c2a6b

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +349 -347
README.md CHANGED
@@ -1,347 +1,349 @@
1
- # SepCache - Native Sparse Attention Cache
2
-
3
- ## Table of Contents
4
-
5
- - [1. Abstract](#1-abstract)
6
- - [2. Usage](#2-usage)
7
- - [2.1 Sample Base Model](#21-sample-base-model)
8
- - [2.2 Quick Start](#22-quick-start)
9
- - [2.2.1 Environment Setup](#221-environment-setup)
10
- - [2.2.2 A Simple Example](#222-a-simple-example)
11
- - [2.2.3 Frequently-Used Parameters](#223-frequently-used-parameters)
12
- - [2.2.4 Update Function](#224-update-function)
13
- - [2.2.5 Monkey Patch Demo](#225-monkey-patch-demo)
14
- - [2.2.6 Downstream Task Evaluation](#226-downstream-task-evaluation)
15
- - [2.2.7 The Detailed Signature of `generate` Function](#227-the-detailed-signature-of-generate-function)
16
- - [3. Adaptation for Other Models](#3-adaptation-for-other-models)
17
- - [3.1 Method 1 - Monkey Patching](#31-method-1---monkey-patching)
18
- - [3.2 Method 2 - Direct Code Modification](#32-method-2---direct-code-modification)
19
- - [3.3 Important Note](#33-important-note)
20
- - [4. Other Advanced Usage](#4-other-advanced-usage)
21
-
22
- ---
23
-
24
- ## 1. Abstract
25
- `SepCache` is a simple yet effective, native sparse attention `Cache` class proposed in the [`SepLLM paper - ICML 2025`](https://icml.cc/virtual/2025/poster/45536), which most closely aligns with the semantic distribution of natural language. In the training phase, `SepLLM` condenses the segment information into the KV of the separator that divides the segment. In the inference phase, the corresponding `SepCache` only needs to store the KVs of initial tokens, separator tokens, and recent tokens for generation.
26
-
27
- Notably, `SepCache` also delivers strong performance across many tasks in training-free scenarios. Moreover, `SepLLM` (or simply `SepCache`) is the **most suitable baseline method for sparse attention mechanisms and KV compression/management**, as it is the natively sparse attention mechanism that best aligns with the natural semantic distribution of language.
28
-
29
- See more details and advanced usage in https://github.com/HKUDS/SepLLM
30
-
31
- ![image](https://hackmd.io/_uploads/r1POJoR4yg.png)
32
-
33
- ## 2. Usage
34
-
35
- ### 2.1 Sample Base Model
36
-
37
- We recommend using models from the **Llama 3 series**. Our example model is based on `meta-llama/Meta-Llama-3-8B-Instruct`, for which we have already prepared a targeted `monkey patch`.
38
-
39
- For other models, using `SepCache` requires minor modifications to the corresponding `modeling_xxx.py` file or writing a **custom monkey patch**. These changes are **very simple** -- you only need to pass arguments like `input_ids` to the `update` function of `SepCache` when calling it.
40
-
41
- We will provide a detailed guide later on how to modify your `modeling_xxx.py` file or `monkey patch` file to adapt `SepCache` to any model.
42
-
43
- ### 2.2 Quick Start
44
-
45
- #### 2.2.1 Environment Setup
46
- You need to install `transformers>=4.53`, and we recommend using `lm_eval>=0.4.9` for running evaluations. We suggest managing your Python environment with `conda` for better dependency control.
47
-
48
- ```bash
49
- conda create -n sepcache python=3.10
50
- conda activate sepcache
51
- pip install transformers==4.53
52
- pip install lm_eval==0.4.9
53
- ```
54
- #### 2.2.2 A Simple Example
55
- You can use `SepCache` by specifying `custom_generate="transformers-community/sep_cache"` or `custom_generate="Gausson/sep_cache"` when calling the `generate` function. In our demo, we have already prepared sample monkey patching for the `Llama 3 series` models and provided some common parameters for initializing `SepCache`.
56
-
57
- ```python
58
- # requires `transformers>=4.53.0`
59
- from transformers import AutoModelForCausalLM, AutoTokenizer
60
-
61
- # Preparing model, tokenizer, and model inputs
62
- tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
63
- model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct", device_map="auto")
64
-
65
-
66
- messages = [{"role": "user", "content": "Tell me a story about a cat."}]
67
- text = tokenizer.apply_chat_template(
68
- messages,
69
- tokenize=False,
70
- add_generation_prompt=True,
71
- enable_thinking=False
72
- )
73
- model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
74
-
75
-
76
- # Using SepCache for generation
77
- gen_out = model.generate(
78
- # usual `generate` arguments
79
- **model_inputs,
80
- do_sample=False,
81
- max_new_tokens=100,
82
- return_dict_in_generate=True,
83
- monkey_patch_verbose = True, # To see which functions are actually being monkey patched for `SepCache`.
84
-
85
- # Using SepCache
86
- custom_generate="transformers-community/sep_cache", ## Alternatively, you can use `Gausson/sep_cache`
87
- trust_remote_code=True,
88
-
89
- # SepCache arguments
90
- init_cache_size = 4,
91
- sep_cache_size = 128,
92
- local_size = 256,
93
- cache_size = 512,
94
- USE_MAX_SEP_CACHE = True,
95
- model_type = 'llama'
96
- )
97
-
98
- print(tokenizer.batch_decode(gen_out.sequences, skip_special_tokens=True))
99
- assert "sepcache" in str(type(gen_out.past_key_values)).lower()
100
- ```
101
-
102
- It is worth noting that you must specify the `separator_token_ids: List[int]` and `PADDING_ID: int` parameters for initializing `SepCache`. In the example above, we did not do this because, for convenience, in the demo above, we specified `model_type = "llama"`, in which case `separator_token_ids` and `PADDING_ID` will be automatically filled.
103
-
104
- However, when you use a tokenizer for a non-Llama 3 series model, you need to specify the specific values of `separator_token_ids` and `PADDING_ID` based on the tokenizer you are using. For example, the following example is based on the values obtained from a Llama 3 series tokenizer.
105
- ```python
106
- # Using SepCache for generation
107
- gen_out = model.generate(
108
- # usual `generate` arguments
109
- **model_inputs,
110
- do_sample=False,
111
- max_new_tokens=100,
112
- return_dict_in_generate=True,
113
- monkey_patch_verbose = True, # To see which functions are actually being monkey patched for `SepCache`.
114
-
115
- # Using SepCache
116
- custom_generate="transformers-community/sep_cache", ## Alternatively, you can use `Gausson/sep_cache`
117
- trust_remote_code=True,
118
-
119
- # SepCache arguments
120
- init_cache_size = 4,
121
- sep_cache_size = 128,
122
- local_size = 256,
123
- cache_size = 512,
124
- USE_MAX_SEP_CACHE = True,
125
- separator_token_ids = [128000, 13, 11, 30, 0, 26, 25, 198, 220, 662, 1174, 949, 758, 2652, 551, 720, 256,262],
126
- PADDING_ID = 128009
127
- )
128
- ```
129
-
130
-
131
- #### 2.2.3 Frequently-Used Parameters
132
-
133
- Below, we provide explanations and examples for the most commonly used parameters when initializing `SepCache`. These parameters can be passed through the `generate` function.
134
-
135
- ```
136
- `SepCache` stores the Key and Value states as lists of tensors, two lists for each layer. The expected shape for each tensor is
137
- `[batch_size, num_heads, seq_len, head_dim]`.
138
-
139
- Frequently-Used Parameters:
140
-
141
- `init_cache_size: Union[int, List]`:
142
- The maximum number of KVs to be stored for initial tokens.
143
- In the paper, the hyperparameter `a` is an abbreviated alias for `init_cache_size`.
144
-
145
- `sep_cache_size: Union[int, List]`:
146
- The maximum number of KVs to be stored for separator tokens.
147
- In the paper, the hyperparameter `s` is an abbreviated alias for `sep_cache_size`.
148
-
149
- `local_size: Union[int, List]`:
150
- The maximum number of KVs to be stored for local tokens (i.e., sliding window).
151
- In the paper, the hyperparameter `w` is an abbreviated alias for `local_size`.
152
-
153
- `cache_size: Union[int, List]`:
154
- The maximum number of KVs to be stored for all the tokens, i.e., the size for the whole KV cache.
155
- In the paper, the hyperparameter `c` is an abbreviated alias for `cache_size`.
156
-
157
- Concerning these four parameters above:
158
- When a list is passed (its length must be `layer_num`), it represents different values for each layer.
159
- When an integer is passed, it means the setting is the same for all layers.
160
-
161
-
162
- `USE_MAX_SEP_CACHE: bool`:
163
- If True, it means we only keep at most `sep_cache_size` separators' KVs.
164
- If the number exceeds this limit, older separators' KVs will be discarded, keeping only the most recent `sep_cache_size` KVs.
165
- In the paper, the hyperparameter `s` is an abbreviated alias for `sep_cache_size`.
166
-
167
- `separator_token_ids: List[int]`:
168
- The token ids of the separator tokens for the current model's tokenizer.
169
- We have some examples, such as the Llama-3 series models, where setting `model_type='llama'` allows you
170
- to skip setting `separator_token_ids` and `PADDING_ID` (SepCache will auto-fill them).
171
-
172
- `PADDING_ID: int`:
173
- The token id of the padding token. You can just set `PADDING_ID` to the id of "<|endoftext|>" token of the tokenizer for the pretrained model.
174
- ```
175
- Important Note:
176
- - When `cache_size` and `local_size` are set to infinity (i.e., sufficiently large positive integers), and `USE_MAX_SEP_CACHE` is `False`, `SepCache` degenerates into a regular Cache.
177
- - You must always ensure that `init_cache_size` + `sep_cache_size` + `local_size` + `left_padding_offset` < `cache_size`. Here, `left_padding_offset` denotes the number of padding tokens in the record with the largest left paddings within a runtime batch. `left_padding_offset` can only be determined at runtime.
178
- - To guarantee the above inequality always holds during runtime, when setting, you can intentionally create a sufficient margin between both sides of the following inequality:
179
- `init_cache_size` + `sep_cache_size` + `local_size` < `cache_size`, i.e., `a`+`s`+`w`<`c` in the [SepLLM paper - ICML 2025](https://arxiv.org/abs/2412.12094) to leave room for `left_padding_offset`.
180
-
181
- **More Important Note: In practice, no need to do positional encoding (PE) shifting like [StreamingLLM](https://github.com/mit-han-lab/streaming-llm/) if the actual length does not exceed the pretrained max PE length (which applies to most downstream tasks.) . So, for most basic usages, just set `APPLY_PE_SHIFT=False` (`False` is also the default setting) and `APPLY_PES_INSIDE=False` for initialization.**
182
-
183
-
184
- #### 2.2.4 Update Function
185
- After initialization, another key point to note is that when using the `update` function of `SepCache` to update the **keys/values** and the **past token IDs** (which is necessary in SepCache), the current `input_ids` must also be provided.
186
- ```python
187
- key_states, value_states = past_key_values.update(
188
- key_states = key_states,
189
- value_states = value_states,
190
- input_ids = input_ids, ## required
191
- layer_idx = layer_idx,
192
- PREFILLING_FLAG = q_len > 1, ## `q_len` is the sequence length of the current `query_states`
193
- )
194
- ```
195
-
196
-
197
- #### 2.2.5 Monkey Patch Demo
198
- To adapt the `update` function of `SepCache` mentioned in [`2.2.4 Update Function`](#224-update-function), i.e., passing the current `input_ids` as a parameter to the `update` function. It is worth noting that during the prefilling stage, the shape of the input_ids tensor is `[batch_size, seq_len]`, while during the decoding stage of auto-regressive models, the shape of the `input_ids` tensor should be `[batch_size, 1]`.
199
-
200
-
201
- In our `custom_generate/generate.py` file, we provide the `monkey_patching` function, which works by replacing the `forward` function in all the related instances of the `XXXAttention` class (for example, in the Llama 3 series model, it would be `LlamaAttention`) with our customized forward function (specified by the `model_atten_forward` parameter of the `monkey_patching` function).
202
- ```python
203
- def monkey_patching(model_obj,
204
- model_atten_forward , ## The `forward` function used to patch.
205
- possible_inner_model_names: List[str] = ["model", "transformer", "gpt_neox"] , # In `XXXForCausalLM` class, the possible name of internal attribute for model. e.g., "model", "transformer", "gpt_neox", etc.
206
- possible_layers_names: List[str] = ["layers", "h" ], # In `XXXModel` class, the possible name of internal attribute for decoder layers, e.g., "layers", "h", etc.
207
- atten_attr_name_pattern_list: List[str] = ["attention", "self_attn"], # In `XXXDecoderLayer` class, the possible name of internal attribute for self-attention, e.g., "attention", "self_attn", etc.
208
- atten_attr_name_pattern_exclude: List[str] = ["norm", "layer"], # In `XXXDecoderLayer` class, the impossible name patterns (i.e., the patterns to be excluded) of internal attribute for self-attention module class, e.g., "norm" , etc. Sometimes, there will be some attributes like "post_attention_norm" and we do not want modify the `forward` function of it - we want to modify the `forward` function of `XXXAttention`. So, we need to exclude attribute name patterns like "norm" to accurately find the correct "forward" function to replace.
209
- verbose = True):
210
-
211
- """
212
- This `monkey_patching` function is to
213
- - find the `forward` function of the `XXXAttention` class.
214
- - replace all the related `forward` functions of the instances of `XXXAttention` class with `model_atten_forward`.
215
- """
216
-
217
- ## To avoid the argument check failure, i.e., let "sepllm_kwargs" pass the check.
218
- transformers.generation.GenerationMixin._validate_model_kwargs = _validate_model_kwargs
219
-
220
- ## Get inner model obj
221
- inner_model_type = PreTrainedModel
222
- inner_model = find_inner_attribute(model_obj, possible_inner_model_names, inner_model_type)
223
-
224
- ## Get the decoder layers (`nn.ModuleList`) obj
225
- layers_type = nn.ModuleList
226
- model_layers = find_inner_attribute(inner_model, possible_layers_names, layers_type)
227
-
228
- ## Replace all the related `forward` functions of XXXAttention class's instances.
229
- for i, decoder_layer in enumerate(model_layers):
230
- self_attn_module = find_attribute_name(decoder_layer, atten_attr_name_pattern_list, atten_attr_name_pattern_exclude, nn.Module)
231
- result = monkey_patch_by_class_path(self_attn_module, model_atten_forward)
232
- if verbose:
233
- decoder_class_name = get_importable_class_path(decoder_layer)
234
- print(f"For Layer {i}'s `{decoder_class_name}`: {result}")
235
-
236
- return model_layers
237
- ```
238
-
239
- The `monkey_patching` function primarily does three things:
240
- - Precisely locate the `forward` function of all instances of the `XXXAttention` class.
241
- - Replace the `forward` function with the `model_atten_forward` function you provide.
242
- - Return the corresponding properties of the decoder layers found during the process, typically of type `nn.ModuleList`. This return value (`model_layers`) is only used to determine the number of layers in the current model later on (obtained by `len(model_layers)`).
243
-
244
- In addition, the `monkey_patching` function replaces `transformers.generation.GenerationMixin._validate_model_kwargs` with our `_validate_model_kwargs` to bypass some parameter checks, as we will provide an additional `sepllm_kwargs` parameter to wrap the `input_ids` for eventual transmission to the `SepCache` `update` function.
245
-
246
-
247
- **Please ensure that the `monkey_patching` function accurately locates and replaces the `forward` function of the `XXXAttention` class. The current `monkey_patching` is designed for the `Llama 3 series` models. For other models, you need to appropriately modify `monkey_patching` to ensure its correctness of targeting and replacement !** You can monitor the monkey patching process by setting `verbose=True` in the `monkey_patching` function (or, `monkey_patch_verbose = True` for the `generate` function.)
248
-
249
-
250
- ```python
251
- def truncate_input_ids_4_autoregression(input_ids, key_states):
252
- if input_ids.shape[-1] != key_states.shape[-2]:
253
- assert input_ids.shape[-1] >= key_states.shape[-2]
254
- truncated_input_ids = input_ids[..., -key_states.shape[-2]: ]
255
- return truncated_input_ids
256
- else:
257
- return input_ids
258
- ```
259
- The `truncate_input_ids_4_autoregression` function in the `custom_generate/generate.py` file is used to shape the `input_ids` tensor to `[batch_size, 1]` during decoding.
260
-
261
- #### 2.2.6 Downstream Task Evaluation
262
- We recommend using `lm_eval==0.4.9` for downstream task evaluation. You can pass model-related parameters via `--model_args` and generation-related parameters (including those required for initializing `SepCache`) via `--gen_kwargs`. Notably, you typically need to pass a `list` to `separator_token_ids` using a string format like `"id1;id2;id3"` (as shown in the example below).
263
- ```bash
264
- lm_eval --model hf \
265
- --model_args pretrained=meta-llama/Meta-Llama-3-8B-Instruct,attn_implementation=flash_attention_2 \
266
- --tasks gsm8k_cot \
267
- --gen_kwargs custom_generate=transformers-community/sep_cache,trust_remote_code=True,monkey_patch_verbose=True,separator_token_ids="128000;13;11;30;0;26;25;198;220;662;1174;949;758;2652;551;720;256;262",PADDING_ID=128009\
268
- --device cuda:0\
269
- --batch_size 80 2>&1 | tee log.txt
270
- ```
271
- Note: `SepCache` is typically used in combination with `Flash Attention` to maximize generation efficiency.
272
-
273
- #### 2.2.7 The Detailed Signature of `generate` Function
274
- Here is the detailed signature of our customized `generate` function for `SepCache` in `custom_generate/generate.py` file:
275
-
276
- ```python
277
- def generate(model,
278
- ## For SepCache
279
- init_cache_size: Union[int, List] = 4,
280
- sep_cache_size: Union[int, List] = 128,
281
- local_size: Union[int, List]=256,
282
- cache_size: Union[int, List]=512,
283
- SEP_ACCUMULATION: bool = True,
284
- USE_MAX_SEP_CACHE: bool = False,
285
- SEP_PADDING_IN_BATCH: bool = False,
286
- separator_token_ids: List[int] = None, ## required for initialization if `model_type` is not provided.
287
- PADDING_ID: int = None, ## required for initialization if `model_type` is not provided.
288
-
289
- ## For inheritance & initialization states
290
- past_tok_ids: List[torch.Tensor] = None, ## It saves all the token ids corresponding to the saved KVs for all layers in SepCache.
291
- key_cache: List[torch.Tensor] = None,
292
- value_cache: List[torch.Tensor] = None,
293
-
294
- ## For debugging
295
- PRINT_KV_RATIO_INSIDE: bool = False,
296
- print_KV_inside_per_steps: int = 1000,
297
- _seen_tokens: int = 0,
298
- _kept_kv_ratio: List[Tuple[int]] = None,
299
-
300
- ### For positional encoding shifting
301
- APPLY_PE_SHIFT: bool = False,
302
- APPLY_PES_INSIDE: bool = False,
303
- _shifted_position_ids: List[torch.Tensor] = None,
304
- _rope_unsqueeze_dim: int = 1, ## The unsqueeze_dim when applying RoPE.
305
- _rope_seq_dim: int=1, ## The seq_len dimension for the `cos` or `sin` tensors.
306
- pe_scaling_factor:float = 1.0,
307
- pe_dim:int=128, ## The number of dims for positional encoding. Typically, just set the `head_dim` to this.
308
- max_position_embeddings: int = 8192,
309
- base: int=10000, ## The base for RoPE.
310
-
311
- ## For basic transformer architecture
312
- k_seq_dim: int=2, ## The dimension for seq_len in key tensors
313
- v_seq_dim: int=2, ## The dimension for seq_len in value tensors
314
- layer_num: int = None, ## required for initialization
315
-
316
- model_type: str = 'llama', ## The model type for running the example. choose from ['llama', 'pythia','falcon'].
317
- device = None,
318
-
319
- ## For verbosity of monkey patching
320
- monkey_patch_verbose: bool = False,
321
-
322
- **kwargs
323
- ):
324
- ...
325
- ```
326
-
327
- ## 3. Adaptation for Other Models
328
-
329
- Adapting `SepCache` to various models is simple - two approaches:
330
-
331
-
332
- ### 3.1 Method 1 - Monkey Patching
333
- - Modify the `monkey_patching` function to correctly locate and target the `forward` function of your model's `XXXAttention` class (e.g., `LlamaAttention` for Llama 3).
334
- - Write your custom `model_atten_forward` function and use `monkey_patching` to replace the `forward` function of all `XXXAttention` class instances. The key modification is passing `input_ids` to `SepCache`'s `update` function.
335
-
336
- ### 3.2 Method 2 - Direct Code Modification (Recommended for Simplicity)
337
- Simply edit your `modeling_xxx.py` file to implement:
338
-
339
- - Initialize `past_key_values` as a `SepCache` instance at the appropriate location (e.g., in `XXXForCausalLM` or `XXXModel` class' `forward` function).
340
- - Modify the `forward` function of the `XXXAttention` class to pass `input_ids` to `SepCache`'s `update` function.
341
-
342
- ### 3.3 Important Note
343
- The shape of `input_ids` is `[batch_size, seq_len]` during prefilling, and `[batch_size, 1]` during generation.
344
-
345
- ## 4. Other Advanced Usage
346
-
347
- Please refer to https://github.com/HKUDS/SepLLM, in which there are detailed explanations and examples.
 
 
 
1
+ # SepCache - Native Sparse Attention Cache
2
+
3
+ ## Table of Contents
4
+
5
+ - [1. Abstract](#1-abstract)
6
+ - [2. Usage](#2-usage)
7
+ - [2.1 Sample Base Model](#21-sample-base-model)
8
+ - [2.2 Quick Start](#22-quick-start)
9
+ - [2.2.1 Environment Setup](#221-environment-setup)
10
+ - [2.2.2 A Simple Example](#222-a-simple-example)
11
+ - [2.2.3 Frequently-Used Parameters](#223-frequently-used-parameters)
12
+ - [2.2.4 Update Function](#224-update-function)
13
+ - [2.2.5 Monkey Patch Demo](#225-monkey-patch-demo)
14
+ - [2.2.6 Downstream Task Evaluation](#226-downstream-task-evaluation)
15
+ - [2.2.7 The Detailed Signature of `generate` Function](#227-the-detailed-signature-of-generate-function)
16
+ - [3. Adaptation for Other Models](#3-adaptation-for-other-models)
17
+ - [3.1 Method 1 - Monkey Patching](#31-method-1---monkey-patching)
18
+ - [3.2 Method 2 - Direct Code Modification](#32-method-2---direct-code-modification)
19
+ - [3.3 Important Note](#33-important-note)
20
+ - [4. Other Advanced Usage](#4-other-advanced-usage)
21
+
22
+ ---
23
+
24
+ ## 1. Abstract
25
+ `SepCache` is a simple yet effective, native sparse attention `Cache` class proposed in the [`SepLLM paper - ICML 2025`](https://icml.cc/virtual/2025/poster/45536), which most closely aligns with the semantic distribution of natural language. In the training phase, `SepLLM` condenses the segment information into the KV of the separator that divides the segment. In the inference phase, the corresponding `SepCache` only needs to store the KVs of initial tokens, separator tokens, and recent tokens for generation.
26
+
27
+ Notably, `SepCache` also delivers strong performance across many tasks in training-free scenarios. Moreover, `SepLLM` (or simply `SepCache`) is the **most suitable baseline method for sparse attention mechanisms and KV compression/management**, as it is the natively sparse attention mechanism that best aligns with the natural semantic distribution of language.
28
+
29
+ See more details and advanced usage in https://github.com/HKUDS/SepLLM
30
+
31
+ ![image](https://hackmd.io/_uploads/r1POJoR4yg.png)
32
+
33
+ ## 2. Usage
34
+
35
+ ### 2.1 Sample Base Model
36
+
37
+ We recommend using models from the **Llama 3 series**. Our example model is based on `meta-llama/Meta-Llama-3-8B-Instruct`, for which we have already prepared a targeted `monkey patch`.
38
+
39
+ For other models, using `SepCache` requires minor modifications to the corresponding `modeling_xxx.py` file or writing a **custom monkey patch**. These changes are **very simple** -- you only need to pass arguments like `input_ids` to the `update` function of `SepCache` when calling it.
40
+
41
+ We will provide a detailed guide later on how to modify your `modeling_xxx.py` file or `monkey patch` file to adapt `SepCache` to any model.
42
+
43
+ ### 2.2 Quick Start
44
+
45
+ #### 2.2.1 Environment Setup
46
+ You need to install `transformers>=4.53`, and we recommend using `lm_eval>=0.4.9` for running evaluations. We suggest managing your Python environment with `conda` for better dependency control.
47
+
48
+ ```bash
49
+ conda create -n sepcache python=3.10
50
+ conda activate sepcache
51
+ pip install transformers==4.53
52
+ pip install lm_eval==0.4.9
53
+ ```
54
+ #### 2.2.2 A Simple Example
55
+ You can use `SepCache` by specifying `custom_generate="transformers-community/sep_cache"` or `custom_generate="Gausson/sep_cache"` when calling the `generate` function. In our demo, we have already prepared sample monkey patching for the `Llama 3 series` models and provided some common parameters for initializing `SepCache`.
56
+
57
+ ```python
58
+ # requires `transformers>=4.53.0`
59
+ from transformers import AutoModelForCausalLM, AutoTokenizer
60
+
61
+ # Preparing model, tokenizer, and model inputs
62
+ tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
63
+ model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct", device_map="auto")
64
+
65
+
66
+ messages = [{"role": "user", "content": "Tell me a story about a cat."}]
67
+ text = tokenizer.apply_chat_template(
68
+ messages,
69
+ tokenize=False,
70
+ add_generation_prompt=True,
71
+ enable_thinking=False
72
+ )
73
+ model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
74
+
75
+
76
+ # Using SepCache for generation
77
+ gen_out = model.generate(
78
+ # usual `generate` arguments
79
+ **model_inputs,
80
+ do_sample=False,
81
+ max_new_tokens=100,
82
+ return_dict_in_generate=True,
83
+ monkey_patch_verbose = True, # To see which functions are actually being monkey patched for `SepCache`.
84
+
85
+ # Using SepCache
86
+ custom_generate="transformers-community/sep_cache", ## Alternatively, you can use `Gausson/sep_cache`
87
+ trust_remote_code=True,
88
+
89
+ # SepCache arguments
90
+ init_cache_size = 4,
91
+ sep_cache_size = 128,
92
+ local_size = 256,
93
+ cache_size = 512,
94
+ USE_MAX_SEP_CACHE = True,
95
+ model_type = 'llama'
96
+ )
97
+
98
+ print(tokenizer.batch_decode(gen_out.sequences, skip_special_tokens=True))
99
+ assert "sepcache" in str(type(gen_out.past_key_values)).lower()
100
+ ```
101
+
102
+ It is worth noting that you must specify the `separator_token_ids: List[int]` and `PADDING_ID: int` parameters for initializing `SepCache`. In the example above, we did not do this because, for convenience, in the demo above, we specified `model_type = "llama"`, in which case `separator_token_ids` and `PADDING_ID` will be automatically filled.
103
+
104
+ However, when you use a tokenizer for a non-Llama 3 series model, you need to specify the specific values of `separator_token_ids` and `PADDING_ID` based on the tokenizer you are using. For example, the following example is based on the values obtained from a Llama 3 series tokenizer.
105
+ ```python
106
+ # Using SepCache for generation
107
+ gen_out = model.generate(
108
+ # usual `generate` arguments
109
+ **model_inputs,
110
+ do_sample=False,
111
+ max_new_tokens=100,
112
+ return_dict_in_generate=True,
113
+ monkey_patch_verbose = True, # To see which functions are actually being monkey patched for `SepCache`.
114
+
115
+ # Using SepCache
116
+ custom_generate="transformers-community/sep_cache", ## Alternatively, you can use `Gausson/sep_cache`
117
+ trust_remote_code=True,
118
+
119
+ # SepCache arguments
120
+ init_cache_size = 4,
121
+ sep_cache_size = 128,
122
+ local_size = 256,
123
+ cache_size = 512,
124
+ USE_MAX_SEP_CACHE = True,
125
+ separator_token_ids = [128000, 13, 11, 30, 0, 26, 25, 198, 220, 662, 1174, 949, 758, 2652, 551, 720, 256,262],
126
+ PADDING_ID = 128009
127
+ )
128
+ ```
129
+
130
+
131
+ #### 2.2.3 Frequently-Used Parameters
132
+
133
+ Below, we provide explanations and examples for the most commonly used parameters when initializing `SepCache`. These parameters can be passed through the `generate` function.
134
+
135
+ ```
136
+ `SepCache` stores the Key and Value states as lists of tensors, two lists for each layer. The expected shape for each tensor is
137
+ `[batch_size, num_heads, seq_len, head_dim]`.
138
+
139
+ Frequently-Used Parameters:
140
+
141
+ `init_cache_size: Union[int, List]`:
142
+ The maximum number of KVs to be stored for initial tokens.
143
+ In the paper, the hyperparameter `a` is an abbreviated alias for `init_cache_size`.
144
+
145
+ `sep_cache_size: Union[int, List]`:
146
+ The maximum number of KVs to be stored for separator tokens.
147
+ In the paper, the hyperparameter `s` is an abbreviated alias for `sep_cache_size`.
148
+
149
+ `local_size: Union[int, List]`:
150
+ The maximum number of KVs to be stored for local tokens (i.e., sliding window).
151
+ In the paper, the hyperparameter `w` is an abbreviated alias for `local_size`.
152
+
153
+ `cache_size: Union[int, List]`:
154
+ The maximum number of KVs to be stored for all the tokens, i.e., the size for the whole KV cache.
155
+ In the paper, the hyperparameter `c` is an abbreviated alias for `cache_size`.
156
+
157
+ Concerning these four parameters above:
158
+ When a list is passed (its length must be `layer_num`), it represents different values for each layer.
159
+ When an integer is passed, it means the setting is the same for all layers.
160
+
161
+
162
+ `USE_MAX_SEP_CACHE: bool`:
163
+ If True, it means we only keep at most `sep_cache_size` separators' KVs.
164
+ If the number exceeds this limit, older separators' KVs will be discarded, keeping only the most recent `sep_cache_size` KVs.
165
+ In the paper, the hyperparameter `s` is an abbreviated alias for `sep_cache_size`.
166
+
167
+ `separator_token_ids: List[int]`:
168
+ The token ids of the separator tokens for the current model's tokenizer.
169
+ We have some examples, such as the Llama-3 series models, where setting `model_type='llama'` allows you
170
+ to skip setting `separator_token_ids` and `PADDING_ID` (SepCache will auto-fill them).
171
+
172
+ `PADDING_ID: int`:
173
+ The token id of the padding token. You can just set `PADDING_ID` to the id of "<|endoftext|>" token of the tokenizer for the pretrained model.
174
+ ```
175
+ Important Note:
176
+ - When `cache_size` and `local_size` are set to infinity (i.e., sufficiently large positive integers), and `USE_MAX_SEP_CACHE` is `False`, `SepCache` degenerates into a regular Cache.
177
+ - You must always ensure that `init_cache_size` + `sep_cache_size` + `local_size` + `left_padding_offset` < `cache_size`. Here, `left_padding_offset` denotes the number of padding tokens in the record with the largest left paddings within a runtime batch. `left_padding_offset` can only be determined at runtime.
178
+ - To guarantee the above inequality always holds during runtime, when setting, you can intentionally create a sufficient margin between both sides of the following inequality:
179
+ `init_cache_size` + `sep_cache_size` + `local_size` < `cache_size`, i.e., `a`+`s`+`w`<`c` in the [SepLLM paper - ICML 2025](https://arxiv.org/abs/2412.12094) to leave room for `left_padding_offset`.
180
+
181
+ **More Important Note: In practice, no need to do positional encoding (PE) shifting like [StreamingLLM](https://github.com/mit-han-lab/streaming-llm/) if the actual length does not exceed the pretrained max PE length (which applies to most downstream tasks.) . So, for most basic usages, just set `APPLY_PE_SHIFT=False` (`False` is also the default setting) and `APPLY_PES_INSIDE=False` for initialization.**
182
+
183
+
184
+ #### 2.2.4 Update Function
185
+ After initialization, another key point to note is that when using the `update` function of `SepCache` to update the **keys/values** and the **past token IDs** (which is necessary in SepCache), the current `input_ids` must also be provided.
186
+ ```python
187
+ key_states, value_states = past_key_values.update(
188
+ key_states = key_states,
189
+ value_states = value_states,
190
+ input_ids = input_ids, ## required
191
+ layer_idx = layer_idx,
192
+ PREFILLING_FLAG = q_len > 1, ## `q_len` is the sequence length of the current `query_states`
193
+ )
194
+ ```
195
+
196
+
197
+ #### 2.2.5 Monkey Patch Demo
198
+ To adapt the `update` function of `SepCache` mentioned in [`2.2.4 Update Function`](#224-update-function), i.e., passing the current `input_ids` as a parameter to the `update` function. It is worth noting that during the prefilling stage, the shape of the input_ids tensor is `[batch_size, seq_len]`, while during the decoding stage of auto-regressive models, the shape of the `input_ids` tensor should be `[batch_size, 1]`.
199
+
200
+
201
+ In our `custom_generate/generate.py` file, we provide the `monkey_patching` function, which works by replacing the `forward` function in all the related instances of the `XXXAttention` class (for example, in the Llama 3 series model, it would be `LlamaAttention`) with our customized forward function (specified by the `model_atten_forward` parameter of the `monkey_patching` function).
202
+ ```python
203
+ def monkey_patching(model_obj,
204
+ model_atten_forward , ## The `forward` function used to patch.
205
+ possible_inner_model_names: List[str] = ["model", "transformer", "gpt_neox"] , # In `XXXForCausalLM` class, the possible name of internal attribute for model. e.g., "model", "transformer", "gpt_neox", etc.
206
+ possible_layers_names: List[str] = ["layers", "h" ], # In `XXXModel` class, the possible name of internal attribute for decoder layers, e.g., "layers", "h", etc.
207
+ atten_attr_name_pattern_list: List[str] = ["attention", "self_attn"], # In `XXXDecoderLayer` class, the possible name of internal attribute for self-attention, e.g., "attention", "self_attn", etc.
208
+ atten_attr_name_pattern_exclude: List[str] = ["norm", "layer"], # In `XXXDecoderLayer` class, the impossible name patterns (i.e., the patterns to be excluded) of internal attribute for self-attention module class, e.g., "norm" , etc. Sometimes, there will be some attributes like "post_attention_norm" and we do not want modify the `forward` function of it - we want to modify the `forward` function of `XXXAttention`. So, we need to exclude attribute name patterns like "norm" to accurately find the correct "forward" function to replace.
209
+ verbose = True):
210
+
211
+ """
212
+ This `monkey_patching` function is to
213
+ - find the `forward` function of the `XXXAttention` class.
214
+ - replace all the related `forward` functions of the instances of `XXXAttention` class with `model_atten_forward`.
215
+ """
216
+
217
+ ## To avoid the argument check failure, i.e., let "sepllm_kwargs" pass the check.
218
+ transformers.generation.GenerationMixin._validate_model_kwargs = _validate_model_kwargs
219
+
220
+ ## Get inner model obj
221
+ inner_model_type = PreTrainedModel
222
+ inner_model = find_inner_attribute(model_obj, possible_inner_model_names, inner_model_type)
223
+
224
+ ## Get the decoder layers (`nn.ModuleList`) obj
225
+ layers_type = nn.ModuleList
226
+ model_layers = find_inner_attribute(inner_model, possible_layers_names, layers_type)
227
+
228
+ ## Replace all the related `forward` functions of XXXAttention class's instances.
229
+ for i, decoder_layer in enumerate(model_layers):
230
+ self_attn_module = find_attribute_name(decoder_layer, atten_attr_name_pattern_list, atten_attr_name_pattern_exclude, nn.Module)
231
+ result = monkey_patch_by_class_path(self_attn_module, model_atten_forward)
232
+ if verbose:
233
+ decoder_class_name = get_importable_class_path(decoder_layer)
234
+ print(f"For Layer {i}'s `{decoder_class_name}`: {result}")
235
+
236
+ return model_layers
237
+ ```
238
+
239
+ The `monkey_patching` function primarily does three things:
240
+ - Precisely locate the `forward` function of all instances of the `XXXAttention` class.
241
+ - Replace the `forward` function with the `model_atten_forward` function you provide.
242
+ - Return the corresponding properties of the decoder layers found during the process, typically of type `nn.ModuleList`. This return value (`model_layers`) is only used to determine the number of layers in the current model later on (obtained by `len(model_layers)`).
243
+
244
+ In addition, the `monkey_patching` function replaces `transformers.generation.GenerationMixin._validate_model_kwargs` with our `_validate_model_kwargs` to bypass some parameter checks, as we will provide an additional `sepllm_kwargs` parameter to wrap the `input_ids` for eventual transmission to the `SepCache` `update` function.
245
+
246
+
247
+ **Please ensure that the `monkey_patching` function accurately locates and replaces the `forward` function of the `XXXAttention` class. The current `monkey_patching` is designed for the `Llama 3 series` models. For other models, you need to appropriately modify `monkey_patching` to ensure its correctness of targeting and replacement !** You can monitor the monkey patching process by setting `verbose=True` in the `monkey_patching` function (or, `monkey_patch_verbose = True` for the `generate` function.)
248
+
249
+
250
+ ```python
251
+ def truncate_input_ids_4_autoregression(input_ids, key_states):
252
+ if input_ids.shape[-1] != key_states.shape[-2]:
253
+ assert input_ids.shape[-1] >= key_states.shape[-2]
254
+ truncated_input_ids = input_ids[..., -key_states.shape[-2]: ]
255
+ return truncated_input_ids
256
+ else:
257
+ return input_ids
258
+ ```
259
+ The `truncate_input_ids_4_autoregression` function in the `custom_generate/generate.py` file is used to shape the `input_ids` tensor to `[batch_size, 1]` during decoding.
260
+
261
+ #### 2.2.6 Downstream Task Evaluation
262
+ We recommend using `lm_eval==0.4.9` for downstream task evaluation. You can pass model-related parameters via `--model_args` and generation-related parameters (including those required for initializing `SepCache`) via `--gen_kwargs`. Notably, you typically need to pass a `list` to `separator_token_ids` using a string format like `"id1;id2;id3"` (as shown in the example below).
263
+ ```bash
264
+ lm_eval --model hf \
265
+ --model_args pretrained=meta-llama/Meta-Llama-3-8B-Instruct,attn_implementation=flash_attention_2 \
266
+ --tasks gsm8k_cot \
267
+ --gen_kwargs custom_generate=transformers-community/sep_cache,trust_remote_code=True,monkey_patch_verbose=True,separator_token_ids="128000;13;11;30;0;26;25;198;220;662;1174;949;758;2652;551;720;256;262",PADDING_ID=128009\
268
+ --device cuda:0\
269
+ --batch_size 80 2>&1 | tee log.txt
270
+ ```
271
+ Note: `SepCache` is typically used in combination with `Flash Attention` to maximize generation efficiency.
272
+
273
+ <img width="1022" height="248" alt="1752618213617" src="https://github.com/user-attachments/assets/87e2e745-9677-4101-895e-dd6fc7b6039d" />
274
+
275
+ #### 2.2.7 The Detailed Signature of `generate` Function
276
+ Here is the detailed signature of our customized `generate` function for `SepCache` in `custom_generate/generate.py` file:
277
+
278
+ ```python
279
+ def generate(model,
280
+ ## For SepCache
281
+ init_cache_size: Union[int, List] = 4,
282
+ sep_cache_size: Union[int, List] = 128,
283
+ local_size: Union[int, List]=256,
284
+ cache_size: Union[int, List]=512,
285
+ SEP_ACCUMULATION: bool = True,
286
+ USE_MAX_SEP_CACHE: bool = False,
287
+ SEP_PADDING_IN_BATCH: bool = False,
288
+ separator_token_ids: List[int] = None, ## required for initialization if `model_type` is not provided.
289
+ PADDING_ID: int = None, ## required for initialization if `model_type` is not provided.
290
+
291
+ ## For inheritance & initialization states
292
+ past_tok_ids: List[torch.Tensor] = None, ## It saves all the token ids corresponding to the saved KVs for all layers in SepCache.
293
+ key_cache: List[torch.Tensor] = None,
294
+ value_cache: List[torch.Tensor] = None,
295
+
296
+ ## For debugging
297
+ PRINT_KV_RATIO_INSIDE: bool = False,
298
+ print_KV_inside_per_steps: int = 1000,
299
+ _seen_tokens: int = 0,
300
+ _kept_kv_ratio: List[Tuple[int]] = None,
301
+
302
+ ### For positional encoding shifting
303
+ APPLY_PE_SHIFT: bool = False,
304
+ APPLY_PES_INSIDE: bool = False,
305
+ _shifted_position_ids: List[torch.Tensor] = None,
306
+ _rope_unsqueeze_dim: int = 1, ## The unsqueeze_dim when applying RoPE.
307
+ _rope_seq_dim: int=1, ## The seq_len dimension for the `cos` or `sin` tensors.
308
+ pe_scaling_factor:float = 1.0,
309
+ pe_dim:int=128, ## The number of dims for positional encoding. Typically, just set the `head_dim` to this.
310
+ max_position_embeddings: int = 8192,
311
+ base: int=10000, ## The base for RoPE.
312
+
313
+ ## For basic transformer architecture
314
+ k_seq_dim: int=2, ## The dimension for seq_len in key tensors
315
+ v_seq_dim: int=2, ## The dimension for seq_len in value tensors
316
+ layer_num: int = None, ## required for initialization
317
+
318
+ model_type: str = 'llama', ## The model type for running the example. choose from ['llama', 'pythia','falcon'].
319
+ device = None,
320
+
321
+ ## For verbosity of monkey patching
322
+ monkey_patch_verbose: bool = False,
323
+
324
+ **kwargs
325
+ ):
326
+ ...
327
+ ```
328
+
329
+ ## 3. Adaptation for Other Models
330
+
331
+ Adapting `SepCache` to various models is simple - two approaches:
332
+
333
+
334
+ ### 3.1 Method 1 - Monkey Patching
335
+ - Modify the `monkey_patching` function to correctly locate and target the `forward` function of your model's `XXXAttention` class (e.g., `LlamaAttention` for Llama 3).
336
+ - Write your custom `model_atten_forward` function and use `monkey_patching` to replace the `forward` function of all `XXXAttention` class instances. The key modification is passing `input_ids` to `SepCache`'s `update` function.
337
+
338
+ ### 3.2 Method 2 - Direct Code Modification (Recommended for Simplicity)
339
+ Simply edit your `modeling_xxx.py` file to implement:
340
+
341
+ - Initialize `past_key_values` as a `SepCache` instance at the appropriate location (e.g., in `XXXForCausalLM` or `XXXModel` class' `forward` function).
342
+ - Modify the `forward` function of the `XXXAttention` class to pass `input_ids` to `SepCache`'s `update` function.
343
+
344
+ ### 3.3 Important Note
345
+ The shape of `input_ids` is `[batch_size, seq_len]` during prefilling, and `[batch_size, 1]` during generation.
346
+
347
+ ## 4. Other Advanced Usage
348
+
349
+ Please refer to https://github.com/HKUDS/SepLLM, in which there are detailed explanations and examples.