File size: 5,284 Bytes
32b542e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 |
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "f2254819-deaf-48ba-848c-471f51ee1221",
"metadata": {},
"outputs": [],
"source": [
"import torch"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "4fe03089-ec4d-4bda-9b02-46cb320e516a",
"metadata": {},
"outputs": [],
"source": [
"origin_checkpoint_path = '/mnt/cache/zhujinguo/codes/UniPerceiver/work_dirs/deepspeed_moe/BERT_L12_H768_experiments/16task_90k_bertbase_lr1e-3_wd0.2_gc0.1_prenorm_warm10k_layerscale1e-3_uniformdp0.1_maeinit_fixedpos_torchfp16_unifieddataset_changeweight_stage2_224size/bertbase_womoe_pretrain2/89999/mp_rank_00_model_states.pt'"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "edc282cd-8345-4321-b0a0-3e21d64bfa35",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"dict_keys(['module', 'buffer_names', 'optimizer', 'lr_scheduler', 'sparse_tensor_module_names', 'skipped_steps', 'global_steps', 'global_samples', 'dp_world_size', 'mp_world_size', 'ds_config', 'ds_version', 'iteration'])"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"origin_checkpoint = torch.load(origin_checkpoint_path, 'cpu')\n",
"origin_checkpoint.keys()\n",
"# list(origin_checkpoint['module'].keys())"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "79d9f479-3144-4791-82ba-71fec264aa29",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"201"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(list(origin_checkpoint['module'].keys()))"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "3452947d-4593-4431-a772-3a8ad4882c03",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"dict_keys(['model', 'trainer', 'amp_scaler', 'scheduler', 'iteration'])"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# new_checkpoint_path = 'new_exp/model_Epoch_00160_Iter_0000159.pth'\n",
"# new_checkpoint = torch.load(new_checkpoint_path, 'cpu')\n",
"# new_checkpoint.keys()\n",
"# list(new_checkpoint['model'].keys())"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "ffdcf5c5-ffd4-4379-89d7-37ce05c4c0f2",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"41"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# len(list(new_checkpoint['model'].keys()))"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "fec7a303-a30c-4e92-9452-b534a52d67e9",
"metadata": {},
"outputs": [],
"source": [
"mapping_dict = {\n",
"\n",
" 'encoder.': 'fused_encoder.',\n",
" 'attention.self.qkv_proj.weight': 'self_attn.in_proj_weight',\n",
" 'attention.self.qkv_proj.bias': 'self_attn.in_proj_bias',\n",
" 'attention.output.dense': 'self_attn.out_proj',\n",
" 'attention_output.residual_scale': 'gamma_1',\n",
" 'ffn.dense.': 'linear1.',\n",
" 'ffn.dense2.': 'linear2.',\n",
" 'ffn_output.residual_scale': 'gamma_2',\n",
" 'LayerNormModules.0.': 'norm1.',\n",
" 'LayerNormModules.1.': 'norm2.',\n",
" 'predictor.': 'loss_prepare.',\n",
" \n",
"}\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "897ff2f0-1232-4d25-9c13-7ea9568da362",
"metadata": {},
"outputs": [],
"source": [
"new_checkpoint = { } \n",
"\n",
"module_checkpoint = origin_checkpoint['module']\n",
"\n",
"for k, v in module_checkpoint.items():\n",
" if k.endswith('residual_scale'):\n",
" v.squeeze_(1).squeeze_(0)\n",
" if k.startswith('visual_embed'):\n",
" continue\n",
" for origin_str, target_str in mapping_dict.items():\n",
" if origin_str in k:\n",
" k = k.replace(origin_str, target_str)\n",
" \n",
" new_checkpoint[k] = v.float()\n",
"\n",
"# merge type embedding in video_embed \n",
"new_checkpoint['video_embed.embeddings.bias'] = new_checkpoint['video_embed.embeddings.bias'] + new_checkpoint['video_embed.embeddings_type.weight'][0]"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "3c26719f-7451-4c0a-85c3-640c820dfe98",
"metadata": {},
"outputs": [],
"source": [
"\n",
"torch.save({ 'model': new_checkpoint}, '/mnt/lustre/zhujinguo/codes/Uni-Perceiver/work_dirs/pretrained_models/uni-perceiver-base-L12-H768-224size-pretrained.pth')"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.11"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
|