File size: 8,418 Bytes
47784f5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 |
import torch
def next_logits_with_cache_update(model, model_kwargs, input_ids):
"""
Gets the next token logits and updates the KV cache:
- Runs the model forward pass
- Extracts logits for the last token
- Updates the KV cache
- Returns updated `model_kwargs` and `logits`
Args:
model: The language model
model_kwargs: Model keyword arguments including KV cache
input_ids: Current input token IDs
Returns:
Updated model_kwargs, logits for the next token
"""
model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs)
with torch.no_grad():
outputs = model(**model_inputs, return_dict=True)
logits = outputs.logits[:, -1].detach()
model_kwargs = model._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=model.config.is_encoder_decoder
)
del outputs
return model_kwargs, logits
def init_gen(model_kwargs, model, max_new_tokens, bos_token_id):
"""
Initializes the generation process and prepares the KV cache:
- Sets up input sequences and model inputs
- Prepares the KV cache for generation
- Returns updated `model_kwargs` and `input_ids`
Args:
model_kwargs: Model keyword arguments
model: The language model
max_new_tokens: Maximum number of new tokens to generate
bos_token_id: Beginning-of-sequence token ID
Returns:
Model keyword arguments and input token IDs
"""
input_ids, model_input_name, model_kwargs = model._prepare_model_inputs(
None, bos_token_id, model_kwargs
)
batch_size = input_ids.shape[0]
model._prepare_cache_for_generation(
model.generation_config, model_kwargs, None, batch_size,
max_cache_length=max_new_tokens, device=input_ids.device
)
# Get initial cache position
model_kwargs = model._get_initial_cache_position(input_ids.shape[1], input_ids.device, model_kwargs)
return model_kwargs, input_ids
def _apply_top_k(ps, model):
"""Apply top-k filtering to probabilities."""
if not hasattr(model, 'generation_config') or not hasattr(model.generation_config, 'top_k'):
return ps
top_k = model.generation_config.top_k
if top_k is None or top_k >= ps.size(-1):
return ps
indices_to_remove = ps < torch.topk(ps, top_k)[0][..., -1, None]
ps[indices_to_remove] = 0.0
return ps / ps.sum(dim=-1, keepdim=True)
def _apply_top_p(ps, model):
"""Apply top-p (nucleus) filtering to probabilities."""
if not hasattr(model, 'generation_config') or not hasattr(model.generation_config, 'top_p'):
return ps
top_p = model.generation_config.top_p
if top_p is None or top_p >= 1.0:
return ps
sorted_probs, sorted_indices = torch.sort(ps, descending=True)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
ps[indices_to_remove] = 0.0
return ps / ps.sum(dim=-1, keepdim=True)
def sampling_with_kvcache(model_kwargs, model, eos_token_ids, pad_token_id, bos_token_id, do_sample=True, max_new_tokens=20, temperature=1.0):
"""
Sampling implementation with proper KV caching.
Args:
prompts: List of input prompts
model: The language model
max_new_tokens: Maximum number of new tokens to generate
eos_token_ids: List of end-of-sequence token IDs
pad_token_id: Padding token ID
bos_token_id: Beginning-of-sequence token ID
max_new_tokens: Maximum number of new tokens to generate
Returns:
Generated sequences, log probabilities, and metadata
"""
# Initialize the generation process and prepare the KV cache
model_kwargs, input_ids = init_gen(model_kwargs, model, max_new_tokens, bos_token_id)
batch_size, _ = input_ids.shape
# Keeps track of which sequences are finished and their lengths
active_seqs = input_ids.new_ones((batch_size, 1), dtype=torch.bool)
# Modified log probabilities of the sequences
scores = torch.zeros((batch_size, max_new_tokens), dtype=model.dtype)
# Unfiltered sequence log probabilities (temperature=1, no sampling processors applied)
logprobs = torch.zeros((batch_size, max_new_tokens), dtype=model.dtype)
for i in range(max_new_tokens):
# Get the next token probabilities and update the KV cache
model_kwargs, logits = next_logits_with_cache_update(model, model_kwargs, input_ids)
# Store original model probabilities (temperature=1, no sampling processors applied)
model_ps = logits.softmax(-1)
# Logit processors (temperature, top-k, top-p). We can chain these!
ps = (logits/temperature).softmax(-1)
ps = _apply_top_k(ps, model)
ps = _apply_top_p(ps, model)
# Sample the next token and gather the log probabilities
if do_sample: # Sampling
next_token_ids = torch.multinomial(ps, 1) * active_seqs + pad_token_id * ~active_seqs
else: # Greedy decoding
next_token_ids = torch.argmax(ps, dim=-1).unsqueeze(-1) * active_seqs + pad_token_id * ~active_seqs
next_token_logprobs = ps.gather(-1, next_token_ids).log()
next_token_model_logprobs = model_ps.gather(-1, next_token_ids).log()
input_ids = torch.cat([input_ids, next_token_ids], dim=-1)
scores[:, i] = (next_token_logprobs * active_seqs).squeeze()
logprobs[:, i] = (next_token_model_logprobs * active_seqs).squeeze()
active_seqs &= ~torch.isin(next_token_ids, eos_token_ids)
if active_seqs.sum() == 0:
break
return input_ids.detach().cpu(), scores[:,:i+1], logprobs[:,:i+1]
def generate(model, **kwargs):
"""
Sampling strategy - multinomial sampling with temperature and optional top-k/top-p filtering.
Simple implementation with proper KV caching support.
Args:
model: The language model
model_kwargs: Model keyword arguments from the tokenizer
generation_config: Generation configuration
temperature: Sampling temperature (higher = more random)
top_k: Only consider top-k most probable tokens
top_p: Only consider tokens with cumulative probability <= top_p
**kwargs: Additional arguments
Returns:
Generated token IDs
"""
generation_config = model.generation_config
max_new_tokens = kwargs.get('max_new_tokens', generation_config.max_new_tokens)
max_new_tokens = 512 if max_new_tokens is None else max_new_tokens
do_sample = kwargs.get('do_sample', True)
eos_token_ids = kwargs.get('eos_token_ids', generation_config.eos_token_id)
if eos_token_ids is None:
raise ValueError("Model generation config does not have an EOS token id. You must provide it to generate() with the eos_token_ids argument.")
eos_token_ids = torch.as_tensor(eos_token_ids, device=model.device)
if eos_token_ids is not None and eos_token_ids.ndim == 0:
eos_token_ids = eos_token_ids.unsqueeze(0)
pad_token_id = kwargs.get('pad_token_id', generation_config.pad_token_id if generation_config.pad_token_id is not None else eos_token_ids[0])
bos_token_id = kwargs.get('bos_token_id', generation_config.bos_token_id)
if bos_token_id is None:
raise ValueError("Model generation config does not have a BOS token id. You must provide it to generate() with the bos_token_id argument.")
temperature = kwargs.get('temperature', 1.0)
return_dict = kwargs.get('return_dict_in_generate', False)
generated_ids, scores, logprobs = sampling_with_kvcache(
model_kwargs=kwargs,
model=model,
eos_token_ids=eos_token_ids,
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
do_sample=do_sample,
max_new_tokens=max_new_tokens,
temperature=temperature,
)
if return_dict:
return {
"sequences": generated_ids,
"scores": scores,
"logprobs": logprobs,
}
else:
return generated_ids
|