MMDet / mmdetection /projects /EfficientDet /convert_tf_to_pt.py
Saurabh1105's picture
MMdet Model for Image Segmentation
6c9ac8f
import argparse
import numpy as np
import torch
from tensorflow.python.training import py_checkpoint_reader
torch.set_printoptions(precision=20)
def tf2pth(v):
if v.ndim == 4:
return np.ascontiguousarray(v.transpose(3, 2, 0, 1))
elif v.ndim == 2:
return np.ascontiguousarray(v.transpose())
return v
def convert_key(model_name, bifpn_repeats, weights):
p6_w1 = [
torch.tensor([-1e4, -1e4], dtype=torch.float64)
for _ in range(bifpn_repeats)
]
p5_w1 = [
torch.tensor([-1e4, -1e4], dtype=torch.float64)
for _ in range(bifpn_repeats)
]
p4_w1 = [
torch.tensor([-1e4, -1e4], dtype=torch.float64)
for _ in range(bifpn_repeats)
]
p3_w1 = [
torch.tensor([-1e4, -1e4], dtype=torch.float64)
for _ in range(bifpn_repeats)
]
p4_w2 = [
torch.tensor([-1e4, -1e4, -1e4], dtype=torch.float64)
for _ in range(bifpn_repeats)
]
p5_w2 = [
torch.tensor([-1e4, -1e4, -1e4], dtype=torch.float64)
for _ in range(bifpn_repeats)
]
p6_w2 = [
torch.tensor([-1e4, -1e4, -1e4], dtype=torch.float64)
for _ in range(bifpn_repeats)
]
p7_w2 = [
torch.tensor([-1e4, -1e4], dtype=torch.float64)
for _ in range(bifpn_repeats)
]
idx2key = {
0: '1.0',
1: '2.0',
2: '2.1',
3: '3.0',
4: '3.1',
5: '4.0',
6: '4.1',
7: '4.2',
8: '4.3',
9: '4.4',
10: '4.5',
11: '5.0',
12: '5.1',
13: '5.2',
14: '5.3',
15: '5.4'
}
m = dict()
for k, v in weights.items():
if 'Exponential' in k or 'global_step' in k:
continue
seg = k.split('/')
if len(seg) == 1:
continue
if seg[2] == 'depthwise_conv2d':
v = v.transpose(1, 0)
if seg[0] == model_name:
if seg[1] == 'stem':
prefix = 'backbone.layers.0'
mapping = {
'conv2d/kernel': 'conv.weight',
'tpu_batch_normalization/beta': 'bn.bias',
'tpu_batch_normalization/gamma': 'bn.weight',
'tpu_batch_normalization/moving_mean': 'bn.running_mean',
'tpu_batch_normalization/moving_variance':
'bn.running_var',
}
suffix = mapping['/'.join(seg[2:])]
m[prefix + '.' + suffix] = v
elif seg[1].startswith('blocks_'):
idx = int(seg[1][7:])
prefix = '.'.join(['backbone', 'layers', idx2key[idx]])
base_mapping = {
'depthwise_conv2d/depthwise_kernel':
'depthwise_conv.conv.weight',
'se/conv2d/kernel': 'se.conv1.conv.weight',
'se/conv2d/bias': 'se.conv1.conv.bias',
'se/conv2d_1/kernel': 'se.conv2.conv.weight',
'se/conv2d_1/bias': 'se.conv2.conv.bias'
}
if idx == 0:
mapping = {
'conv2d/kernel':
'linear_conv.conv.weight',
'tpu_batch_normalization/beta':
'depthwise_conv.bn.bias',
'tpu_batch_normalization/gamma':
'depthwise_conv.bn.weight',
'tpu_batch_normalization/moving_mean':
'depthwise_conv.bn.running_mean',
'tpu_batch_normalization/moving_variance':
'depthwise_conv.bn.running_var',
'tpu_batch_normalization_1/beta':
'linear_conv.bn.bias',
'tpu_batch_normalization_1/gamma':
'linear_conv.bn.weight',
'tpu_batch_normalization_1/moving_mean':
'linear_conv.bn.running_mean',
'tpu_batch_normalization_1/moving_variance':
'linear_conv.bn.running_var',
}
else:
mapping = {
'depthwise_conv2d/depthwise_kernel':
'depthwise_conv.conv.weight',
'conv2d/kernel':
'expand_conv.conv.weight',
'conv2d_1/kernel':
'linear_conv.conv.weight',
'tpu_batch_normalization/beta':
'expand_conv.bn.bias',
'tpu_batch_normalization/gamma':
'expand_conv.bn.weight',
'tpu_batch_normalization/moving_mean':
'expand_conv.bn.running_mean',
'tpu_batch_normalization/moving_variance':
'expand_conv.bn.running_var',
'tpu_batch_normalization_1/beta':
'depthwise_conv.bn.bias',
'tpu_batch_normalization_1/gamma':
'depthwise_conv.bn.weight',
'tpu_batch_normalization_1/moving_mean':
'depthwise_conv.bn.running_mean',
'tpu_batch_normalization_1/moving_variance':
'depthwise_conv.bn.running_var',
'tpu_batch_normalization_2/beta':
'linear_conv.bn.bias',
'tpu_batch_normalization_2/gamma':
'linear_conv.bn.weight',
'tpu_batch_normalization_2/moving_mean':
'linear_conv.bn.running_mean',
'tpu_batch_normalization_2/moving_variance':
'linear_conv.bn.running_var',
}
mapping.update(base_mapping)
suffix = mapping['/'.join(seg[2:])]
m[prefix + '.' + suffix] = v
elif seg[0] == 'resample_p6':
prefix = 'neck.bifpn.0.p5_to_p6.0'
mapping = {
'conv2d/kernel': 'down_conv.weight',
'conv2d/bias': 'down_conv.bias',
'bn/beta': 'bn.bias',
'bn/gamma': 'bn.weight',
'bn/moving_mean': 'bn.running_mean',
'bn/moving_variance': 'bn.running_var',
}
suffix = mapping['/'.join(seg[1:])]
m[prefix + '.' + suffix] = v
elif seg[0] == 'fpn_cells':
fpn_idx = int(seg[1][5:])
prefix = '.'.join(['neck', 'bifpn', str(fpn_idx)])
fnode_id = int(seg[2][5])
if fnode_id == 0:
mapping = {
'op_after_combine5/conv/depthwise_kernel':
'conv6_up.depthwise_conv.weight',
'op_after_combine5/conv/pointwise_kernel':
'conv6_up.pointwise_conv.weight',
'op_after_combine5/conv/bias':
'conv6_up.pointwise_conv.bias',
'op_after_combine5/bn/beta':
'conv6_up.bn.bias',
'op_after_combine5/bn/gamma':
'conv6_up.bn.weight',
'op_after_combine5/bn/moving_mean':
'conv6_up.bn.running_mean',
'op_after_combine5/bn/moving_variance':
'conv6_up.bn.running_var',
}
if seg[3] != 'WSM' and seg[3] != 'WSM_1':
suffix = mapping['/'.join(seg[3:])]
if 'depthwise_conv' in suffix:
v = v.transpose(1, 0)
m[prefix + '.' + suffix] = v
elif seg[3] == 'WSM':
p6_w1[fpn_idx][0] = v
elif seg[3] == 'WSM_1':
p6_w1[fpn_idx][1] = v
if torch.min(p6_w1[fpn_idx]) > -1e4:
m[prefix + '.p6_w1'] = p6_w1[fpn_idx]
elif fnode_id == 1:
base_mapping = {
'op_after_combine6/conv/depthwise_kernel':
'conv5_up.depthwise_conv.weight',
'op_after_combine6/conv/pointwise_kernel':
'conv5_up.pointwise_conv.weight',
'op_after_combine6/conv/bias':
'conv5_up.pointwise_conv.bias',
'op_after_combine6/bn/beta':
'conv5_up.bn.bias',
'op_after_combine6/bn/gamma':
'conv5_up.bn.weight',
'op_after_combine6/bn/moving_mean':
'conv5_up.bn.running_mean',
'op_after_combine6/bn/moving_variance':
'conv5_up.bn.running_var',
}
if fpn_idx == 0:
mapping = {
'resample_0_2_6/conv2d/kernel':
'p5_down_channel.down_conv.weight',
'resample_0_2_6/conv2d/bias':
'p5_down_channel.down_conv.bias',
'resample_0_2_6/bn/beta':
'p5_down_channel.bn.bias',
'resample_0_2_6/bn/gamma':
'p5_down_channel.bn.weight',
'resample_0_2_6/bn/moving_mean':
'p5_down_channel.bn.running_mean',
'resample_0_2_6/bn/moving_variance':
'p5_down_channel.bn.running_var',
}
base_mapping.update(mapping)
if seg[3] != 'WSM' and seg[3] != 'WSM_1':
suffix = base_mapping['/'.join(seg[3:])]
if 'depthwise_conv' in suffix:
v = v.transpose(1, 0)
m[prefix + '.' + suffix] = v
elif seg[3] == 'WSM':
p5_w1[fpn_idx][0] = v
elif seg[3] == 'WSM_1':
p5_w1[fpn_idx][1] = v
if torch.min(p5_w1[fpn_idx]) > -1e4:
m[prefix + '.p5_w1'] = p5_w1[fpn_idx]
elif fnode_id == 2:
base_mapping = {
'op_after_combine7/conv/depthwise_kernel':
'conv4_up.depthwise_conv.weight',
'op_after_combine7/conv/pointwise_kernel':
'conv4_up.pointwise_conv.weight',
'op_after_combine7/conv/bias':
'conv4_up.pointwise_conv.bias',
'op_after_combine7/bn/beta':
'conv4_up.bn.bias',
'op_after_combine7/bn/gamma':
'conv4_up.bn.weight',
'op_after_combine7/bn/moving_mean':
'conv4_up.bn.running_mean',
'op_after_combine7/bn/moving_variance':
'conv4_up.bn.running_var',
}
if fpn_idx == 0:
mapping = {
'resample_0_1_7/conv2d/kernel':
'p4_down_channel.down_conv.weight',
'resample_0_1_7/conv2d/bias':
'p4_down_channel.down_conv.bias',
'resample_0_1_7/bn/beta':
'p4_down_channel.bn.bias',
'resample_0_1_7/bn/gamma':
'p4_down_channel.bn.weight',
'resample_0_1_7/bn/moving_mean':
'p4_down_channel.bn.running_mean',
'resample_0_1_7/bn/moving_variance':
'p4_down_channel.bn.running_var',
}
base_mapping.update(mapping)
if seg[3] != 'WSM' and seg[3] != 'WSM_1':
suffix = base_mapping['/'.join(seg[3:])]
if 'depthwise_conv' in suffix:
v = v.transpose(1, 0)
m[prefix + '.' + suffix] = v
elif seg[3] == 'WSM':
p4_w1[fpn_idx][0] = v
elif seg[3] == 'WSM_1':
p4_w1[fpn_idx][1] = v
if torch.min(p4_w1[fpn_idx]) > -1e4:
m[prefix + '.p4_w1'] = p4_w1[fpn_idx]
elif fnode_id == 3:
base_mapping = {
'op_after_combine8/conv/depthwise_kernel':
'conv3_up.depthwise_conv.weight',
'op_after_combine8/conv/pointwise_kernel':
'conv3_up.pointwise_conv.weight',
'op_after_combine8/conv/bias':
'conv3_up.pointwise_conv.bias',
'op_after_combine8/bn/beta':
'conv3_up.bn.bias',
'op_after_combine8/bn/gamma':
'conv3_up.bn.weight',
'op_after_combine8/bn/moving_mean':
'conv3_up.bn.running_mean',
'op_after_combine8/bn/moving_variance':
'conv3_up.bn.running_var',
}
if fpn_idx == 0:
mapping = {
'resample_0_0_8/conv2d/kernel':
'p3_down_channel.down_conv.weight',
'resample_0_0_8/conv2d/bias':
'p3_down_channel.down_conv.bias',
'resample_0_0_8/bn/beta':
'p3_down_channel.bn.bias',
'resample_0_0_8/bn/gamma':
'p3_down_channel.bn.weight',
'resample_0_0_8/bn/moving_mean':
'p3_down_channel.bn.running_mean',
'resample_0_0_8/bn/moving_variance':
'p3_down_channel.bn.running_var',
}
base_mapping.update(mapping)
if seg[3] != 'WSM' and seg[3] != 'WSM_1':
suffix = base_mapping['/'.join(seg[3:])]
if 'depthwise_conv' in suffix:
v = v.transpose(1, 0)
m[prefix + '.' + suffix] = v
elif seg[3] == 'WSM':
p3_w1[fpn_idx][0] = v
elif seg[3] == 'WSM_1':
p3_w1[fpn_idx][1] = v
if torch.min(p3_w1[fpn_idx]) > -1e4:
m[prefix + '.p3_w1'] = p3_w1[fpn_idx]
elif fnode_id == 4:
base_mapping = {
'op_after_combine9/conv/depthwise_kernel':
'conv4_down.depthwise_conv.weight',
'op_after_combine9/conv/pointwise_kernel':
'conv4_down.pointwise_conv.weight',
'op_after_combine9/conv/bias':
'conv4_down.pointwise_conv.bias',
'op_after_combine9/bn/beta':
'conv4_down.bn.bias',
'op_after_combine9/bn/gamma':
'conv4_down.bn.weight',
'op_after_combine9/bn/moving_mean':
'conv4_down.bn.running_mean',
'op_after_combine9/bn/moving_variance':
'conv4_down.bn.running_var',
}
if fpn_idx == 0:
mapping = {
'resample_0_1_9/conv2d/kernel':
'p4_level_connection.down_conv.weight',
'resample_0_1_9/conv2d/bias':
'p4_level_connection.down_conv.bias',
'resample_0_1_9/bn/beta':
'p4_level_connection.bn.bias',
'resample_0_1_9/bn/gamma':
'p4_level_connection.bn.weight',
'resample_0_1_9/bn/moving_mean':
'p4_level_connection.bn.running_mean',
'resample_0_1_9/bn/moving_variance':
'p4_level_connection.bn.running_var',
}
base_mapping.update(mapping)
if seg[3] != 'WSM' and seg[3] != 'WSM_1' and seg[3] != 'WSM_2':
suffix = base_mapping['/'.join(seg[3:])]
if 'depthwise_conv' in suffix:
v = v.transpose(1, 0)
m[prefix + '.' + suffix] = v
elif seg[3] == 'WSM':
p4_w2[fpn_idx][0] = v
elif seg[3] == 'WSM_1':
p4_w2[fpn_idx][1] = v
elif seg[3] == 'WSM_2':
p4_w2[fpn_idx][2] = v
if torch.min(p4_w2[fpn_idx]) > -1e4:
m[prefix + '.p4_w2'] = p4_w2[fpn_idx]
elif fnode_id == 5:
base_mapping = {
'op_after_combine10/conv/depthwise_kernel':
'conv5_down.depthwise_conv.weight',
'op_after_combine10/conv/pointwise_kernel':
'conv5_down.pointwise_conv.weight',
'op_after_combine10/conv/bias':
'conv5_down.pointwise_conv.bias',
'op_after_combine10/bn/beta':
'conv5_down.bn.bias',
'op_after_combine10/bn/gamma':
'conv5_down.bn.weight',
'op_after_combine10/bn/moving_mean':
'conv5_down.bn.running_mean',
'op_after_combine10/bn/moving_variance':
'conv5_down.bn.running_var',
}
if fpn_idx == 0:
mapping = {
'resample_0_2_10/conv2d/kernel':
'p5_level_connection.down_conv.weight',
'resample_0_2_10/conv2d/bias':
'p5_level_connection.down_conv.bias',
'resample_0_2_10/bn/beta':
'p5_level_connection.bn.bias',
'resample_0_2_10/bn/gamma':
'p5_level_connection.bn.weight',
'resample_0_2_10/bn/moving_mean':
'p5_level_connection.bn.running_mean',
'resample_0_2_10/bn/moving_variance':
'p5_level_connection.bn.running_var',
}
base_mapping.update(mapping)
if seg[3] != 'WSM' and seg[3] != 'WSM_1' and seg[3] != 'WSM_2':
suffix = base_mapping['/'.join(seg[3:])]
if 'depthwise_conv' in suffix:
v = v.transpose(1, 0)
m[prefix + '.' + suffix] = v
elif seg[3] == 'WSM':
p5_w2[fpn_idx][0] = v
elif seg[3] == 'WSM_1':
p5_w2[fpn_idx][1] = v
elif seg[3] == 'WSM_2':
p5_w2[fpn_idx][2] = v
if torch.min(p5_w2[fpn_idx]) > -1e4:
m[prefix + '.p5_w2'] = p5_w2[fpn_idx]
elif fnode_id == 6:
base_mapping = {
'op_after_combine11/conv/depthwise_kernel':
'conv6_down.depthwise_conv.weight',
'op_after_combine11/conv/pointwise_kernel':
'conv6_down.pointwise_conv.weight',
'op_after_combine11/conv/bias':
'conv6_down.pointwise_conv.bias',
'op_after_combine11/bn/beta':
'conv6_down.bn.bias',
'op_after_combine11/bn/gamma':
'conv6_down.bn.weight',
'op_after_combine11/bn/moving_mean':
'conv6_down.bn.running_mean',
'op_after_combine11/bn/moving_variance':
'conv6_down.bn.running_var',
}
if seg[3] != 'WSM' and seg[3] != 'WSM_1' and seg[3] != 'WSM_2':
suffix = base_mapping['/'.join(seg[3:])]
if 'depthwise_conv' in suffix:
v = v.transpose(1, 0)
m[prefix + '.' + suffix] = v
elif seg[3] == 'WSM':
p6_w2[fpn_idx][0] = v
elif seg[3] == 'WSM_1':
p6_w2[fpn_idx][1] = v
elif seg[3] == 'WSM_2':
p6_w2[fpn_idx][2] = v
if torch.min(p6_w2[fpn_idx]) > -1e4:
m[prefix + '.p6_w2'] = p6_w2[fpn_idx]
elif fnode_id == 7:
base_mapping = {
'op_after_combine12/conv/depthwise_kernel':
'conv7_down.depthwise_conv.weight',
'op_after_combine12/conv/pointwise_kernel':
'conv7_down.pointwise_conv.weight',
'op_after_combine12/conv/bias':
'conv7_down.pointwise_conv.bias',
'op_after_combine12/bn/beta':
'conv7_down.bn.bias',
'op_after_combine12/bn/gamma':
'conv7_down.bn.weight',
'op_after_combine12/bn/moving_mean':
'conv7_down.bn.running_mean',
'op_after_combine12/bn/moving_variance':
'conv7_down.bn.running_var',
}
if seg[3] != 'WSM' and seg[3] != 'WSM_1' and seg[3] != 'WSM_2':
suffix = base_mapping['/'.join(seg[3:])]
if 'depthwise_conv' in suffix:
v = v.transpose(1, 0)
m[prefix + '.' + suffix] = v
elif seg[3] == 'WSM':
p7_w2[fpn_idx][0] = v
elif seg[3] == 'WSM_1':
p7_w2[fpn_idx][1] = v
if torch.min(p7_w2[fpn_idx]) > -1e4:
m[prefix + '.p7_w2'] = p7_w2[fpn_idx]
elif seg[0] == 'box_net':
if 'box-predict' in seg[1]:
prefix = '.'.join(['bbox_head', 'reg_header'])
base_mapping = {
'depthwise_kernel': 'depthwise_conv.weight',
'pointwise_kernel': 'pointwise_conv.weight',
'bias': 'pointwise_conv.bias'
}
suffix = base_mapping['/'.join(seg[2:])]
if 'depthwise_conv' in suffix:
v = v.transpose(1, 0)
m[prefix + '.' + suffix] = v
elif 'bn' in seg[1]:
bbox_conv_idx = int(seg[1][4])
bbox_bn_idx = int(seg[1][9]) - 3
prefix = '.'.join([
'bbox_head', 'reg_bn_list',
str(bbox_conv_idx),
str(bbox_bn_idx)
])
base_mapping = {
'beta': 'bias',
'gamma': 'weight',
'moving_mean': 'running_mean',
'moving_variance': 'running_var'
}
suffix = base_mapping['/'.join(seg[2:])]
m[prefix + '.' + suffix] = v
else:
bbox_conv_idx = int(seg[1][4])
prefix = '.'.join(
['bbox_head', 'reg_conv_list',
str(bbox_conv_idx)])
base_mapping = {
'depthwise_kernel': 'depthwise_conv.weight',
'pointwise_kernel': 'pointwise_conv.weight',
'bias': 'pointwise_conv.bias'
}
suffix = base_mapping['/'.join(seg[2:])]
if 'depthwise_conv' in suffix:
v = v.transpose(1, 0)
m[prefix + '.' + suffix] = v
elif seg[0] == 'class_net':
if 'class-predict' in seg[1]:
prefix = '.'.join(['bbox_head', 'cls_header'])
base_mapping = {
'depthwise_kernel': 'depthwise_conv.weight',
'pointwise_kernel': 'pointwise_conv.weight',
'bias': 'pointwise_conv.bias'
}
suffix = base_mapping['/'.join(seg[2:])]
if 'depthwise_conv' in suffix:
v = v.transpose(1, 0)
m[prefix + '.' + suffix] = v
elif 'bn' in seg[1]:
cls_conv_idx = int(seg[1][6])
cls_bn_idx = int(seg[1][11]) - 3
prefix = '.'.join([
'bbox_head', 'cls_bn_list',
str(cls_conv_idx),
str(cls_bn_idx)
])
base_mapping = {
'beta': 'bias',
'gamma': 'weight',
'moving_mean': 'running_mean',
'moving_variance': 'running_var'
}
suffix = base_mapping['/'.join(seg[2:])]
m[prefix + '.' + suffix] = v
else:
cls_conv_idx = int(seg[1][6])
prefix = '.'.join(
['bbox_head', 'cls_conv_list',
str(cls_conv_idx)])
base_mapping = {
'depthwise_kernel': 'depthwise_conv.weight',
'pointwise_kernel': 'pointwise_conv.weight',
'bias': 'pointwise_conv.bias'
}
suffix = base_mapping['/'.join(seg[2:])]
if 'depthwise_conv' in suffix:
v = v.transpose(1, 0)
m[prefix + '.' + suffix] = v
return m
def parse_args():
parser = argparse.ArgumentParser(
description='convert efficientdet weight from tensorflow to pytorch')
parser.add_argument(
'--backbone',
type=str,
help='efficientnet model name, like efficientnet-b0')
parser.add_argument(
'--tensorflow_weight',
type=str,
help='efficientdet tensorflow weight name, like efficientdet-d0/model')
parser.add_argument(
'--out_weight',
type=str,
help='efficientdet pytorch weight name like demo.pth')
args = parser.parse_args()
return args
def main():
args = parse_args()
model_name = args.backbone
ori_weight_name = args.tensorflow_weight
out_name = args.out_weight
repeat_map = {
0: 3,
1: 4,
2: 5,
3: 6,
4: 7,
5: 7,
6: 8,
7: 8,
}
reader = py_checkpoint_reader.NewCheckpointReader(ori_weight_name)
weights = {
n: torch.as_tensor(tf2pth(reader.get_tensor(n)))
for (n, _) in reader.get_variable_to_shape_map().items()
}
bifpn_repeats = repeat_map[int(model_name[14])]
out = convert_key(model_name, bifpn_repeats, weights)
result = {'state_dict': out}
torch.save(result, out_name)
if __name__ == '__main__':
main()