abreza commited on
Commit
1e29c76
·
1 Parent(s): ba5557d

update xpose code

Browse files
Files changed (30) hide show
  1. requirements.txt +2 -3
  2. src/utils/dependencies/XPose/config_model/UniPose_SwinT.py +2 -2
  3. src/utils/dependencies/XPose/models/UniPose/attention.py +27 -2
  4. src/utils/dependencies/XPose/models/UniPose/backbone.py +4 -2
  5. src/utils/dependencies/XPose/models/UniPose/deformable_transformer.py +12 -5
  6. src/utils/dependencies/XPose/models/UniPose/fuse_modules.py +10 -6
  7. src/utils/dependencies/XPose/models/UniPose/mask_generate.py +6 -0
  8. src/utils/dependencies/XPose/models/UniPose/ops/modules/ms_deform_attn.py +4 -1
  9. src/utils/dependencies/XPose/models/UniPose/ops/setup.py +0 -3
  10. src/utils/dependencies/XPose/models/UniPose/ops/src/cuda/ms_deform_attn_cuda.cu +2 -2
  11. src/utils/dependencies/XPose/models/UniPose/position_encoding.py +1 -0
  12. src/utils/dependencies/XPose/models/UniPose/swin_transformer.py +6 -8
  13. src/utils/dependencies/XPose/models/UniPose/transformer_deformable.py +24 -18
  14. src/utils/dependencies/XPose/models/UniPose/transformer_vanilla.py +6 -2
  15. src/utils/dependencies/XPose/models/UniPose/unipose.py +23 -14
  16. src/utils/dependencies/XPose/models/UniPose/utils.py +1 -1
  17. src/utils/dependencies/XPose/transforms.py +1 -0
  18. src/utils/dependencies/XPose/util/__init__.py +1 -0
  19. src/utils/dependencies/XPose/util/addict.py +0 -159
  20. src/utils/dependencies/XPose/util/box_ops.py +1 -1
  21. src/utils/dependencies/XPose/util/config.py +13 -6
  22. src/utils/dependencies/XPose/util/get_param_dicts.py +61 -0
  23. src/utils/dependencies/XPose/util/instance.txt +863 -0
  24. src/utils/dependencies/XPose/util/logger.py +95 -0
  25. src/utils/dependencies/XPose/util/metrics.py +181 -0
  26. src/utils/dependencies/XPose/util/optim.py +70 -0
  27. src/utils/dependencies/XPose/util/plot_utils.py +112 -0
  28. src/utils/dependencies/XPose/util/slio.py +173 -0
  29. src/utils/dependencies/XPose/util/time_counter.py +60 -0
  30. src/utils/dependencies/XPose/util/utils.py +499 -0
requirements.txt CHANGED
@@ -1,7 +1,6 @@
1
  --extra-index-url https://download.pytorch.org/whl/cu118
2
-
3
- torch==1.12.1
4
- torchvision==0.13.1
5
  torchaudio==2.3.0
6
 
7
  numpy==1.26.4
 
1
  --extra-index-url https://download.pytorch.org/whl/cu118
2
+ torch
3
+ torchvision==0.18.0
 
4
  torchaudio==2.3.0
5
 
6
  numpy==1.26.4
src/utils/dependencies/XPose/config_model/UniPose_SwinT.py CHANGED
@@ -108,7 +108,7 @@ shuffle_type = None
108
  use_text_enhancer = True
109
  use_fusion_layer = True
110
 
111
- use_checkpoint = False # True
112
  use_transformer_ckpt = True
113
  text_encoder_type = 'bert-base-uncased'
114
 
@@ -122,4 +122,4 @@ binary_query_selection = False
122
  use_cdn = True
123
  ffn_extra_layernorm = False
124
 
125
- fix_size=False
 
108
  use_text_enhancer = True
109
  use_fusion_layer = True
110
 
111
+ use_checkpoint = True
112
  use_transformer_ckpt = True
113
  text_encoder_type = 'bert-base-uncased'
114
 
 
122
  use_cdn = True
123
  ffn_extra_layernorm = False
124
 
125
+ fix_size=False
src/utils/dependencies/XPose/models/UniPose/attention.py CHANGED
@@ -23,19 +23,44 @@ Mostly copy-paste from https://github.com/pytorch/pytorch/blob/master/torch/nn/m
23
  and https://github.com/pytorch/pytorch/blob/master/torch/nn/functional.py#L4837
24
  """
25
 
 
 
 
 
 
 
 
26
  import warnings
 
 
27
  import torch
 
28
  from torch.nn.modules.linear import Linear
 
29
  from torch.nn.init import constant_
 
 
30
  from torch.nn.modules.module import Module
31
- from torch._jit_internal import Optional, Tuple
 
 
 
 
 
 
 
 
 
 
 
32
  try:
33
  from torch.overrides import has_torch_function, handle_torch_function
34
  except:
35
  from torch._overrides import has_torch_function, handle_torch_function
36
- from torch.nn.functional import linear, pad, softmax, dropout
37
  Tensor = torch.Tensor
38
 
 
 
39
  class MultiheadAttention(Module):
40
  r"""Allows the model to jointly attend to information
41
  from different representation subspaces.
 
23
  and https://github.com/pytorch/pytorch/blob/master/torch/nn/functional.py#L4837
24
  """
25
 
26
+ import copy
27
+ from typing import Optional, List
28
+
29
+ import torch
30
+ import torch.nn.functional as F
31
+ from torch import nn, Tensor
32
+
33
  import warnings
34
+ from typing import Tuple, Optional
35
+
36
  import torch
37
+ from torch import Tensor
38
  from torch.nn.modules.linear import Linear
39
+ from torch.nn.init import xavier_uniform_
40
  from torch.nn.init import constant_
41
+ from torch.nn.init import xavier_normal_
42
+ from torch.nn.parameter import Parameter
43
  from torch.nn.modules.module import Module
44
+ from torch.nn import functional as F
45
+
46
+ import warnings
47
+ import math
48
+
49
+ from torch._C import _infer_size, _add_docstr
50
+ from torch.nn import _reduction as _Reduction
51
+ from torch.nn.modules import utils
52
+ from torch.nn.modules.utils import _single, _pair, _triple, _list_with_default
53
+ from torch.nn import grad
54
+ from torch import _VF
55
+ from torch._jit_internal import boolean_dispatch, List, Optional, _overload, Tuple
56
  try:
57
  from torch.overrides import has_torch_function, handle_torch_function
58
  except:
59
  from torch._overrides import has_torch_function, handle_torch_function
 
60
  Tensor = torch.Tensor
61
 
62
+ from torch.nn.functional import linear, pad, softmax, dropout
63
+
64
  class MultiheadAttention(Module):
65
  r"""Allows the model to jointly attend to information
66
  from different representation subspaces.
src/utils/dependencies/XPose/models/UniPose/backbone.py CHANGED
@@ -16,18 +16,20 @@
16
  Backbone modules.
17
  """
18
 
 
 
19
  import torch
20
  import torch.nn.functional as F
21
  import torchvision
22
  from torch import nn
23
  from torchvision.models._utils import IntermediateLayerGetter
24
- from typing import Dict, List
25
 
26
- from util.misc import NestedTensor, is_main_process
27
 
28
  from .position_encoding import build_position_encoding
29
  from .swin_transformer import build_swin_transformer
30
 
 
31
  class FrozenBatchNorm2d(torch.nn.Module):
32
  """
33
  BatchNorm2d where the batch statistics and the affine parameters are fixed.
 
16
  Backbone modules.
17
  """
18
 
19
+ from typing import Dict, List
20
+
21
  import torch
22
  import torch.nn.functional as F
23
  import torchvision
24
  from torch import nn
25
  from torchvision.models._utils import IntermediateLayerGetter
 
26
 
27
+ from util.misc import NestedTensor, clean_state_dict, is_main_process
28
 
29
  from .position_encoding import build_position_encoding
30
  from .swin_transformer import build_swin_transformer
31
 
32
+
33
  class FrozenBatchNorm2d(torch.nn.Module):
34
  """
35
  BatchNorm2d where the batch statistics and the affine parameters are fixed.
src/utils/dependencies/XPose/models/UniPose/deformable_transformer.py CHANGED
@@ -16,16 +16,21 @@
16
  # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
17
  # ------------------------------------------------------------------------
18
 
19
- import math
 
20
  import copy
21
- import torch
22
- import torch.utils.checkpoint as checkpoint
23
- from torch import nn, Tensor
24
  from typing import Optional
 
25
  from util.misc import inverse_sigmoid
26
 
 
 
 
 
 
27
  from .transformer_vanilla import TransformerEncoderLayer
28
  from .fuse_modules import BiAttentionBlock
 
29
  from .utils import gen_encoder_output_proposals, MLP, _get_activation_fn, gen_sineembed_for_position, get_sine_pos_embed
30
  from .ops.modules import MSDeformAttn
31
 
@@ -580,7 +585,7 @@ class TransformerEncoder(nn.Module):
580
  reference_points_list = []
581
  for lvl, (H_, W_) in enumerate(spatial_shapes):
582
  ref_y, ref_x = torch.meshgrid(torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
583
- torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device),)
584
  ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_)
585
  ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_)
586
  ref = torch.stack((ref_x, ref_y), -1)
@@ -1228,3 +1233,5 @@ def build_deformable_transformer(args):
1228
  binary_query_selection=binary_query_selection,
1229
  ffn_extra_layernorm=ffn_extra_layernorm,
1230
  )
 
 
 
16
  # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
17
  # ------------------------------------------------------------------------
18
 
19
+ import math, random
20
+ import os
21
  import copy
 
 
 
22
  from typing import Optional
23
+
24
  from util.misc import inverse_sigmoid
25
 
26
+ import torch
27
+ from torch import nn, Tensor
28
+
29
+ import torch.utils.checkpoint as checkpoint
30
+
31
  from .transformer_vanilla import TransformerEncoderLayer
32
  from .fuse_modules import BiAttentionBlock
33
+
34
  from .utils import gen_encoder_output_proposals, MLP, _get_activation_fn, gen_sineembed_for_position, get_sine_pos_embed
35
  from .ops.modules import MSDeformAttn
36
 
 
585
  reference_points_list = []
586
  for lvl, (H_, W_) in enumerate(spatial_shapes):
587
  ref_y, ref_x = torch.meshgrid(torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
588
+ torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device))
589
  ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_)
590
  ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_)
591
  ref = torch.stack((ref_x, ref_y), -1)
 
1233
  binary_query_selection=binary_query_selection,
1234
  ffn_extra_layernorm=ffn_extra_layernorm,
1235
  )
1236
+
1237
+
src/utils/dependencies/XPose/models/UniPose/fuse_modules.py CHANGED
@@ -1,9 +1,12 @@
1
- import torch
 
2
  import torch.nn as nn
3
  import torch.nn.functional as F
 
 
 
4
 
5
- # from timm.models.layers import DropPath
6
- from src.modules.util import DropPath
7
 
8
  class FeatureResizer(nn.Module):
9
  """
@@ -178,7 +181,7 @@ class BiMultiHeadAttention(nn.Module):
178
 
179
  if self.stable_softmax_2d:
180
  attn_weights = attn_weights - attn_weights.max()
181
-
182
  if self.clamp_min_for_underflow:
183
  attn_weights = torch.clamp(attn_weights, min=-50000) # Do not increase -50000, data type half has quite limited range
184
  if self.clamp_max_for_overflow:
@@ -261,8 +264,8 @@ class BiAttentionBlock(nn.Module):
261
 
262
  # add layer scale for training stability
263
  self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
264
- self.gamma_v = nn.Parameter(init_values * torch.ones((v_dim)), requires_grad=False)
265
- self.gamma_l = nn.Parameter(init_values * torch.ones((l_dim)), requires_grad=False)
266
 
267
  def forward(self, v, l, attention_mask_v=None, attention_mask_l=None):
268
  v = self.layer_norm_v(v)
@@ -272,3 +275,4 @@ class BiAttentionBlock(nn.Module):
272
  v = v + self.drop_path(self.gamma_v * delta_v)
273
  l = l + self.drop_path(self.gamma_l * delta_l)
274
  return v, l
 
 
1
+ from typing import List
2
+ import torch, os
3
  import torch.nn as nn
4
  import torch.nn.functional as F
5
+ import pdb
6
+ import math
7
+ from timm.models.layers import DropPath
8
 
9
+ from transformers.activations import ACT2FN
 
10
 
11
  class FeatureResizer(nn.Module):
12
  """
 
181
 
182
  if self.stable_softmax_2d:
183
  attn_weights = attn_weights - attn_weights.max()
184
+
185
  if self.clamp_min_for_underflow:
186
  attn_weights = torch.clamp(attn_weights, min=-50000) # Do not increase -50000, data type half has quite limited range
187
  if self.clamp_max_for_overflow:
 
264
 
265
  # add layer scale for training stability
266
  self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
267
+ self.gamma_v = nn.Parameter(init_values * torch.ones((v_dim)), requires_grad=True)
268
+ self.gamma_l = nn.Parameter(init_values * torch.ones((l_dim)), requires_grad=True)
269
 
270
  def forward(self, v, l, attention_mask_v=None, attention_mask_l=None):
271
  v = self.layer_norm_v(v)
 
275
  v = v + self.drop_path(self.gamma_v * delta_v)
276
  l = l + self.drop_path(self.gamma_l * delta_l)
277
  return v, l
278
+
src/utils/dependencies/XPose/models/UniPose/mask_generate.py CHANGED
@@ -1,4 +1,10 @@
1
  import torch
 
 
 
 
 
 
2
 
3
 
4
  def prepare_for_mask(kpt_mask):
 
1
  import torch
2
+ from util.misc import (NestedTensor, nested_tensor_from_tensor_list,
3
+ accuracy, get_world_size, interpolate,
4
+ is_dist_avail_and_initialized, inverse_sigmoid)
5
+ # from .DABDETR import sigmoid_focal_loss
6
+ from util import box_ops
7
+ import torch.nn.functional as F
8
 
9
 
10
  def prepare_for_mask(kpt_mask):
src/utils/dependencies/XPose/models/UniPose/ops/modules/ms_deform_attn.py CHANGED
@@ -20,7 +20,10 @@ from torch import nn
20
  import torch.nn.functional as F
21
  from torch.nn.init import xavier_uniform_, constant_
22
 
23
- from src.utils.dependencies.XPose.models.UniPose.ops.functions.ms_deform_attn_func import MSDeformAttnFunction
 
 
 
24
 
25
 
26
  def _is_power_of_2(n):
 
20
  import torch.nn.functional as F
21
  from torch.nn.init import xavier_uniform_, constant_
22
 
23
+ try:
24
+ from src.utils.dependencies.XPose.models.UniPose.ops.functions.ms_deform_attn_func import MSDeformAttnFunction
25
+ except:
26
+ warnings.warn('Failed to import MSDeformAttnFunction.')
27
 
28
 
29
  def _is_power_of_2(n):
src/utils/dependencies/XPose/models/UniPose/ops/setup.py CHANGED
@@ -41,12 +41,10 @@ def get_extensions():
41
  sources += source_cuda
42
  define_macros += [("WITH_CUDA", None)]
