manueldeprada's picture
update
47784f5
import torch
def next_logits_with_cache_update(model, model_kwargs, input_ids):
"""
Gets the next token logits and updates the KV cache:
- Runs the model forward pass
- Extracts logits for the last token
- Updates the KV cache
- Returns updated `model_kwargs` and `logits`
Args:
model: The language model
model_kwargs: Model keyword arguments including KV cache
input_ids: Current input token IDs
Returns:
Updated model_kwargs, logits for the next token
"""
model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs)
with torch.no_grad():
outputs = model(**model_inputs, return_dict=True)
logits = outputs.logits[:, -1].detach()
model_kwargs = model._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=model.config.is_encoder_decoder
)
del outputs
return model_kwargs, logits
def init_gen(model_kwargs, model, max_new_tokens, bos_token_id):
"""
Initializes the generation process and prepares the KV cache:
- Sets up input sequences and model inputs
- Prepares the KV cache for generation
- Returns updated `model_kwargs` and `input_ids`
Args:
model_kwargs: Model keyword arguments
model: The language model
max_new_tokens: Maximum number of new tokens to generate
bos_token_id: Beginning-of-sequence token ID
Returns:
Model keyword arguments and input token IDs
"""
input_ids, model_input_name, model_kwargs = model._prepare_model_inputs(
None, bos_token_id, model_kwargs
)
batch_size = input_ids.shape[0]
model._prepare_cache_for_generation(
model.generation_config, model_kwargs, None, batch_size,
max_cache_length=max_new_tokens, device=input_ids.device
)
# Get initial cache position
model_kwargs = model._get_initial_cache_position(input_ids.shape[1], input_ids.device, model_kwargs)
return model_kwargs, input_ids
def _apply_top_k(ps, model):
"""Apply top-k filtering to probabilities."""
if not hasattr(model, 'generation_config') or not hasattr(model.generation_config, 'top_k'):
return ps
top_k = model.generation_config.top_k
if top_k is None or top_k >= ps.size(-1):
return ps
indices_to_remove = ps < torch.topk(ps, top_k)[0][..., -1, None]
ps[indices_to_remove] = 0.0
return ps / ps.sum(dim=-1, keepdim=True)
def _apply_top_p(ps, model):
"""Apply top-p (nucleus) filtering to probabilities."""
if not hasattr(model, 'generation_config') or not hasattr(model.generation_config, 'top_p'):
return ps
top_p = model.generation_config.top_p
if top_p is None or top_p >= 1.0:
return ps
sorted_probs, sorted_indices = torch.sort(ps, descending=True)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
ps[indices_to_remove] = 0.0
return ps / ps.sum(dim=-1, keepdim=True)
def sampling_with_kvcache(model_kwargs, model, eos_token_ids, pad_token_id, bos_token_id, do_sample=True, max_new_tokens=20, temperature=1.0):
"""
Sampling implementation with proper KV caching.
Args:
prompts: List of input prompts
model: The language model
max_new_tokens: Maximum number of new tokens to generate
eos_token_ids: List of end-of-sequence token IDs
pad_token_id: Padding token ID
bos_token_id: Beginning-of-sequence token ID
max_new_tokens: Maximum number of new tokens to generate
Returns:
Generated sequences, log probabilities, and metadata
"""
# Initialize the generation process and prepare the KV cache
model_kwargs, input_ids = init_gen(model_kwargs, model, max_new_tokens, bos_token_id)
batch_size, _ = input_ids.shape
# Keeps track of which sequences are finished and their lengths
active_seqs = input_ids.new_ones((batch_size, 1), dtype=torch.bool)
# Modified log probabilities of the sequences
scores = torch.zeros((batch_size, max_new_tokens), dtype=model.dtype)
# Unfiltered sequence log probabilities (temperature=1, no sampling processors applied)
logprobs = torch.zeros((batch_size, max_new_tokens), dtype=model.dtype)
for i in range(max_new_tokens):
# Get the next token probabilities and update the KV cache
model_kwargs, logits = next_logits_with_cache_update(model, model_kwargs, input_ids)
# Store original model probabilities (temperature=1, no sampling processors applied)
model_ps = logits.softmax(-1)
# Logit processors (temperature, top-k, top-p). We can chain these!
ps = (logits/temperature).softmax(-1)
ps = _apply_top_k(ps, model)
ps = _apply_top_p(ps, model)
# Sample the next token and gather the log probabilities
if do_sample: # Sampling
next_token_ids = torch.multinomial(ps, 1) * active_seqs + pad_token_id * ~active_seqs
else: # Greedy decoding
next_token_ids = torch.argmax(ps, dim=-1).unsqueeze(-1) * active_seqs + pad_token_id * ~active_seqs
next_token_logprobs = ps.gather(-1, next_token_ids).log()
next_token_model_logprobs = model_ps.gather(-1, next_token_ids).log()
input_ids = torch.cat([input_ids, next_token_ids], dim=-1)
scores[:, i] = (next_token_logprobs * active_seqs).squeeze()
logprobs[:, i] = (next_token_model_logprobs * active_seqs).squeeze()
active_seqs &= ~torch.isin(next_token_ids, eos_token_ids)
if active_seqs.sum() == 0:
break
return input_ids.detach().cpu(), scores[:,:i+1], logprobs[:,:i+1]
def generate(model, **kwargs):
"""
Sampling strategy - multinomial sampling with temperature and optional top-k/top-p filtering.
Simple implementation with proper KV caching support.
Args:
model: The language model
model_kwargs: Model keyword arguments from the tokenizer
generation_config: Generation configuration
temperature: Sampling temperature (higher = more random)
top_k: Only consider top-k most probable tokens
top_p: Only consider tokens with cumulative probability <= top_p
**kwargs: Additional arguments
Returns:
Generated token IDs
"""
generation_config = model.generation_config
max_new_tokens = kwargs.get('max_new_tokens', generation_config.max_new_tokens)
max_new_tokens = 512 if max_new_tokens is None else max_new_tokens
do_sample = kwargs.get('do_sample', True)
eos_token_ids = kwargs.get('eos_token_ids', generation_config.eos_token_id)
if eos_token_ids is None:
raise ValueError("Model generation config does not have an EOS token id. You must provide it to generate() with the eos_token_ids argument.")
eos_token_ids = torch.as_tensor(eos_token_ids, device=model.device)
if eos_token_ids is not None and eos_token_ids.ndim == 0:
eos_token_ids = eos_token_ids.unsqueeze(0)
pad_token_id = kwargs.get('pad_token_id', generation_config.pad_token_id if generation_config.pad_token_id is not None else eos_token_ids[0])
bos_token_id = kwargs.get('bos_token_id', generation_config.bos_token_id)
if bos_token_id is None:
raise ValueError("Model generation config does not have a BOS token id. You must provide it to generate() with the bos_token_id argument.")
temperature = kwargs.get('temperature', 1.0)
return_dict = kwargs.get('return_dict_in_generate', False)
generated_ids, scores, logprobs = sampling_with_kvcache(
model_kwargs=kwargs,
model=model,
eos_token_ids=eos_token_ids,
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
do_sample=do_sample,
max_new_tokens=max_new_tokens,
temperature=temperature,
)
if return_dict:
return {
"sequences": generated_ids,
"scores": scores,
"logprobs": logprobs,
}
else:
return generated_ids