Commit
·
0da6031
1
Parent(s):
8f67839
eliminate hf helpers
Browse files- README.md +20 -56
- custom_generate/generate.py +185 -100
README.md
CHANGED
@@ -2,13 +2,14 @@
|
|
2 |
library_name: transformers
|
3 |
tags:
|
4 |
- custom_generate
|
5 |
-
-
|
|
|
6 |
---
|
7 |
|
8 |
-
#
|
9 |
|
10 |
## Description
|
11 |
-
A clean, hackable implementation of ancestral sampling
|
12 |
|
13 |
The implementation supports both sampling and greedy decoding modes, with optional temperature scaling and top-k/top-p filtering.
|
14 |
|
@@ -18,19 +19,23 @@ The implementation supports both sampling and greedy decoding modes, with option
|
|
18 |
## Model compatibility
|
19 |
Most transformer LLM/VLM models trained for causal language modeling.
|
20 |
|
21 |
-
##
|
22 |
- `temperature` (float): Sampling temperature (default: 1.0, higher = more random)
|
23 |
- `top_k` (int): Only consider top-k most probable tokens (default: None)
|
24 |
- `top_p` (float): Only consider tokens with cumulative probability <= top_p (default: None)
|
25 |
- `do_sample` (bool): Whether to use sampling (True, default) or greedy decoding (False)
|
26 |
|
27 |
-
|
|
|
|
|
|
|
|
|
|
|
28 |
When `return_dict_in_generate=True`, returns a dictionary with:
|
29 |
- `sequences`: Generated token IDs
|
30 |
- `scores`: Log probabilities of sampled tokens (with temperature/sampling modifications)
|
31 |
-
- `
|
32 |
-
|
33 |
-
- `lens`: Final sequence lengths
|
34 |
|
35 |
## Example usage
|
36 |
|
@@ -43,30 +48,30 @@ model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct", devic
|
|
43 |
inputs = tokenizer(["The quick brown"], return_tensors="pt").to(model.device)
|
44 |
|
45 |
# Basic sampling
|
46 |
-
gen_out = model.generate(**inputs, custom_generate="manueldeprada/
|
47 |
|
48 |
# With temperature
|
49 |
-
gen_out = model.generate(**inputs, custom_generate="manueldeprada/
|
50 |
|
51 |
# With top-k
|
52 |
-
gen_out = model.generate(**inputs, custom_generate="manueldeprada/
|
53 |
|
54 |
# With top-p (nucleus sampling)
|
55 |
-
gen_out = model.generate(**inputs, custom_generate="manueldeprada/
|
56 |
|
57 |
# Greedy decoding (no sampling)
|
58 |
-
gen_out = model.generate(**inputs, custom_generate="manueldeprada/
|
59 |
|
60 |
# Get detailed output with probabilities
|
61 |
gen_out = model.generate(
|
62 |
**inputs,
|
63 |
-
custom_generate="manueldeprada/
|
64 |
return_dict_in_generate=True,
|
65 |
trust_remote_code=True
|
66 |
)
|
67 |
print(f"Generated text: {tokenizer.batch_decode(gen_out['sequences'], skip_special_tokens=True)}")
|
68 |
print(f"Sampling scores: {gen_out['scores']}")
|
69 |
-
print(f"Model log probabilities: {gen_out['
|
70 |
```
|
71 |
|
72 |
## Algorithm
|
@@ -82,47 +87,6 @@ print(f"Model log probabilities: {gen_out['logps']}")
|
|
82 |
- Update KV cache and track sequence completion
|
83 |
3. Return generated sequences and probability information
|
84 |
|
85 |
-
## Helper Functions for Custom Generation
|
86 |
-
|
87 |
-
The implementation provides two key helper functions that you can use to build your own generation strategies:
|
88 |
-
|
89 |
-
### `init_gen(model_kwargs, model, max_new_tokens, bos_token_id)`
|
90 |
-
Initializes the generation process and prepares the KV cache:
|
91 |
-
- Sets up input sequences and model inputs
|
92 |
-
- Prepares the KV cache for generation
|
93 |
-
- Returns updated `model_kwargs` and `input_ids`
|
94 |
|
95 |
-
### `ps_next(model, model_kwargs, input_ids)`
|
96 |
-
Gets the next token logits and updates the KV cache:
|
97 |
-
- Runs the model forward pass
|
98 |
-
- Extracts logits for the last token
|
99 |
-
- Updates the KV cache
|
100 |
-
- Returns updated `model_kwargs` and `logits`
|
101 |
|
102 |
-
### Example: Custom Generation Loop
|
103 |
-
|
104 |
-
```py
|
105 |
-
from ancestral_sampling.generate import init_gen, ps_next
|
106 |
-
|
107 |
-
def custom_generation(model, model_kwargs, max_new_tokens=20, temperature=1.0):
|
108 |
-
# Initialize generation
|
109 |
-
model_kwargs, input_ids = init_gen(model_kwargs, model, max_new_tokens, bos_token_id)
|
110 |
-
|
111 |
-
for i in range(max_new_tokens):
|
112 |
-
# Get next token logits
|
113 |
-
model_kwargs, logits = ps_next(model, model_kwargs, input_ids)
|
114 |
-
|
115 |
-
# Your custom logic here
|
116 |
-
probs = (logits / temperature).softmax(-1)
|
117 |
-
next_token = torch.multinomial(probs, 1)
|
118 |
-
|
119 |
-
# Append token and continue
|
120 |
-
input_ids = torch.cat([input_ids, next_token], dim=-1)
|
121 |
-
|
122 |
-
# Add your stopping conditions
|
123 |
-
if next_token.item() == eos_token_id:
|
124 |
-
break
|
125 |
-
|
126 |
-
return input_ids
|
127 |
-
```
|
128 |
|
|
|
2 |
library_name: transformers
|
3 |
tags:
|
4 |
- custom_generate
|
5 |
+
- sampling
|
6 |
+
- kvcache
|
7 |
---
|
8 |
|
9 |
+
# Sampling with KV Cache
|
10 |
|
11 |
## Description
|
12 |
+
A clean, hackable implementation of sampling (also called ancestral sampling or multinomial sampling) with full KV cache support. This is a simplified alternative to the complex generation mixin in transformers, designed for readability and ease of modification while maintaining full performance.
|
13 |
|
14 |
The implementation supports both sampling and greedy decoding modes, with optional temperature scaling and top-k/top-p filtering.
|
15 |
|
|
|
19 |
## Model compatibility
|
20 |
Most transformer LLM/VLM models trained for causal language modeling.
|
21 |
|
22 |
+
## Relevant Arguments
|
23 |
- `temperature` (float): Sampling temperature (default: 1.0, higher = more random)
|
24 |
- `top_k` (int): Only consider top-k most probable tokens (default: None)
|
25 |
- `top_p` (float): Only consider tokens with cumulative probability <= top_p (default: None)
|
26 |
- `do_sample` (bool): Whether to use sampling (True, default) or greedy decoding (False)
|
27 |
|
28 |
+
### Logits Processing Order
|
29 |
+
Logits processors are applied in sequence: `temperature → softmax → top_k → top_p` (same as HuggingFace's `LogitProcessor` system). Temperature scaling occurs before top-p filtering, affecting the probability distribution that top-p operates on.
|
30 |
+
|
31 |
+
For example, with `temperature=1.0`, `top_p=0.9` might include tokens A, B, C. With `temperature=0.5`, probability mass is much more concentrated, so `top_p=0.9` might only include token A.
|
32 |
+
|
33 |
+
## Outputs
|
34 |
When `return_dict_in_generate=True`, returns a dictionary with:
|
35 |
- `sequences`: Generated token IDs
|
36 |
- `scores`: Log probabilities of sampled tokens (with temperature/sampling modifications)
|
37 |
+
- `logprobs`: Original model log probabilities (T=1, no modifications)
|
38 |
+
Otherwise, returns a tensor of generated token IDs.
|
|
|
39 |
|
40 |
## Example usage
|
41 |
|
|
|
48 |
inputs = tokenizer(["The quick brown"], return_tensors="pt").to(model.device)
|
49 |
|
50 |
# Basic sampling
|
51 |
+
gen_out = model.generate(**inputs, custom_generate="manueldeprada/sampling_with_kvcache", trust_remote_code=True)
|
52 |
|
53 |
# With temperature
|
54 |
+
gen_out = model.generate(**inputs, custom_generate="manueldeprada/sampling_with_kvcache", temperature=0.8, trust_remote_code=True)
|
55 |
|
56 |
# With top-k
|
57 |
+
gen_out = model.generate(**inputs, custom_generate="manueldeprada/sampling_with_kvcache", top_k=50, trust_remote_code=True)
|
58 |
|
59 |
# With top-p (nucleus sampling)
|
60 |
+
gen_out = model.generate(**inputs, custom_generate="manueldeprada/sampling_with_kvcache", top_p=0.9, trust_remote_code=True)
|
61 |
|
62 |
# Greedy decoding (no sampling)
|
63 |
+
gen_out = model.generate(**inputs, custom_generate="manueldeprada/sampling_with_kvcache", do_sample=False, trust_remote_code=True)
|
64 |
|
65 |
# Get detailed output with probabilities
|
66 |
gen_out = model.generate(
|
67 |
**inputs,
|
68 |
+
custom_generate="manueldeprada/sampling_with_kvcache",
|
69 |
return_dict_in_generate=True,
|
70 |
trust_remote_code=True
|
71 |
)
|
72 |
print(f"Generated text: {tokenizer.batch_decode(gen_out['sequences'], skip_special_tokens=True)}")
|
73 |
print(f"Sampling scores: {gen_out['scores']}")
|
74 |
+
print(f"Model log probabilities: {gen_out['logprobs']}")
|
75 |
```
|
76 |
|
77 |
## Algorithm
|
|
|
87 |
- Update KV cache and track sequence completion
|
88 |
3. Return generated sequences and probability information
|
89 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
|
custom_generate/generate.py
CHANGED
@@ -1,87 +1,157 @@
|
|
1 |
import torch
|
2 |
-
from transformers import
|
|
|
|
|
3 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
"""
|
7 |
-
|
8 |
-
|
|
|
|
|
|
|
|
|
9 |
Args:
|
10 |
model: The language model
|
11 |
model_kwargs: Model keyword arguments including KV cache
|
12 |
input_ids: Current input token IDs
|
13 |
-
|
14 |
-
|
15 |
Returns:
|
16 |
-
Updated model_kwargs,
|
17 |
"""
|
18 |
-
model_inputs =
|
19 |
with torch.no_grad():
|
20 |
outputs = model(**model_inputs, return_dict=True)
|
21 |
-
|
22 |
logits = outputs.logits[:, -1].detach()
|
23 |
-
model_kwargs =
|
24 |
-
outputs, model_kwargs, is_encoder_decoder=model.config.is_encoder_decoder
|
25 |
-
)
|
26 |
del outputs
|
27 |
return model_kwargs, logits
|
28 |
|
|
|
29 |
def init_gen(model_kwargs, model, max_new_tokens, bos_token_id):
|
30 |
"""
|
31 |
-
|
32 |
-
|
|
|
|
|
|
|
33 |
Args:
|
34 |
model_kwargs: Model keyword arguments
|
35 |
model: The language model
|
36 |
max_new_tokens: Maximum number of new tokens to generate
|
37 |
-
|
|
|
38 |
Returns:
|
39 |
Model keyword arguments and input token IDs
|
40 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
)
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
return model_kwargs, input_ids
|
55 |
|
56 |
-
def
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
top_p = model.generation_config.top_p
|
67 |
-
if top_p < 1.0:
|
68 |
-
sorted_probs, sorted_indices = torch.sort(ps, descending=True)
|
69 |
-
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
|
70 |
-
|
71 |
-
# Remove tokens with cumulative probability above the threshold
|
72 |
-
sorted_indices_to_remove = cumulative_probs > top_p
|
73 |
-
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
74 |
-
sorted_indices_to_remove[..., 0] = 0
|
75 |
-
|
76 |
-
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
77 |
-
ps[indices_to_remove] = 0.0
|
78 |
-
ps = ps / ps.sum(dim=-1, keepdim=True)
|
79 |
-
return ps
|
80 |
-
|
81 |
-
def ancestral_sampling(model_kwargs, model, eos_token_ids, pad_token_id, bos_token_id, do_sample=True, max_new_tokens=20, T=1.0):
|
82 |
"""
|
83 |
-
|
84 |
-
|
85 |
Args:
|
86 |
prompts: List of input prompts
|
87 |
model: The language model
|
@@ -90,55 +160,64 @@ def ancestral_sampling(model_kwargs, model, eos_token_ids, pad_token_id, bos_tok
|
|
90 |
pad_token_id: Padding token ID
|
91 |
bos_token_id: Beginning-of-sequence token ID
|
92 |
max_new_tokens: Maximum number of new tokens to generate
|
93 |
-
|
94 |
Returns:
|
95 |
Generated sequences, log probabilities, and metadata
|
96 |
"""
|
97 |
# Initialize the generation process and prepare the KV cache
|
98 |
-
model_kwargs, input_ids = init_gen(
|
99 |
-
|
100 |
-
|
|
|
101 |
|
102 |
# Keeps track of which sequences are finished and their lengths
|
103 |
active_seqs = input_ids.new_ones((batch_size, 1), dtype=torch.bool)
|
104 |
-
lens = torch.full((batch_size,), max_prompts_len, dtype=torch.long, device=input_ids.device)
|
105 |
# Modified log probabilities of the sequences
|
106 |
scores = torch.zeros((batch_size, max_new_tokens), dtype=model.dtype)
|
107 |
-
# Unfiltered sequence log probabilities (
|
108 |
-
|
109 |
|
110 |
for i in range(max_new_tokens):
|
111 |
# Get the next token probabilities and update the KV cache
|
112 |
-
model_kwargs, logits =
|
113 |
-
|
|
|
|
|
114 |
model_ps = logits.softmax(-1)
|
115 |
-
|
116 |
-
|
117 |
-
ps =
|
118 |
-
|
|
|
|
|
119 |
# Sample the next token and gather the log probabilities
|
120 |
-
if do_sample:
|
121 |
-
next_token_ids =
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
|
|
|
|
|
|
|
|
|
|
127 |
input_ids = torch.cat([input_ids, next_token_ids], dim=-1)
|
128 |
-
scores[:, i] = (
|
129 |
-
|
130 |
-
|
131 |
-
lens += active_seqs.squeeze(-1).long()
|
132 |
active_seqs &= ~torch.isin(next_token_ids, eos_token_ids)
|
133 |
if active_seqs.sum() == 0:
|
134 |
-
break
|
135 |
-
return input_ids.detach().cpu(), scores[
|
|
|
136 |
|
137 |
def generate(model, **kwargs):
|
138 |
"""
|
139 |
-
|
140 |
Simple implementation with proper KV caching support.
|
141 |
-
|
142 |
Args:
|
143 |
model: The language model
|
144 |
model_kwargs: Model keyword arguments from the tokenizer
|
@@ -147,29 +226,38 @@ def generate(model, **kwargs):
|
|
147 |
top_k: Only consider top-k most probable tokens
|
148 |
top_p: Only consider tokens with cumulative probability <= top_p
|
149 |
**kwargs: Additional arguments
|
150 |
-
|
151 |
Returns:
|
152 |
Generated token IDs
|
153 |
"""
|
154 |
generation_config = model.generation_config
|
155 |
-
max_new_tokens = kwargs.get(
|
156 |
max_new_tokens = 512 if max_new_tokens is None else max_new_tokens
|
157 |
-
do_sample = kwargs.get(
|
158 |
-
eos_token_ids = kwargs.get(
|
159 |
if eos_token_ids is None:
|
160 |
-
raise ValueError(
|
|
|
|
|
161 |
eos_token_ids = torch.as_tensor(eos_token_ids, device=model.device)
|
162 |
if eos_token_ids is not None and eos_token_ids.ndim == 0:
|
163 |
eos_token_ids = eos_token_ids.unsqueeze(0)
|
164 |
-
|
165 |
-
pad_token_id = kwargs.get(
|
166 |
-
|
|
|
|
|
|
|
|
|
|
|
167 |
if bos_token_id is None:
|
168 |
-
raise ValueError(
|
169 |
-
|
170 |
-
|
|
|
|
|
171 |
|
172 |
-
generated_ids, scores,
|
173 |
model_kwargs=kwargs,
|
174 |
model=model,
|
175 |
eos_token_ids=eos_token_ids,
|
@@ -177,17 +265,14 @@ def generate(model, **kwargs):
|
|
177 |
bos_token_id=bos_token_id,
|
178 |
do_sample=do_sample,
|
179 |
max_new_tokens=max_new_tokens,
|
180 |
-
|
181 |
)
|
182 |
|
183 |
if return_dict:
|
184 |
return {
|
185 |
"sequences": generated_ids,
|
186 |
"scores": scores,
|
187 |
-
"
|
188 |
-
"prompt_lens": prompt_lens,
|
189 |
-
"lens": lens,
|
190 |
}
|
191 |
else:
|
192 |
return generated_ids
|
193 |
-
|
|
|
1 |
import torch
|
2 |
+
from transformers import Cache, DynamicCache
|
3 |
+
from transformers.generation.utils import ModelOutput
|
4 |
+
from typing import Optional, Any
|
5 |
|
6 |
+
def prepare_inputs_for_generation(
|
7 |
+
input_ids: torch.LongTensor,
|
8 |
+
past_key_values: Optional[Cache] = None,
|
9 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
10 |
+
cache_position: Optional[torch.LongTensor] = None,
|
11 |
+
**kwargs,
|
12 |
+
):
|
13 |
+
input_ids = input_ids[:, cache_position].clone(memory_format=torch.contiguous_format)
|
14 |
+
cur_len = input_ids.shape[1]
|
15 |
+
model_inputs = {"cache_position": cache_position,
|
16 |
+
"past_key_values": past_key_values,
|
17 |
+
"input_ids": input_ids,
|
18 |
+
"inputs_embeds": None,
|
19 |
+
"attention_mask": attention_mask,
|
20 |
+
}
|
21 |
+
if attention_mask is not None and kwargs.get("position_ids") is None:
|
22 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
23 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
24 |
+
kwargs["position_ids"] = position_ids
|
25 |
+
if past_key_values is not None:
|
26 |
+
for name in ("position_ids", "token_type_ids"):
|
27 |
+
if name in kwargs:
|
28 |
+
kwargs[name] = kwargs[name][:, -cur_len:].clone(memory_format=torch.contiguous_format)
|
29 |
+
model_inputs.update({k: v for k, v in kwargs.items() if k not in model_inputs})
|
30 |
+
return model_inputs
|
31 |
|
32 |
+
def update_model_kwargs_for_generation(
|
33 |
+
outputs: ModelOutput,
|
34 |
+
model_kwargs: dict[str, Any],
|
35 |
+
num_new_tokens: int = 1,
|
36 |
+
) -> dict[str, Any]:
|
37 |
+
model_kwargs["past_key_values"] = getattr(outputs, "past_key_values")
|
38 |
+
if "token_type_ids" in model_kwargs:
|
39 |
+
token_type_ids = model_kwargs["token_type_ids"]
|
40 |
+
model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)
|
41 |
+
if "attention_mask" in model_kwargs:
|
42 |
+
attention_mask = model_kwargs["attention_mask"]
|
43 |
+
model_kwargs["attention_mask"] = torch.cat(
|
44 |
+
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
|
45 |
+
)
|
46 |
+
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens
|
47 |
+
return model_kwargs
|
48 |
+
|
49 |
+
|
50 |
+
def next_logits_with_cache_update(model, model_kwargs, input_ids):
|
51 |
"""
|
52 |
+
Gets the next token logits and updates the KV cache:
|
53 |
+
- Runs the model forward pass
|
54 |
+
- Extracts logits for the last token
|
55 |
+
- Updates the KV cache
|
56 |
+
- Returns updated `model_kwargs` and `logits`
|
57 |
+
|
58 |
Args:
|
59 |
model: The language model
|
60 |
model_kwargs: Model keyword arguments including KV cache
|
61 |
input_ids: Current input token IDs
|
62 |
+
|
|
|
63 |
Returns:
|
64 |
+
Updated model_kwargs, logits for the next token
|
65 |
"""
|
66 |
+
model_inputs = prepare_inputs_for_generation(input_ids, **model_kwargs)
|
67 |
with torch.no_grad():
|
68 |
outputs = model(**model_inputs, return_dict=True)
|
69 |
+
|
70 |
logits = outputs.logits[:, -1].detach()
|
71 |
+
model_kwargs = update_model_kwargs_for_generation(outputs, model_kwargs)
|
|
|
|
|
72 |
del outputs
|
73 |
return model_kwargs, logits
|
74 |
|
75 |
+
|
76 |
def init_gen(model_kwargs, model, max_new_tokens, bos_token_id):
|
77 |
"""
|
78 |
+
Initializes the generation process and prepares the KV cache:
|
79 |
+
- Sets up input sequences and model inputs
|
80 |
+
- Prepares the KV cache for generation
|
81 |
+
- Returns updated `model_kwargs` and `input_ids`
|
82 |
+
|
83 |
Args:
|
84 |
model_kwargs: Model keyword arguments
|
85 |
model: The language model
|
86 |
max_new_tokens: Maximum number of new tokens to generate
|
87 |
+
bos_token_id: Beginning-of-sequence token ID
|
88 |
+
|
89 |
Returns:
|
90 |
Model keyword arguments and input token IDs
|
91 |
"""
|
92 |
+
input_ids = model_kwargs.pop("input_ids")
|
93 |
+
model_kwargs["past_key_values"] = DynamicCache() if model_kwargs.get("past_key_values") is None else model_kwargs["past_key_values"]
|
94 |
+
assert isinstance(model_kwargs["past_key_values"], Cache), "past_key_values must be a Cache object"
|
95 |
+
cache_position = torch.ones(input_ids.shape[1], dtype=torch.int64, device=input_ids.device).cumsum(0) - 1
|
96 |
+
cache_position = cache_position[model_kwargs["past_key_values"].get_seq_length() :]
|
97 |
+
model_kwargs["cache_position"] = cache_position
|
98 |
+
return model_kwargs, input_ids
|
99 |
|
100 |
+
|
101 |
+
def _apply_top_k(ps, model):
|
102 |
+
"""Apply top-k filtering to probabilities."""
|
103 |
+
if not hasattr(model, "generation_config") or not hasattr(
|
104 |
+
model.generation_config, "top_k"
|
105 |
+
):
|
106 |
+
return ps
|
107 |
+
|
108 |
+
top_k = model.generation_config.top_k
|
109 |
+
if top_k is None or top_k >= ps.size(-1):
|
110 |
+
return ps
|
111 |
+
|
112 |
+
indices_to_remove = ps < torch.topk(ps, top_k)[0][..., -1, None]
|
113 |
+
ps[indices_to_remove] = 0.0
|
114 |
+
return ps / ps.sum(dim=-1, keepdim=True)
|
115 |
+
|
116 |
+
|
117 |
+
def _apply_top_p(ps, model):
|
118 |
+
"""Apply top-p (nucleus) filtering to probabilities."""
|
119 |
+
if not hasattr(model, "generation_config") or not hasattr(
|
120 |
+
model.generation_config, "top_p"
|
121 |
+
):
|
122 |
+
return ps
|
123 |
+
|
124 |
+
top_p = model.generation_config.top_p
|
125 |
+
if top_p is None or top_p >= 1.0:
|
126 |
+
return ps
|
127 |
+
|
128 |
+
sorted_probs, sorted_indices = torch.sort(ps, descending=True)
|
129 |
+
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
|
130 |
+
|
131 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
132 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
133 |
+
sorted_indices_to_remove[..., 0] = 0
|
134 |
+
|
135 |
+
indices_to_remove = sorted_indices_to_remove.scatter(
|
136 |
+
1, sorted_indices, sorted_indices_to_remove
|
137 |
)
|
138 |
+
ps[indices_to_remove] = 0.0
|
139 |
+
return ps / ps.sum(dim=-1, keepdim=True)
|
140 |
+
|
|
|
141 |
|
142 |
+
def sampling_with_kvcache(
|
143 |
+
model_kwargs,
|
144 |
+
model,
|
145 |
+
eos_token_ids,
|
146 |
+
pad_token_id,
|
147 |
+
bos_token_id,
|
148 |
+
do_sample=True,
|
149 |
+
max_new_tokens=20,
|
150 |
+
temperature=1.0,
|
151 |
+
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
152 |
"""
|
153 |
+
Sampling implementation with proper KV caching.
|
154 |
+
|
155 |
Args:
|
156 |
prompts: List of input prompts
|
157 |
model: The language model
|
|
|
160 |
pad_token_id: Padding token ID
|
161 |
bos_token_id: Beginning-of-sequence token ID
|
162 |
max_new_tokens: Maximum number of new tokens to generate
|
163 |
+
|
164 |
Returns:
|
165 |
Generated sequences, log probabilities, and metadata
|
166 |
"""
|
167 |
# Initialize the generation process and prepare the KV cache
|
168 |
+
model_kwargs, input_ids = init_gen(
|
169 |
+
model_kwargs, model, max_new_tokens, bos_token_id
|
170 |
+
)
|
171 |
+
batch_size, _ = input_ids.shape
|
172 |
|
173 |
# Keeps track of which sequences are finished and their lengths
|
174 |
active_seqs = input_ids.new_ones((batch_size, 1), dtype=torch.bool)
|
|
|
175 |
# Modified log probabilities of the sequences
|
176 |
scores = torch.zeros((batch_size, max_new_tokens), dtype=model.dtype)
|
177 |
+
# Unfiltered sequence log probabilities (temperature=1, no sampling processors applied)
|
178 |
+
logprobs = torch.zeros((batch_size, max_new_tokens), dtype=model.dtype)
|
179 |
|
180 |
for i in range(max_new_tokens):
|
181 |
# Get the next token probabilities and update the KV cache
|
182 |
+
model_kwargs, logits = next_logits_with_cache_update(
|
183 |
+
model, model_kwargs, input_ids
|
184 |
+
)
|
185 |
+
# Store original model probabilities (temperature=1, no sampling processors applied)
|
186 |
model_ps = logits.softmax(-1)
|
187 |
+
|
188 |
+
# Logit processors (temperature, top-k, top-p). We can chain these!
|
189 |
+
ps = (logits / temperature).softmax(-1)
|
190 |
+
ps = _apply_top_k(ps, model)
|
191 |
+
ps = _apply_top_p(ps, model)
|
192 |
+
|
193 |
# Sample the next token and gather the log probabilities
|
194 |
+
if do_sample: # Sampling
|
195 |
+
next_token_ids = (
|
196 |
+
torch.multinomial(ps, 1) * active_seqs + pad_token_id * ~active_seqs
|
197 |
+
)
|
198 |
+
else: # Greedy decoding
|
199 |
+
next_token_ids = (
|
200 |
+
torch.argmax(ps, dim=-1).unsqueeze(-1) * active_seqs
|
201 |
+
+ pad_token_id * ~active_seqs
|
202 |
+
)
|
203 |
+
next_token_logprobs = ps.gather(-1, next_token_ids).log()
|
204 |
+
next_token_model_logprobs = model_ps.gather(-1, next_token_ids).log()
|
205 |
+
|
206 |
input_ids = torch.cat([input_ids, next_token_ids], dim=-1)
|
207 |
+
scores[:, i] = (next_token_logprobs * active_seqs).squeeze()
|
208 |
+
logprobs[:, i] = (next_token_model_logprobs * active_seqs).squeeze()
|
209 |
+
|
|
|
210 |
active_seqs &= ~torch.isin(next_token_ids, eos_token_ids)
|
211 |
if active_seqs.sum() == 0:
|
212 |
+
break
|
213 |
+
return input_ids.detach().cpu(), scores[:, : i + 1], logprobs[:, : i + 1]
|
214 |
+
|
215 |
|
216 |
def generate(model, **kwargs):
|
217 |
"""
|
218 |
+
Sampling strategy - multinomial sampling with temperature and optional top-k/top-p filtering.
|
219 |
Simple implementation with proper KV caching support.
|
220 |
+
|
221 |
Args:
|
222 |
model: The language model
|
223 |
model_kwargs: Model keyword arguments from the tokenizer
|
|
|
226 |
top_k: Only consider top-k most probable tokens
|
227 |
top_p: Only consider tokens with cumulative probability <= top_p
|
228 |
**kwargs: Additional arguments
|
229 |
+
|
230 |
Returns:
|
231 |
Generated token IDs
|
232 |
"""
|
233 |
generation_config = model.generation_config
|
234 |
+
max_new_tokens = kwargs.get("max_new_tokens", generation_config.max_new_tokens)
|
235 |
max_new_tokens = 512 if max_new_tokens is None else max_new_tokens
|
236 |
+
do_sample = kwargs.get("do_sample", True)
|
237 |
+
eos_token_ids = kwargs.get("eos_token_ids", generation_config.eos_token_id)
|
238 |
if eos_token_ids is None:
|
239 |
+
raise ValueError(
|
240 |
+
"Model generation config does not have an EOS token id. You must provide it to generate() with the eos_token_ids argument."
|
241 |
+
)
|
242 |
eos_token_ids = torch.as_tensor(eos_token_ids, device=model.device)
|
243 |
if eos_token_ids is not None and eos_token_ids.ndim == 0:
|
244 |
eos_token_ids = eos_token_ids.unsqueeze(0)
|
245 |
+
|
246 |
+
pad_token_id = kwargs.get(
|
247 |
+
"pad_token_id",
|
248 |
+
generation_config.pad_token_id
|
249 |
+
if generation_config.pad_token_id is not None
|
250 |
+
else eos_token_ids[0],
|
251 |
+
)
|
252 |
+
bos_token_id = kwargs.get("bos_token_id", generation_config.bos_token_id)
|
253 |
if bos_token_id is None:
|
254 |
+
raise ValueError(
|
255 |
+
"Model generation config does not have a BOS token id. You must provide it to generate() with the bos_token_id argument."
|
256 |
+
)
|
257 |
+
temperature = kwargs.get("temperature", 1.0)
|
258 |
+
return_dict = kwargs.get("return_dict_in_generate", False)
|
259 |
|
260 |
+
generated_ids, scores, logprobs = sampling_with_kvcache(
|
261 |
model_kwargs=kwargs,
|
262 |
model=model,
|
263 |
eos_token_ids=eos_token_ids,
|
|
|
265 |
bos_token_id=bos_token_id,
|
266 |
do_sample=do_sample,
|
267 |
max_new_tokens=max_new_tokens,
|
268 |
+
temperature=temperature,
|
269 |
)
|
270 |
|
271 |
if return_dict:
|
272 |
return {
|
273 |
"sequences": generated_ids,
|
274 |
"scores": scores,
|
275 |
+
"logprobs": logprobs,
|
|
|
|
|
276 |
}
|
277 |
else:
|
278 |
return generated_ids
|
|