MMDet / mmdetection /projects /DiffusionDet /model_converters /diffusiondet_resnet_to_mmdet.py
Saurabh1105's picture
MMdet Model for Image Segmentation
6c9ac8f
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
from collections import OrderedDict
import numpy as np
import torch
from mmengine.fileio import load
def convert(src, dst):
if src.endswith('pth'):
src_model = torch.load(src)
else:
src_model = load(src)
dst_state_dict = OrderedDict()
for k, v in src_model['model'].items():
key_name_split = k.split('.')
if 'backbone.fpn_lateral' in k:
lateral_id = int(key_name_split[-2][-1])
name = f'neck.lateral_convs.{lateral_id - 2}.' \
f'conv.{key_name_split[-1]}'
elif 'backbone.fpn_output' in k:
lateral_id = int(key_name_split[-2][-1])
name = f'neck.fpn_convs.{lateral_id - 2}.conv.' \
f'{key_name_split[-1]}'
elif 'backbone.bottom_up.stem.conv1.norm.' in k:
name = f'backbone.bn1.{key_name_split[-1]}'
elif 'backbone.bottom_up.stem.conv1.' in k:
name = f'backbone.conv1.{key_name_split[-1]}'
elif 'backbone.bottom_up.res' in k:
# weight_type = key_name_split[-1]
res_id = int(key_name_split[2][-1]) - 1
# deal with short cut
if 'shortcut' in key_name_split[4]:
if 'shortcut' == key_name_split[-2]:
name = f'backbone.layer{res_id}.' \
f'{key_name_split[3]}.downsample.0.' \
f'{key_name_split[-1]}'
elif 'shortcut' == key_name_split[-3]:
name = f'backbone.layer{res_id}.' \
f'{key_name_split[3]}.downsample.1.' \
f'{key_name_split[-1]}'
else:
print(f'Unvalid key {k}')
# deal with conv
elif 'conv' in key_name_split[-2]:
conv_id = int(key_name_split[-2][-1])
name = f'backbone.layer{res_id}.{key_name_split[3]}' \
f'.conv{conv_id}.{key_name_split[-1]}'
# deal with BN
elif key_name_split[-2] == 'norm':
conv_id = int(key_name_split[-3][-1])
name = f'backbone.layer{res_id}.{key_name_split[3]}.' \
f'bn{conv_id}.{key_name_split[-1]}'
else:
print(f'{k} is invalid')
elif key_name_split[0] == 'head':
# d2: head.xxx -> mmdet: bbox_head.xxx
name = f'bbox_{k}'
else:
# some base parameters such as beta will not convert
print(f'{k} is not converted!!')
continue
if not isinstance(v, np.ndarray) and not isinstance(v, torch.Tensor):
raise ValueError(
'Unsupported type found in checkpoint! {}: {}'.format(
k, type(v)))
if not isinstance(v, torch.Tensor):
dst_state_dict[name] = torch.from_numpy(v)
else:
dst_state_dict[name] = v
mmdet_model = dict(state_dict=dst_state_dict, meta=dict())
torch.save(mmdet_model, dst)
def main():
parser = argparse.ArgumentParser(description='Convert model keys')
parser.add_argument('src', help='src detectron model path')
parser.add_argument('dst', help='save path')
args = parser.parse_args()
convert(args.src, args.dst)
if __name__ == '__main__':
main()