Spaces:
Runtime error
Runtime error
update xpose code
Browse files- requirements.txt +2 -3
- src/utils/dependencies/XPose/config_model/UniPose_SwinT.py +2 -2
- src/utils/dependencies/XPose/models/UniPose/attention.py +27 -2
- src/utils/dependencies/XPose/models/UniPose/backbone.py +4 -2
- src/utils/dependencies/XPose/models/UniPose/deformable_transformer.py +12 -5
- src/utils/dependencies/XPose/models/UniPose/fuse_modules.py +10 -6
- src/utils/dependencies/XPose/models/UniPose/mask_generate.py +6 -0
- src/utils/dependencies/XPose/models/UniPose/ops/modules/ms_deform_attn.py +4 -1
- src/utils/dependencies/XPose/models/UniPose/ops/setup.py +0 -3
- src/utils/dependencies/XPose/models/UniPose/ops/src/cuda/ms_deform_attn_cuda.cu +2 -2
- src/utils/dependencies/XPose/models/UniPose/position_encoding.py +1 -0
- src/utils/dependencies/XPose/models/UniPose/swin_transformer.py +6 -8
- src/utils/dependencies/XPose/models/UniPose/transformer_deformable.py +24 -18
- src/utils/dependencies/XPose/models/UniPose/transformer_vanilla.py +6 -2
- src/utils/dependencies/XPose/models/UniPose/unipose.py +23 -14
- src/utils/dependencies/XPose/models/UniPose/utils.py +1 -1
- src/utils/dependencies/XPose/transforms.py +1 -0
- src/utils/dependencies/XPose/util/__init__.py +1 -0
- src/utils/dependencies/XPose/util/addict.py +0 -159
- src/utils/dependencies/XPose/util/box_ops.py +1 -1
- src/utils/dependencies/XPose/util/config.py +13 -6
- src/utils/dependencies/XPose/util/get_param_dicts.py +61 -0
- src/utils/dependencies/XPose/util/instance.txt +863 -0
- src/utils/dependencies/XPose/util/logger.py +95 -0
- src/utils/dependencies/XPose/util/metrics.py +181 -0
- src/utils/dependencies/XPose/util/optim.py +70 -0
- src/utils/dependencies/XPose/util/plot_utils.py +112 -0
- src/utils/dependencies/XPose/util/slio.py +173 -0
- src/utils/dependencies/XPose/util/time_counter.py +60 -0
- src/utils/dependencies/XPose/util/utils.py +499 -0
requirements.txt
CHANGED
@@ -1,7 +1,6 @@
|
|
1 |
--extra-index-url https://download.pytorch.org/whl/cu118
|
2 |
-
|
3 |
-
|
4 |
-
torchvision==0.13.1
|
5 |
torchaudio==2.3.0
|
6 |
|
7 |
numpy==1.26.4
|
|
|
1 |
--extra-index-url https://download.pytorch.org/whl/cu118
|
2 |
+
torch
|
3 |
+
torchvision==0.18.0
|
|
|
4 |
torchaudio==2.3.0
|
5 |
|
6 |
numpy==1.26.4
|
src/utils/dependencies/XPose/config_model/UniPose_SwinT.py
CHANGED
@@ -108,7 +108,7 @@ shuffle_type = None
|
|
108 |
use_text_enhancer = True
|
109 |
use_fusion_layer = True
|
110 |
|
111 |
-
use_checkpoint =
|
112 |
use_transformer_ckpt = True
|
113 |
text_encoder_type = 'bert-base-uncased'
|
114 |
|
@@ -122,4 +122,4 @@ binary_query_selection = False
|
|
122 |
use_cdn = True
|
123 |
ffn_extra_layernorm = False
|
124 |
|
125 |
-
fix_size=False
|
|
|
108 |
use_text_enhancer = True
|
109 |
use_fusion_layer = True
|
110 |
|
111 |
+
use_checkpoint = True
|
112 |
use_transformer_ckpt = True
|
113 |
text_encoder_type = 'bert-base-uncased'
|
114 |
|
|
|
122 |
use_cdn = True
|
123 |
ffn_extra_layernorm = False
|
124 |
|
125 |
+
fix_size=False
|
src/utils/dependencies/XPose/models/UniPose/attention.py
CHANGED
@@ -23,19 +23,44 @@ Mostly copy-paste from https://github.com/pytorch/pytorch/blob/master/torch/nn/m
|
|
23 |
and https://github.com/pytorch/pytorch/blob/master/torch/nn/functional.py#L4837
|
24 |
"""
|
25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
import warnings
|
|
|
|
|
27 |
import torch
|
|
|
28 |
from torch.nn.modules.linear import Linear
|
|
|
29 |
from torch.nn.init import constant_
|
|
|
|
|
30 |
from torch.nn.modules.module import Module
|
31 |
-
from torch.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
try:
|
33 |
from torch.overrides import has_torch_function, handle_torch_function
|
34 |
except:
|
35 |
from torch._overrides import has_torch_function, handle_torch_function
|
36 |
-
from torch.nn.functional import linear, pad, softmax, dropout
|
37 |
Tensor = torch.Tensor
|
38 |
|
|
|
|
|
39 |
class MultiheadAttention(Module):
|
40 |
r"""Allows the model to jointly attend to information
|
41 |
from different representation subspaces.
|
|
|
23 |
and https://github.com/pytorch/pytorch/blob/master/torch/nn/functional.py#L4837
|
24 |
"""
|
25 |
|
26 |
+
import copy
|
27 |
+
from typing import Optional, List
|
28 |
+
|
29 |
+
import torch
|
30 |
+
import torch.nn.functional as F
|
31 |
+
from torch import nn, Tensor
|
32 |
+
|
33 |
import warnings
|
34 |
+
from typing import Tuple, Optional
|
35 |
+
|
36 |
import torch
|
37 |
+
from torch import Tensor
|
38 |
from torch.nn.modules.linear import Linear
|
39 |
+
from torch.nn.init import xavier_uniform_
|
40 |
from torch.nn.init import constant_
|
41 |
+
from torch.nn.init import xavier_normal_
|
42 |
+
from torch.nn.parameter import Parameter
|
43 |
from torch.nn.modules.module import Module
|
44 |
+
from torch.nn import functional as F
|
45 |
+
|
46 |
+
import warnings
|
47 |
+
import math
|
48 |
+
|
49 |
+
from torch._C import _infer_size, _add_docstr
|
50 |
+
from torch.nn import _reduction as _Reduction
|
51 |
+
from torch.nn.modules import utils
|
52 |
+
from torch.nn.modules.utils import _single, _pair, _triple, _list_with_default
|
53 |
+
from torch.nn import grad
|
54 |
+
from torch import _VF
|
55 |
+
from torch._jit_internal import boolean_dispatch, List, Optional, _overload, Tuple
|
56 |
try:
|
57 |
from torch.overrides import has_torch_function, handle_torch_function
|
58 |
except:
|
59 |
from torch._overrides import has_torch_function, handle_torch_function
|
|
|
60 |
Tensor = torch.Tensor
|
61 |
|
62 |
+
from torch.nn.functional import linear, pad, softmax, dropout
|
63 |
+
|
64 |
class MultiheadAttention(Module):
|
65 |
r"""Allows the model to jointly attend to information
|
66 |
from different representation subspaces.
|
src/utils/dependencies/XPose/models/UniPose/backbone.py
CHANGED
@@ -16,18 +16,20 @@
|
|
16 |
Backbone modules.
|
17 |
"""
|
18 |
|
|
|
|
|
19 |
import torch
|
20 |
import torch.nn.functional as F
|
21 |
import torchvision
|
22 |
from torch import nn
|
23 |
from torchvision.models._utils import IntermediateLayerGetter
|
24 |
-
from typing import Dict, List
|
25 |
|
26 |
-
from util.misc import NestedTensor, is_main_process
|
27 |
|
28 |
from .position_encoding import build_position_encoding
|
29 |
from .swin_transformer import build_swin_transformer
|
30 |
|
|
|
31 |
class FrozenBatchNorm2d(torch.nn.Module):
|
32 |
"""
|
33 |
BatchNorm2d where the batch statistics and the affine parameters are fixed.
|
|
|
16 |
Backbone modules.
|
17 |
"""
|
18 |
|
19 |
+
from typing import Dict, List
|
20 |
+
|
21 |
import torch
|
22 |
import torch.nn.functional as F
|
23 |
import torchvision
|
24 |
from torch import nn
|
25 |
from torchvision.models._utils import IntermediateLayerGetter
|
|
|
26 |
|
27 |
+
from util.misc import NestedTensor, clean_state_dict, is_main_process
|
28 |
|
29 |
from .position_encoding import build_position_encoding
|
30 |
from .swin_transformer import build_swin_transformer
|
31 |
|
32 |
+
|
33 |
class FrozenBatchNorm2d(torch.nn.Module):
|
34 |
"""
|
35 |
BatchNorm2d where the batch statistics and the affine parameters are fixed.
|
src/utils/dependencies/XPose/models/UniPose/deformable_transformer.py
CHANGED
@@ -16,16 +16,21 @@
|
|
16 |
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
17 |
# ------------------------------------------------------------------------
|
18 |
|
19 |
-
import math
|
|
|
20 |
import copy
|
21 |
-
import torch
|
22 |
-
import torch.utils.checkpoint as checkpoint
|
23 |
-
from torch import nn, Tensor
|
24 |
from typing import Optional
|
|
|
25 |
from util.misc import inverse_sigmoid
|
26 |
|
|
|
|
|
|
|
|
|
|
|
27 |
from .transformer_vanilla import TransformerEncoderLayer
|
28 |
from .fuse_modules import BiAttentionBlock
|
|
|
29 |
from .utils import gen_encoder_output_proposals, MLP, _get_activation_fn, gen_sineembed_for_position, get_sine_pos_embed
|
30 |
from .ops.modules import MSDeformAttn
|
31 |
|
@@ -580,7 +585,7 @@ class TransformerEncoder(nn.Module):
|
|
580 |
reference_points_list = []
|
581 |
for lvl, (H_, W_) in enumerate(spatial_shapes):
|
582 |
ref_y, ref_x = torch.meshgrid(torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
|
583 |
-
torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device)
|
584 |
ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_)
|
585 |
ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_)
|
586 |
ref = torch.stack((ref_x, ref_y), -1)
|
@@ -1228,3 +1233,5 @@ def build_deformable_transformer(args):
|
|
1228 |
binary_query_selection=binary_query_selection,
|
1229 |
ffn_extra_layernorm=ffn_extra_layernorm,
|
1230 |
)
|
|
|
|
|
|
16 |
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
17 |
# ------------------------------------------------------------------------
|
18 |
|
19 |
+
import math, random
|
20 |
+
import os
|
21 |
import copy
|
|
|
|
|
|
|
22 |
from typing import Optional
|
23 |
+
|
24 |
from util.misc import inverse_sigmoid
|
25 |
|
26 |
+
import torch
|
27 |
+
from torch import nn, Tensor
|
28 |
+
|
29 |
+
import torch.utils.checkpoint as checkpoint
|
30 |
+
|
31 |
from .transformer_vanilla import TransformerEncoderLayer
|
32 |
from .fuse_modules import BiAttentionBlock
|
33 |
+
|
34 |
from .utils import gen_encoder_output_proposals, MLP, _get_activation_fn, gen_sineembed_for_position, get_sine_pos_embed
|
35 |
from .ops.modules import MSDeformAttn
|
36 |
|
|
|
585 |
reference_points_list = []
|
586 |
for lvl, (H_, W_) in enumerate(spatial_shapes):
|
587 |
ref_y, ref_x = torch.meshgrid(torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
|
588 |
+
torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device))
|
589 |
ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_)
|
590 |
ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_)
|
591 |
ref = torch.stack((ref_x, ref_y), -1)
|
|
|
1233 |
binary_query_selection=binary_query_selection,
|
1234 |
ffn_extra_layernorm=ffn_extra_layernorm,
|
1235 |
)
|
1236 |
+
|
1237 |
+
|
src/utils/dependencies/XPose/models/UniPose/fuse_modules.py
CHANGED
@@ -1,9 +1,12 @@
|
|
1 |
-
import
|
|
|
2 |
import torch.nn as nn
|
3 |
import torch.nn.functional as F
|
|
|
|
|
|
|
4 |
|
5 |
-
|
6 |
-
from src.modules.util import DropPath
|
7 |
|
8 |
class FeatureResizer(nn.Module):
|
9 |
"""
|
@@ -178,7 +181,7 @@ class BiMultiHeadAttention(nn.Module):
|
|
178 |
|
179 |
if self.stable_softmax_2d:
|
180 |
attn_weights = attn_weights - attn_weights.max()
|
181 |
-
|
182 |
if self.clamp_min_for_underflow:
|
183 |
attn_weights = torch.clamp(attn_weights, min=-50000) # Do not increase -50000, data type half has quite limited range
|
184 |
if self.clamp_max_for_overflow:
|
@@ -261,8 +264,8 @@ class BiAttentionBlock(nn.Module):
|
|
261 |
|
262 |
# add layer scale for training stability
|
263 |
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
264 |
-
self.gamma_v = nn.Parameter(init_values * torch.ones((v_dim)), requires_grad=
|
265 |
-
self.gamma_l = nn.Parameter(init_values * torch.ones((l_dim)), requires_grad=
|
266 |
|
267 |
def forward(self, v, l, attention_mask_v=None, attention_mask_l=None):
|
268 |
v = self.layer_norm_v(v)
|
@@ -272,3 +275,4 @@ class BiAttentionBlock(nn.Module):
|
|
272 |
v = v + self.drop_path(self.gamma_v * delta_v)
|
273 |
l = l + self.drop_path(self.gamma_l * delta_l)
|
274 |
return v, l
|
|
|
|
1 |
+
from typing import List
|
2 |
+
import torch, os
|
3 |
import torch.nn as nn
|
4 |
import torch.nn.functional as F
|
5 |
+
import pdb
|
6 |
+
import math
|
7 |
+
from timm.models.layers import DropPath
|
8 |
|
9 |
+
from transformers.activations import ACT2FN
|
|
|
10 |
|
11 |
class FeatureResizer(nn.Module):
|
12 |
"""
|
|
|
181 |
|
182 |
if self.stable_softmax_2d:
|
183 |
attn_weights = attn_weights - attn_weights.max()
|
184 |
+
|
185 |
if self.clamp_min_for_underflow:
|
186 |
attn_weights = torch.clamp(attn_weights, min=-50000) # Do not increase -50000, data type half has quite limited range
|
187 |
if self.clamp_max_for_overflow:
|
|
|
264 |
|
265 |
# add layer scale for training stability
|
266 |
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
267 |
+
self.gamma_v = nn.Parameter(init_values * torch.ones((v_dim)), requires_grad=True)
|
268 |
+
self.gamma_l = nn.Parameter(init_values * torch.ones((l_dim)), requires_grad=True)
|
269 |
|
270 |
def forward(self, v, l, attention_mask_v=None, attention_mask_l=None):
|
271 |
v = self.layer_norm_v(v)
|
|
|
275 |
v = v + self.drop_path(self.gamma_v * delta_v)
|
276 |
l = l + self.drop_path(self.gamma_l * delta_l)
|
277 |
return v, l
|
278 |
+
|
src/utils/dependencies/XPose/models/UniPose/mask_generate.py
CHANGED
@@ -1,4 +1,10 @@
|
|
1 |
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
|
4 |
def prepare_for_mask(kpt_mask):
|
|
|
1 |
import torch
|
2 |
+
from util.misc import (NestedTensor, nested_tensor_from_tensor_list,
|
3 |
+
accuracy, get_world_size, interpolate,
|
4 |
+
is_dist_avail_and_initialized, inverse_sigmoid)
|
5 |
+
# from .DABDETR import sigmoid_focal_loss
|
6 |
+
from util import box_ops
|
7 |
+
import torch.nn.functional as F
|
8 |
|
9 |
|
10 |
def prepare_for_mask(kpt_mask):
|
src/utils/dependencies/XPose/models/UniPose/ops/modules/ms_deform_attn.py
CHANGED
@@ -20,7 +20,10 @@ from torch import nn
|
|
20 |
import torch.nn.functional as F
|
21 |
from torch.nn.init import xavier_uniform_, constant_
|
22 |
|
23 |
-
|
|
|
|
|
|
|
24 |
|
25 |
|
26 |
def _is_power_of_2(n):
|
|
|
20 |
import torch.nn.functional as F
|
21 |
from torch.nn.init import xavier_uniform_, constant_
|
22 |
|
23 |
+
try:
|
24 |
+
from src.utils.dependencies.XPose.models.UniPose.ops.functions.ms_deform_attn_func import MSDeformAttnFunction
|
25 |
+
except:
|
26 |
+
warnings.warn('Failed to import MSDeformAttnFunction.')
|
27 |
|
28 |
|
29 |
def _is_power_of_2(n):
|
src/utils/dependencies/XPose/models/UniPose/ops/setup.py
CHANGED
@@ -41,12 +41,10 @@ def get_extensions():
|
|
41 |
sources += source_cuda
|
42 |
define_macros += [("WITH_CUDA", None)]
|
43 |
extra_compile_args["nvcc"] = [
|
44 |
-
# "-allow-unsupported-compiler",
|
45 |
"-DCUDA_HAS_FP16=1",
|
46 |
"-D__CUDA_NO_HALF_OPERATORS__",
|
47 |
"-D__CUDA_NO_HALF_CONVERSIONS__",
|
48 |
"-D__CUDA_NO_HALF2_OPERATORS__",
|
49 |
-
# "-std=c++14",
|
50 |
]
|
51 |
else:
|
52 |
raise NotImplementedError('Cuda is not availabel')
|
@@ -64,7 +62,6 @@ def get_extensions():
|
|
64 |
]
|
65 |
return ext_modules
|
66 |
|
67 |
-
|
68 |
setup(
|
69 |
name="MultiScaleDeformableAttention",
|
70 |
version="1.0",
|
|
|
41 |
sources += source_cuda
|
42 |
define_macros += [("WITH_CUDA", None)]
|
43 |
extra_compile_args["nvcc"] = [
|
|
|
44 |
"-DCUDA_HAS_FP16=1",
|
45 |
"-D__CUDA_NO_HALF_OPERATORS__",
|
46 |
"-D__CUDA_NO_HALF_CONVERSIONS__",
|
47 |
"-D__CUDA_NO_HALF2_OPERATORS__",
|
|
|
48 |
]
|
49 |
else:
|
50 |
raise NotImplementedError('Cuda is not availabel')
|
|
|
62 |
]
|
63 |
return ext_modules
|
64 |
|
|
|
65 |
setup(
|
66 |
name="MultiScaleDeformableAttention",
|
67 |
version="1.0",
|
src/utils/dependencies/XPose/models/UniPose/ops/src/cuda/ms_deform_attn_cuda.cu
CHANGED
@@ -61,7 +61,7 @@ at::Tensor ms_deform_attn_cuda_forward(
|
|
61 |
for (int n = 0; n < batch/im2col_step_; ++n)
|
62 |
{
|
63 |
auto columns = output_n.select(0, n);
|
64 |
-
AT_DISPATCH_FLOATING_TYPES(value.
|
65 |
ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
|
66 |
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
|
67 |
spatial_shapes.data<int64_t>(),
|
@@ -131,7 +131,7 @@ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
|
|
131 |
for (int n = 0; n < batch/im2col_step_; ++n)
|
132 |
{
|
133 |
auto grad_output_g = grad_output_n.select(0, n);
|
134 |
-
AT_DISPATCH_FLOATING_TYPES(value.
|
135 |
ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
|
136 |
grad_output_g.data<scalar_t>(),
|
137 |
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
|
|
|
61 |
for (int n = 0; n < batch/im2col_step_; ++n)
|
62 |
{
|
63 |
auto columns = output_n.select(0, n);
|
64 |
+
AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] {
|
65 |
ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
|
66 |
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
|
67 |
spatial_shapes.data<int64_t>(),
|
|
|
131 |
for (int n = 0; n < batch/im2col_step_; ++n)
|
132 |
{
|
133 |
auto grad_output_g = grad_output_n.select(0, n);
|
134 |
+
AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] {
|
135 |
ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
|
136 |
grad_output_g.data<scalar_t>(),
|
137 |
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
|
src/utils/dependencies/XPose/models/UniPose/position_encoding.py
CHANGED
@@ -15,6 +15,7 @@
|
|
15 |
Various positional encodings for the transformer.
|
16 |
"""
|
17 |
import math
|
|
|
18 |
import torch
|
19 |
from torch import nn
|
20 |
|
|
|
15 |
Various positional encodings for the transformer.
|
16 |
"""
|
17 |
import math
|
18 |
+
import os
|
19 |
import torch
|
20 |
from torch import nn
|
21 |
|
src/utils/dependencies/XPose/models/UniPose/swin_transformer.py
CHANGED
@@ -4,10 +4,8 @@ import torch.nn as nn
|
|
4 |
import torch.nn.functional as F
|
5 |
import torch.utils.checkpoint as checkpoint
|
6 |
import numpy as np
|
7 |
-
|
8 |
from util.misc import NestedTensor
|
9 |
-
# from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
10 |
-
from src.modules.util import DropPath, to_2tuple, trunc_normal_
|
11 |
|
12 |
|
13 |
|
@@ -489,8 +487,8 @@ class SwinTransformer(nn.Module):
|
|
489 |
self.frozen_stages = frozen_stages
|
490 |
self.dilation = dilation
|
491 |
|
492 |
-
|
493 |
-
|
494 |
|
495 |
# split image into non-overlapping patches
|
496 |
self.patch_embed = PatchEmbed(
|
@@ -634,7 +632,7 @@ class SwinTransformer(nn.Module):
|
|
634 |
# [torch.Size([2, 192, 256, 256]), torch.Size([2, 384, 128, 128]), \
|
635 |
# torch.Size([2, 768, 64, 64]), torch.Size([2, 1536, 32, 32])]
|
636 |
|
637 |
-
# collect for nesttensors
|
638 |
outs_dict = {}
|
639 |
for idx, out_i in enumerate(outs):
|
640 |
m = tensor_list.mask
|
@@ -661,7 +659,7 @@ def build_swin_transformer(modelname, pretrain_img_size, **kw):
|
|
661 |
depths=[ 2, 2, 6, 2 ],
|
662 |
num_heads=[ 3, 6, 12, 24],
|
663 |
window_size=7
|
664 |
-
),
|
665 |
'swin_B_224_22k': dict(
|
666 |
embed_dim=128,
|
667 |
depths=[ 2, 2, 18, 2 ],
|
@@ -698,4 +696,4 @@ if __name__ == "__main__":
|
|
698 |
y = model.forward_raw(x)
|
699 |
import ipdb; ipdb.set_trace()
|
700 |
x = torch.rand(2, 3, 384, 384)
|
701 |
-
y = model.forward_raw(x)
|
|
|
4 |
import torch.nn.functional as F
|
5 |
import torch.utils.checkpoint as checkpoint
|
6 |
import numpy as np
|
7 |
+
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
8 |
from util.misc import NestedTensor
|
|
|
|
|
9 |
|
10 |
|
11 |
|
|
|
487 |
self.frozen_stages = frozen_stages
|
488 |
self.dilation = dilation
|
489 |
|
490 |
+
if use_checkpoint:
|
491 |
+
print("use_checkpoint!!!!!!!!!!!!!!!!!!!!!!!!")
|
492 |
|
493 |
# split image into non-overlapping patches
|
494 |
self.patch_embed = PatchEmbed(
|
|
|
632 |
# [torch.Size([2, 192, 256, 256]), torch.Size([2, 384, 128, 128]), \
|
633 |
# torch.Size([2, 768, 64, 64]), torch.Size([2, 1536, 32, 32])]
|
634 |
|
635 |
+
# collect for nesttensors
|
636 |
outs_dict = {}
|
637 |
for idx, out_i in enumerate(outs):
|
638 |
m = tensor_list.mask
|
|
|
659 |
depths=[ 2, 2, 6, 2 ],
|
660 |
num_heads=[ 3, 6, 12, 24],
|
661 |
window_size=7
|
662 |
+
),
|
663 |
'swin_B_224_22k': dict(
|
664 |
embed_dim=128,
|
665 |
depths=[ 2, 2, 18, 2 ],
|
|
|
696 |
y = model.forward_raw(x)
|
697 |
import ipdb; ipdb.set_trace()
|
698 |
x = torch.rand(2, 3, 384, 384)
|
699 |
+
y = model.forward_raw(x)
|
src/utils/dependencies/XPose/models/UniPose/transformer_deformable.py
CHANGED
@@ -12,15 +12,19 @@
|
|
12 |
# ------------------------------------------------------------------------
|
13 |
|
14 |
import copy
|
|
|
|
|
15 |
import math
|
|
|
16 |
import torch
|
|
|
17 |
from torch import nn, Tensor
|
18 |
-
from torch.nn.init import xavier_uniform_, constant_, normal_
|
19 |
-
from typing import Optional
|
20 |
|
21 |
from util.misc import inverse_sigmoid
|
22 |
from .ops.modules import MSDeformAttn
|
23 |
-
|
|
|
24 |
|
25 |
class DeformableTransformer(nn.Module):
|
26 |
def __init__(self, d_model=256, nhead=8,
|
@@ -45,7 +49,7 @@ class DeformableTransformer(nn.Module):
|
|
45 |
decoder_layer = DeformableTransformerDecoderLayer(d_model, dim_feedforward,
|
46 |
dropout, activation,
|
47 |
num_feature_levels, nhead, dec_n_points)
|
48 |
-
self.decoder = DeformableTransformerDecoder(decoder_layer, num_decoder_layers, return_intermediate_dec,
|
49 |
use_dab=use_dab, d_model=d_model, high_dim_query_update=high_dim_query_update, no_sine_embed=no_sine_embed)
|
50 |
|
51 |
self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model))
|
@@ -158,7 +162,7 @@ class DeformableTransformer(nn.Module):
|
|
158 |
lvl_pos_embed_flatten.append(lvl_pos_embed)
|
159 |
src_flatten.append(src)
|
160 |
mask_flatten.append(mask)
|
161 |
-
src_flatten = torch.cat(src_flatten, 1) # bs, \sum{hxw}, c
|
162 |
mask_flatten = torch.cat(mask_flatten, 1) # bs, \sum{hxw}
|
163 |
lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
|
164 |
spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device)
|
@@ -187,7 +191,7 @@ class DeformableTransformer(nn.Module):
|
|
187 |
pos_trans_out = self.pos_trans_norm(self.pos_trans(self.get_proposal_pos_embed(topk_coords_unact)))
|
188 |
query_embed, tgt = torch.split(pos_trans_out, c, dim=2)
|
189 |
elif self.use_dab:
|
190 |
-
reference_points = query_embed[..., self.d_model:].sigmoid()
|
191 |
tgt = query_embed[..., :self.d_model]
|
192 |
tgt = tgt.unsqueeze(0).expand(bs, -1, -1)
|
193 |
init_reference_out = reference_points
|
@@ -195,15 +199,15 @@ class DeformableTransformer(nn.Module):
|
|
195 |
query_embed, tgt = torch.split(query_embed, c, dim=1)
|
196 |
query_embed = query_embed.unsqueeze(0).expand(bs, -1, -1)
|
197 |
tgt = tgt.unsqueeze(0).expand(bs, -1, -1)
|
198 |
-
reference_points = self.reference_points(query_embed).sigmoid()
|
199 |
# bs, num_quires, 2
|
200 |
init_reference_out = reference_points
|
201 |
|
202 |
# decoder
|
203 |
# import ipdb; ipdb.set_trace()
|
204 |
hs, inter_references = self.decoder(tgt, reference_points, memory,
|
205 |
-
spatial_shapes, level_start_index, valid_ratios,
|
206 |
-
query_pos=query_embed if not self.use_dab else None,
|
207 |
src_padding_mask=mask_flatten)
|
208 |
|
209 |
inter_references_out = inter_references
|
@@ -387,7 +391,7 @@ class DeformableTransformerDecoderLayer(nn.Module):
|
|
387 |
tgt = self.norm3(tgt)
|
388 |
return tgt
|
389 |
|
390 |
-
def forward_sa(self,
|
391 |
# for tgt
|
392 |
tgt: Optional[Tensor], # nq, bs, d_model
|
393 |
tgt_query_pos: Optional[Tensor] = None, # pos for query. MLP(Sine(pos))
|
@@ -431,9 +435,9 @@ class DeformableTransformerDecoderLayer(nn.Module):
|
|
431 |
else:
|
432 |
raise NotImplementedError("Unknown decoder_sa_type {}".format(self.decoder_sa_type))
|
433 |
|
434 |
-
return tgt
|
435 |
|
436 |
-
def forward_ca(self,
|
437 |
# for tgt
|
438 |
tgt: Optional[Tensor], # nq, bs, d_model
|
439 |
tgt_query_pos: Optional[Tensor] = None, # pos for query. MLP(Sine(pos))
|
@@ -468,9 +472,9 @@ class DeformableTransformerDecoderLayer(nn.Module):
|
|
468 |
tgt = tgt + self.dropout1(tgt2)
|
469 |
tgt = self.norm1(tgt)
|
470 |
|
471 |
-
return tgt
|
472 |
|
473 |
-
def forward(self,
|
474 |
# for tgt
|
475 |
tgt: Optional[Tensor], # nq, bs, d_model
|
476 |
tgt_query_pos: Optional[Tensor] = None, # pos for query. MLP(Sine(pos))
|
@@ -530,7 +534,7 @@ class DeformableTransformerDecoder(nn.Module):
|
|
530 |
self.ref_point_head = MLP(2 * d_model, d_model, d_model, 2)
|
531 |
|
532 |
|
533 |
-
def forward(self, tgt, reference_points, src, src_spatial_shapes,
|
534 |
src_level_start_index, src_valid_ratios,
|
535 |
query_pos=None, src_padding_mask=None):
|
536 |
output = tgt
|
@@ -547,14 +551,14 @@ class DeformableTransformerDecoder(nn.Module):
|
|
547 |
else:
|
548 |
assert reference_points.shape[-1] == 2
|
549 |
reference_points_input = reference_points[:, :, None] * src_valid_ratios[:, None]
|
550 |
-
|
551 |
if self.use_dab:
|
552 |
# import ipdb; ipdb.set_trace()
|
553 |
-
query_sine_embed = gen_sineembed_for_position(reference_points_input[:, :, 0, :]) # bs, nq, 256*2
|
554 |
raw_query_pos = self.ref_point_head(query_sine_embed) # bs, nq, 256
|
555 |
pos_scale = self.query_scale(output) if layer_id != 0 else 1
|
556 |
query_pos = pos_scale * raw_query_pos
|
557 |
-
|
558 |
output = layer(output, query_pos, reference_points_input, src, src_spatial_shapes, src_level_start_index, src_padding_mask)
|
559 |
|
560 |
# hack implementation for iterative bounding box refinement
|
@@ -593,3 +597,5 @@ def build_deforamble_transformer(args):
|
|
593 |
use_dab=args.ddetr_use_dab,
|
594 |
high_dim_query_update=args.ddetr_high_dim_query_update,
|
595 |
no_sine_embed=args.ddetr_no_sine_embed)
|
|
|
|
|
|
12 |
# ------------------------------------------------------------------------
|
13 |
|
14 |
import copy
|
15 |
+
import os
|
16 |
+
from typing import Optional, List
|
17 |
import math
|
18 |
+
|
19 |
import torch
|
20 |
+
import torch.nn.functional as F
|
21 |
from torch import nn, Tensor
|
22 |
+
from torch.nn.init import xavier_uniform_, constant_, uniform_, normal_
|
|
|
23 |
|
24 |
from util.misc import inverse_sigmoid
|
25 |
from .ops.modules import MSDeformAttn
|
26 |
+
|
27 |
+
from .utils import sigmoid_focal_loss, MLP, _get_activation_fn, gen_sineembed_for_position
|
28 |
|
29 |
class DeformableTransformer(nn.Module):
|
30 |
def __init__(self, d_model=256, nhead=8,
|
|
|
49 |
decoder_layer = DeformableTransformerDecoderLayer(d_model, dim_feedforward,
|
50 |
dropout, activation,
|
51 |
num_feature_levels, nhead, dec_n_points)
|
52 |
+
self.decoder = DeformableTransformerDecoder(decoder_layer, num_decoder_layers, return_intermediate_dec,
|
53 |
use_dab=use_dab, d_model=d_model, high_dim_query_update=high_dim_query_update, no_sine_embed=no_sine_embed)
|
54 |
|
55 |
self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model))
|
|
|
162 |
lvl_pos_embed_flatten.append(lvl_pos_embed)
|
163 |
src_flatten.append(src)
|
164 |
mask_flatten.append(mask)
|
165 |
+
src_flatten = torch.cat(src_flatten, 1) # bs, \sum{hxw}, c
|
166 |
mask_flatten = torch.cat(mask_flatten, 1) # bs, \sum{hxw}
|
167 |
lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
|
168 |
spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device)
|
|
|
191 |
pos_trans_out = self.pos_trans_norm(self.pos_trans(self.get_proposal_pos_embed(topk_coords_unact)))
|
192 |
query_embed, tgt = torch.split(pos_trans_out, c, dim=2)
|
193 |
elif self.use_dab:
|
194 |
+
reference_points = query_embed[..., self.d_model:].sigmoid()
|
195 |
tgt = query_embed[..., :self.d_model]
|
196 |
tgt = tgt.unsqueeze(0).expand(bs, -1, -1)
|
197 |
init_reference_out = reference_points
|
|
|
199 |
query_embed, tgt = torch.split(query_embed, c, dim=1)
|
200 |
query_embed = query_embed.unsqueeze(0).expand(bs, -1, -1)
|
201 |
tgt = tgt.unsqueeze(0).expand(bs, -1, -1)
|
202 |
+
reference_points = self.reference_points(query_embed).sigmoid()
|
203 |
# bs, num_quires, 2
|
204 |
init_reference_out = reference_points
|
205 |
|
206 |
# decoder
|
207 |
# import ipdb; ipdb.set_trace()
|
208 |
hs, inter_references = self.decoder(tgt, reference_points, memory,
|
209 |
+
spatial_shapes, level_start_index, valid_ratios,
|
210 |
+
query_pos=query_embed if not self.use_dab else None,
|
211 |
src_padding_mask=mask_flatten)
|
212 |
|
213 |
inter_references_out = inter_references
|
|
|
391 |
tgt = self.norm3(tgt)
|
392 |
return tgt
|
393 |
|
394 |
+
def forward_sa(self,
|
395 |
# for tgt
|
396 |
tgt: Optional[Tensor], # nq, bs, d_model
|
397 |
tgt_query_pos: Optional[Tensor] = None, # pos for query. MLP(Sine(pos))
|
|
|
435 |
else:
|
436 |
raise NotImplementedError("Unknown decoder_sa_type {}".format(self.decoder_sa_type))
|
437 |
|
438 |
+
return tgt
|
439 |
|
440 |
+
def forward_ca(self,
|
441 |
# for tgt
|
442 |
tgt: Optional[Tensor], # nq, bs, d_model
|
443 |
tgt_query_pos: Optional[Tensor] = None, # pos for query. MLP(Sine(pos))
|
|
|
472 |
tgt = tgt + self.dropout1(tgt2)
|
473 |
tgt = self.norm1(tgt)
|
474 |
|
475 |
+
return tgt
|
476 |
|
477 |
+
def forward(self,
|
478 |
# for tgt
|
479 |
tgt: Optional[Tensor], # nq, bs, d_model
|
480 |
tgt_query_pos: Optional[Tensor] = None, # pos for query. MLP(Sine(pos))
|
|
|
534 |
self.ref_point_head = MLP(2 * d_model, d_model, d_model, 2)
|
535 |
|
536 |
|
537 |
+
def forward(self, tgt, reference_points, src, src_spatial_shapes,
|
538 |
src_level_start_index, src_valid_ratios,
|
539 |
query_pos=None, src_padding_mask=None):
|
540 |
output = tgt
|
|
|
551 |
else:
|
552 |
assert reference_points.shape[-1] == 2
|
553 |
reference_points_input = reference_points[:, :, None] * src_valid_ratios[:, None]
|
554 |
+
|
555 |
if self.use_dab:
|
556 |
# import ipdb; ipdb.set_trace()
|
557 |
+
query_sine_embed = gen_sineembed_for_position(reference_points_input[:, :, 0, :]) # bs, nq, 256*2
|
558 |
raw_query_pos = self.ref_point_head(query_sine_embed) # bs, nq, 256
|
559 |
pos_scale = self.query_scale(output) if layer_id != 0 else 1
|
560 |
query_pos = pos_scale * raw_query_pos
|
561 |
+
|
562 |
output = layer(output, query_pos, reference_points_input, src, src_spatial_shapes, src_level_start_index, src_padding_mask)
|
563 |
|
564 |
# hack implementation for iterative bounding box refinement
|
|
|
597 |
use_dab=args.ddetr_use_dab,
|
598 |
high_dim_query_update=args.ddetr_high_dim_query_update,
|
599 |
no_sine_embed=args.ddetr_no_sine_embed)
|
600 |
+
|
601 |
+
|
src/utils/dependencies/XPose/models/UniPose/transformer_vanilla.py
CHANGED
@@ -8,11 +8,15 @@ Copy-paste from torch.nn.Transformer with modifications:
|
|
8 |
* extra LN at the end of encoder is removed
|
9 |
* decoder returns a stack of activations from all decoding layers
|
10 |
"""
|
|
|
|
|
|
|
|
|
11 |
import torch
|
|
|
12 |
from torch import Tensor, nn
|
13 |
-
from typing import List, Optional
|
14 |
|
15 |
-
from .utils import
|
16 |
|
17 |
|
18 |
class TextTransformer(nn.Module):
|
|
|
8 |
* extra LN at the end of encoder is removed
|
9 |
* decoder returns a stack of activations from all decoding layers
|
10 |
"""
|
11 |
+
import copy
|
12 |
+
import os
|
13 |
+
from typing import List, Optional
|
14 |
+
import pdb
|
15 |
import torch
|
16 |
+
import torch.nn.functional as F
|
17 |
from torch import Tensor, nn
|
|
|
18 |
|
19 |
+
from .utils import gen_encoder_output_proposals, sigmoid_focal_loss, MLP, _get_activation_fn, gen_sineembed_for_position, _get_clones
|
20 |
|
21 |
|
22 |
class TextTransformer(nn.Module):
|
src/utils/dependencies/XPose/models/UniPose/unipose.py
CHANGED
@@ -6,21 +6,30 @@
|
|
6 |
# Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR)
|
7 |
# Copyright (c) 2020 SenseTime. All Rights Reserved.
|
8 |
# ------------------------------------------------------------------------
|
9 |
-
import os
|
10 |
import copy
|
|
|
|
|
|
|
11 |
import torch
|
12 |
import torch.nn.functional as F
|
13 |
from torch import nn
|
14 |
-
from
|
15 |
-
|
16 |
from util.keypoint_ops import keypoint_xyzxyz_to_xyxyzz
|
17 |
-
from util
|
18 |
-
|
19 |
-
|
|
|
20 |
from .backbone import build_backbone
|
|
|
|
|
|
|
21 |
from ..registry import MODULE_BUILD_FUNCS
|
22 |
from .mask_generate import prepare_for_mask, post_process
|
23 |
-
|
|
|
|
|
|
|
|
|
24 |
|
25 |
|
26 |
class UniPose(nn.Module):
|
@@ -107,12 +116,12 @@ class UniPose(nn.Module):
|
|
107 |
|
108 |
|
109 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
#
|
114 |
-
|
115 |
-
|
116 |
|
117 |
self.pos_proj = nn.Linear(hidden_dim, 768)
|
118 |
self.padding = nn.Embedding(1, 768)
|
@@ -531,7 +540,7 @@ def build_unipose(args):
|
|
531 |
sub_sentence_present = args.sub_sentence_present
|
532 |
except:
|
533 |
sub_sentence_present = True
|
534 |
-
|
535 |
|
536 |
model = UniPose(
|
537 |
backbone,
|
|
|
6 |
# Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR)
|
7 |
# Copyright (c) 2020 SenseTime. All Rights Reserved.
|
8 |
# ------------------------------------------------------------------------
|
|
|
9 |
import copy
|
10 |
+
import math
|
11 |
+
import os
|
12 |
+
from typing import List
|
13 |
import torch
|
14 |
import torch.nn.functional as F
|
15 |
from torch import nn
|
16 |
+
from torchvision.ops.boxes import nms
|
|
|
17 |
from util.keypoint_ops import keypoint_xyzxyz_to_xyxyzz
|
18 |
+
from util import box_ops
|
19 |
+
from util.misc import (NestedTensor, nested_tensor_from_tensor_list,
|
20 |
+
accuracy, get_world_size, interpolate,
|
21 |
+
is_dist_avail_and_initialized, inverse_sigmoid)
|
22 |
from .backbone import build_backbone
|
23 |
+
from .deformable_transformer import build_deformable_transformer
|
24 |
+
from .utils import sigmoid_focal_loss, MLP
|
25 |
+
|
26 |
from ..registry import MODULE_BUILD_FUNCS
|
27 |
from .mask_generate import prepare_for_mask, post_process
|
28 |
+
import random
|
29 |
+
from .utils import sigmoid_focal_loss, MLP, _get_activation_fn, gen_sineembed_for_position
|
30 |
+
from pathlib import Path
|
31 |
+
import clip
|
32 |
+
|
33 |
|
34 |
|
35 |
class UniPose(nn.Module):
|
|
|
116 |
|
117 |
|
118 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
119 |
+
model, _ = clip.load("ViT-B/32", device=device)
|
120 |
+
self.clip_model = model
|
121 |
+
visual_parameters = list(self.clip_model.visual.parameters())
|
122 |
+
#
|
123 |
+
for param in visual_parameters:
|
124 |
+
param.requires_grad = False
|
125 |
|
126 |
self.pos_proj = nn.Linear(hidden_dim, 768)
|
127 |
self.padding = nn.Embedding(1, 768)
|
|
|
540 |
sub_sentence_present = args.sub_sentence_present
|
541 |
except:
|
542 |
sub_sentence_present = True
|
543 |
+
print('********* sub_sentence_present', sub_sentence_present)
|
544 |
|
545 |
model = UniPose(
|
546 |
backbone,
|
src/utils/dependencies/XPose/models/UniPose/utils.py
CHANGED
@@ -345,4 +345,4 @@ class OKSLoss(nn.Module):
|
|
345 |
linear=self.linear,
|
346 |
sigmas=self.sigmas,
|
347 |
eps=self.eps)
|
348 |
-
return loss
|
|
|
345 |
linear=self.linear,
|
346 |
sigmas=self.sigmas,
|
347 |
eps=self.eps)
|
348 |
+
return loss
|
src/utils/dependencies/XPose/transforms.py
CHANGED
@@ -24,6 +24,7 @@ def crop(image, target, region):
|
|
24 |
i, j, h, w = region
|
25 |
id2catname = target["id2catname"]
|
26 |
caption_list = target["caption_list"]
|
|
|
27 |
target["size"] = torch.tensor([h, w])
|
28 |
|
29 |
fields = ["labels", "area", "iscrowd", "positive_map","keypoints"]
|
|
|
24 |
i, j, h, w = region
|
25 |
id2catname = target["id2catname"]
|
26 |
caption_list = target["caption_list"]
|
27 |
+
# should we do something wrt the original size?
|
28 |
target["size"] = torch.tensor([h, w])
|
29 |
|
30 |
fields = ["labels", "area", "iscrowd", "positive_map","keypoints"]
|
src/utils/dependencies/XPose/util/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
src/utils/dependencies/XPose/util/addict.py
DELETED
@@ -1,159 +0,0 @@
|
|
1 |
-
import copy
|
2 |
-
|
3 |
-
|
4 |
-
class Dict(dict):
|
5 |
-
|
6 |
-
def __init__(__self, *args, **kwargs):
|
7 |
-
object.__setattr__(__self, '__parent', kwargs.pop('__parent', None))
|
8 |
-
object.__setattr__(__self, '__key', kwargs.pop('__key', None))
|
9 |
-
object.__setattr__(__self, '__frozen', False)
|
10 |
-
for arg in args:
|
11 |
-
if not arg:
|
12 |
-
continue
|
13 |
-
elif isinstance(arg, dict):
|
14 |
-
for key, val in arg.items():
|
15 |
-
__self[key] = __self._hook(val)
|
16 |
-
elif isinstance(arg, tuple) and (not isinstance(arg[0], tuple)):
|
17 |
-
__self[arg[0]] = __self._hook(arg[1])
|
18 |
-
else:
|
19 |
-
for key, val in iter(arg):
|
20 |
-
__self[key] = __self._hook(val)
|
21 |
-
|
22 |
-
for key, val in kwargs.items():
|
23 |
-
__self[key] = __self._hook(val)
|
24 |
-
|
25 |
-
def __setattr__(self, name, value):
|
26 |
-
if hasattr(self.__class__, name):
|
27 |
-
raise AttributeError("'Dict' object attribute "
|
28 |
-
"'{0}' is read-only".format(name))
|
29 |
-
else:
|
30 |
-
self[name] = value
|
31 |
-
|
32 |
-
def __setitem__(self, name, value):
|
33 |
-
isFrozen = (hasattr(self, '__frozen') and
|
34 |
-
object.__getattribute__(self, '__frozen'))
|
35 |
-
if isFrozen and name not in super(Dict, self).keys():
|
36 |
-
raise KeyError(name)
|
37 |
-
super(Dict, self).__setitem__(name, value)
|
38 |
-
try:
|
39 |
-
p = object.__getattribute__(self, '__parent')
|
40 |
-
key = object.__getattribute__(self, '__key')
|
41 |
-
except AttributeError:
|
42 |
-
p = None
|
43 |
-
key = None
|
44 |
-
if p is not None:
|
45 |
-
p[key] = self
|
46 |
-
object.__delattr__(self, '__parent')
|
47 |
-
object.__delattr__(self, '__key')
|
48 |
-
|
49 |
-
def __add__(self, other):
|
50 |
-
if not self.keys():
|
51 |
-
return other
|
52 |
-
else:
|
53 |
-
self_type = type(self).__name__
|
54 |
-
other_type = type(other).__name__
|
55 |
-
msg = "unsupported operand type(s) for +: '{}' and '{}'"
|
56 |
-
raise TypeError(msg.format(self_type, other_type))
|
57 |
-
|
58 |
-
@classmethod
|
59 |
-
def _hook(cls, item):
|
60 |
-
if isinstance(item, dict):
|
61 |
-
return cls(item)
|
62 |
-
elif isinstance(item, (list, tuple)):
|
63 |
-
return type(item)(cls._hook(elem) for elem in item)
|
64 |
-
return item
|
65 |
-
|
66 |
-
def __getattr__(self, item):
|
67 |
-
return self.__getitem__(item)
|
68 |
-
|
69 |
-
def __missing__(self, name):
|
70 |
-
if object.__getattribute__(self, '__frozen'):
|
71 |
-
raise KeyError(name)
|
72 |
-
return self.__class__(__parent=self, __key=name)
|
73 |
-
|
74 |
-
def __delattr__(self, name):
|
75 |
-
del self[name]
|
76 |
-
|
77 |
-
def to_dict(self):
|
78 |
-
base = {}
|
79 |
-
for key, value in self.items():
|
80 |
-
if isinstance(value, type(self)):
|
81 |
-
base[key] = value.to_dict()
|
82 |
-
elif isinstance(value, (list, tuple)):
|
83 |
-
base[key] = type(value)(
|
84 |
-
item.to_dict() if isinstance(item, type(self)) else
|
85 |
-
item for item in value)
|
86 |
-
else:
|
87 |
-
base[key] = value
|
88 |
-
return base
|
89 |
-
|
90 |
-
def copy(self):
|
91 |
-
return copy.copy(self)
|
92 |
-
|
93 |
-
def deepcopy(self):
|
94 |
-
return copy.deepcopy(self)
|
95 |
-
|
96 |
-
def __deepcopy__(self, memo):
|
97 |
-
other = self.__class__()
|
98 |
-
memo[id(self)] = other
|
99 |
-
for key, value in self.items():
|
100 |
-
other[copy.deepcopy(key, memo)] = copy.deepcopy(value, memo)
|
101 |
-
return other
|
102 |
-
|
103 |
-
def update(self, *args, **kwargs):
|
104 |
-
other = {}
|
105 |
-
if args:
|
106 |
-
if len(args) > 1:
|
107 |
-
raise TypeError()
|
108 |
-
other.update(args[0])
|
109 |
-
other.update(kwargs)
|
110 |
-
for k, v in other.items():
|
111 |
-
if ((k not in self) or
|
112 |
-
(not isinstance(self[k], dict)) or
|
113 |
-
(not isinstance(v, dict))):
|
114 |
-
self[k] = v
|
115 |
-
else:
|
116 |
-
self[k].update(v)
|
117 |
-
|
118 |
-
def __getnewargs__(self):
|
119 |
-
return tuple(self.items())
|
120 |
-
|
121 |
-
def __getstate__(self):
|
122 |
-
return self
|
123 |
-
|
124 |
-
def __setstate__(self, state):
|
125 |
-
self.update(state)
|
126 |
-
|
127 |
-
def __or__(self, other):
|
128 |
-
if not isinstance(other, (Dict, dict)):
|
129 |
-
return NotImplemented
|
130 |
-
new = Dict(self)
|
131 |
-
new.update(other)
|
132 |
-
return new
|
133 |
-
|
134 |
-
def __ror__(self, other):
|
135 |
-
if not isinstance(other, (Dict, dict)):
|
136 |
-
return NotImplemented
|
137 |
-
new = Dict(other)
|
138 |
-
new.update(self)
|
139 |
-
return new
|
140 |
-
|
141 |
-
def __ior__(self, other):
|
142 |
-
self.update(other)
|
143 |
-
return self
|
144 |
-
|
145 |
-
def setdefault(self, key, default=None):
|
146 |
-
if key in self:
|
147 |
-
return self[key]
|
148 |
-
else:
|
149 |
-
self[key] = default
|
150 |
-
return default
|
151 |
-
|
152 |
-
def freeze(self, shouldFreeze=True):
|
153 |
-
object.__setattr__(self, '__frozen', shouldFreeze)
|
154 |
-
for key, val in self.items():
|
155 |
-
if isinstance(val, Dict):
|
156 |
-
val.freeze(shouldFreeze)
|
157 |
-
|
158 |
-
def unfreeze(self):
|
159 |
-
self.freeze(False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/utils/dependencies/XPose/util/box_ops.py
CHANGED
@@ -136,4 +136,4 @@ if __name__ == '__main__':
|
|
136 |
x = torch.rand(5, 4)
|
137 |
y = torch.rand(3, 4)
|
138 |
iou, union = box_iou(x, y)
|
139 |
-
import ipdb; ipdb.set_trace()
|
|
|
136 |
x = torch.rand(5, 4)
|
137 |
y = torch.rand(3, 4)
|
138 |
iou, union = box_iou(x, y)
|
139 |
+
import ipdb; ipdb.set_trace()
|
src/utils/dependencies/XPose/util/config.py
CHANGED
@@ -1,15 +1,17 @@
|
|
1 |
# ==========================================================
|
2 |
# Modified from mmcv
|
3 |
# ==========================================================
|
4 |
-
import sys
|
5 |
import os.path as osp
|
6 |
import ast
|
7 |
import tempfile
|
8 |
import shutil
|
9 |
from importlib import import_module
|
|
|
10 |
from argparse import Action
|
11 |
|
12 |
-
from
|
|
|
13 |
|
14 |
BASE_KEY = '_base_'
|
15 |
DELETE_KEY = '_delete_'
|
@@ -81,8 +83,6 @@ class Config(object):
|
|
81 |
temp_config_file = tempfile.NamedTemporaryFile(
|
82 |
dir=temp_config_dir, suffix='.py')
|
83 |
temp_config_name = osp.basename(temp_config_file.name)
|
84 |
-
# close temp file before copy
|
85 |
-
temp_config_file.close()
|
86 |
shutil.copyfile(filename,
|
87 |
osp.join(temp_config_dir, temp_config_name))
|
88 |
temp_module_name = osp.splitext(temp_config_name)[0]
|
@@ -97,8 +97,8 @@ class Config(object):
|
|
97 |
}
|
98 |
# delete imported module
|
99 |
del sys.modules[temp_module_name]
|
100 |
-
|
101 |
-
|
102 |
elif filename.lower().endswith(('.yml', '.yaml', '.json')):
|
103 |
from .slio import slload
|
104 |
cfg_dict = slload(filename)
|
@@ -304,6 +304,13 @@ class Config(object):
|
|
304 |
|
305 |
cfg_dict = self._cfg_dict.to_dict()
|
306 |
text = _format_dict(cfg_dict, outest_level=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
307 |
return text
|
308 |
|
309 |
|
|
|
1 |
# ==========================================================
|
2 |
# Modified from mmcv
|
3 |
# ==========================================================
|
4 |
+
import os, sys
|
5 |
import os.path as osp
|
6 |
import ast
|
7 |
import tempfile
|
8 |
import shutil
|
9 |
from importlib import import_module
|
10 |
+
|
11 |
from argparse import Action
|
12 |
|
13 |
+
from addict import Dict
|
14 |
+
from yapf.yapflib.yapf_api import FormatCode
|
15 |
|
16 |
BASE_KEY = '_base_'
|
17 |
DELETE_KEY = '_delete_'
|
|
|
83 |
temp_config_file = tempfile.NamedTemporaryFile(
|
84 |
dir=temp_config_dir, suffix='.py')
|
85 |
temp_config_name = osp.basename(temp_config_file.name)
|
|
|
|
|
86 |
shutil.copyfile(filename,
|
87 |
osp.join(temp_config_dir, temp_config_name))
|
88 |
temp_module_name = osp.splitext(temp_config_name)[0]
|
|
|
97 |
}
|
98 |
# delete imported module
|
99 |
del sys.modules[temp_module_name]
|
100 |
+
# close temp file
|
101 |
+
temp_config_file.close()
|
102 |
elif filename.lower().endswith(('.yml', '.yaml', '.json')):
|
103 |
from .slio import slload
|
104 |
cfg_dict = slload(filename)
|
|
|
304 |
|
305 |
cfg_dict = self._cfg_dict.to_dict()
|
306 |
text = _format_dict(cfg_dict, outest_level=True)
|
307 |
+
# copied from setup.cfg
|
308 |
+
yapf_style = dict(
|
309 |
+
based_on_style='pep8',
|
310 |
+
blank_line_before_nested_class_or_def=True,
|
311 |
+
split_before_expression_after_opening_paren=True)
|
312 |
+
text, _ = FormatCode(text, style_config=yapf_style, verify=True)
|
313 |
+
|
314 |
return text
|
315 |
|
316 |
|
src/utils/dependencies/XPose/util/get_param_dicts.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
|
6 |
+
def match_name_keywords(n: str, name_keywords: list):
|
7 |
+
out = False
|
8 |
+
for b in name_keywords:
|
9 |
+
if b in n:
|
10 |
+
out = True
|
11 |
+
break
|
12 |
+
return out
|
13 |
+
|
14 |
+
|
15 |
+
def get_param_dict(args, model_without_ddp: nn.Module):
|
16 |
+
try:
|
17 |
+
param_dict_type = args.param_dict_type
|
18 |
+
except:
|
19 |
+
param_dict_type = 'default'
|
20 |
+
assert param_dict_type in ['default', 'ddetr_in_mmdet', 'large_wd']
|
21 |
+
|
22 |
+
# by default
|
23 |
+
if param_dict_type == 'default':
|
24 |
+
param_dicts = [
|
25 |
+
{"params": [p for n, p in model_without_ddp.named_parameters() if "backbone" not in n and "bert" not in n and p.requires_grad]},
|
26 |
+
{
|
27 |
+
"params": [p for n, p in model_without_ddp.named_parameters() if "backbone" in n and p.requires_grad],
|
28 |
+
"lr": args.lr_backbone,
|
29 |
+
},
|
30 |
+
{
|
31 |
+
"params": [p for n, p in model_without_ddp.named_parameters() if "bert" in n and p.requires_grad],
|
32 |
+
"lr": args.lr_backbone,
|
33 |
+
}
|
34 |
+
]
|
35 |
+
|
36 |
+
param_name_dicts = [
|
37 |
+
{"params": [n for n, p in model_without_ddp.named_parameters() if "backbone" not in n and "bert" not in n and p.requires_grad]},
|
38 |
+
{
|
39 |
+
"params": [n for n, p in model_without_ddp.named_parameters() if "backbone" in n and p.requires_grad],
|
40 |
+
"lr": args.lr_backbone,
|
41 |
+
},
|
42 |
+
{
|
43 |
+
"params": [n for n, p in model_without_ddp.named_parameters() if "bert" in n and p.requires_grad],
|
44 |
+
"lr": args.lr_backbone,
|
45 |
+
}
|
46 |
+
]
|
47 |
+
|
48 |
+
print('param_name_dicts: ', json.dumps(param_name_dicts, indent=2))
|
49 |
+
|
50 |
+
return param_dicts, param_name_dicts
|
51 |
+
|
52 |
+
|
53 |
+
|
54 |
+
|
55 |
+
raise NotImplementedError
|
56 |
+
|
57 |
+
|
58 |
+
|
59 |
+
# print("param_dicts: {}".format(param_dicts))
|
60 |
+
|
61 |
+
return param_dicts, None
|
src/utils/dependencies/XPose/util/instance.txt
ADDED
@@ -0,0 +1,863 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
AnimalKindom (850 but only providing 402 non-overlapping animal classes):
|
2 |
+
Orange Clownfish
|
3 |
+
Fish
|
4 |
+
Damsel Fish
|
5 |
+
Sterlet Fish
|
6 |
+
Trout
|
7 |
+
Danube Bleak Fish
|
8 |
+
Grayling Fish
|
9 |
+
Catfish
|
10 |
+
Sea Toad Fish
|
11 |
+
Pike Perch Fish
|
12 |
+
Sardine
|
13 |
+
Archer Fish
|
14 |
+
Giant Trevally
|
15 |
+
Atlantic Blue Tang Fish
|
16 |
+
Trout Young
|
17 |
+
Stonefish
|
18 |
+
Yellow Watchman Goby Fish
|
19 |
+
Salmon
|
20 |
+
Danube Salmon
|
21 |
+
Mimic Blenny Fish
|
22 |
+
Pink Skunk Clownfish
|
23 |
+
Barracuda Fish
|
24 |
+
Surgeonfish
|
25 |
+
Keeltail Needlefish
|
26 |
+
Royal Grammas Fish
|
27 |
+
Tench Fish
|
28 |
+
Clownfish
|
29 |
+
Barramundi Fish
|
30 |
+
Sergeant Major Fish
|
31 |
+
Gray Angelfish
|
32 |
+
Butterfly Fish
|
33 |
+
Yellow Wrasse Fish
|
34 |
+
Lionfish
|
35 |
+
Toadfish
|
36 |
+
Perch Fish
|
37 |
+
Round Face Bat Fish
|
38 |
+
Goby Fish
|
39 |
+
Horned Adder
|
40 |
+
Atheris Hispida Viper
|
41 |
+
Fer-De-Lance Snake
|
42 |
+
Naja Nivea Snake
|
43 |
+
Elegant Bronzeback Snake
|
44 |
+
S Viper"]
|
45 |
+
Black Necked Spitting Cobra
|
46 |
+
Dice Snake
|
47 |
+
Boa
|
48 |
+
Dispholidus Typus Snake
|
49 |
+
Red Spitting Cobra
|
50 |
+
Wild Red-Tailed Boa
|
51 |
+
S Spitting Cobra"]
|
52 |
+
Coronella Austriaca Snake
|
53 |
+
Lesser Sunda Pit Viper
|
54 |
+
Oriental Whip Snake
|
55 |
+
Snake
|
56 |
+
Mojave Rattlesnake
|
57 |
+
Annulated Tree Boa
|
58 |
+
Bushmaster Snake
|
59 |
+
Sidewinder Rattlesnake
|
60 |
+
Vipera Berus Snake
|
61 |
+
Reticulated Python
|
62 |
+
King Cobra
|
63 |
+
Atheris Squamigera
|
64 |
+
Lichanura Trivirgata Snake
|
65 |
+
Crotalus Willardi Ridge Nosed Rattlesnake
|
66 |
+
Slender Hognosed Pit Viper
|
67 |
+
Paradise Tree Snake
|
68 |
+
Coral Mimic Snake
|
69 |
+
Many Horned Adder
|
70 |
+
Dendroaspis Polylepis Black Mamba
|
71 |
+
Northern Pacific Rattlesnake
|
72 |
+
S Boa"]
|
73 |
+
Mozambique Spitting Cobra
|
74 |
+
Grass Snake
|
75 |
+
Thelotornis Snake
|
76 |
+
Black Headed Python
|
77 |
+
Metlapilcoatlus Mexicanus Jumping Pit Viper
|
78 |
+
Natrix Natrix Snake
|
79 |
+
Atheris Nitschei Viper
|
80 |
+
Puff Adder
|
81 |
+
Rhamnophis Aethiopissa Snake
|
82 |
+
Coral Snake
|
83 |
+
Montpellier Snake
|
84 |
+
Bothriechis
|
85 |
+
Lampropeltis Pyromelana Snake
|
86 |
+
Lampropeltis Zonata Snake
|
87 |
+
Natrix Tessellata Snake
|
88 |
+
Thamnophis Cyrtopsis Snake
|
89 |
+
Rat Snake
|
90 |
+
White Speckled Rattlesnake
|
91 |
+
Namaqua Dwarf Adder
|
92 |
+
Aesculapian Snake
|
93 |
+
Nose-Horned Viper
|
94 |
+
Zamenis Longissiumus Snake
|
95 |
+
Rhinoceros Viper
|
96 |
+
Rattlesnake
|
97 |
+
Red Bellied Black Snake
|
98 |
+
Eyelash Pit Viper
|
99 |
+
Snouted Cobra
|
100 |
+
JamesonS Mamba
|
101 |
+
Gaboon Viper
|
102 |
+
Lampropeltis Splendida Snake
|
103 |
+
Laticauda Saintgironsi Sea Krait
|
104 |
+
Pituophis Catenifer Snake
|
105 |
+
Eastern Montpellier Snake
|
106 |
+
Black Mamba
|
107 |
+
Javan Spitting Cobra
|
108 |
+
Sea Snake
|
109 |
+
Gyalopion Canum Snake
|
110 |
+
Mojave Rattlesnake Young
|
111 |
+
Spectacled Cobra
|
112 |
+
Hognosed Pit Viper
|
113 |
+
Dog Faced Water Snake
|
114 |
+
Cobra
|
115 |
+
Tree Snake
|
116 |
+
Lampropeltis Getula Snake
|
117 |
+
S Mamba"]
|
118 |
+
Bothrops Asper
|
119 |
+
Rinkhals Snake
|
120 |
+
Malabar Pit Viper
|
121 |
+
Western Diamondback Young
|
122 |
+
Micruroides Euryxanthus Snake
|
123 |
+
Western Diamondback
|
124 |
+
Python
|
125 |
+
Viper
|
126 |
+
Horse
|
127 |
+
Horse Young
|
128 |
+
Texas Brown Tarantula
|
129 |
+
Tarantula
|
130 |
+
Diving Bell Water Spider
|
131 |
+
Turret Spider
|
132 |
+
Avicularia Spider
|
133 |
+
Golden Orb Spider
|
134 |
+
Spider
|
135 |
+
Salticidae Jumping Spider
|
136 |
+
Salticidae 3 Spider
|
137 |
+
Jumping Spider
|
138 |
+
Orb Spider
|
139 |
+
Latrodectus Hersperus Western Widow Spider
|
140 |
+
Redback Spider
|
141 |
+
Western Widow Spider
|
142 |
+
Nursery Web Spider
|
143 |
+
Grass Spider
|
144 |
+
Araneus Diadematus Spider
|
145 |
+
Tetragnatha Versicolor
|
146 |
+
Cyclosa Conica
|
147 |
+
Daddy Longlegs Spider
|
148 |
+
Portia Jumping Spider
|
149 |
+
Portia 3 Spider
|
150 |
+
Habronattus Clypeatus
|
151 |
+
Chimpanzee
|
152 |
+
Gorilla
|
153 |
+
Mountain Gorilla
|
154 |
+
Stump-Tailed Macaque
|
155 |
+
Pig-Tailed Macaque
|
156 |
+
Monkey
|
157 |
+
Orangutan
|
158 |
+
Grey Langur
|
159 |
+
Mitred Leaf Monkey
|
160 |
+
Gibbon
|
161 |
+
Red Ruffed Lemur
|
162 |
+
Proboscis Monkey
|
163 |
+
Sumatran Orangutan
|
164 |
+
Red-Ruffed Lemur
|
165 |
+
White-Faced Saki Monkey
|
166 |
+
Bornean Orangutan
|
167 |
+
Mandrill
|
168 |
+
Maroon Macaque
|
169 |
+
Rhesus Macaque
|
170 |
+
Raffles Banded Langur
|
171 |
+
Western Chimpanzee Young
|
172 |
+
Diana Monkey
|
173 |
+
Lesser Spot Nosed Monkey
|
174 |
+
S Monkey"]
|
175 |
+
Sooty Mangabey
|
176 |
+
King Colobus
|
177 |
+
Olive Colobus
|
178 |
+
Capuchin Monkey
|
179 |
+
Red-Blacked Squirrel Monkey
|
180 |
+
Ring-Tailed Lemur
|
181 |
+
Iguana
|
182 |
+
Marine Iguana
|
183 |
+
Rhacodactylus Trachyrhynchus Gecko
|
184 |
+
Gecko
|
185 |
+
Lizard
|
186 |
+
Common Basilisk Lizard
|
187 |
+
Basilisk Lizard
|
188 |
+
Monitor Lizard
|
189 |
+
Strange-Horned Chameleon
|
190 |
+
Chameleon
|
191 |
+
Leaf-Tailed Gecko
|
192 |
+
Clouded Monitor Lizard
|
193 |
+
S Chameleon"]
|
194 |
+
Indian Chameleon
|
195 |
+
Namaqua Dwarf Chameleon
|
196 |
+
Malayan Water Monitor Lizard
|
197 |
+
Green Iguana
|
198 |
+
White Wig Marine Iguana
|
199 |
+
Skink
|
200 |
+
Black Bearded Draco
|
201 |
+
Yellow Striped Tree Skink
|
202 |
+
Side Blotched Lizard
|
203 |
+
Frilled Neck Lizard
|
204 |
+
Leopard Gecko
|
205 |
+
Morning Gecko
|
206 |
+
Giant Ground Gecko
|
207 |
+
Western Dwarf Chameleon
|
208 |
+
Whooper Swan
|
209 |
+
Goose
|
210 |
+
Puffin
|
211 |
+
Mallard Duck
|
212 |
+
Greylag Goose
|
213 |
+
S Duck"]
|
214 |
+
African Finfoot
|
215 |
+
Wandering Alabatross
|
216 |
+
Duck
|
217 |
+
Common Goldeneye
|
218 |
+
Garganey
|
219 |
+
Black Stork
|
220 |
+
Black Swan
|
221 |
+
Сommon Eider
|
222 |
+
Carolina Duck
|
223 |
+
Smew
|
224 |
+
Lanius Excubitor
|
225 |
+
Song Thrush Bird
|
226 |
+
Wedge Tailed Eagle
|
227 |
+
Great Grey Shrike
|
228 |
+
Anthus Pratensis Bird
|
229 |
+
Red-Throated Pipit
|
230 |
+
Tit Bird
|
231 |
+
Hoopoe
|
232 |
+
Woodpecker
|
233 |
+
Botaurus Stellaris Bird
|
234 |
+
Little Crake Bird
|
235 |
+
White-Throated Dipper
|
236 |
+
Raven
|
237 |
+
Citrine Wagtail Bird
|
238 |
+
Yellowhammer Young
|
239 |
+
Grey Heron
|
240 |
+
Alauda Arvensis Bird
|
241 |
+
Bird
|
242 |
+
Common Cuckoo Bird
|
243 |
+
Tringa Erythropus Bird
|
244 |
+
White Throated Dipper Bird
|
245 |
+
Skylark
|
246 |
+
Mistle Thrush
|
247 |
+
Robin Bird
|
248 |
+
Australian Bowerbird
|
249 |
+
Turtle Dove
|
250 |
+
Black-Winged Stilt
|
251 |
+
Wood Warbler
|
252 |
+
Common Crane
|
253 |
+
Eurasian Wren Bird
|
254 |
+
Common Quail
|
255 |
+
Nightingale Bird
|
256 |
+
Tawny Owl
|
257 |
+
Grebe Bird
|
258 |
+
Water Dipper Bird
|
259 |
+
Yellowhammer
|
260 |
+
Hazel Grouse Bird
|
261 |
+
Greater Racket Tail Drongo
|
262 |
+
Golden Oriole
|
263 |
+
Great Egret
|
264 |
+
Turdus Merula Blackbird
|
265 |
+
Nuthatch Bird
|
266 |
+
Eagle
|
267 |
+
Luscinia Luscinia Nightingale Bird
|
268 |
+
Singing Nightingale
|
269 |
+
Water Rail Bird
|
270 |
+
Gull
|
271 |
+
Azure Tit Bird
|
272 |
+
Numenius Arquata Bird
|
273 |
+
Golden Eagle
|
274 |
+
Remiz Pendulinus Bird
|
275 |
+
Goldfinch
|
276 |
+
Common Whitethroat Bird
|
277 |
+
Red-Backed Shrike Bird
|
278 |
+
Grasshopper Warbler
|
279 |
+
Shoebill Bird
|
280 |
+
Common Rosefinch Bird
|
281 |
+
Owl
|
282 |
+
Chaffinch Bird
|
283 |
+
Bluethroat
|
284 |
+
Green Woodpecker
|
285 |
+
Common Snipe
|
286 |
+
Whinchat Bird
|
287 |
+
Ostrich
|
288 |
+
Boreal Owl
|
289 |
+
European Robin Bird
|
290 |
+
Larus Canus Bird
|
291 |
+
Hawk
|
292 |
+
Three-Toed Woodpecker
|
293 |
+
Thrush Nightingale Bird
|
294 |
+
Jack Snipe Bird
|
295 |
+
Red Crossbill
|
296 |
+
Chiffchaff Bird
|
297 |
+
Shorebird
|
298 |
+
Bullfinch
|
299 |
+
Red-Backed Shrike Bird Young
|
300 |
+
Circus Aeruginosus Bird
|
301 |
+
Kingfisher
|
302 |
+
White-Backed Woodpecker
|
303 |
+
Tringa Ochropus Bird
|
304 |
+
Stock Dove
|
305 |
+
Heron
|
306 |
+
Citrine Wagtail
|
307 |
+
Vanellus Vanellus Bird
|
308 |
+
Tringa Nebularia Bird
|
309 |
+
Eurasian Wryneck Bird
|
310 |
+
Tachybaptus Ruficollis Bird
|
311 |
+
Quail
|
312 |
+
Little Egret
|
313 |
+
Stork
|
314 |
+
Green Mamba
|
315 |
+
Pufferfish
|
316 |
+
Bullfrog
|
317 |
+
Frog
|
318 |
+
Corroboree Frog
|
319 |
+
Desert Rain Frog
|
320 |
+
African Clawed Toad
|
321 |
+
Mountain Yellow-Legged Frog
|
322 |
+
Tropical Reed Frog
|
323 |
+
Mimic Poison Frog
|
324 |
+
Red-Eyed Tree Frog
|
325 |
+
S Frog"]
|
326 |
+
Water Lily Frog
|
327 |
+
Amazon Milk Frog
|
328 |
+
Toad
|
329 |
+
Marbled Rubber Frog
|
330 |
+
Sand Frog
|
331 |
+
Golden Poison Frog
|
332 |
+
Rain Frog
|
333 |
+
Monster Frog
|
334 |
+
S Warbler Bird"]
|
335 |
+
Great Snipe
|
336 |
+
European Serin Bird
|
337 |
+
Calidris Apina Bird
|
338 |
+
Peacock
|
339 |
+
Tringa Glareola Bird
|
340 |
+
Cuckoo Bird
|
341 |
+
Barred Warbler Bird
|
342 |
+
Pacman Frog
|
343 |
+
Ardea Alba Egret
|
344 |
+
Tern
|
345 |
+
Motacilla Alba Bird
|
346 |
+
Motacilla Flava
|
347 |
+
Anas Crecca Bird
|
348 |
+
Blue Poison Dart Frog
|
349 |
+
Marsh Harrier Bird
|
350 |
+
Glass Frog
|
351 |
+
Great Reed Warbler Bird
|
352 |
+
Banded Rubber Frog
|
353 |
+
Tomato Frog
|
354 |
+
Hornbill
|
355 |
+
Woodlark Bird
|
356 |
+
Starling Bird
|
357 |
+
Common Buzzard
|
358 |
+
Gallinago Gallinago Bird
|
359 |
+
White And Gray Wagtail Bird
|
360 |
+
Strawberry Poison-Dart Frog
|
361 |
+
Corncrake
|
362 |
+
Bald Eagle
|
363 |
+
S Harrier Young"]
|
364 |
+
Charadrius Dubius Bird
|
365 |
+
Pelican
|
366 |
+
Flamingo Young
|
367 |
+
Socotran Cormorant
|
368 |
+
Sparrowhawk
|
369 |
+
S Harrier"]
|
370 |
+
Pygmy Owl
|
371 |
+
Philomachus Pugnax Ruff Bird
|
372 |
+
Wren
|
373 |
+
Common Wood Pigeon
|
374 |
+
Grass Warbler Bird
|
375 |
+
Whiskered Tern Bird
|
376 |
+
Icterine Warbler Bird
|
377 |
+
Crowned Eagle
|
378 |
+
Crane
|
379 |
+
Hummingbird
|
380 |
+
Ardeotis Kori Bird
|
381 |
+
Guttural Toad
|
382 |
+
Crested Grebe Bird
|
383 |
+
Reed Bunting Bird
|
384 |
+
White Cockatoo Bird
|
385 |
+
Sedge Warbler Bird
|
386 |
+
Goldcrest Bird
|
387 |
+
Montagus Harrier Young
|
388 |
+
European Turtle Dove
|
389 |
+
Asian Glossy Starling Bird
|
390 |
+
Spotted Wood Owl
|
391 |
+
Sagittarius Serpentarius Bird
|
392 |
+
Parrot
|
393 |
+
Anas Platyrhynchos Bird
|
394 |
+
Phalacrocorax Carbo Bird
|
395 |
+
White-Breasted Waterhen
|
396 |
+
African Bullfrog
|
397 |
+
Stork-Billed Kingfisher
|
398 |
+
Oriental Pied Hornbill
|
399 |
+
Flamingo
|
400 |
+
Banded Woodpecker
|
401 |
+
Foam Nest Frog
|
402 |
+
Vulture
|
403 |
+
Larus Ridibundus Bird
|
404 |
+
|
405 |
+
AP-10K & APT-36K:
|
406 |
+
monkey
|
407 |
+
elephant
|
408 |
+
leopard
|
409 |
+
horse
|
410 |
+
jaguar
|
411 |
+
panda
|
412 |
+
marmot
|
413 |
+
deer
|
414 |
+
noisy night monkey
|
415 |
+
orangutan
|
416 |
+
sheep
|
417 |
+
spider-monkey
|
418 |
+
bison
|
419 |
+
zebra
|
420 |
+
dog
|
421 |
+
weasel
|
422 |
+
bat
|
423 |
+
uakari
|
424 |
+
raccoon
|
425 |
+
tiger
|
426 |
+
rat
|
427 |
+
rhino
|
428 |
+
chimpanzee
|
429 |
+
antelope
|
430 |
+
argali sheep
|
431 |
+
gorilla
|
432 |
+
buffalo
|
433 |
+
bobcat
|
434 |
+
hippo
|
435 |
+
mouse
|
436 |
+
moose
|
437 |
+
howling-monkey
|
438 |
+
black-bear
|
439 |
+
wolf
|
440 |
+
squirrel
|
441 |
+
skunk
|
442 |
+
king cheetah
|
443 |
+
cheetah
|
444 |
+
spider monkey
|
445 |
+
hamster
|
446 |
+
arctic fox
|
447 |
+
polar bear
|
448 |
+
rabbit
|
449 |
+
panther
|
450 |
+
cow
|
451 |
+
brown bear
|
452 |
+
otter
|
453 |
+
beaver
|
454 |
+
pig
|
455 |
+
fox
|
456 |
+
alouatta
|
457 |
+
giraffe
|
458 |
+
polar-bear
|
459 |
+
raccon
|
460 |
+
snow leopard
|
461 |
+
lion
|
462 |
+
cat
|
463 |
+
mole
|
464 |
+
black bear
|
465 |
+
|
466 |
+
Desert Locust
|
467 |
+
|
468 |
+
Vinegar Fly
|
469 |
+
|
470 |
+
|
471 |
+
CUB-200-2011:
|
472 |
+
grebe_body
|
473 |
+
gull_body
|
474 |
+
kingfisher_body
|
475 |
+
sparrow_body
|
476 |
+
tern_body
|
477 |
+
warbler_body
|
478 |
+
woodpecker_body
|
479 |
+
wren_body
|
480 |
+
|
481 |
+
|
482 |
+
|
483 |
+
Carfusion:
|
484 |
+
bus
|
485 |
+
car
|
486 |
+
suv
|
487 |
+
|
488 |
+
|
489 |
+
Deepfashion2:
|
490 |
+
short sleeve top
|
491 |
+
long sleeve top
|
492 |
+
short sleeve outwear
|
493 |
+
long sleeve outwear
|
494 |
+
vest
|
495 |
+
sling
|
496 |
+
shorts
|
497 |
+
trousers
|
498 |
+
skirt
|
499 |
+
short sleeve dress
|
500 |
+
long sleeve dress
|
501 |
+
vest dress
|
502 |
+
sling dress
|
503 |
+
|
504 |
+
|
505 |
+
|
506 |
+
Keypoint-5:
|
507 |
+
bed
|
508 |
+
chair
|
509 |
+
sofa
|
510 |
+
swivelchair
|
511 |
+
table
|
512 |
+
|
513 |
+
AnimalWeb:
|
514 |
+
blackbuck
|
515 |
+
small asian mongoose
|
516 |
+
common dwarf mongoose
|
517 |
+
galapagos sea lion
|
518 |
+
margay
|
519 |
+
nilgai
|
520 |
+
Humboldt penguin
|
521 |
+
oryx
|
522 |
+
tammar wallaby
|
523 |
+
monkey
|
524 |
+
swamp wallaby
|
525 |
+
muntjac deer
|
526 |
+
blue-eyed black lemur
|
527 |
+
binturong
|
528 |
+
hamadryas baboon
|
529 |
+
Adelie penguin
|
530 |
+
Australian cattle dog
|
531 |
+
howler
|
532 |
+
striped hyena
|
533 |
+
vole
|
534 |
+
zebu
|
535 |
+
woodchuck
|
536 |
+
proboscis monkey
|
537 |
+
whiptail wallaby
|
538 |
+
anoa
|
539 |
+
hippopotamus
|
540 |
+
crested penguin
|
541 |
+
addax
|
542 |
+
red-bellied squirrel
|
543 |
+
suni
|
544 |
+
feral cat
|
545 |
+
galagos
|
546 |
+
banteng
|
547 |
+
Weddell seal
|
548 |
+
zebra
|
549 |
+
Ethiopian wolf
|
550 |
+
snow leopard
|
551 |
+
common chimpanzee
|
552 |
+
giant schnauzer
|
553 |
+
lemur
|
554 |
+
jaguarundi
|
555 |
+
Asian golden cat
|
556 |
+
gray wolf
|
557 |
+
anteater
|
558 |
+
golden jackal
|
559 |
+
banded palm civet
|
560 |
+
cougar
|
561 |
+
Barbary macaque
|
562 |
+
giant otter
|
563 |
+
agouti
|
564 |
+
emperor penguin
|
565 |
+
feral horse
|
566 |
+
yellow-footed rock wallaby
|
567 |
+
raccoon
|
568 |
+
topi
|
569 |
+
opossum
|
570 |
+
central chimpanzee
|
571 |
+
pygmy rabbit
|
572 |
+
fishing cat
|
573 |
+
reedbuck
|
574 |
+
Mediterranean monk seal
|
575 |
+
domestic cat
|
576 |
+
kangaroo
|
577 |
+
boar
|
578 |
+
rusty-spotted cat
|
579 |
+
spider monkey
|
580 |
+
echidna
|
581 |
+
Chinese goral
|
582 |
+
ringtail
|
583 |
+
kultarr
|
584 |
+
Californian sea lion
|
585 |
+
guanaco
|
586 |
+
muriqui
|
587 |
+
gerbil
|
588 |
+
wildebeest
|
589 |
+
bison
|
590 |
+
Australian terrier
|
591 |
+
hyrax
|
592 |
+
clouded leopard
|
593 |
+
goat
|
594 |
+
badger
|
595 |
+
beaver
|
596 |
+
Przewalski horse
|
597 |
+
camel
|
598 |
+
beating mongoose
|
599 |
+
field mouse
|
600 |
+
collared peccary
|
601 |
+
tree shrew
|
602 |
+
wombat
|
603 |
+
titi
|
604 |
+
steenbuck steenbok
|
605 |
+
Australian sea lion
|
606 |
+
buffalo
|
607 |
+
chamois
|
608 |
+
baikal seal
|
609 |
+
brush-tailed rock wallaby
|
610 |
+
bongo
|
611 |
+
Barbary sheep
|
612 |
+
great dane
|
613 |
+
cheetah
|
614 |
+
long-nosed mongoose
|
615 |
+
cape buffalo
|
616 |
+
waterbuck
|
617 |
+
rhesus monkey
|
618 |
+
jungle cat
|
619 |
+
black-and-white ruffed lemur
|
620 |
+
Japanese serow
|
621 |
+
potto
|
622 |
+
dall sheep
|
623 |
+
indri
|
624 |
+
large-spotted genet
|
625 |
+
Amur leopard
|
626 |
+
Owston's palm civet
|
627 |
+
dingo
|
628 |
+
gibbons
|
629 |
+
Doberman
|
630 |
+
giraffe
|
631 |
+
fox
|
632 |
+
quokka
|
633 |
+
Amur tiger
|
634 |
+
wild ass
|
635 |
+
walrus
|
636 |
+
common genet
|
637 |
+
bilby
|
638 |
+
hamster
|
639 |
+
yellow-eyed penguin
|
640 |
+
panda
|
641 |
+
agile wallaby
|
642 |
+
bengal slow loris
|
643 |
+
marmoset
|
644 |
+
brown hyena
|
645 |
+
gorilla
|
646 |
+
aardvark
|
647 |
+
swift fox
|
648 |
+
Magellanic penguin
|
649 |
+
bear
|
650 |
+
Anatolian shepherd dog
|
651 |
+
irish wolfhound
|
652 |
+
husky
|
653 |
+
kinkajou
|
654 |
+
brown rat
|
655 |
+
tarsiers
|
656 |
+
matschie's tree kangaroo
|
657 |
+
black-backed jackal
|
658 |
+
ocelot
|
659 |
+
grey seal
|
660 |
+
bullmastiff
|
661 |
+
gentoo penguin
|
662 |
+
gerenuk
|
663 |
+
bearded seal
|
664 |
+
hooded seal
|
665 |
+
monte
|
666 |
+
leopard cat
|
667 |
+
western chimpanzee
|
668 |
+
african penguin
|
669 |
+
dikdik
|
670 |
+
komondor
|
671 |
+
coypu
|
672 |
+
dalmatian
|
673 |
+
armadillo
|
674 |
+
marsh mongoose
|
675 |
+
rusty-spotted genet
|
676 |
+
lion
|
677 |
+
bharal
|
678 |
+
wolverine
|
679 |
+
visayan warty pig
|
680 |
+
lutung
|
681 |
+
bornean slow loris
|
682 |
+
caiman
|
683 |
+
aardwolf
|
684 |
+
german shepherd dog
|
685 |
+
sunda slow loris
|
686 |
+
mangabey
|
687 |
+
cape gray mongoose
|
688 |
+
crowned lemur
|
689 |
+
harp seal
|
690 |
+
gelada baboon
|
691 |
+
wallaroo
|
692 |
+
hare
|
693 |
+
goodfellow's tree kangaroo
|
694 |
+
elk
|
695 |
+
muskox
|
696 |
+
capybara
|
697 |
+
toque macaque
|
698 |
+
roe deer
|
699 |
+
eastern lesser bamboo lemur
|
700 |
+
leopard
|
701 |
+
wapiti
|
702 |
+
gray fox
|
703 |
+
alpaca
|
704 |
+
guinea pig
|
705 |
+
crabeater seal
|
706 |
+
black rhino
|
707 |
+
little blue penguin
|
708 |
+
bighorn sheep
|
709 |
+
caracal
|
710 |
+
tamarin
|
711 |
+
hawaiian monk seal
|
712 |
+
lumholtz's tree kangaroo
|
713 |
+
koala
|
714 |
+
gundi
|
715 |
+
onager
|
716 |
+
cacomistle
|
717 |
+
red-ruffed lemur
|
718 |
+
orangutan
|
719 |
+
bobcat
|
720 |
+
black-footed cat
|
721 |
+
alaskan hare
|
722 |
+
debrazza's monkey
|
723 |
+
swamp rabbit
|
724 |
+
white wolf
|
725 |
+
sharpe's grysbok
|
726 |
+
urial
|
727 |
+
feral goat
|
728 |
+
serval
|
729 |
+
degu
|
730 |
+
golden bamboo lemur
|
731 |
+
deer mouse
|
732 |
+
coatis
|
733 |
+
wildcat
|
734 |
+
roan antelope
|
735 |
+
dugong
|
736 |
+
fennec fox
|
737 |
+
southern elephant seal
|
738 |
+
saluki
|
739 |
+
golden langur
|
740 |
+
oribi
|
741 |
+
red-tail monkey
|
742 |
+
chital
|
743 |
+
dormouse
|
744 |
+
woolly monkey
|
745 |
+
leopard seal
|
746 |
+
possum
|
747 |
+
arctic wolf
|
748 |
+
japanese macaque
|
749 |
+
vervet monkey
|
750 |
+
bamboo lemur
|
751 |
+
aye-aye
|
752 |
+
night monkey
|
753 |
+
blue monkey
|
754 |
+
sand cat
|
755 |
+
bull
|
756 |
+
cape fox
|
757 |
+
klipspringer
|
758 |
+
border collie
|
759 |
+
mouflon
|
760 |
+
chipmunk
|
761 |
+
potoroo
|
762 |
+
bushbuck
|
763 |
+
northern elephant seal
|
764 |
+
patagonian mara
|
765 |
+
bandicoot
|
766 |
+
feral cattle
|
767 |
+
babirusa
|
768 |
+
harvest mouse
|
769 |
+
alaskan malamute
|
770 |
+
servaline genet
|
771 |
+
olive baboon
|
772 |
+
italian greyhound
|
773 |
+
white-headed lemur
|
774 |
+
chihuahua
|
775 |
+
red-necked wallaby
|
776 |
+
fallow deer
|
777 |
+
pygmy slow loris
|
778 |
+
australian shepherd
|
779 |
+
eastern chimpanzee
|
780 |
+
colobus
|
781 |
+
chinstrap penguin
|
782 |
+
deer
|
783 |
+
common warthog
|
784 |
+
dunnart
|
785 |
+
wisent
|
786 |
+
hedgehog
|
787 |
+
douc langur
|
788 |
+
tasmanian devil
|
789 |
+
colo
|
790 |
+
flying squirrel
|
791 |
+
canadian lynx
|
792 |
+
ferret
|
793 |
+
ribbon seal
|
794 |
+
platypus
|
795 |
+
cotton rat
|
796 |
+
oncilla
|
797 |
+
geoffroy's cat
|
798 |
+
horse
|
799 |
+
pardine genet
|
800 |
+
slender mongoose
|
801 |
+
liger
|
802 |
+
mareeba rock wallaby
|
803 |
+
olingos
|
804 |
+
bonobo
|
805 |
+
harbour seal
|
806 |
+
pademelon
|
807 |
+
domestic dog
|
808 |
+
chow chow
|
809 |
+
gharial
|
810 |
+
quoll
|
811 |
+
capuchin monkey
|
812 |
+
corsac fox
|
813 |
+
dassie
|
814 |
+
bolognese dog
|
815 |
+
ruddy mongoose
|
816 |
+
rhinoceros
|
817 |
+
red panda
|
818 |
+
king penguin
|
819 |
+
dachshund
|
820 |
+
common brown lemur
|
821 |
+
pekingese dog
|
822 |
+
western lesser bamboo lemur
|
823 |
+
banded mongoose
|
824 |
+
grey langur
|
825 |
+
patas monkey
|
826 |
+
francois langur
|
827 |
+
white-tailed deer
|
828 |
+
african wild dog
|
829 |
+
collared brown lemur
|
830 |
+
weasel
|
831 |
+
mexican wolf
|
832 |
+
hartebeest
|
833 |
+
uakari
|
834 |
+
viverrata ngalunga malayan civet
|
835 |
+
dhole
|
836 |
+
eurasian lynx
|
837 |
+
hog deer
|
838 |
+
bushbaby
|
839 |
+
grizzly bear
|
840 |
+
caribou
|
841 |
+
german pinscher
|
842 |
+
jaguar
|
843 |
+
donkey
|
844 |
+
duiker
|
845 |
+
spotted hyena
|
846 |
+
golden retriever
|
847 |
+
pantanal cat
|
848 |
+
spotted-necked otter
|
849 |
+
asian palm civet
|
850 |
+
alpine ibex
|
851 |
+
jackrabbit
|
852 |
+
greater bamboo lemur
|
853 |
+
kiang
|
854 |
+
common kusimanse
|
855 |
+
pallas cat
|
856 |
+
stripe-necked mongoose
|
857 |
+
parma wallaby
|
858 |
+
yak
|
859 |
+
balinese cat
|
860 |
+
spotted seal
|
861 |
+
french bulldog
|
862 |
+
zonkey
|
863 |
+
arctic fox
|
src/utils/dependencies/XPose/util/logger.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
2 |
+
import functools
|
3 |
+
import logging
|
4 |
+
import os
|
5 |
+
import sys
|
6 |
+
from termcolor import colored
|
7 |
+
|
8 |
+
|
9 |
+
class _ColorfulFormatter(logging.Formatter):
|
10 |
+
def __init__(self, *args, **kwargs):
|
11 |
+
self._root_name = kwargs.pop("root_name") + "."
|
12 |
+
self._abbrev_name = kwargs.pop("abbrev_name", "")
|
13 |
+
if len(self._abbrev_name):
|
14 |
+
self._abbrev_name = self._abbrev_name + "."
|
15 |
+
super(_ColorfulFormatter, self).__init__(*args, **kwargs)
|
16 |
+
|
17 |
+
def formatMessage(self, record):
|
18 |
+
record.name = record.name.replace(self._root_name, self._abbrev_name)
|
19 |
+
log = super(_ColorfulFormatter, self).formatMessage(record)
|
20 |
+
if record.levelno == logging.WARNING:
|
21 |
+
prefix = colored("WARNING", "red", attrs=["blink"])
|
22 |
+
elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL:
|
23 |
+
prefix = colored("ERROR", "red", attrs=["blink", "underline"])
|
24 |
+
else:
|
25 |
+
return log
|
26 |
+
return prefix + " " + log
|
27 |
+
|
28 |
+
|
29 |
+
# so that calling setup_logger multiple times won't add many handlers
|
30 |
+
@functools.lru_cache()
|
31 |
+
def setup_logger(
|
32 |
+
output=None, distributed_rank=0, *, color=True, name="imagenet", abbrev_name=None
|
33 |
+
):
|
34 |
+
"""
|
35 |
+
Initialize the detectron2 logger and set its verbosity level to "INFO".
|
36 |
+
|
37 |
+
Args:
|
38 |
+
output (str): a file name or a directory to save log. If None, will not save log file.
|
39 |
+
If ends with ".txt" or ".log", assumed to be a file name.
|
40 |
+
Otherwise, logs will be saved to `output/log.txt`.
|
41 |
+
name (str): the root module name of this logger
|
42 |
+
|
43 |
+
Returns:
|
44 |
+
logging.Logger: a logger
|
45 |
+
"""
|
46 |
+
logger = logging.getLogger(name)
|
47 |
+
logger.setLevel(logging.DEBUG)
|
48 |
+
logger.propagate = False
|
49 |
+
|
50 |
+
if abbrev_name is None:
|
51 |
+
abbrev_name = name
|
52 |
+
|
53 |
+
plain_formatter = logging.Formatter(
|
54 |
+
'[%(asctime)s.%(msecs)03d]: %(message)s',
|
55 |
+
datefmt='%m/%d %H:%M:%S'
|
56 |
+
)
|
57 |
+
# stdout logging: master only
|
58 |
+
if distributed_rank == 0:
|
59 |
+
ch = logging.StreamHandler(stream=sys.stdout)
|
60 |
+
ch.setLevel(logging.DEBUG)
|
61 |
+
if color:
|
62 |
+
formatter = _ColorfulFormatter(
|
63 |
+
colored("[%(asctime)s.%(msecs)03d]: ", "green") + "%(message)s",
|
64 |
+
datefmt="%m/%d %H:%M:%S",
|
65 |
+
root_name=name,
|
66 |
+
abbrev_name=str(abbrev_name),
|
67 |
+
)
|
68 |
+
else:
|
69 |
+
formatter = plain_formatter
|
70 |
+
ch.setFormatter(formatter)
|
71 |
+
logger.addHandler(ch)
|
72 |
+
|
73 |
+
# file logging: all workers
|
74 |
+
if output is not None:
|
75 |
+
if output.endswith(".txt") or output.endswith(".log"):
|
76 |
+
filename = output
|
77 |
+
else:
|
78 |
+
filename = os.path.join(output, "log.txt")
|
79 |
+
if distributed_rank > 0:
|
80 |
+
filename = filename + f".rank{distributed_rank}"
|
81 |
+
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
82 |
+
|
83 |
+
fh = logging.StreamHandler(_cached_log_stream(filename))
|
84 |
+
fh.setLevel(logging.DEBUG)
|
85 |
+
fh.setFormatter(plain_formatter)
|
86 |
+
logger.addHandler(fh)
|
87 |
+
|
88 |
+
return logger
|
89 |
+
|
90 |
+
|
91 |
+
# cache the opened file object, so that different calls to `setup_logger`
|
92 |
+
# with the same file name can safely write to the same file.
|
93 |
+
@functools.lru_cache(maxsize=None)
|
94 |
+
def _cached_log_stream(filename):
|
95 |
+
return open(filename, "a")
|
src/utils/dependencies/XPose/util/metrics.py
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Aishwarya Kamath & Nicolas Carion. Licensed under the Apache License 2.0. All Rights Reserved
|
2 |
+
"""
|
3 |
+
Various utilities related to track and report metrics
|
4 |
+
"""
|
5 |
+
import datetime
|
6 |
+
import time
|
7 |
+
from collections import defaultdict, deque
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.distributed as dist
|
11 |
+
|
12 |
+
from util.misc import is_dist_avail_and_initialized
|
13 |
+
|
14 |
+
|
15 |
+
class SmoothedValue:
|
16 |
+
"""Track a series of values and provide access to smoothed values over a
|
17 |
+
window or the global series average.
|
18 |
+
"""
|
19 |
+
|
20 |
+
def __init__(self, window_size=20, fmt=None):
|
21 |
+
if fmt is None:
|
22 |
+
fmt = "{median:.4f} ({global_avg:.4f})"
|
23 |
+
self.deque = deque(maxlen=window_size)
|
24 |
+
self.total = 0.0
|
25 |
+
self.count = 0
|
26 |
+
self.fmt = fmt
|
27 |
+
|
28 |
+
def update(self, value, num=1):
|
29 |
+
self.deque.append(value)
|
30 |
+
self.count += num
|
31 |
+
self.total += value * num
|
32 |
+
|
33 |
+
def synchronize_between_processes(self):
|
34 |
+
"""
|
35 |
+
Distributed synchronization of the metric
|
36 |
+
Warning: does not synchronize the deque!
|
37 |
+
"""
|
38 |
+
if not is_dist_avail_and_initialized():
|
39 |
+
return
|
40 |
+
t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
|
41 |
+
dist.barrier()
|
42 |
+
dist.all_reduce(t)
|
43 |
+
t = t.tolist()
|
44 |
+
self.count = int(t[0])
|
45 |
+
self.total = t[1]
|
46 |
+
|
47 |
+
@property
|
48 |
+
def median(self):
|
49 |
+
d = torch.tensor(list(self.deque))
|
50 |
+
return d.median().item()
|
51 |
+
|
52 |
+
@property
|
53 |
+
def avg(self):
|
54 |
+
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
55 |
+
return d.mean().item()
|
56 |
+
|
57 |
+
@property
|
58 |
+
def global_avg(self):
|
59 |
+
return self.total / self.count
|
60 |
+
|
61 |
+
@property
|
62 |
+
def max(self):
|
63 |
+
return max(self.deque)
|
64 |
+
|
65 |
+
@property
|
66 |
+
def value(self):
|
67 |
+
return self.deque[-1]
|
68 |
+
|
69 |
+
def __str__(self):
|
70 |
+
return self.fmt.format(
|
71 |
+
median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value
|
72 |
+
)
|
73 |
+
|
74 |
+
|
75 |
+
class MetricLogger(object):
|
76 |
+
def __init__(self, delimiter="\t"):
|
77 |
+
self.meters = defaultdict(SmoothedValue)
|
78 |
+
self.delimiter = delimiter
|
79 |
+
|
80 |
+
def update(self, **kwargs):
|
81 |
+
for k, v in kwargs.items():
|
82 |
+
if isinstance(v, torch.Tensor):
|
83 |
+
v = v.item()
|
84 |
+
assert isinstance(v, (float, int))
|
85 |
+
self.meters[k].update(v)
|
86 |
+
|
87 |
+
def __getattr__(self, attr):
|
88 |
+
if attr in self.meters:
|
89 |
+
return self.meters[attr]
|
90 |
+
if attr in self.__dict__:
|
91 |
+
return self.__dict__[attr]
|
92 |
+
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, attr))
|
93 |
+
|
94 |
+
def __str__(self):
|
95 |
+
loss_str = []
|
96 |
+
for name, meter in self.meters.items():
|
97 |
+
loss_str.append("{}: {}".format(name, str(meter)))
|
98 |
+
return self.delimiter.join(loss_str)
|
99 |
+
|
100 |
+
def synchronize_between_processes(self):
|
101 |
+
for meter in self.meters.values():
|
102 |
+
meter.synchronize_between_processes()
|
103 |
+
|
104 |
+
def add_meter(self, name, meter):
|
105 |
+
self.meters[name] = meter
|
106 |
+
|
107 |
+
def log_every(self, iterable, print_freq, header=None):
|
108 |
+
i = 0
|
109 |
+
if not header:
|
110 |
+
header = ""
|
111 |
+
start_time = time.time()
|
112 |
+
end = time.time()
|
113 |
+
iter_time = SmoothedValue(fmt="{avg:.4f}")
|
114 |
+
data_time = SmoothedValue(fmt="{avg:.4f}")
|
115 |
+
space_fmt = ":" + str(len(str(len(iterable)))) + "d"
|
116 |
+
if torch.cuda.is_available():
|
117 |
+
log_msg = self.delimiter.join(
|
118 |
+
[
|
119 |
+
header,
|
120 |
+
"[{0" + space_fmt + "}/{1}]",
|
121 |
+
"eta: {eta}",
|
122 |
+
"{meters}",
|
123 |
+
"time: {time}",
|
124 |
+
"data: {data}",
|
125 |
+
"max mem: {memory:.0f}",
|
126 |
+
]
|
127 |
+
)
|
128 |
+
else:
|
129 |
+
log_msg = self.delimiter.join(
|
130 |
+
[header, "[{0" + space_fmt + "}/{1}]", "eta: {eta}", "{meters}", "time: {time}", "data: {data}"]
|
131 |
+
)
|
132 |
+
MB = 1024.0 * 1024.0
|
133 |
+
for obj in iterable:
|
134 |
+
data_time.update(time.time() - end)
|
135 |
+
yield obj
|
136 |
+
iter_time.update(time.time() - end)
|
137 |
+
if i % print_freq == 0 or i == len(iterable) - 1:
|
138 |
+
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
139 |
+
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
140 |
+
if torch.cuda.is_available():
|
141 |
+
print(
|
142 |
+
log_msg.format(
|
143 |
+
i,
|
144 |
+
len(iterable),
|
145 |
+
eta=eta_string,
|
146 |
+
meters=str(self),
|
147 |
+
time=str(iter_time),
|
148 |
+
data=str(data_time),
|
149 |
+
memory=torch.cuda.max_memory_allocated() / MB,
|
150 |
+
)
|
151 |
+
)
|
152 |
+
else:
|
153 |
+
print(
|
154 |
+
log_msg.format(
|
155 |
+
i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time)
|
156 |
+
)
|
157 |
+
)
|
158 |
+
i += 1
|
159 |
+
end = time.time()
|
160 |
+
total_time = time.time() - start_time
|
161 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
162 |
+
print("{} Total time: {} ({:.4f} s / it)".format(header, total_time_str, total_time / len(iterable)))
|
163 |
+
|
164 |
+
|
165 |
+
@torch.no_grad()
|
166 |
+
def accuracy(output, target, topk=(1,)):
|
167 |
+
"""Computes the precision@k for the specified values of k"""
|
168 |
+
if target.numel() == 0:
|
169 |
+
return [torch.zeros([], device=output.device)]
|
170 |
+
maxk = max(topk)
|
171 |
+
batch_size = target.size(0)
|
172 |
+
|
173 |
+
_, pred = output.topk(maxk, 1, True, True)
|
174 |
+
pred = pred.t()
|
175 |
+
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
176 |
+
|
177 |
+
res = []
|
178 |
+
for k in topk:
|
179 |
+
correct_k = correct[:k].view(-1).float().sum(0)
|
180 |
+
res.append(correct_k.mul_(100.0 / batch_size))
|
181 |
+
return res
|
src/utils/dependencies/XPose/util/optim.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Aishwarya Kamath & Nicolas Carion. Licensed under the Apache License 2.0. All Rights Reserved
|
2 |
+
"""Collections of utilities related to optimization."""
|
3 |
+
from bisect import bisect_right
|
4 |
+
import os
|
5 |
+
|
6 |
+
import torch
|
7 |
+
|
8 |
+
|
9 |
+
def update_ema(model, model_ema, decay):
|
10 |
+
"""Apply exponential moving average update.
|
11 |
+
|
12 |
+
The weights are updated in-place as follow:
|
13 |
+
w_ema = w_ema * decay + (1 - decay) * w
|
14 |
+
Args:
|
15 |
+
model: active model that is being optimized
|
16 |
+
model_ema: running average model
|
17 |
+
decay: exponential decay parameter
|
18 |
+
"""
|
19 |
+
with torch.no_grad():
|
20 |
+
if hasattr(model, "module"):
|
21 |
+
# unwrapping DDP
|
22 |
+
model = model.module
|
23 |
+
msd = model.state_dict()
|
24 |
+
for k, ema_v in model_ema.state_dict().items():
|
25 |
+
model_v = msd[k].detach()
|
26 |
+
ema_v.copy_(ema_v * decay + (1.0 - decay) * model_v)
|
27 |
+
|
28 |
+
|
29 |
+
def adjust_learning_rate(
|
30 |
+
optimizer,
|
31 |
+
epoch: int,
|
32 |
+
curr_step: int,
|
33 |
+
args,
|
34 |
+
):
|
35 |
+
"""Adjust the lr according to the schedule.
|
36 |
+
|
37 |
+
Args:
|
38 |
+
Optimizer: torch optimizer to update.
|
39 |
+
epoch(int): number of the current epoch.
|
40 |
+
curr_step(int): number of optimization step taken so far.
|
41 |
+
num_training_step(int): total number of optimization steps.
|
42 |
+
args: additional training dependent args:
|
43 |
+
- lr_drop(int): number of epochs before dropping the learning rate.
|
44 |
+
- fraction_warmup_steps(float) fraction of steps over which the lr will be increased to its peak.
|
45 |
+
- lr(float): base learning rate
|
46 |
+
- lr_backbone(float): learning rate of the backbone
|
47 |
+
- text_encoder_backbone(float): learning rate of the text encoder
|
48 |
+
- schedule(str): the requested learning rate schedule:
|
49 |
+
"step": all lrs divided by 10 after lr_drop epochs
|
50 |
+
"multistep": divided by 2 after lr_drop epochs, then by 2 after every 50 epochs
|
51 |
+
"linear_with_warmup": same as "step" for backbone + transformer, but for the text encoder, linearly
|
52 |
+
increase for a fraction of the training, then linearly decrease back to 0.
|
53 |
+
"all_linear_with_warmup": same as "linear_with_warmup" for all learning rates involved.
|
54 |
+
|
55 |
+
"""
|
56 |
+
try:
|
57 |
+
num_warmup_steps = args.num_warmup_steps
|
58 |
+
except:
|
59 |
+
return
|
60 |
+
|
61 |
+
if epoch > 0:
|
62 |
+
return
|
63 |
+
|
64 |
+
if curr_step > num_warmup_steps:
|
65 |
+
return
|
66 |
+
|
67 |
+
text_encoder_gamma = float(curr_step) / float(max(1, num_warmup_steps))
|
68 |
+
optimizer.param_groups[-1]["lr"] = args.lr_backbone * text_encoder_gamma
|
69 |
+
|
70 |
+
|
src/utils/dependencies/XPose/util/plot_utils.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Plotting utilities to visualize training logs.
|
3 |
+
"""
|
4 |
+
import torch
|
5 |
+
import pandas as pd
|
6 |
+
import numpy as np
|
7 |
+
import seaborn as sns
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
|
10 |
+
from pathlib import Path, PurePath
|
11 |
+
|
12 |
+
|
13 |
+
def plot_logs(logs, fields=('class_error', 'loss_bbox_unscaled', 'mAP'), ewm_col=0, log_name='log.txt'):
|
14 |
+
'''
|
15 |
+
Function to plot specific fields from training log(s). Plots both training and test results.
|
16 |
+
|
17 |
+
:: Inputs - logs = list containing Path objects, each pointing to individual dir with a log file
|
18 |
+
- fields = which results to plot from each log file - plots both training and test for each field.
|
19 |
+
- ewm_col = optional, which column to use as the exponential weighted smoothing of the plots
|
20 |
+
- log_name = optional, name of log file if different than default 'log.txt'.
|
21 |
+
|
22 |
+
:: Outputs - matplotlib plots of results in fields, color coded for each log file.
|
23 |
+
- solid lines are training results, dashed lines are test results.
|
24 |
+
|
25 |
+
'''
|
26 |
+
func_name = "plot_utils.py::plot_logs"
|
27 |
+
|
28 |
+
# verify logs is a list of Paths (list[Paths]) or single Pathlib object Path,
|
29 |
+
# convert single Path to list to avoid 'not iterable' error
|
30 |
+
|
31 |
+
if not isinstance(logs, list):
|
32 |
+
if isinstance(logs, PurePath):
|
33 |
+
logs = [logs]
|
34 |
+
print(f"{func_name} info: logs param expects a list argument, converted to list[Path].")
|
35 |
+
else:
|
36 |
+
raise ValueError(f"{func_name} - invalid argument for logs parameter.\n \
|
37 |
+
Expect list[Path] or single Path obj, received {type(logs)}")
|
38 |
+
|
39 |
+
# Quality checks - verify valid dir(s), that every item in list is Path object, and that log_name exists in each dir
|
40 |
+
for i, dir in enumerate(logs):
|
41 |
+
if not isinstance(dir, PurePath):
|
42 |
+
raise ValueError(f"{func_name} - non-Path object in logs argument of {type(dir)}: \n{dir}")
|
43 |
+
if not dir.exists():
|
44 |
+
raise ValueError(f"{func_name} - invalid directory in logs argument:\n{dir}")
|
45 |
+
# verify log_name exists
|
46 |
+
fn = Path(dir / log_name)
|
47 |
+
if not fn.exists():
|
48 |
+
print(f"-> missing {log_name}. Have you gotten to Epoch 1 in training?")
|
49 |
+
print(f"--> full path of missing log file: {fn}")
|
50 |
+
return
|
51 |
+
|
52 |
+
# load log file(s) and plot
|
53 |
+
dfs = [pd.read_json(Path(p) / log_name, lines=True) for p in logs]
|
54 |
+
|
55 |
+
fig, axs = plt.subplots(ncols=len(fields), figsize=(16, 5))
|
56 |
+
|
57 |
+
for df, color in zip(dfs, sns.color_palette(n_colors=len(logs))):
|
58 |
+
for j, field in enumerate(fields):
|
59 |
+
if field == 'mAP':
|
60 |
+
coco_eval = pd.DataFrame(
|
61 |
+
np.stack(df.test_coco_eval_bbox.dropna().values)[:, 1]
|
62 |
+
).ewm(com=ewm_col).mean()
|
63 |
+
axs[j].plot(coco_eval, c=color)
|
64 |
+
else:
|
65 |
+
df.interpolate().ewm(com=ewm_col).mean().plot(
|
66 |
+
y=[f'train_{field}', f'test_{field}'],
|
67 |
+
ax=axs[j],
|
68 |
+
color=[color] * 2,
|
69 |
+
style=['-', '--']
|
70 |
+
)
|
71 |
+
for ax, field in zip(axs, fields):
|
72 |
+
if field == 'mAP':
|
73 |
+
ax.legend([Path(p).name for p in logs])
|
74 |
+
ax.set_title(field)
|
75 |
+
else:
|
76 |
+
ax.legend([f'train', f'test'])
|
77 |
+
ax.set_title(field)
|
78 |
+
|
79 |
+
return fig, axs
|
80 |
+
|
81 |
+
def plot_precision_recall(files, naming_scheme='iter'):
|
82 |
+
if naming_scheme == 'exp_id':
|
83 |
+
# name becomes exp_id
|
84 |
+
names = [f.parts[-3] for f in files]
|
85 |
+
elif naming_scheme == 'iter':
|
86 |
+
names = [f.stem for f in files]
|
87 |
+
else:
|
88 |
+
raise ValueError(f'not supported {naming_scheme}')
|
89 |
+
fig, axs = plt.subplots(ncols=2, figsize=(16, 5))
|
90 |
+
for f, color, name in zip(files, sns.color_palette("Blues", n_colors=len(files)), names):
|
91 |
+
data = torch.load(f)
|
92 |
+
# precision is n_iou, n_points, n_cat, n_area, max_det
|
93 |
+
precision = data['precision']
|
94 |
+
recall = data['params'].recThrs
|
95 |
+
scores = data['scores']
|
96 |
+
# take precision for all classes, all areas and 100 detections
|
97 |
+
precision = precision[0, :, :, 0, -1].mean(1)
|
98 |
+
scores = scores[0, :, :, 0, -1].mean(1)
|
99 |
+
prec = precision.mean()
|
100 |
+
rec = data['recall'][0, :, 0, -1].mean()
|
101 |
+
print(f'{naming_scheme} {name}: mAP@50={prec * 100: 05.1f}, ' +
|
102 |
+
f'score={scores.mean():0.3f}, ' +
|
103 |
+
f'f1={2 * prec * rec / (prec + rec + 1e-8):0.3f}'
|
104 |
+
)
|
105 |
+
axs[0].plot(recall, precision, c=color)
|
106 |
+
axs[1].plot(recall, scores, c=color)
|
107 |
+
|
108 |
+
axs[0].set_title('Precision / Recall')
|
109 |
+
axs[0].legend(names)
|
110 |
+
axs[1].set_title('Scores / Recall')
|
111 |
+
axs[1].legend(names)
|
112 |
+
return fig, axs
|
src/utils/dependencies/XPose/util/slio.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ==========================================================
|
2 |
+
# Modified from mmcv
|
3 |
+
# ==========================================================
|
4 |
+
|
5 |
+
import json, pickle, yaml
|
6 |
+
try:
|
7 |
+
from yaml import CLoader as Loader, CDumper as Dumper
|
8 |
+
except ImportError:
|
9 |
+
from yaml import Loader, Dumper
|
10 |
+
|
11 |
+
from pathlib import Path
|
12 |
+
from abc import ABCMeta, abstractmethod
|
13 |
+
|
14 |
+
# ===========================
|
15 |
+
# Rigister handler
|
16 |
+
# ===========================
|
17 |
+
|
18 |
+
class BaseFileHandler(metaclass=ABCMeta):
|
19 |
+
|
20 |
+
@abstractmethod
|
21 |
+
def load_from_fileobj(self, file, **kwargs):
|
22 |
+
pass
|
23 |
+
|
24 |
+
@abstractmethod
|
25 |
+
def dump_to_fileobj(self, obj, file, **kwargs):
|
26 |
+
pass
|
27 |
+
|
28 |
+
@abstractmethod
|
29 |
+
def dump_to_str(self, obj, **kwargs):
|
30 |
+
pass
|
31 |
+
|
32 |
+
def load_from_path(self, filepath, mode='r', **kwargs):
|
33 |
+
with open(filepath, mode) as f:
|
34 |
+
return self.load_from_fileobj(f, **kwargs)
|
35 |
+
|
36 |
+
def dump_to_path(self, obj, filepath, mode='w', **kwargs):
|
37 |
+
with open(filepath, mode) as f:
|
38 |
+
self.dump_to_fileobj(obj, f, **kwargs)
|
39 |
+
|
40 |
+
class JsonHandler(BaseFileHandler):
|
41 |
+
|
42 |
+
def load_from_fileobj(self, file):
|
43 |
+
return json.load(file)
|
44 |
+
|
45 |
+
def dump_to_fileobj(self, obj, file, **kwargs):
|
46 |
+
json.dump(obj, file, **kwargs)
|
47 |
+
|
48 |
+
def dump_to_str(self, obj, **kwargs):
|
49 |
+
return json.dumps(obj, **kwargs)
|
50 |
+
|
51 |
+
class PickleHandler(BaseFileHandler):
|
52 |
+
|
53 |
+
def load_from_fileobj(self, file, **kwargs):
|
54 |
+
return pickle.load(file, **kwargs)
|
55 |
+
|
56 |
+
def load_from_path(self, filepath, **kwargs):
|
57 |
+
return super(PickleHandler, self).load_from_path(
|
58 |
+
filepath, mode='rb', **kwargs)
|
59 |
+
|
60 |
+
def dump_to_str(self, obj, **kwargs):
|
61 |
+
kwargs.setdefault('protocol', 2)
|
62 |
+
return pickle.dumps(obj, **kwargs)
|
63 |
+
|
64 |
+
def dump_to_fileobj(self, obj, file, **kwargs):
|
65 |
+
kwargs.setdefault('protocol', 2)
|
66 |
+
pickle.dump(obj, file, **kwargs)
|
67 |
+
|
68 |
+
def dump_to_path(self, obj, filepath, **kwargs):
|
69 |
+
super(PickleHandler, self).dump_to_path(
|
70 |
+
obj, filepath, mode='wb', **kwargs)
|
71 |
+
|
72 |
+
class YamlHandler(BaseFileHandler):
|
73 |
+
|
74 |
+
def load_from_fileobj(self, file, **kwargs):
|
75 |
+
kwargs.setdefault('Loader', Loader)
|
76 |
+
return yaml.load(file, **kwargs)
|
77 |
+
|
78 |
+
def dump_to_fileobj(self, obj, file, **kwargs):
|
79 |
+
kwargs.setdefault('Dumper', Dumper)
|
80 |
+
yaml.dump(obj, file, **kwargs)
|
81 |
+
|
82 |
+
def dump_to_str(self, obj, **kwargs):
|
83 |
+
kwargs.setdefault('Dumper', Dumper)
|
84 |
+
return yaml.dump(obj, **kwargs)
|
85 |
+
|
86 |
+
file_handlers = {
|
87 |
+
'json': JsonHandler(),
|
88 |
+
'yaml': YamlHandler(),
|
89 |
+
'yml': YamlHandler(),
|
90 |
+
'pickle': PickleHandler(),
|
91 |
+
'pkl': PickleHandler()
|
92 |
+
}
|
93 |
+
|
94 |
+
# ===========================
|
95 |
+
# load and dump
|
96 |
+
# ===========================
|
97 |
+
|
98 |
+
def is_str(x):
|
99 |
+
"""Whether the input is an string instance.
|
100 |
+
|
101 |
+
Note: This method is deprecated since python 2 is no longer supported.
|
102 |
+
"""
|
103 |
+
return isinstance(x, str)
|
104 |
+
|
105 |
+
def slload(file, file_format=None, **kwargs):
|
106 |
+
"""Load data from json/yaml/pickle files.
|
107 |
+
|
108 |
+
This method provides a unified api for loading data from serialized files.
|
109 |
+
|
110 |
+
Args:
|
111 |
+
file (str or :obj:`Path` or file-like object): Filename or a file-like
|
112 |
+
object.
|
113 |
+
file_format (str, optional): If not specified, the file format will be
|
114 |
+
inferred from the file extension, otherwise use the specified one.
|
115 |
+
Currently supported formats include "json", "yaml/yml" and
|
116 |
+
"pickle/pkl".
|
117 |
+
|
118 |
+
Returns:
|
119 |
+
The content from the file.
|
120 |
+
"""
|
121 |
+
if isinstance(file, Path):
|
122 |
+
file = str(file)
|
123 |
+
if file_format is None and is_str(file):
|
124 |
+
file_format = file.split('.')[-1]
|
125 |
+
if file_format not in file_handlers:
|
126 |
+
raise TypeError(f'Unsupported format: {file_format}')
|
127 |
+
|
128 |
+
handler = file_handlers[file_format]
|
129 |
+
if is_str(file):
|
130 |
+
obj = handler.load_from_path(file, **kwargs)
|
131 |
+
elif hasattr(file, 'read'):
|
132 |
+
obj = handler.load_from_fileobj(file, **kwargs)
|
133 |
+
else:
|
134 |
+
raise TypeError('"file" must be a filepath str or a file-object')
|
135 |
+
return obj
|
136 |
+
|
137 |
+
|
138 |
+
def sldump(obj, file=None, file_format=None, **kwargs):
|
139 |
+
"""Dump data to json/yaml/pickle strings or files.
|
140 |
+
|
141 |
+
This method provides a unified api for dumping data as strings or to files,
|
142 |
+
and also supports custom arguments for each file format.
|
143 |
+
|
144 |
+
Args:
|
145 |
+
obj (any): The python object to be dumped.
|
146 |
+
file (str or :obj:`Path` or file-like object, optional): If not
|
147 |
+
specified, then the object is dump to a str, otherwise to a file
|
148 |
+
specified by the filename or file-like object.
|
149 |
+
file_format (str, optional): Same as :func:`load`.
|
150 |
+
|
151 |
+
Returns:
|
152 |
+
bool: True for success, False otherwise.
|
153 |
+
"""
|
154 |
+
if isinstance(file, Path):
|
155 |
+
file = str(file)
|
156 |
+
if file_format is None:
|
157 |
+
if is_str(file):
|
158 |
+
file_format = file.split('.')[-1]
|
159 |
+
elif file is None:
|
160 |
+
raise ValueError(
|
161 |
+
'file_format must be specified since file is None')
|
162 |
+
if file_format not in file_handlers:
|
163 |
+
raise TypeError(f'Unsupported format: {file_format}')
|
164 |
+
|
165 |
+
handler = file_handlers[file_format]
|
166 |
+
if file is None:
|
167 |
+
return handler.dump_to_str(obj, **kwargs)
|
168 |
+
elif is_str(file):
|
169 |
+
handler.dump_to_path(obj, file, **kwargs)
|
170 |
+
elif hasattr(file, 'write'):
|
171 |
+
handler.dump_to_fileobj(obj, file, **kwargs)
|
172 |
+
else:
|
173 |
+
raise TypeError('"file" must be a filename str or a file-object')
|
src/utils/dependencies/XPose/util/time_counter.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import time
|
3 |
+
|
4 |
+
class TimeCounter:
|
5 |
+
def __init__(self) -> None:
|
6 |
+
pass
|
7 |
+
|
8 |
+
def clear(self):
|
9 |
+
self.timedict = {}
|
10 |
+
self.basetime = time.perf_counter()
|
11 |
+
|
12 |
+
def timeit(self, name):
|
13 |
+
nowtime = time.perf_counter() - self.basetime
|
14 |
+
self.timedict[name] = nowtime
|
15 |
+
self.basetime = time.perf_counter()
|
16 |
+
|
17 |
+
|
18 |
+
class TimeHolder:
|
19 |
+
def __init__(self) -> None:
|
20 |
+
self.timedict = {}
|
21 |
+
|
22 |
+
def update(self, _timedict:dict):
|
23 |
+
for k,v in _timedict.items():
|
24 |
+
if k not in self.timedict:
|
25 |
+
self.timedict[k] = AverageMeter(name=k, val_only=True)
|
26 |
+
self.timedict[k].update(val=v)
|
27 |
+
|
28 |
+
def final_res(self):
|
29 |
+
return {k:v.avg for k,v in self.timedict.items()}
|
30 |
+
|
31 |
+
def __str__(self):
|
32 |
+
return json.dumps(self.final_res(), indent=2)
|
33 |
+
|
34 |
+
|
35 |
+
class AverageMeter(object):
|
36 |
+
"""Computes and stores the average and current value"""
|
37 |
+
def __init__(self, name, fmt=':f', val_only=False):
|
38 |
+
self.name = name
|
39 |
+
self.fmt = fmt
|
40 |
+
self.val_only = val_only
|
41 |
+
self.reset()
|
42 |
+
|
43 |
+
def reset(self):
|
44 |
+
self.val = 0
|
45 |
+
self.avg = 0
|
46 |
+
self.sum = 0
|
47 |
+
self.count = 0
|
48 |
+
|
49 |
+
def update(self, val, n=1):
|
50 |
+
self.val = val
|
51 |
+
self.sum += val * n
|
52 |
+
self.count += n
|
53 |
+
self.avg = self.sum / self.count
|
54 |
+
|
55 |
+
def __str__(self):
|
56 |
+
if self.val_only:
|
57 |
+
fmtstr = '{name} {val' + self.fmt + '}'
|
58 |
+
else:
|
59 |
+
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
|
60 |
+
return fmtstr.format(**self.__dict__)
|
src/utils/dependencies/XPose/util/utils.py
ADDED
@@ -0,0 +1,499 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import OrderedDict
|
2 |
+
from copy import deepcopy
|
3 |
+
from typing import Any, Dict, Iterable, List
|
4 |
+
import json
|
5 |
+
import warnings
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
def slprint(x, name="x"):
|
11 |
+
if isinstance(x, (torch.Tensor, np.ndarray)):
|
12 |
+
print(f"{name}.shape:", x.shape)
|
13 |
+
elif isinstance(x, (tuple, list)):
|
14 |
+
print("type x:", type(x))
|
15 |
+
for i in range(min(10, len(x))):
|
16 |
+
slprint(x[i], f"{name}[{i}]")
|
17 |
+
elif isinstance(x, dict):
|
18 |
+
for k, v in x.items():
|
19 |
+
slprint(v, f"{name}[{k}]")
|
20 |
+
else:
|
21 |
+
print(f"{name}.type:", type(x))
|
22 |
+
|
23 |
+
def clean_state_dict(state_dict):
|
24 |
+
new_state_dict = OrderedDict()
|
25 |
+
for k, v in state_dict.items():
|
26 |
+
if k[:7] == 'module.':
|
27 |
+
k = k[7:] # remove `module.`
|
28 |
+
new_state_dict[k] = v
|
29 |
+
return new_state_dict
|
30 |
+
|
31 |
+
def renorm(img: torch.FloatTensor, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) \
|
32 |
+
-> torch.FloatTensor:
|
33 |
+
# img: tensor(3,H,W) or tensor(B,3,H,W)
|
34 |
+
# return: same as img
|
35 |
+
assert img.dim() == 3 or img.dim() == 4, "img.dim() should be 3 or 4 but %d" % img.dim()
|
36 |
+
if img.dim() == 3:
|
37 |
+
assert img.size(0) == 3, 'img.size(0) shoule be 3 but "%d". (%s)' % (img.size(0), str(img.size()))
|
38 |
+
img_perm = img.permute(1,2,0)
|
39 |
+
mean = torch.Tensor(mean)
|
40 |
+
std = torch.Tensor(std)
|
41 |
+
img_res = img_perm * std + mean
|
42 |
+
return img_res.permute(2,0,1)
|
43 |
+
else: # img.dim() == 4
|
44 |
+
assert img.size(1) == 3, 'img.size(1) shoule be 3 but "%d". (%s)' % (img.size(1), str(img.size()))
|
45 |
+
img_perm = img.permute(0,2,3,1)
|
46 |
+
mean = torch.Tensor(mean)
|
47 |
+
std = torch.Tensor(std)
|
48 |
+
img_res = img_perm * std + mean
|
49 |
+
return img_res.permute(0,3,1,2)
|
50 |
+
|
51 |
+
|
52 |
+
|
53 |
+
class CocoClassMapper():
|
54 |
+
def __init__(self) -> None:
|
55 |
+
self.category_map_str = {"1": 1, "2": 2, "3": 3, "4": 4, "5": 5, "6": 6, "7": 7, "8": 8, "9": 9, "10": 10, "11": 11, "13": 12, "14": 13, "15": 14, "16": 15, "17": 16, "18": 17, "19": 18, "20": 19, "21": 20, "22": 21, "23": 22, "24": 23, "25": 24, "27": 25, "28": 26, "31": 27, "32": 28, "33": 29, "34": 30, "35": 31, "36": 32, "37": 33, "38": 34, "39": 35, "40": 36, "41": 37, "42": 38, "43": 39, "44": 40, "46": 41, "47": 42, "48": 43, "49": 44, "50": 45, "51": 46, "52": 47, "53": 48, "54": 49, "55": 50, "56": 51, "57": 52, "58": 53, "59": 54, "60": 55, "61": 56, "62": 57, "63": 58, "64": 59, "65": 60, "67": 61, "70": 62, "72": 63, "73": 64, "74": 65, "75": 66, "76": 67, "77": 68, "78": 69, "79": 70, "80": 71, "81": 72, "82": 73, "84": 74, "85": 75, "86": 76, "87": 77, "88": 78, "89": 79, "90": 80}
|
56 |
+
self.origin2compact_mapper = {int(k):v-1 for k,v in self.category_map_str.items()}
|
57 |
+
self.compact2origin_mapper = {int(v-1):int(k) for k,v in self.category_map_str.items()}
|
58 |
+
|
59 |
+
def origin2compact(self, idx):
|
60 |
+
return self.origin2compact_mapper[int(idx)]
|
61 |
+
|
62 |
+
def compact2origin(self, idx):
|
63 |
+
return self.compact2origin_mapper[int(idx)]
|
64 |
+
|
65 |
+
def to_device(item, device):
|
66 |
+
if isinstance(item, torch.Tensor):
|
67 |
+
return item.to(device)
|
68 |
+
elif isinstance(item, list):
|
69 |
+
return [to_device(i, device) for i in item]
|
70 |
+
elif isinstance(item, dict):
|
71 |
+
return {k: to_device(v, device) for k,v in item.items()}
|
72 |
+
else:
|
73 |
+
raise NotImplementedError("Call Shilong if you use other containers! type: {}".format(type(item)))
|
74 |
+
|
75 |
+
|
76 |
+
|
77 |
+
#
|
78 |
+
def get_gaussian_mean(x, axis, other_axis, softmax=True):
|
79 |
+
"""
|
80 |
+
|
81 |
+
Args:
|
82 |
+
x (float): Input images(BxCxHxW)
|
83 |
+
axis (int): The index for weighted mean
|
84 |
+
other_axis (int): The other index
|
85 |
+
|
86 |
+
Returns: weighted index for axis, BxC
|
87 |
+
|
88 |
+
"""
|
89 |
+
mat2line = torch.sum(x, axis=other_axis)
|
90 |
+
# mat2line = mat2line / mat2line.mean() * 10
|
91 |
+
if softmax:
|
92 |
+
u = torch.softmax(mat2line, axis=2)
|
93 |
+
else:
|
94 |
+
u = mat2line / (mat2line.sum(2, keepdim=True) + 1e-6)
|
95 |
+
size = x.shape[axis]
|
96 |
+
ind = torch.linspace(0, 1, size).to(x.device)
|
97 |
+
batch = x.shape[0]
|
98 |
+
channel = x.shape[1]
|
99 |
+
index = ind.repeat([batch, channel, 1])
|
100 |
+
mean_position = torch.sum(index * u, dim=2)
|
101 |
+
return mean_position
|
102 |
+
|
103 |
+
def get_expected_points_from_map(hm, softmax=True):
|
104 |
+
"""get_gaussian_map_from_points
|
105 |
+
B,C,H,W -> B,N,2 float(0, 1) float(0, 1)
|
106 |
+
softargmax function
|
107 |
+
|
108 |
+
Args:
|
109 |
+
hm (float): Input images(BxCxHxW)
|
110 |
+
|
111 |
+
Returns:
|
112 |
+
weighted index for axis, BxCx2. float between 0 and 1.
|
113 |
+
|
114 |
+
"""
|
115 |
+
# hm = 10*hm
|
116 |
+
B,C,H,W = hm.shape
|
117 |
+
y_mean = get_gaussian_mean(hm, 2, 3, softmax=softmax) # B,C
|
118 |
+
x_mean = get_gaussian_mean(hm, 3, 2, softmax=softmax) # B,C
|
119 |
+
# return torch.cat((x_mean.unsqueeze(-1), y_mean.unsqueeze(-1)), 2)
|
120 |
+
return torch.stack([x_mean, y_mean], dim=2)
|
121 |
+
|
122 |
+
# Positional encoding (section 5.1)
|
123 |
+
# borrow from nerf
|
124 |
+
class Embedder:
|
125 |
+
def __init__(self, **kwargs):
|
126 |
+
self.kwargs = kwargs
|
127 |
+
self.create_embedding_fn()
|
128 |
+
|
129 |
+
def create_embedding_fn(self):
|
130 |
+
embed_fns = []
|
131 |
+
d = self.kwargs['input_dims']
|
132 |
+
out_dim = 0
|
133 |
+
if self.kwargs['include_input']:
|
134 |
+
embed_fns.append(lambda x : x)
|
135 |
+
out_dim += d
|
136 |
+
|
137 |
+
max_freq = self.kwargs['max_freq_log2']
|
138 |
+
N_freqs = self.kwargs['num_freqs']
|
139 |
+
|
140 |
+
if self.kwargs['log_sampling']:
|
141 |
+
freq_bands = 2.**torch.linspace(0., max_freq, steps=N_freqs)
|
142 |
+
else:
|
143 |
+
freq_bands = torch.linspace(2.**0., 2.**max_freq, steps=N_freqs)
|
144 |
+
|
145 |
+
for freq in freq_bands:
|
146 |
+
for p_fn in self.kwargs['periodic_fns']:
|
147 |
+
embed_fns.append(lambda x, p_fn=p_fn, freq=freq : p_fn(x * freq))
|
148 |
+
out_dim += d
|
149 |
+
|
150 |
+
self.embed_fns = embed_fns
|
151 |
+
self.out_dim = out_dim
|
152 |
+
|
153 |
+
def embed(self, inputs):
|
154 |
+
return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
|
155 |
+
|
156 |
+
|
157 |
+
def get_embedder(multires, i=0):
|
158 |
+
import torch.nn as nn
|
159 |
+
if i == -1:
|
160 |
+
return nn.Identity(), 3
|
161 |
+
|
162 |
+
embed_kwargs = {
|
163 |
+
'include_input' : True,
|
164 |
+
'input_dims' : 3,
|
165 |
+
'max_freq_log2' : multires-1,
|
166 |
+
'num_freqs' : multires,
|
167 |
+
'log_sampling' : True,
|
168 |
+
'periodic_fns' : [torch.sin, torch.cos],
|
169 |
+
}
|
170 |
+
|
171 |
+
embedder_obj = Embedder(**embed_kwargs)
|
172 |
+
embed = lambda x, eo=embedder_obj : eo.embed(x)
|
173 |
+
return embed, embedder_obj.out_dim
|
174 |
+
|
175 |
+
class APOPMeter():
|
176 |
+
def __init__(self) -> None:
|
177 |
+
self.tp = 0
|
178 |
+
self.fp = 0
|
179 |
+
self.tn = 0
|
180 |
+
self.fn = 0
|
181 |
+
|
182 |
+
def update(self, pred, gt):
|
183 |
+
"""
|
184 |
+
Input:
|
185 |
+
pred, gt: Tensor()
|
186 |
+
"""
|
187 |
+
assert pred.shape == gt.shape
|
188 |
+
self.tp += torch.logical_and(pred == 1, gt == 1).sum().item()
|
189 |
+
self.fp += torch.logical_and(pred == 1, gt == 0).sum().item()
|
190 |
+
self.tn += torch.logical_and(pred == 0, gt == 0).sum().item()
|
191 |
+
self.tn += torch.logical_and(pred == 1, gt == 0).sum().item()
|
192 |
+
|
193 |
+
def update_cm(self, tp, fp, tn, fn):
|
194 |
+
self.tp += tp
|
195 |
+
self.fp += fp
|
196 |
+
self.tn += tn
|
197 |
+
self.tn += fn
|
198 |
+
|
199 |
+
def inverse_sigmoid(x, eps=1e-5):
|
200 |
+
x = x.clamp(min=0, max=1)
|
201 |
+
x1 = x.clamp(min=eps)
|
202 |
+
x2 = (1 - x).clamp(min=eps)
|
203 |
+
return torch.log(x1/x2)
|
204 |
+
|
205 |
+
import argparse
|
206 |
+
from util.config import Config
|
207 |
+
def get_raw_dict(args):
|
208 |
+
"""
|
209 |
+
return the dicf contained in args.
|
210 |
+
|
211 |
+
e.g:
|
212 |
+
>>> with open(path, 'w') as f:
|
213 |
+
json.dump(get_raw_dict(args), f, indent=2)
|
214 |
+
"""
|
215 |
+
if isinstance(args, argparse.Namespace):
|
216 |
+
return vars(args)
|
217 |
+
elif isinstance(args, dict):
|
218 |
+
return args
|
219 |
+
elif isinstance(args, Config):
|
220 |
+
return args._cfg_dict
|
221 |
+
else:
|
222 |
+
raise NotImplementedError("Unknown type {}".format(type(args)))
|
223 |
+
|
224 |
+
|
225 |
+
def stat_tensors(tensor):
|
226 |
+
assert tensor.dim() == 1
|
227 |
+
tensor_sm = tensor.softmax(0)
|
228 |
+
entropy = (tensor_sm * torch.log(tensor_sm + 1e-9)).sum()
|
229 |
+
|
230 |
+
return {
|
231 |
+
'max': tensor.max(),
|
232 |
+
'min': tensor.min(),
|
233 |
+
'mean': tensor.mean(),
|
234 |
+
'var': tensor.var(),
|
235 |
+
'std': tensor.var() ** 0.5,
|
236 |
+
'entropy': entropy
|
237 |
+
}
|
238 |
+
|
239 |
+
|
240 |
+
class NiceRepr:
|
241 |
+
"""Inherit from this class and define ``__nice__`` to "nicely" print your
|
242 |
+
objects.
|
243 |
+
|
244 |
+
Defines ``__str__`` and ``__repr__`` in terms of ``__nice__`` function
|
245 |
+
Classes that inherit from :class:`NiceRepr` should redefine ``__nice__``.
|
246 |
+
If the inheriting class has a ``__len__``, method then the default
|
247 |
+
``__nice__`` method will return its length.
|
248 |
+
|
249 |
+
Example:
|
250 |
+
>>> class Foo(NiceRepr):
|
251 |
+
... def __nice__(self):
|
252 |
+
... return 'info'
|
253 |
+
>>> foo = Foo()
|
254 |
+
>>> assert str(foo) == '<Foo(info)>'
|
255 |
+
>>> assert repr(foo).startswith('<Foo(info) at ')
|
256 |
+
|
257 |
+
Example:
|
258 |
+
>>> class Bar(NiceRepr):
|
259 |
+
... pass
|
260 |
+
>>> bar = Bar()
|
261 |
+
>>> import pytest
|
262 |
+
>>> with pytest.warns(None) as record:
|
263 |
+
>>> assert 'object at' in str(bar)
|
264 |
+
>>> assert 'object at' in repr(bar)
|
265 |
+
|
266 |
+
Example:
|
267 |
+
>>> class Baz(NiceRepr):
|
268 |
+
... def __len__(self):
|
269 |
+
... return 5
|
270 |
+
>>> baz = Baz()
|
271 |
+
>>> assert str(baz) == '<Baz(5)>'
|
272 |
+
"""
|
273 |
+
|
274 |
+
def __nice__(self):
|
275 |
+
"""str: a "nice" summary string describing this module"""
|
276 |
+
if hasattr(self, '__len__'):
|
277 |
+
# It is a common pattern for objects to use __len__ in __nice__
|
278 |
+
# As a convenience we define a default __nice__ for these objects
|
279 |
+
return str(len(self))
|
280 |
+
else:
|
281 |
+
# In all other cases force the subclass to overload __nice__
|
282 |
+
raise NotImplementedError(
|
283 |
+
f'Define the __nice__ method for {self.__class__!r}')
|
284 |
+
|
285 |
+
def __repr__(self):
|
286 |
+
"""str: the string of the module"""
|
287 |
+
try:
|
288 |
+
nice = self.__nice__()
|
289 |
+
classname = self.__class__.__name__
|
290 |
+
return f'<{classname}({nice}) at {hex(id(self))}>'
|
291 |
+
except NotImplementedError as ex:
|
292 |
+
warnings.warn(str(ex), category=RuntimeWarning)
|
293 |
+
return object.__repr__(self)
|
294 |
+
|
295 |
+
def __str__(self):
|
296 |
+
"""str: the string of the module"""
|
297 |
+
try:
|
298 |
+
classname = self.__class__.__name__
|
299 |
+
nice = self.__nice__()
|
300 |
+
return f'<{classname}({nice})>'
|
301 |
+
except NotImplementedError as ex:
|
302 |
+
warnings.warn(str(ex), category=RuntimeWarning)
|
303 |
+
return object.__repr__(self)
|
304 |
+
|
305 |
+
|
306 |
+
|
307 |
+
def ensure_rng(rng=None):
|
308 |
+
"""Coerces input into a random number generator.
|
309 |
+
|
310 |
+
If the input is None, then a global random state is returned.
|
311 |
+
|
312 |
+
If the input is a numeric value, then that is used as a seed to construct a
|
313 |
+
random state. Otherwise the input is returned as-is.
|
314 |
+
|
315 |
+
Adapted from [1]_.
|
316 |
+
|
317 |
+
Args:
|
318 |
+
rng (int | numpy.random.RandomState | None):
|
319 |
+
if None, then defaults to the global rng. Otherwise this can be an
|
320 |
+
integer or a RandomState class
|
321 |
+
Returns:
|
322 |
+
(numpy.random.RandomState) : rng -
|
323 |
+
a numpy random number generator
|
324 |
+
|
325 |
+
References:
|
326 |
+
.. [1] https://gitlab.kitware.com/computer-vision/kwarray/blob/master/kwarray/util_random.py#L270 # noqa: E501
|
327 |
+
"""
|
328 |
+
|
329 |
+
if rng is None:
|
330 |
+
rng = np.random.mtrand._rand
|
331 |
+
elif isinstance(rng, int):
|
332 |
+
rng = np.random.RandomState(rng)
|
333 |
+
else:
|
334 |
+
rng = rng
|
335 |
+
return rng
|
336 |
+
|
337 |
+
def random_boxes(num=1, scale=1, rng=None):
|
338 |
+
"""Simple version of ``kwimage.Boxes.random``
|
339 |
+
|
340 |
+
Returns:
|
341 |
+
Tensor: shape (n, 4) in x1, y1, x2, y2 format.
|
342 |
+
|
343 |
+
References:
|
344 |
+
https://gitlab.kitware.com/computer-vision/kwimage/blob/master/kwimage/structs/boxes.py#L1390
|
345 |
+
|
346 |
+
Example:
|
347 |
+
>>> num = 3
|
348 |
+
>>> scale = 512
|
349 |
+
>>> rng = 0
|
350 |
+
>>> boxes = random_boxes(num, scale, rng)
|
351 |
+
>>> print(boxes)
|
352 |
+
tensor([[280.9925, 278.9802, 308.6148, 366.1769],
|
353 |
+
[216.9113, 330.6978, 224.0446, 456.5878],
|
354 |
+
[405.3632, 196.3221, 493.3953, 270.7942]])
|
355 |
+
"""
|
356 |
+
rng = ensure_rng(rng)
|
357 |
+
|
358 |
+
tlbr = rng.rand(num, 4).astype(np.float32)
|
359 |
+
|
360 |
+
tl_x = np.minimum(tlbr[:, 0], tlbr[:, 2])
|
361 |
+
tl_y = np.minimum(tlbr[:, 1], tlbr[:, 3])
|
362 |
+
br_x = np.maximum(tlbr[:, 0], tlbr[:, 2])
|
363 |
+
br_y = np.maximum(tlbr[:, 1], tlbr[:, 3])
|
364 |
+
|
365 |
+
tlbr[:, 0] = tl_x * scale
|
366 |
+
tlbr[:, 1] = tl_y * scale
|
367 |
+
tlbr[:, 2] = br_x * scale
|
368 |
+
tlbr[:, 3] = br_y * scale
|
369 |
+
|
370 |
+
boxes = torch.from_numpy(tlbr)
|
371 |
+
return boxes
|
372 |
+
|
373 |
+
|
374 |
+
class ModelEma(torch.nn.Module):
|
375 |
+
def __init__(self, model, decay=0.9997, device=None):
|
376 |
+
super(ModelEma, self).__init__()
|
377 |
+
# make a copy of the model for accumulating moving average of weights
|
378 |
+
self.module = deepcopy(model)
|
379 |
+
self.module.eval()
|
380 |
+
|
381 |
+
# import ipdb; ipdb.set_trace()
|
382 |
+
|
383 |
+
self.decay = decay
|
384 |
+
self.device = device # perform ema on different device from model if set
|
385 |
+
if self.device is not None:
|
386 |
+
self.module.to(device=device)
|
387 |
+
|
388 |
+
def _update(self, model, update_fn):
|
389 |
+
with torch.no_grad():
|
390 |
+
for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
|
391 |
+
if self.device is not None:
|
392 |
+
model_v = model_v.to(device=self.device)
|
393 |
+
ema_v.copy_(update_fn(ema_v, model_v))
|
394 |
+
|
395 |
+
def update(self, model):
|
396 |
+
self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m)
|
397 |
+
|
398 |
+
def set(self, model):
|
399 |
+
self._update(model, update_fn=lambda e, m: m)
|
400 |
+
|
401 |
+
class BestMetricSingle():
|
402 |
+
def __init__(self, init_res=0.0, better='large') -> None:
|
403 |
+
self.init_res = init_res
|
404 |
+
self.best_res = init_res
|
405 |
+
self.best_ep = -1
|
406 |
+
|
407 |
+
self.better = better
|
408 |
+
assert better in ['large', 'small']
|
409 |
+
|
410 |
+
def isbetter(self, new_res, old_res):
|
411 |
+
if self.better == 'large':
|
412 |
+
return new_res > old_res
|
413 |
+
if self.better == 'small':
|
414 |
+
return new_res < old_res
|
415 |
+
|
416 |
+
def update(self, new_res, ep):
|
417 |
+
if self.isbetter(new_res, self.best_res):
|
418 |
+
self.best_res = new_res
|
419 |
+
self.best_ep = ep
|
420 |
+
return True
|
421 |
+
return False
|
422 |
+
|
423 |
+
def __str__(self) -> str:
|
424 |
+
return "best_res: {}\t best_ep: {}".format(self.best_res, self.best_ep)
|
425 |
+
|
426 |
+
def __repr__(self) -> str:
|
427 |
+
return self.__str__()
|
428 |
+
|
429 |
+
def summary(self) -> dict:
|
430 |
+
return {
|
431 |
+
'best_res': self.best_res,
|
432 |
+
'best_ep': self.best_ep,
|
433 |
+
}
|
434 |
+
|
435 |
+
|
436 |
+
class BestMetricHolder():
|
437 |
+
def __init__(self, init_res=0.0, better='large', use_ema=False) -> None:
|
438 |
+
self.best_all = BestMetricSingle(init_res, better)
|
439 |
+
self.use_ema = use_ema
|
440 |
+
if use_ema:
|
441 |
+
self.best_ema = BestMetricSingle(init_res, better)
|
442 |
+
self.best_regular = BestMetricSingle(init_res, better)
|
443 |
+
|
444 |
+
|
445 |
+
def update(self, new_res, epoch, is_ema=False):
|
446 |
+
"""
|
447 |
+
return if the results is the best.
|
448 |
+
"""
|
449 |
+
if not self.use_ema:
|
450 |
+
return self.best_all.update(new_res, epoch)
|
451 |
+
else:
|
452 |
+
if is_ema:
|
453 |
+
self.best_ema.update(new_res, epoch)
|
454 |
+
return self.best_all.update(new_res, epoch)
|
455 |
+
else:
|
456 |
+
self.best_regular.update(new_res, epoch)
|
457 |
+
return self.best_all.update(new_res, epoch)
|
458 |
+
|
459 |
+
def summary(self):
|
460 |
+
if not self.use_ema:
|
461 |
+
return self.best_all.summary()
|
462 |
+
|
463 |
+
res = {}
|
464 |
+
res.update({f'all_{k}':v for k,v in self.best_all.summary().items()})
|
465 |
+
res.update({f'regular_{k}':v for k,v in self.best_regular.summary().items()})
|
466 |
+
res.update({f'ema_{k}':v for k,v in self.best_ema.summary().items()})
|
467 |
+
return res
|
468 |
+
|
469 |
+
def __repr__(self) -> str:
|
470 |
+
return json.dumps(self.summary(), indent=2)
|
471 |
+
|
472 |
+
def __str__(self) -> str:
|
473 |
+
return self.__repr__()
|
474 |
+
|
475 |
+
|
476 |
+
def targets_to(targets: List[Dict[str, Any]], device):
|
477 |
+
"""Moves the target dicts to the given device."""
|
478 |
+
excluded_keys = [
|
479 |
+
"questionId",
|
480 |
+
"tokens_positive",
|
481 |
+
"strings_positive",
|
482 |
+
"tokens",
|
483 |
+
"dataset_name",
|
484 |
+
"sentence_id",
|
485 |
+
"original_img_id",
|
486 |
+
"nb_eval",
|
487 |
+
"task_id",
|
488 |
+
"original_id",
|
489 |
+
"token_span",
|
490 |
+
"caption",
|
491 |
+
"dataset_type",
|
492 |
+
"caption_list",
|
493 |
+
"id2catname",
|
494 |
+
"valid_kpt_num",
|
495 |
+
"image_id_ref",
|
496 |
+
"image_id_current",
|
497 |
+
"test_id"
|
498 |
+
]
|
499 |
+
return [{k: v.to(device) if k not in excluded_keys else v for k, v in t.items()} for t in targets]
|