|
import torch |
|
|
|
|
|
def next_logits_with_cache_update(model, model_kwargs, input_ids): |
|
""" |
|
Gets the next token logits and updates the KV cache: |
|
- Runs the model forward pass |
|
- Extracts logits for the last token |
|
- Updates the KV cache |
|
- Returns updated `model_kwargs` and `logits` |
|
|
|
Args: |
|
model: The language model |
|
model_kwargs: Model keyword arguments including KV cache |
|
input_ids: Current input token IDs |
|
|
|
Returns: |
|
Updated model_kwargs, logits for the next token |
|
""" |
|
model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs) |
|
with torch.no_grad(): |
|
outputs = model(**model_inputs, return_dict=True) |
|
|
|
logits = outputs.logits[:, -1].detach() |
|
model_kwargs = model._update_model_kwargs_for_generation( |
|
outputs, model_kwargs, is_encoder_decoder=model.config.is_encoder_decoder |
|
) |
|
del outputs |
|
return model_kwargs, logits |
|
|
|
def init_gen(model_kwargs, model, max_new_tokens, bos_token_id): |
|
""" |
|
Initializes the generation process and prepares the KV cache: |
|
- Sets up input sequences and model inputs |
|
- Prepares the KV cache for generation |
|
- Returns updated `model_kwargs` and `input_ids` |
|
|
|
Args: |
|
model_kwargs: Model keyword arguments |
|
model: The language model |
|
max_new_tokens: Maximum number of new tokens to generate |
|
bos_token_id: Beginning-of-sequence token ID |
|
|
|
Returns: |
|
Model keyword arguments and input token IDs |
|
""" |
|
|
|
input_ids, model_input_name, model_kwargs = model._prepare_model_inputs( |
|
None, bos_token_id, model_kwargs |
|
) |
|
|
|
batch_size = input_ids.shape[0] |
|
model._prepare_cache_for_generation( |
|
model.generation_config, model_kwargs, None, batch_size, |
|
max_cache_length=max_new_tokens, device=input_ids.device |
|
) |
|
|
|
|
|
model_kwargs = model._get_initial_cache_position(input_ids.shape[1], input_ids.device, model_kwargs) |
|
return model_kwargs, input_ids |
|
|
|
def _apply_top_k(ps, model): |
|
"""Apply top-k filtering to probabilities.""" |
|
if not hasattr(model, 'generation_config') or not hasattr(model.generation_config, 'top_k'): |
|
return ps |
|
|
|
top_k = model.generation_config.top_k |
|
if top_k is None or top_k >= ps.size(-1): |
|
return ps |
|
|
|
indices_to_remove = ps < torch.topk(ps, top_k)[0][..., -1, None] |
|
ps[indices_to_remove] = 0.0 |
|
return ps / ps.sum(dim=-1, keepdim=True) |
|
|
|
def _apply_top_p(ps, model): |
|
"""Apply top-p (nucleus) filtering to probabilities.""" |
|
if not hasattr(model, 'generation_config') or not hasattr(model.generation_config, 'top_p'): |
|
return ps |
|
|
|
top_p = model.generation_config.top_p |
|
if top_p is None or top_p >= 1.0: |
|
return ps |
|
|
|
sorted_probs, sorted_indices = torch.sort(ps, descending=True) |
|
cumulative_probs = torch.cumsum(sorted_probs, dim=-1) |
|
|
|
sorted_indices_to_remove = cumulative_probs > top_p |
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
|
sorted_indices_to_remove[..., 0] = 0 |
|
|
|
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) |
|
ps[indices_to_remove] = 0.0 |
|
return ps / ps.sum(dim=-1, keepdim=True) |
|
|
|
def sampling_with_kvcache(model_kwargs, model, eos_token_ids, pad_token_id, bos_token_id, do_sample=True, max_new_tokens=20, temperature=1.0): |
|
""" |
|
Sampling implementation with proper KV caching. |
|
|
|
Args: |
|
prompts: List of input prompts |
|
model: The language model |
|
max_new_tokens: Maximum number of new tokens to generate |
|
eos_token_ids: List of end-of-sequence token IDs |
|
pad_token_id: Padding token ID |
|
bos_token_id: Beginning-of-sequence token ID |
|
max_new_tokens: Maximum number of new tokens to generate |
|
|
|
Returns: |
|
Generated sequences, log probabilities, and metadata |
|
""" |
|
|
|
model_kwargs, input_ids = init_gen(model_kwargs, model, max_new_tokens, bos_token_id) |
|
batch_size, _ = input_ids.shape |
|
|
|
|
|
active_seqs = input_ids.new_ones((batch_size, 1), dtype=torch.bool) |
|
|
|
scores = torch.zeros((batch_size, max_new_tokens), dtype=model.dtype) |
|
|
|
logprobs = torch.zeros((batch_size, max_new_tokens), dtype=model.dtype) |
|
|
|
for i in range(max_new_tokens): |
|
|
|
model_kwargs, logits = next_logits_with_cache_update(model, model_kwargs, input_ids) |
|
|
|
model_ps = logits.softmax(-1) |
|
|
|
|
|
ps = (logits/temperature).softmax(-1) |
|
ps = _apply_top_k(ps, model) |
|
ps = _apply_top_p(ps, model) |
|
|
|
|
|
if do_sample: |
|
next_token_ids = torch.multinomial(ps, 1) * active_seqs + pad_token_id * ~active_seqs |
|
else: |
|
next_token_ids = torch.argmax(ps, dim=-1).unsqueeze(-1) * active_seqs + pad_token_id * ~active_seqs |
|
next_token_logprobs = ps.gather(-1, next_token_ids).log() |
|
next_token_model_logprobs = model_ps.gather(-1, next_token_ids).log() |
|
|
|
input_ids = torch.cat([input_ids, next_token_ids], dim=-1) |
|
scores[:, i] = (next_token_logprobs * active_seqs).squeeze() |
|
logprobs[:, i] = (next_token_model_logprobs * active_seqs).squeeze() |
|
|
|
active_seqs &= ~torch.isin(next_token_ids, eos_token_ids) |
|
if active_seqs.sum() == 0: |
|
break |
|
return input_ids.detach().cpu(), scores[:,:i+1], logprobs[:,:i+1] |
|
|
|
def generate(model, **kwargs): |
|
""" |
|
Sampling strategy - multinomial sampling with temperature and optional top-k/top-p filtering. |
|
Simple implementation with proper KV caching support. |
|
|
|
Args: |
|
model: The language model |
|
model_kwargs: Model keyword arguments from the tokenizer |
|
generation_config: Generation configuration |
|
temperature: Sampling temperature (higher = more random) |
|
top_k: Only consider top-k most probable tokens |
|
top_p: Only consider tokens with cumulative probability <= top_p |
|
**kwargs: Additional arguments |
|
|
|
Returns: |
|
Generated token IDs |
|
""" |
|
generation_config = model.generation_config |
|
max_new_tokens = kwargs.get('max_new_tokens', generation_config.max_new_tokens) |
|
max_new_tokens = 512 if max_new_tokens is None else max_new_tokens |
|
do_sample = kwargs.get('do_sample', True) |
|
eos_token_ids = kwargs.get('eos_token_ids', generation_config.eos_token_id) |
|
if eos_token_ids is None: |
|
raise ValueError("Model generation config does not have an EOS token id. You must provide it to generate() with the eos_token_ids argument.") |
|
eos_token_ids = torch.as_tensor(eos_token_ids, device=model.device) |
|
if eos_token_ids is not None and eos_token_ids.ndim == 0: |
|
eos_token_ids = eos_token_ids.unsqueeze(0) |
|
|
|
pad_token_id = kwargs.get('pad_token_id', generation_config.pad_token_id if generation_config.pad_token_id is not None else eos_token_ids[0]) |
|
bos_token_id = kwargs.get('bos_token_id', generation_config.bos_token_id) |
|
if bos_token_id is None: |
|
raise ValueError("Model generation config does not have a BOS token id. You must provide it to generate() with the bos_token_id argument.") |
|
temperature = kwargs.get('temperature', 1.0) |
|
return_dict = kwargs.get('return_dict_in_generate', False) |
|
|
|
generated_ids, scores, logprobs = sampling_with_kvcache( |
|
model_kwargs=kwargs, |
|
model=model, |
|
eos_token_ids=eos_token_ids, |
|
pad_token_id=pad_token_id, |
|
bos_token_id=bos_token_id, |
|
do_sample=do_sample, |
|
max_new_tokens=max_new_tokens, |
|
temperature=temperature, |
|
) |
|
|
|
if return_dict: |
|
return { |
|
"sequences": generated_ids, |
|
"scores": scores, |
|
"logprobs": logprobs, |
|
} |
|
else: |
|
return generated_ids |
|
|
|
|