|  |  | 
					
						
						|  | import argparse | 
					
						
						|  | import tempfile | 
					
						
						|  | from collections import OrderedDict | 
					
						
						|  |  | 
					
						
						|  | import torch | 
					
						
						|  | from mmengine import Config | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def parse_config(config_strings): | 
					
						
						|  | temp_file = tempfile.NamedTemporaryFile() | 
					
						
						|  | config_path = f'{temp_file.name}.py' | 
					
						
						|  | with open(config_path, 'w') as f: | 
					
						
						|  | f.write(config_strings) | 
					
						
						|  |  | 
					
						
						|  | config = Config.fromfile(config_path) | 
					
						
						|  |  | 
					
						
						|  | if config.model.bbox_head.type != 'SSDHead': | 
					
						
						|  | raise AssertionError('This is not a SSD model.') | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def convert(in_file, out_file): | 
					
						
						|  | checkpoint = torch.load(in_file) | 
					
						
						|  | in_state_dict = checkpoint.pop('state_dict') | 
					
						
						|  | out_state_dict = OrderedDict() | 
					
						
						|  | meta_info = checkpoint['meta'] | 
					
						
						|  | parse_config('#' + meta_info['config']) | 
					
						
						|  | for key, value in in_state_dict.items(): | 
					
						
						|  | if 'extra' in key: | 
					
						
						|  | layer_idx = int(key.split('.')[2]) | 
					
						
						|  | new_key = 'neck.extra_layers.{}.{}.conv.'.format( | 
					
						
						|  | layer_idx // 2, layer_idx % 2) + key.split('.')[-1] | 
					
						
						|  | elif 'l2_norm' in key: | 
					
						
						|  | new_key = 'neck.l2_norm.weight' | 
					
						
						|  | elif 'bbox_head' in key: | 
					
						
						|  | new_key = key[:21] + '.0' + key[21:] | 
					
						
						|  | else: | 
					
						
						|  | new_key = key | 
					
						
						|  | out_state_dict[new_key] = value | 
					
						
						|  | checkpoint['state_dict'] = out_state_dict | 
					
						
						|  |  | 
					
						
						|  | if torch.__version__ >= '1.6': | 
					
						
						|  | torch.save(checkpoint, out_file, _use_new_zipfile_serialization=False) | 
					
						
						|  | else: | 
					
						
						|  | torch.save(checkpoint, out_file) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def main(): | 
					
						
						|  | parser = argparse.ArgumentParser(description='Upgrade SSD version') | 
					
						
						|  | parser.add_argument('in_file', help='input checkpoint file') | 
					
						
						|  | parser.add_argument('out_file', help='output checkpoint file') | 
					
						
						|  |  | 
					
						
						|  | args = parser.parse_args() | 
					
						
						|  | convert(args.in_file, args.out_file) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if __name__ == '__main__': | 
					
						
						|  | main() | 
					
						
						|  |  |