|
import torch |
|
import torch.utils.checkpoint |
|
from torch import nn |
|
|
|
from transformers import AutoModelForCausalLM, AutoConfig |
|
from transformers.models.opt.modeling_opt import OPTForCausalLM, OPTModel, OPTDecoder, OPTConfig |
|
|
|
from transformers.utils import logging |
|
from typing import Optional, Union |
|
|
|
from transformers.generation.logits_process import LogitsProcessorList |
|
from transformers.generation.utils import GenerateNonBeamOutput, GenerateEncoderDecoderOutput, GenerateDecoderOnlyOutput |
|
from transformers.generation.stopping_criteria import StoppingCriteriaList |
|
from transformers.generation.configuration_utils import GenerationConfig |
|
from transformers.generation.streamers import BaseStreamer |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
class BBoxOPTConfig(OPTConfig): |
|
model_type = "mesh_opt" |
|
|
|
class BBoxOPTDecoder(OPTDecoder): |
|
config_class = BBoxOPTConfig |
|
|
|
class BBoxOPTModel(OPTModel): |
|
config_class = BBoxOPTConfig |
|
def __init__(self, config: BBoxOPTConfig): |
|
super(OPTModel, self).__init__(config) |
|
self.decoder = BBoxOPTDecoder(config) |
|
|
|
self.post_init() |
|
|
|
class BBoxOPT(OPTForCausalLM): |
|
config_class = BBoxOPTConfig |
|
|
|
def __init__(self, config: BBoxOPTConfig): |
|
super(OPTForCausalLM, self).__init__(config) |
|
self.model = BBoxOPTModel(config) |
|
|
|
|
|
self.lm_head = nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False) |
|
|
|
|
|
self.post_init() |
|
|
|
def _sample( |
|
self, |
|
input_ids: torch.LongTensor, |
|
logits_processor: LogitsProcessorList, |
|
stopping_criteria: StoppingCriteriaList, |
|
generation_config: GenerationConfig, |
|
synced_gpus: bool, |
|
streamer: Optional["BaseStreamer"], |
|
**model_kwargs, |
|
) -> Union[GenerateNonBeamOutput, torch.LongTensor]: |
|
r""" |
|
Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and |
|
can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. |
|
|
|
Parameters: |
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): |
|
The sequence used as a prompt for the generation. |
|
logits_processor (`LogitsProcessorList`): |
|
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] |
|
used to modify the prediction scores of the language modeling head applied at each generation step. |
|
stopping_criteria (`StoppingCriteriaList`): |
|
An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] |
|
used to tell if the generation loop should stop. |
|
generation_config ([`~generation.GenerationConfig`]): |
|
The generation configuration to be used as parametrization of the decoding method. |
|
synced_gpus (`bool`): |
|
Whether to continue running the while loop until max_length (needed for ZeRO stage 3) |
|
streamer (`BaseStreamer`, *optional*): |
|
Streamer object that will be used to stream the generated sequences. Generated tokens are passed |
|
through `streamer.put(token_ids)` and the streamer is responsible for any further processing. |
|
model_kwargs: |
|
Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is |
|
an encoder-decoder model the kwargs should include `encoder_outputs`. |
|
|
|
Return: |
|
[`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] or `torch.LongTensor`: |
|
A `torch.LongTensor` containing the generated tokens (default behaviour) or a |
|
[`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and |
|
`return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if |
|
`model.config.is_encoder_decoder=True`. |
|
""" |
|
|
|
pad_token_id = generation_config._pad_token_tensor |
|
output_attentions = generation_config.output_attentions |
|
output_hidden_states = generation_config.output_hidden_states |
|
output_scores = generation_config.output_scores |
|
output_logits = generation_config.output_logits |
|
return_dict_in_generate = generation_config.return_dict_in_generate |
|
max_length = generation_config.max_length |
|
has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria) |
|
do_sample = generation_config.do_sample |
|
|
|
|
|
scores = () if (return_dict_in_generate and output_scores) else None |
|
raw_logits = () if (return_dict_in_generate and output_logits) else None |
|
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None |
|
cross_attentions = () if (return_dict_in_generate and output_attentions) else None |
|
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None |
|
|
|
|
|
if return_dict_in_generate and self.config.is_encoder_decoder: |
|
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None |
|
encoder_hidden_states = ( |
|
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None |
|
) |
|
|
|
|
|
batch_size, cur_len = input_ids.shape |
|
this_peer_finished = False |
|
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) |
|
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) |
|
|
|
while self._has_unfinished_sequences( |
|
this_peer_finished, synced_gpus, device=input_ids.device |
|
) and cur_len < max_length: |
|
|
|
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) |
|
|
|
|
|
model_inputs.update({"output_attentions": output_attentions} if output_attentions else {}) |
|
model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {}) |
|
|
|
|
|
outputs = self(**model_inputs, return_dict=True) |
|
|
|
if synced_gpus and this_peer_finished: |
|
continue |
|
|
|
|
|
|
|
next_token_logits = outputs.logits.clone()[:, -1, :].float() |
|
|
|
|
|
next_token_scores = logits_processor(input_ids, next_token_logits) |
|
|
|
|
|
if return_dict_in_generate: |
|
if output_scores: |
|
scores += (next_token_scores,) |
|
if output_logits: |
|
raw_logits += (next_token_logits,) |
|
if output_attentions: |
|
decoder_attentions += ( |
|
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) |
|
) |
|
if self.config.is_encoder_decoder: |
|
cross_attentions += (outputs.cross_attentions,) |
|
|
|
if output_hidden_states: |
|
decoder_hidden_states += ( |
|
(outputs.decoder_hidden_states,) |
|
if self.config.is_encoder_decoder |
|
else (outputs.hidden_states,) |
|
) |
|
|
|
|
|
if do_sample: |
|
probs = nn.functional.softmax(next_token_scores, dim=-1) |
|
|
|
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) |
|
else: |
|
next_tokens = torch.argmax(next_token_scores, dim=-1) |
|
|
|
|
|
if has_eos_stopping_criteria: |
|
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) |
|
|
|
|
|
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) |
|
if streamer is not None: |
|
streamer.put(next_tokens.cpu()) |
|
model_kwargs = self._update_model_kwargs_for_generation( |
|
outputs, |
|
model_kwargs, |
|
is_encoder_decoder=self.config.is_encoder_decoder, |
|
) |
|
|
|
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) |
|
this_peer_finished = unfinished_sequences.max() == 0 |
|
cur_len += 1 |
|
|
|
|
|
|
|
del outputs |
|
|
|
if streamer is not None: |
|
streamer.end() |
|
|
|
if return_dict_in_generate: |
|
if self.config.is_encoder_decoder: |
|
return GenerateEncoderDecoderOutput( |
|
sequences=input_ids, |
|
scores=scores, |
|
logits=raw_logits, |
|
encoder_attentions=encoder_attentions, |
|
encoder_hidden_states=encoder_hidden_states, |
|
decoder_attentions=decoder_attentions, |
|
cross_attentions=cross_attentions, |
|
decoder_hidden_states=decoder_hidden_states, |
|
past_key_values=model_kwargs.get("past_key_values"), |
|
) |
|
else: |
|
return GenerateDecoderOnlyOutput( |
|
sequences=input_ids, |
|
scores=scores, |
|
logits=raw_logits, |
|
attentions=decoder_attentions, |
|
hidden_states=decoder_hidden_states, |
|
past_key_values=model_kwargs.get("past_key_values"), |
|
) |
|
else: |
|
return input_ids |
|
|
|
|
|
AutoConfig.register("mesh_opt", BBoxOPTConfig) |
|
AutoModelForCausalLM.register(BBoxOPTConfig, BBoxOPT) |
|
|