import torch def next_logits_with_cache_update(model, model_kwargs, input_ids): """ Gets the next token logits and updates the KV cache: - Runs the model forward pass - Extracts logits for the last token - Updates the KV cache - Returns updated `model_kwargs` and `logits` Args: model: The language model model_kwargs: Model keyword arguments including KV cache input_ids: Current input token IDs Returns: Updated model_kwargs, logits for the next token """ model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs) with torch.no_grad(): outputs = model(**model_inputs, return_dict=True) logits = outputs.logits[:, -1].detach() model_kwargs = model._update_model_kwargs_for_generation( outputs, model_kwargs, is_encoder_decoder=model.config.is_encoder_decoder ) del outputs return model_kwargs, logits def init_gen(model_kwargs, model, max_new_tokens, bos_token_id): """ Initializes the generation process and prepares the KV cache: - Sets up input sequences and model inputs - Prepares the KV cache for generation - Returns updated `model_kwargs` and `input_ids` Args: model_kwargs: Model keyword arguments model: The language model max_new_tokens: Maximum number of new tokens to generate bos_token_id: Beginning-of-sequence token ID Returns: Model keyword arguments and input token IDs """ input_ids, model_input_name, model_kwargs = model._prepare_model_inputs( None, bos_token_id, model_kwargs ) batch_size = input_ids.shape[0] model._prepare_cache_for_generation( model.generation_config, model_kwargs, None, batch_size, max_cache_length=max_new_tokens, device=input_ids.device ) # Get initial cache position model_kwargs = model._get_initial_cache_position(input_ids.shape[1], input_ids.device, model_kwargs) return model_kwargs, input_ids def _apply_top_k(ps, model): """Apply top-k filtering to probabilities.""" if not hasattr(model, 'generation_config') or not hasattr(model.generation_config, 'top_k'): return ps top_k = model.generation_config.top_k if top_k is None or top_k >= ps.size(-1): return ps indices_to_remove = ps < torch.topk(ps, top_k)[0][..., -1, None] ps[indices_to_remove] = 0.0 return ps / ps.sum(dim=-1, keepdim=True) def _apply_top_p(ps, model): """Apply top-p (nucleus) filtering to probabilities.""" if not hasattr(model, 'generation_config') or not hasattr(model.generation_config, 'top_p'): return ps top_p = model.generation_config.top_p if top_p is None or top_p >= 1.0: return ps sorted_probs, sorted_indices = torch.sort(ps, descending=True) cumulative_probs = torch.cumsum(sorted_probs, dim=-1) sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) ps[indices_to_remove] = 0.0 return ps / ps.sum(dim=-1, keepdim=True) def sampling_with_kvcache(model_kwargs, model, eos_token_ids, pad_token_id, bos_token_id, do_sample=True, max_new_tokens=20, temperature=1.0): """ Sampling implementation with proper KV caching. Args: prompts: List of input prompts model: The language model max_new_tokens: Maximum number of new tokens to generate eos_token_ids: List of end-of-sequence token IDs pad_token_id: Padding token ID bos_token_id: Beginning-of-sequence token ID max_new_tokens: Maximum number of new tokens to generate Returns: Generated sequences, log probabilities, and metadata """ # Initialize the generation process and prepare the KV cache model_kwargs, input_ids = init_gen(model_kwargs, model, max_new_tokens, bos_token_id) batch_size, _ = input_ids.shape # Keeps track of which sequences are finished and their lengths active_seqs = input_ids.new_ones((batch_size, 1), dtype=torch.bool) # Modified log probabilities of the sequences scores = torch.zeros((batch_size, max_new_tokens), dtype=model.dtype) # Unfiltered sequence log probabilities (temperature=1, no sampling processors applied) logprobs = torch.zeros((batch_size, max_new_tokens), dtype=model.dtype) for i in range(max_new_tokens): # Get the next token probabilities and update the KV cache model_kwargs, logits = next_logits_with_cache_update(model, model_kwargs, input_ids) # Store original model probabilities (temperature=1, no sampling processors applied) model_ps = logits.softmax(-1) # Logit processors (temperature, top-k, top-p). We can chain these! ps = (logits/temperature).softmax(-1) ps = _apply_top_k(ps, model) ps = _apply_top_p(ps, model) # Sample the next token and gather the log probabilities if do_sample: # Sampling next_token_ids = torch.multinomial(ps, 1) * active_seqs + pad_token_id * ~active_seqs else: # Greedy decoding next_token_ids = torch.argmax(ps, dim=-1).unsqueeze(-1) * active_seqs + pad_token_id * ~active_seqs next_token_logprobs = ps.gather(-1, next_token_ids).log() next_token_model_logprobs = model_ps.gather(-1, next_token_ids).log() input_ids = torch.cat([input_ids, next_token_ids], dim=-1) scores[:, i] = (next_token_logprobs * active_seqs).squeeze() logprobs[:, i] = (next_token_model_logprobs * active_seqs).squeeze() active_seqs &= ~torch.isin(next_token_ids, eos_token_ids) if active_seqs.sum() == 0: break return input_ids.detach().cpu(), scores[:,:i+1], logprobs[:,:i+1] def generate(model, **kwargs): """ Sampling strategy - multinomial sampling with temperature and optional top-k/top-p filtering. Simple implementation with proper KV caching support. Args: model: The language model model_kwargs: Model keyword arguments from the tokenizer generation_config: Generation configuration temperature: Sampling temperature (higher = more random) top_k: Only consider top-k most probable tokens top_p: Only consider tokens with cumulative probability <= top_p **kwargs: Additional arguments Returns: Generated token IDs """ generation_config = model.generation_config max_new_tokens = kwargs.get('max_new_tokens', generation_config.max_new_tokens) max_new_tokens = 512 if max_new_tokens is None else max_new_tokens do_sample = kwargs.get('do_sample', True) eos_token_ids = kwargs.get('eos_token_ids', generation_config.eos_token_id) if eos_token_ids is None: raise ValueError("Model generation config does not have an EOS token id. You must provide it to generate() with the eos_token_ids argument.") eos_token_ids = torch.as_tensor(eos_token_ids, device=model.device) if eos_token_ids is not None and eos_token_ids.ndim == 0: eos_token_ids = eos_token_ids.unsqueeze(0) pad_token_id = kwargs.get('pad_token_id', generation_config.pad_token_id if generation_config.pad_token_id is not None else eos_token_ids[0]) bos_token_id = kwargs.get('bos_token_id', generation_config.bos_token_id) if bos_token_id is None: raise ValueError("Model generation config does not have a BOS token id. You must provide it to generate() with the bos_token_id argument.") temperature = kwargs.get('temperature', 1.0) return_dict = kwargs.get('return_dict_in_generate', False) generated_ids, scores, logprobs = sampling_with_kvcache( model_kwargs=kwargs, model=model, eos_token_ids=eos_token_ids, pad_token_id=pad_token_id, bos_token_id=bos_token_id, do_sample=do_sample, max_new_tokens=max_new_tokens, temperature=temperature, ) if return_dict: return { "sequences": generated_ids, "scores": scores, "logprobs": logprobs, } else: return generated_ids