import torch from transformers import Cache, DynamicCache from transformers.generation.utils import ModelOutput from typing import Optional, Any def prepare_inputs_for_generation( input_ids: torch.LongTensor, attention_mask: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, ): input_ids = input_ids[:, cache_position].clone(memory_format=torch.contiguous_format) model_inputs = {"cache_position": cache_position, "past_key_values": None, "input_ids": input_ids, "inputs_embeds": None, "attention_mask": attention_mask, } if attention_mask is not None and kwargs.get("position_ids") is None: position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) kwargs["position_ids"] = position_ids model_inputs.update({k: v for k, v in kwargs.items() if k not in model_inputs}) return model_inputs def update_model_kwargs_for_generation( outputs: ModelOutput, model_kwargs: dict[str, Any], num_new_tokens: int = 1, ) -> dict[str, Any]: if "token_type_ids" in model_kwargs: token_type_ids = model_kwargs["token_type_ids"] model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1) if "attention_mask" in model_kwargs: attention_mask = model_kwargs["attention_mask"] model_kwargs["attention_mask"] = torch.cat( [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 ) past_positions = model_kwargs.pop("cache_position") new_positions = torch.arange( past_positions[-1] + 1, past_positions[-1] + num_new_tokens + 1, dtype=past_positions.dtype ).to(past_positions.device) model_kwargs["cache_position"] = torch.cat((past_positions, new_positions)) return model_kwargs def next_logits_with_cache_update(model, model_kwargs, input_ids): """ Gets the next token logits and updates the KV cache: - Runs the model forward pass - Extracts logits for the last token - Updates the KV cache - Returns updated `model_kwargs` and `logits` Args: model: The language model model_kwargs: Model keyword arguments including KV cache input_ids: Current input token IDs Returns: Updated model_kwargs, logits for the next token """ model_inputs = prepare_inputs_for_generation(input_ids, **model_kwargs) with torch.no_grad(): outputs = model(**model_inputs, return_dict=True) logits = outputs.logits[:, -1].detach() model_kwargs = update_model_kwargs_for_generation(outputs, model_kwargs) del outputs return model_kwargs, logits def init_gen(model_kwargs, model, max_new_tokens, bos_token_id): """ Initializes the generation process and prepares the KV cache: - Sets up input sequences and model inputs - Prepares the KV cache for generation - Returns updated `model_kwargs` and `input_ids` Args: model_kwargs: Model keyword arguments model: The language model max_new_tokens: Maximum number of new tokens to generate bos_token_id: Beginning-of-sequence token ID Returns: Model keyword arguments and input token IDs """ input_ids = model_kwargs.pop("input_ids") model_kwargs["past_key_values"] = None model_kwargs["cache_position"] = torch.ones(input_ids.shape[1], dtype=torch.int64, device=input_ids.device).cumsum(0) - 1 return model_kwargs, input_ids def _apply_top_k(ps, model): """Apply top-k filtering to probabilities.""" if not hasattr(model, "generation_config") or not hasattr( model.generation_config, "top_k" ): return ps top_k = model.generation_config.top_k if top_k is None or top_k >= ps.size(-1): return ps indices_to_remove = ps < torch.topk(ps, top_k)[0][..., -1, None] ps[indices_to_remove] = 0.0 return ps / ps.sum(dim=-1, keepdim=True) def _apply_top_p(ps, model): """Apply top-p (nucleus) filtering to probabilities.""" if not hasattr(model, "generation_config") or not hasattr( model.generation_config, "top_p" ): return ps top_p = model.generation_config.top_p if top_p is None or top_p >= 1.0: return ps sorted_probs, sorted_indices = torch.sort(ps, descending=True) cumulative_probs = torch.cumsum(sorted_probs, dim=-1) sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 indices_to_remove = sorted_indices_to_remove.scatter( 1, sorted_indices, sorted_indices_to_remove ) ps[indices_to_remove] = 0.0 return ps / ps.sum(dim=-1, keepdim=True) def sampling( model_kwargs, model, eos_token_ids, pad_token_id, bos_token_id, do_sample=True, max_new_tokens=20, temperature=1.0, ): """ Sampling implementation with proper KV caching. Args: prompts: List of input prompts model: The language model max_new_tokens: Maximum number of new tokens to generate eos_token_ids: List of end-of-sequence token IDs pad_token_id: Padding token ID bos_token_id: Beginning-of-sequence token ID max_new_tokens: Maximum number of new tokens to generate Returns: Generated sequences, log probabilities, and metadata """ # Initialize the generation process and prepare the KV cache model_kwargs, input_ids = init_gen( model_kwargs, model, max_new_tokens, bos_token_id ) batch_size, _ = input_ids.shape # Keeps track of which sequences are finished and their lengths active_seqs = input_ids.new_ones((batch_size, 1), dtype=torch.bool) # Modified log probabilities of the sequences scores = torch.zeros((batch_size, max_new_tokens), dtype=model.dtype) # Unfiltered sequence log probabilities (temperature=1, no sampling processors applied) logprobs = torch.zeros((batch_size, max_new_tokens), dtype=model.dtype) for i in range(max_new_tokens): # Get the next token probabilities and update the KV cache model_kwargs, logits = next_logits_with_cache_update( model, model_kwargs, input_ids ) # Store original model probabilities (temperature=1, no sampling processors applied) model_ps = logits.softmax(-1) # Logit processors (temperature, top-k, top-p). We can chain these! ps = (logits / temperature).softmax(-1) ps = _apply_top_k(ps, model) ps = _apply_top_p(ps, model) # Sample the next token and gather the log probabilities if do_sample: # Sampling next_token_ids = ( torch.multinomial(ps, 1) * active_seqs + pad_token_id * ~active_seqs ) else: # Greedy decoding next_token_ids = ( torch.argmax(ps, dim=-1).unsqueeze(-1) * active_seqs + pad_token_id * ~active_seqs ) next_token_logprobs = ps.gather(-1, next_token_ids).log() next_token_model_logprobs = model_ps.gather(-1, next_token_ids).log() input_ids = torch.cat([input_ids, next_token_ids], dim=-1) scores[:, i] = (next_token_logprobs * active_seqs).squeeze() logprobs[:, i] = (next_token_model_logprobs * active_seqs).squeeze() active_seqs &= ~torch.isin(next_token_ids, eos_token_ids) if active_seqs.sum() == 0: break return input_ids.detach().cpu(), scores[:, : i + 1], logprobs[:, : i + 1] def generate(model, **kwargs): """ Sampling strategy - multinomial sampling with temperature and optional top-k/top-p filtering. Simple implementation with proper KV caching support. Args: model: The language model model_kwargs: Model keyword arguments from the tokenizer generation_config: Generation configuration temperature: Sampling temperature (higher = more random) top_k: Only consider top-k most probable tokens top_p: Only consider tokens with cumulative probability <= top_p **kwargs: Additional arguments Returns: Generated token IDs """ generation_config = model.generation_config max_new_tokens = kwargs.get("max_new_tokens", generation_config.max_new_tokens) max_new_tokens = 512 if max_new_tokens is None else max_new_tokens do_sample = kwargs.get("do_sample", True) eos_token_ids = kwargs.get("eos_token_ids", generation_config.eos_token_id) if eos_token_ids is None: raise ValueError( "Model generation config does not have an EOS token id. You must provide it to generate() with the eos_token_ids argument." ) eos_token_ids = torch.as_tensor(eos_token_ids, device=model.device) if eos_token_ids is not None and eos_token_ids.ndim == 0: eos_token_ids = eos_token_ids.unsqueeze(0) pad_token_id = kwargs.get( "pad_token_id", generation_config.pad_token_id if generation_config.pad_token_id is not None else eos_token_ids[0], ) bos_token_id = kwargs.get("bos_token_id", generation_config.bos_token_id) if bos_token_id is None: raise ValueError( "Model generation config does not have a BOS token id. You must provide it to generate() with the bos_token_id argument." ) temperature = kwargs.get("temperature", 1.0) return_dict = kwargs.get("return_dict_in_generate", False) generated_ids, scores, logprobs = sampling( model_kwargs=kwargs, model=model, eos_token_ids=eos_token_ids, pad_token_id=pad_token_id, bos_token_id=bos_token_id, do_sample=do_sample, max_new_tokens=max_new_tokens, temperature=temperature, ) if return_dict: return { "sequences": generated_ids, "scores": scores, "logprobs": logprobs, } else: return generated_ids