sampling / README.md
manueldeprada's picture
update
5eb0cfa
metadata
library_name: transformers
tags:
  - custom_generate
  - sampling
  - kvcache

Sampling with KV Cache

Description

A clean, hackable implementation of sampling (also called ancestral sampling or multinomial sampling). This is a simplified alternative to the complex generation mixin in transformers, designed for readability and ease of modification while maintaining full performance.

The implementation supports both sampling and greedy decoding modes, with optional temperature scaling and top-k/top-p filtering.

Base model

Model compatibility

Most transformer LLM/VLM models trained for causal language modeling.

Relevant Arguments

  • temperature (float): Sampling temperature (default: 1.0, higher = more random)
  • top_k (int): Only consider top-k most probable tokens (default: None)
  • top_p (float): Only consider tokens with cumulative probability <= top_p (default: None)
  • do_sample (bool): Whether to use sampling (True, default) or greedy decoding (False)

Logits Processing Order

Logits processors are applied in sequence: temperature → softmax → top_k → top_p (same as HuggingFace's LogitProcessor system). Temperature scaling occurs before top-p filtering, affecting the probability distribution that top-p operates on.

For example, with temperature=1.0, top_p=0.9 might include tokens A, B, C. With temperature=0.5, probability mass is much more concentrated, so top_p=0.9 might only include token A.

Outputs

When return_dict_in_generate=True, returns a dictionary with:

  • sequences: Generated token IDs
  • scores: Log probabilities of sampled tokens (with temperature/sampling modifications)
  • logprobs: Original model log probabilities (T=1, no modifications) Otherwise, returns a tensor of generated token IDs.

Example usage

from transformers import AutoModelForCausalLM, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct", device_map="auto")

inputs = tokenizer(["The quick brown"], return_tensors="pt").to(model.device)

# Basic sampling
gen_out = model.generate(**inputs, custom_generate="manueldeprada/sampling", trust_remote_code=True)

# With temperature
gen_out = model.generate(**inputs, custom_generate="manueldeprada/sampling", temperature=0.8, trust_remote_code=True)

# With top-k
gen_out = model.generate(**inputs, custom_generate="manueldeprada/sampling", top_k=50, trust_remote_code=True)

# With top-p (nucleus sampling)
gen_out = model.generate(**inputs, custom_generate="manueldeprada/sampling", top_p=0.9, trust_remote_code=True)

# Greedy decoding (no sampling)
gen_out = model.generate(**inputs, custom_generate="manueldeprada/sampling", do_sample=False, trust_remote_code=True)

# Get detailed output with probabilities
gen_out = model.generate(
    **inputs, 
    custom_generate="manueldeprada/sampling", 
    return_dict_in_generate=True,
    trust_remote_code=True
)
print(f"Generated text: {tokenizer.batch_decode(gen_out['sequences'], skip_special_tokens=True)}")
print(f"Sampling scores: {gen_out['scores']}")
print(f"Model log probabilities: {gen_out['logprobs']}")

Algorithm

  1. Prepare input sequences
  2. For each generation step:
    • Get logits from the model for the current sequence
    • Apply temperature scaling to logits
    • Optionally apply top-k filtering (keep only top-k tokens)
    • Optionally apply top-p filtering (nucleus sampling)
    • Convert to probabilities using softmax
    • Sample from the probability distribution (or take argmax for greedy)
    • Append the selected token to the sequence
    • Track sequence completion
  3. Return generated sequences and probability information