Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| from typing import Optional | |
| import torch | |
| from mmengine.model import BaseModule | |
| from mmpretrain.registry import MODELS | |
| class SeqGenerationHead(BaseModule): | |
| """Generation head for multi-modal pre-trained task, adopted by BLIP. | |
| Normally used for generation task. | |
| Args: | |
| decoder (dict): Decoder for blip generation head. | |
| init_cfg (dict, optional): the config to control the initialization. | |
| Defaults to None. | |
| """ | |
| def __init__( | |
| self, | |
| decoder: dict, | |
| ignore_index=-100, | |
| loss: dict = dict(type='LabelSmoothLoss', label_smooth_val=0.1), | |
| init_cfg: Optional[dict] = None, | |
| ) -> None: | |
| super(SeqGenerationHead, self).__init__(init_cfg=init_cfg) | |
| self.decoder = MODELS.build(decoder) | |
| self.loss_fn = MODELS.build(loss) | |
| self.ignore_index = ignore_index | |
| def forward(self, input_ids: torch.Tensor, | |
| encoder_hidden_states: torch.Tensor, | |
| encoder_attention_mask: torch.Tensor, labels: torch.Tensor): | |
| """Forward to get decoder output. | |
| Args: | |
| input_ids (torch.Tensor): The tokenized input text tensor. | |
| encoder_hidden_states (torch.Tensor): Hidden states from image | |
| embeddings. | |
| encoder_attention_mask (torch.Tensor): Image embeddings hidden | |
| states attention mask. | |
| labels (torch.Tensor): Decoder target for calculate loss. | |
| Returns: | |
| dict[str, Tensor]: a dictionary of decoder outputs. | |
| """ | |
| decoder_out = self.decoder( | |
| input_ids=input_ids, | |
| encoder_hidden_states=encoder_hidden_states, | |
| encoder_attention_mask=encoder_attention_mask, | |
| labels=labels, | |
| return_dict=True, | |
| ) | |
| return decoder_out | |
| def loss(self, input_ids, encoder_hidden_states, encoder_attention_mask, | |
| labels): | |
| """Calculate losses from the extracted features. | |
| Args: | |
| input_ids (torch.Tensor): The tokenized input text tensor. | |
| encoder_hidden_states (torch.Tensor): Hidden states from image | |
| embeddings. | |
| encoder_attention_mask (torch.Tensor): Image embeddings hidden | |
| states attention mask. | |
| labels (torch.Tensor): Decoder target for calculate loss. | |
| Returns: | |
| dict[str, Tensor]: a dictionary of loss components. | |
| """ | |
| decoder_out = self( | |
| input_ids=input_ids, | |
| encoder_hidden_states=encoder_hidden_states, | |
| encoder_attention_mask=encoder_attention_mask, | |
| labels=labels, | |
| ) | |
| prediction_scores = decoder_out['logits'] | |
| # we are doing next-token prediction; | |
| # shift prediction scores and input ids by one | |
| shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous() | |
| labels = labels[:, 1:].contiguous() | |
| vocab_size = prediction_scores.shape[-1] | |
| # mask ignored index | |
| if (labels == self.ignore_index).any(): | |
| labels = labels.view(-1).clone() | |
| ignore_mask = (labels == self.ignore_index) | |
| labels.masked_fill_(ignore_mask, 0) | |
| weight = torch.logical_not(ignore_mask) | |
| avg_factor = max(weight.sum(), 1) | |
| else: | |
| weight = None | |
| avg_factor = labels.size(0) | |
| lm_loss = self.loss_fn( | |
| shifted_prediction_scores.view(-1, vocab_size), | |
| labels, | |
| weight=weight, | |
| avg_factor=avg_factor, | |
| ) | |
| losses = { | |
| 'seq_gen_lm_loss': lm_loss, | |
| } | |
| return losses | |
| def predict(self, | |
| input_ids, | |
| encoder_hidden_states, | |
| sep_token_id, | |
| pad_token_id, | |
| use_nucleus_sampling=False, | |
| num_beams=3, | |
| max_length=20, | |
| min_length=2, | |
| top_p=0.9, | |
| repetition_penalty=1.0, | |
| **kwargs): | |
| """Decoder prediction method. | |
| Args: | |
| input_ids (torch.Tensor): The tokenized input text tensor. | |
| encoder_hidden_states (torch.Tensor): Hidden states from image | |
| embeddings. | |
| sep_token_id (int): Tokenid of separation token. | |
| pad_token_id (int): Tokenid of pad token. | |
| use_nucleus_sampling (bool): Whether to use nucleus sampling in | |
| prediction. Defaults to False. | |
| num_beams (int): Number of beams used in predition. | |
| Defaults to 3. | |
| max_length (int): Max length of generated text in predition. | |
| Defaults to 20. | |
| min_length (int): Min length of generated text in predition. | |
| Defaults to 20. | |
| top_p (float): | |
| If < 1.0, only keep the top tokens with cumulative probability | |
| >= top_p (nucleus filtering). Defaults to 0.9. | |
| repetition_penalty (float): The parameter for repetition penalty. | |
| Defaults to 1.0. | |
| **kwarg: Other arguments that might used in generation. | |
| Returns: | |
| dict[str, Tensor]: a dictionary of generation outputs. | |
| """ | |
| device = encoder_hidden_states.device | |
| # TODO: In old version of transformers | |
| # Additional repeat interleave of hidden states should be add here. | |
| image_atts = torch.ones( | |
| encoder_hidden_states.size()[:-1], dtype=torch.long).to(device) | |
| model_kwargs = { | |
| 'encoder_hidden_states': encoder_hidden_states, | |
| 'encoder_attention_mask': image_atts, | |
| } | |
| model_kwargs.update(kwargs) | |
| if use_nucleus_sampling: | |
| # nucleus sampling | |
| outputs = self.decoder.generate( | |
| input_ids=input_ids, | |
| max_length=max_length, | |
| min_length=min_length, | |
| do_sample=True, | |
| top_p=top_p, | |
| num_return_sequences=1, | |
| eos_token_id=sep_token_id, | |
| pad_token_id=pad_token_id, | |
| repetition_penalty=1.1, | |
| **model_kwargs) | |
| else: | |
| # beam search | |
| outputs = self.decoder.generate( | |
| input_ids=input_ids, | |
| max_length=max_length, | |
| min_length=min_length, | |
| num_beams=num_beams, | |
| eos_token_id=sep_token_id, | |
| pad_token_id=pad_token_id, | |
| repetition_penalty=repetition_penalty, | |
| **model_kwargs) | |
| return outputs | |