|  | --- | 
					
						
						|  | license: llama3 | 
					
						
						|  | language: | 
					
						
						|  | - en | 
					
						
						|  | base_model: | 
					
						
						|  | - meta-llama/Meta-Llama-3-8B-Instruct | 
					
						
						|  | tags: | 
					
						
						|  | - custom_generate | 
					
						
						|  | --- | 
					
						
						|  | # SepCache - Native Sparse Attention Cache | 
					
						
						|  |  | 
					
						
						|  | ## Table of Contents | 
					
						
						|  |  | 
					
						
						|  | - [1. Abstract](#1-abstract) | 
					
						
						|  | - [2. Usage](#2-usage) | 
					
						
						|  | - [2.1 Sample Base Model](#21-sample-base-model) | 
					
						
						|  | - [2.2 Quick Start](#22-quick-start) | 
					
						
						|  | - [2.2.1 Environment Setup](#221-environment-setup) | 
					
						
						|  | - [2.2.2 A Simple Example](#222-a-simple-example) | 
					
						
						|  | - [2.2.3 Frequently-Used Parameters](#223-frequently-used-parameters) | 
					
						
						|  | - [2.2.4 Update Function](#224-update-function) | 
					
						
						|  | - [2.2.5 Monkey Patch Demo](#225-monkey-patch-demo) | 
					
						
						|  | - [2.2.6 Downstream Task Evaluation](#226-downstream-task-evaluation) | 
					
						
						|  | - [2.2.7 The Detailed Signature of `generate` Function](#227-the-detailed-signature-of-generate-function) | 
					
						
						|  | - [3. Adaptation for Other Models](#3-adaptation-for-other-models) | 
					
						
						|  | - [3.1 Method 1 - Monkey Patching](#31-method-1---monkey-patching) | 
					
						
						|  | - [3.2 Method 2 - Direct Code Modification](#32-method-2---direct-code-modification) | 
					
						
						|  | - [3.3 Important Note](#33-important-note) | 
					
						
						|  | - [4. Other Advanced Usage](#4-other-advanced-usage) | 
					
						
						|  |  | 
					
						
						|  | --- | 
					
						
						|  |  | 
					
						
						|  | ## 1. Abstract | 
					
						
						|  | `SepCache` is a simple yet effective, native sparse attention `Cache` class proposed in the [`SepLLM paper - ICML 2025`](https://icml.cc/virtual/2025/poster/45536), which most closely aligns with the semantic distribution of natural language. In the training phase, `SepLLM` condenses the segment information into the KV of the separator that divides the segment. In the inference phase, the corresponding `SepCache` only needs to store the KVs of initial tokens, separator tokens, and recent tokens for generation. | 
					
						
						|  |  | 
					
						
						|  | Notably, `SepCache` also delivers strong performance across many tasks in training-free scenarios. Moreover, `SepLLM` (or simply `SepCache`) is the **most suitable baseline method for sparse attention mechanisms and KV compression/management**, as it is the natively sparse attention mechanism that best aligns with the natural semantic distribution of language. | 
					
						
						|  |  | 
					
						
						|  | See more details and advanced usage in https://github.com/HKUDS/SepLLM | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | ## 2. Usage | 
					
						
						|  |  | 
					
						
						|  | ### 2.1 Sample Base Model | 
					
						
						|  |  | 
					
						
						|  | We recommend using models from the **Llama 3 series**. Our example model is based on `meta-llama/Meta-Llama-3-8B-Instruct`, for which we have already prepared a targeted `monkey patch`. | 
					
						
						|  |  | 
					
						
						|  | For other models, using `SepCache` requires minor modifications to the corresponding `modeling_xxx.py` file or writing a **custom monkey patch**. These changes are **very simple** -- you only need to pass arguments like `input_ids` to the `update` function of `SepCache` when calling it. | 
					
						
						|  |  | 
					
						
						|  | We will provide a detailed guide later on how to modify your `modeling_xxx.py` file or `monkey patch` file to adapt `SepCache` to any model. | 
					
						
						|  |  | 
					
						
						|  | ### 2.2 Quick Start | 
					
						
						|  |  | 
					
						
						|  | #### 2.2.1 Environment Setup | 
					
						
						|  | You need to install `transformers>=4.53.0,<4.54.0`, and we recommend using `lm_eval>=0.4.9` for running evaluations. We suggest managing your Python environment with `conda` for better dependency control. | 
					
						
						|  |  | 
					
						
						|  | ```bash | 
					
						
						|  | conda create -n sepcache python=3.10 | 
					
						
						|  | conda activate sepcache | 
					
						
						|  | pip install transformers==4.53 | 
					
						
						|  | pip install lm_eval==0.4.9 | 
					
						
						|  | ``` | 
					
						
						|  | #### 2.2.2 A Simple Example | 
					
						
						|  | You can use `SepCache` by specifying `custom_generate="transformers-community/sep_cache"` or `custom_generate="Gausson/sep_cache"` when calling the `generate` function. In our demo, we have already prepared sample monkey patching for the `Llama 3 series` models and provided some common parameters for initializing `SepCache`. | 
					
						
						|  |  | 
					
						
						|  | ```python | 
					
						
						|  | # requires `transformers>=4.53.0,<4.54.0` | 
					
						
						|  | from transformers import AutoModelForCausalLM, AutoTokenizer | 
					
						
						|  |  | 
					
						
						|  | # Preparing model, tokenizer, and model inputs | 
					
						
						|  | tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct") | 
					
						
						|  | model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct", device_map="auto") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | messages = [{"role": "user", "content": "Tell me a story about a cat."}] | 
					
						
						|  | text = tokenizer.apply_chat_template( | 
					
						
						|  | messages, | 
					
						
						|  | tokenize=False, | 
					
						
						|  | add_generation_prompt=True, | 
					
						
						|  | enable_thinking=False | 
					
						
						|  | ) | 
					
						
						|  | model_inputs = tokenizer([text], return_tensors="pt").to(model.device) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | # Using SepCache for generation | 
					
						
						|  | gen_out = model.generate( | 
					
						
						|  | # usual `generate` arguments | 
					
						
						|  | **model_inputs, | 
					
						
						|  | do_sample=False, | 
					
						
						|  | max_new_tokens=100, | 
					
						
						|  | return_dict_in_generate=True, | 
					
						
						|  | monkey_patch_verbose = True, # To see which functions are actually being monkey patched for `SepCache`. | 
					
						
						|  |  | 
					
						
						|  | # Using SepCache | 
					
						
						|  | custom_generate="transformers-community/sep_cache", ## Alternatively, you can use `Gausson/sep_cache` | 
					
						
						|  | trust_remote_code=True, | 
					
						
						|  |  | 
					
						
						|  | # SepCache arguments | 
					
						
						|  | init_cache_size = 4, | 
					
						
						|  | sep_cache_size = 128, | 
					
						
						|  | local_size = 256, | 
					
						
						|  | cache_size = 512, | 
					
						
						|  | USE_MAX_SEP_CACHE = True, | 
					
						
						|  | model_type = 'llama' | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | print(tokenizer.batch_decode(gen_out.sequences, skip_special_tokens=True)) | 
					
						
						|  | assert "sepcache" in str(type(gen_out.past_key_values)).lower() | 
					
						
						|  | ``` | 
					
						
						|  |  | 
					
						
						|  | It is worth noting that you must specify the `separator_token_ids: List[int]` and `PADDING_ID: int` parameters for initializing `SepCache`. In the example above, we did not do this because, for convenience, in the demo above, we specified `model_type = "llama"`, in which case `separator_token_ids` and `PADDING_ID` will be automatically filled. | 
					
						
						|  |  | 
					
						
						|  | However, when you use a tokenizer for a non-Llama 3 series model, you need to specify the specific values of `separator_token_ids` and `PADDING_ID` based on the tokenizer you are using. For example, the following example is based on the values obtained from a Llama 3 series tokenizer. | 
					
						
						|  | ```python | 
					
						
						|  | # Using SepCache for generation | 
					
						
						|  | gen_out = model.generate( | 
					
						
						|  | # usual `generate` arguments | 
					
						
						|  | **model_inputs, | 
					
						
						|  | do_sample=False, | 
					
						
						|  | max_new_tokens=100, | 
					
						
						|  | return_dict_in_generate=True, | 
					
						
						|  | monkey_patch_verbose = True, # To see which functions are actually being monkey patched for `SepCache`. | 
					
						
						|  |  | 
					
						
						|  | # Using SepCache | 
					
						
						|  | custom_generate="transformers-community/sep_cache", ## Alternatively, you can use `Gausson/sep_cache` | 
					
						
						|  | trust_remote_code=True, | 
					
						
						|  |  | 
					
						
						|  | # SepCache arguments | 
					
						
						|  | init_cache_size = 4, | 
					
						
						|  | sep_cache_size = 128, | 
					
						
						|  | local_size = 256, | 
					
						
						|  | cache_size = 512, | 
					
						
						|  | USE_MAX_SEP_CACHE = True, | 
					
						
						|  | separator_token_ids = [128000, 13, 11, 30, 0, 26, 25, 198, 220, 662, 1174, 949, 758, 2652, 551, 720, 256,262], | 
					
						
						|  | PADDING_ID = 128009 | 
					
						
						|  | ) | 
					
						
						|  | ``` | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | #### 2.2.3 Frequently-Used Parameters | 
					
						
						|  |  | 
					
						
						|  | Below, we provide explanations and examples for the most commonly used parameters when initializing `SepCache`. These parameters can be passed through the `generate` function. | 
					
						
						|  |  | 
					
						
						|  | ``` | 
					
						
						|  | `SepCache` stores the Key and Value states as lists of tensors, two lists for each layer. The expected shape for each tensor is | 
					
						
						|  | `[batch_size, num_heads, seq_len, head_dim]`. | 
					
						
						|  |  | 
					
						
						|  | Frequently-Used Parameters: | 
					
						
						|  |  | 
					
						
						|  | `init_cache_size: Union[int, List]`: | 
					
						
						|  | The maximum number of KVs to be stored for initial tokens. | 
					
						
						|  | In the paper, the hyperparameter `a` is an abbreviated alias for `init_cache_size`. | 
					
						
						|  |  | 
					
						
						|  | `sep_cache_size: Union[int, List]`: | 
					
						
						|  | The maximum number of KVs to be stored for separator tokens. | 
					
						
						|  | In the paper, the hyperparameter `s` is an abbreviated alias for `sep_cache_size`. | 
					
						
						|  |  | 
					
						
						|  | `local_size: Union[int, List]`: | 
					
						
						|  | The maximum number of KVs to be stored for local tokens (i.e., sliding window). | 
					
						
						|  | In the paper, the hyperparameter `w` is an abbreviated alias for `local_size`. | 
					
						
						|  |  | 
					
						
						|  | `cache_size: Union[int, List]`: | 
					
						
						|  | The maximum number of KVs to be stored for all the tokens, i.e., the size for the whole KV cache. | 
					
						
						|  | In the paper, the hyperparameter `c` is an abbreviated alias for `cache_size`. | 
					
						
						|  |  | 
					
						
						|  | Concerning these four parameters above: | 
					
						
						|  | When a list is passed (its length must be `layer_num`), it represents different values for each layer. | 
					
						
						|  | When an integer is passed, it means the setting is the same for all layers. | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | `USE_MAX_SEP_CACHE: bool`: | 
					
						
						|  | If True, it means we only keep at most `sep_cache_size` separators' KVs. | 
					
						
						|  | If the number exceeds this limit, older separators' KVs will be discarded, keeping only the most recent `sep_cache_size` KVs. | 
					
						
						|  | In the paper, the hyperparameter `s` is an abbreviated alias for `sep_cache_size`. | 
					
						
						|  |  | 
					
						
						|  | `separator_token_ids: List[int]`: | 
					
						
						|  | The token ids of the separator tokens for the current model's tokenizer. | 
					
						
						|  | We have some examples, such as the Llama-3 series models, where setting `model_type='llama'` allows you | 
					
						
						|  | to skip setting `separator_token_ids` and `PADDING_ID` (SepCache will auto-fill them). | 
					
						
						|  |  | 
					
						
						|  | `PADDING_ID: int`: | 
					
						
						|  | The token id of the padding token. You can just set `PADDING_ID` to the id of "<|endoftext|>" token of the tokenizer for the pretrained model. | 
					
						
						|  | ``` | 
					
						
						|  | Important Note: | 
					
						
						|  | - When `cache_size` and `local_size` are set to infinity (i.e., sufficiently large positive integers), and `USE_MAX_SEP_CACHE` is `False`, `SepCache` degenerates into a regular Cache. | 
					
						
						|  | - You must always ensure that `init_cache_size` + `sep_cache_size` + `local_size` + `left_padding_offset` < `cache_size`. Here, `left_padding_offset` denotes the number of padding tokens in the record with the largest left paddings within a runtime batch. `left_padding_offset` can only be determined at runtime. | 
					
						
						|  | - To guarantee the above inequality always holds during runtime, when setting, you can intentionally create a sufficient margin between both sides of the following inequality: | 
					
						
						|  | `init_cache_size` + `sep_cache_size` + `local_size`  < `cache_size`, i.e., `a`+`s`+`w`<`c` in the [SepLLM paper - ICML 2025](https://arxiv.org/abs/2412.12094) to leave room for `left_padding_offset`. | 
					
						
						|  |  | 
					
						
						|  | **More Important Note: In practice, no need to do positional encoding (PE) shifting like [StreamingLLM](https://github.com/mit-han-lab/streaming-llm/) if the actual length does not exceed the pretrained max PE length (which applies to most downstream tasks.) . So, for most basic usages, just set `APPLY_PE_SHIFT=False` (`False` is also the default setting) and `APPLY_PES_INSIDE=False` for initialization.** | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | #### 2.2.4 Update Function | 
					
						
						|  | After initialization, another key point to note is that when using the `update` function of `SepCache` to update the **keys/values** and the **past token IDs** (which is necessary in SepCache), the current `input_ids` must also be provided. | 
					
						
						|  | ```python | 
					
						
						|  | key_states, value_states = past_key_values.update( | 
					
						
						|  | key_states = key_states, | 
					
						
						|  | value_states = value_states, | 
					
						
						|  | input_ids = input_ids,  ## required | 
					
						
						|  | layer_idx = layer_idx, | 
					
						
						|  | PREFILLING_FLAG = q_len > 1, ## `q_len` is the sequence length of the current `query_states` | 
					
						
						|  | ) | 
					
						
						|  | ``` | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | #### 2.2.5 Monkey Patch Demo | 
					
						
						|  | To adapt the `update` function of `SepCache` mentioned in [`2.2.4 Update Function`](#224-update-function), i.e., passing the current `input_ids` as a parameter to the `update` function. It is worth noting that during the prefilling stage, the shape of the input_ids tensor is `[batch_size, seq_len]`, while during the decoding stage of auto-regressive models, the shape of the `input_ids` tensor should be `[batch_size, 1]`. | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | In our `custom_generate/generate.py` file, we provide the `monkey_patching` function, which works by replacing the `forward` function in all the related instances of the `XXXAttention` class (for example, in the Llama 3 series model, it would be `LlamaAttention`) with our customized forward function (specified by the `model_atten_forward` parameter of the `monkey_patching` function). | 
					
						
						|  | ```python | 
					
						
						|  | def monkey_patching(model_obj, | 
					
						
						|  | model_atten_forward , ## The `forward` function used to patch. | 
					
						
						|  | possible_inner_model_names: List[str] = ["model", "transformer", "gpt_neox"] , # In `XXXForCausalLM` class, the possible name of internal attribute for model. e.g.,  "model", "transformer", "gpt_neox", etc. | 
					
						
						|  | possible_layers_names: List[str] = ["layers", "h" ],  # In `XXXModel` class,  the possible name of internal attribute for decoder layers, e.g.,  "layers", "h", etc. | 
					
						
						|  | atten_attr_name_pattern_list: List[str] = ["attention", "self_attn"],  # In `XXXDecoderLayer` class, the possible name of internal attribute for self-attention, e.g.,  "attention", "self_attn", etc. | 
					
						
						|  | atten_attr_name_pattern_exclude: List[str] = ["norm", "layer"], # In `XXXDecoderLayer` class, the impossible name patterns (i.e., the patterns to be excluded) of internal attribute for self-attention module class, e.g., "norm" , etc. Sometimes, there will be some attributes like "post_attention_norm" and we do not want modify the `forward` function of it - we want to modify the `forward` function of `XXXAttention`. So, we need to exclude attribute name patterns like "norm" to accurately find the correct "forward" function to replace. | 
					
						
						|  | verbose = True): | 
					
						
						|  |  | 
					
						
						|  | """ | 
					
						
						|  | This `monkey_patching` function is to | 
					
						
						|  | - find the `forward` function of the `XXXAttention` class. | 
					
						
						|  | - replace all the related `forward` functions of the instances of `XXXAttention` class with `model_atten_forward`. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | ## To avoid the argument check failure, i.e., let "sepllm_kwargs" pass the check. | 
					
						
						|  | transformers.generation.GenerationMixin._validate_model_kwargs = _validate_model_kwargs | 
					
						
						|  |  | 
					
						
						|  | ## Get inner model obj | 
					
						
						|  | inner_model_type = PreTrainedModel | 
					
						
						|  | inner_model = find_inner_attribute(model_obj, possible_inner_model_names, inner_model_type) | 
					
						
						|  |  | 
					
						
						|  | ## Get the decoder layers (`nn.ModuleList`) obj | 
					
						
						|  | layers_type = nn.ModuleList | 
					
						
						|  | model_layers = find_inner_attribute(inner_model, possible_layers_names, layers_type) | 
					
						
						|  |  | 
					
						
						|  | ## Replace all the related `forward` functions of XXXAttention class's instances. | 
					
						
						|  | for i, decoder_layer in enumerate(model_layers): | 
					
						
						|  | self_attn_module = find_attribute_name(decoder_layer, atten_attr_name_pattern_list, atten_attr_name_pattern_exclude, nn.Module) | 
					
						
						|  | result = monkey_patch_by_class_path(self_attn_module, model_atten_forward) | 
					
						
						|  | if verbose: | 
					
						
						|  | decoder_class_name = get_importable_class_path(decoder_layer) | 
					
						
						|  | print(f"For Layer {i}'s `{decoder_class_name}`: {result}") | 
					
						
						|  |  | 
					
						
						|  | return model_layers | 
					
						
						|  | ``` | 
					
						
						|  |  | 
					
						
						|  | The `monkey_patching` function primarily does three things: | 
					
						
						|  | - Precisely locate the `forward` function of all instances of the `XXXAttention` class. | 
					
						
						|  | - Replace the `forward` function with the `model_atten_forward` function you provide. | 
					
						
						|  | - Return the corresponding properties of the decoder layers found during the process, typically of type `nn.ModuleList`. This return value (`model_layers`) is only used to determine the number of layers in the current model later on (obtained by `len(model_layers)`). | 
					
						
						|  |  | 
					
						
						|  | In addition, the `monkey_patching` function replaces `transformers.generation.GenerationMixin._validate_model_kwargs` with our `_validate_model_kwargs` to bypass some parameter checks, as we will provide an additional `sepllm_kwargs` parameter to wrap the `input_ids` for eventual transmission to the `SepCache` `update` function. | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | **Please ensure that the `monkey_patching` function accurately locates and replaces the `forward` function of the `XXXAttention` class. The current `monkey_patching` is designed for the `Llama 3 series` models. For other models, you need to appropriately modify `monkey_patching` to ensure its correctness of targeting and replacement !** You can monitor the monkey patching process by setting `verbose=True` in the `monkey_patching` function (or, `monkey_patch_verbose = True` for the `generate` function.) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | ```python | 
					
						
						|  | def truncate_input_ids_4_autoregression(input_ids, key_states): | 
					
						
						|  | if input_ids.shape[-1] != key_states.shape[-2]: | 
					
						
						|  | assert input_ids.shape[-1] >= key_states.shape[-2] | 
					
						
						|  | truncated_input_ids = input_ids[..., -key_states.shape[-2]: ] | 
					
						
						|  | return truncated_input_ids | 
					
						
						|  | else: | 
					
						
						|  | return input_ids | 
					
						
						|  | ``` | 
					
						
						|  | The `truncate_input_ids_4_autoregression` function in the `custom_generate/generate.py` file is used to shape the `input_ids` tensor to `[batch_size, 1]` during decoding. | 
					
						
						|  |  | 
					
						
						|  | #### 2.2.6 Downstream Task Evaluation | 
					
						
						|  | We recommend using `lm_eval==0.4.9` for downstream task evaluation. You can pass model-related parameters via `--model_args` and generation-related parameters (including those required for initializing `SepCache`) via `--gen_kwargs`. Notably, you typically need to pass a `list` to `separator_token_ids` using a string format like `"id1;id2;id3"` (as shown in the example below). | 
					
						
						|  | ```bash | 
					
						
						|  | lm_eval --model hf \ | 
					
						
						|  | --model_args pretrained=meta-llama/Meta-Llama-3-8B-Instruct,attn_implementation=flash_attention_2 \ | 
					
						
						|  | --tasks    gsm8k_cot  \ | 
					
						
						|  | --gen_kwargs custom_generate=transformers-community/sep_cache,trust_remote_code=True,monkey_patch_verbose=True,init_cache_size=4,sep_cache_size=128,local_size=256,cache_size=512,separator_token_ids="128000;13;11;30;0;26;25;198;220;662;1174;949;758;2652;551;720;256;262",PADDING_ID=128009 \ | 
					
						
						|  | --device cuda:0\ | 
					
						
						|  | --batch_size 80 2>&1 | tee log.txt | 
					
						
						|  | ``` | 
					
						
						|  | Note: `SepCache` is typically used in combination with `Flash Attention` to maximize generation efficiency. | 
					
						
						|  |  | 
					
						
						|  | <img width="1022" height="248" alt="1752618213617" src="https://github.com/user-attachments/assets/87e2e745-9677-4101-895e-dd6fc7b6039d" /> | 
					
						
						|  |  | 
					
						
						|  | #### 2.2.7 The Detailed Signature of `generate` Function | 
					
						
						|  | Here is the detailed signature of our customized `generate` function for `SepCache` in `custom_generate/generate.py` file: | 
					
						
						|  |  | 
					
						
						|  | ```python | 
					
						
						|  | def generate(model, | 
					
						
						|  | ## For SepCache | 
					
						
						|  | init_cache_size: Union[int, List] = 4, | 
					
						
						|  | sep_cache_size: Union[int, List] = 128, | 
					
						
						|  | local_size: Union[int, List]=256, | 
					
						
						|  | cache_size: Union[int, List]=512, | 
					
						
						|  | SEP_ACCUMULATION: bool = True, | 
					
						
						|  | USE_MAX_SEP_CACHE: bool = False, | 
					
						
						|  | SEP_PADDING_IN_BATCH: bool = False, | 
					
						
						|  | separator_token_ids: List[int] = None, ## required for initialization if `model_type` is not provided. | 
					
						
						|  | PADDING_ID: int = None, ## required for initialization if `model_type` is not provided. | 
					
						
						|  |  | 
					
						
						|  | ## For inheritance & initialization states | 
					
						
						|  | past_tok_ids: List[torch.Tensor] = None,  ## It saves all the token ids corresponding to the saved KVs for all layers in SepCache. | 
					
						
						|  | key_cache: List[torch.Tensor] = None, | 
					
						
						|  | value_cache: List[torch.Tensor] = None, | 
					
						
						|  |  | 
					
						
						|  | ## For debugging | 
					
						
						|  | PRINT_KV_RATIO_INSIDE: bool = False, | 
					
						
						|  | print_KV_inside_per_steps: int = 1000, | 
					
						
						|  | _seen_tokens: int = 0, | 
					
						
						|  | _kept_kv_ratio: List[Tuple[int]] = None, | 
					
						
						|  |  | 
					
						
						|  | ### For positional encoding shifting | 
					
						
						|  | APPLY_PE_SHIFT: bool = False, | 
					
						
						|  | APPLY_PES_INSIDE: bool = False, | 
					
						
						|  | _shifted_position_ids:  List[torch.Tensor] = None, | 
					
						
						|  | _rope_unsqueeze_dim: int = 1, ## The unsqueeze_dim when applying RoPE. | 
					
						
						|  | _rope_seq_dim: int=1, ## The seq_len dimension for the `cos` or `sin` tensors. | 
					
						
						|  | pe_scaling_factor:float = 1.0, | 
					
						
						|  | pe_dim:int=128, ## The number of dims for positional encoding. Typically, just set the `head_dim` to this. | 
					
						
						|  | max_position_embeddings: int = 8192, | 
					
						
						|  | base: int=10000,  ## The base for RoPE. | 
					
						
						|  |  | 
					
						
						|  | ## For basic transformer architecture | 
					
						
						|  | k_seq_dim: int=2, ## The dimension for seq_len in key tensors | 
					
						
						|  | v_seq_dim: int=2, ## The dimension for seq_len in value tensors | 
					
						
						|  | layer_num: int = None, ## required for initialization | 
					
						
						|  |  | 
					
						
						|  | model_type: str = 'llama',  ## The model type for running the example. choose from ['llama', 'pythia','falcon']. | 
					
						
						|  | device = None, | 
					
						
						|  |  | 
					
						
						|  | ## For verbosity of monkey patching | 
					
						
						|  | monkey_patch_verbose: bool = False, | 
					
						
						|  |  | 
					
						
						|  | **kwargs | 
					
						
						|  | ): | 
					
						
						|  | ... | 
					
						
						|  | ``` | 
					
						
						|  |  | 
					
						
						|  | ## 3. Adaptation for Other Models | 
					
						
						|  |  | 
					
						
						|  | Adapting `SepCache` to various models is simple - two approaches: | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | ### 3.1 Method 1 - Monkey Patching | 
					
						
						|  | - Modify the `monkey_patching` function to correctly locate and target the `forward` function of your model's `XXXAttention` class (e.g., `LlamaAttention` for Llama 3). | 
					
						
						|  | - Write your custom `model_atten_forward` function and use `monkey_patching` to replace the `forward` function of all `XXXAttention` class instances. The key modification is passing `input_ids` to `SepCache`'s `update` function. | 
					
						
						|  |  | 
					
						
						|  | ### 3.2 Method 2 - Direct Code Modification (Recommended for Simplicity) | 
					
						
						|  | Simply edit your `modeling_xxx.py` file to implement: | 
					
						
						|  |  | 
					
						
						|  | - Initialize `past_key_values` as a `SepCache` instance at the appropriate location (e.g., in `XXXForCausalLM` or `XXXModel` class' `forward` function). | 
					
						
						|  | - Modify the `forward` function of the `XXXAttention` class to pass `input_ids` to `SepCache`'s `update` function. | 
					
						
						|  |  | 
					
						
						|  | ### 3.3 Important Note | 
					
						
						|  | The shape of `input_ids` is `[batch_size, seq_len]` during prefilling, and `[batch_size, 1]` during generation. | 
					
						
						|  |  | 
					
						
						|  | ## 4. Other Advanced Usage | 
					
						
						|  |  | 
					
						
						|  | Please refer to https://github.com/HKUDS/SepLLM, in which there are detailed explanations and examples. | 
					
						
						|  |  | 
					
						
						|  | ## 5. Citation | 
					
						
						|  | If you find our work helpful, please consider giving us a like ❤️ and citing our paper. We greatly appreciate your support 😄 | 
					
						
						|  | ``` | 
					
						
						|  | @inproceedings{chen2025sepllm, | 
					
						
						|  | title={{SepLLM: Accelerate Large Language Models by Compressing One Segment into One Separator}}, | 
					
						
						|  | author={Chen, Guoxuan and Shi, Han and Li, Jiawei and Gao, Yihang and Ren, Xiaozhe and Chen, Yimeng and Jiang, Xin and Li, Zhenguo and Liu, Weiyang and Huang, Chao}, | 
					
						
						|  | booktitle={International Conference on Machine Learning}, | 
					
						
						|  | year={2025}, | 
					
						
						|  | note={Also available at arXiv:2412.12094} | 
					
						
						|  | } | 
					
						
						|  | ``` | 
					
						
						|  |  | 
					
						
						|  |  |