Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| from typing import Optional | |
| import torch | |
| import torch.nn.functional as F | |
| from mmengine.model import BaseModule | |
| from mmpretrain.models.utils.box_utils import (box_cxcywh_to_xyxy, | |
| generalized_box_iou) | |
| from mmpretrain.registry import MODELS, TOKENIZER | |
| class GroundingHead(BaseModule): | |
| """bbox Coordination generation head for multi-modal pre-trained task, | |
| adapted by BLIP. Normally used for visual grounding. | |
| Args: | |
| loss: dict, | |
| decoder: dict, | |
| init_cfg (dict, optional): the config to control the initialization. | |
| Defaults to None. | |
| """ | |
| def __init__( | |
| self, | |
| decoder: dict = None, | |
| tokenizer: dict = None, | |
| box_l1_loss_coeff=4.0, | |
| box_giou_loss_coeff=2.0, | |
| init_cfg: Optional[dict] = None, | |
| ) -> None: | |
| super(GroundingHead, self).__init__(init_cfg=init_cfg) | |
| ''' init the decoder from med_config''' | |
| self.decoder = None | |
| if decoder: | |
| self.decoder = MODELS.build(decoder) | |
| self.loss_fn = torch.nn.CrossEntropyLoss( | |
| reduction='none', ignore_index=-100) | |
| self.box_l1_loss_coeff = box_l1_loss_coeff | |
| self.box_giou_loss_coeff = box_giou_loss_coeff | |
| if isinstance(tokenizer, dict): | |
| self.tokenizer = TOKENIZER.build(tokenizer) | |
| else: | |
| self.tokenizer = tokenizer | |
| self.image_res = 640 | |
| prefix_ids = torch.tensor( | |
| self.tokenizer.convert_tokens_to_ids(['[unused339]'])) | |
| target_ids = torch.tensor( | |
| self.tokenizer.convert_tokens_to_ids( | |
| [f'[unused{340+_}]' for _ in range(self.image_res + 1)])) | |
| self.register_buffer('prefix_ids', prefix_ids) | |
| self.register_buffer('target_ids', target_ids) | |
| bbox_prob_mask = torch.zeros(len(self.tokenizer)) | |
| bbox_prob_mask[self.target_ids[0]:self.target_ids[-1] + 1] = 1 | |
| bbox_prob_mask = (1.0 - bbox_prob_mask) * -10000.0 | |
| self.register_buffer('bbox_prob_mask', bbox_prob_mask) | |
| self.bin_start_idx = self.target_ids[0] | |
| def forward(self, text_embedding, text_embedding_mask, | |
| encoder_hidden_states, encoder_attention_mask): | |
| # localize prompt token, text embedding | |
| merged_encode_hs = torch.cat([encoder_hidden_states, text_embedding], | |
| 1) | |
| merge_att_mask = torch.cat( | |
| [encoder_attention_mask, text_embedding_mask], 1) | |
| loc_prompt = self.prompt.weight.T | |
| loc_prompt = torch.repeat_interleave(loc_prompt, | |
| merge_att_mask.shape[0], | |
| 0).unsqueeze(1) | |
| loc_prompt_mask = torch.ones(loc_prompt.shape[:-1]).long().to( | |
| loc_prompt.device) | |
| decoder_out = self.decoder( | |
| inputs_embeds=loc_prompt, | |
| attention_mask=loc_prompt_mask, | |
| encoder_hidden_states=merged_encode_hs, | |
| encoder_attention_mask=merge_att_mask, | |
| output_hidden_states=True, | |
| labels=None, | |
| ) | |
| decoder_hs = decoder_out.hidden_states[-1][:, 0, :] | |
| box_pred = self.box_head(decoder_hs) | |
| return decoder_out, decoder_hs, box_pred | |
| def loss(self, | |
| text_embedding, | |
| text_embedding_mask, | |
| encoder_hidden_states, | |
| encoder_attention_mask, | |
| decoder_targets, | |
| return_scores=False): | |
| """Calculate losses from the extracted features. | |
| Args: | |
| feats (dict): The features extracted from the backbone. | |
| data_samples (List[BaseDataElement]): The annotation data of | |
| every samples. | |
| Returns: | |
| dict[str, Tensor]: a dictionary of loss components | |
| """ | |
| merged_encode_hs = torch.cat([encoder_hidden_states, text_embedding], | |
| 1) | |
| merge_att_mask = torch.cat( | |
| [encoder_attention_mask, text_embedding_mask], 1) | |
| answer_targets = (decoder_targets * | |
| self.image_res).long() + self.bin_start_idx | |
| prefix_ids = torch.repeat_interleave(self.prefix_ids, | |
| merge_att_mask.shape[0], | |
| 0).unsqueeze(-1) | |
| prefix_ids = torch.cat([prefix_ids, answer_targets], dim=1) | |
| answer_output = self.decoder( | |
| prefix_ids, | |
| encoder_hidden_states=merged_encode_hs, | |
| encoder_attention_mask=merge_att_mask, | |
| labels=None, | |
| return_dict=True, | |
| ) | |
| prob_mask = self.bbox_prob_mask.view(1, 1, | |
| self.bbox_prob_mask.shape[-1]) | |
| prediction_scores = answer_output.logits + prob_mask | |
| shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous() | |
| labels = prefix_ids[:, 1:].contiguous() | |
| vocab_size = len(self.tokenizer) | |
| loss_seq_init = self.loss_fn( | |
| shifted_prediction_scores.view(-1, vocab_size), labels.view(-1)) | |
| with torch.no_grad(): | |
| pred_box = (torch.argmax( | |
| prediction_scores[:, :-1, :].contiguous(), dim=-1) - | |
| self.bin_start_idx) / self.image_res | |
| weight_bbox = F.l1_loss( | |
| pred_box, decoder_targets, reduction='none').clamp( | |
| 0, 5) * self.box_l1_loss_coeff | |
| weight_giou = (1 - torch.diag( | |
| generalized_box_iou( | |
| box_cxcywh_to_xyxy(pred_box), | |
| box_cxcywh_to_xyxy(decoder_targets))) | |
| ) * self.box_giou_loss_coeff | |
| bs = text_embedding.shape[0] | |
| loss_seq = loss_seq_init[:].view(bs, -1, 4) | |
| loss_seq = loss_seq * weight_bbox | |
| loss_seq = loss_seq * weight_giou.unsqueeze(1) | |
| loss_seq = loss_seq.mean() | |
| losses = { | |
| 'loss_seq': loss_seq, | |
| 'loss_seq_init': loss_seq_init.mean(), | |
| 'loss': loss_seq, | |
| 'box_l1': weight_bbox.mean(-1).mean().detach(), | |
| 'box_giou': weight_giou.mean().detach() | |
| } | |
| return losses | |
| def predict( | |
| self, | |
| text_embedding, | |
| text_embedding_mask, | |
| encoder_hidden_states, | |
| encoder_attention_mask, | |
| ): | |
| """Generates the bbox coordinates at inference time.""" | |
| merged_encode_hs = torch.cat([encoder_hidden_states, text_embedding], | |
| 1) | |
| merge_att_mask = torch.cat( | |
| [encoder_attention_mask, text_embedding_mask], 1) | |
| prefix_ids = torch.repeat_interleave(self.prefix_ids, | |
| merge_att_mask.shape[0], | |
| 0).unsqueeze(-1) | |
| for _ in range(4): | |
| decoder_output = self.decoder( | |
| prefix_ids, | |
| encoder_hidden_states=merged_encode_hs, | |
| encoder_attention_mask=merge_att_mask, | |
| labels=None, | |
| return_dict=True, | |
| ) | |
| prob_mask = self.bbox_prob_mask.view(1, 1, | |
| self.bbox_prob_mask.shape[-1]) | |
| prediction_scores = decoder_output.logits + prob_mask | |
| prefix_ids = torch.cat([ | |
| prefix_ids, | |
| torch.argmax(prediction_scores[:, -1, :], dim=-1).unsqueeze(1) | |
| ], | |
| dim=1) | |
| pred_box = self.process_bbox(prefix_ids[:, 1:]) # xywh 0-1 to xyxy 0-1 | |
| return pred_box | |
| def process_bbox(self, bbox): | |
| bbox = bbox - self.bin_start_idx | |
| bbox = torch.true_divide(bbox, self.image_res) | |
| bbox = box_cxcywh_to_xyxy(bbox) | |
| bbox = torch.clip(bbox, 0, 1) | |
| assert torch.all(bbox <= 1) | |
| return bbox | |