manueldeprada HF Staff commited on
Commit
0da6031
·
1 Parent(s): 8f67839

eliminate hf helpers

Browse files
Files changed (2) hide show
  1. README.md +20 -56
  2. custom_generate/generate.py +185 -100
README.md CHANGED
@@ -2,13 +2,14 @@
2
  library_name: transformers
3
  tags:
4
  - custom_generate
5
- - ancestral_sampling
 
6
  ---
7
 
8
- # Multinomial (Ancestral) Sampling simple implementation
9
 
10
  ## Description
11
- A clean, hackable implementation of ancestral sampling (multinomial sampling) with full KV cache support. This is a simplified alternative to the complex generation mixin in transformers, designed for readability and ease of modification while maintaining full performance.
12
 
13
  The implementation supports both sampling and greedy decoding modes, with optional temperature scaling and top-k/top-p filtering.
14
 
@@ -18,19 +19,23 @@ The implementation supports both sampling and greedy decoding modes, with option
18
  ## Model compatibility
19
  Most transformer LLM/VLM models trained for causal language modeling.
20
 
21
- ## Additional Arguments
22
  - `temperature` (float): Sampling temperature (default: 1.0, higher = more random)
23
  - `top_k` (int): Only consider top-k most probable tokens (default: None)
24
  - `top_p` (float): Only consider tokens with cumulative probability <= top_p (default: None)
25
  - `do_sample` (bool): Whether to use sampling (True, default) or greedy decoding (False)
26
 
27
- ## Output Type changes
 
 
 
 
 
28
  When `return_dict_in_generate=True`, returns a dictionary with:
29
  - `sequences`: Generated token IDs
30
  - `scores`: Log probabilities of sampled tokens (with temperature/sampling modifications)
31
- - `logps`: Original model log probabilities (T=1, no modifications)
32
- - `prompt_lens`: Length of input prompts
33
- - `lens`: Final sequence lengths
34
 
35
  ## Example usage
36
 
@@ -43,30 +48,30 @@ model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct", devic
43
  inputs = tokenizer(["The quick brown"], return_tensors="pt").to(model.device)
44
 
45
  # Basic sampling
46
- gen_out = model.generate(**inputs, custom_generate="manueldeprada/ancestral_sampling", trust_remote_code=True)
47
 
48
  # With temperature
49
- gen_out = model.generate(**inputs, custom_generate="manueldeprada/ancestral_sampling", temperature=0.8, trust_remote_code=True)
50
 
51
  # With top-k
52
- gen_out = model.generate(**inputs, custom_generate="manueldeprada/ancestral_sampling", top_k=50, trust_remote_code=True)
53
 
54
  # With top-p (nucleus sampling)
55
- gen_out = model.generate(**inputs, custom_generate="manueldeprada/ancestral_sampling", top_p=0.9, trust_remote_code=True)
56
 
57
  # Greedy decoding (no sampling)
58
- gen_out = model.generate(**inputs, custom_generate="manueldeprada/ancestral_sampling", do_sample=False, trust_remote_code=True)
59
 
60
  # Get detailed output with probabilities
61
  gen_out = model.generate(
62
  **inputs,
63
- custom_generate="manueldeprada/ancestral_sampling",
64
  return_dict_in_generate=True,
65
  trust_remote_code=True
66
  )
67
  print(f"Generated text: {tokenizer.batch_decode(gen_out['sequences'], skip_special_tokens=True)}")
68
  print(f"Sampling scores: {gen_out['scores']}")
