|
import argparse |
|
|
|
import numpy as np |
|
import torch |
|
from tensorflow.python.training import py_checkpoint_reader |
|
|
|
torch.set_printoptions(precision=20) |
|
|
|
|
|
def tf2pth(v): |
|
if v.ndim == 4: |
|
return np.ascontiguousarray(v.transpose(3, 2, 0, 1)) |
|
elif v.ndim == 2: |
|
return np.ascontiguousarray(v.transpose()) |
|
return v |
|
|
|
|
|
def convert_key(model_name, bifpn_repeats, weights): |
|
|
|
p6_w1 = [ |
|
torch.tensor([-1e4, -1e4], dtype=torch.float64) |
|
for _ in range(bifpn_repeats) |
|
] |
|
p5_w1 = [ |
|
torch.tensor([-1e4, -1e4], dtype=torch.float64) |
|
for _ in range(bifpn_repeats) |
|
] |
|
p4_w1 = [ |
|
torch.tensor([-1e4, -1e4], dtype=torch.float64) |
|
for _ in range(bifpn_repeats) |
|
] |
|
p3_w1 = [ |
|
torch.tensor([-1e4, -1e4], dtype=torch.float64) |
|
for _ in range(bifpn_repeats) |
|
] |
|
p4_w2 = [ |
|
torch.tensor([-1e4, -1e4, -1e4], dtype=torch.float64) |
|
for _ in range(bifpn_repeats) |
|
] |
|
p5_w2 = [ |
|
torch.tensor([-1e4, -1e4, -1e4], dtype=torch.float64) |
|
for _ in range(bifpn_repeats) |
|
] |
|
p6_w2 = [ |
|
torch.tensor([-1e4, -1e4, -1e4], dtype=torch.float64) |
|
for _ in range(bifpn_repeats) |
|
] |
|
p7_w2 = [ |
|
torch.tensor([-1e4, -1e4], dtype=torch.float64) |
|
for _ in range(bifpn_repeats) |
|
] |
|
idx2key = { |
|
0: '1.0', |
|
1: '2.0', |
|
2: '2.1', |
|
3: '3.0', |
|
4: '3.1', |
|
5: '4.0', |
|
6: '4.1', |
|
7: '4.2', |
|
8: '4.3', |
|
9: '4.4', |
|
10: '4.5', |
|
11: '5.0', |
|
12: '5.1', |
|
13: '5.2', |
|
14: '5.3', |
|
15: '5.4' |
|
} |
|
m = dict() |
|
for k, v in weights.items(): |
|
|
|
if 'Exponential' in k or 'global_step' in k: |
|
continue |
|
|
|
seg = k.split('/') |
|
if len(seg) == 1: |
|
continue |
|
if seg[2] == 'depthwise_conv2d': |
|
v = v.transpose(1, 0) |
|
|
|
if seg[0] == model_name: |
|
if seg[1] == 'stem': |
|
prefix = 'backbone.layers.0' |
|
mapping = { |
|
'conv2d/kernel': 'conv.weight', |
|
'tpu_batch_normalization/beta': 'bn.bias', |
|
'tpu_batch_normalization/gamma': 'bn.weight', |
|
'tpu_batch_normalization/moving_mean': 'bn.running_mean', |
|
'tpu_batch_normalization/moving_variance': |
|
'bn.running_var', |
|
} |
|
suffix = mapping['/'.join(seg[2:])] |
|
m[prefix + '.' + suffix] = v |
|
|
|
elif seg[1].startswith('blocks_'): |
|
idx = int(seg[1][7:]) |
|
prefix = '.'.join(['backbone', 'layers', idx2key[idx]]) |
|
base_mapping = { |
|
'depthwise_conv2d/depthwise_kernel': |
|
'depthwise_conv.conv.weight', |
|
'se/conv2d/kernel': 'se.conv1.conv.weight', |
|
'se/conv2d/bias': 'se.conv1.conv.bias', |
|
'se/conv2d_1/kernel': 'se.conv2.conv.weight', |
|
'se/conv2d_1/bias': 'se.conv2.conv.bias' |
|
} |
|
if idx == 0: |
|
mapping = { |
|
'conv2d/kernel': |
|
'linear_conv.conv.weight', |
|
'tpu_batch_normalization/beta': |
|
'depthwise_conv.bn.bias', |
|
'tpu_batch_normalization/gamma': |
|
'depthwise_conv.bn.weight', |
|
'tpu_batch_normalization/moving_mean': |
|
'depthwise_conv.bn.running_mean', |
|
'tpu_batch_normalization/moving_variance': |
|
'depthwise_conv.bn.running_var', |
|
'tpu_batch_normalization_1/beta': |
|
'linear_conv.bn.bias', |
|
'tpu_batch_normalization_1/gamma': |
|
'linear_conv.bn.weight', |
|
'tpu_batch_normalization_1/moving_mean': |
|
'linear_conv.bn.running_mean', |
|
'tpu_batch_normalization_1/moving_variance': |
|
'linear_conv.bn.running_var', |
|
} |
|
else: |
|
mapping = { |
|
'depthwise_conv2d/depthwise_kernel': |
|
'depthwise_conv.conv.weight', |
|
'conv2d/kernel': |
|
'expand_conv.conv.weight', |
|
'conv2d_1/kernel': |
|
'linear_conv.conv.weight', |
|
'tpu_batch_normalization/beta': |
|
'expand_conv.bn.bias', |
|
'tpu_batch_normalization/gamma': |
|
'expand_conv.bn.weight', |
|
'tpu_batch_normalization/moving_mean': |
|
'expand_conv.bn.running_mean', |
|
'tpu_batch_normalization/moving_variance': |
|
'expand_conv.bn.running_var', |
|
'tpu_batch_normalization_1/beta': |
|
'depthwise_conv.bn.bias', |
|
'tpu_batch_normalization_1/gamma': |
|
'depthwise_conv.bn.weight', |
|
'tpu_batch_normalization_1/moving_mean': |
|
'depthwise_conv.bn.running_mean', |
|
'tpu_batch_normalization_1/moving_variance': |
|
'depthwise_conv.bn.running_var', |
|
'tpu_batch_normalization_2/beta': |
|
'linear_conv.bn.bias', |
|
'tpu_batch_normalization_2/gamma': |
|
'linear_conv.bn.weight', |
|
'tpu_batch_normalization_2/moving_mean': |
|
'linear_conv.bn.running_mean', |
|
'tpu_batch_normalization_2/moving_variance': |
|
'linear_conv.bn.running_var', |
|
} |
|
mapping.update(base_mapping) |
|
suffix = mapping['/'.join(seg[2:])] |
|
m[prefix + '.' + suffix] = v |
|
elif seg[0] == 'resample_p6': |
|
prefix = 'neck.bifpn.0.p5_to_p6.0' |
|
mapping = { |
|
'conv2d/kernel': 'down_conv.weight', |
|
'conv2d/bias': 'down_conv.bias', |
|
'bn/beta': 'bn.bias', |
|
'bn/gamma': 'bn.weight', |
|
'bn/moving_mean': 'bn.running_mean', |
|
'bn/moving_variance': 'bn.running_var', |
|
} |
|
suffix = mapping['/'.join(seg[1:])] |
|
m[prefix + '.' + suffix] = v |
|
elif seg[0] == 'fpn_cells': |
|
fpn_idx = int(seg[1][5:]) |
|
prefix = '.'.join(['neck', 'bifpn', str(fpn_idx)]) |
|
fnode_id = int(seg[2][5]) |
|
if fnode_id == 0: |
|
mapping = { |
|
'op_after_combine5/conv/depthwise_kernel': |
|
'conv6_up.depthwise_conv.weight', |
|
'op_after_combine5/conv/pointwise_kernel': |
|
'conv6_up.pointwise_conv.weight', |
|
'op_after_combine5/conv/bias': |
|
'conv6_up.pointwise_conv.bias', |
|
'op_after_combine5/bn/beta': |
|
'conv6_up.bn.bias', |
|
'op_after_combine5/bn/gamma': |
|
'conv6_up.bn.weight', |
|
'op_after_combine5/bn/moving_mean': |
|
'conv6_up.bn.running_mean', |
|
'op_after_combine5/bn/moving_variance': |
|
'conv6_up.bn.running_var', |
|
} |
|
if seg[3] != 'WSM' and seg[3] != 'WSM_1': |
|
suffix = mapping['/'.join(seg[3:])] |
|
if 'depthwise_conv' in suffix: |
|
v = v.transpose(1, 0) |
|
m[prefix + '.' + suffix] = v |
|
elif seg[3] == 'WSM': |
|
p6_w1[fpn_idx][0] = v |
|
elif seg[3] == 'WSM_1': |
|
p6_w1[fpn_idx][1] = v |
|
if torch.min(p6_w1[fpn_idx]) > -1e4: |
|
m[prefix + '.p6_w1'] = p6_w1[fpn_idx] |
|
elif fnode_id == 1: |
|
base_mapping = { |
|
'op_after_combine6/conv/depthwise_kernel': |
|
'conv5_up.depthwise_conv.weight', |
|
'op_after_combine6/conv/pointwise_kernel': |
|
'conv5_up.pointwise_conv.weight', |
|
'op_after_combine6/conv/bias': |
|
'conv5_up.pointwise_conv.bias', |
|
'op_after_combine6/bn/beta': |
|
'conv5_up.bn.bias', |
|
'op_after_combine6/bn/gamma': |
|
'conv5_up.bn.weight', |
|
'op_after_combine6/bn/moving_mean': |
|
'conv5_up.bn.running_mean', |
|
'op_after_combine6/bn/moving_variance': |
|
'conv5_up.bn.running_var', |
|
} |
|
if fpn_idx == 0: |
|
mapping = { |
|
'resample_0_2_6/conv2d/kernel': |
|
'p5_down_channel.down_conv.weight', |
|
'resample_0_2_6/conv2d/bias': |
|
'p5_down_channel.down_conv.bias', |
|
'resample_0_2_6/bn/beta': |
|
'p5_down_channel.bn.bias', |
|
'resample_0_2_6/bn/gamma': |
|
'p5_down_channel.bn.weight', |
|
'resample_0_2_6/bn/moving_mean': |
|
'p5_down_channel.bn.running_mean', |
|
'resample_0_2_6/bn/moving_variance': |
|
'p5_down_channel.bn.running_var', |
|
} |
|
base_mapping.update(mapping) |
|
if seg[3] != 'WSM' and seg[3] != 'WSM_1': |
|
suffix = base_mapping['/'.join(seg[3:])] |
|
if 'depthwise_conv' in suffix: |
|
v = v.transpose(1, 0) |
|
m[prefix + '.' + suffix] = v |
|
elif seg[3] == 'WSM': |
|
p5_w1[fpn_idx][0] = v |
|
elif seg[3] == 'WSM_1': |
|
p5_w1[fpn_idx][1] = v |
|
if torch.min(p5_w1[fpn_idx]) > -1e4: |
|
m[prefix + '.p5_w1'] = p5_w1[fpn_idx] |
|
elif fnode_id == 2: |
|
base_mapping = { |
|
'op_after_combine7/conv/depthwise_kernel': |
|
'conv4_up.depthwise_conv.weight', |
|
'op_after_combine7/conv/pointwise_kernel': |
|
'conv4_up.pointwise_conv.weight', |
|
'op_after_combine7/conv/bias': |
|
'conv4_up.pointwise_conv.bias', |
|
'op_after_combine7/bn/beta': |
|
'conv4_up.bn.bias', |
|
'op_after_combine7/bn/gamma': |
|
'conv4_up.bn.weight', |
|
'op_after_combine7/bn/moving_mean': |
|
'conv4_up.bn.running_mean', |
|
'op_after_combine7/bn/moving_variance': |
|
'conv4_up.bn.running_var', |
|
} |
|
if fpn_idx == 0: |
|
mapping = { |
|
'resample_0_1_7/conv2d/kernel': |
|
'p4_down_channel.down_conv.weight', |
|
'resample_0_1_7/conv2d/bias': |
|
'p4_down_channel.down_conv.bias', |
|
'resample_0_1_7/bn/beta': |
|
'p4_down_channel.bn.bias', |
|
'resample_0_1_7/bn/gamma': |
|
'p4_down_channel.bn.weight', |
|
'resample_0_1_7/bn/moving_mean': |
|
'p4_down_channel.bn.running_mean', |
|
'resample_0_1_7/bn/moving_variance': |
|
'p4_down_channel.bn.running_var', |
|
} |
|
base_mapping.update(mapping) |
|
if seg[3] != 'WSM' and seg[3] != 'WSM_1': |
|
suffix = base_mapping['/'.join(seg[3:])] |
|
if 'depthwise_conv' in suffix: |
|
v = v.transpose(1, 0) |
|
m[prefix + '.' + suffix] = v |
|
elif seg[3] == 'WSM': |
|
p4_w1[fpn_idx][0] = v |
|
elif seg[3] == 'WSM_1': |
|
p4_w1[fpn_idx][1] = v |
|
if torch.min(p4_w1[fpn_idx]) > -1e4: |
|
m[prefix + '.p4_w1'] = p4_w1[fpn_idx] |
|
elif fnode_id == 3: |
|
|
|
base_mapping = { |
|
'op_after_combine8/conv/depthwise_kernel': |
|
'conv3_up.depthwise_conv.weight', |
|
'op_after_combine8/conv/pointwise_kernel': |
|
'conv3_up.pointwise_conv.weight', |
|
'op_after_combine8/conv/bias': |
|
'conv3_up.pointwise_conv.bias', |
|
'op_after_combine8/bn/beta': |
|
'conv3_up.bn.bias', |
|
'op_after_combine8/bn/gamma': |
|
'conv3_up.bn.weight', |
|
'op_after_combine8/bn/moving_mean': |
|
'conv3_up.bn.running_mean', |
|
'op_after_combine8/bn/moving_variance': |
|
'conv3_up.bn.running_var', |
|
} |
|
if fpn_idx == 0: |
|
mapping = { |
|
'resample_0_0_8/conv2d/kernel': |
|
'p3_down_channel.down_conv.weight', |
|
'resample_0_0_8/conv2d/bias': |
|
'p3_down_channel.down_conv.bias', |
|
'resample_0_0_8/bn/beta': |
|
'p3_down_channel.bn.bias', |
|
'resample_0_0_8/bn/gamma': |
|
'p3_down_channel.bn.weight', |
|
'resample_0_0_8/bn/moving_mean': |
|
'p3_down_channel.bn.running_mean', |
|
'resample_0_0_8/bn/moving_variance': |
|
'p3_down_channel.bn.running_var', |
|
} |
|
base_mapping.update(mapping) |
|
if seg[3] != 'WSM' and seg[3] != 'WSM_1': |
|
suffix = base_mapping['/'.join(seg[3:])] |
|
if 'depthwise_conv' in suffix: |
|
v = v.transpose(1, 0) |
|
m[prefix + '.' + suffix] = v |
|
elif seg[3] == 'WSM': |
|
p3_w1[fpn_idx][0] = v |
|
elif seg[3] == 'WSM_1': |
|
p3_w1[fpn_idx][1] = v |
|
if torch.min(p3_w1[fpn_idx]) > -1e4: |
|
m[prefix + '.p3_w1'] = p3_w1[fpn_idx] |
|
elif fnode_id == 4: |
|
base_mapping = { |
|
'op_after_combine9/conv/depthwise_kernel': |
|
'conv4_down.depthwise_conv.weight', |
|
'op_after_combine9/conv/pointwise_kernel': |
|
'conv4_down.pointwise_conv.weight', |
|
'op_after_combine9/conv/bias': |
|
'conv4_down.pointwise_conv.bias', |
|
'op_after_combine9/bn/beta': |
|
'conv4_down.bn.bias', |
|
'op_after_combine9/bn/gamma': |
|
'conv4_down.bn.weight', |
|
'op_after_combine9/bn/moving_mean': |
|
'conv4_down.bn.running_mean', |
|
'op_after_combine9/bn/moving_variance': |
|
'conv4_down.bn.running_var', |
|
} |
|
if fpn_idx == 0: |
|
mapping = { |
|
'resample_0_1_9/conv2d/kernel': |
|
'p4_level_connection.down_conv.weight', |
|
'resample_0_1_9/conv2d/bias': |
|
'p4_level_connection.down_conv.bias', |
|
'resample_0_1_9/bn/beta': |
|
'p4_level_connection.bn.bias', |
|
'resample_0_1_9/bn/gamma': |
|
'p4_level_connection.bn.weight', |
|
'resample_0_1_9/bn/moving_mean': |
|
'p4_level_connection.bn.running_mean', |
|
'resample_0_1_9/bn/moving_variance': |
|
'p4_level_connection.bn.running_var', |
|
} |
|
base_mapping.update(mapping) |
|
if seg[3] != 'WSM' and seg[3] != 'WSM_1' and seg[3] != 'WSM_2': |
|
suffix = base_mapping['/'.join(seg[3:])] |
|
if 'depthwise_conv' in suffix: |
|
v = v.transpose(1, 0) |
|
m[prefix + '.' + suffix] = v |
|
elif seg[3] == 'WSM': |
|
p4_w2[fpn_idx][0] = v |
|
elif seg[3] == 'WSM_1': |
|
p4_w2[fpn_idx][1] = v |
|
elif seg[3] == 'WSM_2': |
|
p4_w2[fpn_idx][2] = v |
|
if torch.min(p4_w2[fpn_idx]) > -1e4: |
|
m[prefix + '.p4_w2'] = p4_w2[fpn_idx] |
|
elif fnode_id == 5: |
|
base_mapping = { |
|
'op_after_combine10/conv/depthwise_kernel': |
|
'conv5_down.depthwise_conv.weight', |
|
'op_after_combine10/conv/pointwise_kernel': |
|
'conv5_down.pointwise_conv.weight', |
|
'op_after_combine10/conv/bias': |
|
'conv5_down.pointwise_conv.bias', |
|
'op_after_combine10/bn/beta': |
|
'conv5_down.bn.bias', |
|
'op_after_combine10/bn/gamma': |
|
'conv5_down.bn.weight', |
|
'op_after_combine10/bn/moving_mean': |
|
'conv5_down.bn.running_mean', |
|
'op_after_combine10/bn/moving_variance': |
|
'conv5_down.bn.running_var', |
|
} |
|
if fpn_idx == 0: |
|
mapping = { |
|
'resample_0_2_10/conv2d/kernel': |
|
'p5_level_connection.down_conv.weight', |
|
'resample_0_2_10/conv2d/bias': |
|
'p5_level_connection.down_conv.bias', |
|
'resample_0_2_10/bn/beta': |
|
'p5_level_connection.bn.bias', |
|
'resample_0_2_10/bn/gamma': |
|
'p5_level_connection.bn.weight', |
|
'resample_0_2_10/bn/moving_mean': |
|
'p5_level_connection.bn.running_mean', |
|
'resample_0_2_10/bn/moving_variance': |
|
'p5_level_connection.bn.running_var', |
|
} |
|
base_mapping.update(mapping) |
|
if seg[3] != 'WSM' and seg[3] != 'WSM_1' and seg[3] != 'WSM_2': |
|
suffix = base_mapping['/'.join(seg[3:])] |
|
if 'depthwise_conv' in suffix: |
|
v = v.transpose(1, 0) |
|
m[prefix + '.' + suffix] = v |
|
elif seg[3] == 'WSM': |
|
p5_w2[fpn_idx][0] = v |
|
elif seg[3] == 'WSM_1': |
|
p5_w2[fpn_idx][1] = v |
|
elif seg[3] == 'WSM_2': |
|
p5_w2[fpn_idx][2] = v |
|
if torch.min(p5_w2[fpn_idx]) > -1e4: |
|
m[prefix + '.p5_w2'] = p5_w2[fpn_idx] |
|
elif fnode_id == 6: |
|
base_mapping = { |
|
'op_after_combine11/conv/depthwise_kernel': |
|
'conv6_down.depthwise_conv.weight', |
|
'op_after_combine11/conv/pointwise_kernel': |
|
'conv6_down.pointwise_conv.weight', |
|
'op_after_combine11/conv/bias': |
|
'conv6_down.pointwise_conv.bias', |
|
'op_after_combine11/bn/beta': |
|
'conv6_down.bn.bias', |
|
'op_after_combine11/bn/gamma': |
|
'conv6_down.bn.weight', |
|
'op_after_combine11/bn/moving_mean': |
|
'conv6_down.bn.running_mean', |
|
'op_after_combine11/bn/moving_variance': |
|
'conv6_down.bn.running_var', |
|
} |
|
if seg[3] != 'WSM' and seg[3] != 'WSM_1' and seg[3] != 'WSM_2': |
|
suffix = base_mapping['/'.join(seg[3:])] |
|
if 'depthwise_conv' in suffix: |
|
v = v.transpose(1, 0) |
|
m[prefix + '.' + suffix] = v |
|
elif seg[3] == 'WSM': |
|
p6_w2[fpn_idx][0] = v |
|
elif seg[3] == 'WSM_1': |
|
p6_w2[fpn_idx][1] = v |
|
elif seg[3] == 'WSM_2': |
|
p6_w2[fpn_idx][2] = v |
|
if torch.min(p6_w2[fpn_idx]) > -1e4: |
|
m[prefix + '.p6_w2'] = p6_w2[fpn_idx] |
|
elif fnode_id == 7: |
|
base_mapping = { |
|
'op_after_combine12/conv/depthwise_kernel': |
|
'conv7_down.depthwise_conv.weight', |
|
'op_after_combine12/conv/pointwise_kernel': |
|
'conv7_down.pointwise_conv.weight', |
|
'op_after_combine12/conv/bias': |
|
'conv7_down.pointwise_conv.bias', |
|
'op_after_combine12/bn/beta': |
|
'conv7_down.bn.bias', |
|
'op_after_combine12/bn/gamma': |
|
'conv7_down.bn.weight', |
|
'op_after_combine12/bn/moving_mean': |
|
'conv7_down.bn.running_mean', |
|
'op_after_combine12/bn/moving_variance': |
|
'conv7_down.bn.running_var', |
|
} |
|
if seg[3] != 'WSM' and seg[3] != 'WSM_1' and seg[3] != 'WSM_2': |
|
suffix = base_mapping['/'.join(seg[3:])] |
|
if 'depthwise_conv' in suffix: |
|
v = v.transpose(1, 0) |
|
m[prefix + '.' + suffix] = v |
|
elif seg[3] == 'WSM': |
|
p7_w2[fpn_idx][0] = v |
|
elif seg[3] == 'WSM_1': |
|
p7_w2[fpn_idx][1] = v |
|
if torch.min(p7_w2[fpn_idx]) > -1e4: |
|
m[prefix + '.p7_w2'] = p7_w2[fpn_idx] |
|
elif seg[0] == 'box_net': |
|
if 'box-predict' in seg[1]: |
|
prefix = '.'.join(['bbox_head', 'reg_header']) |
|
base_mapping = { |
|
'depthwise_kernel': 'depthwise_conv.weight', |
|
'pointwise_kernel': 'pointwise_conv.weight', |
|
'bias': 'pointwise_conv.bias' |
|
} |
|
suffix = base_mapping['/'.join(seg[2:])] |
|
if 'depthwise_conv' in suffix: |
|
v = v.transpose(1, 0) |
|
m[prefix + '.' + suffix] = v |
|
elif 'bn' in seg[1]: |
|
bbox_conv_idx = int(seg[1][4]) |
|
bbox_bn_idx = int(seg[1][9]) - 3 |
|
prefix = '.'.join([ |
|
'bbox_head', 'reg_bn_list', |
|
str(bbox_conv_idx), |
|
str(bbox_bn_idx) |
|
]) |
|
base_mapping = { |
|
'beta': 'bias', |
|
'gamma': 'weight', |
|
'moving_mean': 'running_mean', |
|
'moving_variance': 'running_var' |
|
} |
|
suffix = base_mapping['/'.join(seg[2:])] |
|
m[prefix + '.' + suffix] = v |
|
else: |
|
bbox_conv_idx = int(seg[1][4]) |
|
prefix = '.'.join( |
|
['bbox_head', 'reg_conv_list', |
|
str(bbox_conv_idx)]) |
|
base_mapping = { |
|
'depthwise_kernel': 'depthwise_conv.weight', |
|
'pointwise_kernel': 'pointwise_conv.weight', |
|
'bias': 'pointwise_conv.bias' |
|
} |
|
suffix = base_mapping['/'.join(seg[2:])] |
|
if 'depthwise_conv' in suffix: |
|
v = v.transpose(1, 0) |
|
m[prefix + '.' + suffix] = v |
|
elif seg[0] == 'class_net': |
|
if 'class-predict' in seg[1]: |
|
prefix = '.'.join(['bbox_head', 'cls_header']) |
|
base_mapping = { |
|
'depthwise_kernel': 'depthwise_conv.weight', |
|
'pointwise_kernel': 'pointwise_conv.weight', |
|
'bias': 'pointwise_conv.bias' |
|
} |
|
suffix = base_mapping['/'.join(seg[2:])] |
|
if 'depthwise_conv' in suffix: |
|
v = v.transpose(1, 0) |
|
m[prefix + '.' + suffix] = v |
|
elif 'bn' in seg[1]: |
|
cls_conv_idx = int(seg[1][6]) |
|
cls_bn_idx = int(seg[1][11]) - 3 |
|
prefix = '.'.join([ |
|
'bbox_head', 'cls_bn_list', |
|
str(cls_conv_idx), |
|
str(cls_bn_idx) |
|
]) |
|
base_mapping = { |
|
'beta': 'bias', |
|
'gamma': 'weight', |
|
'moving_mean': 'running_mean', |
|
'moving_variance': 'running_var' |
|
} |
|
suffix = base_mapping['/'.join(seg[2:])] |
|
m[prefix + '.' + suffix] = v |
|
else: |
|
cls_conv_idx = int(seg[1][6]) |
|
prefix = '.'.join( |
|
['bbox_head', 'cls_conv_list', |
|
str(cls_conv_idx)]) |
|
base_mapping = { |
|
'depthwise_kernel': 'depthwise_conv.weight', |
|
'pointwise_kernel': 'pointwise_conv.weight', |
|
'bias': 'pointwise_conv.bias' |
|
} |
|
suffix = base_mapping['/'.join(seg[2:])] |
|
if 'depthwise_conv' in suffix: |
|
v = v.transpose(1, 0) |
|
m[prefix + '.' + suffix] = v |
|
return m |
|
|
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser( |
|
description='convert efficientdet weight from tensorflow to pytorch') |
|
parser.add_argument( |
|
'--backbone', |
|
type=str, |
|
help='efficientnet model name, like efficientnet-b0') |
|
parser.add_argument( |
|
'--tensorflow_weight', |
|
type=str, |
|
help='efficientdet tensorflow weight name, like efficientdet-d0/model') |
|
parser.add_argument( |
|
'--out_weight', |
|
type=str, |
|
help='efficientdet pytorch weight name like demo.pth') |
|
args = parser.parse_args() |
|
return args |
|
|
|
|
|
def main(): |
|
args = parse_args() |
|
model_name = args.backbone |
|
ori_weight_name = args.tensorflow_weight |
|
out_name = args.out_weight |
|
|
|
repeat_map = { |
|
0: 3, |
|
1: 4, |
|
2: 5, |
|
3: 6, |
|
4: 7, |
|
5: 7, |
|
6: 8, |
|
7: 8, |
|
} |
|
|
|
reader = py_checkpoint_reader.NewCheckpointReader(ori_weight_name) |
|
weights = { |
|
n: torch.as_tensor(tf2pth(reader.get_tensor(n))) |
|
for (n, _) in reader.get_variable_to_shape_map().items() |
|
} |
|
bifpn_repeats = repeat_map[int(model_name[14])] |
|
out = convert_key(model_name, bifpn_repeats, weights) |
|
result = {'state_dict': out} |
|
torch.save(result, out_name) |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|