Removed extraneous print statements
Browse files
models/GroundingDINO/transformer.py
CHANGED
|
@@ -237,7 +237,6 @@ class Transformer(nn.Module):
|
|
| 237 |
|
| 238 |
"""
|
| 239 |
# prepare input for encoder
|
| 240 |
-
print("inside transformer forward")
|
| 241 |
src_flatten = []
|
| 242 |
mask_flatten = []
|
| 243 |
lvl_pos_embed_flatten = []
|
|
@@ -274,7 +273,6 @@ class Transformer(nn.Module):
|
|
| 274 |
#########################################################
|
| 275 |
# Begin Encoder
|
| 276 |
#########################################################
|
| 277 |
-
print("begin transformer encoder")
|
| 278 |
memory, memory_text = self.encoder(
|
| 279 |
src_flatten,
|
| 280 |
pos=lvl_pos_embed_flatten,
|
|
@@ -288,7 +286,6 @@ class Transformer(nn.Module):
|
|
| 288 |
position_ids=text_dict["position_ids"],
|
| 289 |
text_self_attention_masks=text_dict["text_self_attention_masks"],
|
| 290 |
)
|
| 291 |
-
print("got encoder output")
|
| 292 |
#########################################################
|
| 293 |
# End Encoder
|
| 294 |
# - memory: bs, \sum{hw}, c
|
|
@@ -303,11 +300,9 @@ class Transformer(nn.Module):
|
|
| 303 |
# import ipdb; ipdb.set_trace()
|
| 304 |
|
| 305 |
if self.two_stage_type == "standard": # 把encoder的输出作为proposal
|
| 306 |
-
print("standard two stage")
|
| 307 |
output_memory, output_proposals = gen_encoder_output_proposals(
|
| 308 |
memory, mask_flatten, spatial_shapes
|
| 309 |
)
|
| 310 |
-
print("got output proposals")
|
| 311 |
output_memory = self.enc_output_norm(self.enc_output(output_memory))
|
| 312 |
|
| 313 |
if text_dict is not None:
|
|
@@ -324,29 +319,22 @@ class Transformer(nn.Module):
|
|
| 324 |
topk = self.num_queries
|
| 325 |
|
| 326 |
topk_proposals = torch.topk(topk_logits, topk, dim=1)[1] # bs, nq
|
| 327 |
-
print("got topk proposals")
|
| 328 |
# gather boxes
|
| 329 |
-
print("gather 1")
|
| 330 |
refpoint_embed_undetach = torch.gather(
|
| 331 |
enc_outputs_coord_unselected,
|
| 332 |
1,
|
| 333 |
topk_proposals.unsqueeze(-1).repeat(1, 1, 4),
|
| 334 |
) # unsigmoid
|
| 335 |
-
print("gathered 1")
|
| 336 |
refpoint_embed_ = refpoint_embed_undetach.detach()
|
| 337 |
-
print("gather 2")
|
| 338 |
init_box_proposal = torch.gather(
|
| 339 |
output_proposals, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)
|
| 340 |
).sigmoid() # sigmoid
|
| 341 |
-
print("gathered 2")
|
| 342 |
-
print("gather 3")
|
| 343 |
# gather tgt
|
| 344 |
tgt_undetach = torch.gather(
|
| 345 |
output_memory,
|
| 346 |
1,
|
| 347 |
topk_proposals.unsqueeze(-1).repeat(1, 1, self.d_model),
|
| 348 |
)
|
| 349 |
-
print("gathered 3")
|
| 350 |
if self.embed_init_tgt:
|
| 351 |
tgt_ = (
|
| 352 |
self.tgt_embed.weight[:, None, :].repeat(1, bs, 1).transpose(0, 1)
|
|
@@ -401,7 +389,6 @@ class Transformer(nn.Module):
|
|
| 401 |
# memory torch.Size([2, 16320, 256])
|
| 402 |
|
| 403 |
# import pdb;pdb.set_trace()
|
| 404 |
-
print("going through decoder")
|
| 405 |
hs, references = self.decoder(
|
| 406 |
tgt=tgt.transpose(0, 1),
|
| 407 |
memory=memory.transpose(0, 1),
|
|
@@ -416,7 +403,6 @@ class Transformer(nn.Module):
|
|
| 416 |
text_attention_mask=~text_dict["text_token_mask"],
|
| 417 |
# we ~ the mask . False means use the token; True means pad the token
|
| 418 |
)
|
| 419 |
-
print("got decoder output")
|
| 420 |
#########################################################
|
| 421 |
# End Decoder
|
| 422 |
# hs: n_dec, bs, nq, d_model
|
|
@@ -560,7 +546,6 @@ class TransformerEncoder(nn.Module):
|
|
| 560 |
"""
|
| 561 |
|
| 562 |
output = src
|
| 563 |
-
print("inside transformer encoder")
|
| 564 |
# preparation and reshape
|
| 565 |
if self.num_layers > 0:
|
| 566 |
reference_points = self.get_reference_points(
|
|
@@ -591,10 +576,8 @@ class TransformerEncoder(nn.Module):
|
|
| 591 |
# if output.isnan().any() or memory_text.isnan().any():
|
| 592 |
# if os.environ.get('IPDB_SHILONG_DEBUG', None) == 'INFO':
|
| 593 |
# import ipdb; ipdb.set_trace()
|
| 594 |
-
print("layer_id: " + str(layer_id))
|
| 595 |
if self.fusion_layers:
|
| 596 |
if self.use_checkpoint:
|
| 597 |
-
print("using checkpoint")
|
| 598 |
output, memory_text = checkpoint.checkpoint(
|
| 599 |
self.fusion_layers[layer_id],
|
| 600 |
output,
|
|
@@ -602,30 +585,24 @@ class TransformerEncoder(nn.Module):
|
|
| 602 |
key_padding_mask,
|
| 603 |
text_attention_mask,
|
| 604 |
)
|
| 605 |
-
print("got checkpoint output")
|
| 606 |
else:
|
| 607 |
-
print("not using checkpoint")
|
| 608 |
output, memory_text = self.fusion_layers[layer_id](
|
| 609 |
v=output,
|
| 610 |
l=memory_text,
|
| 611 |
attention_mask_v=key_padding_mask,
|
| 612 |
attention_mask_l=text_attention_mask,
|
| 613 |
)
|
| 614 |
-
print("got fusion output")
|
| 615 |
|
| 616 |
if self.text_layers:
|
| 617 |
-
print("getting text layers")
|
| 618 |
memory_text = self.text_layers[layer_id](
|
| 619 |
src=memory_text.transpose(0, 1),
|
| 620 |
src_mask=~text_self_attention_masks, # note we use ~ for mask here
|
| 621 |
src_key_padding_mask=text_attention_mask,
|
| 622 |
pos=(pos_text.transpose(0, 1) if pos_text is not None else None),
|
| 623 |
).transpose(0, 1)
|
| 624 |
-
print("got text output")
|
| 625 |
|
| 626 |
# main process
|
| 627 |
if self.use_transformer_ckpt:
|
| 628 |
-
print("use transformer ckpt")
|
| 629 |
output = checkpoint.checkpoint(
|
| 630 |
layer,
|
| 631 |
output,
|
|
@@ -635,9 +612,7 @@ class TransformerEncoder(nn.Module):
|
|
| 635 |
level_start_index,
|
| 636 |
key_padding_mask,
|
| 637 |
)
|
| 638 |
-
print("got output")
|
| 639 |
else:
|
| 640 |
-
print("not use transformer ckpt")
|
| 641 |
output = layer(
|
| 642 |
src=output,
|
| 643 |
pos=pos,
|
|
@@ -646,7 +621,6 @@ class TransformerEncoder(nn.Module):
|
|
| 646 |
level_start_index=level_start_index,
|
| 647 |
key_padding_mask=key_padding_mask,
|
| 648 |
)
|
| 649 |
-
print("got output")
|
| 650 |
|
| 651 |
return output, memory_text
|
| 652 |
|
|
@@ -847,7 +821,6 @@ class DeformableTransformerEncoderLayer(nn.Module):
|
|
| 847 |
):
|
| 848 |
# self attention
|
| 849 |
# import ipdb; ipdb.set_trace()
|
| 850 |
-
print("deformable self-attention")
|
| 851 |
src2 = self.self_attn(
|
| 852 |
query=self.with_pos_embed(src, pos),
|
| 853 |
reference_points=reference_points,
|
|
|
|
| 237 |
|
| 238 |
"""
|
| 239 |
# prepare input for encoder
|
|
|
|
| 240 |
src_flatten = []
|
| 241 |
mask_flatten = []
|
| 242 |
lvl_pos_embed_flatten = []
|
|
|
|
| 273 |
#########################################################
|
| 274 |
# Begin Encoder
|
| 275 |
#########################################################
|
|
|
|
| 276 |
memory, memory_text = self.encoder(
|
| 277 |
src_flatten,
|
| 278 |
pos=lvl_pos_embed_flatten,
|
|
|
|
| 286 |
position_ids=text_dict["position_ids"],
|
| 287 |
text_self_attention_masks=text_dict["text_self_attention_masks"],
|
| 288 |
)
|
|
|
|
| 289 |
#########################################################
|
| 290 |
# End Encoder
|
| 291 |
# - memory: bs, \sum{hw}, c
|
|
|
|
| 300 |
# import ipdb; ipdb.set_trace()
|
| 301 |
|
| 302 |
if self.two_stage_type == "standard": # 把encoder的输出作为proposal
|
|
|
|
| 303 |
output_memory, output_proposals = gen_encoder_output_proposals(
|
| 304 |
memory, mask_flatten, spatial_shapes
|
| 305 |
)
|
|
|
|
| 306 |
output_memory = self.enc_output_norm(self.enc_output(output_memory))
|
| 307 |
|
| 308 |
if text_dict is not None:
|
|
|
|
| 319 |
topk = self.num_queries
|
| 320 |
|
| 321 |
topk_proposals = torch.topk(topk_logits, topk, dim=1)[1] # bs, nq
|
|
|
|
| 322 |
# gather boxes
|
|
|
|
| 323 |
refpoint_embed_undetach = torch.gather(
|
| 324 |
enc_outputs_coord_unselected,
|
| 325 |
1,
|
| 326 |
topk_proposals.unsqueeze(-1).repeat(1, 1, 4),
|
| 327 |
) # unsigmoid
|
|
|
|
| 328 |
refpoint_embed_ = refpoint_embed_undetach.detach()
|
|
|
|
| 329 |
init_box_proposal = torch.gather(
|
| 330 |
output_proposals, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)
|
| 331 |
).sigmoid() # sigmoid
|
|
|
|
|
|
|
| 332 |
# gather tgt
|
| 333 |
tgt_undetach = torch.gather(
|
| 334 |
output_memory,
|
| 335 |
1,
|
| 336 |
topk_proposals.unsqueeze(-1).repeat(1, 1, self.d_model),
|
| 337 |
)
|
|
|
|
| 338 |
if self.embed_init_tgt:
|
| 339 |
tgt_ = (
|
| 340 |
self.tgt_embed.weight[:, None, :].repeat(1, bs, 1).transpose(0, 1)
|
|
|
|
| 389 |
# memory torch.Size([2, 16320, 256])
|
| 390 |
|
| 391 |
# import pdb;pdb.set_trace()
|
|
|
|
| 392 |
hs, references = self.decoder(
|
| 393 |
tgt=tgt.transpose(0, 1),
|
| 394 |
memory=memory.transpose(0, 1),
|
|
|
|
| 403 |
text_attention_mask=~text_dict["text_token_mask"],
|
| 404 |
# we ~ the mask . False means use the token; True means pad the token
|
| 405 |
)
|
|
|
|
| 406 |
#########################################################
|
| 407 |
# End Decoder
|
| 408 |
# hs: n_dec, bs, nq, d_model
|
|
|
|
| 546 |
"""
|
| 547 |
|
| 548 |
output = src
|
|
|
|
| 549 |
# preparation and reshape
|
| 550 |
if self.num_layers > 0:
|
| 551 |
reference_points = self.get_reference_points(
|
|
|
|
| 576 |
# if output.isnan().any() or memory_text.isnan().any():
|
| 577 |
# if os.environ.get('IPDB_SHILONG_DEBUG', None) == 'INFO':
|
| 578 |
# import ipdb; ipdb.set_trace()
|
|
|
|
| 579 |
if self.fusion_layers:
|
| 580 |
if self.use_checkpoint:
|
|
|
|
| 581 |
output, memory_text = checkpoint.checkpoint(
|
| 582 |
self.fusion_layers[layer_id],
|
| 583 |
output,
|
|
|
|
| 585 |
key_padding_mask,
|
| 586 |
text_attention_mask,
|
| 587 |
)
|
|
|
|
| 588 |
else:
|
|
|
|
| 589 |
output, memory_text = self.fusion_layers[layer_id](
|
| 590 |
v=output,
|
| 591 |
l=memory_text,
|
| 592 |
attention_mask_v=key_padding_mask,
|
| 593 |
attention_mask_l=text_attention_mask,
|
| 594 |
)
|
|
|
|
| 595 |
|
| 596 |
if self.text_layers:
|
|
|
|
| 597 |
memory_text = self.text_layers[layer_id](
|
| 598 |
src=memory_text.transpose(0, 1),
|
| 599 |
src_mask=~text_self_attention_masks, # note we use ~ for mask here
|
| 600 |
src_key_padding_mask=text_attention_mask,
|
| 601 |
pos=(pos_text.transpose(0, 1) if pos_text is not None else None),
|
| 602 |
).transpose(0, 1)
|
|
|
|
| 603 |
|
| 604 |
# main process
|
| 605 |
if self.use_transformer_ckpt:
|
|
|
|
| 606 |
output = checkpoint.checkpoint(
|
| 607 |
layer,
|
| 608 |
output,
|
|
|
|
| 612 |
level_start_index,
|
| 613 |
key_padding_mask,
|
| 614 |
)
|
|
|
|
| 615 |
else:
|
|
|
|
| 616 |
output = layer(
|
| 617 |
src=output,
|
| 618 |
pos=pos,
|
|
|
|
| 621 |
level_start_index=level_start_index,
|
| 622 |
key_padding_mask=key_padding_mask,
|
| 623 |
)
|
|
|
|
| 624 |
|
| 625 |
return output, memory_text
|
| 626 |
|
|
|
|
| 821 |
):
|
| 822 |
# self attention
|
| 823 |
# import ipdb; ipdb.set_trace()
|
|
|
|
| 824 |
src2 = self.self_attn(
|
| 825 |
query=self.with_pos_embed(src, pos),
|
| 826 |
reference_points=reference_points,
|