69
- print(f"Model log probabilities: {gen_out['logps']}")
70
  ```
71
 
72
  ## Algorithm
@@ -82,47 +87,6 @@ print(f"Model log probabilities: {gen_out['logps']}")
82
  - Update KV cache and track sequence completion
83
  3. Return generated sequences and probability information
84
 
85
- ## Helper Functions for Custom Generation
86
-
87
- The implementation provides two key helper functions that you can use to build your own generation strategies:
88
-
89
- ### `init_gen(model_kwargs, model, max_new_tokens, bos_token_id)`
90
- Initializes the generation process and prepares the KV cache:
91
- - Sets up input sequences and model inputs
92
- - Prepares the KV cache for generation
93
- - Returns updated `model_kwargs` and `input_ids`
94
 
95
- ### `ps_next(model, model_kwargs, input_ids)`
96
- Gets the next token logits and updates the KV cache:
97
- - Runs the model forward pass
98
- - Extracts logits for the last token
99
- - Updates the KV cache
100
- - Returns updated `model_kwargs` and `logits`
101
 
102
- ### Example: Custom Generation Loop
103
-
104
- ```py
105
- from ancestral_sampling.generate import init_gen, ps_next
106
-
107
- def custom_generation(model, model_kwargs, max_new_tokens=20, temperature=1.0):
108
- # Initialize generation
109
- model_kwargs, input_ids = init_gen(model_kwargs, model, max_new_tokens, bos_token_id)
110
-
111
- for i in range(max_new_tokens):
112
- # Get next token logits
113
- model_kwargs, logits = ps_next(model, model_kwargs, input_ids)
114
-
115
- # Your custom logic here
116
- probs = (logits / temperature).softmax(-1)
117
- next_token = torch.multinomial(probs, 1)
118
-
119
- # Append token and continue
120
- input_ids = torch.cat([input_ids, next_token], dim=-1)
121
-
122
- # Add your stopping conditions
123
- if next_token.item() == eos_token_id:
124
- break
125
-
126
- return input_ids
127
- ```
128
 
 
2
  library_name: transformers
3
  tags:
4
  - custom_generate
5
+ - sampling
6
+ - kvcache
7
  ---
8
 
9
+ # Sampling with KV Cache
10
 
11
  ## Description
12
+ A clean, hackable implementation of sampling (also called ancestral sampling or multinomial sampling) with full KV cache support. This is a simplified alternative to the complex generation mixin in transformers, designed for readability and ease of modification while maintaining full performance.
13
 
14
  The implementation supports both sampling and greedy decoding modes, with optional temperature scaling and top-k/top-p filtering.
15
 
 
19
  ## Model compatibility
20
  Most transformer LLM/VLM models trained for causal language modeling.
21
 
22
+ ## Relevant Arguments
23
  - `temperature` (float): Sampling temperature (default: 1.0, higher = more random)
24
  - `top_k` (int): Only consider top-k most probable tokens (default: None)
25
  - `top_p` (float): Only consider tokens with cumulative probability <= top_p (default: None)
26
  - `do_sample` (bool): Whether to use sampling (True, default) or greedy decoding (False)
27
 
28
+ ### Logits Processing Order
29
+ Logits processors are applied in sequence: `temperature → softmax → top_k → top_p` (same as HuggingFace's `LogitProcessor` system). Temperature scaling occurs before top-p filtering, affecting the probability distribution that top-p operates on.
30
+
31
+ For example, with `temperature=1.0`, `top_p=0.9` might include tokens A, B, C. With `temperature=0.5`, probability mass is much more concentrated, so `top_p=0.9` might only include token A.
32
+
33
+ ## Outputs
34
  When `return_dict_in_generate=True`, returns a dictionary with:
35
  - `sequences`: Generated token IDs
36
  - `scores`: Log probabilities of sampled tokens (with temperature/sampling modifications)
37
+ - `logprobs`: Original model log probabilities (T=1, no modifications)
38
+ Otherwise, returns a tensor of generated token IDs.
 
39
 
40
  ## Example usage
41
 
 
48
  inputs = tokenizer(["The quick brown"], return_tensors="pt").to(model.device)
49
 
50
  # Basic sampling
51
+ gen_out = model.generate(**inputs, custom_generate="manueldeprada/sampling_with_kvcache", trust_remote_code=True)
52
 
53
  # With temperature
54
+ gen_out = model.generate(**inputs, custom_generate="manueldeprada/sampling_with_kvcache", temperature=0.8, trust_remote_code=True)
55
 
56
  # With top-k
57
+ gen_out = model.generate(**inputs, custom_generate="manueldeprada/sampling_with_kvcache", top_k=50, trust_remote_code=True)
58
 
59
  # With top-p (nucleus sampling)
60
+ gen_out = model.generate(**inputs, custom_generate="manueldeprada/sampling_with_kvcache", top_p=0.9, trust_remote_code=True)
61
 
62
  # Greedy decoding (no sampling)
63
+ gen_out = model.generate(**inputs, custom_generate="manueldeprada/sampling_with_kvcache", do_sample=False, trust_remote_code=True)
64
 
65
  # Get detailed output with probabilities
66
  gen_out = model.generate(
67
  **inputs,
68
+ custom_generate="manueldeprada/sampling_with_kvcache",
69
  return_dict_in_generate=True,
70
  trust_remote_code=True
71
  )
72
  print(f"Generated text: {tokenizer.batch_decode(gen_out['sequences'], skip_special_tokens=True)}")
73
  print(f"Sampling scores: {gen_out['scores']}")
74
+ print(f"Model log probabilities: {gen_out['logprobs']}")
75
  ```