43
  extra_compile_args["nvcc"] = [
44
- # "-allow-unsupported-compiler",
45
  "-DCUDA_HAS_FP16=1",
46
  "-D__CUDA_NO_HALF_OPERATORS__",
47
  "-D__CUDA_NO_HALF_CONVERSIONS__",
48
  "-D__CUDA_NO_HALF2_OPERATORS__",
49
- # "-std=c++14",
50
  ]
51
  else:
52
  raise NotImplementedError('Cuda is not availabel')
@@ -64,7 +62,6 @@ def get_extensions():
64
  ]
65
  return ext_modules
66
 
67
-
68
  setup(
69
  name="MultiScaleDeformableAttention",
70
  version="1.0",
 
41
  sources += source_cuda
42
  define_macros += [("WITH_CUDA", None)]
43
  extra_compile_args["nvcc"] = [
 
44
  "-DCUDA_HAS_FP16=1",
45
  "-D__CUDA_NO_HALF_OPERATORS__",
46
  "-D__CUDA_NO_HALF_CONVERSIONS__",
47
  "-D__CUDA_NO_HALF2_OPERATORS__",
 
48
  ]
49
  else:
50
  raise NotImplementedError('Cuda is not availabel')
 
62
  ]
63
  return ext_modules
64
 
 
65
  setup(
66
  name="MultiScaleDeformableAttention",
67
  version="1.0",
src/utils/dependencies/XPose/models/UniPose/ops/src/cuda/ms_deform_attn_cuda.cu CHANGED
@@ -61,7 +61,7 @@ at::Tensor ms_deform_attn_cuda_forward(
61
  for (int n = 0; n < batch/im2col_step_; ++n)
62
  {
63
  auto columns = output_n.select(0, n);
64
- AT_DISPATCH_FLOATING_TYPES(value.scalar_type(), "ms_deform_attn_forward_cuda", ([&] {
65
  ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
66
  value.data<scalar_t>() + n * im2col_step_ * per_value_size,
67
  spatial_shapes.data<int64_t>(),
@@ -131,7 +131,7 @@ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
131
  for (int n = 0; n < batch/im2col_step_; ++n)
132
  {
133
  auto grad_output_g = grad_output_n.select(0, n);
134
- AT_DISPATCH_FLOATING_TYPES(value.scalar_type(), "ms_deform_attn_backward_cuda", ([&] {
135
  ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
136
  grad_output_g.data<scalar_t>(),
137
  value.data<scalar_t>() + n * im2col_step_ * per_value_size,
 
61
  for (int n = 0; n < batch/im2col_step_; ++n)
62
  {
63
  auto columns = output_n.select(0, n);
64
+ AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] {
65
  ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
66
  value.data<scalar_t>() + n * im2col_step_ * per_value_size,
67
  spatial_shapes.data<int64_t>(),
 
131
  for (int n = 0; n < batch/im2col_step_; ++n)
132
  {
133
  auto grad_output_g = grad_output_n.select(0, n);
134
+ AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] {
135
  ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
136
  grad_output_g.data<scalar_t>(),
137
  value.data<scalar_t>() + n * im2col_step_ * per_value_size,
src/utils/dependencies/XPose/models/UniPose/position_encoding.py CHANGED
@@ -15,6 +15,7 @@
15
  Various positional encodings for the transformer.
16
  """
17
  import math
 
18
  import torch
19
  from torch import nn
20
 
 
15
  Various positional encodings for the transformer.
16
  """
17
  import math
18
+ import os
19
  import torch
20
  from torch import nn
21
 
src/utils/dependencies/XPose/models/UniPose/swin_transformer.py CHANGED
@@ -4,10 +4,8 @@ import torch.nn as nn
4
  import torch.nn.functional as F
5
  import torch.utils.checkpoint as checkpoint
6
  import numpy as np
7
-
8
  from util.misc import NestedTensor
9
- # from timm.models.layers import DropPath, to_2tuple, trunc_normal_
10
- from src.modules.util import DropPath, to_2tuple, trunc_normal_
11
 
12
 
13
 
@@ -489,8 +487,8 @@ class SwinTransformer(nn.Module):
489
  self.frozen_stages = frozen_stages
490
  self.dilation = dilation
491
 
492
- # if use_checkpoint:
493
- # print("use_checkpoint!!!!!!!!!!!!!!!!!!!!!!!!")
494
 
495
  # split image into non-overlapping patches
496
  self.patch_embed = PatchEmbed(
@@ -634,7 +632,7 @@ class SwinTransformer(nn.Module):
634
  # [torch.Size([2, 192, 256, 256]), torch.Size([2, 384, 128, 128]), \
635
  # torch.Size([2, 768, 64, 64]), torch.Size([2, 1536, 32, 32])]
636
 
637
- # collect for nesttensors
638
  outs_dict = {}
639
  for idx, out_i in enumerate(outs):
640
  m = tensor_list.mask
@@ -661,7 +659,7 @@ def build_swin_transformer(modelname, pretrain_img_size, **kw):
661
  depths=[ 2, 2, 6, 2 ],
662
  num_heads=[ 3, 6, 12, 24],
663
  window_size=7
664
- ),
665
  'swin_B_224_22k': dict(
666
  embed_dim=128,
667
  depths=[ 2, 2, 18, 2 ],
@@ -698,4 +696,4 @@ if __name__ == "__main__":
698
  y = model.forward_raw(x)
699
  import ipdb; ipdb.set_trace()
700
  x = torch.rand(2, 3, 384, 384)
701
- y = model.forward_raw(x)
 
4
  import torch.nn.functional as F
5
  import torch.utils.checkpoint as checkpoint
6
  import numpy as np
7
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
8
  from util.misc import NestedTensor
 
 
9
 
10
 
11
 
 
487
  self.frozen_stages = frozen_stages
488
  self.dilation = dilation
489
 
490
+ if use_checkpoint:
491
+ print("use_checkpoint!!!!!!!!!!!!!!!!!!!!!!!!")
492
 
493
  # split image into non-overlapping patches
494
  self.patch_embed = PatchEmbed(
 
632
  # [torch.Size([2, 192, 256, 256]), torch.Size([2, 384, 128, 128]), \
633
  # torch.Size([2, 768, 64, 64]), torch.Size([2, 1536, 32, 32])]
634
 
635
+ # collect for nesttensors
636
  outs_dict = {}
637
  for idx, out_i in enumerate(outs):
638
  m = tensor_list.mask
 
659
  depths=[ 2, 2, 6, 2 ],
660
  num_heads=[ 3, 6, 12, 24],
661
  window_size=7
662
+ ),
663
  'swin_B_224_22k': dict(
664
  embed_dim=128,
665
  depths=[ 2, 2, 18, 2 ],
 
696
  y = model.forward_raw(x)
697
  import ipdb; ipdb.set_trace()
698
  x = torch.rand(2, 3, 384, 384)
699
+ y = model.forward_raw(x)
src/utils/dependencies/XPose/models/UniPose/transformer_deformable.py CHANGED
@@ -12,15 +12,19 @@
12
  # ------------------------------------------------------------------------
13
 
14
  import copy
 
 
15
  import math
 
16
  import torch
 
17
  from torch import nn, Tensor
18
- from torch.nn.init import xavier_uniform_, constant_, normal_
19
- from typing import Optional
20
 
21
  from util.misc import inverse_sigmoid
22
  from .ops.modules import MSDeformAttn
23
- from .utils import MLP, _get_activation_fn, gen_sineembed_for_position
 
24
 
25
  class DeformableTransformer(nn.Module):
26
  def __init__(self, d_model=256, nhead=8,
@@ -45,7 +49,7 @@ class DeformableTransformer(nn.Module):
45
  decoder_layer = DeformableTransformerDecoderLayer(d_model, dim_feedforward,
46
  dropout, activation,
47
  num_feature_levels, nhead, dec_n_points)
48
- self.decoder = DeformableTransformerDecoder(decoder_layer, num_decoder_layers, return_intermediate_dec,
49
  use_dab=use_dab, d_model=d_model, high_dim_query_update=high_dim_query_update, no_sine_embed=no_sine_embed)
50
 
51
  self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model))
@@ -158,7 +162,7 @@ class DeformableTransformer(nn.Module):
158
  lvl_pos_embed_flatten.append(lvl_pos_embed)
159
  src_flatten.append(src)
160
  mask_flatten.append(mask)
161
- src_flatten = torch.cat(src_flatten, 1) # bs, \sum{hxw}, c
162
  mask_flatten = torch.cat(mask_flatten, 1) # bs, \sum{hxw}
163
  lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
164
  spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device)
@@ -187,7 +191,7 @@ class DeformableTransformer(nn.Module):
187
  pos_trans_out = self.pos_trans_norm(self.pos_trans(self.get_proposal_pos_embed(topk_coords_unact)))
188
  query_embed, tgt = torch.split(pos_trans_out, c, dim=2)
189
  elif self.use_dab:
190
- reference_points = query_embed[..., self.d_model:].sigmoid()
191
  tgt = query_embed[..., :self.d_model]
192
  tgt = tgt.unsqueeze(0).expand(bs, -1, -1)
193
  init_reference_out = reference_points
@@ -195,15 +199,15 @@ class DeformableTransformer(nn.Module):
195
  query_embed, tgt = torch.split(query_embed, c, dim=1)
196
  query_embed = query_embed.unsqueeze(0).expand(bs, -1, -1)
197
  tgt = tgt.unsqueeze(0).expand(bs, -1, -1)
198
- reference_points = self.reference_points(query_embed).sigmoid()
199
  # bs, num_quires, 2
200
  init_reference_out = reference_points
201
 
202
  # decoder
203
  # import ipdb; ipdb.set_trace()
204
  hs, inter_references = self.decoder(tgt, reference_points, memory,
205
- spatial_shapes, level_start_index, valid_ratios,
206
- query_pos=query_embed if not self.use_dab else None,
207
  src_padding_mask=mask_flatten)
208
 
209
  inter_references_out = inter_references
@@ -387,7 +391,7 @@ class DeformableTransformerDecoderLayer(nn.Module):
387
  tgt = self.norm3(tgt)
388
  return tgt
389
 
390
- def forward_sa(self,
391
  # for tgt
392
  tgt: Optional[Tensor], # nq, bs, d_model
393
  tgt_query_pos: Optional[Tensor] = None, # pos for query. MLP(Sine(pos))
@@ -431,9 +435,9 @@ class DeformableTransformerDecoderLayer(nn.Module):
431
  else:
432
  raise NotImplementedError("Unknown decoder_sa_type {}".format(self.decoder_sa_type))
433
 
434
- return tgt
435
 
436
- def forward_ca(self,
437
  # for tgt
438
  tgt: Optional[Tensor], # nq, bs, d_model
439
  tgt_query_pos: Optional[Tensor] = None, # pos for query. MLP(Sine(pos))
@@ -468,9 +472,9 @@ class DeformableTransformerDecoderLayer(nn.Module):
468
  tgt = tgt + self.dropout1(tgt2)
469
  tgt = self.norm1(tgt)
470
 
471
- return tgt
472
 
473
- def forward(self,
474
  # for tgt
475
  tgt: Optional[Tensor], # nq, bs, d_model
476
  tgt_query_pos: Optional[Tensor] = None, # pos for query. MLP(Sine(pos))
@@ -530,7 +534,7 @@ class DeformableTransformerDecoder(nn.Module):
530
  self.ref_point_head = MLP(2 * d_model, d_model, d_model, 2)
531
 
532
 
533
- def forward(self, tgt, reference_points, src, src_spatial_shapes,
534
  src_level_start_index, src_valid_ratios,
535
  query_pos=None, src_padding_mask=None):
536
  output = tgt
@@ -547,14 +551,14 @@ class DeformableTransformerDecoder(nn.Module):
547
  else:
548
  assert reference_points.shape[-1] == 2
549
  reference_points_input = reference_points[:, :, None] * src_valid_ratios[:, None]
550
-
551
  if self.use_dab:
552
  # import ipdb; ipdb.set_trace()
553
- query_sine_embed = gen_sineembed_for_position(reference_points_input[:, :, 0, :]) # bs, nq, 256*2
554
  raw_query_pos = self.ref_point_head(query_sine_embed) # bs, nq, 256
555
  pos_scale = self.query_scale(output) if layer_id != 0 else 1
556
  query_pos = pos_scale * raw_query_pos
557
-
558
  output = layer(output, query_pos, reference_points_input, src, src_spatial_shapes, src_level_start_index, src_padding_mask)
559
 
560
  # hack implementation for iterative bounding box refinement
@@ -593,3 +597,5 @@ def build_deforamble_transformer(args):
593
  use_dab=args.ddetr_use_dab,
594
  high_dim_query_update=args.ddetr_high_dim_query_update,
595
  no_sine_embed=args.ddetr_no_sine_embed)
 
 
 
12
  # ------------------------------------------------------------------------
13
 
14
  import copy
15
+ import os
16
+ from typing import Optional, List
17
  import math
18
+
19
  import torch
20
+ import torch.nn.functional as F
21
  from torch import nn, Tensor
22
+ from torch.nn.init import xavier_uniform_, constant_, uniform_, normal_
 
23
 
24
  from util.misc import inverse_sigmoid
25
  from .ops.modules import MSDeformAttn
26
+
27
+ from .utils import sigmoid_focal_loss, MLP, _get_activation_fn, gen_sineembed_for_position
28
 
29
  class DeformableTransformer(nn.Module):
30
  def __init__(self, d_model=256, nhead=8,
 
49
  decoder_layer = DeformableTransformerDecoderLayer(d_model, dim_feedforward,
50
  dropout, activation,
51
  num_feature_levels, nhead, dec_n_points)
52
+ self.decoder = DeformableTransformerDecoder(decoder_layer, num_decoder_layers, return_intermediate_dec,
53
  use_dab=use_dab, d_model=d_model, high_dim_query_update=high_dim_query_update, no_sine_embed=no_sine_embed)
54
 
55
  self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model))
 
162
  lvl_pos_embed_flatten.append(lvl_pos_embed)
163
  src_flatten.append(src)
164
  mask_flatten.append(mask)
165
+ src_flatten = torch.cat(src_flatten, 1) # bs, \sum{hxw}, c
166
  mask_flatten = torch.cat(mask_flatten, 1) # bs, \sum{hxw}
167
  lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
168
  spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device)
 
191
  pos_trans_out = self.pos_trans_norm(self.pos_trans(self.get_proposal_pos_embed(topk_coords_unact)))
192
  query_embed, tgt = torch.split(pos_trans_out, c, dim=2)
193
  elif self.use_dab:
194
+ reference_points = query_embed[..., self.d_model:].sigmoid()
195
  tgt = query_embed[..., :self.d_model]
196
  tgt = tgt.unsqueeze(0).expand(bs, -1, -1)
197
  init_reference_out = reference_points
 
199
  query_embed, tgt = torch.split(query_embed, c, dim=1)
200
  query_embed = query_embed.unsqueeze(0).expand(bs, -1, -1)
201
  tgt = tgt.unsqueeze(0).expand(bs, -1, -1)
202
+ reference_points = self.reference_points(query_embed).sigmoid()
203
  # bs, num_quires, 2
204
  init_reference_out = reference_points
205
 
206
  # decoder
207
  # import ipdb; ipdb.set_trace()
208
  hs, inter_references = self.decoder(tgt, reference_points, memory,
209
+ spatial_shapes, level_start_index, valid_ratios,
210
+ query_pos=query_embed if not self.use_dab else None,
211
  src_padding_mask=mask_flatten)
212
 
213
  inter_references_out = inter_references
 
391
  tgt = self.norm3(tgt)
392
  return tgt
393
 
394
+ def forward_sa(self,
395
  # for tgt
396
  tgt: Optional[Tensor], # nq, bs, d_model
397
  tgt_query_pos: Optional[Tensor] = None, # pos for query. MLP(Sine(pos))
 
435
  else:
436
  raise NotImplementedError("Unknown decoder_sa_type {}".format(self.decoder_sa_type))
437
 
438
+ return tgt
439
 
440
+ def forward_ca(self,
441
  # for tgt
442
  tgt: Optional[Tensor], # nq, bs, d_model
443
  tgt_query_pos: Optional[Tensor] = None, # pos for query. MLP(Sine(pos))
 
472
  tgt = tgt + self.dropout1(tgt2)
473
  tgt = self.norm1(tgt)
474
 
475
+ return tgt
476
 
477
+ def forward(self,
478
  # for tgt
479
  tgt: Optional[Tensor], # nq, bs, d_model
480
  tgt_query_pos: Optional[Tensor] = None, # pos for query. MLP(Sine(pos))
 
534
  self.ref_point_head = MLP(2 * d_model, d_model, d_model, 2)
535
 
536
 
537
+ def forward(self, tgt, reference_points, src, src_spatial_shapes,
538
  src_level_start_index, src_valid_ratios,
539
  query_pos=None, src_padding_mask=None):
540
  output = tgt
 
551
  else:
552
  assert reference_points.shape[-1] == 2
553
  reference_points_input = reference_points[:, :, None] * src_valid_ratios[:, None]
554
+
555
  if self.use_dab:
556
  # import ipdb; ipdb.set_trace()
557
+ query_sine_embed = gen_sineembed_for_position(reference_points_input[:, :, 0, :]) # bs, nq, 256*2
558
  raw_query_pos = self.ref_point_head(query_sine_embed) # bs, nq, 256
559
  pos_scale = self.query_scale(output) if layer_id != 0 else 1
560
  query_pos = pos_scale * raw_query_pos
561
+
562
  output = layer(output, query_pos, reference_points_input, src, src_spatial_shapes, src_level_start_index, src_padding_mask)
563
 
564
  # hack implementation for iterative bounding box refinement
 
597
  use_dab=args.ddetr_use_dab,
598
  high_dim_query_update=args.ddetr_high_dim_query_update,
599
  no_sine_embed=args.ddetr_no_sine_embed)
600
+
601
+
src/utils/dependencies/XPose/models/UniPose/transformer_vanilla.py CHANGED
@@ -8,11 +8,15 @@ Copy-paste from torch.nn.Transformer with modifications:
8
  * extra LN at the end of encoder is removed
9
  * decoder returns a stack of activations from all decoding layers
10
  """
 
 
 
 
11
  import torch
 
12
  from torch import Tensor, nn
13
- from typing import List, Optional
14
 
15
- from .utils import _get_activation_fn, _get_clones
16
 
17
 
18
  class TextTransformer(nn.Module):
 
8
  * extra LN at the end of encoder is removed
9
  * decoder returns a stack of activations from all decoding layers
10
  """
11
+ import copy
12
+ import os
13
+ from typing import List, Optional
14
+ import pdb
15
  import torch
16
+ import torch.nn.functional as F
17
  from torch import Tensor, nn
 
18
 
19
+ from .utils import gen_encoder_output_proposals, sigmoid_focal_loss, MLP, _get_activation_fn, gen_sineembed_for_position, _get_clones
20
 
21
 
22
  class TextTransformer(nn.Module):
src/utils/dependencies/XPose/models/UniPose/unipose.py CHANGED
@@ -6,21 +6,30 @@
6
  # Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR)
7
  # Copyright (c) 2020 SenseTime. All Rights Reserved.
8
  # ------------------------------------------------------------------------
9
- import os
10
  import copy
 
 
 
11
  import torch
12
  import torch.nn.functional as F
13
  from torch import nn
14
- from typing import List
15
-
16
  from util.keypoint_ops import keypoint_xyzxyz_to_xyxyzz
17
- from util.misc import NestedTensor, nested_tensor_from_tensor_list,inverse_sigmoid
18
-
19
- from .utils import MLP
 
20
  from .backbone import build_backbone
 
 
 
21
  from ..registry import MODULE_BUILD_FUNCS
22
  from .mask_generate import prepare_for_mask, post_process
23
- from .deformable_transformer import build_deformable_transformer
 
 
 
 
24
 
25
 
26
  class UniPose(nn.Module):
@@ -107,12 +116,12 @@ class UniPose(nn.Module):
107
 
108
 
109
  device = "cuda" if torch.cuda.is_available() else "cpu"
110
- # model, _ = clip.load("ViT-B/32", device=device)
111
- # self.clip_model = model
112
- # visual_parameters = list(self.clip_model.visual.parameters())
113
- # #
114
- # for param in visual_parameters:
115
- # param.requires_grad = False
116
 
117
  self.pos_proj = nn.Linear(hidden_dim, 768)
118
  self.padding = nn.Embedding(1, 768)
@@ -531,7 +540,7 @@ def build_unipose(args):
531
  sub_sentence_present = args.sub_sentence_present
532
  except:
533
  sub_sentence_present = True
534
- # print('********* sub_sentence_present', sub_sentence_present)
535
 
536
  model = UniPose(
537
  backbone,
 
6
  # Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR)
7
  # Copyright (c) 2020 SenseTime. All Rights Reserved.
8
  # ------------------------------------------------------------------------
 
9
  import copy
10
+ import math
11
+ import os
12
+ from typing import List
13
  import torch
14
  import torch.nn.functional as F
15
  from torch import nn
16
+ from torchvision.ops.boxes import nms
 
17
  from util.keypoint_ops import keypoint_xyzxyz_to_xyxyzz
18
+ from util import box_ops
19
+ from util.misc import (NestedTensor, nested_tensor_from_tensor_list,
20
+ accuracy, get_world_size, interpolate,
21
+ is_dist_avail_and_initialized, inverse_sigmoid)
22
  from .backbone import build_backbone
23
+ from .deformable_transformer import build_deformable_transformer
24
+ from .utils import sigmoid_focal_loss, MLP
25
+
26
  from ..registry import MODULE_BUILD_FUNCS
27
  from .mask_generate import prepare_for_mask, post_process
28
+ import random
29
+ from .utils import sigmoid_focal_loss, MLP, _get_activation_fn, gen_sineembed_for_position
30
+ from pathlib import Path
31
+ import clip
32
+
33
 
34
 
35
  class UniPose(nn.Module):
 
116
 
117
 
118
  device = "cuda" if torch.cuda.is_available() else "cpu"
119
+ model, _ = clip.load("ViT-B/32", device=device)
120
+ self.clip_model = model
121
+ visual_parameters = list(self.clip_model.visual.parameters())
122
+ #
123
+ for param in visual_parameters:
124
+ param.requires_grad = False
125
 
126
  self.pos_proj = nn.Linear(hidden_dim, 768)
127
  self.padding = nn.Embedding(1, 768)
 
540
  sub_sentence_present = args.sub_sentence_present
541
  except:
542
  sub_sentence_present = True
543
+ print('********* sub_sentence_present', sub_sentence_present)
544
 
545
  model = UniPose(
546
  backbone,
src/utils/dependencies/XPose/models/UniPose/utils.py CHANGED
@@ -345,4 +345,4 @@ class OKSLoss(nn.Module):
345
  linear=self.linear,
346
  sigmas=self.sigmas,
347
  eps=self.eps)
348
- return loss
 
345
  linear=self.linear,
346
  sigmas=self.sigmas,
347
  eps=self.eps)
348
+ return loss
src/utils/dependencies/XPose/transforms.py CHANGED
@@ -24,6 +24,7 @@ def crop(image, target, region):
24
  i, j, h, w = region
25
  id2catname = target["id2catname"]
26
  caption_list = target["caption_list"]
 
27
  target["size"] = torch.tensor([h, w])
28
 
29
  fields = ["labels", "area", "iscrowd", "positive_map","keypoints"]
 
24
  i, j, h, w = region
25
  id2catname = target["id2catname"]
26
  caption_list = target["caption_list"]
27
+ # should we do something wrt the original size?
28
  target["size"] = torch.tensor([h, w])
29
 
30
  fields = ["labels", "area", "iscrowd", "positive_map","keypoints"]
src/utils/dependencies/XPose/util/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
src/utils/dependencies/XPose/util/addict.py DELETED
@@ -1,159 +0,0 @@
1
- import copy
2
-
3
-
4
- class Dict(dict):
5
-
6
- def __init__(__self, *args, **kwargs):
7
- object.__setattr__(__self, '__parent', kwargs.pop('__parent', None))
8
- object.__setattr__(__self, '__key', kwargs.pop('__key', None))
9
- object.__setattr__(__self, '__frozen', False)
10
- for arg in args:
11
- if not arg:
12
- continue
13
- elif isinstance(arg, dict):
14
- for key, val in arg.items():
15
- __self[key] = __self._hook(val)
16
- elif isinstance(arg, tuple) and (not isinstance(arg[0], tuple)):
17
- __self[arg[0]] = __self._hook(arg[1])
18
- else:
19
- for key, val in iter(arg):
20
- __self[key] = __self._hook(val)
21
-
22
- for key, val in kwargs.items():
23
- __self[key] = __self._hook(val)
24
-
25
- def __setattr__(self, name, value):
26
- if hasattr(self.__class__, name):
27
- raise AttributeError("'Dict' object attribute "
28
- "'{0}' is read-only".format(name))
29
- else:
30
- self[name] = value
31
-
32
- def __setitem__(self, name, value):
33
- isFrozen = (hasattr(self, '__frozen') and
34
- object.__getattribute__(self, '__frozen'))
35
- if isFrozen and name not in super(Dict, self).keys():
36
- raise KeyError(name)
37
- super(Dict, self).__setitem__(name, value)
38
- try:
39
- p = object.__getattribute__(self, '__parent')
40
- key = object.__getattribute__(self, '__key')
41
- except AttributeError:
42
- p = None
43
- key = None
44
- if p is not None:
45
- p[key] = self
46
- object.__delattr__(self, '__parent')
47
- object.__delattr__(self, '__key')
48
-
49
- def __add__(self, other):
50
- if not self.keys():
51
- return other
52
- else:
53
- self_type = type(self).__name__
54
- other_type = type(other).__name__
55
- msg = "unsupported operand type(s) for +: '{}' and '{}'"
56
- raise TypeError(msg.format(self_type, other_type))
57
-
58
- @classmethod
59
- def _hook(cls, item):
60
- if isinstance(item, dict):
61
- return cls(item)
62
- elif isinstance(item, (list, tuple)):
63
- return type(item)(cls._hook(elem) for elem in item)
64
- return item
65
-
66
- def __getattr__(self, item):
67
- return self.__getitem__(item)
68
-
69
- def __missing__(self, name):
70
- if object.__getattribute__(self, '__frozen'):
71
- raise KeyError(name)
72
- return self.__class__(__parent=self, __key=name)
73
-
74
- def __delattr__(self, name):
75
- del self[name]
76
-
77
- def to_dict(self):
78
- base = {}
79
- for key, value in self.items():
80
- if isinstance(value, type(self)):
81
- base[key] = value.to_dict()
82
- elif isinstance(value, (list, tuple)):
83
- base[key] = type(value)(
84
- item.to_dict() if isinstance(item, type(self)) else
85
- item for item in value)
86
- else:
87
- base[key] = value
88
- return base
89
-
90
- def copy(self):
91
- return copy.copy(self)
92
-
93
- def deepcopy(self):
94
- return copy.deepcopy(self)
95
-
96
- def __deepcopy__(self, memo):
97
- other = self.__class__()
98
- memo[id(self)] = other
99
- for key, value in self.items():
100
- other[copy.deepcopy(key, memo)] = copy.deepcopy(value, memo)
101
- return other
102
-
103
- def update(self, *args, **kwargs):
104
- other = {}
105
- if args:
106
- if len(args) > 1:
107
- raise TypeError()
108
- other.update(args[0])
109
- other.update(kwargs)
110
- for k, v in other.items():
111
- if ((k not in self) or
112
- (not isinstance(self[k], dict)) or
113
- (not isinstance(v, dict))):
114
- self[k] = v
115
- else:
116
- self[k].update(v)
117
-
118
- def __getnewargs__(self):
119
- return tuple(self.items())
120
-
121
- def __getstate__(self):
122
- return self
123
-
124
- def __setstate__(self, state):
125
- self.update(state)
126
-
127
- def __or__(self, other):
128
- if not isinstance(other, (Dict, dict)):
129
- return NotImplemented
130
- new = Dict(self)
131
- new.update(other)
132
- return new
133
-
134
- def __ror__(self, other):
135
- if not isinstance(other, (Dict, dict)):
136
- return NotImplemented
137
- new = Dict(other)
138
- new.update(self)
139
- return new
140
-
141
- def __ior__(self, other):
142
- self.update(other)
143
- return self
144
-
145
- def setdefault(self, key, default=None):
146
- if key in self:
147
- return self[key]
148
- else:
149
- self[key] = default
150
- return default
151
-
152
- def freeze(self, shouldFreeze=True):
153
- object.__setattr__(self, '__frozen', shouldFreeze)
154
- for key, val in self.items():
155
- if isinstance(val, Dict):
156
- val.freeze(shouldFreeze)
157
-
158
- def unfreeze(self):
159
- self.freeze(False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/utils/dependencies/XPose/util/box_ops.py CHANGED
@@ -136,4 +136,4 @@ if __name__ == '__main__':
136
  x = torch.rand(5, 4)
137
  y = torch.rand(3, 4)
138
  iou, union = box_iou(x, y)
139
- import ipdb; ipdb.set_trace()
 
136
  x = torch.rand(5, 4)
137
  y = torch.rand(3, 4)
138
  iou, union = box_iou(x, y)
139
+ import ipdb; ipdb.set_trace()
src/utils/dependencies/XPose/util/config.py CHANGED
@@ -1,15 +1,17 @@
1
  # ==========================================================
2
  # Modified from mmcv
3
  # ==========================================================
4
- import sys
5
  import os.path as osp
6
  import ast
7
  import tempfile
8
  import shutil
9
  from importlib import import_module
 
10
  from argparse import Action
11
 
12
- from .addict import Dict
 
13
 
14
  BASE_KEY = '_base_'
15
  DELETE_KEY = '_delete_'
@@ -81,8 +83,6 @@ class Config(object):
81
  temp_config_file = tempfile.NamedTemporaryFile(
82
  dir=temp_config_dir, suffix='.py')
83
  temp_config_name = osp.basename(temp_config_file.name)
84
- # close temp file before copy
85
- temp_config_file.close()
86
  shutil.copyfile(filename,
87
  osp.join(temp_config_dir, temp_config_name))
88
  temp_module_name = osp.splitext(temp_config_name)[0]
@@ -97,8 +97,8 @@ class Config(object):
97
  }
98
  # delete imported module
99
  del sys.modules[temp_module_name]
100
-
101
-
102
  elif filename.lower().endswith(('.yml', '.yaml', '.json')):
103
  from .slio import slload
104
  cfg_dict = slload(filename)
@@ -304,6 +304,13 @@ class Config(object):
304
 
305
  cfg_dict = self._cfg_dict.to_dict()
306
  text = _format_dict(cfg_dict, outest_level=True)
 
 
 
 
 
 
 
307
  return text
308
 
309
 
 
1
  # ==========================================================
2
  # Modified from mmcv
3
  # ==========================================================
4
+ import os, sys
5
  import os.path as osp
6
  import ast
7
  import tempfile
8
  import shutil
9
  from importlib import import_module
10
+
11
  from argparse import Action
12
 
13
+ from addict import Dict
14
+ from yapf.yapflib.yapf_api import FormatCode
15
 
16
  BASE_KEY = '_base_'
17
  DELETE_KEY = '_delete_'
 
83
  temp_config_file = tempfile.NamedTemporaryFile(
84
  dir=temp_config_dir, suffix='.py')
85
  temp_config_name = osp.basename(temp_config_file.name)
 
 
86
  shutil.copyfile(filename,
87
  osp.join(temp_config_dir, temp_config_name))
88
  temp_module_name = osp.splitext(temp_config_name)[0]
 
97
  }
98
  # delete imported module
99
  del sys.modules[temp_module_name]
100
+ # close temp file
101
+ temp_config_file.close()
102
  elif filename.lower().endswith(('.yml', '.yaml', '.json')):
103
  from .slio import slload
104
  cfg_dict = slload(filename)
 
304
 
305
  cfg_dict = self._cfg_dict.to_dict()
306
  text = _format_dict(cfg_dict, outest_level=True)
307
+ # copied from setup.cfg
308
+ yapf_style = dict(
309
+ based_on_style='pep8',
310
+ blank_line_before_nested_class_or_def=True,
311
+ split_before_expression_after_opening_paren=True)
312
+ text, _ = FormatCode(text, style_config=yapf_style, verify=True)
313
+
314
  return text
315
 
316
 
src/utils/dependencies/XPose/util/get_param_dicts.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+
6
+ def match_name_keywords(n: str, name_keywords: list):
7
+ out = False
8
+ for b in name_keywords:
9
+ if b in n:
10
+ out = True
11
+ break
12
+ return out
13
+
14
+
15
+ def get_param_dict(args, model_without_ddp: nn.Module):
16
+ try:
17
+ param_dict_type = args.param_dict_type
18
+ except:
19
+ param_dict_type = 'default'
20
+ assert param_dict_type in ['default', 'ddetr_in_mmdet', 'large_wd']
21
+
22
+ # by default
23
+ if param_dict_type == 'default':
24
+ param_dicts = [
25
+ {"params": [p for n, p in model_without_ddp.named_parameters() if "backbone" not in n and "bert" not in n and p.requires_grad]},
26
+ {
27
+ "params": [p for n, p in model_without_ddp.named_parameters() if "backbone" in n and p.requires_grad],
28
+ "lr": args.lr_backbone,
29
+ },
30
+ {
31
+ "params": [p for n, p in model_without_ddp.named_parameters() if "bert" in n and p.requires_grad],
32
+ "lr": args.lr_backbone,
33
+ }
34
+ ]
35
+
36
+ param_name_dicts = [
37
+ {"params": [n for n, p in model_without_ddp.named_parameters() if "backbone" not in n and "bert" not in n and p.requires_grad]},
38
+ {
39
+ "params": [n for n, p in model_without_ddp.named_parameters() if "backbone" in n and p.requires_grad],
40
+ "lr": args.lr_backbone,
41
+ },
42
+ {
43
+ "params": [n for n, p in model_without_ddp.named_parameters() if "bert" in n and p.requires_grad],
44
+ "lr": args.lr_backbone,
45
+ }
46
+ ]
47
+
48
+ print('param_name_dicts: ', json.dumps(param_name_dicts, indent=2))
49
+
50
+ return param_dicts, param_name_dicts
51
+
52
+
53
+
54
+
55
+ raise NotImplementedError
56
+
57
+
58
+
59
+ # print("param_dicts: {}".format(param_dicts))
60
+
61
+ return param_dicts, None
src/utils/dependencies/XPose/util/instance.txt ADDED
@@ -0,0 +1,863 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ AnimalKindom (850 but only providing 402 non-overlapping animal classes):
2
+ Orange Clownfish
3
+ Fish
4
+ Damsel Fish
5
+ Sterlet Fish
6
+ Trout
7
+ Danube Bleak Fish
8
+ Grayling Fish
9
+ Catfish
10
+ Sea Toad Fish
11
+ Pike Perch Fish
12
+ Sardine
13
+ Archer Fish
14
+ Giant Trevally
15
+ Atlantic Blue Tang Fish
16
+ Trout Young
17
+ Stonefish
18
+ Yellow Watchman Goby Fish
19
+ Salmon
20
+ Danube Salmon
21
+ Mimic Blenny Fish
22
+ Pink Skunk Clownfish
23
+ Barracuda Fish
24
+ Surgeonfish
25
+ Keeltail Needlefish
26
+ Royal Grammas Fish
27
+ Tench Fish
28
+ Clownfish
29
+ Barramundi Fish
30
+ Sergeant Major Fish
31
+ Gray Angelfish
32
+ Butterfly Fish
33
+ Yellow Wrasse Fish
34
+ Lionfish
35
+ Toadfish
36
+ Perch Fish
37
+ Round Face Bat Fish
38
+ Goby Fish
39
+ Horned Adder
40
+ Atheris Hispida Viper
41
+ Fer-De-Lance Snake
42
+ Naja Nivea Snake
43
+ Elegant Bronzeback Snake
44
+ S Viper"]
45
+ Black Necked Spitting Cobra
46
+ Dice Snake
47
+ Boa
48
+ Dispholidus Typus Snake
49
+ Red Spitting Cobra
50
+ Wild Red-Tailed Boa
51
+ S Spitting Cobra"]
52
+ Coronella Austriaca Snake
53
+ Lesser Sunda Pit Viper
54
+ Oriental Whip Snake
55
+ Snake
56
+ Mojave Rattlesnake
57
+ Annulated Tree Boa
58
+ Bushmaster Snake
59
+ Sidewinder Rattlesnake
60
+ Vipera Berus Snake
61
+ Reticulated Python
62
+ King Cobra
63
+ Atheris Squamigera
64
+ Lichanura Trivirgata Snake
65
+ Crotalus Willardi Ridge Nosed Rattlesnake
66
+ Slender Hognosed Pit Viper
67
+ Paradise Tree Snake
68
+ Coral Mimic Snake
69
+ Many Horned Adder
70
+ Dendroaspis Polylepis Black Mamba
71
+ Northern Pacific Rattlesnake
72
+ S Boa"]
73
+ Mozambique Spitting Cobra
74
+ Grass Snake
75
+ Thelotornis Snake
76
+ Black Headed Python
77
+ Metlapilcoatlus Mexicanus Jumping Pit Viper
78
+ Natrix Natrix Snake
79
+ Atheris Nitschei Viper
80
+ Puff Adder
81
+ Rhamnophis Aethiopissa Snake
82
+ Coral Snake
83
+ Montpellier Snake
84
+ Bothriechis
85
+ Lampropeltis Pyromelana Snake
86
+ Lampropeltis Zonata Snake
87
+ Natrix Tessellata Snake
88
+ Thamnophis Cyrtopsis Snake
89
+ Rat Snake
90
+ White Speckled Rattlesnake
91
+ Namaqua Dwarf Adder
92
+ Aesculapian Snake
93
+ Nose-Horned Viper
94
+ Zamenis Longissiumus Snake
95
+ Rhinoceros Viper
96
+ Rattlesnake
97
+ Red Bellied Black Snake
98
+ Eyelash Pit Viper
99
+ Snouted Cobra
100
+ JamesonS Mamba
101
+ Gaboon Viper
102
+ Lampropeltis Splendida Snake
103
+ Laticauda Saintgironsi Sea Krait
104
+ Pituophis Catenifer Snake
105
+ Eastern Montpellier Snake
106
+ Black Mamba
107
+ Javan Spitting Cobra
108
+ Sea Snake
109
+ Gyalopion Canum Snake
110
+ Mojave Rattlesnake Young
111
+ Spectacled Cobra
112
+ Hognosed Pit Viper
113
+ Dog Faced Water Snake
114
+ Cobra
115
+ Tree Snake
116
+ Lampropeltis Getula Snake
117
+ S Mamba"]
118
+ Bothrops Asper
119
+ Rinkhals Snake
120
+ Malabar Pit Viper
121
+ Western Diamondback Young
122
+ Micruroides Euryxanthus Snake
123
+ Western Diamondback
124
+ Python
125
+ Viper
126
+ Horse
127
+ Horse Young
128
+ Texas Brown Tarantula
129
+ Tarantula
130
+ Diving Bell Water Spider
131
+ Turret Spider
132
+ Avicularia Spider
133
+ Golden Orb Spider
134
+ Spider
135
+ Salticidae Jumping Spider
136
+ Salticidae 3 Spider
137
+ Jumping Spider
138
+ Orb Spider
139
+ Latrodectus Hersperus Western Widow Spider
140
+ Redback Spider
141
+ Western Widow Spider
142
+ Nursery Web Spider
143
+ Grass Spider
144
+ Araneus Diadematus Spider
145
+ Tetragnatha Versicolor
146
+ Cyclosa Conica
147
+ Daddy Longlegs Spider
148
+ Portia Jumping Spider
149
+ Portia 3 Spider
150
+ Habronattus Clypeatus
151
+ Chimpanzee
152
+ Gorilla
153
+ Mountain Gorilla
154
+ Stump-Tailed Macaque
155
+ Pig-Tailed Macaque
156
+ Monkey
157
+ Orangutan
158
+ Grey Langur
159
+ Mitred Leaf Monkey
160
+ Gibbon
161
+ Red Ruffed Lemur
162
+ Proboscis Monkey
163
+ Sumatran Orangutan
164
+ Red-Ruffed Lemur
165
+ White-Faced Saki Monkey
166
+ Bornean Orangutan
167
+ Mandrill
168
+ Maroon Macaque
169
+ Rhesus Macaque
170
+ Raffles Banded Langur
171
+ Western Chimpanzee Young
172
+ Diana Monkey
173
+ Lesser Spot Nosed Monkey
174
+ S Monkey"]
175
+ Sooty Mangabey
176
+ King Colobus
177
+ Olive Colobus
178
+ Capuchin Monkey
179
+ Red-Blacked Squirrel Monkey
180
+ Ring-Tailed Lemur
181
+ Iguana
182
+ Marine Iguana
183
+ Rhacodactylus Trachyrhynchus Gecko
184
+ Gecko
185
+ Lizard
186
+ Common Basilisk Lizard
187
+ Basilisk Lizard
188
+ Monitor Lizard
189
+ Strange-Horned Chameleon
190
+ Chameleon
191
+ Leaf-Tailed Gecko
192
+ Clouded Monitor Lizard
193
+ S Chameleon"]
194
+ Indian Chameleon
195
+ Namaqua Dwarf Chameleon
196
+ Malayan Water Monitor Lizard
197
+ Green Iguana
198
+ White Wig Marine Iguana
199
+ Skink
200
+ Black Bearded Draco
201
+ Yellow Striped Tree Skink
202
+ Side Blotched Lizard
203
+ Frilled Neck Lizard
204
+ Leopard Gecko
205
+ Morning Gecko
206
+ Giant Ground Gecko
207
+ Western Dwarf Chameleon
208
+ Whooper Swan
209
+ Goose
210
+ Puffin
211
+ Mallard Duck
212
+ Greylag Goose
213
+ S Duck"]
214
+ African Finfoot
215
+ Wandering Alabatross
216
+ Duck
217
+ Common Goldeneye
218
+ Garganey
219
+ Black Stork
220
+ Black Swan
221
+ Сommon Eider
222
+ Carolina Duck
223
+ Smew
224
+ Lanius Excubitor
225
+ Song Thrush Bird
226
+ Wedge Tailed Eagle
227
+ Great Grey Shrike
228
+ Anthus Pratensis Bird
229
+ Red-Throated Pipit
230
+ Tit Bird
231
+ Hoopoe
232
+ Woodpecker
233
+ Botaurus Stellaris Bird
234
+ Little Crake Bird
235
+ White-Throated Dipper
236
+ Raven
237
+ Citrine Wagtail Bird
238
+ Yellowhammer Young
239
+ Grey Heron
240
+ Alauda Arvensis Bird
241
+ Bird
242
+ Common Cuckoo Bird
243
+ Tringa Erythropus Bird
244
+ White Throated Dipper Bird
245
+ Skylark
246
+ Mistle Thrush
247
+ Robin Bird
248
+ Australian Bowerbird
249
+ Turtle Dove
250
+ Black-Winged Stilt
251
+ Wood Warbler
252
+ Common Crane
253
+ Eurasian Wren Bird
254
+ Common Quail
255
+ Nightingale Bird
256
+ Tawny Owl
257
+ Grebe Bird
258
+ Water Dipper Bird
259
+ Yellowhammer
260
+ Hazel Grouse Bird
261
+ Greater Racket Tail Drongo
262
+ Golden Oriole
263
+ Great Egret
264
+ Turdus Merula Blackbird
265
+ Nuthatch Bird
266
+ Eagle
267
+ Luscinia Luscinia Nightingale Bird
268
+ Singing Nightingale
269
+ Water Rail Bird
270
+ Gull
271
+ Azure Tit Bird
272
+ Numenius Arquata Bird
273
+ Golden Eagle
274
+ Remiz Pendulinus Bird
275
+ Goldfinch
276
+ Common Whitethroat Bird
277
+ Red-Backed Shrike Bird
278
+ Grasshopper Warbler
279
+ Shoebill Bird
280
+ Common Rosefinch Bird
281
+ Owl
282
+ Chaffinch Bird
283
+ Bluethroat
284
+ Green Woodpecker
285
+ Common Snipe
286
+ Whinchat Bird
287
+ Ostrich
288
+ Boreal Owl
289
+ European Robin Bird
290
+ Larus Canus Bird
291
+ Hawk
292
+ Three-Toed Woodpecker
293
+ Thrush Nightingale Bird
294
+ Jack Snipe Bird
295
+ Red Crossbill
296
+ Chiffchaff Bird
297
+ Shorebird
298
+ Bullfinch
299
+ Red-Backed Shrike Bird Young
300
+ Circus Aeruginosus Bird
301
+ Kingfisher
302
+ White-Backed Woodpecker
303
+ Tringa Ochropus Bird
304
+ Stock Dove
305
+ Heron
306
+ Citrine Wagtail
307
+ Vanellus Vanellus Bird
308
+ Tringa Nebularia Bird
309
+ Eurasian Wryneck Bird
310
+ Tachybaptus Ruficollis Bird
311
+ Quail
312
+ Little Egret
313
+ Stork
314
+ Green Mamba
315
+ Pufferfish
316
+ Bullfrog
317
+ Frog
318
+ Corroboree Frog
319
+ Desert Rain Frog
320
+ African Clawed Toad
321
+ Mountain Yellow-Legged Frog
322
+ Tropical Reed Frog
323
+ Mimic Poison Frog
324
+ Red-Eyed Tree Frog
325
+ S Frog"]
326
+ Water Lily Frog
327
+ Amazon Milk Frog
328
+ Toad
329
+ Marbled Rubber Frog
330
+ Sand Frog
331
+ Golden Poison Frog
332
+ Rain Frog
333
+ Monster Frog
334
+ S Warbler Bird"]
335
+ Great Snipe
336
+ European Serin Bird
337
+ Calidris Apina Bird
338
+ Peacock
339
+ Tringa Glareola Bird
340
+ Cuckoo Bird
341
+ Barred Warbler Bird
342
+ Pacman Frog
343
+ Ardea Alba Egret
344
+ Tern
345
+ Motacilla Alba Bird
346
+ Motacilla Flava
347
+ Anas Crecca Bird
348
+ Blue Poison Dart Frog
349
+ Marsh Harrier Bird
350
+ Glass Frog
351
+ Great Reed Warbler Bird
352
+ Banded Rubber Frog
353
+ Tomato Frog
354
+ Hornbill
355
+ Woodlark Bird
356
+ Starling Bird
357
+ Common Buzzard
358
+ Gallinago Gallinago Bird
359
+ White And Gray Wagtail Bird
360
+ Strawberry Poison-Dart Frog
361
+ Corncrake
362
+ Bald Eagle
363
+ S Harrier Young"]
364
+ Charadrius Dubius Bird
365
+ Pelican
366
+ Flamingo Young
367
+ Socotran Cormorant
368
+ Sparrowhawk
369
+ S Harrier"]
370
+ Pygmy Owl
371
+ Philomachus Pugnax Ruff Bird
372
+ Wren
373
+ Common Wood Pigeon
374
+ Grass Warbler Bird
375
+ Whiskered Tern Bird
376
+ Icterine Warbler Bird
377
+ Crowned Eagle
378
+ Crane
379
+ Hummingbird
380
+ Ardeotis Kori Bird
381
+ Guttural Toad
382
+ Crested Grebe Bird
383
+ Reed Bunting Bird
384
+ White Cockatoo Bird
385
+ Sedge Warbler Bird
386
+ Goldcrest Bird
387
+ Montagus Harrier Young
388
+ European Turtle Dove
389
+ Asian Glossy Starling Bird
390
+ Spotted Wood Owl
391
+ Sagittarius Serpentarius Bird
392
+ Parrot
393
+ Anas Platyrhynchos Bird
394
+ Phalacrocorax Carbo Bird
395
+ White-Breasted Waterhen
396
+ African Bullfrog
397
+ Stork-Billed Kingfisher
398
+ Oriental Pied Hornbill
399
+ Flamingo
400
+ Banded Woodpecker
401
+ Foam Nest Frog
402
+ Vulture
403
+ Larus Ridibundus Bird
404
+
405
+ AP-10K & APT-36K:
406
+ monkey
407
+ elephant
408
+ leopard
409
+ horse
410
+ jaguar
411
+ panda
412
+ marmot
413
+ deer
414
+ noisy night monkey
415
+ orangutan
416
+ sheep
417
+ spider-monkey
418
+ bison
419
+ zebra
420
+ dog
421
+ weasel
422
+ bat
423
+ uakari
424
+ raccoon
425
+ tiger
426
+ rat
427
+ rhino
428
+ chimpanzee
429
+ antelope
430
+ argali sheep
431
+ gorilla
432
+ buffalo
433
+ bobcat
434
+ hippo
435
+ mouse
436
+ moose
437
+ howling-monkey
438
+ black-bear
439
+ wolf
440
+ squirrel
441
+ skunk
442
+ king cheetah
443
+ cheetah
444
+ spider monkey
445
+ hamster
446
+ arctic fox
447
+ polar bear
448
+ rabbit
449
+ panther
450
+ cow
451
+ brown bear
452
+ otter
453
+ beaver
454
+ pig
455
+ fox
456
+ alouatta
457
+ giraffe
458
+ polar-bear
459
+ raccon
460
+ snow leopard
461
+ lion
462
+ cat
463
+ mole
464
+ black bear
465
+
466
+ Desert Locust
467
+
468
+ Vinegar Fly
469
+
470
+
471
+ CUB-200-2011:
472
+ grebe_body
473
+ gull_body
474
+ kingfisher_body
475
+ sparrow_body
476
+ tern_body
477
+ warbler_body
478
+ woodpecker_body
479
+ wren_body
480
+
481
+
482
+
483
+ Carfusion:
484
+ bus
485
+ car
486
+ suv
487
+
488
+
489
+ Deepfashion2:
490
+ short sleeve top
491
+ long sleeve top
492
+ short sleeve outwear
493
+ long sleeve outwear
494
+ vest
495
+ sling
496
+ shorts
497
+ trousers
498
+ skirt
499
+ short sleeve dress
500
+ long sleeve dress
501
+ vest dress
502
+ sling dress
503
+
504
+
505
+
506
+ Keypoint-5:
507
+ bed
508
+ chair
509
+ sofa
510
+ swivelchair
511
+ table
512
+
513
+ AnimalWeb:
514
+ blackbuck
515
+ small asian mongoose
516
+ common dwarf mongoose
517
+ galapagos sea lion
518
+ margay
519
+ nilgai
520
+ Humboldt penguin
521
+ oryx
522
+ tammar wallaby
523
+ monkey
524
+ swamp wallaby
525
+ muntjac deer
526
+ blue-eyed black lemur
527
+ binturong
528
+ hamadryas baboon
529
+ Adelie penguin
530
+ Australian cattle dog
531
+ howler
532
+ striped hyena
533
+ vole
534
+ zebu
535
+ woodchuck
536
+ proboscis monkey
537
+ whiptail wallaby
538
+ anoa
539
+ hippopotamus
540
+ crested penguin
541
+ addax
542
+ red-bellied squirrel
543
+ suni
544
+ feral cat
545
+ galagos
546
+ banteng
547
+ Weddell seal
548
+ zebra
549
+ Ethiopian wolf
550
+ snow leopard
551
+ common chimpanzee
552
+ giant schnauzer
553
+ lemur
554
+ jaguarundi
555
+ Asian golden cat
556
+ gray wolf
557
+ anteater
558
+ golden jackal
559
+ banded palm civet
560
+ cougar
561
+ Barbary macaque
562
+ giant otter
563
+ agouti
564
+ emperor penguin
565
+ feral horse
566
+ yellow-footed rock wallaby
567
+ raccoon
568
+ topi
569
+ opossum
570
+ central chimpanzee
571
+ pygmy rabbit
572
+ fishing cat
573
+ reedbuck
574
+ Mediterranean monk seal
575
+ domestic cat
576
+ kangaroo
577
+ boar
578
+ rusty-spotted cat
579
+ spider monkey
580
+ echidna
581
+ Chinese goral
582
+ ringtail
583
+ kultarr
584
+ Californian sea lion
585
+ guanaco
586
+ muriqui
587
+ gerbil
588
+ wildebeest
589
+ bison
590
+ Australian terrier
591
+ hyrax
592
+ clouded leopard
593
+ goat
594
+ badger
595
+ beaver
596
+ Przewalski horse
597
+ camel
598
+ beating mongoose
599
+ field mouse
600
+ collared peccary
601
+ tree shrew
602
+ wombat
603
+ titi
604
+ steenbuck steenbok
605
+ Australian sea lion
606
+ buffalo
607
+ chamois
608
+ baikal seal
609
+ brush-tailed rock wallaby
610
+ bongo
611
+ Barbary sheep
612
+ great dane
613
+ cheetah
614
+ long-nosed mongoose
615
+ cape buffalo
616
+ waterbuck
617
+ rhesus monkey
618
+ jungle cat
619
+ black-and-white ruffed lemur
620
+ Japanese serow
621
+ potto
622
+ dall sheep
623
+ indri
624
+ large-spotted genet
625
+ Amur leopard
626
+ Owston's palm civet
627
+ dingo
628
+ gibbons
629
+ Doberman
630
+ giraffe
631
+ fox
632
+ quokka
633
+ Amur tiger
634
+ wild ass
635
+ walrus
636
+ common genet
637
+ bilby
638
+ hamster
639
+ yellow-eyed penguin
640
+ panda
641
+ agile wallaby
642
+ bengal slow loris
643
+ marmoset
644
+ brown hyena
645
+ gorilla
646
+ aardvark
647
+ swift fox
648
+ Magellanic penguin
649
+ bear
650
+ Anatolian shepherd dog
651
+ irish wolfhound
652
+ husky
653
+ kinkajou
654
+ brown rat
655
+ tarsiers
656
+ matschie's tree kangaroo
657
+ black-backed jackal
658
+ ocelot
659
+ grey seal
660
+ bullmastiff
661
+ gentoo penguin
662
+ gerenuk
663
+ bearded seal
664
+ hooded seal
665
+ monte
666
+ leopard cat
667
+ western chimpanzee
668
+ african penguin
669
+ dikdik
670
+ komondor
671
+ coypu
672
+ dalmatian
673
+ armadillo
674
+ marsh mongoose
675
+ rusty-spotted genet
676
+ lion
677
+ bharal
678
+ wolverine
679
+ visayan warty pig
680
+ lutung
681
+ bornean slow loris
682
+ caiman
683
+ aardwolf
684
+ german shepherd dog
685
+ sunda slow loris
686
+ mangabey
687
+ cape gray mongoose
688
+ crowned lemur
689
+ harp seal
690
+ gelada baboon
691
+ wallaroo
692
+ hare
693
+ goodfellow's tree kangaroo
694
+ elk
695
+ muskox
696
+ capybara
697
+ toque macaque
698
+ roe deer
699
+ eastern lesser bamboo lemur
700
+ leopard
701
+ wapiti
702
+ gray fox
703
+ alpaca
704
+ guinea pig
705
+ crabeater seal
706
+ black rhino
707
+ little blue penguin
708
+ bighorn sheep
709
+ caracal
710
+ tamarin
711
+ hawaiian monk seal
712
+ lumholtz's tree kangaroo
713
+ koala
714
+ gundi
715
+ onager
716
+ cacomistle
717
+ red-ruffed lemur
718
+ orangutan
719
+ bobcat
720
+ black-footed cat
721
+ alaskan hare
722
+ debrazza's monkey
723
+ swamp rabbit
724
+ white wolf
725
+ sharpe's grysbok
726
+ urial
727
+ feral goat
728
+ serval
729
+ degu
730
+ golden bamboo lemur
731
+ deer mouse
732
+ coatis
733
+ wildcat
734
+ roan antelope
735
+ dugong
736
+ fennec fox
737
+ southern elephant seal
738
+ saluki
739
+ golden langur
740
+ oribi
741
+ red-tail monkey
742
+ chital
743
+ dormouse
744
+ woolly monkey
745
+ leopard seal
746
+ possum
747
+ arctic wolf
748
+ japanese macaque
749
+ vervet monkey
750
+ bamboo lemur
751
+ aye-aye
752
+ night monkey
753
+ blue monkey
754
+ sand cat
755
+ bull
756
+ cape fox
757
+ klipspringer
758
+ border collie
759
+ mouflon
760
+ chipmunk
761
+ potoroo
762
+ bushbuck
763
+ northern elephant seal
764
+ patagonian mara
765
+ bandicoot
766
+ feral cattle
767
+ babirusa
768
+ harvest mouse
769
+ alaskan malamute
770
+ servaline genet
771
+ olive baboon
772
+ italian greyhound
773
+ white-headed lemur
774
+ chihuahua
775
+ red-necked wallaby
776
+ fallow deer
777
+ pygmy slow loris
778
+ australian shepherd
779
+ eastern chimpanzee
780
+ colobus
781
+ chinstrap penguin
782
+ deer
783
+ common warthog
784
+ dunnart
785
+ wisent
786
+ hedgehog
787
+ douc langur
788
+ tasmanian devil
789
+ colo
790
+ flying squirrel
791
+ canadian lynx
792
+ ferret
793
+ ribbon seal
794
+ platypus
795
+ cotton rat
796
+ oncilla
797
+ geoffroy's cat
798
+ horse
799
+ pardine genet
800
+ slender mongoose
801
+ liger
802
+ mareeba rock wallaby
803
+ olingos
804
+ bonobo
805
+ harbour seal
806
+ pademelon
807
+ domestic dog
808
+ chow chow
809
+ gharial
810
+ quoll
811
+ capuchin monkey
812
+ corsac fox
813
+ dassie
814
+ bolognese dog
815
+ ruddy mongoose
816
+ rhinoceros
817
+ red panda
818
+ king penguin
819
+ dachshund
820
+ common brown lemur
821
+ pekingese dog
822
+ western lesser bamboo lemur
823
+ banded mongoose
824
+ grey langur
825
+ patas monkey
826
+ francois langur
827
+ white-tailed deer
828
+ african wild dog
829
+ collared brown lemur
830
+ weasel
831
+ mexican wolf
832
+ hartebeest
833
+ uakari
834
+ viverrata ngalunga malayan civet
835
+ dhole
836
+ eurasian lynx
837
+ hog deer
838
+ bushbaby
839
+ grizzly bear
840
+ caribou
841
+ german pinscher
842
+ jaguar
843
+ donkey
844
+ duiker
845
+ spotted hyena
846
+ golden retriever
847
+ pantanal cat
848
+ spotted-necked otter
849
+ asian palm civet
850
+ alpine ibex
851
+ jackrabbit
852
+ greater bamboo lemur
853
+ kiang
854
+ common kusimanse
855
+ pallas cat
856
+ stripe-necked mongoose
857
+ parma wallaby
858
+ yak
859
+ balinese cat
860
+ spotted seal
861
+ french bulldog
862
+ zonkey
863
+ arctic fox
src/utils/dependencies/XPose/util/logger.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ import functools
3
+ import logging
4
+ import os
5
+ import sys
6
+ from termcolor import colored
7
+
8
+
9
+ class _ColorfulFormatter(logging.Formatter):
10
+ def __init__(self, *args, **kwargs):
11
+ self._root_name = kwargs.pop("root_name") + "."
12
+ self._abbrev_name = kwargs.pop("abbrev_name", "")
13
+ if len(self._abbrev_name):
14
+ self._abbrev_name = self._abbrev_name + "."
15
+ super(_ColorfulFormatter, self).__init__(*args, **kwargs)
16
+
17
+ def formatMessage(self, record):
18
+ record.name = record.name.replace(self._root_name, self._abbrev_name)
19
+ log = super(_ColorfulFormatter, self).formatMessage(record)
20
+ if record.levelno == logging.WARNING:
21
+ prefix = colored("WARNING", "red", attrs=["blink"])
22
+ elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL:
23
+ prefix = colored("ERROR", "red", attrs=["blink", "underline"])
24
+ else:
25
+ return log
26
+ return prefix + " " + log
27
+
28
+
29
+ # so that calling setup_logger multiple times won't add many handlers
30
+ @functools.lru_cache()
31
+ def setup_logger(
32
+ output=None, distributed_rank=0, *, color=True, name="imagenet", abbrev_name=None
33
+ ):
34
+ """
35
+ Initialize the detectron2 logger and set its verbosity level to "INFO".
36
+
37
+ Args:
38
+ output (str): a file name or a directory to save log. If None, will not save log file.
39
+ If ends with ".txt" or ".log", assumed to be a file name.
40
+ Otherwise, logs will be saved to `output/log.txt`.
41
+ name (str): the root module name of this logger
42
+
43
+ Returns:
44
+ logging.Logger: a logger
45
+ """
46
+ logger = logging.getLogger(name)
47
+ logger.setLevel(logging.DEBUG)
48
+ logger.propagate = False
49
+
50
+ if abbrev_name is None:
51
+ abbrev_name = name
52
+
53
+ plain_formatter = logging.Formatter(
54
+ '[%(asctime)s.%(msecs)03d]: %(message)s',
55
+ datefmt='%m/%d %H:%M:%S'
56
+ )
57
+ # stdout logging: master only
58
+ if distributed_rank == 0:
59
+ ch = logging.StreamHandler(stream=sys.stdout)
60
+ ch.setLevel(logging.DEBUG)
61
+ if color:
62
+ formatter = _ColorfulFormatter(
63
+ colored("[%(asctime)s.%(msecs)03d]: ", "green") + "%(message)s",
64
+ datefmt="%m/%d %H:%M:%S",
65
+ root_name=name,
66
+ abbrev_name=str(abbrev_name),
67
+ )
68
+ else:
69
+ formatter = plain_formatter
70
+ ch.setFormatter(formatter)
71
+ logger.addHandler(ch)
72
+
73
+ # file logging: all workers
74
+ if output is not None:
75
+ if output.endswith(".txt") or output.endswith(".log"):
76
+ filename = output
77
+ else:
78
+ filename = os.path.join(output, "log.txt")
79
+ if distributed_rank > 0:
80
+ filename = filename + f".rank{distributed_rank}"
81
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
82
+
83
+ fh = logging.StreamHandler(_cached_log_stream(filename))
84
+ fh.setLevel(logging.DEBUG)
85
+ fh.setFormatter(plain_formatter)
86
+ logger.addHandler(fh)
87
+
88
+ return logger
89
+
90
+
91
+ # cache the opened file object, so that different calls to `setup_logger`
92
+ # with the same file name can safely write to the same file.
93
+ @functools.lru_cache(maxsize=None)
94
+ def _cached_log_stream(filename):
95
+ return open(filename, "a")
src/utils/dependencies/XPose/util/metrics.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Aishwarya Kamath & Nicolas Carion. Licensed under the Apache License 2.0. All Rights Reserved
2
+ """
3
+ Various utilities related to track and report metrics
4
+ """
5
+ import datetime
6
+ import time
7
+ from collections import defaultdict, deque
8
+
9
+ import torch
10
+ import torch.distributed as dist
11
+
12
+ from util.misc import is_dist_avail_and_initialized
13
+
14
+
15
+ class SmoothedValue:
16
+ """Track a series of values and provide access to smoothed values over a
17
+ window or the global series average.
18
+ """
19
+
20
+ def __init__(self, window_size=20, fmt=None):
21
+ if fmt is None:
22
+ fmt = "{median:.4f} ({global_avg:.4f})"
23
+ self.deque = deque(maxlen=window_size)
24
+ self.total = 0.0
25
+ self.count = 0
26
+ self.fmt = fmt
27
+
28
+ def update(self, value, num=1):
29
+ self.deque.append(value)
30
+ self.count += num
31
+ self.total += value * num
32
+
33
+ def synchronize_between_processes(self):
34
+ """
35
+ Distributed synchronization of the metric
36
+ Warning: does not synchronize the deque!
37
+ """
38
+ if not is_dist_avail_and_initialized():
39
+ return
40
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
41
+ dist.barrier()
42
+ dist.all_reduce(t)
43
+ t = t.tolist()
44
+ self.count = int(t[0])
45
+ self.total = t[1]
46
+
47
+ @property
48
+ def median(self):
49
+ d = torch.tensor(list(self.deque))
50
+ return d.median().item()
51
+
52
+ @property
53
+ def avg(self):
54
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
55
+ return d.mean().item()
56
+
57
+ @property
58
+ def global_avg(self):
59
+ return self.total / self.count
60
+
61
+ @property
62
+ def max(self):
63
+ return max(self.deque)
64
+
65
+ @property
66
+ def value(self):
67
+ return self.deque[-1]
68
+
69
+ def __str__(self):
70
+ return self.fmt.format(
71
+ median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value
72
+ )
73
+
74
+
75
+ class MetricLogger(object):
76
+ def __init__(self, delimiter="\t"):
77
+ self.meters = defaultdict(SmoothedValue)
78
+ self.delimiter = delimiter
79
+
80
+ def update(self, **kwargs):
81
+ for k, v in kwargs.items():
82
+ if isinstance(v, torch.Tensor):
83
+ v = v.item()
84
+ assert isinstance(v, (float, int))
85
+ self.meters[k].update(v)
86
+
87
+ def __getattr__(self, attr):
88
+ if attr in self.meters:
89
+ return self.meters[attr]
90
+ if attr in self.__dict__:
91
+ return self.__dict__[attr]
92
+ raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, attr))
93
+
94
+ def __str__(self):
95
+ loss_str = []
96
+ for name, meter in self.meters.items():
97
+ loss_str.append("{}: {}".format(name, str(meter)))
98
+ return self.delimiter.join(loss_str)
99
+
100
+ def synchronize_between_processes(self):
101
+ for meter in self.meters.values():
102
+ meter.synchronize_between_processes()
103
+
104
+ def add_meter(self, name, meter):
105
+ self.meters[name] = meter
106
+
107
+ def log_every(self, iterable, print_freq, header=None):
108
+ i = 0
109
+ if not header:
110
+ header = ""
111
+ start_time = time.time()
112
+ end = time.time()
113
+ iter_time = SmoothedValue(fmt="{avg:.4f}")
114
+ data_time = SmoothedValue(fmt="{avg:.4f}")
115
+ space_fmt = ":" + str(len(str(len(iterable)))) + "d"
116
+ if torch.cuda.is_available():
117
+ log_msg = self.delimiter.join(
118
+ [
119
+ header,
120
+ "[{0" + space_fmt + "}/{1}]",
121
+ "eta: {eta}",
122
+ "{meters}",
123
+ "time: {time}",
124
+ "data: {data}",
125
+ "max mem: {memory:.0f}",
126
+ ]
127
+ )
128
+ else:
129
+ log_msg = self.delimiter.join(
130
+ [header, "[{0" + space_fmt + "}/{1}]", "eta: {eta}", "{meters}", "time: {time}", "data: {data}"]
131
+ )
132
+ MB = 1024.0 * 1024.0
133
+ for obj in iterable:
134
+ data_time.update(time.time() - end)
135
+ yield obj
136
+ iter_time.update(time.time() - end)
137
+ if i % print_freq == 0 or i == len(iterable) - 1:
138
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
139
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
140
+ if torch.cuda.is_available():
141
+ print(
142
+ log_msg.format(
143
+ i,
144
+ len(iterable),
145
+ eta=eta_string,
146
+ meters=str(self),
147
+ time=str(iter_time),
148
+ data=str(data_time),
149
+ memory=torch.cuda.max_memory_allocated() / MB,
150
+ )
151
+ )
152
+ else:
153
+ print(
154
+ log_msg.format(
155
+ i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time)
156
+ )
157
+ )
158
+ i += 1
159
+ end = time.time()
160
+ total_time = time.time() - start_time
161
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
162
+ print("{} Total time: {} ({:.4f} s / it)".format(header, total_time_str, total_time / len(iterable)))
163
+
164
+
165
+ @torch.no_grad()
166
+ def accuracy(output, target, topk=(1,)):
167
+ """Computes the precision@k for the specified values of k"""
168
+ if target.numel() == 0:
169
+ return [torch.zeros([], device=output.device)]
170
+ maxk = max(topk)
171
+ batch_size = target.size(0)
172
+
173
+ _, pred = output.topk(maxk, 1, True, True)
174
+ pred = pred.t()
175
+ correct = pred.eq(target.view(1, -1).expand_as(pred))
176
+
177
+ res = []
178
+ for k in topk:
179
+ correct_k = correct[:k].view(-1).float().sum(0)
180
+ res.append(correct_k.mul_(100.0 / batch_size))
181
+ return res
src/utils/dependencies/XPose/util/optim.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Aishwarya Kamath & Nicolas Carion. Licensed under the Apache License 2.0. All Rights Reserved
2
+ """Collections of utilities related to optimization."""
3
+ from bisect import bisect_right
4
+ import os
5
+
6
+ import torch
7
+
8
+
9
+ def update_ema(model, model_ema, decay):
10
+ """Apply exponential moving average update.
11
+
12
+ The weights are updated in-place as follow:
13
+ w_ema = w_ema * decay + (1 - decay) * w
14
+ Args:
15
+ model: active model that is being optimized
16
+ model_ema: running average model
17
+ decay: exponential decay parameter
18
+ """
19
+ with torch.no_grad():
20
+ if hasattr(model, "module"):
21
+ # unwrapping DDP
22
+ model = model.module
23
+ msd = model.state_dict()
24
+ for k, ema_v in model_ema.state_dict().items():
25
+ model_v = msd[k].detach()
26
+ ema_v.copy_(ema_v * decay + (1.0 - decay) * model_v)
27
+
28
+
29
+ def adjust_learning_rate(
30
+ optimizer,
31
+ epoch: int,
32
+ curr_step: int,
33
+ args,
34
+ ):
35
+ """Adjust the lr according to the schedule.
36
+
37
+ Args:
38
+ Optimizer: torch optimizer to update.
39
+ epoch(int): number of the current epoch.
40
+ curr_step(int): number of optimization step taken so far.
41
+ num_training_step(int): total number of optimization steps.
42
+ args: additional training dependent args:
43
+ - lr_drop(int): number of epochs before dropping the learning rate.
44
+ - fraction_warmup_steps(float) fraction of steps over which the lr will be increased to its peak.
45
+ - lr(float): base learning rate
46
+ - lr_backbone(float): learning rate of the backbone
47
+ - text_encoder_backbone(float): learning rate of the text encoder
48
+ - schedule(str): the requested learning rate schedule:
49
+ "step": all lrs divided by 10 after lr_drop epochs
50
+ "multistep": divided by 2 after lr_drop epochs, then by 2 after every 50 epochs
51
+ "linear_with_warmup": same as "step" for backbone + transformer, but for the text encoder, linearly
52
+ increase for a fraction of the training, then linearly decrease back to 0.
53
+ "all_linear_with_warmup": same as "linear_with_warmup" for all learning rates involved.
54
+
55
+ """
56
+ try:
57
+ num_warmup_steps = args.num_warmup_steps
58
+ except:
59
+ return
60
+
61
+ if epoch > 0:
62
+ return
63
+
64
+ if curr_step > num_warmup_steps:
65
+ return
66
+
67
+ text_encoder_gamma = float(curr_step) / float(max(1, num_warmup_steps))
68
+ optimizer.param_groups[-1]["lr"] = args.lr_backbone * text_encoder_gamma
69
+
70
+
src/utils/dependencies/XPose/util/plot_utils.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Plotting utilities to visualize training logs.
3
+ """
4
+ import torch
5
+ import pandas as pd
6
+ import numpy as np
7
+ import seaborn as sns
8
+ import matplotlib.pyplot as plt
9
+
10
+ from pathlib import Path, PurePath
11
+
12
+
13
+ def plot_logs(logs, fields=('class_error', 'loss_bbox_unscaled', 'mAP'), ewm_col=0, log_name='log.txt'):
14
+ '''
15
+ Function to plot specific fields from training log(s). Plots both training and test results.
16
+
17
+ :: Inputs - logs = list containing Path objects, each pointing to individual dir with a log file
18
+ - fields = which results to plot from each log file - plots both training and test for each field.
19
+ - ewm_col = optional, which column to use as the exponential weighted smoothing of the plots
20
+ - log_name = optional, name of log file if different than default 'log.txt'.
21
+
22
+ :: Outputs - matplotlib plots of results in fields, color coded for each log file.
23
+ - solid lines are training results, dashed lines are test results.
24
+
25
+ '''
26
+ func_name = "plot_utils.py::plot_logs"
27
+
28
+ # verify logs is a list of Paths (list[Paths]) or single Pathlib object Path,
29
+ # convert single Path to list to avoid 'not iterable' error
30
+
31
+ if not isinstance(logs, list):
32
+ if isinstance(logs, PurePath):
33
+ logs = [logs]
34
+ print(f"{func_name} info: logs param expects a list argument, converted to list[Path].")
35
+ else:
36
+ raise ValueError(f"{func_name} - invalid argument for logs parameter.\n \
37
+ Expect list[Path] or single Path obj, received {type(logs)}")
38
+
39
+ # Quality checks - verify valid dir(s), that every item in list is Path object, and that log_name exists in each dir
40
+ for i, dir in enumerate(logs):
41
+ if not isinstance(dir, PurePath):
42
+ raise ValueError(f"{func_name} - non-Path object in logs argument of {type(dir)}: \n{dir}")
43
+ if not dir.exists():
44
+ raise ValueError(f"{func_name} - invalid directory in logs argument:\n{dir}")
45
+ # verify log_name exists
46
+ fn = Path(dir / log_name)
47
+ if not fn.exists():
48
+ print(f"-> missing {log_name}. Have you gotten to Epoch 1 in training?")
49
+ print(f"--> full path of missing log file: {fn}")
50
+ return
51
+
52
+ # load log file(s) and plot
53
+ dfs = [pd.read_json(Path(p) / log_name, lines=True) for p in logs]
54
+
55
+ fig, axs = plt.subplots(ncols=len(fields), figsize=(16, 5))
56
+
57
+ for df, color in zip(dfs, sns.color_palette(n_colors=len(logs))):
58
+ for j, field in enumerate(fields):
59
+ if field == 'mAP':
60
+ coco_eval = pd.DataFrame(
61
+ np.stack(df.test_coco_eval_bbox.dropna().values)[:, 1]
62
+ ).ewm(com=ewm_col).mean()
63
+ axs[j].plot(coco_eval, c=color)
64
+ else:
65
+ df.interpolate().ewm(com=ewm_col).mean().plot(
66
+ y=[f'train_{field}', f'test_{field}'],
67
+ ax=axs[j],
68
+ color=[color] * 2,
69
+ style=['-', '--']
70
+ )
71
+ for ax, field in zip(axs, fields):
72
+ if field == 'mAP':
73
+ ax.legend([Path(p).name for p in logs])
74
+ ax.set_title(field)
75
+ else:
76
+ ax.legend([f'train', f'test'])
77
+ ax.set_title(field)
78
+
79
+ return fig, axs
80
+
81
+ def plot_precision_recall(files, naming_scheme='iter'):
82
+ if naming_scheme == 'exp_id':
83
+ # name becomes exp_id
84
+ names = [f.parts[-3] for f in files]
85
+ elif naming_scheme == 'iter':
86
+ names = [f.stem for f in files]
87
+ else:
88
+ raise ValueError(f'not supported {naming_scheme}')
89
+ fig, axs = plt.subplots(ncols=2, figsize=(16, 5))
90
+ for f, color, name in zip(files, sns.color_palette("Blues", n_colors=len(files)), names):
91
+ data = torch.load(f)
92
+ # precision is n_iou, n_points, n_cat, n_area, max_det
93
+ precision = data['precision']
94
+ recall = data['params'].recThrs
95
+ scores = data['scores']
96
+ # take precision for all classes, all areas and 100 detections
97
+ precision = precision[0, :, :, 0, -1].mean(1)
98
+ scores = scores[0, :, :, 0, -1].mean(1)
99
+ prec = precision.mean()
100
+ rec = data['recall'][0, :, 0, -1].mean()
101
+ print(f'{naming_scheme} {name}: mAP@50={prec * 100: 05.1f}, ' +
102
+ f'score={scores.mean():0.3f}, ' +
103
+ f'f1={2 * prec * rec / (prec + rec + 1e-8):0.3f}'
104
+ )
105
+ axs[0].plot(recall, precision, c=color)
106
+ axs[1].plot(recall, scores, c=color)
107
+
108
+ axs[0].set_title('Precision / Recall')
109
+ axs[0].legend(names)
110
+ axs[1].set_title('Scores / Recall')
111
+ axs[1].legend(names)
112
+ return fig, axs
src/utils/dependencies/XPose/util/slio.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ==========================================================
2
+ # Modified from mmcv
3
+ # ==========================================================
4
+
5
+ import json, pickle, yaml
6
+ try:
7
+ from yaml import CLoader as Loader, CDumper as Dumper
8
+ except ImportError:
9
+ from yaml import Loader, Dumper
10
+
11
+ from pathlib import Path
12
+ from abc import ABCMeta, abstractmethod
13
+
14
+ # ===========================
15
+ # Rigister handler
16
+ # ===========================
17
+
18
+ class BaseFileHandler(metaclass=ABCMeta):
19
+
20
+ @abstractmethod
21
+ def load_from_fileobj(self, file, **kwargs):
22
+ pass
23
+
24
+ @abstractmethod
25
+ def dump_to_fileobj(self, obj, file, **kwargs):
26
+ pass
27
+
28
+ @abstractmethod
29
+ def dump_to_str(self, obj, **kwargs):
30
+ pass
31
+
32
+ def load_from_path(self, filepath, mode='r', **kwargs):
33
+ with open(filepath, mode) as f:
34
+ return self.load_from_fileobj(f, **kwargs)
35
+
36
+ def dump_to_path(self, obj, filepath, mode='w', **kwargs):
37
+ with open(filepath, mode) as f:
38
+ self.dump_to_fileobj(obj, f, **kwargs)
39
+
40
+ class JsonHandler(BaseFileHandler):
41
+
42
+ def load_from_fileobj(self, file):
43
+ return json.load(file)
44
+
45
+ def dump_to_fileobj(self, obj, file, **kwargs):
46
+ json.dump(obj, file, **kwargs)
47
+
48
+ def dump_to_str(self, obj, **kwargs):
49
+ return json.dumps(obj, **kwargs)
50
+
51
+ class PickleHandler(BaseFileHandler):
52
+
53
+ def load_from_fileobj(self, file, **kwargs):
54
+ return pickle.load(file, **kwargs)
55
+
56
+ def load_from_path(self, filepath, **kwargs):
57
+ return super(PickleHandler, self).load_from_path(
58
+ filepath, mode='rb', **kwargs)
59
+
60
+ def dump_to_str(self, obj, **kwargs):
61
+ kwargs.setdefault('protocol', 2)
62
+ return pickle.dumps(obj, **kwargs)
63
+
64
+ def dump_to_fileobj(self, obj, file, **kwargs):
65
+ kwargs.setdefault('protocol', 2)
66
+ pickle.dump(obj, file, **kwargs)
67
+
68
+ def dump_to_path(self, obj, filepath, **kwargs):
69
+ super(PickleHandler, self).dump_to_path(
70
+ obj, filepath, mode='wb', **kwargs)
71
+
72
+ class YamlHandler(BaseFileHandler):
73
+
74
+ def load_from_fileobj(self, file, **kwargs):
75
+ kwargs.setdefault('Loader', Loader)
76
+ return yaml.load(file, **kwargs)
77
+
78
+ def dump_to_fileobj(self, obj, file, **kwargs):
79
+ kwargs.setdefault('Dumper', Dumper)
80
+ yaml.dump(obj, file, **kwargs)
81
+
82
+ def dump_to_str(self, obj, **kwargs):
83
+ kwargs.setdefault('Dumper', Dumper)
84
+ return yaml.dump(obj, **kwargs)
85
+
86
+ file_handlers = {
87
+ 'json': JsonHandler(),
88
+ 'yaml': YamlHandler(),
89
+ 'yml': YamlHandler(),
90
+ 'pickle': PickleHandler(),
91
+ 'pkl': PickleHandler()
92
+ }
93
+
94
+ # ===========================
95
+ # load and dump
96
+ # ===========================
97
+
98
+ def is_str(x):
99
+ """Whether the input is an string instance.
100
+
101
+ Note: This method is deprecated since python 2 is no longer supported.
102
+ """
103
+ return isinstance(x, str)
104
+
105
+ def slload(file, file_format=None, **kwargs):
106
+ """Load data from json/yaml/pickle files.
107
+
108
+ This method provides a unified api for loading data from serialized files.
109
+
110
+ Args:
111
+ file (str or :obj:`Path` or file-like object): Filename or a file-like
112
+ object.
113
+ file_format (str, optional): If not specified, the file format will be
114
+ inferred from the file extension, otherwise use the specified one.
115
+ Currently supported formats include "json", "yaml/yml" and
116
+ "pickle/pkl".
117
+
118
+ Returns:
119
+ The content from the file.
120
+ """
121
+ if isinstance(file, Path):
122
+ file = str(file)
123
+ if file_format is None and is_str(file):
124
+ file_format = file.split('.')[-1]
125
+ if file_format not in file_handlers:
126
+ raise TypeError(f'Unsupported format: {file_format}')
127
+
128
+ handler = file_handlers[file_format]
129
+ if is_str(file):
130
+ obj = handler.load_from_path(file, **kwargs)
131
+ elif hasattr(file, 'read'):
132
+ obj = handler.load_from_fileobj(file, **kwargs)
133
+ else:
134
+ raise TypeError('"file" must be a filepath str or a file-object')
135
+ return obj
136
+
137
+
138
+ def sldump(obj, file=None, file_format=None, **kwargs):
139
+ """Dump data to json/yaml/pickle strings or files.
140
+
141
+ This method provides a unified api for dumping data as strings or to files,
142
+ and also supports custom arguments for each file format.
143
+
144
+ Args:
145
+ obj (any): The python object to be dumped.
146
+ file (str or :obj:`Path` or file-like object, optional): If not
147
+ specified, then the object is dump to a str, otherwise to a file
148
+ specified by the filename or file-like object.
149
+ file_format (str, optional): Same as :func:`load`.
150
+
151
+ Returns:
152
+ bool: True for success, False otherwise.
153
+ """
154
+ if isinstance(file, Path):
155
+ file = str(file)
156
+ if file_format is None:
157
+ if is_str(file):
158
+ file_format = file.split('.')[-1]
159
+ elif file is None:
160
+ raise ValueError(
161
+ 'file_format must be specified since file is None')
162
+ if file_format not in file_handlers:
163
+ raise TypeError(f'Unsupported format: {file_format}')
164
+
165
+ handler = file_handlers[file_format]
166
+ if file is None:
167
+ return handler.dump_to_str(obj, **kwargs)
168
+ elif is_str(file):
169
+ handler.dump_to_path(obj, file, **kwargs)
170
+ elif hasattr(file, 'write'):
171
+ handler.dump_to_fileobj(obj, file, **kwargs)
172
+ else:
173
+ raise TypeError('"file" must be a filename str or a file-object')
src/utils/dependencies/XPose/util/time_counter.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import time
3
+
4
+ class TimeCounter:
5
+ def __init__(self) -> None:
6
+ pass
7
+
8
+ def clear(self):
9
+ self.timedict = {}
10
+ self.basetime = time.perf_counter()
11
+
12
+ def timeit(self, name):
13
+ nowtime = time.perf_counter() - self.basetime
14
+ self.timedict[name] = nowtime
15
+ self.basetime = time.perf_counter()
16
+
17
+
18
+ class TimeHolder:
19
+ def __init__(self) -> None:
20
+ self.timedict = {}
21
+
22
+ def update(self, _timedict:dict):
23
+ for k,v in _timedict.items():
24
+ if k not in self.timedict:
25
+ self.timedict[k] = AverageMeter(name=k, val_only=True)
26
+ self.timedict[k].update(val=v)
27
+
28
+ def final_res(self):
29
+ return {k:v.avg for k,v in self.timedict.items()}
30
+
31
+ def __str__(self):
32
+ return json.dumps(self.final_res(), indent=2)
33
+
34
+
35
+ class AverageMeter(object):
36
+ """Computes and stores the average and current value"""
37
+ def __init__(self, name, fmt=':f', val_only=False):
38
+ self.name = name
39
+ self.fmt = fmt
40
+ self.val_only = val_only
41
+ self.reset()
42
+
43
+ def reset(self):
44
+ self.val = 0
45
+ self.avg = 0
46
+ self.sum = 0
47
+ self.count = 0
48
+
49
+ def update(self, val, n=1):
50
+ self.val = val
51
+ self.sum += val * n
52
+ self.count += n
53
+ self.avg = self.sum / self.count
54
+
55
+ def __str__(self):
56
+ if self.val_only:
57
+ fmtstr = '{name} {val' + self.fmt + '}'
58
+ else:
59
+ fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
60
+ return fmtstr.format(**self.__dict__)
src/utils/dependencies/XPose/util/utils.py ADDED
@@ -0,0 +1,499 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ from copy import deepcopy
3
+ from typing import Any, Dict, Iterable, List
4
+ import json
5
+ import warnings
6
+
7
+ import torch
8
+ import numpy as np
9
+
10
+ def slprint(x, name="x"):
11
+ if isinstance(x, (torch.Tensor, np.ndarray)):
12
+ print(f"{name}.shape:", x.shape)
13
+ elif isinstance(x, (tuple, list)):
14
+ print("type x:", type(x))
15
+ for i in range(min(10, len(x))):
16
+ slprint(x[i], f"{name}[{i}]")
17
+ elif isinstance(x, dict):
18
+ for k, v in x.items():
19
+ slprint(v, f"{name}[{k}]")
20
+ else:
21
+ print(f"{name}.type:", type(x))
22
+
23
+ def clean_state_dict(state_dict):
24
+ new_state_dict = OrderedDict()
25
+ for k, v in state_dict.items():
26
+ if k[:7] == 'module.':
27
+ k = k[7:] # remove `module.`
28
+ new_state_dict[k] = v
29
+ return new_state_dict
30
+
31
+ def renorm(img: torch.FloatTensor, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) \
32
+ -> torch.FloatTensor:
33
+ # img: tensor(3,H,W) or tensor(B,3,H,W)
34
+ # return: same as img
35
+ assert img.dim() == 3 or img.dim() == 4, "img.dim() should be 3 or 4 but %d" % img.dim()
36
+ if img.dim() == 3:
37
+ assert img.size(0) == 3, 'img.size(0) shoule be 3 but "%d". (%s)' % (img.size(0), str(img.size()))
38
+ img_perm = img.permute(1,2,0)
39
+ mean = torch.Tensor(mean)
40
+ std = torch.Tensor(std)
41
+ img_res = img_perm * std + mean
42
+ return img_res.permute(2,0,1)
43
+ else: # img.dim() == 4
44
+ assert img.size(1) == 3, 'img.size(1) shoule be 3 but "%d". (%s)' % (img.size(1), str(img.size()))
45
+ img_perm = img.permute(0,2,3,1)
46
+ mean = torch.Tensor(mean)
47
+ std = torch.Tensor(std)
48
+ img_res = img_perm * std + mean
49
+ return img_res.permute(0,3,1,2)
50
+
51
+
52
+
53
+ class CocoClassMapper():
54
+ def __init__(self) -> None:
55
+ self.category_map_str = {"1": 1, "2": 2, "3": 3, "4": 4, "5": 5, "6": 6, "7": 7, "8": 8, "9": 9, "10": 10, "11": 11, "13": 12, "14": 13, "15": 14, "16": 15, "17": 16, "18": 17, "19": 18, "20": 19, "21": 20, "22": 21, "23": 22, "24": 23, "25": 24, "27": 25, "28": 26, "31": 27, "32": 28, "33": 29, "34": 30, "35": 31, "36": 32, "37": 33, "38": 34, "39": 35, "40": 36, "41": 37, "42": 38, "43": 39, "44": 40, "46": 41, "47": 42, "48": 43, "49": 44, "50": 45, "51": 46, "52": 47, "53": 48, "54": 49, "55": 50, "56": 51, "57": 52, "58": 53, "59": 54, "60": 55, "61": 56, "62": 57, "63": 58, "64": 59, "65": 60, "67": 61, "70": 62, "72": 63, "73": 64, "74": 65, "75": 66, "76": 67, "77": 68, "78": 69, "79": 70, "80": 71, "81": 72, "82": 73, "84": 74, "85": 75, "86": 76, "87": 77, "88": 78, "89": 79, "90": 80}
56
+ self.origin2compact_mapper = {int(k):v-1 for k,v in self.category_map_str.items()}
57
+ self.compact2origin_mapper = {int(v-1):int(k) for k,v in self.category_map_str.items()}
58
+
59
+ def origin2compact(self, idx):
60
+ return self.origin2compact_mapper[int(idx)]
61
+
62
+ def compact2origin(self, idx):
63
+ return self.compact2origin_mapper[int(idx)]
64
+
65
+ def to_device(item, device):
66
+ if isinstance(item, torch.Tensor):
67
+ return item.to(device)
68
+ elif isinstance(item, list):
69
+ return [to_device(i, device) for i in item]
70
+ elif isinstance(item, dict):
71
+ return {k: to_device(v, device) for k,v in item.items()}
72
+ else:
73
+ raise NotImplementedError("Call Shilong if you use other containers! type: {}".format(type(item)))
74
+
75
+
76
+
77
+ #
78
+ def get_gaussian_mean(x, axis, other_axis, softmax=True):
79
+ """
80
+
81
+ Args:
82
+ x (float): Input images(BxCxHxW)
83
+ axis (int): The index for weighted mean
84
+ other_axis (int): The other index
85
+
86
+ Returns: weighted index for axis, BxC
87
+
88
+ """
89
+ mat2line = torch.sum(x, axis=other_axis)
90
+ # mat2line = mat2line / mat2line.mean() * 10
91
+ if softmax:
92
+ u = torch.softmax(mat2line, axis=2)
93
+ else:
94
+ u = mat2line / (mat2line.sum(2, keepdim=True) + 1e-6)
95
+ size = x.shape[axis]
96
+ ind = torch.linspace(0, 1, size).to(x.device)
97
+ batch = x.shape[0]
98
+ channel = x.shape[1]
99
+ index = ind.repeat([batch, channel, 1])
100
+ mean_position = torch.sum(index * u, dim=2)
101
+ return mean_position
102
+
103
+ def get_expected_points_from_map(hm, softmax=True):
104
+ """get_gaussian_map_from_points
105
+ B,C,H,W -> B,N,2 float(0, 1) float(0, 1)
106
+ softargmax function
107
+
108
+ Args:
109
+ hm (float): Input images(BxCxHxW)
110
+
111
+ Returns:
112
+ weighted index for axis, BxCx2. float between 0 and 1.
113
+
114
+ """
115
+ # hm = 10*hm
116
+ B,C,H,W = hm.shape
117
+ y_mean = get_gaussian_mean(hm, 2, 3, softmax=softmax) # B,C
118
+ x_mean = get_gaussian_mean(hm, 3, 2, softmax=softmax) # B,C
119
+ # return torch.cat((x_mean.unsqueeze(-1), y_mean.unsqueeze(-1)), 2)
120
+ return torch.stack([x_mean, y_mean], dim=2)
121
+
122
+ # Positional encoding (section 5.1)
123
+ # borrow from nerf
124
+ class Embedder:
125
+ def __init__(self, **kwargs):
126
+ self.kwargs = kwargs
127
+ self.create_embedding_fn()
128
+
129
+ def create_embedding_fn(self):
130
+ embed_fns = []
131
+ d = self.kwargs['input_dims']
132
+ out_dim = 0
133
+ if self.kwargs['include_input']:
134
+ embed_fns.append(lambda x : x)
135
+ out_dim += d
136
+
137
+ max_freq = self.kwargs['max_freq_log2']
138
+ N_freqs = self.kwargs['num_freqs']
139
+
140
+ if self.kwargs['log_sampling']:
141
+ freq_bands = 2.**torch.linspace(0., max_freq, steps=N_freqs)
142
+ else:
143
+ freq_bands = torch.linspace(2.**0., 2.**max_freq, steps=N_freqs)
144
+
145
+ for freq in freq_bands:
146
+ for p_fn in self.kwargs['periodic_fns']:
147
+ embed_fns.append(lambda x, p_fn=p_fn, freq=freq : p_fn(x * freq))
148
+ out_dim += d
149
+
150
+ self.embed_fns = embed_fns
151
+ self.out_dim = out_dim
152
+
153
+ def embed(self, inputs):
154
+ return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
155
+
156
+
157
+ def get_embedder(multires, i=0):
158
+ import torch.nn as nn
159
+ if i == -1:
160
+ return nn.Identity(), 3
161
+
162
+ embed_kwargs = {
163
+ 'include_input' : True,
164
+ 'input_dims' : 3,
165
+ 'max_freq_log2' : multires-1,
166
+ 'num_freqs' : multires,
167
+ 'log_sampling' : True,
168
+ 'periodic_fns' : [torch.sin, torch.cos],
169
+ }
170
+
171
+ embedder_obj = Embedder(**embed_kwargs)
172
+ embed = lambda x, eo=embedder_obj : eo.embed(x)
173
+ return embed, embedder_obj.out_dim
174
+
175
+ class APOPMeter():
176
+ def __init__(self) -> None:
177
+ self.tp = 0
178
+ self.fp = 0
179
+ self.tn = 0
180
+ self.fn = 0
181
+
182
+ def update(self, pred, gt):
183
+ """
184
+ Input:
185
+ pred, gt: Tensor()
186
+ """
187
+ assert pred.shape == gt.shape
188
+ self.tp += torch.logical_and(pred == 1, gt == 1).sum().item()
189
+ self.fp += torch.logical_and(pred == 1, gt == 0).sum().item()
190
+ self.tn += torch.logical_and(pred == 0, gt == 0).sum().item()
191
+ self.tn += torch.logical_and(pred == 1, gt == 0).sum().item()
192
+
193
+ def update_cm(self, tp, fp, tn, fn):
194
+ self.tp += tp
195
+ self.fp += fp
196
+ self.tn += tn
197
+ self.tn += fn
198
+
199
+ def inverse_sigmoid(x, eps=1e-5):
200
+ x = x.clamp(min=0, max=1)
201
+ x1 = x.clamp(min=eps)
202
+ x2 = (1 - x).clamp(min=eps)
203
+ return torch.log(x1/x2)
204
+
205
+ import argparse
206
+ from util.config import Config
207
+ def get_raw_dict(args):
208
+ """
209
+ return the dicf contained in args.
210
+
211
+ e.g:
212
+ >>> with open(path, 'w') as f:
213
+ json.dump(get_raw_dict(args), f, indent=2)
214
+ """
215
+ if isinstance(args, argparse.Namespace):
216
+ return vars(args)
217
+ elif isinstance(args, dict):
218
+ return args
219
+ elif isinstance(args, Config):
220
+ return args._cfg_dict
221
+ else:
222
+ raise NotImplementedError("Unknown type {}".format(type(args)))
223
+
224
+
225
+ def stat_tensors(tensor):
226
+ assert tensor.dim() == 1
227
+ tensor_sm = tensor.softmax(0)
228
+ entropy = (tensor_sm * torch.log(tensor_sm + 1e-9)).sum()
229
+
230
+ return {
231
+ 'max': tensor.max(),
232
+ 'min': tensor.min(),
233
+ 'mean': tensor.mean(),
234
+ 'var': tensor.var(),
235
+ 'std': tensor.var() ** 0.5,
236
+ 'entropy': entropy
237
+ }
238
+
239
+
240
+ class NiceRepr:
241
+ """Inherit from this class and define ``__nice__`` to "nicely" print your
242
+ objects.
243
+
244
+ Defines ``__str__`` and ``__repr__`` in terms of ``__nice__`` function
245
+ Classes that inherit from :class:`NiceRepr` should redefine ``__nice__``.
246
+ If the inheriting class has a ``__len__``, method then the default
247
+ ``__nice__`` method will return its length.
248
+
249
+ Example:
250
+ >>> class Foo(NiceRepr):
251
+ ... def __nice__(self):
252
+ ... return 'info'
253
+ >>> foo = Foo()
254
+ >>> assert str(foo) == '<Foo(info)>'
255
+ >>> assert repr(foo).startswith('<Foo(info) at ')
256
+
257
+ Example:
258
+ >>> class Bar(NiceRepr):
259
+ ... pass
260
+ >>> bar = Bar()
261
+ >>> import pytest
262
+ >>> with pytest.warns(None) as record:
263
+ >>> assert 'object at' in str(bar)
264
+ >>> assert 'object at' in repr(bar)
265
+
266
+ Example:
267
+ >>> class Baz(NiceRepr):
268
+ ... def __len__(self):
269
+ ... return 5
270
+ >>> baz = Baz()
271
+ >>> assert str(baz) == '<Baz(5)>'
272
+ """
273
+
274
+ def __nice__(self):
275
+ """str: a "nice" summary string describing this module"""
276
+ if hasattr(self, '__len__'):
277
+ # It is a common pattern for objects to use __len__ in __nice__
278
+ # As a convenience we define a default __nice__ for these objects
279
+ return str(len(self))
280
+ else:
281
+ # In all other cases force the subclass to overload __nice__
282
+ raise NotImplementedError(
283
+ f'Define the __nice__ method for {self.__class__!r}')
284
+
285
+ def __repr__(self):
286
+ """str: the string of the module"""
287
+ try:
288
+ nice = self.__nice__()
289
+ classname = self.__class__.__name__
290
+ return f'<{classname}({nice}) at {hex(id(self))}>'
291
+ except NotImplementedError as ex:
292
+ warnings.warn(str(ex), category=RuntimeWarning)
293
+ return object.__repr__(self)
294
+
295
+ def __str__(self):
296
+ """str: the string of the module"""
297
+ try:
298
+ classname = self.__class__.__name__
299
+ nice = self.__nice__()
300
+ return f'<{classname}({nice})>'
301
+ except NotImplementedError as ex:
302
+ warnings.warn(str(ex), category=RuntimeWarning)
303
+ return object.__repr__(self)
304
+
305
+
306
+
307
+ def ensure_rng(rng=None):
308
+ """Coerces input into a random number generator.
309
+
310
+ If the input is None, then a global random state is returned.
311
+
312
+ If the input is a numeric value, then that is used as a seed to construct a
313
+ random state. Otherwise the input is returned as-is.
314
+
315
+ Adapted from [1]_.
316
+
317
+ Args:
318
+ rng (int | numpy.random.RandomState | None):
319
+ if None, then defaults to the global rng. Otherwise this can be an
320
+ integer or a RandomState class
321
+ Returns:
322
+ (numpy.random.RandomState) : rng -
323
+ a numpy random number generator
324
+
325
+ References:
326
+ .. [1] https://gitlab.kitware.com/computer-vision/kwarray/blob/master/kwarray/util_random.py#L270 # noqa: E501
327
+ """
328
+
329
+ if rng is None:
330
+ rng = np.random.mtrand._rand
331
+ elif isinstance(rng, int):
332
+ rng = np.random.RandomState(rng)
333
+ else:
334
+ rng = rng
335
+ return rng
336
+
337
+ def random_boxes(num=1, scale=1, rng=None):
338
+ """Simple version of ``kwimage.Boxes.random``
339
+
340
+ Returns:
341
+ Tensor: shape (n, 4) in x1, y1, x2, y2 format.
342
+
343
+ References:
344
+ https://gitlab.kitware.com/computer-vision/kwimage/blob/master/kwimage/structs/boxes.py#L1390
345
+
346
+ Example:
347
+ >>> num = 3
348
+ >>> scale = 512
349
+ >>> rng = 0
350
+ >>> boxes = random_boxes(num, scale, rng)
351
+ >>> print(boxes)
352
+ tensor([[280.9925, 278.9802, 308.6148, 366.1769],
353
+ [216.9113, 330.6978, 224.0446, 456.5878],
354
+ [405.3632, 196.3221, 493.3953, 270.7942]])
355
+ """
356
+ rng = ensure_rng(rng)
357
+
358
+ tlbr = rng.rand(num, 4).astype(np.float32)
359
+
360
+ tl_x = np.minimum(tlbr[:, 0], tlbr[:, 2])
361
+ tl_y = np.minimum(tlbr[:, 1], tlbr[:, 3])
362
+ br_x = np.maximum(tlbr[:, 0], tlbr[:, 2])
363
+ br_y = np.maximum(tlbr[:, 1], tlbr[:, 3])
364
+
365
+ tlbr[:, 0] = tl_x * scale
366
+ tlbr[:, 1] = tl_y * scale
367
+ tlbr[:, 2] = br_x * scale
368
+ tlbr[:, 3] = br_y * scale
369
+
370
+ boxes = torch.from_numpy(tlbr)
371
+ return boxes
372
+
373
+
374
+ class ModelEma(torch.nn.Module):
375
+ def __init__(self, model, decay=0.9997, device=None):
376
+ super(ModelEma, self).__init__()
377
+ # make a copy of the model for accumulating moving average of weights
378
+ self.module = deepcopy(model)
379
+ self.module.eval()
380
+
381
+ # import ipdb; ipdb.set_trace()
382
+
383
+ self.decay = decay
384
+ self.device = device # perform ema on different device from model if set
385
+ if self.device is not None:
386
+ self.module.to(device=device)
387
+
388
+ def _update(self, model, update_fn):
389
+ with torch.no_grad():
390
+ for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
391
+ if self.device is not None:
392
+ model_v = model_v.to(device=self.device)
393
+ ema_v.copy_(update_fn(ema_v, model_v))
394
+
395
+ def update(self, model):
396
+ self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m)
397
+
398
+ def set(self, model):
399
+ self._update(model, update_fn=lambda e, m: m)
400
+
401
+ class BestMetricSingle():
402
+ def __init__(self, init_res=0.0, better='large') -> None:
403
+ self.init_res = init_res
404
+ self.best_res = init_res
405
+ self.best_ep = -1
406
+
407
+ self.better = better
408
+ assert better in ['large', 'small']
409
+
410
+ def isbetter(self, new_res, old_res):
411
+ if self.better == 'large':
412
+ return new_res > old_res
413
+ if self.better == 'small':
414
+ return new_res < old_res
415
+
416
+ def update(self, new_res, ep):
417
+ if self.isbetter(new_res, self.best_res):
418
+ self.best_res = new_res
419
+ self.best_ep = ep
420
+ return True
421
+ return False
422
+
423
+ def __str__(self) -> str:
424
+ return "best_res: {}\t best_ep: {}".format(self.best_res, self.best_ep)
425
+
426
+ def __repr__(self) -> str:
427
+ return self.__str__()
428
+
429
+ def summary(self) -> dict:
430
+ return {
431
+ 'best_res': self.best_res,
432
+ 'best_ep': self.best_ep,
433
+ }
434
+
435
+
436
+ class BestMetricHolder():
437
+ def __init__(self, init_res=0.0, better='large', use_ema=False) -> None:
438
+ self.best_all = BestMetricSingle(init_res, better)
439
+ self.use_ema = use_ema
440
+ if use_ema:
441
+ self.best_ema = BestMetricSingle(init_res, better)
442
+ self.best_regular = BestMetricSingle(init_res, better)
443
+
444
+
445
+ def update(self, new_res, epoch, is_ema=False):
446
+ """
447
+ return if the results is the best.
448
+ """
449
+ if not self.use_ema:
450
+ return self.best_all.update(new_res, epoch)
451
+ else:
452
+ if is_ema:
453
+ self.best_ema.update(new_res, epoch)
454
+ return self.best_all.update(new_res, epoch)
455
+ else:
456
+ self.best_regular.update(new_res, epoch)
457
+ return self.best_all.update(new_res, epoch)
458
+
459
+ def summary(self):
460
+ if not self.use_ema:
461
+ return self.best_all.summary()
462
+
463
+ res = {}
464
+ res.update({f'all_{k}':v for k,v in self.best_all.summary().items()})
465
+ res.update({f'regular_{k}':v for k,v in self.best_regular.summary().items()})
466
+ res.update({f'ema_{k}':v for k,v in self.best_ema.summary().items()})
467
+ return res
468
+
469
+ def __repr__(self) -> str:
470
+ return json.dumps(self.summary(), indent=2)
471
+
472
+ def __str__(self) -> str:
473
+ return self.__repr__()
474
+
475
+
476
+ def targets_to(targets: List[Dict[str, Any]], device):
477
+ """Moves the target dicts to the given device."""
478
+ excluded_keys = [
479
+ "questionId",
480
+ "tokens_positive",
481
+ "strings_positive",
482
+ "tokens",
483
+ "dataset_name",
484
+ "sentence_id",
485
+ "original_img_id",
486
+ "nb_eval",
487
+ "task_id",
488
+ "original_id",
489
+ "token_span",
490
+ "caption",
491
+ "dataset_type",
492
+ "caption_list",
493
+ "id2catname",
494
+ "valid_kpt_num",
495
+ "image_id_ref",
496
+ "image_id_current",
497
+ "test_id"
498
+ ]
499
+ return [{k: v.to(device) if k not in excluded_keys else v for k, v in t.items()} for t in targets]