|
import torch |
|
from transformers import Cache, DynamicCache |
|
from transformers.generation.utils import ModelOutput |
|
from typing import Optional, Any |
|
|
|
def prepare_inputs_for_generation( |
|
input_ids: torch.LongTensor, |
|
attention_mask: Optional[torch.LongTensor] = None, |
|
cache_position: Optional[torch.LongTensor] = None, |
|
**kwargs, |
|
): |
|
input_ids = input_ids[:, cache_position].clone(memory_format=torch.contiguous_format) |
|
model_inputs = {"cache_position": cache_position, |
|
"past_key_values": None, |
|
"input_ids": input_ids, |
|
"inputs_embeds": None, |
|
"attention_mask": attention_mask, |
|
} |
|
if attention_mask is not None and kwargs.get("position_ids") is None: |
|
position_ids = attention_mask.long().cumsum(-1) - 1 |
|
position_ids.masked_fill_(attention_mask == 0, 1) |
|
kwargs["position_ids"] = position_ids |
|
model_inputs.update({k: v for k, v in kwargs.items() if k not in model_inputs}) |
|
return model_inputs |
|
|
|
def update_model_kwargs_for_generation( |
|
outputs: ModelOutput, |
|
model_kwargs: dict[str, Any], |
|
num_new_tokens: int = 1, |
|
) -> dict[str, Any]: |
|
if "token_type_ids" in model_kwargs: |
|
token_type_ids = model_kwargs["token_type_ids"] |
|
model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1) |
|
if "attention_mask" in model_kwargs: |
|
attention_mask = model_kwargs["attention_mask"] |
|
model_kwargs["attention_mask"] = torch.cat( |
|
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 |
|
) |
|
past_positions = model_kwargs.pop("cache_position") |
|
new_positions = torch.arange( |
|
past_positions[-1] + 1, past_positions[-1] + num_new_tokens + 1, dtype=past_positions.dtype |
|
).to(past_positions.device) |
|
model_kwargs["cache_position"] = torch.cat((past_positions, new_positions)) |
|
return model_kwargs |
|
|
|
|
|
def next_logits_with_cache_update(model, model_kwargs, input_ids): |
|
""" |
|
Gets the next token logits and updates the KV cache: |
|
- Runs the model forward pass |
|
- Extracts logits for the last token |
|
- Updates the KV cache |
|
- Returns updated `model_kwargs` and `logits` |
|
|
|
Args: |
|
model: The language model |
|
model_kwargs: Model keyword arguments including KV cache |
|
input_ids: Current input token IDs |
|
|
|
Returns: |
|
Updated model_kwargs, logits for the next token |
|
""" |
|
model_inputs = prepare_inputs_for_generation(input_ids, **model_kwargs) |
|
with torch.no_grad(): |
|
outputs = model(**model_inputs, return_dict=True) |
|
|
|
logits = outputs.logits[:, -1].detach() |
|
model_kwargs = update_model_kwargs_for_generation(outputs, model_kwargs) |
|
del outputs |
|
return model_kwargs, logits |
|
|
|
|
|
def init_gen(model_kwargs, model, max_new_tokens, bos_token_id): |
|
""" |
|
Initializes the generation process and prepares the KV cache: |
|
- Sets up input sequences and model inputs |
|
- Prepares the KV cache for generation |
|
- Returns updated `model_kwargs` and `input_ids` |
|
|
|
Args: |
|
model_kwargs: Model keyword arguments |
|
model: The language model |
|
max_new_tokens: Maximum number of new tokens to generate |
|
bos_token_id: Beginning-of-sequence token ID |
|
|
|
Returns: |
|
Model keyword arguments and input token IDs |
|
""" |
|
input_ids = model_kwargs.pop("input_ids") |
|
model_kwargs["past_key_values"] = None |
|
model_kwargs["cache_position"] = torch.ones(input_ids.shape[1], dtype=torch.int64, device=input_ids.device).cumsum(0) - 1 |
|
return model_kwargs, input_ids |
|
|
|
|
|
def _apply_top_k(ps, model): |
|
"""Apply top-k filtering to probabilities.""" |
|
if not hasattr(model, "generation_config") or not hasattr( |
|
model.generation_config, "top_k" |
|
): |
|
return ps |
|
|
|
top_k = model.generation_config.top_k |
|
if top_k is None or top_k >= ps.size(-1): |
|
return ps |
|
|
|
indices_to_remove = ps < torch.topk(ps, top_k)[0][..., -1, None] |
|
ps[indices_to_remove] = 0.0 |
|
return ps / ps.sum(dim=-1, keepdim=True) |
|
|
|
|
|
def _apply_top_p(ps, model): |
|
"""Apply top-p (nucleus) filtering to probabilities.""" |
|
if not hasattr(model, "generation_config") or not hasattr( |
|
model.generation_config, "top_p" |
|
): |
|
return ps |
|
|
|
top_p = model.generation_config.top_p |
|
if top_p is None or top_p >= 1.0: |
|
return ps |
|
|
|
sorted_probs, sorted_indices = torch.sort(ps, descending=True) |
|
cumulative_probs = torch.cumsum(sorted_probs, dim=-1) |
|
|
|
sorted_indices_to_remove = cumulative_probs > top_p |
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
|
sorted_indices_to_remove[..., 0] = 0 |
|
|
|
indices_to_remove = sorted_indices_to_remove.scatter( |
|
1, sorted_indices, sorted_indices_to_remove |
|
) |
|
ps[indices_to_remove] = 0.0 |
|
return ps / ps.sum(dim=-1, keepdim=True) |
|
|
|
|
|
def sampling( |
|
model_kwargs, |
|
model, |
|
eos_token_ids, |
|
pad_token_id, |
|
bos_token_id, |
|
do_sample=True, |
|
max_new_tokens=20, |
|
temperature=1.0, |
|
): |
|
""" |
|
Sampling implementation with proper KV caching. |
|
|
|
Args: |
|
prompts: List of input prompts |
|
model: The language model |
|
max_new_tokens: Maximum number of new tokens to generate |
|
eos_token_ids: List of end-of-sequence token IDs |
|
pad_token_id: Padding token ID |
|
bos_token_id: Beginning-of-sequence token ID |
|
max_new_tokens: Maximum number of new tokens to generate |
|
|
|
Returns: |
|
Generated sequences, log probabilities, and metadata |
|
""" |
|
|
|
model_kwargs, input_ids = init_gen( |
|
model_kwargs, model, max_new_tokens, bos_token_id |
|
) |
|
batch_size, _ = input_ids.shape |
|
|
|
|
|
active_seqs = input_ids.new_ones((batch_size, 1), dtype=torch.bool) |
|
|
|
scores = torch.zeros((batch_size, max_new_tokens), dtype=model.dtype) |
|
|
|
logprobs = torch.zeros((batch_size, max_new_tokens), dtype=model.dtype) |
|
|
|
for i in range(max_new_tokens): |
|
|
|
model_kwargs, logits = next_logits_with_cache_update( |
|
model, model_kwargs, input_ids |
|
) |
|
|
|
model_ps = logits.softmax(-1) |
|
|
|
|
|
ps = (logits / temperature).softmax(-1) |
|
ps = _apply_top_k(ps, model) |
|
ps = _apply_top_p(ps, model) |
|
|
|
|
|
if do_sample: |
|
next_token_ids = ( |
|
torch.multinomial(ps, 1) * active_seqs + pad_token_id * ~active_seqs |
|
) |
|
else: |
|
next_token_ids = ( |
|
torch.argmax(ps, dim=-1).unsqueeze(-1) * active_seqs |
|
+ pad_token_id * ~active_seqs |
|
) |
|
next_token_logprobs = ps.gather(-1, next_token_ids).log() |
|
next_token_model_logprobs = model_ps.gather(-1, next_token_ids).log() |
|
|
|
input_ids = torch.cat([input_ids, next_token_ids], dim=-1) |
|
scores[:, i] = (next_token_logprobs * active_seqs).squeeze() |
|
logprobs[:, i] = (next_token_model_logprobs * active_seqs).squeeze() |
|
|
|
active_seqs &= ~torch.isin(next_token_ids, eos_token_ids) |
|
if active_seqs.sum() == 0: |
|
break |
|
return input_ids.detach().cpu(), scores[:, : i + 1], logprobs[:, : i + 1] |
|
|
|
|
|
def generate(model, **kwargs): |
|
""" |
|
Sampling strategy - multinomial sampling with temperature and optional top-k/top-p filtering. |
|
Simple implementation with proper KV caching support. |
|
|
|
Args: |
|
model: The language model |
|
model_kwargs: Model keyword arguments from the tokenizer |
|
generation_config: Generation configuration |
|
temperature: Sampling temperature (higher = more random) |
|
top_k: Only consider top-k most probable tokens |
|
top_p: Only consider tokens with cumulative probability <= top_p |
|
**kwargs: Additional arguments |
|
|
|
Returns: |
|
Generated token IDs |
|
""" |
|
generation_config = model.generation_config |
|
max_new_tokens = kwargs.get("max_new_tokens", generation_config.max_new_tokens) |
|
max_new_tokens = 512 if max_new_tokens is None else max_new_tokens |
|
do_sample = kwargs.get("do_sample", True) |
|
eos_token_ids = kwargs.get("eos_token_ids", generation_config.eos_token_id) |
|
if eos_token_ids is None: |
|
raise ValueError( |
|
"Model generation config does not have an EOS token id. You must provide it to generate() with the eos_token_ids argument." |
|
) |
|
eos_token_ids = torch.as_tensor(eos_token_ids, device=model.device) |
|
if eos_token_ids is not None and eos_token_ids.ndim == 0: |
|
eos_token_ids = eos_token_ids.unsqueeze(0) |
|
|
|
pad_token_id = kwargs.get( |
|
"pad_token_id", |
|
generation_config.pad_token_id |
|
if generation_config.pad_token_id is not None |
|
else eos_token_ids[0], |
|
) |
|
bos_token_id = kwargs.get("bos_token_id", generation_config.bos_token_id) |
|
if bos_token_id is None: |
|
raise ValueError( |
|
"Model generation config does not have a BOS token id. You must provide it to generate() with the bos_token_id argument." |
|
) |
|
temperature = kwargs.get("temperature", 1.0) |
|
return_dict = kwargs.get("return_dict_in_generate", False) |
|
|
|
generated_ids, scores, logprobs = sampling( |
|
model_kwargs=kwargs, |
|
model=model, |
|
eos_token_ids=eos_token_ids, |
|
pad_token_id=pad_token_id, |
|
bos_token_id=bos_token_id, |
|
do_sample=do_sample, |
|
max_new_tokens=max_new_tokens, |
|
temperature=temperature, |
|
) |
|
|
|
if return_dict: |
|
return { |
|
"sequences": generated_ids, |
|
"scores": scores, |
|
"logprobs": logprobs, |
|
} |
|
else: |
|
return generated_ids |
|
|