76
 
77
  ## Algorithm
 
87
  - Update KV cache and track sequence completion
88
  3. Return generated sequences and probability information
89
 
 
 
 
 
 
 
 
 
 
90
 
 
 
 
 
 
 
91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
custom_generate/generate.py CHANGED
@@ -1,87 +1,157 @@
1
  import torch
2
- from transformers import GenerationConfig
 
 
3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
- def ps_next(model, model_kwargs, input_ids):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  """
7
- Auxiliary function to get the next token probabilities and update the KV cache.
8
-
 
 
 
 
9
  Args:
10
  model: The language model
11
  model_kwargs: Model keyword arguments including KV cache
12
  input_ids: Current input token IDs
13
- T: Temperature for sampling
14
-
15
  Returns:
16
- Updated model_kwargs, probabilities at temperature T, probabilities at T=1
17
  """
18
- model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs)
19
  with torch.no_grad():
20
  outputs = model(**model_inputs, return_dict=True)
21
-
22
  logits = outputs.logits[:, -1].detach()
23
- model_kwargs = model._update_model_kwargs_for_generation(
24
- outputs, model_kwargs, is_encoder_decoder=model.config.is_encoder_decoder
25
- )
26
  del outputs
27
  return model_kwargs, logits
28
 
 
29
  def init_gen(model_kwargs, model, max_new_tokens, bos_token_id):
30
  """
31
- Auxiliary function to initialize the generation process and prepare the KV cache.
32
-
 
 
 
33
  Args:
34
  model_kwargs: Model keyword arguments
35
  model: The language model
36
  max_new_tokens: Maximum number of new tokens to generate
37
-
 
38
  Returns:
39
  Model keyword arguments and input token IDs
40
  """
 
 
 
 
 
 
 
41
 
42
- input_ids, model_input_name, model_kwargs = model._prepare_model_inputs(
43
- None, bos_token_id, model_kwargs
44
- )
45
-
46
- batch_size = input_ids.shape[0]
47
- model._prepare_cache_for_generation(
48
- model.generation_config, model_kwargs, None, batch_size,
49
- max_cache_length=max_new_tokens, device=input_ids.device
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  )
51
-
52
- # Get initial cache position
53
- model_kwargs = model._get_initial_cache_position(input_ids.shape[1], input_ids.device, model_kwargs)
54
- return model_kwargs, input_ids
55
 
56
- def _apply_top_k_top_p(ps, model):
57
- if hasattr(model, 'generation_config') and hasattr(model.generation_config, 'top_k') and model.generation_config.top_k is not None:
58
- top_k = model.generation_config.top_k
59
- top_k = min(top_k, ps.size(-1))
60
- indices_to_remove = ps < torch.topk(ps, top_k)[0][..., -1, None]
61
- ps[indices_to_remove] = 0.0
62
- ps = ps / ps.sum(dim=-1, keepdim=True)
63
-
64
- # Apply top-p filtering if specified
65
- if hasattr(model, 'generation_config') and hasattr(model.generation_config, 'top_p') and model.generation_config.top_p is not None:
66
- top_p = model.generation_config.top_p
67
- if top_p < 1.0:
68
- sorted_probs, sorted_indices = torch.sort(ps, descending=True)
69
- cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
70
-
71
- # Remove tokens with cumulative probability above the threshold
72
- sorted_indices_to_remove = cumulative_probs > top_p
73
- sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
74
- sorted_indices_to_remove[..., 0] = 0
75
-
76
- indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
77
- ps[indices_to_remove] = 0.0
78
- ps = ps / ps.sum(dim=-1, keepdim=True)
79
- return ps
80
-
81
- def ancestral_sampling(model_kwargs, model, eos_token_ids, pad_token_id, bos_token_id, do_sample=True, max_new_tokens=20, T=1.0):
82
  """
83
- Ancestral sampling implementation with proper KV caching.
84
-
85
  Args:
86
  prompts: List of input prompts
87
  model: The language model
@@ -90,55 +160,64 @@ def ancestral_sampling(model_kwargs, model, eos_token_ids, pad_token_id, bos_tok
90
  pad_token_id: Padding token ID
91
  bos_token_id: Beginning-of-sequence token ID
92
  max_new_tokens: Maximum number of new tokens to generate
93
-
94
  Returns:
95
  Generated sequences, log probabilities, and metadata
96
  """
