Spaces:
Runtime error
Runtime error
| from transformers.models.bart import BartForConditionalGeneration | |
| import torch | |
| from transformers.generation_beam_search import BeamScorer | |
| from abc import ABC, abstractmethod | |
| from collections import UserDict | |
| from typing import Optional, Tuple, Union, Dict, Any | |
| from transformers.generation_logits_process import LogitsProcessorList | |
| from transformers.generation_utils import BeamSearchEncoderDecoderOutput,BeamSearchDecoderOnlyOutput | |
| from torch.nn import functional as F | |
| from transformers.file_utils import ModelOutput | |
| import torch.nn | |
| BeamSearchOutput = Union[BeamSearchEncoderDecoderOutput, BeamSearchDecoderOnlyOutput] | |
| class BartForConditionalGeneration_GroupBeam(BartForConditionalGeneration): | |
| def beam_search( | |
| self, | |
| input_ids: torch.LongTensor, | |
| beam_scorer: BeamScorer, | |
| logits_processor: Optional[LogitsProcessorList] = None, | |
| max_length: Optional[int] = None, | |
| pad_token_id: Optional[int] = None, | |
| eos_token_id: Optional[int] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| output_scores: Optional[bool] = None, | |
| return_dict_in_generate: Optional[bool] = None, | |
| **model_kwargs, | |
| ) -> Union[BeamSearchOutput, torch.LongTensor]: | |
| r""" | |
| Generates sequences for models with a language modeling head using beam search decoding. | |
| Parameters: | |
| input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): | |
| The sequence used as a prompt for the generation. If :obj:`None` the method initializes it as an empty | |
| :obj:`torch.LongTensor` of shape :obj:`(1,)`. | |
| beam_scorer (:obj:`BeamScorer`): | |
| An derived instance of :class:`~transformers.BeamScorer` that defines how beam hypotheses are | |
| constructed, stored and sorted during generation. For more information, the documentation of | |
| :class:`~transformers.BeamScorer` should be read. | |
| logits_processor (:obj:`LogitsProcessorList`, `optional`): | |
| An instance of :class:`~transformers.LogitsProcessorList`. List of instances of class derived from | |
| :class:`~transformers.LogitsProcessor` used to modify the prediction scores of the language modeling | |
| head applied at each generation step. | |
| max_length (:obj:`int`, `optional`, defaults to 20): | |
| The maximum length of the sequence to be generated. | |
| pad_token_id (:obj:`int`, `optional`): | |
| The id of the `padding` token. | |
| eos_token_id (:obj:`int`, `optional`): | |
| The id of the `end-of-sequence` token. | |
| output_attentions (:obj:`bool`, `optional`, defaults to `False`): | |
| Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under | |
| returned tensors for more details. | |
| output_hidden_states (:obj:`bool`, `optional`, defaults to `False`): | |
| Whether or not to return trhe hidden states of all layers. See ``hidden_states`` under returned tensors | |
| for more details. | |
| output_scores (:obj:`bool`, `optional`, defaults to `False`): | |
| Whether or not to return the prediction scores. See ``scores`` under returned tensors for more details. | |
| return_dict_in_generate (:obj:`bool`, `optional`, defaults to `False`): | |
| Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. | |
| model_kwargs: | |
| Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. If | |
| model is an encoder-decoder model the kwargs should include :obj:`encoder_outputs`. | |
| Return: | |
| :class:`~transformers.generation_utilsBeamSearchDecoderOnlyOutput`, | |
| :class:`~transformers.generation_utils.BeamSearchEncoderDecoderOutput` or obj:`torch.LongTensor`: A | |
| :obj:`torch.LongTensor` containing the generated tokens (default behaviour) or a | |
| :class:`~transformers.generation_utils.BeamSearchDecoderOnlyOutput` if | |
| ``model.config.is_encoder_decoder=False`` and ``return_dict_in_generate=True`` or a | |
| :class:`~transformers.generation_utils.BeamSearchEncoderDecoderOutput` if | |
| ``model.config.is_encoder_decoder=True``. | |
| Examples:: | |
| >>> from transformers import ( | |
| ... AutoTokenizer, | |
| ... AutoModelForSeq2SeqLM, | |
| ... LogitsProcessorList, | |
| ... MinLengthLogitsProcessor, | |
| ... BeamSearchScorer, | |
| ... ) | |
| >>> import torch | |
| >>> tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large") | |
| >>> model = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large") | |
| >>> encoder_input_str = "translate English to German: How old are you?" | |
| >>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids | |
| >>> # lets run beam search using 3 beams | |
| >>> num_beams = 3 | |
| >>> # define decoder start token ids | |
| >>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long) | |
| >>> input_ids = input_ids * model.config.decoder_start_token_id | |
| >>> # add encoder_outputs to model keyword arguments | |
| >>> model_kwargs = { | |
| ... "encoder_outputs": model.get_encoder()(encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True) | |
| ... } | |
| >>> # instantiate beam scorer | |
| >>> beam_scorer = BeamSearchScorer( | |
| ... batch_size=1, | |
| ... max_length=model.config.max_length, | |
| ... num_beams=num_beams, | |
| ... device=model.device, | |
| ... ) | |
| >>> # instantiate logits processors | |
| >>> logits_processor = LogitsProcessorList([ | |
| ... MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id), | |
| ... ]) | |
| >>> outputs = model.beam_search(input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs) | |
| >>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True)) | |
| """ | |
| # init values | |
| logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() | |
| max_length = max_length if max_length is not None else self.config.max_length | |
| pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id | |
| eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id | |
| output_scores = output_scores if output_scores is not None else self.config.output_scores | |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
| output_hidden_states = ( | |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
| ) | |
| return_dict_in_generate = ( | |
| return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate | |
| ) | |
| # init attention / hidden states / scores tuples | |
| scores = () if (return_dict_in_generate and output_scores) else None | |
| decoder_attentions = () if (return_dict_in_generate and output_attentions) else None | |
| decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None | |
| # if model is an encoder-decoder, retrieve encoder attention weights and hidden states | |
| if return_dict_in_generate and self.config.is_encoder_decoder: | |
| encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None | |
| encoder_hidden_states = ( | |
| model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None | |
| ) | |
| batch_size = len(beam_scorer._beam_hyps) | |
| num_beams = beam_scorer.num_beams | |
| batch_beam_size, cur_len = input_ids.shape | |
| assert ( | |
| num_beams * batch_size == batch_beam_size | |
| ), "Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." | |
| beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) | |
| beam_scores[:, 1:] = -1e9 | |
| beam_scores = beam_scores.view((batch_size * num_beams,)) | |
| while cur_len < max_length: | |
| model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) | |
| outputs = self( | |
| **model_inputs, | |
| return_dict=True, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| ) | |
| next_token_logits = outputs.logits[:, -1, :] | |
| # adjust tokens for Bart, *e.g.* | |
| next_token_logits = self.adjust_logits_during_generation( | |
| next_token_logits, cur_len=cur_len, max_length=max_length | |
| ) | |
| next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size) | |
| next_token_scores = logits_processor(input_ids, next_token_scores) | |
| next_token_scores = next_token_scores + beam_scores[:, None].expand_as(next_token_scores) | |
| # Store scores, attentions and hidden_states when required | |
| if return_dict_in_generate: | |
| if output_scores: | |
| scores += (next_token_scores,) | |
| if output_attentions: | |
| decoder_attentions += ( | |
| (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) | |
| ) | |
| if output_hidden_states: | |
| decoder_hidden_states += ( | |
| (outputs.decoder_hidden_states,) | |
| if self.config.is_encoder_decoder | |
| else (outputs.hidden_states,) | |
| ) | |
| # reshape for beam search | |
| vocab_size = next_token_scores.shape[-1] | |
| next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) | |
| #m = torch.nn.LayerNorm(num_beams * vocab_size) | |
| #next_token_scores = m(next_token_scores) | |
| next_token_scores_group = torch.sum(next_token_scores,dim=0,keepdim=True).expand(batch_size,-1) / batch_size | |
| for i in range(next_token_scores.size(0)): | |
| '''tmin = torch.min(next_token_scores_group[i]) | |
| for j in range(1,len(model_kwargs['decoder_ori_input_ids'][i])): | |
| next_token_scores_group[i][model_kwargs['decoder_ori_input_ids'][i][j]] = tmin''' | |
| for t in model_kwargs['decoder_ori_input_ids'][i]: | |
| for j in range(num_beams): | |
| #if t not in input_ids[i] or t==1: | |
| next_token_scores_group[i][j * vocab_size + t] = next_token_scores[i][j * vocab_size + t] | |
| next_token_scores, next_tokens = torch.topk( | |
| next_token_scores_group, 2 * num_beams, dim=1, largest=True, sorted=True) | |
| '''next_token_scores_group = next_token_scores_group.expand(batch_size,-1) | |
| next_tokens_group = next_tokens_group.expand(batch_size,-1) | |
| next_token_scores, next_tokens = torch.topk( | |
| next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True | |
| ) | |
| for i in range(next_token_scores.size(0)): | |
| j1 = 0 | |
| for j in range(next_token_scores.size(1)): | |
| if next_tokens[i][j] not in model_kwargs['decoder_ori_input_ids'][i]: | |
| next_tokens[i][j] = next_tokens_group[i][j1] | |
| j1 += 1 | |
| next_token_scores = next_token_scores_group | |
| del next_token_scores_group, next_tokens_group''' | |
| next_indices = next_tokens // vocab_size | |
| next_tokens = next_tokens % vocab_size | |
| # stateless | |
| beam_outputs = beam_scorer.process( | |
| input_ids, | |
| next_token_scores, | |
| next_tokens, | |
| next_indices, | |
| pad_token_id=pad_token_id, | |
| eos_token_id=eos_token_id, | |
| ) | |
| beam_scores = beam_outputs["next_beam_scores"] | |
| beam_next_tokens = beam_outputs["next_beam_tokens"] | |
| beam_idx = beam_outputs["next_beam_indices"] | |
| input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) | |
| cur_len = cur_len + 1 | |
| model_kwargs = self._update_model_kwargs_for_generation( | |
| outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder | |
| ) | |
| if model_kwargs["past"] is not None: | |
| model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], beam_idx) | |
| if beam_scorer.is_done: | |
| break | |
| sequence_outputs = beam_scorer.finalize( | |
| input_ids, beam_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id | |
| ) | |
| if return_dict_in_generate: | |
| if not output_scores: | |
| sequence_outputs["sequence_scores"] = None | |
| if self.config.is_encoder_decoder: | |
| return BeamSearchEncoderDecoderOutput( | |
| sequences=sequence_outputs["sequences"], | |
| sequences_scores=sequence_outputs["sequence_scores"], | |
| scores=scores, | |
| encoder_attentions=encoder_attentions, | |
| encoder_hidden_states=encoder_hidden_states, | |
| decoder_attentions=decoder_attentions, | |
| decoder_hidden_states=decoder_hidden_states, | |
| ) | |
| else: | |
| return BeamSearchDecoderOnlyOutput( | |
| sequences=sequence_outputs["sequences"], | |
| sequences_scores=sequence_outputs["sequence_scores"], | |
| scores=scores, | |
| attentions=decoder_attentions, | |
| hidden_states=decoder_hidden_states, | |
| ) | |
| else: | |
| return sequence_outputs["sequences"] | |
| def group_beam_search( | |
| self, | |
| input_ids: torch.LongTensor, | |
| beam_scorer: BeamScorer, | |
| logits_processor: Optional[LogitsProcessorList] = None, | |
| max_length: Optional[int] = None, | |
| pad_token_id: Optional[int] = None, | |
| eos_token_id: Optional[int] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| output_scores: Optional[bool] = None, | |
| return_dict_in_generate: Optional[bool] = None, | |
| **model_kwargs, | |
| ): | |
| r""" | |
| Generates sequences for models with a language modeling head using beam search decoding. | |
| Parameters: | |
| input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): | |
| The sequence used as a prompt for the generation. If :obj:`None` the method initializes it as an empty | |
| :obj:`torch.LongTensor` of shape :obj:`(1,)`. | |
| beam_scorer (:obj:`BeamScorer`): | |
| An derived instance of :class:`~transformers.BeamScorer` that defines how beam hypotheses are | |
| constructed, stored and sorted during generation. For more information, the documentation of | |
| :class:`~transformers.BeamScorer` should be read. | |
| logits_processor (:obj:`LogitsProcessorList`, `optional`): | |
| An instance of :class:`~transformers.LogitsProcessorList`. List of instances of class derived from | |
| :class:`~transformers.LogitsProcessor` used to modify the prediction scores of the language modeling | |
| head applied at each generation step. | |
| max_length (:obj:`int`, `optional`, defaults to 20): | |
| The maximum length of the sequence to be generated. | |
| pad_token_id (:obj:`int`, `optional`): | |
| The id of the `padding` token. | |
| eos_token_id (:obj:`int`, `optional`): | |
| The id of the `end-of-sequence` token. | |
| output_attentions (:obj:`bool`, `optional`, defaults to `False`): | |
| Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under | |
| returned tensors for more details. | |
| output_hidden_states (:obj:`bool`, `optional`, defaults to `False`): | |
| Whether or not to return trhe hidden states of all layers. See ``hidden_states`` under returned tensors | |
| for more details. | |
| output_scores (:obj:`bool`, `optional`, defaults to `False`): | |
| Whether or not to return the prediction scores. See ``scores`` under returned tensors for more details. | |
| return_dict_in_generate (:obj:`bool`, `optional`, defaults to `False`): | |
| Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. | |
| model_kwargs: | |
| Additional model specific kwargs that will be forwarded to the :obj:`forward` function of the model. If | |
| model is an encoder-decoder model the kwargs should include :obj:`encoder_outputs`. | |
| Return: | |
| :class:`~transformers.generation_utils.BeamSearchDecoderOnlyOutput`, | |
| :class:`~transformers.generation_utils.BeamSearchEncoderDecoderOutput` or obj:`torch.LongTensor`: A | |
| :obj:`torch.LongTensor` containing the generated tokens (default behaviour) or a | |
| :class:`~transformers.generation_utils.BeamSearchDecoderOnlyOutput` if | |
| :class:`~transformers.generation_utils.BeamSearchDecoderOnlyOutput` if | |
| ``model.config.is_encoder_decoder=False`` and ``return_dict_in_generate=True`` or a | |
| :class:`~transformers.generation_utils.BeamSearchEncoderDecoderOutput` if | |
| ``model.config.is_encoder_decoder=True``. | |
| Examples:: | |
| >>> from transformers import ( | |
| ... AutoTokenizer, | |
| ... AutoModelForSeq2SeqLM, | |
| ... LogitsProcessorList, | |
| ... MinLengthLogitsProcessor, | |
| ... HammingDiversityLogitsProcessor, | |
| ... BeamSearchScorer, | |
| ... ) | |
| >>> import torch | |
| >>> tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large") | |
| >>> model = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large") | |
| >>> encoder_input_str = "translate English to German: How old are you?" | |
| >>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids | |
| >>> # lets run diverse beam search using 6 beams | |
| >>> num_beams = 6 | |
| >>> # define decoder start token ids | |
| >>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long) | |
| >>> input_ids = input_ids * model.config.decoder_start_token_id | |
| >>> # add encoder_outputs to model keyword arguments | |
| >>> model_kwargs = { | |
| ... "encoder_outputs": model.get_encoder()(encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True) | |
| ... } | |
| >>> # instantiate beam scorer | |
| >>> beam_scorer = BeamSearchScorer( | |
| ... batch_size=1, | |
| ... max_length=model.config.max_length, | |
| ... num_beams=num_beams, | |
| ... device=model.device, | |
| ... num_beam_groups=3 | |
| ... ) | |
| >>> # instantiate logits processors | |
| >>> logits_processor = LogitsProcessorList([ | |
| ... HammingDiversityLogitsProcessor(5.5, num_beams=6, num_beam_groups=3), | |
| ... MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id), | |
| ... ]) | |
| >>> outputs = model.group_beam_search(input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs) | |
| >>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True)) | |
| """ | |
| # init values | |
| logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() | |
| max_length = max_length if max_length is not None else self.config.max_length | |
| pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id | |
| eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id | |
| output_scores = output_scores if output_scores is not None else self.config.output_scores | |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
| output_hidden_states = ( | |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
| ) | |
| return_dict_in_generate = ( | |
| return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate | |
| ) | |
| # init attention / hidden states / scores tuples | |
| scores = () if (return_dict_in_generate and output_scores) else None | |
| decoder_attentions = () if (return_dict_in_generate and output_attentions) else None | |
| decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None | |
| # if model is an encoder-decoder, retrieve encoder attention weights and hidden states | |
| if return_dict_in_generate and self.config.is_encoder_decoder: | |
| encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None | |
| encoder_hidden_states = ( | |
| model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None | |
| ) | |
| batch_size = len(beam_scorer._beam_hyps) | |
| num_beams = beam_scorer.num_beams | |
| num_beam_groups = beam_scorer.num_beam_groups | |
| num_sub_beams = num_beams // num_beam_groups | |
| device = input_ids.device | |
| batch_beam_size, cur_len = input_ids.shape | |
| assert ( | |
| num_beams * batch_size == batch_beam_size | |
| ), f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." | |
| beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device) | |
| # initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in | |
| # the same group don't produce same tokens everytime. | |
| beam_scores[:, ::num_sub_beams] = 0 | |
| beam_scores = beam_scores.view((batch_size * num_beams,)) | |
| while cur_len < max_length: | |
| # predicted tokens in cur_len step | |
| current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device) | |
| # indices which will form the beams in the next time step | |
| reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device) | |
| # do one decoder step on all beams of all sentences in batch | |
| model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) | |
| outputs = self( | |
| **model_inputs, | |
| return_dict=True, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| ) | |
| for beam_group_idx in range(num_beam_groups): | |
| group_start_idx = beam_group_idx * num_sub_beams | |
| group_end_idx = min(group_start_idx + num_sub_beams, num_beams) | |
| group_size = group_end_idx - group_start_idx | |
| # indices of beams of current group among all sentences in batch | |
| batch_group_indices = [] | |
| if output_scores: | |
| processed_score = torch.zeros_like(outputs.logits[:, -1, :]).half() # .float() | |
| for batch_idx in range(batch_size): | |
| batch_group_indices.extend( | |
| [batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)] | |
| ) | |
| group_input_ids = input_ids[batch_group_indices] | |
| # select outputs of beams of current group only | |
| next_token_logits = outputs.logits[batch_group_indices, -1, :] | |
| # adjust tokens for Bart, *e.g.* | |
| next_token_logits = self.adjust_logits_during_generation( | |
| next_token_logits, cur_len=cur_len, max_length=max_length | |
| ) | |
| next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * group_size, vocab_size) | |
| vocab_size = next_token_scores.shape[-1] | |
| next_token_scores = logits_processor( | |
| group_input_ids, next_token_scores, current_tokens=current_tokens, beam_group_idx=beam_group_idx | |
| ) | |
| next_token_scores = next_token_scores + beam_scores[batch_group_indices].unsqueeze(-1).expand_as( | |
| next_token_scores | |
| ) | |
| if output_scores: | |
| processed_score[batch_group_indices] = next_token_scores.half() # .float() | |
| # reshape for beam search | |
| next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size) | |
| ### | |
| next_token_scores_group = torch.sum(next_token_scores, dim=0, keepdim=True).expand(batch_size, | |
| -1) / batch_size | |
| for i in range(next_token_scores.size(0)): | |
| '''tmin = torch.min(next_token_scores_group[i]) | |
| for j in range(1,len(model_kwargs['decoder_ori_input_ids'][i])): | |
| next_token_scores_group[i][model_kwargs['decoder_ori_input_ids'][i][j]] = tmin''' | |
| for t in model_kwargs['decoder_ori_input_ids'][i]: | |
| for j in range(group_size): | |
| # if t not in input_ids[i] or t==1: | |
| next_token_scores_group[i][j * vocab_size + t] = next_token_scores[i][j * vocab_size + t] | |
| next_token_scores, next_tokens = torch.topk( | |
| next_token_scores_group, 2 * group_size, dim=1, largest=True, sorted=True) | |
| ### | |
| #next_token_scores, next_tokens = torch.topk( | |
| # next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True | |
| #) | |
| next_indices = next_tokens // vocab_size | |
| next_tokens = next_tokens % vocab_size | |
| # stateless | |
| beam_outputs = beam_scorer.process( | |
| group_input_ids, | |
| next_token_scores, | |
| next_tokens, | |
| next_indices, | |
| pad_token_id=pad_token_id, | |
| eos_token_id=eos_token_id, | |
| ) | |
| beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"] | |
| beam_next_tokens = beam_outputs["next_beam_tokens"] | |
| beam_idx = beam_outputs["next_beam_indices"] | |
| input_ids[batch_group_indices] = group_input_ids[beam_idx] | |
| group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) | |
| current_tokens[batch_group_indices] = group_input_ids[:, -1] | |
| # (beam_idx // group_size) -> batch_idx | |
| # (beam_idx % group_size) -> offset of idx inside the group | |
| reordering_indices[batch_group_indices] = ( | |
| num_beams * (beam_idx // group_size) + group_start_idx + (beam_idx % group_size) | |
| ) | |
| # Store scores, attentions and hidden_states when required | |
| if return_dict_in_generate: | |
| if output_scores: | |
| scores += (processed_score,) | |
| if output_attentions: | |
| decoder_attentions += ( | |
| (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) | |
| ) | |
| if output_hidden_states: | |
| decoder_hidden_states += ( | |
| (outputs.decoder_hidden_states,) | |
| if self.config.is_encoder_decoder | |
| else (outputs.hidden_states,) | |
| ) | |
| model_kwargs = self._update_model_kwargs_for_generation( | |
| outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder | |
| ) | |
| if model_kwargs["past"] is not None: | |
| model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], reordering_indices) | |
| input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1) | |
| cur_len = cur_len + 1 | |
| if beam_scorer.is_done: | |
| break | |
| sequence_outputs = beam_scorer.finalize( | |
| input_ids, beam_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id, max_length=max_length, | |
| ) | |
| if return_dict_in_generate: | |
| if not output_scores: | |
| sequence_outputs["sequence_scores"] | |
| if self.config.is_encoder_decoder: | |
| return BeamSearchEncoderDecoderOutput( | |
| sequences=sequence_outputs["sequences"], | |
| sequences_scores=sequence_outputs["sequence_scores"], | |
| scores=scores, | |
| encoder_attentions=encoder_attentions, | |
| encoder_hidden_states=encoder_hidden_states, | |
| decoder_attentions=decoder_attentions, | |
| decoder_hidden_states=decoder_hidden_states, | |
| ) | |
| else: | |
| return BeamSearchDecoderOnlyOutput( | |
| sequences=sequence_outputs["sequences"], | |
| sequences_scores=sequence_outputs["sequence_scores"], | |
| scores=scores, | |
| attentions=decoder_attentions, | |
| hidden_states=decoder_hidden_states, | |
| ) | |
| else: | |
| return sequence_outputs["sequences"] | |