Spaces:
Running
on
T4
Running
on
T4
| # ------------------------------------------------------------------------ | |
| # Grounding DINO | |
| # url: https://github.com/IDEA-Research/GroundingDINO | |
| # Copyright (c) 2023 IDEA. All Rights Reserved. | |
| # Licensed under the Apache License, Version 2.0 [see LICENSE for details] | |
| # ------------------------------------------------------------------------ | |
| # Conditional DETR model and criterion classes. | |
| # Copyright (c) 2021 Microsoft. All Rights Reserved. | |
| # Licensed under the Apache License, Version 2.0 [see LICENSE for details] | |
| # ------------------------------------------------------------------------ | |
| # Modified from DETR (https://github.com/facebookresearch/detr) | |
| # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | |
| # ------------------------------------------------------------------------ | |
| # Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR) | |
| # Copyright (c) 2020 SenseTime. All Rights Reserved. | |
| # ------------------------------------------------------------------------ | |
| import copy | |
| from typing import List | |
| import torchvision.transforms.functional as vis_F | |
| from torchvision.transforms import InterpolationMode | |
| import torch | |
| import torch.nn.functional as F | |
| from torch import nn | |
| from torchvision.ops.boxes import nms | |
| from torchvision.ops import roi_align | |
| from transformers import ( | |
| AutoTokenizer, | |
| BertModel, | |
| BertTokenizer, | |
| RobertaModel, | |
| RobertaTokenizerFast, | |
| ) | |
| from groundingdino.util import box_ops, get_tokenlizer | |
| from groundingdino.util.misc import ( | |
| NestedTensor, | |
| accuracy, | |
| get_world_size, | |
| interpolate, | |
| inverse_sigmoid, | |
| is_dist_avail_and_initialized, | |
| nested_tensor_from_tensor_list, | |
| ) | |
| from groundingdino.util.utils import get_phrases_from_posmap | |
| from groundingdino.util.visualizer import COCOVisualizer | |
| from groundingdino.util.vl_utils import create_positive_map_from_span | |
| from ..registry import MODULE_BUILD_FUNCS | |
| from .backbone import build_backbone | |
| from .bertwarper import ( | |
| BertModelWarper, | |
| generate_masks_with_special_tokens, | |
| generate_masks_with_special_tokens_and_transfer_map, | |
| ) | |
| from .transformer import build_transformer | |
| from .utils import MLP, ContrastiveEmbed, sigmoid_focal_loss | |
| from .matcher import build_matcher | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from matplotlib.patches import Rectangle | |
| from groundingdino.util.visualizer import renorm | |
| def numpy_2_cv2(np_img): | |
| if np.min(np_img) < 0: | |
| raise Exception("image min is less than 0. Img min: " + str(np.min(np_img))) | |
| if np.max(np_img) > 1: | |
| raise Exception("image max is greater than 1. Img max: " + str(np.max(np_img))) | |
| np_img = (np_img * 255).astype(np.uint8) | |
| # Need to somehow ensure image is in RGB format. Note this line shows up in SAM demo: image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| cv2_image = np.asarray(np_img) | |
| return cv2_image | |
| def vis_exemps(image, exemp, f_name): | |
| plt.imshow(image) | |
| plt.gca().add_patch( | |
| Rectangle( | |
| (exemp[0], exemp[1]), | |
| exemp[2] - exemp[0], | |
| exemp[3] - exemp[1], | |
| edgecolor="red", | |
| facecolor="none", | |
| lw=1, | |
| ) | |
| ) | |
| plt.savefig(f_name) | |
| plt.close() | |
| class GroundingDINO(nn.Module): | |
| """This is the Cross-Attention Detector module that performs object detection""" | |
| def __init__( | |
| self, | |
| backbone, | |
| transformer, | |
| num_queries, | |
| aux_loss=False, | |
| iter_update=False, | |
| query_dim=2, | |
| num_feature_levels=1, | |
| nheads=8, | |
| # two stage | |
| two_stage_type="no", # ['no', 'standard'] | |
| dec_pred_bbox_embed_share=True, | |
| two_stage_class_embed_share=True, | |
| two_stage_bbox_embed_share=True, | |
| num_patterns=0, | |
| dn_number=100, | |
| dn_box_noise_scale=0.4, | |
| dn_label_noise_ratio=0.5, | |
| dn_labelbook_size=100, | |
| text_encoder_type="bert-base-uncased", | |
| sub_sentence_present=True, | |
| max_text_len=256, | |
| ): | |
| """Initializes the model. | |
| Parameters: | |
| backbone: torch module of the backbone to be used. See backbone.py | |
| transformer: torch module of the transformer architecture. See transformer.py | |
| num_queries: number of object queries, ie detection slot. This is the maximal number of objects | |
| Conditional DETR can detect in a single image. For COCO, we recommend 100 queries. | |
| aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used. | |
| """ | |
| super().__init__() | |
| self.num_queries = num_queries | |
| self.transformer = transformer | |
| self.hidden_dim = hidden_dim = transformer.d_model | |
| self.num_feature_levels = num_feature_levels | |
| self.nheads = nheads | |
| self.max_text_len = max_text_len | |
| self.sub_sentence_present = sub_sentence_present | |
| # setting query dim | |
| self.query_dim = query_dim | |
| assert query_dim == 4 | |
| # visual exemplar cropping | |
| self.feature_map_proj = nn.Conv2d((256 + 512 + 1024), hidden_dim, kernel_size=1) | |
| # for dn training | |
| self.num_patterns = num_patterns | |
| self.dn_number = dn_number | |
| self.dn_box_noise_scale = dn_box_noise_scale | |
| self.dn_label_noise_ratio = dn_label_noise_ratio | |
| self.dn_labelbook_size = dn_labelbook_size | |
| # bert | |
| self.tokenizer = get_tokenlizer.get_tokenlizer(text_encoder_type) | |
| self.bert = get_tokenlizer.get_pretrained_language_model(text_encoder_type) | |
| self.bert.pooler.dense.weight.requires_grad_(False) | |
| self.bert.pooler.dense.bias.requires_grad_(False) | |
| self.bert = BertModelWarper(bert_model=self.bert) | |
| self.feat_map = nn.Linear( | |
| self.bert.config.hidden_size, self.hidden_dim, bias=True | |
| ) | |
| nn.init.constant_(self.feat_map.bias.data, 0) | |
| nn.init.xavier_uniform_(self.feat_map.weight.data) | |
| # freeze | |
| # special tokens | |
| self.specical_tokens = self.tokenizer.convert_tokens_to_ids( | |
| ["[CLS]", "[SEP]", ".", "?"] | |
| ) | |
| # prepare input projection layers | |
| if num_feature_levels > 1: | |
| num_backbone_outs = len(backbone.num_channels) | |
| input_proj_list = [] | |
| for _ in range(num_backbone_outs): | |
| in_channels = backbone.num_channels[_] | |
| input_proj_list.append( | |
| nn.Sequential( | |
| nn.Conv2d(in_channels, hidden_dim, kernel_size=1), | |
| nn.GroupNorm(32, hidden_dim), | |
| ) | |
| ) | |
| for _ in range(num_feature_levels - num_backbone_outs): | |
| input_proj_list.append( | |
| nn.Sequential( | |
| nn.Conv2d( | |
| in_channels, hidden_dim, kernel_size=3, stride=2, padding=1 | |
| ), | |
| nn.GroupNorm(32, hidden_dim), | |
| ) | |
| ) | |
| in_channels = hidden_dim | |
| self.input_proj = nn.ModuleList(input_proj_list) | |
| else: | |
| assert ( | |
| two_stage_type == "no" | |
| ), "two_stage_type should be no if num_feature_levels=1 !!!" | |
| self.input_proj = nn.ModuleList( | |
| [ | |
| nn.Sequential( | |
| nn.Conv2d(backbone.num_channels[-1], hidden_dim, kernel_size=1), | |
| nn.GroupNorm(32, hidden_dim), | |
| ) | |
| ] | |
| ) | |
| self.backbone = backbone | |
| self.aux_loss = aux_loss | |
| self.box_pred_damping = box_pred_damping = None | |
| self.iter_update = iter_update | |
| assert iter_update, "Why not iter_update?" | |
| # prepare pred layers | |
| self.dec_pred_bbox_embed_share = dec_pred_bbox_embed_share | |
| # prepare class & box embed | |
| _class_embed = ContrastiveEmbed() | |
| _bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3) | |
| nn.init.constant_(_bbox_embed.layers[-1].weight.data, 0) | |
| nn.init.constant_(_bbox_embed.layers[-1].bias.data, 0) | |
| if dec_pred_bbox_embed_share: | |
| box_embed_layerlist = [ | |
| _bbox_embed for i in range(transformer.num_decoder_layers) | |
| ] | |
| else: | |
| box_embed_layerlist = [ | |
| copy.deepcopy(_bbox_embed) | |
| for i in range(transformer.num_decoder_layers) | |
| ] | |
| class_embed_layerlist = [ | |
| _class_embed for i in range(transformer.num_decoder_layers) | |
| ] | |
| self.bbox_embed = nn.ModuleList(box_embed_layerlist) | |
| self.class_embed = nn.ModuleList(class_embed_layerlist) | |
| self.transformer.decoder.bbox_embed = self.bbox_embed | |
| self.transformer.decoder.class_embed = self.class_embed | |
| # two stage | |
| self.two_stage_type = two_stage_type | |
| assert two_stage_type in [ | |
| "no", | |
| "standard", | |
| ], "unknown param {} of two_stage_type".format(two_stage_type) | |
| if two_stage_type != "no": | |
| if two_stage_bbox_embed_share: | |
| assert dec_pred_bbox_embed_share | |
| self.transformer.enc_out_bbox_embed = _bbox_embed | |
| else: | |
| self.transformer.enc_out_bbox_embed = copy.deepcopy(_bbox_embed) | |
| if two_stage_class_embed_share: | |
| assert dec_pred_bbox_embed_share | |
| self.transformer.enc_out_class_embed = _class_embed | |
| else: | |
| self.transformer.enc_out_class_embed = copy.deepcopy(_class_embed) | |
| self.refpoint_embed = None | |
| self._reset_parameters() | |
| def _reset_parameters(self): | |
| # init input_proj | |
| for proj in self.input_proj: | |
| nn.init.xavier_uniform_(proj[0].weight, gain=1) | |
| nn.init.constant_(proj[0].bias, 0) | |
| def init_ref_points(self, use_num_queries): | |
| self.refpoint_embed = nn.Embedding(use_num_queries, self.query_dim) | |
| def add_exemplar_tokens(self, tokenized, text_dict, exemplar_tokens, labels): | |
| input_ids = tokenized["input_ids"] | |
| device = input_ids.device | |
| new_input_ids = [] | |
| encoded_text = text_dict["encoded_text"] | |
| new_encoded_text = [] | |
| text_token_mask = text_dict["text_token_mask"] | |
| new_text_token_mask = [] | |
| position_ids = text_dict["position_ids"] | |
| text_self_attention_masks = text_dict["text_self_attention_masks"] | |
| for sample_ind in range(len(labels)): | |
| label = labels[sample_ind][0] | |
| exemplars = exemplar_tokens[sample_ind] | |
| label_count = -1 | |
| assert len(input_ids[sample_ind]) == len(position_ids[sample_ind]) | |
| for token_ind in range(len(input_ids[sample_ind])): | |
| input_id = input_ids[sample_ind][token_ind] | |
| if (input_id not in self.specical_tokens) and ( | |
| token_ind == 0 | |
| or (input_ids[sample_ind][token_ind - 1] in self.specical_tokens) | |
| ): | |
| label_count += 1 | |
| if label_count == label: | |
| # Get the index where to insert the exemplar tokens. | |
| ind_to_insert_exemplar = token_ind | |
| while ( | |
| input_ids[sample_ind][ind_to_insert_exemplar] | |
| not in self.specical_tokens | |
| ): | |
| ind_to_insert_exemplar += 1 | |
| break | |
| # Handle no text case. | |
| if label_count == -1: | |
| ind_to_insert_exemplar = 1 | |
| # * token indicates exemplar. | |
| new_input_ids.append( | |
| torch.cat( | |
| [ | |
| input_ids[sample_ind][:ind_to_insert_exemplar], | |
| torch.tensor([1008] * exemplars.shape[0]).to(device), | |
| input_ids[sample_ind][ind_to_insert_exemplar:], | |
| ] | |
| ) | |
| ) | |
| new_encoded_text.append( | |
| torch.cat( | |
| [ | |
| encoded_text[sample_ind][:ind_to_insert_exemplar, :], | |
| exemplars, | |
| encoded_text[sample_ind][ind_to_insert_exemplar:, :], | |
| ] | |
| ) | |
| ) | |
| new_text_token_mask.append( | |
| torch.full((len(new_input_ids[sample_ind]),), True).to(device) | |
| ) | |
| tokenized["input_ids"] = torch.stack(new_input_ids) | |
| print(tokenized["input_ids"]) | |
| ( | |
| text_self_attention_masks, | |
| position_ids, | |
| _, | |
| ) = generate_masks_with_special_tokens_and_transfer_map( | |
| tokenized, self.specical_tokens, None | |
| ) | |
| return { | |
| "encoded_text": torch.stack(new_encoded_text), | |
| "text_token_mask": torch.stack(new_text_token_mask), | |
| "position_ids": position_ids, | |
| "text_self_attention_masks": text_self_attention_masks, | |
| } | |
| def combine_features(self, features): | |
| (bs, c, h, w) = ( | |
| features[0].decompose()[0].shape[-4], | |
| features[0].decompose()[0].shape[-3], | |
| features[0].decompose()[0].shape[-2], | |
| features[0].decompose()[0].shape[-1], | |
| ) | |
| x = torch.cat( | |
| [ | |
| F.interpolate( | |
| feat.decompose()[0], | |
| size=(h, w), | |
| mode="bilinear", | |
| align_corners=True, | |
| ) | |
| for feat in features | |
| ], | |
| dim=1, | |
| ) | |
| x = self.feature_map_proj(x) | |
| return x | |
| def forward( | |
| self, | |
| samples: NestedTensor, | |
| exemplar_images: NestedTensor, | |
| exemplars: List, | |
| labels, | |
| targets: List = None, | |
| cropped=False, | |
| orig_img=None, | |
| crop_width=0, | |
| crop_height=0, | |
| **kw, | |
| ): | |
| """The forward expects a NestedTensor, which consists of: | |
| - samples.tensor: batched images, of shape [batch_size x 3 x H x W] | |
| - samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels | |
| It returns a dict with the following elements: | |
| - "pred_logits": the classification logits (including no-object) for all queries. | |
| Shape= [batch_size x num_queries x num_classes] | |
| - "pred_boxes": The normalized boxes coordinates for all queries, represented as | |
| (center_x, center_y, width, height). These values are normalized in [0, 1], | |
| relative to the size of each individual image (disregarding possible padding). | |
| See PostProcess for information on how to retrieve the unnormalized bounding box. | |
| - "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of | |
| dictionnaries containing the two above keys for each decoder layer. | |
| """ | |
| if targets is None: | |
| captions = kw["captions"] | |
| else: | |
| captions = [t["caption"] for t in targets] | |
| # encoder texts | |
| tokenized = self.tokenizer(captions, padding="longest", return_tensors="pt").to( | |
| samples.device | |
| ) | |
| one_hot_token = tokenized | |
| ( | |
| text_self_attention_masks, | |
| position_ids, | |
| cate_to_token_mask_list, | |
| ) = generate_masks_with_special_tokens_and_transfer_map( | |
| tokenized, self.specical_tokens, self.tokenizer | |
| ) | |
| if text_self_attention_masks.shape[1] > self.max_text_len: | |
| text_self_attention_masks = text_self_attention_masks[ | |
| :, : self.max_text_len, : self.max_text_len | |
| ] | |
| position_ids = position_ids[:, : self.max_text_len] | |
| tokenized["input_ids"] = tokenized["input_ids"][:, : self.max_text_len] | |
| tokenized["attention_mask"] = tokenized["attention_mask"][ | |
| :, : self.max_text_len | |
| ] | |
| tokenized["token_type_ids"] = tokenized["token_type_ids"][ | |
| :, : self.max_text_len | |
| ] | |
| # extract text embeddings | |
| if self.sub_sentence_present: | |
| tokenized_for_encoder = { | |
| k: v for k, v in tokenized.items() if k != "attention_mask" | |
| } | |
| tokenized_for_encoder["attention_mask"] = text_self_attention_masks | |
| tokenized_for_encoder["position_ids"] = position_ids | |
| else: | |
| tokenized_for_encoder = tokenized | |
| bert_output = self.bert(**tokenized_for_encoder) # bs, 195, 768 | |
| encoded_text = self.feat_map( | |
| bert_output["last_hidden_state"] | |
| ) # bs, 195, d_model | |
| text_token_mask = tokenized.attention_mask.bool() # bs, 195 | |
| # text_token_mask: True for nomask, False for mask | |
| # text_self_attention_masks: True for nomask, False for mask | |
| if encoded_text.shape[1] > self.max_text_len: | |
| encoded_text = encoded_text[:, : self.max_text_len, :] | |
| text_token_mask = text_token_mask[:, : self.max_text_len] | |
| position_ids = position_ids[:, : self.max_text_len] | |
| text_self_attention_masks = text_self_attention_masks[ | |
| :, : self.max_text_len, : self.max_text_len | |
| ] | |
| text_dict = { | |
| "encoded_text": encoded_text, # bs, 195, d_model | |
| "text_token_mask": text_token_mask, # bs, 195 | |
| "position_ids": position_ids, # bs, 195 | |
| "text_self_attention_masks": text_self_attention_masks, # bs, 195,195 | |
| } | |
| if isinstance(samples, (list, torch.Tensor)): | |
| samples = nested_tensor_from_tensor_list(samples) | |
| if not cropped: | |
| features, poss = self.backbone(samples) | |
| features_exemp, _ = self.backbone(exemplar_images) | |
| combined_features = self.combine_features(features_exemp) | |
| # Get visual exemplar tokens. | |
| bs = len(exemplars) | |
| num_exemplars = exemplars[0].shape[0] | |
| print(exemplars) | |
| print(num_exemplars) | |
| if num_exemplars > 0: | |
| exemplar_tokens = ( | |
| roi_align( | |
| combined_features, | |
| boxes=exemplars, | |
| output_size=(1, 1), | |
| spatial_scale=(1 / 8), | |
| aligned=True, | |
| ) | |
| .squeeze(-1) | |
| .squeeze(-1) | |
| .reshape(bs, num_exemplars, -1) | |
| ) | |
| else: | |
| exemplar_tokens = None | |
| else: | |
| features, poss = self.backbone(samples) | |
| (h, w) = ( | |
| samples.decompose()[0][0].shape[1], | |
| samples.decompose()[0][0].shape[2], | |
| ) | |
| (orig_img_h, orig_img_w) = orig_img.shape[1], orig_img.shape[2] | |
| bs = len(samples.decompose()[0]) | |
| exemp_imgs = [] | |
| new_exemplars = [] | |
| ind = 0 | |
| for exemp in exemplars[0]: | |
| center_x = (exemp[0] + exemp[2]) / 2 | |
| center_y = (exemp[1] + exemp[3]) / 2 | |
| start_x = max(int(center_x - crop_width / 2), 0) | |
| end_x = min(int(center_x + crop_width / 2), orig_img_w) | |
| start_y = max(int(center_y - crop_height / 2), 0) | |
| end_y = min(int(center_y + crop_height / 2), orig_img_h) | |
| scale_x = w / (end_x - start_x) | |
| scale_y = h / (end_y - start_y) | |
| exemp_imgs.append( | |
| vis_F.resize( | |
| orig_img[:, start_y:end_y, start_x:end_x], | |
| (h, w), | |
| interpolation=InterpolationMode.BICUBIC, | |
| ) | |
| ) | |
| new_exemplars.append( | |
| [ | |
| (exemp[0] - start_x) * scale_x, | |
| (exemp[1] - start_y) * scale_y, | |
| (exemp[2] - start_x) * scale_x, | |
| (exemp[3] - start_y) * scale_y, | |
| ] | |
| ) | |
| vis_exemps( | |
| renorm(exemp_imgs[-1].cpu()).permute(1, 2, 0).numpy(), | |
| [coord.item() for coord in new_exemplars[-1]], | |
| str(ind) + ".jpg", | |
| ) | |
| vis_exemps( | |
| renorm(orig_img.cpu()).permute(1, 2, 0).numpy(), | |
| [coord.item() for coord in exemplars[0][ind]], | |
| "orig-" + str(ind) + ".jpg", | |
| ) | |
| ind += 1 | |
| exemp_imgs = nested_tensor_from_tensor_list(exemp_imgs) | |
| features_exemp, _ = self.backbone(exemp_imgs) | |
| combined_features = self.combine_features(features_exemp) | |
| new_exemplars = [ | |
| torch.tensor(exemp).unsqueeze(0).to(samples.device) for exemp in new_exemplars | |
| ] | |
| # Get visual exemplar tokens. | |
| exemplar_tokens = ( | |
| roi_align( | |
| combined_features, | |
| boxes=new_exemplars, | |
| output_size=(1, 1), | |
| spatial_scale=(1 / 8), | |
| aligned=True, | |
| ) | |
| .squeeze(-1) | |
| .squeeze(-1) | |
| .reshape(3, 256) | |
| ) | |
| exemplar_tokens = torch.stack([exemplar_tokens] * bs) | |
| if exemplar_tokens is not None: | |
| text_dict = self.add_exemplar_tokens( | |
| tokenized, text_dict, exemplar_tokens, labels | |
| ) | |
| srcs = [] | |
| masks = [] | |
| for l, feat in enumerate(features): | |
| src, mask = feat.decompose() | |
| srcs.append(self.input_proj[l](src)) | |
| masks.append(mask) | |
| assert mask is not None | |
| if self.num_feature_levels > len(srcs): | |
| _len_srcs = len(srcs) | |
| for l in range(_len_srcs, self.num_feature_levels): | |
| if l == _len_srcs: | |
| src = self.input_proj[l](features[-1].tensors) | |
| else: | |
| src = self.input_proj[l](srcs[-1]) | |
| m = samples.mask | |
| mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to( | |
| torch.bool | |
| )[0] | |
| pos_l = self.backbone[1](NestedTensor(src, mask)).to(src.dtype) | |
| srcs.append(src) | |
| masks.append(mask) | |
| poss.append(pos_l) | |
| input_query_bbox = input_query_label = attn_mask = dn_meta = None | |
| hs, reference, hs_enc, ref_enc, init_box_proposal = self.transformer( | |
| srcs, masks, input_query_bbox, poss, input_query_label, attn_mask, text_dict | |
| ) | |
| # deformable-detr-like anchor update | |
| outputs_coord_list = [] | |
| for dec_lid, (layer_ref_sig, layer_bbox_embed, layer_hs) in enumerate( | |
| zip(reference[:-1], self.bbox_embed, hs) | |
| ): | |
| layer_delta_unsig = layer_bbox_embed(layer_hs) | |
| layer_outputs_unsig = layer_delta_unsig + inverse_sigmoid(layer_ref_sig) | |
| layer_outputs_unsig = layer_outputs_unsig.sigmoid() | |
| outputs_coord_list.append(layer_outputs_unsig) | |
| outputs_coord_list = torch.stack(outputs_coord_list) | |
| outputs_class = torch.stack( | |
| [ | |
| layer_cls_embed(layer_hs, text_dict) | |
| for layer_cls_embed, layer_hs in zip(self.class_embed, hs) | |
| ] | |
| ) | |
| out = {"pred_logits": outputs_class[-1], "pred_boxes": outputs_coord_list[-1]} | |
| # Used to calculate losses | |
| bs, len_td = text_dict["text_token_mask"].shape | |
| out["text_mask"] = torch.zeros(bs, self.max_text_len, dtype=torch.bool).to( | |
| samples.device | |
| ) | |
| for b in range(bs): | |
| for j in range(len_td): | |
| if text_dict["text_token_mask"][b][j] == True: | |
| out["text_mask"][b][j] = True | |
| # for intermediate outputs | |
| if self.aux_loss: | |
| out["aux_outputs"] = self._set_aux_loss(outputs_class, outputs_coord_list) | |
| out["token"] = one_hot_token | |
| # # for encoder output | |
| if hs_enc is not None: | |
| # prepare intermediate outputs | |
| interm_coord = ref_enc[-1] | |
| interm_class = self.transformer.enc_out_class_embed(hs_enc[-1], text_dict) | |
| out["interm_outputs"] = { | |
| "pred_logits": interm_class, | |
| "pred_boxes": interm_coord, | |
| } | |
| out["interm_outputs_for_matching_pre"] = { | |
| "pred_logits": interm_class, | |
| "pred_boxes": init_box_proposal, | |
| } | |
| # outputs['pred_logits'].shape | |
| # torch.Size([4, 900, 256]) | |
| # outputs['pred_boxes'].shape | |
| # torch.Size([4, 900, 4]) | |
| # outputs['text_mask'].shape | |
| # torch.Size([256]) | |
| # outputs['text_mask'] | |
| # outputs['aux_outputs'][0].keys() | |
| # dict_keys(['pred_logits', 'pred_boxes', 'one_hot', 'text_mask']) | |
| # outputs['aux_outputs'][img_idx] | |
| # outputs['token'] | |
| # <class 'transformers.tokenization_utils_base.BatchEncoding'> | |
| # outputs['interm_outputs'].keys() | |
| # dict_keys(['pred_logits', 'pred_boxes', 'one_hot', 'text_mask']) | |
| # outputs['interm_outputs_for_matching_pre'].keys() | |
| # dict_keys(['pred_logits', 'pred_boxes']) | |
| # outputs['one_hot'].shape | |
| # torch.Size([4, 900, 256]) | |
| return out | |
| def _set_aux_loss(self, outputs_class, outputs_coord): | |
| # this is a workaround to make torchscript happy, as torchscript | |
| # doesn't support dictionary with non-homogeneous values, such | |
| # as a dict having both a Tensor and a list. | |
| return [ | |
| {"pred_logits": a, "pred_boxes": b} | |
| for a, b in zip(outputs_class[:-1], outputs_coord[:-1]) | |
| ] | |
| class SetCriterion(nn.Module): | |
| def __init__(self, matcher, weight_dict, focal_alpha, focal_gamma, losses): | |
| """Create the criterion. | |
| Parameters: | |
| matcher: module able to compute a matching between targets and proposals | |
| weight_dict: dict containing as key the names of the losses and as values their relative weight. | |
| losses: list of all the losses to be applied. See get_loss for list of available losses. | |
| focal_alpha: alpha in Focal Loss | |
| """ | |
| super().__init__() | |
| self.matcher = matcher | |
| self.weight_dict = weight_dict | |
| self.losses = losses | |
| self.focal_alpha = focal_alpha | |
| self.focal_gamma = focal_gamma | |
| def loss_cardinality(self, outputs, targets, indices, num_boxes): | |
| """Compute the cardinality error, ie the absolute error in the number of predicted non-empty boxes | |
| This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients | |
| """ | |
| pred_logits = outputs["pred_logits"] | |
| device = pred_logits.device | |
| tgt_lengths = torch.as_tensor( | |
| [len(v["labels"]) for v in targets], device=device | |
| ) | |
| # Count the number of predictions that are NOT "no-object" (which is the last class) | |
| card_pred = (pred_logits.argmax(-1) != pred_logits.shape[-1] - 1).sum(1) | |
| card_err = F.l1_loss(card_pred.float(), tgt_lengths.float()) | |
| losses = {"cardinality_error": card_err} | |
| return losses | |
| def loss_boxes(self, outputs, targets, indices, num_boxes): | |
| """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss | |
| targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4] | |
| The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size. | |
| """ | |
| assert "pred_boxes" in outputs | |
| idx = self._get_src_permutation_idx(indices) | |
| src_boxes = outputs["pred_boxes"][idx] | |
| target_boxes = torch.cat( | |
| [t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0 | |
| ) | |
| loss_bbox = F.l1_loss(src_boxes[:, :2], target_boxes[:, :2], reduction="none") | |
| losses = {} | |
| losses["loss_bbox"] = loss_bbox.sum() / num_boxes | |
| loss_giou = 1 - torch.diag( | |
| box_ops.generalized_box_iou( | |
| box_ops.box_cxcywh_to_xyxy(src_boxes), | |
| box_ops.box_cxcywh_to_xyxy(target_boxes), | |
| ) | |
| ) | |
| losses["loss_giou"] = loss_giou.sum() / num_boxes | |
| # calculate the x,y and h,w loss | |
| with torch.no_grad(): | |
| losses["loss_xy"] = loss_bbox[..., :2].sum() / num_boxes | |
| losses["loss_hw"] = loss_bbox[..., 2:].sum() / num_boxes | |
| return losses | |
| def token_sigmoid_binary_focal_loss(self, outputs, targets, indices, num_boxes): | |
| pred_logits = outputs["pred_logits"] | |
| new_targets = outputs["one_hot"].to(pred_logits.device) | |
| text_mask = outputs["text_mask"] | |
| assert new_targets.dim() == 3 | |
| assert pred_logits.dim() == 3 # batch x from x to | |
| bs, n, _ = pred_logits.shape | |
| alpha = self.focal_alpha | |
| gamma = self.focal_gamma | |
| if text_mask is not None: | |
| # ODVG: each sample has different mask | |
| text_mask = text_mask.repeat(1, pred_logits.size(1)).view( | |
| outputs["text_mask"].shape[0], -1, outputs["text_mask"].shape[1] | |
| ) | |
| pred_logits = torch.masked_select(pred_logits, text_mask) | |
| new_targets = torch.masked_select(new_targets, text_mask) | |
| new_targets = new_targets.float() | |
| p = torch.sigmoid(pred_logits) | |
| ce_loss = F.binary_cross_entropy_with_logits( | |
| pred_logits, new_targets, reduction="none" | |
| ) | |
| p_t = p * new_targets + (1 - p) * (1 - new_targets) | |
| loss = ce_loss * ((1 - p_t) ** gamma) | |
| if alpha >= 0: | |
| alpha_t = alpha * new_targets + (1 - alpha) * (1 - new_targets) | |
| loss = alpha_t * loss | |
| total_num_pos = 0 | |
| for batch_indices in indices: | |
| total_num_pos += len(batch_indices[0]) | |
| num_pos_avg_per_gpu = max(total_num_pos, 1.0) | |
| loss = loss.sum() / num_pos_avg_per_gpu | |
| losses = {"loss_ce": loss} | |
| return losses | |
| def _get_src_permutation_idx(self, indices): | |
| # permute predictions following indices | |
| batch_idx = torch.cat( | |
| [torch.full_like(src, i) for i, (src, _) in enumerate(indices)] | |
| ) | |
| src_idx = torch.cat([src for (src, _) in indices]) | |
| return batch_idx, src_idx | |
| def _get_tgt_permutation_idx(self, indices): | |
| # permute targets following indices | |
| batch_idx = torch.cat( | |
| [torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)] | |
| ) | |
| tgt_idx = torch.cat([tgt for (_, tgt) in indices]) | |
| return batch_idx, tgt_idx | |
| def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs): | |
| loss_map = { | |
| "labels": self.token_sigmoid_binary_focal_loss, | |
| "cardinality": self.loss_cardinality, | |
| "boxes": self.loss_boxes, | |
| } | |
| assert loss in loss_map, f"do you really want to compute {loss} loss?" | |
| return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs) | |
| def forward(self, outputs, targets, cat_list, caption, return_indices=False): | |
| """This performs the loss computation. | |
| Parameters: | |
| outputs: dict of tensors, see the output specification of the model for the format | |
| targets: list of dicts, such that len(targets) == batch_size. | |
| The expected keys in each dict depends on the losses applied, see each loss' doc | |
| return_indices: used for vis. if True, the layer0-5 indices will be returned as well. | |
| """ | |
| device = next(iter(outputs.values())).device | |
| one_hot = torch.zeros( | |
| outputs["pred_logits"].size(), dtype=torch.int64 | |
| ) # torch.Size([bs, 900, 256]) | |
| token = outputs["token"] | |
| label_map_list = [] | |
| indices = [] | |
| for j in range(len(cat_list)): # bs | |
| label_map = [] | |
| for i in range(len(cat_list[j])): | |
| label_id = torch.tensor([i]) | |
| per_label = create_positive_map_exemplar( | |
| token["input_ids"][j], label_id, [101, 102, 1012, 1029] | |
| ) | |
| label_map.append(per_label) | |
| label_map = torch.stack(label_map, dim=0).squeeze(1) | |
| label_map_list.append(label_map) | |
| for j in range(len(cat_list)): # bs | |
| for_match = { | |
| "pred_logits": outputs["pred_logits"][j].unsqueeze(0), | |
| "pred_boxes": outputs["pred_boxes"][j].unsqueeze(0), | |
| } | |
| inds = self.matcher(for_match, [targets[j]], label_map_list[j]) | |
| indices.extend(inds) | |
| # indices : A list of size batch_size, containing tuples of (index_i, index_j) where: | |
| # - index_i is the indices of the selected predictions (in order) | |
| # - index_j is the indices of the corresponding selected targets (in order) | |
| # import pdb; pdb.set_trace() | |
| tgt_ids = [v["labels"].cpu() for v in targets] | |
| # len(tgt_ids) == bs | |
| for i in range(len(indices)): | |
| tgt_ids[i] = tgt_ids[i][indices[i][1]] | |
| one_hot[i, indices[i][0]] = label_map_list[i][tgt_ids[i]].to(torch.long) | |
| outputs["one_hot"] = one_hot | |
| if return_indices: | |
| indices0_copy = indices | |
| indices_list = [] | |
| # Compute the average number of target boxes accross all nodes, for normalization purposes | |
| num_boxes_list = [len(t["labels"]) for t in targets] | |
| num_boxes = sum(num_boxes_list) | |
| num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=device) | |
| if is_dist_avail_and_initialized(): | |
| torch.distributed.all_reduce(num_boxes) | |
| num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item() | |
| # Compute all the requested losses | |
| losses = {} | |
| for loss in self.losses: | |
| losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes)) | |
| # In case of auxiliary losses, we repeat this process with the output of each intermediate layer. | |
| if "aux_outputs" in outputs: | |
| for idx, aux_outputs in enumerate(outputs["aux_outputs"]): | |
| indices = [] | |
| for j in range(len(cat_list)): # bs | |
| aux_output_single = { | |
| "pred_logits": aux_outputs["pred_logits"][j].unsqueeze(0), | |
| "pred_boxes": aux_outputs["pred_boxes"][j].unsqueeze(0), | |
| } | |
| inds = self.matcher( | |
| aux_output_single, [targets[j]], label_map_list[j] | |
| ) | |
| indices.extend(inds) | |
| one_hot_aux = torch.zeros( | |
| outputs["pred_logits"].size(), dtype=torch.int64 | |
| ) | |
| tgt_ids = [v["labels"].cpu() for v in targets] | |
| for i in range(len(indices)): | |
| tgt_ids[i] = tgt_ids[i][indices[i][1]] | |
| one_hot_aux[i, indices[i][0]] = label_map_list[i][tgt_ids[i]].to( | |
| torch.long | |
| ) | |
| aux_outputs["one_hot"] = one_hot_aux | |
| aux_outputs["text_mask"] = outputs["text_mask"] | |
| if return_indices: | |
| indices_list.append(indices) | |
| for loss in self.losses: | |
| kwargs = {} | |
| l_dict = self.get_loss( | |
| loss, aux_outputs, targets, indices, num_boxes, **kwargs | |
| ) | |
| l_dict = {k + f"_{idx}": v for k, v in l_dict.items()} | |
| losses.update(l_dict) | |
| # interm_outputs loss | |
| if "interm_outputs" in outputs: | |
| interm_outputs = outputs["interm_outputs"] | |
| indices = [] | |
| for j in range(len(cat_list)): # bs | |
| interm_output_single = { | |
| "pred_logits": interm_outputs["pred_logits"][j].unsqueeze(0), | |
| "pred_boxes": interm_outputs["pred_boxes"][j].unsqueeze(0), | |
| } | |
| inds = self.matcher( | |
| interm_output_single, [targets[j]], label_map_list[j] | |
| ) | |
| indices.extend(inds) | |
| one_hot_aux = torch.zeros(outputs["pred_logits"].size(), dtype=torch.int64) | |
| tgt_ids = [v["labels"].cpu() for v in targets] | |
| for i in range(len(indices)): | |
| tgt_ids[i] = tgt_ids[i][indices[i][1]] | |
| one_hot_aux[i, indices[i][0]] = label_map_list[i][tgt_ids[i]].to( | |
| torch.long | |
| ) | |
| interm_outputs["one_hot"] = one_hot_aux | |
| interm_outputs["text_mask"] = outputs["text_mask"] | |
| if return_indices: | |
| indices_list.append(indices) | |
| for loss in self.losses: | |
| kwargs = {} | |
| l_dict = self.get_loss( | |
| loss, interm_outputs, targets, indices, num_boxes, **kwargs | |
| ) | |
| l_dict = {k + f"_interm": v for k, v in l_dict.items()} | |
| losses.update(l_dict) | |
| if return_indices: | |
| indices_list.append(indices0_copy) | |
| return losses, indices_list | |
| return losses | |
| class PostProcess(nn.Module): | |
| """This module converts the model's output into the format expected by the coco api""" | |
| def __init__( | |
| self, | |
| num_select=100, | |
| text_encoder_type="text_encoder_type", | |
| nms_iou_threshold=-1, | |
| use_coco_eval=False, | |
| args=None, | |
| ) -> None: | |
| super().__init__() | |
| self.num_select = num_select | |
| self.tokenizer = get_tokenlizer.get_tokenlizer(text_encoder_type) | |
| if args.use_coco_eval: | |
| from pycocotools.coco import COCO | |
| coco = COCO(args.coco_val_path) | |
| category_dict = coco.loadCats(coco.getCatIds()) | |
| cat_list = [item["name"] for item in category_dict] | |
| else: | |
| cat_list = args.label_list | |
| caption = " . ".join(cat_list) + " ." | |
| tokenized = self.tokenizer(caption, padding="longest", return_tensors="pt") | |
| label_list = torch.arange(len(cat_list)) | |
| pos_map = create_positive_map(tokenized, label_list, cat_list, caption) | |
| # build a mapping from label_id to pos_map | |
| if args.use_coco_eval: | |
| id_map = { | |
| 0: 1, | |
| 1: 2, | |
| 2: 3, | |
| 3: 4, | |
| 4: 5, | |
| 5: 6, | |
| 6: 7, | |
| 7: 8, | |
| 8: 9, | |
| 9: 10, | |
| 10: 11, | |
| 11: 13, | |
| 12: 14, | |
| 13: 15, | |
| 14: 16, | |
| 15: 17, | |
| 16: 18, | |
| 17: 19, | |
| 18: 20, | |
| 19: 21, | |
| 20: 22, | |
| 21: 23, | |
| 22: 24, | |
| 23: 25, | |
| 24: 27, | |
| 25: 28, | |
| 26: 31, | |
| 27: 32, | |
| 28: 33, | |
| 29: 34, | |
| 30: 35, | |
| 31: 36, | |
| 32: 37, | |
| 33: 38, | |
| 34: 39, | |
| 35: 40, | |
| 36: 41, | |
| 37: 42, | |
| 38: 43, | |
| 39: 44, | |
| 40: 46, | |
| 41: 47, | |
| 42: 48, | |
| 43: 49, | |
| 44: 50, | |
| 45: 51, | |
| 46: 52, | |
| 47: 53, | |
| 48: 54, | |
| 49: 55, | |
| 50: 56, | |
| 51: 57, | |
| 52: 58, | |
| 53: 59, | |
| 54: 60, | |
| 55: 61, | |
| 56: 62, | |
| 57: 63, | |
| 58: 64, | |
| 59: 65, | |
| 60: 67, | |
| 61: 70, | |
| 62: 72, | |
| 63: 73, | |
| 64: 74, | |
| 65: 75, | |
| 66: 76, | |
| 67: 77, | |
| 68: 78, | |
| 69: 79, | |
| 70: 80, | |
| 71: 81, | |
| 72: 82, | |
| 73: 84, | |
| 74: 85, | |
| 75: 86, | |
| 76: 87, | |
| 77: 88, | |
| 78: 89, | |
| 79: 90, | |
| } | |
| new_pos_map = torch.zeros((91, 256)) | |
| for k, v in id_map.items(): | |
| new_pos_map[v] = pos_map[k] | |
| pos_map = new_pos_map | |
| self.nms_iou_threshold = nms_iou_threshold | |
| self.positive_map = pos_map | |
| def forward(self, outputs, target_sizes, not_to_xyxy=False, test=False): | |
| """Perform the computation | |
| Parameters: | |
| outputs: raw outputs of the model | |
| target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch | |
| For evaluation, this must be the original image size (before any data augmentation) | |
| For visualization, this should be the image size after data augment, but before padding | |
| """ | |
| num_select = self.num_select | |
| out_logits, out_bbox = outputs["pred_logits"], outputs["pred_boxes"] | |
| prob_to_token = out_logits.sigmoid() | |
| pos_maps = self.positive_map.to(prob_to_token.device) | |
| for label_ind in range(len(pos_maps)): | |
| if pos_maps[label_ind].sum() != 0: | |
| pos_maps[label_ind] = pos_maps[label_ind] / pos_maps[label_ind].sum() | |
| prob_to_label = prob_to_token @ pos_maps.T | |
| assert len(out_logits) == len(target_sizes) | |
| assert target_sizes.shape[1] == 2 | |
| prob = prob_to_label | |
| topk_values, topk_indexes = torch.topk( | |
| prob.view(prob.shape[0], -1), num_select, dim=1 | |
| ) | |
| scores = topk_values | |
| topk_boxes = torch.div(topk_indexes, prob.shape[2], rounding_mode="trunc") | |
| labels = topk_indexes % prob.shape[2] | |
| if not_to_xyxy: | |
| boxes = out_bbox | |
| else: | |
| boxes = box_ops.box_cxcywh_to_xyxy(out_bbox) | |
| # if test: | |
| # assert not not_to_xyxy | |
| # boxes[:,:,2:] = boxes[:,:,2:] - boxes[:,:,:2] | |
| boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1, 1, 4)) | |
| # and from relative [0, 1] to absolute [0, height] coordinates | |
| img_h, img_w = target_sizes.unbind(1) | |
| scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1) | |
| boxes = boxes * scale_fct[:, None, :] | |
| if self.nms_iou_threshold > 0: | |
| item_indices = [ | |
| nms(b, s, iou_threshold=self.nms_iou_threshold) | |
| for b, s in zip(boxes, scores) | |
| ] | |
| results = [ | |
| {"scores": s[i], "labels": l[i], "boxes": b[i]} | |
| for s, l, b, i in zip(scores, labels, boxes, item_indices) | |
| ] | |
| else: | |
| results = [ | |
| {"scores": s, "labels": l, "boxes": b} | |
| for s, l, b in zip(scores, labels, boxes) | |
| ] | |
| results = [ | |
| {"scores": s, "labels": l, "boxes": b} | |
| for s, l, b in zip(scores, labels, boxes) | |
| ] | |
| return results | |
| def build_groundingdino(args): | |
| device = torch.device(args.device) | |
| backbone = build_backbone(args) | |
| transformer = build_transformer(args) | |
| dn_labelbook_size = args.dn_labelbook_size | |
| dec_pred_bbox_embed_share = args.dec_pred_bbox_embed_share | |
| sub_sentence_present = args.sub_sentence_present | |
| model = GroundingDINO( | |
| backbone, | |
| transformer, | |
| num_queries=args.num_queries, | |
| aux_loss=args.aux_loss, | |
| iter_update=True, | |
| query_dim=4, | |
| num_feature_levels=args.num_feature_levels, | |
| nheads=args.nheads, | |
| dec_pred_bbox_embed_share=dec_pred_bbox_embed_share, | |
| two_stage_type=args.two_stage_type, | |
| two_stage_bbox_embed_share=args.two_stage_bbox_embed_share, | |
| two_stage_class_embed_share=args.two_stage_class_embed_share, | |
| num_patterns=args.num_patterns, | |
| dn_number=0, | |
| dn_box_noise_scale=args.dn_box_noise_scale, | |
| dn_label_noise_ratio=args.dn_label_noise_ratio, | |
| dn_labelbook_size=dn_labelbook_size, | |
| text_encoder_type=args.text_encoder_type, | |
| sub_sentence_present=sub_sentence_present, | |
| max_text_len=args.max_text_len, | |
| ) | |
| matcher = build_matcher(args) | |
| # prepare weight dict | |
| weight_dict = {"loss_ce": args.cls_loss_coef, "loss_bbox": args.bbox_loss_coef} | |
| weight_dict["loss_giou"] = args.giou_loss_coef | |
| clean_weight_dict_wo_dn = copy.deepcopy(weight_dict) | |
| clean_weight_dict = copy.deepcopy(weight_dict) | |
| # TODO this is a hack | |
| if args.aux_loss: | |
| aux_weight_dict = {} | |
| for i in range(args.dec_layers - 1): | |
| aux_weight_dict.update( | |
| {k + f"_{i}": v for k, v in clean_weight_dict.items()} | |
| ) | |
| weight_dict.update(aux_weight_dict) | |
| if args.two_stage_type != "no": | |
| interm_weight_dict = {} | |
| try: | |
| no_interm_box_loss = args.no_interm_box_loss | |
| except: | |
| no_interm_box_loss = False | |
| _coeff_weight_dict = { | |
| "loss_ce": 1.0, | |
| "loss_bbox": 1.0 if not no_interm_box_loss else 0.0, | |
| "loss_giou": 1.0 if not no_interm_box_loss else 0.0, | |
| } | |
| try: | |
| interm_loss_coef = args.interm_loss_coef | |
| except: | |
| interm_loss_coef = 1.0 | |
| interm_weight_dict.update( | |
| { | |
| k + f"_interm": v * interm_loss_coef * _coeff_weight_dict[k] | |
| for k, v in clean_weight_dict_wo_dn.items() | |
| } | |
| ) | |
| weight_dict.update(interm_weight_dict) | |
| # losses = ['labels', 'boxes', 'cardinality'] | |
| losses = ["labels", "boxes"] | |
| criterion = SetCriterion( | |
| matcher=matcher, | |
| weight_dict=weight_dict, | |
| focal_alpha=args.focal_alpha, | |
| focal_gamma=args.focal_gamma, | |
| losses=losses, | |
| ) | |
| criterion.to(device) | |
| postprocessors = { | |
| "bbox": PostProcess( | |
| num_select=args.num_select, | |
| text_encoder_type=args.text_encoder_type, | |
| nms_iou_threshold=args.nms_iou_threshold, | |
| args=args, | |
| ) | |
| } | |
| return model, criterion, postprocessors | |
| def create_positive_map(tokenized, tokens_positive, cat_list, caption): | |
| """construct a map such that positive_map[i,j] = True iff box i is associated to token j""" | |
| positive_map = torch.zeros((len(tokens_positive), 256), dtype=torch.float) | |
| for j, label in enumerate(tokens_positive): | |
| start_ind = caption.find(cat_list[label]) | |
| end_ind = start_ind + len(cat_list[label]) - 1 | |
| beg_pos = tokenized.char_to_token(start_ind) | |
| try: | |
| end_pos = tokenized.char_to_token(end_ind) | |
| except: | |
| end_pos = None | |
| if end_pos is None: | |
| try: | |
| end_pos = tokenized.char_to_token(end_ind - 1) | |
| if end_pos is None: | |
| end_pos = tokenized.char_to_token(end_ind - 2) | |
| except: | |
| end_pos = None | |
| # except Exception as e: | |
| # print("beg:", beg, "end:", end) | |
| # print("token_positive:", tokens_positive) | |
| # # print("beg_pos:", beg_pos, "end_pos:", end_pos) | |
| # raise e | |
| # if beg_pos is None: | |
| # try: | |
| # beg_pos = tokenized.char_to_token(beg + 1) | |
| # if beg_pos is None: | |
| # beg_pos = tokenized.char_to_token(beg + 2) | |
| # except: | |
| # beg_pos = None | |
| # if end_pos is None: | |
| # try: | |
| # end_pos = tokenized.char_to_token(end - 2) | |
| # if end_pos is None: | |
| # end_pos = tokenized.char_to_token(end - 3) | |
| # except: | |
| # end_pos = None | |
| if beg_pos is None or end_pos is None: | |
| continue | |
| if beg_pos < 0 or end_pos < 0: | |
| continue | |
| if beg_pos > end_pos: | |
| continue | |
| # assert beg_pos is not None and end_pos is not None | |
| positive_map[j, beg_pos : end_pos + 1].fill_(1) | |
| return positive_map | |
| def create_positive_map_exemplar(input_ids, label, special_tokens): | |
| tokens_positive = torch.zeros(256, dtype=torch.float) | |
| count = -1 | |
| for token_ind in range(len(input_ids)): | |
| input_id = input_ids[token_ind] | |
| if (input_id not in special_tokens) and ( | |
| token_ind == 0 or (input_ids[token_ind - 1] in special_tokens) | |
| ): | |
| count += 1 | |
| if count == label: | |
| ind_to_insert_ones = token_ind | |
| while input_ids[ind_to_insert_ones] not in special_tokens: | |
| tokens_positive[ind_to_insert_ones] = 1 | |
| ind_to_insert_ones += 1 | |
| break | |
| return tokens_positive | |