97
  # Initialize the generation process and prepare the KV cache
98
- model_kwargs, input_ids = init_gen(model_kwargs, model, max_new_tokens, bos_token_id)
99
- batch_size, max_prompts_len = input_ids.shape
100
- prompts_len = (input_ids != pad_token_id).sum(dim=-1)
 
101
 
102
  # Keeps track of which sequences are finished and their lengths
103
  active_seqs = input_ids.new_ones((batch_size, 1), dtype=torch.bool)
104
- lens = torch.full((batch_size,), max_prompts_len, dtype=torch.long, device=input_ids.device)
105
  # Modified log probabilities of the sequences
106
  scores = torch.zeros((batch_size, max_new_tokens), dtype=model.dtype)
107
- # Unfiltered sequence log probabilities (T=1, no sampling modifications)
108
- logps = torch.zeros((batch_size, max_new_tokens), dtype=model.dtype)
109
 
110
  for i in range(max_new_tokens):
111
  # Get the next token probabilities and update the KV cache
112
- model_kwargs, logits = ps_next(model, model_kwargs, input_ids)
113
- # Original model probabilities (T=1, no sampling modifications)
 
 
114
  model_ps = logits.softmax(-1)
115
- # Sampling probabilities (T, with sampling modifications)
116
- ps = (logits/T).softmax(-1)
117
- ps = _apply_top_k_top_p(ps, model)
118
-
 
 
119
  # Sample the next token and gather the log probabilities
120
- if do_sample:
121
- next_token_ids = torch.multinomial(ps, 1) * active_seqs + pad_token_id * ~active_seqs
122
- else:
123
- next_token_ids = torch.argmax(ps, dim=-1).unsqueeze(-1) * active_seqs + pad_token_id * ~active_seqs
124
- next_token_logps = ps.gather(-1, next_token_ids).log()
125
- next_token_model_logps = model_ps.gather(-1, next_token_ids).log()
126
-
 
 
 
 
 
127
  input_ids = torch.cat([input_ids, next_token_ids], dim=-1)
128
- scores[:, i] = (next_token_logps * active_seqs).squeeze()
129
- logps[:, i] = (next_token_model_logps * active_seqs).squeeze()
130
-
131
- lens += active_seqs.squeeze(-1).long()
132
  active_seqs &= ~torch.isin(next_token_ids, eos_token_ids)
133
  if active_seqs.sum() == 0:
134
- break
135
- return input_ids.detach().cpu(), scores[:,:i+1], logps[:,:i+1], prompts_len, lens.tolist()
 
136
 
137
  def generate(model, **kwargs):
138
  """
139
- Ancestral sampling strategy - multinomial sampling with temperature and optional top-k/top-p filtering.
140
  Simple implementation with proper KV caching support.
141
-
142
  Args:
143
  model: The language model
144
  model_kwargs: Model keyword arguments from the tokenizer
@@ -147,29 +226,38 @@ def generate(model, **kwargs):
147
  top_k: Only consider top-k most probable tokens
148
  top_p: Only consider tokens with cumulative probability <= top_p
149
  **kwargs: Additional arguments
150
-
151
  Returns:
152
  Generated token IDs
153
  """
154
  generation_config = model.generation_config
155
- max_new_tokens = kwargs.get('max_new_tokens', generation_config.max_new_tokens)
156
  max_new_tokens = 512 if max_new_tokens is None else max_new_tokens
157
- do_sample = kwargs.get('do_sample', True)
158
- eos_token_ids = kwargs.get('eos_token_ids', generation_config.eos_token_id)
159
  if eos_token_ids is None:
160
- raise ValueError("Model generation config does not have an EOS token id. You must provide it to generate() with the eos_token_ids argument.")
 
 
161
  eos_token_ids = torch.as_tensor(eos_token_ids, device=model.device)
162
  if eos_token_ids is not None and eos_token_ids.ndim == 0:
163
  eos_token_ids = eos_token_ids.unsqueeze(0)
164
-
165
- pad_token_id = kwargs.get('pad_token_id', generation_config.pad_token_id if generation_config.pad_token_id is not None else eos_token_ids[0])
166
- bos_token_id = kwargs.get('bos_token_id', generation_config.bos_token_id)
 
 
 
 
 
