Spaces:
Runtime error
Runtime error
| # ------------------------------------------------------------------------ | |
| # Grounding DINO | |
| # url: https://github.com/IDEA-Research/GroundingDINO | |
| # Copyright (c) 2023 IDEA. All Rights Reserved. | |
| # Licensed under the Apache License, Version 2.0 [see LICENSE for details] | |
| # ------------------------------------------------------------------------ | |
| # DINO | |
| # Copyright (c) 2022 IDEA. All Rights Reserved. | |
| # Licensed under the Apache License, Version 2.0 [see LICENSE for details] | |
| # ------------------------------------------------------------------------ | |
| # Conditional DETR Transformer class. | |
| # Copyright (c) 2021 Microsoft. All Rights Reserved. | |
| # Licensed under the Apache License, Version 2.0 [see LICENSE for details] | |
| # ------------------------------------------------------------------------ | |
| # Modified from DETR (https://github.com/facebookresearch/detr) | |
| # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | |
| # ------------------------------------------------------------------------ | |
| from typing import Optional | |
| import torch | |
| import torch.utils.checkpoint as checkpoint | |
| from torch import Tensor, nn | |
| from groundingdino.util.misc import inverse_sigmoid | |
| from .fuse_modules import BiAttentionBlock | |
| from .ms_deform_attn import MultiScaleDeformableAttention as MSDeformAttn | |
| from .transformer_vanilla import TransformerEncoderLayer | |
| from .utils import ( | |
| MLP, | |
| _get_activation_fn, | |
| _get_clones, | |
| gen_encoder_output_proposals, | |
| gen_sineembed_for_position, | |
| get_sine_pos_embed, | |
| ) | |
| class Transformer(nn.Module): | |
| def __init__( | |
| self, | |
| d_model=256, | |
| nhead=8, | |
| num_queries=300, | |
| num_encoder_layers=6, | |
| num_unicoder_layers=0, | |
| num_decoder_layers=6, | |
| dim_feedforward=2048, | |
| dropout=0.0, | |
| activation="relu", | |
| normalize_before=False, | |
| return_intermediate_dec=False, | |
| query_dim=4, | |
| num_patterns=0, | |
| # for deformable encoder | |
| num_feature_levels=1, | |
| enc_n_points=4, | |
| dec_n_points=4, | |
| # init query | |
| learnable_tgt_init=False, | |
| # two stage | |
| two_stage_type="no", # ['no', 'standard', 'early', 'combine', 'enceachlayer', 'enclayer1'] | |
| embed_init_tgt=False, | |
| # for text | |
| use_text_enhancer=False, | |
| use_fusion_layer=False, | |
| use_checkpoint=False, | |
| use_transformer_ckpt=False, | |
| use_text_cross_attention=False, | |
| text_dropout=0.1, | |
| fusion_dropout=0.1, | |
| fusion_droppath=0.0, | |
| ): | |
| super().__init__() | |
| self.num_feature_levels = num_feature_levels | |
| self.num_encoder_layers = num_encoder_layers | |
| self.num_unicoder_layers = num_unicoder_layers | |
| self.num_decoder_layers = num_decoder_layers | |
| self.num_queries = num_queries | |
| assert query_dim == 4 | |
| # choose encoder layer type | |
| encoder_layer = DeformableTransformerEncoderLayer( | |
| d_model, | |
| dim_feedforward, | |
| dropout, | |
| activation, | |
| num_feature_levels, | |
| nhead, | |
| enc_n_points, | |
| ) | |
| if use_text_enhancer: | |
| text_enhance_layer = TransformerEncoderLayer( | |
| d_model=d_model, | |
| nhead=nhead // 2, | |
| dim_feedforward=dim_feedforward // 2, | |
| dropout=text_dropout, | |
| ) | |
| else: | |
| text_enhance_layer = None | |
| if use_fusion_layer: | |
| feature_fusion_layer = BiAttentionBlock( | |
| v_dim=d_model, | |
| l_dim=d_model, | |
| embed_dim=dim_feedforward // 2, | |
| num_heads=nhead // 2, | |
| dropout=fusion_dropout, | |
| drop_path=fusion_droppath, | |
| ) | |
| else: | |
| feature_fusion_layer = None | |
| encoder_norm = nn.LayerNorm(d_model) if normalize_before else None | |
| assert encoder_norm is None | |
| self.encoder = TransformerEncoder( | |
| encoder_layer, | |
| num_encoder_layers, | |
| d_model=d_model, | |
| num_queries=num_queries, | |
| text_enhance_layer=text_enhance_layer, | |
| feature_fusion_layer=feature_fusion_layer, | |
| use_checkpoint=use_checkpoint, | |
| use_transformer_ckpt=use_transformer_ckpt, | |
| ) | |
| # choose decoder layer type | |
| decoder_layer = DeformableTransformerDecoderLayer( | |
| d_model, | |
| dim_feedforward, | |
| dropout, | |
| activation, | |
| num_feature_levels, | |
| nhead, | |
| dec_n_points, | |
| use_text_cross_attention=use_text_cross_attention, | |
| ) | |
| decoder_norm = nn.LayerNorm(d_model) | |
| self.decoder = TransformerDecoder( | |
| decoder_layer, | |
| num_decoder_layers, | |
| decoder_norm, | |
| return_intermediate=return_intermediate_dec, | |
| d_model=d_model, | |
| query_dim=query_dim, | |
| num_feature_levels=num_feature_levels, | |
| ) | |
| self.d_model = d_model | |
| self.nhead = nhead | |
| self.dec_layers = num_decoder_layers | |
| self.num_queries = num_queries # useful for single stage model only | |
| self.num_patterns = num_patterns | |
| if not isinstance(num_patterns, int): | |
| Warning("num_patterns should be int but {}".format(type(num_patterns))) | |
| self.num_patterns = 0 | |
| if num_feature_levels > 1: | |
| if self.num_encoder_layers > 0: | |
| self.level_embed = nn.Parameter( | |
| torch.Tensor(num_feature_levels, d_model) | |
| ) | |
| else: | |
| self.level_embed = None | |
| self.learnable_tgt_init = learnable_tgt_init | |
| assert learnable_tgt_init, "why not learnable_tgt_init" | |
| self.embed_init_tgt = embed_init_tgt | |
| if (two_stage_type != "no" and embed_init_tgt) or (two_stage_type == "no"): | |
| self.tgt_embed = nn.Embedding(self.num_queries, d_model) | |
| nn.init.normal_(self.tgt_embed.weight.data) | |
| else: | |
| self.tgt_embed = None | |
| # for two stage | |
| self.two_stage_type = two_stage_type | |
| assert two_stage_type in [ | |
| "no", | |
| "standard", | |
| ], "unknown param {} of two_stage_type".format(two_stage_type) | |
| if two_stage_type == "standard": | |
| # anchor selection at the output of encoder | |
| self.enc_output = nn.Linear(d_model, d_model) | |
| self.enc_output_norm = nn.LayerNorm(d_model) | |
| self.two_stage_wh_embedding = None | |
| if two_stage_type == "no": | |
| self.init_ref_points(num_queries) # init self.refpoint_embed | |
| self.enc_out_class_embed = None | |
| self.enc_out_bbox_embed = None | |
| self._reset_parameters() | |
| def _reset_parameters(self): | |
| for p in self.parameters(): | |
| if p.dim() > 1: | |
| nn.init.xavier_uniform_(p) | |
| for m in self.modules(): | |
| if isinstance(m, MSDeformAttn): | |
| m._reset_parameters() | |
| if self.num_feature_levels > 1 and self.level_embed is not None: | |
| nn.init.normal_(self.level_embed) | |
| def get_valid_ratio(self, mask): | |
| _, H, W = mask.shape | |
| valid_H = torch.sum(~mask[:, :, 0], 1) | |
| valid_W = torch.sum(~mask[:, 0, :], 1) | |
| valid_ratio_h = valid_H.float() / H | |
| valid_ratio_w = valid_W.float() / W | |
| valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1) | |
| return valid_ratio | |
| def init_ref_points(self, use_num_queries): | |
| self.refpoint_embed = nn.Embedding(use_num_queries, 4) | |
| def forward( | |
| self, | |
| srcs, | |
| masks, | |
| refpoint_embed, | |
| pos_embeds, | |
| tgt, | |
| attn_mask=None, | |
| text_dict=None, | |
| ): | |
| """ | |
| Input: | |
| - srcs: List of multi features [bs, ci, hi, wi] | |
| - masks: List of multi masks [bs, hi, wi] | |
| - refpoint_embed: [bs, num_dn, 4]. None in infer | |
| - pos_embeds: List of multi pos embeds [bs, ci, hi, wi] | |
| - tgt: [bs, num_dn, d_model]. None in infer | |
| """ | |
| # prepare input for encoder | |
| src_flatten = [] | |
| mask_flatten = [] | |
| lvl_pos_embed_flatten = [] | |
| spatial_shapes = [] | |
| for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)): | |
| bs, c, h, w = src.shape | |
| spatial_shape = (h, w) | |
| spatial_shapes.append(spatial_shape) | |
| src = src.flatten(2).transpose(1, 2) # bs, hw, c | |
| mask = mask.flatten(1) # bs, hw | |
| pos_embed = pos_embed.flatten(2).transpose(1, 2) # bs, hw, c | |
| if self.num_feature_levels > 1 and self.level_embed is not None: | |
| lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1) | |
| else: | |
| lvl_pos_embed = pos_embed | |
| lvl_pos_embed_flatten.append(lvl_pos_embed) | |
| src_flatten.append(src) | |
| mask_flatten.append(mask) | |
| src_flatten = torch.cat(src_flatten, 1) # bs, \sum{hxw}, c | |
| mask_flatten = torch.cat(mask_flatten, 1) # bs, \sum{hxw} | |
| lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) # bs, \sum{hxw}, c | |
| spatial_shapes = torch.as_tensor( | |
| spatial_shapes, dtype=torch.long, device=src_flatten.device | |
| ) | |
| level_start_index = torch.cat( | |
| (spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]) | |
| ) | |
| valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1) | |
| # two stage | |
| enc_topk_proposals = enc_refpoint_embed = None | |
| ######################################################### | |
| # Begin Encoder | |
| ######################################################### | |
| memory, memory_text = self.encoder( | |
| src_flatten, | |
| pos=lvl_pos_embed_flatten, | |
| level_start_index=level_start_index, | |
| spatial_shapes=spatial_shapes, | |
| valid_ratios=valid_ratios, | |
| key_padding_mask=mask_flatten, | |
| memory_text=text_dict["encoded_text"], | |
| text_attention_mask=~text_dict["text_token_mask"], | |
| # we ~ the mask . False means use the token; True means pad the token | |
| position_ids=text_dict["position_ids"], | |
| text_self_attention_masks=text_dict["text_self_attention_masks"], | |
| ) | |
| ######################################################### | |
| # End Encoder | |
| # - memory: bs, \sum{hw}, c | |
| # - mask_flatten: bs, \sum{hw} | |
| # - lvl_pos_embed_flatten: bs, \sum{hw}, c | |
| # - enc_intermediate_output: None or (nenc+1, bs, nq, c) or (nenc, bs, nq, c) | |
| # - enc_intermediate_refpoints: None or (nenc+1, bs, nq, c) or (nenc, bs, nq, c) | |
| ######################################################### | |
| text_dict["encoded_text"] = memory_text | |
| # if os.environ.get("SHILONG_AMP_INFNAN_DEBUG") == '1': | |
| # if memory.isnan().any() | memory.isinf().any(): | |
| # import ipdb; ipdb.set_trace() | |
| if self.two_stage_type == "standard": # 把encoder的输出作为proposal | |
| output_memory, output_proposals = gen_encoder_output_proposals( | |
| memory, mask_flatten, spatial_shapes | |
| ) | |
| output_memory = self.enc_output_norm(self.enc_output(output_memory)) | |
| if text_dict is not None: | |
| enc_outputs_class_unselected = self.enc_out_class_embed( | |
| output_memory, text_dict | |
| ) | |
| else: | |
| enc_outputs_class_unselected = self.enc_out_class_embed(output_memory) | |
| topk_logits = enc_outputs_class_unselected.max(-1)[0] | |
| enc_outputs_coord_unselected = ( | |
| self.enc_out_bbox_embed(output_memory) + output_proposals | |
| ) # (bs, \sum{hw}, 4) unsigmoid | |
| topk = self.num_queries | |
| topk_proposals = torch.topk(topk_logits, topk, dim=1)[1] # bs, nq | |
| # gather boxes | |
| refpoint_embed_undetach = torch.gather( | |
| enc_outputs_coord_unselected, | |
| 1, | |
| topk_proposals.unsqueeze(-1).repeat(1, 1, 4), | |
| ) # unsigmoid | |
| refpoint_embed_ = refpoint_embed_undetach.detach() | |
| init_box_proposal = torch.gather( | |
| output_proposals, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4) | |
| ).sigmoid() # sigmoid | |
| # gather tgt | |
| tgt_undetach = torch.gather( | |
| output_memory, | |
| 1, | |
| topk_proposals.unsqueeze(-1).repeat(1, 1, self.d_model), | |
| ) | |
| if self.embed_init_tgt: | |
| tgt_ = ( | |
| self.tgt_embed.weight[:, None, :].repeat(1, bs, 1).transpose(0, 1) | |
| ) # nq, bs, d_model | |
| else: | |
| tgt_ = tgt_undetach.detach() | |
| if refpoint_embed is not None: | |
| refpoint_embed = torch.cat([refpoint_embed, refpoint_embed_], dim=1) | |
| tgt = torch.cat([tgt, tgt_], dim=1) | |
| else: | |
| refpoint_embed, tgt = refpoint_embed_, tgt_ | |
| elif self.two_stage_type == "no": | |
| tgt_ = ( | |
| self.tgt_embed.weight[:, None, :].repeat(1, bs, 1).transpose(0, 1) | |
| ) # nq, bs, d_model | |
| refpoint_embed_ = ( | |
| self.refpoint_embed.weight[:, None, :].repeat(1, bs, 1).transpose(0, 1) | |
| ) # nq, bs, 4 | |
| if refpoint_embed is not None: | |
| refpoint_embed = torch.cat([refpoint_embed, refpoint_embed_], dim=1) | |
| tgt = torch.cat([tgt, tgt_], dim=1) | |
| else: | |
| refpoint_embed, tgt = refpoint_embed_, tgt_ | |
| if self.num_patterns > 0: | |
| tgt_embed = tgt.repeat(1, self.num_patterns, 1) | |
| refpoint_embed = refpoint_embed.repeat(1, self.num_patterns, 1) | |
| tgt_pat = self.patterns.weight[None, :, :].repeat_interleave( | |
| self.num_queries, 1 | |
| ) # 1, n_q*n_pat, d_model | |
| tgt = tgt_embed + tgt_pat | |
| init_box_proposal = refpoint_embed_.sigmoid() | |
| else: | |
| raise NotImplementedError( | |
| "unknown two_stage_type {}".format(self.two_stage_type) | |
| ) | |
| ######################################################### | |
| # End preparing tgt | |
| # - tgt: bs, NQ, d_model | |
| # - refpoint_embed(unsigmoid): bs, NQ, d_model | |
| ######################################################### | |
| ######################################################### | |
| # Begin Decoder | |
| ######################################################### | |
| # memory torch.Size([2, 16320, 256]) | |
| # import pdb;pdb.set_trace() | |
| hs, references = self.decoder( | |
| tgt=tgt.transpose(0, 1), | |
| memory=memory.transpose(0, 1), | |
| memory_key_padding_mask=mask_flatten, | |
| pos=lvl_pos_embed_flatten.transpose(0, 1), | |
| refpoints_unsigmoid=refpoint_embed.transpose(0, 1), | |
| level_start_index=level_start_index, | |
| spatial_shapes=spatial_shapes, | |
| valid_ratios=valid_ratios, | |
| tgt_mask=attn_mask, | |
| memory_text=text_dict["encoded_text"], | |
| text_attention_mask=~text_dict["text_token_mask"], | |
| # we ~ the mask . False means use the token; True means pad the token | |
| ) | |
| ######################################################### | |
| # End Decoder | |
| # hs: n_dec, bs, nq, d_model | |
| # references: n_dec+1, bs, nq, query_dim | |
| ######################################################### | |
| ######################################################### | |
| # Begin postprocess | |
| ######################################################### | |
| if self.two_stage_type == "standard": | |
| hs_enc = tgt_undetach.unsqueeze(0) | |
| ref_enc = refpoint_embed_undetach.sigmoid().unsqueeze(0) | |
| else: | |
| hs_enc = ref_enc = None | |
| ######################################################### | |
| # End postprocess | |
| # hs_enc: (n_enc+1, bs, nq, d_model) or (1, bs, nq, d_model) or (n_enc, bs, nq, d_model) or None | |
| # ref_enc: (n_enc+1, bs, nq, query_dim) or (1, bs, nq, query_dim) or (n_enc, bs, nq, d_model) or None | |
| ######################################################### | |
| return hs, references, hs_enc, ref_enc, init_box_proposal | |
| # hs: (n_dec, bs, nq, d_model) | |
| # references: sigmoid coordinates. (n_dec+1, bs, bq, 4) | |
| # hs_enc: (n_enc+1, bs, nq, d_model) or (1, bs, nq, d_model) or None | |
| # ref_enc: sigmoid coordinates. \ | |
| # (n_enc+1, bs, nq, query_dim) or (1, bs, nq, query_dim) or None | |
| class TransformerEncoder(nn.Module): | |
| def __init__( | |
| self, | |
| encoder_layer, | |
| num_layers, | |
| d_model=256, | |
| num_queries=300, | |
| enc_layer_share=False, | |
| text_enhance_layer=None, | |
| feature_fusion_layer=None, | |
| use_checkpoint=False, | |
| use_transformer_ckpt=False, | |
| ): | |
| """_summary_ | |
| Args: | |
| encoder_layer (_type_): _description_ | |
| num_layers (_type_): _description_ | |
| norm (_type_, optional): _description_. Defaults to None. | |
| d_model (int, optional): _description_. Defaults to 256. | |
| num_queries (int, optional): _description_. Defaults to 300. | |
| enc_layer_share (bool, optional): _description_. Defaults to False. | |
| """ | |
| super().__init__() | |
| # prepare layers | |
| self.layers = [] | |
| self.text_layers = [] | |
| self.fusion_layers = [] | |
| if num_layers > 0: | |
| self.layers = _get_clones( | |
| encoder_layer, num_layers, layer_share=enc_layer_share | |
| ) | |
| if text_enhance_layer is not None: | |
| self.text_layers = _get_clones( | |
| text_enhance_layer, num_layers, layer_share=enc_layer_share | |
| ) | |
| if feature_fusion_layer is not None: | |
| self.fusion_layers = _get_clones( | |
| feature_fusion_layer, num_layers, layer_share=enc_layer_share | |
| ) | |
| else: | |
| self.layers = [] | |
| del encoder_layer | |
| if text_enhance_layer is not None: | |
| self.text_layers = [] | |
| del text_enhance_layer | |
| if feature_fusion_layer is not None: | |
| self.fusion_layers = [] | |
| del feature_fusion_layer | |
| self.query_scale = None | |
| self.num_queries = num_queries | |
| self.num_layers = num_layers | |
| self.d_model = d_model | |
| self.use_checkpoint = use_checkpoint | |
| self.use_transformer_ckpt = use_transformer_ckpt | |
| def get_reference_points(spatial_shapes, valid_ratios, device): | |
| reference_points_list = [] | |
| for lvl, (H_, W_) in enumerate(spatial_shapes): | |
| ref_y, ref_x = torch.meshgrid( | |
| torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device), | |
| torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device), | |
| ) | |
| ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_) | |
| ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_) | |
| ref = torch.stack((ref_x, ref_y), -1) | |
| reference_points_list.append(ref) | |
| reference_points = torch.cat(reference_points_list, 1) | |
| reference_points = reference_points[:, :, None] * valid_ratios[:, None] | |
| return reference_points | |
| def forward( | |
| self, | |
| # for images | |
| src: Tensor, | |
| pos: Tensor, | |
| spatial_shapes: Tensor, | |
| level_start_index: Tensor, | |
| valid_ratios: Tensor, | |
| key_padding_mask: Tensor, | |
| # for texts | |
| memory_text: Tensor = None, | |
| text_attention_mask: Tensor = None, | |
| pos_text: Tensor = None, | |
| text_self_attention_masks: Tensor = None, | |
| position_ids: Tensor = None, | |
| ): | |
| """ | |
| Input: | |
| - src: [bs, sum(hi*wi), 256] | |
| - pos: pos embed for src. [bs, sum(hi*wi), 256] | |
| - spatial_shapes: h,w of each level [num_level, 2] | |
| - level_start_index: [num_level] start point of level in sum(hi*wi). | |
| - valid_ratios: [bs, num_level, 2] | |
| - key_padding_mask: [bs, sum(hi*wi)] | |
| - memory_text: bs, n_text, 256 | |
| - text_attention_mask: bs, n_text | |
| False for no padding; True for padding | |
| - pos_text: bs, n_text, 256 | |
| - position_ids: bs, n_text | |
| Intermedia: | |
| - reference_points: [bs, sum(hi*wi), num_level, 2] | |
| Outpus: | |
| - output: [bs, sum(hi*wi), 256] | |
| """ | |
| output = src | |
| # preparation and reshape | |
| if self.num_layers > 0: | |
| reference_points = self.get_reference_points( | |
| spatial_shapes, valid_ratios, device=src.device | |
| ) | |
| if self.text_layers: | |
| # generate pos_text | |
| bs, n_text, text_dim = memory_text.shape | |
| if pos_text is None and position_ids is None: | |
| pos_text = ( | |
| torch.arange(n_text, device=memory_text.device) | |
| .float() | |
| .unsqueeze(0) | |
| .unsqueeze(-1) | |
| .repeat(bs, 1, 1) | |
| ) | |
| pos_text = get_sine_pos_embed( | |
| pos_text, num_pos_feats=256, exchange_xy=False | |
| ) | |
| if position_ids is not None: | |
| pos_text = get_sine_pos_embed( | |
| position_ids[..., None], num_pos_feats=256, exchange_xy=False | |
| ) | |
| # main process | |
| for layer_id, layer in enumerate(self.layers): | |
| # if output.isnan().any() or memory_text.isnan().any(): | |
| # if os.environ.get('IPDB_SHILONG_DEBUG', None) == 'INFO': | |
| # import ipdb; ipdb.set_trace() | |
| if self.fusion_layers: | |
| if self.use_checkpoint: | |
| output, memory_text = checkpoint.checkpoint( | |
| self.fusion_layers[layer_id], | |
| output, | |
| memory_text, | |
| key_padding_mask, | |
| text_attention_mask, | |
| ) | |
| else: | |
| output, memory_text = self.fusion_layers[layer_id]( | |
| v=output, | |
| l=memory_text, | |
| attention_mask_v=key_padding_mask, | |
| attention_mask_l=text_attention_mask, | |
| ) | |
| if self.text_layers: | |
| memory_text = self.text_layers[layer_id]( | |
| src=memory_text.transpose(0, 1), | |
| src_mask=~text_self_attention_masks, # note we use ~ for mask here | |
| src_key_padding_mask=text_attention_mask, | |
| pos=(pos_text.transpose(0, 1) if pos_text is not None else None), | |
| ).transpose(0, 1) | |
| # main process | |
| if self.use_transformer_ckpt: | |
| output = checkpoint.checkpoint( | |
| layer, | |
| output, | |
| pos, | |
| reference_points, | |
| spatial_shapes, | |
| level_start_index, | |
| key_padding_mask, | |
| ) | |
| else: | |
| output = layer( | |
| src=output, | |
| pos=pos, | |
| reference_points=reference_points, | |
| spatial_shapes=spatial_shapes, | |
| level_start_index=level_start_index, | |
| key_padding_mask=key_padding_mask, | |
| ) | |
| return output, memory_text | |
| class TransformerDecoder(nn.Module): | |
| def __init__( | |
| self, | |
| decoder_layer, | |
| num_layers, | |
| norm=None, | |
| return_intermediate=False, | |
| d_model=256, | |
| query_dim=4, | |
| num_feature_levels=1, | |
| ): | |
| super().__init__() | |
| if num_layers > 0: | |
| self.layers = _get_clones(decoder_layer, num_layers) | |
| else: | |
| self.layers = [] | |
| self.num_layers = num_layers | |
| self.norm = norm | |
| self.return_intermediate = return_intermediate | |
| assert return_intermediate, "support return_intermediate only" | |
| self.query_dim = query_dim | |
| assert query_dim in [2, 4], "query_dim should be 2/4 but {}".format(query_dim) | |
| self.num_feature_levels = num_feature_levels | |
| self.ref_point_head = MLP(query_dim // 2 * d_model, d_model, d_model, 2) | |
| self.query_pos_sine_scale = None | |
| self.query_scale = None | |
| self.bbox_embed = None | |
| self.class_embed = None | |
| self.d_model = d_model | |
| self.ref_anchor_head = None | |
| def forward( | |
| self, | |
| tgt, | |
| memory, | |
| tgt_mask: Optional[Tensor] = None, | |
| memory_mask: Optional[Tensor] = None, | |
| tgt_key_padding_mask: Optional[Tensor] = None, | |
| memory_key_padding_mask: Optional[Tensor] = None, | |
| pos: Optional[Tensor] = None, | |
| refpoints_unsigmoid: Optional[Tensor] = None, # num_queries, bs, 2 | |
| # for memory | |
| level_start_index: Optional[Tensor] = None, # num_levels | |
| spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2 | |
| valid_ratios: Optional[Tensor] = None, | |
| # for text | |
| memory_text: Optional[Tensor] = None, | |
| text_attention_mask: Optional[Tensor] = None, | |
| ): | |
| """ | |
| Input: | |
| - tgt: nq, bs, d_model | |
| - memory: hw, bs, d_model | |
| - pos: hw, bs, d_model | |
| - refpoints_unsigmoid: nq, bs, 2/4 | |
| - valid_ratios/spatial_shapes: bs, nlevel, 2 | |
| """ | |
| output = tgt | |
| intermediate = [] | |
| reference_points = refpoints_unsigmoid.sigmoid() | |
| ref_points = [reference_points] | |
| for layer_id, layer in enumerate(self.layers): | |
| if reference_points.shape[-1] == 4: | |
| reference_points_input = ( | |
| reference_points[:, :, None] | |
| * torch.cat([valid_ratios, valid_ratios], -1)[None, :] | |
| ) # nq, bs, nlevel, 4 | |
| else: | |
| assert reference_points.shape[-1] == 2 | |
| reference_points_input = ( | |
| reference_points[:, :, None] * valid_ratios[None, :] | |
| ) | |
| query_sine_embed = gen_sineembed_for_position( | |
| reference_points_input[:, :, 0, :] | |
| ) # nq, bs, 256*2 | |
| # conditional query | |
| raw_query_pos = self.ref_point_head(query_sine_embed) # nq, bs, 256 | |
| pos_scale = self.query_scale(output) if self.query_scale is not None else 1 | |
| query_pos = pos_scale * raw_query_pos | |
| # if os.environ.get("SHILONG_AMP_INFNAN_DEBUG") == '1': | |
| # if query_pos.isnan().any() | query_pos.isinf().any(): | |
| # import ipdb; ipdb.set_trace() | |
| # main process | |
| output = layer( | |
| tgt=output, | |
| tgt_query_pos=query_pos, | |
| tgt_query_sine_embed=query_sine_embed, | |
| tgt_key_padding_mask=tgt_key_padding_mask, | |
| tgt_reference_points=reference_points_input, | |
| memory_text=memory_text, | |
| text_attention_mask=text_attention_mask, | |
| memory=memory, | |
| memory_key_padding_mask=memory_key_padding_mask, | |
| memory_level_start_index=level_start_index, | |
| memory_spatial_shapes=spatial_shapes, | |
| memory_pos=pos, | |
| self_attn_mask=tgt_mask, | |
| cross_attn_mask=memory_mask, | |
| ) | |
| if output.isnan().any() | output.isinf().any(): | |
| print(f"output layer_id {layer_id} is nan") | |
| try: | |
| num_nan = output.isnan().sum().item() | |
| num_inf = output.isinf().sum().item() | |
| print(f"num_nan {num_nan}, num_inf {num_inf}") | |
| except Exception as e: | |
| print(e) | |
| # if os.environ.get("SHILONG_AMP_INFNAN_DEBUG") == '1': | |
| # import ipdb; ipdb.set_trace() | |
| # iter update | |
| if self.bbox_embed is not None: | |
| # box_holder = self.bbox_embed(output) | |
| # box_holder[..., :self.query_dim] += inverse_sigmoid(reference_points) | |
| # new_reference_points = box_holder[..., :self.query_dim].sigmoid() | |
| reference_before_sigmoid = inverse_sigmoid(reference_points) | |
| delta_unsig = self.bbox_embed[layer_id](output) | |
| outputs_unsig = delta_unsig + reference_before_sigmoid | |
| new_reference_points = outputs_unsig.sigmoid() | |
| reference_points = new_reference_points.detach() | |
| # if layer_id != self.num_layers - 1: | |
| ref_points.append(new_reference_points) | |
| intermediate.append(self.norm(output)) | |
| # import pdb;pdb.set_trace() | |
| return [ | |
| [itm_out.transpose(0, 1) for itm_out in intermediate], | |
| [itm_refpoint.transpose(0, 1) for itm_refpoint in ref_points], | |
| ] | |
| class DeformableTransformerEncoderLayer(nn.Module): | |
| def __init__( | |
| self, | |
| d_model=256, | |
| d_ffn=1024, | |
| dropout=0.1, | |
| activation="relu", | |
| n_levels=4, | |
| n_heads=8, | |
| n_points=4, | |
| ): | |
| super().__init__() | |
| # self attention | |
| self.self_attn = MSDeformAttn( | |
| embed_dim=d_model, | |
| num_levels=n_levels, | |
| num_heads=n_heads, | |
| num_points=n_points, | |
| batch_first=True, | |
| ) | |
| self.dropout1 = nn.Dropout(dropout) | |
| self.norm1 = nn.LayerNorm(d_model) | |
| # ffn | |
| self.linear1 = nn.Linear(d_model, d_ffn) | |
| self.activation = _get_activation_fn(activation, d_model=d_ffn) | |
| self.dropout2 = nn.Dropout(dropout) | |
| self.linear2 = nn.Linear(d_ffn, d_model) | |
| self.dropout3 = nn.Dropout(dropout) | |
| self.norm2 = nn.LayerNorm(d_model) | |
| def with_pos_embed(tensor, pos): | |
| return tensor if pos is None else tensor + pos | |
| def forward_ffn(self, src): | |
| src2 = self.linear2(self.dropout2(self.activation(self.linear1(src)))) | |
| src = src + self.dropout3(src2) | |
| src = self.norm2(src) | |
| return src | |
| def forward( | |
| self, | |
| src, | |
| pos, | |
| reference_points, | |
| spatial_shapes, | |
| level_start_index, | |
| key_padding_mask=None, | |
| ): | |
| # self attention | |
| # import ipdb; ipdb.set_trace() | |
| src2 = self.self_attn( | |
| query=self.with_pos_embed(src, pos), | |
| reference_points=reference_points, | |
| value=src, | |
| spatial_shapes=spatial_shapes, | |
| level_start_index=level_start_index, | |
| key_padding_mask=key_padding_mask, | |
| ) | |
| src = src + self.dropout1(src2) | |
| src = self.norm1(src) | |
| # ffn | |
| src = self.forward_ffn(src) | |
| return src | |
| class DeformableTransformerDecoderLayer(nn.Module): | |
| def __init__( | |
| self, | |
| d_model=256, | |
| d_ffn=1024, | |
| dropout=0.1, | |
| activation="relu", | |
| n_levels=4, | |
| n_heads=8, | |
| n_points=4, | |
| use_text_feat_guide=False, | |
| use_text_cross_attention=False, | |
| ): | |
| super().__init__() | |
| # cross attention | |
| self.cross_attn = MSDeformAttn( | |
| embed_dim=d_model, | |
| num_levels=n_levels, | |
| num_heads=n_heads, | |
| num_points=n_points, | |
| batch_first=True, | |
| ) | |
| self.dropout1 = nn.Dropout(dropout) if dropout > 0 else nn.Identity() | |
| self.norm1 = nn.LayerNorm(d_model) | |
| # cross attention text | |
| if use_text_cross_attention: | |
| self.ca_text = nn.MultiheadAttention(d_model, n_heads, dropout=dropout) | |
| self.catext_dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity() | |
| self.catext_norm = nn.LayerNorm(d_model) | |
| # self attention | |
| self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout) | |
| self.dropout2 = nn.Dropout(dropout) if dropout > 0 else nn.Identity() | |
| self.norm2 = nn.LayerNorm(d_model) | |
| # ffn | |
| self.linear1 = nn.Linear(d_model, d_ffn) | |
| self.activation = _get_activation_fn(activation, d_model=d_ffn, batch_dim=1) | |
| self.dropout3 = nn.Dropout(dropout) if dropout > 0 else nn.Identity() | |
| self.linear2 = nn.Linear(d_ffn, d_model) | |
| self.dropout4 = nn.Dropout(dropout) if dropout > 0 else nn.Identity() | |
| self.norm3 = nn.LayerNorm(d_model) | |
| self.key_aware_proj = None | |
| self.use_text_feat_guide = use_text_feat_guide | |
| assert not use_text_feat_guide | |
| self.use_text_cross_attention = use_text_cross_attention | |
| def rm_self_attn_modules(self): | |
| self.self_attn = None | |
| self.dropout2 = None | |
| self.norm2 = None | |
| def with_pos_embed(tensor, pos): | |
| return tensor if pos is None else tensor + pos | |
| def forward_ffn(self, tgt): | |
| with torch.cuda.amp.autocast(enabled=False): | |
| tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt)))) | |
| tgt = tgt + self.dropout4(tgt2) | |
| tgt = self.norm3(tgt) | |
| return tgt | |
| def forward( | |
| self, | |
| # for tgt | |
| tgt: Optional[Tensor], # nq, bs, d_model | |
| tgt_query_pos: Optional[Tensor] = None, # pos for query. MLP(Sine(pos)) | |
| tgt_query_sine_embed: Optional[Tensor] = None, # pos for query. Sine(pos) | |
| tgt_key_padding_mask: Optional[Tensor] = None, | |
| tgt_reference_points: Optional[Tensor] = None, # nq, bs, 4 | |
| memory_text: Optional[Tensor] = None, # bs, num_token, d_model | |
| text_attention_mask: Optional[Tensor] = None, # bs, num_token | |
| # for memory | |
| memory: Optional[Tensor] = None, # hw, bs, d_model | |
| memory_key_padding_mask: Optional[Tensor] = None, | |
| memory_level_start_index: Optional[Tensor] = None, # num_levels | |
| memory_spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2 | |
| memory_pos: Optional[Tensor] = None, # pos for memory | |
| # sa | |
| self_attn_mask: Optional[Tensor] = None, # mask used for self-attention | |
| cross_attn_mask: Optional[Tensor] = None, # mask used for cross-attention | |
| ): | |
| """ | |
| Input: | |
| - tgt/tgt_query_pos: nq, bs, d_model | |
| - | |
| """ | |
| assert cross_attn_mask is None | |
| # self attention | |
| if self.self_attn is not None: | |
| # import ipdb; ipdb.set_trace() | |
| q = k = self.with_pos_embed(tgt, tgt_query_pos) | |
| tgt2 = self.self_attn(q, k, tgt, attn_mask=self_attn_mask)[0] | |
| tgt = tgt + self.dropout2(tgt2) | |
| tgt = self.norm2(tgt) | |
| if self.use_text_cross_attention: | |
| tgt2 = self.ca_text( | |
| self.with_pos_embed(tgt, tgt_query_pos), | |
| memory_text.transpose(0, 1), | |
| memory_text.transpose(0, 1), | |
| key_padding_mask=text_attention_mask, | |
| )[0] | |
| tgt = tgt + self.catext_dropout(tgt2) | |
| tgt = self.catext_norm(tgt) | |
| tgt2 = self.cross_attn( | |
| query=self.with_pos_embed(tgt, tgt_query_pos).transpose(0, 1), | |
| reference_points=tgt_reference_points.transpose(0, 1).contiguous(), | |
| value=memory.transpose(0, 1), | |
| spatial_shapes=memory_spatial_shapes, | |
| level_start_index=memory_level_start_index, | |
| key_padding_mask=memory_key_padding_mask, | |
| ).transpose(0, 1) | |
| tgt = tgt + self.dropout1(tgt2) | |
| tgt = self.norm1(tgt) | |
| # ffn | |
| tgt = self.forward_ffn(tgt) | |
| return tgt | |
| def build_transformer(args): | |
| return Transformer( | |
| d_model=args.hidden_dim, | |
| dropout=args.dropout, | |
| nhead=args.nheads, | |
| num_queries=args.num_queries, | |
| dim_feedforward=args.dim_feedforward, | |
| num_encoder_layers=args.enc_layers, | |
| num_decoder_layers=args.dec_layers, | |
| normalize_before=args.pre_norm, | |
| return_intermediate_dec=True, | |
| query_dim=args.query_dim, | |
| activation=args.transformer_activation, | |
| num_patterns=args.num_patterns, | |
| num_feature_levels=args.num_feature_levels, | |
| enc_n_points=args.enc_n_points, | |
| dec_n_points=args.dec_n_points, | |
| learnable_tgt_init=True, | |
| # two stage | |
| two_stage_type=args.two_stage_type, # ['no', 'standard', 'early'] | |
| embed_init_tgt=args.embed_init_tgt, | |
| use_text_enhancer=args.use_text_enhancer, | |
| use_fusion_layer=args.use_fusion_layer, | |
| use_checkpoint=args.use_checkpoint, | |
| use_transformer_ckpt=args.use_transformer_ckpt, | |
| use_text_cross_attention=args.use_text_cross_attention, | |
| text_dropout=args.text_dropout, | |
| fusion_dropout=args.fusion_dropout, | |
| fusion_droppath=args.fusion_droppath, | |
| ) | |