manueldeprada's picture
eliminate hf helpers
0da6031
import torch
from transformers import Cache, DynamicCache
from transformers.generation.utils import ModelOutput
from typing import Optional, Any
def prepare_inputs_for_generation(
input_ids: torch.LongTensor,
past_key_values: Optional[Cache] = None,
attention_mask: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
):
input_ids = input_ids[:, cache_position].clone(memory_format=torch.contiguous_format)
cur_len = input_ids.shape[1]
model_inputs = {"cache_position": cache_position,
"past_key_values": past_key_values,
"input_ids": input_ids,
"inputs_embeds": None,
"attention_mask": attention_mask,
}
if attention_mask is not None and kwargs.get("position_ids") is None:
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
kwargs["position_ids"] = position_ids
if past_key_values is not None:
for name in ("position_ids", "token_type_ids"):
if name in kwargs:
kwargs[name] = kwargs[name][:, -cur_len:].clone(memory_format=torch.contiguous_format)
model_inputs.update({k: v for k, v in kwargs.items() if k not in model_inputs})
return model_inputs
def update_model_kwargs_for_generation(
outputs: ModelOutput,
model_kwargs: dict[str, Any],
num_new_tokens: int = 1,
) -> dict[str, Any]:
model_kwargs["past_key_values"] = getattr(outputs, "past_key_values")
if "token_type_ids" in model_kwargs:
token_type_ids = model_kwargs["token_type_ids"]
model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)
if "attention_mask" in model_kwargs:
attention_mask = model_kwargs["attention_mask"]
model_kwargs["attention_mask"] = torch.cat(
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
)
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens
return model_kwargs
def next_logits_with_cache_update(model, model_kwargs, input_ids):
"""
Gets the next token logits and updates the KV cache:
- Runs the model forward pass
- Extracts logits for the last token
- Updates the KV cache
- Returns updated `model_kwargs` and `logits`
Args:
model: The language model
model_kwargs: Model keyword arguments including KV cache
input_ids: Current input token IDs
Returns:
Updated model_kwargs, logits for the next token
"""
model_inputs = prepare_inputs_for_generation(input_ids, **model_kwargs)
with torch.no_grad():
outputs = model(**model_inputs, return_dict=True)
logits = outputs.logits[:, -1].detach()
model_kwargs = update_model_kwargs_for_generation(outputs, model_kwargs)
del outputs
return model_kwargs, logits
def init_gen(model_kwargs, model, max_new_tokens, bos_token_id):
"""
Initializes the generation process and prepares the KV cache:
- Sets up input sequences and model inputs
- Prepares the KV cache for generation
- Returns updated `model_kwargs` and `input_ids`
Args:
model_kwargs: Model keyword arguments
model: The language model
max_new_tokens: Maximum number of new tokens to generate
bos_token_id: Beginning-of-sequence token ID
Returns:
Model keyword arguments and input token IDs
"""
input_ids = model_kwargs.pop("input_ids")
model_kwargs["past_key_values"] = DynamicCache() if model_kwargs.get("past_key_values") is None else model_kwargs["past_key_values"]
assert isinstance(model_kwargs["past_key_values"], Cache), "past_key_values must be a Cache object"
cache_position = torch.ones(input_ids.shape[1], dtype=torch.int64, device=input_ids.device).cumsum(0) - 1
cache_position = cache_position[model_kwargs["past_key_values"].get_seq_length() :]
model_kwargs["cache_position"] = cache_position
return model_kwargs, input_ids
def _apply_top_k(ps, model):
"""Apply top-k filtering to probabilities."""
if not hasattr(model, "generation_config") or not hasattr(
model.generation_config, "top_k"
):
return ps
top_k = model.generation_config.top_k
if top_k is None or top_k >= ps.size(-1):
return ps
indices_to_remove = ps < torch.topk(ps, top_k)[0][..., -1, None]
ps[indices_to_remove] = 0.0
return ps / ps.sum(dim=-1, keepdim=True)
def _apply_top_p(ps, model):
"""Apply top-p (nucleus) filtering to probabilities."""
if not hasattr(model, "generation_config") or not hasattr(
model.generation_config, "top_p"
):
return ps
top_p = model.generation_config.top_p
if top_p is None or top_p >= 1.0:
return ps
sorted_probs, sorted_indices = torch.sort(ps, descending=True)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices_to_remove.scatter(
1, sorted_indices, sorted_indices_to_remove
)
ps[indices_to_remove] = 0.0
return ps / ps.sum(dim=-1, keepdim=True)
def sampling_with_kvcache(
model_kwargs,
model,
eos_token_ids,
pad_token_id,
bos_token_id,
do_sample=True,
max_new_tokens=20,
temperature=1.0,
):
"""
Sampling implementation with proper KV caching.
Args:
prompts: List of input prompts
model: The language model
max_new_tokens: Maximum number of new tokens to generate
eos_token_ids: List of end-of-sequence token IDs
pad_token_id: Padding token ID
bos_token_id: Beginning-of-sequence token ID
max_new_tokens: Maximum number of new tokens to generate
Returns:
Generated sequences, log probabilities, and metadata
"""
# Initialize the generation process and prepare the KV cache
model_kwargs, input_ids = init_gen(
model_kwargs, model, max_new_tokens, bos_token_id
)
batch_size, _ = input_ids.shape
# Keeps track of which sequences are finished and their lengths
active_seqs = input_ids.new_ones((batch_size, 1), dtype=torch.bool)
# Modified log probabilities of the sequences
scores = torch.zeros((batch_size, max_new_tokens), dtype=model.dtype)
# Unfiltered sequence log probabilities (temperature=1, no sampling processors applied)
logprobs = torch.zeros((batch_size, max_new_tokens), dtype=model.dtype)
for i in range(max_new_tokens):
# Get the next token probabilities and update the KV cache
model_kwargs, logits = next_logits_with_cache_update(
model, model_kwargs, input_ids
)
# Store original model probabilities (temperature=1, no sampling processors applied)
model_ps = logits.softmax(-1)
# Logit processors (temperature, top-k, top-p). We can chain these!
ps = (logits / temperature).softmax(-1)
ps = _apply_top_k(ps, model)
ps = _apply_top_p(ps, model)
# Sample the next token and gather the log probabilities
if do_sample: # Sampling
next_token_ids = (
torch.multinomial(ps, 1) * active_seqs + pad_token_id * ~active_seqs
)
else: # Greedy decoding
next_token_ids = (
torch.argmax(ps, dim=-1).unsqueeze(-1) * active_seqs
+ pad_token_id * ~active_seqs
)
next_token_logprobs = ps.gather(-1, next_token_ids).log()
next_token_model_logprobs = model_ps.gather(-1, next_token_ids).log()
input_ids = torch.cat([input_ids, next_token_ids], dim=-1)
scores[:, i] = (next_token_logprobs * active_seqs).squeeze()
logprobs[:, i] = (next_token_model_logprobs * active_seqs).squeeze()
active_seqs &= ~torch.isin(next_token_ids, eos_token_ids)
if active_seqs.sum() == 0:
break
return input_ids.detach().cpu(), scores[:, : i + 1], logprobs[:, : i + 1]
def generate(model, **kwargs):
"""
Sampling strategy - multinomial sampling with temperature and optional top-k/top-p filtering.
Simple implementation with proper KV caching support.
Args:
model: The language model
model_kwargs: Model keyword arguments from the tokenizer
generation_config: Generation configuration
temperature: Sampling temperature (higher = more random)
top_k: Only consider top-k most probable tokens
top_p: Only consider tokens with cumulative probability <= top_p
**kwargs: Additional arguments
Returns:
Generated token IDs
"""
generation_config = model.generation_config
max_new_tokens = kwargs.get("max_new_tokens", generation_config.max_new_tokens)
max_new_tokens = 512 if max_new_tokens is None else max_new_tokens
do_sample = kwargs.get("do_sample", True)
eos_token_ids = kwargs.get("eos_token_ids", generation_config.eos_token_id)
if eos_token_ids is None:
raise ValueError(
"Model generation config does not have an EOS token id. You must provide it to generate() with the eos_token_ids argument."
)
eos_token_ids = torch.as_tensor(eos_token_ids, device=model.device)
if eos_token_ids is not None and eos_token_ids.ndim == 0:
eos_token_ids = eos_token_ids.unsqueeze(0)
pad_token_id = kwargs.get(
"pad_token_id",
generation_config.pad_token_id
if generation_config.pad_token_id is not None
else eos_token_ids[0],
)
bos_token_id = kwargs.get("bos_token_id", generation_config.bos_token_id)
if bos_token_id is None:
raise ValueError(
"Model generation config does not have a BOS token id. You must provide it to generate() with the bos_token_id argument."
)
temperature = kwargs.get("temperature", 1.0)
return_dict = kwargs.get("return_dict_in_generate", False)
generated_ids, scores, logprobs = sampling_with_kvcache(
model_kwargs=kwargs,
model=model,
eos_token_ids=eos_token_ids,
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
do_sample=do_sample,
max_new_tokens=max_new_tokens,
temperature=temperature,
)
if return_dict:
return {
"sequences": generated_ids,
"scores": scores,
"logprobs": logprobs,
}
else:
return generated_ids