167
  if bos_token_id is None:
168
- raise ValueError("Model generation config does not have a BOS token id. You must provide it to generate() with the bos_token_id argument.")
169
- T = kwargs.get('temperature', 1.0)
170
- return_dict = kwargs.get('return_dict_in_generate', False)
 
 
171
 
172
- generated_ids, scores, logps, prompt_lens, lens = ancestral_sampling(
173
  model_kwargs=kwargs,
174
  model=model,
175
  eos_token_ids=eos_token_ids,
@@ -177,17 +265,14 @@ def generate(model, **kwargs):
177
  bos_token_id=bos_token_id,
178
  do_sample=do_sample,
179
  max_new_tokens=max_new_tokens,
180
- T=T,
181
  )
182
 
183
  if return_dict:
184
  return {
185
  "sequences": generated_ids,
186
  "scores": scores,
187
- "logps": logps,
188
- "prompt_lens": prompt_lens,
189
- "lens": lens,
190
  }
191
  else:
192
  return generated_ids
193
-
 
1
  import torch
2
+ from transformers import Cache, DynamicCache
3
+ from transformers.generation.utils import ModelOutput
4
+ from typing import Optional, Any
5
 
6
+ def prepare_inputs_for_generation(
7
+ input_ids: torch.LongTensor,
8
+ past_key_values: Optional[Cache] = None,
9
+ attention_mask: Optional[torch.LongTensor] = None,
10
+ cache_position: Optional[torch.LongTensor] = None,
11
+ **kwargs,
12
+ ):
13
+ input_ids = input_ids[:, cache_position].clone(memory_format=torch.contiguous_format)
14
+ cur_len = input_ids.shape[1]
15
+ model_inputs = {"cache_position": cache_position,
16
+ "past_key_values": past_key_values,
17
+ "input_ids": input_ids,
18
+ "inputs_embeds": None,
19
+ "attention_mask": attention_mask,
20
+ }
21
+ if attention_mask is not None and kwargs.get("position_ids") is None:
22
+ position_ids = attention_mask.long().cumsum(-1) - 1
23
+ position_ids.masked_fill_(attention_mask == 0, 1)
24
+ kwargs["position_ids"] = position_ids
25
+ if past_key_values is not None:
26
+ for name in ("position_ids", "token_type_ids"):
27
+ if name in kwargs:
28
+ kwargs[name] = kwargs[name][:, -cur_len:].clone(memory_format=torch.contiguous_format)
29
+ model_inputs.update({k: v for k, v in kwargs.items() if k not in model_inputs})
30
+ return model_inputs
31
 
