|
|
|
import argparse |
|
from collections import OrderedDict |
|
|
|
import numpy as np |
|
import torch |
|
from mmengine.fileio import load |
|
|
|
|
|
def convert(src, dst): |
|
if src.endswith('pth'): |
|
src_model = torch.load(src) |
|
else: |
|
src_model = load(src) |
|
|
|
dst_state_dict = OrderedDict() |
|
for k, v in src_model['model'].items(): |
|
key_name_split = k.split('.') |
|
if 'backbone.fpn_lateral' in k: |
|
lateral_id = int(key_name_split[-2][-1]) |
|
name = f'neck.lateral_convs.{lateral_id - 2}.' \ |
|
f'conv.{key_name_split[-1]}' |
|
elif 'backbone.fpn_output' in k: |
|
lateral_id = int(key_name_split[-2][-1]) |
|
name = f'neck.fpn_convs.{lateral_id - 2}.conv.' \ |
|
f'{key_name_split[-1]}' |
|
elif 'backbone.bottom_up.stem.conv1.norm.' in k: |
|
name = f'backbone.bn1.{key_name_split[-1]}' |
|
elif 'backbone.bottom_up.stem.conv1.' in k: |
|
name = f'backbone.conv1.{key_name_split[-1]}' |
|
elif 'backbone.bottom_up.res' in k: |
|
|
|
res_id = int(key_name_split[2][-1]) - 1 |
|
|
|
if 'shortcut' in key_name_split[4]: |
|
if 'shortcut' == key_name_split[-2]: |
|
name = f'backbone.layer{res_id}.' \ |
|
f'{key_name_split[3]}.downsample.0.' \ |
|
f'{key_name_split[-1]}' |
|
elif 'shortcut' == key_name_split[-3]: |
|
name = f'backbone.layer{res_id}.' \ |
|
f'{key_name_split[3]}.downsample.1.' \ |
|
f'{key_name_split[-1]}' |
|
else: |
|
print(f'Unvalid key {k}') |
|
|
|
elif 'conv' in key_name_split[-2]: |
|
conv_id = int(key_name_split[-2][-1]) |
|
name = f'backbone.layer{res_id}.{key_name_split[3]}' \ |
|
f'.conv{conv_id}.{key_name_split[-1]}' |
|
|
|
elif key_name_split[-2] == 'norm': |
|
conv_id = int(key_name_split[-3][-1]) |
|
name = f'backbone.layer{res_id}.{key_name_split[3]}.' \ |
|
f'bn{conv_id}.{key_name_split[-1]}' |
|
else: |
|
print(f'{k} is invalid') |
|
|
|
elif key_name_split[0] == 'head': |
|
|
|
name = f'bbox_{k}' |
|
else: |
|
|
|
print(f'{k} is not converted!!') |
|
continue |
|
|
|
if not isinstance(v, np.ndarray) and not isinstance(v, torch.Tensor): |
|
raise ValueError( |
|
'Unsupported type found in checkpoint! {}: {}'.format( |
|
k, type(v))) |
|
if not isinstance(v, torch.Tensor): |
|
dst_state_dict[name] = torch.from_numpy(v) |
|
else: |
|
dst_state_dict[name] = v |
|
mmdet_model = dict(state_dict=dst_state_dict, meta=dict()) |
|
torch.save(mmdet_model, dst) |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser(description='Convert model keys') |
|
parser.add_argument('src', help='src detectron model path') |
|
parser.add_argument('dst', help='save path') |
|
args = parser.parse_args() |
|
convert(args.src, args.dst) |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|