|
--- |
|
library_name: transformers |
|
tags: |
|
- custom_generate |
|
- sampling |
|
- kvcache |
|
--- |
|
|
|
# Sampling with KV Cache |
|
|
|
## Description |
|
A clean, hackable implementation of sampling (also called ancestral sampling or multinomial sampling) with full KV cache support. This is a simplified alternative to the complex generation mixin in transformers, designed for readability and ease of modification while maintaining full performance. |
|
|
|
The implementation supports both sampling and greedy decoding modes, with optional temperature scaling and top-k/top-p filtering. |
|
|
|
## Base model |
|
- [HuggingFaceTB/SmolLM2-135M-Instruct](https://huggingface.co/HuggingFaceTB/SmolLM2-135M-Instruct) |
|
|
|
## Model compatibility |
|
Most transformer LLM/VLM models trained for causal language modeling. |
|
|
|
## Relevant Arguments |
|
- `temperature` (float): Sampling temperature (default: 1.0, higher = more random) |
|
- `top_k` (int): Only consider top-k most probable tokens (default: None) |
|
- `top_p` (float): Only consider tokens with cumulative probability <= top_p (default: None) |
|
- `do_sample` (bool): Whether to use sampling (True, default) or greedy decoding (False) |
|
|
|
### Logits Processing Order |
|
Logits processors are applied in sequence: `temperature β softmax β top_k β top_p` (same as HuggingFace's `LogitProcessor` system). Temperature scaling occurs before top-p filtering, affecting the probability distribution that top-p operates on. |
|
|
|
For example, with `temperature=1.0`, `top_p=0.9` might include tokens A, B, C. With `temperature=0.5`, probability mass is much more concentrated, so `top_p=0.9` might only include token A. |
|
|
|
## Outputs |
|
When `return_dict_in_generate=True`, returns a dictionary with: |
|
- `sequences`: Generated token IDs |
|
- `scores`: Log probabilities of sampled tokens (with temperature/sampling modifications) |
|
- `logprobs`: Original model log probabilities (T=1, no modifications) |
|
Otherwise, returns a tensor of generated token IDs. |
|
|
|
## Example usage |
|
|
|
```py |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") |
|
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct", device_map="auto") |
|
|
|
inputs = tokenizer(["The quick brown"], return_tensors="pt").to(model.device) |
|
|
|
# Basic sampling |
|
gen_out = model.generate(**inputs, custom_generate="manueldeprada/sampling_with_kvcache_hf_helpers", trust_remote_code=True) |
|
|
|
# With temperature |
|
gen_out = model.generate(**inputs, custom_generate="manueldeprada/sampling_with_kvcache_hf_helpers", temperature=0.8, trust_remote_code=True) |
|
|
|
# With top-k |
|
gen_out = model.generate(**inputs, custom_generate="manueldeprada/sampling_with_kvcache_hf_helpers", top_k=50, trust_remote_code=True) |
|
|
|
# With top-p (nucleus sampling) |
|
gen_out = model.generate(**inputs, custom_generate="manueldeprada/sampling_with_kvcache_hf_helpers", top_p=0.9, trust_remote_code=True) |
|
|
|
# Greedy decoding (no sampling) |
|
gen_out = model.generate(**inputs, custom_generate="manueldeprada/sampling_with_kvcache_hf_helpers", do_sample=False, trust_remote_code=True) |
|
|
|
# Get detailed output with probabilities |
|
gen_out = model.generate( |
|
**inputs, |
|
custom_generate="manueldeprada/sampling_with_kvcache_hf_helpers", |
|
return_dict_in_generate=True, |
|
trust_remote_code=True |
|
) |
|
print(f"Generated text: {tokenizer.batch_decode(gen_out['sequences'], skip_special_tokens=True)}") |
|
print(f"Sampling scores: {gen_out['scores']}") |
|
print(f"Model log probabilities: {gen_out['logprobs']}") |
|
``` |
|
|
|
## Algorithm |
|
1. Initialize KV cache and prepare input sequences |
|
2. For each generation step: |
|
- Get logits from the model for the current sequence |
|
- Apply temperature scaling to logits |
|
- Optionally apply top-k filtering (keep only top-k tokens) |
|
- Optionally apply top-p filtering (nucleus sampling) |
|
- Convert to probabilities using softmax |
|
- Sample from the probability distribution (or take argmax for greedy) |
|
- Append the selected token to the sequence |
|
- Update KV cache and track sequence completion |
|
3. Return generated sequences and probability information |
|
|
|
|
|
|
|
|
|
|