32
+ def update_model_kwargs_for_generation(
33
+ outputs: ModelOutput,
34
+ model_kwargs: dict[str, Any],
35
+ num_new_tokens: int = 1,
36
+ ) -> dict[str, Any]:
37
+ model_kwargs["past_key_values"] = getattr(outputs, "past_key_values")
38
+ if "token_type_ids" in model_kwargs:
39
+ token_type_ids = model_kwargs["token_type_ids"]
40
+ model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)
41
+ if "attention_mask" in model_kwargs:
42
+ attention_mask = model_kwargs["attention_mask"]
43
+ model_kwargs["attention_mask"] = torch.cat(
44
+ [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
45
+ )
46
+ model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens
47
+ return model_kwargs
48
+
49
+
50
+ def next_logits_with_cache_update(model, model_kwargs, input_ids):
51
  """
52
+ Gets the next token logits and updates the KV cache:
53
+ - Runs the model forward pass
54
+ - Extracts logits for the last token
55
+ - Updates the KV cache
56
+ - Returns updated `model_kwargs` and `logits`
57
+
58
  Args:
59
  model: The language model
60
  model_kwargs: Model keyword arguments including KV cache
61
  input_ids: Current input token IDs
62
+
 
63
  Returns:
64
+ Updated model_kwargs, logits for the next token
65
  """
66
+ model_inputs = prepare_inputs_for_generation(input_ids, **model_kwargs)
67
  with torch.no_grad():
68
  outputs = model(**model_inputs, return_dict=True)
69
+
70
  logits = outputs.logits[:, -1].detach()
71
+ model_kwargs = update_model_kwargs_for_generation(outputs, model_kwargs)
 
 
72
  del outputs
73
  return model_kwargs, logits
74
 
75
+
76
  def init_gen(model_kwargs, model, max_new_tokens, bos_token_id):
77
  """
78
+ Initializes the generation process and prepares the KV cache:
79
+ - Sets up input sequences and model inputs
80
+ - Prepares the KV cache for generation
81
+ - Returns updated `model_kwargs` and `input_ids`
82
+
83
  Args:
84
  model_kwargs: Model keyword arguments
85
  model: The language model
86
  max_new_tokens: Maximum number of new tokens to generate
87
+ bos_token_id: Beginning-of-sequence token ID
88
+
89
  Returns:
90
  Model keyword arguments and input token IDs
91
  """
92
+ input_ids = model_kwargs.pop("input_ids")
93
+ model_kwargs["past_key_values"] = DynamicCache() if model_kwargs.get("past_key_values") is None else model_kwargs["past_key_values"]
94
+ assert isinstance(model_kwargs["past_key_values"], Cache), "past_key_values must be a Cache object"
95
+ cache_position = torch.ones(input_ids.shape[1], dtype=torch.int64, device=input_ids.device).cumsum(0) - 1
96
+ cache_position = cache_position[model_kwargs["past_key_values"].get_seq_length() :]
97
+ model_kwargs["cache_position"] = cache_position
98
+ return model_kwargs, input_ids
99
 
100
+
101
+ def _apply_top_k(ps, model):
102
+ """Apply top-k filtering to probabilities."""
103
+ if not hasattr(model, "generation_config") or not hasattr(
104
+ model.generation_config, "top_k"
105
+ ):
106
+ return ps
107
+
108
+ top_k = model.generation_config.top_k
109
+ if top_k is None or top_k >= ps.size(-1):
110
+ return ps
111
+
112
+ indices_to_remove = ps < torch.topk(ps, top_k)[0][..., -1, None]
113
+ ps[indices_to_remove] = 0.0
114
+ return ps / ps.sum(dim=-1, keepdim=True)
115
+
116
+
117
+ def _apply_top_p(ps, model):
118
+ """Apply top-p (nucleus) filtering to probabilities."""
119
+ if not hasattr(model, "generation_config") or not hasattr(
120
+ model.generation_config, "top_p"
121
+ ):
122
+ return ps
123
+
124
+ top_p = model.generation_config.top_p
125
+ if top_p is None or top_p >= 1.0:
126
+ return ps
127
+
128
+ sorted_probs, sorted_indices = torch.sort(ps, descending=True)
129
+ cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
130
+
131
+ sorted_indices_to_remove = cumulative_probs > top_p
132
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
133
+ sorted_indices_to_remove[..., 0] = 0
134
+
135
+ indices_to_remove = sorted_indices_to_remove.scatter(
136
+ 1, sorted_indices, sorted_indices_to_remove
137
  )
138
+ ps[indices_to_remove] = 0.0
139
+ return ps / ps.sum(dim=-1, keepdim=True)
140
+
 
141
 
142
+ def sampling_with_kvcache(
143
+ model_kwargs,
144
+ model,
145
+ eos_token_ids,
146
+ pad_token_id,
147
+ bos_token_id,
148
+ do_sample=True,
149
+ max_new_tokens=20,
150
+ temperature=1.0,
151
+ ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  """
153
+ Sampling implementation with proper KV caching.
154
+
155
  Args:
156
  prompts: List of input prompts
157
  model: The language model
 
160
  pad_token_id: Padding token ID
161
  bos_token_id: Beginning-of-sequence token ID
162
  max_new_tokens: Maximum number of new tokens to generate
163
+
164
  Returns:
165
  Generated sequences, log probabilities, and metadata
166
  """
167
  # Initialize the generation process and prepare the KV cache
168
+ model_kwargs, input_ids = init_gen(
169
+ model_kwargs, model, max_new_tokens, bos_token_id
170
+ )
171
+ batch_size, _ = input_ids.shape
172
 
173
  # Keeps track of which sequences are finished and their lengths
174
  active_seqs = input_ids.new_ones((batch_size, 1), dtype=torch.bool)
 
175
  # Modified log probabilities of the sequences
176
  scores = torch.zeros((batch_size, max_new_tokens), dtype=model.dtype)
177
+ # Unfiltered sequence log probabilities (temperature=1, no sampling processors applied)
178
+ logprobs = torch.zeros((batch_size, max_new_tokens), dtype=model.dtype)
179
 
180
  for i in range(max_new_tokens):
181
  # Get the next token probabilities and update the KV cache
182
+ model_kwargs, logits = next_logits_with_cache_update(
183
+ model, model_kwargs, input_ids
184
+ )
185
+ # Store original model probabilities (temperature=1, no sampling processors applied)
186
  model_ps = logits.softmax(-1)
187
+
188
+ # Logit processors (temperature, top-k, top-p). We can chain these!
189
+ ps = (logits / temperature).softmax(-1)
190
+ ps = _apply_top_k(ps, model)
191
+ ps = _apply_top_p(ps, model)
192
+
193
  # Sample the next token and gather the log probabilities
194
+ if do_sample: # Sampling
195
+ next_token_ids = (
196
+ torch.multinomial(ps, 1) * active_seqs + pad_token_id * ~active_seqs
197
+ )
198
+ else: # Greedy decoding
199
+ next_token_ids = (
200
+ torch.argmax(ps, dim=-1).unsqueeze(-1) * active_seqs
201
+ + pad_token_id * ~active_seqs
202
+ )
203
+ next_token_logprobs = ps.gather(-1, next_token_ids).log()
204
+ next_token_model_logprobs = model_ps.gather(-1, next_token_ids).log()
205
+
206
  input_ids = torch.cat([input_ids, next_token_ids], dim=-1)
207
+ scores[:, i] = (next_token_logprobs * active_seqs).squeeze()
208
+ logprobs[:, i] = (next_token_model_logprobs * active_seqs).squeeze()
209
+
 
210
  active_seqs &= ~torch.isin(next_token_ids, eos_token_ids)
211
  if active_seqs.sum() == 0:
212
+ break
213
+ return input_ids.detach().cpu(), scores[:, : i + 1], logprobs[:, : i + 1]
214
+
215
 
216
  def generate(model, **kwargs):
217
  """
218
+ Sampling strategy - multinomial sampling with temperature and optional top-k/top-p filtering.
219
  Simple implementation with proper KV caching support.
220
+
221
  Args:
222
  model: The language model
223
  model_kwargs: Model keyword arguments from the tokenizer
 
226
  top_k: Only consider top-k most probable tokens
227
  top_p: Only consider tokens with cumulative probability <= top_p
228
  **kwargs: Additional arguments
229
+
230
  Returns:
231
  Generated token IDs
232
  """
233
  generation_config = model.generation_config
234
+ max_new_tokens = kwargs.get("max_new_tokens", generation_config.max_new_tokens)
235
  max_new_tokens = 512 if max_new_tokens is None else max_new_tokens
236
+ do_sample = kwargs.get("do_sample", True)
237
+ eos_token_ids = kwargs.get("eos_token_ids", generation_config.eos_token_id)
238
  if eos_token_ids is None:
239
+ raise ValueError(
240
+ "Model generation config does not have an EOS token id. You must provide it to generate() with the eos_token_ids argument."
241
+ )
242
  eos_token_ids = torch.as_tensor(eos_token_ids, device=model.device)
243
  if eos_token_ids is not None and eos_token_ids.ndim == 0:
244
  eos_token_ids = eos_token_ids.unsqueeze(0)
245
+
246
+ pad_token_id = kwargs.get(
247
+ "pad_token_id",
248
+ generation_config.pad_token_id
249
+ if generation_config.pad_token_id is not None
250
+ else eos_token_ids[0],
251
+ )
252
+ bos_token_id = kwargs.get("bos_token_id", generation_config.bos_token_id)
253
  if bos_token_id is None:
254
+ raise ValueError(
255
+ "Model generation config does not have a BOS token id. You must provide it to generate() with the bos_token_id argument."
256
+ )
257
+ temperature = kwargs.get("temperature", 1.0)
258
+ return_dict = kwargs.get("return_dict_in_generate", False)
259
 
260
+ generated_ids, scores, logprobs = sampling_with_kvcache(
261
  model_kwargs=kwargs,
262
  model=model,
263
  eos_token_ids=eos_token_ids,
 
265
  bos_token_id=bos_token_id,
266
  do_sample=do_sample,
267
  max_new_tokens=max_new_tokens,
268
+ temperature=temperature,
269
  )
270
 
271
  if return_dict:
272
  return {
273
  "sequences": generated_ids,
274
  "scores": scores,
275
+ "logprobs": logprobs,
 
 
276
  }
277
  else:
278
  return generated_ids