File size: 4,067 Bytes
47784f5 97c3e33 47784f5 97c3e33 47784f5 97c3e33 47784f5 97c3e33 47784f5 97c3e33 47784f5 9826914 47784f5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
---
library_name: transformers
tags:
- custom_generate
- sampling
- kvcache
---
# Sampling with KV Cache
## Description
A clean, hackable implementation of sampling (also called ancestral sampling or multinomial sampling) with full KV cache support. This is a simplified alternative to the complex generation mixin in transformers, designed for readability and ease of modification while maintaining full performance.
The implementation supports both sampling and greedy decoding modes, with optional temperature scaling and top-k/top-p filtering.
## Base model
- [HuggingFaceTB/SmolLM2-135M-Instruct](https://huggingface.co/HuggingFaceTB/SmolLM2-135M-Instruct)
## Model compatibility
Most transformer LLM/VLM models trained for causal language modeling.
## Relevant Arguments
- `temperature` (float): Sampling temperature (default: 1.0, higher = more random)
- `top_k` (int): Only consider top-k most probable tokens (default: None)
- `top_p` (float): Only consider tokens with cumulative probability <= top_p (default: None)
- `do_sample` (bool): Whether to use sampling (True, default) or greedy decoding (False)
### Logits Processing Order
Logits processors are applied in sequence: `temperature → softmax → top_k → top_p` (same as HuggingFace's `LogitProcessor` system). Temperature scaling occurs before top-p filtering, affecting the probability distribution that top-p operates on.
For example, with `temperature=1.0`, `top_p=0.9` might include tokens A, B, C. With `temperature=0.5`, probability mass is much more concentrated, so `top_p=0.9` might only include token A.
## Outputs
When `return_dict_in_generate=True`, returns a dictionary with:
- `sequences`: Generated token IDs
- `scores`: Log probabilities of sampled tokens (with temperature/sampling modifications)
- `logprobs`: Original model log probabilities (T=1, no modifications)
Otherwise, returns a tensor of generated token IDs.
## Example usage
```py
from transformers import AutoModelForCausalLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct", device_map="auto")
inputs = tokenizer(["The quick brown"], return_tensors="pt").to(model.device)
# Basic sampling
gen_out = model.generate(**inputs, custom_generate="manueldeprada/sampling_with_kvcache_hf_helpers", trust_remote_code=True)
# With temperature
gen_out = model.generate(**inputs, custom_generate="manueldeprada/sampling_with_kvcache_hf_helpers", temperature=0.8, trust_remote_code=True)
# With top-k
gen_out = model.generate(**inputs, custom_generate="manueldeprada/sampling_with_kvcache_hf_helpers", top_k=50, trust_remote_code=True)
# With top-p (nucleus sampling)
gen_out = model.generate(**inputs, custom_generate="manueldeprada/sampling_with_kvcache_hf_helpers", top_p=0.9, trust_remote_code=True)
# Greedy decoding (no sampling)
gen_out = model.generate(**inputs, custom_generate="manueldeprada/sampling_with_kvcache_hf_helpers", do_sample=False, trust_remote_code=True)
# Get detailed output with probabilities
gen_out = model.generate(
**inputs,
custom_generate="manueldeprada/sampling_with_kvcache_hf_helpers",
return_dict_in_generate=True,
trust_remote_code=True
)
print(f"Generated text: {tokenizer.batch_decode(gen_out['sequences'], skip_special_tokens=True)}")
print(f"Sampling scores: {gen_out['scores']}")
print(f"Model log probabilities: {gen_out['logprobs']}")
```
## Algorithm
1. Initialize KV cache and prepare input sequences
2. For each generation step:
- Get logits from the model for the current sequence
- Apply temperature scaling to logits
- Optionally apply top-k filtering (keep only top-k tokens)
- Optionally apply top-p filtering (nucleus sampling)
- Convert to probabilities using softmax
- Sample from the probability distribution (or take argmax for greedy)
- Append the selected token to the sequence
- Update KV cache and track sequence completion
3. Return generated sequences and probability information
|